225 lines
6.0 KiB
Go
225 lines
6.0 KiB
Go
package app
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
// ---------------------------------------------------------------------
|
|
// smartdns runtime accelerator (nftset -> agvpn_dyn4)
|
|
// ---------------------------------------------------------------------
|
|
|
|
// EN: Runtime accelerator state is persisted separately from DNS mode.
|
|
// EN: This allows enabling/disabling SmartDNS nftset hook without changing
|
|
// EN: resolver primary behavior.
|
|
// RU: Состояние runtime-ускорителя хранится отдельно от DNS mode.
|
|
// RU: Это позволяет включать/выключать SmartDNS nftset-hook независимо от
|
|
// RU: основного пути через резолвер.
|
|
|
|
const (
|
|
smartdnsRuntimeDomainSetLine = "domain-set -name agvpn_wild -file /etc/selective-vpn/smartdns.conf"
|
|
smartdnsRuntimeNftsetLine = "nftset /domain-set:agvpn_wild/#4:inet#agvpn#agvpn_dyn4"
|
|
smartdnsRuntimeStateVersion = 1
|
|
)
|
|
|
|
type smartDNSRuntimeState struct {
|
|
Version int `json:"version"`
|
|
Enabled bool `json:"enabled"`
|
|
UpdatedAt string `json:"updated_at"`
|
|
}
|
|
|
|
func wildcardFillSource(runtimeEnabled bool) string {
|
|
if runtimeEnabled {
|
|
return "both"
|
|
}
|
|
return "resolver"
|
|
}
|
|
|
|
func normalizeSmartDNSRuntimeState(st smartDNSRuntimeState) smartDNSRuntimeState {
|
|
if st.Version <= 0 {
|
|
st.Version = smartdnsRuntimeStateVersion
|
|
}
|
|
return st
|
|
}
|
|
|
|
func smartDNSRuntimeEnabledFromConfig() (bool, error) {
|
|
data, err := os.ReadFile(smartdnsMainConfig)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
for _, raw := range strings.Split(string(data), "\n") {
|
|
trimmed := strings.TrimSpace(raw)
|
|
if trimmed == "" || strings.HasPrefix(trimmed, "#") {
|
|
continue
|
|
}
|
|
if strings.Contains(trimmed, "nftset") &&
|
|
strings.Contains(trimmed, "domain-set:agvpn_wild") &&
|
|
strings.Contains(trimmed, "agvpn_dyn4") {
|
|
return true, nil
|
|
}
|
|
}
|
|
return false, nil
|
|
}
|
|
|
|
func inferSmartDNSRuntimeEnabled() bool {
|
|
enabled, err := smartDNSRuntimeEnabledFromConfig()
|
|
if err != nil {
|
|
// Keep historical behavior on first run when config is unavailable.
|
|
return true
|
|
}
|
|
return enabled
|
|
}
|
|
|
|
func loadSmartDNSRuntimeState(logf func(string, ...any)) smartDNSRuntimeState {
|
|
if data, err := os.ReadFile(smartdnsRTPath); err == nil {
|
|
var st smartDNSRuntimeState
|
|
if json.Unmarshal(data, &st) == nil {
|
|
return normalizeSmartDNSRuntimeState(st)
|
|
}
|
|
if logf != nil {
|
|
logf("smartdns runtime: invalid state json at %s, rebuilding", smartdnsRTPath)
|
|
}
|
|
}
|
|
|
|
st := smartDNSRuntimeState{
|
|
Version: smartdnsRuntimeStateVersion,
|
|
Enabled: inferSmartDNSRuntimeEnabled(),
|
|
UpdatedAt: time.Now().UTC().Format(time.RFC3339),
|
|
}
|
|
_ = saveSmartDNSRuntimeState(st)
|
|
return st
|
|
}
|
|
|
|
func saveSmartDNSRuntimeState(st smartDNSRuntimeState) error {
|
|
st = normalizeSmartDNSRuntimeState(st)
|
|
st.UpdatedAt = time.Now().UTC().Format(time.RFC3339)
|
|
data, err := json.MarshalIndent(st, "", " ")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := os.MkdirAll(filepath.Dir(smartdnsRTPath), 0o755); err != nil {
|
|
return err
|
|
}
|
|
tmp := smartdnsRTPath + ".tmp"
|
|
if err := os.WriteFile(tmp, data, 0o644); err != nil {
|
|
return err
|
|
}
|
|
return os.Rename(tmp, smartdnsRTPath)
|
|
}
|
|
|
|
func smartDNSRuntimeEnabled() bool {
|
|
return loadSmartDNSRuntimeState(nil).Enabled
|
|
}
|
|
|
|
func normalizeSmartDNSMainConfig(content string, enabled bool) string {
|
|
normalized := strings.ReplaceAll(content, "\r\n", "\n")
|
|
lines := strings.Split(normalized, "\n")
|
|
out := make([]string, 0, len(lines)+4)
|
|
|
|
seenDomain := false
|
|
seenNftset := false
|
|
|
|
isDomainLine := func(raw string) bool {
|
|
t := strings.TrimSpace(raw)
|
|
if strings.HasPrefix(t, "#") {
|
|
t = strings.TrimSpace(strings.TrimPrefix(t, "#"))
|
|
}
|
|
return strings.HasPrefix(t, "domain-set ") &&
|
|
strings.Contains(t, "-name agvpn_wild") &&
|
|
strings.Contains(t, "/etc/selective-vpn/smartdns.conf")
|
|
}
|
|
isNftsetLine := func(raw string) bool {
|
|
t := strings.TrimSpace(raw)
|
|
if strings.HasPrefix(t, "#") {
|
|
t = strings.TrimSpace(strings.TrimPrefix(t, "#"))
|
|
}
|
|
return strings.HasPrefix(t, "nftset ") &&
|
|
strings.Contains(t, "domain-set:agvpn_wild") &&
|
|
strings.Contains(t, "agvpn_dyn4")
|
|
}
|
|
|
|
for _, raw := range lines {
|
|
switch {
|
|
case isDomainLine(raw):
|
|
if !seenDomain {
|
|
if enabled {
|
|
out = append(out, smartdnsRuntimeDomainSetLine)
|
|
} else {
|
|
out = append(out, "# "+smartdnsRuntimeDomainSetLine)
|
|
}
|
|
seenDomain = true
|
|
}
|
|
case isNftsetLine(raw):
|
|
if !seenNftset {
|
|
if enabled {
|
|
out = append(out, smartdnsRuntimeNftsetLine)
|
|
} else {
|
|
out = append(out, "# "+smartdnsRuntimeNftsetLine)
|
|
}
|
|
seenNftset = true
|
|
}
|
|
default:
|
|
out = append(out, raw)
|
|
}
|
|
}
|
|
|
|
if enabled && (!seenDomain || !seenNftset) {
|
|
if len(out) > 0 && strings.TrimSpace(out[len(out)-1]) != "" {
|
|
out = append(out, "")
|
|
}
|
|
if !seenDomain {
|
|
out = append(out, smartdnsRuntimeDomainSetLine)
|
|
}
|
|
if !seenNftset {
|
|
out = append(out, smartdnsRuntimeNftsetLine)
|
|
}
|
|
}
|
|
|
|
rendered := strings.Join(out, "\n")
|
|
if !strings.HasSuffix(rendered, "\n") {
|
|
rendered += "\n"
|
|
}
|
|
return rendered
|
|
}
|
|
|
|
func applySmartDNSRuntimeConfig(enabled bool) (bool, error) {
|
|
data, err := os.ReadFile(smartdnsMainConfig)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
current := strings.ReplaceAll(string(data), "\r\n", "\n")
|
|
next := normalizeSmartDNSMainConfig(current, enabled)
|
|
if next == current {
|
|
return false, nil
|
|
}
|
|
tmp := smartdnsMainConfig + ".tmp"
|
|
if err := os.WriteFile(tmp, []byte(next), 0o644); err != nil {
|
|
return false, err
|
|
}
|
|
if err := os.Rename(tmp, smartdnsMainConfig); err != nil {
|
|
return false, err
|
|
}
|
|
return true, nil
|
|
}
|
|
|
|
func smartDNSRuntimeSnapshot() SmartDNSRuntimeStatusResponse {
|
|
st := loadSmartDNSRuntimeState(nil)
|
|
appliedEnabled, err := smartDNSRuntimeEnabledFromConfig()
|
|
msg := ""
|
|
if err != nil {
|
|
msg = fmt.Sprintf("config read error: %v", err)
|
|
}
|
|
return SmartDNSRuntimeStatusResponse{
|
|
Enabled: st.Enabled,
|
|
AppliedEnable: appliedEnabled,
|
|
WildcardSource: wildcardFillSource(st.Enabled),
|
|
UnitState: smartdnsUnitState(),
|
|
ConfigPath: smartdnsMainConfig,
|
|
Message: msg,
|
|
}
|
|
}
|