fix(appmarks): use nft socket cgroupv2 rules for per-app routing

This commit is contained in:
beckline
2026-02-15 14:43:13 +03:00
parent 4b99057adb
commit b77adb153a
4 changed files with 468 additions and 198 deletions

View File

@@ -12,10 +12,11 @@ import "embed"
// --------------------------------------------------------------------- // ---------------------------------------------------------------------
const ( const (
stateDir = "/var/lib/selective-vpn" stateDir = "/var/lib/selective-vpn"
statusFilePath = stateDir + "/status.json" statusFilePath = stateDir + "/status.json"
dnsModePath = stateDir + "/dns-mode.json" dnsModePath = stateDir + "/dns-mode.json"
trafficModePath = stateDir + "/traffic-mode.json" trafficModePath = stateDir + "/traffic-mode.json"
trafficAppMarksPath = stateDir + "/traffic-appmarks.json"
traceLogPath = stateDir + "/trace.log" traceLogPath = stateDir + "/trace.log"
smartdnsLogPath = stateDir + "/smartdns.log" smartdnsLogPath = stateDir + "/smartdns.log"
@@ -80,6 +81,7 @@ const (
defaultPollAutoloopMs = 2500 defaultPollAutoloopMs = 2500
defaultPollSystemdMs = 3000 defaultPollSystemdMs = 3000
defaultPollTraceMs = 1500 defaultPollTraceMs = 1500
defaultPollAppMarksMs = 15000
defaultHeartbeatSeconds = 15 defaultHeartbeatSeconds = 15
) )

View File

@@ -151,14 +151,12 @@ func routesUpdate(iface string) cmdResult {
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "set", "inet", "agvpn", "agvpn4", "{", "type", "ipv4_addr", ";", "flags", "interval", ";", "}") _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "set", "inet", "agvpn", "agvpn4", "{", "type", "ipv4_addr", ";", "flags", "interval", ";", "}")
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "set", "inet", "agvpn", "agvpn_dyn4", "{", "type", "ipv4_addr", ";", "flags", "interval", ";", "}") _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "set", "inet", "agvpn", "agvpn_dyn4", "{", "type", "ipv4_addr", ";", "flags", "interval", ";", "}")
// EN: Per-app routing support (cgroup-mark sets). Output chain jumps into: // EN: Output chain jumps into:
// EN: - output_apps: app-scoped marks (MARK_DIRECT / MARK_APP) // EN: - output_apps: runtime per-app marks (MARK_DIRECT / MARK_APP)
// EN: - output_ips: selective domain IP sets (MARK) // EN: - output_ips: selective domain IP sets (MARK)
// RU: Поддержка per-app (cgroup-mark sets). Output chain прыгает в: // RU: Output chain прыгает в:
// RU: - output_apps: per-app marks (MARK_DIRECT / MARK_APP) // RU: - output_apps: runtime per-app marks (MARK_DIRECT / MARK_APP)
// RU: - output_ips: селективные доменные IP сеты (MARK) // RU: - output_ips: селективные доменные IP сеты (MARK)
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "set", "inet", "agvpn", "svpn_cg_vpn", "{", "typeof", "meta", "cgroup", ";", "flags", "timeout", ";", "}")
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "set", "inet", "agvpn", "svpn_cg_direct", "{", "typeof", "meta", "cgroup", ";", "flags", "timeout", ";", "}")
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "chain", "inet", "agvpn", "output", "{", "type", "route", "hook", "output", "priority", "mangle;", "policy", "accept;", "}") _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "chain", "inet", "agvpn", "output", "{", "type", "route", "hook", "output", "priority", "mangle;", "policy", "accept;", "}")
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "chain", "inet", "agvpn", "output_apps") _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "chain", "inet", "agvpn", "output_apps")
@@ -169,10 +167,7 @@ func routesUpdate(iface string) cmdResult {
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "rule", "inet", "agvpn", "output", "jump", "output_apps") _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "rule", "inet", "agvpn", "output", "jump", "output_apps")
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "rule", "inet", "agvpn", "output", "jump", "output_ips") _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "rule", "inet", "agvpn", "output", "jump", "output_ips")
// App chain: mark + accept to stop further evaluation in this base chain. // App chain: runtime rules are managed by traffic_appmarks.go (do not flush here).
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "flush", "chain", "inet", "agvpn", "output_apps")
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "rule", "inet", "agvpn", "output_apps", "meta", "cgroup", "@svpn_cg_direct", "meta", "mark", "set", MARK_DIRECT, "accept")
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "rule", "inet", "agvpn", "output_apps", "meta", "cgroup", "@svpn_cg_vpn", "meta", "mark", "set", MARK_APP, "accept")
// Domain chain: selective IP sets (resolver output). // Domain chain: selective IP sets (resolver output).
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "flush", "chain", "inet", "agvpn", "output_ips") _, _, _, _ = runCommandTimeout(5*time.Second, "nft", "flush", "chain", "inet", "agvpn", "output_ips")

