Files

338 lines
9.4 KiB
Go

package trafficappmarks
import (
"fmt"
"net/netip"
"sort"
"strconv"
"strings"
"time"
)
type RunCommandFunc func(timeout time.Duration, name string, args ...string) (stdout string, stderr string, code int, err error)
type NFTConfig struct {
Table string
Chain string
GuardChain string
LocalBypassSet string
MarkApp string
MarkDirect string
MarkCommentPrefix string
GuardCommentPrefix string
GuardEnabled bool
}
func EnsureBase(cfg NFTConfig, run RunCommandFunc) error {
if run == nil {
return fmt.Errorf("run command func is nil")
}
_, _, _, _ = run(5*time.Second, "nft", "add", "table", "inet", cfg.Table)
_, _, _, _ = run(5*time.Second, "nft", "add", "chain", "inet", cfg.Table, "output", "{", "type", "route", "hook", "output", "priority", "mangle;", "policy", "accept;", "}")
_, _, _, _ = run(5*time.Second, "nft", "add", "chain", "inet", cfg.Table, cfg.GuardChain, "{", "type", "filter", "hook", "output", "priority", "filter;", "policy", "accept;", "}")
_, _, _, _ = run(5*time.Second, "nft", "add", "chain", "inet", cfg.Table, cfg.Chain)
_, _, _, _ = run(5*time.Second, "nft", "add", "set", "inet", cfg.Table, cfg.LocalBypassSet, "{", "type", "ipv4_addr", ";", "flags", "interval", ";", "}")
out, _, _, _ := run(5*time.Second, "nft", "list", "chain", "inet", cfg.Table, "output")
if !strings.Contains(out, "jump "+cfg.Chain) {
_, _, _, _ = run(5*time.Second, "nft", "insert", "rule", "inet", cfg.Table, "output", "jump", cfg.Chain)
}
return nil
}
func AppMarkComment(prefix string, target string, id uint64) string {
return fmt.Sprintf("%s:%s:%d", prefix, target, id)
}
func AppGuardComment(prefix string, target string, id uint64) string {
return fmt.Sprintf("%s:%s:%d", prefix, target, id)
}
func UpdateLocalBypassSet(cfg NFTConfig, vpnIface string, bypassCIDRs []string, run RunCommandFunc) error {
if run == nil {
return fmt.Errorf("run command func is nil")
}
if strings.TrimSpace(cfg.Table) == "" || strings.TrimSpace(cfg.LocalBypassSet) == "" {
return fmt.Errorf("invalid nft config for local bypass set")
}
_, _, _, _ = run(5*time.Second, "nft", "flush", "set", "inet", cfg.Table, cfg.LocalBypassSet)
elems := []string{"127.0.0.0/8"}
for _, dst := range bypassCIDRs {
val := strings.TrimSpace(dst)
if val == "" || val == "default" {
continue
}
elems = append(elems, val)
}
elems = CompactIPv4IntervalElements(elems)
for _, e := range elems {
_, out, code, err := run(
5*time.Second,
"nft", "add", "element", "inet", cfg.Table, cfg.LocalBypassSet,
"{", e, "}",
)
if err != nil || code != 0 {
if err == nil {
err = fmt.Errorf("nft add element exited with %d", code)
}
return fmt.Errorf("failed to update %s: %w (%s)", cfg.LocalBypassSet, err, strings.TrimSpace(out))
}
}
return nil
}
func InsertAppMarkRule(cfg NFTConfig, target string, rel string, level int, id uint64, vpnIface string, bypassCIDRs []string, run RunCommandFunc) error {
if run == nil {
return fmt.Errorf("run command func is nil")
}
target = strings.ToLower(strings.TrimSpace(target))
mark := cfg.MarkDirect
if target == "vpn" {
mark = cfg.MarkApp
}
comment := AppMarkComment(cfg.MarkCommentPrefix, target, id)
pathLit := fmt.Sprintf("\"%s\"", rel)
commentLit := fmt.Sprintf("\"%s\"", comment)
if target == "vpn" && cfg.GuardEnabled {
iface := strings.TrimSpace(vpnIface)
if iface == "" {
return fmt.Errorf("vpn interface required for app guard")
}
if err := UpdateLocalBypassSet(cfg, iface, bypassCIDRs, run); err != nil {
return err
}
guardComment := AppGuardComment(cfg.GuardCommentPrefix, target, id)
guardCommentLit := fmt.Sprintf("\"%s\"", guardComment)
_, out, code, err := run(
5*time.Second,
"nft", "insert", "rule", "inet", cfg.Table, cfg.GuardChain,
"socket", "cgroupv2", "level", strconv.Itoa(level), pathLit,
"meta", "mark", cfg.MarkApp,
"oifname", "!=", iface,
"ip", "daddr", "!=", "@"+cfg.LocalBypassSet,
"drop",
"comment", guardCommentLit,
)
if err != nil || code != 0 {
if err == nil {
err = fmt.Errorf("nft insert guard(v4) exited with %d", code)
}
return fmt.Errorf("nft insert app guard(v4) failed: %w (%s)", err, strings.TrimSpace(out))
}
_, out, code, err = run(
5*time.Second,
"nft", "insert", "rule", "inet", cfg.Table, cfg.GuardChain,
"socket", "cgroupv2", "level", strconv.Itoa(level), pathLit,
"meta", "mark", cfg.MarkApp,
"oifname", "!=", iface,
"meta", "nfproto", "ipv6",
"drop",
"comment", guardCommentLit,
)
if err != nil || code != 0 {
if err == nil {
err = fmt.Errorf("nft insert guard(v6) exited with %d", code)
}
return fmt.Errorf("nft insert app guard(v6) failed: %w (%s)", err, strings.TrimSpace(out))
}
}
_, out, code, err := run(
5*time.Second,
"nft", "insert", "rule", "inet", cfg.Table, cfg.Chain,
"socket", "cgroupv2", "level", strconv.Itoa(level), pathLit,
"meta", "mark", "set", mark,
"accept",
"comment", commentLit,
)
if err != nil || code != 0 {
if err == nil {
err = fmt.Errorf("nft insert rule exited with %d", code)
}
_ = DeleteAppMarkRule(cfg, target, id, run)
return fmt.Errorf("nft insert appmark rule failed: %w (%s)", err, strings.TrimSpace(out))
}
return nil
}
func DeleteAppMarkRule(cfg NFTConfig, target string, id uint64, run RunCommandFunc) error {
if run == nil {
return fmt.Errorf("run command func is nil")
}
comments := []string{
AppMarkComment(cfg.MarkCommentPrefix, target, id),
AppGuardComment(cfg.GuardCommentPrefix, target, id),
}
chains := []string{cfg.Chain, cfg.GuardChain}
for _, chain := range chains {
if strings.TrimSpace(chain) == "" {
continue
}
out, _, _, _ := run(5*time.Second, "nft", "-a", "list", "chain", "inet", cfg.Table, chain)
for _, line := range strings.Split(out, "\n") {
match := false
for _, comment := range comments {
if strings.Contains(line, comment) {
match = true
break
}
}
if !match {
continue
}
h := ParseNftHandle(line)
if h <= 0 {
continue
}
_, _, _, _ = run(5*time.Second, "nft", "delete", "rule", "inet", cfg.Table, chain, "handle", strconv.Itoa(h))
}
}
return nil
}
func HasAppMarkRule(cfg NFTConfig, target string, id uint64, run RunCommandFunc) bool {
if run == nil {
return false
}
markComment := AppMarkComment(cfg.MarkCommentPrefix, target, id)
guardComment := AppGuardComment(cfg.GuardCommentPrefix, target, id)
hasMark := false
out, _, _, _ := run(5*time.Second, "nft", "-a", "list", "chain", "inet", cfg.Table, cfg.Chain)
for _, line := range strings.Split(out, "\n") {
if strings.Contains(line, markComment) {
hasMark = true
break
}
}
if !hasMark {
return false
}
if strings.EqualFold(strings.TrimSpace(target), "vpn") {
if !cfg.GuardEnabled {
return true
}
out, _, _, _ = run(5*time.Second, "nft", "-a", "list", "chain", "inet", cfg.Table, cfg.GuardChain)
for _, line := range strings.Split(out, "\n") {
if strings.Contains(line, guardComment) {
return true
}
}
return false
}
return true
}
func CleanupLegacyRules(cfg NFTConfig, run RunCommandFunc) error {
if run == nil {
return fmt.Errorf("run command func is nil")
}
out, _, _, _ := run(5*time.Second, "nft", "-a", "list", "chain", "inet", cfg.Table, cfg.Chain)
for _, line := range strings.Split(out, "\n") {
l := strings.ToLower(line)
if !strings.Contains(l, "meta cgroup") {
continue
}
if !strings.Contains(l, "svpn_cg_") {
continue
}
h := ParseNftHandle(line)
if h <= 0 {
continue
}
_, _, _, _ = run(5*time.Second, "nft", "delete", "rule", "inet", cfg.Table, cfg.Chain, "handle", strconv.Itoa(h))
}
return nil
}
func ClearManagedRules(cfg NFTConfig, chain string, run RunCommandFunc) {
if run == nil {
return
}
out, _, _, _ := run(5*time.Second, "nft", "-a", "list", "chain", "inet", cfg.Table, chain)
for _, line := range strings.Split(out, "\n") {
l := strings.ToLower(line)
if !strings.Contains(l, strings.ToLower(cfg.MarkCommentPrefix)) &&
!strings.Contains(l, strings.ToLower(cfg.GuardCommentPrefix)) {
continue
}
h := ParseNftHandle(line)
if h <= 0 {
continue
}
_, _, _, _ = run(5*time.Second, "nft", "delete", "rule", "inet", cfg.Table, chain, "handle", strconv.Itoa(h))
}
}
func ParseNftHandle(line string) int {
fields := strings.Fields(line)
for i := 0; i < len(fields)-1; i++ {
if fields[i] == "handle" {
n, _ := strconv.Atoi(fields[i+1])
return n
}
}
return 0
}
func CompactIPv4IntervalElements(raw []string) []string {
pfxs := make([]netip.Prefix, 0, len(raw))
for _, v := range raw {
s := strings.TrimSpace(v)
if s == "" {
continue
}
if strings.Contains(s, "/") {
p, err := netip.ParsePrefix(s)
if err != nil || !p.Addr().Is4() {
continue
}
pfxs = append(pfxs, p.Masked())
continue
}
a, err := netip.ParseAddr(s)
if err != nil || !a.Is4() {
continue
}
pfxs = append(pfxs, netip.PrefixFrom(a, 32))
}
sort.Slice(pfxs, func(i, j int) bool {
ib, jb := pfxs[i].Bits(), pfxs[j].Bits()
if ib != jb {
return ib < jb
}
return pfxs[i].Addr().Less(pfxs[j].Addr())
})
out := make([]netip.Prefix, 0, len(pfxs))
for _, p := range pfxs {
covered := false
for _, ex := range out {
if ex.Contains(p.Addr()) {
covered = true
break
}
}
if covered {
continue
}
out = append(out, p)
}
res := make([]string, 0, len(out))
for _, p := range out {
res = append(res, p.String())
}
return res
}