From 903371719021d74cbec9cd24efb66e386c8b7595 Mon Sep 17 00:00:00 2001 From: saba-futai <120904569+saba-futai@users.noreply.github.com> Date: Sun, 1 Mar 2026 10:22:53 +0800 Subject: [PATCH] feat: sudoku support ws transport (#2589) --- adapter/outbound/sudoku.go | 226 +++++---- docs/config.yaml | 38 +- listener/config/sudoku.go | 1 + listener/inbound/sudoku.go | 42 +- listener/inbound/sudoku_test.go | 5 +- listener/sudoku/server.go | 117 +++-- transport/sudoku/address.go | 11 +- transport/sudoku/config.go | 43 +- transport/sudoku/crypto/ed25519.go | 4 +- transport/sudoku/crypto/record_conn.go | 374 +++++++++++++++ transport/sudoku/features_test.go | 11 +- transport/sudoku/handshake.go | 440 +++++++++++------- transport/sudoku/handshake_kip.go | 73 +++ transport/sudoku/handshake_test.go | 26 +- transport/sudoku/httpmask_tunnel.go | 87 +--- transport/sudoku/httpmask_tunnel_test.go | 80 +++- transport/sudoku/init.go | 97 ++++ transport/sudoku/init_test.go | 44 ++ transport/sudoku/kip.go | 206 ++++++++ transport/sudoku/multiplex.go | 21 +- transport/sudoku/multiplex/mux.go | 39 -- transport/sudoku/multiplex_test.go | 47 +- transport/sudoku/obfs/httpmask/auth.go | 3 +- transport/sudoku/obfs/httpmask/tunnel.go | 317 +++++++++++-- transport/sudoku/obfs/httpmask/tunnel_ws.go | 176 +++++++ .../sudoku/obfs/httpmask/tunnel_ws_server.go | 77 +++ .../sudoku/obfs/httpmask/ws_stream_conn.go | 78 ++++ transport/sudoku/obfs/sudoku/table.go | 7 +- transport/sudoku/replay.go | 74 +++ transport/sudoku/session_keys.go | 58 +++ transport/sudoku/table_probe.go | 47 +- transport/sudoku/tables.go | 26 +- transport/sudoku/uot.go | 17 +- 33 files changed, 2331 insertions(+), 581 deletions(-) create mode 100644 transport/sudoku/crypto/record_conn.go create mode 100644 transport/sudoku/handshake_kip.go create mode 100644 transport/sudoku/init.go create mode 100644 transport/sudoku/init_test.go create mode 100644 transport/sudoku/kip.go delete mode 100644 transport/sudoku/multiplex/mux.go create mode 100644 transport/sudoku/obfs/httpmask/tunnel_ws.go create mode 100644 transport/sudoku/obfs/httpmask/tunnel_ws_server.go create mode 100644 transport/sudoku/obfs/httpmask/ws_stream_conn.go create mode 100644 transport/sudoku/replay.go create mode 100644 transport/sudoku/session_keys.go diff --git a/adapter/outbound/sudoku.go b/adapter/outbound/sudoku.go index 1f6e9781..d8550bf4 100644 --- a/adapter/outbound/sudoku.go +++ b/adapter/outbound/sudoku.go @@ -11,6 +11,7 @@ import ( N "github.com/metacubex/mihomo/common/net" C "github.com/metacubex/mihomo/constant" "github.com/metacubex/mihomo/transport/sudoku" + "github.com/metacubex/mihomo/transport/sudoku/obfs/httpmask" ) type Sudoku struct { @@ -18,32 +19,43 @@ type Sudoku struct { option *SudokuOption baseConf sudoku.ProtocolConfig - httpMaskMu sync.Mutex - httpMaskClient *sudoku.HTTPMaskTunnelClient - muxMu sync.Mutex muxClient *sudoku.MultiplexClient + + httpMaskMu sync.Mutex + httpMaskClient *httpmask.TunnelClient + httpMaskKey string } type SudokuOption struct { BasicOption - Name string `proxy:"name"` - Server string `proxy:"server"` - Port int `proxy:"port"` - Key string `proxy:"key"` - AEADMethod string `proxy:"aead-method,omitempty"` - PaddingMin *int `proxy:"padding-min,omitempty"` - PaddingMax *int `proxy:"padding-max,omitempty"` - TableType string `proxy:"table-type,omitempty"` // "prefer_ascii" or "prefer_entropy" - EnablePureDownlink *bool `proxy:"enable-pure-downlink,omitempty"` - HTTPMask bool `proxy:"http-mask,omitempty"` - HTTPMaskMode string `proxy:"http-mask-mode,omitempty"` // "legacy" (default), "stream", "poll", "auto" - HTTPMaskTLS bool `proxy:"http-mask-tls,omitempty"` // only for http-mask-mode stream/poll/auto - HTTPMaskHost string `proxy:"http-mask-host,omitempty"` // optional Host/SNI override (domain or domain:port) - PathRoot string `proxy:"path-root,omitempty"` // optional first-level path prefix for HTTP tunnel endpoints - HTTPMaskMultiplex string `proxy:"http-mask-multiplex,omitempty"` // "off" (default), "auto" (reuse h1/h2), "on" (single tunnel, multi-target) - CustomTable string `proxy:"custom-table,omitempty"` // optional custom byte layout, e.g. xpxvvpvv - CustomTables []string `proxy:"custom-tables,omitempty"` // optional table rotation patterns, overrides custom-table when non-empty + Name string `proxy:"name"` + Server string `proxy:"server"` + Port int `proxy:"port"` + Key string `proxy:"key"` + AEADMethod string `proxy:"aead-method,omitempty"` + PaddingMin *int `proxy:"padding-min,omitempty"` + PaddingMax *int `proxy:"padding-max,omitempty"` + TableType string `proxy:"table-type,omitempty"` // "prefer_ascii" or "prefer_entropy" + EnablePureDownlink *bool `proxy:"enable-pure-downlink,omitempty"` + HTTPMask *bool `proxy:"http-mask,omitempty"` + HTTPMaskMode string `proxy:"http-mask-mode,omitempty"` // "legacy" (default), "stream", "poll", "auto", "ws" + HTTPMaskTLS bool `proxy:"http-mask-tls,omitempty"` // only for http-mask-mode stream/poll/auto + HTTPMaskHost string `proxy:"http-mask-host,omitempty"` // optional Host/SNI override (domain or domain:port) + PathRoot string `proxy:"path-root,omitempty"` // optional first-level path prefix for HTTP tunnel endpoints + HTTPMaskMultiplex string `proxy:"http-mask-multiplex,omitempty"` // "off" (default), "auto" (reuse h1/h2), "on" (single tunnel, multi-target) + HTTPMaskOptions *SudokuHTTPMaskOptions `proxy:"httpmask,omitempty"` + CustomTable string `proxy:"custom-table,omitempty"` // optional custom byte layout, e.g. xpxvvpvv + CustomTables []string `proxy:"custom-tables,omitempty"` // optional table rotation patterns, overrides custom-table when non-empty +} + +type SudokuHTTPMaskOptions struct { + Disable bool `proxy:"disable,omitempty"` + Mode string `proxy:"mode,omitempty"` + TLS bool `proxy:"tls,omitempty"` + Host string `proxy:"host,omitempty"` + PathRoot string `proxy:"path_root,omitempty"` + Multiplex string `proxy:"multiplex,omitempty"` } // DialContext implements C.ProxyAdapter @@ -73,7 +85,7 @@ func (s *Sudoku) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Con return nil, fmt.Errorf("encode target address failed: %w", err) } - if _, err = c.Write(addrBuf); err != nil { + if err = sudoku.WriteKIPMessage(c, sudoku.KIPTypeOpenTCP, addrBuf); err != nil { return nil, fmt.Errorf("send target address failed: %w", err) } @@ -96,9 +108,9 @@ func (s *Sudoku) ListenPacketContext(ctx context.Context, metadata *C.Metadata) return nil, err } - if err = sudoku.WritePreface(c); err != nil { + if err = sudoku.WriteKIPMessage(c, sudoku.KIPTypeStartUoT, nil); err != nil { _ = c.Close() - return nil, fmt.Errorf("send uot preface failed: %w", err) + return nil, fmt.Errorf("start uot failed: %w", err) } return newPacketConn(N.NewThreadSafePacketConn(sudoku.NewUoTPacketConn(c)), s), nil @@ -141,32 +153,45 @@ func NewSudoku(option SudokuOption) (*Sudoku, error) { return nil, fmt.Errorf("key is required") } - tableType := strings.ToLower(option.TableType) - if tableType == "" { - tableType = "prefer_ascii" + defaultConf := sudoku.DefaultConfig() + tableType, err := sudoku.NormalizeTableType(option.TableType) + if err != nil { + return nil, err } - if tableType != "prefer_ascii" && tableType != "prefer_entropy" { - return nil, fmt.Errorf("table-type must be prefer_ascii or prefer_entropy") + paddingMin, paddingMax := sudoku.ResolvePadding(option.PaddingMin, option.PaddingMax, defaultConf.PaddingMin, defaultConf.PaddingMax) + enablePureDownlink := sudoku.DerefBool(option.EnablePureDownlink, defaultConf.EnablePureDownlink) + + disableHTTPMask := defaultConf.DisableHTTPMask + if option.HTTPMask != nil { + disableHTTPMask = !*option.HTTPMask + } + httpMaskMode := defaultConf.HTTPMaskMode + if option.HTTPMaskMode != "" { + httpMaskMode = option.HTTPMaskMode + } + httpMaskTLS := option.HTTPMaskTLS + httpMaskHost := option.HTTPMaskHost + pathRoot := strings.TrimSpace(option.PathRoot) + httpMaskMultiplex := defaultConf.HTTPMaskMultiplex + if option.HTTPMaskMultiplex != "" { + httpMaskMultiplex = option.HTTPMaskMultiplex } - defaultConf := sudoku.DefaultConfig() - paddingMin := defaultConf.PaddingMin - paddingMax := defaultConf.PaddingMax - if option.PaddingMin != nil { - paddingMin = *option.PaddingMin - } - if option.PaddingMax != nil { - paddingMax = *option.PaddingMax - } - if option.PaddingMin == nil && option.PaddingMax != nil && paddingMax < paddingMin { - paddingMin = paddingMax - } - if option.PaddingMax == nil && option.PaddingMin != nil && paddingMax < paddingMin { - paddingMax = paddingMin - } - enablePureDownlink := defaultConf.EnablePureDownlink - if option.EnablePureDownlink != nil { - enablePureDownlink = *option.EnablePureDownlink + if hm := option.HTTPMaskOptions; hm != nil { + disableHTTPMask = hm.Disable + if hm.Mode != "" { + httpMaskMode = hm.Mode + } + httpMaskTLS = hm.TLS + httpMaskHost = hm.Host + if pr := strings.TrimSpace(hm.PathRoot); pr != "" { + pathRoot = pr + } + if mux := strings.TrimSpace(hm.Multiplex); mux != "" { + httpMaskMultiplex = mux + } else { + httpMaskMultiplex = defaultConf.HTTPMaskMultiplex + } } baseConf := sudoku.ProtocolConfig{ @@ -177,20 +202,14 @@ func NewSudoku(option SudokuOption) (*Sudoku, error) { PaddingMax: paddingMax, EnablePureDownlink: enablePureDownlink, HandshakeTimeoutSeconds: defaultConf.HandshakeTimeoutSeconds, - DisableHTTPMask: !option.HTTPMask, - HTTPMaskMode: defaultConf.HTTPMaskMode, - HTTPMaskTLSEnabled: option.HTTPMaskTLS, - HTTPMaskHost: option.HTTPMaskHost, - HTTPMaskPathRoot: strings.TrimSpace(option.PathRoot), - HTTPMaskMultiplex: defaultConf.HTTPMaskMultiplex, + DisableHTTPMask: disableHTTPMask, + HTTPMaskMode: httpMaskMode, + HTTPMaskTLSEnabled: httpMaskTLS, + HTTPMaskHost: httpMaskHost, + HTTPMaskPathRoot: pathRoot, + HTTPMaskMultiplex: httpMaskMultiplex, } - if option.HTTPMaskMode != "" { - baseConf.HTTPMaskMode = option.HTTPMaskMode - } - if option.HTTPMaskMultiplex != "" { - baseConf.HTTPMaskMultiplex = option.HTTPMaskMultiplex - } - tables, err := sudoku.NewTablesWithCustomPatterns(sudoku.ClientAEADSeed(option.Key), tableType, option.CustomTable, option.CustomTables) + tables, err := sudoku.NewClientTablesWithCustomPatterns(sudoku.ClientAEADSeed(option.Key), tableType, option.CustomTable, option.CustomTables) if err != nil { return nil, fmt.Errorf("build table(s) failed: %w", err) } @@ -244,7 +263,7 @@ func normalizeHTTPMaskMultiplex(mode string) string { func httpTunnelModeEnabled(mode string) bool { switch strings.ToLower(strings.TrimSpace(mode)) { - case "stream", "poll", "auto": + case "stream", "poll", "auto", "ws": return true default: return false @@ -271,14 +290,24 @@ func (s *Sudoku) dialAndHandshake(ctx context.Context, cfg *sudoku.ProtocolConfi ) if !cfg.DisableHTTPMask && httpTunnelModeEnabled(cfg.HTTPMaskMode) { muxMode := normalizeHTTPMaskMultiplex(cfg.HTTPMaskMultiplex) - switch muxMode { - case "auto", "on": - client, errX := s.getOrCreateHTTPMaskClient(cfg) - if errX != nil { - return nil, errX + if muxMode == "auto" && strings.ToLower(strings.TrimSpace(cfg.HTTPMaskMode)) != "ws" { + if client, cerr := s.getOrCreateHTTPMaskClient(cfg); cerr == nil && client != nil { + c, err = client.DialTunnel(ctx, httpmask.TunnelDialOptions{ + Mode: cfg.HTTPMaskMode, + TLSEnabled: cfg.HTTPMaskTLSEnabled, + HostOverride: cfg.HTTPMaskHost, + PathRoot: cfg.HTTPMaskPathRoot, + AuthKey: sudoku.ClientAEADSeed(cfg.Key), + Upgrade: upgrade, + Multiplex: cfg.HTTPMaskMultiplex, + DialContext: s.dialer.DialContext, + }) + if err != nil { + s.resetHTTPMaskClient() + } } - c, err = client.Dial(ctx, upgrade) - default: + } + if c == nil && err == nil { c, err = sudoku.DialHTTPMaskTunnel(ctx, cfg.ServerAddress, cfg, s.dialer.DialContext, upgrade) } if err == nil && c != nil { @@ -372,34 +401,51 @@ func (s *Sudoku) resetMuxClient() { } } -func (s *Sudoku) getOrCreateHTTPMaskClient(cfg *sudoku.ProtocolConfig) (*sudoku.HTTPMaskTunnelClient, error) { - if s == nil { - return nil, fmt.Errorf("nil adapter") - } - if cfg == nil { - return nil, fmt.Errorf("config is required") - } - - s.httpMaskMu.Lock() - defer s.httpMaskMu.Unlock() - - if s.httpMaskClient != nil { - return s.httpMaskClient, nil - } - - c, err := sudoku.NewHTTPMaskTunnelClient(cfg.ServerAddress, cfg, s.dialer.DialContext) - if err != nil { - return nil, err - } - s.httpMaskClient = c - return c, nil -} - func (s *Sudoku) resetHTTPMaskClient() { s.httpMaskMu.Lock() defer s.httpMaskMu.Unlock() if s.httpMaskClient != nil { s.httpMaskClient.CloseIdleConnections() s.httpMaskClient = nil + s.httpMaskKey = "" } } + +func (s *Sudoku) getOrCreateHTTPMaskClient(cfg *sudoku.ProtocolConfig) (*httpmask.TunnelClient, error) { + if s == nil || cfg == nil { + return nil, fmt.Errorf("nil adapter or config") + } + + key := cfg.ServerAddress + "|" + strconv.FormatBool(cfg.HTTPMaskTLSEnabled) + "|" + strings.TrimSpace(cfg.HTTPMaskHost) + + s.httpMaskMu.Lock() + if s.httpMaskClient != nil && s.httpMaskKey == key { + client := s.httpMaskClient + s.httpMaskMu.Unlock() + return client, nil + } + s.httpMaskMu.Unlock() + + client, err := httpmask.NewTunnelClient(cfg.ServerAddress, httpmask.TunnelClientOptions{ + TLSEnabled: cfg.HTTPMaskTLSEnabled, + HostOverride: cfg.HTTPMaskHost, + DialContext: s.dialer.DialContext, + MaxIdleConns: 32, + }) + if err != nil { + return nil, err + } + + s.httpMaskMu.Lock() + defer s.httpMaskMu.Unlock() + if s.httpMaskClient != nil && s.httpMaskKey == key { + client.CloseIdleConnections() + return s.httpMaskClient, nil + } + if s.httpMaskClient != nil { + s.httpMaskClient.CloseIdleConnections() + } + s.httpMaskClient = client + s.httpMaskKey = key + return client, nil +} diff --git a/docs/config.yaml b/docs/config.yaml index 864a9b80..6368f5ed 100644 --- a/docs/config.yaml +++ b/docs/config.yaml @@ -1092,12 +1092,22 @@ proxies: # socks5 table-type: prefer_ascii # 可选值:prefer_ascii、prefer_entropy 前者全ascii映射,后者保证熵值(汉明1)低于3 # custom-table: xpxvvpvv # 可选,自定义字节布局,必须包含2个x、2个p、4个v,可随意组合。启用此处则需配置`table-type`为`prefer_entropy` # custom-tables: ["xpxvvpvv", "vxpvxvvp"] # 可选,自定义字节布局列表(x/v/p),用于 xvp 模式轮换;非空时覆盖 custom-table - http-mask: true # 是否启用http掩码 - # http-mask-mode: legacy # 可选:legacy(默认)、stream(split-stream)、poll、auto(先 stream 再 poll);stream/poll/auto 支持走 CDN/反代 - # http-mask-tls: true # 可选:仅在 http-mask-mode 为 stream/poll/auto 时生效;true 强制 https;false 强制 http(不会根据端口自动推断) - # http-mask-host: "" # 可选:覆盖 Host/SNI(支持 example.com 或 example.com:443);仅在 http-mask-mode 为 stream/poll/auto 时生效 - # path-root: "" # 可选:HTTP 隧道端点一级路径前缀(双方需一致),例如 "aabbcc" 或 "/aabbcc/" => /aabbcc/session、/aabbcc/stream、/aabbcc/api/v1/upload - # http-mask-multiplex: off # 可选:off(默认)、auto(复用底层 HTTP 连接,减少建链 RTT)、on(Sudoku mux 单隧道多目标;仅在 http-mask-mode=stream/poll/auto 生效) + # 推荐:使用 httpmask 对象统一管理 HTTPMask 相关字段: + httpmask: + disable: false # true 禁用所有 HTTP 伪装/隧道 + mode: legacy # 可选:legacy(默认)、stream(split-stream)、poll、auto(先 stream 再 poll)、ws(WebSocket 隧道) + # tls: true # 可选:仅在 mode 为 stream/poll/auto/ws 时生效;true 强制 https/wss;false 强制 http/ws(不会根据端口自动推断) + # host: "" # 可选:覆盖 Host/SNI(支持 example.com 或 example.com:443);仅在 mode 为 stream/poll/auto/ws 时生效 + # path_root: "" # 可选:HTTP 隧道端点一级路径前缀(双方需一致),例如 "aabbcc" 或 "/aabbcc/" => /aabbcc/session、/aabbcc/stream、/aabbcc/api/v1/upload、/aabbcc/ws + # multiplex: off # 可选:off(默认)、auto(复用底层 HTTP 连接,减少建链 RTT)、on(Sudoku mux 单隧道多目标;仅在 mode=stream/poll/auto 生效;ws 强制 off) + # + # 向后兼容旧写法: + # http-mask: true # 是否启用 http 掩码 + # http-mask-mode: legacy # 可选:legacy(默认)、stream(split-stream)、poll、auto(先 stream 再 poll)、ws(WebSocket 隧道) + # http-mask-tls: true # 可选:仅在 http-mask-mode 为 stream/poll/auto/ws 时生效;true 强制 https/wss;false 强制 http/ws + # http-mask-host: "" # 可选:覆盖 Host/SNI(支持 example.com 或 example.com:443);仅在 http-mask-mode 为 stream/poll/auto/ws 时生效 + # path-root: "" # 可选:HTTP 隧道端点一级路径前缀(双方需一致) + # http-mask-multiplex: off # 可选:off(默认)、auto(复用底层 HTTP 连接)、on(Sudoku mux 单隧道多目标;ws 强制 off) enable-pure-downlink: false # 可选:false=带宽优化下行(更快,要求 aead-method != none);true=纯 Sudoku 下行 # anytls @@ -1663,9 +1673,19 @@ listeners: # custom-tables: ["xpxvvpvv", "vxpvxvvp"] # 可选,自定义字节布局列表(x/v/p),用于 xvp 模式轮换;非空时覆盖 custom-table handshake-timeout: 5 # 可选(秒) enable-pure-downlink: false # 可选:false=带宽优化下行(更快,要求 aead-method != none);true=纯 Sudoku 下行 - disable-http-mask: false # 可选:禁用 http 掩码/隧道(默认为 false) - # http-mask-mode: legacy # 可选:legacy(默认)、stream(split-stream)、poll、auto(先 stream 再 poll);stream/poll/auto 支持走 CDN/反代 - # path-root: "" # 可选:HTTP 隧道端点一级路径前缀(双方需一致),例如 "aabbcc" 或 "/aabbcc/" => /aabbcc/session、/aabbcc/stream、/aabbcc/api/v1/upload + # 推荐:使用 httpmask 对象统一管理 HTTPMask 相关字段: + httpmask: + disable: false # true 禁用所有 HTTP 伪装/隧道 + mode: legacy # 可选:legacy(默认)、stream(split-stream)、poll、auto(先 stream 再 poll)、ws(WebSocket 隧道) + # path_root: "" # 可选:HTTP 隧道端点一级路径前缀(双方需一致),例如 "aabbcc" 或 "/aabbcc/" => /aabbcc/session、/aabbcc/stream、/aabbcc/api/v1/upload、/aabbcc/ws + # + # 可选:当启用 HTTPMask 且识别到“像 HTTP 但不符合 tunnel/auth”的请求时,将原始字节透传给 fallback(常用于与其他服务共端口): + # fallback: "127.0.0.1:80" + # + # 向后兼容旧写法: + # disable-http-mask: false # 可选:禁用 http 掩码/隧道(默认为 false) + # http-mask-mode: legacy # 可选:legacy(默认)、stream(split-stream)、poll、auto(先 stream 再 poll)、ws(WebSocket 隧道) + # path-root: "" # 可选:HTTP 隧道端点一级路径前缀(双方需一致) diff --git a/listener/config/sudoku.go b/listener/config/sudoku.go index e22f2418..17310bcf 100644 --- a/listener/config/sudoku.go +++ b/listener/config/sudoku.go @@ -23,6 +23,7 @@ type SudokuServer struct { DisableHTTPMask bool `json:"disable-http-mask,omitempty"` HTTPMaskMode string `json:"http-mask-mode,omitempty"` PathRoot string `json:"path-root,omitempty"` + Fallback string `json:"fallback,omitempty"` // mihomo private extension (not the part of standard Sudoku protocol) MuxOption sing.MuxOption `json:"mux-option,omitempty"` diff --git a/listener/inbound/sudoku.go b/listener/inbound/sudoku.go index 04b47de0..02ef5a34 100644 --- a/listener/inbound/sudoku.go +++ b/listener/inbound/sudoku.go @@ -13,23 +13,31 @@ import ( type SudokuOption struct { BaseOption - Key string `inbound:"key"` - AEADMethod string `inbound:"aead-method,omitempty"` - PaddingMin *int `inbound:"padding-min,omitempty"` - PaddingMax *int `inbound:"padding-max,omitempty"` - TableType string `inbound:"table-type,omitempty"` // "prefer_ascii" or "prefer_entropy" - HandshakeTimeoutSecond *int `inbound:"handshake-timeout,omitempty"` - EnablePureDownlink *bool `inbound:"enable-pure-downlink,omitempty"` - CustomTable string `inbound:"custom-table,omitempty"` // optional custom byte layout, e.g. xpxvvpvv - CustomTables []string `inbound:"custom-tables,omitempty"` - DisableHTTPMask bool `inbound:"disable-http-mask,omitempty"` - HTTPMaskMode string `inbound:"http-mask-mode,omitempty"` // "legacy" (default), "stream", "poll", "auto" - PathRoot string `inbound:"path-root,omitempty"` // optional first-level path prefix for HTTP tunnel endpoints + Key string `inbound:"key"` + AEADMethod string `inbound:"aead-method,omitempty"` + PaddingMin *int `inbound:"padding-min,omitempty"` + PaddingMax *int `inbound:"padding-max,omitempty"` + TableType string `inbound:"table-type,omitempty"` // "prefer_ascii" or "prefer_entropy" + HandshakeTimeoutSecond *int `inbound:"handshake-timeout,omitempty"` + EnablePureDownlink *bool `inbound:"enable-pure-downlink,omitempty"` + CustomTable string `inbound:"custom-table,omitempty"` // optional custom byte layout, e.g. xpxvvpvv + CustomTables []string `inbound:"custom-tables,omitempty"` + DisableHTTPMask bool `inbound:"disable-http-mask,omitempty"` + HTTPMaskMode string `inbound:"http-mask-mode,omitempty"` // "legacy" (default), "stream", "poll", "auto" + PathRoot string `inbound:"path-root,omitempty"` // optional first-level path prefix for HTTP tunnel endpoints + Fallback string `inbound:"fallback,omitempty"` + HTTPMaskOptions *SudokuHTTPMaskOptions `inbound:"httpmask,omitempty"` // mihomo private extension (not the part of standard Sudoku protocol) MuxOption MuxOption `inbound:"mux-option,omitempty"` } +type SudokuHTTPMaskOptions struct { + Disable bool `inbound:"disable,omitempty"` + Mode string `inbound:"mode,omitempty"` + PathRoot string `inbound:"path_root,omitempty"` +} + func (o SudokuOption) Equal(config C.InboundConfig) bool { return optionToString(o) == optionToString(config) } @@ -65,6 +73,16 @@ func NewSudoku(options *SudokuOption) (*Sudoku, error) { DisableHTTPMask: options.DisableHTTPMask, HTTPMaskMode: options.HTTPMaskMode, PathRoot: strings.TrimSpace(options.PathRoot), + Fallback: strings.TrimSpace(options.Fallback), + } + if hm := options.HTTPMaskOptions; hm != nil { + serverConf.DisableHTTPMask = hm.Disable + if hm.Mode != "" { + serverConf.HTTPMaskMode = hm.Mode + } + if pr := strings.TrimSpace(hm.PathRoot); pr != "" { + serverConf.PathRoot = pr + } } serverConf.MuxOption = options.MuxOption.Build() diff --git a/listener/inbound/sudoku_test.go b/listener/inbound/sudoku_test.go index 1d6c4e59..2bd277b8 100644 --- a/listener/inbound/sudoku_test.go +++ b/listener/inbound/sudoku_test.go @@ -168,16 +168,17 @@ func TestInboundSudoku_CustomTable(t *testing.T) { func TestInboundSudoku_HTTPMaskMode(t *testing.T) { key := "test_key_http_mask_mode" - for _, mode := range []string{"legacy", "stream", "poll", "auto"} { + for _, mode := range []string{"ws", "stream", "poll", "auto"} { mode := mode t.Run(mode, func(t *testing.T) { inboundOptions := inbound.SudokuOption{ Key: key, HTTPMaskMode: mode, } + httpMask := true outboundOptions := outbound.SudokuOption{ Key: key, - HTTPMask: true, + HTTPMask: &httpMask, HTTPMaskMode: mode, } testInboundSudoku(t, inboundOptions, outboundOptions) diff --git a/listener/sudoku/server.go b/listener/sudoku/server.go index 5d64110b..1811b133 100644 --- a/listener/sudoku/server.go +++ b/listener/sudoku/server.go @@ -1,16 +1,19 @@ package sudoku import ( + "bytes" "errors" "io" "net" "strings" + "time" "github.com/metacubex/mihomo/adapter/inbound" N "github.com/metacubex/mihomo/common/net" "github.com/metacubex/mihomo/common/utils" C "github.com/metacubex/mihomo/constant" LC "github.com/metacubex/mihomo/listener/config" + "github.com/metacubex/mihomo/listener/inner" "github.com/metacubex/mihomo/listener/sing" "github.com/metacubex/mihomo/log" "github.com/metacubex/mihomo/transport/socks5" @@ -23,6 +26,7 @@ type Listener struct { closed bool protoConf sudoku.ProtocolConfig tunnelSrv *sudoku.HTTPMaskTunnelServer + fallback string handler *sing.ListenerHandler } @@ -49,12 +53,19 @@ func (l *Listener) Close() error { } func (l *Listener) handleConn(conn net.Conn, tunnel C.Tunnel, additions ...inbound.Addition) { + log.Debugln("[Sudoku] accepted %s", conn.RemoteAddr()) handshakeConn := conn handshakeCfg := &l.protoConf + closeConns := func() { + _ = handshakeConn.Close() + if handshakeConn != conn { + _ = conn.Close() + } + } if l.tunnelSrv != nil { c, cfg, done, err := l.tunnelSrv.WrapConn(conn) if err != nil { - _ = conn.Close() + closeConns() return } if done { @@ -68,9 +79,43 @@ func (l *Listener) handleConn(conn net.Conn, tunnel C.Tunnel, additions ...inbou } } - session, err := sudoku.ServerHandshake(handshakeConn, handshakeCfg) + if l.fallback != "" { + if r, ok := handshakeConn.(interface{ IsHTTPMaskRejected() bool }); ok && r.IsHTTPMaskRejected() { + fb, err := inner.HandleTcp(tunnel, l.fallback, "") + if err != nil { + closeConns() + return + } + N.Relay(handshakeConn, fb) + return + } + } + + cConn, meta, err := sudoku.ServerHandshake(handshakeConn, handshakeCfg) if err != nil { - _ = handshakeConn.Close() + fallbackAddr := l.fallback + var susp *sudoku.SuspiciousError + isSuspicious := errors.As(err, &susp) && susp != nil && susp.Conn != nil + if isSuspicious { + log.Warnln("[Sudoku] suspicious handshake from %s: %v", conn.RemoteAddr(), err) + if fallbackAddr != "" { + fb, err := inner.HandleTcp(tunnel, fallbackAddr, "") + if err == nil { + relayToFallback(susp.Conn, conn, fb) + return + } + } + } else { + log.Debugln("[Sudoku] handshake failed from %s: %v", conn.RemoteAddr(), err) + } + closeConns() + return + } + + session, err := sudoku.ReadServerSession(cConn, meta) + if err != nil { + log.Warnln("[Sudoku] read session failed from %s: %v", conn.RemoteAddr(), err) + _ = cConn.Close() if handshakeConn != conn { _ = conn.Close() } @@ -103,6 +148,7 @@ func (l *Listener) handleConn(conn net.Conn, tunnel C.Tunnel, additions ...inbou default: targetAddr := socks5.ParseAddr(session.Target) if targetAddr == nil { + log.Warnln("[Sudoku] invalid target from %s: %q", conn.RemoteAddr(), session.Target) _ = session.Conn.Close() return } @@ -164,6 +210,24 @@ func (p *uotPacket) LocalAddr() net.Addr { return p.rAddr } +func relayToFallback(wrapper net.Conn, rawConn net.Conn, fallback net.Conn) { + if wrapper != nil { + if recorder, ok := wrapper.(interface{ GetBufferedAndRecorded() []byte }); ok { + badData := recorder.GetBufferedAndRecorded() + if len(badData) > 0 { + _ = fallback.SetWriteDeadline(time.Now().Add(3 * time.Second)) + if _, err := io.Copy(fallback, bytes.NewReader(badData)); err != nil { + _ = fallback.Close() + _ = rawConn.Close() + return + } + _ = fallback.SetWriteDeadline(time.Time{}) + } + } + } + N.Relay(rawConn, fallback) +} + func New(config LC.SudokuServer, tunnel C.Tunnel, additions ...inbound.Addition) (*Listener, error) { if len(additions) == 0 { additions = []inbound.Addition{ @@ -188,42 +252,24 @@ func New(config LC.SudokuServer, tunnel C.Tunnel, additions ...inbound.Addition) return nil, err } - tableType := strings.ToLower(config.TableType) - if tableType == "" { - tableType = "prefer_ascii" - } - - defaultConf := sudoku.DefaultConfig() - paddingMin := defaultConf.PaddingMin - paddingMax := defaultConf.PaddingMax - if config.PaddingMin != nil { - paddingMin = *config.PaddingMin - } - if config.PaddingMax != nil { - paddingMax = *config.PaddingMax - } - if config.PaddingMin == nil && config.PaddingMax != nil && paddingMax < paddingMin { - paddingMin = paddingMax - } - if config.PaddingMax == nil && config.PaddingMin != nil && paddingMax < paddingMin { - paddingMax = paddingMin - } - enablePureDownlink := defaultConf.EnablePureDownlink - if config.EnablePureDownlink != nil { - enablePureDownlink = *config.EnablePureDownlink - } - - tables, err := sudoku.NewTablesWithCustomPatterns(config.Key, tableType, config.CustomTable, config.CustomTables) + tableType, err := sudoku.NormalizeTableType(config.TableType) if err != nil { _ = l.Close() return nil, err } - handshakeTimeout := defaultConf.HandshakeTimeoutSeconds - if config.HandshakeTimeoutSecond != nil { - handshakeTimeout = *config.HandshakeTimeoutSecond + defaultConf := sudoku.DefaultConfig() + paddingMin, paddingMax := sudoku.ResolvePadding(config.PaddingMin, config.PaddingMax, defaultConf.PaddingMin, defaultConf.PaddingMax) + enablePureDownlink := sudoku.DerefBool(config.EnablePureDownlink, defaultConf.EnablePureDownlink) + + tables, err := sudoku.NewServerTablesWithCustomPatterns(sudoku.ServerAEADSeed(config.Key), tableType, config.CustomTable, config.CustomTables) + if err != nil { + _ = l.Close() + return nil, err } + handshakeTimeout := sudoku.DerefInt(config.HandshakeTimeoutSecond, defaultConf.HandshakeTimeoutSeconds) + protoConf := sudoku.ProtocolConfig{ Key: config.Key, AEADMethod: defaultConf.AEADMethod, @@ -249,8 +295,13 @@ func New(config LC.SudokuServer, tunnel C.Tunnel, additions ...inbound.Addition) addr: config.Listen, protoConf: protoConf, handler: h, + fallback: strings.TrimSpace(config.Fallback), + } + if sl.fallback != "" { + sl.tunnelSrv = sudoku.NewHTTPMaskTunnelServerWithFallback(&sl.protoConf) + } else { + sl.tunnelSrv = sudoku.NewHTTPMaskTunnelServer(&sl.protoConf) } - sl.tunnelSrv = sudoku.NewHTTPMaskTunnelServer(&sl.protoConf) go func() { for { diff --git a/transport/sudoku/address.go b/transport/sudoku/address.go index 48b32296..d68205ad 100644 --- a/transport/sudoku/address.go +++ b/transport/sudoku/address.go @@ -6,6 +6,7 @@ import ( "io" "net" "strconv" + "strings" ) func EncodeAddress(rawAddr string) ([]byte, error) { @@ -20,13 +21,21 @@ func EncodeAddress(rawAddr string) ([]byte, error) { } var buf []byte + if i := strings.IndexByte(host, '%'); i >= 0 { + // Zone identifiers are not representable in SOCKS5 IPv6 address encoding. + host = host[:i] + } if ip := net.ParseIP(host); ip != nil { if ip4 := ip.To4(); ip4 != nil { buf = append(buf, 0x01) // IPv4 buf = append(buf, ip4...) } else { buf = append(buf, 0x04) // IPv6 - buf = append(buf, ip...) + ip16 := ip.To16() + if ip16 == nil { + return nil, fmt.Errorf("invalid ipv6: %q", host) + } + buf = append(buf, ip16...) } } else { if len(host) > 255 { diff --git a/transport/sudoku/config.go b/transport/sudoku/config.go index 6a1a465a..f5c1229c 100644 --- a/transport/sudoku/config.go +++ b/transport/sudoku/config.go @@ -50,6 +50,7 @@ type ProtocolConfig struct { // - "stream": real HTTP tunnel (split-stream), CDN-compatible // - "poll": plain HTTP tunnel (authorize/push/pull), strong restricted-network pass-through // - "auto": try stream then fall back to poll + // - "ws": WebSocket tunnel (GET upgrade), CDN-friendly HTTPMaskMode string // HTTPMaskTLSEnabled enables HTTPS for HTTP tunnel modes (client-side). @@ -109,9 +110,9 @@ func (c *ProtocolConfig) Validate() error { } switch strings.ToLower(strings.TrimSpace(c.HTTPMaskMode)) { - case "", "legacy", "stream", "poll", "auto": + case "", "legacy", "stream", "poll", "auto", "ws": default: - return fmt.Errorf("invalid http-mask-mode: %s, must be one of: legacy, stream, poll, auto", c.HTTPMaskMode) + return fmt.Errorf("invalid http-mask-mode: %s, must be one of: legacy, stream, poll, auto, ws", c.HTTPMaskMode) } if v := strings.TrimSpace(c.HTTPMaskPathRoot); v != "" { @@ -166,6 +167,44 @@ func DefaultConfig() *ProtocolConfig { } } +func DerefInt(v *int, def int) int { + if v == nil { + return def + } + return *v +} + +func DerefBool(v *bool, def bool) bool { + if v == nil { + return def + } + return *v +} + +// ResolvePadding applies defaults and keeps min/max consistent when only one side is provided. +func ResolvePadding(min, max *int, defMin, defMax int) (int, int) { + paddingMin := DerefInt(min, defMin) + paddingMax := DerefInt(max, defMax) + switch { + case min == nil && max != nil && paddingMax < paddingMin: + paddingMin = paddingMax + case max == nil && min != nil && paddingMax < paddingMin: + paddingMax = paddingMin + } + return paddingMin, paddingMax +} + +func NormalizeTableType(tableType string) (string, error) { + switch t := strings.ToLower(strings.TrimSpace(tableType)); t { + case "", "prefer_ascii": + return "prefer_ascii", nil + case "prefer_entropy": + return "prefer_entropy", nil + default: + return "", fmt.Errorf("table-type must be prefer_ascii or prefer_entropy") + } +} + func (c *ProtocolConfig) tableCandidates() []*sudoku.Table { if c == nil { return nil diff --git a/transport/sudoku/crypto/ed25519.go b/transport/sudoku/crypto/ed25519.go index c0e4395b..1ff627d4 100644 --- a/transport/sudoku/crypto/ed25519.go +++ b/transport/sudoku/crypto/ed25519.go @@ -80,8 +80,8 @@ func RecoverPublicKey(keyHex string) (*edwards25519.Point, error) { return nil, fmt.Errorf("invalid scalar: %w", err) } return new(edwards25519.Point).ScalarBaseMult(x), nil - - } else if len(keyBytes) == 64 { + } + if len(keyBytes) == 64 { // Split Key r || k rBytes := keyBytes[:32] kBytes := keyBytes[32:] diff --git a/transport/sudoku/crypto/record_conn.go b/transport/sudoku/crypto/record_conn.go new file mode 100644 index 00000000..7a80c7f5 --- /dev/null +++ b/transport/sudoku/crypto/record_conn.go @@ -0,0 +1,374 @@ +package crypto + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/hmac" + "crypto/sha256" + "encoding/binary" + "errors" + "fmt" + "io" + "net" + "sync" + + "golang.org/x/crypto/chacha20poly1305" +) + +// KeyUpdateAfterBytes controls automatic key rotation based on plaintext bytes. +// It is a package var (not config) to enable targeted tests with smaller thresholds. +var KeyUpdateAfterBytes int64 = 32 << 20 // 32 MiB + +const ( + recordHeaderSize = 12 // epoch(uint32) + seq(uint64) - also used as nonce+AAD. + maxFrameBodySize = 65535 +) + +type recordKeys struct { + baseSend []byte + baseRecv []byte +} + +// RecordConn is a framed AEAD net.Conn with: +// - deterministic per-record nonce (epoch+seq) +// - per-direction key rotation (epoch), driven by plaintext byte counters +// - replay/out-of-order protection within the connection (strict seq check) +// +// Wire format per record: +// - uint16 bodyLen +// - header[12] = epoch(uint32 BE) || seq(uint64 BE) (plaintext) +// - ciphertext = AEAD(header as nonce, plaintext, header as AAD) +type RecordConn struct { + net.Conn + method string + + writeMu sync.Mutex + readMu sync.Mutex + + keys recordKeys + + sendAEAD cipher.AEAD + sendAEADEpoch uint32 + + recvAEAD cipher.AEAD + recvAEADEpoch uint32 + + // Send direction state. + sendEpoch uint32 + sendSeq uint64 + sendBytes int64 + + // Receive direction state. + recvEpoch uint32 + recvSeq uint64 + + readBuf bytes.Buffer + + // writeFrame is a reusable buffer for [len||header||ciphertext] on the wire. + // Guarded by writeMu. + writeFrame []byte +} + +func (c *RecordConn) CloseWrite() error { + if c == nil { + return nil + } + if cw, ok := c.Conn.(interface{ CloseWrite() error }); ok { + return cw.CloseWrite() + } + return nil +} + +func (c *RecordConn) CloseRead() error { + if c == nil { + return nil + } + if cr, ok := c.Conn.(interface{ CloseRead() error }); ok { + return cr.CloseRead() + } + return nil +} + +func NewRecordConn(conn net.Conn, method string, baseSend, baseRecv []byte) (*RecordConn, error) { + if conn == nil { + return nil, fmt.Errorf("nil conn") + } + method = normalizeAEADMethod(method) + if method != "none" { + if err := validateBaseKey(baseSend); err != nil { + return nil, fmt.Errorf("invalid send base key: %w", err) + } + if err := validateBaseKey(baseRecv); err != nil { + return nil, fmt.Errorf("invalid recv base key: %w", err) + } + } + rc := &RecordConn{Conn: conn, method: method} + rc.keys = recordKeys{baseSend: cloneBytes(baseSend), baseRecv: cloneBytes(baseRecv)} + return rc, nil +} + +func (c *RecordConn) Rekey(baseSend, baseRecv []byte) error { + if c == nil { + return fmt.Errorf("nil conn") + } + if c.method != "none" { + if err := validateBaseKey(baseSend); err != nil { + return fmt.Errorf("invalid send base key: %w", err) + } + if err := validateBaseKey(baseRecv); err != nil { + return fmt.Errorf("invalid recv base key: %w", err) + } + } + + c.writeMu.Lock() + c.readMu.Lock() + defer c.readMu.Unlock() + defer c.writeMu.Unlock() + + c.keys = recordKeys{baseSend: cloneBytes(baseSend), baseRecv: cloneBytes(baseRecv)} + c.sendEpoch = 0 + c.sendSeq = 0 + c.sendBytes = 0 + c.recvEpoch = 0 + c.recvSeq = 0 + c.readBuf.Reset() + + c.sendAEAD = nil + c.recvAEAD = nil + c.sendAEADEpoch = 0 + c.recvAEADEpoch = 0 + return nil +} + +func normalizeAEADMethod(method string) string { + switch method { + case "", "chacha20-poly1305": + return "chacha20-poly1305" + case "aes-128-gcm", "none": + return method + default: + return method + } +} + +func validateBaseKey(b []byte) error { + if len(b) < 32 { + return fmt.Errorf("need at least 32 bytes, got %d", len(b)) + } + return nil +} + +func cloneBytes(b []byte) []byte { + if len(b) == 0 { + return nil + } + return append([]byte(nil), b...) +} + +func (c *RecordConn) newAEADFor(base []byte, epoch uint32) (cipher.AEAD, error) { + if c.method == "none" { + return nil, nil + } + key := deriveEpochKey(base, epoch, c.method) + switch c.method { + case "aes-128-gcm": + block, err := aes.NewCipher(key[:16]) + if err != nil { + return nil, err + } + a, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + if a.NonceSize() != recordHeaderSize { + return nil, fmt.Errorf("unexpected gcm nonce size: %d", a.NonceSize()) + } + return a, nil + case "chacha20-poly1305": + a, err := chacha20poly1305.New(key[:32]) + if err != nil { + return nil, err + } + if a.NonceSize() != recordHeaderSize { + return nil, fmt.Errorf("unexpected chacha nonce size: %d", a.NonceSize()) + } + return a, nil + default: + return nil, fmt.Errorf("unsupported cipher: %s", c.method) + } +} + +func deriveEpochKey(base []byte, epoch uint32, method string) []byte { + var b [4]byte + binary.BigEndian.PutUint32(b[:], epoch) + mac := hmac.New(sha256.New, base) + _, _ = mac.Write([]byte("sudoku-record:")) + _, _ = mac.Write([]byte(method)) + _, _ = mac.Write(b[:]) + return mac.Sum(nil) +} + +func (c *RecordConn) maybeBumpSendEpochLocked(addedPlain int) { + if KeyUpdateAfterBytes <= 0 || c.method == "none" { + return + } + c.sendBytes += int64(addedPlain) + threshold := KeyUpdateAfterBytes * int64(c.sendEpoch+1) + if c.sendBytes < threshold { + return + } + c.sendEpoch++ + c.sendSeq = 0 +} + +func (c *RecordConn) Write(p []byte) (int, error) { + if c == nil || c.Conn == nil { + return 0, net.ErrClosed + } + if c.method == "none" { + return c.Conn.Write(p) + } + + c.writeMu.Lock() + defer c.writeMu.Unlock() + + total := 0 + for len(p) > 0 { + if c.sendAEAD == nil || c.sendAEADEpoch != c.sendEpoch { + a, err := c.newAEADFor(c.keys.baseSend, c.sendEpoch) + if err != nil { + return total, err + } + c.sendAEAD = a + c.sendAEADEpoch = c.sendEpoch + } + aead := c.sendAEAD + + maxPlain := maxFrameBodySize - recordHeaderSize - aead.Overhead() + if maxPlain <= 0 { + return total, errors.New("frame size too small") + } + n := len(p) + if n > maxPlain { + n = maxPlain + } + chunk := p[:n] + p = p[n:] + + var header [recordHeaderSize]byte + binary.BigEndian.PutUint32(header[:4], c.sendEpoch) + binary.BigEndian.PutUint64(header[4:], c.sendSeq) + c.sendSeq++ + + cipherLen := n + aead.Overhead() + bodyLen := recordHeaderSize + cipherLen + frameLen := 2 + bodyLen + if bodyLen > maxFrameBodySize { + return total, errors.New("frame too large") + } + if cap(c.writeFrame) < frameLen { + c.writeFrame = make([]byte, frameLen) + } + frame := c.writeFrame[:frameLen] + binary.BigEndian.PutUint16(frame[:2], uint16(bodyLen)) + copy(frame[2:2+recordHeaderSize], header[:]) + + dst := frame[2+recordHeaderSize : 2+recordHeaderSize : frameLen] + _ = aead.Seal(dst[:0], header[:], chunk, header[:]) + + if err := writeFull(c.Conn, frame); err != nil { + return total, err + } + + total += n + c.maybeBumpSendEpochLocked(n) + } + return total, nil +} + +func (c *RecordConn) Read(p []byte) (int, error) { + if c == nil || c.Conn == nil { + return 0, net.ErrClosed + } + if c.method == "none" { + return c.Conn.Read(p) + } + + c.readMu.Lock() + defer c.readMu.Unlock() + + if c.readBuf.Len() > 0 { + return c.readBuf.Read(p) + } + + var lenBuf [2]byte + if _, err := io.ReadFull(c.Conn, lenBuf[:]); err != nil { + return 0, err + } + bodyLen := int(binary.BigEndian.Uint16(lenBuf[:])) + if bodyLen < recordHeaderSize { + return 0, errors.New("frame too short") + } + if bodyLen > maxFrameBodySize { + return 0, errors.New("frame too large") + } + + body := make([]byte, bodyLen) + if _, err := io.ReadFull(c.Conn, body); err != nil { + return 0, err + } + header := body[:recordHeaderSize] + ciphertext := body[recordHeaderSize:] + + epoch := binary.BigEndian.Uint32(header[:4]) + seq := binary.BigEndian.Uint64(header[4:]) + + if epoch < c.recvEpoch { + return 0, fmt.Errorf("replayed epoch: got %d want >=%d", epoch, c.recvEpoch) + } + if epoch == c.recvEpoch && seq != c.recvSeq { + return 0, fmt.Errorf("out of order: epoch=%d got=%d want=%d", epoch, seq, c.recvSeq) + } + if epoch > c.recvEpoch { + const maxJump = 8 + if epoch-c.recvEpoch > maxJump { + return 0, fmt.Errorf("epoch jump too large: got=%d want<=%d", epoch-c.recvEpoch, maxJump) + } + c.recvEpoch = epoch + c.recvSeq = 0 + if seq != 0 { + return 0, fmt.Errorf("out of order: epoch advanced to %d but seq=%d", epoch, seq) + } + } + + if c.recvAEAD == nil || c.recvAEADEpoch != c.recvEpoch { + a, err := c.newAEADFor(c.keys.baseRecv, c.recvEpoch) + if err != nil { + return 0, err + } + c.recvAEAD = a + c.recvAEADEpoch = c.recvEpoch + } + aead := c.recvAEAD + + plaintext, err := aead.Open(nil, header, ciphertext, header) + if err != nil { + return 0, fmt.Errorf("decryption failed: epoch=%d seq=%d: %w", epoch, seq, err) + } + c.recvSeq++ + + c.readBuf.Write(plaintext) + return c.readBuf.Read(p) +} + +func writeFull(w io.Writer, b []byte) error { + for len(b) > 0 { + n, err := w.Write(b) + if err != nil { + return err + } + b = b[n:] + } + return nil +} diff --git a/transport/sudoku/features_test.go b/transport/sudoku/features_test.go index 68baab45..dd2b1873 100644 --- a/transport/sudoku/features_test.go +++ b/transport/sudoku/features_test.go @@ -43,12 +43,17 @@ func TestCustomTablesRotation_ProbedByServer(t *testing.T) { go func() { defer close(errCh) defer serverConn.Close() - session, err := ServerHandshake(serverConn, serverCfg) + c, meta, err := ServerHandshake(serverConn, serverCfg) if err != nil { errCh <- err return } - defer session.Conn.Close() + session, err := ReadServerSession(c, meta) + if err != nil { + errCh <- err + return + } + defer c.Close() if session.Type != SessionTypeTCP { errCh <- io.ErrUnexpectedEOF return @@ -69,7 +74,7 @@ func TestCustomTablesRotation_ProbedByServer(t *testing.T) { if err != nil { t.Fatalf("encode addr: %v", err) } - if _, err := cConn.Write(addrBuf); err != nil { + if err := WriteKIPMessage(cConn, KIPTypeOpenTCP, addrBuf); err != nil { t.Fatalf("write addr: %v", err) } diff --git a/transport/sudoku/handshake.go b/transport/sudoku/handshake.go index 1cf119c9..971d47fd 100644 --- a/transport/sudoku/handshake.go +++ b/transport/sudoku/handshake.go @@ -2,19 +2,20 @@ package sudoku import ( "bufio" - "crypto/sha256" - "encoding/binary" + "bytes" + "crypto/ecdh" + "crypto/rand" "encoding/hex" "fmt" "io" "net" + "strings" + "sync" "time" "github.com/metacubex/mihomo/transport/sudoku/crypto" "github.com/metacubex/mihomo/transport/sudoku/obfs/httpmask" "github.com/metacubex/mihomo/transport/sudoku/obfs/sudoku" - - "github.com/metacubex/mihomo/log" ) type SessionType int @@ -30,18 +31,96 @@ type ServerSession struct { Type SessionType Target string - // UserHash is a stable per-key identifier derived from the handshake payload. - // It is primarily useful for debugging / user attribution when table rotation is enabled. + // UserHash is a stable per-key identifier derived from the client hello payload. UserHash string } -type bufferedConn struct { - net.Conn - r *bufio.Reader +type HandshakeMeta struct { + UserHash string } -func (bc *bufferedConn) Read(p []byte) (int, error) { - return bc.r.Read(p) +// SuspiciousError indicates a potential probing attempt or protocol violation. +// When returned, Conn (if non-nil) should contain all bytes already consumed/buffered so the caller +// can perform a best-effort fallback relay (e.g. to a local web server) without losing the request. +type SuspiciousError struct { + Err error + Conn net.Conn +} + +func (e *SuspiciousError) Error() string { + if e == nil || e.Err == nil { + return "" + } + return e.Err.Error() +} + +func (e *SuspiciousError) Unwrap() error { return e.Err } + +type recordedConn struct { + net.Conn + recorded []byte +} + +func (rc *recordedConn) GetBufferedAndRecorded() []byte { return rc.recorded } + +type prefixedRecorderConn struct { + net.Conn + prefix []byte +} + +func (pc *prefixedRecorderConn) GetBufferedAndRecorded() []byte { + var rest []byte + if r, ok := pc.Conn.(interface{ GetBufferedAndRecorded() []byte }); ok { + rest = r.GetBufferedAndRecorded() + } + out := make([]byte, 0, len(pc.prefix)+len(rest)) + out = append(out, pc.prefix...) + out = append(out, rest...) + return out +} + +// bufferedRecorderConn wraps a net.Conn and a shared bufio.Reader so we can expose buffered bytes. +// This is used for legacy HTTP mask parsing errors so callers can fall back to a real HTTP server. +type bufferedRecorderConn struct { + net.Conn + r *bufio.Reader + recorder *bytes.Buffer + mu sync.Mutex +} + +func (bc *bufferedRecorderConn) Read(p []byte) (n int, err error) { + n, err = bc.r.Read(p) + if n > 0 && bc.recorder != nil { + bc.mu.Lock() + bc.recorder.Write(p[:n]) + bc.mu.Unlock() + } + return n, err +} + +func (bc *bufferedRecorderConn) GetBufferedAndRecorded() []byte { + if bc == nil { + return nil + } + bc.mu.Lock() + defer bc.mu.Unlock() + + var recorded []byte + if bc.recorder != nil { + recorded = bc.recorder.Bytes() + } + buffered := 0 + if bc.r != nil { + buffered = bc.r.Buffered() + } + if buffered <= 0 { + return recorded + } + peeked, _ := bc.r.Peek(buffered) + full := make([]byte, len(recorded)+len(peeked)) + copy(full, recorded) + copy(full[len(recorded):], peeked) + return full } type preBufferedConn struct { @@ -61,6 +140,26 @@ func (p *preBufferedConn) Read(b []byte) (int, error) { return p.Conn.Read(b) } +func (p *preBufferedConn) CloseWrite() error { + if p == nil { + return nil + } + if cw, ok := p.Conn.(interface{ CloseWrite() error }); ok { + return cw.CloseWrite() + } + return nil +} + +func (p *preBufferedConn) CloseRead() error { + if p == nil { + return nil + } + if cr, ok := p.Conn.(interface{ CloseRead() error }); ok { + return cr.CloseRead() + } + return nil +} + type directionalConn struct { net.Conn reader io.Reader @@ -101,6 +200,26 @@ func (c *directionalConn) Close() error { return firstErr } +func (c *directionalConn) CloseWrite() error { + if c == nil { + return nil + } + if cw, ok := c.Conn.(interface{ CloseWrite() error }); ok { + return cw.CloseWrite() + } + return nil +} + +func (c *directionalConn) CloseRead() error { + if c == nil { + return nil + } + if cr, ok := c.Conn.(interface{ CloseRead() error }); ok { + return cr.CloseRead() + } + return nil +} + func absInt64(v int64) int64 { if v < 0 { return -v @@ -108,18 +227,6 @@ func absInt64(v int64) int64 { return v } -const ( - downlinkModePure byte = 0x01 - downlinkModePacked byte = 0x02 -) - -func downlinkMode(cfg *ProtocolConfig) byte { - if cfg.EnablePureDownlink { - return downlinkModePure - } - return downlinkModePacked -} - func buildClientObfsConn(raw net.Conn, cfg *ProtocolConfig, table *sudoku.Table) net.Conn { baseSudoku := sudoku.NewConn(raw, table, cfg.PaddingMin, cfg.PaddingMax, false) if cfg.EnablePureDownlink { @@ -138,50 +245,16 @@ func buildServerObfsConn(raw net.Conn, cfg *ProtocolConfig, table *sudoku.Table, return uplinkSudoku, newDirectionalConn(raw, uplinkSudoku, packed, packed.Flush) } -func buildHandshakePayload(key string) [16]byte { - var payload [16]byte - binary.BigEndian.PutUint64(payload[:8], uint64(time.Now().Unix())) - - // Align with upstream: only decode hex bytes when this key is an ED25519 key material. - // For plain UUID/strings (even if they look like hex), hash the string bytes as-is. - src := []byte(key) - if _, err := crypto.RecoverPublicKey(key); err == nil { - if keyBytes, decErr := hex.DecodeString(key); decErr == nil && len(keyBytes) > 0 { - src = keyBytes - } +func isLegacyHTTPMaskMode(mode string) bool { + switch strings.ToLower(strings.TrimSpace(mode)) { + case "", "legacy": + return true + default: + return false } - - hash := sha256.Sum256(src) - copy(payload[8:], hash[:8]) - return payload } -func NewTable(key string, tableType string) *sudoku.Table { - table, err := NewTableWithCustom(key, tableType, "") - if err != nil { - panic(fmt.Sprintf("[Sudoku] failed to init tables: %v", err)) - } - return table -} - -func NewTableWithCustom(key string, tableType string, customTable string) (*sudoku.Table, error) { - start := time.Now() - table, err := sudoku.NewTableWithCustom(key, tableType, customTable) - if err != nil { - return nil, err - } - log.Infoln("[Sudoku] Tables initialized (%s, custom=%v) in %v", tableType, customTable != "", time.Since(start)) - return table, nil -} - -func ClientAEADSeed(key string) string { - if recovered, err := crypto.RecoverPublicKey(key); err == nil { - return crypto.EncodePoint(recovered) - } - return key -} - -// ClientHandshake performs the client-side Sudoku handshake (without sending target address). +// ClientHandshake performs the client-side Sudoku handshake (no target request). func ClientHandshake(rawConn net.Conn, cfg *ProtocolConfig) (net.Conn, error) { if cfg == nil { return nil, fmt.Errorf("config is required") @@ -190,7 +263,7 @@ func ClientHandshake(rawConn net.Conn, cfg *ProtocolConfig) (net.Conn, error) { return nil, fmt.Errorf("invalid config: %w", err) } - if !cfg.DisableHTTPMask { + if !cfg.DisableHTTPMask && isLegacyHTTPMaskMode(cfg.HTTPMaskMode) { if err := httpmask.WriteRandomRequestHeaderWithPathRoot(rawConn, cfg.ServerAddress, cfg.HTTPMaskPathRoot); err != nil { return nil, fmt.Errorf("write http mask failed: %w", err) } @@ -201,32 +274,68 @@ func ClientHandshake(rawConn net.Conn, cfg *ProtocolConfig) (net.Conn, error) { return nil, err } + seed := ClientAEADSeed(cfg.Key) obfsConn := buildClientObfsConn(rawConn, cfg, table) - cConn, err := crypto.NewAEADConn(obfsConn, ClientAEADSeed(cfg.Key), cfg.AEADMethod) + pskC2S, pskS2C := derivePSKDirectionalBases(seed) + rc, err := crypto.NewRecordConn(obfsConn, cfg.AEADMethod, pskC2S, pskS2C) if err != nil { return nil, fmt.Errorf("setup crypto failed: %w", err) } - handshake := buildHandshakePayload(cfg.Key) - if _, err := cConn.Write(handshake[:]); err != nil { - cConn.Close() - return nil, fmt.Errorf("send handshake failed: %w", err) - } - if _, err := cConn.Write([]byte{downlinkMode(cfg)}); err != nil { - cConn.Close() - return nil, fmt.Errorf("send downlink mode failed: %w", err) + if _, err := kipHandshakeClient(rc, seed, kipUserHashFromKey(cfg.Key), KIPFeatAll); err != nil { + _ = rc.Close() + return nil, err } - return cConn, nil + return rc, nil } -// ServerHandshake performs Sudoku server-side handshake and detects UoT preface. -func ServerHandshake(rawConn net.Conn, cfg *ProtocolConfig) (*ServerSession, error) { +func readFirstSessionMessage(conn net.Conn) (*KIPMessage, error) { + for { + msg, err := ReadKIPMessage(conn) + if err != nil { + return nil, err + } + if msg.Type == KIPTypeKeepAlive { + continue + } + return msg, nil + } +} + +func maybeConsumeLegacyHTTPMask(rawConn net.Conn, r *bufio.Reader, cfg *ProtocolConfig) ([]byte, *SuspiciousError) { + if rawConn == nil || r == nil || cfg == nil || cfg.DisableHTTPMask || !isLegacyHTTPMaskMode(cfg.HTTPMaskMode) { + return nil, nil + } + + peekBytes, _ := r.Peek(4) // ignore error; subsequent read will handle it + if !httpmask.LooksLikeHTTPRequestStart(peekBytes) { + return nil, nil + } + + consumed, err := httpmask.ConsumeHeader(r) + if err == nil { + return consumed, nil + } + + recorder := new(bytes.Buffer) + if len(consumed) > 0 { + recorder.Write(consumed) + } + badConn := &bufferedRecorderConn{Conn: rawConn, r: r, recorder: recorder} + return consumed, &SuspiciousError{Err: fmt.Errorf("invalid http header: %w", err), Conn: badConn} +} + +// ServerHandshake performs the server-side KIP handshake. +func ServerHandshake(rawConn net.Conn, cfg *ProtocolConfig) (net.Conn, *HandshakeMeta, error) { + if rawConn == nil { + return nil, nil, fmt.Errorf("nil conn") + } if cfg == nil { - return nil, fmt.Errorf("config is required") + return nil, nil, fmt.Errorf("config is required") } if err := cfg.Validate(); err != nil { - return nil, fmt.Errorf("invalid config: %w", err) + return nil, nil, fmt.Errorf("invalid config: %w", err) } handshakeTimeout := time.Duration(cfg.HandshakeTimeoutSeconds) * time.Second @@ -234,116 +343,113 @@ func ServerHandshake(rawConn net.Conn, cfg *ProtocolConfig) (*ServerSession, err handshakeTimeout = 5 * time.Second } - rawConn.SetReadDeadline(time.Now().Add(handshakeTimeout)) - bufReader := bufio.NewReader(rawConn) - if !cfg.DisableHTTPMask { - if peek, err := bufReader.Peek(4); err == nil && httpmask.LooksLikeHTTPRequestStart(peek) { - if _, err := httpmask.ConsumeHeader(bufReader); err != nil { - return nil, fmt.Errorf("invalid http header: %w", err) - } - } + _ = rawConn.SetReadDeadline(time.Now().Add(handshakeTimeout)) + defer func() { _ = rawConn.SetReadDeadline(time.Time{}) }() + + httpHeaderData, susp := maybeConsumeLegacyHTTPMask(rawConn, bufReader, cfg) + if susp != nil { + return nil, nil, susp } selectedTable, preRead, err := selectTableByProbe(bufReader, cfg, cfg.tableCandidates()) if err != nil { - return nil, err + combined := make([]byte, 0, len(httpHeaderData)+len(preRead)) + combined = append(combined, httpHeaderData...) + combined = append(combined, preRead...) + return nil, nil, &SuspiciousError{Err: err, Conn: &recordedConn{Conn: rawConn, recorded: combined}} } baseConn := &preBufferedConn{Conn: rawConn, buf: preRead} - bConn := &bufferedConn{Conn: baseConn, r: bufio.NewReader(baseConn)} - sConn, obfsConn := buildServerObfsConn(bConn, cfg, selectedTable, true) - cConn, err := crypto.NewAEADConn(obfsConn, cfg.Key, cfg.AEADMethod) + sConn, obfsConn := buildServerObfsConn(baseConn, cfg, selectedTable, true) + + seed := ServerAEADSeed(cfg.Key) + pskC2S, pskS2C := derivePSKDirectionalBases(seed) + // Server side: recv is client->server, send is server->client. + rc, err := crypto.NewRecordConn(obfsConn, cfg.AEADMethod, pskS2C, pskC2S) if err != nil { - return nil, fmt.Errorf("crypto setup failed: %w", err) + return nil, nil, fmt.Errorf("setup crypto failed: %w", err) } - var handshakeBuf [16]byte - if _, err := io.ReadFull(cConn, handshakeBuf[:]); err != nil { - cConn.Close() - return nil, fmt.Errorf("read handshake failed: %w", err) + msg, err := ReadKIPMessage(rc) + if err != nil { + return nil, nil, &SuspiciousError{Err: fmt.Errorf("handshake read failed: %w", err), Conn: &prefixedRecorderConn{Conn: sConn, prefix: httpHeaderData}} + } + if msg.Type != KIPTypeClientHello { + return nil, nil, &SuspiciousError{Err: fmt.Errorf("unexpected handshake message: %d", msg.Type), Conn: &prefixedRecorderConn{Conn: sConn, prefix: httpHeaderData}} + } + ch, err := DecodeKIPClientHelloPayload(msg.Payload) + if err != nil { + return nil, nil, &SuspiciousError{Err: fmt.Errorf("decode client hello failed: %w", err), Conn: &prefixedRecorderConn{Conn: sConn, prefix: httpHeaderData}} + } + if absInt64(time.Now().Unix()-ch.Timestamp.Unix()) > int64(kipHandshakeSkew.Seconds()) { + return nil, nil, &SuspiciousError{Err: fmt.Errorf("time skew/replay"), Conn: &prefixedRecorderConn{Conn: sConn, prefix: httpHeaderData}} } - ts := int64(binary.BigEndian.Uint64(handshakeBuf[:8])) - if absInt64(time.Now().Unix()-ts) > 60 { - cConn.Close() - return nil, fmt.Errorf("timestamp skew detected") + userHashHex := hex.EncodeToString(ch.UserHash[:]) + if !globalHandshakeReplay.allow(userHashHex, ch.Nonce, time.Now()) { + return nil, nil, &SuspiciousError{Err: fmt.Errorf("replay"), Conn: &prefixedRecorderConn{Conn: sConn, prefix: httpHeaderData}} + } + + curve := ecdh.X25519() + serverEphemeral, err := curve.GenerateKey(rand.Reader) + if err != nil { + return nil, nil, &SuspiciousError{Err: fmt.Errorf("ecdh generate failed: %w", err), Conn: &prefixedRecorderConn{Conn: sConn, prefix: httpHeaderData}} + } + shared, err := x25519SharedSecret(serverEphemeral, ch.ClientPub[:]) + if err != nil { + return nil, nil, &SuspiciousError{Err: fmt.Errorf("ecdh failed: %w", err), Conn: &prefixedRecorderConn{Conn: sConn, prefix: httpHeaderData}} + } + sessC2S, sessS2C, err := deriveSessionDirectionalBases(seed, shared, ch.Nonce) + if err != nil { + return nil, nil, &SuspiciousError{Err: fmt.Errorf("derive session keys failed: %w", err), Conn: &prefixedRecorderConn{Conn: sConn, prefix: httpHeaderData}} + } + + var serverPub [kipHelloPubSize]byte + copy(serverPub[:], serverEphemeral.PublicKey().Bytes()) + sh := &KIPServerHello{ + Nonce: ch.Nonce, + ServerPub: serverPub, + SelectedFeats: ch.Features & KIPFeatAll, + } + if err := WriteKIPMessage(rc, KIPTypeServerHello, sh.EncodePayload()); err != nil { + return nil, nil, &SuspiciousError{Err: fmt.Errorf("write server hello failed: %w", err), Conn: &prefixedRecorderConn{Conn: sConn, prefix: httpHeaderData}} + } + if err := rc.Rekey(sessS2C, sessC2S); err != nil { + return nil, nil, &SuspiciousError{Err: fmt.Errorf("rekey failed: %w", err), Conn: &prefixedRecorderConn{Conn: sConn, prefix: httpHeaderData}} } - userHash := userHashFromHandshake(handshakeBuf[:]) sConn.StopRecording() + return rc, &HandshakeMeta{UserHash: userHashHex}, nil +} - modeBuf := []byte{0} - if _, err := io.ReadFull(cConn, modeBuf); err != nil { - cConn.Close() - return nil, fmt.Errorf("read downlink mode failed: %w", err) +// ReadServerSession consumes the first post-handshake KIP control message and returns the session intent. +func ReadServerSession(conn net.Conn, meta *HandshakeMeta) (*ServerSession, error) { + if conn == nil { + return nil, fmt.Errorf("nil conn") } - if modeBuf[0] != downlinkMode(cfg) { - cConn.Close() - return nil, fmt.Errorf("downlink mode mismatch: client=%d server=%d", modeBuf[0], downlinkMode(cfg)) + userHash := "" + if meta != nil { + userHash = meta.UserHash } - firstByte := make([]byte, 1) - if _, err := io.ReadFull(cConn, firstByte); err != nil { - cConn.Close() - return nil, fmt.Errorf("read first byte failed: %w", err) + first, err := readFirstSessionMessage(conn) + if err != nil { + return nil, err } - if firstByte[0] == MultiplexMagicByte { - rawConn.SetReadDeadline(time.Time{}) - return &ServerSession{Conn: cConn, Type: SessionTypeMultiplex, UserHash: userHash}, nil - } - - if firstByte[0] == UoTMagicByte { - version := make([]byte, 1) - if _, err := io.ReadFull(cConn, version); err != nil { - cConn.Close() - return nil, fmt.Errorf("read uot version failed: %w", err) + switch first.Type { + case KIPTypeStartUoT: + return &ServerSession{Conn: conn, Type: SessionTypeUoT, UserHash: userHash}, nil + case KIPTypeStartMux: + return &ServerSession{Conn: conn, Type: SessionTypeMultiplex, UserHash: userHash}, nil + case KIPTypeOpenTCP: + target, err := DecodeAddress(bytes.NewReader(first.Payload)) + if err != nil { + return nil, fmt.Errorf("decode target address failed: %w", err) } - if version[0] != uotVersion { - cConn.Close() - return nil, fmt.Errorf("unsupported uot version: %d", version[0]) - } - rawConn.SetReadDeadline(time.Time{}) - return &ServerSession{Conn: cConn, Type: SessionTypeUoT, UserHash: userHash}, nil + return &ServerSession{Conn: conn, Type: SessionTypeTCP, Target: target, UserHash: userHash}, nil + default: + return nil, fmt.Errorf("unknown kip message: %d", first.Type) } - - prefixed := &preBufferedConn{Conn: cConn, buf: firstByte} - target, err := DecodeAddress(prefixed) - if err != nil { - cConn.Close() - return nil, fmt.Errorf("read target address failed: %w", err) - } - - rawConn.SetReadDeadline(time.Time{}) - log.Debugln("[Sudoku] incoming TCP session target: %s", target) - return &ServerSession{ - Conn: prefixed, - Type: SessionTypeTCP, - Target: target, - UserHash: userHash, - }, nil -} - -func GenKeyPair() (privateKey, publicKey string, err error) { - // Generate Master Key - pair, err := crypto.GenerateMasterKey() - if err != nil { - return - } - // Split the master private key to get Available Private Key - availablePrivateKey, err := crypto.SplitPrivateKey(pair.Private) - if err != nil { - return - } - privateKey = availablePrivateKey // Available Private Key for client - publicKey = crypto.EncodePoint(pair.Public) // Master Public Key for server - return -} - -func userHashFromHandshake(handshakeBuf []byte) string { - if len(handshakeBuf) < 16 { - return "" - } - return hex.EncodeToString(handshakeBuf[8:16]) } diff --git a/transport/sudoku/handshake_kip.go b/transport/sudoku/handshake_kip.go new file mode 100644 index 00000000..a99c6385 --- /dev/null +++ b/transport/sudoku/handshake_kip.go @@ -0,0 +1,73 @@ +package sudoku + +import ( + "crypto/ecdh" + "crypto/rand" + "fmt" + "io" + "time" + + "github.com/metacubex/mihomo/transport/sudoku/crypto" +) + +const kipHandshakeSkew = 60 * time.Second + +func kipHandshakeClient(rc *crypto.RecordConn, seed string, userHash [kipHelloUserHashSize]byte, feats uint32) (uint32, error) { + if rc == nil { + return 0, fmt.Errorf("nil conn") + } + + curve := ecdh.X25519() + ephemeral, err := curve.GenerateKey(rand.Reader) + if err != nil { + return 0, fmt.Errorf("ecdh generate failed: %w", err) + } + + var nonce [kipHelloNonceSize]byte + if _, err := io.ReadFull(rand.Reader, nonce[:]); err != nil { + return 0, fmt.Errorf("nonce generate failed: %w", err) + } + + var clientPub [kipHelloPubSize]byte + copy(clientPub[:], ephemeral.PublicKey().Bytes()) + + ch := &KIPClientHello{ + Timestamp: time.Now(), + UserHash: userHash, + Nonce: nonce, + ClientPub: clientPub, + Features: feats, + } + if err := WriteKIPMessage(rc, KIPTypeClientHello, ch.EncodePayload()); err != nil { + return 0, fmt.Errorf("write client hello failed: %w", err) + } + + msg, err := ReadKIPMessage(rc) + if err != nil { + return 0, fmt.Errorf("read server hello failed: %w", err) + } + if msg.Type != KIPTypeServerHello { + return 0, fmt.Errorf("unexpected handshake message: %d", msg.Type) + } + sh, err := DecodeKIPServerHelloPayload(msg.Payload) + if err != nil { + return 0, fmt.Errorf("decode server hello failed: %w", err) + } + if sh.Nonce != nonce { + return 0, fmt.Errorf("handshake nonce mismatch") + } + + shared, err := x25519SharedSecret(ephemeral, sh.ServerPub[:]) + if err != nil { + return 0, fmt.Errorf("ecdh failed: %w", err) + } + sessC2S, sessS2C, err := deriveSessionDirectionalBases(seed, shared, nonce) + if err != nil { + return 0, fmt.Errorf("derive session keys failed: %w", err) + } + if err := rc.Rekey(sessC2S, sessS2C); err != nil { + return 0, fmt.Errorf("rekey failed: %w", err) + } + + return sh.SelectedFeats, nil +} diff --git a/transport/sudoku/handshake_test.go b/transport/sudoku/handshake_test.go index b2f0999e..f62c8189 100644 --- a/transport/sudoku/handshake_test.go +++ b/transport/sudoku/handshake_test.go @@ -124,13 +124,18 @@ func runPackedTCPSession(id int, cfg *ProtocolConfig, errCh chan<- error) { // Server side go func() { - session, err := ServerHandshake(serverConn, cfg) + c, meta, err := ServerHandshake(serverConn, cfg) if err != nil { errCh <- fmt.Errorf("server handshake tcp: %w", err) return } - defer session.Conn.Close() + defer c.Close() + session, err := ReadServerSession(c, meta) + if err != nil { + errCh <- fmt.Errorf("server read session tcp: %w", err) + return + } if session.Type != SessionTypeTCP { errCh <- fmt.Errorf("unexpected session type: %v", session.Type) return @@ -159,8 +164,8 @@ func runPackedTCPSession(id int, cfg *ProtocolConfig, errCh chan<- error) { errCh <- fmt.Errorf("encode address: %w", err) return } - if _, err := cConn.Write(addrBuf); err != nil { - errCh <- fmt.Errorf("client send addr: %w", err) + if err := WriteKIPMessage(cConn, KIPTypeOpenTCP, addrBuf); err != nil { + errCh <- fmt.Errorf("client send open tcp: %w", err) return } @@ -182,13 +187,18 @@ func runPackedUoTSession(id int, cfg *ProtocolConfig, errCh chan<- error) { // Server side go func() { - session, err := ServerHandshake(serverConn, cfg) + c, meta, err := ServerHandshake(serverConn, cfg) if err != nil { errCh <- fmt.Errorf("server handshake uot: %w", err) return } - defer session.Conn.Close() + defer c.Close() + session, err := ReadServerSession(c, meta) + if err != nil { + errCh <- fmt.Errorf("server read session uot: %w", err) + return + } if session.Type != SessionTypeUoT { errCh <- fmt.Errorf("unexpected session type: %v", session.Type) return @@ -208,8 +218,8 @@ func runPackedUoTSession(id int, cfg *ProtocolConfig, errCh chan<- error) { } defer cConn.Close() - if err := WritePreface(cConn); err != nil { - errCh <- fmt.Errorf("client write preface: %w", err) + if err := WriteKIPMessage(cConn, KIPTypeStartUoT, nil); err != nil { + errCh <- fmt.Errorf("client start uot: %w", err) return } diff --git a/transport/sudoku/httpmask_tunnel.go b/transport/sudoku/httpmask_tunnel.go index 45d79abc..1ff2bb38 100644 --- a/transport/sudoku/httpmask_tunnel.go +++ b/transport/sudoku/httpmask_tunnel.go @@ -15,6 +15,14 @@ type HTTPMaskTunnelServer struct { } func NewHTTPMaskTunnelServer(cfg *ProtocolConfig) *HTTPMaskTunnelServer { + return newHTTPMaskTunnelServer(cfg, false) +} + +func NewHTTPMaskTunnelServerWithFallback(cfg *ProtocolConfig) *HTTPMaskTunnelServer { + return newHTTPMaskTunnelServer(cfg, true) +} + +func newHTTPMaskTunnelServer(cfg *ProtocolConfig, passThroughOnReject bool) *HTTPMaskTunnelServer { if cfg == nil { return &HTTPMaskTunnelServer{} } @@ -22,11 +30,13 @@ func NewHTTPMaskTunnelServer(cfg *ProtocolConfig) *HTTPMaskTunnelServer { var ts *httpmask.TunnelServer if !cfg.DisableHTTPMask { switch strings.ToLower(strings.TrimSpace(cfg.HTTPMaskMode)) { - case "stream", "poll", "auto": + case "stream", "poll", "auto", "ws": ts = httpmask.NewTunnelServer(httpmask.TunnelServerOptions{ Mode: cfg.HTTPMaskMode, PathRoot: cfg.HTTPMaskPathRoot, - AuthKey: ClientAEADSeed(cfg.Key), + AuthKey: ServerAEADSeed(cfg.Key), + // When upstream fallback is enabled, preserve rejected HTTP requests for the caller. + PassThroughOnReject: passThroughOnReject, }) } } @@ -62,6 +72,14 @@ func (s *HTTPMaskTunnelServer) WrapConn(rawConn net.Conn) (handshakeConn net.Con case httpmask.HandleStartTunnel: inner := *s.cfg inner.DisableHTTPMask = true + // HTTPMask tunnel modes (stream/poll/auto/ws) add extra round trips before the first + // handshake bytes can reach ServerHandshake, especially under high concurrency. + // Bump the handshake timeout for tunneled conns to avoid flaky timeouts while keeping + // the default strict for raw TCP handshakes. + const minTunneledHandshakeTimeoutSeconds = 15 + if inner.HandshakeTimeoutSeconds <= 0 || inner.HandshakeTimeoutSeconds < minTunneledHandshakeTimeoutSeconds { + inner.HandshakeTimeoutSeconds = minTunneledHandshakeTimeoutSeconds + } return c, &inner, false, nil default: return nil, nil, true, nil @@ -70,7 +88,7 @@ func (s *HTTPMaskTunnelServer) WrapConn(rawConn net.Conn) (handshakeConn net.Con type TunnelDialer func(ctx context.Context, network, addr string) (net.Conn, error) -// DialHTTPMaskTunnel dials a CDN-capable HTTP tunnel (stream/poll/auto) and returns a stream carrying raw Sudoku bytes. +// DialHTTPMaskTunnel dials a CDN-capable HTTP tunnel (stream/poll/auto/ws) and returns a stream carrying raw Sudoku bytes. func DialHTTPMaskTunnel(ctx context.Context, serverAddress string, cfg *ProtocolConfig, dial TunnelDialer, upgrade func(net.Conn) (net.Conn, error)) (net.Conn, error) { if cfg == nil { return nil, fmt.Errorf("config is required") @@ -79,7 +97,7 @@ func DialHTTPMaskTunnel(ctx context.Context, serverAddress string, cfg *Protocol return nil, fmt.Errorf("http mask is disabled") } switch strings.ToLower(strings.TrimSpace(cfg.HTTPMaskMode)) { - case "stream", "poll", "auto": + case "stream", "poll", "auto", "ws": default: return nil, fmt.Errorf("http-mask-mode=%q does not use http tunnel", cfg.HTTPMaskMode) } @@ -94,64 +112,3 @@ func DialHTTPMaskTunnel(ctx context.Context, serverAddress string, cfg *Protocol DialContext: dial, }) } - -type HTTPMaskTunnelClient struct { - mode string - pathRoot string - authKey string - client *httpmask.TunnelClient -} - -func NewHTTPMaskTunnelClient(serverAddress string, cfg *ProtocolConfig, dial TunnelDialer) (*HTTPMaskTunnelClient, error) { - if cfg == nil { - return nil, fmt.Errorf("config is required") - } - if cfg.DisableHTTPMask { - return nil, fmt.Errorf("http mask is disabled") - } - switch strings.ToLower(strings.TrimSpace(cfg.HTTPMaskMode)) { - case "stream", "poll", "auto": - default: - return nil, fmt.Errorf("http-mask-mode=%q does not use http tunnel", cfg.HTTPMaskMode) - } - switch strings.ToLower(strings.TrimSpace(cfg.HTTPMaskMultiplex)) { - case "auto", "on": - default: - return nil, fmt.Errorf("http-mask-multiplex=%q does not enable reuse", cfg.HTTPMaskMultiplex) - } - - c, err := httpmask.NewTunnelClient(serverAddress, httpmask.TunnelClientOptions{ - TLSEnabled: cfg.HTTPMaskTLSEnabled, - HostOverride: cfg.HTTPMaskHost, - DialContext: dial, - }) - if err != nil { - return nil, err - } - - return &HTTPMaskTunnelClient{ - mode: cfg.HTTPMaskMode, - pathRoot: cfg.HTTPMaskPathRoot, - authKey: ClientAEADSeed(cfg.Key), - client: c, - }, nil -} - -func (c *HTTPMaskTunnelClient) Dial(ctx context.Context, upgrade func(net.Conn) (net.Conn, error)) (net.Conn, error) { - if c == nil || c.client == nil { - return nil, fmt.Errorf("nil httpmask tunnel client") - } - return c.client.DialTunnel(ctx, httpmask.TunnelDialOptions{ - Mode: c.mode, - PathRoot: c.pathRoot, - AuthKey: c.authKey, - Upgrade: upgrade, - }) -} - -func (c *HTTPMaskTunnelClient) CloseIdleConnections() { - if c == nil || c.client == nil { - return - } - c.client.CloseIdleConnections() -} diff --git a/transport/sudoku/httpmask_tunnel_test.go b/transport/sudoku/httpmask_tunnel_test.go index d831c53f..8894882e 100644 --- a/transport/sudoku/httpmask_tunnel_test.go +++ b/transport/sudoku/httpmask_tunnel_test.go @@ -61,7 +61,7 @@ func startTunnelServer(t *testing.T, cfg *ProtocolConfig, handle func(*ServerSes return } - session, err := ServerHandshake(handshakeConn, handshakeCfg) + cConn, meta, err := ServerHandshake(handshakeConn, handshakeCfg) if err != nil { _ = handshakeConn.Close() if handshakeConn != conn { @@ -70,8 +70,13 @@ func startTunnelServer(t *testing.T, cfg *ProtocolConfig, handle func(*ServerSes errC <- err return } - defer session.Conn.Close() + defer cConn.Close() + session, err := ReadServerSession(cConn, meta) + if err != nil { + errC <- err + return + } if handleErr := handle(session); handleErr != nil { errC <- handleErr } @@ -172,7 +177,7 @@ func TestHTTPMaskTunnel_Stream_TCPRoundTrip(t *testing.T) { if err != nil { t.Fatalf("encode addr: %v", err) } - if _, err := cConn.Write(addrBuf); err != nil { + if err := WriteKIPMessage(cConn, KIPTypeOpenTCP, addrBuf); err != nil { t.Fatalf("write addr: %v", err) } @@ -239,8 +244,8 @@ func TestHTTPMaskTunnel_Poll_UoTRoundTrip(t *testing.T) { } defer cConn.Close() - if err := WritePreface(cConn); err != nil { - t.Fatalf("write preface: %v", err) + if err := WriteKIPMessage(cConn, KIPTypeStartUoT, nil); err != nil { + t.Fatalf("start uot: %v", err) } if err := WriteDatagram(cConn, target, payload); err != nil { t.Fatalf("write datagram: %v", err) @@ -305,7 +310,68 @@ func TestHTTPMaskTunnel_Auto_TCPRoundTrip(t *testing.T) { if err != nil { t.Fatalf("encode addr: %v", err) } - if _, err := cConn.Write(addrBuf); err != nil { + if err := WriteKIPMessage(cConn, KIPTypeOpenTCP, addrBuf); err != nil { + t.Fatalf("write addr: %v", err) + } + + buf := make([]byte, 2) + if _, err := io.ReadFull(cConn, buf); err != nil { + t.Fatalf("read: %v", err) + } + if string(buf) != "ok" { + t.Fatalf("unexpected payload: %q", buf) + } + + stop() + for err := range errCh { + t.Fatalf("server error: %v", err) + } +} + +func TestHTTPMaskTunnel_WS_TCPRoundTrip(t *testing.T) { + key := "tunnel-ws-key" + target := "1.1.1.1:80" + + serverCfg := newTunnelTestTable(t, key) + serverCfg.HTTPMaskMode = "ws" + + addr, stop, errCh := startTunnelServer(t, serverCfg, func(s *ServerSession) error { + if s.Type != SessionTypeTCP { + return fmt.Errorf("unexpected session type: %v", s.Type) + } + if s.Target != target { + return fmt.Errorf("target mismatch: %s", s.Target) + } + _, _ = s.Conn.Write([]byte("ok")) + return nil + }) + defer stop() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + clientCfg := *serverCfg + clientCfg.ServerAddress = addr + + tunnelConn, err := DialHTTPMaskTunnel(ctx, clientCfg.ServerAddress, &clientCfg, (&net.Dialer{}).DialContext, nil) + if err != nil { + t.Fatalf("dial tunnel: %v", err) + } + defer tunnelConn.Close() + + handshakeCfg := clientCfg + handshakeCfg.DisableHTTPMask = true + cConn, err := ClientHandshake(tunnelConn, &handshakeCfg) + if err != nil { + t.Fatalf("client handshake: %v", err) + } + defer cConn.Close() + + addrBuf, err := EncodeAddress(target) + if err != nil { + t.Fatalf("encode addr: %v", err) + } + if err := WriteKIPMessage(cConn, KIPTypeOpenTCP, addrBuf); err != nil { t.Fatalf("write addr: %v", err) } @@ -406,7 +472,7 @@ func TestHTTPMaskTunnel_Soak_Concurrent(t *testing.T) { runErr <- fmt.Errorf("encode addr: %w", err) return } - if _, err := cConn.Write(addrBuf); err != nil { + if err := WriteKIPMessage(cConn, KIPTypeOpenTCP, addrBuf); err != nil { runErr <- fmt.Errorf("write addr: %w", err) return } diff --git a/transport/sudoku/init.go b/transport/sudoku/init.go new file mode 100644 index 00000000..7dbf1d2c --- /dev/null +++ b/transport/sudoku/init.go @@ -0,0 +1,97 @@ +package sudoku + +import ( + "encoding/hex" + "fmt" + "strings" + + "github.com/metacubex/edwards25519" + "github.com/metacubex/mihomo/transport/sudoku/crypto" + "github.com/metacubex/mihomo/transport/sudoku/obfs/sudoku" +) + +func NewTable(key string, tableType string) *sudoku.Table { + table, err := NewTableWithCustom(key, tableType, "") + if err != nil { + panic(fmt.Sprintf("[Sudoku] failed to init tables: %v", err)) + } + return table +} + +func NewTableWithCustom(key string, tableType string, customTable string) (*sudoku.Table, error) { + table, err := sudoku.NewTableWithCustom(key, tableType, customTable) + if err != nil { + return nil, err + } + return table, nil +} + +// ClientAEADSeed returns a canonical "seed" that is stable between client private key material and server public key. +func ClientAEADSeed(key string) string { + key = strings.TrimSpace(key) + if key == "" { + return "" + } + + b, err := hex.DecodeString(key) + if err != nil { + return key + } + + // Client-side key material can be: + // - split private key: 64 bytes hex (r||k) + // - master private scalar: 32 bytes hex (x) + // - PSK string: non-hex + // + // We intentionally do NOT treat a 32-byte hex as a public key here; the client is expected + // to carry private material. Server-side should use ServerAEADSeed for public keys. + switch len(b) { + case 64: + case 32: + default: + return key + } + if recovered, err := crypto.RecoverPublicKey(key); err == nil { + return crypto.EncodePoint(recovered) + } + return key +} + +// ServerAEADSeed returns a canonical seed for server-side configuration. +// +// When key is a public key (32-byte compressed point, hex), it returns the canonical point encoding. +// When key is private key material (split/master scalar), it derives and returns the public key. +func ServerAEADSeed(key string) string { + key = strings.TrimSpace(key) + if key == "" { + return "" + } + + b, err := hex.DecodeString(key) + if err != nil { + return key + } + + // Prefer interpreting 32-byte hex as a public key point, to avoid accidental scalar parsing. + if len(b) == 32 { + if p, err := new(edwards25519.Point).SetBytes(b); err == nil { + return hex.EncodeToString(p.Bytes()) + } + } + + // Fall back to client-side rules for private key materials / other formats. + return ClientAEADSeed(key) +} + +// GenKeyPair generates a client "available private key" and the corresponding server public key. +func GenKeyPair() (privateKey, publicKey string, err error) { + pair, err := crypto.GenerateMasterKey() + if err != nil { + return "", "", err + } + availablePrivateKey, err := crypto.SplitPrivateKey(pair.Private) + if err != nil { + return "", "", err + } + return availablePrivateKey, crypto.EncodePoint(pair.Public), nil +} diff --git a/transport/sudoku/init_test.go b/transport/sudoku/init_test.go new file mode 100644 index 00000000..9d9841cb --- /dev/null +++ b/transport/sudoku/init_test.go @@ -0,0 +1,44 @@ +package sudoku + +import ( + "crypto/rand" + "encoding/hex" + "testing" + + "github.com/metacubex/edwards25519" + "github.com/stretchr/testify/require" +) + +func TestClientAEADSeed_IsStableForPrivAndPub(t *testing.T) { + for i := 0; i < 64; i++ { + priv, pub, err := GenKeyPair() + require.NoError(t, err) + + require.Equal(t, pub, ClientAEADSeed(priv)) + require.Equal(t, pub, ServerAEADSeed(pub)) + require.Equal(t, pub, ServerAEADSeed(priv)) + } +} + +func TestClientAEADSeed_Supports32ByteMasterScalar(t *testing.T) { + var seed [64]byte + _, err := rand.Read(seed[:]) + require.NoError(t, err) + + s, err := edwards25519.NewScalar().SetUniformBytes(seed[:]) + require.NoError(t, err) + + keyHex := hex.EncodeToString(s.Bytes()) + require.Len(t, keyHex, 64) + require.NotEqual(t, keyHex, ClientAEADSeed(keyHex)) + require.Equal(t, ClientAEADSeed(keyHex), ServerAEADSeed(ClientAEADSeed(keyHex))) +} + +func TestServerAEADSeed_LeavesPublicKeyAsIs(t *testing.T) { + for i := 0; i < 64; i++ { + priv, pub, err := GenKeyPair() + require.NoError(t, err) + require.Equal(t, pub, ServerAEADSeed(pub)) + require.Equal(t, pub, ServerAEADSeed(priv)) + } +} diff --git a/transport/sudoku/kip.go b/transport/sudoku/kip.go new file mode 100644 index 00000000..fab44fb9 --- /dev/null +++ b/transport/sudoku/kip.go @@ -0,0 +1,206 @@ +package sudoku + +import ( + "bytes" + "crypto/sha256" + "encoding/binary" + "encoding/hex" + "errors" + "fmt" + "io" + "strings" + "time" +) + +const ( + kipMagic = "kip" + + KIPTypeClientHello byte = 0x01 + KIPTypeServerHello byte = 0x02 + + KIPTypeOpenTCP byte = 0x10 + KIPTypeStartMux byte = 0x11 + KIPTypeStartUoT byte = 0x12 + KIPTypeKeepAlive byte = 0x14 +) + +// KIP feature bits are advisory capability flags negotiated during the handshake. +// They represent control-plane message families. +const ( + KIPFeatOpenTCP uint32 = 1 << 0 + KIPFeatMux uint32 = 1 << 1 + KIPFeatUoT uint32 = 1 << 2 + KIPFeatKeepAlive uint32 = 1 << 4 + + KIPFeatAll = KIPFeatOpenTCP | KIPFeatMux | KIPFeatUoT | KIPFeatKeepAlive +) + +const ( + kipHelloUserHashSize = 8 + kipHelloNonceSize = 16 + kipHelloPubSize = 32 + kipMaxPayload = 64 * 1024 +) + +var errKIP = errors.New("kip protocol error") + +type KIPMessage struct { + Type byte + Payload []byte +} + +func WriteKIPMessage(w io.Writer, typ byte, payload []byte) error { + if w == nil { + return fmt.Errorf("%w: nil writer", errKIP) + } + if len(payload) > kipMaxPayload { + return fmt.Errorf("%w: payload too large: %d", errKIP, len(payload)) + } + + var hdr [3 + 1 + 2]byte + copy(hdr[:3], []byte(kipMagic)) + hdr[3] = typ + binary.BigEndian.PutUint16(hdr[4:], uint16(len(payload))) + + if err := writeFull(w, hdr[:]); err != nil { + return err + } + if len(payload) == 0 { + return nil + } + return writeFull(w, payload) +} + +func ReadKIPMessage(r io.Reader) (*KIPMessage, error) { + if r == nil { + return nil, fmt.Errorf("%w: nil reader", errKIP) + } + var hdr [3 + 1 + 2]byte + if _, err := io.ReadFull(r, hdr[:]); err != nil { + return nil, err + } + if string(hdr[:3]) != kipMagic { + return nil, fmt.Errorf("%w: bad magic", errKIP) + } + typ := hdr[3] + n := int(binary.BigEndian.Uint16(hdr[4:])) + if n < 0 || n > kipMaxPayload { + return nil, fmt.Errorf("%w: invalid payload length: %d", errKIP, n) + } + var payload []byte + if n > 0 { + payload = make([]byte, n) + if _, err := io.ReadFull(r, payload); err != nil { + return nil, err + } + } + return &KIPMessage{Type: typ, Payload: payload}, nil +} + +type KIPClientHello struct { + Timestamp time.Time + UserHash [kipHelloUserHashSize]byte + Nonce [kipHelloNonceSize]byte + ClientPub [kipHelloPubSize]byte + Features uint32 +} + +type KIPServerHello struct { + Nonce [kipHelloNonceSize]byte + ServerPub [kipHelloPubSize]byte + SelectedFeats uint32 +} + +func kipUserHashFromKey(psk string) [kipHelloUserHashSize]byte { + var out [kipHelloUserHashSize]byte + psk = strings.TrimSpace(psk) + if psk == "" { + return out + } + + // Align with upstream: when the client carries private key material (or even just a public key), + // prefer hashing the raw hex bytes so different split/master keys can be distinguished. + if keyBytes, err := hex.DecodeString(psk); err == nil && len(keyBytes) > 0 { + sum := sha256.Sum256(keyBytes) + copy(out[:], sum[:kipHelloUserHashSize]) + return out + } + + sum := sha256.Sum256([]byte(psk)) + copy(out[:], sum[:kipHelloUserHashSize]) + return out +} + +func KIPUserHashHexFromKey(psk string) string { + uh := kipUserHashFromKey(psk) + return hex.EncodeToString(uh[:]) +} + +func (m *KIPClientHello) EncodePayload() []byte { + var b bytes.Buffer + var tmp [8]byte + binary.BigEndian.PutUint64(tmp[:], uint64(m.Timestamp.Unix())) + b.Write(tmp[:]) + b.Write(m.UserHash[:]) + b.Write(m.Nonce[:]) + b.Write(m.ClientPub[:]) + var f [4]byte + binary.BigEndian.PutUint32(f[:], m.Features) + b.Write(f[:]) + return b.Bytes() +} + +func DecodeKIPClientHelloPayload(payload []byte) (*KIPClientHello, error) { + const minLen = 8 + kipHelloUserHashSize + kipHelloNonceSize + kipHelloPubSize + 4 + if len(payload) < minLen { + return nil, fmt.Errorf("%w: client hello too short", errKIP) + } + var h KIPClientHello + ts := int64(binary.BigEndian.Uint64(payload[:8])) + h.Timestamp = time.Unix(ts, 0) + off := 8 + copy(h.UserHash[:], payload[off:off+kipHelloUserHashSize]) + off += kipHelloUserHashSize + copy(h.Nonce[:], payload[off:off+kipHelloNonceSize]) + off += kipHelloNonceSize + copy(h.ClientPub[:], payload[off:off+kipHelloPubSize]) + off += kipHelloPubSize + h.Features = binary.BigEndian.Uint32(payload[off : off+4]) + return &h, nil +} + +func (m *KIPServerHello) EncodePayload() []byte { + var b bytes.Buffer + b.Write(m.Nonce[:]) + b.Write(m.ServerPub[:]) + var f [4]byte + binary.BigEndian.PutUint32(f[:], m.SelectedFeats) + b.Write(f[:]) + return b.Bytes() +} + +func DecodeKIPServerHelloPayload(payload []byte) (*KIPServerHello, error) { + const want = kipHelloNonceSize + kipHelloPubSize + 4 + if len(payload) != want { + return nil, fmt.Errorf("%w: server hello bad len: %d", errKIP, len(payload)) + } + var h KIPServerHello + off := 0 + copy(h.Nonce[:], payload[off:off+kipHelloNonceSize]) + off += kipHelloNonceSize + copy(h.ServerPub[:], payload[off:off+kipHelloPubSize]) + off += kipHelloPubSize + h.SelectedFeats = binary.BigEndian.Uint32(payload[off : off+4]) + return &h, nil +} + +func writeFull(w io.Writer, b []byte) error { + for len(b) > 0 { + n, err := w.Write(b) + if err != nil { + return err + } + b = b[n:] + } + return nil +} diff --git a/transport/sudoku/multiplex.go b/transport/sudoku/multiplex.go index da635708..3f0bfa69 100644 --- a/transport/sudoku/multiplex.go +++ b/transport/sudoku/multiplex.go @@ -10,19 +10,14 @@ import ( "github.com/metacubex/mihomo/transport/sudoku/multiplex" ) -const ( - MultiplexMagicByte byte = multiplex.MagicByte - MultiplexVersion byte = multiplex.Version -) - -// StartMultiplexClient writes the multiplex preface and upgrades an already-handshaked Sudoku tunnel into a multiplex session. +// StartMultiplexClient upgrades an already-handshaked Sudoku tunnel into a multiplex session. func StartMultiplexClient(conn net.Conn) (*MultiplexClient, error) { if conn == nil { return nil, fmt.Errorf("nil conn") } - if err := multiplex.WritePreface(conn); err != nil { - return nil, fmt.Errorf("write multiplex preface failed: %w", err) + if err := WriteKIPMessage(conn, KIPTypeStartMux, nil); err != nil { + return nil, fmt.Errorf("write mux start failed: %w", err) } sess, err := multiplex.NewClientSession(conn) @@ -77,20 +72,10 @@ func (c *MultiplexClient) IsClosed() bool { } // AcceptMultiplexServer upgrades a server-side, already-handshaked Sudoku connection into a multiplex session. -// -// The caller must have already consumed the multiplex magic byte (MultiplexMagicByte). This function consumes the -// multiplex version byte and starts the session. func AcceptMultiplexServer(conn net.Conn) (*MultiplexServer, error) { if conn == nil { return nil, fmt.Errorf("nil conn") } - v, err := multiplex.ReadVersion(conn) - if err != nil { - return nil, err - } - if err := multiplex.ValidateVersion(v); err != nil { - return nil, err - } sess, err := multiplex.NewServerSession(conn) if err != nil { return nil, err diff --git a/transport/sudoku/multiplex/mux.go b/transport/sudoku/multiplex/mux.go deleted file mode 100644 index 38bc34c8..00000000 --- a/transport/sudoku/multiplex/mux.go +++ /dev/null @@ -1,39 +0,0 @@ -package multiplex - -import ( - "fmt" - "io" -) - -const ( - // MagicByte marks a Sudoku tunnel connection that will switch into multiplex mode. - // It is sent after the Sudoku handshake + downlink mode byte. - // - // Keep it distinct from UoTMagicByte and address type bytes. - MagicByte byte = 0xED - Version byte = 0x01 -) - -func WritePreface(w io.Writer) error { - if w == nil { - return fmt.Errorf("nil writer") - } - _, err := w.Write([]byte{MagicByte, Version}) - return err -} - -func ReadVersion(r io.Reader) (byte, error) { - var b [1]byte - if _, err := io.ReadFull(r, b[:]); err != nil { - return 0, err - } - return b[0], nil -} - -func ValidateVersion(v byte) error { - if v != Version { - return fmt.Errorf("unsupported multiplex version: %d", v) - } - return nil -} - diff --git a/transport/sudoku/multiplex_test.go b/transport/sudoku/multiplex_test.go index 93962906..8be7d9a4 100644 --- a/transport/sudoku/multiplex_test.go +++ b/transport/sudoku/multiplex_test.go @@ -18,7 +18,6 @@ func TestUserHash_StableAcrossTableRotation(t *testing.T) { sudokuobfs.NewTable("seed-b", "prefer_ascii"), } key := "userhash-stability-key" - target := "example.com:80" serverCfg := DefaultConfig() serverCfg.Key = key @@ -48,13 +47,16 @@ func TestUserHash_StableAcrossTableRotation(t *testing.T) { } go func(conn net.Conn) { defer conn.Close() - session, err := ServerHandshake(conn, serverCfg) + _, meta, err := ServerHandshake(conn, serverCfg) if err != nil { errCh <- err return } - defer session.Conn.Close() - hashCh <- session.UserHash + if meta == nil || meta.UserHash == "" { + errCh <- io.ErrUnexpectedEOF + return + } + hashCh <- meta.UserHash }(c) } }() @@ -77,15 +79,6 @@ func TestUserHash_StableAcrossTableRotation(t *testing.T) { t.Fatalf("handshake %d: %v", i, err) } - addrBuf, err := EncodeAddress(target) - if err != nil { - _ = cConn.Close() - t.Fatalf("encode addr %d: %v", i, err) - } - if _, err := cConn.Write(addrBuf); err != nil { - _ = cConn.Close() - t.Fatalf("write addr %d: %v", i, err) - } _ = cConn.Close() } @@ -145,18 +138,22 @@ func TestMultiplex_TCP_Echo(t *testing.T) { } defer raw.Close() - session, err := ServerHandshake(raw, serverCfg) + c, meta, err := ServerHandshake(raw, serverCfg) if err != nil { return } atomic.AddInt64(&handshakes, 1) + session, err := ReadServerSession(c, meta) + if err != nil { + return + } if session.Type != SessionTypeMultiplex { - _ = session.Conn.Close() + _ = c.Close() return } - mux, err := AcceptMultiplexServer(session.Conn) + mux, err := AcceptMultiplexServer(c) if err != nil { return } @@ -240,21 +237,3 @@ func TestMultiplex_TCP_Echo(t *testing.T) { t.Fatalf("unexpected stream count: %d", got) } } - -func TestMultiplex_Boundary_InvalidVersion(t *testing.T) { - client, server := net.Pipe() - t.Cleanup(func() { _ = client.Close() }) - t.Cleanup(func() { _ = server.Close() }) - - errCh := make(chan error, 1) - go func() { - _, err := AcceptMultiplexServer(server) - errCh <- err - }() - - // AcceptMultiplexServer expects the magic byte to have been consumed already; write a bad version byte. - _, _ = client.Write([]byte{0xFF}) - if err := <-errCh; err == nil { - t.Fatalf("expected error") - } -} diff --git a/transport/sudoku/obfs/httpmask/auth.go b/transport/sudoku/obfs/httpmask/auth.go index f59d6958..15b641d0 100644 --- a/transport/sudoku/obfs/httpmask/auth.go +++ b/transport/sudoku/obfs/httpmask/auth.go @@ -6,9 +6,10 @@ import ( "crypto/subtle" "encoding/base64" "encoding/binary" - "github.com/metacubex/http" "strings" "time" + + "github.com/metacubex/http" ) const ( diff --git a/transport/sudoku/obfs/httpmask/tunnel.go b/transport/sudoku/obfs/httpmask/tunnel.go index 7f9f786e..20981c39 100644 --- a/transport/sudoku/obfs/httpmask/tunnel.go +++ b/transport/sudoku/obfs/httpmask/tunnel.go @@ -16,6 +16,7 @@ import ( "strconv" "strings" "sync" + "syscall" "time" "github.com/metacubex/mihomo/component/ca" @@ -32,6 +33,7 @@ const ( TunnelModeStream TunnelMode = "stream" TunnelModePoll TunnelMode = "poll" TunnelModeAuto TunnelMode = "auto" + TunnelModeWS TunnelMode = "ws" ) func normalizeTunnelMode(mode string) TunnelMode { @@ -44,6 +46,8 @@ func normalizeTunnelMode(mode string) TunnelMode { return TunnelModePoll case string(TunnelModeAuto): return TunnelModeAuto + case string(TunnelModeWS): + return TunnelModeWS default: // Be conservative: unknown => legacy return TunnelModeLegacy @@ -88,7 +92,6 @@ type TunnelClientOptions struct { } type TunnelClient struct { - client *http.Client transport *http.Transport target httpClientTarget } @@ -105,7 +108,6 @@ func NewTunnelClient(serverAddress string, opts TunnelClientOptions) (*TunnelCli } return &TunnelClient{ - client: &http.Client{Transport: transport}, transport: transport, target: target, }, nil @@ -119,7 +121,7 @@ func (c *TunnelClient) CloseIdleConnections() { } func (c *TunnelClient) DialTunnel(ctx context.Context, opts TunnelDialOptions) (net.Conn, error) { - if c == nil || c.client == nil { + if c == nil || c.transport == nil { return nil, fmt.Errorf("nil tunnel client") } tm := normalizeTunnelMode(opts.Mode) @@ -127,25 +129,31 @@ func (c *TunnelClient) DialTunnel(ctx context.Context, opts TunnelDialOptions) ( return nil, fmt.Errorf("legacy mode does not use http tunnel") } + // Create a per-dial client while sharing the underlying Transport for connection reuse. + // This matches upstream behavior and avoids potential client-level concurrency pitfalls. + client := &http.Client{Transport: c.transport} + switch tm { case TunnelModeStream: - return dialStreamWithClient(ctx, c.client, c.target, opts) + return dialStreamWithClient(ctx, client, c.target, opts) case TunnelModePoll: - return dialPollWithClient(ctx, c.client, c.target, opts) + return dialPollWithClient(ctx, client, c.target, opts) + case TunnelModeWS: + return nil, fmt.Errorf("ws mode does not support TunnelClient reuse") case TunnelModeAuto: streamCtx, cancelX := context.WithTimeout(ctx, 3*time.Second) - c1, errX := dialStreamWithClient(streamCtx, c.client, c.target, opts) + c1, errX := dialStreamWithClient(streamCtx, client, c.target, opts) cancelX() if errX == nil { return c1, nil } - c2, errP := dialPollWithClient(ctx, c.client, c.target, opts) + c2, errP := dialPollWithClient(ctx, client, c.target, opts) if errP == nil { return c2, nil } return nil, fmt.Errorf("auto tunnel failed: stream: %v; poll: %w", errX, errP) default: - return dialStreamWithClient(ctx, c.client, c.target, opts) + return dialStreamWithClient(ctx, client, c.target, opts) } } @@ -166,6 +174,8 @@ func DialTunnel(ctx context.Context, serverAddress string, opts TunnelDialOption return dialStreamFn(ctx, serverAddress, opts) case TunnelModePoll: return dialPollFn(ctx, serverAddress, opts) + case TunnelModeWS: + return dialWS(ctx, serverAddress, opts) case TunnelModeAuto: // "stream" can hang on some CDNs that buffer uploads until request body completes. // Keep it on a short leash so we can fall back to poll within the caller's deadline. @@ -306,6 +316,36 @@ type sessionDialInfo struct { auth *tunnelAuth } +type httpStatusError struct { + code int + status string +} + +func (e *httpStatusError) Error() string { + if e == nil { + return "bad status" + } + if e.status != "" { + return "bad status: " + e.status + } + return "bad status" +} + +func isRetryableStatusCode(code int) bool { + return code == http.StatusRequestTimeout || code == http.StatusTooManyRequests || code >= 500 +} + +type idleConnCloser interface{ CloseIdleConnections() } + +func closeIdleConnections(client *http.Client) { + if client == nil || client.Transport == nil { + return + } + if c, ok := client.Transport.(idleConnCloser); ok { + c.CloseIdleConnections() + } +} + func dialSessionWithClient(ctx context.Context, client *http.Client, target httpClientTarget, mode TunnelMode, opts TunnelDialOptions) (*sessionDialInfo, error) { if client == nil { return nil, fmt.Errorf("nil http client") @@ -313,25 +353,61 @@ func dialSessionWithClient(ctx context.Context, client *http.Client, target http auth := newTunnelAuth(opts.AuthKey, 0) authorizeURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: joinPathRoot(opts.PathRoot, "/session")}).String() - req, err := http.NewRequestWithContext(ctx, http.MethodGet, authorizeURL, nil) - if err != nil { - return nil, err - } - req.Host = target.headerHost - applyTunnelHeaders(req.Header, target.headerHost, mode) - applyTunnelAuth(req, auth, mode, http.MethodGet, "/session") - resp, err := client.Do(req) - if err != nil { - return nil, err - } - bodyBytes, err := io.ReadAll(io.LimitReader(resp.Body, 4*1024)) - _ = resp.Body.Close() - if err != nil { - return nil, err - } - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("%s authorize bad status: %s (%s)", mode, resp.Status, strings.TrimSpace(string(bodyBytes))) + var bodyBytes []byte + for attempt := 0; ; attempt++ { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, authorizeURL, nil) + if err != nil { + return nil, err + } + req.Host = target.headerHost + applyTunnelHeaders(req.Header, target.headerHost, mode) + applyTunnelAuth(req, auth, mode, http.MethodGet, "/session") + + resp, err := client.Do(req) + if err != nil { + // Transient failure on reused keep-alive conns (multiplex=auto). Retry a few times. + if attempt < 2 && (isDialError(err) || isRetryableRequestError(err)) { + closeIdleConnections(client) + select { + case <-time.After(25 * time.Millisecond): + continue + case <-ctx.Done(): + return nil, err + } + } + return nil, err + } + + bodyBytes, err = io.ReadAll(io.LimitReader(resp.Body, 4*1024)) + _ = resp.Body.Close() + if err != nil { + if attempt < 2 && isRetryableRequestError(err) { + closeIdleConnections(client) + select { + case <-time.After(25 * time.Millisecond): + continue + case <-ctx.Done(): + return nil, err + } + } + return nil, err + } + + if resp.StatusCode != http.StatusOK { + // Retry some transient proxy/CDN errors. + if attempt < 2 && resp.StatusCode >= 500 { + closeIdleConnections(client) + select { + case <-time.After(25 * time.Millisecond): + continue + case <-ctx.Done(): + return nil, fmt.Errorf("%s authorize bad status: %s (%s)", mode, resp.Status, strings.TrimSpace(string(bodyBytes))) + } + } + return nil, fmt.Errorf("%s authorize bad status: %s (%s)", mode, resp.Status, strings.TrimSpace(string(bodyBytes))) + } + break } token, err := parseTunnelToken(bodyBytes) @@ -544,9 +620,8 @@ type streamSplitConn struct { auth *tunnelAuth } -func (c *streamSplitConn) Close() error { - _ = c.closeWithError(io.ErrClosedPipe) - +func (c *streamSplitConn) closeWithError(err error) error { + _ = c.queuedConn.closeWithError(err) if c.cancel != nil { c.cancel() } @@ -554,6 +629,8 @@ func (c *streamSplitConn) Close() error { return nil } +func (c *streamSplitConn) Close() error { return c.closeWithError(io.ErrClosedPipe) } + func newStreamSplitConnFromInfo(info *sessionDialInfo) *streamSplitConn { if info == nil { return nil @@ -659,7 +736,7 @@ func (c *streamSplitConn) pullLoop() { req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, c.pullURL, nil) if err != nil { cancel() - _ = c.Close() + _ = c.closeWithError(fmt.Errorf("stream pull build request failed: %w", err)) return } req.Host = c.headerHost @@ -669,8 +746,9 @@ func (c *streamSplitConn) pullLoop() { resp, err := c.client.Do(req) if err != nil { cancel() - if isDialError(err) && dialRetry < maxDialRetry { + if (isDialError(err) || isRetryableRequestError(err)) && dialRetry < maxDialRetry { dialRetry++ + closeIdleConnections(c.client) select { case <-time.After(backoff): case <-c.closed: @@ -682,16 +760,33 @@ func (c *streamSplitConn) pullLoop() { } continue } - _ = c.Close() + _ = c.closeWithError(fmt.Errorf("stream pull request failed: %w", err)) return } dialRetry = 0 backoff = minBackoff if resp.StatusCode != http.StatusOK { + if isRetryableStatusCode(resp.StatusCode) && dialRetry < maxDialRetry { + dialRetry++ + _, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 4*1024)) + _ = resp.Body.Close() + cancel() + closeIdleConnections(c.client) + select { + case <-time.After(backoff): + case <-c.closed: + return + } + backoff *= 2 + if backoff > maxBackoff { + backoff = maxBackoff + } + continue + } _ = resp.Body.Close() cancel() - _ = c.Close() + _ = c.closeWithError(fmt.Errorf("stream pull bad status: %s", resp.Status)) return } @@ -717,7 +812,12 @@ func (c *streamSplitConn) pullLoop() { // Long-poll ended; retry. break } - _ = c.Close() + // Some environments may sporadically reset the HTTP connection under load; treat + // it as an ended long-poll and retry instead of tearing down the whole tunnel. + if errors.Is(rerr, io.ErrUnexpectedEOF) || isRetryableRequestError(rerr) { + break + } + _ = c.closeWithError(fmt.Errorf("stream pull read failed: %w", rerr)) return } } @@ -735,8 +835,13 @@ func (c *streamSplitConn) pullLoop() { func (c *streamSplitConn) pushLoop() { const ( - maxBatchBytes = 256 * 1024 - flushInterval = 5 * time.Millisecond + // Batching is critical for stability under high concurrency: every flush is a new TCP + // connection in HTTP/1.1, and too many tiny uploads can overwhelm the accept backlog, + // causing sporadic RSTs (connection reset by peer). + // + // Keep this below the server-side maxUploadBytes limit in streamPush(). + maxBatchBytes = 512 * 1024 + flushInterval = 25 * time.Millisecond requestTimeout = 20 * time.Second maxDialRetry = 12 minBackoff = 10 * time.Millisecond @@ -754,12 +859,18 @@ func (c *streamSplitConn) pushLoop() { return nil } + payload := buf.Bytes() reqCtx, cancel := context.WithTimeout(c.ctx, requestTimeout) - req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, c.pushURL, bytes.NewReader(buf.Bytes())) + req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, c.pushURL, bytes.NewReader(payload)) if err != nil { cancel() return err } + // Be explicit: some http client forks won't auto-populate GetBody, which makes POST retries on stale + // keep-alive connections flaky under multiplex=auto. + req.GetBody = func() (io.ReadCloser, error) { + return io.NopCloser(bytes.NewReader(payload)), nil + } req.Host = c.headerHost applyTunnelHeaders(req.Header, c.headerHost, TunnelModeStream) applyTunnelAuth(req, c.auth, TunnelModeStream, http.MethodPost, "/api/v1/upload") @@ -774,7 +885,7 @@ func (c *streamSplitConn) pushLoop() { _ = resp.Body.Close() cancel() if resp.StatusCode != http.StatusOK { - return fmt.Errorf("bad status: %s", resp.Status) + return &httpStatusError{code: resp.StatusCode, status: resp.Status} } buf.Reset() @@ -787,8 +898,22 @@ func (c *streamSplitConn) pushLoop() { for { if err := flush(); err == nil { return nil - } else if isDialError(err) && dialRetry < maxDialRetry { + } else if se := (*httpStatusError)(nil); errors.As(err, &se) && isRetryableStatusCode(se.code) && dialRetry < maxDialRetry { dialRetry++ + closeIdleConnections(c.client) + select { + case <-time.After(backoff): + case <-c.closed: + return io.ErrClosedPipe + } + backoff *= 2 + if backoff > maxBackoff { + backoff = maxBackoff + } + continue + } else if (isDialError(err) || isRetryableRequestError(err)) && dialRetry < maxDialRetry { + dialRetry++ + closeIdleConnections(c.client) select { case <-time.After(backoff): case <-c.closed: @@ -829,7 +954,7 @@ func (c *streamSplitConn) pushLoop() { } if buf.Len()+len(b) > maxBatchBytes { if err := flushWithRetry(); err != nil { - _ = c.Close() + _ = c.closeWithError(fmt.Errorf("stream push flush failed: %w", err)) return } resetTimer() @@ -837,14 +962,14 @@ func (c *streamSplitConn) pushLoop() { _, _ = buf.Write(b) if buf.Len() >= maxBatchBytes { if err := flushWithRetry(); err != nil { - _ = c.Close() + _ = c.closeWithError(fmt.Errorf("stream push flush failed: %w", err)) return } resetTimer() } case <-timer.C: if err := flushWithRetry(); err != nil { - _ = c.Close() + _ = c.closeWithError(fmt.Errorf("stream push flush failed: %w", err)) return } resetTimer() @@ -858,7 +983,7 @@ func (c *streamSplitConn) pushLoop() { } if buf.Len()+len(b) > maxBatchBytes { if err := flushWithRetry(); err != nil { - _ = c.Close() + _ = c.closeWithError(fmt.Errorf("stream push flush failed: %w", err)) return } } @@ -905,6 +1030,43 @@ func isDialError(err error) bool { return false } +func isRetryableRequestError(err error) bool { + if err == nil { + return false + } + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return false + } + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { + return true + } + // net/http may return this when reusing a keep-alive conn that the peer already closed. + // Treat it as retryable: callers already implement bounded backoff retries. + if strings.Contains(strings.ToLower(err.Error()), "server closed idle connection") { + return true + } + + // Unwrap common wrappers. + var urlErr *url.Error + if errors.As(err, &urlErr) { + return isRetryableRequestError(urlErr.Err) + } + + // Connection-level transient failures. + if errors.Is(err, syscall.EPIPE) || errors.Is(err, syscall.ECONNRESET) { + return true + } + if errors.Is(err, io.ErrClosedPipe) || errors.Is(err, net.ErrClosed) { + return true + } + + var netErr net.Error + if errors.As(err, &netErr) { + return netErr.Timeout() || netErr.Temporary() + } + return false +} + func (c *pollConn) closeWithError(err error) error { _ = c.queuedConn.closeWithError(err) if c.cancel != nil { @@ -1012,8 +1174,10 @@ func (c *pollConn) pullLoop() { default: } - req, err := http.NewRequestWithContext(c.ctx, http.MethodGet, c.pullURL, nil) + reqCtx, cancel := context.WithTimeout(c.ctx, 30*time.Second) + req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, c.pullURL, nil) if err != nil { + cancel() _ = c.Close() return } @@ -1023,8 +1187,10 @@ func (c *pollConn) pullLoop() { resp, err := c.client.Do(req) if err != nil { - if isDialError(err) && dialRetry < maxDialRetry { + cancel() + if (isDialError(err) || isRetryableRequestError(err)) && dialRetry < maxDialRetry { dialRetry++ + closeIdleConnections(c.client) select { case <-time.After(backoff): case <-c.closed: @@ -1043,7 +1209,25 @@ func (c *pollConn) pullLoop() { backoff = minBackoff if resp.StatusCode != http.StatusOK { + if isRetryableStatusCode(resp.StatusCode) && dialRetry < maxDialRetry { + dialRetry++ + _, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 4*1024)) + _ = resp.Body.Close() + cancel() + closeIdleConnections(c.client) + select { + case <-time.After(backoff): + case <-c.closed: + return + } + backoff *= 2 + if backoff > maxBackoff { + backoff = maxBackoff + } + continue + } _ = resp.Body.Close() + cancel() _ = c.closeWithError(fmt.Errorf("poll pull bad status: %s", resp.Status)) return } @@ -1068,7 +1252,12 @@ func (c *pollConn) pullLoop() { } } _ = resp.Body.Close() + cancel() if err := scanner.Err(); err != nil { + // Treat transient stream breaks (RST/EOF) as an ended long-poll and retry. + if errors.Is(err, io.ErrUnexpectedEOF) || isRetryableRequestError(err) { + continue + } _ = c.closeWithError(fmt.Errorf("poll pull scan failed: %w", err)) return } @@ -1077,8 +1266,8 @@ func (c *pollConn) pullLoop() { func (c *pollConn) pushLoop() { const ( - maxBatchBytes = 64 * 1024 - flushInterval = 5 * time.Millisecond + maxBatchBytes = 512 * 1024 + flushInterval = 50 * time.Millisecond maxLineRawBytes = 16 * 1024 maxDialRetry = 12 minBackoff = 10 * time.Millisecond @@ -1097,12 +1286,16 @@ func (c *pollConn) pushLoop() { return nil } + payload := buf.Bytes() reqCtx, cancel := context.WithTimeout(c.ctx, 20*time.Second) - req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, c.pushURL, bytes.NewReader(buf.Bytes())) + req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, c.pushURL, bytes.NewReader(payload)) if err != nil { cancel() return err } + req.GetBody = func() (io.ReadCloser, error) { + return io.NopCloser(bytes.NewReader(payload)), nil + } req.Host = c.headerHost applyTunnelHeaders(req.Header, c.headerHost, TunnelModePoll) applyTunnelAuth(req, c.auth, TunnelModePoll, http.MethodPost, "/api/v1/upload") @@ -1117,7 +1310,7 @@ func (c *pollConn) pushLoop() { _ = resp.Body.Close() cancel() if resp.StatusCode != http.StatusOK { - return fmt.Errorf("bad status: %s", resp.Status) + return &httpStatusError{code: resp.StatusCode, status: resp.Status} } buf.Reset() @@ -1131,8 +1324,22 @@ func (c *pollConn) pushLoop() { for { if err := flush(); err == nil { return nil - } else if isDialError(err) && dialRetry < maxDialRetry { + } else if se := (*httpStatusError)(nil); errors.As(err, &se) && isRetryableStatusCode(se.code) && dialRetry < maxDialRetry { dialRetry++ + closeIdleConnections(c.client) + select { + case <-time.After(backoff): + case <-c.closed: + return c.closedErr() + } + backoff *= 2 + if backoff > maxBackoff { + backoff = maxBackoff + } + continue + } else if (isDialError(err) || isRetryableRequestError(err)) && dialRetry < maxDialRetry { + dialRetry++ + closeIdleConnections(c.client) select { case <-time.After(backoff): case <-c.closed: @@ -1482,6 +1689,16 @@ func (s *TunnelServer) HandleConn(rawConn net.Conn) (HandleResult, net.Conn, err return HandleDone, nil, nil } return s.handlePoll(rawConn, req, headerBytes, buffered) + case TunnelModeWS: + if s.mode != TunnelModeWS && s.mode != TunnelModeAuto { + if s.passThroughOnReject { + return reject() + } + _ = writeSimpleHTTPResponse(rawConn, http.StatusNotFound, "not found") + _ = rawConn.Close() + return HandleDone, nil, nil + } + return s.handleWS(rawConn, req, headerBytes, buffered) default: if s.passThroughOnReject { return reject() diff --git a/transport/sudoku/obfs/httpmask/tunnel_ws.go b/transport/sudoku/obfs/httpmask/tunnel_ws.go new file mode 100644 index 00000000..e1299e3d --- /dev/null +++ b/transport/sudoku/obfs/httpmask/tunnel_ws.go @@ -0,0 +1,176 @@ +package httpmask + +import ( + "context" + "fmt" + "io" + mrand "math/rand" + "net" + stdhttp "net/http" + "net/url" + "strings" + "time" + + "github.com/gobwas/ws" + "github.com/metacubex/tls" +) + +func normalizeWSSchemeFromAddress(serverAddress string, tlsEnabled bool) (string, string) { + addr := strings.TrimSpace(serverAddress) + if strings.Contains(addr, "://") { + if u, err := url.Parse(addr); err == nil && u != nil { + switch strings.ToLower(strings.TrimSpace(u.Scheme)) { + case "ws": + return "ws", u.Host + case "wss": + return "wss", u.Host + } + } + } + if tlsEnabled { + return "wss", addr + } + return "ws", addr +} + +func normalizeWSDialTarget(serverAddress string, tlsEnabled bool, hostOverride string) (scheme, urlHost, dialAddr, serverName string, err error) { + scheme, addr := normalizeWSSchemeFromAddress(serverAddress, tlsEnabled) + + host, port, err := net.SplitHostPort(addr) + if err != nil { + // Allow ws(s)://host without port. + if strings.Contains(addr, ":") { + return "", "", "", "", fmt.Errorf("invalid server address %q: %w", serverAddress, err) + } + switch scheme { + case "wss": + port = "443" + default: + port = "80" + } + host = addr + } + + if hostOverride != "" { + // Allow "example.com" or "example.com:443" + if h, p, splitErr := net.SplitHostPort(hostOverride); splitErr == nil { + if h != "" { + hostOverride = h + } + if p != "" { + port = p + } + } + serverName = hostOverride + urlHost = net.JoinHostPort(hostOverride, port) + } else { + serverName = host + urlHost = net.JoinHostPort(host, port) + } + + dialAddr = net.JoinHostPort(host, port) + return scheme, urlHost, dialAddr, trimPortForHost(serverName), nil +} + +func applyWSHeaders(h stdhttp.Header, host string) { + if h == nil { + return + } + r := rngPool.Get().(*mrand.Rand) + ua := userAgents[r.Intn(len(userAgents))] + accept := accepts[r.Intn(len(accepts))] + lang := acceptLanguages[r.Intn(len(acceptLanguages))] + enc := acceptEncodings[r.Intn(len(acceptEncodings))] + rngPool.Put(r) + + h.Set("User-Agent", ua) + h.Set("Accept", accept) + h.Set("Accept-Language", lang) + h.Set("Accept-Encoding", enc) + h.Set("Cache-Control", "no-cache") + h.Set("Pragma", "no-cache") + h.Set("X-Sudoku-Tunnel", string(TunnelModeWS)) + h.Set("X-Sudoku-Version", "1") +} + +func dialWS(ctx context.Context, serverAddress string, opts TunnelDialOptions) (net.Conn, error) { + if opts.DialContext == nil { + panic("httpmask: DialContext is nil") + } + + scheme, urlHost, dialAddr, serverName, err := normalizeWSDialTarget(serverAddress, opts.TLSEnabled, opts.HostOverride) + if err != nil { + return nil, err + } + + httpScheme := "http" + if scheme == "wss" { + httpScheme = "https" + } + headerHost := canonicalHeaderHost(urlHost, httpScheme) + auth := newTunnelAuth(opts.AuthKey, 0) + + u := &url.URL{ + Scheme: scheme, + Host: urlHost, + Path: joinPathRoot(opts.PathRoot, "/ws"), + } + + header := make(stdhttp.Header) + applyWSHeaders(header, headerHost) + + if auth != nil { + token := auth.token(TunnelModeWS, stdhttp.MethodGet, "/ws", time.Now()) + if token != "" { + header.Set("Authorization", "Bearer "+token) + q := u.Query() + q.Set(tunnelAuthQueryKey, token) + u.RawQuery = q.Encode() + } + } + + d := ws.Dialer{ + Host: headerHost, + Header: ws.HandshakeHeaderHTTP(header), + NetDial: func(dialCtx context.Context, network, addr string) (net.Conn, error) { + if addr == urlHost { + addr = dialAddr + } + return opts.DialContext(dialCtx, network, addr) + }, + } + if scheme == "wss" { + tlsConfig := &tls.Config{ + ServerName: serverName, + MinVersion: tls.VersionTLS12, + } + d.TLSClient = func(conn net.Conn, hostname string) net.Conn { + return tls.Client(conn, tlsConfig) + } + } + + conn, br, _, err := d.Dial(ctx, u.String()) + if err != nil { + return nil, err + } + + if br != nil && br.Buffered() > 0 { + pre := make([]byte, br.Buffered()) + _, _ = io.ReadFull(br, pre) + conn = newPreBufferedConn(conn, pre) + } + + wsConn := newWSStreamConn(conn, ws.StateClientSide) + if opts.Upgrade == nil { + return wsConn, nil + } + upgraded, err := opts.Upgrade(wsConn) + if err != nil { + _ = wsConn.Close() + return nil, err + } + if upgraded != nil { + return upgraded, nil + } + return wsConn, nil +} diff --git a/transport/sudoku/obfs/httpmask/tunnel_ws_server.go b/transport/sudoku/obfs/httpmask/tunnel_ws_server.go new file mode 100644 index 00000000..3e79e58a --- /dev/null +++ b/transport/sudoku/obfs/httpmask/tunnel_ws_server.go @@ -0,0 +1,77 @@ +package httpmask + +import ( + "net" + "net/http" + "net/url" + "strings" + "time" + + "github.com/gobwas/ws" +) + +func looksLikeWebSocketUpgrade(headers map[string]string) bool { + if headers == nil { + return false + } + if !strings.EqualFold(strings.TrimSpace(headers["upgrade"]), "websocket") { + return false + } + conn := headers["connection"] + for _, part := range strings.Split(conn, ",") { + if strings.EqualFold(strings.TrimSpace(part), "upgrade") { + return true + } + } + return false +} + +func (s *TunnelServer) handleWS(rawConn net.Conn, req *httpRequestHeader, headerBytes []byte, buffered []byte) (HandleResult, net.Conn, error) { + rejectOrReply := func(code int, body string) (HandleResult, net.Conn, error) { + if s.passThroughOnReject { + prefix := make([]byte, 0, len(headerBytes)+len(buffered)) + prefix = append(prefix, headerBytes...) + prefix = append(prefix, buffered...) + return HandlePassThrough, newRejectedPreBufferedConn(rawConn, prefix), nil + } + _ = writeSimpleHTTPResponse(rawConn, code, body) + _ = rawConn.Close() + return HandleDone, nil, nil + } + + u, err := url.ParseRequestURI(req.target) + if err != nil { + return rejectOrReply(http.StatusBadRequest, "bad request") + } + + path, ok := stripPathRoot(s.pathRoot, u.Path) + if !ok || path != "/ws" { + return rejectOrReply(http.StatusNotFound, "not found") + } + if strings.ToUpper(strings.TrimSpace(req.method)) != http.MethodGet { + return rejectOrReply(http.StatusBadRequest, "bad request") + } + if !looksLikeWebSocketUpgrade(req.headers) { + return rejectOrReply(http.StatusBadRequest, "bad request") + } + + authVal := req.headers["authorization"] + if authVal == "" { + authVal = u.Query().Get(tunnelAuthQueryKey) + } + if !s.auth.verifyValue(authVal, TunnelModeWS, req.method, path, time.Now()) { + return rejectOrReply(http.StatusNotFound, "not found") + } + + prefix := make([]byte, 0, len(headerBytes)+len(buffered)) + prefix = append(prefix, headerBytes...) + prefix = append(prefix, buffered...) + wsConnRaw := newPreBufferedConn(rawConn, prefix) + + if _, err := ws.Upgrade(wsConnRaw); err != nil { + _ = rawConn.Close() + return HandleDone, nil, nil + } + + return HandleStartTunnel, newWSStreamConn(wsConnRaw, ws.StateServerSide), nil +} diff --git a/transport/sudoku/obfs/httpmask/ws_stream_conn.go b/transport/sudoku/obfs/httpmask/ws_stream_conn.go new file mode 100644 index 00000000..46fc3804 --- /dev/null +++ b/transport/sudoku/obfs/httpmask/ws_stream_conn.go @@ -0,0 +1,78 @@ +package httpmask + +import ( + "errors" + "fmt" + "io" + "net" + + "github.com/gobwas/ws" + "github.com/gobwas/ws/wsutil" +) + +type wsStreamConn struct { + net.Conn + state ws.State + reader *wsutil.Reader + controlHandler wsutil.FrameHandlerFunc +} + +func newWSStreamConn(conn net.Conn, state ws.State) net.Conn { + controlHandler := wsutil.ControlFrameHandler(conn, state) + return &wsStreamConn{ + Conn: conn, + state: state, + reader: &wsutil.Reader{ + Source: conn, + State: state, + }, + controlHandler: controlHandler, + } +} + +func (c *wsStreamConn) Read(b []byte) (n int, err error) { + defer func() { + if v := recover(); v != nil { + err = fmt.Errorf("websocket error: %v", v) + } + }() + + for { + n, err = c.reader.Read(b) + if errors.Is(err, io.EOF) { + err = nil + } + if !errors.Is(err, wsutil.ErrNoFrameAdvance) { + return n, err + } + + hdr, err2 := c.reader.NextFrame() + if err2 != nil { + return 0, err2 + } + if hdr.OpCode.IsControl() { + if err := c.controlHandler(hdr, c.reader); err != nil { + return 0, err + } + continue + } + if hdr.OpCode&(ws.OpBinary|ws.OpText) == 0 { + if err := c.reader.Discard(); err != nil { + return 0, err + } + continue + } + } +} + +func (c *wsStreamConn) Write(b []byte) (int, error) { + if err := wsutil.WriteMessage(c.Conn, c.state, ws.OpBinary, b); err != nil { + return 0, err + } + return len(b), nil +} + +func (c *wsStreamConn) Close() error { + _ = wsutil.WriteMessage(c.Conn, c.state, ws.OpClose, ws.NewCloseFrameBody(ws.StatusNormalClosure, "")) + return c.Conn.Close() +} diff --git a/transport/sudoku/obfs/sudoku/table.go b/transport/sudoku/obfs/sudoku/table.go index d86e642f..c845db91 100644 --- a/transport/sudoku/obfs/sudoku/table.go +++ b/transport/sudoku/obfs/sudoku/table.go @@ -4,9 +4,7 @@ import ( "crypto/sha256" "encoding/binary" "errors" - "log" "math/rand" - "time" ) var ( @@ -26,7 +24,7 @@ type Table struct { func NewTable(key string, mode string) *Table { t, err := NewTableWithCustom(key, mode, "") if err != nil { - log.Panicf("failed to build table: %v", err) + panic(err) } return t } @@ -35,8 +33,6 @@ func NewTable(key string, mode string) *Table { // mode: "prefer_ascii" or "prefer_entropy". If a custom pattern is provided, ASCII mode still takes precedence. // The customPattern must contain 8 characters with exactly 2 x, 2 p, and 4 v (case-insensitive). func NewTableWithCustom(key string, mode string, customPattern string) (*Table, error) { - start := time.Now() - layout, err := resolveLayout(mode, customPattern) if err != nil { return nil, err @@ -126,7 +122,6 @@ func NewTableWithCustom(key string, mode string, customPattern string) (*Table, } } } - log.Printf("[Init] Sudoku Tables initialized (%s) in %v", layout.name, time.Since(start)) return t, nil } diff --git a/transport/sudoku/replay.go b/transport/sudoku/replay.go new file mode 100644 index 00000000..b2695082 --- /dev/null +++ b/transport/sudoku/replay.go @@ -0,0 +1,74 @@ +package sudoku + +import ( + "sync" + "time" +) + +var handshakeReplayTTL = 60 * time.Second + +type nonceSet struct { + mu sync.Mutex + m map[[kipHelloNonceSize]byte]time.Time + maxEntries int + lastPrune time.Time +} + +func newNonceSet(maxEntries int) *nonceSet { + if maxEntries <= 0 { + maxEntries = 4096 + } + return &nonceSet{ + m: make(map[[kipHelloNonceSize]byte]time.Time), + maxEntries: maxEntries, + } +} + +func (s *nonceSet) allow(nonce [kipHelloNonceSize]byte, now time.Time, ttl time.Duration) bool { + s.mu.Lock() + defer s.mu.Unlock() + + if ttl <= 0 { + ttl = 60 * time.Second + } + + if now.Sub(s.lastPrune) > ttl/2 || len(s.m) > s.maxEntries { + for k, exp := range s.m { + if !now.Before(exp) { + delete(s.m, k) + } + } + s.lastPrune = now + for len(s.m) > s.maxEntries { + for k := range s.m { + delete(s.m, k) + break + } + } + } + + if exp, ok := s.m[nonce]; ok && now.Before(exp) { + return false + } + s.m[nonce] = now.Add(ttl) + return true +} + +type handshakeReplayProtector struct { + users sync.Map // map[userHash string]*nonceSet +} + +func (p *handshakeReplayProtector) allow(userHash string, nonce [kipHelloNonceSize]byte, now time.Time) bool { + if userHash == "" { + userHash = "_" + } + val, _ := p.users.LoadOrStore(userHash, newNonceSet(4096)) + set, ok := val.(*nonceSet) + if !ok || set == nil { + set = newNonceSet(4096) + p.users.Store(userHash, set) + } + return set.allow(nonce, now, handshakeReplayTTL) +} + +var globalHandshakeReplay = &handshakeReplayProtector{} diff --git a/transport/sudoku/session_keys.go b/transport/sudoku/session_keys.go new file mode 100644 index 00000000..48971f97 --- /dev/null +++ b/transport/sudoku/session_keys.go @@ -0,0 +1,58 @@ +package sudoku + +import ( + "crypto/ecdh" + "crypto/sha256" + "fmt" + "io" + + "golang.org/x/crypto/hkdf" +) + +func derivePSKDirectionalBases(psk string) (c2s, s2c []byte) { + sum := sha256.Sum256([]byte(psk)) + c2sKey := make([]byte, 32) + s2cKey := make([]byte, 32) + if _, err := io.ReadFull(hkdf.Expand(sha256.New, sum[:], []byte("sudoku-psk-c2s")), c2sKey); err != nil { + panic("sudoku: hkdf expand failed") + } + if _, err := io.ReadFull(hkdf.Expand(sha256.New, sum[:], []byte("sudoku-psk-s2c")), s2cKey); err != nil { + panic("sudoku: hkdf expand failed") + } + return c2sKey, s2cKey +} + +func deriveSessionDirectionalBases(psk string, shared []byte, nonce [kipHelloNonceSize]byte) (c2s, s2c []byte, err error) { + sum := sha256.Sum256([]byte(psk)) + ikm := make([]byte, 0, len(shared)+len(nonce)) + ikm = append(ikm, shared...) + ikm = append(ikm, nonce[:]...) + + prk := hkdf.Extract(sha256.New, ikm, sum[:]) + + c2sKey := make([]byte, 32) + s2cKey := make([]byte, 32) + if _, err := io.ReadFull(hkdf.Expand(sha256.New, prk, []byte("sudoku-session-c2s")), c2sKey); err != nil { + return nil, nil, fmt.Errorf("hkdf expand c2s: %w", err) + } + if _, err := io.ReadFull(hkdf.Expand(sha256.New, prk, []byte("sudoku-session-s2c")), s2cKey); err != nil { + return nil, nil, fmt.Errorf("hkdf expand s2c: %w", err) + } + return c2sKey, s2cKey, nil +} + +func x25519SharedSecret(priv *ecdh.PrivateKey, peerPub []byte) ([]byte, error) { + if priv == nil { + return nil, fmt.Errorf("nil priv") + } + curve := ecdh.X25519() + pk, err := curve.NewPublicKey(peerPub) + if err != nil { + return nil, fmt.Errorf("parse peer pub: %w", err) + } + secret, err := priv.ECDH(pk) + if err != nil { + return nil, fmt.Errorf("ecdh: %w", err) + } + return secret, nil +} diff --git a/transport/sudoku/table_probe.go b/transport/sudoku/table_probe.go index c885756e..09799d06 100644 --- a/transport/sudoku/table_probe.go +++ b/transport/sudoku/table_probe.go @@ -4,7 +4,6 @@ import ( "bufio" "bytes" crand "crypto/rand" - "encoding/binary" "errors" "fmt" "io" @@ -56,26 +55,27 @@ func drainBuffered(r *bufio.Reader) ([]byte, error) { func probeHandshakeBytes(probe []byte, cfg *ProtocolConfig, table *sudoku.Table) error { rc := &readOnlyConn{Reader: bytes.NewReader(probe)} _, obfsConn := buildServerObfsConn(rc, cfg, table, false) - cConn, err := crypto.NewAEADConn(obfsConn, cfg.Key, cfg.AEADMethod) + seed := ServerAEADSeed(cfg.Key) + pskC2S, pskS2C := derivePSKDirectionalBases(seed) + // Server side: recv is client->server, send is server->client. + cConn, err := crypto.NewRecordConn(obfsConn, cfg.AEADMethod, pskS2C, pskC2S) if err != nil { return err } - var handshakeBuf [16]byte - if _, err := io.ReadFull(cConn, handshakeBuf[:]); err != nil { + msg, err := ReadKIPMessage(cConn) + if err != nil { return err } - ts := int64(binary.BigEndian.Uint64(handshakeBuf[:8])) - if absInt64(time.Now().Unix()-ts) > 60 { - return fmt.Errorf("timestamp skew/replay detected") + if msg.Type != KIPTypeClientHello { + return fmt.Errorf("unexpected handshake message: %d", msg.Type) } - - modeBuf := []byte{0} - if _, err := io.ReadFull(cConn, modeBuf); err != nil { + ch, err := DecodeKIPClientHelloPayload(msg.Payload) + if err != nil { return err } - if modeBuf[0] != downlinkMode(cfg) { - return fmt.Errorf("downlink mode mismatch") + if absInt64(time.Now().Unix()-ch.Timestamp.Unix()) > int64(kipHandshakeSkew.Seconds()) { + return fmt.Errorf("time skew/replay") } return nil @@ -93,6 +93,17 @@ func selectTableByProbe(r *bufio.Reader, cfg *ProtocolConfig, tables []*sudoku.T return nil, nil, fmt.Errorf("too many table candidates: %d", len(tables)) } + // Copy so we can prune candidates without mutating the caller slice. + candidates := make([]*sudoku.Table, 0, len(tables)) + for _, t := range tables { + if t != nil { + candidates = append(candidates, t) + } + } + if len(candidates) == 0 { + return nil, nil, fmt.Errorf("no table candidates") + } + probe, err := drainBuffered(r) if err != nil { return nil, nil, fmt.Errorf("drain buffered bytes failed: %w", err) @@ -100,17 +111,18 @@ func selectTableByProbe(r *bufio.Reader, cfg *ProtocolConfig, tables []*sudoku.T tmp := make([]byte, readChunk) for { - if len(tables) == 1 { + if len(candidates) == 1 { tail, err := drainBuffered(r) if err != nil { return nil, nil, fmt.Errorf("drain buffered bytes failed: %w", err) } probe = append(probe, tail...) - return tables[0], probe, nil + return candidates[0], probe, nil } needMore := false - for _, table := range tables { + next := candidates[:0] + for _, table := range candidates { err := probeHandshakeBytes(probe, cfg, table) if err == nil { tail, err := drainBuffered(r) @@ -122,10 +134,13 @@ func selectTableByProbe(r *bufio.Reader, cfg *ProtocolConfig, tables []*sudoku.T } if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { needMore = true + next = append(next, table) } + // Definitive mismatch: drop table. } + candidates = next - if !needMore { + if len(candidates) == 0 || !needMore { return nil, probe, fmt.Errorf("handshake table selection failed") } if len(probe) >= maxProbeBytes { diff --git a/transport/sudoku/tables.go b/transport/sudoku/tables.go index 2630ea52..b6a25bb3 100644 --- a/transport/sudoku/tables.go +++ b/transport/sudoku/tables.go @@ -6,9 +6,7 @@ import ( "github.com/metacubex/mihomo/transport/sudoku/obfs/sudoku" ) -// NewTablesWithCustomPatterns builds one or more obfuscation tables from x/v/p custom patterns. -// When customTables is non-empty it overrides customTable (matching upstream Sudoku behavior). -func NewTablesWithCustomPatterns(key string, tableType string, customTable string, customTables []string) ([]*sudoku.Table, error) { +func normalizeCustomPatterns(customTable string, customTables []string) []string { patterns := customTables if len(patterns) == 0 && strings.TrimSpace(customTable) != "" { patterns = []string{customTable} @@ -16,7 +14,15 @@ func NewTablesWithCustomPatterns(key string, tableType string, customTable strin if len(patterns) == 0 { patterns = []string{""} } + return patterns +} +// NewTablesWithCustomPatterns builds one or more obfuscation tables from x/v/p custom patterns. +// When customTables is non-empty it overrides customTable (matching upstream Sudoku behavior). +// +// Deprecated-ish: prefer NewClientTablesWithCustomPatterns / NewServerTablesWithCustomPatterns. +func NewTablesWithCustomPatterns(key string, tableType string, customTable string, customTables []string) ([]*sudoku.Table, error) { + patterns := normalizeCustomPatterns(customTable, customTables) tables := make([]*sudoku.Table, 0, len(patterns)) for _, pattern := range patterns { pattern = strings.TrimSpace(pattern) @@ -28,3 +34,17 @@ func NewTablesWithCustomPatterns(key string, tableType string, customTable strin } return tables, nil } + +func NewClientTablesWithCustomPatterns(key string, tableType string, customTable string, customTables []string) ([]*sudoku.Table, error) { + return NewTablesWithCustomPatterns(key, tableType, customTable, customTables) +} + +// NewServerTablesWithCustomPatterns matches upstream server behavior: when custom table rotation is enabled, +// also accept the default table to avoid forcing clients to update in lockstep. +func NewServerTablesWithCustomPatterns(key string, tableType string, customTable string, customTables []string) ([]*sudoku.Table, error) { + patterns := normalizeCustomPatterns(customTable, customTables) + if len(patterns) > 0 && strings.TrimSpace(patterns[0]) != "" { + patterns = append([]string{""}, patterns...) + } + return NewTablesWithCustomPatterns(key, tableType, "", patterns) +} diff --git a/transport/sudoku/uot.go b/transport/sudoku/uot.go index be3fe900..054108a6 100644 --- a/transport/sudoku/uot.go +++ b/transport/sudoku/uot.go @@ -16,17 +16,9 @@ import ( ) const ( - UoTMagicByte byte = 0xEE - uotVersion = 0x01 - maxUoTPayload = 64 * 1024 + maxUoTPayload = 64 * 1024 ) -// WritePreface writes the UDP-over-TCP marker and version. -func WritePreface(w io.Writer) error { - _, err := w.Write([]byte{UoTMagicByte, uotVersion}) - return err -} - // WriteDatagram sends a single UDP datagram frame over a reliable stream. func WriteDatagram(w io.Writer, addr string, payload []byte) error { addrBuf, err := EncodeAddress(addr) @@ -45,14 +37,13 @@ func WriteDatagram(w io.Writer, addr string, payload []byte) error { binary.BigEndian.PutUint16(header[:2], uint16(len(addrBuf))) binary.BigEndian.PutUint16(header[2:], uint16(len(payload))) - if _, err := w.Write(header[:]); err != nil { + if err := writeFull(w, header[:]); err != nil { return err } - if _, err := w.Write(addrBuf); err != nil { + if err := writeFull(w, addrBuf); err != nil { return err } - _, err = w.Write(payload) - return err + return writeFull(w, payload) } // ReadDatagram parses a single UDP datagram frame from the reliable stream.