338 lines
9.4 KiB
Go
338 lines
9.4 KiB
Go
package trafficappmarks
|
|
|
|
import (
|
|
"fmt"
|
|
"net/netip"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
type RunCommandFunc func(timeout time.Duration, name string, args ...string) (stdout string, stderr string, code int, err error)
|
|
|
|
type NFTConfig struct {
|
|
Table string
|
|
Chain string
|
|
GuardChain string
|
|
LocalBypassSet string
|
|
MarkApp string
|
|
MarkDirect string
|
|
MarkCommentPrefix string
|
|
GuardCommentPrefix string
|
|
GuardEnabled bool
|
|
}
|
|
|
|
func EnsureBase(cfg NFTConfig, run RunCommandFunc) error {
|
|
if run == nil {
|
|
return fmt.Errorf("run command func is nil")
|
|
}
|
|
_, _, _, _ = run(5*time.Second, "nft", "add", "table", "inet", cfg.Table)
|
|
_, _, _, _ = run(5*time.Second, "nft", "add", "chain", "inet", cfg.Table, "output", "{", "type", "route", "hook", "output", "priority", "mangle;", "policy", "accept;", "}")
|
|
_, _, _, _ = run(5*time.Second, "nft", "add", "chain", "inet", cfg.Table, cfg.GuardChain, "{", "type", "filter", "hook", "output", "priority", "filter;", "policy", "accept;", "}")
|
|
_, _, _, _ = run(5*time.Second, "nft", "add", "chain", "inet", cfg.Table, cfg.Chain)
|
|
_, _, _, _ = run(5*time.Second, "nft", "add", "set", "inet", cfg.Table, cfg.LocalBypassSet, "{", "type", "ipv4_addr", ";", "flags", "interval", ";", "}")
|
|
|
|
out, _, _, _ := run(5*time.Second, "nft", "list", "chain", "inet", cfg.Table, "output")
|
|
if !strings.Contains(out, "jump "+cfg.Chain) {
|
|
_, _, _, _ = run(5*time.Second, "nft", "insert", "rule", "inet", cfg.Table, "output", "jump", cfg.Chain)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func AppMarkComment(prefix string, target string, id uint64) string {
|
|
return fmt.Sprintf("%s:%s:%d", prefix, target, id)
|
|
}
|
|
|
|
func AppGuardComment(prefix string, target string, id uint64) string {
|
|
return fmt.Sprintf("%s:%s:%d", prefix, target, id)
|
|
}
|
|
|
|
func UpdateLocalBypassSet(cfg NFTConfig, vpnIface string, bypassCIDRs []string, run RunCommandFunc) error {
|
|
if run == nil {
|
|
return fmt.Errorf("run command func is nil")
|
|
}
|
|
if strings.TrimSpace(cfg.Table) == "" || strings.TrimSpace(cfg.LocalBypassSet) == "" {
|
|
return fmt.Errorf("invalid nft config for local bypass set")
|
|
}
|
|
|
|
_, _, _, _ = run(5*time.Second, "nft", "flush", "set", "inet", cfg.Table, cfg.LocalBypassSet)
|
|
|
|
elems := []string{"127.0.0.0/8"}
|
|
for _, dst := range bypassCIDRs {
|
|
val := strings.TrimSpace(dst)
|
|
if val == "" || val == "default" {
|
|
continue
|
|
}
|
|
elems = append(elems, val)
|
|
}
|
|
elems = CompactIPv4IntervalElements(elems)
|
|
|
|
for _, e := range elems {
|
|
_, out, code, err := run(
|
|
5*time.Second,
|
|
"nft", "add", "element", "inet", cfg.Table, cfg.LocalBypassSet,
|
|
"{", e, "}",
|
|
)
|
|
if err != nil || code != 0 {
|
|
if err == nil {
|
|
err = fmt.Errorf("nft add element exited with %d", code)
|
|
}
|
|
return fmt.Errorf("failed to update %s: %w (%s)", cfg.LocalBypassSet, err, strings.TrimSpace(out))
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func InsertAppMarkRule(cfg NFTConfig, target string, rel string, level int, id uint64, vpnIface string, bypassCIDRs []string, run RunCommandFunc) error {
|
|
if run == nil {
|
|
return fmt.Errorf("run command func is nil")
|
|
}
|
|
|
|
target = strings.ToLower(strings.TrimSpace(target))
|
|
mark := cfg.MarkDirect
|
|
if target == "vpn" {
|
|
mark = cfg.MarkApp
|
|
}
|
|
|
|
comment := AppMarkComment(cfg.MarkCommentPrefix, target, id)
|
|
pathLit := fmt.Sprintf("\"%s\"", rel)
|
|
commentLit := fmt.Sprintf("\"%s\"", comment)
|
|
|
|
if target == "vpn" && cfg.GuardEnabled {
|
|
iface := strings.TrimSpace(vpnIface)
|
|
if iface == "" {
|
|
return fmt.Errorf("vpn interface required for app guard")
|
|
}
|
|
if err := UpdateLocalBypassSet(cfg, iface, bypassCIDRs, run); err != nil {
|
|
return err
|
|
}
|
|
|
|
guardComment := AppGuardComment(cfg.GuardCommentPrefix, target, id)
|
|
guardCommentLit := fmt.Sprintf("\"%s\"", guardComment)
|
|
|
|
_, out, code, err := run(
|
|
5*time.Second,
|
|
"nft", "insert", "rule", "inet", cfg.Table, cfg.GuardChain,
|
|
"socket", "cgroupv2", "level", strconv.Itoa(level), pathLit,
|
|
"meta", "mark", cfg.MarkApp,
|
|
"oifname", "!=", iface,
|
|
"ip", "daddr", "!=", "@"+cfg.LocalBypassSet,
|
|
"drop",
|
|
"comment", guardCommentLit,
|
|
)
|
|
if err != nil || code != 0 {
|
|
if err == nil {
|
|
err = fmt.Errorf("nft insert guard(v4) exited with %d", code)
|
|
}
|
|
return fmt.Errorf("nft insert app guard(v4) failed: %w (%s)", err, strings.TrimSpace(out))
|
|
}
|
|
|
|
_, out, code, err = run(
|
|
5*time.Second,
|
|
"nft", "insert", "rule", "inet", cfg.Table, cfg.GuardChain,
|
|
"socket", "cgroupv2", "level", strconv.Itoa(level), pathLit,
|
|
"meta", "mark", cfg.MarkApp,
|
|
"oifname", "!=", iface,
|
|
"meta", "nfproto", "ipv6",
|
|
"drop",
|
|
"comment", guardCommentLit,
|
|
)
|
|
if err != nil || code != 0 {
|
|
if err == nil {
|
|
err = fmt.Errorf("nft insert guard(v6) exited with %d", code)
|
|
}
|
|
return fmt.Errorf("nft insert app guard(v6) failed: %w (%s)", err, strings.TrimSpace(out))
|
|
}
|
|
}
|
|
|
|
_, out, code, err := run(
|
|
5*time.Second,
|
|
"nft", "insert", "rule", "inet", cfg.Table, cfg.Chain,
|
|
"socket", "cgroupv2", "level", strconv.Itoa(level), pathLit,
|
|
"meta", "mark", "set", mark,
|
|
"accept",
|
|
"comment", commentLit,
|
|
)
|
|
if err != nil || code != 0 {
|
|
if err == nil {
|
|
err = fmt.Errorf("nft insert rule exited with %d", code)
|
|
}
|
|
_ = DeleteAppMarkRule(cfg, target, id, run)
|
|
return fmt.Errorf("nft insert appmark rule failed: %w (%s)", err, strings.TrimSpace(out))
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func DeleteAppMarkRule(cfg NFTConfig, target string, id uint64, run RunCommandFunc) error {
|
|
if run == nil {
|
|
return fmt.Errorf("run command func is nil")
|
|
}
|
|
comments := []string{
|
|
AppMarkComment(cfg.MarkCommentPrefix, target, id),
|
|
AppGuardComment(cfg.GuardCommentPrefix, target, id),
|
|
}
|
|
chains := []string{cfg.Chain, cfg.GuardChain}
|
|
for _, chain := range chains {
|
|
if strings.TrimSpace(chain) == "" {
|
|
continue
|
|
}
|
|
out, _, _, _ := run(5*time.Second, "nft", "-a", "list", "chain", "inet", cfg.Table, chain)
|
|
for _, line := range strings.Split(out, "\n") {
|
|
match := false
|
|
for _, comment := range comments {
|
|
if strings.Contains(line, comment) {
|
|
match = true
|
|
break
|
|
}
|
|
}
|
|
if !match {
|
|
continue
|
|
}
|
|
h := ParseNftHandle(line)
|
|
if h <= 0 {
|
|
continue
|
|
}
|
|
_, _, _, _ = run(5*time.Second, "nft", "delete", "rule", "inet", cfg.Table, chain, "handle", strconv.Itoa(h))
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func HasAppMarkRule(cfg NFTConfig, target string, id uint64, run RunCommandFunc) bool {
|
|
if run == nil {
|
|
return false
|
|
}
|
|
markComment := AppMarkComment(cfg.MarkCommentPrefix, target, id)
|
|
guardComment := AppGuardComment(cfg.GuardCommentPrefix, target, id)
|
|
|
|
hasMark := false
|
|
out, _, _, _ := run(5*time.Second, "nft", "-a", "list", "chain", "inet", cfg.Table, cfg.Chain)
|
|
for _, line := range strings.Split(out, "\n") {
|
|
if strings.Contains(line, markComment) {
|
|
hasMark = true
|
|
break
|
|
}
|
|
}
|
|
if !hasMark {
|
|
return false
|
|
}
|
|
if strings.EqualFold(strings.TrimSpace(target), "vpn") {
|
|
if !cfg.GuardEnabled {
|
|
return true
|
|
}
|
|
out, _, _, _ = run(5*time.Second, "nft", "-a", "list", "chain", "inet", cfg.Table, cfg.GuardChain)
|
|
for _, line := range strings.Split(out, "\n") {
|
|
if strings.Contains(line, guardComment) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
|
|
func CleanupLegacyRules(cfg NFTConfig, run RunCommandFunc) error {
|
|
if run == nil {
|
|
return fmt.Errorf("run command func is nil")
|
|
}
|
|
out, _, _, _ := run(5*time.Second, "nft", "-a", "list", "chain", "inet", cfg.Table, cfg.Chain)
|
|
for _, line := range strings.Split(out, "\n") {
|
|
l := strings.ToLower(line)
|
|
if !strings.Contains(l, "meta cgroup") {
|
|
continue
|
|
}
|
|
if !strings.Contains(l, "svpn_cg_") {
|
|
continue
|
|
}
|
|
h := ParseNftHandle(line)
|
|
if h <= 0 {
|
|
continue
|
|
}
|
|
_, _, _, _ = run(5*time.Second, "nft", "delete", "rule", "inet", cfg.Table, cfg.Chain, "handle", strconv.Itoa(h))
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func ClearManagedRules(cfg NFTConfig, chain string, run RunCommandFunc) {
|
|
if run == nil {
|
|
return
|
|
}
|
|
out, _, _, _ := run(5*time.Second, "nft", "-a", "list", "chain", "inet", cfg.Table, chain)
|
|
for _, line := range strings.Split(out, "\n") {
|
|
l := strings.ToLower(line)
|
|
if !strings.Contains(l, strings.ToLower(cfg.MarkCommentPrefix)) &&
|
|
!strings.Contains(l, strings.ToLower(cfg.GuardCommentPrefix)) {
|
|
continue
|
|
}
|
|
h := ParseNftHandle(line)
|
|
if h <= 0 {
|
|
continue
|
|
}
|
|
_, _, _, _ = run(5*time.Second, "nft", "delete", "rule", "inet", cfg.Table, chain, "handle", strconv.Itoa(h))
|
|
}
|
|
}
|
|
|
|
func ParseNftHandle(line string) int {
|
|
fields := strings.Fields(line)
|
|
for i := 0; i < len(fields)-1; i++ {
|
|
if fields[i] == "handle" {
|
|
n, _ := strconv.Atoi(fields[i+1])
|
|
return n
|
|
}
|
|
}
|
|
return 0
|
|
}
|
|
|
|
func CompactIPv4IntervalElements(raw []string) []string {
|
|
pfxs := make([]netip.Prefix, 0, len(raw))
|
|
for _, v := range raw {
|
|
s := strings.TrimSpace(v)
|
|
if s == "" {
|
|
continue
|
|
}
|
|
if strings.Contains(s, "/") {
|
|
p, err := netip.ParsePrefix(s)
|
|
if err != nil || !p.Addr().Is4() {
|
|
continue
|
|
}
|
|
pfxs = append(pfxs, p.Masked())
|
|
continue
|
|
}
|
|
a, err := netip.ParseAddr(s)
|
|
if err != nil || !a.Is4() {
|
|
continue
|
|
}
|
|
pfxs = append(pfxs, netip.PrefixFrom(a, 32))
|
|
}
|
|
|
|
sort.Slice(pfxs, func(i, j int) bool {
|
|
ib, jb := pfxs[i].Bits(), pfxs[j].Bits()
|
|
if ib != jb {
|
|
return ib < jb
|
|
}
|
|
return pfxs[i].Addr().Less(pfxs[j].Addr())
|
|
})
|
|
|
|
out := make([]netip.Prefix, 0, len(pfxs))
|
|
for _, p := range pfxs {
|
|
covered := false
|
|
for _, ex := range out {
|
|
if ex.Contains(p.Addr()) {
|
|
covered = true
|
|
break
|
|
}
|
|
}
|
|
if covered {
|
|
continue
|
|
}
|
|
out = append(out, p)
|
|
}
|
|
|
|
res := make([]string, 0, len(out))
|
|
for _, p := range out {
|
|
res = append(res, p.String())
|
|
}
|
|
return res
|
|
}
|