from __future__ import annotations from typing import Any, Dict, List, cast from .models import * class DNSApiMixin: def dns_upstreams_get(self) -> DnsUpstreams: data = cast(Dict[str, Any], self._json(self._request("GET", "/api/v1/dns-upstreams")) or {}) return DnsUpstreams( default1=str(data.get("default1") or ""), default2=str(data.get("default2") or ""), meta1=str(data.get("meta1") or ""), meta2=str(data.get("meta2") or ""), ) def dns_upstreams_set(self, cfg: DnsUpstreams) -> None: self._request( "POST", "/api/v1/dns-upstreams", json_body={ "default1": cfg.default1, "default2": cfg.default2, "meta1": cfg.meta1, "meta2": cfg.meta2, }, ) def dns_upstream_pool_get(self) -> DNSUpstreamPoolState: data = cast(Dict[str, Any], self._json(self._request("GET", "/api/v1/dns/upstream-pool")) or {}) raw = data.get("items") or [] if not isinstance(raw, list): raw = [] items: List[DNSBenchmarkUpstream] = [] for row in raw: if not isinstance(row, dict): continue addr = str(row.get("addr") or "").strip() if not addr: continue items.append(DNSBenchmarkUpstream(addr=addr, enabled=bool(row.get("enabled", True)))) return DNSUpstreamPoolState(items=items) def dns_upstream_pool_set(self, items: List[DNSBenchmarkUpstream]) -> DNSUpstreamPoolState: data = cast( Dict[str, Any], self._json( self._request( "POST", "/api/v1/dns/upstream-pool", json_body={ "items": [{"addr": u.addr, "enabled": bool(u.enabled)} for u in (items or [])], }, ) ) or {}, ) raw = data.get("items") or [] if not isinstance(raw, list): raw = [] out: List[DNSBenchmarkUpstream] = [] for row in raw: if not isinstance(row, dict): continue addr = str(row.get("addr") or "").strip() if not addr: continue out.append(DNSBenchmarkUpstream(addr=addr, enabled=bool(row.get("enabled", True)))) return DNSUpstreamPoolState(items=out) def dns_benchmark( self, upstreams: List[DNSBenchmarkUpstream], domains: List[str], timeout_ms: int = 1800, attempts: int = 1, concurrency: int = 6, profile: str = "load", ) -> DNSBenchmarkResponse: # Benchmark can legitimately run much longer than the default 5s API timeout. # Estimate a safe read-timeout from payload size and keep an upper cap. upstream_count = len(upstreams or []) domain_count = len(domains or []) if domain_count <= 0: domain_count = 6 # backend default domains clamped_attempts = max(1, min(int(attempts), 3)) clamped_concurrency = max(1, min(int(concurrency), 32)) if upstream_count <= 0: upstream_count = 1 waves = (upstream_count + clamped_concurrency - 1) // clamped_concurrency mode = str(profile or "load").strip().lower() if mode not in ("quick", "load"): mode = "load" # Rough estimator for backend load profile. load_factor = 1.0 if mode == "quick" else 6.0 per_wave_sec = domain_count * max(1, clamped_attempts) * (max(300, int(timeout_ms)) / 1000.0) * load_factor bench_timeout = min(420.0, max(20.0, waves * per_wave_sec * 1.1 + 8.0)) data = cast( Dict[str, Any], self._json( self._request( "POST", "/api/v1/dns/benchmark", json_body={ "upstreams": [{"addr": u.addr, "enabled": bool(u.enabled)} for u in (upstreams or [])], "domains": [str(d or "").strip() for d in (domains or []) if str(d or "").strip()], "timeout_ms": int(timeout_ms), "attempts": int(attempts), "concurrency": int(concurrency), "profile": mode, }, timeout=bench_timeout, ) ) or {}, ) raw_results = data.get("results") or [] if not isinstance(raw_results, list): raw_results = [] results: List[DNSBenchmarkResult] = [] for row in raw_results: if not isinstance(row, dict): continue results.append( DNSBenchmarkResult( upstream=str(row.get("upstream") or "").strip(), attempts=int(row.get("attempts", 0) or 0), ok=int(row.get("ok", 0) or 0), fail=int(row.get("fail", 0) or 0), nxdomain=int(row.get("nxdomain", 0) or 0), timeout=int(row.get("timeout", 0) or 0), temporary=int(row.get("temporary", 0) or 0), other=int(row.get("other", 0) or 0), avg_ms=int(row.get("avg_ms", 0) or 0), p95_ms=int(row.get("p95_ms", 0) or 0), score=float(row.get("score", 0.0) or 0.0), color=str(row.get("color") or "").strip().lower(), ) ) return DNSBenchmarkResponse( results=results, domains_used=[str(d or "").strip() for d in (data.get("domains_used") or []) if str(d or "").strip()], timeout_ms=int(data.get("timeout_ms", 0) or 0), attempts_per_domain=int(data.get("attempts_per_domain", 0) or 0), profile=str(data.get("profile") or mode), recommended_default=[str(d or "").strip() for d in (data.get("recommended_default") or []) if str(d or "").strip()], recommended_meta=[str(d or "").strip() for d in (data.get("recommended_meta") or []) if str(d or "").strip()], ) def dns_status_get(self) -> DNSStatus: data = cast(Dict[str, Any], self._json(self._request("GET", "/api/v1/dns/status")) or {}) return self._parse_dns_status(data) def dns_mode_set(self, via_smartdns: bool, smartdns_addr: str) -> DNSStatus: mode = "hybrid_wildcard" if bool(via_smartdns) else "direct" data = cast( Dict[str, Any], self._json( self._request( "POST", "/api/v1/dns/mode", json_body={ "via_smartdns": bool(via_smartdns), "smartdns_addr": str(smartdns_addr or ""), "mode": mode, }, ) ) or {}, ) return self._parse_dns_status(data) def dns_smartdns_service_set(self, action: ServiceAction) -> DNSStatus: act = action.lower() if act not in ("start", "stop", "restart"): raise ValueError(f"Invalid action: {action}") data = cast( Dict[str, Any], self._json( self._request( "POST", "/api/v1/dns/smartdns-service", json_body={"action": act}, ) ) or {}, ) if not bool(data.get("ok", False)): raise ValueError(str(data.get("message") or f"SmartDNS {act} failed")) return self._parse_dns_status(data) def smartdns_service_get(self) -> SmartdnsServiceState: data = cast(Dict[str, Any], self._json(self._request("GET", "/api/v1/smartdns/service")) or {}) return SmartdnsServiceState(state=str(data.get("state") or "unknown")) def smartdns_service_set(self, action: ServiceAction) -> CmdResult: act = action.lower() if act not in ("start", "stop", "restart"): raise ValueError(f"Invalid action: {action}") data = cast( Dict[str, Any], self._json(self._request("POST", "/api/v1/smartdns/service", json_body={"action": act})) or {}, ) return self._parse_cmd_result(data) def smartdns_runtime_get(self) -> SmartdnsRuntimeState: data = cast(Dict[str, Any], self._json(self._request("GET", "/api/v1/smartdns/runtime")) or {}) return SmartdnsRuntimeState( enabled=bool(data.get("enabled", False)), applied_enabled=bool(data.get("applied_enabled", False)), wildcard_source=str(data.get("wildcard_source") or ("both" if bool(data.get("enabled", False)) else "resolver")), unit_state=str(data.get("unit_state") or "unknown"), config_path=str(data.get("config_path") or ""), changed=bool(data.get("changed", False)), restarted=bool(data.get("restarted", False)), message=str(data.get("message") or ""), ) def smartdns_runtime_set(self, enabled: bool, restart: bool = True) -> SmartdnsRuntimeState: data = cast( Dict[str, Any], self._json( self._request( "POST", "/api/v1/smartdns/runtime", json_body={"enabled": bool(enabled), "restart": bool(restart)}, ) ) or {}, ) return SmartdnsRuntimeState( enabled=bool(data.get("enabled", False)), applied_enabled=bool(data.get("applied_enabled", False)), wildcard_source=str(data.get("wildcard_source") or ("both" if bool(data.get("enabled", False)) else "resolver")), unit_state=str(data.get("unit_state") or "unknown"), config_path=str(data.get("config_path") or ""), changed=bool(data.get("changed", False)), restarted=bool(data.get("restarted", False)), message=str(data.get("message") or ""), ) def smartdns_prewarm(self, limit: int = 0, aggressive_subs: bool = False) -> CmdResult: payload: Dict[str, Any] = {} if int(limit) > 0: payload["limit"] = int(limit) if aggressive_subs: payload["aggressive_subs"] = True data = cast( Dict[str, Any], self._json(self._request("POST", "/api/v1/smartdns/prewarm", json_body=payload)) or {}, ) return self._parse_cmd_result(data) def _parse_dns_status(self, data: Dict[str, Any]) -> DNSStatus: via = bool(data.get("via_smartdns", False)) runtime = bool(data.get("runtime_nftset", True)) return DNSStatus( via_smartdns=via, smartdns_addr=str(data.get("smartdns_addr") or ""), mode=str(data.get("mode") or ("hybrid_wildcard" if via else "direct")), unit_state=str(data.get("unit_state") or "unknown"), runtime_nftset=runtime, wildcard_source=str(data.get("wildcard_source") or ("both" if runtime else "resolver")), runtime_config_path=str(data.get("runtime_config_path") or ""), runtime_config_error=str(data.get("runtime_config_error") or ""), )