package app import ( "encoding/json" "fmt" "io" "net/http" "os" "path/filepath" "strconv" "strings" "sync" "syscall" "time" ) // --------------------------------------------------------------------- // traffic app marks (per-app routing via cgroupv2 path -> fwmark) // --------------------------------------------------------------------- // // EN: This module manages runtime per-app routing marks. // EN: We match cgroupv2 paths using nftables `socket cgroupv2` and set fwmark: // EN: - MARK_APP (VPN) or MARK_DIRECT (direct). // EN: TTL is kept in a JSON state file; expired entries are pruned. // RU: Этот модуль управляет runtime per-app маршрутизацией. // RU: Мы матчим cgroupv2 path через nftables `socket cgroupv2` и ставим fwmark: // RU: - MARK_APP (VPN) или MARK_DIRECT (direct). // RU: TTL хранится в JSON состоянии; просроченные записи удаляются. const ( appMarksTable = "agvpn" appMarksChain = "output_apps" appMarkCommentPrefix = "svpn_appmark" defaultAppMarkTTLSeconds = 24 * 60 * 60 ) var appMarksMu sync.Mutex type appMarksState struct { Version int `json:"version"` UpdatedAt string `json:"updated_at"` Items []appMarkItem `json:"items,omitempty"` } type appMarkItem struct { ID uint64 `json:"id"` Target string `json:"target"` // vpn|direct Cgroup string `json:"cgroup"` // absolute path ("/user.slice/..."), informational CgroupRel string `json:"cgroup_rel"` Level int `json:"level"` Unit string `json:"unit,omitempty"` Command string `json:"command,omitempty"` AppKey string `json:"app_key,omitempty"` AddedAt string `json:"added_at"` ExpiresAt string `json:"expires_at"` } func handleTrafficAppMarks(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: vpnCount, directCount := appMarksGetStatus() writeJSON(w, http.StatusOK, TrafficAppMarksStatusResponse{ VPNCount: vpnCount, DirectCount: directCount, Message: "ok", }) case http.MethodPost: var body TrafficAppMarksRequest if r.Body != nil { defer r.Body.Close() if err := json.NewDecoder(io.LimitReader(r.Body, 1<<20)).Decode(&body); err != nil && err != io.EOF { http.Error(w, "bad json", http.StatusBadRequest) return } } op := TrafficAppMarksOp(strings.ToLower(strings.TrimSpace(string(body.Op)))) target := strings.ToLower(strings.TrimSpace(body.Target)) cgroup := strings.TrimSpace(body.Cgroup) unit := strings.TrimSpace(body.Unit) command := strings.TrimSpace(body.Command) appKey := strings.TrimSpace(body.AppKey) timeoutSec := body.TimeoutSec if op == "" { http.Error(w, "missing op", http.StatusBadRequest) return } if target == "" { http.Error(w, "missing target", http.StatusBadRequest) return } if target != "vpn" && target != "direct" { http.Error(w, "target must be vpn|direct", http.StatusBadRequest) return } if (op == TrafficAppMarksAdd || op == TrafficAppMarksDel) && cgroup == "" { http.Error(w, "missing cgroup", http.StatusBadRequest) return } if timeoutSec < 0 { http.Error(w, "timeout_sec must be >= 0", http.StatusBadRequest) return } if err := ensureAppMarksNft(); err != nil { writeJSON(w, http.StatusOK, TrafficAppMarksResponse{ OK: false, Op: string(op), Target: target, Cgroup: cgroup, Message: "nft init failed: " + err.Error(), }) return } switch op { case TrafficAppMarksAdd: if isAllDigits(cgroup) { writeJSON(w, http.StatusOK, TrafficAppMarksResponse{ OK: false, Op: string(op), Target: target, Cgroup: cgroup, Message: "cgroup must be a cgroupv2 path (ControlGroup), not a numeric id", }) return } ttl := timeoutSec if ttl == 0 { ttl = defaultAppMarkTTLSeconds } rel, level, inodeID, cgAbs, err := resolveCgroupV2PathForNft(cgroup) if err != nil { writeJSON(w, http.StatusOK, TrafficAppMarksResponse{ OK: false, Op: string(op), Target: target, Cgroup: body.Cgroup, Message: err.Error(), }) return } if target == "vpn" { traffic := loadTrafficModeState() iface, _ := resolveTrafficIface(traffic.PreferredIface) if strings.TrimSpace(iface) == "" { writeJSON(w, http.StatusOK, TrafficAppMarksResponse{ OK: false, Op: string(op), Target: target, Cgroup: cgAbs, CgroupID: inodeID, Message: "vpn interface not found (set preferred iface or bring VPN up)", }) return } if err := ensureTrafficRouteBase(iface, traffic.AutoLocalBypass); err != nil { writeJSON(w, http.StatusOK, TrafficAppMarksResponse{ OK: false, Op: string(op), Target: target, Cgroup: cgAbs, CgroupID: inodeID, Message: "ensure vpn route base failed: " + err.Error(), }) return } } if err := appMarksAdd(target, inodeID, cgAbs, rel, level, unit, command, appKey, ttl); err != nil { writeJSON(w, http.StatusOK, TrafficAppMarksResponse{ OK: false, Op: string(op), Target: target, Cgroup: cgAbs, CgroupID: inodeID, TimeoutSec: ttl, Message: err.Error(), }) return } appendTraceLine("traffic", fmt.Sprintf("appmarks add target=%s cgroup=%s id=%d ttl=%ds", target, cgAbs, inodeID, ttl)) writeJSON(w, http.StatusOK, TrafficAppMarksResponse{ OK: true, Op: string(op), Target: target, Cgroup: cgAbs, CgroupID: inodeID, TimeoutSec: ttl, Message: "added", }) case TrafficAppMarksDel: if err := appMarksDel(target, cgroup); err != nil { writeJSON(w, http.StatusOK, TrafficAppMarksResponse{ OK: false, Op: string(op), Target: target, Cgroup: cgroup, Message: err.Error(), }) return } appendTraceLine("traffic", fmt.Sprintf("appmarks del target=%s cgroup=%s", target, cgroup)) writeJSON(w, http.StatusOK, TrafficAppMarksResponse{ OK: true, Op: string(op), Target: target, Cgroup: cgroup, Message: "deleted", }) case TrafficAppMarksClear: if err := appMarksClear(target); err != nil { writeJSON(w, http.StatusOK, TrafficAppMarksResponse{ OK: false, Op: string(op), Target: target, Message: err.Error(), }) return } appendTraceLine("traffic", fmt.Sprintf("appmarks clear target=%s", target)) writeJSON(w, http.StatusOK, TrafficAppMarksResponse{ OK: true, Op: string(op), Target: target, Message: "cleared", }) default: http.Error(w, "unknown op", http.StatusBadRequest) } default: http.Error(w, "method not allowed", http.StatusMethodNotAllowed) } } func appMarksGetStatus() (vpnCount int, directCount int) { _ = pruneExpiredAppMarks() appMarksMu.Lock() defer appMarksMu.Unlock() st := loadAppMarksState() for _, it := range st.Items { switch strings.ToLower(strings.TrimSpace(it.Target)) { case "vpn": vpnCount++ case "direct": directCount++ } } return vpnCount, directCount } func appMarksAdd(target string, id uint64, cgAbs string, rel string, level int, unit string, command string, appKey string, ttlSec int) error { target = strings.ToLower(strings.TrimSpace(target)) if target != "vpn" && target != "direct" { return fmt.Errorf("invalid target") } if id == 0 { return fmt.Errorf("invalid cgroup id") } if strings.TrimSpace(rel) == "" || level <= 0 { return fmt.Errorf("invalid cgroup path") } if ttlSec <= 0 { ttlSec = defaultAppMarkTTLSeconds } appMarksMu.Lock() defer appMarksMu.Unlock() st := loadAppMarksState() changed := pruneExpiredAppMarksLocked(&st, time.Now().UTC()) unit = strings.TrimSpace(unit) command = strings.TrimSpace(command) appKey = normalizeAppKey(appKey, command) // EN: Avoid unbounded growth of marks for the same app. // RU: Не даём бесконечно плодить метки на одно и то же приложение. if appKey != "" { kept := st.Items[:0] for _, it := range st.Items { if strings.ToLower(strings.TrimSpace(it.Target)) == target && strings.TrimSpace(it.AppKey) == appKey && it.ID != id { _ = nftDeleteAppMarkRule(target, it.ID) changed = true continue } kept = append(kept, it) } st.Items = kept } // Replace any existing rule/state for this (target,id). _ = nftDeleteAppMarkRule(target, id) if err := nftInsertAppMarkRule(target, rel, level, id); err != nil { return err } now := time.Now().UTC() item := appMarkItem{ ID: id, Target: target, Cgroup: cgAbs, CgroupRel: rel, Level: level, Unit: unit, Command: command, AppKey: appKey, AddedAt: now.Format(time.RFC3339), ExpiresAt: now.Add(time.Duration(ttlSec) * time.Second).Format(time.RFC3339), } st.Items = upsertAppMarkItem(st.Items, item) changed = true if changed { if err := saveAppMarksState(st); err != nil { return err } } return nil } func appMarksDel(target string, cgroup string) error { target = strings.ToLower(strings.TrimSpace(target)) if target != "vpn" && target != "direct" { return fmt.Errorf("invalid target") } cgroup = strings.TrimSpace(cgroup) if cgroup == "" { return fmt.Errorf("empty cgroup") } appMarksMu.Lock() defer appMarksMu.Unlock() st := loadAppMarksState() changed := pruneExpiredAppMarksLocked(&st, time.Now().UTC()) var id uint64 var cgAbs string if isAllDigits(cgroup) { v, err := strconv.ParseUint(cgroup, 10, 64) if err == nil { id = v } } else { rel := normalizeCgroupRelOnly(cgroup) if rel != "" { cgAbs = "/" + rel // Try to resolve inode id if directory still exists. if inode, err := cgroupDirInode(rel); err == nil { id = inode } } } // Fallback to state lookup by cgroup string. idx := -1 for i, it := range st.Items { if strings.ToLower(strings.TrimSpace(it.Target)) != target { continue } if id != 0 && it.ID == id { idx = i break } if id == 0 && cgAbs != "" && strings.TrimSpace(it.Cgroup) == cgAbs { id = it.ID idx = i break } } if id != 0 { _ = nftDeleteAppMarkRule(target, id) } if idx >= 0 { st.Items = append(st.Items[:idx], st.Items[idx+1:]...) changed = true } if changed { return saveAppMarksState(st) } return nil } func appMarksClear(target string) error { target = strings.ToLower(strings.TrimSpace(target)) if target != "vpn" && target != "direct" { return fmt.Errorf("invalid target") } appMarksMu.Lock() defer appMarksMu.Unlock() st := loadAppMarksState() changed := pruneExpiredAppMarksLocked(&st, time.Now().UTC()) kept := st.Items[:0] for _, it := range st.Items { if strings.ToLower(strings.TrimSpace(it.Target)) == target { _ = nftDeleteAppMarkRule(target, it.ID) changed = true continue } kept = append(kept, it) } st.Items = kept if changed { return saveAppMarksState(st) } return nil } func ensureAppMarksNft() error { // Best-effort "ensure": ignore "exists" errors and proceed. _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "table", "inet", appMarksTable) _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "chain", "inet", appMarksTable, "output", "{", "type", "route", "hook", "output", "priority", "mangle;", "policy", "accept;", "}") _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "chain", "inet", appMarksTable, appMarksChain) out, _, _, _ := runCommandTimeout(5*time.Second, "nft", "list", "chain", "inet", appMarksTable, "output") if !strings.Contains(out, "jump "+appMarksChain) { _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "insert", "rule", "inet", appMarksTable, "output", "jump", appMarksChain) } // Remove legacy rules that relied on `meta cgroup @svpn_cg_*` (broken on some kernels). _ = cleanupLegacyAppMarksRules() return nil } func cleanupLegacyAppMarksRules() error { out, _, _, _ := runCommandTimeout(5*time.Second, "nft", "-a", "list", "chain", "inet", appMarksTable, appMarksChain) for _, line := range strings.Split(out, "\n") { l := strings.ToLower(line) if !strings.Contains(l, "meta cgroup") { continue } if !strings.Contains(l, "svpn_cg_") { continue } h := parseNftHandle(line) if h <= 0 { continue } _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "delete", "rule", "inet", appMarksTable, appMarksChain, "handle", strconv.Itoa(h)) } return nil } func appMarkComment(target string, id uint64) string { return fmt.Sprintf("%s:%s:%d", appMarkCommentPrefix, target, id) } func nftInsertAppMarkRule(target string, rel string, level int, id uint64) error { mark := MARK_DIRECT if target == "vpn" { mark = MARK_APP } comment := appMarkComment(target, id) // EN: nft requires a *string literal* for cgroupv2 path; paths with '@' (user@1000.service) // EN: break tokenization unless we pass quotes as part of nft language input. // RU: nft ожидает *строку* для cgroupv2 пути; пути с '@' (user@1000.service) // RU: ломают токенизацию, поэтому передаем кавычки как часть nft-выражения. pathLit := fmt.Sprintf("\"%s\"", rel) commentLit := fmt.Sprintf("\"%s\"", comment) _, out, code, err := runCommandTimeout( 5*time.Second, "nft", "insert", "rule", "inet", appMarksTable, appMarksChain, "socket", "cgroupv2", "level", strconv.Itoa(level), pathLit, "meta", "mark", "set", mark, "accept", "comment", commentLit, ) if err != nil || code != 0 { if err == nil { err = fmt.Errorf("nft insert rule exited with %d", code) } return fmt.Errorf("nft insert appmark rule failed: %w (%s)", err, strings.TrimSpace(out)) } return nil } func nftDeleteAppMarkRule(target string, id uint64) error { comment := appMarkComment(target, id) out, _, _, _ := runCommandTimeout(5*time.Second, "nft", "-a", "list", "chain", "inet", appMarksTable, appMarksChain) for _, line := range strings.Split(out, "\n") { if !strings.Contains(line, comment) { continue } h := parseNftHandle(line) if h <= 0 { continue } _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "delete", "rule", "inet", appMarksTable, appMarksChain, "handle", strconv.Itoa(h)) } return nil } func parseNftHandle(line string) int { fields := strings.Fields(line) for i := 0; i < len(fields)-1; i++ { if fields[i] == "handle" { n, _ := strconv.Atoi(fields[i+1]) return n } } return 0 } func resolveCgroupV2PathForNft(input string) (rel string, level int, inodeID uint64, abs string, err error) { raw := strings.TrimSpace(input) if raw == "" { return "", 0, 0, "", fmt.Errorf("empty cgroup") } rel = normalizeCgroupRelOnly(raw) if rel == "" { return "", 0, 0, raw, fmt.Errorf("invalid cgroup path: %s", raw) } inodeID, err = cgroupDirInode(rel) if err != nil { return "", 0, 0, raw, err } level = strings.Count(rel, "/") + 1 abs = "/" + rel return rel, level, inodeID, abs, nil } func normalizeCgroupRelOnly(raw string) string { rel := strings.TrimSpace(raw) rel = strings.TrimPrefix(rel, "/") rel = filepath.Clean(rel) if rel == "." || rel == "" { return "" } if strings.HasPrefix(rel, "..") || strings.Contains(rel, "../") { return "" } return rel } func cgroupDirInode(rel string) (uint64, error) { full := filepath.Join(cgroupRootPath, strings.TrimPrefix(rel, "/")) fi, err := os.Stat(full) if err != nil || fi == nil || !fi.IsDir() { return 0, fmt.Errorf("cgroup not found: %s", "/"+strings.TrimPrefix(rel, "/")) } st, ok := fi.Sys().(*syscall.Stat_t) if !ok || st == nil { return 0, fmt.Errorf("cannot stat cgroup: %s", "/"+strings.TrimPrefix(rel, "/")) } if st.Ino == 0 { return 0, fmt.Errorf("invalid cgroup inode id: %s", "/"+strings.TrimPrefix(rel, "/")) } return st.Ino, nil } func pruneExpiredAppMarks() error { appMarksMu.Lock() defer appMarksMu.Unlock() st := loadAppMarksState() if pruneExpiredAppMarksLocked(&st, time.Now().UTC()) { return saveAppMarksState(st) } return nil } func pruneExpiredAppMarksLocked(st *appMarksState, now time.Time) (changed bool) { if st == nil { return false } kept := st.Items[:0] for _, it := range st.Items { exp, err := time.Parse(time.RFC3339, strings.TrimSpace(it.ExpiresAt)) if err != nil || !exp.After(now) { _ = nftDeleteAppMarkRule(strings.ToLower(strings.TrimSpace(it.Target)), it.ID) changed = true continue } kept = append(kept, it) } st.Items = kept return changed } func upsertAppMarkItem(items []appMarkItem, next appMarkItem) []appMarkItem { out := items[:0] for _, it := range items { if strings.ToLower(strings.TrimSpace(it.Target)) == strings.ToLower(strings.TrimSpace(next.Target)) && it.ID == next.ID { continue } out = append(out, it) } out = append(out, next) return out } func loadAppMarksState() appMarksState { st := appMarksState{Version: 1} data, err := os.ReadFile(trafficAppMarksPath) if err != nil { return st } if err := json.Unmarshal(data, &st); err != nil { return appMarksState{Version: 1} } if st.Version == 0 { st.Version = 1 } return st } func saveAppMarksState(st appMarksState) error { st.Version = 1 st.UpdatedAt = time.Now().UTC().Format(time.RFC3339) data, err := json.MarshalIndent(st, "", " ") if err != nil { return err } if err := os.MkdirAll(stateDir, 0o755); err != nil { return err } tmp := trafficAppMarksPath + ".tmp" if err := os.WriteFile(tmp, data, 0o644); err != nil { return err } return os.Rename(tmp, trafficAppMarksPath) } func normalizeAppKey(appKey string, command string) string { key := strings.TrimSpace(appKey) if key != "" { return key } cmd := strings.TrimSpace(command) if cmd == "" { return "" } fields := strings.Fields(cmd) if len(fields) > 0 { return strings.TrimSpace(fields[0]) } return cmd } func isAllDigits(s string) bool { s = strings.TrimSpace(s) if s == "" { return false } for i := 0; i < len(s); i++ { ch := s[i] if ch < '0' || ch > '9' { return false } } return true }