platform: modularize api/gui, add docs-tests-web foundation, and refresh root config

This commit is contained in:
beckline
2026-03-26 22:40:54 +03:00
parent 0e2d7f61ea
commit 6a56d734c2
562 changed files with 70151 additions and 16423 deletions

View File

@@ -0,0 +1,38 @@
package app
import (
"context"
"errors"
"log"
"net/http"
appbootstrap "selective-vpn-api/app/bootstrap"
"time"
)
func runAPIServerAtAddr(addr string) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
prepareAPIRuntime()
log.Printf("selective-vpn API listening on %s", addr)
if err := appbootstrap.Run(ctx, appbootstrap.Config{
Addr: addr,
ReadHeaderTimeout: 5 * time.Second,
RegisterRoutes: registerAPIRoutes,
WrapHandler: logRequests,
StartWatchers: startWatchers,
}); err != nil && !errors.Is(err, http.ErrServerClosed) {
log.Fatalf("server error: %v", err)
}
}
func prepareAPIRuntime() {
ensureSeeds()
if err := ensureAppMarksNft(); err != nil {
log.Printf("traffic appmarks nft init warning: %v", err)
}
if err := restoreAppMarksFromState(); err != nil {
log.Printf("traffic appmarks restore warning: %v", err)
}
}

View File

@@ -0,0 +1,92 @@
package app
import (
"net/http"
apiroutes "selective-vpn-api/app/apiroutes"
)
func registerAPIRoutes(mux *http.ServeMux) {
apiroutes.Register(mux, apiroutes.Handlers{
Healthz: handleHealthz,
EventsStream: handleEventsStream,
GetStatus: handleGetStatus,
VPNLoginState: handleVPNLoginState,
SystemdState: handleSystemdState,
RoutesServiceStart: makeRoutesServiceActionHandler("start"),
RoutesServiceStop: makeRoutesServiceActionHandler("stop"),
RoutesServiceRestart: makeRoutesServiceActionHandler("restart"),
RoutesService: handleRoutesService,
RoutesUpdate: handleRoutesUpdate,
RoutesTimer: handleRoutesTimer,
RoutesTimerToggle: handleRoutesTimerToggle,
RoutesRollback: handleRoutesClear,
RoutesClear: handleRoutesClear,
RoutesCacheRestore: handleRoutesCacheRestore,
RoutesPrecheckDebug: handleRoutesPrecheckDebug,
RoutesFixPolicyRoute: handleFixPolicyRoute,
RoutesFixPolicyAlias: handleFixPolicyRoute,
TrafficMode: handleTrafficMode,
TrafficModeTest: handleTrafficModeTest,
TrafficAdvancedReset: handleTrafficAdvancedReset,
TrafficInterfaces: handleTrafficInterfaces,
TrafficCandidates: handleTrafficCandidates,
TrafficAppMarks: handleTrafficAppMarks,
TrafficAppMarksItems: handleTrafficAppMarksItems,
TrafficAppProfiles: handleTrafficAppProfiles,
TrafficAudit: handleTrafficAudit,
TransportClients: handleTransportClients,
TransportClientByID: handleTransportClientByID,
TransportInterfaces: handleTransportInterfaces,
TransportRuntimeObservability: handleTransportRuntimeObservability,
TransportPolicies: handleTransportPolicies,
TransportPoliciesValidate: handleTransportPoliciesValidate,
TransportPoliciesApply: handleTransportPoliciesApply,
TransportPoliciesRollback: handleTransportPoliciesRollback,
TransportConflicts: handleTransportConflicts,
TransportOwnership: handleTransportOwnership,
TransportOwnerLocks: handleTransportOwnerLocks,
TransportOwnerLocksClear: handleTransportOwnerLocksClear,
TransportCapabilities: handleTransportCapabilities,
TransportHealthRefresh: handleTransportHealthRefresh,
TransportNetnsToggle: handleTransportNetnsToggle,
TransportSingBoxProfiles: handleTransportSingBoxProfiles,
TransportSingBoxProfileByID: handleTransportSingBoxProfileByID,
TransportSingBoxFeatures: handleTransportSingBoxFeatures,
EgressIdentityGet: handleEgressIdentityGet,
EgressIdentityRefresh: handleEgressIdentityRefresh,
TraceTailPlain: handleTraceTailPlain,
TraceJSON: handleTraceJSON,
TraceAppend: handleTraceAppend,
DNSUpstreams: handleDNSUpstreams,
DNSUpstreamPool: handleDNSUpstreamPool,
DNSStatus: handleDNSStatus,
DNSModeSet: handleDNSModeSet,
DNSBenchmark: handleDNSBenchmark,
DNSSmartdnsService: handleDNSSmartdnsService,
SmartdnsService: handleSmartdnsService,
SmartdnsRuntime: handleSmartdnsRuntime,
SmartdnsPrewarm: handleSmartdnsPrewarm,
SmartdnsWildcards: handleSmartdnsWildcards,
DomainsTable: handleDomainsTable,
DomainsFile: handleDomainsFile,
VPNAutoloopStatus: handleVPNAutoloopStatus,
VPNStatus: handleVPNStatus,
VPNAutoconnect: handleVPNAutoconnect,
VPNListLocations: handleVPNListLocations,
VPNSetLocation: handleVPNSetLocation,
VPNLoginSessionStart: handleVPNLoginSessionStart,
VPNLoginSessionState: handleVPNLoginSessionState,
VPNLoginSessionAction: handleVPNLoginSessionAction,
VPNLoginSessionStop: handleVPNLoginSessionStop,
VPNLogout: handleVPNLogout,
})
}

View File

@@ -0,0 +1,198 @@
package apiroutes
import "net/http"
type Handlers struct {
Healthz http.HandlerFunc
EventsStream http.HandlerFunc
GetStatus http.HandlerFunc
VPNLoginState http.HandlerFunc
SystemdState http.HandlerFunc
RoutesServiceStart http.HandlerFunc
RoutesServiceStop http.HandlerFunc
RoutesServiceRestart http.HandlerFunc
RoutesService http.HandlerFunc
RoutesUpdate http.HandlerFunc
RoutesTimer http.HandlerFunc
RoutesTimerToggle http.HandlerFunc
RoutesRollback http.HandlerFunc
RoutesClear http.HandlerFunc
RoutesCacheRestore http.HandlerFunc
RoutesPrecheckDebug http.HandlerFunc
RoutesFixPolicyRoute http.HandlerFunc
RoutesFixPolicyAlias http.HandlerFunc
TrafficMode http.HandlerFunc
TrafficModeTest http.HandlerFunc
TrafficAdvancedReset http.HandlerFunc
TrafficInterfaces http.HandlerFunc
TrafficCandidates http.HandlerFunc
TrafficAppMarks http.HandlerFunc
TrafficAppMarksItems http.HandlerFunc
TrafficAppProfiles http.HandlerFunc
TrafficAudit http.HandlerFunc
TransportClients http.HandlerFunc
TransportClientByID http.HandlerFunc
TransportInterfaces http.HandlerFunc
TransportRuntimeObservability http.HandlerFunc
TransportPolicies http.HandlerFunc
TransportPoliciesValidate http.HandlerFunc
TransportPoliciesApply http.HandlerFunc
TransportPoliciesRollback http.HandlerFunc
TransportConflicts http.HandlerFunc
TransportOwnership http.HandlerFunc
TransportOwnerLocks http.HandlerFunc
TransportOwnerLocksClear http.HandlerFunc
TransportCapabilities http.HandlerFunc
TransportHealthRefresh http.HandlerFunc
TransportNetnsToggle http.HandlerFunc
TransportSingBoxProfiles http.HandlerFunc
TransportSingBoxProfileByID http.HandlerFunc
TransportSingBoxFeatures http.HandlerFunc
EgressIdentityGet http.HandlerFunc
EgressIdentityRefresh http.HandlerFunc
TraceTailPlain http.HandlerFunc
TraceJSON http.HandlerFunc
TraceAppend http.HandlerFunc
DNSUpstreams http.HandlerFunc
DNSUpstreamPool http.HandlerFunc
DNSStatus http.HandlerFunc
DNSModeSet http.HandlerFunc
DNSBenchmark http.HandlerFunc
DNSSmartdnsService http.HandlerFunc
SmartdnsService http.HandlerFunc
SmartdnsRuntime http.HandlerFunc
SmartdnsPrewarm http.HandlerFunc
SmartdnsWildcards http.HandlerFunc
DomainsTable http.HandlerFunc
DomainsFile http.HandlerFunc
VPNAutoloopStatus http.HandlerFunc
VPNStatus http.HandlerFunc
VPNAutoconnect http.HandlerFunc
VPNListLocations http.HandlerFunc
VPNSetLocation http.HandlerFunc
VPNLoginSessionStart http.HandlerFunc
VPNLoginSessionState http.HandlerFunc
VPNLoginSessionAction http.HandlerFunc
VPNLoginSessionStop http.HandlerFunc
VPNLogout http.HandlerFunc
}
func Register(mux *http.ServeMux, h Handlers) {
registerCoreRoutes(mux, h)
registerRoutesControlRoutes(mux, h)
registerTrafficRoutes(mux, h)
registerTransportRoutes(mux, h)
registerTraceRoutes(mux, h)
registerDNSRoutes(mux, h)
registerSmartDNSRoutes(mux, h)
registerDomainsRoutes(mux, h)
registerVPNRoutes(mux, h)
}
func registerCoreRoutes(mux *http.ServeMux, h Handlers) {
mux.HandleFunc("/healthz", h.Healthz)
mux.HandleFunc("/api/v1/events/stream", h.EventsStream)
mux.HandleFunc("/api/v1/status", h.GetStatus)
mux.HandleFunc("/api/v1/routes/status", h.GetStatus)
mux.HandleFunc("/api/v1/vpn/login-state", h.VPNLoginState)
mux.HandleFunc("/api/v1/systemd/state", h.SystemdState)
}
func registerRoutesControlRoutes(mux *http.ServeMux, h Handlers) {
mux.HandleFunc("/api/v1/routes/service/start", h.RoutesServiceStart)
mux.HandleFunc("/api/v1/routes/service/stop", h.RoutesServiceStop)
mux.HandleFunc("/api/v1/routes/service/restart", h.RoutesServiceRestart)
mux.HandleFunc("/api/v1/routes/service", h.RoutesService)
mux.HandleFunc("/api/v1/routes/update", h.RoutesUpdate)
mux.HandleFunc("/api/v1/routes/timer", h.RoutesTimer)
mux.HandleFunc("/api/v1/routes/timer/toggle", h.RoutesTimerToggle)
mux.HandleFunc("/api/v1/routes/rollback", h.RoutesRollback)
mux.HandleFunc("/api/v1/routes/clear", h.RoutesClear)
mux.HandleFunc("/api/v1/routes/cache/restore", h.RoutesCacheRestore)
mux.HandleFunc("/api/v1/routes/precheck/debug", h.RoutesPrecheckDebug)
mux.HandleFunc("/api/v1/routes/fix-policy-route", h.RoutesFixPolicyRoute)
mux.HandleFunc("/api/v1/routes/fix-policy", h.RoutesFixPolicyAlias)
}
func registerTrafficRoutes(mux *http.ServeMux, h Handlers) {
mux.HandleFunc("/api/v1/traffic/mode", h.TrafficMode)
mux.HandleFunc("/api/v1/traffic/mode/test", h.TrafficModeTest)
mux.HandleFunc("/api/v1/traffic/advanced/reset", h.TrafficAdvancedReset)
mux.HandleFunc("/api/v1/traffic/interfaces", h.TrafficInterfaces)
mux.HandleFunc("/api/v1/traffic/candidates", h.TrafficCandidates)
mux.HandleFunc("/api/v1/traffic/appmarks", h.TrafficAppMarks)
mux.HandleFunc("/api/v1/traffic/appmarks/items", h.TrafficAppMarksItems)
mux.HandleFunc("/api/v1/traffic/app-profiles", h.TrafficAppProfiles)
mux.HandleFunc("/api/v1/traffic/audit", h.TrafficAudit)
}
func registerTransportRoutes(mux *http.ServeMux, h Handlers) {
mux.HandleFunc("/api/v1/transport/clients", h.TransportClients)
mux.HandleFunc("/api/v1/transport/clients/", h.TransportClientByID)
mux.HandleFunc("/api/v1/transport/interfaces", h.TransportInterfaces)
mux.HandleFunc("/api/v1/transport/runtime/observability", h.TransportRuntimeObservability)
mux.HandleFunc("/api/v1/transport/policies", h.TransportPolicies)
mux.HandleFunc("/api/v1/transport/policies/validate", h.TransportPoliciesValidate)
mux.HandleFunc("/api/v1/transport/policies/apply", h.TransportPoliciesApply)
mux.HandleFunc("/api/v1/transport/policies/rollback", h.TransportPoliciesRollback)
mux.HandleFunc("/api/v1/transport/conflicts", h.TransportConflicts)
mux.HandleFunc("/api/v1/transport/owners", h.TransportOwnership)
mux.HandleFunc("/api/v1/transport/owner-locks", h.TransportOwnerLocks)
mux.HandleFunc("/api/v1/transport/owner-locks/clear", h.TransportOwnerLocksClear)
mux.HandleFunc("/api/v1/transport/capabilities", h.TransportCapabilities)
mux.HandleFunc("/api/v1/transport/health/refresh", h.TransportHealthRefresh)
mux.HandleFunc("/api/v1/transport/netns/toggle", h.TransportNetnsToggle)
mux.HandleFunc("/api/v1/transport/singbox/profiles", h.TransportSingBoxProfiles)
mux.HandleFunc("/api/v1/transport/singbox/profiles/", h.TransportSingBoxProfileByID)
mux.HandleFunc("/api/v1/transport/singbox/features", h.TransportSingBoxFeatures)
mux.HandleFunc("/api/v1/egress/identity", h.EgressIdentityGet)
mux.HandleFunc("/api/v1/egress/identity/refresh", h.EgressIdentityRefresh)
}
func registerTraceRoutes(mux *http.ServeMux, h Handlers) {
mux.HandleFunc("/api/v1/trace", h.TraceTailPlain)
mux.HandleFunc("/api/v1/trace-json", h.TraceJSON)
mux.HandleFunc("/api/v1/trace/append", h.TraceAppend)
}
func registerDNSRoutes(mux *http.ServeMux, h Handlers) {
mux.HandleFunc("/api/v1/dns-upstreams", h.DNSUpstreams)
mux.HandleFunc("/api/v1/dns/upstream-pool", h.DNSUpstreamPool)
mux.HandleFunc("/api/v1/dns/status", h.DNSStatus)
mux.HandleFunc("/api/v1/dns/mode", h.DNSModeSet)
mux.HandleFunc("/api/v1/dns/benchmark", h.DNSBenchmark)
mux.HandleFunc("/api/v1/dns/smartdns-service", h.DNSSmartdnsService)
}
func registerSmartDNSRoutes(mux *http.ServeMux, h Handlers) {
mux.HandleFunc("/api/v1/smartdns/service", h.SmartdnsService)
mux.HandleFunc("/api/v1/smartdns/runtime", h.SmartdnsRuntime)
mux.HandleFunc("/api/v1/smartdns/prewarm", h.SmartdnsPrewarm)
mux.HandleFunc("/api/v1/smartdns/wildcards", h.SmartdnsWildcards)
}
func registerDomainsRoutes(mux *http.ServeMux, h Handlers) {
mux.HandleFunc("/api/v1/domains/table", h.DomainsTable)
mux.HandleFunc("/api/v1/domains/file", h.DomainsFile)
}
func registerVPNRoutes(mux *http.ServeMux, h Handlers) {
mux.HandleFunc("/api/v1/vpn/autoloop-status", h.VPNAutoloopStatus)
mux.HandleFunc("/api/v1/vpn/status", h.VPNStatus)
mux.HandleFunc("/api/v1/vpn/autoconnect", h.VPNAutoconnect)
mux.HandleFunc("/api/v1/vpn/locations", h.VPNListLocations)
mux.HandleFunc("/api/v1/vpn/location", h.VPNSetLocation)
mux.HandleFunc("/api/v1/vpn/login/session/start", h.VPNLoginSessionStart)
mux.HandleFunc("/api/v1/vpn/login/session/state", h.VPNLoginSessionState)
mux.HandleFunc("/api/v1/vpn/login/session/action", h.VPNLoginSessionAction)
mux.HandleFunc("/api/v1/vpn/login/session/stop", h.VPNLoginSessionStop)
mux.HandleFunc("/api/v1/vpn/logout", h.VPNLogout)
}

View File

