Files

135 lines
4.0 KiB
Go

package trafficmode
import (
"fmt"
"strings"
"time"
)
type RunCommandTimeoutFunc func(timeout time.Duration, name string, args ...string) (stdout string, stderr string, code int, err error)
type IngressBypassConfig struct {
TableName string
PreroutingChain string
OutputChain string
MarkIngress string
CaptureComment string
RestoreComment string
}
func NftObjectMissing(stdout, stderr string) bool {
text := strings.ToLower(strings.TrimSpace(stdout + " " + stderr))
return strings.Contains(text, "no such file") || strings.Contains(text, "not found")
}
func EnsureIngressReplyBypassChains(cfg IngressBypassConfig, run RunCommandTimeoutFunc) {
if run == nil {
return
}
_, _, _, _ = run(5*time.Second, "nft", "add", "table", "inet", cfg.TableName)
_, _, _, _ = run(
5*time.Second,
"nft", "add", "chain", "inet", cfg.TableName, cfg.PreroutingChain,
"{", "type", "filter", "hook", "prerouting", "priority", "mangle;", "policy", "accept;", "}",
)
_, _, _, _ = run(
5*time.Second,
"nft", "add", "chain", "inet", cfg.TableName, cfg.OutputChain,
"{", "type", "route", "hook", "output", "priority", "mangle;", "policy", "accept;", "}",
)
}
func FlushIngressReplyBypassChains(cfg IngressBypassConfig, run RunCommandTimeoutFunc) error {
if run == nil {
return fmt.Errorf("run command func is nil")
}
for _, chain := range []string{cfg.PreroutingChain, cfg.OutputChain} {
out, errOut, code, err := run(5*time.Second, "nft", "flush", "chain", "inet", cfg.TableName, chain)
if err == nil && code == 0 {
continue
}
if NftObjectMissing(out, errOut) {
continue
}
if err == nil {
err = fmt.Errorf("nft flush chain exited with %d", code)
}
return fmt.Errorf("flush %s failed: %w (%s %s)", chain, err, strings.TrimSpace(out), strings.TrimSpace(errOut))
}
return nil
}
func EnableIngressReplyBypass(cfg IngressBypassConfig, vpnIface string, run RunCommandTimeoutFunc) error {
if run == nil {
return fmt.Errorf("run command func is nil")
}
vpnIface = strings.TrimSpace(vpnIface)
if vpnIface == "" {
return fmt.Errorf("empty vpn iface for ingress bypass")
}
EnsureIngressReplyBypassChains(cfg, run)
if err := FlushIngressReplyBypassChains(cfg, run); err != nil {
return err
}
addRule := func(chain string, args ...string) error {
out, errOut, code, err := run(5*time.Second, "nft", append([]string{"add", "rule", "inet", cfg.TableName, chain}, args...)...)
if err != nil || code != 0 {
if err == nil {
err = fmt.Errorf("nft add rule exited with %d", code)
}
return fmt.Errorf("nft add rule %s failed: %w (%s %s)", chain, err, strings.TrimSpace(out), strings.TrimSpace(errOut))
}
return nil
}
if err := addRule(
cfg.PreroutingChain,
"iifname", "!=", "lo",
"iifname", "!=", vpnIface,
"fib", "daddr", "type", "local",
"ct", "state", "new",
"ct", "mark", "set", cfg.MarkIngress,
"comment", cfg.CaptureComment,
); err != nil {
return err
}
if err := addRule(
cfg.PreroutingChain,
"ct", "mark", cfg.MarkIngress,
"meta", "mark", "set", cfg.MarkIngress,
"comment", cfg.RestoreComment,
); err != nil {
return err
}
if err := addRule(
cfg.OutputChain,
"ct", "mark", cfg.MarkIngress,
"meta", "mark", "set", cfg.MarkIngress,
"comment", cfg.RestoreComment,
); err != nil {
return err
}
return nil
}
func DisableIngressReplyBypass(cfg IngressBypassConfig, run RunCommandTimeoutFunc) error {
EnsureIngressReplyBypassChains(cfg, run)
return FlushIngressReplyBypassChains(cfg, run)
}
func IngressReplyNftActive(cfg IngressBypassConfig, run RunCommandTimeoutFunc) bool {
if run == nil {
return false
}
outPre, _, codePre, _ := run(5*time.Second, "nft", "-a", "list", "chain", "inet", cfg.TableName, cfg.PreroutingChain)
outOut, _, codeOut, _ := run(5*time.Second, "nft", "-a", "list", "chain", "inet", cfg.TableName, cfg.OutputChain)
if codePre != 0 || codeOut != 0 {
return false
}
return strings.Contains(outPre, cfg.CaptureComment) &&
strings.Contains(outPre, cfg.RestoreComment) &&
strings.Contains(outOut, cfg.RestoreComment)
}