From b3c81602a287c384521c69dd4b39ca899a43d77c Mon Sep 17 00:00:00 2001 From: saba-futai <120904569+saba-futai@users.noreply.github.com> Date: Wed, 11 Mar 2026 00:00:32 +0800 Subject: [PATCH] chore: reduce the inherent 1rtt in httpmask mode for sudoku (#2610) --- transport/sudoku/crypto/record_conn.go | 152 ++++++-- transport/sudoku/crypto/record_conn_test.go | 86 +++++ transport/sudoku/early_handshake.go | 345 ++++++++++++++++++ transport/sudoku/handshake.go | 3 + transport/sudoku/httpmask_tunnel.go | 56 ++- transport/sudoku/httpmask_tunnel_test.go | 62 ++++ .../sudoku/obfs/httpmask/early_handshake.go | 174 +++++++++ transport/sudoku/obfs/httpmask/tunnel.go | 170 +++++---- transport/sudoku/obfs/httpmask/tunnel_ws.go | 31 +- .../sudoku/obfs/httpmask/tunnel_ws_server.go | 36 +- transport/sudoku/obfs/sudoku/packed.go | 192 +++++----- .../sudoku/obfs/sudoku/packed_prefix_test.go | 91 +++++ 12 files changed, 1188 insertions(+), 210 deletions(-) create mode 100644 transport/sudoku/crypto/record_conn_test.go create mode 100644 transport/sudoku/early_handshake.go create mode 100644 transport/sudoku/obfs/httpmask/early_handshake.go create mode 100644 transport/sudoku/obfs/sudoku/packed_prefix_test.go diff --git a/transport/sudoku/crypto/record_conn.go b/transport/sudoku/crypto/record_conn.go index 7a80c7f5..7c035715 100644 --- a/transport/sudoku/crypto/record_conn.go +++ b/transport/sudoku/crypto/record_conn.go @@ -5,6 +5,7 @@ import ( "crypto/aes" "crypto/cipher" "crypto/hmac" + "crypto/rand" "crypto/sha256" "encoding/binary" "errors" @@ -12,6 +13,7 @@ import ( "io" "net" "sync" + "sync/atomic" "golang.org/x/crypto/chacha20poly1305" ) @@ -55,13 +57,15 @@ type RecordConn struct { recvAEADEpoch uint32 // Send direction state. - sendEpoch uint32 - sendSeq uint64 - sendBytes int64 + sendEpoch uint32 + sendSeq uint64 + sendBytes int64 + sendEpochUpdates uint32 // Receive direction state. - recvEpoch uint32 - recvSeq uint64 + recvEpoch uint32 + recvSeq uint64 + recvInitialized bool readBuf bytes.Buffer @@ -105,6 +109,9 @@ func NewRecordConn(conn net.Conn, method string, baseSend, baseRecv []byte) (*Re } rc := &RecordConn{Conn: conn, method: method} rc.keys = recordKeys{baseSend: cloneBytes(baseSend), baseRecv: cloneBytes(baseRecv)} + if err := rc.resetTrafficState(); err != nil { + return nil, err + } return rc, nil } @@ -127,11 +134,9 @@ func (c *RecordConn) Rekey(baseSend, baseRecv []byte) error { defer c.writeMu.Unlock() c.keys = recordKeys{baseSend: cloneBytes(baseSend), baseRecv: cloneBytes(baseRecv)} - c.sendEpoch = 0 - c.sendSeq = 0 - c.sendBytes = 0 - c.recvEpoch = 0 - c.recvSeq = 0 + if err := c.resetTrafficState(); err != nil { + return err + } c.readBuf.Reset() c.sendAEAD = nil @@ -141,6 +146,21 @@ func (c *RecordConn) Rekey(baseSend, baseRecv []byte) error { return nil } +func (c *RecordConn) resetTrafficState() error { + sendEpoch, sendSeq, err := randomRecordCounters() + if err != nil { + return fmt.Errorf("initialize record counters: %w", err) + } + c.sendEpoch = sendEpoch + c.sendSeq = sendSeq + c.sendBytes = 0 + c.sendEpochUpdates = 0 + c.recvEpoch = 0 + c.recvSeq = 0 + c.recvInitialized = false + return nil +} + func normalizeAEADMethod(method string) string { switch method { case "", "chacha20-poly1305": @@ -166,6 +186,44 @@ func cloneBytes(b []byte) []byte { return append([]byte(nil), b...) } +func randomRecordCounters() (uint32, uint64, error) { + epoch, err := randomNonZeroUint32() + if err != nil { + return 0, 0, err + } + seq, err := randomNonZeroUint64() + if err != nil { + return 0, 0, err + } + return epoch, seq, nil +} + +func randomNonZeroUint32() (uint32, error) { + var b [4]byte + for { + if _, err := io.ReadFull(rand.Reader, b[:]); err != nil { + return 0, err + } + v := binary.BigEndian.Uint32(b[:]) + if v != 0 && v != ^uint32(0) { + return v, nil + } + } +} + +func randomNonZeroUint64() (uint64, error) { + var b [8]byte + for { + if _, err := io.ReadFull(rand.Reader, b[:]); err != nil { + return 0, err + } + v := binary.BigEndian.Uint64(b[:]) + if v != 0 && v != ^uint64(0) { + return v, nil + } + } +} + func (c *RecordConn) newAEADFor(base []byte, epoch uint32) (cipher.AEAD, error) { if c.method == "none" { return nil, nil @@ -209,17 +267,49 @@ func deriveEpochKey(base []byte, epoch uint32, method string) []byte { return mac.Sum(nil) } -func (c *RecordConn) maybeBumpSendEpochLocked(addedPlain int) { - if KeyUpdateAfterBytes <= 0 || c.method == "none" { - return +func (c *RecordConn) maybeBumpSendEpochLocked(addedPlain int) error { + ku := atomic.LoadInt64(&KeyUpdateAfterBytes) + if ku <= 0 || c.method == "none" { + return nil } c.sendBytes += int64(addedPlain) - threshold := KeyUpdateAfterBytes * int64(c.sendEpoch+1) + threshold := ku * int64(c.sendEpochUpdates+1) if c.sendBytes < threshold { - return + return nil } c.sendEpoch++ - c.sendSeq = 0 + c.sendEpochUpdates++ + nextSeq, err := randomNonZeroUint64() + if err != nil { + return fmt.Errorf("rotate record seq: %w", err) + } + c.sendSeq = nextSeq + return nil +} + +func (c *RecordConn) validateRecvPosition(epoch uint32, seq uint64) error { + if !c.recvInitialized { + return nil + } + if epoch < c.recvEpoch { + return fmt.Errorf("replayed epoch: got %d want >=%d", epoch, c.recvEpoch) + } + if epoch == c.recvEpoch && seq != c.recvSeq { + return fmt.Errorf("out of order: epoch=%d got=%d want=%d", epoch, seq, c.recvSeq) + } + if epoch > c.recvEpoch { + const maxJump = 8 + if epoch-c.recvEpoch > maxJump { + return fmt.Errorf("epoch jump too large: got=%d want<=%d", epoch-c.recvEpoch, maxJump) + } + } + return nil +} + +func (c *RecordConn) markRecvPosition(epoch uint32, seq uint64) { + c.recvEpoch = epoch + c.recvSeq = seq + 1 + c.recvInitialized = true } func (c *RecordConn) Write(p []byte) (int, error) { @@ -282,7 +372,9 @@ func (c *RecordConn) Write(p []byte) (int, error) { } total += n - c.maybeBumpSendEpochLocked(n) + if err := c.maybeBumpSendEpochLocked(n); err != nil { + return total, err + } } return total, nil } @@ -324,31 +416,17 @@ func (c *RecordConn) Read(p []byte) (int, error) { epoch := binary.BigEndian.Uint32(header[:4]) seq := binary.BigEndian.Uint64(header[4:]) - if epoch < c.recvEpoch { - return 0, fmt.Errorf("replayed epoch: got %d want >=%d", epoch, c.recvEpoch) - } - if epoch == c.recvEpoch && seq != c.recvSeq { - return 0, fmt.Errorf("out of order: epoch=%d got=%d want=%d", epoch, seq, c.recvSeq) - } - if epoch > c.recvEpoch { - const maxJump = 8 - if epoch-c.recvEpoch > maxJump { - return 0, fmt.Errorf("epoch jump too large: got=%d want<=%d", epoch-c.recvEpoch, maxJump) - } - c.recvEpoch = epoch - c.recvSeq = 0 - if seq != 0 { - return 0, fmt.Errorf("out of order: epoch advanced to %d but seq=%d", epoch, seq) - } + if err := c.validateRecvPosition(epoch, seq); err != nil { + return 0, err } - if c.recvAEAD == nil || c.recvAEADEpoch != c.recvEpoch { - a, err := c.newAEADFor(c.keys.baseRecv, c.recvEpoch) + if c.recvAEAD == nil || c.recvAEADEpoch != epoch { + a, err := c.newAEADFor(c.keys.baseRecv, epoch) if err != nil { return 0, err } c.recvAEAD = a - c.recvAEADEpoch = c.recvEpoch + c.recvAEADEpoch = epoch } aead := c.recvAEAD @@ -356,7 +434,7 @@ func (c *RecordConn) Read(p []byte) (int, error) { if err != nil { return 0, fmt.Errorf("decryption failed: epoch=%d seq=%d: %w", epoch, seq, err) } - c.recvSeq++ + c.markRecvPosition(epoch, seq) c.readBuf.Write(plaintext) return c.readBuf.Read(p) diff --git a/transport/sudoku/crypto/record_conn_test.go b/transport/sudoku/crypto/record_conn_test.go new file mode 100644 index 00000000..4ea0b9b8 --- /dev/null +++ b/transport/sudoku/crypto/record_conn_test.go @@ -0,0 +1,86 @@ +package crypto + +import ( + "bytes" + "crypto/sha256" + "encoding/binary" + "io" + "net" + "testing" + "time" +) + +type captureConn struct { + bytes.Buffer +} + +func (c *captureConn) Read(_ []byte) (int, error) { return 0, io.EOF } +func (c *captureConn) Write(p []byte) (int, error) { return c.Buffer.Write(p) } +func (c *captureConn) Close() error { return nil } +func (c *captureConn) LocalAddr() net.Addr { return nil } +func (c *captureConn) RemoteAddr() net.Addr { return nil } +func (c *captureConn) SetDeadline(time.Time) error { return nil } +func (c *captureConn) SetReadDeadline(time.Time) error { return nil } +func (c *captureConn) SetWriteDeadline(time.Time) error { return nil } + +type replayConn struct { + reader *bytes.Reader +} + +func (c *replayConn) Read(p []byte) (int, error) { return c.reader.Read(p) } +func (c *replayConn) Write(p []byte) (int, error) { return len(p), nil } +func (c *replayConn) Close() error { return nil } +func (c *replayConn) LocalAddr() net.Addr { return nil } +func (c *replayConn) RemoteAddr() net.Addr { return nil } +func (c *replayConn) SetDeadline(time.Time) error { return nil } +func (c *replayConn) SetReadDeadline(time.Time) error { return nil } +func (c *replayConn) SetWriteDeadline(time.Time) error { return nil } + +func TestRecordConn_FirstFrameUsesRandomizedCounters(t *testing.T) { + pskSend := sha256.Sum256([]byte("record-send")) + pskRecv := sha256.Sum256([]byte("record-recv")) + + raw := &captureConn{} + writer, err := NewRecordConn(raw, "chacha20-poly1305", pskSend[:], pskRecv[:]) + if err != nil { + t.Fatalf("new writer: %v", err) + } + + if writer.sendEpoch == 0 || writer.sendSeq == 0 { + t.Fatalf("expected non-zero randomized counters, got epoch=%d seq=%d", writer.sendEpoch, writer.sendSeq) + } + + want := []byte("record prefix camouflage") + if _, err := writer.Write(want); err != nil { + t.Fatalf("write: %v", err) + } + + wire := raw.Bytes() + if len(wire) < 2+recordHeaderSize { + t.Fatalf("short frame: %d", len(wire)) + } + + bodyLen := int(binary.BigEndian.Uint16(wire[:2])) + if bodyLen != len(wire)-2 { + t.Fatalf("body len mismatch: got %d want %d", bodyLen, len(wire)-2) + } + + epoch := binary.BigEndian.Uint32(wire[2:6]) + seq := binary.BigEndian.Uint64(wire[6:14]) + if epoch == 0 || seq == 0 { + t.Fatalf("wire header still starts from zero: epoch=%d seq=%d", epoch, seq) + } + + reader, err := NewRecordConn(&replayConn{reader: bytes.NewReader(wire)}, "chacha20-poly1305", pskRecv[:], pskSend[:]) + if err != nil { + t.Fatalf("new reader: %v", err) + } + + got := make([]byte, len(want)) + if _, err := io.ReadFull(reader, got); err != nil { + t.Fatalf("read: %v", err) + } + if !bytes.Equal(got, want) { + t.Fatalf("plaintext mismatch: got %q want %q", got, want) + } +} diff --git a/transport/sudoku/early_handshake.go b/transport/sudoku/early_handshake.go new file mode 100644 index 00000000..803a5293 --- /dev/null +++ b/transport/sudoku/early_handshake.go @@ -0,0 +1,345 @@ +package sudoku + +import ( + "bytes" + "crypto/ecdh" + "crypto/rand" + "encoding/hex" + "fmt" + "net" + "time" + + "github.com/metacubex/mihomo/transport/sudoku/crypto" + httpmaskobfs "github.com/metacubex/mihomo/transport/sudoku/obfs/httpmask" + sudokuobfs "github.com/metacubex/mihomo/transport/sudoku/obfs/sudoku" +) + +const earlyKIPHandshakeTTL = 60 * time.Second + +type EarlyCodecConfig struct { + PSK string + AEAD string + EnablePureDownlink bool + PaddingMin int + PaddingMax int +} + +type EarlyClientState struct { + RequestPayload []byte + + cfg EarlyCodecConfig + table *sudokuobfs.Table + nonce [kipHelloNonceSize]byte + ephemeral *ecdh.PrivateKey + sessionC2S []byte + sessionS2C []byte + responseSet bool +} + +type EarlyServerState struct { + ResponsePayload []byte + UserHash string + + cfg EarlyCodecConfig + table *sudokuobfs.Table + sessionC2S []byte + sessionS2C []byte +} + +type ReplayAllowFunc func(userHash string, nonce [kipHelloNonceSize]byte, now time.Time) bool + +type earlyMemoryConn struct { + reader *bytes.Reader + write bytes.Buffer +} + +func newEarlyMemoryConn(readBuf []byte) *earlyMemoryConn { + return &earlyMemoryConn{reader: bytes.NewReader(readBuf)} +} + +func (c *earlyMemoryConn) Read(p []byte) (int, error) { + if c == nil || c.reader == nil { + return 0, net.ErrClosed + } + return c.reader.Read(p) +} + +func (c *earlyMemoryConn) Write(p []byte) (int, error) { + if c == nil { + return 0, net.ErrClosed + } + return c.write.Write(p) +} + +func (c *earlyMemoryConn) Close() error { return nil } +func (c *earlyMemoryConn) LocalAddr() net.Addr { return earlyDummyAddr("local") } +func (c *earlyMemoryConn) RemoteAddr() net.Addr { return earlyDummyAddr("remote") } +func (c *earlyMemoryConn) SetDeadline(time.Time) error { return nil } +func (c *earlyMemoryConn) SetReadDeadline(time.Time) error { return nil } +func (c *earlyMemoryConn) SetWriteDeadline(time.Time) error { return nil } +func (c *earlyMemoryConn) Written() []byte { return append([]byte(nil), c.write.Bytes()...) } + +type earlyDummyAddr string + +func (a earlyDummyAddr) Network() string { return string(a) } +func (a earlyDummyAddr) String() string { return string(a) } + +func buildEarlyClientObfsConn(raw net.Conn, cfg EarlyCodecConfig, table *sudokuobfs.Table) net.Conn { + base := sudokuobfs.NewConn(raw, table, cfg.PaddingMin, cfg.PaddingMax, false) + if cfg.EnablePureDownlink { + return base + } + packed := sudokuobfs.NewPackedConn(raw, table, cfg.PaddingMin, cfg.PaddingMax) + return newDirectionalConn(raw, packed, base) +} + +func buildEarlyServerObfsConn(raw net.Conn, cfg EarlyCodecConfig, table *sudokuobfs.Table) net.Conn { + uplink := sudokuobfs.NewConn(raw, table, cfg.PaddingMin, cfg.PaddingMax, false) + if cfg.EnablePureDownlink { + return uplink + } + packed := sudokuobfs.NewPackedConn(raw, table, cfg.PaddingMin, cfg.PaddingMax) + return newDirectionalConn(raw, uplink, packed, packed.Flush) +} + +func NewEarlyClientState(cfg EarlyCodecConfig, table *sudokuobfs.Table, userHash [kipHelloUserHashSize]byte, feats uint32) (*EarlyClientState, error) { + if table == nil { + return nil, fmt.Errorf("nil table") + } + + curve := ecdh.X25519() + ephemeral, err := curve.GenerateKey(rand.Reader) + if err != nil { + return nil, fmt.Errorf("ecdh generate failed: %w", err) + } + + var nonce [kipHelloNonceSize]byte + if _, err := rand.Read(nonce[:]); err != nil { + return nil, fmt.Errorf("nonce generate failed: %w", err) + } + + var clientPub [kipHelloPubSize]byte + copy(clientPub[:], ephemeral.PublicKey().Bytes()) + hello := &KIPClientHello{ + Timestamp: time.Now(), + UserHash: userHash, + Nonce: nonce, + ClientPub: clientPub, + Features: feats, + } + + mem := newEarlyMemoryConn(nil) + obfsConn := buildEarlyClientObfsConn(mem, cfg, table) + pskC2S, pskS2C := derivePSKDirectionalBases(cfg.PSK) + rc, err := crypto.NewRecordConn(obfsConn, cfg.AEAD, pskC2S, pskS2C) + if err != nil { + return nil, fmt.Errorf("client early crypto setup failed: %w", err) + } + if err := WriteKIPMessage(rc, KIPTypeClientHello, hello.EncodePayload()); err != nil { + return nil, fmt.Errorf("write early client hello failed: %w", err) + } + + return &EarlyClientState{ + RequestPayload: mem.Written(), + cfg: cfg, + table: table, + nonce: nonce, + ephemeral: ephemeral, + }, nil +} + +func (s *EarlyClientState) ProcessResponse(payload []byte) error { + if s == nil { + return fmt.Errorf("nil client state") + } + + mem := newEarlyMemoryConn(payload) + obfsConn := buildEarlyClientObfsConn(mem, s.cfg, s.table) + pskC2S, pskS2C := derivePSKDirectionalBases(s.cfg.PSK) + rc, err := crypto.NewRecordConn(obfsConn, s.cfg.AEAD, pskC2S, pskS2C) + if err != nil { + return fmt.Errorf("client early crypto setup failed: %w", err) + } + + msg, err := ReadKIPMessage(rc) + if err != nil { + return fmt.Errorf("read early server hello failed: %w", err) + } + if msg.Type != KIPTypeServerHello { + return fmt.Errorf("unexpected early handshake message: %d", msg.Type) + } + sh, err := DecodeKIPServerHelloPayload(msg.Payload) + if err != nil { + return fmt.Errorf("decode early server hello failed: %w", err) + } + if sh.Nonce != s.nonce { + return fmt.Errorf("early handshake nonce mismatch") + } + + shared, err := x25519SharedSecret(s.ephemeral, sh.ServerPub[:]) + if err != nil { + return fmt.Errorf("ecdh failed: %w", err) + } + s.sessionC2S, s.sessionS2C, err = deriveSessionDirectionalBases(s.cfg.PSK, shared, s.nonce) + if err != nil { + return fmt.Errorf("derive session keys failed: %w", err) + } + s.responseSet = true + return nil +} + +func (s *EarlyClientState) WrapConn(raw net.Conn) (net.Conn, error) { + if s == nil { + return nil, fmt.Errorf("nil client state") + } + if !s.responseSet { + return nil, fmt.Errorf("early handshake not completed") + } + + obfsConn := buildEarlyClientObfsConn(raw, s.cfg, s.table) + rc, err := crypto.NewRecordConn(obfsConn, s.cfg.AEAD, s.sessionC2S, s.sessionS2C) + if err != nil { + return nil, fmt.Errorf("setup client session crypto failed: %w", err) + } + return rc, nil +} + +func (s *EarlyClientState) Ready() bool { + return s != nil && s.responseSet +} + +func NewHTTPMaskClientEarlyHandshake(cfg EarlyCodecConfig, table *sudokuobfs.Table, userHash [kipHelloUserHashSize]byte, feats uint32) (*httpmaskobfs.ClientEarlyHandshake, error) { + state, err := NewEarlyClientState(cfg, table, userHash, feats) + if err != nil { + return nil, err + } + return &httpmaskobfs.ClientEarlyHandshake{ + RequestPayload: state.RequestPayload, + HandleResponse: state.ProcessResponse, + Ready: state.Ready, + WrapConn: state.WrapConn, + }, nil +} + +func ProcessEarlyClientPayload(cfg EarlyCodecConfig, tables []*sudokuobfs.Table, payload []byte, allowReplay ReplayAllowFunc) (*EarlyServerState, error) { + if len(payload) == 0 { + return nil, fmt.Errorf("empty early payload") + } + if len(tables) == 0 { + return nil, fmt.Errorf("no tables configured") + } + + var firstErr error + for _, table := range tables { + state, err := processEarlyClientPayloadForTable(cfg, table, payload, allowReplay) + if err == nil { + return state, nil + } + if firstErr == nil { + firstErr = err + } + } + if firstErr == nil { + firstErr = fmt.Errorf("early handshake probe failed") + } + return nil, firstErr +} + +func processEarlyClientPayloadForTable(cfg EarlyCodecConfig, table *sudokuobfs.Table, payload []byte, allowReplay ReplayAllowFunc) (*EarlyServerState, error) { + mem := newEarlyMemoryConn(payload) + obfsConn := buildEarlyServerObfsConn(mem, cfg, table) + pskC2S, pskS2C := derivePSKDirectionalBases(cfg.PSK) + rc, err := crypto.NewRecordConn(obfsConn, cfg.AEAD, pskS2C, pskC2S) + if err != nil { + return nil, err + } + + msg, err := ReadKIPMessage(rc) + if err != nil { + return nil, err + } + if msg.Type != KIPTypeClientHello { + return nil, fmt.Errorf("unexpected handshake message: %d", msg.Type) + } + ch, err := DecodeKIPClientHelloPayload(msg.Payload) + if err != nil { + return nil, err + } + if absInt64(time.Now().Unix()-ch.Timestamp.Unix()) > int64(earlyKIPHandshakeTTL.Seconds()) { + return nil, fmt.Errorf("time skew/replay") + } + + userHash := hex.EncodeToString(ch.UserHash[:]) + if allowReplay != nil && !allowReplay(userHash, ch.Nonce, time.Now()) { + return nil, fmt.Errorf("replay detected") + } + + curve := ecdh.X25519() + serverEphemeral, err := curve.GenerateKey(rand.Reader) + if err != nil { + return nil, fmt.Errorf("ecdh generate failed: %w", err) + } + shared, err := x25519SharedSecret(serverEphemeral, ch.ClientPub[:]) + if err != nil { + return nil, fmt.Errorf("ecdh failed: %w", err) + } + sessionC2S, sessionS2C, err := deriveSessionDirectionalBases(cfg.PSK, shared, ch.Nonce) + if err != nil { + return nil, fmt.Errorf("derive session keys failed: %w", err) + } + + var serverPub [kipHelloPubSize]byte + copy(serverPub[:], serverEphemeral.PublicKey().Bytes()) + serverHello := &KIPServerHello{ + Nonce: ch.Nonce, + ServerPub: serverPub, + SelectedFeats: ch.Features & KIPFeatAll, + } + + respMem := newEarlyMemoryConn(nil) + respObfs := buildEarlyServerObfsConn(respMem, cfg, table) + respConn, err := crypto.NewRecordConn(respObfs, cfg.AEAD, pskS2C, pskC2S) + if err != nil { + return nil, fmt.Errorf("server early crypto setup failed: %w", err) + } + if err := WriteKIPMessage(respConn, KIPTypeServerHello, serverHello.EncodePayload()); err != nil { + return nil, fmt.Errorf("write early server hello failed: %w", err) + } + + return &EarlyServerState{ + ResponsePayload: respMem.Written(), + UserHash: userHash, + cfg: cfg, + table: table, + sessionC2S: sessionC2S, + sessionS2C: sessionS2C, + }, nil +} + +func (s *EarlyServerState) WrapConn(raw net.Conn) (net.Conn, error) { + if s == nil { + return nil, fmt.Errorf("nil server state") + } + obfsConn := buildEarlyServerObfsConn(raw, s.cfg, s.table) + rc, err := crypto.NewRecordConn(obfsConn, s.cfg.AEAD, s.sessionS2C, s.sessionC2S) + if err != nil { + return nil, fmt.Errorf("setup server session crypto failed: %w", err) + } + return rc, nil +} + +func NewHTTPMaskServerEarlyHandshake(cfg EarlyCodecConfig, tables []*sudokuobfs.Table, allowReplay ReplayAllowFunc) *httpmaskobfs.TunnelServerEarlyHandshake { + return &httpmaskobfs.TunnelServerEarlyHandshake{ + Prepare: func(payload []byte) (*httpmaskobfs.PreparedServerEarlyHandshake, error) { + state, err := ProcessEarlyClientPayload(cfg, tables, payload, allowReplay) + if err != nil { + return nil, err + } + return &httpmaskobfs.PreparedServerEarlyHandshake{ + ResponsePayload: state.ResponsePayload, + WrapConn: state.WrapConn, + UserHash: state.UserHash, + }, nil + }, + } +} diff --git a/transport/sudoku/handshake.go b/transport/sudoku/handshake.go index 971d47fd..bac688e8 100644 --- a/transport/sudoku/handshake.go +++ b/transport/sudoku/handshake.go @@ -337,6 +337,9 @@ func ServerHandshake(rawConn net.Conn, cfg *ProtocolConfig) (net.Conn, *Handshak if err := cfg.Validate(); err != nil { return nil, nil, fmt.Errorf("invalid config: %w", err) } + if userHash, ok := httpmask.EarlyHandshakeUserHash(rawConn); ok { + return rawConn, &HandshakeMeta{UserHash: userHash}, nil + } handshakeTimeout := time.Duration(cfg.HandshakeTimeoutSeconds) * time.Second if handshakeTimeout <= 0 { diff --git a/transport/sudoku/httpmask_tunnel.go b/transport/sudoku/httpmask_tunnel.go index 1ff2bb38..c066eb1c 100644 --- a/transport/sudoku/httpmask_tunnel.go +++ b/transport/sudoku/httpmask_tunnel.go @@ -14,6 +14,30 @@ type HTTPMaskTunnelServer struct { ts *httpmask.TunnelServer } +func newHTTPMaskEarlyCodecConfig(cfg *ProtocolConfig, psk string) EarlyCodecConfig { + return EarlyCodecConfig{ + PSK: psk, + AEAD: cfg.AEADMethod, + EnablePureDownlink: cfg.EnablePureDownlink, + PaddingMin: cfg.PaddingMin, + PaddingMax: cfg.PaddingMax, + } +} + +func newClientHTTPMaskEarlyHandshake(cfg *ProtocolConfig) (*httpmask.ClientEarlyHandshake, error) { + table, err := pickClientTable(cfg) + if err != nil { + return nil, err + } + + return NewHTTPMaskClientEarlyHandshake( + newHTTPMaskEarlyCodecConfig(cfg, ClientAEADSeed(cfg.Key)), + table, + kipUserHashFromKey(cfg.Key), + KIPFeatAll, + ) +} + func NewHTTPMaskTunnelServer(cfg *ProtocolConfig) *HTTPMaskTunnelServer { return newHTTPMaskTunnelServer(cfg, false) } @@ -35,6 +59,11 @@ func newHTTPMaskTunnelServer(cfg *ProtocolConfig, passThroughOnReject bool) *HTT Mode: cfg.HTTPMaskMode, PathRoot: cfg.HTTPMaskPathRoot, AuthKey: ServerAEADSeed(cfg.Key), + EarlyHandshake: NewHTTPMaskServerEarlyHandshake( + newHTTPMaskEarlyCodecConfig(cfg, ServerAEADSeed(cfg.Key)), + cfg.tableCandidates(), + globalHandshakeReplay.allow, + ), // When upstream fallback is enabled, preserve rejected HTTP requests for the caller. PassThroughOnReject: passThroughOnReject, }) @@ -101,14 +130,25 @@ func DialHTTPMaskTunnel(ctx context.Context, serverAddress string, cfg *Protocol default: return nil, fmt.Errorf("http-mask-mode=%q does not use http tunnel", cfg.HTTPMaskMode) } + var ( + earlyHandshake *httpmask.ClientEarlyHandshake + err error + ) + if upgrade != nil { + earlyHandshake, err = newClientHTTPMaskEarlyHandshake(cfg) + if err != nil { + return nil, err + } + } return httpmask.DialTunnel(ctx, serverAddress, httpmask.TunnelDialOptions{ - Mode: cfg.HTTPMaskMode, - TLSEnabled: cfg.HTTPMaskTLSEnabled, - HostOverride: cfg.HTTPMaskHost, - PathRoot: cfg.HTTPMaskPathRoot, - AuthKey: ClientAEADSeed(cfg.Key), - Upgrade: upgrade, - Multiplex: cfg.HTTPMaskMultiplex, - DialContext: dial, + Mode: cfg.HTTPMaskMode, + TLSEnabled: cfg.HTTPMaskTLSEnabled, + HostOverride: cfg.HTTPMaskHost, + PathRoot: cfg.HTTPMaskPathRoot, + AuthKey: ClientAEADSeed(cfg.Key), + EarlyHandshake: earlyHandshake, + Upgrade: upgrade, + Multiplex: cfg.HTTPMaskMultiplex, + DialContext: dial, }) } diff --git a/transport/sudoku/httpmask_tunnel_test.go b/transport/sudoku/httpmask_tunnel_test.go index 8894882e..01eb3a50 100644 --- a/transport/sudoku/httpmask_tunnel_test.go +++ b/transport/sudoku/httpmask_tunnel_test.go @@ -389,6 +389,68 @@ func TestHTTPMaskTunnel_WS_TCPRoundTrip(t *testing.T) { } } +func TestHTTPMaskTunnel_EarlyHandshake_TCPRoundTrip(t *testing.T) { + modes := []string{"stream", "poll", "ws"} + for _, mode := range modes { + t.Run(mode, func(t *testing.T) { + key := "tunnel-early-" + mode + target := "1.1.1.1:80" + + serverCfg := newTunnelTestTable(t, key) + serverCfg.HTTPMaskMode = mode + + addr, stop, errCh := startTunnelServer(t, serverCfg, func(s *ServerSession) error { + if s.Type != SessionTypeTCP { + return fmt.Errorf("unexpected session type: %v", s.Type) + } + if s.Target != target { + return fmt.Errorf("target mismatch: %s", s.Target) + } + _, _ = s.Conn.Write([]byte("ok")) + return nil + }) + defer stop() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + clientCfg := *serverCfg + clientCfg.ServerAddress = addr + + handshakeCfg := clientCfg + handshakeCfg.DisableHTTPMask = true + tunnelConn, err := DialHTTPMaskTunnel(ctx, clientCfg.ServerAddress, &clientCfg, (&net.Dialer{}).DialContext, func(raw net.Conn) (net.Conn, error) { + return ClientHandshake(raw, &handshakeCfg) + }) + if err != nil { + t.Fatalf("dial tunnel: %v", err) + } + defer tunnelConn.Close() + + addrBuf, err := EncodeAddress(target) + if err != nil { + t.Fatalf("encode addr: %v", err) + } + if err := WriteKIPMessage(tunnelConn, KIPTypeOpenTCP, addrBuf); err != nil { + t.Fatalf("write addr: %v", err) + } + + buf := make([]byte, 2) + if _, err := io.ReadFull(tunnelConn, buf); err != nil { + t.Fatalf("read: %v", err) + } + if string(buf) != "ok" { + t.Fatalf("unexpected payload: %q", buf) + } + + stop() + for err := range errCh { + t.Fatalf("server error: %v", err) + } + }) + } +} + func TestHTTPMaskTunnel_Validation(t *testing.T) { cfg := DefaultConfig() cfg.Key = "k" diff --git a/transport/sudoku/obfs/httpmask/early_handshake.go b/transport/sudoku/obfs/httpmask/early_handshake.go new file mode 100644 index 00000000..54158577 --- /dev/null +++ b/transport/sudoku/obfs/httpmask/early_handshake.go @@ -0,0 +1,174 @@ +package httpmask + +import ( + "encoding/base64" + "errors" + "fmt" + "net" + "net/url" + "strings" +) + +const ( + tunnelEarlyDataQueryKey = "ed" + tunnelEarlyDataHeader = "X-Sudoku-Early" +) + +type ClientEarlyHandshake struct { + RequestPayload []byte + HandleResponse func(payload []byte) error + Ready func() bool + WrapConn func(raw net.Conn) (net.Conn, error) +} + +type TunnelServerEarlyHandshake struct { + Prepare func(payload []byte) (*PreparedServerEarlyHandshake, error) +} + +type PreparedServerEarlyHandshake struct { + ResponsePayload []byte + WrapConn func(raw net.Conn) (net.Conn, error) + UserHash string +} + +type earlyHandshakeMeta interface { + HTTPMaskEarlyHandshakeUserHash() string +} + +type earlyHandshakeConn struct { + net.Conn + userHash string +} + +func (c *earlyHandshakeConn) HTTPMaskEarlyHandshakeUserHash() string { + if c == nil { + return "" + } + return c.userHash +} + +func wrapEarlyHandshakeConn(conn net.Conn, userHash string) net.Conn { + if conn == nil { + return nil + } + return &earlyHandshakeConn{Conn: conn, userHash: userHash} +} + +func EarlyHandshakeUserHash(conn net.Conn) (string, bool) { + if conn == nil { + return "", false + } + v, ok := conn.(earlyHandshakeMeta) + if !ok { + return "", false + } + return v.HTTPMaskEarlyHandshakeUserHash(), true +} + +type authorizeResponse struct { + token string + earlyPayload []byte +} + +func isTunnelTokenByte(c byte) bool { + return (c >= 'a' && c <= 'z') || + (c >= 'A' && c <= 'Z') || + (c >= '0' && c <= '9') || + c == '-' || + c == '_' +} + +func parseAuthorizeResponse(body []byte) (*authorizeResponse, error) { + s := strings.TrimSpace(string(body)) + idx := strings.Index(s, "token=") + if idx < 0 { + return nil, errors.New("missing token") + } + s = s[idx+len("token="):] + if s == "" { + return nil, errors.New("empty token") + } + + var b strings.Builder + for i := 0; i < len(s); i++ { + c := s[i] + if isTunnelTokenByte(c) { + b.WriteByte(c) + continue + } + break + } + token := b.String() + if token == "" { + return nil, errors.New("empty token") + } + + out := &authorizeResponse{token: token} + if earlyLine := findAuthorizeField(body, "ed="); earlyLine != "" { + decoded, err := base64.RawURLEncoding.DecodeString(earlyLine) + if err != nil { + return nil, fmt.Errorf("decode early authorize payload failed: %w", err) + } + out.earlyPayload = decoded + } + return out, nil +} + +func findAuthorizeField(body []byte, prefix string) string { + for _, line := range strings.Split(strings.TrimSpace(string(body)), "\n") { + line = strings.TrimSpace(line) + if strings.HasPrefix(line, prefix) { + return strings.TrimSpace(strings.TrimPrefix(line, prefix)) + } + } + return "" +} + +func setEarlyDataQuery(rawURL string, payload []byte) (string, error) { + if len(payload) == 0 { + return rawURL, nil + } + u, err := url.Parse(rawURL) + if err != nil { + return "", err + } + q := u.Query() + q.Set(tunnelEarlyDataQueryKey, base64.RawURLEncoding.EncodeToString(payload)) + u.RawQuery = q.Encode() + return u.String(), nil +} + +func parseEarlyDataQuery(u *url.URL) ([]byte, error) { + if u == nil { + return nil, nil + } + val := strings.TrimSpace(u.Query().Get(tunnelEarlyDataQueryKey)) + if val == "" { + return nil, nil + } + return base64.RawURLEncoding.DecodeString(val) +} + +func applyEarlyHandshakeOrUpgrade(raw net.Conn, opts TunnelDialOptions) (net.Conn, error) { + out := raw + if opts.EarlyHandshake != nil && opts.EarlyHandshake.WrapConn != nil && (opts.EarlyHandshake.Ready == nil || opts.EarlyHandshake.Ready()) { + wrapped, err := opts.EarlyHandshake.WrapConn(raw) + if err != nil { + return nil, err + } + if wrapped != nil { + out = wrapped + } + return out, nil + } + if opts.Upgrade != nil { + wrapped, err := opts.Upgrade(raw) + if err != nil { + return nil, err + } + if wrapped != nil { + out = wrapped + } + } + return out, nil +} diff --git a/transport/sudoku/obfs/httpmask/tunnel.go b/transport/sudoku/obfs/httpmask/tunnel.go index 20981c39..a100c620 100644 --- a/transport/sudoku/obfs/httpmask/tunnel.go +++ b/transport/sudoku/obfs/httpmask/tunnel.go @@ -72,6 +72,10 @@ type TunnelDialOptions struct { // AuthKey enables short-term HMAC auth for HTTP tunnel requests (anti-probing). // When set (non-empty), each HTTP request carries an Authorization bearer token derived from AuthKey. AuthKey string + // EarlyHandshake folds the protocol handshake into the HTTP/WS setup round trip. + // When the server accepts the early payload, DialTunnel returns a conn that is already post-handshake. + // When the server does not echo early data, DialTunnel falls back to Upgrade. + EarlyHandshake *ClientEarlyHandshake // Upgrade optionally wraps the raw tunnel conn and/or writes a small prelude before DialTunnel returns. // It is called with the raw tunnel conn; if it returns a non-nil conn, that conn is returned by DialTunnel. Upgrade func(raw net.Conn) (net.Conn, error) @@ -225,30 +229,11 @@ func canonicalHeaderHost(urlHost, scheme string) string { } func parseTunnelToken(body []byte) (string, error) { - s := strings.TrimSpace(string(body)) - idx := strings.Index(s, "token=") - if idx < 0 { - return "", errors.New("missing token") + resp, err := parseAuthorizeResponse(body) + if err != nil { + return "", err } - s = s[idx+len("token="):] - if s == "" { - return "", errors.New("empty token") - } - // Token is base64.RawURLEncoding (A-Z a-z 0-9 - _). Strip any trailing bytes (e.g. from CDN compression). - var b strings.Builder - for i := 0; i < len(s); i++ { - c := s[i] - if (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '-' || c == '_' { - b.WriteByte(c) - continue - } - break - } - token := b.String() - if token == "" { - return "", errors.New("empty token") - } - return token, nil + return resp.token, nil } type httpClientTarget struct { @@ -353,6 +338,13 @@ func dialSessionWithClient(ctx context.Context, client *http.Client, target http auth := newTunnelAuth(opts.AuthKey, 0) authorizeURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: joinPathRoot(opts.PathRoot, "/session")}).String() + if opts.EarlyHandshake != nil && len(opts.EarlyHandshake.RequestPayload) > 0 { + var err error + authorizeURL, err = setEarlyDataQuery(authorizeURL, opts.EarlyHandshake.RequestPayload) + if err != nil { + return nil, err + } + } var bodyBytes []byte for attempt := 0; ; attempt++ { @@ -410,13 +402,19 @@ func dialSessionWithClient(ctx context.Context, client *http.Client, target http break } - token, err := parseTunnelToken(bodyBytes) + authResp, err := parseAuthorizeResponse(bodyBytes) if err != nil { return nil, fmt.Errorf("%s authorize failed: %q", mode, strings.TrimSpace(string(bodyBytes))) } + token := authResp.token if token == "" { return nil, fmt.Errorf("%s authorize empty token", mode) } + if opts.EarlyHandshake != nil && len(authResp.earlyPayload) > 0 && opts.EarlyHandshake.HandleResponse != nil { + if err := opts.EarlyHandshake.HandleResponse(authResp.earlyPayload); err != nil { + return nil, err + } + } pushURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: joinPathRoot(opts.PathRoot, "/api/v1/upload"), RawQuery: "token=" + url.QueryEscape(token)}).String() pullURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: joinPathRoot(opts.PathRoot, "/stream"), RawQuery: "token=" + url.QueryEscape(token)}).String() @@ -671,16 +669,10 @@ func dialStreamSplitWithClient(ctx context.Context, client *http.Client, target if c == nil { return nil, fmt.Errorf("failed to build stream split conn") } - outConn := net.Conn(c) - if opts.Upgrade != nil { - upgraded, err := opts.Upgrade(c) - if err != nil { - _ = c.Close() - return nil, err - } - if upgraded != nil { - outConn = upgraded - } + outConn, err := applyEarlyHandshakeOrUpgrade(c, opts) + if err != nil { + _ = c.Close() + return nil, err } return outConn, nil } @@ -694,16 +686,10 @@ func dialStreamSplit(ctx context.Context, serverAddress string, opts TunnelDialO if c == nil { return nil, fmt.Errorf("failed to build stream split conn") } - outConn := net.Conn(c) - if opts.Upgrade != nil { - upgraded, err := opts.Upgrade(c) - if err != nil { - _ = c.Close() - return nil, err - } - if upgraded != nil { - outConn = upgraded - } + outConn, err := applyEarlyHandshakeOrUpgrade(c, opts) + if err != nil { + _ = c.Close() + return nil, err } return outConn, nil } @@ -1120,16 +1106,10 @@ func dialPollWithClient(ctx context.Context, client *http.Client, target httpCli if c == nil { return nil, fmt.Errorf("failed to build poll conn") } - outConn := net.Conn(c) - if opts.Upgrade != nil { - upgraded, err := opts.Upgrade(c) - if err != nil { - _ = c.Close() - return nil, err - } - if upgraded != nil { - outConn = upgraded - } + outConn, err := applyEarlyHandshakeOrUpgrade(c, opts) + if err != nil { + _ = c.Close() + return nil, err } return outConn, nil } @@ -1143,16 +1123,10 @@ func dialPoll(ctx context.Context, serverAddress string, opts TunnelDialOptions) if c == nil { return nil, fmt.Errorf("failed to build poll conn") } - outConn := net.Conn(c) - if opts.Upgrade != nil { - upgraded, err := opts.Upgrade(c) - if err != nil { - _ = c.Close() - return nil, err - } - if upgraded != nil { - outConn = upgraded - } + outConn, err := applyEarlyHandshakeOrUpgrade(c, opts) + if err != nil { + _ = c.Close() + return nil, err } return outConn, nil } @@ -1528,6 +1502,8 @@ type TunnelServerOptions struct { PullReadTimeout time.Duration // SessionTTL is a best-effort TTL to prevent leaked sessions. 0 uses a conservative default. SessionTTL time.Duration + // EarlyHandshake optionally folds the protocol handshake into the initial HTTP/WS round trip. + EarlyHandshake *TunnelServerEarlyHandshake } type TunnelServer struct { @@ -1538,6 +1514,7 @@ type TunnelServer struct { pullReadTimeout time.Duration sessionTTL time.Duration + earlyHandshake *TunnelServerEarlyHandshake mu sync.Mutex sessions map[string]*tunnelSession @@ -1570,6 +1547,7 @@ func NewTunnelServer(opts TunnelServerOptions) *TunnelServer { passThroughOnReject: opts.PassThroughOnReject, pullReadTimeout: timeout, sessionTTL: ttl, + earlyHandshake: opts.EarlyHandshake, sessions: make(map[string]*tunnelSession), } } @@ -1925,9 +1903,12 @@ func (s *TunnelServer) handleStream(rawConn net.Conn, req *httpRequestHeader, he switch strings.ToUpper(req.method) { case http.MethodGet: - // Stream split-session: GET /session (no token) => token + start tunnel on a server-side pipe. if token == "" && path == "/session" { - return s.sessionAuthorize(rawConn) + earlyPayload, err := parseEarlyDataQuery(u) + if err != nil { + return rejectOrReply(http.StatusBadRequest, "bad request") + } + return s.sessionAuthorize(rawConn, earlyPayload) } // Stream split-session: GET /stream?token=... => downlink poll. if token != "" && path == "/stream" { @@ -2045,10 +2026,18 @@ func writeSimpleHTTPResponse(w io.Writer, code int, body string) error { func writeTokenHTTPResponse(w io.Writer, token string) error { token = strings.TrimRight(token, "\r\n") - // Use application/octet-stream to avoid CDN auto-compression (e.g. brotli) breaking clients that expect a plain token string. + return writeTokenHTTPResponseWithEarlyData(w, token, nil) +} + +func writeTokenHTTPResponseWithEarlyData(w io.Writer, token string, earlyPayload []byte) error { + token = strings.TrimRight(token, "\r\n") + body := "token=" + token + if len(earlyPayload) > 0 { + body += "\ned=" + base64.RawURLEncoding.EncodeToString(earlyPayload) + } _, err := io.WriteString(w, - fmt.Sprintf("HTTP/1.1 200 OK\r\nContent-Type: application/octet-stream\r\nCache-Control: no-store\r\nPragma: no-cache\r\nContent-Length: %d\r\nConnection: close\r\n\r\ntoken=%s", - len("token=")+len(token), token)) + fmt.Sprintf("HTTP/1.1 200 OK\r\nContent-Type: application/octet-stream\r\nCache-Control: no-store\r\nPragma: no-cache\r\nContent-Length: %d\r\nConnection: close\r\n\r\n%s", + len(body), body)) return err } @@ -2088,7 +2077,11 @@ func (s *TunnelServer) handlePoll(rawConn net.Conn, req *httpRequestHeader, head switch strings.ToUpper(req.method) { case http.MethodGet: if token == "" && path == "/session" { - return s.sessionAuthorize(rawConn) + earlyPayload, err := parseEarlyDataQuery(u) + if err != nil { + return rejectOrReply(http.StatusBadRequest, "bad request") + } + return s.sessionAuthorize(rawConn, earlyPayload) } if token != "" && path == "/stream" { if s.passThroughOnReject && !s.sessionHas(token) { @@ -2128,7 +2121,7 @@ func (s *TunnelServer) handlePoll(rawConn net.Conn, req *httpRequestHeader, head } } -func (s *TunnelServer) sessionAuthorize(rawConn net.Conn) (HandleResult, net.Conn, error) { +func (s *TunnelServer) sessionAuthorize(rawConn net.Conn, earlyPayload []byte) (HandleResult, net.Conn, error) { token, err := newSessionToken() if err != nil { _ = writeSimpleHTTPResponse(rawConn, http.StatusInternalServerError, "internal error") @@ -2137,6 +2130,37 @@ func (s *TunnelServer) sessionAuthorize(rawConn net.Conn) (HandleResult, net.Con } c1, c2 := newHalfPipe() + outConn := net.Conn(c1) + var responsePayload []byte + var userHash string + if len(earlyPayload) > 0 && s.earlyHandshake != nil && s.earlyHandshake.Prepare != nil { + prepared, err := s.earlyHandshake.Prepare(earlyPayload) + if err != nil { + _ = c1.Close() + _ = c2.Close() + if s.passThroughOnReject { + return HandlePassThrough, newRejectedPreBufferedConn(rawConn, nil), nil + } + _ = writeSimpleHTTPResponse(rawConn, http.StatusNotFound, "not found") + _ = rawConn.Close() + return HandleDone, nil, nil + } + responsePayload = prepared.ResponsePayload + userHash = prepared.UserHash + if prepared.WrapConn != nil { + wrapped, err := prepared.WrapConn(c1) + if err != nil { + _ = c1.Close() + _ = c2.Close() + _ = writeSimpleHTTPResponse(rawConn, http.StatusInternalServerError, "internal error") + _ = rawConn.Close() + return HandleDone, nil, nil + } + if wrapped != nil { + outConn = wrapEarlyHandshakeConn(wrapped, userHash) + } + } + } s.mu.Lock() s.sessions[token] = &tunnelSession{conn: c2, lastActive: time.Now()} @@ -2144,9 +2168,9 @@ func (s *TunnelServer) sessionAuthorize(rawConn net.Conn) (HandleResult, net.Con go s.reapLater(token) - _ = writeTokenHTTPResponse(rawConn, token) + _ = writeTokenHTTPResponseWithEarlyData(rawConn, token, responsePayload) _ = rawConn.Close() - return HandleStartTunnel, c1, nil + return HandleStartTunnel, outConn, nil } func newSessionToken() (string, error) { diff --git a/transport/sudoku/obfs/httpmask/tunnel_ws.go b/transport/sudoku/obfs/httpmask/tunnel_ws.go index e1299e3d..8ef8d5c3 100644 --- a/transport/sudoku/obfs/httpmask/tunnel_ws.go +++ b/transport/sudoku/obfs/httpmask/tunnel_ws.go @@ -2,6 +2,7 @@ package httpmask import ( "context" + "encoding/base64" "fmt" "io" mrand "math/rand" @@ -115,6 +116,16 @@ func dialWS(ctx context.Context, serverAddress string, opts TunnelDialOptions) ( Host: urlHost, Path: joinPathRoot(opts.PathRoot, "/ws"), } + if opts.EarlyHandshake != nil && len(opts.EarlyHandshake.RequestPayload) > 0 { + rawURL, err := setEarlyDataQuery(u.String(), opts.EarlyHandshake.RequestPayload) + if err != nil { + return nil, err + } + u, err = url.Parse(rawURL) + if err != nil { + return nil, err + } + } header := make(stdhttp.Header) applyWSHeaders(header, headerHost) @@ -132,6 +143,16 @@ func dialWS(ctx context.Context, serverAddress string, opts TunnelDialOptions) ( d := ws.Dialer{ Host: headerHost, Header: ws.HandshakeHeaderHTTP(header), + OnHeader: func(key, value []byte) error { + if !strings.EqualFold(string(key), tunnelEarlyDataHeader) || opts.EarlyHandshake == nil || opts.EarlyHandshake.HandleResponse == nil { + return nil + } + decoded, err := base64.RawURLEncoding.DecodeString(strings.TrimSpace(string(value))) + if err != nil { + return err + } + return opts.EarlyHandshake.HandleResponse(decoded) + }, NetDial: func(dialCtx context.Context, network, addr string) (net.Conn, error) { if addr == urlHost { addr = dialAddr @@ -161,16 +182,10 @@ func dialWS(ctx context.Context, serverAddress string, opts TunnelDialOptions) ( } wsConn := newWSStreamConn(conn, ws.StateClientSide) - if opts.Upgrade == nil { - return wsConn, nil - } - upgraded, err := opts.Upgrade(wsConn) + upgraded, err := applyEarlyHandshakeOrUpgrade(wsConn, opts) if err != nil { _ = wsConn.Close() return nil, err } - if upgraded != nil { - return upgraded, nil - } - return wsConn, nil + return upgraded, nil } diff --git a/transport/sudoku/obfs/httpmask/tunnel_ws_server.go b/transport/sudoku/obfs/httpmask/tunnel_ws_server.go index 3e79e58a..b17b1ded 100644 --- a/transport/sudoku/obfs/httpmask/tunnel_ws_server.go +++ b/transport/sudoku/obfs/httpmask/tunnel_ws_server.go @@ -1,6 +1,7 @@ package httpmask import ( + "encoding/base64" "net" "net/http" "net/url" @@ -63,15 +64,46 @@ func (s *TunnelServer) handleWS(rawConn net.Conn, req *httpRequestHeader, header return rejectOrReply(http.StatusNotFound, "not found") } + earlyPayload, err := parseEarlyDataQuery(u) + if err != nil { + return rejectOrReply(http.StatusBadRequest, "bad request") + } + var prepared *PreparedServerEarlyHandshake + if len(earlyPayload) > 0 && s.earlyHandshake != nil && s.earlyHandshake.Prepare != nil { + prepared, err = s.earlyHandshake.Prepare(earlyPayload) + if err != nil { + return rejectOrReply(http.StatusNotFound, "not found") + } + } + prefix := make([]byte, 0, len(headerBytes)+len(buffered)) prefix = append(prefix, headerBytes...) prefix = append(prefix, buffered...) wsConnRaw := newPreBufferedConn(rawConn, prefix) - if _, err := ws.Upgrade(wsConnRaw); err != nil { + upgrader := ws.Upgrader{} + if prepared != nil && len(prepared.ResponsePayload) > 0 { + upgrader.OnBeforeUpgrade = func() (ws.HandshakeHeader, error) { + h := http.Header{} + h.Set(tunnelEarlyDataHeader, base64.RawURLEncoding.EncodeToString(prepared.ResponsePayload)) + return ws.HandshakeHeaderHTTP(h), nil + } + } + if _, err := upgrader.Upgrade(wsConnRaw); err != nil { _ = rawConn.Close() return HandleDone, nil, nil } - return HandleStartTunnel, newWSStreamConn(wsConnRaw, ws.StateServerSide), nil + outConn := net.Conn(newWSStreamConn(wsConnRaw, ws.StateServerSide)) + if prepared != nil && prepared.WrapConn != nil { + wrapped, err := prepared.WrapConn(outConn) + if err != nil { + _ = outConn.Close() + return HandleDone, nil, nil + } + if wrapped != nil { + outConn = wrapEarlyHandshakeConn(wrapped, prepared.UserHash) + } + } + return HandleStartTunnel, outConn, nil } diff --git a/transport/sudoku/obfs/sudoku/packed.go b/transport/sudoku/obfs/sudoku/packed.go index 346314a3..0edf4f32 100644 --- a/transport/sudoku/obfs/sudoku/packed.go +++ b/transport/sudoku/obfs/sudoku/packed.go @@ -11,35 +11,35 @@ import ( ) const ( - // 每次从 RNG 获取批量随机数的缓存大小,减少 RNG 函数调用开销 RngBatchSize = 128 + + packedProtectedPrefixBytes = 14 ) -// 1. 使用 12字节->16组 的块处理优化 Write (减少循环开销) -// 2. 使用整数阈值随机概率判断 Padding,与纯 Sudoku 保持流量特征一致 -// 3. Read 使用 copy 移动避免底层数组泄漏 +// PackedConn encodes traffic with the packed Sudoku layout while preserving +// the same padding model as the regular connection. type PackedConn struct { net.Conn table *Table reader *bufio.Reader - // 读缓冲 + // Read-side buffers. rawBuf []byte - pendingData []byte // 解码后尚未被 Read 取走的字节 + pendingData []byte - // 写缓冲与状态 + // Write-side state. writeMu sync.Mutex writeBuf []byte - bitBuf uint64 // 暂存的位数据 - bitCount int // 暂存的位数 + bitBuf uint64 + bitCount int - // 读状态 + // Read-side bit accumulator. readBitBuf uint64 readBits int - // 随机数与填充控制 - 使用整数阈值随机,与 Conn 一致 + // Padding selection matches Conn's threshold-based model. rng *rand.Rand - paddingThreshold uint64 // 与 Conn 保持一致的随机概率模型 + paddingThreshold uint64 padMarker byte padPool []byte } @@ -95,7 +95,6 @@ func NewPackedConn(c net.Conn, table *Table, pMin, pMax int) *PackedConn { return pc } -// maybeAddPadding 内联辅助:根据概率阈值插入 padding func (pc *PackedConn) maybeAddPadding(out []byte) []byte { if shouldPad(pc.rng, pc.paddingThreshold) { out = append(out, pc.getPaddingByte()) @@ -103,7 +102,73 @@ func (pc *PackedConn) maybeAddPadding(out []byte) []byte { return out } -// Write 极致优化版 - 批量处理 12 字节 +func (pc *PackedConn) appendGroup(out []byte, group byte) []byte { + out = pc.maybeAddPadding(out) + return append(out, pc.encodeGroup(group)) +} + +func (pc *PackedConn) appendForcedPadding(out []byte) []byte { + return append(out, pc.getPaddingByte()) +} + +func (pc *PackedConn) nextProtectedPrefixGap() int { + return 1 + pc.rng.Intn(2) +} + +func (pc *PackedConn) writeProtectedPrefix(out []byte, p []byte) ([]byte, int) { + if len(p) == 0 { + return out, 0 + } + + limit := len(p) + if limit > packedProtectedPrefixBytes { + limit = packedProtectedPrefixBytes + } + + for padCount := 0; padCount < 1+pc.rng.Intn(2); padCount++ { + out = pc.appendForcedPadding(out) + } + + gap := pc.nextProtectedPrefixGap() + effective := 0 + for i := 0; i < limit; i++ { + pc.bitBuf = (pc.bitBuf << 8) | uint64(p[i]) + pc.bitCount += 8 + for pc.bitCount >= 6 { + pc.bitCount -= 6 + group := byte(pc.bitBuf >> pc.bitCount) + if pc.bitCount == 0 { + pc.bitBuf = 0 + } else { + pc.bitBuf &= (1 << pc.bitCount) - 1 + } + out = pc.appendGroup(out, group&0x3F) + } + + effective++ + if effective >= gap { + out = pc.appendForcedPadding(out) + effective = 0 + gap = pc.nextProtectedPrefixGap() + } + } + + return out, limit +} + +func (pc *PackedConn) drainPendingData(dst []byte) int { + n := copy(dst, pc.pendingData) + if n == len(pc.pendingData) { + pc.pendingData = pc.pendingData[:0] + return n + } + + remaining := len(pc.pendingData) - n + copy(pc.pendingData, pc.pendingData[n:]) + pc.pendingData = pc.pendingData[:remaining] + return n +} + func (pc *PackedConn) Write(p []byte) (int, error) { if len(p) == 0 { return 0, nil @@ -112,20 +177,19 @@ func (pc *PackedConn) Write(p []byte) (int, error) { pc.writeMu.Lock() defer pc.writeMu.Unlock() - // 1. 预分配内存,避免 append 导致的多次扩容 - // 预估:原数据 * 1.5 (4/3 + padding 余量) needed := len(p)*3/2 + 32 if cap(pc.writeBuf) < needed { pc.writeBuf = make([]byte, 0, needed) } out := pc.writeBuf[:0] - i := 0 + var prefixN int + out, prefixN = pc.writeProtectedPrefix(out, p) + + i := prefixN n := len(p) - // 2. 头部对齐处理 (Slow Path) for pc.bitCount > 0 && i < n { - out = pc.maybeAddPadding(out) b := p[i] i++ pc.bitBuf = (pc.bitBuf << 8) | uint64(b) @@ -138,14 +202,11 @@ func (pc *PackedConn) Write(p []byte) (int, error) { } else { pc.bitBuf &= (1 << pc.bitCount) - 1 } - out = pc.maybeAddPadding(out) - out = append(out, pc.encodeGroup(group&0x3F)) + out = pc.appendGroup(out, group&0x3F) } } - // 3. 极速批量处理 (Fast Path) - 每次处理 12 字节 → 生成 16 个编码组 for i+11 < n { - // 处理 4 组,每组 3 字节 for batch := 0; batch < 4; batch++ { b1, b2, b3 := p[i], p[i+1], p[i+2] i += 3 @@ -155,19 +216,13 @@ func (pc *PackedConn) Write(p []byte) (int, error) { g3 := ((b2 & 0x0F) << 2) | ((b3 >> 6) & 0x03) g4 := b3 & 0x3F - // 每个组之前都有概率插入 padding - out = pc.maybeAddPadding(out) - out = append(out, pc.encodeGroup(g1)) - out = pc.maybeAddPadding(out) - out = append(out, pc.encodeGroup(g2)) - out = pc.maybeAddPadding(out) - out = append(out, pc.encodeGroup(g3)) - out = pc.maybeAddPadding(out) - out = append(out, pc.encodeGroup(g4)) + out = pc.appendGroup(out, g1) + out = pc.appendGroup(out, g2) + out = pc.appendGroup(out, g3) + out = pc.appendGroup(out, g4) } } - // 4. 处理剩余的 3 字节块 for i+2 < n { b1, b2, b3 := p[i], p[i+1], p[i+2] i += 3 @@ -177,17 +232,12 @@ func (pc *PackedConn) Write(p []byte) (int, error) { g3 := ((b2 & 0x0F) << 2) | ((b3 >> 6) & 0x03) g4 := b3 & 0x3F - out = pc.maybeAddPadding(out) - out = append(out, pc.encodeGroup(g1)) - out = pc.maybeAddPadding(out) - out = append(out, pc.encodeGroup(g2)) - out = pc.maybeAddPadding(out) - out = append(out, pc.encodeGroup(g3)) - out = pc.maybeAddPadding(out) - out = append(out, pc.encodeGroup(g4)) + out = pc.appendGroup(out, g1) + out = pc.appendGroup(out, g2) + out = pc.appendGroup(out, g3) + out = pc.appendGroup(out, g4) } - // 5. 尾部处理 (Tail Path) - 处理剩余的 1 或 2 个字节 for ; i < n; i++ { b := p[i] pc.bitBuf = (pc.bitBuf << 8) | uint64(b) @@ -200,35 +250,28 @@ func (pc *PackedConn) Write(p []byte) (int, error) { } else { pc.bitBuf &= (1 << pc.bitCount) - 1 } - out = pc.maybeAddPadding(out) - out = append(out, pc.encodeGroup(group&0x3F)) + out = pc.appendGroup(out, group&0x3F) } } - // 6. 处理残留位 if pc.bitCount > 0 { - out = pc.maybeAddPadding(out) group := byte(pc.bitBuf << (6 - pc.bitCount)) pc.bitBuf = 0 pc.bitCount = 0 - out = append(out, pc.encodeGroup(group&0x3F)) + out = pc.appendGroup(out, group&0x3F) out = append(out, pc.padMarker) } - // 尾部可能添加 padding out = pc.maybeAddPadding(out) - // 发送数据 if len(out) > 0 { - _, err := pc.Conn.Write(out) pc.writeBuf = out[:0] - return len(p), err + return len(p), writeFull(pc.Conn, out) } pc.writeBuf = out[:0] return len(p), nil } -// Flush 处理最后不足 6 bit 的情况 func (pc *PackedConn) Flush() error { pc.writeMu.Lock() defer pc.writeMu.Unlock() @@ -243,38 +286,34 @@ func (pc *PackedConn) Flush() error { out = append(out, pc.padMarker) } - // 尾部随机添加 padding out = pc.maybeAddPadding(out) if len(out) > 0 { - _, err := pc.Conn.Write(out) pc.writeBuf = out[:0] - return err + return writeFull(pc.Conn, out) } return nil } -// Read 优化版:减少切片操作,避免内存泄漏 -func (pc *PackedConn) Read(p []byte) (int, error) { - // 1. 优先返回待处理区的数据 - if len(pc.pendingData) > 0 { - n := copy(p, pc.pendingData) - if n == len(pc.pendingData) { - pc.pendingData = pc.pendingData[:0] - } else { - // 优化:移动剩余数据到数组头部,避免切片指向中间导致内存泄漏 - remaining := len(pc.pendingData) - n - copy(pc.pendingData, pc.pendingData[n:]) - pc.pendingData = pc.pendingData[:remaining] +func writeFull(w io.Writer, b []byte) error { + for len(b) > 0 { + n, err := w.Write(b) + if err != nil { + return err } - return n, nil + b = b[n:] + } + return nil +} + +func (pc *PackedConn) Read(p []byte) (int, error) { + if len(pc.pendingData) > 0 { + return pc.drainPendingData(p), nil } - // 2. 循环读取直到解出数据或出错 for { nr, rErr := pc.reader.Read(pc.rawBuf) if nr > 0 { - // 缓存频繁访问的变量 rBuf := pc.readBitBuf rBits := pc.readBits padMarker := pc.padMarker @@ -324,24 +363,13 @@ func (pc *PackedConn) Read(p []byte) (int, error) { } } - // 3. 返回解码后的数据 - 优化:避免底层数组泄漏 - n := copy(p, pc.pendingData) - if n == len(pc.pendingData) { - pc.pendingData = pc.pendingData[:0] - } else { - remaining := len(pc.pendingData) - n - copy(pc.pendingData, pc.pendingData[n:]) - pc.pendingData = pc.pendingData[:remaining] - } - return n, nil + return pc.drainPendingData(p), nil } -// getPaddingByte 从 Pool 中随机取 Padding 字节 func (pc *PackedConn) getPaddingByte() byte { return pc.padPool[pc.rng.Intn(len(pc.padPool))] } -// encodeGroup 编码 6-bit 组 func (pc *PackedConn) encodeGroup(group byte) byte { return pc.table.layout.encodeGroup(group) } diff --git a/transport/sudoku/obfs/sudoku/packed_prefix_test.go b/transport/sudoku/obfs/sudoku/packed_prefix_test.go new file mode 100644 index 00000000..f041c0f5 --- /dev/null +++ b/transport/sudoku/obfs/sudoku/packed_prefix_test.go @@ -0,0 +1,91 @@ +package sudoku + +import ( + "bytes" + "io" + "math/rand" + "net" + "testing" + "time" +) + +type mockConn struct { + readBuf []byte + writeBuf []byte +} + +func (c *mockConn) Read(p []byte) (int, error) { + if len(c.readBuf) == 0 { + return 0, io.EOF + } + n := copy(p, c.readBuf) + c.readBuf = c.readBuf[n:] + return n, nil +} + +func (c *mockConn) Write(p []byte) (int, error) { + c.writeBuf = append(c.writeBuf, p...) + return len(p), nil +} + +func (c *mockConn) Close() error { return nil } +func (c *mockConn) LocalAddr() net.Addr { return nil } +func (c *mockConn) RemoteAddr() net.Addr { return nil } +func (c *mockConn) SetDeadline(time.Time) error { return nil } +func (c *mockConn) SetReadDeadline(time.Time) error { return nil } +func (c *mockConn) SetWriteDeadline(time.Time) error { return nil } + +func TestPackedConn_ProtectedPrefixPadding(t *testing.T) { + table := NewTable("packed-prefix-seed", "prefer_ascii") + mock := &mockConn{} + writer := NewPackedConn(mock, table, 0, 0) + writer.rng = rand.New(rand.NewSource(1)) + + payload := bytes.Repeat([]byte{0}, 32) + if _, err := writer.Write(payload); err != nil { + t.Fatalf("write: %v", err) + } + + wire := append([]byte(nil), mock.writeBuf...) + if len(wire) < 20 { + t.Fatalf("wire too short: %d", len(wire)) + } + + firstHint := -1 + nonHintCount := 0 + maxHintRun := 0 + currentHintRun := 0 + for i, b := range wire[:20] { + if table.layout.isHint(b) { + if firstHint == -1 { + firstHint = i + } + currentHintRun++ + if currentHintRun > maxHintRun { + maxHintRun = currentHintRun + } + continue + } + nonHintCount++ + currentHintRun = 0 + } + + if firstHint < 1 || firstHint > 2 { + t.Fatalf("expected 1-2 leading padding bytes, first hint index=%d", firstHint) + } + if nonHintCount < 6 { + t.Fatalf("expected dense prefix padding, got only %d non-hint bytes in first 20", nonHintCount) + } + if maxHintRun > 3 { + t.Fatalf("prefix still exposes long hint run: %d", maxHintRun) + } + + reader := NewPackedConn(&mockConn{readBuf: wire}, table, 0, 0) + decoded := make([]byte, len(payload)) + if _, err := io.ReadFull(reader, decoded); err != nil { + t.Fatalf("read back: %v", err) + } + if !bytes.Equal(decoded, payload) { + t.Fatalf("roundtrip mismatch") + } +}