diff --git a/selective-vpn-api/app/config.go b/selective-vpn-api/app/config.go index 729d0f5..1872787 100644 --- a/selective-vpn-api/app/config.go +++ b/selective-vpn-api/app/config.go @@ -12,10 +12,11 @@ import "embed" // --------------------------------------------------------------------- const ( - stateDir = "/var/lib/selective-vpn" - statusFilePath = stateDir + "/status.json" - dnsModePath = stateDir + "/dns-mode.json" - trafficModePath = stateDir + "/traffic-mode.json" + stateDir = "/var/lib/selective-vpn" + statusFilePath = stateDir + "/status.json" + dnsModePath = stateDir + "/dns-mode.json" + trafficModePath = stateDir + "/traffic-mode.json" + trafficAppMarksPath = stateDir + "/traffic-appmarks.json" traceLogPath = stateDir + "/trace.log" smartdnsLogPath = stateDir + "/smartdns.log" @@ -80,6 +81,7 @@ const ( defaultPollAutoloopMs = 2500 defaultPollSystemdMs = 3000 defaultPollTraceMs = 1500 + defaultPollAppMarksMs = 15000 defaultHeartbeatSeconds = 15 ) diff --git a/selective-vpn-api/app/routes_update.go b/selective-vpn-api/app/routes_update.go index 8df1ed5..c63240d 100644 --- a/selective-vpn-api/app/routes_update.go +++ b/selective-vpn-api/app/routes_update.go @@ -151,14 +151,12 @@ func routesUpdate(iface string) cmdResult { _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "set", "inet", "agvpn", "agvpn4", "{", "type", "ipv4_addr", ";", "flags", "interval", ";", "}") _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "set", "inet", "agvpn", "agvpn_dyn4", "{", "type", "ipv4_addr", ";", "flags", "interval", ";", "}") - // EN: Per-app routing support (cgroup-mark sets). Output chain jumps into: - // EN: - output_apps: app-scoped marks (MARK_DIRECT / MARK_APP) + // EN: Output chain jumps into: + // EN: - output_apps: runtime per-app marks (MARK_DIRECT / MARK_APP) // EN: - output_ips: selective domain IP sets (MARK) - // RU: Поддержка per-app (cgroup-mark sets). Output chain прыгает в: - // RU: - output_apps: per-app marks (MARK_DIRECT / MARK_APP) + // RU: Output chain прыгает в: + // RU: - output_apps: runtime per-app marks (MARK_DIRECT / MARK_APP) // RU: - output_ips: селективные доменные IP сеты (MARK) - _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "set", "inet", "agvpn", "svpn_cg_vpn", "{", "typeof", "meta", "cgroup", ";", "flags", "timeout", ";", "}") - _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "set", "inet", "agvpn", "svpn_cg_direct", "{", "typeof", "meta", "cgroup", ";", "flags", "timeout", ";", "}") _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "chain", "inet", "agvpn", "output", "{", "type", "route", "hook", "output", "priority", "mangle;", "policy", "accept;", "}") _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "chain", "inet", "agvpn", "output_apps") @@ -169,10 +167,7 @@ func routesUpdate(iface string) cmdResult { _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "rule", "inet", "agvpn", "output", "jump", "output_apps") _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "rule", "inet", "agvpn", "output", "jump", "output_ips") - // App chain: mark + accept to stop further evaluation in this base chain. - _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "flush", "chain", "inet", "agvpn", "output_apps") - _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "rule", "inet", "agvpn", "output_apps", "meta", "cgroup", "@svpn_cg_direct", "meta", "mark", "set", MARK_DIRECT, "accept") - _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "rule", "inet", "agvpn", "output_apps", "meta", "cgroup", "@svpn_cg_vpn", "meta", "mark", "set", MARK_APP, "accept") + // App chain: runtime rules are managed by traffic_appmarks.go (do not flush here). // Domain chain: selective IP sets (resolver output). _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "flush", "chain", "inet", "agvpn", "output_ips") diff --git a/selective-vpn-api/app/traffic_appmarks.go b/selective-vpn-api/app/traffic_appmarks.go index aa90d79..d491cb7 100644 --- a/selective-vpn-api/app/traffic_appmarks.go +++ b/selective-vpn-api/app/traffic_appmarks.go @@ -9,35 +9,56 @@ import ( "path/filepath" "strconv" "strings" + "sync" "syscall" "time" ) // --------------------------------------------------------------------- -// traffic app marks (per-app routing via cgroup -> fwmark) +// traffic app marks (per-app routing via cgroupv2 path -> fwmark) // --------------------------------------------------------------------- // -// EN: This module manages runtime cgroup-id sets used by nftables rules in -// EN: routes_update.go (output_apps chain). GUI/clients can add/remove cgroup IDs -// EN: to force traffic through VPN (MARK_APP) or force direct (MARK_DIRECT). -// RU: Этот модуль управляет runtime cgroup-id сетами для nftables правил из -// RU: routes_update.go (цепочка output_apps). GUI/клиенты могут добавлять/удалять -// RU: cgroup IDs, чтобы форсировать трафик через VPN (MARK_APP) или в direct (MARK_DIRECT). +// EN: This module manages runtime per-app routing marks. +// EN: We match cgroupv2 paths using nftables `socket cgroupv2` and set fwmark: +// EN: - MARK_APP (VPN) or MARK_DIRECT (direct). +// EN: TTL is kept in a JSON state file; expired entries are pruned. +// RU: Этот модуль управляет runtime per-app маршрутизацией. +// RU: Мы матчим cgroupv2 path через nftables `socket cgroupv2` и ставим fwmark: +// RU: - MARK_APP (VPN) или MARK_DIRECT (direct). +// RU: TTL хранится в JSON состоянии; просроченные записи удаляются. const ( - nftSetCgroupVPN = "svpn_cg_vpn" - nftSetCgroupDirect = "svpn_cg_direct" - cgroupRootFS = "/sys/fs/cgroup" + appMarksTable = "agvpn" + appMarksChain = "output_apps" + appMarkCommentPrefix = "svpn_appmark" + defaultAppMarkTTLSeconds = 24 * 60 * 60 ) +var appMarksMu sync.Mutex + +type appMarksState struct { + Version int `json:"version"` + UpdatedAt string `json:"updated_at"` + Items []appMarkItem `json:"items,omitempty"` +} + +type appMarkItem struct { + ID uint64 `json:"id"` + Target string `json:"target"` // vpn|direct + Cgroup string `json:"cgroup"` // absolute path ("/user.slice/..."), informational + CgroupRel string `json:"cgroup_rel"` + Level int `json:"level"` + AddedAt string `json:"added_at"` + ExpiresAt string `json:"expires_at"` +} + func handleTrafficAppMarks(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: - vpnElems, _ := readNftSetElements(nftSetCgroupVPN) - directElems, _ := readNftSetElements(nftSetCgroupDirect) + vpnCount, directCount := appMarksGetStatus() writeJSON(w, http.StatusOK, TrafficAppMarksStatusResponse{ - VPNCount: len(vpnElems), - DirectCount: len(directElems), + VPNCount: vpnCount, + DirectCount: directCount, Message: "ok", }) case http.MethodPost: @@ -76,7 +97,6 @@ func handleTrafficAppMarks(w http.ResponseWriter, r *http.Request) { return } - // Ensure nft objects exist even if routes-update hasn't run yet. if err := ensureAppMarksNft(); err != nil { writeJSON(w, http.StatusOK, TrafficAppMarksResponse{ OK: false, @@ -88,12 +108,25 @@ func handleTrafficAppMarks(w http.ResponseWriter, r *http.Request) { return } - var ( - cgID uint64 - err error - ) - if cgroup != "" { - cgID, cgroup, err = resolveCgroupIDForNft(cgroup) + switch op { + case TrafficAppMarksAdd: + if isAllDigits(cgroup) { + writeJSON(w, http.StatusOK, TrafficAppMarksResponse{ + OK: false, + Op: string(op), + Target: target, + Cgroup: cgroup, + Message: "cgroup must be a cgroupv2 path (ControlGroup), not a numeric id", + }) + return + } + + ttl := timeoutSec + if ttl == 0 { + ttl = defaultAppMarkTTLSeconds + } + + rel, level, inodeID, cgAbs, err := resolveCgroupV2PathForNft(cgroup) if err != nil { writeJSON(w, http.StatusOK, TrafficAppMarksResponse{ OK: false, @@ -104,91 +137,78 @@ func handleTrafficAppMarks(w http.ResponseWriter, r *http.Request) { }) return } - } - if op == TrafficAppMarksAdd && target == "vpn" { - // Ensure VPN policy table has a base route. This matters when current traffic-mode=direct. - traffic := loadTrafficModeState() - iface, _ := resolveTrafficIface(traffic.PreferredIface) - if strings.TrimSpace(iface) == "" { + if target == "vpn" { + traffic := loadTrafficModeState() + iface, _ := resolveTrafficIface(traffic.PreferredIface) + if strings.TrimSpace(iface) == "" { + writeJSON(w, http.StatusOK, TrafficAppMarksResponse{ + OK: false, + Op: string(op), + Target: target, + Cgroup: cgAbs, + CgroupID: inodeID, + Message: "vpn interface not found (set preferred iface or bring VPN up)", + }) + return + } + if err := ensureTrafficRouteBase(iface, traffic.AutoLocalBypass); err != nil { + writeJSON(w, http.StatusOK, TrafficAppMarksResponse{ + OK: false, + Op: string(op), + Target: target, + Cgroup: cgAbs, + CgroupID: inodeID, + Message: "ensure vpn route base failed: " + err.Error(), + }) + return + } + } + + if err := appMarksAdd(target, inodeID, cgAbs, rel, level, ttl); err != nil { writeJSON(w, http.StatusOK, TrafficAppMarksResponse{ - OK: false, - Op: string(op), - Target: target, - Cgroup: cgroup, - CgroupID: cgID, - Message: "vpn interface not found (set preferred iface or bring VPN up)", + OK: false, + Op: string(op), + Target: target, + Cgroup: cgAbs, + CgroupID: inodeID, + TimeoutSec: ttl, + Message: err.Error(), }) return } - if err := ensureTrafficRouteBase(iface, traffic.AutoLocalBypass); err != nil { - writeJSON(w, http.StatusOK, TrafficAppMarksResponse{ - OK: false, - Op: string(op), - Target: target, - Cgroup: cgroup, - CgroupID: cgID, - Message: "ensure vpn route base failed: " + err.Error(), - }) - return - } - } - setName := nftSetCgroupDirect - if target == "vpn" { - setName = nftSetCgroupVPN - } - - switch op { - case TrafficAppMarksAdd: - ttl := timeoutSec - if ttl == 0 { - ttl = 24 * 60 * 60 // 24h default if client didn't specify - } - if err := nftAddCgroupElement(setName, cgID, ttl); err != nil { - writeJSON(w, http.StatusOK, TrafficAppMarksResponse{ - OK: false, - Op: string(op), - Target: target, - Cgroup: cgroup, - CgroupID: cgID, - Message: err.Error(), - }) - return - } - appendTraceLine("traffic", fmt.Sprintf("appmarks add target=%s cgroup=%s id=%d ttl=%ds", target, cgroup, cgID, ttl)) + appendTraceLine("traffic", fmt.Sprintf("appmarks add target=%s cgroup=%s id=%d ttl=%ds", target, cgAbs, inodeID, ttl)) writeJSON(w, http.StatusOK, TrafficAppMarksResponse{ OK: true, Op: string(op), Target: target, - Cgroup: cgroup, - CgroupID: cgID, + Cgroup: cgAbs, + CgroupID: inodeID, TimeoutSec: ttl, Message: "added", }) case TrafficAppMarksDel: - if err := nftDelCgroupElement(setName, cgID); err != nil { + if err := appMarksDel(target, cgroup); err != nil { writeJSON(w, http.StatusOK, TrafficAppMarksResponse{ - OK: false, - Op: string(op), - Target: target, - Cgroup: cgroup, - CgroupID: cgID, - Message: err.Error(), + OK: false, + Op: string(op), + Target: target, + Cgroup: cgroup, + Message: err.Error(), }) return } - appendTraceLine("traffic", fmt.Sprintf("appmarks del target=%s cgroup=%s id=%d", target, cgroup, cgID)) + appendTraceLine("traffic", fmt.Sprintf("appmarks del target=%s cgroup=%s", target, cgroup)) writeJSON(w, http.StatusOK, TrafficAppMarksResponse{ - OK: true, - Op: string(op), - Target: target, - Cgroup: cgroup, - CgroupID: cgID, - Message: "deleted", + OK: true, + Op: string(op), + Target: target, + Cgroup: cgroup, + Message: "deleted", }) case TrafficAppMarksClear: - if err := nftFlushSet(setName); err != nil { + if err := appMarksClear(target); err != nil { writeJSON(w, http.StatusOK, TrafficAppMarksResponse{ OK: false, Op: string(op), @@ -212,145 +232,385 @@ func handleTrafficAppMarks(w http.ResponseWriter, r *http.Request) { } } -func ensureAppMarksNft() error { - // Best-effort "ensure": ignore "exists" errors and proceed. - _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "table", "inet", "agvpn") - _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "set", "inet", "agvpn", nftSetCgroupVPN, "{", "typeof", "meta", "cgroup", ";", "flags", "timeout", ";", "}") - _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "set", "inet", "agvpn", nftSetCgroupDirect, "{", "typeof", "meta", "cgroup", ";", "flags", "timeout", ";", "}") - _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "chain", "inet", "agvpn", "output", "{", "type", "route", "hook", "output", "priority", "mangle;", "policy", "accept;", "}") - _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "chain", "inet", "agvpn", "output_apps") +func appMarksGetStatus() (vpnCount int, directCount int) { + _ = pruneExpiredAppMarks() - // Keep output_apps deterministic (no duplicates). Safe because this chain is dedicated. - _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "flush", "chain", "inet", "agvpn", "output_apps") - _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "rule", "inet", "agvpn", "output_apps", "meta", "cgroup", "@"+nftSetCgroupDirect, "meta", "mark", "set", MARK_DIRECT, "accept") - _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "rule", "inet", "agvpn", "output_apps", "meta", "cgroup", "@"+nftSetCgroupVPN, "meta", "mark", "set", MARK_APP, "accept") + appMarksMu.Lock() + defer appMarksMu.Unlock() - // Ensure output chain has a jump into output_apps (routes-update may also manage this). - out, _, _, _ := runCommandTimeout(5*time.Second, "nft", "list", "chain", "inet", "agvpn", "output") - if !strings.Contains(out, "jump output_apps") { - _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "rule", "inet", "agvpn", "output", "jump", "output_apps") + st := loadAppMarksState() + for _, it := range st.Items { + switch strings.ToLower(strings.TrimSpace(it.Target)) { + case "vpn": + vpnCount++ + case "direct": + directCount++ + } + } + return vpnCount, directCount +} + +func appMarksAdd(target string, id uint64, cgAbs string, rel string, level int, ttlSec int) error { + target = strings.ToLower(strings.TrimSpace(target)) + if target != "vpn" && target != "direct" { + return fmt.Errorf("invalid target") + } + if id == 0 { + return fmt.Errorf("invalid cgroup id") + } + if strings.TrimSpace(rel) == "" || level <= 0 { + return fmt.Errorf("invalid cgroup path") + } + if ttlSec <= 0 { + ttlSec = defaultAppMarkTTLSeconds + } + + appMarksMu.Lock() + defer appMarksMu.Unlock() + + st := loadAppMarksState() + changed := pruneExpiredAppMarksLocked(&st, time.Now().UTC()) + + // Replace any existing rule/state for this (target,id). + _ = nftDeleteAppMarkRule(target, id) + if err := nftInsertAppMarkRule(target, rel, level, id); err != nil { + return err + } + + now := time.Now().UTC() + item := appMarkItem{ + ID: id, + Target: target, + Cgroup: cgAbs, + CgroupRel: rel, + Level: level, + AddedAt: now.Format(time.RFC3339), + ExpiresAt: now.Add(time.Duration(ttlSec) * time.Second).Format(time.RFC3339), + } + st.Items = upsertAppMarkItem(st.Items, item) + changed = true + + if changed { + if err := saveAppMarksState(st); err != nil { + return err + } } return nil } -func resolveCgroupIDForNft(input string) (uint64, string, error) { +func appMarksDel(target string, cgroup string) error { + target = strings.ToLower(strings.TrimSpace(target)) + if target != "vpn" && target != "direct" { + return fmt.Errorf("invalid target") + } + cgroup = strings.TrimSpace(cgroup) + if cgroup == "" { + return fmt.Errorf("empty cgroup") + } + + appMarksMu.Lock() + defer appMarksMu.Unlock() + + st := loadAppMarksState() + changed := pruneExpiredAppMarksLocked(&st, time.Now().UTC()) + + var id uint64 + var cgAbs string + + if isAllDigits(cgroup) { + v, err := strconv.ParseUint(cgroup, 10, 64) + if err == nil { + id = v + } + } else { + rel := normalizeCgroupRelOnly(cgroup) + if rel != "" { + cgAbs = "/" + rel + // Try to resolve inode id if directory still exists. + if inode, err := cgroupDirInode(rel); err == nil { + id = inode + } + } + } + + // Fallback to state lookup by cgroup string. + idx := -1 + for i, it := range st.Items { + if strings.ToLower(strings.TrimSpace(it.Target)) != target { + continue + } + if id != 0 && it.ID == id { + idx = i + break + } + if id == 0 && cgAbs != "" && strings.TrimSpace(it.Cgroup) == cgAbs { + id = it.ID + idx = i + break + } + } + + if id != 0 { + _ = nftDeleteAppMarkRule(target, id) + } + if idx >= 0 { + st.Items = append(st.Items[:idx], st.Items[idx+1:]...) + changed = true + } + + if changed { + return saveAppMarksState(st) + } + return nil +} + +func appMarksClear(target string) error { + target = strings.ToLower(strings.TrimSpace(target)) + if target != "vpn" && target != "direct" { + return fmt.Errorf("invalid target") + } + + appMarksMu.Lock() + defer appMarksMu.Unlock() + + st := loadAppMarksState() + changed := pruneExpiredAppMarksLocked(&st, time.Now().UTC()) + + kept := st.Items[:0] + for _, it := range st.Items { + if strings.ToLower(strings.TrimSpace(it.Target)) == target { + _ = nftDeleteAppMarkRule(target, it.ID) + changed = true + continue + } + kept = append(kept, it) + } + st.Items = kept + + if changed { + return saveAppMarksState(st) + } + return nil +} + +func ensureAppMarksNft() error { + // Best-effort "ensure": ignore "exists" errors and proceed. + _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "table", "inet", appMarksTable) + _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "chain", "inet", appMarksTable, "output", "{", "type", "route", "hook", "output", "priority", "mangle;", "policy", "accept;", "}") + _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "chain", "inet", appMarksTable, appMarksChain) + + out, _, _, _ := runCommandTimeout(5*time.Second, "nft", "list", "chain", "inet", appMarksTable, "output") + if !strings.Contains(out, "jump "+appMarksChain) { + _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "insert", "rule", "inet", appMarksTable, "output", "jump", appMarksChain) + } + + // Remove legacy rules that relied on `meta cgroup @svpn_cg_*` (broken on some kernels). + _ = cleanupLegacyAppMarksRules() + return nil +} + +func cleanupLegacyAppMarksRules() error { + out, _, _, _ := runCommandTimeout(5*time.Second, "nft", "-a", "list", "chain", "inet", appMarksTable, appMarksChain) + for _, line := range strings.Split(out, "\n") { + l := strings.ToLower(line) + if !strings.Contains(l, "meta cgroup") { + continue + } + if !strings.Contains(l, "svpn_cg_") { + continue + } + h := parseNftHandle(line) + if h <= 0 { + continue + } + _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "delete", "rule", "inet", appMarksTable, appMarksChain, "handle", strconv.Itoa(h)) + } + return nil +} + +func appMarkComment(target string, id uint64) string { + return fmt.Sprintf("%s:%s:%d", appMarkCommentPrefix, target, id) +} + +func nftInsertAppMarkRule(target string, rel string, level int, id uint64) error { + mark := MARK_DIRECT + if target == "vpn" { + mark = MARK_APP + } + comment := appMarkComment(target, id) + // EN: nft requires a *string literal* for cgroupv2 path; paths with '@' (user@1000.service) + // EN: break tokenization unless we pass quotes as part of nft language input. + // RU: nft ожидает *строку* для cgroupv2 пути; пути с '@' (user@1000.service) + // RU: ломают токенизацию, поэтому передаем кавычки как часть nft-выражения. + pathLit := fmt.Sprintf("\"%s\"", rel) + commentLit := fmt.Sprintf("\"%s\"", comment) + + _, out, code, err := runCommandTimeout( + 5*time.Second, + "nft", "insert", "rule", "inet", appMarksTable, appMarksChain, + "socket", "cgroupv2", "level", strconv.Itoa(level), pathLit, + "meta", "mark", "set", mark, + "accept", + "comment", commentLit, + ) + if err != nil || code != 0 { + if err == nil { + err = fmt.Errorf("nft insert rule exited with %d", code) + } + return fmt.Errorf("nft insert appmark rule failed: %w (%s)", err, strings.TrimSpace(out)) + } + return nil +} + +func nftDeleteAppMarkRule(target string, id uint64) error { + comment := appMarkComment(target, id) + out, _, _, _ := runCommandTimeout(5*time.Second, "nft", "-a", "list", "chain", "inet", appMarksTable, appMarksChain) + for _, line := range strings.Split(out, "\n") { + if !strings.Contains(line, comment) { + continue + } + h := parseNftHandle(line) + if h <= 0 { + continue + } + _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "delete", "rule", "inet", appMarksTable, appMarksChain, "handle", strconv.Itoa(h)) + } + return nil +} + +func parseNftHandle(line string) int { + fields := strings.Fields(line) + for i := 0; i < len(fields)-1; i++ { + if fields[i] == "handle" { + n, _ := strconv.Atoi(fields[i+1]) + return n + } + } + return 0 +} + +func resolveCgroupV2PathForNft(input string) (rel string, level int, inodeID uint64, abs string, err error) { raw := strings.TrimSpace(input) if raw == "" { - return 0, "", fmt.Errorf("empty cgroup") + return "", 0, 0, "", fmt.Errorf("empty cgroup") } - // Allow numeric cgroup id input. - if isAllDigits(raw) { - id, err := strconv.ParseUint(raw, 10, 64) - if err != nil || id == 0 { - return 0, raw, fmt.Errorf("invalid cgroup id: %s", raw) - } - return id, raw, nil + rel = normalizeCgroupRelOnly(raw) + if rel == "" { + return "", 0, 0, raw, fmt.Errorf("invalid cgroup path: %s", raw) } - // Normalize into a safe relative path under /sys/fs/cgroup. - rel := strings.TrimPrefix(raw, "/") + inodeID, err = cgroupDirInode(rel) + if err != nil { + return "", 0, 0, raw, err + } + + level = strings.Count(rel, "/") + 1 + abs = "/" + rel + return rel, level, inodeID, abs, nil +} + +func normalizeCgroupRelOnly(raw string) string { + rel := strings.TrimSpace(raw) + rel = strings.TrimPrefix(rel, "/") rel = filepath.Clean(rel) if rel == "." || rel == "" { - return 0, raw, fmt.Errorf("invalid cgroup path: %s", raw) + return "" } if strings.HasPrefix(rel, "..") || strings.Contains(rel, "../") { - return 0, raw, fmt.Errorf("invalid cgroup path (traversal): %s", raw) + return "" } + return rel +} - full := filepath.Join(cgroupRootFS, rel) +func cgroupDirInode(rel string) (uint64, error) { + full := filepath.Join(cgroupRootPath, strings.TrimPrefix(rel, "/")) fi, err := os.Stat(full) if err != nil || fi == nil || !fi.IsDir() { - return 0, raw, fmt.Errorf("cgroup not found: %s", raw) + return 0, fmt.Errorf("cgroup not found: %s", "/"+strings.TrimPrefix(rel, "/")) } st, ok := fi.Sys().(*syscall.Stat_t) if !ok || st == nil { - return 0, raw, fmt.Errorf("cannot stat cgroup: %s", raw) + return 0, fmt.Errorf("cannot stat cgroup: %s", "/"+strings.TrimPrefix(rel, "/")) } if st.Ino == 0 { - return 0, raw, fmt.Errorf("invalid cgroup inode id: %s", raw) + return 0, fmt.Errorf("invalid cgroup inode id: %s", "/"+strings.TrimPrefix(rel, "/")) } - // EN: For cgroup v2, the directory inode is used as cgroup id (matches meta cgroup / bpf_get_current_cgroup_id). - // RU: Для cgroup v2 inode директории используется как cgroup id (соответствует meta cgroup / bфункции bpf_get_current_cgroup_id). - return st.Ino, "/" + rel, nil + return st.Ino, nil } -func nftAddCgroupElement(setName string, cgroupID uint64, timeoutSec int) error { - if strings.TrimSpace(setName) == "" { - return fmt.Errorf("empty setName") - } - if cgroupID == 0 { - return fmt.Errorf("invalid cgroup id") - } - if timeoutSec < 0 { - return fmt.Errorf("invalid timeout_sec") - } +func pruneExpiredAppMarks() error { + appMarksMu.Lock() + defer appMarksMu.Unlock() - // NOTE: set has flags timeout; element can include timeout. - ttl := fmt.Sprintf("%ds", timeoutSec) - _, out, code, err := runCommandTimeout( - 5*time.Second, - "nft", "add", "element", "inet", "agvpn", setName, - "{", fmt.Sprintf("%d", cgroupID), "timeout", ttl, "}", - ) - if err != nil || code != 0 { - msg := strings.ToLower(out) - if strings.Contains(msg, "file exists") || strings.Contains(msg, "exists") { - return nil - } - if err == nil { - err = fmt.Errorf("nft add element exited with %d", code) - } - return fmt.Errorf("nft add element failed: %w", err) + st := loadAppMarksState() + if pruneExpiredAppMarksLocked(&st, time.Now().UTC()) { + return saveAppMarksState(st) } return nil } -func nftDelCgroupElement(setName string, cgroupID uint64) error { - if strings.TrimSpace(setName) == "" { - return fmt.Errorf("empty setName") +func pruneExpiredAppMarksLocked(st *appMarksState, now time.Time) (changed bool) { + if st == nil { + return false } - if cgroupID == 0 { - return fmt.Errorf("invalid cgroup id") - } - _, out, code, err := runCommandTimeout( - 5*time.Second, - "nft", "delete", "element", "inet", "agvpn", setName, - "{", fmt.Sprintf("%d", cgroupID), "}", - ) - if err != nil || code != 0 { - msg := strings.ToLower(out) - if strings.Contains(msg, "no such file") || - strings.Contains(msg, "not found") || - strings.Contains(msg, "does not exist") { - return nil + kept := st.Items[:0] + for _, it := range st.Items { + exp, err := time.Parse(time.RFC3339, strings.TrimSpace(it.ExpiresAt)) + if err != nil || !exp.After(now) { + _ = nftDeleteAppMarkRule(strings.ToLower(strings.TrimSpace(it.Target)), it.ID) + changed = true + continue } - if err == nil { - err = fmt.Errorf("nft delete element exited with %d", code) - } - return fmt.Errorf("nft delete element failed: %w", err) + kept = append(kept, it) } - return nil + st.Items = kept + return changed } -func nftFlushSet(setName string) error { - if strings.TrimSpace(setName) == "" { - return fmt.Errorf("empty setName") - } - _, out, code, err := runCommandTimeout(5*time.Second, "nft", "flush", "set", "inet", "agvpn", setName) - if err != nil || code != 0 { - msg := strings.ToLower(out) - if strings.Contains(msg, "no such file") || - strings.Contains(msg, "not found") || - strings.Contains(msg, "does not exist") { - return nil +func upsertAppMarkItem(items []appMarkItem, next appMarkItem) []appMarkItem { + out := items[:0] + for _, it := range items { + if strings.ToLower(strings.TrimSpace(it.Target)) == strings.ToLower(strings.TrimSpace(next.Target)) && it.ID == next.ID { + continue } - if err == nil { - err = fmt.Errorf("nft flush set exited with %d", code) - } - return fmt.Errorf("nft flush set failed: %w", err) + out = append(out, it) } - return nil + out = append(out, next) + return out +} + +func loadAppMarksState() appMarksState { + st := appMarksState{Version: 1} + data, err := os.ReadFile(trafficAppMarksPath) + if err != nil { + return st + } + if err := json.Unmarshal(data, &st); err != nil { + return appMarksState{Version: 1} + } + if st.Version == 0 { + st.Version = 1 + } + return st +} + +func saveAppMarksState(st appMarksState) error { + st.Version = 1 + st.UpdatedAt = time.Now().UTC().Format(time.RFC3339) + + data, err := json.MarshalIndent(st, "", " ") + if err != nil { + return err + } + if err := os.MkdirAll(stateDir, 0o755); err != nil { + return err + } + tmp := trafficAppMarksPath + ".tmp" + if err := os.WriteFile(tmp, data, 0o644); err != nil { + return err + } + return os.Rename(tmp, trafficAppMarksPath) } func isAllDigits(s string) bool { diff --git a/selective-vpn-api/app/watchers.go b/selective-vpn-api/app/watchers.go index 3144f92..ef40a18 100644 --- a/selective-vpn-api/app/watchers.go +++ b/selective-vpn-api/app/watchers.go @@ -24,12 +24,14 @@ func startWatchers(ctx context.Context) { autoEvery := time.Duration(envInt("SVPN_POLL_AUTOLOOP_MS", defaultPollAutoloopMs)) * time.Millisecond systemdEvery := time.Duration(envInt("SVPN_POLL_SYSTEMD_MS", defaultPollSystemdMs)) * time.Millisecond traceEvery := time.Duration(envInt("SVPN_POLL_TRACE_MS", defaultPollTraceMs)) * time.Millisecond + appMarksEvery := time.Duration(envInt("SVPN_POLL_APPMARKS_MS", defaultPollAppMarksMs)) * time.Millisecond go watchStatusFile(ctx, statusEvery) go watchLoginFile(ctx, loginEvery) go watchAutoloop(ctx, autoEvery) go watchFileChange(ctx, traceLogPath, "trace_changed", "full", traceEvery) go watchFileChange(ctx, smartdnsLogPath, "trace_changed", "smartdns", traceEvery) + go watchTrafficAppMarksTTL(ctx, appMarksEvery) go watchSystemdUnitDynamic(ctx, routesServiceUnitName, "routes_service", systemdEvery) go watchSystemdUnitDynamic(ctx, routesTimerUnitName, "routes_timer", systemdEvery) @@ -37,6 +39,17 @@ func startWatchers(ctx context.Context) { go watchSystemdUnit(ctx, "smartdns-local.service", "smartdns_unit", systemdEvery) } +func watchTrafficAppMarksTTL(ctx context.Context, every time.Duration) { + for { + select { + case <-ctx.Done(): + return + case <-time.After(every): + } + _ = pruneExpiredAppMarks() + } +} + // --------------------------------------------------------------------- // status file watcher // ---------------------------------------------------------------------