View File

@@ -9,35 +9,56 @@ import (
"path/filepath" "path/filepath"
"strconv" "strconv"
"strings" "strings"
"sync"
"syscall" "syscall"
"time" "time"
) )
// --------------------------------------------------------------------- // ---------------------------------------------------------------------
// traffic app marks (per-app routing via cgroup -> fwmark) // traffic app marks (per-app routing via cgroupv2 path -> fwmark)
// --------------------------------------------------------------------- // ---------------------------------------------------------------------
// //
// EN: This module manages runtime cgroup-id sets used by nftables rules in // EN: This module manages runtime per-app routing marks.
// EN: routes_update.go (output_apps chain). GUI/clients can add/remove cgroup IDs // EN: We match cgroupv2 paths using nftables `socket cgroupv2` and set fwmark:
// EN: to force traffic through VPN (MARK_APP) or force direct (MARK_DIRECT). // EN: - MARK_APP (VPN) or MARK_DIRECT (direct).
// RU: Этот модуль управляет runtime cgroup-id сетами для nftables правил из // EN: TTL is kept in a JSON state file; expired entries are pruned.
// RU: routes_update.go (цепочка output_apps). GUI/клиенты могут добавлять/удалять // RU: Этот модуль управляет runtime per-app маршрутизацией.
// RU: cgroup IDs, чтобы форсировать трафик через VPN (MARK_APP) или в direct (MARK_DIRECT). // RU: Мы матчим cgroupv2 path через nftables `socket cgroupv2` и ставим fwmark:
// RU: - MARK_APP (VPN) или MARK_DIRECT (direct).
// RU: TTL хранится в JSON состоянии; просроченные записи удаляются.
const ( const (
nftSetCgroupVPN = "svpn_cg_vpn" appMarksTable = "agvpn"
nftSetCgroupDirect = "svpn_cg_direct" appMarksChain = "output_apps"
cgroupRootFS = "/sys/fs/cgroup" appMarkCommentPrefix = "svpn_appmark"
defaultAppMarkTTLSeconds = 24 * 60 * 60
) )
var appMarksMu sync.Mutex
type appMarksState struct {
Version int `json:"version"`
UpdatedAt string `json:"updated_at"`
Items []appMarkItem `json:"items,omitempty"`
}
type appMarkItem struct {
ID uint64 `json:"id"`
Target string `json:"target"` // vpn|direct
Cgroup string `json:"cgroup"` // absolute path ("/user.slice/..."), informational
CgroupRel string `json:"cgroup_rel"`
Level int `json:"level"`
AddedAt string `json:"added_at"`
ExpiresAt string `json:"expires_at"`
}
func handleTrafficAppMarks(w http.ResponseWriter, r *http.Request) { func handleTrafficAppMarks(w http.ResponseWriter, r *http.Request) {
switch r.Method { switch r.Method {
case http.MethodGet: case http.MethodGet:
vpnElems, _ := readNftSetElements(nftSetCgroupVPN) vpnCount, directCount := appMarksGetStatus()
directElems, _ := readNftSetElements(nftSetCgroupDirect)
writeJSON(w, http.StatusOK, TrafficAppMarksStatusResponse{ writeJSON(w, http.StatusOK, TrafficAppMarksStatusResponse{
VPNCount: len(vpnElems), VPNCount: vpnCount,
DirectCount: len(directElems), DirectCount: directCount,
Message: "ok", Message: "ok",
}) })
case http.MethodPost: case http.MethodPost:
@@ -76,7 +97,6 @@ func handleTrafficAppMarks(w http.ResponseWriter, r *http.Request) {
return return
} }
// Ensure nft objects exist even if routes-update hasn't run yet.
if err := ensureAppMarksNft(); err != nil { if err := ensureAppMarksNft(); err != nil {
writeJSON(w, http.StatusOK, TrafficAppMarksResponse{ writeJSON(w, http.StatusOK, TrafficAppMarksResponse{
OK: false, OK: false,
@@ -88,12 +108,25 @@ func handleTrafficAppMarks(w http.ResponseWriter, r *http.Request) {
return return
} }
var ( switch op {
cgID uint64 case TrafficAppMarksAdd:
err error if isAllDigits(cgroup) {
) writeJSON(w, http.StatusOK, TrafficAppMarksResponse{
if cgroup != "" { OK: false,
cgID, cgroup, err = resolveCgroupIDForNft(cgroup) Op: string(op),
Target: target,
Cgroup: cgroup,
Message: "cgroup must be a cgroupv2 path (ControlGroup), not a numeric id",
})
return
}
ttl := timeoutSec
if ttl == 0 {
ttl = defaultAppMarkTTLSeconds
}
rel, level, inodeID, cgAbs, err := resolveCgroupV2PathForNft(cgroup)
if err != nil { if err != nil {
writeJSON(w, http.StatusOK, TrafficAppMarksResponse{ writeJSON(w, http.StatusOK, TrafficAppMarksResponse{
OK: false, OK: false,
@@ -104,91 +137,78 @@ func handleTrafficAppMarks(w http.ResponseWriter, r *http.Request) {
}) })
return return
} }
}
if op == TrafficAppMarksAdd && target == "vpn" { if target == "vpn" {
// Ensure VPN policy table has a base route. This matters when current traffic-mode=direct. traffic := loadTrafficModeState()
traffic := loadTrafficModeState() iface, _ := resolveTrafficIface(traffic.PreferredIface)
iface, _ := resolveTrafficIface(traffic.PreferredIface) if strings.TrimSpace(iface) == "" {
if strings.TrimSpace(iface) == "" { writeJSON(w, http.StatusOK, TrafficAppMarksResponse{
OK: false,
Op: string(op),
Target: target,
Cgroup: cgAbs,
CgroupID: inodeID,
Message: "vpn interface not found (set preferred iface or bring VPN up)",
})
return
}
if err := ensureTrafficRouteBase(iface, traffic.AutoLocalBypass); err != nil {
writeJSON(w, http.StatusOK, TrafficAppMarksResponse{
OK: false,
Op: string(op),
Target: target,
Cgroup: cgAbs,
CgroupID: inodeID,
Message: "ensure vpn route base failed: " + err.Error(),
})
return
}
}
if err := appMarksAdd(target, inodeID, cgAbs, rel, level, ttl); err != nil {
writeJSON(w, http.StatusOK, TrafficAppMarksResponse{ writeJSON(w, http.StatusOK, TrafficAppMarksResponse{
OK: false, OK: false,
Op: string(op), Op: string(op),
Target: target, Target: target,
Cgroup: cgroup, Cgroup: cgAbs,
CgroupID: cgID, CgroupID: inodeID,
Message: "vpn interface not found (set preferred iface or bring VPN up)", TimeoutSec: ttl,
Message: err.Error(),
}) })
return return
} }
if err := ensureTrafficRouteBase(iface, traffic.AutoLocalBypass); err != nil {
writeJSON(w, http.StatusOK, TrafficAppMarksResponse{
OK: false,
Op: string(op),
Target: target,
Cgroup: cgroup,
CgroupID: cgID,
Message: "ensure vpn route base failed: " + err.Error(),
})
return
}
}
setName := nftSetCgroupDirect appendTraceLine("traffic", fmt.Sprintf("appmarks add target=%s cgroup=%s id=%d ttl=%ds", target, cgAbs, inodeID, ttl))
if target == "vpn" {
setName = nftSetCgroupVPN
}
switch op {
case TrafficAppMarksAdd:
ttl := timeoutSec
if ttl == 0 {
ttl = 24 * 60 * 60 // 24h default if client didn't specify
}
if err := nftAddCgroupElement(setName, cgID, ttl); err != nil {
writeJSON(w, http.StatusOK, TrafficAppMarksResponse{
OK: false,
Op: string(op),
Target: target,
Cgroup: cgroup,
CgroupID: cgID,
Message: err.Error(),
})
return
}
appendTraceLine("traffic", fmt.Sprintf("appmarks add target=%s cgroup=%s id=%d ttl=%ds", target, cgroup, cgID, ttl))
writeJSON(w, http.StatusOK, TrafficAppMarksResponse{ writeJSON(w, http.StatusOK, TrafficAppMarksResponse{
OK: true, OK: true,
Op: string(op), Op: string(op),
Target: target, Target: target,
Cgroup: cgroup, Cgroup: cgAbs,
CgroupID: cgID, CgroupID: inodeID,
TimeoutSec: ttl, TimeoutSec: ttl,
Message: "added", Message: "added",
}) })
case TrafficAppMarksDel: case TrafficAppMarksDel:
if err := nftDelCgroupElement(setName, cgID); err != nil { if err := appMarksDel(target, cgroup); err != nil {
writeJSON(w, http.StatusOK, TrafficAppMarksResponse{ writeJSON(w, http.StatusOK, TrafficAppMarksResponse{
OK: false, OK: false,
Op: string(op), Op: string(op),
Target: target, Target: target,
Cgroup: cgroup, Cgroup: cgroup,
CgroupID: cgID, Message: err.Error(),
Message: err.Error(),
}) })
return return
} }
appendTraceLine("traffic", fmt.Sprintf("appmarks del target=%s cgroup=%s id=%d", target, cgroup, cgID)) appendTraceLine("traffic", fmt.Sprintf("appmarks del target=%s cgroup=%s", target, cgroup))
writeJSON(w, http.StatusOK, TrafficAppMarksResponse{ writeJSON(w, http.StatusOK, TrafficAppMarksResponse{
OK: true, OK: true,
Op: string(op), Op: string(op),
Target: target, Target: target,
Cgroup: cgroup, Cgroup: cgroup,
CgroupID: cgID, Message: "deleted",
Message: "deleted",
}) })
case TrafficAppMarksClear: case TrafficAppMarksClear:
if err := nftFlushSet(setName); err != nil { if err := appMarksClear(target); err != nil {
writeJSON(w, http.StatusOK, TrafficAppMarksResponse{ writeJSON(w, http.StatusOK, TrafficAppMarksResponse{
OK: false, OK: false,
Op: string(op), Op: string(op),
@@ -212,145 +232,385 @@ func handleTrafficAppMarks(w http.ResponseWriter, r *http.Request) {
} }
} }
func ensureAppMarksNft() error { func appMarksGetStatus() (vpnCount int, directCount int) {
// Best-effort "ensure": ignore "exists" errors and proceed. _ = pruneExpiredAppMarks()
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "table", "inet", "agvpn")
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "set", "inet", "agvpn", nftSetCgroupVPN, "{", "typeof", "meta", "cgroup", ";", "flags", "timeout", ";", "}")
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "set", "inet", "agvpn", nftSetCgroupDirect, "{", "typeof", "meta", "cgroup", ";", "flags", "timeout", ";", "}")
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "chain", "inet", "agvpn", "output", "{", "type", "route", "hook", "output", "priority", "mangle;", "policy", "accept;", "}")
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "chain", "inet", "agvpn", "output_apps")
// Keep output_apps deterministic (no duplicates). Safe because this chain is dedicated. appMarksMu.Lock()
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "flush", "chain", "inet", "agvpn", "output_apps") defer appMarksMu.Unlock()
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "rule", "inet", "agvpn", "output_apps", "meta", "cgroup", "@"+nftSetCgroupDirect, "meta", "mark", "set", MARK_DIRECT, "accept")
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "rule", "inet", "agvpn", "output_apps", "meta", "cgroup", "@"+nftSetCgroupVPN, "meta", "mark", "set", MARK_APP, "accept")
// Ensure output chain has a jump into output_apps (routes-update may also manage this). st := loadAppMarksState()
out, _, _, _ := runCommandTimeout(5*time.Second, "nft", "list", "chain", "inet", "agvpn", "output") for _, it := range st.Items {
if !strings.Contains(out, "jump output_apps") { switch strings.ToLower(strings.TrimSpace(it.Target)) {
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "rule", "inet", "agvpn", "output", "jump", "output_apps") case "vpn":
vpnCount++
case "direct":
directCount++
}
}
return vpnCount, directCount
}
func appMarksAdd(target string, id uint64, cgAbs string, rel string, level int, ttlSec int) error {
target = strings.ToLower(strings.TrimSpace(target))
if target != "vpn" && target != "direct" {
return fmt.Errorf("invalid target")
}
if id == 0 {
return fmt.Errorf("invalid cgroup id")
}
if strings.TrimSpace(rel) == "" || level <= 0 {
return fmt.Errorf("invalid cgroup path")
}
if ttlSec <= 0 {
ttlSec = defaultAppMarkTTLSeconds
}
appMarksMu.Lock()
defer appMarksMu.Unlock()
st := loadAppMarksState()
changed := pruneExpiredAppMarksLocked(&st, time.Now().UTC())
// Replace any existing rule/state for this (target,id).
_ = nftDeleteAppMarkRule(target, id)
if err := nftInsertAppMarkRule(target, rel, level, id); err != nil {
return err
}
now := time.Now().UTC()
item := appMarkItem{
ID: id,
Target: target,
Cgroup: cgAbs,
CgroupRel: rel,
Level: level,
AddedAt: now.Format(time.RFC3339),
ExpiresAt: now.Add(time.Duration(ttlSec) * time.Second).Format(time.RFC3339),
}
st.Items = upsertAppMarkItem(st.Items, item)
changed = true
if changed {
if err := saveAppMarksState(st); err != nil {
return err
}
} }
return nil return nil
} }
func resolveCgroupIDForNft(input string) (uint64, string, error) { func appMarksDel(target string, cgroup string) error {
target = strings.ToLower(strings.TrimSpace(target))
if target != "vpn" && target != "direct" {
return fmt.Errorf("invalid target")
}
cgroup = strings.TrimSpace(cgroup)
if cgroup == "" {
return fmt.Errorf("empty cgroup")
}
appMarksMu.Lock()
defer appMarksMu.Unlock()
st := loadAppMarksState()
changed := pruneExpiredAppMarksLocked(&st, time.Now().UTC())
var id uint64
var cgAbs string
if isAllDigits(cgroup) {
v, err := strconv.ParseUint(cgroup, 10, 64)
if err == nil {
id = v
}
} else {
rel := normalizeCgroupRelOnly(cgroup)
if rel != "" {
cgAbs = "/" + rel
// Try to resolve inode id if directory still exists.
if inode, err := cgroupDirInode(rel); err == nil {
id = inode
}
}
}
// Fallback to state lookup by cgroup string.
idx := -1
for i, it := range st.Items {
if strings.ToLower(strings.TrimSpace(it.Target)) != target {
continue
}
if id != 0 && it.ID == id {
idx = i
break
}
if id == 0 && cgAbs != "" && strings.TrimSpace(it.Cgroup) == cgAbs {
id = it.ID
idx = i
break
}
}
if id != 0 {
_ = nftDeleteAppMarkRule(target, id)
}
if idx >= 0 {
st.Items = append(st.Items[:idx], st.Items[idx+1:]...)
changed = true
}
if changed {
return saveAppMarksState(st)
}
return nil
}
func appMarksClear(target string) error {
target = strings.ToLower(strings.TrimSpace(target))
if target != "vpn" && target != "direct" {
return fmt.Errorf("invalid target")
}
appMarksMu.Lock()
defer appMarksMu.Unlock()
st := loadAppMarksState()
changed := pruneExpiredAppMarksLocked(&st, time.Now().UTC())
kept := st.Items[:0]
for _, it := range st.Items {
if strings.ToLower(strings.TrimSpace(it.Target)) == target {
_ = nftDeleteAppMarkRule(target, it.ID)
changed = true
continue
}
kept = append(kept, it)
}
st.Items = kept
if changed {
return saveAppMarksState(st)
}
return nil
}
func ensureAppMarksNft() error {
// Best-effort "ensure": ignore "exists" errors and proceed.
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "table", "inet", appMarksTable)
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "chain", "inet", appMarksTable, "output", "{", "type", "route", "hook", "output", "priority", "mangle;", "policy", "accept;", "}")
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "chain", "inet", appMarksTable, appMarksChain)
out, _, _, _ := runCommandTimeout(5*time.Second, "nft", "list", "chain", "inet", appMarksTable, "output")
if !strings.Contains(out, "jump "+appMarksChain) {
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "insert", "rule", "inet", appMarksTable, "output", "jump", appMarksChain)
}
// Remove legacy rules that relied on `meta cgroup @svpn_cg_*` (broken on some kernels).
_ = cleanupLegacyAppMarksRules()
return nil
}
func cleanupLegacyAppMarksRules() error {
out, _, _, _ := runCommandTimeout(5*time.Second, "nft", "-a", "list", "chain", "inet", appMarksTable, appMarksChain)
for _, line := range strings.Split(out, "\n") {
l := strings.ToLower(line)
if !strings.Contains(l, "meta cgroup") {
continue
}
if !strings.Contains(l, "svpn_cg_") {
continue
}
h := parseNftHandle(line)
if h <= 0 {
continue
}
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "delete", "rule", "inet", appMarksTable, appMarksChain, "handle", strconv.Itoa(h))
}
return nil
}
func appMarkComment(target string, id uint64) string {
return fmt.Sprintf("%s:%s:%d", appMarkCommentPrefix, target, id)
}
func nftInsertAppMarkRule(target string, rel string, level int, id uint64) error {
mark := MARK_DIRECT
if target == "vpn" {
mark = MARK_APP
}
comment := appMarkComment(target, id)
// EN: nft requires a *string literal* for cgroupv2 path; paths with '@' (user@1000.service)
// EN: break tokenization unless we pass quotes as part of nft language input.
// RU: nft ожидает *строку* для cgroupv2 пути; пути с '@' (user@1000.service)
// RU: ломают токенизацию, поэтому передаем кавычки как часть nft-выражения.
pathLit := fmt.Sprintf("\"%s\"", rel)
commentLit := fmt.Sprintf("\"%s\"", comment)
_, out, code, err := runCommandTimeout(
5*time.Second,
"nft", "insert", "rule", "inet", appMarksTable, appMarksChain,
"socket", "cgroupv2", "level", strconv.Itoa(level), pathLit,
"meta", "mark", "set", mark,
"accept",
"comment", commentLit,
)
if err != nil || code != 0 {
if err == nil {
err = fmt.Errorf("nft insert rule exited with %d", code)
}
return fmt.Errorf("nft insert appmark rule failed: %w (%s)", err, strings.TrimSpace(out))
}
return nil
}
func nftDeleteAppMarkRule(target string, id uint64) error {
comment := appMarkComment(target, id)
out, _, _, _ := runCommandTimeout(5*time.Second, "nft", "-a", "list", "chain", "inet", appMarksTable, appMarksChain)
for _, line := range strings.Split(out, "\n") {
if !strings.Contains(line, comment) {
continue
}
h := parseNftHandle(line)
if h <= 0 {
continue
}
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "delete", "rule", "inet", appMarksTable, appMarksChain, "handle", strconv.Itoa(h))
}
return nil
}
func parseNftHandle(line string) int {
fields := strings.Fields(line)
for i := 0; i < len(fields)-1; i++ {
if fields[i] == "handle" {
n, _ := strconv.Atoi(fields[i+1])
return n
}
}
return 0
}
func resolveCgroupV2PathForNft(input string) (rel string, level int, inodeID uint64, abs string, err error) {
raw := strings.TrimSpace(input) raw := strings.TrimSpace(input)
if raw == "" { if raw == "" {
return 0, "", fmt.Errorf("empty cgroup") return "", 0, 0, "", fmt.Errorf("empty cgroup")
} }
// Allow numeric cgroup id input. rel = normalizeCgroupRelOnly(raw)
if isAllDigits(raw) { if rel == "" {
id, err := strconv.ParseUint(raw, 10, 64) return "", 0, 0, raw, fmt.Errorf("invalid cgroup path: %s", raw)
if err != nil || id == 0 {
return 0, raw, fmt.Errorf("invalid cgroup id: %s", raw)
}
return id, raw, nil
} }
// Normalize into a safe relative path under /sys/fs/cgroup. inodeID, err = cgroupDirInode(rel)
rel := strings.TrimPrefix(raw, "/") if err != nil {
return "", 0, 0, raw, err
}
level = strings.Count(rel, "/") + 1
abs = "/" + rel
return rel, level, inodeID, abs, nil
}
func normalizeCgroupRelOnly(raw string) string {
rel := strings.TrimSpace(raw)
rel = strings.TrimPrefix(rel, "/")
rel = filepath.Clean(rel) rel = filepath.Clean(rel)
if rel == "." || rel == "" { if rel == "." || rel == "" {
return 0, raw, fmt.Errorf("invalid cgroup path: %s", raw) return ""
} }
if strings.HasPrefix(rel, "..") || strings.Contains(rel, "../") { if strings.HasPrefix(rel, "..") || strings.Contains(rel, "../") {
return 0, raw, fmt.Errorf("invalid cgroup path (traversal): %s", raw) return ""
} }
return rel
}
full := filepath.Join(cgroupRootFS, rel) func cgroupDirInode(rel string) (uint64, error) {
full := filepath.Join(cgroupRootPath, strings.TrimPrefix(rel, "/"))
fi, err := os.Stat(full) fi, err := os.Stat(full)
if err != nil || fi == nil || !fi.IsDir() { if err != nil || fi == nil || !fi.IsDir() {
return 0, raw, fmt.Errorf("cgroup not found: %s", raw) return 0, fmt.Errorf("cgroup not found: %s", "/"+strings.TrimPrefix(rel, "/"))
} }
st, ok := fi.Sys().(*syscall.Stat_t) st, ok := fi.Sys().(*syscall.Stat_t)
if !ok || st == nil { if !ok || st == nil {
return 0, raw, fmt.Errorf("cannot stat cgroup: %s", raw) return 0, fmt.Errorf("cannot stat cgroup: %s", "/"+strings.TrimPrefix(rel, "/"))
} }
if st.Ino == 0 { if st.Ino == 0 {
return 0, raw, fmt.Errorf("invalid cgroup inode id: %s", raw) return 0, fmt.Errorf("invalid cgroup inode id: %s", "/"+strings.TrimPrefix(rel, "/"))
} }
// EN: For cgroup v2, the directory inode is used as cgroup id (matches meta cgroup / bpf_get_current_cgroup_id). return st.Ino, nil
// RU: Для cgroup v2 inode директории используется как cgroup id (соответствует meta cgroup / bфункции bpf_get_current_cgroup_id).
return st.Ino, "/" + rel, nil
} }
func nftAddCgroupElement(setName string, cgroupID uint64, timeoutSec int) error { func pruneExpiredAppMarks() error {
if strings.TrimSpace(setName) == "" { appMarksMu.Lock()
return fmt.Errorf("empty setName") defer appMarksMu.Unlock()
}
if cgroupID == 0 {
return fmt.Errorf("invalid cgroup id")
}
if timeoutSec < 0 {
return fmt.Errorf("invalid timeout_sec")
}
// NOTE: set has flags timeout; element can include timeout. st := loadAppMarksState()
ttl := fmt.Sprintf("%ds", timeoutSec) if pruneExpiredAppMarksLocked(&st, time.Now().UTC()) {
_, out, code, err := runCommandTimeout( return saveAppMarksState(st)
5*time.Second,
"nft", "add", "element", "inet", "agvpn", setName,
"{", fmt.Sprintf("%d", cgroupID), "timeout", ttl, "}",
)
if err != nil || code != 0 {
msg := strings.ToLower(out)
if strings.Contains(msg, "file exists") || strings.Contains(msg, "exists") {
return nil
}
if err == nil {
err = fmt.Errorf("nft add element exited with %d", code)
}
return fmt.Errorf("nft add element failed: %w", err)
} }
return nil return nil
} }
func nftDelCgroupElement(setName string, cgroupID uint64) error { func pruneExpiredAppMarksLocked(st *appMarksState, now time.Time) (changed bool) {
if strings.TrimSpace(setName) == "" { if st == nil {
return fmt.Errorf("empty setName") return false
} }
if cgroupID == 0 { kept := st.Items[:0]
return fmt.Errorf("invalid cgroup id") for _, it := range st.Items {
} exp, err := time.Parse(time.RFC3339, strings.TrimSpace(it.ExpiresAt))
_, out, code, err := runCommandTimeout( if err != nil || !exp.After(now) {
5*time.Second, _ = nftDeleteAppMarkRule(strings.ToLower(strings.TrimSpace(it.Target)), it.ID)
"nft", "delete", "element", "inet", "agvpn", setName, changed = true
"{", fmt.Sprintf("%d", cgroupID), "}", continue
)
if err != nil || code != 0 {
msg := strings.ToLower(out)
if strings.Contains(msg, "no such file") ||
strings.Contains(msg, "not found") ||
strings.Contains(msg, "does not exist") {
return nil
} }
if err == nil { kept = append(kept, it)
err = fmt.Errorf("nft delete element exited with %d", code)
}
return fmt.Errorf("nft delete element failed: %w", err)
} }
return nil st.Items = kept
return changed
} }
func nftFlushSet(setName string) error { func upsertAppMarkItem(items []appMarkItem, next appMarkItem) []appMarkItem {
if strings.TrimSpace(setName) == "" { out := items[:0]
return fmt.Errorf("empty setName") for _, it := range items {
} if strings.ToLower(strings.TrimSpace(it.Target)) == strings.ToLower(strings.TrimSpace(next.Target)) && it.ID == next.ID {
_, out, code, err := runCommandTimeout(5*time.Second, "nft", "flush", "set", "inet", "agvpn", setName) continue
if err != nil || code != 0 {
msg := strings.ToLower(out)
if strings.Contains(msg, "no such file") ||
strings.Contains(msg, "not found") ||
strings.Contains(msg, "does not exist") {
return nil
} }
if err == nil { out = append(out, it)
err = fmt.Errorf("nft flush set exited with %d", code)
}
return fmt.Errorf("nft flush set failed: %w", err)
} }
return nil out = append(out, next)
return out
}
func loadAppMarksState() appMarksState {
st := appMarksState{Version: 1}
data, err := os.ReadFile(trafficAppMarksPath)
if err != nil {
return st
}
if err := json.Unmarshal(data, &st); err != nil {
return appMarksState{Version: 1}
}
if st.Version == 0 {
st.Version = 1
}
return st
}
func saveAppMarksState(st appMarksState) error {
st.Version = 1
st.UpdatedAt = time.Now().UTC().Format(time.RFC3339)
data, err := json.MarshalIndent(st, "", " ")
if err != nil {
return err
}
if err := os.MkdirAll(stateDir, 0o755); err != nil {
return err
}
tmp := trafficAppMarksPath + ".tmp"
if err := os.WriteFile(tmp, data, 0o644); err != nil {
return err
}
return os.Rename(tmp, trafficAppMarksPath)
} }
func isAllDigits(s string) bool { func isAllDigits(s string) bool {

View File

@@ -24,12 +24,14 @@ func startWatchers(ctx context.Context) {
autoEvery := time.Duration(envInt("SVPN_POLL_AUTOLOOP_MS", defaultPollAutoloopMs)) * time.Millisecond autoEvery := time.Duration(envInt("SVPN_POLL_AUTOLOOP_MS", defaultPollAutoloopMs)) * time.Millisecond
systemdEvery := time.Duration(envInt("SVPN_POLL_SYSTEMD_MS", defaultPollSystemdMs)) * time.Millisecond systemdEvery := time.Duration(envInt("SVPN_POLL_SYSTEMD_MS", defaultPollSystemdMs)) * time.Millisecond
traceEvery := time.Duration(envInt("SVPN_POLL_TRACE_MS", defaultPollTraceMs)) * time.Millisecond traceEvery := time.Duration(envInt("SVPN_POLL_TRACE_MS", defaultPollTraceMs)) * time.Millisecond
appMarksEvery := time.Duration(envInt("SVPN_POLL_APPMARKS_MS", defaultPollAppMarksMs)) * time.Millisecond
go watchStatusFile(ctx, statusEvery) go watchStatusFile(ctx, statusEvery)
go watchLoginFile(ctx, loginEvery) go watchLoginFile(ctx, loginEvery)
go watchAutoloop(ctx, autoEvery) go watchAutoloop(ctx, autoEvery)
go watchFileChange(ctx, traceLogPath, "trace_changed", "full", traceEvery) go watchFileChange(ctx, traceLogPath, "trace_changed", "full", traceEvery)
go watchFileChange(ctx, smartdnsLogPath, "trace_changed", "smartdns", traceEvery) go watchFileChange(ctx, smartdnsLogPath, "trace_changed", "smartdns", traceEvery)
go watchTrafficAppMarksTTL(ctx, appMarksEvery)
go watchSystemdUnitDynamic(ctx, routesServiceUnitName, "routes_service", systemdEvery) go watchSystemdUnitDynamic(ctx, routesServiceUnitName, "routes_service", systemdEvery)
go watchSystemdUnitDynamic(ctx, routesTimerUnitName, "routes_timer", systemdEvery) go watchSystemdUnitDynamic(ctx, routesTimerUnitName, "routes_timer", systemdEvery)
@@ -37,6 +39,17 @@ func startWatchers(ctx context.Context) {
go watchSystemdUnit(ctx, "smartdns-local.service", "smartdns_unit", systemdEvery) go watchSystemdUnit(ctx, "smartdns-local.service", "smartdns_unit", systemdEvery)
} }
func watchTrafficAppMarksTTL(ctx context.Context, every time.Duration) {
for {
select {
case <-ctx.Done():
return
case <-time.After(every):
}
_ = pruneExpiredAppMarks()
}
}
// --------------------------------------------------------------------- // ---------------------------------------------------------------------
// status file watcher // status file watcher
// --------------------------------------------------------------------- // ---------------------------------------------------------------------