Files
elmprodvpn/selective-vpn-api/app/nft_update.go
beckline 10a10f44a8 baseline: api+gui traffic mode + candidates picker
Snapshot before app-launcher (cgroup/mark) work; ignore binaries/backups.
2026-02-14 15:52:20 +03:00

401 lines
11 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package app
import (
"bytes"
"context"
"errors"
"fmt"
"net/netip"
"os/exec"
"sort"
"strings"
"time"
"github.com/cenkalti/backoff/v4"
)
// ---------------------------------------------------------------------
// nft update helpers
// ---------------------------------------------------------------------
// EN: NFT set update strategy with interval compression and two execution modes:
// EN: atomic transaction first, then chunked fallback with per-IP recovery.
// RU: Стратегия обновления NFT-набора с компрессией интервалов и двумя режимами:
// RU: сначала атомарная транзакция, затем chunked fallback с поштучным восстановлением.
func nftLog(format string, args ...any) {
appendTraceLine("routes", fmt.Sprintf(format, args...))
}
// ---------------------------------------------------------------------
// interval compression
// ---------------------------------------------------------------------
// compressIPIntervals убирает:
// - дубликаты строк
// - подсети, целиком покрытые более широкими подсетями
// - одиночные IP, попадающие в уже имеющиеся подсети
func compressIPIntervals(ips []string) []string {
// чтобы не гонять дубликаты строк
seen := make(map[string]struct{})
type prefixItem struct {
p netip.Prefix
raw string
}
type addrItem struct {
a netip.Addr
raw string
}
var prefixes []prefixItem
var addrs []addrItem
for _, s := range ips {
s = strings.TrimSpace(s)
if s == "" {
continue
}
if _, ok := seen[s]; ok {
continue
}
seen[s] = struct{}{}
if strings.Contains(s, "/") {
p, err := netip.ParsePrefix(s)
if err != nil {
// если формат кривой — просто пропускаем
continue
}
prefixes = append(prefixes, prefixItem{p: p, raw: s})
} else {
a, err := netip.ParseAddr(s)
if err != nil {
continue
}
addrs = append(addrs, addrItem{a: a, raw: s})
}
}
// 1) Убираем подсети, полностью покрытые более крупными подсетями.
//
// Сначала сортируем по:
// - адресу
// - длине префикса (меньший Bits = более широкая сеть) — раньше
sort.Slice(prefixes, func(i, j int) bool {
ai := prefixes[i].p.Addr()
aj := prefixes[j].p.Addr()
if ai == aj {
return prefixes[i].p.Bits() < prefixes[j].p.Bits()
}
return ai.Less(aj)
})
var keptPrefixes []prefixItem
for _, pi := range prefixes {
covered := false
for _, kp := range keptPrefixes {
// если более крупная сеть kp уже покрывает эту — пропускаем
if kp.p.Bits() <= pi.p.Bits() && kp.p.Contains(pi.p.Addr()) {
covered = true
break
}
}
if !covered {
keptPrefixes = append(keptPrefixes, pi)
}
}
var keptAddrs []addrItem
for _, ai := range addrs {
inNet := false
for _, kp := range keptPrefixes {
if kp.p.Contains(ai.a) {
inNet = true
break
}
}
if !inNet {
keptAddrs = append(keptAddrs, ai)
}
}
// 3) Собираем финальный список строк
out := make([]string, 0, len(keptPrefixes)+len(keptAddrs))
for _, ai := range keptAddrs {
out = append(out, ai.raw)
}
for _, pi := range keptPrefixes {
out = append(out, pi.raw)
}
return out
}
// ---------------------------------------------------------------------
// smart update strategy
// ---------------------------------------------------------------------
// умный апдейтер: сначала atomic, при фейле fallback на chunked
func nftUpdateIPsSmart(ctx context.Context, ips []string, progressCb ProgressCallback) error {
return nftUpdateSetIPsSmart(ctx, "agvpn4", ips, progressCb)
}
// nftUpdateSetIPsSmart — тот же апдейтер, но для произвольного nft set.
func nftUpdateSetIPsSmart(ctx context.Context, setName string, ips []string, progressCb ProgressCallback) error {
setName = strings.TrimSpace(setName)
if setName == "" {
setName = "agvpn4"
}
if len(ips) == 0 {
if progressCb != nil {
progressCb(100, "nothing to update")
}
return nil
}
// Сжимаем IP / подсети, убираем пересечения и дубликаты
origCount := len(ips)
ips = compressIPIntervals(ips)
if len(ips) != origCount {
nftLog(
"compress(%s): %d -> %d IP elements (removed %d covered/duplicate entries)",
setName, origCount, len(ips), origCount-len(ips),
)
}
if len(ips) == 0 {
if progressCb != nil {
progressCb(100, "nothing to update after compression")
}
return nil
}
nftLog("nftUpdateSetIPsSmart(%s): start, ips=%d", setName, len(ips))
// 1) atomic транзакция через nft -f -
if err := nftAtomicUpdateWithProgress(ctx, setName, ips, progressCb); err == nil {
return nil
} else {
// если контекст умер дальше не идём
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
nftLog("atomic update cancelled (%s): %v", setName, err)
return err
}
nftLog("atomic nft update failed (%s): %v; falling back to chunked mode", setName, err)
if progressCb != nil {
progressCb(0, "Falling back to non-atomic update")
}
}
// 2) fallback: flush + chunked с поштучным фолбэком
return nftChunkedUpdateWithFallback(ctx, setName, ips, progressCb)
}
// ---------------------------------------------------------------------
// atomic updater
// ---------------------------------------------------------------------
// атомарный апдейт через один nft-транзакционный скрипт
func nftAtomicUpdateWithProgress(ctx context.Context, setName string, ips []string, progressCb ProgressCallback) error {
if len(ips) == 0 {
if progressCb != nil {
progressCb(100, "nothing to update")
}
return nil
}
sort.Strings(ips) // стабильность
total := len(ips)
chunkSize := 500 // старт
bo := backoff.NewExponentialBackOff()
bo.InitialInterval = 500 * time.Millisecond
bo.MaxInterval = 10 * time.Second
bo.MaxElapsedTime = 2 * time.Minute
return backoff.Retry(func() error {
select {
case <-ctx.Done():
if progressCb != nil {
progressCb(0, "Cancelled by context")
}
return ctx.Err()
default:
}
var script strings.Builder
script.WriteString("flush set inet agvpn " + setName + "\n")
processed := 0
chunksTotal := (len(ips) + chunkSize - 1) / chunkSize
for i := 0; i < len(ips); i += chunkSize {
end := i + chunkSize
if end > len(ips) {
end = len(ips)
}
chunk := ips[i:end]
script.WriteString("add element inet agvpn " + setName + " { ")
script.WriteString(strings.Join(chunk, ", "))
script.WriteString(" }\n")
processed += len(chunk)
if progressCb != nil {
percent := processed * 100 / total
progressCb(percent, fmt.Sprintf(
"Preparing chunk %d/%d (%d/%d IPs)",
i/chunkSize+1, chunksTotal, processed, total,
))
}
}
if progressCb != nil {
progressCb(90, "Executing nft transaction...")
}
cmd := exec.CommandContext(ctx, "nft", "-f", "-")
cmd.Stdin = strings.NewReader(script.String())
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
err := cmd.Run()
if err == nil {
nftLog("nft atomic transaction success (%s): %d IPs added", setName, len(ips))
if progressCb != nil {
progressCb(100, "Update complete")
}
return nil
}
errStr := stderr.String()
nftLog("nft atomic transaction failed (%s): err=%v, stderr=%q", setName, err, errStr)
// Ошибки, требующие уменьшения чанка
if strings.Contains(errStr, "too many elements") ||
strings.Contains(errStr, "out of memory") ||
strings.Contains(errStr, "interval overlaps") ||
strings.Contains(errStr, "conflicting intervals") {
newSize := chunkSize / 2
if newSize < 100 {
newSize = 100
}
if newSize == chunkSize {
// дальше делить некуда — Permanent → fallback
return backoff.Permanent(fmt.Errorf("atomic nft cannot shrink further: %w", err))
}
nftLog("reducing atomic chunk size from %d to %d and retrying", chunkSize, newSize)
chunkSize = newSize
if progressCb != nil {
progressCb(0, fmt.Sprintf("Retrying atomic with smaller chunks (%d IPs)", chunkSize))
}
return fmt.Errorf("retry atomic with smaller chunks")
}
// Другие ошибки — Permanent (переход к chunked)
return backoff.Permanent(fmt.Errorf("nft atomic transaction failed: %w", err))
}, bo)
}
// ---------------------------------------------------------------------
// chunked fallback updater
// ---------------------------------------------------------------------
// nftChunkedUpdateWithFallback — fallback-режим: flush + чанки + поштучно при ошибках
func nftChunkedUpdateWithFallback(ctx context.Context, setName string, ips []string, progressCb ProgressCallback) error {
if len(ips) == 0 {
if progressCb != nil {
progressCb(100, "nothing to update")
}
return nil
}
sort.Strings(ips)
total := len(ips)
chunkSize := 500
// flush
_, stderr, _, err := runCommandTimeout(10*time.Second,
"nft", "flush", "set", "inet", "agvpn", setName)
if err != nil {
return fmt.Errorf("flush set failed: %v (%s)", err, stderr)
}
processed := 0
for i := 0; i < len(ips); i += chunkSize {
select {
case <-ctx.Done():
if progressCb != nil {
progressCb(0, "Cancelled during chunked update")
}
return ctx.Err()
default:
}
end := i + chunkSize
if end > len(ips) {
end = len(ips)
}
chunk := ips[i:end]
cmdArgs := []string{
"nft", "add", "element", "inet", "agvpn", setName,
"{ " + strings.Join(chunk, ", ") + " }",
}
cmdName := cmdArgs[0]
cmdRest := cmdArgs[1:]
_, stderr, _, err := runCommandTimeout(15*time.Second, cmdName, cmdRest...)
if err != nil {
// типичные ошибки → поштучно
if strings.Contains(stderr, "interval overlaps") ||
strings.Contains(stderr, "too many elements") ||
strings.Contains(stderr, "out of memory") ||
strings.Contains(stderr, "conflicting intervals") {
nftLog("chunk failed (%d IPs), fallback per-ip", len(chunk))
if progressCb != nil {
progressCb(processed*100/total,
fmt.Sprintf("Chunk failed -> adding %d IPs one by one", len(chunk)))
}
for _, ip := range chunk {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
_, _, _, _ = runCommandTimeout(5*time.Second,
"nft", "add", "element", "inet", "agvpn", setName, "{ "+ip+" }")
}
} else {
return fmt.Errorf("nft chunk add failed: %v (%s)", err, stderr)
}
}
processed += len(chunk)
if progressCb != nil {
percent := processed * 100 / total
progressCb(percent, fmt.Sprintf("Added %d/%d IPs", processed, total))
}
}
if progressCb != nil {
progressCb(100, "chunked update complete")
}
nftLog("nft chunked update success (%s): %d IPs", setName, len(ips))
return nil
}