mirror of
https://github.com/MetaCubeX/mihomo.git
synced 2026-02-26 16:57:08 +00:00
chore: align sudoku with upstream v0.2.0 (#2549)
This commit is contained in:
@@ -7,7 +7,6 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
|
||||||
|
|
||||||
N "github.com/metacubex/mihomo/common/net"
|
N "github.com/metacubex/mihomo/common/net"
|
||||||
C "github.com/metacubex/mihomo/constant"
|
C "github.com/metacubex/mihomo/constant"
|
||||||
@@ -22,10 +21,8 @@ type Sudoku struct {
|
|||||||
httpMaskMu sync.Mutex
|
httpMaskMu sync.Mutex
|
||||||
httpMaskClient *sudoku.HTTPMaskTunnelClient
|
httpMaskClient *sudoku.HTTPMaskTunnelClient
|
||||||
|
|
||||||
muxMu sync.Mutex
|
muxMu sync.Mutex
|
||||||
muxClient *sudoku.MultiplexClient
|
muxClient *sudoku.MultiplexClient
|
||||||
muxBackoffUntil time.Time
|
|
||||||
muxLastErr error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type SudokuOption struct {
|
type SudokuOption struct {
|
||||||
@@ -58,7 +55,7 @@ func (s *Sudoku) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Con
|
|||||||
|
|
||||||
muxMode := normalizeHTTPMaskMultiplex(cfg.HTTPMaskMultiplex)
|
muxMode := normalizeHTTPMaskMultiplex(cfg.HTTPMaskMultiplex)
|
||||||
if muxMode == "on" && !cfg.DisableHTTPMask && httpTunnelModeEnabled(cfg.HTTPMaskMode) {
|
if muxMode == "on" && !cfg.DisableHTTPMask && httpTunnelModeEnabled(cfg.HTTPMaskMode) {
|
||||||
stream, muxErr := s.dialMultiplex(ctx, cfg.TargetAddress, muxMode)
|
stream, muxErr := s.dialMultiplex(ctx, cfg.TargetAddress)
|
||||||
if muxErr == nil {
|
if muxErr == nil {
|
||||||
return NewConn(stream, s), nil
|
return NewConn(stream, s), nil
|
||||||
}
|
}
|
||||||
@@ -312,9 +309,9 @@ func (s *Sudoku) dialAndHandshake(ctx context.Context, cfg *sudoku.ProtocolConfi
|
|||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Sudoku) dialMultiplex(ctx context.Context, targetAddress string, mode string) (net.Conn, error) {
|
func (s *Sudoku) dialMultiplex(ctx context.Context, targetAddress string) (net.Conn, error) {
|
||||||
for attempt := 0; attempt < 2; attempt++ {
|
for attempt := 0; attempt < 2; attempt++ {
|
||||||
client, err := s.getOrCreateMuxClient(ctx, mode)
|
client, err := s.getOrCreateMuxClient(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -330,21 +327,11 @@ func (s *Sudoku) dialMultiplex(ctx context.Context, targetAddress string, mode s
|
|||||||
return nil, fmt.Errorf("multiplex open stream failed")
|
return nil, fmt.Errorf("multiplex open stream failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Sudoku) getOrCreateMuxClient(ctx context.Context, mode string) (*sudoku.MultiplexClient, error) {
|
func (s *Sudoku) getOrCreateMuxClient(ctx context.Context) (*sudoku.MultiplexClient, error) {
|
||||||
if s == nil {
|
if s == nil {
|
||||||
return nil, fmt.Errorf("nil adapter")
|
return nil, fmt.Errorf("nil adapter")
|
||||||
}
|
}
|
||||||
|
|
||||||
if mode == "auto" {
|
|
||||||
s.muxMu.Lock()
|
|
||||||
backoffUntil := s.muxBackoffUntil
|
|
||||||
lastErr := s.muxLastErr
|
|
||||||
s.muxMu.Unlock()
|
|
||||||
if time.Now().Before(backoffUntil) {
|
|
||||||
return nil, fmt.Errorf("multiplex temporarily disabled: %v", lastErr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
s.muxMu.Lock()
|
s.muxMu.Lock()
|
||||||
if s.muxClient != nil && !s.muxClient.IsClosed() {
|
if s.muxClient != nil && !s.muxClient.IsClosed() {
|
||||||
client := s.muxClient
|
client := s.muxClient
|
||||||
@@ -363,20 +350,12 @@ func (s *Sudoku) getOrCreateMuxClient(ctx context.Context, mode string) (*sudoku
|
|||||||
baseCfg := s.baseConf
|
baseCfg := s.baseConf
|
||||||
baseConn, err := s.dialAndHandshake(ctx, &baseCfg)
|
baseConn, err := s.dialAndHandshake(ctx, &baseCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if mode == "auto" {
|
|
||||||
s.muxLastErr = err
|
|
||||||
s.muxBackoffUntil = time.Now().Add(45 * time.Second)
|
|
||||||
}
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
client, err := sudoku.StartMultiplexClient(baseConn)
|
client, err := sudoku.StartMultiplexClient(baseConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = baseConn.Close()
|
_ = baseConn.Close()
|
||||||
if mode == "auto" {
|
|
||||||
s.muxLastErr = err
|
|
||||||
s.muxBackoffUntil = time.Now().Add(45 * time.Second)
|
|
||||||
}
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -384,16 +363,6 @@ func (s *Sudoku) getOrCreateMuxClient(ctx context.Context, mode string) (*sudoku
|
|||||||
return client, nil
|
return client, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Sudoku) noteMuxFailure(mode string, err error) {
|
|
||||||
if mode != "auto" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.muxMu.Lock()
|
|
||||||
s.muxLastErr = err
|
|
||||||
s.muxBackoffUntil = time.Now().Add(45 * time.Second)
|
|
||||||
s.muxMu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Sudoku) resetMuxClient() {
|
func (s *Sudoku) resetMuxClient() {
|
||||||
s.muxMu.Lock()
|
s.muxMu.Lock()
|
||||||
defer s.muxMu.Unlock()
|
defer s.muxMu.Unlock()
|
||||||
|
|||||||
@@ -1082,19 +1082,19 @@ proxies: # socks5
|
|||||||
server: server_ip/domain # 1.2.3.4 or domain
|
server: server_ip/domain # 1.2.3.4 or domain
|
||||||
port: 443
|
port: 443
|
||||||
key: "<client_key>" # 如果你使用sudoku生成的ED25519密钥对,请填写密钥对中的私钥,否则填入和服务端相同的uuid
|
key: "<client_key>" # 如果你使用sudoku生成的ED25519密钥对,请填写密钥对中的私钥,否则填入和服务端相同的uuid
|
||||||
aead-method: chacha20-poly1305 # 可选值:chacha20-poly1305、aes-128-gcm、none 我们保证在none的情况下sudoku混淆层仍然确保安全
|
aead-method: chacha20-poly1305 # 可选:chacha20-poly1305、aes-128-gcm、none(不建议;且 enable-pure-downlink=false 时不可用)
|
||||||
padding-min: 2 # 最小填充字节数
|
padding-min: 2 # 最小填充率(0-100)
|
||||||
padding-max: 7 # 最大填充字节数
|
padding-max: 7 # 最大填充率(0-100,必须 >= padding-min)
|
||||||
table-type: prefer_ascii # 可选值:prefer_ascii、prefer_entropy 前者全ascii映射,后者保证熵值(汉明1)低于3
|
table-type: prefer_ascii # 可选值:prefer_ascii、prefer_entropy 前者全ascii映射,后者保证熵值(汉明1)低于3
|
||||||
# custom-table: xpxvvpvv # 可选,自定义字节布局,必须包含2个x、2个p、4个v,可随意组合。启用此处则需配置`table-type`为`prefer_entropy`
|
# custom-table: xpxvvpvv # 可选,自定义字节布局,必须包含2个x、2个p、4个v,可随意组合。启用此处则需配置`table-type`为`prefer_entropy`
|
||||||
# custom-tables: ["xpxvvpvv", "vxpvxvvp"] # 可选,自定义字节布局列表(x/v/p),用于 xvp 模式轮换;非空时覆盖 custom-table
|
# custom-tables: ["xpxvvpvv", "vxpvxvvp"] # 可选,自定义字节布局列表(x/v/p),用于 xvp 模式轮换;非空时覆盖 custom-table
|
||||||
http-mask: true # 是否启用http掩码
|
http-mask: true # 是否启用http掩码
|
||||||
# http-mask-mode: legacy # 可选:legacy(默认)、stream、poll、auto;stream/poll/auto 支持走 CDN/反代
|
# http-mask-mode: legacy # 可选:legacy(默认)、stream(split-stream)、poll、auto(先 stream 再 poll);stream/poll/auto 支持走 CDN/反代
|
||||||
# http-mask-tls: true # 可选:仅在 http-mask-mode 为 stream/poll/auto 时生效;true 强制 https;false 强制 http(不会根据端口自动推断)
|
# http-mask-tls: true # 可选:仅在 http-mask-mode 为 stream/poll/auto 时生效;true 强制 https;false 强制 http(不会根据端口自动推断)
|
||||||
# http-mask-host: "" # 可选:覆盖 Host/SNI(支持 example.com 或 example.com:443);仅在 http-mask-mode 为 stream/poll/auto 时生效
|
# http-mask-host: "" # 可选:覆盖 Host/SNI(支持 example.com 或 example.com:443);仅在 http-mask-mode 为 stream/poll/auto 时生效
|
||||||
# path-root: "" # 可选:HTTP 隧道端点一级路径前缀(双方需一致),例如 "aabbcc" => /aabbcc/session、/aabbcc/stream、/aabbcc/api/v1/upload
|
# path-root: "" # 可选:HTTP 隧道端点一级路径前缀(双方需一致),例如 "aabbcc" 或 "/aabbcc/" => /aabbcc/session、/aabbcc/stream、/aabbcc/api/v1/upload
|
||||||
# http-mask-multiplex: off # 可选:off(默认)、auto(复用 h1.1 keep-alive / h2 连接,减少每次建链 RTT)、on(单条隧道内多路复用多个目标连接;仅在 http-mask-mode=stream/poll/auto 生效)
|
# http-mask-multiplex: off # 可选:off(默认)、auto(复用底层 HTTP 连接,减少建链 RTT)、on(Sudoku mux 单隧道多目标;仅在 http-mask-mode=stream/poll/auto 生效)
|
||||||
enable-pure-downlink: false # 是否启用混淆下行,false的情况下能在保证数据安全的前提下极大提升下行速度,与服务端端保持相同(如果此处为false,则要求aead不可为none)
|
enable-pure-downlink: false # 可选:false=带宽优化下行(更快,要求 aead-method != none);true=纯 Sudoku 下行
|
||||||
|
|
||||||
# anytls
|
# anytls
|
||||||
- name: anytls
|
- name: anytls
|
||||||
@@ -1632,17 +1632,17 @@ listeners:
|
|||||||
port: 8443 # 仅支持单端口
|
port: 8443 # 仅支持单端口
|
||||||
listen: 0.0.0.0
|
listen: 0.0.0.0
|
||||||
key: "<server_key>" # 如果你使用sudoku生成的ED25519密钥对,此处是密钥对中的公钥,当然,你也可以仅仅使用任意uuid充当key
|
key: "<server_key>" # 如果你使用sudoku生成的ED25519密钥对,此处是密钥对中的公钥,当然,你也可以仅仅使用任意uuid充当key
|
||||||
aead-method: chacha20-poly1305 # 支持chacha20-poly1305或者aes-128-gcm以及none,sudoku的混淆层可以确保none情况下数据安全
|
aead-method: chacha20-poly1305 # 可选:chacha20-poly1305、aes-128-gcm、none(不建议;且 enable-pure-downlink=false 时不可用)
|
||||||
padding-min: 1 # 填充最小长度
|
padding-min: 1 # 最小填充率(0-100)
|
||||||
padding-max: 15 # 填充最大长度,均不建议过大
|
padding-max: 15 # 最大填充率(0-100,必须 >= padding-min)
|
||||||
table-type: prefer_ascii # 可选值:prefer_ascii、prefer_entropy 前者全ascii映射,后者保证熵值(汉明1)低于3
|
table-type: prefer_ascii # 可选值:prefer_ascii、prefer_entropy 前者全ascii映射,后者保证熵值(汉明1)低于3
|
||||||
# custom-table: xpxvvpvv # 可选,自定义字节布局,必须包含2个x、2个p、4个v,可随意组合。启用此处则需配置`table-type`为`prefer_entropy`
|
# custom-table: xpxvvpvv # 可选,自定义字节布局,必须包含2个x、2个p、4个v,可随意组合。启用此处则需配置`table-type`为`prefer_entropy`
|
||||||
# custom-tables: ["xpxvvpvv", "vxpvxvvp"] # 可选,自定义字节布局列表(x/v/p),用于 xvp 模式轮换;非空时覆盖 custom-table
|
# custom-tables: ["xpxvvpvv", "vxpvxvvp"] # 可选,自定义字节布局列表(x/v/p),用于 xvp 模式轮换;非空时覆盖 custom-table
|
||||||
handshake-timeout: 5 # optional
|
handshake-timeout: 5 # 可选(秒)
|
||||||
enable-pure-downlink: false # 是否启用混淆下行,false的情况下能在保证数据安全的前提下极大提升下行速度,与客户端保持相同(如果此处为false,则要求aead不可为none)
|
enable-pure-downlink: false # 可选:false=带宽优化下行(更快,要求 aead-method != none);true=纯 Sudoku 下行
|
||||||
disable-http-mask: false # 可选:禁用 http 掩码/隧道(默认为 false)
|
disable-http-mask: false # 可选:禁用 http 掩码/隧道(默认为 false)
|
||||||
# http-mask-mode: legacy # 可选:legacy(默认)、stream、poll、auto;stream/poll/auto 支持走 CDN/反代
|
# http-mask-mode: legacy # 可选:legacy(默认)、stream(split-stream)、poll、auto(先 stream 再 poll);stream/poll/auto 支持走 CDN/反代
|
||||||
# path-root: "" # 可选:HTTP 隧道端点一级路径前缀(双方需一致),例如 "aabbcc" => /aabbcc/session、/aabbcc/stream、/aabbcc/api/v1/upload
|
# path-root: "" # 可选:HTTP 隧道端点一级路径前缀(双方需一致),例如 "aabbcc" 或 "/aabbcc/" => /aabbcc/session、/aabbcc/stream、/aabbcc/api/v1/upload
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -32,7 +32,8 @@ type ProtocolConfig struct {
|
|||||||
PaddingMin int
|
PaddingMin int
|
||||||
PaddingMax int
|
PaddingMax int
|
||||||
|
|
||||||
// EnablePureDownlink toggles the bandwidth-optimized downlink mode.
|
// EnablePureDownlink enables the pure Sudoku downlink mode.
|
||||||
|
// When false, the connection uses the bandwidth-optimized packed downlink (requires AEAD).
|
||||||
EnablePureDownlink bool
|
EnablePureDownlink bool
|
||||||
|
|
||||||
// Client-only: final target "host:port".
|
// Client-only: final target "host:port".
|
||||||
@@ -46,7 +47,7 @@ type ProtocolConfig struct {
|
|||||||
|
|
||||||
// HTTPMaskMode controls how the HTTP layer behaves:
|
// HTTPMaskMode controls how the HTTP layer behaves:
|
||||||
// - "legacy": write a fake HTTP/1.1 header then switch to raw stream (default, not CDN-compatible)
|
// - "legacy": write a fake HTTP/1.1 header then switch to raw stream (default, not CDN-compatible)
|
||||||
// - "stream": real HTTP tunnel (stream-one or split), CDN-compatible
|
// - "stream": real HTTP tunnel (split-stream), CDN-compatible
|
||||||
// - "poll": plain HTTP tunnel (authorize/push/pull), strong restricted-network pass-through
|
// - "poll": plain HTTP tunnel (authorize/push/pull), strong restricted-network pass-through
|
||||||
// - "auto": try stream then fall back to poll
|
// - "auto": try stream then fall back to poll
|
||||||
HTTPMaskMode string
|
HTTPMaskMode string
|
||||||
@@ -114,7 +115,8 @@ func (c *ProtocolConfig) Validate() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if v := strings.TrimSpace(c.HTTPMaskPathRoot); v != "" {
|
if v := strings.TrimSpace(c.HTTPMaskPathRoot); v != "" {
|
||||||
if strings.Contains(v, "/") {
|
v = strings.Trim(v, "/")
|
||||||
|
if v == "" || strings.Contains(v, "/") {
|
||||||
return fmt.Errorf("invalid http-mask-path-root: must be a single path segment")
|
return fmt.Errorf("invalid http-mask-path-root: must be a single path segment")
|
||||||
}
|
}
|
||||||
for i := 0; i < len(v); i++ {
|
for i := 0; i < len(v); i++ {
|
||||||
|
|||||||
@@ -22,6 +22,26 @@ type AEADConn struct {
|
|||||||
nonceSize int
|
nonceSize int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (cc *AEADConn) CloseWrite() error {
|
||||||
|
if cc == nil || cc.Conn == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if cw, ok := cc.Conn.(interface{ CloseWrite() error }); ok {
|
||||||
|
return cw.CloseWrite()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cc *AEADConn) CloseRead() error {
|
||||||
|
if cc == nil || cc.Conn == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if cr, ok := cc.Conn.(interface{ CloseRead() error }); ok {
|
||||||
|
return cr.CloseRead()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func NewAEADConn(c net.Conn, key string, method string) (*AEADConn, error) {
|
func NewAEADConn(c net.Conn, key string, method string) (*AEADConn, error) {
|
||||||
if method == "none" {
|
if method == "none" {
|
||||||
return &AEADConn{Conn: c, aead: nil}, nil
|
return &AEADConn{Conn: c, aead: nil}, nil
|
||||||
|
|||||||
@@ -383,7 +383,11 @@ func (c *stream) enqueue(payload []byte) {
|
|||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.queue = append(c.queue, payload)
|
if len(c.readBuf) == 0 && len(c.queue) == 0 {
|
||||||
|
c.readBuf = payload
|
||||||
|
} else {
|
||||||
|
c.queue = append(c.queue, payload)
|
||||||
|
}
|
||||||
c.cond.Signal()
|
c.cond.Signal()
|
||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
}
|
}
|
||||||
@@ -491,6 +495,9 @@ func (c *stream) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *stream) CloseWrite() error { return c.Close() }
|
||||||
|
func (c *stream) CloseRead() error { return c.Close() }
|
||||||
|
|
||||||
func (c *stream) LocalAddr() net.Addr { return c.localAddr }
|
func (c *stream) LocalAddr() net.Addr { return c.localAddr }
|
||||||
func (c *stream) RemoteAddr() net.Addr { return c.remoteAddr }
|
func (c *stream) RemoteAddr() net.Addr { return c.remoteAddr }
|
||||||
|
|
||||||
@@ -501,4 +508,3 @@ func (c *stream) SetDeadline(t time.Time) error {
|
|||||||
}
|
}
|
||||||
func (c *stream) SetReadDeadline(time.Time) error { return nil }
|
func (c *stream) SetReadDeadline(time.Time) error { return nil }
|
||||||
func (c *stream) SetWriteDeadline(time.Time) error { return nil }
|
func (c *stream) SetWriteDeadline(time.Time) error { return nil }
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"crypto/subtle"
|
"crypto/subtle"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
"github.com/metacubex/http"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@@ -13,6 +14,7 @@ import (
|
|||||||
const (
|
const (
|
||||||
tunnelAuthHeaderKey = "Authorization"
|
tunnelAuthHeaderKey = "Authorization"
|
||||||
tunnelAuthHeaderPrefix = "Bearer "
|
tunnelAuthHeaderPrefix = "Bearer "
|
||||||
|
tunnelAuthQueryKey = "auth"
|
||||||
)
|
)
|
||||||
|
|
||||||
type tunnelAuth struct {
|
type tunnelAuth struct {
|
||||||
@@ -61,8 +63,15 @@ func (a *tunnelAuth) verify(headers map[string]string, mode TunnelMode, method,
|
|||||||
if headers == nil {
|
if headers == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
return a.verifyValue(headers["authorization"], mode, method, path, now)
|
||||||
|
}
|
||||||
|
|
||||||
val := strings.TrimSpace(headers["authorization"])
|
func (a *tunnelAuth) verifyValue(val string, mode TunnelMode, method, path string, now time.Time) bool {
|
||||||
|
if a == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
val = strings.TrimSpace(val)
|
||||||
if val == "" {
|
if val == "" {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -121,11 +130,9 @@ func (a *tunnelAuth) sign(mode TunnelMode, method, path string, ts int64) [16]by
|
|||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
type headerSetter interface {
|
type httpHeaderSetter = http.Header
|
||||||
Set(key, value string)
|
|
||||||
}
|
|
||||||
|
|
||||||
func applyTunnelAuthHeader(h headerSetter, auth *tunnelAuth, mode TunnelMode, method, path string) {
|
func applyTunnelAuthHeader(h httpHeaderSetter, auth *tunnelAuth, mode TunnelMode, method, path string) {
|
||||||
if auth == nil || h == nil {
|
if auth == nil || h == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -135,3 +142,19 @@ func applyTunnelAuthHeader(h headerSetter, auth *tunnelAuth, mode TunnelMode, me
|
|||||||
}
|
}
|
||||||
h.Set(tunnelAuthHeaderKey, tunnelAuthHeaderPrefix+token)
|
h.Set(tunnelAuthHeaderKey, tunnelAuthHeaderPrefix+token)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func applyTunnelAuth(req *http.Request, auth *tunnelAuth, mode TunnelMode, method, path string) {
|
||||||
|
if auth == nil || req == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
token := auth.token(mode, method, path, time.Now())
|
||||||
|
if token == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
req.Header.Set(tunnelAuthHeaderKey, tunnelAuthHeaderPrefix+token)
|
||||||
|
if req.URL != nil {
|
||||||
|
q := req.URL.Query()
|
||||||
|
q.Set(tunnelAuthQueryKey, token)
|
||||||
|
req.URL.RawQuery = q.Encode()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
229
transport/sudoku/obfs/httpmask/halfpipe.go
Normal file
229
transport/sudoku/obfs/httpmask/halfpipe.go
Normal file
@@ -0,0 +1,229 @@
|
|||||||
|
package httpmask
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type pipeDeadline struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
timer *time.Timer
|
||||||
|
cancel chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func makePipeDeadline() pipeDeadline {
|
||||||
|
return pipeDeadline{cancel: make(chan struct{})}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *pipeDeadline) set(t time.Time) {
|
||||||
|
d.mu.Lock()
|
||||||
|
defer d.mu.Unlock()
|
||||||
|
|
||||||
|
if d.timer != nil && !d.timer.Stop() {
|
||||||
|
<-d.cancel
|
||||||
|
}
|
||||||
|
d.timer = nil
|
||||||
|
|
||||||
|
closed := isClosedPipeChan(d.cancel)
|
||||||
|
if t.IsZero() {
|
||||||
|
if closed {
|
||||||
|
d.cancel = make(chan struct{})
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if dur := time.Until(t); dur > 0 {
|
||||||
|
if closed {
|
||||||
|
d.cancel = make(chan struct{})
|
||||||
|
}
|
||||||
|
d.timer = time.AfterFunc(dur, func() {
|
||||||
|
close(d.cancel)
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !closed {
|
||||||
|
close(d.cancel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *pipeDeadline) wait() <-chan struct{} {
|
||||||
|
d.mu.Lock()
|
||||||
|
ch := d.cancel
|
||||||
|
d.mu.Unlock()
|
||||||
|
return ch
|
||||||
|
}
|
||||||
|
|
||||||
|
func isClosedPipeChan(ch <-chan struct{}) bool {
|
||||||
|
select {
|
||||||
|
case <-ch:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type halfPipeAddr struct{}
|
||||||
|
|
||||||
|
func (halfPipeAddr) Network() string { return "pipe" }
|
||||||
|
func (halfPipeAddr) String() string { return "pipe" }
|
||||||
|
|
||||||
|
type halfPipeConn struct {
|
||||||
|
wrMu sync.Mutex
|
||||||
|
|
||||||
|
rdRx <-chan []byte
|
||||||
|
rdTx chan<- int
|
||||||
|
|
||||||
|
wrTx chan<- []byte
|
||||||
|
wrRx <-chan int
|
||||||
|
|
||||||
|
readOnce sync.Once
|
||||||
|
writeOnce sync.Once
|
||||||
|
|
||||||
|
localReadDone chan struct{}
|
||||||
|
localWriteDone chan struct{}
|
||||||
|
|
||||||
|
remoteReadDone <-chan struct{}
|
||||||
|
remoteWriteDone <-chan struct{}
|
||||||
|
|
||||||
|
readDeadline pipeDeadline
|
||||||
|
writeDeadline pipeDeadline
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHalfPipe() (net.Conn, net.Conn) {
|
||||||
|
cb1 := make(chan []byte)
|
||||||
|
cb2 := make(chan []byte)
|
||||||
|
cn1 := make(chan int)
|
||||||
|
cn2 := make(chan int)
|
||||||
|
|
||||||
|
r1 := make(chan struct{})
|
||||||
|
w1 := make(chan struct{})
|
||||||
|
r2 := make(chan struct{})
|
||||||
|
w2 := make(chan struct{})
|
||||||
|
|
||||||
|
c1 := &halfPipeConn{
|
||||||
|
rdRx: cb1,
|
||||||
|
rdTx: cn1,
|
||||||
|
wrTx: cb2,
|
||||||
|
wrRx: cn2,
|
||||||
|
|
||||||
|
localReadDone: r1,
|
||||||
|
localWriteDone: w1,
|
||||||
|
remoteReadDone: r2,
|
||||||
|
remoteWriteDone: w2,
|
||||||
|
|
||||||
|
readDeadline: makePipeDeadline(),
|
||||||
|
writeDeadline: makePipeDeadline(),
|
||||||
|
}
|
||||||
|
c2 := &halfPipeConn{
|
||||||
|
rdRx: cb2,
|
||||||
|
rdTx: cn2,
|
||||||
|
wrTx: cb1,
|
||||||
|
wrRx: cn1,
|
||||||
|
|
||||||
|
localReadDone: r2,
|
||||||
|
localWriteDone: w2,
|
||||||
|
remoteReadDone: r1,
|
||||||
|
remoteWriteDone: w1,
|
||||||
|
|
||||||
|
readDeadline: makePipeDeadline(),
|
||||||
|
writeDeadline: makePipeDeadline(),
|
||||||
|
}
|
||||||
|
return c1, c2
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*halfPipeConn) LocalAddr() net.Addr { return halfPipeAddr{} }
|
||||||
|
func (*halfPipeConn) RemoteAddr() net.Addr { return halfPipeAddr{} }
|
||||||
|
|
||||||
|
func (c *halfPipeConn) Read(p []byte) (int, error) {
|
||||||
|
switch {
|
||||||
|
case isClosedPipeChan(c.localReadDone):
|
||||||
|
return 0, io.ErrClosedPipe
|
||||||
|
case isClosedPipeChan(c.remoteWriteDone):
|
||||||
|
return 0, io.EOF
|
||||||
|
case isClosedPipeChan(c.readDeadline.wait()):
|
||||||
|
return 0, os.ErrDeadlineExceeded
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case b := <-c.rdRx:
|
||||||
|
n := copy(p, b)
|
||||||
|
c.rdTx <- n
|
||||||
|
return n, nil
|
||||||
|
case <-c.localReadDone:
|
||||||
|
return 0, io.ErrClosedPipe
|
||||||
|
case <-c.remoteWriteDone:
|
||||||
|
return 0, io.EOF
|
||||||
|
case <-c.readDeadline.wait():
|
||||||
|
return 0, os.ErrDeadlineExceeded
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *halfPipeConn) Write(p []byte) (int, error) {
|
||||||
|
switch {
|
||||||
|
case isClosedPipeChan(c.localWriteDone):
|
||||||
|
return 0, io.ErrClosedPipe
|
||||||
|
case isClosedPipeChan(c.remoteReadDone):
|
||||||
|
return 0, io.ErrClosedPipe
|
||||||
|
case isClosedPipeChan(c.writeDeadline.wait()):
|
||||||
|
return 0, os.ErrDeadlineExceeded
|
||||||
|
}
|
||||||
|
|
||||||
|
c.wrMu.Lock()
|
||||||
|
defer c.wrMu.Unlock()
|
||||||
|
|
||||||
|
var (
|
||||||
|
total int
|
||||||
|
rest = p
|
||||||
|
)
|
||||||
|
for once := true; once || len(rest) > 0; once = false {
|
||||||
|
select {
|
||||||
|
case c.wrTx <- rest:
|
||||||
|
n := <-c.wrRx
|
||||||
|
rest = rest[n:]
|
||||||
|
total += n
|
||||||
|
case <-c.localWriteDone:
|
||||||
|
return total, io.ErrClosedPipe
|
||||||
|
case <-c.remoteReadDone:
|
||||||
|
return total, io.ErrClosedPipe
|
||||||
|
case <-c.writeDeadline.wait():
|
||||||
|
return total, os.ErrDeadlineExceeded
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return total, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *halfPipeConn) CloseWrite() error {
|
||||||
|
c.writeOnce.Do(func() { close(c.localWriteDone) })
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *halfPipeConn) CloseRead() error {
|
||||||
|
c.readOnce.Do(func() { close(c.localReadDone) })
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *halfPipeConn) Close() error {
|
||||||
|
_ = c.CloseRead()
|
||||||
|
_ = c.CloseWrite()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *halfPipeConn) SetDeadline(t time.Time) error {
|
||||||
|
c.readDeadline.set(t)
|
||||||
|
c.writeDeadline.set(t)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *halfPipeConn) SetReadDeadline(t time.Time) error {
|
||||||
|
c.readDeadline.set(t)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *halfPipeConn) SetWriteDeadline(t time.Time) error {
|
||||||
|
c.writeDeadline.set(t)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -241,38 +241,6 @@ func parseTunnelToken(body []byte) (string, error) {
|
|||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type httpStreamConn struct {
|
|
||||||
reader io.ReadCloser
|
|
||||||
writer *io.PipeWriter
|
|
||||||
cancel context.CancelFunc
|
|
||||||
|
|
||||||
localAddr net.Addr
|
|
||||||
remoteAddr net.Addr
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *httpStreamConn) Read(p []byte) (int, error) { return c.reader.Read(p) }
|
|
||||||
func (c *httpStreamConn) Write(p []byte) (int, error) { return c.writer.Write(p) }
|
|
||||||
|
|
||||||
func (c *httpStreamConn) Close() error {
|
|
||||||
if c.cancel != nil {
|
|
||||||
c.cancel()
|
|
||||||
}
|
|
||||||
if c.writer != nil {
|
|
||||||
_ = c.writer.CloseWithError(io.ErrClosedPipe)
|
|
||||||
}
|
|
||||||
if c.reader != nil {
|
|
||||||
return c.reader.Close()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *httpStreamConn) LocalAddr() net.Addr { return c.localAddr }
|
|
||||||
func (c *httpStreamConn) RemoteAddr() net.Addr { return c.remoteAddr }
|
|
||||||
|
|
||||||
func (c *httpStreamConn) SetDeadline(time.Time) error { return nil }
|
|
||||||
func (c *httpStreamConn) SetReadDeadline(time.Time) error { return nil }
|
|
||||||
func (c *httpStreamConn) SetWriteDeadline(time.Time) error { return nil }
|
|
||||||
|
|
||||||
type httpClientTarget struct {
|
type httpClientTarget struct {
|
||||||
scheme string
|
scheme string
|
||||||
urlHost string
|
urlHost string
|
||||||
@@ -332,6 +300,7 @@ type sessionDialInfo struct {
|
|||||||
client *http.Client
|
client *http.Client
|
||||||
pushURL string
|
pushURL string
|
||||||
pullURL string
|
pullURL string
|
||||||
|
finURL string
|
||||||
closeURL string
|
closeURL string
|
||||||
headerHost string
|
headerHost string
|
||||||
auth *tunnelAuth
|
auth *tunnelAuth
|
||||||
@@ -350,7 +319,7 @@ func dialSessionWithClient(ctx context.Context, client *http.Client, target http
|
|||||||
}
|
}
|
||||||
req.Host = target.headerHost
|
req.Host = target.headerHost
|
||||||
applyTunnelHeaders(req.Header, target.headerHost, mode)
|
applyTunnelHeaders(req.Header, target.headerHost, mode)
|
||||||
applyTunnelAuthHeader(req.Header, auth, mode, http.MethodGet, "/session")
|
applyTunnelAuth(req, auth, mode, http.MethodGet, "/session")
|
||||||
|
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -375,12 +344,14 @@ func dialSessionWithClient(ctx context.Context, client *http.Client, target http
|
|||||||
|
|
||||||
pushURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: joinPathRoot(opts.PathRoot, "/api/v1/upload"), RawQuery: "token=" + url.QueryEscape(token)}).String()
|
pushURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: joinPathRoot(opts.PathRoot, "/api/v1/upload"), RawQuery: "token=" + url.QueryEscape(token)}).String()
|
||||||
pullURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: joinPathRoot(opts.PathRoot, "/stream"), RawQuery: "token=" + url.QueryEscape(token)}).String()
|
pullURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: joinPathRoot(opts.PathRoot, "/stream"), RawQuery: "token=" + url.QueryEscape(token)}).String()
|
||||||
|
finURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: joinPathRoot(opts.PathRoot, "/api/v1/upload"), RawQuery: "token=" + url.QueryEscape(token) + "&fin=1"}).String()
|
||||||
closeURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: joinPathRoot(opts.PathRoot, "/api/v1/upload"), RawQuery: "token=" + url.QueryEscape(token) + "&close=1"}).String()
|
closeURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: joinPathRoot(opts.PathRoot, "/api/v1/upload"), RawQuery: "token=" + url.QueryEscape(token) + "&close=1"}).String()
|
||||||
|
|
||||||
return &sessionDialInfo{
|
return &sessionDialInfo{
|
||||||
client: client,
|
client: client,
|
||||||
pushURL: pushURL,
|
pushURL: pushURL,
|
||||||
pullURL: pullURL,
|
pullURL: pullURL,
|
||||||
|
finURL: finURL,
|
||||||
closeURL: closeURL,
|
closeURL: closeURL,
|
||||||
headerHost: target.headerHost,
|
headerHost: target.headerHost,
|
||||||
auth: auth,
|
auth: auth,
|
||||||
@@ -409,7 +380,31 @@ func bestEffortCloseSession(client *http.Client, closeURL, headerHost string, mo
|
|||||||
}
|
}
|
||||||
req.Host = headerHost
|
req.Host = headerHost
|
||||||
applyTunnelHeaders(req.Header, headerHost, mode)
|
applyTunnelHeaders(req.Header, headerHost, mode)
|
||||||
applyTunnelAuthHeader(req.Header, auth, mode, http.MethodPost, "/api/v1/upload")
|
applyTunnelAuth(req, auth, mode, http.MethodPost, "/api/v1/upload")
|
||||||
|
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil || resp == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 4*1024))
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func bestEffortCloseWriteSession(client *http.Client, finURL, headerHost string, mode TunnelMode, auth *tunnelAuth) {
|
||||||
|
if client == nil || finURL == "" || headerHost == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
closeCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(closeCtx, http.MethodPost, finURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
req.Host = headerHost
|
||||||
|
applyTunnelHeaders(req.Header, headerHost, mode)
|
||||||
|
applyTunnelAuth(req, auth, mode, http.MethodPost, "/api/v1/upload")
|
||||||
|
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
if err != nil || resp == nil {
|
if err != nil || resp == nil {
|
||||||
@@ -420,160 +415,13 @@ func bestEffortCloseSession(client *http.Client, closeURL, headerHost string, mo
|
|||||||
}
|
}
|
||||||
|
|
||||||
func dialStreamWithClient(ctx context.Context, client *http.Client, target httpClientTarget, opts TunnelDialOptions) (net.Conn, error) {
|
func dialStreamWithClient(ctx context.Context, client *http.Client, target httpClientTarget, opts TunnelDialOptions) (net.Conn, error) {
|
||||||
// Prefer split-session (Cloudflare-friendly). Fall back to stream-one for older servers / environments.
|
// "stream" mode uses split-stream to stay CDN-friendly by default.
|
||||||
c, errSplit := dialStreamSplitWithClient(ctx, client, target, opts)
|
return dialStreamSplitWithClient(ctx, client, target, opts)
|
||||||
if errSplit == nil {
|
|
||||||
return c, nil
|
|
||||||
}
|
|
||||||
c2, errOne := dialStreamOneWithClient(ctx, client, target, opts)
|
|
||||||
if errOne == nil {
|
|
||||||
return c2, nil
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("dial stream failed: split: %v; stream-one: %w", errSplit, errOne)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func dialStream(ctx context.Context, serverAddress string, opts TunnelDialOptions) (net.Conn, error) {
|
func dialStream(ctx context.Context, serverAddress string, opts TunnelDialOptions) (net.Conn, error) {
|
||||||
// Prefer split-session (Cloudflare-friendly). Fall back to stream-one for older servers / environments.
|
// "stream" mode uses split-stream to stay CDN-friendly by default.
|
||||||
c, errSplit := dialStreamSplit(ctx, serverAddress, opts)
|
return dialStreamSplit(ctx, serverAddress, opts)
|
||||||
if errSplit == nil {
|
|
||||||
return c, nil
|
|
||||||
}
|
|
||||||
c2, errOne := dialStreamOne(ctx, serverAddress, opts)
|
|
||||||
if errOne == nil {
|
|
||||||
return c2, nil
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("dial stream failed: split: %v; stream-one: %w", errSplit, errOne)
|
|
||||||
}
|
|
||||||
|
|
||||||
func dialStreamOneWithClient(ctx context.Context, client *http.Client, target httpClientTarget, opts TunnelDialOptions) (net.Conn, error) {
|
|
||||||
if client == nil {
|
|
||||||
return nil, fmt.Errorf("nil http client")
|
|
||||||
}
|
|
||||||
|
|
||||||
auth := newTunnelAuth(opts.AuthKey, 0)
|
|
||||||
r := rngPool.Get().(*mrand.Rand)
|
|
||||||
basePath := paths[r.Intn(len(paths))]
|
|
||||||
path := joinPathRoot(opts.PathRoot, basePath)
|
|
||||||
ctype := contentTypes[r.Intn(len(contentTypes))]
|
|
||||||
rngPool.Put(r)
|
|
||||||
|
|
||||||
u := url.URL{
|
|
||||||
Scheme: target.scheme,
|
|
||||||
Host: target.urlHost,
|
|
||||||
Path: path,
|
|
||||||
}
|
|
||||||
|
|
||||||
reqBodyR, reqBodyW := io.Pipe()
|
|
||||||
|
|
||||||
connCtx, connCancel := context.WithCancel(context.Background())
|
|
||||||
req, err := http.NewRequestWithContext(connCtx, http.MethodPost, u.String(), reqBodyR)
|
|
||||||
if err != nil {
|
|
||||||
connCancel()
|
|
||||||
_ = reqBodyW.Close()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
req.Host = target.headerHost
|
|
||||||
|
|
||||||
applyTunnelHeaders(req.Header, target.headerHost, TunnelModeStream)
|
|
||||||
applyTunnelAuthHeader(req.Header, auth, TunnelModeStream, http.MethodPost, basePath)
|
|
||||||
req.Header.Set("Content-Type", ctype)
|
|
||||||
|
|
||||||
type doResult struct {
|
|
||||||
resp *http.Response
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
doCh := make(chan doResult, 1)
|
|
||||||
go func() {
|
|
||||||
resp, doErr := client.Do(req)
|
|
||||||
doCh <- doResult{resp: resp, err: doErr}
|
|
||||||
}()
|
|
||||||
|
|
||||||
streamConn := &httpStreamConn{
|
|
||||||
writer: reqBodyW,
|
|
||||||
cancel: connCancel,
|
|
||||||
localAddr: &net.TCPAddr{},
|
|
||||||
remoteAddr: &net.TCPAddr{},
|
|
||||||
}
|
|
||||||
|
|
||||||
type upgradeResult struct {
|
|
||||||
conn net.Conn
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
upgradeCh := make(chan upgradeResult, 1)
|
|
||||||
if opts.Upgrade == nil {
|
|
||||||
upgradeCh <- upgradeResult{conn: streamConn, err: nil}
|
|
||||||
} else {
|
|
||||||
go func() {
|
|
||||||
upgradeConn, err := opts.Upgrade(streamConn)
|
|
||||||
if err != nil {
|
|
||||||
upgradeCh <- upgradeResult{conn: nil, err: err}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if upgradeConn == nil {
|
|
||||||
upgradeConn = streamConn
|
|
||||||
}
|
|
||||||
upgradeCh <- upgradeResult{conn: upgradeConn, err: nil}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
outConn net.Conn
|
|
||||||
upgradeDone bool
|
|
||||||
responseReady bool
|
|
||||||
)
|
|
||||||
|
|
||||||
for !(upgradeDone && responseReady) {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
_ = streamConn.Close()
|
|
||||||
if outConn != nil && outConn != streamConn {
|
|
||||||
_ = outConn.Close()
|
|
||||||
}
|
|
||||||
return nil, ctx.Err()
|
|
||||||
|
|
||||||
case u := <-upgradeCh:
|
|
||||||
if u.err != nil {
|
|
||||||
_ = streamConn.Close()
|
|
||||||
return nil, u.err
|
|
||||||
}
|
|
||||||
outConn = u.conn
|
|
||||||
if outConn == nil {
|
|
||||||
outConn = streamConn
|
|
||||||
}
|
|
||||||
upgradeDone = true
|
|
||||||
|
|
||||||
case r := <-doCh:
|
|
||||||
if r.err != nil {
|
|
||||||
_ = streamConn.Close()
|
|
||||||
if outConn != nil && outConn != streamConn {
|
|
||||||
_ = outConn.Close()
|
|
||||||
}
|
|
||||||
return nil, r.err
|
|
||||||
}
|
|
||||||
if r.resp.StatusCode != http.StatusOK {
|
|
||||||
body, _ := io.ReadAll(io.LimitReader(r.resp.Body, 4*1024))
|
|
||||||
_ = r.resp.Body.Close()
|
|
||||||
_ = streamConn.Close()
|
|
||||||
if outConn != nil && outConn != streamConn {
|
|
||||||
_ = outConn.Close()
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("stream bad status: %s (%s)", r.resp.Status, strings.TrimSpace(string(body)))
|
|
||||||
}
|
|
||||||
|
|
||||||
streamConn.reader = r.resp.Body
|
|
||||||
responseReady = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return outConn, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func dialStreamOne(ctx context.Context, serverAddress string, opts TunnelDialOptions) (net.Conn, error) {
|
|
||||||
client, target, err := newHTTPClient(serverAddress, opts, 32)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return dialStreamOneWithClient(ctx, client, target, opts)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type queuedConn struct {
|
type queuedConn struct {
|
||||||
@@ -581,6 +429,9 @@ type queuedConn struct {
|
|||||||
closed chan struct{}
|
closed chan struct{}
|
||||||
|
|
||||||
writeCh chan []byte
|
writeCh chan []byte
|
||||||
|
// writeClosed is closed by CloseWrite to stop accepting new payloads.
|
||||||
|
// When closed, Write returns io.ErrClosedPipe, but Read is unaffected.
|
||||||
|
writeClosed chan struct{}
|
||||||
|
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
readBuf []byte
|
readBuf []byte
|
||||||
@@ -589,6 +440,18 @@ type queuedConn struct {
|
|||||||
remoteAddr net.Addr
|
remoteAddr net.Addr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *queuedConn) CloseWrite() error {
|
||||||
|
if c == nil || c.writeClosed == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
c.mu.Lock()
|
||||||
|
if !isClosedPipeChan(c.writeClosed) {
|
||||||
|
close(c.writeClosed)
|
||||||
|
}
|
||||||
|
c.mu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (c *queuedConn) closeWithError(err error) error {
|
func (c *queuedConn) closeWithError(err error) error {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
select {
|
select {
|
||||||
@@ -640,6 +503,9 @@ func (c *queuedConn) Write(b []byte) (n int, err error) {
|
|||||||
case <-c.closed:
|
case <-c.closed:
|
||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
return 0, c.closedErr()
|
return 0, c.closedErr()
|
||||||
|
case <-c.writeClosed:
|
||||||
|
c.mu.Unlock()
|
||||||
|
return 0, io.ErrClosedPipe
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
@@ -651,6 +517,8 @@ func (c *queuedConn) Write(b []byte) (n int, err error) {
|
|||||||
return len(b), nil
|
return len(b), nil
|
||||||
case <-c.closed:
|
case <-c.closed:
|
||||||
return 0, c.closedErr()
|
return 0, c.closedErr()
|
||||||
|
case <-c.writeClosed:
|
||||||
|
return 0, io.ErrClosedPipe
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -670,6 +538,7 @@ type streamSplitConn struct {
|
|||||||
client *http.Client
|
client *http.Client
|
||||||
pushURL string
|
pushURL string
|
||||||
pullURL string
|
pullURL string
|
||||||
|
finURL string
|
||||||
closeURL string
|
closeURL string
|
||||||
headerHost string
|
headerHost string
|
||||||
auth *tunnelAuth
|
auth *tunnelAuth
|
||||||
@@ -697,15 +566,17 @@ func newStreamSplitConnFromInfo(info *sessionDialInfo) *streamSplitConn {
|
|||||||
client: info.client,
|
client: info.client,
|
||||||
pushURL: info.pushURL,
|
pushURL: info.pushURL,
|
||||||
pullURL: info.pullURL,
|
pullURL: info.pullURL,
|
||||||
|
finURL: info.finURL,
|
||||||
closeURL: info.closeURL,
|
closeURL: info.closeURL,
|
||||||
headerHost: info.headerHost,
|
headerHost: info.headerHost,
|
||||||
auth: info.auth,
|
auth: info.auth,
|
||||||
queuedConn: queuedConn{
|
queuedConn: queuedConn{
|
||||||
rxc: make(chan []byte, 256),
|
rxc: make(chan []byte, 256),
|
||||||
closed: make(chan struct{}),
|
closed: make(chan struct{}),
|
||||||
writeCh: make(chan []byte, 256),
|
writeCh: make(chan []byte, 256),
|
||||||
localAddr: &net.TCPAddr{},
|
writeClosed: make(chan struct{}),
|
||||||
remoteAddr: &net.TCPAddr{},
|
localAddr: &net.TCPAddr{},
|
||||||
|
remoteAddr: &net.TCPAddr{},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -793,7 +664,7 @@ func (c *streamSplitConn) pullLoop() {
|
|||||||
}
|
}
|
||||||
req.Host = c.headerHost
|
req.Host = c.headerHost
|
||||||
applyTunnelHeaders(req.Header, c.headerHost, TunnelModeStream)
|
applyTunnelHeaders(req.Header, c.headerHost, TunnelModeStream)
|
||||||
applyTunnelAuthHeader(req.Header, c.auth, TunnelModeStream, http.MethodGet, "/stream")
|
applyTunnelAuth(req, c.auth, TunnelModeStream, http.MethodGet, "/stream")
|
||||||
|
|
||||||
resp, err := c.client.Do(req)
|
resp, err := c.client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -891,7 +762,7 @@ func (c *streamSplitConn) pushLoop() {
|
|||||||
}
|
}
|
||||||
req.Host = c.headerHost
|
req.Host = c.headerHost
|
||||||
applyTunnelHeaders(req.Header, c.headerHost, TunnelModeStream)
|
applyTunnelHeaders(req.Header, c.headerHost, TunnelModeStream)
|
||||||
applyTunnelAuthHeader(req.Header, c.auth, TunnelModeStream, http.MethodPost, "/api/v1/upload")
|
applyTunnelAuth(req, c.auth, TunnelModeStream, http.MethodPost, "/api/v1/upload")
|
||||||
req.Header.Set("Content-Type", "application/octet-stream")
|
req.Header.Set("Content-Type", "application/octet-stream")
|
||||||
|
|
||||||
resp, err := c.client.Do(req)
|
resp, err := c.client.Do(req)
|
||||||
@@ -977,6 +848,27 @@ func (c *streamSplitConn) pushLoop() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
resetTimer()
|
resetTimer()
|
||||||
|
case <-c.writeClosed:
|
||||||
|
// Drain any already-accepted writes so CloseWrite does not lose data.
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case b := <-c.writeCh:
|
||||||
|
if len(b) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if buf.Len()+len(b) > maxBatchBytes {
|
||||||
|
if err := flushWithRetry(); err != nil {
|
||||||
|
_ = c.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_, _ = buf.Write(b)
|
||||||
|
default:
|
||||||
|
_ = flushWithRetry()
|
||||||
|
bestEffortCloseWriteSession(c.client, c.finURL, c.headerHost, TunnelModeStream, c.auth)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
case <-c.closed:
|
case <-c.closed:
|
||||||
_ = flushWithRetry()
|
_ = flushWithRetry()
|
||||||
return
|
return
|
||||||
@@ -993,6 +885,7 @@ type pollConn struct {
|
|||||||
client *http.Client
|
client *http.Client
|
||||||
pushURL string
|
pushURL string
|
||||||
pullURL string
|
pullURL string
|
||||||
|
finURL string
|
||||||
closeURL string
|
closeURL string
|
||||||
headerHost string
|
headerHost string
|
||||||
auth *tunnelAuth
|
auth *tunnelAuth
|
||||||
@@ -1037,15 +930,17 @@ func newPollConnFromInfo(info *sessionDialInfo) *pollConn {
|
|||||||
client: info.client,
|
client: info.client,
|
||||||
pushURL: info.pushURL,
|
pushURL: info.pushURL,
|
||||||
pullURL: info.pullURL,
|
pullURL: info.pullURL,
|
||||||
|
finURL: info.finURL,
|
||||||
closeURL: info.closeURL,
|
closeURL: info.closeURL,
|
||||||
headerHost: info.headerHost,
|
headerHost: info.headerHost,
|
||||||
auth: info.auth,
|
auth: info.auth,
|
||||||
queuedConn: queuedConn{
|
queuedConn: queuedConn{
|
||||||
rxc: make(chan []byte, 128),
|
rxc: make(chan []byte, 128),
|
||||||
closed: make(chan struct{}),
|
closed: make(chan struct{}),
|
||||||
writeCh: make(chan []byte, 256),
|
writeCh: make(chan []byte, 256),
|
||||||
localAddr: &net.TCPAddr{},
|
writeClosed: make(chan struct{}),
|
||||||
remoteAddr: &net.TCPAddr{},
|
localAddr: &net.TCPAddr{},
|
||||||
|
remoteAddr: &net.TCPAddr{},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1117,14 +1012,14 @@ func (c *pollConn) pullLoop() {
|
|||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := http.NewRequest(http.MethodGet, c.pullURL, nil)
|
req, err := http.NewRequestWithContext(c.ctx, http.MethodGet, c.pullURL, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = c.Close()
|
_ = c.Close()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
req.Host = c.headerHost
|
req.Host = c.headerHost
|
||||||
applyTunnelHeaders(req.Header, c.headerHost, TunnelModePoll)
|
applyTunnelHeaders(req.Header, c.headerHost, TunnelModePoll)
|
||||||
applyTunnelAuthHeader(req.Header, c.auth, TunnelModePoll, http.MethodGet, "/stream")
|
applyTunnelAuth(req, c.auth, TunnelModePoll, http.MethodGet, "/stream")
|
||||||
|
|
||||||
resp, err := c.client.Do(req)
|
resp, err := c.client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1202,21 +1097,25 @@ func (c *pollConn) pushLoop() {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := http.NewRequest(http.MethodPost, c.pushURL, bytes.NewReader(buf.Bytes()))
|
reqCtx, cancel := context.WithTimeout(c.ctx, 20*time.Second)
|
||||||
|
req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, c.pushURL, bytes.NewReader(buf.Bytes()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
cancel()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
req.Host = c.headerHost
|
req.Host = c.headerHost
|
||||||
applyTunnelHeaders(req.Header, c.headerHost, TunnelModePoll)
|
applyTunnelHeaders(req.Header, c.headerHost, TunnelModePoll)
|
||||||
applyTunnelAuthHeader(req.Header, c.auth, TunnelModePoll, http.MethodPost, "/api/v1/upload")
|
applyTunnelAuth(req, c.auth, TunnelModePoll, http.MethodPost, "/api/v1/upload")
|
||||||
req.Header.Set("Content-Type", "text/plain")
|
req.Header.Set("Content-Type", "text/plain")
|
||||||
|
|
||||||
resp, err := c.client.Do(req)
|
resp, err := c.client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
cancel()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 4*1024))
|
_, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 4*1024))
|
||||||
_ = resp.Body.Close()
|
_ = resp.Body.Close()
|
||||||
|
cancel()
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
return fmt.Errorf("bad status: %s", resp.Status)
|
return fmt.Errorf("bad status: %s", resp.Status)
|
||||||
}
|
}
|
||||||
@@ -1309,6 +1208,41 @@ func (c *pollConn) pushLoop() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
resetTimer()
|
resetTimer()
|
||||||
|
case <-c.writeClosed:
|
||||||
|
// Drain any already-accepted writes so CloseWrite does not lose data.
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case b := <-c.writeCh:
|
||||||
|
if len(b) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for len(b) > 0 {
|
||||||
|
chunk := b
|
||||||
|
if len(chunk) > maxLineRawBytes {
|
||||||
|
chunk = b[:maxLineRawBytes]
|
||||||
|
}
|
||||||
|
b = b[len(chunk):]
|
||||||
|
|
||||||
|
encLen := base64.StdEncoding.EncodedLen(len(chunk))
|
||||||
|
if pendingRaw+len(chunk) > maxBatchBytes || buf.Len()+encLen+1 > maxBatchBytes*2 {
|
||||||
|
if err := flushWithRetry(); err != nil {
|
||||||
|
_ = c.closeWithError(fmt.Errorf("poll push flush failed: %w", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tmp := make([]byte, base64.StdEncoding.EncodedLen(len(chunk)))
|
||||||
|
base64.StdEncoding.Encode(tmp, chunk)
|
||||||
|
buf.Write(tmp)
|
||||||
|
buf.WriteByte('\n')
|
||||||
|
pendingRaw += len(chunk)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
_ = flushWithRetry()
|
||||||
|
bestEffortCloseWriteSession(c.client, c.finURL, c.headerHost, TunnelModePoll, c.auth)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
case <-c.closed:
|
case <-c.closed:
|
||||||
_ = flushWithRetry()
|
_ = flushWithRetry()
|
||||||
return
|
return
|
||||||
@@ -1478,19 +1412,50 @@ func (s *TunnelServer) HandleConn(rawConn net.Conn) (HandleResult, net.Conn, err
|
|||||||
|
|
||||||
tunnelHeader := strings.ToLower(strings.TrimSpace(req.headers["x-sudoku-tunnel"]))
|
tunnelHeader := strings.ToLower(strings.TrimSpace(req.headers["x-sudoku-tunnel"]))
|
||||||
if tunnelHeader == "" {
|
if tunnelHeader == "" {
|
||||||
// Not our tunnel; replay full bytes to legacy handler.
|
// Some CDNs / forward proxies may strip unknown headers. When AuthKey is enabled, we can
|
||||||
prefix := make([]byte, 0, len(headerBytes)+len(buffered))
|
// safely infer the intended tunnel mode by verifying the Authorization token against
|
||||||
prefix = append(prefix, headerBytes...)
|
// both stream/poll modes and picking the one that matches.
|
||||||
prefix = append(prefix, buffered...)
|
if s.auth != nil {
|
||||||
return HandlePassThrough, newPreBufferedConn(rawConn, prefix), nil
|
u, err := url.ParseRequestURI(req.target)
|
||||||
}
|
if err == nil {
|
||||||
if s.mode == TunnelModeLegacy {
|
path, ok := stripPathRoot(s.pathRoot, u.Path)
|
||||||
if s.passThroughOnReject {
|
if ok && s.isAllowedBasePath(path) {
|
||||||
|
authVal := req.headers["authorization"]
|
||||||
|
if authVal == "" {
|
||||||
|
authVal = u.Query().Get(tunnelAuthQueryKey)
|
||||||
|
}
|
||||||
|
streamOK := s.auth.verifyValue(authVal, TunnelModeStream, req.method, path, time.Now())
|
||||||
|
pollOK := s.auth.verifyValue(authVal, TunnelModePoll, req.method, path, time.Now())
|
||||||
|
switch {
|
||||||
|
case streamOK && !pollOK:
|
||||||
|
tunnelHeader = string(TunnelModeStream)
|
||||||
|
case pollOK && !streamOK:
|
||||||
|
tunnelHeader = string(TunnelModePoll)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if tunnelHeader == "" {
|
||||||
|
// Not our tunnel; replay full bytes to legacy handler.
|
||||||
prefix := make([]byte, 0, len(headerBytes)+len(buffered))
|
prefix := make([]byte, 0, len(headerBytes)+len(buffered))
|
||||||
prefix = append(prefix, headerBytes...)
|
prefix = append(prefix, headerBytes...)
|
||||||
prefix = append(prefix, buffered...)
|
prefix = append(prefix, buffered...)
|
||||||
return HandlePassThrough, newPreBufferedConn(rawConn, prefix), nil
|
return HandlePassThrough, newPreBufferedConn(rawConn, prefix), nil
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
reject := func() (HandleResult, net.Conn, error) {
|
||||||
|
prefix := make([]byte, 0, len(headerBytes)+len(buffered))
|
||||||
|
prefix = append(prefix, headerBytes...)
|
||||||
|
prefix = append(prefix, buffered...)
|
||||||
|
return HandlePassThrough, newRejectedPreBufferedConn(rawConn, prefix), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.mode == TunnelModeLegacy {
|
||||||
|
if s.passThroughOnReject {
|
||||||
|
return reject()
|
||||||
|
}
|
||||||
_ = writeSimpleHTTPResponse(rawConn, http.StatusNotFound, "not found")
|
_ = writeSimpleHTTPResponse(rawConn, http.StatusNotFound, "not found")
|
||||||
_ = rawConn.Close()
|
_ = rawConn.Close()
|
||||||
return HandleDone, nil, nil
|
return HandleDone, nil, nil
|
||||||
@@ -1500,10 +1465,7 @@ func (s *TunnelServer) HandleConn(rawConn net.Conn) (HandleResult, net.Conn, err
|
|||||||
case TunnelModeStream:
|
case TunnelModeStream:
|
||||||
if s.mode != TunnelModeStream && s.mode != TunnelModeAuto {
|
if s.mode != TunnelModeStream && s.mode != TunnelModeAuto {
|
||||||
if s.passThroughOnReject {
|
if s.passThroughOnReject {
|
||||||
prefix := make([]byte, 0, len(headerBytes)+len(buffered))
|
return reject()
|
||||||
prefix = append(prefix, headerBytes...)
|
|
||||||
prefix = append(prefix, buffered...)
|
|
||||||
return HandlePassThrough, newPreBufferedConn(rawConn, prefix), nil
|
|
||||||
}
|
}
|
||||||
_ = writeSimpleHTTPResponse(rawConn, http.StatusNotFound, "not found")
|
_ = writeSimpleHTTPResponse(rawConn, http.StatusNotFound, "not found")
|
||||||
_ = rawConn.Close()
|
_ = rawConn.Close()
|
||||||
@@ -1513,10 +1475,7 @@ func (s *TunnelServer) HandleConn(rawConn net.Conn) (HandleResult, net.Conn, err
|
|||||||
case TunnelModePoll:
|
case TunnelModePoll:
|
||||||
if s.mode != TunnelModePoll && s.mode != TunnelModeAuto {
|
if s.mode != TunnelModePoll && s.mode != TunnelModeAuto {
|
||||||
if s.passThroughOnReject {
|
if s.passThroughOnReject {
|
||||||
prefix := make([]byte, 0, len(headerBytes)+len(buffered))
|
return reject()
|
||||||
prefix = append(prefix, headerBytes...)
|
|
||||||
prefix = append(prefix, buffered...)
|
|
||||||
return HandlePassThrough, newPreBufferedConn(rawConn, prefix), nil
|
|
||||||
}
|
}
|
||||||
_ = writeSimpleHTTPResponse(rawConn, http.StatusNotFound, "not found")
|
_ = writeSimpleHTTPResponse(rawConn, http.StatusNotFound, "not found")
|
||||||
_ = rawConn.Close()
|
_ = rawConn.Close()
|
||||||
@@ -1525,10 +1484,7 @@ func (s *TunnelServer) HandleConn(rawConn net.Conn) (HandleResult, net.Conn, err
|
|||||||
return s.handlePoll(rawConn, req, headerBytes, buffered)
|
return s.handlePoll(rawConn, req, headerBytes, buffered)
|
||||||
default:
|
default:
|
||||||
if s.passThroughOnReject {
|
if s.passThroughOnReject {
|
||||||
prefix := make([]byte, 0, len(headerBytes)+len(buffered))
|
return reject()
|
||||||
prefix = append(prefix, headerBytes...)
|
|
||||||
prefix = append(prefix, buffered...)
|
|
||||||
return HandlePassThrough, newPreBufferedConn(rawConn, prefix), nil
|
|
||||||
}
|
}
|
||||||
_ = writeSimpleHTTPResponse(rawConn, http.StatusNotFound, "not found")
|
_ = writeSimpleHTTPResponse(rawConn, http.StatusNotFound, "not found")
|
||||||
_ = rawConn.Close()
|
_ = rawConn.Close()
|
||||||
@@ -1619,13 +1575,52 @@ func readAllBuffered(r *bufio.Reader) []byte {
|
|||||||
|
|
||||||
type preBufferedConn struct {
|
type preBufferedConn struct {
|
||||||
net.Conn
|
net.Conn
|
||||||
buf []byte
|
buf []byte
|
||||||
|
recorded []byte
|
||||||
|
rejected bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func newPreBufferedConn(conn net.Conn, pre []byte) net.Conn {
|
func (p *preBufferedConn) CloseWrite() error {
|
||||||
|
if p == nil || p.Conn == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if cw, ok := p.Conn.(interface{ CloseWrite() error }); ok {
|
||||||
|
return cw.CloseWrite()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *preBufferedConn) CloseRead() error {
|
||||||
|
if p == nil || p.Conn == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if cr, ok := p.Conn.(interface{ CloseRead() error }); ok {
|
||||||
|
return cr.CloseRead()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newPreBufferedConn(conn net.Conn, pre []byte) *preBufferedConn {
|
||||||
cpy := make([]byte, len(pre))
|
cpy := make([]byte, len(pre))
|
||||||
copy(cpy, pre)
|
copy(cpy, pre)
|
||||||
return &preBufferedConn{Conn: conn, buf: cpy}
|
return &preBufferedConn{Conn: conn, buf: cpy, recorded: cpy}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRejectedPreBufferedConn(conn net.Conn, pre []byte) *preBufferedConn {
|
||||||
|
c := newPreBufferedConn(conn, pre)
|
||||||
|
c.rejected = true
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *preBufferedConn) IsHTTPMaskRejected() bool { return p.rejected }
|
||||||
|
|
||||||
|
func (p *preBufferedConn) GetBufferedAndRecorded() []byte {
|
||||||
|
if len(p.recorded) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]byte, len(p.recorded))
|
||||||
|
copy(out, p.recorded)
|
||||||
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *preBufferedConn) Read(b []byte) (int, error) {
|
func (p *preBufferedConn) Read(b []byte) (int, error) {
|
||||||
@@ -1682,7 +1677,7 @@ func (s *TunnelServer) handleStream(rawConn net.Conn, req *httpRequestHeader, he
|
|||||||
prefix := make([]byte, 0, len(headerBytes)+len(buffered))
|
prefix := make([]byte, 0, len(headerBytes)+len(buffered))
|
||||||
prefix = append(prefix, headerBytes...)
|
prefix = append(prefix, headerBytes...)
|
||||||
prefix = append(prefix, buffered...)
|
prefix = append(prefix, buffered...)
|
||||||
return HandlePassThrough, newPreBufferedConn(rawConn, prefix), nil
|
return HandlePassThrough, newRejectedPreBufferedConn(rawConn, prefix), nil
|
||||||
}
|
}
|
||||||
_ = writeSimpleHTTPResponse(rawConn, code, body)
|
_ = writeSimpleHTTPResponse(rawConn, code, body)
|
||||||
_ = rawConn.Close()
|
_ = rawConn.Close()
|
||||||
@@ -1699,21 +1694,29 @@ func (s *TunnelServer) handleStream(rawConn net.Conn, req *httpRequestHeader, he
|
|||||||
if !ok || !s.isAllowedBasePath(path) {
|
if !ok || !s.isAllowedBasePath(path) {
|
||||||
return rejectOrReply(http.StatusNotFound, "not found")
|
return rejectOrReply(http.StatusNotFound, "not found")
|
||||||
}
|
}
|
||||||
if !s.auth.verify(req.headers, TunnelModeStream, req.method, path, time.Now()) {
|
authVal := req.headers["authorization"]
|
||||||
|
if authVal == "" {
|
||||||
|
authVal = u.Query().Get(tunnelAuthQueryKey)
|
||||||
|
}
|
||||||
|
if !s.auth.verifyValue(authVal, TunnelModeStream, req.method, path, time.Now()) {
|
||||||
return rejectOrReply(http.StatusNotFound, "not found")
|
return rejectOrReply(http.StatusNotFound, "not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
token := u.Query().Get("token")
|
token := u.Query().Get("token")
|
||||||
closeFlag := u.Query().Get("close") == "1"
|
closeFlag := u.Query().Get("close") == "1"
|
||||||
|
finFlag := u.Query().Get("fin") == "1"
|
||||||
|
|
||||||
switch strings.ToUpper(req.method) {
|
switch strings.ToUpper(req.method) {
|
||||||
case http.MethodGet:
|
case http.MethodGet:
|
||||||
// Stream split-session: GET /session (no token) => token + start tunnel on a server-side pipe.
|
// Stream split-session: GET /session (no token) => token + start tunnel on a server-side pipe.
|
||||||
if token == "" && path == "/session" {
|
if token == "" && path == "/session" {
|
||||||
return s.authorizeSession(rawConn)
|
return s.sessionAuthorize(rawConn)
|
||||||
}
|
}
|
||||||
// Stream split-session: GET /stream?token=... => downlink poll.
|
// Stream split-session: GET /stream?token=... => downlink poll.
|
||||||
if token != "" && path == "/stream" {
|
if token != "" && path == "/stream" {
|
||||||
|
if s.passThroughOnReject && !s.sessionHas(token) {
|
||||||
|
return rejectOrReply(http.StatusNotFound, "not found")
|
||||||
|
}
|
||||||
return s.streamPull(rawConn, token)
|
return s.streamPull(rawConn, token)
|
||||||
}
|
}
|
||||||
return rejectOrReply(http.StatusBadRequest, "bad request")
|
return rejectOrReply(http.StatusBadRequest, "bad request")
|
||||||
@@ -1721,13 +1724,26 @@ func (s *TunnelServer) handleStream(rawConn net.Conn, req *httpRequestHeader, he
|
|||||||
case http.MethodPost:
|
case http.MethodPost:
|
||||||
// Stream split-session: POST /api/v1/upload?token=... => uplink push.
|
// Stream split-session: POST /api/v1/upload?token=... => uplink push.
|
||||||
if token != "" && path == "/api/v1/upload" {
|
if token != "" && path == "/api/v1/upload" {
|
||||||
|
if s.passThroughOnReject && !s.sessionHas(token) {
|
||||||
|
return rejectOrReply(http.StatusNotFound, "not found")
|
||||||
|
}
|
||||||
if closeFlag {
|
if closeFlag {
|
||||||
s.closeSession(token)
|
s.sessionClose(token)
|
||||||
return rejectOrReply(http.StatusOK, "")
|
_ = writeSimpleHTTPResponse(rawConn, http.StatusOK, "")
|
||||||
|
_ = rawConn.Close()
|
||||||
|
return HandleDone, nil, nil
|
||||||
|
}
|
||||||
|
if finFlag {
|
||||||
|
s.sessionCloseWrite(token)
|
||||||
|
_ = writeSimpleHTTPResponse(rawConn, http.StatusOK, "")
|
||||||
|
_ = rawConn.Close()
|
||||||
|
return HandleDone, nil, nil
|
||||||
}
|
}
|
||||||
bodyReader, err := newRequestBodyReader(newPreBufferedConn(rawConn, buffered), req.headers)
|
bodyReader, err := newRequestBodyReader(newPreBufferedConn(rawConn, buffered), req.headers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return rejectOrReply(http.StatusBadRequest, "bad request")
|
_ = writeSimpleHTTPResponse(rawConn, http.StatusBadRequest, "bad request")
|
||||||
|
_ = rawConn.Close()
|
||||||
|
return HandleDone, nil, nil
|
||||||
}
|
}
|
||||||
return s.streamPush(rawConn, token, bodyReader)
|
return s.streamPush(rawConn, token, bodyReader)
|
||||||
}
|
}
|
||||||
@@ -1825,7 +1841,7 @@ func (s *TunnelServer) handlePoll(rawConn net.Conn, req *httpRequestHeader, head
|
|||||||
prefix := make([]byte, 0, len(headerBytes)+len(buffered))
|
prefix := make([]byte, 0, len(headerBytes)+len(buffered))
|
||||||
prefix = append(prefix, headerBytes...)
|
prefix = append(prefix, headerBytes...)
|
||||||
prefix = append(prefix, buffered...)
|
prefix = append(prefix, buffered...)
|
||||||
return HandlePassThrough, newPreBufferedConn(rawConn, prefix), nil
|
return HandlePassThrough, newRejectedPreBufferedConn(rawConn, prefix), nil
|
||||||
}
|
}
|
||||||
_ = writeSimpleHTTPResponse(rawConn, code, body)
|
_ = writeSimpleHTTPResponse(rawConn, code, body)
|
||||||
_ = rawConn.Close()
|
_ = rawConn.Close()
|
||||||
@@ -1841,18 +1857,26 @@ func (s *TunnelServer) handlePoll(rawConn net.Conn, req *httpRequestHeader, head
|
|||||||
if !ok || !s.isAllowedBasePath(path) {
|
if !ok || !s.isAllowedBasePath(path) {
|
||||||
return rejectOrReply(http.StatusNotFound, "not found")
|
return rejectOrReply(http.StatusNotFound, "not found")
|
||||||
}
|
}
|
||||||
if !s.auth.verify(req.headers, TunnelModePoll, req.method, path, time.Now()) {
|
authVal := req.headers["authorization"]
|
||||||
|
if authVal == "" {
|
||||||
|
authVal = u.Query().Get(tunnelAuthQueryKey)
|
||||||
|
}
|
||||||
|
if !s.auth.verifyValue(authVal, TunnelModePoll, req.method, path, time.Now()) {
|
||||||
return rejectOrReply(http.StatusNotFound, "not found")
|
return rejectOrReply(http.StatusNotFound, "not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
token := u.Query().Get("token")
|
token := u.Query().Get("token")
|
||||||
closeFlag := u.Query().Get("close") == "1"
|
closeFlag := u.Query().Get("close") == "1"
|
||||||
|
finFlag := u.Query().Get("fin") == "1"
|
||||||
switch strings.ToUpper(req.method) {
|
switch strings.ToUpper(req.method) {
|
||||||
case http.MethodGet:
|
case http.MethodGet:
|
||||||
if token == "" && path == "/session" {
|
if token == "" && path == "/session" {
|
||||||
return s.authorizeSession(rawConn)
|
return s.sessionAuthorize(rawConn)
|
||||||
}
|
}
|
||||||
if token != "" && path == "/stream" {
|
if token != "" && path == "/stream" {
|
||||||
|
if s.passThroughOnReject && !s.sessionHas(token) {
|
||||||
|
return rejectOrReply(http.StatusNotFound, "not found")
|
||||||
|
}
|
||||||
return s.pollPull(rawConn, token)
|
return s.pollPull(rawConn, token)
|
||||||
}
|
}
|
||||||
return rejectOrReply(http.StatusBadRequest, "bad request")
|
return rejectOrReply(http.StatusBadRequest, "bad request")
|
||||||
@@ -1860,13 +1884,26 @@ func (s *TunnelServer) handlePoll(rawConn net.Conn, req *httpRequestHeader, head
|
|||||||
if token == "" || path != "/api/v1/upload" {
|
if token == "" || path != "/api/v1/upload" {
|
||||||
return rejectOrReply(http.StatusBadRequest, "bad request")
|
return rejectOrReply(http.StatusBadRequest, "bad request")
|
||||||
}
|
}
|
||||||
|
if s.passThroughOnReject && !s.sessionHas(token) {
|
||||||
|
return rejectOrReply(http.StatusNotFound, "not found")
|
||||||
|
}
|
||||||
if closeFlag {
|
if closeFlag {
|
||||||
s.closeSession(token)
|
s.sessionClose(token)
|
||||||
return rejectOrReply(http.StatusOK, "")
|
_ = writeSimpleHTTPResponse(rawConn, http.StatusOK, "")
|
||||||
|
_ = rawConn.Close()
|
||||||
|
return HandleDone, nil, nil
|
||||||
|
}
|
||||||
|
if finFlag {
|
||||||
|
s.sessionCloseWrite(token)
|
||||||
|
_ = writeSimpleHTTPResponse(rawConn, http.StatusOK, "")
|
||||||
|
_ = rawConn.Close()
|
||||||
|
return HandleDone, nil, nil
|
||||||
}
|
}
|
||||||
bodyReader, err := newRequestBodyReader(newPreBufferedConn(rawConn, buffered), req.headers)
|
bodyReader, err := newRequestBodyReader(newPreBufferedConn(rawConn, buffered), req.headers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return rejectOrReply(http.StatusBadRequest, "bad request")
|
_ = writeSimpleHTTPResponse(rawConn, http.StatusBadRequest, "bad request")
|
||||||
|
_ = rawConn.Close()
|
||||||
|
return HandleDone, nil, nil
|
||||||
}
|
}
|
||||||
return s.pollPush(rawConn, token, bodyReader)
|
return s.pollPush(rawConn, token, bodyReader)
|
||||||
default:
|
default:
|
||||||
@@ -1874,7 +1911,7 @@ func (s *TunnelServer) handlePoll(rawConn net.Conn, req *httpRequestHeader, head
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *TunnelServer) authorizeSession(rawConn net.Conn) (HandleResult, net.Conn, error) {
|
func (s *TunnelServer) sessionAuthorize(rawConn net.Conn) (HandleResult, net.Conn, error) {
|
||||||
token, err := newSessionToken()
|
token, err := newSessionToken()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = writeSimpleHTTPResponse(rawConn, http.StatusInternalServerError, "internal error")
|
_ = writeSimpleHTTPResponse(rawConn, http.StatusInternalServerError, "internal error")
|
||||||
@@ -1882,13 +1919,13 @@ func (s *TunnelServer) authorizeSession(rawConn net.Conn) (HandleResult, net.Con
|
|||||||
return HandleDone, nil, nil
|
return HandleDone, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
c1, c2 := net.Pipe()
|
c1, c2 := newHalfPipe()
|
||||||
|
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
s.sessions[token] = &tunnelSession{conn: c2, lastActive: time.Now()}
|
s.sessions[token] = &tunnelSession{conn: c2, lastActive: time.Now()}
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
|
||||||
go s.reapSessionLater(token)
|
go s.reapLater(token)
|
||||||
|
|
||||||
_ = writeTokenHTTPResponse(rawConn, token)
|
_ = writeTokenHTTPResponse(rawConn, token)
|
||||||
_ = rawConn.Close()
|
_ = rawConn.Close()
|
||||||
@@ -1903,31 +1940,50 @@ func newSessionToken() (string, error) {
|
|||||||
return base64.RawURLEncoding.EncodeToString(b[:]), nil
|
return base64.RawURLEncoding.EncodeToString(b[:]), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *TunnelServer) reapSessionLater(token string) {
|
func (s *TunnelServer) reapLater(token string) {
|
||||||
ttl := s.sessionTTL
|
ttl := s.sessionTTL
|
||||||
if ttl <= 0 {
|
if ttl <= 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
timer := time.NewTimer(ttl)
|
timer := time.NewTimer(ttl)
|
||||||
defer timer.Stop()
|
defer timer.Stop()
|
||||||
<-timer.C
|
|
||||||
|
|
||||||
s.mu.Lock()
|
for {
|
||||||
sess, ok := s.sessions[token]
|
<-timer.C
|
||||||
if !ok {
|
|
||||||
|
s.mu.Lock()
|
||||||
|
sess, ok := s.sessions[token]
|
||||||
|
if !ok {
|
||||||
|
s.mu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
idle := time.Since(sess.lastActive)
|
||||||
|
if idle >= ttl {
|
||||||
|
delete(s.sessions, token)
|
||||||
|
s.mu.Unlock()
|
||||||
|
_ = sess.conn.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next := ttl - idle
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
return
|
|
||||||
|
// Avoid a tight loop under high-frequency activity; we only need best-effort cleanup.
|
||||||
|
if next < 50*time.Millisecond {
|
||||||
|
next = 50 * time.Millisecond
|
||||||
|
}
|
||||||
|
timer.Reset(next)
|
||||||
}
|
}
|
||||||
if time.Since(sess.lastActive) < ttl {
|
|
||||||
s.mu.Unlock()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
delete(s.sessions, token)
|
|
||||||
s.mu.Unlock()
|
|
||||||
_ = sess.conn.Close()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *TunnelServer) getSession(token string) (*tunnelSession, bool) {
|
func (s *TunnelServer) sessionHas(token string) bool {
|
||||||
|
s.mu.Lock()
|
||||||
|
_, ok := s.sessions[token]
|
||||||
|
s.mu.Unlock()
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *TunnelServer) sessionGet(token string) (*tunnelSession, bool) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
sess, ok := s.sessions[token]
|
sess, ok := s.sessions[token]
|
||||||
@@ -1938,7 +1994,7 @@ func (s *TunnelServer) getSession(token string) (*tunnelSession, bool) {
|
|||||||
return sess, true
|
return sess, true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *TunnelServer) closeSession(token string) {
|
func (s *TunnelServer) sessionClose(token string) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
sess, ok := s.sessions[token]
|
sess, ok := s.sessions[token]
|
||||||
if ok {
|
if ok {
|
||||||
@@ -1950,8 +2006,20 @@ func (s *TunnelServer) closeSession(token string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *TunnelServer) sessionCloseWrite(token string) {
|
||||||
|
sess, ok := s.sessionGet(token)
|
||||||
|
if !ok || sess == nil || sess.conn == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if cw, ok := sess.conn.(interface{ CloseWrite() error }); ok {
|
||||||
|
_ = cw.CloseWrite()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = sess.conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
func (s *TunnelServer) pollPush(rawConn net.Conn, token string, body io.Reader) (HandleResult, net.Conn, error) {
|
func (s *TunnelServer) pollPush(rawConn net.Conn, token string, body io.Reader) (HandleResult, net.Conn, error) {
|
||||||
sess, ok := s.getSession(token)
|
sess, ok := s.sessionGet(token)
|
||||||
if !ok {
|
if !ok {
|
||||||
_ = writeSimpleHTTPResponse(rawConn, http.StatusForbidden, "forbidden")
|
_ = writeSimpleHTTPResponse(rawConn, http.StatusForbidden, "forbidden")
|
||||||
_ = rawConn.Close()
|
_ = rawConn.Close()
|
||||||
@@ -1985,7 +2053,7 @@ func (s *TunnelServer) pollPush(rawConn net.Conn, token string, body io.Reader)
|
|||||||
_, werr := sess.conn.Write(decoded[:n])
|
_, werr := sess.conn.Write(decoded[:n])
|
||||||
_ = sess.conn.SetWriteDeadline(time.Time{})
|
_ = sess.conn.SetWriteDeadline(time.Time{})
|
||||||
if werr != nil {
|
if werr != nil {
|
||||||
s.closeSession(token)
|
s.sessionClose(token)
|
||||||
_ = writeSimpleHTTPResponse(rawConn, http.StatusGone, "gone")
|
_ = writeSimpleHTTPResponse(rawConn, http.StatusGone, "gone")
|
||||||
_ = rawConn.Close()
|
_ = rawConn.Close()
|
||||||
return HandleDone, nil, nil
|
return HandleDone, nil, nil
|
||||||
@@ -1998,7 +2066,7 @@ func (s *TunnelServer) pollPush(rawConn net.Conn, token string, body io.Reader)
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *TunnelServer) streamPush(rawConn net.Conn, token string, body io.Reader) (HandleResult, net.Conn, error) {
|
func (s *TunnelServer) streamPush(rawConn net.Conn, token string, body io.Reader) (HandleResult, net.Conn, error) {
|
||||||
sess, ok := s.getSession(token)
|
sess, ok := s.sessionGet(token)
|
||||||
if !ok {
|
if !ok {
|
||||||
_ = writeSimpleHTTPResponse(rawConn, http.StatusForbidden, "forbidden")
|
_ = writeSimpleHTTPResponse(rawConn, http.StatusForbidden, "forbidden")
|
||||||
_ = rawConn.Close()
|
_ = rawConn.Close()
|
||||||
@@ -2023,7 +2091,7 @@ func (s *TunnelServer) streamPush(rawConn net.Conn, token string, body io.Reader
|
|||||||
_, werr := sess.conn.Write(payload)
|
_, werr := sess.conn.Write(payload)
|
||||||
_ = sess.conn.SetWriteDeadline(time.Time{})
|
_ = sess.conn.SetWriteDeadline(time.Time{})
|
||||||
if werr != nil {
|
if werr != nil {
|
||||||
s.closeSession(token)
|
s.sessionClose(token)
|
||||||
_ = writeSimpleHTTPResponse(rawConn, http.StatusGone, "gone")
|
_ = writeSimpleHTTPResponse(rawConn, http.StatusGone, "gone")
|
||||||
_ = rawConn.Close()
|
_ = rawConn.Close()
|
||||||
return HandleDone, nil, nil
|
return HandleDone, nil, nil
|
||||||
@@ -2036,7 +2104,7 @@ func (s *TunnelServer) streamPush(rawConn net.Conn, token string, body io.Reader
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *TunnelServer) streamPull(rawConn net.Conn, token string) (HandleResult, net.Conn, error) {
|
func (s *TunnelServer) streamPull(rawConn net.Conn, token string) (HandleResult, net.Conn, error) {
|
||||||
sess, ok := s.getSession(token)
|
sess, ok := s.sessionGet(token)
|
||||||
if !ok {
|
if !ok {
|
||||||
_ = writeSimpleHTTPResponse(rawConn, http.StatusForbidden, "forbidden")
|
_ = writeSimpleHTTPResponse(rawConn, http.StatusForbidden, "forbidden")
|
||||||
_ = rawConn.Close()
|
_ = rawConn.Close()
|
||||||
@@ -2074,14 +2142,14 @@ func (s *TunnelServer) streamPull(rawConn net.Conn, token string) (HandleResult,
|
|||||||
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, net.ErrClosed) {
|
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, net.ErrClosed) {
|
||||||
return HandleDone, nil, nil
|
return HandleDone, nil, nil
|
||||||
}
|
}
|
||||||
s.closeSession(token)
|
s.sessionClose(token)
|
||||||
return HandleDone, nil, nil
|
return HandleDone, nil, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *TunnelServer) pollPull(rawConn net.Conn, token string) (HandleResult, net.Conn, error) {
|
func (s *TunnelServer) pollPull(rawConn net.Conn, token string) (HandleResult, net.Conn, error) {
|
||||||
sess, ok := s.getSession(token)
|
sess, ok := s.sessionGet(token)
|
||||||
if !ok {
|
if !ok {
|
||||||
_ = writeSimpleHTTPResponse(rawConn, http.StatusForbidden, "forbidden")
|
_ = writeSimpleHTTPResponse(rawConn, http.StatusForbidden, "forbidden")
|
||||||
_ = rawConn.Close()
|
_ = rawConn.Close()
|
||||||
@@ -2123,7 +2191,7 @@ func (s *TunnelServer) pollPull(rawConn net.Conn, token string) (HandleResult, n
|
|||||||
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, net.ErrClosed) {
|
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, net.ErrClosed) {
|
||||||
return HandleDone, nil, nil
|
return HandleDone, nil, nil
|
||||||
}
|
}
|
||||||
s.closeSession(token)
|
s.sessionClose(token)
|
||||||
return HandleDone, nil, nil
|
return HandleDone, nil, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -52,8 +52,28 @@ type Conn struct {
|
|||||||
pendingData []byte
|
pendingData []byte
|
||||||
hintBuf []byte
|
hintBuf []byte
|
||||||
|
|
||||||
rng *rand.Rand
|
rng *rand.Rand
|
||||||
paddingRate float32
|
paddingThreshold uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sc *Conn) CloseWrite() error {
|
||||||
|
if sc == nil || sc.Conn == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if cw, ok := sc.Conn.(interface{ CloseWrite() error }); ok {
|
||||||
|
return cw.CloseWrite()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sc *Conn) CloseRead() error {
|
||||||
|
if sc == nil || sc.Conn == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if cr, ok := sc.Conn.(interface{ CloseRead() error }); ok {
|
||||||
|
return cr.CloseRead()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewConn(c net.Conn, table *Table, pMin, pMax int, record bool) *Conn {
|
func NewConn(c net.Conn, table *Table, pMin, pMax int, record bool) *Conn {
|
||||||
@@ -64,19 +84,15 @@ func NewConn(c net.Conn, table *Table, pMin, pMax int, record bool) *Conn {
|
|||||||
seed := int64(binary.BigEndian.Uint64(seedBytes[:]))
|
seed := int64(binary.BigEndian.Uint64(seedBytes[:]))
|
||||||
localRng := rand.New(rand.NewSource(seed))
|
localRng := rand.New(rand.NewSource(seed))
|
||||||
|
|
||||||
min := float32(pMin) / 100.0
|
|
||||||
rng := float32(pMax-pMin) / 100.0
|
|
||||||
rate := min + localRng.Float32()*rng
|
|
||||||
|
|
||||||
sc := &Conn{
|
sc := &Conn{
|
||||||
Conn: c,
|
Conn: c,
|
||||||
table: table,
|
table: table,
|
||||||
reader: bufio.NewReaderSize(c, IOBufferSize),
|
reader: bufio.NewReaderSize(c, IOBufferSize),
|
||||||
rawBuf: make([]byte, IOBufferSize),
|
rawBuf: make([]byte, IOBufferSize),
|
||||||
pendingData: make([]byte, 0, 4096),
|
pendingData: make([]byte, 0, 4096),
|
||||||
hintBuf: make([]byte, 0, 4),
|
hintBuf: make([]byte, 0, 4),
|
||||||
rng: localRng,
|
rng: localRng,
|
||||||
paddingRate: rate,
|
paddingThreshold: pickPaddingThreshold(localRng, pMin, pMax),
|
||||||
}
|
}
|
||||||
if record {
|
if record {
|
||||||
sc.recorder = new(bytes.Buffer)
|
sc.recorder = new(bytes.Buffer)
|
||||||
@@ -127,7 +143,7 @@ func (sc *Conn) Write(p []byte) (n int, err error) {
|
|||||||
padLen := len(pads)
|
padLen := len(pads)
|
||||||
|
|
||||||
for _, b := range p {
|
for _, b := range p {
|
||||||
if sc.rng.Float32() < sc.paddingRate {
|
if shouldPad(sc.rng, sc.paddingThreshold) {
|
||||||
out = append(out, pads[sc.rng.Intn(padLen)])
|
out = append(out, pads[sc.rng.Intn(padLen)])
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -136,14 +152,14 @@ func (sc *Conn) Write(p []byte) (n int, err error) {
|
|||||||
|
|
||||||
perm := perm4[sc.rng.Intn(len(perm4))]
|
perm := perm4[sc.rng.Intn(len(perm4))]
|
||||||
for _, idx := range perm {
|
for _, idx := range perm {
|
||||||
if sc.rng.Float32() < sc.paddingRate {
|
if shouldPad(sc.rng, sc.paddingThreshold) {
|
||||||
out = append(out, pads[sc.rng.Intn(padLen)])
|
out = append(out, pads[sc.rng.Intn(padLen)])
|
||||||
}
|
}
|
||||||
out = append(out, puzzle[idx])
|
out = append(out, puzzle[idx])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if sc.rng.Float32() < sc.paddingRate {
|
if shouldPad(sc.rng, sc.paddingThreshold) {
|
||||||
out = append(out, pads[sc.rng.Intn(padLen)])
|
out = append(out, pads[sc.rng.Intn(padLen)])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// 1. 使用 12字节->16组 的块处理优化 Write (减少循环开销)
|
// 1. 使用 12字节->16组 的块处理优化 Write (减少循环开销)
|
||||||
// 2. 使用浮点随机概率判断 Padding,与纯 Sudoku 保持流量特征一致
|
// 2. 使用整数阈值随机概率判断 Padding,与纯 Sudoku 保持流量特征一致
|
||||||
// 3. Read 使用 copy 移动避免底层数组泄漏
|
// 3. Read 使用 copy 移动避免底层数组泄漏
|
||||||
type PackedConn struct {
|
type PackedConn struct {
|
||||||
net.Conn
|
net.Conn
|
||||||
@@ -37,11 +37,31 @@ type PackedConn struct {
|
|||||||
readBitBuf uint64
|
readBitBuf uint64
|
||||||
readBits int
|
readBits int
|
||||||
|
|
||||||
// 随机数与填充控制 - 使用浮点随机,与 Conn 一致
|
// 随机数与填充控制 - 使用整数阈值随机,与 Conn 一致
|
||||||
rng *rand.Rand
|
rng *rand.Rand
|
||||||
paddingRate float32 // 与 Conn 保持一致的随机概率模型
|
paddingThreshold uint64 // 与 Conn 保持一致的随机概率模型
|
||||||
padMarker byte
|
padMarker byte
|
||||||
padPool []byte
|
padPool []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pc *PackedConn) CloseWrite() error {
|
||||||
|
if pc == nil || pc.Conn == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if cw, ok := pc.Conn.(interface{ CloseWrite() error }); ok {
|
||||||
|
return cw.CloseWrite()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pc *PackedConn) CloseRead() error {
|
||||||
|
if pc == nil || pc.Conn == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if cr, ok := pc.Conn.(interface{ CloseRead() error }); ok {
|
||||||
|
return cr.CloseRead()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPackedConn(c net.Conn, table *Table, pMin, pMax int) *PackedConn {
|
func NewPackedConn(c net.Conn, table *Table, pMin, pMax int) *PackedConn {
|
||||||
@@ -52,20 +72,15 @@ func NewPackedConn(c net.Conn, table *Table, pMin, pMax int) *PackedConn {
|
|||||||
seed := int64(binary.BigEndian.Uint64(seedBytes[:]))
|
seed := int64(binary.BigEndian.Uint64(seedBytes[:]))
|
||||||
localRng := rand.New(rand.NewSource(seed))
|
localRng := rand.New(rand.NewSource(seed))
|
||||||
|
|
||||||
// 与 Conn 保持一致的 padding 概率计算
|
|
||||||
min := float32(pMin) / 100.0
|
|
||||||
rng := float32(pMax-pMin) / 100.0
|
|
||||||
rate := min + localRng.Float32()*rng
|
|
||||||
|
|
||||||
pc := &PackedConn{
|
pc := &PackedConn{
|
||||||
Conn: c,
|
Conn: c,
|
||||||
table: table,
|
table: table,
|
||||||
reader: bufio.NewReaderSize(c, IOBufferSize),
|
reader: bufio.NewReaderSize(c, IOBufferSize),
|
||||||
rawBuf: make([]byte, IOBufferSize),
|
rawBuf: make([]byte, IOBufferSize),
|
||||||
pendingData: make([]byte, 0, 4096),
|
pendingData: make([]byte, 0, 4096),
|
||||||
writeBuf: make([]byte, 0, 4096),
|
writeBuf: make([]byte, 0, 4096),
|
||||||
rng: localRng,
|
rng: localRng,
|
||||||
paddingRate: rate,
|
paddingThreshold: pickPaddingThreshold(localRng, pMin, pMax),
|
||||||
}
|
}
|
||||||
|
|
||||||
pc.padMarker = table.layout.padMarker
|
pc.padMarker = table.layout.padMarker
|
||||||
@@ -80,9 +95,9 @@ func NewPackedConn(c net.Conn, table *Table, pMin, pMax int) *PackedConn {
|
|||||||
return pc
|
return pc
|
||||||
}
|
}
|
||||||
|
|
||||||
// maybeAddPadding 内联辅助:根据浮点概率插入 padding
|
// maybeAddPadding 内联辅助:根据概率阈值插入 padding
|
||||||
func (pc *PackedConn) maybeAddPadding(out []byte) []byte {
|
func (pc *PackedConn) maybeAddPadding(out []byte) []byte {
|
||||||
if pc.rng.Float32() < pc.paddingRate {
|
if shouldPad(pc.rng, pc.paddingThreshold) {
|
||||||
out = append(out, pc.getPaddingByte())
|
out = append(out, pc.getPaddingByte())
|
||||||
}
|
}
|
||||||
return out
|
return out
|
||||||
|
|||||||
44
transport/sudoku/obfs/sudoku/padding_prob.go
Normal file
44
transport/sudoku/obfs/sudoku/padding_prob.go
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
package sudoku
|
||||||
|
|
||||||
|
import "math/rand"
|
||||||
|
|
||||||
|
const probOne = uint64(1) << 32
|
||||||
|
|
||||||
|
func pickPaddingThreshold(r *rand.Rand, pMin, pMax int) uint64 {
|
||||||
|
if r == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if pMin < 0 {
|
||||||
|
pMin = 0
|
||||||
|
}
|
||||||
|
if pMax < pMin {
|
||||||
|
pMax = pMin
|
||||||
|
}
|
||||||
|
if pMax > 100 {
|
||||||
|
pMax = 100
|
||||||
|
}
|
||||||
|
if pMin > 100 {
|
||||||
|
pMin = 100
|
||||||
|
}
|
||||||
|
|
||||||
|
min := uint64(pMin) * probOne / 100
|
||||||
|
max := uint64(pMax) * probOne / 100
|
||||||
|
if max <= min {
|
||||||
|
return min
|
||||||
|
}
|
||||||
|
u := uint64(r.Uint32())
|
||||||
|
return min + (u * (max - min) >> 32)
|
||||||
|
}
|
||||||
|
|
||||||
|
func shouldPad(r *rand.Rand, threshold uint64) bool {
|
||||||
|
if threshold == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if threshold >= probOne {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if r == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return uint64(r.Uint32()) < threshold
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user