mirror of
https://github.com/MetaCubeX/mihomo.git
synced 2026-02-26 08:47:09 +00:00
chore: align sudoku with upstream v0.2.0 (#2549)
This commit is contained in:
@@ -7,7 +7,6 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
N "github.com/metacubex/mihomo/common/net"
|
||||
C "github.com/metacubex/mihomo/constant"
|
||||
@@ -22,10 +21,8 @@ type Sudoku struct {
|
||||
httpMaskMu sync.Mutex
|
||||
httpMaskClient *sudoku.HTTPMaskTunnelClient
|
||||
|
||||
muxMu sync.Mutex
|
||||
muxClient *sudoku.MultiplexClient
|
||||
muxBackoffUntil time.Time
|
||||
muxLastErr error
|
||||
muxMu sync.Mutex
|
||||
muxClient *sudoku.MultiplexClient
|
||||
}
|
||||
|
||||
type SudokuOption struct {
|
||||
@@ -58,7 +55,7 @@ func (s *Sudoku) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Con
|
||||
|
||||
muxMode := normalizeHTTPMaskMultiplex(cfg.HTTPMaskMultiplex)
|
||||
if muxMode == "on" && !cfg.DisableHTTPMask && httpTunnelModeEnabled(cfg.HTTPMaskMode) {
|
||||
stream, muxErr := s.dialMultiplex(ctx, cfg.TargetAddress, muxMode)
|
||||
stream, muxErr := s.dialMultiplex(ctx, cfg.TargetAddress)
|
||||
if muxErr == nil {
|
||||
return NewConn(stream, s), nil
|
||||
}
|
||||
@@ -312,9 +309,9 @@ func (s *Sudoku) dialAndHandshake(ctx context.Context, cfg *sudoku.ProtocolConfi
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (s *Sudoku) dialMultiplex(ctx context.Context, targetAddress string, mode string) (net.Conn, error) {
|
||||
func (s *Sudoku) dialMultiplex(ctx context.Context, targetAddress string) (net.Conn, error) {
|
||||
for attempt := 0; attempt < 2; attempt++ {
|
||||
client, err := s.getOrCreateMuxClient(ctx, mode)
|
||||
client, err := s.getOrCreateMuxClient(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -330,21 +327,11 @@ func (s *Sudoku) dialMultiplex(ctx context.Context, targetAddress string, mode s
|
||||
return nil, fmt.Errorf("multiplex open stream failed")
|
||||
}
|
||||
|
||||
func (s *Sudoku) getOrCreateMuxClient(ctx context.Context, mode string) (*sudoku.MultiplexClient, error) {
|
||||
func (s *Sudoku) getOrCreateMuxClient(ctx context.Context) (*sudoku.MultiplexClient, error) {
|
||||
if s == nil {
|
||||
return nil, fmt.Errorf("nil adapter")
|
||||
}
|
||||
|
||||
if mode == "auto" {
|
||||
s.muxMu.Lock()
|
||||
backoffUntil := s.muxBackoffUntil
|
||||
lastErr := s.muxLastErr
|
||||
s.muxMu.Unlock()
|
||||
if time.Now().Before(backoffUntil) {
|
||||
return nil, fmt.Errorf("multiplex temporarily disabled: %v", lastErr)
|
||||
}
|
||||
}
|
||||
|
||||
s.muxMu.Lock()
|
||||
if s.muxClient != nil && !s.muxClient.IsClosed() {
|
||||
client := s.muxClient
|
||||
@@ -363,20 +350,12 @@ func (s *Sudoku) getOrCreateMuxClient(ctx context.Context, mode string) (*sudoku
|
||||
baseCfg := s.baseConf
|
||||
baseConn, err := s.dialAndHandshake(ctx, &baseCfg)
|
||||
if err != nil {
|
||||
if mode == "auto" {
|
||||
s.muxLastErr = err
|
||||
s.muxBackoffUntil = time.Now().Add(45 * time.Second)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client, err := sudoku.StartMultiplexClient(baseConn)
|
||||
if err != nil {
|
||||
_ = baseConn.Close()
|
||||
if mode == "auto" {
|
||||
s.muxLastErr = err
|
||||
s.muxBackoffUntil = time.Now().Add(45 * time.Second)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -384,16 +363,6 @@ func (s *Sudoku) getOrCreateMuxClient(ctx context.Context, mode string) (*sudoku
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (s *Sudoku) noteMuxFailure(mode string, err error) {
|
||||
if mode != "auto" {
|
||||
return
|
||||
}
|
||||
s.muxMu.Lock()
|
||||
s.muxLastErr = err
|
||||
s.muxBackoffUntil = time.Now().Add(45 * time.Second)
|
||||
s.muxMu.Unlock()
|
||||
}
|
||||
|
||||
func (s *Sudoku) resetMuxClient() {
|
||||
s.muxMu.Lock()
|
||||
defer s.muxMu.Unlock()
|
||||
|
||||
@@ -1082,19 +1082,19 @@ proxies: # socks5
|
||||
server: server_ip/domain # 1.2.3.4 or domain
|
||||
port: 443
|
||||
key: "<client_key>" # 如果你使用sudoku生成的ED25519密钥对,请填写密钥对中的私钥,否则填入和服务端相同的uuid
|
||||
aead-method: chacha20-poly1305 # 可选值:chacha20-poly1305、aes-128-gcm、none 我们保证在none的情况下sudoku混淆层仍然确保安全
|
||||
padding-min: 2 # 最小填充字节数
|
||||
padding-max: 7 # 最大填充字节数
|
||||
aead-method: chacha20-poly1305 # 可选:chacha20-poly1305、aes-128-gcm、none(不建议;且 enable-pure-downlink=false 时不可用)
|
||||
padding-min: 2 # 最小填充率(0-100)
|
||||
padding-max: 7 # 最大填充率(0-100,必须 >= padding-min)
|
||||
table-type: prefer_ascii # 可选值:prefer_ascii、prefer_entropy 前者全ascii映射,后者保证熵值(汉明1)低于3
|
||||
# custom-table: xpxvvpvv # 可选,自定义字节布局,必须包含2个x、2个p、4个v,可随意组合。启用此处则需配置`table-type`为`prefer_entropy`
|
||||
# custom-tables: ["xpxvvpvv", "vxpvxvvp"] # 可选,自定义字节布局列表(x/v/p),用于 xvp 模式轮换;非空时覆盖 custom-table
|
||||
http-mask: true # 是否启用http掩码
|
||||
# http-mask-mode: legacy # 可选:legacy(默认)、stream、poll、auto;stream/poll/auto 支持走 CDN/反代
|
||||
# http-mask-mode: legacy # 可选:legacy(默认)、stream(split-stream)、poll、auto(先 stream 再 poll);stream/poll/auto 支持走 CDN/反代
|
||||
# http-mask-tls: true # 可选:仅在 http-mask-mode 为 stream/poll/auto 时生效;true 强制 https;false 强制 http(不会根据端口自动推断)
|
||||
# http-mask-host: "" # 可选:覆盖 Host/SNI(支持 example.com 或 example.com:443);仅在 http-mask-mode 为 stream/poll/auto 时生效
|
||||
# path-root: "" # 可选:HTTP 隧道端点一级路径前缀(双方需一致),例如 "aabbcc" => /aabbcc/session、/aabbcc/stream、/aabbcc/api/v1/upload
|
||||
# http-mask-multiplex: off # 可选:off(默认)、auto(复用 h1.1 keep-alive / h2 连接,减少每次建链 RTT)、on(单条隧道内多路复用多个目标连接;仅在 http-mask-mode=stream/poll/auto 生效)
|
||||
enable-pure-downlink: false # 是否启用混淆下行,false的情况下能在保证数据安全的前提下极大提升下行速度,与服务端端保持相同(如果此处为false,则要求aead不可为none)
|
||||
# path-root: "" # 可选:HTTP 隧道端点一级路径前缀(双方需一致),例如 "aabbcc" 或 "/aabbcc/" => /aabbcc/session、/aabbcc/stream、/aabbcc/api/v1/upload
|
||||
# http-mask-multiplex: off # 可选:off(默认)、auto(复用底层 HTTP 连接,减少建链 RTT)、on(Sudoku mux 单隧道多目标;仅在 http-mask-mode=stream/poll/auto 生效)
|
||||
enable-pure-downlink: false # 可选:false=带宽优化下行(更快,要求 aead-method != none);true=纯 Sudoku 下行
|
||||
|
||||
# anytls
|
||||
- name: anytls
|
||||
@@ -1632,17 +1632,17 @@ listeners:
|
||||
port: 8443 # 仅支持单端口
|
||||
listen: 0.0.0.0
|
||||
key: "<server_key>" # 如果你使用sudoku生成的ED25519密钥对,此处是密钥对中的公钥,当然,你也可以仅仅使用任意uuid充当key
|
||||
aead-method: chacha20-poly1305 # 支持chacha20-poly1305或者aes-128-gcm以及none,sudoku的混淆层可以确保none情况下数据安全
|
||||
padding-min: 1 # 填充最小长度
|
||||
padding-max: 15 # 填充最大长度,均不建议过大
|
||||
aead-method: chacha20-poly1305 # 可选:chacha20-poly1305、aes-128-gcm、none(不建议;且 enable-pure-downlink=false 时不可用)
|
||||
padding-min: 1 # 最小填充率(0-100)
|
||||
padding-max: 15 # 最大填充率(0-100,必须 >= padding-min)
|
||||
table-type: prefer_ascii # 可选值:prefer_ascii、prefer_entropy 前者全ascii映射,后者保证熵值(汉明1)低于3
|
||||
# custom-table: xpxvvpvv # 可选,自定义字节布局,必须包含2个x、2个p、4个v,可随意组合。启用此处则需配置`table-type`为`prefer_entropy`
|
||||
# custom-tables: ["xpxvvpvv", "vxpvxvvp"] # 可选,自定义字节布局列表(x/v/p),用于 xvp 模式轮换;非空时覆盖 custom-table
|
||||
handshake-timeout: 5 # optional
|
||||
enable-pure-downlink: false # 是否启用混淆下行,false的情况下能在保证数据安全的前提下极大提升下行速度,与客户端保持相同(如果此处为false,则要求aead不可为none)
|
||||
handshake-timeout: 5 # 可选(秒)
|
||||
enable-pure-downlink: false # 可选:false=带宽优化下行(更快,要求 aead-method != none);true=纯 Sudoku 下行
|
||||
disable-http-mask: false # 可选:禁用 http 掩码/隧道(默认为 false)
|
||||
# http-mask-mode: legacy # 可选:legacy(默认)、stream、poll、auto;stream/poll/auto 支持走 CDN/反代
|
||||
# path-root: "" # 可选:HTTP 隧道端点一级路径前缀(双方需一致),例如 "aabbcc" => /aabbcc/session、/aabbcc/stream、/aabbcc/api/v1/upload
|
||||
# http-mask-mode: legacy # 可选:legacy(默认)、stream(split-stream)、poll、auto(先 stream 再 poll);stream/poll/auto 支持走 CDN/反代
|
||||
# path-root: "" # 可选:HTTP 隧道端点一级路径前缀(双方需一致),例如 "aabbcc" 或 "/aabbcc/" => /aabbcc/session、/aabbcc/stream、/aabbcc/api/v1/upload
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -32,7 +32,8 @@ type ProtocolConfig struct {
|
||||
PaddingMin int
|
||||
PaddingMax int
|
||||
|
||||
// EnablePureDownlink toggles the bandwidth-optimized downlink mode.
|
||||
// EnablePureDownlink enables the pure Sudoku downlink mode.
|
||||
// When false, the connection uses the bandwidth-optimized packed downlink (requires AEAD).
|
||||
EnablePureDownlink bool
|
||||
|
||||
// Client-only: final target "host:port".
|
||||
@@ -46,7 +47,7 @@ type ProtocolConfig struct {
|
||||
|
||||
// HTTPMaskMode controls how the HTTP layer behaves:
|
||||
// - "legacy": write a fake HTTP/1.1 header then switch to raw stream (default, not CDN-compatible)
|
||||
// - "stream": real HTTP tunnel (stream-one or split), CDN-compatible
|
||||
// - "stream": real HTTP tunnel (split-stream), CDN-compatible
|
||||
// - "poll": plain HTTP tunnel (authorize/push/pull), strong restricted-network pass-through
|
||||
// - "auto": try stream then fall back to poll
|
||||
HTTPMaskMode string
|
||||
@@ -114,7 +115,8 @@ func (c *ProtocolConfig) Validate() error {
|
||||
}
|
||||
|
||||
if v := strings.TrimSpace(c.HTTPMaskPathRoot); v != "" {
|
||||
if strings.Contains(v, "/") {
|
||||
v = strings.Trim(v, "/")
|
||||
if v == "" || strings.Contains(v, "/") {
|
||||
return fmt.Errorf("invalid http-mask-path-root: must be a single path segment")
|
||||
}
|
||||
for i := 0; i < len(v); i++ {
|
||||
|
||||
@@ -22,6 +22,26 @@ type AEADConn struct {
|
||||
nonceSize int
|
||||
}
|
||||
|
||||
func (cc *AEADConn) CloseWrite() error {
|
||||
if cc == nil || cc.Conn == nil {
|
||||
return nil
|
||||
}
|
||||
if cw, ok := cc.Conn.(interface{ CloseWrite() error }); ok {
|
||||
return cw.CloseWrite()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cc *AEADConn) CloseRead() error {
|
||||
if cc == nil || cc.Conn == nil {
|
||||
return nil
|
||||
}
|
||||
if cr, ok := cc.Conn.(interface{ CloseRead() error }); ok {
|
||||
return cr.CloseRead()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewAEADConn(c net.Conn, key string, method string) (*AEADConn, error) {
|
||||
if method == "none" {
|
||||
return &AEADConn{Conn: c, aead: nil}, nil
|
||||
|
||||
@@ -383,7 +383,11 @@ func (c *stream) enqueue(payload []byte) {
|
||||
c.mu.Unlock()
|
||||
return
|
||||
}
|
||||
c.queue = append(c.queue, payload)
|
||||
if len(c.readBuf) == 0 && len(c.queue) == 0 {
|
||||
c.readBuf = payload
|
||||
} else {
|
||||
c.queue = append(c.queue, payload)
|
||||
}
|
||||
c.cond.Signal()
|
||||
c.mu.Unlock()
|
||||
}
|
||||
@@ -491,6 +495,9 @@ func (c *stream) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *stream) CloseWrite() error { return c.Close() }
|
||||
func (c *stream) CloseRead() error { return c.Close() }
|
||||
|
||||
func (c *stream) LocalAddr() net.Addr { return c.localAddr }
|
||||
func (c *stream) RemoteAddr() net.Addr { return c.remoteAddr }
|
||||
|
||||
@@ -501,4 +508,3 @@ func (c *stream) SetDeadline(t time.Time) error {
|
||||
}
|
||||
func (c *stream) SetReadDeadline(time.Time) error { return nil }
|
||||
func (c *stream) SetWriteDeadline(time.Time) error { return nil }
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"github.com/metacubex/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
@@ -13,6 +14,7 @@ import (
|
||||
const (
|
||||
tunnelAuthHeaderKey = "Authorization"
|
||||
tunnelAuthHeaderPrefix = "Bearer "
|
||||
tunnelAuthQueryKey = "auth"
|
||||
)
|
||||
|
||||
type tunnelAuth struct {
|
||||
@@ -61,8 +63,15 @@ func (a *tunnelAuth) verify(headers map[string]string, mode TunnelMode, method,
|
||||
if headers == nil {
|
||||
return false
|
||||
}
|
||||
return a.verifyValue(headers["authorization"], mode, method, path, now)
|
||||
}
|
||||
|
||||
val := strings.TrimSpace(headers["authorization"])
|
||||
func (a *tunnelAuth) verifyValue(val string, mode TunnelMode, method, path string, now time.Time) bool {
|
||||
if a == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
val = strings.TrimSpace(val)
|
||||
if val == "" {
|
||||
return false
|
||||
}
|
||||
@@ -121,11 +130,9 @@ func (a *tunnelAuth) sign(mode TunnelMode, method, path string, ts int64) [16]by
|
||||
return out
|
||||
}
|
||||
|
||||
type headerSetter interface {
|
||||
Set(key, value string)
|
||||
}
|
||||
type httpHeaderSetter = http.Header
|
||||
|
||||
func applyTunnelAuthHeader(h headerSetter, auth *tunnelAuth, mode TunnelMode, method, path string) {
|
||||
func applyTunnelAuthHeader(h httpHeaderSetter, auth *tunnelAuth, mode TunnelMode, method, path string) {
|
||||
if auth == nil || h == nil {
|
||||
return
|
||||
}
|
||||
@@ -135,3 +142,19 @@ func applyTunnelAuthHeader(h headerSetter, auth *tunnelAuth, mode TunnelMode, me
|
||||
}
|
||||
h.Set(tunnelAuthHeaderKey, tunnelAuthHeaderPrefix+token)
|
||||
}
|
||||
|
||||
func applyTunnelAuth(req *http.Request, auth *tunnelAuth, mode TunnelMode, method, path string) {
|
||||
if auth == nil || req == nil {
|
||||
return
|
||||
}
|
||||
token := auth.token(mode, method, path, time.Now())
|
||||
if token == "" {
|
||||
return
|
||||
}
|
||||
req.Header.Set(tunnelAuthHeaderKey, tunnelAuthHeaderPrefix+token)
|
||||
if req.URL != nil {
|
||||
q := req.URL.Query()
|
||||
q.Set(tunnelAuthQueryKey, token)
|
||||
req.URL.RawQuery = q.Encode()
|
||||
}
|
||||
}
|
||||
|
||||
229
transport/sudoku/obfs/httpmask/halfpipe.go
Normal file
229
transport/sudoku/obfs/httpmask/halfpipe.go
Normal file
@@ -0,0 +1,229 @@
|
||||
package httpmask
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type pipeDeadline struct {
|
||||
mu sync.Mutex
|
||||
timer *time.Timer
|
||||
cancel chan struct{}
|
||||
}
|
||||
|
||||
func makePipeDeadline() pipeDeadline {
|
||||
return pipeDeadline{cancel: make(chan struct{})}
|
||||
}
|
||||
|
||||
func (d *pipeDeadline) set(t time.Time) {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
if d.timer != nil && !d.timer.Stop() {
|
||||
<-d.cancel
|
||||
}
|
||||
d.timer = nil
|
||||
|
||||
closed := isClosedPipeChan(d.cancel)
|
||||
if t.IsZero() {
|
||||
if closed {
|
||||
d.cancel = make(chan struct{})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if dur := time.Until(t); dur > 0 {
|
||||
if closed {
|
||||
d.cancel = make(chan struct{})
|
||||
}
|
||||
d.timer = time.AfterFunc(dur, func() {
|
||||
close(d.cancel)
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if !closed {
|
||||
close(d.cancel)
|
||||
}
|
||||
}
|
||||
|
||||
func (d *pipeDeadline) wait() <-chan struct{} {
|
||||
d.mu.Lock()
|
||||
ch := d.cancel
|
||||
d.mu.Unlock()
|
||||
return ch
|
||||
}
|
||||
|
||||
func isClosedPipeChan(ch <-chan struct{}) bool {
|
||||
select {
|
||||
case <-ch:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
type halfPipeAddr struct{}
|
||||
|
||||
func (halfPipeAddr) Network() string { return "pipe" }
|
||||
func (halfPipeAddr) String() string { return "pipe" }
|
||||
|
||||
type halfPipeConn struct {
|
||||
wrMu sync.Mutex
|
||||
|
||||
rdRx <-chan []byte
|
||||
rdTx chan<- int
|
||||
|
||||
wrTx chan<- []byte
|
||||
wrRx <-chan int
|
||||
|
||||
readOnce sync.Once
|
||||
writeOnce sync.Once
|
||||
|
||||
localReadDone chan struct{}
|
||||
localWriteDone chan struct{}
|
||||
|
||||
remoteReadDone <-chan struct{}
|
||||
remoteWriteDone <-chan struct{}
|
||||
|
||||
readDeadline pipeDeadline
|
||||
writeDeadline pipeDeadline
|
||||
}
|
||||
|
||||
func newHalfPipe() (net.Conn, net.Conn) {
|
||||
cb1 := make(chan []byte)
|
||||
cb2 := make(chan []byte)
|
||||
cn1 := make(chan int)
|
||||
cn2 := make(chan int)
|
||||
|
||||
r1 := make(chan struct{})
|
||||
w1 := make(chan struct{})
|
||||
r2 := make(chan struct{})
|
||||
w2 := make(chan struct{})
|
||||
|
||||
c1 := &halfPipeConn{
|
||||
rdRx: cb1,
|
||||
rdTx: cn1,
|
||||
wrTx: cb2,
|
||||
wrRx: cn2,
|
||||
|
||||
localReadDone: r1,
|
||||
localWriteDone: w1,
|
||||
remoteReadDone: r2,
|
||||
remoteWriteDone: w2,
|
||||
|
||||
readDeadline: makePipeDeadline(),
|
||||
writeDeadline: makePipeDeadline(),
|
||||
}
|
||||
c2 := &halfPipeConn{
|
||||
rdRx: cb2,
|
||||
rdTx: cn2,
|
||||
wrTx: cb1,
|
||||
wrRx: cn1,
|
||||
|
||||
localReadDone: r2,
|
||||
localWriteDone: w2,
|
||||
remoteReadDone: r1,
|
||||
remoteWriteDone: w1,
|
||||
|
||||
readDeadline: makePipeDeadline(),
|
||||
writeDeadline: makePipeDeadline(),
|
||||
}
|
||||
return c1, c2
|
||||
}
|
||||
|
||||
func (*halfPipeConn) LocalAddr() net.Addr { return halfPipeAddr{} }
|
||||
func (*halfPipeConn) RemoteAddr() net.Addr { return halfPipeAddr{} }
|
||||
|
||||
func (c *halfPipeConn) Read(p []byte) (int, error) {
|
||||
switch {
|
||||
case isClosedPipeChan(c.localReadDone):
|
||||
return 0, io.ErrClosedPipe
|
||||
case isClosedPipeChan(c.remoteWriteDone):
|
||||
return 0, io.EOF
|
||||
case isClosedPipeChan(c.readDeadline.wait()):
|
||||
return 0, os.ErrDeadlineExceeded
|
||||
}
|
||||
|
||||
select {
|
||||
case b := <-c.rdRx:
|
||||
n := copy(p, b)
|
||||
c.rdTx <- n
|
||||
return n, nil
|
||||
case <-c.localReadDone:
|
||||
return 0, io.ErrClosedPipe
|
||||
case <-c.remoteWriteDone:
|
||||
return 0, io.EOF
|
||||
case <-c.readDeadline.wait():
|
||||
return 0, os.ErrDeadlineExceeded
|
||||
}
|
||||
}
|
||||
|
||||
func (c *halfPipeConn) Write(p []byte) (int, error) {
|
||||
switch {
|
||||
case isClosedPipeChan(c.localWriteDone):
|
||||
return 0, io.ErrClosedPipe
|
||||
case isClosedPipeChan(c.remoteReadDone):
|
||||
return 0, io.ErrClosedPipe
|
||||
case isClosedPipeChan(c.writeDeadline.wait()):
|
||||
return 0, os.ErrDeadlineExceeded
|
||||
}
|
||||
|
||||
c.wrMu.Lock()
|
||||
defer c.wrMu.Unlock()
|
||||
|
||||
var (
|
||||
total int
|
||||
rest = p
|
||||
)
|
||||
for once := true; once || len(rest) > 0; once = false {
|
||||
select {
|
||||
case c.wrTx <- rest:
|
||||
n := <-c.wrRx
|
||||
rest = rest[n:]
|
||||
total += n
|
||||
case <-c.localWriteDone:
|
||||
return total, io.ErrClosedPipe
|
||||
case <-c.remoteReadDone:
|
||||
return total, io.ErrClosedPipe
|
||||
case <-c.writeDeadline.wait():
|
||||
return total, os.ErrDeadlineExceeded
|
||||
}
|
||||
}
|
||||
return total, nil
|
||||
}
|
||||
|
||||
func (c *halfPipeConn) CloseWrite() error {
|
||||
c.writeOnce.Do(func() { close(c.localWriteDone) })
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *halfPipeConn) CloseRead() error {
|
||||
c.readOnce.Do(func() { close(c.localReadDone) })
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *halfPipeConn) Close() error {
|
||||
_ = c.CloseRead()
|
||||
_ = c.CloseWrite()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *halfPipeConn) SetDeadline(t time.Time) error {
|
||||
c.readDeadline.set(t)
|
||||
c.writeDeadline.set(t)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *halfPipeConn) SetReadDeadline(t time.Time) error {
|
||||
c.readDeadline.set(t)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *halfPipeConn) SetWriteDeadline(t time.Time) error {
|
||||
c.writeDeadline.set(t)
|
||||
return nil
|
||||
}
|
||||
@@ -241,38 +241,6 @@ func parseTunnelToken(body []byte) (string, error) {
|
||||
return token, nil
|
||||
}
|
||||
|
||||
type httpStreamConn struct {
|
||||
reader io.ReadCloser
|
||||
writer *io.PipeWriter
|
||||
cancel context.CancelFunc
|
||||
|
||||
localAddr net.Addr
|
||||
remoteAddr net.Addr
|
||||
}
|
||||
|
||||
func (c *httpStreamConn) Read(p []byte) (int, error) { return c.reader.Read(p) }
|
||||
func (c *httpStreamConn) Write(p []byte) (int, error) { return c.writer.Write(p) }
|
||||
|
||||
func (c *httpStreamConn) Close() error {
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
}
|
||||
if c.writer != nil {
|
||||
_ = c.writer.CloseWithError(io.ErrClosedPipe)
|
||||
}
|
||||
if c.reader != nil {
|
||||
return c.reader.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *httpStreamConn) LocalAddr() net.Addr { return c.localAddr }
|
||||
func (c *httpStreamConn) RemoteAddr() net.Addr { return c.remoteAddr }
|
||||
|
||||
func (c *httpStreamConn) SetDeadline(time.Time) error { return nil }
|
||||
func (c *httpStreamConn) SetReadDeadline(time.Time) error { return nil }
|
||||
func (c *httpStreamConn) SetWriteDeadline(time.Time) error { return nil }
|
||||
|
||||
type httpClientTarget struct {
|
||||
scheme string
|
||||
urlHost string
|
||||
@@ -332,6 +300,7 @@ type sessionDialInfo struct {
|
||||
client *http.Client
|
||||
pushURL string
|
||||
pullURL string
|
||||
finURL string
|
||||
closeURL string
|
||||
headerHost string
|
||||
auth *tunnelAuth
|
||||
@@ -350,7 +319,7 @@ func dialSessionWithClient(ctx context.Context, client *http.Client, target http
|
||||
}
|
||||
req.Host = target.headerHost
|
||||
applyTunnelHeaders(req.Header, target.headerHost, mode)
|
||||
applyTunnelAuthHeader(req.Header, auth, mode, http.MethodGet, "/session")
|
||||
applyTunnelAuth(req, auth, mode, http.MethodGet, "/session")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
@@ -375,12 +344,14 @@ func dialSessionWithClient(ctx context.Context, client *http.Client, target http
|
||||
|
||||
pushURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: joinPathRoot(opts.PathRoot, "/api/v1/upload"), RawQuery: "token=" + url.QueryEscape(token)}).String()
|
||||
pullURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: joinPathRoot(opts.PathRoot, "/stream"), RawQuery: "token=" + url.QueryEscape(token)}).String()
|
||||
finURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: joinPathRoot(opts.PathRoot, "/api/v1/upload"), RawQuery: "token=" + url.QueryEscape(token) + "&fin=1"}).String()
|
||||
closeURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: joinPathRoot(opts.PathRoot, "/api/v1/upload"), RawQuery: "token=" + url.QueryEscape(token) + "&close=1"}).String()
|
||||
|
||||
return &sessionDialInfo{
|
||||
client: client,
|
||||
pushURL: pushURL,
|
||||
pullURL: pullURL,
|
||||
finURL: finURL,
|
||||
closeURL: closeURL,
|
||||
headerHost: target.headerHost,
|
||||
auth: auth,
|
||||
@@ -409,7 +380,31 @@ func bestEffortCloseSession(client *http.Client, closeURL, headerHost string, mo
|
||||
}
|
||||
req.Host = headerHost
|
||||
applyTunnelHeaders(req.Header, headerHost, mode)
|
||||
applyTunnelAuthHeader(req.Header, auth, mode, http.MethodPost, "/api/v1/upload")
|
||||
applyTunnelAuth(req, auth, mode, http.MethodPost, "/api/v1/upload")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil || resp == nil {
|
||||
return
|
||||
}
|
||||
_, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 4*1024))
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
|
||||
func bestEffortCloseWriteSession(client *http.Client, finURL, headerHost string, mode TunnelMode, auth *tunnelAuth) {
|
||||
if client == nil || finURL == "" || headerHost == "" {
|
||||
return
|
||||
}
|
||||
|
||||
closeCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(closeCtx, http.MethodPost, finURL, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
req.Host = headerHost
|
||||
applyTunnelHeaders(req.Header, headerHost, mode)
|
||||
applyTunnelAuth(req, auth, mode, http.MethodPost, "/api/v1/upload")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil || resp == nil {
|
||||
@@ -420,160 +415,13 @@ func bestEffortCloseSession(client *http.Client, closeURL, headerHost string, mo
|
||||
}
|
||||
|
||||
func dialStreamWithClient(ctx context.Context, client *http.Client, target httpClientTarget, opts TunnelDialOptions) (net.Conn, error) {
|
||||
// Prefer split-session (Cloudflare-friendly). Fall back to stream-one for older servers / environments.
|
||||
c, errSplit := dialStreamSplitWithClient(ctx, client, target, opts)
|
||||
if errSplit == nil {
|
||||
return c, nil
|
||||
}
|
||||
c2, errOne := dialStreamOneWithClient(ctx, client, target, opts)
|
||||
if errOne == nil {
|
||||
return c2, nil
|
||||
}
|
||||
return nil, fmt.Errorf("dial stream failed: split: %v; stream-one: %w", errSplit, errOne)
|
||||
// "stream" mode uses split-stream to stay CDN-friendly by default.
|
||||
return dialStreamSplitWithClient(ctx, client, target, opts)
|
||||
}
|
||||
|
||||
func dialStream(ctx context.Context, serverAddress string, opts TunnelDialOptions) (net.Conn, error) {
|
||||
// Prefer split-session (Cloudflare-friendly). Fall back to stream-one for older servers / environments.
|
||||
c, errSplit := dialStreamSplit(ctx, serverAddress, opts)
|
||||
if errSplit == nil {
|
||||
return c, nil
|
||||
}
|
||||
c2, errOne := dialStreamOne(ctx, serverAddress, opts)
|
||||
if errOne == nil {
|
||||
return c2, nil
|
||||
}
|
||||
return nil, fmt.Errorf("dial stream failed: split: %v; stream-one: %w", errSplit, errOne)
|
||||
}
|
||||
|
||||
func dialStreamOneWithClient(ctx context.Context, client *http.Client, target httpClientTarget, opts TunnelDialOptions) (net.Conn, error) {
|
||||
if client == nil {
|
||||
return nil, fmt.Errorf("nil http client")
|
||||
}
|
||||
|
||||
auth := newTunnelAuth(opts.AuthKey, 0)
|
||||
r := rngPool.Get().(*mrand.Rand)
|
||||
basePath := paths[r.Intn(len(paths))]
|
||||
path := joinPathRoot(opts.PathRoot, basePath)
|
||||
ctype := contentTypes[r.Intn(len(contentTypes))]
|
||||
rngPool.Put(r)
|
||||
|
||||
u := url.URL{
|
||||
Scheme: target.scheme,
|
||||
Host: target.urlHost,
|
||||
Path: path,
|
||||
}
|
||||
|
||||
reqBodyR, reqBodyW := io.Pipe()
|
||||
|
||||
connCtx, connCancel := context.WithCancel(context.Background())
|
||||
req, err := http.NewRequestWithContext(connCtx, http.MethodPost, u.String(), reqBodyR)
|
||||
if err != nil {
|
||||
connCancel()
|
||||
_ = reqBodyW.Close()
|
||||
return nil, err
|
||||
}
|
||||
req.Host = target.headerHost
|
||||
|
||||
applyTunnelHeaders(req.Header, target.headerHost, TunnelModeStream)
|
||||
applyTunnelAuthHeader(req.Header, auth, TunnelModeStream, http.MethodPost, basePath)
|
||||
req.Header.Set("Content-Type", ctype)
|
||||
|
||||
type doResult struct {
|
||||
resp *http.Response
|
||||
err error
|
||||
}
|
||||
doCh := make(chan doResult, 1)
|
||||
go func() {
|
||||
resp, doErr := client.Do(req)
|
||||
doCh <- doResult{resp: resp, err: doErr}
|
||||
}()
|
||||
|
||||
streamConn := &httpStreamConn{
|
||||
writer: reqBodyW,
|
||||
cancel: connCancel,
|
||||
localAddr: &net.TCPAddr{},
|
||||
remoteAddr: &net.TCPAddr{},
|
||||
}
|
||||
|
||||
type upgradeResult struct {
|
||||
conn net.Conn
|
||||
err error
|
||||
}
|
||||
upgradeCh := make(chan upgradeResult, 1)
|
||||
if opts.Upgrade == nil {
|
||||
upgradeCh <- upgradeResult{conn: streamConn, err: nil}
|
||||
} else {
|
||||
go func() {
|
||||
upgradeConn, err := opts.Upgrade(streamConn)
|
||||
if err != nil {
|
||||
upgradeCh <- upgradeResult{conn: nil, err: err}
|
||||
return
|
||||
}
|
||||
if upgradeConn == nil {
|
||||
upgradeConn = streamConn
|
||||
}
|
||||
upgradeCh <- upgradeResult{conn: upgradeConn, err: nil}
|
||||
}()
|
||||
}
|
||||
|
||||
var (
|
||||
outConn net.Conn
|
||||
upgradeDone bool
|
||||
responseReady bool
|
||||
)
|
||||
|
||||
for !(upgradeDone && responseReady) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
_ = streamConn.Close()
|
||||
if outConn != nil && outConn != streamConn {
|
||||
_ = outConn.Close()
|
||||
}
|
||||
return nil, ctx.Err()
|
||||
|
||||
case u := <-upgradeCh:
|
||||
if u.err != nil {
|
||||
_ = streamConn.Close()
|
||||
return nil, u.err
|
||||
}
|
||||
outConn = u.conn
|
||||
if outConn == nil {
|
||||
outConn = streamConn
|
||||
}
|
||||
upgradeDone = true
|
||||
|
||||
case r := <-doCh:
|
||||
if r.err != nil {
|
||||
_ = streamConn.Close()
|
||||
if outConn != nil && outConn != streamConn {
|
||||
_ = outConn.Close()
|
||||
}
|
||||
return nil, r.err
|
||||
}
|
||||
if r.resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(io.LimitReader(r.resp.Body, 4*1024))
|
||||
_ = r.resp.Body.Close()
|
||||
_ = streamConn.Close()
|
||||
if outConn != nil && outConn != streamConn {
|
||||
_ = outConn.Close()
|
||||
}
|
||||
return nil, fmt.Errorf("stream bad status: %s (%s)", r.resp.Status, strings.TrimSpace(string(body)))
|
||||
}
|
||||
|
||||
streamConn.reader = r.resp.Body
|
||||
responseReady = true
|
||||
}
|
||||
}
|
||||
|
||||
return outConn, nil
|
||||
}
|
||||
|
||||
func dialStreamOne(ctx context.Context, serverAddress string, opts TunnelDialOptions) (net.Conn, error) {
|
||||
client, target, err := newHTTPClient(serverAddress, opts, 32)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dialStreamOneWithClient(ctx, client, target, opts)
|
||||
// "stream" mode uses split-stream to stay CDN-friendly by default.
|
||||
return dialStreamSplit(ctx, serverAddress, opts)
|
||||
}
|
||||
|
||||
type queuedConn struct {
|
||||
@@ -581,6 +429,9 @@ type queuedConn struct {
|
||||
closed chan struct{}
|
||||
|
||||
writeCh chan []byte
|
||||
// writeClosed is closed by CloseWrite to stop accepting new payloads.
|
||||
// When closed, Write returns io.ErrClosedPipe, but Read is unaffected.
|
||||
writeClosed chan struct{}
|
||||
|
||||
mu sync.Mutex
|
||||
readBuf []byte
|
||||
@@ -589,6 +440,18 @@ type queuedConn struct {
|
||||
remoteAddr net.Addr
|
||||
}
|
||||
|
||||
func (c *queuedConn) CloseWrite() error {
|
||||
if c == nil || c.writeClosed == nil {
|
||||
return nil
|
||||
}
|
||||
c.mu.Lock()
|
||||
if !isClosedPipeChan(c.writeClosed) {
|
||||
close(c.writeClosed)
|
||||
}
|
||||
c.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *queuedConn) closeWithError(err error) error {
|
||||
c.mu.Lock()
|
||||
select {
|
||||
@@ -640,6 +503,9 @@ func (c *queuedConn) Write(b []byte) (n int, err error) {
|
||||
case <-c.closed:
|
||||
c.mu.Unlock()
|
||||
return 0, c.closedErr()
|
||||
case <-c.writeClosed:
|
||||
c.mu.Unlock()
|
||||
return 0, io.ErrClosedPipe
|
||||
default:
|
||||
}
|
||||
c.mu.Unlock()
|
||||
@@ -651,6 +517,8 @@ func (c *queuedConn) Write(b []byte) (n int, err error) {
|
||||
return len(b), nil
|
||||
case <-c.closed:
|
||||
return 0, c.closedErr()
|
||||
case <-c.writeClosed:
|
||||
return 0, io.ErrClosedPipe
|
||||
}
|
||||
}
|
||||
|
||||
@@ -670,6 +538,7 @@ type streamSplitConn struct {
|
||||
client *http.Client
|
||||
pushURL string
|
||||
pullURL string
|
||||
finURL string
|
||||
closeURL string
|
||||
headerHost string
|
||||
auth *tunnelAuth
|
||||
@@ -697,15 +566,17 @@ func newStreamSplitConnFromInfo(info *sessionDialInfo) *streamSplitConn {
|
||||
client: info.client,
|
||||
pushURL: info.pushURL,
|
||||
pullURL: info.pullURL,
|
||||
finURL: info.finURL,
|
||||
closeURL: info.closeURL,
|
||||
headerHost: info.headerHost,
|
||||
auth: info.auth,
|
||||
queuedConn: queuedConn{
|
||||
rxc: make(chan []byte, 256),
|
||||
closed: make(chan struct{}),
|
||||
writeCh: make(chan []byte, 256),
|
||||
localAddr: &net.TCPAddr{},
|
||||
remoteAddr: &net.TCPAddr{},
|
||||
rxc: make(chan []byte, 256),
|
||||
closed: make(chan struct{}),
|
||||
writeCh: make(chan []byte, 256),
|
||||
writeClosed: make(chan struct{}),
|
||||
localAddr: &net.TCPAddr{},
|
||||
remoteAddr: &net.TCPAddr{},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -793,7 +664,7 @@ func (c *streamSplitConn) pullLoop() {
|
||||
}
|
||||
req.Host = c.headerHost
|
||||
applyTunnelHeaders(req.Header, c.headerHost, TunnelModeStream)
|
||||
applyTunnelAuthHeader(req.Header, c.auth, TunnelModeStream, http.MethodGet, "/stream")
|
||||
applyTunnelAuth(req, c.auth, TunnelModeStream, http.MethodGet, "/stream")
|
||||
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
@@ -891,7 +762,7 @@ func (c *streamSplitConn) pushLoop() {
|
||||
}
|
||||
req.Host = c.headerHost
|
||||
applyTunnelHeaders(req.Header, c.headerHost, TunnelModeStream)
|
||||
applyTunnelAuthHeader(req.Header, c.auth, TunnelModeStream, http.MethodPost, "/api/v1/upload")
|
||||
applyTunnelAuth(req, c.auth, TunnelModeStream, http.MethodPost, "/api/v1/upload")
|
||||
req.Header.Set("Content-Type", "application/octet-stream")
|
||||
|
||||
resp, err := c.client.Do(req)
|
||||
@@ -977,6 +848,27 @@ func (c *streamSplitConn) pushLoop() {
|
||||
return
|
||||
}
|
||||
resetTimer()
|
||||
case <-c.writeClosed:
|
||||
// Drain any already-accepted writes so CloseWrite does not lose data.
|
||||
for {
|
||||
select {
|
||||
case b := <-c.writeCh:
|
||||
if len(b) == 0 {
|
||||
continue
|
||||
}
|
||||
if buf.Len()+len(b) > maxBatchBytes {
|
||||
if err := flushWithRetry(); err != nil {
|
||||
_ = c.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
_, _ = buf.Write(b)
|
||||
default:
|
||||
_ = flushWithRetry()
|
||||
bestEffortCloseWriteSession(c.client, c.finURL, c.headerHost, TunnelModeStream, c.auth)
|
||||
return
|
||||
}
|
||||
}
|
||||
case <-c.closed:
|
||||
_ = flushWithRetry()
|
||||
return
|
||||
@@ -993,6 +885,7 @@ type pollConn struct {
|
||||
client *http.Client
|
||||
pushURL string
|
||||
pullURL string
|
||||
finURL string
|
||||
closeURL string
|
||||
headerHost string
|
||||
auth *tunnelAuth
|
||||
@@ -1037,15 +930,17 @@ func newPollConnFromInfo(info *sessionDialInfo) *pollConn {
|
||||
client: info.client,
|
||||
pushURL: info.pushURL,
|
||||
pullURL: info.pullURL,
|
||||
finURL: info.finURL,
|
||||
closeURL: info.closeURL,
|
||||
headerHost: info.headerHost,
|
||||
auth: info.auth,
|
||||
queuedConn: queuedConn{
|
||||
rxc: make(chan []byte, 128),
|
||||
closed: make(chan struct{}),
|
||||
writeCh: make(chan []byte, 256),
|
||||
localAddr: &net.TCPAddr{},
|
||||
remoteAddr: &net.TCPAddr{},
|
||||
rxc: make(chan []byte, 128),
|
||||
closed: make(chan struct{}),
|
||||
writeCh: make(chan []byte, 256),
|
||||
writeClosed: make(chan struct{}),
|
||||
localAddr: &net.TCPAddr{},
|
||||
remoteAddr: &net.TCPAddr{},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1117,14 +1012,14 @@ func (c *pollConn) pullLoop() {
|
||||
default:
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, c.pullURL, nil)
|
||||
req, err := http.NewRequestWithContext(c.ctx, http.MethodGet, c.pullURL, nil)
|
||||
if err != nil {
|
||||
_ = c.Close()
|
||||
return
|
||||
}
|
||||
req.Host = c.headerHost
|
||||
applyTunnelHeaders(req.Header, c.headerHost, TunnelModePoll)
|
||||
applyTunnelAuthHeader(req.Header, c.auth, TunnelModePoll, http.MethodGet, "/stream")
|
||||
applyTunnelAuth(req, c.auth, TunnelModePoll, http.MethodGet, "/stream")
|
||||
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
@@ -1202,21 +1097,25 @@ func (c *pollConn) pushLoop() {
|
||||
return nil
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, c.pushURL, bytes.NewReader(buf.Bytes()))
|
||||
reqCtx, cancel := context.WithTimeout(c.ctx, 20*time.Second)
|
||||
req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, c.pushURL, bytes.NewReader(buf.Bytes()))
|
||||
if err != nil {
|
||||
cancel()
|
||||
return err
|
||||
}
|
||||
req.Host = c.headerHost
|
||||
applyTunnelHeaders(req.Header, c.headerHost, TunnelModePoll)
|
||||
applyTunnelAuthHeader(req.Header, c.auth, TunnelModePoll, http.MethodPost, "/api/v1/upload")
|
||||
applyTunnelAuth(req, c.auth, TunnelModePoll, http.MethodPost, "/api/v1/upload")
|
||||
req.Header.Set("Content-Type", "text/plain")
|
||||
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
cancel()
|
||||
return err
|
||||
}
|
||||
_, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 4*1024))
|
||||
_ = resp.Body.Close()
|
||||
cancel()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("bad status: %s", resp.Status)
|
||||
}
|
||||
@@ -1309,6 +1208,41 @@ func (c *pollConn) pushLoop() {
|
||||
return
|
||||
}
|
||||
resetTimer()
|
||||
case <-c.writeClosed:
|
||||
// Drain any already-accepted writes so CloseWrite does not lose data.
|
||||
for {
|
||||
select {
|
||||
case b := <-c.writeCh:
|
||||
if len(b) == 0 {
|
||||
continue
|
||||
}
|
||||
for len(b) > 0 {
|
||||
chunk := b
|
||||
if len(chunk) > maxLineRawBytes {
|
||||
chunk = b[:maxLineRawBytes]
|
||||
}
|
||||
b = b[len(chunk):]
|
||||
|
||||
encLen := base64.StdEncoding.EncodedLen(len(chunk))
|
||||
if pendingRaw+len(chunk) > maxBatchBytes || buf.Len()+encLen+1 > maxBatchBytes*2 {
|
||||
if err := flushWithRetry(); err != nil {
|
||||
_ = c.closeWithError(fmt.Errorf("poll push flush failed: %w", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
tmp := make([]byte, base64.StdEncoding.EncodedLen(len(chunk)))
|
||||
base64.StdEncoding.Encode(tmp, chunk)
|
||||
buf.Write(tmp)
|
||||
buf.WriteByte('\n')
|
||||
pendingRaw += len(chunk)
|
||||
}
|
||||
default:
|
||||
_ = flushWithRetry()
|
||||
bestEffortCloseWriteSession(c.client, c.finURL, c.headerHost, TunnelModePoll, c.auth)
|
||||
return
|
||||
}
|
||||
}
|
||||
case <-c.closed:
|
||||
_ = flushWithRetry()
|
||||
return
|
||||
@@ -1478,19 +1412,50 @@ func (s *TunnelServer) HandleConn(rawConn net.Conn) (HandleResult, net.Conn, err
|
||||
|
||||
tunnelHeader := strings.ToLower(strings.TrimSpace(req.headers["x-sudoku-tunnel"]))
|
||||
if tunnelHeader == "" {
|
||||
// Not our tunnel; replay full bytes to legacy handler.
|
||||
prefix := make([]byte, 0, len(headerBytes)+len(buffered))
|
||||
prefix = append(prefix, headerBytes...)
|
||||
prefix = append(prefix, buffered...)
|
||||
return HandlePassThrough, newPreBufferedConn(rawConn, prefix), nil
|
||||
}
|
||||
if s.mode == TunnelModeLegacy {
|
||||
if s.passThroughOnReject {
|
||||
// Some CDNs / forward proxies may strip unknown headers. When AuthKey is enabled, we can
|
||||
// safely infer the intended tunnel mode by verifying the Authorization token against
|
||||
// both stream/poll modes and picking the one that matches.
|
||||
if s.auth != nil {
|
||||
u, err := url.ParseRequestURI(req.target)
|
||||
if err == nil {
|
||||
path, ok := stripPathRoot(s.pathRoot, u.Path)
|
||||
if ok && s.isAllowedBasePath(path) {
|
||||
authVal := req.headers["authorization"]
|
||||
if authVal == "" {
|
||||
authVal = u.Query().Get(tunnelAuthQueryKey)
|
||||
}
|
||||
streamOK := s.auth.verifyValue(authVal, TunnelModeStream, req.method, path, time.Now())
|
||||
pollOK := s.auth.verifyValue(authVal, TunnelModePoll, req.method, path, time.Now())
|
||||
switch {
|
||||
case streamOK && !pollOK:
|
||||
tunnelHeader = string(TunnelModeStream)
|
||||
case pollOK && !streamOK:
|
||||
tunnelHeader = string(TunnelModePoll)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if tunnelHeader == "" {
|
||||
// Not our tunnel; replay full bytes to legacy handler.
|
||||
prefix := make([]byte, 0, len(headerBytes)+len(buffered))
|
||||
prefix = append(prefix, headerBytes...)
|
||||
prefix = append(prefix, buffered...)
|
||||
return HandlePassThrough, newPreBufferedConn(rawConn, prefix), nil
|
||||
}
|
||||
}
|
||||
|
||||
reject := func() (HandleResult, net.Conn, error) {
|
||||
prefix := make([]byte, 0, len(headerBytes)+len(buffered))
|
||||
prefix = append(prefix, headerBytes...)
|
||||
prefix = append(prefix, buffered...)
|
||||
return HandlePassThrough, newRejectedPreBufferedConn(rawConn, prefix), nil
|
||||
}
|
||||
|
||||
if s.mode == TunnelModeLegacy {
|
||||
if s.passThroughOnReject {
|
||||
return reject()
|
||||
}
|
||||
_ = writeSimpleHTTPResponse(rawConn, http.StatusNotFound, "not found")
|
||||
_ = rawConn.Close()
|
||||
return HandleDone, nil, nil
|
||||
@@ -1500,10 +1465,7 @@ func (s *TunnelServer) HandleConn(rawConn net.Conn) (HandleResult, net.Conn, err
|
||||
case TunnelModeStream:
|
||||
if s.mode != TunnelModeStream && s.mode != TunnelModeAuto {
|
||||
if s.passThroughOnReject {
|
||||
prefix := make([]byte, 0, len(headerBytes)+len(buffered))
|
||||
prefix = append(prefix, headerBytes...)
|
||||
prefix = append(prefix, buffered...)
|
||||
return HandlePassThrough, newPreBufferedConn(rawConn, prefix), nil
|
||||
return reject()
|
||||
}
|
||||
_ = writeSimpleHTTPResponse(rawConn, http.StatusNotFound, "not found")
|
||||
_ = rawConn.Close()
|
||||
@@ -1513,10 +1475,7 @@ func (s *TunnelServer) HandleConn(rawConn net.Conn) (HandleResult, net.Conn, err
|
||||
case TunnelModePoll:
|
||||
if s.mode != TunnelModePoll && s.mode != TunnelModeAuto {
|
||||
if s.passThroughOnReject {
|
||||
prefix := make([]byte, 0, len(headerBytes)+len(buffered))
|
||||
prefix = append(prefix, headerBytes...)
|
||||
prefix = append(prefix, buffered...)
|
||||
return HandlePassThrough, newPreBufferedConn(rawConn, prefix), nil
|
||||
return reject()
|
||||
}
|
||||
_ = writeSimpleHTTPResponse(rawConn, http.StatusNotFound, "not found")
|
||||
_ = rawConn.Close()
|
||||
@@ -1525,10 +1484,7 @@ func (s *TunnelServer) HandleConn(rawConn net.Conn) (HandleResult, net.Conn, err
|
||||
return s.handlePoll(rawConn, req, headerBytes, buffered)
|
||||
default:
|
||||
if s.passThroughOnReject {
|
||||
prefix := make([]byte, 0, len(headerBytes)+len(buffered))
|
||||
prefix = append(prefix, headerBytes...)
|
||||
prefix = append(prefix, buffered...)
|
||||
return HandlePassThrough, newPreBufferedConn(rawConn, prefix), nil
|
||||
return reject()
|
||||
}
|
||||
_ = writeSimpleHTTPResponse(rawConn, http.StatusNotFound, "not found")
|
||||
_ = rawConn.Close()
|
||||
@@ -1619,13 +1575,52 @@ func readAllBuffered(r *bufio.Reader) []byte {
|
||||
|
||||
type preBufferedConn struct {
|
||||
net.Conn
|
||||
buf []byte
|
||||
buf []byte
|
||||
recorded []byte
|
||||
rejected bool
|
||||
}
|
||||
|
||||
func newPreBufferedConn(conn net.Conn, pre []byte) net.Conn {
|
||||
func (p *preBufferedConn) CloseWrite() error {
|
||||
if p == nil || p.Conn == nil {
|
||||
return nil
|
||||
}
|
||||
if cw, ok := p.Conn.(interface{ CloseWrite() error }); ok {
|
||||
return cw.CloseWrite()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *preBufferedConn) CloseRead() error {
|
||||
if p == nil || p.Conn == nil {
|
||||
return nil
|
||||
}
|
||||
if cr, ok := p.Conn.(interface{ CloseRead() error }); ok {
|
||||
return cr.CloseRead()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func newPreBufferedConn(conn net.Conn, pre []byte) *preBufferedConn {
|
||||
cpy := make([]byte, len(pre))
|
||||
copy(cpy, pre)
|
||||
return &preBufferedConn{Conn: conn, buf: cpy}
|
||||
return &preBufferedConn{Conn: conn, buf: cpy, recorded: cpy}
|
||||
}
|
||||
|
||||
func newRejectedPreBufferedConn(conn net.Conn, pre []byte) *preBufferedConn {
|
||||
c := newPreBufferedConn(conn, pre)
|
||||
c.rejected = true
|
||||
return c
|
||||
}
|
||||
|
||||
func (p *preBufferedConn) IsHTTPMaskRejected() bool { return p.rejected }
|
||||
|
||||
func (p *preBufferedConn) GetBufferedAndRecorded() []byte {
|
||||
if len(p.recorded) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]byte, len(p.recorded))
|
||||
copy(out, p.recorded)
|
||||
return out
|
||||
}
|
||||
|
||||
func (p *preBufferedConn) Read(b []byte) (int, error) {
|
||||
@@ -1682,7 +1677,7 @@ func (s *TunnelServer) handleStream(rawConn net.Conn, req *httpRequestHeader, he
|
||||
prefix := make([]byte, 0, len(headerBytes)+len(buffered))
|
||||
prefix = append(prefix, headerBytes...)
|
||||
prefix = append(prefix, buffered...)
|
||||
return HandlePassThrough, newPreBufferedConn(rawConn, prefix), nil
|
||||
return HandlePassThrough, newRejectedPreBufferedConn(rawConn, prefix), nil
|
||||
}
|
||||
_ = writeSimpleHTTPResponse(rawConn, code, body)
|
||||
_ = rawConn.Close()
|
||||
@@ -1699,21 +1694,29 @@ func (s *TunnelServer) handleStream(rawConn net.Conn, req *httpRequestHeader, he
|
||||
if !ok || !s.isAllowedBasePath(path) {
|
||||
return rejectOrReply(http.StatusNotFound, "not found")
|
||||
}
|
||||
if !s.auth.verify(req.headers, TunnelModeStream, req.method, path, time.Now()) {
|
||||
authVal := req.headers["authorization"]
|
||||
if authVal == "" {
|
||||
authVal = u.Query().Get(tunnelAuthQueryKey)
|
||||
}
|
||||
if !s.auth.verifyValue(authVal, TunnelModeStream, req.method, path, time.Now()) {
|
||||
return rejectOrReply(http.StatusNotFound, "not found")
|
||||
}
|
||||
|
||||
token := u.Query().Get("token")
|
||||
closeFlag := u.Query().Get("close") == "1"
|
||||
finFlag := u.Query().Get("fin") == "1"
|
||||
|
||||
switch strings.ToUpper(req.method) {
|
||||
case http.MethodGet:
|
||||
// Stream split-session: GET /session (no token) => token + start tunnel on a server-side pipe.
|
||||
if token == "" && path == "/session" {
|
||||
return s.authorizeSession(rawConn)
|
||||
return s.sessionAuthorize(rawConn)
|
||||
}
|
||||
// Stream split-session: GET /stream?token=... => downlink poll.
|
||||
if token != "" && path == "/stream" {
|
||||
if s.passThroughOnReject && !s.sessionHas(token) {
|
||||
return rejectOrReply(http.StatusNotFound, "not found")
|
||||
}
|
||||
return s.streamPull(rawConn, token)
|
||||
}
|
||||
return rejectOrReply(http.StatusBadRequest, "bad request")
|
||||
@@ -1721,13 +1724,26 @@ func (s *TunnelServer) handleStream(rawConn net.Conn, req *httpRequestHeader, he
|
||||
case http.MethodPost:
|
||||
// Stream split-session: POST /api/v1/upload?token=... => uplink push.
|
||||
if token != "" && path == "/api/v1/upload" {
|
||||
if s.passThroughOnReject && !s.sessionHas(token) {
|
||||
return rejectOrReply(http.StatusNotFound, "not found")
|
||||
}
|
||||
if closeFlag {
|
||||
s.closeSession(token)
|
||||
return rejectOrReply(http.StatusOK, "")
|
||||
s.sessionClose(token)
|
||||
_ = writeSimpleHTTPResponse(rawConn, http.StatusOK, "")
|
||||
_ = rawConn.Close()
|
||||
return HandleDone, nil, nil
|
||||
}
|
||||
if finFlag {
|
||||
s.sessionCloseWrite(token)
|
||||
_ = writeSimpleHTTPResponse(rawConn, http.StatusOK, "")
|
||||
_ = rawConn.Close()
|
||||
return HandleDone, nil, nil
|
||||
}
|
||||
bodyReader, err := newRequestBodyReader(newPreBufferedConn(rawConn, buffered), req.headers)
|
||||
if err != nil {
|
||||
return rejectOrReply(http.StatusBadRequest, "bad request")
|
||||
_ = writeSimpleHTTPResponse(rawConn, http.StatusBadRequest, "bad request")
|
||||
_ = rawConn.Close()
|
||||
return HandleDone, nil, nil
|
||||
}
|
||||
return s.streamPush(rawConn, token, bodyReader)
|
||||
}
|
||||
@@ -1825,7 +1841,7 @@ func (s *TunnelServer) handlePoll(rawConn net.Conn, req *httpRequestHeader, head
|
||||
prefix := make([]byte, 0, len(headerBytes)+len(buffered))
|
||||
prefix = append(prefix, headerBytes...)
|
||||
prefix = append(prefix, buffered...)
|
||||
return HandlePassThrough, newPreBufferedConn(rawConn, prefix), nil
|
||||
return HandlePassThrough, newRejectedPreBufferedConn(rawConn, prefix), nil
|
||||
}
|
||||
_ = writeSimpleHTTPResponse(rawConn, code, body)
|
||||
_ = rawConn.Close()
|
||||
@@ -1841,18 +1857,26 @@ func (s *TunnelServer) handlePoll(rawConn net.Conn, req *httpRequestHeader, head
|
||||
if !ok || !s.isAllowedBasePath(path) {
|
||||
return rejectOrReply(http.StatusNotFound, "not found")
|
||||
}
|
||||
if !s.auth.verify(req.headers, TunnelModePoll, req.method, path, time.Now()) {
|
||||
authVal := req.headers["authorization"]
|
||||
if authVal == "" {
|
||||
authVal = u.Query().Get(tunnelAuthQueryKey)
|
||||
}
|
||||
if !s.auth.verifyValue(authVal, TunnelModePoll, req.method, path, time.Now()) {
|
||||
return rejectOrReply(http.StatusNotFound, "not found")
|
||||
}
|
||||
|
||||
token := u.Query().Get("token")
|
||||
closeFlag := u.Query().Get("close") == "1"
|
||||
finFlag := u.Query().Get("fin") == "1"
|
||||
switch strings.ToUpper(req.method) {
|
||||
case http.MethodGet:
|
||||
if token == "" && path == "/session" {
|
||||
return s.authorizeSession(rawConn)
|
||||
return s.sessionAuthorize(rawConn)
|
||||
}
|
||||
if token != "" && path == "/stream" {
|
||||
if s.passThroughOnReject && !s.sessionHas(token) {
|
||||
return rejectOrReply(http.StatusNotFound, "not found")
|
||||
}
|
||||
return s.pollPull(rawConn, token)
|
||||
}
|
||||
return rejectOrReply(http.StatusBadRequest, "bad request")
|
||||
@@ -1860,13 +1884,26 @@ func (s *TunnelServer) handlePoll(rawConn net.Conn, req *httpRequestHeader, head
|
||||
if token == "" || path != "/api/v1/upload" {
|
||||
return rejectOrReply(http.StatusBadRequest, "bad request")
|
||||
}
|
||||
if s.passThroughOnReject && !s.sessionHas(token) {
|
||||
return rejectOrReply(http.StatusNotFound, "not found")
|
||||
}
|
||||
if closeFlag {
|
||||
s.closeSession(token)
|
||||
return rejectOrReply(http.StatusOK, "")
|
||||
s.sessionClose(token)
|
||||
_ = writeSimpleHTTPResponse(rawConn, http.StatusOK, "")
|
||||
_ = rawConn.Close()
|
||||
return HandleDone, nil, nil
|
||||
}
|
||||
if finFlag {
|
||||
s.sessionCloseWrite(token)
|
||||
_ = writeSimpleHTTPResponse(rawConn, http.StatusOK, "")
|
||||
_ = rawConn.Close()
|
||||
return HandleDone, nil, nil
|
||||
}
|
||||
bodyReader, err := newRequestBodyReader(newPreBufferedConn(rawConn, buffered), req.headers)
|
||||
if err != nil {
|
||||
return rejectOrReply(http.StatusBadRequest, "bad request")
|
||||
_ = writeSimpleHTTPResponse(rawConn, http.StatusBadRequest, "bad request")
|
||||
_ = rawConn.Close()
|
||||
return HandleDone, nil, nil
|
||||
}
|
||||
return s.pollPush(rawConn, token, bodyReader)
|
||||
default:
|
||||
@@ -1874,7 +1911,7 @@ func (s *TunnelServer) handlePoll(rawConn net.Conn, req *httpRequestHeader, head
|
||||
}
|
||||
}
|
||||
|
||||
func (s *TunnelServer) authorizeSession(rawConn net.Conn) (HandleResult, net.Conn, error) {
|
||||
func (s *TunnelServer) sessionAuthorize(rawConn net.Conn) (HandleResult, net.Conn, error) {
|
||||
token, err := newSessionToken()
|
||||
if err != nil {
|
||||
_ = writeSimpleHTTPResponse(rawConn, http.StatusInternalServerError, "internal error")
|
||||
@@ -1882,13 +1919,13 @@ func (s *TunnelServer) authorizeSession(rawConn net.Conn) (HandleResult, net.Con
|
||||
return HandleDone, nil, nil
|
||||
}
|
||||
|
||||
c1, c2 := net.Pipe()
|
||||
c1, c2 := newHalfPipe()
|
||||
|
||||
s.mu.Lock()
|
||||
s.sessions[token] = &tunnelSession{conn: c2, lastActive: time.Now()}
|
||||
s.mu.Unlock()
|
||||
|
||||
go s.reapSessionLater(token)
|
||||
go s.reapLater(token)
|
||||
|
||||
_ = writeTokenHTTPResponse(rawConn, token)
|
||||
_ = rawConn.Close()
|
||||
@@ -1903,31 +1940,50 @@ func newSessionToken() (string, error) {
|
||||
return base64.RawURLEncoding.EncodeToString(b[:]), nil
|
||||
}
|
||||
|
||||
func (s *TunnelServer) reapSessionLater(token string) {
|
||||
func (s *TunnelServer) reapLater(token string) {
|
||||
ttl := s.sessionTTL
|
||||
if ttl <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
timer := time.NewTimer(ttl)
|
||||
defer timer.Stop()
|
||||
<-timer.C
|
||||
|
||||
s.mu.Lock()
|
||||
sess, ok := s.sessions[token]
|
||||
if !ok {
|
||||
for {
|
||||
<-timer.C
|
||||
|
||||
s.mu.Lock()
|
||||
sess, ok := s.sessions[token]
|
||||
if !ok {
|
||||
s.mu.Unlock()
|
||||
return
|
||||
}
|
||||
idle := time.Since(sess.lastActive)
|
||||
if idle >= ttl {
|
||||
delete(s.sessions, token)
|
||||
s.mu.Unlock()
|
||||
_ = sess.conn.Close()
|
||||
return
|
||||
}
|
||||
next := ttl - idle
|
||||
s.mu.Unlock()
|
||||
return
|
||||
|
||||
// Avoid a tight loop under high-frequency activity; we only need best-effort cleanup.
|
||||
if next < 50*time.Millisecond {
|
||||
next = 50 * time.Millisecond
|
||||
}
|
||||
timer.Reset(next)
|
||||
}
|
||||
if time.Since(sess.lastActive) < ttl {
|
||||
s.mu.Unlock()
|
||||
return
|
||||
}
|
||||
delete(s.sessions, token)
|
||||
s.mu.Unlock()
|
||||
_ = sess.conn.Close()
|
||||
}
|
||||
|
||||
func (s *TunnelServer) getSession(token string) (*tunnelSession, bool) {
|
||||
func (s *TunnelServer) sessionHas(token string) bool {
|
||||
s.mu.Lock()
|
||||
_, ok := s.sessions[token]
|
||||
s.mu.Unlock()
|
||||
return ok
|
||||
}
|
||||
|
||||
func (s *TunnelServer) sessionGet(token string) (*tunnelSession, bool) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
sess, ok := s.sessions[token]
|
||||
@@ -1938,7 +1994,7 @@ func (s *TunnelServer) getSession(token string) (*tunnelSession, bool) {
|
||||
return sess, true
|
||||
}
|
||||
|
||||
func (s *TunnelServer) closeSession(token string) {
|
||||
func (s *TunnelServer) sessionClose(token string) {
|
||||
s.mu.Lock()
|
||||
sess, ok := s.sessions[token]
|
||||
if ok {
|
||||
@@ -1950,8 +2006,20 @@ func (s *TunnelServer) closeSession(token string) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *TunnelServer) sessionCloseWrite(token string) {
|
||||
sess, ok := s.sessionGet(token)
|
||||
if !ok || sess == nil || sess.conn == nil {
|
||||
return
|
||||
}
|
||||
if cw, ok := sess.conn.(interface{ CloseWrite() error }); ok {
|
||||
_ = cw.CloseWrite()
|
||||
return
|
||||
}
|
||||
_ = sess.conn.Close()
|
||||
}
|
||||
|
||||
func (s *TunnelServer) pollPush(rawConn net.Conn, token string, body io.Reader) (HandleResult, net.Conn, error) {
|
||||
sess, ok := s.getSession(token)
|
||||
sess, ok := s.sessionGet(token)
|
||||
if !ok {
|
||||
_ = writeSimpleHTTPResponse(rawConn, http.StatusForbidden, "forbidden")
|
||||
_ = rawConn.Close()
|
||||
@@ -1985,7 +2053,7 @@ func (s *TunnelServer) pollPush(rawConn net.Conn, token string, body io.Reader)
|
||||
_, werr := sess.conn.Write(decoded[:n])
|
||||
_ = sess.conn.SetWriteDeadline(time.Time{})
|
||||
if werr != nil {
|
||||
s.closeSession(token)
|
||||
s.sessionClose(token)
|
||||
_ = writeSimpleHTTPResponse(rawConn, http.StatusGone, "gone")
|
||||
_ = rawConn.Close()
|
||||
return HandleDone, nil, nil
|
||||
@@ -1998,7 +2066,7 @@ func (s *TunnelServer) pollPush(rawConn net.Conn, token string, body io.Reader)
|
||||
}
|
||||
|
||||
func (s *TunnelServer) streamPush(rawConn net.Conn, token string, body io.Reader) (HandleResult, net.Conn, error) {
|
||||
sess, ok := s.getSession(token)
|
||||
sess, ok := s.sessionGet(token)
|
||||
if !ok {
|
||||
_ = writeSimpleHTTPResponse(rawConn, http.StatusForbidden, "forbidden")
|
||||
_ = rawConn.Close()
|
||||
@@ -2023,7 +2091,7 @@ func (s *TunnelServer) streamPush(rawConn net.Conn, token string, body io.Reader
|
||||
_, werr := sess.conn.Write(payload)
|
||||
_ = sess.conn.SetWriteDeadline(time.Time{})
|
||||
if werr != nil {
|
||||
s.closeSession(token)
|
||||
s.sessionClose(token)
|
||||
_ = writeSimpleHTTPResponse(rawConn, http.StatusGone, "gone")
|
||||
_ = rawConn.Close()
|
||||
return HandleDone, nil, nil
|
||||
@@ -2036,7 +2104,7 @@ func (s *TunnelServer) streamPush(rawConn net.Conn, token string, body io.Reader
|
||||
}
|
||||
|
||||
func (s *TunnelServer) streamPull(rawConn net.Conn, token string) (HandleResult, net.Conn, error) {
|
||||
sess, ok := s.getSession(token)
|
||||
sess, ok := s.sessionGet(token)
|
||||
if !ok {
|
||||
_ = writeSimpleHTTPResponse(rawConn, http.StatusForbidden, "forbidden")
|
||||
_ = rawConn.Close()
|
||||
@@ -2074,14 +2142,14 @@ func (s *TunnelServer) streamPull(rawConn net.Conn, token string) (HandleResult,
|
||||
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, net.ErrClosed) {
|
||||
return HandleDone, nil, nil
|
||||
}
|
||||
s.closeSession(token)
|
||||
s.sessionClose(token)
|
||||
return HandleDone, nil, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *TunnelServer) pollPull(rawConn net.Conn, token string) (HandleResult, net.Conn, error) {
|
||||
sess, ok := s.getSession(token)
|
||||
sess, ok := s.sessionGet(token)
|
||||
if !ok {
|
||||
_ = writeSimpleHTTPResponse(rawConn, http.StatusForbidden, "forbidden")
|
||||
_ = rawConn.Close()
|
||||
@@ -2123,7 +2191,7 @@ func (s *TunnelServer) pollPull(rawConn net.Conn, token string) (HandleResult, n
|
||||
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, net.ErrClosed) {
|
||||
return HandleDone, nil, nil
|
||||
}
|
||||
s.closeSession(token)
|
||||
s.sessionClose(token)
|
||||
return HandleDone, nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -52,8 +52,28 @@ type Conn struct {
|
||||
pendingData []byte
|
||||
hintBuf []byte
|
||||
|
||||
rng *rand.Rand
|
||||
paddingRate float32
|
||||
rng *rand.Rand
|
||||
paddingThreshold uint64
|
||||
}
|
||||
|
||||
func (sc *Conn) CloseWrite() error {
|
||||
if sc == nil || sc.Conn == nil {
|
||||
return nil
|
||||
}
|
||||
if cw, ok := sc.Conn.(interface{ CloseWrite() error }); ok {
|
||||
return cw.CloseWrite()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sc *Conn) CloseRead() error {
|
||||
if sc == nil || sc.Conn == nil {
|
||||
return nil
|
||||
}
|
||||
if cr, ok := sc.Conn.(interface{ CloseRead() error }); ok {
|
||||
return cr.CloseRead()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewConn(c net.Conn, table *Table, pMin, pMax int, record bool) *Conn {
|
||||
@@ -64,19 +84,15 @@ func NewConn(c net.Conn, table *Table, pMin, pMax int, record bool) *Conn {
|
||||
seed := int64(binary.BigEndian.Uint64(seedBytes[:]))
|
||||
localRng := rand.New(rand.NewSource(seed))
|
||||
|
||||
min := float32(pMin) / 100.0
|
||||
rng := float32(pMax-pMin) / 100.0
|
||||
rate := min + localRng.Float32()*rng
|
||||
|
||||
sc := &Conn{
|
||||
Conn: c,
|
||||
table: table,
|
||||
reader: bufio.NewReaderSize(c, IOBufferSize),
|
||||
rawBuf: make([]byte, IOBufferSize),
|
||||
pendingData: make([]byte, 0, 4096),
|
||||
hintBuf: make([]byte, 0, 4),
|
||||
rng: localRng,
|
||||
paddingRate: rate,
|
||||
Conn: c,
|
||||
table: table,
|
||||
reader: bufio.NewReaderSize(c, IOBufferSize),
|
||||
rawBuf: make([]byte, IOBufferSize),
|
||||
pendingData: make([]byte, 0, 4096),
|
||||
hintBuf: make([]byte, 0, 4),
|
||||
rng: localRng,
|
||||
paddingThreshold: pickPaddingThreshold(localRng, pMin, pMax),
|
||||
}
|
||||
if record {
|
||||
sc.recorder = new(bytes.Buffer)
|
||||
@@ -127,7 +143,7 @@ func (sc *Conn) Write(p []byte) (n int, err error) {
|
||||
padLen := len(pads)
|
||||
|
||||
for _, b := range p {
|
||||
if sc.rng.Float32() < sc.paddingRate {
|
||||
if shouldPad(sc.rng, sc.paddingThreshold) {
|
||||
out = append(out, pads[sc.rng.Intn(padLen)])
|
||||
}
|
||||
|
||||
@@ -136,14 +152,14 @@ func (sc *Conn) Write(p []byte) (n int, err error) {
|
||||
|
||||
perm := perm4[sc.rng.Intn(len(perm4))]
|
||||
for _, idx := range perm {
|
||||
if sc.rng.Float32() < sc.paddingRate {
|
||||
if shouldPad(sc.rng, sc.paddingThreshold) {
|
||||
out = append(out, pads[sc.rng.Intn(padLen)])
|
||||
}
|
||||
out = append(out, puzzle[idx])
|
||||
}
|
||||
}
|
||||
|
||||
if sc.rng.Float32() < sc.paddingRate {
|
||||
if shouldPad(sc.rng, sc.paddingThreshold) {
|
||||
out = append(out, pads[sc.rng.Intn(padLen)])
|
||||
}
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ const (
|
||||
)
|
||||
|
||||
// 1. 使用 12字节->16组 的块处理优化 Write (减少循环开销)
|
||||
// 2. 使用浮点随机概率判断 Padding,与纯 Sudoku 保持流量特征一致
|
||||
// 2. 使用整数阈值随机概率判断 Padding,与纯 Sudoku 保持流量特征一致
|
||||
// 3. Read 使用 copy 移动避免底层数组泄漏
|
||||
type PackedConn struct {
|
||||
net.Conn
|
||||
@@ -37,11 +37,31 @@ type PackedConn struct {
|
||||
readBitBuf uint64
|
||||
readBits int
|
||||
|
||||
// 随机数与填充控制 - 使用浮点随机,与 Conn 一致
|
||||
rng *rand.Rand
|
||||
paddingRate float32 // 与 Conn 保持一致的随机概率模型
|
||||
padMarker byte
|
||||
padPool []byte
|
||||
// 随机数与填充控制 - 使用整数阈值随机,与 Conn 一致
|
||||
rng *rand.Rand
|
||||
paddingThreshold uint64 // 与 Conn 保持一致的随机概率模型
|
||||
padMarker byte
|
||||
padPool []byte
|
||||
}
|
||||
|
||||
func (pc *PackedConn) CloseWrite() error {
|
||||
if pc == nil || pc.Conn == nil {
|
||||
return nil
|
||||
}
|
||||
if cw, ok := pc.Conn.(interface{ CloseWrite() error }); ok {
|
||||
return cw.CloseWrite()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pc *PackedConn) CloseRead() error {
|
||||
if pc == nil || pc.Conn == nil {
|
||||
return nil
|
||||
}
|
||||
if cr, ok := pc.Conn.(interface{ CloseRead() error }); ok {
|
||||
return cr.CloseRead()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewPackedConn(c net.Conn, table *Table, pMin, pMax int) *PackedConn {
|
||||
@@ -52,20 +72,15 @@ func NewPackedConn(c net.Conn, table *Table, pMin, pMax int) *PackedConn {
|
||||
seed := int64(binary.BigEndian.Uint64(seedBytes[:]))
|
||||
localRng := rand.New(rand.NewSource(seed))
|
||||
|
||||
// 与 Conn 保持一致的 padding 概率计算
|
||||
min := float32(pMin) / 100.0
|
||||
rng := float32(pMax-pMin) / 100.0
|
||||
rate := min + localRng.Float32()*rng
|
||||
|
||||
pc := &PackedConn{
|
||||
Conn: c,
|
||||
table: table,
|
||||
reader: bufio.NewReaderSize(c, IOBufferSize),
|
||||
rawBuf: make([]byte, IOBufferSize),
|
||||
pendingData: make([]byte, 0, 4096),
|
||||
writeBuf: make([]byte, 0, 4096),
|
||||
rng: localRng,
|
||||
paddingRate: rate,
|
||||
Conn: c,
|
||||
table: table,
|
||||
reader: bufio.NewReaderSize(c, IOBufferSize),
|
||||
rawBuf: make([]byte, IOBufferSize),
|
||||
pendingData: make([]byte, 0, 4096),
|
||||
writeBuf: make([]byte, 0, 4096),
|
||||
rng: localRng,
|
||||
paddingThreshold: pickPaddingThreshold(localRng, pMin, pMax),
|
||||
}
|
||||
|
||||
pc.padMarker = table.layout.padMarker
|
||||
@@ -80,9 +95,9 @@ func NewPackedConn(c net.Conn, table *Table, pMin, pMax int) *PackedConn {
|
||||
return pc
|
||||
}
|
||||
|
||||
// maybeAddPadding 内联辅助:根据浮点概率插入 padding
|
||||
// maybeAddPadding 内联辅助:根据概率阈值插入 padding
|
||||
func (pc *PackedConn) maybeAddPadding(out []byte) []byte {
|
||||
if pc.rng.Float32() < pc.paddingRate {
|
||||
if shouldPad(pc.rng, pc.paddingThreshold) {
|
||||
out = append(out, pc.getPaddingByte())
|
||||
}
|
||||
return out
|
||||
|
||||
44
transport/sudoku/obfs/sudoku/padding_prob.go
Normal file
44
transport/sudoku/obfs/sudoku/padding_prob.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package sudoku
|
||||
|
||||
import "math/rand"
|
||||
|
||||
const probOne = uint64(1) << 32
|
||||
|
||||
func pickPaddingThreshold(r *rand.Rand, pMin, pMax int) uint64 {
|
||||
if r == nil {
|
||||
return 0
|
||||
}
|
||||
if pMin < 0 {
|
||||
pMin = 0
|
||||
}
|
||||
if pMax < pMin {
|
||||
pMax = pMin
|
||||
}
|
||||
if pMax > 100 {
|
||||
pMax = 100
|
||||
}
|
||||
if pMin > 100 {
|
||||
pMin = 100
|
||||
}
|
||||
|
||||
min := uint64(pMin) * probOne / 100
|
||||
max := uint64(pMax) * probOne / 100
|
||||
if max <= min {
|
||||
return min
|
||||
}
|
||||
u := uint64(r.Uint32())
|
||||
return min + (u * (max - min) >> 32)
|
||||
}
|
||||
|
||||
func shouldPad(r *rand.Rand, threshold uint64) bool {
|
||||
if threshold == 0 {
|
||||
return false
|
||||
}
|
||||
if threshold >= probOne {
|
||||
return true
|
||||
}
|
||||
if r == nil {
|
||||
return false
|
||||
}
|
||||
return uint64(r.Uint32()) < threshold
|
||||
}
|
||||
Reference in New Issue
Block a user