Files
elmprodvpn/selective-vpn-api/app/transport_policy_apply_kernel.go

271 lines
7.0 KiB
Go

package app
import (
"context"
"fmt"
"net/netip"
"os"
"sort"
"strconv"
"strings"
"time"
)
const (
transportPolicyKernelEnvEnable = "SVPN_TRANSPORT_POLICY_KERNEL_APPLY"
transportPolicyKernelEnvIPRules = "SVPN_TRANSPORT_POLICY_KERNEL_IPRULES"
)
var (
transportPolicyKernelRunCommand = runCommandTimeout
transportPolicyKernelUpdateSet = func(ctx context.Context, setName string, ips []string) error {
return nftUpdateSetIPsSmart(ctx, setName, ips, nil)
}
)
func applyTransportPolicyKernelStage(current, staged transportPolicyRuntimeState) error {
if !transportPolicyKernelApplyEnabled() {
return nil
}
if err := applyTransportPolicyKernelNftSets(current, staged); err != nil {
return err
}
if transportPolicyKernelIPRulesEnabled() {
if err := applyTransportPolicyKernelIPRules(staged); err != nil {
return err
}
}
if transportPolicyKernelConntrackStickyEnabled() {
if err := refreshTransportPolicyOwnerLocksFromConntrack(staged); err != nil {
appendTraceLineRateLimited(
"transport",
fmt.Sprintf("policy conntrack sticky refresh warning: %v", err),
5*time.Second,
)
}
}
return nil
}
func transportPolicyKernelApplyEnabled() bool {
switch strings.ToLower(strings.TrimSpace(os.Getenv(transportPolicyKernelEnvEnable))) {
case "1", "true", "yes", "on":
return true
default:
return false
}
}
func transportPolicyKernelIPRulesEnabled() bool {
switch strings.ToLower(strings.TrimSpace(os.Getenv(transportPolicyKernelEnvIPRules))) {
case "1", "true", "yes", "on":
return true
default:
return false
}
}
func applyTransportPolicyKernelNftSets(current, staged transportPolicyRuntimeState) error {
desired := buildTransportPolicyCIDRSetElements(staged.Interfaces)
currentSets := collectTransportPolicyCIDRSetNames(current.Interfaces)
desiredSets := collectMapKeys(desired)
_, _, _, _ = transportPolicyKernelRunCommand(5*time.Second, "nft", "add", "table", "inet", "agvpn")
for _, setName := range desiredSets {
if strings.TrimSpace(setName) == "" {
continue
}
_, _, _, _ = transportPolicyKernelRunCommand(
5*time.Second,
"nft", "add", "set", "inet", "agvpn", setName,
"{", "type", "ipv4_addr", ";", "flags", "interval", ";", "}",
)
ips := desired[setName]
if len(ips) == 0 {
_, stderr, _, err := transportPolicyKernelRunCommand(10*time.Second, "nft", "flush", "set", "inet", "agvpn", setName)
if err != nil && !strings.Contains(strings.ToLower(stderr), "no such file") {
return fmt.Errorf("nft flush %s failed: %v (%s)", setName, err, strings.TrimSpace(stderr))
}
continue
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
err := transportPolicyKernelUpdateSet(ctx, setName, ips)
cancel()
if err != nil {
return fmt.Errorf("nft update set %s failed: %w", setName, err)
}
}
for _, stale := range diffStringSet(currentSets, desiredSets) {
if strings.TrimSpace(stale) == "" {
continue
}
_, _, _, _ = transportPolicyKernelRunCommand(5*time.Second, "nft", "flush", "set", "inet", "agvpn", stale)
_, _, _, _ = transportPolicyKernelRunCommand(5*time.Second, "nft", "delete", "set", "inet", "agvpn", stale)
}
appendTraceLineRateLimited(
"transport",
fmt.Sprintf("policy kernel nft stage: sets=%d stale=%d", len(desiredSets), len(diffStringSet(currentSets, desiredSets))),
5*time.Second,
)
return nil
}
func applyTransportPolicyKernelIPRules(staged transportPolicyRuntimeState) error {
type tuple struct {
pref int
mark string
table string
}
seen := map[string]struct{}{}
rules := make([]tuple, 0, 16)
for _, iface := range staged.Interfaces {
table := strings.TrimSpace(iface.RoutingTable)
if normalizeTransportIfaceID(iface.IfaceID) == transportDefaultIfaceID {
continue
}
if table == "" {
continue
}
for _, r := range iface.Rules {
mark := strings.ToLower(strings.TrimSpace(r.MarkHex))
if _, ok := parseTransportMarkHex(mark); !ok {
continue
}
if _, ok := parseTransportPref(r.PriorityBase); !ok {
continue
}
key := strconv.Itoa(r.PriorityBase) + "|" + mark + "|" + table
if _, ok := seen[key]; ok {
continue
}
seen[key] = struct{}{}
rules = append(rules, tuple{pref: r.PriorityBase, mark: mark, table: table})
}
}
sort.Slice(rules, func(i, j int) bool { return rules[i].pref < rules[j].pref })
for _, rule := range rules {
_, _, _, _ = transportPolicyKernelRunCommand(4*time.Second, "ip", "-4", "rule", "del", "pref", strconv.Itoa(rule.pref))
stdout, stderr, code, err := transportPolicyKernelRunCommand(
5*time.Second,
"ip", "-4", "rule", "add",
"pref", strconv.Itoa(rule.pref),
"fwmark", rule.mark,
"lookup", rule.table,
)
if err != nil || code != 0 {
return fmt.Errorf("ip rule add pref=%d mark=%s table=%s failed: %v (stdout=%s stderr=%s code=%d)",
rule.pref, rule.mark, rule.table, err, strings.TrimSpace(stdout), strings.TrimSpace(stderr), code)
}
}
appendTraceLineRateLimited(
"transport",
fmt.Sprintf("policy kernel iprule stage: rules=%d", len(rules)),
5*time.Second,
)
return nil
}
func buildTransportPolicyCIDRSetElements(interfaces []TransportPolicyCompileInterface) map[string][]string {
out := map[string][]string{}
for _, iface := range interfaces {
for _, rule := range iface.Rules {
if strings.ToLower(strings.TrimSpace(rule.SelectorType)) != "cidr" {
continue
}
setName := strings.TrimSpace(rule.NftSet)
if setName == "" {
continue
}
pfx, err := parseIntentCIDR(rule.SelectorValue)
if err != nil {
continue
}
ip := cidrToNftElement(pfx)
if ip == "" {
continue
}
out[setName] = appendUniqueString(out[setName], ip)
}
}
for k := range out {
sort.Strings(out[k])
}
return out
}
func cidrToNftElement(pfx netip.Prefix) string {
if !pfx.IsValid() || !pfx.Addr().Is4() {
return ""
}
if pfx.Bits() == 32 {
return pfx.Addr().String()
}
return pfx.Masked().String()
}
func collectTransportPolicyCIDRSetNames(interfaces []TransportPolicyCompileInterface) []string {
seen := map[string]struct{}{}
for _, iface := range interfaces {
for _, it := range iface.Sets {
if strings.ToLower(strings.TrimSpace(it.SelectorType)) != "cidr" {
continue
}
name := strings.TrimSpace(it.Name)
if name == "" {
continue
}
seen[name] = struct{}{}
}
}
out := make([]string, 0, len(seen))
for k := range seen {
out = append(out, k)
}
sort.Strings(out)
return out
}
func collectMapKeys(m map[string][]string) []string {
out := make([]string, 0, len(m))
for k := range m {
out = append(out, k)
}
sort.Strings(out)
return out
}
func diffStringSet(a, b []string) []string {
if len(a) == 0 {
return nil
}
inB := map[string]struct{}{}
for _, it := range b {
inB[it] = struct{}{}
}
out := make([]string, 0, len(a))
for _, it := range a {
if _, ok := inB[it]; ok {
continue
}
out = append(out, it)
}
sort.Strings(out)
return out
}
func appendUniqueString(in []string, v string) []string {
val := strings.TrimSpace(v)
if val == "" {
return in
}
for _, it := range in {
if it == val {
return in
}
}
return append(in, val)
}