Files
elmprodvpn/selective-vpn-api/app/dnscfg/mode.go

137 lines
3.4 KiB
Go

package dnscfg
import (
"encoding/json"
"os"
"path/filepath"
"strings"
)
func NormalizeResolverMode(mode string, viaSmartDNS bool, directMode string, smartDNSMode string, hybridWildcardMode string) string {
switch strings.ToLower(strings.TrimSpace(mode)) {
case strings.ToLower(strings.TrimSpace(directMode)):
return directMode
case strings.ToLower(strings.TrimSpace(smartDNSMode)):
return hybridWildcardMode
case strings.ToLower(strings.TrimSpace(hybridWildcardMode)), "hybrid":
return hybridWildcardMode
default:
if viaSmartDNS {
return hybridWildcardMode
}
return directMode
}
}
func SmartDNSForced(envRaw string) bool {
switch strings.TrimSpace(strings.ToLower(envRaw)) {
case "1", "true", "yes", "on":
return true
default:
return false
}
}
type ModeState struct {
ViaSmartDNS bool `json:"via_smartdns"`
SmartDNSAddr string `json:"smartdns_addr"`
Mode string `json:"mode"`
}
type ModeConfig struct {
Path string
DirectMode string
DefaultSmartDNSAddr string
NormalizeResolverMode func(mode string, viaSmartDNS bool) string
NormalizeSmartDNSAddr func(raw string) string
}
func LoadMode(cfg ModeConfig) (ModeState, bool) {
mode := ModeState{
ViaSmartDNS: false,
SmartDNSAddr: strings.TrimSpace(cfg.DefaultSmartDNSAddr),
Mode: strings.TrimSpace(cfg.DirectMode),
}
needPersist := false
data, err := os.ReadFile(strings.TrimSpace(cfg.Path))
switch {
case err == nil:
var stored ModeState
if err := json.Unmarshal(data, &stored); err == nil {
normalized, changed := normalizeModeState(stored, cfg)
mode = normalized
if strings.TrimSpace(stored.Mode) == "" || stored.ViaSmartDNS != normalized.ViaSmartDNS || changed {
needPersist = true
}
} else {
needPersist = true
}
case os.IsNotExist(err):
needPersist = true
}
normalized, changed := normalizeModeState(mode, cfg)
mode = normalized
if changed {
needPersist = true
}
return mode, needPersist
}
func SaveMode(cfg ModeConfig, mode ModeState) error {
normalized, _ := normalizeModeState(mode, cfg)
path := strings.TrimSpace(cfg.Path)
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
return err
}
tmp := path + ".tmp"
b, err := json.MarshalIndent(normalized, "", " ")
if err != nil {
return err
}
if err := os.WriteFile(tmp, b, 0o644); err != nil {
return err
}
return os.Rename(tmp, path)
}
func normalizeModeState(mode ModeState, cfg ModeConfig) (ModeState, bool) {
changed := false
prevMode := mode.Mode
mode.Mode = normalizeResolverMode(cfg, mode.Mode, mode.ViaSmartDNS)
if mode.Mode != prevMode {
changed = true
}
viaSmartDNS := mode.Mode != strings.TrimSpace(cfg.DirectMode)
if mode.ViaSmartDNS != viaSmartDNS {
mode.ViaSmartDNS = viaSmartDNS
changed = true
}
prevAddr := mode.SmartDNSAddr
mode.SmartDNSAddr = normalizeSmartDNSAddr(cfg, mode.SmartDNSAddr)
if mode.SmartDNSAddr == "" {
mode.SmartDNSAddr = strings.TrimSpace(cfg.DefaultSmartDNSAddr)
}
if mode.SmartDNSAddr != prevAddr {
changed = true
}
return mode, changed
}
func normalizeResolverMode(cfg ModeConfig, mode string, viaSmartDNS bool) string {
if cfg.NormalizeResolverMode == nil {
return strings.TrimSpace(mode)
}
return strings.TrimSpace(cfg.NormalizeResolverMode(mode, viaSmartDNS))
}
func normalizeSmartDNSAddr(cfg ModeConfig, raw string) string {
if cfg.NormalizeSmartDNSAddr == nil {
return strings.TrimSpace(raw)
}
return strings.TrimSpace(cfg.NormalizeSmartDNSAddr(raw))
}