diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f73d11e --- /dev/null +++ b/.gitignore @@ -0,0 +1,22 @@ +# Build artifacts / binaries +smartdns +selective-vpn-api.tar.gz +selective-vpn-api/selective-vpn-api + +# Python caches +__pycache__/ +*.pyc + +# GUI legacy copies / archives +selective-vpn-gui/1/ +selective-vpn-gui/2/ +selective-vpn-gui/*.zip +selective-vpn-gui/main.go + +# Backups / scratch +*.bak +*.bak_* +*.bak.* +*.tmp +selective-vpn-api/works + diff --git a/selective-vpn-api/OPS_CHECKLIST.md b/selective-vpn-api/OPS_CHECKLIST.md new file mode 100644 index 0000000..46989c8 --- /dev/null +++ b/selective-vpn-api/OPS_CHECKLIST.md @@ -0,0 +1,139 @@ +# Ops Checklist (Selective VPN) / Боевой чеклист + +RU: Практический чеклист для проверки и восстановления работы маршрутизации и DNS. +EN: Practical runbook/checklist for validating and recovering routing + DNS behavior. + +RU: DNS mode и Traffic mode это две независимые оси. +EN: DNS mode and Traffic mode are independent. + +## 0) Safety First / Безопасность + +RU: + +- Лучше тестировать изменения, пока у тебя есть стабильный SSH. +- `direct` traffic mode это аварийный режим: он убирает базовые policy rules. +- Если ты сидишь удаленно и доступ только через VPN, избегай `direct` без плана отката. + +EN: + +- Prefer to test changes while you still have a stable SSH session. +- `direct` traffic mode is an emergency option: it removes base policy rules. +- If VPN is your only access, avoid switching to `direct` without a rollback plan. + +## 1) Quick Health (API) + +```bash +curl -s http://127.0.0.1:8080/healthz +curl -s http://127.0.0.1:8080/api/v1/status +curl -s http://127.0.0.1:8080/api/v1/dns/status +curl -s http://127.0.0.1:8080/api/v1/traffic/mode/test +``` + +RU: смотри, чтобы `traffic/mode/test` вернул `healthy=true`, `probe_ok=true`. +EN: make sure `traffic/mode/test` returns `healthy=true`, `probe_ok=true`. + +## 2) Quick Health (Linux) + +```bash +ip rule show +ip -4 route show table agvpn +ip -4 route get 1.1.1.1 +ip -4 route get 1.1.1.1 mark 0x66 +``` + +RU: ожидаемые паттерны: + +- Selective: есть правило типа `pref 12000 fwmark 0x66 lookup agvpn`. +- Full tunnel: есть правило типа `pref 11900 lookup agvpn`. +- В VPN-режимах в `agvpn` таблице есть `default dev `. + +EN: expected patterns: + +- Selective mode: rule like `pref 12000 fwmark 0x66 lookup agvpn`. +- Full tunnel: rule like `pref 11900 lookup agvpn`. +- In VPN modes `agvpn` has `default dev `. + +## 3) Validate LAN and Containers / Проверка локалки и Docker + +RU: цель: в `full_tunnel` обычно нужно, чтобы LAN и Docker продолжали работать. +EN: goal: in `full_tunnel` you usually want LAN and Docker networks to keep working. + +RU: если в `full_tunnel` ломается доступ к LAN/docker: + +- включи `auto_local_bypass`. +- если нужно, чтобы контейнеры ходили в интернет direct (а хост через VPN), добавь docker CIDR в `Force Direct subnets`. + +EN: if LAN/docker break in `full_tunnel`: + +- enable `auto_local_bypass`. +- if you want containers direct in full tunnel, add docker CIDRs to `Force Direct subnets`. + +## 4) Validate nft sets / Проверка nft + +RU: обычно используются два сета: + +- `agvpn4`: direct-resolved IPs + static +- `agvpn_dyn4`: wildcard/smartdns dynamic IPs + +```bash +nft list table inet agvpn +nft list set inet agvpn agvpn4 +nft list set inet agvpn agvpn_dyn4 +``` + +## 5) Wildcard DNS / SmartDNS + +RU: state и артефакты: + +- Canonical wildcard state: `/var/lib/selective-vpn/smartdns-wildcards.json` +- Generated rules file: `/etc/selective-vpn/smartdns.conf` + +RU: runtime accelerator (опционально): + +- когда включен, SmartDNS конфиг может содержать `nftset ... agvpn_dyn4`. +- когда выключен, wildcard все равно работает через resolver job + prewarm. + +```bash +systemctl is-active smartdns-local.service +ls -la /etc/selective-vpn/smartdns.conf /var/lib/selective-vpn/smartdns-wildcards.json +``` + +## 6) Safe Recovery / Безопасный откат + +### A) Clear routes (save cache) / Clear с сохранением снапшота + +GUI: `Clear routes (save cache)`. + +RU: очищает routes/nft, но сохраняет снапшот для восстановления. +EN: clears routes/nft but saves a snapshot for restore. + +### B) Restore cached routes / Восстановление снапшота + +GUI: `Restore cached routes`. + +RU: + +- часть маршрутов может быть `linkdown` (docker bridge). Restore пропускает некритичные ошибки. + +EN: + +- some routes can be `linkdown` (docker bridges). Restore skips non-critical failures. + +### C) Restart services / Рестарт сервисов + +```bash +sudo systemctl restart selective-vpn-api.service +``` + +## 7) Logs / Логи + +```bash +journalctl -u selective-vpn-api.service -n 200 --no-pager +journalctl -u selective-vpn-api.service -f +``` + +## 8) Common Pitfalls / Частые грабли + +- Docker bridge маршруты могут существовать, но быть `linkdown` (best-effort). +- UID/cgroup overrides влияют на процессы хоста (OUTPUT) и обычно не управляют forwarded Docker-трафиком. +- Если overrides списки слишком большие, backend отвергнет их (лимит на каждый тип). diff --git a/selective-vpn-api/TRAFFIC_OVERRIDES_USAGE.md b/selective-vpn-api/TRAFFIC_OVERRIDES_USAGE.md new file mode 100644 index 0000000..3b700b9 --- /dev/null +++ b/selective-vpn-api/TRAFFIC_OVERRIDES_USAGE.md @@ -0,0 +1,173 @@ +# Traffic Mode and Overrides Usage + +This document describes how to use traffic mode extensions in the current build: + +- traffic modes: `selective`, `full_tunnel`, `direct` +- `auto_local_bypass` +- policy overrides by source subnet / UID / cgroup (systemd) +- detected candidates UI (`Add detected...`) + +## 1) Modes + +- `selective`: only marked traffic goes to VPN table (`agvpn`) +- `full_tunnel`: all traffic goes to VPN table +- `direct`: base VPN routing rules are removed + +Notes: + +- DNS mode is independent from traffic mode. +- Modes are controlled from GUI dialog `Traffic mode settings`. + +## 2) Auto-local bypass + +Option: `Auto-local bypass (LAN/container subnets)`. + +When enabled, backend mirrors local routes from `main` table into `agvpn` table: + +- link-scope routes +- private/local ranges +- common container interfaces (`docker*`, `br-*`, `veth*`, `cni*`) + +Purpose: reduce LAN/container breakage in `full_tunnel`. + +Important: + +- `auto_local_bypass` does NOT make containers use direct internet in `full_tunnel`. +- If you want containers to be `direct` in `full_tunnel`, use `Force Direct subnets`. + +## 3) Policy overrides (Advanced) + +Configured in dialog tab `Policy overrides (Advanced)`. + +Layout: + +- `Force VPN` column +- `Force Direct` column + +Each column provides the same types of overrides: + +- `Source subnets` +- `UIDs` +- `Cgroups / services` + +### 3.1) Source subnets + +Meaning: force routing for traffic **by source subnet**. + +Input format: + +- subnet: `172.18.0.0/16` +- single IP is accepted and normalized to `/32` +- one value per line (comma/semicolon separated values are also accepted) + +Practical usage: + +- Docker/bridge networks are best controlled via `Source subnets`. + +### 3.2) UIDs + +Meaning: force routing for **host-local processes** by UID/uidrange. + +Input format: + +- UID: `1000` +- UID range: `1000-1010` +- one value per line + +Important limitation: + +- UID rules generally affect host OUTPUT traffic, not forwarded traffic from Docker bridges. + +### 3.3) Cgroups / services + +Meaning: select workloads by systemd cgroup, backend resolves them to UID rules at apply time. + +Input format: + +- cgroup path or cgroup name, one per line +- examples: + - `/system.slice/jellyfin.service` + - `system.slice/docker.service` + +Current implementation model: + +1. backend scans matching cgroup directory (recursively) and reads `cgroup.procs` +2. resolves each PID owner UID from `/proc//status` +3. creates `uidrange` policy rules from those UIDs + +Important limitations: + +- cgroup override is currently UID-based after resolution. +- if multiple workloads run under same UID (for example `root`), they cannot be separated by UID policy rules. +- if cgroup has no running processes at apply time, no UID rules are created from that cgroup. + +## 4) Detected candidates (`Add detected...`) + +Button: `Add detected...` (in `Policy overrides (Advanced)`). + +This opens a selector populated by the backend endpoint: + +- `GET /api/v1/traffic/candidates` + +Tabs: + +- `Subnets`: LAN + docker/bridge subnets detected from `ip -4 route show table main` +- `Services`: running systemd units -> mapped to cgroup like `system.slice/.service` +- `UIDs`: UIDs detected from running processes (`ps -eo uid,user,comm`) + +Presets (Subnets tab): + +- `Keep LAN direct` +- `Keep Docker direct` + +Safety model: + +- Selecting items only fills the text fields. +- Nothing changes on the host until you click `Apply overrides`. + +## 5) Rule priority and precedence + +Managed `ip rule` priorities: + +- direct subnet overrides: `11600+` +- direct UID overrides: `11680+` +- VPN subnet overrides: `11720+` +- VPN UID overrides: `11800+` +- full tunnel base rule: `11900` +- selective base rule: `12000` + +This means direct overrides are evaluated before VPN overrides and before base mode rules. + +## 6) Recommended workflow + +1. Select traffic mode. +2. Select preferred iface (or `auto`). +3. Toggle `auto_local_bypass` as needed. +4. Fill overrides (subnet/UID/cgroup), optionally using `Add detected...`. +5. Click `Apply overrides`. +6. Click `Test mode` (on Routes tab). +7. If needed, click `Clear routes (save cache)` and/or `Restore cached routes`. + +## 7) Observability + +GUI status line shows: + +- desired/applied mode +- bypass route count +- override count +- resolved cgroup UID count +- cgroup warning text (if any) + +## 8) Troubleshooting quick checks + +```bash +ip rule show +ip -4 route show table agvpn +nft list ruleset | sed -n '/table inet agvpn/,$p' +``` + +If mode health is not OK: + +- verify selected iface exists and is up +- verify `agvpn` table has default route in VPN modes +- verify subnet/UID/cgroup entries are valid and currently active diff --git a/selective-vpn-api/app/assets/domains/bases.txt b/selective-vpn-api/app/assets/domains/bases.txt new file mode 100644 index 0000000..504d1b3 --- /dev/null +++ b/selective-vpn-api/app/assets/domains/bases.txt @@ -0,0 +1,2 @@ +### +# Default bases list (seed). Add domains here; one per line. diff --git a/selective-vpn-api/app/assets/domains/meta-special.txt b/selective-vpn-api/app/assets/domains/meta-special.txt new file mode 100644 index 0000000..06f5bc9 --- /dev/null +++ b/selective-vpn-api/app/assets/domains/meta-special.txt @@ -0,0 +1 @@ +# meta domains (seed) diff --git a/selective-vpn-api/app/assets/domains/static-ips.txt b/selective-vpn-api/app/assets/domains/static-ips.txt new file mode 100644 index 0000000..d1b8402 --- /dev/null +++ b/selective-vpn-api/app/assets/domains/static-ips.txt @@ -0,0 +1 @@ +# static IPs (seed) diff --git a/selective-vpn-api/app/assets/domains/subs.txt b/selective-vpn-api/app/assets/domains/subs.txt new file mode 100644 index 0000000..ec4de7e --- /dev/null +++ b/selective-vpn-api/app/assets/domains/subs.txt @@ -0,0 +1,3 @@ +www +api +static diff --git a/selective-vpn-api/app/autoloop.go b/selective-vpn-api/app/autoloop.go new file mode 100644 index 0000000..88c89bb --- /dev/null +++ b/selective-vpn-api/app/autoloop.go @@ -0,0 +1,204 @@ +package app + +import ( + "fmt" + "log" + "os" + "regexp" + "strings" + "time" +) + +// --------------------------------------------------------------------- +// autoloop +// --------------------------------------------------------------------- + +// EN: Long-running VPN autoloop worker that keeps the tunnel connected, +// EN: updates login/license state, enforces policy route defaults, and emits events. +// RU: Долгоживущий воркер VPN autoloop, который поддерживает соединение, +// RU: обновляет login/license state, чинит policy route и публикует события. + +func runAutoloop(iface, table string, mtu int, stateDirPath, defaultLoc string) { + locFile := stateDirPath + "/adguard-location.txt" + logFile := stateDirPath + "/adguard-autoloop.log" + loginStateFile := stateDirPath + "/adguard-login.json" + licenseTTL := 3600 * time.Second + statusTimeout := 8 * time.Second + connectTimeout := 25 * time.Second + disconnectTimeout := 8 * time.Second + licenseTimeout := 10 * time.Second + lastLicense := time.Time{} + + _ = os.MkdirAll(stateDirPath, 0o755) + + log.Printf("autoloop: start iface=%s table=%s mtu=%d", iface, table, mtu) + + logLine := func(msg string) { + line := fmt.Sprintf("%s autoloop: %s\n", time.Now().Format(time.RFC3339), msg) + f, err := os.OpenFile(logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644) + if err == nil { + defer f.Close() + _, _ = f.WriteString(line) + } + fmt.Print(line) + } + + writeLoginState := func(state, email, msg string) { + ts := time.Now().Format(time.RFC3339) + payload := fmt.Sprintf(`{"ts":"%s","state":"%s","email":"%s","msg":"%s"}`, ts, escapeJSON(state), escapeJSON(email), escapeJSON(msg)) + _ = os.WriteFile(loginStateFile, []byte(payload), 0o644) + } + + getLocation := func() string { + if data, err := os.ReadFile(locFile); err == nil { + for _, ln := range strings.Split(string(data), "\n") { + t := strings.TrimSpace(ln) + if t != "" && !strings.HasPrefix(t, "#") { + return t + } + } + } + return defaultLoc + } + + isConnected := func(out string) bool { + low := strings.ToLower(out) + return strings.Contains(low, "vpn is connected") || + strings.Contains(low, "connected to") || + strings.Contains(low, "after connect: connected") + } + + fixPolicy := func() { + _, stderr, _, err := runCommandTimeout(5*time.Second, + "ip", "-4", "route", "replace", + "default", "dev", iface, + "table", table, + "mtu", fmt.Sprintf("%d", mtu), + ) + if err != nil { + logLine("route: FAILED to set default dev " + iface + + " table " + table + ": " + stderr) + } else { + logLine("route: default dev " + iface + " table " + table + + " mtu " + fmt.Sprintf("%d", mtu) + " OK") + } + } + var emailRe = regexp.MustCompile(`[A-Za-z0-9._%+\-]+@[A-Za-z0-9.\-]+`) + parseEmail := func(text string) string { + return emailRe.FindString(text) + } + + isLoginRequired := func(t string) bool { + low := strings.ToLower(t) + return strings.Contains(low, "please log in") || + strings.Contains(low, "not logged in") || + strings.Contains(low, "login required") || + strings.Contains(low, "sign in") + } + + updateLoginStateFromText := func(text string) { + if isLoginRequired(text) { + writeLoginState("no_login", "", "NOT LOGGED IN") + logLine("login: NO (detected from output)") + return + } + if em := parseEmail(text); em != "" { + writeLoginState("ok", em, "logged in") + logLine("login: OK email=" + em) + return + } + low := strings.ToLower(text) + if strings.Contains(low, "not logged in") || + strings.Contains(low, "expired") || + strings.Contains(low, "no active license") { + writeLoginState("no_login", "", "NOT LOGGED IN (license)") + logLine("login: NO (license says not logged in)") + return + } + + if strings.Contains(low, "license") && + (strings.Contains(low, "active") || strings.Contains(low, "valid")) { + writeLoginState("ok", "", "logged in (license ok)") + logLine("login: OK (license ok, email not found)") + return + } + } + + updateLicense := func() { + now := time.Now() + if !lastLicense.IsZero() && now.Sub(lastLicense) < licenseTTL { + return + } + lastLicense = now + out, _, _, _ := runCommandTimeout(licenseTimeout, adgvpnCLI, "license") + out = stripANSI(out) + updateLoginStateFromText(out) + } + + writeLoginState("unknown", "", "not checked yet") + updateLicense() + + for { + statusOut, _, exitCode, err := runCommandTimeout(statusTimeout, adgvpnCLI, "status") + statusOut = stripANSI(statusOut) + if err != nil { + logLine(fmt.Sprintf("status: ERROR exit=%d err=%v raw=%q", exitCode, err, statusOut)) + } + if isConnected(statusOut) { + logLine("status: CONNECTED; raw: " + statusOut) + fixPolicy() + updateLicense() + events.push("autoloop_status_changed", map[string]string{ + "status_word": "CONNECTED", + "raw_text": statusOut, + }) + time.Sleep(20 * time.Second) + continue + } + + logLine("status: DISCONNECTED; raw: " + statusOut) + events.push("autoloop_status_changed", map[string]string{ + "status_word": "DISCONNECTED", + "raw_text": statusOut, + }) + updateLoginStateFromText(statusOut) + + loc := getLocation() + logLine("reconnecting to " + loc) + + _, _, _, _ = runCommandTimeout(disconnectTimeout, adgvpnCLI, "disconnect") + connectOut, _, _, _ := runCommandTimeout(connectTimeout, adgvpnCLI, "connect", "-l", loc, "--log-to-file") + connectOut = stripANSI(connectOut) + logLine("connect raw: " + connectOut) + updateLoginStateFromText(connectOut) + + statusAfter, _, _, _ := runCommandTimeout(statusTimeout, adgvpnCLI, "status") + statusAfter = stripANSI(statusAfter) + if isConnected(statusAfter) { + logLine("after connect: CONNECTED; raw: " + statusAfter) + fixPolicy() + updateLicense() + events.push("autoloop_status_changed", map[string]string{ + "status_word": "CONNECTED", + "raw_text": statusAfter, + }) + time.Sleep(20 * time.Second) + continue + } + + logLine("after connect: STILL DISCONNECTED; raw: " + statusAfter) + time.Sleep(10 * time.Second) + } +} + +// --------------------------------------------------------------------- +// autoloop helpers +// --------------------------------------------------------------------- + +func escapeJSON(s string) string { + s = strings.ReplaceAll(s, `\`, `\\`) + s = strings.ReplaceAll(s, `"`, `\\"`) + s = strings.ReplaceAll(s, "\n", "\\n") + s = strings.ReplaceAll(s, "\r", "") + return s +} diff --git a/selective-vpn-api/app/config.go b/selective-vpn-api/app/config.go new file mode 100644 index 0000000..92d9177 --- /dev/null +++ b/selective-vpn-api/app/config.go @@ -0,0 +1,109 @@ +package app + +import "embed" + +// EN: Centralized runtime configuration constants and embedded seed assets used +// EN: across the API server, route updater, VPN helpers, and background workers. +// RU: Централизованные runtime-константы и встроенные seed-ресурсы, +// RU: используемые API-сервером, апдейтером маршрутов, VPN-хелперами и воркерами. + +// --------------------------------------------------------------------- +// runtime constants +// --------------------------------------------------------------------- + +const ( + stateDir = "/var/lib/selective-vpn" + statusFilePath = stateDir + "/status.json" + dnsModePath = stateDir + "/dns-mode.json" + trafficModePath = stateDir + "/traffic-mode.json" + + traceLogPath = stateDir + "/trace.log" + smartdnsLogPath = stateDir + "/smartdns.log" + lastIPsPath = stateDir + "/last-ips.txt" + lastIPsMapPath = stateDir + "/last-ips-map.txt" + lastIPsDirect = stateDir + "/last-ips-direct.txt" + lastIPsDyn = stateDir + "/last-ips-dyn.txt" + lastIPsMapDirect = stateDir + "/last-ips-map-direct.txt" + lastIPsMapDyn = stateDir + "/last-ips-map-wildcard.txt" + routesCacheMeta = stateDir + "/routes-clear-cache.json" + routesCacheIPs = stateDir + "/routes-clear-cache-ips.txt" + routesCacheDyn = stateDir + "/routes-clear-cache-ips-dyn.txt" + routesCacheMap = stateDir + "/routes-clear-cache-ips-map.txt" + routesCacheRT = stateDir + "/routes-clear-cache-routes.txt" + + autoloopLogPath = stateDir + "/adguard-autoloop.log" + loginStatePath = stateDir + "/adguard-login.json" + dnsUpstreamsPath = stateDir + "/dns-upstreams.json" + smartdnsWLPath = stateDir + "/smartdns-wildcards.json" + smartdnsRTPath = stateDir + "/smartdns-runtime.json" + desiredLocation = stateDir + "/adguard-location.txt" + + adgvpnCLI = "/usr/local/bin/adguardvpn-cli-root" + + // маршруты v2 + routesServiceTemplate = "selective-vpn2@%s.service" + routesTimerTemplate = "selective-vpn2@%s.timer" + routesServiceEnv = "SELECTIVE_VPN_ROUTES_UNIT" + routesTimerEnv = "SELECTIVE_VPN_ROUTES_TIMER" + + // юнит автоконнекта AdGuard VPN + adgvpnUnit = "adguardvpn-autoconnect.service" + + // доменные файлы / пути + domainDir = "/etc/selective-vpn/domains" + dnsUpstreamsConf = "/etc/selective-vpn/dns-upstreams.conf" + smartdnsDomainsFile = "/etc/selective-vpn/smartdns.conf" + smartdnsMainConfig = "/opt/stack/adguardapp/smartdns.conf" + staticIPsFile = "/etc/selective-vpn/static-ips.txt" + heartbeatFile = stateDir + "/heartbeat" + lockFile = "/run/lock/selective-vpn.lock" + MARK = "0x66" + defaultDNS1 = "94.140.14.14" + defaultDNS2 = "94.140.15.15" + defaultMeta1 = "46.243.231.30" + defaultMeta2 = "46.243.231.41" + + smartDNSDefaultAddr = "127.0.0.1#6053" + smartDNSAddrEnv = "SVPN_SMARTDNS_ADDR" + smartDNSForceEnv = "SVPN_SMARTDNS_FORCE" + + policyRouteMTU = "1380" + defaultTraceTailMax = 800 + + defaultEventsCapacity = 512 + defaultPollStatusMs = 2000 + defaultPollLoginMs = 2500 + defaultPollAutoloopMs = 2500 + defaultPollSystemdMs = 3000 + defaultPollTraceMs = 1500 + defaultHeartbeatSeconds = 15 +) + +// --------------------------------------------------------------------- +// domain expansion lists +// --------------------------------------------------------------------- + +// EN: Domain expansion lists used by routes update to build selective targets. +// RU: Списки доменов для расширения селективных целей при обновлении маршрутов. +var googleLikeDomains = []string{ + "google.com", "googleapis.com", "gstatic.com", "googleusercontent.com", + "1e100.net", "gvt1.com", "gvt2.com", "gvt3.com", +} + +// EN: Extra Twitter subdomains that should be forced through selective routing. +// RU: Дополнительные поддомены Twitter, которые принудительно идут через селективный маршрут. +var twitterSpecial = []string{ + "ton", "pay", "caps", "sms", "cert", "tdweb", "p", "ma-0.twimg", "si0.twimg", + "syndication", "tweetdeck", "stream", "userstream", "sitestream", "betastream", + "music", "ms1", "ms3", "urls-real.api", "music-partner", "partner-stream", +} + +// --------------------------------------------------------------------- +// embedded assets +// --------------------------------------------------------------------- + +// EN: Embedded default domain files used as seed content when runtime files are absent. +// RU: Встроенные файлы доменов по умолчанию для первичного seed, если runtime-файлы отсутствуют. +// +//go:embed assets/domains/* +var embeddedDomains embed.FS diff --git a/selective-vpn-api/app/dns_settings.go b/selective-vpn-api/app/dns_settings.go new file mode 100644 index 0000000..a954695 --- /dev/null +++ b/selective-vpn-api/app/dns_settings.go @@ -0,0 +1,886 @@ +package app + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "os" + "path/filepath" + "sort" + "strings" + "time" +) + +// --------------------------------------------------------------------- +// DNS settings + SmartDNS control +// --------------------------------------------------------------------- + +// EN: DNS control-plane handlers and storage helpers. +// EN: This unit keeps resolver mode, SmartDNS address, SmartDNS service control, +// EN: and dns-upstreams.conf in one place for GUI and backend consistency. +// RU: Обработчики DNS control-plane и helper-функции хранения. +// RU: Этот модуль держит в одном месте режим резолвера, адрес SmartDNS, +// RU: управление сервисом SmartDNS и dns-upstreams.conf для консистентности GUI и backend. + +// --------------------------------------------------------------------- +// EN: `handleDNSUpstreams` is an HTTP handler for dns upstreams. +// RU: `handleDNSUpstreams` - HTTP-обработчик для dns upstreams. +// --------------------------------------------------------------------- +func handleDNSUpstreams(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + writeJSON(w, http.StatusOK, loadDNSUpstreamsConf()) + case http.MethodPost: + var cfg DNSUpstreams + if r.Body != nil { + defer r.Body.Close() + if err := json.NewDecoder(io.LimitReader(r.Body, 1<<20)).Decode(&cfg); err != nil && err != io.EOF { + http.Error(w, "bad json", http.StatusBadRequest) + return + } + } + if err := saveDNSUpstreamsConf(cfg); err != nil { + http.Error(w, "write error", http.StatusInternalServerError) + return + } + writeJSON(w, http.StatusOK, map[string]any{ + "ok": true, + "cfg": loadDNSUpstreamsConf(), + }) + default: + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + } +} + +// --------------------------------------------------------------------- +// EN: `handleDNSStatus` is an HTTP handler for dns status. +// RU: `handleDNSStatus` - HTTP-обработчик для dns status. +// --------------------------------------------------------------------- +func handleDNSStatus(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + mode := loadDNSMode() + writeJSON(w, http.StatusOK, makeDNSStatusResponse(mode)) +} + +// --------------------------------------------------------------------- +// EN: `handleDNSModeSet` is an HTTP handler for dns mode set. +// RU: `handleDNSModeSet` - HTTP-обработчик для dns mode set. +// --------------------------------------------------------------------- +func handleDNSModeSet(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + var req DNSModeRequest + if r.Body != nil { + defer r.Body.Close() + if err := json.NewDecoder(io.LimitReader(r.Body, 1<<20)).Decode(&req); err != nil && err != io.EOF { + http.Error(w, "bad json", http.StatusBadRequest) + return + } + } + + mode := loadDNSMode() + mode.Mode = normalizeDNSResolverMode(req.Mode, req.ViaSmartDNS) + mode.ViaSmartDNS = mode.Mode != DNSModeDirect + if strings.TrimSpace(req.SmartDNSAddr) != "" { + mode.SmartDNSAddr = req.SmartDNSAddr + } + if err := saveDNSMode(mode); err != nil { + http.Error(w, "write error", http.StatusInternalServerError) + return + } + + mode = loadDNSMode() + writeJSON(w, http.StatusOK, makeDNSStatusResponse(mode)) +} + +// --------------------------------------------------------------------- +// EN: `handleDNSSmartdnsService` is an HTTP handler for dns smartdns service. +// RU: `handleDNSSmartdnsService` - HTTP-обработчик для dns smartdns service. +// --------------------------------------------------------------------- +func handleDNSSmartdnsService(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + var body struct { + Action string `json:"action"` + } + if r.Body != nil { + defer r.Body.Close() + _ = json.NewDecoder(io.LimitReader(r.Body, 1<<20)).Decode(&body) + } + + action := strings.ToLower(strings.TrimSpace(body.Action)) + if action == "" { + action = "restart" + } + switch action { + case "start", "stop", "restart": + default: + http.Error(w, "unknown action", http.StatusBadRequest) + return + } + + res := runSmartdnsUnitAction(action) + mode := loadDNSMode() + rt := smartDNSRuntimeSnapshot() + writeJSON(w, http.StatusOK, map[string]any{ + "ok": res.OK, + "message": res.Message, + "exitCode": res.ExitCode, + "stdout": res.Stdout, + "stderr": res.Stderr, + "unit_state": smartdnsUnitState(), + "via_smartdns": mode.ViaSmartDNS, + "smartdns_addr": mode.SmartDNSAddr, + "mode": mode.Mode, + "runtime_nftset": rt.Enabled, + "wildcard_source": rt.WildcardSource, + }) +} + +func makeDNSStatusResponse(mode DNSMode) DNSStatusResponse { + rt := smartDNSRuntimeSnapshot() + resp := DNSStatusResponse{ + ViaSmartDNS: mode.ViaSmartDNS, + SmartDNSAddr: mode.SmartDNSAddr, + Mode: mode.Mode, + UnitState: smartdnsUnitState(), + RuntimeNftset: rt.Enabled, + WildcardSource: rt.WildcardSource, + RuntimeCfgPath: rt.ConfigPath, + } + if rt.Message != "" { + resp.RuntimeCfgError = rt.Message + } + return resp +} + +// --------------------------------------------------------------------- +// EN: `handleSmartdnsService` is an HTTP handler for smartdns service. +// RU: `handleSmartdnsService` - HTTP-обработчик для smartdns service. +// --------------------------------------------------------------------- +func handleSmartdnsService(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + writeJSON(w, http.StatusOK, map[string]string{"state": smartdnsUnitState()}) + case http.MethodPost: + var body struct { + Action string `json:"action"` + } + if r.Body != nil { + defer r.Body.Close() + _ = json.NewDecoder(io.LimitReader(r.Body, 1<<20)).Decode(&body) + } + + action := strings.ToLower(strings.TrimSpace(body.Action)) + if action == "" { + action = "restart" + } + switch action { + case "start", "stop", "restart": + default: + http.Error(w, "unknown action", http.StatusBadRequest) + return + } + writeJSON(w, http.StatusOK, runSmartdnsUnitAction(action)) + default: + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + } +} + +// --------------------------------------------------------------------- +// smartdns runtime accelerator state +// --------------------------------------------------------------------- + +func handleSmartdnsRuntime(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + writeJSON(w, http.StatusOK, smartDNSRuntimeSnapshot()) + case http.MethodPost: + var body SmartDNSRuntimeRequest + if r.Body != nil { + defer r.Body.Close() + if err := json.NewDecoder(io.LimitReader(r.Body, 1<<20)).Decode(&body); err != nil && err != io.EOF { + http.Error(w, "bad json", http.StatusBadRequest) + return + } + } + if body.Enabled == nil { + http.Error(w, "enabled is required", http.StatusBadRequest) + return + } + + prev := loadSmartDNSRuntimeState(nil) + next := prev + next.Enabled = *body.Enabled + if err := saveSmartDNSRuntimeState(next); err != nil { + http.Error(w, "runtime state write error", http.StatusInternalServerError) + return + } + + changed, err := applySmartDNSRuntimeConfig(next.Enabled) + if err != nil { + _ = saveSmartDNSRuntimeState(prev) + http.Error(w, "runtime config apply error: "+err.Error(), http.StatusInternalServerError) + return + } + + restart := true + if body.Restart != nil { + restart = *body.Restart + } + restarted := false + msg := "" + if restart && smartdnsUnitState() == "active" { + res := runSmartdnsUnitAction("restart") + restarted = res.OK + if !res.OK { + msg = "runtime config changed, but smartdns restart failed: " + strings.TrimSpace(res.Message) + } + } + if msg == "" { + msg = fmt.Sprintf("smartdns runtime set: enabled=%t changed=%t restarted=%t", next.Enabled, changed, restarted) + } + appendTraceLineTo(smartdnsLogPath, "smartdns", msg) + + resp := smartDNSRuntimeSnapshot() + resp.Changed = changed + resp.Restarted = restarted + resp.Message = msg + writeJSON(w, http.StatusOK, resp) + default: + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + } +} + +// --------------------------------------------------------------------- +// EN: `handleSmartdnsPrewarm` forces DNS lookups for wildcard domains via SmartDNS. +// EN: This warms agvpn_dyn4 in realtime through SmartDNS nftset runtime integration. +// RU: `handleSmartdnsPrewarm` принудительно резолвит wildcard-домены через SmartDNS. +// RU: Это прогревает agvpn_dyn4 в realtime через runtime-интеграцию SmartDNS nftset. +// --------------------------------------------------------------------- +func handleSmartdnsPrewarm(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + var body struct { + Limit int `json:"limit"` + Workers int `json:"workers"` + TimeoutMS int `json:"timeout_ms"` + AggressiveSubs bool `json:"aggressive_subs"` + } + if r.Body != nil { + defer r.Body.Close() + _ = json.NewDecoder(io.LimitReader(r.Body, 1<<20)).Decode(&body) + } + writeJSON(w, http.StatusOK, runSmartdnsPrewarm(body.Limit, body.Workers, body.TimeoutMS, body.AggressiveSubs)) +} + +func runSmartdnsPrewarm(limit, workers, timeoutMS int, aggressiveSubs bool) cmdResult { + mode := loadDNSMode() + runtimeEnabled := smartDNSRuntimeEnabled() + source := "resolver" + if runtimeEnabled { + source = "smartdns_runtime" + } + smartdnsAddr := normalizeSmartDNSAddr(mode.SmartDNSAddr) + if smartdnsAddr == "" { + smartdnsAddr = resolveDefaultSmartDNSAddr() + } + if smartdnsAddr == "" { + return cmdResult{OK: false, Message: "SmartDNS address is empty"} + } + + wildcards := loadSmartDNSWildcardDomains(nil) + if len(wildcards) == 0 { + msg := "prewarm skipped: wildcard list is empty" + appendTraceLineTo(smartdnsLogPath, "smartdns", msg) + return cmdResult{OK: true, Message: msg} + } + + aggressive := aggressiveSubs || prewarmAggressiveFromEnv() + + // Default prewarm is wildcard-only (no subs fan-out). + subs := []string{} + subsPerBaseLimit := 0 + if aggressive { + subs = loadList(domainDir + "/subs.txt") + subsPerBaseLimit = envInt("RESOLVE_SUBS_PER_BASE_LIMIT", 0) + if subsPerBaseLimit < 0 { + subsPerBaseLimit = 0 + } + } + domainSet := make(map[string]struct{}, len(wildcards)*(len(subs)+1)) + for _, d := range wildcards { + d = strings.TrimSpace(d) + if d == "" { + continue + } + domainSet[d] = struct{}{} + if aggressive && !isGoogleLike(d) { + maxSubs := len(subs) + if subsPerBaseLimit > 0 && subsPerBaseLimit < maxSubs { + maxSubs = subsPerBaseLimit + } + for i := 0; i < maxSubs; i++ { + domainSet[subs[i]+"."+d] = struct{}{} + } + } + } + + domains := make([]string, 0, len(domainSet)) + for d := range domainSet { + domains = append(domains, d) + } + sort.Strings(domains) + + if limit > 0 && len(domains) > limit { + domains = domains[:limit] + } + if len(domains) == 0 { + msg := "prewarm skipped: expanded wildcard list is empty" + appendTraceLineTo(smartdnsLogPath, "smartdns", msg) + return cmdResult{OK: true, Message: msg} + } + + if workers <= 0 { + workers = envInt("SMARTDNS_PREWARM_WORKERS", 24) + } + if workers < 1 { + workers = 1 + } + if workers > 200 { + workers = 200 + } + + if timeoutMS <= 0 { + timeoutMS = envInt("SMARTDNS_PREWARM_TIMEOUT_MS", 1800) + } + if timeoutMS < 200 { + timeoutMS = 200 + } + if timeoutMS > 15000 { + timeoutMS = 15000 + } + timeout := time.Duration(timeoutMS) * time.Millisecond + + // Ensure runtime set exists before prewarm queries hit SmartDNS nftset hook. + _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "table", "inet", "agvpn") + _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "set", "inet", "agvpn", "agvpn_dyn4", "{", "type", "ipv4_addr", ";", "flags", "interval", ";", "}") + + appendTraceLineTo( + smartdnsLogPath, + "smartdns", + fmt.Sprintf( + "prewarm start: mode=%s source=%s runtime_nftset=%t smartdns=%s wildcard_domains=%d expanded=%d aggressive_subs=%t workers=%d timeout_ms=%d", + mode.Mode, source, runtimeEnabled, smartdnsAddr, len(wildcards), len(domains), aggressive, workers, timeoutMS, + ), + ) + + type prewarmItem struct { + host string + ips []string + stats dnsMetrics + } + jobs := make(chan string, len(domains)) + results := make(chan prewarmItem, len(domains)) + for i := 0; i < workers; i++ { + go func() { + for host := range jobs { + ips, stats := digA(host, []string{smartdnsAddr}, timeout, nil) + results <- prewarmItem{host: host, ips: ips, stats: stats} + } + }() + } + for _, host := range domains { + jobs <- host + } + close(jobs) + + resolvedHosts := 0 + totalIPs := 0 + errorHosts := 0 + stats := dnsMetrics{} + resolvedIPSet := map[string]struct{}{} + loggedHosts := 0 + const maxHostsLog = 200 + + for i := 0; i < len(domains); i++ { + item := <-results + stats.merge(item.stats) + if item.stats.totalErrors() > 0 { + errorHosts++ + } + if len(item.ips) == 0 { + continue + } + resolvedHosts++ + totalIPs += len(item.ips) + for _, ip := range item.ips { + if strings.TrimSpace(ip) != "" { + resolvedIPSet[ip] = struct{}{} + } + } + if loggedHosts < maxHostsLog { + appendTraceLineTo(smartdnsLogPath, "smartdns", fmt.Sprintf("prewarm add: %s -> %s", item.host, strings.Join(item.ips, ", "))) + loggedHosts++ + } + } + + manualAdded := 0 + totalDyn := 0 + totalDynText := "n/a" + if !runtimeEnabled { + existing, _ := readNftSetElements("agvpn_dyn4") + mergedSet := make(map[string]struct{}, len(existing)+len(resolvedIPSet)) + for _, ip := range existing { + if strings.TrimSpace(ip) != "" { + mergedSet[ip] = struct{}{} + } + } + for ip := range resolvedIPSet { + if _, ok := mergedSet[ip]; !ok { + manualAdded++ + } + mergedSet[ip] = struct{}{} + } + merged := make([]string, 0, len(mergedSet)) + for ip := range mergedSet { + merged = append(merged, ip) + } + totalDyn = len(merged) + totalDynText = fmt.Sprintf("%d", totalDyn) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + if err := nftUpdateSetIPsSmart(ctx, "agvpn_dyn4", merged, nil); err != nil { + msg := fmt.Sprintf("prewarm manual apply failed: %v", err) + appendTraceLineTo(smartdnsLogPath, "smartdns", msg) + return cmdResult{OK: false, Message: msg} + } + appendTraceLineTo( + smartdnsLogPath, + "smartdns", + fmt.Sprintf("prewarm manual merge: existing=%d resolved=%d added=%d total_dyn=%d", len(existing), len(resolvedIPSet), manualAdded, totalDyn), + ) + } + if len(domains) > loggedHosts { + appendTraceLineTo(smartdnsLogPath, "smartdns", fmt.Sprintf("prewarm add: +%d domains omitted", len(domains)-loggedHosts)) + } + + msg := fmt.Sprintf( + "prewarm done: source=%s expanded=%d resolved=%d total_ips=%d error_hosts=%d dns_attempts=%d dns_ok=%d dns_errors=%d manual_added=%d dyn_total=%s", + source, + len(domains), + resolvedHosts, + totalIPs, + errorHosts, + stats.Attempts, + stats.OK, + stats.totalErrors(), + manualAdded, + totalDynText, + ) + appendTraceLineTo(smartdnsLogPath, "smartdns", msg) + if perUpstream := stats.formatPerUpstream(); perUpstream != "" { + appendTraceLineTo(smartdnsLogPath, "smartdns", "prewarm dns upstreams: "+perUpstream) + } + + return cmdResult{ + OK: true, + Message: msg, + ExitCode: resolvedHosts, + } +} + +func prewarmAggressiveFromEnv() bool { + switch strings.ToLower(strings.TrimSpace(os.Getenv("SMARTDNS_PREWARM_AGGRESSIVE"))) { + case "1", "true", "yes", "on": + return true + default: + return false + } +} + +// --------------------------------------------------------------------- +// EN: `loadDNSUpstreamsConf` loads dns upstreams conf from storage or config. +// RU: `loadDNSUpstreamsConf` - загружает dns upstreams conf из хранилища или конфига. +// --------------------------------------------------------------------- +func loadDNSUpstreamsConf() DNSUpstreams { + cfg := DNSUpstreams{ + Default1: defaultDNS1, + Default2: defaultDNS2, + Meta1: defaultMeta1, + Meta2: defaultMeta2, + } + + data, err := os.ReadFile(dnsUpstreamsConf) + if err != nil { + return cfg + } + + for _, ln := range strings.Split(string(data), "\n") { + s := strings.TrimSpace(ln) + if s == "" || strings.HasPrefix(s, "#") { + continue + } + parts := strings.Fields(s) + if len(parts) < 2 { + continue + } + key := strings.ToLower(parts[0]) + vals := parts[1:] + switch key { + case "default": + if len(vals) > 0 { + cfg.Default1 = normalizeDNSUpstream(vals[0], "53") + } + if len(vals) > 1 { + cfg.Default2 = normalizeDNSUpstream(vals[1], "53") + } + case "meta": + if len(vals) > 0 { + cfg.Meta1 = normalizeDNSUpstream(vals[0], "53") + } + if len(vals) > 1 { + cfg.Meta2 = normalizeDNSUpstream(vals[1], "53") + } + } + } + + if cfg.Default1 == "" { + cfg.Default1 = defaultDNS1 + } + if cfg.Default2 == "" { + cfg.Default2 = defaultDNS2 + } + if cfg.Meta1 == "" { + cfg.Meta1 = defaultMeta1 + } + if cfg.Meta2 == "" { + cfg.Meta2 = defaultMeta2 + } + return cfg +} + +// --------------------------------------------------------------------- +// EN: `saveDNSUpstreamsConf` saves dns upstreams conf to persistent storage. +// RU: `saveDNSUpstreamsConf` - сохраняет dns upstreams conf в постоянное хранилище. +// --------------------------------------------------------------------- +func saveDNSUpstreamsConf(cfg DNSUpstreams) error { + cfg.Default1 = normalizeDNSUpstream(cfg.Default1, "53") + cfg.Default2 = normalizeDNSUpstream(cfg.Default2, "53") + cfg.Meta1 = normalizeDNSUpstream(cfg.Meta1, "53") + cfg.Meta2 = normalizeDNSUpstream(cfg.Meta2, "53") + + if cfg.Default1 == "" { + cfg.Default1 = defaultDNS1 + } + if cfg.Default2 == "" { + cfg.Default2 = defaultDNS2 + } + if cfg.Meta1 == "" { + cfg.Meta1 = defaultMeta1 + } + if cfg.Meta2 == "" { + cfg.Meta2 = defaultMeta2 + } + + content := fmt.Sprintf( + "default %s %s\nmeta %s %s\n", + cfg.Default1, cfg.Default2, cfg.Meta1, cfg.Meta2, + ) + + if err := os.MkdirAll(filepath.Dir(dnsUpstreamsConf), 0o755); err != nil { + return err + } + tmp := dnsUpstreamsConf + ".tmp" + if err := os.WriteFile(tmp, []byte(content), 0o644); err != nil { + return err + } + if err := os.Rename(tmp, dnsUpstreamsConf); err != nil { + return err + } + + // Legacy JSON mirror for backward compatibility with older UI/runtime bits. + _ = os.MkdirAll(stateDir, 0o755) + if b, err := json.MarshalIndent(cfg, "", " "); err == nil { + _ = os.WriteFile(dnsUpstreamsPath, b, 0o644) + } + + return nil +} + +// --------------------------------------------------------------------- +// EN: `loadDNSMode` loads dns mode from storage or config. +// RU: `loadDNSMode` - загружает dns mode из хранилища или конфига. +// --------------------------------------------------------------------- +func loadDNSMode() DNSMode { + mode := DNSMode{ + ViaSmartDNS: false, + SmartDNSAddr: resolveDefaultSmartDNSAddr(), + Mode: DNSModeDirect, + } + needPersist := false + + data, err := os.ReadFile(dnsModePath) + switch { + case err == nil: + var stored DNSMode + if err := json.Unmarshal(data, &stored); err == nil { + mode.Mode = normalizeDNSResolverMode(stored.Mode, stored.ViaSmartDNS) + mode.ViaSmartDNS = mode.Mode != DNSModeDirect + if strings.TrimSpace(string(stored.Mode)) == "" || stored.ViaSmartDNS != mode.ViaSmartDNS { + needPersist = true + } + if addr := normalizeSmartDNSAddr(stored.SmartDNSAddr); addr != "" { + mode.SmartDNSAddr = addr + } else { + needPersist = true + } + } else { + needPersist = true + } + case os.IsNotExist(err): + needPersist = true + } + + if mode.SmartDNSAddr == "" { + mode.SmartDNSAddr = smartDNSDefaultAddr + needPersist = true + } + mode.Mode = normalizeDNSResolverMode(mode.Mode, mode.ViaSmartDNS) + mode.ViaSmartDNS = mode.Mode != DNSModeDirect + + if needPersist { + _ = saveDNSMode(mode) + } + return mode +} + +// --------------------------------------------------------------------- +// EN: `saveDNSMode` saves dns mode to persistent storage. +// RU: `saveDNSMode` - сохраняет dns mode в постоянное хранилище. +// --------------------------------------------------------------------- +func saveDNSMode(mode DNSMode) error { + mode.Mode = normalizeDNSResolverMode(mode.Mode, mode.ViaSmartDNS) + mode.ViaSmartDNS = mode.Mode != DNSModeDirect + mode.SmartDNSAddr = normalizeSmartDNSAddr(mode.SmartDNSAddr) + if mode.SmartDNSAddr == "" { + mode.SmartDNSAddr = resolveDefaultSmartDNSAddr() + } + + if err := os.MkdirAll(stateDir, 0o755); err != nil { + return err + } + tmp := dnsModePath + ".tmp" + b, err := json.MarshalIndent(mode, "", " ") + if err != nil { + return err + } + if err := os.WriteFile(tmp, b, 0o644); err != nil { + return err + } + return os.Rename(tmp, dnsModePath) +} + +// --------------------------------------------------------------------- +// EN: `normalizeDNSResolverMode` normalizes dns resolver mode values. +// RU: `normalizeDNSResolverMode` - нормализует значения режима dns резолвера. +// --------------------------------------------------------------------- +func normalizeDNSResolverMode(mode DNSResolverMode, viaSmartDNS bool) DNSResolverMode { + switch DNSResolverMode(strings.ToLower(strings.TrimSpace(string(mode)))) { + case DNSModeDirect: + return DNSModeDirect + case DNSModeSmartDNS: + // Legacy value: map old SmartDNS-only selection into hybrid wildcard mode. + return DNSModeHybridWildcard + case DNSModeHybridWildcard, DNSResolverMode("hybrid"): + return DNSModeHybridWildcard + default: + if viaSmartDNS { + return DNSModeHybridWildcard + } + return DNSModeDirect + } +} + +// --------------------------------------------------------------------- +// EN: `smartDNSAddr` contains core logic for smart d n s addr. +// RU: `smartDNSAddr` - содержит основную логику для smart d n s addr. +// --------------------------------------------------------------------- +func smartDNSAddr() string { + return loadDNSMode().SmartDNSAddr +} + +// --------------------------------------------------------------------- +// EN: `smartDNSForced` contains core logic for smart d n s forced. +// RU: `smartDNSForced` - содержит основную логику для smart d n s forced. +// --------------------------------------------------------------------- +func smartDNSForced() bool { + v := strings.TrimSpace(strings.ToLower(os.Getenv(smartDNSForceEnv))) + switch v { + case "1", "true", "yes", "on": + return true + default: + return false + } +} + +// --------------------------------------------------------------------- +// EN: `smartdnsUnitState` contains core logic for smartdns unit state. +// RU: `smartdnsUnitState` - содержит основную логику для smartdns unit state. +// --------------------------------------------------------------------- +func smartdnsUnitState() string { + stdout, _, _, _ := runCommand("systemctl", "is-active", "smartdns-local.service") + st := strings.TrimSpace(stdout) + if st == "" { + return "unknown" + } + return st +} + +// --------------------------------------------------------------------- +// EN: `runSmartdnsUnitAction` runs the workflow for smartdns unit action. +// RU: `runSmartdnsUnitAction` - запускает рабочий процесс для smartdns unit action. +// --------------------------------------------------------------------- +func runSmartdnsUnitAction(action string) cmdResult { + stdout, stderr, exitCode, err := runCommand("systemctl", action, "smartdns-local.service") + res := cmdResult{ + OK: err == nil && exitCode == 0, + ExitCode: exitCode, + Stdout: stdout, + Stderr: stderr, + } + if err != nil { + res.Message = err.Error() + } else { + res.Message = "smartdns " + action + " done" + } + return res +} + +// --------------------------------------------------------------------- +// EN: `resolveDefaultSmartDNSAddr` resolves default smart d n s addr into concrete values. +// RU: `resolveDefaultSmartDNSAddr` - резолвит default smart d n s addr в конкретные значения. +// --------------------------------------------------------------------- +func resolveDefaultSmartDNSAddr() string { + if v := strings.TrimSpace(os.Getenv(smartDNSAddrEnv)); v != "" { + if addr := normalizeSmartDNSAddr(v); addr != "" { + return addr + } + } + for _, path := range []string{ + "/opt/stack/adguardapp/smartdns.conf", + "/etc/selective-vpn/smartdns.conf", + } { + if addr := smartDNSAddrFromConfig(path); addr != "" { + return addr + } + } + return smartDNSDefaultAddr +} + +// --------------------------------------------------------------------- +// EN: `smartDNSAddrFromConfig` loads smart d n s addr from config. +// RU: `smartDNSAddrFromConfig` - загружает smart d n s addr из конфига. +// --------------------------------------------------------------------- +func smartDNSAddrFromConfig(path string) string { + data, err := os.ReadFile(path) + if err != nil { + return "" + } + for _, ln := range strings.Split(string(data), "\n") { + s := strings.TrimSpace(ln) + if s == "" || strings.HasPrefix(s, "#") { + continue + } + if !strings.HasPrefix(strings.ToLower(s), "bind ") { + continue + } + parts := strings.Fields(s) + if len(parts) < 2 { + continue + } + if addr := normalizeSmartDNSAddr(parts[1]); addr != "" { + return addr + } + } + return "" +} + +// --------------------------------------------------------------------- +// EN: `normalizeDNSUpstream` parses dns upstream and returns normalized values. +// RU: `normalizeDNSUpstream` - парсит dns upstream и возвращает нормализованные значения. +// --------------------------------------------------------------------- +func normalizeDNSUpstream(raw string, defaultPort string) string { + s := strings.TrimSpace(raw) + if s == "" { + return "" + } + + s = strings.TrimPrefix(s, "udp://") + s = strings.TrimPrefix(s, "tcp://") + + if strings.Contains(s, "#") { + parts := strings.SplitN(s, "#", 2) + host := strings.Trim(strings.TrimSpace(parts[0]), "[]") + port := strings.TrimSpace(parts[1]) + if host == "" { + return "" + } + if port == "" { + port = defaultPort + } + return host + "#" + port + } + + if host, port, err := net.SplitHostPort(s); err == nil { + host = strings.Trim(strings.TrimSpace(host), "[]") + port = strings.TrimSpace(port) + if host == "" { + return "" + } + if port == "" { + port = defaultPort + } + return host + "#" + port + } + + if strings.Count(s, ":") == 1 { + parts := strings.SplitN(s, ":", 2) + host := strings.TrimSpace(parts[0]) + port := strings.TrimSpace(parts[1]) + if host != "" && port != "" { + return host + "#" + port + } + } + + return s +} + +// --------------------------------------------------------------------- +// EN: `normalizeSmartDNSAddr` parses smart d n s addr and returns normalized values. +// RU: `normalizeSmartDNSAddr` - парсит smart d n s addr и возвращает нормализованные значения. +// --------------------------------------------------------------------- +func normalizeSmartDNSAddr(raw string) string { + s := normalizeDNSUpstream(raw, "6053") + if s == "" { + return "" + } + if strings.Contains(s, "#") { + return s + } + return s + "#6053" +} diff --git a/selective-vpn-api/app/domains_handlers.go b/selective-vpn-api/app/domains_handlers.go new file mode 100644 index 0000000..c6f87e0 --- /dev/null +++ b/selective-vpn-api/app/domains_handlers.go @@ -0,0 +1,184 @@ +package app + +import ( + "encoding/json" + "io" + "io/fs" + "net/http" + "os" + "path/filepath" + "strings" +) + +// --------------------------------------------------------------------- +// domains editor + smartdns wildcards +// --------------------------------------------------------------------- + +// EN: Domain and SmartDNS configuration endpoints. +// EN: Provides CRUD-style file access for domain lists, current nft/ipset table dump, +// EN: and persisted SmartDNS wildcard configuration. +// RU: Эндпоинты конфигурации доменов и SmartDNS. +// RU: Предоставляет доступ к файлам списков доменов, дамп текущей таблицы nft/ipset +// RU: и сохранение конфигурации wildcard-доменов SmartDNS. + +var domainFiles = map[string]string{ + "bases": domainDir + "/bases.txt", + "meta": domainDir + "/meta-special.txt", + "subs": domainDir + "/subs.txt", + "static": staticIPsFile, + "last-ips-map": lastIPsMapPath, + "last-ips-map-direct": lastIPsMapDirect, + "last-ips-map-wildcard": lastIPsMapDyn, +} + +// --------------------------------------------------------------------- +// domains table +// --------------------------------------------------------------------- + +// GET /api/v1/domains/table -> { "lines": [ ... ] } +func handleDomainsTable(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + stdout, _, _, err := runCommand("ipset", "list", "agvpn4") + lines := []string{} + if err == nil { + for _, l := range strings.Split(stdout, "\n") { + l = strings.TrimRight(l, "\r") + if l != "" { + lines = append(lines, l) + } + } + } + writeJSON(w, http.StatusOK, map[string]any{"lines": lines}) +} + +// --------------------------------------------------------------------- +// domains file +// --------------------------------------------------------------------- + +// GET /api/v1/domains/file?name=bases|meta|subs|static|smartdns|last-ips-map|last-ips-map-direct|last-ips-map-wildcard +// POST /api/v1/domains/file { "name": "...", "content": "..." } +func handleDomainsFile(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + name := strings.TrimSpace(r.URL.Query().Get("name")) + if name == "smartdns" { + domains, source := loadSmartDNSWildcardDomainsState(nil) + writeJSON(w, http.StatusOK, map[string]string{ + "content": renderSmartDNSDomainsContent(domains), + "source": source, + }) + return + } + path, ok := domainFiles[name] + if !ok { + http.Error(w, "unknown file name", http.StatusBadRequest) + return + } + source := "file" + if strings.HasPrefix(name, "last-ips-map") { + source = "artifact" + } + data, err := os.ReadFile(path) + if err != nil { + if !os.IsNotExist(err) { + http.Error(w, "read error", http.StatusInternalServerError) + return + } + switch name { + case "bases", "meta", "subs": + // fallback to embedded seed + embedName := name + ".txt" + if name == "meta" { + embedName = "meta-special.txt" + } + data, _ = fs.ReadFile(embeddedDomains, "assets/domains/"+embedName) + source = "embedded" + default: + data = []byte{} + } + } + writeJSON(w, http.StatusOK, map[string]string{ + "content": string(data), + "source": source, + }) + case http.MethodPost: + var body struct { + Name string `json:"name"` + Content string `json:"content"` + } + if r.Body != nil { + defer r.Body.Close() + if err := json.NewDecoder(io.LimitReader(r.Body, 1<<20)).Decode(&body); err != nil { + http.Error(w, "bad json", http.StatusBadRequest) + return + } + } + if strings.TrimSpace(body.Name) == "smartdns" { + domains := parseSmartDNSDomainsContent(body.Content) + if err := saveSmartDNSWildcardDomainsState(domains); err != nil { + http.Error(w, "write error", http.StatusInternalServerError) + return + } + writeJSON(w, http.StatusOK, map[string]string{"status": "ok"}) + return + } + if body.Name == "last-ips-map-direct" || body.Name == "last-ips-map-wildcard" { + http.Error(w, "read-only file name", http.StatusBadRequest) + return + } + path, ok := domainFiles[strings.TrimSpace(body.Name)] + if !ok { + http.Error(w, "unknown file name", http.StatusBadRequest) + return + } + _ = os.MkdirAll(filepath.Dir(path), 0o755) + if err := os.WriteFile(path, []byte(body.Content), 0o644); err != nil { + http.Error(w, "write error", http.StatusInternalServerError) + return + } + writeJSON(w, http.StatusOK, map[string]string{"status": "ok"}) + default: + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + } +} + +// --------------------------------------------------------------------- +// smartdns wildcards +// --------------------------------------------------------------------- + +func handleSmartdnsWildcards(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + payload := struct { + Domains []string `json:"domains"` + }{Domains: readSmartDNSWildcardDomains()} + writeJSON(w, http.StatusOK, payload) + case http.MethodPost: + var payload struct { + Domains []string `json:"domains"` + } + if r.Body != nil { + defer r.Body.Close() + if err := json.NewDecoder(io.LimitReader(r.Body, 1<<20)).Decode(&payload); err != nil { + http.Error(w, "bad json", http.StatusBadRequest) + return + } + } + if err := saveSmartDNSWildcardDomainsState(payload.Domains); err != nil { + http.Error(w, "write error", http.StatusInternalServerError) + return + } + writeJSON(w, http.StatusOK, map[string]string{"status": "ok"}) + default: + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + } +} + +func readSmartDNSWildcardDomains() []string { + domains, _ := loadSmartDNSWildcardDomainsState(nil) + return domains +} diff --git a/selective-vpn-api/app/events_bus.go b/selective-vpn-api/app/events_bus.go new file mode 100644 index 0000000..1204930 --- /dev/null +++ b/selective-vpn-api/app/events_bus.go @@ -0,0 +1,109 @@ +package app + +import ( + "os" + "strconv" + "strings" + "sync" + "time" +) + +// --------------------------------------------------------------------- +// события / event bus +// --------------------------------------------------------------------- + +// EN: In-memory bounded event bus used for SSE replay and polling watchers. +// EN: It keeps only the latest N events and assigns monotonically increasing IDs. +// RU: Ограниченная in-memory шина событий для SSE-реплея и фоновых вотчеров. +// RU: Хранит только последние N событий и присваивает монотонно растущие ID. +type eventBus struct { + mu sync.Mutex + cond *sync.Cond + buf []Event + cap int + next int64 +} + +// --------------------------------------------------------------------- +// EN: `newEventBus` creates a new instance for event bus. +// RU: `newEventBus` - создает новый экземпляр для event bus. +// --------------------------------------------------------------------- +func newEventBus(capacity int) *eventBus { + if capacity < 16 { + capacity = 16 + } + b := &eventBus{ + cap: capacity, + buf: make([]Event, 0, capacity), + } + b.cond = sync.NewCond(&b.mu) + return b +} + +// --------------------------------------------------------------------- +// EN: `push` contains core logic for push. +// RU: `push` - содержит основную логику для push. +// --------------------------------------------------------------------- +func (b *eventBus) push(kind string, data interface{}) Event { + b.mu.Lock() + defer b.mu.Unlock() + + b.next++ + evt := Event{ + ID: b.next, + Kind: kind, + Ts: time.Now().UTC().Format(time.RFC3339Nano), + Data: data, + } + + if len(b.buf) >= b.cap { + b.buf = b.buf[1:] + } + b.buf = append(b.buf, evt) + b.cond.Broadcast() + return evt +} + +// --------------------------------------------------------------------- +// EN: `since` contains core logic for since. +// RU: `since` - содержит основную логику для since. +// --------------------------------------------------------------------- +func (b *eventBus) since(id int64) []Event { + b.mu.Lock() + defer b.mu.Unlock() + return b.sinceLocked(id) +} + +// --------------------------------------------------------------------- +// EN: `sinceLocked` contains core logic for since locked. +// RU: `sinceLocked` - содержит основную логику для since locked. +// --------------------------------------------------------------------- +func (b *eventBus) sinceLocked(id int64) []Event { + if len(b.buf) == 0 { + return nil + } + var out []Event + for _, ev := range b.buf { + if ev.ID > id { + out = append(out, ev) + } + } + return out +} + +// --------------------------------------------------------------------- +// env helpers +// --------------------------------------------------------------------- + +// EN: Positive integer env reader with safe default fallback. +// RU: Чтение положительного целого из env с безопасным fallback на дефолт. +func envInt(key string, def int) int { + if v := strings.TrimSpace(os.Getenv(key)); v != "" { + if n, err := strconv.Atoi(v); err == nil && n > 0 { + return n + } + } + return def +} + +var events = newEventBus(envInt("SVPN_EVENTS_CAP", defaultEventsCapacity)) diff --git a/selective-vpn-api/app/events_handlers.go b/selective-vpn-api/app/events_handlers.go new file mode 100644 index 0000000..4feac10 --- /dev/null +++ b/selective-vpn-api/app/events_handlers.go @@ -0,0 +1,111 @@ +package app + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "strconv" + "strings" + "time" +) + +// --------------------------------------------------------------------- +// events (SSE) +// --------------------------------------------------------------------- + +// EN: Server-Sent Events transport with replay support via Last-Event-ID/since, +// EN: heartbeat pings, and periodic polling of the in-memory event buffer. +// RU: Транспорт Server-Sent Events с поддержкой реплея через Last-Event-ID/since, +// RU: heartbeat-пингами и периодическим опросом in-memory буфера событий. + +// --------------------------------------------------------------------- +// SSE helpers +// --------------------------------------------------------------------- + +func parseSinceID(r *http.Request) int64 { + sinceStr := strings.TrimSpace(r.URL.Query().Get("since")) + if sinceStr == "" { + sinceStr = strings.TrimSpace(r.Header.Get("Last-Event-ID")) + } + if sinceStr == "" { + return 0 + } + if v, err := strconv.ParseInt(sinceStr, 10, 64); err == nil && v >= 0 { + return v + } + return 0 +} + +// --------------------------------------------------------------------- +// SSE stream handler +// --------------------------------------------------------------------- + +func handleEventsStream(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "streaming unsupported", http.StatusBadRequest) + return + } + + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + + ctx := r.Context() + since := parseSinceID(r) + + send := func(ev Event) error { + payload, err := json.Marshal(ev) + if err != nil { + return err + } + if _, err := fmt.Fprintf(w, "id: %d\nevent: %s\ndata: %s\n\n", ev.ID, ev.Kind, string(payload)); err != nil { + return err + } + flusher.Flush() + return nil + } + + // initial replay + for _, ev := range events.since(since) { + if err := send(ev); err != nil { + return + } + since = ev.ID + } + + // polling loop; lightweight for localhost + pollEvery := 500 * time.Millisecond + heartbeat := time.Duration(envInt("SVPN_EVENTS_HEARTBEAT_SEC", defaultHeartbeatSeconds)) * time.Second + pollTicker := time.NewTicker(pollEvery) + pingTicker := time.NewTicker(heartbeat) + defer pollTicker.Stop() + defer pingTicker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-pingTicker.C: + _, _ = io.WriteString(w, ": ping\n\n") + flusher.Flush() + case <-pollTicker.C: + evs := events.since(since) + if len(evs) == 0 { + continue + } + for _, ev := range evs { + if err := send(ev); err != nil { + return + } + since = ev.ID + } + } + } +} diff --git a/selective-vpn-api/app/http_helpers.go b/selective-vpn-api/app/http_helpers.go new file mode 100644 index 0000000..95d132f --- /dev/null +++ b/selective-vpn-api/app/http_helpers.go @@ -0,0 +1,59 @@ +package app + +import ( + "encoding/json" + "log" + "net/http" + "time" +) + +// --------------------------------------------------------------------- +// HTTP helpers +// --------------------------------------------------------------------- + +// EN: Common HTTP helpers used by all endpoint groups for consistent JSON output, +// EN: lightweight request timing logs, and health probing. +// RU: Общие HTTP-хелперы для всех групп эндпоинтов: единый JSON-ответ, +// RU: лёгкое логирование длительности запросов и health-check. + +// --------------------------------------------------------------------- +// request logging +// --------------------------------------------------------------------- + +func logRequests(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + next.ServeHTTP(w, r) + log.Printf("%s %s %s", r.Method, r.URL.Path, time.Since(start)) + }) +} + +// --------------------------------------------------------------------- +// JSON response helper +// --------------------------------------------------------------------- + +func writeJSON(w http.ResponseWriter, status int, v any) { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(status) + if v == nil { + return + } + if err := json.NewEncoder(w).Encode(v); err != nil { + log.Printf("writeJSON error: %v", err) + } +} + +// --------------------------------------------------------------------- +// health endpoint +// --------------------------------------------------------------------- + +func handleHealthz(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + writeJSON(w, http.StatusOK, map[string]string{ + "status": "ok", + "time": time.Now().Format(time.RFC3339), + }) +} diff --git a/selective-vpn-api/app/nft_update.go b/selective-vpn-api/app/nft_update.go new file mode 100644 index 0000000..4fab701 --- /dev/null +++ b/selective-vpn-api/app/nft_update.go @@ -0,0 +1,400 @@ +package app + +import ( + "bytes" + "context" + "errors" + "fmt" + "net/netip" + "os/exec" + "sort" + "strings" + "time" + + "github.com/cenkalti/backoff/v4" +) + +// --------------------------------------------------------------------- +// nft update helpers +// --------------------------------------------------------------------- + +// EN: NFT set update strategy with interval compression and two execution modes: +// EN: atomic transaction first, then chunked fallback with per-IP recovery. +// RU: Стратегия обновления NFT-набора с компрессией интервалов и двумя режимами: +// RU: сначала атомарная транзакция, затем chunked fallback с поштучным восстановлением. + +func nftLog(format string, args ...any) { + appendTraceLine("routes", fmt.Sprintf(format, args...)) +} + +// --------------------------------------------------------------------- +// interval compression +// --------------------------------------------------------------------- + +// compressIPIntervals убирает: +// - дубликаты строк +// - подсети, целиком покрытые более широкими подсетями +// - одиночные IP, попадающие в уже имеющиеся подсети +func compressIPIntervals(ips []string) []string { + // чтобы не гонять дубликаты строк + seen := make(map[string]struct{}) + + type prefixItem struct { + p netip.Prefix + raw string + } + type addrItem struct { + a netip.Addr + raw string + } + + var prefixes []prefixItem + var addrs []addrItem + + for _, s := range ips { + s = strings.TrimSpace(s) + if s == "" { + continue + } + if _, ok := seen[s]; ok { + continue + } + seen[s] = struct{}{} + + if strings.Contains(s, "/") { + p, err := netip.ParsePrefix(s) + if err != nil { + // если формат кривой — просто пропускаем + continue + } + prefixes = append(prefixes, prefixItem{p: p, raw: s}) + } else { + a, err := netip.ParseAddr(s) + if err != nil { + continue + } + addrs = append(addrs, addrItem{a: a, raw: s}) + } + } + + // 1) Убираем подсети, полностью покрытые более крупными подсетями. + // + // Сначала сортируем по: + // - адресу + // - длине префикса (меньший Bits = более широкая сеть) — раньше + sort.Slice(prefixes, func(i, j int) bool { + ai := prefixes[i].p.Addr() + aj := prefixes[j].p.Addr() + if ai == aj { + return prefixes[i].p.Bits() < prefixes[j].p.Bits() + } + return ai.Less(aj) + }) + + var keptPrefixes []prefixItem + for _, pi := range prefixes { + covered := false + for _, kp := range keptPrefixes { + // если более крупная сеть kp уже покрывает эту — пропускаем + if kp.p.Bits() <= pi.p.Bits() && kp.p.Contains(pi.p.Addr()) { + covered = true + break + } + } + if !covered { + keptPrefixes = append(keptPrefixes, pi) + } + } + + var keptAddrs []addrItem + for _, ai := range addrs { + inNet := false + for _, kp := range keptPrefixes { + if kp.p.Contains(ai.a) { + inNet = true + break + } + } + if !inNet { + keptAddrs = append(keptAddrs, ai) + } + } + + // 3) Собираем финальный список строк + out := make([]string, 0, len(keptPrefixes)+len(keptAddrs)) + for _, ai := range keptAddrs { + out = append(out, ai.raw) + } + for _, pi := range keptPrefixes { + out = append(out, pi.raw) + } + + return out +} + +// --------------------------------------------------------------------- +// smart update strategy +// --------------------------------------------------------------------- + +// умный апдейтер: сначала atomic, при фейле – fallback на chunked +func nftUpdateIPsSmart(ctx context.Context, ips []string, progressCb ProgressCallback) error { + return nftUpdateSetIPsSmart(ctx, "agvpn4", ips, progressCb) +} + +// nftUpdateSetIPsSmart — тот же апдейтер, но для произвольного nft set. +func nftUpdateSetIPsSmart(ctx context.Context, setName string, ips []string, progressCb ProgressCallback) error { + setName = strings.TrimSpace(setName) + if setName == "" { + setName = "agvpn4" + } + + if len(ips) == 0 { + if progressCb != nil { + progressCb(100, "nothing to update") + } + return nil + } + + // Сжимаем IP / подсети, убираем пересечения и дубликаты + origCount := len(ips) + ips = compressIPIntervals(ips) + if len(ips) != origCount { + nftLog( + "compress(%s): %d -> %d IP elements (removed %d covered/duplicate entries)", + setName, origCount, len(ips), origCount-len(ips), + ) + } + + if len(ips) == 0 { + if progressCb != nil { + progressCb(100, "nothing to update after compression") + } + return nil + } + + nftLog("nftUpdateSetIPsSmart(%s): start, ips=%d", setName, len(ips)) + + // 1) atomic транзакция через nft -f - + if err := nftAtomicUpdateWithProgress(ctx, setName, ips, progressCb); err == nil { + return nil + } else { + // если контекст умер – дальше не идём + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + nftLog("atomic update cancelled (%s): %v", setName, err) + return err + } + nftLog("atomic nft update failed (%s): %v; falling back to chunked mode", setName, err) + if progressCb != nil { + progressCb(0, "Falling back to non-atomic update") + } + } + + // 2) fallback: flush + chunked с поштучным фолбэком + return nftChunkedUpdateWithFallback(ctx, setName, ips, progressCb) +} + +// --------------------------------------------------------------------- +// atomic updater +// --------------------------------------------------------------------- + +// атомарный апдейт через один nft-транзакционный скрипт +func nftAtomicUpdateWithProgress(ctx context.Context, setName string, ips []string, progressCb ProgressCallback) error { + if len(ips) == 0 { + if progressCb != nil { + progressCb(100, "nothing to update") + } + return nil + } + + sort.Strings(ips) // стабильность + + total := len(ips) + chunkSize := 500 // старт + + bo := backoff.NewExponentialBackOff() + bo.InitialInterval = 500 * time.Millisecond + bo.MaxInterval = 10 * time.Second + bo.MaxElapsedTime = 2 * time.Minute + + return backoff.Retry(func() error { + select { + case <-ctx.Done(): + if progressCb != nil { + progressCb(0, "Cancelled by context") + } + return ctx.Err() + default: + } + + var script strings.Builder + script.WriteString("flush set inet agvpn " + setName + "\n") + + processed := 0 + chunksTotal := (len(ips) + chunkSize - 1) / chunkSize + + for i := 0; i < len(ips); i += chunkSize { + end := i + chunkSize + if end > len(ips) { + end = len(ips) + } + chunk := ips[i:end] + + script.WriteString("add element inet agvpn " + setName + " { ") + script.WriteString(strings.Join(chunk, ", ")) + script.WriteString(" }\n") + + processed += len(chunk) + if progressCb != nil { + percent := processed * 100 / total + progressCb(percent, fmt.Sprintf( + "Preparing chunk %d/%d (%d/%d IPs)", + i/chunkSize+1, chunksTotal, processed, total, + )) + } + } + + if progressCb != nil { + progressCb(90, "Executing nft transaction...") + } + + cmd := exec.CommandContext(ctx, "nft", "-f", "-") + cmd.Stdin = strings.NewReader(script.String()) + + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + err := cmd.Run() + if err == nil { + nftLog("nft atomic transaction success (%s): %d IPs added", setName, len(ips)) + if progressCb != nil { + progressCb(100, "Update complete") + } + return nil + } + + errStr := stderr.String() + nftLog("nft atomic transaction failed (%s): err=%v, stderr=%q", setName, err, errStr) + + // Ошибки, требующие уменьшения чанка + if strings.Contains(errStr, "too many elements") || + strings.Contains(errStr, "out of memory") || + strings.Contains(errStr, "interval overlaps") || + strings.Contains(errStr, "conflicting intervals") { + + newSize := chunkSize / 2 + if newSize < 100 { + newSize = 100 + } + if newSize == chunkSize { + // дальше делить некуда — Permanent → fallback + return backoff.Permanent(fmt.Errorf("atomic nft cannot shrink further: %w", err)) + } + + nftLog("reducing atomic chunk size from %d to %d and retrying", chunkSize, newSize) + chunkSize = newSize + + if progressCb != nil { + progressCb(0, fmt.Sprintf("Retrying atomic with smaller chunks (%d IPs)", chunkSize)) + } + + return fmt.Errorf("retry atomic with smaller chunks") + } + + // Другие ошибки — Permanent (переход к chunked) + return backoff.Permanent(fmt.Errorf("nft atomic transaction failed: %w", err)) + }, bo) +} + +// --------------------------------------------------------------------- +// chunked fallback updater +// --------------------------------------------------------------------- + +// nftChunkedUpdateWithFallback — fallback-режим: flush + чанки + поштучно при ошибках +func nftChunkedUpdateWithFallback(ctx context.Context, setName string, ips []string, progressCb ProgressCallback) error { + if len(ips) == 0 { + if progressCb != nil { + progressCb(100, "nothing to update") + } + return nil + } + + sort.Strings(ips) + + total := len(ips) + chunkSize := 500 + + // flush + _, stderr, _, err := runCommandTimeout(10*time.Second, + "nft", "flush", "set", "inet", "agvpn", setName) + if err != nil { + return fmt.Errorf("flush set failed: %v (%s)", err, stderr) + } + + processed := 0 + + for i := 0; i < len(ips); i += chunkSize { + select { + case <-ctx.Done(): + if progressCb != nil { + progressCb(0, "Cancelled during chunked update") + } + return ctx.Err() + default: + } + + end := i + chunkSize + if end > len(ips) { + end = len(ips) + } + chunk := ips[i:end] + + cmdArgs := []string{ + "nft", "add", "element", "inet", "agvpn", setName, + "{ " + strings.Join(chunk, ", ") + " }", + } + + cmdName := cmdArgs[0] + cmdRest := cmdArgs[1:] + + _, stderr, _, err := runCommandTimeout(15*time.Second, cmdName, cmdRest...) + if err != nil { + // типичные ошибки → поштучно + if strings.Contains(stderr, "interval overlaps") || + strings.Contains(stderr, "too many elements") || + strings.Contains(stderr, "out of memory") || + strings.Contains(stderr, "conflicting intervals") { + + nftLog("chunk failed (%d IPs), fallback per-ip", len(chunk)) + if progressCb != nil { + progressCb(processed*100/total, + fmt.Sprintf("Chunk failed -> adding %d IPs one by one", len(chunk))) + } + + for _, ip := range chunk { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + _, _, _, _ = runCommandTimeout(5*time.Second, + "nft", "add", "element", "inet", "agvpn", setName, "{ "+ip+" }") + } + } else { + return fmt.Errorf("nft chunk add failed: %v (%s)", err, stderr) + } + } + + processed += len(chunk) + if progressCb != nil { + percent := processed * 100 / total + progressCb(percent, fmt.Sprintf("Added %d/%d IPs", processed, total)) + } + } + + if progressCb != nil { + progressCb(100, "chunked update complete") + } + nftLog("nft chunked update success (%s): %d IPs", setName, len(ips)) + return nil +} diff --git a/selective-vpn-api/app/resolver.go b/selective-vpn-api/app/resolver.go new file mode 100644 index 0000000..ee07ecb --- /dev/null +++ b/selective-vpn-api/app/resolver.go @@ -0,0 +1,1170 @@ +package app + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net" + "net/netip" + "os" + "regexp" + "sort" + "strconv" + "strings" + "time" +) + +// --------------------------------------------------------------------- +// Go resolver +// --------------------------------------------------------------------- + +// EN: Go-based domain resolver pipeline used by routes update. +// EN: Handles cache reuse, concurrent DNS lookups, PTR labeling for static entries, +// EN: and returns deduplicated IP sets plus IP-to-label mapping artifacts. +// RU: Go-резолвер, используемый пайплайном обновления маршрутов. +// RU: Обрабатывает кэш, конкурентные DNS-запросы, PTR-лейблы для static entries +// RU: и возвращает дедуплицированный список IP и IP-to-label mapping. + +type dnsErrorKind string + +const ( + dnsErrorNXDomain dnsErrorKind = "nxdomain" + dnsErrorTimeout dnsErrorKind = "timeout" + dnsErrorTemporary dnsErrorKind = "temporary" + dnsErrorOther dnsErrorKind = "other" +) + +type dnsUpstreamMetrics struct { + Attempts int + OK int + NXDomain int + Timeout int + Temporary int + Other int +} + +type dnsMetrics struct { + Attempts int + OK int + NXDomain int + Timeout int + Temporary int + Other int + + PerUpstream map[string]*dnsUpstreamMetrics +} + +func (m *dnsMetrics) ensureUpstream(upstream string) *dnsUpstreamMetrics { + if m.PerUpstream == nil { + m.PerUpstream = map[string]*dnsUpstreamMetrics{} + } + if us, ok := m.PerUpstream[upstream]; ok { + return us + } + us := &dnsUpstreamMetrics{} + m.PerUpstream[upstream] = us + return us +} + +func (m *dnsMetrics) addSuccess(upstream string) { + m.Attempts++ + m.OK++ + us := m.ensureUpstream(upstream) + us.Attempts++ + us.OK++ +} + +func (m *dnsMetrics) addError(upstream string, kind dnsErrorKind) { + m.Attempts++ + us := m.ensureUpstream(upstream) + us.Attempts++ + switch kind { + case dnsErrorNXDomain: + m.NXDomain++ + us.NXDomain++ + case dnsErrorTimeout: + m.Timeout++ + us.Timeout++ + case dnsErrorTemporary: + m.Temporary++ + us.Temporary++ + default: + m.Other++ + us.Other++ + } +} + +func (m *dnsMetrics) merge(other dnsMetrics) { + m.Attempts += other.Attempts + m.OK += other.OK + m.NXDomain += other.NXDomain + m.Timeout += other.Timeout + m.Temporary += other.Temporary + m.Other += other.Other + + for upstream, src := range other.PerUpstream { + dst := m.ensureUpstream(upstream) + dst.Attempts += src.Attempts + dst.OK += src.OK + dst.NXDomain += src.NXDomain + dst.Timeout += src.Timeout + dst.Temporary += src.Temporary + dst.Other += src.Other + } +} + +func (m dnsMetrics) totalErrors() int { + return m.NXDomain + m.Timeout + m.Temporary + m.Other +} + +func (m dnsMetrics) formatPerUpstream() string { + if len(m.PerUpstream) == 0 { + return "" + } + keys := make([]string, 0, len(m.PerUpstream)) + for k := range m.PerUpstream { + keys = append(keys, k) + } + sort.Strings(keys) + parts := make([]string, 0, len(keys)) + for _, k := range keys { + v := m.PerUpstream[k] + parts = append(parts, fmt.Sprintf("%s{attempts=%d ok=%d nxdomain=%d timeout=%d temporary=%d other=%d}", k, v.Attempts, v.OK, v.NXDomain, v.Timeout, v.Temporary, v.Other)) + } + return strings.Join(parts, "; ") +} + +type wildcardMatcher struct { + exact map[string]struct{} + suffix []string +} + +func normalizeWildcardDomain(raw string) string { + d := strings.TrimSpace(strings.SplitN(raw, "#", 2)[0]) + d = strings.ToLower(d) + d = strings.TrimPrefix(d, "*.") + d = strings.TrimPrefix(d, ".") + d = strings.TrimSuffix(d, ".") + return d +} + +func newWildcardMatcher(domains []string) wildcardMatcher { + seen := map[string]struct{}{} + m := wildcardMatcher{exact: map[string]struct{}{}} + for _, raw := range domains { + d := normalizeWildcardDomain(raw) + if d == "" { + continue + } + if _, ok := seen[d]; ok { + continue + } + seen[d] = struct{}{} + m.exact[d] = struct{}{} + m.suffix = append(m.suffix, "."+d) + } + return m +} + +func (m wildcardMatcher) match(host string) bool { + if len(m.exact) == 0 { + return false + } + h := strings.TrimSuffix(strings.ToLower(strings.TrimSpace(host)), ".") + if h == "" { + return false + } + if _, ok := m.exact[h]; ok { + return true + } + for _, suffix := range m.suffix { + if strings.HasSuffix(h, suffix) { + return true + } + } + return false +} + +// --------------------------------------------------------------------- +// EN: `runResolverJob` runs the workflow for resolver job. +// RU: `runResolverJob` - запускает рабочий процесс для resolver job. +// --------------------------------------------------------------------- +func runResolverJob(opts ResolverOpts, logf func(string, ...any)) (resolverResult, error) { + res := resolverResult{ + DomainCache: map[string]any{}, + PtrCache: map[string]any{}, + } + + domains := loadList(opts.DomainsPath) + metaSpecial := loadList(opts.MetaPath) + staticLines := readLinesAllowMissing(opts.StaticPath) + wildcards := newWildcardMatcher(opts.SmartDNSWildcards) + + cfg := loadDNSConfig(opts.DNSConfigPath, logf) + if !smartDNSForced() { + cfg.Mode = normalizeDNSResolverMode(opts.Mode, opts.ViaSmartDNS) + } + if addr := normalizeSmartDNSAddr(opts.SmartDNSAddr); addr != "" { + cfg.SmartDNS = addr + } + if cfg.SmartDNS == "" { + cfg.SmartDNS = smartDNSAddr() + } + if cfg.Mode == DNSModeSmartDNS && cfg.SmartDNS != "" { + cfg.Default = []string{cfg.SmartDNS} + cfg.Meta = []string{cfg.SmartDNS} + } + if logf != nil { + switch cfg.Mode { + case DNSModeSmartDNS: + logf("resolver dns mode: SmartDNS-only (%s)", cfg.SmartDNS) + case DNSModeHybridWildcard: + logf("resolver dns mode: hybrid_wildcard smartdns=%s wildcards=%d default=%v meta=%v", cfg.SmartDNS, len(wildcards.exact), cfg.Default, cfg.Meta) + default: + logf("resolver dns mode: direct default=%v meta=%v", cfg.Default, cfg.Meta) + } + } + + ttl := opts.TTL + if ttl <= 0 { + ttl = 24 * 3600 + } + // safety clamp: 60s .. 24h + if ttl < 60 { + ttl = 60 + } + if ttl > 24*3600 { + ttl = 24 * 3600 + } + workers := opts.Workers + if workers <= 0 { + workers = 80 + } + // safety clamp: 1..500 + if workers < 1 { + workers = 1 + } + if workers > 500 { + workers = 500 + } + + domainCache := loadDomainCacheState(opts.CachePath, logf) + ptrCache := loadJSONMap(opts.PtrCachePath) + now := int(time.Now().Unix()) + + cacheSourceForHost := func(host string) domainCacheSource { + switch cfg.Mode { + case DNSModeSmartDNS: + return domainCacheSourceWildcard + case DNSModeHybridWildcard: + if wildcards.match(host) { + return domainCacheSourceWildcard + } + } + return domainCacheSourceDirect + } + + if logf != nil { + logf("resolver start: domains=%d ttl=%ds workers=%d", len(domains), ttl, workers) + } + start := time.Now() + + fresh := map[string][]string{} + var toResolve []string + for _, d := range domains { + source := cacheSourceForHost(d) + if ips, ok := domainCache.get(d, source, now, ttl); ok { + fresh[d] = ips + if logf != nil { + logf("cache hit[%s]: %s -> %v", source, d, ips) + } + continue + } + toResolve = append(toResolve, d) + } + + resolved := map[string][]string{} + for k, v := range fresh { + resolved[k] = v + } + + if logf != nil { + logf("resolve: domains=%d cache_hits=%d to_resolve=%d", len(domains), len(fresh), len(toResolve)) + } + + dnsStats := dnsMetrics{} + + if len(toResolve) > 0 { + type job struct { + host string + } + jobs := make(chan job, len(toResolve)) + results := make(chan struct { + host string + ips []string + stats dnsMetrics + }, len(toResolve)) + + for i := 0; i < workers; i++ { + go func() { + for j := range jobs { + ips, stats := resolveHostGo(j.host, cfg, metaSpecial, wildcards, logf) + results <- struct { + host string + ips []string + stats dnsMetrics + }{j.host, ips, stats} + } + }() + } + for _, h := range toResolve { + jobs <- job{host: h} + } + close(jobs) + for i := 0; i < len(toResolve); i++ { + r := <-results + dnsStats.merge(r.stats) + hostErrors := r.stats.totalErrors() + if hostErrors > 0 && logf != nil { + logf("resolve errors for %s: total=%d nxdomain=%d timeout=%d temporary=%d other=%d", r.host, hostErrors, r.stats.NXDomain, r.stats.Timeout, r.stats.Temporary, r.stats.Other) + } + if len(r.ips) > 0 { + resolved[r.host] = r.ips + source := cacheSourceForHost(r.host) + domainCache.set(r.host, source, r.ips, now) + if logf != nil { + logf("%s -> %v", r.host, r.ips) + } + } else if logf != nil { + logf("%s: no IPs", r.host) + } + } + } + + staticEntries, staticSkipped := parseStaticEntriesGo(staticLines, logf) + staticLabels, ptrLookups, ptrErrors := resolveStaticLabels(staticEntries, cfg, ptrCache, ttl, logf) + + ipSetAll := map[string]struct{}{} + ipSetDirect := map[string]struct{}{} + ipSetWildcard := map[string]struct{}{} + + ipMapAll := map[string]map[string]struct{}{} + ipMapDirect := map[string]map[string]struct{}{} + ipMapWildcard := map[string]map[string]struct{}{} + + add := func(set map[string]struct{}, labels map[string]map[string]struct{}, ip, label string) { + if ip == "" { + return + } + set[ip] = struct{}{} + m := labels[ip] + if m == nil { + m = map[string]struct{}{} + labels[ip] = m + } + m[label] = struct{}{} + } + + isWildcardHost := func(host string) bool { + switch cfg.Mode { + case DNSModeSmartDNS: + return true + case DNSModeHybridWildcard: + return wildcards.match(host) + default: + return false + } + } + + for host, ips := range resolved { + wildcardHost := isWildcardHost(host) + for _, ip := range ips { + add(ipSetAll, ipMapAll, ip, host) + if wildcardHost { + add(ipSetWildcard, ipMapWildcard, ip, host) + } else { + add(ipSetDirect, ipMapDirect, ip, host) + } + } + } + for ipEntry, labels := range staticLabels { + for _, lbl := range labels { + add(ipSetAll, ipMapAll, ipEntry, lbl) + // Static entries are explicit operator rules; keep them in direct set. + add(ipSetDirect, ipMapDirect, ipEntry, lbl) + } + } + + appendMapPairs := func(dst *[][2]string, labelsByIP map[string]map[string]struct{}) { + for ip := range labelsByIP { + labels := labelsByIP[ip] + for lbl := range labels { + *dst = append(*dst, [2]string{ip, lbl}) + } + } + sort.Slice(*dst, func(i, j int) bool { + if (*dst)[i][0] == (*dst)[j][0] { + return (*dst)[i][1] < (*dst)[j][1] + } + return (*dst)[i][0] < (*dst)[j][0] + }) + } + appendIPs := func(dst *[]string, set map[string]struct{}) { + for ip := range set { + *dst = append(*dst, ip) + } + sort.Strings(*dst) + } + + appendMapPairs(&res.IPMap, ipMapAll) + appendMapPairs(&res.DirectIPMap, ipMapDirect) + appendMapPairs(&res.WildcardIPMap, ipMapWildcard) + appendIPs(&res.IPs, ipSetAll) + appendIPs(&res.DirectIPs, ipSetDirect) + appendIPs(&res.WildcardIPs, ipSetWildcard) + + res.DomainCache = domainCache.toMap() + res.PtrCache = ptrCache + + if logf != nil { + dnsErrors := dnsStats.totalErrors() + logf( + "resolve summary: domains=%d cache_hits=%d resolved_now=%d unresolved=%d static_entries=%d static_skipped=%d unique_ips=%d direct_ips=%d wildcard_ips=%d ptr_lookups=%d ptr_errors=%d dns_attempts=%d dns_ok=%d dns_nxdomain=%d dns_timeout=%d dns_temporary=%d dns_other=%d dns_errors=%d duration_ms=%d", + len(domains), + len(fresh), + len(resolved)-len(fresh), + len(domains)-len(resolved), + len(staticEntries), + staticSkipped, + len(res.IPs), + len(res.DirectIPs), + len(res.WildcardIPs), + ptrLookups, + ptrErrors, + dnsStats.Attempts, + dnsStats.OK, + dnsStats.NXDomain, + dnsStats.Timeout, + dnsStats.Temporary, + dnsStats.Other, + dnsErrors, + time.Since(start).Milliseconds(), + ) + if perUpstream := dnsStats.formatPerUpstream(); perUpstream != "" { + logf("resolve dns upstreams: %s", perUpstream) + } + } + return res, nil +} + +// --------------------------------------------------------------------- +// DNS resolve helpers +// --------------------------------------------------------------------- + +func resolveHostGo(host string, cfg dnsConfig, metaSpecial []string, wildcards wildcardMatcher, logf func(string, ...any)) ([]string, dnsMetrics) { + useMeta := false + for _, m := range metaSpecial { + if host == m { + useMeta = true + break + } + } + dnsList := cfg.Default + if useMeta { + dnsList = cfg.Meta + } + switch cfg.Mode { + case DNSModeSmartDNS: + if cfg.SmartDNS != "" { + dnsList = []string{cfg.SmartDNS} + } + case DNSModeHybridWildcard: + if cfg.SmartDNS != "" && wildcards.match(host) { + dnsList = []string{cfg.SmartDNS} + } + } + ips, stats := digA(host, dnsList, 3*time.Second, logf) + out := []string{} + seen := map[string]struct{}{} + for _, ip := range ips { + if isPrivateIPv4(ip) { + continue + } + if _, ok := seen[ip]; !ok { + seen[ip] = struct{}{} + out = append(out, ip) + } + } + return out, stats +} + +// --------------------------------------------------------------------- +// EN: `digA` contains core logic for dig a. +// RU: `digA` - содержит основную логику для dig a. +// --------------------------------------------------------------------- +func digA(host string, dnsList []string, timeout time.Duration, logf func(string, ...any)) ([]string, dnsMetrics) { + var ips []string + stats := dnsMetrics{} + for _, entry := range dnsList { + server, port := splitDNS(entry) + if server == "" { + continue + } + if port == "" { + port = "53" + } + addr := net.JoinHostPort(server, port) + r := &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + d := net.Dialer{} + return d.DialContext(ctx, "udp", addr) + }, + } + ctx, cancel := context.WithTimeout(context.Background(), timeout) + records, err := r.LookupHost(ctx, host) + cancel() + if err != nil { + kind := classifyDNSError(err) + stats.addError(addr, kind) + if logf != nil { + logf("dns warn %s via %s: kind=%s err=%v", host, addr, kind, err) + } + continue + } + stats.addSuccess(addr) + for _, ip := range records { + if isPrivateIPv4(ip) { + continue + } + ips = append(ips, ip) + } + } + return uniqueStrings(ips), stats +} + +func classifyDNSError(err error) dnsErrorKind { + if err == nil { + return dnsErrorOther + } + var dnsErr *net.DNSError + if errors.As(err, &dnsErr) { + if dnsErr.IsNotFound { + return dnsErrorNXDomain + } + if dnsErr.IsTimeout { + return dnsErrorTimeout + } + if dnsErr.IsTemporary { + return dnsErrorTemporary + } + } + msg := strings.ToLower(err.Error()) + switch { + case strings.Contains(msg, "no such host"), strings.Contains(msg, "nxdomain"): + return dnsErrorNXDomain + case strings.Contains(msg, "i/o timeout"), strings.Contains(msg, "timeout"): + return dnsErrorTimeout + case strings.Contains(msg, "temporary"): + return dnsErrorTemporary + default: + return dnsErrorOther + } +} + +// --------------------------------------------------------------------- +// EN: `splitDNS` splits dns into structured parts. +// RU: `splitDNS` - разделяет dns на структурированные части. +// --------------------------------------------------------------------- +func splitDNS(dns string) (string, string) { + if strings.Contains(dns, "#") { + parts := strings.SplitN(dns, "#", 2) + host := strings.TrimSpace(parts[0]) + port := strings.TrimSpace(parts[1]) + if host == "" { + host = "127.0.0.1" + } + if port == "" { + port = "53" + } + return host, port + } + return strings.TrimSpace(dns), "" +} + +// --------------------------------------------------------------------- +// static entries + PTR labels +// --------------------------------------------------------------------- + +func parseStaticEntriesGo(lines []string, logf func(string, ...any)) (entries [][3]string, skipped int) { + for _, ln := range lines { + s := strings.TrimSpace(ln) + if s == "" || strings.HasPrefix(s, "#") { + continue + } + comment := "" + if idx := strings.Index(s, "#"); idx >= 0 { + comment = strings.TrimSpace(s[idx+1:]) + s = strings.TrimSpace(s[:idx]) + } + if s == "" || isPrivateIPv4(s) { + continue + } + + // validate ip/prefix + rawBase := strings.SplitN(s, "/", 2)[0] + if strings.Contains(s, "/") { + if _, err := netip.ParsePrefix(s); err != nil { + skipped++ + if logf != nil { + logf("static skip invalid prefix %q: %v", s, err) + } + continue + } + } else { + if _, err := netip.ParseAddr(rawBase); err != nil { + skipped++ + if logf != nil { + logf("static skip invalid ip %q: %v", s, err) + } + continue + } + } + + entries = append(entries, [3]string{s, rawBase, comment}) + } + return entries, skipped +} + +// --------------------------------------------------------------------- +// EN: `resolveStaticLabels` resolves static labels into concrete values. +// RU: `resolveStaticLabels` - резолвит static labels в конкретные значения. +// --------------------------------------------------------------------- +func resolveStaticLabels(entries [][3]string, cfg dnsConfig, ptrCache map[string]any, ttl int, logf func(string, ...any)) (map[string][]string, int, int) { + now := int(time.Now().Unix()) + result := map[string][]string{} + ptrLookups := 0 + ptrErrors := 0 + dnsForPtr := "" + if len(cfg.Default) > 0 { + dnsForPtr = cfg.Default[0] + } else { + dnsForPtr = defaultDNS1 + } + for _, e := range entries { + ipEntry, baseIP, comment := e[0], e[1], e[2] + var labels []string + if comment != "" { + labels = append(labels, "*"+comment) + } + if comment == "" { + if cached, ok := ptrCache[baseIP].(map[string]any); ok { + names, _ := cached["names"].([]any) + last, _ := cached["last_resolved"].(float64) + if len(names) > 0 && last > 0 && now-int(last) <= ttl { + for _, n := range names { + if s, ok := n.(string); ok && s != "" { + labels = append(labels, "*"+s) + } + } + } + } + if len(labels) == 0 { + ptrLookups++ + names, err := digPTR(baseIP, dnsForPtr, 3*time.Second, logf) + if err != nil { + ptrErrors++ + } + if len(names) > 0 { + ptrCache[baseIP] = map[string]any{"names": names, "last_resolved": now} + for _, n := range names { + labels = append(labels, "*"+n) + } + } + } + } + if len(labels) == 0 { + labels = []string{"*[STATIC-IP]"} + } + result[ipEntry] = labels + if logf != nil { + logf("static %s -> %v", ipEntry, labels) + } + } + return result, ptrLookups, ptrErrors +} + +// --------------------------------------------------------------------- +// DNS config + cache helpers +// --------------------------------------------------------------------- + +type domainCacheSource string + +const ( + domainCacheSourceDirect domainCacheSource = "direct" + domainCacheSourceWildcard domainCacheSource = "wildcard" +) + +type domainCacheEntry struct { + IPs []string `json:"ips"` + LastResolved int `json:"last_resolved"` +} + +type domainCacheRecord struct { + Direct *domainCacheEntry `json:"direct,omitempty"` + Wildcard *domainCacheEntry `json:"wildcard,omitempty"` +} + +type domainCacheState struct { + Version int `json:"version"` + Domains map[string]domainCacheRecord `json:"domains"` +} + +func newDomainCacheState() domainCacheState { + return domainCacheState{ + Version: 2, + Domains: map[string]domainCacheRecord{}, + } +} + +func normalizeCacheIPs(raw []string) []string { + seen := map[string]struct{}{} + out := make([]string, 0, len(raw)) + for _, ip := range raw { + ip = strings.TrimSpace(ip) + if ip == "" || isPrivateIPv4(ip) { + continue + } + if _, ok := seen[ip]; ok { + continue + } + seen[ip] = struct{}{} + out = append(out, ip) + } + sort.Strings(out) + return out +} + +func parseAnyStringSlice(raw any) []string { + switch v := raw.(type) { + case []string: + return append([]string(nil), v...) + case []any: + out := make([]string, 0, len(v)) + for _, x := range v { + if s, ok := x.(string); ok { + out = append(out, s) + } + } + return out + default: + return nil + } +} + +func parseAnyInt(raw any) (int, bool) { + switch v := raw.(type) { + case int: + return v, true + case int64: + return int(v), true + case float64: + return int(v), true + case json.Number: + n, err := v.Int64() + if err != nil { + return 0, false + } + return int(n), true + default: + return 0, false + } +} + +func parseLegacyDomainCacheEntry(raw any) (domainCacheEntry, bool) { + m, ok := raw.(map[string]any) + if !ok { + return domainCacheEntry{}, false + } + ips := normalizeCacheIPs(parseAnyStringSlice(m["ips"])) + if len(ips) == 0 { + return domainCacheEntry{}, false + } + ts, ok := parseAnyInt(m["last_resolved"]) + if !ok || ts <= 0 { + return domainCacheEntry{}, false + } + return domainCacheEntry{IPs: ips, LastResolved: ts}, true +} + +func loadDomainCacheState(path string, logf func(string, ...any)) domainCacheState { + data, err := os.ReadFile(path) + if err != nil || len(data) == 0 { + return newDomainCacheState() + } + + var st domainCacheState + if err := json.Unmarshal(data, &st); err == nil && st.Domains != nil { + if st.Version <= 0 { + st.Version = 2 + } + normalized := newDomainCacheState() + for host, rec := range st.Domains { + host = strings.TrimSpace(strings.ToLower(host)) + if host == "" { + continue + } + nrec := domainCacheRecord{} + if rec.Direct != nil { + ips := normalizeCacheIPs(rec.Direct.IPs) + if len(ips) > 0 && rec.Direct.LastResolved > 0 { + nrec.Direct = &domainCacheEntry{IPs: ips, LastResolved: rec.Direct.LastResolved} + } + } + if rec.Wildcard != nil { + ips := normalizeCacheIPs(rec.Wildcard.IPs) + if len(ips) > 0 && rec.Wildcard.LastResolved > 0 { + nrec.Wildcard = &domainCacheEntry{IPs: ips, LastResolved: rec.Wildcard.LastResolved} + } + } + if nrec.Direct != nil || nrec.Wildcard != nil { + normalized.Domains[host] = nrec + } + } + return normalized + } + + // Legacy shape: { "domain.tld": {"ips":[...], "last_resolved":...}, ... } + var legacy map[string]any + if err := json.Unmarshal(data, &legacy); err != nil { + if logf != nil { + logf("domain-cache: invalid json at %s, ignore", path) + } + return newDomainCacheState() + } + + out := newDomainCacheState() + migrated := 0 + for host, raw := range legacy { + host = strings.TrimSpace(strings.ToLower(host)) + if host == "" || host == "version" || host == "domains" { + continue + } + entry, ok := parseLegacyDomainCacheEntry(raw) + if !ok { + continue + } + rec := out.Domains[host] + rec.Direct = &entry + out.Domains[host] = rec + migrated++ + } + if logf != nil && migrated > 0 { + logf("domain-cache: migrated legacy entries=%d into split cache (direct bucket)", migrated) + } + return out +} + +func (s domainCacheState) get(domain string, source domainCacheSource, now, ttl int) ([]string, bool) { + rec, ok := s.Domains[strings.TrimSpace(strings.ToLower(domain))] + if !ok { + return nil, false + } + var entry *domainCacheEntry + switch source { + case domainCacheSourceWildcard: + entry = rec.Wildcard + default: + entry = rec.Direct + } + if entry == nil || entry.LastResolved <= 0 { + return nil, false + } + if now-entry.LastResolved > ttl { + return nil, false + } + ips := normalizeCacheIPs(entry.IPs) + if len(ips) == 0 { + return nil, false + } + return ips, true +} + +func (s *domainCacheState) set(domain string, source domainCacheSource, ips []string, now int) { + host := strings.TrimSpace(strings.ToLower(domain)) + if host == "" || now <= 0 { + return + } + norm := normalizeCacheIPs(ips) + if len(norm) == 0 { + return + } + if s.Domains == nil { + s.Domains = map[string]domainCacheRecord{} + } + rec := s.Domains[host] + entry := &domainCacheEntry{IPs: norm, LastResolved: now} + switch source { + case domainCacheSourceWildcard: + rec.Wildcard = entry + default: + rec.Direct = entry + } + s.Domains[host] = rec +} + +func (s domainCacheState) toMap() map[string]any { + out := map[string]any{ + "version": 2, + "domains": map[string]any{}, + } + domainsAny := out["domains"].(map[string]any) + hosts := make([]string, 0, len(s.Domains)) + for host := range s.Domains { + hosts = append(hosts, host) + } + sort.Strings(hosts) + for _, host := range hosts { + rec := s.Domains[host] + recOut := map[string]any{} + if rec.Direct != nil && len(rec.Direct.IPs) > 0 && rec.Direct.LastResolved > 0 { + recOut["direct"] = map[string]any{ + "ips": rec.Direct.IPs, + "last_resolved": rec.Direct.LastResolved, + } + } + if rec.Wildcard != nil && len(rec.Wildcard.IPs) > 0 && rec.Wildcard.LastResolved > 0 { + recOut["wildcard"] = map[string]any{ + "ips": rec.Wildcard.IPs, + "last_resolved": rec.Wildcard.LastResolved, + } + } + if len(recOut) > 0 { + domainsAny[host] = recOut + } + } + return out +} + +func digPTR(ip, upstream string, timeout time.Duration, logf func(string, ...any)) ([]string, error) { + server, port := splitDNS(upstream) + if server == "" { + return nil, fmt.Errorf("upstream empty") + } + if port == "" { + port = "53" + } + addr := net.JoinHostPort(server, port) + r := &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + d := net.Dialer{} + return d.DialContext(ctx, "udp", addr) + }, + } + ctx, cancel := context.WithTimeout(context.Background(), timeout) + names, err := r.LookupAddr(ctx, ip) + cancel() + if err != nil { + if logf != nil { + logf("ptr error %s via %s: %v", ip, addr, err) + } + return nil, err + } + seen := map[string]struct{}{} + var out []string + for _, n := range names { + n = strings.TrimSuffix(strings.ToLower(strings.TrimSpace(n)), ".") + if n == "" { + continue + } + if _, ok := seen[n]; !ok { + seen[n] = struct{}{} + out = append(out, n) + } + } + return out, nil +} + +// --------------------------------------------------------------------- +// EN: `loadDNSConfig` loads dns config from storage or config. +// RU: `loadDNSConfig` - загружает dns config из хранилища или конфига. +// --------------------------------------------------------------------- +func loadDNSConfig(path string, logf func(string, ...any)) dnsConfig { + cfg := dnsConfig{ + Default: []string{defaultDNS1, defaultDNS2}, + Meta: []string{defaultMeta1, defaultMeta2}, + SmartDNS: smartDNSAddr(), + Mode: DNSModeDirect, + } + + // 1) Если форсируем SmartDNS — вообще игнорим файл и ходим только через локальный резолвер. + if smartDNSForced() { + addr := smartDNSAddr() + cfg.Default = []string{addr} + cfg.Meta = []string{addr} + cfg.SmartDNS = addr + cfg.Mode = DNSModeSmartDNS + + if logf != nil { + logf("dns-config: SmartDNS forced (%s), ignore %s", addr, path) + } + return cfg + } + + // 2) Иначе пытаемся прочитать dns-upstreams.conf, как и раньше. + data, err := os.ReadFile(path) + if err != nil { + if logf != nil { + logf("dns-config: use built-in defaults, can't read %s: %v", path, err) + } + return cfg + } + + var def, meta []string + lines := strings.Split(string(data), "\n") + for _, ln := range lines { + s := strings.TrimSpace(ln) + if s == "" || strings.HasPrefix(s, "#") { + continue + } + parts := strings.Fields(s) + if len(parts) < 2 { + continue + } + key := strings.ToLower(parts[0]) + vals := parts[1:] + switch key { + case "default": + for _, v := range vals { + if n := normalizeDNSUpstream(v, "53"); n != "" { + def = append(def, n) + } + } + case "meta": + for _, v := range vals { + if n := normalizeDNSUpstream(v, "53"); n != "" { + meta = append(meta, n) + } + } + case "smartdns": + if len(vals) > 0 { + if n := normalizeSmartDNSAddr(vals[0]); n != "" { + cfg.SmartDNS = n + } + } + case "mode": + if len(vals) > 0 { + cfg.Mode = normalizeDNSResolverMode(DNSResolverMode(vals[0]), false) + } + } + } + if len(def) > 0 { + cfg.Default = def + } + if len(meta) > 0 { + cfg.Meta = meta + } + if logf != nil { + logf("dns-config: accept %s: mode=%s smartdns=%s default=%v; meta=%v", path, cfg.Mode, cfg.SmartDNS, cfg.Default, cfg.Meta) + } + return cfg +} + +// --------------------------------------------------------------------- +// EN: `readLinesAllowMissing` reads lines allow missing from input data. +// RU: `readLinesAllowMissing` - читает lines allow missing из входных данных. +// --------------------------------------------------------------------- +func readLinesAllowMissing(path string) []string { + data, err := os.ReadFile(path) + if err != nil { + return nil + } + return strings.Split(strings.ReplaceAll(string(data), "\r\n", "\n"), "\n") +} + +// --------------------------------------------------------------------- +// EN: `loadJSONMap` loads json map from storage or config. +// RU: `loadJSONMap` - загружает json map из хранилища или конфига. +// --------------------------------------------------------------------- +func loadJSONMap(path string) map[string]any { + data, err := os.ReadFile(path) + if err != nil { + return map[string]any{} + } + var out map[string]any + if err := json.Unmarshal(data, &out); err != nil { + return map[string]any{} + } + return out +} + +// --------------------------------------------------------------------- +// EN: `saveJSON` saves json to persistent storage. +// RU: `saveJSON` - сохраняет json в постоянное хранилище. +// --------------------------------------------------------------------- +func saveJSON(data any, path string) { + tmp := path + ".tmp" + b, err := json.MarshalIndent(data, "", " ") + if err != nil { + return + } + _ = os.WriteFile(tmp, b, 0o644) + _ = os.Rename(tmp, path) +} + +// --------------------------------------------------------------------- +// EN: `uniqueStrings` contains core logic for unique strings. +// RU: `uniqueStrings` - содержит основную логику для unique strings. +// --------------------------------------------------------------------- +func uniqueStrings(in []string) []string { + seen := map[string]struct{}{} + var out []string + for _, v := range in { + if _, ok := seen[v]; !ok { + seen[v] = struct{}{} + out = append(out, v) + } + } + return out +} + +// --------------------------------------------------------------------- +// text cleanup + IP classifiers +// --------------------------------------------------------------------- + +var reANSI = regexp.MustCompile(`\x1B\[[0-9;]*[A-Za-z]`) + +func stripANSI(s string) string { + return reANSI.ReplaceAllString(s, "") +} + +// --------------------------------------------------------------------- +// EN: `isPrivateIPv4` checks whether private i pv4 is true. +// RU: `isPrivateIPv4` - проверяет, является ли private i pv4 истинным условием. +// --------------------------------------------------------------------- +func isPrivateIPv4(ip string) bool { + parts := strings.Split(strings.Split(ip, "/")[0], ".") + if len(parts) != 4 { + return true + } + vals := make([]int, 4) + for i, p := range parts { + n, err := strconv.Atoi(p) + if err != nil || n < 0 || n > 255 { + return true + } + vals[i] = n + } + if vals[0] == 10 || vals[0] == 127 || vals[0] == 0 { + return true + } + if vals[0] == 192 && vals[1] == 168 { + return true + } + if vals[0] == 172 && vals[1] >= 16 && vals[1] <= 31 { + return true + } + return false +} diff --git a/selective-vpn-api/app/resolver_cache_test.go b/selective-vpn-api/app/resolver_cache_test.go new file mode 100644 index 0000000..7627a87 --- /dev/null +++ b/selective-vpn-api/app/resolver_cache_test.go @@ -0,0 +1,70 @@ +package app + +import ( + "os" + "path/filepath" + "testing" +) + +func TestDomainCacheLegacyMigrationToDirectBucket(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "domain-cache.json") + legacy := `{ + "example.com": {"ips": ["1.1.1.1", "10.0.0.1"], "last_resolved": 100}, + "bad.com": {"ips": [], "last_resolved": 100} +}` + if err := os.WriteFile(path, []byte(legacy), 0o644); err != nil { + t.Fatalf("write legacy cache: %v", err) + } + + st := loadDomainCacheState(path, nil) + if _, ok := st.get("example.com", domainCacheSourceDirect, 150, 100); !ok { + t.Fatalf("expected direct cache hit after migration") + } + if _, ok := st.get("example.com", domainCacheSourceWildcard, 150, 100); ok { + t.Fatalf("did not expect wildcard cache hit for migrated legacy entry") + } + if ips, ok := st.get("example.com", domainCacheSourceDirect, 150, 100); !ok || len(ips) != 1 || ips[0] != "1.1.1.1" { + t.Fatalf("unexpected migrated ips: ok=%v ips=%v", ok, ips) + } +} + +func TestDomainCacheSplitBucketsAreIndependent(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "domain-cache.json") + v2 := `{ + "version": 2, + "domains": { + "example.com": { + "direct": {"ips": ["1.1.1.1"], "last_resolved": 100}, + "wildcard": {"ips": ["2.2.2.2"], "last_resolved": 100} + } + } +}` + if err := os.WriteFile(path, []byte(v2), 0o644); err != nil { + t.Fatalf("write v2 cache: %v", err) + } + + st := loadDomainCacheState(path, nil) + direct, ok := st.get("example.com", domainCacheSourceDirect, 150, 100) + if !ok || len(direct) != 1 || direct[0] != "1.1.1.1" { + t.Fatalf("unexpected direct lookup: ok=%v ips=%v", ok, direct) + } + wild, ok := st.get("example.com", domainCacheSourceWildcard, 150, 100) + if !ok || len(wild) != 1 || wild[0] != "2.2.2.2" { + t.Fatalf("unexpected wildcard lookup: ok=%v ips=%v", ok, wild) + } +} + +func TestDomainCacheSetAndTTL(t *testing.T) { + st := newDomainCacheState() + st.set("example.com", domainCacheSourceDirect, []string{"1.1.1.1", "1.1.1.1", "10.0.0.1"}, 100) + + if _, ok := st.get("example.com", domainCacheSourceDirect, 201, 100); ok { + t.Fatalf("expected cache miss due ttl expiry") + } + ips, ok := st.get("example.com", domainCacheSourceDirect, 200, 100) + if !ok || len(ips) != 1 || ips[0] != "1.1.1.1" { + t.Fatalf("unexpected ttl hit result: ok=%v ips=%v", ok, ips) + } +} diff --git a/selective-vpn-api/app/routes_cache.go b/selective-vpn-api/app/routes_cache.go new file mode 100644 index 0000000..7d757ad --- /dev/null +++ b/selective-vpn-api/app/routes_cache.go @@ -0,0 +1,399 @@ +package app + +import ( + "context" + "encoding/json" + "fmt" + "os" + "sort" + "strings" + "time" +) + +// --------------------------------------------------------------------- +// routes clear cache (safe clear / fast restore) +// --------------------------------------------------------------------- + +// EN: Snapshot data persisted before routes clear to support fast restore +// EN: without running full domain resolve again. +// RU: Снимок данных, который сохраняется перед routes clear для быстрого +// RU: восстановления без повторного полного резолва доменов. +type routesClearCacheMeta struct { + CreatedAt string `json:"created_at"` + Iface string `json:"iface,omitempty"` + RouteCount int `json:"route_count"` + IPCount int `json:"ip_count"` + DynIPCount int `json:"dyn_ip_count"` + HasIPMap bool `json:"has_ip_map"` +} + +func saveRoutesClearCache() (routesClearCacheMeta, error) { + if err := os.MkdirAll(stateDir, 0o755); err != nil { + return routesClearCacheMeta{}, err + } + + routes, err := readCurrentRoutesTableLines() + if err != nil { + return routesClearCacheMeta{}, err + } + if err := writeLinesFile(routesCacheRT, routes); err != nil { + return routesClearCacheMeta{}, err + } + + var warns []string + + ipCount, err := snapshotNftSetToFile("agvpn4", routesCacheIPs) + if err != nil { + warns = append(warns, fmt.Sprintf("agvpn4 snapshot failed: %v", err)) + _ = cacheCopyOrEmpty(stateDir+"/last-ips.txt", routesCacheIPs) + ipCount = len(readNonEmptyLines(routesCacheIPs)) + } + + dynIPCount, err := snapshotNftSetToFile("agvpn_dyn4", routesCacheDyn) + if err != nil { + warns = append(warns, fmt.Sprintf("agvpn_dyn4 snapshot failed: %v", err)) + _ = os.WriteFile(routesCacheDyn, []byte{}, 0o644) + dynIPCount = 0 + } + + if err := cacheCopyOrEmpty(stateDir+"/last-ips-map.txt", routesCacheMap); err != nil { + warns = append(warns, fmt.Sprintf("last-ips-map cache copy failed: %v", err)) + } + + meta := routesClearCacheMeta{ + CreatedAt: time.Now().UTC().Format(time.RFC3339), + Iface: detectIfaceFromRoutes(routes), + RouteCount: len(routes), + IPCount: ipCount, + DynIPCount: dynIPCount, + HasIPMap: fileExists(routesCacheMap), + } + + data, err := json.MarshalIndent(meta, "", " ") + if err != nil { + return routesClearCacheMeta{}, err + } + if err := os.WriteFile(routesCacheMeta, data, 0o644); err != nil { + return routesClearCacheMeta{}, err + } + if len(warns) > 0 { + return meta, fmt.Errorf("%s", strings.Join(warns, "; ")) + } + return meta, nil +} + +func restoreRoutesFromCache() cmdResult { + meta, err := loadRoutesClearCacheMeta() + if err != nil { + return cmdResult{ + OK: false, + Message: fmt.Sprintf("routes cache missing: %v", err), + } + } + + ips := readNonEmptyLines(routesCacheIPs) + dynIPs := readNonEmptyLines(routesCacheDyn) + routeLines, _ := readLinesFile(routesCacheRT) + + ensureRoutesTableEntry() + removeTrafficRulesForTable() + _, _, _, _ = runCommandTimeout(5*time.Second, "ip", "route", "flush", "table", routesTableName()) + + ignoredRoutes := 0 + for _, ln := range routeLines { + if err := restoreRouteLine(ln); err != nil { + if shouldIgnoreRestoreRouteError(ln, err) { + ignoredRoutes++ + appendTraceLine("routes", fmt.Sprintf("restore route skipped (%q): %v", ln, err)) + continue + } + return cmdResult{ + OK: false, + Message: fmt.Sprintf("restore route failed (%q): %v", ln, err), + } + } + } + if ignoredRoutes > 0 { + appendTraceLine("routes", fmt.Sprintf("restore route: skipped non-critical routes=%d", ignoredRoutes)) + } + + if len(routeLines) == 0 && strings.TrimSpace(meta.Iface) != "" { + _, _, _, _ = runCommandTimeout( + 5*time.Second, + "ip", "-4", "route", "replace", + "default", "dev", meta.Iface, + "table", routesTableName(), + "mtu", policyRouteMTU, + ) + } + + _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "table", "inet", "agvpn") + _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "set", "inet", "agvpn", "agvpn4", "{", "type", "ipv4_addr", ";", "flags", "interval", ";", "}") + _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "set", "inet", "agvpn", "agvpn_dyn4", "{", "type", "ipv4_addr", ";", "flags", "interval", ";", "}") + _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "flush", "set", "inet", "agvpn", "agvpn4") + _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "flush", "set", "inet", "agvpn", "agvpn_dyn4") + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) + defer cancel() + + if len(ips) > 0 { + if err := nftUpdateSetIPsSmart(ctx, "agvpn4", ips, nil); err != nil { + return cmdResult{ + OK: false, + Message: fmt.Sprintf("restore nft cache failed for agvpn4: %v", err), + } + } + } + if len(dynIPs) > 0 { + if err := nftUpdateSetIPsSmart(ctx, "agvpn_dyn4", dynIPs, nil); err != nil { + return cmdResult{ + OK: false, + Message: fmt.Sprintf("restore nft cache failed for agvpn_dyn4: %v", err), + } + } + } + + traffic := loadTrafficModeState() + iface := strings.TrimSpace(meta.Iface) + if iface == "" { + iface = detectIfaceFromRoutes(routeLines) + } + if iface == "" { + iface, _ = resolveTrafficIface(traffic.PreferredIface) + } + if iface != "" { + if err := applyTrafficMode(traffic, iface); err != nil { + return cmdResult{ + OK: false, + Message: fmt.Sprintf("cache restored, but traffic mode apply failed: %v", err), + } + } + } + + _ = cacheCopyOrEmpty(routesCacheIPs, stateDir+"/last-ips.txt") + if fileExists(routesCacheMap) { + _ = cacheCopyOrEmpty(routesCacheMap, stateDir+"/last-ips-map.txt") + } + + return cmdResult{ + OK: true, + Message: fmt.Sprintf( + "routes restored from cache: agvpn4=%d agvpn_dyn4=%d routes=%d iface=%s", + len(ips), len(dynIPs), len(routeLines), ifaceOrDash(iface), + ), + } +} + +func readCurrentRoutesTableLines() ([]string, error) { + out, _, code, err := runCommandTimeout(5*time.Second, "ip", "-4", "route", "show", "table", routesTableName()) + if err != nil && code != 0 { + return nil, err + } + lines := make([]string, 0, 32) + for _, raw := range strings.Split(out, "\n") { + ln := strings.TrimSpace(raw) + if ln == "" { + continue + } + lines = append(lines, ln) + } + return lines, nil +} + +func writeLinesFile(path string, lines []string) error { + if len(lines) == 0 { + return os.WriteFile(path, []byte{}, 0o644) + } + payload := strings.Join(lines, "\n") + if !strings.HasSuffix(payload, "\n") { + payload += "\n" + } + return os.WriteFile(path, []byte(payload), 0o644) +} + +func readLinesFile(path string) ([]string, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + lines := make([]string, 0, 64) + for _, raw := range strings.Split(string(data), "\n") { + ln := strings.TrimSpace(raw) + if ln == "" { + continue + } + lines = append(lines, ln) + } + return lines, nil +} + +func detectIfaceFromRoutes(lines []string) string { + for _, ln := range lines { + fields := strings.Fields(ln) + for i := 0; i+1 < len(fields); i++ { + if fields[i] == "dev" { + return strings.TrimSpace(fields[i+1]) + } + } + } + return "" +} + +func restoreRouteLine(line string) error { + fields := strings.Fields(strings.TrimSpace(line)) + if len(fields) == 0 { + return nil + } + args := []string{"-4", "route", "replace"} + args = append(args, fields...) + hasTable := false + for i := 0; i+1 < len(fields); i++ { + if fields[i] == "table" { + hasTable = true + break + } + } + if !hasTable { + args = append(args, "table", routesTableName()) + } + _, _, code, err := runCommandTimeout(5*time.Second, "ip", args...) + if err != nil || code != 0 { + if err == nil { + err = fmt.Errorf("exit code %d", code) + } + return err + } + return nil +} + +func shouldIgnoreRestoreRouteError(line string, err error) bool { + ln := strings.ToLower(strings.TrimSpace(line)) + if strings.Contains(ln, " linkdown") { + return true + } + + dev := routeLineDevice(ln) + if dev != "" && !strings.HasPrefix(ln, "default ") && !ifaceExists(dev) { + return true + } + + msg := strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", err))) + if strings.HasPrefix(ln, "default ") { + return false + } + if strings.Contains(msg, "cannot find device") || + strings.Contains(msg, "no such device") || + strings.Contains(msg, "network is down") { + return true + } + return false +} + +func routeLineDevice(line string) string { + fields := strings.Fields(strings.TrimSpace(line)) + for i := 0; i+1 < len(fields); i++ { + if fields[i] == "dev" { + return strings.TrimSpace(fields[i+1]) + } + } + return "" +} + +func cacheCopyOrEmpty(src, dst string) error { + if err := copyFile(src, dst); err == nil { + return nil + } + return os.WriteFile(dst, []byte{}, 0o644) +} + +func snapshotNftSetToFile(setName, dst string) (int, error) { + elems, err := readNftSetElements(setName) + if err != nil { + return 0, err + } + if err := writeLinesFile(dst, elems); err != nil { + return 0, err + } + return len(elems), nil +} + +func readNftSetElements(setName string) ([]string, error) { + out, stderr, code, err := runCommandTimeout( + 8*time.Second, "nft", "list", "set", "inet", "agvpn", setName, + ) + if err != nil || code != 0 { + msg := strings.ToLower(strings.TrimSpace(out + " " + stderr)) + if strings.Contains(msg, "no such file") || + strings.Contains(msg, "not found") || + strings.Contains(msg, "does not exist") { + return nil, nil + } + if err != nil { + return nil, fmt.Errorf("nft list set %s failed: %w", setName, err) + } + return nil, fmt.Errorf("nft list set %s failed: %s", setName, strings.TrimSpace(stderr)) + } + return parseNftSetElementsText(out), nil +} + +func parseNftSetElementsText(raw string) []string { + idx := strings.Index(raw, "elements =") + if idx < 0 { + return nil + } + chunk := raw[idx:] + open := strings.Index(chunk, "{") + if open < 0 { + return nil + } + body := chunk[open+1:] + closeIdx := strings.Index(body, "}") + if closeIdx >= 0 { + body = body[:closeIdx] + } + body = strings.ReplaceAll(body, "\r", " ") + body = strings.ReplaceAll(body, "\n", " ") + + seen := map[string]struct{}{} + out := make([]string, 0, 1024) + for _, tok := range strings.Split(body, ",") { + val := strings.TrimSpace(tok) + if val == "" { + continue + } + if _, ok := seen[val]; ok { + continue + } + seen[val] = struct{}{} + out = append(out, val) + } + sort.Strings(out) + return out +} + +func loadRoutesClearCacheMeta() (routesClearCacheMeta, error) { + data, err := os.ReadFile(routesCacheMeta) + if err != nil { + return routesClearCacheMeta{}, err + } + var meta routesClearCacheMeta + if err := json.Unmarshal(data, &meta); err != nil { + return routesClearCacheMeta{}, err + } + return meta, nil +} + +func fileExists(path string) bool { + info, err := os.Stat(path) + if err != nil { + return false + } + return !info.IsDir() +} + +func ifaceOrDash(iface string) string { + if strings.TrimSpace(iface) == "" { + return "-" + } + return iface +} diff --git a/selective-vpn-api/app/routes_handlers.go b/selective-vpn-api/app/routes_handlers.go new file mode 100644 index 0000000..601f732 --- /dev/null +++ b/selective-vpn-api/app/routes_handlers.go @@ -0,0 +1,405 @@ +package app + +import ( + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "os" + "strings" + "syscall" +) + +// --------------------------------------------------------------------- +// routes handlers +// --------------------------------------------------------------------- + +// EN: HTTP handlers for selective routing control plane operations: +// EN: status, systemd service/timer control, route cleanup, policy fix, and async update trigger. +// RU: HTTP-обработчики control-plane для селективной маршрутизации: +// RU: статус, управление service/timer через systemd, очистка, фиксация policy route и запуск обновления. + +func handleGetStatus(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + data, err := os.ReadFile(statusFilePath) + if err != nil { + if os.IsNotExist(err) { + http.Error(w, "status file not found", http.StatusNotFound) + return + } + http.Error(w, "failed to read status file", http.StatusInternalServerError) + return + } + + var st Status + if err := json.Unmarshal(data, &st); err != nil { + http.Error(w, "invalid status.json", http.StatusInternalServerError) + return + } + + if st.Iface != "" && st.Iface != "-" && st.Table != "" && st.Table != "-" { + ok, err := checkPolicyRoute(st.Iface, st.Table) + if err != nil { + log.Printf("checkPolicyRoute error: %v", err) + } else { + st.PolicyRouteOK = &ok + st.RouteOK = &ok + } + } + + writeJSON(w, http.StatusOK, st) +} + +// --------------------------------------------------------------------- +// routes service +// --------------------------------------------------------------------- + +func makeCmdHandler(name string, args ...string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + stdout, stderr, code, err := runCommand(name, args...) + res := cmdResult{ + OK: err == nil && code == 0, + ExitCode: code, + Stdout: stdout, + Stderr: stderr, + } + if err != nil { + res.Message = err.Error() + } + writeJSON(w, http.StatusOK, res) + } +} + +func runRoutesServiceAction(action string) cmdResult { + action = strings.ToLower(strings.TrimSpace(action)) + unit := routesServiceUnitName() + if unit == "" { + return cmdResult{ + OK: false, + Message: "routes service unit unresolved: set preferred iface or SELECTIVE_VPN_ROUTES_UNIT", + } + } + + var args []string + switch action { + case "start", "stop", "restart": + args = []string{"systemctl", action, unit} + default: + return cmdResult{ + OK: false, + Message: "unknown action (expected start|stop|restart)", + } + } + + stdout, stderr, exitCode, err := runCommand(args[0], args[1:]...) + res := cmdResult{ + OK: err == nil && exitCode == 0, + ExitCode: exitCode, + Stdout: stdout, + Stderr: stderr, + } + if err != nil { + res.Message = err.Error() + } else { + res.Message = fmt.Sprintf("%s done (%s)", action, unit) + } + return res +} + +func makeRoutesServiceActionHandler(action string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + res := runRoutesServiceAction(action) + writeJSON(w, http.StatusOK, res) + } +} + +// POST /api/v1/routes/service { "action": "start|stop|restart" } +// --------------------------------------------------------------------- +// EN: `handleRoutesService` is an HTTP handler for routes service. +// RU: `handleRoutesService` - HTTP-обработчик для routes service. +// --------------------------------------------------------------------- +func handleRoutesService(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + var body struct { + Action string `json:"action"` + } + if r.Body != nil { + defer r.Body.Close() + _ = json.NewDecoder(io.LimitReader(r.Body, 1<<20)).Decode(&body) + } + + res := runRoutesServiceAction(body.Action) + if strings.Contains(res.Message, "unknown action") { + writeJSON(w, http.StatusBadRequest, res) + return + } + writeJSON(w, http.StatusOK, res) +} + +// --------------------------------------------------------------------- +// routes timer +// --------------------------------------------------------------------- + +// старый toggle (используем из GUI, если что) +func handleRoutesTimerToggle(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + enabled := isTimerEnabled() + res := runRoutesTimerSet(!enabled) + writeJSON(w, http.StatusOK, res) +} + +// новый API: GET → {enabled:bool}, POST {enabled:bool} +// --------------------------------------------------------------------- +// EN: `handleRoutesTimer` is an HTTP handler for routes timer. +// RU: `handleRoutesTimer` - HTTP-обработчик для routes timer. +// --------------------------------------------------------------------- +func handleRoutesTimer(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + enabled := isTimerEnabled() + writeJSON(w, http.StatusOK, map[string]any{ + "enabled": enabled, + }) + case http.MethodPost: + var body struct { + Enabled bool `json:"enabled"` + } + if r.Body != nil { + defer r.Body.Close() + _ = json.NewDecoder(io.LimitReader(r.Body, 1<<20)).Decode(&body) + } + res := runRoutesTimerSet(body.Enabled) + writeJSON(w, http.StatusOK, res) + default: + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + } +} + +// --------------------------------------------------------------------- +// EN: `isTimerEnabled` checks whether timer enabled is true. +// RU: `isTimerEnabled` - проверяет, является ли timer enabled истинным условием. +// --------------------------------------------------------------------- +func isTimerEnabled() bool { + unit := routesTimerUnitName() + if unit == "" { + return false + } + _, _, code, _ := runCommand("systemctl", "is-enabled", unit) + return code == 0 +} + +func runRoutesTimerSet(enabled bool) cmdResult { + unit := routesTimerUnitName() + if unit == "" { + return cmdResult{ + OK: false, + Message: "routes timer unit unresolved: set preferred iface or SELECTIVE_VPN_ROUTES_TIMER", + } + } + cmd := []string{"systemctl", "disable", "--now", unit} + msg := "routes timer disabled" + if enabled { + cmd = []string{"systemctl", "enable", "--now", unit} + msg = "routes timer enabled" + } + stdout, stderr, exitCode, err := runCommand(cmd[0], cmd[1:]...) + res := cmdResult{ + OK: err == nil && exitCode == 0, + Message: fmt.Sprintf("%s (%s)", msg, unit), + ExitCode: exitCode, + Stdout: stdout, + Stderr: stderr, + } + if err != nil { + res.Message = fmt.Sprintf("%s (%s): %v", msg, unit, err) + } + return res +} + +// --------------------------------------------------------------------- +// rollback / clear +// --------------------------------------------------------------------- + +func handleRoutesClear(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + res := routesClear() + writeJSON(w, http.StatusOK, res) +} + +func handleRoutesCacheRestore(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + res := restoreRoutesFromCache() + writeJSON(w, http.StatusOK, res) +} + +// --------------------------------------------------------------------- +// EN: `routesClear` contains core logic for routes clear. +// RU: `routesClear` - содержит основную логику для routes clear. +// --------------------------------------------------------------------- +func routesClear() cmdResult { + cacheMeta, cacheErr := saveRoutesClearCache() + + stdout, stderr, _, err := runCommand("ip", "rule", "show") + if err == nil && stdout != "" { + removeTrafficRulesForTable() + } + + _, _, _, _ = runCommand("ip", "route", "flush", "table", routesTableName()) + _, _, _, _ = runCommand("nft", "flush", "set", "inet", "agvpn", "agvpn4") + _, _, _, _ = runCommand("nft", "flush", "set", "inet", "agvpn", "agvpn_dyn4") + + res := cmdResult{ + OK: true, + Message: "routes cleared", + ExitCode: 0, + Stdout: stdout, + Stderr: stderr, + } + if cacheErr != nil { + res.Message = fmt.Sprintf("%s (cache warning: %v)", res.Message, cacheErr) + } else { + res.Message = fmt.Sprintf( + "%s (cache saved: agvpn4=%d agvpn_dyn4=%d routes=%d iface=%s at=%s)", + res.Message, + cacheMeta.IPCount, + cacheMeta.DynIPCount, + cacheMeta.RouteCount, + ifaceOrDash(cacheMeta.Iface), + cacheMeta.CreatedAt, + ) + } + return res +} + +// --------------------------------------------------------------------- +// policy route +// --------------------------------------------------------------------- + +func handleFixPolicyRoute(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + data, err := os.ReadFile(statusFilePath) + if err != nil { + http.Error(w, "status.json missing", http.StatusBadRequest) + return + } + var st Status + if err := json.Unmarshal(data, &st); err != nil { + http.Error(w, "invalid status.json", http.StatusBadRequest) + return + } + + iface := strings.TrimSpace(st.Iface) + table := strings.TrimSpace(st.Table) + if iface == "" || iface == "-" || table == "" || table == "-" { + http.Error(w, "iface/table unknown in status.json", http.StatusBadRequest) + return + } + + stdout, stderr, exitCode, err := runCommand( + "ip", "-4", "route", "replace", + "default", "dev", iface, "table", table, "mtu", policyRouteMTU, + ) + + ok := err == nil && exitCode == 0 + res := cmdResult{ + OK: ok, + ExitCode: exitCode, + Stdout: stdout, + Stderr: stderr, + } + if ok { + res.Message = fmt.Sprintf("policy route fixed: default dev %s table %s", iface, table) + } else if err != nil { + res.Message = err.Error() + } + + writeJSON(w, http.StatusOK, res) +} + +// --------------------------------------------------------------------- +// routes update (Go port of update-selective-routes2.sh) +// --------------------------------------------------------------------- + +func handleRoutesUpdate(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + var body struct { + Iface string `json:"iface"` + } + if r.Body != nil { + defer r.Body.Close() + _ = json.NewDecoder(io.LimitReader(r.Body, 1<<20)).Decode(&body) + } + iface := strings.TrimSpace(body.Iface) + iface = normalizePreferredIface(iface) + if iface == "" { + iface, _ = resolveTrafficIface(loadTrafficModeState().PreferredIface) + } + + lock, err := os.OpenFile(lockFile, os.O_CREATE|os.O_RDWR, 0o644) + if err != nil { + http.Error(w, "lock open error", http.StatusInternalServerError) + return + } + if err := syscall.Flock(int(lock.Fd()), syscall.LOCK_EX|syscall.LOCK_NB); err != nil { + writeJSON(w, http.StatusOK, map[string]any{ + "ok": false, + "message": "routes update already running", + }) + lock.Close() + return + } + + go func(iface string, lockFile *os.File) { + defer syscall.Flock(int(lockFile.Fd()), syscall.LOCK_UN) + defer lockFile.Close() + + res := routesUpdate(iface) + evKind := "routes_update_done" + if !res.OK { + evKind = "routes_update_error" + } + events.push(evKind, map[string]any{ + "ok": res.OK, + "message": res.Message, + "ip_cnt": res.ExitCode, // reuse exitCode to pass ip_count if set + }) + }(iface, lock) + + writeJSON(w, http.StatusOK, map[string]any{ + "ok": true, + "message": "routes update started", + }) +} diff --git a/selective-vpn-api/app/routes_units.go b/selective-vpn-api/app/routes_units.go new file mode 100644 index 0000000..aac0e57 --- /dev/null +++ b/selective-vpn-api/app/routes_units.go @@ -0,0 +1,52 @@ +package app + +import ( + "fmt" + "os" + "strings" +) + +// --------------------------------------------------------------------- +// routes systemd unit name resolution +// --------------------------------------------------------------------- + +// EN: Resolve routes service/timer unit names from preferred/active iface. +// EN: Env overrides still have top priority for custom deployments. +// RU: Вычисление имен unit для routes service/timer по preferred/active iface. +// RU: Для кастомных окружений сохраняется приоритет переменных окружения. + +func resolveRoutesUnitIface() (string, string) { + st := loadTrafficModeState() + if pref := normalizePreferredIface(st.PreferredIface); pref != "" { + return pref, "preferred" + } + if statusIface := statusIfaceFromFile(); statusIface != "" && statusIface != "-" { + return statusIface, "status" + } + if active, reason := resolveTrafficIface(""); active != "" { + return active, reason + } + return "", "iface-not-found" +} + +func routesServiceUnitName() string { + if forced := strings.TrimSpace(os.Getenv(routesServiceEnv)); forced != "" { + return forced + } + iface, _ := resolveRoutesUnitIface() + if iface == "" { + return "" + } + return fmt.Sprintf(routesServiceTemplate, iface) +} + +func routesTimerUnitName() string { + if forced := strings.TrimSpace(os.Getenv(routesTimerEnv)); forced != "" { + return forced + } + iface, _ := resolveRoutesUnitIface() + if iface == "" { + return "" + } + return fmt.Sprintf(routesTimerTemplate, iface) +} diff --git a/selective-vpn-api/app/routes_update.go b/selective-vpn-api/app/routes_update.go new file mode 100644 index 0000000..9878f8c --- /dev/null +++ b/selective-vpn-api/app/routes_update.go @@ -0,0 +1,703 @@ +package app + +import ( + "context" + "encoding/json" + "fmt" + "io/fs" + "net" + "os" + "os/user" + "sort" + "strconv" + "strings" + "time" +) + +// --------------------------------------------------------------------- +// основной routesUpdate +// --------------------------------------------------------------------- + +// EN: Core selective-routes orchestration pipeline. +// EN: This unit prepares policy routing, nftables objects, domain expansion, +// EN: resolver execution, status artifacts, and GUI-facing progress events. +// RU: Основной orchestration-пайплайн selective-routes. +// RU: Модуль готовит policy routing, nftables-объекты, расширение доменов, +// RU: запуск резолвера, статусные артефакты и события прогресса для GUI. + +// --------------------------------------------------------------------- +// EN: `routesUpdate` contains core logic for routes update. +// RU: `routesUpdate` - содержит основную логику для routes update. +// --------------------------------------------------------------------- +func routesUpdate(iface string) cmdResult { + logp := func(format string, args ...any) { + appendTraceLine("routes", fmt.Sprintf(format, args...)) + } + heartbeat := func() { + _ = os.WriteFile(heartbeatFile, []byte{}, 0o644) + } + + res := cmdResult{OK: false} + + iface = normalizePreferredIface(iface) + if iface == "" { + iface, _ = resolveTrafficIface(loadTrafficModeState().PreferredIface) + } + if iface == "" { + logp("no active vpn iface, exit 0") + res.OK = true + res.Message = "interface not found, skipped" + return res + } + + // ----------------------------------------------------------------- + // preflight + // ----------------------------------------------------------------- + + // ensure dirs + _ = os.MkdirAll(stateDir, 0o755) + _ = os.MkdirAll(domainDir, 0o755) + _ = os.MkdirAll("/etc/selective-vpn", 0o755) + + heartbeat() + + // wait iface up + up := false + for i := 0; i < 30; i++ { + if _, _, code, _ := runCommandTimeout(3*time.Second, "ip", "link", "show", iface); code == 0 { + up = true + break + } + time.Sleep(1 * time.Second) + heartbeat() + } + if !up { + logp("no %s, exit 0", iface) + res.OK = true + res.Message = "interface not found, skipped" + return res + } + + // wait DNS (like wait-for-dns.sh) + if err := waitDNS(15, 1*time.Second); err != nil { + logp("dns not ready: %v", err) + res.Message = "dns not ready" + return res + } + + // ----------------------------------------------------------------- + // policy routing setup + // ----------------------------------------------------------------- + + // rt_tables entry + ensureRoutesTableEntry() + + // ip rules: remove old rules pointing to table + if out, _, _, _ := runCommandTimeout(5*time.Second, "ip", "rule", "show"); out != "" { + for _, line := range strings.Split(out, "\n") { + if !strings.Contains(line, "lookup "+routesTableName()) { + continue + } + fields := strings.Fields(line) + if len(fields) == 0 { + continue + } + pref := strings.TrimSuffix(fields[0], ":") + if pref == "" { + continue + } + _, _, _, _ = runCommandTimeout(5*time.Second, "ip", "rule", "del", "pref", pref) + } + } + + // clean table and set default route + _, _, _, _ = runCommandTimeout(5*time.Second, "ip", "route", "flush", "table", routesTableName()) + _, _, _, _ = runCommandTimeout(5*time.Second, "ip", "-4", "route", "replace", "default", "dev", iface, "table", routesTableName(), "mtu", policyRouteMTU) + // apply traffic mode rules (selective/full_tunnel/direct) over fresh table. + trafficState := loadTrafficModeState() + trafficIface, trafficIfaceReason := resolveTrafficIface(trafficState.PreferredIface) + if trafficIface == "" { + trafficIface = iface + trafficIfaceReason = "routes-update-iface" + } + if err := applyTrafficMode(trafficState, trafficIface); err != nil { + logp("traffic mode apply failed: mode=%s iface=%s err=%v", trafficState.Mode, iface, err) + res.Message = fmt.Sprintf("traffic mode apply failed: %v", err) + return res + } + trafficEval := evaluateTrafficMode(trafficState) + logp( + "traffic mode: desired=%s applied=%s healthy=%t iface=%s reason=%s", + trafficEval.DesiredMode, + trafficEval.AppliedMode, + trafficEval.Healthy, + trafficEval.ActiveIface, + trafficEval.Message+" (apply_iface_source="+trafficIfaceReason+")", + ) + + // ensure default exists + if out, _, _, _ := runCommandTimeout(5*time.Second, "ip", "route", "show", "table", routesTableName()); !strings.Contains(out, "default dev "+iface) { + _, _, _, _ = runCommandTimeout(5*time.Second, "ip", "-4", "route", "replace", "default", "dev", iface, "table", routesTableName(), "mtu", policyRouteMTU) + } + + heartbeat() + + // ----------------------------------------------------------------- + // nft base objects + // ----------------------------------------------------------------- + + // nft setup + _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "table", "inet", "agvpn") + _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "set", "inet", "agvpn", "agvpn4", "{", "type", "ipv4_addr", ";", "flags", "interval", ";", "}") + _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "set", "inet", "agvpn", "agvpn_dyn4", "{", "type", "ipv4_addr", ";", "flags", "interval", ";", "}") + + _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "chain", "inet", "agvpn", "output", "{", "type", "route", "hook", "output", "priority", "mangle;", "policy", "accept;", "}") + _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "flush", "chain", "inet", "agvpn", "output") + _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "rule", "inet", "agvpn", "output", "ip", "daddr", "@agvpn4", "meta", "mark", "set", MARK) + _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "rule", "inet", "agvpn", "output", "ip", "daddr", "@agvpn_dyn4", "meta", "mark", "set", MARK) + + _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "chain", "inet", "agvpn", "prerouting", "{", "type", "filter", "hook", "prerouting", "priority", "mangle;", "policy", "accept;", "}") + _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "flush", "chain", "inet", "agvpn", "prerouting") + _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "rule", "inet", "agvpn", "prerouting", "iifname", "!=", iface, "ip", "daddr", "@agvpn4", "meta", "mark", "set", MARK) + _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "rule", "inet", "agvpn", "prerouting", "iifname", "!=", iface, "ip", "daddr", "@agvpn_dyn4", "meta", "mark", "set", MARK) + + heartbeat() + + // ----------------------------------------------------------------- + // domains + resolver + // ----------------------------------------------------------------- + + // domain lists + bases := loadList(domainDir + "/bases.txt") + subs := loadList(domainDir + "/subs.txt") + wildcards := loadSmartDNSWildcardDomains(logp) + wildcardBasesAdded := 0 + for _, d := range wildcards { + d = strings.TrimSpace(d) + if d == "" { + continue + } + bases = append(bases, d) + wildcardBasesAdded++ + } + subsPerBaseLimit := envInt("RESOLVE_SUBS_PER_BASE_LIMIT", 0) + if subsPerBaseLimit < 0 { + subsPerBaseLimit = 0 + } + hardCap := envInt("RESOLVE_DOMAINS_HARD_CAP", 0) + if hardCap < 0 { + hardCap = 0 + } + + domainSet := make(map[string]struct{}) + expandedAdded := 0 + twitterAdded := 0 + for _, d := range bases { + domainSet[d] = struct{}{} + if !isGoogleLike(d) { + limit := len(subs) + if subsPerBaseLimit > 0 && subsPerBaseLimit < limit { + limit = subsPerBaseLimit + } + for i := 0; i < limit; i++ { + fqdn := subs[i] + "." + d + if _, ok := domainSet[fqdn]; !ok { + expandedAdded++ + } + domainSet[fqdn] = struct{}{} + } + } + } + for _, spec := range twitterSpecial { + fqdn := spec + ".twitter.com" + if _, ok := domainSet[fqdn]; !ok { + twitterAdded++ + } + domainSet[fqdn] = struct{}{} + } + + domains := make([]string, 0, len(domainSet)) + for d := range domainSet { + if d != "" { + domains = append(domains, d) + } + } + sort.Strings(domains) + totalBeforeCap := len(domains) + if hardCap > 0 && len(domains) > hardCap { + domains = domains[:hardCap] + logp("domain cap applied: before=%d after=%d hard_cap=%d", totalBeforeCap, len(domains), hardCap) + } + logp( + "domains expanded: bases=%d subs_total=%d subs_per_base_limit=%d expanded_added=%d twitter_added=%d total_before_cap=%d total_used=%d", + len(bases), + len(subs), + subsPerBaseLimit, + expandedAdded, + twitterAdded, + totalBeforeCap, + len(domains), + ) + if wildcardBasesAdded > 0 { + logp("domains wildcard seed added: %d base domains from smartdns.conf state", wildcardBasesAdded) + } + + domTmp, _ := os.CreateTemp(stateDir, "domains-*.txt") + defer os.Remove(domTmp.Name()) + for _, d := range domains { + _, _ = domTmp.WriteString(d + "\n") + } + domTmp.Close() + + ipTmp, _ := os.CreateTemp(stateDir, "ips-*.txt") + ipTmp.Close() + ipMapTmp, _ := os.CreateTemp(stateDir, "ipmap-*.txt") + ipMapTmp.Close() + ipDirectTmp, _ := os.CreateTemp(stateDir, "ips-direct-*.txt") + ipDirectTmp.Close() + ipDynTmp, _ := os.CreateTemp(stateDir, "ips-dyn-*.txt") + ipDynTmp.Close() + ipMapDirectTmp, _ := os.CreateTemp(stateDir, "ipmap-direct-*.txt") + ipMapDirectTmp.Close() + ipMapDynTmp, _ := os.CreateTemp(stateDir, "ipmap-dyn-*.txt") + ipMapDynTmp.Close() + + heartbeat() + logp("using Go resolver for domains -> IPs") + mode := loadDNSMode() + runtimeEnabled := smartDNSRuntimeEnabled() + wildcardSource := wildcardFillSource(runtimeEnabled) + logp("resolver mode=%s smartdns_addr=%s wildcards=%d", mode.Mode, mode.SmartDNSAddr, len(wildcards)) + logp("wildcard source baseline: %s (runtime_nftset=%t)", wildcardSource, runtimeEnabled) + + resolveOpts := ResolverOpts{ + DomainsPath: domTmp.Name(), + MetaPath: domainDir + "/meta-special.txt", + StaticPath: staticIPsFile, + CachePath: stateDir + "/domain-cache.json", + PtrCachePath: stateDir + "/ptr-cache.json", + TraceLog: traceLogPath, + TTL: envInt("RESOLVE_TTL", 24*3600), + Workers: envInt("RESOLVE_JOBS", 40), + DNSConfigPath: dnsUpstreamsConf, + ViaSmartDNS: mode.ViaSmartDNS, // legacy fallback for older clients/state + Mode: mode.Mode, + SmartDNSAddr: mode.SmartDNSAddr, + SmartDNSWildcards: wildcards, + } + + resJob, err := runResolverJob(resolveOpts, logp) + if err != nil { + logp("Go resolver FAILED: %v", err) + res.Message = fmt.Sprintf("resolver failed: %v", err) + return res + } + + if err := writeLines(ipTmp.Name(), resJob.IPs); err != nil { + logp("write ips failed: %v", err) + res.Message = fmt.Sprintf("write ips failed: %v", err) + return res + } + if err := writeMapPairs(ipMapTmp.Name(), resJob.IPMap); err != nil { + logp("write ip_map failed: %v", err) + res.Message = fmt.Sprintf("write ip_map failed: %v", err) + return res + } + if err := writeLines(ipDirectTmp.Name(), resJob.DirectIPs); err != nil { + logp("write direct ips failed: %v", err) + res.Message = fmt.Sprintf("write direct ips failed: %v", err) + return res + } + if err := writeLines(ipDynTmp.Name(), resJob.WildcardIPs); err != nil { + logp("write wildcard ips failed: %v", err) + res.Message = fmt.Sprintf("write wildcard ips failed: %v", err) + return res + } + if err := writeMapPairs(ipMapDirectTmp.Name(), resJob.DirectIPMap); err != nil { + logp("write direct ip_map failed: %v", err) + res.Message = fmt.Sprintf("write direct ip_map failed: %v", err) + return res + } + if err := writeMapPairs(ipMapDynTmp.Name(), resJob.WildcardIPMap); err != nil { + logp("write wildcard ip_map failed: %v", err) + res.Message = fmt.Sprintf("write wildcard ip_map failed: %v", err) + return res + } + saveJSON(resJob.DomainCache, resolveOpts.CachePath) + saveJSON(resJob.PtrCache, resolveOpts.PtrCachePath) + + heartbeat() + + ipCount := len(resJob.IPs) + directIPCount := len(resJob.DirectIPs) + wildcardIPCount := len(resJob.WildcardIPs) + domainCount := countDomainsFromPairs(resJob.IPMap) + + // ----------------------------------------------------------------- + // nft population + // ----------------------------------------------------------------- + + // nft load через умный апдейтер + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + + progressCb := func(percent int, msg string) { + logp("NFT progress: %d%% - %s", percent, msg) + heartbeat() + events.push("routes_nft_progress", map[string]any{ + "percent": percent, + "message": msg, + }) + } + + progressRange := func(start, end int, prefix string) ProgressCallback { + if progressCb == nil { + return nil + } + if end < start { + end = start + } + return func(percent int, msg string) { + if percent < 0 { + percent = 0 + } + if percent > 100 { + percent = 100 + } + scaled := start + (end-start)*percent/100 + if strings.TrimSpace(msg) == "" { + msg = "updating" + } + progressCb(scaled, prefix+": "+msg) + } + } + + if err := nftUpdateSetIPsSmart(ctx, "agvpn4", resJob.DirectIPs, progressRange(0, 50, "agvpn4")); err != nil { + logp("nft set update failed for agvpn4: %v", err) + res.Message = fmt.Sprintf("nft update failed for agvpn4: %v", err) + return res + } + if err := nftUpdateSetIPsSmart(ctx, "agvpn_dyn4", resJob.WildcardIPs, progressRange(50, 100, "agvpn_dyn4")); err != nil { + logp("nft set update failed for agvpn_dyn4: %v", err) + res.Message = fmt.Sprintf("nft update failed for agvpn_dyn4: %v", err) + return res + } + + logp("summary: domains=%d, unique_ips=%d direct_ips=%d wildcard_ips=%d", len(domains), ipCount, directIPCount, wildcardIPCount) + logp("updated agvpn4 with %d IPs (direct + static)", directIPCount) + logp("updated agvpn_dyn4 with %d IPs (wildcard, source=%s)", wildcardIPCount, wildcardSource) + logWildcardSmartDNSTrace(mode, wildcardSource, resJob.WildcardIPMap, wildcardIPCount) + + // ----------------------------------------------------------------- + // artifacts + status + // ----------------------------------------------------------------- + + // copy artifacts + _ = copyFile(ipTmp.Name(), lastIPsPath) + _ = copyFile(ipMapTmp.Name(), lastIPsMapPath) + _ = copyFile(ipDirectTmp.Name(), lastIPsDirect) + _ = copyFile(ipDynTmp.Name(), lastIPsDyn) + _ = copyFile(ipMapDirectTmp.Name(), lastIPsMapDirect) + _ = copyFile(ipMapDynTmp.Name(), lastIPsMapDyn) + + now := time.Now().Format(time.RFC3339) + status := Status{ + Timestamp: now, + IPCount: ipCount, + DomainCount: domainCount, + Iface: iface, + Table: routesTableName(), + Mark: MARK, + } + statusData, _ := json.MarshalIndent(status, "", " ") + _ = os.WriteFile(statusFilePath, statusData, 0o644) + + chownDev( + traceLogPath, + ipTmp.Name(), ipMapTmp.Name(), + ipDirectTmp.Name(), ipDynTmp.Name(), ipMapDirectTmp.Name(), ipMapDynTmp.Name(), + lastIPsPath, lastIPsMapPath, lastIPsDirect, lastIPsDyn, lastIPsMapDirect, lastIPsMapDyn, + statusFilePath, + heartbeatFile, + ) + chmodPaths( + 0o644, + ipTmp.Name(), ipMapTmp.Name(), + ipDirectTmp.Name(), ipDynTmp.Name(), ipMapDirectTmp.Name(), ipMapDynTmp.Name(), + lastIPsPath, lastIPsMapPath, lastIPsDirect, lastIPsDyn, lastIPsMapDirect, lastIPsMapDyn, + statusFilePath, + heartbeatFile, + ) + _ = os.Chmod(traceLogPath, 0o666) + _ = os.Chmod(stateDir, 0o755) + + heartbeat() + + res.OK = true + res.Message = fmt.Sprintf("update done: domains=%d unique_ips=%d direct_ips=%d wildcard_ips=%d", len(domains), ipCount, directIPCount, wildcardIPCount) + res.ExitCode = ipCount + return res +} + +// --------------------------------------------------------------------- +// routesUpdate helpers: table / list / counters +// --------------------------------------------------------------------- + +func routesTableName() string { return "agvpn" } + +// --------------------------------------------------------------------- +// EN: `routesTableNum` contains core logic for routes table num. +// RU: `routesTableNum` - содержит основную логику для routes table num. +// --------------------------------------------------------------------- +func routesTableNum() string { return "666" } + +// --------------------------------------------------------------------- +// EN: `loadList` loads list from storage or config. +// RU: `loadList` - загружает list из хранилища или конфига. +// --------------------------------------------------------------------- +func loadList(path string) []string { + data, err := os.ReadFile(path) + if err != nil { + return nil + } + var out []string + for _, ln := range strings.Split(string(data), "\n") { + ln = strings.TrimSpace(strings.SplitN(ln, "#", 2)[0]) + if ln == "" { + continue + } + out = append(out, ln) + } + return out +} + +// --------------------------------------------------------------------- +// EN: `loadSmartDNSWildcardDomains` loads SmartDNS wildcard domains from canonical API state. +// RU: `loadSmartDNSWildcardDomains` - загружает wildcard-домены SmartDNS из каноничного API-состояния. +// --------------------------------------------------------------------- +func loadSmartDNSWildcardDomains(logf func(string, ...any)) []string { + out, source := loadSmartDNSWildcardDomainsState(logf) + sort.Strings(out) + if logf != nil { + logf("smartdns wildcards loaded: source=%s count=%d", source, len(out)) + } + return out +} + +// --------------------------------------------------------------------- +// EN: `isGoogleLike` checks whether google like is true. +// RU: `isGoogleLike` - проверяет, является ли google like истинным условием. +// --------------------------------------------------------------------- +func isGoogleLike(d string) bool { + low := strings.ToLower(d) + for _, base := range googleLikeDomains { + if low == base || strings.HasSuffix(low, "."+base) { + return true + } + } + return false +} + +// --------------------------------------------------------------------- +// EN: `readNonEmptyLines` reads non empty lines from input data. +// RU: `readNonEmptyLines` - читает non empty lines из входных данных. +// --------------------------------------------------------------------- +func readNonEmptyLines(path string) []string { + data, err := os.ReadFile(path) + if err != nil { + return nil + } + var out []string + for _, ln := range strings.Split(string(data), "\n") { + ln = strings.TrimSpace(ln) + if ln != "" { + out = append(out, ln) + } + } + return out +} + +func writeLines(path string, lines []string) error { + if len(lines) == 0 { + return os.WriteFile(path, []byte{}, 0o644) + } + return os.WriteFile(path, []byte(strings.Join(lines, "\n")+"\n"), 0o644) +} + +func writeMapPairs(path string, pairs [][2]string) error { + if len(pairs) == 0 { + return os.WriteFile(path, []byte{}, 0o644) + } + lines := make([]string, 0, len(pairs)) + for _, p := range pairs { + lines = append(lines, p[0]+"\t"+p[1]) + } + return os.WriteFile(path, []byte(strings.Join(lines, "\n")+"\n"), 0o644) +} + +func countDomainsFromPairs(pairs [][2]string) int { + seen := make(map[string]struct{}) + for _, p := range pairs { + if len(p) < 2 { + continue + } + d := strings.TrimSpace(p[1]) + if d == "" || strings.HasPrefix(d, "[") { + continue + } + seen[d] = struct{}{} + } + return len(seen) +} + +func wildcardHostIPMap(pairs [][2]string) map[string][]string { + hostToIPs := make(map[string]map[string]struct{}) + for _, p := range pairs { + if len(p) < 2 { + continue + } + ip := strings.TrimSpace(p[0]) + host := strings.TrimSpace(p[1]) + if ip == "" || host == "" || strings.HasPrefix(host, "[") { + continue + } + ips := hostToIPs[host] + if ips == nil { + ips = map[string]struct{}{} + hostToIPs[host] = ips + } + ips[ip] = struct{}{} + } + + out := make(map[string][]string, len(hostToIPs)) + for host, ipset := range hostToIPs { + ips := make([]string, 0, len(ipset)) + for ip := range ipset { + ips = append(ips, ip) + } + sort.Strings(ips) + out[host] = ips + } + return out +} + +func logWildcardSmartDNSTrace(mode DNSMode, source string, pairs [][2]string, wildcardIPCount int) { + lowMode := strings.ToLower(strings.TrimSpace(string(mode.Mode))) + if lowMode != string(DNSModeHybridWildcard) && lowMode != string(DNSModeSmartDNS) { + return + } + + hostMap := wildcardHostIPMap(pairs) + hosts := make([]string, 0, len(hostMap)) + for host := range hostMap { + hosts = append(hosts, host) + } + sort.Strings(hosts) + + appendTraceLineTo( + smartdnsLogPath, + "smartdns", + fmt.Sprintf("wildcard sync: mode=%s source=%s domains=%d ips=%d", mode.Mode, source, len(hosts), wildcardIPCount), + ) + + const maxHostsLog = 200 + for i, host := range hosts { + if i >= maxHostsLog { + appendTraceLineTo( + smartdnsLogPath, + "smartdns", + fmt.Sprintf("wildcard sync: +%d domains omitted", len(hosts)-maxHostsLog), + ) + return + } + appendTraceLineTo( + smartdnsLogPath, + "smartdns", + fmt.Sprintf("wildcard add: %s -> %s", host, strings.Join(hostMap[host], ", ")), + ) + } +} + +// --------------------------------------------------------------------- +// EN: `countDomainsFromMap` counts items for domains from map. +// RU: `countDomainsFromMap` - считает элементы для domains from map. +// --------------------------------------------------------------------- +func countDomainsFromMap(path string) int { + data, err := os.ReadFile(path) + if err != nil { + return 0 + } + seen := make(map[string]struct{}) + for _, ln := range strings.Split(string(data), "\n") { + ln = strings.TrimSpace(ln) + if ln == "" { + continue + } + fields := strings.Fields(ln) + if len(fields) < 2 { + continue + } + d := fields[1] + if strings.HasPrefix(d, "[") { + continue + } + seen[d] = struct{}{} + } + return len(seen) +} + +// --------------------------------------------------------------------- +// filesystem helpers +// --------------------------------------------------------------------- + +func copyFile(src, dst string) error { + data, err := os.ReadFile(src) + if err != nil { + return err + } + return os.WriteFile(dst, data, 0o644) +} + +// --------------------------------------------------------------------- +// EN: `chownDev` contains core logic for chown dev. +// RU: `chownDev` - содержит основную логику для chown dev. +// --------------------------------------------------------------------- +func chownDev(paths ...string) { + usr, err := user.Lookup("dev") + if err != nil { + return + } + uid, _ := strconv.Atoi(usr.Uid) + gid, _ := strconv.Atoi(usr.Gid) + for _, p := range paths { + _ = os.Chown(p, uid, gid) + } +} + +// --------------------------------------------------------------------- +// EN: `chmodPaths` contains core logic for chmod paths. +// RU: `chmodPaths` - содержит основную логику для chmod paths. +// --------------------------------------------------------------------- +func chmodPaths(mode fs.FileMode, paths ...string) { + for _, p := range paths { + _ = os.Chmod(p, mode) + } +} + +// --------------------------------------------------------------------- +// readiness helpers +// --------------------------------------------------------------------- + +func waitDNS(attempts int, delay time.Duration) error { + target := "openai.com" + for i := 0; i < attempts; i++ { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + _, err := net.DefaultResolver.LookupHost(ctx, target) + cancel() + if err == nil { + return nil + } + time.Sleep(delay) + } + return fmt.Errorf("dns lookup failed after %d attempts", attempts) +} diff --git a/selective-vpn-api/app/seeds.go b/selective-vpn-api/app/seeds.go new file mode 100644 index 0000000..e1564b2 --- /dev/null +++ b/selective-vpn-api/app/seeds.go @@ -0,0 +1,47 @@ +package app + +import ( + "fmt" + "io/fs" + "os" +) + +// --------------------------------------------------------------------- +// seed-файлы и bootstrap +// --------------------------------------------------------------------- + +// EN: Bootstrap seed files on first run so the API can start with sane defaults +// EN: even when runtime configuration files do not exist yet. +// RU: Инициализация seed-файлов при первом запуске, чтобы API стартовал +// RU: с корректными значениями по умолчанию при отсутствии runtime-конфигов. + +// --------------------------------------------------------------------- +// seed initializer +// --------------------------------------------------------------------- + +func ensureSeeds() { + _ = os.MkdirAll(domainDir, 0o755) + _ = os.MkdirAll("/etc/selective-vpn", 0o755) + _ = os.MkdirAll(stateDir, 0o755) + + seedFile := func(name string, path string) { + if _, err := os.Stat(path); err == nil { + return + } + data, err := fs.ReadFile(embeddedDomains, "assets/domains/"+name) + if err != nil { + data = []byte{} + } + _ = os.WriteFile(path, data, 0o644) + } + + seedFile("bases.txt", domainDir+"/bases.txt") + seedFile("subs.txt", domainDir+"/subs.txt") + seedFile("meta-special.txt", domainDir+"/meta-special.txt") + seedFile("static-ips.txt", staticIPsFile) + + if _, err := os.Stat(dnsUpstreamsConf); err != nil { + content := fmt.Sprintf("default %s %s\nmeta %s %s\n", defaultDNS1, defaultDNS2, defaultMeta1, defaultMeta2) + _ = os.WriteFile(dnsUpstreamsConf, []byte(content), 0o644) + } +} diff --git a/selective-vpn-api/app/server.go b/selective-vpn-api/app/server.go new file mode 100644 index 0000000..272b226 --- /dev/null +++ b/selective-vpn-api/app/server.go @@ -0,0 +1,204 @@ +package app + +import ( + "context" + "errors" + "flag" + "fmt" + "log" + "net/http" + "os" + "syscall" + "time" +) + +// --------------------------------------------------------------------- +// main + общие хелперы +// --------------------------------------------------------------------- + +// EN: Application entrypoint and process bootstrap. +// EN: This file wires CLI modes, registers all HTTP routes, and starts background +// EN: watchers plus the localhost API server. +// RU: Точка входа приложения и bootstrap процесса. +// RU: Этот файл связывает CLI-режимы, регистрирует все HTTP-маршруты и запускает +// RU: фоновые вотчеры вместе с локальным API-сервером. +func Run() { + // --------------------------------------------------------------------- + // CLI modes + // --------------------------------------------------------------------- + + // CLI mode: routes-update + if len(os.Args) > 1 && (os.Args[1] == "routes-update" || os.Args[1] == "-routes-update") { + fs := flag.NewFlagSet("routes-update", flag.ExitOnError) + iface := fs.String("iface", "", "VPN interface (empty/auto = detect active)") + _ = fs.Parse(os.Args[2:]) + lock, err := os.OpenFile(lockFile, os.O_CREATE|os.O_RDWR, 0o644) + if err != nil { + fmt.Fprintf(os.Stderr, "lock open error: %v\n", err) + os.Exit(1) + } + if err := syscall.Flock(int(lock.Fd()), syscall.LOCK_EX|syscall.LOCK_NB); err != nil { + fmt.Println("routes update already running") + lock.Close() + return + } + res := routesUpdate(*iface) + _ = syscall.Flock(int(lock.Fd()), syscall.LOCK_UN) + _ = lock.Close() + if res.OK { + fmt.Println(res.Message) + return + } + fmt.Fprintln(os.Stderr, res.Message) + os.Exit(1) + } + + // CLI mode: routes-clear + if len(os.Args) > 1 && os.Args[1] == "routes-clear" { + res := routesClear() + if res.OK { + fmt.Println(res.Message) + return + } + fmt.Fprintln(os.Stderr, res.Message) + os.Exit(1) + } + + // CLI mode: autoloop + if len(os.Args) > 1 && os.Args[1] == "autoloop" { + fs := flag.NewFlagSet("autoloop", flag.ExitOnError) + iface := fs.String("iface", "", "VPN interface (empty/auto = detect active)") + table := fs.String("table", "agvpn", "routing table name") + mtu := fs.Int("mtu", 1380, "MTU for default route") + stateDirFlag := fs.String("state-dir", stateDir, "state directory") + defaultLoc := fs.String("default-location", "Austria", "default location") + _ = fs.Parse(os.Args[2:]) + resolvedIface := normalizePreferredIface(*iface) + if resolvedIface == "" { + resolvedIface, _ = resolveTrafficIface(loadTrafficModeState().PreferredIface) + } + if resolvedIface == "" { + fmt.Fprintln(os.Stderr, "autoloop: cannot resolve VPN interface (set --iface or preferred iface)") + os.Exit(1) + } + runAutoloop(resolvedIface, *table, *mtu, *stateDirFlag, *defaultLoc) + return + } + + // --------------------------------------------------------------------- + // API server bootstrap + // --------------------------------------------------------------------- + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ensureSeeds() + + mux := http.NewServeMux() + + // --------------------------------------------------------------------- + // route registration + // --------------------------------------------------------------------- + + // health + mux.HandleFunc("/healthz", handleHealthz) + // event stream (SSE) + mux.HandleFunc("/api/v1/events/stream", handleEventsStream) + + // статус selective-routes + mux.HandleFunc("/api/v1/status", handleGetStatus) + mux.HandleFunc("/api/v1/routes/status", handleGetStatus) + + // login state + mux.HandleFunc("/api/v1/vpn/login-state", handleVPNLoginState) + + // systemd state + mux.HandleFunc("/api/v1/systemd/state", handleSystemdState) + + // сервис selective-routes + mux.HandleFunc("/api/v1/routes/service/start", + makeRoutesServiceActionHandler("start")) + mux.HandleFunc("/api/v1/routes/service/stop", + makeRoutesServiceActionHandler("stop")) + mux.HandleFunc("/api/v1/routes/service/restart", + makeRoutesServiceActionHandler("restart")) + // универсальный: {"action":"start|stop|restart"} + mux.HandleFunc("/api/v1/routes/service", handleRoutesService) + // ручной апдейт маршрутов (Go-реализация вместо bash) + mux.HandleFunc("/api/v1/routes/update", handleRoutesUpdate) + + // таймер маршрутов (новый API) + mux.HandleFunc("/api/v1/routes/timer", handleRoutesTimer) + // старый toggle для совместимости + mux.HandleFunc("/api/v1/routes/timer/toggle", handleRoutesTimerToggle) + + // rollback / clear (Go implementation) + mux.HandleFunc("/api/v1/routes/rollback", handleRoutesClear) + // alias: /routes/clear + mux.HandleFunc("/api/v1/routes/clear", handleRoutesClear) + // fast restore from clear-cache + mux.HandleFunc("/api/v1/routes/cache/restore", handleRoutesCacheRestore) + + // фиксим policy route + mux.HandleFunc("/api/v1/routes/fix-policy-route", handleFixPolicyRoute) + mux.HandleFunc("/api/v1/routes/fix-policy", handleFixPolicyRoute) + mux.HandleFunc("/api/v1/traffic/mode", handleTrafficMode) + mux.HandleFunc("/api/v1/traffic/mode/test", handleTrafficModeTest) + mux.HandleFunc("/api/v1/traffic/interfaces", handleTrafficInterfaces) + mux.HandleFunc("/api/v1/traffic/candidates", handleTrafficCandidates) + + // trace: хвост + JSON + append для GUI + mux.HandleFunc("/api/v1/trace", handleTraceTailPlain) + mux.HandleFunc("/api/v1/trace-json", handleTraceJSON) + mux.HandleFunc("/api/v1/trace/append", handleTraceAppend) + + // DNS upstreams + mux.HandleFunc("/api/v1/dns-upstreams", handleDNSUpstreams) + mux.HandleFunc("/api/v1/dns/status", handleDNSStatus) + mux.HandleFunc("/api/v1/dns/mode", handleDNSModeSet) + mux.HandleFunc("/api/v1/dns/smartdns-service", handleDNSSmartdnsService) + + // SmartDNS service + mux.HandleFunc("/api/v1/smartdns/service", handleSmartdnsService) + mux.HandleFunc("/api/v1/smartdns/runtime", handleSmartdnsRuntime) + mux.HandleFunc("/api/v1/smartdns/prewarm", handleSmartdnsPrewarm) + + // domains editor + mux.HandleFunc("/api/v1/domains/table", handleDomainsTable) + mux.HandleFunc("/api/v1/domains/file", handleDomainsFile) + + // SmartDNS wildcards + mux.HandleFunc("/api/v1/smartdns/wildcards", handleSmartdnsWildcards) + + // AdGuard VPN: status + autoloop + autoconnect + locations + mux.HandleFunc("/api/v1/vpn/autoloop-status", handleVPNAutoloopStatus) + mux.HandleFunc("/api/v1/vpn/status", handleVPNStatus) + mux.HandleFunc("/api/v1/vpn/autoconnect", handleVPNAutoconnect) + mux.HandleFunc("/api/v1/vpn/locations", handleVPNListLocations) + mux.HandleFunc("/api/v1/vpn/location", handleVPNSetLocation) + + // AdGuard VPN: interactive login session (PTY) + mux.HandleFunc("/api/v1/vpn/login/session/start", handleVPNLoginSessionStart) + mux.HandleFunc("/api/v1/vpn/login/session/state", handleVPNLoginSessionState) + mux.HandleFunc("/api/v1/vpn/login/session/action", handleVPNLoginSessionAction) + mux.HandleFunc("/api/v1/vpn/login/session/stop", handleVPNLoginSessionStop) + // logout + mux.HandleFunc("/api/v1/vpn/logout", handleVPNLogout) + + // --------------------------------------------------------------------- + // HTTP server + // --------------------------------------------------------------------- + + srv := &http.Server{ + Addr: "127.0.0.1:8080", + Handler: logRequests(mux), + ReadHeaderTimeout: 5 * time.Second, + } + + go startWatchers(ctx) + + log.Printf("selective-vpn API listening on %s", srv.Addr) + if err := srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + log.Fatalf("server error: %v", err) + } +} diff --git a/selective-vpn-api/app/shell.go b/selective-vpn-api/app/shell.go new file mode 100644 index 0000000..d446388 --- /dev/null +++ b/selective-vpn-api/app/shell.go @@ -0,0 +1,72 @@ +package app + +import ( + "context" + "errors" + "fmt" + "os/exec" + "strings" + "time" +) + +// --------------------------------------------------------------------- +// низкоуровневые helpers +// --------------------------------------------------------------------- + +// EN: Low-level command execution adapters with timeout handling and small +// EN: policy-route verification helper used by higher-level handlers. +// RU: Низкоуровневые адаптеры запуска команд с таймаутами и вспомогательной +// RU: проверкой policy-route, используемой верхнеуровневыми обработчиками. + +func runCommand(name string, args ...string) (string, string, int, error) { + return runCommandTimeout(60*time.Second, name, args...) +} + +// --------------------------------------------------------------------- +// policy route check +// --------------------------------------------------------------------- + +func checkPolicyRoute(iface, table string) (bool, error) { + stdout, _, exitCode, err := runCommand("ip", "route", "show", "table", table) + if exitCode != 0 { + if err == nil { + err = fmt.Errorf("ip route show exited with %d", exitCode) + } + return false, err + } + want := fmt.Sprintf("default dev %s", iface) + for _, line := range strings.Split(stdout, "\n") { + line = strings.TrimSpace(line) + if strings.HasPrefix(line, want) { + return true, nil + } + } + return false, nil +} + +// --------------------------------------------------------------------- +// command timeout helper +// --------------------------------------------------------------------- + +func runCommandTimeout(timeout time.Duration, name string, args ...string) (string, string, int, error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + cmd := exec.CommandContext(ctx, name, args...) + out, err := cmd.CombinedOutput() + stdout := string(out) + stderr := stdout + + exitCode := 0 + if err != nil { + if ee, ok := err.(*exec.ExitError); ok { + exitCode = ee.ExitCode() + } else if errors.Is(err, context.DeadlineExceeded) { + exitCode = -1 + err = fmt.Errorf("command timeout: %w", err) + } else { + exitCode = -1 + } + } + return stdout, stderr, exitCode, err +} diff --git a/selective-vpn-api/app/smartdns_runtime.go b/selective-vpn-api/app/smartdns_runtime.go new file mode 100644 index 0000000..8b636b1 --- /dev/null +++ b/selective-vpn-api/app/smartdns_runtime.go @@ -0,0 +1,224 @@ +package app + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + "time" +) + +// --------------------------------------------------------------------- +// smartdns runtime accelerator (nftset -> agvpn_dyn4) +// --------------------------------------------------------------------- + +// EN: Runtime accelerator state is persisted separately from DNS mode. +// EN: This allows enabling/disabling SmartDNS nftset hook without changing +// EN: resolver primary behavior. +// RU: Состояние runtime-ускорителя хранится отдельно от DNS mode. +// RU: Это позволяет включать/выключать SmartDNS nftset-hook независимо от +// RU: основного пути через резолвер. + +const ( + smartdnsRuntimeDomainSetLine = "domain-set -name agvpn_wild -file /etc/selective-vpn/smartdns.conf" + smartdnsRuntimeNftsetLine = "nftset /domain-set:agvpn_wild/#4:inet#agvpn#agvpn_dyn4" + smartdnsRuntimeStateVersion = 1 +) + +type smartDNSRuntimeState struct { + Version int `json:"version"` + Enabled bool `json:"enabled"` + UpdatedAt string `json:"updated_at"` +} + +func wildcardFillSource(runtimeEnabled bool) string { + if runtimeEnabled { + return "both" + } + return "resolver" +} + +func normalizeSmartDNSRuntimeState(st smartDNSRuntimeState) smartDNSRuntimeState { + if st.Version <= 0 { + st.Version = smartdnsRuntimeStateVersion + } + return st +} + +func smartDNSRuntimeEnabledFromConfig() (bool, error) { + data, err := os.ReadFile(smartdnsMainConfig) + if err != nil { + return false, err + } + for _, raw := range strings.Split(string(data), "\n") { + trimmed := strings.TrimSpace(raw) + if trimmed == "" || strings.HasPrefix(trimmed, "#") { + continue + } + if strings.Contains(trimmed, "nftset") && + strings.Contains(trimmed, "domain-set:agvpn_wild") && + strings.Contains(trimmed, "agvpn_dyn4") { + return true, nil + } + } + return false, nil +} + +func inferSmartDNSRuntimeEnabled() bool { + enabled, err := smartDNSRuntimeEnabledFromConfig() + if err != nil { + // Keep historical behavior on first run when config is unavailable. + return true + } + return enabled +} + +func loadSmartDNSRuntimeState(logf func(string, ...any)) smartDNSRuntimeState { + if data, err := os.ReadFile(smartdnsRTPath); err == nil { + var st smartDNSRuntimeState + if json.Unmarshal(data, &st) == nil { + return normalizeSmartDNSRuntimeState(st) + } + if logf != nil { + logf("smartdns runtime: invalid state json at %s, rebuilding", smartdnsRTPath) + } + } + + st := smartDNSRuntimeState{ + Version: smartdnsRuntimeStateVersion, + Enabled: inferSmartDNSRuntimeEnabled(), + UpdatedAt: time.Now().UTC().Format(time.RFC3339), + } + _ = saveSmartDNSRuntimeState(st) + return st +} + +func saveSmartDNSRuntimeState(st smartDNSRuntimeState) error { + st = normalizeSmartDNSRuntimeState(st) + st.UpdatedAt = time.Now().UTC().Format(time.RFC3339) + data, err := json.MarshalIndent(st, "", " ") + if err != nil { + return err + } + if err := os.MkdirAll(filepath.Dir(smartdnsRTPath), 0o755); err != nil { + return err + } + tmp := smartdnsRTPath + ".tmp" + if err := os.WriteFile(tmp, data, 0o644); err != nil { + return err + } + return os.Rename(tmp, smartdnsRTPath) +} + +func smartDNSRuntimeEnabled() bool { + return loadSmartDNSRuntimeState(nil).Enabled +} + +func normalizeSmartDNSMainConfig(content string, enabled bool) string { + normalized := strings.ReplaceAll(content, "\r\n", "\n") + lines := strings.Split(normalized, "\n") + out := make([]string, 0, len(lines)+4) + + seenDomain := false + seenNftset := false + + isDomainLine := func(raw string) bool { + t := strings.TrimSpace(raw) + if strings.HasPrefix(t, "#") { + t = strings.TrimSpace(strings.TrimPrefix(t, "#")) + } + return strings.HasPrefix(t, "domain-set ") && + strings.Contains(t, "-name agvpn_wild") && + strings.Contains(t, "/etc/selective-vpn/smartdns.conf") + } + isNftsetLine := func(raw string) bool { + t := strings.TrimSpace(raw) + if strings.HasPrefix(t, "#") { + t = strings.TrimSpace(strings.TrimPrefix(t, "#")) + } + return strings.HasPrefix(t, "nftset ") && + strings.Contains(t, "domain-set:agvpn_wild") && + strings.Contains(t, "agvpn_dyn4") + } + + for _, raw := range lines { + switch { + case isDomainLine(raw): + if !seenDomain { + if enabled { + out = append(out, smartdnsRuntimeDomainSetLine) + } else { + out = append(out, "# "+smartdnsRuntimeDomainSetLine) + } + seenDomain = true + } + case isNftsetLine(raw): + if !seenNftset { + if enabled { + out = append(out, smartdnsRuntimeNftsetLine) + } else { + out = append(out, "# "+smartdnsRuntimeNftsetLine) + } + seenNftset = true + } + default: + out = append(out, raw) + } + } + + if enabled && (!seenDomain || !seenNftset) { + if len(out) > 0 && strings.TrimSpace(out[len(out)-1]) != "" { + out = append(out, "") + } + if !seenDomain { + out = append(out, smartdnsRuntimeDomainSetLine) + } + if !seenNftset { + out = append(out, smartdnsRuntimeNftsetLine) + } + } + + rendered := strings.Join(out, "\n") + if !strings.HasSuffix(rendered, "\n") { + rendered += "\n" + } + return rendered +} + +func applySmartDNSRuntimeConfig(enabled bool) (bool, error) { + data, err := os.ReadFile(smartdnsMainConfig) + if err != nil { + return false, err + } + current := strings.ReplaceAll(string(data), "\r\n", "\n") + next := normalizeSmartDNSMainConfig(current, enabled) + if next == current { + return false, nil + } + tmp := smartdnsMainConfig + ".tmp" + if err := os.WriteFile(tmp, []byte(next), 0o644); err != nil { + return false, err + } + if err := os.Rename(tmp, smartdnsMainConfig); err != nil { + return false, err + } + return true, nil +} + +func smartDNSRuntimeSnapshot() SmartDNSRuntimeStatusResponse { + st := loadSmartDNSRuntimeState(nil) + appliedEnabled, err := smartDNSRuntimeEnabledFromConfig() + msg := "" + if err != nil { + msg = fmt.Sprintf("config read error: %v", err) + } + return SmartDNSRuntimeStatusResponse{ + Enabled: st.Enabled, + AppliedEnable: appliedEnabled, + WildcardSource: wildcardFillSource(st.Enabled), + UnitState: smartdnsUnitState(), + ConfigPath: smartdnsMainConfig, + Message: msg, + } +} diff --git a/selective-vpn-api/app/smartdns_wildcards_store.go b/selective-vpn-api/app/smartdns_wildcards_store.go new file mode 100644 index 0000000..7457484 --- /dev/null +++ b/selective-vpn-api/app/smartdns_wildcards_store.go @@ -0,0 +1,132 @@ +package app + +import ( + "encoding/json" + "os" + "path/filepath" + "strings" + "time" +) + +// --------------------------------------------------------------------- +// smartdns wildcard canonical store +// --------------------------------------------------------------------- + +// EN: Canonical SmartDNS wildcard storage is JSON in stateDir. +// EN: `/etc/selective-vpn/smartdns.conf` is generated as a runtime artifact. +// RU: Каноничное хранилище wildcard-доменов SmartDNS — JSON в stateDir. +// RU: `/etc/selective-vpn/smartdns.conf` генерируется как runtime-артефакт. + +type smartDNSWildcardState struct { + Version int `json:"version"` + UpdatedAt string `json:"updated_at"` + Domains []string `json:"domains"` +} + +func normalizeWildcardDomains(raw []string) []string { + seen := map[string]struct{}{} + out := make([]string, 0, len(raw)) + for _, ln := range raw { + d := normalizeWildcardDomain(ln) + if d == "" { + continue + } + if _, ok := seen[d]; ok { + continue + } + seen[d] = struct{}{} + out = append(out, d) + } + return out +} + +func parseSmartDNSDomainsContent(content string) []string { + return normalizeWildcardDomains(strings.Split(content, "\n")) +} + +func renderSmartDNSDomainsContent(domains []string) string { + header := strings.TrimSpace(` +# Auto-generated by selective-vpn API. +# SmartDNS wildcard rules for selective VPN / AGVPN. +`) + "\n" + if len(domains) == 0 { + return header + } + return header + "\n" + strings.Join(domains, "\n") + "\n" +} + +func loadSmartDNSWildcardDomainsState(logf func(string, ...any)) ([]string, string) { + if data, err := os.ReadFile(smartdnsWLPath); err == nil { + // preferred shape: object with metadata + var st smartDNSWildcardState + if json.Unmarshal(data, &st) == nil { + domains := normalizeWildcardDomains(st.Domains) + _ = writeSmartDNSDomainsArtifact(domains) + return domains, "state" + } + // backward-compat shape: plain []string + var arr []string + if json.Unmarshal(data, &arr) == nil { + domains := normalizeWildcardDomains(arr) + _ = saveSmartDNSWildcardDomainsState(domains) + return domains, "state-legacy" + } + if logf != nil { + logf("smartdns wildcards: invalid state json at %s, fallback to conf", smartdnsWLPath) + } + } + + // migration path: parse legacy .conf file if state json is missing/broken. + confData, err := os.ReadFile(smartdnsDomainsFile) + if err == nil { + domains := parseSmartDNSDomainsContent(string(confData)) + if saveErr := saveSmartDNSWildcardDomainsState(domains); saveErr != nil && logf != nil { + logf("smartdns wildcards: migration from conf failed: %v", saveErr) + } + return domains, "migrated-conf" + } + + // bootstrap empty canonical state + artifact. + _ = saveSmartDNSWildcardDomainsState(nil) + return nil, "default" +} + +func saveSmartDNSWildcardDomainsState(domains []string) error { + normalized := normalizeWildcardDomains(domains) + + state := smartDNSWildcardState{ + Version: 1, + UpdatedAt: time.Now().UTC().Format(time.RFC3339), + Domains: normalized, + } + data, err := json.MarshalIndent(state, "", " ") + if err != nil { + return err + } + + if err := os.MkdirAll(filepath.Dir(smartdnsWLPath), 0o755); err != nil { + return err + } + if err := os.MkdirAll(filepath.Dir(smartdnsDomainsFile), 0o755); err != nil { + return err + } + + stateTmp := smartdnsWLPath + ".tmp" + if err := os.WriteFile(stateTmp, data, 0o644); err != nil { + return err + } + if err := os.Rename(stateTmp, smartdnsWLPath); err != nil { + return err + } + + return writeSmartDNSDomainsArtifact(normalized) +} + +func writeSmartDNSDomainsArtifact(domains []string) error { + content := renderSmartDNSDomainsContent(domains) + tmp := smartdnsDomainsFile + ".tmp" + if err := os.WriteFile(tmp, []byte(content), 0o644); err != nil { + return err + } + return os.Rename(tmp, smartdnsDomainsFile) +} diff --git a/selective-vpn-api/app/trace_handlers.go b/selective-vpn-api/app/trace_handlers.go new file mode 100644 index 0000000..48bc117 --- /dev/null +++ b/selective-vpn-api/app/trace_handlers.go @@ -0,0 +1,261 @@ +package app + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "strconv" + "strings" + "time" +) + +// --------------------------------------------------------------------- +// trace: чтение + запись +// --------------------------------------------------------------------- + +// EN: Trace log endpoints and helpers for GUI/operator visibility. +// EN: Includes plain tail, filtered JSON views, append API, and bounded tail reader. +// RU: Эндпоинты и хелперы trace-логов для GUI/оператора. +// RU: Включает plain tail, фильтрованные JSON-режимы, append API и безопасный tail-reader. + +func handleTraceTailPlain(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + lines := tailFile(traceLogPath, defaultTraceTailMax) + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + _, _ = io.WriteString(w, strings.Join(lines, "\n")) +} + +// --------------------------------------------------------------------- +// trace-json +// --------------------------------------------------------------------- + +// GET /api/v1/trace-json?mode=full|gui|events|smartdns +func handleTraceJSON(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + mode := strings.ToLower(strings.TrimSpace(r.URL.Query().Get("mode"))) + if mode == "" { + mode = "full" + } + if mode == "events" { + mode = "gui" + } + + var lines []string + + switch mode { + case "smartdns": + // чисто SmartDNS-лог + lines = tailFile(smartdnsLogPath, defaultTraceTailMax) + + case "gui": + // Events: только человеко-читабельные события/ошибки/команды. + full := tailFile(traceLogPath, defaultTraceTailMax) + allow := []string{ + "[gui]", "[info]", "[login]", "[vpn]", "[event]", "[error]", + } + for _, l := range full { + ll := strings.ToLower(l) + + // берём только наши "человеческие" префиксы + ok := false + for _, a := range allow { + if strings.Contains(ll, strings.ToLower(a)) { + ok = true + break + } + } + if !ok { + // если префикса нет, но это похоже на ошибку — тоже включаем + if strings.Contains(ll, "error") || strings.Contains(ll, "failed") || strings.Contains(ll, "timeout") { + ok = true + } + } + if !ok { + continue + } + + // режем шум от резолвера/маршрутов/массовых вставок + if strings.Contains(ll, "smartdns") || + strings.Contains(ll, "resolver") || + strings.Contains(ll, "dnstt") || + strings.Contains(ll, "routes") || + strings.Contains(ll, "nft add element") || + strings.Contains(ll, "cache hit:") { + continue + } + + lines = append(lines, l) + } + + default: // full + // полный хвост trace.log без фильтрации + lines = tailFile(traceLogPath, defaultTraceTailMax) + } + + writeJSON(w, http.StatusOK, map[string]any{ + "lines": lines, + }) +} + +// --------------------------------------------------------------------- +// trace append +// --------------------------------------------------------------------- + +// POST /api/v1/trace/append { "kind": "gui|smartdns|info", "line": "..." } +func handleTraceAppend(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + var body struct { + Kind string `json:"kind"` + Line string `json:"line"` + } + if r.Body != nil { + defer r.Body.Close() + _ = json.NewDecoder(io.LimitReader(r.Body, 1<<20)).Decode(&body) + } + kind := strings.ToLower(strings.TrimSpace(body.Kind)) + line := strings.TrimRight(body.Line, "\r\n") + + if line == "" { + writeJSON(w, http.StatusOK, map[string]any{"ok": true}) + return + } + + _ = os.MkdirAll(stateDir, 0o755) + + switch kind { + case "smartdns": + appendTraceLineTo(smartdnsLogPath, "smartdns", line) + case "gui": + appendTraceLineTo(traceLogPath, "gui", line) + default: + appendTraceLineTo(traceLogPath, "info", line) + } + + events.push("trace_append", map[string]any{ + "kind": kind, + }) + + writeJSON(w, http.StatusOK, map[string]any{"ok": true}) +} + +// --------------------------------------------------------------------- +// trace write helpers +// --------------------------------------------------------------------- + +func appendTraceLineTo(path, prefix, line string) { + line = strings.TrimRight(line, "\r\n") + if line == "" { + return + } + ts := time.Now().UTC().Format(time.RFC3339) + _ = os.MkdirAll(stateDir, 0o755) + + // простейший "ручной логротейт" + const maxSize = 10 * 1024 * 1024 // 10 МБ + if fi, err := os.Stat(path); err == nil && fi.Size() > maxSize { + // можно просто truncate + _ = os.Truncate(path, 0) + // или переименовать в *.1 и начать новый + // _ = os.Rename(path, path+".1") + } + + f, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644) + if err != nil { + return + } + defer f.Close() + _, _ = fmt.Fprintf(f, "[%s] %s %s\n", prefix, ts, line) +} + +// --------------------------------------------------------------------- +// EN: `appendTraceLine` appends or adds trace line to an existing state. +// RU: `appendTraceLine` - добавляет trace line в текущее состояние. +// --------------------------------------------------------------------- +func appendTraceLine(prefix, line string) { + appendTraceLineTo(traceLogPath, prefix, line) +} + +// --------------------------------------------------------------------- +// tail helper +// --------------------------------------------------------------------- + +const defaultTailMaxBytes = 512 * 1024 + +func tailFile(path string, maxLines int) []string { + if maxLines <= 0 { + return nil + } + + // читаем лимит из env, если задан + maxBytes := defaultTailMaxBytes + if env := os.Getenv("SVPN_TAIL_MAX_BYTES"); env != "" { + if n, err := strconv.Atoi(env); err == nil && n > 0 { + maxBytes = n + } + } + + f, err := os.Open(path) + if err != nil { + // файла нет или нет прав — просто ничего не отдаём + return nil + } + defer f.Close() + + fi, err := f.Stat() + if err != nil { + return nil + } + size := fi.Size() + if size <= 0 { + return nil + } + + // с какого смещения читаем хвост + start := int64(0) + if size > int64(maxBytes) { + start = size - int64(maxBytes) + } + + // двигаем указатель в файле + if _, err := f.Seek(start, io.SeekStart); err != nil { + return nil + } + + // читаем хвост + data, err := io.ReadAll(f) + if err != nil { + return nil + } + + // режем по строкам + lines := strings.Split(string(data), "\n") + + // если мы начали читать с середины файла (start > 0), + // первая строка почти наверняка обрезана — выбрасываем её. + if start > 0 && len(lines) > 0 { + lines = lines[1:] + } + + // убираем финальную пустую строку, если есть + if n := len(lines); n > 0 && lines[n-1] == "" { + lines = lines[:n-1] + } + + // берём только последние maxLines + if len(lines) > maxLines { + lines = lines[len(lines)-maxLines:] + } + + return lines +} diff --git a/selective-vpn-api/app/traffic_candidates.go b/selective-vpn-api/app/traffic_candidates.go new file mode 100644 index 0000000..9edb0fb --- /dev/null +++ b/selective-vpn-api/app/traffic_candidates.go @@ -0,0 +1,225 @@ +package app + +import ( + "net/http" + "sort" + "strconv" + "strings" + "time" +) + +// --------------------------------------------------------------------- +// traffic candidates (subnets / systemd units / UIDs) +// --------------------------------------------------------------------- + +// EN: Provides best-effort suggestions for traffic overrides UI. +// EN: This endpoint must never apply anything automatically. +// RU: Отдаёт подсказки для UI overrides. +// RU: Этот эндпоинт никогда не должен применять что-либо автоматически. + +func handleTrafficCandidates(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + resp := TrafficCandidatesResponse{ + GeneratedAt: time.Now().UTC().Format(time.RFC3339), + Subnets: trafficCandidateSubnets(), + Units: trafficCandidateUnits(), + UIDs: trafficCandidateUIDs(), + } + writeJSON(w, http.StatusOK, resp) +} + +func trafficCandidateSubnets() []TrafficCandidateSubnet { + out, _, code, _ := runCommand("ip", "-4", "route", "show", "table", "main") + if code != 0 { + return nil + } + + seen := map[string]struct{}{} + items := make([]TrafficCandidateSubnet, 0, 24) + + for _, raw := range strings.Split(out, "\n") { + line := strings.TrimSpace(raw) + if line == "" { + continue + } + fields := strings.Fields(line) + if len(fields) == 0 { + continue + } + + dst := strings.TrimSpace(fields[0]) + if dst == "" || dst == "default" { + continue + } + dev := parseRouteDevice(fields) + if dev == "" || dev == "lo" { + continue + } + if isVPNLikeIface(dev) { + continue + } + + isDocker := isContainerIface(dev) + isLocal := isAutoBypassDestination(dst) + if !isDocker && !isLocal { + // keep suggestions intentionally small: only local/LAN + container subnets + continue + } + + kind := "lan" + if isDocker { + kind = "docker" + } else if strings.Contains(" "+strings.ToLower(line)+" ", " scope link ") { + kind = "link" + } + + key := kind + "|" + dst + "|" + dev + if _, ok := seen[key]; ok { + continue + } + seen[key] = struct{}{} + + items = append(items, TrafficCandidateSubnet{ + CIDR: dst, + Dev: dev, + Kind: kind, + LinkDown: strings.Contains(strings.ToLower(line), " linkdown"), + }) + } + + sort.Slice(items, func(i, j int) bool { + if items[i].Kind != items[j].Kind { + return items[i].Kind < items[j].Kind + } + if items[i].Dev != items[j].Dev { + return items[i].Dev < items[j].Dev + } + return items[i].CIDR < items[j].CIDR + }) + return items +} + +func trafficCandidateUnits() []TrafficCandidateUnit { + stdout, _, code, _ := runCommand( + "systemctl", + "list-units", + "--type=service", + "--state=running", + "--no-legend", + "--no-pager", + "--plain", + ) + if code != 0 { + return nil + } + + seen := map[string]struct{}{} + items := make([]TrafficCandidateUnit, 0, 32) + for _, raw := range strings.Split(stdout, "\n") { + line := strings.TrimSpace(raw) + if line == "" { + continue + } + fields := strings.Fields(line) + if len(fields) < 1 { + continue + } + unit := strings.TrimSpace(fields[0]) + if unit == "" { + continue + } + if _, ok := seen[unit]; ok { + continue + } + seen[unit] = struct{}{} + + desc := "" + // UNIT LOAD ACTIVE SUB DESCRIPTION + if len(fields) >= 5 { + desc = strings.Join(fields[4:], " ") + } + + items = append(items, TrafficCandidateUnit{ + Unit: unit, + Description: strings.TrimSpace(desc), + Cgroup: "system.slice/" + unit, + }) + } + + sort.Slice(items, func(i, j int) bool { + return items[i].Unit < items[j].Unit + }) + return items +} + +func trafficCandidateUIDs() []TrafficCandidateUID { + stdout, _, code, _ := runCommand("ps", "-eo", "uid,user,comm", "--no-headers") + if code != 0 { + return nil + } + + type agg struct { + uid int + user string + comms map[string]struct{} + } + + aggs := map[int]*agg{} + for _, raw := range strings.Split(stdout, "\n") { + line := strings.TrimSpace(raw) + if line == "" { + continue + } + fields := strings.Fields(line) + if len(fields) < 2 { + continue + } + uidN, err := strconv.Atoi(strings.TrimSpace(fields[0])) + if err != nil || uidN < 0 { + continue + } + user := strings.TrimSpace(fields[1]) + comm := "" + if len(fields) >= 3 { + comm = strings.TrimSpace(fields[2]) + } + + a := aggs[uidN] + if a == nil { + a = &agg{uid: uidN, user: user, comms: map[string]struct{}{}} + aggs[uidN] = a + } + if a.user == "" && user != "" { + a.user = user + } + if comm != "" { + a.comms[comm] = struct{}{} + } + } + + items := make([]TrafficCandidateUID, 0, len(aggs)) + for _, a := range aggs { + examples := make([]string, 0, len(a.comms)) + for c := range a.comms { + examples = append(examples, c) + } + sort.Strings(examples) + if len(examples) > 3 { + examples = examples[:3] + } + items = append(items, TrafficCandidateUID{ + UID: a.uid, + User: a.user, + Examples: examples, + }) + } + + sort.Slice(items, func(i, j int) bool { + return items[i].UID < items[j].UID + }) + return items +} diff --git a/selective-vpn-api/app/traffic_mode.go b/selective-vpn-api/app/traffic_mode.go new file mode 100644 index 0000000..fe79de2 --- /dev/null +++ b/selective-vpn-api/app/traffic_mode.go @@ -0,0 +1,1154 @@ +package app + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "net/netip" + "os" + "path/filepath" + "sort" + "strconv" + "strings" + "time" +) + +const ( + trafficRulePrefDirectSubnetStart = 11600 + trafficRulePrefDirectUIDStart = 11680 + trafficRulePrefVPNSubnetStart = 11720 + trafficRulePrefVPNUIDStart = 11800 + trafficRulePrefFull = 11900 + trafficRulePrefSelective = 12000 + trafficRulePrefManagedMin = 11600 + trafficRulePrefManagedMax = 12099 + trafficRulePerKindLimit = 70 + trafficAutoLocalDefault = true +) + +var cgnatPrefix = netip.MustParsePrefix("100.64.0.0/10") + +const cgroupRootPath = "/sys/fs/cgroup" + +// --------------------------------------------------------------------- +// traffic mode (selective / full_tunnel / direct) +// --------------------------------------------------------------------- + +// EN: Controls route-policy behavior independently from DNS mode. +// EN: Uses a persisted desired state with runtime verification and rollback. +// RU: Управляет policy routing независимо от DNS-режима. +// RU: Использует сохраненное desired-state, runtime-проверку и откат. + +func normalizeTrafficMode(raw TrafficMode) TrafficMode { + switch strings.ToLower(strings.TrimSpace(string(raw))) { + case string(TrafficModeFullTunnel): + return TrafficModeFullTunnel + case string(TrafficModeDirect): + return TrafficModeDirect + case string(TrafficModeSelective): + return TrafficModeSelective + default: + return TrafficModeSelective + } +} + +func normalizePreferredIface(raw string) string { + v := strings.TrimSpace(raw) + l := strings.ToLower(v) + if l == "" || l == "auto" || l == "-" || l == "default" { + return "" + } + return v +} + +func tokenizeList(raw []string) []string { + repl := strings.NewReplacer(",", " ", ";", " ", "\n", " ", "\t", " ") + out := make([]string, 0, len(raw)) + for _, line := range raw { + for _, tok := range strings.Fields(repl.Replace(line)) { + val := strings.TrimSpace(tok) + if val != "" { + out = append(out, val) + } + } + } + return out +} + +func normalizeSubnetList(raw []string) []string { + seen := map[string]struct{}{} + out := make([]string, 0, len(raw)) + for _, tok := range tokenizeList(raw) { + var cidr string + if strings.Contains(tok, "/") { + pfx, err := netip.ParsePrefix(tok) + if err != nil || !pfx.Addr().Is4() { + continue + } + cidr = pfx.Masked().String() + } else { + ip, err := netip.ParseAddr(tok) + if err != nil || !ip.Is4() { + continue + } + cidr = netip.PrefixFrom(ip, 32).String() + } + if _, ok := seen[cidr]; ok { + continue + } + seen[cidr] = struct{}{} + out = append(out, cidr) + } + sort.Strings(out) + return out +} + +func normalizeUIDToken(tok string) (string, bool) { + t := strings.TrimSpace(tok) + if t == "" { + return "", false + } + parseOne := func(s string) (uint64, bool) { + n, err := strconv.ParseUint(strings.TrimSpace(s), 10, 32) + if err != nil { + return 0, false + } + return n, true + } + if strings.Contains(t, "-") { + parts := strings.SplitN(t, "-", 2) + if len(parts) != 2 { + return "", false + } + start, okA := parseOne(parts[0]) + end, okB := parseOne(parts[1]) + if !okA || !okB || end < start { + return "", false + } + return fmt.Sprintf("%d-%d", start, end), true + } + n, ok := parseOne(t) + if !ok { + return "", false + } + return fmt.Sprintf("%d-%d", n, n), true +} + +func normalizeUIDList(raw []string) []string { + seen := map[string]struct{}{} + out := make([]string, 0, len(raw)) + for _, tok := range tokenizeList(raw) { + v, ok := normalizeUIDToken(tok) + if !ok { + continue + } + if _, exists := seen[v]; exists { + continue + } + seen[v] = struct{}{} + out = append(out, v) + } + sort.Strings(out) + return out +} + +func normalizeCgroupList(raw []string) []string { + seen := map[string]struct{}{} + out := make([]string, 0, len(raw)) + for _, tok := range tokenizeList(raw) { + v := strings.TrimSpace(tok) + if v == "" { + continue + } + v = strings.TrimSuffix(v, "/") + if v == "" { + v = "/" + } + if _, exists := seen[v]; exists { + continue + } + seen[v] = struct{}{} + out = append(out, v) + } + sort.Strings(out) + return out +} + +func normalizeTrafficModeState(st TrafficModeState) TrafficModeState { + st.Mode = normalizeTrafficMode(st.Mode) + st.PreferredIface = normalizePreferredIface(st.PreferredIface) + st.ForceVPNSubnets = normalizeSubnetList(st.ForceVPNSubnets) + st.ForceVPNUIDs = normalizeUIDList(st.ForceVPNUIDs) + st.ForceVPNCGroups = normalizeCgroupList(st.ForceVPNCGroups) + st.ForceDirectSubnets = normalizeSubnetList(st.ForceDirectSubnets) + st.ForceDirectUIDs = normalizeUIDList(st.ForceDirectUIDs) + st.ForceDirectCGroups = normalizeCgroupList(st.ForceDirectCGroups) + return st +} + +func loadTrafficModeState() TrafficModeState { + data, err := os.ReadFile(trafficModePath) + if err != nil { + return inferTrafficModeState() + } + + type diskState struct { + Mode TrafficMode `json:"mode"` + PreferredIface string `json:"preferred_iface,omitempty"` + AutoLocalBypass *bool `json:"auto_local_bypass,omitempty"` + ForceVPNSubnets []string `json:"force_vpn_subnets,omitempty"` + ForceVPNUIDs []string `json:"force_vpn_uids,omitempty"` + ForceVPNCGroups []string `json:"force_vpn_cgroups,omitempty"` + ForceDirectSubnets []string `json:"force_direct_subnets,omitempty"` + ForceDirectUIDs []string `json:"force_direct_uids,omitempty"` + ForceDirectCGroups []string `json:"force_direct_cgroups,omitempty"` + } + var raw diskState + if err := json.Unmarshal(data, &raw); err != nil { + return inferTrafficModeState() + } + st := TrafficModeState{ + Mode: raw.Mode, + PreferredIface: raw.PreferredIface, + AutoLocalBypass: trafficAutoLocalDefault, + ForceVPNSubnets: append([]string(nil), raw.ForceVPNSubnets...), + ForceVPNUIDs: append([]string(nil), raw.ForceVPNUIDs...), + ForceVPNCGroups: append([]string(nil), raw.ForceVPNCGroups...), + ForceDirectSubnets: append([]string(nil), raw.ForceDirectSubnets...), + ForceDirectUIDs: append([]string(nil), raw.ForceDirectUIDs...), + ForceDirectCGroups: append([]string(nil), raw.ForceDirectCGroups...), + } + if raw.AutoLocalBypass != nil { + st.AutoLocalBypass = *raw.AutoLocalBypass + } + return normalizeTrafficModeState(st) +} + +func saveTrafficModeState(st TrafficModeState) error { + st = normalizeTrafficModeState(st) + st.UpdatedAt = time.Now().UTC().Format(time.RFC3339) + + data, err := json.MarshalIndent(st, "", " ") + if err != nil { + return err + } + if err := os.MkdirAll(stateDir, 0o755); err != nil { + return err + } + tmp := trafficModePath + ".tmp" + if err := os.WriteFile(tmp, data, 0o644); err != nil { + return err + } + return os.Rename(tmp, trafficModePath) +} + +func inferTrafficModeState() TrafficModeState { + rules := readTrafficRules() + mode := detectAppliedTrafficMode(rules) + iface, _ := resolveTrafficIface("") + return normalizeTrafficModeState(TrafficModeState{ + Mode: mode, + PreferredIface: iface, + AutoLocalBypass: trafficAutoLocalDefault, + ForceVPNSubnets: nil, + ForceVPNUIDs: nil, + ForceVPNCGroups: nil, + ForceDirectSubnets: nil, + ForceDirectUIDs: nil, + ForceDirectCGroups: nil, + }) +} + +func ensureRoutesTableEntry() { + data, _ := os.ReadFile("/etc/iproute2/rt_tables") + want := fmt.Sprintf("%s %s", routesTableNum(), routesTableName()) + if strings.Contains(string(data), "\n"+want) || strings.HasPrefix(string(data), want) { + return + } + f, err := os.OpenFile("/etc/iproute2/rt_tables", os.O_APPEND|os.O_WRONLY, 0o644) + if err != nil { + return + } + defer f.Close() + _, _ = fmt.Fprintf(f, "%s\n", want) +} + +func ifaceExists(iface string) bool { + iface = strings.TrimSpace(iface) + if iface == "" { + return false + } + _, _, code, _ := runCommand("ip", "link", "show", iface) + return code == 0 +} + +func statusIfaceFromFile() string { + data, err := os.ReadFile(statusFilePath) + if err != nil { + return "" + } + var st Status + if json.Unmarshal(data, &st) != nil { + return "" + } + return strings.TrimSpace(st.Iface) +} + +func listUpIfaces() []string { + out, _, code, _ := runCommand("ip", "-o", "link", "show", "up") + if code != 0 { + return nil + } + seen := map[string]struct{}{} + var outIfaces []string + for _, line := range strings.Split(out, "\n") { + line = strings.TrimSpace(line) + if line == "" { + continue + } + parts := strings.SplitN(line, ":", 3) + if len(parts) < 3 { + continue + } + name := strings.TrimSpace(parts[1]) + name = strings.SplitN(name, "@", 2)[0] + name = strings.TrimSpace(name) + if name == "" || name == "lo" { + continue + } + if _, ok := seen[name]; ok { + continue + } + seen[name] = struct{}{} + outIfaces = append(outIfaces, name) + } + return outIfaces +} + +func listSelectableIfaces(preferred string) []string { + up := listUpIfaces() + seen := map[string]struct{}{} + var vpnLike []string + var other []string + + add := func(dst *[]string, iface string) { + iface = strings.TrimSpace(iface) + if iface == "" { + return + } + if _, ok := seen[iface]; ok { + return + } + seen[iface] = struct{}{} + *dst = append(*dst, iface) + } + + for _, iface := range up { + if isVPNLikeIface(iface) { + add(&vpnLike, iface) + } + } + for _, iface := range up { + if !isVPNLikeIface(iface) { + add(&other, iface) + } + } + sort.Strings(vpnLike) + sort.Strings(other) + + selected := make([]string, 0, len(vpnLike)+len(other)+1) + selected = append(selected, vpnLike...) + selected = append(selected, other...) + + pref := normalizePreferredIface(preferred) + if pref != "" { + if _, ok := seen[pref]; !ok { + selected = append([]string{pref}, selected...) + } + } + return selected +} + +func isVPNLikeIface(iface string) bool { + l := strings.ToLower(strings.TrimSpace(iface)) + return strings.HasPrefix(l, "tun") || + strings.HasPrefix(l, "wg") || + strings.HasPrefix(l, "ppp") || + strings.HasPrefix(l, "tap") || + strings.HasPrefix(l, "utun") || + strings.HasPrefix(l, "vpn") +} + +func resolveTrafficIface(preferred string) (string, string) { + pref := normalizePreferredIface(preferred) + if pref != "" && ifaceExists(pref) { + return pref, "preferred" + } + + statusIface := statusIfaceFromFile() + if statusIface != "" && ifaceExists(statusIface) { + return statusIface, "status" + } + + for _, iface := range listUpIfaces() { + if isVPNLikeIface(iface) { + return iface, "auto-vpn-like" + } + } + + if pref != "" { + return "", "preferred-not-found" + } + return "", "iface-not-found" +} + +type autoLocalRoute struct { + Dst string + Dev string +} + +func parseRouteDevice(fields []string) string { + for i := 0; i+1 < len(fields); i++ { + if fields[i] == "dev" { + return strings.TrimSpace(fields[i+1]) + } + } + return "" +} + +func isContainerIface(iface string) bool { + l := strings.ToLower(strings.TrimSpace(iface)) + return strings.HasPrefix(l, "docker") || + strings.HasPrefix(l, "br-") || + strings.HasPrefix(l, "veth") || + strings.HasPrefix(l, "cni") +} + +func isPrivateLikeAddr(a netip.Addr) bool { + if !a.Is4() { + return false + } + if a.IsPrivate() || a.IsLoopback() || a.IsLinkLocalUnicast() { + return true + } + // Carrier-grade NAT block. + return cgnatPrefix.Contains(a) +} + +func isAutoBypassDestination(dst string) bool { + dst = strings.TrimSpace(dst) + if dst == "" || dst == "default" { + return false + } + if strings.Contains(dst, "/") { + pfx, err := netip.ParsePrefix(dst) + if err != nil { + return false + } + return isPrivateLikeAddr(pfx.Addr()) + } + addr, err := netip.ParseAddr(dst) + if err != nil { + return false + } + return isPrivateLikeAddr(addr) +} + +func detectAutoLocalBypassRoutes(vpnIface string) []autoLocalRoute { + vpnIface = strings.TrimSpace(vpnIface) + out, _, code, _ := runCommand("ip", "-4", "route", "show", "table", "main") + if code != 0 { + return nil + } + + seen := map[string]struct{}{} + routes := make([]autoLocalRoute, 0, 8) + + add := func(dst, dev string) { + dst = strings.TrimSpace(dst) + dev = strings.TrimSpace(dev) + if dst == "" || dev == "" { + return + } + key := dst + "|" + dev + if _, ok := seen[key]; ok { + return + } + seen[key] = struct{}{} + routes = append(routes, autoLocalRoute{Dst: dst, Dev: dev}) + } + + for _, raw := range strings.Split(out, "\n") { + line := strings.TrimSpace(raw) + if line == "" { + continue + } + fields := strings.Fields(line) + if len(fields) == 0 { + continue + } + dst := strings.TrimSpace(fields[0]) + if dst == "" || dst == "default" { + continue + } + dev := parseRouteDevice(fields) + if dev == "" || dev == "lo" { + continue + } + if vpnIface != "" && dev == vpnIface { + continue + } + if isVPNLikeIface(dev) { + continue + } + + isScopeLink := strings.Contains(" "+line+" ", " scope link ") + if isScopeLink || isContainerIface(dev) || isAutoBypassDestination(dst) { + add(dst, dev) + } + } + + sort.Slice(routes, func(i, j int) bool { + if routes[i].Dev == routes[j].Dev { + return routes[i].Dst < routes[j].Dst + } + return routes[i].Dev < routes[j].Dev + }) + return routes +} + +func applyAutoLocalBypass(vpnIface string) { + for _, rt := range detectAutoLocalBypassRoutes(vpnIface) { + _, _, _, _ = runCommand( + "ip", "-4", "route", "replace", + rt.Dst, "dev", rt.Dev, "table", routesTableName(), + ) + } +} + +func prefStr(v int) string { + return strconv.Itoa(v) +} + +func removeTrafficRulesForTable() { + out, _, _, _ := runCommand("ip", "rule", "show") + for _, line := range strings.Split(out, "\n") { + line = strings.TrimSpace(line) + if line == "" { + continue + } + fields := strings.Fields(line) + if len(fields) == 0 { + continue + } + pref := strings.TrimSuffix(fields[0], ":") + if pref == "" { + continue + } + prefNum, _ := strconv.Atoi(pref) + low := strings.ToLower(line) + managed := prefNum >= trafficRulePrefManagedMin && prefNum <= trafficRulePrefManagedMax + legacy := strings.Contains(low, "lookup "+routesTableName()) + if !managed && !legacy { + continue + } + _, _, _, _ = runCommand("ip", "rule", "del", "pref", pref) + } +} + +func cgroupCandidates(entry string) []string { + v := strings.TrimSpace(entry) + if v == "" { + return nil + } + vc := filepath.Clean(v) + vals := []string{} + if filepath.IsAbs(vc) { + if strings.HasPrefix(vc, cgroupRootPath) { + vals = append(vals, vc) + } else { + vals = append(vals, filepath.Join(cgroupRootPath, strings.TrimPrefix(vc, "/"))) + } + } else { + vals = append(vals, + filepath.Join(cgroupRootPath, strings.TrimPrefix(vc, "/")), + filepath.Join(cgroupRootPath, "system.slice", strings.TrimPrefix(vc, "/")), + filepath.Join(cgroupRootPath, "user.slice", strings.TrimPrefix(vc, "/")), + ) + } + seen := map[string]struct{}{} + out := make([]string, 0, len(vals)) + for _, p := range vals { + cp := filepath.Clean(p) + if cp == "." || cp == "" { + continue + } + if _, ok := seen[cp]; ok { + continue + } + seen[cp] = struct{}{} + out = append(out, cp) + } + return out +} + +func resolveCgroupPath(entry string) (string, string) { + for _, cand := range cgroupCandidates(entry) { + fi, err := os.Stat(cand) + if err != nil || !fi.IsDir() { + continue + } + return cand, "" + } + return "", "cgroup not found: " + strings.TrimSpace(entry) +} + +func collectPIDsFromCgroup(root string) (map[int]struct{}, string) { + const ( + maxDirs = 5000 + maxPIDs = 50000 + ) + + pids := map[int]struct{}{} + dirs := 0 + warn := "" + + _ = filepath.WalkDir(root, func(path string, d os.DirEntry, err error) error { + if err != nil || d == nil || !d.IsDir() { + return nil + } + dirs++ + if dirs > maxDirs { + warn = "cgroup scan truncated by directory limit" + return filepath.SkipDir + } + data, err := os.ReadFile(filepath.Join(path, "cgroup.procs")) + if err != nil { + return nil + } + for _, ln := range strings.Split(string(data), "\n") { + ln = strings.TrimSpace(ln) + if ln == "" { + continue + } + pid, err := strconv.Atoi(ln) + if err != nil || pid <= 0 { + continue + } + pids[pid] = struct{}{} + if len(pids) > maxPIDs { + warn = "cgroup scan truncated by pid limit" + return filepath.SkipDir + } + } + return nil + }) + return pids, warn +} + +func uidRangeForPID(pid int) (string, bool) { + data, err := os.ReadFile(fmt.Sprintf("/proc/%d/status", pid)) + if err != nil { + return "", false + } + for _, ln := range strings.Split(string(data), "\n") { + ln = strings.TrimSpace(ln) + if !strings.HasPrefix(ln, "Uid:") { + continue + } + fields := strings.Fields(ln) + if len(fields) < 2 { + return "", false + } + v, ok := normalizeUIDToken(fields[1]) + return v, ok + } + return "", false +} + +func resolveCgroupUIDRanges(entries []string) ([]string, string) { + var uids []string + var warnings []string + + for _, entry := range normalizeCgroupList(entries) { + root, warn := resolveCgroupPath(entry) + if root == "" { + if warn != "" { + warnings = append(warnings, warn) + } + continue + } + pids, scanWarn := collectPIDsFromCgroup(root) + if scanWarn != "" { + warnings = append(warnings, scanWarn) + } + if len(pids) == 0 { + warnings = append(warnings, "cgroup has no processes: "+entry) + continue + } + for pid := range pids { + uidRange, ok := uidRangeForPID(pid) + if !ok || uidRange == "" { + continue + } + uids = append(uids, uidRange) + } + } + seenWarn := map[string]struct{}{} + uniqWarn := make([]string, 0, len(warnings)) + for _, w := range warnings { + ww := strings.TrimSpace(w) + if ww == "" { + continue + } + if _, ok := seenWarn[ww]; ok { + continue + } + seenWarn[ww] = struct{}{} + uniqWarn = append(uniqWarn, ww) + } + return normalizeUIDList(uids), strings.Join(uniqWarn, "; ") +} + +type effectiveTrafficOverrides struct { + VPNSubnets []string + VPNUIDs []string + DirectSubnets []string + DirectUIDs []string + CgroupResolvedUIDs int + CgroupWarning string +} + +func buildEffectiveOverrides(st TrafficModeState) effectiveTrafficOverrides { + st = normalizeTrafficModeState(st) + e := effectiveTrafficOverrides{ + VPNSubnets: append([]string(nil), st.ForceVPNSubnets...), + VPNUIDs: append([]string(nil), st.ForceVPNUIDs...), + DirectSubnets: append([]string(nil), st.ForceDirectSubnets...), + DirectUIDs: append([]string(nil), st.ForceDirectUIDs...), + } + + vpnUIDsFromCG, warnVPN := resolveCgroupUIDRanges(st.ForceVPNCGroups) + directUIDsFromCG, warnDirect := resolveCgroupUIDRanges(st.ForceDirectCGroups) + e.CgroupResolvedUIDs = len(vpnUIDsFromCG) + len(directUIDsFromCG) + e.VPNUIDs = normalizeUIDList(append(e.VPNUIDs, vpnUIDsFromCG...)) + e.DirectUIDs = normalizeUIDList(append(e.DirectUIDs, directUIDsFromCG...)) + warns := make([]string, 0, 2) + if strings.TrimSpace(warnVPN) != "" { + warns = append(warns, strings.TrimSpace(warnVPN)) + } + if strings.TrimSpace(warnDirect) != "" { + warns = append(warns, strings.TrimSpace(warnDirect)) + } + e.CgroupWarning = strings.Join(warns, "; ") + return e +} + +func applyRule(pref int, args ...string) error { + if pref <= 0 { + return fmt.Errorf("invalid pref: %d", pref) + } + cmd := []string{"rule", "add"} + cmd = append(cmd, args...) + cmd = append(cmd, "pref", prefStr(pref)) + _, _, code, err := runCommand("ip", cmd...) + if err != nil || code != 0 { + if err == nil { + err = fmt.Errorf("ip %s exited with %d", strings.Join(cmd, " "), code) + } + return err + } + return nil +} + +func applyTrafficOverrides(e effectiveTrafficOverrides) (int, error) { + applied := 0 + if len(e.DirectSubnets) > trafficRulePerKindLimit || + len(e.DirectUIDs) > trafficRulePerKindLimit || + len(e.VPNSubnets) > trafficRulePerKindLimit || + len(e.VPNUIDs) > trafficRulePerKindLimit { + return 0, fmt.Errorf("override list too large (max %d entries per kind)", trafficRulePerKindLimit) + } + + for i, cidr := range e.DirectSubnets { + if err := applyRule(trafficRulePrefDirectSubnetStart+i, "from", cidr, "lookup", "main"); err != nil { + return applied, err + } + applied++ + } + for i, uidr := range e.DirectUIDs { + if err := applyRule(trafficRulePrefDirectUIDStart+i, "uidrange", uidr, "lookup", "main"); err != nil { + return applied, err + } + applied++ + } + for i, cidr := range e.VPNSubnets { + if err := applyRule(trafficRulePrefVPNSubnetStart+i, "from", cidr, "lookup", routesTableName()); err != nil { + return applied, err + } + applied++ + } + for i, uidr := range e.VPNUIDs { + if err := applyRule(trafficRulePrefVPNUIDStart+i, "uidrange", uidr, "lookup", routesTableName()); err != nil { + return applied, err + } + applied++ + } + return applied, nil +} + +func ensureTrafficRouteBase(iface string, autoLocalBypass bool) error { + iface = strings.TrimSpace(iface) + if iface == "" { + return fmt.Errorf("empty interface") + } + if !ifaceExists(iface) { + return fmt.Errorf("interface not found: %s", iface) + } + + ensureRoutesTableEntry() + + if _, _, code, err := runCommand("ip", "-4", "route", "replace", "default", "dev", iface, "table", routesTableName(), "mtu", policyRouteMTU); err != nil || code != 0 { + if err == nil { + err = fmt.Errorf("ip route replace default exited with %d", code) + } + return err + } + + if autoLocalBypass { + applyAutoLocalBypass(iface) + } + return nil +} + +func applyTrafficMode(st TrafficModeState, iface string) error { + st = normalizeTrafficModeState(st) + eff := buildEffectiveOverrides(st) + + removeTrafficRulesForTable() + + needVPNTable := st.Mode != TrafficModeDirect || len(eff.VPNSubnets) > 0 || len(eff.VPNUIDs) > 0 + if needVPNTable { + if err := ensureTrafficRouteBase(iface, st.AutoLocalBypass); err != nil { + return err + } + } + + if _, err := applyTrafficOverrides(eff); err != nil { + return err + } + + switch st.Mode { + case TrafficModeFullTunnel: + if err := applyRule(trafficRulePrefFull, "lookup", routesTableName()); err != nil { + return err + } + case TrafficModeSelective: + if err := applyRule(trafficRulePrefSelective, "fwmark", MARK, "lookup", routesTableName()); err != nil { + return err + } + case TrafficModeDirect: + // direct mode relies only on optional direct/vpn overrides. + default: + return fmt.Errorf("unknown traffic mode: %s", st.Mode) + } + + return nil +} + +type trafficRulesState struct { + Mark bool + Full bool +} + +func readTrafficRules() trafficRulesState { + out, _, _, _ := runCommand("ip", "rule", "show") + var st trafficRulesState + for _, line := range strings.Split(out, "\n") { + l := strings.ToLower(strings.TrimSpace(line)) + if l == "" || !strings.Contains(l, "lookup "+routesTableName()) { + continue + } + fields := strings.Fields(l) + if len(fields) == 0 { + continue + } + prefRaw := strings.TrimSuffix(fields[0], ":") + pref, _ := strconv.Atoi(prefRaw) + switch pref { + case trafficRulePrefSelective: + st.Mark = true + case trafficRulePrefFull: + st.Full = true + } + } + return st +} + +func detectAppliedTrafficMode(rules trafficRulesState) TrafficMode { + if rules.Full { + return TrafficModeFullTunnel + } + if rules.Mark { + return TrafficModeSelective + } + return TrafficModeDirect +} + +func probeTrafficMode(mode TrafficMode, iface string) (bool, string) { + mode = normalizeTrafficMode(mode) + iface = strings.TrimSpace(iface) + + args := []string{"-4", "route", "get", "1.1.1.1"} + if mode == TrafficModeSelective { + args = append(args, "mark", MARK) + } + + out, _, code, err := runCommand("ip", args...) + if err != nil || code != 0 { + if err == nil { + err = fmt.Errorf("ip route get exited with %d", code) + } + return false, err.Error() + } + + text := strings.ToLower(out) + switch mode { + case TrafficModeDirect: + // direct mode must not be forced through agvpn rule table. + if strings.Contains(text, " table "+strings.ToLower(routesTableName())) { + return false, "route probe still uses agvpn table" + } + return true, "route probe direct path ok" + case TrafficModeFullTunnel, TrafficModeSelective: + if iface == "" { + return false, "route probe has empty iface" + } + if !strings.Contains(text, "dev "+strings.ToLower(iface)) { + return false, fmt.Sprintf("route probe mismatch: expected dev %s", iface) + } + return true, "route probe vpn path ok" + default: + return false, "route probe unknown mode" + } +} + +func evaluateTrafficMode(st TrafficModeState) TrafficModeStatusResponse { + st = normalizeTrafficModeState(st) + eff := buildEffectiveOverrides(st) + hasVPN := len(eff.VPNSubnets) > 0 || len(eff.VPNUIDs) > 0 + iface, reason := resolveTrafficIface(st.PreferredIface) + rules := readTrafficRules() + applied := detectAppliedTrafficMode(rules) + bypassCandidates := 0 + if st.AutoLocalBypass && (st.Mode != TrafficModeDirect || hasVPN) { + bypassCandidates = len(detectAutoLocalBypassRoutes(iface)) + } + + overridesApplied := len(eff.VPNSubnets) + len(eff.VPNUIDs) + len(eff.DirectSubnets) + len(eff.DirectUIDs) + + tableDefault := false + if iface != "" && (st.Mode != TrafficModeDirect || hasVPN) { + ok, _ := checkPolicyRoute(iface, routesTableName()) + tableDefault = ok + } + + res := TrafficModeStatusResponse{ + Mode: st.Mode, + DesiredMode: st.Mode, + AppliedMode: applied, + PreferredIface: st.PreferredIface, + AutoLocalBypass: st.AutoLocalBypass, + BypassCandidates: bypassCandidates, + ForceVPNSubnets: append([]string(nil), st.ForceVPNSubnets...), + ForceVPNUIDs: append([]string(nil), st.ForceVPNUIDs...), + ForceVPNCGroups: append([]string(nil), st.ForceVPNCGroups...), + ForceDirectSubnets: append([]string(nil), st.ForceDirectSubnets...), + ForceDirectUIDs: append([]string(nil), st.ForceDirectUIDs...), + ForceDirectCGroups: append([]string(nil), st.ForceDirectCGroups...), + OverridesApplied: overridesApplied, + CgroupResolvedUIDs: eff.CgroupResolvedUIDs, + CgroupWarning: eff.CgroupWarning, + ActiveIface: iface, + IfaceReason: reason, + RuleMark: rules.Mark, + RuleFull: rules.Full, + TableDefault: tableDefault, + } + + res.ProbeOK, res.ProbeMessage = probeTrafficMode(st.Mode, iface) + + switch st.Mode { + case TrafficModeDirect: + // direct mode can still be healthy when vpn overrides exist + // (base full/selective rules must be absent). + if hasVPN { + res.Healthy = !rules.Mark && !rules.Full && tableDefault && iface != "" && res.ProbeOK + } else { + res.Healthy = !rules.Mark && !rules.Full && res.ProbeOK + } + case TrafficModeFullTunnel: + res.Healthy = rules.Full && !rules.Mark && tableDefault && iface != "" && res.ProbeOK + case TrafficModeSelective: + res.Healthy = rules.Mark && !rules.Full && tableDefault && iface != "" && res.ProbeOK + default: + res.Healthy = false + } + + if res.Healthy { + res.Message = "traffic mode applied" + return res + } + if iface == "" && (st.Mode != TrafficModeDirect || hasVPN) { + res.Message = "vpn interface not found" + return res + } + if st.Mode != applied { + res.Message = fmt.Sprintf("desired=%s applied=%s mismatch", st.Mode, applied) + return res + } + if (st.Mode != TrafficModeDirect || hasVPN) && !tableDefault { + res.Message = "policy table default route is missing" + return res + } + if !res.ProbeOK { + res.Message = res.ProbeMessage + return res + } + if rules.Mark && rules.Full { + res.Message = "conflicting traffic rules detected" + return res + } + res.Message = "traffic mode check failed" + return res +} + +func handleTrafficInterfaces(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + st := loadTrafficModeState() + active, reason := resolveTrafficIface(st.PreferredIface) + resp := TrafficInterfacesResponse{ + Interfaces: listSelectableIfaces(st.PreferredIface), + PreferredIface: normalizePreferredIface(st.PreferredIface), + ActiveIface: active, + IfaceReason: reason, + } + writeJSON(w, http.StatusOK, resp) +} + +func handleTrafficModeTest(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet && r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + st := loadTrafficModeState() + writeJSON(w, http.StatusOK, evaluateTrafficMode(st)) +} + +func handleTrafficMode(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + st := loadTrafficModeState() + writeJSON(w, http.StatusOK, evaluateTrafficMode(st)) + case http.MethodPost: + prev := loadTrafficModeState() + next := prev + + var body TrafficModeRequest + if r.Body != nil { + defer r.Body.Close() + if err := json.NewDecoder(io.LimitReader(r.Body, 1<<20)).Decode(&body); err != nil { + http.Error(w, "bad json", http.StatusBadRequest) + return + } + } + + if strings.TrimSpace(string(body.Mode)) != "" { + next.Mode = normalizeTrafficMode(body.Mode) + } + if body.PreferredIface != nil { + next.PreferredIface = normalizePreferredIface(*body.PreferredIface) + } + if body.AutoLocalBypass != nil { + next.AutoLocalBypass = *body.AutoLocalBypass + } + if body.ForceVPNSubnets != nil { + next.ForceVPNSubnets = append([]string(nil), (*body.ForceVPNSubnets)...) + } + if body.ForceVPNUIDs != nil { + next.ForceVPNUIDs = append([]string(nil), (*body.ForceVPNUIDs)...) + } + if body.ForceVPNCGroups != nil { + next.ForceVPNCGroups = append([]string(nil), (*body.ForceVPNCGroups)...) + } + if body.ForceDirectSubnets != nil { + next.ForceDirectSubnets = append([]string(nil), (*body.ForceDirectSubnets)...) + } + if body.ForceDirectUIDs != nil { + next.ForceDirectUIDs = append([]string(nil), (*body.ForceDirectUIDs)...) + } + if body.ForceDirectCGroups != nil { + next.ForceDirectCGroups = append([]string(nil), (*body.ForceDirectCGroups)...) + } + + next = normalizeTrafficModeState(next) + prev = normalizeTrafficModeState(prev) + + nextIface, _ := resolveTrafficIface(next.PreferredIface) + if err := applyTrafficMode(next, nextIface); err != nil { + prevIface, _ := resolveTrafficIface(prev.PreferredIface) + _ = applyTrafficMode(prev, prevIface) + msg := evaluateTrafficMode(prev) + msg.Message = "apply failed, rolled back: " + err.Error() + writeJSON(w, http.StatusOK, msg) + return + } + + if err := saveTrafficModeState(next); err != nil { + writeJSON(w, http.StatusOK, TrafficModeStatusResponse{ + Mode: next.Mode, + DesiredMode: next.Mode, + PreferredIface: next.PreferredIface, + AutoLocalBypass: next.AutoLocalBypass, + ForceVPNSubnets: append([]string(nil), next.ForceVPNSubnets...), + ForceVPNUIDs: append([]string(nil), next.ForceVPNUIDs...), + ForceVPNCGroups: append([]string(nil), next.ForceVPNCGroups...), + ForceDirectSubnets: append([]string(nil), next.ForceDirectSubnets...), + ForceDirectUIDs: append([]string(nil), next.ForceDirectUIDs...), + ForceDirectCGroups: append([]string(nil), next.ForceDirectCGroups...), + OverridesApplied: len(next.ForceVPNSubnets) + len(next.ForceVPNUIDs) + len(next.ForceDirectSubnets) + len(next.ForceDirectUIDs), + Healthy: false, + Message: "state save failed: " + err.Error(), + }) + return + } + + res := evaluateTrafficMode(next) + if !res.Healthy { + prevIface, _ := resolveTrafficIface(prev.PreferredIface) + _ = applyTrafficMode(prev, prevIface) + _ = saveTrafficModeState(prev) + rolled := evaluateTrafficMode(prev) + rolled.Message = "verification failed, rolled back: " + res.Message + writeJSON(w, http.StatusOK, rolled) + return + } + + events.push("traffic_mode_changed", map[string]any{ + "mode": res.Mode, + "applied": res.AppliedMode, + "active_iface": res.ActiveIface, + "healthy": res.Healthy, + "auto_local_bypass": res.AutoLocalBypass, + "overrides_applied": res.OverridesApplied, + }) + writeJSON(w, http.StatusOK, res) + default: + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + } +} diff --git a/selective-vpn-api/app/types.go b/selective-vpn-api/app/types.go new file mode 100644 index 0000000..10a0048 --- /dev/null +++ b/selective-vpn-api/app/types.go @@ -0,0 +1,250 @@ +package app + +// --------------------------------------------------------------------- +// структуры +// --------------------------------------------------------------------- + +// EN: Shared DTO/model definitions exchanged between HTTP handlers, workers, +// EN: SSE stream, and internal orchestration logic. +// RU: Общие DTO/модели, которыми обмениваются HTTP-обработчики, воркеры, +// RU: SSE-поток и внутренняя оркестрация. + +type Status struct { + Timestamp string `json:"timestamp"` + IPCount int `json:"ip_count"` + DomainCount int `json:"domain_count"` + Iface string `json:"iface"` + Table string `json:"table"` + Mark string `json:"mark"` + + PolicyRouteOK *bool `json:"policy_route_ok,omitempty"` + RouteOK *bool `json:"route_ok,omitempty"` +} + +type cmdResult struct { + OK bool `json:"ok"` + Message string `json:"message,omitempty"` + ExitCode int `json:"exitCode,omitempty"` + Stdout string `json:"stdout,omitempty"` + Stderr string `json:"stderr,omitempty"` +} + +type VPNLoginState struct { + State string `json:"state"` + Email string `json:"email,omitempty"` + Msg string `json:"msg,omitempty"` + + // для GUI + Text string `json:"text,omitempty"` + Color string `json:"color,omitempty"` +} + +type DNSUpstreams struct { + Default1 string `json:"default1"` + Default2 string `json:"default2"` + Meta1 string `json:"meta1"` + Meta2 string `json:"meta2"` +} + +type DNSResolverMode string + +const ( + DNSModeDirect DNSResolverMode = "direct" + DNSModeSmartDNS DNSResolverMode = "smartdns" + DNSModeHybridWildcard DNSResolverMode = "hybrid_wildcard" +) + +type DNSMode struct { + ViaSmartDNS bool `json:"via_smartdns"` + SmartDNSAddr string `json:"smartdns_addr"` + Mode DNSResolverMode `json:"mode"` +} + +type DNSStatusResponse struct { + ViaSmartDNS bool `json:"via_smartdns"` + SmartDNSAddr string `json:"smartdns_addr"` + Mode DNSResolverMode `json:"mode"` + UnitState string `json:"unit_state"` + RuntimeNftset bool `json:"runtime_nftset"` + WildcardSource string `json:"wildcard_source"` + RuntimeCfgPath string `json:"runtime_config_path,omitempty"` + RuntimeCfgError string `json:"runtime_config_error,omitempty"` +} + +type DNSModeRequest struct { + ViaSmartDNS bool `json:"via_smartdns"` + SmartDNSAddr string `json:"smartdns_addr"` + Mode DNSResolverMode `json:"mode"` +} + +type SmartDNSRuntimeStatusResponse struct { + Enabled bool `json:"enabled"` + AppliedEnable bool `json:"applied_enabled"` + WildcardSource string `json:"wildcard_source"` + UnitState string `json:"unit_state"` + ConfigPath string `json:"config_path"` + Changed bool `json:"changed,omitempty"` + Restarted bool `json:"restarted,omitempty"` + Message string `json:"message,omitempty"` +} + +type SmartDNSRuntimeRequest struct { + Enabled *bool `json:"enabled"` + Restart *bool `json:"restart,omitempty"` +} + +type TrafficMode string + +const ( + TrafficModeSelective TrafficMode = "selective" + TrafficModeFullTunnel TrafficMode = "full_tunnel" + TrafficModeDirect TrafficMode = "direct" +) + +type TrafficModeState struct { + Mode TrafficMode `json:"mode"` + PreferredIface string `json:"preferred_iface,omitempty"` + AutoLocalBypass bool `json:"auto_local_bypass"` + ForceVPNSubnets []string `json:"force_vpn_subnets,omitempty"` + ForceVPNUIDs []string `json:"force_vpn_uids,omitempty"` + ForceVPNCGroups []string `json:"force_vpn_cgroups,omitempty"` + ForceDirectSubnets []string `json:"force_direct_subnets,omitempty"` + ForceDirectUIDs []string `json:"force_direct_uids,omitempty"` + ForceDirectCGroups []string `json:"force_direct_cgroups,omitempty"` + UpdatedAt string `json:"updated_at,omitempty"` +} + +type TrafficModeRequest struct { + Mode TrafficMode `json:"mode"` + PreferredIface *string `json:"preferred_iface,omitempty"` + AutoLocalBypass *bool `json:"auto_local_bypass,omitempty"` + ForceVPNSubnets *[]string `json:"force_vpn_subnets,omitempty"` + ForceVPNUIDs *[]string `json:"force_vpn_uids,omitempty"` + ForceVPNCGroups *[]string `json:"force_vpn_cgroups,omitempty"` + ForceDirectSubnets *[]string `json:"force_direct_subnets,omitempty"` + ForceDirectUIDs *[]string `json:"force_direct_uids,omitempty"` + ForceDirectCGroups *[]string `json:"force_direct_cgroups,omitempty"` +} + +type TrafficModeStatusResponse struct { + Mode TrafficMode `json:"mode"` + DesiredMode TrafficMode `json:"desired_mode"` + AppliedMode TrafficMode `json:"applied_mode"` + PreferredIface string `json:"preferred_iface,omitempty"` + AutoLocalBypass bool `json:"auto_local_bypass"` + BypassCandidates int `json:"bypass_candidates"` + ForceVPNSubnets []string `json:"force_vpn_subnets,omitempty"` + ForceVPNUIDs []string `json:"force_vpn_uids,omitempty"` + ForceVPNCGroups []string `json:"force_vpn_cgroups,omitempty"` + ForceDirectSubnets []string `json:"force_direct_subnets,omitempty"` + ForceDirectUIDs []string `json:"force_direct_uids,omitempty"` + ForceDirectCGroups []string `json:"force_direct_cgroups,omitempty"` + OverridesApplied int `json:"overrides_applied"` + CgroupResolvedUIDs int `json:"cgroup_resolved_uids"` + CgroupWarning string `json:"cgroup_warning,omitempty"` + ActiveIface string `json:"active_iface,omitempty"` + IfaceReason string `json:"iface_reason,omitempty"` + RuleMark bool `json:"rule_mark"` + RuleFull bool `json:"rule_full"` + TableDefault bool `json:"table_default"` + ProbeOK bool `json:"probe_ok"` + ProbeMessage string `json:"probe_message,omitempty"` + Healthy bool `json:"healthy"` + Message string `json:"message,omitempty"` +} + +type TrafficCandidateSubnet struct { + CIDR string `json:"cidr"` + Dev string `json:"dev,omitempty"` + Kind string `json:"kind,omitempty"` // lan|docker|link + LinkDown bool `json:"linkdown,omitempty"` +} + +type TrafficCandidateUnit struct { + Unit string `json:"unit"` + Description string `json:"description,omitempty"` + Cgroup string `json:"cgroup,omitempty"` +} + +type TrafficCandidateUID struct { + UID int `json:"uid"` + User string `json:"user,omitempty"` + Examples []string `json:"examples,omitempty"` +} + +type TrafficCandidatesResponse struct { + GeneratedAt string `json:"generated_at"` + Subnets []TrafficCandidateSubnet `json:"subnets,omitempty"` + Units []TrafficCandidateUnit `json:"units,omitempty"` + UIDs []TrafficCandidateUID `json:"uids,omitempty"` +} +type TrafficInterfacesResponse struct { + Interfaces []string `json:"interfaces"` + PreferredIface string `json:"preferred_iface,omitempty"` + ActiveIface string `json:"active_iface,omitempty"` + IfaceReason string `json:"iface_reason,omitempty"` +} + +type SystemdState struct { + State string `json:"state"` +} + +// --------------------------------------------------------------------- +// события / SSE +// --------------------------------------------------------------------- + +type Event struct { + ID int64 `json:"id"` + Kind string `json:"kind"` + Ts string `json:"ts"` + Data interface{} `json:"data,omitempty"` +} + +// EN: Callback for streaming user-visible progress from long-running nft updates. +// RU: Колбэк для отправки прогресса длительных nft-операций в пользовательский интерфейс. +type ProgressCallback func(percent int, message string) + +// --------------------------------------------------------------------- +// resolver модели +// --------------------------------------------------------------------- + +// EN: Input contract for the Go-based domain resolver job. +// RU: Контракт входных параметров для Go-резолвера доменов. +type ResolverOpts struct { + DomainsPath string + MetaPath string + StaticPath string + CachePath string + PtrCachePath string + TraceLog string + TTL int + Workers int + DNSConfigPath string + + ViaSmartDNS bool + Mode DNSResolverMode + SmartDNSAddr string + SmartDNSWildcards []string +} + +// EN: Aggregated resolver outputs consumed by routes update pipeline. +// RU: Агрегированные результаты резолвера, используемые пайплайном обновления маршрутов. +type resolverResult struct { + IPs []string + IPMap [][2]string + DirectIPs []string + DirectIPMap [][2]string + WildcardIPs []string + WildcardIPMap [][2]string + DomainCache map[string]any + PtrCache map[string]any +} + +// EN: Runtime DNS upstream pools for standard and meta-special lookups. +// RU: Наборы DNS-апстримов для обычных и meta-special резолвов. +type dnsConfig struct { + Default []string + Meta []string + SmartDNS string + Mode DNSResolverMode +} diff --git a/selective-vpn-api/app/vpn_handlers.go b/selective-vpn-api/app/vpn_handlers.go new file mode 100644 index 0000000..ec23d3d --- /dev/null +++ b/selective-vpn-api/app/vpn_handlers.go @@ -0,0 +1,371 @@ +package app + +import ( + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "os" + "strings" + "time" +) + +// --------------------------------------------------------------------- +// VPN handlers / status / locations +// --------------------------------------------------------------------- + +// EN: VPN-facing HTTP handlers for login state, logout, service/unit control, +// EN: autoloop status, locations, and location switching. +// RU: VPN-ориентированные HTTP-обработчики для login state, logout, +// RU: управления unit/service, статуса autoloop, списка локаций и смены локации. + +func handleVPNLoginState(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + state := VPNLoginState{ + State: "no_login", + Msg: "login state file not found", + Text: "AdGuard VPN: (no login data)", + Color: "gray30", + } + + data, err := os.ReadFile(loginStatePath) + if err == nil { + var fileState VPNLoginState + if err := json.Unmarshal(data, &fileState); err == nil { + if fileState.State != "" { + state.State = fileState.State + } + if fileState.Email != "" { + state.Email = fileState.Email + } + if fileState.Msg != "" { + state.Msg = fileState.Msg + } + } else { + state.State = "error" + state.Msg = "invalid adguard-login.json: " + err.Error() + } + } else if !os.IsNotExist(err) { + state.State = "error" + state.Msg = err.Error() + } + + // text/color для GUI + switch state.State { + case "ok": + if state.Email != "" { + state.Text = fmt.Sprintf("AdGuard VPN: logged in as %s", state.Email) + } else { + state.Text = "AdGuard VPN: logged in" + } + state.Color = "green4" + case "no_login": + state.Text = "AdGuard VPN: (no login data)" + state.Color = "gray30" + default: + state.Text = "AdGuard VPN: " + state.State + state.Color = "orange3" + } + + writeJSON(w, http.StatusOK, state) +} + +// --------------------------------------------------------------------- +// logout +// --------------------------------------------------------------------- + +func handleVPNLogout(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + appendTraceLine("login", "logout") + stdout, stderr, exitCode, err := runCommand(adgvpnCLI, "logout") + res := cmdResult{ + OK: err == nil && exitCode == 0, + ExitCode: exitCode, + Stdout: stdout, + Stderr: stderr, + } + if err != nil { + res.Message = err.Error() + } else { + res.Message = "logout done" + } + + // refresh login state + _, _, _, _ = runCommand("systemctl", "restart", adgvpnUnit) + + writeJSON(w, http.StatusOK, res) +} + +// --------------------------------------------------------------------- +// systemd state +// --------------------------------------------------------------------- + +func handleSystemdState(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + unit := strings.TrimSpace(r.URL.Query().Get("unit")) + if unit == "" { + http.Error(w, "unit required", http.StatusBadRequest) + return + } + stdout, _, _, err := runCommand("systemctl", "is-active", unit) + st := strings.TrimSpace(stdout) + if err != nil || st == "" { + st = "unknown" + } + writeJSON(w, http.StatusOK, SystemdState{State: st}) +} + +// --------------------------------------------------------------------- +// AdGuard autoloop / status parse +// --------------------------------------------------------------------- + +// аккуратный разбор лога autoloop: игнорим "route:", смотрим status +func parseAutoloopStatus(lines []string) (word, raw string) { + for i := len(lines) - 1; i >= 0; i-- { + line := strings.TrimSpace(lines[i]) + if line == "" { + continue + } + if idx := strings.Index(line, "autoloop:"); idx >= 0 { + line = strings.TrimSpace(line[idx+len("autoloop:"):]) + } + lower := strings.ToLower(line) + + // route: default dev ... - нам неинтересно + if strings.HasPrefix(lower, "route: ") { + continue + } + + switch { + case strings.Contains(lower, "status: connected"), + strings.Contains(lower, "after connect: connected"): + return "CONNECTED", line + case strings.Contains(lower, "status: reconnecting"): + return "RECONNECTING", line + case strings.Contains(lower, "status: disconnected"), + strings.Contains(lower, "still disconnected"): + return "DISCONNECTED", line + case strings.Contains(lower, "timeout"), + strings.Contains(lower, "failed"): + return "ERROR", line + } + } + return "unknown", "" +} + +// --------------------------------------------------------------------- +// /api/v1/vpn/autoloop-status +// --------------------------------------------------------------------- + +func handleVPNAutoloopStatus(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + lines := tailFile(autoloopLogPath, 200) + word, raw := parseAutoloopStatus(lines) + writeJSON(w, http.StatusOK, map[string]any{ + "raw_text": raw, + "status_word": word, + }) +} + +// --------------------------------------------------------------------- +// /api/v1/vpn/status +// --------------------------------------------------------------------- + +func handleVPNStatus(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + // desired location + loc := "" + if data, err := os.ReadFile(desiredLocation); err == nil { + loc = strings.TrimSpace(string(data)) + } + + // unit state + stdout, _, _, err := runCommand("systemctl", "is-active", adgvpnUnit) + unitState := strings.TrimSpace(stdout) + if err != nil || unitState == "" { + unitState = "unknown" + } + + // автолуп + lines := tailFile(autoloopLogPath, 200) + word, raw := parseAutoloopStatus(lines) + + writeJSON(w, http.StatusOK, map[string]any{ + "desired_location": loc, + "status_word": word, + "raw_text": raw, + "unit_state": unitState, + }) +} + +// --------------------------------------------------------------------- +// /api/v1/vpn/autoconnect +// --------------------------------------------------------------------- + +func handleVPNAutoconnect(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + var body struct { + Action string `json:"action"` + } + if r.Body != nil { + defer r.Body.Close() + _ = json.NewDecoder(io.LimitReader(r.Body, 1<<20)).Decode(&body) + } + action := strings.ToLower(strings.TrimSpace(body.Action)) + var cmd []string + switch action { + case "start": + cmd = []string{"systemctl", "start", adgvpnUnit} + case "stop": + cmd = []string{"systemctl", "stop", adgvpnUnit} + default: + http.Error(w, "unknown action", http.StatusBadRequest) + return + } + stdout, stderr, exitCode, err := runCommand(cmd[0], cmd[1:]...) + res := cmdResult{ + OK: err == nil && exitCode == 0, + ExitCode: exitCode, + Stdout: stdout, + Stderr: stderr, + } + if err != nil { + res.Message = err.Error() + } + writeJSON(w, http.StatusOK, res) +} + +// --------------------------------------------------------------------- +// /api/v1/vpn/locations +// --------------------------------------------------------------------- + +func handleVPNListLocations(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + // Жесткий таймаут на list-locations, чтобы не клинить HTTP + const locationsTimeout = 7 * time.Second + + start := time.Now() + stdout, _, exitCode, err := runCommandTimeout(locationsTimeout, adgvpnCLI, "list-locations") + log.Printf("list-locations took %s (exit=%d, err=%v)", time.Since(start), exitCode, err) + if err != nil || exitCode != 0 { + writeJSON(w, http.StatusOK, map[string]any{ + "locations": []any{}, + "error": fmt.Sprintf("list-locations failed: %v (exit=%d)", err, exitCode), + }) + return + } + + stdout = stripANSI(stdout) + + var locations []map[string]string + + for _, ln := range strings.Split(stdout, "\n") { + line := strings.TrimSpace(ln) + if line == "" { + continue + } + if strings.HasPrefix(line, "ISO ") { + continue + } + if strings.HasPrefix(line, "VPN ") || strings.HasPrefix(line, "You can connect") { + continue + } + + parts := strings.Fields(line) + if len(parts) < 4 { + continue + } + iso := parts[0] + ping := parts[len(parts)-1] + + if len(iso) != 2 { + continue + } + okPing := true + for _, ch := range ping { + if ch < '0' || ch > '9' { + okPing = false + break + } + } + if !okPing { + continue + } + + name := strings.Join(parts[1:len(parts)-1], " ") + label := fmt.Sprintf("%s %s (%s ms)", iso, name, ping) + + locations = append(locations, map[string]string{ + "label": label, + "iso": iso, + }) + } + + writeJSON(w, http.StatusOK, map[string]any{ + "locations": locations, + }) +} + +// --------------------------------------------------------------------- +// /api/v1/vpn/location +// --------------------------------------------------------------------- + +func handleVPNSetLocation(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + var body struct { + ISO string `json:"iso"` + } + if r.Body != nil { + defer r.Body.Close() + if err := json.NewDecoder(io.LimitReader(r.Body, 1<<20)).Decode(&body); err != nil { + http.Error(w, "bad json", http.StatusBadRequest) + return + } + } + val := strings.TrimSpace(body.ISO) + if val == "" { + http.Error(w, "iso is required", http.StatusBadRequest) + return + } + _ = os.MkdirAll(stateDir, 0o755) + if err := os.WriteFile(desiredLocation, []byte(val+"\n"), 0o644); err != nil { + http.Error(w, "write error", http.StatusInternalServerError) + return + } + + // как старый GUI: сразу рестартуем автоконнект + _, _, _, _ = runCommand("systemctl", "restart", adgvpnUnit) + + writeJSON(w, http.StatusOK, map[string]any{ + "status": "ok", + "iso": val, + }) +} diff --git a/selective-vpn-api/app/vpn_login_session.go b/selective-vpn-api/app/vpn_login_session.go new file mode 100644 index 0000000..f65735c --- /dev/null +++ b/selective-vpn-api/app/vpn_login_session.go @@ -0,0 +1,539 @@ +package app + +import ( + "bufio" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "os/exec" + "regexp" + "strconv" + "strings" + "sync" + "time" + + "github.com/creack/pty" +) + +// --------------------------------------------------------------------- +// AdGuard VPN interactive login session (PTY) +// --------------------------------------------------------------------- + +// EN: Interactive AdGuard VPN login session over PTY. +// EN: This file contains session state machine, PTY reader/parser, and HTTP API +// EN: endpoints to start/poll/control/cancel login flow. +// RU: Интерактивная PTY-сессия логина AdGuard VPN. +// RU: Файл содержит state machine, PTY reader/parser и HTTP API для +// RU: старта/опроса/управления/остановки login-процесса. + +// --------------------------------------------------------------------- +// login session API models +// --------------------------------------------------------------------- + +type LoginSessionStartResp struct { + OK bool `json:"ok"` + Phase string `json:"phase"` + Level string `json:"level"` + PID int `json:"pid,omitempty"` + Email string `json:"email,omitempty"` + Error string `json:"error,omitempty"` +} + +type LoginSessionStateResp struct { + OK bool `json:"ok"` + Phase string `json:"phase"` + Level string `json:"level"` + Alive bool `json:"alive"` + + URL string `json:"url,omitempty"` + Email string `json:"email,omitempty"` + + Cursor int64 `json:"cursor"` + Lines []string `json:"lines"` + + CanOpen bool `json:"can_open"` + CanCheck bool `json:"can_check"` + CanCancel bool `json:"can_cancel"` + + Error string `json:"error,omitempty"` +} + +type LoginSessionActionReq struct { + Action string `json:"action"` +} + +type loginLine struct { + N int64 + Line string +} + +// --------------------------------------------------------------------- +// login session manager +// --------------------------------------------------------------------- + +type loginSessionManager struct { + mu sync.Mutex + + cmd *exec.Cmd + pty *os.File + + phase string + level string + alive bool + + url string + email string + + lines []loginLine + max int + lastN int64 + + lastAutoCheck time.Time + + reURL *regexp.Regexp + reEmail *regexp.Regexp + reNextCheck *regexp.Regexp +} + +var loginMgr = newLoginSessionManager(defaultTraceTailMax) + +// --------------------------------------------------------------------- +// EN: `newLoginSessionManager` creates a new instance for login session manager. +// RU: `newLoginSessionManager` - создает новый экземпляр для login session manager. +// --------------------------------------------------------------------- +func newLoginSessionManager(max int) *loginSessionManager { + return &loginSessionManager{ + phase: "idle", + level: "yellow", + alive: false, + max: max, + reURL: regexp.MustCompile(`(https?://\S+)`), + reEmail: regexp.MustCompile(`[A-Za-z0-9._%+\-]+@[A-Za-z0-9.\-]+`), + reNextCheck: regexp.MustCompile(`(?i)^Next check in \d+s$`), + } +} + +// --------------------------------------------------------------------- +// EN: `setPhaseLocked` sets phase locked to the requested value. +// RU: `setPhaseLocked` - устанавливает phase locked в требуемое значение. +// --------------------------------------------------------------------- +func (m *loginSessionManager) setPhaseLocked(phase, level string) { + m.phase = phase + m.level = level +} + +// --------------------------------------------------------------------- +// EN: `resetLocked` contains core logic for reset locked. +// RU: `resetLocked` - содержит основную логику для reset locked. +// --------------------------------------------------------------------- +func (m *loginSessionManager) resetLocked() { + m.lines = nil + m.lastN = 0 + m.url = "" + m.email = "" + m.lastAutoCheck = time.Time{} +} + +// --------------------------------------------------------------------- +// EN: `appendLineLocked` appends or adds line locked to an existing state. +// RU: `appendLineLocked` - добавляет line locked в текущее состояние. +// --------------------------------------------------------------------- +func (m *loginSessionManager) appendLineLocked(line string) { + m.lastN++ + m.lines = append(m.lines, loginLine{N: m.lastN, Line: line}) + if len(m.lines) > m.max { + m.lines = m.lines[len(m.lines)-m.max:] + } +} + +// --------------------------------------------------------------------- +// EN: `linesSinceLocked` contains core logic for lines since locked. +// RU: `linesSinceLocked` - содержит основную логику для lines since locked. +// --------------------------------------------------------------------- +func (m *loginSessionManager) linesSinceLocked(since int64) (out []string) { + for _, it := range m.lines { + if it.N > since { + out = append(out, it.Line) + } + } + return out +} + +// --------------------------------------------------------------------- +// EN: `sendKeyLocked` sends key locked to a downstream process. +// RU: `sendKeyLocked` - отправляет key locked в нижележащий процесс. +// --------------------------------------------------------------------- +func (m *loginSessionManager) sendKeyLocked(key string) error { + if !m.alive || m.pty == nil { + return fmt.Errorf("login session not alive") + } + _, err := m.pty.Write([]byte(key + "\n")) + return err +} + +// --------------------------------------------------------------------- +// EN: `stopLocked` stops locked and cleans up resources. +// RU: `stopLocked` - останавливает locked и освобождает ресурсы. +// --------------------------------------------------------------------- +func (m *loginSessionManager) stopLocked(hard bool) { + if m.cmd == nil { + m.setPhaseLocked("idle", "yellow") + m.alive = false + m.url = "" + return + } + + // мягкий cancel + _ = m.sendKeyLocked("x") + + deadline := time.Now().Add(1200 * time.Millisecond) + for time.Now().Before(deadline) { + if m.cmd == nil || m.cmd.Process == nil { + break + } + time.Sleep(80 * time.Millisecond) + } + + if hard && m.cmd != nil && m.cmd.Process != nil { + _ = m.cmd.Process.Signal(os.Interrupt) + time.Sleep(150 * time.Millisecond) + _ = m.cmd.Process.Kill() + } + + if m.pty != nil { + _ = m.pty.Close() + m.pty = nil + } + + m.cmd = nil + m.alive = false + m.setPhaseLocked("idle", "yellow") + m.url = "" +} + +// --------------------------------------------------------------------- +// EN: `setAlreadyLoggedLocked` sets already logged locked to the requested value. +// RU: `setAlreadyLoggedLocked` - устанавливает already logged locked в требуемое значение. +// --------------------------------------------------------------------- +func (m *loginSessionManager) setAlreadyLoggedLocked(email string) { + // без запуска процесса + m.stopLocked(true) + m.resetLocked() + m.email = email + m.alive = false + m.setPhaseLocked("already_logged", "green") + if email != "" { + m.appendLineLocked("Already logged in as " + email) + } else { + m.appendLineLocked("Already logged in") + } +} + +// --------------------------------------------------------------------- +// EN: `startPTY` starts pty and initializes required state. +// RU: `startPTY` - запускает pty и инициализирует нужное состояние. +// --------------------------------------------------------------------- +func (m *loginSessionManager) startPTY() (pid int, err error) { + // caller must hold lock + m.stopLocked(true) + m.resetLocked() + m.setPhaseLocked("starting", "yellow") + + cmd := exec.Command(adgvpnCLI, "login") + ptmx, err := pty.Start(cmd) + if err != nil { + m.setPhaseLocked("failed", "red") + return 0, err + } + + m.cmd = cmd + m.pty = ptmx + m.alive = true + + pid = 0 + if cmd.Process != nil { + pid = cmd.Process.Pid + } + + go m.readerLoop(cmd, ptmx) + + return pid, nil +} + +// --------------------------------------------------------------------- +// EN: `readerLoop` reads er loop from input data. +// RU: `readerLoop` - читает er loop из входных данных. +// --------------------------------------------------------------------- +func (m *loginSessionManager) readerLoop(cmd *exec.Cmd, ptmx *os.File) { + sc := bufio.NewScanner(ptmx) + buf := make([]byte, 0, 64*1024) + sc.Buffer(buf, 1024*1024) + + for sc.Scan() { + line := strings.TrimRight(sc.Text(), "\r\n") + line = strings.TrimSpace(line) + if line == "" { + continue + } + + m.mu.Lock() + low := strings.ToLower(line) + + // URL + if m.url == "" { + if mm := m.reURL.FindStringSubmatch(line); len(mm) > 1 { + m.url = mm[1] + m.setPhaseLocked("waiting_browser", "yellow") + } + } + + // already logged / current user + if strings.Contains(low, "already logged in") || strings.Contains(low, "current user is") { + if em := m.reEmail.FindStringSubmatch(line); len(em) > 0 { + m.email = em[0] + } + m.setPhaseLocked("already_logged", "green") + } + + // success / fail + if strings.Contains(low, "successfully logged in") { + m.setPhaseLocked("success", "green") + if em := m.reEmail.FindStringSubmatch(line); len(em) > 0 { + m.email = em[0] + } + } + if strings.Contains(low, "failed to log in") { + m.setPhaseLocked("failed", "red") + } + + // auto-check trigger + if m.reNextCheck.MatchString(line) { + m.setPhaseLocked("checking", "yellow") + now := time.Now() + if m.lastAutoCheck.IsZero() || now.Sub(m.lastAutoCheck) > 1200*time.Millisecond { + _ = m.sendKeyLocked("s") + m.lastAutoCheck = now + } + m.appendLineLocked(line) + m.mu.Unlock() + continue + } + + m.appendLineLocked(line) + m.mu.Unlock() + } + + _ = ptmx.Close() + err := cmd.Wait() + + m.mu.Lock() + defer m.mu.Unlock() + + m.alive = false + + switch m.phase { + case "success", "failed", "cancelled", "already_logged": + // keep + default: + if err != nil { + m.setPhaseLocked("failed", "red") + } else { + m.setPhaseLocked("exited", "yellow") + } + } + + m.cmd = nil + m.pty = nil +} + +// --------------------------------------------------------------------- +// login state helper +// --------------------------------------------------------------------- + +func loginStateAlreadyLogged() (bool, string) { + data, err := os.ReadFile(loginStatePath) + if err != nil { + return false, "" + } + var st VPNLoginState + if err := json.Unmarshal(data, &st); err != nil { + return false, "" + } + if strings.TrimSpace(st.State) == "ok" { + return true, strings.TrimSpace(st.Email) + } + return false, "" +} + +// --------------------------------------------------------------------- +// login session API +// --------------------------------------------------------------------- + +func handleVPNLoginSessionStart(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + // если уже залогинен (по adguard-login.json) — сразу возвращаем green + if ok, email := loginStateAlreadyLogged(); ok { + appendTraceLine("login", fmt.Sprintf("session/start: already_logged email=%s", email)) + loginMgr.mu.Lock() + loginMgr.setAlreadyLoggedLocked(email) + loginMgr.mu.Unlock() + writeJSON(w, http.StatusOK, LoginSessionStartResp{ + OK: true, + Phase: "already_logged", + Level: "green", + Email: email, + }) + return + } + + loginMgr.mu.Lock() + pid, err := loginMgr.startPTY() + phase := loginMgr.phase + level := loginMgr.level + loginMgr.mu.Unlock() + if err == nil { + appendTraceLine("login", fmt.Sprintf("session/start: pid=%d", pid)) + } else { + appendTraceLine("login", fmt.Sprintf("session/start: failed: %v", err)) + } + + if err != nil { + writeJSON(w, http.StatusOK, LoginSessionStartResp{ + OK: false, + Phase: "failed", + Level: "red", + Error: err.Error(), + }) + return + } + + writeJSON(w, http.StatusOK, LoginSessionStartResp{ + OK: true, + Phase: phase, + Level: level, + PID: pid, + }) +} + +// GET /api/v1/vpn/login/session/state +// --------------------------------------------------------------------- +// EN: `handleVPNLoginSessionState` is an HTTP handler for vpn login session state. +// RU: `handleVPNLoginSessionState` - HTTP-обработчик для vpn login session state. +// --------------------------------------------------------------------- +func handleVPNLoginSessionState(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + sinceStr := strings.TrimSpace(r.URL.Query().Get("since")) + var since int64 + if sinceStr != "" { + if v, err := strconv.ParseInt(sinceStr, 10, 64); err == nil && v >= 0 { + since = v + } + } + + loginMgr.mu.Lock() + lines := loginMgr.linesSinceLocked(since) + phase := loginMgr.phase + level := loginMgr.level + alive := loginMgr.alive + url := loginMgr.url + email := loginMgr.email + cursor := loginMgr.lastN + loginMgr.mu.Unlock() + + can := alive && phase != "success" && phase != "already_logged" && phase != "failed" && phase != "cancelled" + writeJSON(w, http.StatusOK, LoginSessionStateResp{ + OK: true, + Phase: phase, + Level: level, + Alive: alive, + URL: url, + Email: email, + Cursor: cursor, + Lines: lines, + CanOpen: can, + CanCheck: can, + CanCancel: can, + }) +} + +// POST /api/v1/vpn/login/session/action +// --------------------------------------------------------------------- +// EN: `handleVPNLoginSessionAction` is an HTTP handler for vpn login session action. +// RU: `handleVPNLoginSessionAction` - HTTP-обработчик для vpn login session action. +// --------------------------------------------------------------------- +func handleVPNLoginSessionAction(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + var body LoginSessionActionReq + if r.Body != nil { + defer r.Body.Close() + _ = json.NewDecoder(io.LimitReader(r.Body, 1<<20)).Decode(&body) + } + action := strings.ToLower(strings.TrimSpace(body.Action)) + if action == "" { + http.Error(w, "action required", http.StatusBadRequest) + return + } + + loginMgr.mu.Lock() + defer loginMgr.mu.Unlock() + + if !loginMgr.alive { + writeJSON(w, http.StatusOK, map[string]any{"ok": false, "error": "login session not alive"}) + return + } + + switch action { + case "open": + appendTraceLine("login", "session/action: open") + _ = loginMgr.sendKeyLocked("b") + loginMgr.setPhaseLocked("waiting_browser", "yellow") + case "check": + appendTraceLine("login", "session/action: check") + _ = loginMgr.sendKeyLocked("s") + loginMgr.setPhaseLocked("checking", "yellow") + case "cancel": + appendTraceLine("login", "session/action: cancel") + _ = loginMgr.sendKeyLocked("x") + loginMgr.setPhaseLocked("cancelled", "red") + default: + http.Error(w, "unknown action (open|check|cancel)", http.StatusBadRequest) + return + } + + writeJSON(w, http.StatusOK, map[string]any{ + "ok": true, + "phase": loginMgr.phase, + "level": loginMgr.level, + }) +} + +// POST /api/v1/vpn/login/session/stop +// --------------------------------------------------------------------- +// EN: `handleVPNLoginSessionStop` is an HTTP handler for vpn login session stop. +// RU: `handleVPNLoginSessionStop` - HTTP-обработчик для vpn login session stop. +// --------------------------------------------------------------------- +func handleVPNLoginSessionStop(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + loginMgr.mu.Lock() + appendTraceLine("login", "session/stop") + loginMgr.stopLocked(true) + loginMgr.mu.Unlock() + writeJSON(w, http.StatusOK, map[string]any{"ok": true}) +} diff --git a/selective-vpn-api/app/watchers.go b/selective-vpn-api/app/watchers.go new file mode 100644 index 0000000..3144f92 --- /dev/null +++ b/selective-vpn-api/app/watchers.go @@ -0,0 +1,229 @@ +package app + +import ( + "context" + "crypto/sha256" + "encoding/json" + "os" + "strings" + "time" +) + +// --------------------------------------------------------------------- +// фоновые вотчеры / события +// --------------------------------------------------------------------- + +// EN: Background poll-based watchers that detect file/service state changes and +// EN: publish normalized events into the in-memory event bus for SSE clients. +// RU: Фоновые poll-вотчеры, отслеживающие изменения файлов/сервисов и +// RU: публикующие нормализованные события в in-memory event bus для SSE-клиентов. + +func startWatchers(ctx context.Context) { + statusEvery := time.Duration(envInt("SVPN_POLL_STATUS_MS", defaultPollStatusMs)) * time.Millisecond + loginEvery := time.Duration(envInt("SVPN_POLL_LOGIN_MS", defaultPollLoginMs)) * time.Millisecond + autoEvery := time.Duration(envInt("SVPN_POLL_AUTOLOOP_MS", defaultPollAutoloopMs)) * time.Millisecond + systemdEvery := time.Duration(envInt("SVPN_POLL_SYSTEMD_MS", defaultPollSystemdMs)) * time.Millisecond + traceEvery := time.Duration(envInt("SVPN_POLL_TRACE_MS", defaultPollTraceMs)) * time.Millisecond + + go watchStatusFile(ctx, statusEvery) + go watchLoginFile(ctx, loginEvery) + go watchAutoloop(ctx, autoEvery) + go watchFileChange(ctx, traceLogPath, "trace_changed", "full", traceEvery) + go watchFileChange(ctx, smartdnsLogPath, "trace_changed", "smartdns", traceEvery) + + go watchSystemdUnitDynamic(ctx, routesServiceUnitName, "routes_service", systemdEvery) + go watchSystemdUnitDynamic(ctx, routesTimerUnitName, "routes_timer", systemdEvery) + go watchSystemdUnit(ctx, adgvpnUnit, "vpn_unit", systemdEvery) + go watchSystemdUnit(ctx, "smartdns-local.service", "smartdns_unit", systemdEvery) +} + +// --------------------------------------------------------------------- +// status file watcher +// --------------------------------------------------------------------- + +func watchStatusFile(ctx context.Context, every time.Duration) { + var last [32]byte + have := false + for { + select { + case <-ctx.Done(): + return + case <-time.After(every): + } + + data, err := os.ReadFile(statusFilePath) + if err != nil { + continue + } + h := sha256.Sum256(data) + if have && h == last { + continue + } + last = h + have = true + + var st Status + if err := json.Unmarshal(data, &st); err != nil { + events.push("status_error", map[string]any{"error": err.Error()}) + continue + } + events.push("status_changed", st) + } +} + +// --------------------------------------------------------------------- +// login file watcher +// --------------------------------------------------------------------- + +func watchLoginFile(ctx context.Context, every time.Duration) { + var last [32]byte + have := false + for { + select { + case <-ctx.Done(): + return + case <-time.After(every): + } + + data, err := os.ReadFile(loginStatePath) + if err != nil { + continue + } + h := sha256.Sum256(data) + if have && h == last { + continue + } + last = h + have = true + + var st VPNLoginState + if err := json.Unmarshal(data, &st); err != nil { + events.push("login_state_error", map[string]any{"error": err.Error()}) + continue + } + events.push("login_state_changed", st) + } +} + +// --------------------------------------------------------------------- +// autoloop watcher +// --------------------------------------------------------------------- + +func watchAutoloop(ctx context.Context, every time.Duration) { + lastWord := "" + lastRaw := "" + for { + select { + case <-ctx.Done(): + return + case <-time.After(every): + } + + lines := tailFile(autoloopLogPath, 200) + word, raw := parseAutoloopStatus(lines) + if word == "" && raw == "" { + continue + } + if word == lastWord && raw == lastRaw { + continue + } + lastWord, lastRaw = word, raw + events.push("autoloop_status_changed", map[string]string{ + "status_word": word, + "raw_text": raw, + }) + } +} + +// --------------------------------------------------------------------- +// systemd unit watcher +// --------------------------------------------------------------------- + +func watchSystemdUnit(ctx context.Context, unit string, kind string, every time.Duration) { + last := "" + for { + select { + case <-ctx.Done(): + return + case <-time.After(every): + } + + stdout, _, _, err := runCommand("systemctl", "is-active", unit) + state := strings.TrimSpace(stdout) + if err != nil || state == "" { + state = "unknown" + } + if state == last { + continue + } + last = state + events.push("unit_state_changed", map[string]string{ + "unit": unit, + "kind": kind, + "state": state, + }) + } +} + +func watchSystemdUnitDynamic(ctx context.Context, resolveUnit func() string, kind string, every time.Duration) { + lastUnit := "" + lastState := "" + for { + select { + case <-ctx.Done(): + return + case <-time.After(every): + } + + unit := strings.TrimSpace(resolveUnit()) + state := "unknown" + if unit != "" { + stdout, _, _, err := runCommand("systemctl", "is-active", unit) + s := strings.TrimSpace(stdout) + if err == nil && s != "" { + state = s + } + } + if unit == lastUnit && state == lastState { + continue + } + lastUnit, lastState = unit, state + events.push("unit_state_changed", map[string]string{ + "unit": unit, + "kind": kind, + "state": state, + }) + } +} + +// --------------------------------------------------------------------- +// generic file watcher +// --------------------------------------------------------------------- + +func watchFileChange(ctx context.Context, path string, kind string, mode string, every time.Duration) { + var lastMod time.Time + var lastSize int64 = -1 + for { + select { + case <-ctx.Done(): + return + case <-time.After(every): + } + + info, err := os.Stat(path) + if err != nil { + continue + } + if info.ModTime() == lastMod && info.Size() == lastSize { + continue + } + lastMod = info.ModTime() + lastSize = info.Size() + events.push(kind, map[string]any{ + "path": path, + "mode": mode, + "size": info.Size(), + "mtime": info.ModTime().UTC().Format(time.RFC3339Nano), + }) + } +} diff --git a/selective-vpn-api/go.mod b/selective-vpn-api/go.mod new file mode 100644 index 0000000..a2f19a8 --- /dev/null +++ b/selective-vpn-api/go.mod @@ -0,0 +1,8 @@ +module selective-vpn-api + +go 1.24.2 + +require ( + github.com/cenkalti/backoff/v4 v4.3.0 + github.com/creack/pty v1.1.24 +) diff --git a/selective-vpn-api/go.sum b/selective-vpn-api/go.sum new file mode 100644 index 0000000..fefa210 --- /dev/null +++ b/selective-vpn-api/go.sum @@ -0,0 +1,4 @@ +github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= +github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= +github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s= +github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE= diff --git a/selective-vpn-api/internal/assets/domains/bases.txt b/selective-vpn-api/internal/assets/domains/bases.txt new file mode 100644 index 0000000..504d1b3 --- /dev/null +++ b/selective-vpn-api/internal/assets/domains/bases.txt @@ -0,0 +1,2 @@ +### +# Default bases list (seed). Add domains here; one per line. diff --git a/selective-vpn-api/internal/assets/domains/meta-special.txt b/selective-vpn-api/internal/assets/domains/meta-special.txt new file mode 100644 index 0000000..06f5bc9 --- /dev/null +++ b/selective-vpn-api/internal/assets/domains/meta-special.txt @@ -0,0 +1 @@ +# meta domains (seed) diff --git a/selective-vpn-api/internal/assets/domains/static-ips.txt b/selective-vpn-api/internal/assets/domains/static-ips.txt new file mode 100644 index 0000000..d1b8402 --- /dev/null +++ b/selective-vpn-api/internal/assets/domains/static-ips.txt @@ -0,0 +1 @@ +# static IPs (seed) diff --git a/selective-vpn-api/internal/assets/domains/subs.txt b/selective-vpn-api/internal/assets/domains/subs.txt new file mode 100644 index 0000000..ec4de7e --- /dev/null +++ b/selective-vpn-api/internal/assets/domains/subs.txt @@ -0,0 +1,3 @@ +www +api +static diff --git a/selective-vpn-api/main.go b/selective-vpn-api/main.go new file mode 100644 index 0000000..ec226cd --- /dev/null +++ b/selective-vpn-api/main.go @@ -0,0 +1,9 @@ +package main + +import app "selective-vpn-api/app" + +// EN: Thin executable entrypoint that delegates runtime startup to the app package. +// RU: Тонкая точка входа бинаря, делегирующая запуск пакету app. +func main() { + app.Run() +} diff --git a/selective-vpn-api/инструкция.txt b/selective-vpn-api/инструкция.txt new file mode 100644 index 0000000..dea5985 --- /dev/null +++ b/selective-vpn-api/инструкция.txt @@ -0,0 +1,460 @@ +Окей, давай сначала аккуратно разрулим логику, а уже потом можно будет добивать конкретный код (Go + Qt). + +Ты сейчас хочешь: + +Нормальную модель «DNS/SmartDNS» (кто за что отвечает). + +Новую вкладку DNS в дашборде, без текстовой панели, но с понятной индикацией: + +запущен ли smartdns-daemon (systemd unit), + +использует ли именно приложение SmartDNS или ходит напрямую на default/meta. + +Логи пока оставим на второй шаг — главное, чтобы сама схема была ровная. + +1. Предлагаемая модель DNS / SmartDNS +1.1. Состояния + +У нас по сути три независимых штуки: + +Upstreams для резолвера selective-vpn + +default1, default2, meta1, meta2 + +лежат в dnsUpstreamsConf (как и сейчас). + +SmartDNS systemd unit + +юнит типа smartdns-local.service + +может быть active/inactive/failed. + +Режим работы резолвера selective-vpn + +via_smartdns = false → резолвер ходит напрямую на default/meta. + +via_smartdns = true → резолвер ходит на 127.0.0.1#6053 (SmartDNS), а тот уже сам цепочкой: локальные → bind9 → adguard-dns. + +Важно: юнит SmartDNS и режим резолвера — разные вещи. +Может быть включён юнит, но приложение всё равно ходит напрямую на adguard-dns. + +1.2. Где хранить режим + +Вариант попроще (без изменения 100500 мест): + +Заводим отдельный json в stateDir, например +dns-mode.json: + +{ + "via_smartdns": true +} + + +Go-структура: + +// types.go +type DNSMode struct { + ViaSmartDNS bool `json:"via_smartdns"` +} + + +Помощники: + +// config.go или resolver.go +const dnsModePath = stateDir + "/dns-mode.json" + +func loadDNSMode() DNSMode { + data, err := os.ReadFile(dnsModePath) + if err != nil { + return DNSMode{ViaSmartDNS: false} + } + var m DNSMode + if err := json.Unmarshal(data, &m); err != nil { + return DNSMode{ViaSmartDNS: false} + } + return m +} + +func saveDNSMode(m DNSMode) { + b, err := json.MarshalIndent(m, "", " ") + if err != nil { + return + } + tmp := dnsModePath + ".tmp" + _ = os.WriteFile(tmp, b, 0o644) + _ = os.Rename(tmp, dnsModePath) +} + +1.3. Как это зашить в резолвер + +В ResolverOpts (в resolver.go) уже есть поле DNSConfigPath. +Добавляем туда флаг: + +type ResolverOpts struct { + DomainsPath string + MetaPath string + StaticPath string + CachePath string + PtrCachePath string + TraceLog string + TTL int + Workers int + DNSConfigPath string + + ViaSmartDNS bool +} + + +Там, где сейчас создаёшь ResolverOpts (в routes_update.go / autoloop.go), просто подставляешь: + +mode := loadDNSMode() +opts := ResolverOpts{ + // ... + DNSConfigPath: dnsUpstreamsConf, + ViaSmartDNS: mode.ViaSmartDNS, +} + + +И дальше в runResolverJob: + +func runResolverJob(opts ResolverOpts, logf func(string, ...any)) (resolverResult, error) { + // ... + + cfg := loadDNSConfig(opts.DNSConfigPath, logf) + + // если включён режим SmartDNS – игнорируем default/meta из файла + if opts.ViaSmartDNS { + cfg.Default = []string{"127.0.0.1#6053"} + cfg.Meta = []string{"127.0.0.1#6053"} + if logf != nil { + logf("dns-mode: via smartdns 127.0.0.1#6053") + } + } + + // дальше как было + // ... +} + + +Так мы гарантируем, что при включённом via_smartdns твой Go-резолвер больше не будет ходить на 94.140.* и т.п. — всё через локальный SmartDNS. + +1.4. API для режима и юнита + +Чтобы вкладка DNS могла всё это показать/пощёлкать, делаем два эндпойнта: + +Статус + +// types.go +type DNSStatusResponse struct { + ViaSmartDNS bool `json:"via_smartdns"` + UnitState string `json:"unit_state"` // "active", "inactive", "failed", "unknown" +} + +// vpn_handlers.go или dns_handlers.go +func handleDNSStatus(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + mode := loadDNSMode() + + unitState := "unknown" + if out, err := ShellRunCapture("systemctl", "is-active", "smartdns-local.service"); err == nil { + unitState = strings.TrimSpace(out) + } + + writeJSON(w, http.StatusOK, DNSStatusResponse{ + ViaSmartDNS: mode.ViaSmartDNS, + UnitState: unitState, + }) +} + + +Смена режима (использовать SmartDNS или нет) + +type DNSModeRequest struct { + ViaSmartDNS bool `json:"via_smartdns"` +} + +func handleDNSModeSet(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + var req DNSModeRequest + if r.Body != nil { + defer r.Body.Close() + _ = json.NewDecoder(io.LimitReader(r.Body, 1<<20)).Decode(&req) + } + + mode := loadDNSMode() + mode.ViaSmartDNS = req.ViaSmartDNS + saveDNSMode(mode) + + writeJSON(w, http.StatusOK, map[string]any{ + "ok": true, + "via_smartdns": mode.ViaSmartDNS, + }) +} + + +Управление юнитом SmartDNS (у тебя уже что-то такое есть — можно просто нормализовать): + +type SmartdnsServiceRequest struct { + Action string `json:"action"` // "start" / "stop" / "restart" +} + +type SmartdnsServiceResponse struct { + UnitState string `json:"unit_state"` +} + +func handleSmartdnsService(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + var req SmartdnsServiceRequest + if r.Body != nil { + defer r.Body.Close() + _ = json.NewDecoder(io.LimitReader(r.Body, 1<<20)).Decode(&req) + } + + act := strings.ToLower(strings.TrimSpace(req.Action)) + if act == "" { + act = "restart" + } + + if act == "start" || act == "restart" || act == "stop" { + _ = ShellRun("systemctl", act, "smartdns-local.service") + } + + state := "unknown" + if out, err := ShellRunCapture("systemctl", "is-active", "smartdns-local.service"); err == nil { + state = strings.TrimSpace(out) + } + + writeJSON(w, http.StatusOK, SmartdnsServiceResponse{ + UnitState: state, + }) +} + + +Роуты: + +mux.HandleFunc("/api/v1/dns/status", handleDNSStatus) +mux.HandleFunc("/api/v1/dns/mode", handleDNSModeSet) +mux.HandleFunc("/api/v1/dns/smartdns-service", handleSmartdnsService) + +2. Новая вкладка DNS в Qt + +Теперь UI можно сделать максимально простым: + +2.1. Как будет выглядеть + +Группа 1 — Upstreams + +4 поля: default1, default2, meta1, meta2 + +Кнопка «Save upstreams» + +Группа 2 — SmartDNS + +Чекбокс: Use SmartDNS (127.0.0.1:6053) → это via_smartdns + +Лейбл: SmartDNS unit: active/inactive/failed + +Лейбл: Resolver mode: via SmartDNS / direct upstreams + +(опционально) две кнопки — Start / Stop юнит + +Нижняя текстовая панель полностью убирается — всё, что касается логов, уже есть на вкладке Trace (режим mode=smartdns и mode=full). + +2.2. Код виджета (переписанный _build_tab_dns) + +В vpn_dashboard_qt.py вместо старой _build_tab_dns можно вот так (без учёта импортов — у тебя уже есть): + +def _build_tab_dns(self): + tab = QWidget() + main_layout = QVBoxLayout(tab) + + # ---------------- Upstreams ---------------- + grp_up = QGroupBox("Upstreams") + up_layout = QFormLayout(grp_up) + + self.dns_default1 = QLineEdit() + self.dns_default2 = QLineEdit() + self.dns_meta1 = QLineEdit() + self.dns_meta2 = QLineEdit() + + up_layout.addRow("default1", self.dns_default1) + up_layout.addRow("default2", self.dns_default2) + up_layout.addRow("meta1", self.dns_meta1) + up_layout.addRow("meta2", self.dns_meta2) + + btn_save = QPushButton("Save upstreams") + btn_save.clicked.connect(self.on_save_upstreams_clicked) + up_layout.addRow(btn_save) + + # ---------------- SmartDNS ---------------- + grp_smartdns = QGroupBox("SmartDNS") + sd_layout = QVBoxLayout(grp_smartdns) + + self.chk_dns_via_smartdns = QCheckBox("Use SmartDNS (127.0.0.1:6053)") + self.chk_dns_via_smartdns.stateChanged.connect(self.on_dns_mode_changed) + + self.lbl_smartdns_unit = QLabel("SmartDNS unit: unknown") + self.lbl_dns_mode = QLabel("Resolver mode: unknown") + + btn_row = QHBoxLayout() + self.btn_smartdns_start = QPushButton("Start unit") + self.btn_smartdns_stop = QPushButton("Stop unit") + self.btn_smartdns_start.clicked.connect( + lambda: self.on_smartdns_unit_action("start") + ) + self.btn_smartdns_stop.clicked.connect( + lambda: self.on_smartdns_unit_action("stop") + ) + btn_row.addWidget(self.btn_smartdns_start) + btn_row.addWidget(self.btn_smartdns_stop) + btn_row.addStretch(1) + + sd_layout.addWidget(self.chk_dns_via_smartdns) + sd_layout.addWidget(self.lbl_smartdns_unit) + sd_layout.addWidget(self.lbl_dns_mode) + sd_layout.addLayout(btn_row) + + # ---------------- Compose ---------------- + main_layout.addWidget(grp_up) + main_layout.addWidget(grp_smartdns) + main_layout.addStretch(1) + + self.tab_dns = tab + self.tabs.addTab(tab, "DNS") + +2.3. Обновление вкладки (refresh_dns_tab) +def refresh_dns_tab(self): + # 1) upstreams + ups = self.c.dns_upstreams_view() # как и было + self.dns_default1.setText(ups.default1 or "") + self.dns_default2.setText(ups.default2 or "") + self.dns_meta1.setText(ups.meta1 or "") + self.dns_meta2.setText(ups.meta2 or "") + + # 2) статус DNS / SmartDNS + st = self.c.dns_status_view() # новый метод в контроллере + + # режим + self.chk_dns_via_smartdns.blockSignals(True) + self.chk_dns_via_smartdns.setChecked(bool(st.via_smartdns)) + self.chk_dns_via_smartdns.blockSignals(False) + + mode_txt = "via SmartDNS" if st.via_smartdns else "direct upstreams" + self.lbl_dns_mode.setText(f"Resolver mode: {mode_txt}") + + # юнит + self.lbl_smartdns_unit.setText(f"SmartDNS unit: {st.unit_state or 'unknown'}") + + # немного UX: если юнит inactive, кнопка Start активна, Stop — серый + is_active = (st.unit_state == "active") + self.btn_smartdns_start.setEnabled(not is_active) + self.btn_smartdns_stop.setEnabled(is_active) + +2.4. Обработчики +def on_save_upstreams_clicked(self): + ups = self.c.dns_upstreams_view() + ups.default1 = self.dns_default1.text().strip() + ups.default2 = self.dns_default2.text().strip() + ups.meta1 = self.dns_meta1.text().strip() + ups.meta2 = self.dns_meta2.text().strip() + + ok, err = self.c.dns_upstreams_save(ups) + if not ok: + QMessageBox.critical(self, "Error", f"Failed to save upstreams:\n{err}") + else: + self.show_status("DNS upstreams saved") + + +def on_dns_mode_changed(self, state: int): + via = (state == Qt.Checked) + ok, st, err = self.c.dns_mode_set(via) + if not ok: + QMessageBox.critical(self, "Error", f"Failed to change DNS mode:\n{err}") + # откатываем чекбокс + self.refresh_dns_tab() + return + # обновляем лейблы + self.refresh_dns_tab() + + +def on_smartdns_unit_action(self, action: str): + ok, st, err = self.c.smartdns_service_action(action) + if not ok: + QMessageBox.critical(self, "Error", f"SmartDNS {action} failed:\n{err}") + return + self.refresh_dns_tab() + self.refresh_status_tab() # если хочешь синхронизировать индикатор на Status + +2.5. Методы в dashboard_controller.py + +Нужно добавить: + +@dataclass +class DNSStatus: + via_smartdns: bool + unit_state: str + +class DashboardController: + # ... + + def dns_status_view(self) -> DNSStatus: + data = self.client.get_json("/api/v1/dns/status") + return DNSStatus( + via_smartdns=bool(data.get("via_smartdns")), + unit_state=data.get("unit_state") or "unknown", + ) + + def dns_mode_set(self, via: bool) -> tuple[bool, DNSStatus | None, str | None]: + try: + data = self.client.post_json( + "/api/v1/dns/mode", + {"via_smartdns": bool(via)}, + ) + st = self.dns_status_view() + return True, st, None + except Exception as e: + return False, None, str(e) + + def smartdns_service_action(self, action: str) -> tuple[bool, str | None, str | None]: + try: + data = self.client.post_json( + "/api/v1/dns/smartdns-service", + {"action": action}, + ) + return True, data.get("unit_state"), None + except Exception as e: + return False, None, str(e) + + +Upstreams-методы у тебя уже есть, их трогать не нужно. + +3. Что дальше с логами + +Когда эта схема заведётся (юнит живёт своей жизнью, резолвер знает via_smartdns и ходит только на 127.0.0.1), дальше можно: + +В digA логировать отдельно: + +dns smartdns когда ViaSmartDNS == true, + +dns direct для прямых upstream. + +В handleTraceJSON(mode="gui") продолжать выкидывать «шум» вроде больших дампов резолвера. + +Для mode=smartdns показывать только строки, где есть smartdns в префиксе — и там уже будет видно цепочку. + +Если хочешь, в следующем шаге можем прям точечно добить: + +исправление поля ViaSmartDNS bool в types.go (я видел, что у тебя там раздралось на ViaSmartDN S), + +аккуратно переписать digA-лог, чтобы не было этих сообщений вида dns error ... via 94.140.14.14: lookup ... on 192.168.50.10:53, когда на самом деле всё ок. diff --git a/selective-vpn-api/инструкция2.txt b/selective-vpn-api/инструкция2.txt new file mode 100644 index 0000000..7f84c31 --- /dev/null +++ b/selective-vpn-api/инструкция2.txt @@ -0,0 +1,565 @@ +Ок, давай разберём по двум пунктам: + +Финальные имена и что именно править в коде + +Почему стало ~590 IP вместо привычных ~1700 + +1. Имена / где что менять + +Чтобы не было зоопарка, фиксируем одну схему: + +внутренний конфиг резолвера: + +// types.go +type DNSMode string + +const ( + DNSModeDirect DNSMode = "direct" // ходим напрямую на default/meta + DNSModeSmartOnly DNSMode = "smartdns" // всё через SmartDNS +) + +type dnsConfig struct { + Default []string `json:"default"` + Meta []string `json:"meta"` + SmartDNS string `json:"smartdns"` + Mode DNSMode `json:"mode"` +} + + +Важно: одно имя — dnsConfig, одно поле режима — Mode. +Никаких ViaSmartDNS, ViaSmartDnS и т.п. – только Mode. + +1.1. loadDNSConfig (используется резолвером) + +Внутри dns_settings.go (или там, где у тебя сейчас чтение /etc/selective-vpn/dns-upstreams.conf) логика должна отдавать именно такой dnsConfig: + +func loadDNSConfig(path string, logf func(string, ...any)) dnsConfig { + cfg := dnsConfig{ + Default: []string{defaultDNS1, defaultDNS2}, + Meta: []string{defaultMeta1, defaultMeta2}, + SmartDNS: "", + Mode: DNSModeDirect, + } + + data, err := os.ReadFile(path) + if err != nil { + // файла нет – просто дефолты + return cfg + } + + var def, meta []string + mode := DNSModeDirect + smart := "" + + lines := strings.Split(string(data), "\n") + for _, ln := range lines { + s := strings.TrimSpace(ln) + if s == "" || strings.HasPrefix(s, "#") { + continue + } + parts := strings.Fields(s) + if len(parts) < 2 { + continue + } + key := strings.ToLower(parts[0]) + vals := parts[1:] + + switch key { + case "default": + def = append(def, vals...) + case "meta": + meta = append(meta, vals...) + case "smartdns": + if len(vals) > 0 { + smart = vals[0] + } + case "mode": + switch DNSMode(vals[0]) { + case DNSModeSmartOnly: + mode = DNSModeSmartOnly + default: + mode = DNSModeDirect + } + } + } + + if len(def) > 0 { + cfg.Default = def + } + if len(meta) > 0 { + cfg.Meta = meta + } + if smart != "" { + cfg.SmartDNS = smart + } + cfg.Mode = mode + + if logf != nil { + if cfg.Mode == DNSModeSmartOnly && cfg.SmartDNS != "" { + logf("dns-config: mode=smartdns smartdns=%s", cfg.SmartDNS) + } else { + logf("dns-config: mode=direct default=%v meta=%v", cfg.Default, cfg.Meta) + } + } + + return cfg +} + +1.2. runResolverJob — как он выбирает режим + +Твой кусок: + +cfg := loadDNSConfig(opts.DNSConfigPath, logf) + if logf != nil { + if cfg.ViaSmartDNS { + logf("resolver dns mode: SmartDNS-only (%v)", cfg.Default) + } else { + logf("resolver dns mode: direct default=%v meta=%v", cfg.Default, cfg.Meta) + } + } + + +Нужно заменить на: + +cfg := loadDNSConfig(opts.DNSConfigPath, logf) + +if logf != nil { + if cfg.Mode == DNSModeSmartOnly && cfg.SmartDNS != "" { + logf("resolver dns mode: SmartDNS-only (%s)", cfg.SmartDNS) + } else { + logf("resolver dns mode: direct default=%v meta=%v", cfg.Default, cfg.Meta) + } +} + +1.3. Как резолвер реально ходит в DNS + +В resolveHostGo сейчас у тебя примерно так (я по смыслу): + +func resolveHostGo(host string, cfg dnsConfig, metaSpecial []string, logf func(string, ...any)) ([]string, int) { + useMeta := false + for _, m := range metaSpecial { + if host == m { + useMeta = true + break + } + } + + dnsList := cfg.Default + if useMeta { + dnsList = cfg.Meta + } + + ips, errs := digA(host, dnsList, 3*time.Second, logf) + ... +} + + +Делаем так, чтобы при включённом SmartDNS всегда шли только на него: + +func resolveHostGo(host string, cfg dnsConfig, metaSpecial []string, logf func(string, ...any)) ([]string, int) { + useMeta := false + for _, m := range metaSpecial { + if host == m { + useMeta = true + break + } + } + + var dnsList []string + + if cfg.Mode == DNSModeSmartOnly && cfg.SmartDNS != "" { + // ВСЁ через SmartDNS + dnsList = []string{cfg.SmartDNS} + } else if useMeta { + dnsList = cfg.Meta + } else { + dnsList = cfg.Default + } + + ips, errs := digA(host, dnsList, 3*time.Second, logf) + ... +} + + +И для PTR-запросов (resolveStaticLabels → digPTR) — аналогично: + +func resolveStaticLabels(entries [][3]string, cfg dnsConfig, ptrCache map[string]any, ttl int, logf func(string, ...any)) (map[string][]string, int, int) { + ... + dnsForPtr := "" + + if cfg.Mode == DNSModeSmartOnly && cfg.SmartDNS != "" { + dnsForPtr = cfg.SmartDNS + } else if len(cfg.Default) > 0 { + dnsForPtr = cfg.Default[0] + } else { + dnsForPtr = defaultDNS1 + } + ... +} + +1.4. routes_update.go — что именно должно быть + +Там, где запускается резолвер (что-то вроде runResolverJob), должно быть только это (без своих конфигов): + +opts := ResolverOpts{ + DomainsPath: domainDir + "/bases.txt", + MetaPath: domainDir + "/meta-special.txt", + StaticPath: staticIPsFile, + CachePath: stateDir + "/resolver-cache.json", + PtrCachePath: stateDir + "/resolver-ptr-cache.json", + TraceLog: traceLogPath, + TTL: 24 * 3600, // или твой config.ResolverTTLSeconds + Workers: 200, // или config.ResolverWorkers + DNSConfigPath: dnsUpstreamsConf, // ВАЖНО: один путь, один формат +} + +res, err := runResolverJob(opts, logf) + + +Никаких других структур / полей для DNS в этом файле не нужно. +Вся магия DNS живёт в dns_settings.go + resolver.go. + +2. Почему теперь 590 IP, а не 1700 + +Смотрим на лог: + +summary: domains=14864, cache_hits=2 resolved_now=1281 unresolved=13581 static_entries=294 ... +unique_ips=590 ... + + +Это говорит о трёх вещах: + +Резолвер реально прошёлся по всем ~15k доменам. + +Ответы он получил только по ~1281 домену. + +Из них после фильтра isPrivateIPv4 осталось 590 уникальных публичных IP. + +Что поменялось по сравнению с прошлой версией: + +2.1. Всё теперь идёт ТОЛЬКО через SmartDNS + +Раньше схема была примерно: + +резолвер → напрямую на 94.140.* + 46.243.* + +SmartDNS жил своей жизнью как локальный помощник + +Сейчас при Use SmartDNS for resolver: + +резолвер → только на 127.0.0.1#6053 + +SmartDNS → дальше сам ходит на 192.168.50.10:53 (bind) → AdGuard DNS и т.д. + +Поэтому: + +все блокировки / NXDOMAIN / подмена на 0.0.0.0 / 127.0.0.1 происходят до нашего Go-кода; + +Go-резолвер видит либо нормальный внешний IP, либо 0.0.0.0/локалку/ошибку. + +2.2. Мы фильтруем все приватные и мусорные IP + +В runResolverJob: + +if len(rawIps) > 0 && ts > 0 && now-int(ts) <= ttl { + for _, v := range rawIps { + if s, ok := v.(string); ok && s != "" && !isPrivateIPv4(s) { + ips = append(ips, s) + } + } +} +... +if isPrivateIPv4(ip) { + continue +} + + +А isPrivateIPv4 считает приватными: + +0.x.x.x + +10.x.x.x + +127.x.x.x + +192.168.x.x + +172.16–31.x.x + +и вообще всё, что невалидное + +То есть: + +если AdGuard/SmartDNS возвращает 0.0.0.0 или 127.0.0.1 для трекеров / рекламы, +мы их больше не учитываем вообще; + +раньше, когда ходили мимо SmartDNS, таких «мусорных» ответов было гораздо меньше. + +Отсюда очень понятная картина: + +1700 IP ранее = реальные IP + всякий мусор (0.0.0.0, локалки, возможно ещё что-то); + +590 IP сейчас = только реальные публичные адреса, которые реально нужно засунуть в ipset и гонять через VPN. + +2.3. Что с ошибками в логах вида ... on 192.168.50.10:53 + +Это просто текст ошибки из SmartDNS/bind, который пробрасывается наружу: + +ptr error 104.26.10.38 via 127.0.0.1:6053: lookup 38.10.26.104.in-addr.arpa. on 192.168.50.10:53: no such host + + +мы обращаемся к 127.0.0.1:6053; + +SmartDNS дальше ходит на 192.168.50.10:53 и говорит «no such host»; + +net.Resolver формирует такое красивое сообщение. + +То есть Go-код SmartDNS не обходит, всё как раз правильно идёт через него. + +Если хочешь «как раньше» по цифрам + +Если прям хочется снова видеть ~1700 IP (включая мусор): + +Можно временно ослабить фильтр: + +func isPrivateIPv4(ip string) bool { + // В тесте можно не резать 0.0.0.0 / 127.x.x.x и т.п., + // просто вернуть false и посмотреть цифры + return false +} + + +Или сделать две метрики: + +unique_ips_total — с учётом всего, что вернул DNS; + +unique_ips_public — после фильтра приватных; + +и логировать обе. + +Но с точки зрения селективных маршрутов тебе как раз нужны только публичные IP, так что текущие 590 очень вероятно ближе к «правильной» картине. + +Если хочешь, в следующем шаге могу: + +дать патч для логов, чтобы в summary было: +unique_ips_public, unique_ips_private, blocked_0_ips и т.д., + +плюс подсветить в Trace вкладке отдельной строкой, сколько доменов SmartDNS убил в ноль. + +в том то и дело, что теперь половина доменов которые железно работали, не работают теперь, что то изменилось, не понимаю что, у нас кеш пишется? что же не так ... и кстати логи по прежнему не верно отображается +Перестала думать +в том то и дело, что теперь половина доменов которые железно работали, не работают теперь, что то изменилось, не понимаю что, у нас кеш пишется? что же не так ... и кстати логи по прежнему не верно отображается +Думал на протяжении 2m 20s + +Окей, давай по порядку, что видно из скрина логов, и что мы реально поменяли в коде. + +⚠️ Сначала техническая ремарка +Часть файлов, которые ты раньше заливал (resolver.go, config.go и т.д.), у меня уже недоступны – система их подчищает со временем. Я помню общую структуру и последние куски, которые ты присылал в текст, но если захочешь, чтобы я дал прям точные патчи построчно по актуальному resolver.go / dns_settings.go, лучше закинуть их ещё раз. + +1. Что сейчас не так по логам + +На твоём скрине снизу была строчка типа: + +dns_errors=13295 unresolved=13581 unique_ips=590 ... + + +И вверху куча: + +dns warn www.twitterinc.com via 127.0.0.1:6053: + lookup www.twitterinc.com on 192.168.50.53:53: i/o timeout +... +dns warn www.twitterstat.us via 127.0.0.1:6053: + dial udp 127.0.0.1:6053: i/o timeout + + +Это говорит о простом: + +Домены в большинстве случаев не резолвятся вообще. +Не из-за кэша, а из-за того, что SmartDNS/BIND/AdGuard под нагрузкой просто не отвечает (таймауты). + +Из-за этого у нас: + +resolved_now маленький + +dns_errors огромный + +unique_ips ≈ 500+, хотя раньше было ≈ 1700. + +То есть лог честный: с точки зрения нашего Go-резолвера для этих доменов реально "no IPs". + +2. Что именно изменилось в логике DNS + +До рефакторинга: + +наш resolver.go бился напрямую в default1/default2/meta1/meta2 (типа 94.140.14.15 и т.д.); + +SmartDNS жил сам по себе и занимался вайлдкарт-ipset’ами. + +После рефакторинга (который мы как раз и делали по твоей просьбе): + +если отмечено «Use SmartDNS for resolver» — мы: + +подменяем cfg.Default = []string{ smartdnsAddr } (127.0.0.1#6053); + +и всё резолвим через SmartDNS; + +параллельно мы оставили очень агрессивный пул: workers = 200, domains ≈ 15k. + +И получается ситуация: + +15 000 доменов × 200 воркеров × таймаут 3 секунды → местный SmartDNS + BIND просто не успевает, начинает дропать/таймаутить, мы видим кучу i/o timeout, и в итоге получаем только ~590 IP. + +То есть сломали не кэш, а заддосили свой же локальный resolver 😅 + +3. Кэш – жив он или нет? + +Да, по коду (который ты кидал): + +domainCache := loadJSONMap(opts.CachePath) +ptrCache := loadJSONMap(opts.PtrCachePath) +... +domainCache[host] = map[string]any{ + "ips": r.ips, + "last_resolved": now, +} +... +res.DomainCache = domainCache +res.PtrCache = ptrCache +... +saveJSON(res.DomainCache, opts.CachePath) +saveJSON(res.PtrCache, opts.PtrCachePath) + + +Кэш: + +читается в начале (loadJSONMap); + +обновляется только для успешных доменов; + +сохраняется в конце. + +Проблема в другом: когда у тебя за один прогон 13k DNS-ошибок, в кэш почти нечего писать — большая часть доменов вообще не получила IP, и в следующий прогон они снова идут "с нуля". + +4. Что я бы сделал сейчас в коде (минимально и по делу) +4.1. Ограничить параллелизм при работе через SmartDNS + +В runResolverJob (там, где ты сейчас считаешь workers), сделай так: + +workers := opts.Workers +if workers <= 0 { + workers = 200 +} + +// если работаем через SmartDNS – не лупим по нему 200 горутин +if cfg.ViaSmartDNS && workers > 50 { + workers = 50 +} + +// safety clamps +if workers < 1 { + workers = 1 +} +if workers > 500 { + workers = 500 +} + +if logf != nil { + logf("resolver start: domains=%d ttl=%ds workers=%d viaSmartDNS=%v", + len(domains), ttl, workers, cfg.ViaSmartDNS) +} + + +Суть: + +при прямых DNS — всё как раньше (200 воркеров нормально жрут 94.140.*); + +при SmartDNS — ставим жёсткий потолок, например 50 (можно 30-80 поэкспериментировать). + +Это радикально уменьшит количество одновременных запросов к SmartDNS/BIND, и таймаутов станет сильно меньше → dns_errors упадут, unique_ips вырастут. + +4.2. (Опционально) Увеличить таймаут для SmartDNS + +Сейчас digA вызывается так (по прошлому коду): + +ips, errs := digA(host, dnsList, 3*time.Second, logf) + + +Если хочешь мягко подстраховаться: + +timeout := 3 * time.Second +if cfg.ViaSmartDNS { + timeout = 5 * time.Second +} +ips, errs := digA(host, dnsList, timeout, logf) + + +Для этого нужно просто протащить cfg или флажок viaSmartDNS до resolveHostGo: + +func resolveHostGo(host string, cfg dnsConfig, metaSpecial []string, viaSmartDNS bool, logf func(string, ...any)) ([]string, int) { + ... + timeout := 3 * time.Second + if viaSmartDNS { + timeout = 5 * time.Second + } + ips, errs := digA(host, dnsList, timeout, logf) + ... +} + + +И в runResolverJob вызывать: + +ips, errs := resolveHostGo(j.host, cfg, metaSpecial, cfg.ViaSmartDNS, logf) + + +Это не обязательно, но в связке с уменьшением workers даст более стабильную картину. + +5. Почему лог «неверно отображается» + +Тут важно разделить: + +То, что реально делает наш резолвер. +Если он пишет: + +dns warn www.twitterinc.com via 127.0.0.1:6053: dial udp 127.0.0.1:6053: i/o timeout + + +– это значит, что конкретно этот запрос реально упал по таймауту. +Даже если "в браузере всё открывается", это уже другие запросы, в другое время и с другим паттерном (не 15k доменов за раз). + +Фильтрация на вкладке Events/SmartDNS/Full. +Мы уже делали фильтр для mode=gui, чтобы убрать: + +[routes], resolver, dnstt, smartdns, nft add element, cache hit и т.п. шум. + +Если сейчас в режиме Events ты всё равно видишь этот трэш с [routes] dns warn ... — значит, в handleTraceJSON в ветке case "gui": надо ещё сильнее ужать фильтр, например: + +if strings.HasPrefix(ll, "[routes]") || + strings.Contains(ll, " dns warn ") || + strings.Contains(ll, " dns error ") { + continue +} + + +Тогда: + +весь подробный DNS-дебаг остаётся в Full и SmartDNS; + +Events показывает только человеческие штуки [gui], [login], [vpn], [error] и т.д. + +Если хочешь, чтобы я точно подогнал фильтр под твой текущий trace_handlers.go, просто перекинь файл ещё раз — старая версия у меня уже протухла. + +6. Что бы я сделал дальше по шагам + +Внести правку с ограничением workers при ViaSmartDNS. + +(Опционально) Протянуть флаг viaSmartDNS в resolveHostGo и увеличить таймаут до 5s. + +Запустить update-selective-routes и посмотреть: + +строку summary (особенно dns_errors и unique_ips); + +стало ли больше IP (должно быть ближе к твоим старым ~1700, хотя не обязательно 1:1). + +Если с IP всё норм, но Events всё ещё замусорен – подправить фильтр в handleTraceJSON (mode = "gui"). + +Если хочешь, чтобы я прям по твоим актуальным файлам выписал патчи diffом (resolver.go, dns_settings.go, trace_handlers.go) – скинь их ещё раз, я аккуратно вставлю всё в нужные места без переименования функций и с теми же структурами, что уже есть. diff --git a/selective-vpn-api/инструкция3-safe.txt b/selective-vpn-api/инструкция3-safe.txt new file mode 100644 index 0000000..6fc6aec --- /dev/null +++ b/selective-vpn-api/инструкция3-safe.txt @@ -0,0 +1,201 @@ +Инструкция 3 (safe-версия) + +Цель: +Сделать улучшения резолвера без risky-переписывания, сохранить обратную совместимость с текущим GUI/API, +и убрать главные источники шума: массовые NXDOMAIN, таймауты и нечитабельные метрики. + +-------------------------------------------------------------------- +0) Что подтверждено по текущему состоянию +-------------------------------------------------------------------- + +1. Основной рабочий код сейчас в: + - app/resolver.go + - app/routes_update.go + - app/dns_settings.go + - app/domains_handlers.go + +2. Логи показывают: + - много NXDOMAIN (ожидаемо при широком base x subs) + - заметную долю timeout + - один агрегированный счетчик dns_errors, из-за чего трудно понять причину деградации. + +3. Формат dns upstream у нас: host#port (например 94.140.14.15#53 или 127.0.0.1#6053). + Это важно: нельзя использовать валидацию, которая принимает только host:port. + +-------------------------------------------------------------------- +1) Архитектурное решение (рекомендуемое) +-------------------------------------------------------------------- + +Не оставлять только переключатель direct <-> smartdns. +Сделать 3 режима резолвера: + +- direct: + обычные домены через default/meta upstream. + +- smartdns: + все домены через SmartDNS address. + +- hybrid_wildcard (recommended): + только wildcard-домены через SmartDNS, остальные напрямую через default/meta. + +Почему так лучше: +- сохраняем скорость и отказоустойчивость direct для обычных доменов; +- wildcard-логику держим строго в SmartDNS, как ты и хотел; +- не ломаем текущий UX: можно оставить старый bool и маппить его на mode. + +-------------------------------------------------------------------- +2) Что НЕ внедряем из старой инструкция3 +-------------------------------------------------------------------- + +1. Не используем netip.ParseAddrPort для upstream-валидации (ломает host#port). +2. Не используем netip.MustParseAddr в hot path (может паниковать). +3. Не добавляем лишний semaphore поверх worker pool (сложность без явной выгоды). +4. Не делаем агрессивный рефактор API-контрактов без backward-compat. + +-------------------------------------------------------------------- +3) Пакет безопасных правок (приоритет P1) +-------------------------------------------------------------------- + +P1.1 - Режимы DNS (backward compatible) +Файлы: +- app/types.go +- app/dns_settings.go +- app/routes_update.go +- app/resolver.go + +Изменения: +1) Ввести enum режима: + type DNSResolverMode string + const ( + DNSModeDirect DNSResolverMode = "direct" + DNSModeSmartDNS DNSResolverMode = "smartdns" + DNSModeHybrid DNSResolverMode = "hybrid_wildcard" + ) + +2) Расширить DNSMode/DNSStatusResponse/DNSModeRequest полем Mode, + но оставить ViaSmartDNS для старого GUI: + - если Mode пустой, использовать ViaSmartDNS: + - true -> smartdns + - false -> direct + +3) В ResolverOpts передавать Mode и список wildcard-доменов (один раз на job). + +4) В resolveHostGo выбирать dnsList так: + - mode == smartdns: []{smartdnsAddr} + - mode == hybrid_wildcard и host совпал с wildcard: []{smartdnsAddr} + - иначе: meta или default по текущей логике. + +Примечание: +Wildcard-список уже хранится в smartdns-wildcards.json через /api/v1/smartdns/wildcards. +Нужно только использовать его в резолвере. + +P1.2 - Upstream fallback с классификацией ошибок +Файл: +- app/resolver.go + +Изменения: +1) В digA: + - идти по upstream последовательно; + - timeout/temporary -> fallback на следующий upstream; + - nxdomain -> остановить попытки для домена (дальше пробовать бессмысленно). + +2) Классифицировать ошибки через net.DNSError + fallback по тексту: + - nxdomain + - timeout + - temporary + - other + +3) Вместо одного dns_errors вести структуру счетчиков: + dns_attempts, dns_ok, dns_nxdomain, dns_timeout, dns_temporary, dns_other. + +P1.3 - Разделенные метрики в summary +Файл: +- app/resolver.go + +Изменения: +1) Обновить финальный лог "resolve summary" с раздельными счетчиками DNS-ошибок. +2) Добавить per-upstream агрегаты (минимум attempts/ok/timeout/nxdomain/other). + Формат может быть одной строкой JSON, чтобы GUI/анализатору было проще парсить. + +P1.4 - Ограничение domain expansion +Файл: +- app/routes_update.go + +Изменения: +1) Добавить конфиг-лимиты через env: + - RESOLVE_SUBS_PER_BASE_LIMIT (например default 25) + - RESOLVE_DOMAINS_HARD_CAP (например default 12000) + +2) После построения domainSet: + - сортировать домены; + - при превышении hard cap обрезать хвост детерминированно; + - писать явный warning в trace. + +3) Логировать breakdown: + bases_count, subs_count, expanded_count, total_domains. + +-------------------------------------------------------------------- +4) Пакет улучшений P2 (после P1) +-------------------------------------------------------------------- + +P2.1 - Negative cache +Файл: +- app/resolver.go + +Идея: +- кэшировать nxdomain/servfail на короткий TTL (например 10-20 минут), + чтобы не долбить одинаковые несуществующие имена каждую прогонку. + +P2.2 - PTR retry (ограниченный) +Файл: +- app/resolver.go + +Идея: +- для digPTR сделать 1-2 retry только на timeout/temporary; +- не ретраить nxdomain. + +P2.3 - GUI/API отображение режима +Файлы GUI: +- selective-vpn-gui/api_client.py +- selective-vpn-gui/dashboard_controller.py +- selective-vpn-gui/vpn_dashboard_qt.py + +Идея: +- показать mode = direct/smartdns/hybrid_wildcard; +- оставить старый toggle рабочим (маппинг direct/smartdns), + а hybrid можно добавить как отдельный выбор (позже). + +-------------------------------------------------------------------- +5) Порядок внедрения (рекомендуемый) +-------------------------------------------------------------------- + +Шаг 1: +Реализовать режимы + fallback + split-метрики (P1.1, P1.2, P1.3). + +Шаг 2: +Добавить лимиты expansion (P1.4), проверить изменение unresolved/timeouts. + +Шаг 3: +Добавить negative cache и PTR retry (P2.1, P2.2). + +-------------------------------------------------------------------- +6) Критерии готовности +-------------------------------------------------------------------- + +1. В trace видно не один dns_errors, а раздельные причины. +2. В режиме direct при падении одного upstream запросы частично продолжают проходить через fallback. +3. В режиме hybrid_wildcard wildcard-домены уходят через SmartDNS, остальные через direct. +4. Количество timeout и общее время прогона снижаются относительно текущего baseline. + +-------------------------------------------------------------------- +7) Короткий диагноз по сути +-------------------------------------------------------------------- + +Проблема не в одном SmartDNS bottleneck. +Основной вклад сейчас дает комбинация: +- большой domain expansion, +- много несуществующих имен (NXDOMAIN), +- и отсутствие четкой стратегии fallback/метрик. + +Safe-путь: +ввести гибридный routing DNS + раздельную диагностику + контролируемый expansion. diff --git a/selective-vpn-api/инструкция3.txt b/selective-vpn-api/инструкция3.txt new file mode 100644 index 0000000..83c84b8 --- /dev/null +++ b/selective-vpn-api/инструкция3.txt @@ -0,0 +1,252 @@ +Анализ твоего кода resolver.go +Я внимательно прошёлся по всему коду: структура хорошая (модульный, с логами, кэшем, concurrent резолвом), логика правильная для задачи (домены → IPs с meta/default upstream'ами, static с PTR-лейблами, приватные IP фильтруем). Это уже рабочий инструмент, который может обрабатывать тысячи доменов без сбоев. Но да, есть потенциальные проблемы и места для улучшений — некоторые очевидные (edge-кейсы), другие — для масштаба/надёжности. +Я разделил на потенциальные поломки (где код может крашнуться или дать неверный результат), слабые места (не поломки, но риски в проде) и рекомендации с кодовыми примерами. Всё основано на твоём коде — не выдумываю, а анализирую. +1. Потенциальные поломки (где код может упасть или дать мусор) + +Конфликт портов в upstream (в resolveHostGo и digPTR): +splitDNS возвращает port (e.g., "127.0.0.1#6053" → "127.0.0.1", "6053"). +Но в net.Resolver.Dial ты используешь net.JoinHostPort(server, port) — это ок, но Resolver по дефолту использует UDP. Если upstream требует TCP (e.g., большие ответы >512 байт) — запрос fail (dns.exception.SyntaxError или timeout). +Поломка: Для доменов с большим TXT/NS — no IPs, но errs++ (не обновит кэш). +Тест: Резолв домен с большим ответом (e.g., "example.com" TXT) на upstream с port. Если fail — вот оно. + +Timeout слишком короткий (3s в ctx): +На медленных upstream'ах (e.g., DNS over Tor или VPN) — запрос не успеет, вернёт [] и errs++. +Поломка: Unresolved домены, даже если upstream живой. При повторном запуске — из кэша не возьмёт (потому что не обновил). +Тест: Установи медленный DNS (e.g., 8.8.8.8 с delay) — увидишь dns.timeout. + +PTR fail без retry (в resolveStaticLabels): +digPTR без retry — если upstream временно down, PTR не получит, и лейбл = "*[STATIC-IP]". +Поломка: Если все upstream fail — все static без PTR, но код не крашится. В проде — потеря лейблов. +Тест: Установи invalid upstream — увидишь ptr_errors++, но код продолжит. + +loadDNSConfig fallback на дефолт без валидации: +Если файл битый (e.g., "default abc#invalid") — fallback на дефолт, но дефолт может не работать. +Поломка: Silent fail — резолв будет работать на старых дефолтах, но юзер не поймёт почему. +Тест: Сделай файл с кривыми upstream'ами — увидишь в логах fallback, но без ошибки. + +uniqueStrings и map в IPMap — но без сортировки в map: +IPMap использует map[string]struct{} для dedup, но при append в res.IPMap — sort.Slice по IP, но лейблы внутри не отсортированы. +Поломка: Не поломка, но нестабильный порядок лейблов в mapLines — если GUI полагается на порядок. +Тест: Несколько лейблов на IP — порядок random. + +isPrivateIPv4 не обрабатывает /mask в IP: +Проверяет только base IP, но если ip = "10.0.0.0/8" — parts = ["10.0.0.0/8"] → len!=4 → return true (ок). Но если mask invalid — не catch. +Поломка: Если static с кривым CIDR (e.g., "1.2.3.4/33") — пропустит как non-private. +Тест: Добавь "1.2.3.4/33" в static — пройдёт, хотя invalid. + +runResolverJob: если workers=0 или ttl=0 — не краш, но дефолты +Ты clamp'ишь (ttl 60..86400, workers 1..500) — хорошо, но если envInt fail — silent. +Поломка: Не поломка, но если ENV кривой — silent fallback. + + +2. Слабые места (не поломки, но риски в проде) + +Concurrent резолв без rate-limit: +80 workers + 500+ доменов — может flood upstream'ы (DNS rate-limit ban). +Риск: Upstream забанит IP, все запросы fail. + +Нет IPv6: +Всё только IPv4 (A-записи, isPrivateIPv4). Если домен только AAAA — no IPs. +Риск: Современные сайты (e.g., google) имеют IPv6 — потеряешь трафик. + +Кэш без invalidation: +Кэш на TTL, но если upstream сменился — старый кэш может быть invalid. +Риск: После смены DNS — старые IPs в кэше до TTL. + +PTR без кэша ошибок: +Если PTR fail — всё равно пишем в кэш empty names. Следующий раз — из кэша empty. +Риск: Временный fail upstream — навсегда потерянные лейблы до ручного clear cache. + +Логи только в файл: +appendTraceLine — ok, но если файл locked — silent fail. +Риск: В многопоточном режиме — race на append. + +envInt без дефолта fallback: +Если ENV empty — strconv.Atoi("") → error, но ты используешь default. Ok, но можно сделать safe. + + +3. Рекомендации и фиксы +Вот конкретные советы + кодовые патчи. Внедряй по приоритету (сначала поломки). + +Фикс портов и TCP в resolveHostGo (для поломки 1): +Используй net.Dialer с UDP/TCP fallback. +Patch:Gofunc resolveHostGo(host string, cfg dnsConfig, metaSpecial []string, logf func(string, ...any)) ([]string, int) { + // ... (твой код) + var ips []string + errs := 0 + + for _, upstream := range dnsList { + server, port := splitDNS(upstream) + if server == "" { + errs++ + continue + } + if port == "" { + port = "53" + } + addr := net.JoinHostPort(server, port) + + // Сначала UDP + ipsUdp, errUdp := lookupA(host, addr, "udp", 3*time.Second) + if errUdp == nil { + ips = append(ips, ipsUdp...) + continue + } else if logf != nil { + logf("UDP fail for %s@%s: %v", host, addr, errUdp) + } + + // Fallback TCP + ipsTcp, errTcp := lookupA(host, addr, "tcp", 5*time.Second) + if errTcp == nil { + ips = append(ips, ipsTcp...) + } else { + errs++ + if logf != nil { + logf("TCP fail for %s@%s: %v", host, addr, errTcp) + } + } + } + + // dedup + seen := map[string]struct{}{} + unique := []string{} + for _, ip := range ips { + if _, ok := seen[ip]; !ok && !isPrivateIPv4(ip) { + seen[ip] = struct{}{} + unique = append(unique, ip) + } + } + + return unique, errs +} + +func lookupA(host, addr, network string, timeout time.Duration) ([]string, error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + r := &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, _, _ string) (net.Conn, error) { + d := net.Dialer{} + return d.DialContext(ctx, network, addr) + }, + } + + ips, err := r.LookupHost(ctx, host) + if err != nil { + return nil, err + } + + var out []string + for _, ip := range ips { + if netip.MustParseAddr(ip).Is4() { + out = append(out, ip) + } + } + return out, nil +} + +Увеличь timeout + retry в resolveHostGo (для поломки 2): +Добавь retry для всей функции (3 попытки). +Patch: Оберни loop по upstream в backoff.Retry (как в nft). + +Retry для PTR (для поломки 3): +Patch в resolveStaticLabels:Go// Внутри if !names: +for attempt := 0; attempt < 3; attempt++ { + names, err = digPTR(base_ip, dns_for_ptr, 5*time.Second, logf) + if err == nil { + break + } + time.Sleep(500 * time.Millisecond) +} +if err != nil { + ptrErrors++ +} + +Валидация upstream в loadDNSConfig (для поломки 4): +Patch:Gofunc loadDNSConfig(path string, logf func(string, ...any)) dnsConfig { + // ... (твой код) + valid := func(s string) bool { + if s == "" { + return false + } + _, err := netip.ParseAddrPort(s) + return err == nil + } + cfg.Default = filter(cfg.Default, valid) + cfg.Meta = filter(cfg.Meta, valid) + if len(cfg.Default) == 0 { + cfg.Default = defaultDNS + if logf != nil { + logf("dns-config: no valid default upstreams, fallback to built-in") + } + } + // аналогично для meta +} + +func filter(ss []string, fn func(string) bool) []string { + var out []string + for _, s := range ss { + if fn(s) { + out = append(out, s) + } + } + return out +} + +Сортировка лейблов в IPMap (для поломки 5): +Patch в runResolverJob после for host, ips...:Gofor ip, labelsMap := range ipMap { + var labels []string + for lbl := range labelsMap { + labels = append(labels, lbl) + } + sort.Strings(labels) + for _, lbl := range labels { + res.IPMap = append(res.IPMap, [2]string{ip, lbl}) + } +} + +Валидация CIDR в isPrivateIPv4 (для поломки 6): +Patch: Используй netip.ParsePrefix:Gofunc isPrivateIPv4(ip string) bool { + if strings.Contains(ip, "/") { + p, err := netip.ParsePrefix(ip) + if err != nil { + return true // invalid → private (пропустить) + } + return p.IsPrivate() + } + a, err := netip.ParseAddr(ip) + if err != nil { + return true + } + return a.IsPrivate() +} + +Rate-limit в concurrent (для слабого 1): +Добавь semaphore в workers-loop:Gosem := make(chan struct{}, 20) // лимит 20 одновременных запросов +for i := 0; i < workers; i++ { + go func() { + for j := range jobs { + sem <- struct{}{} + // resolve + <-sem + } + }() +} + +IPv6 поддержка (для слабого 2): +Добавь AAAA в resolveHostGo (параллельно с A). +Добавь isPrivateIPv6. + +Кэш invalidation (для слабого 3): +Добавь флаг --force для игнора TTL. + +Логи в файл с sync (для слабого 5): +В appendTraceLine добавь file.Sync() после Write. + + +Итог + +Поломки: В основном edge-кейсы (медленный DNS, invalid upstream, большие ответы) — фиксы выше. +Слабые: Масштаб (rate-limit, IPv6) и robustness (retry PTR, validate). +Код в целом очень хороший — внедряй фиксы по приоритету (сначала порты/TCP, timeout/retry). diff --git a/selective-vpn-gui/agvpn-resolver.py b/selective-vpn-gui/agvpn-resolver.py new file mode 100755 index 0000000..2d87496 --- /dev/null +++ b/selective-vpn-gui/agvpn-resolver.py @@ -0,0 +1,645 @@ +#!/usr/bin/env python3 +import argparse +import concurrent.futures +import json +import os +import sys +import time +from collections import defaultdict + +# --- dnspython -------------------------------------------------------- +try: + import dns.resolver + import dns.reversename + import dns.exception +except ImportError as e: + print(f"[resolver] dnspython is required: {e}", file=sys.stderr) + sys.exit(2) + +# -------------------------------------------------------------------- +# Общий DNS-конфиг +# -------------------------------------------------------------------- +DNS_CONFIG_PATH = "/etc/selective-vpn/dns-upstreams.conf" + +DEFAULT_DNS_DEFAULT = ["94.140.14.14", "94.140.15.15"] +DEFAULT_DNS_META = ["46.243.231.30", "46.243.231.41"] + +DNS_DEFAULT = DEFAULT_DNS_DEFAULT.copy() +DNS_META = DEFAULT_DNS_META.copy() + + +# -------------------------------------------------------------------- +# helpers +# -------------------------------------------------------------------- +def log(msg, trace_log=None): + line = f"[resolver] {msg}" + print(line, file=sys.stderr) + if trace_log: + try: + with open(trace_log, "a") as f: + f.write(line + "\n") + except Exception: + pass + + +def is_private_ipv4(ip: str) -> bool: + """ + ip может быть "A.B.C.D" или "A.B.C.D/nn". + Возвращаем True, если адрес из приватных диапазонов. + """ + parts = ip.split("/") + base = parts[0] + try: + o1, o2, o3, o4 = map(int, base.split(".")) + except ValueError: + return True + + if o1 == 10: + return True + if o1 == 127: + return True + if o1 == 0: + return True + if o1 == 192 and o2 == 168: + return True + if o1 == 172 and 16 <= o2 <= 31: + return True + return False + + +def load_list(path): + if not os.path.exists(path): + return [] + out = [] + with open(path, "r") as f: + for line in f: + s = line.strip() + if not s or s.startswith("#"): + continue + out.append(s) + return out + + +def load_cache(path): + if not os.path.exists(path): + return {} + try: + with open(path, "r") as f: + return json.load(f) + except Exception: + return {} + + +def save_cache(path, data): + tmp = path + ".tmp" + try: + with open(tmp, "w") as f: + json.dump(data, f, indent=2, sort_keys=True) + os.replace(tmp, path) + except Exception: + pass + + +def split_dns(dns: str): + """ + Разбор записи вида: + "1.2.3.4" -> ("1.2.3.4", None) + "1.2.3.4#6053" -> ("1.2.3.4", "6053") + """ + if "#" in dns: + host, port = dns.split("#", 1) + host = host.strip() + port = port.strip() + if not host: + host = "127.0.0.1" + if not port: + port = "53" + return host, port + return dns, None + + +# -------------------------------------------------------------------- +# dnspython-резолвы +# -------------------------------------------------------------------- +def dig_a(host, dns_list, timeout=3): + """ + A-резолв через dnspython. + dns_list: либо строка "IP[#PORT]", либо список таких строк. + """ + if isinstance(dns_list, str): + dns_list = [dns_list] + + ips = [] + + for entry in dns_list: + server, port = split_dns(entry) + if not server: + continue + + r = dns.resolver.Resolver(configure=False) + r.nameservers = [server] + if port: + try: + r.port = int(port) + except ValueError: + r.port = 53 + r.timeout = timeout + r.lifetime = timeout + + try: + answer = r.resolve(host, "A") + except dns.exception.DNSException: + continue + except Exception: + continue + + for rr in answer: + s = rr.to_text().strip() + parts = s.split(".") + if len(parts) != 4: + continue + if all(p.isdigit() and 0 <= int(p) <= 255 for p in parts): + if not is_private_ipv4(s) and s not in ips: + ips.append(s) + + return ips + + +def dig_ptr(ip, upstream, timeout=3): + """ + PTR-резолв: ip -> список имён. + dns может быть "IP" или "IP#PORT". + """ + server, port = split_dns(upstream) + if not server: + return [] + + r = dns.resolver.Resolver(configure=False) + r.nameservers = [server] + if port: + try: + r.port = int(port) + except ValueError: + r.port = 53 + r.timeout = timeout + r.lifetime = timeout + + try: + rev = dns.reversename.from_address(ip) + except Exception: + return [] + + try: + answer = r.resolve(rev, "PTR") + except dns.exception.DNSException: + return [] + except Exception: + return [] + + names = [] + for rr in answer: + s = rr.to_text().strip() + if s.endswith("."): + s = s[:-1] + if s: + names.append(s.lower()) + return names + + +# -------------------------------------------------------------------- +# Загрузка DNS-конфига +# -------------------------------------------------------------------- +def load_dns_config(path=DNS_CONFIG_PATH, trace_log=None): + """ + Читает /etc/selective-vpn/dns-upstreams.conf и обновляет + глобальные DNS_DEFAULT / DNS_META. + + Формат строк: + default 1.2.3.4 5.6.7.8 + meta 9.9.9.9 8.8.8.8 + Можно использовать "ip#port", например 127.0.0.1#6053. + """ + global DNS_DEFAULT, DNS_META + + if not os.path.exists(path): + DNS_DEFAULT = DEFAULT_DNS_DEFAULT.copy() + DNS_META = DEFAULT_DNS_META.copy() + log( + f"dns-config: {path} not found, fallback to built-in defaults " + f"(default={DNS_DEFAULT}, meta={DNS_META})", + trace_log, + ) + return + + dflt = [] + meta = [] + + try: + with open(path, "r") as f: + for line in f: + s = line.strip() + if not s or s.startswith("#"): + continue + parts = s.split() + if len(parts) < 2: + continue + key = parts[0].lower() + addrs = parts[1:] + if key == "default": + dflt.extend(addrs) + elif key == "meta": + meta.extend(addrs) + except Exception as e: + DNS_DEFAULT = DEFAULT_DNS_DEFAULT.copy() + DNS_META = DEFAULT_DNS_META.copy() + log( + f"dns-config: failed to read {path}: {e}, fallback to built-in defaults " + f"(default={DNS_DEFAULT}, meta={DNS_META})", + trace_log, + ) + return + + if not dflt: + dflt = DEFAULT_DNS_DEFAULT.copy() + log( + "dns-config: no 'default' section, fallback to built-in for default", + trace_log, + ) + if not meta: + meta = DEFAULT_DNS_META.copy() + log("dns-config: no 'meta' section, fallback to built-in for meta", trace_log) + + DNS_DEFAULT = dflt + DNS_META = meta + log( + f"dns-config: accept {path}: " + f"default={', '.join(DNS_DEFAULT)}; meta={', '.join(DNS_META)}", + trace_log, + ) + + +def resolve_host(host, meta_special, trace_log=None): + """ + Forward-резолв одного домена (A-записи). + DNS берём из DNS_DEFAULT / DNS_META, которые загрузил load_dns_config(). + """ + if host in meta_special: + dns_list = DNS_META + else: + dns_list = DNS_DEFAULT + + ips = dig_a(host, dns_list) + + uniq = [] + for ip in ips: + if ip not in uniq: + uniq.append(ip) + + if uniq: + log(f"{host}: {', '.join(uniq)}", trace_log) + else: + log(f"{host}: no IPs", trace_log) + return uniq + + +def parse_static_entries(static_lines): + """ + static_lines — строки из static-ips.txt. + Возвращаем список кортежей (ip_entry, base_ip, comment). + """ + entries = [] + for line in static_lines: + s = line.strip() + if not s or s.startswith("#"): + continue + + if "#" in s: + ip_part, comment = s.split("#", 1) + ip_part = ip_part.strip() + comment = comment.strip() + else: + ip_part = s + comment = "" + + if not ip_part: + continue + if is_private_ipv4(ip_part): + continue + + base_ip = ip_part.split("/", 1)[0] + entries.append((ip_part, base_ip, comment)) + return entries + + +def resolve_static_entries(static_entries, ptr_cache, ttl_sec, trace_log=None): + """ + static_entries: список кортежей (ip_entry, base_ip, comment). + ip_entry — как в static-ips.txt (может быть с /mask) + base_ip — A.B.C.D (без маски) + comment — текст после # или "". + + Возвращаем dict: ip_entry -> список меток, + уже с префиксом '*' (чтобы можно было искать). + """ + now = int(time.time()) + result = {} + for ip_entry, base_ip, comment in static_entries: + labels = [] + # 1) если есть комментарий — он главнее всего + if comment: + labels.append(f"*{comment}") + # 2) если комментария нет, пробуем PTR (с кэшем) + if not comment: + cache_entry = ptr_cache.get(base_ip) + names = [] + if ( + cache_entry + and isinstance(cache_entry, dict) + and isinstance(cache_entry.get("last_resolved"), (int, float)) + ): + age = now - cache_entry["last_resolved"] + cached_names = cache_entry.get("names") or [] + if age <= ttl_sec and cached_names: + names = cached_names + if not names: + # PTR через те же DNS, что и обычный трафик (используем первый из default) + dns_for_ptr = DNS_DEFAULT[0] if DNS_DEFAULT else DEFAULT_DNS_DEFAULT[0] + + try: + names = dig_ptr(base_ip, dns_for_ptr) or [] + except Exception as e: + log( + f"PTR failed for {base_ip} (using {dns_for_ptr}): " + f"{type(e).__name__}: {e}", + trace_log, + ) + names = [] + + uniq_names = [] + for n in names: + if n not in uniq_names: + uniq_names.append(n) + names = uniq_names + ptr_cache[base_ip] = { + "names": names, + "last_resolved": now, + } + for n in names: + labels.append(f"*{n}") + # 3) если вообще ничего нет — ставим общий тег + if not labels: + labels = ["*[STATIC-IP]"] + result[ip_entry] = labels + log(f"static {ip_entry}: labels={', '.join(labels)}", trace_log) + return result + + +# -------------------------------------------------------------------- +# API-слой: одна чистая функция, которую легко вызвать откуда угодно +# -------------------------------------------------------------------- +def run_resolver_job( + *, + domains, + meta_special, + static_lines, + cache_path, + ptr_cache_path, + ttl_sec, + workers, + trace_log=None, +): + """ + Главный API резолвера. + + Вход: + domains — список доменов + meta_special — set() доменов из meta-special.txt + static_lines — строки из static-ips.txt + cache_path — путь к domain-cache.json + ptr_cache_path— путь к ptr-cache.json + ttl_sec — TTL кэша доменов / PTR + workers — число потоков + trace_log — путь к trace.log (или None) + + Выход: dict с ключами: + ips — отсортированный список IP/подсетей + ip_map — список (ip, label) пар (домен или *LABEL) + domain_cache — обновлённый кэш доменов + ptr_cache — обновлённый PTR-кэш + summary — статистика (dict) + """ + # --- подгружаем DNS-конфиг --- + load_dns_config(DNS_CONFIG_PATH, trace_log) + + meta_special = set(meta_special or []) + + log(f"domains to resolve: {len(domains)}", trace_log) + + # --- кэши --- + domain_cache = load_cache(cache_path) + ptr_cache = load_cache(ptr_cache_path) + now = int(time.time()) + + # --- разруливаем: что берём из domain_cache, что резолвим --- + fresh_from_cache = {} + to_resolve = [] + + for d in domains: + entry = domain_cache.get(d) + if entry and isinstance(entry, dict): + ts = entry.get("last_resolved") or 0 + ips = entry.get("ips") or [] + if isinstance(ts, (int, float)) and isinstance(ips, list) and ips: + if now - ts <= ttl_sec: + valid_ips = [ip for ip in ips if not is_private_ipv4(ip)] + if valid_ips: + fresh_from_cache[d] = valid_ips + continue + + to_resolve.append(d) + + log( + f"from cache: {len(fresh_from_cache)}, to resolve: {len(to_resolve)}", + trace_log, + ) + + resolved = dict(fresh_from_cache) + + total_domains = len(domains) + cache_hits = len(fresh_from_cache) + resolved_now = 0 + unresolved = 0 + + # --- параллельный резолв доменов --- + if to_resolve: + with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as ex: + fut2host = { + ex.submit(resolve_host, d, meta_special, trace_log): d + for d in to_resolve + } + for fut in concurrent.futures.as_completed(fut2host): + d = fut2host[fut] + try: + ips = fut.result() + except Exception as e: + log(f"{d}: resolver exception: {e}", trace_log) + ips = [] + + if ips: + resolved[d] = ips + domain_cache[d] = { + "ips": ips, + "last_resolved": now, + } + resolved_now += 1 + else: + unresolved += 1 + + # --- читаем static-ips и готовим список для PTR --- + static_entries = parse_static_entries(static_lines) + log(f"static entries: {len(static_entries)}", trace_log) + + # --- PTR/labels для static-ips --- + static_label_map = resolve_static_entries( + static_entries, ptr_cache, ttl_sec, trace_log + ) + + # --- собираем общий список IP и map --- + ip_set = set() + ip_to_domains = defaultdict(set) + + # доменные IP + for d, ips in resolved.items(): + for ip in ips: + ip_set.add(ip) + ip_to_domains[ip].add(d) + + # статические IP / сети + for ip_entry, _, _ in static_entries: + ip_set.add(ip_entry) + for label in static_label_map.get(ip_entry, []): + ip_to_domains[ip_entry].add(label) + + unique_ip_count = len(ip_set) + if unique_ip_count == 0: + log("no IPs resolved at all", trace_log) + else: + log(f"resolver done: {unique_ip_count} unique IPs", trace_log) + + ips_sorted = sorted(ip_set) + + # flatten ip_map + ip_map_pairs = [] + for ip in ips_sorted: + for dom in sorted(ip_to_domains[ip]): + ip_map_pairs.append((ip, dom)) + + summary = { + "domains_total": total_domains, + "from_cache": cache_hits, + "resolved_now": resolved_now, + "unresolved": unresolved, + "static_entries": len(static_entries), + "unique_ips": unique_ip_count, + } + + log( + "summary: domains=%d, from_cache=%d, resolved_now=%d, " + "unresolved=%d, static_entries=%d, unique_ips=%d" + % ( + summary["domains_total"], + summary["from_cache"], + summary["resolved_now"], + summary["unresolved"], + summary["static_entries"], + summary["unique_ips"], + ), + trace_log, + ) + + return { + "ips": ips_sorted, + "ip_map": ip_map_pairs, + "domain_cache": domain_cache, + "ptr_cache": ptr_cache, + "summary": summary, + } + + +# -------------------------------------------------------------------- +# CLI-обёртка вокруг API-функции (для bash-скрипта) +# -------------------------------------------------------------------- +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--domains", required=True, help="file with domains (one per line)") + ap.add_argument("--output-ips", required=True, help="file to write unique IPs") + ap.add_argument( + "--output-map", + required=True, + help="file to write IPdomain map", + ) + ap.add_argument("--meta-file", required=True, help="meta-special.txt path") + ap.add_argument("--static-ips", required=True, help="static-ips.txt path") + ap.add_argument("--cache", required=True, help="domain-cache.json path") + ap.add_argument("--ptr-cache", required=True, help="ptr-cache.json path") + ap.add_argument("--trace-log", default=None) + ap.add_argument("--workers", type=int, default=40) + ap.add_argument("--ttl-sec", type=int, default=24 * 3600) + args = ap.parse_args() + + trace_log = args.trace_log + + try: + # входные данные для API-функции + domains = load_list(args.domains) + meta_special = load_list(args.meta_file) + + static_lines = [] + if os.path.exists(args.static_ips): + with open(args.static_ips, "r") as f: + static_lines = f.read().splitlines() + + job_result = run_resolver_job( + domains=domains, + meta_special=meta_special, + static_lines=static_lines, + cache_path=args.cache, + ptr_cache_path=args.ptr_cache, + ttl_sec=args.ttl_sec, + workers=args.workers, + trace_log=trace_log, + ) + + ips_sorted = job_result["ips"] + ip_map_pairs = job_result["ip_map"] + domain_cache = job_result["domain_cache"] + ptr_cache = job_result["ptr_cache"] + + # output-ips: по одному IP/подсети + with open(args.output_ips, "w") as f: + for ip in ips_sorted: + f.write(ip + "\n") + + # output-map: IPдомен/метка + with open(args.output_map, "w") as f: + for ip, dom in ip_map_pairs: + f.write(f"{ip}\t{dom}\n") + + # сохраняем кэши + save_cache(args.cache, domain_cache) + save_cache(args.ptr_cache, ptr_cache) + + return 0 + + except Exception as e: + # настоящий фатал + log(f"FATAL resolver error: {e}", trace_log) + import traceback + + traceback.print_exc(file=sys.stderr) + return 2 + + +if __name__ == "__main__": + raise SystemExit(main()) + diff --git a/selective-vpn-gui/api_client.py b/selective-vpn-gui/api_client.py new file mode 100644 index 0000000..ac50f92 --- /dev/null +++ b/selective-vpn-gui/api_client.py @@ -0,0 +1,1121 @@ +#!/usr/bin/env python3 +"""Selective-VPN API client (UI-agnostic). + +Design goals: +- The dashboard (GUI) must NOT know any URLs, HTTP methods, JSON keys, or payload shapes. +- All REST details live here. +- Returned values are normalized into dataclasses for clean UI usage. + +Env: +- SELECTIVE_VPN_API (default: http://127.0.0.1:8080) + +This file is meant to be imported by a controller (dashboard_controller.py) and UI. +""" + +from __future__ import annotations + +from dataclasses import dataclass +import json +import os +import re +import time +from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, cast + +import requests + + +# --------------------------- +# Small utilities +# --------------------------- + +_ANSI_RE = re.compile(r"\x1B\[[0-9;]*[A-Za-z]") + + +def strip_ansi(s: str) -> str: + """Remove ANSI escape sequences.""" + if not s: + return "" + return _ANSI_RE.sub("", s) + + +# --------------------------- +# Models (UI-friendly) +# --------------------------- + +@dataclass(frozen=True) +class Status: + timestamp: str + ip_count: int + domain_count: int + iface: str + table: str + mark: str + # NOTE: backend uses omitempty for these, so they may be absent. + policy_route_ok: Optional[bool] + route_ok: Optional[bool] + + +@dataclass(frozen=True) +class CmdResult: + ok: bool + message: str + exit_code: Optional[int] = None + stdout: str = "" + stderr: str = "" + + +@dataclass(frozen=True) +class LoginState: + state: str + email: str + msg: str + # backend may also provide UI-ready fields + text: str + color: str + + +@dataclass(frozen=True) +class UnitState: + state: str + + +@dataclass(frozen=True) +class RoutesTimerState: + enabled: bool + + +@dataclass(frozen=True) +class TrafficModeStatus: + mode: str + desired_mode: str + applied_mode: str + preferred_iface: str + auto_local_bypass: bool + bypass_candidates: int + force_vpn_subnets: List[str] + force_vpn_uids: List[str] + force_vpn_cgroups: List[str] + force_direct_subnets: List[str] + force_direct_uids: List[str] + force_direct_cgroups: List[str] + overrides_applied: int + cgroup_resolved_uids: int + cgroup_warning: str + active_iface: str + iface_reason: str + rule_mark: bool + rule_full: bool + table_default: bool + probe_ok: bool + probe_message: str + healthy: bool + message: str + + +@dataclass(frozen=True) +class TrafficInterfaces: + interfaces: List[str] + preferred_iface: str + active_iface: str + iface_reason: str + + + +@dataclass(frozen=True) +class TrafficCandidateSubnet: + cidr: str + dev: str + kind: str + linkdown: bool + + +@dataclass(frozen=True) +class TrafficCandidateUnit: + unit: str + description: str + cgroup: str + + +@dataclass(frozen=True) +class TrafficCandidateUID: + uid: int + user: str + examples: List[str] + + +@dataclass(frozen=True) +class TrafficCandidates: + generated_at: str + subnets: List[TrafficCandidateSubnet] + units: List[TrafficCandidateUnit] + uids: List[TrafficCandidateUID] + + +@dataclass(frozen=True) +class DnsUpstreams: + default1: str + default2: str + meta1: str + meta2: str + + +@dataclass(frozen=True) +class SmartdnsServiceState: + state: str + + +@dataclass(frozen=True) +class DNSStatus: + via_smartdns: bool + smartdns_addr: str + mode: str + unit_state: str + runtime_nftset: bool + wildcard_source: str + runtime_config_path: str + runtime_config_error: str + + +@dataclass(frozen=True) +class SmartdnsRuntimeState: + enabled: bool + applied_enabled: bool + wildcard_source: str + unit_state: str + config_path: str + changed: bool = False + restarted: bool = False + message: str = "" + + +@dataclass(frozen=True) +class DomainsTable: + lines: List[str] + + +@dataclass(frozen=True) +class DomainsFile: + name: str + content: str + source: str = "" + + +@dataclass(frozen=True) +class VpnAutoloopStatus: + raw_text: str + status_word: str + + +@dataclass(frozen=True) +class VpnStatus: + desired_location: str + status_word: str + raw_text: str + unit_state: str + + +@dataclass(frozen=True) +class VpnLocation: + label: str + iso: str + + +@dataclass(frozen=True) +class TraceDump: + lines: List[str] + + +@dataclass(frozen=True) +class Event: + id: int + kind: str + ts: str + data: Any + +# --------------------------- +# AdGuard VPN interactive login-session (PTY) +# --------------------------- + +@dataclass(frozen=True) +class LoginSessionStart: + ok: bool + phase: str + level: str + pid: Optional[int] = None + email: str = "" + error: str = "" + + +@dataclass(frozen=True) +class LoginSessionState: + ok: bool + phase: str + level: str + alive: bool + url: str + email: str + cursor: int + lines: List[str] + can_open: bool + can_check: bool + can_cancel: bool + + +@dataclass(frozen=True) +class LoginSessionAction: + ok: bool + phase: str = "" + level: str = "" + error: str = "" + +# --------------------------- +# Errors +# --------------------------- + +@dataclass(frozen=True) +class ApiError(Exception): + """Raised when API call fails (network or non-2xx).""" + message: str + method: str + url: str + status_code: Optional[int] = None + response_text: str = "" + + def __str__(self) -> str: + code = f" ({self.status_code})" if self.status_code is not None else "" + tail = f": {self.response_text}" if self.response_text else "" + return f"{self.message}{code} [{self.method} {self.url}]{tail}" + + +# --------------------------- +# Client +# --------------------------- + +TraceMode = Literal["full", "gui", "smartdns"] +ServiceAction = Literal["start", "stop", "restart"] + + +class ApiClient: + """Domain API client. + + Public methods here are the ONLY surface the dashboard/controller should use. + """ + + def __init__( + self, + base_url: str, + *, + timeout: float = 5.0, + session: Optional[requests.Session] = None, + ) -> None: + self.base_url = base_url.rstrip("/") + self.timeout = float(timeout) + self._s = session or requests.Session() + + @classmethod + def from_env( + cls, + env_var: str = "SELECTIVE_VPN_API", + default: str = "http://127.0.0.1:8080", + *, + timeout: float = 5.0, + ) -> "ApiClient": + base = os.environ.get(env_var, default).rstrip("/") + return cls(base, timeout=timeout) + + # ---- low-level internals (private) ---- + + def _url(self, path: str) -> str: + if not path.startswith("/"): + path = "/" + path + return self.base_url + path + + def _request( + self, + method: str, + path: str, + *, + params: Optional[Dict[str, Any]] = None, + json_body: Optional[Dict[str, Any]] = None, + timeout: Optional[float] = None, + accept_json: bool = True, + ) -> requests.Response: + url = self._url(path) + headers: Dict[str, str] = {} + if accept_json: + headers["Accept"] = "application/json" + + try: + resp = self._s.request( + method=method.upper(), + url=url, + params=params, + json=json_body, + timeout=self.timeout if timeout is None else float(timeout), + headers=headers, + ) + except requests.RequestException as e: + raise ApiError("API request failed", method.upper(), url, None, str(e)) from e + + if not (200 <= resp.status_code < 300): + txt = resp.text.strip() + raise ApiError("API returned error", method.upper(), url, resp.status_code, txt) + + return resp + + def _json(self, resp: requests.Response) -> Any: + if not resp.content: + return None + try: + return resp.json() + except ValueError: + # Backend should be JSON, but keep safe fallback. + return {"raw": resp.text} + + # ---- event stream (SSE) ---- + + def events_stream(self, since: int = 0, stop: Optional[Callable[[], bool]] = None) -> Iterator[Event]: + """ + Iterate over server-sent events. Reconnects automatically on errors. + + Args: + since: last seen event id (inclusive). Server will replay newer ones. + stop: optional callable returning True to stop streaming. + """ + last = max(0, int(since)) + backoff = 1.0 + while True: + if stop and stop(): + return + try: + for ev in self._sse_once(last, stop): + if stop and stop(): + return + last = ev.id if ev.id else last + yield ev + # normal end → reconnect + backoff = 1.0 + except ApiError: + # bubble up API errors; caller decides + raise + except Exception: + # transient error, retry with backoff + time.sleep(backoff) + backoff = min(backoff * 2, 10.0) + + def _sse_once(self, since: int, stop: Optional[Callable[[], bool]]) -> Iterator[Event]: + headers = { + "Accept": "text/event-stream", + "Cache-Control": "no-cache", + } + params = {} + if since > 0: + params["since"] = str(since) + + url = self._url("/api/v1/events/stream") + # SSE соединение живёт долго: backend шлёт heartbeat каждые 15s, + # поэтому ставим более длинный read-timeout, иначе стандартные 5s + # приводят к ложным ошибокам чтения. + read_timeout = max(self.timeout * 3, 60.0) + try: + resp = self._s.request( + method="GET", + url=url, + headers=headers, + params=params, + stream=True, + timeout=(self.timeout, read_timeout), + ) + except requests.RequestException as e: + raise ApiError("API request failed", "GET", url, None, str(e)) from e + + if not (200 <= resp.status_code < 300): + txt = resp.text.strip() + raise ApiError("API returned error", "GET", url, resp.status_code, txt) + + ev_id: Optional[int] = None + ev_kind: str = "" + data_lines: List[str] = [] + + for raw in resp.iter_lines(decode_unicode=True): + if stop and stop(): + resp.close() + return + if raw is None: + continue + line = raw.strip("\r") + if line == "": + if data_lines or ev_kind or ev_id is not None: + ev = self._make_event(ev_id, ev_kind, data_lines) + if ev: + yield ev + ev_id = None + ev_kind = "" + data_lines = [] + continue + if line.startswith(":"): + # heartbeat/comment + continue + if line.startswith("id:"): + try: + ev_id = int(line[3:].strip()) + except ValueError: + ev_id = None + continue + if line.startswith("event:"): + ev_kind = line[6:].strip() + continue + if line.startswith("data:"): + data_lines.append(line[5:].lstrip()) + continue + # unknown field → ignore + + def _make_event(self, ev_id: Optional[int], ev_kind: str, data_lines: List[str]) -> Optional[Event]: + payload: Any = None + if data_lines: + data_str = "\n".join(data_lines) + try: + payload = json.loads(data_str) + except Exception: + payload = data_str + if isinstance(payload, dict): + id_val = ev_id + if id_val is None: + try: + id_val = int(payload.get("id", 0)) + except Exception: + id_val = 0 + kind_val = ev_kind or str(payload.get("kind") or "") + ts_val = str(payload.get("ts") or "") + data_val = payload.get("data", payload) + return Event(id=id_val, kind=kind_val, ts=ts_val, data=data_val) + return Event(id=ev_id or 0, kind=ev_kind, ts="", data=payload) + + # ---- domain methods ---- + + # Status / system + def get_status(self) -> Status: + data = cast(Dict[str, Any], self._json(self._request("GET", "/api/v1/status")) or {}) + return Status( + timestamp=str(data.get("timestamp") or ""), + ip_count=int(data.get("ip_count") or 0), + domain_count=int(data.get("domain_count") or 0), + iface=str(data.get("iface") or ""), + table=str(data.get("table") or ""), + mark=str(data.get("mark") or ""), + policy_route_ok=cast(Optional[bool], data.get("policy_route_ok", None)), + route_ok=cast(Optional[bool], data.get("route_ok", None)), + ) + + def systemd_state(self, unit: str) -> UnitState: + data = cast( + Dict[str, Any], + self._json( + self._request("GET", "/api/v1/systemd/state", params={"unit": unit}, timeout=2.0) + ) + or {}, + ) + st = str(data.get("state") or "unknown").strip() or "unknown" + return UnitState(state=st) + + def get_login_state(self) -> LoginState: + data = cast(Dict[str, Any], self._json(self._request("GET", "/api/v1/vpn/login-state", timeout=2.0)) or {}) + # Normalize and strip ANSI + state = str(data.get("state") or "unknown").strip() + email = strip_ansi(str(data.get("email") or "").strip()) + msg = strip_ansi(str(data.get("msg") or "").strip()) + text = strip_ansi(str(data.get("text") or "").strip()) + color = str(data.get("color") or "").strip() + + return LoginState( + state=state, + email=email, + msg=msg, + text=text, + color=color, + ) + + # Routes + def routes_service(self, action: ServiceAction) -> CmdResult: + action_l = action.lower() + if action_l not in ("start", "stop", "restart"): + raise ValueError(f"Invalid action: {action}") + url = self._url("/api/v1/routes/service") + payload = {"action": action_l} + try: + # короткий read-timeout: если systemctl висит минутами, отваливаемся, + # но сервер всё равно продолжит выполнение (runCommand не привязан к r.Context()). + resp = self._s.post(url, json=payload, timeout=(self.timeout, 2.0)) + except requests.Timeout: + return CmdResult( + ok=True, + message=f"{action_l} accepted; backend is still running systemctl", + exit_code=None, + ) + except requests.RequestException as e: + raise ApiError("API request failed", "POST", url, None, str(e)) from e + + if not (200 <= resp.status_code < 300): + txt = resp.text.strip() + raise ApiError("API returned error", "POST", url, resp.status_code, txt) + + data = cast(Dict[str, Any], self._json(resp) or {}) + return self._parse_cmd_result(data) + + def routes_clear(self) -> CmdResult: + data = cast(Dict[str, Any], self._json(self._request("POST", "/api/v1/routes/clear")) or {}) + return self._parse_cmd_result(data) + + def routes_cache_restore(self) -> CmdResult: + data = cast( + Dict[str, Any], + self._json(self._request("POST", "/api/v1/routes/cache/restore")) or {}, + ) + return self._parse_cmd_result(data) + + def routes_fix_policy_route(self) -> CmdResult: + data = cast(Dict[str, Any], self._json(self._request("POST", "/api/v1/routes/fix-policy-route")) or {}) + return self._parse_cmd_result(data) + + def routes_timer_get(self) -> RoutesTimerState: + data = cast(Dict[str, Any], self._json(self._request("GET", "/api/v1/routes/timer")) or {}) + return RoutesTimerState(enabled=bool(data.get("enabled", False))) + + def routes_timer_set(self, enabled: bool) -> CmdResult: + data = cast(Dict[str, Any], self._json(self._request("POST", "/api/v1/routes/timer", json_body={"enabled": bool(enabled)})) or {}) + return self._parse_cmd_result(data) + + def traffic_mode_get(self) -> TrafficModeStatus: + data = cast(Dict[str, Any], self._json(self._request("GET", "/api/v1/traffic/mode")) or {}) + return TrafficModeStatus( + mode=str(data.get("mode") or "selective"), + desired_mode=str(data.get("desired_mode") or data.get("mode") or "selective"), + applied_mode=str(data.get("applied_mode") or "direct"), + preferred_iface=str(data.get("preferred_iface") or ""), + auto_local_bypass=bool(data.get("auto_local_bypass", True)), + bypass_candidates=int(data.get("bypass_candidates", 0) or 0), + force_vpn_subnets=[str(x) for x in (data.get("force_vpn_subnets") or []) if str(x).strip()], + force_vpn_uids=[str(x) for x in (data.get("force_vpn_uids") or []) if str(x).strip()], + force_vpn_cgroups=[str(x) for x in (data.get("force_vpn_cgroups") or []) if str(x).strip()], + force_direct_subnets=[str(x) for x in (data.get("force_direct_subnets") or []) if str(x).strip()], + force_direct_uids=[str(x) for x in (data.get("force_direct_uids") or []) if str(x).strip()], + force_direct_cgroups=[str(x) for x in (data.get("force_direct_cgroups") or []) if str(x).strip()], + overrides_applied=int(data.get("overrides_applied", 0) or 0), + cgroup_resolved_uids=int(data.get("cgroup_resolved_uids", 0) or 0), + cgroup_warning=str(data.get("cgroup_warning") or ""), + active_iface=str(data.get("active_iface") or ""), + iface_reason=str(data.get("iface_reason") or ""), + rule_mark=bool(data.get("rule_mark", False)), + rule_full=bool(data.get("rule_full", False)), + table_default=bool(data.get("table_default", False)), + probe_ok=bool(data.get("probe_ok", False)), + probe_message=str(data.get("probe_message") or ""), + healthy=bool(data.get("healthy", False)), + message=str(data.get("message") or ""), + ) + + def traffic_mode_set( + self, + mode: str, + preferred_iface: Optional[str] = None, + auto_local_bypass: Optional[bool] = None, + force_vpn_subnets: Optional[List[str]] = None, + force_vpn_uids: Optional[List[str]] = None, + force_vpn_cgroups: Optional[List[str]] = None, + force_direct_subnets: Optional[List[str]] = None, + force_direct_uids: Optional[List[str]] = None, + force_direct_cgroups: Optional[List[str]] = None, + ) -> TrafficModeStatus: + m = str(mode or "").strip().lower() + if m not in ("selective", "full_tunnel", "direct"): + raise ValueError(f"Invalid traffic mode: {mode}") + payload: Dict[str, Any] = {"mode": m} + if preferred_iface is not None: + payload["preferred_iface"] = str(preferred_iface).strip() + if auto_local_bypass is not None: + payload["auto_local_bypass"] = bool(auto_local_bypass) + if force_vpn_subnets is not None: + payload["force_vpn_subnets"] = [str(x) for x in force_vpn_subnets] + if force_vpn_uids is not None: + payload["force_vpn_uids"] = [str(x) for x in force_vpn_uids] + if force_vpn_cgroups is not None: + payload["force_vpn_cgroups"] = [str(x) for x in force_vpn_cgroups] + if force_direct_subnets is not None: + payload["force_direct_subnets"] = [str(x) for x in force_direct_subnets] + if force_direct_uids is not None: + payload["force_direct_uids"] = [str(x) for x in force_direct_uids] + if force_direct_cgroups is not None: + payload["force_direct_cgroups"] = [str(x) for x in force_direct_cgroups] + data = cast( + Dict[str, Any], + self._json( + self._request( + "POST", + "/api/v1/traffic/mode", + json_body=payload, + ) + ) + or {}, + ) + return TrafficModeStatus( + mode=str(data.get("mode") or m), + desired_mode=str(data.get("desired_mode") or data.get("mode") or m), + applied_mode=str(data.get("applied_mode") or "direct"), + preferred_iface=str(data.get("preferred_iface") or ""), + auto_local_bypass=bool(data.get("auto_local_bypass", True)), + bypass_candidates=int(data.get("bypass_candidates", 0) or 0), + force_vpn_subnets=[str(x) for x in (data.get("force_vpn_subnets") or []) if str(x).strip()], + force_vpn_uids=[str(x) for x in (data.get("force_vpn_uids") or []) if str(x).strip()], + force_vpn_cgroups=[str(x) for x in (data.get("force_vpn_cgroups") or []) if str(x).strip()], + force_direct_subnets=[str(x) for x in (data.get("force_direct_subnets") or []) if str(x).strip()], + force_direct_uids=[str(x) for x in (data.get("force_direct_uids") or []) if str(x).strip()], + force_direct_cgroups=[str(x) for x in (data.get("force_direct_cgroups") or []) if str(x).strip()], + overrides_applied=int(data.get("overrides_applied", 0) or 0), + cgroup_resolved_uids=int(data.get("cgroup_resolved_uids", 0) or 0), + cgroup_warning=str(data.get("cgroup_warning") or ""), + active_iface=str(data.get("active_iface") or ""), + iface_reason=str(data.get("iface_reason") or ""), + rule_mark=bool(data.get("rule_mark", False)), + rule_full=bool(data.get("rule_full", False)), + table_default=bool(data.get("table_default", False)), + probe_ok=bool(data.get("probe_ok", False)), + probe_message=str(data.get("probe_message") or ""), + healthy=bool(data.get("healthy", False)), + message=str(data.get("message") or ""), + ) + + def traffic_mode_test(self) -> TrafficModeStatus: + data = cast( + Dict[str, Any], + self._json(self._request("GET", "/api/v1/traffic/mode/test")) or {}, + ) + return TrafficModeStatus( + mode=str(data.get("mode") or "selective"), + desired_mode=str(data.get("desired_mode") or data.get("mode") or "selective"), + applied_mode=str(data.get("applied_mode") or "direct"), + preferred_iface=str(data.get("preferred_iface") or ""), + auto_local_bypass=bool(data.get("auto_local_bypass", True)), + bypass_candidates=int(data.get("bypass_candidates", 0) or 0), + force_vpn_subnets=[str(x) for x in (data.get("force_vpn_subnets") or []) if str(x).strip()], + force_vpn_uids=[str(x) for x in (data.get("force_vpn_uids") or []) if str(x).strip()], + force_vpn_cgroups=[str(x) for x in (data.get("force_vpn_cgroups") or []) if str(x).strip()], + force_direct_subnets=[str(x) for x in (data.get("force_direct_subnets") or []) if str(x).strip()], + force_direct_uids=[str(x) for x in (data.get("force_direct_uids") or []) if str(x).strip()], + force_direct_cgroups=[str(x) for x in (data.get("force_direct_cgroups") or []) if str(x).strip()], + overrides_applied=int(data.get("overrides_applied", 0) or 0), + cgroup_resolved_uids=int(data.get("cgroup_resolved_uids", 0) or 0), + cgroup_warning=str(data.get("cgroup_warning") or ""), + active_iface=str(data.get("active_iface") or ""), + iface_reason=str(data.get("iface_reason") or ""), + rule_mark=bool(data.get("rule_mark", False)), + rule_full=bool(data.get("rule_full", False)), + table_default=bool(data.get("table_default", False)), + probe_ok=bool(data.get("probe_ok", False)), + probe_message=str(data.get("probe_message") or ""), + healthy=bool(data.get("healthy", False)), + message=str(data.get("message") or ""), + ) + + def traffic_interfaces_get(self) -> TrafficInterfaces: + data = cast( + Dict[str, Any], + self._json(self._request("GET", "/api/v1/traffic/interfaces")) or {}, + ) + raw = data.get("interfaces") or [] + if not isinstance(raw, list): + raw = [] + return TrafficInterfaces( + interfaces=[str(x) for x in raw if str(x).strip()], + preferred_iface=str(data.get("preferred_iface") or ""), + active_iface=str(data.get("active_iface") or ""), + iface_reason=str(data.get("iface_reason") or ""), + ) + + def traffic_candidates_get(self) -> TrafficCandidates: + data = cast( + Dict[str, Any], + self._json(self._request("GET", "/api/v1/traffic/candidates")) or {}, + ) + + subnets: List[TrafficCandidateSubnet] = [] + for it in (data.get("subnets") or []): + if not isinstance(it, dict): + continue + cidr = str(it.get("cidr") or "").strip() + if not cidr: + continue + subnets.append( + TrafficCandidateSubnet( + cidr=cidr, + dev=str(it.get("dev") or "").strip(), + kind=str(it.get("kind") or "").strip(), + linkdown=bool(it.get("linkdown", False)), + ) + ) + + units: List[TrafficCandidateUnit] = [] + for it in (data.get("units") or []): + if not isinstance(it, dict): + continue + unit = str(it.get("unit") or "").strip() + if not unit: + continue + units.append( + TrafficCandidateUnit( + unit=unit, + description=str(it.get("description") or "").strip(), + cgroup=str(it.get("cgroup") or "").strip(), + ) + ) + + uids: List[TrafficCandidateUID] = [] + for it in (data.get("uids") or []): + if not isinstance(it, dict): + continue + try: + uid = int(it.get("uid", 0) or 0) + except Exception: + continue + user = str(it.get("user") or "").strip() + raw_ex = it.get("examples") or [] + if not isinstance(raw_ex, list): + raw_ex = [] + examples = [str(x) for x in raw_ex if str(x).strip()] + uids.append(TrafficCandidateUID(uid=uid, user=user, examples=examples)) + + return TrafficCandidates( + generated_at=str(data.get("generated_at") or ""), + subnets=subnets, + units=units, + uids=uids, + ) + + + # DNS / SmartDNS + def dns_upstreams_get(self) -> DnsUpstreams: + data = cast(Dict[str, Any], self._json(self._request("GET", "/api/v1/dns-upstreams")) or {}) + return DnsUpstreams( + default1=str(data.get("default1") or ""), + default2=str(data.get("default2") or ""), + meta1=str(data.get("meta1") or ""), + meta2=str(data.get("meta2") or ""), + ) + + def dns_upstreams_set(self, cfg: DnsUpstreams) -> None: + self._request( + "POST", + "/api/v1/dns-upstreams", + json_body={ + "default1": cfg.default1, + "default2": cfg.default2, + "meta1": cfg.meta1, + "meta2": cfg.meta2, + }, + ) + + def dns_status_get(self) -> DNSStatus: + data = cast(Dict[str, Any], self._json(self._request("GET", "/api/v1/dns/status")) or {}) + return self._parse_dns_status(data) + + def dns_mode_set(self, via_smartdns: bool, smartdns_addr: str) -> DNSStatus: + mode = "hybrid_wildcard" if bool(via_smartdns) else "direct" + data = cast( + Dict[str, Any], + self._json( + self._request( + "POST", + "/api/v1/dns/mode", + json_body={ + "via_smartdns": bool(via_smartdns), + "smartdns_addr": str(smartdns_addr or ""), + "mode": mode, + }, + ) + ) + or {}, + ) + return self._parse_dns_status(data) + + def dns_smartdns_service_set(self, action: ServiceAction) -> DNSStatus: + act = action.lower() + if act not in ("start", "stop", "restart"): + raise ValueError(f"Invalid action: {action}") + data = cast( + Dict[str, Any], + self._json( + self._request( + "POST", + "/api/v1/dns/smartdns-service", + json_body={"action": act}, + ) + ) + or {}, + ) + if not bool(data.get("ok", False)): + raise ValueError(str(data.get("message") or f"SmartDNS {act} failed")) + return self._parse_dns_status(data) + + def smartdns_service_get(self) -> SmartdnsServiceState: + data = cast(Dict[str, Any], self._json(self._request("GET", "/api/v1/smartdns/service")) or {}) + return SmartdnsServiceState(state=str(data.get("state") or "unknown")) + + def smartdns_service_set(self, action: ServiceAction) -> CmdResult: + act = action.lower() + if act not in ("start", "stop", "restart"): + raise ValueError(f"Invalid action: {action}") + data = cast(Dict[str, Any], self._json(self._request("POST", "/api/v1/smartdns/service", json_body={"action": act})) or {}) + return self._parse_cmd_result(data) + + def smartdns_runtime_get(self) -> SmartdnsRuntimeState: + data = cast(Dict[str, Any], self._json(self._request("GET", "/api/v1/smartdns/runtime")) or {}) + return SmartdnsRuntimeState( + enabled=bool(data.get("enabled", False)), + applied_enabled=bool(data.get("applied_enabled", False)), + wildcard_source=str(data.get("wildcard_source") or ("both" if bool(data.get("enabled", False)) else "resolver")), + unit_state=str(data.get("unit_state") or "unknown"), + config_path=str(data.get("config_path") or ""), + changed=bool(data.get("changed", False)), + restarted=bool(data.get("restarted", False)), + message=str(data.get("message") or ""), + ) + + def smartdns_runtime_set(self, enabled: bool, restart: bool = True) -> SmartdnsRuntimeState: + data = cast( + Dict[str, Any], + self._json( + self._request( + "POST", + "/api/v1/smartdns/runtime", + json_body={"enabled": bool(enabled), "restart": bool(restart)}, + ) + ) + or {}, + ) + return SmartdnsRuntimeState( + enabled=bool(data.get("enabled", False)), + applied_enabled=bool(data.get("applied_enabled", False)), + wildcard_source=str(data.get("wildcard_source") or ("both" if bool(data.get("enabled", False)) else "resolver")), + unit_state=str(data.get("unit_state") or "unknown"), + config_path=str(data.get("config_path") or ""), + changed=bool(data.get("changed", False)), + restarted=bool(data.get("restarted", False)), + message=str(data.get("message") or ""), + ) + + def smartdns_prewarm(self, limit: int = 0, aggressive_subs: bool = False) -> CmdResult: + payload: Dict[str, Any] = {} + if int(limit) > 0: + payload["limit"] = int(limit) + if aggressive_subs: + payload["aggressive_subs"] = True + data = cast( + Dict[str, Any], + self._json(self._request("POST", "/api/v1/smartdns/prewarm", json_body=payload)) or {}, + ) + return self._parse_cmd_result(data) + + def _parse_dns_status(self, data: Dict[str, Any]) -> DNSStatus: + via = bool(data.get("via_smartdns", False)) + runtime = bool(data.get("runtime_nftset", True)) + return DNSStatus( + via_smartdns=via, + smartdns_addr=str(data.get("smartdns_addr") or ""), + mode=str(data.get("mode") or ("hybrid_wildcard" if via else "direct")), + unit_state=str(data.get("unit_state") or "unknown"), + runtime_nftset=runtime, + wildcard_source=str(data.get("wildcard_source") or ("both" if runtime else "resolver")), + runtime_config_path=str(data.get("runtime_config_path") or ""), + runtime_config_error=str(data.get("runtime_config_error") or ""), + ) + + # Domains editor + def domains_table(self) -> DomainsTable: + data = cast(Dict[str, Any], self._json(self._request("GET", "/api/v1/domains/table")) or {}) + lines = data.get("lines") or [] + if not isinstance(lines, list): + lines = [] + return DomainsTable(lines=[str(x) for x in lines]) + + def domains_file_get(self, name: Literal["bases", "meta", "subs", "static", "smartdns", "last-ips-map", "last-ips-map-direct", "last-ips-map-wildcard"]) -> DomainsFile: + data = cast(Dict[str, Any], self._json(self._request("GET", "/api/v1/domains/file", params={"name": name})) or {}) + content = str(data.get("content") or "") + source = str(data.get("source") or "") + return DomainsFile(name=name, content=content, source=source) + + def domains_file_set(self, name: Literal["bases", "meta", "subs", "static", "smartdns", "last-ips-map", "last-ips-map-direct", "last-ips-map-wildcard"], content: str) -> None: + self._request("POST", "/api/v1/domains/file", json_body={"name": name, "content": content}) + + # VPN + def vpn_autoloop_status(self) -> VpnAutoloopStatus: + data = cast(Dict[str, Any], self._json(self._request("GET", "/api/v1/vpn/autoloop-status", timeout=2.0)) or {}) + raw = strip_ansi(str(data.get("raw_text") or "").strip()) + word = str(data.get("status_word") or "unknown").strip() + return VpnAutoloopStatus(raw_text=raw, status_word=word) + + def vpn_status(self) -> VpnStatus: + data = cast(Dict[str, Any], self._json(self._request("GET", "/api/v1/vpn/status", timeout=2.0)) or {}) + return VpnStatus( + desired_location=str(data.get("desired_location") or "").strip(), + status_word=str(data.get("status_word") or "unknown").strip(), + raw_text=strip_ansi(str(data.get("raw_text") or "").strip()), + unit_state=str(data.get("unit_state") or "unknown").strip(), + ) + + def vpn_autoconnect(self, enable: bool) -> CmdResult: + action = "start" if enable else "stop" + data = cast(Dict[str, Any], self._json(self._request("POST", "/api/v1/vpn/autoconnect", json_body={"action": action})) or {}) + return self._parse_cmd_result(data) + + def vpn_locations(self) -> List[VpnLocation]: + data = cast(Dict[str, Any], self._json(self._request("GET", "/api/v1/vpn/locations", timeout=10.0)) or {}) + locs = data.get("locations") or [] + res: List[VpnLocation] = [] + if isinstance(locs, list): + for item in locs: + if isinstance(item, dict): + label = str(item.get("label") or "") + iso = str(item.get("iso") or "") + if label and iso: + res.append(VpnLocation(label=label, iso=iso)) + return res + + def vpn_set_location(self, iso: str) -> None: + val = str(iso).strip() + if not val: + raise ValueError("iso is required") + self._request("POST", "/api/v1/vpn/location", json_body={"iso": val}) + + # Trace + def trace_get(self, mode: TraceMode = "full") -> TraceDump: + m = str(mode).lower().strip() + if m not in ("full", "gui", "smartdns"): + m = "full" + data = cast(Dict[str, Any], self._json(self._request("GET", "/api/v1/trace-json", params={"mode": m}, timeout=5.0)) or {}) + lines = data.get("lines") or [] + if not isinstance(lines, list): + lines = [] + return TraceDump(lines=[strip_ansi(str(x)) for x in lines]) + + def trace_append(self, kind: Literal["gui", "smartdns", "info"], line: str) -> None: + try: + self._request("POST", "/api/v1/trace/append", json_body={"kind": kind, "line": str(line)}, timeout=2.0) + except ApiError: + # Logging must never crash UI. + pass + + # ---- AdGuard VPN interactive login-session (NEW) ---- + + def vpn_login_session_start(self) -> LoginSessionStart: + data = cast( + Dict[str, Any], + self._json(self._request("POST", "/api/v1/vpn/login/session/start", timeout=10.0)) or {}, + ) + pid_val = data.get("pid", None) + pid: Optional[int] + try: + pid = int(pid_val) if pid_val is not None else None + except (TypeError, ValueError): + pid = None + + return LoginSessionStart( + ok=bool(data.get("ok", False)), + phase=str(data.get("phase") or ""), + level=str(data.get("level") or ""), + pid=pid, + email=strip_ansi(str(data.get("email") or "").strip()), + error=strip_ansi(str(data.get("error") or "").strip()), + ) + + def vpn_login_session_state(self, since: int = 0) -> LoginSessionState: + since_i = int(since) if since is not None else 0 + data = cast( + Dict[str, Any], + self._json( + self._request( + "GET", + "/api/v1/vpn/login/session/state", + params={"since": str(max(0, since_i))}, + timeout=5.0, + ) + ) + or {}, + ) + + lines = data.get("lines") or [] + if not isinstance(lines, list): + lines = [] + + cursor_val = data.get("cursor", 0) + try: + cursor = int(cursor_val) + except (TypeError, ValueError): + cursor = 0 + + return LoginSessionState( + ok=bool(data.get("ok", False)), + phase=str(data.get("phase") or ""), + level=str(data.get("level") or ""), + alive=bool(data.get("alive", False)), + url=strip_ansi(str(data.get("url") or "").strip()), + email=strip_ansi(str(data.get("email") or "").strip()), + cursor=cursor, + lines=[strip_ansi(str(x)) for x in lines], + can_open=bool(data.get("can_open", False)), + can_check=bool(data.get("can_check", False)), + can_cancel=bool(data.get("can_cancel", False)), + ) + + def vpn_login_session_action(self, action: Literal["open", "check", "cancel"]) -> LoginSessionAction: + act = str(action).strip().lower() + if act not in ("open", "check", "cancel"): + raise ValueError(f"Invalid login-session action: {action}") + + data = cast( + Dict[str, Any], + self._json( + self._request( + "POST", + "/api/v1/vpn/login/session/action", + json_body={"action": act}, + timeout=10.0, + ) + ) + or {}, + ) + + # backend может вернуть {ok:false,error:"..."} без phase/level + return LoginSessionAction( + ok=bool(data.get("ok", False)), + phase=str(data.get("phase") or ""), + level=str(data.get("level") or ""), + error=strip_ansi(str(data.get("error") or "").strip()), + ) + + def vpn_login_session_stop(self) -> CmdResult: + # stop returns {"ok": true} — завернём в CmdResult, чтобы UI/Controller единообразно печатал + data = cast( + Dict[str, Any], + self._json(self._request("POST", "/api/v1/vpn/login/session/stop", timeout=10.0)) or {}, + ) + ok = bool(data.get("ok", False)) + return CmdResult(ok=ok, message="login session stopped" if ok else "failed to stop login session") + + def vpn_logout(self) -> CmdResult: + data = cast(Dict[str, Any], self._json(self._request("POST", "/api/v1/vpn/logout", timeout=20.0)) or {}) + return self._parse_cmd_result(data) + + # ---- helpers ---- + + def _parse_cmd_result(self, data: Dict[str, Any]) -> CmdResult: + ok = bool(data.get("ok", False)) + msg = str(data.get("message") or "") + exit_code_val = data.get("exitCode", None) + exit_code: Optional[int] + try: + exit_code = int(exit_code_val) if exit_code_val is not None else None + except (TypeError, ValueError): + exit_code = None + + stdout = strip_ansi(str(data.get("stdout") or "")) + stderr = strip_ansi(str(data.get("stderr") or "")) + return CmdResult(ok=ok, message=msg, exit_code=exit_code, stdout=stdout, stderr=stderr) diff --git a/selective-vpn-gui/dashboard_controller.py b/selective-vpn-gui/dashboard_controller.py new file mode 100644 index 0000000..ded7fbd --- /dev/null +++ b/selective-vpn-gui/dashboard_controller.py @@ -0,0 +1,847 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +DashboardController + +Тонкий "мозг" между UI и ApiClient. + +UI не должен знать URL'ы / JSON, только вызывать методы этого контроллера. +""" + +from __future__ import annotations + +from dataclasses import dataclass +import os +import re +from typing import Iterable, List, Literal, Optional, cast + +# вырезаем спам автопроверки из логов (CLI любит писать "Next check in ...") +_NEXT_CHECK_RE = re.compile( + r"(?:\b\d+s\.)?\s*Next check in\s+\d+s\.?", re.IGNORECASE +) + +from api_client import ( + ApiClient, + CmdResult, + DNSStatus, + DnsUpstreams, + DomainsFile, + DomainsTable, + Event, + LoginState, + Status, + TrafficCandidates, + TrafficInterfaces, + TrafficModeStatus, + TraceDump, + UnitState, + VpnLocation, + VpnStatus, + SmartdnsRuntimeState, + # login flow models + LoginSessionStart, + LoginSessionState, + LoginSessionAction, +) + +TraceMode = Literal["full", "gui", "smartdns"] +ServiceAction = Literal["start", "stop", "restart"] +LoginAction = Literal["open", "check", "cancel"] + + +# --------------------------- +# View models (UI-friendly) +# --------------------------- + +@dataclass(frozen=True) +class LoginView: + text: str + color: str + logged_in: bool + email: str + + +@dataclass(frozen=True) +class StatusOverviewView: + timestamp: str + counts: str + iface_table_mark: str + policy_route: str + routes_service: str + smartdns_service: str + vpn_service: str + + +@dataclass(frozen=True) +class VpnStatusView: + desired_location: str + pretty_text: str + + +@dataclass(frozen=True) +class ActionView: + ok: bool + pretty_text: str + + +@dataclass(frozen=True) +class LoginFlowView: + phase: str + level: str + dot_color: str + status_text: str + url: str + email: str + alive: bool + cursor: int + lines: List[str] + can_open: bool + can_check: bool + can_cancel: bool + + +@dataclass(frozen=True) +class VpnAutoconnectView: + """Для блока Autoconnect на вкладке AdGuardVPN.""" + enabled: bool # True = включён autoloop + unit_text: str # строка вида "unit: active" + color: str # "green" / "red" / "orange" + + +@dataclass(frozen=True) +class RoutesNftProgressView: + """Прогресс обновления nft-наборов (agvpn4).""" + percent: int + message: str + active: bool # True — пока идёт апдейт, False — когда закончили / ничего не идёт + + +@dataclass(frozen=True) +class TrafficModeView: + desired_mode: str + applied_mode: str + preferred_iface: str + auto_local_bypass: bool + bypass_candidates: int + force_vpn_subnets: List[str] + force_vpn_uids: List[str] + force_vpn_cgroups: List[str] + force_direct_subnets: List[str] + force_direct_uids: List[str] + force_direct_cgroups: List[str] + overrides_applied: int + cgroup_resolved_uids: int + cgroup_warning: str + active_iface: str + iface_reason: str + probe_ok: bool + probe_message: str + healthy: bool + message: str + + +# --------------------------- +# Controller +# --------------------------- + +class DashboardController: + def __init__( + self, + client: ApiClient, + *, + routes_unit: Optional[str] = None, + smartdns_unit: Optional[str] = None, + ) -> None: + self.client = client + self.routes_unit = ( + routes_unit + or os.environ.get("SELECTIVE_VPN_ROUTES_UNIT") + or "" + ) + self.smartdns_unit = ( + smartdns_unit + or os.environ.get("SELECTIVE_VPN_SMARTDNS_UNIT") + or "smartdns-local.service" + ) + + # -------- logging -------- + + def log_gui(self, msg: str) -> None: + self.client.trace_append("gui", msg) + + def log_smartdns(self, msg: str) -> None: + self.client.trace_append("smartdns", msg) + + # -------- events stream -------- + + def iter_events(self, since: int = 0, stop=None): + return self.client.events_stream(since=since, stop=stop) + + def classify_event(self, ev: Event) -> List[str]: + """Return list of areas to refresh for given event kind.""" + k = (ev.kind or "").strip().lower() + if not k: + return [] + if k in ("status_changed", "status_error"): + return ["status", "routes", "vpn"] + if k in ("login_state_changed", "login_state_error"): + return ["login", "vpn"] + if k == "autoloop_status_changed": + return ["vpn"] + if k == "unit_state_changed": + return ["status", "vpn", "routes", "dns"] + if k in ("trace_changed", "trace_append"): + return ["trace"] + if k == "routes_nft_progress": + # перерисовать блок "routes" (кнопки + прогресс) + return ["routes"] + if k == "traffic_mode_changed": + return ["routes", "status"] + return [] + + # -------- helpers -------- + + def _is_logged_in_state(self, st: LoginState) -> bool: + # backend “state” может быть любым, делаем устойчивую проверку + s = (st.state or "").strip().lower() + if st.email: + return True + if s in ("ok", "logged", "logged_in", "success", "authorized", "ready"): + return True + return False + + def _level_to_color(self, level: str) -> str: + lv = (level or "").strip().lower() + if lv in ("green", "ok", "true", "success"): + return "green" + if lv in ("red", "error", "false", "failed"): + return "red" + return "orange" + + # -------- overview / status -------- + + def get_login_view(self) -> LoginView: + st: LoginState = self.client.get_login_state() + + # Prefer backend UI-ready "text" if provided, else build it. + if st.text: + txt = st.text + else: + if st.email: + txt = f"AdGuard VPN: logged in as {st.email}" + else: + txt = "AdGuard VPN: (no login data)" + + logged_in = self._is_logged_in_state(st) + + # Цвет: либо из backend, либо простой нормализованный вариант + if st.color: + color = st.color + else: + if logged_in: + color = "green" + else: + s = (st.state or "").strip().lower() + color = "orange" if s in ("unknown", "checking") else "red" + + return LoginView( + text=txt, + color=color, + logged_in=logged_in, + email=st.email or "", + ) + + def get_status_overview(self) -> StatusOverviewView: + st: Status = self.client.get_status() + + routes_unit = self._resolve_routes_unit(st.iface) + routes_s: UnitState = ( + self.client.systemd_state(routes_unit) + if routes_unit + else UnitState(state="unknown") + ) + smartdns_s: UnitState = self.client.systemd_state(self.smartdns_unit) + vpn_st: VpnStatus = self.client.vpn_status() + + counts = f"domains={st.domain_count}, ips={st.ip_count}" + iface = f"iface={st.iface} table={st.table} mark={st.mark}" + + policy_route = self._format_policy_route(st.policy_route_ok, st.route_ok) + + # SmartDNS: если state пустой/unknown — считаем это ошибкой + smart_state = smartdns_s.state or "unknown" + if smart_state.lower() in ("", "unknown", "failed"): + smart_state = "ERROR (unknown state)" + + return StatusOverviewView( + timestamp=st.timestamp or "—", + counts=counts, + iface_table_mark=iface, + policy_route=policy_route, + routes_service=f"{routes_unit or 'selective-vpn2@.service'}: {routes_s.state}", + smartdns_service=f"{self.smartdns_unit}: {smart_state}", + # это состояние самого VPN-юнита, НЕ autoloop: + # т.е. работает ли AdGuardVPN-daemon / туннель + vpn_service=f"VPN: {vpn_st.unit_state}", + ) + + def _format_policy_route( + self, + policy_ok: Optional[bool], + route_ok: Optional[bool], + ) -> str: + if policy_ok is None and route_ok is None: + return "unknown (not checked)" + val = policy_ok if policy_ok is not None else route_ok + if val is True: + return "OK (default route present in VPN table)" + return "MISSING default route in VPN table" + + def _resolve_routes_unit(self, iface: str) -> str: + forced = (self.routes_unit or "").strip() + if forced: + return forced + ifc = (iface or "").strip() + if ifc and ifc != "-": + return f"selective-vpn2@{ifc}.service" + return "" + + # -------- VPN -------- + + def vpn_locations_view(self) -> List[VpnLocation]: + return self.client.vpn_locations() + + def vpn_status_view(self) -> VpnStatusView: + st = self.client.vpn_status() + pretty = self._pretty_vpn_status(st) + return VpnStatusView( + desired_location=st.desired_location, + pretty_text=pretty, + ) + + # --- autoconnect / autoloop --- + + def _autoconnect_from_auto(self, auto) -> bool: + """ + Вытаскиваем True/False из ответа /vpn/autoloop/status. + + Приоритет: + 1) явное поле auto.enabled (bool) + 2) эвристика по status_word / raw_text + """ + enabled_field = getattr(auto, "enabled", None) + if isinstance(enabled_field, bool): + return enabled_field + + word = (getattr(auto, "status_word", "") or "").strip().lower() + raw = (getattr(auto, "raw_text", "") or "").lower() + + # приоритет — явные статусы + if word in ( + "active", + "running", + "enabled", + "on", + "up", + "started", + "ok", + "true", + "yes", + ): + return True + if word in ("inactive", "stopped", "disabled", "off", "down", "false", "no"): + return False + + # фоллбек — по raw_text + if "inactive" in raw or "disabled" in raw or "failed" in raw: + return False + if "active" in raw or "running" in raw or "enabled" in raw: + return True + return False + + def vpn_autoconnect_view(self) -> VpnAutoconnectView: + try: + auto = self.client.vpn_autoloop_status() + except Exception as e: + return VpnAutoconnectView( + enabled=False, + unit_text=f"unit: ERROR ({e})", + color="red", + ) + + enabled = self._autoconnect_from_auto(auto) + + unit_state = ( + getattr(auto, "unit_state", "") # если backend так отдаёт + or (auto.status_word or "") + or "unknown" + ) + + text = f"unit: {unit_state}" + + low = f"{unit_state} {(auto.raw_text or '')}".lower() + if any(x in low for x in ("failed", "error", "unknown", "inactive", "dead")): + color = "red" + elif "active" in low or "running" in low or "enabled" in low: + color = "green" + else: + color = "orange" + + return VpnAutoconnectView(enabled=enabled, unit_text=text, color=color) + + def vpn_autoconnect_enabled(self) -> bool: + """Старый интерфейс — оставляем для кнопки toggle.""" + return self.vpn_autoconnect_view().enabled + + def vpn_set_autoconnect(self, enable: bool) -> VpnStatusView: + res = self.client.vpn_autoconnect(enable) + st = self.client.vpn_status() + pretty = self._pretty_cmd_then_status(res, st) + return VpnStatusView( + desired_location=st.desired_location, + pretty_text=pretty, + ) + + def vpn_set_location(self, iso: str) -> VpnStatusView: + self.client.vpn_set_location(iso) + st = self.client.vpn_status() + pretty = self._pretty_vpn_status(st) + return VpnStatusView( + desired_location=st.desired_location, + pretty_text=pretty, + ) + + def _pretty_vpn_status(self, st: VpnStatus) -> str: + lines = [ + f"unit_state: {st.unit_state}", + f"desired_location: {st.desired_location or '—'}", + f"status: {st.status_word}", + ] + if st.raw_text: + lines.append("") + lines.append(st.raw_text.strip()) + return "\n".join(lines).strip() + "\n" + + # -------- Login Flow (interactive) -------- + + def login_flow_start(self) -> LoginFlowView: + s: LoginSessionStart = self.client.vpn_login_session_start() + + dot = self._level_to_color(s.level) + + if not s.ok: + txt = s.error or "Failed to start login session" + return LoginFlowView( + phase=s.phase or "failed", + level=s.level or "red", + dot_color="red", + status_text=txt, + url="", + email="", + alive=False, + cursor=0, + lines=[txt], + can_open=False, + can_check=False, + can_cancel=False, + ) + + if (s.phase or "").lower() == "already_logged": + txt = ( + f"Already logged in as {s.email}" + if s.email + else "Already logged in" + ) + return LoginFlowView( + phase="already_logged", + level="green", + dot_color="green", + status_text=txt, + url="", + email=s.email or "", + alive=False, + cursor=0, + lines=[txt], + can_open=False, + can_check=False, + can_cancel=False, + ) + + txt = f"Login started (pid={s.pid})" if s.pid else "Login started" + return LoginFlowView( + phase=s.phase or "starting", + level=s.level or "yellow", + dot_color=dot, + status_text=txt, + url="", + email="", + alive=True, + cursor=0, + lines=[], + can_open=True, + can_check=True, + can_cancel=True, + ) + + def login_flow_poll(self, since: int) -> LoginFlowView: + st: LoginSessionState = self.client.vpn_login_session_state(since=since) + + dot = self._level_to_color(st.level) + + phase = (st.phase or "").lower() + if phase == "waiting_browser": + status_txt = "Waiting for browser authorization…" + elif phase == "checking": + status_txt = "Checking…" + elif phase == "success": + status_txt = "✅ Logged in" + elif phase == "failed": + status_txt = "❌ Login failed" + elif phase == "cancelled": + status_txt = "Cancelled" + elif phase == "already_logged": + status_txt = ( + f"Already logged in as {st.email}" + if st.email + else "Already logged in" + ) + else: + status_txt = st.phase or "…" + + clean_lines = self._clean_login_lines(st.lines) + + return LoginFlowView( + phase=st.phase, + level=st.level, + dot_color=dot, + status_text=status_txt, + url=st.url, + email=st.email, + alive=st.alive, + cursor=st.cursor, + can_open=st.can_open, + can_check=st.can_cancel, + can_cancel=st.can_cancel, + lines=clean_lines, + ) + + def login_flow_action(self, action: str) -> ActionView: + act = action.strip().lower() + if act not in ("open", "check", "cancel"): + raise ValueError(f"Invalid login action: {action}") + + res: LoginSessionAction = self.client.vpn_login_session_action( + cast(LoginAction, act) + ) + + if not res.ok: + txt = res.error or "Login action failed" + return ActionView(ok=False, pretty_text=txt + "\n") + + txt = f"OK: {act} → phase={res.phase} level={res.level}" + return ActionView(ok=True, pretty_text=txt + "\n") + + def login_flow_stop(self) -> ActionView: + res = self.client.vpn_login_session_stop() + return ActionView(ok=res.ok, pretty_text=self._pretty_cmd(res)) + + def vpn_logout(self) -> ActionView: + res = self.client.vpn_logout() + return ActionView(ok=res.ok, pretty_text=self._pretty_cmd(res)) + + # Баннер "AdGuard VPN: logged in as ...", по клику показываем инфу как в CLI + def login_banner_cli_text(self) -> str: + try: + st: LoginState = self.client.get_login_state() + except Exception as e: + return f"Failed to query login state: {e}" + + # backend может не иметь поля error, поэтому через getattr + err = getattr(st, "error", None) or getattr(st, "message", None) + if err: + return str(err) + + if st.email: + return f"You are already logged in.\nCurrent user is {st.email}" + + if st.state: + return f"Login state: {st.state}" + + return "No login information available." + + # -------- Routes -------- + + def routes_service_action(self, action: str) -> ActionView: + act = action.strip().lower() + if act not in ("start", "stop", "restart"): + raise ValueError(f"Invalid routes action: {action}") + res = self.client.routes_service(cast(ServiceAction, act)) + return ActionView(ok=res.ok, pretty_text=self._pretty_cmd(res)) + + def routes_clear(self) -> ActionView: + res = self.client.routes_clear() + return ActionView(ok=res.ok, pretty_text=self._pretty_cmd(res)) + + def routes_cache_restore(self) -> ActionView: + res = self.client.routes_cache_restore() + return ActionView(ok=res.ok, pretty_text=self._pretty_cmd(res)) + + def routes_fix_policy_route(self) -> ActionView: + res = self.client.routes_fix_policy_route() + return ActionView(ok=res.ok, pretty_text=self._pretty_cmd(res)) + + def routes_timer_enabled(self) -> bool: + st = self.client.routes_timer_get() + return bool(st.enabled) + + def routes_timer_set(self, enabled: bool) -> ActionView: + res = self.client.routes_timer_set(bool(enabled)) + return ActionView(ok=res.ok, pretty_text=self._pretty_cmd(res)) + + def traffic_mode_view(self) -> TrafficModeView: + st: TrafficModeStatus = self.client.traffic_mode_get() + return TrafficModeView( + desired_mode=(st.desired_mode or st.mode or "selective"), + applied_mode=(st.applied_mode or "direct"), + preferred_iface=st.preferred_iface or "", + auto_local_bypass=bool(st.auto_local_bypass), + bypass_candidates=int(st.bypass_candidates), + force_vpn_subnets=list(st.force_vpn_subnets or []), + force_vpn_uids=list(st.force_vpn_uids or []), + force_vpn_cgroups=list(st.force_vpn_cgroups or []), + force_direct_subnets=list(st.force_direct_subnets or []), + force_direct_uids=list(st.force_direct_uids or []), + force_direct_cgroups=list(st.force_direct_cgroups or []), + overrides_applied=int(st.overrides_applied), + cgroup_resolved_uids=int(st.cgroup_resolved_uids), + cgroup_warning=st.cgroup_warning or "", + active_iface=st.active_iface or "", + iface_reason=st.iface_reason or "", + probe_ok=bool(st.probe_ok), + probe_message=st.probe_message or "", + healthy=bool(st.healthy), + message=st.message or "", + ) + + def traffic_mode_set( + self, + mode: str, + preferred_iface: Optional[str] = None, + auto_local_bypass: Optional[bool] = None, + force_vpn_subnets: Optional[List[str]] = None, + force_vpn_uids: Optional[List[str]] = None, + force_vpn_cgroups: Optional[List[str]] = None, + force_direct_subnets: Optional[List[str]] = None, + force_direct_uids: Optional[List[str]] = None, + force_direct_cgroups: Optional[List[str]] = None, + ) -> TrafficModeView: + st: TrafficModeStatus = self.client.traffic_mode_set( + mode, + preferred_iface, + auto_local_bypass, + force_vpn_subnets, + force_vpn_uids, + force_vpn_cgroups, + force_direct_subnets, + force_direct_uids, + force_direct_cgroups, + ) + return TrafficModeView( + desired_mode=(st.desired_mode or st.mode or mode), + applied_mode=(st.applied_mode or "direct"), + preferred_iface=st.preferred_iface or "", + auto_local_bypass=bool(st.auto_local_bypass), + bypass_candidates=int(st.bypass_candidates), + force_vpn_subnets=list(st.force_vpn_subnets or []), + force_vpn_uids=list(st.force_vpn_uids or []), + force_vpn_cgroups=list(st.force_vpn_cgroups or []), + force_direct_subnets=list(st.force_direct_subnets or []), + force_direct_uids=list(st.force_direct_uids or []), + force_direct_cgroups=list(st.force_direct_cgroups or []), + overrides_applied=int(st.overrides_applied), + cgroup_resolved_uids=int(st.cgroup_resolved_uids), + cgroup_warning=st.cgroup_warning or "", + active_iface=st.active_iface or "", + iface_reason=st.iface_reason or "", + probe_ok=bool(st.probe_ok), + probe_message=st.probe_message or "", + healthy=bool(st.healthy), + message=st.message or "", + ) + + def traffic_mode_test(self) -> TrafficModeView: + st: TrafficModeStatus = self.client.traffic_mode_test() + return TrafficModeView( + desired_mode=(st.desired_mode or st.mode or "selective"), + applied_mode=(st.applied_mode or "direct"), + preferred_iface=st.preferred_iface or "", + auto_local_bypass=bool(st.auto_local_bypass), + bypass_candidates=int(st.bypass_candidates), + force_vpn_subnets=list(st.force_vpn_subnets or []), + force_vpn_uids=list(st.force_vpn_uids or []), + force_vpn_cgroups=list(st.force_vpn_cgroups or []), + force_direct_subnets=list(st.force_direct_subnets or []), + force_direct_uids=list(st.force_direct_uids or []), + force_direct_cgroups=list(st.force_direct_cgroups or []), + overrides_applied=int(st.overrides_applied), + cgroup_resolved_uids=int(st.cgroup_resolved_uids), + cgroup_warning=st.cgroup_warning or "", + active_iface=st.active_iface or "", + iface_reason=st.iface_reason or "", + probe_ok=bool(st.probe_ok), + probe_message=st.probe_message or "", + healthy=bool(st.healthy), + message=st.message or "", + ) + + def traffic_interfaces(self) -> List[str]: + st: TrafficInterfaces = self.client.traffic_interfaces_get() + vals = [x for x in st.interfaces if x] + if st.preferred_iface and st.preferred_iface not in vals: + vals.insert(0, st.preferred_iface) + return vals + + def traffic_candidates(self) -> TrafficCandidates: + return self.client.traffic_candidates_get() + + + def routes_nft_progress_from_event(self, ev: Event) -> RoutesNftProgressView: + """ + Превращает Event(kind='routes_nft_progress') в удобную модель + для прогресс-бара/лейбла. + """ + payload = ( + getattr(ev, "data", None) + or getattr(ev, "payload", None) + or getattr(ev, "extra", None) + or {} + ) + + if not isinstance(payload, dict): + payload = {} + + try: + percent = int(payload.get("percent", 0)) + except Exception: + percent = 0 + + msg = str(payload.get("message", "")) if payload is not None else "" + if not msg: + msg = "Updating nft set…" + + active = 0 <= percent < 100 + + return RoutesNftProgressView( + percent=percent, + message=msg, + active=active, + ) + + # -------- DNS / SmartDNS -------- + + def dns_upstreams_view(self) -> DnsUpstreams: + return self.client.dns_upstreams_get() + + def dns_upstreams_save(self, cfg: DnsUpstreams) -> None: + self.client.dns_upstreams_set(cfg) + + def dns_status_view(self) -> DNSStatus: + return self.client.dns_status_get() + + def dns_mode_set(self, via: bool, smartdns_addr: str) -> DNSStatus: + return self.client.dns_mode_set(via, smartdns_addr) + + def smartdns_service_action(self, action: str) -> DNSStatus: + act = action.strip().lower() + if act not in ("start", "stop", "restart"): + raise ValueError(f"Invalid SmartDNS action: {action}") + return self.client.dns_smartdns_service_set(cast(ServiceAction, act)) + + def smartdns_prewarm(self, limit: int = 0, aggressive_subs: bool = False) -> ActionView: + res = self.client.smartdns_prewarm(limit=limit, aggressive_subs=aggressive_subs) + return ActionView(ok=res.ok, pretty_text=self._pretty_cmd(res)) + + def smartdns_runtime_view(self) -> SmartdnsRuntimeState: + return self.client.smartdns_runtime_get() + + def smartdns_runtime_set(self, enabled: bool, restart: bool = True) -> SmartdnsRuntimeState: + return self.client.smartdns_runtime_set(enabled=enabled, restart=restart) + + # -------- Domains -------- + + def domains_table_view(self) -> DomainsTable: + return self.client.domains_table() + + def domains_file_load(self, name: str) -> DomainsFile: + nm = name.strip().lower() + if nm not in ("bases", "meta", "subs", "static", "smartdns", "last-ips-map", "last-ips-map-direct", "last-ips-map-wildcard"): + raise ValueError(f"Invalid domains file name: {name}") + return self.client.domains_file_get( + cast(Literal["bases", "meta", "subs", "static", "smartdns", "last-ips-map", "last-ips-map-direct", "last-ips-map-wildcard"], nm) + ) + + def domains_file_save(self, name: str, content: str) -> None: + nm = name.strip().lower() + if nm not in ("bases", "meta", "subs", "static", "smartdns", "last-ips-map", "last-ips-map-direct", "last-ips-map-wildcard"): + raise ValueError(f"Invalid domains file name: {name}") + self.client.domains_file_set( + cast(Literal["bases", "meta", "subs", "static", "smartdns", "last-ips-map", "last-ips-map-direct", "last-ips-map-wildcard"], nm), content + ) + + # -------- Trace -------- + + def trace_view(self, mode: TraceMode = "full") -> TraceDump: + return self.client.trace_get(mode) + + # -------- formatting helpers -------- + + def _pretty_cmd(self, res: CmdResult) -> str: + lines: List[str] = [] + lines.append("OK" if res.ok else "ERROR") + if res.message: + lines.append(res.message.strip()) + if res.exit_code is not None: + lines.append(f"exit_code: {res.exit_code}") + if res.stdout.strip(): + lines.append("") + lines.append("stdout:") + lines.append(res.stdout.rstrip()) + if res.stderr.strip() and res.stderr.strip() != res.stdout.strip(): + lines.append("") + lines.append("stderr:") + lines.append(res.stderr.rstrip()) + return "\n".join(lines).strip() + "\n" + + def _pretty_cmd_then_status(self, res: CmdResult, st: VpnStatus) -> str: + return ( + self._pretty_cmd(res).rstrip() + + "\n\n" + + self._pretty_vpn_status(st).rstrip() + + "\n" + ) + + def _clean_login_lines(self, lines: Iterable[str]) -> List[str]: + out: List[str] = [] + for raw in lines or []: + if raw is None: + continue + + s = str(raw).replace("\r", "\n") + for part in s.splitlines(): + t = part.strip() + if not t: + continue + + # вырезаем спам "Next check in ..." + t2 = _NEXT_CHECK_RE.sub("", t).strip() + if not t2: + continue + + # на всякий — повторно + t2 = _NEXT_CHECK_RE.sub("", t2).strip() + if not t2: + continue + + out.append(t2) + return out diff --git a/selective-vpn-gui/internal/assets/domains/bases.txt b/selective-vpn-gui/internal/assets/domains/bases.txt new file mode 100644 index 0000000..504d1b3 --- /dev/null +++ b/selective-vpn-gui/internal/assets/domains/bases.txt @@ -0,0 +1,2 @@ +### +# Default bases list (seed). Add domains here; one per line. diff --git a/selective-vpn-gui/internal/assets/domains/meta-special.txt b/selective-vpn-gui/internal/assets/domains/meta-special.txt new file mode 100644 index 0000000..06f5bc9 --- /dev/null +++ b/selective-vpn-gui/internal/assets/domains/meta-special.txt @@ -0,0 +1 @@ +# meta domains (seed) diff --git a/selective-vpn-gui/internal/assets/domains/static-ips.txt b/selective-vpn-gui/internal/assets/domains/static-ips.txt new file mode 100644 index 0000000..d1b8402 --- /dev/null +++ b/selective-vpn-gui/internal/assets/domains/static-ips.txt @@ -0,0 +1 @@ +# static IPs (seed) diff --git a/selective-vpn-gui/internal/assets/domains/subs.txt b/selective-vpn-gui/internal/assets/domains/subs.txt new file mode 100644 index 0000000..ec4de7e --- /dev/null +++ b/selective-vpn-gui/internal/assets/domains/subs.txt @@ -0,0 +1,3 @@ +www +api +static diff --git a/selective-vpn-gui/traffic_mode_dialog.py b/selective-vpn-gui/traffic_mode_dialog.py new file mode 100644 index 0000000..7bf9672 --- /dev/null +++ b/selective-vpn-gui/traffic_mode_dialog.py @@ -0,0 +1,1012 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import os +from typing import Callable + +from PySide6 import QtCore, QtGui +from PySide6.QtWidgets import ( + QCheckBox, + QComboBox, + QDialog, + QGroupBox, + QHBoxLayout, + QLabel, + QLineEdit, + QListWidget, + QListWidgetItem, + QAbstractItemView, + QMessageBox, + QPlainTextEdit, + QPushButton, + QRadioButton, + QTabWidget, + QVBoxLayout, + QWidget, +) + +from dashboard_controller import DashboardController + + +class TrafficModeDialog(QDialog): + def __init__( + self, + controller: DashboardController, + *, + log_cb: Callable[[str], None] | None = None, + refresh_cb: Callable[[], None] | None = None, + parent=None, + ) -> None: + super().__init__(parent) + self.ctrl = controller + self.log_cb = log_cb + self.refresh_cb = refresh_cb + + self.setWindowTitle("Traffic mode settings") + self.resize(780, 760) + + root = QVBoxLayout(self) + + hint_group = QGroupBox("Mode behavior") + hint_layout = QVBoxLayout(hint_group) + hint_layout.addWidget(QLabel("Selective: only marked traffic goes via VPN.")) + hint_layout.addWidget(QLabel("Full tunnel: all traffic goes via VPN.")) + hint_layout.addWidget(QLabel("Direct: VPN routing rules are disabled.")) + warn = QLabel( + "Warning: Full tunnel can break local/LAN access depending on your host routes." + ) + warn.setStyleSheet("color: red;") + hint_layout.addWidget(warn) + root.addWidget(hint_group) + + tip = QLabel("Tip: hover any control for help. Подсказка: наведи на элемент для описания.") + tip.setWordWrap(True) + tip.setStyleSheet("color: gray;") + root.addWidget(tip) + + self.tabs = QTabWidget() + root.addWidget(self.tabs, stretch=1) + + tab_basic = QWidget() + tab_basic_layout = QVBoxLayout(tab_basic) + + mode_group = QGroupBox("Traffic mode relay") + mode_layout = QVBoxLayout(mode_group) + + row_mode = QHBoxLayout() + self.rad_selective = QRadioButton("Selective") + self.rad_selective.setToolTip("""EN: Only marked traffic (fwmark 0x66) uses VPN policy table (agvpn). +RU: Только помеченный трафик (fwmark 0x66) идет через policy-table (agvpn).""") + self.rad_selective.toggled.connect( + lambda checked: self.on_mode_toggle("selective", checked) + ) + row_mode.addWidget(self.rad_selective) + + self.rad_full = QRadioButton("Full tunnel") + self.rad_full.setToolTip("""EN: All traffic uses VPN policy table (agvpn). Use with auto-local bypass for LAN/docker. +RU: Весь трафик идет через policy-table (agvpn). Для LAN/docker включай auto-local bypass.""") + self.rad_full.toggled.connect( + lambda checked: self.on_mode_toggle("full_tunnel", checked) + ) + row_mode.addWidget(self.rad_full) + + self.rad_direct = QRadioButton("Direct") + self.rad_direct.setToolTip("""EN: Disables base VPN routing rules (no full/selective rule). +RU: Отключает базовые VPN policy-rules (нет full/selective правила).""") + self.rad_direct.toggled.connect( + lambda checked: self.on_mode_toggle("direct", checked) + ) + row_mode.addWidget(self.rad_direct) + row_mode.addStretch(1) + mode_layout.addLayout(row_mode) + + row_iface = QHBoxLayout() + row_iface.addWidget(QLabel("Preferred iface")) + self.cmb_iface = QComboBox() + self.cmb_iface.setToolTip("""EN: VPN interface for policy routing. Use auto unless you know the exact iface. +RU: Интерфейс VPN для policy routing. Оставь auto, если не уверен.""") + self.cmb_iface.setEditable(True) + self.cmb_iface.setInsertPolicy(QComboBox.NoInsert) + self.cmb_iface.setMinimumWidth(180) + row_iface.addWidget(self.cmb_iface) + + self.btn_refresh_ifaces = QPushButton("Detect ifaces") + self.btn_refresh_ifaces.setToolTip("""EN: Refresh list of available interfaces (UP). +RU: Обновить список доступных интерфейсов (UP).""") + self.btn_refresh_ifaces.clicked.connect(self.on_refresh_ifaces) + row_iface.addWidget(self.btn_refresh_ifaces) + row_iface.addStretch(1) + mode_layout.addLayout(row_iface) + + self.chk_auto_local = QCheckBox("Auto-local bypass (LAN/container subnets)") + self.chk_auto_local.setToolTip("""EN: Mirrors local/LAN/docker routes from main into agvpn table to prevent breakage in full tunnel. +EN: This does NOT force containers to use direct internet; use Force Direct subnets for that. +RU: Копирует локальные/LAN/docker маршруты из main в agvpn, чтобы не ломалась локалка в full tunnel. +RU: Это НЕ делает контейнеры direct в интернет; для этого используй Force Direct subnets.""") + self.chk_auto_local.stateChanged.connect(lambda _state: self.on_auto_local_toggle()) + mode_layout.addWidget(self.chk_auto_local) + + self.lbl_state = QLabel("Traffic mode: —") + self.lbl_state.setStyleSheet("color: gray;") + mode_layout.addWidget(self.lbl_state) + + self.lbl_diag = QLabel("—") + self.lbl_diag.setStyleSheet("color: gray;") + mode_layout.addWidget(self.lbl_diag) + + tab_basic_layout.addWidget(mode_group) + + maint_group = QGroupBox("Rollback / cache") + maint_layout = QHBoxLayout(maint_group) + self.btn_rollback = QPushButton("Clear routes (save cache)") + self.btn_rollback.setToolTip("""EN: Clears VPN routes and nft sets, but saves a cache snapshot for restore. +RU: Очищает VPN маршруты и nft-сеты, но сохраняет снапшот для восстановления.""") + self.btn_rollback.clicked.connect(self.on_rollback) + maint_layout.addWidget(self.btn_rollback) + self.btn_restore_cache = QPushButton("Restore cached routes") + self.btn_restore_cache.setToolTip("""EN: Restores routes/nft from the last clear snapshot. Skips non-critical route restore errors. +RU: Восстанавливает маршруты/nft из последнего снапшота clear. Некритичные ошибки восстановления пропускаются.""") + self.btn_restore_cache.clicked.connect(self.on_restore_cache) + maint_layout.addWidget(self.btn_restore_cache) + maint_layout.addStretch(1) + tab_basic_layout.addWidget(maint_group) + + tab_basic_layout.addStretch(1) + self.tabs.addTab(tab_basic, "Traffic basics") + + tab_adv = QWidget() + tab_adv_layout = QVBoxLayout(tab_adv) + + self.ed_vpn_subnets = QPlainTextEdit() + self.ed_vpn_subnets.setToolTip("""EN: Force VPN by source subnet. Useful for docker subnets when you want containers via VPN. +RU: Принудительно через VPN по source subnet. Полезно для docker-подсетей, если хочешь контейнеры через VPN.""") + self.ed_vpn_subnets.setPlaceholderText("Force VPN by source subnet, one per line (e.g. 172.18.0.0/16)") + self.ed_vpn_subnets.setFixedHeight(72) + + self.ed_vpn_uids = QPlainTextEdit() + self.ed_vpn_uids.setToolTip("""EN: Force VPN by UID/uidrange (host OUTPUT only). Does not affect forwarded docker traffic. +RU: Принудительно через VPN по UID (только процессы хоста). На forwarded docker-трафик не влияет.""") + self.ed_vpn_uids.setPlaceholderText("Force VPN by UID/UID range, one per line (e.g. 1000 or 1000-1010)") + self.ed_vpn_uids.setFixedHeight(60) + + self.ed_vpn_cgroups = QPlainTextEdit() + self.ed_vpn_cgroups.setToolTip("""EN: Force VPN by systemd cgroup. Backend resolves cgroup -> PIDs -> UID rules at apply time. +RU: Принудительно через VPN по cgroup (systemd). Backend резолвит cgroup -> PID -> UID при применении.""") + self.ed_vpn_cgroups.setPlaceholderText("Force VPN by cgroup path/name, one per line") + self.ed_vpn_cgroups.setFixedHeight(60) + + self.ed_direct_subnets = QPlainTextEdit() + self.ed_direct_subnets.setToolTip("""EN: Force Direct by source subnet. Useful to keep docker subnets direct in full tunnel. +RU: Принудительно direct по source subnet. Полезно, чтобы docker-подсети были direct в full tunnel.""") + self.ed_direct_subnets.setPlaceholderText("Force Direct by source subnet, one per line") + self.ed_direct_subnets.setFixedHeight(72) + + self.ed_direct_uids = QPlainTextEdit() + self.ed_direct_uids.setToolTip("""EN: Force Direct by UID/uidrange (host OUTPUT only). +RU: Принудительно direct по UID (только процессы хоста).""") + self.ed_direct_uids.setPlaceholderText("Force Direct by UID/UID range, one per line") + self.ed_direct_uids.setFixedHeight(60) + + self.ed_direct_cgroups = QPlainTextEdit() + self.ed_direct_cgroups.setToolTip("""EN: Force Direct by systemd cgroup (resolved to UID rules at apply time). +RU: Принудительно direct по cgroup (резолвится в UID правила при применении).""") + self.ed_direct_cgroups.setPlaceholderText("Force Direct by cgroup path/name, one per line") + self.ed_direct_cgroups.setFixedHeight(60) + + cols = QHBoxLayout() + + vpn_group = QGroupBox("Force VPN") + vpn_layout = QVBoxLayout(vpn_group) + vpn_layout.addWidget(QLabel("Source subnets")) + vpn_layout.addWidget(self.ed_vpn_subnets) + vpn_layout.addWidget(QLabel("UIDs")) + vpn_layout.addWidget(self.ed_vpn_uids) + vpn_layout.addWidget(QLabel("Cgroups / services")) + vpn_layout.addWidget(self.ed_vpn_cgroups) + cols.addWidget(vpn_group, stretch=1) + + direct_group = QGroupBox("Force Direct") + direct_layout = QVBoxLayout(direct_group) + direct_layout.addWidget(QLabel("Source subnets")) + direct_layout.addWidget(self.ed_direct_subnets) + direct_layout.addWidget(QLabel("UIDs")) + direct_layout.addWidget(self.ed_direct_uids) + direct_layout.addWidget(QLabel("Cgroups / services")) + direct_layout.addWidget(self.ed_direct_cgroups) + cols.addWidget(direct_group, stretch=1) + + tab_adv_layout.addLayout(cols, stretch=1) + + row_adv = QHBoxLayout() + self.btn_pick_detected = QPushButton("Add detected...") + self.btn_pick_detected.setToolTip("""EN: Opens a selector with detected subnets/services/UIDs. Only fills fields; nothing is applied automatically. +RU: Открывает список обнаруженных subnet/service/UID. Только заполняет поля; ничего не применяется автоматически.""") + self.btn_pick_detected.clicked.connect(self.on_pick_detected) + row_adv.addWidget(self.btn_pick_detected) + self.btn_apply_overrides = QPushButton("Apply overrides") + self.btn_apply_overrides.setToolTip("""EN: Applies policy rules and verifies health. On failure backend rolls back. +RU: Применяет policy-rules и проверяет health. При ошибке backend делает откат.""") + self.btn_apply_overrides.clicked.connect(self.on_apply_overrides) + row_adv.addWidget(self.btn_apply_overrides) + self.btn_reload_overrides = QPushButton("Reload overrides") + self.btn_reload_overrides.clicked.connect(self.refresh_state) + row_adv.addWidget(self.btn_reload_overrides) + row_adv.addStretch(1) + tab_adv_layout.addLayout(row_adv) + + self.tabs.addTab(tab_adv, "Policy overrides (Advanced)") + + # EN: Small status line for last action performed in this dialog. + # RU: Строка статуса последнего действия в этом окне. + self.lbl_action = QLabel("—") + self.lbl_action.setWordWrap(True) + self.lbl_action.setStyleSheet("color: gray;") + root.addWidget(self.lbl_action) + + row_bottom = QHBoxLayout() + row_bottom.addStretch(1) + btn_close = QPushButton("Close") + btn_close.clicked.connect(self.accept) + row_bottom.addWidget(btn_close) + root.addLayout(row_bottom) + + QtCore.QTimer.singleShot(0, self.refresh_state) + + def _is_operation_error(self, message: str) -> bool: + low = (message or "").strip().lower() + return ("rolled back" in low) or ("apply failed" in low) or ("verification failed" in low) + + def _set_action_status(self, msg: str, ok: bool | None = None) -> None: + text = (msg or "").strip() or "—" + self.lbl_action.setText(text) + if ok is True: + self.lbl_action.setStyleSheet("color: green;") + elif ok is False: + self.lbl_action.setStyleSheet("color: red;") + else: + self.lbl_action.setStyleSheet("color: gray;") + + def _safe(self, fn, *, title: str = "Traffic mode error") -> None: + try: + fn() + except Exception as e: + msg = f"[ui-error] {title}: {e}" + self._set_action_status(msg, ok=False) + self._emit_log(msg) + QMessageBox.critical(self, title, str(e)) + + def _emit_log(self, msg: str) -> None: + text = (msg or "").strip() + if not text: + return + if self.log_cb: + self.log_cb(text) + else: + try: + self.ctrl.log_gui(text) + except Exception: + pass + + def _preferred_iface_value(self) -> str: + raw = self.cmb_iface.currentText().strip() + if raw.lower() in ("", "auto", "-", "default"): + return "" + return raw + + def _set_preferred_iface_options(self, ifaces: list[str], selected: str) -> None: + vals = ["auto"] + [x for x in ifaces if x] + sel = selected.strip() if selected else "auto" + if not sel: + sel = "auto" + if sel not in vals: + vals.append(sel) + + self.cmb_iface.blockSignals(True) + self.cmb_iface.clear() + self.cmb_iface.addItems(vals) + idx = self.cmb_iface.findText(sel) + if idx < 0: + idx = self.cmb_iface.findText("auto") + if idx >= 0: + self.cmb_iface.setCurrentIndex(idx) + else: + self.cmb_iface.setEditText(sel) + self.cmb_iface.blockSignals(False) + + def _lines_from_text(self, txt: str) -> list[str]: + out: list[str] = [] + for raw in (txt or "").replace("\r", "\n").split("\n"): + line = raw.strip() + if line: + out.append(line) + return out + + def _set_lines(self, widget: QPlainTextEdit, vals: list[str]) -> None: + widget.blockSignals(True) + widget.setPlainText("\n".join([x for x in vals if str(x).strip()])) + widget.blockSignals(False) + + def _merge_lines(self, widget: QPlainTextEdit, vals: list[str]) -> int: + cur = self._lines_from_text(widget.toPlainText()) + seen = {x.strip() for x in cur} + added = 0 + for v in (vals or []): + vv = str(v).strip() + if not vv or vv in seen: + continue + cur.append(vv) + seen.add(vv) + added += 1 + if added > 0: + self._set_lines(widget, cur) + return added + + def _candidates_add(self, target: str, kind: str, values: list[str]) -> None: + tgt = (target or "").strip().lower() + k = (kind or "").strip().lower() + if tgt not in ("vpn", "direct"): + return + + widget: QPlainTextEdit | None = None + if tgt == "vpn": + if k == "subnet": + widget = self.ed_vpn_subnets + elif k == "uid": + widget = self.ed_vpn_uids + elif k == "cgroup": + widget = self.ed_vpn_cgroups + else: + if k == "subnet": + widget = self.ed_direct_subnets + elif k == "uid": + widget = self.ed_direct_uids + elif k == "cgroup": + widget = self.ed_direct_cgroups + + if widget is None: + return + + added = self._merge_lines(widget, values or []) + if added > 0: + msg = f"Traffic candidates added: target={tgt} kind={k} added={added}" + self._set_action_status(msg, ok=True) + self._emit_log(msg) + else: + msg = f"Traffic candidates add: nothing new (target={tgt} kind={k})" + self._set_action_status(msg, ok=None) + self._emit_log(msg) + + def on_pick_detected(self) -> None: + def work() -> None: + cands = self.ctrl.traffic_candidates() + existing = { + "vpn": { + "subnet": set(self._lines_from_text(self.ed_vpn_subnets.toPlainText())), + "uid": set(self._lines_from_text(self.ed_vpn_uids.toPlainText())), + "cgroup": set(self._lines_from_text(self.ed_vpn_cgroups.toPlainText())), + }, + "direct": { + "subnet": set(self._lines_from_text(self.ed_direct_subnets.toPlainText())), + "uid": set(self._lines_from_text(self.ed_direct_uids.toPlainText())), + "cgroup": set(self._lines_from_text(self.ed_direct_cgroups.toPlainText())), + }, + } + dlg = TrafficCandidatesDialog( + cands, + existing=existing, + add_cb=self._candidates_add, + parent=self, + ) + dlg.exec() + + self._safe(work, title="Traffic candidates error") + + + def _set_mode_state( + self, + desired_mode: str, + applied_mode: str, + preferred_iface: str, + auto_local_bypass: bool, + bypass_candidates: int, + overrides_applied: int, + cgroup_resolved_uids: int, + cgroup_warning: str, + healthy: bool, + probe_ok: bool, + probe_message: str, + active_iface: str, + iface_reason: str, + message: str, + ) -> None: + desired = (desired_mode or "").strip().lower() or "selective" + applied = (applied_mode or "").strip().lower() or "direct" + + if healthy: + color = "green" + health_txt = "OK" + else: + color = "red" + health_txt = "MISMATCH" + + text = f"Traffic mode: {desired} (applied: {applied}) [{health_txt}]" + diag_parts = [] + diag_parts.append(f"preferred={preferred_iface or 'auto'}") + diag_parts.append( + f"auto_local_bypass={'on' if auto_local_bypass else 'off'}" + ) + if bypass_candidates > 0: + diag_parts.append(f"bypass_routes={bypass_candidates}") + diag_parts.append(f"overrides={overrides_applied}") + if cgroup_resolved_uids > 0: + diag_parts.append(f"cgroup_uids={cgroup_resolved_uids}") + if cgroup_warning: + diag_parts.append(f"cgroup_warning={cgroup_warning}") + if active_iface: + diag_parts.append(f"iface={active_iface}") + if iface_reason: + diag_parts.append(f"source={iface_reason}") + diag_parts.append(f"probe={'ok' if probe_ok else 'fail'}") + if probe_message: + diag_parts.append(probe_message) + if message: + diag_parts.append(message) + diag = " | ".join(diag_parts) if diag_parts else "—" + + self.lbl_state.setText(text) + self.lbl_state.setStyleSheet(f"color: {color};") + self.lbl_diag.setText(diag) + self.lbl_diag.setStyleSheet("color: gray;") + + def refresh_state(self) -> None: + def work() -> None: + view = self.ctrl.traffic_mode_view() + mode = (view.desired_mode or "selective").strip().lower() + + self.rad_selective.blockSignals(True) + self.rad_full.blockSignals(True) + self.rad_direct.blockSignals(True) + self.rad_selective.setChecked(mode == "selective") + self.rad_full.setChecked(mode == "full_tunnel") + self.rad_direct.setChecked(mode == "direct") + self.rad_selective.blockSignals(False) + self.rad_full.blockSignals(False) + self.rad_direct.blockSignals(False) + + opts = self.ctrl.traffic_interfaces() + self._set_preferred_iface_options(opts, view.preferred_iface) + self.chk_auto_local.blockSignals(True) + self.chk_auto_local.setChecked(bool(view.auto_local_bypass)) + self.chk_auto_local.blockSignals(False) + self._set_lines(self.ed_vpn_subnets, list(view.force_vpn_subnets or [])) + self._set_lines(self.ed_vpn_uids, list(view.force_vpn_uids or [])) + self._set_lines(self.ed_vpn_cgroups, list(view.force_vpn_cgroups or [])) + self._set_lines(self.ed_direct_subnets, list(view.force_direct_subnets or [])) + self._set_lines(self.ed_direct_uids, list(view.force_direct_uids or [])) + self._set_lines(self.ed_direct_cgroups, list(view.force_direct_cgroups or [])) + + self._set_mode_state( + view.desired_mode, + view.applied_mode, + view.preferred_iface, + bool(view.auto_local_bypass), + int(view.bypass_candidates), + int(view.overrides_applied), + int(view.cgroup_resolved_uids), + view.cgroup_warning, + bool(view.healthy), + bool(view.probe_ok), + view.probe_message, + view.active_iface, + view.iface_reason, + view.message, + ) + + self._safe(work) + + def on_mode_toggle(self, mode: str, checked: bool) -> None: + if not checked: + return + + def work() -> None: + preferred = self._preferred_iface_value() + auto_local = self.chk_auto_local.isChecked() + view = self.ctrl.traffic_mode_set(mode, preferred, auto_local) + msg = ( + f"Traffic mode set: desired={view.desired_mode}, " + f"applied={view.applied_mode}, iface={view.active_iface or '-'}, " + f"preferred={preferred or 'auto'}, probe_ok={view.probe_ok}, " + f"healthy={view.healthy}, auto_local_bypass={view.auto_local_bypass}, " + f"bypass_routes={view.bypass_candidates}, overrides={view.overrides_applied}, " + f"cgroup_uids={view.cgroup_resolved_uids}, message={view.message}" + ) + self._emit_log(msg) + op_ok = bool(view.healthy) and not self._is_operation_error(view.message) + self._set_action_status( + f"Traffic mode set: desired={view.desired_mode} applied={view.applied_mode} message={view.message}", + ok=op_ok, + ) + self.refresh_state() + if self.refresh_cb: + self.refresh_cb() + + self._safe(work) + + def on_refresh_ifaces(self) -> None: + def work() -> None: + view = self.ctrl.traffic_mode_view() + opts = self.ctrl.traffic_interfaces() + self._set_preferred_iface_options(opts, view.preferred_iface) + self._emit_log( + "Traffic ifaces refreshed: " + f"preferred={view.preferred_iface or 'auto'} " + f"active={view.active_iface or '-'}" + ) + self._set_action_status("Traffic ifaces refreshed", ok=True) + self.refresh_state() + if self.refresh_cb: + self.refresh_cb() + + self._safe(work, title="Traffic iface detect error") + + def _selected_mode(self) -> str: + if self.rad_full.isChecked(): + return "full_tunnel" + if self.rad_direct.isChecked(): + return "direct" + return "selective" + + def on_auto_local_toggle(self) -> None: + def work() -> None: + mode = self._selected_mode() + preferred = self._preferred_iface_value() + auto_local = self.chk_auto_local.isChecked() + view = self.ctrl.traffic_mode_set(mode, preferred, auto_local) + msg = ( + f"Traffic auto-local set: mode={view.desired_mode}, " + f"auto_local_bypass={view.auto_local_bypass}, " + f"bypass_routes={view.bypass_candidates}, overrides={view.overrides_applied}, " + f"cgroup_uids={view.cgroup_resolved_uids}, message={view.message}" + ) + self._emit_log(msg) + op_ok = bool(view.healthy) and not self._is_operation_error(view.message) + self._set_action_status( + f"Auto-local bypass set: {'on' if view.auto_local_bypass else 'off'} ({view.message})", + ok=op_ok, + ) + self.refresh_state() + if self.refresh_cb: + self.refresh_cb() + + self._safe(work, title="Auto-local bypass error") + + def on_apply_overrides(self) -> None: + def work() -> None: + mode = self._selected_mode() + preferred = self._preferred_iface_value() + auto_local = self.chk_auto_local.isChecked() + vpn_subnets = self._lines_from_text(self.ed_vpn_subnets.toPlainText()) + vpn_uids = self._lines_from_text(self.ed_vpn_uids.toPlainText()) + vpn_cgroups = self._lines_from_text(self.ed_vpn_cgroups.toPlainText()) + direct_subnets = self._lines_from_text(self.ed_direct_subnets.toPlainText()) + direct_uids = self._lines_from_text(self.ed_direct_uids.toPlainText()) + direct_cgroups = self._lines_from_text(self.ed_direct_cgroups.toPlainText()) + + view = self.ctrl.traffic_mode_set( + mode, + preferred, + auto_local, + vpn_subnets, + vpn_uids, + vpn_cgroups, + direct_subnets, + direct_uids, + direct_cgroups, + ) + msg = ( + f"Traffic overrides applied: mode={view.desired_mode}, " + f"vpn_subnets={len(view.force_vpn_subnets)}, vpn_uids={len(view.force_vpn_uids)}, vpn_cgroups={len(view.force_vpn_cgroups)}, " + f"direct_subnets={len(view.force_direct_subnets)}, direct_uids={len(view.force_direct_uids)}, direct_cgroups={len(view.force_direct_cgroups)}, " + f"overrides={view.overrides_applied}, cgroup_uids={view.cgroup_resolved_uids}, " + f"healthy={view.healthy}, message={view.message}" + ) + self._emit_log(msg) + op_ok = bool(view.healthy) and not self._is_operation_error(view.message) + self._set_action_status( + f"Overrides applied: overrides={view.overrides_applied} message={view.message}", + ok=op_ok, + ) + self.refresh_state() + if self.refresh_cb: + self.refresh_cb() + + self._safe(work, title="Apply overrides error") + + def on_rollback(self) -> None: + def work() -> None: + res = self.ctrl.routes_clear() + self._emit_log(res.pretty_text or "rollback done") + self._set_action_status(res.pretty_text or "routes cleared (cache saved)", ok=bool(res.ok)) + self.refresh_state() + if self.refresh_cb: + self.refresh_cb() + + self._safe(work, title="Rollback error") + + def on_restore_cache(self) -> None: + def work() -> None: + res = self.ctrl.routes_cache_restore() + self._emit_log(res.pretty_text or "cache restore done") + self._set_action_status(res.pretty_text or "routes restored from cache", ok=bool(res.ok)) + self.refresh_state() + if self.refresh_cb: + self.refresh_cb() + + self._safe(work, title="Restore cache error") + + +class TrafficCandidatesDialog(QDialog): + def __init__( + self, + candidates, + *, + existing: dict[str, dict[str, set[str]]] | None = None, + add_cb: Callable[[str, str, list[str]], None], + parent=None, + ) -> None: + super().__init__(parent) + self.cands = candidates + self.add_cb = add_cb + self.existing = existing or {"vpn": {}, "direct": {}} + + self.setWindowTitle("Add detected overrides") + self.resize(820, 680) + + root = QVBoxLayout(self) + + note = QLabel( + "Tip: hover list items for details. Подсказка: наведи на элементы списка.\n" + "Detect results from backend. Nothing is applied until you click Apply overrides." + ) + note.setWordWrap(True) + note.setStyleSheet("color: gray;") + root.addWidget(note) + + self.chk_hide_existing = QCheckBox("Hide already added") + self.chk_hide_existing.setToolTip( + """EN: Hides items that are already present in Force VPN/Force Direct fields. +RU: Скрывает элементы, которые уже есть в Force VPN/Force Direct.""" + ) + self.chk_hide_existing.stateChanged.connect(lambda _s: self._refilter_current()) + root.addWidget(self.chk_hide_existing) + + self.tabs = QTabWidget() + root.addWidget(self.tabs, stretch=1) + + self._tab_kind: dict[QWidget, str] = {} + self._tab_list: dict[QWidget, QListWidget] = {} + self._tab_filter: dict[QWidget, QLineEdit] = {} + self.tabs.currentChanged.connect(lambda _idx: self._refilter_current()) + + self._build_subnets_tab() + self._build_services_tab() + self._build_uids_tab() + + row = QHBoxLayout() + btn_vpn = QPushButton("Add to Force VPN") + btn_vpn.clicked.connect(lambda: self._add_selected("vpn")) + row.addWidget(btn_vpn) + + btn_direct = QPushButton("Add to Force Direct") + btn_direct.clicked.connect(lambda: self._add_selected("direct")) + row.addWidget(btn_direct) + + row.addStretch(1) + btn_close = QPushButton("Close") + btn_close.clicked.connect(self.accept) + row.addWidget(btn_close) + root.addLayout(row) + + def _mark_state(self, kind: str, value: str) -> tuple[bool, bool]: + k = (kind or "").strip().lower() + v = (value or "").strip() + if not k or not v: + return False, False + in_vpn = v in (self.existing.get("vpn", {}).get(k, set()) or set()) + in_direct = v in (self.existing.get("direct", {}).get(k, set()) or set()) + return bool(in_vpn), bool(in_direct) + + def _apply_filter(self, lst: QListWidget, query: str) -> None: + q = (query or "").strip().lower() + hide_existing = bool(self.chk_hide_existing.isChecked()) + for i in range(lst.count()): + it = lst.item(i) + if not it: + continue + if hide_existing and bool(it.data(QtCore.Qt.UserRole + 1) or False): + it.setHidden(True) + continue + if not q: + it.setHidden(False) + continue + it.setHidden(q not in it.text().lower()) + + def _refilter_current(self) -> None: + tab = self.tabs.currentWidget() + if tab is None: + return + lst = self._tab_list.get(tab) + filt = self._tab_filter.get(tab) + if lst is None or filt is None: + return + self._apply_filter(lst, filt.text()) + + def _add_tab(self, title: str, kind: str, items: list[tuple[str, str, str]], *, extra=None) -> None: + tab = QWidget() + layout = QVBoxLayout(tab) + + if extra is not None: + extra(layout) + + filt = QLineEdit() + filt.setPlaceholderText("Filter...") + layout.addWidget(filt) + + lst = QListWidget() + lst.setSelectionMode(QAbstractItemView.ExtendedSelection) + for entry in items: + label = str(entry[0]) if len(entry) > 0 else "" + value = str(entry[1]) if len(entry) > 1 else "" + tip = str(entry[2]) if len(entry) > 2 else "" + in_vpn, in_direct = self._mark_state(kind, value) + flags = [] + if in_vpn: + flags.append("VPN") + if in_direct: + flags.append("DIRECT") + if flags: + label = f"{label} [{' + '.join(flags)}]" + + it = QListWidgetItem(label) + it.setData(QtCore.Qt.UserRole, value) + it.setData(QtCore.Qt.UserRole + 1, bool(in_vpn or in_direct)) + if tip.strip(): + extra_tip = ( + f"\n\nAlready in Force VPN: {'yes' if in_vpn else 'no'}\n" + f"Already in Force Direct: {'yes' if in_direct else 'no'}" + ) + it.setToolTip(tip + extra_tip) + if in_vpn or in_direct: + it.setForeground(QtGui.QBrush(QtGui.QColor("gray"))) + lst.addItem(it) + + layout.addWidget(lst, stretch=1) + filt.textChanged.connect(lambda txt, l=lst: self._apply_filter(l, txt)) + + self.tabs.addTab(tab, title) + self._tab_kind[tab] = kind + self._tab_list[tab] = lst + self._tab_filter[tab] = filt + + def _current_kind_and_list(self) -> tuple[str, QListWidget | None]: + tab = self.tabs.currentWidget() + if tab is None: + return "", None + return self._tab_kind.get(tab, ""), self._tab_list.get(tab) + + def _add_selected(self, target: str) -> None: + kind, lst = self._current_kind_and_list() + if not kind or lst is None: + return + + vals: list[str] = [] + for it in lst.selectedItems(): + v = it.data(QtCore.Qt.UserRole) + vv = str(v or "").strip() + if vv: + vals.append(vv) + + # stable de-dupe + out: list[str] = [] + seen: set[str] = set() + for v in vals: + if v in seen: + continue + seen.add(v) + out.append(v) + + if out: + self.add_cb(target, kind, out) + + def _list_for_title(self, title: str) -> QListWidget | None: + for i in range(self.tabs.count()): + if self.tabs.tabText(i) == title: + tab = self.tabs.widget(i) + return self._tab_list.get(tab) + return None + + def _preset_clear_selection(self, title: str) -> None: + lst = self._list_for_title(title) + if lst is not None: + lst.clearSelection() + + def _preset_select_services(self, keywords: list[str]) -> None: + lst = self._list_for_title("Services") + if lst is None: + return + keys = [str(k).strip().lower() for k in (keywords or []) if str(k).strip()] + if not keys: + return + lst.clearSelection() + for i in range(lst.count()): + it = lst.item(i) + if it is None: + continue + txt = (it.text() or "").lower() + val = str(it.data(QtCore.Qt.UserRole) or "").lower() + if any(k in txt or k in val for k in keys): + it.setSelected(True) + + def _preset_select_uids(self, uids: list[int]) -> None: + lst = self._list_for_title("UIDs") + if lst is None: + return + want = {f"{int(u)}-{int(u)}" for u in (uids or [])} + if not want: + return + lst.clearSelection() + for i in range(lst.count()): + it = lst.item(i) + if it is None: + continue + token = str(it.data(QtCore.Qt.UserRole) or "").strip() + if token in want: + it.setSelected(True) + + def _build_subnets_tab(self) -> None: + subs = list(getattr(self.cands, "subnets", []) or []) + + items: list[tuple[str, str, str]] = [] + for s in subs: + cidr = str(getattr(s, "cidr", "") or "").strip() + if not cidr: + continue + dev = str(getattr(s, "dev", "") or "").strip() + kind = str(getattr(s, "kind", "") or "").strip() + linkdown = bool(getattr(s, "linkdown", False)) + tags = [] + if kind: + tags.append(kind) + if dev: + tags.append(dev) + if linkdown: + tags.append("linkdown") + tag_txt = " " + "[" + ", ".join(tags) + "]" if tags else "" + tip = ( + f"CIDR: {cidr}\n" + f"kind={kind or '-'} dev={dev or '-'} linkdown={linkdown}\n\n" + "EN: Source subnet overrides affect forwarded traffic (Docker).\n" + "RU: Source subnet влияет на forwarded трафик (Docker)." + ) + items.append((f"{cidr}{tag_txt}", cidr, tip)) + + def extra(layout: QVBoxLayout) -> None: + row = QHBoxLayout() + btn_lan = QPushButton("Keep LAN direct") + btn_lan.clicked.connect(lambda: self._preset_add_lan_direct()) + row.addWidget(btn_lan) + btn_docker = QPushButton("Keep Docker direct") + btn_docker.clicked.connect(lambda: self._preset_add_docker_direct()) + row.addWidget(btn_docker) + row.addStretch(1) + layout.addLayout(row) + + self._add_tab("Subnets", "subnet", items, extra=extra) + + def _preset_add_lan_direct(self) -> None: + subs = list(getattr(self.cands, "subnets", []) or []) + vals: list[str] = [] + for s in subs: + kind = str(getattr(s, "kind", "") or "").strip() + cidr = str(getattr(s, "cidr", "") or "").strip() + if not cidr: + continue + if kind in ("lan", "link"): + vals.append(cidr) + if vals: + self.add_cb("direct", "subnet", vals) + + def _preset_add_docker_direct(self) -> None: + subs = list(getattr(self.cands, "subnets", []) or []) + vals: list[str] = [] + for s in subs: + kind = str(getattr(s, "kind", "") or "").strip() + cidr = str(getattr(s, "cidr", "") or "").strip() + if not cidr: + continue + if kind == "docker": + vals.append(cidr) + if vals: + self.add_cb("direct", "subnet", vals) + + def _build_services_tab(self) -> None: + units = list(getattr(self.cands, "units", []) or []) + + items: list[tuple[str, str, str]] = [] + for u in units: + unit = str(getattr(u, "unit", "") or "").strip() + if not unit: + continue + desc = str(getattr(u, "description", "") or "").strip() + cgroup = str(getattr(u, "cgroup", "") or "").strip() or unit + label = unit + if desc: + label += " - " + desc + tip = ( + f"Unit: {unit}\n" + f"Cgroup token: {cgroup}\n\n" + "EN: Adds a cgroup override; backend resolves it to UID rules at apply time.\n" + "RU: Добавляет cgroup override; backend резолвит его в UID правила при применении." + ) + items.append((label, cgroup, tip)) + + def extra(layout: QVBoxLayout) -> None: + row = QHBoxLayout() + btn_docker = QPushButton("Select docker/container") + btn_docker.clicked.connect(lambda: self._preset_select_services(["docker", "containerd", "podman"])) + row.addWidget(btn_docker) + btn_media = QPushButton("Select media (jellyfin/plex)") + btn_media.clicked.connect(lambda: self._preset_select_services(["jellyfin", "plex", "emby"])) + row.addWidget(btn_media) + btn_clear = QPushButton("Clear selection") + btn_clear.clicked.connect(lambda: self._preset_clear_selection("Services")) + row.addWidget(btn_clear) + row.addStretch(1) + layout.addLayout(row) + + self._add_tab("Services", "cgroup", items, extra=extra) + + def _build_uids_tab(self) -> None: + uids = list(getattr(self.cands, "uids", []) or []) + + items: list[tuple[str, str, str]] = [] + for u in uids: + try: + uid = int(getattr(u, "uid", 0) or 0) + except Exception: + continue + user = str(getattr(u, "user", "") or "").strip() + examples = list(getattr(u, "examples", []) or []) + ex_txt = ", ".join([str(x) for x in examples if str(x).strip()]) + + label = str(uid) + if user: + label += f" ({user})" + if ex_txt: + label += " - " + ex_txt + + token = f"{uid}-{uid}" + tip = ( + f"UID: {uid}\n" + f"User: {user or '-'}\n" + f"Examples: {ex_txt or '-'}\n\n" + "EN: UID rules affect host-local processes (OUTPUT).\n" + "RU: UID правила влияют на процессы хоста (OUTPUT)." + ) + items.append((label, token, tip)) + + def extra(layout: QVBoxLayout) -> None: + row = QHBoxLayout() + btn_me = QPushButton("Select my UID") + btn_me.clicked.connect(lambda: self._preset_select_uids([os.getuid()])) + row.addWidget(btn_me) + btn_root = QPushButton("Select root UID") + btn_root.clicked.connect(lambda: self._preset_select_uids([0])) + row.addWidget(btn_root) + btn_clear = QPushButton("Clear selection") + btn_clear.clicked.connect(lambda: self._preset_clear_selection("UIDs")) + row.addWidget(btn_clear) + row.addStretch(1) + layout.addLayout(row) + + self._add_tab("UIDs", "uid", items, extra=extra) diff --git a/selective-vpn-gui/vpn-dashboard.py b/selective-vpn-gui/vpn-dashboard.py new file mode 100755 index 0000000..ebf88ab --- /dev/null +++ b/selective-vpn-gui/vpn-dashboard.py @@ -0,0 +1,901 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +Selective-VPN Dashboard (UI only) + +RULES: +- This file must NOT know anything about REST paths, HTTP methods, or JSON keys. +- It talks ONLY to DashboardController (which uses ApiClient). +""" + +from __future__ import annotations + +import re +import subprocess +import sys +import tkinter as tk +from tkinter import messagebox +from tkinter import ttk +from typing import Literal, Optional, cast, Tuple + +from api_client import ApiClient, DnsUpstreams +from dashboard_controller import DashboardController + +TraceMode = Literal["full", "gui", "smartdns"] + +# убираем спам автопроверки из логов UI (на всякий случай, даже если почистил controller) +_NEXT_CHECK_RE = re.compile(r"(?:\b\d+s\.)?\s*Next check in\s+\d+s\.?", re.IGNORECASE) + + +class App(ttk.Frame): + def __init__(self, master: tk.Tk, ctrl: DashboardController) -> None: + super().__init__(master) + self.master = master + self.ctrl = ctrl + + # login-flow runtime + self._login_flow_active: bool = False + self._login_cursor: int = 0 + self._login_url_opened: bool = False + self._login_poll_after_id: Optional[str] = None + + self._build_ui() + self._wire_events() + + self.after(50, self.refresh_everything) + self.master.protocol("WM_DELETE_WINDOW", self._on_close) + + # ---------------- UI BUILD ---------------- + + def _build_ui(self) -> None: + self.master.title("Selective-VPN Dashboard") + self.pack(fill="both", expand=True) + + # Top bar + top = ttk.Frame(self) + top.pack(fill="x", padx=10, pady=(10, 6)) + + self.btn_refresh = ttk.Button(top, text="Refresh all", command=self.refresh_everything) + self.btn_refresh.pack(side="left") + + # Login indicator (dot + text) + self.login_dot = tk.Canvas(top, width=12, height=12, highlightthickness=0) + self.login_dot.pack(side="left", padx=(12, 4)) + self._login_dot_id = self.login_dot.create_oval(2, 2, 10, 10, fill="gray", outline="") + + self.lbl_login = ttk.Label(top, text="AdGuard VPN: ...", font=("TkDefaultFont", 10, "bold")) + self.lbl_login.pack(side="left", padx=(0, 10)) + + # Single auth button (Login/Logout) + self.btn_auth = ttk.Button(top, text="Login", command=self.on_auth_button) + self.btn_auth.pack(side="left") + + self.lbl_hint = ttk.Label(top, text="(GUI contains no API logic)", foreground="gray") + self.lbl_hint.pack(side="right") + + # Notebook + self.nb = ttk.Notebook(self) + self.nb.pack(fill="both", expand=True, padx=10, pady=(0, 10)) + + self._build_tab_status() + self._build_tab_vpn() + self._build_tab_routes() + self._build_tab_dns() + self._build_tab_domains() + self._build_tab_trace() + + def _build_tab_status(self) -> None: + tab = ttk.Frame(self.nb) + self.nb.add(tab, text="Status") + + frm = ttk.Frame(tab) + frm.pack(fill="both", expand=True, padx=10, pady=10) + + grid = ttk.Frame(frm) + grid.pack(fill="x") + + def row(r: int, label: str) -> ttk.Label: + ttk.Label(grid, text=label).grid(row=r, column=0, sticky="w", pady=2) + v = ttk.Label(grid, text="—") + v.grid(row=r, column=1, sticky="w", pady=2, padx=(10, 0)) + return v + + self.st_timestamp = row(0, "Timestamp") + self.st_counts = row(1, "Counts") + self.st_iface = row(2, "Iface/Table/Mark") + self.st_route = row(3, "Policy route") + self.st_routesvc = row(4, "Routes service") + self.st_smartdns = row(5, "SmartDNS service") + self.st_vpnsvc = row(6, "VPN service") + + btns = ttk.Frame(frm) + btns.pack(fill="x", pady=(10, 0)) + ttk.Button(btns, text="Refresh status", command=self.refresh_status_tab).pack(side="left") + + def _build_tab_vpn(self) -> None: + tab = ttk.Frame(self.nb) + self.nb.add(tab, text="AdGuardVPN") + + # Pages container + self.vpn_pages = ttk.Frame(tab) + self.vpn_pages.pack(fill="both", expand=True, padx=10, pady=10) + + self.vpn_page_main = ttk.Frame(self.vpn_pages) + self.vpn_page_login = ttk.Frame(self.vpn_pages) + + for p in (self.vpn_page_main, self.vpn_page_login): + p.grid(row=0, column=0, sticky="nsew") + self.vpn_pages.rowconfigure(0, weight=1) + self.vpn_pages.columnconfigure(0, weight=1) + + # -------- Page 1: main VPN controls (Enter Login removed) -------- + frm = self.vpn_page_main + + top_actions = ttk.Frame(frm) + top_actions.pack(fill="x", pady=(0, 10)) + ttk.Button(top_actions, text="Refresh", command=self.refresh_vpn_tab).pack(side="right") + + # Autoconnect toggle + ac = ttk.LabelFrame(frm, text="Auto-connect") + ac.pack(fill="x") + + self.var_autoconnect = tk.BooleanVar(value=False) + self.chk_autoconnect = ttk.Checkbutton( + ac, + text="Enable auto-connect", + variable=self.var_autoconnect, + command=self.on_toggle_autoconnect, + ) + self.chk_autoconnect.pack(side="left", padx=10, pady=8) + + # Location picker + loc = ttk.LabelFrame(frm, text="Location") + loc.pack(fill="x", pady=(10, 0)) + + self.cmb_location = ttk.Combobox(loc, state="readonly", width=40) + self.cmb_location.pack(side="left", padx=10, pady=8) + self.btn_set_location = ttk.Button(loc, text="Set location", command=self.on_set_location) + self.btn_set_location.pack(side="left", padx=6, pady=8) + + self.lbl_vpn_desired = ttk.Label(loc, text="Desired: —", foreground="gray") + self.lbl_vpn_desired.pack(side="left", padx=12) + + # Status output + st = ttk.LabelFrame(frm, text="VPN Status") + st.pack(fill="both", expand=True, pady=(10, 0)) + + self.txt_vpn = tk.Text(st, height=12, wrap="none") + self.txt_vpn.pack(fill="both", expand=True, padx=10, pady=10) + + # -------- Page 2: Login flow -------- + lf = self.vpn_page_login + + lf_top = ttk.Frame(lf) + lf_top.pack(fill="x", pady=(0, 10)) + + ttk.Button(lf_top, text="← Back", command=self.on_login_back).pack(side="left") + + self.login_flow_dot = tk.Canvas(lf_top, width=14, height=14, highlightthickness=0) + self.login_flow_dot.pack(side="left", padx=(10, 4)) + self._login_flow_dot_id = self.login_flow_dot.create_oval(2, 2, 12, 12, fill="orange", outline="") + + self.lbl_login_flow_status = ttk.Label(lf_top, text="Status: —", font=("TkDefaultFont", 10, "bold")) + self.lbl_login_flow_status.pack(side="left", padx=(0, 10)) + + self.lbl_login_flow_email = ttk.Label(lf_top, text="", foreground="gray") + self.lbl_login_flow_email.pack(side="left") + + url_row = ttk.Frame(lf) + url_row.pack(fill="x", pady=(0, 10)) + + ttk.Label(url_row, text="URL:").pack(side="left") + self.var_login_url = tk.StringVar(value="") + self.ent_login_url = ttk.Entry(url_row, textvariable=self.var_login_url, state="readonly") + self.ent_login_url.pack(side="left", fill="x", expand=True, padx=8) + + self.btn_login_copy = ttk.Button(url_row, text="Copy", command=self.on_login_copy) + self.btn_login_copy.pack(side="left", padx=(0, 6)) + + self.btn_login_open = ttk.Button(url_row, text="Open", command=self.on_login_open) + self.btn_login_open.pack(side="left") + + ctrl_row = ttk.Frame(lf) + ctrl_row.pack(fill="x", pady=(0, 10)) + + self.btn_login_check = ttk.Button(ctrl_row, text="Check", command=self.on_login_check) + self.btn_login_check.pack(side="left") + + self.btn_login_close = ttk.Button(ctrl_row, text="Close (cancel)", command=self.on_login_cancel) + self.btn_login_close.pack(side="left", padx=6) + + self.btn_login_stop = ttk.Button(ctrl_row, text="Stop (force)", command=self.on_login_stop) + self.btn_login_stop.pack(side="left", padx=6) + + # Log output + out = ttk.LabelFrame(lf, text="Login output") + out.pack(fill="both", expand=True) + + self.txt_login_flow = tk.Text(out, wrap="word", height=16) + self.txt_login_flow.pack(fill="both", expand=True, padx=10, pady=10) + + self._show_vpn_page("main") + + def _build_tab_routes(self) -> None: + tab = ttk.Frame(self.nb) + self.nb.add(tab, text="Routes") + + frm = ttk.Frame(tab) + frm.pack(fill="both", expand=True, padx=10, pady=10) + + svc = ttk.LabelFrame(frm, text="Routes service") + svc.pack(fill="x") + + ttk.Button(svc, text="Start", command=lambda: self.on_routes_action("start")).pack(side="left", padx=10, pady=8) + ttk.Button(svc, text="Stop", command=lambda: self.on_routes_action("stop")).pack(side="left", padx=6, pady=8) + ttk.Button(svc, text="Restart", command=lambda: self.on_routes_action("restart")).pack(side="left", padx=6, pady=8) + + ttk.Button(svc, text="Clear routes", command=self.on_routes_clear).pack(side="right", padx=10, pady=8) + + timer = ttk.LabelFrame(frm, text="Timer") + timer.pack(fill="x", pady=(10, 0)) + + self.var_timer = tk.BooleanVar(value=False) + self.chk_timer = ttk.Checkbutton(timer, text="Enable timer", variable=self.var_timer, command=self.on_toggle_timer) + self.chk_timer.pack(side="left", padx=10, pady=8) + + ttk.Button(timer, text="Fix policy route", command=self.on_fix_policy_route).pack(side="right", padx=10, pady=8) + + out = ttk.LabelFrame(frm, text="Output") + out.pack(fill="both", expand=True, pady=(10, 0)) + + self.txt_routes = tk.Text(out, height=12, wrap="none") + self.txt_routes.pack(fill="both", expand=True, padx=10, pady=10) + + def _build_tab_dns(self) -> None: + tab = ttk.Frame(self.nb) + self.nb.add(tab, text="DNS") + + frm = ttk.Frame(tab) + frm.pack(fill="both", expand=True, padx=10, pady=10) + + ups = ttk.LabelFrame(frm, text="Upstreams") + ups.pack(fill="x") + + def add_field(r: int, label: str) -> ttk.Entry: + ttk.Label(ups, text=label).grid(row=r, column=0, sticky="w", padx=10, pady=4) + e = ttk.Entry(ups, width=60) + e.grid(row=r, column=1, sticky="we", padx=10, pady=4) + return e + + ups.columnconfigure(1, weight=1) + self.ent_def1 = add_field(0, "default1") + self.ent_def2 = add_field(1, "default2") + self.ent_meta1 = add_field(2, "meta1") + self.ent_meta2 = add_field(3, "meta2") + + btns = ttk.Frame(frm) + btns.pack(fill="x", pady=(10, 0)) + ttk.Button(btns, text="Refresh", command=self.refresh_dns_tab).pack(side="left") + ttk.Button(btns, text="Save", command=self.on_save_upstreams).pack(side="left", padx=6) + + sm = ttk.LabelFrame(frm, text="SmartDNS") + sm.pack(fill="both", expand=True, pady=(10, 0)) + + top = ttk.Frame(sm) + top.pack(fill="x", padx=10, pady=(10, 6)) + + self.lbl_smartdns_state = ttk.Label(top, text="Service: —") + self.lbl_smartdns_state.pack(side="left") + + ttk.Button(top, text="Start", command=lambda: self.on_smartdns_action("start")).pack(side="right", padx=6) + ttk.Button(top, text="Stop", command=lambda: self.on_smartdns_action("stop")).pack(side="right") + + mid = ttk.Frame(sm) + mid.pack(fill="both", expand=True, padx=10, pady=(0, 10)) + + ttk.Label(mid, text="Wildcards (one per line):").pack(anchor="w") + self.txt_wildcards = tk.Text(mid, height=10, wrap="none") + self.txt_wildcards.pack(fill="both", expand=True, pady=(4, 6)) + + btns2 = ttk.Frame(mid) + btns2.pack(fill="x") + ttk.Button(btns2, text="Refresh", command=self.refresh_dns_tab).pack(side="left") + ttk.Button(btns2, text="Save", command=self.on_save_wildcards).pack(side="left", padx=6) + + def _build_tab_domains(self) -> None: + tab = ttk.Frame(self.nb) + self.nb.add(tab, text="Domains") + + frm = ttk.Frame(tab) + frm.pack(fill="both", expand=True, padx=10, pady=10) + + left = ttk.Frame(frm) + left.pack(side="left", fill="y") + + right = ttk.Frame(frm) + right.pack(side="left", fill="both", expand=True, padx=(10, 0)) + + ttk.Label(left, text="Files:").pack(anchor="w") + self.lst_files = tk.Listbox(left, height=6, exportselection=False) + for name in ("bases", "meta", "subs", "static"): + self.lst_files.insert("end", name) + self.lst_files.selection_set(0) + self.lst_files.pack(fill="y", pady=(4, 8)) + + ttk.Button(left, text="Refresh table", command=self.refresh_domains_tab).pack(fill="x") + ttk.Button(left, text="Load file", command=self.on_domains_load).pack(fill="x", pady=(6, 0)) + ttk.Button(left, text="Save file", command=self.on_domains_save).pack(fill="x", pady=(6, 0)) + ttk.Button(left, text="Load AGVPN table", command=self.on_load_agvpn_table).pack(fill="x", pady=(10, 0)) + ttk.Button(left, text="Load SmartDNS table", command=self.on_load_smartdns_table).pack(fill="x", pady=(6, 0)) + + top = ttk.Frame(right) + top.pack(fill="x") + self.lbl_domains_info = ttk.Label(top, text="—", foreground="gray") + self.lbl_domains_info.pack(side="left") + + self.txt_domains = tk.Text(right, wrap="none") + self.txt_domains.pack(fill="both", expand=True, pady=(6, 0)) + + def _build_tab_trace(self) -> None: + tab = ttk.Frame(self.nb) + self.nb.add(tab, text="Trace") + + frm = ttk.Frame(tab) + frm.pack(fill="both", expand=True, padx=10, pady=10) + + top = ttk.Frame(frm) + top.pack(fill="x") + + self.var_trace_mode = tk.StringVar(value="full") + for m, title in (("full", "Full"), ("gui", "GUI"), ("smartdns", "SmartDNS")): + ttk.Radiobutton(top, text=title, value=m, variable=self.var_trace_mode, command=self.refresh_trace_tab).pack( + side="left", padx=(0, 10) + ) + ttk.Button(top, text="Refresh", command=self.refresh_trace_tab).pack(side="right") + + self.txt_trace = tk.Text(frm, wrap="none") + self.txt_trace.pack(fill="both", expand=True, pady=(10, 0)) + + def _wire_events(self) -> None: + self.lst_files.bind("<>", lambda _e: self.on_domains_load()) + + # ---------------- UI HELPERS ---------------- + + def _set_text(self, widget: tk.Text, text: str) -> None: + widget.config(state="normal") + widget.delete("1.0", "end") + widget.insert("1.0", text) + widget.config(state="normal") + + def _append_text(self, widget: tk.Text, text: str) -> None: + widget.config(state="normal") + widget.insert("end", text) + widget.see("end") + widget.config(state="normal") + + def _clean_ui_lines(self, lines) -> str: + # финальная страховка: убираем "Next check" и нормализуем \r + buf = "\n".join([str(x) for x in (lines or [])]).replace("\r", "\n") + out_lines = [] + for ln in buf.splitlines(): + t = ln.strip() + if not t: + continue + t2 = _NEXT_CHECK_RE.sub("", t).strip() + if not t2: + continue + out_lines.append(t2) + return "\n".join(out_lines).rstrip() + + def _get_selected_domains_file(self) -> str: + sel = self.lst_files.curselection() + if not sel: + return "bases" + return str(self.lst_files.get(sel[0])) + + def _read_local_file(self, path: str) -> str: + try: + with open(path, "r", encoding="utf-8", errors="ignore") as f: + return f.read() + except Exception: + return "" + + def _safe(self, fn, *, title: str = "Error"): + try: + return fn() + except Exception as e: + messagebox.showerror(title, str(e)) + return None + + def _set_dot(self, canvas: tk.Canvas, dot_id: int, color: str) -> None: + c = (color or "").strip().lower() + if c in ("green", "ok", "true"): + fill = "green" + elif c in ("red", "error", "false"): + fill = "red" + elif c in ("orange", "yellow", "try", "unknown", "pending", "wait"): + fill = "orange" + else: + fill = "gray" + try: + canvas.itemconfigure(dot_id, fill=fill) + except Exception: + pass + + def _show_vpn_page(self, which: Literal["main", "login"]) -> None: + if which == "login": + self.vpn_page_login.tkraise() + else: + self.vpn_page_main.tkraise() + + def _parse_login_banner(self, text: str, color: str) -> Tuple[bool, str]: + # считаем "logged" если зеленый + is_logged = (color or "").strip().lower() == "green" + email = "" + t = (text or "") + # пытаемся вытащить email из строки + m = re.search(r"([A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,})", t) + if m: + email = m.group(1) + return is_logged, email + + def _set_auth_button(self, logged: bool) -> None: + self.btn_auth.config(text=("Logout" if logged else "Login")) + + # ---------------- REFRESH ---------------- + + def refresh_everything(self) -> None: + self.refresh_status_tab() + self.refresh_vpn_tab() + self.refresh_routes_tab() + self.refresh_dns_tab() + self.refresh_domains_tab() + self.refresh_trace_tab() + self.refresh_login_banner() + + def refresh_login_banner(self) -> None: + def work(): + view = self.ctrl.get_login_view() + self.lbl_login.config(text=view.text) + self._set_dot(self.login_dot, self._login_dot_id, view.color) + + # НЕ гадаем по цвету: используем нормализованную логику controller-а + self._set_auth_button(bool(view.logged_in)) + + try: + self.lbl_login.config(foreground=view.color) + except tk.TclError: + pass + + self._safe(work, title="Login state error") + + def refresh_status_tab(self) -> None: + def work(): + view = self.ctrl.get_status_overview() + self.st_timestamp.config(text=view.timestamp) + self.st_counts.config(text=view.counts) + self.st_iface.config(text=view.iface_table_mark) + self.st_route.config(text=view.policy_route) + self.st_routesvc.config(text=view.routes_service) + self.st_smartdns.config(text=view.smartdns_service) + self.st_vpnsvc.config(text=view.vpn_service) + self._safe(work, title="Status error") + + def refresh_vpn_tab(self) -> None: + def work(): + locs = self.ctrl.vpn_locations_view() + self.cmb_location["values"] = [f"{x.iso} — {x.label}" for x in locs] + + st = self.ctrl.vpn_status_view() + self.lbl_vpn_desired.config(text=f"Desired: {st.desired_location or '—'}") + self._set_text(self.txt_vpn, st.pretty_text) + + self.var_autoconnect.set(self.ctrl.vpn_autoconnect_enabled()) + self._safe(work, title="VPN error") + + def refresh_routes_tab(self) -> None: + def work(): + self.var_timer.set(self.ctrl.routes_timer_enabled()) + self._safe(work, title="Routes error") + + def refresh_dns_tab(self) -> None: + def work(): + cfg = self.ctrl.dns_upstreams_view() + self.ent_def1.delete(0, "end"); self.ent_def1.insert(0, cfg.default1) + self.ent_def2.delete(0, "end"); self.ent_def2.insert(0, cfg.default2) + self.ent_meta1.delete(0, "end"); self.ent_meta1.insert(0, cfg.meta1) + self.ent_meta2.delete(0, "end"); self.ent_meta2.insert(0, cfg.meta2) + + sd = self.ctrl.smartdns_service_view() + self.lbl_smartdns_state.config(text=f"Service: {sd.state}") + + wc = self.ctrl.smartdns_wildcards_view() + self._set_text(self.txt_wildcards, "\n".join(wc.domains).strip() + ("\n" if wc.domains else "")) + self._safe(work, title="DNS error") + + def refresh_domains_tab(self) -> None: + def work(): + table = self.ctrl.domains_table_view() + self.lbl_domains_info.config(text=f"Table lines: {len(table.lines)}") + self._safe(work, title="Domains error") + + def refresh_trace_tab(self) -> None: + def work(): + mode = cast(TraceMode, self.var_trace_mode.get()) + dump = self.ctrl.trace_view(mode) + self._set_text(self.txt_trace, "\n".join(dump.lines).strip() + ("\n" if dump.lines else "")) + self._safe(work, title="Trace error") + + # ---------------- LOGIN FLOW (UI) ---------------- + + def _login_flow_reset_ui(self) -> None: + self._login_cursor = 0 + self._login_url_opened = False + self.var_login_url.set("") + self.lbl_login_flow_status.config(text="Status: —") + self.lbl_login_flow_email.config(text="") + self._set_dot(self.login_flow_dot, self._login_flow_dot_id, "orange") + self._set_text(self.txt_login_flow, "") + + def _login_flow_set_buttons(self, *, can_open: bool, can_check: bool, can_cancel: bool) -> None: + def set_state(btn: ttk.Button, enabled: bool) -> None: + try: + btn.config(state=("normal" if enabled else "disabled")) + except Exception: + pass + + set_state(self.btn_login_open, can_open) + set_state(self.btn_login_copy, bool(self.var_login_url.get().strip())) + set_state(self.btn_login_check, can_check) + set_state(self.btn_login_close, can_cancel) + + # stop — страховка, но если уже success/already_logged, можно тоже выключить не обязательно + try: + self.btn_login_stop.config(state="normal") + except Exception: + pass + + def _login_flow_autopoll_start(self) -> None: + self._login_flow_active = True + self._login_poll_tick() + + def _login_flow_autopoll_stop(self) -> None: + self._login_flow_active = False + if self._login_poll_after_id is not None: + try: + self.after_cancel(self._login_poll_after_id) + except Exception: + pass + self._login_poll_after_id = None + + def _login_poll_tick(self) -> None: + if not self._login_flow_active: + return + + def work(): + view = self.ctrl.login_flow_poll(self._login_cursor) + self._login_cursor = int(view.cursor) + + # indicator + status + self._set_dot(self.login_flow_dot, self._login_flow_dot_id, view.dot_color) + self.lbl_login_flow_status.config(text=f"Status: {view.status_text or '—'}") + self.lbl_login_flow_email.config(text=(f"User: {view.email}" if view.email else "")) + + if view.url: + self.var_login_url.set(view.url) + + # buttons + self._login_flow_set_buttons(can_open=view.can_open, can_check=view.can_check, can_cancel=view.can_cancel) + + # append cleaned lines + cleaned = self._clean_ui_lines(view.lines) + if cleaned: + self._append_text(self.txt_login_flow, cleaned + "\n") + + # auto-open browser once when url appears + if (not self._login_url_opened) and view.url: + self._login_url_opened = True + try: + subprocess.Popen(["xdg-open", view.url]) + except Exception: + pass + + phase = (view.phase or "").strip().lower() + if (not view.alive) or phase in ("success", "failed", "cancelled", "already_logged"): + # Авто-обновляем баннер при успехе/уже залогинен + if phase in ("success", "already_logged"): + self.after(250, self.refresh_login_banner) + # и возвращаемся на main страницу VPN, чтобы UX был как у тебя на примере + self.after(500, lambda: self._show_vpn_page("main")) + + # на терминале — стопаем polling + self._login_flow_autopoll_stop() + + # в терминале делаем кнопки логина неактивными (как в твоём "идеальном" окне) + self._login_flow_set_buttons(can_open=False, can_check=False, can_cancel=False) + try: + self.btn_login_stop.config(state="disabled") + except Exception: + pass + + self._safe(work, title="Login flow error") + + if self._login_flow_active: + self._login_poll_after_id = self.after(250, self._login_poll_tick) + + # ---------------- TOP AUTH BUTTON ---------------- + + def on_auth_button(self) -> None: + # decide based on current banner + def work(): + view = self.ctrl.get_login_view() + if bool(view.logged_in): + self.on_logout() + else: + self.on_start_login() + + self._safe(work, title="Auth error") + + # ---------------- ACTIONS ---------------- + + def on_start_login(self) -> None: + def work(): + self.ctrl.log_gui("Top Login clicked") + self._login_flow_reset_ui() + + start = self.ctrl.login_flow_start() + + # reflect start info + self._set_dot(self.login_flow_dot, self._login_flow_dot_id, start.dot_color) + self.lbl_login_flow_status.config(text=f"Status: {start.status_text or '—'}") + self.lbl_login_flow_email.config(text=(f"User: {start.email}" if start.email else "")) + + if start.url: + self.var_login_url.set(start.url) + + cleaned = self._clean_ui_lines(start.lines) + if cleaned: + self._append_text(self.txt_login_flow, cleaned + "\n") + + # already logged: banner update and stop + phase = (start.phase or "").strip().lower() + if phase == "already_logged": + self.refresh_login_banner() + messagebox.showinfo("Login", f"Already logged in{f' as {start.email}' if start.email else ''}.") + return + + # show login page and start polling + self._show_vpn_page("login") + self._login_cursor = int(start.cursor or 0) + self._login_flow_set_buttons(can_open=start.can_open, can_check=start.can_check, can_cancel=start.can_cancel) + self._login_flow_autopoll_start() + + self._safe(work, title="Login start error") + + def on_login_back(self) -> None: + self._login_flow_autopoll_stop() + self._show_vpn_page("main") + self.refresh_login_banner() + + def on_login_copy(self) -> None: + u = self.var_login_url.get().strip() + if not u: + return + try: + self.master.clipboard_clear() + self.master.clipboard_append(u) + except Exception: + pass + + def on_login_open(self) -> None: + def work(): + self.ctrl.login_flow_action("open") + u = self.var_login_url.get().strip() + if u: + try: + subprocess.Popen(["xdg-open", u]) + except Exception: + pass + self.ctrl.log_gui("Login flow: open") + self._safe(work, title="Login open error") + + def on_login_check(self) -> None: + def work(): + self.ctrl.login_flow_action("check") + self.ctrl.log_gui("Login flow: check") + self._safe(work, title="Login check error") + + def on_login_cancel(self) -> None: + def work(): + self.ctrl.login_flow_action("cancel") + self.ctrl.log_gui("Login flow: cancel") + self._safe(work, title="Login cancel error") + + def on_login_stop(self) -> None: + def work(): + self.ctrl.login_flow_stop() + self.ctrl.log_gui("Login flow: stop") + self._login_flow_autopoll_stop() + self.after(250, self.refresh_login_banner) + self._safe(work, title="Login stop error") + + def on_logout(self) -> None: + def work(): + if not messagebox.askyesno("Logout", "Logout from AdGuard VPN account?"): + return + res = self.ctrl.vpn_logout() + self.ctrl.log_gui("VPN logout executed") + messagebox.showinfo("Logout", res.pretty_text.strip() or "Done.") + self.refresh_login_banner() + self.refresh_vpn_tab() + self._safe(work, title="Logout error") + + def on_toggle_autoconnect(self) -> None: + def work(): + enable = bool(self.var_autoconnect.get()) + res = self.ctrl.vpn_set_autoconnect(enable) + self._set_text(self.txt_vpn, res.pretty_text) + self.ctrl.log_gui(f"Auto-connect set to {enable}") + self._safe(work, title="Auto-connect error") + + def on_set_location(self) -> None: + def work(): + val = self.cmb_location.get().strip() + if not val: + messagebox.showinfo("Location", "Choose a location first.") + return + iso = val.split("—", 1)[0].strip() + res = self.ctrl.vpn_set_location(iso) + self._set_text(self.txt_vpn, res.pretty_text) + self.ctrl.log_gui(f"Location set to {iso}") + self.refresh_vpn_tab() + self._safe(work, title="Set location error") + + def on_routes_action(self, action: str) -> None: + def work(): + res = self.ctrl.routes_service_action(action) + self._set_text(self.txt_routes, res.pretty_text) + self.ctrl.log_gui(f"Routes service: {action}") + self.refresh_status_tab() + self._safe(work, title="Routes service error") + + def on_routes_clear(self) -> None: + def work(): + res = self.ctrl.routes_clear() + self._set_text(self.txt_routes, res.pretty_text) + self.ctrl.log_gui("Routes cleared") + self.refresh_status_tab() + self._safe(work, title="Clear routes error") + + def on_toggle_timer(self) -> None: + def work(): + enabled = bool(self.var_timer.get()) + res = self.ctrl.routes_timer_set(enabled) + self._set_text(self.txt_routes, res.pretty_text) + self.ctrl.log_gui(f"Routes timer set to {enabled}") + self.refresh_status_tab() + self._safe(work, title="Timer error") + + def on_fix_policy_route(self) -> None: + def work(): + res = self.ctrl.routes_fix_policy_route() + self._set_text(self.txt_routes, res.pretty_text) + self.ctrl.log_gui("Policy route fix executed") + self.refresh_status_tab() + self._safe(work, title="Fix policy route error") + + def on_save_upstreams(self) -> None: + def work(): + cfg = DnsUpstreams( + default1=self.ent_def1.get().strip(), + default2=self.ent_def2.get().strip(), + meta1=self.ent_meta1.get().strip(), + meta2=self.ent_meta2.get().strip(), + ) + self.ctrl.dns_upstreams_save(cfg) + self.ctrl.log_gui("DNS upstreams saved") + messagebox.showinfo("DNS", "Saved.") + self.refresh_dns_tab() + self._safe(work, title="Save upstreams error") + + def on_smartdns_action(self, action: str) -> None: + def work(): + _res = self.ctrl.smartdns_service_action(action) + self.ctrl.log_gui(f"SmartDNS action: {action}") + self.refresh_dns_tab() + self.refresh_trace_tab() + self._safe(work, title="SmartDNS error") + + def on_save_wildcards(self) -> None: + def work(): + raw = self.txt_wildcards.get("1.0", "end") + domains = [x.strip() for x in raw.splitlines() if x.strip()] + self.ctrl.smartdns_wildcards_save(domains) + self.ctrl.log_gui(f"Wildcards saved: {len(domains)}") + messagebox.showinfo("SmartDNS", "Wildcards saved.") + self.refresh_dns_tab() + self._safe(work, title="Save wildcards error") + + def on_domains_load(self) -> None: + def work(): + name = self._get_selected_domains_file() + f = self.ctrl.domains_file_load(name) + content = f.content or "" + source = getattr(f, "source", "") or "file" + if not content: + path = f"/etc/selective-vpn/domains/{name}.txt" + content = self._read_local_file(path) + if content: + source = f"{source}+fallback" if source else "fallback" + self._set_text(self.txt_domains, content) + self.lbl_domains_info.config(text=f"{name} (source: {source})") + self.ctrl.log_gui(f"Domains file loaded: {name} source={source}") + self._safe(work, title="Load domains file error") + + def on_domains_save(self) -> None: + def work(): + name = self._get_selected_domains_file() + content = self.txt_domains.get("1.0", "end") + self.ctrl.domains_file_save(name, content) + self.ctrl.log_gui(f"Domains file saved: {name}") + messagebox.showinfo("Domains", "Saved.") + self.refresh_status_tab() + self._safe(work, title="Save domains file error") + + def on_load_agvpn_table(self) -> None: + path = "/var/lib/selective-vpn/last-ips-map.txt" + data = self._read_local_file(path) + self._set_text(self.txt_domains, data or "(empty)") + self.lbl_domains_info.config(text=f"AGVPN table: {path}") + + def on_load_smartdns_table(self) -> None: + path = "/etc/selective-vpn/smartdns.conf" + data = self._read_local_file(path) + self._set_text(self.txt_domains, data or "(empty)") + self.lbl_domains_info.config(text=f"SmartDNS table: {path}") + + # ---------------- CLOSE HANDLER ---------------- + + def _on_close(self) -> None: + def work(): + if self._login_flow_active: + try: + self.ctrl.login_flow_action("cancel") + except Exception: + pass + try: + self.ctrl.login_flow_stop() + except Exception: + pass + self._login_flow_autopoll_stop() + try: + work() + finally: + self.master.destroy() + + +def main() -> int: + client = ApiClient.from_env() + ctrl = DashboardController(client) + + root = tk.Tk() + try: + root.minsize(900, 650) + except Exception: + pass + + try: + style = ttk.Style() + if "clam" in style.theme_names(): + style.theme_use("clam") + except Exception: + pass + + _app = App(root, ctrl) + root.mainloop() + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/selective-vpn-gui/vpn_dashboard_qt.py b/selective-vpn-gui/vpn_dashboard_qt.py new file mode 100755 index 0000000..4fef265 --- /dev/null +++ b/selective-vpn-gui/vpn_dashboard_qt.py @@ -0,0 +1,1515 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import re +import subprocess +import sys +import time +from typing import Literal + +from PySide6 import QtCore +from PySide6.QtCore import Qt, QSettings, QTimer +from PySide6.QtGui import QTextCursor +from PySide6.QtWidgets import ( + QApplication, + QComboBox, + QFormLayout, + QGroupBox, + QHBoxLayout, + QLabel, + QListView, + QListWidget, + QListWidgetItem, + QMainWindow, + QMessageBox, + QPushButton, + QPlainTextEdit, + QRadioButton, + QStackedWidget, + QTabWidget, + QVBoxLayout, + QWidget, + QLineEdit, + QCheckBox, + QProgressBar, +) + +from api_client import ApiClient, DnsUpstreams +from dashboard_controller import DashboardController, TraceMode +from traffic_mode_dialog import TrafficModeDialog + +_NEXT_CHECK_RE = re.compile(r"(?i)next check in \d+s") +LoginPage = Literal["main", "login"] + + +class EventThread(QtCore.QThread): + eventReceived = QtCore.Signal(object) + error = QtCore.Signal(str) + + def __init__(self, controller: DashboardController, parent=None) -> None: + super().__init__(parent) + self.ctrl = controller + self._stop = False + self._since = 0 + + def stop(self) -> None: + self._stop = True + + def run(self) -> None: # pragma: no cover - thread + while not self._stop: + try: + for ev in self.ctrl.iter_events(since=self._since, stop=lambda: self._stop): + if self._stop: + break + try: + self._since = int(getattr(ev, "id", self._since)) + except Exception: + pass + self.eventReceived.emit(ev) + # graceful end -> short delay + time.sleep(0.5) + except Exception as e: + self.error.emit(str(e)) + time.sleep(1.5) + + +class MainWindow(QMainWindow): + def __init__(self, controller: DashboardController) -> None: + super().__init__() + self.ctrl = controller + + self.setWindowTitle("Selective VPN Dashboard (Qt)") + self.resize(1024, 700) + + # login-flow state + self._login_flow_active: bool = False + self._login_cursor: int = 0 + self._login_url_opened: bool = False + self.events_thread: EventThread | None = None + self._routes_progress_last: int = 0 + self._dns_ui_refresh: bool = False + self._ui_settings = QSettings("AdGuardVPN", "SelectiveVPNDashboardQt") + + self.login_poll_timer = QTimer(self) + self.login_poll_timer.setInterval(250) + self.login_poll_timer.timeout.connect(self._login_poll_tick) + + self.dns_save_timer = QTimer(self) + self.dns_save_timer.setSingleShot(True) + self.dns_save_timer.setInterval(700) + self.dns_save_timer.timeout.connect(self._apply_dns_autosave) + + self._build_ui() + self._load_ui_preferences() + self.refresh_everything() + self._start_events_stream() + + # ---------------- UI BUILD ---------------- + + def _build_ui(self) -> None: + root = QWidget() + root_layout = QVBoxLayout(root) + root.setLayout(root_layout) + self.setCentralWidget(root) + + # top bar --------------------------------------------------------- + top = QHBoxLayout() + root_layout.addLayout(top) + + # клик по этому баннеру показывает whoami + self.btn_login_banner = QPushButton("AdGuard VPN: —") + self.btn_login_banner.setFlat(True) + self.btn_login_banner.setStyleSheet( + "text-align: left; border: none; color: gray;" + ) + self.btn_login_banner.clicked.connect(self.on_login_banner_clicked) + top.addWidget(self.btn_login_banner, stretch=1) + + self.btn_auth = QPushButton("Login") + self.btn_auth.clicked.connect(self.on_auth_button) + top.addWidget(self.btn_auth) + + self.btn_refresh_all = QPushButton("Refresh all") + self.btn_refresh_all.clicked.connect(self.refresh_everything) + top.addWidget(self.btn_refresh_all) + + # tabs ------------------------------------------------------------- + self.tabs = QTabWidget() + root_layout.addWidget(self.tabs, stretch=1) + + self._build_tab_status() + self._build_tab_vpn() + self._build_tab_routes() + self._build_tab_dns() + self._build_tab_domains() + self._build_tab_trace() + + # ---------------- STATUS TAB ---------------- + + def _build_tab_status(self) -> None: + tab = QWidget() + layout = QVBoxLayout(tab) + + grid = QFormLayout() + layout.addLayout(grid) + + self.st_timestamp = QLabel("—") + self.st_counts = QLabel("—") + self.st_iface = QLabel("—") + self.st_route = QLabel("—") + self.st_routes_service = QLabel("—") + self.st_smartdns_service = QLabel("—") + self.st_vpn_service = QLabel("—") + + grid.addRow("Timestamp:", self.st_timestamp) + grid.addRow("Counts:", self.st_counts) + grid.addRow("Iface / table / mark:", self.st_iface) + grid.addRow("Policy route:", self.st_route) + grid.addRow("Routes service:", self.st_routes_service) + grid.addRow("SmartDNS:", self.st_smartdns_service) + grid.addRow("VPN service:", self.st_vpn_service) + + btns = QHBoxLayout() + layout.addLayout(btns) + btn_refresh = QPushButton("Refresh") + btn_refresh.clicked.connect(self.refresh_status_tab) + btns.addWidget(btn_refresh) + btns.addStretch(1) + + self.tabs.addTab(tab, "Status") + + # ---------------- VPN TAB ---------------- + + def _build_tab_vpn(self) -> None: + tab = QWidget() + self.tab_vpn = tab # нужно, чтобы переключаться сюда из шапки + layout = QVBoxLayout(tab) + + # stack: main vs login-flow page + self.vpn_stack = QStackedWidget() + layout.addWidget(self.vpn_stack, stretch=1) + + # ---- main page + page_main = QWidget() + main_layout = QVBoxLayout(page_main) + + # Autoconnect group + auto_group = QGroupBox("Autoconnect (AdGuardVPN autoloop)") + auto_layout = QHBoxLayout(auto_group) + self.btn_autoconnect_toggle = QPushButton("Enable autoconnect") + self.btn_autoconnect_toggle.clicked.connect(self.on_toggle_autoconnect) + auto_layout.addWidget(self.btn_autoconnect_toggle) + + auto_layout.addStretch(1) + + # справа текст "unit: active/inactive" с цветом + self.lbl_autoconnect_state = QLabel("unit: —") + self.lbl_autoconnect_state.setStyleSheet("color: gray;") + auto_layout.addWidget(self.lbl_autoconnect_state) + + main_layout.addWidget(auto_group) + + # Locations group + loc_group = QGroupBox("Location") + loc_layout = QHBoxLayout(loc_group) + + self.cmb_locations = QComboBox() + # компактный popup со скроллом, а не на весь экран + self.cmb_locations.setMaxVisibleItems(12) + self.cmb_locations.setStyleSheet("combobox-popup: 0;") + view = QListView() + view.setUniformItemSizes(True) + self.cmb_locations.setView(view) + + loc_layout.addWidget(self.cmb_locations, stretch=1) + + self.btn_set_location = QPushButton("Apply & restart loop") + self.btn_set_location.clicked.connect(self.on_set_location) + loc_layout.addWidget(self.btn_set_location) + + main_layout.addWidget(loc_group) + + # Status output + self.txt_vpn = QPlainTextEdit() + self.txt_vpn.setReadOnly(True) + main_layout.addWidget(self.txt_vpn, stretch=1) + + self.vpn_stack.addWidget(page_main) + + # ---- login page + page_login = QWidget() + lf_layout = QVBoxLayout(page_login) + + top = QHBoxLayout() + lf_layout.addLayout(top) + + self.lbl_login_flow_status = QLabel("Status: —") + top.addWidget(self.lbl_login_flow_status) + self.lbl_login_flow_email = QLabel("") + self.lbl_login_flow_email.setStyleSheet("color: gray;") + top.addWidget(self.lbl_login_flow_email) + top.addStretch(1) + + # URL + buttons row + row2 = QHBoxLayout() + lf_layout.addLayout(row2) + row2.addWidget(QLabel("URL:")) + self.edit_login_url = QLineEdit() + row2.addWidget(self.edit_login_url, stretch=1) + self.btn_login_open = QPushButton("Open") + self.btn_login_open.clicked.connect(self.on_login_open) + row2.addWidget(self.btn_login_open) + self.btn_login_copy = QPushButton("Copy") + self.btn_login_copy.clicked.connect(self.on_login_copy) + row2.addWidget(self.btn_login_copy) + self.btn_login_check = QPushButton("Check") + self.btn_login_check.clicked.connect(self.on_login_check) + row2.addWidget(self.btn_login_check) + self.btn_login_close = QPushButton("Cancel") + self.btn_login_close.clicked.connect(self.on_login_cancel) + row2.addWidget(self.btn_login_close) + self.btn_login_stop = QPushButton("Stop session") + self.btn_login_stop.clicked.connect(self.on_login_stop) + row2.addWidget(self.btn_login_stop) + + # log text + self.txt_login_flow = QPlainTextEdit() + self.txt_login_flow.setReadOnly(True) + lf_layout.addWidget(self.txt_login_flow, stretch=1) + + # bottom buttons + bottom = QHBoxLayout() + lf_layout.addLayout(bottom) + + # Start login визуально убираем, но объект оставим на всякий + self.btn_login_start = QPushButton("Start login") + self.btn_login_start.clicked.connect(self.on_start_login) + self.btn_login_start.setVisible(False) + bottom.addWidget(self.btn_login_start) + + btn_back = QPushButton("Back to VPN") + btn_back.clicked.connect(lambda: self._show_vpn_page("main")) + bottom.addWidget(btn_back) + bottom.addStretch(1) + + self.vpn_stack.addWidget(page_login) + + self.tabs.addTab(tab, "AdGuardVPN") + + # ---------------- ROUTES TAB ---------------- + + def _build_tab_routes(self) -> None: + tab = QWidget() + layout = QVBoxLayout(tab) + + # --- Service actions --- + act_group = QGroupBox("Selective routes service") + act_layout = QHBoxLayout(act_group) + + self.btn_routes_start = QPushButton("Start") + self.btn_routes_start.clicked.connect( + lambda: self.on_routes_action("start") + ) + + self.btn_routes_restart = QPushButton("Restart") + self.btn_routes_restart.clicked.connect( + lambda: self.on_routes_action("restart") + ) + + self.btn_routes_stop = QPushButton("Stop") + self.btn_routes_stop.clicked.connect( + lambda: self.on_routes_action("stop") + ) + + act_layout.addWidget(self.btn_routes_start) + act_layout.addWidget(self.btn_routes_restart) + act_layout.addWidget(self.btn_routes_stop) + act_layout.addStretch(1) + + layout.addWidget(act_group) + + # --- Timer / policy route --- + timer_group = QGroupBox("Timer") + timer_layout = QHBoxLayout(timer_group) + + self.chk_timer = QCheckBox("Enable timer") + self.chk_timer.stateChanged.connect(self.on_toggle_timer) + timer_layout.addWidget(self.chk_timer) + + self.btn_fix_policy = QPushButton("Fix policy route") + self.btn_fix_policy.clicked.connect(self.on_fix_policy_route) + timer_layout.addWidget(self.btn_fix_policy) + + timer_layout.addStretch(1) + + layout.addWidget(timer_group) + + # --- Traffic mode relay --- + traffic_group = QGroupBox("Traffic mode relay") + traffic_layout = QVBoxLayout(traffic_group) + + relay_row = QHBoxLayout() + self.btn_traffic_settings = QPushButton("Open traffic settings") + self.btn_traffic_settings.clicked.connect(self.on_open_traffic_settings) + relay_row.addWidget(self.btn_traffic_settings) + self.btn_traffic_test = QPushButton("Test mode") + self.btn_traffic_test.clicked.connect(self.on_test_traffic_mode) + relay_row.addWidget(self.btn_traffic_test) + self.btn_routes_prewarm = QPushButton("Prewarm wildcard now") + self.btn_routes_prewarm.setToolTip("""EN: Sends DNS queries for wildcard domains to prefill agvpn_dyn4 before traffic arrives. +RU: Делает DNS-запросы wildcard-доменов, чтобы заранее наполнить agvpn_dyn4.""") + self.btn_routes_prewarm.clicked.connect(self.on_smartdns_prewarm) + relay_row.addWidget(self.btn_routes_prewarm) + relay_row.addStretch(1) + traffic_layout.addLayout(relay_row) + + self.chk_routes_prewarm_aggressive = QCheckBox("Aggressive prewarm (use subs)") + self.chk_routes_prewarm_aggressive.setToolTip("""EN: Aggressive mode also queries subs list. This can increase DNS load. +RU: Агрессивный режим дополнительно дергает subs список. Может увеличить нагрузку на DNS.""") + self.chk_routes_prewarm_aggressive.stateChanged.connect(self._on_prewarm_aggressive_changed) + traffic_layout.addWidget(self.chk_routes_prewarm_aggressive) + + self.lbl_routes_prewarm_mode = QLabel("Prewarm mode: wildcard-only") + self.lbl_routes_prewarm_mode.setStyleSheet("color: gray;") + traffic_layout.addWidget(self.lbl_routes_prewarm_mode) + self._update_prewarm_mode_label() + + self.lbl_traffic_mode_state = QLabel("Traffic mode: —") + self.lbl_traffic_mode_state.setStyleSheet("color: gray;") + traffic_layout.addWidget(self.lbl_traffic_mode_state) + + self.lbl_traffic_mode_diag = QLabel("—") + self.lbl_traffic_mode_diag.setStyleSheet("color: gray;") + traffic_layout.addWidget(self.lbl_traffic_mode_diag) + + layout.addWidget(traffic_group) + + # --- NFT progress (agvpn4) --- + progress_row = QHBoxLayout() + + self.routes_progress = QProgressBar() + self.routes_progress.setRange(0, 100) + self.routes_progress.setValue(0) + self.routes_progress.setFormat("") # текст выводим отдельным лейблом + self.routes_progress.setTextVisible(False) + self.routes_progress.setEnabled(False) # idle по умолчанию + + self.lbl_routes_progress = QLabel("NFT: idle") + self.lbl_routes_progress.setStyleSheet("color: gray;") + + progress_row.addWidget(self.routes_progress) + progress_row.addWidget(self.lbl_routes_progress) + + layout.addLayout(progress_row) + + # --- Log output --- + self.txt_routes = QPlainTextEdit() + self.txt_routes.setReadOnly(True) + layout.addWidget(self.txt_routes, stretch=1) + + self.tabs.addTab(tab, "Routes") + + # ---------------- DNS TAB ---------------- + + def _build_tab_dns(self) -> None: + tab = QWidget() + main_layout = QVBoxLayout(tab) + + tip = QLabel("Tip: hover fields for help. Подсказка: наведи на элементы для описания.") + tip.setWordWrap(True) + tip.setStyleSheet("color: gray;") + main_layout.addWidget(tip) + + ups_group = QGroupBox("Upstreams (auto-save)") + ups_group.setToolTip("""EN: DNS upstreams for direct resolver mode (and non-wildcard lists in hybrid mode). +RU: DNS апстримы для direct-резолвера (и для не-wildcard списков в hybrid режиме).""") + ups_form = QFormLayout(ups_group) + self.ent_def1 = QLineEdit() + self.ent_def1.setToolTip("""EN: Upstream default1. You can set an IP (port 53 is assumed). +RU: Апстрим default1. Можно указать IP (порт 53 по умолчанию).""") + self.ent_def2 = QLineEdit() + self.ent_def2.setToolTip("""EN: Upstream default2. You can set an IP (port 53 is assumed). +RU: Апстрим default2. Можно указать IP (порт 53 по умолчанию).""") + self.ent_meta1 = QLineEdit() + self.ent_meta1.setToolTip("""EN: Upstream meta1. You can set an IP (port 53 is assumed). +RU: Апстрим meta1. Можно указать IP (порт 53 по умолчанию).""") + self.ent_meta2 = QLineEdit() + self.ent_meta2.setToolTip("""EN: Upstream meta2. You can set an IP (port 53 is assumed). +RU: Апстрим meta2. Можно указать IP (порт 53 по умолчанию).""") + self.ent_def1.textEdited.connect(self._schedule_dns_autosave) + self.ent_def2.textEdited.connect(self._schedule_dns_autosave) + self.ent_meta1.textEdited.connect(self._schedule_dns_autosave) + self.ent_meta2.textEdited.connect(self._schedule_dns_autosave) + ups_form.addRow("default1", self.ent_def1) + ups_form.addRow("default2", self.ent_def2) + ups_form.addRow("meta1", self.ent_meta1) + ups_form.addRow("meta2", self.ent_meta2) + main_layout.addWidget(ups_group) + + smart_group = QGroupBox("SmartDNS") + smart_group.setToolTip("""EN: SmartDNS is used for wildcard domains in hybrid mode. +RU: SmartDNS используется для wildcard-доменов в hybrid режиме.""") + smart_layout = QVBoxLayout(smart_group) + + smart_form = QFormLayout() + self.ent_smartdns_addr = QLineEdit() + self.ent_smartdns_addr.setToolTip("""EN: SmartDNS address in host#port format (example: 127.0.0.1#6053). +RU: Адрес SmartDNS в формате host#port (пример: 127.0.0.1#6053).""") + self.ent_smartdns_addr.setPlaceholderText("127.0.0.1#6053") + self.ent_smartdns_addr.textEdited.connect(self._schedule_dns_autosave) + smart_form.addRow("SmartDNS address", self.ent_smartdns_addr) + smart_layout.addLayout(smart_form) + + self.chk_dns_via_smartdns = QCheckBox("Use SmartDNS for wildcard domains") + self.chk_dns_via_smartdns.setToolTip("""EN: Hybrid wildcard mode: wildcard domains resolve via SmartDNS, other lists resolve via direct upstreams. +RU: Hybrid wildcard режим: wildcard-домены резолвятся через SmartDNS, остальные списки через direct апстримы.""") + self.chk_dns_via_smartdns.stateChanged.connect(self.on_dns_mode_toggle) + smart_layout.addWidget(self.chk_dns_via_smartdns) + + self.lbl_dns_mode_state = QLabel("Resolver mode: unknown") + self.lbl_dns_mode_state.setToolTip("""EN: Current resolver mode reported by API. +RU: Текущий режим резолвера по данным API.""") + smart_layout.addWidget(self.lbl_dns_mode_state) + + self.chk_dns_unit_relay = QCheckBox("SmartDNS unit relay: OFF") + self.chk_dns_unit_relay.setToolTip("""EN: Starts/stops smartdns-local.service. Service state is independent from resolver mode. +RU: Запускает/останавливает smartdns-local.service. Состояние сервиса не равно режиму резолвера.""") + self.chk_dns_unit_relay.stateChanged.connect(self.on_smartdns_unit_toggle) + smart_layout.addWidget(self.chk_dns_unit_relay) + + self.chk_dns_runtime_nftset = QCheckBox("SmartDNS runtime accelerator (nftset -> agvpn_dyn4): ON") + self.chk_dns_runtime_nftset.setToolTip("""EN: Optional accelerator: SmartDNS can add resolved IPs to agvpn_dyn4 in runtime (via nftset). +EN: Wildcard still works without it (resolver job + prewarm). +RU: Опциональный ускоритель: SmartDNS может добавлять IP в agvpn_dyn4 в runtime (через nftset). +RU: Wildcard работает и без него (resolver job + prewarm).""") + self.chk_dns_runtime_nftset.stateChanged.connect(self.on_smartdns_runtime_toggle) + smart_layout.addWidget(self.chk_dns_runtime_nftset) + + self.lbl_dns_wildcard_source = QLabel("Wildcard source: resolver") + self.lbl_dns_wildcard_source.setToolTip("""EN: Where wildcard IPs come from: resolver job, SmartDNS runtime nftset, or both. +RU: Источник wildcard IP: резолвер, runtime nftset SmartDNS, или оба.""") + self.lbl_dns_wildcard_source.setStyleSheet("color: gray;") + smart_layout.addWidget(self.lbl_dns_wildcard_source) + + main_layout.addWidget(smart_group) + main_layout.addStretch(1) + + self.tabs.addTab(tab, "DNS") + + # ---------------- DOMAINS TAB ---------------- + + def _build_tab_domains(self) -> None: + tab = QWidget() + main_layout = QHBoxLayout(tab) + + left = QVBoxLayout() + main_layout.addLayout(left) + + left.addWidget(QLabel("Files:")) + self.lst_files = QListWidget() + for name in ( + "bases", + "meta-special", + "subs", + "static-ips", + "last-ips-map-direct", + "last-ips-map-wildcard", + "smartdns.conf", + ): + QListWidgetItem(name, self.lst_files) + self.lst_files.setCurrentRow(0) + self.lst_files.itemSelectionChanged.connect(self.on_domains_load) + left.addWidget(self.lst_files) + + self.btn_domains_save = QPushButton("Save file") + self.btn_domains_save.clicked.connect(self.on_domains_save) + left.addWidget(self.btn_domains_save) + left.addStretch(1) + + right_layout = QVBoxLayout() + main_layout.addLayout(right_layout, stretch=1) + + self.lbl_domains_info = QLabel("—") + self.lbl_domains_info.setStyleSheet("color: gray;") + right_layout.addWidget(self.lbl_domains_info) + + self.txt_domains = QPlainTextEdit() + right_layout.addWidget(self.txt_domains, stretch=1) + + self.tabs.addTab(tab, "Domains") + + # ---------------- TRACE TAB ---------------- + + def _build_tab_trace(self) -> None: + tab = QWidget() + layout = QVBoxLayout(tab) + + top = QHBoxLayout() + layout.addLayout(top) + + self.radio_trace_full = QRadioButton("Full") + self.radio_trace_full.setChecked(True) + self.radio_trace_full.toggled.connect(self.refresh_trace_tab) + top.addWidget(self.radio_trace_full) + self.radio_trace_gui = QRadioButton("Events") + self.radio_trace_gui.toggled.connect(self.refresh_trace_tab) + top.addWidget(self.radio_trace_gui) + self.radio_trace_smartdns = QRadioButton("SmartDNS") + self.radio_trace_smartdns.toggled.connect(self.refresh_trace_tab) + top.addWidget(self.radio_trace_smartdns) + + btn_refresh = QPushButton("Refresh") + btn_refresh.clicked.connect(self.refresh_trace_tab) + top.addWidget(btn_refresh) + top.addStretch(1) + + self.txt_trace = QPlainTextEdit() + self.txt_trace.setReadOnly(True) + layout.addWidget(self.txt_trace, stretch=1) + + self.tabs.addTab(tab, "Trace") + + # ---------------- UI HELPERS ---------------- + + def _safe(self, fn, *, title: str = "Error"): + try: + return fn() + except Exception as e: # pragma: no cover - GUI + try: + self.ctrl.log_gui(f"[ui-error] {title}: {e}") + except Exception: + pass + QMessageBox.critical(self, title, str(e)) + return None + + def _set_text(self, widget: QPlainTextEdit, text: str, *, preserve_scroll: bool = False) -> None: + """Set text, optionally сохраняя положение скролла (для trace).""" + if not preserve_scroll: + widget.setPlainText(text) + return + + sb = widget.verticalScrollBar() + old_max = sb.maximum() + old_val = sb.value() + at_end = old_val >= old_max - 2 + + widget.setPlainText(text) + + new_max = sb.maximum() + if at_end: + sb.setValue(new_max) + else: + # подвинем на ту же относительную позицию, учитывая прирост размера + sb.setValue(max(0, min(new_max, old_val+(new_max-old_max)))) + + def _append_text(self, widget: QPlainTextEdit, text: str) -> None: + cursor = widget.textCursor() + cursor.movePosition(QTextCursor.End) + cursor.insertText(text) + widget.setTextCursor(cursor) + widget.ensureCursorVisible() + + def _clean_ui_lines(self, lines) -> str: + buf = "\n".join([str(x) for x in (lines or [])]).replace("\r", "\n") + out_lines = [] + for ln in buf.splitlines(): + t = ln.strip() + if not t: + continue + t2 = _NEXT_CHECK_RE.sub("", t).strip() + if not t2: + continue + out_lines.append(t2) + return "\n".join(out_lines).rstrip() + + def _get_selected_domains_file(self) -> str: + item = self.lst_files.currentItem() + return item.text() if item is not None else "bases" + + def _load_file_content(self, name: str) -> tuple[str, str, str]: + api_map = { + "bases": "bases", + "meta-special": "meta", + "subs": "subs", + "static-ips": "static", + "last-ips-map-direct": "last-ips-map-direct", + "last-ips-map-wildcard": "last-ips-map-wildcard", + "smartdns.conf": "smartdns", + } + if name in api_map: + f = self.ctrl.domains_file_load(api_map[name]) + content = f.content or "" + source = getattr(f, "source", "") or "api" + if name == "smartdns.conf": + path = "/var/lib/selective-vpn/smartdns-wildcards.json -> /etc/selective-vpn/smartdns.conf" + elif name == "last-ips-map-direct": + path = "/var/lib/selective-vpn/last-ips-map-direct.txt (artifact: agvpn4)" + elif name == "last-ips-map-wildcard": + path = "/var/lib/selective-vpn/last-ips-map-wildcard.txt (artifact: agvpn_dyn4)" + else: + path = f"/etc/selective-vpn/domains/{name}.txt" + return content, source, path + return "", "unknown", name + + def _save_file_content(self, name: str, content: str) -> None: + api_map = { + "bases": "bases", + "meta-special": "meta", + "subs": "subs", + "static-ips": "static", + "smartdns.conf": "smartdns", + } + if name in api_map: + self.ctrl.domains_file_save(api_map[name], content) + return + + def _show_vpn_page(self, which: LoginPage) -> None: + self.vpn_stack.setCurrentIndex(1 if which == "login" else 0) + + def _set_auth_button(self, logged: bool) -> None: + self.btn_auth.setText("Logout" if logged else "Login") + + def _set_status_label_color(self, lbl: QLabel, text: str, *, kind: str) -> None: + """Подкраска Policy route / services.""" + lbl.setText(text) + low = (text or "").lower() + color = "black" + if kind == "policy": + if "ok" in low and "missing" not in low and "error" not in low: + color = "green" + elif any(w in low for w in ("missing", "error", "failed")): + color = "red" + else: + color = "orange" + else: # service + if any(w in low for w in ("failed", "error", "unknown", "inactive", "dead")): + color = "red" + elif "active" in low or "running" in low: + color = "green" + else: + color = "orange" + lbl.setStyleSheet(f"color: {color};") + + def _set_dns_unit_relay_state(self, enabled: bool) -> None: + txt = "SmartDNS unit relay: ON" if enabled else "SmartDNS unit relay: OFF" + color = "green" if enabled else "red" + self.chk_dns_unit_relay.setText(txt) + self.chk_dns_unit_relay.setStyleSheet(f"color: {color};") + + def _set_dns_runtime_state(self, enabled: bool, source: str, cfg_error: str = "") -> None: + txt = "SmartDNS runtime accelerator (nftset -> agvpn_dyn4): ON" if enabled else "SmartDNS runtime accelerator (nftset -> agvpn_dyn4): OFF" + color = "green" if enabled else "orange" + self.chk_dns_runtime_nftset.setText(txt) + self.chk_dns_runtime_nftset.setStyleSheet(f"color: {color};") + + src = (source or "").strip().lower() + if src == "both": + src_txt = "Wildcard source: both (resolver + smartdns_runtime)" + src_color = "green" + elif src == "smartdns_runtime": + src_txt = "Wildcard source: smartdns_runtime" + src_color = "orange" + else: + src_txt = "Wildcard source: resolver" + src_color = "gray" + if cfg_error: + src_txt = f"{src_txt} | runtime cfg: {cfg_error}" + src_color = "orange" + self.lbl_dns_wildcard_source.setText(src_txt) + self.lbl_dns_wildcard_source.setStyleSheet(f"color: {src_color};") + + def _set_dns_mode_state(self, mode: str) -> None: + low = (mode or "").strip().lower() + if low in ("hybrid_wildcard", "hybrid"): + txt = "Resolver mode: hybrid wildcard (SmartDNS for wildcard domains)" + color = "green" + elif low == "direct": + txt = "Resolver mode: direct upstreams" + color = "red" + else: + txt = "Resolver mode: unknown" + color = "orange" + self.lbl_dns_mode_state.setText(txt) + self.lbl_dns_mode_state.setStyleSheet(f"color: {color};") + + def _set_traffic_mode_state( + self, + desired_mode: str, + applied_mode: str, + preferred_iface: str, + auto_local_bypass: bool, + bypass_candidates: int, + overrides_applied: int, + cgroup_resolved_uids: int, + cgroup_warning: str, + healthy: bool, + probe_ok: bool, + probe_message: str, + active_iface: str, + iface_reason: str, + message: str, + ) -> None: + desired = (desired_mode or "").strip().lower() or "selective" + applied = (applied_mode or "").strip().lower() or "direct" + + if healthy: + color = "green" + health_txt = "OK" + else: + color = "red" + health_txt = "MISMATCH" + + text = f"Traffic mode: {desired} (applied: {applied}) [{health_txt}]" + diag_parts = [] + diag_parts.append(f"preferred={preferred_iface or 'auto'}") + diag_parts.append( + f"auto_local_bypass={'on' if auto_local_bypass else 'off'}" + ) + if bypass_candidates > 0: + diag_parts.append(f"bypass_routes={bypass_candidates}") + diag_parts.append(f"overrides={overrides_applied}") + if cgroup_resolved_uids > 0: + diag_parts.append(f"cgroup_uids={cgroup_resolved_uids}") + if cgroup_warning: + diag_parts.append(f"cgroup_warning={cgroup_warning}") + if active_iface: + diag_parts.append(f"iface={active_iface}") + if iface_reason: + diag_parts.append(f"source={iface_reason}") + diag_parts.append(f"probe={'ok' if probe_ok else 'fail'}") + if probe_message: + diag_parts.append(probe_message) + if message: + diag_parts.append(message) + diag = " | ".join(diag_parts) if diag_parts else "—" + + self.lbl_traffic_mode_state.setText(text) + self.lbl_traffic_mode_state.setStyleSheet(f"color: {color};") + self.lbl_traffic_mode_diag.setText(diag) + self.lbl_traffic_mode_diag.setStyleSheet("color: gray;") + + def _update_routes_progress_label(self, view) -> None: + """ + Обновляет прогресс nft по RoutesNftProgressView. + view ожидаем с полями percent, message, active (duck-typing). + """ + if view is None: + # сброс до idle + self._routes_progress_last = 0 + self.routes_progress.setValue(0) + self.lbl_routes_progress.setText("NFT: idle") + self.lbl_routes_progress.setStyleSheet("color: gray;") + return + + # аккуратно ограничим 0..100 + try: + percent = max(0, min(100, int(view.percent))) + except Exception: + percent = 0 + + # не даём прогрессу дёргаться назад, кроме явного сброса (percent==0) + if percent == 0: + self._routes_progress_last = 0 + else: + percent = max(percent, self._routes_progress_last) + self._routes_progress_last = percent + + self.routes_progress.setValue(percent) + + text = f"{percent}% – {view.message}" + if not view.active and percent >= 100: + color = "green" + elif view.active: + color = "orange" + else: + color = "gray" + + self.lbl_routes_progress.setText(text) + self.lbl_routes_progress.setStyleSheet(f"color: {color};") + + def _load_ui_preferences(self) -> None: + raw = self._ui_settings.value("routes/prewarm_aggressive", False) + if isinstance(raw, str): + val = raw.strip().lower() in ("1", "true", "yes", "on") + else: + val = bool(raw) + self.chk_routes_prewarm_aggressive.blockSignals(True) + self.chk_routes_prewarm_aggressive.setChecked(val) + self.chk_routes_prewarm_aggressive.blockSignals(False) + self._update_prewarm_mode_label() + + def _save_ui_preferences(self) -> None: + self._ui_settings.setValue( + "routes/prewarm_aggressive", + bool(self.chk_routes_prewarm_aggressive.isChecked()), + ) + self._ui_settings.sync() + + # ---------------- EVENTS STREAM ---------------- + + def _start_events_stream(self) -> None: + if self.events_thread: + return + self.events_thread = EventThread(self.ctrl, self) + self.events_thread.eventReceived.connect(self._handle_event) + self.events_thread.error.connect(self._handle_event_error) + self.events_thread.start() + + @QtCore.Slot(object) + def _handle_event(self, ev) -> None: + try: + kinds = self.ctrl.classify_event(ev) + except Exception: + kinds = [] + + # Отдельно ловим routes_nft_progress, чтобы обновить лейбл прогресса. + try: + k = (getattr(ev, "kind", "") or "").strip().lower() + except Exception: + k = "" + + if k == "routes_nft_progress": + try: + prog_view = self.ctrl.routes_nft_progress_from_event(ev) + self._update_routes_progress_label(prog_view) + except Exception: + # не роняем UI, просто игнор + pass + + # Простая стратегия: триггерить существующие refresh-функции. + if "status" in kinds: + self.refresh_status_tab() + if "login" in kinds: + self.refresh_login_banner() + if "vpn" in kinds: + self.refresh_vpn_tab() + if "routes" in kinds: + self.refresh_routes_tab() + if "dns" in kinds: + self.refresh_dns_tab() + if "trace" in kinds: + self.refresh_trace_tab() + + + @QtCore.Slot(str) + def _handle_event_error(self, msg: str) -> None: + # Логируем в trace, UI не блокируем. + try: + self.ctrl.log_gui(f"[sse-error] {msg}") + except Exception: + pass + + # ---------------- REFRESH ---------------- + + def refresh_everything(self) -> None: + self.refresh_login_banner() + self.refresh_status_tab() + self.refresh_vpn_tab() + self.refresh_routes_tab() + self.refresh_dns_tab() + self.refresh_domains_tab() + self.refresh_trace_tab() + + def refresh_login_banner(self) -> None: + def work(): + view = self.ctrl.get_login_view() + + self.btn_login_banner.setText(view.text) + self._set_auth_button(view.logged_in) + + # Принудительно: зелёный если залогинен, серый если нет + color = "green" if view.logged_in else "gray" + base_style = "text-align: left; border: none;" + self.btn_login_banner.setStyleSheet( + f"{base_style} color: {color};" + ) + + self._safe(work, title="Login state error") + + def refresh_status_tab(self) -> None: + def work(): + view = self.ctrl.get_status_overview() + self.st_timestamp.setText(view.timestamp) + self.st_counts.setText(view.counts) + self.st_iface.setText(view.iface_table_mark) + + self._set_status_label_color( + self.st_route, view.policy_route, kind="policy" + ) + self._set_status_label_color( + self.st_routes_service, view.routes_service, kind="service" + ) + self._set_status_label_color( + self.st_smartdns_service, view.smartdns_service, kind="service" + ) + self._set_status_label_color( + self.st_vpn_service, view.vpn_service, kind="service" + ) + + self._safe(work, title="Status error") + + def refresh_vpn_tab(self) -> None: + def work(): + view = self.ctrl.vpn_status_view() + txt = [] + if view.desired_location: + txt.append(f"Desired location: {view.desired_location}") + if view.pretty_text: + txt.append(view.pretty_text.rstrip()) + self._set_text(self.txt_vpn, "\n".join(txt).strip() + "\n") + + auto_view = self.ctrl.vpn_autoconnect_view() + self.btn_autoconnect_toggle.setText( + "Disable autoconnect" if auto_view.enabled else "Enable autoconnect" + ) + self.lbl_autoconnect_state.setText(auto_view.unit_text) + self.lbl_autoconnect_state.setStyleSheet( + f"color: {auto_view.color};" + ) + + locs = self.ctrl.vpn_locations_view() + self.cmb_locations.blockSignals(True) + self.cmb_locations.clear() + + current_iso = (view.desired_location or "").strip().upper() + current_index = 0 + + for i, loc in enumerate(locs or []): + self.cmb_locations.addItem(loc.label, loc.iso) + if (loc.iso or "").upper() == current_iso: + current_index = i + + if self.cmb_locations.count() > 0: + self.cmb_locations.setCurrentIndex(current_index) + + self.cmb_locations.blockSignals(False) + + self._safe(work, title="VPN error") + + def refresh_routes_tab(self) -> None: + def work(): + timer_enabled = self.ctrl.routes_timer_enabled() + self.chk_timer.blockSignals(True) + self.chk_timer.setChecked(bool(timer_enabled)) + self.chk_timer.blockSignals(False) + + t = self.ctrl.traffic_mode_view() + self._set_traffic_mode_state( + t.desired_mode, + t.applied_mode, + t.preferred_iface, + bool(t.auto_local_bypass), + int(t.bypass_candidates), + int(t.overrides_applied), + int(t.cgroup_resolved_uids), + t.cgroup_warning, + bool(t.healthy), + bool(t.probe_ok), + t.probe_message, + t.active_iface, + t.iface_reason, + t.message, + ) + self._safe(work, title="Routes error") + + def refresh_dns_tab(self) -> None: + def work(): + self._dns_ui_refresh = True + try: + ups = self.ctrl.dns_upstreams_view() + self.ent_def1.setText(ups.default1 or "") + self.ent_def2.setText(ups.default2 or "") + self.ent_meta1.setText(ups.meta1 or "") + self.ent_meta2.setText(ups.meta2 or "") + + st = self.ctrl.dns_status_view() + self.ent_smartdns_addr.setText(st.smartdns_addr or "") + + mode = (getattr(st, "mode", "") or "").strip().lower() + if mode in ("hybrid_wildcard", "hybrid"): + hybrid_enabled = True + mode = "hybrid_wildcard" + else: + hybrid_enabled = False + mode = "direct" + + self.chk_dns_via_smartdns.blockSignals(True) + self.chk_dns_via_smartdns.setChecked(hybrid_enabled) + self.chk_dns_via_smartdns.blockSignals(False) + + # In direct + hybrid modes upstreams stay editable. + self.ent_def1.setEnabled(True) + self.ent_def2.setEnabled(True) + self.ent_meta1.setEnabled(True) + self.ent_meta2.setEnabled(True) + + unit_state = (st.unit_state or "unknown").strip().lower() + unit_active = unit_state == "active" + self.chk_dns_unit_relay.blockSignals(True) + self.chk_dns_unit_relay.setChecked(unit_active) + self.chk_dns_unit_relay.blockSignals(False) + + self.chk_dns_runtime_nftset.blockSignals(True) + self.chk_dns_runtime_nftset.setChecked(bool(getattr(st, "runtime_nftset", True))) + self.chk_dns_runtime_nftset.blockSignals(False) + self._set_dns_unit_relay_state(unit_active) + self._set_dns_runtime_state( + bool(getattr(st, "runtime_nftset", True)), + str(getattr(st, "wildcard_source", "") or ""), + str(getattr(st, "runtime_config_error", "") or ""), + ) + self._set_dns_mode_state(mode) + finally: + self._dns_ui_refresh = False + self._safe(work, title="DNS error") + + def refresh_domains_tab(self) -> None: + def work(): + # reload currently selected file + self.on_domains_load() + self._safe(work, title="Domains error") + + def refresh_trace_tab(self) -> None: + def work(): + if self.radio_trace_gui.isChecked(): + mode: TraceMode = "gui" + elif self.radio_trace_smartdns.isChecked(): + mode = "smartdns" + else: + mode = "full" + dump = self.ctrl.trace_view(mode) + text = "\n".join(dump.lines).rstrip() + if dump.lines: + text += "\n" + self._set_text(self.txt_trace, text, preserve_scroll=True) + self._safe(work, title="Trace error") + + # ---------------- TOP AUTH / BANNER ---------------- + + def on_auth_button(self) -> None: + def work(): + view = self.ctrl.get_login_view() + if view.logged_in: + self.on_logout() + else: + # при логине всегда переходим на вкладку AdGuardVPN и + # показываем страницу логина + self.tabs.setCurrentWidget(self.tab_vpn) + self._show_vpn_page("login") + self.on_start_login() + self._safe(work, title="Auth error") + + def on_login_banner_clicked(self) -> None: + def work(): + txt = self.ctrl.login_banner_cli_text() + QMessageBox.information(self, "AdGuard VPN", txt) + self._safe(work, title="Login banner error") + + # ---------------- LOGIN FLOW ACTIONS ---------------- + + def on_start_login(self) -> None: + def work(): + self.ctrl.log_gui("Top Login clicked") + self._show_vpn_page("login") + self._login_flow_reset_ui() + + start = self.ctrl.login_flow_start() + + self._login_cursor = int(start.cursor) + self.lbl_login_flow_status.setText( + f"Status: {start.status_text or '—'}" + ) + self.lbl_login_flow_email.setText( + f"User: {start.email}" if start.email else "" + ) + self.edit_login_url.setText(start.url or "") + + self._login_flow_set_buttons( + can_open=start.can_open, + can_check=start.can_check, + can_cancel=start.can_cancel, + ) + + if start.lines: + cleaned = self._clean_ui_lines(start.lines) + if cleaned: + self._append_text(self.txt_login_flow, cleaned + "\n") + + if not start.alive: + self._login_flow_autopoll_stop() + self._login_flow_set_buttons( + can_open=False, can_check=False, can_cancel=False + ) + self.btn_login_stop.setEnabled(False) + QTimer.singleShot(250, self.refresh_login_banner) + return + + self._login_flow_autopoll_start() + + self._safe(work, title="Login start error") + + def _login_flow_reset_ui(self) -> None: + self._login_cursor = 0 + self._login_url_opened = False + self.edit_login_url.setText("") + self.lbl_login_flow_status.setText("Status: —") + self.lbl_login_flow_email.setText("") + self._set_text(self.txt_login_flow, "") + + def _login_flow_set_buttons( + self, + *, + can_open: bool, + can_check: bool, + can_cancel: bool, + ) -> None: + self.btn_login_open.setEnabled(bool(can_open)) + self.btn_login_copy.setEnabled(bool(self.edit_login_url.text().strip())) + self.btn_login_check.setEnabled(bool(can_check)) + self.btn_login_close.setEnabled(bool(can_cancel)) + self.btn_login_stop.setEnabled(True) + + def _login_flow_autopoll_start(self) -> None: + self._login_flow_active = True + if not self.login_poll_timer.isActive(): + self.login_poll_timer.start() + + def _login_flow_autopoll_stop(self) -> None: + self._login_flow_active = False + if self.login_poll_timer.isActive(): + self.login_poll_timer.stop() + + def _login_poll_tick(self) -> None: + if not self._login_flow_active: + return + + def work(): + view = self.ctrl.login_flow_poll(self._login_cursor) + self._login_cursor = int(view.cursor) + + self.lbl_login_flow_status.setText( + f"Status: {view.status_text or '—'}" + ) + self.lbl_login_flow_email.setText( + f"User: {view.email}" if view.email else "" + ) + + if view.url: + self.edit_login_url.setText(view.url) + + self._login_flow_set_buttons( + can_open=view.can_open, + can_check=view.can_check, + can_cancel=view.can_cancel, + ) + + cleaned = self._clean_ui_lines(view.lines) + if cleaned: + self._append_text(self.txt_login_flow, cleaned + "\n") + + if (not self._login_url_opened) and view.url: + self._login_url_opened = True + try: + subprocess.Popen(["xdg-open", view.url]) + except Exception: + pass + + phase = (view.phase or "").strip().lower() + if (not view.alive) or phase in ( + "success", + "failed", + "cancelled", + "already_logged", + ): + self._login_flow_autopoll_stop() + self._login_flow_set_buttons( + can_open=False, can_check=False, can_cancel=False + ) + self.btn_login_stop.setEnabled(False) + QTimer.singleShot(250, self.refresh_login_banner) + + self._safe(work, title="Login flow error") + + def on_login_copy(self) -> None: + def work(): + u = self.edit_login_url.text().strip() + if u: + QApplication.clipboard().setText(u) + self.ctrl.log_gui("Login flow: copy-url") + self._safe(work, title="Login copy error") + + def on_login_open(self) -> None: + def work(): + u = self.edit_login_url.text().strip() + if u: + try: + subprocess.Popen(["xdg-open", u]) + except Exception: + pass + self.ctrl.log_gui("Login flow: open") + self._safe(work, title="Login open error") + + def on_login_check(self) -> None: + def work(): + # если ещё ничего не запущено — считаем это стартом логина + if ( + not self._login_flow_active + and self._login_cursor == 0 + and not self.edit_login_url.text().strip() + and not self.txt_login_flow.toPlainText().strip() + ): + self.on_start_login() + return + + self.ctrl.login_flow_action("check") + self.ctrl.log_gui("Login flow: check") + self._safe(work, title="Login check error") + + def on_login_cancel(self) -> None: + def work(): + self.ctrl.login_flow_action("cancel") + self.ctrl.log_gui("Login flow: cancel") + self._safe(work, title="Login cancel error") + + def on_login_stop(self) -> None: + def work(): + self.ctrl.login_flow_stop() + self.ctrl.log_gui("Login flow: stop") + self._login_flow_autopoll_stop() + QTimer.singleShot(250, self.refresh_login_banner) + self._safe(work, title="Login stop error") + + def on_logout(self) -> None: + def work(): + self.ctrl.log_gui("Top Logout clicked") + res = self.ctrl.vpn_logout() + self._set_text(self.txt_vpn, res.pretty_text or str(res)) + QTimer.singleShot(250, self.refresh_login_banner) + self._safe(work, title="Logout error") + + # ---- VPN actions --------------------------------------------------- + + def on_toggle_autoconnect(self) -> None: + def work(): + current = self.ctrl.vpn_autoconnect_enabled() + enable = not current + self.ctrl.vpn_set_autoconnect(enable) + self.ctrl.log_gui(f"VPN autoconnect set to {enable}") + self.refresh_vpn_tab() + self._safe(work, title="Autoconnect error") + + def on_set_location(self) -> None: + def work(): + idx = self.cmb_locations.currentIndex() + if idx < 0: + return + iso = self.cmb_locations.currentData() + self.ctrl.vpn_set_location(iso) + self.ctrl.log_gui(f"VPN location set to {iso}") + self.refresh_vpn_tab() + self._safe(work, title="Location error") + + # ---- Routes actions ------------------------------------------------ + + def on_routes_action( + self, action: Literal["start", "stop", "restart"] + ) -> None: + def work(): + res = self.ctrl.routes_service_action(action) + self._set_text(self.txt_routes, res.pretty_text or str(res)) + self.refresh_status_tab() + self._safe(work, title="Routes error") + + def _append_routes_log(self, msg: str) -> None: + line = (msg or "").strip() + if not line: + return + self._append_text(self.txt_routes, line + "\n") + self.ctrl.log_gui(line) + + def on_open_traffic_settings(self) -> None: + def work(): + def refresh_all_traffic() -> None: + self.refresh_routes_tab() + self.refresh_status_tab() + + dlg = TrafficModeDialog( + self.ctrl, + log_cb=self._append_routes_log, + refresh_cb=refresh_all_traffic, + parent=self, + ) + dlg.exec() + refresh_all_traffic() + self._safe(work, title="Traffic mode dialog error") + + def on_test_traffic_mode(self) -> None: + def work(): + view = self.ctrl.traffic_mode_test() + msg = ( + f"Traffic mode test: desired={view.desired_mode}, applied={view.applied_mode}, " + f"iface={view.active_iface or '-'}, probe_ok={view.probe_ok}, " + f"healthy={view.healthy}, auto_local_bypass={view.auto_local_bypass}, " + f"bypass_routes={view.bypass_candidates}, overrides={view.overrides_applied}, " + f"cgroup_uids={view.cgroup_resolved_uids}, cgroup_warning={view.cgroup_warning or '-'}, " + f"message={view.message}, probe={view.probe_message}" + ) + self._append_routes_log(msg) + self.refresh_routes_tab() + self.refresh_status_tab() + self._safe(work, title="Traffic mode test error") + + def on_toggle_timer(self) -> None: + def work(): + enabled = self.chk_timer.isChecked() + res = self.ctrl.routes_timer_set(enabled) + self.ctrl.log_gui(f"Routes timer set to {enabled}") + self._set_text(self.txt_routes, res.pretty_text or str(res)) + self.refresh_routes_tab() + self._safe(work, title="Timer error") + + def on_fix_policy_route(self) -> None: + def work(): + res = self.ctrl.routes_fix_policy_route() + self._set_text(self.txt_routes, res.pretty_text or str(res)) + self.refresh_status_tab() + self._safe(work, title="Policy route error") + + # ---- DNS actions --------------------------------------------------- + + def _schedule_dns_autosave(self, _text: str = "") -> None: + if self._dns_ui_refresh: + return + self.dns_save_timer.start() + + def _apply_dns_autosave(self) -> None: + def work(): + if self._dns_ui_refresh: + return + ups = DnsUpstreams( + default1=self.ent_def1.text().strip(), + default2=self.ent_def2.text().strip(), + meta1=self.ent_meta1.text().strip(), + meta2=self.ent_meta2.text().strip(), + ) + self.ctrl.dns_upstreams_save(ups) + self.ctrl.dns_mode_set( + self.chk_dns_via_smartdns.isChecked(), + self.ent_smartdns_addr.text().strip(), + ) + self.ctrl.log_gui("DNS settings autosaved") + self._safe(work, title="DNS save error") + + def on_dns_mode_toggle(self) -> None: + def work(): + via = self.chk_dns_via_smartdns.isChecked() + self.ctrl.dns_mode_set(via, self.ent_smartdns_addr.text().strip()) + mode = "hybrid_wildcard" if via else "direct" + self.ctrl.log_gui(f"DNS mode changed: mode={mode}") + self.refresh_dns_tab() + self._safe(work, title="DNS mode error") + + def on_smartdns_unit_toggle(self) -> None: + def work(): + enable = self.chk_dns_unit_relay.isChecked() + action = "start" if enable else "stop" + self.ctrl.smartdns_service_action(action) + self.ctrl.log_smartdns(f"SmartDNS unit action from GUI: {action}") + self.refresh_dns_tab() + self.refresh_status_tab() + self._safe(work, title="SmartDNS error") + + def on_smartdns_runtime_toggle(self) -> None: + def work(): + if self._dns_ui_refresh: + return + enable = self.chk_dns_runtime_nftset.isChecked() + st = self.ctrl.smartdns_runtime_set(enabled=enable, restart=True) + self.ctrl.log_smartdns( + f"SmartDNS runtime accelerator set from GUI: enabled={enable} changed={st.changed} restarted={st.restarted} source={st.wildcard_source}" + ) + self.refresh_dns_tab() + self.refresh_trace_tab() + self._safe(work, title="SmartDNS runtime error") + + def on_smartdns_prewarm(self) -> None: + def work(): + aggressive = bool(self.chk_routes_prewarm_aggressive.isChecked()) + result = self.ctrl.smartdns_prewarm(aggressive_subs=aggressive) + mode_txt = "aggressive_subs=on" if aggressive else "aggressive_subs=off" + self.ctrl.log_smartdns(f"SmartDNS prewarm requested from GUI: {mode_txt}") + txt = (result.pretty_text or "").strip() + if result.ok: + QMessageBox.information(self, "SmartDNS prewarm", txt or "OK") + else: + QMessageBox.critical(self, "SmartDNS prewarm", txt or "ERROR") + self.refresh_trace_tab() + self._safe(work, title="SmartDNS prewarm error") + + def _update_prewarm_mode_label(self, _state: int = 0) -> None: + aggressive = bool(self.chk_routes_prewarm_aggressive.isChecked()) + if aggressive: + self.lbl_routes_prewarm_mode.setText("Prewarm mode: aggressive (subs enabled)") + self.lbl_routes_prewarm_mode.setStyleSheet("color: orange;") + else: + self.lbl_routes_prewarm_mode.setText("Prewarm mode: wildcard-only") + self.lbl_routes_prewarm_mode.setStyleSheet("color: gray;") + + def _on_prewarm_aggressive_changed(self, _state: int = 0) -> None: + self._update_prewarm_mode_label(_state) + self._save_ui_preferences() + + # ---- Domains actions ----------------------------------------------- + + def on_domains_load(self) -> None: + def work(): + name = self._get_selected_domains_file() + content, source, path = self._load_file_content(name) + is_readonly = name in ("last-ips-map-direct", "last-ips-map-wildcard") + self.txt_domains.setReadOnly(is_readonly) + self.btn_domains_save.setEnabled(not is_readonly) + self._set_text(self.txt_domains, content) + ro = "read-only" if is_readonly else "editable" + self.lbl_domains_info.setText(f"{name} ({source}, {ro}) [{path}]") + self._safe(work, title="Domains load error") + + def on_domains_save(self) -> None: + def work(): + name = self._get_selected_domains_file() + content = self.txt_domains.toPlainText() + self._save_file_content(name, content) + self.ctrl.log_gui(f"Domains file saved: {name}") + self._safe(work, title="Domains save error") + + # ---- close event --------------------------------------------------- + + def closeEvent(self, event) -> None: # pragma: no cover - GUI + try: + self._save_ui_preferences() + self._login_flow_autopoll_stop() + if self.events_thread: + self.events_thread.stop() + self.events_thread.wait(1500) + finally: + super().closeEvent(event) + + +def main(argv: list[str] | None = None) -> int: + if argv is None: + argv = sys.argv[1:] + + base_url = "http://127.0.0.1:8080" + if argv: + base_url = argv[0] + + client = ApiClient(base_url) + ctrl = DashboardController(client) + + app = QApplication(sys.argv) + win = MainWindow(ctrl) + win.show() + return app.exec() + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/smartdns.conf b/smartdns.conf new file mode 100644 index 0000000..aa05455 --- /dev/null +++ b/smartdns.conf @@ -0,0 +1,24 @@ +# ---- basic listen ---- +bind 127.0.0.1:6053 -no-speed-check -no-cache + +# ---- upstream: Meta DNS (VPN-only) ---- +server 46.243.231.30 +server 46.243.231.41 + +# ---- upstream: AdGuard Home на PVE ---- +# обычный UDP DNS-сервер + +# включим простой лог в stdout (чтоб увидеть хоть что-то через journalctl) +log-level info +response-mode fastest-response + +# набор доменов для автотуннеля +domain-set -name agvpn_wild -file /etc/selective-vpn/smartdns.conf + +# кидать все A-ответы по доменам из agvpn_wild в nft set inet/agvpn/agvpn_dyn4 +nftset /domain-set:agvpn_wild/#4:inet#agvpn#agvpn_dyn4 + +# (опционально) включить таймауты и дебаг nftset +nftset-timeout yes +nftset-debug yes +