traffic: add per-app runtime app routing via cgroup marks
This commit is contained in:
368
selective-vpn-api/app/traffic_appmarks.go
Normal file
368
selective-vpn-api/app/traffic_appmarks.go
Normal file
@@ -0,0 +1,368 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------
|
||||
// traffic app marks (per-app routing via cgroup -> fwmark)
|
||||
// ---------------------------------------------------------------------
|
||||
//
|
||||
// EN: This module manages runtime cgroup-id sets used by nftables rules in
|
||||
// EN: routes_update.go (output_apps chain). GUI/clients can add/remove cgroup IDs
|
||||
// EN: to force traffic through VPN (MARK_APP) or force direct (MARK_DIRECT).
|
||||
// RU: Этот модуль управляет runtime cgroup-id сетами для nftables правил из
|
||||
// RU: routes_update.go (цепочка output_apps). GUI/клиенты могут добавлять/удалять
|
||||
// RU: cgroup IDs, чтобы форсировать трафик через VPN (MARK_APP) или в direct (MARK_DIRECT).
|
||||
|
||||
const (
|
||||
nftSetCgroupVPN = "svpn_cg_vpn"
|
||||
nftSetCgroupDirect = "svpn_cg_direct"
|
||||
cgroupRootFS = "/sys/fs/cgroup"
|
||||
)
|
||||
|
||||
func handleTrafficAppMarks(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
vpnElems, _ := readNftSetElements(nftSetCgroupVPN)
|
||||
directElems, _ := readNftSetElements(nftSetCgroupDirect)
|
||||
writeJSON(w, http.StatusOK, TrafficAppMarksStatusResponse{
|
||||
VPNCount: len(vpnElems),
|
||||
DirectCount: len(directElems),
|
||||
Message: "ok",
|
||||
})
|
||||
case http.MethodPost:
|
||||
var body TrafficAppMarksRequest
|
||||
if r.Body != nil {
|
||||
defer r.Body.Close()
|
||||
if err := json.NewDecoder(io.LimitReader(r.Body, 1<<20)).Decode(&body); err != nil && err != io.EOF {
|
||||
http.Error(w, "bad json", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
op := TrafficAppMarksOp(strings.ToLower(strings.TrimSpace(string(body.Op))))
|
||||
target := strings.ToLower(strings.TrimSpace(body.Target))
|
||||
cgroup := strings.TrimSpace(body.Cgroup)
|
||||
timeoutSec := body.TimeoutSec
|
||||
|
||||
if op == "" {
|
||||
http.Error(w, "missing op", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if target == "" {
|
||||
http.Error(w, "missing target", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if target != "vpn" && target != "direct" {
|
||||
http.Error(w, "target must be vpn|direct", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if (op == TrafficAppMarksAdd || op == TrafficAppMarksDel) && cgroup == "" {
|
||||
http.Error(w, "missing cgroup", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if timeoutSec < 0 {
|
||||
http.Error(w, "timeout_sec must be >= 0", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure nft objects exist even if routes-update hasn't run yet.
|
||||
if err := ensureAppMarksNft(); err != nil {
|
||||
writeJSON(w, http.StatusOK, TrafficAppMarksResponse{
|
||||
OK: false,
|
||||
Op: string(op),
|
||||
Target: target,
|
||||
Cgroup: cgroup,
|
||||
Message: "nft init failed: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var (
|
||||
cgID uint64
|
||||
err error
|
||||
)
|
||||
if cgroup != "" {
|
||||
cgID, cgroup, err = resolveCgroupIDForNft(cgroup)
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusOK, TrafficAppMarksResponse{
|
||||
OK: false,
|
||||
Op: string(op),
|
||||
Target: target,
|
||||
Cgroup: body.Cgroup,
|
||||
Message: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if op == TrafficAppMarksAdd && target == "vpn" {
|
||||
// Ensure VPN policy table has a base route. This matters when current traffic-mode=direct.
|
||||
traffic := loadTrafficModeState()
|
||||
iface, _ := resolveTrafficIface(traffic.PreferredIface)
|
||||
if strings.TrimSpace(iface) == "" {
|
||||
writeJSON(w, http.StatusOK, TrafficAppMarksResponse{
|
||||
OK: false,
|
||||
Op: string(op),
|
||||
Target: target,
|
||||
Cgroup: cgroup,
|
||||
CgroupID: cgID,
|
||||
Message: "vpn interface not found (set preferred iface or bring VPN up)",
|
||||
})
|
||||
return
|
||||
}
|
||||
if err := ensureTrafficRouteBase(iface, traffic.AutoLocalBypass); err != nil {
|
||||
writeJSON(w, http.StatusOK, TrafficAppMarksResponse{
|
||||
OK: false,
|
||||
Op: string(op),
|
||||
Target: target,
|
||||
Cgroup: cgroup,
|
||||
CgroupID: cgID,
|
||||
Message: "ensure vpn route base failed: " + err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
setName := nftSetCgroupDirect
|
||||
if target == "vpn" {
|
||||
setName = nftSetCgroupVPN
|
||||
}
|
||||
|
||||
switch op {
|
||||
case TrafficAppMarksAdd:
|
||||
ttl := timeoutSec
|
||||
if ttl == 0 {
|
||||
ttl = 24 * 60 * 60 // 24h default if client didn't specify
|
||||
}
|
||||
if err := nftAddCgroupElement(setName, cgID, ttl); err != nil {
|
||||
writeJSON(w, http.StatusOK, TrafficAppMarksResponse{
|
||||
OK: false,
|
||||
Op: string(op),
|
||||
Target: target,
|
||||
Cgroup: cgroup,
|
||||
CgroupID: cgID,
|
||||
Message: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
appendTraceLine("traffic", fmt.Sprintf("appmarks add target=%s cgroup=%s id=%d ttl=%ds", target, cgroup, cgID, ttl))
|
||||
writeJSON(w, http.StatusOK, TrafficAppMarksResponse{
|
||||
OK: true,
|
||||
Op: string(op),
|
||||
Target: target,
|
||||
Cgroup: cgroup,
|
||||
CgroupID: cgID,
|
||||
TimeoutSec: ttl,
|
||||
Message: "added",
|
||||
})
|
||||
case TrafficAppMarksDel:
|
||||
if err := nftDelCgroupElement(setName, cgID); err != nil {
|
||||
writeJSON(w, http.StatusOK, TrafficAppMarksResponse{
|
||||
OK: false,
|
||||
Op: string(op),
|
||||
Target: target,
|
||||
Cgroup: cgroup,
|
||||
CgroupID: cgID,
|
||||
Message: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
appendTraceLine("traffic", fmt.Sprintf("appmarks del target=%s cgroup=%s id=%d", target, cgroup, cgID))
|
||||
writeJSON(w, http.StatusOK, TrafficAppMarksResponse{
|
||||
OK: true,
|
||||
Op: string(op),
|
||||
Target: target,
|
||||
Cgroup: cgroup,
|
||||
CgroupID: cgID,
|
||||
Message: "deleted",
|
||||
})
|
||||
case TrafficAppMarksClear:
|
||||
if err := nftFlushSet(setName); err != nil {
|
||||
writeJSON(w, http.StatusOK, TrafficAppMarksResponse{
|
||||
OK: false,
|
||||
Op: string(op),
|
||||
Target: target,
|
||||
Message: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
appendTraceLine("traffic", fmt.Sprintf("appmarks clear target=%s", target))
|
||||
writeJSON(w, http.StatusOK, TrafficAppMarksResponse{
|
||||
OK: true,
|
||||
Op: string(op),
|
||||
Target: target,
|
||||
Message: "cleared",
|
||||
})
|
||||
default:
|
||||
http.Error(w, "unknown op", http.StatusBadRequest)
|
||||
}
|
||||
default:
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
}
|
||||
|
||||
func ensureAppMarksNft() error {
|
||||
// Best-effort "ensure": ignore "exists" errors and proceed.
|
||||
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "table", "inet", "agvpn")
|
||||
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "set", "inet", "agvpn", nftSetCgroupVPN, "{", "typeof", "meta", "cgroup", ";", "flags", "timeout", ";", "}")
|
||||
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "set", "inet", "agvpn", nftSetCgroupDirect, "{", "typeof", "meta", "cgroup", ";", "flags", "timeout", ";", "}")
|
||||
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "chain", "inet", "agvpn", "output", "{", "type", "route", "hook", "output", "priority", "mangle;", "policy", "accept;", "}")
|
||||
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "chain", "inet", "agvpn", "output_apps")
|
||||
|
||||
// Keep output_apps deterministic (no duplicates). Safe because this chain is dedicated.
|
||||
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "flush", "chain", "inet", "agvpn", "output_apps")
|
||||
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "rule", "inet", "agvpn", "output_apps", "meta", "cgroup", "@"+nftSetCgroupDirect, "meta", "mark", "set", MARK_DIRECT, "accept")
|
||||
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "rule", "inet", "agvpn", "output_apps", "meta", "cgroup", "@"+nftSetCgroupVPN, "meta", "mark", "set", MARK_APP, "accept")
|
||||
|
||||
// Ensure output chain has a jump into output_apps (routes-update may also manage this).
|
||||
out, _, _, _ := runCommandTimeout(5*time.Second, "nft", "list", "chain", "inet", "agvpn", "output")
|
||||
if !strings.Contains(out, "jump output_apps") {
|
||||
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "rule", "inet", "agvpn", "output", "jump", "output_apps")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func resolveCgroupIDForNft(input string) (uint64, string, error) {
|
||||
raw := strings.TrimSpace(input)
|
||||
if raw == "" {
|
||||
return 0, "", fmt.Errorf("empty cgroup")
|
||||
}
|
||||
|
||||
// Allow numeric cgroup id input.
|
||||
if isAllDigits(raw) {
|
||||
id, err := strconv.ParseUint(raw, 10, 64)
|
||||
if err != nil || id == 0 {
|
||||
return 0, raw, fmt.Errorf("invalid cgroup id: %s", raw)
|
||||
}
|
||||
return id, raw, nil
|
||||
}
|
||||
|
||||
// Normalize into a safe relative path under /sys/fs/cgroup.
|
||||
rel := strings.TrimPrefix(raw, "/")
|
||||
rel = filepath.Clean(rel)
|
||||
if rel == "." || rel == "" {
|
||||
return 0, raw, fmt.Errorf("invalid cgroup path: %s", raw)
|
||||
}
|
||||
if strings.HasPrefix(rel, "..") || strings.Contains(rel, "../") {
|
||||
return 0, raw, fmt.Errorf("invalid cgroup path (traversal): %s", raw)
|
||||
}
|
||||
|
||||
full := filepath.Join(cgroupRootFS, rel)
|
||||
fi, err := os.Stat(full)
|
||||
if err != nil || fi == nil || !fi.IsDir() {
|
||||
return 0, raw, fmt.Errorf("cgroup not found: %s", raw)
|
||||
}
|
||||
st, ok := fi.Sys().(*syscall.Stat_t)
|
||||
if !ok || st == nil {
|
||||
return 0, raw, fmt.Errorf("cannot stat cgroup: %s", raw)
|
||||
}
|
||||
if st.Ino == 0 {
|
||||
return 0, raw, fmt.Errorf("invalid cgroup inode id: %s", raw)
|
||||
}
|
||||
// EN: For cgroup v2, the directory inode is used as cgroup id (matches meta cgroup / bpf_get_current_cgroup_id).
|
||||
// RU: Для cgroup v2 inode директории используется как cgroup id (соответствует meta cgroup / bфункции bpf_get_current_cgroup_id).
|
||||
return st.Ino, "/" + rel, nil
|
||||
}
|
||||
|
||||
func nftAddCgroupElement(setName string, cgroupID uint64, timeoutSec int) error {
|
||||
if strings.TrimSpace(setName) == "" {
|
||||
return fmt.Errorf("empty setName")
|
||||
}
|
||||
if cgroupID == 0 {
|
||||
return fmt.Errorf("invalid cgroup id")
|
||||
}
|
||||
if timeoutSec < 0 {
|
||||
return fmt.Errorf("invalid timeout_sec")
|
||||
}
|
||||
|
||||
// NOTE: set has flags timeout; element can include timeout.
|
||||
ttl := fmt.Sprintf("%ds", timeoutSec)
|
||||
_, out, code, err := runCommandTimeout(
|
||||
5*time.Second,
|
||||
"nft", "add", "element", "inet", "agvpn", setName,
|
||||
"{", fmt.Sprintf("%d", cgroupID), "timeout", ttl, "}",
|
||||
)
|
||||
if err != nil || code != 0 {
|
||||
msg := strings.ToLower(out)
|
||||
if strings.Contains(msg, "file exists") || strings.Contains(msg, "exists") {
|
||||
return nil
|
||||
}
|
||||
if err == nil {
|
||||
err = fmt.Errorf("nft add element exited with %d", code)
|
||||
}
|
||||
return fmt.Errorf("nft add element failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func nftDelCgroupElement(setName string, cgroupID uint64) error {
|
||||
if strings.TrimSpace(setName) == "" {
|
||||
return fmt.Errorf("empty setName")
|
||||
}
|
||||
if cgroupID == 0 {
|
||||
return fmt.Errorf("invalid cgroup id")
|
||||
}
|
||||
_, out, code, err := runCommandTimeout(
|
||||
5*time.Second,
|
||||
"nft", "delete", "element", "inet", "agvpn", setName,
|
||||
"{", fmt.Sprintf("%d", cgroupID), "}",
|
||||
)
|
||||
if err != nil || code != 0 {
|
||||
msg := strings.ToLower(out)
|
||||
if strings.Contains(msg, "no such file") ||
|
||||
strings.Contains(msg, "not found") ||
|
||||
strings.Contains(msg, "does not exist") {
|
||||
return nil
|
||||
}
|
||||
if err == nil {
|
||||
err = fmt.Errorf("nft delete element exited with %d", code)
|
||||
}
|
||||
return fmt.Errorf("nft delete element failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func nftFlushSet(setName string) error {
|
||||
if strings.TrimSpace(setName) == "" {
|
||||
return fmt.Errorf("empty setName")
|
||||
}
|
||||
_, out, code, err := runCommandTimeout(5*time.Second, "nft", "flush", "set", "inet", "agvpn", setName)
|
||||
if err != nil || code != 0 {
|
||||
msg := strings.ToLower(out)
|
||||
if strings.Contains(msg, "no such file") ||
|
||||
strings.Contains(msg, "not found") ||
|
||||
strings.Contains(msg, "does not exist") {
|
||||
return nil
|
||||
}
|
||||
if err == nil {
|
||||
err = fmt.Errorf("nft flush set exited with %d", code)
|
||||
}
|
||||
return fmt.Errorf("nft flush set failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func isAllDigits(s string) bool {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return false
|
||||
}
|
||||
for i := 0; i < len(s); i++ {
|
||||
ch := s[i]
|
||||
if ch < '0' || ch > '9' {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
Reference in New Issue
Block a user