package dnscfg import ( "fmt" "sort" "strings" "time" ) type PrewarmDNSUpstreamMetrics struct { Attempts int OK int NXDomain int Timeout int Temporary int Other int Skipped int } type PrewarmDNSMetrics struct { Attempts int OK int NXDomain int Timeout int Temporary int Other int Skipped int PerUpstream map[string]PrewarmDNSUpstreamMetrics } func (m *PrewarmDNSMetrics) Merge(other PrewarmDNSMetrics) { m.Attempts += other.Attempts m.OK += other.OK m.NXDomain += other.NXDomain m.Timeout += other.Timeout m.Temporary += other.Temporary m.Other += other.Other m.Skipped += other.Skipped if len(other.PerUpstream) == 0 { return } if m.PerUpstream == nil { m.PerUpstream = map[string]PrewarmDNSUpstreamMetrics{} } for upstream, src := range other.PerUpstream { dst := m.PerUpstream[upstream] dst.Attempts += src.Attempts dst.OK += src.OK dst.NXDomain += src.NXDomain dst.Timeout += src.Timeout dst.Temporary += src.Temporary dst.Other += src.Other dst.Skipped += src.Skipped m.PerUpstream[upstream] = dst } } func (m PrewarmDNSMetrics) TotalErrors() int { return m.NXDomain + m.Timeout + m.Temporary + m.Other } func (m PrewarmDNSMetrics) FormatPerUpstream() string { if len(m.PerUpstream) == 0 { return "" } keys := make([]string, 0, len(m.PerUpstream)) for k := range m.PerUpstream { keys = append(keys, k) } sort.Strings(keys) parts := make([]string, 0, len(keys)) for _, k := range keys { v := m.PerUpstream[k] parts = append(parts, fmt.Sprintf("%s{attempts=%d ok=%d nxdomain=%d timeout=%d temporary=%d other=%d skipped=%d}", k, v.Attempts, v.OK, v.NXDomain, v.Timeout, v.Temporary, v.Other, v.Skipped)) } return strings.Join(parts, "; ") } type PrewarmInput struct { Mode string Source string RuntimeEnabled bool SmartDNSAddr string Wildcards []string AggressiveSubs bool Subs []string SubsPerBaseLimit int Limit int Workers int TimeoutMS int EnvWorkers int EnvTimeoutMS int MaxHostsLog int WildcardMapPath string } type PrewarmDeps struct { IsGoogleLike func(string) bool EnsureRuntimeSet func() DigA func(host string, dnsList []string, timeout time.Duration) ([]string, PrewarmDNSMetrics) ReadDynSet func() ([]string, error) ApplyDynSet func([]string) error Logf func(message string) } type PrewarmResult struct { OK bool Message string ExitCode int ResolvedHosts int } func RunPrewarm(in PrewarmInput, deps PrewarmDeps) PrewarmResult { smartdnsAddr := strings.TrimSpace(in.SmartDNSAddr) if smartdnsAddr == "" { return PrewarmResult{OK: false, Message: "SmartDNS address is empty"} } wildcards := trimNonEmptyUnique(in.Wildcards) if len(wildcards) == 0 { msg := "prewarm skipped: wildcard list is empty" logPrewarm(deps.Logf, msg) return PrewarmResult{OK: true, Message: msg} } aggressive := in.AggressiveSubs subs := trimNonEmptyUnique(in.Subs) subsPerBaseLimit := in.SubsPerBaseLimit if subsPerBaseLimit < 0 { subsPerBaseLimit = 0 } domainSet := make(map[string]struct{}, len(wildcards)*(len(subs)+1)) for _, d := range wildcards { domainSet[d] = struct{}{} if !aggressive || isGoogleLikeSafe(deps.IsGoogleLike, d) { continue } maxSubs := len(subs) if subsPerBaseLimit > 0 && subsPerBaseLimit < maxSubs { maxSubs = subsPerBaseLimit } for i := 0; i < maxSubs; i++ { domainSet[subs[i]+"."+d] = struct{}{} } } domains := make([]string, 0, len(domainSet)) for d := range domainSet { domains = append(domains, d) } sort.Strings(domains) if in.Limit > 0 && len(domains) > in.Limit { domains = domains[:in.Limit] } if len(domains) == 0 { msg := "prewarm skipped: expanded wildcard list is empty" logPrewarm(deps.Logf, msg) return PrewarmResult{OK: true, Message: msg} } workers := in.Workers if workers <= 0 { workers = in.EnvWorkers if workers <= 0 { workers = 24 } } if workers < 1 { workers = 1 } if workers > 200 { workers = 200 } timeoutMS := in.TimeoutMS if timeoutMS <= 0 { timeoutMS = in.EnvTimeoutMS if timeoutMS <= 0 { timeoutMS = 1800 } } if timeoutMS < 200 { timeoutMS = 200 } if timeoutMS > 15000 { timeoutMS = 15000 } timeout := time.Duration(timeoutMS) * time.Millisecond if deps.EnsureRuntimeSet != nil { deps.EnsureRuntimeSet() } logPrewarm( deps.Logf, fmt.Sprintf( "prewarm start: mode=%s source=%s runtime_nftset=%t smartdns=%s wildcard_domains=%d expanded=%d aggressive_subs=%t workers=%d timeout_ms=%d", strings.TrimSpace(in.Mode), strings.TrimSpace(in.Source), in.RuntimeEnabled, smartdnsAddr, len(wildcards), len(domains), aggressive, workers, timeoutMS, ), ) type prewarmItem struct { host string ips []string stats PrewarmDNSMetrics } jobs := make(chan string, len(domains)) results := make(chan prewarmItem, len(domains)) for i := 0; i < workers; i++ { go func() { for host := range jobs { ips, stats := safeDigA(deps.DigA, host, []string{smartdnsAddr}, timeout) results <- prewarmItem{host: host, ips: ips, stats: stats} } }() } for _, host := range domains { jobs <- host } close(jobs) resolvedHosts := 0 totalIPs := 0 errorHosts := 0 stats := PrewarmDNSMetrics{} resolvedIPSet := map[string]struct{}{} loggedHosts := 0 maxHostsLog := in.MaxHostsLog if maxHostsLog <= 0 { maxHostsLog = 200 } for i := 0; i < len(domains); i++ { item := <-results stats.Merge(item.stats) if item.stats.TotalErrors() > 0 { errorHosts++ } if len(item.ips) == 0 { continue } resolvedHosts++ totalIPs += len(item.ips) for _, ip := range item.ips { if strings.TrimSpace(ip) != "" { resolvedIPSet[ip] = struct{}{} } } if loggedHosts < maxHostsLog { logPrewarm(deps.Logf, fmt.Sprintf("prewarm add: %s -> %s", item.host, strings.Join(item.ips, ", "))) loggedHosts++ } } manualAdded := 0 totalDynText := "n/a" if !in.RuntimeEnabled { existing, _ := safeReadDynSet(deps.ReadDynSet) mergedSet := make(map[string]struct{}, len(existing)+len(resolvedIPSet)) for _, ip := range existing { if strings.TrimSpace(ip) != "" { mergedSet[ip] = struct{}{} } } for ip := range resolvedIPSet { if _, ok := mergedSet[ip]; !ok { manualAdded++ } mergedSet[ip] = struct{}{} } merged := make([]string, 0, len(mergedSet)) for ip := range mergedSet { merged = append(merged, ip) } totalDynText = fmt.Sprintf("%d", len(merged)) if err := safeApplyDynSet(deps.ApplyDynSet, merged); err != nil { msg := fmt.Sprintf("prewarm manual apply failed: %v", err) logPrewarm(deps.Logf, msg) return PrewarmResult{OK: false, Message: msg} } logPrewarm( deps.Logf, fmt.Sprintf("prewarm manual merge: existing=%d resolved=%d added=%d total_dyn=%d", len(existing), len(resolvedIPSet), manualAdded, len(merged)), ) } if len(domains) > loggedHosts { logPrewarm( deps.Logf, fmt.Sprintf( "prewarm add: trace truncated, omitted=%d hosts (full wildcard map: %s)", len(domains)-loggedHosts, strings.TrimSpace(in.WildcardMapPath), ), ) } msg := fmt.Sprintf( "prewarm done: source=%s expanded=%d resolved=%d total_ips=%d error_hosts=%d dns_attempts=%d dns_ok=%d dns_errors=%d manual_added=%d dyn_total=%s", strings.TrimSpace(in.Source), len(domains), resolvedHosts, totalIPs, errorHosts, stats.Attempts, stats.OK, stats.TotalErrors(), manualAdded, totalDynText, ) logPrewarm(deps.Logf, msg) if perUpstream := stats.FormatPerUpstream(); perUpstream != "" { logPrewarm(deps.Logf, "prewarm dns upstreams: "+perUpstream) } return PrewarmResult{ OK: true, Message: msg, ExitCode: resolvedHosts, ResolvedHosts: resolvedHosts, } } func logPrewarm(logf func(string), msg string) { if logf != nil { logf(msg) } } func safeDigA( dig func(host string, dnsList []string, timeout time.Duration) ([]string, PrewarmDNSMetrics), host string, dnsList []string, timeout time.Duration, ) ([]string, PrewarmDNSMetrics) { if dig == nil { return nil, PrewarmDNSMetrics{} } return dig(host, dnsList, timeout) } func safeReadDynSet(read func() ([]string, error)) ([]string, error) { if read == nil { return nil, nil } return read() } func safeApplyDynSet(apply func([]string) error, ips []string) error { if apply == nil { return fmt.Errorf("apply dyn set callback is nil") } return apply(ips) } func isGoogleLikeSafe(check func(string) bool, domain string) bool { if check == nil { return false } return check(domain) } func trimNonEmptyUnique(in []string) []string { if len(in) == 0 { return nil } seen := map[string]struct{}{} out := make([]string, 0, len(in)) for _, item := range in { v := strings.TrimSpace(item) if v == "" { continue } if _, ok := seen[v]; ok { continue } seen[v] = struct{}{} out = append(out, v) } return out }