Files
elmprodvpn/selective-vpn-api/app/traffic_appmarks.go

369 lines
12 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package app
import (
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strconv"
"strings"
"syscall"
"time"
)
// ---------------------------------------------------------------------
// traffic app marks (per-app routing via cgroup -> fwmark)
// ---------------------------------------------------------------------
//
// EN: This module manages runtime cgroup-id sets used by nftables rules in
// EN: routes_update.go (output_apps chain). GUI/clients can add/remove cgroup IDs
// EN: to force traffic through VPN (MARK_APP) or force direct (MARK_DIRECT).
// RU: Этот модуль управляет runtime cgroup-id сетами для nftables правил из
// RU: routes_update.go (цепочка output_apps). GUI/клиенты могут добавлять/удалять
// RU: cgroup IDs, чтобы форсировать трафик через VPN (MARK_APP) или в direct (MARK_DIRECT).
const (
nftSetCgroupVPN = "svpn_cg_vpn"
nftSetCgroupDirect = "svpn_cg_direct"
cgroupRootFS = "/sys/fs/cgroup"
)
func handleTrafficAppMarks(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
vpnElems, _ := readNftSetElements(nftSetCgroupVPN)
directElems, _ := readNftSetElements(nftSetCgroupDirect)
writeJSON(w, http.StatusOK, TrafficAppMarksStatusResponse{
VPNCount: len(vpnElems),
DirectCount: len(directElems),
Message: "ok",
})
case http.MethodPost:
var body TrafficAppMarksRequest
if r.Body != nil {
defer r.Body.Close()
if err := json.NewDecoder(io.LimitReader(r.Body, 1<<20)).Decode(&body); err != nil && err != io.EOF {
http.Error(w, "bad json", http.StatusBadRequest)
return
}
}
op := TrafficAppMarksOp(strings.ToLower(strings.TrimSpace(string(body.Op))))
target := strings.ToLower(strings.TrimSpace(body.Target))
cgroup := strings.TrimSpace(body.Cgroup)
timeoutSec := body.TimeoutSec
if op == "" {
http.Error(w, "missing op", http.StatusBadRequest)
return
}
if target == "" {
http.Error(w, "missing target", http.StatusBadRequest)
return
}
if target != "vpn" && target != "direct" {
http.Error(w, "target must be vpn|direct", http.StatusBadRequest)
return
}
if (op == TrafficAppMarksAdd || op == TrafficAppMarksDel) && cgroup == "" {
http.Error(w, "missing cgroup", http.StatusBadRequest)
return
}
if timeoutSec < 0 {
http.Error(w, "timeout_sec must be >= 0", http.StatusBadRequest)
return
}
// Ensure nft objects exist even if routes-update hasn't run yet.
if err := ensureAppMarksNft(); err != nil {
writeJSON(w, http.StatusOK, TrafficAppMarksResponse{
OK: false,
Op: string(op),
Target: target,
Cgroup: cgroup,
Message: "nft init failed: " + err.Error(),
})
return
}
var (
cgID uint64
err error
)
if cgroup != "" {
cgID, cgroup, err = resolveCgroupIDForNft(cgroup)
if err != nil {
writeJSON(w, http.StatusOK, TrafficAppMarksResponse{
OK: false,
Op: string(op),
Target: target,
Cgroup: body.Cgroup,
Message: err.Error(),
})
return
}
}
if op == TrafficAppMarksAdd && target == "vpn" {
// Ensure VPN policy table has a base route. This matters when current traffic-mode=direct.
traffic := loadTrafficModeState()
iface, _ := resolveTrafficIface(traffic.PreferredIface)
if strings.TrimSpace(iface) == "" {
writeJSON(w, http.StatusOK, TrafficAppMarksResponse{
OK: false,
Op: string(op),
Target: target,
Cgroup: cgroup,
CgroupID: cgID,
Message: "vpn interface not found (set preferred iface or bring VPN up)",
})
return
}
if err := ensureTrafficRouteBase(iface, traffic.AutoLocalBypass); err != nil {
writeJSON(w, http.StatusOK, TrafficAppMarksResponse{
OK: false,
Op: string(op),
Target: target,
Cgroup: cgroup,
CgroupID: cgID,
Message: "ensure vpn route base failed: " + err.Error(),
})
return
}
}
setName := nftSetCgroupDirect
if target == "vpn" {
setName = nftSetCgroupVPN
}
switch op {
case TrafficAppMarksAdd:
ttl := timeoutSec
if ttl == 0 {
ttl = 24 * 60 * 60 // 24h default if client didn't specify
}
if err := nftAddCgroupElement(setName, cgID, ttl); err != nil {
writeJSON(w, http.StatusOK, TrafficAppMarksResponse{
OK: false,
Op: string(op),
Target: target,
Cgroup: cgroup,
CgroupID: cgID,
Message: err.Error(),
})
return
}
appendTraceLine("traffic", fmt.Sprintf("appmarks add target=%s cgroup=%s id=%d ttl=%ds", target, cgroup, cgID, ttl))
writeJSON(w, http.StatusOK, TrafficAppMarksResponse{
OK: true,
Op: string(op),
Target: target,
Cgroup: cgroup,
CgroupID: cgID,
TimeoutSec: ttl,
Message: "added",
})
case TrafficAppMarksDel:
if err := nftDelCgroupElement(setName, cgID); err != nil {
writeJSON(w, http.StatusOK, TrafficAppMarksResponse{
OK: false,
Op: string(op),
Target: target,
Cgroup: cgroup,
CgroupID: cgID,
Message: err.Error(),
})
return
}
appendTraceLine("traffic", fmt.Sprintf("appmarks del target=%s cgroup=%s id=%d", target, cgroup, cgID))
writeJSON(w, http.StatusOK, TrafficAppMarksResponse{
OK: true,
Op: string(op),
Target: target,
Cgroup: cgroup,
CgroupID: cgID,
Message: "deleted",
})
case TrafficAppMarksClear:
if err := nftFlushSet(setName); err != nil {
writeJSON(w, http.StatusOK, TrafficAppMarksResponse{
OK: false,
Op: string(op),
Target: target,
Message: err.Error(),
})
return
}
appendTraceLine("traffic", fmt.Sprintf("appmarks clear target=%s", target))
writeJSON(w, http.StatusOK, TrafficAppMarksResponse{
OK: true,
Op: string(op),
Target: target,
Message: "cleared",
})
default:
http.Error(w, "unknown op", http.StatusBadRequest)
}
default:
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
}
}
func ensureAppMarksNft() error {
// Best-effort "ensure": ignore "exists" errors and proceed.
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "table", "inet", "agvpn")
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "set", "inet", "agvpn", nftSetCgroupVPN, "{", "typeof", "meta", "cgroup", ";", "flags", "timeout", ";", "}")
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "set", "inet", "agvpn", nftSetCgroupDirect, "{", "typeof", "meta", "cgroup", ";", "flags", "timeout", ";", "}")
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "chain", "inet", "agvpn", "output", "{", "type", "route", "hook", "output", "priority", "mangle;", "policy", "accept;", "}")
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "chain", "inet", "agvpn", "output_apps")
// Keep output_apps deterministic (no duplicates). Safe because this chain is dedicated.
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "flush", "chain", "inet", "agvpn", "output_apps")
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "rule", "inet", "agvpn", "output_apps", "meta", "cgroup", "@"+nftSetCgroupDirect, "meta", "mark", "set", MARK_DIRECT, "accept")
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "rule", "inet", "agvpn", "output_apps", "meta", "cgroup", "@"+nftSetCgroupVPN, "meta", "mark", "set", MARK_APP, "accept")
// Ensure output chain has a jump into output_apps (routes-update may also manage this).
out, _, _, _ := runCommandTimeout(5*time.Second, "nft", "list", "chain", "inet", "agvpn", "output")
if !strings.Contains(out, "jump output_apps") {
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "rule", "inet", "agvpn", "output", "jump", "output_apps")
}
return nil
}
func resolveCgroupIDForNft(input string) (uint64, string, error) {
raw := strings.TrimSpace(input)
if raw == "" {
return 0, "", fmt.Errorf("empty cgroup")
}
// Allow numeric cgroup id input.
if isAllDigits(raw) {
id, err := strconv.ParseUint(raw, 10, 64)
if err != nil || id == 0 {
return 0, raw, fmt.Errorf("invalid cgroup id: %s", raw)
}
return id, raw, nil
}
// Normalize into a safe relative path under /sys/fs/cgroup.
rel := strings.TrimPrefix(raw, "/")
rel = filepath.Clean(rel)
if rel == "." || rel == "" {
return 0, raw, fmt.Errorf("invalid cgroup path: %s", raw)
}
if strings.HasPrefix(rel, "..") || strings.Contains(rel, "../") {
return 0, raw, fmt.Errorf("invalid cgroup path (traversal): %s", raw)
}
full := filepath.Join(cgroupRootFS, rel)
fi, err := os.Stat(full)
if err != nil || fi == nil || !fi.IsDir() {
return 0, raw, fmt.Errorf("cgroup not found: %s", raw)
}
st, ok := fi.Sys().(*syscall.Stat_t)
if !ok || st == nil {
return 0, raw, fmt.Errorf("cannot stat cgroup: %s", raw)
}
if st.Ino == 0 {
return 0, raw, fmt.Errorf("invalid cgroup inode id: %s", raw)
}
// EN: For cgroup v2, the directory inode is used as cgroup id (matches meta cgroup / bpf_get_current_cgroup_id).
// RU: Для cgroup v2 inode директории используется как cgroup id (соответствует meta cgroup / bфункции bpf_get_current_cgroup_id).
return st.Ino, "/" + rel, nil
}
func nftAddCgroupElement(setName string, cgroupID uint64, timeoutSec int) error {
if strings.TrimSpace(setName) == "" {
return fmt.Errorf("empty setName")
}
if cgroupID == 0 {
return fmt.Errorf("invalid cgroup id")
}
if timeoutSec < 0 {
return fmt.Errorf("invalid timeout_sec")
}
// NOTE: set has flags timeout; element can include timeout.
ttl := fmt.Sprintf("%ds", timeoutSec)
_, out, code, err := runCommandTimeout(
5*time.Second,
"nft", "add", "element", "inet", "agvpn", setName,
"{", fmt.Sprintf("%d", cgroupID), "timeout", ttl, "}",
)
if err != nil || code != 0 {
msg := strings.ToLower(out)
if strings.Contains(msg, "file exists") || strings.Contains(msg, "exists") {
return nil
}
if err == nil {
err = fmt.Errorf("nft add element exited with %d", code)
}
return fmt.Errorf("nft add element failed: %w", err)
}
return nil
}
func nftDelCgroupElement(setName string, cgroupID uint64) error {
if strings.TrimSpace(setName) == "" {
return fmt.Errorf("empty setName")
}
if cgroupID == 0 {
return fmt.Errorf("invalid cgroup id")
}
_, out, code, err := runCommandTimeout(
5*time.Second,
"nft", "delete", "element", "inet", "agvpn", setName,
"{", fmt.Sprintf("%d", cgroupID), "}",
)
if err != nil || code != 0 {
msg := strings.ToLower(out)
if strings.Contains(msg, "no such file") ||
strings.Contains(msg, "not found") ||
strings.Contains(msg, "does not exist") {
return nil
}
if err == nil {
err = fmt.Errorf("nft delete element exited with %d", code)
}
return fmt.Errorf("nft delete element failed: %w", err)
}
return nil
}
func nftFlushSet(setName string) error {
if strings.TrimSpace(setName) == "" {
return fmt.Errorf("empty setName")
}
_, out, code, err := runCommandTimeout(5*time.Second, "nft", "flush", "set", "inet", "agvpn", setName)
if err != nil || code != 0 {
msg := strings.ToLower(out)
if strings.Contains(msg, "no such file") ||
strings.Contains(msg, "not found") ||
strings.Contains(msg, "does not exist") {
return nil
}
if err == nil {
err = fmt.Errorf("nft flush set exited with %d", code)
}
return fmt.Errorf("nft flush set failed: %w", err)
}
return nil
}
func isAllDigits(s string) bool {
s = strings.TrimSpace(s)
if s == "" {
return false
}
for i := 0; i < len(s); i++ {
ch := s[i]
if ch < '0' || ch > '9' {
return false
}
}
return true
}