package nftupdate import ( "bytes" "context" "errors" "fmt" "net/netip" "os/exec" "sort" "strings" "time" "github.com/cenkalti/backoff/v4" ) type ProgressCallback func(percent int, message string) type CmdRunner func(timeout time.Duration, name string, args ...string) (stdout, stderr string, exitCode int, err error) type Logger func(format string, args ...any) func UpdateIPsSmart(ctx context.Context, ips []string, progressCb ProgressCallback, runCmd CmdRunner, logf Logger) error { return UpdateSetIPsSmart(ctx, "agvpn4", ips, progressCb, runCmd, logf) } func UpdateSetIPsSmart(ctx context.Context, setName string, ips []string, progressCb ProgressCallback, runCmd CmdRunner, logf Logger) error { setName = strings.TrimSpace(setName) if setName == "" { setName = "agvpn4" } if runCmd == nil { return fmt.Errorf("run command function is not configured") } if len(ips) == 0 { if progressCb != nil { progressCb(100, "nothing to update") } return nil } origCount := len(ips) ips = compressIPIntervals(ips) if len(ips) != origCount { log(logf, "compress(%s): %d -> %d IP elements (removed %d covered/duplicate entries)", setName, origCount, len(ips), origCount-len(ips)) } if len(ips) == 0 { if progressCb != nil { progressCb(100, "nothing to update after compression") } return nil } log(logf, "nft UpdateSetIPsSmart(%s): start, ips=%d", setName, len(ips)) if err := atomicUpdateWithProgress(ctx, setName, ips, progressCb, logf); err == nil { return nil } else { if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { log(logf, "atomic update cancelled (%s): %v", setName, err) return err } log(logf, "atomic nft update failed (%s): %v; falling back to chunked mode", setName, err) if progressCb != nil { progressCb(0, "Falling back to non-atomic update") } } return chunkedUpdateWithFallback(ctx, setName, ips, progressCb, runCmd, logf) } func compressIPIntervals(ips []string) []string { seen := make(map[string]struct{}) type prefixItem struct { p netip.Prefix raw string } type addrItem struct { a netip.Addr raw string } var prefixes []prefixItem var addrs []addrItem for _, s := range ips { s = strings.TrimSpace(s) if s == "" { continue } if _, ok := seen[s]; ok { continue } seen[s] = struct{}{} if strings.Contains(s, "/") { p, err := netip.ParsePrefix(s) if err != nil { continue } prefixes = append(prefixes, prefixItem{p: p, raw: s}) } else { a, err := netip.ParseAddr(s) if err != nil { continue } addrs = append(addrs, addrItem{a: a, raw: s}) } } sort.Slice(prefixes, func(i, j int) bool { ai := prefixes[i].p.Addr() aj := prefixes[j].p.Addr() if ai == aj { return prefixes[i].p.Bits() < prefixes[j].p.Bits() } return ai.Less(aj) }) var keptPrefixes []prefixItem for _, pi := range prefixes { covered := false for _, kp := range keptPrefixes { if kp.p.Bits() <= pi.p.Bits() && kp.p.Contains(pi.p.Addr()) { covered = true break } } if !covered { keptPrefixes = append(keptPrefixes, pi) } } var keptAddrs []addrItem for _, ai := range addrs { inNet := false for _, kp := range keptPrefixes { if kp.p.Contains(ai.a) { inNet = true break } } if !inNet { keptAddrs = append(keptAddrs, ai) } } out := make([]string, 0, len(keptPrefixes)+len(keptAddrs)) for _, ai := range keptAddrs { out = append(out, ai.raw) } for _, pi := range keptPrefixes { out = append(out, pi.raw) } return out } func atomicUpdateWithProgress(ctx context.Context, setName string, ips []string, progressCb ProgressCallback, logf Logger) error { if len(ips) == 0 { if progressCb != nil { progressCb(100, "nothing to update") } return nil } sort.Strings(ips) total := len(ips) chunkSize := 500 bo := backoff.NewExponentialBackOff() bo.InitialInterval = 500 * time.Millisecond bo.MaxInterval = 10 * time.Second bo.MaxElapsedTime = 2 * time.Minute return backoff.Retry(func() error { select { case <-ctx.Done(): if progressCb != nil { progressCb(0, "Cancelled by context") } return ctx.Err() default: } var script strings.Builder script.WriteString("flush set inet agvpn " + setName + "\n") processed := 0 chunksTotal := (len(ips) + chunkSize - 1) / chunkSize for i := 0; i < len(ips); i += chunkSize { end := i + chunkSize if end > len(ips) { end = len(ips) } chunk := ips[i:end] script.WriteString("add element inet agvpn " + setName + " { ") script.WriteString(strings.Join(chunk, ", ")) script.WriteString(" }\n") processed += len(chunk) if progressCb != nil { percent := processed * 100 / total progressCb(percent, fmt.Sprintf("Preparing chunk %d/%d (%d/%d IPs)", i/chunkSize+1, chunksTotal, processed, total)) } } if progressCb != nil { progressCb(90, "Executing nft transaction...") } cmd := exec.CommandContext(ctx, "nft", "-f", "-") cmd.Stdin = strings.NewReader(script.String()) var stdout, stderr bytes.Buffer cmd.Stdout = &stdout cmd.Stderr = &stderr err := cmd.Run() if err == nil { log(logf, "nft atomic transaction success (%s): %d IPs added", setName, len(ips)) if progressCb != nil { progressCb(100, "Update complete") } return nil } errStr := stderr.String() log(logf, "nft atomic transaction failed (%s): err=%v, stderr=%q", setName, err, errStr) if strings.Contains(errStr, "too many elements") || strings.Contains(errStr, "out of memory") || strings.Contains(errStr, "interval overlaps") || strings.Contains(errStr, "conflicting intervals") { newSize := chunkSize / 2 if newSize < 100 { newSize = 100 } if newSize == chunkSize { return backoff.Permanent(fmt.Errorf("atomic nft cannot shrink further: %w", err)) } log(logf, "reducing atomic chunk size from %d to %d and retrying", chunkSize, newSize) chunkSize = newSize if progressCb != nil { progressCb(0, fmt.Sprintf("Retrying atomic with smaller chunks (%d IPs)", chunkSize)) } return fmt.Errorf("retry atomic with smaller chunks") } return backoff.Permanent(fmt.Errorf("nft atomic transaction failed: %w", err)) }, bo) } func chunkedUpdateWithFallback(ctx context.Context, setName string, ips []string, progressCb ProgressCallback, runCmd CmdRunner, logf Logger) error { if len(ips) == 0 { if progressCb != nil { progressCb(100, "nothing to update") } return nil } sort.Strings(ips) total := len(ips) chunkSize := 500 _, stderr, _, err := runCmd(10*time.Second, "nft", "flush", "set", "inet", "agvpn", setName) if err != nil { return fmt.Errorf("flush set failed: %v (%s)", err, stderr) } processed := 0 for i := 0; i < len(ips); i += chunkSize { select { case <-ctx.Done(): if progressCb != nil { progressCb(0, "Cancelled during chunked update") } return ctx.Err() default: } end := i + chunkSize if end > len(ips) { end = len(ips) } chunk := ips[i:end] _, stderr, _, err := runCmd(15*time.Second, "nft", "add", "element", "inet", "agvpn", setName, "{ "+strings.Join(chunk, ", ")+" }") if err != nil { if strings.Contains(stderr, "interval overlaps") || strings.Contains(stderr, "too many elements") || strings.Contains(stderr, "out of memory") || strings.Contains(stderr, "conflicting intervals") { log(logf, "chunk failed (%d IPs), fallback per-ip", len(chunk)) if progressCb != nil { progressCb(processed*100/total, fmt.Sprintf("Chunk failed -> adding %d IPs one by one", len(chunk))) } for _, ip := range chunk { select { case <-ctx.Done(): return ctx.Err() default: } _, _, _, _ = runCmd(5*time.Second, "nft", "add", "element", "inet", "agvpn", setName, "{ "+ip+" }") } } else { return fmt.Errorf("nft chunk add failed: %v (%s)", err, stderr) } } processed += len(chunk) if progressCb != nil { percent := processed * 100 / total progressCb(percent, fmt.Sprintf("Added %d/%d IPs", processed, total)) } } if progressCb != nil { progressCb(100, "chunked update complete") } log(logf, "nft chunked update success (%s): %d IPs", setName, len(ips)) return nil } func log(logf Logger, format string, args ...any) { if logf != nil { logf(format, args...) } }