135 lines
4.0 KiB
Go
135 lines
4.0 KiB
Go
package trafficmode
|
|
|
|
import (
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
type RunCommandTimeoutFunc func(timeout time.Duration, name string, args ...string) (stdout string, stderr string, code int, err error)
|
|
|
|
type IngressBypassConfig struct {
|
|
TableName string
|
|
PreroutingChain string
|
|
OutputChain string
|
|
MarkIngress string
|
|
CaptureComment string
|
|
RestoreComment string
|
|
}
|
|
|
|
func NftObjectMissing(stdout, stderr string) bool {
|
|
text := strings.ToLower(strings.TrimSpace(stdout + " " + stderr))
|
|
return strings.Contains(text, "no such file") || strings.Contains(text, "not found")
|
|
}
|
|
|
|
func EnsureIngressReplyBypassChains(cfg IngressBypassConfig, run RunCommandTimeoutFunc) {
|
|
if run == nil {
|
|
return
|
|
}
|
|
_, _, _, _ = run(5*time.Second, "nft", "add", "table", "inet", cfg.TableName)
|
|
_, _, _, _ = run(
|
|
5*time.Second,
|
|
"nft", "add", "chain", "inet", cfg.TableName, cfg.PreroutingChain,
|
|
"{", "type", "filter", "hook", "prerouting", "priority", "mangle;", "policy", "accept;", "}",
|
|
)
|
|
_, _, _, _ = run(
|
|
5*time.Second,
|
|
"nft", "add", "chain", "inet", cfg.TableName, cfg.OutputChain,
|
|
"{", "type", "route", "hook", "output", "priority", "mangle;", "policy", "accept;", "}",
|
|
)
|
|
}
|
|
|
|
func FlushIngressReplyBypassChains(cfg IngressBypassConfig, run RunCommandTimeoutFunc) error {
|
|
if run == nil {
|
|
return fmt.Errorf("run command func is nil")
|
|
}
|
|
for _, chain := range []string{cfg.PreroutingChain, cfg.OutputChain} {
|
|
out, errOut, code, err := run(5*time.Second, "nft", "flush", "chain", "inet", cfg.TableName, chain)
|
|
if err == nil && code == 0 {
|
|
continue
|
|
}
|
|
if NftObjectMissing(out, errOut) {
|
|
continue
|
|
}
|
|
if err == nil {
|
|
err = fmt.Errorf("nft flush chain exited with %d", code)
|
|
}
|
|
return fmt.Errorf("flush %s failed: %w (%s %s)", chain, err, strings.TrimSpace(out), strings.TrimSpace(errOut))
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func EnableIngressReplyBypass(cfg IngressBypassConfig, vpnIface string, run RunCommandTimeoutFunc) error {
|
|
if run == nil {
|
|
return fmt.Errorf("run command func is nil")
|
|
}
|
|
vpnIface = strings.TrimSpace(vpnIface)
|
|
if vpnIface == "" {
|
|
return fmt.Errorf("empty vpn iface for ingress bypass")
|
|
}
|
|
|
|
EnsureIngressReplyBypassChains(cfg, run)
|
|
if err := FlushIngressReplyBypassChains(cfg, run); err != nil {
|
|
return err
|
|
}
|
|
|
|
addRule := func(chain string, args ...string) error {
|
|
out, errOut, code, err := run(5*time.Second, "nft", append([]string{"add", "rule", "inet", cfg.TableName, chain}, args...)...)
|
|
if err != nil || code != 0 {
|
|
if err == nil {
|
|
err = fmt.Errorf("nft add rule exited with %d", code)
|
|
}
|
|
return fmt.Errorf("nft add rule %s failed: %w (%s %s)", chain, err, strings.TrimSpace(out), strings.TrimSpace(errOut))
|
|
}
|
|
return nil
|
|
}
|
|
|
|
if err := addRule(
|
|
cfg.PreroutingChain,
|
|
"iifname", "!=", "lo",
|
|
"iifname", "!=", vpnIface,
|
|
"fib", "daddr", "type", "local",
|
|
"ct", "state", "new",
|
|
"ct", "mark", "set", cfg.MarkIngress,
|
|
"comment", cfg.CaptureComment,
|
|
); err != nil {
|
|
return err
|
|
}
|
|
if err := addRule(
|
|
cfg.PreroutingChain,
|
|
"ct", "mark", cfg.MarkIngress,
|
|
"meta", "mark", "set", cfg.MarkIngress,
|
|
"comment", cfg.RestoreComment,
|
|
); err != nil {
|
|
return err
|
|
}
|
|
if err := addRule(
|
|
cfg.OutputChain,
|
|
"ct", "mark", cfg.MarkIngress,
|
|
"meta", "mark", "set", cfg.MarkIngress,
|
|
"comment", cfg.RestoreComment,
|
|
); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func DisableIngressReplyBypass(cfg IngressBypassConfig, run RunCommandTimeoutFunc) error {
|
|
EnsureIngressReplyBypassChains(cfg, run)
|
|
return FlushIngressReplyBypassChains(cfg, run)
|
|
}
|
|
|
|
func IngressReplyNftActive(cfg IngressBypassConfig, run RunCommandTimeoutFunc) bool {
|
|
if run == nil {
|
|
return false
|
|
}
|
|
outPre, _, codePre, _ := run(5*time.Second, "nft", "-a", "list", "chain", "inet", cfg.TableName, cfg.PreroutingChain)
|
|
outOut, _, codeOut, _ := run(5*time.Second, "nft", "-a", "list", "chain", "inet", cfg.TableName, cfg.OutputChain)
|
|
if codePre != 0 || codeOut != 0 {
|
|
return false
|
|
}
|
|
return strings.Contains(outPre, cfg.CaptureComment) &&
|
|
strings.Contains(outPre, cfg.RestoreComment) &&
|
|
strings.Contains(outOut, cfg.RestoreComment)
|
|
}
|