package trafficmode import ( "fmt" "strings" "time" ) type RunCommandTimeoutFunc func(timeout time.Duration, name string, args ...string) (stdout string, stderr string, code int, err error) type IngressBypassConfig struct { TableName string PreroutingChain string OutputChain string MarkIngress string CaptureComment string RestoreComment string } func NftObjectMissing(stdout, stderr string) bool { text := strings.ToLower(strings.TrimSpace(stdout + " " + stderr)) return strings.Contains(text, "no such file") || strings.Contains(text, "not found") } func EnsureIngressReplyBypassChains(cfg IngressBypassConfig, run RunCommandTimeoutFunc) { if run == nil { return } _, _, _, _ = run(5*time.Second, "nft", "add", "table", "inet", cfg.TableName) _, _, _, _ = run( 5*time.Second, "nft", "add", "chain", "inet", cfg.TableName, cfg.PreroutingChain, "{", "type", "filter", "hook", "prerouting", "priority", "mangle;", "policy", "accept;", "}", ) _, _, _, _ = run( 5*time.Second, "nft", "add", "chain", "inet", cfg.TableName, cfg.OutputChain, "{", "type", "route", "hook", "output", "priority", "mangle;", "policy", "accept;", "}", ) } func FlushIngressReplyBypassChains(cfg IngressBypassConfig, run RunCommandTimeoutFunc) error { if run == nil { return fmt.Errorf("run command func is nil") } for _, chain := range []string{cfg.PreroutingChain, cfg.OutputChain} { out, errOut, code, err := run(5*time.Second, "nft", "flush", "chain", "inet", cfg.TableName, chain) if err == nil && code == 0 { continue } if NftObjectMissing(out, errOut) { continue } if err == nil { err = fmt.Errorf("nft flush chain exited with %d", code) } return fmt.Errorf("flush %s failed: %w (%s %s)", chain, err, strings.TrimSpace(out), strings.TrimSpace(errOut)) } return nil } func EnableIngressReplyBypass(cfg IngressBypassConfig, vpnIface string, run RunCommandTimeoutFunc) error { if run == nil { return fmt.Errorf("run command func is nil") } vpnIface = strings.TrimSpace(vpnIface) if vpnIface == "" { return fmt.Errorf("empty vpn iface for ingress bypass") } EnsureIngressReplyBypassChains(cfg, run) if err := FlushIngressReplyBypassChains(cfg, run); err != nil { return err } addRule := func(chain string, args ...string) error { out, errOut, code, err := run(5*time.Second, "nft", append([]string{"add", "rule", "inet", cfg.TableName, chain}, args...)...) if err != nil || code != 0 { if err == nil { err = fmt.Errorf("nft add rule exited with %d", code) } return fmt.Errorf("nft add rule %s failed: %w (%s %s)", chain, err, strings.TrimSpace(out), strings.TrimSpace(errOut)) } return nil } if err := addRule( cfg.PreroutingChain, "iifname", "!=", "lo", "iifname", "!=", vpnIface, "fib", "daddr", "type", "local", "ct", "state", "new", "ct", "mark", "set", cfg.MarkIngress, "comment", cfg.CaptureComment, ); err != nil { return err } if err := addRule( cfg.PreroutingChain, "ct", "mark", cfg.MarkIngress, "meta", "mark", "set", cfg.MarkIngress, "comment", cfg.RestoreComment, ); err != nil { return err } if err := addRule( cfg.OutputChain, "ct", "mark", cfg.MarkIngress, "meta", "mark", "set", cfg.MarkIngress, "comment", cfg.RestoreComment, ); err != nil { return err } return nil } func DisableIngressReplyBypass(cfg IngressBypassConfig, run RunCommandTimeoutFunc) error { EnsureIngressReplyBypassChains(cfg, run) return FlushIngressReplyBypassChains(cfg, run) } func IngressReplyNftActive(cfg IngressBypassConfig, run RunCommandTimeoutFunc) bool { if run == nil { return false } outPre, _, codePre, _ := run(5*time.Second, "nft", "-a", "list", "chain", "inet", cfg.TableName, cfg.PreroutingChain) outOut, _, codeOut, _ := run(5*time.Second, "nft", "-a", "list", "chain", "inet", cfg.TableName, cfg.OutputChain) if codePre != 0 || codeOut != 0 { return false } return strings.Contains(outPre, cfg.CaptureComment) && strings.Contains(outPre, cfg.RestoreComment) && strings.Contains(outOut, cfg.RestoreComment) }