chore: reduce the inherent 1rtt in httpmask mode for sudoku (#2610)

This commit is contained in:
saba-futai
2026-03-11 00:00:32 +08:00
committed by GitHub
parent 6517d2a9b2
commit b3c81602a2
12 changed files with 1188 additions and 210 deletions

View File

@@ -5,6 +5,7 @@ import (
"crypto/aes" "crypto/aes"
"crypto/cipher" "crypto/cipher"
"crypto/hmac" "crypto/hmac"
"crypto/rand"
"crypto/sha256" "crypto/sha256"
"encoding/binary" "encoding/binary"
"errors" "errors"
@@ -12,6 +13,7 @@ import (
"io" "io"
"net" "net"
"sync" "sync"
"sync/atomic"
"golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/chacha20poly1305"
) )
@@ -55,13 +57,15 @@ type RecordConn struct {
recvAEADEpoch uint32 recvAEADEpoch uint32
// Send direction state. // Send direction state.
sendEpoch uint32 sendEpoch uint32
sendSeq uint64 sendSeq uint64
sendBytes int64 sendBytes int64
sendEpochUpdates uint32
// Receive direction state. // Receive direction state.
recvEpoch uint32 recvEpoch uint32
recvSeq uint64 recvSeq uint64
recvInitialized bool
readBuf bytes.Buffer readBuf bytes.Buffer
@@ -105,6 +109,9 @@ func NewRecordConn(conn net.Conn, method string, baseSend, baseRecv []byte) (*Re
} }
rc := &RecordConn{Conn: conn, method: method} rc := &RecordConn{Conn: conn, method: method}
rc.keys = recordKeys{baseSend: cloneBytes(baseSend), baseRecv: cloneBytes(baseRecv)} rc.keys = recordKeys{baseSend: cloneBytes(baseSend), baseRecv: cloneBytes(baseRecv)}
if err := rc.resetTrafficState(); err != nil {
return nil, err
}
return rc, nil return rc, nil
} }
@@ -127,11 +134,9 @@ func (c *RecordConn) Rekey(baseSend, baseRecv []byte) error {
defer c.writeMu.Unlock() defer c.writeMu.Unlock()
c.keys = recordKeys{baseSend: cloneBytes(baseSend), baseRecv: cloneBytes(baseRecv)} c.keys = recordKeys{baseSend: cloneBytes(baseSend), baseRecv: cloneBytes(baseRecv)}
c.sendEpoch = 0 if err := c.resetTrafficState(); err != nil {
c.sendSeq = 0 return err
c.sendBytes = 0 }
c.recvEpoch = 0
c.recvSeq = 0
c.readBuf.Reset() c.readBuf.Reset()
c.sendAEAD = nil c.sendAEAD = nil
@@ -141,6 +146,21 @@ func (c *RecordConn) Rekey(baseSend, baseRecv []byte) error {
return nil return nil
} }
func (c *RecordConn) resetTrafficState() error {
sendEpoch, sendSeq, err := randomRecordCounters()
if err != nil {
return fmt.Errorf("initialize record counters: %w", err)
}
c.sendEpoch = sendEpoch
c.sendSeq = sendSeq
c.sendBytes = 0
c.sendEpochUpdates = 0
c.recvEpoch = 0
c.recvSeq = 0
c.recvInitialized = false
return nil
}
func normalizeAEADMethod(method string) string { func normalizeAEADMethod(method string) string {
switch method { switch method {
case "", "chacha20-poly1305": case "", "chacha20-poly1305":
@@ -166,6 +186,44 @@ func cloneBytes(b []byte) []byte {
return append([]byte(nil), b...) return append([]byte(nil), b...)
} }
func randomRecordCounters() (uint32, uint64, error) {
epoch, err := randomNonZeroUint32()
if err != nil {
return 0, 0, err
}
seq, err := randomNonZeroUint64()
if err != nil {
return 0, 0, err
}
return epoch, seq, nil
}
func randomNonZeroUint32() (uint32, error) {
var b [4]byte
for {
if _, err := io.ReadFull(rand.Reader, b[:]); err != nil {
return 0, err
}
v := binary.BigEndian.Uint32(b[:])
if v != 0 && v != ^uint32(0) {
return v, nil
}
}
}
func randomNonZeroUint64() (uint64, error) {
var b [8]byte
for {
if _, err := io.ReadFull(rand.Reader, b[:]); err != nil {
return 0, err
}
v := binary.BigEndian.Uint64(b[:])
if v != 0 && v != ^uint64(0) {
return v, nil
}
}
}
func (c *RecordConn) newAEADFor(base []byte, epoch uint32) (cipher.AEAD, error) { func (c *RecordConn) newAEADFor(base []byte, epoch uint32) (cipher.AEAD, error) {
if c.method == "none" { if c.method == "none" {
return nil, nil return nil, nil
@@ -209,17 +267,49 @@ func deriveEpochKey(base []byte, epoch uint32, method string) []byte {
return mac.Sum(nil) return mac.Sum(nil)
} }
func (c *RecordConn) maybeBumpSendEpochLocked(addedPlain int) { func (c *RecordConn) maybeBumpSendEpochLocked(addedPlain int) error {
if KeyUpdateAfterBytes <= 0 || c.method == "none" { ku := atomic.LoadInt64(&KeyUpdateAfterBytes)
return if ku <= 0 || c.method == "none" {
return nil
} }
c.sendBytes += int64(addedPlain) c.sendBytes += int64(addedPlain)
threshold := KeyUpdateAfterBytes * int64(c.sendEpoch+1) threshold := ku * int64(c.sendEpochUpdates+1)
if c.sendBytes < threshold { if c.sendBytes < threshold {
return return nil
} }
c.sendEpoch++ c.sendEpoch++
c.sendSeq = 0 c.sendEpochUpdates++
nextSeq, err := randomNonZeroUint64()
if err != nil {
return fmt.Errorf("rotate record seq: %w", err)
}
c.sendSeq = nextSeq
return nil
}
func (c *RecordConn) validateRecvPosition(epoch uint32, seq uint64) error {
if !c.recvInitialized {
return nil
}
if epoch < c.recvEpoch {
return fmt.Errorf("replayed epoch: got %d want >=%d", epoch, c.recvEpoch)
}
if epoch == c.recvEpoch && seq != c.recvSeq {
return fmt.Errorf("out of order: epoch=%d got=%d want=%d", epoch, seq, c.recvSeq)
}
if epoch > c.recvEpoch {
const maxJump = 8
if epoch-c.recvEpoch > maxJump {
return fmt.Errorf("epoch jump too large: got=%d want<=%d", epoch-c.recvEpoch, maxJump)
}
}
return nil
}
func (c *RecordConn) markRecvPosition(epoch uint32, seq uint64) {
c.recvEpoch = epoch
c.recvSeq = seq + 1
c.recvInitialized = true
} }
func (c *RecordConn) Write(p []byte) (int, error) { func (c *RecordConn) Write(p []byte) (int, error) {
@@ -282,7 +372,9 @@ func (c *RecordConn) Write(p []byte) (int, error) {
} }
total += n total += n
c.maybeBumpSendEpochLocked(n) if err := c.maybeBumpSendEpochLocked(n); err != nil {
return total, err
}
} }
return total, nil return total, nil
} }
@@ -324,31 +416,17 @@ func (c *RecordConn) Read(p []byte) (int, error) {
epoch := binary.BigEndian.Uint32(header[:4]) epoch := binary.BigEndian.Uint32(header[:4])
seq := binary.BigEndian.Uint64(header[4:]) seq := binary.BigEndian.Uint64(header[4:])
if epoch < c.recvEpoch { if err := c.validateRecvPosition(epoch, seq); err != nil {
return 0, fmt.Errorf("replayed epoch: got %d want >=%d", epoch, c.recvEpoch) return 0, err
}
if epoch == c.recvEpoch && seq != c.recvSeq {
return 0, fmt.Errorf("out of order: epoch=%d got=%d want=%d", epoch, seq, c.recvSeq)
}
if epoch > c.recvEpoch {
const maxJump = 8
if epoch-c.recvEpoch > maxJump {
return 0, fmt.Errorf("epoch jump too large: got=%d want<=%d", epoch-c.recvEpoch, maxJump)
}
c.recvEpoch = epoch
c.recvSeq = 0
if seq != 0 {
return 0, fmt.Errorf("out of order: epoch advanced to %d but seq=%d", epoch, seq)
}
} }
if c.recvAEAD == nil || c.recvAEADEpoch != c.recvEpoch { if c.recvAEAD == nil || c.recvAEADEpoch != epoch {
a, err := c.newAEADFor(c.keys.baseRecv, c.recvEpoch) a, err := c.newAEADFor(c.keys.baseRecv, epoch)
if err != nil { if err != nil {
return 0, err return 0, err
} }
c.recvAEAD = a c.recvAEAD = a
c.recvAEADEpoch = c.recvEpoch c.recvAEADEpoch = epoch
} }
aead := c.recvAEAD aead := c.recvAEAD
@@ -356,7 +434,7 @@ func (c *RecordConn) Read(p []byte) (int, error) {
if err != nil { if err != nil {
return 0, fmt.Errorf("decryption failed: epoch=%d seq=%d: %w", epoch, seq, err) return 0, fmt.Errorf("decryption failed: epoch=%d seq=%d: %w", epoch, seq, err)
} }
c.recvSeq++ c.markRecvPosition(epoch, seq)
c.readBuf.Write(plaintext) c.readBuf.Write(plaintext)
return c.readBuf.Read(p) return c.readBuf.Read(p)

View File

@@ -0,0 +1,86 @@
package crypto
import (
"bytes"
"crypto/sha256"
"encoding/binary"
"io"
"net"
"testing"
"time"
)
type captureConn struct {
bytes.Buffer
}
func (c *captureConn) Read(_ []byte) (int, error) { return 0, io.EOF }
func (c *captureConn) Write(p []byte) (int, error) { return c.Buffer.Write(p) }
func (c *captureConn) Close() error { return nil }
func (c *captureConn) LocalAddr() net.Addr { return nil }
func (c *captureConn) RemoteAddr() net.Addr { return nil }
func (c *captureConn) SetDeadline(time.Time) error { return nil }
func (c *captureConn) SetReadDeadline(time.Time) error { return nil }
func (c *captureConn) SetWriteDeadline(time.Time) error { return nil }
type replayConn struct {
reader *bytes.Reader
}
func (c *replayConn) Read(p []byte) (int, error) { return c.reader.Read(p) }
func (c *replayConn) Write(p []byte) (int, error) { return len(p), nil }
func (c *replayConn) Close() error { return nil }
func (c *replayConn) LocalAddr() net.Addr { return nil }
func (c *replayConn) RemoteAddr() net.Addr { return nil }
func (c *replayConn) SetDeadline(time.Time) error { return nil }
func (c *replayConn) SetReadDeadline(time.Time) error { return nil }
func (c *replayConn) SetWriteDeadline(time.Time) error { return nil }
func TestRecordConn_FirstFrameUsesRandomizedCounters(t *testing.T) {
pskSend := sha256.Sum256([]byte("record-send"))
pskRecv := sha256.Sum256([]byte("record-recv"))
raw := &captureConn{}
writer, err := NewRecordConn(raw, "chacha20-poly1305", pskSend[:], pskRecv[:])
if err != nil {
t.Fatalf("new writer: %v", err)
}
if writer.sendEpoch == 0 || writer.sendSeq == 0 {
t.Fatalf("expected non-zero randomized counters, got epoch=%d seq=%d", writer.sendEpoch, writer.sendSeq)
}
want := []byte("record prefix camouflage")
if _, err := writer.Write(want); err != nil {
t.Fatalf("write: %v", err)
}
wire := raw.Bytes()
if len(wire) < 2+recordHeaderSize {
t.Fatalf("short frame: %d", len(wire))
}
bodyLen := int(binary.BigEndian.Uint16(wire[:2]))
if bodyLen != len(wire)-2 {
t.Fatalf("body len mismatch: got %d want %d", bodyLen, len(wire)-2)
}
epoch := binary.BigEndian.Uint32(wire[2:6])
seq := binary.BigEndian.Uint64(wire[6:14])
if epoch == 0 || seq == 0 {
t.Fatalf("wire header still starts from zero: epoch=%d seq=%d", epoch, seq)
}
reader, err := NewRecordConn(&replayConn{reader: bytes.NewReader(wire)}, "chacha20-poly1305", pskRecv[:], pskSend[:])
if err != nil {
t.Fatalf("new reader: %v", err)
}
got := make([]byte, len(want))
if _, err := io.ReadFull(reader, got); err != nil {
t.Fatalf("read: %v", err)
}
if !bytes.Equal(got, want) {
t.Fatalf("plaintext mismatch: got %q want %q", got, want)
}
}

View File

@@ -0,0 +1,345 @@
package sudoku
import (
"bytes"
"crypto/ecdh"
"crypto/rand"
"encoding/hex"
"fmt"
"net"
"time"
"github.com/metacubex/mihomo/transport/sudoku/crypto"
httpmaskobfs "github.com/metacubex/mihomo/transport/sudoku/obfs/httpmask"
sudokuobfs "github.com/metacubex/mihomo/transport/sudoku/obfs/sudoku"
)
const earlyKIPHandshakeTTL = 60 * time.Second
type EarlyCodecConfig struct {
PSK string
AEAD string
EnablePureDownlink bool
PaddingMin int
PaddingMax int
}
type EarlyClientState struct {
RequestPayload []byte
cfg EarlyCodecConfig
table *sudokuobfs.Table
nonce [kipHelloNonceSize]byte
ephemeral *ecdh.PrivateKey
sessionC2S []byte
sessionS2C []byte
responseSet bool
}
type EarlyServerState struct {
ResponsePayload []byte
UserHash string
cfg EarlyCodecConfig
table *sudokuobfs.Table
sessionC2S []byte
sessionS2C []byte
}
type ReplayAllowFunc func(userHash string, nonce [kipHelloNonceSize]byte, now time.Time) bool
type earlyMemoryConn struct {
reader *bytes.Reader
write bytes.Buffer
}
func newEarlyMemoryConn(readBuf []byte) *earlyMemoryConn {
return &earlyMemoryConn{reader: bytes.NewReader(readBuf)}
}
func (c *earlyMemoryConn) Read(p []byte) (int, error) {
if c == nil || c.reader == nil {
return 0, net.ErrClosed
}
return c.reader.Read(p)
}
func (c *earlyMemoryConn) Write(p []byte) (int, error) {
if c == nil {
return 0, net.ErrClosed
}
return c.write.Write(p)
}
func (c *earlyMemoryConn) Close() error { return nil }
func (c *earlyMemoryConn) LocalAddr() net.Addr { return earlyDummyAddr("local") }
func (c *earlyMemoryConn) RemoteAddr() net.Addr { return earlyDummyAddr("remote") }
func (c *earlyMemoryConn) SetDeadline(time.Time) error { return nil }
func (c *earlyMemoryConn) SetReadDeadline(time.Time) error { return nil }
func (c *earlyMemoryConn) SetWriteDeadline(time.Time) error { return nil }
func (c *earlyMemoryConn) Written() []byte { return append([]byte(nil), c.write.Bytes()...) }
type earlyDummyAddr string
func (a earlyDummyAddr) Network() string { return string(a) }
func (a earlyDummyAddr) String() string { return string(a) }
func buildEarlyClientObfsConn(raw net.Conn, cfg EarlyCodecConfig, table *sudokuobfs.Table) net.Conn {
base := sudokuobfs.NewConn(raw, table, cfg.PaddingMin, cfg.PaddingMax, false)
if cfg.EnablePureDownlink {
return base
}
packed := sudokuobfs.NewPackedConn(raw, table, cfg.PaddingMin, cfg.PaddingMax)
return newDirectionalConn(raw, packed, base)
}
func buildEarlyServerObfsConn(raw net.Conn, cfg EarlyCodecConfig, table *sudokuobfs.Table) net.Conn {
uplink := sudokuobfs.NewConn(raw, table, cfg.PaddingMin, cfg.PaddingMax, false)
if cfg.EnablePureDownlink {
return uplink
}
packed := sudokuobfs.NewPackedConn(raw, table, cfg.PaddingMin, cfg.PaddingMax)
return newDirectionalConn(raw, uplink, packed, packed.Flush)
}
func NewEarlyClientState(cfg EarlyCodecConfig, table *sudokuobfs.Table, userHash [kipHelloUserHashSize]byte, feats uint32) (*EarlyClientState, error) {
if table == nil {
return nil, fmt.Errorf("nil table")
}
curve := ecdh.X25519()
ephemeral, err := curve.GenerateKey(rand.Reader)
if err != nil {
return nil, fmt.Errorf("ecdh generate failed: %w", err)
}
var nonce [kipHelloNonceSize]byte
if _, err := rand.Read(nonce[:]); err != nil {
return nil, fmt.Errorf("nonce generate failed: %w", err)
}
var clientPub [kipHelloPubSize]byte
copy(clientPub[:], ephemeral.PublicKey().Bytes())
hello := &KIPClientHello{
Timestamp: time.Now(),
UserHash: userHash,
Nonce: nonce,
ClientPub: clientPub,
Features: feats,
}
mem := newEarlyMemoryConn(nil)
obfsConn := buildEarlyClientObfsConn(mem, cfg, table)
pskC2S, pskS2C := derivePSKDirectionalBases(cfg.PSK)
rc, err := crypto.NewRecordConn(obfsConn, cfg.AEAD, pskC2S, pskS2C)
if err != nil {
return nil, fmt.Errorf("client early crypto setup failed: %w", err)
}
if err := WriteKIPMessage(rc, KIPTypeClientHello, hello.EncodePayload()); err != nil {
return nil, fmt.Errorf("write early client hello failed: %w", err)
}
return &EarlyClientState{
RequestPayload: mem.Written(),
cfg: cfg,
table: table,
nonce: nonce,
ephemeral: ephemeral,
}, nil
}
func (s *EarlyClientState) ProcessResponse(payload []byte) error {
if s == nil {
return fmt.Errorf("nil client state")
}
mem := newEarlyMemoryConn(payload)
obfsConn := buildEarlyClientObfsConn(mem, s.cfg, s.table)
pskC2S, pskS2C := derivePSKDirectionalBases(s.cfg.PSK)
rc, err := crypto.NewRecordConn(obfsConn, s.cfg.AEAD, pskC2S, pskS2C)
if err != nil {
return fmt.Errorf("client early crypto setup failed: %w", err)
}
msg, err := ReadKIPMessage(rc)
if err != nil {
return fmt.Errorf("read early server hello failed: %w", err)
}
if msg.Type != KIPTypeServerHello {
return fmt.Errorf("unexpected early handshake message: %d", msg.Type)
}
sh, err := DecodeKIPServerHelloPayload(msg.Payload)
if err != nil {
return fmt.Errorf("decode early server hello failed: %w", err)
}
if sh.Nonce != s.nonce {
return fmt.Errorf("early handshake nonce mismatch")
}
shared, err := x25519SharedSecret(s.ephemeral, sh.ServerPub[:])
if err != nil {
return fmt.Errorf("ecdh failed: %w", err)
}
s.sessionC2S, s.sessionS2C, err = deriveSessionDirectionalBases(s.cfg.PSK, shared, s.nonce)
if err != nil {
return fmt.Errorf("derive session keys failed: %w", err)
}
s.responseSet = true
return nil
}
func (s *EarlyClientState) WrapConn(raw net.Conn) (net.Conn, error) {
if s == nil {
return nil, fmt.Errorf("nil client state")
}
if !s.responseSet {
return nil, fmt.Errorf("early handshake not completed")
}
obfsConn := buildEarlyClientObfsConn(raw, s.cfg, s.table)
rc, err := crypto.NewRecordConn(obfsConn, s.cfg.AEAD, s.sessionC2S, s.sessionS2C)
if err != nil {
return nil, fmt.Errorf("setup client session crypto failed: %w", err)
}
return rc, nil
}
func (s *EarlyClientState) Ready() bool {
return s != nil && s.responseSet
}
func NewHTTPMaskClientEarlyHandshake(cfg EarlyCodecConfig, table *sudokuobfs.Table, userHash [kipHelloUserHashSize]byte, feats uint32) (*httpmaskobfs.ClientEarlyHandshake, error) {
state, err := NewEarlyClientState(cfg, table, userHash, feats)
if err != nil {
return nil, err
}
return &httpmaskobfs.ClientEarlyHandshake{
RequestPayload: state.RequestPayload,
HandleResponse: state.ProcessResponse,
Ready: state.Ready,
WrapConn: state.WrapConn,
}, nil
}
func ProcessEarlyClientPayload(cfg EarlyCodecConfig, tables []*sudokuobfs.Table, payload []byte, allowReplay ReplayAllowFunc) (*EarlyServerState, error) {
if len(payload) == 0 {
return nil, fmt.Errorf("empty early payload")
}
if len(tables) == 0 {
return nil, fmt.Errorf("no tables configured")
}
var firstErr error
for _, table := range tables {
state, err := processEarlyClientPayloadForTable(cfg, table, payload, allowReplay)
if err == nil {
return state, nil
}
if firstErr == nil {
firstErr = err
}
}
if firstErr == nil {
firstErr = fmt.Errorf("early handshake probe failed")
}
return nil, firstErr
}
func processEarlyClientPayloadForTable(cfg EarlyCodecConfig, table *sudokuobfs.Table, payload []byte, allowReplay ReplayAllowFunc) (*EarlyServerState, error) {
mem := newEarlyMemoryConn(payload)
obfsConn := buildEarlyServerObfsConn(mem, cfg, table)
pskC2S, pskS2C := derivePSKDirectionalBases(cfg.PSK)
rc, err := crypto.NewRecordConn(obfsConn, cfg.AEAD, pskS2C, pskC2S)
if err != nil {
return nil, err
}
msg, err := ReadKIPMessage(rc)
if err != nil {
return nil, err
}
if msg.Type != KIPTypeClientHello {
return nil, fmt.Errorf("unexpected handshake message: %d", msg.Type)
}
ch, err := DecodeKIPClientHelloPayload(msg.Payload)
if err != nil {
return nil, err
}
if absInt64(time.Now().Unix()-ch.Timestamp.Unix()) > int64(earlyKIPHandshakeTTL.Seconds()) {
return nil, fmt.Errorf("time skew/replay")
}
userHash := hex.EncodeToString(ch.UserHash[:])
if allowReplay != nil && !allowReplay(userHash, ch.Nonce, time.Now()) {
return nil, fmt.Errorf("replay detected")
}
curve := ecdh.X25519()
serverEphemeral, err := curve.GenerateKey(rand.Reader)
if err != nil {
return nil, fmt.Errorf("ecdh generate failed: %w", err)
}
shared, err := x25519SharedSecret(serverEphemeral, ch.ClientPub[:])
if err != nil {
return nil, fmt.Errorf("ecdh failed: %w", err)
}
sessionC2S, sessionS2C, err := deriveSessionDirectionalBases(cfg.PSK, shared, ch.Nonce)
if err != nil {
return nil, fmt.Errorf("derive session keys failed: %w", err)
}
var serverPub [kipHelloPubSize]byte
copy(serverPub[:], serverEphemeral.PublicKey().Bytes())
serverHello := &KIPServerHello{
Nonce: ch.Nonce,
ServerPub: serverPub,
SelectedFeats: ch.Features & KIPFeatAll,
}
respMem := newEarlyMemoryConn(nil)
respObfs := buildEarlyServerObfsConn(respMem, cfg, table)
respConn, err := crypto.NewRecordConn(respObfs, cfg.AEAD, pskS2C, pskC2S)
if err != nil {
return nil, fmt.Errorf("server early crypto setup failed: %w", err)
}
if err := WriteKIPMessage(respConn, KIPTypeServerHello, serverHello.EncodePayload()); err != nil {
return nil, fmt.Errorf("write early server hello failed: %w", err)
}
return &EarlyServerState{
ResponsePayload: respMem.Written(),
UserHash: userHash,
cfg: cfg,
table: table,
sessionC2S: sessionC2S,
sessionS2C: sessionS2C,
}, nil
}
func (s *EarlyServerState) WrapConn(raw net.Conn) (net.Conn, error) {
if s == nil {
return nil, fmt.Errorf("nil server state")
}
obfsConn := buildEarlyServerObfsConn(raw, s.cfg, s.table)
rc, err := crypto.NewRecordConn(obfsConn, s.cfg.AEAD, s.sessionS2C, s.sessionC2S)
if err != nil {
return nil, fmt.Errorf("setup server session crypto failed: %w", err)
}
return rc, nil
}
func NewHTTPMaskServerEarlyHandshake(cfg EarlyCodecConfig, tables []*sudokuobfs.Table, allowReplay ReplayAllowFunc) *httpmaskobfs.TunnelServerEarlyHandshake {
return &httpmaskobfs.TunnelServerEarlyHandshake{
Prepare: func(payload []byte) (*httpmaskobfs.PreparedServerEarlyHandshake, error) {
state, err := ProcessEarlyClientPayload(cfg, tables, payload, allowReplay)
if err != nil {
return nil, err
}
return &httpmaskobfs.PreparedServerEarlyHandshake{
ResponsePayload: state.ResponsePayload,
WrapConn: state.WrapConn,
UserHash: state.UserHash,
}, nil
},
}
}

View File

@@ -337,6 +337,9 @@ func ServerHandshake(rawConn net.Conn, cfg *ProtocolConfig) (net.Conn, *Handshak
if err := cfg.Validate(); err != nil { if err := cfg.Validate(); err != nil {
return nil, nil, fmt.Errorf("invalid config: %w", err) return nil, nil, fmt.Errorf("invalid config: %w", err)
} }
if userHash, ok := httpmask.EarlyHandshakeUserHash(rawConn); ok {
return rawConn, &HandshakeMeta{UserHash: userHash}, nil
}
handshakeTimeout := time.Duration(cfg.HandshakeTimeoutSeconds) * time.Second handshakeTimeout := time.Duration(cfg.HandshakeTimeoutSeconds) * time.Second
if handshakeTimeout <= 0 { if handshakeTimeout <= 0 {

View File

@@ -14,6 +14,30 @@ type HTTPMaskTunnelServer struct {
ts *httpmask.TunnelServer ts *httpmask.TunnelServer
} }
func newHTTPMaskEarlyCodecConfig(cfg *ProtocolConfig, psk string) EarlyCodecConfig {
return EarlyCodecConfig{
PSK: psk,
AEAD: cfg.AEADMethod,
EnablePureDownlink: cfg.EnablePureDownlink,
PaddingMin: cfg.PaddingMin,
PaddingMax: cfg.PaddingMax,
}
}
func newClientHTTPMaskEarlyHandshake(cfg *ProtocolConfig) (*httpmask.ClientEarlyHandshake, error) {
table, err := pickClientTable(cfg)
if err != nil {
return nil, err
}
return NewHTTPMaskClientEarlyHandshake(
newHTTPMaskEarlyCodecConfig(cfg, ClientAEADSeed(cfg.Key)),
table,
kipUserHashFromKey(cfg.Key),
KIPFeatAll,
)
}
func NewHTTPMaskTunnelServer(cfg *ProtocolConfig) *HTTPMaskTunnelServer { func NewHTTPMaskTunnelServer(cfg *ProtocolConfig) *HTTPMaskTunnelServer {
return newHTTPMaskTunnelServer(cfg, false) return newHTTPMaskTunnelServer(cfg, false)
} }
@@ -35,6 +59,11 @@ func newHTTPMaskTunnelServer(cfg *ProtocolConfig, passThroughOnReject bool) *HTT
Mode: cfg.HTTPMaskMode, Mode: cfg.HTTPMaskMode,
PathRoot: cfg.HTTPMaskPathRoot, PathRoot: cfg.HTTPMaskPathRoot,
AuthKey: ServerAEADSeed(cfg.Key), AuthKey: ServerAEADSeed(cfg.Key),
EarlyHandshake: NewHTTPMaskServerEarlyHandshake(
newHTTPMaskEarlyCodecConfig(cfg, ServerAEADSeed(cfg.Key)),
cfg.tableCandidates(),
globalHandshakeReplay.allow,
),
// When upstream fallback is enabled, preserve rejected HTTP requests for the caller. // When upstream fallback is enabled, preserve rejected HTTP requests for the caller.
PassThroughOnReject: passThroughOnReject, PassThroughOnReject: passThroughOnReject,
}) })
@@ -101,14 +130,25 @@ func DialHTTPMaskTunnel(ctx context.Context, serverAddress string, cfg *Protocol
default: default:
return nil, fmt.Errorf("http-mask-mode=%q does not use http tunnel", cfg.HTTPMaskMode) return nil, fmt.Errorf("http-mask-mode=%q does not use http tunnel", cfg.HTTPMaskMode)
} }
var (
earlyHandshake *httpmask.ClientEarlyHandshake
err error
)
if upgrade != nil {
earlyHandshake, err = newClientHTTPMaskEarlyHandshake(cfg)
if err != nil {
return nil, err
}
}
return httpmask.DialTunnel(ctx, serverAddress, httpmask.TunnelDialOptions{ return httpmask.DialTunnel(ctx, serverAddress, httpmask.TunnelDialOptions{
Mode: cfg.HTTPMaskMode, Mode: cfg.HTTPMaskMode,
TLSEnabled: cfg.HTTPMaskTLSEnabled, TLSEnabled: cfg.HTTPMaskTLSEnabled,
HostOverride: cfg.HTTPMaskHost, HostOverride: cfg.HTTPMaskHost,
PathRoot: cfg.HTTPMaskPathRoot, PathRoot: cfg.HTTPMaskPathRoot,
AuthKey: ClientAEADSeed(cfg.Key), AuthKey: ClientAEADSeed(cfg.Key),
Upgrade: upgrade, EarlyHandshake: earlyHandshake,
Multiplex: cfg.HTTPMaskMultiplex, Upgrade: upgrade,
DialContext: dial, Multiplex: cfg.HTTPMaskMultiplex,
DialContext: dial,
}) })
} }

