feat: add path-root for sudoku (#2511)

This commit is contained in:
saba-futai
2026-01-14 21:25:05 +08:00
committed by GitHub
parent f38fc2020f
commit 06f5fbac06
17 changed files with 660 additions and 187 deletions

View File

@@ -43,6 +43,7 @@ type SudokuOption struct {
HTTPMaskMode string `proxy:"http-mask-mode,omitempty"` // "legacy" (default), "stream", "poll", "auto"
HTTPMaskTLS bool `proxy:"http-mask-tls,omitempty"` // only for http-mask-mode stream/poll/auto
HTTPMaskHost string `proxy:"http-mask-host,omitempty"` // optional Host/SNI override (domain or domain:port)
PathRoot string `proxy:"path-root,omitempty"` // optional first-level path prefix for HTTP tunnel endpoints
HTTPMaskMultiplex string `proxy:"http-mask-multiplex,omitempty"` // "off" (default), "auto" (reuse h1/h2), "on" (single tunnel, multi-target)
CustomTable string `proxy:"custom-table,omitempty"` // optional custom byte layout, e.g. xpxvvpvv
CustomTables []string `proxy:"custom-tables,omitempty"` // optional table rotation patterns, overrides custom-table when non-empty
@@ -183,6 +184,7 @@ func NewSudoku(option SudokuOption) (*Sudoku, error) {
HTTPMaskMode: defaultConf.HTTPMaskMode,
HTTPMaskTLSEnabled: option.HTTPMaskTLS,
HTTPMaskHost: option.HTTPMaskHost,
HTTPMaskPathRoot: strings.TrimSpace(option.PathRoot),
HTTPMaskMultiplex: defaultConf.HTTPMaskMultiplex,
}
if option.HTTPMaskMode != "" {
@@ -257,7 +259,19 @@ func (s *Sudoku) dialAndHandshake(ctx context.Context, cfg *sudoku.ProtocolConfi
return nil, fmt.Errorf("config is required")
}
var c net.Conn
handshakeCfg := *cfg
if !handshakeCfg.DisableHTTPMask && httpTunnelModeEnabled(handshakeCfg.HTTPMaskMode) {
handshakeCfg.DisableHTTPMask = true
}
upgrade := func(raw net.Conn) (net.Conn, error) {
return sudoku.ClientHandshakeWithOptions(raw, &handshakeCfg, sudoku.ClientHandshakeOptions{})
}
var (
c net.Conn
handshakeDone bool
)
if !cfg.DisableHTTPMask && httpTunnelModeEnabled(cfg.HTTPMaskMode) {
muxMode := normalizeHTTPMaskMultiplex(cfg.HTTPMaskMultiplex)
switch muxMode {
@@ -266,9 +280,12 @@ func (s *Sudoku) dialAndHandshake(ctx context.Context, cfg *sudoku.ProtocolConfi
if errX != nil {
return nil, errX
}
c, err = client.Dial(ctx)
c, err = client.Dial(ctx, upgrade)
default:
c, err = sudoku.DialHTTPMaskTunnel(ctx, cfg.ServerAddress, cfg, s.dialer.DialContext)
c, err = sudoku.DialHTTPMaskTunnel(ctx, cfg.ServerAddress, cfg, s.dialer.DialContext, upgrade)
}
if err == nil && c != nil {
handshakeDone = true
}
}
if c == nil && err == nil {
@@ -285,14 +302,11 @@ func (s *Sudoku) dialAndHandshake(ctx context.Context, cfg *sudoku.ProtocolConfi
defer done(&err)
}
handshakeCfg := *cfg
if !handshakeCfg.DisableHTTPMask && httpTunnelModeEnabled(handshakeCfg.HTTPMaskMode) {
handshakeCfg.DisableHTTPMask = true
}
c, err = sudoku.ClientHandshakeWithOptions(c, &handshakeCfg, sudoku.ClientHandshakeOptions{})
if err != nil {
return nil, err
if !handshakeDone {
c, err = sudoku.ClientHandshakeWithOptions(c, &handshakeCfg, sudoku.ClientHandshakeOptions{})
if err != nil {
return nil, err
}
}
return c, nil

View File

@@ -1072,6 +1072,7 @@ proxies: # socks5
# http-mask-mode: legacy # 可选legacy默认、stream、poll、autostream/poll/auto 支持走 CDN/反代
# http-mask-tls: true # 可选:仅在 http-mask-mode 为 stream/poll/auto 时生效true 强制 httpsfalse 强制 http不会根据端口自动推断
# http-mask-host: "" # 可选:覆盖 Host/SNI支持 example.com 或 example.com:443仅在 http-mask-mode 为 stream/poll/auto 时生效
# path-root: "" # 可选HTTP 隧道端点一级路径前缀(双方需一致),例如 "aabbcc" => /aabbcc/session、/aabbcc/stream、/aabbcc/api/v1/upload
# http-mask-multiplex: off # 可选off默认、auto复用 h1.1 keep-alive / h2 连接,减少每次建链 RTT、on单条隧道内多路复用多个目标连接仅在 http-mask-mode=stream/poll/auto 生效)
enable-pure-downlink: false # 是否启用混淆下行false的情况下能在保证数据安全的前提下极大提升下行速度与服务端端保持相同(如果此处为false则要求aead不可为none)
@@ -1621,6 +1622,7 @@ listeners:
enable-pure-downlink: false # 是否启用混淆下行false的情况下能在保证数据安全的前提下极大提升下行速度与客户端保持相同(如果此处为false则要求aead不可为none)
disable-http-mask: false # 可选:禁用 http 掩码/隧道(默认为 false
# http-mask-mode: legacy # 可选legacy默认、stream、poll、autostream/poll/auto 支持走 CDN/反代
# path-root: "" # 可选HTTP 隧道端点一级路径前缀(双方需一致),例如 "aabbcc" => /aabbcc/session、/aabbcc/stream、/aabbcc/api/v1/upload

View File

@@ -22,6 +22,7 @@ type SudokuServer struct {
CustomTables []string `json:"custom-tables,omitempty"`
DisableHTTPMask bool `json:"disable-http-mask,omitempty"`
HTTPMaskMode string `json:"http-mask-mode,omitempty"`
PathRoot string `json:"path-root,omitempty"`
// mihomo private extension (not the part of standard Sudoku protocol)
MuxOption sing.MuxOption `json:"mux-option,omitempty"`

View File

@@ -24,6 +24,7 @@ type SudokuOption struct {
CustomTables []string `inbound:"custom-tables,omitempty"`
DisableHTTPMask bool `inbound:"disable-http-mask,omitempty"`
HTTPMaskMode string `inbound:"http-mask-mode,omitempty"` // "legacy" (default), "stream", "poll", "auto"
PathRoot string `inbound:"path-root,omitempty"` // optional first-level path prefix for HTTP tunnel endpoints
// mihomo private extension (not the part of standard Sudoku protocol)
MuxOption MuxOption `inbound:"mux-option,omitempty"`
@@ -63,6 +64,7 @@ func NewSudoku(options *SudokuOption) (*Sudoku, error) {
CustomTables: options.CustomTables,
DisableHTTPMask: options.DisableHTTPMask,
HTTPMaskMode: options.HTTPMaskMode,
PathRoot: strings.TrimSpace(options.PathRoot),
}
serverConf.MuxOption = options.MuxOption.Build()

View File

@@ -229,6 +229,7 @@ func New(config LC.SudokuServer, tunnel C.Tunnel, additions ...inbound.Addition)
HandshakeTimeoutSeconds: handshakeTimeout,
DisableHTTPMask: config.DisableHTTPMask,
HTTPMaskMode: config.HTTPMaskMode,
HTTPMaskPathRoot: strings.TrimSpace(config.PathRoot),
}
if len(tables) == 1 {
protoConf.Table = tables[0]

View File

@@ -58,6 +58,10 @@ type ProtocolConfig struct {
// HTTPMaskHost optionally overrides the HTTP Host header / SNI host for HTTP tunnel modes (client-side).
HTTPMaskHost string
// HTTPMaskPathRoot optionally prefixes all HTTP mask paths with a first-level segment.
// Example: "aabbcc" => "/aabbcc/session", "/aabbcc/api/v1/upload", ...
HTTPMaskPathRoot string
// HTTPMaskMultiplex controls multiplex behavior when HTTPMask tunnel modes are enabled:
// - "off": disable reuse; each Dial establishes its own HTTPMask tunnel
// - "auto": reuse underlying HTTP connections across multiple tunnel dials (HTTP/1.1 keep-alive / HTTP/2)
@@ -109,6 +113,23 @@ func (c *ProtocolConfig) Validate() error {
return fmt.Errorf("invalid http-mask-mode: %s, must be one of: legacy, stream, poll, auto", c.HTTPMaskMode)
}
if v := strings.TrimSpace(c.HTTPMaskPathRoot); v != "" {
if strings.Contains(v, "/") {
return fmt.Errorf("invalid http-mask-path-root: must be a single path segment")
}
for i := 0; i < len(v); i++ {
ch := v[i]
switch {
case ch >= 'a' && ch <= 'z':
case ch >= 'A' && ch <= 'Z':
case ch >= '0' && ch <= '9':
case ch == '_' || ch == '-':
default:
return fmt.Errorf("invalid http-mask-path-root: contains invalid character %q", ch)
}
}
}
switch strings.ToLower(strings.TrimSpace(c.HTTPMaskMultiplex)) {
case "", "off", "auto", "on":
default:

View File

@@ -2,7 +2,6 @@ package sudoku
import (
"bufio"
"crypto/rand"
"crypto/sha256"
"encoding/binary"
"encoding/hex"
@@ -153,14 +152,17 @@ func buildServerObfsConn(raw net.Conn, cfg *ProtocolConfig, table *sudoku.Table,
func buildHandshakePayload(key string) [16]byte {
var payload [16]byte
binary.BigEndian.PutUint64(payload[:8], uint64(time.Now().Unix()))
// Hash the decoded HEX bytes of the key, not the HEX string itself.
// This ensures the user hash is computed on the actual key bytes.
keyBytes, err := hex.DecodeString(key)
if err != nil {
// Fallback: if key is not valid HEX (e.g., a UUID or plain string), hash the string bytes
keyBytes = []byte(key)
// Align with upstream: only decode hex bytes when this key is an ED25519 key material.
// For plain UUID/strings (even if they look like hex), hash the string bytes as-is.
src := []byte(key)
if _, err := crypto.RecoverPublicKey(key); err == nil {
if keyBytes, decErr := hex.DecodeString(key); decErr == nil && len(keyBytes) > 0 {
src = keyBytes
}
}
hash := sha256.Sum256(keyBytes)
hash := sha256.Sum256(src)
copy(payload[8:], hash[:8])
return payload
}
@@ -211,12 +213,12 @@ func ClientHandshakeWithOptions(rawConn net.Conn, cfg *ProtocolConfig, opt Clien
}
if !cfg.DisableHTTPMask {
if err := WriteHTTPMaskHeader(rawConn, cfg.ServerAddress, opt.HTTPMaskStrategy); err != nil {
if err := WriteHTTPMaskHeader(rawConn, cfg.ServerAddress, cfg.HTTPMaskPathRoot, opt.HTTPMaskStrategy); err != nil {
return nil, fmt.Errorf("write http mask failed: %w", err)
}
}
table, tableID, err := pickClientTable(cfg)
table, err := pickClientTable(cfg)
if err != nil {
return nil, err
}
@@ -228,9 +230,6 @@ func ClientHandshakeWithOptions(rawConn net.Conn, cfg *ProtocolConfig, opt Clien
}
handshake := buildHandshakePayload(cfg.Key)
if len(cfg.tableCandidates()) > 1 {
handshake[8] = tableID
}
if _, err := cConn.Write(handshake[:]); err != nil {
cConn.Close()
return nil, fmt.Errorf("send handshake failed: %w", err)
@@ -376,19 +375,9 @@ func normalizeHTTPMaskStrategy(strategy string) string {
}
}
// randomByte returns a cryptographically random byte (with a math/rand fallback).
func randomByte() byte {
var b [1]byte
if _, err := rand.Read(b[:]); err == nil {
return b[0]
}
return byte(time.Now().UnixNano())
}
func userHashFromHandshake(handshakeBuf []byte) string {
if len(handshakeBuf) < 16 {
return ""
}
// handshake[8] may be a table ID when table rotation is enabled; use [9:16] as stable user hash bytes.
return hex.EncodeToString(handshakeBuf[9:16])
return hex.EncodeToString(handshakeBuf[8:16])
}

View File

@@ -7,6 +7,7 @@ import (
"math/rand"
"net"
"strconv"
"strings"
"sync"
"time"
@@ -92,24 +93,24 @@ func appendCommonHeaders(buf []byte, host string, r *rand.Rand) []byte {
// WriteHTTPMaskHeader writes an HTTP/1.x request header as a mask, according to strategy.
// Supported strategies: ""/"random", "post", "websocket".
func WriteHTTPMaskHeader(w io.Writer, host string, strategy string) error {
func WriteHTTPMaskHeader(w io.Writer, host string, pathRoot string, strategy string) error {
switch normalizeHTTPMaskStrategy(strategy) {
case "random":
return httpmask.WriteRandomRequestHeader(w, host)
return httpmask.WriteRandomRequestHeaderWithPathRoot(w, host, pathRoot)
case "post":
return writeHTTPMaskPOST(w, host)
return writeHTTPMaskPOST(w, host, pathRoot)
case "websocket":
return writeHTTPMaskWebSocket(w, host)
return writeHTTPMaskWebSocket(w, host, pathRoot)
default:
return fmt.Errorf("unsupported http-mask-strategy: %s", strategy)
}
}
func writeHTTPMaskPOST(w io.Writer, host string) error {
func writeHTTPMaskPOST(w io.Writer, host string, pathRoot string) error {
r := httpMaskRngPool.Get().(*rand.Rand)
defer httpMaskRngPool.Put(r)
path := httpMaskPaths[r.Intn(len(httpMaskPaths))]
path := joinPathRoot(pathRoot, httpMaskPaths[r.Intn(len(httpMaskPaths))])
ctype := httpMaskContentTypes[r.Intn(len(httpMaskContentTypes))]
bufPtr := httpMaskBufPool.Get().(*[]byte)
@@ -140,11 +141,11 @@ func writeHTTPMaskPOST(w io.Writer, host string) error {
return err
}
func writeHTTPMaskWebSocket(w io.Writer, host string) error {
func writeHTTPMaskWebSocket(w io.Writer, host string, pathRoot string) error {
r := httpMaskRngPool.Get().(*rand.Rand)
defer httpMaskRngPool.Put(r)
path := httpMaskPaths[r.Intn(len(httpMaskPaths))]
path := joinPathRoot(pathRoot, httpMaskPaths[r.Intn(len(httpMaskPaths))])
bufPtr := httpMaskBufPool.Get().(*[]byte)
buf := *bufPtr
@@ -177,3 +178,37 @@ func writeHTTPMaskWebSocket(w io.Writer, host string) error {
_, err := w.Write(buf)
return err
}
func normalizePathRoot(root string) string {
root = strings.TrimSpace(root)
root = strings.Trim(root, "/")
if root == "" {
return ""
}
for i := 0; i < len(root); i++ {
c := root[i]
switch {
case c >= 'a' && c <= 'z':
case c >= 'A' && c <= 'Z':
case c >= '0' && c <= '9':
case c == '_' || c == '-':
default:
return ""
}
}
return "/" + root
}
func joinPathRoot(root, path string) string {
root = normalizePathRoot(root)
if root == "" {
return path
}
if path == "" {
return root
}
if !strings.HasPrefix(path, "/") {
path = "/" + path
}
return root + path
}

View File

@@ -23,7 +23,11 @@ func NewHTTPMaskTunnelServer(cfg *ProtocolConfig) *HTTPMaskTunnelServer {
if !cfg.DisableHTTPMask {
switch strings.ToLower(strings.TrimSpace(cfg.HTTPMaskMode)) {
case "stream", "poll", "auto":
ts = httpmask.NewTunnelServer(httpmask.TunnelServerOptions{Mode: cfg.HTTPMaskMode})
ts = httpmask.NewTunnelServer(httpmask.TunnelServerOptions{
Mode: cfg.HTTPMaskMode,
PathRoot: cfg.HTTPMaskPathRoot,
AuthKey: ClientAEADSeed(cfg.Key),
})
}
}
return &HTTPMaskTunnelServer{cfg: cfg, ts: ts}
@@ -67,7 +71,7 @@ func (s *HTTPMaskTunnelServer) WrapConn(rawConn net.Conn) (handshakeConn net.Con
type TunnelDialer func(ctx context.Context, network, addr string) (net.Conn, error)
// DialHTTPMaskTunnel dials a CDN-capable HTTP tunnel (stream/poll/auto) and returns a stream carrying raw Sudoku bytes.
func DialHTTPMaskTunnel(ctx context.Context, serverAddress string, cfg *ProtocolConfig, dial TunnelDialer) (net.Conn, error) {
func DialHTTPMaskTunnel(ctx context.Context, serverAddress string, cfg *ProtocolConfig, dial TunnelDialer, upgrade func(net.Conn) (net.Conn, error)) (net.Conn, error) {
if cfg == nil {
return nil, fmt.Errorf("config is required")
}
@@ -83,14 +87,19 @@ func DialHTTPMaskTunnel(ctx context.Context, serverAddress string, cfg *Protocol
Mode: cfg.HTTPMaskMode,
TLSEnabled: cfg.HTTPMaskTLSEnabled,
HostOverride: cfg.HTTPMaskHost,
PathRoot: cfg.HTTPMaskPathRoot,
AuthKey: ClientAEADSeed(cfg.Key),
Upgrade: upgrade,
Multiplex: cfg.HTTPMaskMultiplex,
DialContext: dial,
})
}
type HTTPMaskTunnelClient struct {
mode string
client *httpmask.TunnelClient
mode string
pathRoot string
authKey string
client *httpmask.TunnelClient
}
func NewHTTPMaskTunnelClient(serverAddress string, cfg *ProtocolConfig, dial TunnelDialer) (*HTTPMaskTunnelClient, error) {
@@ -121,16 +130,23 @@ func NewHTTPMaskTunnelClient(serverAddress string, cfg *ProtocolConfig, dial Tun
}
return &HTTPMaskTunnelClient{
mode: cfg.HTTPMaskMode,
client: c,
mode: cfg.HTTPMaskMode,
pathRoot: cfg.HTTPMaskPathRoot,
authKey: ClientAEADSeed(cfg.Key),
client: c,
}, nil
}
func (c *HTTPMaskTunnelClient) Dial(ctx context.Context) (net.Conn, error) {
func (c *HTTPMaskTunnelClient) Dial(ctx context.Context, upgrade func(net.Conn) (net.Conn, error)) (net.Conn, error) {
if c == nil || c.client == nil {
return nil, fmt.Errorf("nil httpmask tunnel client")
}
return c.client.DialTunnel(ctx, c.mode)
return c.client.DialTunnel(ctx, httpmask.TunnelDialOptions{
Mode: c.mode,
PathRoot: c.pathRoot,
AuthKey: c.authKey,
Upgrade: upgrade,
})
}
func (c *HTTPMaskTunnelClient) CloseIdleConnections() {

View File

@@ -154,7 +154,7 @@ func TestHTTPMaskTunnel_Stream_TCPRoundTrip(t *testing.T) {
clientCfg.ServerAddress = addr
clientCfg.HTTPMaskHost = "example.com"
tunnelConn, err := DialHTTPMaskTunnel(ctx, clientCfg.ServerAddress, &clientCfg, (&net.Dialer{}).DialContext)
tunnelConn, err := DialHTTPMaskTunnel(ctx, clientCfg.ServerAddress, &clientCfg, (&net.Dialer{}).DialContext, nil)
if err != nil {
t.Fatalf("dial tunnel: %v", err)
}
@@ -225,7 +225,7 @@ func TestHTTPMaskTunnel_Poll_UoTRoundTrip(t *testing.T) {
clientCfg := *serverCfg
clientCfg.ServerAddress = addr
tunnelConn, err := DialHTTPMaskTunnel(ctx, clientCfg.ServerAddress, &clientCfg, (&net.Dialer{}).DialContext)
tunnelConn, err := DialHTTPMaskTunnel(ctx, clientCfg.ServerAddress, &clientCfg, (&net.Dialer{}).DialContext, nil)
if err != nil {
t.Fatalf("dial tunnel: %v", err)
}
@@ -287,7 +287,7 @@ func TestHTTPMaskTunnel_Auto_TCPRoundTrip(t *testing.T) {
clientCfg := *serverCfg
clientCfg.ServerAddress = addr
tunnelConn, err := DialHTTPMaskTunnel(ctx, clientCfg.ServerAddress, &clientCfg, (&net.Dialer{}).DialContext)
tunnelConn, err := DialHTTPMaskTunnel(ctx, clientCfg.ServerAddress, &clientCfg, (&net.Dialer{}).DialContext, nil)
if err != nil {
t.Fatalf("dial tunnel: %v", err)
}
@@ -331,13 +331,13 @@ func TestHTTPMaskTunnel_Validation(t *testing.T) {
cfg.DisableHTTPMask = true
cfg.HTTPMaskMode = "stream"
if _, err := DialHTTPMaskTunnel(context.Background(), cfg.ServerAddress, cfg, (&net.Dialer{}).DialContext); err == nil {
if _, err := DialHTTPMaskTunnel(context.Background(), cfg.ServerAddress, cfg, (&net.Dialer{}).DialContext, nil); err == nil {
t.Fatalf("expected error for disabled http mask")
}
cfg.DisableHTTPMask = false
cfg.HTTPMaskMode = "legacy"
if _, err := DialHTTPMaskTunnel(context.Background(), cfg.ServerAddress, cfg, (&net.Dialer{}).DialContext); err == nil {
if _, err := DialHTTPMaskTunnel(context.Background(), cfg.ServerAddress, cfg, (&net.Dialer{}).DialContext, nil); err == nil {
t.Fatalf("expected error for legacy mode")
}
}
@@ -385,7 +385,7 @@ func TestHTTPMaskTunnel_Soak_Concurrent(t *testing.T) {
clientCfg.ServerAddress = addr
clientCfg.HTTPMaskHost = strings.TrimSpace(clientCfg.HTTPMaskHost)
tunnelConn, err := DialHTTPMaskTunnel(ctx, clientCfg.ServerAddress, &clientCfg, (&net.Dialer{}).DialContext)
tunnelConn, err := DialHTTPMaskTunnel(ctx, clientCfg.ServerAddress, &clientCfg, (&net.Dialer{}).DialContext, nil)
if err != nil {
runErr <- fmt.Errorf("dial: %w", err)
return

View File

@@ -99,7 +99,7 @@ func TestUserHash_StableAcrossTableRotation(t *testing.T) {
if h == "" {
t.Fatalf("empty user hash")
}
if len(h) != 14 {
if len(h) != 16 {
t.Fatalf("unexpected user hash length: %d", len(h))
}
unique[h] = struct{}{}
@@ -258,4 +258,3 @@ func TestMultiplex_Boundary_InvalidVersion(t *testing.T) {
t.Fatalf("expected error")
}
}

View File

@@ -0,0 +1,137 @@
package httpmask
import (
"crypto/hmac"
"crypto/sha256"
"crypto/subtle"
"encoding/base64"
"encoding/binary"
"strings"
"time"
)
const (
tunnelAuthHeaderKey = "Authorization"
tunnelAuthHeaderPrefix = "Bearer "
)
type tunnelAuth struct {
key [32]byte // derived HMAC key
skew time.Duration
}
func newTunnelAuth(key string, skew time.Duration) *tunnelAuth {
key = strings.TrimSpace(key)
if key == "" {
return nil
}
if skew <= 0 {
skew = 60 * time.Second
}
// Domain separation: keep this HMAC key independent from other uses of cfg.Key.
h := sha256.New()
_, _ = h.Write([]byte("sudoku-httpmask-auth-v1:"))
_, _ = h.Write([]byte(key))
var sum [32]byte
h.Sum(sum[:0])
return &tunnelAuth{key: sum, skew: skew}
}
func (a *tunnelAuth) token(mode TunnelMode, method, path string, now time.Time) string {
if a == nil {
return ""
}
ts := now.Unix()
sig := a.sign(mode, method, path, ts)
var buf [8 + 16]byte
binary.BigEndian.PutUint64(buf[:8], uint64(ts))
copy(buf[8:], sig[:])
return base64.RawURLEncoding.EncodeToString(buf[:])
}
func (a *tunnelAuth) verify(headers map[string]string, mode TunnelMode, method, path string, now time.Time) bool {
if a == nil {
return true
}
if headers == nil {
return false
}
val := strings.TrimSpace(headers["authorization"])
if val == "" {
return false
}
// Accept both "Bearer <token>" and raw token forms (for forward proxies / CDNs that may normalize headers).
if len(val) > len(tunnelAuthHeaderPrefix) && strings.EqualFold(val[:len(tunnelAuthHeaderPrefix)], tunnelAuthHeaderPrefix) {
val = strings.TrimSpace(val[len(tunnelAuthHeaderPrefix):])
}
if val == "" {
return false
}
raw, err := base64.RawURLEncoding.DecodeString(val)
if err != nil || len(raw) != 8+16 {
return false
}
ts := int64(binary.BigEndian.Uint64(raw[:8]))
nowTS := now.Unix()
delta := nowTS - ts
if delta < 0 {
delta = -delta
}
if delta > int64(a.skew.Seconds()) {
return false
}
want := a.sign(mode, method, path, ts)
return subtle.ConstantTimeCompare(raw[8:], want[:]) == 1
}
func (a *tunnelAuth) sign(mode TunnelMode, method, path string, ts int64) [16]byte {
method = strings.ToUpper(strings.TrimSpace(method))
if method == "" {
method = "GET"
}
path = strings.TrimSpace(path)
var tsBuf [8]byte
binary.BigEndian.PutUint64(tsBuf[:], uint64(ts))
mac := hmac.New(sha256.New, a.key[:])
_, _ = mac.Write([]byte(mode))
_, _ = mac.Write([]byte{0})
_, _ = mac.Write([]byte(method))
_, _ = mac.Write([]byte{0})
_, _ = mac.Write([]byte(path))
_, _ = mac.Write([]byte{0})
_, _ = mac.Write(tsBuf[:])
var full [32]byte
mac.Sum(full[:0])
var out [16]byte
copy(out[:], full[:16])
return out
}
type headerSetter interface {
Set(key, value string)
}
func applyTunnelAuthHeader(h headerSetter, auth *tunnelAuth, mode TunnelMode, method, path string) {
if auth == nil || h == nil {
return
}
token := auth.token(mode, method, path, time.Now())
if token == "" {
return
}
h.Set(tunnelAuthHeaderKey, tunnelAuthHeaderPrefix+token)
}

View File

@@ -129,11 +129,17 @@ func appendCommonHeaders(buf []byte, host string, r *rand.Rand) []byte {
// WriteRandomRequestHeader writes a plausible HTTP/1.1 request header as a mask.
func WriteRandomRequestHeader(w io.Writer, host string) error {
return WriteRandomRequestHeaderWithPathRoot(w, host, "")
}
// WriteRandomRequestHeaderWithPathRoot is like WriteRandomRequestHeader but prefixes all paths with pathRoot.
// pathRoot must be a single segment (e.g. "aabbcc"); invalid inputs are treated as empty (disabled).
func WriteRandomRequestHeaderWithPathRoot(w io.Writer, host string, pathRoot string) error {
// Get RNG from pool
r := rngPool.Get().(*rand.Rand)
defer rngPool.Put(r)
path := paths[r.Intn(len(paths))]
path := joinPathRoot(pathRoot, paths[r.Intn(len(paths))])
ctype := contentTypes[r.Intn(len(contentTypes))]
// Use buffer pool

View File

@@ -0,0 +1,52 @@
package httpmask
import "strings"
// normalizePathRoot normalizes the configured path root into "/<segment>" form.
//
// It is intentionally strict: only a single path segment is allowed, consisting of
// [A-Za-z0-9_-]. Invalid inputs are treated as empty (disabled).
func normalizePathRoot(root string) string {
root = strings.TrimSpace(root)
root = strings.Trim(root, "/")
if root == "" {
return ""
}
for i := 0; i < len(root); i++ {
c := root[i]
switch {
case c >= 'a' && c <= 'z':
case c >= 'A' && c <= 'Z':
case c >= '0' && c <= '9':
case c == '_' || c == '-':
default:
return ""
}
}
return "/" + root
}
func joinPathRoot(root, path string) string {
root = normalizePathRoot(root)
if root == "" {
return path
}
if path == "" {
return root
}
if !strings.HasPrefix(path, "/") {
path = "/" + path
}
return root + path
}
func stripPathRoot(root, fullPath string) (string, bool) {
root = normalizePathRoot(root)
if root == "" {
return fullPath, true
}
if !strings.HasPrefix(fullPath, root+"/") {
return "", false
}
return strings.TrimPrefix(fullPath, root), true
}

View File

@@ -62,6 +62,15 @@ type TunnelDialOptions struct {
Mode string
TLSEnabled bool // when true, use HTTPS; otherwise, use HTTP (no port-based inference)
HostOverride string // optional Host header / SNI host (without scheme); accepts "example.com" or "example.com:443"
// PathRoot is an optional first-level path prefix for all HTTP tunnel endpoints.
// Example: "aabbcc" => "/aabbcc/session", "/aabbcc/api/v1/upload", ...
PathRoot string
// AuthKey enables short-term HMAC auth for HTTP tunnel requests (anti-probing).
// When set (non-empty), each HTTP request carries an Authorization bearer token derived from AuthKey.
AuthKey string
// Upgrade optionally wraps the raw tunnel conn and/or writes a small prelude before DialTunnel returns.
// It is called with the raw tunnel conn; if it returns a non-nil conn, that conn is returned by DialTunnel.
Upgrade func(raw net.Conn) (net.Conn, error)
// Multiplex controls whether the caller should reuse underlying HTTP connections (HTTP/1.1 keep-alive / HTTP/2).
// To reuse across multiple dials, create a TunnelClient per proxy and reuse it.
// Values: "off" disables reuse; "auto"/"on" enables it.
@@ -109,34 +118,34 @@ func (c *TunnelClient) CloseIdleConnections() {
c.transport.CloseIdleConnections()
}
func (c *TunnelClient) DialTunnel(ctx context.Context, mode string) (net.Conn, error) {
func (c *TunnelClient) DialTunnel(ctx context.Context, opts TunnelDialOptions) (net.Conn, error) {
if c == nil || c.client == nil {
return nil, fmt.Errorf("nil tunnel client")
}
tm := normalizeTunnelMode(mode)
tm := normalizeTunnelMode(opts.Mode)
if tm == TunnelModeLegacy {
return nil, fmt.Errorf("legacy mode does not use http tunnel")
}
switch tm {
case TunnelModeStream:
return dialStreamWithClient(ctx, c.client, c.target)
return dialStreamWithClient(ctx, c.client, c.target, opts)
case TunnelModePoll:
return dialPollWithClient(ctx, c.client, c.target)
return dialPollWithClient(ctx, c.client, c.target, opts)
case TunnelModeAuto:
streamCtx, cancelX := context.WithTimeout(ctx, 3*time.Second)
c1, errX := dialStreamWithClient(streamCtx, c.client, c.target)
c1, errX := dialStreamWithClient(streamCtx, c.client, c.target, opts)
cancelX()
if errX == nil {
return c1, nil
}
c2, errP := dialPollWithClient(ctx, c.client, c.target)
c2, errP := dialPollWithClient(ctx, c.client, c.target, opts)
if errP == nil {
return c2, nil
}
return nil, fmt.Errorf("auto tunnel failed: stream: %v; poll: %w", errX, errP)
default:
return dialStreamWithClient(ctx, c.client, c.target)
return dialStreamWithClient(ctx, c.client, c.target, opts)
}
}
@@ -248,8 +257,13 @@ func (c *httpStreamConn) Close() error {
if c.cancel != nil {
c.cancel()
}
_ = c.writer.CloseWithError(io.ErrClosedPipe)
return c.reader.Close()
if c.writer != nil {
_ = c.writer.CloseWithError(io.ErrClosedPipe)
}
if c.reader != nil {
return c.reader.Close()
}
return nil
}
func (c *httpStreamConn) LocalAddr() net.Addr { return c.localAddr }
@@ -320,20 +334,23 @@ type sessionDialInfo struct {
pullURL string
closeURL string
headerHost string
auth *tunnelAuth
}
func dialSessionWithClient(ctx context.Context, client *http.Client, target httpClientTarget, mode TunnelMode) (*sessionDialInfo, error) {
func dialSessionWithClient(ctx context.Context, client *http.Client, target httpClientTarget, mode TunnelMode, opts TunnelDialOptions) (*sessionDialInfo, error) {
if client == nil {
return nil, fmt.Errorf("nil http client")
}
authorizeURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: "/session"}).String()
auth := newTunnelAuth(opts.AuthKey, 0)
authorizeURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: joinPathRoot(opts.PathRoot, "/session")}).String()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, authorizeURL, nil)
if err != nil {
return nil, err
}
req.Host = target.headerHost
applyTunnelHeaders(req.Header, target.headerHost, mode)
applyTunnelAuthHeader(req.Header, auth, mode, http.MethodGet, "/session")
resp, err := client.Do(req)
if err != nil {
@@ -356,9 +373,9 @@ func dialSessionWithClient(ctx context.Context, client *http.Client, target http
return nil, fmt.Errorf("%s authorize empty token", mode)
}
pushURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: "/api/v1/upload", RawQuery: "token=" + url.QueryEscape(token)}).String()
pullURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: "/stream", RawQuery: "token=" + url.QueryEscape(token)}).String()
closeURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: "/api/v1/upload", RawQuery: "token=" + url.QueryEscape(token) + "&close=1"}).String()
pushURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: joinPathRoot(opts.PathRoot, "/api/v1/upload"), RawQuery: "token=" + url.QueryEscape(token)}).String()
pullURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: joinPathRoot(opts.PathRoot, "/stream"), RawQuery: "token=" + url.QueryEscape(token)}).String()
closeURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: joinPathRoot(opts.PathRoot, "/api/v1/upload"), RawQuery: "token=" + url.QueryEscape(token) + "&close=1"}).String()
return &sessionDialInfo{
client: client,
@@ -366,6 +383,7 @@ func dialSessionWithClient(ctx context.Context, client *http.Client, target http
pullURL: pullURL,
closeURL: closeURL,
headerHost: target.headerHost,
auth: auth,
}, nil
}
@@ -374,10 +392,10 @@ func dialSession(ctx context.Context, serverAddress string, opts TunnelDialOptio
if err != nil {
return nil, err
}
return dialSessionWithClient(ctx, client, target, mode)
return dialSessionWithClient(ctx, client, target, mode, opts)
}
func bestEffortCloseSession(client *http.Client, closeURL, headerHost string, mode TunnelMode) {
func bestEffortCloseSession(client *http.Client, closeURL, headerHost string, mode TunnelMode, auth *tunnelAuth) {
if client == nil || closeURL == "" || headerHost == "" {
return
}
@@ -391,6 +409,7 @@ func bestEffortCloseSession(client *http.Client, closeURL, headerHost string, mo
}
req.Host = headerHost
applyTunnelHeaders(req.Header, headerHost, mode)
applyTunnelAuthHeader(req.Header, auth, mode, http.MethodPost, "/api/v1/upload")
resp, err := client.Do(req)
if err != nil || resp == nil {
@@ -400,13 +419,13 @@ func bestEffortCloseSession(client *http.Client, closeURL, headerHost string, mo
_ = resp.Body.Close()
}
func dialStreamWithClient(ctx context.Context, client *http.Client, target httpClientTarget) (net.Conn, error) {
// Prefer split session (Cloudflare-friendly). Fall back to stream-one for older servers / environments.
c, errSplit := dialStreamSplitWithClient(ctx, client, target)
func dialStreamWithClient(ctx context.Context, client *http.Client, target httpClientTarget, opts TunnelDialOptions) (net.Conn, error) {
// Prefer split-session (Cloudflare-friendly). Fall back to stream-one for older servers / environments.
c, errSplit := dialStreamSplitWithClient(ctx, client, target, opts)
if errSplit == nil {
return c, nil
}
c2, errOne := dialStreamOneWithClient(ctx, client, target)
c2, errOne := dialStreamOneWithClient(ctx, client, target, opts)
if errOne == nil {
return c2, nil
}
@@ -414,7 +433,7 @@ func dialStreamWithClient(ctx context.Context, client *http.Client, target httpC
}
func dialStream(ctx context.Context, serverAddress string, opts TunnelDialOptions) (net.Conn, error) {
// Prefer split session (Cloudflare-friendly). Fall back to stream-one for older servers / environments.
// Prefer split-session (Cloudflare-friendly). Fall back to stream-one for older servers / environments.
c, errSplit := dialStreamSplit(ctx, serverAddress, opts)
if errSplit == nil {
return c, nil
@@ -426,13 +445,15 @@ func dialStream(ctx context.Context, serverAddress string, opts TunnelDialOption
return nil, fmt.Errorf("dial stream failed: split: %v; stream-one: %w", errSplit, errOne)
}
func dialStreamOneWithClient(ctx context.Context, client *http.Client, target httpClientTarget) (net.Conn, error) {
func dialStreamOneWithClient(ctx context.Context, client *http.Client, target httpClientTarget, opts TunnelDialOptions) (net.Conn, error) {
if client == nil {
return nil, fmt.Errorf("nil http client")
}
auth := newTunnelAuth(opts.AuthKey, 0)
r := rngPool.Get().(*mrand.Rand)
path := paths[r.Intn(len(paths))]
basePath := paths[r.Intn(len(paths))]
path := joinPathRoot(opts.PathRoot, basePath)
ctype := contentTypes[r.Intn(len(contentTypes))]
rngPool.Put(r)
@@ -454,6 +475,7 @@ func dialStreamOneWithClient(ctx context.Context, client *http.Client, target ht
req.Host = target.headerHost
applyTunnelHeaders(req.Header, target.headerHost, TunnelModeStream)
applyTunnelAuthHeader(req.Header, auth, TunnelModeStream, http.MethodPost, basePath)
req.Header.Set("Content-Type", ctype)
type doResult struct {
@@ -466,33 +488,84 @@ func dialStreamOneWithClient(ctx context.Context, client *http.Client, target ht
doCh <- doResult{resp: resp, err: doErr}
}()
select {
case <-ctx.Done():
connCancel()
_ = reqBodyW.Close()
return nil, ctx.Err()
case r := <-doCh:
if r.err != nil {
connCancel()
_ = reqBodyW.Close()
return nil, r.err
}
if r.resp.StatusCode != http.StatusOK {
defer r.resp.Body.Close()
body, _ := io.ReadAll(io.LimitReader(r.resp.Body, 4*1024))
connCancel()
_ = reqBodyW.Close()
return nil, fmt.Errorf("stream bad status: %s (%s)", r.resp.Status, strings.TrimSpace(string(body)))
}
return &httpStreamConn{
reader: r.resp.Body,
writer: reqBodyW,
cancel: connCancel,
localAddr: &net.TCPAddr{},
remoteAddr: &net.TCPAddr{},
}, nil
streamConn := &httpStreamConn{
writer: reqBodyW,
cancel: connCancel,
localAddr: &net.TCPAddr{},
remoteAddr: &net.TCPAddr{},
}
type upgradeResult struct {
conn net.Conn
err error
}
upgradeCh := make(chan upgradeResult, 1)
if opts.Upgrade == nil {
upgradeCh <- upgradeResult{conn: streamConn, err: nil}
} else {
go func() {
upgradeConn, err := opts.Upgrade(streamConn)
if err != nil {
upgradeCh <- upgradeResult{conn: nil, err: err}
return
}
if upgradeConn == nil {
upgradeConn = streamConn
}
upgradeCh <- upgradeResult{conn: upgradeConn, err: nil}
}()
}
var (
outConn net.Conn
upgradeDone bool
responseReady bool
)
for !(upgradeDone && responseReady) {
select {
case <-ctx.Done():
_ = streamConn.Close()
if outConn != nil && outConn != streamConn {
_ = outConn.Close()
}
return nil, ctx.Err()
case u := <-upgradeCh:
if u.err != nil {
_ = streamConn.Close()
return nil, u.err
}
outConn = u.conn
if outConn == nil {
outConn = streamConn
}
upgradeDone = true
case r := <-doCh:
if r.err != nil {
_ = streamConn.Close()
if outConn != nil && outConn != streamConn {
_ = outConn.Close()
}
return nil, r.err
}
if r.resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(r.resp.Body, 4*1024))
_ = r.resp.Body.Close()
_ = streamConn.Close()
if outConn != nil && outConn != streamConn {
_ = outConn.Close()
}
return nil, fmt.Errorf("stream bad status: %s (%s)", r.resp.Status, strings.TrimSpace(string(body)))
}
streamConn.reader = r.resp.Body
responseReady = true
}
}
return outConn, nil
}
func dialStreamOne(ctx context.Context, serverAddress string, opts TunnelDialOptions) (net.Conn, error) {
@@ -500,7 +573,7 @@ func dialStreamOne(ctx context.Context, serverAddress string, opts TunnelDialOpt
if err != nil {
return nil, err
}
return dialStreamOneWithClient(ctx, client, target)
return dialStreamOneWithClient(ctx, client, target, opts)
}
type queuedConn struct {
@@ -599,6 +672,7 @@ type streamSplitConn struct {
pullURL string
closeURL string
headerHost string
auth *tunnelAuth
}
func (c *streamSplitConn) Close() error {
@@ -607,7 +681,7 @@ func (c *streamSplitConn) Close() error {
if c.cancel != nil {
c.cancel()
}
bestEffortCloseSession(c.client, c.closeURL, c.headerHost, TunnelModeStream)
bestEffortCloseSession(c.client, c.closeURL, c.headerHost, TunnelModeStream, c.auth)
return nil
}
@@ -625,6 +699,7 @@ func newStreamSplitConnFromInfo(info *sessionDialInfo) *streamSplitConn {
pullURL: info.pullURL,
closeURL: info.closeURL,
headerHost: info.headerHost,
auth: info.auth,
queuedConn: queuedConn{
rxc: make(chan []byte, 256),
closed: make(chan struct{}),
@@ -639,8 +714,8 @@ func newStreamSplitConnFromInfo(info *sessionDialInfo) *streamSplitConn {
return c
}
func dialStreamSplitWithClient(ctx context.Context, client *http.Client, target httpClientTarget) (net.Conn, error) {
info, err := dialSessionWithClient(ctx, client, target, TunnelModeStream)
func dialStreamSplitWithClient(ctx context.Context, client *http.Client, target httpClientTarget, opts TunnelDialOptions) (net.Conn, error) {
info, err := dialSessionWithClient(ctx, client, target, TunnelModeStream, opts)
if err != nil {
return nil, err
}
@@ -648,7 +723,18 @@ func dialStreamSplitWithClient(ctx context.Context, client *http.Client, target
if c == nil {
return nil, fmt.Errorf("failed to build stream split conn")
}
return c, nil
outConn := net.Conn(c)
if opts.Upgrade != nil {
upgraded, err := opts.Upgrade(c)
if err != nil {
_ = c.Close()
return nil, err
}
if upgraded != nil {
outConn = upgraded
}
}
return outConn, nil
}
func dialStreamSplit(ctx context.Context, serverAddress string, opts TunnelDialOptions) (net.Conn, error) {
@@ -660,7 +746,18 @@ func dialStreamSplit(ctx context.Context, serverAddress string, opts TunnelDialO
if c == nil {
return nil, fmt.Errorf("failed to build stream split conn")
}
return c, nil
outConn := net.Conn(c)
if opts.Upgrade != nil {
upgraded, err := opts.Upgrade(c)
if err != nil {
_ = c.Close()
return nil, err
}
if upgraded != nil {
outConn = upgraded
}
}
return outConn, nil
}
func (c *streamSplitConn) pullLoop() {
@@ -696,6 +793,7 @@ func (c *streamSplitConn) pullLoop() {
}
req.Host = c.headerHost
applyTunnelHeaders(req.Header, c.headerHost, TunnelModeStream)
applyTunnelAuthHeader(req.Header, c.auth, TunnelModeStream, http.MethodGet, "/stream")
resp, err := c.client.Do(req)
if err != nil {
@@ -793,6 +891,7 @@ func (c *streamSplitConn) pushLoop() {
}
req.Host = c.headerHost
applyTunnelHeaders(req.Header, c.headerHost, TunnelModeStream)
applyTunnelAuthHeader(req.Header, c.auth, TunnelModeStream, http.MethodPost, "/api/v1/upload")
req.Header.Set("Content-Type", "application/octet-stream")
resp, err := c.client.Do(req)
@@ -896,6 +995,7 @@ type pollConn struct {
pullURL string
closeURL string
headerHost string
auth *tunnelAuth
}
func isDialError(err error) bool {
@@ -917,7 +1017,7 @@ func (c *pollConn) closeWithError(err error) error {
if c.cancel != nil {
c.cancel()
}
bestEffortCloseSession(c.client, c.closeURL, c.headerHost, TunnelModePoll)
bestEffortCloseSession(c.client, c.closeURL, c.headerHost, TunnelModePoll, c.auth)
return nil
}
@@ -939,6 +1039,7 @@ func newPollConnFromInfo(info *sessionDialInfo) *pollConn {
pullURL: info.pullURL,
closeURL: info.closeURL,
headerHost: info.headerHost,
auth: info.auth,
queuedConn: queuedConn{
rxc: make(chan []byte, 128),
closed: make(chan struct{}),
@@ -953,8 +1054,8 @@ func newPollConnFromInfo(info *sessionDialInfo) *pollConn {
return c
}
func dialPollWithClient(ctx context.Context, client *http.Client, target httpClientTarget) (net.Conn, error) {
info, err := dialSessionWithClient(ctx, client, target, TunnelModePoll)
func dialPollWithClient(ctx context.Context, client *http.Client, target httpClientTarget, opts TunnelDialOptions) (net.Conn, error) {
info, err := dialSessionWithClient(ctx, client, target, TunnelModePoll, opts)
if err != nil {
return nil, err
}
@@ -962,7 +1063,18 @@ func dialPollWithClient(ctx context.Context, client *http.Client, target httpCli
if c == nil {
return nil, fmt.Errorf("failed to build poll conn")
}
return c, nil
outConn := net.Conn(c)
if opts.Upgrade != nil {
upgraded, err := opts.Upgrade(c)
if err != nil {
_ = c.Close()
return nil, err
}
if upgraded != nil {
outConn = upgraded
}
}
return outConn, nil
}
func dialPoll(ctx context.Context, serverAddress string, opts TunnelDialOptions) (net.Conn, error) {
@@ -974,7 +1086,18 @@ func dialPoll(ctx context.Context, serverAddress string, opts TunnelDialOptions)
if c == nil {
return nil, fmt.Errorf("failed to build poll conn")
}
return c, nil
outConn := net.Conn(c)
if opts.Upgrade != nil {
upgraded, err := opts.Upgrade(c)
if err != nil {
_ = c.Close()
return nil, err
}
if upgraded != nil {
outConn = upgraded
}
}
return outConn, nil
}
func (c *pollConn) pullLoop() {
@@ -1001,6 +1124,7 @@ func (c *pollConn) pullLoop() {
}
req.Host = c.headerHost
applyTunnelHeaders(req.Header, c.headerHost, TunnelModePoll)
applyTunnelAuthHeader(req.Header, c.auth, TunnelModePoll, http.MethodGet, "/stream")
resp, err := c.client.Do(req)
if err != nil {
@@ -1084,6 +1208,7 @@ func (c *pollConn) pushLoop() {
}
req.Host = c.headerHost
applyTunnelHeaders(req.Header, c.headerHost, TunnelModePoll)
applyTunnelAuthHeader(req.Header, c.auth, TunnelModePoll, http.MethodPost, "/api/v1/upload")
req.Header.Set("Content-Type", "text/plain")
resp, err := c.client.Do(req)
@@ -1246,6 +1371,18 @@ func applyTunnelHeaders(h http.Header, host string, mode TunnelMode) {
type TunnelServerOptions struct {
Mode string
// PathRoot is an optional first-level path prefix for all HTTP tunnel endpoints.
// Example: "aabbcc" => "/aabbcc/session", "/aabbcc/api/v1/upload", ...
PathRoot string
// AuthKey enables short-term HMAC auth for HTTP tunnel requests (anti-probing).
// When set (non-empty), the server requires each request to carry a valid Authorization bearer token.
AuthKey string
// AuthSkew controls allowed clock skew / replay window for AuthKey. 0 uses a conservative default.
AuthSkew time.Duration
// PassThroughOnReject controls how the server handles "recognized but rejected" tunnel requests
// (e.g., wrong mode / wrong path / invalid token). When true, the request bytes are replayed back
// to the caller as HandlePassThrough to allow higher-level fallback handling.
PassThroughOnReject bool
// PullReadTimeout controls how long the server long-poll waits for tunnel downlink data before replying with a keepalive newline.
PullReadTimeout time.Duration
// SessionTTL is a best-effort TTL to prevent leaked sessions. 0 uses a conservative default.
@@ -1253,7 +1390,10 @@ type TunnelServerOptions struct {
}
type TunnelServer struct {
mode TunnelMode
mode TunnelMode
pathRoot string
passThroughOnReject bool
auth *tunnelAuth
pullReadTimeout time.Duration
sessionTTL time.Duration
@@ -1272,6 +1412,8 @@ func NewTunnelServer(opts TunnelServerOptions) *TunnelServer {
if mode == TunnelModeLegacy {
// Server-side "legacy" means: don't accept stream/poll tunnels; only passthrough.
}
pathRoot := normalizePathRoot(opts.PathRoot)
auth := newTunnelAuth(opts.AuthKey, opts.AuthSkew)
timeout := opts.PullReadTimeout
if timeout <= 0 {
timeout = 10 * time.Second
@@ -1281,10 +1423,13 @@ func NewTunnelServer(opts TunnelServerOptions) *TunnelServer {
ttl = 2 * time.Minute
}
return &TunnelServer{
mode: mode,
pullReadTimeout: timeout,
sessionTTL: ttl,
sessions: make(map[string]*tunnelSession),
mode: mode,
pathRoot: pathRoot,
auth: auth,
passThroughOnReject: opts.PassThroughOnReject,
pullReadTimeout: timeout,
sessionTTL: ttl,
sessions: make(map[string]*tunnelSession),
}
}
@@ -1340,6 +1485,12 @@ func (s *TunnelServer) HandleConn(rawConn net.Conn) (HandleResult, net.Conn, err
return HandlePassThrough, newPreBufferedConn(rawConn, prefix), nil
}
if s.mode == TunnelModeLegacy {
if s.passThroughOnReject {
prefix := make([]byte, 0, len(headerBytes)+len(buffered))
prefix = append(prefix, headerBytes...)
prefix = append(prefix, buffered...)
return HandlePassThrough, newPreBufferedConn(rawConn, prefix), nil
}
_ = writeSimpleHTTPResponse(rawConn, http.StatusNotFound, "not found")
_ = rawConn.Close()
return HandleDone, nil, nil
@@ -1348,19 +1499,37 @@ func (s *TunnelServer) HandleConn(rawConn net.Conn) (HandleResult, net.Conn, err
switch TunnelMode(tunnelHeader) {
case TunnelModeStream:
if s.mode != TunnelModeStream && s.mode != TunnelModeAuto {
if s.passThroughOnReject {
prefix := make([]byte, 0, len(headerBytes)+len(buffered))
prefix = append(prefix, headerBytes...)
prefix = append(prefix, buffered...)
return HandlePassThrough, newPreBufferedConn(rawConn, prefix), nil
}
_ = writeSimpleHTTPResponse(rawConn, http.StatusNotFound, "not found")
_ = rawConn.Close()
return HandleDone, nil, nil
}
return s.handleStream(rawConn, req, buffered)
return s.handleStream(rawConn, req, headerBytes, buffered)
case TunnelModePoll:
if s.mode != TunnelModePoll && s.mode != TunnelModeAuto {
if s.passThroughOnReject {
prefix := make([]byte, 0, len(headerBytes)+len(buffered))
prefix = append(prefix, headerBytes...)
prefix = append(prefix, buffered...)
return HandlePassThrough, newPreBufferedConn(rawConn, prefix), nil
}
_ = writeSimpleHTTPResponse(rawConn, http.StatusNotFound, "not found")
_ = rawConn.Close()
return HandleDone, nil, nil
}
return s.handlePoll(rawConn, req, buffered)
return s.handlePoll(rawConn, req, headerBytes, buffered)
default:
if s.passThroughOnReject {
prefix := make([]byte, 0, len(headerBytes)+len(buffered))
prefix = append(prefix, headerBytes...)
prefix = append(prefix, buffered...)
return HandlePassThrough, newPreBufferedConn(rawConn, prefix), nil
}
_ = writeSimpleHTTPResponse(rawConn, http.StatusNotFound, "not found")
_ = rawConn.Close()
return HandleDone, nil, nil
@@ -1507,19 +1676,31 @@ func (c *bodyConn) Close() error {
return firstErr
}
func (s *TunnelServer) handleStream(rawConn net.Conn, req *httpRequestHeader, buffered []byte) (HandleResult, net.Conn, error) {
u, err := url.ParseRequestURI(req.target)
if err != nil {
_ = writeSimpleHTTPResponse(rawConn, http.StatusBadRequest, "bad request")
func (s *TunnelServer) handleStream(rawConn net.Conn, req *httpRequestHeader, headerBytes []byte, buffered []byte) (HandleResult, net.Conn, error) {
rejectOrReply := func(code int, body string) (HandleResult, net.Conn, error) {
if s.passThroughOnReject {
prefix := make([]byte, 0, len(headerBytes)+len(buffered))
prefix = append(prefix, headerBytes...)
prefix = append(prefix, buffered...)
return HandlePassThrough, newPreBufferedConn(rawConn, prefix), nil
}
_ = writeSimpleHTTPResponse(rawConn, code, body)
_ = rawConn.Close()
return HandleDone, nil, nil
}
u, err := url.ParseRequestURI(req.target)
if err != nil {
return rejectOrReply(http.StatusBadRequest, "bad request")
}
// Only accept plausible paths to reduce accidental exposure.
if !isAllowedPath(req.target) {
_ = writeSimpleHTTPResponse(rawConn, http.StatusNotFound, "not found")
_ = rawConn.Close()
return HandleDone, nil, nil
path, ok := stripPathRoot(s.pathRoot, u.Path)
if !ok || !s.isAllowedBasePath(path) {
return rejectOrReply(http.StatusNotFound, "not found")
}
if !s.auth.verify(req.headers, TunnelModeStream, req.method, path, time.Now()) {
return rejectOrReply(http.StatusNotFound, "not found")
}
token := u.Query().Get("token")
@@ -1528,31 +1709,25 @@ func (s *TunnelServer) handleStream(rawConn net.Conn, req *httpRequestHeader, bu
switch strings.ToUpper(req.method) {
case http.MethodGet:
// Stream split-session: GET /session (no token) => token + start tunnel on a server-side pipe.
if token == "" && u.Path == "/session" {
if token == "" && path == "/session" {
return s.authorizeSession(rawConn)
}
// Stream split-session: GET /stream?token=... => downlink poll.
if token != "" && u.Path == "/stream" {
if token != "" && path == "/stream" {
return s.streamPull(rawConn, token)
}
_ = writeSimpleHTTPResponse(rawConn, http.StatusBadRequest, "bad request")
_ = rawConn.Close()
return HandleDone, nil, nil
return rejectOrReply(http.StatusBadRequest, "bad request")
case http.MethodPost:
// Stream split-session: POST /api/v1/upload?token=... => uplink push.
if token != "" && u.Path == "/api/v1/upload" {
if token != "" && path == "/api/v1/upload" {
if closeFlag {
s.closeSession(token)
_ = writeSimpleHTTPResponse(rawConn, http.StatusOK, "")
_ = rawConn.Close()
return HandleDone, nil, nil
return rejectOrReply(http.StatusOK, "")
}
bodyReader, err := newRequestBodyReader(newPreBufferedConn(rawConn, buffered), req.headers)
if err != nil {
_ = writeSimpleHTTPResponse(rawConn, http.StatusBadRequest, "bad request")
_ = rawConn.Close()
return HandleDone, nil, nil
return rejectOrReply(http.StatusBadRequest, "bad request")
}
return s.streamPush(rawConn, token, bodyReader)
}
@@ -1581,19 +1756,13 @@ func (s *TunnelServer) handleStream(rawConn net.Conn, req *httpRequestHeader, bu
return HandleStartTunnel, stream, nil
default:
_ = writeSimpleHTTPResponse(rawConn, http.StatusBadRequest, "bad request")
_ = rawConn.Close()
return HandleDone, nil, nil
return rejectOrReply(http.StatusBadRequest, "bad request")
}
}
func isAllowedPath(target string) bool {
u, err := url.ParseRequestURI(target)
if err != nil {
return false
}
func (s *TunnelServer) isAllowedBasePath(path string) bool {
for _, p := range paths {
if u.Path == p {
if path == p {
return true
}
}
@@ -1650,51 +1819,58 @@ func writeTokenHTTPResponse(w io.Writer, token string) error {
return err
}
func (s *TunnelServer) handlePoll(rawConn net.Conn, req *httpRequestHeader, buffered []byte) (HandleResult, net.Conn, error) {
u, err := url.ParseRequestURI(req.target)
if err != nil {
_ = writeSimpleHTTPResponse(rawConn, http.StatusBadRequest, "bad request")
func (s *TunnelServer) handlePoll(rawConn net.Conn, req *httpRequestHeader, headerBytes []byte, buffered []byte) (HandleResult, net.Conn, error) {
rejectOrReply := func(code int, body string) (HandleResult, net.Conn, error) {
if s.passThroughOnReject {
prefix := make([]byte, 0, len(headerBytes)+len(buffered))
prefix = append(prefix, headerBytes...)
prefix = append(prefix, buffered...)
return HandlePassThrough, newPreBufferedConn(rawConn, prefix), nil
}
_ = writeSimpleHTTPResponse(rawConn, code, body)
_ = rawConn.Close()
return HandleDone, nil, nil
}
if !isAllowedPath(req.target) {
_ = writeSimpleHTTPResponse(rawConn, http.StatusNotFound, "not found")
_ = rawConn.Close()
return HandleDone, nil, nil
u, err := url.ParseRequestURI(req.target)
if err != nil {
return rejectOrReply(http.StatusBadRequest, "bad request")
}
path, ok := stripPathRoot(s.pathRoot, u.Path)
if !ok || !s.isAllowedBasePath(path) {
return rejectOrReply(http.StatusNotFound, "not found")
}
if !s.auth.verify(req.headers, TunnelModePoll, req.method, path, time.Now()) {
return rejectOrReply(http.StatusNotFound, "not found")
}
token := u.Query().Get("token")
closeFlag := u.Query().Get("close") == "1"
switch strings.ToUpper(req.method) {
case http.MethodGet:
if token == "" {
if token == "" && path == "/session" {
return s.authorizeSession(rawConn)
}
return s.pollPull(rawConn, token)
if token != "" && path == "/stream" {
return s.pollPull(rawConn, token)
}
return rejectOrReply(http.StatusBadRequest, "bad request")
case http.MethodPost:
if token == "" {
_ = writeSimpleHTTPResponse(rawConn, http.StatusBadRequest, "missing token")
_ = rawConn.Close()
return HandleDone, nil, nil
if token == "" || path != "/api/v1/upload" {
return rejectOrReply(http.StatusBadRequest, "bad request")
}
if closeFlag {
s.closeSession(token)
_ = writeSimpleHTTPResponse(rawConn, http.StatusOK, "")
_ = rawConn.Close()
return HandleDone, nil, nil
return rejectOrReply(http.StatusOK, "")
}
bodyReader, err := newRequestBodyReader(newPreBufferedConn(rawConn, buffered), req.headers)
if err != nil {
_ = writeSimpleHTTPResponse(rawConn, http.StatusBadRequest, "bad request")
_ = rawConn.Close()
return HandleDone, nil, nil
return rejectOrReply(http.StatusBadRequest, "bad request")
}
return s.pollPush(rawConn, token, bodyReader)
default:
_ = writeSimpleHTTPResponse(rawConn, http.StatusBadRequest, "bad request")
_ = rawConn.Close()
return HandleDone, nil, nil
return rejectOrReply(http.StatusBadRequest, "bad request")
}
}

View File

@@ -20,7 +20,11 @@ type byteLayout struct {
}
func (l *byteLayout) isHint(b byte) bool {
return (b & l.hintMask) == l.hintValue
if (b & l.hintMask) == l.hintValue {
return true
}
// ASCII layout maps the single non-printable marker (0x7F) to '\n' on the wire.
return l.name == "ascii" && b == '\n'
}
// resolveLayout picks the byte layout based on ASCII preference and optional custom pattern.
@@ -53,12 +57,25 @@ func newASCIILayout() *byteLayout {
padMarker: 0x3F,
paddingPool: padding,
encodeHint: func(val, pos byte) byte {
return 0x40 | ((val & 0x03) << 4) | (pos & 0x0F)
b := 0x40 | ((val & 0x03) << 4) | (pos & 0x0F)
// Avoid DEL (0x7F) in prefer_ascii mode; map it to '\n' to reduce fingerprint.
if b == 0x7F {
return '\n'
}
return b
},
encodeGroup: func(group byte) byte {
return 0x40 | (group & 0x3F)
b := 0x40 | (group & 0x3F)
// Avoid DEL (0x7F) in prefer_ascii mode; map it to '\n' to reduce fingerprint.
if b == 0x7F {
return '\n'
}
return b
},
decodeGroup: func(b byte) (byte, bool) {
if b == '\n' {
return 0x3F, true
}
if (b & 0x40) == 0 {
return 0, false
}

View File

@@ -3,6 +3,7 @@ package sudoku
import (
"bufio"
"bytes"
crand "crypto/rand"
"encoding/binary"
"errors"
"fmt"
@@ -14,16 +15,20 @@ import (
"github.com/metacubex/mihomo/transport/sudoku/obfs/sudoku"
)
func pickClientTable(cfg *ProtocolConfig) (*sudoku.Table, byte, error) {
func pickClientTable(cfg *ProtocolConfig) (*sudoku.Table, error) {
candidates := cfg.tableCandidates()
if len(candidates) == 0 {
return nil, 0, fmt.Errorf("no table configured")
return nil, fmt.Errorf("no table configured")
}
if len(candidates) == 1 {
return candidates[0], 0, nil
return candidates[0], nil
}
idx := int(randomByte()) % len(candidates)
return candidates[idx], byte(idx), nil
var b [1]byte
if _, err := crand.Read(b[:]); err != nil {
return nil, fmt.Errorf("random table pick failed: %w", err)
}
idx := int(b[0]) % len(candidates)
return candidates[idx], nil
}
type readOnlyConn struct {