Files

339 lines
8.2 KiB
Go

package nftupdate
import (
"bytes"
"context"
"errors"
"fmt"
"net/netip"
"os/exec"
"sort"
"strings"
"time"
"github.com/cenkalti/backoff/v4"
)
type ProgressCallback func(percent int, message string)
type CmdRunner func(timeout time.Duration, name string, args ...string) (stdout, stderr string, exitCode int, err error)
type Logger func(format string, args ...any)
func UpdateIPsSmart(ctx context.Context, ips []string, progressCb ProgressCallback, runCmd CmdRunner, logf Logger) error {
return UpdateSetIPsSmart(ctx, "agvpn4", ips, progressCb, runCmd, logf)
}
func UpdateSetIPsSmart(ctx context.Context, setName string, ips []string, progressCb ProgressCallback, runCmd CmdRunner, logf Logger) error {
setName = strings.TrimSpace(setName)
if setName == "" {
setName = "agvpn4"
}
if runCmd == nil {
return fmt.Errorf("run command function is not configured")
}
if len(ips) == 0 {
if progressCb != nil {
progressCb(100, "nothing to update")
}
return nil
}
origCount := len(ips)
ips = compressIPIntervals(ips)
if len(ips) != origCount {
log(logf, "compress(%s): %d -> %d IP elements (removed %d covered/duplicate entries)", setName, origCount, len(ips), origCount-len(ips))
}
if len(ips) == 0 {
if progressCb != nil {
progressCb(100, "nothing to update after compression")
}
return nil
}
log(logf, "nft UpdateSetIPsSmart(%s): start, ips=%d", setName, len(ips))
if err := atomicUpdateWithProgress(ctx, setName, ips, progressCb, logf); err == nil {
return nil
} else {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
log(logf, "atomic update cancelled (%s): %v", setName, err)
return err
}
log(logf, "atomic nft update failed (%s): %v; falling back to chunked mode", setName, err)
if progressCb != nil {
progressCb(0, "Falling back to non-atomic update")
}
}
return chunkedUpdateWithFallback(ctx, setName, ips, progressCb, runCmd, logf)
}
func compressIPIntervals(ips []string) []string {
seen := make(map[string]struct{})
type prefixItem struct {
p netip.Prefix
raw string
}
type addrItem struct {
a netip.Addr
raw string
}
var prefixes []prefixItem
var addrs []addrItem
for _, s := range ips {
s = strings.TrimSpace(s)
if s == "" {
continue
}
if _, ok := seen[s]; ok {
continue
}
seen[s] = struct{}{}
if strings.Contains(s, "/") {
p, err := netip.ParsePrefix(s)
if err != nil {
continue
}
prefixes = append(prefixes, prefixItem{p: p, raw: s})
} else {
a, err := netip.ParseAddr(s)
if err != nil {
continue
}
addrs = append(addrs, addrItem{a: a, raw: s})
}
}
sort.Slice(prefixes, func(i, j int) bool {
ai := prefixes[i].p.Addr()
aj := prefixes[j].p.Addr()
if ai == aj {
return prefixes[i].p.Bits() < prefixes[j].p.Bits()
}
return ai.Less(aj)
})
var keptPrefixes []prefixItem
for _, pi := range prefixes {
covered := false
for _, kp := range keptPrefixes {
if kp.p.Bits() <= pi.p.Bits() && kp.p.Contains(pi.p.Addr()) {
covered = true
break
}
}
if !covered {
keptPrefixes = append(keptPrefixes, pi)
}
}
var keptAddrs []addrItem
for _, ai := range addrs {
inNet := false
for _, kp := range keptPrefixes {
if kp.p.Contains(ai.a) {
inNet = true
break
}
}
if !inNet {
keptAddrs = append(keptAddrs, ai)
}
}
out := make([]string, 0, len(keptPrefixes)+len(keptAddrs))
for _, ai := range keptAddrs {
out = append(out, ai.raw)
}
for _, pi := range keptPrefixes {
out = append(out, pi.raw)
}
return out
}
func atomicUpdateWithProgress(ctx context.Context, setName string, ips []string, progressCb ProgressCallback, logf Logger) error {
if len(ips) == 0 {
if progressCb != nil {
progressCb(100, "nothing to update")
}
return nil
}
sort.Strings(ips)
total := len(ips)
chunkSize := 500
bo := backoff.NewExponentialBackOff()
bo.InitialInterval = 500 * time.Millisecond
bo.MaxInterval = 10 * time.Second
bo.MaxElapsedTime = 2 * time.Minute
return backoff.Retry(func() error {
select {
case <-ctx.Done():
if progressCb != nil {
progressCb(0, "Cancelled by context")
}
return ctx.Err()
default:
}
var script strings.Builder
script.WriteString("flush set inet agvpn " + setName + "\n")
processed := 0
chunksTotal := (len(ips) + chunkSize - 1) / chunkSize
for i := 0; i < len(ips); i += chunkSize {
end := i + chunkSize
if end > len(ips) {
end = len(ips)
}
chunk := ips[i:end]
script.WriteString("add element inet agvpn " + setName + " { ")
script.WriteString(strings.Join(chunk, ", "))
script.WriteString(" }\n")
processed += len(chunk)
if progressCb != nil {
percent := processed * 100 / total
progressCb(percent, fmt.Sprintf("Preparing chunk %d/%d (%d/%d IPs)", i/chunkSize+1, chunksTotal, processed, total))
}
}
if progressCb != nil {
progressCb(90, "Executing nft transaction...")
}
cmd := exec.CommandContext(ctx, "nft", "-f", "-")
cmd.Stdin = strings.NewReader(script.String())
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
err := cmd.Run()
if err == nil {
log(logf, "nft atomic transaction success (%s): %d IPs added", setName, len(ips))
if progressCb != nil {
progressCb(100, "Update complete")
}
return nil
}
errStr := stderr.String()
log(logf, "nft atomic transaction failed (%s): err=%v, stderr=%q", setName, err, errStr)
if strings.Contains(errStr, "too many elements") ||
strings.Contains(errStr, "out of memory") ||
strings.Contains(errStr, "interval overlaps") ||
strings.Contains(errStr, "conflicting intervals") {
newSize := chunkSize / 2
if newSize < 100 {
newSize = 100
}
if newSize == chunkSize {
return backoff.Permanent(fmt.Errorf("atomic nft cannot shrink further: %w", err))
}
log(logf, "reducing atomic chunk size from %d to %d and retrying", chunkSize, newSize)
chunkSize = newSize
if progressCb != nil {
progressCb(0, fmt.Sprintf("Retrying atomic with smaller chunks (%d IPs)", chunkSize))
}
return fmt.Errorf("retry atomic with smaller chunks")
}
return backoff.Permanent(fmt.Errorf("nft atomic transaction failed: %w", err))
}, bo)
}
func chunkedUpdateWithFallback(ctx context.Context, setName string, ips []string, progressCb ProgressCallback, runCmd CmdRunner, logf Logger) error {
if len(ips) == 0 {
if progressCb != nil {
progressCb(100, "nothing to update")
}
return nil
}
sort.Strings(ips)
total := len(ips)
chunkSize := 500
_, stderr, _, err := runCmd(10*time.Second, "nft", "flush", "set", "inet", "agvpn", setName)
if err != nil {
return fmt.Errorf("flush set failed: %v (%s)", err, stderr)
}
processed := 0
for i := 0; i < len(ips); i += chunkSize {
select {
case <-ctx.Done():
if progressCb != nil {
progressCb(0, "Cancelled during chunked update")
}
return ctx.Err()
default:
}
end := i + chunkSize
if end > len(ips) {
end = len(ips)
}
chunk := ips[i:end]
_, stderr, _, err := runCmd(15*time.Second, "nft", "add", "element", "inet", "agvpn", setName, "{ "+strings.Join(chunk, ", ")+" }")
if err != nil {
if strings.Contains(stderr, "interval overlaps") ||
strings.Contains(stderr, "too many elements") ||
strings.Contains(stderr, "out of memory") ||
strings.Contains(stderr, "conflicting intervals") {
log(logf, "chunk failed (%d IPs), fallback per-ip", len(chunk))
if progressCb != nil {
progressCb(processed*100/total, fmt.Sprintf("Chunk failed -> adding %d IPs one by one", len(chunk)))
}
for _, ip := range chunk {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
_, _, _, _ = runCmd(5*time.Second, "nft", "add", "element", "inet", "agvpn", setName, "{ "+ip+" }")
}
} else {
return fmt.Errorf("nft chunk add failed: %v (%s)", err, stderr)
}
}
processed += len(chunk)
if progressCb != nil {
percent := processed * 100 / total
progressCb(percent, fmt.Sprintf("Added %d/%d IPs", processed, total))
}
}
if progressCb != nil {
progressCb(100, "chunked update complete")
}
log(logf, "nft chunked update success (%s): %d IPs", setName, len(ips))
return nil
}
func log(logf Logger, format string, args ...any) {
if logf != nil {
logf(format, args...)
}
}