Files
elmprodvpn/selective-vpn-gui/api/dns.py

270 lines
11 KiB
Python

from __future__ import annotations
from typing import Any, Dict, List, cast
from .models import *
class DNSApiMixin:
def dns_upstreams_get(self) -> DnsUpstreams:
data = cast(Dict[str, Any], self._json(self._request("GET", "/api/v1/dns-upstreams")) or {})
return DnsUpstreams(
default1=str(data.get("default1") or ""),
default2=str(data.get("default2") or ""),
meta1=str(data.get("meta1") or ""),
meta2=str(data.get("meta2") or ""),
)
def dns_upstreams_set(self, cfg: DnsUpstreams) -> None:
self._request(
"POST",
"/api/v1/dns-upstreams",
json_body={
"default1": cfg.default1,
"default2": cfg.default2,
"meta1": cfg.meta1,
"meta2": cfg.meta2,
},
)
def dns_upstream_pool_get(self) -> DNSUpstreamPoolState:
data = cast(Dict[str, Any], self._json(self._request("GET", "/api/v1/dns/upstream-pool")) or {})
raw = data.get("items") or []
if not isinstance(raw, list):
raw = []
items: List[DNSBenchmarkUpstream] = []
for row in raw:
if not isinstance(row, dict):
continue
addr = str(row.get("addr") or "").strip()
if not addr:
continue
items.append(DNSBenchmarkUpstream(addr=addr, enabled=bool(row.get("enabled", True))))
return DNSUpstreamPoolState(items=items)
def dns_upstream_pool_set(self, items: List[DNSBenchmarkUpstream]) -> DNSUpstreamPoolState:
data = cast(
Dict[str, Any],
self._json(
self._request(
"POST",
"/api/v1/dns/upstream-pool",
json_body={
"items": [{"addr": u.addr, "enabled": bool(u.enabled)} for u in (items or [])],
},
)
)
or {},
)
raw = data.get("items") or []
if not isinstance(raw, list):
raw = []
out: List[DNSBenchmarkUpstream] = []
for row in raw:
if not isinstance(row, dict):
continue
addr = str(row.get("addr") or "").strip()
if not addr:
continue
out.append(DNSBenchmarkUpstream(addr=addr, enabled=bool(row.get("enabled", True))))
return DNSUpstreamPoolState(items=out)
def dns_benchmark(
self,
upstreams: List[DNSBenchmarkUpstream],
domains: List[str],
timeout_ms: int = 1800,
attempts: int = 1,
concurrency: int = 6,
profile: str = "load",
) -> DNSBenchmarkResponse:
# Benchmark can legitimately run much longer than the default 5s API timeout.
# Estimate a safe read-timeout from payload size and keep an upper cap.
upstream_count = len(upstreams or [])
domain_count = len(domains or [])
if domain_count <= 0:
domain_count = 6 # backend default domains
clamped_attempts = max(1, min(int(attempts), 3))
clamped_concurrency = max(1, min(int(concurrency), 32))
if upstream_count <= 0:
upstream_count = 1
waves = (upstream_count + clamped_concurrency - 1) // clamped_concurrency
mode = str(profile or "load").strip().lower()
if mode not in ("quick", "load"):
mode = "load"
# Rough estimator for backend load profile.
load_factor = 1.0 if mode == "quick" else 6.0
per_wave_sec = domain_count * max(1, clamped_attempts) * (max(300, int(timeout_ms)) / 1000.0) * load_factor
bench_timeout = min(420.0, max(20.0, waves * per_wave_sec * 1.1 + 8.0))
data = cast(
Dict[str, Any],
self._json(
self._request(
"POST",
"/api/v1/dns/benchmark",
json_body={
"upstreams": [{"addr": u.addr, "enabled": bool(u.enabled)} for u in (upstreams or [])],
"domains": [str(d or "").strip() for d in (domains or []) if str(d or "").strip()],
"timeout_ms": int(timeout_ms),
"attempts": int(attempts),
"concurrency": int(concurrency),
"profile": mode,
},
timeout=bench_timeout,
)
)
or {},
)
raw_results = data.get("results") or []
if not isinstance(raw_results, list):
raw_results = []
results: List[DNSBenchmarkResult] = []
for row in raw_results:
if not isinstance(row, dict):
continue
results.append(
DNSBenchmarkResult(
upstream=str(row.get("upstream") or "").strip(),
attempts=int(row.get("attempts", 0) or 0),
ok=int(row.get("ok", 0) or 0),
fail=int(row.get("fail", 0) or 0),
nxdomain=int(row.get("nxdomain", 0) or 0),
timeout=int(row.get("timeout", 0) or 0),
temporary=int(row.get("temporary", 0) or 0),
other=int(row.get("other", 0) or 0),
avg_ms=int(row.get("avg_ms", 0) or 0),
p95_ms=int(row.get("p95_ms", 0) or 0),
score=float(row.get("score", 0.0) or 0.0),
color=str(row.get("color") or "").strip().lower(),
)
)
return DNSBenchmarkResponse(
results=results,
domains_used=[str(d or "").strip() for d in (data.get("domains_used") or []) if str(d or "").strip()],
timeout_ms=int(data.get("timeout_ms", 0) or 0),
attempts_per_domain=int(data.get("attempts_per_domain", 0) or 0),
profile=str(data.get("profile") or mode),
recommended_default=[str(d or "").strip() for d in (data.get("recommended_default") or []) if str(d or "").strip()],
recommended_meta=[str(d or "").strip() for d in (data.get("recommended_meta") or []) if str(d or "").strip()],
)
def dns_status_get(self) -> DNSStatus:
data = cast(Dict[str, Any], self._json(self._request("GET", "/api/v1/dns/status")) or {})
return self._parse_dns_status(data)
def dns_mode_set(self, via_smartdns: bool, smartdns_addr: str) -> DNSStatus:
mode = "hybrid_wildcard" if bool(via_smartdns) else "direct"
data = cast(
Dict[str, Any],
self._json(
self._request(
"POST",
"/api/v1/dns/mode",
json_body={
"via_smartdns": bool(via_smartdns),
"smartdns_addr": str(smartdns_addr or ""),
"mode": mode,
},
)
)
or {},
)
return self._parse_dns_status(data)
def dns_smartdns_service_set(self, action: ServiceAction) -> DNSStatus:
act = action.lower()
if act not in ("start", "stop", "restart"):
raise ValueError(f"Invalid action: {action}")
data = cast(
Dict[str, Any],
self._json(
self._request(
"POST",
"/api/v1/dns/smartdns-service",
json_body={"action": act},
)
)
or {},
)
if not bool(data.get("ok", False)):
raise ValueError(str(data.get("message") or f"SmartDNS {act} failed"))
return self._parse_dns_status(data)
def smartdns_service_get(self) -> SmartdnsServiceState:
data = cast(Dict[str, Any], self._json(self._request("GET", "/api/v1/smartdns/service")) or {})
return SmartdnsServiceState(state=str(data.get("state") or "unknown"))
def smartdns_service_set(self, action: ServiceAction) -> CmdResult:
act = action.lower()
if act not in ("start", "stop", "restart"):
raise ValueError(f"Invalid action: {action}")
data = cast(
Dict[str, Any],
self._json(self._request("POST", "/api/v1/smartdns/service", json_body={"action": act}))
or {},
)
return self._parse_cmd_result(data)
def smartdns_runtime_get(self) -> SmartdnsRuntimeState:
data = cast(Dict[str, Any], self._json(self._request("GET", "/api/v1/smartdns/runtime")) or {})
return SmartdnsRuntimeState(
enabled=bool(data.get("enabled", False)),
applied_enabled=bool(data.get("applied_enabled", False)),
wildcard_source=str(data.get("wildcard_source") or ("both" if bool(data.get("enabled", False)) else "resolver")),
unit_state=str(data.get("unit_state") or "unknown"),
config_path=str(data.get("config_path") or ""),
changed=bool(data.get("changed", False)),
restarted=bool(data.get("restarted", False)),
message=str(data.get("message") or ""),
)
def smartdns_runtime_set(self, enabled: bool, restart: bool = True) -> SmartdnsRuntimeState:
data = cast(
Dict[str, Any],
self._json(
self._request(
"POST",
"/api/v1/smartdns/runtime",
json_body={"enabled": bool(enabled), "restart": bool(restart)},
)
)
or {},
)
return SmartdnsRuntimeState(
enabled=bool(data.get("enabled", False)),
applied_enabled=bool(data.get("applied_enabled", False)),
wildcard_source=str(data.get("wildcard_source") or ("both" if bool(data.get("enabled", False)) else "resolver")),
unit_state=str(data.get("unit_state") or "unknown"),
config_path=str(data.get("config_path") or ""),
changed=bool(data.get("changed", False)),
restarted=bool(data.get("restarted", False)),
message=str(data.get("message") or ""),
)
def smartdns_prewarm(self, limit: int = 0, aggressive_subs: bool = False) -> CmdResult:
payload: Dict[str, Any] = {}
if int(limit) > 0:
payload["limit"] = int(limit)
if aggressive_subs:
payload["aggressive_subs"] = True
data = cast(
Dict[str, Any],
self._json(self._request("POST", "/api/v1/smartdns/prewarm", json_body=payload)) or {},
)
return self._parse_cmd_result(data)
def _parse_dns_status(self, data: Dict[str, Any]) -> DNSStatus:
via = bool(data.get("via_smartdns", False))
runtime = bool(data.get("runtime_nftset", True))
return DNSStatus(
via_smartdns=via,
smartdns_addr=str(data.get("smartdns_addr") or ""),
mode=str(data.get("mode") or ("hybrid_wildcard" if via else "direct")),
unit_state=str(data.get("unit_state") or "unknown"),
runtime_nftset=runtime,
wildcard_source=str(data.get("wildcard_source") or ("both" if runtime else "resolver")),
runtime_config_path=str(data.get("runtime_config_path") or ""),
runtime_config_error=str(data.get("runtime_config_error") or ""),
)