diff --git a/adapter/outbound/sudoku.go b/adapter/outbound/sudoku.go index cdf7647d..1f6e9781 100644 --- a/adapter/outbound/sudoku.go +++ b/adapter/outbound/sudoku.go @@ -7,7 +7,6 @@ import ( "strconv" "strings" "sync" - "time" N "github.com/metacubex/mihomo/common/net" C "github.com/metacubex/mihomo/constant" @@ -22,10 +21,8 @@ type Sudoku struct { httpMaskMu sync.Mutex httpMaskClient *sudoku.HTTPMaskTunnelClient - muxMu sync.Mutex - muxClient *sudoku.MultiplexClient - muxBackoffUntil time.Time - muxLastErr error + muxMu sync.Mutex + muxClient *sudoku.MultiplexClient } type SudokuOption struct { @@ -58,7 +55,7 @@ func (s *Sudoku) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Con muxMode := normalizeHTTPMaskMultiplex(cfg.HTTPMaskMultiplex) if muxMode == "on" && !cfg.DisableHTTPMask && httpTunnelModeEnabled(cfg.HTTPMaskMode) { - stream, muxErr := s.dialMultiplex(ctx, cfg.TargetAddress, muxMode) + stream, muxErr := s.dialMultiplex(ctx, cfg.TargetAddress) if muxErr == nil { return NewConn(stream, s), nil } @@ -312,9 +309,9 @@ func (s *Sudoku) dialAndHandshake(ctx context.Context, cfg *sudoku.ProtocolConfi return c, nil } -func (s *Sudoku) dialMultiplex(ctx context.Context, targetAddress string, mode string) (net.Conn, error) { +func (s *Sudoku) dialMultiplex(ctx context.Context, targetAddress string) (net.Conn, error) { for attempt := 0; attempt < 2; attempt++ { - client, err := s.getOrCreateMuxClient(ctx, mode) + client, err := s.getOrCreateMuxClient(ctx) if err != nil { return nil, err } @@ -330,21 +327,11 @@ func (s *Sudoku) dialMultiplex(ctx context.Context, targetAddress string, mode s return nil, fmt.Errorf("multiplex open stream failed") } -func (s *Sudoku) getOrCreateMuxClient(ctx context.Context, mode string) (*sudoku.MultiplexClient, error) { +func (s *Sudoku) getOrCreateMuxClient(ctx context.Context) (*sudoku.MultiplexClient, error) { if s == nil { return nil, fmt.Errorf("nil adapter") } - if mode == "auto" { - s.muxMu.Lock() - backoffUntil := s.muxBackoffUntil - lastErr := s.muxLastErr - s.muxMu.Unlock() - if time.Now().Before(backoffUntil) { - return nil, fmt.Errorf("multiplex temporarily disabled: %v", lastErr) - } - } - s.muxMu.Lock() if s.muxClient != nil && !s.muxClient.IsClosed() { client := s.muxClient @@ -363,20 +350,12 @@ func (s *Sudoku) getOrCreateMuxClient(ctx context.Context, mode string) (*sudoku baseCfg := s.baseConf baseConn, err := s.dialAndHandshake(ctx, &baseCfg) if err != nil { - if mode == "auto" { - s.muxLastErr = err - s.muxBackoffUntil = time.Now().Add(45 * time.Second) - } return nil, err } client, err := sudoku.StartMultiplexClient(baseConn) if err != nil { _ = baseConn.Close() - if mode == "auto" { - s.muxLastErr = err - s.muxBackoffUntil = time.Now().Add(45 * time.Second) - } return nil, err } @@ -384,16 +363,6 @@ func (s *Sudoku) getOrCreateMuxClient(ctx context.Context, mode string) (*sudoku return client, nil } -func (s *Sudoku) noteMuxFailure(mode string, err error) { - if mode != "auto" { - return - } - s.muxMu.Lock() - s.muxLastErr = err - s.muxBackoffUntil = time.Now().Add(45 * time.Second) - s.muxMu.Unlock() -} - func (s *Sudoku) resetMuxClient() { s.muxMu.Lock() defer s.muxMu.Unlock() diff --git a/docs/config.yaml b/docs/config.yaml index 8746ae20..aae42c35 100644 --- a/docs/config.yaml +++ b/docs/config.yaml @@ -1082,19 +1082,19 @@ proxies: # socks5 server: server_ip/domain # 1.2.3.4 or domain port: 443 key: "" # 如果你使用sudoku生成的ED25519密钥对,请填写密钥对中的私钥,否则填入和服务端相同的uuid - aead-method: chacha20-poly1305 # 可选值:chacha20-poly1305、aes-128-gcm、none 我们保证在none的情况下sudoku混淆层仍然确保安全 - padding-min: 2 # 最小填充字节数 - padding-max: 7 # 最大填充字节数 + aead-method: chacha20-poly1305 # 可选:chacha20-poly1305、aes-128-gcm、none(不建议;且 enable-pure-downlink=false 时不可用) + padding-min: 2 # 最小填充率(0-100) + padding-max: 7 # 最大填充率(0-100,必须 >= padding-min) table-type: prefer_ascii # 可选值:prefer_ascii、prefer_entropy 前者全ascii映射,后者保证熵值(汉明1)低于3 # custom-table: xpxvvpvv # 可选,自定义字节布局,必须包含2个x、2个p、4个v,可随意组合。启用此处则需配置`table-type`为`prefer_entropy` # custom-tables: ["xpxvvpvv", "vxpvxvvp"] # 可选,自定义字节布局列表(x/v/p),用于 xvp 模式轮换;非空时覆盖 custom-table http-mask: true # 是否启用http掩码 - # http-mask-mode: legacy # 可选:legacy(默认)、stream、poll、auto;stream/poll/auto 支持走 CDN/反代 + # http-mask-mode: legacy # 可选:legacy(默认)、stream(split-stream)、poll、auto(先 stream 再 poll);stream/poll/auto 支持走 CDN/反代 # http-mask-tls: true # 可选:仅在 http-mask-mode 为 stream/poll/auto 时生效;true 强制 https;false 强制 http(不会根据端口自动推断) # http-mask-host: "" # 可选:覆盖 Host/SNI(支持 example.com 或 example.com:443);仅在 http-mask-mode 为 stream/poll/auto 时生效 - # path-root: "" # 可选:HTTP 隧道端点一级路径前缀(双方需一致),例如 "aabbcc" => /aabbcc/session、/aabbcc/stream、/aabbcc/api/v1/upload - # http-mask-multiplex: off # 可选:off(默认)、auto(复用 h1.1 keep-alive / h2 连接,减少每次建链 RTT)、on(单条隧道内多路复用多个目标连接;仅在 http-mask-mode=stream/poll/auto 生效) - enable-pure-downlink: false # 是否启用混淆下行,false的情况下能在保证数据安全的前提下极大提升下行速度,与服务端端保持相同(如果此处为false,则要求aead不可为none) + # path-root: "" # 可选:HTTP 隧道端点一级路径前缀(双方需一致),例如 "aabbcc" 或 "/aabbcc/" => /aabbcc/session、/aabbcc/stream、/aabbcc/api/v1/upload + # http-mask-multiplex: off # 可选:off(默认)、auto(复用底层 HTTP 连接,减少建链 RTT)、on(Sudoku mux 单隧道多目标;仅在 http-mask-mode=stream/poll/auto 生效) + enable-pure-downlink: false # 可选:false=带宽优化下行(更快,要求 aead-method != none);true=纯 Sudoku 下行 # anytls - name: anytls @@ -1632,17 +1632,17 @@ listeners: port: 8443 # 仅支持单端口 listen: 0.0.0.0 key: "" # 如果你使用sudoku生成的ED25519密钥对,此处是密钥对中的公钥,当然,你也可以仅仅使用任意uuid充当key - aead-method: chacha20-poly1305 # 支持chacha20-poly1305或者aes-128-gcm以及none,sudoku的混淆层可以确保none情况下数据安全 - padding-min: 1 # 填充最小长度 - padding-max: 15 # 填充最大长度,均不建议过大 + aead-method: chacha20-poly1305 # 可选:chacha20-poly1305、aes-128-gcm、none(不建议;且 enable-pure-downlink=false 时不可用) + padding-min: 1 # 最小填充率(0-100) + padding-max: 15 # 最大填充率(0-100,必须 >= padding-min) table-type: prefer_ascii # 可选值:prefer_ascii、prefer_entropy 前者全ascii映射,后者保证熵值(汉明1)低于3 # custom-table: xpxvvpvv # 可选,自定义字节布局,必须包含2个x、2个p、4个v,可随意组合。启用此处则需配置`table-type`为`prefer_entropy` # custom-tables: ["xpxvvpvv", "vxpvxvvp"] # 可选,自定义字节布局列表(x/v/p),用于 xvp 模式轮换;非空时覆盖 custom-table - handshake-timeout: 5 # optional - enable-pure-downlink: false # 是否启用混淆下行,false的情况下能在保证数据安全的前提下极大提升下行速度,与客户端保持相同(如果此处为false,则要求aead不可为none) + handshake-timeout: 5 # 可选(秒) + enable-pure-downlink: false # 可选:false=带宽优化下行(更快,要求 aead-method != none);true=纯 Sudoku 下行 disable-http-mask: false # 可选:禁用 http 掩码/隧道(默认为 false) - # http-mask-mode: legacy # 可选:legacy(默认)、stream、poll、auto;stream/poll/auto 支持走 CDN/反代 - # path-root: "" # 可选:HTTP 隧道端点一级路径前缀(双方需一致),例如 "aabbcc" => /aabbcc/session、/aabbcc/stream、/aabbcc/api/v1/upload + # http-mask-mode: legacy # 可选:legacy(默认)、stream(split-stream)、poll、auto(先 stream 再 poll);stream/poll/auto 支持走 CDN/反代 + # path-root: "" # 可选:HTTP 隧道端点一级路径前缀(双方需一致),例如 "aabbcc" 或 "/aabbcc/" => /aabbcc/session、/aabbcc/stream、/aabbcc/api/v1/upload diff --git a/transport/sudoku/config.go b/transport/sudoku/config.go index d13eab43..6a1a465a 100644 --- a/transport/sudoku/config.go +++ b/transport/sudoku/config.go @@ -32,7 +32,8 @@ type ProtocolConfig struct { PaddingMin int PaddingMax int - // EnablePureDownlink toggles the bandwidth-optimized downlink mode. + // EnablePureDownlink enables the pure Sudoku downlink mode. + // When false, the connection uses the bandwidth-optimized packed downlink (requires AEAD). EnablePureDownlink bool // Client-only: final target "host:port". @@ -46,7 +47,7 @@ type ProtocolConfig struct { // HTTPMaskMode controls how the HTTP layer behaves: // - "legacy": write a fake HTTP/1.1 header then switch to raw stream (default, not CDN-compatible) - // - "stream": real HTTP tunnel (stream-one or split), CDN-compatible + // - "stream": real HTTP tunnel (split-stream), CDN-compatible // - "poll": plain HTTP tunnel (authorize/push/pull), strong restricted-network pass-through // - "auto": try stream then fall back to poll HTTPMaskMode string @@ -114,7 +115,8 @@ func (c *ProtocolConfig) Validate() error { } if v := strings.TrimSpace(c.HTTPMaskPathRoot); v != "" { - if strings.Contains(v, "/") { + v = strings.Trim(v, "/") + if v == "" || strings.Contains(v, "/") { return fmt.Errorf("invalid http-mask-path-root: must be a single path segment") } for i := 0; i < len(v); i++ { diff --git a/transport/sudoku/crypto/aead.go b/transport/sudoku/crypto/aead.go index b5f574d9..368caea3 100644 --- a/transport/sudoku/crypto/aead.go +++ b/transport/sudoku/crypto/aead.go @@ -22,6 +22,26 @@ type AEADConn struct { nonceSize int } +func (cc *AEADConn) CloseWrite() error { + if cc == nil || cc.Conn == nil { + return nil + } + if cw, ok := cc.Conn.(interface{ CloseWrite() error }); ok { + return cw.CloseWrite() + } + return nil +} + +func (cc *AEADConn) CloseRead() error { + if cc == nil || cc.Conn == nil { + return nil + } + if cr, ok := cc.Conn.(interface{ CloseRead() error }); ok { + return cr.CloseRead() + } + return nil +} + func NewAEADConn(c net.Conn, key string, method string) (*AEADConn, error) { if method == "none" { return &AEADConn{Conn: c, aead: nil}, nil diff --git a/transport/sudoku/multiplex/session.go b/transport/sudoku/multiplex/session.go index 64205dec..779fcafb 100644 --- a/transport/sudoku/multiplex/session.go +++ b/transport/sudoku/multiplex/session.go @@ -383,7 +383,11 @@ func (c *stream) enqueue(payload []byte) { c.mu.Unlock() return } - c.queue = append(c.queue, payload) + if len(c.readBuf) == 0 && len(c.queue) == 0 { + c.readBuf = payload + } else { + c.queue = append(c.queue, payload) + } c.cond.Signal() c.mu.Unlock() } @@ -491,6 +495,9 @@ func (c *stream) Close() error { return nil } +func (c *stream) CloseWrite() error { return c.Close() } +func (c *stream) CloseRead() error { return c.Close() } + func (c *stream) LocalAddr() net.Addr { return c.localAddr } func (c *stream) RemoteAddr() net.Addr { return c.remoteAddr } @@ -501,4 +508,3 @@ func (c *stream) SetDeadline(t time.Time) error { } func (c *stream) SetReadDeadline(time.Time) error { return nil } func (c *stream) SetWriteDeadline(time.Time) error { return nil } - diff --git a/transport/sudoku/obfs/httpmask/auth.go b/transport/sudoku/obfs/httpmask/auth.go index 3810cbbf..f59d6958 100644 --- a/transport/sudoku/obfs/httpmask/auth.go +++ b/transport/sudoku/obfs/httpmask/auth.go @@ -6,6 +6,7 @@ import ( "crypto/subtle" "encoding/base64" "encoding/binary" + "github.com/metacubex/http" "strings" "time" ) @@ -13,6 +14,7 @@ import ( const ( tunnelAuthHeaderKey = "Authorization" tunnelAuthHeaderPrefix = "Bearer " + tunnelAuthQueryKey = "auth" ) type tunnelAuth struct { @@ -61,8 +63,15 @@ func (a *tunnelAuth) verify(headers map[string]string, mode TunnelMode, method, if headers == nil { return false } + return a.verifyValue(headers["authorization"], mode, method, path, now) +} - val := strings.TrimSpace(headers["authorization"]) +func (a *tunnelAuth) verifyValue(val string, mode TunnelMode, method, path string, now time.Time) bool { + if a == nil { + return true + } + + val = strings.TrimSpace(val) if val == "" { return false } @@ -121,11 +130,9 @@ func (a *tunnelAuth) sign(mode TunnelMode, method, path string, ts int64) [16]by return out } -type headerSetter interface { - Set(key, value string) -} +type httpHeaderSetter = http.Header -func applyTunnelAuthHeader(h headerSetter, auth *tunnelAuth, mode TunnelMode, method, path string) { +func applyTunnelAuthHeader(h httpHeaderSetter, auth *tunnelAuth, mode TunnelMode, method, path string) { if auth == nil || h == nil { return } @@ -135,3 +142,19 @@ func applyTunnelAuthHeader(h headerSetter, auth *tunnelAuth, mode TunnelMode, me } h.Set(tunnelAuthHeaderKey, tunnelAuthHeaderPrefix+token) } + +func applyTunnelAuth(req *http.Request, auth *tunnelAuth, mode TunnelMode, method, path string) { + if auth == nil || req == nil { + return + } + token := auth.token(mode, method, path, time.Now()) + if token == "" { + return + } + req.Header.Set(tunnelAuthHeaderKey, tunnelAuthHeaderPrefix+token) + if req.URL != nil { + q := req.URL.Query() + q.Set(tunnelAuthQueryKey, token) + req.URL.RawQuery = q.Encode() + } +} diff --git a/transport/sudoku/obfs/httpmask/halfpipe.go b/transport/sudoku/obfs/httpmask/halfpipe.go new file mode 100644 index 00000000..afbe0bcc --- /dev/null +++ b/transport/sudoku/obfs/httpmask/halfpipe.go @@ -0,0 +1,229 @@ +package httpmask + +import ( + "io" + "net" + "os" + "sync" + "time" +) + +type pipeDeadline struct { + mu sync.Mutex + timer *time.Timer + cancel chan struct{} +} + +func makePipeDeadline() pipeDeadline { + return pipeDeadline{cancel: make(chan struct{})} +} + +func (d *pipeDeadline) set(t time.Time) { + d.mu.Lock() + defer d.mu.Unlock() + + if d.timer != nil && !d.timer.Stop() { + <-d.cancel + } + d.timer = nil + + closed := isClosedPipeChan(d.cancel) + if t.IsZero() { + if closed { + d.cancel = make(chan struct{}) + } + return + } + + if dur := time.Until(t); dur > 0 { + if closed { + d.cancel = make(chan struct{}) + } + d.timer = time.AfterFunc(dur, func() { + close(d.cancel) + }) + return + } + + if !closed { + close(d.cancel) + } +} + +func (d *pipeDeadline) wait() <-chan struct{} { + d.mu.Lock() + ch := d.cancel + d.mu.Unlock() + return ch +} + +func isClosedPipeChan(ch <-chan struct{}) bool { + select { + case <-ch: + return true + default: + return false + } +} + +type halfPipeAddr struct{} + +func (halfPipeAddr) Network() string { return "pipe" } +func (halfPipeAddr) String() string { return "pipe" } + +type halfPipeConn struct { + wrMu sync.Mutex + + rdRx <-chan []byte + rdTx chan<- int + + wrTx chan<- []byte + wrRx <-chan int + + readOnce sync.Once + writeOnce sync.Once + + localReadDone chan struct{} + localWriteDone chan struct{} + + remoteReadDone <-chan struct{} + remoteWriteDone <-chan struct{} + + readDeadline pipeDeadline + writeDeadline pipeDeadline +} + +func newHalfPipe() (net.Conn, net.Conn) { + cb1 := make(chan []byte) + cb2 := make(chan []byte) + cn1 := make(chan int) + cn2 := make(chan int) + + r1 := make(chan struct{}) + w1 := make(chan struct{}) + r2 := make(chan struct{}) + w2 := make(chan struct{}) + + c1 := &halfPipeConn{ + rdRx: cb1, + rdTx: cn1, + wrTx: cb2, + wrRx: cn2, + + localReadDone: r1, + localWriteDone: w1, + remoteReadDone: r2, + remoteWriteDone: w2, + + readDeadline: makePipeDeadline(), + writeDeadline: makePipeDeadline(), + } + c2 := &halfPipeConn{ + rdRx: cb2, + rdTx: cn2, + wrTx: cb1, + wrRx: cn1, + + localReadDone: r2, + localWriteDone: w2, + remoteReadDone: r1, + remoteWriteDone: w1, + + readDeadline: makePipeDeadline(), + writeDeadline: makePipeDeadline(), + } + return c1, c2 +} + +func (*halfPipeConn) LocalAddr() net.Addr { return halfPipeAddr{} } +func (*halfPipeConn) RemoteAddr() net.Addr { return halfPipeAddr{} } + +func (c *halfPipeConn) Read(p []byte) (int, error) { + switch { + case isClosedPipeChan(c.localReadDone): + return 0, io.ErrClosedPipe + case isClosedPipeChan(c.remoteWriteDone): + return 0, io.EOF + case isClosedPipeChan(c.readDeadline.wait()): + return 0, os.ErrDeadlineExceeded + } + + select { + case b := <-c.rdRx: + n := copy(p, b) + c.rdTx <- n + return n, nil + case <-c.localReadDone: + return 0, io.ErrClosedPipe + case <-c.remoteWriteDone: + return 0, io.EOF + case <-c.readDeadline.wait(): + return 0, os.ErrDeadlineExceeded + } +} + +func (c *halfPipeConn) Write(p []byte) (int, error) { + switch { + case isClosedPipeChan(c.localWriteDone): + return 0, io.ErrClosedPipe + case isClosedPipeChan(c.remoteReadDone): + return 0, io.ErrClosedPipe + case isClosedPipeChan(c.writeDeadline.wait()): + return 0, os.ErrDeadlineExceeded + } + + c.wrMu.Lock() + defer c.wrMu.Unlock() + + var ( + total int + rest = p + ) + for once := true; once || len(rest) > 0; once = false { + select { + case c.wrTx <- rest: + n := <-c.wrRx + rest = rest[n:] + total += n + case <-c.localWriteDone: + return total, io.ErrClosedPipe + case <-c.remoteReadDone: + return total, io.ErrClosedPipe + case <-c.writeDeadline.wait(): + return total, os.ErrDeadlineExceeded + } + } + return total, nil +} + +func (c *halfPipeConn) CloseWrite() error { + c.writeOnce.Do(func() { close(c.localWriteDone) }) + return nil +} + +func (c *halfPipeConn) CloseRead() error { + c.readOnce.Do(func() { close(c.localReadDone) }) + return nil +} + +func (c *halfPipeConn) Close() error { + _ = c.CloseRead() + _ = c.CloseWrite() + return nil +} + +func (c *halfPipeConn) SetDeadline(t time.Time) error { + c.readDeadline.set(t) + c.writeDeadline.set(t) + return nil +} + +func (c *halfPipeConn) SetReadDeadline(t time.Time) error { + c.readDeadline.set(t) + return nil +} + +func (c *halfPipeConn) SetWriteDeadline(t time.Time) error { + c.writeDeadline.set(t) + return nil +} diff --git a/transport/sudoku/obfs/httpmask/tunnel.go b/transport/sudoku/obfs/httpmask/tunnel.go index 1d8fe905..7f9f786e 100644 --- a/transport/sudoku/obfs/httpmask/tunnel.go +++ b/transport/sudoku/obfs/httpmask/tunnel.go @@ -241,38 +241,6 @@ func parseTunnelToken(body []byte) (string, error) { return token, nil } -type httpStreamConn struct { - reader io.ReadCloser - writer *io.PipeWriter - cancel context.CancelFunc - - localAddr net.Addr - remoteAddr net.Addr -} - -func (c *httpStreamConn) Read(p []byte) (int, error) { return c.reader.Read(p) } -func (c *httpStreamConn) Write(p []byte) (int, error) { return c.writer.Write(p) } - -func (c *httpStreamConn) Close() error { - if c.cancel != nil { - c.cancel() - } - if c.writer != nil { - _ = c.writer.CloseWithError(io.ErrClosedPipe) - } - if c.reader != nil { - return c.reader.Close() - } - return nil -} - -func (c *httpStreamConn) LocalAddr() net.Addr { return c.localAddr } -func (c *httpStreamConn) RemoteAddr() net.Addr { return c.remoteAddr } - -func (c *httpStreamConn) SetDeadline(time.Time) error { return nil } -func (c *httpStreamConn) SetReadDeadline(time.Time) error { return nil } -func (c *httpStreamConn) SetWriteDeadline(time.Time) error { return nil } - type httpClientTarget struct { scheme string urlHost string @@ -332,6 +300,7 @@ type sessionDialInfo struct { client *http.Client pushURL string pullURL string + finURL string closeURL string headerHost string auth *tunnelAuth @@ -350,7 +319,7 @@ func dialSessionWithClient(ctx context.Context, client *http.Client, target http } req.Host = target.headerHost applyTunnelHeaders(req.Header, target.headerHost, mode) - applyTunnelAuthHeader(req.Header, auth, mode, http.MethodGet, "/session") + applyTunnelAuth(req, auth, mode, http.MethodGet, "/session") resp, err := client.Do(req) if err != nil { @@ -375,12 +344,14 @@ func dialSessionWithClient(ctx context.Context, client *http.Client, target http pushURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: joinPathRoot(opts.PathRoot, "/api/v1/upload"), RawQuery: "token=" + url.QueryEscape(token)}).String() pullURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: joinPathRoot(opts.PathRoot, "/stream"), RawQuery: "token=" + url.QueryEscape(token)}).String() + finURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: joinPathRoot(opts.PathRoot, "/api/v1/upload"), RawQuery: "token=" + url.QueryEscape(token) + "&fin=1"}).String() closeURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: joinPathRoot(opts.PathRoot, "/api/v1/upload"), RawQuery: "token=" + url.QueryEscape(token) + "&close=1"}).String() return &sessionDialInfo{ client: client, pushURL: pushURL, pullURL: pullURL, + finURL: finURL, closeURL: closeURL, headerHost: target.headerHost, auth: auth, @@ -409,7 +380,31 @@ func bestEffortCloseSession(client *http.Client, closeURL, headerHost string, mo } req.Host = headerHost applyTunnelHeaders(req.Header, headerHost, mode) - applyTunnelAuthHeader(req.Header, auth, mode, http.MethodPost, "/api/v1/upload") + applyTunnelAuth(req, auth, mode, http.MethodPost, "/api/v1/upload") + + resp, err := client.Do(req) + if err != nil || resp == nil { + return + } + _, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 4*1024)) + _ = resp.Body.Close() +} + +func bestEffortCloseWriteSession(client *http.Client, finURL, headerHost string, mode TunnelMode, auth *tunnelAuth) { + if client == nil || finURL == "" || headerHost == "" { + return + } + + closeCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(closeCtx, http.MethodPost, finURL, nil) + if err != nil { + return + } + req.Host = headerHost + applyTunnelHeaders(req.Header, headerHost, mode) + applyTunnelAuth(req, auth, mode, http.MethodPost, "/api/v1/upload") resp, err := client.Do(req) if err != nil || resp == nil { @@ -420,160 +415,13 @@ func bestEffortCloseSession(client *http.Client, closeURL, headerHost string, mo } func dialStreamWithClient(ctx context.Context, client *http.Client, target httpClientTarget, opts TunnelDialOptions) (net.Conn, error) { - // Prefer split-session (Cloudflare-friendly). Fall back to stream-one for older servers / environments. - c, errSplit := dialStreamSplitWithClient(ctx, client, target, opts) - if errSplit == nil { - return c, nil - } - c2, errOne := dialStreamOneWithClient(ctx, client, target, opts) - if errOne == nil { - return c2, nil - } - return nil, fmt.Errorf("dial stream failed: split: %v; stream-one: %w", errSplit, errOne) + // "stream" mode uses split-stream to stay CDN-friendly by default. + return dialStreamSplitWithClient(ctx, client, target, opts) } func dialStream(ctx context.Context, serverAddress string, opts TunnelDialOptions) (net.Conn, error) { - // Prefer split-session (Cloudflare-friendly). Fall back to stream-one for older servers / environments. - c, errSplit := dialStreamSplit(ctx, serverAddress, opts) - if errSplit == nil { - return c, nil - } - c2, errOne := dialStreamOne(ctx, serverAddress, opts) - if errOne == nil { - return c2, nil - } - return nil, fmt.Errorf("dial stream failed: split: %v; stream-one: %w", errSplit, errOne) -} - -func dialStreamOneWithClient(ctx context.Context, client *http.Client, target httpClientTarget, opts TunnelDialOptions) (net.Conn, error) { - if client == nil { - return nil, fmt.Errorf("nil http client") - } - - auth := newTunnelAuth(opts.AuthKey, 0) - r := rngPool.Get().(*mrand.Rand) - basePath := paths[r.Intn(len(paths))] - path := joinPathRoot(opts.PathRoot, basePath) - ctype := contentTypes[r.Intn(len(contentTypes))] - rngPool.Put(r) - - u := url.URL{ - Scheme: target.scheme, - Host: target.urlHost, - Path: path, - } - - reqBodyR, reqBodyW := io.Pipe() - - connCtx, connCancel := context.WithCancel(context.Background()) - req, err := http.NewRequestWithContext(connCtx, http.MethodPost, u.String(), reqBodyR) - if err != nil { - connCancel() - _ = reqBodyW.Close() - return nil, err - } - req.Host = target.headerHost - - applyTunnelHeaders(req.Header, target.headerHost, TunnelModeStream) - applyTunnelAuthHeader(req.Header, auth, TunnelModeStream, http.MethodPost, basePath) - req.Header.Set("Content-Type", ctype) - - type doResult struct { - resp *http.Response - err error - } - doCh := make(chan doResult, 1) - go func() { - resp, doErr := client.Do(req) - doCh <- doResult{resp: resp, err: doErr} - }() - - streamConn := &httpStreamConn{ - writer: reqBodyW, - cancel: connCancel, - localAddr: &net.TCPAddr{}, - remoteAddr: &net.TCPAddr{}, - } - - type upgradeResult struct { - conn net.Conn - err error - } - upgradeCh := make(chan upgradeResult, 1) - if opts.Upgrade == nil { - upgradeCh <- upgradeResult{conn: streamConn, err: nil} - } else { - go func() { - upgradeConn, err := opts.Upgrade(streamConn) - if err != nil { - upgradeCh <- upgradeResult{conn: nil, err: err} - return - } - if upgradeConn == nil { - upgradeConn = streamConn - } - upgradeCh <- upgradeResult{conn: upgradeConn, err: nil} - }() - } - - var ( - outConn net.Conn - upgradeDone bool - responseReady bool - ) - - for !(upgradeDone && responseReady) { - select { - case <-ctx.Done(): - _ = streamConn.Close() - if outConn != nil && outConn != streamConn { - _ = outConn.Close() - } - return nil, ctx.Err() - - case u := <-upgradeCh: - if u.err != nil { - _ = streamConn.Close() - return nil, u.err - } - outConn = u.conn - if outConn == nil { - outConn = streamConn - } - upgradeDone = true - - case r := <-doCh: - if r.err != nil { - _ = streamConn.Close() - if outConn != nil && outConn != streamConn { - _ = outConn.Close() - } - return nil, r.err - } - if r.resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(io.LimitReader(r.resp.Body, 4*1024)) - _ = r.resp.Body.Close() - _ = streamConn.Close() - if outConn != nil && outConn != streamConn { - _ = outConn.Close() - } - return nil, fmt.Errorf("stream bad status: %s (%s)", r.resp.Status, strings.TrimSpace(string(body))) - } - - streamConn.reader = r.resp.Body - responseReady = true - } - } - - return outConn, nil -} - -func dialStreamOne(ctx context.Context, serverAddress string, opts TunnelDialOptions) (net.Conn, error) { - client, target, err := newHTTPClient(serverAddress, opts, 32) - if err != nil { - return nil, err - } - return dialStreamOneWithClient(ctx, client, target, opts) + // "stream" mode uses split-stream to stay CDN-friendly by default. + return dialStreamSplit(ctx, serverAddress, opts) } type queuedConn struct { @@ -581,6 +429,9 @@ type queuedConn struct { closed chan struct{} writeCh chan []byte + // writeClosed is closed by CloseWrite to stop accepting new payloads. + // When closed, Write returns io.ErrClosedPipe, but Read is unaffected. + writeClosed chan struct{} mu sync.Mutex readBuf []byte @@ -589,6 +440,18 @@ type queuedConn struct { remoteAddr net.Addr } +func (c *queuedConn) CloseWrite() error { + if c == nil || c.writeClosed == nil { + return nil + } + c.mu.Lock() + if !isClosedPipeChan(c.writeClosed) { + close(c.writeClosed) + } + c.mu.Unlock() + return nil +} + func (c *queuedConn) closeWithError(err error) error { c.mu.Lock() select { @@ -640,6 +503,9 @@ func (c *queuedConn) Write(b []byte) (n int, err error) { case <-c.closed: c.mu.Unlock() return 0, c.closedErr() + case <-c.writeClosed: + c.mu.Unlock() + return 0, io.ErrClosedPipe default: } c.mu.Unlock() @@ -651,6 +517,8 @@ func (c *queuedConn) Write(b []byte) (n int, err error) { return len(b), nil case <-c.closed: return 0, c.closedErr() + case <-c.writeClosed: + return 0, io.ErrClosedPipe } } @@ -670,6 +538,7 @@ type streamSplitConn struct { client *http.Client pushURL string pullURL string + finURL string closeURL string headerHost string auth *tunnelAuth @@ -697,15 +566,17 @@ func newStreamSplitConnFromInfo(info *sessionDialInfo) *streamSplitConn { client: info.client, pushURL: info.pushURL, pullURL: info.pullURL, + finURL: info.finURL, closeURL: info.closeURL, headerHost: info.headerHost, auth: info.auth, queuedConn: queuedConn{ - rxc: make(chan []byte, 256), - closed: make(chan struct{}), - writeCh: make(chan []byte, 256), - localAddr: &net.TCPAddr{}, - remoteAddr: &net.TCPAddr{}, + rxc: make(chan []byte, 256), + closed: make(chan struct{}), + writeCh: make(chan []byte, 256), + writeClosed: make(chan struct{}), + localAddr: &net.TCPAddr{}, + remoteAddr: &net.TCPAddr{}, }, } @@ -793,7 +664,7 @@ func (c *streamSplitConn) pullLoop() { } req.Host = c.headerHost applyTunnelHeaders(req.Header, c.headerHost, TunnelModeStream) - applyTunnelAuthHeader(req.Header, c.auth, TunnelModeStream, http.MethodGet, "/stream") + applyTunnelAuth(req, c.auth, TunnelModeStream, http.MethodGet, "/stream") resp, err := c.client.Do(req) if err != nil { @@ -891,7 +762,7 @@ func (c *streamSplitConn) pushLoop() { } req.Host = c.headerHost applyTunnelHeaders(req.Header, c.headerHost, TunnelModeStream) - applyTunnelAuthHeader(req.Header, c.auth, TunnelModeStream, http.MethodPost, "/api/v1/upload") + applyTunnelAuth(req, c.auth, TunnelModeStream, http.MethodPost, "/api/v1/upload") req.Header.Set("Content-Type", "application/octet-stream") resp, err := c.client.Do(req) @@ -977,6 +848,27 @@ func (c *streamSplitConn) pushLoop() { return } resetTimer() + case <-c.writeClosed: + // Drain any already-accepted writes so CloseWrite does not lose data. + for { + select { + case b := <-c.writeCh: + if len(b) == 0 { + continue + } + if buf.Len()+len(b) > maxBatchBytes { + if err := flushWithRetry(); err != nil { + _ = c.Close() + return + } + } + _, _ = buf.Write(b) + default: + _ = flushWithRetry() + bestEffortCloseWriteSession(c.client, c.finURL, c.headerHost, TunnelModeStream, c.auth) + return + } + } case <-c.closed: _ = flushWithRetry() return @@ -993,6 +885,7 @@ type pollConn struct { client *http.Client pushURL string pullURL string + finURL string closeURL string headerHost string auth *tunnelAuth @@ -1037,15 +930,17 @@ func newPollConnFromInfo(info *sessionDialInfo) *pollConn { client: info.client, pushURL: info.pushURL, pullURL: info.pullURL, + finURL: info.finURL, closeURL: info.closeURL, headerHost: info.headerHost, auth: info.auth, queuedConn: queuedConn{ - rxc: make(chan []byte, 128), - closed: make(chan struct{}), - writeCh: make(chan []byte, 256), - localAddr: &net.TCPAddr{}, - remoteAddr: &net.TCPAddr{}, + rxc: make(chan []byte, 128), + closed: make(chan struct{}), + writeCh: make(chan []byte, 256), + writeClosed: make(chan struct{}), + localAddr: &net.TCPAddr{}, + remoteAddr: &net.TCPAddr{}, }, } @@ -1117,14 +1012,14 @@ func (c *pollConn) pullLoop() { default: } - req, err := http.NewRequest(http.MethodGet, c.pullURL, nil) + req, err := http.NewRequestWithContext(c.ctx, http.MethodGet, c.pullURL, nil) if err != nil { _ = c.Close() return } req.Host = c.headerHost applyTunnelHeaders(req.Header, c.headerHost, TunnelModePoll) - applyTunnelAuthHeader(req.Header, c.auth, TunnelModePoll, http.MethodGet, "/stream") + applyTunnelAuth(req, c.auth, TunnelModePoll, http.MethodGet, "/stream") resp, err := c.client.Do(req) if err != nil { @@ -1202,21 +1097,25 @@ func (c *pollConn) pushLoop() { return nil } - req, err := http.NewRequest(http.MethodPost, c.pushURL, bytes.NewReader(buf.Bytes())) + reqCtx, cancel := context.WithTimeout(c.ctx, 20*time.Second) + req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, c.pushURL, bytes.NewReader(buf.Bytes())) if err != nil { + cancel() return err } req.Host = c.headerHost applyTunnelHeaders(req.Header, c.headerHost, TunnelModePoll) - applyTunnelAuthHeader(req.Header, c.auth, TunnelModePoll, http.MethodPost, "/api/v1/upload") + applyTunnelAuth(req, c.auth, TunnelModePoll, http.MethodPost, "/api/v1/upload") req.Header.Set("Content-Type", "text/plain") resp, err := c.client.Do(req) if err != nil { + cancel() return err } _, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 4*1024)) _ = resp.Body.Close() + cancel() if resp.StatusCode != http.StatusOK { return fmt.Errorf("bad status: %s", resp.Status) } @@ -1309,6 +1208,41 @@ func (c *pollConn) pushLoop() { return } resetTimer() + case <-c.writeClosed: + // Drain any already-accepted writes so CloseWrite does not lose data. + for { + select { + case b := <-c.writeCh: + if len(b) == 0 { + continue + } + for len(b) > 0 { + chunk := b + if len(chunk) > maxLineRawBytes { + chunk = b[:maxLineRawBytes] + } + b = b[len(chunk):] + + encLen := base64.StdEncoding.EncodedLen(len(chunk)) + if pendingRaw+len(chunk) > maxBatchBytes || buf.Len()+encLen+1 > maxBatchBytes*2 { + if err := flushWithRetry(); err != nil { + _ = c.closeWithError(fmt.Errorf("poll push flush failed: %w", err)) + return + } + } + + tmp := make([]byte, base64.StdEncoding.EncodedLen(len(chunk))) + base64.StdEncoding.Encode(tmp, chunk) + buf.Write(tmp) + buf.WriteByte('\n') + pendingRaw += len(chunk) + } + default: + _ = flushWithRetry() + bestEffortCloseWriteSession(c.client, c.finURL, c.headerHost, TunnelModePoll, c.auth) + return + } + } case <-c.closed: _ = flushWithRetry() return @@ -1478,19 +1412,50 @@ func (s *TunnelServer) HandleConn(rawConn net.Conn) (HandleResult, net.Conn, err tunnelHeader := strings.ToLower(strings.TrimSpace(req.headers["x-sudoku-tunnel"])) if tunnelHeader == "" { - // Not our tunnel; replay full bytes to legacy handler. - prefix := make([]byte, 0, len(headerBytes)+len(buffered)) - prefix = append(prefix, headerBytes...) - prefix = append(prefix, buffered...) - return HandlePassThrough, newPreBufferedConn(rawConn, prefix), nil - } - if s.mode == TunnelModeLegacy { - if s.passThroughOnReject { + // Some CDNs / forward proxies may strip unknown headers. When AuthKey is enabled, we can + // safely infer the intended tunnel mode by verifying the Authorization token against + // both stream/poll modes and picking the one that matches. + if s.auth != nil { + u, err := url.ParseRequestURI(req.target) + if err == nil { + path, ok := stripPathRoot(s.pathRoot, u.Path) + if ok && s.isAllowedBasePath(path) { + authVal := req.headers["authorization"] + if authVal == "" { + authVal = u.Query().Get(tunnelAuthQueryKey) + } + streamOK := s.auth.verifyValue(authVal, TunnelModeStream, req.method, path, time.Now()) + pollOK := s.auth.verifyValue(authVal, TunnelModePoll, req.method, path, time.Now()) + switch { + case streamOK && !pollOK: + tunnelHeader = string(TunnelModeStream) + case pollOK && !streamOK: + tunnelHeader = string(TunnelModePoll) + } + } + } + } + + if tunnelHeader == "" { + // Not our tunnel; replay full bytes to legacy handler. prefix := make([]byte, 0, len(headerBytes)+len(buffered)) prefix = append(prefix, headerBytes...) prefix = append(prefix, buffered...) return HandlePassThrough, newPreBufferedConn(rawConn, prefix), nil } + } + + reject := func() (HandleResult, net.Conn, error) { + prefix := make([]byte, 0, len(headerBytes)+len(buffered)) + prefix = append(prefix, headerBytes...) + prefix = append(prefix, buffered...) + return HandlePassThrough, newRejectedPreBufferedConn(rawConn, prefix), nil + } + + if s.mode == TunnelModeLegacy { + if s.passThroughOnReject { + return reject() + } _ = writeSimpleHTTPResponse(rawConn, http.StatusNotFound, "not found") _ = rawConn.Close() return HandleDone, nil, nil @@ -1500,10 +1465,7 @@ func (s *TunnelServer) HandleConn(rawConn net.Conn) (HandleResult, net.Conn, err case TunnelModeStream: if s.mode != TunnelModeStream && s.mode != TunnelModeAuto { if s.passThroughOnReject { - prefix := make([]byte, 0, len(headerBytes)+len(buffered)) - prefix = append(prefix, headerBytes...) - prefix = append(prefix, buffered...) - return HandlePassThrough, newPreBufferedConn(rawConn, prefix), nil + return reject() } _ = writeSimpleHTTPResponse(rawConn, http.StatusNotFound, "not found") _ = rawConn.Close() @@ -1513,10 +1475,7 @@ func (s *TunnelServer) HandleConn(rawConn net.Conn) (HandleResult, net.Conn, err case TunnelModePoll: if s.mode != TunnelModePoll && s.mode != TunnelModeAuto { if s.passThroughOnReject { - prefix := make([]byte, 0, len(headerBytes)+len(buffered)) - prefix = append(prefix, headerBytes...) - prefix = append(prefix, buffered...) - return HandlePassThrough, newPreBufferedConn(rawConn, prefix), nil + return reject() } _ = writeSimpleHTTPResponse(rawConn, http.StatusNotFound, "not found") _ = rawConn.Close() @@ -1525,10 +1484,7 @@ func (s *TunnelServer) HandleConn(rawConn net.Conn) (HandleResult, net.Conn, err return s.handlePoll(rawConn, req, headerBytes, buffered) default: if s.passThroughOnReject { - prefix := make([]byte, 0, len(headerBytes)+len(buffered)) - prefix = append(prefix, headerBytes...) - prefix = append(prefix, buffered...) - return HandlePassThrough, newPreBufferedConn(rawConn, prefix), nil + return reject() } _ = writeSimpleHTTPResponse(rawConn, http.StatusNotFound, "not found") _ = rawConn.Close() @@ -1619,13 +1575,52 @@ func readAllBuffered(r *bufio.Reader) []byte { type preBufferedConn struct { net.Conn - buf []byte + buf []byte + recorded []byte + rejected bool } -func newPreBufferedConn(conn net.Conn, pre []byte) net.Conn { +func (p *preBufferedConn) CloseWrite() error { + if p == nil || p.Conn == nil { + return nil + } + if cw, ok := p.Conn.(interface{ CloseWrite() error }); ok { + return cw.CloseWrite() + } + return nil +} + +func (p *preBufferedConn) CloseRead() error { + if p == nil || p.Conn == nil { + return nil + } + if cr, ok := p.Conn.(interface{ CloseRead() error }); ok { + return cr.CloseRead() + } + return nil +} + +func newPreBufferedConn(conn net.Conn, pre []byte) *preBufferedConn { cpy := make([]byte, len(pre)) copy(cpy, pre) - return &preBufferedConn{Conn: conn, buf: cpy} + return &preBufferedConn{Conn: conn, buf: cpy, recorded: cpy} +} + +func newRejectedPreBufferedConn(conn net.Conn, pre []byte) *preBufferedConn { + c := newPreBufferedConn(conn, pre) + c.rejected = true + return c +} + +func (p *preBufferedConn) IsHTTPMaskRejected() bool { return p.rejected } + +func (p *preBufferedConn) GetBufferedAndRecorded() []byte { + if len(p.recorded) == 0 { + return nil + } + out := make([]byte, len(p.recorded)) + copy(out, p.recorded) + return out } func (p *preBufferedConn) Read(b []byte) (int, error) { @@ -1682,7 +1677,7 @@ func (s *TunnelServer) handleStream(rawConn net.Conn, req *httpRequestHeader, he prefix := make([]byte, 0, len(headerBytes)+len(buffered)) prefix = append(prefix, headerBytes...) prefix = append(prefix, buffered...) - return HandlePassThrough, newPreBufferedConn(rawConn, prefix), nil + return HandlePassThrough, newRejectedPreBufferedConn(rawConn, prefix), nil } _ = writeSimpleHTTPResponse(rawConn, code, body) _ = rawConn.Close() @@ -1699,21 +1694,29 @@ func (s *TunnelServer) handleStream(rawConn net.Conn, req *httpRequestHeader, he if !ok || !s.isAllowedBasePath(path) { return rejectOrReply(http.StatusNotFound, "not found") } - if !s.auth.verify(req.headers, TunnelModeStream, req.method, path, time.Now()) { + authVal := req.headers["authorization"] + if authVal == "" { + authVal = u.Query().Get(tunnelAuthQueryKey) + } + if !s.auth.verifyValue(authVal, TunnelModeStream, req.method, path, time.Now()) { return rejectOrReply(http.StatusNotFound, "not found") } token := u.Query().Get("token") closeFlag := u.Query().Get("close") == "1" + finFlag := u.Query().Get("fin") == "1" switch strings.ToUpper(req.method) { case http.MethodGet: // Stream split-session: GET /session (no token) => token + start tunnel on a server-side pipe. if token == "" && path == "/session" { - return s.authorizeSession(rawConn) + return s.sessionAuthorize(rawConn) } // Stream split-session: GET /stream?token=... => downlink poll. if token != "" && path == "/stream" { + if s.passThroughOnReject && !s.sessionHas(token) { + return rejectOrReply(http.StatusNotFound, "not found") + } return s.streamPull(rawConn, token) } return rejectOrReply(http.StatusBadRequest, "bad request") @@ -1721,13 +1724,26 @@ func (s *TunnelServer) handleStream(rawConn net.Conn, req *httpRequestHeader, he case http.MethodPost: // Stream split-session: POST /api/v1/upload?token=... => uplink push. if token != "" && path == "/api/v1/upload" { + if s.passThroughOnReject && !s.sessionHas(token) { + return rejectOrReply(http.StatusNotFound, "not found") + } if closeFlag { - s.closeSession(token) - return rejectOrReply(http.StatusOK, "") + s.sessionClose(token) + _ = writeSimpleHTTPResponse(rawConn, http.StatusOK, "") + _ = rawConn.Close() + return HandleDone, nil, nil + } + if finFlag { + s.sessionCloseWrite(token) + _ = writeSimpleHTTPResponse(rawConn, http.StatusOK, "") + _ = rawConn.Close() + return HandleDone, nil, nil } bodyReader, err := newRequestBodyReader(newPreBufferedConn(rawConn, buffered), req.headers) if err != nil { - return rejectOrReply(http.StatusBadRequest, "bad request") + _ = writeSimpleHTTPResponse(rawConn, http.StatusBadRequest, "bad request") + _ = rawConn.Close() + return HandleDone, nil, nil } return s.streamPush(rawConn, token, bodyReader) } @@ -1825,7 +1841,7 @@ func (s *TunnelServer) handlePoll(rawConn net.Conn, req *httpRequestHeader, head prefix := make([]byte, 0, len(headerBytes)+len(buffered)) prefix = append(prefix, headerBytes...) prefix = append(prefix, buffered...) - return HandlePassThrough, newPreBufferedConn(rawConn, prefix), nil + return HandlePassThrough, newRejectedPreBufferedConn(rawConn, prefix), nil } _ = writeSimpleHTTPResponse(rawConn, code, body) _ = rawConn.Close() @@ -1841,18 +1857,26 @@ func (s *TunnelServer) handlePoll(rawConn net.Conn, req *httpRequestHeader, head if !ok || !s.isAllowedBasePath(path) { return rejectOrReply(http.StatusNotFound, "not found") } - if !s.auth.verify(req.headers, TunnelModePoll, req.method, path, time.Now()) { + authVal := req.headers["authorization"] + if authVal == "" { + authVal = u.Query().Get(tunnelAuthQueryKey) + } + if !s.auth.verifyValue(authVal, TunnelModePoll, req.method, path, time.Now()) { return rejectOrReply(http.StatusNotFound, "not found") } token := u.Query().Get("token") closeFlag := u.Query().Get("close") == "1" + finFlag := u.Query().Get("fin") == "1" switch strings.ToUpper(req.method) { case http.MethodGet: if token == "" && path == "/session" { - return s.authorizeSession(rawConn) + return s.sessionAuthorize(rawConn) } if token != "" && path == "/stream" { + if s.passThroughOnReject && !s.sessionHas(token) { + return rejectOrReply(http.StatusNotFound, "not found") + } return s.pollPull(rawConn, token) } return rejectOrReply(http.StatusBadRequest, "bad request") @@ -1860,13 +1884,26 @@ func (s *TunnelServer) handlePoll(rawConn net.Conn, req *httpRequestHeader, head if token == "" || path != "/api/v1/upload" { return rejectOrReply(http.StatusBadRequest, "bad request") } + if s.passThroughOnReject && !s.sessionHas(token) { + return rejectOrReply(http.StatusNotFound, "not found") + } if closeFlag { - s.closeSession(token) - return rejectOrReply(http.StatusOK, "") + s.sessionClose(token) + _ = writeSimpleHTTPResponse(rawConn, http.StatusOK, "") + _ = rawConn.Close() + return HandleDone, nil, nil + } + if finFlag { + s.sessionCloseWrite(token) + _ = writeSimpleHTTPResponse(rawConn, http.StatusOK, "") + _ = rawConn.Close() + return HandleDone, nil, nil } bodyReader, err := newRequestBodyReader(newPreBufferedConn(rawConn, buffered), req.headers) if err != nil { - return rejectOrReply(http.StatusBadRequest, "bad request") + _ = writeSimpleHTTPResponse(rawConn, http.StatusBadRequest, "bad request") + _ = rawConn.Close() + return HandleDone, nil, nil } return s.pollPush(rawConn, token, bodyReader) default: @@ -1874,7 +1911,7 @@ func (s *TunnelServer) handlePoll(rawConn net.Conn, req *httpRequestHeader, head } } -func (s *TunnelServer) authorizeSession(rawConn net.Conn) (HandleResult, net.Conn, error) { +func (s *TunnelServer) sessionAuthorize(rawConn net.Conn) (HandleResult, net.Conn, error) { token, err := newSessionToken() if err != nil { _ = writeSimpleHTTPResponse(rawConn, http.StatusInternalServerError, "internal error") @@ -1882,13 +1919,13 @@ func (s *TunnelServer) authorizeSession(rawConn net.Conn) (HandleResult, net.Con return HandleDone, nil, nil } - c1, c2 := net.Pipe() + c1, c2 := newHalfPipe() s.mu.Lock() s.sessions[token] = &tunnelSession{conn: c2, lastActive: time.Now()} s.mu.Unlock() - go s.reapSessionLater(token) + go s.reapLater(token) _ = writeTokenHTTPResponse(rawConn, token) _ = rawConn.Close() @@ -1903,31 +1940,50 @@ func newSessionToken() (string, error) { return base64.RawURLEncoding.EncodeToString(b[:]), nil } -func (s *TunnelServer) reapSessionLater(token string) { +func (s *TunnelServer) reapLater(token string) { ttl := s.sessionTTL if ttl <= 0 { return } + timer := time.NewTimer(ttl) defer timer.Stop() - <-timer.C - s.mu.Lock() - sess, ok := s.sessions[token] - if !ok { + for { + <-timer.C + + s.mu.Lock() + sess, ok := s.sessions[token] + if !ok { + s.mu.Unlock() + return + } + idle := time.Since(sess.lastActive) + if idle >= ttl { + delete(s.sessions, token) + s.mu.Unlock() + _ = sess.conn.Close() + return + } + next := ttl - idle s.mu.Unlock() - return + + // Avoid a tight loop under high-frequency activity; we only need best-effort cleanup. + if next < 50*time.Millisecond { + next = 50 * time.Millisecond + } + timer.Reset(next) } - if time.Since(sess.lastActive) < ttl { - s.mu.Unlock() - return - } - delete(s.sessions, token) - s.mu.Unlock() - _ = sess.conn.Close() } -func (s *TunnelServer) getSession(token string) (*tunnelSession, bool) { +func (s *TunnelServer) sessionHas(token string) bool { + s.mu.Lock() + _, ok := s.sessions[token] + s.mu.Unlock() + return ok +} + +func (s *TunnelServer) sessionGet(token string) (*tunnelSession, bool) { s.mu.Lock() defer s.mu.Unlock() sess, ok := s.sessions[token] @@ -1938,7 +1994,7 @@ func (s *TunnelServer) getSession(token string) (*tunnelSession, bool) { return sess, true } -func (s *TunnelServer) closeSession(token string) { +func (s *TunnelServer) sessionClose(token string) { s.mu.Lock() sess, ok := s.sessions[token] if ok { @@ -1950,8 +2006,20 @@ func (s *TunnelServer) closeSession(token string) { } } +func (s *TunnelServer) sessionCloseWrite(token string) { + sess, ok := s.sessionGet(token) + if !ok || sess == nil || sess.conn == nil { + return + } + if cw, ok := sess.conn.(interface{ CloseWrite() error }); ok { + _ = cw.CloseWrite() + return + } + _ = sess.conn.Close() +} + func (s *TunnelServer) pollPush(rawConn net.Conn, token string, body io.Reader) (HandleResult, net.Conn, error) { - sess, ok := s.getSession(token) + sess, ok := s.sessionGet(token) if !ok { _ = writeSimpleHTTPResponse(rawConn, http.StatusForbidden, "forbidden") _ = rawConn.Close() @@ -1985,7 +2053,7 @@ func (s *TunnelServer) pollPush(rawConn net.Conn, token string, body io.Reader) _, werr := sess.conn.Write(decoded[:n]) _ = sess.conn.SetWriteDeadline(time.Time{}) if werr != nil { - s.closeSession(token) + s.sessionClose(token) _ = writeSimpleHTTPResponse(rawConn, http.StatusGone, "gone") _ = rawConn.Close() return HandleDone, nil, nil @@ -1998,7 +2066,7 @@ func (s *TunnelServer) pollPush(rawConn net.Conn, token string, body io.Reader) } func (s *TunnelServer) streamPush(rawConn net.Conn, token string, body io.Reader) (HandleResult, net.Conn, error) { - sess, ok := s.getSession(token) + sess, ok := s.sessionGet(token) if !ok { _ = writeSimpleHTTPResponse(rawConn, http.StatusForbidden, "forbidden") _ = rawConn.Close() @@ -2023,7 +2091,7 @@ func (s *TunnelServer) streamPush(rawConn net.Conn, token string, body io.Reader _, werr := sess.conn.Write(payload) _ = sess.conn.SetWriteDeadline(time.Time{}) if werr != nil { - s.closeSession(token) + s.sessionClose(token) _ = writeSimpleHTTPResponse(rawConn, http.StatusGone, "gone") _ = rawConn.Close() return HandleDone, nil, nil @@ -2036,7 +2104,7 @@ func (s *TunnelServer) streamPush(rawConn net.Conn, token string, body io.Reader } func (s *TunnelServer) streamPull(rawConn net.Conn, token string) (HandleResult, net.Conn, error) { - sess, ok := s.getSession(token) + sess, ok := s.sessionGet(token) if !ok { _ = writeSimpleHTTPResponse(rawConn, http.StatusForbidden, "forbidden") _ = rawConn.Close() @@ -2074,14 +2142,14 @@ func (s *TunnelServer) streamPull(rawConn net.Conn, token string) (HandleResult, if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, net.ErrClosed) { return HandleDone, nil, nil } - s.closeSession(token) + s.sessionClose(token) return HandleDone, nil, nil } } } func (s *TunnelServer) pollPull(rawConn net.Conn, token string) (HandleResult, net.Conn, error) { - sess, ok := s.getSession(token) + sess, ok := s.sessionGet(token) if !ok { _ = writeSimpleHTTPResponse(rawConn, http.StatusForbidden, "forbidden") _ = rawConn.Close() @@ -2123,7 +2191,7 @@ func (s *TunnelServer) pollPull(rawConn net.Conn, token string) (HandleResult, n if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, net.ErrClosed) { return HandleDone, nil, nil } - s.closeSession(token) + s.sessionClose(token) return HandleDone, nil, nil } } diff --git a/transport/sudoku/obfs/sudoku/conn.go b/transport/sudoku/obfs/sudoku/conn.go index d09c8a68..fd13d74a 100644 --- a/transport/sudoku/obfs/sudoku/conn.go +++ b/transport/sudoku/obfs/sudoku/conn.go @@ -52,8 +52,28 @@ type Conn struct { pendingData []byte hintBuf []byte - rng *rand.Rand - paddingRate float32 + rng *rand.Rand + paddingThreshold uint64 +} + +func (sc *Conn) CloseWrite() error { + if sc == nil || sc.Conn == nil { + return nil + } + if cw, ok := sc.Conn.(interface{ CloseWrite() error }); ok { + return cw.CloseWrite() + } + return nil +} + +func (sc *Conn) CloseRead() error { + if sc == nil || sc.Conn == nil { + return nil + } + if cr, ok := sc.Conn.(interface{ CloseRead() error }); ok { + return cr.CloseRead() + } + return nil } func NewConn(c net.Conn, table *Table, pMin, pMax int, record bool) *Conn { @@ -64,19 +84,15 @@ func NewConn(c net.Conn, table *Table, pMin, pMax int, record bool) *Conn { seed := int64(binary.BigEndian.Uint64(seedBytes[:])) localRng := rand.New(rand.NewSource(seed)) - min := float32(pMin) / 100.0 - rng := float32(pMax-pMin) / 100.0 - rate := min + localRng.Float32()*rng - sc := &Conn{ - Conn: c, - table: table, - reader: bufio.NewReaderSize(c, IOBufferSize), - rawBuf: make([]byte, IOBufferSize), - pendingData: make([]byte, 0, 4096), - hintBuf: make([]byte, 0, 4), - rng: localRng, - paddingRate: rate, + Conn: c, + table: table, + reader: bufio.NewReaderSize(c, IOBufferSize), + rawBuf: make([]byte, IOBufferSize), + pendingData: make([]byte, 0, 4096), + hintBuf: make([]byte, 0, 4), + rng: localRng, + paddingThreshold: pickPaddingThreshold(localRng, pMin, pMax), } if record { sc.recorder = new(bytes.Buffer) @@ -127,7 +143,7 @@ func (sc *Conn) Write(p []byte) (n int, err error) { padLen := len(pads) for _, b := range p { - if sc.rng.Float32() < sc.paddingRate { + if shouldPad(sc.rng, sc.paddingThreshold) { out = append(out, pads[sc.rng.Intn(padLen)]) } @@ -136,14 +152,14 @@ func (sc *Conn) Write(p []byte) (n int, err error) { perm := perm4[sc.rng.Intn(len(perm4))] for _, idx := range perm { - if sc.rng.Float32() < sc.paddingRate { + if shouldPad(sc.rng, sc.paddingThreshold) { out = append(out, pads[sc.rng.Intn(padLen)]) } out = append(out, puzzle[idx]) } } - if sc.rng.Float32() < sc.paddingRate { + if shouldPad(sc.rng, sc.paddingThreshold) { out = append(out, pads[sc.rng.Intn(padLen)]) } diff --git a/transport/sudoku/obfs/sudoku/packed.go b/transport/sudoku/obfs/sudoku/packed.go index 567afe73..346314a3 100644 --- a/transport/sudoku/obfs/sudoku/packed.go +++ b/transport/sudoku/obfs/sudoku/packed.go @@ -16,7 +16,7 @@ const ( ) // 1. 使用 12字节->16组 的块处理优化 Write (减少循环开销) -// 2. 使用浮点随机概率判断 Padding,与纯 Sudoku 保持流量特征一致 +// 2. 使用整数阈值随机概率判断 Padding,与纯 Sudoku 保持流量特征一致 // 3. Read 使用 copy 移动避免底层数组泄漏 type PackedConn struct { net.Conn @@ -37,11 +37,31 @@ type PackedConn struct { readBitBuf uint64 readBits int - // 随机数与填充控制 - 使用浮点随机,与 Conn 一致 - rng *rand.Rand - paddingRate float32 // 与 Conn 保持一致的随机概率模型 - padMarker byte - padPool []byte + // 随机数与填充控制 - 使用整数阈值随机,与 Conn 一致 + rng *rand.Rand + paddingThreshold uint64 // 与 Conn 保持一致的随机概率模型 + padMarker byte + padPool []byte +} + +func (pc *PackedConn) CloseWrite() error { + if pc == nil || pc.Conn == nil { + return nil + } + if cw, ok := pc.Conn.(interface{ CloseWrite() error }); ok { + return cw.CloseWrite() + } + return nil +} + +func (pc *PackedConn) CloseRead() error { + if pc == nil || pc.Conn == nil { + return nil + } + if cr, ok := pc.Conn.(interface{ CloseRead() error }); ok { + return cr.CloseRead() + } + return nil } func NewPackedConn(c net.Conn, table *Table, pMin, pMax int) *PackedConn { @@ -52,20 +72,15 @@ func NewPackedConn(c net.Conn, table *Table, pMin, pMax int) *PackedConn { seed := int64(binary.BigEndian.Uint64(seedBytes[:])) localRng := rand.New(rand.NewSource(seed)) - // 与 Conn 保持一致的 padding 概率计算 - min := float32(pMin) / 100.0 - rng := float32(pMax-pMin) / 100.0 - rate := min + localRng.Float32()*rng - pc := &PackedConn{ - Conn: c, - table: table, - reader: bufio.NewReaderSize(c, IOBufferSize), - rawBuf: make([]byte, IOBufferSize), - pendingData: make([]byte, 0, 4096), - writeBuf: make([]byte, 0, 4096), - rng: localRng, - paddingRate: rate, + Conn: c, + table: table, + reader: bufio.NewReaderSize(c, IOBufferSize), + rawBuf: make([]byte, IOBufferSize), + pendingData: make([]byte, 0, 4096), + writeBuf: make([]byte, 0, 4096), + rng: localRng, + paddingThreshold: pickPaddingThreshold(localRng, pMin, pMax), } pc.padMarker = table.layout.padMarker @@ -80,9 +95,9 @@ func NewPackedConn(c net.Conn, table *Table, pMin, pMax int) *PackedConn { return pc } -// maybeAddPadding 内联辅助:根据浮点概率插入 padding +// maybeAddPadding 内联辅助:根据概率阈值插入 padding func (pc *PackedConn) maybeAddPadding(out []byte) []byte { - if pc.rng.Float32() < pc.paddingRate { + if shouldPad(pc.rng, pc.paddingThreshold) { out = append(out, pc.getPaddingByte()) } return out diff --git a/transport/sudoku/obfs/sudoku/padding_prob.go b/transport/sudoku/obfs/sudoku/padding_prob.go new file mode 100644 index 00000000..f23a883b --- /dev/null +++ b/transport/sudoku/obfs/sudoku/padding_prob.go @@ -0,0 +1,44 @@ +package sudoku + +import "math/rand" + +const probOne = uint64(1) << 32 + +func pickPaddingThreshold(r *rand.Rand, pMin, pMax int) uint64 { + if r == nil { + return 0 + } + if pMin < 0 { + pMin = 0 + } + if pMax < pMin { + pMax = pMin + } + if pMax > 100 { + pMax = 100 + } + if pMin > 100 { + pMin = 100 + } + + min := uint64(pMin) * probOne / 100 + max := uint64(pMax) * probOne / 100 + if max <= min { + return min + } + u := uint64(r.Uint32()) + return min + (u * (max - min) >> 32) +} + +func shouldPad(r *rand.Rand, threshold uint64) bool { + if threshold == 0 { + return false + } + if threshold >= probOne { + return true + } + if r == nil { + return false + } + return uint64(r.Uint32()) < threshold +}