403 lines
10 KiB
Go
403 lines
10 KiB
Go
package transportcfg
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net"
|
|
"net/netip"
|
|
"net/url"
|
|
"os"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
const (
|
|
DefaultHealthTimeout = 5 * time.Second
|
|
DefaultProbeTimeout = 900 * time.Millisecond
|
|
)
|
|
|
|
type DialRunner func(ctx context.Context, network, address string) (net.Conn, error)
|
|
|
|
type Endpoint struct {
|
|
Host string
|
|
Port int
|
|
}
|
|
|
|
func (ep Endpoint) Address() string {
|
|
return net.JoinHostPort(ep.Host, strconv.Itoa(ep.Port))
|
|
}
|
|
|
|
type ProbeDeps struct {
|
|
Dial DialRunner
|
|
HealthTimeout time.Duration
|
|
ProbeTimeout time.Duration
|
|
NetnsEnabled func(Client) bool
|
|
NetnsName func(Client) string
|
|
NetnsExecCommand func(Client, string, ...string) (string, []string, error)
|
|
RunCommand func(time.Duration, string, ...string) (string, string, int, error)
|
|
CommandError func(string, string, string, int, error) error
|
|
ShellJoinArgs func([]string) string
|
|
ReadFile func(string) ([]byte, error)
|
|
ConfigInt func(map[string]any, string, int) int
|
|
}
|
|
|
|
func ProbeClientLatency(client Client, deps ProbeDeps) (int, error) {
|
|
healthTimeout := deps.HealthTimeout
|
|
if healthTimeout <= 0 {
|
|
healthTimeout = DefaultHealthTimeout
|
|
}
|
|
probeTimeout := deps.ProbeTimeout
|
|
if probeTimeout <= 0 {
|
|
probeTimeout = DefaultProbeTimeout
|
|
}
|
|
|
|
endpoints := CollectProbeEndpoints(client, deps.ReadFile, deps.ConfigInt)
|
|
if len(endpoints) == 0 {
|
|
return 0, nil
|
|
}
|
|
|
|
deadline := time.Now().Add(healthTimeout)
|
|
var firstErr error
|
|
for _, ep := range endpoints {
|
|
remaining := time.Until(deadline)
|
|
if remaining <= 0 {
|
|
break
|
|
}
|
|
if remaining > probeTimeout {
|
|
remaining = probeTimeout
|
|
}
|
|
ms, err := ProbeDialEndpoint(client, ep, remaining, deps)
|
|
if err == nil && ms >= 0 {
|
|
return ms, nil
|
|
}
|
|
if firstErr == nil && err != nil {
|
|
firstErr = err
|
|
}
|
|
}
|
|
if firstErr != nil {
|
|
return 0, firstErr
|
|
}
|
|
return 0, nil
|
|
}
|
|
|
|
func ProbeDialEndpoint(client Client, ep Endpoint, timeout time.Duration, deps ProbeDeps) (int, error) {
|
|
probeTimeout := deps.ProbeTimeout
|
|
if probeTimeout <= 0 {
|
|
probeTimeout = DefaultProbeTimeout
|
|
}
|
|
if timeout <= 0 {
|
|
timeout = probeTimeout
|
|
}
|
|
if deps.NetnsEnabled != nil && deps.NetnsEnabled(client) {
|
|
if ms, err := ProbeDialEndpointInNetns(client, ep, timeout, deps); err == nil {
|
|
return ms, nil
|
|
}
|
|
// Fall back to host probe for compatibility.
|
|
}
|
|
return ProbeDialEndpointHost(ep, timeout, deps.Dial)
|
|
}
|
|
|
|
func ProbeDialEndpointHost(ep Endpoint, timeout time.Duration, dial DialRunner) (int, error) {
|
|
host := strings.TrimSpace(ep.Host)
|
|
if addr, err := netip.ParseAddr(strings.TrimSpace(ep.Host)); err == nil {
|
|
host = addr.Unmap().String()
|
|
}
|
|
if host == "" || ep.Port <= 0 || ep.Port > 65535 {
|
|
return 0, fmt.Errorf("invalid endpoint")
|
|
}
|
|
if timeout <= 0 {
|
|
timeout = DefaultProbeTimeout
|
|
}
|
|
d := dial
|
|
if d == nil {
|
|
d = func(ctx context.Context, network, address string) (net.Conn, error) {
|
|
var nd net.Dialer
|
|
return nd.DialContext(ctx, network, address)
|
|
}
|
|
}
|
|
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
|
defer cancel()
|
|
start := time.Now()
|
|
conn, err := d(ctx, "tcp4", net.JoinHostPort(host, strconv.Itoa(ep.Port)))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
_ = conn.Close()
|
|
ms := int(time.Since(start).Milliseconds())
|
|
if ms < 1 {
|
|
ms = 1
|
|
}
|
|
return ms, nil
|
|
}
|
|
|
|
func ProbeDialEndpointInNetns(client Client, ep Endpoint, timeout time.Duration, deps ProbeDeps) (int, error) {
|
|
probeTimeout := deps.ProbeTimeout
|
|
if probeTimeout <= 0 {
|
|
probeTimeout = DefaultProbeTimeout
|
|
}
|
|
if timeout <= 0 {
|
|
timeout = probeTimeout
|
|
}
|
|
if deps.NetnsName == nil || deps.NetnsExecCommand == nil || deps.RunCommand == nil {
|
|
return 0, fmt.Errorf("netns probe dependencies are not configured")
|
|
}
|
|
ns := strings.TrimSpace(deps.NetnsName(client))
|
|
if ns == "" {
|
|
return 0, fmt.Errorf("netns name is empty")
|
|
}
|
|
script := fmt.Sprintf(
|
|
"set -e; t0=$(date +%%s%%3N); exec 3<>/dev/tcp/%s/%d; exec 3>&-; t1=$(date +%%s%%3N); echo $((t1-t0))",
|
|
strings.TrimSpace(ep.Host),
|
|
ep.Port,
|
|
)
|
|
name, args, err := deps.NetnsExecCommand(client, ns, "bash", "-lc", script)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
start := time.Now()
|
|
stdout, stderr, code, runErr := deps.RunCommand(timeout+500*time.Millisecond, name, args...)
|
|
if runErr != nil || code != 0 {
|
|
cmdErr := deps.CommandError
|
|
if cmdErr == nil {
|
|
cmdErr = defaultCommandError
|
|
}
|
|
join := deps.ShellJoinArgs
|
|
if join == nil {
|
|
join = func(in []string) string { return ShellJoinArgs(in, ShellQuoteArg) }
|
|
}
|
|
return 0, cmdErr(join(append([]string{name}, args...)), stdout, stderr, code, runErr)
|
|
}
|
|
val := strings.TrimSpace(stdout)
|
|
ms, err := strconv.Atoi(val)
|
|
if err != nil || ms <= 0 {
|
|
ms = int(time.Since(start).Milliseconds())
|
|
}
|
|
if ms < 1 {
|
|
ms = 1
|
|
}
|
|
return ms, nil
|
|
}
|
|
|
|
func CollectProbeEndpoints(client Client, readFile func(string) ([]byte, error), configInt func(map[string]any, string, int) int) []Endpoint {
|
|
combined := make([]Endpoint, 0, 12)
|
|
if strings.ToLower(strings.TrimSpace(client.Kind)) == KindSingBox {
|
|
combined = append(combined, CollectSingBoxConfigProbeEndpoints(client, readFile)...)
|
|
}
|
|
combined = append(combined, CollectConfigProbeEndpoints(client.Config, configInt)...)
|
|
return DedupeProbeEndpoints(combined)
|
|
}
|
|
|
|
func CollectConfigProbeEndpoints(cfg map[string]any, configInt func(map[string]any, string, int) int) []Endpoint {
|
|
intGetter := configInt
|
|
if intGetter == nil {
|
|
intGetter = ConfigInt
|
|
}
|
|
if cfg == nil {
|
|
return nil
|
|
}
|
|
rawHosts := SplitCSV(ConfigString(cfg, "probe_endpoints"))
|
|
if len(rawHosts) == 0 {
|
|
host := strings.TrimSpace(ConfigString(cfg, "endpoint_host"))
|
|
port := intGetter(cfg, "endpoint_port", 443)
|
|
if host != "" && port > 0 {
|
|
rawHosts = []string{fmt.Sprintf("%s:%d", host, port)}
|
|
}
|
|
}
|
|
fallbackPort := intGetter(cfg, "probe_port", 443)
|
|
out := make([]Endpoint, 0, len(rawHosts))
|
|
for _, raw := range rawHosts {
|
|
if ep, ok := ParseDialEndpoint(raw, fallbackPort); ok {
|
|
out = append(out, ep)
|
|
}
|
|
}
|
|
return out
|
|
}
|
|
|
|
func CollectSingBoxConfigProbeEndpoints(client Client, readFile func(string) ([]byte, error)) []Endpoint {
|
|
reader := readFile
|
|
if reader == nil {
|
|
reader = os.ReadFile
|
|
}
|
|
path := strings.TrimSpace(ConfigString(client.Config, "config_path"))
|
|
if path == "" {
|
|
path = strings.TrimSpace(ConfigString(client.Config, "singbox_config_path"))
|
|
}
|
|
if path == "" {
|
|
return nil
|
|
}
|
|
data, err := reader(path)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
var raw any
|
|
if err := json.Unmarshal(data, &raw); err != nil {
|
|
return nil
|
|
}
|
|
out := make([]Endpoint, 0, 8)
|
|
if raw != nil {
|
|
CollectProbeEndpointsRecursive(raw, &out)
|
|
}
|
|
return out
|
|
}
|
|
|
|
func CollectProbeEndpointsRecursive(node any, out *[]Endpoint) {
|
|
switch v := node.(type) {
|
|
case map[string]any:
|
|
fallbackPort := 443
|
|
for _, key := range []string{"server_port", "port", "listen_port"} {
|
|
if p, ok := ParseInt(v[key]); ok && p > 0 && p <= 65535 {
|
|
fallbackPort = p
|
|
break
|
|
}
|
|
}
|
|
for _, key := range []string{"server", "address", "host"} {
|
|
raw, ok := v[key]
|
|
if !ok || raw == nil {
|
|
continue
|
|
}
|
|
if vv, ok := raw.(string); ok {
|
|
if ep, ok := ParseDialEndpoint(vv, fallbackPort); ok {
|
|
*out = append(*out, ep)
|
|
}
|
|
}
|
|
}
|
|
for _, child := range v {
|
|
CollectProbeEndpointsRecursive(child, out)
|
|
}
|
|
case []any:
|
|
for _, child := range v {
|
|
CollectProbeEndpointsRecursive(child, out)
|
|
}
|
|
}
|
|
}
|
|
|
|
func ParseDialEndpoint(raw string, fallbackPort int) (Endpoint, bool) {
|
|
s := strings.TrimSpace(raw)
|
|
if s == "" {
|
|
return Endpoint{}, false
|
|
}
|
|
if strings.Contains(s, "://") {
|
|
if u, err := url.Parse(s); err == nil {
|
|
host := strings.TrimSpace(u.Hostname())
|
|
if host != "" {
|
|
port := fallbackPort
|
|
if p := u.Port(); p != "" {
|
|
if parsed, err := strconv.Atoi(p); err == nil {
|
|
port = parsed
|
|
}
|
|
}
|
|
if port <= 0 {
|
|
port = fallbackPort
|
|
}
|
|
if port > 0 && port <= 65535 {
|
|
return Endpoint{Host: strings.ToLower(host), Port: port}, true
|
|
}
|
|
}
|
|
}
|
|
return Endpoint{}, false
|
|
}
|
|
host := s
|
|
port := fallbackPort
|
|
if h, p, err := net.SplitHostPort(s); err == nil {
|
|
host = strings.TrimSpace(h)
|
|
if parsed, err := strconv.Atoi(strings.TrimSpace(p)); err == nil {
|
|
port = parsed
|
|
}
|
|
} else if idx := strings.LastIndex(s, ":"); idx > 0 && idx+1 < len(s) && !strings.Contains(s[idx+1:], ":") {
|
|
candidateHost := strings.TrimSpace(s[:idx])
|
|
candidatePort := strings.TrimSpace(s[idx+1:])
|
|
if parsed, err := strconv.Atoi(candidatePort); err == nil {
|
|
host = candidateHost
|
|
port = parsed
|
|
}
|
|
}
|
|
host = strings.TrimSpace(host)
|
|
if host == "" {
|
|
return Endpoint{}, false
|
|
}
|
|
if addr, err := netip.ParseAddr(host); err == nil {
|
|
host = addr.Unmap().String()
|
|
}
|
|
if port <= 0 || port > 65535 {
|
|
return Endpoint{}, false
|
|
}
|
|
return Endpoint{Host: strings.ToLower(host), Port: port}, true
|
|
}
|
|
|
|
func DedupeProbeEndpoints(in []Endpoint) []Endpoint {
|
|
if len(in) == 0 {
|
|
return nil
|
|
}
|
|
seen := map[string]struct{}{}
|
|
out := make([]Endpoint, 0, len(in))
|
|
for _, ep := range in {
|
|
host := strings.ToLower(strings.TrimSpace(ep.Host))
|
|
if host == "" || ep.Port <= 0 || ep.Port > 65535 {
|
|
continue
|
|
}
|
|
key := fmt.Sprintf("%s:%d", host, ep.Port)
|
|
if _, ok := seen[key]; ok {
|
|
continue
|
|
}
|
|
seen[key] = struct{}{}
|
|
out = append(out, Endpoint{Host: host, Port: ep.Port})
|
|
}
|
|
return out
|
|
}
|
|
|
|
func ParseInt(raw any) (int, bool) {
|
|
switch v := raw.(type) {
|
|
case int:
|
|
return v, true
|
|
case int32:
|
|
return int(v), true
|
|
case int64:
|
|
return int(v), true
|
|
case float64:
|
|
return int(v), true
|
|
case string:
|
|
s := strings.TrimSpace(v)
|
|
if s == "" {
|
|
return 0, false
|
|
}
|
|
n, err := strconv.Atoi(s)
|
|
if err != nil {
|
|
return 0, false
|
|
}
|
|
return n, true
|
|
default:
|
|
return 0, false
|
|
}
|
|
}
|
|
|
|
func SplitCSV(raw string) []string {
|
|
parts := strings.Split(raw, ",")
|
|
out := make([]string, 0, len(parts))
|
|
for _, p := range parts {
|
|
v := strings.TrimSpace(p)
|
|
if v == "" {
|
|
continue
|
|
}
|
|
out = append(out, v)
|
|
}
|
|
return out
|
|
}
|
|
|
|
func defaultCommandError(cmd, stdout, stderr string, code int, err error) error {
|
|
if err == nil {
|
|
err = fmt.Errorf("exit code %d", code)
|
|
}
|
|
cmd = strings.TrimSpace(cmd)
|
|
stderr = strings.TrimSpace(stderr)
|
|
if stderr != "" {
|
|
return fmt.Errorf("%s: %w: %s", cmd, err, stderr)
|
|
}
|
|
stdout = strings.TrimSpace(stdout)
|
|
if stdout != "" {
|
|
return fmt.Errorf("%s: %w: %s", cmd, err, stdout)
|
|
}
|
|
return fmt.Errorf("%s: %w", cmd, err)
|
|
}
|