feat: add path-root for sudoku (#2511)

This commit is contained in:
saba-futai
2026-01-14 21:25:05 +08:00
committed by GitHub
parent f38fc2020f
commit 06f5fbac06
17 changed files with 660 additions and 187 deletions

View File

@@ -43,6 +43,7 @@ type SudokuOption struct {
HTTPMaskMode string `proxy:"http-mask-mode,omitempty"` // "legacy" (default), "stream", "poll", "auto" HTTPMaskMode string `proxy:"http-mask-mode,omitempty"` // "legacy" (default), "stream", "poll", "auto"
HTTPMaskTLS bool `proxy:"http-mask-tls,omitempty"` // only for http-mask-mode stream/poll/auto HTTPMaskTLS bool `proxy:"http-mask-tls,omitempty"` // only for http-mask-mode stream/poll/auto
HTTPMaskHost string `proxy:"http-mask-host,omitempty"` // optional Host/SNI override (domain or domain:port) HTTPMaskHost string `proxy:"http-mask-host,omitempty"` // optional Host/SNI override (domain or domain:port)
PathRoot string `proxy:"path-root,omitempty"` // optional first-level path prefix for HTTP tunnel endpoints
HTTPMaskMultiplex string `proxy:"http-mask-multiplex,omitempty"` // "off" (default), "auto" (reuse h1/h2), "on" (single tunnel, multi-target) HTTPMaskMultiplex string `proxy:"http-mask-multiplex,omitempty"` // "off" (default), "auto" (reuse h1/h2), "on" (single tunnel, multi-target)
CustomTable string `proxy:"custom-table,omitempty"` // optional custom byte layout, e.g. xpxvvpvv CustomTable string `proxy:"custom-table,omitempty"` // optional custom byte layout, e.g. xpxvvpvv
CustomTables []string `proxy:"custom-tables,omitempty"` // optional table rotation patterns, overrides custom-table when non-empty CustomTables []string `proxy:"custom-tables,omitempty"` // optional table rotation patterns, overrides custom-table when non-empty
@@ -183,6 +184,7 @@ func NewSudoku(option SudokuOption) (*Sudoku, error) {
HTTPMaskMode: defaultConf.HTTPMaskMode, HTTPMaskMode: defaultConf.HTTPMaskMode,
HTTPMaskTLSEnabled: option.HTTPMaskTLS, HTTPMaskTLSEnabled: option.HTTPMaskTLS,
HTTPMaskHost: option.HTTPMaskHost, HTTPMaskHost: option.HTTPMaskHost,
HTTPMaskPathRoot: strings.TrimSpace(option.PathRoot),
HTTPMaskMultiplex: defaultConf.HTTPMaskMultiplex, HTTPMaskMultiplex: defaultConf.HTTPMaskMultiplex,
} }
if option.HTTPMaskMode != "" { if option.HTTPMaskMode != "" {
@@ -257,7 +259,19 @@ func (s *Sudoku) dialAndHandshake(ctx context.Context, cfg *sudoku.ProtocolConfi
return nil, fmt.Errorf("config is required") return nil, fmt.Errorf("config is required")
} }
var c net.Conn handshakeCfg := *cfg
if !handshakeCfg.DisableHTTPMask && httpTunnelModeEnabled(handshakeCfg.HTTPMaskMode) {
handshakeCfg.DisableHTTPMask = true
}
upgrade := func(raw net.Conn) (net.Conn, error) {
return sudoku.ClientHandshakeWithOptions(raw, &handshakeCfg, sudoku.ClientHandshakeOptions{})
}
var (
c net.Conn
handshakeDone bool
)
if !cfg.DisableHTTPMask && httpTunnelModeEnabled(cfg.HTTPMaskMode) { if !cfg.DisableHTTPMask && httpTunnelModeEnabled(cfg.HTTPMaskMode) {
muxMode := normalizeHTTPMaskMultiplex(cfg.HTTPMaskMultiplex) muxMode := normalizeHTTPMaskMultiplex(cfg.HTTPMaskMultiplex)
switch muxMode { switch muxMode {
@@ -266,9 +280,12 @@ func (s *Sudoku) dialAndHandshake(ctx context.Context, cfg *sudoku.ProtocolConfi
if errX != nil { if errX != nil {
return nil, errX return nil, errX
} }
c, err = client.Dial(ctx) c, err = client.Dial(ctx, upgrade)
default: default:
c, err = sudoku.DialHTTPMaskTunnel(ctx, cfg.ServerAddress, cfg, s.dialer.DialContext) c, err = sudoku.DialHTTPMaskTunnel(ctx, cfg.ServerAddress, cfg, s.dialer.DialContext, upgrade)
}
if err == nil && c != nil {
handshakeDone = true
} }
} }
if c == nil && err == nil { if c == nil && err == nil {
@@ -285,14 +302,11 @@ func (s *Sudoku) dialAndHandshake(ctx context.Context, cfg *sudoku.ProtocolConfi
defer done(&err) defer done(&err)
} }
handshakeCfg := *cfg if !handshakeDone {
if !handshakeCfg.DisableHTTPMask && httpTunnelModeEnabled(handshakeCfg.HTTPMaskMode) { c, err = sudoku.ClientHandshakeWithOptions(c, &handshakeCfg, sudoku.ClientHandshakeOptions{})
handshakeCfg.DisableHTTPMask = true if err != nil {
} return nil, err
}
c, err = sudoku.ClientHandshakeWithOptions(c, &handshakeCfg, sudoku.ClientHandshakeOptions{})
if err != nil {
return nil, err
} }
return c, nil return c, nil

View File

@@ -1072,6 +1072,7 @@ proxies: # socks5
# http-mask-mode: legacy # 可选legacy默认、stream、poll、autostream/poll/auto 支持走 CDN/反代 # http-mask-mode: legacy # 可选legacy默认、stream、poll、autostream/poll/auto 支持走 CDN/反代
# http-mask-tls: true # 可选:仅在 http-mask-mode 为 stream/poll/auto 时生效true 强制 httpsfalse 强制 http不会根据端口自动推断 # http-mask-tls: true # 可选:仅在 http-mask-mode 为 stream/poll/auto 时生效true 强制 httpsfalse 强制 http不会根据端口自动推断
# http-mask-host: "" # 可选:覆盖 Host/SNI支持 example.com 或 example.com:443仅在 http-mask-mode 为 stream/poll/auto 时生效 # http-mask-host: "" # 可选:覆盖 Host/SNI支持 example.com 或 example.com:443仅在 http-mask-mode 为 stream/poll/auto 时生效
# path-root: "" # 可选HTTP 隧道端点一级路径前缀(双方需一致),例如 "aabbcc" => /aabbcc/session、/aabbcc/stream、/aabbcc/api/v1/upload
# http-mask-multiplex: off # 可选off默认、auto复用 h1.1 keep-alive / h2 连接,减少每次建链 RTT、on单条隧道内多路复用多个目标连接仅在 http-mask-mode=stream/poll/auto 生效) # http-mask-multiplex: off # 可选off默认、auto复用 h1.1 keep-alive / h2 连接,减少每次建链 RTT、on单条隧道内多路复用多个目标连接仅在 http-mask-mode=stream/poll/auto 生效)
enable-pure-downlink: false # 是否启用混淆下行false的情况下能在保证数据安全的前提下极大提升下行速度与服务端端保持相同(如果此处为false则要求aead不可为none) enable-pure-downlink: false # 是否启用混淆下行false的情况下能在保证数据安全的前提下极大提升下行速度与服务端端保持相同(如果此处为false则要求aead不可为none)
@@ -1621,6 +1622,7 @@ listeners:
enable-pure-downlink: false # 是否启用混淆下行false的情况下能在保证数据安全的前提下极大提升下行速度与客户端保持相同(如果此处为false则要求aead不可为none) enable-pure-downlink: false # 是否启用混淆下行false的情况下能在保证数据安全的前提下极大提升下行速度与客户端保持相同(如果此处为false则要求aead不可为none)
disable-http-mask: false # 可选:禁用 http 掩码/隧道(默认为 false disable-http-mask: false # 可选:禁用 http 掩码/隧道(默认为 false
# http-mask-mode: legacy # 可选legacy默认、stream、poll、autostream/poll/auto 支持走 CDN/反代 # http-mask-mode: legacy # 可选legacy默认、stream、poll、autostream/poll/auto 支持走 CDN/反代
# path-root: "" # 可选HTTP 隧道端点一级路径前缀(双方需一致),例如 "aabbcc" => /aabbcc/session、/aabbcc/stream、/aabbcc/api/v1/upload

View File

@@ -22,6 +22,7 @@ type SudokuServer struct {
CustomTables []string `json:"custom-tables,omitempty"` CustomTables []string `json:"custom-tables,omitempty"`
DisableHTTPMask bool `json:"disable-http-mask,omitempty"` DisableHTTPMask bool `json:"disable-http-mask,omitempty"`
HTTPMaskMode string `json:"http-mask-mode,omitempty"` HTTPMaskMode string `json:"http-mask-mode,omitempty"`
PathRoot string `json:"path-root,omitempty"`
// mihomo private extension (not the part of standard Sudoku protocol) // mihomo private extension (not the part of standard Sudoku protocol)
MuxOption sing.MuxOption `json:"mux-option,omitempty"` MuxOption sing.MuxOption `json:"mux-option,omitempty"`

View File

@@ -24,6 +24,7 @@ type SudokuOption struct {
CustomTables []string `inbound:"custom-tables,omitempty"` CustomTables []string `inbound:"custom-tables,omitempty"`
DisableHTTPMask bool `inbound:"disable-http-mask,omitempty"` DisableHTTPMask bool `inbound:"disable-http-mask,omitempty"`
HTTPMaskMode string `inbound:"http-mask-mode,omitempty"` // "legacy" (default), "stream", "poll", "auto" HTTPMaskMode string `inbound:"http-mask-mode,omitempty"` // "legacy" (default), "stream", "poll", "auto"
PathRoot string `inbound:"path-root,omitempty"` // optional first-level path prefix for HTTP tunnel endpoints
// mihomo private extension (not the part of standard Sudoku protocol) // mihomo private extension (not the part of standard Sudoku protocol)
MuxOption MuxOption `inbound:"mux-option,omitempty"` MuxOption MuxOption `inbound:"mux-option,omitempty"`
@@ -63,6 +64,7 @@ func NewSudoku(options *SudokuOption) (*Sudoku, error) {
CustomTables: options.CustomTables, CustomTables: options.CustomTables,
DisableHTTPMask: options.DisableHTTPMask, DisableHTTPMask: options.DisableHTTPMask,
HTTPMaskMode: options.HTTPMaskMode, HTTPMaskMode: options.HTTPMaskMode,
PathRoot: strings.TrimSpace(options.PathRoot),
} }
serverConf.MuxOption = options.MuxOption.Build() serverConf.MuxOption = options.MuxOption.Build()

View File

@@ -229,6 +229,7 @@ func New(config LC.SudokuServer, tunnel C.Tunnel, additions ...inbound.Addition)
HandshakeTimeoutSeconds: handshakeTimeout, HandshakeTimeoutSeconds: handshakeTimeout,
DisableHTTPMask: config.DisableHTTPMask, DisableHTTPMask: config.DisableHTTPMask,
HTTPMaskMode: config.HTTPMaskMode, HTTPMaskMode: config.HTTPMaskMode,
HTTPMaskPathRoot: strings.TrimSpace(config.PathRoot),
} }
if len(tables) == 1 { if len(tables) == 1 {
protoConf.Table = tables[0] protoConf.Table = tables[0]

View File

@@ -58,6 +58,10 @@ type ProtocolConfig struct {
// HTTPMaskHost optionally overrides the HTTP Host header / SNI host for HTTP tunnel modes (client-side). // HTTPMaskHost optionally overrides the HTTP Host header / SNI host for HTTP tunnel modes (client-side).
HTTPMaskHost string HTTPMaskHost string
// HTTPMaskPathRoot optionally prefixes all HTTP mask paths with a first-level segment.
// Example: "aabbcc" => "/aabbcc/session", "/aabbcc/api/v1/upload", ...
HTTPMaskPathRoot string
// HTTPMaskMultiplex controls multiplex behavior when HTTPMask tunnel modes are enabled: // HTTPMaskMultiplex controls multiplex behavior when HTTPMask tunnel modes are enabled:
// - "off": disable reuse; each Dial establishes its own HTTPMask tunnel // - "off": disable reuse; each Dial establishes its own HTTPMask tunnel
// - "auto": reuse underlying HTTP connections across multiple tunnel dials (HTTP/1.1 keep-alive / HTTP/2) // - "auto": reuse underlying HTTP connections across multiple tunnel dials (HTTP/1.1 keep-alive / HTTP/2)
@@ -109,6 +113,23 @@ func (c *ProtocolConfig) Validate() error {
return fmt.Errorf("invalid http-mask-mode: %s, must be one of: legacy, stream, poll, auto", c.HTTPMaskMode) return fmt.Errorf("invalid http-mask-mode: %s, must be one of: legacy, stream, poll, auto", c.HTTPMaskMode)
} }
if v := strings.TrimSpace(c.HTTPMaskPathRoot); v != "" {
if strings.Contains(v, "/") {
return fmt.Errorf("invalid http-mask-path-root: must be a single path segment")
}
for i := 0; i < len(v); i++ {
ch := v[i]
switch {
case ch >= 'a' && ch <= 'z':
case ch >= 'A' && ch <= 'Z':
case ch >= '0' && ch <= '9':
case ch == '_' || ch == '-':
default:
return fmt.Errorf("invalid http-mask-path-root: contains invalid character %q", ch)
}
}
}
switch strings.ToLower(strings.TrimSpace(c.HTTPMaskMultiplex)) { switch strings.ToLower(strings.TrimSpace(c.HTTPMaskMultiplex)) {
case "", "off", "auto", "on": case "", "off", "auto", "on":
default: default:

View File

@@ -2,7 +2,6 @@ package sudoku
import ( import (
"bufio" "bufio"
"crypto/rand"
"crypto/sha256" "crypto/sha256"
"encoding/binary" "encoding/binary"
"encoding/hex" "encoding/hex"
@@ -153,14 +152,17 @@ func buildServerObfsConn(raw net.Conn, cfg *ProtocolConfig, table *sudoku.Table,
func buildHandshakePayload(key string) [16]byte { func buildHandshakePayload(key string) [16]byte {
var payload [16]byte var payload [16]byte
binary.BigEndian.PutUint64(payload[:8], uint64(time.Now().Unix())) binary.BigEndian.PutUint64(payload[:8], uint64(time.Now().Unix()))
// Hash the decoded HEX bytes of the key, not the HEX string itself.
// This ensures the user hash is computed on the actual key bytes. // Align with upstream: only decode hex bytes when this key is an ED25519 key material.
keyBytes, err := hex.DecodeString(key) // For plain UUID/strings (even if they look like hex), hash the string bytes as-is.
if err != nil { src := []byte(key)
// Fallback: if key is not valid HEX (e.g., a UUID or plain string), hash the string bytes if _, err := crypto.RecoverPublicKey(key); err == nil {
keyBytes = []byte(key) if keyBytes, decErr := hex.DecodeString(key); decErr == nil && len(keyBytes) > 0 {
src = keyBytes
}
} }
hash := sha256.Sum256(keyBytes)
hash := sha256.Sum256(src)
copy(payload[8:], hash[:8]) copy(payload[8:], hash[:8])
return payload return payload
} }
@@ -211,12 +213,12 @@ func ClientHandshakeWithOptions(rawConn net.Conn, cfg *ProtocolConfig, opt Clien
} }
if !cfg.DisableHTTPMask { if !cfg.DisableHTTPMask {
if err := WriteHTTPMaskHeader(rawConn, cfg.ServerAddress, opt.HTTPMaskStrategy); err != nil { if err := WriteHTTPMaskHeader(rawConn, cfg.ServerAddress, cfg.HTTPMaskPathRoot, opt.HTTPMaskStrategy); err != nil {
return nil, fmt.Errorf("write http mask failed: %w", err) return nil, fmt.Errorf("write http mask failed: %w", err)
} }
} }
table, tableID, err := pickClientTable(cfg) table, err := pickClientTable(cfg)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -228,9 +230,6 @@ func ClientHandshakeWithOptions(rawConn net.Conn, cfg *ProtocolConfig, opt Clien
} }
handshake := buildHandshakePayload(cfg.Key) handshake := buildHandshakePayload(cfg.Key)
if len(cfg.tableCandidates()) > 1 {
handshake[8] = tableID
}
if _, err := cConn.Write(handshake[:]); err != nil { if _, err := cConn.Write(handshake[:]); err != nil {
cConn.Close() cConn.Close()
return nil, fmt.Errorf("send handshake failed: %w", err) return nil, fmt.Errorf("send handshake failed: %w", err)
@@ -376,19 +375,9 @@ func normalizeHTTPMaskStrategy(strategy string) string {
} }
} }
// randomByte returns a cryptographically random byte (with a math/rand fallback).
func randomByte() byte {
var b [1]byte
if _, err := rand.Read(b[:]); err == nil {
return b[0]
}
return byte(time.Now().UnixNano())
}
func userHashFromHandshake(handshakeBuf []byte) string { func userHashFromHandshake(handshakeBuf []byte) string {
if len(handshakeBuf) < 16 { if len(handshakeBuf) < 16 {
return "" return ""
} }
// handshake[8] may be a table ID when table rotation is enabled; use [9:16] as stable user hash bytes. return hex.EncodeToString(handshakeBuf[8:16])
return hex.EncodeToString(handshakeBuf[9:16])
} }

View File

@@ -7,6 +7,7 @@ import (
"math/rand" "math/rand"
"net" "net"
"strconv" "strconv"
"strings"
"sync" "sync"
"time" "time"
@@ -92,24 +93,24 @@ func appendCommonHeaders(buf []byte, host string, r *rand.Rand) []byte {
// WriteHTTPMaskHeader writes an HTTP/1.x request header as a mask, according to strategy. // WriteHTTPMaskHeader writes an HTTP/1.x request header as a mask, according to strategy.
// Supported strategies: ""/"random", "post", "websocket". // Supported strategies: ""/"random", "post", "websocket".
func WriteHTTPMaskHeader(w io.Writer, host string, strategy string) error { func WriteHTTPMaskHeader(w io.Writer, host string, pathRoot string, strategy string) error {
switch normalizeHTTPMaskStrategy(strategy) { switch normalizeHTTPMaskStrategy(strategy) {
case "random": case "random":
return httpmask.WriteRandomRequestHeader(w, host) return httpmask.WriteRandomRequestHeaderWithPathRoot(w, host, pathRoot)
case "post": case "post":
return writeHTTPMaskPOST(w, host) return writeHTTPMaskPOST(w, host, pathRoot)
case "websocket": case "websocket":
return writeHTTPMaskWebSocket(w, host) return writeHTTPMaskWebSocket(w, host, pathRoot)
default: default:
return fmt.Errorf("unsupported http-mask-strategy: %s", strategy) return fmt.Errorf("unsupported http-mask-strategy: %s", strategy)
} }
} }
func writeHTTPMaskPOST(w io.Writer, host string) error { func writeHTTPMaskPOST(w io.Writer, host string, pathRoot string) error {
r := httpMaskRngPool.Get().(*rand.Rand) r := httpMaskRngPool.Get().(*rand.Rand)
defer httpMaskRngPool.Put(r) defer httpMaskRngPool.Put(r)
path := httpMaskPaths[r.Intn(len(httpMaskPaths))] path := joinPathRoot(pathRoot, httpMaskPaths[r.Intn(len(httpMaskPaths))])
ctype := httpMaskContentTypes[r.Intn(len(httpMaskContentTypes))] ctype := httpMaskContentTypes[r.Intn(len(httpMaskContentTypes))]
bufPtr := httpMaskBufPool.Get().(*[]byte) bufPtr := httpMaskBufPool.Get().(*[]byte)
@@ -140,11 +141,11 @@ func writeHTTPMaskPOST(w io.Writer, host string) error {
return err return err
} }
func writeHTTPMaskWebSocket(w io.Writer, host string) error { func writeHTTPMaskWebSocket(w io.Writer, host string, pathRoot string) error {
r := httpMaskRngPool.Get().(*rand.Rand) r := httpMaskRngPool.Get().(*rand.Rand)
defer httpMaskRngPool.Put(r) defer httpMaskRngPool.Put(r)
path := httpMaskPaths[r.Intn(len(httpMaskPaths))] path := joinPathRoot(pathRoot, httpMaskPaths[r.Intn(len(httpMaskPaths))])
bufPtr := httpMaskBufPool.Get().(*[]byte) bufPtr := httpMaskBufPool.Get().(*[]byte)
buf := *bufPtr buf := *bufPtr
@@ -177,3 +178,37 @@ func writeHTTPMaskWebSocket(w io.Writer, host string) error {
_, err := w.Write(buf) _, err := w.Write(buf)
return err return err
} }
func normalizePathRoot(root string) string {
root = strings.TrimSpace(root)
root = strings.Trim(root, "/")
if root == "" {
return ""
}
for i := 0; i < len(root); i++ {
c := root[i]
switch {
case c >= 'a' && c <= 'z':
case c >= 'A' && c <= 'Z':
case c >= '0' && c <= '9':
case c == '_' || c == '-':
default:
return ""
}
}
return "/" + root
}
func joinPathRoot(root, path string) string {
root = normalizePathRoot(root)
if root == "" {
return path
}
if path == "" {
return root
}
if !strings.HasPrefix(path, "/") {
path = "/" + path
}
return root + path
}

View File

@@ -23,7 +23,11 @@ func NewHTTPMaskTunnelServer(cfg *ProtocolConfig) *HTTPMaskTunnelServer {
if !cfg.DisableHTTPMask { if !cfg.DisableHTTPMask {
switch strings.ToLower(strings.TrimSpace(cfg.HTTPMaskMode)) { switch strings.ToLower(strings.TrimSpace(cfg.HTTPMaskMode)) {
case "stream", "poll", "auto": case "stream", "poll", "auto":
ts = httpmask.NewTunnelServer(httpmask.TunnelServerOptions{Mode: cfg.HTTPMaskMode}) ts = httpmask.NewTunnelServer(httpmask.TunnelServerOptions{
Mode: cfg.HTTPMaskMode,
PathRoot: cfg.HTTPMaskPathRoot,
AuthKey: ClientAEADSeed(cfg.Key),
})
} }
} }
return &HTTPMaskTunnelServer{cfg: cfg, ts: ts} return &HTTPMaskTunnelServer{cfg: cfg, ts: ts}
@@ -67,7 +71,7 @@ func (s *HTTPMaskTunnelServer) WrapConn(rawConn net.Conn) (handshakeConn net.Con
type TunnelDialer func(ctx context.Context, network, addr string) (net.Conn, error) type TunnelDialer func(ctx context.Context, network, addr string) (net.Conn, error)
// DialHTTPMaskTunnel dials a CDN-capable HTTP tunnel (stream/poll/auto) and returns a stream carrying raw Sudoku bytes. // DialHTTPMaskTunnel dials a CDN-capable HTTP tunnel (stream/poll/auto) and returns a stream carrying raw Sudoku bytes.
func DialHTTPMaskTunnel(ctx context.Context, serverAddress string, cfg *ProtocolConfig, dial TunnelDialer) (net.Conn, error) { func DialHTTPMaskTunnel(ctx context.Context, serverAddress string, cfg *ProtocolConfig, dial TunnelDialer, upgrade func(net.Conn) (net.Conn, error)) (net.Conn, error) {
if cfg == nil { if cfg == nil {
return nil, fmt.Errorf("config is required") return nil, fmt.Errorf("config is required")
} }
@@ -83,14 +87,19 @@ func DialHTTPMaskTunnel(ctx context.Context, serverAddress string, cfg *Protocol
Mode: cfg.HTTPMaskMode, Mode: cfg.HTTPMaskMode,
TLSEnabled: cfg.HTTPMaskTLSEnabled, TLSEnabled: cfg.HTTPMaskTLSEnabled,
HostOverride: cfg.HTTPMaskHost, HostOverride: cfg.HTTPMaskHost,
PathRoot: cfg.HTTPMaskPathRoot,
AuthKey: ClientAEADSeed(cfg.Key),
Upgrade: upgrade,
Multiplex: cfg.HTTPMaskMultiplex, Multiplex: cfg.HTTPMaskMultiplex,
DialContext: dial, DialContext: dial,
}) })
} }
type HTTPMaskTunnelClient struct { type HTTPMaskTunnelClient struct {
mode string mode string
client *httpmask.TunnelClient pathRoot string
authKey string
client *httpmask.TunnelClient
} }
func NewHTTPMaskTunnelClient(serverAddress string, cfg *ProtocolConfig, dial TunnelDialer) (*HTTPMaskTunnelClient, error) { func NewHTTPMaskTunnelClient(serverAddress string, cfg *ProtocolConfig, dial TunnelDialer) (*HTTPMaskTunnelClient, error) {
@@ -121,16 +130,23 @@ func NewHTTPMaskTunnelClient(serverAddress string, cfg *ProtocolConfig, dial Tun
} }
return &HTTPMaskTunnelClient{ return &HTTPMaskTunnelClient{
mode: cfg.HTTPMaskMode, mode: cfg.HTTPMaskMode,
client: c, pathRoot: cfg.HTTPMaskPathRoot,
authKey: ClientAEADSeed(cfg.Key),
client: c,
}, nil }, nil
} }
func (c *HTTPMaskTunnelClient) Dial(ctx context.Context) (net.Conn, error) { func (c *HTTPMaskTunnelClient) Dial(ctx context.Context, upgrade func(net.Conn) (net.Conn, error)) (net.Conn, error) {
if c == nil || c.client == nil { if c == nil || c.client == nil {
return nil, fmt.Errorf("nil httpmask tunnel client") return nil, fmt.Errorf("nil httpmask tunnel client")
} }
return c.client.DialTunnel(ctx, c.mode) return c.client.DialTunnel(ctx, httpmask.TunnelDialOptions{
Mode: c.mode,
PathRoot: c.pathRoot,
AuthKey: c.authKey,
Upgrade: upgrade,
})
} }
func (c *HTTPMaskTunnelClient) CloseIdleConnections() { func (c *HTTPMaskTunnelClient) CloseIdleConnections() {

View File

@@ -154,7 +154,7 @@ func TestHTTPMaskTunnel_Stream_TCPRoundTrip(t *testing.T) {
clientCfg.ServerAddress = addr clientCfg.ServerAddress = addr
clientCfg.HTTPMaskHost = "example.com" clientCfg.HTTPMaskHost = "example.com"
tunnelConn, err := DialHTTPMaskTunnel(ctx, clientCfg.ServerAddress, &clientCfg, (&net.Dialer{}).DialContext) tunnelConn, err := DialHTTPMaskTunnel(ctx, clientCfg.ServerAddress, &clientCfg, (&net.Dialer{}).DialContext, nil)
if err != nil { if err != nil {
t.Fatalf("dial tunnel: %v", err) t.Fatalf("dial tunnel: %v", err)
} }
@@ -225,7 +225,7 @@ func TestHTTPMaskTunnel_Poll_UoTRoundTrip(t *testing.T) {
clientCfg := *serverCfg clientCfg := *serverCfg
clientCfg.ServerAddress = addr clientCfg.ServerAddress = addr
tunnelConn, err := DialHTTPMaskTunnel(ctx, clientCfg.ServerAddress, &clientCfg, (&net.Dialer{}).DialContext) tunnelConn, err := DialHTTPMaskTunnel(ctx, clientCfg.ServerAddress, &clientCfg, (&net.Dialer{}).DialContext, nil)
if err != nil { if err != nil {
t.Fatalf("dial tunnel: %v", err) t.Fatalf("dial tunnel: %v", err)
} }
@@ -287,7 +287,7 @@ func TestHTTPMaskTunnel_Auto_TCPRoundTrip(t *testing.T) {
clientCfg := *serverCfg clientCfg := *serverCfg
clientCfg.ServerAddress = addr clientCfg.ServerAddress = addr
tunnelConn, err := DialHTTPMaskTunnel(ctx, clientCfg.ServerAddress, &clientCfg, (&net.Dialer{}).DialContext) tunnelConn, err := DialHTTPMaskTunnel(ctx, clientCfg.ServerAddress, &clientCfg, (&net.Dialer{}).DialContext, nil)
if err != nil { if err != nil {
t.Fatalf("dial tunnel: %v", err) t.Fatalf("dial tunnel: %v", err)
} }
@@ -331,13 +331,13 @@ func TestHTTPMaskTunnel_Validation(t *testing.T) {
cfg.DisableHTTPMask = true cfg.DisableHTTPMask = true
cfg.HTTPMaskMode = "stream" cfg.HTTPMaskMode = "stream"
if _, err := DialHTTPMaskTunnel(context.Background(), cfg.ServerAddress, cfg, (&net.Dialer{}).DialContext); err == nil { if _, err := DialHTTPMaskTunnel(context.Background(), cfg.ServerAddress, cfg, (&net.Dialer{}).DialContext, nil); err == nil {
t.Fatalf("expected error for disabled http mask") t.Fatalf("expected error for disabled http mask")
} }
cfg.DisableHTTPMask = false cfg.DisableHTTPMask = false
cfg.HTTPMaskMode = "legacy" cfg.HTTPMaskMode = "legacy"
if _, err := DialHTTPMaskTunnel(context.Background(), cfg.ServerAddress, cfg, (&net.Dialer{}).DialContext); err == nil { if _, err := DialHTTPMaskTunnel(context.Background(), cfg.ServerAddress, cfg, (&net.Dialer{}).DialContext, nil); err == nil {
t.Fatalf("expected error for legacy mode") t.Fatalf("expected error for legacy mode")
} }
} }
@@ -385,7 +385,7 @@ func TestHTTPMaskTunnel_Soak_Concurrent(t *testing.T) {
clientCfg.ServerAddress = addr clientCfg.ServerAddress = addr
clientCfg.HTTPMaskHost = strings.TrimSpace(clientCfg.HTTPMaskHost) clientCfg.HTTPMaskHost = strings.TrimSpace(clientCfg.HTTPMaskHost)
tunnelConn, err := DialHTTPMaskTunnel(ctx, clientCfg.ServerAddress, &clientCfg, (&net.Dialer{}).DialContext) tunnelConn, err := DialHTTPMaskTunnel(ctx, clientCfg.ServerAddress, &clientCfg, (&net.Dialer{}).DialContext, nil)
if err != nil { if err != nil {
runErr <- fmt.Errorf("dial: %w", err) runErr <- fmt.Errorf("dial: %w", err)
return return

View File

@@ -99,7 +99,7 @@ func TestUserHash_StableAcrossTableRotation(t *testing.T) {
if h == "" { if h == "" {
t.Fatalf("empty user hash") t.Fatalf("empty user hash")
} }
if len(h) != 14 { if len(h) != 16 {
t.Fatalf("unexpected user hash length: %d", len(h)) t.Fatalf("unexpected user hash length: %d", len(h))
} }
unique[h] = struct{}{} unique[h] = struct{}{}
@@ -258,4 +258,3 @@ func TestMultiplex_Boundary_InvalidVersion(t *testing.T) {
t.Fatalf("expected error") t.Fatalf("expected error")
} }
} }

View File

@@ -0,0 +1,137 @@
package httpmask
import (
"crypto/hmac"
"crypto/sha256"
"crypto/subtle"
"encoding/base64"
"encoding/binary"
"strings"
"time"
)
const (
tunnelAuthHeaderKey = "Authorization"
tunnelAuthHeaderPrefix = "Bearer "
)
type tunnelAuth struct {
key [32]byte // derived HMAC key
skew time.Duration
}
func newTunnelAuth(key string, skew time.Duration) *tunnelAuth {
key = strings.TrimSpace(key)
if key == "" {
return nil
}
if skew <= 0 {
skew = 60 * time.Second
}
// Domain separation: keep this HMAC key independent from other uses of cfg.Key.
h := sha256.New()
_, _ = h.Write([]byte("sudoku-httpmask-auth-v1:"))
_, _ = h.Write([]byte(key))
var sum [32]byte
h.Sum(sum[:0])
return &tunnelAuth{key: sum, skew: skew}
}
func (a *tunnelAuth) token(mode TunnelMode, method, path string, now time.Time) string {
if a == nil {
return ""
}
ts := now.Unix()
sig := a.sign(mode, method, path, ts)
var buf [8 + 16]byte
binary.BigEndian.PutUint64(buf[:8], uint64(ts))
copy(buf[8:], sig[:])
return base64.RawURLEncoding.EncodeToString(buf[:])
}
func (a *tunnelAuth) verify(headers map[string]string, mode TunnelMode, method, path string, now time.Time) bool {
if a == nil {
return true
}
if headers == nil {
return false
}
val := strings.TrimSpace(headers["authorization"])
if val == "" {
return false
}
// Accept both "Bearer <token>" and raw token forms (for forward proxies / CDNs that may normalize headers).
if len(val) > len(tunnelAuthHeaderPrefix) && strings.EqualFold(val[:len(tunnelAuthHeaderPrefix)], tunnelAuthHeaderPrefix) {
val = strings.TrimSpace(val[len(tunnelAuthHeaderPrefix):])
}
if val == "" {
return false
}
raw, err := base64.RawURLEncoding.DecodeString(val)
if err != nil || len(raw) != 8+16 {
return false
}
ts := int64(binary.BigEndian.Uint64(raw[:8]))
nowTS := now.Unix()
delta := nowTS - ts
if delta < 0 {
delta = -delta
}
if delta > int64(a.skew.Seconds()) {
return false
}
want := a.sign(mode, method, path, ts)
return subtle.ConstantTimeCompare(raw[8:], want[:]) == 1
}
func (a *tunnelAuth) sign(mode TunnelMode, method, path string, ts int64) [16]byte {
method = strings.ToUpper(strings.TrimSpace(method))
if method == "" {
method = "GET"
}
path = strings.TrimSpace(path)
var tsBuf [8]byte
binary.BigEndian.PutUint64(tsBuf[:], uint64(ts))
mac := hmac.New(sha256.New, a.key[:])
_, _ = mac.Write([]byte(mode))
_, _ = mac.Write([]byte{0})
_, _ = mac.Write([]byte(method))
_, _ = mac.Write([]byte{0})
_, _ = mac.Write([]byte(path))
_, _ = mac.Write([]byte{0})
_, _ = mac.Write(tsBuf[:])
var full [32]byte
mac.Sum(full[:0])
var out [16]byte
copy(out[:], full[:16])
return out
}
type headerSetter interface {
Set(key, value string)
}
func applyTunnelAuthHeader(h headerSetter, auth *tunnelAuth, mode TunnelMode, method, path string) {
if auth == nil || h == nil {
return
}
token := auth.token(mode, method, path, time.Now())
if token == "" {
return
}
h.Set(tunnelAuthHeaderKey, tunnelAuthHeaderPrefix+token)
}

View File

@@ -129,11 +129,17 @@ func appendCommonHeaders(buf []byte, host string, r *rand.Rand) []byte {
// WriteRandomRequestHeader writes a plausible HTTP/1.1 request header as a mask. // WriteRandomRequestHeader writes a plausible HTTP/1.1 request header as a mask.
func WriteRandomRequestHeader(w io.Writer, host string) error { func WriteRandomRequestHeader(w io.Writer, host string) error {
return WriteRandomRequestHeaderWithPathRoot(w, host, "")
}
// WriteRandomRequestHeaderWithPathRoot is like WriteRandomRequestHeader but prefixes all paths with pathRoot.
// pathRoot must be a single segment (e.g. "aabbcc"); invalid inputs are treated as empty (disabled).
func WriteRandomRequestHeaderWithPathRoot(w io.Writer, host string, pathRoot string) error {
// Get RNG from pool // Get RNG from pool
r := rngPool.Get().(*rand.Rand) r := rngPool.Get().(*rand.Rand)
defer rngPool.Put(r) defer rngPool.Put(r)
path := paths[r.Intn(len(paths))] path := joinPathRoot(pathRoot, paths[r.Intn(len(paths))])
ctype := contentTypes[r.Intn(len(contentTypes))] ctype := contentTypes[r.Intn(len(contentTypes))]
// Use buffer pool // Use buffer pool

View File

@@ -0,0 +1,52 @@
package httpmask
import "strings"
// normalizePathRoot normalizes the configured path root into "/<segment>" form.
//
// It is intentionally strict: only a single path segment is allowed, consisting of
// [A-Za-z0-9_-]. Invalid inputs are treated as empty (disabled).
func normalizePathRoot(root string) string {
root = strings.TrimSpace(root)
root = strings.Trim(root, "/")
if root == "" {
return ""
}
for i := 0; i < len(root); i++ {
c := root[i]
switch {
case c >= 'a' && c <= 'z':
case c >= 'A' && c <= 'Z':
case c >= '0' && c <= '9':
case c == '_' || c == '-':
default:
return ""
}
}
return "/" + root
}
func joinPathRoot(root, path string) string {
root = normalizePathRoot(root)
if root == "" {
return path
}
if path == "" {
return root
}
if !strings.HasPrefix(path, "/") {
path = "/" + path
}
return root + path
}
func stripPathRoot(root, fullPath string) (string, bool) {
root = normalizePathRoot(root)
if root == "" {
return fullPath, true
}
if !strings.HasPrefix(fullPath, root+"/") {
return "", false
}
return strings.TrimPrefix(fullPath, root), true
}

View File

@@ -62,6 +62,15 @@ type TunnelDialOptions struct {
Mode string Mode string
TLSEnabled bool // when true, use HTTPS; otherwise, use HTTP (no port-based inference) TLSEnabled bool // when true, use HTTPS; otherwise, use HTTP (no port-based inference)
HostOverride string // optional Host header / SNI host (without scheme); accepts "example.com" or "example.com:443" HostOverride string // optional Host header / SNI host (without scheme); accepts "example.com" or "example.com:443"
// PathRoot is an optional first-level path prefix for all HTTP tunnel endpoints.
// Example: "aabbcc" => "/aabbcc/session", "/aabbcc/api/v1/upload", ...
PathRoot string
// AuthKey enables short-term HMAC auth for HTTP tunnel requests (anti-probing).
// When set (non-empty), each HTTP request carries an Authorization bearer token derived from AuthKey.
AuthKey string
// Upgrade optionally wraps the raw tunnel conn and/or writes a small prelude before DialTunnel returns.
// It is called with the raw tunnel conn; if it returns a non-nil conn, that conn is returned by DialTunnel.
Upgrade func(raw net.Conn) (net.Conn, error)
// Multiplex controls whether the caller should reuse underlying HTTP connections (HTTP/1.1 keep-alive / HTTP/2). // Multiplex controls whether the caller should reuse underlying HTTP connections (HTTP/1.1 keep-alive / HTTP/2).
// To reuse across multiple dials, create a TunnelClient per proxy and reuse it. // To reuse across multiple dials, create a TunnelClient per proxy and reuse it.
// Values: "off" disables reuse; "auto"/"on" enables it. // Values: "off" disables reuse; "auto"/"on" enables it.
@@ -109,34 +118,34 @@ func (c *TunnelClient) CloseIdleConnections() {
c.transport.CloseIdleConnections() c.transport.CloseIdleConnections()
} }
func (c *TunnelClient) DialTunnel(ctx context.Context, mode string) (net.Conn, error) { func (c *TunnelClient) DialTunnel(ctx context.Context, opts TunnelDialOptions) (net.Conn, error) {
if c == nil || c.client == nil { if c == nil || c.client == nil {
return nil, fmt.Errorf("nil tunnel client") return nil, fmt.Errorf("nil tunnel client")
} }
tm := normalizeTunnelMode(mode) tm := normalizeTunnelMode(opts.Mode)
if tm == TunnelModeLegacy { if tm == TunnelModeLegacy {
return nil, fmt.Errorf("legacy mode does not use http tunnel") return nil, fmt.Errorf("legacy mode does not use http tunnel")
} }
switch tm { switch tm {
case TunnelModeStream: case TunnelModeStream:
return dialStreamWithClient(ctx, c.client, c.target) return dialStreamWithClient(ctx, c.client, c.target, opts)
case TunnelModePoll: case TunnelModePoll:
return dialPollWithClient(ctx, c.client, c.target) return dialPollWithClient(ctx, c.client, c.target, opts)
case TunnelModeAuto: case TunnelModeAuto:
streamCtx, cancelX := context.WithTimeout(ctx, 3*time.Second) streamCtx, cancelX := context.WithTimeout(ctx, 3*time.Second)
c1, errX := dialStreamWithClient(streamCtx, c.client, c.target) c1, errX := dialStreamWithClient(streamCtx, c.client, c.target, opts)
cancelX() cancelX()
if errX == nil { if errX == nil {
return c1, nil return c1, nil
} }
c2, errP := dialPollWithClient(ctx, c.client, c.target) c2, errP := dialPollWithClient(ctx, c.client, c.target, opts)
if errP == nil { if errP == nil {
return c2, nil return c2, nil
} }
return nil, fmt.Errorf("auto tunnel failed: stream: %v; poll: %w", errX, errP) return nil, fmt.Errorf("auto tunnel failed: stream: %v; poll: %w", errX, errP)
default: default:
return dialStreamWithClient(ctx, c.client, c.target) return dialStreamWithClient(ctx, c.client, c.target, opts)
} }
} }
@@ -248,8 +257,13 @@ func (c *httpStreamConn) Close() error {
if c.cancel != nil { if c.cancel != nil {
c.cancel() c.cancel()
} }
_ = c.writer.CloseWithError(io.ErrClosedPipe) if c.writer != nil {
return c.reader.Close() _ = c.writer.CloseWithError(io.ErrClosedPipe)
}
if c.reader != nil {
return c.reader.Close()
}
return nil
} }
func (c *httpStreamConn) LocalAddr() net.Addr { return c.localAddr } func (c *httpStreamConn) LocalAddr() net.Addr { return c.localAddr }
@@ -320,20 +334,23 @@ type sessionDialInfo struct {
pullURL string pullURL string
closeURL string closeURL string
headerHost string headerHost string
auth *tunnelAuth
} }
func dialSessionWithClient(ctx context.Context, client *http.Client, target httpClientTarget, mode TunnelMode) (*sessionDialInfo, error) { func dialSessionWithClient(ctx context.Context, client *http.Client, target httpClientTarget, mode TunnelMode, opts TunnelDialOptions) (*sessionDialInfo, error) {
if client == nil { if client == nil {
return nil, fmt.Errorf("nil http client") return nil, fmt.Errorf("nil http client")
} }
authorizeURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: "/session"}).String() auth := newTunnelAuth(opts.AuthKey, 0)
authorizeURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: joinPathRoot(opts.PathRoot, "/session")}).String()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, authorizeURL, nil) req, err := http.NewRequestWithContext(ctx, http.MethodGet, authorizeURL, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
req.Host = target.headerHost req.Host = target.headerHost
applyTunnelHeaders(req.Header, target.headerHost, mode) applyTunnelHeaders(req.Header, target.headerHost, mode)
applyTunnelAuthHeader(req.Header, auth, mode, http.MethodGet, "/session")
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
@@ -356,9 +373,9 @@ func dialSessionWithClient(ctx context.Context, client *http.Client, target http
return nil, fmt.Errorf("%s authorize empty token", mode) return nil, fmt.Errorf("%s authorize empty token", mode)
} }
pushURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: "/api/v1/upload", RawQuery: "token=" + url.QueryEscape(token)}).String() pushURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: joinPathRoot(opts.PathRoot, "/api/v1/upload"), RawQuery: "token=" + url.QueryEscape(token)}).String()
pullURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: "/stream", RawQuery: "token=" + url.QueryEscape(token)}).String() pullURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: joinPathRoot(opts.PathRoot, "/stream"), RawQuery: "token=" + url.QueryEscape(token)}).String()
closeURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: "/api/v1/upload", RawQuery: "token=" + url.QueryEscape(token) + "&close=1"}).String() closeURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: joinPathRoot(opts.PathRoot, "/api/v1/upload"), RawQuery: "token=" + url.QueryEscape(token) + "&close=1"}).String()
return &sessionDialInfo{ return &sessionDialInfo{
client: client, client: client,
@@ -366,6 +383,7 @@ func dialSessionWithClient(ctx context.Context, client *http.Client, target http
pullURL: pullURL, pullURL: pullURL,
closeURL: closeURL, closeURL: closeURL,
headerHost: target.headerHost, headerHost: target.headerHost,
auth: auth,
}, nil }, nil
} }
@@ -374,10 +392,10 @@ func dialSession(ctx context.Context, serverAddress string, opts TunnelDialOptio
if err != nil { if err != nil {
return nil, err return nil, err
} }
return dialSessionWithClient(ctx, client, target, mode) return dialSessionWithClient(ctx, client, target, mode, opts)
} }
func bestEffortCloseSession(client *http.Client, closeURL, headerHost string, mode TunnelMode) { func bestEffortCloseSession(client *http.Client, closeURL, headerHost string, mode TunnelMode, auth *tunnelAuth) {
if client == nil || closeURL == "" || headerHost == "" { if client == nil || closeURL == "" || headerHost == "" {
return return
} }
@@ -391,6 +409,7 @@ func bestEffortCloseSession(client *http.Client, closeURL, headerHost string, mo
} }
req.Host = headerHost req.Host = headerHost
applyTunnelHeaders(req.Header, headerHost, mode) applyTunnelHeaders(req.Header, headerHost, mode)
applyTunnelAuthHeader(req.Header, auth, mode, http.MethodPost, "/api/v1/upload")
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil || resp == nil { if err != nil || resp == nil {
@@ -400,13 +419,13 @@ func bestEffortCloseSession(client *http.Client, closeURL, headerHost string, mo
_ = resp.Body.Close() _ = resp.Body.Close()
} }
func dialStreamWithClient(ctx context.Context, client *http.Client, target httpClientTarget) (net.Conn, error) { func dialStreamWithClient(ctx context.Context, client *http.Client, target httpClientTarget, opts TunnelDialOptions) (net.Conn, error) {
// Prefer split session (Cloudflare-friendly). Fall back to stream-one for older servers / environments. // Prefer split-session (Cloudflare-friendly). Fall back to stream-one for older servers / environments.
c, errSplit := dialStreamSplitWithClient(ctx, client, target) c, errSplit := dialStreamSplitWithClient(ctx, client, target, opts)
if errSplit == nil { if errSplit == nil {
return c, nil return c, nil
} }
c2, errOne := dialStreamOneWithClient(ctx, client, target) c2, errOne := dialStreamOneWithClient(ctx, client, target, opts)
if errOne == nil { if errOne == nil {
return c2, nil return c2, nil
} }
@@ -414,7 +433,7 @@ func dialStreamWithClient(ctx context.Context, client *http.Client, target httpC
} }
func dialStream(ctx context.Context, serverAddress string, opts TunnelDialOptions) (net.Conn, error) { func dialStream(ctx context.Context, serverAddress string, opts TunnelDialOptions) (net.Conn, error) {
// Prefer split session (Cloudflare-friendly). Fall back to stream-one for older servers / environments. // Prefer split-session (Cloudflare-friendly). Fall back to stream-one for older servers / environments.
c, errSplit := dialStreamSplit(ctx, serverAddress, opts) c, errSplit := dialStreamSplit(ctx, serverAddress, opts)
if errSplit == nil { if errSplit == nil {
return c, nil return c, nil
@@ -426,13 +445,15 @@ func dialStream(ctx context.Context, serverAddress string, opts TunnelDialOption
return nil, fmt.Errorf("dial stream failed: split: %v; stream-one: %w", errSplit, errOne) return nil, fmt.Errorf("dial stream failed: split: %v; stream-one: %w", errSplit, errOne)
} }
func dialStreamOneWithClient(ctx context.Context, client *http.Client, target httpClientTarget) (net.Conn, error) { func dialStreamOneWithClient(ctx context.Context, client *http.Client, target httpClientTarget, opts TunnelDialOptions) (net.Conn, error) {
if client == nil { if client == nil {
return nil, fmt.Errorf("nil http client") return nil, fmt.Errorf("nil http client")
} }
auth := newTunnelAuth(opts.AuthKey, 0)
r := rngPool.Get().(*mrand.Rand) r := rngPool.Get().(*mrand.Rand)
path := paths[r.Intn(len(paths))] basePath := paths[r.Intn(len(paths))]
path := joinPathRoot(opts.PathRoot, basePath)
ctype := contentTypes[r.Intn(len(contentTypes))] ctype := contentTypes[r.Intn(len(contentTypes))]
rngPool.Put(r) rngPool.Put(r)
@@ -454,6 +475,7 @@ func dialStreamOneWithClient(ctx context.Context, client *http.Client, target ht
req.Host = target.headerHost req.Host = target.headerHost
applyTunnelHeaders(req.Header, target.headerHost, TunnelModeStream) applyTunnelHeaders(req.Header, target.headerHost, TunnelModeStream)
applyTunnelAuthHeader(req.Header, auth, TunnelModeStream, http.MethodPost, basePath)
req.Header.Set("Content-Type", ctype) req.Header.Set("Content-Type", ctype)
type doResult struct { type doResult struct {
@@ -466,33 +488,84 @@ func dialStreamOneWithClient(ctx context.Context, client *http.Client, target ht
doCh <- doResult{resp: resp, err: doErr} doCh <- doResult{resp: resp, err: doErr}
}() }()
select { streamConn := &httpStreamConn{
case <-ctx.Done(): writer: reqBodyW,
connCancel() cancel: connCancel,
_ = reqBodyW.Close() localAddr: &net.TCPAddr{},
return nil, ctx.Err() remoteAddr: &net.TCPAddr{},
case r := <-doCh:
if r.err != nil {
connCancel()
_ = reqBodyW.Close()
return nil, r.err
}
if r.resp.StatusCode != http.StatusOK {
defer r.resp.Body.Close()
body, _ := io.ReadAll(io.LimitReader(r.resp.Body, 4*1024))
connCancel()
_ = reqBodyW.Close()
return nil, fmt.Errorf("stream bad status: %s (%s)", r.resp.Status, strings.TrimSpace(string(body)))
}
return &httpStreamConn{
reader: r.resp.Body,
writer: reqBodyW,
cancel: connCancel,
localAddr: &net.TCPAddr{},
remoteAddr: &net.TCPAddr{},
}, nil
} }
type upgradeResult struct {
conn net.Conn
err error
}
upgradeCh := make(chan upgradeResult, 1)
if opts.Upgrade == nil {
upgradeCh <- upgradeResult{conn: streamConn, err: nil}
} else {
go func() {
upgradeConn, err := opts.Upgrade(streamConn)
if err != nil {
upgradeCh <- upgradeResult{conn: nil, err: err}
return
}
if upgradeConn == nil {
upgradeConn = streamConn
}
upgradeCh <- upgradeResult{conn: upgradeConn, err: nil}
}()
}
var (
outConn net.Conn
upgradeDone bool
responseReady bool
)
for !(upgradeDone && responseReady) {
select {
case <-ctx.Done():
_ = streamConn.Close()
if outConn != nil && outConn != streamConn {
_ = outConn.Close()
}
return nil, ctx.Err()
case u := <-upgradeCh:
if u.err != nil {
_ = streamConn.Close()
return nil, u.err
}
outConn = u.conn
if outConn == nil {
outConn = streamConn
}
upgradeDone = true
case r := <-doCh:
if r.err != nil {
_ = streamConn.Close()
if outConn != nil && outConn != streamConn {
_ = outConn.Close()
}
return nil, r.err
}
if r.resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(r.resp.Body, 4*1024))
_ = r.resp.Body.Close()
_ = streamConn.Close()
if outConn != nil && outConn != streamConn {
_ = outConn.Close()
}
return nil, fmt.Errorf("stream bad status: %s (%s)", r.resp.Status, strings.TrimSpace(string(body)))
}
streamConn.reader = r.resp.Body
responseReady = true
}
}
return outConn, nil
} }
func dialStreamOne(ctx context.Context, serverAddress string, opts TunnelDialOptions) (net.Conn, error) { func dialStreamOne(ctx context.Context, serverAddress string, opts TunnelDialOptions) (net.Conn, error) {
@@ -500,7 +573,7 @@ func dialStreamOne(ctx context.Context, serverAddress string, opts TunnelDialOpt
if err != nil { if err != nil {
return nil, err return nil, err
} }
return dialStreamOneWithClient(ctx, client, target) return dialStreamOneWithClient(ctx, client, target, opts)
} }
type queuedConn struct { type queuedConn struct {
@@ -599,6 +672,7 @@ type streamSplitConn struct {
pullURL string pullURL string
closeURL string closeURL string
headerHost string headerHost string
auth *tunnelAuth
} }
func (c *streamSplitConn) Close() error { func (c *streamSplitConn) Close() error {
@@ -607,7 +681,7 @@ func (c *streamSplitConn) Close() error {
if c.cancel != nil { if c.cancel != nil {
c.cancel() c.cancel()
} }
bestEffortCloseSession(c.client, c.closeURL, c.headerHost, TunnelModeStream) bestEffortCloseSession(c.client, c.closeURL, c.headerHost, TunnelModeStream, c.auth)
return nil return nil
} }
@@ -625,6 +699,7 @@ func newStreamSplitConnFromInfo(info *sessionDialInfo) *streamSplitConn {
pullURL: info.pullURL, pullURL: info.pullURL,
closeURL: info.closeURL, closeURL: info.closeURL,
headerHost: info.headerHost, headerHost: info.headerHost,
auth: info.auth,
queuedConn: queuedConn{ queuedConn: queuedConn{
rxc: make(chan []byte, 256), rxc: make(chan []byte, 256),
closed: make(chan struct{}), closed: make(chan struct{}),
@@ -639,8 +714,8 @@ func newStreamSplitConnFromInfo(info *sessionDialInfo) *streamSplitConn {
return c return c
} }
func dialStreamSplitWithClient(ctx context.Context, client *http.Client, target httpClientTarget) (net.Conn, error) { func dialStreamSplitWithClient(ctx context.Context, client *http.Client, target httpClientTarget, opts TunnelDialOptions) (net.Conn, error) {
info, err := dialSessionWithClient(ctx, client, target, TunnelModeStream) info, err := dialSessionWithClient(ctx, client, target, TunnelModeStream, opts)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -648,7 +723,18 @@ func dialStreamSplitWithClient(ctx context.Context, client *http.Client, target
if c == nil { if c == nil {
return nil, fmt.Errorf("failed to build stream split conn") return nil, fmt.Errorf("failed to build stream split conn")
} }
return c, nil outConn := net.Conn(c)
if opts.Upgrade != nil {
upgraded, err := opts.Upgrade(c)
if err != nil {
_ = c.Close()
return nil, err
}
if upgraded != nil {
outConn = upgraded
}
}
return outConn, nil
} }
func dialStreamSplit(ctx context.Context, serverAddress string, opts TunnelDialOptions) (net.Conn, error) { func dialStreamSplit(ctx context.Context, serverAddress string, opts TunnelDialOptions) (net.Conn, error) {
@@ -660,7 +746,18 @@ func dialStreamSplit(ctx context.Context, serverAddress string, opts TunnelDialO
if c == nil { if c == nil {
return nil, fmt.Errorf("failed to build stream split conn") return nil, fmt.Errorf("failed to build stream split conn")
} }
return c, nil outConn := net.Conn(c)
if opts.Upgrade != nil {
upgraded, err := opts.Upgrade(c)
if err != nil {
_ = c.Close()
return nil, err
}
if upgraded != nil {
outConn = upgraded
}
}
return outConn, nil
} }
func (c *streamSplitConn) pullLoop() { func (c *streamSplitConn) pullLoop() {
@@ -696,6 +793,7 @@ func (c *streamSplitConn) pullLoop() {
} }
req.Host = c.headerHost req.Host = c.headerHost
applyTunnelHeaders(req.Header, c.headerHost, TunnelModeStream) applyTunnelHeaders(req.Header, c.headerHost, TunnelModeStream)
applyTunnelAuthHeader(req.Header, c.auth, TunnelModeStream, http.MethodGet, "/stream")
resp, err := c.client.Do(req) resp, err := c.client.Do(req)
if err != nil { if err != nil {
@@ -793,6 +891,7 @@ func (c *streamSplitConn) pushLoop() {
} }
req.Host = c.headerHost req.Host = c.headerHost
applyTunnelHeaders(req.Header, c.headerHost, TunnelModeStream) applyTunnelHeaders(req.Header, c.headerHost, TunnelModeStream)
applyTunnelAuthHeader(req.Header, c.auth, TunnelModeStream, http.MethodPost, "/api/v1/upload")
req.Header.Set("Content-Type", "application/octet-stream") req.Header.Set("Content-Type", "application/octet-stream")
resp, err := c.client.Do(req) resp, err := c.client.Do(req)
@@ -896,6 +995,7 @@ type pollConn struct {
pullURL string pullURL string
closeURL string closeURL string
headerHost string headerHost string
auth *tunnelAuth
} }
func isDialError(err error) bool { func isDialError(err error) bool {
@@ -917,7 +1017,7 @@ func (c *pollConn) closeWithError(err error) error {
if c.cancel != nil { if c.cancel != nil {
c.cancel() c.cancel()
} }
bestEffortCloseSession(c.client, c.closeURL, c.headerHost, TunnelModePoll) bestEffortCloseSession(c.client, c.closeURL, c.headerHost, TunnelModePoll, c.auth)
return nil return nil
} }
@@ -939,6 +1039,7 @@ func newPollConnFromInfo(info *sessionDialInfo) *pollConn {
pullURL: info.pullURL, pullURL: info.pullURL,
closeURL: info.closeURL, closeURL: info.closeURL,
headerHost: info.headerHost, headerHost: info.headerHost,
auth: info.auth,
queuedConn: queuedConn{ queuedConn: queuedConn{
rxc: make(chan []byte, 128), rxc: make(chan []byte, 128),
closed: make(chan struct{}), closed: make(chan struct{}),
@@ -953,8 +1054,8 @@ func newPollConnFromInfo(info *sessionDialInfo) *pollConn {
return c return c
} }
func dialPollWithClient(ctx context.Context, client *http.Client, target httpClientTarget) (net.Conn, error) { func dialPollWithClient(ctx context.Context, client *http.Client, target httpClientTarget, opts TunnelDialOptions) (net.Conn, error) {
info, err := dialSessionWithClient(ctx, client, target, TunnelModePoll) info, err := dialSessionWithClient(ctx, client, target, TunnelModePoll, opts)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -962,7 +1063,18 @@ func dialPollWithClient(ctx context.Context, client *http.Client, target httpCli
if c == nil { if c == nil {
return nil, fmt.Errorf("failed to build poll conn") return nil, fmt.Errorf("failed to build poll conn")
} }
return c, nil outConn := net.Conn(c)
if opts.Upgrade != nil {
upgraded, err := opts.Upgrade(c)
if err != nil {
_ = c.Close()
return nil, err
}
if upgraded != nil {
outConn = upgraded
}
}
return outConn, nil
} }
func dialPoll(ctx context.Context, serverAddress string, opts TunnelDialOptions) (net.Conn, error) { func dialPoll(ctx context.Context, serverAddress string, opts TunnelDialOptions) (net.Conn, error) {
@@ -974,7 +1086,18 @@ func dialPoll(ctx context.Context, serverAddress string, opts TunnelDialOptions)
if c == nil { if c == nil {
return nil, fmt.Errorf("failed to build poll conn") return nil, fmt.Errorf("failed to build poll conn")
} }
return c, nil outConn := net.Conn(c)
if opts.Upgrade != nil {
upgraded, err := opts.Upgrade(c)
if err != nil {
_ = c.Close()
return nil, err
}
if upgraded != nil {
outConn = upgraded
}
}
return outConn, nil
} }
func (c *pollConn) pullLoop() { func (c *pollConn) pullLoop() {
@@ -1001,6 +1124,7 @@ func (c *pollConn) pullLoop() {
} }
req.Host = c.headerHost req.Host = c.headerHost
applyTunnelHeaders(req.Header, c.headerHost, TunnelModePoll) applyTunnelHeaders(req.Header, c.headerHost, TunnelModePoll)
applyTunnelAuthHeader(req.Header, c.auth, TunnelModePoll, http.MethodGet, "/stream")
resp, err := c.client.Do(req) resp, err := c.client.Do(req)
if err != nil { if err != nil {
@@ -1084,6 +1208,7 @@ func (c *pollConn) pushLoop() {
} }
req.Host = c.headerHost req.Host = c.headerHost
applyTunnelHeaders(req.Header, c.headerHost, TunnelModePoll) applyTunnelHeaders(req.Header, c.headerHost, TunnelModePoll)
applyTunnelAuthHeader(req.Header, c.auth, TunnelModePoll, http.MethodPost, "/api/v1/upload")
req.Header.Set("Content-Type", "text/plain") req.Header.Set("Content-Type", "text/plain")
resp, err := c.client.Do(req) resp, err := c.client.Do(req)
@@ -1246,6 +1371,18 @@ func applyTunnelHeaders(h http.Header, host string, mode TunnelMode) {
type TunnelServerOptions struct { type TunnelServerOptions struct {
Mode string Mode string
// PathRoot is an optional first-level path prefix for all HTTP tunnel endpoints.
// Example: "aabbcc" => "/aabbcc/session", "/aabbcc/api/v1/upload", ...
PathRoot string
// AuthKey enables short-term HMAC auth for HTTP tunnel requests (anti-probing).
// When set (non-empty), the server requires each request to carry a valid Authorization bearer token.
AuthKey string
// AuthSkew controls allowed clock skew / replay window for AuthKey. 0 uses a conservative default.
AuthSkew time.Duration
// PassThroughOnReject controls how the server handles "recognized but rejected" tunnel requests
// (e.g., wrong mode / wrong path / invalid token). When true, the request bytes are replayed back
// to the caller as HandlePassThrough to allow higher-level fallback handling.
PassThroughOnReject bool
// PullReadTimeout controls how long the server long-poll waits for tunnel downlink data before replying with a keepalive newline. // PullReadTimeout controls how long the server long-poll waits for tunnel downlink data before replying with a keepalive newline.
PullReadTimeout time.Duration PullReadTimeout time.Duration
// SessionTTL is a best-effort TTL to prevent leaked sessions. 0 uses a conservative default. // SessionTTL is a best-effort TTL to prevent leaked sessions. 0 uses a conservative default.
@@ -1253,7 +1390,10 @@ type TunnelServerOptions struct {
} }
type TunnelServer struct { type TunnelServer struct {
mode TunnelMode mode TunnelMode
pathRoot string
passThroughOnReject bool
auth *tunnelAuth
pullReadTimeout time.Duration pullReadTimeout time.Duration
sessionTTL time.Duration sessionTTL time.Duration
@@ -1272,6 +1412,8 @@ func NewTunnelServer(opts TunnelServerOptions) *TunnelServer {
if mode == TunnelModeLegacy { if mode == TunnelModeLegacy {
// Server-side "legacy" means: don't accept stream/poll tunnels; only passthrough. // Server-side "legacy" means: don't accept stream/poll tunnels; only passthrough.
} }
pathRoot := normalizePathRoot(opts.PathRoot)
auth := newTunnelAuth(opts.AuthKey, opts.AuthSkew)
timeout := opts.PullReadTimeout timeout := opts.PullReadTimeout
if timeout <= 0 { if timeout <= 0 {
timeout = 10 * time.Second timeout = 10 * time.Second
@@ -1281,10 +1423,13 @@ func NewTunnelServer(opts TunnelServerOptions) *TunnelServer {
ttl = 2 * time.Minute ttl = 2 * time.Minute
} }
return &TunnelServer{ return &TunnelServer{
mode: mode, mode: mode,
pullReadTimeout: timeout, pathRoot: pathRoot,
sessionTTL: ttl, auth: auth,
sessions: make(map[string]*tunnelSession), passThroughOnReject: opts.PassThroughOnReject,
pullReadTimeout: timeout,
sessionTTL: ttl,
sessions: make(map[string]*tunnelSession),
} }
} }
@@ -1340,6 +1485,12 @@ func (s *TunnelServer) HandleConn(rawConn net.Conn) (HandleResult, net.Conn, err
return HandlePassThrough, newPreBufferedConn(rawConn, prefix), nil return HandlePassThrough, newPreBufferedConn(rawConn, prefix), nil
} }
if s.mode == TunnelModeLegacy { if s.mode == TunnelModeLegacy {
if s.passThroughOnReject {
prefix := make([]byte, 0, len(headerBytes)+len(buffered))
prefix = append(prefix, headerBytes...)
prefix = append(prefix, buffered...)
return HandlePassThrough, newPreBufferedConn(rawConn, prefix), nil
}
_ = writeSimpleHTTPResponse(rawConn, http.StatusNotFound, "not found") _ = writeSimpleHTTPResponse(rawConn, http.StatusNotFound, "not found")
_ = rawConn.Close() _ = rawConn.Close()
return HandleDone, nil, nil return HandleDone, nil, nil
@@ -1348,19 +1499,37 @@ func (s *TunnelServer) HandleConn(rawConn net.Conn) (HandleResult, net.Conn, err
switch TunnelMode(tunnelHeader) { switch TunnelMode(tunnelHeader) {
case TunnelModeStream: case TunnelModeStream:
if s.mode != TunnelModeStream && s.mode != TunnelModeAuto { if s.mode != TunnelModeStream && s.mode != TunnelModeAuto {
if s.passThroughOnReject {
prefix := make([]byte, 0, len(headerBytes)+len(buffered))
prefix = append(prefix, headerBytes...)
prefix = append(prefix, buffered...)
return HandlePassThrough, newPreBufferedConn(rawConn, prefix), nil
}
_ = writeSimpleHTTPResponse(rawConn, http.StatusNotFound, "not found") _ = writeSimpleHTTPResponse(rawConn, http.StatusNotFound, "not found")
_ = rawConn.Close() _ = rawConn.Close()
return HandleDone, nil, nil return HandleDone, nil, nil
} }
return s.handleStream(rawConn, req, buffered) return s.handleStream(rawConn, req, headerBytes, buffered)
case TunnelModePoll: case TunnelModePoll:
if s.mode != TunnelModePoll && s.mode != TunnelModeAuto { if s.mode != TunnelModePoll && s.mode != TunnelModeAuto {
if s.passThroughOnReject {
prefix := make([]byte, 0, len(headerBytes)+len(buffered))
prefix = append(prefix, headerBytes...)
prefix = append(prefix, buffered...)
return HandlePassThrough, newPreBufferedConn(rawConn, prefix), nil
}
_ = writeSimpleHTTPResponse(rawConn, http.StatusNotFound, "not found") _ = writeSimpleHTTPResponse(rawConn, http.StatusNotFound, "not found")
_ = rawConn.Close() _ = rawConn.Close()
return HandleDone, nil, nil return HandleDone, nil, nil
} }
return s.handlePoll(rawConn, req, buffered) return s.handlePoll(rawConn, req, headerBytes, buffered)
default: default:
if s.passThroughOnReject {
prefix := make([]byte, 0, len(headerBytes)+len(buffered))
prefix = append(prefix, headerBytes...)
prefix = append(prefix, buffered...)
return HandlePassThrough, newPreBufferedConn(rawConn, prefix), nil
}
_ = writeSimpleHTTPResponse(rawConn, http.StatusNotFound, "not found") _ = writeSimpleHTTPResponse(rawConn, http.StatusNotFound, "not found")
_ = rawConn.Close() _ = rawConn.Close()
return HandleDone, nil, nil return HandleDone, nil, nil
@@ -1507,19 +1676,31 @@ func (c *bodyConn) Close() error {
return firstErr return firstErr
} }
func (s *TunnelServer) handleStream(rawConn net.Conn, req *httpRequestHeader, buffered []byte) (HandleResult, net.Conn, error) { func (s *TunnelServer) handleStream(rawConn net.Conn, req *httpRequestHeader, headerBytes []byte, buffered []byte) (HandleResult, net.Conn, error) {
u, err := url.ParseRequestURI(req.target) rejectOrReply := func(code int, body string) (HandleResult, net.Conn, error) {
if err != nil { if s.passThroughOnReject {
_ = writeSimpleHTTPResponse(rawConn, http.StatusBadRequest, "bad request") prefix := make([]byte, 0, len(headerBytes)+len(buffered))
prefix = append(prefix, headerBytes...)
prefix = append(prefix, buffered...)
return HandlePassThrough, newPreBufferedConn(rawConn, prefix), nil
}
_ = writeSimpleHTTPResponse(rawConn, code, body)
_ = rawConn.Close() _ = rawConn.Close()
return HandleDone, nil, nil return HandleDone, nil, nil
} }
u, err := url.ParseRequestURI(req.target)
if err != nil {
return rejectOrReply(http.StatusBadRequest, "bad request")
}
// Only accept plausible paths to reduce accidental exposure. // Only accept plausible paths to reduce accidental exposure.
if !isAllowedPath(req.target) { path, ok := stripPathRoot(s.pathRoot, u.Path)
_ = writeSimpleHTTPResponse(rawConn, http.StatusNotFound, "not found") if !ok || !s.isAllowedBasePath(path) {
_ = rawConn.Close() return rejectOrReply(http.StatusNotFound, "not found")
return HandleDone, nil, nil }
if !s.auth.verify(req.headers, TunnelModeStream, req.method, path, time.Now()) {
return rejectOrReply(http.StatusNotFound, "not found")
} }
token := u.Query().Get("token") token := u.Query().Get("token")
@@ -1528,31 +1709,25 @@ func (s *TunnelServer) handleStream(rawConn net.Conn, req *httpRequestHeader, bu
switch strings.ToUpper(req.method) { switch strings.ToUpper(req.method) {
case http.MethodGet: case http.MethodGet:
// Stream split-session: GET /session (no token) => token + start tunnel on a server-side pipe. // Stream split-session: GET /session (no token) => token + start tunnel on a server-side pipe.
if token == "" && u.Path == "/session" { if token == "" && path == "/session" {
return s.authorizeSession(rawConn) return s.authorizeSession(rawConn)
} }
// Stream split-session: GET /stream?token=... => downlink poll. // Stream split-session: GET /stream?token=... => downlink poll.
if token != "" && u.Path == "/stream" { if token != "" && path == "/stream" {
return s.streamPull(rawConn, token) return s.streamPull(rawConn, token)
} }
_ = writeSimpleHTTPResponse(rawConn, http.StatusBadRequest, "bad request") return rejectOrReply(http.StatusBadRequest, "bad request")
_ = rawConn.Close()
return HandleDone, nil, nil
case http.MethodPost: case http.MethodPost:
// Stream split-session: POST /api/v1/upload?token=... => uplink push. // Stream split-session: POST /api/v1/upload?token=... => uplink push.
if token != "" && u.Path == "/api/v1/upload" { if token != "" && path == "/api/v1/upload" {
if closeFlag { if closeFlag {
s.closeSession(token) s.closeSession(token)
_ = writeSimpleHTTPResponse(rawConn, http.StatusOK, "") return rejectOrReply(http.StatusOK, "")
_ = rawConn.Close()
return HandleDone, nil, nil
} }
bodyReader, err := newRequestBodyReader(newPreBufferedConn(rawConn, buffered), req.headers) bodyReader, err := newRequestBodyReader(newPreBufferedConn(rawConn, buffered), req.headers)
if err != nil { if err != nil {
_ = writeSimpleHTTPResponse(rawConn, http.StatusBadRequest, "bad request") return rejectOrReply(http.StatusBadRequest, "bad request")
_ = rawConn.Close()
return HandleDone, nil, nil
} }
return s.streamPush(rawConn, token, bodyReader) return s.streamPush(rawConn, token, bodyReader)
} }
@@ -1581,19 +1756,13 @@ func (s *TunnelServer) handleStream(rawConn net.Conn, req *httpRequestHeader, bu
return HandleStartTunnel, stream, nil return HandleStartTunnel, stream, nil
default: default:
_ = writeSimpleHTTPResponse(rawConn, http.StatusBadRequest, "bad request") return rejectOrReply(http.StatusBadRequest, "bad request")
_ = rawConn.Close()
return HandleDone, nil, nil
} }
} }
func isAllowedPath(target string) bool { func (s *TunnelServer) isAllowedBasePath(path string) bool {
u, err := url.ParseRequestURI(target)
if err != nil {
return false
}
for _, p := range paths { for _, p := range paths {
if u.Path == p { if path == p {
return true return true
} }
} }
@@ -1650,51 +1819,58 @@ func writeTokenHTTPResponse(w io.Writer, token string) error {
return err return err
} }
func (s *TunnelServer) handlePoll(rawConn net.Conn, req *httpRequestHeader, buffered []byte) (HandleResult, net.Conn, error) { func (s *TunnelServer) handlePoll(rawConn net.Conn, req *httpRequestHeader, headerBytes []byte, buffered []byte) (HandleResult, net.Conn, error) {
u, err := url.ParseRequestURI(req.target) rejectOrReply := func(code int, body string) (HandleResult, net.Conn, error) {
if err != nil { if s.passThroughOnReject {
_ = writeSimpleHTTPResponse(rawConn, http.StatusBadRequest, "bad request") prefix := make([]byte, 0, len(headerBytes)+len(buffered))
prefix = append(prefix, headerBytes...)
prefix = append(prefix, buffered...)
return HandlePassThrough, newPreBufferedConn(rawConn, prefix), nil
}
_ = writeSimpleHTTPResponse(rawConn, code, body)
_ = rawConn.Close() _ = rawConn.Close()
return HandleDone, nil, nil return HandleDone, nil, nil
} }
if !isAllowedPath(req.target) { u, err := url.ParseRequestURI(req.target)
_ = writeSimpleHTTPResponse(rawConn, http.StatusNotFound, "not found") if err != nil {
_ = rawConn.Close() return rejectOrReply(http.StatusBadRequest, "bad request")
return HandleDone, nil, nil }
path, ok := stripPathRoot(s.pathRoot, u.Path)
if !ok || !s.isAllowedBasePath(path) {
return rejectOrReply(http.StatusNotFound, "not found")
}
if !s.auth.verify(req.headers, TunnelModePoll, req.method, path, time.Now()) {
return rejectOrReply(http.StatusNotFound, "not found")
} }
token := u.Query().Get("token") token := u.Query().Get("token")
closeFlag := u.Query().Get("close") == "1" closeFlag := u.Query().Get("close") == "1"
switch strings.ToUpper(req.method) { switch strings.ToUpper(req.method) {
case http.MethodGet: case http.MethodGet:
if token == "" { if token == "" && path == "/session" {
return s.authorizeSession(rawConn) return s.authorizeSession(rawConn)
} }
return s.pollPull(rawConn, token) if token != "" && path == "/stream" {
return s.pollPull(rawConn, token)
}
return rejectOrReply(http.StatusBadRequest, "bad request")
case http.MethodPost: case http.MethodPost:
if token == "" { if token == "" || path != "/api/v1/upload" {
_ = writeSimpleHTTPResponse(rawConn, http.StatusBadRequest, "missing token") return rejectOrReply(http.StatusBadRequest, "bad request")
_ = rawConn.Close()
return HandleDone, nil, nil
} }
if closeFlag { if closeFlag {
s.closeSession(token) s.closeSession(token)
_ = writeSimpleHTTPResponse(rawConn, http.StatusOK, "") return rejectOrReply(http.StatusOK, "")
_ = rawConn.Close()
return HandleDone, nil, nil
} }
bodyReader, err := newRequestBodyReader(newPreBufferedConn(rawConn, buffered), req.headers) bodyReader, err := newRequestBodyReader(newPreBufferedConn(rawConn, buffered), req.headers)
if err != nil { if err != nil {
_ = writeSimpleHTTPResponse(rawConn, http.StatusBadRequest, "bad request") return rejectOrReply(http.StatusBadRequest, "bad request")
_ = rawConn.Close()
return HandleDone, nil, nil
} }
return s.pollPush(rawConn, token, bodyReader) return s.pollPush(rawConn, token, bodyReader)
default: default:
_ = writeSimpleHTTPResponse(rawConn, http.StatusBadRequest, "bad request") return rejectOrReply(http.StatusBadRequest, "bad request")
_ = rawConn.Close()
return HandleDone, nil, nil
} }
} }

View File

@@ -20,7 +20,11 @@ type byteLayout struct {
} }
func (l *byteLayout) isHint(b byte) bool { func (l *byteLayout) isHint(b byte) bool {
return (b & l.hintMask) == l.hintValue if (b & l.hintMask) == l.hintValue {
return true
}
// ASCII layout maps the single non-printable marker (0x7F) to '\n' on the wire.
return l.name == "ascii" && b == '\n'
} }
// resolveLayout picks the byte layout based on ASCII preference and optional custom pattern. // resolveLayout picks the byte layout based on ASCII preference and optional custom pattern.
@@ -53,12 +57,25 @@ func newASCIILayout() *byteLayout {
padMarker: 0x3F, padMarker: 0x3F,
paddingPool: padding, paddingPool: padding,
encodeHint: func(val, pos byte) byte { encodeHint: func(val, pos byte) byte {
return 0x40 | ((val & 0x03) << 4) | (pos & 0x0F) b := 0x40 | ((val & 0x03) << 4) | (pos & 0x0F)
// Avoid DEL (0x7F) in prefer_ascii mode; map it to '\n' to reduce fingerprint.
if b == 0x7F {
return '\n'
}
return b
}, },
encodeGroup: func(group byte) byte { encodeGroup: func(group byte) byte {
return 0x40 | (group & 0x3F) b := 0x40 | (group & 0x3F)
// Avoid DEL (0x7F) in prefer_ascii mode; map it to '\n' to reduce fingerprint.
if b == 0x7F {
return '\n'
}
return b
}, },
decodeGroup: func(b byte) (byte, bool) { decodeGroup: func(b byte) (byte, bool) {
if b == '\n' {
return 0x3F, true
}
if (b & 0x40) == 0 { if (b & 0x40) == 0 {
return 0, false return 0, false
} }

View File

@@ -3,6 +3,7 @@ package sudoku
import ( import (
"bufio" "bufio"
"bytes" "bytes"
crand "crypto/rand"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
@@ -14,16 +15,20 @@ import (
"github.com/metacubex/mihomo/transport/sudoku/obfs/sudoku" "github.com/metacubex/mihomo/transport/sudoku/obfs/sudoku"
) )
func pickClientTable(cfg *ProtocolConfig) (*sudoku.Table, byte, error) { func pickClientTable(cfg *ProtocolConfig) (*sudoku.Table, error) {
candidates := cfg.tableCandidates() candidates := cfg.tableCandidates()
if len(candidates) == 0 { if len(candidates) == 0 {
return nil, 0, fmt.Errorf("no table configured") return nil, fmt.Errorf("no table configured")
} }
if len(candidates) == 1 { if len(candidates) == 1 {
return candidates[0], 0, nil return candidates[0], nil
} }
idx := int(randomByte()) % len(candidates) var b [1]byte
return candidates[idx], byte(idx), nil if _, err := crand.Read(b[:]); err != nil {
return nil, fmt.Errorf("random table pick failed: %w", err)
}
idx := int(b[0]) % len(candidates)
return candidates[idx], nil
} }
type readOnlyConn struct { type readOnlyConn struct {