339 lines
8.2 KiB
Go
339 lines
8.2 KiB
Go
package nftupdate
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"net/netip"
|
|
"os/exec"
|
|
"sort"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/cenkalti/backoff/v4"
|
|
)
|
|
|
|
type ProgressCallback func(percent int, message string)
|
|
|
|
type CmdRunner func(timeout time.Duration, name string, args ...string) (stdout, stderr string, exitCode int, err error)
|
|
|
|
type Logger func(format string, args ...any)
|
|
|
|
func UpdateIPsSmart(ctx context.Context, ips []string, progressCb ProgressCallback, runCmd CmdRunner, logf Logger) error {
|
|
return UpdateSetIPsSmart(ctx, "agvpn4", ips, progressCb, runCmd, logf)
|
|
}
|
|
|
|
func UpdateSetIPsSmart(ctx context.Context, setName string, ips []string, progressCb ProgressCallback, runCmd CmdRunner, logf Logger) error {
|
|
setName = strings.TrimSpace(setName)
|
|
if setName == "" {
|
|
setName = "agvpn4"
|
|
}
|
|
if runCmd == nil {
|
|
return fmt.Errorf("run command function is not configured")
|
|
}
|
|
|
|
if len(ips) == 0 {
|
|
if progressCb != nil {
|
|
progressCb(100, "nothing to update")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
origCount := len(ips)
|
|
ips = compressIPIntervals(ips)
|
|
if len(ips) != origCount {
|
|
log(logf, "compress(%s): %d -> %d IP elements (removed %d covered/duplicate entries)", setName, origCount, len(ips), origCount-len(ips))
|
|
}
|
|
|
|
if len(ips) == 0 {
|
|
if progressCb != nil {
|
|
progressCb(100, "nothing to update after compression")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
log(logf, "nft UpdateSetIPsSmart(%s): start, ips=%d", setName, len(ips))
|
|
|
|
if err := atomicUpdateWithProgress(ctx, setName, ips, progressCb, logf); err == nil {
|
|
return nil
|
|
} else {
|
|
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
|
log(logf, "atomic update cancelled (%s): %v", setName, err)
|
|
return err
|
|
}
|
|
log(logf, "atomic nft update failed (%s): %v; falling back to chunked mode", setName, err)
|
|
if progressCb != nil {
|
|
progressCb(0, "Falling back to non-atomic update")
|
|
}
|
|
}
|
|
|
|
return chunkedUpdateWithFallback(ctx, setName, ips, progressCb, runCmd, logf)
|
|
}
|
|
|
|
func compressIPIntervals(ips []string) []string {
|
|
seen := make(map[string]struct{})
|
|
|
|
type prefixItem struct {
|
|
p netip.Prefix
|
|
raw string
|
|
}
|
|
type addrItem struct {
|
|
a netip.Addr
|
|
raw string
|
|
}
|
|
|
|
var prefixes []prefixItem
|
|
var addrs []addrItem
|
|
|
|
for _, s := range ips {
|
|
s = strings.TrimSpace(s)
|
|
if s == "" {
|
|
continue
|
|
}
|
|
if _, ok := seen[s]; ok {
|
|
continue
|
|
}
|
|
seen[s] = struct{}{}
|
|
|
|
if strings.Contains(s, "/") {
|
|
p, err := netip.ParsePrefix(s)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
prefixes = append(prefixes, prefixItem{p: p, raw: s})
|
|
} else {
|
|
a, err := netip.ParseAddr(s)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
addrs = append(addrs, addrItem{a: a, raw: s})
|
|
}
|
|
}
|
|
|
|
sort.Slice(prefixes, func(i, j int) bool {
|
|
ai := prefixes[i].p.Addr()
|
|
aj := prefixes[j].p.Addr()
|
|
if ai == aj {
|
|
return prefixes[i].p.Bits() < prefixes[j].p.Bits()
|
|
}
|
|
return ai.Less(aj)
|
|
})
|
|
|
|
var keptPrefixes []prefixItem
|
|
for _, pi := range prefixes {
|
|
covered := false
|
|
for _, kp := range keptPrefixes {
|
|
if kp.p.Bits() <= pi.p.Bits() && kp.p.Contains(pi.p.Addr()) {
|
|
covered = true
|
|
break
|
|
}
|
|
}
|
|
if !covered {
|
|
keptPrefixes = append(keptPrefixes, pi)
|
|
}
|
|
}
|
|
|
|
var keptAddrs []addrItem
|
|
for _, ai := range addrs {
|
|
inNet := false
|
|
for _, kp := range keptPrefixes {
|
|
if kp.p.Contains(ai.a) {
|
|
inNet = true
|
|
break
|
|
}
|
|
}
|
|
if !inNet {
|
|
keptAddrs = append(keptAddrs, ai)
|
|
}
|
|
}
|
|
|
|
out := make([]string, 0, len(keptPrefixes)+len(keptAddrs))
|
|
for _, ai := range keptAddrs {
|
|
out = append(out, ai.raw)
|
|
}
|
|
for _, pi := range keptPrefixes {
|
|
out = append(out, pi.raw)
|
|
}
|
|
|
|
return out
|
|
}
|
|
|
|
func atomicUpdateWithProgress(ctx context.Context, setName string, ips []string, progressCb ProgressCallback, logf Logger) error {
|
|
if len(ips) == 0 {
|
|
if progressCb != nil {
|
|
progressCb(100, "nothing to update")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
sort.Strings(ips)
|
|
total := len(ips)
|
|
chunkSize := 500
|
|
|
|
bo := backoff.NewExponentialBackOff()
|
|
bo.InitialInterval = 500 * time.Millisecond
|
|
bo.MaxInterval = 10 * time.Second
|
|
bo.MaxElapsedTime = 2 * time.Minute
|
|
|
|
return backoff.Retry(func() error {
|
|
select {
|
|
case <-ctx.Done():
|
|
if progressCb != nil {
|
|
progressCb(0, "Cancelled by context")
|
|
}
|
|
return ctx.Err()
|
|
default:
|
|
}
|
|
|
|
var script strings.Builder
|
|
script.WriteString("flush set inet agvpn " + setName + "\n")
|
|
|
|
processed := 0
|
|
chunksTotal := (len(ips) + chunkSize - 1) / chunkSize
|
|
|
|
for i := 0; i < len(ips); i += chunkSize {
|
|
end := i + chunkSize
|
|
if end > len(ips) {
|
|
end = len(ips)
|
|
}
|
|
chunk := ips[i:end]
|
|
|
|
script.WriteString("add element inet agvpn " + setName + " { ")
|
|
script.WriteString(strings.Join(chunk, ", "))
|
|
script.WriteString(" }\n")
|
|
|
|
processed += len(chunk)
|
|
if progressCb != nil {
|
|
percent := processed * 100 / total
|
|
progressCb(percent, fmt.Sprintf("Preparing chunk %d/%d (%d/%d IPs)", i/chunkSize+1, chunksTotal, processed, total))
|
|
}
|
|
}
|
|
|
|
if progressCb != nil {
|
|
progressCb(90, "Executing nft transaction...")
|
|
}
|
|
|
|
cmd := exec.CommandContext(ctx, "nft", "-f", "-")
|
|
cmd.Stdin = strings.NewReader(script.String())
|
|
|
|
var stdout, stderr bytes.Buffer
|
|
cmd.Stdout = &stdout
|
|
cmd.Stderr = &stderr
|
|
|
|
err := cmd.Run()
|
|
if err == nil {
|
|
log(logf, "nft atomic transaction success (%s): %d IPs added", setName, len(ips))
|
|
if progressCb != nil {
|
|
progressCb(100, "Update complete")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
errStr := stderr.String()
|
|
log(logf, "nft atomic transaction failed (%s): err=%v, stderr=%q", setName, err, errStr)
|
|
|
|
if strings.Contains(errStr, "too many elements") ||
|
|
strings.Contains(errStr, "out of memory") ||
|
|
strings.Contains(errStr, "interval overlaps") ||
|
|
strings.Contains(errStr, "conflicting intervals") {
|
|
|
|
newSize := chunkSize / 2
|
|
if newSize < 100 {
|
|
newSize = 100
|
|
}
|
|
if newSize == chunkSize {
|
|
return backoff.Permanent(fmt.Errorf("atomic nft cannot shrink further: %w", err))
|
|
}
|
|
|
|
log(logf, "reducing atomic chunk size from %d to %d and retrying", chunkSize, newSize)
|
|
chunkSize = newSize
|
|
if progressCb != nil {
|
|
progressCb(0, fmt.Sprintf("Retrying atomic with smaller chunks (%d IPs)", chunkSize))
|
|
}
|
|
return fmt.Errorf("retry atomic with smaller chunks")
|
|
}
|
|
|
|
return backoff.Permanent(fmt.Errorf("nft atomic transaction failed: %w", err))
|
|
}, bo)
|
|
}
|
|
|
|
func chunkedUpdateWithFallback(ctx context.Context, setName string, ips []string, progressCb ProgressCallback, runCmd CmdRunner, logf Logger) error {
|
|
if len(ips) == 0 {
|
|
if progressCb != nil {
|
|
progressCb(100, "nothing to update")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
sort.Strings(ips)
|
|
total := len(ips)
|
|
chunkSize := 500
|
|
|
|
_, stderr, _, err := runCmd(10*time.Second, "nft", "flush", "set", "inet", "agvpn", setName)
|
|
if err != nil {
|
|
return fmt.Errorf("flush set failed: %v (%s)", err, stderr)
|
|
}
|
|
|
|
processed := 0
|
|
for i := 0; i < len(ips); i += chunkSize {
|
|
select {
|
|
case <-ctx.Done():
|
|
if progressCb != nil {
|
|
progressCb(0, "Cancelled during chunked update")
|
|
}
|
|
return ctx.Err()
|
|
default:
|
|
}
|
|
|
|
end := i + chunkSize
|
|
if end > len(ips) {
|
|
end = len(ips)
|
|
}
|
|
chunk := ips[i:end]
|
|
|
|
_, stderr, _, err := runCmd(15*time.Second, "nft", "add", "element", "inet", "agvpn", setName, "{ "+strings.Join(chunk, ", ")+" }")
|
|
if err != nil {
|
|
if strings.Contains(stderr, "interval overlaps") ||
|
|
strings.Contains(stderr, "too many elements") ||
|
|
strings.Contains(stderr, "out of memory") ||
|
|
strings.Contains(stderr, "conflicting intervals") {
|
|
|
|
log(logf, "chunk failed (%d IPs), fallback per-ip", len(chunk))
|
|
if progressCb != nil {
|
|
progressCb(processed*100/total, fmt.Sprintf("Chunk failed -> adding %d IPs one by one", len(chunk)))
|
|
}
|
|
|
|
for _, ip := range chunk {
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
default:
|
|
}
|
|
_, _, _, _ = runCmd(5*time.Second, "nft", "add", "element", "inet", "agvpn", setName, "{ "+ip+" }")
|
|
}
|
|
} else {
|
|
return fmt.Errorf("nft chunk add failed: %v (%s)", err, stderr)
|
|
}
|
|
}
|
|
|
|
processed += len(chunk)
|
|
if progressCb != nil {
|
|
percent := processed * 100 / total
|
|
progressCb(percent, fmt.Sprintf("Added %d/%d IPs", processed, total))
|
|
}
|
|
}
|
|
|
|
if progressCb != nil {
|
|
progressCb(100, "chunked update complete")
|
|
}
|
|
log(logf, "nft chunked update success (%s): %d IPs", setName, len(ips))
|
|
return nil
|
|
}
|
|
|
|
func log(logf Logger, format string, args ...any) {
|
|
if logf != nil {
|
|
logf(format, args...)
|
|
}
|
|
}
|