@@ -4,7 +4,6 @@ import (
"fmt"
"log"
"os"
"regexp"
"strings"
"time"
)
@@ -44,28 +43,7 @@ func runAutoloop(iface, table string, mtu int, stateDirPath, defaultLoc string)
}
writeLoginState := func(state, email, msg string) {
ts := time.Now().Format(time.RFC3339)
payload := fmt.Sprintf(`{"ts":"%s","state":"%s","email":"%s","msg":"%s"}`, ts, escapeJSON(state), escapeJSON(email), escapeJSON(msg))
_ = os.WriteFile(loginStateFile, []byte(payload), 0o644)
}
getLocation := func() string {
if data, err := os.ReadFile(locFile); err == nil {
for _, ln := range strings.Split(string(data), "\n") {
t := strings.TrimSpace(ln)
if t != "" && !strings.HasPrefix(t, "#") {
return t
}
}
}
return defaultLoc
}
isConnected := func(out string) bool {
low := strings.ToLower(out)
return strings.Contains(low, "vpn is connected") ||
strings.Contains(low, "connected to") ||
strings.Contains(low, "after connect: connected")
writeAutoloopLoginState(loginStateFile, state, email, msg)
}
fixPolicy := func() {
@@ -83,45 +61,9 @@ func runAutoloop(iface, table string, mtu int, stateDirPath, defaultLoc string)
" mtu " + fmt.Sprintf("%d", mtu) + " OK")
}
}
var emailRe = regexp.MustCompile(`[A-Za-z0-9._%+\-]+@[A-Za-z0-9.\-]+`)
parseEmail := func(text string) string {
return emailRe.FindString(text)
}
isLoginRequired := func(t string) bool {
low := strings.ToLower(t)
return strings.Contains(low, "please log in") ||
strings.Contains(low, "not logged in") ||
strings.Contains(low, "login required") ||
strings.Contains(low, "sign in")
}
updateLoginStateFromText := func(text string) {
if isLoginRequired(text) {
writeLoginState("no_login", "", "NOT LOGGED IN")
logLine("login: NO (detected from output)")
return
}
if em := parseEmail(text); em != "" {
writeLoginState("ok", em, "logged in")
logLine("login: OK email=" + em)
return
}
low := strings.ToLower(text)
if strings.Contains(low, "not logged in") ||
strings.Contains(low, "expired") ||
strings.Contains(low, "no active license") {
writeLoginState("no_login", "", "NOT LOGGED IN (license)")
logLine("login: NO (license says not logged in)")
return
}
if strings.Contains(low, "license") &&
(strings.Contains(low, "active") || strings.Contains(low, "valid")) {
writeLoginState("ok", "", "logged in (license ok)")
logLine("login: OK (license ok, email not found)")
return
}
updateAutoloopLoginStateFromText(text, writeLoginState, logLine)
}
updateLicense := func() {
@@ -144,7 +86,7 @@ func runAutoloop(iface, table string, mtu int, stateDirPath, defaultLoc string)
if err != nil {
logLine(fmt.Sprintf("status: ERROR exit=%d err=%v raw=%q", exitCode, err, statusOut))
}
if isConnected(statusOut) {
if isAutoloopConnected(statusOut) {
logLine("status: CONNECTED; raw: " + statusOut)
fixPolicy()
updateLicense()
@@ -163,18 +105,30 @@ func runAutoloop(iface, table string, mtu int, stateDirPath, defaultLoc string)
})
updateLoginStateFromText(statusOut)
loc := getLocation()
logLine("reconnecting to " + loc)
loc := resolveAutoloopLocationSpec(locFile, defaultLoc)
primary := strings.TrimSpace(loc.Primary)
if primary == "" {
primary = strings.TrimSpace(defaultLoc)
}
logLine("reconnecting to " + primary)
_, _, _, _ = runCommandTimeout(disconnectTimeout, adgvpnCLI, "disconnect")
connectOut, _, _, _ := runCommandTimeout(connectTimeout, adgvpnCLI, "connect", "-l", loc, "--log-to-file")
connectOut, _, _, _ := runCommandTimeout(connectTimeout, adgvpnCLI, "connect", "-l", primary, "--log-to-file")
connectOut = stripANSI(connectOut)
logLine("connect raw: " + connectOut)
updateLoginStateFromText(connectOut)
if !isAutoloopConnected(connectOut) && loc.ISO != "" && !strings.EqualFold(loc.ISO, primary) {
logLine("connect fallback to ISO: " + loc.ISO)
fallbackOut, _, _, _ := runCommandTimeout(connectTimeout, adgvpnCLI, "connect", "-l", loc.ISO, "--log-to-file")
fallbackOut = stripANSI(fallbackOut)
logLine("connect fallback raw: " + fallbackOut)
updateLoginStateFromText(fallbackOut)
}
statusAfter, _, _, _ := runCommandTimeout(statusTimeout, adgvpnCLI, "status")
statusAfter = stripANSI(statusAfter)
if isConnected(statusAfter) {
if isAutoloopConnected(statusAfter) {
logLine("after connect: CONNECTED; raw: " + statusAfter)
fixPolicy()
updateLicense()
@@ -190,15 +144,3 @@ func runAutoloop(iface, table string, mtu int, stateDirPath, defaultLoc string)
time.Sleep(10 * time.Second)
}
}
// ---------------------------------------------------------------------
// autoloop helpers
// ---------------------------------------------------------------------
func escapeJSON(s string) string {
s = strings.ReplaceAll(s, `\`, `\\`)
s = strings.ReplaceAll(s, `"`, `\\"`)
s = strings.ReplaceAll(s, "\n", "\\n")
s = strings.ReplaceAll(s, "\r", "")
return s
}

View File

@@ -0,0 +1,12 @@
package app
import (
"regexp"
)
type autoloopLocationSpec struct {
Primary string
ISO string
}
var autoloopEmailRe = regexp.MustCompile(`[A-Za-z0-9._%+\-]+@[A-Za-z0-9.\-]+`)

View File

@@ -0,0 +1,107 @@
package app
import (
"encoding/json"
"os"
"strings"
)
func resolveAutoloopLocationSpec(locFile, defaultLoc string) autoloopLocationSpec {
raw := ""
if data, err := os.ReadFile(locFile); err == nil {
for _, ln := range strings.Split(string(data), "\n") {
t := strings.TrimSpace(ln)
if t != "" && !strings.HasPrefix(t, "#") {
raw = t
break
}
}
}
if raw == "" {
raw = strings.TrimSpace(defaultLoc)
}
raw = strings.TrimSpace(raw)
primary := raw
iso := ""
if p := strings.SplitN(raw, "|", 2); len(p) == 2 {
primary = strings.TrimSpace(p[0])
iso = strings.ToUpper(strings.TrimSpace(p[1]))
}
if primary == "" {
primary = strings.TrimSpace(defaultLoc)
}
if isISO2(primary) {
iso = strings.ToUpper(primary)
}
if iso == "" {
if tokens := strings.Fields(primary); len(tokens) > 0 && isISO2(tokens[0]) {
iso = strings.ToUpper(tokens[0])
}
}
if iso == "" {
iso = lookupISOFromLocationsCache(primary)
}
if iso == "" && isISO2(defaultLoc) {
iso = strings.ToUpper(strings.TrimSpace(defaultLoc))
}
return autoloopLocationSpec{
Primary: primary,
ISO: iso,
}
}
func isISO2(v string) bool {
s := strings.TrimSpace(v)
if len(s) != 2 {
return false
}
for _, ch := range s {
if (ch < 'A' || ch > 'Z') && (ch < 'a' || ch > 'z') {
return false
}
}
return true
}
func lookupISOFromLocationsCache(primary string) string {
want := strings.ToLower(strings.TrimSpace(primary))
if want == "" {
return ""
}
var disk struct {
Locations []struct {
ISO string `json:"iso"`
Label string `json:"label"`
Target string `json:"target"`
} `json:"locations"`
}
data, err := os.ReadFile(vpnLocationsCachePath)
if err != nil {
return ""
}
if err := json.Unmarshal(data, &disk); err != nil {
return ""
}
for _, it := range disk.Locations {
iso := strings.ToUpper(strings.TrimSpace(it.ISO))
if !isISO2(iso) {
continue
}
target := strings.ToLower(strings.TrimSpace(it.Target))
if target != "" && target == want {
return iso
}
labelNorm := strings.ToLower(strings.TrimSpace(it.Label))
if labelNorm != "" {
if inferVPNLocationTargetFromLabel(it.Label, iso) == primary {
return iso
}
if strings.Contains(labelNorm, want) {
return iso
}
}
}
return ""
}

View File

@@ -0,0 +1,89 @@
package app
import (
"fmt"
"os"
"strings"
"time"
)
func writeAutoloopLoginState(loginStateFile, state, email, msg string) {
ts := time.Now().Format(time.RFC3339)
payload := fmt.Sprintf(`{"ts":"%s","state":"%s","email":"%s","msg":"%s"}`,
ts,
escapeJSON(state),
escapeJSON(email),
escapeJSON(msg),
)
_ = os.WriteFile(loginStateFile, []byte(payload), 0o644)
}
func isAutoloopConnected(out string) bool {
low := strings.ToLower(out)
return strings.Contains(low, "vpn is connected") ||
strings.Contains(low, "connected to") ||
strings.Contains(low, "after connect: connected")
}
func parseAutoloopEmail(text string) string {
return autoloopEmailRe.FindString(text)
}
func isAutoloopLoginRequired(text string) bool {
low := strings.ToLower(text)
return strings.Contains(low, "please log in") ||
strings.Contains(low, "not logged in") ||
strings.Contains(low, "login required") ||
strings.Contains(low, "sign in")
}
func updateAutoloopLoginStateFromText(
text string,
writeState func(state, email, msg string),
logLine func(msg string),
) {
if writeState == nil {
return
}
if isAutoloopLoginRequired(text) {
writeState("no_login", "", "NOT LOGGED IN")
if logLine != nil {
logLine("login: NO (detected from output)")
}
return
}
if em := parseAutoloopEmail(text); em != "" {
writeState("ok", em, "logged in")
if logLine != nil {
logLine("login: OK email=" + em)
}
return
}
low := strings.ToLower(text)
if strings.Contains(low, "not logged in") ||
strings.Contains(low, "expired") ||
strings.Contains(low, "no active license") {
writeState("no_login", "", "NOT LOGGED IN (license)")
if logLine != nil {
logLine("login: NO (license says not logged in)")
}
return
}
if strings.Contains(low, "license") &&
(strings.Contains(low, "active") || strings.Contains(low, "valid")) {
writeState("ok", "", "logged in (license ok)")
if logLine != nil {
logLine("login: OK (license ok, email not found)")
}
return
}
}
func escapeJSON(s string) string {
s = strings.ReplaceAll(s, `\`, `\\`)
s = strings.ReplaceAll(s, `"`, `\\"`)
s = strings.ReplaceAll(s, "\n", "\\n")
s = strings.ReplaceAll(s, "\r", "")
return s
}

View File

@@ -0,0 +1,54 @@
package bootstrap
import (
"context"
"errors"
"fmt"
"net/http"
"time"
)
type Config struct {
Addr string
ReadHeaderTimeout time.Duration
RegisterRoutes func(mux *http.ServeMux)
WrapHandler func(next http.Handler) http.Handler
StartWatchers func(ctx context.Context)
}
func Run(ctx context.Context, cfg Config) error {
if cfg.RegisterRoutes == nil {
return fmt.Errorf("register routes callback is required")
}
if cfg.Addr == "" {
return fmt.Errorf("addr is required")
}
readHeaderTimeout := cfg.ReadHeaderTimeout
if readHeaderTimeout <= 0 {
readHeaderTimeout = 5 * time.Second
}
mux := http.NewServeMux()
cfg.RegisterRoutes(mux)
handler := http.Handler(mux)
if cfg.WrapHandler != nil {
handler = cfg.WrapHandler(handler)
}
srv := &http.Server{
Addr: cfg.Addr,
Handler: handler,
ReadHeaderTimeout: readHeaderTimeout,
}
if cfg.StartWatchers != nil {
go cfg.StartWatchers(ctx)
}
if err := srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
return err
}
return nil
}

View File

@@ -0,0 +1,59 @@
package cli
import (
"flag"
"fmt"
"io"
"os"
)
type AutoloopParams struct {
Iface string
Table string
MTU int
StateDir string
DefaultLocation string
}
type AutoloopDeps struct {
StateDirDefault string
ResolveIface func(flagIface string) string
Run func(params AutoloopParams)
Stderr io.Writer
}
func RunAutoloop(args []string, deps AutoloopDeps) int {
if deps.ResolveIface == nil || deps.Run == nil {
return 1
}
stderr := deps.Stderr
if stderr == nil {
stderr = os.Stderr
}
fs := flag.NewFlagSet("autoloop", flag.ContinueOnError)
fs.SetOutput(stderr)
iface := fs.String("iface", "", "VPN interface (empty/auto = detect active)")
table := fs.String("table", "agvpn", "routing table name")
mtu := fs.Int("mtu", 1380, "MTU for default route")
stateDir := fs.String("state-dir", deps.StateDirDefault, "state directory")
defaultLoc := fs.String("default-location", "Austria", "default location")
if err := fs.Parse(args); err != nil {
return 2
}
resolvedIface := deps.ResolveIface(*iface)
if resolvedIface == "" {
fmt.Fprintln(stderr, "autoloop: cannot resolve VPN interface (set --iface or preferred iface)")
return 1
}
deps.Run(AutoloopParams{
Iface: resolvedIface,
Table: *table,
MTU: *mtu,
StateDir: *stateDir,
DefaultLocation: *defaultLoc,
})
return 0
}

View File

@@ -0,0 +1,47 @@
package cli
import (
"flag"
"fmt"
"io"
"os"
"strings"
)
type RoutesClearDeps struct {
Clear func() (ok bool, message string)
Stdout io.Writer
Stderr io.Writer
}
func RunRoutesClear(args []string, deps RoutesClearDeps) int {
if deps.Clear == nil {
return 1
}
stdout := deps.Stdout
if stdout == nil {
stdout = os.Stdout
}
stderr := deps.Stderr
if stderr == nil {
stderr = os.Stderr
}
fs := flag.NewFlagSet("routes-clear", flag.ContinueOnError)
fs.SetOutput(stderr)
if err := fs.Parse(args); err != nil {
return 2
}
ok, message := deps.Clear()
if ok {
fmt.Fprintln(stdout, strings.TrimSpace(message))
return 0
}
msg := strings.TrimSpace(message)
if msg == "" {
msg = "routes clear failed"
}
fmt.Fprintln(stderr, msg)
return 1
}

View File

@@ -0,0 +1,65 @@
package cli
import (
"flag"
"fmt"
"io"
"os"
"strings"
"syscall"
)
type RoutesUpdateDeps struct {
LockFile string
Update func(iface string) (ok bool, message string)
Stdout io.Writer
Stderr io.Writer
}
func RunRoutesUpdate(args []string, deps RoutesUpdateDeps) int {
if deps.Update == nil || deps.LockFile == "" {
return 1
}
stdout := deps.Stdout
if stdout == nil {
stdout = os.Stdout
}
stderr := deps.Stderr
if stderr == nil {
stderr = os.Stderr
}
fs := flag.NewFlagSet("routes-update", flag.ContinueOnError)
fs.SetOutput(stderr)
iface := fs.String("iface", "", "VPN interface (empty/auto = detect active)")
if err := fs.Parse(args); err != nil {
return 2
}
lock, err := os.OpenFile(deps.LockFile, os.O_CREATE|os.O_RDWR, 0o644)
if err != nil {
fmt.Fprintf(stderr, "lock open error: %v\n", err)
return 1
}
defer lock.Close()
if err := syscall.Flock(int(lock.Fd()), syscall.LOCK_EX|syscall.LOCK_NB); err != nil {
fmt.Fprintln(stdout, "routes update already running")
return 0
}
defer func() {
_ = syscall.Flock(int(lock.Fd()), syscall.LOCK_UN)
}()
ok, message := deps.Update(*iface)
if ok {
fmt.Fprintln(stdout, strings.TrimSpace(message))
return 0
}
msg := strings.TrimSpace(message)
if msg == "" {
msg = "routes update failed"
}
fmt.Fprintln(stderr, msg)
return 1
}

View File

@@ -12,28 +12,40 @@ import "embed"
// ---------------------------------------------------------------------
const (
stateDir = "/var/lib/selective-vpn"
statusFilePath = stateDir + "/status.json"
dnsModePath = stateDir + "/dns-mode.json"
trafficModePath = stateDir + "/traffic-mode.json"
trafficAppMarksPath = stateDir + "/traffic-appmarks.json"
trafficAppProfilesPath = stateDir + "/traffic-app-profiles.json"
stateDir = "/var/lib/selective-vpn"
statusFilePath = stateDir + "/status.json"
dnsModePath = stateDir + "/dns-mode.json"
trafficModePath = stateDir + "/traffic-mode.json"
trafficAppMarksPath = stateDir + "/traffic-appmarks.json"
trafficAppProfilesPath = stateDir + "/traffic-app-profiles.json"
transportClientsPath = stateDir + "/transport-clients.json"
transportInterfacesPath = stateDir + "/transport-interfaces.json"
transportPolicyPath = stateDir + "/transport-policies.json"
transportPolicyPlanPath = stateDir + "/transport-policies.plan.json"
transportPolicyRuntimePath = stateDir + "/transport-policies.runtime.json"
transportPolicyRuntimeSnap = stateDir + "/transport-policies.runtime.prev.json"
transportOwnershipPath = stateDir + "/transport-ownership.json"
transportOwnerLocksPath = stateDir + "/transport-owner-locks.json"
transportConflictsPath = stateDir + "/transport-conflicts.json"
transportPolicySnap = stateDir + "/transport-policies.prev.json"
transportBootstrapPath = stateDir + "/transport-bootstrap-routes.json"
traceLogPath = stateDir + "/trace.log"
smartdnsLogPath = stateDir + "/smartdns.log"
lastIPsPath = stateDir + "/last-ips.txt"
lastIPsMapPath = stateDir + "/last-ips-map.txt"
lastIPsDirect = stateDir + "/last-ips-direct.txt"
lastIPsDyn = stateDir + "/last-ips-dyn.txt"
lastIPsMapDirect = stateDir + "/last-ips-map-direct.txt"
lastIPsMapDyn = stateDir + "/last-ips-map-wildcard.txt"
routesCacheMeta = stateDir + "/routes-clear-cache.json"
routesCacheIPs = stateDir + "/routes-clear-cache-ips.txt"
routesCacheDyn = stateDir + "/routes-clear-cache-ips-dyn.txt"
routesCacheMap = stateDir + "/routes-clear-cache-ips-map.txt"
routesCacheMapD = stateDir + "/routes-clear-cache-ips-map-direct.txt"
routesCacheMapW = stateDir + "/routes-clear-cache-ips-map-wildcard.txt"
routesCacheRT = stateDir + "/routes-clear-cache-routes.txt"
traceLogPath = stateDir + "/trace.log"
smartdnsLogPath = stateDir + "/smartdns.log"
lastIPsPath = stateDir + "/last-ips.txt"
lastIPsMapPath = stateDir + "/last-ips-map.txt"
lastIPsDirect = stateDir + "/last-ips-direct.txt"
lastIPsDyn = stateDir + "/last-ips-dyn.txt"
lastIPsMapDirect = stateDir + "/last-ips-map-direct.txt"
lastIPsMapDyn = stateDir + "/last-ips-map-wildcard.txt"
routesCacheMeta = stateDir + "/routes-clear-cache.json"
routesCacheIPs = stateDir + "/routes-clear-cache-ips.txt"
routesCacheDyn = stateDir + "/routes-clear-cache-ips-dyn.txt"
routesCacheMap = stateDir + "/routes-clear-cache-ips-map.txt"
routesCacheMapD = stateDir + "/routes-clear-cache-ips-map-direct.txt"
routesCacheMapW = stateDir + "/routes-clear-cache-ips-map-wildcard.txt"
routesCacheRT = stateDir + "/routes-clear-cache-routes.txt"
precheckForcePath = stateDir + "/precheck-force.once"
autoloopLogPath = stateDir + "/adguard-autoloop.log"
loginStatePath = stateDir + "/adguard-login.json"

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,81 @@
package app
import (
dnscfgpkg "selective-vpn-api/app/dnscfg"
"time"
)
var dnsBenchmarkDefaultDomains = append([]string(nil), dnscfgpkg.BenchmarkDefaultDomains...)
const (
dnsBenchmarkProfileQuick = "quick"
dnsBenchmarkProfileLoad = "load"
)
type dnsBenchmarkOptions = dnscfgpkg.BenchmarkOptions
func normalizeBenchmarkUpstreams(in []DNSBenchmarkUpstream) []string {
if len(in) == 0 {
return nil
}
raw := make([]string, 0, len(in))
for _, item := range in {
raw = append(raw, item.Addr)
}
return dnscfgpkg.NormalizeBenchmarkUpstreamStrings(raw, normalizeDNSUpstream)
}
func benchmarkDNSUpstream(upstream string, domains []string, timeout time.Duration, attempts int, opts dnsBenchmarkOptions) DNSBenchmarkResult {
classify := func(err error) string {
switch classifyDNSError(err) {
case dnsErrorNXDomain:
return dnscfgpkg.BenchmarkErrorNXDomain
case dnsErrorTimeout:
return dnscfgpkg.BenchmarkErrorTimeout
case dnsErrorTemporary:
return dnscfgpkg.BenchmarkErrorTemporary
default:
return dnscfgpkg.BenchmarkErrorOther
}
}
pkgRes := dnscfgpkg.BenchmarkDNSUpstream(upstream, domains, timeout, attempts, opts, dnsLookupAOnce, classify)
return DNSBenchmarkResult{
Upstream: pkgRes.Upstream,
Attempts: pkgRes.Attempts,
OK: pkgRes.OK,
Fail: pkgRes.Fail,
NXDomain: pkgRes.NXDomain,
Timeout: pkgRes.Timeout,
Temporary: pkgRes.Temporary,
Other: pkgRes.Other,
AvgMS: pkgRes.AvgMS,
P95MS: pkgRes.P95MS,
Score: pkgRes.Score,
Color: pkgRes.Color,
}
}
func dnsLookupAOnce(host string, upstream string, timeout time.Duration) ([]string, error) {
return dnscfgpkg.DNSLookupAOnce(host, upstream, timeout, splitDNS, isPrivateIPv4)
}
func benchmarkTopN(results []DNSBenchmarkResult, n int, fallback []string) []string {
pkgResults := make([]dnscfgpkg.BenchmarkResult, 0, len(results))
for _, item := range results {
pkgResults = append(pkgResults, dnscfgpkg.BenchmarkResult{
Upstream: item.Upstream,
Attempts: item.Attempts,
OK: item.OK,
Fail: item.Fail,
NXDomain: item.NXDomain,
Timeout: item.Timeout,
Temporary: item.Temporary,
Other: item.Other,
AvgMS: item.AvgMS,
P95MS: item.P95MS,
Score: item.Score,
Color: item.Color,
})
}
return dnscfgpkg.BenchmarkTopN(pkgResults, n, fallback)
}

View File

@@ -0,0 +1,140 @@
package app
import (
"encoding/json"
"io"
"net/http"
dnscfgpkg "selective-vpn-api/app/dnscfg"
"sort"
"sync"
"time"
)
func handleDNSBenchmark(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
var req DNSBenchmarkRequest
if r.Body != nil {
defer r.Body.Close()
if err := json.NewDecoder(io.LimitReader(r.Body, 1<<20)).Decode(&req); err != nil && err != io.EOF {
http.Error(w, "bad json", http.StatusBadRequest)
return
}
}
upstreams := normalizeBenchmarkUpstreams(req.Upstreams)
if len(upstreams) == 0 {
pool := loadDNSUpstreamPoolState()
if len(pool) > 0 {
tmp := make([]DNSBenchmarkUpstream, 0, len(pool))
for _, item := range pool {
tmp = append(tmp, DNSBenchmarkUpstream{Addr: item.Addr, Enabled: item.Enabled})
}
upstreams = normalizeBenchmarkUpstreams(tmp)
}
if len(upstreams) == 0 {
cfg := loadDNSUpstreamsConf()
upstreams = dnscfgpkg.NormalizeBenchmarkUpstreamStrings([]string{
cfg.Default1,
cfg.Default2,
cfg.Meta1,
cfg.Meta2,
}, normalizeDNSUpstream)
}
}
if len(upstreams) == 0 {
http.Error(w, "no upstreams", http.StatusBadRequest)
return
}
domains := dnscfgpkg.NormalizeBenchmarkDomains(req.Domains)
if len(domains) == 0 {
domains = append(domains, dnsBenchmarkDefaultDomains...)
}
timeoutMS := req.TimeoutMS
if timeoutMS <= 0 {
timeoutMS = 1800
}
if timeoutMS < 300 {
timeoutMS = 300
}
if timeoutMS > 5000 {
timeoutMS = 5000
}
attempts := req.Attempts
if attempts <= 0 {
attempts = 1
}
if attempts > 3 {
attempts = 3
}
profile := dnscfgpkg.NormalizeBenchmarkProfile(req.Profile)
if profile == dnsBenchmarkProfileLoad && attempts < 2 {
// Load profile should emulate real resolver pressure.
attempts = 2
}
concurrency := req.Concurrency
if concurrency <= 0 {
concurrency = 6
}
if concurrency < 1 {
concurrency = 1
}
if concurrency > 32 {
concurrency = 32
}
if concurrency > len(upstreams) {
concurrency = len(upstreams)
}
opts := dnscfgpkg.MakeDNSBenchmarkOptions(profile, concurrency)
results := make([]DNSBenchmarkResult, 0, len(upstreams))
var mu sync.Mutex
var wg sync.WaitGroup
sem := make(chan struct{}, concurrency)
timeout := time.Duration(timeoutMS) * time.Millisecond
for _, upstream := range upstreams {
wg.Add(1)
sem <- struct{}{}
go func(upstream string) {
defer wg.Done()
defer func() { <-sem }()
result := benchmarkDNSUpstream(upstream, domains, timeout, attempts, opts)
mu.Lock()
results = append(results, result)
mu.Unlock()
}(upstream)
}
wg.Wait()
sort.Slice(results, func(i, j int) bool {
if results[i].Score == results[j].Score {
if results[i].AvgMS == results[j].AvgMS {
if results[i].OK == results[j].OK {
return results[i].Upstream < results[j].Upstream
}
return results[i].OK > results[j].OK
}
return results[i].AvgMS < results[j].AvgMS
}
return results[i].Score > results[j].Score
})
resp := DNSBenchmarkResponse{
Results: results,
DomainsUsed: domains,
TimeoutMS: timeoutMS,
AttemptsPerDomain: attempts,
Profile: profile,
}
resp.RecommendedDefault = benchmarkTopN(results, 2, upstreams)
resp.RecommendedMeta = benchmarkTopN(results, 2, upstreams)
writeJSON(w, http.StatusOK, resp)
}

View File

@@ -0,0 +1,72 @@
package app
import (
"encoding/json"
"io"
"net/http"
"strings"
)
// ---------------------------------------------------------------------
// EN: `handleDNSStatus` is an HTTP handler for dns status.
// RU: `handleDNSStatus` - HTTP-обработчик для dns status.
// ---------------------------------------------------------------------
func handleDNSStatus(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
mode := loadDNSMode()
writeJSON(w, http.StatusOK, makeDNSStatusResponse(mode))
}
// ---------------------------------------------------------------------
// EN: `handleDNSModeSet` is an HTTP handler for dns mode set.
// RU: `handleDNSModeSet` - HTTP-обработчик для dns mode set.
// ---------------------------------------------------------------------
func handleDNSModeSet(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
var req DNSModeRequest
if r.Body != nil {
defer r.Body.Close()
if err := json.NewDecoder(io.LimitReader(r.Body, 1<<20)).Decode(&req); err != nil && err != io.EOF {
http.Error(w, "bad json", http.StatusBadRequest)
return
}
}
mode := loadDNSMode()
mode.Mode = normalizeDNSResolverMode(req.Mode, req.ViaSmartDNS)
mode.ViaSmartDNS = mode.Mode != DNSModeDirect
if strings.TrimSpace(req.SmartDNSAddr) != "" {
mode.SmartDNSAddr = req.SmartDNSAddr
}
if err := saveDNSMode(mode); err != nil {
http.Error(w, "write error", http.StatusInternalServerError)
return
}
mode = loadDNSMode()
writeJSON(w, http.StatusOK, makeDNSStatusResponse(mode))
}
func makeDNSStatusResponse(mode DNSMode) DNSStatusResponse {
rt := smartDNSRuntimeSnapshot()
resp := DNSStatusResponse{
ViaSmartDNS: mode.ViaSmartDNS,
SmartDNSAddr: mode.SmartDNSAddr,
Mode: mode.Mode,
UnitState: smartdnsUnitState(),
RuntimeNftset: rt.Enabled,
WildcardSource: rt.WildcardSource,
RuntimeCfgPath: rt.ConfigPath,
}
if rt.Message != "" {
resp.RuntimeCfgError = rt.Message
}
return resp
}

View File

@@ -0,0 +1,88 @@
package app
import (
"encoding/json"
"io"
"net/http"
"strings"
)
// ---------------------------------------------------------------------
// EN: `handleDNSSmartdnsService` is an HTTP handler for dns smartdns service.
// RU: `handleDNSSmartdnsService` - HTTP-обработчик для dns smartdns service.
// ---------------------------------------------------------------------
func handleDNSSmartdnsService(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
var body struct {
Action string `json:"action"`
}
if r.Body != nil {
defer r.Body.Close()
_ = json.NewDecoder(io.LimitReader(r.Body, 1<<20)).Decode(&body)
}
action := strings.ToLower(strings.TrimSpace(body.Action))
if action == "" {
action = "restart"
}
switch action {
case "start", "stop", "restart":
default:
http.Error(w, "unknown action", http.StatusBadRequest)
return
}
res := runSmartdnsUnitAction(action)
mode := loadDNSMode()
rt := smartDNSRuntimeSnapshot()
writeJSON(w, http.StatusOK, map[string]any{
"ok": res.OK,
"message": res.Message,
"exitCode": res.ExitCode,
"stdout": res.Stdout,
"stderr": res.Stderr,
"unit_state": smartdnsUnitState(),
"via_smartdns": mode.ViaSmartDNS,
"smartdns_addr": mode.SmartDNSAddr,
"mode": mode.Mode,
"runtime_nftset": rt.Enabled,
"wildcard_source": rt.WildcardSource,
})
}
// ---------------------------------------------------------------------
// EN: `handleSmartdnsService` is an HTTP handler for smartdns service.
// RU: `handleSmartdnsService` - HTTP-обработчик для smartdns service.
// ---------------------------------------------------------------------
func handleSmartdnsService(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
writeJSON(w, http.StatusOK, map[string]string{"state": smartdnsUnitState()})
case http.MethodPost:
var body struct {
Action string `json:"action"`
}
if r.Body != nil {
defer r.Body.Close()
_ = json.NewDecoder(io.LimitReader(r.Body, 1<<20)).Decode(&body)
}
action := strings.ToLower(strings.TrimSpace(body.Action))
if action == "" {
action = "restart"
}
switch action {
case "start", "stop", "restart":
default:
http.Error(w, "unknown action", http.StatusBadRequest)
return
}
writeJSON(w, http.StatusOK, runSmartdnsUnitAction(action))
default:
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
}
}

View File

@@ -0,0 +1,61 @@
package app
import (
"encoding/json"
"io"
"net/http"
)
// ---------------------------------------------------------------------
// EN: `handleDNSUpstreams` is an HTTP handler for dns upstreams.
// RU: `handleDNSUpstreams` - HTTP-обработчик для dns upstreams.
// ---------------------------------------------------------------------
func handleDNSUpstreams(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
writeJSON(w, http.StatusOK, loadDNSUpstreamsConf())
case http.MethodPost:
var cfg DNSUpstreams
if r.Body != nil {
defer r.Body.Close()
if err := json.NewDecoder(io.LimitReader(r.Body, 1<<20)).Decode(&cfg); err != nil && err != io.EOF {
http.Error(w, "bad json", http.StatusBadRequest)
return
}
}
if err := saveDNSUpstreamsConf(cfg); err != nil {
http.Error(w, "write error", http.StatusInternalServerError)
return
}
writeJSON(w, http.StatusOK, map[string]any{
"ok": true,
"cfg": loadDNSUpstreamsConf(),
})
default:
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
}
}
func handleDNSUpstreamPool(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
items := loadDNSUpstreamPoolState()
writeJSON(w, http.StatusOK, DNSUpstreamPoolState{Items: items})
case http.MethodPost:
var body DNSUpstreamPoolState
if r.Body != nil {
defer r.Body.Close()
if err := json.NewDecoder(io.LimitReader(r.Body, 1<<20)).Decode(&body); err != nil && err != io.EOF {
http.Error(w, "bad json", http.StatusBadRequest)
return
}
}
if err := saveDNSUpstreamPoolState(body.Items); err != nil {
http.Error(w, "write error", http.StatusInternalServerError)
return
}
writeJSON(w, http.StatusOK, DNSUpstreamPoolState{Items: loadDNSUpstreamPoolState()})
default:
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
}
}

View File

@@ -0,0 +1,63 @@
package app
import dnscfgpkg "selective-vpn-api/app/dnscfg"
// ---------------------------------------------------------------------
// EN: `loadDNSMode` loads dns mode from storage or config.
// RU: `loadDNSMode` - загружает dns mode из хранилища или конфига.
// ---------------------------------------------------------------------
func loadDNSMode() DNSMode {
st, needPersist := dnscfgpkg.LoadMode(dnscfgModeConfig())
mode := DNSMode{
ViaSmartDNS: st.ViaSmartDNS,
SmartDNSAddr: st.SmartDNSAddr,
Mode: DNSResolverMode(st.Mode),
}
if needPersist {
_ = saveDNSMode(mode)
}
return mode
}
// ---------------------------------------------------------------------
// EN: `saveDNSMode` saves dns mode to persistent storage.
// RU: `saveDNSMode` - сохраняет dns mode в постоянное хранилище.
// ---------------------------------------------------------------------
func saveDNSMode(mode DNSMode) error {
return dnscfgpkg.SaveMode(
dnscfgModeConfig(),
dnscfgpkg.ModeState{
ViaSmartDNS: mode.ViaSmartDNS,
SmartDNSAddr: mode.SmartDNSAddr,
Mode: string(mode.Mode),
},
)
}
func dnscfgModeConfig() dnscfgpkg.ModeConfig {
return dnscfgpkg.ModeConfig{
Path: dnsModePath,
DirectMode: string(DNSModeDirect),
DefaultSmartDNSAddr: resolveDefaultSmartDNSAddr(),
NormalizeResolverMode: func(mode string, viaSmartDNS bool) string {
return string(normalizeDNSResolverMode(DNSResolverMode(mode), viaSmartDNS))
},
NormalizeSmartDNSAddr: normalizeSmartDNSAddr,
}
}
// ---------------------------------------------------------------------
// EN: `normalizeDNSResolverMode` normalizes dns resolver mode values.
// RU: `normalizeDNSResolverMode` - нормализует значения режима dns резолвера.
// ---------------------------------------------------------------------
func normalizeDNSResolverMode(mode DNSResolverMode, viaSmartDNS bool) DNSResolverMode {
return DNSResolverMode(
dnscfgpkg.NormalizeResolverMode(
string(mode),
viaSmartDNS,
string(DNSModeDirect),
string(DNSModeSmartDNS),
string(DNSModeHybridWildcard),
),
)
}

View File

@@ -0,0 +1,90 @@
package app
import (
"os"
"strings"
dnscfgpkg "selective-vpn-api/app/dnscfg"
)
// ---------------------------------------------------------------------
// EN: `smartDNSAddr` contains core logic for smart d n s addr.
// RU: `smartDNSAddr` - содержит основную логику для smart d n s addr.
// ---------------------------------------------------------------------
func smartDNSAddr() string {
return loadDNSMode().SmartDNSAddr
}
// ---------------------------------------------------------------------
// EN: `smartDNSForced` contains core logic for smart d n s forced.
// RU: `smartDNSForced` - содержит основную логику для smart d n s forced.
// ---------------------------------------------------------------------
func smartDNSForced() bool {
return dnscfgpkg.SmartDNSForced(os.Getenv(smartDNSForceEnv))
}
// ---------------------------------------------------------------------
// EN: `smartdnsUnitState` contains core logic for smartdns unit state.
// RU: `smartdnsUnitState` - содержит основную логику для smartdns unit state.
// ---------------------------------------------------------------------
func smartdnsUnitState() string {
return dnscfgpkg.UnitState(runCommand, "smartdns-local.service")
}
// ---------------------------------------------------------------------
// EN: `runSmartdnsUnitAction` runs the workflow for smartdns unit action.
// RU: `runSmartdnsUnitAction` - запускает рабочий процесс для smartdns unit action.
// ---------------------------------------------------------------------
func runSmartdnsUnitAction(action string) cmdResult {
res := dnscfgpkg.RunUnitAction(runCommand, "smartdns-local.service", action)
msg := res.Message
if res.OK {
msg = "smartdns " + strings.TrimSpace(action) + " done"
}
return cmdResult{
OK: res.OK,
ExitCode: res.ExitCode,
Stdout: res.Stdout,
Stderr: res.Stderr,
Message: msg,
}
}
// ---------------------------------------------------------------------
// EN: `resolveDefaultSmartDNSAddr` resolves default smart d n s addr into concrete values.
// RU: `resolveDefaultSmartDNSAddr` - резолвит default smart d n s addr в конкретные значения.
// ---------------------------------------------------------------------
func resolveDefaultSmartDNSAddr() string {
return dnscfgpkg.ResolveDefaultSmartDNSAddr(
os.Getenv(smartDNSAddrEnv),
[]string{
"/opt/stack/adguardapp/smartdns.conf",
"/etc/selective-vpn/smartdns.conf",
},
smartDNSDefaultAddr,
)
}
// ---------------------------------------------------------------------
// EN: `smartDNSAddrFromConfig` loads smart d n s addr from config.
// RU: `smartDNSAddrFromConfig` - загружает smart d n s addr из конфига.
// ---------------------------------------------------------------------
func smartDNSAddrFromConfig(path string) string {
return dnscfgpkg.SmartDNSAddrFromConfig(path)
}
// ---------------------------------------------------------------------
// EN: `normalizeDNSUpstream` parses dns upstream and returns normalized values.
// RU: `normalizeDNSUpstream` - парсит dns upstream и возвращает нормализованные значения.
// ---------------------------------------------------------------------
func normalizeDNSUpstream(raw string, defaultPort string) string {
return dnscfgpkg.NormalizeDNSUpstream(raw, defaultPort)
}
// ---------------------------------------------------------------------
// EN: `normalizeSmartDNSAddr` parses smart d n s addr and returns normalized values.
// RU: `normalizeSmartDNSAddr` - парсит smart d n s addr и возвращает нормализованные значения.
// ---------------------------------------------------------------------
func normalizeSmartDNSAddr(raw string) string {
return dnscfgpkg.NormalizeSmartDNSAddr(raw)
}

View File

@@ -0,0 +1,104 @@
package app
import dnscfgpkg "selective-vpn-api/app/dnscfg"
func dnscfgPoolItemsFromApp(items []DNSUpstreamPoolItem) []dnscfgpkg.UpstreamPoolItem {
out := make([]dnscfgpkg.UpstreamPoolItem, 0, len(items))
for _, item := range items {
out = append(out, dnscfgpkg.UpstreamPoolItem{
Addr: item.Addr,
Enabled: item.Enabled,
})
}
return out
}
func dnscfgPoolItemsToApp(items []dnscfgpkg.UpstreamPoolItem) []DNSUpstreamPoolItem {
out := make([]DNSUpstreamPoolItem, 0, len(items))
for _, item := range items {
out = append(out, DNSUpstreamPoolItem{
Addr: item.Addr,
Enabled: item.Enabled,
})
}
return out
}
func dnscfgUpstreamsFromApp(cfg DNSUpstreams) dnscfgpkg.Upstreams {
return dnscfgpkg.Upstreams{
Default1: cfg.Default1,
Default2: cfg.Default2,
Meta1: cfg.Meta1,
Meta2: cfg.Meta2,
}
}
func dnscfgUpstreamsToApp(cfg dnscfgpkg.Upstreams) DNSUpstreams {
return DNSUpstreams{
Default1: cfg.Default1,
Default2: cfg.Default2,
Meta1: cfg.Meta1,
Meta2: cfg.Meta2,
}
}
func dnscfgLegacyDefaults() dnscfgpkg.Upstreams {
return dnscfgpkg.Upstreams{
Default1: defaultDNS1,
Default2: defaultDNS2,
Meta1: defaultMeta1,
Meta2: defaultMeta2,
}
}
func normalizeDNSUpstreamPoolItems(items []DNSUpstreamPoolItem) []DNSUpstreamPoolItem {
return dnscfgPoolItemsToApp(
dnscfgpkg.NormalizeUpstreamPoolItems(dnscfgPoolItemsFromApp(items), normalizeDNSUpstream),
)
}
func dnsUpstreamPoolFromLegacy(cfg DNSUpstreams) []DNSUpstreamPoolItem {
return dnscfgPoolItemsToApp(
dnscfgpkg.UpstreamPoolFromLegacy(dnscfgUpstreamsFromApp(cfg), normalizeDNSUpstream),
)
}
func dnsUpstreamPoolToLegacy(items []DNSUpstreamPoolItem) DNSUpstreams {
return dnscfgUpstreamsToApp(
dnscfgpkg.UpstreamPoolToLegacy(
dnscfgPoolItemsFromApp(items),
dnscfgLegacyDefaults(),
normalizeDNSUpstream,
),
)
}
func loadEnabledDNSUpstreamPool() []string {
items := loadDNSUpstreamPoolState()
return uniqueStrings(
dnscfgpkg.EnabledPool(dnscfgPoolItemsFromApp(items), normalizeDNSUpstream),
)
}
// ---------------------------------------------------------------------
// EN: `loadDNSUpstreamsConf` loads dns upstreams conf from storage or config.
// RU: `loadDNSUpstreamsConf` - загружает dns upstreams conf из хранилища или конфига.
// ---------------------------------------------------------------------
func loadDNSUpstreamsConf() DNSUpstreams {
pool := loadDNSUpstreamPoolState()
if len(pool) > 0 {
return dnsUpstreamPoolToLegacy(pool)
}
return loadDNSUpstreamsConfFile()
}
// ---------------------------------------------------------------------
// EN: `saveDNSUpstreamsConf` saves dns upstreams conf to persistent storage.
// RU: `saveDNSUpstreamsConf` - сохраняет dns upstreams conf в постоянное хранилище.
// ---------------------------------------------------------------------
func saveDNSUpstreamsConf(cfg DNSUpstreams) error {
if err := saveDNSUpstreamsConfFile(cfg); err != nil {
return err
}
return saveDNSUpstreamPoolFile(dnsUpstreamPoolFromLegacy(cfg))
}

View File

@@ -0,0 +1,42 @@
package app
import (
"encoding/json"
"os"
"path/filepath"
dnscfgpkg "selective-vpn-api/app/dnscfg"
)
func loadDNSUpstreamsConfFile() DNSUpstreams {
data, err := os.ReadFile(dnsUpstreamsConf)
if err != nil {
return dnscfgUpstreamsToApp(dnscfgLegacyDefaults())
}
pkgCfg := dnscfgpkg.ParseUpstreamsConf(string(data), dnscfgLegacyDefaults(), normalizeDNSUpstream)
return dnscfgUpstreamsToApp(pkgCfg)
}
func saveDNSUpstreamsConfFile(cfg DNSUpstreams) error {
pkgCfg := dnscfgpkg.NormalizeUpstreams(dnscfgUpstreamsFromApp(cfg), dnscfgLegacyDefaults(), normalizeDNSUpstream)
cfg = dnscfgUpstreamsToApp(pkgCfg)
content := dnscfgpkg.RenderUpstreamsConf(pkgCfg)
if err := os.MkdirAll(filepath.Dir(dnsUpstreamsConf), 0o755); err != nil {
return err
}
tmp := dnsUpstreamsConf + ".tmp"
if err := os.WriteFile(tmp, []byte(content), 0o644); err != nil {
return err
}
if err := os.Rename(tmp, dnsUpstreamsConf); err != nil {
return err
}
// Legacy JSON mirror for backward compatibility with older UI/runtime bits.
_ = os.MkdirAll(stateDir, 0o755)
if b, err := json.MarshalIndent(cfg, "", " "); err == nil {
_ = os.WriteFile(dnsUpstreamsPath, b, 0o644)
}
return nil
}

View File

@@ -0,0 +1,53 @@
package app
import (
"encoding/json"
"os"
"path/filepath"
)
func saveDNSUpstreamPoolFile(items []DNSUpstreamPoolItem) error {
state := DNSUpstreamPoolState{Items: normalizeDNSUpstreamPoolItems(items)}
if err := os.MkdirAll(filepath.Dir(dnsUpstreamPool), 0o755); err != nil {
return err
}
tmp := dnsUpstreamPool + ".tmp"
b, err := json.MarshalIndent(state, "", " ")
if err != nil {
return err
}
if err := os.WriteFile(tmp, b, 0o644); err != nil {
return err
}
return os.Rename(tmp, dnsUpstreamPool)
}
func loadDNSUpstreamPoolState() []DNSUpstreamPoolItem {
data, err := os.ReadFile(dnsUpstreamPool)
if err == nil {
var st DNSUpstreamPoolState
if json.Unmarshal(data, &st) == nil {
items := normalizeDNSUpstreamPoolItems(st.Items)
if len(items) > 0 {
return items
}
}
}
legacy := loadDNSUpstreamsConfFile()
items := dnsUpstreamPoolFromLegacy(legacy)
if len(items) > 0 {
_ = saveDNSUpstreamPoolFile(items)
}
return items
}
func saveDNSUpstreamPoolState(items []DNSUpstreamPoolItem) error {
items = normalizeDNSUpstreamPoolItems(items)
if len(items) == 0 {
items = dnsUpstreamPoolFromLegacy(loadDNSUpstreamsConfFile())
}
if err := saveDNSUpstreamPoolFile(items); err != nil {
return err
}
return saveDNSUpstreamsConfFile(dnsUpstreamPoolToLegacy(items))
}

View File

@@ -0,0 +1,5 @@
package app
// SmartDNS HTTP handlers are split by role:
// - runtime status/toggle: dns_smartdns_handlers_runtime.go
// - prewarm execution/helpers: dns_smartdns_handlers_prewarm.go

View File

@@ -0,0 +1,141 @@
package app
import (
"context"
"encoding/json"
"io"
"net/http"
"os"
dnscfgpkg "selective-vpn-api/app/dnscfg"
"time"
)
// ---------------------------------------------------------------------
// EN: `handleSmartdnsPrewarm` forces DNS lookups for wildcard domains via SmartDNS.
// EN: This warms agvpn_dyn4 in realtime through SmartDNS nftset runtime integration.
// RU: `handleSmartdnsPrewarm` принудительно резолвит wildcard-домены через SmartDNS.
// RU: Это прогревает agvpn_dyn4 в realtime через runtime-интеграцию SmartDNS nftset.
// ---------------------------------------------------------------------
func handleSmartdnsPrewarm(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
var body struct {
Limit int `json:"limit"`
Workers int `json:"workers"`
TimeoutMS int `json:"timeout_ms"`
AggressiveSubs bool `json:"aggressive_subs"`
}
if r.Body != nil {
defer r.Body.Close()
_ = json.NewDecoder(io.LimitReader(r.Body, 1<<20)).Decode(&body)
}
writeJSON(w, http.StatusOK, runSmartdnsPrewarm(body.Limit, body.Workers, body.TimeoutMS, body.AggressiveSubs))
}
func runSmartdnsPrewarm(limit, workers, timeoutMS int, aggressiveSubs bool) cmdResult {
mode := loadDNSMode()
runtimeEnabled := smartDNSRuntimeEnabled()
source := "resolver"
if runtimeEnabled {
source = "smartdns_runtime"
}
smartdnsAddr := normalizeSmartDNSAddr(mode.SmartDNSAddr)
if smartdnsAddr == "" {
smartdnsAddr = resolveDefaultSmartDNSAddr()
}
aggressive := aggressiveSubs || prewarmAggressiveFromEnv()
subs := []string{}
subsPerBaseLimit := 0
if aggressive {
subs = loadList(domainDir + "/subs.txt")
subsPerBaseLimit = envInt("RESOLVE_SUBS_PER_BASE_LIMIT", 0)
if subsPerBaseLimit < 0 {
subsPerBaseLimit = 0
}
}
res := dnscfgpkg.RunPrewarm(
dnscfgpkg.PrewarmInput{
Mode: string(mode.Mode),
Source: source,
RuntimeEnabled: runtimeEnabled,
SmartDNSAddr: smartdnsAddr,
Wildcards: loadSmartDNSWildcardDomains(nil),
AggressiveSubs: aggressive,
Subs: subs,
SubsPerBaseLimit: subsPerBaseLimit,
Limit: limit,
Workers: workers,
TimeoutMS: timeoutMS,
EnvWorkers: envInt("SMARTDNS_PREWARM_WORKERS", 24),
EnvTimeoutMS: envInt("SMARTDNS_PREWARM_TIMEOUT_MS", 1800),
MaxHostsLog: 200,
WildcardMapPath: lastIPsMapDyn,
},
dnscfgpkg.PrewarmDeps{
IsGoogleLike: isGoogleLike,
EnsureRuntimeSet: func() {
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "table", "inet", "agvpn")
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "set", "inet", "agvpn", "agvpn_dyn4", "{", "type", "ipv4_addr", ";", "flags", "interval", ";", "}")
},
DigA: func(host string, dnsList []string, timeout time.Duration) ([]string, dnscfgpkg.PrewarmDNSMetrics) {
ips, stats := digA(host, dnsList, timeout, nil)
return ips, prewarmMetricsFromDNSMetrics(stats)
},
ReadDynSet: func() ([]string, error) {
return readNftSetElements("agvpn_dyn4")
},
ApplyDynSet: func(ips []string) error {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
return nftUpdateSetIPsSmart(ctx, "agvpn_dyn4", ips, nil)
},
Logf: func(message string) {
appendTraceLineTo(smartdnsLogPath, "smartdns", message)
},
},
)
return cmdResult{
OK: res.OK,
Message: res.Message,
ExitCode: res.ExitCode,
}
}
func prewarmAggressiveFromEnv() bool {
return dnscfgpkg.SmartDNSForced(os.Getenv("SMARTDNS_PREWARM_AGGRESSIVE"))
}
func prewarmMetricsFromDNSMetrics(in dnsMetrics) dnscfgpkg.PrewarmDNSMetrics {
out := dnscfgpkg.PrewarmDNSMetrics{
Attempts: in.Attempts,
OK: in.OK,
NXDomain: in.NXDomain,
Timeout: in.Timeout,
Temporary: in.Temporary,
Other: in.Other,
Skipped: in.Skipped,
}
if len(in.PerUpstream) > 0 {
out.PerUpstream = make(map[string]dnscfgpkg.PrewarmDNSUpstreamMetrics, len(in.PerUpstream))
for upstream, stats := range in.PerUpstream {
if stats == nil {
continue
}
out.PerUpstream[upstream] = dnscfgpkg.PrewarmDNSUpstreamMetrics{
Attempts: stats.Attempts,
OK: stats.OK,
NXDomain: stats.NXDomain,
Timeout: stats.Timeout,
Temporary: stats.Temporary,
Other: stats.Other,
Skipped: stats.Skipped,
}
}
}
return out
}

View File

@@ -0,0 +1,70 @@
package app
import (
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
)
func handleSmartdnsRuntime(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
writeJSON(w, http.StatusOK, smartDNSRuntimeSnapshot())
case http.MethodPost:
var body SmartDNSRuntimeRequest
if r.Body != nil {
defer r.Body.Close()
if err := json.NewDecoder(io.LimitReader(r.Body, 1<<20)).Decode(&body); err != nil && err != io.EOF {
http.Error(w, "bad json", http.StatusBadRequest)
return
}
}
if body.Enabled == nil {
http.Error(w, "enabled is required", http.StatusBadRequest)
return
}
prev := loadSmartDNSRuntimeState(nil)
next := prev
next.Enabled = *body.Enabled
if err := saveSmartDNSRuntimeState(next); err != nil {
http.Error(w, "runtime state write error", http.StatusInternalServerError)
return
}
changed, err := applySmartDNSRuntimeConfig(next.Enabled)
if err != nil {
_ = saveSmartDNSRuntimeState(prev)
http.Error(w, "runtime config apply error: "+err.Error(), http.StatusInternalServerError)
return
}
restart := true
if body.Restart != nil {
restart = *body.Restart
}
restarted := false
msg := ""
if restart && smartdnsUnitState() == "active" {
res := runSmartdnsUnitAction("restart")
restarted = res.OK
if !res.OK {
msg = "runtime config changed, but smartdns restart failed: " + strings.TrimSpace(res.Message)
}
}
if msg == "" {
msg = fmt.Sprintf("smartdns runtime set: enabled=%t changed=%t restarted=%t", next.Enabled, changed, restarted)
}
appendTraceLineTo(smartdnsLogPath, "smartdns", msg)
resp := smartDNSRuntimeSnapshot()
resp.Changed = changed
resp.Restarted = restarted
resp.Message = msg
writeJSON(w, http.StatusOK, resp)
default:
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
}
}

View File

@@ -0,0 +1,358 @@
package dnscfg
import (
"context"
"fmt"
"net"
"sort"
"strings"
"sync"
"time"
)
const (
BenchmarkProfileQuick = "quick"
BenchmarkProfileLoad = "load"
BenchmarkErrorNXDomain = "nxdomain"
BenchmarkErrorTimeout = "timeout"
BenchmarkErrorTemporary = "temporary"
BenchmarkErrorOther = "other"
)
var BenchmarkDefaultDomains = []string{
"cloudflare.com",
"google.com",
"telegram.org",
"github.com",
"youtube.com",
"twitter.com",
}
type BenchmarkOptions struct {
Profile string
LoadWorkers int
Rounds int
SyntheticPerDomain int
}
type BenchmarkResult struct {
Upstream string
Attempts int
OK int
Fail int
NXDomain int
Timeout int
Temporary int
Other int
AvgMS int
P95MS int
Score float64
Color string
}
func NormalizeBenchmarkProfile(raw string) string {
switch strings.ToLower(strings.TrimSpace(raw)) {
case "", BenchmarkProfileLoad:
return BenchmarkProfileLoad
case BenchmarkProfileQuick:
return BenchmarkProfileQuick
default:
return BenchmarkProfileLoad
}
}
func MakeDNSBenchmarkOptions(profile string, concurrency int) BenchmarkOptions {
if concurrency < 1 {
concurrency = 1
}
if profile == BenchmarkProfileQuick {
return BenchmarkOptions{
Profile: BenchmarkProfileQuick,
LoadWorkers: 1,
Rounds: 1,
SyntheticPerDomain: 0,
}
}
workers := concurrency * 2
if workers < 4 {
workers = 4
}
if workers > 16 {
workers = 16
}
return BenchmarkOptions{
Profile: BenchmarkProfileLoad,
LoadWorkers: workers,
Rounds: 3,
SyntheticPerDomain: 2,
}
}
func NormalizeBenchmarkUpstreamStrings(in []string, normalizeUpstream func(string, string) string) []string {
out := make([]string, 0, len(in))
seen := map[string]struct{}{}
for _, raw := range in {
n := strings.TrimSpace(raw)
if normalizeUpstream != nil {
n = normalizeUpstream(n, "53")
}
if n == "" {
continue
}
if _, ok := seen[n]; ok {
continue
}
seen[n] = struct{}{}
out = append(out, n)
}
return out
}
func NormalizeBenchmarkDomains(in []string) []string {
if len(in) == 0 {
return nil
}
out := make([]string, 0, len(in))
seen := map[string]struct{}{}
for _, raw := range in {
d := strings.TrimSuffix(strings.ToLower(strings.TrimSpace(raw)), ".")
if d == "" || strings.HasPrefix(d, "#") {
continue
}
if _, ok := seen[d]; ok {
continue
}
seen[d] = struct{}{}
out = append(out, d)
}
if len(out) > 100 {
out = out[:100]
}
return out
}
func BenchmarkDNSUpstream(
upstream string,
domains []string,
timeout time.Duration,
attempts int,
opts BenchmarkOptions,
lookupAOnce func(host, upstream string, timeout time.Duration) ([]string, error),
classifyErr func(error) string,
) BenchmarkResult {
res := BenchmarkResult{Upstream: upstream}
probes := BuildBenchmarkProbeHosts(domains, attempts, opts)
if len(probes) == 0 {
return res
}
durations := make([]int, 0, len(probes))
var mu sync.Mutex
jobs := make(chan string, len(probes))
workers := opts.LoadWorkers
if workers < 1 {
workers = 1
}
if workers > len(probes) {
workers = len(probes)
}
var wg sync.WaitGroup
for i := 0; i < workers; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for host := range jobs {
start := time.Now()
_, err := lookupAOnce(host, upstream, timeout)
elapsed := int(time.Since(start).Milliseconds())
if elapsed < 1 {
elapsed = 1
}
mu.Lock()
res.Attempts++
durations = append(durations, elapsed)
if err != nil {
res.Fail++
switch strings.ToLower(strings.TrimSpace(classifyErr(err))) {
case BenchmarkErrorNXDomain:
res.NXDomain++
case BenchmarkErrorTimeout:
res.Timeout++
case BenchmarkErrorTemporary:
res.Temporary++
default:
res.Other++
}
} else {
res.OK++
}
mu.Unlock()
}
}()
}
for _, host := range probes {
jobs <- host
}
close(jobs)
wg.Wait()
if len(durations) > 0 {
sort.Ints(durations)
sum := 0
for _, d := range durations {
sum += d
}
res.AvgMS = sum / len(durations)
idx := int(float64(len(durations)-1) * 0.95)
if idx < 0 {
idx = 0
}
res.P95MS = durations[idx]
}
total := res.Attempts
if total > 0 {
okRate := float64(res.OK) / float64(total)
answeredRate := float64(res.OK+res.NXDomain+res.Temporary+res.Other) / float64(total)
timeoutRate := float64(res.Timeout) / float64(total)
temporaryRate := float64(res.Temporary) / float64(total)
otherRate := float64(res.Other) / float64(total)
avg := float64(res.AvgMS)
if avg <= 0 {
avg = float64(timeout.Milliseconds())
}
p95 := float64(res.P95MS)
if p95 <= 0 {
p95 = avg
}
res.Score = answeredRate*100.0 + okRate*15.0 - timeoutRate*120.0 - temporaryRate*35.0 - otherRate*20.0 - (avg / 25.0) - (p95 / 45.0)
}
timeoutRate := 0.0
answeredRate := 0.0
if res.Attempts > 0 {
timeoutRate = float64(res.Timeout) / float64(res.Attempts)
answeredRate = float64(res.OK+res.NXDomain+res.Temporary+res.Other) / float64(res.Attempts)
}
switch {
case answeredRate < 0.85 || timeoutRate >= 0.10 || res.P95MS > 1800:
res.Color = "red"
case answeredRate >= 0.97 && timeoutRate <= 0.02 && res.P95MS <= 700:
res.Color = "green"
default:
res.Color = "yellow"
}
return res
}
func BuildBenchmarkProbeHosts(domains []string, attempts int, opts BenchmarkOptions) []string {
if len(domains) == 0 {
return nil
}
if attempts < 1 {
attempts = 1
}
rounds := opts.Rounds
if rounds < 1 {
rounds = 1
}
synth := opts.SyntheticPerDomain
if synth < 0 {
synth = 0
}
out := make([]string, 0, len(domains)*attempts*rounds*(1+synth))
for round := 0; round < rounds; round++ {
for _, host := range domains {
for i := 0; i < attempts; i++ {
out = append(out, host)
}
for n := 0; n < synth; n++ {
out = append(out, fmt.Sprintf("svpn-bench-%d-%d.%s", round+1, n+1, host))
}
}
}
if len(out) > 10000 {
out = out[:10000]
}
return out
}
func DNSLookupAOnce(
host string,
upstream string,
timeout time.Duration,
splitDNS func(string) (string, string),
isPrivateIPv4 func(string) bool,
) ([]string, error) {
if splitDNS == nil {
return nil, fmt.Errorf("splitDNS callback is nil")
}
server, port := splitDNS(upstream)
if server == "" {
return nil, fmt.Errorf("upstream empty")
}
if port == "" {
port = "53"
}
addr := net.JoinHostPort(server, port)
resolver := &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
d := net.Dialer{}
return d.DialContext(ctx, "udp", addr)
},
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
records, err := resolver.LookupHost(ctx, host)
cancel()
if err != nil {
return nil, err
}
seen := map[string]struct{}{}
out := make([]string, 0, len(records))
for _, ip := range records {
if isPrivateIPv4 != nil && isPrivateIPv4(ip) {
continue
}
if _, ok := seen[ip]; ok {
continue
}
seen[ip] = struct{}{}
out = append(out, ip)
}
if len(out) == 0 {
return nil, fmt.Errorf("no public ips")
}
return out, nil
}
func BenchmarkTopN(results []BenchmarkResult, n int, fallback []string) []string {
out := make([]string, 0, n)
for _, item := range results {
if item.OK <= 0 {
continue
}
out = append(out, item.Upstream)
if len(out) >= n {
return out
}
}
for _, item := range fallback {
if len(out) >= n {
break
}
dup := false
for _, e := range out {
if e == item {
dup = true
break
}
}
if !dup {
out = append(out, item)
}
}
return out
}

View File

@@ -0,0 +1,136 @@
package dnscfg
import (
"encoding/json"
"os"
"path/filepath"
"strings"
)
func NormalizeResolverMode(mode string, viaSmartDNS bool, directMode string, smartDNSMode string, hybridWildcardMode string) string {
switch strings.ToLower(strings.TrimSpace(mode)) {
case strings.ToLower(strings.TrimSpace(directMode)):
return directMode
case strings.ToLower(strings.TrimSpace(smartDNSMode)):
return hybridWildcardMode
case strings.ToLower(strings.TrimSpace(hybridWildcardMode)), "hybrid":
return hybridWildcardMode
default:
if viaSmartDNS {
return hybridWildcardMode
}
return directMode
}
}
func SmartDNSForced(envRaw string) bool {
switch strings.TrimSpace(strings.ToLower(envRaw)) {
case "1", "true", "yes", "on":
return true
default:
return false
}
}
type ModeState struct {
ViaSmartDNS bool `json:"via_smartdns"`
SmartDNSAddr string `json:"smartdns_addr"`
Mode string `json:"mode"`
}
type ModeConfig struct {
Path string
DirectMode string
DefaultSmartDNSAddr string
NormalizeResolverMode func(mode string, viaSmartDNS bool) string
NormalizeSmartDNSAddr func(raw string) string
}
func LoadMode(cfg ModeConfig) (ModeState, bool) {
mode := ModeState{
ViaSmartDNS: false,
SmartDNSAddr: strings.TrimSpace(cfg.DefaultSmartDNSAddr),
Mode: strings.TrimSpace(cfg.DirectMode),
}
needPersist := false
data, err := os.ReadFile(strings.TrimSpace(cfg.Path))
switch {
case err == nil:
var stored ModeState
if err := json.Unmarshal(data, &stored); err == nil {
normalized, changed := normalizeModeState(stored, cfg)
mode = normalized
if strings.TrimSpace(stored.Mode) == "" || stored.ViaSmartDNS != normalized.ViaSmartDNS || changed {
needPersist = true
}
} else {
needPersist = true
}
case os.IsNotExist(err):
needPersist = true
}
normalized, changed := normalizeModeState(mode, cfg)
mode = normalized
if changed {
needPersist = true
}
return mode, needPersist
}
func SaveMode(cfg ModeConfig, mode ModeState) error {
normalized, _ := normalizeModeState(mode, cfg)
path := strings.TrimSpace(cfg.Path)
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
return err
}
tmp := path + ".tmp"
b, err := json.MarshalIndent(normalized, "", " ")
if err != nil {
return err
}
if err := os.WriteFile(tmp, b, 0o644); err != nil {
return err
}
return os.Rename(tmp, path)
}
func normalizeModeState(mode ModeState, cfg ModeConfig) (ModeState, bool) {
changed := false
prevMode := mode.Mode
mode.Mode = normalizeResolverMode(cfg, mode.Mode, mode.ViaSmartDNS)
if mode.Mode != prevMode {
changed = true
}
viaSmartDNS := mode.Mode != strings.TrimSpace(cfg.DirectMode)
if mode.ViaSmartDNS != viaSmartDNS {
mode.ViaSmartDNS = viaSmartDNS
changed = true
}
prevAddr := mode.SmartDNSAddr
mode.SmartDNSAddr = normalizeSmartDNSAddr(cfg, mode.SmartDNSAddr)
if mode.SmartDNSAddr == "" {
mode.SmartDNSAddr = strings.TrimSpace(cfg.DefaultSmartDNSAddr)
}
if mode.SmartDNSAddr != prevAddr {
changed = true
}
return mode, changed
}
func normalizeResolverMode(cfg ModeConfig, mode string, viaSmartDNS bool) string {
if cfg.NormalizeResolverMode == nil {
return strings.TrimSpace(mode)
}
return strings.TrimSpace(cfg.NormalizeResolverMode(mode, viaSmartDNS))
}
func normalizeSmartDNSAddr(cfg ModeConfig, raw string) string {
if cfg.NormalizeSmartDNSAddr == nil {
return strings.TrimSpace(raw)
}
return strings.TrimSpace(cfg.NormalizeSmartDNSAddr(raw))
}

View File

@@ -0,0 +1,100 @@
package dnscfg
import "strings"
type Upstreams struct {
Default1 string
Default2 string
Meta1 string
Meta2 string
}
type UpstreamPoolItem struct {
Addr string
Enabled bool
}
func NormalizeUpstreamPoolItems(items []UpstreamPoolItem, normalizeUpstream func(raw string, defaultPort string) string) []UpstreamPoolItem {
if len(items) == 0 {
return nil
}
seen := map[string]struct{}{}
out := make([]UpstreamPoolItem, 0, len(items))
for _, item := range items {
addr := strings.TrimSpace(item.Addr)
if normalizeUpstream != nil {
addr = normalizeUpstream(addr, "53")
}
if addr == "" {
continue
}
if _, ok := seen[addr]; ok {
continue
}
seen[addr] = struct{}{}
out = append(out, UpstreamPoolItem{
Addr: addr,
Enabled: item.Enabled,
})
}
return out
}
func UpstreamPoolFromLegacy(cfg Upstreams, normalizeUpstream func(raw string, defaultPort string) string) []UpstreamPoolItem {
out := []UpstreamPoolItem{
{Addr: cfg.Default1, Enabled: true},
{Addr: cfg.Default2, Enabled: true},
{Addr: cfg.Meta1, Enabled: true},
{Addr: cfg.Meta2, Enabled: true},
}
return NormalizeUpstreamPoolItems(out, normalizeUpstream)
}
func UpstreamPoolToLegacy(items []UpstreamPoolItem, defaults Upstreams, normalizeUpstream func(raw string, defaultPort string) string) Upstreams {
items = NormalizeUpstreamPoolItems(items, normalizeUpstream)
out := defaults
enabled := make([]string, 0, len(items))
for _, item := range items {
if !item.Enabled {
continue
}
addr := strings.TrimSpace(item.Addr)
if normalizeUpstream != nil {
addr = normalizeUpstream(addr, "53")
}
if addr != "" {
enabled = append(enabled, addr)
}
}
if len(enabled) > 0 {
out.Default1 = enabled[0]
}
if len(enabled) > 1 {
out.Default2 = enabled[1]
}
if len(enabled) > 2 {
out.Meta1 = enabled[2]
}
if len(enabled) > 3 {
out.Meta2 = enabled[3]
}
return out
}
func EnabledPool(items []UpstreamPoolItem, normalizeUpstream func(raw string, defaultPort string) string) []string {
items = NormalizeUpstreamPoolItems(items, normalizeUpstream)
out := make([]string, 0, len(items))
for _, item := range items {
if !item.Enabled {
continue
}
addr := strings.TrimSpace(item.Addr)
if normalizeUpstream != nil {
addr = normalizeUpstream(addr, "53")
}
if addr != "" {
out = append(out, addr)
}
}
return out
}

View File

@@ -0,0 +1,392 @@
package dnscfg
import (
"fmt"
"sort"
"strings"
"time"
)
type PrewarmDNSUpstreamMetrics struct {
Attempts int
OK int
NXDomain int
Timeout int
Temporary int
Other int
Skipped int
}
type PrewarmDNSMetrics struct {
Attempts int
OK int
NXDomain int
Timeout int
Temporary int
Other int
Skipped int
PerUpstream map[string]PrewarmDNSUpstreamMetrics
}
func (m *PrewarmDNSMetrics) Merge(other PrewarmDNSMetrics) {
m.Attempts += other.Attempts
m.OK += other.OK
m.NXDomain += other.NXDomain
m.Timeout += other.Timeout
m.Temporary += other.Temporary
m.Other += other.Other
m.Skipped += other.Skipped
if len(other.PerUpstream) == 0 {
return
}
if m.PerUpstream == nil {
m.PerUpstream = map[string]PrewarmDNSUpstreamMetrics{}
}
for upstream, src := range other.PerUpstream {
dst := m.PerUpstream[upstream]
dst.Attempts += src.Attempts
dst.OK += src.OK
dst.NXDomain += src.NXDomain
dst.Timeout += src.Timeout
dst.Temporary += src.Temporary
dst.Other += src.Other
dst.Skipped += src.Skipped
m.PerUpstream[upstream] = dst
}
}
func (m PrewarmDNSMetrics) TotalErrors() int {
return m.NXDomain + m.Timeout + m.Temporary + m.Other
}
func (m PrewarmDNSMetrics) FormatPerUpstream() string {
if len(m.PerUpstream) == 0 {
return ""
}
keys := make([]string, 0, len(m.PerUpstream))
for k := range m.PerUpstream {
keys = append(keys, k)
}
sort.Strings(keys)
parts := make([]string, 0, len(keys))
for _, k := range keys {
v := m.PerUpstream[k]
parts = append(parts, fmt.Sprintf("%s{attempts=%d ok=%d nxdomain=%d timeout=%d temporary=%d other=%d skipped=%d}", k, v.Attempts, v.OK, v.NXDomain, v.Timeout, v.Temporary, v.Other, v.Skipped))
}
return strings.Join(parts, "; ")
}
type PrewarmInput struct {
Mode string
Source string
RuntimeEnabled bool
SmartDNSAddr string
Wildcards []string
AggressiveSubs bool
Subs []string
SubsPerBaseLimit int
Limit int
Workers int
TimeoutMS int
EnvWorkers int
EnvTimeoutMS int
MaxHostsLog int
WildcardMapPath string
}
type PrewarmDeps struct {
IsGoogleLike func(string) bool
EnsureRuntimeSet func()
DigA func(host string, dnsList []string, timeout time.Duration) ([]string, PrewarmDNSMetrics)
ReadDynSet func() ([]string, error)
ApplyDynSet func([]string) error
Logf func(message string)
}
type PrewarmResult struct {
OK bool
Message string
ExitCode int
ResolvedHosts int
}
func RunPrewarm(in PrewarmInput, deps PrewarmDeps) PrewarmResult {
smartdnsAddr := strings.TrimSpace(in.SmartDNSAddr)
if smartdnsAddr == "" {
return PrewarmResult{OK: false, Message: "SmartDNS address is empty"}
}
wildcards := trimNonEmptyUnique(in.Wildcards)
if len(wildcards) == 0 {
msg := "prewarm skipped: wildcard list is empty"
logPrewarm(deps.Logf, msg)
return PrewarmResult{OK: true, Message: msg}
}
aggressive := in.AggressiveSubs
subs := trimNonEmptyUnique(in.Subs)
subsPerBaseLimit := in.SubsPerBaseLimit
if subsPerBaseLimit < 0 {
subsPerBaseLimit = 0
}
domainSet := make(map[string]struct{}, len(wildcards)*(len(subs)+1))
for _, d := range wildcards {
domainSet[d] = struct{}{}
if !aggressive || isGoogleLikeSafe(deps.IsGoogleLike, d) {
continue
}
maxSubs := len(subs)
if subsPerBaseLimit > 0 && subsPerBaseLimit < maxSubs {
maxSubs = subsPerBaseLimit
}
for i := 0; i < maxSubs; i++ {
domainSet[subs[i]+"."+d] = struct{}{}
}
}
domains := make([]string, 0, len(domainSet))
for d := range domainSet {
domains = append(domains, d)
}
sort.Strings(domains)
if in.Limit > 0 && len(domains) > in.Limit {
domains = domains[:in.Limit]
}
if len(domains) == 0 {
msg := "prewarm skipped: expanded wildcard list is empty"
logPrewarm(deps.Logf, msg)
return PrewarmResult{OK: true, Message: msg}
}
workers := in.Workers
if workers <= 0 {
workers = in.EnvWorkers
if workers <= 0 {
workers = 24
}
}
if workers < 1 {
workers = 1
}
if workers > 200 {
workers = 200
}
timeoutMS := in.TimeoutMS
if timeoutMS <= 0 {
timeoutMS = in.EnvTimeoutMS
if timeoutMS <= 0 {
timeoutMS = 1800
}
}
if timeoutMS < 200 {
timeoutMS = 200
}
if timeoutMS > 15000 {
timeoutMS = 15000
}
timeout := time.Duration(timeoutMS) * time.Millisecond
if deps.EnsureRuntimeSet != nil {
deps.EnsureRuntimeSet()
}
logPrewarm(
deps.Logf,
fmt.Sprintf(
"prewarm start: mode=%s source=%s runtime_nftset=%t smartdns=%s wildcard_domains=%d expanded=%d aggressive_subs=%t workers=%d timeout_ms=%d",
strings.TrimSpace(in.Mode),
strings.TrimSpace(in.Source),
in.RuntimeEnabled,
smartdnsAddr,
len(wildcards),
len(domains),
aggressive,
workers,
timeoutMS,
),
)
type prewarmItem struct {
host string
ips []string
stats PrewarmDNSMetrics
}
jobs := make(chan string, len(domains))
results := make(chan prewarmItem, len(domains))
for i := 0; i < workers; i++ {
go func() {
for host := range jobs {
ips, stats := safeDigA(deps.DigA, host, []string{smartdnsAddr}, timeout)
results <- prewarmItem{host: host, ips: ips, stats: stats}
}
}()
}
for _, host := range domains {
jobs <- host
}
close(jobs)
resolvedHosts := 0
totalIPs := 0
errorHosts := 0
stats := PrewarmDNSMetrics{}
resolvedIPSet := map[string]struct{}{}
loggedHosts := 0
maxHostsLog := in.MaxHostsLog
if maxHostsLog <= 0 {
maxHostsLog = 200
}
for i := 0; i < len(domains); i++ {
item := <-results
stats.Merge(item.stats)
if item.stats.TotalErrors() > 0 {
errorHosts++
}
if len(item.ips) == 0 {
continue
}
resolvedHosts++
totalIPs += len(item.ips)
for _, ip := range item.ips {
if strings.TrimSpace(ip) != "" {
resolvedIPSet[ip] = struct{}{}
}
}
if loggedHosts < maxHostsLog {
logPrewarm(deps.Logf, fmt.Sprintf("prewarm add: %s -> %s", item.host, strings.Join(item.ips, ", ")))
loggedHosts++
}
}
manualAdded := 0
totalDynText := "n/a"
if !in.RuntimeEnabled {
existing, _ := safeReadDynSet(deps.ReadDynSet)
mergedSet := make(map[string]struct{}, len(existing)+len(resolvedIPSet))
for _, ip := range existing {
if strings.TrimSpace(ip) != "" {
mergedSet[ip] = struct{}{}
}
}
for ip := range resolvedIPSet {
if _, ok := mergedSet[ip]; !ok {
manualAdded++
}
mergedSet[ip] = struct{}{}
}
merged := make([]string, 0, len(mergedSet))
for ip := range mergedSet {
merged = append(merged, ip)
}
totalDynText = fmt.Sprintf("%d", len(merged))
if err := safeApplyDynSet(deps.ApplyDynSet, merged); err != nil {
msg := fmt.Sprintf("prewarm manual apply failed: %v", err)
logPrewarm(deps.Logf, msg)
return PrewarmResult{OK: false, Message: msg}
}
logPrewarm(
deps.Logf,
fmt.Sprintf("prewarm manual merge: existing=%d resolved=%d added=%d total_dyn=%d", len(existing), len(resolvedIPSet), manualAdded, len(merged)),
)
}
if len(domains) > loggedHosts {
logPrewarm(
deps.Logf,
fmt.Sprintf(
"prewarm add: trace truncated, omitted=%d hosts (full wildcard map: %s)",
len(domains)-loggedHosts,
strings.TrimSpace(in.WildcardMapPath),
),
)
}
msg := fmt.Sprintf(
"prewarm done: source=%s expanded=%d resolved=%d total_ips=%d error_hosts=%d dns_attempts=%d dns_ok=%d dns_errors=%d manual_added=%d dyn_total=%s",
strings.TrimSpace(in.Source),
len(domains),
resolvedHosts,
totalIPs,
errorHosts,
stats.Attempts,
stats.OK,
stats.TotalErrors(),
manualAdded,
totalDynText,
)
logPrewarm(deps.Logf, msg)
if perUpstream := stats.FormatPerUpstream(); perUpstream != "" {
logPrewarm(deps.Logf, "prewarm dns upstreams: "+perUpstream)
}
return PrewarmResult{
OK: true,
Message: msg,
ExitCode: resolvedHosts,
ResolvedHosts: resolvedHosts,
}
}
func logPrewarm(logf func(string), msg string) {
if logf != nil {
logf(msg)
}
}
func safeDigA(
dig func(host string, dnsList []string, timeout time.Duration) ([]string, PrewarmDNSMetrics),
host string,
dnsList []string,
timeout time.Duration,
) ([]string, PrewarmDNSMetrics) {
if dig == nil {
return nil, PrewarmDNSMetrics{}
}
return dig(host, dnsList, timeout)
}
func safeReadDynSet(read func() ([]string, error)) ([]string, error) {
if read == nil {
return nil, nil
}
return read()
}
func safeApplyDynSet(apply func([]string) error, ips []string) error {
if apply == nil {
return fmt.Errorf("apply dyn set callback is nil")
}
return apply(ips)
}
func isGoogleLikeSafe(check func(string) bool, domain string) bool {
if check == nil {
return false
}
return check(domain)
}
func trimNonEmptyUnique(in []string) []string {
if len(in) == 0 {
return nil
}
seen := map[string]struct{}{}
out := make([]string, 0, len(in))
for _, item := range in {
v := strings.TrimSpace(item)
if v == "" {
continue
}
if _, ok := seen[v]; ok {
continue
}
seen[v] = struct{}{}
out = append(out, v)
}
return out
}

View File

@@ -0,0 +1,102 @@
package dnscfg
import (
"net"
"os"
"strings"
)
func ResolveDefaultSmartDNSAddr(addrEnvValue string, configPaths []string, fallback string) string {
if v := strings.TrimSpace(addrEnvValue); v != "" {
if addr := NormalizeSmartDNSAddr(v); addr != "" {
return addr
}
}
for _, path := range configPaths {
if addr := SmartDNSAddrFromConfig(path); addr != "" {
return addr
}
}
return strings.TrimSpace(fallback)
}
func SmartDNSAddrFromConfig(path string) string {
data, err := os.ReadFile(path)
if err != nil {
return ""
}
for _, ln := range strings.Split(string(data), "\n") {
s := strings.TrimSpace(ln)
if s == "" || strings.HasPrefix(s, "#") {
continue
}
if !strings.HasPrefix(strings.ToLower(s), "bind ") {
continue
}
parts := strings.Fields(s)
if len(parts) < 2 {
continue
}
if addr := NormalizeSmartDNSAddr(parts[1]); addr != "" {
return addr
}
}
return ""
}
func NormalizeDNSUpstream(raw string, defaultPort string) string {
s := strings.TrimSpace(raw)
if s == "" {
return ""
}
s = strings.TrimPrefix(s, "udp://")
s = strings.TrimPrefix(s, "tcp://")
if strings.Contains(s, "#") {
parts := strings.SplitN(s, "#", 2)
host := strings.Trim(strings.TrimSpace(parts[0]), "[]")
port := strings.TrimSpace(parts[1])
if host == "" {
return ""
}
if port == "" {
port = defaultPort
}
return host + "#" + port
}
if host, port, err := net.SplitHostPort(s); err == nil {
host = strings.Trim(strings.TrimSpace(host), "[]")
port = strings.TrimSpace(port)
if host == "" {
return ""
}
if port == "" {
port = defaultPort
}
return host + "#" + port
}
if strings.Count(s, ":") == 1 {
parts := strings.SplitN(s, ":", 2)
host := strings.TrimSpace(parts[0])
port := strings.TrimSpace(parts[1])
if host != "" && port != "" {
return host + "#" + port
}
}
return s
}
func NormalizeSmartDNSAddr(raw string) string {
s := NormalizeDNSUpstream(raw, "6053")
if s == "" {
return ""
}
if strings.Contains(s, "#") {
return s
}
return s + "#6053"
}

View File

@@ -0,0 +1,47 @@
package dnscfg
import "strings"
type RunCommandFunc func(name string, args ...string) (stdout string, stderr string, exitCode int, err error)
type CmdResult struct {
OK bool
ExitCode int
Stdout string
Stderr string
Message string
}
func UnitState(run RunCommandFunc, unit string) string {
if run == nil {
return "unknown"
}
stdout, _, _, _ := run("systemctl", "is-active", strings.TrimSpace(unit))
st := strings.TrimSpace(stdout)
if st == "" {
return "unknown"
}
return st
}
func RunUnitAction(run RunCommandFunc, unit, action string) CmdResult {
if run == nil {
return CmdResult{
OK: false,
Message: "run command func is nil",
}
}
stdout, stderr, exitCode, err := run("systemctl", strings.TrimSpace(action), strings.TrimSpace(unit))
res := CmdResult{
OK: err == nil && exitCode == 0,
ExitCode: exitCode,
Stdout: stdout,
Stderr: stderr,
}
if err != nil {
res.Message = err.Error()
} else {
res.Message = strings.TrimSpace(unit) + " " + strings.TrimSpace(action) + " done"
}
return res
}

View File

@@ -0,0 +1,69 @@
package dnscfg
import (
"fmt"
"strings"
)
func NormalizeUpstreams(cfg Upstreams, defaults Upstreams, normalizeUpstream func(raw string, defaultPort string) string) Upstreams {
if normalizeUpstream != nil {
cfg.Default1 = normalizeUpstream(cfg.Default1, "53")
cfg.Default2 = normalizeUpstream(cfg.Default2, "53")
cfg.Meta1 = normalizeUpstream(cfg.Meta1, "53")
cfg.Meta2 = normalizeUpstream(cfg.Meta2, "53")
}
if strings.TrimSpace(cfg.Default1) == "" {
cfg.Default1 = defaults.Default1
}
if strings.TrimSpace(cfg.Default2) == "" {
cfg.Default2 = defaults.Default2
}
if strings.TrimSpace(cfg.Meta1) == "" {
cfg.Meta1 = defaults.Meta1
}
if strings.TrimSpace(cfg.Meta2) == "" {
cfg.Meta2 = defaults.Meta2
}
return cfg
}
func ParseUpstreamsConf(content string, defaults Upstreams, normalizeUpstream func(raw string, defaultPort string) string) Upstreams {
cfg := defaults
for _, ln := range strings.Split(content, "\n") {
s := strings.TrimSpace(ln)
if s == "" || strings.HasPrefix(s, "#") {
continue
}
parts := strings.Fields(s)
if len(parts) < 2 {
continue
}
key := strings.ToLower(strings.TrimSpace(parts[0]))
vals := parts[1:]
switch key {
case "default":
if len(vals) > 0 {
cfg.Default1 = vals[0]
}
if len(vals) > 1 {
cfg.Default2 = vals[1]
}
case "meta":
if len(vals) > 0 {
cfg.Meta1 = vals[0]
}
if len(vals) > 1 {
cfg.Meta2 = vals[1]
}
}
}
return NormalizeUpstreams(cfg, defaults, normalizeUpstream)
}
func RenderUpstreamsConf(cfg Upstreams) string {
return fmt.Sprintf(
"default %s %s\nmeta %s %s\n",
cfg.Default1, cfg.Default2, cfg.Meta1, cfg.Meta2,
)
}

View File

@@ -1,16 +1,5 @@
package app
import (
"encoding/json"
"io"
"io/fs"
"net/http"
"os"
"path/filepath"
"sort"
"strings"
)
// ---------------------------------------------------------------------
// domains editor + smartdns wildcards
// ---------------------------------------------------------------------
@@ -31,210 +20,3 @@ var domainFiles = map[string]string{
"last-ips-map-direct": lastIPsMapDirect,
"last-ips-map-wildcard": lastIPsMapDyn,
}
// ---------------------------------------------------------------------
// domains table
// ---------------------------------------------------------------------
// GET /api/v1/domains/table -> { "lines": [ ... ] }
func handleDomainsTable(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
lines := []string{}
for _, setName := range []string{"agvpn4", "agvpn_dyn4"} {
stdout, _, code, _ := runCommand("nft", "list", "set", "inet", "agvpn", setName)
if code == 0 {
for _, l := range strings.Split(stdout, "\n") {
l = strings.TrimRight(l, "\r")
if l != "" {
lines = append(lines, l)
}
}
continue
}
// Backward-compatible fallback for legacy hosts that still have ipset.
stdout, _, code, _ = runCommand("ipset", "list", setName)
if code != 0 {
continue
}
for _, l := range strings.Split(stdout, "\n") {
l = strings.TrimRight(l, "\r")
if l != "" {
lines = append(lines, l)
}
}
}
writeJSON(w, http.StatusOK, map[string]any{"lines": lines})
}
// ---------------------------------------------------------------------
// domains file
// ---------------------------------------------------------------------
// GET /api/v1/domains/file?name=bases|meta|subs|static|smartdns|last-ips-map|last-ips-map-direct|last-ips-map-wildcard|wildcard-observed-hosts
// POST /api/v1/domains/file { "name": "...", "content": "..." }
func handleDomainsFile(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
name := strings.TrimSpace(r.URL.Query().Get("name"))
if name == "smartdns" {
domains, source := loadSmartDNSWildcardDomainsState(nil)
writeJSON(w, http.StatusOK, map[string]string{
"content": renderSmartDNSDomainsContent(domains),
"source": source,
})
return
}
if name == "wildcard-observed-hosts" {
writeJSON(w, http.StatusOK, map[string]string{
"content": readWildcardObservedHostsContent(),
"source": "derived",
})
return
}
path, ok := domainFiles[name]
if !ok {
http.Error(w, "unknown file name", http.StatusBadRequest)
return
}
source := "file"
if strings.HasPrefix(name, "last-ips-map") {
source = "artifact"
}
data, err := os.ReadFile(path)
if err != nil {
if !os.IsNotExist(err) {
http.Error(w, "read error", http.StatusInternalServerError)
return
}
switch name {
case "bases", "meta", "subs":
// fallback to embedded seed
embedName := name + ".txt"
if name == "meta" {
embedName = "meta-special.txt"
}
data, _ = fs.ReadFile(embeddedDomains, "assets/domains/"+embedName)
source = "embedded"
default:
data = []byte{}
}
}
writeJSON(w, http.StatusOK, map[string]string{
"content": string(data),
"source": source,
})
case http.MethodPost:
var body struct {
Name string `json:"name"`
Content string `json:"content"`
}
if r.Body != nil {
defer r.Body.Close()
if err := json.NewDecoder(io.LimitReader(r.Body, 1<<20)).Decode(&body); err != nil {
http.Error(w, "bad json", http.StatusBadRequest)
return
}
}
if strings.TrimSpace(body.Name) == "smartdns" {
domains := parseSmartDNSDomainsContent(body.Content)
if err := saveSmartDNSWildcardDomainsState(domains); err != nil {
http.Error(w, "write error", http.StatusInternalServerError)
return
}
writeJSON(w, http.StatusOK, map[string]string{"status": "ok"})
return
}
if body.Name == "last-ips-map-direct" || body.Name == "last-ips-map-wildcard" || body.Name == "wildcard-observed-hosts" {
http.Error(w, "read-only file name", http.StatusBadRequest)
return
}
path, ok := domainFiles[strings.TrimSpace(body.Name)]
if !ok {
http.Error(w, "unknown file name", http.StatusBadRequest)
return
}
_ = os.MkdirAll(filepath.Dir(path), 0o755)
if err := os.WriteFile(path, []byte(body.Content), 0o644); err != nil {
http.Error(w, "write error", http.StatusInternalServerError)
return
}
writeJSON(w, http.StatusOK, map[string]string{"status": "ok"})
default:
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
}
}
func readWildcardObservedHostsContent() string {
data, err := os.ReadFile(lastIPsMapDyn)
if err != nil {
return ""
}
seen := make(map[string]struct{})
out := make([]string, 0, 256)
for _, ln := range strings.Split(string(data), "\n") {
ln = strings.TrimSpace(ln)
if ln == "" || strings.HasPrefix(ln, "#") {
continue
}
fields := strings.Fields(ln)
if len(fields) < 2 {
continue
}
host := strings.TrimSpace(fields[1])
if host == "" || strings.HasPrefix(host, "[") {
continue
}
if _, ok := seen[host]; ok {
continue
}
seen[host] = struct{}{}
out = append(out, host)
}
sort.Strings(out)
if len(out) == 0 {
return ""
}
return strings.Join(out, "\n") + "\n"
}
// ---------------------------------------------------------------------
// smartdns wildcards
// ---------------------------------------------------------------------
func handleSmartdnsWildcards(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
payload := struct {
Domains []string `json:"domains"`
}{Domains: readSmartDNSWildcardDomains()}
writeJSON(w, http.StatusOK, payload)
case http.MethodPost:
var payload struct {
Domains []string `json:"domains"`
}
if r.Body != nil {
defer r.Body.Close()
if err := json.NewDecoder(io.LimitReader(r.Body, 1<<20)).Decode(&payload); err != nil {
http.Error(w, "bad json", http.StatusBadRequest)
return
}
}
if err := saveSmartDNSWildcardDomainsState(payload.Domains); err != nil {
http.Error(w, "write error", http.StatusInternalServerError)
return
}
writeJSON(w, http.StatusOK, map[string]string{"status": "ok"})
default:
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
}
}
func readSmartDNSWildcardDomains() []string {
domains, _ := loadSmartDNSWildcardDomainsState(nil)
return domains
}

View File

@@ -0,0 +1,109 @@
package app
import (
"encoding/json"
"io"
"io/fs"
"net/http"
"os"
"path/filepath"
"strings"
)
// ---------------------------------------------------------------------
// domains file
// ---------------------------------------------------------------------
// GET /api/v1/domains/file?name=bases|meta|subs|static|smartdns|last-ips-map|last-ips-map-direct|last-ips-map-wildcard|wildcard-observed-hosts
// POST /api/v1/domains/file { "name": "...", "content": "..." }
func handleDomainsFile(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
name := strings.TrimSpace(r.URL.Query().Get("name"))
if name == "smartdns" {
domains, source := loadSmartDNSWildcardDomainsState(nil)
writeJSON(w, http.StatusOK, map[string]string{
"content": renderSmartDNSDomainsContent(domains),
"source": source,
})
return
}
if name == "wildcard-observed-hosts" {
writeJSON(w, http.StatusOK, map[string]string{
"content": readWildcardObservedHostsContent(),
"source": "derived",
})
return
}
path, ok := domainFiles[name]
if !ok {
http.Error(w, "unknown file name", http.StatusBadRequest)
return
}
source := "file"
if strings.HasPrefix(name, "last-ips-map") {
source = "artifact"
}
data, err := os.ReadFile(path)
if err != nil {
if !os.IsNotExist(err) {
http.Error(w, "read error", http.StatusInternalServerError)
return
}
switch name {
case "bases", "meta", "subs":
// fallback to embedded seed
embedName := name + ".txt"
if name == "meta" {
embedName = "meta-special.txt"
}
data, _ = fs.ReadFile(embeddedDomains, "assets/domains/"+embedName)
source = "embedded"
default:
data = []byte{}
}
}
writeJSON(w, http.StatusOK, map[string]string{
"content": string(data),
"source": source,
})
case http.MethodPost:
var body struct {
Name string `json:"name"`
Content string `json:"content"`
}
if r.Body != nil {
defer r.Body.Close()
if err := json.NewDecoder(io.LimitReader(r.Body, 1<<20)).Decode(&body); err != nil {
http.Error(w, "bad json", http.StatusBadRequest)
return
}
}
if strings.TrimSpace(body.Name) == "smartdns" {
domains := parseSmartDNSDomainsContent(body.Content)
if err := saveSmartDNSWildcardDomainsState(domains); err != nil {
http.Error(w, "write error", http.StatusInternalServerError)
return
}
writeJSON(w, http.StatusOK, map[string]string{"status": "ok"})
return
}
if body.Name == "last-ips-map-direct" || body.Name == "last-ips-map-wildcard" || body.Name == "wildcard-observed-hosts" {
http.Error(w, "read-only file name", http.StatusBadRequest)
return
}
path, ok := domainFiles[strings.TrimSpace(body.Name)]
if !ok {
http.Error(w, "unknown file name", http.StatusBadRequest)
return
}
_ = os.MkdirAll(filepath.Dir(path), 0o755)
if err := os.WriteFile(path, []byte(body.Content), 0o644); err != nil {
http.Error(w, "write error", http.StatusInternalServerError)
return
}
writeJSON(w, http.StatusOK, map[string]string{"status": "ok"})
default:
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
}
}

View File

@@ -0,0 +1,40 @@
package app
import (
"os"
"sort"
"strings"
)
func readWildcardObservedHostsContent() string {
data, err := os.ReadFile(lastIPsMapDyn)
if err != nil {
return ""
}
seen := make(map[string]struct{})
out := make([]string, 0, 256)
for _, ln := range strings.Split(string(data), "\n") {
ln = strings.TrimSpace(ln)
if ln == "" || strings.HasPrefix(ln, "#") {
continue
}
fields := strings.Fields(ln)
if len(fields) < 2 {
continue
}
host := strings.TrimSpace(fields[1])
if host == "" || strings.HasPrefix(host, "[") {
continue
}
if _, ok := seen[host]; ok {
continue
}
seen[host] = struct{}{}
out = append(out, host)
}
sort.Strings(out)
if len(out) == 0 {
return ""
}
return strings.Join(out, "\n") + "\n"
}

View File

@@ -0,0 +1,44 @@
package app
import (
"encoding/json"
"io"
"net/http"
)
// ---------------------------------------------------------------------
// smartdns wildcards
// ---------------------------------------------------------------------
func handleSmartdnsWildcards(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
payload := struct {
Domains []string `json:"domains"`
}{Domains: readSmartDNSWildcardDomains()}
writeJSON(w, http.StatusOK, payload)
case http.MethodPost:
var payload struct {
Domains []string `json:"domains"`
}
if r.Body != nil {
defer r.Body.Close()
if err := json.NewDecoder(io.LimitReader(r.Body, 1<<20)).Decode(&payload); err != nil {
http.Error(w, "bad json", http.StatusBadRequest)
return
}
}
if err := saveSmartDNSWildcardDomainsState(payload.Domains); err != nil {
http.Error(w, "write error", http.StatusInternalServerError)
return
}
writeJSON(w, http.StatusOK, map[string]string{"status": "ok"})
default:
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
}
}
func readSmartDNSWildcardDomains() []string {
domains, _ := loadSmartDNSWildcardDomainsState(nil)
return domains
}

View File

@@ -0,0 +1,45 @@
package app
import (
"net/http"
"strings"
)
// ---------------------------------------------------------------------
// domains table
// ---------------------------------------------------------------------
// GET /api/v1/domains/table -> { "lines": [ ... ] }
func handleDomainsTable(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
lines := []string{}
for _, setName := range []string{"agvpn4", "agvpn_dyn4"} {
stdout, _, code, _ := runCommand("nft", "list", "set", "inet", "agvpn", setName)
if code == 0 {
for _, l := range strings.Split(stdout, "\n") {
l = strings.TrimRight(l, "\r")
if l != "" {
lines = append(lines, l)
}
}
continue
}
// Backward-compatible fallback for legacy hosts that still have ipset.
stdout, _, code, _ = runCommand("ipset", "list", setName)
if code != 0 {
continue
}
for _, l := range strings.Split(stdout, "\n") {
l = strings.TrimRight(l, "\r")
if l != "" {
lines = append(lines, l)
}
}
}
writeJSON(w, http.StatusOK, map[string]any{"lines": lines})
}

View File

@@ -0,0 +1,78 @@
package app
import (
"net/http"
"sync"
"time"
)
const (
egressIdentityFreshTTL = 3 * time.Minute
egressIdentityBackoffMin = 3 * time.Second
egressIdentityBackoffMax = 2 * time.Minute
egressIdentityProbeTimeout = 4 * time.Second
egressIdentityGeoTimeout = 4 * time.Second
egressIdentityGeoCacheTTL = 24 * time.Hour
egressIdentityGeoFailTTL = 30 * time.Second
egressIdentityMaxConcurrency = 2
)
var (
egressIdentitySWR = newEgressIdentityService(
envInt("SVPN_EGRESS_MAX_PARALLEL", egressIdentityMaxConcurrency),
)
egressHTTPClient = &http.Client{}
)
type egressScopeTarget struct {
Scope string
Source string
SourceID string
}
type egressSourceProvider interface {
Probe(target egressScopeTarget) (string, error)
}
type egressIdentityEntry struct {
item EgressIdentity
swr refreshCoordinator
}
type egressGeoCacheEntry struct {
CountryCode string
CountryName string
LastError string
ExpiresAt time.Time
}
type egressIdentityService struct {
mu sync.Mutex
entries map[string]*egressIdentityEntry
sem chan struct{}
providers map[string]egressSourceProvider
geoMu sync.Mutex
geoCache map[string]egressGeoCacheEntry
}
type egressSystemProvider struct{}
type egressAdGuardProvider struct{}
type egressTransportProvider struct{}
func newEgressIdentityService(maxConcurrent int) *egressIdentityService {
n := maxConcurrent
if n <= 0 {
n = egressIdentityMaxConcurrency
}
return &egressIdentityService{
entries: map[string]*egressIdentityEntry{},
sem: make(chan struct{}, n),
providers: map[string]egressSourceProvider{
"system": egressSystemProvider{},
"adguardvpn": egressAdGuardProvider{},
"transport": egressTransportProvider{},
},
geoCache: map[string]egressGeoCacheEntry{},
}
}

View File

@@ -0,0 +1,76 @@
package app
import (
"fmt"
egressutilpkg "selective-vpn-api/app/egressutil"
"strings"
"time"
)
func (s *egressIdentityService) lookupGeo(ip string, force bool) (string, string, error) {
ip = strings.TrimSpace(ip)
if ip == "" {
return "", "", fmt.Errorf("empty ip")
}
now := time.Now()
s.geoMu.Lock()
if entry, ok := s.geoCache[ip]; ok && !entry.ExpiresAt.IsZero() && now.Before(entry.ExpiresAt) {
code := egressutilpkg.NormalizeCountryCode(entry.CountryCode)
name := strings.TrimSpace(entry.CountryName)
errMsg := strings.TrimSpace(entry.LastError)
s.geoMu.Unlock()
if code != "" || name != "" {
return code, name, nil
}
if errMsg != "" && !force {
return "", "", fmt.Errorf("%s", errMsg)
}
if !force {
return "", "", nil
}
// Force refresh bypasses negative geo cache to recover country flag quickly.
}
stale := s.geoCache[ip]
s.geoMu.Unlock()
geoURLs := egressGeoEndpointsForIP(ip)
errs := make([]string, 0, len(geoURLs))
for _, rawURL := range geoURLs {
body, err := egressutilpkg.HTTPGetBody(egressHTTPClient, rawURL, egressIdentityGeoTimeout, "selective-vpn-api/egress-identity", 8*1024)
if err != nil {
errs = append(errs, err.Error())
continue
}
code, name, err := egressutilpkg.ParseGeoResponse(body)
if err != nil {
errs = append(errs, err.Error())
continue
}
s.geoMu.Lock()
s.geoCache[ip] = egressGeoCacheEntry{
CountryCode: egressutilpkg.NormalizeCountryCode(code),
CountryName: strings.TrimSpace(name),
ExpiresAt: now.Add(egressIdentityGeoCacheTTL),
}
s.geoMu.Unlock()
return egressutilpkg.NormalizeCountryCode(code), strings.TrimSpace(name), nil
}
if strings.TrimSpace(stale.CountryCode) != "" || strings.TrimSpace(stale.CountryName) != "" {
return egressutilpkg.NormalizeCountryCode(stale.CountryCode), strings.TrimSpace(stale.CountryName), nil
}
msg := "geo lookup failed"
if len(errs) > 0 {
msg = strings.Join(errs, "; ")
}
s.geoMu.Lock()
s.geoCache[ip] = egressGeoCacheEntry{
LastError: msg,
ExpiresAt: now.Add(egressIdentityGeoFailTTL),
}
s.geoMu.Unlock()
return "", "", fmt.Errorf("%s", msg)
}

View File

@@ -0,0 +1,61 @@
package app
import (
"encoding/json"
"io"
"net/http"
"strings"
)
func handleEgressIdentityGet(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
scope := strings.TrimSpace(r.URL.Query().Get("scope"))
if scope == "" {
http.Error(w, "scope is required", http.StatusBadRequest)
return
}
refresh := false
switch strings.ToLower(strings.TrimSpace(r.URL.Query().Get("refresh"))) {
case "1", "true", "yes", "on":
refresh = true
}
item, err := egressIdentitySWR.getSnapshot(scope, refresh)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
writeJSON(w, http.StatusOK, EgressIdentityResponse{
OK: true,
Message: "ok",
Item: item,
})
}
func handleEgressIdentityRefresh(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
var body EgressIdentityRefreshRequest
if r.Body != nil {
defer r.Body.Close()
if err := json.NewDecoder(io.LimitReader(r.Body, 1<<20)).Decode(&body); err != nil && err != io.EOF {
http.Error(w, "bad json", http.StatusBadRequest)
return
}
}
resp, err := egressIdentitySWR.queueRefresh(body.Scopes, body.Force)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
writeJSON(w, http.StatusOK, resp)
}

View File

@@ -0,0 +1 @@
package app

View File

@@ -0,0 +1,44 @@
package app
import (
"fmt"
"strings"
egressutilpkg "selective-vpn-api/app/egressutil"
)
func egressProbeExternalIP() (string, error) {
endpoints := egressIPEndpoints()
ip, errs := egressutilpkg.ProbeFirstSuccess(endpoints, func(rawURL string) (string, error) {
body, err := egressutilpkg.HTTPGetBody(egressHTTPClient, rawURL, egressIdentityProbeTimeout, "selective-vpn-api/egress-identity", 8*1024)
if err != nil {
return "", err
}
return egressutilpkg.ParseIPFromBody(body)
})
if strings.TrimSpace(ip) != "" {
return ip, nil
}
if len(errs) == 0 {
return "", fmt.Errorf("egress probe endpoints are not configured")
}
return "", fmt.Errorf("%s", strings.Join(errs, "; "))
}
func egressProbeExternalIPViaInterface(iface string) (string, error) {
iface = strings.TrimSpace(iface)
if iface == "" {
return egressProbeExternalIP()
}
endpoints := egressIPEndpoints()
ip, errs := egressutilpkg.ProbeFirstSuccess(endpoints, func(rawURL string) (string, error) {
return egressProbeURLViaInterface(rawURL, iface, egressIdentityProbeTimeout)
})
if strings.TrimSpace(ip) != "" {
return ip, nil
}
if len(errs) == 0 {
return "", fmt.Errorf("egress probe endpoints are not configured")
}
return "", fmt.Errorf("%s", strings.Join(errs, "; "))
}

View File

@@ -0,0 +1,74 @@
package app
import (
"encoding/json"
"net/netip"
"os"
"strings"
"time"
egressutilpkg "selective-vpn-api/app/egressutil"
)
func egressIPEndpoints() []string {
return egressutilpkg.IPEndpoints(os.Getenv("SVPN_EGRESS_IP_ENDPOINTS"))
}
func egressGeoEndpointsForIP(ip string) []string {
return egressutilpkg.GeoEndpointsForIP(os.Getenv("SVPN_EGRESS_GEO_ENDPOINTS"), ip)
}
func egressLimitEndpointsForNetns(in []string) []string {
maxN := envInt("SVPN_EGRESS_NETNS_MAX_ENDPOINTS", 1)
return egressutilpkg.LimitEndpoints(in, maxN)
}
func egressJoinErrorsCompact(errs []string) string {
return egressutilpkg.JoinErrorsCompact(errs)
}
func egressSingBoxSOCKSProxyURL(client TransportClient) string {
if client.Kind != TransportClientSingBox {
return ""
}
path := transportSingBoxConfigPath(client)
if strings.TrimSpace(path) == "" {
return ""
}
data, err := os.ReadFile(path)
if err != nil || len(data) == 0 {
return ""
}
var root map[string]any
if err := json.Unmarshal(data, &root); err != nil {
return ""
}
return egressutilpkg.ParseSingBoxSOCKSProxyURL(root)
}
func egressInterfaceBindAddress(iface string) string {
iface = strings.TrimSpace(iface)
if iface == "" {
return ""
}
stdout, _, code, err := runCommandTimeout(1500*time.Millisecond, "ip", "-4", "-o", "addr", "show", "dev", iface, "scope", "global")
if err != nil || code != 0 {
return ""
}
for _, line := range strings.Split(stdout, "\n") {
fields := strings.Fields(strings.TrimSpace(line))
for i := 0; i < len(fields); i++ {
if fields[i] != "inet" || i+1 >= len(fields) {
continue
}
ip := strings.TrimSpace(fields[i+1])
if slash := strings.Index(ip, "/"); slash > 0 {
ip = ip[:slash]
}
if addr, err := netip.ParseAddr(ip); err == nil && addr.Is4() {
return addr.String()
}
}
}
return ""
}

View File

@@ -0,0 +1,146 @@
package app
import (
"fmt"
"strconv"
"strings"
"time"
egressutilpkg "selective-vpn-api/app/egressutil"
)
func egressProbeExternalIPInNetns(client TransportClient, ns string) (string, error) {
endpoints := egressLimitEndpointsForNetns(egressIPEndpoints())
ip, errs := egressutilpkg.ProbeFirstSuccess(endpoints, func(rawURL string) (string, error) {
return egressProbeURLInNetns(client, ns, rawURL, egressIdentityProbeTimeout)
})
if strings.TrimSpace(ip) != "" {
return ip, nil
}
if len(errs) == 0 {
return "", fmt.Errorf("egress probe endpoints are not configured")
}
return "", fmt.Errorf("%s", egressJoinErrorsCompact(errs))
}
func egressProbeExternalIPInNetnsViaProxy(client TransportClient, ns, proxyURL string) (string, error) {
proxy := strings.TrimSpace(proxyURL)
if proxy == "" {
return "", fmt.Errorf("proxy url is empty")
}
endpoints := egressLimitEndpointsForNetns(egressIPEndpoints())
ip, errs := egressutilpkg.ProbeFirstSuccess(endpoints, func(rawURL string) (string, error) {
return egressProbeURLInNetnsViaProxy(client, ns, rawURL, proxy, egressIdentityProbeTimeout)
})
if strings.TrimSpace(ip) != "" {
return ip, nil
}
if len(errs) == 0 {
return "", fmt.Errorf("egress probe endpoints are not configured")
}
return "", fmt.Errorf("%s", egressJoinErrorsCompact(errs))
}
func egressProbeURLViaInterface(rawURL, iface string, timeout time.Duration) (string, error) {
curl := egressutilpkg.ResolveCurlPath()
sec := egressutilpkg.TimeoutSec(timeout)
if curl != "" {
args := []string{
"-4",
"-fsSL",
"--max-time", strconv.Itoa(sec),
"--connect-timeout", "2",
"--interface", iface,
rawURL,
}
stdout, stderr, code, err := runCommandTimeout(timeout+time.Second, curl, args...)
if err != nil || code != 0 {
return "", transportCommandError(shellJoinArgs(append([]string{curl}, args...)), stdout, stderr, code, err)
}
return egressutilpkg.ParseIPFromBody(stdout)
}
wget := egressutilpkg.ResolveWgetPath()
if wget == "" {
return "", fmt.Errorf("curl/wget are not available for interface-bound egress probe")
}
bindAddr := egressInterfaceBindAddress(iface)
if bindAddr == "" {
return "", fmt.Errorf("cannot resolve IPv4 address for interface %q", iface)
}
args := []string{
"-4",
"-q",
"-T", strconv.Itoa(sec),
"-O", "-",
"--bind-address", bindAddr,
rawURL,
}
stdout, stderr, code, err := runCommandTimeout(timeout+time.Second, wget, args...)
if err != nil || code != 0 {
return "", transportCommandError(shellJoinArgs(append([]string{wget}, args...)), stdout, stderr, code, err)
}
return egressutilpkg.ParseIPFromBody(stdout)
}
func egressProbeURLInNetns(client TransportClient, ns, rawURL string, timeout time.Duration) (string, error) {
sec := egressutilpkg.TimeoutSec(timeout)
resolveHost, resolvePort, resolveIP := egressutilpkg.ResolvedHostForURL(rawURL)
curlBin := egressutilpkg.ResolveCurlPath()
if curlBin == "" {
return "", fmt.Errorf("curl is not available for netns probe")
}
curlArgs := []string{
"-4",
"-fsSL",
"--max-time", strconv.Itoa(sec),
"--connect-timeout", "2",
}
if resolveHost != "" && resolveIP != "" && resolvePort > 0 {
curlArgs = append(curlArgs, "--resolve", fmt.Sprintf("%s:%d:%s", resolveHost, resolvePort, resolveIP))
}
curlArgs = append(curlArgs, rawURL)
curlCmd := append([]string{curlBin}, curlArgs...)
name, args, err := transportNetnsExecCommand(client, ns, curlCmd...)
if err != nil {
return "", err
}
stdout, stderr, code, runErr := runCommandTimeout(timeout+time.Second, name, args...)
if runErr != nil || code != 0 {
return "", transportCommandError(shellJoinArgs(append([]string{name}, args...)), stdout, stderr, code, runErr)
}
return egressutilpkg.ParseIPFromBody(stdout)
}
func egressProbeURLInNetnsViaProxy(
client TransportClient,
ns string,
rawURL string,
proxyURL string,
timeout time.Duration,
) (string, error) {
curlBin := egressutilpkg.ResolveCurlPath()
if curlBin == "" {
return "", fmt.Errorf("curl is not available for proxy probe")
}
sec := egressutilpkg.TimeoutSec(timeout)
args := []string{
"-4",
"-fsSL",
"--max-time", strconv.Itoa(sec),
"--connect-timeout", "3",
"--proxy", strings.TrimSpace(proxyURL),
rawURL,
}
cmd := append([]string{curlBin}, args...)
name, netnsArgs, err := transportNetnsExecCommand(client, ns, cmd...)
if err != nil {
return "", err
}
stdout, stderr, code, runErr := runCommandTimeout(timeout+time.Second, name, netnsArgs...)
if runErr != nil || code != 0 {
return "", transportCommandError(shellJoinArgs(append([]string{name}, netnsArgs...)), stdout, stderr, code, runErr)
}
return egressutilpkg.ParseIPFromBody(stdout)
}

View File

@@ -0,0 +1,95 @@
package app
import (
"fmt"
"strings"
"time"
)
func (egressSystemProvider) Probe(_ egressScopeTarget) (string, error) {
return egressProbeExternalIP()
}
func (egressAdGuardProvider) Probe(_ egressScopeTarget) (string, error) {
stdout, stderr, code, err := runCommandTimeout(2*time.Second, "systemctl", "is-active", adgvpnUnit)
state := strings.ToLower(strings.TrimSpace(stdout))
if state != "active" || err != nil || code != 0 {
return "", transportCommandError("systemctl is-active "+adgvpnUnit, stdout, stderr, code, err)
}
iface, _ := resolveTrafficIface(loadTrafficModeState().PreferredIface)
if iface = strings.TrimSpace(iface); iface == "" {
return "", fmt.Errorf("adguardvpn interface is not resolved")
}
if !ifaceExists(iface) {
return "", fmt.Errorf("adguardvpn interface %q is not available", iface)
}
return egressProbeExternalIPViaInterface(iface)
}
func (egressTransportProvider) Probe(target egressScopeTarget) (string, error) {
id := sanitizeID(target.SourceID)
if id == "" {
return "", fmt.Errorf("invalid transport source id")
}
transportMu.Lock()
st := loadTransportClientsState()
idx := findTransportClientIndex(st.Items, id)
var client TransportClient
if idx >= 0 {
client = st.Items[idx]
}
transportMu.Unlock()
if idx < 0 {
return "", fmt.Errorf("transport client %q not found", id)
}
if !client.Enabled {
return "", fmt.Errorf("transport client %q is disabled", id)
}
if normalizeTransportStatus(client.Status) == TransportClientDown {
backend := selectTransportBackend(client)
live := backend.Health(client)
if normalizeTransportStatus(live.Status) != TransportClientUp {
msg := strings.TrimSpace(live.Message)
if msg == "" {
msg = fmt.Sprintf("transport client %q is down", id)
}
return "", fmt.Errorf("%s", msg)
}
client.Status = TransportClientUp
}
if transportNetnsEnabled(client) {
ns := transportNetnsName(client)
if strings.TrimSpace(ns) == "" {
return "", fmt.Errorf("transport client %q netns is enabled but netns_name is empty", id)
}
if client.Kind == TransportClientSingBox {
proxyURL := egressSingBoxSOCKSProxyURL(client)
if proxyURL == "" {
return "", fmt.Errorf("proxy probe failed: singbox socks inbound not found")
}
// For SingBox in netns we must use tunnel egress probe (SOCKS inbound -> outbound proxy).
// Direct netns probe is intentionally not used: in selective mode it may return AdGuard/system IP.
ip, err := egressProbeExternalIPInNetnsViaProxy(client, ns, proxyURL)
if err == nil {
return ip, nil
}
return "", fmt.Errorf("proxy probe failed: %v", err)
}
ip, err := egressProbeExternalIPInNetns(client, ns)
if err == nil {
return ip, nil
}
return "", err
}
iface := strings.TrimSpace(client.Iface)
if iface != "" && ifaceExists(iface) {
return egressProbeExternalIPViaInterface(iface)
}
return egressProbeExternalIP()
}

View File

@@ -0,0 +1,140 @@
package app
import (
"sort"
"strings"
"time"
)
func (s *egressIdentityService) getSnapshot(scopeRaw string, refresh bool) (EgressIdentity, error) {
target, err := parseEgressScope(scopeRaw)
if err != nil {
return EgressIdentity{}, err
}
if refresh {
s.queueScopeRefresh(target, false)
}
return s.snapshot(target, time.Now()), nil
}
func (s *egressIdentityService) queueRefresh(scopes []string, force bool) (EgressIdentityRefreshResponse, error) {
rawTargets := make([]string, 0, len(scopes))
for _, raw := range scopes {
v := strings.TrimSpace(raw)
if v == "" {
continue
}
rawTargets = append(rawTargets, v)
}
if len(rawTargets) == 0 {
rawTargets = s.knownScopes()
}
if len(rawTargets) == 0 {
rawTargets = []string{"adguardvpn", "system"}
}
targets := make([]egressScopeTarget, 0, len(rawTargets))
seen := map[string]struct{}{}
for _, raw := range rawTargets {
target, err := parseEgressScope(raw)
if err != nil {
return EgressIdentityRefreshResponse{}, err
}
if _, ok := seen[target.Scope]; ok {
continue
}
seen[target.Scope] = struct{}{}
targets = append(targets, target)
}
resp := EgressIdentityRefreshResponse{
OK: true,
Message: "refresh queued",
Items: make([]EgressIdentityRefreshItem, 0, len(targets)),
}
for _, target := range targets {
queued, reason := s.queueScopeRefresh(target, force)
item := EgressIdentityRefreshItem{
Scope: target.Scope,
Queued: queued,
}
if queued {
item.Status = "queued"
resp.Queued++
} else {
item.Status = "skipped"
item.Reason = strings.TrimSpace(reason)
if item.Reason == "" {
item.Reason = "throttled or already fresh"
}
resp.Skipped++
}
resp.Items = append(resp.Items, item)
}
resp.Count = len(resp.Items)
if resp.Queued == 0 {
resp.Message = "refresh skipped"
}
return resp, nil
}
func (s *egressIdentityService) knownScopes() []string {
outSet := map[string]struct{}{
"adguardvpn": {},
"system": {},
}
transportMu.Lock()
st := loadTransportClientsState()
transportMu.Unlock()
for _, it := range st.Items {
id := sanitizeID(it.ID)
if id == "" {
continue
}
outSet["transport:"+id] = struct{}{}
}
s.mu.Lock()
for scope := range s.entries {
outSet[scope] = struct{}{}
}
s.mu.Unlock()
out := make([]string, 0, len(outSet))
for scope := range outSet {
out = append(out, scope)
}
sort.Strings(out)
return out
}
func (s *egressIdentityService) queueScopeRefresh(target egressScopeTarget, force bool) (bool, string) {
now := time.Now()
s.mu.Lock()
entry := s.ensureEntryLocked(target)
hasData := strings.TrimSpace(entry.item.IP) != ""
switch {
case entry.swr.refreshInProgress():
s.mu.Unlock()
return false, "already in progress"
case !force && !entry.swr.nextRetryAt().IsZero() && now.Before(entry.swr.nextRetryAt()):
s.mu.Unlock()
return false, "backoff in progress"
case !force && hasData && !entry.swr.isStale(now):
s.mu.Unlock()
return false, "already fresh"
}
if force {
entry.swr.clearBackoff()
}
if !entry.swr.beginRefresh(now, force, hasData) {
s.mu.Unlock()
return false, "throttled or already fresh"
}
s.mu.Unlock()
go s.refreshScope(target, force)
return true, ""
}

View File

@@ -0,0 +1,117 @@
package app
import (
"log"
"strings"
"time"
egressutilpkg "selective-vpn-api/app/egressutil"
)
func (s *egressIdentityService) refreshScope(target egressScopeTarget, force bool) {
s.acquire()
defer s.release()
now := time.Now().UTC()
provider := s.providerFor(target.Source)
if provider == nil {
s.finishError(target, "provider is not configured for scope source", now)
return
}
ip, err := provider.Probe(target)
if err != nil {
s.finishError(target, err.Error(), now)
return
}
code, name, geoErr := s.lookupGeo(ip, force)
s.finishSuccess(target, ip, code, name, geoErr, now)
}
func (s *egressIdentityService) providerFor(source string) egressSourceProvider {
s.mu.Lock()
defer s.mu.Unlock()
return s.providers[strings.ToLower(strings.TrimSpace(source))]
}
func (s *egressIdentityService) finishError(target egressScopeTarget, msg string, at time.Time) {
s.mu.Lock()
entry := s.ensureEntryLocked(target)
prev := s.entrySnapshotLocked(entry, target, at)
entry.swr.finishError(msg, at)
next := s.entrySnapshotLocked(entry, target, at)
changed := egressIdentityChanged(prev, next)
s.mu.Unlock()
if changed {
events.push("egress_identity_changed", map[string]any{
"scope": next.Scope,
"ip": next.IP,
"country_code": next.CountryCode,
"country_name": next.CountryName,
"updated_at": next.UpdatedAt,
"stale": next.Stale,
"last_error": next.LastError,
})
if target.Source == "transport" {
publishTransportRuntimeObservabilitySnapshotChanged(
"egress_identity_changed",
[]string{target.SourceID},
nil,
)
}
}
}
func (s *egressIdentityService) finishSuccess(
target egressScopeTarget,
ip string,
code string,
name string,
geoErr error,
at time.Time,
) {
s.mu.Lock()
entry := s.ensureEntryLocked(target)
prev := s.entrySnapshotLocked(entry, target, at)
previousIP := strings.TrimSpace(entry.item.IP)
entry.item.Scope = target.Scope
entry.item.Source = target.Source
entry.item.SourceID = target.SourceID
entry.item.IP = strings.TrimSpace(ip)
if geoErr == nil {
entry.item.CountryCode = egressutilpkg.NormalizeCountryCode(code)
entry.item.CountryName = strings.TrimSpace(name)
} else if previousIP != strings.TrimSpace(ip) {
entry.item.CountryCode = ""
entry.item.CountryName = ""
}
entry.swr.finishSuccess(at)
next := s.entrySnapshotLocked(entry, target, at)
changed := egressIdentityChanged(prev, next)
s.mu.Unlock()
if geoErr != nil {
log.Printf("egress geo lookup warning: scope=%s ip=%s err=%v", target.Scope, ip, geoErr)
}
if changed {
events.push("egress_identity_changed", map[string]any{
"scope": next.Scope,
"ip": next.IP,
"country_code": next.CountryCode,
"country_name": next.CountryName,
"updated_at": next.UpdatedAt,
"stale": next.Stale,
"last_error": next.LastError,
})
if target.Source == "transport" {
publishTransportRuntimeObservabilitySnapshotChanged(
"egress_identity_changed",
[]string{target.SourceID},
nil,
)
}
}
}

View File

@@ -0,0 +1,78 @@
package app
import (
"strings"
"time"
)
func (s *egressIdentityService) snapshot(target egressScopeTarget, now time.Time) EgressIdentity {
s.mu.Lock()
defer s.mu.Unlock()
entry := s.ensureEntryLocked(target)
return s.entrySnapshotLocked(entry, target, now)
}
func (s *egressIdentityService) ensureEntryLocked(target egressScopeTarget) *egressIdentityEntry {
entry := s.entries[target.Scope]
if entry != nil {
if entry.item.Scope == "" {
entry.item.Scope = target.Scope
}
if entry.item.Source == "" {
entry.item.Source = target.Source
}
if entry.item.SourceID == "" {
entry.item.SourceID = target.SourceID
}
return entry
}
entry = &egressIdentityEntry{
item: EgressIdentity{
Scope: target.Scope,
Source: target.Source,
SourceID: target.SourceID,
},
swr: newRefreshCoordinator(
egressIdentityFreshTTL,
egressIdentityBackoffMin,
egressIdentityBackoffMax,
),
}
s.entries[target.Scope] = entry
return entry
}
func (s *egressIdentityService) entrySnapshotLocked(
entry *egressIdentityEntry,
target egressScopeTarget,
now time.Time,
) EgressIdentity {
item := entry.item
if item.Scope == "" {
item.Scope = target.Scope
}
if item.Source == "" {
item.Source = target.Source
}
if item.SourceID == "" {
item.SourceID = target.SourceID
}
meta := entry.swr.snapshot(now)
item.UpdatedAt = meta.UpdatedAt
item.Stale = meta.Stale
item.RefreshInProgress = meta.RefreshInProgress
item.LastError = strings.TrimSpace(meta.LastError)
item.NextRetryAt = meta.NextRetryAt
return item
}
func (s *egressIdentityService) acquire() {
s.sem <- struct{}{}
}
func (s *egressIdentityService) release() {
select {
case <-s.sem:
default:
}
}

View File

@@ -0,0 +1,35 @@
package app
import egressutilpkg "selective-vpn-api/app/egressutil"
func parseEgressScope(raw string) (egressScopeTarget, error) {
target, err := egressutilpkg.ParseScope(raw, sanitizeID)
if err != nil {
return egressScopeTarget{}, err
}
return egressScopeTarget{
Scope: target.Scope,
Source: target.Source,
SourceID: target.SourceID,
}, nil
}
func egressIdentityChanged(prev, next EgressIdentity) bool {
return egressutilpkg.IdentityChanged(
egressIdentitySnapshot(prev),
egressIdentitySnapshot(next),
)
}
func egressIdentitySnapshot(item EgressIdentity) egressutilpkg.IdentitySnapshot {
return egressutilpkg.IdentitySnapshot{
IP: item.IP,
CountryCode: item.CountryCode,
CountryName: item.CountryName,
UpdatedAt: item.UpdatedAt,
Stale: item.Stale,
RefreshInProgress: item.RefreshInProgress,
LastError: item.LastError,
NextRetryAt: item.NextRetryAt,
}
}

View File

@@ -0,0 +1,188 @@
package app
import (
egressutilpkg "selective-vpn-api/app/egressutil"
"testing"
)
func TestParseEgressScope(t *testing.T) {
tests := []struct {
name string
in string
wantOK bool
want egressScopeTarget
}{
{
name: "adguardvpn",
in: "adguardvpn",
wantOK: true,
want: egressScopeTarget{
Scope: "adguardvpn",
Source: "adguardvpn",
},
},
{
name: "system",
in: "system",
wantOK: true,
want: egressScopeTarget{
Scope: "system",
Source: "system",
},
},
{
name: "transport normalize",
in: "transport: SG RealNetns ",
wantOK: true,
want: egressScopeTarget{
Scope: "transport:sg-realnetns",
Source: "transport",
SourceID: "sg-realnetns",
},
},
{
name: "bad scope",
in: "transport:",
wantOK: false,
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
got, err := parseEgressScope(tc.in)
if tc.wantOK && err != nil {
t.Fatalf("parseEgressScope(%q) unexpected error: %v", tc.in, err)
}
if !tc.wantOK {
if err == nil {
t.Fatalf("parseEgressScope(%q) expected error", tc.in)
}
return
}
if got != tc.want {
t.Fatalf("parseEgressScope(%q)=%+v want %+v", tc.in, got, tc.want)
}
})
}
}
func TestParseEgressIPFromBody(t *testing.T) {
tests := []struct {
name string
in string
wantIP string
wantOK bool
}{
{
name: "plain ipv4",
in: "203.0.113.10\n",
wantIP: "203.0.113.10",
wantOK: true,
},
{
name: "json ip",
in: `{"ip":"198.51.100.7"}`,
wantIP: "198.51.100.7",
wantOK: true,
},
{
name: "json query",
in: `{"query":"2001:db8::1"}`,
wantIP: "2001:db8::1",
wantOK: true,
},
{
name: "invalid",
in: "not-an-ip",
wantOK: false,
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
got, err := egressutilpkg.ParseIPFromBody(tc.in)
if tc.wantOK && err != nil {
t.Fatalf("ParseIPFromBody unexpected error: %v", err)
}
if !tc.wantOK {
if err == nil {
t.Fatalf("ParseIPFromBody expected error")
}
return
}
if got != tc.wantIP {
t.Fatalf("ParseIPFromBody=%q want %q", got, tc.wantIP)
}
})
}
}
func TestEgressParseGeoResponse(t *testing.T) {
code, name, err := egressutilpkg.ParseGeoResponse(`{"success":true,"country":"Singapore","country_code":"SG"}`)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if code != "SG" || name != "Singapore" {
t.Fatalf("unexpected geo parse result: code=%q name=%q", code, name)
}
code, name, err = egressutilpkg.ParseGeoResponse(`{"status":"success","country":"Netherlands","countryCode":"NL"}`)
if err != nil {
t.Fatalf("unexpected ip-api error: %v", err)
}
if code != "NL" || name != "Netherlands" {
t.Fatalf("unexpected ip-api geo parse result: code=%q name=%q", code, name)
}
if _, _, err := egressutilpkg.ParseGeoResponse(`{"status":"fail","message":"private range"}`); err == nil {
t.Fatalf("expected geo parse error for fail status")
}
}
func TestNormalizeCountryCode(t *testing.T) {
if got := egressutilpkg.NormalizeCountryCode("sg"); got != "SG" {
t.Fatalf("NormalizeCountryCode(sg)=%q", got)
}
if got := egressutilpkg.NormalizeCountryCode("123"); got != "" {
t.Fatalf("NormalizeCountryCode(123) should be empty, got=%q", got)
}
}
func TestEgressResolvedHostForURL(t *testing.T) {
host, port, ip := egressutilpkg.ResolvedHostForURL("https://127.0.0.1/ip")
if host != "" || port != 0 || ip != "" {
t.Fatalf("expected empty resolve tuple for literal IP, got host=%q port=%d ip=%q", host, port, ip)
}
}
func TestEgressParseSingBoxSOCKSProxyURL(t *testing.T) {
root := map[string]any{
"inbounds": []any{
map[string]any{
"type": "socks",
"listen": "127.0.0.1",
"listen_port": 2080,
},
},
}
got := egressutilpkg.ParseSingBoxSOCKSProxyURL(root)
if got != "socks5h://127.0.0.1:2080" {
t.Fatalf("unexpected proxy url: %q", got)
}
root2 := map[string]any{
"inbounds": []any{
map[string]any{
"type": "mixed",
"listen": "0.0.0.0",
"listen_port": float64(1080),
},
},
}
got2 := egressutilpkg.ParseSingBoxSOCKSProxyURL(root2)
if got2 != "socks5h://127.0.0.1:1080" {
t.Fatalf("unexpected mixed proxy url: %q", got2)
}
}

View File

@@ -0,0 +1,48 @@
package egressutil
import (
"context"
"fmt"
"io"
"net/http"
"strings"
"time"
)
func HTTPGetBody(client *http.Client, rawURL string, timeout time.Duration, userAgent string, maxBytes int64) (string, error) {
if client == nil {
return "", fmt.Errorf("http client is nil")
}
limit := maxBytes
if limit <= 0 {
limit = 8 * 1024
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil)
if err != nil {
return "", err
}
if strings.TrimSpace(userAgent) != "" {
req.Header.Set("User-Agent", strings.TrimSpace(userAgent))
}
req.Header.Set("Accept", "application/json, text/plain, */*")
resp, err := client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
body, _ := io.ReadAll(io.LimitReader(resp.Body, limit))
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
msg := strings.TrimSpace(string(body))
if msg == "" {
msg = resp.Status
}
return "", fmt.Errorf("%s -> %s", rawURL, msg)
}
return string(body), nil
}

View File

@@ -0,0 +1,25 @@
package egressutil
import "strings"
type IdentitySnapshot struct {
IP string
CountryCode string
CountryName string
UpdatedAt string
Stale bool
RefreshInProgress bool
LastError string
NextRetryAt string
}
func IdentityChanged(prev, next IdentitySnapshot) bool {
return strings.TrimSpace(prev.IP) != strings.TrimSpace(next.IP) ||
strings.TrimSpace(prev.CountryCode) != strings.TrimSpace(next.CountryCode) ||
strings.TrimSpace(prev.CountryName) != strings.TrimSpace(next.CountryName) ||
strings.TrimSpace(prev.UpdatedAt) != strings.TrimSpace(next.UpdatedAt) ||
prev.Stale != next.Stale ||
prev.RefreshInProgress != next.RefreshInProgress ||
strings.TrimSpace(prev.LastError) != strings.TrimSpace(next.LastError) ||
strings.TrimSpace(prev.NextRetryAt) != strings.TrimSpace(next.NextRetryAt)
}

View File

@@ -0,0 +1,19 @@
package egressutil
func ProbeFirstSuccess(endpoints []string, probe func(rawURL string) (string, error)) (string, []string) {
if len(endpoints) == 0 {
return "", nil
}
errs := make([]string, 0, len(endpoints))
for _, rawURL := range endpoints {
if probe == nil {
continue
}
val, err := probe(rawURL)
if err == nil {
return val, nil
}
errs = append(errs, err.Error())
}
return "", errs
}

View File

@@ -0,0 +1,43 @@
package egressutil
import (
"fmt"
"strings"
)
type ScopeTarget struct {
Scope string
Source string
SourceID string
}
func ParseScope(raw string, sanitizeID func(string) string) (ScopeTarget, error) {
scope := strings.ToLower(strings.TrimSpace(raw))
switch {
case scope == "adguardvpn":
return ScopeTarget{
Scope: "adguardvpn",
Source: "adguardvpn",
}, nil
case scope == "system":
return ScopeTarget{
Scope: "system",
Source: "system",
}, nil
case strings.HasPrefix(scope, "transport:"):
id := strings.TrimSpace(strings.TrimPrefix(scope, "transport:"))
if sanitizeID != nil {
id = sanitizeID(id)
}
if id == "" {
return ScopeTarget{}, fmt.Errorf("invalid transport scope id")
}
return ScopeTarget{
Scope: "transport:" + id,
Source: "transport",
SourceID: id,
}, nil
default:
return ScopeTarget{}, fmt.Errorf("invalid scope, expected adguardvpn|system|transport:<id>")
}
}

View File

@@ -0,0 +1,399 @@
package egressutil
import (
"context"
"encoding/json"
"fmt"
"net"
"net/netip"
"net/url"
"os/exec"
"strconv"
"strings"
"sync"
"time"
)
var (
curlPathOnce sync.Once
curlPath string
wgetPathOnce sync.Once
wgetPath string
)
func ParseIPFromBody(raw string) (string, error) {
s := strings.TrimSpace(raw)
if s == "" {
return "", fmt.Errorf("empty response")
}
var obj map[string]any
if strings.HasPrefix(s, "{") && strings.HasSuffix(s, "}") {
if err := json.Unmarshal([]byte(s), &obj); err == nil && obj != nil {
keys := []string{
"ip", "origin", "query", "your_ip", "client_ip",
"ip_addr", "ip_address", "address",
}
for _, key := range keys {
if v := strings.TrimSpace(AnyToString(obj[key])); v != "" {
if i := strings.Index(v, ","); i >= 0 {
v = strings.TrimSpace(v[:i])
}
if addr, err := netip.ParseAddr(v); err == nil {
return addr.String(), nil
}
}
}
}
}
parts := strings.FieldsFunc(s, func(r rune) bool {
switch r {
case '\n', '\r', '\t', ' ', ',', ';', '[', ']', '"', '\'':
return true
default:
return false
}
})
for _, part := range parts {
v := strings.TrimSpace(strings.TrimPrefix(part, "ip="))
if v == "" {
continue
}
if addr, err := netip.ParseAddr(v); err == nil {
return addr.String(), nil
}
}
return "", fmt.Errorf("cannot parse egress ip from response: %q", s)
}
func ParseGeoResponse(raw string) (string, string, error) {
var obj map[string]any
if err := json.Unmarshal([]byte(raw), &obj); err != nil {
return "", "", err
}
if v, ok := obj["success"]; ok {
if b, ok := v.(bool); ok && !b {
msg := strings.TrimSpace(AnyToString(obj["message"]))
if msg == "" {
msg = "geo lookup reported success=false"
}
return "", "", fmt.Errorf("%s", msg)
}
}
if status := strings.ToLower(strings.TrimSpace(AnyToString(obj["status"]))); status == "fail" {
msg := strings.TrimSpace(AnyToString(obj["message"]))
if msg == "" {
msg = "geo lookup status=fail"
}
return "", "", fmt.Errorf("%s", msg)
}
code := NormalizeCountryCode(FirstNonEmptyAny(obj, "country_code", "countryCode", "cc"))
name := strings.TrimSpace(FirstNonEmptyAny(obj, "country_name", "country", "countryName"))
if code == "" && name == "" {
return "", "", fmt.Errorf("geo response does not contain country fields")
}
return code, name, nil
}
func NormalizeCountryCode(raw string) string {
cc := strings.ToUpper(strings.TrimSpace(raw))
if len(cc) != 2 {
return ""
}
for _, ch := range cc {
if ch < 'A' || ch > 'Z' {
return ""
}
}
return cc
}
func IPEndpoints(envRaw string) []string {
raw := strings.TrimSpace(strings.ReplaceAll(envRaw, ";", ","))
if raw == "" {
return []string{
"https://api64.ipify.org",
"https://api.ipify.org",
"https://ifconfig.me/ip",
}
}
return ParseURLList(raw)
}
func GeoEndpointsForIP(envRaw, ip string) []string {
raw := strings.TrimSpace(strings.ReplaceAll(envRaw, ";", ","))
if raw == "" {
raw = "https://ipwho.is/%s,http://ip-api.com/json/%s?fields=status,country,countryCode,query,message"
}
base := ParseURLList(raw)
out := make([]string, 0, len(base))
for _, item := range base {
if strings.Contains(item, "%s") {
out = append(out, fmt.Sprintf(item, ip))
continue
}
out = append(out, strings.TrimRight(item, "/")+"/"+ip)
}
return out
}
func ParseURLList(raw string) []string {
parts := strings.FieldsFunc(raw, func(r rune) bool {
return r == ',' || r == '\n' || r == '\r' || r == '\t' || r == ' '
})
out := make([]string, 0, len(parts))
for _, part := range parts {
v := strings.TrimSpace(part)
if v == "" {
continue
}
if !strings.Contains(v, "://") {
v = "https://" + v
}
out = append(out, v)
}
return dedupeStrings(out)
}
func LimitEndpoints(in []string, maxN int) []string {
if len(in) == 0 {
return nil
}
if maxN <= 0 || maxN >= len(in) {
out := make([]string, 0, len(in))
out = append(out, in...)
return out
}
out := make([]string, 0, maxN)
out = append(out, in[:maxN]...)
return out
}
func JoinErrorsCompact(errs []string) string {
if len(errs) == 0 {
return "probe failed"
}
first := strings.TrimSpace(errs[0])
if first == "" {
first = "probe failed"
}
if len(errs) == 1 {
return first
}
return fmt.Sprintf("%s; +%d more", first, len(errs)-1)
}
func ParseSingBoxSOCKSProxyURL(root map[string]any) string {
if root == nil {
return ""
}
rawInbounds, ok := root["inbounds"].([]any)
if !ok || len(rawInbounds) == 0 {
return ""
}
for _, raw := range rawInbounds {
inb, ok := raw.(map[string]any)
if !ok {
continue
}
typ := strings.ToLower(strings.TrimSpace(AnyToString(inb["type"])))
if typ != "socks" && typ != "mixed" {
continue
}
port, ok := parseIntAny(inb["listen_port"])
if !ok || port <= 0 || port > 65535 {
continue
}
host := strings.TrimSpace(AnyToString(inb["listen"]))
switch host {
case "", "::", "::1", "0.0.0.0", "[::]":
host = "127.0.0.1"
}
if strings.TrimSpace(host) == "" {
host = "127.0.0.1"
}
return fmt.Sprintf("socks5h://%s:%d", host, port)
}
return ""
}
func ResolvedHostForURL(rawURL string) (string, int, string) {
u, err := url.Parse(strings.TrimSpace(rawURL))
if err != nil {
return "", 0, ""
}
host := strings.TrimSpace(u.Hostname())
if host == "" {
return "", 0, ""
}
if _, err := netip.ParseAddr(host); err == nil {
return "", 0, ""
}
port := 0
if p := strings.TrimSpace(u.Port()); p != "" {
n, err := strconv.Atoi(p)
if err == nil && n > 0 && n <= 65535 {
port = n
}
}
if port == 0 {
switch strings.ToLower(strings.TrimSpace(u.Scheme)) {
case "http":
port = 80
default:
port = 443
}
}
ip, err := ResolveHostIPv4(host, 2*time.Second)
if err != nil || ip == "" {
return "", 0, ""
}
return host, port, ip
}
func ResolveHostIPv4(host string, timeout time.Duration) (string, error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
addrs, err := net.DefaultResolver.LookupIPAddr(ctx, strings.TrimSpace(host))
if err != nil {
return "", err
}
for _, a := range addrs {
if ip4 := a.IP.To4(); ip4 != nil {
return ip4.String(), nil
}
}
return "", fmt.Errorf("no ipv4 address for host %q", host)
}
func ResolveCurlPath() string {
curlPathOnce.Do(func() {
if p, err := exec.LookPath("curl"); err == nil {
curlPath = strings.TrimSpace(p)
return
}
for _, cand := range []string{"/usr/bin/curl", "/bin/curl"} {
if _, err := exec.LookPath(cand); err == nil {
curlPath = strings.TrimSpace(cand)
return
}
}
curlPath = ""
})
return curlPath
}
func ResolveWgetPath() string {
wgetPathOnce.Do(func() {
if p, err := exec.LookPath("wget"); err == nil {
wgetPath = strings.TrimSpace(p)
return
}
for _, cand := range []string{"/usr/bin/wget", "/bin/wget"} {
if _, err := exec.LookPath(cand); err == nil {
wgetPath = strings.TrimSpace(cand)
return
}
}
wgetPath = ""
})
return wgetPath
}
func TimeoutSec(timeout time.Duration) int {
sec := int(timeout.Seconds())
if sec < 1 {
sec = 1
}
return sec
}
func FirstNonEmptyAny(obj map[string]any, keys ...string) string {
for _, key := range keys {
if v := strings.TrimSpace(AnyToString(obj[key])); v != "" {
return v
}
}
return ""
}
func AnyToString(v any) string {
switch x := v.(type) {
case string:
return x
case fmt.Stringer:
return x.String()
case int:
return strconv.Itoa(x)
case int64:
return strconv.FormatInt(x, 10)
case float64:
return strconv.FormatFloat(x, 'f', -1, 64)
case bool:
if x {
return "true"
}
return "false"
default:
return ""
}
}
func dedupeStrings(in []string) []string {
seen := map[string]struct{}{}
out := make([]string, 0, len(in))
for _, raw := range in {
v := strings.TrimSpace(raw)
if v == "" {
continue
}
if _, ok := seen[v]; ok {
continue
}
seen[v] = struct{}{}
out = append(out, v)
}
return out
}
func parseIntAny(v any) (int, bool) {
switch x := v.(type) {
case int:
return x, true
case int8:
return int(x), true
case int16:
return int(x), true
case int32:
return int(x), true
case int64:
return int(x), true
case uint:
return int(x), true
case uint8:
return int(x), true
case uint16:
return int(x), true
case uint32:
return int(x), true
case uint64:
return int(x), true
case float64:
return int(x), true
case string:
n, err := strconv.Atoi(strings.TrimSpace(x))
if err != nil {
return 0, false
}
return n, true
default:
return 0, false
}
}

View File

@@ -0,0 +1,90 @@
package app
import (
"os"
appcli "selective-vpn-api/app/cli"
)
// EN: Application entrypoint and process bootstrap.
// EN: This file wires CLI modes and delegates API runtime bootstrap.
// RU: Точка входа приложения и bootstrap процесса.
// RU: Этот файл связывает CLI-режимы и делегирует запуск API-сервера.
func Run() {
if code, handled := runLegacyCLIMode(os.Args[1:]); handled {
if code != 0 {
os.Exit(code)
}
return
}
RunAPIServer()
}
// RunAPIServer starts the HTTP/SSE API server mode.
func RunAPIServer() {
runAPIServerAtAddr("127.0.0.1:8080")
}
// RunRoutesUpdateCLI executes one-shot routes update mode.
func RunRoutesUpdateCLI(args []string) int {
return appcli.RunRoutesUpdate(args, appcli.RoutesUpdateDeps{
LockFile: lockFile,
Update: func(iface string) (bool, string) {
res := routesUpdate(iface)
return res.OK, res.Message
},
Stdout: os.Stdout,
Stderr: os.Stderr,
})
}
// RunRoutesClearCLI executes one-shot routes clear mode.
func RunRoutesClearCLI(args []string) int {
return appcli.RunRoutesClear(args, appcli.RoutesClearDeps{
Clear: func() (bool, string) {
res := routesClear()
return res.OK, res.Message
},
Stdout: os.Stdout,
Stderr: os.Stderr,
})
}
// RunAutoloopCLI executes autoloop mode.
func RunAutoloopCLI(args []string) int {
return appcli.RunAutoloop(args, appcli.AutoloopDeps{
StateDirDefault: stateDir,
ResolveIface: func(flagIface string) string {
resolvedIface := normalizePreferredIface(flagIface)
if resolvedIface == "" {
resolvedIface, _ = resolveTrafficIface(loadTrafficModeState().PreferredIface)
}
return resolvedIface
},
Run: func(params appcli.AutoloopParams) {
runAutoloop(
params.Iface,
params.Table,
params.MTU,
params.StateDir,
params.DefaultLocation,
)
},
Stderr: os.Stderr,
})
}
func runLegacyCLIMode(args []string) (int, bool) {
if len(args) == 0 {
return 0, false
}
switch args[0] {
case "routes-update", "-routes-update":
return RunRoutesUpdateCLI(args[1:]), true
case "routes-clear":
return RunRoutesClearCLI(args[1:]), true
case "autoloop":
return RunAutoloopCLI(args[1:]), true
default:
return 0, false
}
}

View File

@@ -4,97 +4,35 @@ import (
"os"
"strconv"
"strings"
"sync"
"time"
eventsbuspkg "selective-vpn-api/app/eventsbus"
)
// ---------------------------------------------------------------------
// события / event bus
// ---------------------------------------------------------------------
// EN: In-memory bounded event bus used for SSE replay and polling watchers.
// EN: It keeps only the latest N events and assigns monotonically increasing IDs.
// RU: Ограниченная in-memory шина событий для SSE-реплея и фоновых вотчеров.
// RU: Хранит только последние N событий и присваивает монотонно растущие ID.
type eventBus struct {
mu sync.Mutex
cond *sync.Cond
buf []Event
cap int
next int64
inner *eventsbuspkg.Bus
}
// ---------------------------------------------------------------------
// EN: `newEventBus` creates a new instance for event bus.
// RU: `newEventBus` - создает новый экземпляр для event bus.
// ---------------------------------------------------------------------
func newEventBus(capacity int) *eventBus {
if capacity < 16 {
capacity = 16
}
b := &eventBus{
cap: capacity,
buf: make([]Event, 0, capacity),
}
b.cond = sync.NewCond(&b.mu)
return b
return &eventBus{inner: eventsbuspkg.New(capacity)}
}
// ---------------------------------------------------------------------
// EN: `push` contains core logic for push.
// RU: `push` - содержит основную логику для push.
// ---------------------------------------------------------------------
func (b *eventBus) push(kind string, data interface{}) Event {
b.mu.Lock()
defer b.mu.Unlock()
b.next++
evt := Event{
ID: b.next,
Kind: kind,
Ts: time.Now().UTC().Format(time.RFC3339Nano),
Data: data,
}
if len(b.buf) >= b.cap {
b.buf = b.buf[1:]
}
b.buf = append(b.buf, evt)
b.cond.Broadcast()
return evt
ev := b.inner.Push(kind, data)
return Event{ID: ev.ID, Kind: ev.Kind, Ts: ev.Ts, Data: ev.Data}
}
// ---------------------------------------------------------------------
// EN: `since` contains core logic for since.
// RU: `since` - содержит основную логику для since.
// ---------------------------------------------------------------------
func (b *eventBus) since(id int64) []Event {
b.mu.Lock()
defer b.mu.Unlock()
return b.sinceLocked(id)
}
// ---------------------------------------------------------------------
// EN: `sinceLocked` contains core logic for since locked.
// RU: `sinceLocked` - содержит основную логику для since locked.
// ---------------------------------------------------------------------
func (b *eventBus) sinceLocked(id int64) []Event {
if len(b.buf) == 0 {
raw := b.inner.Since(id)
if len(raw) == 0 {
return nil
}
var out []Event
for _, ev := range b.buf {
if ev.ID > id {
out = append(out, ev)
}
out := make([]Event, 0, len(raw))
for _, ev := range raw {
out = append(out, Event{ID: ev.ID, Kind: ev.Kind, Ts: ev.Ts, Data: ev.Data})
}
return out
}
// ---------------------------------------------------------------------
// env helpers
// ---------------------------------------------------------------------
// EN: Positive integer env reader with safe default fallback.
// RU: Чтение положительного целого из env с безопасным fallback на дефолт.
func envInt(key string, def int) int {

View File

@@ -1,111 +1,32 @@
package app
import (
"encoding/json"
"fmt"
"io"
"net/http"
"strconv"
"strings"
eventstreampkg "selective-vpn-api/app/eventstream"
"time"
)
// ---------------------------------------------------------------------
// events (SSE)
// ---------------------------------------------------------------------
// EN: Server-Sent Events transport with replay support via Last-Event-ID/since,
// EN: heartbeat pings, and periodic polling of the in-memory event buffer.
// RU: Транспорт Server-Sent Events с поддержкой реплея через Last-Event-ID/since,
// RU: heartbeat-пингами и периодическим опросом in-memory буфера событий.
// ---------------------------------------------------------------------
// SSE helpers
// ---------------------------------------------------------------------
func parseSinceID(r *http.Request) int64 {
sinceStr := strings.TrimSpace(r.URL.Query().Get("since"))
if sinceStr == "" {
sinceStr = strings.TrimSpace(r.Header.Get("Last-Event-ID"))
}
if sinceStr == "" {
return 0
}
if v, err := strconv.ParseInt(sinceStr, 10, 64); err == nil && v >= 0 {
return v
}
return 0
return eventstreampkg.ParseSinceID(r)
}
// ---------------------------------------------------------------------
// SSE stream handler
// ---------------------------------------------------------------------
func handleEventsStream(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "streaming unsupported", http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
ctx := r.Context()
since := parseSinceID(r)
send := func(ev Event) error {
payload, err := json.Marshal(ev)
if err != nil {
return err
}
if _, err := fmt.Fprintf(w, "id: %d\nevent: %s\ndata: %s\n\n", ev.ID, ev.Kind, string(payload)); err != nil {
return err
}
flusher.Flush()
return nil
}
// initial replay
for _, ev := range events.since(since) {
if err := send(ev); err != nil {
return
}
since = ev.ID
}
// polling loop; lightweight for localhost
pollEvery := 500 * time.Millisecond
heartbeat := time.Duration(envInt("SVPN_EVENTS_HEARTBEAT_SEC", defaultHeartbeatSeconds)) * time.Second
pollTicker := time.NewTicker(pollEvery)
pingTicker := time.NewTicker(heartbeat)
defer pollTicker.Stop()
defer pingTicker.Stop()
for {
select {
case <-ctx.Done():
return
case <-pingTicker.C:
_, _ = io.WriteString(w, ": ping\n\n")
flusher.Flush()
case <-pollTicker.C:
evs := events.since(since)
if len(evs) == 0 {
continue
eventstreampkg.Serve(
w,
r,
500*time.Millisecond,
heartbeat,
func(since int64) []eventstreampkg.Event {
raw := events.since(since)
if len(raw) == 0 {
return nil
}
for _, ev := range evs {
if err := send(ev); err != nil {
return
}
since = ev.ID
out := make([]eventstreampkg.Event, 0, len(raw))
for _, ev := range raw {
out = append(out, eventstreampkg.Event{ID: ev.ID, Kind: ev.Kind, Data: ev})
}
}
}
return out
},
)
}

View File

@@ -0,0 +1,68 @@
package eventsbus
import (
"sync"
"time"
)
type Event struct {
ID int64
Kind string
Ts string
Data any
}
type Bus struct {
mu sync.Mutex
cond *sync.Cond
buf []Event
cap int
next int64
}
func New(capacity int) *Bus {
if capacity < 16 {
capacity = 16
}
b := &Bus{
cap: capacity,
buf: make([]Event, 0, capacity),
}
b.cond = sync.NewCond(&b.mu)
return b
}
func (b *Bus) Push(kind string, data any) Event {
b.mu.Lock()
defer b.mu.Unlock()
b.next++
evt := Event{
ID: b.next,
Kind: kind,
Ts: time.Now().UTC().Format(time.RFC3339Nano),
Data: data,
}
if len(b.buf) >= b.cap {
b.buf = b.buf[1:]
}
b.buf = append(b.buf, evt)
b.cond.Broadcast()
return evt
}
func (b *Bus) Since(id int64) []Event {
b.mu.Lock()
defer b.mu.Unlock()
if len(b.buf) == 0 {
return nil
}
out := make([]Event, 0, len(b.buf))
for _, ev := range b.buf {
if ev.ID > id {
out = append(out, ev)
}
}
return out
}

View File

@@ -0,0 +1,107 @@
package eventstream
import (
"encoding/json"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"time"
)
type Event struct {
ID int64
Kind string
Data any
}
func ParseSinceID(r *http.Request) int64 {
sinceStr := strings.TrimSpace(r.URL.Query().Get("since"))
if sinceStr == "" {
sinceStr = strings.TrimSpace(r.Header.Get("Last-Event-ID"))
}
if sinceStr == "" {
return 0
}
if v, err := strconv.ParseInt(sinceStr, 10, 64); err == nil && v >= 0 {
return v
}
return 0
}
func Serve(w http.ResponseWriter, r *http.Request, pollEvery, heartbeat time.Duration, loadSince func(int64) []Event) {
if r.Method != http.MethodGet {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "streaming unsupported", http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
if pollEvery <= 0 {
pollEvery = 500 * time.Millisecond
}
if heartbeat <= 0 {
heartbeat = 15 * time.Second
}
ctx := r.Context()
since := ParseSinceID(r)
send := func(ev Event) error {
payload, err := json.Marshal(ev.Data)
if err != nil {
return err
}
if _, err := fmt.Fprintf(w, "id: %d\nevent: %s\ndata: %s\n\n", ev.ID, ev.Kind, string(payload)); err != nil {
return err
}
flusher.Flush()
return nil
}
if loadSince != nil {
for _, ev := range loadSince(since) {
if err := send(ev); err != nil {
return
}
since = ev.ID
}
}
pollTicker := time.NewTicker(pollEvery)
pingTicker := time.NewTicker(heartbeat)
defer pollTicker.Stop()
defer pingTicker.Stop()
for {
select {
case <-ctx.Done():
return
case <-pingTicker.C:
_, _ = io.WriteString(w, ": ping\n\n")
flusher.Flush()
case <-pollTicker.C:
if loadSince == nil {
continue
}
evs := loadSince(since)
if len(evs) == 0 {
continue
}
for _, ev := range evs {
if err := send(ev); err != nil {
return
}
since = ev.ID
}
}
}
}

View File

@@ -0,0 +1,30 @@
package app
import (
"fmt"
"strconv"
"strings"
)
func parsePort(raw string) int {
p := strings.TrimSpace(raw)
if p == "" {
return 0
}
n, err := strconv.Atoi(p)
if err != nil || n <= 0 || n > 65535 {
return 0
}
return n
}
func asString(v any) string {
switch vv := v.(type) {
case string:
return strings.TrimSpace(vv)
case nil:
return ""
default:
return strings.TrimSpace(fmt.Sprint(vv))
}
}

View File

@@ -1,59 +1,18 @@
package app
import (
"encoding/json"
"log"
"net/http"
"time"
httpxpkg "selective-vpn-api/app/httpx"
)
// ---------------------------------------------------------------------
// HTTP helpers
// ---------------------------------------------------------------------
// EN: Common HTTP helpers used by all endpoint groups for consistent JSON output,
// EN: lightweight request timing logs, and health probing.
// RU: Общие HTTP-хелперы для всех групп эндпоинтов: единый JSON-ответ,
// RU: лёгкое логирование длительности запросов и health-check.
// ---------------------------------------------------------------------
// request logging
// ---------------------------------------------------------------------
func logRequests(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
next.ServeHTTP(w, r)
log.Printf("%s %s %s", r.Method, r.URL.Path, time.Since(start))
})
return httpxpkg.LogRequests(next)
}
// ---------------------------------------------------------------------
// JSON response helper
// ---------------------------------------------------------------------
func writeJSON(w http.ResponseWriter, status int, v any) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.WriteHeader(status)
if v == nil {
return
}
if err := json.NewEncoder(w).Encode(v); err != nil {
log.Printf("writeJSON error: %v", err)
}
httpxpkg.WriteJSON(w, status, v)
}
// ---------------------------------------------------------------------
// health endpoint
// ---------------------------------------------------------------------
func handleHealthz(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
writeJSON(w, http.StatusOK, map[string]string{
"status": "ok",
"time": time.Now().Format(time.RFC3339),
})
httpxpkg.HandleHealthz(w, r)
}

View File

@@ -0,0 +1,38 @@
package httpx
import (
"encoding/json"
"log"
"net/http"
"time"
)
func LogRequests(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
next.ServeHTTP(w, r)
log.Printf("%s %s %s", r.Method, r.URL.Path, time.Since(start))
})
}
func WriteJSON(w http.ResponseWriter, status int, v any) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.WriteHeader(status)
if v == nil {
return
}
if err := json.NewEncoder(w).Encode(v); err != nil {
log.Printf("writeJSON error: %v", err)
}
}
func HandleHealthz(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
WriteJSON(w, http.StatusOK, map[string]string{
"status": "ok",
"time": time.Now().Format(time.RFC3339),
})
}

View File

@@ -1,400 +1,40 @@
package app
import (
"bytes"
"context"
"errors"
"fmt"
"net/netip"
"os/exec"
"sort"
"strings"
"time"
"github.com/cenkalti/backoff/v4"
nftupdatepkg "selective-vpn-api/app/nftupdate"
)
// ---------------------------------------------------------------------
// nft update helpers
// ---------------------------------------------------------------------
// EN: NFT set update strategy with interval compression and two execution modes:
// EN: atomic transaction first, then chunked fallback with per-IP recovery.
// RU: Стратегия обновления NFT-набора с компрессией интервалов и двумя режимами:
// RU: сначала атомарная транзакция, затем chunked fallback с поштучным восстановлением.
func nftLog(format string, args ...any) {
appendTraceLine("routes", fmt.Sprintf(format, args...))
}
// ---------------------------------------------------------------------
// interval compression
// ---------------------------------------------------------------------
// compressIPIntervals убирает:
// - дубликаты строк
// - подсети, целиком покрытые более широкими подсетями
// - одиночные IP, попадающие в уже имеющиеся подсети
func compressIPIntervals(ips []string) []string {
// чтобы не гонять дубликаты строк
seen := make(map[string]struct{})
type prefixItem struct {
p netip.Prefix
raw string
}
type addrItem struct {
a netip.Addr
raw string
}
var prefixes []prefixItem
var addrs []addrItem
for _, s := range ips {
s = strings.TrimSpace(s)
if s == "" {
continue
}
if _, ok := seen[s]; ok {
continue
}
seen[s] = struct{}{}
if strings.Contains(s, "/") {
p, err := netip.ParsePrefix(s)
if err != nil {
// если формат кривой — просто пропускаем
continue
}
prefixes = append(prefixes, prefixItem{p: p, raw: s})
} else {
a, err := netip.ParseAddr(s)
if err != nil {
continue
}
addrs = append(addrs, addrItem{a: a, raw: s})
}
}
// 1) Убираем подсети, полностью покрытые более крупными подсетями.
//
// Сначала сортируем по:
// - адресу
// - длине префикса (меньший Bits = более широкая сеть) — раньше
sort.Slice(prefixes, func(i, j int) bool {
ai := prefixes[i].p.Addr()
aj := prefixes[j].p.Addr()
if ai == aj {
return prefixes[i].p.Bits() < prefixes[j].p.Bits()
}
return ai.Less(aj)
})
var keptPrefixes []prefixItem
for _, pi := range prefixes {
covered := false
for _, kp := range keptPrefixes {
// если более крупная сеть kp уже покрывает эту — пропускаем
if kp.p.Bits() <= pi.p.Bits() && kp.p.Contains(pi.p.Addr()) {
covered = true
break
}
}
if !covered {
keptPrefixes = append(keptPrefixes, pi)
}
}
var keptAddrs []addrItem
for _, ai := range addrs {
inNet := false
for _, kp := range keptPrefixes {
if kp.p.Contains(ai.a) {
inNet = true
break
}
}
if !inNet {
keptAddrs = append(keptAddrs, ai)
}
}
// 3) Собираем финальный список строк
out := make([]string, 0, len(keptPrefixes)+len(keptAddrs))
for _, ai := range keptAddrs {
out = append(out, ai.raw)
}
for _, pi := range keptPrefixes {
out = append(out, pi.raw)
}
return out
}
// ---------------------------------------------------------------------
// smart update strategy
// ---------------------------------------------------------------------
// умный апдейтер: сначала atomic, при фейле fallback на chunked
func nftUpdateIPsSmart(ctx context.Context, ips []string, progressCb ProgressCallback) error {
return nftUpdateSetIPsSmart(ctx, "agvpn4", ips, progressCb)
return nftupdatepkg.UpdateIPsSmart(
ctx,
ips,
func(percent int, message string) {
if progressCb != nil {
progressCb(percent, message)
}
},
runCommandTimeout,
nftLog,
)
}
// nftUpdateSetIPsSmart — тот же апдейтер, но для произвольного nft set.
func nftUpdateSetIPsSmart(ctx context.Context, setName string, ips []string, progressCb ProgressCallback) error {
setName = strings.TrimSpace(setName)
if setName == "" {
setName = "agvpn4"
}
if len(ips) == 0 {
if progressCb != nil {
progressCb(100, "nothing to update")
}
return nil
}
// Сжимаем IP / подсети, убираем пересечения и дубликаты
origCount := len(ips)
ips = compressIPIntervals(ips)
if len(ips) != origCount {
nftLog(
"compress(%s): %d -> %d IP elements (removed %d covered/duplicate entries)",
setName, origCount, len(ips), origCount-len(ips),
)
}
if len(ips) == 0 {
if progressCb != nil {
progressCb(100, "nothing to update after compression")
}
return nil
}
nftLog("nftUpdateSetIPsSmart(%s): start, ips=%d", setName, len(ips))
// 1) atomic транзакция через nft -f -
if err := nftAtomicUpdateWithProgress(ctx, setName, ips, progressCb); err == nil {
return nil
} else {
// если контекст умер дальше не идём
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
nftLog("atomic update cancelled (%s): %v", setName, err)
return err
}
nftLog("atomic nft update failed (%s): %v; falling back to chunked mode", setName, err)
if progressCb != nil {
progressCb(0, "Falling back to non-atomic update")
}
}
// 2) fallback: flush + chunked с поштучным фолбэком
return nftChunkedUpdateWithFallback(ctx, setName, ips, progressCb)
}
// ---------------------------------------------------------------------
// atomic updater
// ---------------------------------------------------------------------
// атомарный апдейт через один nft-транзакционный скрипт
func nftAtomicUpdateWithProgress(ctx context.Context, setName string, ips []string, progressCb ProgressCallback) error {
if len(ips) == 0 {
if progressCb != nil {
progressCb(100, "nothing to update")
}
return nil
}
sort.Strings(ips) // стабильность
total := len(ips)
chunkSize := 500 // старт
bo := backoff.NewExponentialBackOff()
bo.InitialInterval = 500 * time.Millisecond
bo.MaxInterval = 10 * time.Second
bo.MaxElapsedTime = 2 * time.Minute
return backoff.Retry(func() error {
select {
case <-ctx.Done():
if progressCb != nil {
progressCb(0, "Cancelled by context")
}
return ctx.Err()
default:
}
var script strings.Builder
script.WriteString("flush set inet agvpn " + setName + "\n")
processed := 0
chunksTotal := (len(ips) + chunkSize - 1) / chunkSize
for i := 0; i < len(ips); i += chunkSize {
end := i + chunkSize
if end > len(ips) {
end = len(ips)
}
chunk := ips[i:end]
script.WriteString("add element inet agvpn " + setName + " { ")
script.WriteString(strings.Join(chunk, ", "))
script.WriteString(" }\n")
processed += len(chunk)
if progressCb != nil {
percent := processed * 100 / total
progressCb(percent, fmt.Sprintf(
"Preparing chunk %d/%d (%d/%d IPs)",
i/chunkSize+1, chunksTotal, processed, total,
))
}
}
if progressCb != nil {
progressCb(90, "Executing nft transaction...")
}
cmd := exec.CommandContext(ctx, "nft", "-f", "-")
cmd.Stdin = strings.NewReader(script.String())
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
err := cmd.Run()
if err == nil {
nftLog("nft atomic transaction success (%s): %d IPs added", setName, len(ips))
if progressCb != nil {
progressCb(100, "Update complete")
}
return nil
}
errStr := stderr.String()
nftLog("nft atomic transaction failed (%s): err=%v, stderr=%q", setName, err, errStr)
// Ошибки, требующие уменьшения чанка
if strings.Contains(errStr, "too many elements") ||
strings.Contains(errStr, "out of memory") ||
strings.Contains(errStr, "interval overlaps") ||
strings.Contains(errStr, "conflicting intervals") {
newSize := chunkSize / 2
if newSize < 100 {
newSize = 100
}
if newSize == chunkSize {
// дальше делить некуда — Permanent → fallback
return backoff.Permanent(fmt.Errorf("atomic nft cannot shrink further: %w", err))
}
nftLog("reducing atomic chunk size from %d to %d and retrying", chunkSize, newSize)
chunkSize = newSize
if progressCb != nil {
progressCb(0, fmt.Sprintf("Retrying atomic with smaller chunks (%d IPs)", chunkSize))
}
return fmt.Errorf("retry atomic with smaller chunks")
}
// Другие ошибки — Permanent (переход к chunked)
return backoff.Permanent(fmt.Errorf("nft atomic transaction failed: %w", err))
}, bo)
}
// ---------------------------------------------------------------------
// chunked fallback updater
// ---------------------------------------------------------------------
// nftChunkedUpdateWithFallback — fallback-режим: flush + чанки + поштучно при ошибках
func nftChunkedUpdateWithFallback(ctx context.Context, setName string, ips []string, progressCb ProgressCallback) error {
if len(ips) == 0 {
if progressCb != nil {
progressCb(100, "nothing to update")
}
return nil
}
sort.Strings(ips)
total := len(ips)
chunkSize := 500
// flush
_, stderr, _, err := runCommandTimeout(10*time.Second,
"nft", "flush", "set", "inet", "agvpn", setName)
if err != nil {
return fmt.Errorf("flush set failed: %v (%s)", err, stderr)
}
processed := 0
for i := 0; i < len(ips); i += chunkSize {
select {
case <-ctx.Done():
if progressCb != nil {
progressCb(0, "Cancelled during chunked update")
}
return ctx.Err()
default:
}
end := i + chunkSize
if end > len(ips) {
end = len(ips)
}
chunk := ips[i:end]
cmdArgs := []string{
"nft", "add", "element", "inet", "agvpn", setName,
"{ " + strings.Join(chunk, ", ") + " }",
}
cmdName := cmdArgs[0]
cmdRest := cmdArgs[1:]
_, stderr, _, err := runCommandTimeout(15*time.Second, cmdName, cmdRest...)
if err != nil {
// типичные ошибки → поштучно
if strings.Contains(stderr, "interval overlaps") ||
strings.Contains(stderr, "too many elements") ||
strings.Contains(stderr, "out of memory") ||
strings.Contains(stderr, "conflicting intervals") {
nftLog("chunk failed (%d IPs), fallback per-ip", len(chunk))
if progressCb != nil {
progressCb(processed*100/total,
fmt.Sprintf("Chunk failed -> adding %d IPs one by one", len(chunk)))
}
for _, ip := range chunk {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
_, _, _, _ = runCommandTimeout(5*time.Second,
"nft", "add", "element", "inet", "agvpn", setName, "{ "+ip+" }")
}
} else {
return fmt.Errorf("nft chunk add failed: %v (%s)", err, stderr)
}
}
processed += len(chunk)
if progressCb != nil {
percent := processed * 100 / total
progressCb(percent, fmt.Sprintf("Added %d/%d IPs", processed, total))
}
}
if progressCb != nil {
progressCb(100, "chunked update complete")
}
nftLog("nft chunked update success (%s): %d IPs", setName, len(ips))
return nil
return nftupdatepkg.UpdateSetIPsSmart(
ctx,
setName,
ips,
func(percent int, message string) {
if progressCb != nil {
progressCb(percent, message)
}
},
runCommandTimeout,
nftLog,
)
}

View File

@@ -0,0 +1,338 @@
package nftupdate
import (
"bytes"
"context"
"errors"
"fmt"
"net/netip"
"os/exec"
"sort"
"strings"
"time"
"github.com/cenkalti/backoff/v4"
)
type ProgressCallback func(percent int, message string)
type CmdRunner func(timeout time.Duration, name string, args ...string) (stdout, stderr string, exitCode int, err error)
type Logger func(format string, args ...any)
func UpdateIPsSmart(ctx context.Context, ips []string, progressCb ProgressCallback, runCmd CmdRunner, logf Logger) error {
return UpdateSetIPsSmart(ctx, "agvpn4", ips, progressCb, runCmd, logf)
}
func UpdateSetIPsSmart(ctx context.Context, setName string, ips []string, progressCb ProgressCallback, runCmd CmdRunner, logf Logger) error {
setName = strings.TrimSpace(setName)
if setName == "" {
setName = "agvpn4"
}
if runCmd == nil {
return fmt.Errorf("run command function is not configured")
}
if len(ips) == 0 {
if progressCb != nil {
progressCb(100, "nothing to update")
}
return nil
}
origCount := len(ips)
ips = compressIPIntervals(ips)
if len(ips) != origCount {
log(logf, "compress(%s): %d -> %d IP elements (removed %d covered/duplicate entries)", setName, origCount, len(ips), origCount-len(ips))
}
if len(ips) == 0 {
if progressCb != nil {
progressCb(100, "nothing to update after compression")
}
return nil
}
log(logf, "nft UpdateSetIPsSmart(%s): start, ips=%d", setName, len(ips))
if err := atomicUpdateWithProgress(ctx, setName, ips, progressCb, logf); err == nil {
return nil
} else {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
log(logf, "atomic update cancelled (%s): %v", setName, err)
return err
}
log(logf, "atomic nft update failed (%s): %v; falling back to chunked mode", setName, err)
if progressCb != nil {
progressCb(0, "Falling back to non-atomic update")
}
}
return chunkedUpdateWithFallback(ctx, setName, ips, progressCb, runCmd, logf)
}
func compressIPIntervals(ips []string) []string {
seen := make(map[string]struct{})
type prefixItem struct {
p netip.Prefix
raw string
}
type addrItem struct {
a netip.Addr
raw string
}
var prefixes []prefixItem
var addrs []addrItem
for _, s := range ips {
s = strings.TrimSpace(s)
if s == "" {
continue
}
if _, ok := seen[s]; ok {
continue
}
seen[s] = struct{}{}
if strings.Contains(s, "/") {
p, err := netip.ParsePrefix(s)
if err != nil {
continue
}
prefixes = append(prefixes, prefixItem{p: p, raw: s})
} else {
a, err := netip.ParseAddr(s)
if err != nil {
continue
}
addrs = append(addrs, addrItem{a: a, raw: s})
}
}
sort.Slice(prefixes, func(i, j int) bool {
ai := prefixes[i].p.Addr()
aj := prefixes[j].p.Addr()
if ai == aj {
return prefixes[i].p.Bits() < prefixes[j].p.Bits()
}
return ai.Less(aj)
})
var keptPrefixes []prefixItem
for _, pi := range prefixes {
covered := false
for _, kp := range keptPrefixes {
if kp.p.Bits() <= pi.p.Bits() && kp.p.Contains(pi.p.Addr()) {
covered = true
break
}
}
if !covered {
keptPrefixes = append(keptPrefixes, pi)
}
}
var keptAddrs []addrItem
for _, ai := range addrs {
inNet := false
for _, kp := range keptPrefixes {
if kp.p.Contains(ai.a) {
inNet = true
break
}
}
if !inNet {
keptAddrs = append(keptAddrs, ai)
}
}
out := make([]string, 0, len(keptPrefixes)+len(keptAddrs))
for _, ai := range keptAddrs {
out = append(out, ai.raw)
}
for _, pi := range keptPrefixes {
out = append(out, pi.raw)
}
return out
}
func atomicUpdateWithProgress(ctx context.Context, setName string, ips []string, progressCb ProgressCallback, logf Logger) error {
if len(ips) == 0 {
if progressCb != nil {
progressCb(100, "nothing to update")
}
return nil
}
sort.Strings(ips)
total := len(ips)
chunkSize := 500
bo := backoff.NewExponentialBackOff()
bo.InitialInterval = 500 * time.Millisecond
bo.MaxInterval = 10 * time.Second
bo.MaxElapsedTime = 2 * time.Minute
return backoff.Retry(func() error {
select {
case <-ctx.Done():
if progressCb != nil {
progressCb(0, "Cancelled by context")
}
return ctx.Err()
default:
}
var script strings.Builder
script.WriteString("flush set inet agvpn " + setName + "\n")
processed := 0
chunksTotal := (len(ips) + chunkSize - 1) / chunkSize
for i := 0; i < len(ips); i += chunkSize {
end := i + chunkSize
if end > len(ips) {
end = len(ips)
}
chunk := ips[i:end]
script.WriteString("add element inet agvpn " + setName + " { ")
script.WriteString(strings.Join(chunk, ", "))
script.WriteString(" }\n")
processed += len(chunk)
if progressCb != nil {
percent := processed * 100 / total
progressCb(percent, fmt.Sprintf("Preparing chunk %d/%d (%d/%d IPs)", i/chunkSize+1, chunksTotal, processed, total))
}
}
if progressCb != nil {
progressCb(90, "Executing nft transaction...")
}
cmd := exec.CommandContext(ctx, "nft", "-f", "-")
cmd.Stdin = strings.NewReader(script.String())
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
err := cmd.Run()
if err == nil {
log(logf, "nft atomic transaction success (%s): %d IPs added", setName, len(ips))
if progressCb != nil {
progressCb(100, "Update complete")
}
return nil
}
errStr := stderr.String()
log(logf, "nft atomic transaction failed (%s): err=%v, stderr=%q", setName, err, errStr)
if strings.Contains(errStr, "too many elements") ||
strings.Contains(errStr, "out of memory") ||
strings.Contains(errStr, "interval overlaps") ||
strings.Contains(errStr, "conflicting intervals") {
newSize := chunkSize / 2
if newSize < 100 {
newSize = 100
}
if newSize == chunkSize {
return backoff.Permanent(fmt.Errorf("atomic nft cannot shrink further: %w", err))
}
log(logf, "reducing atomic chunk size from %d to %d and retrying", chunkSize, newSize)
chunkSize = newSize
if progressCb != nil {
progressCb(0, fmt.Sprintf("Retrying atomic with smaller chunks (%d IPs)", chunkSize))
}
return fmt.Errorf("retry atomic with smaller chunks")
}
return backoff.Permanent(fmt.Errorf("nft atomic transaction failed: %w", err))
}, bo)
}
func chunkedUpdateWithFallback(ctx context.Context, setName string, ips []string, progressCb ProgressCallback, runCmd CmdRunner, logf Logger) error {
if len(ips) == 0 {
if progressCb != nil {
progressCb(100, "nothing to update")
}
return nil
}
sort.Strings(ips)
total := len(ips)
chunkSize := 500
_, stderr, _, err := runCmd(10*time.Second, "nft", "flush", "set", "inet", "agvpn", setName)
if err != nil {
return fmt.Errorf("flush set failed: %v (%s)", err, stderr)
}
processed := 0
for i := 0; i < len(ips); i += chunkSize {
select {
case <-ctx.Done():
if progressCb != nil {
progressCb(0, "Cancelled during chunked update")
}
return ctx.Err()
default:
}
end := i + chunkSize
if end > len(ips) {
end = len(ips)
}
chunk := ips[i:end]
_, stderr, _, err := runCmd(15*time.Second, "nft", "add", "element", "inet", "agvpn", setName, "{ "+strings.Join(chunk, ", ")+" }")
if err != nil {
if strings.Contains(stderr, "interval overlaps") ||
strings.Contains(stderr, "too many elements") ||
strings.Contains(stderr, "out of memory") ||
strings.Contains(stderr, "conflicting intervals") {
log(logf, "chunk failed (%d IPs), fallback per-ip", len(chunk))
if progressCb != nil {
progressCb(processed*100/total, fmt.Sprintf("Chunk failed -> adding %d IPs one by one", len(chunk)))
}
for _, ip := range chunk {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
_, _, _, _ = runCmd(5*time.Second, "nft", "add", "element", "inet", "agvpn", setName, "{ "+ip+" }")
}
} else {
return fmt.Errorf("nft chunk add failed: %v (%s)", err, stderr)
}
}
processed += len(chunk)
if progressCb != nil {
percent := processed * 100 / total
progressCb(percent, fmt.Sprintf("Added %d/%d IPs", processed, total))
}
}
if progressCb != nil {
progressCb(100, "chunked update complete")
}
log(logf, "nft chunked update success (%s): %d IPs", setName, len(ips))
return nil
}
func log(logf Logger, format string, args ...any) {
if logf != nil {
logf(format, args...)
}
}

View File

@@ -0,0 +1,64 @@
package app
import (
refreshcoordpkg "selective-vpn-api/app/refreshcoord"
"time"
)
type refreshStateSnapshot = refreshcoordpkg.Snapshot
type refreshCoordinator struct {
inner refreshcoordpkg.Coordinator
}
func newRefreshCoordinator(freshTTL, backoffMin, backoffMax time.Duration) refreshCoordinator {
return refreshCoordinator{inner: refreshcoordpkg.New(freshTTL, backoffMin, backoffMax)}
}
func (c *refreshCoordinator) setUpdatedAt(at time.Time) {
c.inner.SetUpdatedAt(at)
}
func (c *refreshCoordinator) beginRefresh(now time.Time, force bool, hasData bool) bool {
return c.inner.BeginRefresh(now, force, hasData)
}
func (c *refreshCoordinator) shouldRefresh(now time.Time, force bool, hasData bool) bool {
return c.inner.ShouldRefresh(now, force, hasData)
}
func (c *refreshCoordinator) isStale(now time.Time) bool {
return c.inner.IsStale(now)
}
func (c *refreshCoordinator) finishSuccess(now time.Time) {
c.inner.FinishSuccess(now)
}
func (c *refreshCoordinator) finishError(msg string, now time.Time) {
c.inner.FinishError(msg, now)
}
func (c *refreshCoordinator) snapshot(now time.Time) refreshStateSnapshot {
return c.inner.Snapshot(now)
}
func (c *refreshCoordinator) refreshInProgress() bool {
return c.inner.RefreshInProgress()
}
func (c *refreshCoordinator) nextRetryAt() time.Time {
return c.inner.NextRetryAt()
}
func (c *refreshCoordinator) clearBackoff() {
c.inner.ClearBackoff()
}
func (c *refreshCoordinator) consecutiveErrors() int {
return c.inner.ConsecutiveErrors()
}
func (c *refreshCoordinator) lastError() string {
return c.inner.LastError()
}

View File

@@ -0,0 +1,80 @@
package app
import (
"testing"
"time"
)
func TestRefreshCoordinatorLifecycle(t *testing.T) {
now := time.Date(2026, time.March, 9, 21, 0, 0, 0, time.UTC)
rc := newRefreshCoordinator(10*time.Minute, 2*time.Second, 60*time.Second)
if !rc.shouldRefresh(now, false, false) {
t.Fatalf("expected refresh when cache is empty")
}
if !rc.beginRefresh(now, false, false) {
t.Fatalf("expected beginRefresh=true for empty cache")
}
if rc.shouldRefresh(now, false, false) {
t.Fatalf("must not refresh while refresh is in progress")
}
rc.finishSuccess(now)
snap := rc.snapshot(now)
if snap.RefreshInProgress {
t.Fatalf("refresh must be finished after success")
}
if snap.Stale {
t.Fatalf("fresh snapshot must not be stale")
}
if snap.LastError != "" || snap.NextRetryAt != "" {
t.Fatalf("success must clear error and retry metadata: %#v", snap)
}
if rc.shouldRefresh(now.Add(5*time.Minute), false, true) {
t.Fatalf("fresh cache should not refresh yet")
}
if !rc.shouldRefresh(now.Add(11*time.Minute), false, true) {
t.Fatalf("stale cache should refresh")
}
}
func TestRefreshCoordinatorBackoffAndReset(t *testing.T) {
start := time.Date(2026, time.March, 9, 21, 5, 0, 0, time.UTC)
rc := newRefreshCoordinator(10*time.Minute, 2*time.Second, 60*time.Second)
type step struct {
at time.Time
expected time.Duration
}
steps := []step{
{at: start, expected: 2 * time.Second},
{at: start.Add(2 * time.Second), expected: 4 * time.Second},
{at: start.Add(6 * time.Second), expected: 8 * time.Second},
{at: start.Add(14 * time.Second), expected: 16 * time.Second},
{at: start.Add(30 * time.Second), expected: 32 * time.Second},
{at: start.Add(62 * time.Second), expected: 60 * time.Second},
}
for i, st := range steps {
rc.finishError("probe failed", st.at)
got := rc.nextRetryAt().Sub(st.at)
if got != st.expected {
t.Fatalf("step=%d unexpected backoff: got=%s want=%s", i+1, got, st.expected)
}
if rc.lastError() == "" {
t.Fatalf("step=%d expected non-empty lastError", i+1)
}
}
rc.finishSuccess(start.Add(2 * time.Minute))
if rc.consecutiveErrors() != 0 {
t.Fatalf("expected consecutiveErrors reset, got %d", rc.consecutiveErrors())
}
if !rc.nextRetryAt().IsZero() {
t.Fatalf("expected nextRetryAt reset")
}
if rc.lastError() != "" {
t.Fatalf("expected lastError reset, got %q", rc.lastError())
}
}

View File

@@ -0,0 +1,164 @@
package refreshcoord
import (
"strings"
"time"
)
type Snapshot struct {
UpdatedAt string
Stale bool
RefreshInProgress bool
LastError string
NextRetryAt string
}
// Coordinator is a small reusable SWR-like state machine:
// stale-while-refresh + single-flight + exponential backoff after failures.
//
// All methods are intentionally lock-free; caller owns synchronization.
type Coordinator struct {
freshTTL time.Duration
backoffMin time.Duration
backoffMax time.Duration
updatedAt time.Time
lastError string
refreshInProgress bool
consecutiveErrors int
nextRetryAt time.Time
}
func New(freshTTL, backoffMin, backoffMax time.Duration) Coordinator {
if freshTTL <= 0 {
freshTTL = 10 * time.Minute
}
if backoffMin <= 0 {
backoffMin = 2 * time.Second
}
if backoffMax <= 0 {
backoffMax = 60 * time.Second
}
if backoffMax < backoffMin {
backoffMax = backoffMin
}
return Coordinator{
freshTTL: freshTTL,
backoffMin: backoffMin,
backoffMax: backoffMax,
}
}
func (c *Coordinator) SetUpdatedAt(at time.Time) {
c.updatedAt = at
}
func (c *Coordinator) BeginRefresh(now time.Time, force bool, hasData bool) bool {
if !c.ShouldRefresh(now, force, hasData) {
return false
}
c.refreshInProgress = true
return true
}
func (c *Coordinator) ShouldRefresh(now time.Time, force bool, hasData bool) bool {
if c.refreshInProgress {
return false
}
if !c.nextRetryAt.IsZero() && now.Before(c.nextRetryAt) {
return false
}
if force {
return true
}
if !hasData {
return true
}
return c.IsStale(now)
}
func (c *Coordinator) IsStale(now time.Time) bool {
if c.updatedAt.IsZero() {
return true
}
return now.Sub(c.updatedAt) > c.freshTTL
}
func (c *Coordinator) FinishSuccess(now time.Time) {
c.updatedAt = now
c.lastError = ""
c.refreshInProgress = false
c.consecutiveErrors = 0
c.nextRetryAt = time.Time{}
}
func (c *Coordinator) FinishError(msg string, now time.Time) {
c.lastError = strings.TrimSpace(msg)
c.refreshInProgress = false
c.consecutiveErrors++
c.nextRetryAt = now.Add(c.nextBackoff())
}
func (c *Coordinator) Snapshot(now time.Time) Snapshot {
out := Snapshot{
Stale: c.IsStale(now),
RefreshInProgress: c.refreshInProgress,
LastError: strings.TrimSpace(c.lastError),
}
if !c.updatedAt.IsZero() {
out.UpdatedAt = c.updatedAt.UTC().Format(time.RFC3339)
}
if !c.nextRetryAt.IsZero() {
out.NextRetryAt = c.nextRetryAt.UTC().Format(time.RFC3339)
}
return out
}
func (c *Coordinator) RefreshInProgress() bool {
return c.refreshInProgress
}
func (c *Coordinator) NextRetryAt() time.Time {
return c.nextRetryAt
}
func (c *Coordinator) ConsecutiveErrors() int {
return c.consecutiveErrors
}
func (c *Coordinator) LastError() string {
return c.lastError
}
func (c *Coordinator) ClearBackoff() {
c.nextRetryAt = time.Time{}
}
func (c *Coordinator) nextBackoff() time.Duration {
backoff := c.backoffMin
if backoff <= 0 {
backoff = 2 * time.Second
}
maxBackoff := c.backoffMax
if maxBackoff <= 0 {
maxBackoff = backoff
}
if maxBackoff < backoff {
maxBackoff = backoff
}
for i := 1; i < c.consecutiveErrors; i++ {
if backoff >= maxBackoff {
return maxBackoff
}
if backoff > maxBackoff/2 {
return maxBackoff
}
backoff *= 2
}
if backoff > maxBackoff {
backoff = maxBackoff
}
return backoff
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,88 @@
package resolver
import "sort"
type ResolverArtifacts struct {
IPs []string
DirectIPs []string
WildcardIPs []string
IPMap [][2]string
DirectIPMap [][2]string
WildcardIPMap [][2]string
}
func BuildResolverArtifacts(resolved map[string][]string, staticLabels map[string][]string, isWildcardHost func(string) bool) ResolverArtifacts {
ipSetAll := map[string]struct{}{}
ipSetDirect := map[string]struct{}{}
ipSetWildcard := map[string]struct{}{}
ipMapAll := map[string]map[string]struct{}{}
ipMapDirect := map[string]map[string]struct{}{}
ipMapWildcard := map[string]map[string]struct{}{}
add := func(set map[string]struct{}, labels map[string]map[string]struct{}, ip, label string) {
if ip == "" {
return
}
set[ip] = struct{}{}
m := labels[ip]
if m == nil {
m = map[string]struct{}{}
labels[ip] = m
}
m[label] = struct{}{}
}
for host, ips := range resolved {
wildcardHost := false
if isWildcardHost != nil {
wildcardHost = isWildcardHost(host)
}
for _, ip := range ips {
add(ipSetAll, ipMapAll, ip, host)
if wildcardHost {
add(ipSetWildcard, ipMapWildcard, ip, host)
} else {
add(ipSetDirect, ipMapDirect, ip, host)
}
}
}
for ipEntry, labels := range staticLabels {
for _, lbl := range labels {
add(ipSetAll, ipMapAll, ipEntry, lbl)
add(ipSetDirect, ipMapDirect, ipEntry, lbl)
}
}
var out ResolverArtifacts
appendMapPairs := func(dst *[][2]string, labelsByIP map[string]map[string]struct{}) {
for ip := range labelsByIP {
labels := labelsByIP[ip]
for lbl := range labels {
*dst = append(*dst, [2]string{ip, lbl})
}
}
sort.Slice(*dst, func(i, j int) bool {
if (*dst)[i][0] == (*dst)[j][0] {
return (*dst)[i][1] < (*dst)[j][1]
}
return (*dst)[i][0] < (*dst)[j][0]
})
}
appendIPs := func(dst *[]string, set map[string]struct{}) {
for ip := range set {
*dst = append(*dst, ip)
}
sort.Strings(*dst)
}
appendMapPairs(&out.IPMap, ipMapAll)
appendMapPairs(&out.DirectIPMap, ipMapDirect)
appendMapPairs(&out.WildcardIPMap, ipMapWildcard)
appendIPs(&out.IPs, ipSetAll)
appendIPs(&out.DirectIPs, ipSetDirect)
appendIPs(&out.WildcardIPs, ipSetWildcard)
return out
}

View File

@@ -0,0 +1,60 @@
package resolver
import (
"hash/fnv"
"regexp"
"strconv"
"strings"
)
var reANSI = regexp.MustCompile(`\x1B\[[0-9;]*[A-Za-z]`)
func UniqueStrings(in []string) []string {
seen := map[string]struct{}{}
var out []string
for _, v := range in {
if _, ok := seen[v]; !ok {
seen[v] = struct{}{}
out = append(out, v)
}
}
return out
}
func PickDNSStartIndex(host string, size int) int {
if size <= 1 {
return 0
}
h := fnv.New32a()
_, _ = h.Write([]byte(strings.ToLower(strings.TrimSpace(host))))
return int(h.Sum32() % uint32(size))
}
func StripANSI(s string) string {
return reANSI.ReplaceAllString(s, "")
}
func IsPrivateIPv4(ip string) bool {
parts := strings.Split(strings.Split(ip, "/")[0], ".")
if len(parts) != 4 {
return true
}
vals := make([]int, 4)
for i, p := range parts {
n, err := strconv.Atoi(p)
if err != nil || n < 0 || n > 255 {
return true
}
vals[i] = n
}
if vals[0] == 10 || vals[0] == 127 || vals[0] == 0 {
return true
}
if vals[0] == 192 && vals[1] == 168 {
return true
}
if vals[0] == 172 && vals[1] >= 16 && vals[1] <= 31 {
return true
}
return false
}

View File

@@ -0,0 +1,150 @@
package resolver
import (
"os"
"strings"
)
type DNSConfig struct {
Default []string
Meta []string
SmartDNS string
Mode string
}
type DNSConfigDeps struct {
ActivePool []string
IsSmartDNSForced bool
SmartDNSAddr string
SmartDNSForcedMode string
ResolveFallbackPool func() []string
MergeDNSUpstreamPools func(primary, fallback []string) []string
NormalizeDNSUpstream func(raw string, defaultPort string) string
NormalizeSmartDNSAddr func(raw string) string
NormalizeDNSResolverMode func(raw string) string
}
func LoadDNSConfig(path string, base DNSConfig, deps DNSConfigDeps, logf func(string, ...any)) DNSConfig {
cfg := DNSConfig{
Default: append([]string(nil), base.Default...),
Meta: append([]string(nil), base.Meta...),
SmartDNS: strings.TrimSpace(base.SmartDNS),
Mode: strings.TrimSpace(base.Mode),
}
if cfg.Mode == "" {
cfg.Mode = "direct"
}
if len(deps.ActivePool) > 0 {
cfg.Default = append([]string(nil), deps.ActivePool...)
cfg.Meta = append([]string(nil), deps.ActivePool...)
}
if deps.IsSmartDNSForced {
addr := strings.TrimSpace(deps.SmartDNSAddr)
if deps.NormalizeSmartDNSAddr != nil {
if n := deps.NormalizeSmartDNSAddr(addr); n != "" {
addr = n
}
}
if addr == "" {
addr = cfg.SmartDNS
}
cfg.Default = []string{addr}
cfg.Meta = []string{addr}
cfg.SmartDNS = addr
if strings.TrimSpace(deps.SmartDNSForcedMode) != "" {
cfg.Mode = deps.SmartDNSForcedMode
} else {
cfg.Mode = "smartdns"
}
if logf != nil {
logf("dns-config: SmartDNS forced (%s), ignore %s", addr, path)
}
return cfg
}
data, err := os.ReadFile(path)
if err != nil {
if logf != nil {
logf("dns-config: can't read %s: %v", path, err)
}
fallback := []string(nil)
if deps.ResolveFallbackPool != nil {
fallback = deps.ResolveFallbackPool()
}
if deps.MergeDNSUpstreamPools != nil {
cfg.Default = deps.MergeDNSUpstreamPools(cfg.Default, fallback)
cfg.Meta = deps.MergeDNSUpstreamPools(cfg.Meta, fallback)
}
return cfg
}
var def, meta []string
lines := strings.Split(string(data), "\n")
for _, ln := range lines {
s := strings.TrimSpace(ln)
if s == "" || strings.HasPrefix(s, "#") {
continue
}
parts := strings.Fields(s)
if len(parts) < 2 {
continue
}
key := strings.ToLower(parts[0])
vals := parts[1:]
switch key {
case "default":
for _, v := range vals {
if deps.NormalizeDNSUpstream != nil {
if n := deps.NormalizeDNSUpstream(v, "53"); n != "" {
def = append(def, n)
}
}
}
case "meta":
for _, v := range vals {
if deps.NormalizeDNSUpstream != nil {
if n := deps.NormalizeDNSUpstream(v, "53"); n != "" {
meta = append(meta, n)
}
}
}
case "smartdns":
if len(vals) > 0 && deps.NormalizeSmartDNSAddr != nil {
if n := deps.NormalizeSmartDNSAddr(vals[0]); n != "" {
cfg.SmartDNS = n
}
}
case "mode":
if len(vals) > 0 {
rawMode := vals[0]
if deps.NormalizeDNSResolverMode != nil {
cfg.Mode = deps.NormalizeDNSResolverMode(rawMode)
} else {
cfg.Mode = strings.ToLower(strings.TrimSpace(rawMode))
}
}
}
}
if len(deps.ActivePool) == 0 {
if len(def) > 0 {
cfg.Default = def
}
if len(meta) > 0 {
cfg.Meta = meta
}
}
fallback := []string(nil)
if deps.ResolveFallbackPool != nil {
fallback = deps.ResolveFallbackPool()
}
if deps.MergeDNSUpstreamPools != nil {
cfg.Default = deps.MergeDNSUpstreamPools(cfg.Default, fallback)
cfg.Meta = deps.MergeDNSUpstreamPools(cfg.Meta, fallback)
}
if logf != nil {
logf("dns-config: accept %s: mode=%s smartdns=%s default=%v; meta=%v", path, cfg.Mode, cfg.SmartDNS, cfg.Default, cfg.Meta)
}
return cfg
}

View File

@@ -0,0 +1,52 @@
package resolver
import (
"errors"
"net"
"strings"
)
func ClassifyDNSError(err error) string {
if err == nil {
return "other"
}
var dnsErr *net.DNSError
if errors.As(err, &dnsErr) {
if dnsErr.IsNotFound {
return "nxdomain"
}
if dnsErr.IsTimeout {
return "timeout"
}
if dnsErr.IsTemporary {
return "temporary"
}
}
msg := strings.ToLower(err.Error())
switch {
case strings.Contains(msg, "no such host"), strings.Contains(msg, "nxdomain"):
return "nxdomain"
case strings.Contains(msg, "i/o timeout"), strings.Contains(msg, "timeout"):
return "timeout"
case strings.Contains(msg, "temporary"):
return "temporary"
default:
return "other"
}
}
func SplitDNS(dns string) (string, string) {
if strings.Contains(dns, "#") {
parts := strings.SplitN(dns, "#", 2)
host := strings.TrimSpace(parts[0])
port := strings.TrimSpace(parts[1])
if host == "" {
host = "127.0.0.1"
}
if port == "" {
port = "53"
}
return host, port
}
return strings.TrimSpace(dns), ""
}

View File

@@ -0,0 +1,158 @@
package resolver
import (
"fmt"
"sort"
"strings"
)
type DNSErrorKind string
const (
DNSErrorNXDomain DNSErrorKind = "nxdomain"
DNSErrorTimeout DNSErrorKind = "timeout"
DNSErrorTemporary DNSErrorKind = "temporary"
DNSErrorOther DNSErrorKind = "other"
)
type DNSUpstreamMetrics struct {
Attempts int
OK int
NXDomain int
Timeout int
Temporary int
Other int
Skipped int
}
type DNSMetrics struct {
Attempts int
OK int
NXDomain int
Timeout int
Temporary int
Other int
Skipped int
PerUpstream map[string]*DNSUpstreamMetrics
}
func (m *DNSMetrics) EnsureUpstream(upstream string) *DNSUpstreamMetrics {
if m.PerUpstream == nil {
m.PerUpstream = map[string]*DNSUpstreamMetrics{}
}
if us, ok := m.PerUpstream[upstream]; ok {
return us
}
us := &DNSUpstreamMetrics{}
m.PerUpstream[upstream] = us
return us
}
func (m *DNSMetrics) AddSuccess(upstream string) {
m.Attempts++
m.OK++
us := m.EnsureUpstream(upstream)
us.Attempts++
us.OK++
}
func (m *DNSMetrics) AddError(upstream string, kind DNSErrorKind) {
m.Attempts++
us := m.EnsureUpstream(upstream)
us.Attempts++
switch kind {
case DNSErrorNXDomain:
m.NXDomain++
us.NXDomain++
case DNSErrorTimeout:
m.Timeout++
us.Timeout++
case DNSErrorTemporary:
m.Temporary++
us.Temporary++
default:
m.Other++
us.Other++
}
}
func (m *DNSMetrics) AddCooldownSkip(upstream string) {
m.Skipped++
us := m.EnsureUpstream(upstream)
us.Skipped++
}
func (m *DNSMetrics) Merge(other DNSMetrics) {
m.Attempts += other.Attempts
m.OK += other.OK
m.NXDomain += other.NXDomain
m.Timeout += other.Timeout
m.Temporary += other.Temporary
m.Other += other.Other
m.Skipped += other.Skipped
for upstream, src := range other.PerUpstream {
dst := m.EnsureUpstream(upstream)
dst.Attempts += src.Attempts
dst.OK += src.OK
dst.NXDomain += src.NXDomain
dst.Timeout += src.Timeout
dst.Temporary += src.Temporary
dst.Other += src.Other
dst.Skipped += src.Skipped
}
}
func (m DNSMetrics) TotalErrors() int {
return m.NXDomain + m.Timeout + m.Temporary + m.Other
}
func (m DNSMetrics) FormatPerUpstream() string {
if len(m.PerUpstream) == 0 {
return ""
}
keys := make([]string, 0, len(m.PerUpstream))
for k := range m.PerUpstream {
keys = append(keys, k)
}
sort.Strings(keys)
parts := make([]string, 0, len(keys))
for _, k := range keys {
v := m.PerUpstream[k]
parts = append(parts, fmt.Sprintf("%s{attempts=%d ok=%d nxdomain=%d timeout=%d temporary=%d other=%d skipped=%d}", k, v.Attempts, v.OK, v.NXDomain, v.Timeout, v.Temporary, v.Other, v.Skipped))
}
return strings.Join(parts, "; ")
}
func (m DNSMetrics) FormatResolverHealth() string {
if len(m.PerUpstream) == 0 {
return ""
}
keys := make([]string, 0, len(m.PerUpstream))
for k := range m.PerUpstream {
keys = append(keys, k)
}
sort.Strings(keys)
parts := make([]string, 0, len(keys))
for _, k := range keys {
v := m.PerUpstream[k]
if v == nil || v.Attempts <= 0 {
continue
}
okRate := float64(v.OK) / float64(v.Attempts)
timeoutRate := float64(v.Timeout) / float64(v.Attempts)
score := okRate*100.0 - timeoutRate*50.0
state := "bad"
switch {
case score >= 70 && timeoutRate <= 0.05:
state = "good"
case score >= 35:
state = "degraded"
default:
state = "bad"
}
parts = append(parts, fmt.Sprintf("%s{score=%.1f state=%s attempts=%d ok=%d timeout=%d nxdomain=%d temporary=%d other=%d skipped=%d}", k, score, state, v.Attempts, v.OK, v.Timeout, v.NXDomain, v.Temporary, v.Other, v.Skipped))
}
return strings.Join(parts, "; ")
}

View File

@@ -0,0 +1,57 @@
package resolver
import "strings"
func BuildResolverFallbackPool(raw string, fallbackDefaults []string, normalize func(string) string) []string {
switch strings.ToLower(strings.TrimSpace(raw)) {
case "off", "none", "0":
return nil
}
candidates := fallbackDefaults
if strings.TrimSpace(raw) != "" {
candidates = nil
fields := strings.FieldsFunc(raw, func(r rune) bool {
return r == ',' || r == ';' || r == ' ' || r == '\n' || r == '\t'
})
for _, f := range fields {
if normalize == nil {
continue
}
if n := normalize(f); n != "" {
candidates = append(candidates, n)
}
}
}
return UniqueStrings(candidates)
}
func MergeDNSUpstreamPools(primary, fallback []string, maxUpstreams int, normalize func(string) string) []string {
if maxUpstreams < 1 {
maxUpstreams = 1
}
out := make([]string, 0, len(primary)+len(fallback))
seen := map[string]struct{}{}
add := func(items []string) {
for _, item := range items {
if len(out) >= maxUpstreams {
return
}
if normalize == nil {
continue
}
n := normalize(item)
if n == "" {
continue
}
if _, ok := seen[n]; ok {
continue
}
seen[n] = struct{}{}
out = append(out, n)
}
}
add(primary)
add(fallback)
return out
}

View File

@@ -0,0 +1,649 @@
package resolver
import (
"encoding/json"
"fmt"
"os"
"sort"
"strings"
)
type DomainCacheSource string
const (
DomainCacheSourceDirect DomainCacheSource = "direct"
DomainCacheSourceWildcard DomainCacheSource = "wildcard"
)
const (
DomainStateActive = "active"
DomainStateStable = "stable"
DomainStateSuspect = "suspect"
DomainStateQuarantine = "quarantine"
DomainStateHardQuar = "hard_quarantine"
DomainScoreMin = -100
DomainScoreMax = 100
DomainCacheVersion = 4
DefaultQuarantineTTL = 24 * 3600
DefaultHardQuarTTL = 7 * 24 * 3600
)
var EnvInt = func(key string, def int) int { return def }
var NXHardQuarantineEnabled = func() bool { return false }
type DomainCacheEntry struct {
IPs []string `json:"ips,omitempty"`
LastResolved int `json:"last_resolved,omitempty"`
LastErrorKind string `json:"last_error_kind,omitempty"`
LastErrorAt int `json:"last_error_at,omitempty"`
Score int `json:"score,omitempty"`
State string `json:"state,omitempty"`
QuarantineUntil int `json:"quarantine_until,omitempty"`
}
type DomainCacheRecord struct {
Direct *DomainCacheEntry `json:"direct,omitempty"`
Wildcard *DomainCacheEntry `json:"wildcard,omitempty"`
}
type DomainCacheState struct {
Version int `json:"version"`
Domains map[string]DomainCacheRecord `json:"domains"`
}
func NewDomainCacheState() DomainCacheState {
return DomainCacheState{
Version: DomainCacheVersion,
Domains: map[string]DomainCacheRecord{},
}
}
func NormalizeCacheIPs(raw []string) []string {
seen := map[string]struct{}{}
out := make([]string, 0, len(raw))
for _, ip := range raw {
ip = strings.TrimSpace(ip)
if ip == "" || IsPrivateIPv4(ip) {
continue
}
if _, ok := seen[ip]; ok {
continue
}
seen[ip] = struct{}{}
out = append(out, ip)
}
sort.Strings(out)
return out
}
func NormalizeCacheErrorKind(raw string) (DNSErrorKind, bool) {
switch strings.ToLower(strings.TrimSpace(raw)) {
case string(DNSErrorNXDomain):
return DNSErrorNXDomain, true
case string(DNSErrorTimeout):
return DNSErrorTimeout, true
case string(DNSErrorTemporary):
return DNSErrorTemporary, true
case string(DNSErrorOther):
return DNSErrorOther, true
default:
return "", false
}
}
func NormalizeDomainCacheEntry(in *DomainCacheEntry) *DomainCacheEntry {
if in == nil {
return nil
}
out := &DomainCacheEntry{}
ips := NormalizeCacheIPs(in.IPs)
if len(ips) > 0 && in.LastResolved > 0 {
out.IPs = ips
out.LastResolved = in.LastResolved
}
if kind, ok := NormalizeCacheErrorKind(in.LastErrorKind); ok && in.LastErrorAt > 0 {
out.LastErrorKind = string(kind)
out.LastErrorAt = in.LastErrorAt
}
out.Score = ClampDomainScore(in.Score)
if st := NormalizeDomainState(in.State, out.Score); st != "" {
out.State = st
}
if in.QuarantineUntil > 0 {
out.QuarantineUntil = in.QuarantineUntil
}
if out.LastResolved <= 0 && out.LastErrorAt <= 0 {
if out.Score == 0 && out.QuarantineUntil <= 0 {
return nil
}
}
return out
}
func parseAnyStringSlice(raw any) []string {
switch v := raw.(type) {
case []string:
return append([]string(nil), v...)
case []any:
out := make([]string, 0, len(v))
for _, x := range v {
if s, ok := x.(string); ok {
out = append(out, s)
}
}
return out
default:
return nil
}
}
func parseLegacyDomainCacheEntry(raw any) (DomainCacheEntry, bool) {
m, ok := raw.(map[string]any)
if !ok {
return DomainCacheEntry{}, false
}
ips := NormalizeCacheIPs(parseAnyStringSlice(m["ips"]))
if len(ips) == 0 {
return DomainCacheEntry{}, false
}
ts, ok := parseAnyInt(m["last_resolved"])
if !ok || ts <= 0 {
return DomainCacheEntry{}, false
}
return DomainCacheEntry{IPs: ips, LastResolved: ts}, true
}
func LoadDomainCacheState(path string, logf func(string, ...any)) DomainCacheState {
data, err := os.ReadFile(path)
if err != nil || len(data) == 0 {
return NewDomainCacheState()
}
var st DomainCacheState
if err := json.Unmarshal(data, &st); err == nil && st.Domains != nil {
if st.Version <= 0 {
st.Version = DomainCacheVersion
}
normalized := NewDomainCacheState()
for host, rec := range st.Domains {
host = strings.TrimSpace(strings.ToLower(host))
if host == "" {
continue
}
nrec := DomainCacheRecord{}
nrec.Direct = NormalizeDomainCacheEntry(rec.Direct)
nrec.Wildcard = NormalizeDomainCacheEntry(rec.Wildcard)
if nrec.Direct != nil || nrec.Wildcard != nil {
normalized.Domains[host] = nrec
}
}
return normalized
}
var legacy map[string]any
if err := json.Unmarshal(data, &legacy); err != nil {
if logf != nil {
logf("domain-cache: invalid json at %s, ignore", path)
}
return NewDomainCacheState()
}
out := NewDomainCacheState()
migrated := 0
for host, raw := range legacy {
host = strings.TrimSpace(strings.ToLower(host))
if host == "" || host == "version" || host == "domains" {
continue
}
entry, ok := parseLegacyDomainCacheEntry(raw)
if !ok {
continue
}
rec := out.Domains[host]
rec.Direct = &entry
out.Domains[host] = rec
migrated++
}
if logf != nil && migrated > 0 {
logf("domain-cache: migrated legacy entries=%d into split cache (direct bucket)", migrated)
}
return out
}
func (s DomainCacheState) Get(domain string, source DomainCacheSource, now, ttl int) ([]string, bool) {
rec, ok := s.Domains[strings.TrimSpace(strings.ToLower(domain))]
if !ok {
return nil, false
}
var entry *DomainCacheEntry
switch source {
case DomainCacheSourceWildcard:
entry = rec.Wildcard
default:
entry = rec.Direct
}
if entry == nil || entry.LastResolved <= 0 {
return nil, false
}
if now-entry.LastResolved > ttl {
return nil, false
}
ips := NormalizeCacheIPs(entry.IPs)
if len(ips) == 0 {
return nil, false
}
return ips, true
}
func (s DomainCacheState) GetNegative(domain string, source DomainCacheSource, now, nxTTL, timeoutTTL, temporaryTTL, otherTTL int) (DNSErrorKind, int, bool) {
rec, ok := s.Domains[strings.TrimSpace(strings.ToLower(domain))]
if !ok {
return "", 0, false
}
var entry *DomainCacheEntry
switch source {
case DomainCacheSourceWildcard:
entry = rec.Wildcard
default:
entry = rec.Direct
}
if entry == nil || entry.LastErrorAt <= 0 {
return "", 0, false
}
kind, ok := NormalizeCacheErrorKind(entry.LastErrorKind)
if !ok {
return "", 0, false
}
age := now - entry.LastErrorAt
if age < 0 {
return "", 0, false
}
cacheTTL := 0
switch kind {
case DNSErrorNXDomain:
cacheTTL = nxTTL
case DNSErrorTimeout:
cacheTTL = timeoutTTL
case DNSErrorTemporary:
cacheTTL = temporaryTTL
case DNSErrorOther:
cacheTTL = otherTTL
}
if cacheTTL <= 0 || age > cacheTTL {
return "", 0, false
}
return kind, age, true
}
func (s DomainCacheState) GetStoredIPs(domain string, source DomainCacheSource) []string {
rec, ok := s.Domains[strings.TrimSpace(strings.ToLower(domain))]
if !ok {
return nil
}
entry := GetCacheEntryBySource(rec, source)
if entry == nil {
return nil
}
return NormalizeCacheIPs(entry.IPs)
}
func (s DomainCacheState) GetLastErrorKind(domain string, source DomainCacheSource) (DNSErrorKind, bool) {
rec, ok := s.Domains[strings.TrimSpace(strings.ToLower(domain))]
if !ok {
return "", false
}
entry := GetCacheEntryBySource(rec, source)
if entry == nil || entry.LastErrorAt <= 0 {
return "", false
}
return NormalizeCacheErrorKind(entry.LastErrorKind)
}
func (s DomainCacheState) GetQuarantine(domain string, source DomainCacheSource, now int) (string, int, bool) {
rec, ok := s.Domains[strings.TrimSpace(strings.ToLower(domain))]
if !ok {
return "", 0, false
}
entry := GetCacheEntryBySource(rec, source)
if entry == nil || entry.QuarantineUntil <= 0 {
return "", 0, false
}
if now >= entry.QuarantineUntil {
return "", 0, false
}
state := NormalizeDomainState(entry.State, entry.Score)
if state == "" {
state = DomainStateQuarantine
}
age := 0
if entry.LastErrorAt > 0 {
age = now - entry.LastErrorAt
}
return state, age, true
}
func (s DomainCacheState) GetStale(domain string, source DomainCacheSource, now, maxAge int) ([]string, int, bool) {
if maxAge <= 0 {
return nil, 0, false
}
rec, ok := s.Domains[strings.TrimSpace(strings.ToLower(domain))]
if !ok {
return nil, 0, false
}
entry := GetCacheEntryBySource(rec, source)
if entry == nil || entry.LastResolved <= 0 {
return nil, 0, false
}
age := now - entry.LastResolved
if age < 0 || age > maxAge {
return nil, 0, false
}
ips := NormalizeCacheIPs(entry.IPs)
if len(ips) == 0 {
return nil, 0, false
}
return ips, age, true
}
func (s *DomainCacheState) Set(domain string, source DomainCacheSource, ips []string, now int) {
host := strings.TrimSpace(strings.ToLower(domain))
if host == "" || now <= 0 {
return
}
norm := NormalizeCacheIPs(ips)
if len(norm) == 0 {
return
}
if s.Domains == nil {
s.Domains = map[string]DomainCacheRecord{}
}
rec := s.Domains[host]
prev := GetCacheEntryBySource(rec, source)
prevScore := 0
if prev != nil {
prevScore = prev.Score
}
entry := &DomainCacheEntry{
IPs: norm,
LastResolved: now,
LastErrorKind: "",
LastErrorAt: 0,
Score: ClampDomainScore(prevScore + EnvInt("RESOLVE_DOMAIN_SCORE_OK", 8)),
QuarantineUntil: 0,
}
entry.State = DomainStateFromScore(entry.Score)
switch source {
case DomainCacheSourceWildcard:
rec.Wildcard = entry
default:
rec.Direct = entry
}
s.Domains[host] = rec
}
func GetCacheEntryBySource(rec DomainCacheRecord, source DomainCacheSource) *DomainCacheEntry {
switch source {
case DomainCacheSourceWildcard:
return rec.Wildcard
default:
return rec.Direct
}
}
func ClampDomainScore(v int) int {
if v < DomainScoreMin {
return DomainScoreMin
}
if v > DomainScoreMax {
return DomainScoreMax
}
return v
}
func DomainStateFromScore(score int) string {
switch {
case score >= 20:
return DomainStateActive
case score >= 5:
return DomainStateStable
case score >= -10:
return DomainStateSuspect
case score >= -30:
return DomainStateQuarantine
default:
return DomainStateHardQuar
}
}
func NormalizeDomainState(raw string, score int) string {
switch strings.TrimSpace(strings.ToLower(raw)) {
case DomainStateActive:
return DomainStateActive
case DomainStateStable:
return DomainStateStable
case DomainStateSuspect:
return DomainStateSuspect
case DomainStateQuarantine:
return DomainStateQuarantine
case DomainStateHardQuar:
return DomainStateHardQuar
default:
if score == 0 {
return ""
}
return DomainStateFromScore(score)
}
}
func DomainScorePenalty(stats DNSMetrics) int {
if stats.NXDomain >= 2 {
return EnvInt("RESOLVE_DOMAIN_SCORE_NX_CONFIRMED", -15)
}
if stats.NXDomain > 0 {
return EnvInt("RESOLVE_DOMAIN_SCORE_NX_SINGLE", -7)
}
if stats.Timeout > 0 {
return EnvInt("RESOLVE_DOMAIN_SCORE_TIMEOUT", -3)
}
if stats.Temporary > 0 {
return EnvInt("RESOLVE_DOMAIN_SCORE_TEMPORARY", -2)
}
return EnvInt("RESOLVE_DOMAIN_SCORE_OTHER", -2)
}
func classifyHostErrorKind(stats DNSMetrics) (DNSErrorKind, bool) {
if stats.Timeout > 0 {
return DNSErrorTimeout, true
}
if stats.Temporary > 0 {
return DNSErrorTemporary, true
}
if stats.Other > 0 {
return DNSErrorOther, true
}
if stats.NXDomain > 0 {
return DNSErrorNXDomain, true
}
return "", false
}
func (s *DomainCacheState) SetErrorWithStats(domain string, source DomainCacheSource, stats DNSMetrics, now int) {
host := strings.TrimSpace(strings.ToLower(domain))
if host == "" || now <= 0 {
return
}
kind, ok := classifyHostErrorKind(stats)
if !ok {
return
}
normKind, ok := NormalizeCacheErrorKind(string(kind))
if !ok {
return
}
penalty := DomainScorePenalty(stats)
quarantineTTL := EnvInt("RESOLVE_QUARANTINE_TTL_SEC", DefaultQuarantineTTL)
if quarantineTTL < 0 {
quarantineTTL = 0
}
hardQuarantineTTL := EnvInt("RESOLVE_HARD_QUARANTINE_TTL_SEC", DefaultHardQuarTTL)
if hardQuarantineTTL < 0 {
hardQuarantineTTL = 0
}
if s.Domains == nil {
s.Domains = map[string]DomainCacheRecord{}
}
rec := s.Domains[host]
entry := GetCacheEntryBySource(rec, source)
if entry == nil {
entry = &DomainCacheEntry{}
}
prevKind, _ := NormalizeCacheErrorKind(entry.LastErrorKind)
entry.Score = ClampDomainScore(entry.Score + penalty)
entry.State = DomainStateFromScore(entry.Score)
if normKind == DNSErrorTimeout && prevKind != DNSErrorNXDomain {
if entry.Score < -10 {
entry.Score = -10
}
entry.State = DomainStateSuspect
}
if normKind == DNSErrorNXDomain && !NXHardQuarantineEnabled() && entry.State == DomainStateHardQuar {
entry.State = DomainStateQuarantine
if entry.Score < -30 {
entry.Score = -30
}
}
entry.LastErrorKind = string(normKind)
entry.LastErrorAt = now
switch entry.State {
case DomainStateHardQuar:
entry.QuarantineUntil = now + hardQuarantineTTL
case DomainStateQuarantine:
entry.QuarantineUntil = now + quarantineTTL
default:
entry.QuarantineUntil = 0
}
switch source {
case DomainCacheSourceWildcard:
rec.Wildcard = entry
default:
rec.Direct = entry
}
s.Domains[host] = rec
}
func (s DomainCacheState) ToMap() map[string]any {
out := map[string]any{
"version": DomainCacheVersion,
"domains": map[string]any{},
}
domainsAny := out["domains"].(map[string]any)
hosts := make([]string, 0, len(s.Domains))
for host := range s.Domains {
hosts = append(hosts, host)
}
sort.Strings(hosts)
for _, host := range hosts {
rec := s.Domains[host]
recOut := map[string]any{}
if rec.Direct != nil {
directOut := map[string]any{}
if len(rec.Direct.IPs) > 0 && rec.Direct.LastResolved > 0 {
directOut["ips"] = rec.Direct.IPs
directOut["last_resolved"] = rec.Direct.LastResolved
}
if kind, ok := NormalizeCacheErrorKind(rec.Direct.LastErrorKind); ok && rec.Direct.LastErrorAt > 0 {
directOut["last_error_kind"] = string(kind)
directOut["last_error_at"] = rec.Direct.LastErrorAt
}
if rec.Direct.Score != 0 {
directOut["score"] = rec.Direct.Score
}
if st := NormalizeDomainState(rec.Direct.State, rec.Direct.Score); st != "" {
directOut["state"] = st
}
if rec.Direct.QuarantineUntil > 0 {
directOut["quarantine_until"] = rec.Direct.QuarantineUntil
}
if len(directOut) > 0 {
recOut["direct"] = directOut
}
}
if rec.Wildcard != nil {
wildOut := map[string]any{}
if len(rec.Wildcard.IPs) > 0 && rec.Wildcard.LastResolved > 0 {
wildOut["ips"] = rec.Wildcard.IPs
wildOut["last_resolved"] = rec.Wildcard.LastResolved
}
if kind, ok := NormalizeCacheErrorKind(rec.Wildcard.LastErrorKind); ok && rec.Wildcard.LastErrorAt > 0 {
wildOut["last_error_kind"] = string(kind)
wildOut["last_error_at"] = rec.Wildcard.LastErrorAt
}
if rec.Wildcard.Score != 0 {
wildOut["score"] = rec.Wildcard.Score
}
if st := NormalizeDomainState(rec.Wildcard.State, rec.Wildcard.Score); st != "" {
wildOut["state"] = st
}
if rec.Wildcard.QuarantineUntil > 0 {
wildOut["quarantine_until"] = rec.Wildcard.QuarantineUntil
}
if len(wildOut) > 0 {
recOut["wildcard"] = wildOut
}
}
if len(recOut) > 0 {
domainsAny[host] = recOut
}
}
return out
}
func (s DomainCacheState) FormatStateSummary(now int) string {
type counters struct {
active int
stable int
suspect int
quarantine int
hardQuar int
}
add := func(c *counters, entry *DomainCacheEntry) {
if entry == nil {
return
}
st := NormalizeDomainState(entry.State, entry.Score)
if entry.QuarantineUntil > now {
if st == DomainStateHardQuar {
c.hardQuar++
return
}
c.quarantine++
return
}
switch st {
case DomainStateActive:
c.active++
case DomainStateStable:
c.stable++
case DomainStateSuspect:
c.suspect++
case DomainStateQuarantine:
c.quarantine++
case DomainStateHardQuar:
c.hardQuar++
}
}
var c counters
for _, rec := range s.Domains {
add(&c, rec.Direct)
add(&c, rec.Wildcard)
}
total := c.active + c.stable + c.suspect + c.quarantine + c.hardQuar
if total == 0 {
return ""
}
return fmt.Sprintf(
"active=%d stable=%d suspect=%d quarantine=%d hard_quarantine=%d total=%d",
c.active, c.stable, c.suspect, c.quarantine, c.hardQuar, total,
)
}

View File

@@ -0,0 +1,53 @@
package resolver
import (
"os"
"strings"
)
func SmartDNSFallbackForTimeoutEnabled() bool {
switch strings.ToLower(strings.TrimSpace(os.Getenv("RESOLVE_SMARTDNS_TIMEOUT_FALLBACK"))) {
case "1", "true", "yes", "on":
return true
case "0", "false", "no", "off":
return false
default:
return false
}
}
func ShouldFallbackToSmartDNS(stats DNSMetrics) bool {
if stats.OK > 0 {
return false
}
if stats.NXDomain > 0 {
return false
}
if stats.Timeout > 0 || stats.Temporary > 0 {
return true
}
return stats.Other > 0
}
func ClassifyHostErrorKind(stats DNSMetrics) (DNSErrorKind, bool) {
if stats.Timeout > 0 {
return DNSErrorTimeout, true
}
if stats.Temporary > 0 {
return DNSErrorTemporary, true
}
if stats.Other > 0 {
return DNSErrorOther, true
}
if stats.NXDomain > 0 {
return DNSErrorNXDomain, true
}
return "", false
}
func ShouldUseStaleOnError(stats DNSMetrics) bool {
if stats.OK > 0 {
return false
}
return stats.Timeout > 0 || stats.Temporary > 0 || stats.Other > 0
}

View File

@@ -0,0 +1,234 @@
package resolver
import (
"context"
"net"
"time"
)
const (
dnsModeSmartDNS = "smartdns"
dnsModeHybridWildcard = "hybrid_wildcard"
)
type DNSAttemptPolicy struct {
TryLimit int
DomainBudget time.Duration
StopOnNX bool
}
type DNSCooldown interface {
ShouldSkip(upstream string, now int64) bool
ObserveSuccess(upstream string)
ObserveError(upstream string, kind DNSErrorKind, now int64) (bool, int)
}
func ResolveHost(
host string,
cfg DNSConfig,
metaSpecial []string,
isWildcard func(string) bool,
timeout time.Duration,
cooldown DNSCooldown,
directPolicyFor func(int) DNSAttemptPolicy,
wildcardPolicyFor func(int) DNSAttemptPolicy,
smartDNSFallbackEnabled bool,
logf func(string, ...any),
) ([]string, DNSMetrics) {
useMeta := false
for _, m := range metaSpecial {
if host == m {
useMeta = true
break
}
}
dnsList := cfg.Default
if useMeta {
dnsList = cfg.Meta
}
primaryViaSmartDNS := false
switch cfg.Mode {
case dnsModeSmartDNS:
if cfg.SmartDNS != "" {
dnsList = []string{cfg.SmartDNS}
primaryViaSmartDNS = true
}
case dnsModeHybridWildcard:
if cfg.SmartDNS != "" && isWildcard != nil && isWildcard(host) {
dnsList = []string{cfg.SmartDNS}
primaryViaSmartDNS = true
}
}
policy := safePolicy(directPolicyFor, len(dnsList), timeout)
if primaryViaSmartDNS {
policy = safePolicy(wildcardPolicyFor, len(dnsList), timeout)
}
ips, stats := DigAWithPolicy(host, dnsList, timeout, policy, cooldown, logf)
if len(ips) == 0 &&
!primaryViaSmartDNS &&
cfg.SmartDNS != "" &&
smartDNSFallbackEnabled &&
ShouldFallbackToSmartDNS(stats) {
if logf != nil {
logf(
"dns fallback %s: trying smartdns=%s after errors nxdomain=%d timeout=%d temporary=%d other=%d",
host,
cfg.SmartDNS,
stats.NXDomain,
stats.Timeout,
stats.Temporary,
stats.Other,
)
}
fallbackPolicy := safePolicy(wildcardPolicyFor, 1, timeout)
fallbackIPs, fallbackStats := DigAWithPolicy(host, []string{cfg.SmartDNS}, timeout, fallbackPolicy, cooldown, logf)
stats.Merge(fallbackStats)
if len(fallbackIPs) > 0 {
ips = fallbackIPs
if logf != nil {
logf("dns fallback %s: resolved via smartdns (%d ips)", host, len(fallbackIPs))
}
}
}
out := make([]string, 0, len(ips))
seen := map[string]struct{}{}
for _, ip := range ips {
if IsPrivateIPv4(ip) {
continue
}
if _, ok := seen[ip]; ok {
continue
}
seen[ip] = struct{}{}
out = append(out, ip)
}
return out, stats
}
func DigAWithPolicy(
host string,
dnsList []string,
timeout time.Duration,
policy DNSAttemptPolicy,
cooldown DNSCooldown,
logf func(string, ...any),
) ([]string, DNSMetrics) {
stats := DNSMetrics{}
if len(dnsList) == 0 {
return nil, stats
}
tryLimit := policy.TryLimit
if tryLimit <= 0 {
tryLimit = 1
}
if tryLimit > len(dnsList) {
tryLimit = len(dnsList)
}
budget := policy.DomainBudget
if budget <= 0 {
budget = time.Duration(tryLimit) * timeout
}
if budget < 200*time.Millisecond {
budget = 200 * time.Millisecond
}
deadline := time.Now().Add(budget)
start := PickDNSStartIndex(host, len(dnsList))
for attempt := 0; attempt < tryLimit; attempt++ {
remaining := time.Until(deadline)
if remaining <= 0 {
if logf != nil {
logf("dns budget exhausted %s: attempts=%d budget_ms=%d", host, attempt, budget.Milliseconds())
}
break
}
entry := dnsList[(start+attempt)%len(dnsList)]
server, port := SplitDNS(entry)
if server == "" {
continue
}
if port == "" {
port = "53"
}
addr := net.JoinHostPort(server, port)
if cooldown != nil && cooldown.ShouldSkip(addr, time.Now().Unix()) {
stats.AddCooldownSkip(addr)
continue
}
r := &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
d := net.Dialer{}
return d.DialContext(ctx, "udp", addr)
},
}
perAttemptTimeout := timeout
if remaining < perAttemptTimeout {
perAttemptTimeout = remaining
}
if perAttemptTimeout < 100*time.Millisecond {
perAttemptTimeout = 100 * time.Millisecond
}
ctx, cancel := context.WithTimeout(context.Background(), perAttemptTimeout)
records, err := r.LookupHost(ctx, host)
cancel()
if err != nil {
kindRaw := ClassifyDNSError(err)
kind, ok := NormalizeCacheErrorKind(kindRaw)
if !ok {
kind = DNSErrorOther
}
stats.AddError(addr, kind)
if cooldown != nil {
if banned, banSec := cooldown.ObserveError(addr, kind, time.Now().Unix()); banned && logf != nil {
logf("dns cooldown ban %s: timeout-like failures; ban_sec=%d", addr, banSec)
}
}
if logf != nil {
logf("dns warn %s via %s: kind=%s attempt=%d/%d err=%v", host, addr, kind, attempt+1, tryLimit, err)
}
if policy.StopOnNX && kind == DNSErrorNXDomain {
if logf != nil {
logf("dns early-stop %s: nxdomain via %s (attempt=%d/%d)", host, addr, attempt+1, tryLimit)
}
break
}
continue
}
var ips []string
for _, ip := range records {
if IsPrivateIPv4(ip) {
continue
}
ips = append(ips, ip)
}
if len(ips) == 0 {
stats.AddError(addr, DNSErrorOther)
if cooldown != nil {
_, _ = cooldown.ObserveError(addr, DNSErrorOther, time.Now().Unix())
}
if logf != nil {
logf("dns warn %s via %s: kind=other err=no_public_ips", host, addr)
}
continue
}
stats.AddSuccess(addr)
if cooldown != nil {
cooldown.ObserveSuccess(addr)
}
return UniqueStrings(ips), stats
}
return nil, stats
}
func safePolicy(factory func(int) DNSAttemptPolicy, count int, timeout time.Duration) DNSAttemptPolicy {
if factory != nil {
return factory(count)
}
return DNSAttemptPolicy{
TryLimit: 1,
DomainBudget: timeout,
StopOnNX: true,
}
}

View File

@@ -0,0 +1,135 @@
package resolver
import (
"encoding/json"
"os"
"strconv"
"strings"
)
func ReadLinesAllowMissing(path string) []string {
data, err := os.ReadFile(path)
if err != nil {
return nil
}
return strings.Split(strings.ReplaceAll(string(data), "\r\n", "\n"), "\n")
}
func LoadJSONMap(path string) map[string]any {
data, err := os.ReadFile(path)
if err != nil {
return map[string]any{}
}
var out map[string]any
if err := json.Unmarshal(data, &out); err != nil {
return map[string]any{}
}
return out
}
func SaveJSON(data any, path string) {
tmp := path + ".tmp"
b, err := json.MarshalIndent(data, "", " ")
if err != nil {
return
}
_ = os.WriteFile(tmp, b, 0o644)
_ = os.Rename(tmp, path)
}
func parseAnyInt(raw any) (int, bool) {
switch v := raw.(type) {
case int:
return v, true
case int64:
return int(v), true
case float64:
return int(v), true
case json.Number:
n, err := v.Int64()
if err != nil {
return 0, false
}
return int(n), true
case string:
s := strings.TrimSpace(v)
if s == "" {
return 0, false
}
n, err := strconv.Atoi(s)
if err != nil {
return 0, false
}
return n, true
default:
return 0, false
}
}
func LoadResolverPrecheckLastRun(path string) int {
m := LoadJSONMap(path)
if len(m) == 0 {
return 0
}
v, ok := parseAnyInt(m["last_run"])
if !ok || v <= 0 {
return 0
}
return v
}
func LoadResolverLiveBatchTarget(path string, fallback, minV, maxV int) int {
if fallback < minV {
fallback = minV
}
if fallback > maxV {
fallback = maxV
}
m := LoadJSONMap(path)
if len(m) == 0 {
return fallback
}
raw := m["live_batch_next_target"]
if raw == nil {
raw = m["live_batch_target"]
}
v, ok := parseAnyInt(raw)
if !ok || v <= 0 {
return fallback
}
if v < minV {
v = minV
}
if v > maxV {
v = maxV
}
return v
}
func LoadResolverLiveBatchNXHeavyPct(path string, fallback, minV, maxV int) int {
if fallback < minV {
fallback = minV
}
if fallback > maxV {
fallback = maxV
}
m := LoadJSONMap(path)
if len(m) == 0 {
return fallback
}
raw := m["live_batch_nxheavy_next_pct"]
if raw == nil {
raw = m["live_batch_nxheavy_pct"]
}
v, ok := parseAnyInt(raw)
if !ok {
return fallback
}
if v < minV {
v = minV
}
if v > maxV {
v = maxV
}
return v
}

View File

@@ -0,0 +1,113 @@
package resolver
func ComputeNextLiveBatchTarget(current, minV, maxV int, dnsStats DNSMetrics, deferred int) (int, string) {
if current < minV {
current = minV
}
if current > maxV {
current = maxV
}
next := current
reason := "stable"
attempts := dnsStats.Attempts
timeoutRate := 0.0
if attempts > 0 {
timeoutRate = float64(dnsStats.Timeout) / float64(attempts)
}
switch {
case attempts == 0:
reason = "no_dns_attempts"
case dnsStats.Skipped > 0 || timeoutRate >= 0.15:
next = int(float64(current) * 0.75)
reason = "timeout_high_or_cooldown"
case timeoutRate >= 0.08:
next = int(float64(current) * 0.90)
reason = "timeout_medium"
case timeoutRate <= 0.03 && deferred > 0:
next = int(float64(current) * 1.15)
reason = "timeout_low_expand"
case timeoutRate <= 0.03:
next = int(float64(current) * 1.10)
reason = "timeout_low"
}
if next < minV {
next = minV
}
if next > maxV {
next = maxV
}
if next == current && reason == "timeout_low" {
reason = "stable"
}
return next, reason
}
func ComputeNextLiveBatchNXHeavyPct(
current, minV, maxV int,
dnsStats DNSMetrics,
resolvedNowDNS int,
selectedP3 int,
nxTotal int,
liveNXHeavySkip int,
) (int, string) {
if current < minV {
current = minV
}
if current > maxV {
current = maxV
}
next := current
reason := "stable"
attempts := dnsStats.Attempts
timeoutRate := 0.0
okRate := 0.0
nxRate := 0.0
if attempts > 0 {
timeoutRate = float64(dnsStats.Timeout) / float64(attempts)
okRate = float64(dnsStats.OK) / float64(attempts)
nxRate = float64(dnsStats.NXDomain) / float64(attempts)
}
nxSelectedRatio := 0.0
if nxTotal > 0 {
nxSelectedRatio = float64(selectedP3) / float64(nxTotal)
}
switch {
case attempts == 0:
reason = "no_dns_attempts"
case timeoutRate >= 0.20 || dnsStats.Skipped > 0:
next = current - 3
reason = "timeout_very_high_or_cooldown"
case timeoutRate >= 0.12:
next = current - 2
reason = "timeout_high"
case dnsStats.OK == 0 && dnsStats.NXDomain > 0:
next = current - 2
reason = "no_success_nx_only"
case nxRate >= 0.90 && resolvedNowDNS == 0:
next = current - 2
reason = "nx_dominant_no_resolve"
case nxSelectedRatio >= 0.95 && resolvedNowDNS == 0:
next = current - 1
reason = "nxheavy_selected_dominant"
case timeoutRate <= 0.02 && okRate >= 0.10 && liveNXHeavySkip > 0:
next = current + 2
reason = "healthy_fast_reintroduce_nxheavy"
case timeoutRate <= 0.04 && resolvedNowDNS > 0 && liveNXHeavySkip > 0:
next = current + 1
reason = "healthy_reintroduce_nxheavy"
}
if next < minV {
next = minV
}
if next > maxV {
next = maxV
}
if next == current && reason != "no_dns_attempts" {
reason = "stable"
}
return next, reason
}

View File

@@ -0,0 +1,161 @@
package resolver
import "strings"
func ClassifyLiveBatchHost(
host string,
cache DomainCacheState,
cacheSourceForHost func(string) DomainCacheSource,
wildcards WildcardMatcher,
) (priority int, nxHeavy bool) {
h := strings.TrimSpace(strings.ToLower(host))
if h == "" {
return 2, false
}
if wildcards.IsExact(h) {
return 1, false
}
source := cacheSourceForHost(h)
rec, ok := cache.Domains[h]
if !ok {
return 2, false
}
entry := GetCacheEntryBySource(rec, source)
if entry == nil {
return 2, false
}
stored := NormalizeCacheIPs(entry.IPs)
state := NormalizeDomainState(entry.State, entry.Score)
errKind, hasErr := NormalizeCacheErrorKind(entry.LastErrorKind)
nxHeavy = hasErr && errKind == DNSErrorNXDomain && (state == DomainStateQuarantine || state == DomainStateHardQuar || entry.Score <= -10)
switch {
case len(stored) > 0:
return 1, false
case state == DomainStateActive || state == DomainStateStable || state == DomainStateSuspect:
return 1, false
case nxHeavy:
return 3, true
default:
return 2, false
}
}
func SplitLiveBatchCandidates(
candidates []string,
cache DomainCacheState,
cacheSourceForHost func(string) DomainCacheSource,
wildcards WildcardMatcher,
) (p1, p2, p3 []string, nxHeavyTotal int) {
for _, host := range candidates {
h := strings.TrimSpace(strings.ToLower(host))
if h == "" {
continue
}
prio, nxHeavy := ClassifyLiveBatchHost(h, cache, cacheSourceForHost, wildcards)
switch prio {
case 1:
p1 = append(p1, h)
case 3:
nxHeavyTotal++
p3 = append(p3, h)
case 2:
p2 = append(p2, h)
default:
if nxHeavy {
nxHeavyTotal++
p3 = append(p3, h)
} else {
p2 = append(p2, h)
}
}
}
return p1, p2, p3, nxHeavyTotal
}
func PickAdaptiveLiveBatch(
candidates []string,
target int,
now int,
nxHeavyPct int,
cache DomainCacheState,
cacheSourceForHost func(string) DomainCacheSource,
wildcards WildcardMatcher,
) ([]string, int, int, int, int, int) {
if len(candidates) == 0 {
return nil, 0, 0, 0, 0, 0
}
if target <= 0 {
p1, p2, p3, nxTotal := SplitLiveBatchCandidates(candidates, cache, cacheSourceForHost, wildcards)
return append([]string(nil), candidates...), len(p1), len(p2), len(p3), nxTotal, 0
}
if target > len(candidates) {
target = len(candidates)
}
if nxHeavyPct < 0 {
nxHeavyPct = 0
}
if nxHeavyPct > 100 {
nxHeavyPct = 100
}
start := now % len(candidates)
if start < 0 {
start = 0
}
rotated := make([]string, 0, len(candidates))
for i := 0; i < len(candidates); i++ {
idx := (start + i) % len(candidates)
rotated = append(rotated, candidates[idx])
}
p1, p2, p3, nxTotal := SplitLiveBatchCandidates(rotated, cache, cacheSourceForHost, wildcards)
out := make([]string, 0, target)
selectedP1 := 0
selectedP2 := 0
selectedP3 := 0
take := func(src []string, n int) ([]string, int) {
if n <= 0 || len(src) == 0 {
return src, 0
}
if n > len(src) {
n = len(src)
}
out = append(out, src[:n]...)
return src[n:], n
}
remain := target
var took int
p1, took = take(p1, remain)
selectedP1 += took
remain = target - len(out)
p2, took = take(p2, remain)
selectedP2 += took
remain = target - len(out)
p3Cap := (target * nxHeavyPct) / 100
if nxHeavyPct > 0 && p3Cap == 0 {
p3Cap = 1
}
if len(out) == 0 && len(p3) > 0 && p3Cap == 0 {
p3Cap = 1
}
if p3Cap > remain {
p3Cap = remain
}
p3, took = take(p3, p3Cap)
selectedP3 += took
if len(out) == 0 && len(p3) > 0 && target > 0 {
remain = target - len(out)
p3, took = take(p3, remain)
selectedP3 += took
}
nxSkipped := nxTotal - selectedP3
if nxSkipped < 0 {
nxSkipped = 0
}
return out, selectedP1, selectedP2, selectedP3, nxTotal, nxSkipped
}

View File

@@ -0,0 +1,62 @@
package resolver
import "strings"
type DNSModeRuntimeInput struct {
Config DNSConfig
Mode string
ViaSmartDNS bool
SmartDNSAddr string
SmartDNSForced bool
SmartDNSDefault string
NormalizeMode func(mode string, viaSmartDNS bool) string
NormalizeSmartDNSAddr func(raw string) string
}
func ApplyDNSModeRuntime(in DNSModeRuntimeInput) DNSConfig {
cfg := DNSConfig{
Default: append([]string(nil), in.Config.Default...),
Meta: append([]string(nil), in.Config.Meta...),
SmartDNS: strings.TrimSpace(in.Config.SmartDNS),
Mode: strings.TrimSpace(in.Config.Mode),
}
if !in.SmartDNSForced && in.NormalizeMode != nil {
if mode := strings.TrimSpace(in.NormalizeMode(in.Mode, in.ViaSmartDNS)); mode != "" {
cfg.Mode = mode
}
}
if in.NormalizeSmartDNSAddr != nil {
if addr := strings.TrimSpace(in.NormalizeSmartDNSAddr(in.SmartDNSAddr)); addr != "" {
cfg.SmartDNS = addr
}
} else if addr := strings.TrimSpace(in.SmartDNSAddr); addr != "" {
cfg.SmartDNS = addr
}
if cfg.SmartDNS == "" {
cfg.SmartDNS = strings.TrimSpace(in.SmartDNSDefault)
}
if cfg.Mode == "smartdns" && cfg.SmartDNS != "" {
cfg.Default = []string{cfg.SmartDNS}
cfg.Meta = []string{cfg.SmartDNS}
}
return cfg
}
func LogDNSMode(cfg DNSConfig, wildcardCount int, logf func(string, ...any)) {
if logf == nil {
return
}
switch cfg.Mode {
case "smartdns":
logf("resolver dns mode: SmartDNS-only (%s)", cfg.SmartDNS)
case "hybrid_wildcard":
logf("resolver dns mode: hybrid_wildcard smartdns=%s wildcards=%d default=%v meta=%v", cfg.SmartDNS, wildcardCount, cfg.Default, cfg.Meta)
default:
logf("resolver dns mode: direct default=%v meta=%v", cfg.Default, cfg.Meta)
}
}

View File

@@ -0,0 +1,119 @@
package resolver
type ResolvePlanningInput struct {
Domains []string
Now int
TTL int
PrecheckDue bool
PrecheckMaxDomains int
StaleKeepSec int
NegTTLNX int
NegTTLTimeout int
NegTTLTemporary int
NegTTLOther int
}
type ResolvePlanningResult struct {
Fresh map[string][]string
ToResolve []string
CacheNegativeHits int
QuarantineHits int
StaleHits int
PrecheckScheduled int
}
func BuildResolvePlanning(
in ResolvePlanningInput,
domainCache *DomainCacheState,
cacheSourceForHost func(string) DomainCacheSource,
logf func(string, ...any),
) ResolvePlanningResult {
result := ResolvePlanningResult{
Fresh: map[string][]string{},
}
if domainCache == nil {
result.ToResolve = append(result.ToResolve, in.Domains...)
return result
}
resolveSource := cacheSourceForHost
if resolveSource == nil {
resolveSource = func(string) DomainCacheSource { return DomainCacheSourceDirect }
}
for _, d := range in.Domains {
source := resolveSource(d)
if ips, ok := domainCache.Get(d, source, in.Now, in.TTL); ok {
result.Fresh[d] = ips
if logf != nil {
logf("cache hit[%s]: %s -> %v", source, d, ips)
}
continue
}
// Quarantine has priority over negative TTL cache so 24h quarantine
// is not silently overridden by shorter negative cache windows.
if state, age, ok := domainCache.GetQuarantine(d, source, in.Now); ok {
kind, hasKind := domainCache.GetLastErrorKind(d, source)
timeoutKind := hasKind && kind == DNSErrorTimeout
if in.PrecheckDue && result.PrecheckScheduled < in.PrecheckMaxDomains {
// Timeout-based quarantine is rechecked in background batch and should
// not flood trace with per-domain debug lines.
if timeoutKind {
result.QuarantineHits++
if in.StaleKeepSec > 0 {
if staleIPs, staleAge, ok := domainCache.GetStale(d, source, in.Now, in.StaleKeepSec); ok {
result.StaleHits++
result.Fresh[d] = staleIPs
if logf != nil {
logf("cache stale-keep (quarantine)[age=%ds]: %s -> %v", staleAge, d, staleIPs)
}
}
}
continue
}
result.PrecheckScheduled++
result.ToResolve = append(result.ToResolve, d)
if logf != nil {
logf("precheck schedule[quarantine/%s age=%ds]: %s (%s)", state, age, d, source)
}
continue
}
result.QuarantineHits++
if logf != nil {
logf("cache quarantine hit[%s age=%ds]: %s (%s)", state, age, d, source)
}
if in.StaleKeepSec > 0 {
if staleIPs, staleAge, ok := domainCache.GetStale(d, source, in.Now, in.StaleKeepSec); ok {
result.StaleHits++
result.Fresh[d] = staleIPs
if logf != nil {
logf("cache stale-keep (quarantine)[age=%ds]: %s -> %v", staleAge, d, staleIPs)
}
}
}
continue
}
if kind, age, ok := domainCache.GetNegative(d, source, in.Now, in.NegTTLNX, in.NegTTLTimeout, in.NegTTLTemporary, in.NegTTLOther); ok {
if in.PrecheckDue && result.PrecheckScheduled < in.PrecheckMaxDomains {
if kind == DNSErrorTimeout {
result.CacheNegativeHits++
continue
}
result.PrecheckScheduled++
result.ToResolve = append(result.ToResolve, d)
if logf != nil {
logf("precheck schedule[negative/%s age=%ds]: %s (%s)", kind, age, d, source)
}
continue
}
result.CacheNegativeHits++
if logf != nil {
logf("cache neg hit[%s/%s age=%ds]: %s", source, kind, age, d)
}
continue
}
result.ToResolve = append(result.ToResolve, d)
}
return result
}

View File

@@ -0,0 +1,106 @@
package resolver
import (
"os"
"strings"
)
type ResolverPrecheckFinalizeInput struct {
PrecheckDue bool
PrecheckStatePath string
Now int
TimeoutRecheck ResolverTimeoutRecheckStats
LiveBatchTarget int
LiveBatchMin int
LiveBatchMax int
LiveBatchNXHeavyPct int
LiveBatchNXHeavyMin int
LiveBatchNXHeavyMax int
DNSStats DNSMetrics
LiveDeferred int
ResolvedNowDNS int
LiveP1 int
LiveP2 int
LiveP3 int
LiveNXHeavyTotal int
LiveNXHeavySkip int
ToResolveTotal int
PrecheckFileForced bool
PrecheckForcePath string
}
type ResolverPrecheckFinalizeResult struct {
NextTarget int
NextReason string
NextNXPct int
NextNXReason string
Saved bool
ForceFileConsumed bool
}
func FinalizeResolverPrecheck(in ResolverPrecheckFinalizeInput, logf func(string, ...any)) ResolverPrecheckFinalizeResult {
out := ResolverPrecheckFinalizeResult{}
if in.PrecheckDue {
nextTarget, nextReason := ComputeNextLiveBatchTarget(in.LiveBatchTarget, in.LiveBatchMin, in.LiveBatchMax, in.DNSStats, in.LiveDeferred)
nextNXPct, nextNXReason := ComputeNextLiveBatchNXHeavyPct(
in.LiveBatchNXHeavyPct,
in.LiveBatchNXHeavyMin,
in.LiveBatchNXHeavyMax,
in.DNSStats,
in.ResolvedNowDNS,
in.LiveP3,
in.LiveNXHeavyTotal,
in.LiveNXHeavySkip,
)
if logf != nil {
logf(
"resolve live-batch nxheavy: pct=%d next=%d reason=%s selected=%d total=%d skipped=%d",
in.LiveBatchNXHeavyPct,
nextNXPct,
nextNXReason,
in.LiveP3,
in.LiveNXHeavyTotal,
in.LiveNXHeavySkip,
)
}
SaveResolverPrecheckState(
in.PrecheckStatePath,
in.Now,
in.TimeoutRecheck,
ResolverLiveBatchStats{
Target: in.LiveBatchTarget,
Total: in.ToResolveTotal,
Deferred: in.LiveDeferred,
P1: in.LiveP1,
P2: in.LiveP2,
P3: in.LiveP3,
NXHeavyPct: in.LiveBatchNXHeavyPct,
NXHeavyTotal: in.LiveNXHeavyTotal,
NXHeavySkip: in.LiveNXHeavySkip,
NextTarget: nextTarget,
NextReason: nextReason,
NextNXPct: nextNXPct,
NextNXReason: nextNXReason,
DNSAttempts: in.DNSStats.Attempts,
DNSTimeout: in.DNSStats.Timeout,
DNSCoolSkips: in.DNSStats.Skipped,
},
)
out.NextTarget = nextTarget
out.NextReason = nextReason
out.NextNXPct = nextNXPct
out.NextNXReason = nextNXReason
out.Saved = true
}
if in.PrecheckFileForced && strings.TrimSpace(in.PrecheckForcePath) != "" {
_ = os.Remove(in.PrecheckForcePath)
if logf != nil {
logf("resolve precheck force-file consumed: %s", in.PrecheckForcePath)
}
out.ForceFileConsumed = true
}
return out
}

View File

@@ -0,0 +1,39 @@
package resolver
func SaveResolverPrecheckState(path string, ts int, timeoutStats ResolverTimeoutRecheckStats, live ResolverLiveBatchStats) {
if path == "" || ts <= 0 {
return
}
state := LoadJSONMap(path)
if state == nil {
state = map[string]any{}
}
state["last_run"] = ts
state["timeout_recheck"] = map[string]any{
"checked": timeoutStats.Checked,
"recovered": timeoutStats.Recovered,
"recovered_ips": timeoutStats.RecoveredIPs,
"still_timeout": timeoutStats.StillTimeout,
"now_nxdomain": timeoutStats.NowNXDomain,
"now_temporary": timeoutStats.NowTemporary,
"now_other": timeoutStats.NowOther,
"no_signal": timeoutStats.NoSignal,
}
state["live_batch_target"] = live.Target
state["live_batch_total"] = live.Total
state["live_batch_deferred"] = live.Deferred
state["live_batch_p1"] = live.P1
state["live_batch_p2"] = live.P2
state["live_batch_p3"] = live.P3
state["live_batch_nxheavy_pct"] = live.NXHeavyPct
state["live_batch_nxheavy_total"] = live.NXHeavyTotal
state["live_batch_nxheavy_skip"] = live.NXHeavySkip
state["live_batch_nxheavy_next_pct"] = live.NextNXPct
state["live_batch_nxheavy_next_reason"] = live.NextNXReason
state["live_batch_next_target"] = live.NextTarget
state["live_batch_next_reason"] = live.NextReason
state["live_batch_dns_attempts"] = live.DNSAttempts
state["live_batch_dns_timeout"] = live.DNSTimeout
state["live_batch_dns_cooldown_skips"] = live.DNSCoolSkips
SaveJSON(state, path)
}

View File

@@ -0,0 +1,31 @@
package resolver
type ResolverTimeoutRecheckStats struct {
Checked int
Recovered int
RecoveredIPs int
StillTimeout int
NowNXDomain int
NowTemporary int
NowOther int
NoSignal int
}
type ResolverLiveBatchStats struct {
Target int
Total int
Deferred int
P1 int
P2 int
P3 int
NXHeavyPct int
NXHeavyTotal int
NXHeavySkip int
NextTarget int
NextReason string
NextNXPct int
NextNXReason string
DNSAttempts int
DNSTimeout int
DNSCoolSkips int
}

View File

@@ -0,0 +1,115 @@
package resolver
type ResolveBatchInput struct {
ToResolve []string
Workers int
Now int
StaleKeepSec int
}
type ResolveBatchResult struct {
DNSStats DNSMetrics
ResolvedNowDNS int
ResolvedNowStale int
UnresolvedAfterAttempts int
StaleHitsDelta int
}
func ExecuteResolveBatch(
in ResolveBatchInput,
resolved map[string][]string,
domainCache *DomainCacheState,
cacheSourceForHost func(string) DomainCacheSource,
resolveHost func(string) ([]string, DNSMetrics),
logf func(string, ...any),
) ResolveBatchResult {
out := ResolveBatchResult{}
if len(in.ToResolve) == 0 || resolveHost == nil || domainCache == nil {
return out
}
workers := in.Workers
if workers < 1 {
workers = 1
}
if workers > 500 {
workers = 500
}
resolveSource := cacheSourceForHost
if resolveSource == nil {
resolveSource = func(string) DomainCacheSource { return DomainCacheSourceDirect }
}
type result struct {
host string
ips []string
stats DNSMetrics
}
jobs := make(chan string, len(in.ToResolve))
results := make(chan result, len(in.ToResolve))
for i := 0; i < workers; i++ {
go func() {
for host := range jobs {
ips, stats := resolveHost(host)
results <- result{host: host, ips: ips, stats: stats}
}
}()
}
for _, host := range in.ToResolve {
jobs <- host
}
close(jobs)
for i := 0; i < len(in.ToResolve); i++ {
r := <-results
out.DNSStats.Merge(r.stats)
hostErrors := r.stats.TotalErrors()
if hostErrors > 0 && logf != nil {
logf("resolve errors for %s: total=%d nxdomain=%d timeout=%d temporary=%d other=%d", r.host, hostErrors, r.stats.NXDomain, r.stats.Timeout, r.stats.Temporary, r.stats.Other)
}
if len(r.ips) > 0 {
if resolved != nil {
resolved[r.host] = r.ips
}
out.ResolvedNowDNS++
source := resolveSource(r.host)
domainCache.Set(r.host, source, r.ips, in.Now)
if logf != nil {
logf("%s -> %v", r.host, r.ips)
}
continue
}
staleApplied := false
if hostErrors > 0 {
source := resolveSource(r.host)
domainCache.SetErrorWithStats(r.host, source, r.stats, in.Now)
if in.StaleKeepSec > 0 && ShouldUseStaleOnError(r.stats) {
if staleIPs, staleAge, ok := domainCache.GetStale(r.host, source, in.Now, in.StaleKeepSec); ok {
out.StaleHitsDelta++
out.ResolvedNowStale++
staleApplied = true
if resolved != nil {
resolved[r.host] = staleIPs
}
if logf != nil {
logf("cache stale-keep (error)[age=%ds]: %s -> %v", staleAge, r.host, staleIPs)
}
}
}
}
if !staleApplied {
out.UnresolvedAfterAttempts++
}
if logf != nil && resolved != nil {
if _, ok := resolved[r.host]; !ok {
logf("%s: no IPs", r.host)
}
}
}
return out
}

View File

@@ -0,0 +1,205 @@
package resolver
import "time"
type ResolverRuntimeTuningInput struct {
TTL int
Workers int
Now int
PrecheckStatePath string
PrecheckEnvForced bool
PrecheckFileForced bool
}
type ResolverRuntimeTuning struct {
TTL int
Workers int
DNSTimeoutMS int
DNSTimeout time.Duration
PrecheckEverySec int
PrecheckMaxDomains int
TimeoutRecheckMax int
LiveBatchMin int
LiveBatchMax int
LiveBatchTarget int
LiveBatchNXHeavyMin int
LiveBatchNXHeavyMax int
LiveBatchNXHeavyPct int
PrecheckDue bool
StaleKeepSec int
NegTTLNX int
NegTTLTimeout int
NegTTLTemporary int
NegTTLOther int
}
type ResolverRuntimeTuningDeps struct {
EnvInt func(string, int) int
LoadResolverPrecheckLastRun func(path string) int
LoadResolverLiveBatchTarget func(path string, fallback, minV, maxV int) int
LoadResolverLiveBatchNXHeavyPct func(path string, fallback, minV, maxV int) int
}
func BuildResolverRuntimeTuning(in ResolverRuntimeTuningInput, deps ResolverRuntimeTuningDeps) ResolverRuntimeTuning {
envInt := deps.EnvInt
if envInt == nil {
envInt = func(_ string, def int) int { return def }
}
ttl := in.TTL
if ttl <= 0 {
ttl = 24 * 3600
}
if ttl < 60 {
ttl = 60
}
if ttl > 24*3600 {
ttl = 24 * 3600
}
workers := in.Workers
if workers <= 0 {
workers = 80
}
if workers < 1 {
workers = 1
}
if workers > 500 {
workers = 500
}
dnsTimeoutMs := envInt("RESOLVE_DNS_TIMEOUT_MS", 1800)
if dnsTimeoutMs < 300 {
dnsTimeoutMs = 300
}
if dnsTimeoutMs > 5000 {
dnsTimeoutMs = 5000
}
precheckEverySec := envInt("RESOLVE_PRECHECK_EVERY_SEC", 24*3600)
if precheckEverySec < 0 {
precheckEverySec = 0
}
precheckMaxDomains := envInt("RESOLVE_PRECHECK_MAX_DOMAINS", 3000)
if precheckMaxDomains < 0 {
precheckMaxDomains = 0
}
if precheckMaxDomains > 50000 {
precheckMaxDomains = 50000
}
timeoutRecheckMax := envInt("RESOLVE_TIMEOUT_RECHECK_MAX", precheckMaxDomains)
if timeoutRecheckMax < 0 {
timeoutRecheckMax = 0
}
if timeoutRecheckMax > 50000 {
timeoutRecheckMax = 50000
}
liveBatchMin := envInt("RESOLVE_LIVE_BATCH_MIN", 800)
liveBatchMax := envInt("RESOLVE_LIVE_BATCH_MAX", 3000)
liveBatchDefault := envInt("RESOLVE_LIVE_BATCH_DEFAULT", 1800)
if liveBatchMin < 200 {
liveBatchMin = 200
}
if liveBatchMin > 50000 {
liveBatchMin = 50000
}
if liveBatchMax < liveBatchMin {
liveBatchMax = liveBatchMin
}
if liveBatchMax > 50000 {
liveBatchMax = 50000
}
if liveBatchDefault < liveBatchMin {
liveBatchDefault = liveBatchMin
}
if liveBatchDefault > liveBatchMax {
liveBatchDefault = liveBatchMax
}
liveBatchTarget := liveBatchDefault
if deps.LoadResolverLiveBatchTarget != nil {
liveBatchTarget = deps.LoadResolverLiveBatchTarget(in.PrecheckStatePath, liveBatchDefault, liveBatchMin, liveBatchMax)
}
liveBatchNXHeavyMin := envInt("RESOLVE_LIVE_BATCH_NX_HEAVY_MIN_PCT", 5)
liveBatchNXHeavyMax := envInt("RESOLVE_LIVE_BATCH_NX_HEAVY_MAX_PCT", 35)
liveBatchNXHeavyDefault := envInt("RESOLVE_LIVE_BATCH_NX_HEAVY_PCT", 10)
if liveBatchNXHeavyMin < 0 {
liveBatchNXHeavyMin = 0
}
if liveBatchNXHeavyMin > 100 {
liveBatchNXHeavyMin = 100
}
if liveBatchNXHeavyMax < liveBatchNXHeavyMin {
liveBatchNXHeavyMax = liveBatchNXHeavyMin
}
if liveBatchNXHeavyMax > 100 {
liveBatchNXHeavyMax = 100
}
if liveBatchNXHeavyDefault < liveBatchNXHeavyMin {
liveBatchNXHeavyDefault = liveBatchNXHeavyMin
}
if liveBatchNXHeavyDefault > liveBatchNXHeavyMax {
liveBatchNXHeavyDefault = liveBatchNXHeavyMax
}
liveBatchNXHeavyPct := liveBatchNXHeavyDefault
if deps.LoadResolverLiveBatchNXHeavyPct != nil {
liveBatchNXHeavyPct = deps.LoadResolverLiveBatchNXHeavyPct(in.PrecheckStatePath, liveBatchNXHeavyDefault, liveBatchNXHeavyMin, liveBatchNXHeavyMax)
}
precheckLastRun := 0
if deps.LoadResolverPrecheckLastRun != nil {
precheckLastRun = deps.LoadResolverPrecheckLastRun(in.PrecheckStatePath)
}
precheckDue := in.PrecheckEnvForced || in.PrecheckFileForced || (precheckEverySec > 0 && (precheckLastRun <= 0 || in.Now-precheckLastRun >= precheckEverySec))
staleKeepSec := envInt("RESOLVE_STALE_KEEP_SEC", 48*3600)
if staleKeepSec < 0 {
staleKeepSec = 0
}
if staleKeepSec > 7*24*3600 {
staleKeepSec = 7 * 24 * 3600
}
negTTLNX := envInt("RESOLVE_NEGATIVE_TTL_NX", 6*3600)
negTTLTimeout := envInt("RESOLVE_NEGATIVE_TTL_TIMEOUT", 15*60)
negTTLTemporary := envInt("RESOLVE_NEGATIVE_TTL_TEMPORARY", 10*60)
negTTLOther := envInt("RESOLVE_NEGATIVE_TTL_OTHER", 10*60)
clampTTL := func(v int) int {
if v < 0 {
return 0
}
if v > 24*3600 {
return 24 * 3600
}
return v
}
negTTLNX = clampTTL(negTTLNX)
negTTLTimeout = clampTTL(negTTLTimeout)
negTTLTemporary = clampTTL(negTTLTemporary)
negTTLOther = clampTTL(negTTLOther)
return ResolverRuntimeTuning{
TTL: ttl,
Workers: workers,
DNSTimeoutMS: dnsTimeoutMs,
DNSTimeout: time.Duration(dnsTimeoutMs) * time.Millisecond,
PrecheckEverySec: precheckEverySec,
PrecheckMaxDomains: precheckMaxDomains,
TimeoutRecheckMax: timeoutRecheckMax,
LiveBatchMin: liveBatchMin,
LiveBatchMax: liveBatchMax,
LiveBatchTarget: liveBatchTarget,
LiveBatchNXHeavyMin: liveBatchNXHeavyMin,
LiveBatchNXHeavyMax: liveBatchNXHeavyMax,
LiveBatchNXHeavyPct: liveBatchNXHeavyPct,
PrecheckDue: precheckDue,
StaleKeepSec: staleKeepSec,
NegTTLNX: negTTLNX,
NegTTLTimeout: negTTLTimeout,
NegTTLTemporary: negTTLTemporary,
NegTTLOther: negTTLOther,
}
}

View File

@@ -0,0 +1,64 @@
package resolver
type ResolverStartLogInput struct {
DomainsTotal int
TTL int
Workers int
DNSTimeoutMS int
DirectTry int
DirectBudgetMS int64
WildcardTry int
WildcardBudgetMS int64
NXEarlyStop bool
NXHardQuarantine bool
CooldownEnabled bool
CooldownMinAttempts int
CooldownTimeoutRate int
CooldownFailStreak int
CooldownBanSec int
CooldownMaxBanSec int
LiveBatchTarget int
LiveBatchMin int
LiveBatchMax int
LiveBatchNXHeavyPct int
LiveBatchNXHeavyMin int
LiveBatchNXHeavyMax int
StaleKeepSec int
PrecheckEverySec int
PrecheckMaxDomains int
PrecheckForcedEnv bool
PrecheckForcedFile bool
}
func LogResolverStart(in ResolverStartLogInput, logf func(string, ...any)) {
if logf == nil {
return
}
logf("resolver start: domains=%d ttl=%ds workers=%d dns_timeout_ms=%d", in.DomainsTotal, in.TTL, in.Workers, in.DNSTimeoutMS)
logf(
"resolver policy: direct_try=%d direct_budget_ms=%d wildcard_try=%d wildcard_budget_ms=%d nx_early_stop=%t nx_hard_quarantine=%t cooldown_enabled=%t cooldown_min_attempts=%d cooldown_timeout_rate=%d cooldown_fail_streak=%d cooldown_ban_sec=%d cooldown_max_ban_sec=%d live_batch_target=%d live_batch_min=%d live_batch_max=%d live_batch_nx_heavy_pct=%d live_batch_nx_heavy_min=%d live_batch_nx_heavy_max=%d stale_keep_sec=%d precheck_every_sec=%d precheck_max=%d precheck_forced_env=%t precheck_forced_file=%t",
in.DirectTry,
in.DirectBudgetMS,
in.WildcardTry,
in.WildcardBudgetMS,
in.NXEarlyStop,
in.NXHardQuarantine,
in.CooldownEnabled,
in.CooldownMinAttempts,
in.CooldownTimeoutRate,
in.CooldownFailStreak,
in.CooldownBanSec,
in.CooldownMaxBanSec,
in.LiveBatchTarget,
in.LiveBatchMin,
in.LiveBatchMax,
in.LiveBatchNXHeavyPct,
in.LiveBatchNXHeavyMin,
in.LiveBatchNXHeavyMax,
in.StaleKeepSec,
in.PrecheckEverySec,
in.PrecheckMaxDomains,
in.PrecheckForcedEnv,
in.PrecheckForcedFile,
)
}

View File

@@ -0,0 +1,139 @@
package resolver
import (
"context"
"fmt"
"net"
"net/netip"
"strings"
"time"
)
func ParseStaticEntries(lines []string, logf func(string, ...any)) (entries [][3]string, skipped int) {
for _, ln := range lines {
s := strings.TrimSpace(ln)
if s == "" || strings.HasPrefix(s, "#") {
continue
}
comment := ""
if idx := strings.Index(s, "#"); idx >= 0 {
comment = strings.TrimSpace(s[idx+1:])
s = strings.TrimSpace(s[:idx])
}
if s == "" || IsPrivateIPv4(s) {
continue
}
rawBase := strings.SplitN(s, "/", 2)[0]
if strings.Contains(s, "/") {
if _, err := netip.ParsePrefix(s); err != nil {
skipped++
if logf != nil {
logf("static skip invalid prefix %q: %v", s, err)
}
continue
}
} else {
if _, err := netip.ParseAddr(rawBase); err != nil {
skipped++
if logf != nil {
logf("static skip invalid ip %q: %v", s, err)
}
continue
}
}
entries = append(entries, [3]string{s, rawBase, comment})
}
return entries, skipped
}
func ResolveStaticLabels(entries [][3]string, dnsForPtr string, ptrCache map[string]any, ttl int, logf func(string, ...any)) (map[string][]string, int, int) {
now := int(time.Now().Unix())
result := map[string][]string{}
ptrLookups := 0
ptrErrors := 0
for _, e := range entries {
ipEntry, baseIP, comment := e[0], e[1], e[2]
var labels []string
if comment != "" {
labels = append(labels, "*"+comment)
}
if comment == "" {
if cached, ok := ptrCache[baseIP].(map[string]any); ok {
names, _ := cached["names"].([]any)
last, _ := cached["last_resolved"].(float64)
if len(names) > 0 && last > 0 && now-int(last) <= ttl {
for _, n := range names {
if s, ok := n.(string); ok && s != "" {
labels = append(labels, "*"+s)
}
}
}
}
if len(labels) == 0 {
ptrLookups++
names, err := DigPTR(baseIP, dnsForPtr, 3*time.Second, logf)
if err != nil {
ptrErrors++
}
if len(names) > 0 {
ptrCache[baseIP] = map[string]any{"names": names, "last_resolved": now}
for _, n := range names {
labels = append(labels, "*"+n)
}
}
}
}
if len(labels) == 0 {
labels = []string{"*[STATIC-IP]"}
}
result[ipEntry] = labels
if logf != nil {
logf("static %s -> %v", ipEntry, labels)
}
}
return result, ptrLookups, ptrErrors
}
func DigPTR(ip, upstream string, timeout time.Duration, logf func(string, ...any)) ([]string, error) {
server, port := SplitDNS(upstream)
if server == "" {
return nil, fmt.Errorf("upstream empty")
}
if port == "" {
port = "53"
}
addr := net.JoinHostPort(server, port)
r := &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
d := net.Dialer{}
return d.DialContext(ctx, "udp", addr)
},
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
names, err := r.LookupAddr(ctx, ip)
cancel()
if err != nil {
if logf != nil {
logf("ptr error %s via %s: %v", ip, addr, err)
}
return nil, err
}
seen := map[string]struct{}{}
var out []string
for _, n := range names {
n = strings.TrimSuffix(strings.ToLower(strings.TrimSpace(n)), ".")
if n == "" {
continue
}
if _, ok := seen[n]; !ok {
seen[n] = struct{}{}
out = append(out, n)
}
}
return out, nil
}

View File

@@ -0,0 +1,112 @@
package resolver
type ResolverSummaryLogInput struct {
DomainsTotal int
FreshCount int
CacheNegativeHits int
QuarantineHits int
StaleHits int
ResolvedTotal int
UnresolvedAfterAttempts int
LiveBatchTarget int
LiveDeferred int
LiveP1 int
LiveP2 int
LiveP3 int
LiveBatchNXHeavyPct int
LiveNXHeavyTotal int
LiveNXHeavySkip int
StaticEntries int
StaticSkipped int
UniqueIPs int
DirectIPs int
WildcardIPs int
PtrLookups int
PtrErrors int
DNSStats DNSMetrics
TimeoutRecheck ResolverTimeoutRecheckStats
DurationMS int64
DomainStateSummary string
ResolvedNowDNS int
ResolvedNowStale int
PrecheckDue bool
PrecheckScheduled int
PrecheckStatePath string
}
func LogResolverSummary(in ResolverSummaryLogInput, logf func(string, ...any)) {
if logf == nil {
return
}
dnsErrors := in.DNSStats.TotalErrors()
unresolved := in.DomainsTotal - in.ResolvedTotal
unresolvedSuppressed := in.CacheNegativeHits + in.QuarantineHits + in.LiveDeferred
logf(
"resolve summary: domains=%d cache_hits=%d cache_neg_hits=%d quarantine_hits=%d stale_hits=%d resolved_now=%d unresolved=%d unresolved_live=%d unresolved_suppressed=%d live_batch_target=%d live_batch_deferred=%d live_batch_p1=%d live_batch_p2=%d live_batch_p3=%d live_batch_nxheavy_pct=%d live_batch_nxheavy_total=%d live_batch_nxheavy_skip=%d static_entries=%d static_skipped=%d unique_ips=%d direct_ips=%d wildcard_ips=%d ptr_lookups=%d ptr_errors=%d dns_attempts=%d dns_ok=%d dns_nxdomain=%d dns_timeout=%d dns_temporary=%d dns_other=%d dns_cooldown_skips=%d dns_errors=%d timeout_recheck_checked=%d timeout_recheck_recovered=%d timeout_recheck_recovered_ips=%d timeout_recheck_still_timeout=%d timeout_recheck_now_nxdomain=%d timeout_recheck_now_temporary=%d timeout_recheck_now_other=%d timeout_recheck_no_signal=%d duration_ms=%d",
in.DomainsTotal,
in.FreshCount,
in.CacheNegativeHits,
in.QuarantineHits,
in.StaleHits,
in.ResolvedTotal-in.FreshCount,
unresolved,
in.UnresolvedAfterAttempts,
unresolvedSuppressed,
in.LiveBatchTarget,
in.LiveDeferred,
in.LiveP1,
in.LiveP2,
in.LiveP3,
in.LiveBatchNXHeavyPct,
in.LiveNXHeavyTotal,
in.LiveNXHeavySkip,
in.StaticEntries,
in.StaticSkipped,
in.UniqueIPs,
in.DirectIPs,
in.WildcardIPs,
in.PtrLookups,
in.PtrErrors,
in.DNSStats.Attempts,
in.DNSStats.OK,
in.DNSStats.NXDomain,
in.DNSStats.Timeout,
in.DNSStats.Temporary,
in.DNSStats.Other,
in.DNSStats.Skipped,
dnsErrors,
in.TimeoutRecheck.Checked,
in.TimeoutRecheck.Recovered,
in.TimeoutRecheck.RecoveredIPs,
in.TimeoutRecheck.StillTimeout,
in.TimeoutRecheck.NowNXDomain,
in.TimeoutRecheck.NowTemporary,
in.TimeoutRecheck.NowOther,
in.TimeoutRecheck.NoSignal,
in.DurationMS,
)
if perUpstream := in.DNSStats.FormatPerUpstream(); perUpstream != "" {
logf("resolve dns upstreams: %s", perUpstream)
}
if health := in.DNSStats.FormatResolverHealth(); health != "" {
logf("resolve dns health: %s", health)
}
if in.DomainStateSummary != "" {
logf("resolve domain states: %s", in.DomainStateSummary)
}
logf(
"resolve breakdown: resolved_now_total=%d resolved_now_dns=%d resolved_now_stale=%d skipped_neg=%d skipped_quarantine=%d deferred_live_batch=%d unresolved_after_attempts=%d",
in.ResolvedTotal-in.FreshCount,
in.ResolvedNowDNS,
in.ResolvedNowStale,
in.CacheNegativeHits,
in.QuarantineHits,
in.LiveDeferred,
in.UnresolvedAfterAttempts,
)
if in.PrecheckDue {
logf("resolve precheck done: scheduled=%d state=%s", in.PrecheckScheduled, in.PrecheckStatePath)
}
}

View File

@@ -0,0 +1,127 @@
package resolver
import "strings"
func RunTimeoutQuarantineRecheck(
domains []string,
now int,
limit int,
workers int,
domainCache *DomainCacheState,
cacheSourceForHost func(string) DomainCacheSource,
resolveHost func(string) ([]string, DNSMetrics),
) ResolverTimeoutRecheckStats {
stats := ResolverTimeoutRecheckStats{}
if limit <= 0 || now <= 0 || domainCache == nil || resolveHost == nil {
return stats
}
if workers < 1 {
workers = 1
}
if workers > 200 {
workers = 200
}
resolveSource := cacheSourceForHost
if resolveSource == nil {
resolveSource = func(string) DomainCacheSource { return DomainCacheSourceDirect }
}
seen := map[string]struct{}{}
capHint := len(domains)
if capHint > limit {
capHint = limit
}
candidates := make([]string, 0, capHint)
for _, raw := range domains {
host := strings.TrimSpace(strings.ToLower(raw))
if host == "" {
continue
}
if _, ok := seen[host]; ok {
continue
}
seen[host] = struct{}{}
source := resolveSource(host)
if _, _, ok := domainCache.GetQuarantine(host, source, now); !ok {
continue
}
kind, ok := domainCache.GetLastErrorKind(host, source)
if !ok || kind != DNSErrorTimeout {
continue
}
candidates = append(candidates, host)
if len(candidates) >= limit {
break
}
}
if len(candidates) == 0 {
return stats
}
recoveredIPSet := map[string]struct{}{}
type result struct {
host string
source DomainCacheSource
ips []string
dns DNSMetrics
}
jobs := make(chan string, len(candidates))
results := make(chan result, len(candidates))
for i := 0; i < workers; i++ {
go func() {
for host := range jobs {
src := resolveSource(host)
ips, dnsStats := resolveHost(host)
results <- result{host: host, source: src, ips: ips, dns: dnsStats}
}
}()
}
for _, host := range candidates {
jobs <- host
}
close(jobs)
for i := 0; i < len(candidates); i++ {
r := <-results
stats.Checked++
if len(r.ips) > 0 {
for _, ip := range r.ips {
ip = strings.TrimSpace(ip)
if ip == "" {
continue
}
recoveredIPSet[ip] = struct{}{}
}
domainCache.Set(r.host, r.source, r.ips, now)
stats.Recovered++
continue
}
if r.dns.TotalErrors() > 0 {
domainCache.SetErrorWithStats(r.host, r.source, r.dns, now)
}
kind, ok := ClassifyHostErrorKind(r.dns)
if !ok {
stats.NoSignal++
continue
}
switch kind {
case DNSErrorTimeout:
stats.StillTimeout++
case DNSErrorNXDomain:
stats.NowNXDomain++
case DNSErrorTemporary:
stats.NowTemporary++
default:
stats.NowOther++
}
}
stats.RecoveredIPs = len(recoveredIPSet)
return stats
}

View File

@@ -0,0 +1,70 @@
package resolver
import "strings"
type WildcardMatcher struct {
exact map[string]struct{}
suffix []string
}
func NormalizeWildcardDomain(raw string) string {
d := strings.TrimSpace(strings.SplitN(raw, "#", 2)[0])
d = strings.ToLower(d)
d = strings.TrimPrefix(d, "*.")
d = strings.TrimPrefix(d, ".")
d = strings.TrimSuffix(d, ".")
return d
}
func NewWildcardMatcher(domains []string) WildcardMatcher {
seen := map[string]struct{}{}
m := WildcardMatcher{exact: map[string]struct{}{}}
for _, raw := range domains {
d := NormalizeWildcardDomain(raw)
if d == "" {
continue
}
if _, ok := seen[d]; ok {
continue
}
seen[d] = struct{}{}
m.exact[d] = struct{}{}
m.suffix = append(m.suffix, "."+d)
}
return m
}
func (m WildcardMatcher) Match(host string) bool {
if len(m.exact) == 0 {
return false
}
h := strings.TrimSuffix(strings.ToLower(strings.TrimSpace(host)), ".")
if h == "" {
return false
}
if _, ok := m.exact[h]; ok {
return true
}
for _, suffix := range m.suffix {
if strings.HasSuffix(h, suffix) {
return true
}
}
return false
}
func (m WildcardMatcher) IsExact(host string) bool {
if len(m.exact) == 0 {
return false
}
h := strings.TrimSuffix(strings.ToLower(strings.TrimSpace(host)), ".")
if h == "" {
return false
}
_, ok := m.exact[h]
return ok
}
func (m WildcardMatcher) Count() int {
return len(m.exact)
}

View File

@@ -0,0 +1,63 @@
package app
import resolverpkg "selective-vpn-api/app/resolver"
// ---------------------------------------------------------------------
// resolver bridge layer (consolidated)
// ---------------------------------------------------------------------
type dnsErrorKind = resolverpkg.DNSErrorKind
const (
dnsErrorNXDomain dnsErrorKind = resolverpkg.DNSErrorNXDomain
dnsErrorTimeout dnsErrorKind = resolverpkg.DNSErrorTimeout
dnsErrorTemporary dnsErrorKind = resolverpkg.DNSErrorTemporary
dnsErrorOther dnsErrorKind = resolverpkg.DNSErrorOther
)
type dnsUpstreamMetrics = resolverpkg.DNSUpstreamMetrics
type dnsMetrics = resolverpkg.DNSMetrics
type wildcardMatcher = resolverpkg.WildcardMatcher
type domainCacheSource = resolverpkg.DomainCacheSource
const (
domainCacheSourceDirect domainCacheSource = resolverpkg.DomainCacheSourceDirect
domainCacheSourceWildcard domainCacheSource = resolverpkg.DomainCacheSourceWildcard
)
const (
domainStateActive = resolverpkg.DomainStateActive
domainStateStable = resolverpkg.DomainStateStable
domainStateSuspect = resolverpkg.DomainStateSuspect
domainStateQuarantine = resolverpkg.DomainStateQuarantine
domainStateHardQuar = resolverpkg.DomainStateHardQuar
domainScoreMin = resolverpkg.DomainScoreMin
domainScoreMax = resolverpkg.DomainScoreMax
defaultQuarantineTTL = resolverpkg.DefaultQuarantineTTL
defaultHardQuarantineTT = resolverpkg.DefaultHardQuarTTL
)
type domainCacheEntry = resolverpkg.DomainCacheEntry
type domainCacheRecord = resolverpkg.DomainCacheRecord
type domainCacheState resolverpkg.DomainCacheState
type resolverPlanningResult = resolverpkg.ResolvePlanningResult
type resolverTimeoutRecheckStats = resolverpkg.ResolverTimeoutRecheckStats
type resolverLiveBatchStats = resolverpkg.ResolverLiveBatchStats
type resolverResolveBatchResult = resolverpkg.ResolveBatchResult
type resolverRuntimeTuning = resolverpkg.ResolverRuntimeTuning
type resolverStartLogInput = resolverpkg.ResolverStartLogInput
type resolverSummaryLogInput = resolverpkg.ResolverSummaryLogInput
func init() {
resolverpkg.EnvInt = envInt
resolverpkg.NXHardQuarantineEnabled = resolveNXHardQuarantineEnabled
}

View File

@@ -0,0 +1,90 @@
package app
import resolverpkg "selective-vpn-api/app/resolver"
func newDomainCacheState() domainCacheState {
return domainCacheState(resolverpkg.NewDomainCacheState())
}
func normalizeCacheIPs(raw []string) []string {
return resolverpkg.NormalizeCacheIPs(raw)
}
func normalizeCacheErrorKind(raw string) (dnsErrorKind, bool) {
kind, ok := resolverpkg.NormalizeCacheErrorKind(raw)
return dnsErrorKind(kind), ok
}
func normalizeDomainCacheEntry(in *domainCacheEntry) *domainCacheEntry {
return resolverpkg.NormalizeDomainCacheEntry(in)
}
func loadDomainCacheState(path string, logf func(string, ...any)) domainCacheState {
return domainCacheState(resolverpkg.LoadDomainCacheState(path, logf))
}
func getCacheEntryBySource(rec domainCacheRecord, source domainCacheSource) *domainCacheEntry {
return resolverpkg.GetCacheEntryBySource(rec, source)
}
func clampDomainScore(v int) int {
return resolverpkg.ClampDomainScore(v)
}
func domainStateFromScore(score int) string {
return resolverpkg.DomainStateFromScore(score)
}
func normalizeDomainState(raw string, score int) string {
return resolverpkg.NormalizeDomainState(raw, score)
}
func domainScorePenalty(stats dnsMetrics) int {
return resolverpkg.DomainScorePenalty(stats)
}
func (s domainCacheState) get(domain string, source domainCacheSource, now, ttl int) ([]string, bool) {
return resolverpkg.DomainCacheState(s).Get(domain, source, now, ttl)
}
func (s domainCacheState) getNegative(domain string, source domainCacheSource, now, nxTTL, timeoutTTL, temporaryTTL, otherTTL int) (dnsErrorKind, int, bool) {
kind, age, ok := resolverpkg.DomainCacheState(s).GetNegative(domain, source, now, nxTTL, timeoutTTL, temporaryTTL, otherTTL)
return dnsErrorKind(kind), age, ok
}
func (s domainCacheState) getStoredIPs(domain string, source domainCacheSource) []string {
return resolverpkg.DomainCacheState(s).GetStoredIPs(domain, source)
}
func (s domainCacheState) getLastErrorKind(domain string, source domainCacheSource) (dnsErrorKind, bool) {
kind, ok := resolverpkg.DomainCacheState(s).GetLastErrorKind(domain, source)
return dnsErrorKind(kind), ok
}
func (s domainCacheState) getQuarantine(domain string, source domainCacheSource, now int) (string, int, bool) {
return resolverpkg.DomainCacheState(s).GetQuarantine(domain, source, now)
}
func (s domainCacheState) getStale(domain string, source domainCacheSource, now, maxAge int) ([]string, int, bool) {
return resolverpkg.DomainCacheState(s).GetStale(domain, source, now, maxAge)
}
func (s *domainCacheState) set(domain string, source domainCacheSource, ips []string, now int) {
state := resolverpkg.DomainCacheState(*s)
state.Set(domain, source, ips, now)
*s = domainCacheState(state)
}
func (s *domainCacheState) setErrorWithStats(domain string, source domainCacheSource, stats dnsMetrics, now int) {
state := resolverpkg.DomainCacheState(*s)
state.SetErrorWithStats(domain, source, stats, now)
*s = domainCacheState(state)
}
func (s domainCacheState) toMap() map[string]any {
return resolverpkg.DomainCacheState(s).ToMap()
}
func (s domainCacheState) formatStateSummary(now int) string {
return resolverpkg.DomainCacheState(s).FormatStateSummary(now)
}

View File

@@ -0,0 +1,171 @@
package app
import (
"strings"
"time"
resolverpkg "selective-vpn-api/app/resolver"
)
func loadDNSConfig(path string, logf func(string, ...any)) dnsConfig {
cfg := resolverpkg.LoadDNSConfig(
path,
resolverpkg.DNSConfig{
Default: []string{defaultDNS1, defaultDNS2},
Meta: []string{defaultMeta1, defaultMeta2},
SmartDNS: smartDNSAddr(),
Mode: string(DNSModeDirect),
},
resolverpkg.DNSConfigDeps{
ActivePool: loadEnabledDNSUpstreamPool(),
IsSmartDNSForced: smartDNSForced(),
SmartDNSAddr: smartDNSAddr(),
SmartDNSForcedMode: string(DNSModeSmartDNS),
ResolveFallbackPool: func() []string {
return resolverFallbackPool()
},
MergeDNSUpstreamPools: func(primary, fallback []string) []string {
return mergeDNSUpstreamPools(primary, fallback)
},
NormalizeDNSUpstream: func(raw string, defaultPort string) string {
return normalizeDNSUpstream(raw, defaultPort)
},
NormalizeSmartDNSAddr: normalizeSmartDNSAddr,
NormalizeDNSResolverMode: func(raw string) string {
return string(normalizeDNSResolverMode(DNSResolverMode(raw), false))
},
},
logf,
)
return dnsConfig{
Default: cfg.Default,
Meta: cfg.Meta,
SmartDNS: cfg.SmartDNS,
Mode: DNSResolverMode(cfg.Mode),
}
}
type resolverCooldownAdapter struct {
inner *dnsRunCooldown
}
func (a resolverCooldownAdapter) ShouldSkip(upstream string, now int64) bool {
if a.inner == nil {
return false
}
return a.inner.shouldSkip(upstream, now)
}
func (a resolverCooldownAdapter) ObserveSuccess(upstream string) {
if a.inner == nil {
return
}
a.inner.observeSuccess(upstream)
}
func (a resolverCooldownAdapter) ObserveError(upstream string, kind resolverpkg.DNSErrorKind, now int64) (bool, int) {
if a.inner == nil {
return false, 0
}
return a.inner.observeError(upstream, dnsErrorKind(kind), now)
}
func toResolverPolicy(policy dnsAttemptPolicy) resolverpkg.DNSAttemptPolicy {
return resolverpkg.DNSAttemptPolicy{
TryLimit: policy.TryLimit,
DomainBudget: policy.DomainBudget,
StopOnNX: policy.StopOnNX,
}
}
func resolveHostGo(host string, cfg dnsConfig, metaSpecial []string, wildcards wildcardMatcher, timeout time.Duration, cooldown *dnsRunCooldown, logf func(string, ...any)) ([]string, dnsMetrics) {
return resolverpkg.ResolveHost(
host,
resolverpkg.DNSConfig{
Default: cfg.Default,
Meta: cfg.Meta,
SmartDNS: cfg.SmartDNS,
Mode: string(cfg.Mode),
},
metaSpecial,
func(h string) bool { return wildcards.Match(h) },
timeout,
resolverCooldownAdapter{inner: cooldown},
func(dnsCount int) resolverpkg.DNSAttemptPolicy {
return toResolverPolicy(directDNSAttemptPolicy(dnsCount))
},
func(dnsCount int) resolverpkg.DNSAttemptPolicy {
return toResolverPolicy(wildcardDNSAttemptPolicy(dnsCount))
},
smartDNSFallbackForTimeoutEnabled(),
logf,
)
}
func digA(host string, dnsList []string, timeout time.Duration, logf func(string, ...any)) ([]string, dnsMetrics) {
policy := toResolverPolicy(defaultDNSAttemptPolicy(len(dnsList)))
return resolverpkg.DigAWithPolicy(host, dnsList, timeout, policy, nil, logf)
}
func digAWithPolicy(host string, dnsList []string, timeout time.Duration, logf func(string, ...any), policy dnsAttemptPolicy, cooldown *dnsRunCooldown) ([]string, dnsMetrics) {
return resolverpkg.DigAWithPolicy(host, dnsList, timeout, toResolverPolicy(policy), resolverCooldownAdapter{inner: cooldown}, logf)
}
func applyResolverDNSModeRuntime(cfg dnsConfig, opts ResolverOpts) dnsConfig {
out := resolverpkg.ApplyDNSModeRuntime(resolverpkg.DNSModeRuntimeInput{
Config: resolverpkg.DNSConfig{
Default: cfg.Default,
Meta: cfg.Meta,
SmartDNS: cfg.SmartDNS,
Mode: string(cfg.Mode),
},
Mode: string(opts.Mode),
ViaSmartDNS: opts.ViaSmartDNS,
SmartDNSAddr: opts.SmartDNSAddr,
SmartDNSForced: smartDNSForced(),
SmartDNSDefault: smartDNSAddr(),
NormalizeMode: func(mode string, viaSmartDNS bool) string {
return string(normalizeDNSResolverMode(DNSResolverMode(mode), viaSmartDNS))
},
NormalizeSmartDNSAddr: normalizeSmartDNSAddr,
})
return dnsConfig{
Default: out.Default,
Meta: out.Meta,
SmartDNS: out.SmartDNS,
Mode: DNSResolverMode(out.Mode),
}
}
func logResolverDNSMode(cfg dnsConfig, wildcards wildcardMatcher, logf func(string, ...any)) {
resolverpkg.LogDNSMode(
resolverpkg.DNSConfig{
Default: cfg.Default,
Meta: cfg.Meta,
SmartDNS: cfg.SmartDNS,
Mode: string(cfg.Mode),
},
wildcards.Count(),
logf,
)
}
func parseStaticEntriesGo(lines []string, logf func(string, ...any)) (entries [][3]string, skipped int) {
return resolverpkg.ParseStaticEntries(lines, logf)
}
func resolveStaticLabels(entries [][3]string, cfg dnsConfig, ptrCache map[string]any, ttl int, logf func(string, ...any)) (map[string][]string, int, int) {
dnsForPtr := defaultDNS1
if len(cfg.Default) > 0 && strings.TrimSpace(cfg.Default[0]) != "" {
dnsForPtr = cfg.Default[0]
}
return resolverpkg.ResolveStaticLabels(entries, dnsForPtr, ptrCache, ttl, logf)
}
func digPTR(ip, upstream string, timeout time.Duration, logf func(string, ...any)) ([]string, error) {
return resolverpkg.DigPTR(ip, upstream, timeout, logf)
}
func logResolverSummary(input resolverpkg.ResolverSummaryLogInput, logf func(string, ...any)) {
resolverpkg.LogResolverSummary(input, logf)
}

View File

@@ -0,0 +1,5 @@
package app
// Resolver bridge pipeline helpers are split by role:
// - planning/runtime tuning wrappers: resolver_bridge_pipeline_planning.go
// - execution/recheck/artifacts wrappers: resolver_bridge_pipeline_exec.go

View File

@@ -0,0 +1,86 @@
package app
import (
"time"
resolverpkg "selective-vpn-api/app/resolver"
)
func executeResolverBatch(
toResolve []string,
workers int,
now int,
staleKeepSec int,
resolved map[string][]string,
domainCache *domainCacheState,
cacheSourceForHost func(string) domainCacheSource,
resolveHost func(string) ([]string, dnsMetrics),
logf func(string, ...any),
) resolverResolveBatchResult {
if domainCache == nil {
return resolverResolveBatchResult{}
}
sourceFn := func(host string) resolverpkg.DomainCacheSource {
if cacheSourceForHost == nil {
return resolverpkg.DomainCacheSourceDirect
}
return resolverpkg.DomainCacheSource(cacheSourceForHost(host))
}
resolveFn := func(host string) ([]string, resolverpkg.DNSMetrics) {
if resolveHost == nil {
return nil, resolverpkg.DNSMetrics{}
}
return resolveHost(host)
}
return resolverpkg.ExecuteResolveBatch(
resolverpkg.ResolveBatchInput{
ToResolve: toResolve,
Workers: workers,
Now: now,
StaleKeepSec: staleKeepSec,
},
resolved,
(*resolverpkg.DomainCacheState)(domainCache),
sourceFn,
resolveFn,
logf,
)
}
func runTimeoutQuarantineRecheck(
domains []string,
cfg dnsConfig,
metaSpecial []string,
wildcards wildcardMatcher,
timeout time.Duration,
domainCache *domainCacheState,
cacheSourceForHost func(string) domainCacheSource,
now int,
limit int,
workers int,
) resolverTimeoutRecheckStats {
if domainCache == nil {
return resolverTimeoutRecheckStats{}
}
sourceFn := func(host string) resolverpkg.DomainCacheSource {
if cacheSourceForHost == nil {
return resolverpkg.DomainCacheSourceDirect
}
return resolverpkg.DomainCacheSource(cacheSourceForHost(host))
}
return resolverpkg.RunTimeoutQuarantineRecheck(
domains,
now,
limit,
workers,
(*resolverpkg.DomainCacheState)(domainCache),
sourceFn,
func(host string) ([]string, resolverpkg.DNSMetrics) {
return resolveHostGo(host, cfg, metaSpecial, wildcards, timeout, nil, nil)
},
)
}
func buildResolverArtifacts(resolved map[string][]string, staticLabels map[string][]string, isWildcardHost func(string) bool) resolverpkg.ResolverArtifacts {
return resolverpkg.BuildResolverArtifacts(resolved, staticLabels, isWildcardHost)
}

View File

@@ -0,0 +1,118 @@
package app
import resolverpkg "selective-vpn-api/app/resolver"
func buildResolverPlanning(
domains []string,
now int,
ttl int,
precheckDue bool,
precheckMaxDomains int,
staleKeepSec int,
negTTLNX int,
negTTLTimeout int,
negTTLTemporary int,
negTTLOther int,
domainCache *domainCacheState,
cacheSourceForHost func(string) domainCacheSource,
logf func(string, ...any),
) resolverPlanningResult {
sourceFn := func(host string) resolverpkg.DomainCacheSource {
if cacheSourceForHost == nil {
return resolverpkg.DomainCacheSourceDirect
}
return resolverpkg.DomainCacheSource(cacheSourceForHost(host))
}
return resolverpkg.BuildResolvePlanning(
resolverpkg.ResolvePlanningInput{
Domains: domains,
Now: now,
TTL: ttl,
PrecheckDue: precheckDue,
PrecheckMaxDomains: precheckMaxDomains,
StaleKeepSec: staleKeepSec,
NegTTLNX: negTTLNX,
NegTTLTimeout: negTTLTimeout,
NegTTLTemporary: negTTLTemporary,
NegTTLOther: negTTLOther,
},
(*resolverpkg.DomainCacheState)(domainCache),
sourceFn,
logf,
)
}
func finalizeResolverPrecheck(
precheckDue bool,
precheckStatePath string,
now int,
timeoutRecheck resolverTimeoutRecheckStats,
liveBatchTarget int,
liveBatchMin int,
liveBatchMax int,
liveBatchNXHeavyPct int,
liveBatchNXHeavyMin int,
liveBatchNXHeavyMax int,
dnsStats dnsMetrics,
liveDeferred int,
resolvedNowDNS int,
liveP1 int,
liveP2 int,
liveP3 int,
liveNXHeavyTotal int,
liveNXHeavySkip int,
toResolveTotal int,
precheckFileForced bool,
precheckForcePath string,
logf func(string, ...any),
) resolverpkg.ResolverPrecheckFinalizeResult {
return resolverpkg.FinalizeResolverPrecheck(
resolverpkg.ResolverPrecheckFinalizeInput{
PrecheckDue: precheckDue,
PrecheckStatePath: precheckStatePath,
Now: now,
TimeoutRecheck: timeoutRecheck,
LiveBatchTarget: liveBatchTarget,
LiveBatchMin: liveBatchMin,
LiveBatchMax: liveBatchMax,
LiveBatchNXHeavyPct: liveBatchNXHeavyPct,
LiveBatchNXHeavyMin: liveBatchNXHeavyMin,
LiveBatchNXHeavyMax: liveBatchNXHeavyMax,
DNSStats: dnsStats,
LiveDeferred: liveDeferred,
ResolvedNowDNS: resolvedNowDNS,
LiveP1: liveP1,
LiveP2: liveP2,
LiveP3: liveP3,
LiveNXHeavyTotal: liveNXHeavyTotal,
LiveNXHeavySkip: liveNXHeavySkip,
ToResolveTotal: toResolveTotal,
PrecheckFileForced: precheckFileForced,
PrecheckForcePath: precheckForcePath,
},
logf,
)
}
func buildResolverRuntimeTuning(opts ResolverOpts, now int, precheckStatePath string, precheckEnvForced bool, precheckFileForced bool) resolverRuntimeTuning {
return resolverpkg.BuildResolverRuntimeTuning(
resolverpkg.ResolverRuntimeTuningInput{
TTL: opts.TTL,
Workers: opts.Workers,
Now: now,
PrecheckStatePath: precheckStatePath,
PrecheckEnvForced: precheckEnvForced,
PrecheckFileForced: precheckFileForced,
},
resolverpkg.ResolverRuntimeTuningDeps{
EnvInt: envInt,
LoadResolverPrecheckLastRun: loadResolverPrecheckLastRun,
LoadResolverLiveBatchTarget: loadResolverLiveBatchTarget,
LoadResolverLiveBatchNXHeavyPct: loadResolverLiveBatchNXHeavyPct,
},
)
}
func logResolverStart(input resolverStartLogInput, logf func(string, ...any)) {
resolverpkg.LogResolverStart(input, logf)
}

Some files were not shown because too many files have changed in this diff Show More