Files
elmprodvpn/selective-vpn-api/app/resolver/domain_cache.go

650 lines
16 KiB
Go

package resolver
import (
"encoding/json"
"fmt"
"os"
"sort"
"strings"
)
type DomainCacheSource string
const (
DomainCacheSourceDirect DomainCacheSource = "direct"
DomainCacheSourceWildcard DomainCacheSource = "wildcard"
)
const (
DomainStateActive = "active"
DomainStateStable = "stable"
DomainStateSuspect = "suspect"
DomainStateQuarantine = "quarantine"
DomainStateHardQuar = "hard_quarantine"
DomainScoreMin = -100
DomainScoreMax = 100
DomainCacheVersion = 4
DefaultQuarantineTTL = 24 * 3600
DefaultHardQuarTTL = 7 * 24 * 3600
)
var EnvInt = func(key string, def int) int { return def }
var NXHardQuarantineEnabled = func() bool { return false }
type DomainCacheEntry struct {
IPs []string `json:"ips,omitempty"`
LastResolved int `json:"last_resolved,omitempty"`
LastErrorKind string `json:"last_error_kind,omitempty"`
LastErrorAt int `json:"last_error_at,omitempty"`
Score int `json:"score,omitempty"`
State string `json:"state,omitempty"`
QuarantineUntil int `json:"quarantine_until,omitempty"`
}
type DomainCacheRecord struct {
Direct *DomainCacheEntry `json:"direct,omitempty"`
Wildcard *DomainCacheEntry `json:"wildcard,omitempty"`
}
type DomainCacheState struct {
Version int `json:"version"`
Domains map[string]DomainCacheRecord `json:"domains"`
}
func NewDomainCacheState() DomainCacheState {
return DomainCacheState{
Version: DomainCacheVersion,
Domains: map[string]DomainCacheRecord{},
}
}
func NormalizeCacheIPs(raw []string) []string {
seen := map[string]struct{}{}
out := make([]string, 0, len(raw))
for _, ip := range raw {
ip = strings.TrimSpace(ip)
if ip == "" || IsPrivateIPv4(ip) {
continue
}
if _, ok := seen[ip]; ok {
continue
}
seen[ip] = struct{}{}
out = append(out, ip)
}
sort.Strings(out)
return out
}
func NormalizeCacheErrorKind(raw string) (DNSErrorKind, bool) {
switch strings.ToLower(strings.TrimSpace(raw)) {
case string(DNSErrorNXDomain):
return DNSErrorNXDomain, true
case string(DNSErrorTimeout):
return DNSErrorTimeout, true
case string(DNSErrorTemporary):
return DNSErrorTemporary, true
case string(DNSErrorOther):
return DNSErrorOther, true
default:
return "", false
}
}
func NormalizeDomainCacheEntry(in *DomainCacheEntry) *DomainCacheEntry {
if in == nil {
return nil
}
out := &DomainCacheEntry{}
ips := NormalizeCacheIPs(in.IPs)
if len(ips) > 0 && in.LastResolved > 0 {
out.IPs = ips
out.LastResolved = in.LastResolved
}
if kind, ok := NormalizeCacheErrorKind(in.LastErrorKind); ok && in.LastErrorAt > 0 {
out.LastErrorKind = string(kind)
out.LastErrorAt = in.LastErrorAt
}
out.Score = ClampDomainScore(in.Score)
if st := NormalizeDomainState(in.State, out.Score); st != "" {
out.State = st
}
if in.QuarantineUntil > 0 {
out.QuarantineUntil = in.QuarantineUntil
}
if out.LastResolved <= 0 && out.LastErrorAt <= 0 {
if out.Score == 0 && out.QuarantineUntil <= 0 {
return nil
}
}
return out
}
func parseAnyStringSlice(raw any) []string {
switch v := raw.(type) {
case []string:
return append([]string(nil), v...)
case []any:
out := make([]string, 0, len(v))
for _, x := range v {
if s, ok := x.(string); ok {
out = append(out, s)
}
}
return out
default:
return nil
}
}
func parseLegacyDomainCacheEntry(raw any) (DomainCacheEntry, bool) {
m, ok := raw.(map[string]any)
if !ok {
return DomainCacheEntry{}, false
}
ips := NormalizeCacheIPs(parseAnyStringSlice(m["ips"]))
if len(ips) == 0 {
return DomainCacheEntry{}, false
}
ts, ok := parseAnyInt(m["last_resolved"])
if !ok || ts <= 0 {
return DomainCacheEntry{}, false
}
return DomainCacheEntry{IPs: ips, LastResolved: ts}, true
}
func LoadDomainCacheState(path string, logf func(string, ...any)) DomainCacheState {
data, err := os.ReadFile(path)
if err != nil || len(data) == 0 {
return NewDomainCacheState()
}
var st DomainCacheState
if err := json.Unmarshal(data, &st); err == nil && st.Domains != nil {
if st.Version <= 0 {
st.Version = DomainCacheVersion
}
normalized := NewDomainCacheState()
for host, rec := range st.Domains {
host = strings.TrimSpace(strings.ToLower(host))
if host == "" {
continue
}
nrec := DomainCacheRecord{}
nrec.Direct = NormalizeDomainCacheEntry(rec.Direct)
nrec.Wildcard = NormalizeDomainCacheEntry(rec.Wildcard)
if nrec.Direct != nil || nrec.Wildcard != nil {
normalized.Domains[host] = nrec
}
}
return normalized
}
var legacy map[string]any
if err := json.Unmarshal(data, &legacy); err != nil {
if logf != nil {
logf("domain-cache: invalid json at %s, ignore", path)
}
return NewDomainCacheState()
}
out := NewDomainCacheState()
migrated := 0
for host, raw := range legacy {
host = strings.TrimSpace(strings.ToLower(host))
if host == "" || host == "version" || host == "domains" {
continue
}
entry, ok := parseLegacyDomainCacheEntry(raw)
if !ok {
continue
}
rec := out.Domains[host]
rec.Direct = &entry
out.Domains[host] = rec
migrated++
}
if logf != nil && migrated > 0 {
logf("domain-cache: migrated legacy entries=%d into split cache (direct bucket)", migrated)
}
return out
}
func (s DomainCacheState) Get(domain string, source DomainCacheSource, now, ttl int) ([]string, bool) {
rec, ok := s.Domains[strings.TrimSpace(strings.ToLower(domain))]
if !ok {
return nil, false
}
var entry *DomainCacheEntry
switch source {
case DomainCacheSourceWildcard:
entry = rec.Wildcard
default:
entry = rec.Direct
}
if entry == nil || entry.LastResolved <= 0 {
return nil, false
}
if now-entry.LastResolved > ttl {
return nil, false
}
ips := NormalizeCacheIPs(entry.IPs)
if len(ips) == 0 {
return nil, false
}
return ips, true
}
func (s DomainCacheState) GetNegative(domain string, source DomainCacheSource, now, nxTTL, timeoutTTL, temporaryTTL, otherTTL int) (DNSErrorKind, int, bool) {
rec, ok := s.Domains[strings.TrimSpace(strings.ToLower(domain))]
if !ok {
return "", 0, false
}
var entry *DomainCacheEntry
switch source {
case DomainCacheSourceWildcard:
entry = rec.Wildcard
default:
entry = rec.Direct
}
if entry == nil || entry.LastErrorAt <= 0 {
return "", 0, false
}
kind, ok := NormalizeCacheErrorKind(entry.LastErrorKind)
if !ok {
return "", 0, false
}
age := now - entry.LastErrorAt
if age < 0 {
return "", 0, false
}
cacheTTL := 0
switch kind {
case DNSErrorNXDomain:
cacheTTL = nxTTL
case DNSErrorTimeout:
cacheTTL = timeoutTTL
case DNSErrorTemporary:
cacheTTL = temporaryTTL
case DNSErrorOther:
cacheTTL = otherTTL
}
if cacheTTL <= 0 || age > cacheTTL {
return "", 0, false
}
return kind, age, true
}
func (s DomainCacheState) GetStoredIPs(domain string, source DomainCacheSource) []string {
rec, ok := s.Domains[strings.TrimSpace(strings.ToLower(domain))]
if !ok {
return nil
}
entry := GetCacheEntryBySource(rec, source)
if entry == nil {
return nil
}
return NormalizeCacheIPs(entry.IPs)
}
func (s DomainCacheState) GetLastErrorKind(domain string, source DomainCacheSource) (DNSErrorKind, bool) {
rec, ok := s.Domains[strings.TrimSpace(strings.ToLower(domain))]
if !ok {
return "", false
}
entry := GetCacheEntryBySource(rec, source)
if entry == nil || entry.LastErrorAt <= 0 {
return "", false
}
return NormalizeCacheErrorKind(entry.LastErrorKind)
}
func (s DomainCacheState) GetQuarantine(domain string, source DomainCacheSource, now int) (string, int, bool) {
rec, ok := s.Domains[strings.TrimSpace(strings.ToLower(domain))]
if !ok {
return "", 0, false
}
entry := GetCacheEntryBySource(rec, source)
if entry == nil || entry.QuarantineUntil <= 0 {
return "", 0, false
}
if now >= entry.QuarantineUntil {
return "", 0, false
}
state := NormalizeDomainState(entry.State, entry.Score)
if state == "" {
state = DomainStateQuarantine
}
age := 0
if entry.LastErrorAt > 0 {
age = now - entry.LastErrorAt
}
return state, age, true
}
func (s DomainCacheState) GetStale(domain string, source DomainCacheSource, now, maxAge int) ([]string, int, bool) {
if maxAge <= 0 {
return nil, 0, false
}
rec, ok := s.Domains[strings.TrimSpace(strings.ToLower(domain))]
if !ok {
return nil, 0, false
}
entry := GetCacheEntryBySource(rec, source)
if entry == nil || entry.LastResolved <= 0 {
return nil, 0, false
}
age := now - entry.LastResolved
if age < 0 || age > maxAge {
return nil, 0, false
}
ips := NormalizeCacheIPs(entry.IPs)
if len(ips) == 0 {
return nil, 0, false
}
return ips, age, true
}
func (s *DomainCacheState) Set(domain string, source DomainCacheSource, ips []string, now int) {
host := strings.TrimSpace(strings.ToLower(domain))
if host == "" || now <= 0 {
return
}
norm := NormalizeCacheIPs(ips)
if len(norm) == 0 {
return
}
if s.Domains == nil {
s.Domains = map[string]DomainCacheRecord{}
}
rec := s.Domains[host]
prev := GetCacheEntryBySource(rec, source)
prevScore := 0
if prev != nil {
prevScore = prev.Score
}
entry := &DomainCacheEntry{
IPs: norm,
LastResolved: now,
LastErrorKind: "",
LastErrorAt: 0,
Score: ClampDomainScore(prevScore + EnvInt("RESOLVE_DOMAIN_SCORE_OK", 8)),
QuarantineUntil: 0,
}
entry.State = DomainStateFromScore(entry.Score)
switch source {
case DomainCacheSourceWildcard:
rec.Wildcard = entry
default:
rec.Direct = entry
}
s.Domains[host] = rec
}
func GetCacheEntryBySource(rec DomainCacheRecord, source DomainCacheSource) *DomainCacheEntry {
switch source {
case DomainCacheSourceWildcard:
return rec.Wildcard
default:
return rec.Direct
}
}
func ClampDomainScore(v int) int {
if v < DomainScoreMin {
return DomainScoreMin
}
if v > DomainScoreMax {
return DomainScoreMax
}
return v
}
func DomainStateFromScore(score int) string {
switch {
case score >= 20:
return DomainStateActive
case score >= 5:
return DomainStateStable
case score >= -10:
return DomainStateSuspect
case score >= -30:
return DomainStateQuarantine
default:
return DomainStateHardQuar
}
}
func NormalizeDomainState(raw string, score int) string {
switch strings.TrimSpace(strings.ToLower(raw)) {
case DomainStateActive:
return DomainStateActive
case DomainStateStable:
return DomainStateStable
case DomainStateSuspect:
return DomainStateSuspect
case DomainStateQuarantine:
return DomainStateQuarantine
case DomainStateHardQuar:
return DomainStateHardQuar
default:
if score == 0 {
return ""
}
return DomainStateFromScore(score)
}
}
func DomainScorePenalty(stats DNSMetrics) int {
if stats.NXDomain >= 2 {
return EnvInt("RESOLVE_DOMAIN_SCORE_NX_CONFIRMED", -15)
}
if stats.NXDomain > 0 {
return EnvInt("RESOLVE_DOMAIN_SCORE_NX_SINGLE", -7)
}
if stats.Timeout > 0 {
return EnvInt("RESOLVE_DOMAIN_SCORE_TIMEOUT", -3)
}
if stats.Temporary > 0 {
return EnvInt("RESOLVE_DOMAIN_SCORE_TEMPORARY", -2)
}
return EnvInt("RESOLVE_DOMAIN_SCORE_OTHER", -2)
}
func classifyHostErrorKind(stats DNSMetrics) (DNSErrorKind, bool) {
if stats.Timeout > 0 {
return DNSErrorTimeout, true
}
if stats.Temporary > 0 {
return DNSErrorTemporary, true
}
if stats.Other > 0 {
return DNSErrorOther, true
}
if stats.NXDomain > 0 {
return DNSErrorNXDomain, true
}
return "", false
}
func (s *DomainCacheState) SetErrorWithStats(domain string, source DomainCacheSource, stats DNSMetrics, now int) {
host := strings.TrimSpace(strings.ToLower(domain))
if host == "" || now <= 0 {
return
}
kind, ok := classifyHostErrorKind(stats)
if !ok {
return
}
normKind, ok := NormalizeCacheErrorKind(string(kind))
if !ok {
return
}
penalty := DomainScorePenalty(stats)
quarantineTTL := EnvInt("RESOLVE_QUARANTINE_TTL_SEC", DefaultQuarantineTTL)
if quarantineTTL < 0 {
quarantineTTL = 0
}
hardQuarantineTTL := EnvInt("RESOLVE_HARD_QUARANTINE_TTL_SEC", DefaultHardQuarTTL)
if hardQuarantineTTL < 0 {
hardQuarantineTTL = 0
}
if s.Domains == nil {
s.Domains = map[string]DomainCacheRecord{}
}
rec := s.Domains[host]
entry := GetCacheEntryBySource(rec, source)
if entry == nil {
entry = &DomainCacheEntry{}
}
prevKind, _ := NormalizeCacheErrorKind(entry.LastErrorKind)
entry.Score = ClampDomainScore(entry.Score + penalty)
entry.State = DomainStateFromScore(entry.Score)
if normKind == DNSErrorTimeout && prevKind != DNSErrorNXDomain {
if entry.Score < -10 {
entry.Score = -10
}
entry.State = DomainStateSuspect
}
if normKind == DNSErrorNXDomain && !NXHardQuarantineEnabled() && entry.State == DomainStateHardQuar {
entry.State = DomainStateQuarantine
if entry.Score < -30 {
entry.Score = -30
}
}
entry.LastErrorKind = string(normKind)
entry.LastErrorAt = now
switch entry.State {
case DomainStateHardQuar:
entry.QuarantineUntil = now + hardQuarantineTTL
case DomainStateQuarantine:
entry.QuarantineUntil = now + quarantineTTL
default:
entry.QuarantineUntil = 0
}
switch source {
case DomainCacheSourceWildcard:
rec.Wildcard = entry
default:
rec.Direct = entry
}
s.Domains[host] = rec
}
func (s DomainCacheState) ToMap() map[string]any {
out := map[string]any{
"version": DomainCacheVersion,
"domains": map[string]any{},
}
domainsAny := out["domains"].(map[string]any)
hosts := make([]string, 0, len(s.Domains))
for host := range s.Domains {
hosts = append(hosts, host)
}
sort.Strings(hosts)
for _, host := range hosts {
rec := s.Domains[host]
recOut := map[string]any{}
if rec.Direct != nil {
directOut := map[string]any{}
if len(rec.Direct.IPs) > 0 && rec.Direct.LastResolved > 0 {
directOut["ips"] = rec.Direct.IPs
directOut["last_resolved"] = rec.Direct.LastResolved
}
if kind, ok := NormalizeCacheErrorKind(rec.Direct.LastErrorKind); ok && rec.Direct.LastErrorAt > 0 {
directOut["last_error_kind"] = string(kind)
directOut["last_error_at"] = rec.Direct.LastErrorAt
}
if rec.Direct.Score != 0 {
directOut["score"] = rec.Direct.Score
}
if st := NormalizeDomainState(rec.Direct.State, rec.Direct.Score); st != "" {
directOut["state"] = st
}
if rec.Direct.QuarantineUntil > 0 {
directOut["quarantine_until"] = rec.Direct.QuarantineUntil
}
if len(directOut) > 0 {
recOut["direct"] = directOut
}
}
if rec.Wildcard != nil {
wildOut := map[string]any{}
if len(rec.Wildcard.IPs) > 0 && rec.Wildcard.LastResolved > 0 {
wildOut["ips"] = rec.Wildcard.IPs
wildOut["last_resolved"] = rec.Wildcard.LastResolved
}
if kind, ok := NormalizeCacheErrorKind(rec.Wildcard.LastErrorKind); ok && rec.Wildcard.LastErrorAt > 0 {
wildOut["last_error_kind"] = string(kind)
wildOut["last_error_at"] = rec.Wildcard.LastErrorAt
}
if rec.Wildcard.Score != 0 {
wildOut["score"] = rec.Wildcard.Score
}
if st := NormalizeDomainState(rec.Wildcard.State, rec.Wildcard.Score); st != "" {
wildOut["state"] = st
}
if rec.Wildcard.QuarantineUntil > 0 {
wildOut["quarantine_until"] = rec.Wildcard.QuarantineUntil
}
if len(wildOut) > 0 {
recOut["wildcard"] = wildOut
}
}
if len(recOut) > 0 {
domainsAny[host] = recOut
}
}
return out
}
func (s DomainCacheState) FormatStateSummary(now int) string {
type counters struct {
active int
stable int
suspect int
quarantine int
hardQuar int
}
add := func(c *counters, entry *DomainCacheEntry) {
if entry == nil {
return
}
st := NormalizeDomainState(entry.State, entry.Score)
if entry.QuarantineUntil > now {
if st == DomainStateHardQuar {
c.hardQuar++
return
}
c.quarantine++
return
}
switch st {
case DomainStateActive:
c.active++
case DomainStateStable:
c.stable++
case DomainStateSuspect:
c.suspect++
case DomainStateQuarantine:
c.quarantine++
case DomainStateHardQuar:
c.hardQuar++
}
}
var c counters
for _, rec := range s.Domains {
add(&c, rec.Direct)
add(&c, rec.Wildcard)
}
total := c.active + c.stable + c.suspect + c.quarantine + c.hardQuar
if total == 0 {
return ""
}
return fmt.Sprintf(
"active=%d stable=%d suspect=%d quarantine=%d hard_quarantine=%d total=%d",
c.active, c.stable, c.suspect, c.quarantine, c.hardQuar, total,
)
}