package app import ( "context" "fmt" "net/netip" "os" "sort" "strconv" "strings" "time" ) const ( transportPolicyKernelEnvEnable = "SVPN_TRANSPORT_POLICY_KERNEL_APPLY" transportPolicyKernelEnvIPRules = "SVPN_TRANSPORT_POLICY_KERNEL_IPRULES" ) var ( transportPolicyKernelRunCommand = runCommandTimeout transportPolicyKernelUpdateSet = func(ctx context.Context, setName string, ips []string) error { return nftUpdateSetIPsSmart(ctx, setName, ips, nil) } ) func applyTransportPolicyKernelStage(current, staged transportPolicyRuntimeState) error { if !transportPolicyKernelApplyEnabled() { return nil } if err := applyTransportPolicyKernelNftSets(current, staged); err != nil { return err } if transportPolicyKernelIPRulesEnabled() { if err := applyTransportPolicyKernelIPRules(staged); err != nil { return err } } if transportPolicyKernelConntrackStickyEnabled() { if err := refreshTransportPolicyOwnerLocksFromConntrack(staged); err != nil { appendTraceLineRateLimited( "transport", fmt.Sprintf("policy conntrack sticky refresh warning: %v", err), 5*time.Second, ) } } return nil } func transportPolicyKernelApplyEnabled() bool { switch strings.ToLower(strings.TrimSpace(os.Getenv(transportPolicyKernelEnvEnable))) { case "1", "true", "yes", "on": return true default: return false } } func transportPolicyKernelIPRulesEnabled() bool { switch strings.ToLower(strings.TrimSpace(os.Getenv(transportPolicyKernelEnvIPRules))) { case "1", "true", "yes", "on": return true default: return false } } func applyTransportPolicyKernelNftSets(current, staged transportPolicyRuntimeState) error { desired := buildTransportPolicyCIDRSetElements(staged.Interfaces) currentSets := collectTransportPolicyCIDRSetNames(current.Interfaces) desiredSets := collectMapKeys(desired) _, _, _, _ = transportPolicyKernelRunCommand(5*time.Second, "nft", "add", "table", "inet", "agvpn") for _, setName := range desiredSets { if strings.TrimSpace(setName) == "" { continue } _, _, _, _ = transportPolicyKernelRunCommand( 5*time.Second, "nft", "add", "set", "inet", "agvpn", setName, "{", "type", "ipv4_addr", ";", "flags", "interval", ";", "}", ) ips := desired[setName] if len(ips) == 0 { _, stderr, _, err := transportPolicyKernelRunCommand(10*time.Second, "nft", "flush", "set", "inet", "agvpn", setName) if err != nil && !strings.Contains(strings.ToLower(stderr), "no such file") { return fmt.Errorf("nft flush %s failed: %v (%s)", setName, err, strings.TrimSpace(stderr)) } continue } ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) err := transportPolicyKernelUpdateSet(ctx, setName, ips) cancel() if err != nil { return fmt.Errorf("nft update set %s failed: %w", setName, err) } } for _, stale := range diffStringSet(currentSets, desiredSets) { if strings.TrimSpace(stale) == "" { continue } _, _, _, _ = transportPolicyKernelRunCommand(5*time.Second, "nft", "flush", "set", "inet", "agvpn", stale) _, _, _, _ = transportPolicyKernelRunCommand(5*time.Second, "nft", "delete", "set", "inet", "agvpn", stale) } appendTraceLineRateLimited( "transport", fmt.Sprintf("policy kernel nft stage: sets=%d stale=%d", len(desiredSets), len(diffStringSet(currentSets, desiredSets))), 5*time.Second, ) return nil } func applyTransportPolicyKernelIPRules(staged transportPolicyRuntimeState) error { type tuple struct { pref int mark string table string } seen := map[string]struct{}{} rules := make([]tuple, 0, 16) for _, iface := range staged.Interfaces { table := strings.TrimSpace(iface.RoutingTable) if normalizeTransportIfaceID(iface.IfaceID) == transportDefaultIfaceID { continue } if table == "" { continue } for _, r := range iface.Rules { mark := strings.ToLower(strings.TrimSpace(r.MarkHex)) if _, ok := parseTransportMarkHex(mark); !ok { continue } if _, ok := parseTransportPref(r.PriorityBase); !ok { continue } key := strconv.Itoa(r.PriorityBase) + "|" + mark + "|" + table if _, ok := seen[key]; ok { continue } seen[key] = struct{}{} rules = append(rules, tuple{pref: r.PriorityBase, mark: mark, table: table}) } } sort.Slice(rules, func(i, j int) bool { return rules[i].pref < rules[j].pref }) for _, rule := range rules { _, _, _, _ = transportPolicyKernelRunCommand(4*time.Second, "ip", "-4", "rule", "del", "pref", strconv.Itoa(rule.pref)) stdout, stderr, code, err := transportPolicyKernelRunCommand( 5*time.Second, "ip", "-4", "rule", "add", "pref", strconv.Itoa(rule.pref), "fwmark", rule.mark, "lookup", rule.table, ) if err != nil || code != 0 { return fmt.Errorf("ip rule add pref=%d mark=%s table=%s failed: %v (stdout=%s stderr=%s code=%d)", rule.pref, rule.mark, rule.table, err, strings.TrimSpace(stdout), strings.TrimSpace(stderr), code) } } appendTraceLineRateLimited( "transport", fmt.Sprintf("policy kernel iprule stage: rules=%d", len(rules)), 5*time.Second, ) return nil } func buildTransportPolicyCIDRSetElements(interfaces []TransportPolicyCompileInterface) map[string][]string { out := map[string][]string{} for _, iface := range interfaces { for _, rule := range iface.Rules { if strings.ToLower(strings.TrimSpace(rule.SelectorType)) != "cidr" { continue } setName := strings.TrimSpace(rule.NftSet) if setName == "" { continue } pfx, err := parseIntentCIDR(rule.SelectorValue) if err != nil { continue } ip := cidrToNftElement(pfx) if ip == "" { continue } out[setName] = appendUniqueString(out[setName], ip) } } for k := range out { sort.Strings(out[k]) } return out } func cidrToNftElement(pfx netip.Prefix) string { if !pfx.IsValid() || !pfx.Addr().Is4() { return "" } if pfx.Bits() == 32 { return pfx.Addr().String() } return pfx.Masked().String() } func collectTransportPolicyCIDRSetNames(interfaces []TransportPolicyCompileInterface) []string { seen := map[string]struct{}{} for _, iface := range interfaces { for _, it := range iface.Sets { if strings.ToLower(strings.TrimSpace(it.SelectorType)) != "cidr" { continue } name := strings.TrimSpace(it.Name) if name == "" { continue } seen[name] = struct{}{} } } out := make([]string, 0, len(seen)) for k := range seen { out = append(out, k) } sort.Strings(out) return out } func collectMapKeys(m map[string][]string) []string { out := make([]string, 0, len(m)) for k := range m { out = append(out, k) } sort.Strings(out) return out } func diffStringSet(a, b []string) []string { if len(a) == 0 { return nil } inB := map[string]struct{}{} for _, it := range b { inB[it] = struct{}{} } out := make([]string, 0, len(a)) for _, it := range a { if _, ok := inB[it]; ok { continue } out = append(out, it) } sort.Strings(out) return out } func appendUniqueString(in []string, v string) []string { val := strings.TrimSpace(v) if val == "" { return in } for _, it := range in { if it == val { return in } } return append(in, val) }