package transportcfg import ( "context" "encoding/json" "fmt" "net" "net/netip" "net/url" "os" "strconv" "strings" "time" ) const ( DefaultHealthTimeout = 5 * time.Second DefaultProbeTimeout = 900 * time.Millisecond ) type DialRunner func(ctx context.Context, network, address string) (net.Conn, error) type Endpoint struct { Host string Port int } func (ep Endpoint) Address() string { return net.JoinHostPort(ep.Host, strconv.Itoa(ep.Port)) } type ProbeDeps struct { Dial DialRunner HealthTimeout time.Duration ProbeTimeout time.Duration NetnsEnabled func(Client) bool NetnsName func(Client) string NetnsExecCommand func(Client, string, ...string) (string, []string, error) RunCommand func(time.Duration, string, ...string) (string, string, int, error) CommandError func(string, string, string, int, error) error ShellJoinArgs func([]string) string ReadFile func(string) ([]byte, error) ConfigInt func(map[string]any, string, int) int } func ProbeClientLatency(client Client, deps ProbeDeps) (int, error) { healthTimeout := deps.HealthTimeout if healthTimeout <= 0 { healthTimeout = DefaultHealthTimeout } probeTimeout := deps.ProbeTimeout if probeTimeout <= 0 { probeTimeout = DefaultProbeTimeout } endpoints := CollectProbeEndpoints(client, deps.ReadFile, deps.ConfigInt) if len(endpoints) == 0 { return 0, nil } deadline := time.Now().Add(healthTimeout) var firstErr error for _, ep := range endpoints { remaining := time.Until(deadline) if remaining <= 0 { break } if remaining > probeTimeout { remaining = probeTimeout } ms, err := ProbeDialEndpoint(client, ep, remaining, deps) if err == nil && ms >= 0 { return ms, nil } if firstErr == nil && err != nil { firstErr = err } } if firstErr != nil { return 0, firstErr } return 0, nil } func ProbeDialEndpoint(client Client, ep Endpoint, timeout time.Duration, deps ProbeDeps) (int, error) { probeTimeout := deps.ProbeTimeout if probeTimeout <= 0 { probeTimeout = DefaultProbeTimeout } if timeout <= 0 { timeout = probeTimeout } if deps.NetnsEnabled != nil && deps.NetnsEnabled(client) { if ms, err := ProbeDialEndpointInNetns(client, ep, timeout, deps); err == nil { return ms, nil } // Fall back to host probe for compatibility. } return ProbeDialEndpointHost(ep, timeout, deps.Dial) } func ProbeDialEndpointHost(ep Endpoint, timeout time.Duration, dial DialRunner) (int, error) { host := strings.TrimSpace(ep.Host) if addr, err := netip.ParseAddr(strings.TrimSpace(ep.Host)); err == nil { host = addr.Unmap().String() } if host == "" || ep.Port <= 0 || ep.Port > 65535 { return 0, fmt.Errorf("invalid endpoint") } if timeout <= 0 { timeout = DefaultProbeTimeout } d := dial if d == nil { d = func(ctx context.Context, network, address string) (net.Conn, error) { var nd net.Dialer return nd.DialContext(ctx, network, address) } } ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() start := time.Now() conn, err := d(ctx, "tcp4", net.JoinHostPort(host, strconv.Itoa(ep.Port))) if err != nil { return 0, err } _ = conn.Close() ms := int(time.Since(start).Milliseconds()) if ms < 1 { ms = 1 } return ms, nil } func ProbeDialEndpointInNetns(client Client, ep Endpoint, timeout time.Duration, deps ProbeDeps) (int, error) { probeTimeout := deps.ProbeTimeout if probeTimeout <= 0 { probeTimeout = DefaultProbeTimeout } if timeout <= 0 { timeout = probeTimeout } if deps.NetnsName == nil || deps.NetnsExecCommand == nil || deps.RunCommand == nil { return 0, fmt.Errorf("netns probe dependencies are not configured") } ns := strings.TrimSpace(deps.NetnsName(client)) if ns == "" { return 0, fmt.Errorf("netns name is empty") } script := fmt.Sprintf( "set -e; t0=$(date +%%s%%3N); exec 3<>/dev/tcp/%s/%d; exec 3>&-; t1=$(date +%%s%%3N); echo $((t1-t0))", strings.TrimSpace(ep.Host), ep.Port, ) name, args, err := deps.NetnsExecCommand(client, ns, "bash", "-lc", script) if err != nil { return 0, err } start := time.Now() stdout, stderr, code, runErr := deps.RunCommand(timeout+500*time.Millisecond, name, args...) if runErr != nil || code != 0 { cmdErr := deps.CommandError if cmdErr == nil { cmdErr = defaultCommandError } join := deps.ShellJoinArgs if join == nil { join = func(in []string) string { return ShellJoinArgs(in, ShellQuoteArg) } } return 0, cmdErr(join(append([]string{name}, args...)), stdout, stderr, code, runErr) } val := strings.TrimSpace(stdout) ms, err := strconv.Atoi(val) if err != nil || ms <= 0 { ms = int(time.Since(start).Milliseconds()) } if ms < 1 { ms = 1 } return ms, nil } func CollectProbeEndpoints(client Client, readFile func(string) ([]byte, error), configInt func(map[string]any, string, int) int) []Endpoint { combined := make([]Endpoint, 0, 12) if strings.ToLower(strings.TrimSpace(client.Kind)) == KindSingBox { combined = append(combined, CollectSingBoxConfigProbeEndpoints(client, readFile)...) } combined = append(combined, CollectConfigProbeEndpoints(client.Config, configInt)...) return DedupeProbeEndpoints(combined) } func CollectConfigProbeEndpoints(cfg map[string]any, configInt func(map[string]any, string, int) int) []Endpoint { intGetter := configInt if intGetter == nil { intGetter = ConfigInt } if cfg == nil { return nil } rawHosts := SplitCSV(ConfigString(cfg, "probe_endpoints")) if len(rawHosts) == 0 { host := strings.TrimSpace(ConfigString(cfg, "endpoint_host")) port := intGetter(cfg, "endpoint_port", 443) if host != "" && port > 0 { rawHosts = []string{fmt.Sprintf("%s:%d", host, port)} } } fallbackPort := intGetter(cfg, "probe_port", 443) out := make([]Endpoint, 0, len(rawHosts)) for _, raw := range rawHosts { if ep, ok := ParseDialEndpoint(raw, fallbackPort); ok { out = append(out, ep) } } return out } func CollectSingBoxConfigProbeEndpoints(client Client, readFile func(string) ([]byte, error)) []Endpoint { reader := readFile if reader == nil { reader = os.ReadFile } path := strings.TrimSpace(ConfigString(client.Config, "config_path")) if path == "" { path = strings.TrimSpace(ConfigString(client.Config, "singbox_config_path")) } if path == "" { return nil } data, err := reader(path) if err != nil { return nil } var raw any if err := json.Unmarshal(data, &raw); err != nil { return nil } out := make([]Endpoint, 0, 8) if raw != nil { CollectProbeEndpointsRecursive(raw, &out) } return out } func CollectProbeEndpointsRecursive(node any, out *[]Endpoint) { switch v := node.(type) { case map[string]any: fallbackPort := 443 for _, key := range []string{"server_port", "port", "listen_port"} { if p, ok := ParseInt(v[key]); ok && p > 0 && p <= 65535 { fallbackPort = p break } } for _, key := range []string{"server", "address", "host"} { raw, ok := v[key] if !ok || raw == nil { continue } if vv, ok := raw.(string); ok { if ep, ok := ParseDialEndpoint(vv, fallbackPort); ok { *out = append(*out, ep) } } } for _, child := range v { CollectProbeEndpointsRecursive(child, out) } case []any: for _, child := range v { CollectProbeEndpointsRecursive(child, out) } } } func ParseDialEndpoint(raw string, fallbackPort int) (Endpoint, bool) { s := strings.TrimSpace(raw) if s == "" { return Endpoint{}, false } if strings.Contains(s, "://") { if u, err := url.Parse(s); err == nil { host := strings.TrimSpace(u.Hostname()) if host != "" { port := fallbackPort if p := u.Port(); p != "" { if parsed, err := strconv.Atoi(p); err == nil { port = parsed } } if port <= 0 { port = fallbackPort } if port > 0 && port <= 65535 { return Endpoint{Host: strings.ToLower(host), Port: port}, true } } } return Endpoint{}, false } host := s port := fallbackPort if h, p, err := net.SplitHostPort(s); err == nil { host = strings.TrimSpace(h) if parsed, err := strconv.Atoi(strings.TrimSpace(p)); err == nil { port = parsed } } else if idx := strings.LastIndex(s, ":"); idx > 0 && idx+1 < len(s) && !strings.Contains(s[idx+1:], ":") { candidateHost := strings.TrimSpace(s[:idx]) candidatePort := strings.TrimSpace(s[idx+1:]) if parsed, err := strconv.Atoi(candidatePort); err == nil { host = candidateHost port = parsed } } host = strings.TrimSpace(host) if host == "" { return Endpoint{}, false } if addr, err := netip.ParseAddr(host); err == nil { host = addr.Unmap().String() } if port <= 0 || port > 65535 { return Endpoint{}, false } return Endpoint{Host: strings.ToLower(host), Port: port}, true } func DedupeProbeEndpoints(in []Endpoint) []Endpoint { if len(in) == 0 { return nil } seen := map[string]struct{}{} out := make([]Endpoint, 0, len(in)) for _, ep := range in { host := strings.ToLower(strings.TrimSpace(ep.Host)) if host == "" || ep.Port <= 0 || ep.Port > 65535 { continue } key := fmt.Sprintf("%s:%d", host, ep.Port) if _, ok := seen[key]; ok { continue } seen[key] = struct{}{} out = append(out, Endpoint{Host: host, Port: ep.Port}) } return out } func ParseInt(raw any) (int, bool) { switch v := raw.(type) { case int: return v, true case int32: return int(v), true case int64: return int(v), true case float64: return int(v), true case string: s := strings.TrimSpace(v) if s == "" { return 0, false } n, err := strconv.Atoi(s) if err != nil { return 0, false } return n, true default: return 0, false } } func SplitCSV(raw string) []string { parts := strings.Split(raw, ",") out := make([]string, 0, len(parts)) for _, p := range parts { v := strings.TrimSpace(p) if v == "" { continue } out = append(out, v) } return out } func defaultCommandError(cmd, stdout, stderr string, code int, err error) error { if err == nil { err = fmt.Errorf("exit code %d", code) } cmd = strings.TrimSpace(cmd) stderr = strings.TrimSpace(stderr) if stderr != "" { return fmt.Errorf("%s: %w: %s", cmd, err, stderr) } stdout = strings.TrimSpace(stdout) if stdout != "" { return fmt.Errorf("%s: %w: %s", cmd, err, stdout) } return fmt.Errorf("%s: %w", cmd, err) }