mirror of
https://github.com/MetaCubeX/mihomo.git
synced 2026-03-11 02:49:57 +00:00
chore: reduce the inherent 1rtt in httpmask mode for sudoku (#2610)
This commit is contained in:
@@ -5,6 +5,7 @@ import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
@@ -12,6 +13,7 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
)
|
||||
@@ -55,13 +57,15 @@ type RecordConn struct {
|
||||
recvAEADEpoch uint32
|
||||
|
||||
// Send direction state.
|
||||
sendEpoch uint32
|
||||
sendSeq uint64
|
||||
sendBytes int64
|
||||
sendEpoch uint32
|
||||
sendSeq uint64
|
||||
sendBytes int64
|
||||
sendEpochUpdates uint32
|
||||
|
||||
// Receive direction state.
|
||||
recvEpoch uint32
|
||||
recvSeq uint64
|
||||
recvEpoch uint32
|
||||
recvSeq uint64
|
||||
recvInitialized bool
|
||||
|
||||
readBuf bytes.Buffer
|
||||
|
||||
@@ -105,6 +109,9 @@ func NewRecordConn(conn net.Conn, method string, baseSend, baseRecv []byte) (*Re
|
||||
}
|
||||
rc := &RecordConn{Conn: conn, method: method}
|
||||
rc.keys = recordKeys{baseSend: cloneBytes(baseSend), baseRecv: cloneBytes(baseRecv)}
|
||||
if err := rc.resetTrafficState(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return rc, nil
|
||||
}
|
||||
|
||||
@@ -127,11 +134,9 @@ func (c *RecordConn) Rekey(baseSend, baseRecv []byte) error {
|
||||
defer c.writeMu.Unlock()
|
||||
|
||||
c.keys = recordKeys{baseSend: cloneBytes(baseSend), baseRecv: cloneBytes(baseRecv)}
|
||||
c.sendEpoch = 0
|
||||
c.sendSeq = 0
|
||||
c.sendBytes = 0
|
||||
c.recvEpoch = 0
|
||||
c.recvSeq = 0
|
||||
if err := c.resetTrafficState(); err != nil {
|
||||
return err
|
||||
}
|
||||
c.readBuf.Reset()
|
||||
|
||||
c.sendAEAD = nil
|
||||
@@ -141,6 +146,21 @@ func (c *RecordConn) Rekey(baseSend, baseRecv []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *RecordConn) resetTrafficState() error {
|
||||
sendEpoch, sendSeq, err := randomRecordCounters()
|
||||
if err != nil {
|
||||
return fmt.Errorf("initialize record counters: %w", err)
|
||||
}
|
||||
c.sendEpoch = sendEpoch
|
||||
c.sendSeq = sendSeq
|
||||
c.sendBytes = 0
|
||||
c.sendEpochUpdates = 0
|
||||
c.recvEpoch = 0
|
||||
c.recvSeq = 0
|
||||
c.recvInitialized = false
|
||||
return nil
|
||||
}
|
||||
|
||||
func normalizeAEADMethod(method string) string {
|
||||
switch method {
|
||||
case "", "chacha20-poly1305":
|
||||
@@ -166,6 +186,44 @@ func cloneBytes(b []byte) []byte {
|
||||
return append([]byte(nil), b...)
|
||||
}
|
||||
|
||||
func randomRecordCounters() (uint32, uint64, error) {
|
||||
epoch, err := randomNonZeroUint32()
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
seq, err := randomNonZeroUint64()
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
return epoch, seq, nil
|
||||
}
|
||||
|
||||
func randomNonZeroUint32() (uint32, error) {
|
||||
var b [4]byte
|
||||
for {
|
||||
if _, err := io.ReadFull(rand.Reader, b[:]); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
v := binary.BigEndian.Uint32(b[:])
|
||||
if v != 0 && v != ^uint32(0) {
|
||||
return v, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func randomNonZeroUint64() (uint64, error) {
|
||||
var b [8]byte
|
||||
for {
|
||||
if _, err := io.ReadFull(rand.Reader, b[:]); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
v := binary.BigEndian.Uint64(b[:])
|
||||
if v != 0 && v != ^uint64(0) {
|
||||
return v, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *RecordConn) newAEADFor(base []byte, epoch uint32) (cipher.AEAD, error) {
|
||||
if c.method == "none" {
|
||||
return nil, nil
|
||||
@@ -209,17 +267,49 @@ func deriveEpochKey(base []byte, epoch uint32, method string) []byte {
|
||||
return mac.Sum(nil)
|
||||
}
|
||||
|
||||
func (c *RecordConn) maybeBumpSendEpochLocked(addedPlain int) {
|
||||
if KeyUpdateAfterBytes <= 0 || c.method == "none" {
|
||||
return
|
||||
func (c *RecordConn) maybeBumpSendEpochLocked(addedPlain int) error {
|
||||
ku := atomic.LoadInt64(&KeyUpdateAfterBytes)
|
||||
if ku <= 0 || c.method == "none" {
|
||||
return nil
|
||||
}
|
||||
c.sendBytes += int64(addedPlain)
|
||||
threshold := KeyUpdateAfterBytes * int64(c.sendEpoch+1)
|
||||
threshold := ku * int64(c.sendEpochUpdates+1)
|
||||
if c.sendBytes < threshold {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
c.sendEpoch++
|
||||
c.sendSeq = 0
|
||||
c.sendEpochUpdates++
|
||||
nextSeq, err := randomNonZeroUint64()
|
||||
if err != nil {
|
||||
return fmt.Errorf("rotate record seq: %w", err)
|
||||
}
|
||||
c.sendSeq = nextSeq
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *RecordConn) validateRecvPosition(epoch uint32, seq uint64) error {
|
||||
if !c.recvInitialized {
|
||||
return nil
|
||||
}
|
||||
if epoch < c.recvEpoch {
|
||||
return fmt.Errorf("replayed epoch: got %d want >=%d", epoch, c.recvEpoch)
|
||||
}
|
||||
if epoch == c.recvEpoch && seq != c.recvSeq {
|
||||
return fmt.Errorf("out of order: epoch=%d got=%d want=%d", epoch, seq, c.recvSeq)
|
||||
}
|
||||
if epoch > c.recvEpoch {
|
||||
const maxJump = 8
|
||||
if epoch-c.recvEpoch > maxJump {
|
||||
return fmt.Errorf("epoch jump too large: got=%d want<=%d", epoch-c.recvEpoch, maxJump)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *RecordConn) markRecvPosition(epoch uint32, seq uint64) {
|
||||
c.recvEpoch = epoch
|
||||
c.recvSeq = seq + 1
|
||||
c.recvInitialized = true
|
||||
}
|
||||
|
||||
func (c *RecordConn) Write(p []byte) (int, error) {
|
||||
@@ -282,7 +372,9 @@ func (c *RecordConn) Write(p []byte) (int, error) {
|
||||
}
|
||||
|
||||
total += n
|
||||
c.maybeBumpSendEpochLocked(n)
|
||||
if err := c.maybeBumpSendEpochLocked(n); err != nil {
|
||||
return total, err
|
||||
}
|
||||
}
|
||||
return total, nil
|
||||
}
|
||||
@@ -324,31 +416,17 @@ func (c *RecordConn) Read(p []byte) (int, error) {
|
||||
epoch := binary.BigEndian.Uint32(header[:4])
|
||||
seq := binary.BigEndian.Uint64(header[4:])
|
||||
|
||||
if epoch < c.recvEpoch {
|
||||
return 0, fmt.Errorf("replayed epoch: got %d want >=%d", epoch, c.recvEpoch)
|
||||
}
|
||||
if epoch == c.recvEpoch && seq != c.recvSeq {
|
||||
return 0, fmt.Errorf("out of order: epoch=%d got=%d want=%d", epoch, seq, c.recvSeq)
|
||||
}
|
||||
if epoch > c.recvEpoch {
|
||||
const maxJump = 8
|
||||
if epoch-c.recvEpoch > maxJump {
|
||||
return 0, fmt.Errorf("epoch jump too large: got=%d want<=%d", epoch-c.recvEpoch, maxJump)
|
||||
}
|
||||
c.recvEpoch = epoch
|
||||
c.recvSeq = 0
|
||||
if seq != 0 {
|
||||
return 0, fmt.Errorf("out of order: epoch advanced to %d but seq=%d", epoch, seq)
|
||||
}
|
||||
if err := c.validateRecvPosition(epoch, seq); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if c.recvAEAD == nil || c.recvAEADEpoch != c.recvEpoch {
|
||||
a, err := c.newAEADFor(c.keys.baseRecv, c.recvEpoch)
|
||||
if c.recvAEAD == nil || c.recvAEADEpoch != epoch {
|
||||
a, err := c.newAEADFor(c.keys.baseRecv, epoch)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
c.recvAEAD = a
|
||||
c.recvAEADEpoch = c.recvEpoch
|
||||
c.recvAEADEpoch = epoch
|
||||
}
|
||||
aead := c.recvAEAD
|
||||
|
||||
@@ -356,7 +434,7 @@ func (c *RecordConn) Read(p []byte) (int, error) {
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("decryption failed: epoch=%d seq=%d: %w", epoch, seq, err)
|
||||
}
|
||||
c.recvSeq++
|
||||
c.markRecvPosition(epoch, seq)
|
||||
|
||||
c.readBuf.Write(plaintext)
|
||||
return c.readBuf.Read(p)
|
||||
|
||||
86
transport/sudoku/crypto/record_conn_test.go
Normal file
86
transport/sudoku/crypto/record_conn_test.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type captureConn struct {
|
||||
bytes.Buffer
|
||||
}
|
||||
|
||||
func (c *captureConn) Read(_ []byte) (int, error) { return 0, io.EOF }
|
||||
func (c *captureConn) Write(p []byte) (int, error) { return c.Buffer.Write(p) }
|
||||
func (c *captureConn) Close() error { return nil }
|
||||
func (c *captureConn) LocalAddr() net.Addr { return nil }
|
||||
func (c *captureConn) RemoteAddr() net.Addr { return nil }
|
||||
func (c *captureConn) SetDeadline(time.Time) error { return nil }
|
||||
func (c *captureConn) SetReadDeadline(time.Time) error { return nil }
|
||||
func (c *captureConn) SetWriteDeadline(time.Time) error { return nil }
|
||||
|
||||
type replayConn struct {
|
||||
reader *bytes.Reader
|
||||
}
|
||||
|
||||
func (c *replayConn) Read(p []byte) (int, error) { return c.reader.Read(p) }
|
||||
func (c *replayConn) Write(p []byte) (int, error) { return len(p), nil }
|
||||
func (c *replayConn) Close() error { return nil }
|
||||
func (c *replayConn) LocalAddr() net.Addr { return nil }
|
||||
func (c *replayConn) RemoteAddr() net.Addr { return nil }
|
||||
func (c *replayConn) SetDeadline(time.Time) error { return nil }
|
||||
func (c *replayConn) SetReadDeadline(time.Time) error { return nil }
|
||||
func (c *replayConn) SetWriteDeadline(time.Time) error { return nil }
|
||||
|
||||
func TestRecordConn_FirstFrameUsesRandomizedCounters(t *testing.T) {
|
||||
pskSend := sha256.Sum256([]byte("record-send"))
|
||||
pskRecv := sha256.Sum256([]byte("record-recv"))
|
||||
|
||||
raw := &captureConn{}
|
||||
writer, err := NewRecordConn(raw, "chacha20-poly1305", pskSend[:], pskRecv[:])
|
||||
if err != nil {
|
||||
t.Fatalf("new writer: %v", err)
|
||||
}
|
||||
|
||||
if writer.sendEpoch == 0 || writer.sendSeq == 0 {
|
||||
t.Fatalf("expected non-zero randomized counters, got epoch=%d seq=%d", writer.sendEpoch, writer.sendSeq)
|
||||
}
|
||||
|
||||
want := []byte("record prefix camouflage")
|
||||
if _, err := writer.Write(want); err != nil {
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
|
||||
wire := raw.Bytes()
|
||||
if len(wire) < 2+recordHeaderSize {
|
||||
t.Fatalf("short frame: %d", len(wire))
|
||||
}
|
||||
|
||||
bodyLen := int(binary.BigEndian.Uint16(wire[:2]))
|
||||
if bodyLen != len(wire)-2 {
|
||||
t.Fatalf("body len mismatch: got %d want %d", bodyLen, len(wire)-2)
|
||||
}
|
||||
|
||||
epoch := binary.BigEndian.Uint32(wire[2:6])
|
||||
seq := binary.BigEndian.Uint64(wire[6:14])
|
||||
if epoch == 0 || seq == 0 {
|
||||
t.Fatalf("wire header still starts from zero: epoch=%d seq=%d", epoch, seq)
|
||||
}
|
||||
|
||||
reader, err := NewRecordConn(&replayConn{reader: bytes.NewReader(wire)}, "chacha20-poly1305", pskRecv[:], pskSend[:])
|
||||
if err != nil {
|
||||
t.Fatalf("new reader: %v", err)
|
||||
}
|
||||
|
||||
got := make([]byte, len(want))
|
||||
if _, err := io.ReadFull(reader, got); err != nil {
|
||||
t.Fatalf("read: %v", err)
|
||||
}
|
||||
if !bytes.Equal(got, want) {
|
||||
t.Fatalf("plaintext mismatch: got %q want %q", got, want)
|
||||
}
|
||||
}
|
||||
345
transport/sudoku/early_handshake.go
Normal file
345
transport/sudoku/early_handshake.go
Normal file
@@ -0,0 +1,345 @@
|
||||
package sudoku
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/ecdh"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/metacubex/mihomo/transport/sudoku/crypto"
|
||||
httpmaskobfs "github.com/metacubex/mihomo/transport/sudoku/obfs/httpmask"
|
||||
sudokuobfs "github.com/metacubex/mihomo/transport/sudoku/obfs/sudoku"
|
||||
)
|
||||
|
||||
const earlyKIPHandshakeTTL = 60 * time.Second
|
||||
|
||||
type EarlyCodecConfig struct {
|
||||
PSK string
|
||||
AEAD string
|
||||
EnablePureDownlink bool
|
||||
PaddingMin int
|
||||
PaddingMax int
|
||||
}
|
||||
|
||||
type EarlyClientState struct {
|
||||
RequestPayload []byte
|
||||
|
||||
cfg EarlyCodecConfig
|
||||
table *sudokuobfs.Table
|
||||
nonce [kipHelloNonceSize]byte
|
||||
ephemeral *ecdh.PrivateKey
|
||||
sessionC2S []byte
|
||||
sessionS2C []byte
|
||||
responseSet bool
|
||||
}
|
||||
|
||||
type EarlyServerState struct {
|
||||
ResponsePayload []byte
|
||||
UserHash string
|
||||
|
||||
cfg EarlyCodecConfig
|
||||
table *sudokuobfs.Table
|
||||
sessionC2S []byte
|
||||
sessionS2C []byte
|
||||
}
|
||||
|
||||
type ReplayAllowFunc func(userHash string, nonce [kipHelloNonceSize]byte, now time.Time) bool
|
||||
|
||||
type earlyMemoryConn struct {
|
||||
reader *bytes.Reader
|
||||
write bytes.Buffer
|
||||
}
|
||||
|
||||
func newEarlyMemoryConn(readBuf []byte) *earlyMemoryConn {
|
||||
return &earlyMemoryConn{reader: bytes.NewReader(readBuf)}
|
||||
}
|
||||
|
||||
func (c *earlyMemoryConn) Read(p []byte) (int, error) {
|
||||
if c == nil || c.reader == nil {
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
return c.reader.Read(p)
|
||||
}
|
||||
|
||||
func (c *earlyMemoryConn) Write(p []byte) (int, error) {
|
||||
if c == nil {
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
return c.write.Write(p)
|
||||
}
|
||||
|
||||
func (c *earlyMemoryConn) Close() error { return nil }
|
||||
func (c *earlyMemoryConn) LocalAddr() net.Addr { return earlyDummyAddr("local") }
|
||||
func (c *earlyMemoryConn) RemoteAddr() net.Addr { return earlyDummyAddr("remote") }
|
||||
func (c *earlyMemoryConn) SetDeadline(time.Time) error { return nil }
|
||||
func (c *earlyMemoryConn) SetReadDeadline(time.Time) error { return nil }
|
||||
func (c *earlyMemoryConn) SetWriteDeadline(time.Time) error { return nil }
|
||||
func (c *earlyMemoryConn) Written() []byte { return append([]byte(nil), c.write.Bytes()...) }
|
||||
|
||||
type earlyDummyAddr string
|
||||
|
||||
func (a earlyDummyAddr) Network() string { return string(a) }
|
||||
func (a earlyDummyAddr) String() string { return string(a) }
|
||||
|
||||
func buildEarlyClientObfsConn(raw net.Conn, cfg EarlyCodecConfig, table *sudokuobfs.Table) net.Conn {
|
||||
base := sudokuobfs.NewConn(raw, table, cfg.PaddingMin, cfg.PaddingMax, false)
|
||||
if cfg.EnablePureDownlink {
|
||||
return base
|
||||
}
|
||||
packed := sudokuobfs.NewPackedConn(raw, table, cfg.PaddingMin, cfg.PaddingMax)
|
||||
return newDirectionalConn(raw, packed, base)
|
||||
}
|
||||
|
||||
func buildEarlyServerObfsConn(raw net.Conn, cfg EarlyCodecConfig, table *sudokuobfs.Table) net.Conn {
|
||||
uplink := sudokuobfs.NewConn(raw, table, cfg.PaddingMin, cfg.PaddingMax, false)
|
||||
if cfg.EnablePureDownlink {
|
||||
return uplink
|
||||
}
|
||||
packed := sudokuobfs.NewPackedConn(raw, table, cfg.PaddingMin, cfg.PaddingMax)
|
||||
return newDirectionalConn(raw, uplink, packed, packed.Flush)
|
||||
}
|
||||
|
||||
func NewEarlyClientState(cfg EarlyCodecConfig, table *sudokuobfs.Table, userHash [kipHelloUserHashSize]byte, feats uint32) (*EarlyClientState, error) {
|
||||
if table == nil {
|
||||
return nil, fmt.Errorf("nil table")
|
||||
}
|
||||
|
||||
curve := ecdh.X25519()
|
||||
ephemeral, err := curve.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ecdh generate failed: %w", err)
|
||||
}
|
||||
|
||||
var nonce [kipHelloNonceSize]byte
|
||||
if _, err := rand.Read(nonce[:]); err != nil {
|
||||
return nil, fmt.Errorf("nonce generate failed: %w", err)
|
||||
}
|
||||
|
||||
var clientPub [kipHelloPubSize]byte
|
||||
copy(clientPub[:], ephemeral.PublicKey().Bytes())
|
||||
hello := &KIPClientHello{
|
||||
Timestamp: time.Now(),
|
||||
UserHash: userHash,
|
||||
Nonce: nonce,
|
||||
ClientPub: clientPub,
|
||||
Features: feats,
|
||||
}
|
||||
|
||||
mem := newEarlyMemoryConn(nil)
|
||||
obfsConn := buildEarlyClientObfsConn(mem, cfg, table)
|
||||
pskC2S, pskS2C := derivePSKDirectionalBases(cfg.PSK)
|
||||
rc, err := crypto.NewRecordConn(obfsConn, cfg.AEAD, pskC2S, pskS2C)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("client early crypto setup failed: %w", err)
|
||||
}
|
||||
if err := WriteKIPMessage(rc, KIPTypeClientHello, hello.EncodePayload()); err != nil {
|
||||
return nil, fmt.Errorf("write early client hello failed: %w", err)
|
||||
}
|
||||
|
||||
return &EarlyClientState{
|
||||
RequestPayload: mem.Written(),
|
||||
cfg: cfg,
|
||||
table: table,
|
||||
nonce: nonce,
|
||||
ephemeral: ephemeral,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *EarlyClientState) ProcessResponse(payload []byte) error {
|
||||
if s == nil {
|
||||
return fmt.Errorf("nil client state")
|
||||
}
|
||||
|
||||
mem := newEarlyMemoryConn(payload)
|
||||
obfsConn := buildEarlyClientObfsConn(mem, s.cfg, s.table)
|
||||
pskC2S, pskS2C := derivePSKDirectionalBases(s.cfg.PSK)
|
||||
rc, err := crypto.NewRecordConn(obfsConn, s.cfg.AEAD, pskC2S, pskS2C)
|
||||
if err != nil {
|
||||
return fmt.Errorf("client early crypto setup failed: %w", err)
|
||||
}
|
||||
|
||||
msg, err := ReadKIPMessage(rc)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read early server hello failed: %w", err)
|
||||
}
|
||||
if msg.Type != KIPTypeServerHello {
|
||||
return fmt.Errorf("unexpected early handshake message: %d", msg.Type)
|
||||
}
|
||||
sh, err := DecodeKIPServerHelloPayload(msg.Payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("decode early server hello failed: %w", err)
|
||||
}
|
||||
if sh.Nonce != s.nonce {
|
||||
return fmt.Errorf("early handshake nonce mismatch")
|
||||
}
|
||||
|
||||
shared, err := x25519SharedSecret(s.ephemeral, sh.ServerPub[:])
|
||||
if err != nil {
|
||||
return fmt.Errorf("ecdh failed: %w", err)
|
||||
}
|
||||
s.sessionC2S, s.sessionS2C, err = deriveSessionDirectionalBases(s.cfg.PSK, shared, s.nonce)
|
||||
if err != nil {
|
||||
return fmt.Errorf("derive session keys failed: %w", err)
|
||||
}
|
||||
s.responseSet = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *EarlyClientState) WrapConn(raw net.Conn) (net.Conn, error) {
|
||||
if s == nil {
|
||||
return nil, fmt.Errorf("nil client state")
|
||||
}
|
||||
if !s.responseSet {
|
||||
return nil, fmt.Errorf("early handshake not completed")
|
||||
}
|
||||
|
||||
obfsConn := buildEarlyClientObfsConn(raw, s.cfg, s.table)
|
||||
rc, err := crypto.NewRecordConn(obfsConn, s.cfg.AEAD, s.sessionC2S, s.sessionS2C)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setup client session crypto failed: %w", err)
|
||||
}
|
||||
return rc, nil
|
||||
}
|
||||
|
||||
func (s *EarlyClientState) Ready() bool {
|
||||
return s != nil && s.responseSet
|
||||
}
|
||||
|
||||
func NewHTTPMaskClientEarlyHandshake(cfg EarlyCodecConfig, table *sudokuobfs.Table, userHash [kipHelloUserHashSize]byte, feats uint32) (*httpmaskobfs.ClientEarlyHandshake, error) {
|
||||
state, err := NewEarlyClientState(cfg, table, userHash, feats)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &httpmaskobfs.ClientEarlyHandshake{
|
||||
RequestPayload: state.RequestPayload,
|
||||
HandleResponse: state.ProcessResponse,
|
||||
Ready: state.Ready,
|
||||
WrapConn: state.WrapConn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func ProcessEarlyClientPayload(cfg EarlyCodecConfig, tables []*sudokuobfs.Table, payload []byte, allowReplay ReplayAllowFunc) (*EarlyServerState, error) {
|
||||
if len(payload) == 0 {
|
||||
return nil, fmt.Errorf("empty early payload")
|
||||
}
|
||||
if len(tables) == 0 {
|
||||
return nil, fmt.Errorf("no tables configured")
|
||||
}
|
||||
|
||||
var firstErr error
|
||||
for _, table := range tables {
|
||||
state, err := processEarlyClientPayloadForTable(cfg, table, payload, allowReplay)
|
||||
if err == nil {
|
||||
return state, nil
|
||||
}
|
||||
if firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
if firstErr == nil {
|
||||
firstErr = fmt.Errorf("early handshake probe failed")
|
||||
}
|
||||
return nil, firstErr
|
||||
}
|
||||
|
||||
func processEarlyClientPayloadForTable(cfg EarlyCodecConfig, table *sudokuobfs.Table, payload []byte, allowReplay ReplayAllowFunc) (*EarlyServerState, error) {
|
||||
mem := newEarlyMemoryConn(payload)
|
||||
obfsConn := buildEarlyServerObfsConn(mem, cfg, table)
|
||||
pskC2S, pskS2C := derivePSKDirectionalBases(cfg.PSK)
|
||||
rc, err := crypto.NewRecordConn(obfsConn, cfg.AEAD, pskS2C, pskC2S)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msg, err := ReadKIPMessage(rc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if msg.Type != KIPTypeClientHello {
|
||||
return nil, fmt.Errorf("unexpected handshake message: %d", msg.Type)
|
||||
}
|
||||
ch, err := DecodeKIPClientHelloPayload(msg.Payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if absInt64(time.Now().Unix()-ch.Timestamp.Unix()) > int64(earlyKIPHandshakeTTL.Seconds()) {
|
||||
return nil, fmt.Errorf("time skew/replay")
|
||||
}
|
||||
|
||||
userHash := hex.EncodeToString(ch.UserHash[:])
|
||||
if allowReplay != nil && !allowReplay(userHash, ch.Nonce, time.Now()) {
|
||||
return nil, fmt.Errorf("replay detected")
|
||||
}
|
||||
|
||||
curve := ecdh.X25519()
|
||||
serverEphemeral, err := curve.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ecdh generate failed: %w", err)
|
||||
}
|
||||
shared, err := x25519SharedSecret(serverEphemeral, ch.ClientPub[:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ecdh failed: %w", err)
|
||||
}
|
||||
sessionC2S, sessionS2C, err := deriveSessionDirectionalBases(cfg.PSK, shared, ch.Nonce)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("derive session keys failed: %w", err)
|
||||
}
|
||||
|
||||
var serverPub [kipHelloPubSize]byte
|
||||
copy(serverPub[:], serverEphemeral.PublicKey().Bytes())
|
||||
serverHello := &KIPServerHello{
|
||||
Nonce: ch.Nonce,
|
||||
ServerPub: serverPub,
|
||||
SelectedFeats: ch.Features & KIPFeatAll,
|
||||
}
|
||||
|
||||
respMem := newEarlyMemoryConn(nil)
|
||||
respObfs := buildEarlyServerObfsConn(respMem, cfg, table)
|
||||
respConn, err := crypto.NewRecordConn(respObfs, cfg.AEAD, pskS2C, pskC2S)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("server early crypto setup failed: %w", err)
|
||||
}
|
||||
if err := WriteKIPMessage(respConn, KIPTypeServerHello, serverHello.EncodePayload()); err != nil {
|
||||
return nil, fmt.Errorf("write early server hello failed: %w", err)
|
||||
}
|
||||
|
||||
return &EarlyServerState{
|
||||
ResponsePayload: respMem.Written(),
|
||||
UserHash: userHash,
|
||||
cfg: cfg,
|
||||
table: table,
|
||||
sessionC2S: sessionC2S,
|
||||
sessionS2C: sessionS2C,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *EarlyServerState) WrapConn(raw net.Conn) (net.Conn, error) {
|
||||
if s == nil {
|
||||
return nil, fmt.Errorf("nil server state")
|
||||
}
|
||||
obfsConn := buildEarlyServerObfsConn(raw, s.cfg, s.table)
|
||||
rc, err := crypto.NewRecordConn(obfsConn, s.cfg.AEAD, s.sessionS2C, s.sessionC2S)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setup server session crypto failed: %w", err)
|
||||
}
|
||||
return rc, nil
|
||||
}
|
||||
|
||||
func NewHTTPMaskServerEarlyHandshake(cfg EarlyCodecConfig, tables []*sudokuobfs.Table, allowReplay ReplayAllowFunc) *httpmaskobfs.TunnelServerEarlyHandshake {
|
||||
return &httpmaskobfs.TunnelServerEarlyHandshake{
|
||||
Prepare: func(payload []byte) (*httpmaskobfs.PreparedServerEarlyHandshake, error) {
|
||||
state, err := ProcessEarlyClientPayload(cfg, tables, payload, allowReplay)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &httpmaskobfs.PreparedServerEarlyHandshake{
|
||||
ResponsePayload: state.ResponsePayload,
|
||||
WrapConn: state.WrapConn,
|
||||
UserHash: state.UserHash,
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -337,6 +337,9 @@ func ServerHandshake(rawConn net.Conn, cfg *ProtocolConfig) (net.Conn, *Handshak
|
||||
if err := cfg.Validate(); err != nil {
|
||||
return nil, nil, fmt.Errorf("invalid config: %w", err)
|
||||
}
|
||||
if userHash, ok := httpmask.EarlyHandshakeUserHash(rawConn); ok {
|
||||
return rawConn, &HandshakeMeta{UserHash: userHash}, nil
|
||||
}
|
||||
|
||||
handshakeTimeout := time.Duration(cfg.HandshakeTimeoutSeconds) * time.Second
|
||||
if handshakeTimeout <= 0 {
|
||||
|
||||
@@ -14,6 +14,30 @@ type HTTPMaskTunnelServer struct {
|
||||
ts *httpmask.TunnelServer
|
||||
}
|
||||
|
||||
func newHTTPMaskEarlyCodecConfig(cfg *ProtocolConfig, psk string) EarlyCodecConfig {
|
||||
return EarlyCodecConfig{
|
||||
PSK: psk,
|
||||
AEAD: cfg.AEADMethod,
|
||||
EnablePureDownlink: cfg.EnablePureDownlink,
|
||||
PaddingMin: cfg.PaddingMin,
|
||||
PaddingMax: cfg.PaddingMax,
|
||||
}
|
||||
}
|
||||
|
||||
func newClientHTTPMaskEarlyHandshake(cfg *ProtocolConfig) (*httpmask.ClientEarlyHandshake, error) {
|
||||
table, err := pickClientTable(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewHTTPMaskClientEarlyHandshake(
|
||||
newHTTPMaskEarlyCodecConfig(cfg, ClientAEADSeed(cfg.Key)),
|
||||
table,
|
||||
kipUserHashFromKey(cfg.Key),
|
||||
KIPFeatAll,
|
||||
)
|
||||
}
|
||||
|
||||
func NewHTTPMaskTunnelServer(cfg *ProtocolConfig) *HTTPMaskTunnelServer {
|
||||
return newHTTPMaskTunnelServer(cfg, false)
|
||||
}
|
||||
@@ -35,6 +59,11 @@ func newHTTPMaskTunnelServer(cfg *ProtocolConfig, passThroughOnReject bool) *HTT
|
||||
Mode: cfg.HTTPMaskMode,
|
||||
PathRoot: cfg.HTTPMaskPathRoot,
|
||||
AuthKey: ServerAEADSeed(cfg.Key),
|
||||
EarlyHandshake: NewHTTPMaskServerEarlyHandshake(
|
||||
newHTTPMaskEarlyCodecConfig(cfg, ServerAEADSeed(cfg.Key)),
|
||||
cfg.tableCandidates(),
|
||||
globalHandshakeReplay.allow,
|
||||
),
|
||||
// When upstream fallback is enabled, preserve rejected HTTP requests for the caller.
|
||||
PassThroughOnReject: passThroughOnReject,
|
||||
})
|
||||
@@ -101,14 +130,25 @@ func DialHTTPMaskTunnel(ctx context.Context, serverAddress string, cfg *Protocol
|
||||
default:
|
||||
return nil, fmt.Errorf("http-mask-mode=%q does not use http tunnel", cfg.HTTPMaskMode)
|
||||
}
|
||||
var (
|
||||
earlyHandshake *httpmask.ClientEarlyHandshake
|
||||
err error
|
||||
)
|
||||
if upgrade != nil {
|
||||
earlyHandshake, err = newClientHTTPMaskEarlyHandshake(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return httpmask.DialTunnel(ctx, serverAddress, httpmask.TunnelDialOptions{
|
||||
Mode: cfg.HTTPMaskMode,
|
||||
TLSEnabled: cfg.HTTPMaskTLSEnabled,
|
||||
HostOverride: cfg.HTTPMaskHost,
|
||||
PathRoot: cfg.HTTPMaskPathRoot,
|
||||
AuthKey: ClientAEADSeed(cfg.Key),
|
||||
Upgrade: upgrade,
|
||||
Multiplex: cfg.HTTPMaskMultiplex,
|
||||
DialContext: dial,
|
||||
Mode: cfg.HTTPMaskMode,
|
||||
TLSEnabled: cfg.HTTPMaskTLSEnabled,
|
||||
HostOverride: cfg.HTTPMaskHost,
|
||||
PathRoot: cfg.HTTPMaskPathRoot,
|
||||
AuthKey: ClientAEADSeed(cfg.Key),
|
||||
EarlyHandshake: earlyHandshake,
|
||||
Upgrade: upgrade,
|
||||
Multiplex: cfg.HTTPMaskMultiplex,
|
||||
DialContext: dial,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -389,6 +389,68 @@ func TestHTTPMaskTunnel_WS_TCPRoundTrip(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPMaskTunnel_EarlyHandshake_TCPRoundTrip(t *testing.T) {
|
||||
modes := []string{"stream", "poll", "ws"}
|
||||
for _, mode := range modes {
|
||||
t.Run(mode, func(t *testing.T) {
|
||||
key := "tunnel-early-" + mode
|
||||
target := "1.1.1.1:80"
|
||||
|
||||
serverCfg := newTunnelTestTable(t, key)
|
||||
serverCfg.HTTPMaskMode = mode
|
||||
|
||||
addr, stop, errCh := startTunnelServer(t, serverCfg, func(s *ServerSession) error {
|
||||
if s.Type != SessionTypeTCP {
|
||||
return fmt.Errorf("unexpected session type: %v", s.Type)
|
||||
}
|
||||
if s.Target != target {
|
||||
return fmt.Errorf("target mismatch: %s", s.Target)
|
||||
}
|
||||
_, _ = s.Conn.Write([]byte("ok"))
|
||||
return nil
|
||||
})
|
||||
defer stop()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
clientCfg := *serverCfg
|
||||
clientCfg.ServerAddress = addr
|
||||
|
||||
handshakeCfg := clientCfg
|
||||
handshakeCfg.DisableHTTPMask = true
|
||||
tunnelConn, err := DialHTTPMaskTunnel(ctx, clientCfg.ServerAddress, &clientCfg, (&net.Dialer{}).DialContext, func(raw net.Conn) (net.Conn, error) {
|
||||
return ClientHandshake(raw, &handshakeCfg)
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("dial tunnel: %v", err)
|
||||
}
|
||||
defer tunnelConn.Close()
|
||||
|
||||
addrBuf, err := EncodeAddress(target)
|
||||
if err != nil {
|
||||
t.Fatalf("encode addr: %v", err)
|
||||
}
|
||||
if err := WriteKIPMessage(tunnelConn, KIPTypeOpenTCP, addrBuf); err != nil {
|
||||
t.Fatalf("write addr: %v", err)
|
||||
}
|
||||
|
||||
buf := make([]byte, 2)
|
||||
if _, err := io.ReadFull(tunnelConn, buf); err != nil {
|
||||
t.Fatalf("read: %v", err)
|
||||
}
|
||||
if string(buf) != "ok" {
|
||||
t.Fatalf("unexpected payload: %q", buf)
|
||||
}
|
||||
|
||||
stop()
|
||||
for err := range errCh {
|
||||
t.Fatalf("server error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPMaskTunnel_Validation(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
cfg.Key = "k"
|
||||
|
||||
174
transport/sudoku/obfs/httpmask/early_handshake.go
Normal file
174
transport/sudoku/obfs/httpmask/early_handshake.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package httpmask
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
tunnelEarlyDataQueryKey = "ed"
|
||||
tunnelEarlyDataHeader = "X-Sudoku-Early"
|
||||
)
|
||||
|
||||
type ClientEarlyHandshake struct {
|
||||
RequestPayload []byte
|
||||
HandleResponse func(payload []byte) error
|
||||
Ready func() bool
|
||||
WrapConn func(raw net.Conn) (net.Conn, error)
|
||||
}
|
||||
|
||||
type TunnelServerEarlyHandshake struct {
|
||||
Prepare func(payload []byte) (*PreparedServerEarlyHandshake, error)
|
||||
}
|
||||
|
||||
type PreparedServerEarlyHandshake struct {
|
||||
ResponsePayload []byte
|
||||
WrapConn func(raw net.Conn) (net.Conn, error)
|
||||
UserHash string
|
||||
}
|
||||
|
||||
type earlyHandshakeMeta interface {
|
||||
HTTPMaskEarlyHandshakeUserHash() string
|
||||
}
|
||||
|
||||
type earlyHandshakeConn struct {
|
||||
net.Conn
|
||||
userHash string
|
||||
}
|
||||
|
||||
func (c *earlyHandshakeConn) HTTPMaskEarlyHandshakeUserHash() string {
|
||||
if c == nil {
|
||||
return ""
|
||||
}
|
||||
return c.userHash
|
||||
}
|
||||
|
||||
func wrapEarlyHandshakeConn(conn net.Conn, userHash string) net.Conn {
|
||||
if conn == nil {
|
||||
return nil
|
||||
}
|
||||
return &earlyHandshakeConn{Conn: conn, userHash: userHash}
|
||||
}
|
||||
|
||||
func EarlyHandshakeUserHash(conn net.Conn) (string, bool) {
|
||||
if conn == nil {
|
||||
return "", false
|
||||
}
|
||||
v, ok := conn.(earlyHandshakeMeta)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
return v.HTTPMaskEarlyHandshakeUserHash(), true
|
||||
}
|
||||
|
||||
type authorizeResponse struct {
|
||||
token string
|
||||
earlyPayload []byte
|
||||
}
|
||||
|
||||
func isTunnelTokenByte(c byte) bool {
|
||||
return (c >= 'a' && c <= 'z') ||
|
||||
(c >= 'A' && c <= 'Z') ||
|
||||
(c >= '0' && c <= '9') ||
|
||||
c == '-' ||
|
||||
c == '_'
|
||||
}
|
||||
|
||||
func parseAuthorizeResponse(body []byte) (*authorizeResponse, error) {
|
||||
s := strings.TrimSpace(string(body))
|
||||
idx := strings.Index(s, "token=")
|
||||
if idx < 0 {
|
||||
return nil, errors.New("missing token")
|
||||
}
|
||||
s = s[idx+len("token="):]
|
||||
if s == "" {
|
||||
return nil, errors.New("empty token")
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
for i := 0; i < len(s); i++ {
|
||||
c := s[i]
|
||||
if isTunnelTokenByte(c) {
|
||||
b.WriteByte(c)
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
token := b.String()
|
||||
if token == "" {
|
||||
return nil, errors.New("empty token")
|
||||
}
|
||||
|
||||
out := &authorizeResponse{token: token}
|
||||
if earlyLine := findAuthorizeField(body, "ed="); earlyLine != "" {
|
||||
decoded, err := base64.RawURLEncoding.DecodeString(earlyLine)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode early authorize payload failed: %w", err)
|
||||
}
|
||||
out.earlyPayload = decoded
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func findAuthorizeField(body []byte, prefix string) string {
|
||||
for _, line := range strings.Split(strings.TrimSpace(string(body)), "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if strings.HasPrefix(line, prefix) {
|
||||
return strings.TrimSpace(strings.TrimPrefix(line, prefix))
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func setEarlyDataQuery(rawURL string, payload []byte) (string, error) {
|
||||
if len(payload) == 0 {
|
||||
return rawURL, nil
|
||||
}
|
||||
u, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
q := u.Query()
|
||||
q.Set(tunnelEarlyDataQueryKey, base64.RawURLEncoding.EncodeToString(payload))
|
||||
u.RawQuery = q.Encode()
|
||||
return u.String(), nil
|
||||
}
|
||||
|
||||
func parseEarlyDataQuery(u *url.URL) ([]byte, error) {
|
||||
if u == nil {
|
||||
return nil, nil
|
||||
}
|
||||
val := strings.TrimSpace(u.Query().Get(tunnelEarlyDataQueryKey))
|
||||
if val == "" {
|
||||
return nil, nil
|
||||
}
|
||||
return base64.RawURLEncoding.DecodeString(val)
|
||||
}
|
||||
|
||||
func applyEarlyHandshakeOrUpgrade(raw net.Conn, opts TunnelDialOptions) (net.Conn, error) {
|
||||
out := raw
|
||||
if opts.EarlyHandshake != nil && opts.EarlyHandshake.WrapConn != nil && (opts.EarlyHandshake.Ready == nil || opts.EarlyHandshake.Ready()) {
|
||||
wrapped, err := opts.EarlyHandshake.WrapConn(raw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if wrapped != nil {
|
||||
out = wrapped
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
if opts.Upgrade != nil {
|
||||
wrapped, err := opts.Upgrade(raw)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if wrapped != nil {
|
||||
out = wrapped
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
@@ -72,6 +72,10 @@ type TunnelDialOptions struct {
|
||||
// AuthKey enables short-term HMAC auth for HTTP tunnel requests (anti-probing).
|
||||
// When set (non-empty), each HTTP request carries an Authorization bearer token derived from AuthKey.
|
||||
AuthKey string
|
||||
// EarlyHandshake folds the protocol handshake into the HTTP/WS setup round trip.
|
||||
// When the server accepts the early payload, DialTunnel returns a conn that is already post-handshake.
|
||||
// When the server does not echo early data, DialTunnel falls back to Upgrade.
|
||||
EarlyHandshake *ClientEarlyHandshake
|
||||
// Upgrade optionally wraps the raw tunnel conn and/or writes a small prelude before DialTunnel returns.
|
||||
// It is called with the raw tunnel conn; if it returns a non-nil conn, that conn is returned by DialTunnel.
|
||||
Upgrade func(raw net.Conn) (net.Conn, error)
|
||||
@@ -225,30 +229,11 @@ func canonicalHeaderHost(urlHost, scheme string) string {
|
||||
}
|
||||
|
||||
func parseTunnelToken(body []byte) (string, error) {
|
||||
s := strings.TrimSpace(string(body))
|
||||
idx := strings.Index(s, "token=")
|
||||
if idx < 0 {
|
||||
return "", errors.New("missing token")
|
||||
resp, err := parseAuthorizeResponse(body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
s = s[idx+len("token="):]
|
||||
if s == "" {
|
||||
return "", errors.New("empty token")
|
||||
}
|
||||
// Token is base64.RawURLEncoding (A-Z a-z 0-9 - _). Strip any trailing bytes (e.g. from CDN compression).
|
||||
var b strings.Builder
|
||||
for i := 0; i < len(s); i++ {
|
||||
c := s[i]
|
||||
if (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '-' || c == '_' {
|
||||
b.WriteByte(c)
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
token := b.String()
|
||||
if token == "" {
|
||||
return "", errors.New("empty token")
|
||||
}
|
||||
return token, nil
|
||||
return resp.token, nil
|
||||
}
|
||||
|
||||
type httpClientTarget struct {
|
||||
@@ -353,6 +338,13 @@ func dialSessionWithClient(ctx context.Context, client *http.Client, target http
|
||||
|
||||
auth := newTunnelAuth(opts.AuthKey, 0)
|
||||
authorizeURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: joinPathRoot(opts.PathRoot, "/session")}).String()
|
||||
if opts.EarlyHandshake != nil && len(opts.EarlyHandshake.RequestPayload) > 0 {
|
||||
var err error
|
||||
authorizeURL, err = setEarlyDataQuery(authorizeURL, opts.EarlyHandshake.RequestPayload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
var bodyBytes []byte
|
||||
for attempt := 0; ; attempt++ {
|
||||
@@ -410,13 +402,19 @@ func dialSessionWithClient(ctx context.Context, client *http.Client, target http
|
||||
break
|
||||
}
|
||||
|
||||
token, err := parseTunnelToken(bodyBytes)
|
||||
authResp, err := parseAuthorizeResponse(bodyBytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s authorize failed: %q", mode, strings.TrimSpace(string(bodyBytes)))
|
||||
}
|
||||
token := authResp.token
|
||||
if token == "" {
|
||||
return nil, fmt.Errorf("%s authorize empty token", mode)
|
||||
}
|
||||
if opts.EarlyHandshake != nil && len(authResp.earlyPayload) > 0 && opts.EarlyHandshake.HandleResponse != nil {
|
||||
if err := opts.EarlyHandshake.HandleResponse(authResp.earlyPayload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
pushURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: joinPathRoot(opts.PathRoot, "/api/v1/upload"), RawQuery: "token=" + url.QueryEscape(token)}).String()
|
||||
pullURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: joinPathRoot(opts.PathRoot, "/stream"), RawQuery: "token=" + url.QueryEscape(token)}).String()
|
||||
@@ -671,16 +669,10 @@ func dialStreamSplitWithClient(ctx context.Context, client *http.Client, target
|
||||
if c == nil {
|
||||
return nil, fmt.Errorf("failed to build stream split conn")
|
||||
}
|
||||
outConn := net.Conn(c)
|
||||
if opts.Upgrade != nil {
|
||||
upgraded, err := opts.Upgrade(c)
|
||||
if err != nil {
|
||||
_ = c.Close()
|
||||
return nil, err
|
||||
}
|
||||
if upgraded != nil {
|
||||
outConn = upgraded
|
||||
}
|
||||
outConn, err := applyEarlyHandshakeOrUpgrade(c, opts)
|
||||
if err != nil {
|
||||
_ = c.Close()
|
||||
return nil, err
|
||||
}
|
||||
return outConn, nil
|
||||
}
|
||||
@@ -694,16 +686,10 @@ func dialStreamSplit(ctx context.Context, serverAddress string, opts TunnelDialO
|
||||
if c == nil {
|
||||
return nil, fmt.Errorf("failed to build stream split conn")
|
||||
}
|
||||
outConn := net.Conn(c)
|
||||
if opts.Upgrade != nil {
|
||||
upgraded, err := opts.Upgrade(c)
|
||||
if err != nil {
|
||||
_ = c.Close()
|
||||
return nil, err
|
||||
}
|
||||
if upgraded != nil {
|
||||
outConn = upgraded
|
||||
}
|
||||
outConn, err := applyEarlyHandshakeOrUpgrade(c, opts)
|
||||
if err != nil {
|
||||
_ = c.Close()
|
||||
return nil, err
|
||||
}
|
||||
return outConn, nil
|
||||
}
|
||||
@@ -1120,16 +1106,10 @@ func dialPollWithClient(ctx context.Context, client *http.Client, target httpCli
|
||||
if c == nil {
|
||||
return nil, fmt.Errorf("failed to build poll conn")
|
||||
}
|
||||
outConn := net.Conn(c)
|
||||
if opts.Upgrade != nil {
|
||||
upgraded, err := opts.Upgrade(c)
|
||||
if err != nil {
|
||||
_ = c.Close()
|
||||
return nil, err
|
||||
}
|
||||
if upgraded != nil {
|
||||
outConn = upgraded
|
||||
}
|
||||
outConn, err := applyEarlyHandshakeOrUpgrade(c, opts)
|
||||
if err != nil {
|
||||
_ = c.Close()
|
||||
return nil, err
|
||||
}
|
||||
return outConn, nil
|
||||
}
|
||||
@@ -1143,16 +1123,10 @@ func dialPoll(ctx context.Context, serverAddress string, opts TunnelDialOptions)
|
||||
if c == nil {
|
||||
return nil, fmt.Errorf("failed to build poll conn")
|
||||
}
|
||||
outConn := net.Conn(c)
|
||||
if opts.Upgrade != nil {
|
||||
upgraded, err := opts.Upgrade(c)
|
||||
if err != nil {
|
||||
_ = c.Close()
|
||||
return nil, err
|
||||
}
|
||||
if upgraded != nil {
|
||||
outConn = upgraded
|
||||
}
|
||||
outConn, err := applyEarlyHandshakeOrUpgrade(c, opts)
|
||||
if err != nil {
|
||||
_ = c.Close()
|
||||
return nil, err
|
||||
}
|
||||
return outConn, nil
|
||||
}
|
||||
@@ -1528,6 +1502,8 @@ type TunnelServerOptions struct {
|
||||
PullReadTimeout time.Duration
|
||||
// SessionTTL is a best-effort TTL to prevent leaked sessions. 0 uses a conservative default.
|
||||
SessionTTL time.Duration
|
||||
// EarlyHandshake optionally folds the protocol handshake into the initial HTTP/WS round trip.
|
||||
EarlyHandshake *TunnelServerEarlyHandshake
|
||||
}
|
||||
|
||||
type TunnelServer struct {
|
||||
@@ -1538,6 +1514,7 @@ type TunnelServer struct {
|
||||
|
||||
pullReadTimeout time.Duration
|
||||
sessionTTL time.Duration
|
||||
earlyHandshake *TunnelServerEarlyHandshake
|
||||
|
||||
mu sync.Mutex
|
||||
sessions map[string]*tunnelSession
|
||||
@@ -1570,6 +1547,7 @@ func NewTunnelServer(opts TunnelServerOptions) *TunnelServer {
|
||||
passThroughOnReject: opts.PassThroughOnReject,
|
||||
pullReadTimeout: timeout,
|
||||
sessionTTL: ttl,
|
||||
earlyHandshake: opts.EarlyHandshake,
|
||||
sessions: make(map[string]*tunnelSession),
|
||||
}
|
||||
}
|
||||
@@ -1925,9 +1903,12 @@ func (s *TunnelServer) handleStream(rawConn net.Conn, req *httpRequestHeader, he
|
||||
|
||||
switch strings.ToUpper(req.method) {
|
||||
case http.MethodGet:
|
||||
// Stream split-session: GET /session (no token) => token + start tunnel on a server-side pipe.
|
||||
if token == "" && path == "/session" {
|
||||
return s.sessionAuthorize(rawConn)
|
||||
earlyPayload, err := parseEarlyDataQuery(u)
|
||||
if err != nil {
|
||||
return rejectOrReply(http.StatusBadRequest, "bad request")
|
||||
}
|
||||
return s.sessionAuthorize(rawConn, earlyPayload)
|
||||
}
|
||||
// Stream split-session: GET /stream?token=... => downlink poll.
|
||||
if token != "" && path == "/stream" {
|
||||
@@ -2045,10 +2026,18 @@ func writeSimpleHTTPResponse(w io.Writer, code int, body string) error {
|
||||
|
||||
func writeTokenHTTPResponse(w io.Writer, token string) error {
|
||||
token = strings.TrimRight(token, "\r\n")
|
||||
// Use application/octet-stream to avoid CDN auto-compression (e.g. brotli) breaking clients that expect a plain token string.
|
||||
return writeTokenHTTPResponseWithEarlyData(w, token, nil)
|
||||
}
|
||||
|
||||
func writeTokenHTTPResponseWithEarlyData(w io.Writer, token string, earlyPayload []byte) error {
|
||||
token = strings.TrimRight(token, "\r\n")
|
||||
body := "token=" + token
|
||||
if len(earlyPayload) > 0 {
|
||||
body += "\ned=" + base64.RawURLEncoding.EncodeToString(earlyPayload)
|
||||
}
|
||||
_, err := io.WriteString(w,
|
||||
fmt.Sprintf("HTTP/1.1 200 OK\r\nContent-Type: application/octet-stream\r\nCache-Control: no-store\r\nPragma: no-cache\r\nContent-Length: %d\r\nConnection: close\r\n\r\ntoken=%s",
|
||||
len("token=")+len(token), token))
|
||||
fmt.Sprintf("HTTP/1.1 200 OK\r\nContent-Type: application/octet-stream\r\nCache-Control: no-store\r\nPragma: no-cache\r\nContent-Length: %d\r\nConnection: close\r\n\r\n%s",
|
||||
len(body), body))
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -2088,7 +2077,11 @@ func (s *TunnelServer) handlePoll(rawConn net.Conn, req *httpRequestHeader, head
|
||||
switch strings.ToUpper(req.method) {
|
||||
case http.MethodGet:
|
||||
if token == "" && path == "/session" {
|
||||
return s.sessionAuthorize(rawConn)
|
||||
earlyPayload, err := parseEarlyDataQuery(u)
|
||||
if err != nil {
|
||||
return rejectOrReply(http.StatusBadRequest, "bad request")
|
||||
}
|
||||
return s.sessionAuthorize(rawConn, earlyPayload)
|
||||
}
|
||||
if token != "" && path == "/stream" {
|
||||
if s.passThroughOnReject && !s.sessionHas(token) {
|
||||
@@ -2128,7 +2121,7 @@ func (s *TunnelServer) handlePoll(rawConn net.Conn, req *httpRequestHeader, head
|
||||
}
|
||||
}
|
||||
|
||||
func (s *TunnelServer) sessionAuthorize(rawConn net.Conn) (HandleResult, net.Conn, error) {
|
||||
func (s *TunnelServer) sessionAuthorize(rawConn net.Conn, earlyPayload []byte) (HandleResult, net.Conn, error) {
|
||||
token, err := newSessionToken()
|
||||
if err != nil {
|
||||
_ = writeSimpleHTTPResponse(rawConn, http.StatusInternalServerError, "internal error")
|
||||
@@ -2137,6 +2130,37 @@ func (s *TunnelServer) sessionAuthorize(rawConn net.Conn) (HandleResult, net.Con
|
||||
}
|
||||
|
||||
c1, c2 := newHalfPipe()
|
||||
outConn := net.Conn(c1)
|
||||
var responsePayload []byte
|
||||
var userHash string
|
||||
if len(earlyPayload) > 0 && s.earlyHandshake != nil && s.earlyHandshake.Prepare != nil {
|
||||
prepared, err := s.earlyHandshake.Prepare(earlyPayload)
|
||||
if err != nil {
|
||||
_ = c1.Close()
|
||||
_ = c2.Close()
|
||||
if s.passThroughOnReject {
|
||||
return HandlePassThrough, newRejectedPreBufferedConn(rawConn, nil), nil
|
||||
}
|
||||
_ = writeSimpleHTTPResponse(rawConn, http.StatusNotFound, "not found")
|
||||
_ = rawConn.Close()
|
||||
return HandleDone, nil, nil
|
||||
}
|
||||
responsePayload = prepared.ResponsePayload
|
||||
userHash = prepared.UserHash
|
||||
if prepared.WrapConn != nil {
|
||||
wrapped, err := prepared.WrapConn(c1)
|
||||
if err != nil {
|
||||
_ = c1.Close()
|
||||
_ = c2.Close()
|
||||
_ = writeSimpleHTTPResponse(rawConn, http.StatusInternalServerError, "internal error")
|
||||
_ = rawConn.Close()
|
||||
return HandleDone, nil, nil
|
||||
}
|
||||
if wrapped != nil {
|
||||
outConn = wrapEarlyHandshakeConn(wrapped, userHash)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.sessions[token] = &tunnelSession{conn: c2, lastActive: time.Now()}
|
||||
@@ -2144,9 +2168,9 @@ func (s *TunnelServer) sessionAuthorize(rawConn net.Conn) (HandleResult, net.Con
|
||||
|
||||
go s.reapLater(token)
|
||||
|
||||
_ = writeTokenHTTPResponse(rawConn, token)
|
||||
_ = writeTokenHTTPResponseWithEarlyData(rawConn, token, responsePayload)
|
||||
_ = rawConn.Close()
|
||||
return HandleStartTunnel, c1, nil
|
||||
return HandleStartTunnel, outConn, nil
|
||||
}
|
||||
|
||||
func newSessionToken() (string, error) {
|
||||
|
||||
@@ -2,6 +2,7 @@ package httpmask
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
mrand "math/rand"
|
||||
@@ -115,6 +116,16 @@ func dialWS(ctx context.Context, serverAddress string, opts TunnelDialOptions) (
|
||||
Host: urlHost,
|
||||
Path: joinPathRoot(opts.PathRoot, "/ws"),
|
||||
}
|
||||
if opts.EarlyHandshake != nil && len(opts.EarlyHandshake.RequestPayload) > 0 {
|
||||
rawURL, err := setEarlyDataQuery(u.String(), opts.EarlyHandshake.RequestPayload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u, err = url.Parse(rawURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
header := make(stdhttp.Header)
|
||||
applyWSHeaders(header, headerHost)
|
||||
@@ -132,6 +143,16 @@ func dialWS(ctx context.Context, serverAddress string, opts TunnelDialOptions) (
|
||||
d := ws.Dialer{
|
||||
Host: headerHost,
|
||||
Header: ws.HandshakeHeaderHTTP(header),
|
||||
OnHeader: func(key, value []byte) error {
|
||||
if !strings.EqualFold(string(key), tunnelEarlyDataHeader) || opts.EarlyHandshake == nil || opts.EarlyHandshake.HandleResponse == nil {
|
||||
return nil
|
||||
}
|
||||
decoded, err := base64.RawURLEncoding.DecodeString(strings.TrimSpace(string(value)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return opts.EarlyHandshake.HandleResponse(decoded)
|
||||
},
|
||||
NetDial: func(dialCtx context.Context, network, addr string) (net.Conn, error) {
|
||||
if addr == urlHost {
|
||||
addr = dialAddr
|
||||
@@ -161,16 +182,10 @@ func dialWS(ctx context.Context, serverAddress string, opts TunnelDialOptions) (
|
||||
}
|
||||
|
||||
wsConn := newWSStreamConn(conn, ws.StateClientSide)
|
||||
if opts.Upgrade == nil {
|
||||
return wsConn, nil
|
||||
}
|
||||
upgraded, err := opts.Upgrade(wsConn)
|
||||
upgraded, err := applyEarlyHandshakeOrUpgrade(wsConn, opts)
|
||||
if err != nil {
|
||||
_ = wsConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
if upgraded != nil {
|
||||
return upgraded, nil
|
||||
}
|
||||
return wsConn, nil
|
||||
return upgraded, nil
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package httpmask
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@@ -63,15 +64,46 @@ func (s *TunnelServer) handleWS(rawConn net.Conn, req *httpRequestHeader, header
|
||||
return rejectOrReply(http.StatusNotFound, "not found")
|
||||
}
|
||||
|
||||
earlyPayload, err := parseEarlyDataQuery(u)
|
||||
if err != nil {
|
||||
return rejectOrReply(http.StatusBadRequest, "bad request")
|
||||
}
|
||||
var prepared *PreparedServerEarlyHandshake
|
||||
if len(earlyPayload) > 0 && s.earlyHandshake != nil && s.earlyHandshake.Prepare != nil {
|
||||
prepared, err = s.earlyHandshake.Prepare(earlyPayload)
|
||||
if err != nil {
|
||||
return rejectOrReply(http.StatusNotFound, "not found")
|
||||
}
|
||||
}
|
||||
|
||||
prefix := make([]byte, 0, len(headerBytes)+len(buffered))
|
||||
prefix = append(prefix, headerBytes...)
|
||||
prefix = append(prefix, buffered...)
|
||||
wsConnRaw := newPreBufferedConn(rawConn, prefix)
|
||||
|
||||
if _, err := ws.Upgrade(wsConnRaw); err != nil {
|
||||
upgrader := ws.Upgrader{}
|
||||
if prepared != nil && len(prepared.ResponsePayload) > 0 {
|
||||
upgrader.OnBeforeUpgrade = func() (ws.HandshakeHeader, error) {
|
||||
h := http.Header{}
|
||||
h.Set(tunnelEarlyDataHeader, base64.RawURLEncoding.EncodeToString(prepared.ResponsePayload))
|
||||
return ws.HandshakeHeaderHTTP(h), nil
|
||||
}
|
||||
}
|
||||
if _, err := upgrader.Upgrade(wsConnRaw); err != nil {
|
||||
_ = rawConn.Close()
|
||||
return HandleDone, nil, nil
|
||||
}
|
||||
|
||||
return HandleStartTunnel, newWSStreamConn(wsConnRaw, ws.StateServerSide), nil
|
||||
outConn := net.Conn(newWSStreamConn(wsConnRaw, ws.StateServerSide))
|
||||
if prepared != nil && prepared.WrapConn != nil {
|
||||
wrapped, err := prepared.WrapConn(outConn)
|
||||
if err != nil {
|
||||
_ = outConn.Close()
|
||||
return HandleDone, nil, nil
|
||||
}
|
||||
if wrapped != nil {
|
||||
outConn = wrapEarlyHandshakeConn(wrapped, prepared.UserHash)
|
||||
}
|
||||
}
|
||||
return HandleStartTunnel, outConn, nil
|
||||
}
|
||||
|
||||
@@ -11,35 +11,35 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
// 每次从 RNG 获取批量随机数的缓存大小,减少 RNG 函数调用开销
|
||||
RngBatchSize = 128
|
||||
|
||||
packedProtectedPrefixBytes = 14
|
||||
)
|
||||
|
||||
// 1. 使用 12字节->16组 的块处理优化 Write (减少循环开销)
|
||||
// 2. 使用整数阈值随机概率判断 Padding,与纯 Sudoku 保持流量特征一致
|
||||
// 3. Read 使用 copy 移动避免底层数组泄漏
|
||||
// PackedConn encodes traffic with the packed Sudoku layout while preserving
|
||||
// the same padding model as the regular connection.
|
||||
type PackedConn struct {
|
||||
net.Conn
|
||||
table *Table
|
||||
reader *bufio.Reader
|
||||
|
||||
// 读缓冲
|
||||
// Read-side buffers.
|
||||
rawBuf []byte
|
||||
pendingData []byte // 解码后尚未被 Read 取走的字节
|
||||
pendingData []byte
|
||||
|
||||
// 写缓冲与状态
|
||||
// Write-side state.
|
||||
writeMu sync.Mutex
|
||||
writeBuf []byte
|
||||
bitBuf uint64 // 暂存的位数据
|
||||
bitCount int // 暂存的位数
|
||||
bitBuf uint64
|
||||
bitCount int
|
||||
|
||||
// 读状态
|
||||
// Read-side bit accumulator.
|
||||
readBitBuf uint64
|
||||
readBits int
|
||||
|
||||
// 随机数与填充控制 - 使用整数阈值随机,与 Conn 一致
|
||||
// Padding selection matches Conn's threshold-based model.
|
||||
rng *rand.Rand
|
||||
paddingThreshold uint64 // 与 Conn 保持一致的随机概率模型
|
||||
paddingThreshold uint64
|
||||
padMarker byte
|
||||
padPool []byte
|
||||
}
|
||||
@@ -95,7 +95,6 @@ func NewPackedConn(c net.Conn, table *Table, pMin, pMax int) *PackedConn {
|
||||
return pc
|
||||
}
|
||||
|
||||
// maybeAddPadding 内联辅助:根据概率阈值插入 padding
|
||||
func (pc *PackedConn) maybeAddPadding(out []byte) []byte {
|
||||
if shouldPad(pc.rng, pc.paddingThreshold) {
|
||||
out = append(out, pc.getPaddingByte())
|
||||
@@ -103,7 +102,73 @@ func (pc *PackedConn) maybeAddPadding(out []byte) []byte {
|
||||
return out
|
||||
}
|
||||
|
||||
// Write 极致优化版 - 批量处理 12 字节
|
||||
func (pc *PackedConn) appendGroup(out []byte, group byte) []byte {
|
||||
out = pc.maybeAddPadding(out)
|
||||
return append(out, pc.encodeGroup(group))
|
||||
}
|
||||
|
||||
func (pc *PackedConn) appendForcedPadding(out []byte) []byte {
|
||||
return append(out, pc.getPaddingByte())
|
||||
}
|
||||
|
||||
func (pc *PackedConn) nextProtectedPrefixGap() int {
|
||||
return 1 + pc.rng.Intn(2)
|
||||
}
|
||||
|
||||
func (pc *PackedConn) writeProtectedPrefix(out []byte, p []byte) ([]byte, int) {
|
||||
if len(p) == 0 {
|
||||
return out, 0
|
||||
}
|
||||
|
||||
limit := len(p)
|
||||
if limit > packedProtectedPrefixBytes {
|
||||
limit = packedProtectedPrefixBytes
|
||||
}
|
||||
|
||||
for padCount := 0; padCount < 1+pc.rng.Intn(2); padCount++ {
|
||||
out = pc.appendForcedPadding(out)
|
||||
}
|
||||
|
||||
gap := pc.nextProtectedPrefixGap()
|
||||
effective := 0
|
||||
for i := 0; i < limit; i++ {
|
||||
pc.bitBuf = (pc.bitBuf << 8) | uint64(p[i])
|
||||
pc.bitCount += 8
|
||||
for pc.bitCount >= 6 {
|
||||
pc.bitCount -= 6
|
||||
group := byte(pc.bitBuf >> pc.bitCount)
|
||||
if pc.bitCount == 0 {
|
||||
pc.bitBuf = 0
|
||||
} else {
|
||||
pc.bitBuf &= (1 << pc.bitCount) - 1
|
||||
}
|
||||
out = pc.appendGroup(out, group&0x3F)
|
||||
}
|
||||
|
||||
effective++
|
||||
if effective >= gap {
|
||||
out = pc.appendForcedPadding(out)
|
||||
effective = 0
|
||||
gap = pc.nextProtectedPrefixGap()
|
||||
}
|
||||
}
|
||||
|
||||
return out, limit
|
||||
}
|
||||
|
||||
func (pc *PackedConn) drainPendingData(dst []byte) int {
|
||||
n := copy(dst, pc.pendingData)
|
||||
if n == len(pc.pendingData) {
|
||||
pc.pendingData = pc.pendingData[:0]
|
||||
return n
|
||||
}
|
||||
|
||||
remaining := len(pc.pendingData) - n
|
||||
copy(pc.pendingData, pc.pendingData[n:])
|
||||
pc.pendingData = pc.pendingData[:remaining]
|
||||
return n
|
||||
}
|
||||
|
||||
func (pc *PackedConn) Write(p []byte) (int, error) {
|
||||
if len(p) == 0 {
|
||||
return 0, nil
|
||||
@@ -112,20 +177,19 @@ func (pc *PackedConn) Write(p []byte) (int, error) {
|
||||
pc.writeMu.Lock()
|
||||
defer pc.writeMu.Unlock()
|
||||
|
||||
// 1. 预分配内存,避免 append 导致的多次扩容
|
||||
// 预估:原数据 * 1.5 (4/3 + padding 余量)
|
||||
needed := len(p)*3/2 + 32
|
||||
if cap(pc.writeBuf) < needed {
|
||||
pc.writeBuf = make([]byte, 0, needed)
|
||||
}
|
||||
out := pc.writeBuf[:0]
|
||||
|
||||
i := 0
|
||||
var prefixN int
|
||||
out, prefixN = pc.writeProtectedPrefix(out, p)
|
||||
|
||||
i := prefixN
|
||||
n := len(p)
|
||||
|
||||
// 2. 头部对齐处理 (Slow Path)
|
||||
for pc.bitCount > 0 && i < n {
|
||||
out = pc.maybeAddPadding(out)
|
||||
b := p[i]
|
||||
i++
|
||||
pc.bitBuf = (pc.bitBuf << 8) | uint64(b)
|
||||
@@ -138,14 +202,11 @@ func (pc *PackedConn) Write(p []byte) (int, error) {
|
||||
} else {
|
||||
pc.bitBuf &= (1 << pc.bitCount) - 1
|
||||
}
|
||||
out = pc.maybeAddPadding(out)
|
||||
out = append(out, pc.encodeGroup(group&0x3F))
|
||||
out = pc.appendGroup(out, group&0x3F)
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 极速批量处理 (Fast Path) - 每次处理 12 字节 → 生成 16 个编码组
|
||||
for i+11 < n {
|
||||
// 处理 4 组,每组 3 字节
|
||||
for batch := 0; batch < 4; batch++ {
|
||||
b1, b2, b3 := p[i], p[i+1], p[i+2]
|
||||
i += 3
|
||||
@@ -155,19 +216,13 @@ func (pc *PackedConn) Write(p []byte) (int, error) {
|
||||
g3 := ((b2 & 0x0F) << 2) | ((b3 >> 6) & 0x03)
|
||||
g4 := b3 & 0x3F
|
||||
|
||||
// 每个组之前都有概率插入 padding
|
||||
out = pc.maybeAddPadding(out)
|
||||
out = append(out, pc.encodeGroup(g1))
|
||||
out = pc.maybeAddPadding(out)
|
||||
out = append(out, pc.encodeGroup(g2))
|
||||
out = pc.maybeAddPadding(out)
|
||||
out = append(out, pc.encodeGroup(g3))
|
||||
out = pc.maybeAddPadding(out)
|
||||
out = append(out, pc.encodeGroup(g4))
|
||||
out = pc.appendGroup(out, g1)
|
||||
out = pc.appendGroup(out, g2)
|
||||
out = pc.appendGroup(out, g3)
|
||||
out = pc.appendGroup(out, g4)
|
||||
}
|
||||
}
|
||||
|
||||
// 4. 处理剩余的 3 字节块
|
||||
for i+2 < n {
|
||||
b1, b2, b3 := p[i], p[i+1], p[i+2]
|
||||
i += 3
|
||||
@@ -177,17 +232,12 @@ func (pc *PackedConn) Write(p []byte) (int, error) {
|
||||
g3 := ((b2 & 0x0F) << 2) | ((b3 >> 6) & 0x03)
|
||||
g4 := b3 & 0x3F
|
||||
|
||||
out = pc.maybeAddPadding(out)
|
||||
out = append(out, pc.encodeGroup(g1))
|
||||
out = pc.maybeAddPadding(out)
|
||||
out = append(out, pc.encodeGroup(g2))
|
||||
out = pc.maybeAddPadding(out)
|
||||
out = append(out, pc.encodeGroup(g3))
|
||||
out = pc.maybeAddPadding(out)
|
||||
out = append(out, pc.encodeGroup(g4))
|
||||
out = pc.appendGroup(out, g1)
|
||||
out = pc.appendGroup(out, g2)
|
||||
out = pc.appendGroup(out, g3)
|
||||
out = pc.appendGroup(out, g4)
|
||||
}
|
||||
|
||||
// 5. 尾部处理 (Tail Path) - 处理剩余的 1 或 2 个字节
|
||||
for ; i < n; i++ {
|
||||
b := p[i]
|
||||
pc.bitBuf = (pc.bitBuf << 8) | uint64(b)
|
||||
@@ -200,35 +250,28 @@ func (pc *PackedConn) Write(p []byte) (int, error) {
|
||||
} else {
|
||||
pc.bitBuf &= (1 << pc.bitCount) - 1
|
||||
}
|
||||
out = pc.maybeAddPadding(out)
|
||||
out = append(out, pc.encodeGroup(group&0x3F))
|
||||
out = pc.appendGroup(out, group&0x3F)
|
||||
}
|
||||
}
|
||||
|
||||
// 6. 处理残留位
|
||||
if pc.bitCount > 0 {
|
||||
out = pc.maybeAddPadding(out)
|
||||
group := byte(pc.bitBuf << (6 - pc.bitCount))
|
||||
pc.bitBuf = 0
|
||||
pc.bitCount = 0
|
||||
out = append(out, pc.encodeGroup(group&0x3F))
|
||||
out = pc.appendGroup(out, group&0x3F)
|
||||
out = append(out, pc.padMarker)
|
||||
}
|
||||
|
||||
// 尾部可能添加 padding
|
||||
out = pc.maybeAddPadding(out)
|
||||
|
||||
// 发送数据
|
||||
if len(out) > 0 {
|
||||
_, err := pc.Conn.Write(out)
|
||||
pc.writeBuf = out[:0]
|
||||
return len(p), err
|
||||
return len(p), writeFull(pc.Conn, out)
|
||||
}
|
||||
pc.writeBuf = out[:0]
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
// Flush 处理最后不足 6 bit 的情况
|
||||
func (pc *PackedConn) Flush() error {
|
||||
pc.writeMu.Lock()
|
||||
defer pc.writeMu.Unlock()
|
||||
@@ -243,38 +286,34 @@ func (pc *PackedConn) Flush() error {
|
||||
out = append(out, pc.padMarker)
|
||||
}
|
||||
|
||||
// 尾部随机添加 padding
|
||||
out = pc.maybeAddPadding(out)
|
||||
|
||||
if len(out) > 0 {
|
||||
_, err := pc.Conn.Write(out)
|
||||
pc.writeBuf = out[:0]
|
||||
return err
|
||||
return writeFull(pc.Conn, out)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Read 优化版:减少切片操作,避免内存泄漏
|
||||
func (pc *PackedConn) Read(p []byte) (int, error) {
|
||||
// 1. 优先返回待处理区的数据
|
||||
if len(pc.pendingData) > 0 {
|
||||
n := copy(p, pc.pendingData)
|
||||
if n == len(pc.pendingData) {
|
||||
pc.pendingData = pc.pendingData[:0]
|
||||
} else {
|
||||
// 优化:移动剩余数据到数组头部,避免切片指向中间导致内存泄漏
|
||||
remaining := len(pc.pendingData) - n
|
||||
copy(pc.pendingData, pc.pendingData[n:])
|
||||
pc.pendingData = pc.pendingData[:remaining]
|
||||
func writeFull(w io.Writer, b []byte) error {
|
||||
for len(b) > 0 {
|
||||
n, err := w.Write(b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return n, nil
|
||||
b = b[n:]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pc *PackedConn) Read(p []byte) (int, error) {
|
||||
if len(pc.pendingData) > 0 {
|
||||
return pc.drainPendingData(p), nil
|
||||
}
|
||||
|
||||
// 2. 循环读取直到解出数据或出错
|
||||
for {
|
||||
nr, rErr := pc.reader.Read(pc.rawBuf)
|
||||
if nr > 0 {
|
||||
// 缓存频繁访问的变量
|
||||
rBuf := pc.readBitBuf
|
||||
rBits := pc.readBits
|
||||
padMarker := pc.padMarker
|
||||
@@ -324,24 +363,13 @@ func (pc *PackedConn) Read(p []byte) (int, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 返回解码后的数据 - 优化:避免底层数组泄漏
|
||||
n := copy(p, pc.pendingData)
|
||||
if n == len(pc.pendingData) {
|
||||
pc.pendingData = pc.pendingData[:0]
|
||||
} else {
|
||||
remaining := len(pc.pendingData) - n
|
||||
copy(pc.pendingData, pc.pendingData[n:])
|
||||
pc.pendingData = pc.pendingData[:remaining]
|
||||
}
|
||||
return n, nil
|
||||
return pc.drainPendingData(p), nil
|
||||
}
|
||||
|
||||
// getPaddingByte 从 Pool 中随机取 Padding 字节
|
||||
func (pc *PackedConn) getPaddingByte() byte {
|
||||
return pc.padPool[pc.rng.Intn(len(pc.padPool))]
|
||||
}
|
||||
|
||||
// encodeGroup 编码 6-bit 组
|
||||
func (pc *PackedConn) encodeGroup(group byte) byte {
|
||||
return pc.table.layout.encodeGroup(group)
|
||||
}
|
||||
|
||||
91
transport/sudoku/obfs/sudoku/packed_prefix_test.go
Normal file
91
transport/sudoku/obfs/sudoku/packed_prefix_test.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package sudoku
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type mockConn struct {
|
||||
readBuf []byte
|
||||
writeBuf []byte
|
||||
}
|
||||
|
||||
func (c *mockConn) Read(p []byte) (int, error) {
|
||||
if len(c.readBuf) == 0 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
n := copy(p, c.readBuf)
|
||||
c.readBuf = c.readBuf[n:]
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (c *mockConn) Write(p []byte) (int, error) {
|
||||
c.writeBuf = append(c.writeBuf, p...)
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (c *mockConn) Close() error { return nil }
|
||||
func (c *mockConn) LocalAddr() net.Addr { return nil }
|
||||
func (c *mockConn) RemoteAddr() net.Addr { return nil }
|
||||
func (c *mockConn) SetDeadline(time.Time) error { return nil }
|
||||
func (c *mockConn) SetReadDeadline(time.Time) error { return nil }
|
||||
func (c *mockConn) SetWriteDeadline(time.Time) error { return nil }
|
||||
|
||||
func TestPackedConn_ProtectedPrefixPadding(t *testing.T) {
|
||||
table := NewTable("packed-prefix-seed", "prefer_ascii")
|
||||
mock := &mockConn{}
|
||||
writer := NewPackedConn(mock, table, 0, 0)
|
||||
writer.rng = rand.New(rand.NewSource(1))
|
||||
|
||||
payload := bytes.Repeat([]byte{0}, 32)
|
||||
if _, err := writer.Write(payload); err != nil {
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
|
||||
wire := append([]byte(nil), mock.writeBuf...)
|
||||
if len(wire) < 20 {
|
||||
t.Fatalf("wire too short: %d", len(wire))
|
||||
}
|
||||
|
||||
firstHint := -1
|
||||
nonHintCount := 0
|
||||
maxHintRun := 0
|
||||
currentHintRun := 0
|
||||
for i, b := range wire[:20] {
|
||||
if table.layout.isHint(b) {
|
||||
if firstHint == -1 {
|
||||
firstHint = i
|
||||
}
|
||||
currentHintRun++
|
||||
if currentHintRun > maxHintRun {
|
||||
maxHintRun = currentHintRun
|
||||
}
|
||||
continue
|
||||
}
|
||||
nonHintCount++
|
||||
currentHintRun = 0
|
||||
}
|
||||
|
||||
if firstHint < 1 || firstHint > 2 {
|
||||
t.Fatalf("expected 1-2 leading padding bytes, first hint index=%d", firstHint)
|
||||
}
|
||||
if nonHintCount < 6 {
|
||||
t.Fatalf("expected dense prefix padding, got only %d non-hint bytes in first 20", nonHintCount)
|
||||
}
|
||||
if maxHintRun > 3 {
|
||||
t.Fatalf("prefix still exposes long hint run: %d", maxHintRun)
|
||||
}
|
||||
|
||||
reader := NewPackedConn(&mockConn{readBuf: wire}, table, 0, 0)
|
||||
decoded := make([]byte, len(payload))
|
||||
if _, err := io.ReadFull(reader, decoded); err != nil {
|
||||
t.Fatalf("read back: %v", err)
|
||||
}
|
||||
if !bytes.Equal(decoded, payload) {
|
||||
t.Fatalf("roundtrip mismatch")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user