Files

403 lines
10 KiB
Go

package transportcfg
import (
"context"
"encoding/json"
"fmt"
"net"
"net/netip"
"net/url"
"os"
"strconv"
"strings"
"time"
)
const (
DefaultHealthTimeout = 5 * time.Second
DefaultProbeTimeout = 900 * time.Millisecond
)
type DialRunner func(ctx context.Context, network, address string) (net.Conn, error)
type Endpoint struct {
Host string
Port int
}
func (ep Endpoint) Address() string {
return net.JoinHostPort(ep.Host, strconv.Itoa(ep.Port))
}
type ProbeDeps struct {
Dial DialRunner
HealthTimeout time.Duration
ProbeTimeout time.Duration
NetnsEnabled func(Client) bool
NetnsName func(Client) string
NetnsExecCommand func(Client, string, ...string) (string, []string, error)
RunCommand func(time.Duration, string, ...string) (string, string, int, error)
CommandError func(string, string, string, int, error) error
ShellJoinArgs func([]string) string
ReadFile func(string) ([]byte, error)
ConfigInt func(map[string]any, string, int) int
}
func ProbeClientLatency(client Client, deps ProbeDeps) (int, error) {
healthTimeout := deps.HealthTimeout
if healthTimeout <= 0 {
healthTimeout = DefaultHealthTimeout
}
probeTimeout := deps.ProbeTimeout
if probeTimeout <= 0 {
probeTimeout = DefaultProbeTimeout
}
endpoints := CollectProbeEndpoints(client, deps.ReadFile, deps.ConfigInt)
if len(endpoints) == 0 {
return 0, nil
}
deadline := time.Now().Add(healthTimeout)
var firstErr error
for _, ep := range endpoints {
remaining := time.Until(deadline)
if remaining <= 0 {
break
}
if remaining > probeTimeout {
remaining = probeTimeout
}
ms, err := ProbeDialEndpoint(client, ep, remaining, deps)
if err == nil && ms >= 0 {
return ms, nil
}
if firstErr == nil && err != nil {
firstErr = err
}
}
if firstErr != nil {
return 0, firstErr
}
return 0, nil
}
func ProbeDialEndpoint(client Client, ep Endpoint, timeout time.Duration, deps ProbeDeps) (int, error) {
probeTimeout := deps.ProbeTimeout
if probeTimeout <= 0 {
probeTimeout = DefaultProbeTimeout
}
if timeout <= 0 {
timeout = probeTimeout
}
if deps.NetnsEnabled != nil && deps.NetnsEnabled(client) {
if ms, err := ProbeDialEndpointInNetns(client, ep, timeout, deps); err == nil {
return ms, nil
}
// Fall back to host probe for compatibility.
}
return ProbeDialEndpointHost(ep, timeout, deps.Dial)
}
func ProbeDialEndpointHost(ep Endpoint, timeout time.Duration, dial DialRunner) (int, error) {
host := strings.TrimSpace(ep.Host)
if addr, err := netip.ParseAddr(strings.TrimSpace(ep.Host)); err == nil {
host = addr.Unmap().String()
}
if host == "" || ep.Port <= 0 || ep.Port > 65535 {
return 0, fmt.Errorf("invalid endpoint")
}
if timeout <= 0 {
timeout = DefaultProbeTimeout
}
d := dial
if d == nil {
d = func(ctx context.Context, network, address string) (net.Conn, error) {
var nd net.Dialer
return nd.DialContext(ctx, network, address)
}
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
start := time.Now()
conn, err := d(ctx, "tcp4", net.JoinHostPort(host, strconv.Itoa(ep.Port)))
if err != nil {
return 0, err
}
_ = conn.Close()
ms := int(time.Since(start).Milliseconds())
if ms < 1 {
ms = 1
}
return ms, nil
}
func ProbeDialEndpointInNetns(client Client, ep Endpoint, timeout time.Duration, deps ProbeDeps) (int, error) {
probeTimeout := deps.ProbeTimeout
if probeTimeout <= 0 {
probeTimeout = DefaultProbeTimeout
}
if timeout <= 0 {
timeout = probeTimeout
}
if deps.NetnsName == nil || deps.NetnsExecCommand == nil || deps.RunCommand == nil {
return 0, fmt.Errorf("netns probe dependencies are not configured")
}
ns := strings.TrimSpace(deps.NetnsName(client))
if ns == "" {
return 0, fmt.Errorf("netns name is empty")
}
script := fmt.Sprintf(
"set -e; t0=$(date +%%s%%3N); exec 3<>/dev/tcp/%s/%d; exec 3>&-; t1=$(date +%%s%%3N); echo $((t1-t0))",
strings.TrimSpace(ep.Host),
ep.Port,
)
name, args, err := deps.NetnsExecCommand(client, ns, "bash", "-lc", script)
if err != nil {
return 0, err
}
start := time.Now()
stdout, stderr, code, runErr := deps.RunCommand(timeout+500*time.Millisecond, name, args...)
if runErr != nil || code != 0 {
cmdErr := deps.CommandError
if cmdErr == nil {
cmdErr = defaultCommandError
}
join := deps.ShellJoinArgs
if join == nil {
join = func(in []string) string { return ShellJoinArgs(in, ShellQuoteArg) }
}
return 0, cmdErr(join(append([]string{name}, args...)), stdout, stderr, code, runErr)
}
val := strings.TrimSpace(stdout)
ms, err := strconv.Atoi(val)
if err != nil || ms <= 0 {
ms = int(time.Since(start).Milliseconds())
}
if ms < 1 {
ms = 1
}
return ms, nil
}
func CollectProbeEndpoints(client Client, readFile func(string) ([]byte, error), configInt func(map[string]any, string, int) int) []Endpoint {
combined := make([]Endpoint, 0, 12)
if strings.ToLower(strings.TrimSpace(client.Kind)) == KindSingBox {
combined = append(combined, CollectSingBoxConfigProbeEndpoints(client, readFile)...)
}
combined = append(combined, CollectConfigProbeEndpoints(client.Config, configInt)...)
return DedupeProbeEndpoints(combined)
}
func CollectConfigProbeEndpoints(cfg map[string]any, configInt func(map[string]any, string, int) int) []Endpoint {
intGetter := configInt
if intGetter == nil {
intGetter = ConfigInt
}
if cfg == nil {
return nil
}
rawHosts := SplitCSV(ConfigString(cfg, "probe_endpoints"))
if len(rawHosts) == 0 {
host := strings.TrimSpace(ConfigString(cfg, "endpoint_host"))
port := intGetter(cfg, "endpoint_port", 443)
if host != "" && port > 0 {
rawHosts = []string{fmt.Sprintf("%s:%d", host, port)}
}
}
fallbackPort := intGetter(cfg, "probe_port", 443)
out := make([]Endpoint, 0, len(rawHosts))
for _, raw := range rawHosts {
if ep, ok := ParseDialEndpoint(raw, fallbackPort); ok {
out = append(out, ep)
}
}
return out
}
func CollectSingBoxConfigProbeEndpoints(client Client, readFile func(string) ([]byte, error)) []Endpoint {
reader := readFile
if reader == nil {
reader = os.ReadFile
}
path := strings.TrimSpace(ConfigString(client.Config, "config_path"))
if path == "" {
path = strings.TrimSpace(ConfigString(client.Config, "singbox_config_path"))
}
if path == "" {
return nil
}
data, err := reader(path)
if err != nil {
return nil
}
var raw any
if err := json.Unmarshal(data, &raw); err != nil {
return nil
}
out := make([]Endpoint, 0, 8)
if raw != nil {
CollectProbeEndpointsRecursive(raw, &out)
}
return out
}
func CollectProbeEndpointsRecursive(node any, out *[]Endpoint) {
switch v := node.(type) {
case map[string]any:
fallbackPort := 443
for _, key := range []string{"server_port", "port", "listen_port"} {
if p, ok := ParseInt(v[key]); ok && p > 0 && p <= 65535 {
fallbackPort = p
break
}
}
for _, key := range []string{"server", "address", "host"} {
raw, ok := v[key]
if !ok || raw == nil {
continue
}
if vv, ok := raw.(string); ok {
if ep, ok := ParseDialEndpoint(vv, fallbackPort); ok {
*out = append(*out, ep)
}
}
}
for _, child := range v {
CollectProbeEndpointsRecursive(child, out)
}
case []any:
for _, child := range v {
CollectProbeEndpointsRecursive(child, out)
}
}
}
func ParseDialEndpoint(raw string, fallbackPort int) (Endpoint, bool) {
s := strings.TrimSpace(raw)
if s == "" {
return Endpoint{}, false
}
if strings.Contains(s, "://") {
if u, err := url.Parse(s); err == nil {
host := strings.TrimSpace(u.Hostname())
if host != "" {
port := fallbackPort
if p := u.Port(); p != "" {
if parsed, err := strconv.Atoi(p); err == nil {
port = parsed
}
}
if port <= 0 {
port = fallbackPort
}
if port > 0 && port <= 65535 {
return Endpoint{Host: strings.ToLower(host), Port: port}, true
}
}
}
return Endpoint{}, false
}
host := s
port := fallbackPort
if h, p, err := net.SplitHostPort(s); err == nil {
host = strings.TrimSpace(h)
if parsed, err := strconv.Atoi(strings.TrimSpace(p)); err == nil {
port = parsed
}
} else if idx := strings.LastIndex(s, ":"); idx > 0 && idx+1 < len(s) && !strings.Contains(s[idx+1:], ":") {
candidateHost := strings.TrimSpace(s[:idx])
candidatePort := strings.TrimSpace(s[idx+1:])
if parsed, err := strconv.Atoi(candidatePort); err == nil {
host = candidateHost
port = parsed
}
}
host = strings.TrimSpace(host)
if host == "" {
return Endpoint{}, false
}
if addr, err := netip.ParseAddr(host); err == nil {
host = addr.Unmap().String()
}
if port <= 0 || port > 65535 {
return Endpoint{}, false
}
return Endpoint{Host: strings.ToLower(host), Port: port}, true
}
func DedupeProbeEndpoints(in []Endpoint) []Endpoint {
if len(in) == 0 {
return nil
}
seen := map[string]struct{}{}
out := make([]Endpoint, 0, len(in))
for _, ep := range in {
host := strings.ToLower(strings.TrimSpace(ep.Host))
if host == "" || ep.Port <= 0 || ep.Port > 65535 {
continue
}
key := fmt.Sprintf("%s:%d", host, ep.Port)
if _, ok := seen[key]; ok {
continue
}
seen[key] = struct{}{}
out = append(out, Endpoint{Host: host, Port: ep.Port})
}
return out
}
func ParseInt(raw any) (int, bool) {
switch v := raw.(type) {
case int:
return v, true
case int32:
return int(v), true
case int64:
return int(v), true
case float64:
return int(v), true
case string:
s := strings.TrimSpace(v)
if s == "" {
return 0, false
}
n, err := strconv.Atoi(s)
if err != nil {
return 0, false
}
return n, true
default:
return 0, false
}
}
func SplitCSV(raw string) []string {
parts := strings.Split(raw, ",")
out := make([]string, 0, len(parts))
for _, p := range parts {
v := strings.TrimSpace(p)
if v == "" {
continue
}
out = append(out, v)
}
return out
}
func defaultCommandError(cmd, stdout, stderr string, code int, err error) error {
if err == nil {
err = fmt.Errorf("exit code %d", code)
}
cmd = strings.TrimSpace(cmd)
stderr = strings.TrimSpace(stderr)
if stderr != "" {
return fmt.Errorf("%s: %w: %s", cmd, err, stderr)
}
stdout = strings.TrimSpace(stdout)
if stdout != "" {
return fmt.Errorf("%s: %w: %s", cmd, err, stdout)
}
return fmt.Errorf("%s: %w", cmd, err)
}