393 lines
8.8 KiB
Go
393 lines
8.8 KiB
Go
package dnscfg
|
|
|
|
import (
|
|
"fmt"
|
|
"sort"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
type PrewarmDNSUpstreamMetrics struct {
|
|
Attempts int
|
|
OK int
|
|
NXDomain int
|
|
Timeout int
|
|
Temporary int
|
|
Other int
|
|
Skipped int
|
|
}
|
|
|
|
type PrewarmDNSMetrics struct {
|
|
Attempts int
|
|
OK int
|
|
NXDomain int
|
|
Timeout int
|
|
Temporary int
|
|
Other int
|
|
Skipped int
|
|
|
|
PerUpstream map[string]PrewarmDNSUpstreamMetrics
|
|
}
|
|
|
|
func (m *PrewarmDNSMetrics) Merge(other PrewarmDNSMetrics) {
|
|
m.Attempts += other.Attempts
|
|
m.OK += other.OK
|
|
m.NXDomain += other.NXDomain
|
|
m.Timeout += other.Timeout
|
|
m.Temporary += other.Temporary
|
|
m.Other += other.Other
|
|
m.Skipped += other.Skipped
|
|
if len(other.PerUpstream) == 0 {
|
|
return
|
|
}
|
|
if m.PerUpstream == nil {
|
|
m.PerUpstream = map[string]PrewarmDNSUpstreamMetrics{}
|
|
}
|
|
for upstream, src := range other.PerUpstream {
|
|
dst := m.PerUpstream[upstream]
|
|
dst.Attempts += src.Attempts
|
|
dst.OK += src.OK
|
|
dst.NXDomain += src.NXDomain
|
|
dst.Timeout += src.Timeout
|
|
dst.Temporary += src.Temporary
|
|
dst.Other += src.Other
|
|
dst.Skipped += src.Skipped
|
|
m.PerUpstream[upstream] = dst
|
|
}
|
|
}
|
|
|
|
func (m PrewarmDNSMetrics) TotalErrors() int {
|
|
return m.NXDomain + m.Timeout + m.Temporary + m.Other
|
|
}
|
|
|
|
func (m PrewarmDNSMetrics) FormatPerUpstream() string {
|
|
if len(m.PerUpstream) == 0 {
|
|
return ""
|
|
}
|
|
keys := make([]string, 0, len(m.PerUpstream))
|
|
for k := range m.PerUpstream {
|
|
keys = append(keys, k)
|
|
}
|
|
sort.Strings(keys)
|
|
parts := make([]string, 0, len(keys))
|
|
for _, k := range keys {
|
|
v := m.PerUpstream[k]
|
|
parts = append(parts, fmt.Sprintf("%s{attempts=%d ok=%d nxdomain=%d timeout=%d temporary=%d other=%d skipped=%d}", k, v.Attempts, v.OK, v.NXDomain, v.Timeout, v.Temporary, v.Other, v.Skipped))
|
|
}
|
|
return strings.Join(parts, "; ")
|
|
}
|
|
|
|
type PrewarmInput struct {
|
|
Mode string
|
|
Source string
|
|
RuntimeEnabled bool
|
|
SmartDNSAddr string
|
|
Wildcards []string
|
|
AggressiveSubs bool
|
|
Subs []string
|
|
SubsPerBaseLimit int
|
|
Limit int
|
|
Workers int
|
|
TimeoutMS int
|
|
EnvWorkers int
|
|
EnvTimeoutMS int
|
|
MaxHostsLog int
|
|
WildcardMapPath string
|
|
}
|
|
|
|
type PrewarmDeps struct {
|
|
IsGoogleLike func(string) bool
|
|
EnsureRuntimeSet func()
|
|
DigA func(host string, dnsList []string, timeout time.Duration) ([]string, PrewarmDNSMetrics)
|
|
ReadDynSet func() ([]string, error)
|
|
ApplyDynSet func([]string) error
|
|
Logf func(message string)
|
|
}
|
|
|
|
type PrewarmResult struct {
|
|
OK bool
|
|
Message string
|
|
ExitCode int
|
|
ResolvedHosts int
|
|
}
|
|
|
|
func RunPrewarm(in PrewarmInput, deps PrewarmDeps) PrewarmResult {
|
|
smartdnsAddr := strings.TrimSpace(in.SmartDNSAddr)
|
|
if smartdnsAddr == "" {
|
|
return PrewarmResult{OK: false, Message: "SmartDNS address is empty"}
|
|
}
|
|
|
|
wildcards := trimNonEmptyUnique(in.Wildcards)
|
|
if len(wildcards) == 0 {
|
|
msg := "prewarm skipped: wildcard list is empty"
|
|
logPrewarm(deps.Logf, msg)
|
|
return PrewarmResult{OK: true, Message: msg}
|
|
}
|
|
|
|
aggressive := in.AggressiveSubs
|
|
subs := trimNonEmptyUnique(in.Subs)
|
|
subsPerBaseLimit := in.SubsPerBaseLimit
|
|
if subsPerBaseLimit < 0 {
|
|
subsPerBaseLimit = 0
|
|
}
|
|
|
|
domainSet := make(map[string]struct{}, len(wildcards)*(len(subs)+1))
|
|
for _, d := range wildcards {
|
|
domainSet[d] = struct{}{}
|
|
if !aggressive || isGoogleLikeSafe(deps.IsGoogleLike, d) {
|
|
continue
|
|
}
|
|
maxSubs := len(subs)
|
|
if subsPerBaseLimit > 0 && subsPerBaseLimit < maxSubs {
|
|
maxSubs = subsPerBaseLimit
|
|
}
|
|
for i := 0; i < maxSubs; i++ {
|
|
domainSet[subs[i]+"."+d] = struct{}{}
|
|
}
|
|
}
|
|
|
|
domains := make([]string, 0, len(domainSet))
|
|
for d := range domainSet {
|
|
domains = append(domains, d)
|
|
}
|
|
sort.Strings(domains)
|
|
if in.Limit > 0 && len(domains) > in.Limit {
|
|
domains = domains[:in.Limit]
|
|
}
|
|
if len(domains) == 0 {
|
|
msg := "prewarm skipped: expanded wildcard list is empty"
|
|
logPrewarm(deps.Logf, msg)
|
|
return PrewarmResult{OK: true, Message: msg}
|
|
}
|
|
|
|
workers := in.Workers
|
|
if workers <= 0 {
|
|
workers = in.EnvWorkers
|
|
if workers <= 0 {
|
|
workers = 24
|
|
}
|
|
}
|
|
if workers < 1 {
|
|
workers = 1
|
|
}
|
|
if workers > 200 {
|
|
workers = 200
|
|
}
|
|
|
|
timeoutMS := in.TimeoutMS
|
|
if timeoutMS <= 0 {
|
|
timeoutMS = in.EnvTimeoutMS
|
|
if timeoutMS <= 0 {
|
|
timeoutMS = 1800
|
|
}
|
|
}
|
|
if timeoutMS < 200 {
|
|
timeoutMS = 200
|
|
}
|
|
if timeoutMS > 15000 {
|
|
timeoutMS = 15000
|
|
}
|
|
timeout := time.Duration(timeoutMS) * time.Millisecond
|
|
|
|
if deps.EnsureRuntimeSet != nil {
|
|
deps.EnsureRuntimeSet()
|
|
}
|
|
|
|
logPrewarm(
|
|
deps.Logf,
|
|
fmt.Sprintf(
|
|
"prewarm start: mode=%s source=%s runtime_nftset=%t smartdns=%s wildcard_domains=%d expanded=%d aggressive_subs=%t workers=%d timeout_ms=%d",
|
|
strings.TrimSpace(in.Mode),
|
|
strings.TrimSpace(in.Source),
|
|
in.RuntimeEnabled,
|
|
smartdnsAddr,
|
|
len(wildcards),
|
|
len(domains),
|
|
aggressive,
|
|
workers,
|
|
timeoutMS,
|
|
),
|
|
)
|
|
|
|
type prewarmItem struct {
|
|
host string
|
|
ips []string
|
|
stats PrewarmDNSMetrics
|
|
}
|
|
|
|
jobs := make(chan string, len(domains))
|
|
results := make(chan prewarmItem, len(domains))
|
|
for i := 0; i < workers; i++ {
|
|
go func() {
|
|
for host := range jobs {
|
|
ips, stats := safeDigA(deps.DigA, host, []string{smartdnsAddr}, timeout)
|
|
results <- prewarmItem{host: host, ips: ips, stats: stats}
|
|
}
|
|
}()
|
|
}
|
|
for _, host := range domains {
|
|
jobs <- host
|
|
}
|
|
close(jobs)
|
|
|
|
resolvedHosts := 0
|
|
totalIPs := 0
|
|
errorHosts := 0
|
|
stats := PrewarmDNSMetrics{}
|
|
resolvedIPSet := map[string]struct{}{}
|
|
loggedHosts := 0
|
|
maxHostsLog := in.MaxHostsLog
|
|
if maxHostsLog <= 0 {
|
|
maxHostsLog = 200
|
|
}
|
|
|
|
for i := 0; i < len(domains); i++ {
|
|
item := <-results
|
|
stats.Merge(item.stats)
|
|
if item.stats.TotalErrors() > 0 {
|
|
errorHosts++
|
|
}
|
|
if len(item.ips) == 0 {
|
|
continue
|
|
}
|
|
resolvedHosts++
|
|
totalIPs += len(item.ips)
|
|
for _, ip := range item.ips {
|
|
if strings.TrimSpace(ip) != "" {
|
|
resolvedIPSet[ip] = struct{}{}
|
|
}
|
|
}
|
|
if loggedHosts < maxHostsLog {
|
|
logPrewarm(deps.Logf, fmt.Sprintf("prewarm add: %s -> %s", item.host, strings.Join(item.ips, ", ")))
|
|
loggedHosts++
|
|
}
|
|
}
|
|
|
|
manualAdded := 0
|
|
totalDynText := "n/a"
|
|
if !in.RuntimeEnabled {
|
|
existing, _ := safeReadDynSet(deps.ReadDynSet)
|
|
mergedSet := make(map[string]struct{}, len(existing)+len(resolvedIPSet))
|
|
for _, ip := range existing {
|
|
if strings.TrimSpace(ip) != "" {
|
|
mergedSet[ip] = struct{}{}
|
|
}
|
|
}
|
|
for ip := range resolvedIPSet {
|
|
if _, ok := mergedSet[ip]; !ok {
|
|
manualAdded++
|
|
}
|
|
mergedSet[ip] = struct{}{}
|
|
}
|
|
merged := make([]string, 0, len(mergedSet))
|
|
for ip := range mergedSet {
|
|
merged = append(merged, ip)
|
|
}
|
|
totalDynText = fmt.Sprintf("%d", len(merged))
|
|
if err := safeApplyDynSet(deps.ApplyDynSet, merged); err != nil {
|
|
msg := fmt.Sprintf("prewarm manual apply failed: %v", err)
|
|
logPrewarm(deps.Logf, msg)
|
|
return PrewarmResult{OK: false, Message: msg}
|
|
}
|
|
logPrewarm(
|
|
deps.Logf,
|
|
fmt.Sprintf("prewarm manual merge: existing=%d resolved=%d added=%d total_dyn=%d", len(existing), len(resolvedIPSet), manualAdded, len(merged)),
|
|
)
|
|
}
|
|
|
|
if len(domains) > loggedHosts {
|
|
logPrewarm(
|
|
deps.Logf,
|
|
fmt.Sprintf(
|
|
"prewarm add: trace truncated, omitted=%d hosts (full wildcard map: %s)",
|
|
len(domains)-loggedHosts,
|
|
strings.TrimSpace(in.WildcardMapPath),
|
|
),
|
|
)
|
|
}
|
|
|
|
msg := fmt.Sprintf(
|
|
"prewarm done: source=%s expanded=%d resolved=%d total_ips=%d error_hosts=%d dns_attempts=%d dns_ok=%d dns_errors=%d manual_added=%d dyn_total=%s",
|
|
strings.TrimSpace(in.Source),
|
|
len(domains),
|
|
resolvedHosts,
|
|
totalIPs,
|
|
errorHosts,
|
|
stats.Attempts,
|
|
stats.OK,
|
|
stats.TotalErrors(),
|
|
manualAdded,
|
|
totalDynText,
|
|
)
|
|
logPrewarm(deps.Logf, msg)
|
|
if perUpstream := stats.FormatPerUpstream(); perUpstream != "" {
|
|
logPrewarm(deps.Logf, "prewarm dns upstreams: "+perUpstream)
|
|
}
|
|
|
|
return PrewarmResult{
|
|
OK: true,
|
|
Message: msg,
|
|
ExitCode: resolvedHosts,
|
|
ResolvedHosts: resolvedHosts,
|
|
}
|
|
}
|
|
|
|
func logPrewarm(logf func(string), msg string) {
|
|
if logf != nil {
|
|
logf(msg)
|
|
}
|
|
}
|
|
|
|
func safeDigA(
|
|
dig func(host string, dnsList []string, timeout time.Duration) ([]string, PrewarmDNSMetrics),
|
|
host string,
|
|
dnsList []string,
|
|
timeout time.Duration,
|
|
) ([]string, PrewarmDNSMetrics) {
|
|
if dig == nil {
|
|
return nil, PrewarmDNSMetrics{}
|
|
}
|
|
return dig(host, dnsList, timeout)
|
|
}
|
|
|
|
func safeReadDynSet(read func() ([]string, error)) ([]string, error) {
|
|
if read == nil {
|
|
return nil, nil
|
|
}
|
|
return read()
|
|
}
|
|
|
|
func safeApplyDynSet(apply func([]string) error, ips []string) error {
|
|
if apply == nil {
|
|
return fmt.Errorf("apply dyn set callback is nil")
|
|
}
|
|
return apply(ips)
|
|
}
|
|
|
|
func isGoogleLikeSafe(check func(string) bool, domain string) bool {
|
|
if check == nil {
|
|
return false
|
|
}
|
|
return check(domain)
|
|
}
|
|
|
|
func trimNonEmptyUnique(in []string) []string {
|
|
if len(in) == 0 {
|
|
return nil
|
|
}
|
|
seen := map[string]struct{}{}
|
|
out := make([]string, 0, len(in))
|
|
for _, item := range in {
|
|
v := strings.TrimSpace(item)
|
|
if v == "" {
|
|
continue
|
|
}
|
|
if _, ok := seen[v]; ok {
|
|
continue
|
|
}
|
|
seen[v] = struct{}{}
|
|
out = append(out, v)
|
|
}
|
|
return out
|
|
}
|