Files
elmprodvpn/selective-vpn-api/app/transport_policy_apply_kernel_conntrack.go

198 lines
5.2 KiB
Go

package app
import (
"fmt"
"net/netip"
"os"
"sort"
"strconv"
"strings"
"time"
)
const transportPolicyKernelEnvConntrackSticky = "SVPN_TRANSPORT_POLICY_CONNTRACK_STICKY"
var (
transportPolicyKernelConntrackOutput = func(timeout time.Duration) (string, error) {
stdout, stderr, code, err := transportPolicyKernelRunCommand(timeout, "conntrack", "-L", "-f", "ipv4")
if err != nil || code != 0 {
return "", fmt.Errorf("conntrack list failed: %v (stderr=%s code=%d)", err, strings.TrimSpace(stderr), code)
}
return stdout, nil
}
transportPolicyKernelSaveOwnerLocksState = saveTransportOwnerLocksState
)
type transportPolicyMarkOwner struct {
ClientID string
ClientKind string
IfaceID string
MarkHex string
}
func transportPolicyKernelConntrackStickyEnabled() bool {
switch strings.ToLower(strings.TrimSpace(os.Getenv(transportPolicyKernelEnvConntrackSticky))) {
case "1", "true", "yes", "on":
return true
default:
return false
}
}
func refreshTransportPolicyOwnerLocksFromConntrack(staged transportPolicyRuntimeState) error {
markOwner := collectTransportPolicyMarkOwners(staged.Interfaces)
if len(markOwner) == 0 {
_ = transportPolicyKernelSaveOwnerLocksState(TransportOwnerLockState{
Version: transportStateVersion,
PolicyRevision: staged.PolicyRevision,
})
return nil
}
output, err := transportPolicyKernelConntrackOutput(8 * time.Second)
if err != nil {
return err
}
locks := parseTransportOwnerLocksFromConntrack(output, markOwner, staged.PolicyRevision)
if err := transportPolicyKernelSaveOwnerLocksState(locks); err != nil {
return fmt.Errorf("save owner locks failed: %w", err)
}
appendTraceLineRateLimited(
"transport",
fmt.Sprintf("policy conntrack sticky refresh: locks=%d revision=%d", len(locks.Items), staged.PolicyRevision),
5*time.Second,
)
return nil
}
func collectTransportPolicyMarkOwners(interfaces []TransportPolicyCompileInterface) map[uint32]transportPolicyMarkOwner {
out := map[uint32]transportPolicyMarkOwner{}
for _, iface := range interfaces {
ifaceID := normalizeTransportIfaceID(iface.IfaceID)
for _, rule := range iface.Rules {
markHex := strings.ToLower(strings.TrimSpace(rule.MarkHex))
markRaw, ok := parseTransportMarkHex(markHex)
if !ok {
continue
}
if markRaw > uint64(^uint32(0)) {
continue
}
mark := uint32(markRaw)
if strings.TrimSpace(rule.ClientID) == "" {
continue
}
if _, exists := out[mark]; exists {
continue
}
out[mark] = transportPolicyMarkOwner{
ClientID: strings.TrimSpace(rule.ClientID),
ClientKind: strings.TrimSpace(rule.ClientKind),
IfaceID: ifaceID,
MarkHex: markHex,
}
}
}
return out
}
func parseTransportOwnerLocksFromConntrack(raw string, markOwner map[uint32]transportPolicyMarkOwner, policyRevision int64) TransportOwnerLockState {
now := time.Now().UTC().Format(time.RFC3339)
byDst := map[string]TransportOwnerLockRecord{}
lines := strings.Split(raw, "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
if line == "" {
continue
}
dst, mark, proto, ok := parseTransportConntrackLockLine(line)
if !ok {
continue
}
owner, exists := markOwner[mark]
if !exists {
continue
}
key := dst.String()
if _, exists := byDst[key]; exists {
continue
}
byDst[key] = TransportOwnerLockRecord{
DestinationIP: key,
ClientID: owner.ClientID,
ClientKind: owner.ClientKind,
IfaceID: owner.IfaceID,
MarkHex: owner.MarkHex,
Proto: proto,
UpdatedAt: now,
}
}
items := make([]TransportOwnerLockRecord, 0, len(byDst))
for _, item := range byDst {
items = append(items, item)
}
sort.Slice(items, func(i, j int) bool {
if items[i].DestinationIP != items[j].DestinationIP {
return items[i].DestinationIP < items[j].DestinationIP
}
return items[i].ClientID < items[j].ClientID
})
return TransportOwnerLockState{
Version: transportStateVersion,
UpdatedAt: now,
PolicyRevision: policyRevision,
Count: len(items),
Items: items,
}
}
func parseTransportConntrackLockLine(line string) (dst netip.Addr, mark uint32, proto string, ok bool) {
tokens := strings.Fields(strings.TrimSpace(line))
if len(tokens) == 0 {
return netip.Addr{}, 0, "", false
}
proto = strings.ToLower(strings.TrimSpace(tokens[0]))
var gotDst bool
var gotMark bool
for _, tok := range tokens {
if !gotDst && strings.HasPrefix(tok, "dst=") {
val := strings.TrimPrefix(tok, "dst=")
addr, err := netip.ParseAddr(strings.TrimSpace(val))
if err == nil && addr.Is4() {
dst = addr
gotDst = true
}
continue
}
if !gotMark && strings.HasPrefix(tok, "mark=") {
val := strings.TrimPrefix(tok, "mark=")
parsed, ok := parseTransportConntrackMark(val)
if ok {
mark = parsed
gotMark = true
}
}
}
if !gotDst || !gotMark {
return netip.Addr{}, 0, "", false
}
return dst, mark, proto, true
}
func parseTransportConntrackMark(raw string) (uint32, bool) {
v := strings.TrimSpace(raw)
if v == "" {
return 0, false
}
base := 10
if strings.HasPrefix(strings.ToLower(v), "0x") {
base = 16
v = v[2:]
}
n, err := strconv.ParseUint(v, base, 32)
if err != nil {
return 0, false
}
return uint32(n), true
}