diff --git a/transport/sudoku/features_test.go b/transport/sudoku/features_test.go index 39ec37c7..68baab45 100644 --- a/transport/sudoku/features_test.go +++ b/transport/sudoku/features_test.go @@ -5,41 +5,10 @@ import ( "io" "net" "testing" - "time" sudokuobfs "github.com/metacubex/mihomo/transport/sudoku/obfs/sudoku" ) -type discardConn struct{} - -func (discardConn) Read([]byte) (int, error) { return 0, io.EOF } -func (discardConn) Write(p []byte) (int, error) { return len(p), nil } -func (discardConn) Close() error { return nil } -func (discardConn) LocalAddr() net.Addr { return nil } -func (discardConn) RemoteAddr() net.Addr { return nil } -func (discardConn) SetDeadline(time.Time) error { return nil } -func (discardConn) SetReadDeadline(time.Time) error { return nil } -func (discardConn) SetWriteDeadline(time.Time) error { return nil } - -func TestSudokuObfsWriter_ReducesWriteAllocs(t *testing.T) { - table := sudokuobfs.NewTable("alloc-seed", "prefer_ascii") - w := newSudokuObfsWriter(discardConn{}, table, 0, 0) - - payload := bytes.Repeat([]byte{0x42}, 2048) - if _, err := w.Write(payload); err != nil { - t.Fatalf("warmup write: %v", err) - } - - allocs := testing.AllocsPerRun(100, func() { - if _, err := w.Write(payload); err != nil { - t.Fatalf("write: %v", err) - } - }) - if allocs != 0 { - t.Fatalf("expected 0 allocs/run, got %.2f", allocs) - } -} - func TestCustomTablesRotation_ProbedByServer(t *testing.T) { key := "rotate-test-key" target := "8.8.8.8:53" diff --git a/transport/sudoku/handshake.go b/transport/sudoku/handshake.go index 2b6c73dd..1cf119c9 100644 --- a/transport/sudoku/handshake.go +++ b/transport/sudoku/handshake.go @@ -68,6 +68,15 @@ type directionalConn struct { closers []func() error } +func newDirectionalConn(base net.Conn, reader io.Reader, writer io.Writer, closers ...func() error) net.Conn { + return &directionalConn{ + Conn: base, + reader: reader, + writer: writer, + closers: closers, + } +} + func (c *directionalConn) Read(p []byte) (int, error) { return c.reader.Read(p) } @@ -112,40 +121,21 @@ func downlinkMode(cfg *ProtocolConfig) byte { } func buildClientObfsConn(raw net.Conn, cfg *ProtocolConfig, table *sudoku.Table) net.Conn { - baseReader := sudoku.NewConn(raw, table, cfg.PaddingMin, cfg.PaddingMax, false) - baseWriter := newSudokuObfsWriter(raw, table, cfg.PaddingMin, cfg.PaddingMax) + baseSudoku := sudoku.NewConn(raw, table, cfg.PaddingMin, cfg.PaddingMax, false) if cfg.EnablePureDownlink { - return &directionalConn{ - Conn: raw, - reader: baseReader, - writer: baseWriter, - } + return baseSudoku } packed := sudoku.NewPackedConn(raw, table, cfg.PaddingMin, cfg.PaddingMax) - return &directionalConn{ - Conn: raw, - reader: packed, - writer: baseWriter, - } + return newDirectionalConn(raw, packed, baseSudoku) } func buildServerObfsConn(raw net.Conn, cfg *ProtocolConfig, table *sudoku.Table, record bool) (*sudoku.Conn, net.Conn) { - uplink := sudoku.NewConn(raw, table, cfg.PaddingMin, cfg.PaddingMax, record) + uplinkSudoku := sudoku.NewConn(raw, table, cfg.PaddingMin, cfg.PaddingMax, record) if cfg.EnablePureDownlink { - downlink := &directionalConn{ - Conn: raw, - reader: uplink, - writer: newSudokuObfsWriter(raw, table, cfg.PaddingMin, cfg.PaddingMax), - } - return uplink, downlink + return uplinkSudoku, uplinkSudoku } packed := sudoku.NewPackedConn(raw, table, cfg.PaddingMin, cfg.PaddingMax) - return uplink, &directionalConn{ - Conn: raw, - reader: uplink, - writer: packed, - closers: []func() error{packed.Flush}, - } + return uplinkSudoku, newDirectionalConn(raw, uplinkSudoku, packed, packed.Flush) } func buildHandshakePayload(key string) [16]byte { diff --git a/transport/sudoku/obfs_writer.go b/transport/sudoku/obfs_writer.go deleted file mode 100644 index 3dc94b4e..00000000 --- a/transport/sudoku/obfs_writer.go +++ /dev/null @@ -1,113 +0,0 @@ -package sudoku - -import ( - crypto_rand "crypto/rand" - "encoding/binary" - "math/rand" - "net" - - "github.com/metacubex/mihomo/transport/sudoku/obfs/sudoku" -) - -// perm4 matches github.com/saba-futai/sudoku/pkg/obfs/sudoku perm4. -var perm4 = [24][4]byte{ - {0, 1, 2, 3}, - {0, 1, 3, 2}, - {0, 2, 1, 3}, - {0, 2, 3, 1}, - {0, 3, 1, 2}, - {0, 3, 2, 1}, - {1, 0, 2, 3}, - {1, 0, 3, 2}, - {1, 2, 0, 3}, - {1, 2, 3, 0}, - {1, 3, 0, 2}, - {1, 3, 2, 0}, - {2, 0, 1, 3}, - {2, 0, 3, 1}, - {2, 1, 0, 3}, - {2, 1, 3, 0}, - {2, 3, 0, 1}, - {2, 3, 1, 0}, - {3, 0, 1, 2}, - {3, 0, 2, 1}, - {3, 1, 0, 2}, - {3, 1, 2, 0}, - {3, 2, 0, 1}, - {3, 2, 1, 0}, -} - -type sudokuObfsWriter struct { - conn net.Conn - table *sudoku.Table - rng *rand.Rand - paddingRate float32 - - outBuf []byte - pads []byte - padLen int -} - -func newSudokuObfsWriter(conn net.Conn, table *sudoku.Table, pMin, pMax int) *sudokuObfsWriter { - var seedBytes [8]byte - if _, err := crypto_rand.Read(seedBytes[:]); err != nil { - binary.BigEndian.PutUint64(seedBytes[:], uint64(rand.Int63())) - } - seed := int64(binary.BigEndian.Uint64(seedBytes[:])) - localRng := rand.New(rand.NewSource(seed)) - - min := float32(pMin) / 100.0 - span := float32(pMax-pMin) / 100.0 - rate := min + localRng.Float32()*span - - w := &sudokuObfsWriter{ - conn: conn, - table: table, - rng: localRng, - paddingRate: rate, - } - w.pads = table.PaddingPool - w.padLen = len(w.pads) - return w -} - -func (w *sudokuObfsWriter) Write(p []byte) (int, error) { - if len(p) == 0 { - return 0, nil - } - - // Worst-case: 4 hints + up to 6 paddings per input byte. - needed := len(p)*10 + 1 - if cap(w.outBuf) < needed { - w.outBuf = make([]byte, 0, needed) - } - out := w.outBuf[:0] - - pads := w.pads - padLen := w.padLen - - for _, b := range p { - if padLen > 0 && w.rng.Float32() < w.paddingRate { - out = append(out, pads[w.rng.Intn(padLen)]) - } - - puzzles := w.table.EncodeTable[b] - puzzle := puzzles[w.rng.Intn(len(puzzles))] - - perm := perm4[w.rng.Intn(len(perm4))] - for _, idx := range perm { - if padLen > 0 && w.rng.Float32() < w.paddingRate { - out = append(out, pads[w.rng.Intn(padLen)]) - } - out = append(out, puzzle[idx]) - } - } - - if padLen > 0 && w.rng.Float32() < w.paddingRate { - out = append(out, pads[w.rng.Intn(padLen)]) - } - - w.outBuf = out - _, err := w.conn.Write(out) - return len(p), err -}