mirror of
https://github.com/MetaCubeX/mihomo.git
synced 2026-03-01 10:09:54 +00:00
feat: sudoku support ws transport (#2589)
This commit is contained in:
@@ -11,6 +11,7 @@ import (
|
||||
N "github.com/metacubex/mihomo/common/net"
|
||||
C "github.com/metacubex/mihomo/constant"
|
||||
"github.com/metacubex/mihomo/transport/sudoku"
|
||||
"github.com/metacubex/mihomo/transport/sudoku/obfs/httpmask"
|
||||
)
|
||||
|
||||
type Sudoku struct {
|
||||
@@ -18,32 +19,43 @@ type Sudoku struct {
|
||||
option *SudokuOption
|
||||
baseConf sudoku.ProtocolConfig
|
||||
|
||||
httpMaskMu sync.Mutex
|
||||
httpMaskClient *sudoku.HTTPMaskTunnelClient
|
||||
|
||||
muxMu sync.Mutex
|
||||
muxClient *sudoku.MultiplexClient
|
||||
|
||||
httpMaskMu sync.Mutex
|
||||
httpMaskClient *httpmask.TunnelClient
|
||||
httpMaskKey string
|
||||
}
|
||||
|
||||
type SudokuOption struct {
|
||||
BasicOption
|
||||
Name string `proxy:"name"`
|
||||
Server string `proxy:"server"`
|
||||
Port int `proxy:"port"`
|
||||
Key string `proxy:"key"`
|
||||
AEADMethod string `proxy:"aead-method,omitempty"`
|
||||
PaddingMin *int `proxy:"padding-min,omitempty"`
|
||||
PaddingMax *int `proxy:"padding-max,omitempty"`
|
||||
TableType string `proxy:"table-type,omitempty"` // "prefer_ascii" or "prefer_entropy"
|
||||
EnablePureDownlink *bool `proxy:"enable-pure-downlink,omitempty"`
|
||||
HTTPMask bool `proxy:"http-mask,omitempty"`
|
||||
HTTPMaskMode string `proxy:"http-mask-mode,omitempty"` // "legacy" (default), "stream", "poll", "auto"
|
||||
HTTPMaskTLS bool `proxy:"http-mask-tls,omitempty"` // only for http-mask-mode stream/poll/auto
|
||||
HTTPMaskHost string `proxy:"http-mask-host,omitempty"` // optional Host/SNI override (domain or domain:port)
|
||||
PathRoot string `proxy:"path-root,omitempty"` // optional first-level path prefix for HTTP tunnel endpoints
|
||||
HTTPMaskMultiplex string `proxy:"http-mask-multiplex,omitempty"` // "off" (default), "auto" (reuse h1/h2), "on" (single tunnel, multi-target)
|
||||
CustomTable string `proxy:"custom-table,omitempty"` // optional custom byte layout, e.g. xpxvvpvv
|
||||
CustomTables []string `proxy:"custom-tables,omitempty"` // optional table rotation patterns, overrides custom-table when non-empty
|
||||
Name string `proxy:"name"`
|
||||
Server string `proxy:"server"`
|
||||
Port int `proxy:"port"`
|
||||
Key string `proxy:"key"`
|
||||
AEADMethod string `proxy:"aead-method,omitempty"`
|
||||
PaddingMin *int `proxy:"padding-min,omitempty"`
|
||||
PaddingMax *int `proxy:"padding-max,omitempty"`
|
||||
TableType string `proxy:"table-type,omitempty"` // "prefer_ascii" or "prefer_entropy"
|
||||
EnablePureDownlink *bool `proxy:"enable-pure-downlink,omitempty"`
|
||||
HTTPMask *bool `proxy:"http-mask,omitempty"`
|
||||
HTTPMaskMode string `proxy:"http-mask-mode,omitempty"` // "legacy" (default), "stream", "poll", "auto", "ws"
|
||||
HTTPMaskTLS bool `proxy:"http-mask-tls,omitempty"` // only for http-mask-mode stream/poll/auto
|
||||
HTTPMaskHost string `proxy:"http-mask-host,omitempty"` // optional Host/SNI override (domain or domain:port)
|
||||
PathRoot string `proxy:"path-root,omitempty"` // optional first-level path prefix for HTTP tunnel endpoints
|
||||
HTTPMaskMultiplex string `proxy:"http-mask-multiplex,omitempty"` // "off" (default), "auto" (reuse h1/h2), "on" (single tunnel, multi-target)
|
||||
HTTPMaskOptions *SudokuHTTPMaskOptions `proxy:"httpmask,omitempty"`
|
||||
CustomTable string `proxy:"custom-table,omitempty"` // optional custom byte layout, e.g. xpxvvpvv
|
||||
CustomTables []string `proxy:"custom-tables,omitempty"` // optional table rotation patterns, overrides custom-table when non-empty
|
||||
}
|
||||
|
||||
type SudokuHTTPMaskOptions struct {
|
||||
Disable bool `proxy:"disable,omitempty"`
|
||||
Mode string `proxy:"mode,omitempty"`
|
||||
TLS bool `proxy:"tls,omitempty"`
|
||||
Host string `proxy:"host,omitempty"`
|
||||
PathRoot string `proxy:"path_root,omitempty"`
|
||||
Multiplex string `proxy:"multiplex,omitempty"`
|
||||
}
|
||||
|
||||
// DialContext implements C.ProxyAdapter
|
||||
@@ -73,7 +85,7 @@ func (s *Sudoku) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Con
|
||||
return nil, fmt.Errorf("encode target address failed: %w", err)
|
||||
}
|
||||
|
||||
if _, err = c.Write(addrBuf); err != nil {
|
||||
if err = sudoku.WriteKIPMessage(c, sudoku.KIPTypeOpenTCP, addrBuf); err != nil {
|
||||
return nil, fmt.Errorf("send target address failed: %w", err)
|
||||
}
|
||||
|
||||
@@ -96,9 +108,9 @@ func (s *Sudoku) ListenPacketContext(ctx context.Context, metadata *C.Metadata)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = sudoku.WritePreface(c); err != nil {
|
||||
if err = sudoku.WriteKIPMessage(c, sudoku.KIPTypeStartUoT, nil); err != nil {
|
||||
_ = c.Close()
|
||||
return nil, fmt.Errorf("send uot preface failed: %w", err)
|
||||
return nil, fmt.Errorf("start uot failed: %w", err)
|
||||
}
|
||||
|
||||
return newPacketConn(N.NewThreadSafePacketConn(sudoku.NewUoTPacketConn(c)), s), nil
|
||||
@@ -141,32 +153,45 @@ func NewSudoku(option SudokuOption) (*Sudoku, error) {
|
||||
return nil, fmt.Errorf("key is required")
|
||||
}
|
||||
|
||||
tableType := strings.ToLower(option.TableType)
|
||||
if tableType == "" {
|
||||
tableType = "prefer_ascii"
|
||||
defaultConf := sudoku.DefaultConfig()
|
||||
tableType, err := sudoku.NormalizeTableType(option.TableType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if tableType != "prefer_ascii" && tableType != "prefer_entropy" {
|
||||
return nil, fmt.Errorf("table-type must be prefer_ascii or prefer_entropy")
|
||||
paddingMin, paddingMax := sudoku.ResolvePadding(option.PaddingMin, option.PaddingMax, defaultConf.PaddingMin, defaultConf.PaddingMax)
|
||||
enablePureDownlink := sudoku.DerefBool(option.EnablePureDownlink, defaultConf.EnablePureDownlink)
|
||||
|
||||
disableHTTPMask := defaultConf.DisableHTTPMask
|
||||
if option.HTTPMask != nil {
|
||||
disableHTTPMask = !*option.HTTPMask
|
||||
}
|
||||
httpMaskMode := defaultConf.HTTPMaskMode
|
||||
if option.HTTPMaskMode != "" {
|
||||
httpMaskMode = option.HTTPMaskMode
|
||||
}
|
||||
httpMaskTLS := option.HTTPMaskTLS
|
||||
httpMaskHost := option.HTTPMaskHost
|
||||
pathRoot := strings.TrimSpace(option.PathRoot)
|
||||
httpMaskMultiplex := defaultConf.HTTPMaskMultiplex
|
||||
if option.HTTPMaskMultiplex != "" {
|
||||
httpMaskMultiplex = option.HTTPMaskMultiplex
|
||||
}
|
||||
|
||||
defaultConf := sudoku.DefaultConfig()
|
||||
paddingMin := defaultConf.PaddingMin
|
||||
paddingMax := defaultConf.PaddingMax
|
||||
if option.PaddingMin != nil {
|
||||
paddingMin = *option.PaddingMin
|
||||
}
|
||||
if option.PaddingMax != nil {
|
||||
paddingMax = *option.PaddingMax
|
||||
}
|
||||
if option.PaddingMin == nil && option.PaddingMax != nil && paddingMax < paddingMin {
|
||||
paddingMin = paddingMax
|
||||
}
|
||||
if option.PaddingMax == nil && option.PaddingMin != nil && paddingMax < paddingMin {
|
||||
paddingMax = paddingMin
|
||||
}
|
||||
enablePureDownlink := defaultConf.EnablePureDownlink
|
||||
if option.EnablePureDownlink != nil {
|
||||
enablePureDownlink = *option.EnablePureDownlink
|
||||
if hm := option.HTTPMaskOptions; hm != nil {
|
||||
disableHTTPMask = hm.Disable
|
||||
if hm.Mode != "" {
|
||||
httpMaskMode = hm.Mode
|
||||
}
|
||||
httpMaskTLS = hm.TLS
|
||||
httpMaskHost = hm.Host
|
||||
if pr := strings.TrimSpace(hm.PathRoot); pr != "" {
|
||||
pathRoot = pr
|
||||
}
|
||||
if mux := strings.TrimSpace(hm.Multiplex); mux != "" {
|
||||
httpMaskMultiplex = mux
|
||||
} else {
|
||||
httpMaskMultiplex = defaultConf.HTTPMaskMultiplex
|
||||
}
|
||||
}
|
||||
|
||||
baseConf := sudoku.ProtocolConfig{
|
||||
@@ -177,20 +202,14 @@ func NewSudoku(option SudokuOption) (*Sudoku, error) {
|
||||
PaddingMax: paddingMax,
|
||||
EnablePureDownlink: enablePureDownlink,
|
||||
HandshakeTimeoutSeconds: defaultConf.HandshakeTimeoutSeconds,
|
||||
DisableHTTPMask: !option.HTTPMask,
|
||||
HTTPMaskMode: defaultConf.HTTPMaskMode,
|
||||
HTTPMaskTLSEnabled: option.HTTPMaskTLS,
|
||||
HTTPMaskHost: option.HTTPMaskHost,
|
||||
HTTPMaskPathRoot: strings.TrimSpace(option.PathRoot),
|
||||
HTTPMaskMultiplex: defaultConf.HTTPMaskMultiplex,
|
||||
DisableHTTPMask: disableHTTPMask,
|
||||
HTTPMaskMode: httpMaskMode,
|
||||
HTTPMaskTLSEnabled: httpMaskTLS,
|
||||
HTTPMaskHost: httpMaskHost,
|
||||
HTTPMaskPathRoot: pathRoot,
|
||||
HTTPMaskMultiplex: httpMaskMultiplex,
|
||||
}
|
||||
if option.HTTPMaskMode != "" {
|
||||
baseConf.HTTPMaskMode = option.HTTPMaskMode
|
||||
}
|
||||
if option.HTTPMaskMultiplex != "" {
|
||||
baseConf.HTTPMaskMultiplex = option.HTTPMaskMultiplex
|
||||
}
|
||||
tables, err := sudoku.NewTablesWithCustomPatterns(sudoku.ClientAEADSeed(option.Key), tableType, option.CustomTable, option.CustomTables)
|
||||
tables, err := sudoku.NewClientTablesWithCustomPatterns(sudoku.ClientAEADSeed(option.Key), tableType, option.CustomTable, option.CustomTables)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build table(s) failed: %w", err)
|
||||
}
|
||||
@@ -244,7 +263,7 @@ func normalizeHTTPMaskMultiplex(mode string) string {
|
||||
|
||||
func httpTunnelModeEnabled(mode string) bool {
|
||||
switch strings.ToLower(strings.TrimSpace(mode)) {
|
||||
case "stream", "poll", "auto":
|
||||
case "stream", "poll", "auto", "ws":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
@@ -271,14 +290,24 @@ func (s *Sudoku) dialAndHandshake(ctx context.Context, cfg *sudoku.ProtocolConfi
|
||||
)
|
||||
if !cfg.DisableHTTPMask && httpTunnelModeEnabled(cfg.HTTPMaskMode) {
|
||||
muxMode := normalizeHTTPMaskMultiplex(cfg.HTTPMaskMultiplex)
|
||||
switch muxMode {
|
||||
case "auto", "on":
|
||||
client, errX := s.getOrCreateHTTPMaskClient(cfg)
|
||||
if errX != nil {
|
||||
return nil, errX
|
||||
if muxMode == "auto" && strings.ToLower(strings.TrimSpace(cfg.HTTPMaskMode)) != "ws" {
|
||||
if client, cerr := s.getOrCreateHTTPMaskClient(cfg); cerr == nil && client != nil {
|
||||
c, err = client.DialTunnel(ctx, httpmask.TunnelDialOptions{
|
||||
Mode: cfg.HTTPMaskMode,
|
||||
TLSEnabled: cfg.HTTPMaskTLSEnabled,
|
||||
HostOverride: cfg.HTTPMaskHost,
|
||||
PathRoot: cfg.HTTPMaskPathRoot,
|
||||
AuthKey: sudoku.ClientAEADSeed(cfg.Key),
|
||||
Upgrade: upgrade,
|
||||
Multiplex: cfg.HTTPMaskMultiplex,
|
||||
DialContext: s.dialer.DialContext,
|
||||
})
|
||||
if err != nil {
|
||||
s.resetHTTPMaskClient()
|
||||
}
|
||||
}
|
||||
c, err = client.Dial(ctx, upgrade)
|
||||
default:
|
||||
}
|
||||
if c == nil && err == nil {
|
||||
c, err = sudoku.DialHTTPMaskTunnel(ctx, cfg.ServerAddress, cfg, s.dialer.DialContext, upgrade)
|
||||
}
|
||||
if err == nil && c != nil {
|
||||
@@ -372,34 +401,51 @@ func (s *Sudoku) resetMuxClient() {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Sudoku) getOrCreateHTTPMaskClient(cfg *sudoku.ProtocolConfig) (*sudoku.HTTPMaskTunnelClient, error) {
|
||||
if s == nil {
|
||||
return nil, fmt.Errorf("nil adapter")
|
||||
}
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("config is required")
|
||||
}
|
||||
|
||||
s.httpMaskMu.Lock()
|
||||
defer s.httpMaskMu.Unlock()
|
||||
|
||||
if s.httpMaskClient != nil {
|
||||
return s.httpMaskClient, nil
|
||||
}
|
||||
|
||||
c, err := sudoku.NewHTTPMaskTunnelClient(cfg.ServerAddress, cfg, s.dialer.DialContext)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.httpMaskClient = c
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (s *Sudoku) resetHTTPMaskClient() {
|
||||
s.httpMaskMu.Lock()
|
||||
defer s.httpMaskMu.Unlock()
|
||||
if s.httpMaskClient != nil {
|
||||
s.httpMaskClient.CloseIdleConnections()
|
||||
s.httpMaskClient = nil
|
||||
s.httpMaskKey = ""
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Sudoku) getOrCreateHTTPMaskClient(cfg *sudoku.ProtocolConfig) (*httpmask.TunnelClient, error) {
|
||||
if s == nil || cfg == nil {
|
||||
return nil, fmt.Errorf("nil adapter or config")
|
||||
}
|
||||
|
||||
key := cfg.ServerAddress + "|" + strconv.FormatBool(cfg.HTTPMaskTLSEnabled) + "|" + strings.TrimSpace(cfg.HTTPMaskHost)
|
||||
|
||||
s.httpMaskMu.Lock()
|
||||
if s.httpMaskClient != nil && s.httpMaskKey == key {
|
||||
client := s.httpMaskClient
|
||||
s.httpMaskMu.Unlock()
|
||||
return client, nil
|
||||
}
|
||||
s.httpMaskMu.Unlock()
|
||||
|
||||
client, err := httpmask.NewTunnelClient(cfg.ServerAddress, httpmask.TunnelClientOptions{
|
||||
TLSEnabled: cfg.HTTPMaskTLSEnabled,
|
||||
HostOverride: cfg.HTTPMaskHost,
|
||||
DialContext: s.dialer.DialContext,
|
||||
MaxIdleConns: 32,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.httpMaskMu.Lock()
|
||||
defer s.httpMaskMu.Unlock()
|
||||
if s.httpMaskClient != nil && s.httpMaskKey == key {
|
||||
client.CloseIdleConnections()
|
||||
return s.httpMaskClient, nil
|
||||
}
|
||||
if s.httpMaskClient != nil {
|
||||
s.httpMaskClient.CloseIdleConnections()
|
||||
}
|
||||
s.httpMaskClient = client
|
||||
s.httpMaskKey = key
|
||||
return client, nil
|
||||
}
|
||||
|
||||
@@ -1092,12 +1092,22 @@ proxies: # socks5
|
||||
table-type: prefer_ascii # 可选值:prefer_ascii、prefer_entropy 前者全ascii映射,后者保证熵值(汉明1)低于3
|
||||
# custom-table: xpxvvpvv # 可选,自定义字节布局,必须包含2个x、2个p、4个v,可随意组合。启用此处则需配置`table-type`为`prefer_entropy`
|
||||
# custom-tables: ["xpxvvpvv", "vxpvxvvp"] # 可选,自定义字节布局列表(x/v/p),用于 xvp 模式轮换;非空时覆盖 custom-table
|
||||
http-mask: true # 是否启用http掩码
|
||||
# http-mask-mode: legacy # 可选:legacy(默认)、stream(split-stream)、poll、auto(先 stream 再 poll);stream/poll/auto 支持走 CDN/反代
|
||||
# http-mask-tls: true # 可选:仅在 http-mask-mode 为 stream/poll/auto 时生效;true 强制 https;false 强制 http(不会根据端口自动推断)
|
||||
# http-mask-host: "" # 可选:覆盖 Host/SNI(支持 example.com 或 example.com:443);仅在 http-mask-mode 为 stream/poll/auto 时生效
|
||||
# path-root: "" # 可选:HTTP 隧道端点一级路径前缀(双方需一致),例如 "aabbcc" 或 "/aabbcc/" => /aabbcc/session、/aabbcc/stream、/aabbcc/api/v1/upload
|
||||
# http-mask-multiplex: off # 可选:off(默认)、auto(复用底层 HTTP 连接,减少建链 RTT)、on(Sudoku mux 单隧道多目标;仅在 http-mask-mode=stream/poll/auto 生效)
|
||||
# 推荐:使用 httpmask 对象统一管理 HTTPMask 相关字段:
|
||||
httpmask:
|
||||
disable: false # true 禁用所有 HTTP 伪装/隧道
|
||||
mode: legacy # 可选:legacy(默认)、stream(split-stream)、poll、auto(先 stream 再 poll)、ws(WebSocket 隧道)
|
||||
# tls: true # 可选:仅在 mode 为 stream/poll/auto/ws 时生效;true 强制 https/wss;false 强制 http/ws(不会根据端口自动推断)
|
||||
# host: "" # 可选:覆盖 Host/SNI(支持 example.com 或 example.com:443);仅在 mode 为 stream/poll/auto/ws 时生效
|
||||
# path_root: "" # 可选:HTTP 隧道端点一级路径前缀(双方需一致),例如 "aabbcc" 或 "/aabbcc/" => /aabbcc/session、/aabbcc/stream、/aabbcc/api/v1/upload、/aabbcc/ws
|
||||
# multiplex: off # 可选:off(默认)、auto(复用底层 HTTP 连接,减少建链 RTT)、on(Sudoku mux 单隧道多目标;仅在 mode=stream/poll/auto 生效;ws 强制 off)
|
||||
#
|
||||
# 向后兼容旧写法:
|
||||
# http-mask: true # 是否启用 http 掩码
|
||||
# http-mask-mode: legacy # 可选:legacy(默认)、stream(split-stream)、poll、auto(先 stream 再 poll)、ws(WebSocket 隧道)
|
||||
# http-mask-tls: true # 可选:仅在 http-mask-mode 为 stream/poll/auto/ws 时生效;true 强制 https/wss;false 强制 http/ws
|
||||
# http-mask-host: "" # 可选:覆盖 Host/SNI(支持 example.com 或 example.com:443);仅在 http-mask-mode 为 stream/poll/auto/ws 时生效
|
||||
# path-root: "" # 可选:HTTP 隧道端点一级路径前缀(双方需一致)
|
||||
# http-mask-multiplex: off # 可选:off(默认)、auto(复用底层 HTTP 连接)、on(Sudoku mux 单隧道多目标;ws 强制 off)
|
||||
enable-pure-downlink: false # 可选:false=带宽优化下行(更快,要求 aead-method != none);true=纯 Sudoku 下行
|
||||
|
||||
# anytls
|
||||
@@ -1663,9 +1673,19 @@ listeners:
|
||||
# custom-tables: ["xpxvvpvv", "vxpvxvvp"] # 可选,自定义字节布局列表(x/v/p),用于 xvp 模式轮换;非空时覆盖 custom-table
|
||||
handshake-timeout: 5 # 可选(秒)
|
||||
enable-pure-downlink: false # 可选:false=带宽优化下行(更快,要求 aead-method != none);true=纯 Sudoku 下行
|
||||
disable-http-mask: false # 可选:禁用 http 掩码/隧道(默认为 false)
|
||||
# http-mask-mode: legacy # 可选:legacy(默认)、stream(split-stream)、poll、auto(先 stream 再 poll);stream/poll/auto 支持走 CDN/反代
|
||||
# path-root: "" # 可选:HTTP 隧道端点一级路径前缀(双方需一致),例如 "aabbcc" 或 "/aabbcc/" => /aabbcc/session、/aabbcc/stream、/aabbcc/api/v1/upload
|
||||
# 推荐:使用 httpmask 对象统一管理 HTTPMask 相关字段:
|
||||
httpmask:
|
||||
disable: false # true 禁用所有 HTTP 伪装/隧道
|
||||
mode: legacy # 可选:legacy(默认)、stream(split-stream)、poll、auto(先 stream 再 poll)、ws(WebSocket 隧道)
|
||||
# path_root: "" # 可选:HTTP 隧道端点一级路径前缀(双方需一致),例如 "aabbcc" 或 "/aabbcc/" => /aabbcc/session、/aabbcc/stream、/aabbcc/api/v1/upload、/aabbcc/ws
|
||||
#
|
||||
# 可选:当启用 HTTPMask 且识别到“像 HTTP 但不符合 tunnel/auth”的请求时,将原始字节透传给 fallback(常用于与其他服务共端口):
|
||||
# fallback: "127.0.0.1:80"
|
||||
#
|
||||
# 向后兼容旧写法:
|
||||
# disable-http-mask: false # 可选:禁用 http 掩码/隧道(默认为 false)
|
||||
# http-mask-mode: legacy # 可选:legacy(默认)、stream(split-stream)、poll、auto(先 stream 再 poll)、ws(WebSocket 隧道)
|
||||
# path-root: "" # 可选:HTTP 隧道端点一级路径前缀(双方需一致)
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ type SudokuServer struct {
|
||||
DisableHTTPMask bool `json:"disable-http-mask,omitempty"`
|
||||
HTTPMaskMode string `json:"http-mask-mode,omitempty"`
|
||||
PathRoot string `json:"path-root,omitempty"`
|
||||
Fallback string `json:"fallback,omitempty"`
|
||||
|
||||
// mihomo private extension (not the part of standard Sudoku protocol)
|
||||
MuxOption sing.MuxOption `json:"mux-option,omitempty"`
|
||||
|
||||
@@ -13,23 +13,31 @@ import (
|
||||
|
||||
type SudokuOption struct {
|
||||
BaseOption
|
||||
Key string `inbound:"key"`
|
||||
AEADMethod string `inbound:"aead-method,omitempty"`
|
||||
PaddingMin *int `inbound:"padding-min,omitempty"`
|
||||
PaddingMax *int `inbound:"padding-max,omitempty"`
|
||||
TableType string `inbound:"table-type,omitempty"` // "prefer_ascii" or "prefer_entropy"
|
||||
HandshakeTimeoutSecond *int `inbound:"handshake-timeout,omitempty"`
|
||||
EnablePureDownlink *bool `inbound:"enable-pure-downlink,omitempty"`
|
||||
CustomTable string `inbound:"custom-table,omitempty"` // optional custom byte layout, e.g. xpxvvpvv
|
||||
CustomTables []string `inbound:"custom-tables,omitempty"`
|
||||
DisableHTTPMask bool `inbound:"disable-http-mask,omitempty"`
|
||||
HTTPMaskMode string `inbound:"http-mask-mode,omitempty"` // "legacy" (default), "stream", "poll", "auto"
|
||||
PathRoot string `inbound:"path-root,omitempty"` // optional first-level path prefix for HTTP tunnel endpoints
|
||||
Key string `inbound:"key"`
|
||||
AEADMethod string `inbound:"aead-method,omitempty"`
|
||||
PaddingMin *int `inbound:"padding-min,omitempty"`
|
||||
PaddingMax *int `inbound:"padding-max,omitempty"`
|
||||
TableType string `inbound:"table-type,omitempty"` // "prefer_ascii" or "prefer_entropy"
|
||||
HandshakeTimeoutSecond *int `inbound:"handshake-timeout,omitempty"`
|
||||
EnablePureDownlink *bool `inbound:"enable-pure-downlink,omitempty"`
|
||||
CustomTable string `inbound:"custom-table,omitempty"` // optional custom byte layout, e.g. xpxvvpvv
|
||||
CustomTables []string `inbound:"custom-tables,omitempty"`
|
||||
DisableHTTPMask bool `inbound:"disable-http-mask,omitempty"`
|
||||
HTTPMaskMode string `inbound:"http-mask-mode,omitempty"` // "legacy" (default), "stream", "poll", "auto"
|
||||
PathRoot string `inbound:"path-root,omitempty"` // optional first-level path prefix for HTTP tunnel endpoints
|
||||
Fallback string `inbound:"fallback,omitempty"`
|
||||
HTTPMaskOptions *SudokuHTTPMaskOptions `inbound:"httpmask,omitempty"`
|
||||
|
||||
// mihomo private extension (not the part of standard Sudoku protocol)
|
||||
MuxOption MuxOption `inbound:"mux-option,omitempty"`
|
||||
}
|
||||
|
||||
type SudokuHTTPMaskOptions struct {
|
||||
Disable bool `inbound:"disable,omitempty"`
|
||||
Mode string `inbound:"mode,omitempty"`
|
||||
PathRoot string `inbound:"path_root,omitempty"`
|
||||
}
|
||||
|
||||
func (o SudokuOption) Equal(config C.InboundConfig) bool {
|
||||
return optionToString(o) == optionToString(config)
|
||||
}
|
||||
@@ -65,6 +73,16 @@ func NewSudoku(options *SudokuOption) (*Sudoku, error) {
|
||||
DisableHTTPMask: options.DisableHTTPMask,
|
||||
HTTPMaskMode: options.HTTPMaskMode,
|
||||
PathRoot: strings.TrimSpace(options.PathRoot),
|
||||
Fallback: strings.TrimSpace(options.Fallback),
|
||||
}
|
||||
if hm := options.HTTPMaskOptions; hm != nil {
|
||||
serverConf.DisableHTTPMask = hm.Disable
|
||||
if hm.Mode != "" {
|
||||
serverConf.HTTPMaskMode = hm.Mode
|
||||
}
|
||||
if pr := strings.TrimSpace(hm.PathRoot); pr != "" {
|
||||
serverConf.PathRoot = pr
|
||||
}
|
||||
}
|
||||
serverConf.MuxOption = options.MuxOption.Build()
|
||||
|
||||
|
||||
@@ -168,16 +168,17 @@ func TestInboundSudoku_CustomTable(t *testing.T) {
|
||||
func TestInboundSudoku_HTTPMaskMode(t *testing.T) {
|
||||
key := "test_key_http_mask_mode"
|
||||
|
||||
for _, mode := range []string{"legacy", "stream", "poll", "auto"} {
|
||||
for _, mode := range []string{"ws", "stream", "poll", "auto"} {
|
||||
mode := mode
|
||||
t.Run(mode, func(t *testing.T) {
|
||||
inboundOptions := inbound.SudokuOption{
|
||||
Key: key,
|
||||
HTTPMaskMode: mode,
|
||||
}
|
||||
httpMask := true
|
||||
outboundOptions := outbound.SudokuOption{
|
||||
Key: key,
|
||||
HTTPMask: true,
|
||||
HTTPMask: &httpMask,
|
||||
HTTPMaskMode: mode,
|
||||
}
|
||||
testInboundSudoku(t, inboundOptions, outboundOptions)
|
||||
|
||||
@@ -1,16 +1,19 @@
|
||||
package sudoku
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/metacubex/mihomo/adapter/inbound"
|
||||
N "github.com/metacubex/mihomo/common/net"
|
||||
"github.com/metacubex/mihomo/common/utils"
|
||||
C "github.com/metacubex/mihomo/constant"
|
||||
LC "github.com/metacubex/mihomo/listener/config"
|
||||
"github.com/metacubex/mihomo/listener/inner"
|
||||
"github.com/metacubex/mihomo/listener/sing"
|
||||
"github.com/metacubex/mihomo/log"
|
||||
"github.com/metacubex/mihomo/transport/socks5"
|
||||
@@ -23,6 +26,7 @@ type Listener struct {
|
||||
closed bool
|
||||
protoConf sudoku.ProtocolConfig
|
||||
tunnelSrv *sudoku.HTTPMaskTunnelServer
|
||||
fallback string
|
||||
handler *sing.ListenerHandler
|
||||
}
|
||||
|
||||
@@ -49,12 +53,19 @@ func (l *Listener) Close() error {
|
||||
}
|
||||
|
||||
func (l *Listener) handleConn(conn net.Conn, tunnel C.Tunnel, additions ...inbound.Addition) {
|
||||
log.Debugln("[Sudoku] accepted %s", conn.RemoteAddr())
|
||||
handshakeConn := conn
|
||||
handshakeCfg := &l.protoConf
|
||||
closeConns := func() {
|
||||
_ = handshakeConn.Close()
|
||||
if handshakeConn != conn {
|
||||
_ = conn.Close()
|
||||
}
|
||||
}
|
||||
if l.tunnelSrv != nil {
|
||||
c, cfg, done, err := l.tunnelSrv.WrapConn(conn)
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
closeConns()
|
||||
return
|
||||
}
|
||||
if done {
|
||||
@@ -68,9 +79,43 @@ func (l *Listener) handleConn(conn net.Conn, tunnel C.Tunnel, additions ...inbou
|
||||
}
|
||||
}
|
||||
|
||||
session, err := sudoku.ServerHandshake(handshakeConn, handshakeCfg)
|
||||
if l.fallback != "" {
|
||||
if r, ok := handshakeConn.(interface{ IsHTTPMaskRejected() bool }); ok && r.IsHTTPMaskRejected() {
|
||||
fb, err := inner.HandleTcp(tunnel, l.fallback, "")
|
||||
if err != nil {
|
||||
closeConns()
|
||||
return
|
||||
}
|
||||
N.Relay(handshakeConn, fb)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
cConn, meta, err := sudoku.ServerHandshake(handshakeConn, handshakeCfg)
|
||||
if err != nil {
|
||||
_ = handshakeConn.Close()
|
||||
fallbackAddr := l.fallback
|
||||
var susp *sudoku.SuspiciousError
|
||||
isSuspicious := errors.As(err, &susp) && susp != nil && susp.Conn != nil
|
||||
if isSuspicious {
|
||||
log.Warnln("[Sudoku] suspicious handshake from %s: %v", conn.RemoteAddr(), err)
|
||||
if fallbackAddr != "" {
|
||||
fb, err := inner.HandleTcp(tunnel, fallbackAddr, "")
|
||||
if err == nil {
|
||||
relayToFallback(susp.Conn, conn, fb)
|
||||
return
|
||||
}
|
||||
}
|
||||
} else {
|
||||
log.Debugln("[Sudoku] handshake failed from %s: %v", conn.RemoteAddr(), err)
|
||||
}
|
||||
closeConns()
|
||||
return
|
||||
}
|
||||
|
||||
session, err := sudoku.ReadServerSession(cConn, meta)
|
||||
if err != nil {
|
||||
log.Warnln("[Sudoku] read session failed from %s: %v", conn.RemoteAddr(), err)
|
||||
_ = cConn.Close()
|
||||
if handshakeConn != conn {
|
||||
_ = conn.Close()
|
||||
}
|
||||
@@ -103,6 +148,7 @@ func (l *Listener) handleConn(conn net.Conn, tunnel C.Tunnel, additions ...inbou
|
||||
default:
|
||||
targetAddr := socks5.ParseAddr(session.Target)
|
||||
if targetAddr == nil {
|
||||
log.Warnln("[Sudoku] invalid target from %s: %q", conn.RemoteAddr(), session.Target)
|
||||
_ = session.Conn.Close()
|
||||
return
|
||||
}
|
||||
@@ -164,6 +210,24 @@ func (p *uotPacket) LocalAddr() net.Addr {
|
||||
return p.rAddr
|
||||
}
|
||||
|
||||
func relayToFallback(wrapper net.Conn, rawConn net.Conn, fallback net.Conn) {
|
||||
if wrapper != nil {
|
||||
if recorder, ok := wrapper.(interface{ GetBufferedAndRecorded() []byte }); ok {
|
||||
badData := recorder.GetBufferedAndRecorded()
|
||||
if len(badData) > 0 {
|
||||
_ = fallback.SetWriteDeadline(time.Now().Add(3 * time.Second))
|
||||
if _, err := io.Copy(fallback, bytes.NewReader(badData)); err != nil {
|
||||
_ = fallback.Close()
|
||||
_ = rawConn.Close()
|
||||
return
|
||||
}
|
||||
_ = fallback.SetWriteDeadline(time.Time{})
|
||||
}
|
||||
}
|
||||
}
|
||||
N.Relay(rawConn, fallback)
|
||||
}
|
||||
|
||||
func New(config LC.SudokuServer, tunnel C.Tunnel, additions ...inbound.Addition) (*Listener, error) {
|
||||
if len(additions) == 0 {
|
||||
additions = []inbound.Addition{
|
||||
@@ -188,42 +252,24 @@ func New(config LC.SudokuServer, tunnel C.Tunnel, additions ...inbound.Addition)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tableType := strings.ToLower(config.TableType)
|
||||
if tableType == "" {
|
||||
tableType = "prefer_ascii"
|
||||
}
|
||||
|
||||
defaultConf := sudoku.DefaultConfig()
|
||||
paddingMin := defaultConf.PaddingMin
|
||||
paddingMax := defaultConf.PaddingMax
|
||||
if config.PaddingMin != nil {
|
||||
paddingMin = *config.PaddingMin
|
||||
}
|
||||
if config.PaddingMax != nil {
|
||||
paddingMax = *config.PaddingMax
|
||||
}
|
||||
if config.PaddingMin == nil && config.PaddingMax != nil && paddingMax < paddingMin {
|
||||
paddingMin = paddingMax
|
||||
}
|
||||
if config.PaddingMax == nil && config.PaddingMin != nil && paddingMax < paddingMin {
|
||||
paddingMax = paddingMin
|
||||
}
|
||||
enablePureDownlink := defaultConf.EnablePureDownlink
|
||||
if config.EnablePureDownlink != nil {
|
||||
enablePureDownlink = *config.EnablePureDownlink
|
||||
}
|
||||
|
||||
tables, err := sudoku.NewTablesWithCustomPatterns(config.Key, tableType, config.CustomTable, config.CustomTables)
|
||||
tableType, err := sudoku.NormalizeTableType(config.TableType)
|
||||
if err != nil {
|
||||
_ = l.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
handshakeTimeout := defaultConf.HandshakeTimeoutSeconds
|
||||
if config.HandshakeTimeoutSecond != nil {
|
||||
handshakeTimeout = *config.HandshakeTimeoutSecond
|
||||
defaultConf := sudoku.DefaultConfig()
|
||||
paddingMin, paddingMax := sudoku.ResolvePadding(config.PaddingMin, config.PaddingMax, defaultConf.PaddingMin, defaultConf.PaddingMax)
|
||||
enablePureDownlink := sudoku.DerefBool(config.EnablePureDownlink, defaultConf.EnablePureDownlink)
|
||||
|
||||
tables, err := sudoku.NewServerTablesWithCustomPatterns(sudoku.ServerAEADSeed(config.Key), tableType, config.CustomTable, config.CustomTables)
|
||||
if err != nil {
|
||||
_ = l.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
handshakeTimeout := sudoku.DerefInt(config.HandshakeTimeoutSecond, defaultConf.HandshakeTimeoutSeconds)
|
||||
|
||||
protoConf := sudoku.ProtocolConfig{
|
||||
Key: config.Key,
|
||||
AEADMethod: defaultConf.AEADMethod,
|
||||
@@ -249,8 +295,13 @@ func New(config LC.SudokuServer, tunnel C.Tunnel, additions ...inbound.Addition)
|
||||
addr: config.Listen,
|
||||
protoConf: protoConf,
|
||||
handler: h,
|
||||
fallback: strings.TrimSpace(config.Fallback),
|
||||
}
|
||||
if sl.fallback != "" {
|
||||
sl.tunnelSrv = sudoku.NewHTTPMaskTunnelServerWithFallback(&sl.protoConf)
|
||||
} else {
|
||||
sl.tunnelSrv = sudoku.NewHTTPMaskTunnelServer(&sl.protoConf)
|
||||
}
|
||||
sl.tunnelSrv = sudoku.NewHTTPMaskTunnelServer(&sl.protoConf)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func EncodeAddress(rawAddr string) ([]byte, error) {
|
||||
@@ -20,13 +21,21 @@ func EncodeAddress(rawAddr string) ([]byte, error) {
|
||||
}
|
||||
|
||||
var buf []byte
|
||||
if i := strings.IndexByte(host, '%'); i >= 0 {
|
||||
// Zone identifiers are not representable in SOCKS5 IPv6 address encoding.
|
||||
host = host[:i]
|
||||
}
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
buf = append(buf, 0x01) // IPv4
|
||||
buf = append(buf, ip4...)
|
||||
} else {
|
||||
buf = append(buf, 0x04) // IPv6
|
||||
buf = append(buf, ip...)
|
||||
ip16 := ip.To16()
|
||||
if ip16 == nil {
|
||||
return nil, fmt.Errorf("invalid ipv6: %q", host)
|
||||
}
|
||||
buf = append(buf, ip16...)
|
||||
}
|
||||
} else {
|
||||
if len(host) > 255 {
|
||||
|
||||
@@ -50,6 +50,7 @@ type ProtocolConfig struct {
|
||||
// - "stream": real HTTP tunnel (split-stream), CDN-compatible
|
||||
// - "poll": plain HTTP tunnel (authorize/push/pull), strong restricted-network pass-through
|
||||
// - "auto": try stream then fall back to poll
|
||||
// - "ws": WebSocket tunnel (GET upgrade), CDN-friendly
|
||||
HTTPMaskMode string
|
||||
|
||||
// HTTPMaskTLSEnabled enables HTTPS for HTTP tunnel modes (client-side).
|
||||
@@ -109,9 +110,9 @@ func (c *ProtocolConfig) Validate() error {
|
||||
}
|
||||
|
||||
switch strings.ToLower(strings.TrimSpace(c.HTTPMaskMode)) {
|
||||
case "", "legacy", "stream", "poll", "auto":
|
||||
case "", "legacy", "stream", "poll", "auto", "ws":
|
||||
default:
|
||||
return fmt.Errorf("invalid http-mask-mode: %s, must be one of: legacy, stream, poll, auto", c.HTTPMaskMode)
|
||||
return fmt.Errorf("invalid http-mask-mode: %s, must be one of: legacy, stream, poll, auto, ws", c.HTTPMaskMode)
|
||||
}
|
||||
|
||||
if v := strings.TrimSpace(c.HTTPMaskPathRoot); v != "" {
|
||||
@@ -166,6 +167,44 @@ func DefaultConfig() *ProtocolConfig {
|
||||
}
|
||||
}
|
||||
|
||||
func DerefInt(v *int, def int) int {
|
||||
if v == nil {
|
||||
return def
|
||||
}
|
||||
return *v
|
||||
}
|
||||
|
||||
func DerefBool(v *bool, def bool) bool {
|
||||
if v == nil {
|
||||
return def
|
||||
}
|
||||
return *v
|
||||
}
|
||||
|
||||
// ResolvePadding applies defaults and keeps min/max consistent when only one side is provided.
|
||||
func ResolvePadding(min, max *int, defMin, defMax int) (int, int) {
|
||||
paddingMin := DerefInt(min, defMin)
|
||||
paddingMax := DerefInt(max, defMax)
|
||||
switch {
|
||||
case min == nil && max != nil && paddingMax < paddingMin:
|
||||
paddingMin = paddingMax
|
||||
case max == nil && min != nil && paddingMax < paddingMin:
|
||||
paddingMax = paddingMin
|
||||
}
|
||||
return paddingMin, paddingMax
|
||||
}
|
||||
|
||||
func NormalizeTableType(tableType string) (string, error) {
|
||||
switch t := strings.ToLower(strings.TrimSpace(tableType)); t {
|
||||
case "", "prefer_ascii":
|
||||
return "prefer_ascii", nil
|
||||
case "prefer_entropy":
|
||||
return "prefer_entropy", nil
|
||||
default:
|
||||
return "", fmt.Errorf("table-type must be prefer_ascii or prefer_entropy")
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ProtocolConfig) tableCandidates() []*sudoku.Table {
|
||||
if c == nil {
|
||||
return nil
|
||||
|
||||
@@ -80,8 +80,8 @@ func RecoverPublicKey(keyHex string) (*edwards25519.Point, error) {
|
||||
return nil, fmt.Errorf("invalid scalar: %w", err)
|
||||
}
|
||||
return new(edwards25519.Point).ScalarBaseMult(x), nil
|
||||
|
||||
} else if len(keyBytes) == 64 {
|
||||
}
|
||||
if len(keyBytes) == 64 {
|
||||
// Split Key r || k
|
||||
rBytes := keyBytes[:32]
|
||||
kBytes := keyBytes[32:]
|
||||
|
||||
374
transport/sudoku/crypto/record_conn.go
Normal file
374
transport/sudoku/crypto/record_conn.go
Normal file
@@ -0,0 +1,374 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
)
|
||||
|
||||
// KeyUpdateAfterBytes controls automatic key rotation based on plaintext bytes.
|
||||
// It is a package var (not config) to enable targeted tests with smaller thresholds.
|
||||
var KeyUpdateAfterBytes int64 = 32 << 20 // 32 MiB
|
||||
|
||||
const (
|
||||
recordHeaderSize = 12 // epoch(uint32) + seq(uint64) - also used as nonce+AAD.
|
||||
maxFrameBodySize = 65535
|
||||
)
|
||||
|
||||
type recordKeys struct {
|
||||
baseSend []byte
|
||||
baseRecv []byte
|
||||
}
|
||||
|
||||
// RecordConn is a framed AEAD net.Conn with:
|
||||
// - deterministic per-record nonce (epoch+seq)
|
||||
// - per-direction key rotation (epoch), driven by plaintext byte counters
|
||||
// - replay/out-of-order protection within the connection (strict seq check)
|
||||
//
|
||||
// Wire format per record:
|
||||
// - uint16 bodyLen
|
||||
// - header[12] = epoch(uint32 BE) || seq(uint64 BE) (plaintext)
|
||||
// - ciphertext = AEAD(header as nonce, plaintext, header as AAD)
|
||||
type RecordConn struct {
|
||||
net.Conn
|
||||
method string
|
||||
|
||||
writeMu sync.Mutex
|
||||
readMu sync.Mutex
|
||||
|
||||
keys recordKeys
|
||||
|
||||
sendAEAD cipher.AEAD
|
||||
sendAEADEpoch uint32
|
||||
|
||||
recvAEAD cipher.AEAD
|
||||
recvAEADEpoch uint32
|
||||
|
||||
// Send direction state.
|
||||
sendEpoch uint32
|
||||
sendSeq uint64
|
||||
sendBytes int64
|
||||
|
||||
// Receive direction state.
|
||||
recvEpoch uint32
|
||||
recvSeq uint64
|
||||
|
||||
readBuf bytes.Buffer
|
||||
|
||||
// writeFrame is a reusable buffer for [len||header||ciphertext] on the wire.
|
||||
// Guarded by writeMu.
|
||||
writeFrame []byte
|
||||
}
|
||||
|
||||
func (c *RecordConn) CloseWrite() error {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
if cw, ok := c.Conn.(interface{ CloseWrite() error }); ok {
|
||||
return cw.CloseWrite()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *RecordConn) CloseRead() error {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
if cr, ok := c.Conn.(interface{ CloseRead() error }); ok {
|
||||
return cr.CloseRead()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewRecordConn(conn net.Conn, method string, baseSend, baseRecv []byte) (*RecordConn, error) {
|
||||
if conn == nil {
|
||||
return nil, fmt.Errorf("nil conn")
|
||||
}
|
||||
method = normalizeAEADMethod(method)
|
||||
if method != "none" {
|
||||
if err := validateBaseKey(baseSend); err != nil {
|
||||
return nil, fmt.Errorf("invalid send base key: %w", err)
|
||||
}
|
||||
if err := validateBaseKey(baseRecv); err != nil {
|
||||
return nil, fmt.Errorf("invalid recv base key: %w", err)
|
||||
}
|
||||
}
|
||||
rc := &RecordConn{Conn: conn, method: method}
|
||||
rc.keys = recordKeys{baseSend: cloneBytes(baseSend), baseRecv: cloneBytes(baseRecv)}
|
||||
return rc, nil
|
||||
}
|
||||
|
||||
func (c *RecordConn) Rekey(baseSend, baseRecv []byte) error {
|
||||
if c == nil {
|
||||
return fmt.Errorf("nil conn")
|
||||
}
|
||||
if c.method != "none" {
|
||||
if err := validateBaseKey(baseSend); err != nil {
|
||||
return fmt.Errorf("invalid send base key: %w", err)
|
||||
}
|
||||
if err := validateBaseKey(baseRecv); err != nil {
|
||||
return fmt.Errorf("invalid recv base key: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
c.writeMu.Lock()
|
||||
c.readMu.Lock()
|
||||
defer c.readMu.Unlock()
|
||||
defer c.writeMu.Unlock()
|
||||
|
||||
c.keys = recordKeys{baseSend: cloneBytes(baseSend), baseRecv: cloneBytes(baseRecv)}
|
||||
c.sendEpoch = 0
|
||||
c.sendSeq = 0
|
||||
c.sendBytes = 0
|
||||
c.recvEpoch = 0
|
||||
c.recvSeq = 0
|
||||
c.readBuf.Reset()
|
||||
|
||||
c.sendAEAD = nil
|
||||
c.recvAEAD = nil
|
||||
c.sendAEADEpoch = 0
|
||||
c.recvAEADEpoch = 0
|
||||
return nil
|
||||
}
|
||||
|
||||
func normalizeAEADMethod(method string) string {
|
||||
switch method {
|
||||
case "", "chacha20-poly1305":
|
||||
return "chacha20-poly1305"
|
||||
case "aes-128-gcm", "none":
|
||||
return method
|
||||
default:
|
||||
return method
|
||||
}
|
||||
}
|
||||
|
||||
func validateBaseKey(b []byte) error {
|
||||
if len(b) < 32 {
|
||||
return fmt.Errorf("need at least 32 bytes, got %d", len(b))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func cloneBytes(b []byte) []byte {
|
||||
if len(b) == 0 {
|
||||
return nil
|
||||
}
|
||||
return append([]byte(nil), b...)
|
||||
}
|
||||
|
||||
func (c *RecordConn) newAEADFor(base []byte, epoch uint32) (cipher.AEAD, error) {
|
||||
if c.method == "none" {
|
||||
return nil, nil
|
||||
}
|
||||
key := deriveEpochKey(base, epoch, c.method)
|
||||
switch c.method {
|
||||
case "aes-128-gcm":
|
||||
block, err := aes.NewCipher(key[:16])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
a, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if a.NonceSize() != recordHeaderSize {
|
||||
return nil, fmt.Errorf("unexpected gcm nonce size: %d", a.NonceSize())
|
||||
}
|
||||
return a, nil
|
||||
case "chacha20-poly1305":
|
||||
a, err := chacha20poly1305.New(key[:32])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if a.NonceSize() != recordHeaderSize {
|
||||
return nil, fmt.Errorf("unexpected chacha nonce size: %d", a.NonceSize())
|
||||
}
|
||||
return a, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported cipher: %s", c.method)
|
||||
}
|
||||
}
|
||||
|
||||
func deriveEpochKey(base []byte, epoch uint32, method string) []byte {
|
||||
var b [4]byte
|
||||
binary.BigEndian.PutUint32(b[:], epoch)
|
||||
mac := hmac.New(sha256.New, base)
|
||||
_, _ = mac.Write([]byte("sudoku-record:"))
|
||||
_, _ = mac.Write([]byte(method))
|
||||
_, _ = mac.Write(b[:])
|
||||
return mac.Sum(nil)
|
||||
}
|
||||
|
||||
func (c *RecordConn) maybeBumpSendEpochLocked(addedPlain int) {
|
||||
if KeyUpdateAfterBytes <= 0 || c.method == "none" {
|
||||
return
|
||||
}
|
||||
c.sendBytes += int64(addedPlain)
|
||||
threshold := KeyUpdateAfterBytes * int64(c.sendEpoch+1)
|
||||
if c.sendBytes < threshold {
|
||||
return
|
||||
}
|
||||
c.sendEpoch++
|
||||
c.sendSeq = 0
|
||||
}
|
||||
|
||||
func (c *RecordConn) Write(p []byte) (int, error) {
|
||||
if c == nil || c.Conn == nil {
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
if c.method == "none" {
|
||||
return c.Conn.Write(p)
|
||||
}
|
||||
|
||||
c.writeMu.Lock()
|
||||
defer c.writeMu.Unlock()
|
||||
|
||||
total := 0
|
||||
for len(p) > 0 {
|
||||
if c.sendAEAD == nil || c.sendAEADEpoch != c.sendEpoch {
|
||||
a, err := c.newAEADFor(c.keys.baseSend, c.sendEpoch)
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
c.sendAEAD = a
|
||||
c.sendAEADEpoch = c.sendEpoch
|
||||
}
|
||||
aead := c.sendAEAD
|
||||
|
||||
maxPlain := maxFrameBodySize - recordHeaderSize - aead.Overhead()
|
||||
if maxPlain <= 0 {
|
||||
return total, errors.New("frame size too small")
|
||||
}
|
||||
n := len(p)
|
||||
if n > maxPlain {
|
||||
n = maxPlain
|
||||
}
|
||||
chunk := p[:n]
|
||||
p = p[n:]
|
||||
|
||||
var header [recordHeaderSize]byte
|
||||
binary.BigEndian.PutUint32(header[:4], c.sendEpoch)
|
||||
binary.BigEndian.PutUint64(header[4:], c.sendSeq)
|
||||
c.sendSeq++
|
||||
|
||||
cipherLen := n + aead.Overhead()
|
||||
bodyLen := recordHeaderSize + cipherLen
|
||||
frameLen := 2 + bodyLen
|
||||
if bodyLen > maxFrameBodySize {
|
||||
return total, errors.New("frame too large")
|
||||
}
|
||||
if cap(c.writeFrame) < frameLen {
|
||||
c.writeFrame = make([]byte, frameLen)
|
||||
}
|
||||
frame := c.writeFrame[:frameLen]
|
||||
binary.BigEndian.PutUint16(frame[:2], uint16(bodyLen))
|
||||
copy(frame[2:2+recordHeaderSize], header[:])
|
||||
|
||||
dst := frame[2+recordHeaderSize : 2+recordHeaderSize : frameLen]
|
||||
_ = aead.Seal(dst[:0], header[:], chunk, header[:])
|
||||
|
||||
if err := writeFull(c.Conn, frame); err != nil {
|
||||
return total, err
|
||||
}
|
||||
|
||||
total += n
|
||||
c.maybeBumpSendEpochLocked(n)
|
||||
}
|
||||
return total, nil
|
||||
}
|
||||
|
||||
func (c *RecordConn) Read(p []byte) (int, error) {
|
||||
if c == nil || c.Conn == nil {
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
if c.method == "none" {
|
||||
return c.Conn.Read(p)
|
||||
}
|
||||
|
||||
c.readMu.Lock()
|
||||
defer c.readMu.Unlock()
|
||||
|
||||
if c.readBuf.Len() > 0 {
|
||||
return c.readBuf.Read(p)
|
||||
}
|
||||
|
||||
var lenBuf [2]byte
|
||||
if _, err := io.ReadFull(c.Conn, lenBuf[:]); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
bodyLen := int(binary.BigEndian.Uint16(lenBuf[:]))
|
||||
if bodyLen < recordHeaderSize {
|
||||
return 0, errors.New("frame too short")
|
||||
}
|
||||
if bodyLen > maxFrameBodySize {
|
||||
return 0, errors.New("frame too large")
|
||||
}
|
||||
|
||||
body := make([]byte, bodyLen)
|
||||
if _, err := io.ReadFull(c.Conn, body); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
header := body[:recordHeaderSize]
|
||||
ciphertext := body[recordHeaderSize:]
|
||||
|
||||
epoch := binary.BigEndian.Uint32(header[:4])
|
||||
seq := binary.BigEndian.Uint64(header[4:])
|
||||
|
||||
if epoch < c.recvEpoch {
|
||||
return 0, fmt.Errorf("replayed epoch: got %d want >=%d", epoch, c.recvEpoch)
|
||||
}
|
||||
if epoch == c.recvEpoch && seq != c.recvSeq {
|
||||
return 0, fmt.Errorf("out of order: epoch=%d got=%d want=%d", epoch, seq, c.recvSeq)
|
||||
}
|
||||
if epoch > c.recvEpoch {
|
||||
const maxJump = 8
|
||||
if epoch-c.recvEpoch > maxJump {
|
||||
return 0, fmt.Errorf("epoch jump too large: got=%d want<=%d", epoch-c.recvEpoch, maxJump)
|
||||
}
|
||||
c.recvEpoch = epoch
|
||||
c.recvSeq = 0
|
||||
if seq != 0 {
|
||||
return 0, fmt.Errorf("out of order: epoch advanced to %d but seq=%d", epoch, seq)
|
||||
}
|
||||
}
|
||||
|
||||
if c.recvAEAD == nil || c.recvAEADEpoch != c.recvEpoch {
|
||||
a, err := c.newAEADFor(c.keys.baseRecv, c.recvEpoch)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
c.recvAEAD = a
|
||||
c.recvAEADEpoch = c.recvEpoch
|
||||
}
|
||||
aead := c.recvAEAD
|
||||
|
||||
plaintext, err := aead.Open(nil, header, ciphertext, header)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("decryption failed: epoch=%d seq=%d: %w", epoch, seq, err)
|
||||
}
|
||||
c.recvSeq++
|
||||
|
||||
c.readBuf.Write(plaintext)
|
||||
return c.readBuf.Read(p)
|
||||
}
|
||||
|
||||
func writeFull(w io.Writer, b []byte) error {
|
||||
for len(b) > 0 {
|
||||
n, err := w.Write(b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
b = b[n:]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -43,12 +43,17 @@ func TestCustomTablesRotation_ProbedByServer(t *testing.T) {
|
||||
go func() {
|
||||
defer close(errCh)
|
||||
defer serverConn.Close()
|
||||
session, err := ServerHandshake(serverConn, serverCfg)
|
||||
c, meta, err := ServerHandshake(serverConn, serverCfg)
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
defer session.Conn.Close()
|
||||
session, err := ReadServerSession(c, meta)
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
defer c.Close()
|
||||
if session.Type != SessionTypeTCP {
|
||||
errCh <- io.ErrUnexpectedEOF
|
||||
return
|
||||
@@ -69,7 +74,7 @@ func TestCustomTablesRotation_ProbedByServer(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("encode addr: %v", err)
|
||||
}
|
||||
if _, err := cConn.Write(addrBuf); err != nil {
|
||||
if err := WriteKIPMessage(cConn, KIPTypeOpenTCP, addrBuf); err != nil {
|
||||
t.Fatalf("write addr: %v", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -2,19 +2,20 @@ package sudoku
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"bytes"
|
||||
"crypto/ecdh"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/metacubex/mihomo/transport/sudoku/crypto"
|
||||
"github.com/metacubex/mihomo/transport/sudoku/obfs/httpmask"
|
||||
"github.com/metacubex/mihomo/transport/sudoku/obfs/sudoku"
|
||||
|
||||
"github.com/metacubex/mihomo/log"
|
||||
)
|
||||
|
||||
type SessionType int
|
||||
@@ -30,18 +31,96 @@ type ServerSession struct {
|
||||
Type SessionType
|
||||
Target string
|
||||
|
||||
// UserHash is a stable per-key identifier derived from the handshake payload.
|
||||
// It is primarily useful for debugging / user attribution when table rotation is enabled.
|
||||
// UserHash is a stable per-key identifier derived from the client hello payload.
|
||||
UserHash string
|
||||
}
|
||||
|
||||
type bufferedConn struct {
|
||||
net.Conn
|
||||
r *bufio.Reader
|
||||
type HandshakeMeta struct {
|
||||
UserHash string
|
||||
}
|
||||
|
||||
func (bc *bufferedConn) Read(p []byte) (int, error) {
|
||||
return bc.r.Read(p)
|
||||
// SuspiciousError indicates a potential probing attempt or protocol violation.
|
||||
// When returned, Conn (if non-nil) should contain all bytes already consumed/buffered so the caller
|
||||
// can perform a best-effort fallback relay (e.g. to a local web server) without losing the request.
|
||||
type SuspiciousError struct {
|
||||
Err error
|
||||
Conn net.Conn
|
||||
}
|
||||
|
||||
func (e *SuspiciousError) Error() string {
|
||||
if e == nil || e.Err == nil {
|
||||
return ""
|
||||
}
|
||||
return e.Err.Error()
|
||||
}
|
||||
|
||||
func (e *SuspiciousError) Unwrap() error { return e.Err }
|
||||
|
||||
type recordedConn struct {
|
||||
net.Conn
|
||||
recorded []byte
|
||||
}
|
||||
|
||||
func (rc *recordedConn) GetBufferedAndRecorded() []byte { return rc.recorded }
|
||||
|
||||
type prefixedRecorderConn struct {
|
||||
net.Conn
|
||||
prefix []byte
|
||||
}
|
||||
|
||||
func (pc *prefixedRecorderConn) GetBufferedAndRecorded() []byte {
|
||||
var rest []byte
|
||||
if r, ok := pc.Conn.(interface{ GetBufferedAndRecorded() []byte }); ok {
|
||||
rest = r.GetBufferedAndRecorded()
|
||||
}
|
||||
out := make([]byte, 0, len(pc.prefix)+len(rest))
|
||||
out = append(out, pc.prefix...)
|
||||
out = append(out, rest...)
|
||||
return out
|
||||
}
|
||||
|
||||
// bufferedRecorderConn wraps a net.Conn and a shared bufio.Reader so we can expose buffered bytes.
|
||||
// This is used for legacy HTTP mask parsing errors so callers can fall back to a real HTTP server.
|
||||
type bufferedRecorderConn struct {
|
||||
net.Conn
|
||||
r *bufio.Reader
|
||||
recorder *bytes.Buffer
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (bc *bufferedRecorderConn) Read(p []byte) (n int, err error) {
|
||||
n, err = bc.r.Read(p)
|
||||
if n > 0 && bc.recorder != nil {
|
||||
bc.mu.Lock()
|
||||
bc.recorder.Write(p[:n])
|
||||
bc.mu.Unlock()
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (bc *bufferedRecorderConn) GetBufferedAndRecorded() []byte {
|
||||
if bc == nil {
|
||||
return nil
|
||||
}
|
||||
bc.mu.Lock()
|
||||
defer bc.mu.Unlock()
|
||||
|
||||
var recorded []byte
|
||||
if bc.recorder != nil {
|
||||
recorded = bc.recorder.Bytes()
|
||||
}
|
||||
buffered := 0
|
||||
if bc.r != nil {
|
||||
buffered = bc.r.Buffered()
|
||||
}
|
||||
if buffered <= 0 {
|
||||
return recorded
|
||||
}
|
||||
peeked, _ := bc.r.Peek(buffered)
|
||||
full := make([]byte, len(recorded)+len(peeked))
|
||||
copy(full, recorded)
|
||||
copy(full[len(recorded):], peeked)
|
||||
return full
|
||||
}
|
||||
|
||||
type preBufferedConn struct {
|
||||
@@ -61,6 +140,26 @@ func (p *preBufferedConn) Read(b []byte) (int, error) {
|
||||
return p.Conn.Read(b)
|
||||
}
|
||||
|
||||
func (p *preBufferedConn) CloseWrite() error {
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
if cw, ok := p.Conn.(interface{ CloseWrite() error }); ok {
|
||||
return cw.CloseWrite()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *preBufferedConn) CloseRead() error {
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
if cr, ok := p.Conn.(interface{ CloseRead() error }); ok {
|
||||
return cr.CloseRead()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type directionalConn struct {
|
||||
net.Conn
|
||||
reader io.Reader
|
||||
@@ -101,6 +200,26 @@ func (c *directionalConn) Close() error {
|
||||
return firstErr
|
||||
}
|
||||
|
||||
func (c *directionalConn) CloseWrite() error {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
if cw, ok := c.Conn.(interface{ CloseWrite() error }); ok {
|
||||
return cw.CloseWrite()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *directionalConn) CloseRead() error {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
if cr, ok := c.Conn.(interface{ CloseRead() error }); ok {
|
||||
return cr.CloseRead()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func absInt64(v int64) int64 {
|
||||
if v < 0 {
|
||||
return -v
|
||||
@@ -108,18 +227,6 @@ func absInt64(v int64) int64 {
|
||||
return v
|
||||
}
|
||||
|
||||
const (
|
||||
downlinkModePure byte = 0x01
|
||||
downlinkModePacked byte = 0x02
|
||||
)
|
||||
|
||||
func downlinkMode(cfg *ProtocolConfig) byte {
|
||||
if cfg.EnablePureDownlink {
|
||||
return downlinkModePure
|
||||
}
|
||||
return downlinkModePacked
|
||||
}
|
||||
|
||||
func buildClientObfsConn(raw net.Conn, cfg *ProtocolConfig, table *sudoku.Table) net.Conn {
|
||||
baseSudoku := sudoku.NewConn(raw, table, cfg.PaddingMin, cfg.PaddingMax, false)
|
||||
if cfg.EnablePureDownlink {
|
||||
@@ -138,50 +245,16 @@ func buildServerObfsConn(raw net.Conn, cfg *ProtocolConfig, table *sudoku.Table,
|
||||
return uplinkSudoku, newDirectionalConn(raw, uplinkSudoku, packed, packed.Flush)
|
||||
}
|
||||
|
||||
func buildHandshakePayload(key string) [16]byte {
|
||||
var payload [16]byte
|
||||
binary.BigEndian.PutUint64(payload[:8], uint64(time.Now().Unix()))
|
||||
|
||||
// Align with upstream: only decode hex bytes when this key is an ED25519 key material.
|
||||
// For plain UUID/strings (even if they look like hex), hash the string bytes as-is.
|
||||
src := []byte(key)
|
||||
if _, err := crypto.RecoverPublicKey(key); err == nil {
|
||||
if keyBytes, decErr := hex.DecodeString(key); decErr == nil && len(keyBytes) > 0 {
|
||||
src = keyBytes
|
||||
}
|
||||
func isLegacyHTTPMaskMode(mode string) bool {
|
||||
switch strings.ToLower(strings.TrimSpace(mode)) {
|
||||
case "", "legacy":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
hash := sha256.Sum256(src)
|
||||
copy(payload[8:], hash[:8])
|
||||
return payload
|
||||
}
|
||||
|
||||
func NewTable(key string, tableType string) *sudoku.Table {
|
||||
table, err := NewTableWithCustom(key, tableType, "")
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("[Sudoku] failed to init tables: %v", err))
|
||||
}
|
||||
return table
|
||||
}
|
||||
|
||||
func NewTableWithCustom(key string, tableType string, customTable string) (*sudoku.Table, error) {
|
||||
start := time.Now()
|
||||
table, err := sudoku.NewTableWithCustom(key, tableType, customTable)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
log.Infoln("[Sudoku] Tables initialized (%s, custom=%v) in %v", tableType, customTable != "", time.Since(start))
|
||||
return table, nil
|
||||
}
|
||||
|
||||
func ClientAEADSeed(key string) string {
|
||||
if recovered, err := crypto.RecoverPublicKey(key); err == nil {
|
||||
return crypto.EncodePoint(recovered)
|
||||
}
|
||||
return key
|
||||
}
|
||||
|
||||
// ClientHandshake performs the client-side Sudoku handshake (without sending target address).
|
||||
// ClientHandshake performs the client-side Sudoku handshake (no target request).
|
||||
func ClientHandshake(rawConn net.Conn, cfg *ProtocolConfig) (net.Conn, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("config is required")
|
||||
@@ -190,7 +263,7 @@ func ClientHandshake(rawConn net.Conn, cfg *ProtocolConfig) (net.Conn, error) {
|
||||
return nil, fmt.Errorf("invalid config: %w", err)
|
||||
}
|
||||
|
||||
if !cfg.DisableHTTPMask {
|
||||
if !cfg.DisableHTTPMask && isLegacyHTTPMaskMode(cfg.HTTPMaskMode) {
|
||||
if err := httpmask.WriteRandomRequestHeaderWithPathRoot(rawConn, cfg.ServerAddress, cfg.HTTPMaskPathRoot); err != nil {
|
||||
return nil, fmt.Errorf("write http mask failed: %w", err)
|
||||
}
|
||||
@@ -201,32 +274,68 @@ func ClientHandshake(rawConn net.Conn, cfg *ProtocolConfig) (net.Conn, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
seed := ClientAEADSeed(cfg.Key)
|
||||
obfsConn := buildClientObfsConn(rawConn, cfg, table)
|
||||
cConn, err := crypto.NewAEADConn(obfsConn, ClientAEADSeed(cfg.Key), cfg.AEADMethod)
|
||||
pskC2S, pskS2C := derivePSKDirectionalBases(seed)
|
||||
rc, err := crypto.NewRecordConn(obfsConn, cfg.AEADMethod, pskC2S, pskS2C)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setup crypto failed: %w", err)
|
||||
}
|
||||
|
||||
handshake := buildHandshakePayload(cfg.Key)
|
||||
if _, err := cConn.Write(handshake[:]); err != nil {
|
||||
cConn.Close()
|
||||
return nil, fmt.Errorf("send handshake failed: %w", err)
|
||||
}
|
||||
if _, err := cConn.Write([]byte{downlinkMode(cfg)}); err != nil {
|
||||
cConn.Close()
|
||||
return nil, fmt.Errorf("send downlink mode failed: %w", err)
|
||||
if _, err := kipHandshakeClient(rc, seed, kipUserHashFromKey(cfg.Key), KIPFeatAll); err != nil {
|
||||
_ = rc.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return cConn, nil
|
||||
return rc, nil
|
||||
}
|
||||
|
||||
// ServerHandshake performs Sudoku server-side handshake and detects UoT preface.
|
||||
func ServerHandshake(rawConn net.Conn, cfg *ProtocolConfig) (*ServerSession, error) {
|
||||
func readFirstSessionMessage(conn net.Conn) (*KIPMessage, error) {
|
||||
for {
|
||||
msg, err := ReadKIPMessage(conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if msg.Type == KIPTypeKeepAlive {
|
||||
continue
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
}
|
||||
|
||||
func maybeConsumeLegacyHTTPMask(rawConn net.Conn, r *bufio.Reader, cfg *ProtocolConfig) ([]byte, *SuspiciousError) {
|
||||
if rawConn == nil || r == nil || cfg == nil || cfg.DisableHTTPMask || !isLegacyHTTPMaskMode(cfg.HTTPMaskMode) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
peekBytes, _ := r.Peek(4) // ignore error; subsequent read will handle it
|
||||
if !httpmask.LooksLikeHTTPRequestStart(peekBytes) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
consumed, err := httpmask.ConsumeHeader(r)
|
||||
if err == nil {
|
||||
return consumed, nil
|
||||
}
|
||||
|
||||
recorder := new(bytes.Buffer)
|
||||
if len(consumed) > 0 {
|
||||
recorder.Write(consumed)
|
||||
}
|
||||
badConn := &bufferedRecorderConn{Conn: rawConn, r: r, recorder: recorder}
|
||||
return consumed, &SuspiciousError{Err: fmt.Errorf("invalid http header: %w", err), Conn: badConn}
|
||||
}
|
||||
|
||||
// ServerHandshake performs the server-side KIP handshake.
|
||||
func ServerHandshake(rawConn net.Conn, cfg *ProtocolConfig) (net.Conn, *HandshakeMeta, error) {
|
||||
if rawConn == nil {
|
||||
return nil, nil, fmt.Errorf("nil conn")
|
||||
}
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("config is required")
|
||||
return nil, nil, fmt.Errorf("config is required")
|
||||
}
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("invalid config: %w", err)
|
||||
return nil, nil, fmt.Errorf("invalid config: %w", err)
|
||||
}
|
||||
|
||||
handshakeTimeout := time.Duration(cfg.HandshakeTimeoutSeconds) * time.Second
|
||||
@@ -234,116 +343,113 @@ func ServerHandshake(rawConn net.Conn, cfg *ProtocolConfig) (*ServerSession, err
|
||||
handshakeTimeout = 5 * time.Second
|
||||
}
|
||||
|
||||
rawConn.SetReadDeadline(time.Now().Add(handshakeTimeout))
|
||||
|
||||
bufReader := bufio.NewReader(rawConn)
|
||||
if !cfg.DisableHTTPMask {
|
||||
if peek, err := bufReader.Peek(4); err == nil && httpmask.LooksLikeHTTPRequestStart(peek) {
|
||||
if _, err := httpmask.ConsumeHeader(bufReader); err != nil {
|
||||
return nil, fmt.Errorf("invalid http header: %w", err)
|
||||
}
|
||||
}
|
||||
_ = rawConn.SetReadDeadline(time.Now().Add(handshakeTimeout))
|
||||
defer func() { _ = rawConn.SetReadDeadline(time.Time{}) }()
|
||||
|
||||
httpHeaderData, susp := maybeConsumeLegacyHTTPMask(rawConn, bufReader, cfg)
|
||||
if susp != nil {
|
||||
return nil, nil, susp
|
||||
}
|
||||
|
||||
selectedTable, preRead, err := selectTableByProbe(bufReader, cfg, cfg.tableCandidates())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
combined := make([]byte, 0, len(httpHeaderData)+len(preRead))
|
||||
combined = append(combined, httpHeaderData...)
|
||||
combined = append(combined, preRead...)
|
||||
return nil, nil, &SuspiciousError{Err: err, Conn: &recordedConn{Conn: rawConn, recorded: combined}}
|
||||
}
|
||||
|
||||
baseConn := &preBufferedConn{Conn: rawConn, buf: preRead}
|
||||
bConn := &bufferedConn{Conn: baseConn, r: bufio.NewReader(baseConn)}
|
||||
sConn, obfsConn := buildServerObfsConn(bConn, cfg, selectedTable, true)
|
||||
cConn, err := crypto.NewAEADConn(obfsConn, cfg.Key, cfg.AEADMethod)
|
||||
sConn, obfsConn := buildServerObfsConn(baseConn, cfg, selectedTable, true)
|
||||
|
||||
seed := ServerAEADSeed(cfg.Key)
|
||||
pskC2S, pskS2C := derivePSKDirectionalBases(seed)
|
||||
// Server side: recv is client->server, send is server->client.
|
||||
rc, err := crypto.NewRecordConn(obfsConn, cfg.AEADMethod, pskS2C, pskC2S)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("crypto setup failed: %w", err)
|
||||
return nil, nil, fmt.Errorf("setup crypto failed: %w", err)
|
||||
}
|
||||
|
||||
var handshakeBuf [16]byte
|
||||
if _, err := io.ReadFull(cConn, handshakeBuf[:]); err != nil {
|
||||
cConn.Close()
|
||||
return nil, fmt.Errorf("read handshake failed: %w", err)
|
||||
msg, err := ReadKIPMessage(rc)
|
||||
if err != nil {
|
||||
return nil, nil, &SuspiciousError{Err: fmt.Errorf("handshake read failed: %w", err), Conn: &prefixedRecorderConn{Conn: sConn, prefix: httpHeaderData}}
|
||||
}
|
||||
if msg.Type != KIPTypeClientHello {
|
||||
return nil, nil, &SuspiciousError{Err: fmt.Errorf("unexpected handshake message: %d", msg.Type), Conn: &prefixedRecorderConn{Conn: sConn, prefix: httpHeaderData}}
|
||||
}
|
||||
ch, err := DecodeKIPClientHelloPayload(msg.Payload)
|
||||
if err != nil {
|
||||
return nil, nil, &SuspiciousError{Err: fmt.Errorf("decode client hello failed: %w", err), Conn: &prefixedRecorderConn{Conn: sConn, prefix: httpHeaderData}}
|
||||
}
|
||||
if absInt64(time.Now().Unix()-ch.Timestamp.Unix()) > int64(kipHandshakeSkew.Seconds()) {
|
||||
return nil, nil, &SuspiciousError{Err: fmt.Errorf("time skew/replay"), Conn: &prefixedRecorderConn{Conn: sConn, prefix: httpHeaderData}}
|
||||
}
|
||||
|
||||
ts := int64(binary.BigEndian.Uint64(handshakeBuf[:8]))
|
||||
if absInt64(time.Now().Unix()-ts) > 60 {
|
||||
cConn.Close()
|
||||
return nil, fmt.Errorf("timestamp skew detected")
|
||||
userHashHex := hex.EncodeToString(ch.UserHash[:])
|
||||
if !globalHandshakeReplay.allow(userHashHex, ch.Nonce, time.Now()) {
|
||||
return nil, nil, &SuspiciousError{Err: fmt.Errorf("replay"), Conn: &prefixedRecorderConn{Conn: sConn, prefix: httpHeaderData}}
|
||||
}
|
||||
|
||||
curve := ecdh.X25519()
|
||||
serverEphemeral, err := curve.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
return nil, nil, &SuspiciousError{Err: fmt.Errorf("ecdh generate failed: %w", err), Conn: &prefixedRecorderConn{Conn: sConn, prefix: httpHeaderData}}
|
||||
}
|
||||
shared, err := x25519SharedSecret(serverEphemeral, ch.ClientPub[:])
|
||||
if err != nil {
|
||||
return nil, nil, &SuspiciousError{Err: fmt.Errorf("ecdh failed: %w", err), Conn: &prefixedRecorderConn{Conn: sConn, prefix: httpHeaderData}}
|
||||
}
|
||||
sessC2S, sessS2C, err := deriveSessionDirectionalBases(seed, shared, ch.Nonce)
|
||||
if err != nil {
|
||||
return nil, nil, &SuspiciousError{Err: fmt.Errorf("derive session keys failed: %w", err), Conn: &prefixedRecorderConn{Conn: sConn, prefix: httpHeaderData}}
|
||||
}
|
||||
|
||||
var serverPub [kipHelloPubSize]byte
|
||||
copy(serverPub[:], serverEphemeral.PublicKey().Bytes())
|
||||
sh := &KIPServerHello{
|
||||
Nonce: ch.Nonce,
|
||||
ServerPub: serverPub,
|
||||
SelectedFeats: ch.Features & KIPFeatAll,
|
||||
}
|
||||
if err := WriteKIPMessage(rc, KIPTypeServerHello, sh.EncodePayload()); err != nil {
|
||||
return nil, nil, &SuspiciousError{Err: fmt.Errorf("write server hello failed: %w", err), Conn: &prefixedRecorderConn{Conn: sConn, prefix: httpHeaderData}}
|
||||
}
|
||||
if err := rc.Rekey(sessS2C, sessC2S); err != nil {
|
||||
return nil, nil, &SuspiciousError{Err: fmt.Errorf("rekey failed: %w", err), Conn: &prefixedRecorderConn{Conn: sConn, prefix: httpHeaderData}}
|
||||
}
|
||||
|
||||
userHash := userHashFromHandshake(handshakeBuf[:])
|
||||
sConn.StopRecording()
|
||||
return rc, &HandshakeMeta{UserHash: userHashHex}, nil
|
||||
}
|
||||
|
||||
modeBuf := []byte{0}
|
||||
if _, err := io.ReadFull(cConn, modeBuf); err != nil {
|
||||
cConn.Close()
|
||||
return nil, fmt.Errorf("read downlink mode failed: %w", err)
|
||||
// ReadServerSession consumes the first post-handshake KIP control message and returns the session intent.
|
||||
func ReadServerSession(conn net.Conn, meta *HandshakeMeta) (*ServerSession, error) {
|
||||
if conn == nil {
|
||||
return nil, fmt.Errorf("nil conn")
|
||||
}
|
||||
if modeBuf[0] != downlinkMode(cfg) {
|
||||
cConn.Close()
|
||||
return nil, fmt.Errorf("downlink mode mismatch: client=%d server=%d", modeBuf[0], downlinkMode(cfg))
|
||||
userHash := ""
|
||||
if meta != nil {
|
||||
userHash = meta.UserHash
|
||||
}
|
||||
|
||||
firstByte := make([]byte, 1)
|
||||
if _, err := io.ReadFull(cConn, firstByte); err != nil {
|
||||
cConn.Close()
|
||||
return nil, fmt.Errorf("read first byte failed: %w", err)
|
||||
first, err := readFirstSessionMessage(conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if firstByte[0] == MultiplexMagicByte {
|
||||
rawConn.SetReadDeadline(time.Time{})
|
||||
return &ServerSession{Conn: cConn, Type: SessionTypeMultiplex, UserHash: userHash}, nil
|
||||
}
|
||||
|
||||
if firstByte[0] == UoTMagicByte {
|
||||
version := make([]byte, 1)
|
||||
if _, err := io.ReadFull(cConn, version); err != nil {
|
||||
cConn.Close()
|
||||
return nil, fmt.Errorf("read uot version failed: %w", err)
|
||||
switch first.Type {
|
||||
case KIPTypeStartUoT:
|
||||
return &ServerSession{Conn: conn, Type: SessionTypeUoT, UserHash: userHash}, nil
|
||||
case KIPTypeStartMux:
|
||||
return &ServerSession{Conn: conn, Type: SessionTypeMultiplex, UserHash: userHash}, nil
|
||||
case KIPTypeOpenTCP:
|
||||
target, err := DecodeAddress(bytes.NewReader(first.Payload))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode target address failed: %w", err)
|
||||
}
|
||||
if version[0] != uotVersion {
|
||||
cConn.Close()
|
||||
return nil, fmt.Errorf("unsupported uot version: %d", version[0])
|
||||
}
|
||||
rawConn.SetReadDeadline(time.Time{})
|
||||
return &ServerSession{Conn: cConn, Type: SessionTypeUoT, UserHash: userHash}, nil
|
||||
return &ServerSession{Conn: conn, Type: SessionTypeTCP, Target: target, UserHash: userHash}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown kip message: %d", first.Type)
|
||||
}
|
||||
|
||||
prefixed := &preBufferedConn{Conn: cConn, buf: firstByte}
|
||||
target, err := DecodeAddress(prefixed)
|
||||
if err != nil {
|
||||
cConn.Close()
|
||||
return nil, fmt.Errorf("read target address failed: %w", err)
|
||||
}
|
||||
|
||||
rawConn.SetReadDeadline(time.Time{})
|
||||
log.Debugln("[Sudoku] incoming TCP session target: %s", target)
|
||||
return &ServerSession{
|
||||
Conn: prefixed,
|
||||
Type: SessionTypeTCP,
|
||||
Target: target,
|
||||
UserHash: userHash,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func GenKeyPair() (privateKey, publicKey string, err error) {
|
||||
// Generate Master Key
|
||||
pair, err := crypto.GenerateMasterKey()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// Split the master private key to get Available Private Key
|
||||
availablePrivateKey, err := crypto.SplitPrivateKey(pair.Private)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
privateKey = availablePrivateKey // Available Private Key for client
|
||||
publicKey = crypto.EncodePoint(pair.Public) // Master Public Key for server
|
||||
return
|
||||
}
|
||||
|
||||
func userHashFromHandshake(handshakeBuf []byte) string {
|
||||
if len(handshakeBuf) < 16 {
|
||||
return ""
|
||||
}
|
||||
return hex.EncodeToString(handshakeBuf[8:16])
|
||||
}
|
||||
|
||||
73
transport/sudoku/handshake_kip.go
Normal file
73
transport/sudoku/handshake_kip.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package sudoku
|
||||
|
||||
import (
|
||||
"crypto/ecdh"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/metacubex/mihomo/transport/sudoku/crypto"
|
||||
)
|
||||
|
||||
const kipHandshakeSkew = 60 * time.Second
|
||||
|
||||
func kipHandshakeClient(rc *crypto.RecordConn, seed string, userHash [kipHelloUserHashSize]byte, feats uint32) (uint32, error) {
|
||||
if rc == nil {
|
||||
return 0, fmt.Errorf("nil conn")
|
||||
}
|
||||
|
||||
curve := ecdh.X25519()
|
||||
ephemeral, err := curve.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("ecdh generate failed: %w", err)
|
||||
}
|
||||
|
||||
var nonce [kipHelloNonceSize]byte
|
||||
if _, err := io.ReadFull(rand.Reader, nonce[:]); err != nil {
|
||||
return 0, fmt.Errorf("nonce generate failed: %w", err)
|
||||
}
|
||||
|
||||
var clientPub [kipHelloPubSize]byte
|
||||
copy(clientPub[:], ephemeral.PublicKey().Bytes())
|
||||
|
||||
ch := &KIPClientHello{
|
||||
Timestamp: time.Now(),
|
||||
UserHash: userHash,
|
||||
Nonce: nonce,
|
||||
ClientPub: clientPub,
|
||||
Features: feats,
|
||||
}
|
||||
if err := WriteKIPMessage(rc, KIPTypeClientHello, ch.EncodePayload()); err != nil {
|
||||
return 0, fmt.Errorf("write client hello failed: %w", err)
|
||||
}
|
||||
|
||||
msg, err := ReadKIPMessage(rc)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("read server hello failed: %w", err)
|
||||
}
|
||||
if msg.Type != KIPTypeServerHello {
|
||||
return 0, fmt.Errorf("unexpected handshake message: %d", msg.Type)
|
||||
}
|
||||
sh, err := DecodeKIPServerHelloPayload(msg.Payload)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("decode server hello failed: %w", err)
|
||||
}
|
||||
if sh.Nonce != nonce {
|
||||
return 0, fmt.Errorf("handshake nonce mismatch")
|
||||
}
|
||||
|
||||
shared, err := x25519SharedSecret(ephemeral, sh.ServerPub[:])
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("ecdh failed: %w", err)
|
||||
}
|
||||
sessC2S, sessS2C, err := deriveSessionDirectionalBases(seed, shared, nonce)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("derive session keys failed: %w", err)
|
||||
}
|
||||
if err := rc.Rekey(sessC2S, sessS2C); err != nil {
|
||||
return 0, fmt.Errorf("rekey failed: %w", err)
|
||||
}
|
||||
|
||||
return sh.SelectedFeats, nil
|
||||
}
|
||||
@@ -124,13 +124,18 @@ func runPackedTCPSession(id int, cfg *ProtocolConfig, errCh chan<- error) {
|
||||
|
||||
// Server side
|
||||
go func() {
|
||||
session, err := ServerHandshake(serverConn, cfg)
|
||||
c, meta, err := ServerHandshake(serverConn, cfg)
|
||||
if err != nil {
|
||||
errCh <- fmt.Errorf("server handshake tcp: %w", err)
|
||||
return
|
||||
}
|
||||
defer session.Conn.Close()
|
||||
defer c.Close()
|
||||
|
||||
session, err := ReadServerSession(c, meta)
|
||||
if err != nil {
|
||||
errCh <- fmt.Errorf("server read session tcp: %w", err)
|
||||
return
|
||||
}
|
||||
if session.Type != SessionTypeTCP {
|
||||
errCh <- fmt.Errorf("unexpected session type: %v", session.Type)
|
||||
return
|
||||
@@ -159,8 +164,8 @@ func runPackedTCPSession(id int, cfg *ProtocolConfig, errCh chan<- error) {
|
||||
errCh <- fmt.Errorf("encode address: %w", err)
|
||||
return
|
||||
}
|
||||
if _, err := cConn.Write(addrBuf); err != nil {
|
||||
errCh <- fmt.Errorf("client send addr: %w", err)
|
||||
if err := WriteKIPMessage(cConn, KIPTypeOpenTCP, addrBuf); err != nil {
|
||||
errCh <- fmt.Errorf("client send open tcp: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -182,13 +187,18 @@ func runPackedUoTSession(id int, cfg *ProtocolConfig, errCh chan<- error) {
|
||||
|
||||
// Server side
|
||||
go func() {
|
||||
session, err := ServerHandshake(serverConn, cfg)
|
||||
c, meta, err := ServerHandshake(serverConn, cfg)
|
||||
if err != nil {
|
||||
errCh <- fmt.Errorf("server handshake uot: %w", err)
|
||||
return
|
||||
}
|
||||
defer session.Conn.Close()
|
||||
defer c.Close()
|
||||
|
||||
session, err := ReadServerSession(c, meta)
|
||||
if err != nil {
|
||||
errCh <- fmt.Errorf("server read session uot: %w", err)
|
||||
return
|
||||
}
|
||||
if session.Type != SessionTypeUoT {
|
||||
errCh <- fmt.Errorf("unexpected session type: %v", session.Type)
|
||||
return
|
||||
@@ -208,8 +218,8 @@ func runPackedUoTSession(id int, cfg *ProtocolConfig, errCh chan<- error) {
|
||||
}
|
||||
defer cConn.Close()
|
||||
|
||||
if err := WritePreface(cConn); err != nil {
|
||||
errCh <- fmt.Errorf("client write preface: %w", err)
|
||||
if err := WriteKIPMessage(cConn, KIPTypeStartUoT, nil); err != nil {
|
||||
errCh <- fmt.Errorf("client start uot: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -15,6 +15,14 @@ type HTTPMaskTunnelServer struct {
|
||||
}
|
||||
|
||||
func NewHTTPMaskTunnelServer(cfg *ProtocolConfig) *HTTPMaskTunnelServer {
|
||||
return newHTTPMaskTunnelServer(cfg, false)
|
||||
}
|
||||
|
||||
func NewHTTPMaskTunnelServerWithFallback(cfg *ProtocolConfig) *HTTPMaskTunnelServer {
|
||||
return newHTTPMaskTunnelServer(cfg, true)
|
||||
}
|
||||
|
||||
func newHTTPMaskTunnelServer(cfg *ProtocolConfig, passThroughOnReject bool) *HTTPMaskTunnelServer {
|
||||
if cfg == nil {
|
||||
return &HTTPMaskTunnelServer{}
|
||||
}
|
||||
@@ -22,11 +30,13 @@ func NewHTTPMaskTunnelServer(cfg *ProtocolConfig) *HTTPMaskTunnelServer {
|
||||
var ts *httpmask.TunnelServer
|
||||
if !cfg.DisableHTTPMask {
|
||||
switch strings.ToLower(strings.TrimSpace(cfg.HTTPMaskMode)) {
|
||||
case "stream", "poll", "auto":
|
||||
case "stream", "poll", "auto", "ws":
|
||||
ts = httpmask.NewTunnelServer(httpmask.TunnelServerOptions{
|
||||
Mode: cfg.HTTPMaskMode,
|
||||
PathRoot: cfg.HTTPMaskPathRoot,
|
||||
AuthKey: ClientAEADSeed(cfg.Key),
|
||||
AuthKey: ServerAEADSeed(cfg.Key),
|
||||
// When upstream fallback is enabled, preserve rejected HTTP requests for the caller.
|
||||
PassThroughOnReject: passThroughOnReject,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -62,6 +72,14 @@ func (s *HTTPMaskTunnelServer) WrapConn(rawConn net.Conn) (handshakeConn net.Con
|
||||
case httpmask.HandleStartTunnel:
|
||||
inner := *s.cfg
|
||||
inner.DisableHTTPMask = true
|
||||
// HTTPMask tunnel modes (stream/poll/auto/ws) add extra round trips before the first
|
||||
// handshake bytes can reach ServerHandshake, especially under high concurrency.
|
||||
// Bump the handshake timeout for tunneled conns to avoid flaky timeouts while keeping
|
||||
// the default strict for raw TCP handshakes.
|
||||
const minTunneledHandshakeTimeoutSeconds = 15
|
||||
if inner.HandshakeTimeoutSeconds <= 0 || inner.HandshakeTimeoutSeconds < minTunneledHandshakeTimeoutSeconds {
|
||||
inner.HandshakeTimeoutSeconds = minTunneledHandshakeTimeoutSeconds
|
||||
}
|
||||
return c, &inner, false, nil
|
||||
default:
|
||||
return nil, nil, true, nil
|
||||
@@ -70,7 +88,7 @@ func (s *HTTPMaskTunnelServer) WrapConn(rawConn net.Conn) (handshakeConn net.Con
|
||||
|
||||
type TunnelDialer func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||
|
||||
// DialHTTPMaskTunnel dials a CDN-capable HTTP tunnel (stream/poll/auto) and returns a stream carrying raw Sudoku bytes.
|
||||
// DialHTTPMaskTunnel dials a CDN-capable HTTP tunnel (stream/poll/auto/ws) and returns a stream carrying raw Sudoku bytes.
|
||||
func DialHTTPMaskTunnel(ctx context.Context, serverAddress string, cfg *ProtocolConfig, dial TunnelDialer, upgrade func(net.Conn) (net.Conn, error)) (net.Conn, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("config is required")
|
||||
@@ -79,7 +97,7 @@ func DialHTTPMaskTunnel(ctx context.Context, serverAddress string, cfg *Protocol
|
||||
return nil, fmt.Errorf("http mask is disabled")
|
||||
}
|
||||
switch strings.ToLower(strings.TrimSpace(cfg.HTTPMaskMode)) {
|
||||
case "stream", "poll", "auto":
|
||||
case "stream", "poll", "auto", "ws":
|
||||
default:
|
||||
return nil, fmt.Errorf("http-mask-mode=%q does not use http tunnel", cfg.HTTPMaskMode)
|
||||
}
|
||||
@@ -94,64 +112,3 @@ func DialHTTPMaskTunnel(ctx context.Context, serverAddress string, cfg *Protocol
|
||||
DialContext: dial,
|
||||
})
|
||||
}
|
||||
|
||||
type HTTPMaskTunnelClient struct {
|
||||
mode string
|
||||
pathRoot string
|
||||
authKey string
|
||||
client *httpmask.TunnelClient
|
||||
}
|
||||
|
||||
func NewHTTPMaskTunnelClient(serverAddress string, cfg *ProtocolConfig, dial TunnelDialer) (*HTTPMaskTunnelClient, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("config is required")
|
||||
}
|
||||
if cfg.DisableHTTPMask {
|
||||
return nil, fmt.Errorf("http mask is disabled")
|
||||
}
|
||||
switch strings.ToLower(strings.TrimSpace(cfg.HTTPMaskMode)) {
|
||||
case "stream", "poll", "auto":
|
||||
default:
|
||||
return nil, fmt.Errorf("http-mask-mode=%q does not use http tunnel", cfg.HTTPMaskMode)
|
||||
}
|
||||
switch strings.ToLower(strings.TrimSpace(cfg.HTTPMaskMultiplex)) {
|
||||
case "auto", "on":
|
||||
default:
|
||||
return nil, fmt.Errorf("http-mask-multiplex=%q does not enable reuse", cfg.HTTPMaskMultiplex)
|
||||
}
|
||||
|
||||
c, err := httpmask.NewTunnelClient(serverAddress, httpmask.TunnelClientOptions{
|
||||
TLSEnabled: cfg.HTTPMaskTLSEnabled,
|
||||
HostOverride: cfg.HTTPMaskHost,
|
||||
DialContext: dial,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &HTTPMaskTunnelClient{
|
||||
mode: cfg.HTTPMaskMode,
|
||||
pathRoot: cfg.HTTPMaskPathRoot,
|
||||
authKey: ClientAEADSeed(cfg.Key),
|
||||
client: c,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *HTTPMaskTunnelClient) Dial(ctx context.Context, upgrade func(net.Conn) (net.Conn, error)) (net.Conn, error) {
|
||||
if c == nil || c.client == nil {
|
||||
return nil, fmt.Errorf("nil httpmask tunnel client")
|
||||
}
|
||||
return c.client.DialTunnel(ctx, httpmask.TunnelDialOptions{
|
||||
Mode: c.mode,
|
||||
PathRoot: c.pathRoot,
|
||||
AuthKey: c.authKey,
|
||||
Upgrade: upgrade,
|
||||
})
|
||||
}
|
||||
|
||||
func (c *HTTPMaskTunnelClient) CloseIdleConnections() {
|
||||
if c == nil || c.client == nil {
|
||||
return
|
||||
}
|
||||
c.client.CloseIdleConnections()
|
||||
}
|
||||
|
||||
@@ -61,7 +61,7 @@ func startTunnelServer(t *testing.T, cfg *ProtocolConfig, handle func(*ServerSes
|
||||
return
|
||||
}
|
||||
|
||||
session, err := ServerHandshake(handshakeConn, handshakeCfg)
|
||||
cConn, meta, err := ServerHandshake(handshakeConn, handshakeCfg)
|
||||
if err != nil {
|
||||
_ = handshakeConn.Close()
|
||||
if handshakeConn != conn {
|
||||
@@ -70,8 +70,13 @@ func startTunnelServer(t *testing.T, cfg *ProtocolConfig, handle func(*ServerSes
|
||||
errC <- err
|
||||
return
|
||||
}
|
||||
defer session.Conn.Close()
|
||||
defer cConn.Close()
|
||||
|
||||
session, err := ReadServerSession(cConn, meta)
|
||||
if err != nil {
|
||||
errC <- err
|
||||
return
|
||||
}
|
||||
if handleErr := handle(session); handleErr != nil {
|
||||
errC <- handleErr
|
||||
}
|
||||
@@ -172,7 +177,7 @@ func TestHTTPMaskTunnel_Stream_TCPRoundTrip(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("encode addr: %v", err)
|
||||
}
|
||||
if _, err := cConn.Write(addrBuf); err != nil {
|
||||
if err := WriteKIPMessage(cConn, KIPTypeOpenTCP, addrBuf); err != nil {
|
||||
t.Fatalf("write addr: %v", err)
|
||||
}
|
||||
|
||||
@@ -239,8 +244,8 @@ func TestHTTPMaskTunnel_Poll_UoTRoundTrip(t *testing.T) {
|
||||
}
|
||||
defer cConn.Close()
|
||||
|
||||
if err := WritePreface(cConn); err != nil {
|
||||
t.Fatalf("write preface: %v", err)
|
||||
if err := WriteKIPMessage(cConn, KIPTypeStartUoT, nil); err != nil {
|
||||
t.Fatalf("start uot: %v", err)
|
||||
}
|
||||
if err := WriteDatagram(cConn, target, payload); err != nil {
|
||||
t.Fatalf("write datagram: %v", err)
|
||||
@@ -305,7 +310,68 @@ func TestHTTPMaskTunnel_Auto_TCPRoundTrip(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("encode addr: %v", err)
|
||||
}
|
||||
if _, err := cConn.Write(addrBuf); err != nil {
|
||||
if err := WriteKIPMessage(cConn, KIPTypeOpenTCP, addrBuf); err != nil {
|
||||
t.Fatalf("write addr: %v", err)
|
||||
}
|
||||
|
||||
buf := make([]byte, 2)
|
||||
if _, err := io.ReadFull(cConn, buf); err != nil {
|
||||
t.Fatalf("read: %v", err)
|
||||
}
|
||||
if string(buf) != "ok" {
|
||||
t.Fatalf("unexpected payload: %q", buf)
|
||||
}
|
||||
|
||||
stop()
|
||||
for err := range errCh {
|
||||
t.Fatalf("server error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPMaskTunnel_WS_TCPRoundTrip(t *testing.T) {
|
||||
key := "tunnel-ws-key"
|
||||
target := "1.1.1.1:80"
|
||||
|
||||
serverCfg := newTunnelTestTable(t, key)
|
||||
serverCfg.HTTPMaskMode = "ws"
|
||||
|
||||
addr, stop, errCh := startTunnelServer(t, serverCfg, func(s *ServerSession) error {
|
||||
if s.Type != SessionTypeTCP {
|
||||
return fmt.Errorf("unexpected session type: %v", s.Type)
|
||||
}
|
||||
if s.Target != target {
|
||||
return fmt.Errorf("target mismatch: %s", s.Target)
|
||||
}
|
||||
_, _ = s.Conn.Write([]byte("ok"))
|
||||
return nil
|
||||
})
|
||||
defer stop()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
clientCfg := *serverCfg
|
||||
clientCfg.ServerAddress = addr
|
||||
|
||||
tunnelConn, err := DialHTTPMaskTunnel(ctx, clientCfg.ServerAddress, &clientCfg, (&net.Dialer{}).DialContext, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("dial tunnel: %v", err)
|
||||
}
|
||||
defer tunnelConn.Close()
|
||||
|
||||
handshakeCfg := clientCfg
|
||||
handshakeCfg.DisableHTTPMask = true
|
||||
cConn, err := ClientHandshake(tunnelConn, &handshakeCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("client handshake: %v", err)
|
||||
}
|
||||
defer cConn.Close()
|
||||
|
||||
addrBuf, err := EncodeAddress(target)
|
||||
if err != nil {
|
||||
t.Fatalf("encode addr: %v", err)
|
||||
}
|
||||
if err := WriteKIPMessage(cConn, KIPTypeOpenTCP, addrBuf); err != nil {
|
||||
t.Fatalf("write addr: %v", err)
|
||||
}
|
||||
|
||||
@@ -406,7 +472,7 @@ func TestHTTPMaskTunnel_Soak_Concurrent(t *testing.T) {
|
||||
runErr <- fmt.Errorf("encode addr: %w", err)
|
||||
return
|
||||
}
|
||||
if _, err := cConn.Write(addrBuf); err != nil {
|
||||
if err := WriteKIPMessage(cConn, KIPTypeOpenTCP, addrBuf); err != nil {
|
||||
runErr <- fmt.Errorf("write addr: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
97
transport/sudoku/init.go
Normal file
97
transport/sudoku/init.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package sudoku
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/metacubex/edwards25519"
|
||||
"github.com/metacubex/mihomo/transport/sudoku/crypto"
|
||||
"github.com/metacubex/mihomo/transport/sudoku/obfs/sudoku"
|
||||
)
|
||||
|
||||
func NewTable(key string, tableType string) *sudoku.Table {
|
||||
table, err := NewTableWithCustom(key, tableType, "")
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("[Sudoku] failed to init tables: %v", err))
|
||||
}
|
||||
return table
|
||||
}
|
||||
|
||||
func NewTableWithCustom(key string, tableType string, customTable string) (*sudoku.Table, error) {
|
||||
table, err := sudoku.NewTableWithCustom(key, tableType, customTable)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return table, nil
|
||||
}
|
||||
|
||||
// ClientAEADSeed returns a canonical "seed" that is stable between client private key material and server public key.
|
||||
func ClientAEADSeed(key string) string {
|
||||
key = strings.TrimSpace(key)
|
||||
if key == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
b, err := hex.DecodeString(key)
|
||||
if err != nil {
|
||||
return key
|
||||
}
|
||||
|
||||
// Client-side key material can be:
|
||||
// - split private key: 64 bytes hex (r||k)
|
||||
// - master private scalar: 32 bytes hex (x)
|
||||
// - PSK string: non-hex
|
||||
//
|
||||
// We intentionally do NOT treat a 32-byte hex as a public key here; the client is expected
|
||||
// to carry private material. Server-side should use ServerAEADSeed for public keys.
|
||||
switch len(b) {
|
||||
case 64:
|
||||
case 32:
|
||||
default:
|
||||
return key
|
||||
}
|
||||
if recovered, err := crypto.RecoverPublicKey(key); err == nil {
|
||||
return crypto.EncodePoint(recovered)
|
||||
}
|
||||
return key
|
||||
}
|
||||
|
||||
// ServerAEADSeed returns a canonical seed for server-side configuration.
|
||||
//
|
||||
// When key is a public key (32-byte compressed point, hex), it returns the canonical point encoding.
|
||||
// When key is private key material (split/master scalar), it derives and returns the public key.
|
||||
func ServerAEADSeed(key string) string {
|
||||
key = strings.TrimSpace(key)
|
||||
if key == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
b, err := hex.DecodeString(key)
|
||||
if err != nil {
|
||||
return key
|
||||
}
|
||||
|
||||
// Prefer interpreting 32-byte hex as a public key point, to avoid accidental scalar parsing.
|
||||
if len(b) == 32 {
|
||||
if p, err := new(edwards25519.Point).SetBytes(b); err == nil {
|
||||
return hex.EncodeToString(p.Bytes())
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to client-side rules for private key materials / other formats.
|
||||
return ClientAEADSeed(key)
|
||||
}
|
||||
|
||||
// GenKeyPair generates a client "available private key" and the corresponding server public key.
|
||||
func GenKeyPair() (privateKey, publicKey string, err error) {
|
||||
pair, err := crypto.GenerateMasterKey()
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
availablePrivateKey, err := crypto.SplitPrivateKey(pair.Private)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
return availablePrivateKey, crypto.EncodePoint(pair.Public), nil
|
||||
}
|
||||
44
transport/sudoku/init_test.go
Normal file
44
transport/sudoku/init_test.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package sudoku
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"testing"
|
||||
|
||||
"github.com/metacubex/edwards25519"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestClientAEADSeed_IsStableForPrivAndPub(t *testing.T) {
|
||||
for i := 0; i < 64; i++ {
|
||||
priv, pub, err := GenKeyPair()
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, pub, ClientAEADSeed(priv))
|
||||
require.Equal(t, pub, ServerAEADSeed(pub))
|
||||
require.Equal(t, pub, ServerAEADSeed(priv))
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientAEADSeed_Supports32ByteMasterScalar(t *testing.T) {
|
||||
var seed [64]byte
|
||||
_, err := rand.Read(seed[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
s, err := edwards25519.NewScalar().SetUniformBytes(seed[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
keyHex := hex.EncodeToString(s.Bytes())
|
||||
require.Len(t, keyHex, 64)
|
||||
require.NotEqual(t, keyHex, ClientAEADSeed(keyHex))
|
||||
require.Equal(t, ClientAEADSeed(keyHex), ServerAEADSeed(ClientAEADSeed(keyHex)))
|
||||
}
|
||||
|
||||
func TestServerAEADSeed_LeavesPublicKeyAsIs(t *testing.T) {
|
||||
for i := 0; i < 64; i++ {
|
||||
priv, pub, err := GenKeyPair()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, pub, ServerAEADSeed(pub))
|
||||
require.Equal(t, pub, ServerAEADSeed(priv))
|
||||
}
|
||||
}
|
||||
206
transport/sudoku/kip.go
Normal file
206
transport/sudoku/kip.go
Normal file
@@ -0,0 +1,206 @@
|
||||
package sudoku
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
kipMagic = "kip"
|
||||
|
||||
KIPTypeClientHello byte = 0x01
|
||||
KIPTypeServerHello byte = 0x02
|
||||
|
||||
KIPTypeOpenTCP byte = 0x10
|
||||
KIPTypeStartMux byte = 0x11
|
||||
KIPTypeStartUoT byte = 0x12
|
||||
KIPTypeKeepAlive byte = 0x14
|
||||
)
|
||||
|
||||
// KIP feature bits are advisory capability flags negotiated during the handshake.
|
||||
// They represent control-plane message families.
|
||||
const (
|
||||
KIPFeatOpenTCP uint32 = 1 << 0
|
||||
KIPFeatMux uint32 = 1 << 1
|
||||
KIPFeatUoT uint32 = 1 << 2
|
||||
KIPFeatKeepAlive uint32 = 1 << 4
|
||||
|
||||
KIPFeatAll = KIPFeatOpenTCP | KIPFeatMux | KIPFeatUoT | KIPFeatKeepAlive
|
||||
)
|
||||
|
||||
const (
|
||||
kipHelloUserHashSize = 8
|
||||
kipHelloNonceSize = 16
|
||||
kipHelloPubSize = 32
|
||||
kipMaxPayload = 64 * 1024
|
||||
)
|
||||
|
||||
var errKIP = errors.New("kip protocol error")
|
||||
|
||||
type KIPMessage struct {
|
||||
Type byte
|
||||
Payload []byte
|
||||
}
|
||||
|
||||
func WriteKIPMessage(w io.Writer, typ byte, payload []byte) error {
|
||||
if w == nil {
|
||||
return fmt.Errorf("%w: nil writer", errKIP)
|
||||
}
|
||||
if len(payload) > kipMaxPayload {
|
||||
return fmt.Errorf("%w: payload too large: %d", errKIP, len(payload))
|
||||
}
|
||||
|
||||
var hdr [3 + 1 + 2]byte
|
||||
copy(hdr[:3], []byte(kipMagic))
|
||||
hdr[3] = typ
|
||||
binary.BigEndian.PutUint16(hdr[4:], uint16(len(payload)))
|
||||
|
||||
if err := writeFull(w, hdr[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(payload) == 0 {
|
||||
return nil
|
||||
}
|
||||
return writeFull(w, payload)
|
||||
}
|
||||
|
||||
func ReadKIPMessage(r io.Reader) (*KIPMessage, error) {
|
||||
if r == nil {
|
||||
return nil, fmt.Errorf("%w: nil reader", errKIP)
|
||||
}
|
||||
var hdr [3 + 1 + 2]byte
|
||||
if _, err := io.ReadFull(r, hdr[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if string(hdr[:3]) != kipMagic {
|
||||
return nil, fmt.Errorf("%w: bad magic", errKIP)
|
||||
}
|
||||
typ := hdr[3]
|
||||
n := int(binary.BigEndian.Uint16(hdr[4:]))
|
||||
if n < 0 || n > kipMaxPayload {
|
||||
return nil, fmt.Errorf("%w: invalid payload length: %d", errKIP, n)
|
||||
}
|
||||
var payload []byte
|
||||
if n > 0 {
|
||||
payload = make([]byte, n)
|
||||
if _, err := io.ReadFull(r, payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return &KIPMessage{Type: typ, Payload: payload}, nil
|
||||
}
|
||||
|
||||
type KIPClientHello struct {
|
||||
Timestamp time.Time
|
||||
UserHash [kipHelloUserHashSize]byte
|
||||
Nonce [kipHelloNonceSize]byte
|
||||
ClientPub [kipHelloPubSize]byte
|
||||
Features uint32
|
||||
}
|
||||
|
||||
type KIPServerHello struct {
|
||||
Nonce [kipHelloNonceSize]byte
|
||||
ServerPub [kipHelloPubSize]byte
|
||||
SelectedFeats uint32
|
||||
}
|
||||
|
||||
func kipUserHashFromKey(psk string) [kipHelloUserHashSize]byte {
|
||||
var out [kipHelloUserHashSize]byte
|
||||
psk = strings.TrimSpace(psk)
|
||||
if psk == "" {
|
||||
return out
|
||||
}
|
||||
|
||||
// Align with upstream: when the client carries private key material (or even just a public key),
|
||||
// prefer hashing the raw hex bytes so different split/master keys can be distinguished.
|
||||
if keyBytes, err := hex.DecodeString(psk); err == nil && len(keyBytes) > 0 {
|
||||
sum := sha256.Sum256(keyBytes)
|
||||
copy(out[:], sum[:kipHelloUserHashSize])
|
||||
return out
|
||||
}
|
||||
|
||||
sum := sha256.Sum256([]byte(psk))
|
||||
copy(out[:], sum[:kipHelloUserHashSize])
|
||||
return out
|
||||
}
|
||||
|
||||
func KIPUserHashHexFromKey(psk string) string {
|
||||
uh := kipUserHashFromKey(psk)
|
||||
return hex.EncodeToString(uh[:])
|
||||
}
|
||||
|
||||
func (m *KIPClientHello) EncodePayload() []byte {
|
||||
var b bytes.Buffer
|
||||
var tmp [8]byte
|
||||
binary.BigEndian.PutUint64(tmp[:], uint64(m.Timestamp.Unix()))
|
||||
b.Write(tmp[:])
|
||||
b.Write(m.UserHash[:])
|
||||
b.Write(m.Nonce[:])
|
||||
b.Write(m.ClientPub[:])
|
||||
var f [4]byte
|
||||
binary.BigEndian.PutUint32(f[:], m.Features)
|
||||
b.Write(f[:])
|
||||
return b.Bytes()
|
||||
}
|
||||
|
||||
func DecodeKIPClientHelloPayload(payload []byte) (*KIPClientHello, error) {
|
||||
const minLen = 8 + kipHelloUserHashSize + kipHelloNonceSize + kipHelloPubSize + 4
|
||||
if len(payload) < minLen {
|
||||
return nil, fmt.Errorf("%w: client hello too short", errKIP)
|
||||
}
|
||||
var h KIPClientHello
|
||||
ts := int64(binary.BigEndian.Uint64(payload[:8]))
|
||||
h.Timestamp = time.Unix(ts, 0)
|
||||
off := 8
|
||||
copy(h.UserHash[:], payload[off:off+kipHelloUserHashSize])
|
||||
off += kipHelloUserHashSize
|
||||
copy(h.Nonce[:], payload[off:off+kipHelloNonceSize])
|
||||
off += kipHelloNonceSize
|
||||
copy(h.ClientPub[:], payload[off:off+kipHelloPubSize])
|
||||
off += kipHelloPubSize
|
||||
h.Features = binary.BigEndian.Uint32(payload[off : off+4])
|
||||
return &h, nil
|
||||
}
|
||||
|
||||
func (m *KIPServerHello) EncodePayload() []byte {
|
||||
var b bytes.Buffer
|
||||
b.Write(m.Nonce[:])
|
||||
b.Write(m.ServerPub[:])
|
||||
var f [4]byte
|
||||
binary.BigEndian.PutUint32(f[:], m.SelectedFeats)
|
||||
b.Write(f[:])
|
||||
return b.Bytes()
|
||||
}
|
||||
|
||||
func DecodeKIPServerHelloPayload(payload []byte) (*KIPServerHello, error) {
|
||||
const want = kipHelloNonceSize + kipHelloPubSize + 4
|
||||
if len(payload) != want {
|
||||
return nil, fmt.Errorf("%w: server hello bad len: %d", errKIP, len(payload))
|
||||
}
|
||||
var h KIPServerHello
|
||||
off := 0
|
||||
copy(h.Nonce[:], payload[off:off+kipHelloNonceSize])
|
||||
off += kipHelloNonceSize
|
||||
copy(h.ServerPub[:], payload[off:off+kipHelloPubSize])
|
||||
off += kipHelloPubSize
|
||||
h.SelectedFeats = binary.BigEndian.Uint32(payload[off : off+4])
|
||||
return &h, nil
|
||||
}
|
||||
|
||||
func writeFull(w io.Writer, b []byte) error {
|
||||
for len(b) > 0 {
|
||||
n, err := w.Write(b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
b = b[n:]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -10,19 +10,14 @@ import (
|
||||
"github.com/metacubex/mihomo/transport/sudoku/multiplex"
|
||||
)
|
||||
|
||||
const (
|
||||
MultiplexMagicByte byte = multiplex.MagicByte
|
||||
MultiplexVersion byte = multiplex.Version
|
||||
)
|
||||
|
||||
// StartMultiplexClient writes the multiplex preface and upgrades an already-handshaked Sudoku tunnel into a multiplex session.
|
||||
// StartMultiplexClient upgrades an already-handshaked Sudoku tunnel into a multiplex session.
|
||||
func StartMultiplexClient(conn net.Conn) (*MultiplexClient, error) {
|
||||
if conn == nil {
|
||||
return nil, fmt.Errorf("nil conn")
|
||||
}
|
||||
|
||||
if err := multiplex.WritePreface(conn); err != nil {
|
||||
return nil, fmt.Errorf("write multiplex preface failed: %w", err)
|
||||
if err := WriteKIPMessage(conn, KIPTypeStartMux, nil); err != nil {
|
||||
return nil, fmt.Errorf("write mux start failed: %w", err)
|
||||
}
|
||||
|
||||
sess, err := multiplex.NewClientSession(conn)
|
||||
@@ -77,20 +72,10 @@ func (c *MultiplexClient) IsClosed() bool {
|
||||
}
|
||||
|
||||
// AcceptMultiplexServer upgrades a server-side, already-handshaked Sudoku connection into a multiplex session.
|
||||
//
|
||||
// The caller must have already consumed the multiplex magic byte (MultiplexMagicByte). This function consumes the
|
||||
// multiplex version byte and starts the session.
|
||||
func AcceptMultiplexServer(conn net.Conn) (*MultiplexServer, error) {
|
||||
if conn == nil {
|
||||
return nil, fmt.Errorf("nil conn")
|
||||
}
|
||||
v, err := multiplex.ReadVersion(conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := multiplex.ValidateVersion(v); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sess, err := multiplex.NewServerSession(conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
package multiplex
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
const (
|
||||
// MagicByte marks a Sudoku tunnel connection that will switch into multiplex mode.
|
||||
// It is sent after the Sudoku handshake + downlink mode byte.
|
||||
//
|
||||
// Keep it distinct from UoTMagicByte and address type bytes.
|
||||
MagicByte byte = 0xED
|
||||
Version byte = 0x01
|
||||
)
|
||||
|
||||
func WritePreface(w io.Writer) error {
|
||||
if w == nil {
|
||||
return fmt.Errorf("nil writer")
|
||||
}
|
||||
_, err := w.Write([]byte{MagicByte, Version})
|
||||
return err
|
||||
}
|
||||
|
||||
func ReadVersion(r io.Reader) (byte, error) {
|
||||
var b [1]byte
|
||||
if _, err := io.ReadFull(r, b[:]); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return b[0], nil
|
||||
}
|
||||
|
||||
func ValidateVersion(v byte) error {
|
||||
if v != Version {
|
||||
return fmt.Errorf("unsupported multiplex version: %d", v)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -18,7 +18,6 @@ func TestUserHash_StableAcrossTableRotation(t *testing.T) {
|
||||
sudokuobfs.NewTable("seed-b", "prefer_ascii"),
|
||||
}
|
||||
key := "userhash-stability-key"
|
||||
target := "example.com:80"
|
||||
|
||||
serverCfg := DefaultConfig()
|
||||
serverCfg.Key = key
|
||||
@@ -48,13 +47,16 @@ func TestUserHash_StableAcrossTableRotation(t *testing.T) {
|
||||
}
|
||||
go func(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
session, err := ServerHandshake(conn, serverCfg)
|
||||
_, meta, err := ServerHandshake(conn, serverCfg)
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
defer session.Conn.Close()
|
||||
hashCh <- session.UserHash
|
||||
if meta == nil || meta.UserHash == "" {
|
||||
errCh <- io.ErrUnexpectedEOF
|
||||
return
|
||||
}
|
||||
hashCh <- meta.UserHash
|
||||
}(c)
|
||||
}
|
||||
}()
|
||||
@@ -77,15 +79,6 @@ func TestUserHash_StableAcrossTableRotation(t *testing.T) {
|
||||
t.Fatalf("handshake %d: %v", i, err)
|
||||
}
|
||||
|
||||
addrBuf, err := EncodeAddress(target)
|
||||
if err != nil {
|
||||
_ = cConn.Close()
|
||||
t.Fatalf("encode addr %d: %v", i, err)
|
||||
}
|
||||
if _, err := cConn.Write(addrBuf); err != nil {
|
||||
_ = cConn.Close()
|
||||
t.Fatalf("write addr %d: %v", i, err)
|
||||
}
|
||||
_ = cConn.Close()
|
||||
}
|
||||
|
||||
@@ -145,18 +138,22 @@ func TestMultiplex_TCP_Echo(t *testing.T) {
|
||||
}
|
||||
defer raw.Close()
|
||||
|
||||
session, err := ServerHandshake(raw, serverCfg)
|
||||
c, meta, err := ServerHandshake(raw, serverCfg)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
atomic.AddInt64(&handshakes, 1)
|
||||
|
||||
session, err := ReadServerSession(c, meta)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if session.Type != SessionTypeMultiplex {
|
||||
_ = session.Conn.Close()
|
||||
_ = c.Close()
|
||||
return
|
||||
}
|
||||
|
||||
mux, err := AcceptMultiplexServer(session.Conn)
|
||||
mux, err := AcceptMultiplexServer(c)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -240,21 +237,3 @@ func TestMultiplex_TCP_Echo(t *testing.T) {
|
||||
t.Fatalf("unexpected stream count: %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultiplex_Boundary_InvalidVersion(t *testing.T) {
|
||||
client, server := net.Pipe()
|
||||
t.Cleanup(func() { _ = client.Close() })
|
||||
t.Cleanup(func() { _ = server.Close() })
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := AcceptMultiplexServer(server)
|
||||
errCh <- err
|
||||
}()
|
||||
|
||||
// AcceptMultiplexServer expects the magic byte to have been consumed already; write a bad version byte.
|
||||
_, _ = client.Write([]byte{0xFF})
|
||||
if err := <-errCh; err == nil {
|
||||
t.Fatalf("expected error")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,9 +6,10 @@ import (
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"github.com/metacubex/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/metacubex/http"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/metacubex/mihomo/component/ca"
|
||||
@@ -32,6 +33,7 @@ const (
|
||||
TunnelModeStream TunnelMode = "stream"
|
||||
TunnelModePoll TunnelMode = "poll"
|
||||
TunnelModeAuto TunnelMode = "auto"
|
||||
TunnelModeWS TunnelMode = "ws"
|
||||
)
|
||||
|
||||
func normalizeTunnelMode(mode string) TunnelMode {
|
||||
@@ -44,6 +46,8 @@ func normalizeTunnelMode(mode string) TunnelMode {
|
||||
return TunnelModePoll
|
||||
case string(TunnelModeAuto):
|
||||
return TunnelModeAuto
|
||||
case string(TunnelModeWS):
|
||||
return TunnelModeWS
|
||||
default:
|
||||
// Be conservative: unknown => legacy
|
||||
return TunnelModeLegacy
|
||||
@@ -88,7 +92,6 @@ type TunnelClientOptions struct {
|
||||
}
|
||||
|
||||
type TunnelClient struct {
|
||||
client *http.Client
|
||||
transport *http.Transport
|
||||
target httpClientTarget
|
||||
}
|
||||
@@ -105,7 +108,6 @@ func NewTunnelClient(serverAddress string, opts TunnelClientOptions) (*TunnelCli
|
||||
}
|
||||
|
||||
return &TunnelClient{
|
||||
client: &http.Client{Transport: transport},
|
||||
transport: transport,
|
||||
target: target,
|
||||
}, nil
|
||||
@@ -119,7 +121,7 @@ func (c *TunnelClient) CloseIdleConnections() {
|
||||
}
|
||||
|
||||
func (c *TunnelClient) DialTunnel(ctx context.Context, opts TunnelDialOptions) (net.Conn, error) {
|
||||
if c == nil || c.client == nil {
|
||||
if c == nil || c.transport == nil {
|
||||
return nil, fmt.Errorf("nil tunnel client")
|
||||
}
|
||||
tm := normalizeTunnelMode(opts.Mode)
|
||||
@@ -127,25 +129,31 @@ func (c *TunnelClient) DialTunnel(ctx context.Context, opts TunnelDialOptions) (
|
||||
return nil, fmt.Errorf("legacy mode does not use http tunnel")
|
||||
}
|
||||
|
||||
// Create a per-dial client while sharing the underlying Transport for connection reuse.
|
||||
// This matches upstream behavior and avoids potential client-level concurrency pitfalls.
|
||||
client := &http.Client{Transport: c.transport}
|
||||
|
||||
switch tm {
|
||||
case TunnelModeStream:
|
||||
return dialStreamWithClient(ctx, c.client, c.target, opts)
|
||||
return dialStreamWithClient(ctx, client, c.target, opts)
|
||||
case TunnelModePoll:
|
||||
return dialPollWithClient(ctx, c.client, c.target, opts)
|
||||
return dialPollWithClient(ctx, client, c.target, opts)
|
||||
case TunnelModeWS:
|
||||
return nil, fmt.Errorf("ws mode does not support TunnelClient reuse")
|
||||
case TunnelModeAuto:
|
||||
streamCtx, cancelX := context.WithTimeout(ctx, 3*time.Second)
|
||||
c1, errX := dialStreamWithClient(streamCtx, c.client, c.target, opts)
|
||||
c1, errX := dialStreamWithClient(streamCtx, client, c.target, opts)
|
||||
cancelX()
|
||||
if errX == nil {
|
||||
return c1, nil
|
||||
}
|
||||
c2, errP := dialPollWithClient(ctx, c.client, c.target, opts)
|
||||
c2, errP := dialPollWithClient(ctx, client, c.target, opts)
|
||||
if errP == nil {
|
||||
return c2, nil
|
||||
}
|
||||
return nil, fmt.Errorf("auto tunnel failed: stream: %v; poll: %w", errX, errP)
|
||||
default:
|
||||
return dialStreamWithClient(ctx, c.client, c.target, opts)
|
||||
return dialStreamWithClient(ctx, client, c.target, opts)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -166,6 +174,8 @@ func DialTunnel(ctx context.Context, serverAddress string, opts TunnelDialOption
|
||||
return dialStreamFn(ctx, serverAddress, opts)
|
||||
case TunnelModePoll:
|
||||
return dialPollFn(ctx, serverAddress, opts)
|
||||
case TunnelModeWS:
|
||||
return dialWS(ctx, serverAddress, opts)
|
||||
case TunnelModeAuto:
|
||||
// "stream" can hang on some CDNs that buffer uploads until request body completes.
|
||||
// Keep it on a short leash so we can fall back to poll within the caller's deadline.
|
||||
@@ -306,6 +316,36 @@ type sessionDialInfo struct {
|
||||
auth *tunnelAuth
|
||||
}
|
||||
|
||||
type httpStatusError struct {
|
||||
code int
|
||||
status string
|
||||
}
|
||||
|
||||
func (e *httpStatusError) Error() string {
|
||||
if e == nil {
|
||||
return "bad status"
|
||||
}
|
||||
if e.status != "" {
|
||||
return "bad status: " + e.status
|
||||
}
|
||||
return "bad status"
|
||||
}
|
||||
|
||||
func isRetryableStatusCode(code int) bool {
|
||||
return code == http.StatusRequestTimeout || code == http.StatusTooManyRequests || code >= 500
|
||||
}
|
||||
|
||||
type idleConnCloser interface{ CloseIdleConnections() }
|
||||
|
||||
func closeIdleConnections(client *http.Client) {
|
||||
if client == nil || client.Transport == nil {
|
||||
return
|
||||
}
|
||||
if c, ok := client.Transport.(idleConnCloser); ok {
|
||||
c.CloseIdleConnections()
|
||||
}
|
||||
}
|
||||
|
||||
func dialSessionWithClient(ctx context.Context, client *http.Client, target httpClientTarget, mode TunnelMode, opts TunnelDialOptions) (*sessionDialInfo, error) {
|
||||
if client == nil {
|
||||
return nil, fmt.Errorf("nil http client")
|
||||
@@ -313,25 +353,61 @@ func dialSessionWithClient(ctx context.Context, client *http.Client, target http
|
||||
|
||||
auth := newTunnelAuth(opts.AuthKey, 0)
|
||||
authorizeURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: joinPathRoot(opts.PathRoot, "/session")}).String()
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, authorizeURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Host = target.headerHost
|
||||
applyTunnelHeaders(req.Header, target.headerHost, mode)
|
||||
applyTunnelAuth(req, auth, mode, http.MethodGet, "/session")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
bodyBytes, err := io.ReadAll(io.LimitReader(resp.Body, 4*1024))
|
||||
_ = resp.Body.Close()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("%s authorize bad status: %s (%s)", mode, resp.Status, strings.TrimSpace(string(bodyBytes)))
|
||||
var bodyBytes []byte
|
||||
for attempt := 0; ; attempt++ {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, authorizeURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Host = target.headerHost
|
||||
applyTunnelHeaders(req.Header, target.headerHost, mode)
|
||||
applyTunnelAuth(req, auth, mode, http.MethodGet, "/session")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
// Transient failure on reused keep-alive conns (multiplex=auto). Retry a few times.
|
||||
if attempt < 2 && (isDialError(err) || isRetryableRequestError(err)) {
|
||||
closeIdleConnections(client)
|
||||
select {
|
||||
case <-time.After(25 * time.Millisecond):
|
||||
continue
|
||||
case <-ctx.Done():
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bodyBytes, err = io.ReadAll(io.LimitReader(resp.Body, 4*1024))
|
||||
_ = resp.Body.Close()
|
||||
if err != nil {
|
||||
if attempt < 2 && isRetryableRequestError(err) {
|
||||
closeIdleConnections(client)
|
||||
select {
|
||||
case <-time.After(25 * time.Millisecond):
|
||||
continue
|
||||
case <-ctx.Done():
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
// Retry some transient proxy/CDN errors.
|
||||
if attempt < 2 && resp.StatusCode >= 500 {
|
||||
closeIdleConnections(client)
|
||||
select {
|
||||
case <-time.After(25 * time.Millisecond):
|
||||
continue
|
||||
case <-ctx.Done():
|
||||
return nil, fmt.Errorf("%s authorize bad status: %s (%s)", mode, resp.Status, strings.TrimSpace(string(bodyBytes)))
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("%s authorize bad status: %s (%s)", mode, resp.Status, strings.TrimSpace(string(bodyBytes)))
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
token, err := parseTunnelToken(bodyBytes)
|
||||
@@ -544,9 +620,8 @@ type streamSplitConn struct {
|
||||
auth *tunnelAuth
|
||||
}
|
||||
|
||||
func (c *streamSplitConn) Close() error {
|
||||
_ = c.closeWithError(io.ErrClosedPipe)
|
||||
|
||||
func (c *streamSplitConn) closeWithError(err error) error {
|
||||
_ = c.queuedConn.closeWithError(err)
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
}
|
||||
@@ -554,6 +629,8 @@ func (c *streamSplitConn) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *streamSplitConn) Close() error { return c.closeWithError(io.ErrClosedPipe) }
|
||||
|
||||
func newStreamSplitConnFromInfo(info *sessionDialInfo) *streamSplitConn {
|
||||
if info == nil {
|
||||
return nil
|
||||
@@ -659,7 +736,7 @@ func (c *streamSplitConn) pullLoop() {
|
||||
req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, c.pullURL, nil)
|
||||
if err != nil {
|
||||
cancel()
|
||||
_ = c.Close()
|
||||
_ = c.closeWithError(fmt.Errorf("stream pull build request failed: %w", err))
|
||||
return
|
||||
}
|
||||
req.Host = c.headerHost
|
||||
@@ -669,8 +746,9 @@ func (c *streamSplitConn) pullLoop() {
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
cancel()
|
||||
if isDialError(err) && dialRetry < maxDialRetry {
|
||||
if (isDialError(err) || isRetryableRequestError(err)) && dialRetry < maxDialRetry {
|
||||
dialRetry++
|
||||
closeIdleConnections(c.client)
|
||||
select {
|
||||
case <-time.After(backoff):
|
||||
case <-c.closed:
|
||||
@@ -682,16 +760,33 @@ func (c *streamSplitConn) pullLoop() {
|
||||
}
|
||||
continue
|
||||
}
|
||||
_ = c.Close()
|
||||
_ = c.closeWithError(fmt.Errorf("stream pull request failed: %w", err))
|
||||
return
|
||||
}
|
||||
dialRetry = 0
|
||||
backoff = minBackoff
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if isRetryableStatusCode(resp.StatusCode) && dialRetry < maxDialRetry {
|
||||
dialRetry++
|
||||
_, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 4*1024))
|
||||
_ = resp.Body.Close()
|
||||
cancel()
|
||||
closeIdleConnections(c.client)
|
||||
select {
|
||||
case <-time.After(backoff):
|
||||
case <-c.closed:
|
||||
return
|
||||
}
|
||||
backoff *= 2
|
||||
if backoff > maxBackoff {
|
||||
backoff = maxBackoff
|
||||
}
|
||||
continue
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
cancel()
|
||||
_ = c.Close()
|
||||
_ = c.closeWithError(fmt.Errorf("stream pull bad status: %s", resp.Status))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -717,7 +812,12 @@ func (c *streamSplitConn) pullLoop() {
|
||||
// Long-poll ended; retry.
|
||||
break
|
||||
}
|
||||
_ = c.Close()
|
||||
// Some environments may sporadically reset the HTTP connection under load; treat
|
||||
// it as an ended long-poll and retry instead of tearing down the whole tunnel.
|
||||
if errors.Is(rerr, io.ErrUnexpectedEOF) || isRetryableRequestError(rerr) {
|
||||
break
|
||||
}
|
||||
_ = c.closeWithError(fmt.Errorf("stream pull read failed: %w", rerr))
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -735,8 +835,13 @@ func (c *streamSplitConn) pullLoop() {
|
||||
|
||||
func (c *streamSplitConn) pushLoop() {
|
||||
const (
|
||||
maxBatchBytes = 256 * 1024
|
||||
flushInterval = 5 * time.Millisecond
|
||||
// Batching is critical for stability under high concurrency: every flush is a new TCP
|
||||
// connection in HTTP/1.1, and too many tiny uploads can overwhelm the accept backlog,
|
||||
// causing sporadic RSTs (connection reset by peer).
|
||||
//
|
||||
// Keep this below the server-side maxUploadBytes limit in streamPush().
|
||||
maxBatchBytes = 512 * 1024
|
||||
flushInterval = 25 * time.Millisecond
|
||||
requestTimeout = 20 * time.Second
|
||||
maxDialRetry = 12
|
||||
minBackoff = 10 * time.Millisecond
|
||||
@@ -754,12 +859,18 @@ func (c *streamSplitConn) pushLoop() {
|
||||
return nil
|
||||
}
|
||||
|
||||
payload := buf.Bytes()
|
||||
reqCtx, cancel := context.WithTimeout(c.ctx, requestTimeout)
|
||||
req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, c.pushURL, bytes.NewReader(buf.Bytes()))
|
||||
req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, c.pushURL, bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
cancel()
|
||||
return err
|
||||
}
|
||||
// Be explicit: some http client forks won't auto-populate GetBody, which makes POST retries on stale
|
||||
// keep-alive connections flaky under multiplex=auto.
|
||||
req.GetBody = func() (io.ReadCloser, error) {
|
||||
return io.NopCloser(bytes.NewReader(payload)), nil
|
||||
}
|
||||
req.Host = c.headerHost
|
||||
applyTunnelHeaders(req.Header, c.headerHost, TunnelModeStream)
|
||||
applyTunnelAuth(req, c.auth, TunnelModeStream, http.MethodPost, "/api/v1/upload")
|
||||
@@ -774,7 +885,7 @@ func (c *streamSplitConn) pushLoop() {
|
||||
_ = resp.Body.Close()
|
||||
cancel()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("bad status: %s", resp.Status)
|
||||
return &httpStatusError{code: resp.StatusCode, status: resp.Status}
|
||||
}
|
||||
|
||||
buf.Reset()
|
||||
@@ -787,8 +898,22 @@ func (c *streamSplitConn) pushLoop() {
|
||||
for {
|
||||
if err := flush(); err == nil {
|
||||
return nil
|
||||
} else if isDialError(err) && dialRetry < maxDialRetry {
|
||||
} else if se := (*httpStatusError)(nil); errors.As(err, &se) && isRetryableStatusCode(se.code) && dialRetry < maxDialRetry {
|
||||
dialRetry++
|
||||
closeIdleConnections(c.client)
|
||||
select {
|
||||
case <-time.After(backoff):
|
||||
case <-c.closed:
|
||||
return io.ErrClosedPipe
|
||||
}
|
||||
backoff *= 2
|
||||
if backoff > maxBackoff {
|
||||
backoff = maxBackoff
|
||||
}
|
||||
continue
|
||||
} else if (isDialError(err) || isRetryableRequestError(err)) && dialRetry < maxDialRetry {
|
||||
dialRetry++
|
||||
closeIdleConnections(c.client)
|
||||
select {
|
||||
case <-time.After(backoff):
|
||||
case <-c.closed:
|
||||
@@ -829,7 +954,7 @@ func (c *streamSplitConn) pushLoop() {
|
||||
}
|
||||
if buf.Len()+len(b) > maxBatchBytes {
|
||||
if err := flushWithRetry(); err != nil {
|
||||
_ = c.Close()
|
||||
_ = c.closeWithError(fmt.Errorf("stream push flush failed: %w", err))
|
||||
return
|
||||
}
|
||||
resetTimer()
|
||||
@@ -837,14 +962,14 @@ func (c *streamSplitConn) pushLoop() {
|
||||
_, _ = buf.Write(b)
|
||||
if buf.Len() >= maxBatchBytes {
|
||||
if err := flushWithRetry(); err != nil {
|
||||
_ = c.Close()
|
||||
_ = c.closeWithError(fmt.Errorf("stream push flush failed: %w", err))
|
||||
return
|
||||
}
|
||||
resetTimer()
|
||||
}
|
||||
case <-timer.C:
|
||||
if err := flushWithRetry(); err != nil {
|
||||
_ = c.Close()
|
||||
_ = c.closeWithError(fmt.Errorf("stream push flush failed: %w", err))
|
||||
return
|
||||
}
|
||||
resetTimer()
|
||||
@@ -858,7 +983,7 @@ func (c *streamSplitConn) pushLoop() {
|
||||
}
|
||||
if buf.Len()+len(b) > maxBatchBytes {
|
||||
if err := flushWithRetry(); err != nil {
|
||||
_ = c.Close()
|
||||
_ = c.closeWithError(fmt.Errorf("stream push flush failed: %w", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -905,6 +1030,43 @@ func isDialError(err error) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func isRetryableRequestError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return false
|
||||
}
|
||||
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
return true
|
||||
}
|
||||
// net/http may return this when reusing a keep-alive conn that the peer already closed.
|
||||
// Treat it as retryable: callers already implement bounded backoff retries.
|
||||
if strings.Contains(strings.ToLower(err.Error()), "server closed idle connection") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Unwrap common wrappers.
|
||||
var urlErr *url.Error
|
||||
if errors.As(err, &urlErr) {
|
||||
return isRetryableRequestError(urlErr.Err)
|
||||
}
|
||||
|
||||
// Connection-level transient failures.
|
||||
if errors.Is(err, syscall.EPIPE) || errors.Is(err, syscall.ECONNRESET) {
|
||||
return true
|
||||
}
|
||||
if errors.Is(err, io.ErrClosedPipe) || errors.Is(err, net.ErrClosed) {
|
||||
return true
|
||||
}
|
||||
|
||||
var netErr net.Error
|
||||
if errors.As(err, &netErr) {
|
||||
return netErr.Timeout() || netErr.Temporary()
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *pollConn) closeWithError(err error) error {
|
||||
_ = c.queuedConn.closeWithError(err)
|
||||
if c.cancel != nil {
|
||||
@@ -1012,8 +1174,10 @@ func (c *pollConn) pullLoop() {
|
||||
default:
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(c.ctx, http.MethodGet, c.pullURL, nil)
|
||||
reqCtx, cancel := context.WithTimeout(c.ctx, 30*time.Second)
|
||||
req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, c.pullURL, nil)
|
||||
if err != nil {
|
||||
cancel()
|
||||
_ = c.Close()
|
||||
return
|
||||
}
|
||||
@@ -1023,8 +1187,10 @@ func (c *pollConn) pullLoop() {
|
||||
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
if isDialError(err) && dialRetry < maxDialRetry {
|
||||
cancel()
|
||||
if (isDialError(err) || isRetryableRequestError(err)) && dialRetry < maxDialRetry {
|
||||
dialRetry++
|
||||
closeIdleConnections(c.client)
|
||||
select {
|
||||
case <-time.After(backoff):
|
||||
case <-c.closed:
|
||||
@@ -1043,7 +1209,25 @@ func (c *pollConn) pullLoop() {
|
||||
backoff = minBackoff
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if isRetryableStatusCode(resp.StatusCode) && dialRetry < maxDialRetry {
|
||||
dialRetry++
|
||||
_, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 4*1024))
|
||||
_ = resp.Body.Close()
|
||||
cancel()
|
||||
closeIdleConnections(c.client)
|
||||
select {
|
||||
case <-time.After(backoff):
|
||||
case <-c.closed:
|
||||
return
|
||||
}
|
||||
backoff *= 2
|
||||
if backoff > maxBackoff {
|
||||
backoff = maxBackoff
|
||||
}
|
||||
continue
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
cancel()
|
||||
_ = c.closeWithError(fmt.Errorf("poll pull bad status: %s", resp.Status))
|
||||
return
|
||||
}
|
||||
@@ -1068,7 +1252,12 @@ func (c *pollConn) pullLoop() {
|
||||
}
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
cancel()
|
||||
if err := scanner.Err(); err != nil {
|
||||
// Treat transient stream breaks (RST/EOF) as an ended long-poll and retry.
|
||||
if errors.Is(err, io.ErrUnexpectedEOF) || isRetryableRequestError(err) {
|
||||
continue
|
||||
}
|
||||
_ = c.closeWithError(fmt.Errorf("poll pull scan failed: %w", err))
|
||||
return
|
||||
}
|
||||
@@ -1077,8 +1266,8 @@ func (c *pollConn) pullLoop() {
|
||||
|
||||
func (c *pollConn) pushLoop() {
|
||||
const (
|
||||
maxBatchBytes = 64 * 1024
|
||||
flushInterval = 5 * time.Millisecond
|
||||
maxBatchBytes = 512 * 1024
|
||||
flushInterval = 50 * time.Millisecond
|
||||
maxLineRawBytes = 16 * 1024
|
||||
maxDialRetry = 12
|
||||
minBackoff = 10 * time.Millisecond
|
||||
@@ -1097,12 +1286,16 @@ func (c *pollConn) pushLoop() {
|
||||
return nil
|
||||
}
|
||||
|
||||
payload := buf.Bytes()
|
||||
reqCtx, cancel := context.WithTimeout(c.ctx, 20*time.Second)
|
||||
req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, c.pushURL, bytes.NewReader(buf.Bytes()))
|
||||
req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, c.pushURL, bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
cancel()
|
||||
return err
|
||||
}
|
||||
req.GetBody = func() (io.ReadCloser, error) {
|
||||
return io.NopCloser(bytes.NewReader(payload)), nil
|
||||
}
|
||||
req.Host = c.headerHost
|
||||
applyTunnelHeaders(req.Header, c.headerHost, TunnelModePoll)
|
||||
applyTunnelAuth(req, c.auth, TunnelModePoll, http.MethodPost, "/api/v1/upload")
|
||||
@@ -1117,7 +1310,7 @@ func (c *pollConn) pushLoop() {
|
||||
_ = resp.Body.Close()
|
||||
cancel()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("bad status: %s", resp.Status)
|
||||
return &httpStatusError{code: resp.StatusCode, status: resp.Status}
|
||||
}
|
||||
|
||||
buf.Reset()
|
||||
@@ -1131,8 +1324,22 @@ func (c *pollConn) pushLoop() {
|
||||
for {
|
||||
if err := flush(); err == nil {
|
||||
return nil
|
||||
} else if isDialError(err) && dialRetry < maxDialRetry {
|
||||
} else if se := (*httpStatusError)(nil); errors.As(err, &se) && isRetryableStatusCode(se.code) && dialRetry < maxDialRetry {
|
||||
dialRetry++
|
||||
closeIdleConnections(c.client)
|
||||
select {
|
||||
case <-time.After(backoff):
|
||||
case <-c.closed:
|
||||
return c.closedErr()
|
||||
}
|
||||
backoff *= 2
|
||||
if backoff > maxBackoff {
|
||||
backoff = maxBackoff
|
||||
}
|
||||
continue
|
||||
} else if (isDialError(err) || isRetryableRequestError(err)) && dialRetry < maxDialRetry {
|
||||
dialRetry++
|
||||
closeIdleConnections(c.client)
|
||||
select {
|
||||
case <-time.After(backoff):
|
||||
case <-c.closed:
|
||||
@@ -1482,6 +1689,16 @@ func (s *TunnelServer) HandleConn(rawConn net.Conn) (HandleResult, net.Conn, err
|
||||
return HandleDone, nil, nil
|
||||
}
|
||||
return s.handlePoll(rawConn, req, headerBytes, buffered)
|
||||
case TunnelModeWS:
|
||||
if s.mode != TunnelModeWS && s.mode != TunnelModeAuto {
|
||||
if s.passThroughOnReject {
|
||||
return reject()
|
||||
}
|
||||
_ = writeSimpleHTTPResponse(rawConn, http.StatusNotFound, "not found")
|
||||
_ = rawConn.Close()
|
||||
return HandleDone, nil, nil
|
||||
}
|
||||
return s.handleWS(rawConn, req, headerBytes, buffered)
|
||||
default:
|
||||
if s.passThroughOnReject {
|
||||
return reject()
|
||||
|
||||
176
transport/sudoku/obfs/httpmask/tunnel_ws.go
Normal file
176
transport/sudoku/obfs/httpmask/tunnel_ws.go
Normal file
@@ -0,0 +1,176 @@
|
||||
package httpmask
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
mrand "math/rand"
|
||||
"net"
|
||||
stdhttp "net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gobwas/ws"
|
||||
"github.com/metacubex/tls"
|
||||
)
|
||||
|
||||
func normalizeWSSchemeFromAddress(serverAddress string, tlsEnabled bool) (string, string) {
|
||||
addr := strings.TrimSpace(serverAddress)
|
||||
if strings.Contains(addr, "://") {
|
||||
if u, err := url.Parse(addr); err == nil && u != nil {
|
||||
switch strings.ToLower(strings.TrimSpace(u.Scheme)) {
|
||||
case "ws":
|
||||
return "ws", u.Host
|
||||
case "wss":
|
||||
return "wss", u.Host
|
||||
}
|
||||
}
|
||||
}
|
||||
if tlsEnabled {
|
||||
return "wss", addr
|
||||
}
|
||||
return "ws", addr
|
||||
}
|
||||
|
||||
func normalizeWSDialTarget(serverAddress string, tlsEnabled bool, hostOverride string) (scheme, urlHost, dialAddr, serverName string, err error) {
|
||||
scheme, addr := normalizeWSSchemeFromAddress(serverAddress, tlsEnabled)
|
||||
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
// Allow ws(s)://host without port.
|
||||
if strings.Contains(addr, ":") {
|
||||
return "", "", "", "", fmt.Errorf("invalid server address %q: %w", serverAddress, err)
|
||||
}
|
||||
switch scheme {
|
||||
case "wss":
|
||||
port = "443"
|
||||
default:
|
||||
port = "80"
|
||||
}
|
||||
host = addr
|
||||
}
|
||||
|
||||
if hostOverride != "" {
|
||||
// Allow "example.com" or "example.com:443"
|
||||
if h, p, splitErr := net.SplitHostPort(hostOverride); splitErr == nil {
|
||||
if h != "" {
|
||||
hostOverride = h
|
||||
}
|
||||
if p != "" {
|
||||
port = p
|
||||
}
|
||||
}
|
||||
serverName = hostOverride
|
||||
urlHost = net.JoinHostPort(hostOverride, port)
|
||||
} else {
|
||||
serverName = host
|
||||
urlHost = net.JoinHostPort(host, port)
|
||||
}
|
||||
|
||||
dialAddr = net.JoinHostPort(host, port)
|
||||
return scheme, urlHost, dialAddr, trimPortForHost(serverName), nil
|
||||
}
|
||||
|
||||
func applyWSHeaders(h stdhttp.Header, host string) {
|
||||
if h == nil {
|
||||
return
|
||||
}
|
||||
r := rngPool.Get().(*mrand.Rand)
|
||||
ua := userAgents[r.Intn(len(userAgents))]
|
||||
accept := accepts[r.Intn(len(accepts))]
|
||||
lang := acceptLanguages[r.Intn(len(acceptLanguages))]
|
||||
enc := acceptEncodings[r.Intn(len(acceptEncodings))]
|
||||
rngPool.Put(r)
|
||||
|
||||
h.Set("User-Agent", ua)
|
||||
h.Set("Accept", accept)
|
||||
h.Set("Accept-Language", lang)
|
||||
h.Set("Accept-Encoding", enc)
|
||||
h.Set("Cache-Control", "no-cache")
|
||||
h.Set("Pragma", "no-cache")
|
||||
h.Set("X-Sudoku-Tunnel", string(TunnelModeWS))
|
||||
h.Set("X-Sudoku-Version", "1")
|
||||
}
|
||||
|
||||
func dialWS(ctx context.Context, serverAddress string, opts TunnelDialOptions) (net.Conn, error) {
|
||||
if opts.DialContext == nil {
|
||||
panic("httpmask: DialContext is nil")
|
||||
}
|
||||
|
||||
scheme, urlHost, dialAddr, serverName, err := normalizeWSDialTarget(serverAddress, opts.TLSEnabled, opts.HostOverride)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpScheme := "http"
|
||||
if scheme == "wss" {
|
||||
httpScheme = "https"
|
||||
}
|
||||
headerHost := canonicalHeaderHost(urlHost, httpScheme)
|
||||
auth := newTunnelAuth(opts.AuthKey, 0)
|
||||
|
||||
u := &url.URL{
|
||||
Scheme: scheme,
|
||||
Host: urlHost,
|
||||
Path: joinPathRoot(opts.PathRoot, "/ws"),
|
||||
}
|
||||
|
||||
header := make(stdhttp.Header)
|
||||
applyWSHeaders(header, headerHost)
|
||||
|
||||
if auth != nil {
|
||||
token := auth.token(TunnelModeWS, stdhttp.MethodGet, "/ws", time.Now())
|
||||
if token != "" {
|
||||
header.Set("Authorization", "Bearer "+token)
|
||||
q := u.Query()
|
||||
q.Set(tunnelAuthQueryKey, token)
|
||||
u.RawQuery = q.Encode()
|
||||
}
|
||||
}
|
||||
|
||||
d := ws.Dialer{
|
||||
Host: headerHost,
|
||||
Header: ws.HandshakeHeaderHTTP(header),
|
||||
NetDial: func(dialCtx context.Context, network, addr string) (net.Conn, error) {
|
||||
if addr == urlHost {
|
||||
addr = dialAddr
|
||||
}
|
||||
return opts.DialContext(dialCtx, network, addr)
|
||||
},
|
||||
}
|
||||
if scheme == "wss" {
|
||||
tlsConfig := &tls.Config{
|
||||
ServerName: serverName,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
d.TLSClient = func(conn net.Conn, hostname string) net.Conn {
|
||||
return tls.Client(conn, tlsConfig)
|
||||
}
|
||||
}
|
||||
|
||||
conn, br, _, err := d.Dial(ctx, u.String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if br != nil && br.Buffered() > 0 {
|
||||
pre := make([]byte, br.Buffered())
|
||||
_, _ = io.ReadFull(br, pre)
|
||||
conn = newPreBufferedConn(conn, pre)
|
||||
}
|
||||
|
||||
wsConn := newWSStreamConn(conn, ws.StateClientSide)
|
||||
if opts.Upgrade == nil {
|
||||
return wsConn, nil
|
||||
}
|
||||
upgraded, err := opts.Upgrade(wsConn)
|
||||
if err != nil {
|
||||
_ = wsConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
if upgraded != nil {
|
||||
return upgraded, nil
|
||||
}
|
||||
return wsConn, nil
|
||||
}
|
||||
77
transport/sudoku/obfs/httpmask/tunnel_ws_server.go
Normal file
77
transport/sudoku/obfs/httpmask/tunnel_ws_server.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package httpmask
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gobwas/ws"
|
||||
)
|
||||
|
||||
func looksLikeWebSocketUpgrade(headers map[string]string) bool {
|
||||
if headers == nil {
|
||||
return false
|
||||
}
|
||||
if !strings.EqualFold(strings.TrimSpace(headers["upgrade"]), "websocket") {
|
||||
return false
|
||||
}
|
||||
conn := headers["connection"]
|
||||
for _, part := range strings.Split(conn, ",") {
|
||||
if strings.EqualFold(strings.TrimSpace(part), "upgrade") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *TunnelServer) handleWS(rawConn net.Conn, req *httpRequestHeader, headerBytes []byte, buffered []byte) (HandleResult, net.Conn, error) {
|
||||
rejectOrReply := func(code int, body string) (HandleResult, net.Conn, error) {
|
||||
if s.passThroughOnReject {
|
||||
prefix := make([]byte, 0, len(headerBytes)+len(buffered))
|
||||
prefix = append(prefix, headerBytes...)
|
||||
prefix = append(prefix, buffered...)
|
||||
return HandlePassThrough, newRejectedPreBufferedConn(rawConn, prefix), nil
|
||||
}
|
||||
_ = writeSimpleHTTPResponse(rawConn, code, body)
|
||||
_ = rawConn.Close()
|
||||
return HandleDone, nil, nil
|
||||
}
|
||||
|
||||
u, err := url.ParseRequestURI(req.target)
|
||||
if err != nil {
|
||||
return rejectOrReply(http.StatusBadRequest, "bad request")
|
||||
}
|
||||
|
||||
path, ok := stripPathRoot(s.pathRoot, u.Path)
|
||||
if !ok || path != "/ws" {
|
||||
return rejectOrReply(http.StatusNotFound, "not found")
|
||||
}
|
||||
if strings.ToUpper(strings.TrimSpace(req.method)) != http.MethodGet {
|
||||
return rejectOrReply(http.StatusBadRequest, "bad request")
|
||||
}
|
||||
if !looksLikeWebSocketUpgrade(req.headers) {
|
||||
return rejectOrReply(http.StatusBadRequest, "bad request")
|
||||
}
|
||||
|
||||
authVal := req.headers["authorization"]
|
||||
if authVal == "" {
|
||||
authVal = u.Query().Get(tunnelAuthQueryKey)
|
||||
}
|
||||
if !s.auth.verifyValue(authVal, TunnelModeWS, req.method, path, time.Now()) {
|
||||
return rejectOrReply(http.StatusNotFound, "not found")
|
||||
}
|
||||
|
||||
prefix := make([]byte, 0, len(headerBytes)+len(buffered))
|
||||
prefix = append(prefix, headerBytes...)
|
||||
prefix = append(prefix, buffered...)
|
||||
wsConnRaw := newPreBufferedConn(rawConn, prefix)
|
||||
|
||||
if _, err := ws.Upgrade(wsConnRaw); err != nil {
|
||||
_ = rawConn.Close()
|
||||
return HandleDone, nil, nil
|
||||
}
|
||||
|
||||
return HandleStartTunnel, newWSStreamConn(wsConnRaw, ws.StateServerSide), nil
|
||||
}
|
||||
78
transport/sudoku/obfs/httpmask/ws_stream_conn.go
Normal file
78
transport/sudoku/obfs/httpmask/ws_stream_conn.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package httpmask
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/gobwas/ws"
|
||||
"github.com/gobwas/ws/wsutil"
|
||||
)
|
||||
|
||||
type wsStreamConn struct {
|
||||
net.Conn
|
||||
state ws.State
|
||||
reader *wsutil.Reader
|
||||
controlHandler wsutil.FrameHandlerFunc
|
||||
}
|
||||
|
||||
func newWSStreamConn(conn net.Conn, state ws.State) net.Conn {
|
||||
controlHandler := wsutil.ControlFrameHandler(conn, state)
|
||||
return &wsStreamConn{
|
||||
Conn: conn,
|
||||
state: state,
|
||||
reader: &wsutil.Reader{
|
||||
Source: conn,
|
||||
State: state,
|
||||
},
|
||||
controlHandler: controlHandler,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *wsStreamConn) Read(b []byte) (n int, err error) {
|
||||
defer func() {
|
||||
if v := recover(); v != nil {
|
||||
err = fmt.Errorf("websocket error: %v", v)
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
n, err = c.reader.Read(b)
|
||||
if errors.Is(err, io.EOF) {
|
||||
err = nil
|
||||
}
|
||||
if !errors.Is(err, wsutil.ErrNoFrameAdvance) {
|
||||
return n, err
|
||||
}
|
||||
|
||||
hdr, err2 := c.reader.NextFrame()
|
||||
if err2 != nil {
|
||||
return 0, err2
|
||||
}
|
||||
if hdr.OpCode.IsControl() {
|
||||
if err := c.controlHandler(hdr, c.reader); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
continue
|
||||
}
|
||||
if hdr.OpCode&(ws.OpBinary|ws.OpText) == 0 {
|
||||
if err := c.reader.Discard(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *wsStreamConn) Write(b []byte) (int, error) {
|
||||
if err := wsutil.WriteMessage(c.Conn, c.state, ws.OpBinary, b); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (c *wsStreamConn) Close() error {
|
||||
_ = wsutil.WriteMessage(c.Conn, c.state, ws.OpClose, ws.NewCloseFrameBody(ws.StatusNormalClosure, ""))
|
||||
return c.Conn.Close()
|
||||
}
|
||||
@@ -4,9 +4,7 @@ import (
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"log"
|
||||
"math/rand"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -26,7 +24,7 @@ type Table struct {
|
||||
func NewTable(key string, mode string) *Table {
|
||||
t, err := NewTableWithCustom(key, mode, "")
|
||||
if err != nil {
|
||||
log.Panicf("failed to build table: %v", err)
|
||||
panic(err)
|
||||
}
|
||||
return t
|
||||
}
|
||||
@@ -35,8 +33,6 @@ func NewTable(key string, mode string) *Table {
|
||||
// mode: "prefer_ascii" or "prefer_entropy". If a custom pattern is provided, ASCII mode still takes precedence.
|
||||
// The customPattern must contain 8 characters with exactly 2 x, 2 p, and 4 v (case-insensitive).
|
||||
func NewTableWithCustom(key string, mode string, customPattern string) (*Table, error) {
|
||||
start := time.Now()
|
||||
|
||||
layout, err := resolveLayout(mode, customPattern)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -126,7 +122,6 @@ func NewTableWithCustom(key string, mode string, customPattern string) (*Table,
|
||||
}
|
||||
}
|
||||
}
|
||||
log.Printf("[Init] Sudoku Tables initialized (%s) in %v", layout.name, time.Since(start))
|
||||
return t, nil
|
||||
}
|
||||
|
||||
|
||||
74
transport/sudoku/replay.go
Normal file
74
transport/sudoku/replay.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package sudoku
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var handshakeReplayTTL = 60 * time.Second
|
||||
|
||||
type nonceSet struct {
|
||||
mu sync.Mutex
|
||||
m map[[kipHelloNonceSize]byte]time.Time
|
||||
maxEntries int
|
||||
lastPrune time.Time
|
||||
}
|
||||
|
||||
func newNonceSet(maxEntries int) *nonceSet {
|
||||
if maxEntries <= 0 {
|
||||
maxEntries = 4096
|
||||
}
|
||||
return &nonceSet{
|
||||
m: make(map[[kipHelloNonceSize]byte]time.Time),
|
||||
maxEntries: maxEntries,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *nonceSet) allow(nonce [kipHelloNonceSize]byte, now time.Time, ttl time.Duration) bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if ttl <= 0 {
|
||||
ttl = 60 * time.Second
|
||||
}
|
||||
|
||||
if now.Sub(s.lastPrune) > ttl/2 || len(s.m) > s.maxEntries {
|
||||
for k, exp := range s.m {
|
||||
if !now.Before(exp) {
|
||||
delete(s.m, k)
|
||||
}
|
||||
}
|
||||
s.lastPrune = now
|
||||
for len(s.m) > s.maxEntries {
|
||||
for k := range s.m {
|
||||
delete(s.m, k)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if exp, ok := s.m[nonce]; ok && now.Before(exp) {
|
||||
return false
|
||||
}
|
||||
s.m[nonce] = now.Add(ttl)
|
||||
return true
|
||||
}
|
||||
|
||||
type handshakeReplayProtector struct {
|
||||
users sync.Map // map[userHash string]*nonceSet
|
||||
}
|
||||
|
||||
func (p *handshakeReplayProtector) allow(userHash string, nonce [kipHelloNonceSize]byte, now time.Time) bool {
|
||||
if userHash == "" {
|
||||
userHash = "_"
|
||||
}
|
||||
val, _ := p.users.LoadOrStore(userHash, newNonceSet(4096))
|
||||
set, ok := val.(*nonceSet)
|
||||
if !ok || set == nil {
|
||||
set = newNonceSet(4096)
|
||||
p.users.Store(userHash, set)
|
||||
}
|
||||
return set.allow(nonce, now, handshakeReplayTTL)
|
||||
}
|
||||
|
||||
var globalHandshakeReplay = &handshakeReplayProtector{}
|
||||
58
transport/sudoku/session_keys.go
Normal file
58
transport/sudoku/session_keys.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package sudoku
|
||||
|
||||
import (
|
||||
"crypto/ecdh"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"golang.org/x/crypto/hkdf"
|
||||
)
|
||||
|
||||
func derivePSKDirectionalBases(psk string) (c2s, s2c []byte) {
|
||||
sum := sha256.Sum256([]byte(psk))
|
||||
c2sKey := make([]byte, 32)
|
||||
s2cKey := make([]byte, 32)
|
||||
if _, err := io.ReadFull(hkdf.Expand(sha256.New, sum[:], []byte("sudoku-psk-c2s")), c2sKey); err != nil {
|
||||
panic("sudoku: hkdf expand failed")
|
||||
}
|
||||
if _, err := io.ReadFull(hkdf.Expand(sha256.New, sum[:], []byte("sudoku-psk-s2c")), s2cKey); err != nil {
|
||||
panic("sudoku: hkdf expand failed")
|
||||
}
|
||||
return c2sKey, s2cKey
|
||||
}
|
||||
|
||||
func deriveSessionDirectionalBases(psk string, shared []byte, nonce [kipHelloNonceSize]byte) (c2s, s2c []byte, err error) {
|
||||
sum := sha256.Sum256([]byte(psk))
|
||||
ikm := make([]byte, 0, len(shared)+len(nonce))
|
||||
ikm = append(ikm, shared...)
|
||||
ikm = append(ikm, nonce[:]...)
|
||||
|
||||
prk := hkdf.Extract(sha256.New, ikm, sum[:])
|
||||
|
||||
c2sKey := make([]byte, 32)
|
||||
s2cKey := make([]byte, 32)
|
||||
if _, err := io.ReadFull(hkdf.Expand(sha256.New, prk, []byte("sudoku-session-c2s")), c2sKey); err != nil {
|
||||
return nil, nil, fmt.Errorf("hkdf expand c2s: %w", err)
|
||||
}
|
||||
if _, err := io.ReadFull(hkdf.Expand(sha256.New, prk, []byte("sudoku-session-s2c")), s2cKey); err != nil {
|
||||
return nil, nil, fmt.Errorf("hkdf expand s2c: %w", err)
|
||||
}
|
||||
return c2sKey, s2cKey, nil
|
||||
}
|
||||
|
||||
func x25519SharedSecret(priv *ecdh.PrivateKey, peerPub []byte) ([]byte, error) {
|
||||
if priv == nil {
|
||||
return nil, fmt.Errorf("nil priv")
|
||||
}
|
||||
curve := ecdh.X25519()
|
||||
pk, err := curve.NewPublicKey(peerPub)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse peer pub: %w", err)
|
||||
}
|
||||
secret, err := priv.ECDH(pk)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ecdh: %w", err)
|
||||
}
|
||||
return secret, nil
|
||||
}
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
crand "crypto/rand"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -56,26 +55,27 @@ func drainBuffered(r *bufio.Reader) ([]byte, error) {
|
||||
func probeHandshakeBytes(probe []byte, cfg *ProtocolConfig, table *sudoku.Table) error {
|
||||
rc := &readOnlyConn{Reader: bytes.NewReader(probe)}
|
||||
_, obfsConn := buildServerObfsConn(rc, cfg, table, false)
|
||||
cConn, err := crypto.NewAEADConn(obfsConn, cfg.Key, cfg.AEADMethod)
|
||||
seed := ServerAEADSeed(cfg.Key)
|
||||
pskC2S, pskS2C := derivePSKDirectionalBases(seed)
|
||||
// Server side: recv is client->server, send is server->client.
|
||||
cConn, err := crypto.NewRecordConn(obfsConn, cfg.AEADMethod, pskS2C, pskC2S)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var handshakeBuf [16]byte
|
||||
if _, err := io.ReadFull(cConn, handshakeBuf[:]); err != nil {
|
||||
msg, err := ReadKIPMessage(cConn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ts := int64(binary.BigEndian.Uint64(handshakeBuf[:8]))
|
||||
if absInt64(time.Now().Unix()-ts) > 60 {
|
||||
return fmt.Errorf("timestamp skew/replay detected")
|
||||
if msg.Type != KIPTypeClientHello {
|
||||
return fmt.Errorf("unexpected handshake message: %d", msg.Type)
|
||||
}
|
||||
|
||||
modeBuf := []byte{0}
|
||||
if _, err := io.ReadFull(cConn, modeBuf); err != nil {
|
||||
ch, err := DecodeKIPClientHelloPayload(msg.Payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if modeBuf[0] != downlinkMode(cfg) {
|
||||
return fmt.Errorf("downlink mode mismatch")
|
||||
if absInt64(time.Now().Unix()-ch.Timestamp.Unix()) > int64(kipHandshakeSkew.Seconds()) {
|
||||
return fmt.Errorf("time skew/replay")
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -93,6 +93,17 @@ func selectTableByProbe(r *bufio.Reader, cfg *ProtocolConfig, tables []*sudoku.T
|
||||
return nil, nil, fmt.Errorf("too many table candidates: %d", len(tables))
|
||||
}
|
||||
|
||||
// Copy so we can prune candidates without mutating the caller slice.
|
||||
candidates := make([]*sudoku.Table, 0, len(tables))
|
||||
for _, t := range tables {
|
||||
if t != nil {
|
||||
candidates = append(candidates, t)
|
||||
}
|
||||
}
|
||||
if len(candidates) == 0 {
|
||||
return nil, nil, fmt.Errorf("no table candidates")
|
||||
}
|
||||
|
||||
probe, err := drainBuffered(r)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("drain buffered bytes failed: %w", err)
|
||||
@@ -100,17 +111,18 @@ func selectTableByProbe(r *bufio.Reader, cfg *ProtocolConfig, tables []*sudoku.T
|
||||
|
||||
tmp := make([]byte, readChunk)
|
||||
for {
|
||||
if len(tables) == 1 {
|
||||
if len(candidates) == 1 {
|
||||
tail, err := drainBuffered(r)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("drain buffered bytes failed: %w", err)
|
||||
}
|
||||
probe = append(probe, tail...)
|
||||
return tables[0], probe, nil
|
||||
return candidates[0], probe, nil
|
||||
}
|
||||
|
||||
needMore := false
|
||||
for _, table := range tables {
|
||||
next := candidates[:0]
|
||||
for _, table := range candidates {
|
||||
err := probeHandshakeBytes(probe, cfg, table)
|
||||
if err == nil {
|
||||
tail, err := drainBuffered(r)
|
||||
@@ -122,10 +134,13 @@ func selectTableByProbe(r *bufio.Reader, cfg *ProtocolConfig, tables []*sudoku.T
|
||||
}
|
||||
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
needMore = true
|
||||
next = append(next, table)
|
||||
}
|
||||
// Definitive mismatch: drop table.
|
||||
}
|
||||
candidates = next
|
||||
|
||||
if !needMore {
|
||||
if len(candidates) == 0 || !needMore {
|
||||
return nil, probe, fmt.Errorf("handshake table selection failed")
|
||||
}
|
||||
if len(probe) >= maxProbeBytes {
|
||||
|
||||
@@ -6,9 +6,7 @@ import (
|
||||
"github.com/metacubex/mihomo/transport/sudoku/obfs/sudoku"
|
||||
)
|
||||
|
||||
// NewTablesWithCustomPatterns builds one or more obfuscation tables from x/v/p custom patterns.
|
||||
// When customTables is non-empty it overrides customTable (matching upstream Sudoku behavior).
|
||||
func NewTablesWithCustomPatterns(key string, tableType string, customTable string, customTables []string) ([]*sudoku.Table, error) {
|
||||
func normalizeCustomPatterns(customTable string, customTables []string) []string {
|
||||
patterns := customTables
|
||||
if len(patterns) == 0 && strings.TrimSpace(customTable) != "" {
|
||||
patterns = []string{customTable}
|
||||
@@ -16,7 +14,15 @@ func NewTablesWithCustomPatterns(key string, tableType string, customTable strin
|
||||
if len(patterns) == 0 {
|
||||
patterns = []string{""}
|
||||
}
|
||||
return patterns
|
||||
}
|
||||
|
||||
// NewTablesWithCustomPatterns builds one or more obfuscation tables from x/v/p custom patterns.
|
||||
// When customTables is non-empty it overrides customTable (matching upstream Sudoku behavior).
|
||||
//
|
||||
// Deprecated-ish: prefer NewClientTablesWithCustomPatterns / NewServerTablesWithCustomPatterns.
|
||||
func NewTablesWithCustomPatterns(key string, tableType string, customTable string, customTables []string) ([]*sudoku.Table, error) {
|
||||
patterns := normalizeCustomPatterns(customTable, customTables)
|
||||
tables := make([]*sudoku.Table, 0, len(patterns))
|
||||
for _, pattern := range patterns {
|
||||
pattern = strings.TrimSpace(pattern)
|
||||
@@ -28,3 +34,17 @@ func NewTablesWithCustomPatterns(key string, tableType string, customTable strin
|
||||
}
|
||||
return tables, nil
|
||||
}
|
||||
|
||||
func NewClientTablesWithCustomPatterns(key string, tableType string, customTable string, customTables []string) ([]*sudoku.Table, error) {
|
||||
return NewTablesWithCustomPatterns(key, tableType, customTable, customTables)
|
||||
}
|
||||
|
||||
// NewServerTablesWithCustomPatterns matches upstream server behavior: when custom table rotation is enabled,
|
||||
// also accept the default table to avoid forcing clients to update in lockstep.
|
||||
func NewServerTablesWithCustomPatterns(key string, tableType string, customTable string, customTables []string) ([]*sudoku.Table, error) {
|
||||
patterns := normalizeCustomPatterns(customTable, customTables)
|
||||
if len(patterns) > 0 && strings.TrimSpace(patterns[0]) != "" {
|
||||
patterns = append([]string{""}, patterns...)
|
||||
}
|
||||
return NewTablesWithCustomPatterns(key, tableType, "", patterns)
|
||||
}
|
||||
|
||||
@@ -16,17 +16,9 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
UoTMagicByte byte = 0xEE
|
||||
uotVersion = 0x01
|
||||
maxUoTPayload = 64 * 1024
|
||||
maxUoTPayload = 64 * 1024
|
||||
)
|
||||
|
||||
// WritePreface writes the UDP-over-TCP marker and version.
|
||||
func WritePreface(w io.Writer) error {
|
||||
_, err := w.Write([]byte{UoTMagicByte, uotVersion})
|
||||
return err
|
||||
}
|
||||
|
||||
// WriteDatagram sends a single UDP datagram frame over a reliable stream.
|
||||
func WriteDatagram(w io.Writer, addr string, payload []byte) error {
|
||||
addrBuf, err := EncodeAddress(addr)
|
||||
@@ -45,14 +37,13 @@ func WriteDatagram(w io.Writer, addr string, payload []byte) error {
|
||||
binary.BigEndian.PutUint16(header[:2], uint16(len(addrBuf)))
|
||||
binary.BigEndian.PutUint16(header[2:], uint16(len(payload)))
|
||||
|
||||
if _, err := w.Write(header[:]); err != nil {
|
||||
if err := writeFull(w, header[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := w.Write(addrBuf); err != nil {
|
||||
if err := writeFull(w, addrBuf); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = w.Write(payload)
|
||||
return err
|
||||
return writeFull(w, payload)
|
||||
}
|
||||
|
||||
// ReadDatagram parses a single UDP datagram frame from the reliable stream.
|
||||
|
||||
Reference in New Issue
Block a user