Files
elmprodvpn/selective-vpn-api/app/transport_policy_compile.go

304 lines
9.2 KiB
Go

package app
import (
"crypto/sha1"
"encoding/hex"
"fmt"
"sort"
"strconv"
"strings"
"time"
)
type transportPolicyCompileIfaceBucket struct {
Iface TransportPolicyCompileInterface
sets map[string]*TransportPolicyCompileSet
}
func compileTransportPolicyPlan(intents []TransportPolicyIntent, clients []TransportClient, policyRevision int64) (TransportPolicyCompilePlan, []TransportConflictRecord) {
clientByID := make(map[string]TransportClient, len(clients))
for _, it := range clients {
clientByID[it.ID] = it
}
ifaces, _ := normalizeTransportInterfacesState(loadTransportInterfacesState(), clients)
plan := TransportPolicyCompilePlan{
GeneratedAt: time.Now().UTC().Format(time.RFC3339),
PolicyRevision: policyRevision,
}
buckets := map[string]*transportPolicyCompileIfaceBucket{}
markOwner := map[string]string{}
prefOwner := map[int]string{}
tableOwner := map[string]string{}
conflicts := make([]TransportConflictRecord, 0)
for _, it := range intents {
client, ok := clientByID[it.ClientID]
if !ok {
conflicts = append(conflicts, TransportConflictRecord{
Key: "compile:client:" + it.ClientID,
Type: "unknown_client",
Severity: "block",
Owners: []string{it.ClientID},
Reason: "client not found during compile",
SuggestedResolution: "refresh clients and re-validate policy",
})
continue
}
binding := resolveTransportIfaceBinding(client, ifaces)
ifaceID := normalizeTransportIfaceID(binding.IfaceID)
routingTable := normalizeTransportRoutingTable(binding.RoutingTable, transportRoutingTableForID(client.ID))
if ifaceID != transportDefaultIfaceID && strings.TrimSpace(routingTable) != "" {
if owner, exists := tableOwner[routingTable]; exists && owner != ifaceID {
conflicts = append(conflicts, TransportConflictRecord{
Key: "allocator:table:" + routingTable,
Type: "allocator_collision",
Severity: "block",
Owners: []string{owner, ifaceID},
Reason: "routing table is shared across different iface_id",
SuggestedResolution: "assign unique routing_table per interface",
})
} else {
tableOwner[routingTable] = ifaceID
}
}
markHex := strings.TrimSpace(client.MarkHex)
if markHex != "" {
markHex = strings.ToLower(markHex)
if _, ok := parseTransportMarkHex(markHex); !ok {
conflicts = append(conflicts, TransportConflictRecord{
Key: "allocator:mark:" + client.ID,
Type: "allocator_invalid",
Severity: "block",
Owners: []string{client.ID},
Reason: "invalid mark_hex",
SuggestedResolution: "reconcile clients allocation state",
})
} else if owner, exists := markOwner[markHex]; exists && owner != ifaceID {
conflicts = append(conflicts, TransportConflictRecord{
Key: "allocator:mark:" + markHex,
Type: "allocator_collision",
Severity: "block",
Owners: []string{owner, ifaceID},
Reason: "mark_hex is shared across different iface_id",
SuggestedResolution: "assign unique mark pool per interface",
})
} else {
markOwner[markHex] = ifaceID
}
}
prefBase := client.PriorityBase
if prefBase > 0 {
if _, ok := parseTransportPref(prefBase); !ok {
conflicts = append(conflicts, TransportConflictRecord{
Key: "allocator:pref:" + client.ID,
Type: "allocator_invalid",
Severity: "block",
Owners: []string{client.ID},
Reason: "invalid priority_base",
SuggestedResolution: "reconcile clients allocation state",
})
} else if owner, exists := prefOwner[prefBase]; exists && owner != ifaceID {
conflicts = append(conflicts, TransportConflictRecord{
Key: "allocator:pref:" + strconv.Itoa(prefBase),
Type: "allocator_collision",
Severity: "block",
Owners: []string{owner, ifaceID},
Reason: "priority_base is shared across different iface_id",
SuggestedResolution: "assign unique pref pool per interface",
})
} else {
prefOwner[prefBase] = ifaceID
}
}
b, ok := buckets[ifaceID]
if !ok {
mode := normalizeTransportInterfaceMode("", ifaceID)
b = &transportPolicyCompileIfaceBucket{
Iface: TransportPolicyCompileInterface{
IfaceID: ifaceID,
Mode: string(mode),
RuntimeIface: strings.TrimSpace(binding.RuntimeIface),
NetnsName: strings.TrimSpace(binding.NetnsName),
RoutingTable: routingTable,
ClientIDs: nil,
MarkHexes: nil,
PriorityBase: nil,
Sets: nil,
Rules: nil,
},
sets: map[string]*TransportPolicyCompileSet{},
}
buckets[ifaceID] = b
}
if strings.TrimSpace(b.Iface.RoutingTable) == "" {
b.Iface.RoutingTable = routingTable
}
addUniqueString(&b.Iface.ClientIDs, client.ID)
if markHex != "" {
addUniqueString(&b.Iface.MarkHexes, markHex)
}
if prefBase > 0 {
addUniqueInt(&b.Iface.PriorityBase, prefBase)
}
ownerScope := transportPolicyNftOwnerScope(ifaceID, client.ID)
setName := transportPolicyNftSetName(ownerScope, it.SelectorType)
setKey := ownerScope + "|" + it.SelectorType + "|" + setName
if _, exists := b.sets[setKey]; !exists {
b.sets[setKey] = &TransportPolicyCompileSet{
SelectorType: it.SelectorType,
OwnerScope: ownerScope,
Name: setName,
}
}
b.sets[setKey].RuleCount++
b.Iface.RuleCount++
b.Iface.Rules = append(b.Iface.Rules, TransportPolicyCompileRule{
SelectorType: it.SelectorType,
SelectorValue: it.SelectorValue,
ClientID: client.ID,
ClientKind: string(client.Kind),
OwnerScope: ownerScope,
Mode: it.Mode,
Priority: it.Priority,
MarkHex: markHex,
PriorityBase: prefBase,
NftSet: setName,
})
plan.RuleCount++
}
keys := make([]string, 0, len(buckets))
for ifaceID := range buckets {
keys = append(keys, ifaceID)
}
sort.Strings(keys)
for _, ifaceID := range keys {
b := buckets[ifaceID]
sort.Strings(b.Iface.ClientIDs)
sort.Strings(b.Iface.MarkHexes)
sort.Ints(b.Iface.PriorityBase)
sort.Slice(b.Iface.Rules, func(i, j int) bool {
a := b.Iface.Rules[i]
c := b.Iface.Rules[j]
if a.SelectorType != c.SelectorType {
return a.SelectorType < c.SelectorType
}
if a.SelectorValue != c.SelectorValue {
return a.SelectorValue < c.SelectorValue
}
return a.ClientID < c.ClientID
})
sets := make([]TransportPolicyCompileSet, 0, len(b.sets))
for _, it := range b.sets {
sets = append(sets, *it)
}
sort.Slice(sets, func(i, j int) bool {
if sets[i].SelectorType != sets[j].SelectorType {
return sets[i].SelectorType < sets[j].SelectorType
}
if sets[i].OwnerScope != sets[j].OwnerScope {
return sets[i].OwnerScope < sets[j].OwnerScope
}
return sets[i].Name < sets[j].Name
})
b.Iface.Sets = sets
plan.Interfaces = append(plan.Interfaces, b.Iface)
}
plan.InterfaceCount = len(plan.Interfaces)
conflicts = dedupeTransportConflicts(conflicts)
return plan, conflicts
}
func transportPolicyNftSetName(ownerScope, selectorType string) string {
scope := strings.TrimSpace(ownerScope)
scope = strings.ReplaceAll(sanitizeID(scope), "-", "_")
if scope == "" {
scope = "shared_client"
}
selector := strings.ToLower(strings.TrimSpace(selectorType))
switch selector {
case "domain", "cidr", "app_key", "cgroup", "uid":
default:
selector = "misc"
}
name := fmt.Sprintf("agvpn_pi_%s_%s", scope, selector)
if len(name) > 63 {
hash := transportPolicyShortHash(name)
// nft set name max is 63 chars: agvpn_pi_<scope>_<selector>_<hash>
budget := 63 - len("agvpn_pi_") - len(selector) - len(hash) - 2
if budget < 6 {
budget = 6
}
if len(scope) > budget {
scope = scope[:budget]
}
name = fmt.Sprintf("agvpn_pi_%s_%s_%s", strings.Trim(scope, "_"), selector, hash)
if len(name) > 63 {
name = name[:63]
}
}
return strings.Trim(name, "_")
}
func transportPolicyNftOwnerScope(ifaceID, clientID string) string {
iface := strings.TrimSpace(ifaceID)
if iface == "" {
iface = transportDefaultIfaceID
}
iface = strings.ReplaceAll(sanitizeID(iface), "-", "_")
if iface == "" {
iface = transportDefaultIfaceID
}
client := strings.ReplaceAll(sanitizeID(clientID), "-", "_")
if client == "" {
client = "client"
}
scope := strings.Trim(iface+"_"+client, "_")
if len(scope) <= 32 {
return scope
}
hash := transportPolicyShortHash(scope)
if len(scope) > 20 {
scope = scope[:20]
}
scope = strings.Trim(scope, "_") + "_" + hash
if len(scope) > 32 {
scope = scope[:32]
}
return strings.Trim(scope, "_")
}
func transportPolicyShortHash(raw string) string {
sum := sha1.Sum([]byte(raw))
return hex.EncodeToString(sum[:])[:10]
}
func addUniqueString(dst *[]string, v string) {
val := strings.TrimSpace(v)
if val == "" {
return
}
for _, it := range *dst {
if it == val {
return
}
}
*dst = append(*dst, val)
}
func addUniqueInt(dst *[]int, v int) {
if v <= 0 {
return
}
for _, it := range *dst {
if it == v {
return
}
}
*dst = append(*dst, v)
}