Files
elmprodvpn/selective-vpn-api/app/transport_policy_apply_kernel_test.go

106 lines
3.1 KiB
Go

package app
import (
"context"
"os"
"strings"
"testing"
"time"
)
func TestBuildTransportPolicyCIDRSetElements(t *testing.T) {
in := []TransportPolicyCompileInterface{
{
IfaceID: "edge-a",
Rules: []TransportPolicyCompileRule{
{SelectorType: "cidr", SelectorValue: "10.1.0.0/24", NftSet: "agvpn_pi_edge_a_cidr"},
{SelectorType: "cidr", SelectorValue: "10.1.0.5", NftSet: "agvpn_pi_edge_a_cidr"},
{SelectorType: "domain", SelectorValue: "example.com", NftSet: "agvpn_pi_edge_a_domain"},
},
},
}
got := buildTransportPolicyCIDRSetElements(in)
ips := got["agvpn_pi_edge_a_cidr"]
if len(ips) != 2 {
t.Fatalf("unexpected cidr elements: %#v", got)
}
}
func TestApplyTransportPolicyKernelNftSets(t *testing.T) {
prevRun := transportPolicyKernelRunCommand
prevUpdate := transportPolicyKernelUpdateSet
defer func() {
transportPolicyKernelRunCommand = prevRun
transportPolicyKernelUpdateSet = prevUpdate
}()
var cmds []string
transportPolicyKernelRunCommand = func(timeout time.Duration, name string, args ...string) (string, string, int, error) {
cmds = append(cmds, name+" "+strings.Join(args, " "))
return "", "", 0, nil
}
updates := map[string][]string{}
transportPolicyKernelUpdateSet = func(_ context.Context, setName string, ips []string) error {
updates[setName] = append([]string(nil), ips...)
return nil
}
current := transportPolicyRuntimeState{
Interfaces: []TransportPolicyCompileInterface{
{
IfaceID: "edge-old",
Sets: []TransportPolicyCompileSet{
{SelectorType: "cidr", Name: "agvpn_pi_edge_old_cidr"},
},
},
},
}
staged := transportPolicyRuntimeState{
Interfaces: []TransportPolicyCompileInterface{
{
IfaceID: "edge-a",
Rules: []TransportPolicyCompileRule{
{SelectorType: "cidr", SelectorValue: "10.1.0.0/24", NftSet: "agvpn_pi_edge_a_cidr"},
},
},
},
}
if err := applyTransportPolicyKernelNftSets(current, staged); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(updates["agvpn_pi_edge_a_cidr"]) != 1 {
t.Fatalf("expected update for desired set, got: %#v", updates)
}
hasDeleteStale := false
for _, cmd := range cmds {
if strings.Contains(cmd, "delete set inet agvpn agvpn_pi_edge_old_cidr") {
hasDeleteStale = true
break
}
}
if !hasDeleteStale {
t.Fatalf("expected stale set cleanup command, got: %#v", cmds)
}
}
func TestApplyTransportPolicyKernelStageDisabled(t *testing.T) {
prevRun := transportPolicyKernelRunCommand
defer func() { transportPolicyKernelRunCommand = prevRun }()
called := false
transportPolicyKernelRunCommand = func(timeout time.Duration, name string, args ...string) (string, string, int, error) {
called = true
return "", "", 0, nil
}
prev := os.Getenv(transportPolicyKernelEnvEnable)
_ = os.Setenv(transportPolicyKernelEnvEnable, "0")
defer func() { _ = os.Setenv(transportPolicyKernelEnvEnable, prev) }()
if err := applyTransportPolicyKernelStage(transportPolicyRuntimeState{}, transportPolicyRuntimeState{}); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if called {
t.Fatalf("kernel stage must be skipped when disabled")
}
}