Harden resolver and expand traffic runtime controls
This commit is contained in:
@@ -836,7 +836,15 @@ func runSmartdnsPrewarm(limit, workers, timeoutMS int, aggressiveSubs bool) cmdR
|
||||
)
|
||||
}
|
||||
if len(domains) > loggedHosts {
|
||||
appendTraceLineTo(smartdnsLogPath, "smartdns", fmt.Sprintf("prewarm add: +%d domains omitted", len(domains)-loggedHosts))
|
||||
appendTraceLineTo(
|
||||
smartdnsLogPath,
|
||||
"smartdns",
|
||||
fmt.Sprintf(
|
||||
"prewarm add: trace truncated, omitted=%d hosts (full wildcard map: %s)",
|
||||
len(domains)-loggedHosts,
|
||||
lastIPsMapDyn,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
msg := fmt.Sprintf(
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@@ -42,9 +43,24 @@ func handleDomainsTable(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
stdout, _, _, err := runCommand("ipset", "list", "agvpn4")
|
||||
lines := []string{}
|
||||
if err == nil {
|
||||
for _, setName := range []string{"agvpn4", "agvpn_dyn4"} {
|
||||
stdout, _, code, _ := runCommand("nft", "list", "set", "inet", "agvpn", setName)
|
||||
if code == 0 {
|
||||
for _, l := range strings.Split(stdout, "\n") {
|
||||
l = strings.TrimRight(l, "\r")
|
||||
if l != "" {
|
||||
lines = append(lines, l)
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Backward-compatible fallback for legacy hosts that still have ipset.
|
||||
stdout, _, code, _ = runCommand("ipset", "list", setName)
|
||||
if code != 0 {
|
||||
continue
|
||||
}
|
||||
for _, l := range strings.Split(stdout, "\n") {
|
||||
l = strings.TrimRight(l, "\r")
|
||||
if l != "" {
|
||||
@@ -59,7 +75,7 @@ func handleDomainsTable(w http.ResponseWriter, r *http.Request) {
|
||||
// domains file
|
||||
// ---------------------------------------------------------------------
|
||||
|
||||
// GET /api/v1/domains/file?name=bases|meta|subs|static|smartdns|last-ips-map|last-ips-map-direct|last-ips-map-wildcard
|
||||
// GET /api/v1/domains/file?name=bases|meta|subs|static|smartdns|last-ips-map|last-ips-map-direct|last-ips-map-wildcard|wildcard-observed-hosts
|
||||
// POST /api/v1/domains/file { "name": "...", "content": "..." }
|
||||
func handleDomainsFile(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
@@ -73,6 +89,13 @@ func handleDomainsFile(w http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
return
|
||||
}
|
||||
if name == "wildcard-observed-hosts" {
|
||||
writeJSON(w, http.StatusOK, map[string]string{
|
||||
"content": readWildcardObservedHostsContent(),
|
||||
"source": "derived",
|
||||
})
|
||||
return
|
||||
}
|
||||
path, ok := domainFiles[name]
|
||||
if !ok {
|
||||
http.Error(w, "unknown file name", http.StatusBadRequest)
|
||||
@@ -126,7 +149,7 @@ func handleDomainsFile(w http.ResponseWriter, r *http.Request) {
|
||||
writeJSON(w, http.StatusOK, map[string]string{"status": "ok"})
|
||||
return
|
||||
}
|
||||
if body.Name == "last-ips-map-direct" || body.Name == "last-ips-map-wildcard" {
|
||||
if body.Name == "last-ips-map-direct" || body.Name == "last-ips-map-wildcard" || body.Name == "wildcard-observed-hosts" {
|
||||
http.Error(w, "read-only file name", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
@@ -146,6 +169,39 @@ func handleDomainsFile(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
func readWildcardObservedHostsContent() string {
|
||||
data, err := os.ReadFile(lastIPsMapDyn)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
seen := make(map[string]struct{})
|
||||
out := make([]string, 0, 256)
|
||||
for _, ln := range strings.Split(string(data), "\n") {
|
||||
ln = strings.TrimSpace(ln)
|
||||
if ln == "" || strings.HasPrefix(ln, "#") {
|
||||
continue
|
||||
}
|
||||
fields := strings.Fields(ln)
|
||||
if len(fields) < 2 {
|
||||
continue
|
||||
}
|
||||
host := strings.TrimSpace(fields[1])
|
||||
if host == "" || strings.HasPrefix(host, "[") {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[host]; ok {
|
||||
continue
|
||||
}
|
||||
seen[host] = struct{}{}
|
||||
out = append(out, host)
|
||||
}
|
||||
sort.Strings(out)
|
||||
if len(out) == 0 {
|
||||
return ""
|
||||
}
|
||||
return strings.Join(out, "\n") + "\n"
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------
|
||||
// smartdns wildcards
|
||||
// ---------------------------------------------------------------------
|
||||
|
||||
@@ -265,6 +265,23 @@ func runResolverJob(opts ResolverOpts, logf func(string, ...any)) (resolverResul
|
||||
domainCache := loadDomainCacheState(opts.CachePath, logf)
|
||||
ptrCache := loadJSONMap(opts.PtrCachePath)
|
||||
now := int(time.Now().Unix())
|
||||
negTTLNX := envInt("RESOLVE_NEGATIVE_TTL_NX", 6*3600)
|
||||
negTTLTimeout := envInt("RESOLVE_NEGATIVE_TTL_TIMEOUT", 15*60)
|
||||
negTTLTemporary := envInt("RESOLVE_NEGATIVE_TTL_TEMPORARY", 10*60)
|
||||
negTTLOther := envInt("RESOLVE_NEGATIVE_TTL_OTHER", 10*60)
|
||||
clampTTL := func(v int) int {
|
||||
if v < 0 {
|
||||
return 0
|
||||
}
|
||||
if v > 24*3600 {
|
||||
return 24 * 3600
|
||||
}
|
||||
return v
|
||||
}
|
||||
negTTLNX = clampTTL(negTTLNX)
|
||||
negTTLTimeout = clampTTL(negTTLTimeout)
|
||||
negTTLTemporary = clampTTL(negTTLTemporary)
|
||||
negTTLOther = clampTTL(negTTLOther)
|
||||
|
||||
cacheSourceForHost := func(host string) domainCacheSource {
|
||||
switch cfg.Mode {
|
||||
@@ -284,6 +301,7 @@ func runResolverJob(opts ResolverOpts, logf func(string, ...any)) (resolverResul
|
||||
start := time.Now()
|
||||
|
||||
fresh := map[string][]string{}
|
||||
cacheNegativeHits := 0
|
||||
var toResolve []string
|
||||
for _, d := range domains {
|
||||
source := cacheSourceForHost(d)
|
||||
@@ -294,6 +312,13 @@ func runResolverJob(opts ResolverOpts, logf func(string, ...any)) (resolverResul
|
||||
}
|
||||
continue
|
||||
}
|
||||
if kind, age, ok := domainCache.getNegative(d, source, now, negTTLNX, negTTLTimeout, negTTLTemporary, negTTLOther); ok {
|
||||
cacheNegativeHits++
|
||||
if logf != nil {
|
||||
logf("cache neg hit[%s/%s age=%ds]: %s", source, kind, age, d)
|
||||
}
|
||||
continue
|
||||
}
|
||||
toResolve = append(toResolve, d)
|
||||
}
|
||||
|
||||
@@ -303,7 +328,7 @@ func runResolverJob(opts ResolverOpts, logf func(string, ...any)) (resolverResul
|
||||
}
|
||||
|
||||
if logf != nil {
|
||||
logf("resolve: domains=%d cache_hits=%d to_resolve=%d", len(domains), len(fresh), len(toResolve))
|
||||
logf("resolve: domains=%d cache_hits=%d cache_neg_hits=%d to_resolve=%d", len(domains), len(fresh), cacheNegativeHits, len(toResolve))
|
||||
}
|
||||
|
||||
dnsStats := dnsMetrics{}
|
||||
@@ -349,8 +374,16 @@ func runResolverJob(opts ResolverOpts, logf func(string, ...any)) (resolverResul
|
||||
if logf != nil {
|
||||
logf("%s -> %v", r.host, r.ips)
|
||||
}
|
||||
} else if logf != nil {
|
||||
logf("%s: no IPs", r.host)
|
||||
} else {
|
||||
if hostErrors > 0 {
|
||||
source := cacheSourceForHost(r.host)
|
||||
if kind, ok := classifyHostErrorKind(r.stats); ok {
|
||||
domainCache.setError(r.host, source, kind, now)
|
||||
}
|
||||
}
|
||||
if logf != nil {
|
||||
logf("%s: no IPs", r.host)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -443,9 +476,10 @@ func runResolverJob(opts ResolverOpts, logf func(string, ...any)) (resolverResul
|
||||
if logf != nil {
|
||||
dnsErrors := dnsStats.totalErrors()
|
||||
logf(
|
||||
"resolve summary: domains=%d cache_hits=%d resolved_now=%d unresolved=%d static_entries=%d static_skipped=%d unique_ips=%d direct_ips=%d wildcard_ips=%d ptr_lookups=%d ptr_errors=%d dns_attempts=%d dns_ok=%d dns_nxdomain=%d dns_timeout=%d dns_temporary=%d dns_other=%d dns_errors=%d duration_ms=%d",
|
||||
"resolve summary: domains=%d cache_hits=%d cache_neg_hits=%d resolved_now=%d unresolved=%d static_entries=%d static_skipped=%d unique_ips=%d direct_ips=%d wildcard_ips=%d ptr_lookups=%d ptr_errors=%d dns_attempts=%d dns_ok=%d dns_nxdomain=%d dns_timeout=%d dns_temporary=%d dns_other=%d dns_errors=%d duration_ms=%d",
|
||||
len(domains),
|
||||
len(fresh),
|
||||
cacheNegativeHits,
|
||||
len(resolved)-len(fresh),
|
||||
len(domains)-len(resolved),
|
||||
len(staticEntries),
|
||||
@@ -487,17 +521,45 @@ func resolveHostGo(host string, cfg dnsConfig, metaSpecial []string, wildcards w
|
||||
if useMeta {
|
||||
dnsList = cfg.Meta
|
||||
}
|
||||
primaryViaSmartDNS := false
|
||||
switch cfg.Mode {
|
||||
case DNSModeSmartDNS:
|
||||
if cfg.SmartDNS != "" {
|
||||
dnsList = []string{cfg.SmartDNS}
|
||||
primaryViaSmartDNS = true
|
||||
}
|
||||
case DNSModeHybridWildcard:
|
||||
if cfg.SmartDNS != "" && wildcards.match(host) {
|
||||
dnsList = []string{cfg.SmartDNS}
|
||||
primaryViaSmartDNS = true
|
||||
}
|
||||
}
|
||||
ips, stats := digA(host, dnsList, timeout, logf)
|
||||
if len(ips) == 0 &&
|
||||
!primaryViaSmartDNS &&
|
||||
cfg.SmartDNS != "" &&
|
||||
smartDNSFallbackForTimeoutEnabled() &&
|
||||
shouldFallbackToSmartDNS(stats) {
|
||||
if logf != nil {
|
||||
logf(
|
||||
"dns fallback %s: trying smartdns=%s after errors nxdomain=%d timeout=%d temporary=%d other=%d",
|
||||
host,
|
||||
cfg.SmartDNS,
|
||||
stats.NXDomain,
|
||||
stats.Timeout,
|
||||
stats.Temporary,
|
||||
stats.Other,
|
||||
)
|
||||
}
|
||||
fallbackIPs, fallbackStats := digA(host, []string{cfg.SmartDNS}, timeout, logf)
|
||||
stats.merge(fallbackStats)
|
||||
if len(fallbackIPs) > 0 {
|
||||
ips = fallbackIPs
|
||||
if logf != nil {
|
||||
logf("dns fallback %s: resolved via smartdns (%d ips)", host, len(fallbackIPs))
|
||||
}
|
||||
}
|
||||
}
|
||||
out := []string{}
|
||||
seen := map[string]struct{}{}
|
||||
for _, ip := range ips {
|
||||
@@ -512,6 +574,52 @@ func resolveHostGo(host string, cfg dnsConfig, metaSpecial []string, wildcards w
|
||||
return out, stats
|
||||
}
|
||||
|
||||
// smartDNSFallbackForTimeoutEnabled controls direct->SmartDNS fallback behavior.
|
||||
// Default is disabled to avoid overloading SmartDNS on large unresolved batches.
|
||||
// Set RESOLVE_SMARTDNS_TIMEOUT_FALLBACK=1 to enable.
|
||||
func smartDNSFallbackForTimeoutEnabled() bool {
|
||||
v := strings.ToLower(strings.TrimSpace(os.Getenv("RESOLVE_SMARTDNS_TIMEOUT_FALLBACK")))
|
||||
switch v {
|
||||
case "1", "true", "yes", "on":
|
||||
return true
|
||||
case "0", "false", "no", "off":
|
||||
return false
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback is useful only for transport-like errors. If we already got NXDOMAIN,
|
||||
// SmartDNS fallback is unlikely to change result and only adds latency/noise.
|
||||
func shouldFallbackToSmartDNS(stats dnsMetrics) bool {
|
||||
if stats.OK > 0 {
|
||||
return false
|
||||
}
|
||||
if stats.NXDomain > 0 {
|
||||
return false
|
||||
}
|
||||
if stats.Timeout > 0 || stats.Temporary > 0 {
|
||||
return true
|
||||
}
|
||||
return stats.Other > 0
|
||||
}
|
||||
|
||||
func classifyHostErrorKind(stats dnsMetrics) (dnsErrorKind, bool) {
|
||||
if stats.Timeout > 0 {
|
||||
return dnsErrorTimeout, true
|
||||
}
|
||||
if stats.Temporary > 0 {
|
||||
return dnsErrorTemporary, true
|
||||
}
|
||||
if stats.Other > 0 {
|
||||
return dnsErrorOther, true
|
||||
}
|
||||
if stats.NXDomain > 0 {
|
||||
return dnsErrorNXDomain, true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------
|
||||
// EN: `digA` contains core logic for dig a.
|
||||
// RU: `digA` - содержит основную логику для dig a.
|
||||
@@ -742,8 +850,10 @@ const (
|
||||
)
|
||||
|
||||
type domainCacheEntry struct {
|
||||
IPs []string `json:"ips"`
|
||||
LastResolved int `json:"last_resolved"`
|
||||
IPs []string `json:"ips,omitempty"`
|
||||
LastResolved int `json:"last_resolved,omitempty"`
|
||||
LastErrorKind string `json:"last_error_kind,omitempty"`
|
||||
LastErrorAt int `json:"last_error_at,omitempty"`
|
||||
}
|
||||
|
||||
type domainCacheRecord struct {
|
||||
@@ -758,7 +868,7 @@ type domainCacheState struct {
|
||||
|
||||
func newDomainCacheState() domainCacheState {
|
||||
return domainCacheState{
|
||||
Version: 2,
|
||||
Version: 3,
|
||||
Domains: map[string]domainCacheRecord{},
|
||||
}
|
||||
}
|
||||
@@ -781,6 +891,41 @@ func normalizeCacheIPs(raw []string) []string {
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizeCacheErrorKind(raw string) (dnsErrorKind, bool) {
|
||||
switch strings.ToLower(strings.TrimSpace(raw)) {
|
||||
case string(dnsErrorNXDomain):
|
||||
return dnsErrorNXDomain, true
|
||||
case string(dnsErrorTimeout):
|
||||
return dnsErrorTimeout, true
|
||||
case string(dnsErrorTemporary):
|
||||
return dnsErrorTemporary, true
|
||||
case string(dnsErrorOther):
|
||||
return dnsErrorOther, true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeDomainCacheEntry(in *domainCacheEntry) *domainCacheEntry {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
out := &domainCacheEntry{}
|
||||
ips := normalizeCacheIPs(in.IPs)
|
||||
if len(ips) > 0 && in.LastResolved > 0 {
|
||||
out.IPs = ips
|
||||
out.LastResolved = in.LastResolved
|
||||
}
|
||||
if kind, ok := normalizeCacheErrorKind(in.LastErrorKind); ok && in.LastErrorAt > 0 {
|
||||
out.LastErrorKind = string(kind)
|
||||
out.LastErrorAt = in.LastErrorAt
|
||||
}
|
||||
if out.LastResolved <= 0 && out.LastErrorAt <= 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func parseAnyStringSlice(raw any) []string {
|
||||
switch v := raw.(type) {
|
||||
case []string:
|
||||
@@ -842,7 +987,7 @@ func loadDomainCacheState(path string, logf func(string, ...any)) domainCacheSta
|
||||
var st domainCacheState
|
||||
if err := json.Unmarshal(data, &st); err == nil && st.Domains != nil {
|
||||
if st.Version <= 0 {
|
||||
st.Version = 2
|
||||
st.Version = 3
|
||||
}
|
||||
normalized := newDomainCacheState()
|
||||
for host, rec := range st.Domains {
|
||||
@@ -851,18 +996,8 @@ func loadDomainCacheState(path string, logf func(string, ...any)) domainCacheSta
|
||||
continue
|
||||
}
|
||||
nrec := domainCacheRecord{}
|
||||
if rec.Direct != nil {
|
||||
ips := normalizeCacheIPs(rec.Direct.IPs)
|
||||
if len(ips) > 0 && rec.Direct.LastResolved > 0 {
|
||||
nrec.Direct = &domainCacheEntry{IPs: ips, LastResolved: rec.Direct.LastResolved}
|
||||
}
|
||||
}
|
||||
if rec.Wildcard != nil {
|
||||
ips := normalizeCacheIPs(rec.Wildcard.IPs)
|
||||
if len(ips) > 0 && rec.Wildcard.LastResolved > 0 {
|
||||
nrec.Wildcard = &domainCacheEntry{IPs: ips, LastResolved: rec.Wildcard.LastResolved}
|
||||
}
|
||||
}
|
||||
nrec.Direct = normalizeDomainCacheEntry(rec.Direct)
|
||||
nrec.Wildcard = normalizeDomainCacheEntry(rec.Wildcard)
|
||||
if nrec.Direct != nil || nrec.Wildcard != nil {
|
||||
normalized.Domains[host] = nrec
|
||||
}
|
||||
@@ -926,6 +1061,46 @@ func (s domainCacheState) get(domain string, source domainCacheSource, now, ttl
|
||||
return ips, true
|
||||
}
|
||||
|
||||
func (s domainCacheState) getNegative(domain string, source domainCacheSource, now, nxTTL, timeoutTTL, temporaryTTL, otherTTL int) (dnsErrorKind, int, bool) {
|
||||
rec, ok := s.Domains[strings.TrimSpace(strings.ToLower(domain))]
|
||||
if !ok {
|
||||
return "", 0, false
|
||||
}
|
||||
var entry *domainCacheEntry
|
||||
switch source {
|
||||
case domainCacheSourceWildcard:
|
||||
entry = rec.Wildcard
|
||||
default:
|
||||
entry = rec.Direct
|
||||
}
|
||||
if entry == nil || entry.LastErrorAt <= 0 {
|
||||
return "", 0, false
|
||||
}
|
||||
kind, ok := normalizeCacheErrorKind(entry.LastErrorKind)
|
||||
if !ok {
|
||||
return "", 0, false
|
||||
}
|
||||
age := now - entry.LastErrorAt
|
||||
if age < 0 {
|
||||
return "", 0, false
|
||||
}
|
||||
cacheTTL := 0
|
||||
switch kind {
|
||||
case dnsErrorNXDomain:
|
||||
cacheTTL = nxTTL
|
||||
case dnsErrorTimeout:
|
||||
cacheTTL = timeoutTTL
|
||||
case dnsErrorTemporary:
|
||||
cacheTTL = temporaryTTL
|
||||
case dnsErrorOther:
|
||||
cacheTTL = otherTTL
|
||||
}
|
||||
if cacheTTL <= 0 || age > cacheTTL {
|
||||
return "", 0, false
|
||||
}
|
||||
return kind, age, true
|
||||
}
|
||||
|
||||
func (s *domainCacheState) set(domain string, source domainCacheSource, ips []string, now int) {
|
||||
host := strings.TrimSpace(strings.ToLower(domain))
|
||||
if host == "" || now <= 0 {
|
||||
@@ -939,7 +1114,10 @@ func (s *domainCacheState) set(domain string, source domainCacheSource, ips []st
|
||||
s.Domains = map[string]domainCacheRecord{}
|
||||
}
|
||||
rec := s.Domains[host]
|
||||
entry := &domainCacheEntry{IPs: norm, LastResolved: now}
|
||||
entry := &domainCacheEntry{
|
||||
IPs: norm,
|
||||
LastResolved: now,
|
||||
}
|
||||
switch source {
|
||||
case domainCacheSourceWildcard:
|
||||
rec.Wildcard = entry
|
||||
@@ -949,9 +1127,39 @@ func (s *domainCacheState) set(domain string, source domainCacheSource, ips []st
|
||||
s.Domains[host] = rec
|
||||
}
|
||||
|
||||
func (s *domainCacheState) setError(domain string, source domainCacheSource, kind dnsErrorKind, now int) {
|
||||
host := strings.TrimSpace(strings.ToLower(domain))
|
||||
if host == "" || now <= 0 {
|
||||
return
|
||||
}
|
||||
normKind, ok := normalizeCacheErrorKind(string(kind))
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if s.Domains == nil {
|
||||
s.Domains = map[string]domainCacheRecord{}
|
||||
}
|
||||
rec := s.Domains[host]
|
||||
switch source {
|
||||
case domainCacheSourceWildcard:
|
||||
if rec.Wildcard == nil {
|
||||
rec.Wildcard = &domainCacheEntry{}
|
||||
}
|
||||
rec.Wildcard.LastErrorKind = string(normKind)
|
||||
rec.Wildcard.LastErrorAt = now
|
||||
default:
|
||||
if rec.Direct == nil {
|
||||
rec.Direct = &domainCacheEntry{}
|
||||
}
|
||||
rec.Direct.LastErrorKind = string(normKind)
|
||||
rec.Direct.LastErrorAt = now
|
||||
}
|
||||
s.Domains[host] = rec
|
||||
}
|
||||
|
||||
func (s domainCacheState) toMap() map[string]any {
|
||||
out := map[string]any{
|
||||
"version": 2,
|
||||
"version": 3,
|
||||
"domains": map[string]any{},
|
||||
}
|
||||
domainsAny := out["domains"].(map[string]any)
|
||||
@@ -963,16 +1171,32 @@ func (s domainCacheState) toMap() map[string]any {
|
||||
for _, host := range hosts {
|
||||
rec := s.Domains[host]
|
||||
recOut := map[string]any{}
|
||||
if rec.Direct != nil && len(rec.Direct.IPs) > 0 && rec.Direct.LastResolved > 0 {
|
||||
recOut["direct"] = map[string]any{
|
||||
"ips": rec.Direct.IPs,
|
||||
"last_resolved": rec.Direct.LastResolved,
|
||||
if rec.Direct != nil {
|
||||
directOut := map[string]any{}
|
||||
if len(rec.Direct.IPs) > 0 && rec.Direct.LastResolved > 0 {
|
||||
directOut["ips"] = rec.Direct.IPs
|
||||
directOut["last_resolved"] = rec.Direct.LastResolved
|
||||
}
|
||||
if kind, ok := normalizeCacheErrorKind(rec.Direct.LastErrorKind); ok && rec.Direct.LastErrorAt > 0 {
|
||||
directOut["last_error_kind"] = string(kind)
|
||||
directOut["last_error_at"] = rec.Direct.LastErrorAt
|
||||
}
|
||||
if len(directOut) > 0 {
|
||||
recOut["direct"] = directOut
|
||||
}
|
||||
}
|
||||
if rec.Wildcard != nil && len(rec.Wildcard.IPs) > 0 && rec.Wildcard.LastResolved > 0 {
|
||||
recOut["wildcard"] = map[string]any{
|
||||
"ips": rec.Wildcard.IPs,
|
||||
"last_resolved": rec.Wildcard.LastResolved,
|
||||
if rec.Wildcard != nil {
|
||||
wildOut := map[string]any{}
|
||||
if len(rec.Wildcard.IPs) > 0 && rec.Wildcard.LastResolved > 0 {
|
||||
wildOut["ips"] = rec.Wildcard.IPs
|
||||
wildOut["last_resolved"] = rec.Wildcard.LastResolved
|
||||
}
|
||||
if kind, ok := normalizeCacheErrorKind(rec.Wildcard.LastErrorKind); ok && rec.Wildcard.LastErrorAt > 0 {
|
||||
wildOut["last_error_kind"] = string(kind)
|
||||
wildOut["last_error_at"] = rec.Wildcard.LastErrorAt
|
||||
}
|
||||
if len(wildOut) > 0 {
|
||||
recOut["wildcard"] = wildOut
|
||||
}
|
||||
}
|
||||
if len(recOut) > 0 {
|
||||
|
||||
@@ -59,6 +59,12 @@ func saveRoutesClearCache() (routesClearCacheMeta, error) {
|
||||
if err := cacheCopyOrEmpty(stateDir+"/last-ips-map.txt", routesCacheMap); err != nil {
|
||||
warns = append(warns, fmt.Sprintf("last-ips-map cache copy failed: %v", err))
|
||||
}
|
||||
if err := cacheCopyOrEmpty(lastIPsMapDirect, routesCacheMapD); err != nil {
|
||||
warns = append(warns, fmt.Sprintf("last-ips-map-direct cache copy failed: %v", err))
|
||||
}
|
||||
if err := cacheCopyOrEmpty(lastIPsMapDyn, routesCacheMapW); err != nil {
|
||||
warns = append(warns, fmt.Sprintf("last-ips-map-wildcard cache copy failed: %v", err))
|
||||
}
|
||||
|
||||
meta := routesClearCacheMeta{
|
||||
CreatedAt: time.Now().UTC().Format(time.RFC3339),
|
||||
@@ -83,6 +89,10 @@ func saveRoutesClearCache() (routesClearCacheMeta, error) {
|
||||
}
|
||||
|
||||
func restoreRoutesFromCache() cmdResult {
|
||||
return withRoutesOpLock("routes restore", restoreRoutesFromCacheUnlocked)
|
||||
}
|
||||
|
||||
func restoreRoutesFromCacheUnlocked() cmdResult {
|
||||
meta, err := loadRoutesClearCacheMeta()
|
||||
if err != nil {
|
||||
return cmdResult{
|
||||
@@ -174,6 +184,13 @@ func restoreRoutesFromCache() cmdResult {
|
||||
if fileExists(routesCacheMap) {
|
||||
_ = cacheCopyOrEmpty(routesCacheMap, stateDir+"/last-ips-map.txt")
|
||||
}
|
||||
if fileExists(routesCacheMapD) {
|
||||
_ = cacheCopyOrEmpty(routesCacheMapD, lastIPsMapDirect)
|
||||
}
|
||||
if fileExists(routesCacheMapW) {
|
||||
_ = cacheCopyOrEmpty(routesCacheMapW, lastIPsMapDyn)
|
||||
}
|
||||
_ = writeStatusSnapshot(len(ips)+len(dynIPs), iface)
|
||||
|
||||
return cmdResult{
|
||||
OK: true,
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"os"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------
|
||||
@@ -263,6 +264,10 @@ func handleRoutesCacheRestore(w http.ResponseWriter, r *http.Request) {
|
||||
// RU: `routesClear` - содержит основную логику для routes clear.
|
||||
// ---------------------------------------------------------------------
|
||||
func routesClear() cmdResult {
|
||||
return withRoutesOpLock("routes clear", routesClearUnlocked)
|
||||
}
|
||||
|
||||
func routesClearUnlocked() cmdResult {
|
||||
cacheMeta, cacheErr := saveRoutesClearCache()
|
||||
|
||||
stdout, stderr, _, err := runCommand("ip", "rule", "show")
|
||||
@@ -273,6 +278,11 @@ func routesClear() cmdResult {
|
||||
_, _, _, _ = runCommand("ip", "route", "flush", "table", routesTableName())
|
||||
_, _, _, _ = runCommand("nft", "flush", "set", "inet", "agvpn", "agvpn4")
|
||||
_, _, _, _ = runCommand("nft", "flush", "set", "inet", "agvpn", "agvpn_dyn4")
|
||||
iface := strings.TrimSpace(cacheMeta.Iface)
|
||||
if iface == "" {
|
||||
iface, _ = resolveTrafficIface(loadTrafficModeState().PreferredIface)
|
||||
}
|
||||
_ = writeStatusSnapshot(0, iface)
|
||||
|
||||
res := cmdResult{
|
||||
OK: true,
|
||||
@@ -297,6 +307,50 @@ func routesClear() cmdResult {
|
||||
return res
|
||||
}
|
||||
|
||||
func withRoutesOpLock(opName string, fn func() cmdResult) cmdResult {
|
||||
lock, err := os.OpenFile(lockFile, os.O_CREATE|os.O_RDWR, 0o644)
|
||||
if err != nil {
|
||||
return cmdResult{
|
||||
OK: false,
|
||||
Message: fmt.Sprintf("%s lock open error: %v", opName, err),
|
||||
}
|
||||
}
|
||||
defer lock.Close()
|
||||
|
||||
if err := syscall.Flock(int(lock.Fd()), syscall.LOCK_EX|syscall.LOCK_NB); err != nil {
|
||||
return cmdResult{
|
||||
OK: false,
|
||||
Message: fmt.Sprintf("%s skipped: routes operation already running", opName),
|
||||
}
|
||||
}
|
||||
defer syscall.Flock(int(lock.Fd()), syscall.LOCK_UN)
|
||||
|
||||
return fn()
|
||||
}
|
||||
|
||||
func writeStatusSnapshot(ipCount int, iface string) error {
|
||||
if ipCount < 0 {
|
||||
ipCount = 0
|
||||
}
|
||||
iface = strings.TrimSpace(iface)
|
||||
if iface == "" {
|
||||
iface = "-"
|
||||
}
|
||||
st := Status{
|
||||
Timestamp: time.Now().UTC().Format(time.RFC3339),
|
||||
IPCount: ipCount,
|
||||
DomainCount: countDomainsFromMap(lastIPsMapPath),
|
||||
Iface: iface,
|
||||
Table: routesTableName(),
|
||||
Mark: MARK,
|
||||
}
|
||||
data, err := json.MarshalIndent(st, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(statusFilePath, data, 0o644)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------
|
||||
// policy route
|
||||
// ---------------------------------------------------------------------
|
||||
|
||||
@@ -189,6 +189,13 @@ func routesUpdate(iface string) cmdResult {
|
||||
bases := loadList(domainDir + "/bases.txt")
|
||||
subs := loadList(domainDir + "/subs.txt")
|
||||
wildcards := loadSmartDNSWildcardDomains(logp)
|
||||
wildcardBaseSet := make(map[string]struct{}, len(wildcards))
|
||||
for _, d := range wildcards {
|
||||
d = strings.TrimSpace(d)
|
||||
if d != "" {
|
||||
wildcardBaseSet[d] = struct{}{}
|
||||
}
|
||||
}
|
||||
wildcardBasesAdded := 0
|
||||
for _, d := range wildcards {
|
||||
d = strings.TrimSpace(d)
|
||||
@@ -212,7 +219,10 @@ func routesUpdate(iface string) cmdResult {
|
||||
twitterAdded := 0
|
||||
for _, d := range bases {
|
||||
domainSet[d] = struct{}{}
|
||||
if !isGoogleLike(d) {
|
||||
_, wildcardBase := wildcardBaseSet[d]
|
||||
// Wildcard bases are now resolved "as-is" (no subs fan-out) to keep
|
||||
// SmartDNS wildcard behavior transparent and avoid synthetic host noise.
|
||||
if !wildcardBase && !isGoogleLike(d) {
|
||||
limit := len(subs)
|
||||
if subsPerBaseLimit > 0 && subsPerBaseLimit < limit {
|
||||
limit = subsPerBaseLimit
|
||||
@@ -258,6 +268,14 @@ func routesUpdate(iface string) cmdResult {
|
||||
)
|
||||
if wildcardBasesAdded > 0 {
|
||||
logp("domains wildcard seed added: %d base domains from smartdns.conf state", wildcardBasesAdded)
|
||||
appendTraceLineTo(
|
||||
smartdnsLogPath,
|
||||
"smartdns",
|
||||
fmt.Sprintf(
|
||||
"wildcard plan: base_domains=%d sub_expanded=0 (routes update uses pure wildcard bases; subs fan-out only in aggressive prewarm)",
|
||||
wildcardBasesAdded,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
domTmp, _ := os.CreateTemp(stateDir, "domains-*.txt")
|
||||
@@ -612,19 +630,27 @@ func logWildcardSmartDNSTrace(mode DNSMode, source string, pairs [][2]string, wi
|
||||
}
|
||||
sort.Strings(hosts)
|
||||
|
||||
const maxHostsLog = 200
|
||||
omitted := 0
|
||||
if len(hosts) > maxHostsLog {
|
||||
omitted = len(hosts) - maxHostsLog
|
||||
}
|
||||
|
||||
appendTraceLineTo(
|
||||
smartdnsLogPath,
|
||||
"smartdns",
|
||||
fmt.Sprintf("wildcard sync: mode=%s source=%s domains=%d ips=%d", mode.Mode, source, len(hosts), wildcardIPCount),
|
||||
fmt.Sprintf(
|
||||
"wildcard sync: mode=%s source=%s domains=%d ips=%d logged=%d omitted=%d map=%s",
|
||||
mode.Mode, source, len(hosts), wildcardIPCount, len(hosts)-omitted, omitted, lastIPsMapDyn,
|
||||
),
|
||||
)
|
||||
|
||||
const maxHostsLog = 200
|
||||
for i, host := range hosts {
|
||||
if i >= maxHostsLog {
|
||||
appendTraceLineTo(
|
||||
smartdnsLogPath,
|
||||
"smartdns",
|
||||
fmt.Sprintf("wildcard sync: +%d domains omitted", len(hosts)-maxHostsLog),
|
||||
fmt.Sprintf("wildcard sync: trace truncated, %d domains not shown (see %s)", omitted, lastIPsMapDyn),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -25,7 +25,7 @@ import (
|
||||
// RU: привязаны к конкретному systemd unit/cgroup.
|
||||
|
||||
const (
|
||||
trafficAppProfilesDefaultTTLSec = 24 * 60 * 60
|
||||
trafficAppProfilesDefaultTTLSec = 0 // 0 = persistent runtime mark policy
|
||||
)
|
||||
|
||||
var trafficAppProfilesMu sync.Mutex
|
||||
@@ -295,6 +295,11 @@ func loadTrafficAppProfilesState() trafficAppProfilesState {
|
||||
st.Profiles[i].AppKey = canon
|
||||
changed = true
|
||||
}
|
||||
st.Profiles[i].Target = strings.ToLower(strings.TrimSpace(st.Profiles[i].Target))
|
||||
}
|
||||
if deduped, dedupChanged := dedupeTrafficAppProfiles(st.Profiles); dedupChanged {
|
||||
st.Profiles = deduped
|
||||
changed = true
|
||||
}
|
||||
if changed {
|
||||
_ = saveTrafficAppProfilesState(st)
|
||||
@@ -302,6 +307,89 @@ func loadTrafficAppProfilesState() trafficAppProfilesState {
|
||||
return st
|
||||
}
|
||||
|
||||
func dedupeTrafficAppProfiles(in []TrafficAppProfile) ([]TrafficAppProfile, bool) {
|
||||
if len(in) <= 1 {
|
||||
return in, false
|
||||
}
|
||||
|
||||
out := make([]TrafficAppProfile, 0, len(in))
|
||||
byID := map[string]int{}
|
||||
byAppTarget := map[string]int{}
|
||||
changed := false
|
||||
|
||||
for _, raw := range in {
|
||||
p := raw
|
||||
p.ID = strings.TrimSpace(p.ID)
|
||||
p.Target = strings.ToLower(strings.TrimSpace(p.Target))
|
||||
p.AppKey = canonicalizeAppKey(p.AppKey, p.Command)
|
||||
|
||||
if p.ID == "" {
|
||||
changed = true
|
||||
continue
|
||||
}
|
||||
if p.Target != "vpn" && p.Target != "direct" {
|
||||
p.Target = "vpn"
|
||||
changed = true
|
||||
}
|
||||
|
||||
if idx, ok := byID[p.ID]; ok {
|
||||
if preferTrafficProfile(p, out[idx]) {
|
||||
out[idx] = p
|
||||
}
|
||||
changed = true
|
||||
continue
|
||||
}
|
||||
|
||||
if p.AppKey != "" {
|
||||
key := p.Target + "|" + p.AppKey
|
||||
if idx, ok := byAppTarget[key]; ok {
|
||||
if preferTrafficProfile(p, out[idx]) {
|
||||
byID[p.ID] = idx
|
||||
out[idx] = p
|
||||
}
|
||||
changed = true
|
||||
continue
|
||||
}
|
||||
byAppTarget[key] = len(out)
|
||||
}
|
||||
|
||||
byID[p.ID] = len(out)
|
||||
out = append(out, p)
|
||||
}
|
||||
return out, changed
|
||||
}
|
||||
|
||||
func preferTrafficProfile(cand, cur TrafficAppProfile) bool {
|
||||
cu := strings.TrimSpace(cand.UpdatedAt)
|
||||
ou := strings.TrimSpace(cur.UpdatedAt)
|
||||
if cu != ou {
|
||||
if cu == "" {
|
||||
return false
|
||||
}
|
||||
if ou == "" {
|
||||
return true
|
||||
}
|
||||
return cu > ou
|
||||
}
|
||||
|
||||
cc := strings.TrimSpace(cand.CreatedAt)
|
||||
oc := strings.TrimSpace(cur.CreatedAt)
|
||||
if cc != oc {
|
||||
if cc == "" {
|
||||
return false
|
||||
}
|
||||
if oc == "" {
|
||||
return true
|
||||
}
|
||||
return cc > oc
|
||||
}
|
||||
|
||||
if strings.TrimSpace(cand.Command) != "" && strings.TrimSpace(cur.Command) == "" {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func saveTrafficAppProfilesState(st trafficAppProfilesState) error {
|
||||
st.Version = 1
|
||||
st.UpdatedAt = time.Now().UTC().Format(time.RFC3339)
|
||||
|
||||
@@ -32,7 +32,7 @@ func canonicalizeAppKey(appKey string, command string) string {
|
||||
key := strings.TrimSpace(appKey)
|
||||
cmd := strings.TrimSpace(command)
|
||||
|
||||
fields := strings.Fields(cmd)
|
||||
fields := splitCommandTokens(cmd)
|
||||
if len(fields) == 0 && key != "" {
|
||||
fields = []string{key}
|
||||
}
|
||||
@@ -61,12 +61,12 @@ func canonicalizeAppKey(appKey string, command string) string {
|
||||
switch base {
|
||||
case "flatpak":
|
||||
if id := extractRunTarget(clean); id != "" {
|
||||
return "flatpak:" + id
|
||||
return "flatpak:" + strings.ToLower(strings.TrimSpace(id))
|
||||
}
|
||||
return "flatpak"
|
||||
case "snap":
|
||||
if name := extractRunTarget(clean); name != "" {
|
||||
return "snap:" + name
|
||||
return "snap:" + strings.ToLower(strings.TrimSpace(name))
|
||||
}
|
||||
return "snap"
|
||||
case "gtk-launch":
|
||||
@@ -74,7 +74,7 @@ func canonicalizeAppKey(appKey string, command string) string {
|
||||
if len(clean) >= 2 {
|
||||
id := strings.TrimSpace(clean[1])
|
||||
if id != "" && !strings.HasPrefix(id, "-") {
|
||||
return "desktop:" + id
|
||||
return "desktop:" + strings.ToLower(id)
|
||||
}
|
||||
}
|
||||
case "env":
|
||||
@@ -102,11 +102,11 @@ func canonicalizeAppKey(appKey string, command string) string {
|
||||
if strings.Contains(primary, "/") {
|
||||
b := filepath.Base(primary)
|
||||
if b != "" && b != "." && b != "/" {
|
||||
return b
|
||||
return strings.ToLower(strings.TrimSpace(b))
|
||||
}
|
||||
}
|
||||
|
||||
return primary
|
||||
return strings.ToLower(strings.TrimSpace(primary))
|
||||
}
|
||||
|
||||
func stripOuterQuotes(s string) string {
|
||||
@@ -151,3 +151,65 @@ func extractRunTarget(fields []string) string {
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// splitCommandTokens performs lightweight shell-style tokenization.
|
||||
// It supports single/double quotes and backslash escaping which is enough
|
||||
// for canonical app key extraction.
|
||||
func splitCommandTokens(raw string) []string {
|
||||
s := strings.TrimSpace(raw)
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
out := make([]string, 0, 8)
|
||||
var cur strings.Builder
|
||||
inSingle := false
|
||||
inDouble := false
|
||||
escaped := false
|
||||
|
||||
flush := func() {
|
||||
if cur.Len() == 0 {
|
||||
return
|
||||
}
|
||||
out = append(out, cur.String())
|
||||
cur.Reset()
|
||||
}
|
||||
|
||||
for _, r := range s {
|
||||
if escaped {
|
||||
cur.WriteRune(r)
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
switch r {
|
||||
case '\\':
|
||||
if inSingle {
|
||||
cur.WriteRune(r)
|
||||
} else {
|
||||
escaped = true
|
||||
}
|
||||
case '\'':
|
||||
if inDouble {
|
||||
cur.WriteRune(r)
|
||||
} else {
|
||||
inSingle = !inSingle
|
||||
}
|
||||
case '"':
|
||||
if inSingle {
|
||||
cur.WriteRune(r)
|
||||
} else {
|
||||
inDouble = !inDouble
|
||||
}
|
||||
case ' ', '\t', '\n', '\r':
|
||||
if inSingle || inDouble {
|
||||
cur.WriteRune(r)
|
||||
} else {
|
||||
flush()
|
||||
}
|
||||
default:
|
||||
cur.WriteRune(r)
|
||||
}
|
||||
}
|
||||
flush()
|
||||
return out
|
||||
}
|
||||
|
||||
135
selective-vpn-api/app/traffic_appkey_test.go
Normal file
135
selective-vpn-api/app/traffic_appkey_test.go
Normal file
@@ -0,0 +1,135 @@
|
||||
package app
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestCanonicalizeAppKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
appKey string
|
||||
command string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "path vs bare command normalized to lowercase basename",
|
||||
command: "/usr/bin/Google-Chrome-Stable --new-window",
|
||||
want: "google-chrome-stable",
|
||||
},
|
||||
{
|
||||
name: "quoted path with spaces",
|
||||
command: "'/opt/My Apps/Opera' --private",
|
||||
want: "opera",
|
||||
},
|
||||
{
|
||||
name: "env wrapper skips assignments",
|
||||
command: "env GTK_THEME=Adwaita /usr/bin/Brave-Browser --incognito",
|
||||
want: "brave-browser",
|
||||
},
|
||||
{
|
||||
name: "flatpak run app id",
|
||||
command: "flatpak run org.mozilla.Firefox",
|
||||
want: "flatpak:org.mozilla.firefox",
|
||||
},
|
||||
{
|
||||
name: "snap run app id",
|
||||
command: "snap run --experimental foo.Bar",
|
||||
want: "snap:foo.bar",
|
||||
},
|
||||
{
|
||||
name: "gtk-launch desktop id",
|
||||
command: "gtk-launch Org.Gnome.Nautilus.desktop",
|
||||
want: "desktop:org.gnome.nautilus.desktop",
|
||||
},
|
||||
{
|
||||
name: "explicit app key fallback",
|
||||
appKey: "Opera",
|
||||
want: "opera",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := canonicalizeAppKey(tc.appKey, tc.command)
|
||||
if got != tc.want {
|
||||
t.Fatalf("canonicalizeAppKey(%q,%q) = %q, want %q", tc.appKey, tc.command, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitCommandTokens(t *testing.T) {
|
||||
in := `env A=1 "/opt/My App/bin/App" --flag="x y"`
|
||||
got := splitCommandTokens(in)
|
||||
want := []string{"env", "A=1", "/opt/My App/bin/App", "--flag=x y"}
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("tokens len=%d want=%d tokens=%v", len(got), len(want), got)
|
||||
}
|
||||
for i := range want {
|
||||
if got[i] != want[i] {
|
||||
t.Fatalf("token[%d]=%q want=%q all=%v", i, got[i], want[i], got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDedupeTrafficAppProfilesByCanonicalAppKey(t *testing.T) {
|
||||
in := []TrafficAppProfile{
|
||||
{
|
||||
ID: "chrome-old",
|
||||
Target: "VPN",
|
||||
AppKey: "Google-Chrome-Stable",
|
||||
Command: "/usr/bin/Google-Chrome-Stable --new-window",
|
||||
UpdatedAt: "2026-02-20T10:00:00Z",
|
||||
},
|
||||
{
|
||||
ID: "chrome-new",
|
||||
Target: "vpn",
|
||||
AppKey: "google-chrome-stable",
|
||||
Command: "google-chrome-stable --incognito",
|
||||
UpdatedAt: "2026-02-20T11:00:00Z",
|
||||
},
|
||||
}
|
||||
out, changed := dedupeTrafficAppProfiles(in)
|
||||
if !changed {
|
||||
t.Fatalf("expected changed=true")
|
||||
}
|
||||
if len(out) != 1 {
|
||||
t.Fatalf("expected 1 profile, got %d", len(out))
|
||||
}
|
||||
if out[0].ID != "chrome-new" {
|
||||
t.Fatalf("expected newest profile to win, got id=%q", out[0].ID)
|
||||
}
|
||||
if out[0].AppKey != "google-chrome-stable" {
|
||||
t.Fatalf("expected canonical app key, got %q", out[0].AppKey)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDedupeAppMarkItemsByCanonicalAppKey(t *testing.T) {
|
||||
in := []appMarkItem{
|
||||
{
|
||||
ID: 101,
|
||||
Target: "VPN",
|
||||
AppKey: "Opera",
|
||||
Command: "/usr/bin/Opera --private",
|
||||
AddedAt: "2026-02-20T10:00:00Z",
|
||||
},
|
||||
{
|
||||
ID: 202,
|
||||
Target: "vpn",
|
||||
AppKey: "opera",
|
||||
Command: "opera --new-window",
|
||||
AddedAt: "2026-02-20T11:00:00Z",
|
||||
},
|
||||
}
|
||||
out, changed := dedupeAppMarkItems(in)
|
||||
if !changed {
|
||||
t.Fatalf("expected changed=true")
|
||||
}
|
||||
if len(out) != 1 {
|
||||
t.Fatalf("expected 1 app mark item, got %d", len(out))
|
||||
}
|
||||
if out[0].ID != 202 {
|
||||
t.Fatalf("expected newest item to win, got id=%d", out[0].ID)
|
||||
}
|
||||
if out[0].AppKey != "opera" {
|
||||
t.Fatalf("expected canonical app key, got %q", out[0].AppKey)
|
||||
}
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
@@ -31,8 +32,11 @@ import (
|
||||
const (
|
||||
appMarksTable = "agvpn"
|
||||
appMarksChain = "output_apps"
|
||||
appMarksGuardChain = "output_guard"
|
||||
appMarksLocalBypassSet = "svpn_local4"
|
||||
appMarkCommentPrefix = "svpn_appmark"
|
||||
defaultAppMarkTTLSeconds = 24 * 60 * 60
|
||||
appGuardCommentPrefix = "svpn_appguard"
|
||||
defaultAppMarkTTLSeconds = 0 // 0 = persistent until explicit unmark/clear
|
||||
)
|
||||
|
||||
var appMarksMu sync.Mutex
|
||||
@@ -129,9 +133,6 @@ func handleTrafficAppMarks(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
ttl := timeoutSec
|
||||
if ttl == 0 {
|
||||
ttl = defaultAppMarkTTLSeconds
|
||||
}
|
||||
|
||||
rel, level, inodeID, cgAbs, err := resolveCgroupV2PathForNft(cgroup)
|
||||
if err != nil {
|
||||
@@ -145,6 +146,7 @@ func handleTrafficAppMarks(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
vpnIface := ""
|
||||
if target == "vpn" {
|
||||
traffic := loadTrafficModeState()
|
||||
iface, _ := resolveTrafficIface(traffic.PreferredIface)
|
||||
@@ -159,6 +161,7 @@ func handleTrafficAppMarks(w http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
return
|
||||
}
|
||||
vpnIface = strings.TrimSpace(iface)
|
||||
if err := ensureTrafficRouteBase(iface, traffic.AutoLocalBypass); err != nil {
|
||||
writeJSON(w, http.StatusOK, TrafficAppMarksResponse{
|
||||
OK: false,
|
||||
@@ -172,7 +175,7 @@ func handleTrafficAppMarks(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
if err := appMarksAdd(target, inodeID, cgAbs, rel, level, unit, command, appKey, ttl); err != nil {
|
||||
if err := appMarksAdd(target, inodeID, cgAbs, rel, level, unit, command, appKey, ttl, vpnIface); err != nil {
|
||||
writeJSON(w, http.StatusOK, TrafficAppMarksResponse{
|
||||
OK: false,
|
||||
Op: string(op),
|
||||
@@ -253,11 +256,16 @@ func handleTrafficAppMarksItems(w http.ResponseWriter, r *http.Request) {
|
||||
now := time.Now().UTC()
|
||||
items := make([]TrafficAppMarkItemView, 0, len(st.Items))
|
||||
for _, it := range st.Items {
|
||||
rem := 0
|
||||
exp, err := time.Parse(time.RFC3339, strings.TrimSpace(it.ExpiresAt))
|
||||
if err == nil {
|
||||
rem = int(exp.Sub(now).Seconds())
|
||||
if rem < 0 {
|
||||
rem := -1 // persistent by default
|
||||
expRaw := strings.TrimSpace(it.ExpiresAt)
|
||||
if expRaw != "" {
|
||||
exp, err := time.Parse(time.RFC3339, expRaw)
|
||||
if err == nil {
|
||||
rem = int(exp.Sub(now).Seconds())
|
||||
if rem < 0 {
|
||||
rem = 0
|
||||
}
|
||||
} else {
|
||||
rem = 0
|
||||
}
|
||||
}
|
||||
@@ -308,7 +316,7 @@ func appMarksGetStatus() (vpnCount int, directCount int) {
|
||||
return vpnCount, directCount
|
||||
}
|
||||
|
||||
func appMarksAdd(target string, id uint64, cgAbs string, rel string, level int, unit string, command string, appKey string, ttlSec int) error {
|
||||
func appMarksAdd(target string, id uint64, cgAbs string, rel string, level int, unit string, command string, appKey string, ttlSec int, vpnIface string) error {
|
||||
target = strings.ToLower(strings.TrimSpace(target))
|
||||
if target != "vpn" && target != "direct" {
|
||||
return fmt.Errorf("invalid target")
|
||||
@@ -333,30 +341,51 @@ func appMarksAdd(target string, id uint64, cgAbs string, rel string, level int,
|
||||
command = strings.TrimSpace(command)
|
||||
appKey = canonicalizeAppKey(appKey, command)
|
||||
|
||||
// EN: Avoid unbounded growth of marks for the same app.
|
||||
// RU: Не даём бесконечно плодить метки на одно и то же приложение.
|
||||
if appKey != "" {
|
||||
kept := st.Items[:0]
|
||||
for _, it := range st.Items {
|
||||
if strings.ToLower(strings.TrimSpace(it.Target)) == target &&
|
||||
strings.TrimSpace(it.AppKey) == appKey &&
|
||||
it.ID != id {
|
||||
_ = nftDeleteAppMarkRule(target, it.ID)
|
||||
changed = true
|
||||
continue
|
||||
}
|
||||
kept = append(kept, it)
|
||||
// EN: Keep only one effective mark per app and avoid cross-target conflicts.
|
||||
// EN: If the same app_key is re-marked with another target, old mark is removed first.
|
||||
// RU: Держим только одну эффективную метку на приложение и убираем конфликты между target.
|
||||
// RU: Если тот же app_key перемечается на другой target — старая метка удаляется.
|
||||
kept := st.Items[:0]
|
||||
for _, it := range st.Items {
|
||||
itTarget := strings.ToLower(strings.TrimSpace(it.Target))
|
||||
itKey := strings.TrimSpace(it.AppKey)
|
||||
remove := false
|
||||
|
||||
// Same cgroup id but different target => conflicting rules (mark+guard).
|
||||
if it.ID == id && it.ID != 0 && itTarget != target {
|
||||
remove = true
|
||||
}
|
||||
st.Items = kept
|
||||
// Same app_key (if known) should not keep multiple active runtime routes.
|
||||
if !remove && appKey != "" && itKey != "" && itKey == appKey {
|
||||
if it.ID != id || itTarget != target {
|
||||
remove = true
|
||||
}
|
||||
}
|
||||
|
||||
if remove {
|
||||
_ = nftDeleteAppMarkRule(itTarget, it.ID)
|
||||
changed = true
|
||||
continue
|
||||
}
|
||||
kept = append(kept, it)
|
||||
}
|
||||
st.Items = kept
|
||||
|
||||
// Replace any existing rule/state for this (target,id).
|
||||
_ = nftDeleteAppMarkRule(target, id)
|
||||
if err := nftInsertAppMarkRule(target, rel, level, id); err != nil {
|
||||
if err := nftInsertAppMarkRule(target, rel, level, id, vpnIface); err != nil {
|
||||
return err
|
||||
}
|
||||
if !nftHasAppMarkRule(target, id) {
|
||||
_ = nftDeleteAppMarkRule(target, id)
|
||||
return fmt.Errorf("appmark rule not active after insert (target=%s id=%d)", target, id)
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
expiresAt := ""
|
||||
if ttlSec > 0 {
|
||||
expiresAt = now.Add(time.Duration(ttlSec) * time.Second).Format(time.RFC3339)
|
||||
}
|
||||
item := appMarkItem{
|
||||
ID: id,
|
||||
Target: target,
|
||||
@@ -367,13 +396,15 @@ func appMarksAdd(target string, id uint64, cgAbs string, rel string, level int,
|
||||
Command: command,
|
||||
AppKey: appKey,
|
||||
AddedAt: now.Format(time.RFC3339),
|
||||
ExpiresAt: now.Add(time.Duration(ttlSec) * time.Second).Format(time.RFC3339),
|
||||
ExpiresAt: expiresAt,
|
||||
}
|
||||
st.Items = upsertAppMarkItem(st.Items, item)
|
||||
changed = true
|
||||
|
||||
if changed {
|
||||
if err := saveAppMarksState(st); err != nil {
|
||||
// Keep runtime state and nft in sync on disk write errors.
|
||||
_ = nftDeleteAppMarkRule(target, id)
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -479,7 +510,9 @@ func ensureAppMarksNft() error {
|
||||
// Best-effort "ensure": ignore "exists" errors and proceed.
|
||||
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "table", "inet", appMarksTable)
|
||||
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "chain", "inet", appMarksTable, "output", "{", "type", "route", "hook", "output", "priority", "mangle;", "policy", "accept;", "}")
|
||||
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "chain", "inet", appMarksTable, appMarksGuardChain, "{", "type", "filter", "hook", "output", "priority", "filter;", "policy", "accept;", "}")
|
||||
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "chain", "inet", appMarksTable, appMarksChain)
|
||||
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "set", "inet", appMarksTable, appMarksLocalBypassSet, "{", "type", "ipv4_addr", ";", "flags", "interval", ";", "}")
|
||||
|
||||
out, _, _, _ := runCommandTimeout(5*time.Second, "nft", "list", "chain", "inet", appMarksTable, "output")
|
||||
if !strings.Contains(out, "jump "+appMarksChain) {
|
||||
@@ -514,7 +547,102 @@ func appMarkComment(target string, id uint64) string {
|
||||
return fmt.Sprintf("%s:%s:%d", appMarkCommentPrefix, target, id)
|
||||
}
|
||||
|
||||
func nftInsertAppMarkRule(target string, rel string, level int, id uint64) error {
|
||||
func appGuardComment(target string, id uint64) string {
|
||||
return fmt.Sprintf("%s:%s:%d", appGuardCommentPrefix, target, id)
|
||||
}
|
||||
|
||||
func appGuardEnabled() bool {
|
||||
v := strings.ToLower(strings.TrimSpace(os.Getenv("SVPN_APP_GUARD")))
|
||||
return v == "1" || v == "true" || v == "yes" || v == "on"
|
||||
}
|
||||
|
||||
func updateAppMarkLocalBypassSet(vpnIface string) error {
|
||||
// EN: Keep a small allowlist for local/LAN/container destinations so VPN app kill-switch
|
||||
// EN: does not break host-local access.
|
||||
// RU: Держим небольшой allowlist локальных/LAN/container направлений, чтобы VPN kill-switch
|
||||
// RU: не ломал локальный доступ хоста.
|
||||
vpnIface = strings.TrimSpace(vpnIface)
|
||||
_ = ensureAppMarksNft()
|
||||
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "flush", "set", "inet", appMarksTable, appMarksLocalBypassSet)
|
||||
|
||||
elems := []string{"127.0.0.0/8"}
|
||||
for _, rt := range detectAutoLocalBypassRoutes(vpnIface) {
|
||||
dst := strings.TrimSpace(rt.Dst)
|
||||
if dst == "" || dst == "default" {
|
||||
continue
|
||||
}
|
||||
elems = append(elems, dst)
|
||||
}
|
||||
elems = compactIPv4IntervalElements(elems)
|
||||
for _, e := range elems {
|
||||
_, out, code, err := runCommandTimeout(
|
||||
5*time.Second,
|
||||
"nft", "add", "element", "inet", appMarksTable, appMarksLocalBypassSet,
|
||||
"{", e, "}",
|
||||
)
|
||||
if err != nil || code != 0 {
|
||||
if err == nil {
|
||||
err = fmt.Errorf("nft add element exited with %d", code)
|
||||
}
|
||||
return fmt.Errorf("failed to update %s: %w (%s)", appMarksLocalBypassSet, err, strings.TrimSpace(out))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func compactIPv4IntervalElements(raw []string) []string {
|
||||
pfxs := make([]netip.Prefix, 0, len(raw))
|
||||
for _, v := range raw {
|
||||
s := strings.TrimSpace(v)
|
||||
if s == "" {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(s, "/") {
|
||||
p, err := netip.ParsePrefix(s)
|
||||
if err != nil || !p.Addr().Is4() {
|
||||
continue
|
||||
}
|
||||
pfxs = append(pfxs, p.Masked())
|
||||
continue
|
||||
}
|
||||
a, err := netip.ParseAddr(s)
|
||||
if err != nil || !a.Is4() {
|
||||
continue
|
||||
}
|
||||
pfxs = append(pfxs, netip.PrefixFrom(a, 32))
|
||||
}
|
||||
|
||||
sort.Slice(pfxs, func(i, j int) bool {
|
||||
ib, jb := pfxs[i].Bits(), pfxs[j].Bits()
|
||||
if ib != jb {
|
||||
return ib < jb // broader first
|
||||
}
|
||||
return pfxs[i].Addr().Less(pfxs[j].Addr())
|
||||
})
|
||||
|
||||
out := make([]netip.Prefix, 0, len(pfxs))
|
||||
for _, p := range pfxs {
|
||||
covered := false
|
||||
for _, ex := range out {
|
||||
if ex.Contains(p.Addr()) {
|
||||
covered = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if covered {
|
||||
continue
|
||||
}
|
||||
out = append(out, p)
|
||||
}
|
||||
|
||||
res := make([]string, 0, len(out))
|
||||
for _, p := range out {
|
||||
res = append(res, p.String())
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func nftInsertAppMarkRule(target string, rel string, level int, id uint64, vpnIface string) error {
|
||||
mark := MARK_DIRECT
|
||||
if target == "vpn" {
|
||||
mark = MARK_APP
|
||||
@@ -527,6 +655,58 @@ func nftInsertAppMarkRule(target string, rel string, level int, id uint64) error
|
||||
pathLit := fmt.Sprintf("\"%s\"", rel)
|
||||
commentLit := fmt.Sprintf("\"%s\"", comment)
|
||||
|
||||
if target == "vpn" {
|
||||
if !appGuardEnabled() {
|
||||
goto insertMark
|
||||
}
|
||||
iface := strings.TrimSpace(vpnIface)
|
||||
if iface == "" {
|
||||
return fmt.Errorf("vpn interface required for app guard")
|
||||
}
|
||||
if err := updateAppMarkLocalBypassSet(iface); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
guardComment := appGuardComment(target, id)
|
||||
guardCommentLit := fmt.Sprintf("\"%s\"", guardComment)
|
||||
// IPv4: drop non-tun egress except local bypass ranges.
|
||||
_, out, code, err := runCommandTimeout(
|
||||
5*time.Second,
|
||||
"nft", "insert", "rule", "inet", appMarksTable, appMarksGuardChain,
|
||||
"socket", "cgroupv2", "level", strconv.Itoa(level), pathLit,
|
||||
"meta", "mark", MARK_APP,
|
||||
"oifname", "!=", iface,
|
||||
"ip", "daddr", "!=", "@"+appMarksLocalBypassSet,
|
||||
"drop",
|
||||
"comment", guardCommentLit,
|
||||
)
|
||||
if err != nil || code != 0 {
|
||||
if err == nil {
|
||||
err = fmt.Errorf("nft insert guard(v4) exited with %d", code)
|
||||
}
|
||||
return fmt.Errorf("nft insert app guard(v4) failed: %w (%s)", err, strings.TrimSpace(out))
|
||||
}
|
||||
|
||||
// IPv6: default deny outside VPN iface to prevent WebRTC/STUN leaks on dual-stack hosts.
|
||||
_, out, code, err = runCommandTimeout(
|
||||
5*time.Second,
|
||||
"nft", "insert", "rule", "inet", appMarksTable, appMarksGuardChain,
|
||||
"socket", "cgroupv2", "level", strconv.Itoa(level), pathLit,
|
||||
"meta", "mark", MARK_APP,
|
||||
"oifname", "!=", iface,
|
||||
"meta", "nfproto", "ipv6",
|
||||
"drop",
|
||||
"comment", guardCommentLit,
|
||||
)
|
||||
if err != nil || code != 0 {
|
||||
if err == nil {
|
||||
err = fmt.Errorf("nft insert guard(v6) exited with %d", code)
|
||||
}
|
||||
return fmt.Errorf("nft insert app guard(v6) failed: %w (%s)", err, strings.TrimSpace(out))
|
||||
}
|
||||
}
|
||||
|
||||
insertMark:
|
||||
_, out, code, err := runCommandTimeout(
|
||||
5*time.Second,
|
||||
"nft", "insert", "rule", "inet", appMarksTable, appMarksChain,
|
||||
@@ -539,27 +719,71 @@ func nftInsertAppMarkRule(target string, rel string, level int, id uint64) error
|
||||
if err == nil {
|
||||
err = fmt.Errorf("nft insert rule exited with %d", code)
|
||||
}
|
||||
_ = nftDeleteAppMarkRule(target, id)
|
||||
return fmt.Errorf("nft insert appmark rule failed: %w (%s)", err, strings.TrimSpace(out))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func nftDeleteAppMarkRule(target string, id uint64) error {
|
||||
comment := appMarkComment(target, id)
|
||||
out, _, _, _ := runCommandTimeout(5*time.Second, "nft", "-a", "list", "chain", "inet", appMarksTable, appMarksChain)
|
||||
for _, line := range strings.Split(out, "\n") {
|
||||
if !strings.Contains(line, comment) {
|
||||
continue
|
||||
comments := []string{
|
||||
appMarkComment(target, id),
|
||||
appGuardComment(target, id),
|
||||
}
|
||||
chains := []string{appMarksChain, appMarksGuardChain}
|
||||
for _, chain := range chains {
|
||||
out, _, _, _ := runCommandTimeout(5*time.Second, "nft", "-a", "list", "chain", "inet", appMarksTable, chain)
|
||||
for _, line := range strings.Split(out, "\n") {
|
||||
match := false
|
||||
for _, comment := range comments {
|
||||
if strings.Contains(line, comment) {
|
||||
match = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !match {
|
||||
continue
|
||||
}
|
||||
h := parseNftHandle(line)
|
||||
if h <= 0 {
|
||||
continue
|
||||
}
|
||||
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "delete", "rule", "inet", appMarksTable, chain, "handle", strconv.Itoa(h))
|
||||
}
|
||||
h := parseNftHandle(line)
|
||||
if h <= 0 {
|
||||
continue
|
||||
}
|
||||
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "delete", "rule", "inet", appMarksTable, appMarksChain, "handle", strconv.Itoa(h))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func nftHasAppMarkRule(target string, id uint64) bool {
|
||||
markComment := appMarkComment(target, id)
|
||||
guardComment := appGuardComment(target, id)
|
||||
|
||||
hasMark := false
|
||||
out, _, _, _ := runCommandTimeout(5*time.Second, "nft", "-a", "list", "chain", "inet", appMarksTable, appMarksChain)
|
||||
for _, line := range strings.Split(out, "\n") {
|
||||
if strings.Contains(line, markComment) {
|
||||
hasMark = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasMark {
|
||||
return false
|
||||
}
|
||||
if strings.EqualFold(strings.TrimSpace(target), "vpn") {
|
||||
if !appGuardEnabled() {
|
||||
return true
|
||||
}
|
||||
out, _, _, _ = runCommandTimeout(5*time.Second, "nft", "-a", "list", "chain", "inet", appMarksTable, appMarksGuardChain)
|
||||
for _, line := range strings.Split(out, "\n") {
|
||||
if strings.Contains(line, guardComment) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func parseNftHandle(line string) int {
|
||||
fields := strings.Fields(line)
|
||||
for i := 0; i < len(fields)-1; i++ {
|
||||
@@ -638,8 +862,20 @@ func pruneExpiredAppMarksLocked(st *appMarksState, now time.Time) (changed bool)
|
||||
}
|
||||
kept := st.Items[:0]
|
||||
for _, it := range st.Items {
|
||||
exp, err := time.Parse(time.RFC3339, strings.TrimSpace(it.ExpiresAt))
|
||||
if err != nil || !exp.After(now) {
|
||||
expRaw := strings.TrimSpace(it.ExpiresAt)
|
||||
if expRaw == "" {
|
||||
kept = append(kept, it)
|
||||
continue
|
||||
}
|
||||
exp, err := time.Parse(time.RFC3339, expRaw)
|
||||
if err != nil {
|
||||
// Corrupted timestamp: keep mark as persistent to avoid accidental route leak.
|
||||
it.ExpiresAt = ""
|
||||
kept = append(kept, it)
|
||||
changed = true
|
||||
continue
|
||||
}
|
||||
if !exp.After(now) {
|
||||
_ = nftDeleteAppMarkRule(strings.ToLower(strings.TrimSpace(it.Target)), it.ID)
|
||||
changed = true
|
||||
continue
|
||||
@@ -662,6 +898,116 @@ func upsertAppMarkItem(items []appMarkItem, next appMarkItem) []appMarkItem {
|
||||
return out
|
||||
}
|
||||
|
||||
func clearManagedAppMarkRules(chain string) {
|
||||
out, _, _, _ := runCommandTimeout(5*time.Second, "nft", "-a", "list", "chain", "inet", appMarksTable, chain)
|
||||
for _, line := range strings.Split(out, "\n") {
|
||||
l := strings.ToLower(line)
|
||||
if !strings.Contains(l, strings.ToLower(appMarkCommentPrefix)) &&
|
||||
!strings.Contains(l, strings.ToLower(appGuardCommentPrefix)) {
|
||||
continue
|
||||
}
|
||||
h := parseNftHandle(line)
|
||||
if h <= 0 {
|
||||
continue
|
||||
}
|
||||
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "delete", "rule", "inet", appMarksTable, chain, "handle", strconv.Itoa(h))
|
||||
}
|
||||
}
|
||||
|
||||
func restoreAppMarksFromState() error {
|
||||
appMarksMu.Lock()
|
||||
defer appMarksMu.Unlock()
|
||||
|
||||
if err := ensureAppMarksNft(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
st := loadAppMarksState()
|
||||
now := time.Now().UTC()
|
||||
changed := pruneExpiredAppMarksLocked(&st, now)
|
||||
|
||||
clearManagedAppMarkRules(appMarksChain)
|
||||
clearManagedAppMarkRules(appMarksGuardChain)
|
||||
|
||||
traffic := loadTrafficModeState()
|
||||
vpnIface, _ := resolveTrafficIface(traffic.PreferredIface)
|
||||
vpnIface = strings.TrimSpace(vpnIface)
|
||||
|
||||
kept := make([]appMarkItem, 0, len(st.Items))
|
||||
for _, it := range st.Items {
|
||||
target := strings.ToLower(strings.TrimSpace(it.Target))
|
||||
if target != "vpn" && target != "direct" {
|
||||
changed = true
|
||||
continue
|
||||
}
|
||||
|
||||
rel := normalizeCgroupRelOnly(it.CgroupRel)
|
||||
if rel == "" {
|
||||
rel = normalizeCgroupRelOnly(it.Cgroup)
|
||||
}
|
||||
if rel == "" {
|
||||
changed = true
|
||||
continue
|
||||
}
|
||||
|
||||
id := it.ID
|
||||
if id == 0 {
|
||||
inode, err := cgroupDirInode(rel)
|
||||
if err != nil {
|
||||
changed = true
|
||||
continue
|
||||
}
|
||||
id = inode
|
||||
it.ID = inode
|
||||
changed = true
|
||||
}
|
||||
|
||||
level := it.Level
|
||||
if level <= 0 {
|
||||
level = strings.Count(strings.Trim(rel, "/"), "/") + 1
|
||||
it.Level = level
|
||||
changed = true
|
||||
}
|
||||
|
||||
abs := "/" + strings.TrimPrefix(rel, "/")
|
||||
it.CgroupRel = rel
|
||||
it.Cgroup = abs
|
||||
|
||||
if _, err := cgroupDirInode(rel); err != nil {
|
||||
changed = true
|
||||
continue
|
||||
}
|
||||
|
||||
iface := ""
|
||||
if target == "vpn" {
|
||||
if vpnIface == "" {
|
||||
// Keep state for later retry when VPN interface appears.
|
||||
kept = append(kept, it)
|
||||
continue
|
||||
}
|
||||
iface = vpnIface
|
||||
}
|
||||
|
||||
if err := nftInsertAppMarkRule(target, rel, level, id, iface); err != nil {
|
||||
appendTraceLine("traffic", fmt.Sprintf("appmarks restore failed target=%s id=%d err=%v", target, id, err))
|
||||
kept = append(kept, it)
|
||||
continue
|
||||
}
|
||||
if !nftHasAppMarkRule(target, id) {
|
||||
appendTraceLine("traffic", fmt.Sprintf("appmarks restore post-check failed target=%s id=%d", target, id))
|
||||
kept = append(kept, it)
|
||||
continue
|
||||
}
|
||||
kept = append(kept, it)
|
||||
}
|
||||
st.Items = kept
|
||||
|
||||
if changed {
|
||||
return saveAppMarksState(st)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func loadAppMarksState() appMarksState {
|
||||
st := appMarksState{Version: 1}
|
||||
data, err := os.ReadFile(trafficAppMarksPath)
|
||||
@@ -679,18 +1025,88 @@ func loadAppMarksState() appMarksState {
|
||||
// RU: Best-effort миграция: нормализуем app_key в канонический вид.
|
||||
changed := false
|
||||
for i := range st.Items {
|
||||
st.Items[i].Target = strings.ToLower(strings.TrimSpace(st.Items[i].Target))
|
||||
canon := canonicalizeAppKey(st.Items[i].AppKey, st.Items[i].Command)
|
||||
if canon != "" && strings.TrimSpace(st.Items[i].AppKey) != canon {
|
||||
st.Items[i].AppKey = canon
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
if deduped, dedupChanged := dedupeAppMarkItems(st.Items); dedupChanged {
|
||||
st.Items = deduped
|
||||
changed = true
|
||||
}
|
||||
if changed {
|
||||
_ = saveAppMarksState(st)
|
||||
}
|
||||
return st
|
||||
}
|
||||
|
||||
func dedupeAppMarkItems(in []appMarkItem) ([]appMarkItem, bool) {
|
||||
if len(in) <= 1 {
|
||||
return in, false
|
||||
}
|
||||
out := make([]appMarkItem, 0, len(in))
|
||||
byTargetID := map[string]int{}
|
||||
byTargetApp := map[string]int{}
|
||||
changed := false
|
||||
|
||||
for _, raw := range in {
|
||||
it := raw
|
||||
it.Target = strings.ToLower(strings.TrimSpace(it.Target))
|
||||
if it.Target != "vpn" && it.Target != "direct" {
|
||||
changed = true
|
||||
continue
|
||||
}
|
||||
it.AppKey = canonicalizeAppKey(it.AppKey, it.Command)
|
||||
|
||||
if it.ID > 0 {
|
||||
idKey := fmt.Sprintf("%s:%d", it.Target, it.ID)
|
||||
if idx, ok := byTargetID[idKey]; ok {
|
||||
if preferAppMarkItem(it, out[idx]) {
|
||||
out[idx] = it
|
||||
}
|
||||
changed = true
|
||||
continue
|
||||
}
|
||||
byTargetID[idKey] = len(out)
|
||||
}
|
||||
|
||||
if it.AppKey != "" {
|
||||
appKey := it.Target + "|" + it.AppKey
|
||||
if idx, ok := byTargetApp[appKey]; ok {
|
||||
if preferAppMarkItem(it, out[idx]) {
|
||||
out[idx] = it
|
||||
}
|
||||
changed = true
|
||||
continue
|
||||
}
|
||||
byTargetApp[appKey] = len(out)
|
||||
}
|
||||
|
||||
out = append(out, it)
|
||||
}
|
||||
return out, changed
|
||||
}
|
||||
|
||||
func preferAppMarkItem(cand, cur appMarkItem) bool {
|
||||
ca := strings.TrimSpace(cand.AddedAt)
|
||||
oa := strings.TrimSpace(cur.AddedAt)
|
||||
if ca != oa {
|
||||
if ca == "" {
|
||||
return false
|
||||
}
|
||||
if oa == "" {
|
||||
return true
|
||||
}
|
||||
return ca > oa
|
||||
}
|
||||
if strings.TrimSpace(cand.Command) != "" && strings.TrimSpace(cur.Command) == "" {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func saveAppMarksState(st appMarksState) error {
|
||||
st.Version = 1
|
||||
st.UpdatedAt = time.Now().UTC().Format(time.RFC3339)
|
||||
|
||||
@@ -11,11 +11,13 @@ import (
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
trafficRulePrefMarkDirect = 11500
|
||||
trafficRulePrefMarkIngressReply = 11505
|
||||
trafficRulePrefMarkAppVPN = 11510
|
||||
trafficRulePrefDirectSubnetStart = 11600
|
||||
trafficRulePrefDirectUIDStart = 11680
|
||||
@@ -27,6 +29,13 @@ const (
|
||||
trafficRulePrefManagedMax = 12099
|
||||
trafficRulePerKindLimit = 70
|
||||
trafficAutoLocalDefault = true
|
||||
trafficIngressReplyDefault = false
|
||||
|
||||
trafficIngressPreroutingChain = "prerouting_ingress_reply"
|
||||
trafficIngressOutputChain = "output_ingress_reply"
|
||||
|
||||
trafficIngressCaptureComment = "svpn_ingress_reply_capture"
|
||||
trafficIngressRestoreComment = "svpn_ingress_reply_restore"
|
||||
)
|
||||
|
||||
var cgnatPrefix = netip.MustParsePrefix("100.64.0.0/10")
|
||||
@@ -199,6 +208,7 @@ func loadTrafficModeState() TrafficModeState {
|
||||
Mode TrafficMode `json:"mode"`
|
||||
PreferredIface string `json:"preferred_iface,omitempty"`
|
||||
AutoLocalBypass *bool `json:"auto_local_bypass,omitempty"`
|
||||
IngressReplyBypass *bool `json:"ingress_reply_bypass,omitempty"`
|
||||
ForceVPNSubnets []string `json:"force_vpn_subnets,omitempty"`
|
||||
ForceVPNUIDs []string `json:"force_vpn_uids,omitempty"`
|
||||
ForceVPNCGroups []string `json:"force_vpn_cgroups,omitempty"`
|
||||
@@ -214,6 +224,7 @@ func loadTrafficModeState() TrafficModeState {
|
||||
Mode: raw.Mode,
|
||||
PreferredIface: raw.PreferredIface,
|
||||
AutoLocalBypass: trafficAutoLocalDefault,
|
||||
IngressReplyBypass: trafficIngressReplyDefault,
|
||||
ForceVPNSubnets: append([]string(nil), raw.ForceVPNSubnets...),
|
||||
ForceVPNUIDs: append([]string(nil), raw.ForceVPNUIDs...),
|
||||
ForceVPNCGroups: append([]string(nil), raw.ForceVPNCGroups...),
|
||||
@@ -224,6 +235,9 @@ func loadTrafficModeState() TrafficModeState {
|
||||
if raw.AutoLocalBypass != nil {
|
||||
st.AutoLocalBypass = *raw.AutoLocalBypass
|
||||
}
|
||||
if raw.IngressReplyBypass != nil {
|
||||
st.IngressReplyBypass = *raw.IngressReplyBypass
|
||||
}
|
||||
return normalizeTrafficModeState(st)
|
||||
}
|
||||
|
||||
@@ -253,6 +267,7 @@ func inferTrafficModeState() TrafficModeState {
|
||||
Mode: mode,
|
||||
PreferredIface: iface,
|
||||
AutoLocalBypass: trafficAutoLocalDefault,
|
||||
IngressReplyBypass: trafficIngressReplyDefault,
|
||||
ForceVPNSubnets: nil,
|
||||
ForceVPNUIDs: nil,
|
||||
ForceVPNCGroups: nil,
|
||||
@@ -529,6 +544,116 @@ func applyAutoLocalBypass(vpnIface string) {
|
||||
}
|
||||
}
|
||||
|
||||
func nftObjectMissing(stdout, stderr string) bool {
|
||||
text := strings.ToLower(strings.TrimSpace(stdout + " " + stderr))
|
||||
return strings.Contains(text, "no such file") || strings.Contains(text, "not found")
|
||||
}
|
||||
|
||||
func ensureIngressReplyBypassChains() {
|
||||
_, _, _, _ = runCommandTimeout(5*time.Second, "nft", "add", "table", "inet", routesTableName())
|
||||
_, _, _, _ = runCommandTimeout(
|
||||
5*time.Second,
|
||||
"nft", "add", "chain", "inet", routesTableName(), trafficIngressPreroutingChain,
|
||||
"{", "type", "filter", "hook", "prerouting", "priority", "mangle;", "policy", "accept;", "}",
|
||||
)
|
||||
_, _, _, _ = runCommandTimeout(
|
||||
5*time.Second,
|
||||
"nft", "add", "chain", "inet", routesTableName(), trafficIngressOutputChain,
|
||||
"{", "type", "route", "hook", "output", "priority", "mangle;", "policy", "accept;", "}",
|
||||
)
|
||||
}
|
||||
|
||||
func flushIngressReplyBypassChains() error {
|
||||
for _, chain := range []string{trafficIngressPreroutingChain, trafficIngressOutputChain} {
|
||||
out, errOut, code, err := runCommandTimeout(5*time.Second, "nft", "flush", "chain", "inet", routesTableName(), chain)
|
||||
if err == nil && code == 0 {
|
||||
continue
|
||||
}
|
||||
if nftObjectMissing(out, errOut) {
|
||||
continue
|
||||
}
|
||||
if err == nil {
|
||||
err = fmt.Errorf("nft flush chain exited with %d", code)
|
||||
}
|
||||
return fmt.Errorf("flush %s failed: %w (%s %s)", chain, err, strings.TrimSpace(out), strings.TrimSpace(errOut))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func enableIngressReplyBypass(vpnIface string) error {
|
||||
vpnIface = strings.TrimSpace(vpnIface)
|
||||
if vpnIface == "" {
|
||||
return fmt.Errorf("empty vpn iface for ingress bypass")
|
||||
}
|
||||
|
||||
ensureIngressReplyBypassChains()
|
||||
if err := flushIngressReplyBypassChains(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
addRule := func(chain string, args ...string) error {
|
||||
out, errOut, code, err := runCommandTimeout(5*time.Second, "nft", append([]string{"add", "rule", "inet", routesTableName(), chain}, args...)...)
|
||||
if err != nil || code != 0 {
|
||||
if err == nil {
|
||||
err = fmt.Errorf("nft add rule exited with %d", code)
|
||||
}
|
||||
return fmt.Errorf("nft add rule %s failed: %w (%s %s)", chain, err, strings.TrimSpace(out), strings.TrimSpace(errOut))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// EN: Mark inbound NEW connections (except loopback/VPN iface) so reply path can stay direct in full tunnel.
|
||||
// RU: Помечаем входящие NEW-соединения (кроме loopback/VPN iface), чтобы ответ шел напрямую в full tunnel.
|
||||
if err := addRule(
|
||||
trafficIngressPreroutingChain,
|
||||
"iifname", "!=", "lo",
|
||||
"iifname", "!=", vpnIface,
|
||||
"fib", "daddr", "type", "local",
|
||||
"ct", "state", "new",
|
||||
"ct", "mark", "set", MARK_INGRESS,
|
||||
"comment", trafficIngressCaptureComment,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
// EN: Restore fwmark from ct mark in prerouting for forwarded reply traffic.
|
||||
// RU: Восстанавливаем fwmark из ct mark в prerouting для forwarded-ответов.
|
||||
if err := addRule(
|
||||
trafficIngressPreroutingChain,
|
||||
"ct", "mark", MARK_INGRESS,
|
||||
"meta", "mark", "set", MARK_INGRESS,
|
||||
"comment", trafficIngressRestoreComment,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
// EN: Restore fwmark from ct mark in output for local-process replies.
|
||||
// RU: Восстанавливаем fwmark из ct mark в output для ответов локальных процессов.
|
||||
if err := addRule(
|
||||
trafficIngressOutputChain,
|
||||
"ct", "mark", MARK_INGRESS,
|
||||
"meta", "mark", "set", MARK_INGRESS,
|
||||
"comment", trafficIngressRestoreComment,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func disableIngressReplyBypass() error {
|
||||
ensureIngressReplyBypassChains()
|
||||
return flushIngressReplyBypassChains()
|
||||
}
|
||||
|
||||
func ingressReplyNftActive() bool {
|
||||
outPre, _, codePre, _ := runCommandTimeout(5*time.Second, "nft", "-a", "list", "chain", "inet", routesTableName(), trafficIngressPreroutingChain)
|
||||
outOut, _, codeOut, _ := runCommandTimeout(5*time.Second, "nft", "-a", "list", "chain", "inet", routesTableName(), trafficIngressOutputChain)
|
||||
if codePre != 0 || codeOut != 0 {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(outPre, trafficIngressCaptureComment) &&
|
||||
strings.Contains(outPre, trafficIngressRestoreComment) &&
|
||||
strings.Contains(outOut, trafficIngressRestoreComment)
|
||||
}
|
||||
|
||||
func prefStr(v int) string {
|
||||
return strconv.Itoa(v)
|
||||
}
|
||||
@@ -827,16 +952,22 @@ func ensureTrafficRouteBase(iface string, autoLocalBypass bool) error {
|
||||
func applyTrafficMode(st TrafficModeState, iface string) error {
|
||||
st = normalizeTrafficModeState(st)
|
||||
eff := buildEffectiveOverrides(st)
|
||||
advancedActive := st.Mode == TrafficModeFullTunnel
|
||||
autoLocalActive := advancedActive && st.AutoLocalBypass
|
||||
ingressReplyActive := advancedActive && st.IngressReplyBypass
|
||||
|
||||
removeTrafficRulesForTable()
|
||||
|
||||
// EN: Ensure the policy table name exists even in direct mode so mark-based rules can be installed.
|
||||
// RU: Гарантируем наличие имени policy-table даже в direct режиме, чтобы можно было ставить mark-правила.
|
||||
ensureRoutesTableEntry()
|
||||
if err := disableIngressReplyBypass(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
needVPNTable := st.Mode != TrafficModeDirect || len(eff.VPNSubnets) > 0 || len(eff.VPNUIDs) > 0
|
||||
if needVPNTable {
|
||||
if err := ensureTrafficRouteBase(iface, st.AutoLocalBypass); err != nil {
|
||||
if err := ensureTrafficRouteBase(iface, autoLocalActive); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -852,6 +983,11 @@ func applyTrafficMode(st TrafficModeState, iface string) error {
|
||||
if err := applyRule(trafficRulePrefMarkDirect, "fwmark", MARK_DIRECT, "lookup", "main"); err != nil {
|
||||
return err
|
||||
}
|
||||
if ingressReplyActive {
|
||||
if err := applyRule(trafficRulePrefMarkIngressReply, "fwmark", MARK_INGRESS, "lookup", "main"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := applyRule(trafficRulePrefMarkAppVPN, "fwmark", MARK_APP, "lookup", routesTableName()); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -870,13 +1006,23 @@ func applyTrafficMode(st TrafficModeState, iface string) error {
|
||||
default:
|
||||
return fmt.Errorf("unknown traffic mode: %s", st.Mode)
|
||||
}
|
||||
if ingressReplyActive {
|
||||
if err := enableIngressReplyBypass(iface); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := restoreAppMarksFromState(); err != nil {
|
||||
appendTraceLine("traffic", fmt.Sprintf("appmarks restore warning: %v", err))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type trafficRulesState struct {
|
||||
Mark bool
|
||||
Full bool
|
||||
Mark bool
|
||||
Full bool
|
||||
IngressReply bool
|
||||
}
|
||||
|
||||
func readTrafficRules() trafficRulesState {
|
||||
@@ -884,7 +1030,7 @@ func readTrafficRules() trafficRulesState {
|
||||
var st trafficRulesState
|
||||
for _, line := range strings.Split(out, "\n") {
|
||||
l := strings.ToLower(strings.TrimSpace(line))
|
||||
if l == "" || !strings.Contains(l, "lookup "+routesTableName()) {
|
||||
if l == "" {
|
||||
continue
|
||||
}
|
||||
fields := strings.Fields(l)
|
||||
@@ -895,9 +1041,17 @@ func readTrafficRules() trafficRulesState {
|
||||
pref, _ := strconv.Atoi(prefRaw)
|
||||
switch pref {
|
||||
case trafficRulePrefSelective:
|
||||
st.Mark = true
|
||||
if strings.Contains(l, "lookup "+routesTableName()) {
|
||||
st.Mark = true
|
||||
}
|
||||
case trafficRulePrefFull:
|
||||
st.Full = true
|
||||
if strings.Contains(l, "lookup "+routesTableName()) {
|
||||
st.Full = true
|
||||
}
|
||||
case trafficRulePrefMarkIngressReply:
|
||||
if strings.Contains(l, "fwmark "+strings.ToLower(MARK_INGRESS)) && strings.Contains(l, "lookup main") {
|
||||
st.IngressReply = true
|
||||
}
|
||||
}
|
||||
}
|
||||
return st
|
||||
@@ -954,12 +1108,20 @@ func probeTrafficMode(mode TrafficMode, iface string) (bool, string) {
|
||||
func evaluateTrafficMode(st TrafficModeState) TrafficModeStatusResponse {
|
||||
st = normalizeTrafficModeState(st)
|
||||
eff := buildEffectiveOverrides(st)
|
||||
advancedActive := st.Mode == TrafficModeFullTunnel
|
||||
autoLocalActive := advancedActive && st.AutoLocalBypass
|
||||
ingressDesired := st.IngressReplyBypass
|
||||
ingressExpected := advancedActive && ingressDesired
|
||||
hasVPN := len(eff.VPNSubnets) > 0 || len(eff.VPNUIDs) > 0
|
||||
iface, reason := resolveTrafficIface(st.PreferredIface)
|
||||
rules := readTrafficRules()
|
||||
applied := detectAppliedTrafficMode(rules)
|
||||
ingressNft := false
|
||||
if rules.IngressReply || st.Mode == TrafficModeFullTunnel || st.IngressReplyBypass {
|
||||
ingressNft = ingressReplyNftActive()
|
||||
}
|
||||
bypassCandidates := 0
|
||||
if st.AutoLocalBypass && (st.Mode != TrafficModeDirect || hasVPN) {
|
||||
if autoLocalActive && (st.Mode != TrafficModeDirect || hasVPN) {
|
||||
bypassCandidates = len(detectAutoLocalBypassRoutes(iface))
|
||||
}
|
||||
|
||||
@@ -976,7 +1138,11 @@ func evaluateTrafficMode(st TrafficModeState) TrafficModeStatusResponse {
|
||||
DesiredMode: st.Mode,
|
||||
AppliedMode: applied,
|
||||
PreferredIface: st.PreferredIface,
|
||||
AdvancedActive: advancedActive,
|
||||
AutoLocalBypass: st.AutoLocalBypass,
|
||||
AutoLocalActive: autoLocalActive,
|
||||
IngressReplyBypass: ingressDesired,
|
||||
IngressReplyActive: rules.IngressReply && ingressNft,
|
||||
BypassCandidates: bypassCandidates,
|
||||
ForceVPNSubnets: append([]string(nil), st.ForceVPNSubnets...),
|
||||
ForceVPNUIDs: append([]string(nil), st.ForceVPNUIDs...),
|
||||
@@ -991,6 +1157,8 @@ func evaluateTrafficMode(st TrafficModeState) TrafficModeStatusResponse {
|
||||
IfaceReason: reason,
|
||||
RuleMark: rules.Mark,
|
||||
RuleFull: rules.Full,
|
||||
IngressRulePresent: rules.IngressReply,
|
||||
IngressNftActive: ingressNft,
|
||||
TableDefault: tableDefault,
|
||||
}
|
||||
|
||||
@@ -1001,14 +1169,18 @@ func evaluateTrafficMode(st TrafficModeState) TrafficModeStatusResponse {
|
||||
// direct mode can still be healthy when vpn overrides exist
|
||||
// (base full/selective rules must be absent).
|
||||
if hasVPN {
|
||||
res.Healthy = !rules.Mark && !rules.Full && tableDefault && iface != "" && res.ProbeOK
|
||||
res.Healthy = !rules.Mark && !rules.Full && !rules.IngressReply && !ingressNft && tableDefault && iface != "" && res.ProbeOK
|
||||
} else {
|
||||
res.Healthy = !rules.Mark && !rules.Full && res.ProbeOK
|
||||
res.Healthy = !rules.Mark && !rules.Full && !rules.IngressReply && !ingressNft && res.ProbeOK
|
||||
}
|
||||
case TrafficModeFullTunnel:
|
||||
res.Healthy = rules.Full && !rules.Mark && tableDefault && iface != "" && res.ProbeOK
|
||||
if ingressExpected {
|
||||
res.Healthy = rules.Full && !rules.Mark && rules.IngressReply && ingressNft && tableDefault && iface != "" && res.ProbeOK
|
||||
} else {
|
||||
res.Healthy = rules.Full && !rules.Mark && !rules.IngressReply && !ingressNft && tableDefault && iface != "" && res.ProbeOK
|
||||
}
|
||||
case TrafficModeSelective:
|
||||
res.Healthy = rules.Mark && !rules.Full && tableDefault && iface != "" && res.ProbeOK
|
||||
res.Healthy = rules.Mark && !rules.Full && !rules.IngressReply && !ingressNft && tableDefault && iface != "" && res.ProbeOK
|
||||
default:
|
||||
res.Healthy = false
|
||||
}
|
||||
@@ -1037,6 +1209,14 @@ func evaluateTrafficMode(st TrafficModeState) TrafficModeStatusResponse {
|
||||
res.Message = "conflicting traffic rules detected"
|
||||
return res
|
||||
}
|
||||
if ingressExpected && (!rules.IngressReply || !ingressNft) {
|
||||
res.Message = "ingress-reply bypass rule is not active"
|
||||
return res
|
||||
}
|
||||
if !ingressExpected && (rules.IngressReply || ingressNft) {
|
||||
res.Message = "stale ingress-reply bypass rule is active"
|
||||
return res
|
||||
}
|
||||
res.Message = "traffic mode check failed"
|
||||
return res
|
||||
}
|
||||
@@ -1067,12 +1247,102 @@ func handleTrafficModeTest(w http.ResponseWriter, r *http.Request) {
|
||||
writeJSON(w, http.StatusOK, evaluateTrafficMode(st))
|
||||
}
|
||||
|
||||
func acquireTrafficApplyLock() (*os.File, *TrafficModeStatusResponse) {
|
||||
lock, err := os.OpenFile(lockFile, os.O_CREATE|os.O_RDWR, 0o644)
|
||||
if err != nil {
|
||||
msg := evaluateTrafficMode(loadTrafficModeState())
|
||||
msg.Message = "traffic lock open failed: " + err.Error()
|
||||
return nil, &msg
|
||||
}
|
||||
if err := syscall.Flock(int(lock.Fd()), syscall.LOCK_EX|syscall.LOCK_NB); err != nil {
|
||||
_ = lock.Close()
|
||||
msg := evaluateTrafficMode(loadTrafficModeState())
|
||||
msg.Message = "traffic apply skipped: routes operation already running"
|
||||
return nil, &msg
|
||||
}
|
||||
return lock, nil
|
||||
}
|
||||
|
||||
func handleTrafficAdvancedReset(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
lock, lockMsg := acquireTrafficApplyLock()
|
||||
if lockMsg != nil {
|
||||
writeJSON(w, http.StatusOK, *lockMsg)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = syscall.Flock(int(lock.Fd()), syscall.LOCK_UN)
|
||||
_ = lock.Close()
|
||||
}()
|
||||
|
||||
prev := normalizeTrafficModeState(loadTrafficModeState())
|
||||
next := prev
|
||||
next.AutoLocalBypass = false
|
||||
next.IngressReplyBypass = false
|
||||
|
||||
nextIface, _ := resolveTrafficIface(next.PreferredIface)
|
||||
if err := applyTrafficMode(next, nextIface); err != nil {
|
||||
prevIface, _ := resolveTrafficIface(prev.PreferredIface)
|
||||
_ = applyTrafficMode(prev, prevIface)
|
||||
msg := evaluateTrafficMode(prev)
|
||||
msg.Message = "advanced reset failed, rolled back: " + err.Error()
|
||||
writeJSON(w, http.StatusOK, msg)
|
||||
return
|
||||
}
|
||||
|
||||
if err := saveTrafficModeState(next); err != nil {
|
||||
prevIface, _ := resolveTrafficIface(prev.PreferredIface)
|
||||
_ = applyTrafficMode(prev, prevIface)
|
||||
_ = saveTrafficModeState(prev)
|
||||
msg := evaluateTrafficMode(prev)
|
||||
msg.Message = "advanced reset save failed, rolled back: " + err.Error()
|
||||
writeJSON(w, http.StatusOK, msg)
|
||||
return
|
||||
}
|
||||
|
||||
res := evaluateTrafficMode(next)
|
||||
if !res.Healthy {
|
||||
prevIface, _ := resolveTrafficIface(prev.PreferredIface)
|
||||
_ = applyTrafficMode(prev, prevIface)
|
||||
_ = saveTrafficModeState(prev)
|
||||
rolled := evaluateTrafficMode(prev)
|
||||
rolled.Message = "advanced reset verification failed, rolled back: " + res.Message
|
||||
writeJSON(w, http.StatusOK, rolled)
|
||||
return
|
||||
}
|
||||
|
||||
events.push("traffic_advanced_reset", map[string]any{
|
||||
"mode": res.Mode,
|
||||
"applied": res.AppliedMode,
|
||||
"active_iface": res.ActiveIface,
|
||||
"healthy": res.Healthy,
|
||||
"auto_local": res.AutoLocalBypass,
|
||||
"ingress_reply": res.IngressReplyBypass,
|
||||
"advanced_active": res.AdvancedActive,
|
||||
})
|
||||
res.Message = "advanced bypass reset"
|
||||
writeJSON(w, http.StatusOK, res)
|
||||
}
|
||||
|
||||
func handleTrafficMode(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
st := loadTrafficModeState()
|
||||
writeJSON(w, http.StatusOK, evaluateTrafficMode(st))
|
||||
case http.MethodPost:
|
||||
lock, lockMsg := acquireTrafficApplyLock()
|
||||
if lockMsg != nil {
|
||||
writeJSON(w, http.StatusOK, *lockMsg)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = syscall.Flock(int(lock.Fd()), syscall.LOCK_UN)
|
||||
_ = lock.Close()
|
||||
}()
|
||||
|
||||
prev := loadTrafficModeState()
|
||||
next := prev
|
||||
|
||||
@@ -1094,6 +1364,9 @@ func handleTrafficMode(w http.ResponseWriter, r *http.Request) {
|
||||
if body.AutoLocalBypass != nil {
|
||||
next.AutoLocalBypass = *body.AutoLocalBypass
|
||||
}
|
||||
if body.IngressReplyBypass != nil {
|
||||
next.IngressReplyBypass = *body.IngressReplyBypass
|
||||
}
|
||||
if body.ForceVPNSubnets != nil {
|
||||
next.ForceVPNSubnets = append([]string(nil), (*body.ForceVPNSubnets)...)
|
||||
}
|
||||
@@ -1127,21 +1400,12 @@ func handleTrafficMode(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
if err := saveTrafficModeState(next); err != nil {
|
||||
writeJSON(w, http.StatusOK, TrafficModeStatusResponse{
|
||||
Mode: next.Mode,
|
||||
DesiredMode: next.Mode,
|
||||
PreferredIface: next.PreferredIface,
|
||||
AutoLocalBypass: next.AutoLocalBypass,
|
||||
ForceVPNSubnets: append([]string(nil), next.ForceVPNSubnets...),
|
||||
ForceVPNUIDs: append([]string(nil), next.ForceVPNUIDs...),
|
||||
ForceVPNCGroups: append([]string(nil), next.ForceVPNCGroups...),
|
||||
ForceDirectSubnets: append([]string(nil), next.ForceDirectSubnets...),
|
||||
ForceDirectUIDs: append([]string(nil), next.ForceDirectUIDs...),
|
||||
ForceDirectCGroups: append([]string(nil), next.ForceDirectCGroups...),
|
||||
OverridesApplied: len(next.ForceVPNSubnets) + len(next.ForceVPNUIDs) + len(next.ForceDirectSubnets) + len(next.ForceDirectUIDs),
|
||||
Healthy: false,
|
||||
Message: "state save failed: " + err.Error(),
|
||||
})
|
||||
prevIface, _ := resolveTrafficIface(prev.PreferredIface)
|
||||
_ = applyTrafficMode(prev, prevIface)
|
||||
_ = saveTrafficModeState(prev)
|
||||
rolled := evaluateTrafficMode(prev)
|
||||
rolled.Message = "state save failed, rolled back: " + err.Error()
|
||||
writeJSON(w, http.StatusOK, rolled)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1161,7 +1425,11 @@ func handleTrafficMode(w http.ResponseWriter, r *http.Request) {
|
||||
"applied": res.AppliedMode,
|
||||
"active_iface": res.ActiveIface,
|
||||
"healthy": res.Healthy,
|
||||
"advanced_active": res.AdvancedActive,
|
||||
"auto_local_bypass": res.AutoLocalBypass,
|
||||
"auto_local_active": res.AutoLocalActive,
|
||||
"ingress_reply": res.IngressReplyBypass,
|
||||
"ingress_active": res.IngressReplyActive,
|
||||
"overrides_applied": res.OverridesApplied,
|
||||
})
|
||||
writeJSON(w, http.StatusOK, res)
|
||||
|
||||
Reference in New Issue
Block a user