feat: sudoku support ws transport (#2589)

This commit is contained in:
saba-futai
2026-03-01 10:22:53 +08:00
committed by GitHub
parent dda1d525c1
commit 9033717190
33 changed files with 2331 additions and 581 deletions

View File

@@ -11,6 +11,7 @@ import (
N "github.com/metacubex/mihomo/common/net"
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/transport/sudoku"
"github.com/metacubex/mihomo/transport/sudoku/obfs/httpmask"
)
type Sudoku struct {
@@ -18,32 +19,43 @@ type Sudoku struct {
option *SudokuOption
baseConf sudoku.ProtocolConfig
httpMaskMu sync.Mutex
httpMaskClient *sudoku.HTTPMaskTunnelClient
muxMu sync.Mutex
muxClient *sudoku.MultiplexClient
httpMaskMu sync.Mutex
httpMaskClient *httpmask.TunnelClient
httpMaskKey string
}
type SudokuOption struct {
BasicOption
Name string `proxy:"name"`
Server string `proxy:"server"`
Port int `proxy:"port"`
Key string `proxy:"key"`
AEADMethod string `proxy:"aead-method,omitempty"`
PaddingMin *int `proxy:"padding-min,omitempty"`
PaddingMax *int `proxy:"padding-max,omitempty"`
TableType string `proxy:"table-type,omitempty"` // "prefer_ascii" or "prefer_entropy"
EnablePureDownlink *bool `proxy:"enable-pure-downlink,omitempty"`
HTTPMask bool `proxy:"http-mask,omitempty"`
HTTPMaskMode string `proxy:"http-mask-mode,omitempty"` // "legacy" (default), "stream", "poll", "auto"
HTTPMaskTLS bool `proxy:"http-mask-tls,omitempty"` // only for http-mask-mode stream/poll/auto
HTTPMaskHost string `proxy:"http-mask-host,omitempty"` // optional Host/SNI override (domain or domain:port)
PathRoot string `proxy:"path-root,omitempty"` // optional first-level path prefix for HTTP tunnel endpoints
HTTPMaskMultiplex string `proxy:"http-mask-multiplex,omitempty"` // "off" (default), "auto" (reuse h1/h2), "on" (single tunnel, multi-target)
CustomTable string `proxy:"custom-table,omitempty"` // optional custom byte layout, e.g. xpxvvpvv
CustomTables []string `proxy:"custom-tables,omitempty"` // optional table rotation patterns, overrides custom-table when non-empty
Name string `proxy:"name"`
Server string `proxy:"server"`
Port int `proxy:"port"`
Key string `proxy:"key"`
AEADMethod string `proxy:"aead-method,omitempty"`
PaddingMin *int `proxy:"padding-min,omitempty"`
PaddingMax *int `proxy:"padding-max,omitempty"`
TableType string `proxy:"table-type,omitempty"` // "prefer_ascii" or "prefer_entropy"
EnablePureDownlink *bool `proxy:"enable-pure-downlink,omitempty"`
HTTPMask *bool `proxy:"http-mask,omitempty"`
HTTPMaskMode string `proxy:"http-mask-mode,omitempty"` // "legacy" (default), "stream", "poll", "auto", "ws"
HTTPMaskTLS bool `proxy:"http-mask-tls,omitempty"` // only for http-mask-mode stream/poll/auto
HTTPMaskHost string `proxy:"http-mask-host,omitempty"` // optional Host/SNI override (domain or domain:port)
PathRoot string `proxy:"path-root,omitempty"` // optional first-level path prefix for HTTP tunnel endpoints
HTTPMaskMultiplex string `proxy:"http-mask-multiplex,omitempty"` // "off" (default), "auto" (reuse h1/h2), "on" (single tunnel, multi-target)
HTTPMaskOptions *SudokuHTTPMaskOptions `proxy:"httpmask,omitempty"`
CustomTable string `proxy:"custom-table,omitempty"` // optional custom byte layout, e.g. xpxvvpvv
CustomTables []string `proxy:"custom-tables,omitempty"` // optional table rotation patterns, overrides custom-table when non-empty
}
type SudokuHTTPMaskOptions struct {
Disable bool `proxy:"disable,omitempty"`
Mode string `proxy:"mode,omitempty"`
TLS bool `proxy:"tls,omitempty"`
Host string `proxy:"host,omitempty"`
PathRoot string `proxy:"path_root,omitempty"`
Multiplex string `proxy:"multiplex,omitempty"`
}
// DialContext implements C.ProxyAdapter
@@ -73,7 +85,7 @@ func (s *Sudoku) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Con
return nil, fmt.Errorf("encode target address failed: %w", err)
}
if _, err = c.Write(addrBuf); err != nil {
if err = sudoku.WriteKIPMessage(c, sudoku.KIPTypeOpenTCP, addrBuf); err != nil {
return nil, fmt.Errorf("send target address failed: %w", err)
}
@@ -96,9 +108,9 @@ func (s *Sudoku) ListenPacketContext(ctx context.Context, metadata *C.Metadata)
return nil, err
}
if err = sudoku.WritePreface(c); err != nil {
if err = sudoku.WriteKIPMessage(c, sudoku.KIPTypeStartUoT, nil); err != nil {
_ = c.Close()
return nil, fmt.Errorf("send uot preface failed: %w", err)
return nil, fmt.Errorf("start uot failed: %w", err)
}
return newPacketConn(N.NewThreadSafePacketConn(sudoku.NewUoTPacketConn(c)), s), nil
@@ -141,32 +153,45 @@ func NewSudoku(option SudokuOption) (*Sudoku, error) {
return nil, fmt.Errorf("key is required")
}
tableType := strings.ToLower(option.TableType)
if tableType == "" {
tableType = "prefer_ascii"
defaultConf := sudoku.DefaultConfig()
tableType, err := sudoku.NormalizeTableType(option.TableType)
if err != nil {
return nil, err
}
if tableType != "prefer_ascii" && tableType != "prefer_entropy" {
return nil, fmt.Errorf("table-type must be prefer_ascii or prefer_entropy")
paddingMin, paddingMax := sudoku.ResolvePadding(option.PaddingMin, option.PaddingMax, defaultConf.PaddingMin, defaultConf.PaddingMax)
enablePureDownlink := sudoku.DerefBool(option.EnablePureDownlink, defaultConf.EnablePureDownlink)
disableHTTPMask := defaultConf.DisableHTTPMask
if option.HTTPMask != nil {
disableHTTPMask = !*option.HTTPMask
}
httpMaskMode := defaultConf.HTTPMaskMode
if option.HTTPMaskMode != "" {
httpMaskMode = option.HTTPMaskMode
}
httpMaskTLS := option.HTTPMaskTLS
httpMaskHost := option.HTTPMaskHost
pathRoot := strings.TrimSpace(option.PathRoot)
httpMaskMultiplex := defaultConf.HTTPMaskMultiplex
if option.HTTPMaskMultiplex != "" {
httpMaskMultiplex = option.HTTPMaskMultiplex
}
defaultConf := sudoku.DefaultConfig()
paddingMin := defaultConf.PaddingMin
paddingMax := defaultConf.PaddingMax
if option.PaddingMin != nil {
paddingMin = *option.PaddingMin
}
if option.PaddingMax != nil {
paddingMax = *option.PaddingMax
}
if option.PaddingMin == nil && option.PaddingMax != nil && paddingMax < paddingMin {
paddingMin = paddingMax
}
if option.PaddingMax == nil && option.PaddingMin != nil && paddingMax < paddingMin {
paddingMax = paddingMin
}
enablePureDownlink := defaultConf.EnablePureDownlink
if option.EnablePureDownlink != nil {
enablePureDownlink = *option.EnablePureDownlink
if hm := option.HTTPMaskOptions; hm != nil {
disableHTTPMask = hm.Disable
if hm.Mode != "" {
httpMaskMode = hm.Mode
}
httpMaskTLS = hm.TLS
httpMaskHost = hm.Host
if pr := strings.TrimSpace(hm.PathRoot); pr != "" {
pathRoot = pr
}
if mux := strings.TrimSpace(hm.Multiplex); mux != "" {
httpMaskMultiplex = mux
} else {
httpMaskMultiplex = defaultConf.HTTPMaskMultiplex
}
}
baseConf := sudoku.ProtocolConfig{
@@ -177,20 +202,14 @@ func NewSudoku(option SudokuOption) (*Sudoku, error) {
PaddingMax: paddingMax,
EnablePureDownlink: enablePureDownlink,
HandshakeTimeoutSeconds: defaultConf.HandshakeTimeoutSeconds,
DisableHTTPMask: !option.HTTPMask,
HTTPMaskMode: defaultConf.HTTPMaskMode,
HTTPMaskTLSEnabled: option.HTTPMaskTLS,
HTTPMaskHost: option.HTTPMaskHost,
HTTPMaskPathRoot: strings.TrimSpace(option.PathRoot),
HTTPMaskMultiplex: defaultConf.HTTPMaskMultiplex,
DisableHTTPMask: disableHTTPMask,
HTTPMaskMode: httpMaskMode,
HTTPMaskTLSEnabled: httpMaskTLS,
HTTPMaskHost: httpMaskHost,
HTTPMaskPathRoot: pathRoot,
HTTPMaskMultiplex: httpMaskMultiplex,
}
if option.HTTPMaskMode != "" {
baseConf.HTTPMaskMode = option.HTTPMaskMode
}
if option.HTTPMaskMultiplex != "" {
baseConf.HTTPMaskMultiplex = option.HTTPMaskMultiplex
}
tables, err := sudoku.NewTablesWithCustomPatterns(sudoku.ClientAEADSeed(option.Key), tableType, option.CustomTable, option.CustomTables)
tables, err := sudoku.NewClientTablesWithCustomPatterns(sudoku.ClientAEADSeed(option.Key), tableType, option.CustomTable, option.CustomTables)
if err != nil {
return nil, fmt.Errorf("build table(s) failed: %w", err)
}
@@ -244,7 +263,7 @@ func normalizeHTTPMaskMultiplex(mode string) string {
func httpTunnelModeEnabled(mode string) bool {
switch strings.ToLower(strings.TrimSpace(mode)) {
case "stream", "poll", "auto":
case "stream", "poll", "auto", "ws":
return true
default:
return false
@@ -271,14 +290,24 @@ func (s *Sudoku) dialAndHandshake(ctx context.Context, cfg *sudoku.ProtocolConfi
)
if !cfg.DisableHTTPMask && httpTunnelModeEnabled(cfg.HTTPMaskMode) {
muxMode := normalizeHTTPMaskMultiplex(cfg.HTTPMaskMultiplex)
switch muxMode {
case "auto", "on":
client, errX := s.getOrCreateHTTPMaskClient(cfg)
if errX != nil {
return nil, errX
if muxMode == "auto" && strings.ToLower(strings.TrimSpace(cfg.HTTPMaskMode)) != "ws" {
if client, cerr := s.getOrCreateHTTPMaskClient(cfg); cerr == nil && client != nil {
c, err = client.DialTunnel(ctx, httpmask.TunnelDialOptions{
Mode: cfg.HTTPMaskMode,
TLSEnabled: cfg.HTTPMaskTLSEnabled,
HostOverride: cfg.HTTPMaskHost,
PathRoot: cfg.HTTPMaskPathRoot,
AuthKey: sudoku.ClientAEADSeed(cfg.Key),
Upgrade: upgrade,
Multiplex: cfg.HTTPMaskMultiplex,
DialContext: s.dialer.DialContext,
})
if err != nil {
s.resetHTTPMaskClient()
}
}
c, err = client.Dial(ctx, upgrade)
default:
}
if c == nil && err == nil {
c, err = sudoku.DialHTTPMaskTunnel(ctx, cfg.ServerAddress, cfg, s.dialer.DialContext, upgrade)
}
if err == nil && c != nil {
@@ -372,34 +401,51 @@ func (s *Sudoku) resetMuxClient() {
}
}
func (s *Sudoku) getOrCreateHTTPMaskClient(cfg *sudoku.ProtocolConfig) (*sudoku.HTTPMaskTunnelClient, error) {
if s == nil {
return nil, fmt.Errorf("nil adapter")
}
if cfg == nil {
return nil, fmt.Errorf("config is required")
}
s.httpMaskMu.Lock()
defer s.httpMaskMu.Unlock()
if s.httpMaskClient != nil {
return s.httpMaskClient, nil
}
c, err := sudoku.NewHTTPMaskTunnelClient(cfg.ServerAddress, cfg, s.dialer.DialContext)
if err != nil {
return nil, err
}
s.httpMaskClient = c
return c, nil
}
func (s *Sudoku) resetHTTPMaskClient() {
s.httpMaskMu.Lock()
defer s.httpMaskMu.Unlock()
if s.httpMaskClient != nil {
s.httpMaskClient.CloseIdleConnections()
s.httpMaskClient = nil
s.httpMaskKey = ""
}
}
func (s *Sudoku) getOrCreateHTTPMaskClient(cfg *sudoku.ProtocolConfig) (*httpmask.TunnelClient, error) {
if s == nil || cfg == nil {
return nil, fmt.Errorf("nil adapter or config")
}
key := cfg.ServerAddress + "|" + strconv.FormatBool(cfg.HTTPMaskTLSEnabled) + "|" + strings.TrimSpace(cfg.HTTPMaskHost)
s.httpMaskMu.Lock()
if s.httpMaskClient != nil && s.httpMaskKey == key {
client := s.httpMaskClient
s.httpMaskMu.Unlock()
return client, nil
}
s.httpMaskMu.Unlock()
client, err := httpmask.NewTunnelClient(cfg.ServerAddress, httpmask.TunnelClientOptions{
TLSEnabled: cfg.HTTPMaskTLSEnabled,
HostOverride: cfg.HTTPMaskHost,
DialContext: s.dialer.DialContext,
MaxIdleConns: 32,
})
if err != nil {
return nil, err
}
s.httpMaskMu.Lock()
defer s.httpMaskMu.Unlock()
if s.httpMaskClient != nil && s.httpMaskKey == key {
client.CloseIdleConnections()
return s.httpMaskClient, nil
}
if s.httpMaskClient != nil {
s.httpMaskClient.CloseIdleConnections()
}
s.httpMaskClient = client
s.httpMaskKey = key
return client, nil
}

View File

@@ -1092,12 +1092,22 @@ proxies: # socks5
table-type: prefer_ascii # 可选值prefer_ascii、prefer_entropy 前者全ascii映射后者保证熵值汉明1低于3
# custom-table: xpxvvpvv # 可选自定义字节布局必须包含2个x、2个p、4个v可随意组合。启用此处则需配置`table-type`为`prefer_entropy`
# custom-tables: ["xpxvvpvv", "vxpvxvvp"] # 可选自定义字节布局列表x/v/p用于 xvp 模式轮换;非空时覆盖 custom-table
http-mask: true # 是否启用http掩码
# http-mask-mode: legacy # 可选legacy默认、streamsplit-stream、poll、auto先 stream 再 pollstream/poll/auto 支持走 CDN/反代
# http-mask-tls: true # 可选:仅在 http-mask-mode 为 stream/poll/auto 时生效true 强制 httpsfalse 强制 http不会根据端口自动推断
# http-mask-host: "" # 可选:覆盖 Host/SNI支持 example.com 或 example.com:443仅在 http-mask-mode 为 stream/poll/auto 时生效
# path-root: "" # 可选HTTP 隧道端点一级路径前缀(双方需一致),例如 "aabbcc" 或 "/aabbcc/" => /aabbcc/session、/aabbcc/stream、/aabbcc/api/v1/upload
# http-mask-multiplex: off # 可选off默认、auto复用底层 HTTP 连接,减少建链 RTT、onSudoku mux 单隧道多目标;仅在 http-mask-mode=stream/poll/auto 生效
# 推荐:使用 httpmask 对象统一管理 HTTPMask 相关字段:
httpmask:
disable: false # true 禁用所有 HTTP 伪装/隧道
mode: legacy # 可选legacy默认、streamsplit-stream、poll、auto streampoll、wsWebSocket 隧道)
# tls: true # 可选:仅在 mode 为 stream/poll/auto/ws 时生效true 强制 https/wssfalse 强制 http/ws不会根据端口自动推断
# host: "" # 可选:覆盖 Host/SNI支持 example.com 或 example.com:443仅在 modestream/poll/auto/ws 时生效
# path_root: "" # 可选HTTP 隧道端点一级路径前缀(双方需一致),例如 "aabbcc" 或 "/aabbcc/" => /aabbcc/session、/aabbcc/stream、/aabbcc/api/v1/upload、/aabbcc/ws
# multiplex: off # 可选off默认、auto复用底层 HTTP 连接,减少建链 RTT、onSudoku mux 单隧道多目标;仅在 mode=stream/poll/auto 生效ws 强制 off
#
# 向后兼容旧写法:
# http-mask: true # 是否启用 http 掩码
# http-mask-mode: legacy # 可选legacy默认、streamsplit-stream、poll、auto先 stream 再 poll、wsWebSocket 隧道)
# http-mask-tls: true # 可选:仅在 http-mask-mode 为 stream/poll/auto/ws 时生效true 强制 https/wssfalse 强制 http/ws
# http-mask-host: "" # 可选:覆盖 Host/SNI支持 example.com 或 example.com:443仅在 http-mask-mode 为 stream/poll/auto/ws 时生效
# path-root: "" # 可选HTTP 隧道端点一级路径前缀(双方需一致)
# http-mask-multiplex: off # 可选off默认、auto复用底层 HTTP 连接、onSudoku mux 单隧道多目标ws 强制 off
enable-pure-downlink: false # 可选false=带宽优化下行(更快,要求 aead-method != nonetrue=纯 Sudoku 下行
# anytls
@@ -1663,9 +1673,19 @@ listeners:
# custom-tables: ["xpxvvpvv", "vxpvxvvp"] # 可选自定义字节布局列表x/v/p用于 xvp 模式轮换;非空时覆盖 custom-table
handshake-timeout: 5 # 可选(秒)
enable-pure-downlink: false # 可选false=带宽优化下行(更快,要求 aead-method != nonetrue=纯 Sudoku 下行
disable-http-mask: false # 可选:禁用 http 掩码/隧道(默认为 false
# http-mask-mode: legacy # 可选legacy默认、streamsplit-stream、poll、auto先 stream 再 pollstream/poll/auto 支持走 CDN/反代
# path-root: "" # 可选HTTP 隧道端点一级路径前缀(双方需一致),例如 "aabbcc" 或 "/aabbcc/" => /aabbcc/session、/aabbcc/stream、/aabbcc/api/v1/upload
# 推荐:使用 httpmask 对象统一管理 HTTPMask 相关字段:
httpmask:
disable: false # true 禁用所有 HTTP 伪装/隧道
mode: legacy # 可选legacy默认、streamsplit-stream、poll、auto先 stream 再 poll、wsWebSocket 隧道)
# path_root: "" # 可选HTTP 隧道端点一级路径前缀(双方需一致),例如 "aabbcc" 或 "/aabbcc/" => /aabbcc/session、/aabbcc/stream、/aabbcc/api/v1/upload、/aabbcc/ws
#
# 可选:当启用 HTTPMask 且识别到“像 HTTP 但不符合 tunnel/auth”的请求时将原始字节透传给 fallback常用于与其他服务共端口
# fallback: "127.0.0.1:80"
#
# 向后兼容旧写法:
# disable-http-mask: false # 可选:禁用 http 掩码/隧道(默认为 false
# http-mask-mode: legacy # 可选legacy默认、streamsplit-stream、poll、auto先 stream 再 poll、wsWebSocket 隧道)
# path-root: "" # 可选HTTP 隧道端点一级路径前缀(双方需一致)

View File

@@ -23,6 +23,7 @@ type SudokuServer struct {
DisableHTTPMask bool `json:"disable-http-mask,omitempty"`
HTTPMaskMode string `json:"http-mask-mode,omitempty"`
PathRoot string `json:"path-root,omitempty"`
Fallback string `json:"fallback,omitempty"`
// mihomo private extension (not the part of standard Sudoku protocol)
MuxOption sing.MuxOption `json:"mux-option,omitempty"`

View File

@@ -13,23 +13,31 @@ import (
type SudokuOption struct {
BaseOption
Key string `inbound:"key"`
AEADMethod string `inbound:"aead-method,omitempty"`
PaddingMin *int `inbound:"padding-min,omitempty"`
PaddingMax *int `inbound:"padding-max,omitempty"`
TableType string `inbound:"table-type,omitempty"` // "prefer_ascii" or "prefer_entropy"
HandshakeTimeoutSecond *int `inbound:"handshake-timeout,omitempty"`
EnablePureDownlink *bool `inbound:"enable-pure-downlink,omitempty"`
CustomTable string `inbound:"custom-table,omitempty"` // optional custom byte layout, e.g. xpxvvpvv
CustomTables []string `inbound:"custom-tables,omitempty"`
DisableHTTPMask bool `inbound:"disable-http-mask,omitempty"`
HTTPMaskMode string `inbound:"http-mask-mode,omitempty"` // "legacy" (default), "stream", "poll", "auto"
PathRoot string `inbound:"path-root,omitempty"` // optional first-level path prefix for HTTP tunnel endpoints
Key string `inbound:"key"`
AEADMethod string `inbound:"aead-method,omitempty"`
PaddingMin *int `inbound:"padding-min,omitempty"`
PaddingMax *int `inbound:"padding-max,omitempty"`
TableType string `inbound:"table-type,omitempty"` // "prefer_ascii" or "prefer_entropy"
HandshakeTimeoutSecond *int `inbound:"handshake-timeout,omitempty"`
EnablePureDownlink *bool `inbound:"enable-pure-downlink,omitempty"`
CustomTable string `inbound:"custom-table,omitempty"` // optional custom byte layout, e.g. xpxvvpvv
CustomTables []string `inbound:"custom-tables,omitempty"`
DisableHTTPMask bool `inbound:"disable-http-mask,omitempty"`
HTTPMaskMode string `inbound:"http-mask-mode,omitempty"` // "legacy" (default), "stream", "poll", "auto"
PathRoot string `inbound:"path-root,omitempty"` // optional first-level path prefix for HTTP tunnel endpoints
Fallback string `inbound:"fallback,omitempty"`
HTTPMaskOptions *SudokuHTTPMaskOptions `inbound:"httpmask,omitempty"`
// mihomo private extension (not the part of standard Sudoku protocol)
MuxOption MuxOption `inbound:"mux-option,omitempty"`
}
type SudokuHTTPMaskOptions struct {
Disable bool `inbound:"disable,omitempty"`
Mode string `inbound:"mode,omitempty"`
PathRoot string `inbound:"path_root,omitempty"`
}
func (o SudokuOption) Equal(config C.InboundConfig) bool {
return optionToString(o) == optionToString(config)
}
@@ -65,6 +73,16 @@ func NewSudoku(options *SudokuOption) (*Sudoku, error) {
DisableHTTPMask: options.DisableHTTPMask,
HTTPMaskMode: options.HTTPMaskMode,
PathRoot: strings.TrimSpace(options.PathRoot),
Fallback: strings.TrimSpace(options.Fallback),
}
if hm := options.HTTPMaskOptions; hm != nil {
serverConf.DisableHTTPMask = hm.Disable
if hm.Mode != "" {
serverConf.HTTPMaskMode = hm.Mode
}
if pr := strings.TrimSpace(hm.PathRoot); pr != "" {
serverConf.PathRoot = pr
}
}
serverConf.MuxOption = options.MuxOption.Build()

View File

@@ -168,16 +168,17 @@ func TestInboundSudoku_CustomTable(t *testing.T) {
func TestInboundSudoku_HTTPMaskMode(t *testing.T) {
key := "test_key_http_mask_mode"
for _, mode := range []string{"legacy", "stream", "poll", "auto"} {
for _, mode := range []string{"ws", "stream", "poll", "auto"} {
mode := mode
t.Run(mode, func(t *testing.T) {
inboundOptions := inbound.SudokuOption{
Key: key,
HTTPMaskMode: mode,
}
httpMask := true
outboundOptions := outbound.SudokuOption{
Key: key,
HTTPMask: true,
HTTPMask: &httpMask,
HTTPMaskMode: mode,
}
testInboundSudoku(t, inboundOptions, outboundOptions)

View File

@@ -1,16 +1,19 @@
package sudoku
import (
"bytes"
"errors"
"io"
"net"
"strings"
"time"
"github.com/metacubex/mihomo/adapter/inbound"
N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/common/utils"
C "github.com/metacubex/mihomo/constant"
LC "github.com/metacubex/mihomo/listener/config"
"github.com/metacubex/mihomo/listener/inner"
"github.com/metacubex/mihomo/listener/sing"
"github.com/metacubex/mihomo/log"
"github.com/metacubex/mihomo/transport/socks5"
@@ -23,6 +26,7 @@ type Listener struct {
closed bool
protoConf sudoku.ProtocolConfig
tunnelSrv *sudoku.HTTPMaskTunnelServer
fallback string
handler *sing.ListenerHandler
}
@@ -49,12 +53,19 @@ func (l *Listener) Close() error {
}
func (l *Listener) handleConn(conn net.Conn, tunnel C.Tunnel, additions ...inbound.Addition) {
log.Debugln("[Sudoku] accepted %s", conn.RemoteAddr())
handshakeConn := conn
handshakeCfg := &l.protoConf
closeConns := func() {
_ = handshakeConn.Close()
if handshakeConn != conn {
_ = conn.Close()
}
}
if l.tunnelSrv != nil {
c, cfg, done, err := l.tunnelSrv.WrapConn(conn)
if err != nil {
_ = conn.Close()
closeConns()
return
}
if done {
@@ -68,9 +79,43 @@ func (l *Listener) handleConn(conn net.Conn, tunnel C.Tunnel, additions ...inbou
}
}
session, err := sudoku.ServerHandshake(handshakeConn, handshakeCfg)
if l.fallback != "" {
if r, ok := handshakeConn.(interface{ IsHTTPMaskRejected() bool }); ok && r.IsHTTPMaskRejected() {
fb, err := inner.HandleTcp(tunnel, l.fallback, "")
if err != nil {
closeConns()
return
}
N.Relay(handshakeConn, fb)
return
}
}
cConn, meta, err := sudoku.ServerHandshake(handshakeConn, handshakeCfg)
if err != nil {
_ = handshakeConn.Close()
fallbackAddr := l.fallback
var susp *sudoku.SuspiciousError
isSuspicious := errors.As(err, &susp) && susp != nil && susp.Conn != nil
if isSuspicious {
log.Warnln("[Sudoku] suspicious handshake from %s: %v", conn.RemoteAddr(), err)
if fallbackAddr != "" {
fb, err := inner.HandleTcp(tunnel, fallbackAddr, "")
if err == nil {
relayToFallback(susp.Conn, conn, fb)
return
}
}
} else {
log.Debugln("[Sudoku] handshake failed from %s: %v", conn.RemoteAddr(), err)
}
closeConns()
return
}
session, err := sudoku.ReadServerSession(cConn, meta)
if err != nil {
log.Warnln("[Sudoku] read session failed from %s: %v", conn.RemoteAddr(), err)
_ = cConn.Close()
if handshakeConn != conn {
_ = conn.Close()
}
@@ -103,6 +148,7 @@ func (l *Listener) handleConn(conn net.Conn, tunnel C.Tunnel, additions ...inbou
default:
targetAddr := socks5.ParseAddr(session.Target)
if targetAddr == nil {
log.Warnln("[Sudoku] invalid target from %s: %q", conn.RemoteAddr(), session.Target)
_ = session.Conn.Close()
return
}
@@ -164,6 +210,24 @@ func (p *uotPacket) LocalAddr() net.Addr {
return p.rAddr
}
func relayToFallback(wrapper net.Conn, rawConn net.Conn, fallback net.Conn) {
if wrapper != nil {
if recorder, ok := wrapper.(interface{ GetBufferedAndRecorded() []byte }); ok {
badData := recorder.GetBufferedAndRecorded()
if len(badData) > 0 {
_ = fallback.SetWriteDeadline(time.Now().Add(3 * time.Second))
if _, err := io.Copy(fallback, bytes.NewReader(badData)); err != nil {
_ = fallback.Close()
_ = rawConn.Close()
return
}
_ = fallback.SetWriteDeadline(time.Time{})
}
}
}
N.Relay(rawConn, fallback)
}
func New(config LC.SudokuServer, tunnel C.Tunnel, additions ...inbound.Addition) (*Listener, error) {
if len(additions) == 0 {
additions = []inbound.Addition{
@@ -188,42 +252,24 @@ func New(config LC.SudokuServer, tunnel C.Tunnel, additions ...inbound.Addition)
return nil, err
}
tableType := strings.ToLower(config.TableType)
if tableType == "" {
tableType = "prefer_ascii"
}
defaultConf := sudoku.DefaultConfig()
paddingMin := defaultConf.PaddingMin
paddingMax := defaultConf.PaddingMax
if config.PaddingMin != nil {
paddingMin = *config.PaddingMin
}
if config.PaddingMax != nil {
paddingMax = *config.PaddingMax
}
if config.PaddingMin == nil && config.PaddingMax != nil && paddingMax < paddingMin {
paddingMin = paddingMax
}
if config.PaddingMax == nil && config.PaddingMin != nil && paddingMax < paddingMin {
paddingMax = paddingMin
}
enablePureDownlink := defaultConf.EnablePureDownlink
if config.EnablePureDownlink != nil {
enablePureDownlink = *config.EnablePureDownlink
}
tables, err := sudoku.NewTablesWithCustomPatterns(config.Key, tableType, config.CustomTable, config.CustomTables)
tableType, err := sudoku.NormalizeTableType(config.TableType)
if err != nil {
_ = l.Close()
return nil, err
}
handshakeTimeout := defaultConf.HandshakeTimeoutSeconds
if config.HandshakeTimeoutSecond != nil {
handshakeTimeout = *config.HandshakeTimeoutSecond
defaultConf := sudoku.DefaultConfig()
paddingMin, paddingMax := sudoku.ResolvePadding(config.PaddingMin, config.PaddingMax, defaultConf.PaddingMin, defaultConf.PaddingMax)
enablePureDownlink := sudoku.DerefBool(config.EnablePureDownlink, defaultConf.EnablePureDownlink)
tables, err := sudoku.NewServerTablesWithCustomPatterns(sudoku.ServerAEADSeed(config.Key), tableType, config.CustomTable, config.CustomTables)
if err != nil {
_ = l.Close()
return nil, err
}
handshakeTimeout := sudoku.DerefInt(config.HandshakeTimeoutSecond, defaultConf.HandshakeTimeoutSeconds)
protoConf := sudoku.ProtocolConfig{
Key: config.Key,
AEADMethod: defaultConf.AEADMethod,
@@ -249,8 +295,13 @@ func New(config LC.SudokuServer, tunnel C.Tunnel, additions ...inbound.Addition)
addr: config.Listen,
protoConf: protoConf,
handler: h,
fallback: strings.TrimSpace(config.Fallback),
}
if sl.fallback != "" {
sl.tunnelSrv = sudoku.NewHTTPMaskTunnelServerWithFallback(&sl.protoConf)
} else {
sl.tunnelSrv = sudoku.NewHTTPMaskTunnelServer(&sl.protoConf)
}
sl.tunnelSrv = sudoku.NewHTTPMaskTunnelServer(&sl.protoConf)
go func() {
for {

View File

@@ -6,6 +6,7 @@ import (
"io"
"net"
"strconv"
"strings"
)
func EncodeAddress(rawAddr string) ([]byte, error) {
@@ -20,13 +21,21 @@ func EncodeAddress(rawAddr string) ([]byte, error) {
}
var buf []byte
if i := strings.IndexByte(host, '%'); i >= 0 {
// Zone identifiers are not representable in SOCKS5 IPv6 address encoding.
host = host[:i]
}
if ip := net.ParseIP(host); ip != nil {
if ip4 := ip.To4(); ip4 != nil {
buf = append(buf, 0x01) // IPv4
buf = append(buf, ip4...)
} else {
buf = append(buf, 0x04) // IPv6
buf = append(buf, ip...)
ip16 := ip.To16()
if ip16 == nil {
return nil, fmt.Errorf("invalid ipv6: %q", host)
}
buf = append(buf, ip16...)
}
} else {
if len(host) > 255 {

View File

@@ -50,6 +50,7 @@ type ProtocolConfig struct {
// - "stream": real HTTP tunnel (split-stream), CDN-compatible
// - "poll": plain HTTP tunnel (authorize/push/pull), strong restricted-network pass-through
// - "auto": try stream then fall back to poll
// - "ws": WebSocket tunnel (GET upgrade), CDN-friendly
HTTPMaskMode string
// HTTPMaskTLSEnabled enables HTTPS for HTTP tunnel modes (client-side).
@@ -109,9 +110,9 @@ func (c *ProtocolConfig) Validate() error {
}
switch strings.ToLower(strings.TrimSpace(c.HTTPMaskMode)) {
case "", "legacy", "stream", "poll", "auto":
case "", "legacy", "stream", "poll", "auto", "ws":
default:
return fmt.Errorf("invalid http-mask-mode: %s, must be one of: legacy, stream, poll, auto", c.HTTPMaskMode)
return fmt.Errorf("invalid http-mask-mode: %s, must be one of: legacy, stream, poll, auto, ws", c.HTTPMaskMode)
}
if v := strings.TrimSpace(c.HTTPMaskPathRoot); v != "" {
@@ -166,6 +167,44 @@ func DefaultConfig() *ProtocolConfig {
}
}
func DerefInt(v *int, def int) int {
if v == nil {
return def
}
return *v
}
func DerefBool(v *bool, def bool) bool {
if v == nil {
return def
}
return *v
}
// ResolvePadding applies defaults and keeps min/max consistent when only one side is provided.
func ResolvePadding(min, max *int, defMin, defMax int) (int, int) {
paddingMin := DerefInt(min, defMin)
paddingMax := DerefInt(max, defMax)
switch {
case min == nil && max != nil && paddingMax < paddingMin:
paddingMin = paddingMax
case max == nil && min != nil && paddingMax < paddingMin:
paddingMax = paddingMin
}
return paddingMin, paddingMax
}
func NormalizeTableType(tableType string) (string, error) {
switch t := strings.ToLower(strings.TrimSpace(tableType)); t {
case "", "prefer_ascii":
return "prefer_ascii", nil
case "prefer_entropy":
return "prefer_entropy", nil
default:
return "", fmt.Errorf("table-type must be prefer_ascii or prefer_entropy")
}
}
func (c *ProtocolConfig) tableCandidates() []*sudoku.Table {
if c == nil {
return nil

View File

@@ -80,8 +80,8 @@ func RecoverPublicKey(keyHex string) (*edwards25519.Point, error) {
return nil, fmt.Errorf("invalid scalar: %w", err)
}
return new(edwards25519.Point).ScalarBaseMult(x), nil
} else if len(keyBytes) == 64 {
}
if len(keyBytes) == 64 {
// Split Key r || k
rBytes := keyBytes[:32]
kBytes := keyBytes[32:]

View File

@@ -0,0 +1,374 @@
package crypto
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/hmac"
"crypto/sha256"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"sync"
"golang.org/x/crypto/chacha20poly1305"
)
// KeyUpdateAfterBytes controls automatic key rotation based on plaintext bytes.
// It is a package var (not config) to enable targeted tests with smaller thresholds.
var KeyUpdateAfterBytes int64 = 32 << 20 // 32 MiB
const (
recordHeaderSize = 12 // epoch(uint32) + seq(uint64) - also used as nonce+AAD.
maxFrameBodySize = 65535
)
type recordKeys struct {
baseSend []byte
baseRecv []byte
}
// RecordConn is a framed AEAD net.Conn with:
// - deterministic per-record nonce (epoch+seq)
// - per-direction key rotation (epoch), driven by plaintext byte counters
// - replay/out-of-order protection within the connection (strict seq check)
//
// Wire format per record:
// - uint16 bodyLen
// - header[12] = epoch(uint32 BE) || seq(uint64 BE) (plaintext)
// - ciphertext = AEAD(header as nonce, plaintext, header as AAD)
type RecordConn struct {
net.Conn
method string
writeMu sync.Mutex
readMu sync.Mutex
keys recordKeys
sendAEAD cipher.AEAD
sendAEADEpoch uint32
recvAEAD cipher.AEAD
recvAEADEpoch uint32
// Send direction state.
sendEpoch uint32
sendSeq uint64
sendBytes int64
// Receive direction state.
recvEpoch uint32
recvSeq uint64
readBuf bytes.Buffer
// writeFrame is a reusable buffer for [len||header||ciphertext] on the wire.
// Guarded by writeMu.
writeFrame []byte
}
func (c *RecordConn) CloseWrite() error {
if c == nil {
return nil
}
if cw, ok := c.Conn.(interface{ CloseWrite() error }); ok {
return cw.CloseWrite()
}
return nil
}
func (c *RecordConn) CloseRead() error {
if c == nil {
return nil
}
if cr, ok := c.Conn.(interface{ CloseRead() error }); ok {
return cr.CloseRead()
}
return nil
}
func NewRecordConn(conn net.Conn, method string, baseSend, baseRecv []byte) (*RecordConn, error) {
if conn == nil {
return nil, fmt.Errorf("nil conn")
}
method = normalizeAEADMethod(method)
if method != "none" {
if err := validateBaseKey(baseSend); err != nil {
return nil, fmt.Errorf("invalid send base key: %w", err)
}
if err := validateBaseKey(baseRecv); err != nil {
return nil, fmt.Errorf("invalid recv base key: %w", err)
}
}
rc := &RecordConn{Conn: conn, method: method}
rc.keys = recordKeys{baseSend: cloneBytes(baseSend), baseRecv: cloneBytes(baseRecv)}
return rc, nil
}
func (c *RecordConn) Rekey(baseSend, baseRecv []byte) error {
if c == nil {
return fmt.Errorf("nil conn")
}
if c.method != "none" {
if err := validateBaseKey(baseSend); err != nil {
return fmt.Errorf("invalid send base key: %w", err)
}
if err := validateBaseKey(baseRecv); err != nil {
return fmt.Errorf("invalid recv base key: %w", err)
}
}
c.writeMu.Lock()
c.readMu.Lock()
defer c.readMu.Unlock()
defer c.writeMu.Unlock()
c.keys = recordKeys{baseSend: cloneBytes(baseSend), baseRecv: cloneBytes(baseRecv)}
c.sendEpoch = 0
c.sendSeq = 0
c.sendBytes = 0
c.recvEpoch = 0
c.recvSeq = 0
c.readBuf.Reset()
c.sendAEAD = nil
c.recvAEAD = nil
c.sendAEADEpoch = 0
c.recvAEADEpoch = 0
return nil
}
func normalizeAEADMethod(method string) string {
switch method {
case "", "chacha20-poly1305":
return "chacha20-poly1305"
case "aes-128-gcm", "none":
return method
default:
return method
}
}
func validateBaseKey(b []byte) error {
if len(b) < 32 {
return fmt.Errorf("need at least 32 bytes, got %d", len(b))
}
return nil
}
func cloneBytes(b []byte) []byte {
if len(b) == 0 {
return nil
}
return append([]byte(nil), b...)
}
func (c *RecordConn) newAEADFor(base []byte, epoch uint32) (cipher.AEAD, error) {
if c.method == "none" {
return nil, nil
}
key := deriveEpochKey(base, epoch, c.method)
switch c.method {
case "aes-128-gcm":
block, err := aes.NewCipher(key[:16])
if err != nil {
return nil, err
}
a, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
if a.NonceSize() != recordHeaderSize {
return nil, fmt.Errorf("unexpected gcm nonce size: %d", a.NonceSize())
}
return a, nil
case "chacha20-poly1305":
a, err := chacha20poly1305.New(key[:32])
if err != nil {
return nil, err
}
if a.NonceSize() != recordHeaderSize {
return nil, fmt.Errorf("unexpected chacha nonce size: %d", a.NonceSize())
}
return a, nil
default:
return nil, fmt.Errorf("unsupported cipher: %s", c.method)
}
}
func deriveEpochKey(base []byte, epoch uint32, method string) []byte {
var b [4]byte
binary.BigEndian.PutUint32(b[:], epoch)
mac := hmac.New(sha256.New, base)
_, _ = mac.Write([]byte("sudoku-record:"))
_, _ = mac.Write([]byte(method))
_, _ = mac.Write(b[:])
return mac.Sum(nil)
}
func (c *RecordConn) maybeBumpSendEpochLocked(addedPlain int) {
if KeyUpdateAfterBytes <= 0 || c.method == "none" {
return
}
c.sendBytes += int64(addedPlain)
threshold := KeyUpdateAfterBytes * int64(c.sendEpoch+1)
if c.sendBytes < threshold {
return
}
c.sendEpoch++
c.sendSeq = 0
}
func (c *RecordConn) Write(p []byte) (int, error) {
if c == nil || c.Conn == nil {
return 0, net.ErrClosed
}
if c.method == "none" {
return c.Conn.Write(p)
}
c.writeMu.Lock()
defer c.writeMu.Unlock()
total := 0
for len(p) > 0 {
if c.sendAEAD == nil || c.sendAEADEpoch != c.sendEpoch {
a, err := c.newAEADFor(c.keys.baseSend, c.sendEpoch)
if err != nil {
return total, err
}
c.sendAEAD = a
c.sendAEADEpoch = c.sendEpoch
}
aead := c.sendAEAD
maxPlain := maxFrameBodySize - recordHeaderSize - aead.Overhead()
if maxPlain <= 0 {
return total, errors.New("frame size too small")
}
n := len(p)
if n > maxPlain {
n = maxPlain
}
chunk := p[:n]
p = p[n:]
var header [recordHeaderSize]byte
binary.BigEndian.PutUint32(header[:4], c.sendEpoch)
binary.BigEndian.PutUint64(header[4:], c.sendSeq)
c.sendSeq++
cipherLen := n + aead.Overhead()
bodyLen := recordHeaderSize + cipherLen
frameLen := 2 + bodyLen
if bodyLen > maxFrameBodySize {
return total, errors.New("frame too large")
}
if cap(c.writeFrame) < frameLen {
c.writeFrame = make([]byte, frameLen)
}
frame := c.writeFrame[:frameLen]
binary.BigEndian.PutUint16(frame[:2], uint16(bodyLen))
copy(frame[2:2+recordHeaderSize], header[:])
dst := frame[2+recordHeaderSize : 2+recordHeaderSize : frameLen]
_ = aead.Seal(dst[:0], header[:], chunk, header[:])
if err := writeFull(c.Conn, frame); err != nil {
return total, err
}
total += n
c.maybeBumpSendEpochLocked(n)
}
return total, nil
}
func (c *RecordConn) Read(p []byte) (int, error) {
if c == nil || c.Conn == nil {
return 0, net.ErrClosed
}
if c.method == "none" {
return c.Conn.Read(p)
}
c.readMu.Lock()
defer c.readMu.Unlock()
if c.readBuf.Len() > 0 {
return c.readBuf.Read(p)
}
var lenBuf [2]byte
if _, err := io.ReadFull(c.Conn, lenBuf[:]); err != nil {
return 0, err
}
bodyLen := int(binary.BigEndian.Uint16(lenBuf[:]))
if bodyLen < recordHeaderSize {
return 0, errors.New("frame too short")
}
if bodyLen > maxFrameBodySize {
return 0, errors.New("frame too large")
}
body := make([]byte, bodyLen)
if _, err := io.ReadFull(c.Conn, body); err != nil {
return 0, err
}
header := body[:recordHeaderSize]
ciphertext := body[recordHeaderSize:]
epoch := binary.BigEndian.Uint32(header[:4])
seq := binary.BigEndian.Uint64(header[4:])
if epoch < c.recvEpoch {
return 0, fmt.Errorf("replayed epoch: got %d want >=%d", epoch, c.recvEpoch)
}
if epoch == c.recvEpoch && seq != c.recvSeq {
return 0, fmt.Errorf("out of order: epoch=%d got=%d want=%d", epoch, seq, c.recvSeq)
}
if epoch > c.recvEpoch {
const maxJump = 8
if epoch-c.recvEpoch > maxJump {
return 0, fmt.Errorf("epoch jump too large: got=%d want<=%d", epoch-c.recvEpoch, maxJump)
}
c.recvEpoch = epoch
c.recvSeq = 0
if seq != 0 {
return 0, fmt.Errorf("out of order: epoch advanced to %d but seq=%d", epoch, seq)
}
}
if c.recvAEAD == nil || c.recvAEADEpoch != c.recvEpoch {
a, err := c.newAEADFor(c.keys.baseRecv, c.recvEpoch)
if err != nil {
return 0, err
}
c.recvAEAD = a
c.recvAEADEpoch = c.recvEpoch
}
aead := c.recvAEAD
plaintext, err := aead.Open(nil, header, ciphertext, header)
if err != nil {
return 0, fmt.Errorf("decryption failed: epoch=%d seq=%d: %w", epoch, seq, err)
}
c.recvSeq++
c.readBuf.Write(plaintext)
return c.readBuf.Read(p)
}
func writeFull(w io.Writer, b []byte) error {
for len(b) > 0 {
n, err := w.Write(b)
if err != nil {
return err
}
b = b[n:]
}
return nil
}

View File

@@ -43,12 +43,17 @@ func TestCustomTablesRotation_ProbedByServer(t *testing.T) {
go func() {
defer close(errCh)
defer serverConn.Close()
session, err := ServerHandshake(serverConn, serverCfg)
c, meta, err := ServerHandshake(serverConn, serverCfg)
if err != nil {
errCh <- err
return
}
defer session.Conn.Close()
session, err := ReadServerSession(c, meta)
if err != nil {
errCh <- err
return
}
defer c.Close()
if session.Type != SessionTypeTCP {
errCh <- io.ErrUnexpectedEOF
return
@@ -69,7 +74,7 @@ func TestCustomTablesRotation_ProbedByServer(t *testing.T) {
if err != nil {
t.Fatalf("encode addr: %v", err)
}
if _, err := cConn.Write(addrBuf); err != nil {
if err := WriteKIPMessage(cConn, KIPTypeOpenTCP, addrBuf); err != nil {
t.Fatalf("write addr: %v", err)
}

View File

@@ -2,19 +2,20 @@ package sudoku
import (
"bufio"
"crypto/sha256"
"encoding/binary"
"bytes"
"crypto/ecdh"
"crypto/rand"
"encoding/hex"
"fmt"
"io"
"net"
"strings"
"sync"
"time"
"github.com/metacubex/mihomo/transport/sudoku/crypto"
"github.com/metacubex/mihomo/transport/sudoku/obfs/httpmask"
"github.com/metacubex/mihomo/transport/sudoku/obfs/sudoku"
"github.com/metacubex/mihomo/log"
)
type SessionType int
@@ -30,18 +31,96 @@ type ServerSession struct {
Type SessionType
Target string
// UserHash is a stable per-key identifier derived from the handshake payload.
// It is primarily useful for debugging / user attribution when table rotation is enabled.
// UserHash is a stable per-key identifier derived from the client hello payload.
UserHash string
}
type bufferedConn struct {
net.Conn
r *bufio.Reader
type HandshakeMeta struct {
UserHash string
}
func (bc *bufferedConn) Read(p []byte) (int, error) {
return bc.r.Read(p)
// SuspiciousError indicates a potential probing attempt or protocol violation.
// When returned, Conn (if non-nil) should contain all bytes already consumed/buffered so the caller
// can perform a best-effort fallback relay (e.g. to a local web server) without losing the request.
type SuspiciousError struct {
Err error
Conn net.Conn
}
func (e *SuspiciousError) Error() string {
if e == nil || e.Err == nil {
return ""
}
return e.Err.Error()
}
func (e *SuspiciousError) Unwrap() error { return e.Err }
type recordedConn struct {
net.Conn
recorded []byte
}
func (rc *recordedConn) GetBufferedAndRecorded() []byte { return rc.recorded }
type prefixedRecorderConn struct {
net.Conn
prefix []byte
}
func (pc *prefixedRecorderConn) GetBufferedAndRecorded() []byte {
var rest []byte
if r, ok := pc.Conn.(interface{ GetBufferedAndRecorded() []byte }); ok {
rest = r.GetBufferedAndRecorded()
}
out := make([]byte, 0, len(pc.prefix)+len(rest))
out = append(out, pc.prefix...)
out = append(out, rest...)
return out
}
// bufferedRecorderConn wraps a net.Conn and a shared bufio.Reader so we can expose buffered bytes.
// This is used for legacy HTTP mask parsing errors so callers can fall back to a real HTTP server.
type bufferedRecorderConn struct {
net.Conn
r *bufio.Reader
recorder *bytes.Buffer
mu sync.Mutex
}
func (bc *bufferedRecorderConn) Read(p []byte) (n int, err error) {
n, err = bc.r.Read(p)
if n > 0 && bc.recorder != nil {
bc.mu.Lock()
bc.recorder.Write(p[:n])
bc.mu.Unlock()
}
return n, err
}
func (bc *bufferedRecorderConn) GetBufferedAndRecorded() []byte {
if bc == nil {
return nil
}
bc.mu.Lock()
defer bc.mu.Unlock()
var recorded []byte
if bc.recorder != nil {
recorded = bc.recorder.Bytes()
}
buffered := 0
if bc.r != nil {
buffered = bc.r.Buffered()
}
if buffered <= 0 {
return recorded
}
peeked, _ := bc.r.Peek(buffered)
full := make([]byte, len(recorded)+len(peeked))
copy(full, recorded)
copy(full[len(recorded):], peeked)
return full
}
type preBufferedConn struct {
@@ -61,6 +140,26 @@ func (p *preBufferedConn) Read(b []byte) (int, error) {
return p.Conn.Read(b)
}
func (p *preBufferedConn) CloseWrite() error {
if p == nil {
return nil
}
if cw, ok := p.Conn.(interface{ CloseWrite() error }); ok {
return cw.CloseWrite()
}
return nil
}
func (p *preBufferedConn) CloseRead() error {
if p == nil {
return nil
}
if cr, ok := p.Conn.(interface{ CloseRead() error }); ok {
return cr.CloseRead()
}
return nil
}
type directionalConn struct {
net.Conn
reader io.Reader
@@ -101,6 +200,26 @@ func (c *directionalConn) Close() error {
return firstErr
}
func (c *directionalConn) CloseWrite() error {
if c == nil {
return nil
}
if cw, ok := c.Conn.(interface{ CloseWrite() error }); ok {
return cw.CloseWrite()
}
return nil
}
func (c *directionalConn) CloseRead() error {
if c == nil {
return nil
}
if cr, ok := c.Conn.(interface{ CloseRead() error }); ok {
return cr.CloseRead()
}
return nil
}
func absInt64(v int64) int64 {
if v < 0 {
return -v
@@ -108,18 +227,6 @@ func absInt64(v int64) int64 {
return v
}
const (
downlinkModePure byte = 0x01
downlinkModePacked byte = 0x02
)
func downlinkMode(cfg *ProtocolConfig) byte {
if cfg.EnablePureDownlink {
return downlinkModePure
}
return downlinkModePacked
}
func buildClientObfsConn(raw net.Conn, cfg *ProtocolConfig, table *sudoku.Table) net.Conn {
baseSudoku := sudoku.NewConn(raw, table, cfg.PaddingMin, cfg.PaddingMax, false)
if cfg.EnablePureDownlink {
@@ -138,50 +245,16 @@ func buildServerObfsConn(raw net.Conn, cfg *ProtocolConfig, table *sudoku.Table,
return uplinkSudoku, newDirectionalConn(raw, uplinkSudoku, packed, packed.Flush)
}
func buildHandshakePayload(key string) [16]byte {
var payload [16]byte
binary.BigEndian.PutUint64(payload[:8], uint64(time.Now().Unix()))
// Align with upstream: only decode hex bytes when this key is an ED25519 key material.
// For plain UUID/strings (even if they look like hex), hash the string bytes as-is.
src := []byte(key)
if _, err := crypto.RecoverPublicKey(key); err == nil {
if keyBytes, decErr := hex.DecodeString(key); decErr == nil && len(keyBytes) > 0 {
src = keyBytes
}
func isLegacyHTTPMaskMode(mode string) bool {
switch strings.ToLower(strings.TrimSpace(mode)) {
case "", "legacy":
return true
default:
return false
}
hash := sha256.Sum256(src)
copy(payload[8:], hash[:8])
return payload
}
func NewTable(key string, tableType string) *sudoku.Table {
table, err := NewTableWithCustom(key, tableType, "")
if err != nil {
panic(fmt.Sprintf("[Sudoku] failed to init tables: %v", err))
}
return table
}
func NewTableWithCustom(key string, tableType string, customTable string) (*sudoku.Table, error) {
start := time.Now()
table, err := sudoku.NewTableWithCustom(key, tableType, customTable)
if err != nil {
return nil, err
}
log.Infoln("[Sudoku] Tables initialized (%s, custom=%v) in %v", tableType, customTable != "", time.Since(start))
return table, nil
}
func ClientAEADSeed(key string) string {
if recovered, err := crypto.RecoverPublicKey(key); err == nil {
return crypto.EncodePoint(recovered)
}
return key
}
// ClientHandshake performs the client-side Sudoku handshake (without sending target address).
// ClientHandshake performs the client-side Sudoku handshake (no target request).
func ClientHandshake(rawConn net.Conn, cfg *ProtocolConfig) (net.Conn, error) {
if cfg == nil {
return nil, fmt.Errorf("config is required")
@@ -190,7 +263,7 @@ func ClientHandshake(rawConn net.Conn, cfg *ProtocolConfig) (net.Conn, error) {
return nil, fmt.Errorf("invalid config: %w", err)
}
if !cfg.DisableHTTPMask {
if !cfg.DisableHTTPMask && isLegacyHTTPMaskMode(cfg.HTTPMaskMode) {
if err := httpmask.WriteRandomRequestHeaderWithPathRoot(rawConn, cfg.ServerAddress, cfg.HTTPMaskPathRoot); err != nil {
return nil, fmt.Errorf("write http mask failed: %w", err)
}
@@ -201,32 +274,68 @@ func ClientHandshake(rawConn net.Conn, cfg *ProtocolConfig) (net.Conn, error) {
return nil, err
}
seed := ClientAEADSeed(cfg.Key)
obfsConn := buildClientObfsConn(rawConn, cfg, table)
cConn, err := crypto.NewAEADConn(obfsConn, ClientAEADSeed(cfg.Key), cfg.AEADMethod)
pskC2S, pskS2C := derivePSKDirectionalBases(seed)
rc, err := crypto.NewRecordConn(obfsConn, cfg.AEADMethod, pskC2S, pskS2C)
if err != nil {
return nil, fmt.Errorf("setup crypto failed: %w", err)
}
handshake := buildHandshakePayload(cfg.Key)
if _, err := cConn.Write(handshake[:]); err != nil {
cConn.Close()
return nil, fmt.Errorf("send handshake failed: %w", err)
}
if _, err := cConn.Write([]byte{downlinkMode(cfg)}); err != nil {
cConn.Close()
return nil, fmt.Errorf("send downlink mode failed: %w", err)
if _, err := kipHandshakeClient(rc, seed, kipUserHashFromKey(cfg.Key), KIPFeatAll); err != nil {
_ = rc.Close()
return nil, err
}
return cConn, nil
return rc, nil
}
// ServerHandshake performs Sudoku server-side handshake and detects UoT preface.
func ServerHandshake(rawConn net.Conn, cfg *ProtocolConfig) (*ServerSession, error) {
func readFirstSessionMessage(conn net.Conn) (*KIPMessage, error) {
for {
msg, err := ReadKIPMessage(conn)
if err != nil {
return nil, err
}
if msg.Type == KIPTypeKeepAlive {
continue
}
return msg, nil
}
}
func maybeConsumeLegacyHTTPMask(rawConn net.Conn, r *bufio.Reader, cfg *ProtocolConfig) ([]byte, *SuspiciousError) {
if rawConn == nil || r == nil || cfg == nil || cfg.DisableHTTPMask || !isLegacyHTTPMaskMode(cfg.HTTPMaskMode) {
return nil, nil
}
peekBytes, _ := r.Peek(4) // ignore error; subsequent read will handle it
if !httpmask.LooksLikeHTTPRequestStart(peekBytes) {
return nil, nil
}
consumed, err := httpmask.ConsumeHeader(r)
if err == nil {
return consumed, nil
}
recorder := new(bytes.Buffer)
if len(consumed) > 0 {
recorder.Write(consumed)
}
badConn := &bufferedRecorderConn{Conn: rawConn, r: r, recorder: recorder}
return consumed, &SuspiciousError{Err: fmt.Errorf("invalid http header: %w", err), Conn: badConn}
}
// ServerHandshake performs the server-side KIP handshake.
func ServerHandshake(rawConn net.Conn, cfg *ProtocolConfig) (net.Conn, *HandshakeMeta, error) {
if rawConn == nil {
return nil, nil, fmt.Errorf("nil conn")
}
if cfg == nil {
return nil, fmt.Errorf("config is required")
return nil, nil, fmt.Errorf("config is required")
}
if err := cfg.Validate(); err != nil {
return nil, fmt.Errorf("invalid config: %w", err)
return nil, nil, fmt.Errorf("invalid config: %w", err)
}
handshakeTimeout := time.Duration(cfg.HandshakeTimeoutSeconds) * time.Second
@@ -234,116 +343,113 @@ func ServerHandshake(rawConn net.Conn, cfg *ProtocolConfig) (*ServerSession, err
handshakeTimeout = 5 * time.Second
}
rawConn.SetReadDeadline(time.Now().Add(handshakeTimeout))
bufReader := bufio.NewReader(rawConn)
if !cfg.DisableHTTPMask {
if peek, err := bufReader.Peek(4); err == nil && httpmask.LooksLikeHTTPRequestStart(peek) {
if _, err := httpmask.ConsumeHeader(bufReader); err != nil {
return nil, fmt.Errorf("invalid http header: %w", err)
}
}
_ = rawConn.SetReadDeadline(time.Now().Add(handshakeTimeout))
defer func() { _ = rawConn.SetReadDeadline(time.Time{}) }()
httpHeaderData, susp := maybeConsumeLegacyHTTPMask(rawConn, bufReader, cfg)
if susp != nil {
return nil, nil, susp
}
selectedTable, preRead, err := selectTableByProbe(bufReader, cfg, cfg.tableCandidates())
if err != nil {
return nil, err
combined := make([]byte, 0, len(httpHeaderData)+len(preRead))
combined = append(combined, httpHeaderData...)
combined = append(combined, preRead...)
return nil, nil, &SuspiciousError{Err: err, Conn: &recordedConn{Conn: rawConn, recorded: combined}}
}
baseConn := &preBufferedConn{Conn: rawConn, buf: preRead}
bConn := &bufferedConn{Conn: baseConn, r: bufio.NewReader(baseConn)}
sConn, obfsConn := buildServerObfsConn(bConn, cfg, selectedTable, true)
cConn, err := crypto.NewAEADConn(obfsConn, cfg.Key, cfg.AEADMethod)
sConn, obfsConn := buildServerObfsConn(baseConn, cfg, selectedTable, true)
seed := ServerAEADSeed(cfg.Key)
pskC2S, pskS2C := derivePSKDirectionalBases(seed)
// Server side: recv is client->server, send is server->client.
rc, err := crypto.NewRecordConn(obfsConn, cfg.AEADMethod, pskS2C, pskC2S)
if err != nil {
return nil, fmt.Errorf("crypto setup failed: %w", err)
return nil, nil, fmt.Errorf("setup crypto failed: %w", err)
}
var handshakeBuf [16]byte
if _, err := io.ReadFull(cConn, handshakeBuf[:]); err != nil {
cConn.Close()
return nil, fmt.Errorf("read handshake failed: %w", err)
msg, err := ReadKIPMessage(rc)
if err != nil {
return nil, nil, &SuspiciousError{Err: fmt.Errorf("handshake read failed: %w", err), Conn: &prefixedRecorderConn{Conn: sConn, prefix: httpHeaderData}}
}
if msg.Type != KIPTypeClientHello {
return nil, nil, &SuspiciousError{Err: fmt.Errorf("unexpected handshake message: %d", msg.Type), Conn: &prefixedRecorderConn{Conn: sConn, prefix: httpHeaderData}}
}
ch, err := DecodeKIPClientHelloPayload(msg.Payload)
if err != nil {
return nil, nil, &SuspiciousError{Err: fmt.Errorf("decode client hello failed: %w", err), Conn: &prefixedRecorderConn{Conn: sConn, prefix: httpHeaderData}}
}
if absInt64(time.Now().Unix()-ch.Timestamp.Unix()) > int64(kipHandshakeSkew.Seconds()) {
return nil, nil, &SuspiciousError{Err: fmt.Errorf("time skew/replay"), Conn: &prefixedRecorderConn{Conn: sConn, prefix: httpHeaderData}}
}
ts := int64(binary.BigEndian.Uint64(handshakeBuf[:8]))
if absInt64(time.Now().Unix()-ts) > 60 {
cConn.Close()
return nil, fmt.Errorf("timestamp skew detected")
userHashHex := hex.EncodeToString(ch.UserHash[:])
if !globalHandshakeReplay.allow(userHashHex, ch.Nonce, time.Now()) {
return nil, nil, &SuspiciousError{Err: fmt.Errorf("replay"), Conn: &prefixedRecorderConn{Conn: sConn, prefix: httpHeaderData}}
}
curve := ecdh.X25519()
serverEphemeral, err := curve.GenerateKey(rand.Reader)
if err != nil {
return nil, nil, &SuspiciousError{Err: fmt.Errorf("ecdh generate failed: %w", err), Conn: &prefixedRecorderConn{Conn: sConn, prefix: httpHeaderData}}
}
shared, err := x25519SharedSecret(serverEphemeral, ch.ClientPub[:])
if err != nil {
return nil, nil, &SuspiciousError{Err: fmt.Errorf("ecdh failed: %w", err), Conn: &prefixedRecorderConn{Conn: sConn, prefix: httpHeaderData}}
}
sessC2S, sessS2C, err := deriveSessionDirectionalBases(seed, shared, ch.Nonce)
if err != nil {
return nil, nil, &SuspiciousError{Err: fmt.Errorf("derive session keys failed: %w", err), Conn: &prefixedRecorderConn{Conn: sConn, prefix: httpHeaderData}}
}
var serverPub [kipHelloPubSize]byte
copy(serverPub[:], serverEphemeral.PublicKey().Bytes())
sh := &KIPServerHello{
Nonce: ch.Nonce,
ServerPub: serverPub,
SelectedFeats: ch.Features & KIPFeatAll,
}
if err := WriteKIPMessage(rc, KIPTypeServerHello, sh.EncodePayload()); err != nil {
return nil, nil, &SuspiciousError{Err: fmt.Errorf("write server hello failed: %w", err), Conn: &prefixedRecorderConn{Conn: sConn, prefix: httpHeaderData}}
}
if err := rc.Rekey(sessS2C, sessC2S); err != nil {
return nil, nil, &SuspiciousError{Err: fmt.Errorf("rekey failed: %w", err), Conn: &prefixedRecorderConn{Conn: sConn, prefix: httpHeaderData}}
}
userHash := userHashFromHandshake(handshakeBuf[:])
sConn.StopRecording()
return rc, &HandshakeMeta{UserHash: userHashHex}, nil
}
modeBuf := []byte{0}
if _, err := io.ReadFull(cConn, modeBuf); err != nil {
cConn.Close()
return nil, fmt.Errorf("read downlink mode failed: %w", err)
// ReadServerSession consumes the first post-handshake KIP control message and returns the session intent.
func ReadServerSession(conn net.Conn, meta *HandshakeMeta) (*ServerSession, error) {
if conn == nil {
return nil, fmt.Errorf("nil conn")
}
if modeBuf[0] != downlinkMode(cfg) {
cConn.Close()
return nil, fmt.Errorf("downlink mode mismatch: client=%d server=%d", modeBuf[0], downlinkMode(cfg))
userHash := ""
if meta != nil {
userHash = meta.UserHash
}
firstByte := make([]byte, 1)
if _, err := io.ReadFull(cConn, firstByte); err != nil {
cConn.Close()
return nil, fmt.Errorf("read first byte failed: %w", err)
first, err := readFirstSessionMessage(conn)
if err != nil {
return nil, err
}
if firstByte[0] == MultiplexMagicByte {
rawConn.SetReadDeadline(time.Time{})
return &ServerSession{Conn: cConn, Type: SessionTypeMultiplex, UserHash: userHash}, nil
}
if firstByte[0] == UoTMagicByte {
version := make([]byte, 1)
if _, err := io.ReadFull(cConn, version); err != nil {
cConn.Close()
return nil, fmt.Errorf("read uot version failed: %w", err)
switch first.Type {
case KIPTypeStartUoT:
return &ServerSession{Conn: conn, Type: SessionTypeUoT, UserHash: userHash}, nil
case KIPTypeStartMux:
return &ServerSession{Conn: conn, Type: SessionTypeMultiplex, UserHash: userHash}, nil
case KIPTypeOpenTCP:
target, err := DecodeAddress(bytes.NewReader(first.Payload))
if err != nil {
return nil, fmt.Errorf("decode target address failed: %w", err)
}
if version[0] != uotVersion {
cConn.Close()
return nil, fmt.Errorf("unsupported uot version: %d", version[0])
}
rawConn.SetReadDeadline(time.Time{})
return &ServerSession{Conn: cConn, Type: SessionTypeUoT, UserHash: userHash}, nil
return &ServerSession{Conn: conn, Type: SessionTypeTCP, Target: target, UserHash: userHash}, nil
default:
return nil, fmt.Errorf("unknown kip message: %d", first.Type)
}
prefixed := &preBufferedConn{Conn: cConn, buf: firstByte}
target, err := DecodeAddress(prefixed)
if err != nil {
cConn.Close()
return nil, fmt.Errorf("read target address failed: %w", err)
}
rawConn.SetReadDeadline(time.Time{})
log.Debugln("[Sudoku] incoming TCP session target: %s", target)
return &ServerSession{
Conn: prefixed,
Type: SessionTypeTCP,
Target: target,
UserHash: userHash,
}, nil
}
func GenKeyPair() (privateKey, publicKey string, err error) {
// Generate Master Key
pair, err := crypto.GenerateMasterKey()
if err != nil {
return
}
// Split the master private key to get Available Private Key
availablePrivateKey, err := crypto.SplitPrivateKey(pair.Private)
if err != nil {
return
}
privateKey = availablePrivateKey // Available Private Key for client
publicKey = crypto.EncodePoint(pair.Public) // Master Public Key for server
return
}
func userHashFromHandshake(handshakeBuf []byte) string {
if len(handshakeBuf) < 16 {
return ""
}
return hex.EncodeToString(handshakeBuf[8:16])
}

View File

@@ -0,0 +1,73 @@
package sudoku
import (
"crypto/ecdh"
"crypto/rand"
"fmt"
"io"
"time"
"github.com/metacubex/mihomo/transport/sudoku/crypto"
)
const kipHandshakeSkew = 60 * time.Second
func kipHandshakeClient(rc *crypto.RecordConn, seed string, userHash [kipHelloUserHashSize]byte, feats uint32) (uint32, error) {
if rc == nil {
return 0, fmt.Errorf("nil conn")
}
curve := ecdh.X25519()
ephemeral, err := curve.GenerateKey(rand.Reader)
if err != nil {
return 0, fmt.Errorf("ecdh generate failed: %w", err)
}
var nonce [kipHelloNonceSize]byte
if _, err := io.ReadFull(rand.Reader, nonce[:]); err != nil {
return 0, fmt.Errorf("nonce generate failed: %w", err)
}
var clientPub [kipHelloPubSize]byte
copy(clientPub[:], ephemeral.PublicKey().Bytes())
ch := &KIPClientHello{
Timestamp: time.Now(),
UserHash: userHash,
Nonce: nonce,
ClientPub: clientPub,
Features: feats,
}
if err := WriteKIPMessage(rc, KIPTypeClientHello, ch.EncodePayload()); err != nil {
return 0, fmt.Errorf("write client hello failed: %w", err)
}
msg, err := ReadKIPMessage(rc)
if err != nil {
return 0, fmt.Errorf("read server hello failed: %w", err)
}
if msg.Type != KIPTypeServerHello {
return 0, fmt.Errorf("unexpected handshake message: %d", msg.Type)
}
sh, err := DecodeKIPServerHelloPayload(msg.Payload)
if err != nil {
return 0, fmt.Errorf("decode server hello failed: %w", err)
}
if sh.Nonce != nonce {
return 0, fmt.Errorf("handshake nonce mismatch")
}
shared, err := x25519SharedSecret(ephemeral, sh.ServerPub[:])
if err != nil {
return 0, fmt.Errorf("ecdh failed: %w", err)
}
sessC2S, sessS2C, err := deriveSessionDirectionalBases(seed, shared, nonce)
if err != nil {
return 0, fmt.Errorf("derive session keys failed: %w", err)
}
if err := rc.Rekey(sessC2S, sessS2C); err != nil {
return 0, fmt.Errorf("rekey failed: %w", err)
}
return sh.SelectedFeats, nil
}

View File

@@ -124,13 +124,18 @@ func runPackedTCPSession(id int, cfg *ProtocolConfig, errCh chan<- error) {
// Server side
go func() {
session, err := ServerHandshake(serverConn, cfg)
c, meta, err := ServerHandshake(serverConn, cfg)
if err != nil {
errCh <- fmt.Errorf("server handshake tcp: %w", err)
return
}
defer session.Conn.Close()
defer c.Close()
session, err := ReadServerSession(c, meta)
if err != nil {
errCh <- fmt.Errorf("server read session tcp: %w", err)
return
}
if session.Type != SessionTypeTCP {
errCh <- fmt.Errorf("unexpected session type: %v", session.Type)
return
@@ -159,8 +164,8 @@ func runPackedTCPSession(id int, cfg *ProtocolConfig, errCh chan<- error) {
errCh <- fmt.Errorf("encode address: %w", err)
return
}
if _, err := cConn.Write(addrBuf); err != nil {
errCh <- fmt.Errorf("client send addr: %w", err)
if err := WriteKIPMessage(cConn, KIPTypeOpenTCP, addrBuf); err != nil {
errCh <- fmt.Errorf("client send open tcp: %w", err)
return
}
@@ -182,13 +187,18 @@ func runPackedUoTSession(id int, cfg *ProtocolConfig, errCh chan<- error) {
// Server side
go func() {
session, err := ServerHandshake(serverConn, cfg)
c, meta, err := ServerHandshake(serverConn, cfg)
if err != nil {
errCh <- fmt.Errorf("server handshake uot: %w", err)
return
}
defer session.Conn.Close()
defer c.Close()
session, err := ReadServerSession(c, meta)
if err != nil {
errCh <- fmt.Errorf("server read session uot: %w", err)
return
}
if session.Type != SessionTypeUoT {
errCh <- fmt.Errorf("unexpected session type: %v", session.Type)
return
@@ -208,8 +218,8 @@ func runPackedUoTSession(id int, cfg *ProtocolConfig, errCh chan<- error) {
}
defer cConn.Close()
if err := WritePreface(cConn); err != nil {
errCh <- fmt.Errorf("client write preface: %w", err)
if err := WriteKIPMessage(cConn, KIPTypeStartUoT, nil); err != nil {
errCh <- fmt.Errorf("client start uot: %w", err)
return
}

View File

@@ -15,6 +15,14 @@ type HTTPMaskTunnelServer struct {
}
func NewHTTPMaskTunnelServer(cfg *ProtocolConfig) *HTTPMaskTunnelServer {
return newHTTPMaskTunnelServer(cfg, false)
}
func NewHTTPMaskTunnelServerWithFallback(cfg *ProtocolConfig) *HTTPMaskTunnelServer {
return newHTTPMaskTunnelServer(cfg, true)
}
func newHTTPMaskTunnelServer(cfg *ProtocolConfig, passThroughOnReject bool) *HTTPMaskTunnelServer {
if cfg == nil {
return &HTTPMaskTunnelServer{}
}
@@ -22,11 +30,13 @@ func NewHTTPMaskTunnelServer(cfg *ProtocolConfig) *HTTPMaskTunnelServer {
var ts *httpmask.TunnelServer
if !cfg.DisableHTTPMask {
switch strings.ToLower(strings.TrimSpace(cfg.HTTPMaskMode)) {
case "stream", "poll", "auto":
case "stream", "poll", "auto", "ws":
ts = httpmask.NewTunnelServer(httpmask.TunnelServerOptions{
Mode: cfg.HTTPMaskMode,
PathRoot: cfg.HTTPMaskPathRoot,
AuthKey: ClientAEADSeed(cfg.Key),
AuthKey: ServerAEADSeed(cfg.Key),
// When upstream fallback is enabled, preserve rejected HTTP requests for the caller.
PassThroughOnReject: passThroughOnReject,
})
}
}
@@ -62,6 +72,14 @@ func (s *HTTPMaskTunnelServer) WrapConn(rawConn net.Conn) (handshakeConn net.Con
case httpmask.HandleStartTunnel:
inner := *s.cfg
inner.DisableHTTPMask = true
// HTTPMask tunnel modes (stream/poll/auto/ws) add extra round trips before the first
// handshake bytes can reach ServerHandshake, especially under high concurrency.
// Bump the handshake timeout for tunneled conns to avoid flaky timeouts while keeping
// the default strict for raw TCP handshakes.
const minTunneledHandshakeTimeoutSeconds = 15
if inner.HandshakeTimeoutSeconds <= 0 || inner.HandshakeTimeoutSeconds < minTunneledHandshakeTimeoutSeconds {
inner.HandshakeTimeoutSeconds = minTunneledHandshakeTimeoutSeconds
}
return c, &inner, false, nil
default:
return nil, nil, true, nil
@@ -70,7 +88,7 @@ func (s *HTTPMaskTunnelServer) WrapConn(rawConn net.Conn) (handshakeConn net.Con
type TunnelDialer func(ctx context.Context, network, addr string) (net.Conn, error)
// DialHTTPMaskTunnel dials a CDN-capable HTTP tunnel (stream/poll/auto) and returns a stream carrying raw Sudoku bytes.
// DialHTTPMaskTunnel dials a CDN-capable HTTP tunnel (stream/poll/auto/ws) and returns a stream carrying raw Sudoku bytes.
func DialHTTPMaskTunnel(ctx context.Context, serverAddress string, cfg *ProtocolConfig, dial TunnelDialer, upgrade func(net.Conn) (net.Conn, error)) (net.Conn, error) {
if cfg == nil {
return nil, fmt.Errorf("config is required")
@@ -79,7 +97,7 @@ func DialHTTPMaskTunnel(ctx context.Context, serverAddress string, cfg *Protocol
return nil, fmt.Errorf("http mask is disabled")
}
switch strings.ToLower(strings.TrimSpace(cfg.HTTPMaskMode)) {
case "stream", "poll", "auto":
case "stream", "poll", "auto", "ws":
default:
return nil, fmt.Errorf("http-mask-mode=%q does not use http tunnel", cfg.HTTPMaskMode)
}
@@ -94,64 +112,3 @@ func DialHTTPMaskTunnel(ctx context.Context, serverAddress string, cfg *Protocol
DialContext: dial,
})
}
type HTTPMaskTunnelClient struct {
mode string
pathRoot string
authKey string
client *httpmask.TunnelClient
}
func NewHTTPMaskTunnelClient(serverAddress string, cfg *ProtocolConfig, dial TunnelDialer) (*HTTPMaskTunnelClient, error) {
if cfg == nil {
return nil, fmt.Errorf("config is required")
}
if cfg.DisableHTTPMask {
return nil, fmt.Errorf("http mask is disabled")
}
switch strings.ToLower(strings.TrimSpace(cfg.HTTPMaskMode)) {
case "stream", "poll", "auto":
default:
return nil, fmt.Errorf("http-mask-mode=%q does not use http tunnel", cfg.HTTPMaskMode)
}
switch strings.ToLower(strings.TrimSpace(cfg.HTTPMaskMultiplex)) {
case "auto", "on":
default:
return nil, fmt.Errorf("http-mask-multiplex=%q does not enable reuse", cfg.HTTPMaskMultiplex)
}
c, err := httpmask.NewTunnelClient(serverAddress, httpmask.TunnelClientOptions{
TLSEnabled: cfg.HTTPMaskTLSEnabled,
HostOverride: cfg.HTTPMaskHost,
DialContext: dial,
})
if err != nil {
return nil, err
}
return &HTTPMaskTunnelClient{
mode: cfg.HTTPMaskMode,
pathRoot: cfg.HTTPMaskPathRoot,
authKey: ClientAEADSeed(cfg.Key),
client: c,
}, nil
}
func (c *HTTPMaskTunnelClient) Dial(ctx context.Context, upgrade func(net.Conn) (net.Conn, error)) (net.Conn, error) {
if c == nil || c.client == nil {
return nil, fmt.Errorf("nil httpmask tunnel client")
}
return c.client.DialTunnel(ctx, httpmask.TunnelDialOptions{
Mode: c.mode,
PathRoot: c.pathRoot,
AuthKey: c.authKey,
Upgrade: upgrade,
})
}
func (c *HTTPMaskTunnelClient) CloseIdleConnections() {
if c == nil || c.client == nil {
return
}
c.client.CloseIdleConnections()
}

View File

@@ -61,7 +61,7 @@ func startTunnelServer(t *testing.T, cfg *ProtocolConfig, handle func(*ServerSes
return
}
session, err := ServerHandshake(handshakeConn, handshakeCfg)
cConn, meta, err := ServerHandshake(handshakeConn, handshakeCfg)
if err != nil {
_ = handshakeConn.Close()
if handshakeConn != conn {
@@ -70,8 +70,13 @@ func startTunnelServer(t *testing.T, cfg *ProtocolConfig, handle func(*ServerSes
errC <- err
return
}
defer session.Conn.Close()
defer cConn.Close()
session, err := ReadServerSession(cConn, meta)
if err != nil {
errC <- err
return
}
if handleErr := handle(session); handleErr != nil {
errC <- handleErr
}
@@ -172,7 +177,7 @@ func TestHTTPMaskTunnel_Stream_TCPRoundTrip(t *testing.T) {
if err != nil {
t.Fatalf("encode addr: %v", err)
}
if _, err := cConn.Write(addrBuf); err != nil {
if err := WriteKIPMessage(cConn, KIPTypeOpenTCP, addrBuf); err != nil {
t.Fatalf("write addr: %v", err)
}
@@ -239,8 +244,8 @@ func TestHTTPMaskTunnel_Poll_UoTRoundTrip(t *testing.T) {
}
defer cConn.Close()
if err := WritePreface(cConn); err != nil {
t.Fatalf("write preface: %v", err)
if err := WriteKIPMessage(cConn, KIPTypeStartUoT, nil); err != nil {
t.Fatalf("start uot: %v", err)
}
if err := WriteDatagram(cConn, target, payload); err != nil {
t.Fatalf("write datagram: %v", err)
@@ -305,7 +310,68 @@ func TestHTTPMaskTunnel_Auto_TCPRoundTrip(t *testing.T) {
if err != nil {
t.Fatalf("encode addr: %v", err)
}
if _, err := cConn.Write(addrBuf); err != nil {
if err := WriteKIPMessage(cConn, KIPTypeOpenTCP, addrBuf); err != nil {
t.Fatalf("write addr: %v", err)
}
buf := make([]byte, 2)
if _, err := io.ReadFull(cConn, buf); err != nil {
t.Fatalf("read: %v", err)
}
if string(buf) != "ok" {
t.Fatalf("unexpected payload: %q", buf)
}
stop()
for err := range errCh {
t.Fatalf("server error: %v", err)
}
}
func TestHTTPMaskTunnel_WS_TCPRoundTrip(t *testing.T) {
key := "tunnel-ws-key"
target := "1.1.1.1:80"
serverCfg := newTunnelTestTable(t, key)
serverCfg.HTTPMaskMode = "ws"
addr, stop, errCh := startTunnelServer(t, serverCfg, func(s *ServerSession) error {
if s.Type != SessionTypeTCP {
return fmt.Errorf("unexpected session type: %v", s.Type)
}
if s.Target != target {
return fmt.Errorf("target mismatch: %s", s.Target)
}
_, _ = s.Conn.Write([]byte("ok"))
return nil
})
defer stop()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
clientCfg := *serverCfg
clientCfg.ServerAddress = addr
tunnelConn, err := DialHTTPMaskTunnel(ctx, clientCfg.ServerAddress, &clientCfg, (&net.Dialer{}).DialContext, nil)
if err != nil {
t.Fatalf("dial tunnel: %v", err)
}
defer tunnelConn.Close()
handshakeCfg := clientCfg
handshakeCfg.DisableHTTPMask = true
cConn, err := ClientHandshake(tunnelConn, &handshakeCfg)
if err != nil {
t.Fatalf("client handshake: %v", err)
}
defer cConn.Close()
addrBuf, err := EncodeAddress(target)
if err != nil {
t.Fatalf("encode addr: %v", err)
}
if err := WriteKIPMessage(cConn, KIPTypeOpenTCP, addrBuf); err != nil {
t.Fatalf("write addr: %v", err)
}
@@ -406,7 +472,7 @@ func TestHTTPMaskTunnel_Soak_Concurrent(t *testing.T) {
runErr <- fmt.Errorf("encode addr: %w", err)
return
}
if _, err := cConn.Write(addrBuf); err != nil {
if err := WriteKIPMessage(cConn, KIPTypeOpenTCP, addrBuf); err != nil {
runErr <- fmt.Errorf("write addr: %w", err)
return
}

97
transport/sudoku/init.go Normal file
View File

@@ -0,0 +1,97 @@
package sudoku
import (
"encoding/hex"
"fmt"
"strings"
"github.com/metacubex/edwards25519"
"github.com/metacubex/mihomo/transport/sudoku/crypto"
"github.com/metacubex/mihomo/transport/sudoku/obfs/sudoku"
)
func NewTable(key string, tableType string) *sudoku.Table {
table, err := NewTableWithCustom(key, tableType, "")
if err != nil {
panic(fmt.Sprintf("[Sudoku] failed to init tables: %v", err))
}
return table
}
func NewTableWithCustom(key string, tableType string, customTable string) (*sudoku.Table, error) {
table, err := sudoku.NewTableWithCustom(key, tableType, customTable)
if err != nil {
return nil, err
}
return table, nil
}
// ClientAEADSeed returns a canonical "seed" that is stable between client private key material and server public key.
func ClientAEADSeed(key string) string {
key = strings.TrimSpace(key)
if key == "" {
return ""
}
b, err := hex.DecodeString(key)
if err != nil {
return key
}
// Client-side key material can be:
// - split private key: 64 bytes hex (r||k)
// - master private scalar: 32 bytes hex (x)
// - PSK string: non-hex
//
// We intentionally do NOT treat a 32-byte hex as a public key here; the client is expected
// to carry private material. Server-side should use ServerAEADSeed for public keys.
switch len(b) {
case 64:
case 32:
default:
return key
}
if recovered, err := crypto.RecoverPublicKey(key); err == nil {
return crypto.EncodePoint(recovered)
}
return key
}
// ServerAEADSeed returns a canonical seed for server-side configuration.
//
// When key is a public key (32-byte compressed point, hex), it returns the canonical point encoding.
// When key is private key material (split/master scalar), it derives and returns the public key.
func ServerAEADSeed(key string) string {
key = strings.TrimSpace(key)
if key == "" {
return ""
}
b, err := hex.DecodeString(key)
if err != nil {
return key
}
// Prefer interpreting 32-byte hex as a public key point, to avoid accidental scalar parsing.
if len(b) == 32 {
if p, err := new(edwards25519.Point).SetBytes(b); err == nil {
return hex.EncodeToString(p.Bytes())
}
}
// Fall back to client-side rules for private key materials / other formats.
return ClientAEADSeed(key)
}
// GenKeyPair generates a client "available private key" and the corresponding server public key.
func GenKeyPair() (privateKey, publicKey string, err error) {
pair, err := crypto.GenerateMasterKey()
if err != nil {
return "", "", err
}
availablePrivateKey, err := crypto.SplitPrivateKey(pair.Private)
if err != nil {
return "", "", err
}
return availablePrivateKey, crypto.EncodePoint(pair.Public), nil
}

View File

@@ -0,0 +1,44 @@
package sudoku
import (
"crypto/rand"
"encoding/hex"
"testing"
"github.com/metacubex/edwards25519"
"github.com/stretchr/testify/require"
)
func TestClientAEADSeed_IsStableForPrivAndPub(t *testing.T) {
for i := 0; i < 64; i++ {
priv, pub, err := GenKeyPair()
require.NoError(t, err)
require.Equal(t, pub, ClientAEADSeed(priv))
require.Equal(t, pub, ServerAEADSeed(pub))
require.Equal(t, pub, ServerAEADSeed(priv))
}
}
func TestClientAEADSeed_Supports32ByteMasterScalar(t *testing.T) {
var seed [64]byte
_, err := rand.Read(seed[:])
require.NoError(t, err)
s, err := edwards25519.NewScalar().SetUniformBytes(seed[:])
require.NoError(t, err)
keyHex := hex.EncodeToString(s.Bytes())
require.Len(t, keyHex, 64)
require.NotEqual(t, keyHex, ClientAEADSeed(keyHex))
require.Equal(t, ClientAEADSeed(keyHex), ServerAEADSeed(ClientAEADSeed(keyHex)))
}
func TestServerAEADSeed_LeavesPublicKeyAsIs(t *testing.T) {
for i := 0; i < 64; i++ {
priv, pub, err := GenKeyPair()
require.NoError(t, err)
require.Equal(t, pub, ServerAEADSeed(pub))
require.Equal(t, pub, ServerAEADSeed(priv))
}
}

206
transport/sudoku/kip.go Normal file
View File

@@ -0,0 +1,206 @@
package sudoku
import (
"bytes"
"crypto/sha256"
"encoding/binary"
"encoding/hex"
"errors"
"fmt"
"io"
"strings"
"time"
)
const (
kipMagic = "kip"
KIPTypeClientHello byte = 0x01
KIPTypeServerHello byte = 0x02
KIPTypeOpenTCP byte = 0x10
KIPTypeStartMux byte = 0x11
KIPTypeStartUoT byte = 0x12
KIPTypeKeepAlive byte = 0x14
)
// KIP feature bits are advisory capability flags negotiated during the handshake.
// They represent control-plane message families.
const (
KIPFeatOpenTCP uint32 = 1 << 0
KIPFeatMux uint32 = 1 << 1
KIPFeatUoT uint32 = 1 << 2
KIPFeatKeepAlive uint32 = 1 << 4
KIPFeatAll = KIPFeatOpenTCP | KIPFeatMux | KIPFeatUoT | KIPFeatKeepAlive
)
const (
kipHelloUserHashSize = 8
kipHelloNonceSize = 16
kipHelloPubSize = 32
kipMaxPayload = 64 * 1024
)
var errKIP = errors.New("kip protocol error")
type KIPMessage struct {
Type byte
Payload []byte
}
func WriteKIPMessage(w io.Writer, typ byte, payload []byte) error {
if w == nil {
return fmt.Errorf("%w: nil writer", errKIP)
}
if len(payload) > kipMaxPayload {
return fmt.Errorf("%w: payload too large: %d", errKIP, len(payload))
}
var hdr [3 + 1 + 2]byte
copy(hdr[:3], []byte(kipMagic))
hdr[3] = typ
binary.BigEndian.PutUint16(hdr[4:], uint16(len(payload)))
if err := writeFull(w, hdr[:]); err != nil {
return err
}
if len(payload) == 0 {
return nil
}
return writeFull(w, payload)
}
func ReadKIPMessage(r io.Reader) (*KIPMessage, error) {
if r == nil {
return nil, fmt.Errorf("%w: nil reader", errKIP)
}
var hdr [3 + 1 + 2]byte
if _, err := io.ReadFull(r, hdr[:]); err != nil {
return nil, err
}
if string(hdr[:3]) != kipMagic {
return nil, fmt.Errorf("%w: bad magic", errKIP)
}
typ := hdr[3]
n := int(binary.BigEndian.Uint16(hdr[4:]))
if n < 0 || n > kipMaxPayload {
return nil, fmt.Errorf("%w: invalid payload length: %d", errKIP, n)
}
var payload []byte
if n > 0 {
payload = make([]byte, n)
if _, err := io.ReadFull(r, payload); err != nil {
return nil, err
}
}
return &KIPMessage{Type: typ, Payload: payload}, nil
}
type KIPClientHello struct {
Timestamp time.Time
UserHash [kipHelloUserHashSize]byte
Nonce [kipHelloNonceSize]byte
ClientPub [kipHelloPubSize]byte
Features uint32
}
type KIPServerHello struct {
Nonce [kipHelloNonceSize]byte
ServerPub [kipHelloPubSize]byte
SelectedFeats uint32
}
func kipUserHashFromKey(psk string) [kipHelloUserHashSize]byte {
var out [kipHelloUserHashSize]byte
psk = strings.TrimSpace(psk)
if psk == "" {
return out
}
// Align with upstream: when the client carries private key material (or even just a public key),
// prefer hashing the raw hex bytes so different split/master keys can be distinguished.
if keyBytes, err := hex.DecodeString(psk); err == nil && len(keyBytes) > 0 {
sum := sha256.Sum256(keyBytes)
copy(out[:], sum[:kipHelloUserHashSize])
return out
}
sum := sha256.Sum256([]byte(psk))
copy(out[:], sum[:kipHelloUserHashSize])
return out
}
func KIPUserHashHexFromKey(psk string) string {
uh := kipUserHashFromKey(psk)
return hex.EncodeToString(uh[:])
}
func (m *KIPClientHello) EncodePayload() []byte {
var b bytes.Buffer
var tmp [8]byte
binary.BigEndian.PutUint64(tmp[:], uint64(m.Timestamp.Unix()))
b.Write(tmp[:])
b.Write(m.UserHash[:])
b.Write(m.Nonce[:])
b.Write(m.ClientPub[:])
var f [4]byte
binary.BigEndian.PutUint32(f[:], m.Features)
b.Write(f[:])
return b.Bytes()
}
func DecodeKIPClientHelloPayload(payload []byte) (*KIPClientHello, error) {
const minLen = 8 + kipHelloUserHashSize + kipHelloNonceSize + kipHelloPubSize + 4
if len(payload) < minLen {
return nil, fmt.Errorf("%w: client hello too short", errKIP)
}
var h KIPClientHello
ts := int64(binary.BigEndian.Uint64(payload[:8]))
h.Timestamp = time.Unix(ts, 0)
off := 8
copy(h.UserHash[:], payload[off:off+kipHelloUserHashSize])
off += kipHelloUserHashSize
copy(h.Nonce[:], payload[off:off+kipHelloNonceSize])
off += kipHelloNonceSize
copy(h.ClientPub[:], payload[off:off+kipHelloPubSize])
off += kipHelloPubSize
h.Features = binary.BigEndian.Uint32(payload[off : off+4])
return &h, nil
}
func (m *KIPServerHello) EncodePayload() []byte {
var b bytes.Buffer
b.Write(m.Nonce[:])
b.Write(m.ServerPub[:])
var f [4]byte
binary.BigEndian.PutUint32(f[:], m.SelectedFeats)
b.Write(f[:])
return b.Bytes()
}
func DecodeKIPServerHelloPayload(payload []byte) (*KIPServerHello, error) {
const want = kipHelloNonceSize + kipHelloPubSize + 4
if len(payload) != want {
return nil, fmt.Errorf("%w: server hello bad len: %d", errKIP, len(payload))
}
var h KIPServerHello
off := 0
copy(h.Nonce[:], payload[off:off+kipHelloNonceSize])
off += kipHelloNonceSize
copy(h.ServerPub[:], payload[off:off+kipHelloPubSize])
off += kipHelloPubSize
h.SelectedFeats = binary.BigEndian.Uint32(payload[off : off+4])
return &h, nil
}
func writeFull(w io.Writer, b []byte) error {
for len(b) > 0 {
n, err := w.Write(b)
if err != nil {
return err
}
b = b[n:]
}
return nil
}

View File

@@ -10,19 +10,14 @@ import (
"github.com/metacubex/mihomo/transport/sudoku/multiplex"
)
const (
MultiplexMagicByte byte = multiplex.MagicByte
MultiplexVersion byte = multiplex.Version
)
// StartMultiplexClient writes the multiplex preface and upgrades an already-handshaked Sudoku tunnel into a multiplex session.
// StartMultiplexClient upgrades an already-handshaked Sudoku tunnel into a multiplex session.
func StartMultiplexClient(conn net.Conn) (*MultiplexClient, error) {
if conn == nil {
return nil, fmt.Errorf("nil conn")
}
if err := multiplex.WritePreface(conn); err != nil {
return nil, fmt.Errorf("write multiplex preface failed: %w", err)
if err := WriteKIPMessage(conn, KIPTypeStartMux, nil); err != nil {
return nil, fmt.Errorf("write mux start failed: %w", err)
}
sess, err := multiplex.NewClientSession(conn)
@@ -77,20 +72,10 @@ func (c *MultiplexClient) IsClosed() bool {
}
// AcceptMultiplexServer upgrades a server-side, already-handshaked Sudoku connection into a multiplex session.
//
// The caller must have already consumed the multiplex magic byte (MultiplexMagicByte). This function consumes the
// multiplex version byte and starts the session.
func AcceptMultiplexServer(conn net.Conn) (*MultiplexServer, error) {
if conn == nil {
return nil, fmt.Errorf("nil conn")
}
v, err := multiplex.ReadVersion(conn)
if err != nil {
return nil, err
}
if err := multiplex.ValidateVersion(v); err != nil {
return nil, err
}
sess, err := multiplex.NewServerSession(conn)
if err != nil {
return nil, err

View File

@@ -1,39 +0,0 @@
package multiplex
import (
"fmt"
"io"
)
const (
// MagicByte marks a Sudoku tunnel connection that will switch into multiplex mode.
// It is sent after the Sudoku handshake + downlink mode byte.
//
// Keep it distinct from UoTMagicByte and address type bytes.
MagicByte byte = 0xED
Version byte = 0x01
)
func WritePreface(w io.Writer) error {
if w == nil {
return fmt.Errorf("nil writer")
}
_, err := w.Write([]byte{MagicByte, Version})
return err
}
func ReadVersion(r io.Reader) (byte, error) {
var b [1]byte
if _, err := io.ReadFull(r, b[:]); err != nil {
return 0, err
}
return b[0], nil
}
func ValidateVersion(v byte) error {
if v != Version {
return fmt.Errorf("unsupported multiplex version: %d", v)
}
return nil
}

View File

@@ -18,7 +18,6 @@ func TestUserHash_StableAcrossTableRotation(t *testing.T) {
sudokuobfs.NewTable("seed-b", "prefer_ascii"),
}
key := "userhash-stability-key"
target := "example.com:80"
serverCfg := DefaultConfig()
serverCfg.Key = key
@@ -48,13 +47,16 @@ func TestUserHash_StableAcrossTableRotation(t *testing.T) {
}
go func(conn net.Conn) {
defer conn.Close()
session, err := ServerHandshake(conn, serverCfg)
_, meta, err := ServerHandshake(conn, serverCfg)
if err != nil {
errCh <- err
return
}
defer session.Conn.Close()
hashCh <- session.UserHash
if meta == nil || meta.UserHash == "" {
errCh <- io.ErrUnexpectedEOF
return
}
hashCh <- meta.UserHash
}(c)
}
}()
@@ -77,15 +79,6 @@ func TestUserHash_StableAcrossTableRotation(t *testing.T) {
t.Fatalf("handshake %d: %v", i, err)
}
addrBuf, err := EncodeAddress(target)
if err != nil {
_ = cConn.Close()
t.Fatalf("encode addr %d: %v", i, err)
}
if _, err := cConn.Write(addrBuf); err != nil {
_ = cConn.Close()
t.Fatalf("write addr %d: %v", i, err)
}
_ = cConn.Close()
}
@@ -145,18 +138,22 @@ func TestMultiplex_TCP_Echo(t *testing.T) {
}
defer raw.Close()
session, err := ServerHandshake(raw, serverCfg)
c, meta, err := ServerHandshake(raw, serverCfg)
if err != nil {
return
}
atomic.AddInt64(&handshakes, 1)
session, err := ReadServerSession(c, meta)
if err != nil {
return
}
if session.Type != SessionTypeMultiplex {
_ = session.Conn.Close()
_ = c.Close()
return
}
mux, err := AcceptMultiplexServer(session.Conn)
mux, err := AcceptMultiplexServer(c)
if err != nil {
return
}
@@ -240,21 +237,3 @@ func TestMultiplex_TCP_Echo(t *testing.T) {
t.Fatalf("unexpected stream count: %d", got)
}
}
func TestMultiplex_Boundary_InvalidVersion(t *testing.T) {
client, server := net.Pipe()
t.Cleanup(func() { _ = client.Close() })
t.Cleanup(func() { _ = server.Close() })
errCh := make(chan error, 1)
go func() {
_, err := AcceptMultiplexServer(server)
errCh <- err
}()
// AcceptMultiplexServer expects the magic byte to have been consumed already; write a bad version byte.
_, _ = client.Write([]byte{0xFF})
if err := <-errCh; err == nil {
t.Fatalf("expected error")
}
}

View File

@@ -6,9 +6,10 @@ import (
"crypto/subtle"
"encoding/base64"
"encoding/binary"
"github.com/metacubex/http"
"strings"
"time"
"github.com/metacubex/http"
)
const (

View File

@@ -16,6 +16,7 @@ import (
"strconv"
"strings"
"sync"
"syscall"
"time"
"github.com/metacubex/mihomo/component/ca"
@@ -32,6 +33,7 @@ const (
TunnelModeStream TunnelMode = "stream"
TunnelModePoll TunnelMode = "poll"
TunnelModeAuto TunnelMode = "auto"
TunnelModeWS TunnelMode = "ws"
)
func normalizeTunnelMode(mode string) TunnelMode {
@@ -44,6 +46,8 @@ func normalizeTunnelMode(mode string) TunnelMode {
return TunnelModePoll
case string(TunnelModeAuto):
return TunnelModeAuto
case string(TunnelModeWS):
return TunnelModeWS
default:
// Be conservative: unknown => legacy
return TunnelModeLegacy
@@ -88,7 +92,6 @@ type TunnelClientOptions struct {
}
type TunnelClient struct {
client *http.Client
transport *http.Transport
target httpClientTarget
}
@@ -105,7 +108,6 @@ func NewTunnelClient(serverAddress string, opts TunnelClientOptions) (*TunnelCli
}
return &TunnelClient{
client: &http.Client{Transport: transport},
transport: transport,
target: target,
}, nil
@@ -119,7 +121,7 @@ func (c *TunnelClient) CloseIdleConnections() {
}
func (c *TunnelClient) DialTunnel(ctx context.Context, opts TunnelDialOptions) (net.Conn, error) {
if c == nil || c.client == nil {
if c == nil || c.transport == nil {
return nil, fmt.Errorf("nil tunnel client")
}
tm := normalizeTunnelMode(opts.Mode)
@@ -127,25 +129,31 @@ func (c *TunnelClient) DialTunnel(ctx context.Context, opts TunnelDialOptions) (
return nil, fmt.Errorf("legacy mode does not use http tunnel")
}
// Create a per-dial client while sharing the underlying Transport for connection reuse.
// This matches upstream behavior and avoids potential client-level concurrency pitfalls.
client := &http.Client{Transport: c.transport}
switch tm {
case TunnelModeStream:
return dialStreamWithClient(ctx, c.client, c.target, opts)
return dialStreamWithClient(ctx, client, c.target, opts)
case TunnelModePoll:
return dialPollWithClient(ctx, c.client, c.target, opts)
return dialPollWithClient(ctx, client, c.target, opts)
case TunnelModeWS:
return nil, fmt.Errorf("ws mode does not support TunnelClient reuse")
case TunnelModeAuto:
streamCtx, cancelX := context.WithTimeout(ctx, 3*time.Second)
c1, errX := dialStreamWithClient(streamCtx, c.client, c.target, opts)
c1, errX := dialStreamWithClient(streamCtx, client, c.target, opts)
cancelX()
if errX == nil {
return c1, nil
}
c2, errP := dialPollWithClient(ctx, c.client, c.target, opts)
c2, errP := dialPollWithClient(ctx, client, c.target, opts)
if errP == nil {
return c2, nil
}
return nil, fmt.Errorf("auto tunnel failed: stream: %v; poll: %w", errX, errP)
default:
return dialStreamWithClient(ctx, c.client, c.target, opts)
return dialStreamWithClient(ctx, client, c.target, opts)
}
}
@@ -166,6 +174,8 @@ func DialTunnel(ctx context.Context, serverAddress string, opts TunnelDialOption
return dialStreamFn(ctx, serverAddress, opts)
case TunnelModePoll:
return dialPollFn(ctx, serverAddress, opts)
case TunnelModeWS:
return dialWS(ctx, serverAddress, opts)
case TunnelModeAuto:
// "stream" can hang on some CDNs that buffer uploads until request body completes.
// Keep it on a short leash so we can fall back to poll within the caller's deadline.
@@ -306,6 +316,36 @@ type sessionDialInfo struct {
auth *tunnelAuth
}
type httpStatusError struct {
code int
status string
}
func (e *httpStatusError) Error() string {
if e == nil {
return "bad status"
}
if e.status != "" {
return "bad status: " + e.status
}
return "bad status"
}
func isRetryableStatusCode(code int) bool {
return code == http.StatusRequestTimeout || code == http.StatusTooManyRequests || code >= 500
}
type idleConnCloser interface{ CloseIdleConnections() }
func closeIdleConnections(client *http.Client) {
if client == nil || client.Transport == nil {
return
}
if c, ok := client.Transport.(idleConnCloser); ok {
c.CloseIdleConnections()
}
}
func dialSessionWithClient(ctx context.Context, client *http.Client, target httpClientTarget, mode TunnelMode, opts TunnelDialOptions) (*sessionDialInfo, error) {
if client == nil {
return nil, fmt.Errorf("nil http client")
@@ -313,25 +353,61 @@ func dialSessionWithClient(ctx context.Context, client *http.Client, target http
auth := newTunnelAuth(opts.AuthKey, 0)
authorizeURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: joinPathRoot(opts.PathRoot, "/session")}).String()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, authorizeURL, nil)
if err != nil {
return nil, err
}
req.Host = target.headerHost
applyTunnelHeaders(req.Header, target.headerHost, mode)
applyTunnelAuth(req, auth, mode, http.MethodGet, "/session")
resp, err := client.Do(req)
if err != nil {
return nil, err
}
bodyBytes, err := io.ReadAll(io.LimitReader(resp.Body, 4*1024))
_ = resp.Body.Close()
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("%s authorize bad status: %s (%s)", mode, resp.Status, strings.TrimSpace(string(bodyBytes)))
var bodyBytes []byte
for attempt := 0; ; attempt++ {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, authorizeURL, nil)
if err != nil {
return nil, err
}
req.Host = target.headerHost
applyTunnelHeaders(req.Header, target.headerHost, mode)
applyTunnelAuth(req, auth, mode, http.MethodGet, "/session")
resp, err := client.Do(req)
if err != nil {
// Transient failure on reused keep-alive conns (multiplex=auto). Retry a few times.
if attempt < 2 && (isDialError(err) || isRetryableRequestError(err)) {
closeIdleConnections(client)
select {
case <-time.After(25 * time.Millisecond):
continue
case <-ctx.Done():
return nil, err
}
}
return nil, err
}
bodyBytes, err = io.ReadAll(io.LimitReader(resp.Body, 4*1024))
_ = resp.Body.Close()
if err != nil {
if attempt < 2 && isRetryableRequestError(err) {
closeIdleConnections(client)
select {
case <-time.After(25 * time.Millisecond):
continue
case <-ctx.Done():
return nil, err
}
}
return nil, err
}
if resp.StatusCode != http.StatusOK {
// Retry some transient proxy/CDN errors.
if attempt < 2 && resp.StatusCode >= 500 {
closeIdleConnections(client)
select {
case <-time.After(25 * time.Millisecond):
continue
case <-ctx.Done():
return nil, fmt.Errorf("%s authorize bad status: %s (%s)", mode, resp.Status, strings.TrimSpace(string(bodyBytes)))
}
}
return nil, fmt.Errorf("%s authorize bad status: %s (%s)", mode, resp.Status, strings.TrimSpace(string(bodyBytes)))
}
break
}
token, err := parseTunnelToken(bodyBytes)
@@ -544,9 +620,8 @@ type streamSplitConn struct {
auth *tunnelAuth
}
func (c *streamSplitConn) Close() error {
_ = c.closeWithError(io.ErrClosedPipe)
func (c *streamSplitConn) closeWithError(err error) error {
_ = c.queuedConn.closeWithError(err)
if c.cancel != nil {
c.cancel()
}
@@ -554,6 +629,8 @@ func (c *streamSplitConn) Close() error {
return nil
}
func (c *streamSplitConn) Close() error { return c.closeWithError(io.ErrClosedPipe) }
func newStreamSplitConnFromInfo(info *sessionDialInfo) *streamSplitConn {
if info == nil {
return nil
@@ -659,7 +736,7 @@ func (c *streamSplitConn) pullLoop() {
req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, c.pullURL, nil)
if err != nil {
cancel()
_ = c.Close()
_ = c.closeWithError(fmt.Errorf("stream pull build request failed: %w", err))
return
}
req.Host = c.headerHost
@@ -669,8 +746,9 @@ func (c *streamSplitConn) pullLoop() {
resp, err := c.client.Do(req)
if err != nil {
cancel()
if isDialError(err) && dialRetry < maxDialRetry {
if (isDialError(err) || isRetryableRequestError(err)) && dialRetry < maxDialRetry {
dialRetry++
closeIdleConnections(c.client)
select {
case <-time.After(backoff):
case <-c.closed:
@@ -682,16 +760,33 @@ func (c *streamSplitConn) pullLoop() {
}
continue
}
_ = c.Close()
_ = c.closeWithError(fmt.Errorf("stream pull request failed: %w", err))
return
}
dialRetry = 0
backoff = minBackoff
if resp.StatusCode != http.StatusOK {
if isRetryableStatusCode(resp.StatusCode) && dialRetry < maxDialRetry {
dialRetry++
_, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 4*1024))
_ = resp.Body.Close()
cancel()
closeIdleConnections(c.client)
select {
case <-time.After(backoff):
case <-c.closed:
return
}
backoff *= 2
if backoff > maxBackoff {
backoff = maxBackoff
}
continue
}
_ = resp.Body.Close()
cancel()
_ = c.Close()
_ = c.closeWithError(fmt.Errorf("stream pull bad status: %s", resp.Status))
return
}
@@ -717,7 +812,12 @@ func (c *streamSplitConn) pullLoop() {
// Long-poll ended; retry.
break
}
_ = c.Close()
// Some environments may sporadically reset the HTTP connection under load; treat
// it as an ended long-poll and retry instead of tearing down the whole tunnel.
if errors.Is(rerr, io.ErrUnexpectedEOF) || isRetryableRequestError(rerr) {
break
}
_ = c.closeWithError(fmt.Errorf("stream pull read failed: %w", rerr))
return
}
}
@@ -735,8 +835,13 @@ func (c *streamSplitConn) pullLoop() {
func (c *streamSplitConn) pushLoop() {
const (
maxBatchBytes = 256 * 1024
flushInterval = 5 * time.Millisecond
// Batching is critical for stability under high concurrency: every flush is a new TCP
// connection in HTTP/1.1, and too many tiny uploads can overwhelm the accept backlog,
// causing sporadic RSTs (connection reset by peer).
//
// Keep this below the server-side maxUploadBytes limit in streamPush().
maxBatchBytes = 512 * 1024
flushInterval = 25 * time.Millisecond
requestTimeout = 20 * time.Second
maxDialRetry = 12
minBackoff = 10 * time.Millisecond
@@ -754,12 +859,18 @@ func (c *streamSplitConn) pushLoop() {
return nil
}
payload := buf.Bytes()
reqCtx, cancel := context.WithTimeout(c.ctx, requestTimeout)
req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, c.pushURL, bytes.NewReader(buf.Bytes()))
req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, c.pushURL, bytes.NewReader(payload))
if err != nil {
cancel()
return err
}
// Be explicit: some http client forks won't auto-populate GetBody, which makes POST retries on stale
// keep-alive connections flaky under multiplex=auto.
req.GetBody = func() (io.ReadCloser, error) {
return io.NopCloser(bytes.NewReader(payload)), nil
}
req.Host = c.headerHost
applyTunnelHeaders(req.Header, c.headerHost, TunnelModeStream)
applyTunnelAuth(req, c.auth, TunnelModeStream, http.MethodPost, "/api/v1/upload")
@@ -774,7 +885,7 @@ func (c *streamSplitConn) pushLoop() {
_ = resp.Body.Close()
cancel()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("bad status: %s", resp.Status)
return &httpStatusError{code: resp.StatusCode, status: resp.Status}
}
buf.Reset()
@@ -787,8 +898,22 @@ func (c *streamSplitConn) pushLoop() {
for {
if err := flush(); err == nil {
return nil
} else if isDialError(err) && dialRetry < maxDialRetry {
} else if se := (*httpStatusError)(nil); errors.As(err, &se) && isRetryableStatusCode(se.code) && dialRetry < maxDialRetry {
dialRetry++
closeIdleConnections(c.client)
select {
case <-time.After(backoff):
case <-c.closed:
return io.ErrClosedPipe
}
backoff *= 2
if backoff > maxBackoff {
backoff = maxBackoff
}
continue
} else if (isDialError(err) || isRetryableRequestError(err)) && dialRetry < maxDialRetry {
dialRetry++
closeIdleConnections(c.client)
select {
case <-time.After(backoff):
case <-c.closed:
@@ -829,7 +954,7 @@ func (c *streamSplitConn) pushLoop() {
}
if buf.Len()+len(b) > maxBatchBytes {
if err := flushWithRetry(); err != nil {
_ = c.Close()
_ = c.closeWithError(fmt.Errorf("stream push flush failed: %w", err))
return
}
resetTimer()
@@ -837,14 +962,14 @@ func (c *streamSplitConn) pushLoop() {
_, _ = buf.Write(b)
if buf.Len() >= maxBatchBytes {
if err := flushWithRetry(); err != nil {
_ = c.Close()
_ = c.closeWithError(fmt.Errorf("stream push flush failed: %w", err))
return
}
resetTimer()
}
case <-timer.C:
if err := flushWithRetry(); err != nil {
_ = c.Close()
_ = c.closeWithError(fmt.Errorf("stream push flush failed: %w", err))
return
}
resetTimer()
@@ -858,7 +983,7 @@ func (c *streamSplitConn) pushLoop() {
}
if buf.Len()+len(b) > maxBatchBytes {
if err := flushWithRetry(); err != nil {
_ = c.Close()
_ = c.closeWithError(fmt.Errorf("stream push flush failed: %w", err))
return
}
}
@@ -905,6 +1030,43 @@ func isDialError(err error) bool {
return false
}
func isRetryableRequestError(err error) bool {
if err == nil {
return false
}
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return false
}
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
return true
}
// net/http may return this when reusing a keep-alive conn that the peer already closed.
// Treat it as retryable: callers already implement bounded backoff retries.
if strings.Contains(strings.ToLower(err.Error()), "server closed idle connection") {
return true
}
// Unwrap common wrappers.
var urlErr *url.Error
if errors.As(err, &urlErr) {
return isRetryableRequestError(urlErr.Err)
}
// Connection-level transient failures.
if errors.Is(err, syscall.EPIPE) || errors.Is(err, syscall.ECONNRESET) {
return true
}
if errors.Is(err, io.ErrClosedPipe) || errors.Is(err, net.ErrClosed) {
return true
}
var netErr net.Error
if errors.As(err, &netErr) {
return netErr.Timeout() || netErr.Temporary()
}
return false
}
func (c *pollConn) closeWithError(err error) error {
_ = c.queuedConn.closeWithError(err)
if c.cancel != nil {
@@ -1012,8 +1174,10 @@ func (c *pollConn) pullLoop() {
default:
}
req, err := http.NewRequestWithContext(c.ctx, http.MethodGet, c.pullURL, nil)
reqCtx, cancel := context.WithTimeout(c.ctx, 30*time.Second)
req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, c.pullURL, nil)
if err != nil {
cancel()
_ = c.Close()
return
}
@@ -1023,8 +1187,10 @@ func (c *pollConn) pullLoop() {
resp, err := c.client.Do(req)
if err != nil {
if isDialError(err) && dialRetry < maxDialRetry {
cancel()
if (isDialError(err) || isRetryableRequestError(err)) && dialRetry < maxDialRetry {
dialRetry++
closeIdleConnections(c.client)
select {
case <-time.After(backoff):
case <-c.closed:
@@ -1043,7 +1209,25 @@ func (c *pollConn) pullLoop() {
backoff = minBackoff
if resp.StatusCode != http.StatusOK {
if isRetryableStatusCode(resp.StatusCode) && dialRetry < maxDialRetry {
dialRetry++
_, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 4*1024))
_ = resp.Body.Close()
cancel()
closeIdleConnections(c.client)
select {
case <-time.After(backoff):
case <-c.closed:
return
}
backoff *= 2
if backoff > maxBackoff {
backoff = maxBackoff
}
continue
}
_ = resp.Body.Close()
cancel()
_ = c.closeWithError(fmt.Errorf("poll pull bad status: %s", resp.Status))
return
}
@@ -1068,7 +1252,12 @@ func (c *pollConn) pullLoop() {
}
}
_ = resp.Body.Close()
cancel()
if err := scanner.Err(); err != nil {
// Treat transient stream breaks (RST/EOF) as an ended long-poll and retry.
if errors.Is(err, io.ErrUnexpectedEOF) || isRetryableRequestError(err) {
continue
}
_ = c.closeWithError(fmt.Errorf("poll pull scan failed: %w", err))
return
}
@@ -1077,8 +1266,8 @@ func (c *pollConn) pullLoop() {
func (c *pollConn) pushLoop() {
const (
maxBatchBytes = 64 * 1024
flushInterval = 5 * time.Millisecond
maxBatchBytes = 512 * 1024
flushInterval = 50 * time.Millisecond
maxLineRawBytes = 16 * 1024
maxDialRetry = 12
minBackoff = 10 * time.Millisecond
@@ -1097,12 +1286,16 @@ func (c *pollConn) pushLoop() {
return nil
}
payload := buf.Bytes()
reqCtx, cancel := context.WithTimeout(c.ctx, 20*time.Second)
req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, c.pushURL, bytes.NewReader(buf.Bytes()))
req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, c.pushURL, bytes.NewReader(payload))
if err != nil {
cancel()
return err
}
req.GetBody = func() (io.ReadCloser, error) {
return io.NopCloser(bytes.NewReader(payload)), nil
}
req.Host = c.headerHost
applyTunnelHeaders(req.Header, c.headerHost, TunnelModePoll)
applyTunnelAuth(req, c.auth, TunnelModePoll, http.MethodPost, "/api/v1/upload")
@@ -1117,7 +1310,7 @@ func (c *pollConn) pushLoop() {
_ = resp.Body.Close()
cancel()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("bad status: %s", resp.Status)
return &httpStatusError{code: resp.StatusCode, status: resp.Status}
}
buf.Reset()
@@ -1131,8 +1324,22 @@ func (c *pollConn) pushLoop() {
for {
if err := flush(); err == nil {
return nil
} else if isDialError(err) && dialRetry < maxDialRetry {
} else if se := (*httpStatusError)(nil); errors.As(err, &se) && isRetryableStatusCode(se.code) && dialRetry < maxDialRetry {
dialRetry++
closeIdleConnections(c.client)
select {
case <-time.After(backoff):
case <-c.closed:
return c.closedErr()
}
backoff *= 2
if backoff > maxBackoff {
backoff = maxBackoff
}
continue
} else if (isDialError(err) || isRetryableRequestError(err)) && dialRetry < maxDialRetry {
dialRetry++
closeIdleConnections(c.client)
select {
case <-time.After(backoff):
case <-c.closed:
@@ -1482,6 +1689,16 @@ func (s *TunnelServer) HandleConn(rawConn net.Conn) (HandleResult, net.Conn, err
return HandleDone, nil, nil
}
return s.handlePoll(rawConn, req, headerBytes, buffered)
case TunnelModeWS:
if s.mode != TunnelModeWS && s.mode != TunnelModeAuto {
if s.passThroughOnReject {
return reject()
}
_ = writeSimpleHTTPResponse(rawConn, http.StatusNotFound, "not found")
_ = rawConn.Close()
return HandleDone, nil, nil
}
return s.handleWS(rawConn, req, headerBytes, buffered)
default:
if s.passThroughOnReject {
return reject()

View File

@@ -0,0 +1,176 @@
package httpmask
import (
"context"
"fmt"
"io"
mrand "math/rand"
"net"
stdhttp "net/http"
"net/url"
"strings"
"time"
"github.com/gobwas/ws"
"github.com/metacubex/tls"
)
func normalizeWSSchemeFromAddress(serverAddress string, tlsEnabled bool) (string, string) {
addr := strings.TrimSpace(serverAddress)
if strings.Contains(addr, "://") {
if u, err := url.Parse(addr); err == nil && u != nil {
switch strings.ToLower(strings.TrimSpace(u.Scheme)) {
case "ws":
return "ws", u.Host
case "wss":
return "wss", u.Host
}
}
}
if tlsEnabled {
return "wss", addr
}
return "ws", addr
}
func normalizeWSDialTarget(serverAddress string, tlsEnabled bool, hostOverride string) (scheme, urlHost, dialAddr, serverName string, err error) {
scheme, addr := normalizeWSSchemeFromAddress(serverAddress, tlsEnabled)
host, port, err := net.SplitHostPort(addr)
if err != nil {
// Allow ws(s)://host without port.
if strings.Contains(addr, ":") {
return "", "", "", "", fmt.Errorf("invalid server address %q: %w", serverAddress, err)
}
switch scheme {
case "wss":
port = "443"
default:
port = "80"
}
host = addr
}
if hostOverride != "" {
// Allow "example.com" or "example.com:443"
if h, p, splitErr := net.SplitHostPort(hostOverride); splitErr == nil {
if h != "" {
hostOverride = h
}
if p != "" {
port = p
}
}
serverName = hostOverride
urlHost = net.JoinHostPort(hostOverride, port)
} else {
serverName = host
urlHost = net.JoinHostPort(host, port)
}
dialAddr = net.JoinHostPort(host, port)
return scheme, urlHost, dialAddr, trimPortForHost(serverName), nil
}
func applyWSHeaders(h stdhttp.Header, host string) {
if h == nil {
return
}
r := rngPool.Get().(*mrand.Rand)
ua := userAgents[r.Intn(len(userAgents))]
accept := accepts[r.Intn(len(accepts))]
lang := acceptLanguages[r.Intn(len(acceptLanguages))]
enc := acceptEncodings[r.Intn(len(acceptEncodings))]
rngPool.Put(r)
h.Set("User-Agent", ua)
h.Set("Accept", accept)
h.Set("Accept-Language", lang)
h.Set("Accept-Encoding", enc)
h.Set("Cache-Control", "no-cache")
h.Set("Pragma", "no-cache")
h.Set("X-Sudoku-Tunnel", string(TunnelModeWS))
h.Set("X-Sudoku-Version", "1")
}
func dialWS(ctx context.Context, serverAddress string, opts TunnelDialOptions) (net.Conn, error) {
if opts.DialContext == nil {
panic("httpmask: DialContext is nil")
}
scheme, urlHost, dialAddr, serverName, err := normalizeWSDialTarget(serverAddress, opts.TLSEnabled, opts.HostOverride)
if err != nil {
return nil, err
}
httpScheme := "http"
if scheme == "wss" {
httpScheme = "https"
}
headerHost := canonicalHeaderHost(urlHost, httpScheme)
auth := newTunnelAuth(opts.AuthKey, 0)
u := &url.URL{
Scheme: scheme,
Host: urlHost,
Path: joinPathRoot(opts.PathRoot, "/ws"),
}
header := make(stdhttp.Header)
applyWSHeaders(header, headerHost)
if auth != nil {
token := auth.token(TunnelModeWS, stdhttp.MethodGet, "/ws", time.Now())
if token != "" {
header.Set("Authorization", "Bearer "+token)
q := u.Query()
q.Set(tunnelAuthQueryKey, token)
u.RawQuery = q.Encode()
}
}
d := ws.Dialer{
Host: headerHost,
Header: ws.HandshakeHeaderHTTP(header),
NetDial: func(dialCtx context.Context, network, addr string) (net.Conn, error) {
if addr == urlHost {
addr = dialAddr
}
return opts.DialContext(dialCtx, network, addr)
},
}
if scheme == "wss" {
tlsConfig := &tls.Config{
ServerName: serverName,
MinVersion: tls.VersionTLS12,
}
d.TLSClient = func(conn net.Conn, hostname string) net.Conn {
return tls.Client(conn, tlsConfig)
}
}
conn, br, _, err := d.Dial(ctx, u.String())
if err != nil {
return nil, err
}
if br != nil && br.Buffered() > 0 {
pre := make([]byte, br.Buffered())
_, _ = io.ReadFull(br, pre)
conn = newPreBufferedConn(conn, pre)
}
wsConn := newWSStreamConn(conn, ws.StateClientSide)
if opts.Upgrade == nil {
return wsConn, nil
}
upgraded, err := opts.Upgrade(wsConn)
if err != nil {
_ = wsConn.Close()
return nil, err
}
if upgraded != nil {
return upgraded, nil
}
return wsConn, nil
}

View File

@@ -0,0 +1,77 @@
package httpmask
import (
"net"
"net/http"
"net/url"
"strings"
"time"
"github.com/gobwas/ws"
)
func looksLikeWebSocketUpgrade(headers map[string]string) bool {
if headers == nil {
return false
}
if !strings.EqualFold(strings.TrimSpace(headers["upgrade"]), "websocket") {
return false
}
conn := headers["connection"]
for _, part := range strings.Split(conn, ",") {
if strings.EqualFold(strings.TrimSpace(part), "upgrade") {
return true
}
}
return false
}
func (s *TunnelServer) handleWS(rawConn net.Conn, req *httpRequestHeader, headerBytes []byte, buffered []byte) (HandleResult, net.Conn, error) {
rejectOrReply := func(code int, body string) (HandleResult, net.Conn, error) {
if s.passThroughOnReject {
prefix := make([]byte, 0, len(headerBytes)+len(buffered))
prefix = append(prefix, headerBytes...)
prefix = append(prefix, buffered...)
return HandlePassThrough, newRejectedPreBufferedConn(rawConn, prefix), nil
}
_ = writeSimpleHTTPResponse(rawConn, code, body)
_ = rawConn.Close()
return HandleDone, nil, nil
}
u, err := url.ParseRequestURI(req.target)
if err != nil {
return rejectOrReply(http.StatusBadRequest, "bad request")
}
path, ok := stripPathRoot(s.pathRoot, u.Path)
if !ok || path != "/ws" {
return rejectOrReply(http.StatusNotFound, "not found")
}
if strings.ToUpper(strings.TrimSpace(req.method)) != http.MethodGet {
return rejectOrReply(http.StatusBadRequest, "bad request")
}
if !looksLikeWebSocketUpgrade(req.headers) {
return rejectOrReply(http.StatusBadRequest, "bad request")
}
authVal := req.headers["authorization"]
if authVal == "" {
authVal = u.Query().Get(tunnelAuthQueryKey)
}
if !s.auth.verifyValue(authVal, TunnelModeWS, req.method, path, time.Now()) {
return rejectOrReply(http.StatusNotFound, "not found")
}
prefix := make([]byte, 0, len(headerBytes)+len(buffered))
prefix = append(prefix, headerBytes...)
prefix = append(prefix, buffered...)
wsConnRaw := newPreBufferedConn(rawConn, prefix)
if _, err := ws.Upgrade(wsConnRaw); err != nil {
_ = rawConn.Close()
return HandleDone, nil, nil
}
return HandleStartTunnel, newWSStreamConn(wsConnRaw, ws.StateServerSide), nil
}

View File

@@ -0,0 +1,78 @@
package httpmask
import (
"errors"
"fmt"
"io"
"net"
"github.com/gobwas/ws"
"github.com/gobwas/ws/wsutil"
)
type wsStreamConn struct {
net.Conn
state ws.State
reader *wsutil.Reader
controlHandler wsutil.FrameHandlerFunc
}
func newWSStreamConn(conn net.Conn, state ws.State) net.Conn {
controlHandler := wsutil.ControlFrameHandler(conn, state)
return &wsStreamConn{
Conn: conn,
state: state,
reader: &wsutil.Reader{
Source: conn,
State: state,
},
controlHandler: controlHandler,
}
}
func (c *wsStreamConn) Read(b []byte) (n int, err error) {
defer func() {
if v := recover(); v != nil {
err = fmt.Errorf("websocket error: %v", v)
}
}()
for {
n, err = c.reader.Read(b)
if errors.Is(err, io.EOF) {
err = nil
}
if !errors.Is(err, wsutil.ErrNoFrameAdvance) {
return n, err
}
hdr, err2 := c.reader.NextFrame()
if err2 != nil {
return 0, err2
}
if hdr.OpCode.IsControl() {
if err := c.controlHandler(hdr, c.reader); err != nil {
return 0, err
}
continue
}
if hdr.OpCode&(ws.OpBinary|ws.OpText) == 0 {
if err := c.reader.Discard(); err != nil {
return 0, err
}
continue
}
}
}
func (c *wsStreamConn) Write(b []byte) (int, error) {
if err := wsutil.WriteMessage(c.Conn, c.state, ws.OpBinary, b); err != nil {
return 0, err
}
return len(b), nil
}
func (c *wsStreamConn) Close() error {
_ = wsutil.WriteMessage(c.Conn, c.state, ws.OpClose, ws.NewCloseFrameBody(ws.StatusNormalClosure, ""))
return c.Conn.Close()
}

View File

@@ -4,9 +4,7 @@ import (
"crypto/sha256"
"encoding/binary"
"errors"
"log"
"math/rand"
"time"
)
var (
@@ -26,7 +24,7 @@ type Table struct {
func NewTable(key string, mode string) *Table {
t, err := NewTableWithCustom(key, mode, "")
if err != nil {
log.Panicf("failed to build table: %v", err)
panic(err)
}
return t
}
@@ -35,8 +33,6 @@ func NewTable(key string, mode string) *Table {
// mode: "prefer_ascii" or "prefer_entropy". If a custom pattern is provided, ASCII mode still takes precedence.
// The customPattern must contain 8 characters with exactly 2 x, 2 p, and 4 v (case-insensitive).
func NewTableWithCustom(key string, mode string, customPattern string) (*Table, error) {
start := time.Now()
layout, err := resolveLayout(mode, customPattern)
if err != nil {
return nil, err
@@ -126,7 +122,6 @@ func NewTableWithCustom(key string, mode string, customPattern string) (*Table,
}
}
}
log.Printf("[Init] Sudoku Tables initialized (%s) in %v", layout.name, time.Since(start))
return t, nil
}

View File

@@ -0,0 +1,74 @@
package sudoku
import (
"sync"
"time"
)
var handshakeReplayTTL = 60 * time.Second
type nonceSet struct {
mu sync.Mutex
m map[[kipHelloNonceSize]byte]time.Time
maxEntries int
lastPrune time.Time
}
func newNonceSet(maxEntries int) *nonceSet {
if maxEntries <= 0 {
maxEntries = 4096
}
return &nonceSet{
m: make(map[[kipHelloNonceSize]byte]time.Time),
maxEntries: maxEntries,
}
}
func (s *nonceSet) allow(nonce [kipHelloNonceSize]byte, now time.Time, ttl time.Duration) bool {
s.mu.Lock()
defer s.mu.Unlock()
if ttl <= 0 {
ttl = 60 * time.Second
}
if now.Sub(s.lastPrune) > ttl/2 || len(s.m) > s.maxEntries {
for k, exp := range s.m {
if !now.Before(exp) {
delete(s.m, k)
}
}
s.lastPrune = now
for len(s.m) > s.maxEntries {
for k := range s.m {
delete(s.m, k)
break
}
}
}
if exp, ok := s.m[nonce]; ok && now.Before(exp) {
return false
}
s.m[nonce] = now.Add(ttl)
return true
}
type handshakeReplayProtector struct {
users sync.Map // map[userHash string]*nonceSet
}
func (p *handshakeReplayProtector) allow(userHash string, nonce [kipHelloNonceSize]byte, now time.Time) bool {
if userHash == "" {
userHash = "_"
}
val, _ := p.users.LoadOrStore(userHash, newNonceSet(4096))
set, ok := val.(*nonceSet)
if !ok || set == nil {
set = newNonceSet(4096)
p.users.Store(userHash, set)
}
return set.allow(nonce, now, handshakeReplayTTL)
}
var globalHandshakeReplay = &handshakeReplayProtector{}

View File

@@ -0,0 +1,58 @@
package sudoku
import (
"crypto/ecdh"
"crypto/sha256"
"fmt"
"io"
"golang.org/x/crypto/hkdf"
)
func derivePSKDirectionalBases(psk string) (c2s, s2c []byte) {
sum := sha256.Sum256([]byte(psk))
c2sKey := make([]byte, 32)
s2cKey := make([]byte, 32)
if _, err := io.ReadFull(hkdf.Expand(sha256.New, sum[:], []byte("sudoku-psk-c2s")), c2sKey); err != nil {
panic("sudoku: hkdf expand failed")
}
if _, err := io.ReadFull(hkdf.Expand(sha256.New, sum[:], []byte("sudoku-psk-s2c")), s2cKey); err != nil {
panic("sudoku: hkdf expand failed")
}
return c2sKey, s2cKey
}
func deriveSessionDirectionalBases(psk string, shared []byte, nonce [kipHelloNonceSize]byte) (c2s, s2c []byte, err error) {
sum := sha256.Sum256([]byte(psk))
ikm := make([]byte, 0, len(shared)+len(nonce))
ikm = append(ikm, shared...)
ikm = append(ikm, nonce[:]...)
prk := hkdf.Extract(sha256.New, ikm, sum[:])
c2sKey := make([]byte, 32)
s2cKey := make([]byte, 32)
if _, err := io.ReadFull(hkdf.Expand(sha256.New, prk, []byte("sudoku-session-c2s")), c2sKey); err != nil {
return nil, nil, fmt.Errorf("hkdf expand c2s: %w", err)
}
if _, err := io.ReadFull(hkdf.Expand(sha256.New, prk, []byte("sudoku-session-s2c")), s2cKey); err != nil {
return nil, nil, fmt.Errorf("hkdf expand s2c: %w", err)
}
return c2sKey, s2cKey, nil
}
func x25519SharedSecret(priv *ecdh.PrivateKey, peerPub []byte) ([]byte, error) {
if priv == nil {
return nil, fmt.Errorf("nil priv")
}
curve := ecdh.X25519()
pk, err := curve.NewPublicKey(peerPub)
if err != nil {
return nil, fmt.Errorf("parse peer pub: %w", err)
}
secret, err := priv.ECDH(pk)
if err != nil {
return nil, fmt.Errorf("ecdh: %w", err)
}
return secret, nil
}

View File

@@ -4,7 +4,6 @@ import (
"bufio"
"bytes"
crand "crypto/rand"
"encoding/binary"
"errors"
"fmt"
"io"
@@ -56,26 +55,27 @@ func drainBuffered(r *bufio.Reader) ([]byte, error) {
func probeHandshakeBytes(probe []byte, cfg *ProtocolConfig, table *sudoku.Table) error {
rc := &readOnlyConn{Reader: bytes.NewReader(probe)}
_, obfsConn := buildServerObfsConn(rc, cfg, table, false)
cConn, err := crypto.NewAEADConn(obfsConn, cfg.Key, cfg.AEADMethod)
seed := ServerAEADSeed(cfg.Key)
pskC2S, pskS2C := derivePSKDirectionalBases(seed)
// Server side: recv is client->server, send is server->client.
cConn, err := crypto.NewRecordConn(obfsConn, cfg.AEADMethod, pskS2C, pskC2S)
if err != nil {
return err
}
var handshakeBuf [16]byte
if _, err := io.ReadFull(cConn, handshakeBuf[:]); err != nil {
msg, err := ReadKIPMessage(cConn)
if err != nil {
return err
}
ts := int64(binary.BigEndian.Uint64(handshakeBuf[:8]))
if absInt64(time.Now().Unix()-ts) > 60 {
return fmt.Errorf("timestamp skew/replay detected")
if msg.Type != KIPTypeClientHello {
return fmt.Errorf("unexpected handshake message: %d", msg.Type)
}
modeBuf := []byte{0}
if _, err := io.ReadFull(cConn, modeBuf); err != nil {
ch, err := DecodeKIPClientHelloPayload(msg.Payload)
if err != nil {
return err
}
if modeBuf[0] != downlinkMode(cfg) {
return fmt.Errorf("downlink mode mismatch")
if absInt64(time.Now().Unix()-ch.Timestamp.Unix()) > int64(kipHandshakeSkew.Seconds()) {
return fmt.Errorf("time skew/replay")
}
return nil
@@ -93,6 +93,17 @@ func selectTableByProbe(r *bufio.Reader, cfg *ProtocolConfig, tables []*sudoku.T
return nil, nil, fmt.Errorf("too many table candidates: %d", len(tables))
}
// Copy so we can prune candidates without mutating the caller slice.
candidates := make([]*sudoku.Table, 0, len(tables))
for _, t := range tables {
if t != nil {
candidates = append(candidates, t)
}
}
if len(candidates) == 0 {
return nil, nil, fmt.Errorf("no table candidates")
}
probe, err := drainBuffered(r)
if err != nil {
return nil, nil, fmt.Errorf("drain buffered bytes failed: %w", err)
@@ -100,17 +111,18 @@ func selectTableByProbe(r *bufio.Reader, cfg *ProtocolConfig, tables []*sudoku.T
tmp := make([]byte, readChunk)
for {
if len(tables) == 1 {
if len(candidates) == 1 {
tail, err := drainBuffered(r)
if err != nil {
return nil, nil, fmt.Errorf("drain buffered bytes failed: %w", err)
}
probe = append(probe, tail...)
return tables[0], probe, nil
return candidates[0], probe, nil
}
needMore := false
for _, table := range tables {
next := candidates[:0]
for _, table := range candidates {
err := probeHandshakeBytes(probe, cfg, table)
if err == nil {
tail, err := drainBuffered(r)
@@ -122,10 +134,13 @@ func selectTableByProbe(r *bufio.Reader, cfg *ProtocolConfig, tables []*sudoku.T
}
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
needMore = true
next = append(next, table)
}
// Definitive mismatch: drop table.
}
candidates = next
if !needMore {
if len(candidates) == 0 || !needMore {
return nil, probe, fmt.Errorf("handshake table selection failed")
}
if len(probe) >= maxProbeBytes {

View File

@@ -6,9 +6,7 @@ import (
"github.com/metacubex/mihomo/transport/sudoku/obfs/sudoku"
)
// NewTablesWithCustomPatterns builds one or more obfuscation tables from x/v/p custom patterns.
// When customTables is non-empty it overrides customTable (matching upstream Sudoku behavior).
func NewTablesWithCustomPatterns(key string, tableType string, customTable string, customTables []string) ([]*sudoku.Table, error) {
func normalizeCustomPatterns(customTable string, customTables []string) []string {
patterns := customTables
if len(patterns) == 0 && strings.TrimSpace(customTable) != "" {
patterns = []string{customTable}
@@ -16,7 +14,15 @@ func NewTablesWithCustomPatterns(key string, tableType string, customTable strin
if len(patterns) == 0 {
patterns = []string{""}
}
return patterns
}
// NewTablesWithCustomPatterns builds one or more obfuscation tables from x/v/p custom patterns.
// When customTables is non-empty it overrides customTable (matching upstream Sudoku behavior).
//
// Deprecated-ish: prefer NewClientTablesWithCustomPatterns / NewServerTablesWithCustomPatterns.
func NewTablesWithCustomPatterns(key string, tableType string, customTable string, customTables []string) ([]*sudoku.Table, error) {
patterns := normalizeCustomPatterns(customTable, customTables)
tables := make([]*sudoku.Table, 0, len(patterns))
for _, pattern := range patterns {
pattern = strings.TrimSpace(pattern)
@@ -28,3 +34,17 @@ func NewTablesWithCustomPatterns(key string, tableType string, customTable strin
}
return tables, nil
}
func NewClientTablesWithCustomPatterns(key string, tableType string, customTable string, customTables []string) ([]*sudoku.Table, error) {
return NewTablesWithCustomPatterns(key, tableType, customTable, customTables)
}
// NewServerTablesWithCustomPatterns matches upstream server behavior: when custom table rotation is enabled,
// also accept the default table to avoid forcing clients to update in lockstep.
func NewServerTablesWithCustomPatterns(key string, tableType string, customTable string, customTables []string) ([]*sudoku.Table, error) {
patterns := normalizeCustomPatterns(customTable, customTables)
if len(patterns) > 0 && strings.TrimSpace(patterns[0]) != "" {
patterns = append([]string{""}, patterns...)
}
return NewTablesWithCustomPatterns(key, tableType, "", patterns)
}

View File

@@ -16,17 +16,9 @@ import (
)
const (
UoTMagicByte byte = 0xEE
uotVersion = 0x01
maxUoTPayload = 64 * 1024
maxUoTPayload = 64 * 1024
)
// WritePreface writes the UDP-over-TCP marker and version.
func WritePreface(w io.Writer) error {
_, err := w.Write([]byte{UoTMagicByte, uotVersion})
return err
}
// WriteDatagram sends a single UDP datagram frame over a reliable stream.
func WriteDatagram(w io.Writer, addr string, payload []byte) error {
addrBuf, err := EncodeAddress(addr)
@@ -45,14 +37,13 @@ func WriteDatagram(w io.Writer, addr string, payload []byte) error {
binary.BigEndian.PutUint16(header[:2], uint16(len(addrBuf)))
binary.BigEndian.PutUint16(header[2:], uint16(len(payload)))
if _, err := w.Write(header[:]); err != nil {
if err := writeFull(w, header[:]); err != nil {
return err
}
if _, err := w.Write(addrBuf); err != nil {
if err := writeFull(w, addrBuf); err != nil {
return err
}
_, err = w.Write(payload)
return err
return writeFull(w, payload)
}
// ReadDatagram parses a single UDP datagram frame from the reliable stream.