Files
elmprodvpn/selective-vpn-api/app/traffic_appmarks.go

727 lines
19 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package app
import (
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"sort"
"strconv"
"strings"
"sync"
"syscall"
"time"
)
// ---------------------------------------------------------------------
// traffic app marks (per-app routing via cgroupv2 path -> fwmark)
// ---------------------------------------------------------------------
//
// EN: This module manages runtime per-app routing marks.
// EN: We match cgroupv2 paths using nftables `socket cgroupv2` and set fwmark:
// EN: - MARK_APP (VPN) or MARK_DIRECT (direct).
// EN: TTL is kept in a JSON state file; expired entries are pruned.
// RU: Этот модуль управляет runtime per-app маршрутизацией.
// RU: Мы матчим cgroupv2 path через nftables `socket cgroupv2` и ставим fwmark:
// RU: - MARK_APP (VPN) или MARK_DIRECT (direct).
// RU: TTL хранится в JSON состоянии; просроченные записи удаляются.
const (
appMarksTable = "agvpn"
appMarksChain = "output_apps"
appMarkCommentPrefix = "svpn_appmark"
defaultAppMarkTTLSeconds = 24 * 60 * 60
)
var appMarksMu sync.Mutex
type appMarksState struct {
Version int `json:"version"`
UpdatedAt string `json:"updated_at"`
Items []appMarkItem `json:"items,omitempty"`
}
type appMarkItem struct {
ID uint64 `json:"id"`
Target string `json:"target"` // vpn|direct
Cgroup string `json:"cgroup"` // absolute path ("/user.slice/..."), informational
CgroupRel string `json:"cgroup_rel"`
Level int `json:"level"`
Unit string `json:"unit,omitempty"`
Command string `json:"command,omitempty"`
AppKey string `json:"app_key,omitempty"`
AddedAt string `json:"added_at"`
ExpiresAt string `json:"expires_at"`
}
func handleTrafficAppMarks(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
vpnCount, directCount := appMarksGetStatus()
writeJSON(w, http.StatusOK, TrafficAppMarksStatusResponse{
VPNCount: vpnCount,
DirectCount: directCount,
Message: "ok",
})
case http.MethodPost:
var body TrafficAppMarksRequest
if r.Body != nil {
defer r.Body.Close()
if err := json.NewDecoder(io.LimitReader(r.Body, 1<<20)).Decode(&body); err != nil && err != io.EOF {
http.Error(w, "bad json", http.StatusBadRequest)
return
}
}
op := TrafficAppMarksOp(strings.ToLower(strings.TrimSpace(string(body.Op))))
target := strings.ToLower(strings.TrimSpace(body.Target))
cgroup := strings.TrimSpace(body.Cgroup)
unit := strings.TrimSpace(body.Unit)
command := strings.TrimSpace(body.Command)
appKey := strings.TrimSpace(body.AppKey)
timeoutSec := body.TimeoutSec
if op == "" {
http.Error(w, "missing op", http.StatusBadRequest)
return
}
if target == "" {
http.Error(w, "missing target", http.StatusBadRequest)
return
}
if target != "vpn" && target != "direct" {
http.Error(w, "target must be vpn|direct", http.StatusBadRequest)
return
}
if (op == TrafficAppMarksAdd || op == TrafficAppMarksDel) && cgroup == "" {
http.Error(w, "missing cgroup", http.StatusBadRequest)
return
}
if timeoutSec < 0 {
http.Error(w, "timeout_sec must be >= 0", http.StatusBadRequest)
return
}
if err := ensureAppMarksNft(); err != nil {
writeJSON(w, http.StatusOK, TrafficAppMarksResponse{
OK: false,
Op: string(op),
Target: target,
Cgroup: cgroup,
Message: "nft init failed: " + err.Error(),
})
return
}
switch op {
case TrafficAppMarksAdd:
if isAllDigits(cgroup) {
writeJSON(w, http.StatusOK, TrafficAppMarksResponse{
OK: false,
Op: string(op),
Target: target,
Cgroup: cgroup,
Message: "cgroup must be a cgroupv2 path (ControlGroup), not a numeric id",
})
return
}
ttl := timeoutSec
if ttl == 0 {
ttl = defaultAppMarkTTLSeconds
}
rel, level, inodeID, cgAbs, err := resolveCgroupV2PathForNft(cgroup)
if err != nil {
writeJSON(w, http.StatusOK, TrafficAppMarksResponse{
OK: false,
Op: string(op),
Target: target,
Cgroup: body.Cgroup,
Message: err.Error(),
})
return
}
if target == "vpn" {
traffic := loadTrafficModeState()
iface, _ := resolveTrafficIface(traffic.PreferredIface)
if strings.TrimSpace(iface) == "" {
writeJSON(w, http.StatusOK, TrafficAppMarksResponse{
OK: false,
Op: string(op),
Target: target,
Cgroup: cgAbs,
CgroupID: inodeID,
Message: "vpn interface not found (set preferred iface or bring VPN up)",
})
return
}
if err := ensureTrafficRouteBase(iface, traffic.AutoLocalBypass); err != nil {
writeJSON(w, http.StatusOK, TrafficAppMarksResponse{
OK: false,
Op: string(op),
Target: target,
Cgroup: cgAbs,
CgroupID: inodeID,
Message: "ensure vpn route base failed: " + err.Error(),
})
return
}
}
if err := appMarksAdd(target, inodeID, cgAbs, rel, level, unit, command, appKey, ttl); err != nil {
writeJSON(w, http.StatusOK, TrafficAppMarksResponse{
OK: false,
Op: string(op),
Target: target,
Cgroup: cgAbs,
CgroupID: inodeID,
TimeoutSec: ttl,
Message: err.Error(),
})
return
}
appendTraceLine("traffic", fmt.Sprintf("appmarks add target=%s cgroup=%s id=%d ttl=%ds", target, cgAbs, inodeID, ttl))
writeJSON(w, http.StatusOK, TrafficAppMarksResponse{
OK: true,
Op: string(op),
Target: target,
Cgroup: cgAbs,
CgroupID: inodeID,
TimeoutSec: ttl,
Message: "added",
})
case TrafficAppMarksDel:
if err := appMarksDel(target, cgroup); err != nil {
writeJSON(w, http.StatusOK, TrafficAppMarksResponse{
OK: false,
Op: string(op),
Target: target,
Cgroup: cgroup,
Message: err.Error(),
})
return
}
appendTraceLine("traffic", fmt.Sprintf("appmarks del target=%s cgroup=%s", target, cgroup))
writeJSON(w, http.StatusOK, TrafficAppMarksResponse{
OK: true,
Op: string(op),
Target: target,
Cgroup: cgroup,
Message: "deleted",
})
case TrafficAppMarksClear:
if err := appMarksClear(target); err != nil {
writeJSON(w, http.StatusOK, TrafficAppMarksResponse{
OK: false,
Op: string(op),
Target: target,
Message: err.Error(),
})
return
}
appendTraceLine("traffic", fmt.Sprintf("appmarks clear target=%s", target))
writeJSON(w, http.StatusOK, TrafficAppMarksResponse{
OK: true,
Op: string(op),
Target: target,
Message: "cleared",
})
default:
http.Error(w, "unknown op", http.StatusBadRequest)
}
default:
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
}
}
func handleTrafficAppMarksItems(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
_ = pruneExpiredAppMarks()
appMarksMu.Lock()
st := loadAppMarksState()
appMarksMu.Unlock()
now := time.Now().UTC()
items := make([]TrafficAppMarkItemView, 0, len(st.Items))
for _, it := range st.Items {
rem := 0
exp, err := time.Parse(time.RFC3339, strings.TrimSpace(it.ExpiresAt))
if err == nil {
rem = int(exp.Sub(now).Seconds())
if rem < 0 {
rem = 0
}
}
items = append(items, TrafficAppMarkItemView{
ID: it.ID,
Target: it.Target,
Cgroup: it.Cgroup,
CgroupRel: it.CgroupRel,
Level: it.Level,
Unit: it.Unit,
Command: it.Command,
AppKey: it.AppKey,
AddedAt: it.AddedAt,
ExpiresAt: it.ExpiresAt,
RemainingSec: rem,
})
}
// Sort: target -> app_key -> remaining desc.
sort.Slice(items, func(i, j int) bool {
if items[i].Target != items[j].Target {
return items[i].Target < items[j].Target
}
if items[i].AppKey != items[j].AppKey {
return items[i].AppKey < items[j].AppKey
}
return items[i].RemainingSec > items[j].RemainingSec
})
writeJSON(w, http.StatusOK, TrafficAppMarksItemsResponse{Items: items, Message: "ok"})
}
func appMarksGetStatus() (vpnCount int, directCount int) {
_ = pruneExpiredAppMarks()
appMarksMu.Lock()
defer appMarksMu.Unlock()
st := loadAppMarksState()
for _, it := range st.Items {
switch strings.ToLower(strings.TrimSpace(it.Target)) {
case "vpn":
vpnCount++
case "direct":
directCount++
}
}
return vpnCount, directCount
}
func appMarksAdd(target string, id uint64, cgAbs string, rel string, level int, unit string, command string, appKey string, ttlSec int) error {
target = strings.ToLower(strings.TrimSpace(target))
if target != "vpn" && target != "direct" {
return fmt.Errorf("invalid target")
}
if id == 0 {
return fmt.Errorf("invalid cgroup id")
}
if strings.TrimSpace(rel) == "" || level <= 0 {
return fmt.Errorf("invalid cgroup path")
}
if ttlSec <= 0 {
ttlSec = defaultAppMarkTTLSeconds
}
appMarksMu.Lock()
defer appMarksMu.Unlock()
st := loadAppMarksState()
changed := pruneExpiredAppMarksLocked(&st, time.Now().UTC())
unit = strings.TrimSpace(unit)
command = strings.TrimSpace(command)
appKey = normalizeAppKey(appKey, command)
// EN: Avoid unbounded growth of marks for the same app.
// RU: Не даём бесконечно плодить метки на одно и то же приложение.
if appKey != "" {
kept := st.Items[:0]
for _, it := range st.Items {
if strings.ToLower(strings.TrimSpace(it.Target)) == target &&
strings.TrimSpace(it.AppKey) == appKey &&
it.ID != id {
_ = nftDeleteAppMarkRule(target, it.ID)
changed = true
continue
}
kept = append(kept, it)
}
st.Items = kept
}
// Replace any existing rule/state for this (target,id).
_ = nftDeleteAppMarkRule(target, id)
if err := nftInsertAppMarkRule(target, rel, level, id); err != nil {
return err
}
now := time.Now().UTC()
item := appMarkItem{
ID: id,
Target: target,
Cgroup: cgAbs,
CgroupRel: rel,
Level: level,
Unit: unit,
Command: command,
AppKey: appKey,
AddedAt: now.Format(time.RFC3339),
ExpiresAt: now.Add(time.Duration(ttlSec) * time.Second).Format(time.RFC3339),
}
st.Items = upsertAppMarkItem(st.Items, item)
changed = true
if changed {
if err := saveAppMarksState(st); err != nil {
return err
}
}
return nil
}
func appMarksDel(target string, cgroup string) error {
target = strings.ToLower(strings.TrimSpace(target))
if target != "vpn" && target != "direct" {
return fmt.Errorf("invalid target")
}
cgroup = strings.TrimSpace(cgroup)
if cgroup == "" {
return fmt.Errorf("empty cgroup")
}
appMarksMu.Lock()
defer appMarksMu.Unlock()
st := loadAppMarksState()
changed := pruneExpiredAppMarksLocked(&st, time.Now().UTC())
var id uint64
var cgAbs string
if isAllDigits(cgroup) {
v, err := strconv.ParseUint(cgroup, 10, 64)
if err == nil {
id = v
}
} else {
rel := normalizeCgroupRelOnly(cgroup)
if rel != "" {
cgAbs = "/" + rel
// Try to resolve inode id if directory still exists.
if inode, err := cgroupDirInode(rel); err == nil {
id = inode
}
}
}
// Fallback to state lookup by cgroup string.
idx := -1
for i, it := range st.Items {
if strings.ToLower(strings.TrimSpace(it.Target)) != target {
continue
}
if id != 0 && it.ID == id {
idx = i
break
}
if id == 0 && cgAbs != "" && strings.TrimSpace(it.Cgroup) == cgAbs {
id = it.ID
idx = i
break
}
}
if id != 0 {
_ = nftDeleteAppMarkRule(target, id)
}
if idx >= 0 {
st.Items = append(st.Items[:idx], st.Items[idx+1:]...)
changed = true
}
if changed {
return saveAppMarksState(st)
}
return nil
}
func appMarksClear(target string) error {
target = strings.ToLower(strings.TrimSpace(target))
if target != "vpn" && target != "direct" {
return fmt.Errorf("invalid target")
}
appMarksMu.Lock()
defer appMarksMu.Unlock()
st := loadAppMarksState()
changed := pruneExpiredAppMarksLocked(&st, time.Now().UTC())
kept := st.Items[:0]
for _, it := range st.Items {
if strings.ToLower(strings.TrimSpace(it.Target)) == target {
_ = nftDeleteAppMarkRule(target, it.ID)
changed = true
continue
}
kept = append(kept, it)
}
st.Items = kept
if changed {
return saveAppMarksState(st)
}
return nil
}
func ensureAppMarksNft() error {
// Best-effort "ensure": ignore "exists" errors and proceed.
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "table", "inet", appMarksTable)
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "chain", "inet", appMarksTable, "output", "{", "type", "route", "hook", "output", "priority", "mangle;", "policy", "accept;", "}")
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "chain", "inet", appMarksTable, appMarksChain)
out, _, _, _ := runCommandTimeout(5*time.Second, "nft", "list", "chain", "inet", appMarksTable, "output")
if !strings.Contains(out, "jump "+appMarksChain) {
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "insert", "rule", "inet", appMarksTable, "output", "jump", appMarksChain)
}
// Remove legacy rules that relied on `meta cgroup @svpn_cg_*` (broken on some kernels).
_ = cleanupLegacyAppMarksRules()
return nil
}
func cleanupLegacyAppMarksRules() error {
out, _, _, _ := runCommandTimeout(5*time.Second, "nft", "-a", "list", "chain", "inet", appMarksTable, appMarksChain)
for _, line := range strings.Split(out, "\n") {
l := strings.ToLower(line)
if !strings.Contains(l, "meta cgroup") {
continue
}
if !strings.Contains(l, "svpn_cg_") {
continue
}
h := parseNftHandle(line)
if h <= 0 {
continue
}
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "delete", "rule", "inet", appMarksTable, appMarksChain, "handle", strconv.Itoa(h))
}
return nil
}
func appMarkComment(target string, id uint64) string {
return fmt.Sprintf("%s:%s:%d", appMarkCommentPrefix, target, id)
}
func nftInsertAppMarkRule(target string, rel string, level int, id uint64) error {
mark := MARK_DIRECT
if target == "vpn" {
mark = MARK_APP
}
comment := appMarkComment(target, id)
// EN: nft requires a *string literal* for cgroupv2 path; paths with '@' (user@1000.service)
// EN: break tokenization unless we pass quotes as part of nft language input.
// RU: nft ожидает *строку* для cgroupv2 пути; пути с '@' (user@1000.service)
// RU: ломают токенизацию, поэтому передаем кавычки как часть nft-выражения.
pathLit := fmt.Sprintf("\"%s\"", rel)
commentLit := fmt.Sprintf("\"%s\"", comment)
_, out, code, err := runCommandTimeout(
5*time.Second,
"nft", "insert", "rule", "inet", appMarksTable, appMarksChain,
"socket", "cgroupv2", "level", strconv.Itoa(level), pathLit,
"meta", "mark", "set", mark,
"accept",
"comment", commentLit,
)
if err != nil || code != 0 {
if err == nil {
err = fmt.Errorf("nft insert rule exited with %d", code)
}
return fmt.Errorf("nft insert appmark rule failed: %w (%s)", err, strings.TrimSpace(out))
}
return nil
}
func nftDeleteAppMarkRule(target string, id uint64) error {
comment := appMarkComment(target, id)
out, _, _, _ := runCommandTimeout(5*time.Second, "nft", "-a", "list", "chain", "inet", appMarksTable, appMarksChain)
for _, line := range strings.Split(out, "\n") {
if !strings.Contains(line, comment) {
continue
}
h := parseNftHandle(line)
if h <= 0 {
continue
}
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "delete", "rule", "inet", appMarksTable, appMarksChain, "handle", strconv.Itoa(h))
}
return nil
}
func parseNftHandle(line string) int {
fields := strings.Fields(line)
for i := 0; i < len(fields)-1; i++ {
if fields[i] == "handle" {
n, _ := strconv.Atoi(fields[i+1])
return n
}
}
return 0
}
func resolveCgroupV2PathForNft(input string) (rel string, level int, inodeID uint64, abs string, err error) {
raw := strings.TrimSpace(input)
if raw == "" {
return "", 0, 0, "", fmt.Errorf("empty cgroup")
}
rel = normalizeCgroupRelOnly(raw)
if rel == "" {
return "", 0, 0, raw, fmt.Errorf("invalid cgroup path: %s", raw)
}
inodeID, err = cgroupDirInode(rel)
if err != nil {
return "", 0, 0, raw, err
}
level = strings.Count(rel, "/") + 1
abs = "/" + rel
return rel, level, inodeID, abs, nil
}
func normalizeCgroupRelOnly(raw string) string {
rel := strings.TrimSpace(raw)
rel = strings.TrimPrefix(rel, "/")
rel = filepath.Clean(rel)
if rel == "." || rel == "" {
return ""
}
if strings.HasPrefix(rel, "..") || strings.Contains(rel, "../") {
return ""
}
return rel
}
func cgroupDirInode(rel string) (uint64, error) {
full := filepath.Join(cgroupRootPath, strings.TrimPrefix(rel, "/"))
fi, err := os.Stat(full)
if err != nil || fi == nil || !fi.IsDir() {
return 0, fmt.Errorf("cgroup not found: %s", "/"+strings.TrimPrefix(rel, "/"))
}
st, ok := fi.Sys().(*syscall.Stat_t)
if !ok || st == nil {
return 0, fmt.Errorf("cannot stat cgroup: %s", "/"+strings.TrimPrefix(rel, "/"))
}
if st.Ino == 0 {
return 0, fmt.Errorf("invalid cgroup inode id: %s", "/"+strings.TrimPrefix(rel, "/"))
}
return st.Ino, nil
}
func pruneExpiredAppMarks() error {
appMarksMu.Lock()
defer appMarksMu.Unlock()
st := loadAppMarksState()
if pruneExpiredAppMarksLocked(&st, time.Now().UTC()) {
return saveAppMarksState(st)
}
return nil
}
func pruneExpiredAppMarksLocked(st *appMarksState, now time.Time) (changed bool) {
if st == nil {
return false
}
kept := st.Items[:0]
for _, it := range st.Items {
exp, err := time.Parse(time.RFC3339, strings.TrimSpace(it.ExpiresAt))
if err != nil || !exp.After(now) {
_ = nftDeleteAppMarkRule(strings.ToLower(strings.TrimSpace(it.Target)), it.ID)
changed = true
continue
}
kept = append(kept, it)
}
st.Items = kept
return changed
}
func upsertAppMarkItem(items []appMarkItem, next appMarkItem) []appMarkItem {
out := items[:0]
for _, it := range items {
if strings.ToLower(strings.TrimSpace(it.Target)) == strings.ToLower(strings.TrimSpace(next.Target)) && it.ID == next.ID {
continue
}
out = append(out, it)
}
out = append(out, next)
return out
}
func loadAppMarksState() appMarksState {
st := appMarksState{Version: 1}
data, err := os.ReadFile(trafficAppMarksPath)
if err != nil {
return st
}
if err := json.Unmarshal(data, &st); err != nil {
return appMarksState{Version: 1}
}
if st.Version == 0 {
st.Version = 1
}
return st
}
func saveAppMarksState(st appMarksState) error {
st.Version = 1
st.UpdatedAt = time.Now().UTC().Format(time.RFC3339)
data, err := json.MarshalIndent(st, "", " ")
if err != nil {
return err
}
if err := os.MkdirAll(stateDir, 0o755); err != nil {
return err
}
tmp := trafficAppMarksPath + ".tmp"
if err := os.WriteFile(tmp, data, 0o644); err != nil {
return err
}
return os.Rename(tmp, trafficAppMarksPath)
}
func normalizeAppKey(appKey string, command string) string {
key := strings.TrimSpace(appKey)
if key != "" {
return key
}
cmd := strings.TrimSpace(command)
if cmd == "" {
return ""
}
fields := strings.Fields(cmd)
if len(fields) > 0 {
return strings.TrimSpace(fields[0])
}
return cmd
}
func isAllDigits(s string) bool {
s = strings.TrimSpace(s)
if s == "" {
return false
}
for i := 0; i < len(s); i++ {
ch := s[i]
if ch < '0' || ch > '9' {
return false
}
}
return true
}