Compare commits

...

5 Commits
Meta ... Alpha

Author SHA1 Message Date
wwqgtxx
0317d9f742 action: adjust setup go condition 2026-03-11 15:54:02 +08:00
wwqgtxx
61c13586e9 feat: add ping-interval to grpc-opts 2026-03-11 02:21:39 +08:00
saba-futai
b3c81602a2 chore: reduce the inherent 1rtt in httpmask mode for sudoku (#2610) 2026-03-11 00:00:32 +08:00
wwqgtxx
6517d2a9b2 chore: align with legacy behavior 2026-03-10 16:53:40 +08:00
wwqgtxx
e28fe24fee fix: incorrect use of hyphen 2026-03-10 01:45:00 +08:00
24 changed files with 1259 additions and 277 deletions

View File

@@ -53,7 +53,7 @@ jobs:
- { goos: linux, goarch: mipsle, gomips: softfloat, output: mipsle-softfloat }
- { goos: linux, goarch: mips64, output: mips64 }
- { goos: linux, goarch: mips64le, output: mips64le, debian: mips64el, rpm: mips64el }
- { goos: linux, goarch: loong64, output: loong64-abi1, abi: '1', debian: loongarch64, rpm: loongarch64 }
- { goos: linux, goarch: loong64, output: loong64-abi1, abi: '1', debian: loongarch64, rpm: loongarch64, goversion: 'custom' }
- { goos: linux, goarch: loong64, output: loong64-abi2, abi: '2', debian: loong64, rpm: loong64 }
- { goos: linux, goarch: riscv64, output: riscv64, debian: riscv64, rpm: riscv64 }
- { goos: linux, goarch: s390x, output: s390x, debian: s390x, rpm: s390x }
@@ -158,14 +158,14 @@ jobs:
- uses: actions/checkout@v5
- name: Set up Go
if: ${{ matrix.jobs.goversion == '' && matrix.jobs.abi != '1' }}
if: ${{ matrix.jobs.goversion == '' }}
uses: actions/setup-go@v6
with:
go-version: '1.26'
check-latest: true # Always check for the latest patch release
- name: Set up Go
if: ${{ matrix.jobs.goversion != '' && matrix.jobs.abi != '1' }}
if: ${{ matrix.jobs.goversion != '' && matrix.jobs.goversion != 'custom' }}
uses: actions/setup-go@v6
with:
go-version: ${{ matrix.jobs.goversion }}

View File

@@ -54,7 +54,7 @@ type SudokuHTTPMaskOptions struct {
Mode string `proxy:"mode,omitempty"`
TLS bool `proxy:"tls,omitempty"`
Host string `proxy:"host,omitempty"`
PathRoot string `proxy:"path_root,omitempty"`
PathRoot string `proxy:"path-root,omitempty"`
Multiplex string `proxy:"multiplex,omitempty"`
}

View File

@@ -27,8 +27,7 @@ type Trojan struct {
hexPassword [trojan.KeyLength]byte
// for gun mux
gunConfig *gun.Config
gunTransport *gun.TransportWrap
gunTransport *gun.Transport
realityConfig *tlsC.RealityConfig
echConfig *ech.Config
@@ -178,7 +177,7 @@ func (t *Trojan) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Con
var c net.Conn
// gun transport
if t.gunTransport != nil {
c, err = gun.StreamGunWithTransport(t.gunTransport, t.gunConfig)
c, err = t.gunTransport.Dial()
} else {
c, err = t.dialer.DialContext(ctx, "tcp", t.addr)
}
@@ -206,7 +205,7 @@ func (t *Trojan) ListenPacketContext(ctx context.Context, metadata *C.Metadata)
var c net.Conn
// grpc transport
if t.gunTransport != nil {
c, err = gun.StreamGunWithTransport(t.gunTransport, t.gunConfig)
c, err = t.gunTransport.Dial()
} else {
c, err = t.dialer.DialContext(ctx, "tcp", t.addr)
}
@@ -317,13 +316,14 @@ func NewTrojan(option TrojanOption) (*Trojan, error) {
Reality: t.realityConfig,
}
t.gunTransport = gun.NewHTTP2Client(dialFn, tlsConfig)
t.gunConfig = &gun.Config{
gunConfig := &gun.Config{
ServiceName: option.GrpcOpts.GrpcServiceName,
UserAgent: option.GrpcOpts.GrpcUserAgent,
Host: option.SNI,
PingInterval: option.GrpcOpts.PingInterval,
}
t.gunTransport = gun.NewTransport(dialFn, tlsConfig, gunConfig)
}
return t, nil

View File

@@ -33,8 +33,7 @@ type Vless struct {
encryption *encryption.ClientInstance
// for gun mux
gunConfig *gun.Config
gunTransport *gun.TransportWrap
gunTransport *gun.Transport
realityConfig *tlsC.RealityConfig
echConfig *ech.Config
@@ -234,7 +233,7 @@ func (v *Vless) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn
var c net.Conn
// gun transport
if v.gunTransport != nil {
c, err = gun.StreamGunWithTransport(v.gunTransport, v.gunConfig)
c, err = v.gunTransport.Dial()
} else {
c, err = v.dialer.DialContext(ctx, "tcp", v.addr)
}
@@ -260,7 +259,7 @@ func (v *Vless) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (
var c net.Conn
// gun transport
if v.gunTransport != nil {
c, err = gun.StreamGunWithTransport(v.gunTransport, v.gunConfig)
c, err = v.gunTransport.Dial()
} else {
c, err = v.dialer.DialContext(ctx, "tcp", v.addr)
}
@@ -434,6 +433,7 @@ func NewVless(option VlessOption) (*Vless, error) {
ServiceName: option.GrpcOpts.GrpcServiceName,
UserAgent: option.GrpcOpts.GrpcUserAgent,
Host: option.ServerName,
PingInterval: option.GrpcOpts.PingInterval,
}
if option.ServerName == "" {
gunConfig.Host = v.addr
@@ -457,9 +457,7 @@ func NewVless(option VlessOption) (*Vless, error) {
}
}
v.gunConfig = gunConfig
v.gunTransport = gun.NewHTTP2Client(dialFn, tlsConfig)
v.gunTransport = gun.NewTransport(dialFn, tlsConfig, gunConfig)
}
return v, nil

View File

@@ -34,8 +34,7 @@ type Vmess struct {
option *VmessOption
// for gun mux
gunConfig *gun.Config
gunTransport *gun.TransportWrap
gunTransport *gun.Transport
realityConfig *tlsC.RealityConfig
echConfig *ech.Config
@@ -86,6 +85,7 @@ type HTTP2Options struct {
type GrpcOptions struct {
GrpcServiceName string `proxy:"grpc-service-name,omitempty"`
GrpcUserAgent string `proxy:"grpc-user-agent,omitempty"`
PingInterval int `proxy:"ping-interval,omitempty"`
}
type WSOptions struct {
@@ -295,7 +295,7 @@ func (v *Vmess) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn
var c net.Conn
// gun transport
if v.gunTransport != nil {
c, err = gun.StreamGunWithTransport(v.gunTransport, v.gunConfig)
c, err = v.gunTransport.Dial()
} else {
c, err = v.dialer.DialContext(ctx, "tcp", v.addr)
}
@@ -318,7 +318,7 @@ func (v *Vmess) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (
var c net.Conn
// gun transport
if v.gunTransport != nil {
c, err = gun.StreamGunWithTransport(v.gunTransport, v.gunConfig)
c, err = v.gunTransport.Dial()
} else {
c, err = v.dialer.DialContext(ctx, "tcp", v.addr)
}
@@ -440,6 +440,7 @@ func NewVmess(option VmessOption) (*Vmess, error) {
ServiceName: option.GrpcOpts.GrpcServiceName,
UserAgent: option.GrpcOpts.GrpcUserAgent,
Host: option.ServerName,
PingInterval: option.GrpcOpts.PingInterval,
}
if option.ServerName == "" {
gunConfig.Host = v.addr
@@ -463,9 +464,7 @@ func NewVmess(option VmessOption) (*Vmess, error) {
}
}
v.gunConfig = gunConfig
v.gunTransport = gun.NewHTTP2Client(dialFn, tlsConfig)
v.gunTransport = gun.NewTransport(dialFn, tlsConfig, gunConfig)
}
return v, nil

View File

@@ -18,8 +18,8 @@ var (
type healthCheckSchema struct {
Enable bool `provider:"enable"`
URL string `provider:"url"`
Interval int `provider:"interval"`
URL string `provider:"url,omitempty"`
Interval int `provider:"interval,omitempty"`
TestTimeout int `provider:"timeout,omitempty"`
Lazy bool `provider:"lazy,omitempty"`
ExpectedStatus string `provider:"expected-status,omitempty"`

View File

@@ -669,6 +669,7 @@ proxies: # socks5
grpc-opts:
grpc-service-name: "example"
# grpc-user-agent: "grpc-go/1.36.0"
# ping-interval: 0 # 默认关闭,单位为秒
# ip-version: ipv4
# vless
@@ -759,6 +760,7 @@ proxies: # socks5
grpc-opts:
grpc-service-name: "grpc"
# grpc-user-agent: "grpc-go/1.36.0"
# ping-interval: 0 # 默认关闭,单位为秒
reality-opts:
public-key: CrrQSjAG_YkHLwvM2M-7XkKJilgL5upBKCp0od0tLhE
@@ -830,6 +832,7 @@ proxies: # socks5
grpc-opts:
grpc-service-name: "example"
# grpc-user-agent: "grpc-go/1.36.0"
# ping-interval: 0 # 默认关闭,单位为秒
- name: trojan-ws
server: server
@@ -1098,7 +1101,7 @@ proxies: # socks5
mode: legacy # 可选legacy默认、streamsplit-stream、poll、auto先 stream 再 poll、wsWebSocket 隧道)
# tls: true # 可选:仅在 mode 为 stream/poll/auto/ws 时生效true 强制 https/wssfalse 强制 http/ws不会根据端口自动推断
# host: "" # 可选:覆盖 Host/SNI支持 example.com 或 example.com:443仅在 mode 为 stream/poll/auto/ws 时生效
# path_root: "" # 可选HTTP 隧道端点一级路径前缀(双方需一致),例如 "aabbcc" 或 "/aabbcc/" => /aabbcc/session、/aabbcc/stream、/aabbcc/api/v1/upload、/aabbcc/ws
# path-root: "" # 可选HTTP 隧道端点一级路径前缀(双方需一致),例如 "aabbcc" 或 "/aabbcc/" => /aabbcc/session、/aabbcc/stream、/aabbcc/api/v1/upload、/aabbcc/ws
# multiplex: off # 可选off默认、auto复用底层 HTTP 连接,减少建链 RTT、onSudoku mux 单隧道多目标;仅在 mode=stream/poll/auto 生效ws 强制 off
#
# 向后兼容旧写法:
@@ -1677,7 +1680,7 @@ listeners:
httpmask:
disable: false # true 禁用所有 HTTP 伪装/隧道
mode: legacy # 可选legacy默认、streamsplit-stream、poll、auto先 stream 再 poll、wsWebSocket 隧道)
# path_root: "" # 可选HTTP 隧道端点一级路径前缀(双方需一致),例如 "aabbcc" 或 "/aabbcc/" => /aabbcc/session、/aabbcc/stream、/aabbcc/api/v1/upload、/aabbcc/ws
# path-root: "" # 可选HTTP 隧道端点一级路径前缀(双方需一致),例如 "aabbcc" 或 "/aabbcc/" => /aabbcc/session、/aabbcc/stream、/aabbcc/api/v1/upload、/aabbcc/ws
#
# 可选:当启用 HTTPMask 且识别到“像 HTTP 但不符合 tunnel/auth”的请求时将原始字节透传给 fallback常用于与其他服务共端口
# fallback: "127.0.0.1:80"

View File

@@ -35,7 +35,7 @@ type SudokuOption struct {
type SudokuHTTPMaskOptions struct {
Disable bool `inbound:"disable,omitempty"`
Mode string `inbound:"mode,omitempty"`
PathRoot string `inbound:"path_root,omitempty"`
PathRoot string `inbound:"path-root,omitempty"`
}
func (o SudokuOption) Equal(config C.InboundConfig) bool {

View File

@@ -62,6 +62,7 @@ type Config struct {
ServiceName string
UserAgent string
Host string
PingInterval int
}
func (g *Conn) initReader() {
@@ -246,7 +247,7 @@ func (g *Conn) SetDeadline(t time.Time) error {
return nil
}
func NewHTTP2Client(dialFn DialFn, tlsConfig *vmess.TLSConfig) *TransportWrap {
func NewTransport(dialFn DialFn, tlsConfig *vmess.TLSConfig, gunCfg *Config) *Transport {
dialFunc := func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) {
ctx, cancel := context.WithTimeout(ctx, C.DefaultTLSTimeout)
defer cancel()
@@ -288,12 +289,14 @@ func NewHTTP2Client(dialFn DialFn, tlsConfig *vmess.TLSConfig) *TransportWrap {
DialTLSContext: dialFunc,
AllowHTTP: false,
DisableCompression: true,
ReadIdleTimeout: time.Duration(gunCfg.PingInterval) * time.Second, // If zero, no health check is performed
PingTimeout: 0,
}
ctx, cancel := context.WithCancel(context.Background())
wrap := &TransportWrap{
Http2Transport: transport,
wrap := &Transport{
transport: transport,
cfg: gunCfg,
ctx: ctx,
cancel: cancel,
}
@@ -307,18 +310,18 @@ func ServiceNameToPath(serviceName string) string {
return "/" + serviceName + "/Tun"
}
func StreamGunWithTransport(transport *TransportWrap, cfg *Config) (net.Conn, error) {
func (t *Transport) Dial() (net.Conn, error) {
serviceName := "GunService"
if cfg.ServiceName != "" {
serviceName = cfg.ServiceName
if t.cfg.ServiceName != "" {
serviceName = t.cfg.ServiceName
}
path := ServiceNameToPath(serviceName)
reader, writer := io.Pipe()
header := defaultHeader.Clone()
if cfg.UserAgent != "" {
header.Set("User-Agent", cfg.UserAgent)
if t.cfg.UserAgent != "" {
header.Set("User-Agent", t.cfg.UserAgent)
}
request := &http.Request{
@@ -326,17 +329,17 @@ func StreamGunWithTransport(transport *TransportWrap, cfg *Config) (net.Conn, er
Body: reader,
URL: &url.URL{
Scheme: "https",
Host: cfg.Host,
Host: t.cfg.Host,
Path: path,
// for unescape path
Opaque: "//" + cfg.Host + path,
Opaque: "//" + t.cfg.Host + path,
},
Proto: "HTTP/2",
ProtoMajor: 2,
ProtoMinor: 0,
Header: header,
}
request = request.WithContext(transport.ctx)
request = request.WithContext(t.ctx)
conn := &Conn{
initFn: func() (io.ReadCloser, NetAddr, error) {
@@ -348,7 +351,7 @@ func StreamGunWithTransport(transport *TransportWrap, cfg *Config) (net.Conn, er
},
}
request = request.WithContext(httptrace.WithClientTrace(request.Context(), trace))
response, err := transport.RoundTrip(request)
response, err := t.transport.RoundTrip(request)
if err != nil {
return nil, nAddr, err
}
@@ -361,13 +364,13 @@ func StreamGunWithTransport(transport *TransportWrap, cfg *Config) (net.Conn, er
return conn, nil
}
func StreamGunWithConn(conn net.Conn, tlsConfig *vmess.TLSConfig, cfg *Config) (net.Conn, error) {
func StreamGunWithConn(conn net.Conn, tlsConfig *vmess.TLSConfig, gunCfg *Config) (net.Conn, error) {
dialFn := func(ctx context.Context, network, addr string) (net.Conn, error) {
return conn, nil
}
transport := NewHTTP2Client(dialFn, tlsConfig)
c, err := StreamGunWithTransport(transport, cfg)
transport := NewTransport(dialFn, tlsConfig, gunCfg)
c, err := transport.Dial()
if err != nil {
return nil, err
}

View File

@@ -10,17 +10,18 @@ import (
"github.com/metacubex/http"
)
type TransportWrap struct {
*http.Http2Transport
type Transport struct {
transport *http.Http2Transport
cfg *Config
ctx context.Context
cancel context.CancelFunc
closeOnce sync.Once
}
func (tw *TransportWrap) Close() error {
tw.closeOnce.Do(func() {
tw.cancel()
CloseTransport(tw.Http2Transport)
func (t *Transport) Close() error {
t.closeOnce.Do(func() {
t.cancel()
CloseHttp2Transport(t.transport)
})
return nil
}

View File

@@ -44,7 +44,7 @@ func closeClientConn(cc *http.Http2ClientConn) { // like forceCloseConn() in htt
_ = cc.Close()
}
func CloseTransport(tr *http.Http2Transport) {
func CloseHttp2Transport(tr *http.Http2Transport) {
connPool := transportConnPool(tr)
p := (*clientConnPool)((*efaceWords)(unsafe.Pointer(&connPool)).data)
p.mu.Lock()

View File

@@ -5,6 +5,7 @@ import (
"crypto/aes"
"crypto/cipher"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/binary"
"errors"
@@ -12,6 +13,7 @@ import (
"io"
"net"
"sync"
"sync/atomic"
"golang.org/x/crypto/chacha20poly1305"
)
@@ -58,10 +60,12 @@ type RecordConn struct {
sendEpoch uint32
sendSeq uint64
sendBytes int64
sendEpochUpdates uint32
// Receive direction state.
recvEpoch uint32
recvSeq uint64
recvInitialized bool
readBuf bytes.Buffer
@@ -105,6 +109,9 @@ func NewRecordConn(conn net.Conn, method string, baseSend, baseRecv []byte) (*Re
}
rc := &RecordConn{Conn: conn, method: method}
rc.keys = recordKeys{baseSend: cloneBytes(baseSend), baseRecv: cloneBytes(baseRecv)}
if err := rc.resetTrafficState(); err != nil {
return nil, err
}
return rc, nil
}
@@ -127,11 +134,9 @@ func (c *RecordConn) Rekey(baseSend, baseRecv []byte) error {
defer c.writeMu.Unlock()
c.keys = recordKeys{baseSend: cloneBytes(baseSend), baseRecv: cloneBytes(baseRecv)}
c.sendEpoch = 0
c.sendSeq = 0
c.sendBytes = 0
c.recvEpoch = 0
c.recvSeq = 0
if err := c.resetTrafficState(); err != nil {
return err
}
c.readBuf.Reset()
c.sendAEAD = nil
@@ -141,6 +146,21 @@ func (c *RecordConn) Rekey(baseSend, baseRecv []byte) error {
return nil
}
func (c *RecordConn) resetTrafficState() error {
sendEpoch, sendSeq, err := randomRecordCounters()
if err != nil {
return fmt.Errorf("initialize record counters: %w", err)
}
c.sendEpoch = sendEpoch
c.sendSeq = sendSeq
c.sendBytes = 0
c.sendEpochUpdates = 0
c.recvEpoch = 0
c.recvSeq = 0
c.recvInitialized = false
return nil
}
func normalizeAEADMethod(method string) string {
switch method {
case "", "chacha20-poly1305":
@@ -166,6 +186,44 @@ func cloneBytes(b []byte) []byte {
return append([]byte(nil), b...)
}
func randomRecordCounters() (uint32, uint64, error) {
epoch, err := randomNonZeroUint32()
if err != nil {
return 0, 0, err
}
seq, err := randomNonZeroUint64()
if err != nil {
return 0, 0, err
}
return epoch, seq, nil
}
func randomNonZeroUint32() (uint32, error) {
var b [4]byte
for {
if _, err := io.ReadFull(rand.Reader, b[:]); err != nil {
return 0, err
}
v := binary.BigEndian.Uint32(b[:])
if v != 0 && v != ^uint32(0) {
return v, nil
}
}
}
func randomNonZeroUint64() (uint64, error) {
var b [8]byte
for {
if _, err := io.ReadFull(rand.Reader, b[:]); err != nil {
return 0, err
}
v := binary.BigEndian.Uint64(b[:])
if v != 0 && v != ^uint64(0) {
return v, nil
}
}
}
func (c *RecordConn) newAEADFor(base []byte, epoch uint32) (cipher.AEAD, error) {
if c.method == "none" {
return nil, nil
@@ -209,17 +267,49 @@ func deriveEpochKey(base []byte, epoch uint32, method string) []byte {
return mac.Sum(nil)
}
func (c *RecordConn) maybeBumpSendEpochLocked(addedPlain int) {
if KeyUpdateAfterBytes <= 0 || c.method == "none" {
return
func (c *RecordConn) maybeBumpSendEpochLocked(addedPlain int) error {
ku := atomic.LoadInt64(&KeyUpdateAfterBytes)
if ku <= 0 || c.method == "none" {
return nil
}
c.sendBytes += int64(addedPlain)
threshold := KeyUpdateAfterBytes * int64(c.sendEpoch+1)
threshold := ku * int64(c.sendEpochUpdates+1)
if c.sendBytes < threshold {
return
return nil
}
c.sendEpoch++
c.sendSeq = 0
c.sendEpochUpdates++
nextSeq, err := randomNonZeroUint64()
if err != nil {
return fmt.Errorf("rotate record seq: %w", err)
}
c.sendSeq = nextSeq
return nil
}
func (c *RecordConn) validateRecvPosition(epoch uint32, seq uint64) error {
if !c.recvInitialized {
return nil
}
if epoch < c.recvEpoch {
return fmt.Errorf("replayed epoch: got %d want >=%d", epoch, c.recvEpoch)
}
if epoch == c.recvEpoch && seq != c.recvSeq {
return fmt.Errorf("out of order: epoch=%d got=%d want=%d", epoch, seq, c.recvSeq)
}
if epoch > c.recvEpoch {
const maxJump = 8
if epoch-c.recvEpoch > maxJump {
return fmt.Errorf("epoch jump too large: got=%d want<=%d", epoch-c.recvEpoch, maxJump)
}
}
return nil
}
func (c *RecordConn) markRecvPosition(epoch uint32, seq uint64) {
c.recvEpoch = epoch
c.recvSeq = seq + 1
c.recvInitialized = true
}
func (c *RecordConn) Write(p []byte) (int, error) {
@@ -282,7 +372,9 @@ func (c *RecordConn) Write(p []byte) (int, error) {
}
total += n
c.maybeBumpSendEpochLocked(n)
if err := c.maybeBumpSendEpochLocked(n); err != nil {
return total, err
}
}
return total, nil
}
@@ -324,31 +416,17 @@ func (c *RecordConn) Read(p []byte) (int, error) {
epoch := binary.BigEndian.Uint32(header[:4])
seq := binary.BigEndian.Uint64(header[4:])
if epoch < c.recvEpoch {
return 0, fmt.Errorf("replayed epoch: got %d want >=%d", epoch, c.recvEpoch)
}
if epoch == c.recvEpoch && seq != c.recvSeq {
return 0, fmt.Errorf("out of order: epoch=%d got=%d want=%d", epoch, seq, c.recvSeq)
}
if epoch > c.recvEpoch {
const maxJump = 8
if epoch-c.recvEpoch > maxJump {
return 0, fmt.Errorf("epoch jump too large: got=%d want<=%d", epoch-c.recvEpoch, maxJump)
}
c.recvEpoch = epoch
c.recvSeq = 0
if seq != 0 {
return 0, fmt.Errorf("out of order: epoch advanced to %d but seq=%d", epoch, seq)
}
if err := c.validateRecvPosition(epoch, seq); err != nil {
return 0, err
}
if c.recvAEAD == nil || c.recvAEADEpoch != c.recvEpoch {
a, err := c.newAEADFor(c.keys.baseRecv, c.recvEpoch)
if c.recvAEAD == nil || c.recvAEADEpoch != epoch {
a, err := c.newAEADFor(c.keys.baseRecv, epoch)
if err != nil {
return 0, err
}
c.recvAEAD = a
c.recvAEADEpoch = c.recvEpoch
c.recvAEADEpoch = epoch
}
aead := c.recvAEAD
@@ -356,7 +434,7 @@ func (c *RecordConn) Read(p []byte) (int, error) {
if err != nil {
return 0, fmt.Errorf("decryption failed: epoch=%d seq=%d: %w", epoch, seq, err)
}
c.recvSeq++
c.markRecvPosition(epoch, seq)
c.readBuf.Write(plaintext)
return c.readBuf.Read(p)

View File

@@ -0,0 +1,86 @@
package crypto
import (
"bytes"
"crypto/sha256"
"encoding/binary"
"io"
"net"
"testing"
"time"
)
type captureConn struct {
bytes.Buffer
}
func (c *captureConn) Read(_ []byte) (int, error) { return 0, io.EOF }
func (c *captureConn) Write(p []byte) (int, error) { return c.Buffer.Write(p) }
func (c *captureConn) Close() error { return nil }
func (c *captureConn) LocalAddr() net.Addr { return nil }
func (c *captureConn) RemoteAddr() net.Addr { return nil }
func (c *captureConn) SetDeadline(time.Time) error { return nil }
func (c *captureConn) SetReadDeadline(time.Time) error { return nil }
func (c *captureConn) SetWriteDeadline(time.Time) error { return nil }
type replayConn struct {
reader *bytes.Reader
}
func (c *replayConn) Read(p []byte) (int, error) { return c.reader.Read(p) }
func (c *replayConn) Write(p []byte) (int, error) { return len(p), nil }
func (c *replayConn) Close() error { return nil }
func (c *replayConn) LocalAddr() net.Addr { return nil }
func (c *replayConn) RemoteAddr() net.Addr { return nil }
func (c *replayConn) SetDeadline(time.Time) error { return nil }
func (c *replayConn) SetReadDeadline(time.Time) error { return nil }
func (c *replayConn) SetWriteDeadline(time.Time) error { return nil }
func TestRecordConn_FirstFrameUsesRandomizedCounters(t *testing.T) {
pskSend := sha256.Sum256([]byte("record-send"))
pskRecv := sha256.Sum256([]byte("record-recv"))
raw := &captureConn{}
writer, err := NewRecordConn(raw, "chacha20-poly1305", pskSend[:], pskRecv[:])
if err != nil {
t.Fatalf("new writer: %v", err)
}
if writer.sendEpoch == 0 || writer.sendSeq == 0 {
t.Fatalf("expected non-zero randomized counters, got epoch=%d seq=%d", writer.sendEpoch, writer.sendSeq)
}
want := []byte("record prefix camouflage")
if _, err := writer.Write(want); err != nil {
t.Fatalf("write: %v", err)
}
wire := raw.Bytes()
if len(wire) < 2+recordHeaderSize {
t.Fatalf("short frame: %d", len(wire))
}
bodyLen := int(binary.BigEndian.Uint16(wire[:2]))
if bodyLen != len(wire)-2 {
t.Fatalf("body len mismatch: got %d want %d", bodyLen, len(wire)-2)
}
epoch := binary.BigEndian.Uint32(wire[2:6])
seq := binary.BigEndian.Uint64(wire[6:14])
if epoch == 0 || seq == 0 {
t.Fatalf("wire header still starts from zero: epoch=%d seq=%d", epoch, seq)
}
reader, err := NewRecordConn(&replayConn{reader: bytes.NewReader(wire)}, "chacha20-poly1305", pskRecv[:], pskSend[:])
if err != nil {
t.Fatalf("new reader: %v", err)
}
got := make([]byte, len(want))
if _, err := io.ReadFull(reader, got); err != nil {
t.Fatalf("read: %v", err)
}
if !bytes.Equal(got, want) {
t.Fatalf("plaintext mismatch: got %q want %q", got, want)
}
}

View File

@@ -0,0 +1,345 @@
package sudoku
import (
"bytes"
"crypto/ecdh"
"crypto/rand"
"encoding/hex"
"fmt"
"net"
"time"
"github.com/metacubex/mihomo/transport/sudoku/crypto"
httpmaskobfs "github.com/metacubex/mihomo/transport/sudoku/obfs/httpmask"
sudokuobfs "github.com/metacubex/mihomo/transport/sudoku/obfs/sudoku"
)
const earlyKIPHandshakeTTL = 60 * time.Second
type EarlyCodecConfig struct {
PSK string
AEAD string
EnablePureDownlink bool
PaddingMin int
PaddingMax int
}
type EarlyClientState struct {
RequestPayload []byte
cfg EarlyCodecConfig
table *sudokuobfs.Table
nonce [kipHelloNonceSize]byte
ephemeral *ecdh.PrivateKey
sessionC2S []byte
sessionS2C []byte
responseSet bool
}
type EarlyServerState struct {
ResponsePayload []byte
UserHash string
cfg EarlyCodecConfig
table *sudokuobfs.Table
sessionC2S []byte
sessionS2C []byte
}
type ReplayAllowFunc func(userHash string, nonce [kipHelloNonceSize]byte, now time.Time) bool
type earlyMemoryConn struct {
reader *bytes.Reader
write bytes.Buffer
}
func newEarlyMemoryConn(readBuf []byte) *earlyMemoryConn {
return &earlyMemoryConn{reader: bytes.NewReader(readBuf)}
}
func (c *earlyMemoryConn) Read(p []byte) (int, error) {
if c == nil || c.reader == nil {
return 0, net.ErrClosed
}
return c.reader.Read(p)
}
func (c *earlyMemoryConn) Write(p []byte) (int, error) {
if c == nil {
return 0, net.ErrClosed
}
return c.write.Write(p)
}
func (c *earlyMemoryConn) Close() error { return nil }
func (c *earlyMemoryConn) LocalAddr() net.Addr { return earlyDummyAddr("local") }
func (c *earlyMemoryConn) RemoteAddr() net.Addr { return earlyDummyAddr("remote") }
func (c *earlyMemoryConn) SetDeadline(time.Time) error { return nil }
func (c *earlyMemoryConn) SetReadDeadline(time.Time) error { return nil }
func (c *earlyMemoryConn) SetWriteDeadline(time.Time) error { return nil }
func (c *earlyMemoryConn) Written() []byte { return append([]byte(nil), c.write.Bytes()...) }
type earlyDummyAddr string
func (a earlyDummyAddr) Network() string { return string(a) }
func (a earlyDummyAddr) String() string { return string(a) }
func buildEarlyClientObfsConn(raw net.Conn, cfg EarlyCodecConfig, table *sudokuobfs.Table) net.Conn {
base := sudokuobfs.NewConn(raw, table, cfg.PaddingMin, cfg.PaddingMax, false)
if cfg.EnablePureDownlink {
return base
}
packed := sudokuobfs.NewPackedConn(raw, table, cfg.PaddingMin, cfg.PaddingMax)
return newDirectionalConn(raw, packed, base)
}
func buildEarlyServerObfsConn(raw net.Conn, cfg EarlyCodecConfig, table *sudokuobfs.Table) net.Conn {
uplink := sudokuobfs.NewConn(raw, table, cfg.PaddingMin, cfg.PaddingMax, false)
if cfg.EnablePureDownlink {
return uplink
}
packed := sudokuobfs.NewPackedConn(raw, table, cfg.PaddingMin, cfg.PaddingMax)
return newDirectionalConn(raw, uplink, packed, packed.Flush)
}
func NewEarlyClientState(cfg EarlyCodecConfig, table *sudokuobfs.Table, userHash [kipHelloUserHashSize]byte, feats uint32) (*EarlyClientState, error) {
if table == nil {
return nil, fmt.Errorf("nil table")
}
curve := ecdh.X25519()
ephemeral, err := curve.GenerateKey(rand.Reader)
if err != nil {
return nil, fmt.Errorf("ecdh generate failed: %w", err)
}
var nonce [kipHelloNonceSize]byte
if _, err := rand.Read(nonce[:]); err != nil {
return nil, fmt.Errorf("nonce generate failed: %w", err)
}
var clientPub [kipHelloPubSize]byte
copy(clientPub[:], ephemeral.PublicKey().Bytes())
hello := &KIPClientHello{
Timestamp: time.Now(),
UserHash: userHash,
Nonce: nonce,
ClientPub: clientPub,
Features: feats,
}
mem := newEarlyMemoryConn(nil)
obfsConn := buildEarlyClientObfsConn(mem, cfg, table)
pskC2S, pskS2C := derivePSKDirectionalBases(cfg.PSK)
rc, err := crypto.NewRecordConn(obfsConn, cfg.AEAD, pskC2S, pskS2C)
if err != nil {
return nil, fmt.Errorf("client early crypto setup failed: %w", err)
}
if err := WriteKIPMessage(rc, KIPTypeClientHello, hello.EncodePayload()); err != nil {
return nil, fmt.Errorf("write early client hello failed: %w", err)
}
return &EarlyClientState{
RequestPayload: mem.Written(),
cfg: cfg,
table: table,
nonce: nonce,
ephemeral: ephemeral,
}, nil
}
func (s *EarlyClientState) ProcessResponse(payload []byte) error {
if s == nil {
return fmt.Errorf("nil client state")
}
mem := newEarlyMemoryConn(payload)
obfsConn := buildEarlyClientObfsConn(mem, s.cfg, s.table)
pskC2S, pskS2C := derivePSKDirectionalBases(s.cfg.PSK)
rc, err := crypto.NewRecordConn(obfsConn, s.cfg.AEAD, pskC2S, pskS2C)
if err != nil {
return fmt.Errorf("client early crypto setup failed: %w", err)
}
msg, err := ReadKIPMessage(rc)
if err != nil {
return fmt.Errorf("read early server hello failed: %w", err)
}
if msg.Type != KIPTypeServerHello {
return fmt.Errorf("unexpected early handshake message: %d", msg.Type)
}
sh, err := DecodeKIPServerHelloPayload(msg.Payload)
if err != nil {
return fmt.Errorf("decode early server hello failed: %w", err)
}
if sh.Nonce != s.nonce {
return fmt.Errorf("early handshake nonce mismatch")
}
shared, err := x25519SharedSecret(s.ephemeral, sh.ServerPub[:])
if err != nil {
return fmt.Errorf("ecdh failed: %w", err)
}
s.sessionC2S, s.sessionS2C, err = deriveSessionDirectionalBases(s.cfg.PSK, shared, s.nonce)
if err != nil {
return fmt.Errorf("derive session keys failed: %w", err)
}
s.responseSet = true
return nil
}
func (s *EarlyClientState) WrapConn(raw net.Conn) (net.Conn, error) {
if s == nil {
return nil, fmt.Errorf("nil client state")
}
if !s.responseSet {
return nil, fmt.Errorf("early handshake not completed")
}
obfsConn := buildEarlyClientObfsConn(raw, s.cfg, s.table)
rc, err := crypto.NewRecordConn(obfsConn, s.cfg.AEAD, s.sessionC2S, s.sessionS2C)
if err != nil {
return nil, fmt.Errorf("setup client session crypto failed: %w", err)
}
return rc, nil
}
func (s *EarlyClientState) Ready() bool {
return s != nil && s.responseSet
}
func NewHTTPMaskClientEarlyHandshake(cfg EarlyCodecConfig, table *sudokuobfs.Table, userHash [kipHelloUserHashSize]byte, feats uint32) (*httpmaskobfs.ClientEarlyHandshake, error) {
state, err := NewEarlyClientState(cfg, table, userHash, feats)
if err != nil {
return nil, err
}
return &httpmaskobfs.ClientEarlyHandshake{
RequestPayload: state.RequestPayload,
HandleResponse: state.ProcessResponse,
Ready: state.Ready,
WrapConn: state.WrapConn,
}, nil
}
func ProcessEarlyClientPayload(cfg EarlyCodecConfig, tables []*sudokuobfs.Table, payload []byte, allowReplay ReplayAllowFunc) (*EarlyServerState, error) {
if len(payload) == 0 {
return nil, fmt.Errorf("empty early payload")
}
if len(tables) == 0 {
return nil, fmt.Errorf("no tables configured")
}
var firstErr error
for _, table := range tables {
state, err := processEarlyClientPayloadForTable(cfg, table, payload, allowReplay)
if err == nil {
return state, nil
}
if firstErr == nil {
firstErr = err
}
}
if firstErr == nil {
firstErr = fmt.Errorf("early handshake probe failed")
}
return nil, firstErr
}
func processEarlyClientPayloadForTable(cfg EarlyCodecConfig, table *sudokuobfs.Table, payload []byte, allowReplay ReplayAllowFunc) (*EarlyServerState, error) {
mem := newEarlyMemoryConn(payload)
obfsConn := buildEarlyServerObfsConn(mem, cfg, table)
pskC2S, pskS2C := derivePSKDirectionalBases(cfg.PSK)
rc, err := crypto.NewRecordConn(obfsConn, cfg.AEAD, pskS2C, pskC2S)
if err != nil {
return nil, err
}
msg, err := ReadKIPMessage(rc)
if err != nil {
return nil, err
}
if msg.Type != KIPTypeClientHello {
return nil, fmt.Errorf("unexpected handshake message: %d", msg.Type)
}
ch, err := DecodeKIPClientHelloPayload(msg.Payload)
if err != nil {
return nil, err
}
if absInt64(time.Now().Unix()-ch.Timestamp.Unix()) > int64(earlyKIPHandshakeTTL.Seconds()) {
return nil, fmt.Errorf("time skew/replay")
}
userHash := hex.EncodeToString(ch.UserHash[:])
if allowReplay != nil && !allowReplay(userHash, ch.Nonce, time.Now()) {
return nil, fmt.Errorf("replay detected")
}
curve := ecdh.X25519()
serverEphemeral, err := curve.GenerateKey(rand.Reader)
if err != nil {
return nil, fmt.Errorf("ecdh generate failed: %w", err)
}
shared, err := x25519SharedSecret(serverEphemeral, ch.ClientPub[:])
if err != nil {
return nil, fmt.Errorf("ecdh failed: %w", err)
}
sessionC2S, sessionS2C, err := deriveSessionDirectionalBases(cfg.PSK, shared, ch.Nonce)
if err != nil {
return nil, fmt.Errorf("derive session keys failed: %w", err)
}
var serverPub [kipHelloPubSize]byte
copy(serverPub[:], serverEphemeral.PublicKey().Bytes())
serverHello := &KIPServerHello{
Nonce: ch.Nonce,
ServerPub: serverPub,
SelectedFeats: ch.Features & KIPFeatAll,
}
respMem := newEarlyMemoryConn(nil)
respObfs := buildEarlyServerObfsConn(respMem, cfg, table)
respConn, err := crypto.NewRecordConn(respObfs, cfg.AEAD, pskS2C, pskC2S)
if err != nil {
return nil, fmt.Errorf("server early crypto setup failed: %w", err)
}
if err := WriteKIPMessage(respConn, KIPTypeServerHello, serverHello.EncodePayload()); err != nil {
return nil, fmt.Errorf("write early server hello failed: %w", err)
}
return &EarlyServerState{
ResponsePayload: respMem.Written(),
UserHash: userHash,
cfg: cfg,
table: table,
sessionC2S: sessionC2S,
sessionS2C: sessionS2C,
}, nil
}
func (s *EarlyServerState) WrapConn(raw net.Conn) (net.Conn, error) {
if s == nil {
return nil, fmt.Errorf("nil server state")
}
obfsConn := buildEarlyServerObfsConn(raw, s.cfg, s.table)
rc, err := crypto.NewRecordConn(obfsConn, s.cfg.AEAD, s.sessionS2C, s.sessionC2S)
if err != nil {
return nil, fmt.Errorf("setup server session crypto failed: %w", err)
}
return rc, nil
}
func NewHTTPMaskServerEarlyHandshake(cfg EarlyCodecConfig, tables []*sudokuobfs.Table, allowReplay ReplayAllowFunc) *httpmaskobfs.TunnelServerEarlyHandshake {
return &httpmaskobfs.TunnelServerEarlyHandshake{
Prepare: func(payload []byte) (*httpmaskobfs.PreparedServerEarlyHandshake, error) {
state, err := ProcessEarlyClientPayload(cfg, tables, payload, allowReplay)
if err != nil {
return nil, err
}
return &httpmaskobfs.PreparedServerEarlyHandshake{
ResponsePayload: state.ResponsePayload,
WrapConn: state.WrapConn,
UserHash: state.UserHash,
}, nil
},
}
}

View File

@@ -337,6 +337,9 @@ func ServerHandshake(rawConn net.Conn, cfg *ProtocolConfig) (net.Conn, *Handshak
if err := cfg.Validate(); err != nil {
return nil, nil, fmt.Errorf("invalid config: %w", err)
}
if userHash, ok := httpmask.EarlyHandshakeUserHash(rawConn); ok {
return rawConn, &HandshakeMeta{UserHash: userHash}, nil
}
handshakeTimeout := time.Duration(cfg.HandshakeTimeoutSeconds) * time.Second
if handshakeTimeout <= 0 {

View File

@@ -14,6 +14,30 @@ type HTTPMaskTunnelServer struct {
ts *httpmask.TunnelServer
}
func newHTTPMaskEarlyCodecConfig(cfg *ProtocolConfig, psk string) EarlyCodecConfig {
return EarlyCodecConfig{
PSK: psk,
AEAD: cfg.AEADMethod,
EnablePureDownlink: cfg.EnablePureDownlink,
PaddingMin: cfg.PaddingMin,
PaddingMax: cfg.PaddingMax,
}
}
func newClientHTTPMaskEarlyHandshake(cfg *ProtocolConfig) (*httpmask.ClientEarlyHandshake, error) {
table, err := pickClientTable(cfg)
if err != nil {
return nil, err
}
return NewHTTPMaskClientEarlyHandshake(
newHTTPMaskEarlyCodecConfig(cfg, ClientAEADSeed(cfg.Key)),
table,
kipUserHashFromKey(cfg.Key),
KIPFeatAll,
)
}
func NewHTTPMaskTunnelServer(cfg *ProtocolConfig) *HTTPMaskTunnelServer {
return newHTTPMaskTunnelServer(cfg, false)
}
@@ -35,6 +59,11 @@ func newHTTPMaskTunnelServer(cfg *ProtocolConfig, passThroughOnReject bool) *HTT
Mode: cfg.HTTPMaskMode,
PathRoot: cfg.HTTPMaskPathRoot,
AuthKey: ServerAEADSeed(cfg.Key),
EarlyHandshake: NewHTTPMaskServerEarlyHandshake(
newHTTPMaskEarlyCodecConfig(cfg, ServerAEADSeed(cfg.Key)),
cfg.tableCandidates(),
globalHandshakeReplay.allow,
),
// When upstream fallback is enabled, preserve rejected HTTP requests for the caller.
PassThroughOnReject: passThroughOnReject,
})
@@ -101,12 +130,23 @@ func DialHTTPMaskTunnel(ctx context.Context, serverAddress string, cfg *Protocol
default:
return nil, fmt.Errorf("http-mask-mode=%q does not use http tunnel", cfg.HTTPMaskMode)
}
var (
earlyHandshake *httpmask.ClientEarlyHandshake
err error
)
if upgrade != nil {
earlyHandshake, err = newClientHTTPMaskEarlyHandshake(cfg)
if err != nil {
return nil, err
}
}
return httpmask.DialTunnel(ctx, serverAddress, httpmask.TunnelDialOptions{
Mode: cfg.HTTPMaskMode,
TLSEnabled: cfg.HTTPMaskTLSEnabled,
HostOverride: cfg.HTTPMaskHost,
PathRoot: cfg.HTTPMaskPathRoot,
AuthKey: ClientAEADSeed(cfg.Key),
EarlyHandshake: earlyHandshake,
Upgrade: upgrade,
Multiplex: cfg.HTTPMaskMultiplex,
DialContext: dial,

View File

@@ -389,6 +389,68 @@ func TestHTTPMaskTunnel_WS_TCPRoundTrip(t *testing.T) {
}
}
func TestHTTPMaskTunnel_EarlyHandshake_TCPRoundTrip(t *testing.T) {
modes := []string{"stream", "poll", "ws"}
for _, mode := range modes {
t.Run(mode, func(t *testing.T) {
key := "tunnel-early-" + mode
target := "1.1.1.1:80"
serverCfg := newTunnelTestTable(t, key)
serverCfg.HTTPMaskMode = mode
addr, stop, errCh := startTunnelServer(t, serverCfg, func(s *ServerSession) error {
if s.Type != SessionTypeTCP {
return fmt.Errorf("unexpected session type: %v", s.Type)
}
if s.Target != target {
return fmt.Errorf("target mismatch: %s", s.Target)
}
_, _ = s.Conn.Write([]byte("ok"))
return nil
})
defer stop()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
clientCfg := *serverCfg
clientCfg.ServerAddress = addr
handshakeCfg := clientCfg
handshakeCfg.DisableHTTPMask = true
tunnelConn, err := DialHTTPMaskTunnel(ctx, clientCfg.ServerAddress, &clientCfg, (&net.Dialer{}).DialContext, func(raw net.Conn) (net.Conn, error) {
return ClientHandshake(raw, &handshakeCfg)
})
if err != nil {
t.Fatalf("dial tunnel: %v", err)
}
defer tunnelConn.Close()
addrBuf, err := EncodeAddress(target)
if err != nil {
t.Fatalf("encode addr: %v", err)
}
if err := WriteKIPMessage(tunnelConn, KIPTypeOpenTCP, addrBuf); err != nil {
t.Fatalf("write addr: %v", err)
}
buf := make([]byte, 2)
if _, err := io.ReadFull(tunnelConn, buf); err != nil {
t.Fatalf("read: %v", err)
}
if string(buf) != "ok" {
t.Fatalf("unexpected payload: %q", buf)
}
stop()
for err := range errCh {
t.Fatalf("server error: %v", err)
}
})
}
}
func TestHTTPMaskTunnel_Validation(t *testing.T) {
cfg := DefaultConfig()
cfg.Key = "k"

View File

@@ -0,0 +1,174 @@
package httpmask
import (
"encoding/base64"
"errors"
"fmt"
"net"
"net/url"
"strings"
)
const (
tunnelEarlyDataQueryKey = "ed"
tunnelEarlyDataHeader = "X-Sudoku-Early"
)
type ClientEarlyHandshake struct {
RequestPayload []byte
HandleResponse func(payload []byte) error
Ready func() bool
WrapConn func(raw net.Conn) (net.Conn, error)
}
type TunnelServerEarlyHandshake struct {
Prepare func(payload []byte) (*PreparedServerEarlyHandshake, error)
}
type PreparedServerEarlyHandshake struct {
ResponsePayload []byte
WrapConn func(raw net.Conn) (net.Conn, error)
UserHash string
}
type earlyHandshakeMeta interface {
HTTPMaskEarlyHandshakeUserHash() string
}
type earlyHandshakeConn struct {
net.Conn
userHash string
}
func (c *earlyHandshakeConn) HTTPMaskEarlyHandshakeUserHash() string {
if c == nil {
return ""
}
return c.userHash
}
func wrapEarlyHandshakeConn(conn net.Conn, userHash string) net.Conn {
if conn == nil {
return nil
}
return &earlyHandshakeConn{Conn: conn, userHash: userHash}
}
func EarlyHandshakeUserHash(conn net.Conn) (string, bool) {
if conn == nil {
return "", false
}
v, ok := conn.(earlyHandshakeMeta)
if !ok {
return "", false
}
return v.HTTPMaskEarlyHandshakeUserHash(), true
}
type authorizeResponse struct {
token string
earlyPayload []byte
}
func isTunnelTokenByte(c byte) bool {
return (c >= 'a' && c <= 'z') ||
(c >= 'A' && c <= 'Z') ||
(c >= '0' && c <= '9') ||
c == '-' ||
c == '_'
}
func parseAuthorizeResponse(body []byte) (*authorizeResponse, error) {
s := strings.TrimSpace(string(body))
idx := strings.Index(s, "token=")
if idx < 0 {
return nil, errors.New("missing token")
}
s = s[idx+len("token="):]
if s == "" {
return nil, errors.New("empty token")
}
var b strings.Builder
for i := 0; i < len(s); i++ {
c := s[i]
if isTunnelTokenByte(c) {
b.WriteByte(c)
continue
}
break
}
token := b.String()
if token == "" {
return nil, errors.New("empty token")
}
out := &authorizeResponse{token: token}
if earlyLine := findAuthorizeField(body, "ed="); earlyLine != "" {
decoded, err := base64.RawURLEncoding.DecodeString(earlyLine)
if err != nil {
return nil, fmt.Errorf("decode early authorize payload failed: %w", err)
}
out.earlyPayload = decoded
}
return out, nil
}
func findAuthorizeField(body []byte, prefix string) string {
for _, line := range strings.Split(strings.TrimSpace(string(body)), "\n") {
line = strings.TrimSpace(line)
if strings.HasPrefix(line, prefix) {
return strings.TrimSpace(strings.TrimPrefix(line, prefix))
}
}
return ""
}
func setEarlyDataQuery(rawURL string, payload []byte) (string, error) {
if len(payload) == 0 {
return rawURL, nil
}
u, err := url.Parse(rawURL)
if err != nil {
return "", err
}
q := u.Query()
q.Set(tunnelEarlyDataQueryKey, base64.RawURLEncoding.EncodeToString(payload))
u.RawQuery = q.Encode()
return u.String(), nil
}
func parseEarlyDataQuery(u *url.URL) ([]byte, error) {
if u == nil {
return nil, nil
}
val := strings.TrimSpace(u.Query().Get(tunnelEarlyDataQueryKey))
if val == "" {
return nil, nil
}
return base64.RawURLEncoding.DecodeString(val)
}
func applyEarlyHandshakeOrUpgrade(raw net.Conn, opts TunnelDialOptions) (net.Conn, error) {
out := raw
if opts.EarlyHandshake != nil && opts.EarlyHandshake.WrapConn != nil && (opts.EarlyHandshake.Ready == nil || opts.EarlyHandshake.Ready()) {
wrapped, err := opts.EarlyHandshake.WrapConn(raw)
if err != nil {
return nil, err
}
if wrapped != nil {
out = wrapped
}
return out, nil
}
if opts.Upgrade != nil {
wrapped, err := opts.Upgrade(raw)
if err != nil {
return nil, err
}
if wrapped != nil {
out = wrapped
}
}
return out, nil
}

View File

@@ -72,6 +72,10 @@ type TunnelDialOptions struct {
// AuthKey enables short-term HMAC auth for HTTP tunnel requests (anti-probing).
// When set (non-empty), each HTTP request carries an Authorization bearer token derived from AuthKey.
AuthKey string
// EarlyHandshake folds the protocol handshake into the HTTP/WS setup round trip.
// When the server accepts the early payload, DialTunnel returns a conn that is already post-handshake.
// When the server does not echo early data, DialTunnel falls back to Upgrade.
EarlyHandshake *ClientEarlyHandshake
// Upgrade optionally wraps the raw tunnel conn and/or writes a small prelude before DialTunnel returns.
// It is called with the raw tunnel conn; if it returns a non-nil conn, that conn is returned by DialTunnel.
Upgrade func(raw net.Conn) (net.Conn, error)
@@ -225,30 +229,11 @@ func canonicalHeaderHost(urlHost, scheme string) string {
}
func parseTunnelToken(body []byte) (string, error) {
s := strings.TrimSpace(string(body))
idx := strings.Index(s, "token=")
if idx < 0 {
return "", errors.New("missing token")
resp, err := parseAuthorizeResponse(body)
if err != nil {
return "", err
}
s = s[idx+len("token="):]
if s == "" {
return "", errors.New("empty token")
}
// Token is base64.RawURLEncoding (A-Z a-z 0-9 - _). Strip any trailing bytes (e.g. from CDN compression).
var b strings.Builder
for i := 0; i < len(s); i++ {
c := s[i]
if (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '-' || c == '_' {
b.WriteByte(c)
continue
}
break
}
token := b.String()
if token == "" {
return "", errors.New("empty token")
}
return token, nil
return resp.token, nil
}
type httpClientTarget struct {
@@ -353,6 +338,13 @@ func dialSessionWithClient(ctx context.Context, client *http.Client, target http
auth := newTunnelAuth(opts.AuthKey, 0)
authorizeURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: joinPathRoot(opts.PathRoot, "/session")}).String()
if opts.EarlyHandshake != nil && len(opts.EarlyHandshake.RequestPayload) > 0 {
var err error
authorizeURL, err = setEarlyDataQuery(authorizeURL, opts.EarlyHandshake.RequestPayload)
if err != nil {
return nil, err
}
}
var bodyBytes []byte
for attempt := 0; ; attempt++ {
@@ -410,13 +402,19 @@ func dialSessionWithClient(ctx context.Context, client *http.Client, target http
break
}
token, err := parseTunnelToken(bodyBytes)
authResp, err := parseAuthorizeResponse(bodyBytes)
if err != nil {
return nil, fmt.Errorf("%s authorize failed: %q", mode, strings.TrimSpace(string(bodyBytes)))
}
token := authResp.token
if token == "" {
return nil, fmt.Errorf("%s authorize empty token", mode)
}
if opts.EarlyHandshake != nil && len(authResp.earlyPayload) > 0 && opts.EarlyHandshake.HandleResponse != nil {
if err := opts.EarlyHandshake.HandleResponse(authResp.earlyPayload); err != nil {
return nil, err
}
}
pushURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: joinPathRoot(opts.PathRoot, "/api/v1/upload"), RawQuery: "token=" + url.QueryEscape(token)}).String()
pullURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: joinPathRoot(opts.PathRoot, "/stream"), RawQuery: "token=" + url.QueryEscape(token)}).String()
@@ -671,17 +669,11 @@ func dialStreamSplitWithClient(ctx context.Context, client *http.Client, target
if c == nil {
return nil, fmt.Errorf("failed to build stream split conn")
}
outConn := net.Conn(c)
if opts.Upgrade != nil {
upgraded, err := opts.Upgrade(c)
outConn, err := applyEarlyHandshakeOrUpgrade(c, opts)
if err != nil {
_ = c.Close()
return nil, err
}
if upgraded != nil {
outConn = upgraded
}
}
return outConn, nil
}
@@ -694,17 +686,11 @@ func dialStreamSplit(ctx context.Context, serverAddress string, opts TunnelDialO
if c == nil {
return nil, fmt.Errorf("failed to build stream split conn")
}
outConn := net.Conn(c)
if opts.Upgrade != nil {
upgraded, err := opts.Upgrade(c)
outConn, err := applyEarlyHandshakeOrUpgrade(c, opts)
if err != nil {
_ = c.Close()
return nil, err
}
if upgraded != nil {
outConn = upgraded
}
}
return outConn, nil
}
@@ -1120,17 +1106,11 @@ func dialPollWithClient(ctx context.Context, client *http.Client, target httpCli
if c == nil {
return nil, fmt.Errorf("failed to build poll conn")
}
outConn := net.Conn(c)
if opts.Upgrade != nil {
upgraded, err := opts.Upgrade(c)
outConn, err := applyEarlyHandshakeOrUpgrade(c, opts)
if err != nil {
_ = c.Close()
return nil, err
}
if upgraded != nil {
outConn = upgraded
}
}
return outConn, nil
}
@@ -1143,17 +1123,11 @@ func dialPoll(ctx context.Context, serverAddress string, opts TunnelDialOptions)
if c == nil {
return nil, fmt.Errorf("failed to build poll conn")
}
outConn := net.Conn(c)
if opts.Upgrade != nil {
upgraded, err := opts.Upgrade(c)
outConn, err := applyEarlyHandshakeOrUpgrade(c, opts)
if err != nil {
_ = c.Close()
return nil, err
}
if upgraded != nil {
outConn = upgraded
}
}
return outConn, nil
}
@@ -1528,6 +1502,8 @@ type TunnelServerOptions struct {
PullReadTimeout time.Duration
// SessionTTL is a best-effort TTL to prevent leaked sessions. 0 uses a conservative default.
SessionTTL time.Duration
// EarlyHandshake optionally folds the protocol handshake into the initial HTTP/WS round trip.
EarlyHandshake *TunnelServerEarlyHandshake
}
type TunnelServer struct {
@@ -1538,6 +1514,7 @@ type TunnelServer struct {
pullReadTimeout time.Duration
sessionTTL time.Duration
earlyHandshake *TunnelServerEarlyHandshake
mu sync.Mutex
sessions map[string]*tunnelSession
@@ -1570,6 +1547,7 @@ func NewTunnelServer(opts TunnelServerOptions) *TunnelServer {
passThroughOnReject: opts.PassThroughOnReject,
pullReadTimeout: timeout,
sessionTTL: ttl,
earlyHandshake: opts.EarlyHandshake,
sessions: make(map[string]*tunnelSession),
}
}
@@ -1925,9 +1903,12 @@ func (s *TunnelServer) handleStream(rawConn net.Conn, req *httpRequestHeader, he
switch strings.ToUpper(req.method) {
case http.MethodGet:
// Stream split-session: GET /session (no token) => token + start tunnel on a server-side pipe.
if token == "" && path == "/session" {
return s.sessionAuthorize(rawConn)
earlyPayload, err := parseEarlyDataQuery(u)
if err != nil {
return rejectOrReply(http.StatusBadRequest, "bad request")
}
return s.sessionAuthorize(rawConn, earlyPayload)
}
// Stream split-session: GET /stream?token=... => downlink poll.
if token != "" && path == "/stream" {
@@ -2045,10 +2026,18 @@ func writeSimpleHTTPResponse(w io.Writer, code int, body string) error {
func writeTokenHTTPResponse(w io.Writer, token string) error {
token = strings.TrimRight(token, "\r\n")
// Use application/octet-stream to avoid CDN auto-compression (e.g. brotli) breaking clients that expect a plain token string.
return writeTokenHTTPResponseWithEarlyData(w, token, nil)
}
func writeTokenHTTPResponseWithEarlyData(w io.Writer, token string, earlyPayload []byte) error {
token = strings.TrimRight(token, "\r\n")
body := "token=" + token
if len(earlyPayload) > 0 {
body += "\ned=" + base64.RawURLEncoding.EncodeToString(earlyPayload)
}
_, err := io.WriteString(w,
fmt.Sprintf("HTTP/1.1 200 OK\r\nContent-Type: application/octet-stream\r\nCache-Control: no-store\r\nPragma: no-cache\r\nContent-Length: %d\r\nConnection: close\r\n\r\ntoken=%s",
len("token=")+len(token), token))
fmt.Sprintf("HTTP/1.1 200 OK\r\nContent-Type: application/octet-stream\r\nCache-Control: no-store\r\nPragma: no-cache\r\nContent-Length: %d\r\nConnection: close\r\n\r\n%s",
len(body), body))
return err
}
@@ -2088,7 +2077,11 @@ func (s *TunnelServer) handlePoll(rawConn net.Conn, req *httpRequestHeader, head
switch strings.ToUpper(req.method) {
case http.MethodGet:
if token == "" && path == "/session" {
return s.sessionAuthorize(rawConn)
earlyPayload, err := parseEarlyDataQuery(u)
if err != nil {
return rejectOrReply(http.StatusBadRequest, "bad request")
}
return s.sessionAuthorize(rawConn, earlyPayload)
}
if token != "" && path == "/stream" {
if s.passThroughOnReject && !s.sessionHas(token) {
@@ -2128,7 +2121,7 @@ func (s *TunnelServer) handlePoll(rawConn net.Conn, req *httpRequestHeader, head
}
}
func (s *TunnelServer) sessionAuthorize(rawConn net.Conn) (HandleResult, net.Conn, error) {
func (s *TunnelServer) sessionAuthorize(rawConn net.Conn, earlyPayload []byte) (HandleResult, net.Conn, error) {
token, err := newSessionToken()
if err != nil {
_ = writeSimpleHTTPResponse(rawConn, http.StatusInternalServerError, "internal error")
@@ -2137,6 +2130,37 @@ func (s *TunnelServer) sessionAuthorize(rawConn net.Conn) (HandleResult, net.Con
}
c1, c2 := newHalfPipe()
outConn := net.Conn(c1)
var responsePayload []byte
var userHash string
if len(earlyPayload) > 0 && s.earlyHandshake != nil && s.earlyHandshake.Prepare != nil {
prepared, err := s.earlyHandshake.Prepare(earlyPayload)
if err != nil {
_ = c1.Close()
_ = c2.Close()
if s.passThroughOnReject {
return HandlePassThrough, newRejectedPreBufferedConn(rawConn, nil), nil
}
_ = writeSimpleHTTPResponse(rawConn, http.StatusNotFound, "not found")
_ = rawConn.Close()
return HandleDone, nil, nil
}
responsePayload = prepared.ResponsePayload
userHash = prepared.UserHash
if prepared.WrapConn != nil {
wrapped, err := prepared.WrapConn(c1)
if err != nil {
_ = c1.Close()
_ = c2.Close()
_ = writeSimpleHTTPResponse(rawConn, http.StatusInternalServerError, "internal error")
_ = rawConn.Close()
return HandleDone, nil, nil
}
if wrapped != nil {
outConn = wrapEarlyHandshakeConn(wrapped, userHash)
}
}
}
s.mu.Lock()
s.sessions[token] = &tunnelSession{conn: c2, lastActive: time.Now()}
@@ -2144,9 +2168,9 @@ func (s *TunnelServer) sessionAuthorize(rawConn net.Conn) (HandleResult, net.Con
go s.reapLater(token)
_ = writeTokenHTTPResponse(rawConn, token)
_ = writeTokenHTTPResponseWithEarlyData(rawConn, token, responsePayload)
_ = rawConn.Close()
return HandleStartTunnel, c1, nil
return HandleStartTunnel, outConn, nil
}
func newSessionToken() (string, error) {

View File

@@ -2,6 +2,7 @@ package httpmask
import (
"context"
"encoding/base64"
"fmt"
"io"
mrand "math/rand"
@@ -115,6 +116,16 @@ func dialWS(ctx context.Context, serverAddress string, opts TunnelDialOptions) (
Host: urlHost,
Path: joinPathRoot(opts.PathRoot, "/ws"),
}
if opts.EarlyHandshake != nil && len(opts.EarlyHandshake.RequestPayload) > 0 {
rawURL, err := setEarlyDataQuery(u.String(), opts.EarlyHandshake.RequestPayload)
if err != nil {
return nil, err
}
u, err = url.Parse(rawURL)
if err != nil {
return nil, err
}
}
header := make(stdhttp.Header)
applyWSHeaders(header, headerHost)
@@ -132,6 +143,16 @@ func dialWS(ctx context.Context, serverAddress string, opts TunnelDialOptions) (
d := ws.Dialer{
Host: headerHost,
Header: ws.HandshakeHeaderHTTP(header),
OnHeader: func(key, value []byte) error {
if !strings.EqualFold(string(key), tunnelEarlyDataHeader) || opts.EarlyHandshake == nil || opts.EarlyHandshake.HandleResponse == nil {
return nil
}
decoded, err := base64.RawURLEncoding.DecodeString(strings.TrimSpace(string(value)))
if err != nil {
return err
}
return opts.EarlyHandshake.HandleResponse(decoded)
},
NetDial: func(dialCtx context.Context, network, addr string) (net.Conn, error) {
if addr == urlHost {
addr = dialAddr
@@ -161,16 +182,10 @@ func dialWS(ctx context.Context, serverAddress string, opts TunnelDialOptions) (
}
wsConn := newWSStreamConn(conn, ws.StateClientSide)
if opts.Upgrade == nil {
return wsConn, nil
}
upgraded, err := opts.Upgrade(wsConn)
upgraded, err := applyEarlyHandshakeOrUpgrade(wsConn, opts)
if err != nil {
_ = wsConn.Close()
return nil, err
}
if upgraded != nil {
return upgraded, nil
}
return wsConn, nil
}

View File

@@ -1,6 +1,7 @@
package httpmask
import (
"encoding/base64"
"net"
"net/http"
"net/url"
@@ -63,15 +64,46 @@ func (s *TunnelServer) handleWS(rawConn net.Conn, req *httpRequestHeader, header
return rejectOrReply(http.StatusNotFound, "not found")
}
earlyPayload, err := parseEarlyDataQuery(u)
if err != nil {
return rejectOrReply(http.StatusBadRequest, "bad request")
}
var prepared *PreparedServerEarlyHandshake
if len(earlyPayload) > 0 && s.earlyHandshake != nil && s.earlyHandshake.Prepare != nil {
prepared, err = s.earlyHandshake.Prepare(earlyPayload)
if err != nil {
return rejectOrReply(http.StatusNotFound, "not found")
}
}
prefix := make([]byte, 0, len(headerBytes)+len(buffered))
prefix = append(prefix, headerBytes...)
prefix = append(prefix, buffered...)
wsConnRaw := newPreBufferedConn(rawConn, prefix)
if _, err := ws.Upgrade(wsConnRaw); err != nil {
upgrader := ws.Upgrader{}
if prepared != nil && len(prepared.ResponsePayload) > 0 {
upgrader.OnBeforeUpgrade = func() (ws.HandshakeHeader, error) {
h := http.Header{}
h.Set(tunnelEarlyDataHeader, base64.RawURLEncoding.EncodeToString(prepared.ResponsePayload))
return ws.HandshakeHeaderHTTP(h), nil
}
}
if _, err := upgrader.Upgrade(wsConnRaw); err != nil {
_ = rawConn.Close()
return HandleDone, nil, nil
}
return HandleStartTunnel, newWSStreamConn(wsConnRaw, ws.StateServerSide), nil
outConn := net.Conn(newWSStreamConn(wsConnRaw, ws.StateServerSide))
if prepared != nil && prepared.WrapConn != nil {
wrapped, err := prepared.WrapConn(outConn)
if err != nil {
_ = outConn.Close()
return HandleDone, nil, nil
}
if wrapped != nil {
outConn = wrapEarlyHandshakeConn(wrapped, prepared.UserHash)
}
}
return HandleStartTunnel, outConn, nil
}

View File

@@ -11,35 +11,35 @@ import (
)
const (
// 每次从 RNG 获取批量随机数的缓存大小,减少 RNG 函数调用开销
RngBatchSize = 128
packedProtectedPrefixBytes = 14
)
// 1. 使用 12字节->16组 的块处理优化 Write (减少循环开销)
// 2. 使用整数阈值随机概率判断 Padding与纯 Sudoku 保持流量特征一致
// 3. Read 使用 copy 移动避免底层数组泄漏
// PackedConn encodes traffic with the packed Sudoku layout while preserving
// the same padding model as the regular connection.
type PackedConn struct {
net.Conn
table *Table
reader *bufio.Reader
// 读缓冲
// Read-side buffers.
rawBuf []byte
pendingData []byte // 解码后尚未被 Read 取走的字节
pendingData []byte
// 写缓冲与状态
// Write-side state.
writeMu sync.Mutex
writeBuf []byte
bitBuf uint64 // 暂存的位数据
bitCount int // 暂存的位数
bitBuf uint64
bitCount int
// 读状态
// Read-side bit accumulator.
readBitBuf uint64
readBits int
// 随机数与填充控制 - 使用整数阈值随机,与 Conn 一致
// Padding selection matches Conn's threshold-based model.
rng *rand.Rand
paddingThreshold uint64 // 与 Conn 保持一致的随机概率模型
paddingThreshold uint64
padMarker byte
padPool []byte
}
@@ -95,7 +95,6 @@ func NewPackedConn(c net.Conn, table *Table, pMin, pMax int) *PackedConn {
return pc
}
// maybeAddPadding 内联辅助:根据概率阈值插入 padding
func (pc *PackedConn) maybeAddPadding(out []byte) []byte {
if shouldPad(pc.rng, pc.paddingThreshold) {
out = append(out, pc.getPaddingByte())
@@ -103,7 +102,73 @@ func (pc *PackedConn) maybeAddPadding(out []byte) []byte {
return out
}
// Write 极致优化版 - 批量处理 12 字节
func (pc *PackedConn) appendGroup(out []byte, group byte) []byte {
out = pc.maybeAddPadding(out)
return append(out, pc.encodeGroup(group))
}
func (pc *PackedConn) appendForcedPadding(out []byte) []byte {
return append(out, pc.getPaddingByte())
}
func (pc *PackedConn) nextProtectedPrefixGap() int {
return 1 + pc.rng.Intn(2)
}
func (pc *PackedConn) writeProtectedPrefix(out []byte, p []byte) ([]byte, int) {
if len(p) == 0 {
return out, 0
}
limit := len(p)
if limit > packedProtectedPrefixBytes {
limit = packedProtectedPrefixBytes
}
for padCount := 0; padCount < 1+pc.rng.Intn(2); padCount++ {
out = pc.appendForcedPadding(out)
}
gap := pc.nextProtectedPrefixGap()
effective := 0
for i := 0; i < limit; i++ {
pc.bitBuf = (pc.bitBuf << 8) | uint64(p[i])
pc.bitCount += 8
for pc.bitCount >= 6 {
pc.bitCount -= 6
group := byte(pc.bitBuf >> pc.bitCount)
if pc.bitCount == 0 {
pc.bitBuf = 0
} else {
pc.bitBuf &= (1 << pc.bitCount) - 1
}
out = pc.appendGroup(out, group&0x3F)
}
effective++
if effective >= gap {
out = pc.appendForcedPadding(out)
effective = 0
gap = pc.nextProtectedPrefixGap()
}
}
return out, limit
}
func (pc *PackedConn) drainPendingData(dst []byte) int {
n := copy(dst, pc.pendingData)
if n == len(pc.pendingData) {
pc.pendingData = pc.pendingData[:0]
return n
}
remaining := len(pc.pendingData) - n
copy(pc.pendingData, pc.pendingData[n:])
pc.pendingData = pc.pendingData[:remaining]
return n
}
func (pc *PackedConn) Write(p []byte) (int, error) {
if len(p) == 0 {
return 0, nil
@@ -112,20 +177,19 @@ func (pc *PackedConn) Write(p []byte) (int, error) {
pc.writeMu.Lock()
defer pc.writeMu.Unlock()
// 1. 预分配内存,避免 append 导致的多次扩容
// 预估:原数据 * 1.5 (4/3 + padding 余量)
needed := len(p)*3/2 + 32
if cap(pc.writeBuf) < needed {
pc.writeBuf = make([]byte, 0, needed)
}
out := pc.writeBuf[:0]
i := 0
var prefixN int
out, prefixN = pc.writeProtectedPrefix(out, p)
i := prefixN
n := len(p)
// 2. 头部对齐处理 (Slow Path)
for pc.bitCount > 0 && i < n {
out = pc.maybeAddPadding(out)
b := p[i]
i++
pc.bitBuf = (pc.bitBuf << 8) | uint64(b)
@@ -138,14 +202,11 @@ func (pc *PackedConn) Write(p []byte) (int, error) {
} else {
pc.bitBuf &= (1 << pc.bitCount) - 1
}
out = pc.maybeAddPadding(out)
out = append(out, pc.encodeGroup(group&0x3F))
out = pc.appendGroup(out, group&0x3F)
}
}
// 3. 极速批量处理 (Fast Path) - 每次处理 12 字节 → 生成 16 个编码组
for i+11 < n {
// 处理 4 组,每组 3 字节
for batch := 0; batch < 4; batch++ {
b1, b2, b3 := p[i], p[i+1], p[i+2]
i += 3
@@ -155,19 +216,13 @@ func (pc *PackedConn) Write(p []byte) (int, error) {
g3 := ((b2 & 0x0F) << 2) | ((b3 >> 6) & 0x03)
g4 := b3 & 0x3F
// 每个组之前都有概率插入 padding
out = pc.maybeAddPadding(out)
out = append(out, pc.encodeGroup(g1))
out = pc.maybeAddPadding(out)
out = append(out, pc.encodeGroup(g2))
out = pc.maybeAddPadding(out)
out = append(out, pc.encodeGroup(g3))
out = pc.maybeAddPadding(out)
out = append(out, pc.encodeGroup(g4))
out = pc.appendGroup(out, g1)
out = pc.appendGroup(out, g2)
out = pc.appendGroup(out, g3)
out = pc.appendGroup(out, g4)
}
}
// 4. 处理剩余的 3 字节块
for i+2 < n {
b1, b2, b3 := p[i], p[i+1], p[i+2]
i += 3
@@ -177,17 +232,12 @@ func (pc *PackedConn) Write(p []byte) (int, error) {
g3 := ((b2 & 0x0F) << 2) | ((b3 >> 6) & 0x03)
g4 := b3 & 0x3F
out = pc.maybeAddPadding(out)
out = append(out, pc.encodeGroup(g1))
out = pc.maybeAddPadding(out)
out = append(out, pc.encodeGroup(g2))
out = pc.maybeAddPadding(out)
out = append(out, pc.encodeGroup(g3))
out = pc.maybeAddPadding(out)
out = append(out, pc.encodeGroup(g4))
out = pc.appendGroup(out, g1)
out = pc.appendGroup(out, g2)
out = pc.appendGroup(out, g3)
out = pc.appendGroup(out, g4)
}
// 5. 尾部处理 (Tail Path) - 处理剩余的 1 或 2 个字节
for ; i < n; i++ {
b := p[i]
pc.bitBuf = (pc.bitBuf << 8) | uint64(b)
@@ -200,35 +250,28 @@ func (pc *PackedConn) Write(p []byte) (int, error) {
} else {
pc.bitBuf &= (1 << pc.bitCount) - 1
}
out = pc.maybeAddPadding(out)
out = append(out, pc.encodeGroup(group&0x3F))
out = pc.appendGroup(out, group&0x3F)
}
}
// 6. 处理残留位
if pc.bitCount > 0 {
out = pc.maybeAddPadding(out)
group := byte(pc.bitBuf << (6 - pc.bitCount))
pc.bitBuf = 0
pc.bitCount = 0
out = append(out, pc.encodeGroup(group&0x3F))
out = pc.appendGroup(out, group&0x3F)
out = append(out, pc.padMarker)
}
// 尾部可能添加 padding
out = pc.maybeAddPadding(out)
// 发送数据
if len(out) > 0 {
_, err := pc.Conn.Write(out)
pc.writeBuf = out[:0]
return len(p), err
return len(p), writeFull(pc.Conn, out)
}
pc.writeBuf = out[:0]
return len(p), nil
}
// Flush 处理最后不足 6 bit 的情况
func (pc *PackedConn) Flush() error {
pc.writeMu.Lock()
defer pc.writeMu.Unlock()
@@ -243,38 +286,34 @@ func (pc *PackedConn) Flush() error {
out = append(out, pc.padMarker)
}
// 尾部随机添加 padding
out = pc.maybeAddPadding(out)
if len(out) > 0 {
_, err := pc.Conn.Write(out)
pc.writeBuf = out[:0]
return err
return writeFull(pc.Conn, out)
}
return nil
}
// Read 优化版:减少切片操作,避免内存泄漏
func (pc *PackedConn) Read(p []byte) (int, error) {
// 1. 优先返回待处理区的数据
if len(pc.pendingData) > 0 {
n := copy(p, pc.pendingData)
if n == len(pc.pendingData) {
pc.pendingData = pc.pendingData[:0]
} else {
// 优化:移动剩余数据到数组头部,避免切片指向中间导致内存泄漏
remaining := len(pc.pendingData) - n
copy(pc.pendingData, pc.pendingData[n:])
pc.pendingData = pc.pendingData[:remaining]
func writeFull(w io.Writer, b []byte) error {
for len(b) > 0 {
n, err := w.Write(b)
if err != nil {
return err
}
return n, nil
b = b[n:]
}
return nil
}
func (pc *PackedConn) Read(p []byte) (int, error) {
if len(pc.pendingData) > 0 {
return pc.drainPendingData(p), nil
}
// 2. 循环读取直到解出数据或出错
for {
nr, rErr := pc.reader.Read(pc.rawBuf)
if nr > 0 {
// 缓存频繁访问的变量
rBuf := pc.readBitBuf
rBits := pc.readBits
padMarker := pc.padMarker
@@ -324,24 +363,13 @@ func (pc *PackedConn) Read(p []byte) (int, error) {
}
}
// 3. 返回解码后的数据 - 优化:避免底层数组泄漏
n := copy(p, pc.pendingData)
if n == len(pc.pendingData) {
pc.pendingData = pc.pendingData[:0]
} else {
remaining := len(pc.pendingData) - n
copy(pc.pendingData, pc.pendingData[n:])
pc.pendingData = pc.pendingData[:remaining]
}
return n, nil
return pc.drainPendingData(p), nil
}
// getPaddingByte 从 Pool 中随机取 Padding 字节
func (pc *PackedConn) getPaddingByte() byte {
return pc.padPool[pc.rng.Intn(len(pc.padPool))]
}
// encodeGroup 编码 6-bit 组
func (pc *PackedConn) encodeGroup(group byte) byte {
return pc.table.layout.encodeGroup(group)
}

View File

@@ -0,0 +1,91 @@
package sudoku
import (
"bytes"
"io"
"math/rand"
"net"
"testing"
"time"
)
type mockConn struct {
readBuf []byte
writeBuf []byte
}
func (c *mockConn) Read(p []byte) (int, error) {
if len(c.readBuf) == 0 {
return 0, io.EOF
}
n := copy(p, c.readBuf)
c.readBuf = c.readBuf[n:]
return n, nil
}
func (c *mockConn) Write(p []byte) (int, error) {
c.writeBuf = append(c.writeBuf, p...)
return len(p), nil
}
func (c *mockConn) Close() error { return nil }
func (c *mockConn) LocalAddr() net.Addr { return nil }
func (c *mockConn) RemoteAddr() net.Addr { return nil }
func (c *mockConn) SetDeadline(time.Time) error { return nil }
func (c *mockConn) SetReadDeadline(time.Time) error { return nil }
func (c *mockConn) SetWriteDeadline(time.Time) error { return nil }
func TestPackedConn_ProtectedPrefixPadding(t *testing.T) {
table := NewTable("packed-prefix-seed", "prefer_ascii")
mock := &mockConn{}
writer := NewPackedConn(mock, table, 0, 0)
writer.rng = rand.New(rand.NewSource(1))
payload := bytes.Repeat([]byte{0}, 32)
if _, err := writer.Write(payload); err != nil {
t.Fatalf("write: %v", err)
}
wire := append([]byte(nil), mock.writeBuf...)
if len(wire) < 20 {
t.Fatalf("wire too short: %d", len(wire))
}
firstHint := -1
nonHintCount := 0
maxHintRun := 0
currentHintRun := 0
for i, b := range wire[:20] {
if table.layout.isHint(b) {
if firstHint == -1 {
firstHint = i
}
currentHintRun++
if currentHintRun > maxHintRun {
maxHintRun = currentHintRun
}
continue
}
nonHintCount++
currentHintRun = 0
}
if firstHint < 1 || firstHint > 2 {
t.Fatalf("expected 1-2 leading padding bytes, first hint index=%d", firstHint)
}
if nonHintCount < 6 {
t.Fatalf("expected dense prefix padding, got only %d non-hint bytes in first 20", nonHintCount)
}
if maxHintRun > 3 {
t.Fatalf("prefix still exposes long hint run: %d", maxHintRun)
}
reader := NewPackedConn(&mockConn{readBuf: wire}, table, 0, 0)
decoded := make([]byte, len(payload))
if _, err := io.ReadFull(reader, decoded); err != nil {
t.Fatalf("read back: %v", err)
}
if !bytes.Equal(decoded, payload) {
t.Fatalf("roundtrip mismatch")
}
}

View File

@@ -11,7 +11,7 @@ func forceCloseAllConnections(roundTripper RoundTripper) {
roundTripper.CloseIdleConnections()
switch tr := roundTripper.(type) {
case *http.Http2Transport:
gun.CloseTransport(tr)
gun.CloseHttp2Transport(tr)
case *http3.Transport:
_ = tr.Close()
}