Files

235 lines
5.6 KiB
Go

package resolver
import (
"context"
"net"
"time"
)
const (
dnsModeSmartDNS = "smartdns"
dnsModeHybridWildcard = "hybrid_wildcard"
)
type DNSAttemptPolicy struct {
TryLimit int
DomainBudget time.Duration
StopOnNX bool
}
type DNSCooldown interface {
ShouldSkip(upstream string, now int64) bool
ObserveSuccess(upstream string)
ObserveError(upstream string, kind DNSErrorKind, now int64) (bool, int)
}
func ResolveHost(
host string,
cfg DNSConfig,
metaSpecial []string,
isWildcard func(string) bool,
timeout time.Duration,
cooldown DNSCooldown,
directPolicyFor func(int) DNSAttemptPolicy,
wildcardPolicyFor func(int) DNSAttemptPolicy,
smartDNSFallbackEnabled bool,
logf func(string, ...any),
) ([]string, DNSMetrics) {
useMeta := false
for _, m := range metaSpecial {
if host == m {
useMeta = true
break
}
}
dnsList := cfg.Default
if useMeta {
dnsList = cfg.Meta
}
primaryViaSmartDNS := false
switch cfg.Mode {
case dnsModeSmartDNS:
if cfg.SmartDNS != "" {
dnsList = []string{cfg.SmartDNS}
primaryViaSmartDNS = true
}
case dnsModeHybridWildcard:
if cfg.SmartDNS != "" && isWildcard != nil && isWildcard(host) {
dnsList = []string{cfg.SmartDNS}
primaryViaSmartDNS = true
}
}
policy := safePolicy(directPolicyFor, len(dnsList), timeout)
if primaryViaSmartDNS {
policy = safePolicy(wildcardPolicyFor, len(dnsList), timeout)
}
ips, stats := DigAWithPolicy(host, dnsList, timeout, policy, cooldown, logf)
if len(ips) == 0 &&
!primaryViaSmartDNS &&
cfg.SmartDNS != "" &&
smartDNSFallbackEnabled &&
ShouldFallbackToSmartDNS(stats) {
if logf != nil {
logf(
"dns fallback %s: trying smartdns=%s after errors nxdomain=%d timeout=%d temporary=%d other=%d",
host,
cfg.SmartDNS,
stats.NXDomain,
stats.Timeout,
stats.Temporary,
stats.Other,
)
}
fallbackPolicy := safePolicy(wildcardPolicyFor, 1, timeout)
fallbackIPs, fallbackStats := DigAWithPolicy(host, []string{cfg.SmartDNS}, timeout, fallbackPolicy, cooldown, logf)
stats.Merge(fallbackStats)
if len(fallbackIPs) > 0 {
ips = fallbackIPs
if logf != nil {
logf("dns fallback %s: resolved via smartdns (%d ips)", host, len(fallbackIPs))
}
}
}
out := make([]string, 0, len(ips))
seen := map[string]struct{}{}
for _, ip := range ips {
if IsPrivateIPv4(ip) {
continue
}
if _, ok := seen[ip]; ok {
continue
}
seen[ip] = struct{}{}
out = append(out, ip)
}
return out, stats
}
func DigAWithPolicy(
host string,
dnsList []string,
timeout time.Duration,
policy DNSAttemptPolicy,
cooldown DNSCooldown,
logf func(string, ...any),
) ([]string, DNSMetrics) {
stats := DNSMetrics{}
if len(dnsList) == 0 {
return nil, stats
}
tryLimit := policy.TryLimit
if tryLimit <= 0 {
tryLimit = 1
}
if tryLimit > len(dnsList) {
tryLimit = len(dnsList)
}
budget := policy.DomainBudget
if budget <= 0 {
budget = time.Duration(tryLimit) * timeout
}
if budget < 200*time.Millisecond {
budget = 200 * time.Millisecond
}
deadline := time.Now().Add(budget)
start := PickDNSStartIndex(host, len(dnsList))
for attempt := 0; attempt < tryLimit; attempt++ {
remaining := time.Until(deadline)
if remaining <= 0 {
if logf != nil {
logf("dns budget exhausted %s: attempts=%d budget_ms=%d", host, attempt, budget.Milliseconds())
}
break
}
entry := dnsList[(start+attempt)%len(dnsList)]
server, port := SplitDNS(entry)
if server == "" {
continue
}
if port == "" {
port = "53"
}
addr := net.JoinHostPort(server, port)
if cooldown != nil && cooldown.ShouldSkip(addr, time.Now().Unix()) {
stats.AddCooldownSkip(addr)
continue
}
r := &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
d := net.Dialer{}
return d.DialContext(ctx, "udp", addr)
},
}
perAttemptTimeout := timeout
if remaining < perAttemptTimeout {
perAttemptTimeout = remaining
}
if perAttemptTimeout < 100*time.Millisecond {
perAttemptTimeout = 100 * time.Millisecond
}
ctx, cancel := context.WithTimeout(context.Background(), perAttemptTimeout)
records, err := r.LookupHost(ctx, host)
cancel()
if err != nil {
kindRaw := ClassifyDNSError(err)
kind, ok := NormalizeCacheErrorKind(kindRaw)
if !ok {
kind = DNSErrorOther
}
stats.AddError(addr, kind)
if cooldown != nil {
if banned, banSec := cooldown.ObserveError(addr, kind, time.Now().Unix()); banned && logf != nil {
logf("dns cooldown ban %s: timeout-like failures; ban_sec=%d", addr, banSec)
}
}
if logf != nil {
logf("dns warn %s via %s: kind=%s attempt=%d/%d err=%v", host, addr, kind, attempt+1, tryLimit, err)
}
if policy.StopOnNX && kind == DNSErrorNXDomain {
if logf != nil {
logf("dns early-stop %s: nxdomain via %s (attempt=%d/%d)", host, addr, attempt+1, tryLimit)
}
break
}
continue
}
var ips []string
for _, ip := range records {
if IsPrivateIPv4(ip) {
continue
}
ips = append(ips, ip)
}
if len(ips) == 0 {
stats.AddError(addr, DNSErrorOther)
if cooldown != nil {
_, _ = cooldown.ObserveError(addr, DNSErrorOther, time.Now().Unix())
}
if logf != nil {
logf("dns warn %s via %s: kind=other err=no_public_ips", host, addr)
}
continue
}
stats.AddSuccess(addr)
if cooldown != nil {
cooldown.ObserveSuccess(addr)
}
return UniqueStrings(ips), stats
}
return nil, stats
}
func safePolicy(factory func(int) DNSAttemptPolicy, count int, timeout time.Duration) DNSAttemptPolicy {
if factory != nil {
return factory(count)
}
return DNSAttemptPolicy{
TryLimit: 1,
DomainBudget: timeout,
StopOnNX: true,
}
}