Files
elmprodvpn/selective-vpn-api/app/transport_policy_apply_executor.go

119 lines
3.6 KiB
Go

package app
import (
"fmt"
"strings"
"time"
)
func applyTransportPolicyDataPlaneAtomicLocked(plan TransportPolicyCompilePlan, applyID string) (transportPolicyRuntimeState, error) {
current := loadTransportPolicyRuntimeState()
if err := saveTransportPolicyRuntimeSnapshot(current); err != nil {
return current, fmt.Errorf("runtime snapshot save failed: %w", err)
}
staged := transportPolicyRuntimeState{
Version: transportStateVersion,
PolicyRevision: plan.PolicyRevision,
ApplyID: strings.TrimSpace(applyID),
InterfaceCount: plan.InterfaceCount,
RuleCount: plan.RuleCount,
Interfaces: cloneTransportPolicyCompileInterfaces(plan.Interfaces),
}
if err := executeTransportPolicyCompilePlan(current, staged); err != nil {
_ = rollbackTransportPolicyRuntimeToSnapshot(staged)
return current, err
}
if err := saveTransportPolicyRuntimeState(staged); err != nil {
_ = rollbackTransportPolicyRuntimeToSnapshot(staged)
return current, fmt.Errorf("runtime state save failed: %w", err)
}
return staged, nil
}
func rollbackTransportPolicyRuntimeToSnapshot(current transportPolicyRuntimeState) error {
rollbackRuntime, ok := loadTransportPolicyRuntimeSnapshot()
if !ok {
return fmt.Errorf("runtime snapshot not found")
}
if err := executeTransportPolicyCompilePlan(current, rollbackRuntime); err != nil {
return err
}
return saveTransportPolicyRuntimeState(rollbackRuntime)
}
func executeTransportPolicyCompilePlan(current, staged transportPolicyRuntimeState) error {
for _, iface := range staged.Interfaces {
if err := validateTransportPolicyCompileInterface(iface); err != nil {
return err
}
appendTraceLineRateLimited(
"transport",
fmt.Sprintf(
"policy runtime stage: iface=%s table=%s rules=%d sets=%d",
iface.IfaceID,
iface.RoutingTable,
iface.RuleCount,
len(iface.Sets),
),
3*time.Second,
)
}
if err := applyTransportPolicyKernelStage(current, staged); err != nil {
return err
}
return nil
}
func validateTransportPolicyCompileInterface(iface TransportPolicyCompileInterface) error {
ifaceID := normalizeTransportIfaceID(iface.IfaceID)
if strings.TrimSpace(ifaceID) == "" {
return fmt.Errorf("compile interface has empty iface_id")
}
if iface.RuleCount < 0 {
return fmt.Errorf("compile interface %s has invalid rule_count", ifaceID)
}
if iface.RuleCount != len(iface.Rules) {
return fmt.Errorf("compile interface %s rule_count mismatch", ifaceID)
}
if ifaceID != transportDefaultIfaceID && strings.TrimSpace(iface.RoutingTable) == "" {
return fmt.Errorf("compile interface %s has empty routing_table", ifaceID)
}
for _, rule := range iface.Rules {
if strings.TrimSpace(rule.ClientID) == "" {
return fmt.Errorf("compile interface %s has rule without client_id", ifaceID)
}
if strings.TrimSpace(rule.SelectorType) == "" || strings.TrimSpace(rule.SelectorValue) == "" {
return fmt.Errorf("compile interface %s has invalid selector", ifaceID)
}
}
return nil
}
func cloneTransportPolicyCompileInterfaces(in []TransportPolicyCompileInterface) []TransportPolicyCompileInterface {
if len(in) == 0 {
return nil
}
out := make([]TransportPolicyCompileInterface, len(in))
for i := range in {
it := in[i]
it.ClientIDs = append([]string(nil), it.ClientIDs...)
it.MarkHexes = append([]string(nil), it.MarkHexes...)
it.PriorityBase = append([]int(nil), it.PriorityBase...)
if len(it.Sets) > 0 {
it.Sets = append([]TransportPolicyCompileSet(nil), it.Sets...)
} else {
it.Sets = nil
}
if len(it.Rules) > 0 {
it.Rules = append([]TransportPolicyCompileRule(nil), it.Rules...)
} else {
it.Rules = nil
}
out[i] = it
}
return out
}