package trafficappmarks import ( "fmt" "net/netip" "sort" "strconv" "strings" "time" ) type RunCommandFunc func(timeout time.Duration, name string, args ...string) (stdout string, stderr string, code int, err error) type NFTConfig struct { Table string Chain string GuardChain string LocalBypassSet string MarkApp string MarkDirect string MarkCommentPrefix string GuardCommentPrefix string GuardEnabled bool } func EnsureBase(cfg NFTConfig, run RunCommandFunc) error { if run == nil { return fmt.Errorf("run command func is nil") } _, _, _, _ = run(5*time.Second, "nft", "add", "table", "inet", cfg.Table) _, _, _, _ = run(5*time.Second, "nft", "add", "chain", "inet", cfg.Table, "output", "{", "type", "route", "hook", "output", "priority", "mangle;", "policy", "accept;", "}") _, _, _, _ = run(5*time.Second, "nft", "add", "chain", "inet", cfg.Table, cfg.GuardChain, "{", "type", "filter", "hook", "output", "priority", "filter;", "policy", "accept;", "}") _, _, _, _ = run(5*time.Second, "nft", "add", "chain", "inet", cfg.Table, cfg.Chain) _, _, _, _ = run(5*time.Second, "nft", "add", "set", "inet", cfg.Table, cfg.LocalBypassSet, "{", "type", "ipv4_addr", ";", "flags", "interval", ";", "}") out, _, _, _ := run(5*time.Second, "nft", "list", "chain", "inet", cfg.Table, "output") if !strings.Contains(out, "jump "+cfg.Chain) { _, _, _, _ = run(5*time.Second, "nft", "insert", "rule", "inet", cfg.Table, "output", "jump", cfg.Chain) } return nil } func AppMarkComment(prefix string, target string, id uint64) string { return fmt.Sprintf("%s:%s:%d", prefix, target, id) } func AppGuardComment(prefix string, target string, id uint64) string { return fmt.Sprintf("%s:%s:%d", prefix, target, id) } func UpdateLocalBypassSet(cfg NFTConfig, vpnIface string, bypassCIDRs []string, run RunCommandFunc) error { if run == nil { return fmt.Errorf("run command func is nil") } if strings.TrimSpace(cfg.Table) == "" || strings.TrimSpace(cfg.LocalBypassSet) == "" { return fmt.Errorf("invalid nft config for local bypass set") } _, _, _, _ = run(5*time.Second, "nft", "flush", "set", "inet", cfg.Table, cfg.LocalBypassSet) elems := []string{"127.0.0.0/8"} for _, dst := range bypassCIDRs { val := strings.TrimSpace(dst) if val == "" || val == "default" { continue } elems = append(elems, val) } elems = CompactIPv4IntervalElements(elems) for _, e := range elems { _, out, code, err := run( 5*time.Second, "nft", "add", "element", "inet", cfg.Table, cfg.LocalBypassSet, "{", e, "}", ) if err != nil || code != 0 { if err == nil { err = fmt.Errorf("nft add element exited with %d", code) } return fmt.Errorf("failed to update %s: %w (%s)", cfg.LocalBypassSet, err, strings.TrimSpace(out)) } } return nil } func InsertAppMarkRule(cfg NFTConfig, target string, rel string, level int, id uint64, vpnIface string, bypassCIDRs []string, run RunCommandFunc) error { if run == nil { return fmt.Errorf("run command func is nil") } target = strings.ToLower(strings.TrimSpace(target)) mark := cfg.MarkDirect if target == "vpn" { mark = cfg.MarkApp } comment := AppMarkComment(cfg.MarkCommentPrefix, target, id) pathLit := fmt.Sprintf("\"%s\"", rel) commentLit := fmt.Sprintf("\"%s\"", comment) if target == "vpn" && cfg.GuardEnabled { iface := strings.TrimSpace(vpnIface) if iface == "" { return fmt.Errorf("vpn interface required for app guard") } if err := UpdateLocalBypassSet(cfg, iface, bypassCIDRs, run); err != nil { return err } guardComment := AppGuardComment(cfg.GuardCommentPrefix, target, id) guardCommentLit := fmt.Sprintf("\"%s\"", guardComment) _, out, code, err := run( 5*time.Second, "nft", "insert", "rule", "inet", cfg.Table, cfg.GuardChain, "socket", "cgroupv2", "level", strconv.Itoa(level), pathLit, "meta", "mark", cfg.MarkApp, "oifname", "!=", iface, "ip", "daddr", "!=", "@"+cfg.LocalBypassSet, "drop", "comment", guardCommentLit, ) if err != nil || code != 0 { if err == nil { err = fmt.Errorf("nft insert guard(v4) exited with %d", code) } return fmt.Errorf("nft insert app guard(v4) failed: %w (%s)", err, strings.TrimSpace(out)) } _, out, code, err = run( 5*time.Second, "nft", "insert", "rule", "inet", cfg.Table, cfg.GuardChain, "socket", "cgroupv2", "level", strconv.Itoa(level), pathLit, "meta", "mark", cfg.MarkApp, "oifname", "!=", iface, "meta", "nfproto", "ipv6", "drop", "comment", guardCommentLit, ) if err != nil || code != 0 { if err == nil { err = fmt.Errorf("nft insert guard(v6) exited with %d", code) } return fmt.Errorf("nft insert app guard(v6) failed: %w (%s)", err, strings.TrimSpace(out)) } } _, out, code, err := run( 5*time.Second, "nft", "insert", "rule", "inet", cfg.Table, cfg.Chain, "socket", "cgroupv2", "level", strconv.Itoa(level), pathLit, "meta", "mark", "set", mark, "accept", "comment", commentLit, ) if err != nil || code != 0 { if err == nil { err = fmt.Errorf("nft insert rule exited with %d", code) } _ = DeleteAppMarkRule(cfg, target, id, run) return fmt.Errorf("nft insert appmark rule failed: %w (%s)", err, strings.TrimSpace(out)) } return nil } func DeleteAppMarkRule(cfg NFTConfig, target string, id uint64, run RunCommandFunc) error { if run == nil { return fmt.Errorf("run command func is nil") } comments := []string{ AppMarkComment(cfg.MarkCommentPrefix, target, id), AppGuardComment(cfg.GuardCommentPrefix, target, id), } chains := []string{cfg.Chain, cfg.GuardChain} for _, chain := range chains { if strings.TrimSpace(chain) == "" { continue } out, _, _, _ := run(5*time.Second, "nft", "-a", "list", "chain", "inet", cfg.Table, chain) for _, line := range strings.Split(out, "\n") { match := false for _, comment := range comments { if strings.Contains(line, comment) { match = true break } } if !match { continue } h := ParseNftHandle(line) if h <= 0 { continue } _, _, _, _ = run(5*time.Second, "nft", "delete", "rule", "inet", cfg.Table, chain, "handle", strconv.Itoa(h)) } } return nil } func HasAppMarkRule(cfg NFTConfig, target string, id uint64, run RunCommandFunc) bool { if run == nil { return false } markComment := AppMarkComment(cfg.MarkCommentPrefix, target, id) guardComment := AppGuardComment(cfg.GuardCommentPrefix, target, id) hasMark := false out, _, _, _ := run(5*time.Second, "nft", "-a", "list", "chain", "inet", cfg.Table, cfg.Chain) for _, line := range strings.Split(out, "\n") { if strings.Contains(line, markComment) { hasMark = true break } } if !hasMark { return false } if strings.EqualFold(strings.TrimSpace(target), "vpn") { if !cfg.GuardEnabled { return true } out, _, _, _ = run(5*time.Second, "nft", "-a", "list", "chain", "inet", cfg.Table, cfg.GuardChain) for _, line := range strings.Split(out, "\n") { if strings.Contains(line, guardComment) { return true } } return false } return true } func CleanupLegacyRules(cfg NFTConfig, run RunCommandFunc) error { if run == nil { return fmt.Errorf("run command func is nil") } out, _, _, _ := run(5*time.Second, "nft", "-a", "list", "chain", "inet", cfg.Table, cfg.Chain) for _, line := range strings.Split(out, "\n") { l := strings.ToLower(line) if !strings.Contains(l, "meta cgroup") { continue } if !strings.Contains(l, "svpn_cg_") { continue } h := ParseNftHandle(line) if h <= 0 { continue } _, _, _, _ = run(5*time.Second, "nft", "delete", "rule", "inet", cfg.Table, cfg.Chain, "handle", strconv.Itoa(h)) } return nil } func ClearManagedRules(cfg NFTConfig, chain string, run RunCommandFunc) { if run == nil { return } out, _, _, _ := run(5*time.Second, "nft", "-a", "list", "chain", "inet", cfg.Table, chain) for _, line := range strings.Split(out, "\n") { l := strings.ToLower(line) if !strings.Contains(l, strings.ToLower(cfg.MarkCommentPrefix)) && !strings.Contains(l, strings.ToLower(cfg.GuardCommentPrefix)) { continue } h := ParseNftHandle(line) if h <= 0 { continue } _, _, _, _ = run(5*time.Second, "nft", "delete", "rule", "inet", cfg.Table, chain, "handle", strconv.Itoa(h)) } } func ParseNftHandle(line string) int { fields := strings.Fields(line) for i := 0; i < len(fields)-1; i++ { if fields[i] == "handle" { n, _ := strconv.Atoi(fields[i+1]) return n } } return 0 } func CompactIPv4IntervalElements(raw []string) []string { pfxs := make([]netip.Prefix, 0, len(raw)) for _, v := range raw { s := strings.TrimSpace(v) if s == "" { continue } if strings.Contains(s, "/") { p, err := netip.ParsePrefix(s) if err != nil || !p.Addr().Is4() { continue } pfxs = append(pfxs, p.Masked()) continue } a, err := netip.ParseAddr(s) if err != nil || !a.Is4() { continue } pfxs = append(pfxs, netip.PrefixFrom(a, 32)) } sort.Slice(pfxs, func(i, j int) bool { ib, jb := pfxs[i].Bits(), pfxs[j].Bits() if ib != jb { return ib < jb } return pfxs[i].Addr().Less(pfxs[j].Addr()) }) out := make([]netip.Prefix, 0, len(pfxs)) for _, p := range pfxs { covered := false for _, ex := range out { if ex.Contains(p.Addr()) { covered = true break } } if covered { continue } out = append(out, p) } res := make([]string, 0, len(out)) for _, p := range out { res = append(res, p.String()) } return res }