271 lines
7.0 KiB
Go
271 lines
7.0 KiB
Go
package app
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net/netip"
|
|
"os"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
const (
|
|
transportPolicyKernelEnvEnable = "SVPN_TRANSPORT_POLICY_KERNEL_APPLY"
|
|
transportPolicyKernelEnvIPRules = "SVPN_TRANSPORT_POLICY_KERNEL_IPRULES"
|
|
)
|
|
|
|
var (
|
|
transportPolicyKernelRunCommand = runCommandTimeout
|
|
transportPolicyKernelUpdateSet = func(ctx context.Context, setName string, ips []string) error {
|
|
return nftUpdateSetIPsSmart(ctx, setName, ips, nil)
|
|
}
|
|
)
|
|
|
|
func applyTransportPolicyKernelStage(current, staged transportPolicyRuntimeState) error {
|
|
if !transportPolicyKernelApplyEnabled() {
|
|
return nil
|
|
}
|
|
if err := applyTransportPolicyKernelNftSets(current, staged); err != nil {
|
|
return err
|
|
}
|
|
if transportPolicyKernelIPRulesEnabled() {
|
|
if err := applyTransportPolicyKernelIPRules(staged); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
if transportPolicyKernelConntrackStickyEnabled() {
|
|
if err := refreshTransportPolicyOwnerLocksFromConntrack(staged); err != nil {
|
|
appendTraceLineRateLimited(
|
|
"transport",
|
|
fmt.Sprintf("policy conntrack sticky refresh warning: %v", err),
|
|
5*time.Second,
|
|
)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func transportPolicyKernelApplyEnabled() bool {
|
|
switch strings.ToLower(strings.TrimSpace(os.Getenv(transportPolicyKernelEnvEnable))) {
|
|
case "1", "true", "yes", "on":
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
func transportPolicyKernelIPRulesEnabled() bool {
|
|
switch strings.ToLower(strings.TrimSpace(os.Getenv(transportPolicyKernelEnvIPRules))) {
|
|
case "1", "true", "yes", "on":
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
func applyTransportPolicyKernelNftSets(current, staged transportPolicyRuntimeState) error {
|
|
desired := buildTransportPolicyCIDRSetElements(staged.Interfaces)
|
|
currentSets := collectTransportPolicyCIDRSetNames(current.Interfaces)
|
|
desiredSets := collectMapKeys(desired)
|
|
|
|
_, _, _, _ = transportPolicyKernelRunCommand(5*time.Second, "nft", "add", "table", "inet", "agvpn")
|
|
|
|
for _, setName := range desiredSets {
|
|
if strings.TrimSpace(setName) == "" {
|
|
continue
|
|
}
|
|
_, _, _, _ = transportPolicyKernelRunCommand(
|
|
5*time.Second,
|
|
"nft", "add", "set", "inet", "agvpn", setName,
|
|
"{", "type", "ipv4_addr", ";", "flags", "interval", ";", "}",
|
|
)
|
|
ips := desired[setName]
|
|
if len(ips) == 0 {
|
|
_, stderr, _, err := transportPolicyKernelRunCommand(10*time.Second, "nft", "flush", "set", "inet", "agvpn", setName)
|
|
if err != nil && !strings.Contains(strings.ToLower(stderr), "no such file") {
|
|
return fmt.Errorf("nft flush %s failed: %v (%s)", setName, err, strings.TrimSpace(stderr))
|
|
}
|
|
continue
|
|
}
|
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
err := transportPolicyKernelUpdateSet(ctx, setName, ips)
|
|
cancel()
|
|
if err != nil {
|
|
return fmt.Errorf("nft update set %s failed: %w", setName, err)
|
|
}
|
|
}
|
|
|
|
for _, stale := range diffStringSet(currentSets, desiredSets) {
|
|
if strings.TrimSpace(stale) == "" {
|
|
continue
|
|
}
|
|
_, _, _, _ = transportPolicyKernelRunCommand(5*time.Second, "nft", "flush", "set", "inet", "agvpn", stale)
|
|
_, _, _, _ = transportPolicyKernelRunCommand(5*time.Second, "nft", "delete", "set", "inet", "agvpn", stale)
|
|
}
|
|
|
|
appendTraceLineRateLimited(
|
|
"transport",
|
|
fmt.Sprintf("policy kernel nft stage: sets=%d stale=%d", len(desiredSets), len(diffStringSet(currentSets, desiredSets))),
|
|
5*time.Second,
|
|
)
|
|
return nil
|
|
}
|
|
|
|
func applyTransportPolicyKernelIPRules(staged transportPolicyRuntimeState) error {
|
|
type tuple struct {
|
|
pref int
|
|
mark string
|
|
table string
|
|
}
|
|
seen := map[string]struct{}{}
|
|
rules := make([]tuple, 0, 16)
|
|
for _, iface := range staged.Interfaces {
|
|
table := strings.TrimSpace(iface.RoutingTable)
|
|
if normalizeTransportIfaceID(iface.IfaceID) == transportDefaultIfaceID {
|
|
continue
|
|
}
|
|
if table == "" {
|
|
continue
|
|
}
|
|
for _, r := range iface.Rules {
|
|
mark := strings.ToLower(strings.TrimSpace(r.MarkHex))
|
|
if _, ok := parseTransportMarkHex(mark); !ok {
|
|
continue
|
|
}
|
|
if _, ok := parseTransportPref(r.PriorityBase); !ok {
|
|
continue
|
|
}
|
|
key := strconv.Itoa(r.PriorityBase) + "|" + mark + "|" + table
|
|
if _, ok := seen[key]; ok {
|
|
continue
|
|
}
|
|
seen[key] = struct{}{}
|
|
rules = append(rules, tuple{pref: r.PriorityBase, mark: mark, table: table})
|
|
}
|
|
}
|
|
sort.Slice(rules, func(i, j int) bool { return rules[i].pref < rules[j].pref })
|
|
for _, rule := range rules {
|
|
_, _, _, _ = transportPolicyKernelRunCommand(4*time.Second, "ip", "-4", "rule", "del", "pref", strconv.Itoa(rule.pref))
|
|
stdout, stderr, code, err := transportPolicyKernelRunCommand(
|
|
5*time.Second,
|
|
"ip", "-4", "rule", "add",
|
|
"pref", strconv.Itoa(rule.pref),
|
|
"fwmark", rule.mark,
|
|
"lookup", rule.table,
|
|
)
|
|
if err != nil || code != 0 {
|
|
return fmt.Errorf("ip rule add pref=%d mark=%s table=%s failed: %v (stdout=%s stderr=%s code=%d)",
|
|
rule.pref, rule.mark, rule.table, err, strings.TrimSpace(stdout), strings.TrimSpace(stderr), code)
|
|
}
|
|
}
|
|
appendTraceLineRateLimited(
|
|
"transport",
|
|
fmt.Sprintf("policy kernel iprule stage: rules=%d", len(rules)),
|
|
5*time.Second,
|
|
)
|
|
return nil
|
|
}
|
|
|
|
func buildTransportPolicyCIDRSetElements(interfaces []TransportPolicyCompileInterface) map[string][]string {
|
|
out := map[string][]string{}
|
|
for _, iface := range interfaces {
|
|
for _, rule := range iface.Rules {
|
|
if strings.ToLower(strings.TrimSpace(rule.SelectorType)) != "cidr" {
|
|
continue
|
|
}
|
|
setName := strings.TrimSpace(rule.NftSet)
|
|
if setName == "" {
|
|
continue
|
|
}
|
|
pfx, err := parseIntentCIDR(rule.SelectorValue)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
ip := cidrToNftElement(pfx)
|
|
if ip == "" {
|
|
continue
|
|
}
|
|
out[setName] = appendUniqueString(out[setName], ip)
|
|
}
|
|
}
|
|
for k := range out {
|
|
sort.Strings(out[k])
|
|
}
|
|
return out
|
|
}
|
|
|
|
func cidrToNftElement(pfx netip.Prefix) string {
|
|
if !pfx.IsValid() || !pfx.Addr().Is4() {
|
|
return ""
|
|
}
|
|
if pfx.Bits() == 32 {
|
|
return pfx.Addr().String()
|
|
}
|
|
return pfx.Masked().String()
|
|
}
|
|
|
|
func collectTransportPolicyCIDRSetNames(interfaces []TransportPolicyCompileInterface) []string {
|
|
seen := map[string]struct{}{}
|
|
for _, iface := range interfaces {
|
|
for _, it := range iface.Sets {
|
|
if strings.ToLower(strings.TrimSpace(it.SelectorType)) != "cidr" {
|
|
continue
|
|
}
|
|
name := strings.TrimSpace(it.Name)
|
|
if name == "" {
|
|
continue
|
|
}
|
|
seen[name] = struct{}{}
|
|
}
|
|
}
|
|
out := make([]string, 0, len(seen))
|
|
for k := range seen {
|
|
out = append(out, k)
|
|
}
|
|
sort.Strings(out)
|
|
return out
|
|
}
|
|
|
|
func collectMapKeys(m map[string][]string) []string {
|
|
out := make([]string, 0, len(m))
|
|
for k := range m {
|
|
out = append(out, k)
|
|
}
|
|
sort.Strings(out)
|
|
return out
|
|
}
|
|
|
|
func diffStringSet(a, b []string) []string {
|
|
if len(a) == 0 {
|
|
return nil
|
|
}
|
|
inB := map[string]struct{}{}
|
|
for _, it := range b {
|
|
inB[it] = struct{}{}
|
|
}
|
|
out := make([]string, 0, len(a))
|
|
for _, it := range a {
|
|
if _, ok := inB[it]; ok {
|
|
continue
|
|
}
|
|
out = append(out, it)
|
|
}
|
|
sort.Strings(out)
|
|
return out
|
|
}
|
|
|
|
func appendUniqueString(in []string, v string) []string {
|
|
val := strings.TrimSpace(v)
|
|
if val == "" {
|
|
return in
|
|
}
|
|
for _, it := range in {
|
|
if it == val {
|
|
return in
|
|
}
|
|
}
|
|
return append(in, val)
|
|
}
|