View File

@@ -389,6 +389,68 @@ func TestHTTPMaskTunnel_WS_TCPRoundTrip(t *testing.T) {
} }
} }
func TestHTTPMaskTunnel_EarlyHandshake_TCPRoundTrip(t *testing.T) {
modes := []string{"stream", "poll", "ws"}
for _, mode := range modes {
t.Run(mode, func(t *testing.T) {
key := "tunnel-early-" + mode
target := "1.1.1.1:80"
serverCfg := newTunnelTestTable(t, key)
serverCfg.HTTPMaskMode = mode
addr, stop, errCh := startTunnelServer(t, serverCfg, func(s *ServerSession) error {
if s.Type != SessionTypeTCP {
return fmt.Errorf("unexpected session type: %v", s.Type)
}
if s.Target != target {
return fmt.Errorf("target mismatch: %s", s.Target)
}
_, _ = s.Conn.Write([]byte("ok"))
return nil
})
defer stop()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
clientCfg := *serverCfg
clientCfg.ServerAddress = addr
handshakeCfg := clientCfg
handshakeCfg.DisableHTTPMask = true
tunnelConn, err := DialHTTPMaskTunnel(ctx, clientCfg.ServerAddress, &clientCfg, (&net.Dialer{}).DialContext, func(raw net.Conn) (net.Conn, error) {
return ClientHandshake(raw, &handshakeCfg)
})
if err != nil {
t.Fatalf("dial tunnel: %v", err)
}
defer tunnelConn.Close()
addrBuf, err := EncodeAddress(target)
if err != nil {
t.Fatalf("encode addr: %v", err)
}
if err := WriteKIPMessage(tunnelConn, KIPTypeOpenTCP, addrBuf); err != nil {
t.Fatalf("write addr: %v", err)
}
buf := make([]byte, 2)
if _, err := io.ReadFull(tunnelConn, buf); err != nil {
t.Fatalf("read: %v", err)
}
if string(buf) != "ok" {
t.Fatalf("unexpected payload: %q", buf)
}
stop()
for err := range errCh {
t.Fatalf("server error: %v", err)
}
})
}
}
func TestHTTPMaskTunnel_Validation(t *testing.T) { func TestHTTPMaskTunnel_Validation(t *testing.T) {
cfg := DefaultConfig() cfg := DefaultConfig()
cfg.Key = "k" cfg.Key = "k"

View File

@@ -0,0 +1,174 @@
package httpmask
import (
"encoding/base64"
"errors"
"fmt"
"net"
"net/url"
"strings"
)
const (
tunnelEarlyDataQueryKey = "ed"
tunnelEarlyDataHeader = "X-Sudoku-Early"
)
type ClientEarlyHandshake struct {
RequestPayload []byte
HandleResponse func(payload []byte) error
Ready func() bool
WrapConn func(raw net.Conn) (net.Conn, error)
}
type TunnelServerEarlyHandshake struct {
Prepare func(payload []byte) (*PreparedServerEarlyHandshake, error)
}
type PreparedServerEarlyHandshake struct {
ResponsePayload []byte
WrapConn func(raw net.Conn) (net.Conn, error)
UserHash string
}
type earlyHandshakeMeta interface {
HTTPMaskEarlyHandshakeUserHash() string
}
type earlyHandshakeConn struct {
net.Conn
userHash string
}
func (c *earlyHandshakeConn) HTTPMaskEarlyHandshakeUserHash() string {
if c == nil {
return ""
}
return c.userHash
}
func wrapEarlyHandshakeConn(conn net.Conn, userHash string) net.Conn {
if conn == nil {
return nil
}
return &earlyHandshakeConn{Conn: conn, userHash: userHash}
}
func EarlyHandshakeUserHash(conn net.Conn) (string, bool) {
if conn == nil {
return "", false
}
v, ok := conn.(earlyHandshakeMeta)
if !ok {
return "", false
}
return v.HTTPMaskEarlyHandshakeUserHash(), true
}
type authorizeResponse struct {
token string
earlyPayload []byte
}
func isTunnelTokenByte(c byte) bool {
return (c >= 'a' && c <= 'z') ||
(c >= 'A' && c <= 'Z') ||
(c >= '0' && c <= '9') ||
c == '-' ||
c == '_'
}
func parseAuthorizeResponse(body []byte) (*authorizeResponse, error) {
s := strings.TrimSpace(string(body))
idx := strings.Index(s, "token=")
if idx < 0 {
return nil, errors.New("missing token")
}
s = s[idx+len("token="):]
if s == "" {
return nil, errors.New("empty token")
}
var b strings.Builder
for i := 0; i < len(s); i++ {
c := s[i]
if isTunnelTokenByte(c) {
b.WriteByte(c)
continue
}
break
}
token := b.String()
if token == "" {
return nil, errors.New("empty token")
}
out := &authorizeResponse{token: token}
if earlyLine := findAuthorizeField(body, "ed="); earlyLine != "" {
decoded, err := base64.RawURLEncoding.DecodeString(earlyLine)
if err != nil {
return nil, fmt.Errorf("decode early authorize payload failed: %w", err)
}
out.earlyPayload = decoded
}
return out, nil
}
func findAuthorizeField(body []byte, prefix string) string {
for _, line := range strings.Split(strings.TrimSpace(string(body)), "\n") {
line = strings.TrimSpace(line)
if strings.HasPrefix(line, prefix) {
return strings.TrimSpace(strings.TrimPrefix(line, prefix))
}
}
return ""
}
func setEarlyDataQuery(rawURL string, payload []byte) (string, error) {
if len(payload) == 0 {
return rawURL, nil
}
u, err := url.Parse(rawURL)
if err != nil {
return "", err
}
q := u.Query()
q.Set(tunnelEarlyDataQueryKey, base64.RawURLEncoding.EncodeToString(payload))
u.RawQuery = q.Encode()
return u.String(), nil
}
func parseEarlyDataQuery(u *url.URL) ([]byte, error) {
if u == nil {
return nil, nil
}
val := strings.TrimSpace(u.Query().Get(tunnelEarlyDataQueryKey))
if val == "" {
return nil, nil
}
return base64.RawURLEncoding.DecodeString(val)
}
func applyEarlyHandshakeOrUpgrade(raw net.Conn, opts TunnelDialOptions) (net.Conn, error) {
out := raw
if opts.EarlyHandshake != nil && opts.EarlyHandshake.WrapConn != nil && (opts.EarlyHandshake.Ready == nil || opts.EarlyHandshake.Ready()) {
wrapped, err := opts.EarlyHandshake.WrapConn(raw)
if err != nil {
return nil, err
}
if wrapped != nil {
out = wrapped
}
return out, nil
}
if opts.Upgrade != nil {
wrapped, err := opts.Upgrade(raw)
if err != nil {
return nil, err
}
if wrapped != nil {
out = wrapped
}
}
return out, nil
}

View File

@@ -72,6 +72,10 @@ type TunnelDialOptions struct {
// AuthKey enables short-term HMAC auth for HTTP tunnel requests (anti-probing). // AuthKey enables short-term HMAC auth for HTTP tunnel requests (anti-probing).
// When set (non-empty), each HTTP request carries an Authorization bearer token derived from AuthKey. // When set (non-empty), each HTTP request carries an Authorization bearer token derived from AuthKey.
AuthKey string AuthKey string
// EarlyHandshake folds the protocol handshake into the HTTP/WS setup round trip.
// When the server accepts the early payload, DialTunnel returns a conn that is already post-handshake.
// When the server does not echo early data, DialTunnel falls back to Upgrade.
EarlyHandshake *ClientEarlyHandshake
// Upgrade optionally wraps the raw tunnel conn and/or writes a small prelude before DialTunnel returns. // Upgrade optionally wraps the raw tunnel conn and/or writes a small prelude before DialTunnel returns.
// It is called with the raw tunnel conn; if it returns a non-nil conn, that conn is returned by DialTunnel. // It is called with the raw tunnel conn; if it returns a non-nil conn, that conn is returned by DialTunnel.
Upgrade func(raw net.Conn) (net.Conn, error) Upgrade func(raw net.Conn) (net.Conn, error)
@@ -225,30 +229,11 @@ func canonicalHeaderHost(urlHost, scheme string) string {
} }
func parseTunnelToken(body []byte) (string, error) { func parseTunnelToken(body []byte) (string, error) {
s := strings.TrimSpace(string(body)) resp, err := parseAuthorizeResponse(body)
idx := strings.Index(s, "token=") if err != nil {
if idx < 0 { return "", err
return "", errors.New("missing token")
} }
s = s[idx+len("token="):] return resp.token, nil
if s == "" {
return "", errors.New("empty token")
}
// Token is base64.RawURLEncoding (A-Z a-z 0-9 - _). Strip any trailing bytes (e.g. from CDN compression).
var b strings.Builder
for i := 0; i < len(s); i++ {
c := s[i]
if (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '-' || c == '_' {
b.WriteByte(c)
continue
}
break
}
token := b.String()
if token == "" {
return "", errors.New("empty token")
}
return token, nil
} }
type httpClientTarget struct { type httpClientTarget struct {
@@ -353,6 +338,13 @@ func dialSessionWithClient(ctx context.Context, client *http.Client, target http
auth := newTunnelAuth(opts.AuthKey, 0) auth := newTunnelAuth(opts.AuthKey, 0)
authorizeURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: joinPathRoot(opts.PathRoot, "/session")}).String() authorizeURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: joinPathRoot(opts.PathRoot, "/session")}).String()
if opts.EarlyHandshake != nil && len(opts.EarlyHandshake.RequestPayload) > 0 {
var err error
authorizeURL, err = setEarlyDataQuery(authorizeURL, opts.EarlyHandshake.RequestPayload)
if err != nil {
return nil, err
}
}
var bodyBytes []byte var bodyBytes []byte
for attempt := 0; ; attempt++ { for attempt := 0; ; attempt++ {
@@ -410,13 +402,19 @@ func dialSessionWithClient(ctx context.Context, client *http.Client, target http
break break
} }
token, err := parseTunnelToken(bodyBytes) authResp, err := parseAuthorizeResponse(bodyBytes)
if err != nil { if err != nil {
return nil, fmt.Errorf("%s authorize failed: %q", mode, strings.TrimSpace(string(bodyBytes))) return nil, fmt.Errorf("%s authorize failed: %q", mode, strings.TrimSpace(string(bodyBytes)))
} }
token := authResp.token
if token == "" { if token == "" {
return nil, fmt.Errorf("%s authorize empty token", mode) return nil, fmt.Errorf("%s authorize empty token", mode)
} }
if opts.EarlyHandshake != nil && len(authResp.earlyPayload) > 0 && opts.EarlyHandshake.HandleResponse != nil {
if err := opts.EarlyHandshake.HandleResponse(authResp.earlyPayload); err != nil {
return nil, err
}
}
pushURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: joinPathRoot(opts.PathRoot, "/api/v1/upload"), RawQuery: "token=" + url.QueryEscape(token)}).String() pushURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: joinPathRoot(opts.PathRoot, "/api/v1/upload"), RawQuery: "token=" + url.QueryEscape(token)}).String()
pullURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: joinPathRoot(opts.PathRoot, "/stream"), RawQuery: "token=" + url.QueryEscape(token)}).String() pullURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: joinPathRoot(opts.PathRoot, "/stream"), RawQuery: "token=" + url.QueryEscape(token)}).String()
@@ -671,16 +669,10 @@ func dialStreamSplitWithClient(ctx context.Context, client *http.Client, target
if c == nil { if c == nil {
return nil, fmt.Errorf("failed to build stream split conn") return nil, fmt.Errorf("failed to build stream split conn")
} }
outConn := net.Conn(c) outConn, err := applyEarlyHandshakeOrUpgrade(c, opts)
if opts.Upgrade != nil { if err != nil {
upgraded, err := opts.Upgrade(c) _ = c.Close()
if err != nil { return nil, err
_ = c.Close()
return nil, err
}
if upgraded != nil {
outConn = upgraded
}
} }
return outConn, nil return outConn, nil
} }
@@ -694,16 +686,10 @@ func dialStreamSplit(ctx context.Context, serverAddress string, opts TunnelDialO
if c == nil { if c == nil {
return nil, fmt.Errorf("failed to build stream split conn") return nil, fmt.Errorf("failed to build stream split conn")
} }
outConn := net.Conn(c) outConn, err := applyEarlyHandshakeOrUpgrade(c, opts)
if opts.Upgrade != nil { if err != nil {
upgraded, err := opts.Upgrade(c) _ = c.Close()
if err != nil { return nil, err
_ = c.Close()
return nil, err
}
if upgraded != nil {
outConn = upgraded
}
} }
return outConn, nil return outConn, nil
} }
@@ -1120,16 +1106,10 @@ func dialPollWithClient(ctx context.Context, client *http.Client, target httpCli
if c == nil { if c == nil {
return nil, fmt.Errorf("failed to build poll conn") return nil, fmt.Errorf("failed to build poll conn")
} }
outConn := net.Conn(c) outConn, err := applyEarlyHandshakeOrUpgrade(c, opts)
if opts.Upgrade != nil { if err != nil {
upgraded, err := opts.Upgrade(c) _ = c.Close()
if err != nil { return nil, err
_ = c.Close()
return nil, err
}
if upgraded != nil {
outConn = upgraded
}
} }
return outConn, nil return outConn, nil
} }
@@ -1143,16 +1123,10 @@ func dialPoll(ctx context.Context, serverAddress string, opts TunnelDialOptions)
if c == nil { if c == nil {
return nil, fmt.Errorf("failed to build poll conn") return nil, fmt.Errorf("failed to build poll conn")
} }
outConn := net.Conn(c) outConn, err := applyEarlyHandshakeOrUpgrade(c, opts)
if opts.Upgrade != nil { if err != nil {
upgraded, err := opts.Upgrade(c) _ = c.Close()
if err != nil { return nil, err
_ = c.Close()
return nil, err
}
if upgraded != nil {
outConn = upgraded
}
} }
return outConn, nil return outConn, nil
} }
@@ -1528,6 +1502,8 @@ type TunnelServerOptions struct {
PullReadTimeout time.Duration PullReadTimeout time.Duration
// SessionTTL is a best-effort TTL to prevent leaked sessions. 0 uses a conservative default. // SessionTTL is a best-effort TTL to prevent leaked sessions. 0 uses a conservative default.
SessionTTL time.Duration SessionTTL time.Duration
// EarlyHandshake optionally folds the protocol handshake into the initial HTTP/WS round trip.
EarlyHandshake *TunnelServerEarlyHandshake
} }
type TunnelServer struct { type TunnelServer struct {
@@ -1538,6 +1514,7 @@ type TunnelServer struct {
pullReadTimeout time.Duration pullReadTimeout time.Duration
sessionTTL time.Duration sessionTTL time.Duration
earlyHandshake *TunnelServerEarlyHandshake
mu sync.Mutex mu sync.Mutex
sessions map[string]*tunnelSession sessions map[string]*tunnelSession
@@ -1570,6 +1547,7 @@ func NewTunnelServer(opts TunnelServerOptions) *TunnelServer {
passThroughOnReject: opts.PassThroughOnReject, passThroughOnReject: opts.PassThroughOnReject,
pullReadTimeout: timeout, pullReadTimeout: timeout,
sessionTTL: ttl, sessionTTL: ttl,
earlyHandshake: opts.EarlyHandshake,
sessions: make(map[string]*tunnelSession), sessions: make(map[string]*tunnelSession),
} }
} }
@@ -1925,9 +1903,12 @@ func (s *TunnelServer) handleStream(rawConn net.Conn, req *httpRequestHeader, he
switch strings.ToUpper(req.method) { switch strings.ToUpper(req.method) {
case http.MethodGet: case http.MethodGet:
// Stream split-session: GET /session (no token) => token + start tunnel on a server-side pipe.
if token == "" && path == "/session" { if token == "" && path == "/session" {
return s.sessionAuthorize(rawConn) earlyPayload, err := parseEarlyDataQuery(u)
if err != nil {
return rejectOrReply(http.StatusBadRequest, "bad request")
}
return s.sessionAuthorize(rawConn, earlyPayload)
} }
// Stream split-session: GET /stream?token=... => downlink poll. // Stream split-session: GET /stream?token=... => downlink poll.
if token != "" && path == "/stream" { if token != "" && path == "/stream" {
@@ -2045,10 +2026,18 @@ func writeSimpleHTTPResponse(w io.Writer, code int, body string) error {
func writeTokenHTTPResponse(w io.Writer, token string) error { func writeTokenHTTPResponse(w io.Writer, token string) error {
token = strings.TrimRight(token, "\r\n") token = strings.TrimRight(token, "\r\n")
// Use application/octet-stream to avoid CDN auto-compression (e.g. brotli) breaking clients that expect a plain token string. return writeTokenHTTPResponseWithEarlyData(w, token, nil)
}
func writeTokenHTTPResponseWithEarlyData(w io.Writer, token string, earlyPayload []byte) error {
token = strings.TrimRight(token, "\r\n")
body := "token=" + token
if len(earlyPayload) > 0 {
body += "\ned=" + base64.RawURLEncoding.EncodeToString(earlyPayload)
}
_, err := io.WriteString(w, _, err := io.WriteString(w,
fmt.Sprintf("HTTP/1.1 200 OK\r\nContent-Type: application/octet-stream\r\nCache-Control: no-store\r\nPragma: no-cache\r\nContent-Length: %d\r\nConnection: close\r\n\r\ntoken=%s", fmt.Sprintf("HTTP/1.1 200 OK\r\nContent-Type: application/octet-stream\r\nCache-Control: no-store\r\nPragma: no-cache\r\nContent-Length: %d\r\nConnection: close\r\n\r\n%s",
len("token=")+len(token), token)) len(body), body))
return err return err
} }
@@ -2088,7 +2077,11 @@ func (s *TunnelServer) handlePoll(rawConn net.Conn, req *httpRequestHeader, head
switch strings.ToUpper(req.method) { switch strings.ToUpper(req.method) {
case http.MethodGet: case http.MethodGet:
if token == "" && path == "/session" { if token == "" && path == "/session" {
return s.sessionAuthorize(rawConn) earlyPayload, err := parseEarlyDataQuery(u)
if err != nil {
return rejectOrReply(http.StatusBadRequest, "bad request")
}
return s.sessionAuthorize(rawConn, earlyPayload)
} }
if token != "" && path == "/stream" { if token != "" && path == "/stream" {
if s.passThroughOnReject && !s.sessionHas(token) { if s.passThroughOnReject && !s.sessionHas(token) {
@@ -2128,7 +2121,7 @@ func (s *TunnelServer) handlePoll(rawConn net.Conn, req *httpRequestHeader, head
} }
} }
func (s *TunnelServer) sessionAuthorize(rawConn net.Conn) (HandleResult, net.Conn, error) { func (s *TunnelServer) sessionAuthorize(rawConn net.Conn, earlyPayload []byte) (HandleResult, net.Conn, error) {
token, err := newSessionToken() token, err := newSessionToken()
if err != nil { if err != nil {
_ = writeSimpleHTTPResponse(rawConn, http.StatusInternalServerError, "internal error") _ = writeSimpleHTTPResponse(rawConn, http.StatusInternalServerError, "internal error")
@@ -2137,6 +2130,37 @@ func (s *TunnelServer) sessionAuthorize(rawConn net.Conn) (HandleResult, net.Con
} }
c1, c2 := newHalfPipe() c1, c2 := newHalfPipe()
outConn := net.Conn(c1)
var responsePayload []byte
var userHash string
if len(earlyPayload) > 0 && s.earlyHandshake != nil && s.earlyHandshake.Prepare != nil {
prepared, err := s.earlyHandshake.Prepare(earlyPayload)
if err != nil {
_ = c1.Close()
_ = c2.Close()
if s.passThroughOnReject {
return HandlePassThrough, newRejectedPreBufferedConn(rawConn, nil), nil
}
_ = writeSimpleHTTPResponse(rawConn, http.StatusNotFound, "not found")
_ = rawConn.Close()
return HandleDone, nil, nil
}
responsePayload = prepared.ResponsePayload
userHash = prepared.UserHash
if prepared.WrapConn != nil {
wrapped, err := prepared.WrapConn(c1)
if err != nil {
_ = c1.Close()
_ = c2.Close()
_ = writeSimpleHTTPResponse(rawConn, http.StatusInternalServerError, "internal error")
_ = rawConn.Close()
return HandleDone, nil, nil
}
if wrapped != nil {
outConn = wrapEarlyHandshakeConn(wrapped, userHash)
}
}
}
s.mu.Lock() s.mu.Lock()
s.sessions[token] = &tunnelSession{conn: c2, lastActive: time.Now()} s.sessions[token] = &tunnelSession{conn: c2, lastActive: time.Now()}
@@ -2144,9 +2168,9 @@ func (s *TunnelServer) sessionAuthorize(rawConn net.Conn) (HandleResult, net.Con
go s.reapLater(token) go s.reapLater(token)
_ = writeTokenHTTPResponse(rawConn, token) _ = writeTokenHTTPResponseWithEarlyData(rawConn, token, responsePayload)
_ = rawConn.Close() _ = rawConn.Close()
return HandleStartTunnel, c1, nil return HandleStartTunnel, outConn, nil
} }
func newSessionToken() (string, error) { func newSessionToken() (string, error) {

View File

@@ -2,6 +2,7 @@ package httpmask
import ( import (
"context" "context"
"encoding/base64"
"fmt" "fmt"
"io" "io"
mrand "math/rand" mrand "math/rand"
@@ -115,6 +116,16 @@ func dialWS(ctx context.Context, serverAddress string, opts TunnelDialOptions) (
Host: urlHost, Host: urlHost,
Path: joinPathRoot(opts.PathRoot, "/ws"), Path: joinPathRoot(opts.PathRoot, "/ws"),
} }
if opts.EarlyHandshake != nil && len(opts.EarlyHandshake.RequestPayload) > 0 {
rawURL, err := setEarlyDataQuery(u.String(), opts.EarlyHandshake.RequestPayload)
if err != nil {
return nil, err
}
u, err = url.Parse(rawURL)
if err != nil {
return nil, err
}
}
header := make(stdhttp.Header) header := make(stdhttp.Header)
applyWSHeaders(header, headerHost) applyWSHeaders(header, headerHost)
@@ -132,6 +143,16 @@ func dialWS(ctx context.Context, serverAddress string, opts TunnelDialOptions) (
d := ws.Dialer{ d := ws.Dialer{
Host: headerHost, Host: headerHost,
Header: ws.HandshakeHeaderHTTP(header), Header: ws.HandshakeHeaderHTTP(header),
OnHeader: func(key, value []byte) error {
if !strings.EqualFold(string(key), tunnelEarlyDataHeader) || opts.EarlyHandshake == nil || opts.EarlyHandshake.HandleResponse == nil {
return nil
}
decoded, err := base64.RawURLEncoding.DecodeString(strings.TrimSpace(string(value)))
if err != nil {
return err
}
return opts.EarlyHandshake.HandleResponse(decoded)
},
NetDial: func(dialCtx context.Context, network, addr string) (net.Conn, error) { NetDial: func(dialCtx context.Context, network, addr string) (net.Conn, error) {
if addr == urlHost { if addr == urlHost {
addr = dialAddr addr = dialAddr
@@ -161,16 +182,10 @@ func dialWS(ctx context.Context, serverAddress string, opts TunnelDialOptions) (
} }
wsConn := newWSStreamConn(conn, ws.StateClientSide) wsConn := newWSStreamConn(conn, ws.StateClientSide)
if opts.Upgrade == nil { upgraded, err := applyEarlyHandshakeOrUpgrade(wsConn, opts)
return wsConn, nil
}
upgraded, err := opts.Upgrade(wsConn)
if err != nil { if err != nil {
_ = wsConn.Close() _ = wsConn.Close()
return nil, err return nil, err
} }
if upgraded != nil { return upgraded, nil
return upgraded, nil
}
return wsConn, nil
} }

View File

@@ -1,6 +1,7 @@
package httpmask package httpmask
import ( import (
"encoding/base64"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
@@ -63,15 +64,46 @@ func (s *TunnelServer) handleWS(rawConn net.Conn, req *httpRequestHeader, header
return rejectOrReply(http.StatusNotFound, "not found") return rejectOrReply(http.StatusNotFound, "not found")
} }
earlyPayload, err := parseEarlyDataQuery(u)
if err != nil {
return rejectOrReply(http.StatusBadRequest, "bad request")
}
var prepared *PreparedServerEarlyHandshake
if len(earlyPayload) > 0 && s.earlyHandshake != nil && s.earlyHandshake.Prepare != nil {
prepared, err = s.earlyHandshake.Prepare(earlyPayload)
if err != nil {
return rejectOrReply(http.StatusNotFound, "not found")
}
}
prefix := make([]byte, 0, len(headerBytes)+len(buffered)) prefix := make([]byte, 0, len(headerBytes)+len(buffered))
prefix = append(prefix, headerBytes...) prefix = append(prefix, headerBytes...)
prefix = append(prefix, buffered...) prefix = append(prefix, buffered...)
wsConnRaw := newPreBufferedConn(rawConn, prefix) wsConnRaw := newPreBufferedConn(rawConn, prefix)
if _, err := ws.Upgrade(wsConnRaw); err != nil { upgrader := ws.Upgrader{}
if prepared != nil && len(prepared.ResponsePayload) > 0 {
upgrader.OnBeforeUpgrade = func() (ws.HandshakeHeader, error) {
h := http.Header{}
h.Set(tunnelEarlyDataHeader, base64.RawURLEncoding.EncodeToString(prepared.ResponsePayload))
return ws.HandshakeHeaderHTTP(h), nil
}
}
if _, err := upgrader.Upgrade(wsConnRaw); err != nil {
_ = rawConn.Close() _ = rawConn.Close()
return HandleDone, nil, nil return HandleDone, nil, nil
} }
return HandleStartTunnel, newWSStreamConn(wsConnRaw, ws.StateServerSide), nil outConn := net.Conn(newWSStreamConn(wsConnRaw, ws.StateServerSide))
if prepared != nil && prepared.WrapConn != nil {
wrapped, err := prepared.WrapConn(outConn)
if err != nil {
_ = outConn.Close()
return HandleDone, nil, nil
}
if wrapped != nil {
outConn = wrapEarlyHandshakeConn(wrapped, prepared.UserHash)
}
}
return HandleStartTunnel, outConn, nil
} }

View File

@@ -11,35 +11,35 @@ import (
) )
const ( const (
// 每次从 RNG 获取批量随机数的缓存大小,减少 RNG 函数调用开销
RngBatchSize = 128 RngBatchSize = 128
packedProtectedPrefixBytes = 14
) )
// 1. 使用 12字节->16组 的块处理优化 Write (减少循环开销) // PackedConn encodes traffic with the packed Sudoku layout while preserving
// 2. 使用整数阈值随机概率判断 Padding与纯 Sudoku 保持流量特征一致 // the same padding model as the regular connection.
// 3. Read 使用 copy 移动避免底层数组泄漏
type PackedConn struct { type PackedConn struct {
net.Conn net.Conn
table *Table table *Table
reader *bufio.Reader reader *bufio.Reader
// 读缓冲 // Read-side buffers.
rawBuf []byte rawBuf []byte
pendingData []byte // 解码后尚未被 Read 取走的字节 pendingData []byte
// 写缓冲与状态 // Write-side state.
writeMu sync.Mutex writeMu sync.Mutex
writeBuf []byte writeBuf []byte
bitBuf uint64 // 暂存的位数据 bitBuf uint64
bitCount int // 暂存的位数 bitCount int
// 读状态 // Read-side bit accumulator.
readBitBuf uint64 readBitBuf uint64
readBits int readBits int
// 随机数与填充控制 - 使用整数阈值随机,与 Conn 一致 // Padding selection matches Conn's threshold-based model.
rng *rand.Rand rng *rand.Rand
paddingThreshold uint64 // 与 Conn 保持一致的随机概率模型 paddingThreshold uint64
padMarker byte padMarker byte
padPool []byte padPool []byte
} }
@@ -95,7 +95,6 @@ func NewPackedConn(c net.Conn, table *Table, pMin, pMax int) *PackedConn {
return pc return pc
} }
// maybeAddPadding 内联辅助:根据概率阈值插入 padding
func (pc *PackedConn) maybeAddPadding(out []byte) []byte { func (pc *PackedConn) maybeAddPadding(out []byte) []byte {
if shouldPad(pc.rng, pc.paddingThreshold) { if shouldPad(pc.rng, pc.paddingThreshold) {
out = append(out, pc.getPaddingByte()) out = append(out, pc.getPaddingByte())
@@ -103,7 +102,73 @@ func (pc *PackedConn) maybeAddPadding(out []byte) []byte {
return out return out
} }
// Write 极致优化版 - 批量处理 12 字节 func (pc *PackedConn) appendGroup(out []byte, group byte) []byte {
out = pc.maybeAddPadding(out)
return append(out, pc.encodeGroup(group))
}
func (pc *PackedConn) appendForcedPadding(out []byte) []byte {
return append(out, pc.getPaddingByte())
}
func (pc *PackedConn) nextProtectedPrefixGap() int {
return 1 + pc.rng.Intn(2)
}
func (pc *PackedConn) writeProtectedPrefix(out []byte, p []byte) ([]byte, int) {
if len(p) == 0 {
return out, 0
}
limit := len(p)
if limit > packedProtectedPrefixBytes {
limit = packedProtectedPrefixBytes
}
for padCount := 0; padCount < 1+pc.rng.Intn(2); padCount++ {
out = pc.appendForcedPadding(out)
}
gap := pc.nextProtectedPrefixGap()
effective := 0
for i := 0; i < limit; i++ {
pc.bitBuf = (pc.bitBuf << 8) | uint64(p[i])
pc.bitCount += 8
for pc.bitCount >= 6 {
pc.bitCount -= 6
group := byte(pc.bitBuf >> pc.bitCount)
if pc.bitCount == 0 {
pc.bitBuf = 0
} else {
pc.bitBuf &= (1 << pc.bitCount) - 1
}
out = pc.appendGroup(out, group&0x3F)
}
effective++
if effective >= gap {
out = pc.appendForcedPadding(out)
effective = 0
gap = pc.nextProtectedPrefixGap()
}
}
return out, limit
}
func (pc *PackedConn) drainPendingData(dst []byte) int {
n := copy(dst, pc.pendingData)
if n == len(pc.pendingData) {
pc.pendingData = pc.pendingData[:0]
return n
}
remaining := len(pc.pendingData) - n
copy(pc.pendingData, pc.pendingData[n:])
pc.pendingData = pc.pendingData[:remaining]
return n
}
func (pc *PackedConn) Write(p []byte) (int, error) { func (pc *PackedConn) Write(p []byte) (int, error) {
if len(p) == 0 { if len(p) == 0 {
return 0, nil return 0, nil
@@ -112,20 +177,19 @@ func (pc *PackedConn) Write(p []byte) (int, error) {
pc.writeMu.Lock() pc.writeMu.Lock()
defer pc.writeMu.Unlock() defer pc.writeMu.Unlock()
// 1. 预分配内存,避免 append 导致的多次扩容
// 预估:原数据 * 1.5 (4/3 + padding 余量)
needed := len(p)*3/2 + 32 needed := len(p)*3/2 + 32
if cap(pc.writeBuf) < needed { if cap(pc.writeBuf) < needed {
pc.writeBuf = make([]byte, 0, needed) pc.writeBuf = make([]byte, 0, needed)
} }
out := pc.writeBuf[:0] out := pc.writeBuf[:0]
i := 0 var prefixN int
out, prefixN = pc.writeProtectedPrefix(out, p)
i := prefixN
n := len(p) n := len(p)
// 2. 头部对齐处理 (Slow Path)
for pc.bitCount > 0 && i < n { for pc.bitCount > 0 && i < n {
out = pc.maybeAddPadding(out)
b := p[i] b := p[i]
i++ i++
pc.bitBuf = (pc.bitBuf << 8) | uint64(b) pc.bitBuf = (pc.bitBuf << 8) | uint64(b)
@@ -138,14 +202,11 @@ func (pc *PackedConn) Write(p []byte) (int, error) {
} else { } else {
pc.bitBuf &= (1 << pc.bitCount) - 1 pc.bitBuf &= (1 << pc.bitCount) - 1
} }
out = pc.maybeAddPadding(out) out = pc.appendGroup(out, group&0x3F)
out = append(out, pc.encodeGroup(group&0x3F))
} }
} }
// 3. 极速批量处理 (Fast Path) - 每次处理 12 字节 → 生成 16 个编码组
for i+11 < n { for i+11 < n {
// 处理 4 组,每组 3 字节
for batch := 0; batch < 4; batch++ { for batch := 0; batch < 4; batch++ {
b1, b2, b3 := p[i], p[i+1], p[i+2] b1, b2, b3 := p[i], p[i+1], p[i+2]
i += 3 i += 3
@@ -155,19 +216,13 @@ func (pc *PackedConn) Write(p []byte) (int, error) {
g3 := ((b2 & 0x0F) << 2) | ((b3 >> 6) & 0x03) g3 := ((b2 & 0x0F) << 2) | ((b3 >> 6) & 0x03)
g4 := b3 & 0x3F g4 := b3 & 0x3F
// 每个组之前都有概率插入 padding out = pc.appendGroup(out, g1)
out = pc.maybeAddPadding(out) out = pc.appendGroup(out, g2)
out = append(out, pc.encodeGroup(g1)) out = pc.appendGroup(out, g3)
out = pc.maybeAddPadding(out) out = pc.appendGroup(out, g4)
out = append(out, pc.encodeGroup(g2))
out = pc.maybeAddPadding(out)
out = append(out, pc.encodeGroup(g3))
out = pc.maybeAddPadding(out)
out = append(out, pc.encodeGroup(g4))
} }
} }
// 4. 处理剩余的 3 字节块
for i+2 < n { for i+2 < n {
b1, b2, b3 := p[i], p[i+1], p[i+2] b1, b2, b3 := p[i], p[i+1], p[i+2]
i += 3 i += 3
@@ -177,17 +232,12 @@ func (pc *PackedConn) Write(p []byte) (int, error) {
g3 := ((b2 & 0x0F) << 2) | ((b3 >> 6) & 0x03) g3 := ((b2 & 0x0F) << 2) | ((b3 >> 6) & 0x03)
g4 := b3 & 0x3F g4 := b3 & 0x3F
out = pc.maybeAddPadding(out) out = pc.appendGroup(out, g1)
out = append(out, pc.encodeGroup(g1)) out = pc.appendGroup(out, g2)
out = pc.maybeAddPadding(out) out = pc.appendGroup(out, g3)
out = append(out, pc.encodeGroup(g2)) out = pc.appendGroup(out, g4)
out = pc.maybeAddPadding(out)
out = append(out, pc.encodeGroup(g3))
out = pc.maybeAddPadding(out)
out = append(out, pc.encodeGroup(g4))
} }
// 5. 尾部处理 (Tail Path) - 处理剩余的 1 或 2 个字节
for ; i < n; i++ { for ; i < n; i++ {
b := p[i] b := p[i]
pc.bitBuf = (pc.bitBuf << 8) | uint64(b) pc.bitBuf = (pc.bitBuf << 8) | uint64(b)
@@ -200,35 +250,28 @@ func (pc *PackedConn) Write(p []byte) (int, error) {
} else { } else {
pc.bitBuf &= (1 << pc.bitCount) - 1 pc.bitBuf &= (1 << pc.bitCount) - 1
} }
out = pc.maybeAddPadding(out) out = pc.appendGroup(out, group&0x3F)
out = append(out, pc.encodeGroup(group&0x3F))
} }
} }
// 6. 处理残留位
if pc.bitCount > 0 { if pc.bitCount > 0 {
out = pc.maybeAddPadding(out)
group := byte(pc.bitBuf << (6 - pc.bitCount)) group := byte(pc.bitBuf << (6 - pc.bitCount))
pc.bitBuf = 0 pc.bitBuf = 0
pc.bitCount = 0 pc.bitCount = 0
out = append(out, pc.encodeGroup(group&0x3F)) out = pc.appendGroup(out, group&0x3F)
out = append(out, pc.padMarker) out = append(out, pc.padMarker)
} }
// 尾部可能添加 padding
out = pc.maybeAddPadding(out) out = pc.maybeAddPadding(out)
// 发送数据
if len(out) > 0 { if len(out) > 0 {
_, err := pc.Conn.Write(out)
pc.writeBuf = out[:0] pc.writeBuf = out[:0]
return len(p), err return len(p), writeFull(pc.Conn, out)
} }
pc.writeBuf = out[:0] pc.writeBuf = out[:0]
return len(p), nil return len(p), nil
} }
// Flush 处理最后不足 6 bit 的情况
func (pc *PackedConn) Flush() error { func (pc *PackedConn) Flush() error {
pc.writeMu.Lock() pc.writeMu.Lock()
defer pc.writeMu.Unlock() defer pc.writeMu.Unlock()
@@ -243,38 +286,34 @@ func (pc *PackedConn) Flush() error {
out = append(out, pc.padMarker) out = append(out, pc.padMarker)
} }
// 尾部随机添加 padding
out = pc.maybeAddPadding(out) out = pc.maybeAddPadding(out)
if len(out) > 0 { if len(out) > 0 {
_, err := pc.Conn.Write(out)
pc.writeBuf = out[:0] pc.writeBuf = out[:0]
return err return writeFull(pc.Conn, out)
} }
return nil return nil
} }
// Read 优化版:减少切片操作,避免内存泄漏 func writeFull(w io.Writer, b []byte) error {
func (pc *PackedConn) Read(p []byte) (int, error) { for len(b) > 0 {
// 1. 优先返回待处理区的数据 n, err := w.Write(b)
if len(pc.pendingData) > 0 { if err != nil {
n := copy(p, pc.pendingData) return err
if n == len(pc.pendingData) {
pc.pendingData = pc.pendingData[:0]
} else {
// 优化:移动剩余数据到数组头部,避免切片指向中间导致内存泄漏
remaining := len(pc.pendingData) - n
copy(pc.pendingData, pc.pendingData[n:])
pc.pendingData = pc.pendingData[:remaining]
} }
return n, nil b = b[n:]
}
return nil
}
func (pc *PackedConn) Read(p []byte) (int, error) {
if len(pc.pendingData) > 0 {
return pc.drainPendingData(p), nil
} }
// 2. 循环读取直到解出数据或出错
for { for {
nr, rErr := pc.reader.Read(pc.rawBuf) nr, rErr := pc.reader.Read(pc.rawBuf)
if nr > 0 { if nr > 0 {
// 缓存频繁访问的变量
rBuf := pc.readBitBuf rBuf := pc.readBitBuf
rBits := pc.readBits rBits := pc.readBits
padMarker := pc.padMarker padMarker := pc.padMarker
@@ -324,24 +363,13 @@ func (pc *PackedConn) Read(p []byte) (int, error) {
} }
} }
// 3. 返回解码后的数据 - 优化:避免底层数组泄漏 return pc.drainPendingData(p), nil
n := copy(p, pc.pendingData)
if n == len(pc.pendingData) {
pc.pendingData = pc.pendingData[:0]
} else {
remaining := len(pc.pendingData) - n
copy(pc.pendingData, pc.pendingData[n:])
pc.pendingData = pc.pendingData[:remaining]
}
return n, nil
} }
// getPaddingByte 从 Pool 中随机取 Padding 字节
func (pc *PackedConn) getPaddingByte() byte { func (pc *PackedConn) getPaddingByte() byte {
return pc.padPool[pc.rng.Intn(len(pc.padPool))] return pc.padPool[pc.rng.Intn(len(pc.padPool))]
} }
// encodeGroup 编码 6-bit 组
func (pc *PackedConn) encodeGroup(group byte) byte { func (pc *PackedConn) encodeGroup(group byte) byte {
return pc.table.layout.encodeGroup(group) return pc.table.layout.encodeGroup(group)
} }

View File

@@ -0,0 +1,91 @@
package sudoku
import (
"bytes"
"io"
"math/rand"
"net"
"testing"
"time"
)
type mockConn struct {
readBuf []byte
writeBuf []byte
}
func (c *mockConn) Read(p []byte) (int, error) {
if len(c.readBuf) == 0 {
return 0, io.EOF
}
n := copy(p, c.readBuf)
c.readBuf = c.readBuf[n:]
return n, nil
}
func (c *mockConn) Write(p []byte) (int, error) {
c.writeBuf = append(c.writeBuf, p...)
return len(p), nil
}
func (c *mockConn) Close() error { return nil }
func (c *mockConn) LocalAddr() net.Addr { return nil }
func (c *mockConn) RemoteAddr() net.Addr { return nil }
func (c *mockConn) SetDeadline(time.Time) error { return nil }
func (c *mockConn) SetReadDeadline(time.Time) error { return nil }
func (c *mockConn) SetWriteDeadline(time.Time) error { return nil }
func TestPackedConn_ProtectedPrefixPadding(t *testing.T) {
table := NewTable("packed-prefix-seed", "prefer_ascii")
mock := &mockConn{}
writer := NewPackedConn(mock, table, 0, 0)
writer.rng = rand.New(rand.NewSource(1))
payload := bytes.Repeat([]byte{0}, 32)
if _, err := writer.Write(payload); err != nil {
t.Fatalf("write: %v", err)
}
wire := append([]byte(nil), mock.writeBuf...)
if len(wire) < 20 {
t.Fatalf("wire too short: %d", len(wire))
}
firstHint := -1
nonHintCount := 0
maxHintRun := 0
currentHintRun := 0
for i, b := range wire[:20] {
if table.layout.isHint(b) {
if firstHint == -1 {
firstHint = i
}
currentHintRun++
if currentHintRun > maxHintRun {
maxHintRun = currentHintRun
}
continue
}
nonHintCount++
currentHintRun = 0
}
if firstHint < 1 || firstHint > 2 {
t.Fatalf("expected 1-2 leading padding bytes, first hint index=%d", firstHint)
}
if nonHintCount < 6 {
t.Fatalf("expected dense prefix padding, got only %d non-hint bytes in first 20", nonHintCount)
}
if maxHintRun > 3 {
t.Fatalf("prefix still exposes long hint run: %d", maxHintRun)
}
reader := NewPackedConn(&mockConn{readBuf: wire}, table, 0, 0)
decoded := make([]byte, len(payload))
if _, err := io.ReadFull(reader, decoded); err != nil {
t.Fatalf("read back: %v", err)
}
if !bytes.Equal(decoded, payload) {
t.Fatalf("roundtrip mismatch")
}
}