106 lines
3.1 KiB
Go
106 lines
3.1 KiB
Go
package app
|
|
|
|
import (
|
|
"context"
|
|
"os"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
func TestBuildTransportPolicyCIDRSetElements(t *testing.T) {
|
|
in := []TransportPolicyCompileInterface{
|
|
{
|
|
IfaceID: "edge-a",
|
|
Rules: []TransportPolicyCompileRule{
|
|
{SelectorType: "cidr", SelectorValue: "10.1.0.0/24", NftSet: "agvpn_pi_edge_a_cidr"},
|
|
{SelectorType: "cidr", SelectorValue: "10.1.0.5", NftSet: "agvpn_pi_edge_a_cidr"},
|
|
{SelectorType: "domain", SelectorValue: "example.com", NftSet: "agvpn_pi_edge_a_domain"},
|
|
},
|
|
},
|
|
}
|
|
got := buildTransportPolicyCIDRSetElements(in)
|
|
ips := got["agvpn_pi_edge_a_cidr"]
|
|
if len(ips) != 2 {
|
|
t.Fatalf("unexpected cidr elements: %#v", got)
|
|
}
|
|
}
|
|
|
|
func TestApplyTransportPolicyKernelNftSets(t *testing.T) {
|
|
prevRun := transportPolicyKernelRunCommand
|
|
prevUpdate := transportPolicyKernelUpdateSet
|
|
defer func() {
|
|
transportPolicyKernelRunCommand = prevRun
|
|
transportPolicyKernelUpdateSet = prevUpdate
|
|
}()
|
|
|
|
var cmds []string
|
|
transportPolicyKernelRunCommand = func(timeout time.Duration, name string, args ...string) (string, string, int, error) {
|
|
cmds = append(cmds, name+" "+strings.Join(args, " "))
|
|
return "", "", 0, nil
|
|
}
|
|
updates := map[string][]string{}
|
|
transportPolicyKernelUpdateSet = func(_ context.Context, setName string, ips []string) error {
|
|
updates[setName] = append([]string(nil), ips...)
|
|
return nil
|
|
}
|
|
|
|
current := transportPolicyRuntimeState{
|
|
Interfaces: []TransportPolicyCompileInterface{
|
|
{
|
|
IfaceID: "edge-old",
|
|
Sets: []TransportPolicyCompileSet{
|
|
{SelectorType: "cidr", Name: "agvpn_pi_edge_old_cidr"},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
staged := transportPolicyRuntimeState{
|
|
Interfaces: []TransportPolicyCompileInterface{
|
|
{
|
|
IfaceID: "edge-a",
|
|
Rules: []TransportPolicyCompileRule{
|
|
{SelectorType: "cidr", SelectorValue: "10.1.0.0/24", NftSet: "agvpn_pi_edge_a_cidr"},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
if err := applyTransportPolicyKernelNftSets(current, staged); err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if len(updates["agvpn_pi_edge_a_cidr"]) != 1 {
|
|
t.Fatalf("expected update for desired set, got: %#v", updates)
|
|
}
|
|
hasDeleteStale := false
|
|
for _, cmd := range cmds {
|
|
if strings.Contains(cmd, "delete set inet agvpn agvpn_pi_edge_old_cidr") {
|
|
hasDeleteStale = true
|
|
break
|
|
}
|
|
}
|
|
if !hasDeleteStale {
|
|
t.Fatalf("expected stale set cleanup command, got: %#v", cmds)
|
|
}
|
|
}
|
|
|
|
func TestApplyTransportPolicyKernelStageDisabled(t *testing.T) {
|
|
prevRun := transportPolicyKernelRunCommand
|
|
defer func() { transportPolicyKernelRunCommand = prevRun }()
|
|
called := false
|
|
transportPolicyKernelRunCommand = func(timeout time.Duration, name string, args ...string) (string, string, int, error) {
|
|
called = true
|
|
return "", "", 0, nil
|
|
}
|
|
|
|
prev := os.Getenv(transportPolicyKernelEnvEnable)
|
|
_ = os.Setenv(transportPolicyKernelEnvEnable, "0")
|
|
defer func() { _ = os.Setenv(transportPolicyKernelEnvEnable, prev) }()
|
|
|
|
if err := applyTransportPolicyKernelStage(transportPolicyRuntimeState{}, transportPolicyRuntimeState{}); err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if called {
|
|
t.Fatalf("kernel stage must be skipped when disabled")
|
|
}
|
|
}
|