Files
elmprodvpn/selective-vpn-api/app/routes_cache.go

417 lines
11 KiB
Go

package app
import (
"context"
"encoding/json"
"fmt"
"os"
"sort"
"strings"
"time"
)
// ---------------------------------------------------------------------
// routes clear cache (safe clear / fast restore)
// ---------------------------------------------------------------------
// EN: Snapshot data persisted before routes clear to support fast restore
// EN: without running full domain resolve again.
// RU: Снимок данных, который сохраняется перед routes clear для быстрого
// RU: восстановления без повторного полного резолва доменов.
type routesClearCacheMeta struct {
CreatedAt string `json:"created_at"`
Iface string `json:"iface,omitempty"`
RouteCount int `json:"route_count"`
IPCount int `json:"ip_count"`
DynIPCount int `json:"dyn_ip_count"`
HasIPMap bool `json:"has_ip_map"`
}
func saveRoutesClearCache() (routesClearCacheMeta, error) {
if err := os.MkdirAll(stateDir, 0o755); err != nil {
return routesClearCacheMeta{}, err
}
routes, err := readCurrentRoutesTableLines()
if err != nil {
return routesClearCacheMeta{}, err
}
if err := writeLinesFile(routesCacheRT, routes); err != nil {
return routesClearCacheMeta{}, err
}
var warns []string
ipCount, err := snapshotNftSetToFile("agvpn4", routesCacheIPs)
if err != nil {
warns = append(warns, fmt.Sprintf("agvpn4 snapshot failed: %v", err))
_ = cacheCopyOrEmpty(stateDir+"/last-ips.txt", routesCacheIPs)
ipCount = len(readNonEmptyLines(routesCacheIPs))
}
dynIPCount, err := snapshotNftSetToFile("agvpn_dyn4", routesCacheDyn)
if err != nil {
warns = append(warns, fmt.Sprintf("agvpn_dyn4 snapshot failed: %v", err))
_ = os.WriteFile(routesCacheDyn, []byte{}, 0o644)
dynIPCount = 0
}
if err := cacheCopyOrEmpty(stateDir+"/last-ips-map.txt", routesCacheMap); err != nil {
warns = append(warns, fmt.Sprintf("last-ips-map cache copy failed: %v", err))
}
if err := cacheCopyOrEmpty(lastIPsMapDirect, routesCacheMapD); err != nil {
warns = append(warns, fmt.Sprintf("last-ips-map-direct cache copy failed: %v", err))
}
if err := cacheCopyOrEmpty(lastIPsMapDyn, routesCacheMapW); err != nil {
warns = append(warns, fmt.Sprintf("last-ips-map-wildcard cache copy failed: %v", err))
}
meta := routesClearCacheMeta{
CreatedAt: time.Now().UTC().Format(time.RFC3339),
Iface: detectIfaceFromRoutes(routes),
RouteCount: len(routes),
IPCount: ipCount,
DynIPCount: dynIPCount,
HasIPMap: fileExists(routesCacheMap),
}
data, err := json.MarshalIndent(meta, "", " ")
if err != nil {
return routesClearCacheMeta{}, err
}
if err := os.WriteFile(routesCacheMeta, data, 0o644); err != nil {
return routesClearCacheMeta{}, err
}
if len(warns) > 0 {
return meta, fmt.Errorf("%s", strings.Join(warns, "; "))
}
return meta, nil
}
func restoreRoutesFromCache() cmdResult {
return withRoutesOpLock("routes restore", restoreRoutesFromCacheUnlocked)
}
func restoreRoutesFromCacheUnlocked() cmdResult {
meta, err := loadRoutesClearCacheMeta()
if err != nil {
return cmdResult{
OK: false,
Message: fmt.Sprintf("routes cache missing: %v", err),
}
}
ips := readNonEmptyLines(routesCacheIPs)
dynIPs := readNonEmptyLines(routesCacheDyn)
routeLines, _ := readLinesFile(routesCacheRT)
ensureRoutesTableEntry()
removeTrafficRulesForTable()
_, _, _, _ = runCommandTimeout(5*time.Second, "ip", "route", "flush", "table", routesTableName())
ignoredRoutes := 0
for _, ln := range routeLines {
if err := restoreRouteLine(ln); err != nil {
if shouldIgnoreRestoreRouteError(ln, err) {
ignoredRoutes++
appendTraceLine("routes", fmt.Sprintf("restore route skipped (%q): %v", ln, err))
continue
}
return cmdResult{
OK: false,
Message: fmt.Sprintf("restore route failed (%q): %v", ln, err),
}
}
}
if ignoredRoutes > 0 {
appendTraceLine("routes", fmt.Sprintf("restore route: skipped non-critical routes=%d", ignoredRoutes))
}
if len(routeLines) == 0 && strings.TrimSpace(meta.Iface) != "" {
_, _, _, _ = runCommandTimeout(
5*time.Second,
"ip", "-4", "route", "replace",
"default", "dev", meta.Iface,
"table", routesTableName(),
"mtu", policyRouteMTU,
)
}
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "table", "inet", "agvpn")
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "set", "inet", "agvpn", "agvpn4", "{", "type", "ipv4_addr", ";", "flags", "interval", ";", "}")
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "set", "inet", "agvpn", "agvpn_dyn4", "{", "type", "ipv4_addr", ";", "flags", "interval", ";", "}")
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "flush", "set", "inet", "agvpn", "agvpn4")
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "flush", "set", "inet", "agvpn", "agvpn_dyn4")
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
defer cancel()
if len(ips) > 0 {
if err := nftUpdateSetIPsSmart(ctx, "agvpn4", ips, nil); err != nil {
return cmdResult{
OK: false,
Message: fmt.Sprintf("restore nft cache failed for agvpn4: %v", err),
}
}
}
if len(dynIPs) > 0 {
if err := nftUpdateSetIPsSmart(ctx, "agvpn_dyn4", dynIPs, nil); err != nil {
return cmdResult{
OK: false,
Message: fmt.Sprintf("restore nft cache failed for agvpn_dyn4: %v", err),
}
}
}
traffic := loadTrafficModeState()
iface := strings.TrimSpace(meta.Iface)
if iface == "" {
iface = detectIfaceFromRoutes(routeLines)
}
if iface == "" {
iface, _ = resolveTrafficIface(traffic.PreferredIface)
}
if iface != "" {
if err := applyTrafficMode(traffic, iface); err != nil {
return cmdResult{
OK: false,
Message: fmt.Sprintf("cache restored, but traffic mode apply failed: %v", err),
}
}
}
_ = cacheCopyOrEmpty(routesCacheIPs, stateDir+"/last-ips.txt")
if fileExists(routesCacheMap) {
_ = cacheCopyOrEmpty(routesCacheMap, stateDir+"/last-ips-map.txt")
}
if fileExists(routesCacheMapD) {
_ = cacheCopyOrEmpty(routesCacheMapD, lastIPsMapDirect)
}
if fileExists(routesCacheMapW) {
_ = cacheCopyOrEmpty(routesCacheMapW, lastIPsMapDyn)
}
_ = writeStatusSnapshot(len(ips)+len(dynIPs), iface)
return cmdResult{
OK: true,
Message: fmt.Sprintf(
"routes restored from cache: agvpn4=%d agvpn_dyn4=%d routes=%d iface=%s",
len(ips), len(dynIPs), len(routeLines), ifaceOrDash(iface),
),
}
}
func readCurrentRoutesTableLines() ([]string, error) {
out, _, code, err := runCommandTimeout(5*time.Second, "ip", "-4", "route", "show", "table", routesTableName())
if err != nil && code != 0 {
return nil, err
}
lines := make([]string, 0, 32)
for _, raw := range strings.Split(out, "\n") {
ln := strings.TrimSpace(raw)
if ln == "" {
continue
}
lines = append(lines, ln)
}
return lines, nil
}
func writeLinesFile(path string, lines []string) error {
if len(lines) == 0 {
return os.WriteFile(path, []byte{}, 0o644)
}
payload := strings.Join(lines, "\n")
if !strings.HasSuffix(payload, "\n") {
payload += "\n"
}
return os.WriteFile(path, []byte(payload), 0o644)
}
func readLinesFile(path string) ([]string, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
lines := make([]string, 0, 64)
for _, raw := range strings.Split(string(data), "\n") {
ln := strings.TrimSpace(raw)
if ln == "" {
continue
}
lines = append(lines, ln)
}
return lines, nil
}
func detectIfaceFromRoutes(lines []string) string {
for _, ln := range lines {
fields := strings.Fields(ln)
for i := 0; i+1 < len(fields); i++ {
if fields[i] == "dev" {
return strings.TrimSpace(fields[i+1])
}
}
}
return ""
}
func restoreRouteLine(line string) error {
fields := strings.Fields(strings.TrimSpace(line))
if len(fields) == 0 {
return nil
}
args := []string{"-4", "route", "replace"}
args = append(args, fields...)
hasTable := false
for i := 0; i+1 < len(fields); i++ {
if fields[i] == "table" {
hasTable = true
break
}
}
if !hasTable {
args = append(args, "table", routesTableName())
}
_, _, code, err := runCommandTimeout(5*time.Second, "ip", args...)
if err != nil || code != 0 {
if err == nil {
err = fmt.Errorf("exit code %d", code)
}
return err
}
return nil
}
func shouldIgnoreRestoreRouteError(line string, err error) bool {
ln := strings.ToLower(strings.TrimSpace(line))
if strings.Contains(ln, " linkdown") {
return true
}
dev := routeLineDevice(ln)
if dev != "" && !strings.HasPrefix(ln, "default ") && !ifaceExists(dev) {
return true
}
msg := strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", err)))
if strings.HasPrefix(ln, "default ") {
return false
}
if strings.Contains(msg, "cannot find device") ||
strings.Contains(msg, "no such device") ||
strings.Contains(msg, "network is down") {
return true
}
return false
}
func routeLineDevice(line string) string {
fields := strings.Fields(strings.TrimSpace(line))
for i := 0; i+1 < len(fields); i++ {
if fields[i] == "dev" {
return strings.TrimSpace(fields[i+1])
}
}
return ""
}
func cacheCopyOrEmpty(src, dst string) error {
if err := copyFile(src, dst); err == nil {
return nil
}
return os.WriteFile(dst, []byte{}, 0o644)
}
func snapshotNftSetToFile(setName, dst string) (int, error) {
elems, err := readNftSetElements(setName)
if err != nil {
return 0, err
}
if err := writeLinesFile(dst, elems); err != nil {
return 0, err
}
return len(elems), nil
}
func readNftSetElements(setName string) ([]string, error) {
out, stderr, code, err := runCommandTimeout(
8*time.Second, "nft", "list", "set", "inet", "agvpn", setName,
)
if err != nil || code != 0 {
msg := strings.ToLower(strings.TrimSpace(out + " " + stderr))
if strings.Contains(msg, "no such file") ||
strings.Contains(msg, "not found") ||
strings.Contains(msg, "does not exist") {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("nft list set %s failed: %w", setName, err)
}
return nil, fmt.Errorf("nft list set %s failed: %s", setName, strings.TrimSpace(stderr))
}
return parseNftSetElementsText(out), nil
}
func parseNftSetElementsText(raw string) []string {
idx := strings.Index(raw, "elements =")
if idx < 0 {
return nil
}
chunk := raw[idx:]
open := strings.Index(chunk, "{")
if open < 0 {
return nil
}
body := chunk[open+1:]
closeIdx := strings.Index(body, "}")
if closeIdx >= 0 {
body = body[:closeIdx]
}
body = strings.ReplaceAll(body, "\r", " ")
body = strings.ReplaceAll(body, "\n", " ")
seen := map[string]struct{}{}
out := make([]string, 0, 1024)
for _, tok := range strings.Split(body, ",") {
val := strings.TrimSpace(tok)
if val == "" {
continue
}
if _, ok := seen[val]; ok {
continue
}
seen[val] = struct{}{}
out = append(out, val)
}
sort.Strings(out)
return out
}
func loadRoutesClearCacheMeta() (routesClearCacheMeta, error) {
data, err := os.ReadFile(routesCacheMeta)
if err != nil {
return routesClearCacheMeta{}, err
}
var meta routesClearCacheMeta
if err := json.Unmarshal(data, &meta); err != nil {
return routesClearCacheMeta{}, err
}
return meta, nil
}
func fileExists(path string) bool {
info, err := os.Stat(path)
if err != nil {
return false
}
return !info.IsDir()
}
func ifaceOrDash(iface string) string {
if strings.TrimSpace(iface) == "" {
return "-"
}
return iface
}