package resolver import ( "context" "net" "time" ) const ( dnsModeSmartDNS = "smartdns" dnsModeHybridWildcard = "hybrid_wildcard" ) type DNSAttemptPolicy struct { TryLimit int DomainBudget time.Duration StopOnNX bool } type DNSCooldown interface { ShouldSkip(upstream string, now int64) bool ObserveSuccess(upstream string) ObserveError(upstream string, kind DNSErrorKind, now int64) (bool, int) } func ResolveHost( host string, cfg DNSConfig, metaSpecial []string, isWildcard func(string) bool, timeout time.Duration, cooldown DNSCooldown, directPolicyFor func(int) DNSAttemptPolicy, wildcardPolicyFor func(int) DNSAttemptPolicy, smartDNSFallbackEnabled bool, logf func(string, ...any), ) ([]string, DNSMetrics) { useMeta := false for _, m := range metaSpecial { if host == m { useMeta = true break } } dnsList := cfg.Default if useMeta { dnsList = cfg.Meta } primaryViaSmartDNS := false switch cfg.Mode { case dnsModeSmartDNS: if cfg.SmartDNS != "" { dnsList = []string{cfg.SmartDNS} primaryViaSmartDNS = true } case dnsModeHybridWildcard: if cfg.SmartDNS != "" && isWildcard != nil && isWildcard(host) { dnsList = []string{cfg.SmartDNS} primaryViaSmartDNS = true } } policy := safePolicy(directPolicyFor, len(dnsList), timeout) if primaryViaSmartDNS { policy = safePolicy(wildcardPolicyFor, len(dnsList), timeout) } ips, stats := DigAWithPolicy(host, dnsList, timeout, policy, cooldown, logf) if len(ips) == 0 && !primaryViaSmartDNS && cfg.SmartDNS != "" && smartDNSFallbackEnabled && ShouldFallbackToSmartDNS(stats) { if logf != nil { logf( "dns fallback %s: trying smartdns=%s after errors nxdomain=%d timeout=%d temporary=%d other=%d", host, cfg.SmartDNS, stats.NXDomain, stats.Timeout, stats.Temporary, stats.Other, ) } fallbackPolicy := safePolicy(wildcardPolicyFor, 1, timeout) fallbackIPs, fallbackStats := DigAWithPolicy(host, []string{cfg.SmartDNS}, timeout, fallbackPolicy, cooldown, logf) stats.Merge(fallbackStats) if len(fallbackIPs) > 0 { ips = fallbackIPs if logf != nil { logf("dns fallback %s: resolved via smartdns (%d ips)", host, len(fallbackIPs)) } } } out := make([]string, 0, len(ips)) seen := map[string]struct{}{} for _, ip := range ips { if IsPrivateIPv4(ip) { continue } if _, ok := seen[ip]; ok { continue } seen[ip] = struct{}{} out = append(out, ip) } return out, stats } func DigAWithPolicy( host string, dnsList []string, timeout time.Duration, policy DNSAttemptPolicy, cooldown DNSCooldown, logf func(string, ...any), ) ([]string, DNSMetrics) { stats := DNSMetrics{} if len(dnsList) == 0 { return nil, stats } tryLimit := policy.TryLimit if tryLimit <= 0 { tryLimit = 1 } if tryLimit > len(dnsList) { tryLimit = len(dnsList) } budget := policy.DomainBudget if budget <= 0 { budget = time.Duration(tryLimit) * timeout } if budget < 200*time.Millisecond { budget = 200 * time.Millisecond } deadline := time.Now().Add(budget) start := PickDNSStartIndex(host, len(dnsList)) for attempt := 0; attempt < tryLimit; attempt++ { remaining := time.Until(deadline) if remaining <= 0 { if logf != nil { logf("dns budget exhausted %s: attempts=%d budget_ms=%d", host, attempt, budget.Milliseconds()) } break } entry := dnsList[(start+attempt)%len(dnsList)] server, port := SplitDNS(entry) if server == "" { continue } if port == "" { port = "53" } addr := net.JoinHostPort(server, port) if cooldown != nil && cooldown.ShouldSkip(addr, time.Now().Unix()) { stats.AddCooldownSkip(addr) continue } r := &net.Resolver{ PreferGo: true, Dial: func(ctx context.Context, network, address string) (net.Conn, error) { d := net.Dialer{} return d.DialContext(ctx, "udp", addr) }, } perAttemptTimeout := timeout if remaining < perAttemptTimeout { perAttemptTimeout = remaining } if perAttemptTimeout < 100*time.Millisecond { perAttemptTimeout = 100 * time.Millisecond } ctx, cancel := context.WithTimeout(context.Background(), perAttemptTimeout) records, err := r.LookupHost(ctx, host) cancel() if err != nil { kindRaw := ClassifyDNSError(err) kind, ok := NormalizeCacheErrorKind(kindRaw) if !ok { kind = DNSErrorOther } stats.AddError(addr, kind) if cooldown != nil { if banned, banSec := cooldown.ObserveError(addr, kind, time.Now().Unix()); banned && logf != nil { logf("dns cooldown ban %s: timeout-like failures; ban_sec=%d", addr, banSec) } } if logf != nil { logf("dns warn %s via %s: kind=%s attempt=%d/%d err=%v", host, addr, kind, attempt+1, tryLimit, err) } if policy.StopOnNX && kind == DNSErrorNXDomain { if logf != nil { logf("dns early-stop %s: nxdomain via %s (attempt=%d/%d)", host, addr, attempt+1, tryLimit) } break } continue } var ips []string for _, ip := range records { if IsPrivateIPv4(ip) { continue } ips = append(ips, ip) } if len(ips) == 0 { stats.AddError(addr, DNSErrorOther) if cooldown != nil { _, _ = cooldown.ObserveError(addr, DNSErrorOther, time.Now().Unix()) } if logf != nil { logf("dns warn %s via %s: kind=other err=no_public_ips", host, addr) } continue } stats.AddSuccess(addr) if cooldown != nil { cooldown.ObserveSuccess(addr) } return UniqueStrings(ips), stats } return nil, stats } func safePolicy(factory func(int) DNSAttemptPolicy, count int, timeout time.Duration) DNSAttemptPolicy { if factory != nil { return factory(count) } return DNSAttemptPolicy{ TryLimit: 1, DomainBudget: timeout, StopOnNX: true, } }