chore: align sudoku with upstream v0.2.0 (#2549)

This commit is contained in:
saba-futai
2026-01-30 10:33:22 +08:00
committed by GitHub
parent f52c9356c2
commit d36b024b10
11 changed files with 753 additions and 361 deletions

View File

@@ -7,7 +7,6 @@ import (
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"time"
N "github.com/metacubex/mihomo/common/net" N "github.com/metacubex/mihomo/common/net"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
@@ -22,10 +21,8 @@ type Sudoku struct {
httpMaskMu sync.Mutex httpMaskMu sync.Mutex
httpMaskClient *sudoku.HTTPMaskTunnelClient httpMaskClient *sudoku.HTTPMaskTunnelClient
muxMu sync.Mutex muxMu sync.Mutex
muxClient *sudoku.MultiplexClient muxClient *sudoku.MultiplexClient
muxBackoffUntil time.Time
muxLastErr error
} }
type SudokuOption struct { type SudokuOption struct {
@@ -58,7 +55,7 @@ func (s *Sudoku) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Con
muxMode := normalizeHTTPMaskMultiplex(cfg.HTTPMaskMultiplex) muxMode := normalizeHTTPMaskMultiplex(cfg.HTTPMaskMultiplex)
if muxMode == "on" && !cfg.DisableHTTPMask && httpTunnelModeEnabled(cfg.HTTPMaskMode) { if muxMode == "on" && !cfg.DisableHTTPMask && httpTunnelModeEnabled(cfg.HTTPMaskMode) {
stream, muxErr := s.dialMultiplex(ctx, cfg.TargetAddress, muxMode) stream, muxErr := s.dialMultiplex(ctx, cfg.TargetAddress)
if muxErr == nil { if muxErr == nil {
return NewConn(stream, s), nil return NewConn(stream, s), nil
} }
@@ -312,9 +309,9 @@ func (s *Sudoku) dialAndHandshake(ctx context.Context, cfg *sudoku.ProtocolConfi
return c, nil return c, nil
} }
func (s *Sudoku) dialMultiplex(ctx context.Context, targetAddress string, mode string) (net.Conn, error) { func (s *Sudoku) dialMultiplex(ctx context.Context, targetAddress string) (net.Conn, error) {
for attempt := 0; attempt < 2; attempt++ { for attempt := 0; attempt < 2; attempt++ {
client, err := s.getOrCreateMuxClient(ctx, mode) client, err := s.getOrCreateMuxClient(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -330,21 +327,11 @@ func (s *Sudoku) dialMultiplex(ctx context.Context, targetAddress string, mode s
return nil, fmt.Errorf("multiplex open stream failed") return nil, fmt.Errorf("multiplex open stream failed")
} }
func (s *Sudoku) getOrCreateMuxClient(ctx context.Context, mode string) (*sudoku.MultiplexClient, error) { func (s *Sudoku) getOrCreateMuxClient(ctx context.Context) (*sudoku.MultiplexClient, error) {
if s == nil { if s == nil {
return nil, fmt.Errorf("nil adapter") return nil, fmt.Errorf("nil adapter")
} }
if mode == "auto" {
s.muxMu.Lock()
backoffUntil := s.muxBackoffUntil
lastErr := s.muxLastErr
s.muxMu.Unlock()
if time.Now().Before(backoffUntil) {
return nil, fmt.Errorf("multiplex temporarily disabled: %v", lastErr)
}
}
s.muxMu.Lock() s.muxMu.Lock()
if s.muxClient != nil && !s.muxClient.IsClosed() { if s.muxClient != nil && !s.muxClient.IsClosed() {
client := s.muxClient client := s.muxClient
@@ -363,20 +350,12 @@ func (s *Sudoku) getOrCreateMuxClient(ctx context.Context, mode string) (*sudoku
baseCfg := s.baseConf baseCfg := s.baseConf
baseConn, err := s.dialAndHandshake(ctx, &baseCfg) baseConn, err := s.dialAndHandshake(ctx, &baseCfg)
if err != nil { if err != nil {
if mode == "auto" {
s.muxLastErr = err
s.muxBackoffUntil = time.Now().Add(45 * time.Second)
}
return nil, err return nil, err
} }
client, err := sudoku.StartMultiplexClient(baseConn) client, err := sudoku.StartMultiplexClient(baseConn)
if err != nil { if err != nil {
_ = baseConn.Close() _ = baseConn.Close()
if mode == "auto" {
s.muxLastErr = err
s.muxBackoffUntil = time.Now().Add(45 * time.Second)
}
return nil, err return nil, err
} }
@@ -384,16 +363,6 @@ func (s *Sudoku) getOrCreateMuxClient(ctx context.Context, mode string) (*sudoku
return client, nil return client, nil
} }
func (s *Sudoku) noteMuxFailure(mode string, err error) {
if mode != "auto" {
return
}
s.muxMu.Lock()
s.muxLastErr = err
s.muxBackoffUntil = time.Now().Add(45 * time.Second)
s.muxMu.Unlock()
}
func (s *Sudoku) resetMuxClient() { func (s *Sudoku) resetMuxClient() {
s.muxMu.Lock() s.muxMu.Lock()
defer s.muxMu.Unlock() defer s.muxMu.Unlock()

View File

@@ -1082,19 +1082,19 @@ proxies: # socks5
server: server_ip/domain # 1.2.3.4 or domain server: server_ip/domain # 1.2.3.4 or domain
port: 443 port: 443
key: "<client_key>" # 如果你使用sudoku生成的ED25519密钥对请填写密钥对中的私钥否则填入和服务端相同的uuid key: "<client_key>" # 如果你使用sudoku生成的ED25519密钥对请填写密钥对中的私钥否则填入和服务端相同的uuid
aead-method: chacha20-poly1305 # 可选chacha20-poly1305、aes-128-gcm、none 我们保证在none的情况下sudoku混淆层仍然确保安全 aead-method: chacha20-poly1305 # 可选chacha20-poly1305、aes-128-gcm、none(不建议;且 enable-pure-downlink=false 时不可用)
padding-min: 2 # 最小填充字节数 padding-min: 2 # 最小填充0-100
padding-max: 7 # 最大填充字节数 padding-max: 7 # 最大填充0-100必须 >= padding-min
table-type: prefer_ascii # 可选值prefer_ascii、prefer_entropy 前者全ascii映射后者保证熵值汉明1低于3 table-type: prefer_ascii # 可选值prefer_ascii、prefer_entropy 前者全ascii映射后者保证熵值汉明1低于3
# custom-table: xpxvvpvv # 可选自定义字节布局必须包含2个x、2个p、4个v可随意组合。启用此处则需配置`table-type`为`prefer_entropy` # custom-table: xpxvvpvv # 可选自定义字节布局必须包含2个x、2个p、4个v可随意组合。启用此处则需配置`table-type`为`prefer_entropy`
# custom-tables: ["xpxvvpvv", "vxpvxvvp"] # 可选自定义字节布局列表x/v/p用于 xvp 模式轮换;非空时覆盖 custom-table # custom-tables: ["xpxvvpvv", "vxpvxvvp"] # 可选自定义字节布局列表x/v/p用于 xvp 模式轮换;非空时覆盖 custom-table
http-mask: true # 是否启用http掩码 http-mask: true # 是否启用http掩码
# http-mask-mode: legacy # 可选legacy默认、stream、poll、autostream/poll/auto 支持走 CDN/反代 # http-mask-mode: legacy # 可选legacy默认、streamsplit-stream、poll、auto先 stream 再 pollstream/poll/auto 支持走 CDN/反代
# http-mask-tls: true # 可选:仅在 http-mask-mode 为 stream/poll/auto 时生效true 强制 httpsfalse 强制 http不会根据端口自动推断 # http-mask-tls: true # 可选:仅在 http-mask-mode 为 stream/poll/auto 时生效true 强制 httpsfalse 强制 http不会根据端口自动推断
# http-mask-host: "" # 可选:覆盖 Host/SNI支持 example.com 或 example.com:443仅在 http-mask-mode 为 stream/poll/auto 时生效 # http-mask-host: "" # 可选:覆盖 Host/SNI支持 example.com 或 example.com:443仅在 http-mask-mode 为 stream/poll/auto 时生效
# path-root: "" # 可选HTTP 隧道端点一级路径前缀(双方需一致),例如 "aabbcc" => /aabbcc/session、/aabbcc/stream、/aabbcc/api/v1/upload # path-root: "" # 可选HTTP 隧道端点一级路径前缀(双方需一致),例如 "aabbcc" 或 "/aabbcc/" => /aabbcc/session、/aabbcc/stream、/aabbcc/api/v1/upload
# http-mask-multiplex: off # 可选off默认、auto复用 h1.1 keep-alive / h2 连接,减少每次建链 RTT、on单条隧道内多路复用多个目标连接;仅在 http-mask-mode=stream/poll/auto 生效) # http-mask-multiplex: off # 可选off默认、auto复用底层 HTTP 连接,减少建链 RTT、onSudoku mux 单隧道多目标;仅在 http-mask-mode=stream/poll/auto 生效)
enable-pure-downlink: false # 是否启用混淆下行false的情况下能在保证数据安全的前提下极大提升下行速度与服务端端保持相同(如果此处为false要求aead不可为none) enable-pure-downlink: false # 可选false=带宽优化下行(更快,要求 aead-method != nonetrue=纯 Sudoku 下行
# anytls # anytls
- name: anytls - name: anytls
@@ -1632,17 +1632,17 @@ listeners:
port: 8443 # 仅支持单端口 port: 8443 # 仅支持单端口
listen: 0.0.0.0 listen: 0.0.0.0
key: "<server_key>" # 如果你使用sudoku生成的ED25519密钥对此处是密钥对中的公钥当然你也可以仅仅使用任意uuid充当key key: "<server_key>" # 如果你使用sudoku生成的ED25519密钥对此处是密钥对中的公钥当然你也可以仅仅使用任意uuid充当key
aead-method: chacha20-poly1305 # 支持chacha20-poly1305或者aes-128-gcm以及nonesudoku的混淆层可以确保none情况下数据安全 aead-method: chacha20-poly1305 # 可选:chacha20-poly1305aes-128-gcmnone(不建议;且 enable-pure-downlink=false 时不可用)
padding-min: 1 # 填充最小长度 padding-min: 1 # 最小填充率0-100
padding-max: 15 # 填充最大长度,均不建议过大 padding-max: 15 # 最大填充率0-100必须 >= padding-min
table-type: prefer_ascii # 可选值prefer_ascii、prefer_entropy 前者全ascii映射后者保证熵值汉明1低于3 table-type: prefer_ascii # 可选值prefer_ascii、prefer_entropy 前者全ascii映射后者保证熵值汉明1低于3
# custom-table: xpxvvpvv # 可选自定义字节布局必须包含2个x、2个p、4个v可随意组合。启用此处则需配置`table-type`为`prefer_entropy` # custom-table: xpxvvpvv # 可选自定义字节布局必须包含2个x、2个p、4个v可随意组合。启用此处则需配置`table-type`为`prefer_entropy`
# custom-tables: ["xpxvvpvv", "vxpvxvvp"] # 可选自定义字节布局列表x/v/p用于 xvp 模式轮换;非空时覆盖 custom-table # custom-tables: ["xpxvvpvv", "vxpvxvvp"] # 可选自定义字节布局列表x/v/p用于 xvp 模式轮换;非空时覆盖 custom-table
handshake-timeout: 5 # optional handshake-timeout: 5 # 可选(秒)
enable-pure-downlink: false # 是否启用混淆下行false的情况下能在保证数据安全的前提下极大提升下行速度与客户端保持相同(如果此处为false要求aead不可为none) enable-pure-downlink: false # 可选false=带宽优化下行(更快,要求 aead-method != nonetrue=纯 Sudoku 下行
disable-http-mask: false # 可选:禁用 http 掩码/隧道(默认为 false disable-http-mask: false # 可选:禁用 http 掩码/隧道(默认为 false
# http-mask-mode: legacy # 可选legacy默认、stream、poll、autostream/poll/auto 支持走 CDN/反代 # http-mask-mode: legacy # 可选legacy默认、streamsplit-stream、poll、auto先 stream 再 pollstream/poll/auto 支持走 CDN/反代
# path-root: "" # 可选HTTP 隧道端点一级路径前缀(双方需一致),例如 "aabbcc" => /aabbcc/session、/aabbcc/stream、/aabbcc/api/v1/upload # path-root: "" # 可选HTTP 隧道端点一级路径前缀(双方需一致),例如 "aabbcc" 或 "/aabbcc/" => /aabbcc/session、/aabbcc/stream、/aabbcc/api/v1/upload

View File

@@ -32,7 +32,8 @@ type ProtocolConfig struct {
PaddingMin int PaddingMin int
PaddingMax int PaddingMax int
// EnablePureDownlink toggles the bandwidth-optimized downlink mode. // EnablePureDownlink enables the pure Sudoku downlink mode.
// When false, the connection uses the bandwidth-optimized packed downlink (requires AEAD).
EnablePureDownlink bool EnablePureDownlink bool
// Client-only: final target "host:port". // Client-only: final target "host:port".
@@ -46,7 +47,7 @@ type ProtocolConfig struct {
// HTTPMaskMode controls how the HTTP layer behaves: // HTTPMaskMode controls how the HTTP layer behaves:
// - "legacy": write a fake HTTP/1.1 header then switch to raw stream (default, not CDN-compatible) // - "legacy": write a fake HTTP/1.1 header then switch to raw stream (default, not CDN-compatible)
// - "stream": real HTTP tunnel (stream-one or split), CDN-compatible // - "stream": real HTTP tunnel (split-stream), CDN-compatible
// - "poll": plain HTTP tunnel (authorize/push/pull), strong restricted-network pass-through // - "poll": plain HTTP tunnel (authorize/push/pull), strong restricted-network pass-through
// - "auto": try stream then fall back to poll // - "auto": try stream then fall back to poll
HTTPMaskMode string HTTPMaskMode string
@@ -114,7 +115,8 @@ func (c *ProtocolConfig) Validate() error {
} }
if v := strings.TrimSpace(c.HTTPMaskPathRoot); v != "" { if v := strings.TrimSpace(c.HTTPMaskPathRoot); v != "" {
if strings.Contains(v, "/") { v = strings.Trim(v, "/")
if v == "" || strings.Contains(v, "/") {
return fmt.Errorf("invalid http-mask-path-root: must be a single path segment") return fmt.Errorf("invalid http-mask-path-root: must be a single path segment")
} }
for i := 0; i < len(v); i++ { for i := 0; i < len(v); i++ {

View File

@@ -22,6 +22,26 @@ type AEADConn struct {
nonceSize int nonceSize int
} }
func (cc *AEADConn) CloseWrite() error {
if cc == nil || cc.Conn == nil {
return nil
}
if cw, ok := cc.Conn.(interface{ CloseWrite() error }); ok {
return cw.CloseWrite()
}
return nil
}
func (cc *AEADConn) CloseRead() error {
if cc == nil || cc.Conn == nil {
return nil
}
if cr, ok := cc.Conn.(interface{ CloseRead() error }); ok {
return cr.CloseRead()
}
return nil
}
func NewAEADConn(c net.Conn, key string, method string) (*AEADConn, error) { func NewAEADConn(c net.Conn, key string, method string) (*AEADConn, error) {
if method == "none" { if method == "none" {
return &AEADConn{Conn: c, aead: nil}, nil return &AEADConn{Conn: c, aead: nil}, nil

View File

@@ -383,7 +383,11 @@ func (c *stream) enqueue(payload []byte) {
c.mu.Unlock() c.mu.Unlock()
return return
} }
c.queue = append(c.queue, payload) if len(c.readBuf) == 0 && len(c.queue) == 0 {
c.readBuf = payload
} else {
c.queue = append(c.queue, payload)
}
c.cond.Signal() c.cond.Signal()
c.mu.Unlock() c.mu.Unlock()
} }
@@ -491,6 +495,9 @@ func (c *stream) Close() error {
return nil return nil
} }
func (c *stream) CloseWrite() error { return c.Close() }
func (c *stream) CloseRead() error { return c.Close() }
func (c *stream) LocalAddr() net.Addr { return c.localAddr } func (c *stream) LocalAddr() net.Addr { return c.localAddr }
func (c *stream) RemoteAddr() net.Addr { return c.remoteAddr } func (c *stream) RemoteAddr() net.Addr { return c.remoteAddr }
@@ -501,4 +508,3 @@ func (c *stream) SetDeadline(t time.Time) error {
} }
func (c *stream) SetReadDeadline(time.Time) error { return nil } func (c *stream) SetReadDeadline(time.Time) error { return nil }
func (c *stream) SetWriteDeadline(time.Time) error { return nil } func (c *stream) SetWriteDeadline(time.Time) error { return nil }

View File

@@ -6,6 +6,7 @@ import (
"crypto/subtle" "crypto/subtle"
"encoding/base64" "encoding/base64"
"encoding/binary" "encoding/binary"
"github.com/metacubex/http"
"strings" "strings"
"time" "time"
) )
@@ -13,6 +14,7 @@ import (
const ( const (
tunnelAuthHeaderKey = "Authorization" tunnelAuthHeaderKey = "Authorization"
tunnelAuthHeaderPrefix = "Bearer " tunnelAuthHeaderPrefix = "Bearer "
tunnelAuthQueryKey = "auth"
) )
type tunnelAuth struct { type tunnelAuth struct {
@@ -61,8 +63,15 @@ func (a *tunnelAuth) verify(headers map[string]string, mode TunnelMode, method,
if headers == nil { if headers == nil {
return false return false
} }
return a.verifyValue(headers["authorization"], mode, method, path, now)
}
val := strings.TrimSpace(headers["authorization"]) func (a *tunnelAuth) verifyValue(val string, mode TunnelMode, method, path string, now time.Time) bool {
if a == nil {
return true
}
val = strings.TrimSpace(val)
if val == "" { if val == "" {
return false return false
} }
@@ -121,11 +130,9 @@ func (a *tunnelAuth) sign(mode TunnelMode, method, path string, ts int64) [16]by
return out return out
} }
type headerSetter interface { type httpHeaderSetter = http.Header
Set(key, value string)
}
func applyTunnelAuthHeader(h headerSetter, auth *tunnelAuth, mode TunnelMode, method, path string) { func applyTunnelAuthHeader(h httpHeaderSetter, auth *tunnelAuth, mode TunnelMode, method, path string) {
if auth == nil || h == nil { if auth == nil || h == nil {
return return
} }
@@ -135,3 +142,19 @@ func applyTunnelAuthHeader(h headerSetter, auth *tunnelAuth, mode TunnelMode, me
} }
h.Set(tunnelAuthHeaderKey, tunnelAuthHeaderPrefix+token) h.Set(tunnelAuthHeaderKey, tunnelAuthHeaderPrefix+token)
} }
func applyTunnelAuth(req *http.Request, auth *tunnelAuth, mode TunnelMode, method, path string) {
if auth == nil || req == nil {
return
}
token := auth.token(mode, method, path, time.Now())
if token == "" {
return
}
req.Header.Set(tunnelAuthHeaderKey, tunnelAuthHeaderPrefix+token)
if req.URL != nil {
q := req.URL.Query()
q.Set(tunnelAuthQueryKey, token)
req.URL.RawQuery = q.Encode()
}
}

View File

@@ -0,0 +1,229 @@
package httpmask
import (
"io"
"net"
"os"
"sync"
"time"
)
type pipeDeadline struct {
mu sync.Mutex
timer *time.Timer
cancel chan struct{}
}
func makePipeDeadline() pipeDeadline {
return pipeDeadline{cancel: make(chan struct{})}
}
func (d *pipeDeadline) set(t time.Time) {
d.mu.Lock()
defer d.mu.Unlock()
if d.timer != nil && !d.timer.Stop() {
<-d.cancel
}
d.timer = nil
closed := isClosedPipeChan(d.cancel)
if t.IsZero() {
if closed {
d.cancel = make(chan struct{})
}
return
}
if dur := time.Until(t); dur > 0 {
if closed {
d.cancel = make(chan struct{})
}
d.timer = time.AfterFunc(dur, func() {
close(d.cancel)
})
return
}
if !closed {
close(d.cancel)
}
}
func (d *pipeDeadline) wait() <-chan struct{} {
d.mu.Lock()
ch := d.cancel
d.mu.Unlock()
return ch
}
func isClosedPipeChan(ch <-chan struct{}) bool {
select {
case <-ch:
return true
default:
return false
}
}
type halfPipeAddr struct{}
func (halfPipeAddr) Network() string { return "pipe" }
func (halfPipeAddr) String() string { return "pipe" }
type halfPipeConn struct {
wrMu sync.Mutex
rdRx <-chan []byte
rdTx chan<- int
wrTx chan<- []byte
wrRx <-chan int
readOnce sync.Once
writeOnce sync.Once
localReadDone chan struct{}
localWriteDone chan struct{}
remoteReadDone <-chan struct{}
remoteWriteDone <-chan struct{}
readDeadline pipeDeadline
writeDeadline pipeDeadline
}
func newHalfPipe() (net.Conn, net.Conn) {
cb1 := make(chan []byte)
cb2 := make(chan []byte)
cn1 := make(chan int)
cn2 := make(chan int)
r1 := make(chan struct{})
w1 := make(chan struct{})
r2 := make(chan struct{})
w2 := make(chan struct{})
c1 := &halfPipeConn{
rdRx: cb1,
rdTx: cn1,
wrTx: cb2,
wrRx: cn2,
localReadDone: r1,
localWriteDone: w1,
remoteReadDone: r2,
remoteWriteDone: w2,
readDeadline: makePipeDeadline(),
writeDeadline: makePipeDeadline(),
}
c2 := &halfPipeConn{
rdRx: cb2,
rdTx: cn2,
wrTx: cb1,
wrRx: cn1,
localReadDone: r2,
localWriteDone: w2,
remoteReadDone: r1,
remoteWriteDone: w1,
readDeadline: makePipeDeadline(),
writeDeadline: makePipeDeadline(),
}
return c1, c2
}
func (*halfPipeConn) LocalAddr() net.Addr { return halfPipeAddr{} }
func (*halfPipeConn) RemoteAddr() net.Addr { return halfPipeAddr{} }
func (c *halfPipeConn) Read(p []byte) (int, error) {
switch {
case isClosedPipeChan(c.localReadDone):
return 0, io.ErrClosedPipe
case isClosedPipeChan(c.remoteWriteDone):
return 0, io.EOF
case isClosedPipeChan(c.readDeadline.wait()):
return 0, os.ErrDeadlineExceeded
}
select {
case b := <-c.rdRx:
n := copy(p, b)
c.rdTx <- n
return n, nil
case <-c.localReadDone:
return 0, io.ErrClosedPipe
case <-c.remoteWriteDone:
return 0, io.EOF
case <-c.readDeadline.wait():
return 0, os.ErrDeadlineExceeded
}
}
func (c *halfPipeConn) Write(p []byte) (int, error) {
switch {
case isClosedPipeChan(c.localWriteDone):
return 0, io.ErrClosedPipe
case isClosedPipeChan(c.remoteReadDone):
return 0, io.ErrClosedPipe
case isClosedPipeChan(c.writeDeadline.wait()):
return 0, os.ErrDeadlineExceeded
}
c.wrMu.Lock()
defer c.wrMu.Unlock()
var (
total int
rest = p
)
for once := true; once || len(rest) > 0; once = false {
select {
case c.wrTx <- rest:
n := <-c.wrRx
rest = rest[n:]
total += n
case <-c.localWriteDone:
return total, io.ErrClosedPipe
case <-c.remoteReadDone:
return total, io.ErrClosedPipe
case <-c.writeDeadline.wait():
return total, os.ErrDeadlineExceeded
}
}
return total, nil
}
func (c *halfPipeConn) CloseWrite() error {
c.writeOnce.Do(func() { close(c.localWriteDone) })
return nil
}
func (c *halfPipeConn) CloseRead() error {
c.readOnce.Do(func() { close(c.localReadDone) })
return nil
}
func (c *halfPipeConn) Close() error {
_ = c.CloseRead()
_ = c.CloseWrite()
return nil
}
func (c *halfPipeConn) SetDeadline(t time.Time) error {
c.readDeadline.set(t)
c.writeDeadline.set(t)
return nil
}
func (c *halfPipeConn) SetReadDeadline(t time.Time) error {
c.readDeadline.set(t)
return nil
}
func (c *halfPipeConn) SetWriteDeadline(t time.Time) error {
c.writeDeadline.set(t)
return nil
}

View File

@@ -241,38 +241,6 @@ func parseTunnelToken(body []byte) (string, error) {
return token, nil return token, nil
} }
type httpStreamConn struct {
reader io.ReadCloser
writer *io.PipeWriter
cancel context.CancelFunc
localAddr net.Addr
remoteAddr net.Addr
}
func (c *httpStreamConn) Read(p []byte) (int, error) { return c.reader.Read(p) }
func (c *httpStreamConn) Write(p []byte) (int, error) { return c.writer.Write(p) }
func (c *httpStreamConn) Close() error {
if c.cancel != nil {
c.cancel()
}
if c.writer != nil {
_ = c.writer.CloseWithError(io.ErrClosedPipe)
}
if c.reader != nil {
return c.reader.Close()
}
return nil
}
func (c *httpStreamConn) LocalAddr() net.Addr { return c.localAddr }
func (c *httpStreamConn) RemoteAddr() net.Addr { return c.remoteAddr }
func (c *httpStreamConn) SetDeadline(time.Time) error { return nil }
func (c *httpStreamConn) SetReadDeadline(time.Time) error { return nil }
func (c *httpStreamConn) SetWriteDeadline(time.Time) error { return nil }
type httpClientTarget struct { type httpClientTarget struct {
scheme string scheme string
urlHost string urlHost string
@@ -332,6 +300,7 @@ type sessionDialInfo struct {
client *http.Client client *http.Client
pushURL string pushURL string
pullURL string pullURL string
finURL string
closeURL string closeURL string
headerHost string headerHost string
auth *tunnelAuth auth *tunnelAuth
@@ -350,7 +319,7 @@ func dialSessionWithClient(ctx context.Context, client *http.Client, target http
} }
req.Host = target.headerHost req.Host = target.headerHost
applyTunnelHeaders(req.Header, target.headerHost, mode) applyTunnelHeaders(req.Header, target.headerHost, mode)
applyTunnelAuthHeader(req.Header, auth, mode, http.MethodGet, "/session") applyTunnelAuth(req, auth, mode, http.MethodGet, "/session")
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
@@ -375,12 +344,14 @@ func dialSessionWithClient(ctx context.Context, client *http.Client, target http
pushURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: joinPathRoot(opts.PathRoot, "/api/v1/upload"), RawQuery: "token=" + url.QueryEscape(token)}).String() pushURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: joinPathRoot(opts.PathRoot, "/api/v1/upload"), RawQuery: "token=" + url.QueryEscape(token)}).String()
pullURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: joinPathRoot(opts.PathRoot, "/stream"), RawQuery: "token=" + url.QueryEscape(token)}).String() pullURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: joinPathRoot(opts.PathRoot, "/stream"), RawQuery: "token=" + url.QueryEscape(token)}).String()
finURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: joinPathRoot(opts.PathRoot, "/api/v1/upload"), RawQuery: "token=" + url.QueryEscape(token) + "&fin=1"}).String()
closeURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: joinPathRoot(opts.PathRoot, "/api/v1/upload"), RawQuery: "token=" + url.QueryEscape(token) + "&close=1"}).String() closeURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: joinPathRoot(opts.PathRoot, "/api/v1/upload"), RawQuery: "token=" + url.QueryEscape(token) + "&close=1"}).String()
return &sessionDialInfo{ return &sessionDialInfo{
client: client, client: client,
pushURL: pushURL, pushURL: pushURL,
pullURL: pullURL, pullURL: pullURL,
finURL: finURL,
closeURL: closeURL, closeURL: closeURL,
headerHost: target.headerHost, headerHost: target.headerHost,
auth: auth, auth: auth,
@@ -409,7 +380,31 @@ func bestEffortCloseSession(client *http.Client, closeURL, headerHost string, mo
} }
req.Host = headerHost req.Host = headerHost
applyTunnelHeaders(req.Header, headerHost, mode) applyTunnelHeaders(req.Header, headerHost, mode)
applyTunnelAuthHeader(req.Header, auth, mode, http.MethodPost, "/api/v1/upload") applyTunnelAuth(req, auth, mode, http.MethodPost, "/api/v1/upload")
resp, err := client.Do(req)
if err != nil || resp == nil {
return
}
_, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 4*1024))
_ = resp.Body.Close()
}
func bestEffortCloseWriteSession(client *http.Client, finURL, headerHost string, mode TunnelMode, auth *tunnelAuth) {
if client == nil || finURL == "" || headerHost == "" {
return
}
closeCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(closeCtx, http.MethodPost, finURL, nil)
if err != nil {
return
}
req.Host = headerHost
applyTunnelHeaders(req.Header, headerHost, mode)
applyTunnelAuth(req, auth, mode, http.MethodPost, "/api/v1/upload")
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil || resp == nil { if err != nil || resp == nil {
@@ -420,160 +415,13 @@ func bestEffortCloseSession(client *http.Client, closeURL, headerHost string, mo
} }
func dialStreamWithClient(ctx context.Context, client *http.Client, target httpClientTarget, opts TunnelDialOptions) (net.Conn, error) { func dialStreamWithClient(ctx context.Context, client *http.Client, target httpClientTarget, opts TunnelDialOptions) (net.Conn, error) {
// Prefer split-session (Cloudflare-friendly). Fall back to stream-one for older servers / environments. // "stream" mode uses split-stream to stay CDN-friendly by default.
c, errSplit := dialStreamSplitWithClient(ctx, client, target, opts) return dialStreamSplitWithClient(ctx, client, target, opts)
if errSplit == nil {
return c, nil
}
c2, errOne := dialStreamOneWithClient(ctx, client, target, opts)
if errOne == nil {
return c2, nil
}
return nil, fmt.Errorf("dial stream failed: split: %v; stream-one: %w", errSplit, errOne)
} }
func dialStream(ctx context.Context, serverAddress string, opts TunnelDialOptions) (net.Conn, error) { func dialStream(ctx context.Context, serverAddress string, opts TunnelDialOptions) (net.Conn, error) {
// Prefer split-session (Cloudflare-friendly). Fall back to stream-one for older servers / environments. // "stream" mode uses split-stream to stay CDN-friendly by default.
c, errSplit := dialStreamSplit(ctx, serverAddress, opts) return dialStreamSplit(ctx, serverAddress, opts)
if errSplit == nil {
return c, nil
}
c2, errOne := dialStreamOne(ctx, serverAddress, opts)
if errOne == nil {
return c2, nil
}
return nil, fmt.Errorf("dial stream failed: split: %v; stream-one: %w", errSplit, errOne)
}
func dialStreamOneWithClient(ctx context.Context, client *http.Client, target httpClientTarget, opts TunnelDialOptions) (net.Conn, error) {
if client == nil {
return nil, fmt.Errorf("nil http client")
}
auth := newTunnelAuth(opts.AuthKey, 0)
r := rngPool.Get().(*mrand.Rand)
basePath := paths[r.Intn(len(paths))]
path := joinPathRoot(opts.PathRoot, basePath)
ctype := contentTypes[r.Intn(len(contentTypes))]
rngPool.Put(r)
u := url.URL{
Scheme: target.scheme,
Host: target.urlHost,
Path: path,
}
reqBodyR, reqBodyW := io.Pipe()
connCtx, connCancel := context.WithCancel(context.Background())
req, err := http.NewRequestWithContext(connCtx, http.MethodPost, u.String(), reqBodyR)
if err != nil {
connCancel()
_ = reqBodyW.Close()
return nil, err
}
req.Host = target.headerHost
applyTunnelHeaders(req.Header, target.headerHost, TunnelModeStream)
applyTunnelAuthHeader(req.Header, auth, TunnelModeStream, http.MethodPost, basePath)
req.Header.Set("Content-Type", ctype)
type doResult struct {
resp *http.Response
err error
}
doCh := make(chan doResult, 1)
go func() {
resp, doErr := client.Do(req)
doCh <- doResult{resp: resp, err: doErr}
}()
streamConn := &httpStreamConn{
writer: reqBodyW,
cancel: connCancel,
localAddr: &net.TCPAddr{},
remoteAddr: &net.TCPAddr{},
}
type upgradeResult struct {
conn net.Conn
err error
}
upgradeCh := make(chan upgradeResult, 1)
if opts.Upgrade == nil {
upgradeCh <- upgradeResult{conn: streamConn, err: nil}
} else {
go func() {
upgradeConn, err := opts.Upgrade(streamConn)
if err != nil {
upgradeCh <- upgradeResult{conn: nil, err: err}
return
}
if upgradeConn == nil {
upgradeConn = streamConn
}
upgradeCh <- upgradeResult{conn: upgradeConn, err: nil}
}()
}
var (
outConn net.Conn
upgradeDone bool
responseReady bool
)
for !(upgradeDone && responseReady) {
select {
case <-ctx.Done():
_ = streamConn.Close()
if outConn != nil && outConn != streamConn {
_ = outConn.Close()
}
return nil, ctx.Err()
case u := <-upgradeCh:
if u.err != nil {
_ = streamConn.Close()
return nil, u.err
}
outConn = u.conn
if outConn == nil {
outConn = streamConn
}
upgradeDone = true
case r := <-doCh:
if r.err != nil {
_ = streamConn.Close()
if outConn != nil && outConn != streamConn {
_ = outConn.Close()
}
return nil, r.err
}
if r.resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(r.resp.Body, 4*1024))
_ = r.resp.Body.Close()
_ = streamConn.Close()
if outConn != nil && outConn != streamConn {
_ = outConn.Close()
}
return nil, fmt.Errorf("stream bad status: %s (%s)", r.resp.Status, strings.TrimSpace(string(body)))
}
streamConn.reader = r.resp.Body
responseReady = true
}
}
return outConn, nil
}
func dialStreamOne(ctx context.Context, serverAddress string, opts TunnelDialOptions) (net.Conn, error) {
client, target, err := newHTTPClient(serverAddress, opts, 32)
if err != nil {
return nil, err
}
return dialStreamOneWithClient(ctx, client, target, opts)
} }
type queuedConn struct { type queuedConn struct {
@@ -581,6 +429,9 @@ type queuedConn struct {
closed chan struct{} closed chan struct{}
writeCh chan []byte writeCh chan []byte
// writeClosed is closed by CloseWrite to stop accepting new payloads.
// When closed, Write returns io.ErrClosedPipe, but Read is unaffected.
writeClosed chan struct{}
mu sync.Mutex mu sync.Mutex
readBuf []byte readBuf []byte
@@ -589,6 +440,18 @@ type queuedConn struct {
remoteAddr net.Addr remoteAddr net.Addr
} }
func (c *queuedConn) CloseWrite() error {
if c == nil || c.writeClosed == nil {
return nil
}
c.mu.Lock()
if !isClosedPipeChan(c.writeClosed) {
close(c.writeClosed)
}
c.mu.Unlock()
return nil
}
func (c *queuedConn) closeWithError(err error) error { func (c *queuedConn) closeWithError(err error) error {
c.mu.Lock() c.mu.Lock()
select { select {
@@ -640,6 +503,9 @@ func (c *queuedConn) Write(b []byte) (n int, err error) {
case <-c.closed: case <-c.closed:
c.mu.Unlock() c.mu.Unlock()
return 0, c.closedErr() return 0, c.closedErr()
case <-c.writeClosed:
c.mu.Unlock()
return 0, io.ErrClosedPipe
default: default:
} }
c.mu.Unlock() c.mu.Unlock()
@@ -651,6 +517,8 @@ func (c *queuedConn) Write(b []byte) (n int, err error) {
return len(b), nil return len(b), nil
case <-c.closed: case <-c.closed:
return 0, c.closedErr() return 0, c.closedErr()
case <-c.writeClosed:
return 0, io.ErrClosedPipe
} }
} }
@@ -670,6 +538,7 @@ type streamSplitConn struct {
client *http.Client client *http.Client
pushURL string pushURL string
pullURL string pullURL string
finURL string
closeURL string closeURL string
headerHost string headerHost string
auth *tunnelAuth auth *tunnelAuth
@@ -697,15 +566,17 @@ func newStreamSplitConnFromInfo(info *sessionDialInfo) *streamSplitConn {
client: info.client, client: info.client,
pushURL: info.pushURL, pushURL: info.pushURL,
pullURL: info.pullURL, pullURL: info.pullURL,
finURL: info.finURL,
closeURL: info.closeURL, closeURL: info.closeURL,
headerHost: info.headerHost, headerHost: info.headerHost,
auth: info.auth, auth: info.auth,
queuedConn: queuedConn{ queuedConn: queuedConn{
rxc: make(chan []byte, 256), rxc: make(chan []byte, 256),
closed: make(chan struct{}), closed: make(chan struct{}),
writeCh: make(chan []byte, 256), writeCh: make(chan []byte, 256),
localAddr: &net.TCPAddr{}, writeClosed: make(chan struct{}),
remoteAddr: &net.TCPAddr{}, localAddr: &net.TCPAddr{},
remoteAddr: &net.TCPAddr{},
}, },
} }
@@ -793,7 +664,7 @@ func (c *streamSplitConn) pullLoop() {
} }
req.Host = c.headerHost req.Host = c.headerHost
applyTunnelHeaders(req.Header, c.headerHost, TunnelModeStream) applyTunnelHeaders(req.Header, c.headerHost, TunnelModeStream)
applyTunnelAuthHeader(req.Header, c.auth, TunnelModeStream, http.MethodGet, "/stream") applyTunnelAuth(req, c.auth, TunnelModeStream, http.MethodGet, "/stream")
resp, err := c.client.Do(req) resp, err := c.client.Do(req)
if err != nil { if err != nil {
@@ -891,7 +762,7 @@ func (c *streamSplitConn) pushLoop() {
} }
req.Host = c.headerHost req.Host = c.headerHost
applyTunnelHeaders(req.Header, c.headerHost, TunnelModeStream) applyTunnelHeaders(req.Header, c.headerHost, TunnelModeStream)
applyTunnelAuthHeader(req.Header, c.auth, TunnelModeStream, http.MethodPost, "/api/v1/upload") applyTunnelAuth(req, c.auth, TunnelModeStream, http.MethodPost, "/api/v1/upload")
req.Header.Set("Content-Type", "application/octet-stream") req.Header.Set("Content-Type", "application/octet-stream")
resp, err := c.client.Do(req) resp, err := c.client.Do(req)
@@ -977,6 +848,27 @@ func (c *streamSplitConn) pushLoop() {
return return
} }
resetTimer() resetTimer()
case <-c.writeClosed:
// Drain any already-accepted writes so CloseWrite does not lose data.
for {
select {
case b := <-c.writeCh:
if len(b) == 0 {
continue
}
if buf.Len()+len(b) > maxBatchBytes {
if err := flushWithRetry(); err != nil {
_ = c.Close()
return
}
}
_, _ = buf.Write(b)
default:
_ = flushWithRetry()
bestEffortCloseWriteSession(c.client, c.finURL, c.headerHost, TunnelModeStream, c.auth)
return
}
}
case <-c.closed: case <-c.closed:
_ = flushWithRetry() _ = flushWithRetry()
return return
@@ -993,6 +885,7 @@ type pollConn struct {
client *http.Client client *http.Client
pushURL string pushURL string
pullURL string pullURL string
finURL string
closeURL string closeURL string
headerHost string headerHost string
auth *tunnelAuth auth *tunnelAuth
@@ -1037,15 +930,17 @@ func newPollConnFromInfo(info *sessionDialInfo) *pollConn {
client: info.client, client: info.client,
pushURL: info.pushURL, pushURL: info.pushURL,
pullURL: info.pullURL, pullURL: info.pullURL,
finURL: info.finURL,
closeURL: info.closeURL, closeURL: info.closeURL,
headerHost: info.headerHost, headerHost: info.headerHost,
auth: info.auth, auth: info.auth,
queuedConn: queuedConn{ queuedConn: queuedConn{
rxc: make(chan []byte, 128), rxc: make(chan []byte, 128),
closed: make(chan struct{}), closed: make(chan struct{}),
writeCh: make(chan []byte, 256), writeCh: make(chan []byte, 256),
localAddr: &net.TCPAddr{}, writeClosed: make(chan struct{}),
remoteAddr: &net.TCPAddr{}, localAddr: &net.TCPAddr{},
remoteAddr: &net.TCPAddr{},
}, },
} }
@@ -1117,14 +1012,14 @@ func (c *pollConn) pullLoop() {
default: default:
} }
req, err := http.NewRequest(http.MethodGet, c.pullURL, nil) req, err := http.NewRequestWithContext(c.ctx, http.MethodGet, c.pullURL, nil)
if err != nil { if err != nil {
_ = c.Close() _ = c.Close()
return return
} }
req.Host = c.headerHost req.Host = c.headerHost
applyTunnelHeaders(req.Header, c.headerHost, TunnelModePoll) applyTunnelHeaders(req.Header, c.headerHost, TunnelModePoll)
applyTunnelAuthHeader(req.Header, c.auth, TunnelModePoll, http.MethodGet, "/stream") applyTunnelAuth(req, c.auth, TunnelModePoll, http.MethodGet, "/stream")
resp, err := c.client.Do(req) resp, err := c.client.Do(req)
if err != nil { if err != nil {
@@ -1202,21 +1097,25 @@ func (c *pollConn) pushLoop() {
return nil return nil
} }
req, err := http.NewRequest(http.MethodPost, c.pushURL, bytes.NewReader(buf.Bytes())) reqCtx, cancel := context.WithTimeout(c.ctx, 20*time.Second)
req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, c.pushURL, bytes.NewReader(buf.Bytes()))
if err != nil { if err != nil {
cancel()
return err return err
} }
req.Host = c.headerHost req.Host = c.headerHost
applyTunnelHeaders(req.Header, c.headerHost, TunnelModePoll) applyTunnelHeaders(req.Header, c.headerHost, TunnelModePoll)
applyTunnelAuthHeader(req.Header, c.auth, TunnelModePoll, http.MethodPost, "/api/v1/upload") applyTunnelAuth(req, c.auth, TunnelModePoll, http.MethodPost, "/api/v1/upload")
req.Header.Set("Content-Type", "text/plain") req.Header.Set("Content-Type", "text/plain")
resp, err := c.client.Do(req) resp, err := c.client.Do(req)
if err != nil { if err != nil {
cancel()
return err return err
} }
_, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 4*1024)) _, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 4*1024))
_ = resp.Body.Close() _ = resp.Body.Close()
cancel()
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return fmt.Errorf("bad status: %s", resp.Status) return fmt.Errorf("bad status: %s", resp.Status)
} }
@@ -1309,6 +1208,41 @@ func (c *pollConn) pushLoop() {
return return
} }
resetTimer() resetTimer()
case <-c.writeClosed:
// Drain any already-accepted writes so CloseWrite does not lose data.
for {
select {
case b := <-c.writeCh:
if len(b) == 0 {
continue
}
for len(b) > 0 {
chunk := b
if len(chunk) > maxLineRawBytes {
chunk = b[:maxLineRawBytes]
}
b = b[len(chunk):]
encLen := base64.StdEncoding.EncodedLen(len(chunk))
if pendingRaw+len(chunk) > maxBatchBytes || buf.Len()+encLen+1 > maxBatchBytes*2 {
if err := flushWithRetry(); err != nil {
_ = c.closeWithError(fmt.Errorf("poll push flush failed: %w", err))
return
}
}
tmp := make([]byte, base64.StdEncoding.EncodedLen(len(chunk)))
base64.StdEncoding.Encode(tmp, chunk)
buf.Write(tmp)
buf.WriteByte('\n')
pendingRaw += len(chunk)
}
default:
_ = flushWithRetry()
bestEffortCloseWriteSession(c.client, c.finURL, c.headerHost, TunnelModePoll, c.auth)
return
}
}
case <-c.closed: case <-c.closed:
_ = flushWithRetry() _ = flushWithRetry()
return return
@@ -1478,19 +1412,50 @@ func (s *TunnelServer) HandleConn(rawConn net.Conn) (HandleResult, net.Conn, err
tunnelHeader := strings.ToLower(strings.TrimSpace(req.headers["x-sudoku-tunnel"])) tunnelHeader := strings.ToLower(strings.TrimSpace(req.headers["x-sudoku-tunnel"]))
if tunnelHeader == "" { if tunnelHeader == "" {
// Not our tunnel; replay full bytes to legacy handler. // Some CDNs / forward proxies may strip unknown headers. When AuthKey is enabled, we can
prefix := make([]byte, 0, len(headerBytes)+len(buffered)) // safely infer the intended tunnel mode by verifying the Authorization token against
prefix = append(prefix, headerBytes...) // both stream/poll modes and picking the one that matches.
prefix = append(prefix, buffered...) if s.auth != nil {
return HandlePassThrough, newPreBufferedConn(rawConn, prefix), nil u, err := url.ParseRequestURI(req.target)
} if err == nil {
if s.mode == TunnelModeLegacy { path, ok := stripPathRoot(s.pathRoot, u.Path)
if s.passThroughOnReject { if ok && s.isAllowedBasePath(path) {
authVal := req.headers["authorization"]
if authVal == "" {
authVal = u.Query().Get(tunnelAuthQueryKey)
}
streamOK := s.auth.verifyValue(authVal, TunnelModeStream, req.method, path, time.Now())
pollOK := s.auth.verifyValue(authVal, TunnelModePoll, req.method, path, time.Now())
switch {
case streamOK && !pollOK:
tunnelHeader = string(TunnelModeStream)
case pollOK && !streamOK:
tunnelHeader = string(TunnelModePoll)
}
}
}
}
if tunnelHeader == "" {
// Not our tunnel; replay full bytes to legacy handler.
prefix := make([]byte, 0, len(headerBytes)+len(buffered)) prefix := make([]byte, 0, len(headerBytes)+len(buffered))
prefix = append(prefix, headerBytes...) prefix = append(prefix, headerBytes...)
prefix = append(prefix, buffered...) prefix = append(prefix, buffered...)
return HandlePassThrough, newPreBufferedConn(rawConn, prefix), nil return HandlePassThrough, newPreBufferedConn(rawConn, prefix), nil
} }
}
reject := func() (HandleResult, net.Conn, error) {
prefix := make([]byte, 0, len(headerBytes)+len(buffered))
prefix = append(prefix, headerBytes...)
prefix = append(prefix, buffered...)
return HandlePassThrough, newRejectedPreBufferedConn(rawConn, prefix), nil
}
if s.mode == TunnelModeLegacy {
if s.passThroughOnReject {
return reject()
}
_ = writeSimpleHTTPResponse(rawConn, http.StatusNotFound, "not found") _ = writeSimpleHTTPResponse(rawConn, http.StatusNotFound, "not found")
_ = rawConn.Close() _ = rawConn.Close()
return HandleDone, nil, nil return HandleDone, nil, nil
@@ -1500,10 +1465,7 @@ func (s *TunnelServer) HandleConn(rawConn net.Conn) (HandleResult, net.Conn, err
case TunnelModeStream: case TunnelModeStream:
if s.mode != TunnelModeStream && s.mode != TunnelModeAuto { if s.mode != TunnelModeStream && s.mode != TunnelModeAuto {
if s.passThroughOnReject { if s.passThroughOnReject {
prefix := make([]byte, 0, len(headerBytes)+len(buffered)) return reject()
prefix = append(prefix, headerBytes...)
prefix = append(prefix, buffered...)
return HandlePassThrough, newPreBufferedConn(rawConn, prefix), nil
} }
_ = writeSimpleHTTPResponse(rawConn, http.StatusNotFound, "not found") _ = writeSimpleHTTPResponse(rawConn, http.StatusNotFound, "not found")
_ = rawConn.Close() _ = rawConn.Close()
@@ -1513,10 +1475,7 @@ func (s *TunnelServer) HandleConn(rawConn net.Conn) (HandleResult, net.Conn, err
case TunnelModePoll: case TunnelModePoll:
if s.mode != TunnelModePoll && s.mode != TunnelModeAuto { if s.mode != TunnelModePoll && s.mode != TunnelModeAuto {
if s.passThroughOnReject { if s.passThroughOnReject {
prefix := make([]byte, 0, len(headerBytes)+len(buffered)) return reject()
prefix = append(prefix, headerBytes...)
prefix = append(prefix, buffered...)
return HandlePassThrough, newPreBufferedConn(rawConn, prefix), nil
} }
_ = writeSimpleHTTPResponse(rawConn, http.StatusNotFound, "not found") _ = writeSimpleHTTPResponse(rawConn, http.StatusNotFound, "not found")
_ = rawConn.Close() _ = rawConn.Close()
@@ -1525,10 +1484,7 @@ func (s *TunnelServer) HandleConn(rawConn net.Conn) (HandleResult, net.Conn, err
return s.handlePoll(rawConn, req, headerBytes, buffered) return s.handlePoll(rawConn, req, headerBytes, buffered)
default: default:
if s.passThroughOnReject { if s.passThroughOnReject {
prefix := make([]byte, 0, len(headerBytes)+len(buffered)) return reject()
prefix = append(prefix, headerBytes...)
prefix = append(prefix, buffered...)
return HandlePassThrough, newPreBufferedConn(rawConn, prefix), nil
} }
_ = writeSimpleHTTPResponse(rawConn, http.StatusNotFound, "not found") _ = writeSimpleHTTPResponse(rawConn, http.StatusNotFound, "not found")
_ = rawConn.Close() _ = rawConn.Close()
@@ -1619,13 +1575,52 @@ func readAllBuffered(r *bufio.Reader) []byte {
type preBufferedConn struct { type preBufferedConn struct {
net.Conn net.Conn
buf []byte buf []byte
recorded []byte
rejected bool
} }
func newPreBufferedConn(conn net.Conn, pre []byte) net.Conn { func (p *preBufferedConn) CloseWrite() error {
if p == nil || p.Conn == nil {
return nil
}
if cw, ok := p.Conn.(interface{ CloseWrite() error }); ok {
return cw.CloseWrite()
}
return nil
}
func (p *preBufferedConn) CloseRead() error {
if p == nil || p.Conn == nil {
return nil
}
if cr, ok := p.Conn.(interface{ CloseRead() error }); ok {
return cr.CloseRead()
}
return nil
}
func newPreBufferedConn(conn net.Conn, pre []byte) *preBufferedConn {
cpy := make([]byte, len(pre)) cpy := make([]byte, len(pre))
copy(cpy, pre) copy(cpy, pre)
return &preBufferedConn{Conn: conn, buf: cpy} return &preBufferedConn{Conn: conn, buf: cpy, recorded: cpy}
}
func newRejectedPreBufferedConn(conn net.Conn, pre []byte) *preBufferedConn {
c := newPreBufferedConn(conn, pre)
c.rejected = true
return c
}
func (p *preBufferedConn) IsHTTPMaskRejected() bool { return p.rejected }
func (p *preBufferedConn) GetBufferedAndRecorded() []byte {
if len(p.recorded) == 0 {
return nil
}
out := make([]byte, len(p.recorded))
copy(out, p.recorded)
return out
} }
func (p *preBufferedConn) Read(b []byte) (int, error) { func (p *preBufferedConn) Read(b []byte) (int, error) {
@@ -1682,7 +1677,7 @@ func (s *TunnelServer) handleStream(rawConn net.Conn, req *httpRequestHeader, he
prefix := make([]byte, 0, len(headerBytes)+len(buffered)) prefix := make([]byte, 0, len(headerBytes)+len(buffered))
prefix = append(prefix, headerBytes...) prefix = append(prefix, headerBytes...)
prefix = append(prefix, buffered...) prefix = append(prefix, buffered...)
return HandlePassThrough, newPreBufferedConn(rawConn, prefix), nil return HandlePassThrough, newRejectedPreBufferedConn(rawConn, prefix), nil
} }
_ = writeSimpleHTTPResponse(rawConn, code, body) _ = writeSimpleHTTPResponse(rawConn, code, body)
_ = rawConn.Close() _ = rawConn.Close()
@@ -1699,21 +1694,29 @@ func (s *TunnelServer) handleStream(rawConn net.Conn, req *httpRequestHeader, he
if !ok || !s.isAllowedBasePath(path) { if !ok || !s.isAllowedBasePath(path) {
return rejectOrReply(http.StatusNotFound, "not found") return rejectOrReply(http.StatusNotFound, "not found")
} }
if !s.auth.verify(req.headers, TunnelModeStream, req.method, path, time.Now()) { authVal := req.headers["authorization"]
if authVal == "" {
authVal = u.Query().Get(tunnelAuthQueryKey)
}
if !s.auth.verifyValue(authVal, TunnelModeStream, req.method, path, time.Now()) {
return rejectOrReply(http.StatusNotFound, "not found") return rejectOrReply(http.StatusNotFound, "not found")
} }
token := u.Query().Get("token") token := u.Query().Get("token")
closeFlag := u.Query().Get("close") == "1" closeFlag := u.Query().Get("close") == "1"
finFlag := u.Query().Get("fin") == "1"
switch strings.ToUpper(req.method) { switch strings.ToUpper(req.method) {
case http.MethodGet: case http.MethodGet:
// Stream split-session: GET /session (no token) => token + start tunnel on a server-side pipe. // Stream split-session: GET /session (no token) => token + start tunnel on a server-side pipe.
if token == "" && path == "/session" { if token == "" && path == "/session" {
return s.authorizeSession(rawConn) return s.sessionAuthorize(rawConn)
} }
// Stream split-session: GET /stream?token=... => downlink poll. // Stream split-session: GET /stream?token=... => downlink poll.
if token != "" && path == "/stream" { if token != "" && path == "/stream" {
if s.passThroughOnReject && !s.sessionHas(token) {
return rejectOrReply(http.StatusNotFound, "not found")
}
return s.streamPull(rawConn, token) return s.streamPull(rawConn, token)
} }
return rejectOrReply(http.StatusBadRequest, "bad request") return rejectOrReply(http.StatusBadRequest, "bad request")
@@ -1721,13 +1724,26 @@ func (s *TunnelServer) handleStream(rawConn net.Conn, req *httpRequestHeader, he
case http.MethodPost: case http.MethodPost:
// Stream split-session: POST /api/v1/upload?token=... => uplink push. // Stream split-session: POST /api/v1/upload?token=... => uplink push.
if token != "" && path == "/api/v1/upload" { if token != "" && path == "/api/v1/upload" {
if s.passThroughOnReject && !s.sessionHas(token) {
return rejectOrReply(http.StatusNotFound, "not found")
}
if closeFlag { if closeFlag {
s.closeSession(token) s.sessionClose(token)
return rejectOrReply(http.StatusOK, "") _ = writeSimpleHTTPResponse(rawConn, http.StatusOK, "")
_ = rawConn.Close()
return HandleDone, nil, nil
}
if finFlag {
s.sessionCloseWrite(token)
_ = writeSimpleHTTPResponse(rawConn, http.StatusOK, "")
_ = rawConn.Close()
return HandleDone, nil, nil
} }
bodyReader, err := newRequestBodyReader(newPreBufferedConn(rawConn, buffered), req.headers) bodyReader, err := newRequestBodyReader(newPreBufferedConn(rawConn, buffered), req.headers)
if err != nil { if err != nil {
return rejectOrReply(http.StatusBadRequest, "bad request") _ = writeSimpleHTTPResponse(rawConn, http.StatusBadRequest, "bad request")
_ = rawConn.Close()
return HandleDone, nil, nil
} }
return s.streamPush(rawConn, token, bodyReader) return s.streamPush(rawConn, token, bodyReader)
} }
@@ -1825,7 +1841,7 @@ func (s *TunnelServer) handlePoll(rawConn net.Conn, req *httpRequestHeader, head
prefix := make([]byte, 0, len(headerBytes)+len(buffered)) prefix := make([]byte, 0, len(headerBytes)+len(buffered))
prefix = append(prefix, headerBytes...) prefix = append(prefix, headerBytes...)
prefix = append(prefix, buffered...) prefix = append(prefix, buffered...)
return HandlePassThrough, newPreBufferedConn(rawConn, prefix), nil return HandlePassThrough, newRejectedPreBufferedConn(rawConn, prefix), nil
} }
_ = writeSimpleHTTPResponse(rawConn, code, body) _ = writeSimpleHTTPResponse(rawConn, code, body)
_ = rawConn.Close() _ = rawConn.Close()
@@ -1841,18 +1857,26 @@ func (s *TunnelServer) handlePoll(rawConn net.Conn, req *httpRequestHeader, head
if !ok || !s.isAllowedBasePath(path) { if !ok || !s.isAllowedBasePath(path) {
return rejectOrReply(http.StatusNotFound, "not found") return rejectOrReply(http.StatusNotFound, "not found")
} }
if !s.auth.verify(req.headers, TunnelModePoll, req.method, path, time.Now()) { authVal := req.headers["authorization"]
if authVal == "" {
authVal = u.Query().Get(tunnelAuthQueryKey)
}
if !s.auth.verifyValue(authVal, TunnelModePoll, req.method, path, time.Now()) {
return rejectOrReply(http.StatusNotFound, "not found") return rejectOrReply(http.StatusNotFound, "not found")
} }
token := u.Query().Get("token") token := u.Query().Get("token")
closeFlag := u.Query().Get("close") == "1" closeFlag := u.Query().Get("close") == "1"
finFlag := u.Query().Get("fin") == "1"
switch strings.ToUpper(req.method) { switch strings.ToUpper(req.method) {
case http.MethodGet: case http.MethodGet:
if token == "" && path == "/session" { if token == "" && path == "/session" {
return s.authorizeSession(rawConn) return s.sessionAuthorize(rawConn)
} }
if token != "" && path == "/stream" { if token != "" && path == "/stream" {
if s.passThroughOnReject && !s.sessionHas(token) {
return rejectOrReply(http.StatusNotFound, "not found")
}
return s.pollPull(rawConn, token) return s.pollPull(rawConn, token)
} }
return rejectOrReply(http.StatusBadRequest, "bad request") return rejectOrReply(http.StatusBadRequest, "bad request")
@@ -1860,13 +1884,26 @@ func (s *TunnelServer) handlePoll(rawConn net.Conn, req *httpRequestHeader, head
if token == "" || path != "/api/v1/upload" { if token == "" || path != "/api/v1/upload" {
return rejectOrReply(http.StatusBadRequest, "bad request") return rejectOrReply(http.StatusBadRequest, "bad request")
} }
if s.passThroughOnReject && !s.sessionHas(token) {
return rejectOrReply(http.StatusNotFound, "not found")
}
if closeFlag { if closeFlag {
s.closeSession(token) s.sessionClose(token)
return rejectOrReply(http.StatusOK, "") _ = writeSimpleHTTPResponse(rawConn, http.StatusOK, "")
_ = rawConn.Close()
return HandleDone, nil, nil
}
if finFlag {
s.sessionCloseWrite(token)
_ = writeSimpleHTTPResponse(rawConn, http.StatusOK, "")
_ = rawConn.Close()
return HandleDone, nil, nil
} }
bodyReader, err := newRequestBodyReader(newPreBufferedConn(rawConn, buffered), req.headers) bodyReader, err := newRequestBodyReader(newPreBufferedConn(rawConn, buffered), req.headers)
if err != nil { if err != nil {
return rejectOrReply(http.StatusBadRequest, "bad request") _ = writeSimpleHTTPResponse(rawConn, http.StatusBadRequest, "bad request")
_ = rawConn.Close()
return HandleDone, nil, nil
} }
return s.pollPush(rawConn, token, bodyReader) return s.pollPush(rawConn, token, bodyReader)
default: default:
@@ -1874,7 +1911,7 @@ func (s *TunnelServer) handlePoll(rawConn net.Conn, req *httpRequestHeader, head
} }
} }
func (s *TunnelServer) authorizeSession(rawConn net.Conn) (HandleResult, net.Conn, error) { func (s *TunnelServer) sessionAuthorize(rawConn net.Conn) (HandleResult, net.Conn, error) {
token, err := newSessionToken() token, err := newSessionToken()
if err != nil { if err != nil {
_ = writeSimpleHTTPResponse(rawConn, http.StatusInternalServerError, "internal error") _ = writeSimpleHTTPResponse(rawConn, http.StatusInternalServerError, "internal error")
@@ -1882,13 +1919,13 @@ func (s *TunnelServer) authorizeSession(rawConn net.Conn) (HandleResult, net.Con
return HandleDone, nil, nil return HandleDone, nil, nil
} }
c1, c2 := net.Pipe() c1, c2 := newHalfPipe()
s.mu.Lock() s.mu.Lock()
s.sessions[token] = &tunnelSession{conn: c2, lastActive: time.Now()} s.sessions[token] = &tunnelSession{conn: c2, lastActive: time.Now()}
s.mu.Unlock() s.mu.Unlock()
go s.reapSessionLater(token) go s.reapLater(token)
_ = writeTokenHTTPResponse(rawConn, token) _ = writeTokenHTTPResponse(rawConn, token)
_ = rawConn.Close() _ = rawConn.Close()
@@ -1903,31 +1940,50 @@ func newSessionToken() (string, error) {
return base64.RawURLEncoding.EncodeToString(b[:]), nil return base64.RawURLEncoding.EncodeToString(b[:]), nil
} }
func (s *TunnelServer) reapSessionLater(token string) { func (s *TunnelServer) reapLater(token string) {
ttl := s.sessionTTL ttl := s.sessionTTL
if ttl <= 0 { if ttl <= 0 {
return return
} }
timer := time.NewTimer(ttl) timer := time.NewTimer(ttl)
defer timer.Stop() defer timer.Stop()
<-timer.C
s.mu.Lock() for {
sess, ok := s.sessions[token] <-timer.C
if !ok {
s.mu.Lock()
sess, ok := s.sessions[token]
if !ok {
s.mu.Unlock()
return
}
idle := time.Since(sess.lastActive)
if idle >= ttl {
delete(s.sessions, token)
s.mu.Unlock()
_ = sess.conn.Close()
return
}
next := ttl - idle
s.mu.Unlock() s.mu.Unlock()
return
// Avoid a tight loop under high-frequency activity; we only need best-effort cleanup.
if next < 50*time.Millisecond {
next = 50 * time.Millisecond
}
timer.Reset(next)
} }
if time.Since(sess.lastActive) < ttl {
s.mu.Unlock()
return
}
delete(s.sessions, token)
s.mu.Unlock()
_ = sess.conn.Close()
} }
func (s *TunnelServer) getSession(token string) (*tunnelSession, bool) { func (s *TunnelServer) sessionHas(token string) bool {
s.mu.Lock()
_, ok := s.sessions[token]
s.mu.Unlock()
return ok
}
func (s *TunnelServer) sessionGet(token string) (*tunnelSession, bool) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
sess, ok := s.sessions[token] sess, ok := s.sessions[token]
@@ -1938,7 +1994,7 @@ func (s *TunnelServer) getSession(token string) (*tunnelSession, bool) {
return sess, true return sess, true
} }
func (s *TunnelServer) closeSession(token string) { func (s *TunnelServer) sessionClose(token string) {
s.mu.Lock() s.mu.Lock()
sess, ok := s.sessions[token] sess, ok := s.sessions[token]
if ok { if ok {
@@ -1950,8 +2006,20 @@ func (s *TunnelServer) closeSession(token string) {
} }
} }
func (s *TunnelServer) sessionCloseWrite(token string) {
sess, ok := s.sessionGet(token)
if !ok || sess == nil || sess.conn == nil {
return
}
if cw, ok := sess.conn.(interface{ CloseWrite() error }); ok {
_ = cw.CloseWrite()
return
}
_ = sess.conn.Close()
}
func (s *TunnelServer) pollPush(rawConn net.Conn, token string, body io.Reader) (HandleResult, net.Conn, error) { func (s *TunnelServer) pollPush(rawConn net.Conn, token string, body io.Reader) (HandleResult, net.Conn, error) {
sess, ok := s.getSession(token) sess, ok := s.sessionGet(token)
if !ok { if !ok {
_ = writeSimpleHTTPResponse(rawConn, http.StatusForbidden, "forbidden") _ = writeSimpleHTTPResponse(rawConn, http.StatusForbidden, "forbidden")
_ = rawConn.Close() _ = rawConn.Close()
@@ -1985,7 +2053,7 @@ func (s *TunnelServer) pollPush(rawConn net.Conn, token string, body io.Reader)
_, werr := sess.conn.Write(decoded[:n]) _, werr := sess.conn.Write(decoded[:n])
_ = sess.conn.SetWriteDeadline(time.Time{}) _ = sess.conn.SetWriteDeadline(time.Time{})
if werr != nil { if werr != nil {
s.closeSession(token) s.sessionClose(token)
_ = writeSimpleHTTPResponse(rawConn, http.StatusGone, "gone") _ = writeSimpleHTTPResponse(rawConn, http.StatusGone, "gone")
_ = rawConn.Close() _ = rawConn.Close()
return HandleDone, nil, nil return HandleDone, nil, nil
@@ -1998,7 +2066,7 @@ func (s *TunnelServer) pollPush(rawConn net.Conn, token string, body io.Reader)
} }
func (s *TunnelServer) streamPush(rawConn net.Conn, token string, body io.Reader) (HandleResult, net.Conn, error) { func (s *TunnelServer) streamPush(rawConn net.Conn, token string, body io.Reader) (HandleResult, net.Conn, error) {
sess, ok := s.getSession(token) sess, ok := s.sessionGet(token)
if !ok { if !ok {
_ = writeSimpleHTTPResponse(rawConn, http.StatusForbidden, "forbidden") _ = writeSimpleHTTPResponse(rawConn, http.StatusForbidden, "forbidden")
_ = rawConn.Close() _ = rawConn.Close()
@@ -2023,7 +2091,7 @@ func (s *TunnelServer) streamPush(rawConn net.Conn, token string, body io.Reader
_, werr := sess.conn.Write(payload) _, werr := sess.conn.Write(payload)
_ = sess.conn.SetWriteDeadline(time.Time{}) _ = sess.conn.SetWriteDeadline(time.Time{})
if werr != nil { if werr != nil {
s.closeSession(token) s.sessionClose(token)
_ = writeSimpleHTTPResponse(rawConn, http.StatusGone, "gone") _ = writeSimpleHTTPResponse(rawConn, http.StatusGone, "gone")
_ = rawConn.Close() _ = rawConn.Close()
return HandleDone, nil, nil return HandleDone, nil, nil
@@ -2036,7 +2104,7 @@ func (s *TunnelServer) streamPush(rawConn net.Conn, token string, body io.Reader
} }
func (s *TunnelServer) streamPull(rawConn net.Conn, token string) (HandleResult, net.Conn, error) { func (s *TunnelServer) streamPull(rawConn net.Conn, token string) (HandleResult, net.Conn, error) {
sess, ok := s.getSession(token) sess, ok := s.sessionGet(token)
if !ok { if !ok {
_ = writeSimpleHTTPResponse(rawConn, http.StatusForbidden, "forbidden") _ = writeSimpleHTTPResponse(rawConn, http.StatusForbidden, "forbidden")
_ = rawConn.Close() _ = rawConn.Close()
@@ -2074,14 +2142,14 @@ func (s *TunnelServer) streamPull(rawConn net.Conn, token string) (HandleResult,
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, net.ErrClosed) { if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, net.ErrClosed) {
return HandleDone, nil, nil return HandleDone, nil, nil
} }
s.closeSession(token) s.sessionClose(token)
return HandleDone, nil, nil return HandleDone, nil, nil
} }
} }
} }
func (s *TunnelServer) pollPull(rawConn net.Conn, token string) (HandleResult, net.Conn, error) { func (s *TunnelServer) pollPull(rawConn net.Conn, token string) (HandleResult, net.Conn, error) {
sess, ok := s.getSession(token) sess, ok := s.sessionGet(token)
if !ok { if !ok {
_ = writeSimpleHTTPResponse(rawConn, http.StatusForbidden, "forbidden") _ = writeSimpleHTTPResponse(rawConn, http.StatusForbidden, "forbidden")
_ = rawConn.Close() _ = rawConn.Close()
@@ -2123,7 +2191,7 @@ func (s *TunnelServer) pollPull(rawConn net.Conn, token string) (HandleResult, n
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, net.ErrClosed) { if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, net.ErrClosed) {
return HandleDone, nil, nil return HandleDone, nil, nil
} }
s.closeSession(token) s.sessionClose(token)
return HandleDone, nil, nil return HandleDone, nil, nil
} }
} }

View File

@@ -52,8 +52,28 @@ type Conn struct {
pendingData []byte pendingData []byte
hintBuf []byte hintBuf []byte
rng *rand.Rand rng *rand.Rand
paddingRate float32 paddingThreshold uint64
}
func (sc *Conn) CloseWrite() error {
if sc == nil || sc.Conn == nil {
return nil
}
if cw, ok := sc.Conn.(interface{ CloseWrite() error }); ok {
return cw.CloseWrite()
}
return nil
}
func (sc *Conn) CloseRead() error {
if sc == nil || sc.Conn == nil {
return nil
}
if cr, ok := sc.Conn.(interface{ CloseRead() error }); ok {
return cr.CloseRead()
}
return nil
} }
func NewConn(c net.Conn, table *Table, pMin, pMax int, record bool) *Conn { func NewConn(c net.Conn, table *Table, pMin, pMax int, record bool) *Conn {
@@ -64,19 +84,15 @@ func NewConn(c net.Conn, table *Table, pMin, pMax int, record bool) *Conn {
seed := int64(binary.BigEndian.Uint64(seedBytes[:])) seed := int64(binary.BigEndian.Uint64(seedBytes[:]))
localRng := rand.New(rand.NewSource(seed)) localRng := rand.New(rand.NewSource(seed))
min := float32(pMin) / 100.0
rng := float32(pMax-pMin) / 100.0
rate := min + localRng.Float32()*rng
sc := &Conn{ sc := &Conn{
Conn: c, Conn: c,
table: table, table: table,
reader: bufio.NewReaderSize(c, IOBufferSize), reader: bufio.NewReaderSize(c, IOBufferSize),
rawBuf: make([]byte, IOBufferSize), rawBuf: make([]byte, IOBufferSize),
pendingData: make([]byte, 0, 4096), pendingData: make([]byte, 0, 4096),
hintBuf: make([]byte, 0, 4), hintBuf: make([]byte, 0, 4),
rng: localRng, rng: localRng,
paddingRate: rate, paddingThreshold: pickPaddingThreshold(localRng, pMin, pMax),
} }
if record { if record {
sc.recorder = new(bytes.Buffer) sc.recorder = new(bytes.Buffer)
@@ -127,7 +143,7 @@ func (sc *Conn) Write(p []byte) (n int, err error) {
padLen := len(pads) padLen := len(pads)
for _, b := range p { for _, b := range p {
if sc.rng.Float32() < sc.paddingRate { if shouldPad(sc.rng, sc.paddingThreshold) {
out = append(out, pads[sc.rng.Intn(padLen)]) out = append(out, pads[sc.rng.Intn(padLen)])
} }
@@ -136,14 +152,14 @@ func (sc *Conn) Write(p []byte) (n int, err error) {
perm := perm4[sc.rng.Intn(len(perm4))] perm := perm4[sc.rng.Intn(len(perm4))]
for _, idx := range perm { for _, idx := range perm {
if sc.rng.Float32() < sc.paddingRate { if shouldPad(sc.rng, sc.paddingThreshold) {
out = append(out, pads[sc.rng.Intn(padLen)]) out = append(out, pads[sc.rng.Intn(padLen)])
} }
out = append(out, puzzle[idx]) out = append(out, puzzle[idx])
} }
} }
if sc.rng.Float32() < sc.paddingRate { if shouldPad(sc.rng, sc.paddingThreshold) {
out = append(out, pads[sc.rng.Intn(padLen)]) out = append(out, pads[sc.rng.Intn(padLen)])
} }

View File

@@ -16,7 +16,7 @@ const (
) )
// 1. 使用 12字节->16组 的块处理优化 Write (减少循环开销) // 1. 使用 12字节->16组 的块处理优化 Write (减少循环开销)
// 2. 使用浮点随机概率判断 Padding与纯 Sudoku 保持流量特征一致 // 2. 使用整数阈值随机概率判断 Padding与纯 Sudoku 保持流量特征一致
// 3. Read 使用 copy 移动避免底层数组泄漏 // 3. Read 使用 copy 移动避免底层数组泄漏
type PackedConn struct { type PackedConn struct {
net.Conn net.Conn
@@ -37,11 +37,31 @@ type PackedConn struct {
readBitBuf uint64 readBitBuf uint64
readBits int readBits int
// 随机数与填充控制 - 使用浮点随机,与 Conn 一致 // 随机数与填充控制 - 使用整数阈值随机,与 Conn 一致
rng *rand.Rand rng *rand.Rand
paddingRate float32 // 与 Conn 保持一致的随机概率模型 paddingThreshold uint64 // 与 Conn 保持一致的随机概率模型
padMarker byte padMarker byte
padPool []byte padPool []byte
}
func (pc *PackedConn) CloseWrite() error {
if pc == nil || pc.Conn == nil {
return nil
}
if cw, ok := pc.Conn.(interface{ CloseWrite() error }); ok {
return cw.CloseWrite()
}
return nil
}
func (pc *PackedConn) CloseRead() error {
if pc == nil || pc.Conn == nil {
return nil
}
if cr, ok := pc.Conn.(interface{ CloseRead() error }); ok {
return cr.CloseRead()
}
return nil
} }
func NewPackedConn(c net.Conn, table *Table, pMin, pMax int) *PackedConn { func NewPackedConn(c net.Conn, table *Table, pMin, pMax int) *PackedConn {
@@ -52,20 +72,15 @@ func NewPackedConn(c net.Conn, table *Table, pMin, pMax int) *PackedConn {
seed := int64(binary.BigEndian.Uint64(seedBytes[:])) seed := int64(binary.BigEndian.Uint64(seedBytes[:]))
localRng := rand.New(rand.NewSource(seed)) localRng := rand.New(rand.NewSource(seed))
// 与 Conn 保持一致的 padding 概率计算
min := float32(pMin) / 100.0
rng := float32(pMax-pMin) / 100.0
rate := min + localRng.Float32()*rng
pc := &PackedConn{ pc := &PackedConn{
Conn: c, Conn: c,
table: table, table: table,
reader: bufio.NewReaderSize(c, IOBufferSize), reader: bufio.NewReaderSize(c, IOBufferSize),
rawBuf: make([]byte, IOBufferSize), rawBuf: make([]byte, IOBufferSize),
pendingData: make([]byte, 0, 4096), pendingData: make([]byte, 0, 4096),
writeBuf: make([]byte, 0, 4096), writeBuf: make([]byte, 0, 4096),
rng: localRng, rng: localRng,
paddingRate: rate, paddingThreshold: pickPaddingThreshold(localRng, pMin, pMax),
} }
pc.padMarker = table.layout.padMarker pc.padMarker = table.layout.padMarker
@@ -80,9 +95,9 @@ func NewPackedConn(c net.Conn, table *Table, pMin, pMax int) *PackedConn {
return pc return pc
} }
// maybeAddPadding 内联辅助:根据浮点概率插入 padding // maybeAddPadding 内联辅助:根据概率阈值插入 padding
func (pc *PackedConn) maybeAddPadding(out []byte) []byte { func (pc *PackedConn) maybeAddPadding(out []byte) []byte {
if pc.rng.Float32() < pc.paddingRate { if shouldPad(pc.rng, pc.paddingThreshold) {
out = append(out, pc.getPaddingByte()) out = append(out, pc.getPaddingByte())
} }
return out return out

View File

@@ -0,0 +1,44 @@
package sudoku
import "math/rand"
const probOne = uint64(1) << 32
func pickPaddingThreshold(r *rand.Rand, pMin, pMax int) uint64 {
if r == nil {
return 0
}
if pMin < 0 {
pMin = 0
}
if pMax < pMin {
pMax = pMin
}
if pMax > 100 {
pMax = 100
}
if pMin > 100 {
pMin = 100
}
min := uint64(pMin) * probOne / 100
max := uint64(pMax) * probOne / 100
if max <= min {
return min
}
u := uint64(r.Uint32())
return min + (u * (max - min) >> 32)
}
func shouldPad(r *rand.Rand, threshold uint64) bool {
if threshold == 0 {
return false
}
if threshold >= probOne {
return true
}
if r == nil {
return false
}
return uint64(r.Uint32()) < threshold
}