package app import ( "encoding/json" "fmt" "io" "net/http" "os" "path/filepath" "strconv" "strings" "syscall" "time" ) // --------------------------------------------------------------------- // traffic app marks (per-app routing via cgroup -> fwmark) // --------------------------------------------------------------------- // // EN: This module manages runtime cgroup-id sets used by nftables rules in // EN: routes_update.go (output_apps chain). GUI/clients can add/remove cgroup IDs // EN: to force traffic through VPN (MARK_APP) or force direct (MARK_DIRECT). // RU: Этот модуль управляет runtime cgroup-id сетами для nftables правил из // RU: routes_update.go (цепочка output_apps). GUI/клиенты могут добавлять/удалять // RU: cgroup IDs, чтобы форсировать трафик через VPN (MARK_APP) или в direct (MARK_DIRECT). const ( nftSetCgroupVPN = "svpn_cg_vpn" nftSetCgroupDirect = "svpn_cg_direct" cgroupRootFS = "/sys/fs/cgroup" ) func handleTrafficAppMarks(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: vpnElems, _ := readNftSetElements(nftSetCgroupVPN) directElems, _ := readNftSetElements(nftSetCgroupDirect) writeJSON(w, http.StatusOK, TrafficAppMarksStatusResponse{ VPNCount: len(vpnElems), DirectCount: len(directElems), Message: "ok", }) case http.MethodPost: var body TrafficAppMarksRequest if r.Body != nil { defer r.Body.Close() if err := json.NewDecoder(io.LimitReader(r.Body, 1<<20)).Decode(&body); err != nil && err != io.EOF { http.Error(w, "bad json", http.StatusBadRequest) return } } op := TrafficAppMarksOp(strings.ToLower(strings.TrimSpace(string(body.Op)))) target := strings.ToLower(strings.TrimSpace(body.Target)) cgroup := strings.TrimSpace(body.Cgroup) timeoutSec := body.TimeoutSec if op == "" { http.Error(w, "missing op", http.StatusBadRequest) return } if target == "" { http.Error(w, "missing target", http.StatusBadRequest) return } if target != "vpn" && target != "direct" { http.Error(w, "target must be vpn|direct", http.StatusBadRequest) return } if (op == TrafficAppMarksAdd || op == TrafficAppMarksDel) && cgroup == "" { http.Error(w, "missing cgroup", http.StatusBadRequest) return } if timeoutSec < 0 { http.Error(w, "timeout_sec must be >= 0", http.StatusBadRequest) return } // Ensure nft objects exist even if routes-update hasn't run yet. if err := ensureAppMarksNft(); err != nil { writeJSON(w, http.StatusOK, TrafficAppMarksResponse{ OK: false, Op: string(op), Target: target, Cgroup: cgroup, Message: "nft init failed: " + err.Error(), }) return } var ( cgID uint64 err error ) if cgroup != "" { cgID, cgroup, err = resolveCgroupIDForNft(cgroup) if err != nil { writeJSON(w, http.StatusOK, TrafficAppMarksResponse{ OK: false, Op: string(op), Target: target, Cgroup: body.Cgroup, Message: err.Error(), }) return } } if op == TrafficAppMarksAdd && target == "vpn" { // Ensure VPN policy table has a base route. This matters when current traffic-mode=direct. traffic := loadTrafficModeState() iface, _ := resolveTrafficIface(traffic.PreferredIface) if strings.TrimSpace(iface) == "" { writeJSON(w, http.StatusOK, TrafficAppMarksResponse{ OK: false, Op: string(op), Target: target, Cgroup: cgroup, CgroupID: cgID, Message: "vpn interface not found (set preferred iface or bring VPN up)", }) return } if err := ensureTrafficRouteBase(iface, traffic.AutoLocalBypass); err != nil { writeJSON(w, http.StatusOK, TrafficAppMarksResponse{ OK: false, Op: string(op), Target: target, Cgroup: cgroup, CgroupID: cgID, Message: "ensure vpn route base failed: " + err.Error(), }) return } } setName := nftSetCgroupDirect if target == "vpn" { setName = nftSetCgroupVPN } switch op { case TrafficAppMarksAdd: ttl := timeoutSec if ttl == 0 { ttl = 24 * 60 * 60 // 24h default if client didn't specify } if err := nftAddCgroupElement(setName, cgID, ttl); err != nil { writeJSON(w, http.StatusOK, TrafficAppMarksResponse{ OK: false, Op: string(op), Target: target, Cgroup: cgroup, CgroupID: cgID, Message: err.Error(), }) return } appendTraceLine("traffic", fmt.Sprintf("appmarks add target=%s cgroup=%s id=%d ttl=%ds", target, cgroup, cgID, ttl)) writeJSON(w, http.StatusOK, TrafficAppMarksResponse{ OK: true, Op: string(op), Target: target, Cgroup: cgroup, CgroupID: cgID, TimeoutSec: ttl, Message: "added", }) case TrafficAppMarksDel: if err := nftDelCgroupElement(setName, cgID); err != nil { writeJSON(w, http.StatusOK, TrafficAppMarksResponse{ OK: false, Op: string(op), Target: target, Cgroup: cgroup, CgroupID: cgID, Message: err.Error(), }) return } appendTraceLine("traffic", fmt.Sprintf("appmarks del target=%s cgroup=%s id=%d", target, cgroup, cgID)) writeJSON(w, http.StatusOK, TrafficAppMarksResponse{ OK: true, Op: string(op), Target: target, Cgroup: cgroup, CgroupID: cgID, Message: "deleted", }) case TrafficAppMarksClear: if err := nftFlushSet(setName); err != nil { writeJSON(w, http.StatusOK, TrafficAppMarksResponse{ OK: false, Op: string(op), Target: target, Message: err.Error(), }) return } appendTraceLine("traffic", fmt.Sprintf("appmarks clear target=%s", target)) writeJSON(w, http.StatusOK, TrafficAppMarksResponse{ OK: true, Op: string(op), Target: target, Message: "cleared", }) default: http.Error(w, "unknown op", http.StatusBadRequest) } default: http.Error(w, "method not allowed", http.StatusMethodNotAllowed) } } func ensureAppMarksNft() error { // Best-effort "ensure": ignore "exists" errors and proceed. _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "table", "inet", "agvpn") _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "set", "inet", "agvpn", nftSetCgroupVPN, "{", "typeof", "meta", "cgroup", ";", "flags", "timeout", ";", "}") _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "set", "inet", "agvpn", nftSetCgroupDirect, "{", "typeof", "meta", "cgroup", ";", "flags", "timeout", ";", "}") _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "chain", "inet", "agvpn", "output", "{", "type", "route", "hook", "output", "priority", "mangle;", "policy", "accept;", "}") _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "chain", "inet", "agvpn", "output_apps") // Keep output_apps deterministic (no duplicates). Safe because this chain is dedicated. _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "flush", "chain", "inet", "agvpn", "output_apps") _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "rule", "inet", "agvpn", "output_apps", "meta", "cgroup", "@"+nftSetCgroupDirect, "meta", "mark", "set", MARK_DIRECT, "accept") _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "rule", "inet", "agvpn", "output_apps", "meta", "cgroup", "@"+nftSetCgroupVPN, "meta", "mark", "set", MARK_APP, "accept") // Ensure output chain has a jump into output_apps (routes-update may also manage this). out, _, _, _ := runCommandTimeout(5*time.Second, "nft", "list", "chain", "inet", "agvpn", "output") if !strings.Contains(out, "jump output_apps") { _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "rule", "inet", "agvpn", "output", "jump", "output_apps") } return nil } func resolveCgroupIDForNft(input string) (uint64, string, error) { raw := strings.TrimSpace(input) if raw == "" { return 0, "", fmt.Errorf("empty cgroup") } // Allow numeric cgroup id input. if isAllDigits(raw) { id, err := strconv.ParseUint(raw, 10, 64) if err != nil || id == 0 { return 0, raw, fmt.Errorf("invalid cgroup id: %s", raw) } return id, raw, nil } // Normalize into a safe relative path under /sys/fs/cgroup. rel := strings.TrimPrefix(raw, "/") rel = filepath.Clean(rel) if rel == "." || rel == "" { return 0, raw, fmt.Errorf("invalid cgroup path: %s", raw) } if strings.HasPrefix(rel, "..") || strings.Contains(rel, "../") { return 0, raw, fmt.Errorf("invalid cgroup path (traversal): %s", raw) } full := filepath.Join(cgroupRootFS, rel) fi, err := os.Stat(full) if err != nil || fi == nil || !fi.IsDir() { return 0, raw, fmt.Errorf("cgroup not found: %s", raw) } st, ok := fi.Sys().(*syscall.Stat_t) if !ok || st == nil { return 0, raw, fmt.Errorf("cannot stat cgroup: %s", raw) } if st.Ino == 0 { return 0, raw, fmt.Errorf("invalid cgroup inode id: %s", raw) } // EN: For cgroup v2, the directory inode is used as cgroup id (matches meta cgroup / bpf_get_current_cgroup_id). // RU: Для cgroup v2 inode директории используется как cgroup id (соответствует meta cgroup / bфункции bpf_get_current_cgroup_id). return st.Ino, "/" + rel, nil } func nftAddCgroupElement(setName string, cgroupID uint64, timeoutSec int) error { if strings.TrimSpace(setName) == "" { return fmt.Errorf("empty setName") } if cgroupID == 0 { return fmt.Errorf("invalid cgroup id") } if timeoutSec < 0 { return fmt.Errorf("invalid timeout_sec") } // NOTE: set has flags timeout; element can include timeout. ttl := fmt.Sprintf("%ds", timeoutSec) _, out, code, err := runCommandTimeout( 5*time.Second, "nft", "add", "element", "inet", "agvpn", setName, "{", fmt.Sprintf("%d", cgroupID), "timeout", ttl, "}", ) if err != nil || code != 0 { msg := strings.ToLower(out) if strings.Contains(msg, "file exists") || strings.Contains(msg, "exists") { return nil } if err == nil { err = fmt.Errorf("nft add element exited with %d", code) } return fmt.Errorf("nft add element failed: %w", err) } return nil } func nftDelCgroupElement(setName string, cgroupID uint64) error { if strings.TrimSpace(setName) == "" { return fmt.Errorf("empty setName") } if cgroupID == 0 { return fmt.Errorf("invalid cgroup id") } _, out, code, err := runCommandTimeout( 5*time.Second, "nft", "delete", "element", "inet", "agvpn", setName, "{", fmt.Sprintf("%d", cgroupID), "}", ) if err != nil || code != 0 { msg := strings.ToLower(out) if strings.Contains(msg, "no such file") || strings.Contains(msg, "not found") || strings.Contains(msg, "does not exist") { return nil } if err == nil { err = fmt.Errorf("nft delete element exited with %d", code) } return fmt.Errorf("nft delete element failed: %w", err) } return nil } func nftFlushSet(setName string) error { if strings.TrimSpace(setName) == "" { return fmt.Errorf("empty setName") } _, out, code, err := runCommandTimeout(5*time.Second, "nft", "flush", "set", "inet", "agvpn", setName) if err != nil || code != 0 { msg := strings.ToLower(out) if strings.Contains(msg, "no such file") || strings.Contains(msg, "not found") || strings.Contains(msg, "does not exist") { return nil } if err == nil { err = fmt.Errorf("nft flush set exited with %d", code) } return fmt.Errorf("nft flush set failed: %w", err) } return nil } func isAllDigits(s string) bool { s = strings.TrimSpace(s) if s == "" { return false } for i := 0; i < len(s); i++ { ch := s[i] if ch < '0' || ch > '9' { return false } } return true }