Files
elmprodvpn/selective-vpn-api/app/resolver.go

2290 lines
60 KiB
Go

package app
import (
"context"
"encoding/json"
"errors"
"fmt"
"hash/fnv"
"net"
"net/netip"
"os"
"regexp"
"sort"
"strconv"
"strings"
"time"
)
// ---------------------------------------------------------------------
// Go resolver
// ---------------------------------------------------------------------
// EN: Go-based domain resolver pipeline used by routes update.
// EN: Handles cache reuse, concurrent DNS lookups, PTR labeling for static entries,
// EN: and returns deduplicated IP sets plus IP-to-label mapping artifacts.
// RU: Go-резолвер, используемый пайплайном обновления маршрутов.
// RU: Обрабатывает кэш, конкурентные DNS-запросы, PTR-лейблы для static entries
// RU: и возвращает дедуплицированный список IP и IP-to-label mapping.
type dnsErrorKind string
const (
dnsErrorNXDomain dnsErrorKind = "nxdomain"
dnsErrorTimeout dnsErrorKind = "timeout"
dnsErrorTemporary dnsErrorKind = "temporary"
dnsErrorOther dnsErrorKind = "other"
)
type dnsUpstreamMetrics struct {
Attempts int
OK int
NXDomain int
Timeout int
Temporary int
Other int
}
type dnsMetrics struct {
Attempts int
OK int
NXDomain int
Timeout int
Temporary int
Other int
PerUpstream map[string]*dnsUpstreamMetrics
}
func (m *dnsMetrics) ensureUpstream(upstream string) *dnsUpstreamMetrics {
if m.PerUpstream == nil {
m.PerUpstream = map[string]*dnsUpstreamMetrics{}
}
if us, ok := m.PerUpstream[upstream]; ok {
return us
}
us := &dnsUpstreamMetrics{}
m.PerUpstream[upstream] = us
return us
}
func (m *dnsMetrics) addSuccess(upstream string) {
m.Attempts++
m.OK++
us := m.ensureUpstream(upstream)
us.Attempts++
us.OK++
}
func (m *dnsMetrics) addError(upstream string, kind dnsErrorKind) {
m.Attempts++
us := m.ensureUpstream(upstream)
us.Attempts++
switch kind {
case dnsErrorNXDomain:
m.NXDomain++
us.NXDomain++
case dnsErrorTimeout:
m.Timeout++
us.Timeout++
case dnsErrorTemporary:
m.Temporary++
us.Temporary++
default:
m.Other++
us.Other++
}
}
func (m *dnsMetrics) merge(other dnsMetrics) {
m.Attempts += other.Attempts
m.OK += other.OK
m.NXDomain += other.NXDomain
m.Timeout += other.Timeout
m.Temporary += other.Temporary
m.Other += other.Other
for upstream, src := range other.PerUpstream {
dst := m.ensureUpstream(upstream)
dst.Attempts += src.Attempts
dst.OK += src.OK
dst.NXDomain += src.NXDomain
dst.Timeout += src.Timeout
dst.Temporary += src.Temporary
dst.Other += src.Other
}
}
func (m dnsMetrics) totalErrors() int {
return m.NXDomain + m.Timeout + m.Temporary + m.Other
}
func (m dnsMetrics) formatPerUpstream() string {
if len(m.PerUpstream) == 0 {
return ""
}
keys := make([]string, 0, len(m.PerUpstream))
for k := range m.PerUpstream {
keys = append(keys, k)
}
sort.Strings(keys)
parts := make([]string, 0, len(keys))
for _, k := range keys {
v := m.PerUpstream[k]
parts = append(parts, fmt.Sprintf("%s{attempts=%d ok=%d nxdomain=%d timeout=%d temporary=%d other=%d}", k, v.Attempts, v.OK, v.NXDomain, v.Timeout, v.Temporary, v.Other))
}
return strings.Join(parts, "; ")
}
func (m dnsMetrics) formatResolverHealth() string {
if len(m.PerUpstream) == 0 {
return ""
}
keys := make([]string, 0, len(m.PerUpstream))
for k := range m.PerUpstream {
keys = append(keys, k)
}
sort.Strings(keys)
parts := make([]string, 0, len(keys))
for _, k := range keys {
v := m.PerUpstream[k]
if v == nil || v.Attempts <= 0 {
continue
}
okRate := float64(v.OK) / float64(v.Attempts)
timeoutRate := float64(v.Timeout) / float64(v.Attempts)
score := okRate*100.0 - timeoutRate*50.0
state := "bad"
switch {
case score >= 70 && timeoutRate <= 0.05:
state = "good"
case score >= 35:
state = "degraded"
default:
state = "bad"
}
parts = append(parts, fmt.Sprintf("%s{score=%.1f state=%s attempts=%d ok=%d timeout=%d nxdomain=%d temporary=%d other=%d}", k, score, state, v.Attempts, v.OK, v.Timeout, v.NXDomain, v.Temporary, v.Other))
}
return strings.Join(parts, "; ")
}
type wildcardMatcher struct {
exact map[string]struct{}
suffix []string
}
type dnsAttemptPolicy struct {
TryLimit int
DomainBudget time.Duration
StopOnNX bool
}
const (
domainStateActive = "active"
domainStateStable = "stable"
domainStateSuspect = "suspect"
domainStateQuarantine = "quarantine"
domainStateHardQuar = "hard_quarantine"
domainScoreMin = -100
domainScoreMax = 100
defaultQuarantineTTL = 24 * 3600
defaultHardQuarantineTT = 7 * 24 * 3600
)
type resolverTimeoutRecheckStats struct {
Checked int
Recovered int
RecoveredIPs int
StillTimeout int
NowNXDomain int
NowTemporary int
NowOther int
NoSignal int
}
// Empty by default: primary resolver pool comes from DNS upstream pool state.
// Optional fallback list can still be provided via RESOLVE_DNS_FALLBACKS env.
var resolverFallbackDNS []string
func normalizeWildcardDomain(raw string) string {
d := strings.TrimSpace(strings.SplitN(raw, "#", 2)[0])
d = strings.ToLower(d)
d = strings.TrimPrefix(d, "*.")
d = strings.TrimPrefix(d, ".")
d = strings.TrimSuffix(d, ".")
return d
}
func newWildcardMatcher(domains []string) wildcardMatcher {
seen := map[string]struct{}{}
m := wildcardMatcher{exact: map[string]struct{}{}}
for _, raw := range domains {
d := normalizeWildcardDomain(raw)
if d == "" {
continue
}
if _, ok := seen[d]; ok {
continue
}
seen[d] = struct{}{}
m.exact[d] = struct{}{}
m.suffix = append(m.suffix, "."+d)
}
return m
}
func (m wildcardMatcher) match(host string) bool {
if len(m.exact) == 0 {
return false
}
h := strings.TrimSuffix(strings.ToLower(strings.TrimSpace(host)), ".")
if h == "" {
return false
}
if _, ok := m.exact[h]; ok {
return true
}
for _, suffix := range m.suffix {
if strings.HasSuffix(h, suffix) {
return true
}
}
return false
}
// ---------------------------------------------------------------------
// EN: `runResolverJob` runs the workflow for resolver job.
// RU: `runResolverJob` - запускает рабочий процесс для resolver job.
// ---------------------------------------------------------------------
func runResolverJob(opts ResolverOpts, logf func(string, ...any)) (resolverResult, error) {
res := resolverResult{
DomainCache: map[string]any{},
PtrCache: map[string]any{},
}
domains := loadList(opts.DomainsPath)
metaSpecial := loadList(opts.MetaPath)
staticLines := readLinesAllowMissing(opts.StaticPath)
wildcards := newWildcardMatcher(opts.SmartDNSWildcards)
cfg := loadDNSConfig(opts.DNSConfigPath, logf)
if !smartDNSForced() {
cfg.Mode = normalizeDNSResolverMode(opts.Mode, opts.ViaSmartDNS)
}
if addr := normalizeSmartDNSAddr(opts.SmartDNSAddr); addr != "" {
cfg.SmartDNS = addr
}
if cfg.SmartDNS == "" {
cfg.SmartDNS = smartDNSAddr()
}
if cfg.Mode == DNSModeSmartDNS && cfg.SmartDNS != "" {
cfg.Default = []string{cfg.SmartDNS}
cfg.Meta = []string{cfg.SmartDNS}
}
if logf != nil {
switch cfg.Mode {
case DNSModeSmartDNS:
logf("resolver dns mode: SmartDNS-only (%s)", cfg.SmartDNS)
case DNSModeHybridWildcard:
logf("resolver dns mode: hybrid_wildcard smartdns=%s wildcards=%d default=%v meta=%v", cfg.SmartDNS, len(wildcards.exact), cfg.Default, cfg.Meta)
default:
logf("resolver dns mode: direct default=%v meta=%v", cfg.Default, cfg.Meta)
}
}
ttl := opts.TTL
if ttl <= 0 {
ttl = 24 * 3600
}
// safety clamp: 60s .. 24h
if ttl < 60 {
ttl = 60
}
if ttl > 24*3600 {
ttl = 24 * 3600
}
workers := opts.Workers
if workers <= 0 {
workers = 80
}
// safety clamp: 1..500
if workers < 1 {
workers = 1
}
if workers > 500 {
workers = 500
}
dnsTimeoutMs := envInt("RESOLVE_DNS_TIMEOUT_MS", 1800)
if dnsTimeoutMs < 300 {
dnsTimeoutMs = 300
}
if dnsTimeoutMs > 5000 {
dnsTimeoutMs = 5000
}
dnsTimeout := time.Duration(dnsTimeoutMs) * time.Millisecond
domainCache := loadDomainCacheState(opts.CachePath, logf)
ptrCache := loadJSONMap(opts.PtrCachePath)
now := int(time.Now().Unix())
precheckEverySec := envInt("RESOLVE_PRECHECK_EVERY_SEC", 24*3600)
if precheckEverySec < 0 {
precheckEverySec = 0
}
precheckMaxDomains := envInt("RESOLVE_PRECHECK_MAX_DOMAINS", 3000)
if precheckMaxDomains < 0 {
precheckMaxDomains = 0
}
if precheckMaxDomains > 50000 {
precheckMaxDomains = 50000
}
timeoutRecheckMax := envInt("RESOLVE_TIMEOUT_RECHECK_MAX", precheckMaxDomains)
if timeoutRecheckMax < 0 {
timeoutRecheckMax = 0
}
if timeoutRecheckMax > 50000 {
timeoutRecheckMax = 50000
}
precheckStatePath := opts.CachePath + ".precheck.json"
precheckLastRun := loadResolverPrecheckLastRun(precheckStatePath)
precheckEnvForced := resolvePrecheckForceEnvEnabled()
precheckFileForced := resolvePrecheckForceFileEnabled(precheckForcePath)
precheckDue := precheckEnvForced || precheckFileForced || (precheckEverySec > 0 && (precheckLastRun <= 0 || now-precheckLastRun >= precheckEverySec))
precheckScheduled := 0
staleKeepSec := envInt("RESOLVE_STALE_KEEP_SEC", 48*3600)
if staleKeepSec < 0 {
staleKeepSec = 0
}
if staleKeepSec > 7*24*3600 {
staleKeepSec = 7 * 24 * 3600
}
negTTLNX := envInt("RESOLVE_NEGATIVE_TTL_NX", 6*3600)
negTTLTimeout := envInt("RESOLVE_NEGATIVE_TTL_TIMEOUT", 15*60)
negTTLTemporary := envInt("RESOLVE_NEGATIVE_TTL_TEMPORARY", 10*60)
negTTLOther := envInt("RESOLVE_NEGATIVE_TTL_OTHER", 10*60)
clampTTL := func(v int) int {
if v < 0 {
return 0
}
if v > 24*3600 {
return 24 * 3600
}
return v
}
negTTLNX = clampTTL(negTTLNX)
negTTLTimeout = clampTTL(negTTLTimeout)
negTTLTemporary = clampTTL(negTTLTemporary)
negTTLOther = clampTTL(negTTLOther)
cacheSourceForHost := func(host string) domainCacheSource {
switch cfg.Mode {
case DNSModeSmartDNS:
return domainCacheSourceWildcard
case DNSModeHybridWildcard:
if wildcards.match(host) {
return domainCacheSourceWildcard
}
}
return domainCacheSourceDirect
}
timeoutRecheck := resolverTimeoutRecheckStats{}
if precheckDue && timeoutRecheckMax > 0 {
timeoutRecheck = runTimeoutQuarantineRecheck(
domains,
cfg,
metaSpecial,
wildcards,
dnsTimeout,
&domainCache,
cacheSourceForHost,
now,
timeoutRecheckMax,
workers,
)
}
if logf != nil {
logf("resolver start: domains=%d ttl=%ds workers=%d dns_timeout_ms=%d", len(domains), ttl, workers, dnsTimeoutMs)
directPolicy := directDNSAttemptPolicy(len(cfg.Default))
wildcardPolicy := wildcardDNSAttemptPolicy(1)
logf(
"resolver policy: direct_try=%d direct_budget_ms=%d wildcard_try=%d wildcard_budget_ms=%d nx_early_stop=%t stale_keep_sec=%d precheck_every_sec=%d precheck_max=%d precheck_forced_env=%t precheck_forced_file=%t",
directPolicy.TryLimit,
directPolicy.DomainBudget.Milliseconds(),
wildcardPolicy.TryLimit,
wildcardPolicy.DomainBudget.Milliseconds(),
resolveNXEarlyStopEnabled(),
staleKeepSec,
precheckEverySec,
precheckMaxDomains,
precheckEnvForced,
precheckFileForced,
)
}
start := time.Now()
fresh := map[string][]string{}
cacheNegativeHits := 0
quarantineHits := 0
staleHits := 0
var toResolve []string
for _, d := range domains {
source := cacheSourceForHost(d)
if ips, ok := domainCache.get(d, source, now, ttl); ok {
fresh[d] = ips
if logf != nil {
logf("cache hit[%s]: %s -> %v", source, d, ips)
}
continue
}
// Quarantine has priority over negative TTL cache so 24h quarantine
// is not silently overridden by shorter negative cache windows.
if state, age, ok := domainCache.getQuarantine(d, source, now); ok {
kind, hasKind := domainCache.getLastErrorKind(d, source)
timeoutKind := hasKind && kind == dnsErrorTimeout
if precheckDue && precheckScheduled < precheckMaxDomains {
// Timeout-based quarantine is rechecked in background batch and should
// not flood trace with per-domain debug lines.
if timeoutKind {
quarantineHits++
if staleKeepSec > 0 {
if staleIPs, staleAge, ok := domainCache.getStale(d, source, now, staleKeepSec); ok {
staleHits++
fresh[d] = staleIPs
if logf != nil {
logf("cache stale-keep (quarantine)[age=%ds]: %s -> %v", staleAge, d, staleIPs)
}
}
}
continue
}
precheckScheduled++
toResolve = append(toResolve, d)
if logf != nil {
logf("precheck schedule[quarantine/%s age=%ds]: %s (%s)", state, age, d, source)
}
continue
}
quarantineHits++
if logf != nil {
logf("cache quarantine hit[%s age=%ds]: %s (%s)", state, age, d, source)
}
if staleKeepSec > 0 {
if staleIPs, staleAge, ok := domainCache.getStale(d, source, now, staleKeepSec); ok {
staleHits++
fresh[d] = staleIPs
if logf != nil {
logf("cache stale-keep (quarantine)[age=%ds]: %s -> %v", staleAge, d, staleIPs)
}
}
}
continue
}
if kind, age, ok := domainCache.getNegative(d, source, now, negTTLNX, negTTLTimeout, negTTLTemporary, negTTLOther); ok {
if precheckDue && precheckScheduled < precheckMaxDomains {
if kind == dnsErrorTimeout {
cacheNegativeHits++
continue
}
precheckScheduled++
toResolve = append(toResolve, d)
if logf != nil {
logf("precheck schedule[negative/%s age=%ds]: %s (%s)", kind, age, d, source)
}
continue
}
cacheNegativeHits++
if logf != nil {
logf("cache neg hit[%s/%s age=%ds]: %s", source, kind, age, d)
}
continue
}
toResolve = append(toResolve, d)
}
resolved := map[string][]string{}
for k, v := range fresh {
resolved[k] = v
}
if logf != nil {
logf("resolve: domains=%d cache_hits=%d cache_neg_hits=%d quarantine_hits=%d stale_hits=%d precheck_due=%t precheck_scheduled=%d to_resolve=%d", len(domains), len(fresh), cacheNegativeHits, quarantineHits, staleHits, precheckDue, precheckScheduled, len(toResolve))
}
dnsStats := dnsMetrics{}
resolvedNowDNS := 0
resolvedNowStale := 0
unresolvedAfterAttempts := 0
if len(toResolve) > 0 {
type job struct {
host string
}
jobs := make(chan job, len(toResolve))
results := make(chan struct {
host string
ips []string
stats dnsMetrics
}, len(toResolve))
for i := 0; i < workers; i++ {
go func() {
for j := range jobs {
ips, stats := resolveHostGo(j.host, cfg, metaSpecial, wildcards, dnsTimeout, logf)
results <- struct {
host string
ips []string
stats dnsMetrics
}{j.host, ips, stats}
}
}()
}
for _, h := range toResolve {
jobs <- job{host: h}
}
close(jobs)
for i := 0; i < len(toResolve); i++ {
r := <-results
dnsStats.merge(r.stats)
hostErrors := r.stats.totalErrors()
if hostErrors > 0 && logf != nil {
logf("resolve errors for %s: total=%d nxdomain=%d timeout=%d temporary=%d other=%d", r.host, hostErrors, r.stats.NXDomain, r.stats.Timeout, r.stats.Temporary, r.stats.Other)
}
if len(r.ips) > 0 {
resolved[r.host] = r.ips
resolvedNowDNS++
source := cacheSourceForHost(r.host)
domainCache.set(r.host, source, r.ips, now)
if logf != nil {
logf("%s -> %v", r.host, r.ips)
}
} else {
staleApplied := false
if hostErrors > 0 {
source := cacheSourceForHost(r.host)
domainCache.setErrorWithStats(r.host, source, r.stats, now)
if staleKeepSec > 0 && shouldUseStaleOnError(r.stats) {
if staleIPs, staleAge, ok := domainCache.getStale(r.host, source, now, staleKeepSec); ok {
staleHits++
resolvedNowStale++
staleApplied = true
resolved[r.host] = staleIPs
if logf != nil {
logf("cache stale-keep (error)[age=%ds]: %s -> %v", staleAge, r.host, staleIPs)
}
}
}
}
if !staleApplied {
unresolvedAfterAttempts++
}
if logf != nil {
if _, ok := resolved[r.host]; !ok {
logf("%s: no IPs", r.host)
}
}
}
}
}
staticEntries, staticSkipped := parseStaticEntriesGo(staticLines, logf)
staticLabels, ptrLookups, ptrErrors := resolveStaticLabels(staticEntries, cfg, ptrCache, ttl, logf)
ipSetAll := map[string]struct{}{}
ipSetDirect := map[string]struct{}{}
ipSetWildcard := map[string]struct{}{}
ipMapAll := map[string]map[string]struct{}{}
ipMapDirect := map[string]map[string]struct{}{}
ipMapWildcard := map[string]map[string]struct{}{}
add := func(set map[string]struct{}, labels map[string]map[string]struct{}, ip, label string) {
if ip == "" {
return
}
set[ip] = struct{}{}
m := labels[ip]
if m == nil {
m = map[string]struct{}{}
labels[ip] = m
}
m[label] = struct{}{}
}
isWildcardHost := func(host string) bool {
switch cfg.Mode {
case DNSModeSmartDNS:
return true
case DNSModeHybridWildcard:
return wildcards.match(host)
default:
return false
}
}
for host, ips := range resolved {
wildcardHost := isWildcardHost(host)
for _, ip := range ips {
add(ipSetAll, ipMapAll, ip, host)
if wildcardHost {
add(ipSetWildcard, ipMapWildcard, ip, host)
} else {
add(ipSetDirect, ipMapDirect, ip, host)
}
}
}
for ipEntry, labels := range staticLabels {
for _, lbl := range labels {
add(ipSetAll, ipMapAll, ipEntry, lbl)
// Static entries are explicit operator rules; keep them in direct set.
add(ipSetDirect, ipMapDirect, ipEntry, lbl)
}
}
appendMapPairs := func(dst *[][2]string, labelsByIP map[string]map[string]struct{}) {
for ip := range labelsByIP {
labels := labelsByIP[ip]
for lbl := range labels {
*dst = append(*dst, [2]string{ip, lbl})
}
}
sort.Slice(*dst, func(i, j int) bool {
if (*dst)[i][0] == (*dst)[j][0] {
return (*dst)[i][1] < (*dst)[j][1]
}
return (*dst)[i][0] < (*dst)[j][0]
})
}
appendIPs := func(dst *[]string, set map[string]struct{}) {
for ip := range set {
*dst = append(*dst, ip)
}
sort.Strings(*dst)
}
appendMapPairs(&res.IPMap, ipMapAll)
appendMapPairs(&res.DirectIPMap, ipMapDirect)
appendMapPairs(&res.WildcardIPMap, ipMapWildcard)
appendIPs(&res.IPs, ipSetAll)
appendIPs(&res.DirectIPs, ipSetDirect)
appendIPs(&res.WildcardIPs, ipSetWildcard)
res.DomainCache = domainCache.toMap()
res.PtrCache = ptrCache
if logf != nil {
dnsErrors := dnsStats.totalErrors()
logf(
"resolve summary: domains=%d cache_hits=%d cache_neg_hits=%d quarantine_hits=%d stale_hits=%d resolved_now=%d unresolved=%d static_entries=%d static_skipped=%d unique_ips=%d direct_ips=%d wildcard_ips=%d ptr_lookups=%d ptr_errors=%d dns_attempts=%d dns_ok=%d dns_nxdomain=%d dns_timeout=%d dns_temporary=%d dns_other=%d dns_errors=%d timeout_recheck_checked=%d timeout_recheck_recovered=%d timeout_recheck_recovered_ips=%d timeout_recheck_still_timeout=%d timeout_recheck_now_nxdomain=%d timeout_recheck_now_temporary=%d timeout_recheck_now_other=%d timeout_recheck_no_signal=%d duration_ms=%d",
len(domains),
len(fresh),
cacheNegativeHits,
quarantineHits,
staleHits,
len(resolved)-len(fresh),
len(domains)-len(resolved),
len(staticEntries),
staticSkipped,
len(res.IPs),
len(res.DirectIPs),
len(res.WildcardIPs),
ptrLookups,
ptrErrors,
dnsStats.Attempts,
dnsStats.OK,
dnsStats.NXDomain,
dnsStats.Timeout,
dnsStats.Temporary,
dnsStats.Other,
dnsErrors,
timeoutRecheck.Checked,
timeoutRecheck.Recovered,
timeoutRecheck.RecoveredIPs,
timeoutRecheck.StillTimeout,
timeoutRecheck.NowNXDomain,
timeoutRecheck.NowTemporary,
timeoutRecheck.NowOther,
timeoutRecheck.NoSignal,
time.Since(start).Milliseconds(),
)
if perUpstream := dnsStats.formatPerUpstream(); perUpstream != "" {
logf("resolve dns upstreams: %s", perUpstream)
}
if health := dnsStats.formatResolverHealth(); health != "" {
logf("resolve dns health: %s", health)
}
if stateSummary := domainCache.formatStateSummary(now); stateSummary != "" {
logf("resolve domain states: %s", stateSummary)
}
logf(
"resolve breakdown: resolved_now_total=%d resolved_now_dns=%d resolved_now_stale=%d skipped_neg=%d skipped_quarantine=%d unresolved_after_attempts=%d",
len(resolved)-len(fresh),
resolvedNowDNS,
resolvedNowStale,
cacheNegativeHits,
quarantineHits,
unresolvedAfterAttempts,
)
if precheckDue {
logf("resolve precheck done: scheduled=%d state=%s", precheckScheduled, precheckStatePath)
}
}
if precheckDue {
saveResolverPrecheckState(precheckStatePath, now, timeoutRecheck)
}
if precheckFileForced {
_ = os.Remove(precheckForcePath)
if logf != nil {
logf("resolve precheck force-file consumed: %s", precheckForcePath)
}
}
return res, nil
}
// ---------------------------------------------------------------------
// DNS resolve helpers
// ---------------------------------------------------------------------
func resolveHostGo(host string, cfg dnsConfig, metaSpecial []string, wildcards wildcardMatcher, timeout time.Duration, logf func(string, ...any)) ([]string, dnsMetrics) {
useMeta := false
for _, m := range metaSpecial {
if host == m {
useMeta = true
break
}
}
dnsList := cfg.Default
if useMeta {
dnsList = cfg.Meta
}
primaryViaSmartDNS := false
switch cfg.Mode {
case DNSModeSmartDNS:
if cfg.SmartDNS != "" {
dnsList = []string{cfg.SmartDNS}
primaryViaSmartDNS = true
}
case DNSModeHybridWildcard:
if cfg.SmartDNS != "" && wildcards.match(host) {
dnsList = []string{cfg.SmartDNS}
primaryViaSmartDNS = true
}
}
policy := directDNSAttemptPolicy(len(dnsList))
if primaryViaSmartDNS {
policy = wildcardDNSAttemptPolicy(len(dnsList))
}
ips, stats := digAWithPolicy(host, dnsList, timeout, logf, policy)
if len(ips) == 0 &&
!primaryViaSmartDNS &&
cfg.SmartDNS != "" &&
smartDNSFallbackForTimeoutEnabled() &&
shouldFallbackToSmartDNS(stats) {
if logf != nil {
logf(
"dns fallback %s: trying smartdns=%s after errors nxdomain=%d timeout=%d temporary=%d other=%d",
host,
cfg.SmartDNS,
stats.NXDomain,
stats.Timeout,
stats.Temporary,
stats.Other,
)
}
fallbackPolicy := wildcardDNSAttemptPolicy(1)
fallbackIPs, fallbackStats := digAWithPolicy(host, []string{cfg.SmartDNS}, timeout, logf, fallbackPolicy)
stats.merge(fallbackStats)
if len(fallbackIPs) > 0 {
ips = fallbackIPs
if logf != nil {
logf("dns fallback %s: resolved via smartdns (%d ips)", host, len(fallbackIPs))
}
}
}
out := []string{}
seen := map[string]struct{}{}
for _, ip := range ips {
if isPrivateIPv4(ip) {
continue
}
if _, ok := seen[ip]; !ok {
seen[ip] = struct{}{}
out = append(out, ip)
}
}
return out, stats
}
// smartDNSFallbackForTimeoutEnabled controls direct->SmartDNS fallback behavior.
// Default is disabled to avoid overloading SmartDNS on large unresolved batches.
// Set RESOLVE_SMARTDNS_TIMEOUT_FALLBACK=1 to enable.
func smartDNSFallbackForTimeoutEnabled() bool {
v := strings.ToLower(strings.TrimSpace(os.Getenv("RESOLVE_SMARTDNS_TIMEOUT_FALLBACK")))
switch v {
case "1", "true", "yes", "on":
return true
case "0", "false", "no", "off":
return false
default:
return false
}
}
// Fallback is useful only for transport-like errors. If we already got NXDOMAIN,
// SmartDNS fallback is unlikely to change result and only adds latency/noise.
func shouldFallbackToSmartDNS(stats dnsMetrics) bool {
if stats.OK > 0 {
return false
}
if stats.NXDomain > 0 {
return false
}
if stats.Timeout > 0 || stats.Temporary > 0 {
return true
}
return stats.Other > 0
}
func classifyHostErrorKind(stats dnsMetrics) (dnsErrorKind, bool) {
if stats.Timeout > 0 {
return dnsErrorTimeout, true
}
if stats.Temporary > 0 {
return dnsErrorTemporary, true
}
if stats.Other > 0 {
return dnsErrorOther, true
}
if stats.NXDomain > 0 {
return dnsErrorNXDomain, true
}
return "", false
}
func shouldUseStaleOnError(stats dnsMetrics) bool {
if stats.OK > 0 {
return false
}
return stats.Timeout > 0 || stats.Temporary > 0 || stats.Other > 0
}
func runTimeoutQuarantineRecheck(
domains []string,
cfg dnsConfig,
metaSpecial []string,
wildcards wildcardMatcher,
timeout time.Duration,
domainCache *domainCacheState,
cacheSourceForHost func(string) domainCacheSource,
now int,
limit int,
workers int,
) resolverTimeoutRecheckStats {
stats := resolverTimeoutRecheckStats{}
if limit <= 0 || now <= 0 {
return stats
}
if workers < 1 {
workers = 1
}
if workers > 200 {
workers = 200
}
seen := map[string]struct{}{}
capHint := len(domains)
if capHint > limit {
capHint = limit
}
candidates := make([]string, 0, capHint)
for _, raw := range domains {
host := strings.TrimSpace(strings.ToLower(raw))
if host == "" {
continue
}
if _, ok := seen[host]; ok {
continue
}
seen[host] = struct{}{}
source := cacheSourceForHost(host)
if _, _, ok := domainCache.getQuarantine(host, source, now); !ok {
continue
}
kind, ok := domainCache.getLastErrorKind(host, source)
if !ok || kind != dnsErrorTimeout {
continue
}
candidates = append(candidates, host)
if len(candidates) >= limit {
break
}
}
if len(candidates) == 0 {
return stats
}
recoveredIPSet := map[string]struct{}{}
type result struct {
host string
source domainCacheSource
ips []string
dns dnsMetrics
}
jobs := make(chan string, len(candidates))
results := make(chan result, len(candidates))
for i := 0; i < workers; i++ {
go func() {
for host := range jobs {
src := cacheSourceForHost(host)
ips, dnsStats := resolveHostGo(host, cfg, metaSpecial, wildcards, timeout, nil)
results <- result{host: host, source: src, ips: ips, dns: dnsStats}
}
}()
}
for _, host := range candidates {
jobs <- host
}
close(jobs)
for i := 0; i < len(candidates); i++ {
r := <-results
stats.Checked++
if len(r.ips) > 0 {
for _, ip := range r.ips {
ip = strings.TrimSpace(ip)
if ip == "" {
continue
}
recoveredIPSet[ip] = struct{}{}
}
domainCache.set(r.host, r.source, r.ips, now)
stats.Recovered++
continue
}
if r.dns.totalErrors() > 0 {
domainCache.setErrorWithStats(r.host, r.source, r.dns, now)
}
kind, ok := classifyHostErrorKind(r.dns)
if !ok {
stats.NoSignal++
continue
}
switch kind {
case dnsErrorTimeout:
stats.StillTimeout++
case dnsErrorNXDomain:
stats.NowNXDomain++
case dnsErrorTemporary:
stats.NowTemporary++
default:
stats.NowOther++
}
}
stats.RecoveredIPs = len(recoveredIPSet)
return stats
}
// ---------------------------------------------------------------------
// EN: `digA` contains core logic for dig a.
// RU: `digA` - содержит основную логику для dig a.
// ---------------------------------------------------------------------
func digA(host string, dnsList []string, timeout time.Duration, logf func(string, ...any)) ([]string, dnsMetrics) {
return digAWithPolicy(host, dnsList, timeout, logf, defaultDNSAttemptPolicy(len(dnsList)))
}
func digAWithPolicy(host string, dnsList []string, timeout time.Duration, logf func(string, ...any), policy dnsAttemptPolicy) ([]string, dnsMetrics) {
stats := dnsMetrics{}
if len(dnsList) == 0 {
return nil, stats
}
tryLimit := policy.TryLimit
if tryLimit <= 0 {
tryLimit = 1
}
if tryLimit > len(dnsList) {
tryLimit = len(dnsList)
}
budget := policy.DomainBudget
if budget <= 0 {
budget = time.Duration(tryLimit) * timeout
}
if budget < 200*time.Millisecond {
budget = 200 * time.Millisecond
}
deadline := time.Now().Add(budget)
start := pickDNSStartIndex(host, len(dnsList))
for attempt := 0; attempt < tryLimit; attempt++ {
remaining := time.Until(deadline)
if remaining <= 0 {
if logf != nil {
logf("dns budget exhausted %s: attempts=%d budget_ms=%d", host, attempt, budget.Milliseconds())
}
break
}
entry := dnsList[(start+attempt)%len(dnsList)]
server, port := splitDNS(entry)
if server == "" {
continue
}
if port == "" {
port = "53"
}
addr := net.JoinHostPort(server, port)
r := &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
d := net.Dialer{}
return d.DialContext(ctx, "udp", addr)
},
}
perAttemptTimeout := timeout
if remaining < perAttemptTimeout {
perAttemptTimeout = remaining
}
if perAttemptTimeout < 100*time.Millisecond {
perAttemptTimeout = 100 * time.Millisecond
}
ctx, cancel := context.WithTimeout(context.Background(), perAttemptTimeout)
records, err := r.LookupHost(ctx, host)
cancel()
if err != nil {
kind := classifyDNSError(err)
stats.addError(addr, kind)
if logf != nil {
logf("dns warn %s via %s: kind=%s attempt=%d/%d err=%v", host, addr, kind, attempt+1, tryLimit, err)
}
if policy.StopOnNX && kind == dnsErrorNXDomain {
if logf != nil {
logf("dns early-stop %s: nxdomain via %s (attempt=%d/%d)", host, addr, attempt+1, tryLimit)
}
break
}
continue
}
var ips []string
for _, ip := range records {
if isPrivateIPv4(ip) {
continue
}
ips = append(ips, ip)
}
if len(ips) == 0 {
stats.addError(addr, dnsErrorOther)
if logf != nil {
logf("dns warn %s via %s: kind=other err=no_public_ips", host, addr)
}
continue
}
stats.addSuccess(addr)
return uniqueStrings(ips), stats
}
return nil, stats
}
func defaultDNSAttemptPolicy(dnsCount int) dnsAttemptPolicy {
tryLimit := envInt("RESOLVE_DNS_TRY_LIMIT", 2)
if tryLimit < 1 {
tryLimit = 1
}
if dnsCount > 0 && tryLimit > dnsCount {
tryLimit = dnsCount
}
budgetMS := envInt("RESOLVE_DNS_DOMAIN_BUDGET_MS", 1200)
if budgetMS < 200 {
budgetMS = 200
}
if budgetMS > 15000 {
budgetMS = 15000
}
return dnsAttemptPolicy{
TryLimit: tryLimit,
DomainBudget: time.Duration(budgetMS) * time.Millisecond,
StopOnNX: resolveNXEarlyStopEnabled(),
}
}
func directDNSAttemptPolicy(dnsCount int) dnsAttemptPolicy {
tryLimit := envInt("RESOLVE_DIRECT_TRY_LIMIT", 2)
if tryLimit < 1 {
tryLimit = 1
}
if tryLimit > 3 {
tryLimit = 3
}
if dnsCount > 0 && tryLimit > dnsCount {
tryLimit = dnsCount
}
budgetMS := envInt("RESOLVE_DIRECT_BUDGET_MS", 1200)
if budgetMS < 200 {
budgetMS = 200
}
if budgetMS > 15000 {
budgetMS = 15000
}
return dnsAttemptPolicy{
TryLimit: tryLimit,
DomainBudget: time.Duration(budgetMS) * time.Millisecond,
StopOnNX: resolveNXEarlyStopEnabled(),
}
}
func wildcardDNSAttemptPolicy(dnsCount int) dnsAttemptPolicy {
tryLimit := envInt("RESOLVE_WILDCARD_TRY_LIMIT", 1)
if tryLimit < 1 {
tryLimit = 1
}
if tryLimit > 2 {
tryLimit = 2
}
if dnsCount > 0 && tryLimit > dnsCount {
tryLimit = dnsCount
}
budgetMS := envInt("RESOLVE_WILDCARD_BUDGET_MS", 1200)
if budgetMS < 200 {
budgetMS = 200
}
if budgetMS > 15000 {
budgetMS = 15000
}
return dnsAttemptPolicy{
TryLimit: tryLimit,
DomainBudget: time.Duration(budgetMS) * time.Millisecond,
StopOnNX: resolveNXEarlyStopEnabled(),
}
}
func resolveNXEarlyStopEnabled() bool {
switch strings.ToLower(strings.TrimSpace(os.Getenv("RESOLVE_NX_EARLY_STOP"))) {
case "0", "false", "no", "off":
return false
default:
return true
}
}
func resolvePrecheckForceEnvEnabled() bool {
switch strings.ToLower(strings.TrimSpace(os.Getenv("RESOLVE_PRECHECK_FORCE"))) {
case "1", "true", "yes", "on":
return true
default:
return false
}
}
func resolvePrecheckForceFileEnabled(path string) bool {
if strings.TrimSpace(path) == "" {
return false
}
_, err := os.Stat(path)
return err == nil
}
func classifyDNSError(err error) dnsErrorKind {
if err == nil {
return dnsErrorOther
}
var dnsErr *net.DNSError
if errors.As(err, &dnsErr) {
if dnsErr.IsNotFound {
return dnsErrorNXDomain
}
if dnsErr.IsTimeout {
return dnsErrorTimeout
}
if dnsErr.IsTemporary {
return dnsErrorTemporary
}
}
msg := strings.ToLower(err.Error())
switch {
case strings.Contains(msg, "no such host"), strings.Contains(msg, "nxdomain"):
return dnsErrorNXDomain
case strings.Contains(msg, "i/o timeout"), strings.Contains(msg, "timeout"):
return dnsErrorTimeout
case strings.Contains(msg, "temporary"):
return dnsErrorTemporary
default:
return dnsErrorOther
}
}
// ---------------------------------------------------------------------
// EN: `splitDNS` splits dns into structured parts.
// RU: `splitDNS` - разделяет dns на структурированные части.
// ---------------------------------------------------------------------
func splitDNS(dns string) (string, string) {
if strings.Contains(dns, "#") {
parts := strings.SplitN(dns, "#", 2)
host := strings.TrimSpace(parts[0])
port := strings.TrimSpace(parts[1])
if host == "" {
host = "127.0.0.1"
}
if port == "" {
port = "53"
}
return host, port
}
return strings.TrimSpace(dns), ""
}
// ---------------------------------------------------------------------
// static entries + PTR labels
// ---------------------------------------------------------------------
func parseStaticEntriesGo(lines []string, logf func(string, ...any)) (entries [][3]string, skipped int) {
for _, ln := range lines {
s := strings.TrimSpace(ln)
if s == "" || strings.HasPrefix(s, "#") {
continue
}
comment := ""
if idx := strings.Index(s, "#"); idx >= 0 {
comment = strings.TrimSpace(s[idx+1:])
s = strings.TrimSpace(s[:idx])
}
if s == "" || isPrivateIPv4(s) {
continue
}
// validate ip/prefix
rawBase := strings.SplitN(s, "/", 2)[0]
if strings.Contains(s, "/") {
if _, err := netip.ParsePrefix(s); err != nil {
skipped++
if logf != nil {
logf("static skip invalid prefix %q: %v", s, err)
}
continue
}
} else {
if _, err := netip.ParseAddr(rawBase); err != nil {
skipped++
if logf != nil {
logf("static skip invalid ip %q: %v", s, err)
}
continue
}
}
entries = append(entries, [3]string{s, rawBase, comment})
}
return entries, skipped
}
// ---------------------------------------------------------------------
// EN: `resolveStaticLabels` resolves static labels into concrete values.
// RU: `resolveStaticLabels` - резолвит static labels в конкретные значения.
// ---------------------------------------------------------------------
func resolveStaticLabels(entries [][3]string, cfg dnsConfig, ptrCache map[string]any, ttl int, logf func(string, ...any)) (map[string][]string, int, int) {
now := int(time.Now().Unix())
result := map[string][]string{}
ptrLookups := 0
ptrErrors := 0
dnsForPtr := ""
if len(cfg.Default) > 0 {
dnsForPtr = cfg.Default[0]
} else {
dnsForPtr = defaultDNS1
}
for _, e := range entries {
ipEntry, baseIP, comment := e[0], e[1], e[2]
var labels []string
if comment != "" {
labels = append(labels, "*"+comment)
}
if comment == "" {
if cached, ok := ptrCache[baseIP].(map[string]any); ok {
names, _ := cached["names"].([]any)
last, _ := cached["last_resolved"].(float64)
if len(names) > 0 && last > 0 && now-int(last) <= ttl {
for _, n := range names {
if s, ok := n.(string); ok && s != "" {
labels = append(labels, "*"+s)
}
}
}
}
if len(labels) == 0 {
ptrLookups++
names, err := digPTR(baseIP, dnsForPtr, 3*time.Second, logf)
if err != nil {
ptrErrors++
}
if len(names) > 0 {
ptrCache[baseIP] = map[string]any{"names": names, "last_resolved": now}
for _, n := range names {
labels = append(labels, "*"+n)
}
}
}
}
if len(labels) == 0 {
labels = []string{"*[STATIC-IP]"}
}
result[ipEntry] = labels
if logf != nil {
logf("static %s -> %v", ipEntry, labels)
}
}
return result, ptrLookups, ptrErrors
}
// ---------------------------------------------------------------------
// DNS config + cache helpers
// ---------------------------------------------------------------------
type domainCacheSource string
const (
domainCacheSourceDirect domainCacheSource = "direct"
domainCacheSourceWildcard domainCacheSource = "wildcard"
)
type domainCacheEntry struct {
IPs []string `json:"ips,omitempty"`
LastResolved int `json:"last_resolved,omitempty"`
LastErrorKind string `json:"last_error_kind,omitempty"`
LastErrorAt int `json:"last_error_at,omitempty"`
Score int `json:"score,omitempty"`
State string `json:"state,omitempty"`
QuarantineUntil int `json:"quarantine_until,omitempty"`
}
type domainCacheRecord struct {
Direct *domainCacheEntry `json:"direct,omitempty"`
Wildcard *domainCacheEntry `json:"wildcard,omitempty"`
}
type domainCacheState struct {
Version int `json:"version"`
Domains map[string]domainCacheRecord `json:"domains"`
}
func newDomainCacheState() domainCacheState {
return domainCacheState{
Version: 4,
Domains: map[string]domainCacheRecord{},
}
}
func normalizeCacheIPs(raw []string) []string {
seen := map[string]struct{}{}
out := make([]string, 0, len(raw))
for _, ip := range raw {
ip = strings.TrimSpace(ip)
if ip == "" || isPrivateIPv4(ip) {
continue
}
if _, ok := seen[ip]; ok {
continue
}
seen[ip] = struct{}{}
out = append(out, ip)
}
sort.Strings(out)
return out
}
func normalizeCacheErrorKind(raw string) (dnsErrorKind, bool) {
switch strings.ToLower(strings.TrimSpace(raw)) {
case string(dnsErrorNXDomain):
return dnsErrorNXDomain, true
case string(dnsErrorTimeout):
return dnsErrorTimeout, true
case string(dnsErrorTemporary):
return dnsErrorTemporary, true
case string(dnsErrorOther):
return dnsErrorOther, true
default:
return "", false
}
}
func normalizeDomainCacheEntry(in *domainCacheEntry) *domainCacheEntry {
if in == nil {
return nil
}
out := &domainCacheEntry{}
ips := normalizeCacheIPs(in.IPs)
if len(ips) > 0 && in.LastResolved > 0 {
out.IPs = ips
out.LastResolved = in.LastResolved
}
if kind, ok := normalizeCacheErrorKind(in.LastErrorKind); ok && in.LastErrorAt > 0 {
out.LastErrorKind = string(kind)
out.LastErrorAt = in.LastErrorAt
}
out.Score = clampDomainScore(in.Score)
if st := normalizeDomainState(in.State, out.Score); st != "" {
out.State = st
}
if in.QuarantineUntil > 0 {
out.QuarantineUntil = in.QuarantineUntil
}
if out.LastResolved <= 0 && out.LastErrorAt <= 0 {
if out.Score == 0 && out.QuarantineUntil <= 0 {
return nil
}
}
return out
}
func parseAnyStringSlice(raw any) []string {
switch v := raw.(type) {
case []string:
return append([]string(nil), v...)
case []any:
out := make([]string, 0, len(v))
for _, x := range v {
if s, ok := x.(string); ok {
out = append(out, s)
}
}
return out
default:
return nil
}
}
func parseAnyInt(raw any) (int, bool) {
switch v := raw.(type) {
case int:
return v, true
case int64:
return int(v), true
case float64:
return int(v), true
case json.Number:
n, err := v.Int64()
if err != nil {
return 0, false
}
return int(n), true
default:
return 0, false
}
}
func parseLegacyDomainCacheEntry(raw any) (domainCacheEntry, bool) {
m, ok := raw.(map[string]any)
if !ok {
return domainCacheEntry{}, false
}
ips := normalizeCacheIPs(parseAnyStringSlice(m["ips"]))
if len(ips) == 0 {
return domainCacheEntry{}, false
}
ts, ok := parseAnyInt(m["last_resolved"])
if !ok || ts <= 0 {
return domainCacheEntry{}, false
}
return domainCacheEntry{IPs: ips, LastResolved: ts}, true
}
func loadDomainCacheState(path string, logf func(string, ...any)) domainCacheState {
data, err := os.ReadFile(path)
if err != nil || len(data) == 0 {
return newDomainCacheState()
}
var st domainCacheState
if err := json.Unmarshal(data, &st); err == nil && st.Domains != nil {
if st.Version <= 0 {
st.Version = 4
}
normalized := newDomainCacheState()
for host, rec := range st.Domains {
host = strings.TrimSpace(strings.ToLower(host))
if host == "" {
continue
}
nrec := domainCacheRecord{}
nrec.Direct = normalizeDomainCacheEntry(rec.Direct)
nrec.Wildcard = normalizeDomainCacheEntry(rec.Wildcard)
if nrec.Direct != nil || nrec.Wildcard != nil {
normalized.Domains[host] = nrec
}
}
return normalized
}
// Legacy shape: { "domain.tld": {"ips":[...], "last_resolved":...}, ... }
var legacy map[string]any
if err := json.Unmarshal(data, &legacy); err != nil {
if logf != nil {
logf("domain-cache: invalid json at %s, ignore", path)
}
return newDomainCacheState()
}
out := newDomainCacheState()
migrated := 0
for host, raw := range legacy {
host = strings.TrimSpace(strings.ToLower(host))
if host == "" || host == "version" || host == "domains" {
continue
}
entry, ok := parseLegacyDomainCacheEntry(raw)
if !ok {
continue
}
rec := out.Domains[host]
rec.Direct = &entry
out.Domains[host] = rec
migrated++
}
if logf != nil && migrated > 0 {
logf("domain-cache: migrated legacy entries=%d into split cache (direct bucket)", migrated)
}
return out
}
func (s domainCacheState) get(domain string, source domainCacheSource, now, ttl int) ([]string, bool) {
rec, ok := s.Domains[strings.TrimSpace(strings.ToLower(domain))]
if !ok {
return nil, false
}
var entry *domainCacheEntry
switch source {
case domainCacheSourceWildcard:
entry = rec.Wildcard
default:
entry = rec.Direct
}
if entry == nil || entry.LastResolved <= 0 {
return nil, false
}
if now-entry.LastResolved > ttl {
return nil, false
}
ips := normalizeCacheIPs(entry.IPs)
if len(ips) == 0 {
return nil, false
}
return ips, true
}
func (s domainCacheState) getNegative(domain string, source domainCacheSource, now, nxTTL, timeoutTTL, temporaryTTL, otherTTL int) (dnsErrorKind, int, bool) {
rec, ok := s.Domains[strings.TrimSpace(strings.ToLower(domain))]
if !ok {
return "", 0, false
}
var entry *domainCacheEntry
switch source {
case domainCacheSourceWildcard:
entry = rec.Wildcard
default:
entry = rec.Direct
}
if entry == nil || entry.LastErrorAt <= 0 {
return "", 0, false
}
kind, ok := normalizeCacheErrorKind(entry.LastErrorKind)
if !ok {
return "", 0, false
}
age := now - entry.LastErrorAt
if age < 0 {
return "", 0, false
}
cacheTTL := 0
switch kind {
case dnsErrorNXDomain:
cacheTTL = nxTTL
case dnsErrorTimeout:
cacheTTL = timeoutTTL
case dnsErrorTemporary:
cacheTTL = temporaryTTL
case dnsErrorOther:
cacheTTL = otherTTL
}
if cacheTTL <= 0 || age > cacheTTL {
return "", 0, false
}
return kind, age, true
}
func (s domainCacheState) getStoredIPs(domain string, source domainCacheSource) []string {
rec, ok := s.Domains[strings.TrimSpace(strings.ToLower(domain))]
if !ok {
return nil
}
entry := getCacheEntryBySource(rec, source)
if entry == nil {
return nil
}
return normalizeCacheIPs(entry.IPs)
}
func (s domainCacheState) getLastErrorKind(domain string, source domainCacheSource) (dnsErrorKind, bool) {
rec, ok := s.Domains[strings.TrimSpace(strings.ToLower(domain))]
if !ok {
return "", false
}
entry := getCacheEntryBySource(rec, source)
if entry == nil || entry.LastErrorAt <= 0 {
return "", false
}
return normalizeCacheErrorKind(entry.LastErrorKind)
}
func (s domainCacheState) getQuarantine(domain string, source domainCacheSource, now int) (string, int, bool) {
rec, ok := s.Domains[strings.TrimSpace(strings.ToLower(domain))]
if !ok {
return "", 0, false
}
entry := getCacheEntryBySource(rec, source)
if entry == nil || entry.QuarantineUntil <= 0 {
return "", 0, false
}
if now >= entry.QuarantineUntil {
return "", 0, false
}
state := normalizeDomainState(entry.State, entry.Score)
if state == "" {
state = domainStateQuarantine
}
age := 0
if entry.LastErrorAt > 0 {
age = now - entry.LastErrorAt
}
return state, age, true
}
func (s domainCacheState) getStale(domain string, source domainCacheSource, now, maxAge int) ([]string, int, bool) {
if maxAge <= 0 {
return nil, 0, false
}
rec, ok := s.Domains[strings.TrimSpace(strings.ToLower(domain))]
if !ok {
return nil, 0, false
}
entry := getCacheEntryBySource(rec, source)
if entry == nil || entry.LastResolved <= 0 {
return nil, 0, false
}
age := now - entry.LastResolved
if age < 0 || age > maxAge {
return nil, 0, false
}
ips := normalizeCacheIPs(entry.IPs)
if len(ips) == 0 {
return nil, 0, false
}
return ips, age, true
}
func (s *domainCacheState) set(domain string, source domainCacheSource, ips []string, now int) {
host := strings.TrimSpace(strings.ToLower(domain))
if host == "" || now <= 0 {
return
}
norm := normalizeCacheIPs(ips)
if len(norm) == 0 {
return
}
if s.Domains == nil {
s.Domains = map[string]domainCacheRecord{}
}
rec := s.Domains[host]
prev := getCacheEntryBySource(rec, source)
prevScore := 0
if prev != nil {
prevScore = prev.Score
}
entry := &domainCacheEntry{
IPs: norm,
LastResolved: now,
LastErrorKind: "",
LastErrorAt: 0,
Score: clampDomainScore(prevScore + envInt("RESOLVE_DOMAIN_SCORE_OK", 8)),
QuarantineUntil: 0,
}
entry.State = domainStateFromScore(entry.Score)
switch source {
case domainCacheSourceWildcard:
rec.Wildcard = entry
default:
rec.Direct = entry
}
s.Domains[host] = rec
}
func getCacheEntryBySource(rec domainCacheRecord, source domainCacheSource) *domainCacheEntry {
switch source {
case domainCacheSourceWildcard:
return rec.Wildcard
default:
return rec.Direct
}
}
func clampDomainScore(v int) int {
if v < domainScoreMin {
return domainScoreMin
}
if v > domainScoreMax {
return domainScoreMax
}
return v
}
func domainStateFromScore(score int) string {
switch {
case score >= 20:
return domainStateActive
case score >= 5:
return domainStateStable
case score >= -10:
return domainStateSuspect
case score >= -30:
return domainStateQuarantine
default:
return domainStateHardQuar
}
}
func normalizeDomainState(raw string, score int) string {
switch strings.TrimSpace(strings.ToLower(raw)) {
case domainStateActive:
return domainStateActive
case domainStateStable:
return domainStateStable
case domainStateSuspect:
return domainStateSuspect
case domainStateQuarantine:
return domainStateQuarantine
case domainStateHardQuar:
return domainStateHardQuar
default:
if score == 0 {
return ""
}
return domainStateFromScore(score)
}
}
func domainScorePenalty(stats dnsMetrics) int {
if stats.NXDomain >= 2 {
return envInt("RESOLVE_DOMAIN_SCORE_NX_CONFIRMED", -15)
}
if stats.NXDomain > 0 {
return envInt("RESOLVE_DOMAIN_SCORE_NX_SINGLE", -7)
}
if stats.Timeout > 0 {
return envInt("RESOLVE_DOMAIN_SCORE_TIMEOUT", -3)
}
if stats.Temporary > 0 {
return envInt("RESOLVE_DOMAIN_SCORE_TEMPORARY", -2)
}
return envInt("RESOLVE_DOMAIN_SCORE_OTHER", -2)
}
func (s *domainCacheState) setErrorWithStats(domain string, source domainCacheSource, stats dnsMetrics, now int) {
host := strings.TrimSpace(strings.ToLower(domain))
if host == "" || now <= 0 {
return
}
kind, ok := classifyHostErrorKind(stats)
if !ok {
return
}
normKind, ok := normalizeCacheErrorKind(string(kind))
if !ok {
return
}
penalty := domainScorePenalty(stats)
quarantineTTL := envInt("RESOLVE_QUARANTINE_TTL_SEC", defaultQuarantineTTL)
if quarantineTTL < 0 {
quarantineTTL = 0
}
hardQuarantineTTL := envInt("RESOLVE_HARD_QUARANTINE_TTL_SEC", defaultHardQuarantineTT)
if hardQuarantineTTL < 0 {
hardQuarantineTTL = 0
}
if s.Domains == nil {
s.Domains = map[string]domainCacheRecord{}
}
rec := s.Domains[host]
entry := getCacheEntryBySource(rec, source)
if entry == nil {
entry = &domainCacheEntry{}
}
prevKind, _ := normalizeCacheErrorKind(entry.LastErrorKind)
entry.Score = clampDomainScore(entry.Score + penalty)
entry.State = domainStateFromScore(entry.Score)
// Timeout-only failures are treated as transient transport noise by default.
// Keep them in suspect bucket (no quarantine) unless we have NX signal.
if normKind == dnsErrorTimeout && prevKind != dnsErrorNXDomain {
if entry.Score < -10 {
entry.Score = -10
}
entry.State = domainStateSuspect
}
entry.LastErrorKind = string(normKind)
entry.LastErrorAt = now
switch entry.State {
case domainStateHardQuar:
entry.QuarantineUntil = now + hardQuarantineTTL
case domainStateQuarantine:
entry.QuarantineUntil = now + quarantineTTL
default:
entry.QuarantineUntil = 0
}
switch source {
case domainCacheSourceWildcard:
rec.Wildcard = entry
default:
rec.Direct = entry
}
s.Domains[host] = rec
}
func (s domainCacheState) toMap() map[string]any {
out := map[string]any{
"version": 4,
"domains": map[string]any{},
}
domainsAny := out["domains"].(map[string]any)
hosts := make([]string, 0, len(s.Domains))
for host := range s.Domains {
hosts = append(hosts, host)
}
sort.Strings(hosts)
for _, host := range hosts {
rec := s.Domains[host]
recOut := map[string]any{}
if rec.Direct != nil {
directOut := map[string]any{}
if len(rec.Direct.IPs) > 0 && rec.Direct.LastResolved > 0 {
directOut["ips"] = rec.Direct.IPs
directOut["last_resolved"] = rec.Direct.LastResolved
}
if kind, ok := normalizeCacheErrorKind(rec.Direct.LastErrorKind); ok && rec.Direct.LastErrorAt > 0 {
directOut["last_error_kind"] = string(kind)
directOut["last_error_at"] = rec.Direct.LastErrorAt
}
if rec.Direct.Score != 0 {
directOut["score"] = rec.Direct.Score
}
if st := normalizeDomainState(rec.Direct.State, rec.Direct.Score); st != "" {
directOut["state"] = st
}
if rec.Direct.QuarantineUntil > 0 {
directOut["quarantine_until"] = rec.Direct.QuarantineUntil
}
if len(directOut) > 0 {
recOut["direct"] = directOut
}
}
if rec.Wildcard != nil {
wildOut := map[string]any{}
if len(rec.Wildcard.IPs) > 0 && rec.Wildcard.LastResolved > 0 {
wildOut["ips"] = rec.Wildcard.IPs
wildOut["last_resolved"] = rec.Wildcard.LastResolved
}
if kind, ok := normalizeCacheErrorKind(rec.Wildcard.LastErrorKind); ok && rec.Wildcard.LastErrorAt > 0 {
wildOut["last_error_kind"] = string(kind)
wildOut["last_error_at"] = rec.Wildcard.LastErrorAt
}
if rec.Wildcard.Score != 0 {
wildOut["score"] = rec.Wildcard.Score
}
if st := normalizeDomainState(rec.Wildcard.State, rec.Wildcard.Score); st != "" {
wildOut["state"] = st
}
if rec.Wildcard.QuarantineUntil > 0 {
wildOut["quarantine_until"] = rec.Wildcard.QuarantineUntil
}
if len(wildOut) > 0 {
recOut["wildcard"] = wildOut
}
}
if len(recOut) > 0 {
domainsAny[host] = recOut
}
}
return out
}
func (s domainCacheState) formatStateSummary(now int) string {
type counters struct {
active int
stable int
suspect int
quarantine int
hardQuar int
}
add := func(c *counters, entry *domainCacheEntry) {
if entry == nil {
return
}
st := normalizeDomainState(entry.State, entry.Score)
if entry.QuarantineUntil > now {
// Keep hard quarantine state if explicitly marked,
// otherwise active quarantine bucket.
if st == domainStateHardQuar {
c.hardQuar++
return
}
c.quarantine++
return
}
switch st {
case domainStateActive:
c.active++
case domainStateStable:
c.stable++
case domainStateSuspect:
c.suspect++
case domainStateQuarantine:
c.quarantine++
case domainStateHardQuar:
c.hardQuar++
}
}
var c counters
for _, rec := range s.Domains {
add(&c, rec.Direct)
add(&c, rec.Wildcard)
}
total := c.active + c.stable + c.suspect + c.quarantine + c.hardQuar
if total == 0 {
return ""
}
return fmt.Sprintf(
"active=%d stable=%d suspect=%d quarantine=%d hard_quarantine=%d total=%d",
c.active, c.stable, c.suspect, c.quarantine, c.hardQuar, total,
)
}
func digPTR(ip, upstream string, timeout time.Duration, logf func(string, ...any)) ([]string, error) {
server, port := splitDNS(upstream)
if server == "" {
return nil, fmt.Errorf("upstream empty")
}
if port == "" {
port = "53"
}
addr := net.JoinHostPort(server, port)
r := &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
d := net.Dialer{}
return d.DialContext(ctx, "udp", addr)
},
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
names, err := r.LookupAddr(ctx, ip)
cancel()
if err != nil {
if logf != nil {
logf("ptr error %s via %s: %v", ip, addr, err)
}
return nil, err
}
seen := map[string]struct{}{}
var out []string
for _, n := range names {
n = strings.TrimSuffix(strings.ToLower(strings.TrimSpace(n)), ".")
if n == "" {
continue
}
if _, ok := seen[n]; !ok {
seen[n] = struct{}{}
out = append(out, n)
}
}
return out, nil
}
// ---------------------------------------------------------------------
// EN: `loadDNSConfig` loads dns config from storage or config.
// RU: `loadDNSConfig` - загружает dns config из хранилища или конфига.
// ---------------------------------------------------------------------
func loadDNSConfig(path string, logf func(string, ...any)) dnsConfig {
cfg := dnsConfig{
Default: []string{defaultDNS1, defaultDNS2},
Meta: []string{defaultMeta1, defaultMeta2},
SmartDNS: smartDNSAddr(),
Mode: DNSModeDirect,
}
activePool := loadEnabledDNSUpstreamPool()
if len(activePool) > 0 {
cfg.Default = activePool
cfg.Meta = activePool
}
// 1) Если форсируем SmartDNS — вообще игнорим файл и ходим только через локальный резолвер.
if smartDNSForced() {
addr := smartDNSAddr()
cfg.Default = []string{addr}
cfg.Meta = []string{addr}
cfg.SmartDNS = addr
cfg.Mode = DNSModeSmartDNS
if logf != nil {
logf("dns-config: SmartDNS forced (%s), ignore %s", addr, path)
}
return cfg
}
// 2) Читаем dns-upstreams.conf для legacy-совместимости и smartdns/mode значений.
data, err := os.ReadFile(path)
if err != nil {
if logf != nil {
logf("dns-config: can't read %s: %v", path, err)
}
cfg.Default = mergeDNSUpstreamPools(cfg.Default, resolverFallbackPool())
cfg.Meta = mergeDNSUpstreamPools(cfg.Meta, resolverFallbackPool())
return cfg
}
var def, meta []string
lines := strings.Split(string(data), "\n")
for _, ln := range lines {
s := strings.TrimSpace(ln)
if s == "" || strings.HasPrefix(s, "#") {
continue
}
parts := strings.Fields(s)
if len(parts) < 2 {
continue
}
key := strings.ToLower(parts[0])
vals := parts[1:]
switch key {
case "default":
for _, v := range vals {
if n := normalizeDNSUpstream(v, "53"); n != "" {
def = append(def, n)
}
}
case "meta":
for _, v := range vals {
if n := normalizeDNSUpstream(v, "53"); n != "" {
meta = append(meta, n)
}
}
case "smartdns":
if len(vals) > 0 {
if n := normalizeSmartDNSAddr(vals[0]); n != "" {
cfg.SmartDNS = n
}
}
case "mode":
if len(vals) > 0 {
cfg.Mode = normalizeDNSResolverMode(DNSResolverMode(vals[0]), false)
}
}
}
if len(activePool) == 0 {
if len(def) > 0 {
cfg.Default = def
}
if len(meta) > 0 {
cfg.Meta = meta
}
}
cfg.Default = mergeDNSUpstreamPools(cfg.Default, resolverFallbackPool())
cfg.Meta = mergeDNSUpstreamPools(cfg.Meta, resolverFallbackPool())
if logf != nil {
logf("dns-config: accept %s: mode=%s smartdns=%s default=%v; meta=%v", path, cfg.Mode, cfg.SmartDNS, cfg.Default, cfg.Meta)
}
return cfg
}
// ---------------------------------------------------------------------
// EN: `readLinesAllowMissing` reads lines allow missing from input data.
// RU: `readLinesAllowMissing` - читает lines allow missing из входных данных.
// ---------------------------------------------------------------------
func readLinesAllowMissing(path string) []string {
data, err := os.ReadFile(path)
if err != nil {
return nil
}
return strings.Split(strings.ReplaceAll(string(data), "\r\n", "\n"), "\n")
}
// ---------------------------------------------------------------------
// EN: `loadJSONMap` loads json map from storage or config.
// RU: `loadJSONMap` - загружает json map из хранилища или конфига.
// ---------------------------------------------------------------------
func loadJSONMap(path string) map[string]any {
data, err := os.ReadFile(path)
if err != nil {
return map[string]any{}
}
var out map[string]any
if err := json.Unmarshal(data, &out); err != nil {
return map[string]any{}
}
return out
}
func loadResolverPrecheckLastRun(path string) int {
m := loadJSONMap(path)
if len(m) == 0 {
return 0
}
v, ok := parseAnyInt(m["last_run"])
if !ok || v <= 0 {
return 0
}
return v
}
func saveResolverPrecheckState(path string, ts int, timeoutStats resolverTimeoutRecheckStats) {
if path == "" || ts <= 0 {
return
}
state := loadJSONMap(path)
if state == nil {
state = map[string]any{}
}
state["last_run"] = ts
state["timeout_recheck"] = map[string]any{
"checked": timeoutStats.Checked,
"recovered": timeoutStats.Recovered,
"recovered_ips": timeoutStats.RecoveredIPs,
"still_timeout": timeoutStats.StillTimeout,
"now_nxdomain": timeoutStats.NowNXDomain,
"now_temporary": timeoutStats.NowTemporary,
"now_other": timeoutStats.NowOther,
"no_signal": timeoutStats.NoSignal,
}
saveJSON(state, path)
}
// ---------------------------------------------------------------------
// EN: `saveJSON` saves json to persistent storage.
// RU: `saveJSON` - сохраняет json в постоянное хранилище.
// ---------------------------------------------------------------------
func saveJSON(data any, path string) {
tmp := path + ".tmp"
b, err := json.MarshalIndent(data, "", " ")
if err != nil {
return
}
_ = os.WriteFile(tmp, b, 0o644)
_ = os.Rename(tmp, path)
}
// ---------------------------------------------------------------------
// EN: `uniqueStrings` contains core logic for unique strings.
// RU: `uniqueStrings` - содержит основную логику для unique strings.
// ---------------------------------------------------------------------
func uniqueStrings(in []string) []string {
seen := map[string]struct{}{}
var out []string
for _, v := range in {
if _, ok := seen[v]; !ok {
seen[v] = struct{}{}
out = append(out, v)
}
}
return out
}
func pickDNSStartIndex(host string, size int) int {
if size <= 1 {
return 0
}
h := fnv.New32a()
_, _ = h.Write([]byte(strings.ToLower(strings.TrimSpace(host))))
return int(h.Sum32() % uint32(size))
}
func resolverFallbackPool() []string {
raw := strings.TrimSpace(os.Getenv("RESOLVE_DNS_FALLBACKS"))
switch strings.ToLower(raw) {
case "off", "none", "0":
return nil
}
candidates := resolverFallbackDNS
if raw != "" {
candidates = nil
fields := strings.FieldsFunc(raw, func(r rune) bool {
return r == ',' || r == ';' || r == ' ' || r == '\n' || r == '\t'
})
for _, f := range fields {
if n := normalizeDNSUpstream(f, "53"); n != "" {
candidates = append(candidates, n)
}
}
}
return uniqueStrings(candidates)
}
func mergeDNSUpstreamPools(primary, fallback []string) []string {
maxUpstreams := envInt("RESOLVE_DNS_MAX_UPSTREAMS", 12)
if maxUpstreams < 1 {
maxUpstreams = 1
}
out := make([]string, 0, len(primary)+len(fallback))
seen := map[string]struct{}{}
add := func(items []string) {
for _, item := range items {
if len(out) >= maxUpstreams {
return
}
n := normalizeDNSUpstream(item, "53")
if n == "" {
continue
}
if _, ok := seen[n]; ok {
continue
}
seen[n] = struct{}{}
out = append(out, n)
}
}
add(primary)
add(fallback)
return out
}
// ---------------------------------------------------------------------
// text cleanup + IP classifiers
// ---------------------------------------------------------------------
var reANSI = regexp.MustCompile(`\x1B\[[0-9;]*[A-Za-z]`)
func stripANSI(s string) string {
return reANSI.ReplaceAllString(s, "")
}
// ---------------------------------------------------------------------
// EN: `isPrivateIPv4` checks whether private i pv4 is true.
// RU: `isPrivateIPv4` - проверяет, является ли private i pv4 истинным условием.
// ---------------------------------------------------------------------
func isPrivateIPv4(ip string) bool {
parts := strings.Split(strings.Split(ip, "/")[0], ".")
if len(parts) != 4 {
return true
}
vals := make([]int, 4)
for i, p := range parts {
n, err := strconv.Atoi(p)
if err != nil || n < 0 || n > 255 {
return true
}
vals[i] = n
}
if vals[0] == 10 || vals[0] == 127 || vals[0] == 0 {
return true
}
if vals[0] == 192 && vals[1] == 168 {
return true
}
if vals[0] == 172 && vals[1] >= 16 && vals[1] <= 31 {
return true
}
return false
}