Files

393 lines
8.8 KiB
Go

package dnscfg
import (
"fmt"
"sort"
"strings"
"time"
)
type PrewarmDNSUpstreamMetrics struct {
Attempts int
OK int
NXDomain int
Timeout int
Temporary int
Other int
Skipped int
}
type PrewarmDNSMetrics struct {
Attempts int
OK int
NXDomain int
Timeout int
Temporary int
Other int
Skipped int
PerUpstream map[string]PrewarmDNSUpstreamMetrics
}
func (m *PrewarmDNSMetrics) Merge(other PrewarmDNSMetrics) {
m.Attempts += other.Attempts
m.OK += other.OK
m.NXDomain += other.NXDomain
m.Timeout += other.Timeout
m.Temporary += other.Temporary
m.Other += other.Other
m.Skipped += other.Skipped
if len(other.PerUpstream) == 0 {
return
}
if m.PerUpstream == nil {
m.PerUpstream = map[string]PrewarmDNSUpstreamMetrics{}
}
for upstream, src := range other.PerUpstream {
dst := m.PerUpstream[upstream]
dst.Attempts += src.Attempts
dst.OK += src.OK
dst.NXDomain += src.NXDomain
dst.Timeout += src.Timeout
dst.Temporary += src.Temporary
dst.Other += src.Other
dst.Skipped += src.Skipped
m.PerUpstream[upstream] = dst
}
}
func (m PrewarmDNSMetrics) TotalErrors() int {
return m.NXDomain + m.Timeout + m.Temporary + m.Other
}
func (m PrewarmDNSMetrics) FormatPerUpstream() string {
if len(m.PerUpstream) == 0 {
return ""
}
keys := make([]string, 0, len(m.PerUpstream))
for k := range m.PerUpstream {
keys = append(keys, k)
}
sort.Strings(keys)
parts := make([]string, 0, len(keys))
for _, k := range keys {
v := m.PerUpstream[k]
parts = append(parts, fmt.Sprintf("%s{attempts=%d ok=%d nxdomain=%d timeout=%d temporary=%d other=%d skipped=%d}", k, v.Attempts, v.OK, v.NXDomain, v.Timeout, v.Temporary, v.Other, v.Skipped))
}
return strings.Join(parts, "; ")
}
type PrewarmInput struct {
Mode string
Source string
RuntimeEnabled bool
SmartDNSAddr string
Wildcards []string
AggressiveSubs bool
Subs []string
SubsPerBaseLimit int
Limit int
Workers int
TimeoutMS int
EnvWorkers int
EnvTimeoutMS int
MaxHostsLog int
WildcardMapPath string
}
type PrewarmDeps struct {
IsGoogleLike func(string) bool
EnsureRuntimeSet func()
DigA func(host string, dnsList []string, timeout time.Duration) ([]string, PrewarmDNSMetrics)
ReadDynSet func() ([]string, error)
ApplyDynSet func([]string) error
Logf func(message string)
}
type PrewarmResult struct {
OK bool
Message string
ExitCode int
ResolvedHosts int
}
func RunPrewarm(in PrewarmInput, deps PrewarmDeps) PrewarmResult {
smartdnsAddr := strings.TrimSpace(in.SmartDNSAddr)
if smartdnsAddr == "" {
return PrewarmResult{OK: false, Message: "SmartDNS address is empty"}
}
wildcards := trimNonEmptyUnique(in.Wildcards)
if len(wildcards) == 0 {
msg := "prewarm skipped: wildcard list is empty"
logPrewarm(deps.Logf, msg)
return PrewarmResult{OK: true, Message: msg}
}
aggressive := in.AggressiveSubs
subs := trimNonEmptyUnique(in.Subs)
subsPerBaseLimit := in.SubsPerBaseLimit
if subsPerBaseLimit < 0 {
subsPerBaseLimit = 0
}
domainSet := make(map[string]struct{}, len(wildcards)*(len(subs)+1))
for _, d := range wildcards {
domainSet[d] = struct{}{}
if !aggressive || isGoogleLikeSafe(deps.IsGoogleLike, d) {
continue
}
maxSubs := len(subs)
if subsPerBaseLimit > 0 && subsPerBaseLimit < maxSubs {
maxSubs = subsPerBaseLimit
}
for i := 0; i < maxSubs; i++ {
domainSet[subs[i]+"."+d] = struct{}{}
}
}
domains := make([]string, 0, len(domainSet))
for d := range domainSet {
domains = append(domains, d)
}
sort.Strings(domains)
if in.Limit > 0 && len(domains) > in.Limit {
domains = domains[:in.Limit]
}
if len(domains) == 0 {
msg := "prewarm skipped: expanded wildcard list is empty"
logPrewarm(deps.Logf, msg)
return PrewarmResult{OK: true, Message: msg}
}
workers := in.Workers
if workers <= 0 {
workers = in.EnvWorkers
if workers <= 0 {
workers = 24
}
}
if workers < 1 {
workers = 1
}
if workers > 200 {
workers = 200
}
timeoutMS := in.TimeoutMS
if timeoutMS <= 0 {
timeoutMS = in.EnvTimeoutMS
if timeoutMS <= 0 {
timeoutMS = 1800
}
}
if timeoutMS < 200 {
timeoutMS = 200
}
if timeoutMS > 15000 {
timeoutMS = 15000
}
timeout := time.Duration(timeoutMS) * time.Millisecond
if deps.EnsureRuntimeSet != nil {
deps.EnsureRuntimeSet()
}
logPrewarm(
deps.Logf,
fmt.Sprintf(
"prewarm start: mode=%s source=%s runtime_nftset=%t smartdns=%s wildcard_domains=%d expanded=%d aggressive_subs=%t workers=%d timeout_ms=%d",
strings.TrimSpace(in.Mode),
strings.TrimSpace(in.Source),
in.RuntimeEnabled,
smartdnsAddr,
len(wildcards),
len(domains),
aggressive,
workers,
timeoutMS,
),
)
type prewarmItem struct {
host string
ips []string
stats PrewarmDNSMetrics
}
jobs := make(chan string, len(domains))
results := make(chan prewarmItem, len(domains))
for i := 0; i < workers; i++ {
go func() {
for host := range jobs {
ips, stats := safeDigA(deps.DigA, host, []string{smartdnsAddr}, timeout)
results <- prewarmItem{host: host, ips: ips, stats: stats}
}
}()
}
for _, host := range domains {
jobs <- host
}
close(jobs)
resolvedHosts := 0
totalIPs := 0
errorHosts := 0
stats := PrewarmDNSMetrics{}
resolvedIPSet := map[string]struct{}{}
loggedHosts := 0
maxHostsLog := in.MaxHostsLog
if maxHostsLog <= 0 {
maxHostsLog = 200
}
for i := 0; i < len(domains); i++ {
item := <-results
stats.Merge(item.stats)
if item.stats.TotalErrors() > 0 {
errorHosts++
}
if len(item.ips) == 0 {
continue
}
resolvedHosts++
totalIPs += len(item.ips)
for _, ip := range item.ips {
if strings.TrimSpace(ip) != "" {
resolvedIPSet[ip] = struct{}{}
}
}
if loggedHosts < maxHostsLog {
logPrewarm(deps.Logf, fmt.Sprintf("prewarm add: %s -> %s", item.host, strings.Join(item.ips, ", ")))
loggedHosts++
}
}
manualAdded := 0
totalDynText := "n/a"
if !in.RuntimeEnabled {
existing, _ := safeReadDynSet(deps.ReadDynSet)
mergedSet := make(map[string]struct{}, len(existing)+len(resolvedIPSet))
for _, ip := range existing {
if strings.TrimSpace(ip) != "" {
mergedSet[ip] = struct{}{}
}
}
for ip := range resolvedIPSet {
if _, ok := mergedSet[ip]; !ok {
manualAdded++
}
mergedSet[ip] = struct{}{}
}
merged := make([]string, 0, len(mergedSet))
for ip := range mergedSet {
merged = append(merged, ip)
}
totalDynText = fmt.Sprintf("%d", len(merged))
if err := safeApplyDynSet(deps.ApplyDynSet, merged); err != nil {
msg := fmt.Sprintf("prewarm manual apply failed: %v", err)
logPrewarm(deps.Logf, msg)
return PrewarmResult{OK: false, Message: msg}
}
logPrewarm(
deps.Logf,
fmt.Sprintf("prewarm manual merge: existing=%d resolved=%d added=%d total_dyn=%d", len(existing), len(resolvedIPSet), manualAdded, len(merged)),
)
}
if len(domains) > loggedHosts {
logPrewarm(
deps.Logf,
fmt.Sprintf(
"prewarm add: trace truncated, omitted=%d hosts (full wildcard map: %s)",
len(domains)-loggedHosts,
strings.TrimSpace(in.WildcardMapPath),
),
)
}
msg := fmt.Sprintf(
"prewarm done: source=%s expanded=%d resolved=%d total_ips=%d error_hosts=%d dns_attempts=%d dns_ok=%d dns_errors=%d manual_added=%d dyn_total=%s",
strings.TrimSpace(in.Source),
len(domains),
resolvedHosts,
totalIPs,
errorHosts,
stats.Attempts,
stats.OK,
stats.TotalErrors(),
manualAdded,
totalDynText,
)
logPrewarm(deps.Logf, msg)
if perUpstream := stats.FormatPerUpstream(); perUpstream != "" {
logPrewarm(deps.Logf, "prewarm dns upstreams: "+perUpstream)
}
return PrewarmResult{
OK: true,
Message: msg,
ExitCode: resolvedHosts,
ResolvedHosts: resolvedHosts,
}
}
func logPrewarm(logf func(string), msg string) {
if logf != nil {
logf(msg)
}
}
func safeDigA(
dig func(host string, dnsList []string, timeout time.Duration) ([]string, PrewarmDNSMetrics),
host string,
dnsList []string,
timeout time.Duration,
) ([]string, PrewarmDNSMetrics) {
if dig == nil {
return nil, PrewarmDNSMetrics{}
}
return dig(host, dnsList, timeout)
}
func safeReadDynSet(read func() ([]string, error)) ([]string, error) {
if read == nil {
return nil, nil
}
return read()
}
func safeApplyDynSet(apply func([]string) error, ips []string) error {
if apply == nil {
return fmt.Errorf("apply dyn set callback is nil")
}
return apply(ips)
}
func isGoogleLikeSafe(check func(string) bool, domain string) bool {
if check == nil {
return false
}
return check(domain)
}
func trimNonEmptyUnique(in []string) []string {
if len(in) == 0 {
return nil
}
seen := map[string]struct{}{}
out := make([]string, 0, len(in))
for _, item := range in {
v := strings.TrimSpace(item)
if v == "" {
continue
}
if _, ok := seen[v]; ok {
continue
}
seen[v] = struct{}{}
out = append(out, v)
}
return out
}