mirror of
https://github.com/MetaCubeX/mihomo.git
synced 2026-02-26 16:57:08 +00:00
chore: refactored the implementation of suduko mux (#2486)
This commit is contained in:
@@ -19,6 +19,9 @@ type Sudoku struct {
|
|||||||
option *SudokuOption
|
option *SudokuOption
|
||||||
baseConf sudoku.ProtocolConfig
|
baseConf sudoku.ProtocolConfig
|
||||||
|
|
||||||
|
httpMaskMu sync.Mutex
|
||||||
|
httpMaskClient *sudoku.HTTPMaskTunnelClient
|
||||||
|
|
||||||
muxMu sync.Mutex
|
muxMu sync.Mutex
|
||||||
muxClient *sudoku.MultiplexClient
|
muxClient *sudoku.MultiplexClient
|
||||||
muxBackoffUntil time.Time
|
muxBackoffUntil time.Time
|
||||||
@@ -40,7 +43,7 @@ type SudokuOption struct {
|
|||||||
HTTPMaskMode string `proxy:"http-mask-mode,omitempty"` // "legacy" (default), "stream", "poll", "auto"
|
HTTPMaskMode string `proxy:"http-mask-mode,omitempty"` // "legacy" (default), "stream", "poll", "auto"
|
||||||
HTTPMaskTLS bool `proxy:"http-mask-tls,omitempty"` // only for http-mask-mode stream/poll/auto
|
HTTPMaskTLS bool `proxy:"http-mask-tls,omitempty"` // only for http-mask-mode stream/poll/auto
|
||||||
HTTPMaskHost string `proxy:"http-mask-host,omitempty"` // optional Host/SNI override (domain or domain:port)
|
HTTPMaskHost string `proxy:"http-mask-host,omitempty"` // optional Host/SNI override (domain or domain:port)
|
||||||
HTTPMaskMultiplex string `proxy:"http-mask-multiplex,omitempty"` // "off" (default), "auto", "on"
|
HTTPMaskMultiplex string `proxy:"http-mask-multiplex,omitempty"` // "off" (default), "auto" (reuse h1/h2), "on" (single tunnel, multi-target)
|
||||||
CustomTable string `proxy:"custom-table,omitempty"` // optional custom byte layout, e.g. xpxvvpvv
|
CustomTable string `proxy:"custom-table,omitempty"` // optional custom byte layout, e.g. xpxvvpvv
|
||||||
CustomTables []string `proxy:"custom-tables,omitempty"` // optional table rotation patterns, overrides custom-table when non-empty
|
CustomTables []string `proxy:"custom-tables,omitempty"` // optional table rotation patterns, overrides custom-table when non-empty
|
||||||
}
|
}
|
||||||
@@ -53,18 +56,12 @@ func (s *Sudoku) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Con
|
|||||||
}
|
}
|
||||||
|
|
||||||
muxMode := normalizeHTTPMaskMultiplex(cfg.HTTPMaskMultiplex)
|
muxMode := normalizeHTTPMaskMultiplex(cfg.HTTPMaskMultiplex)
|
||||||
if !cfg.DisableHTTPMask && muxMode != "off" {
|
if muxMode == "on" && !cfg.DisableHTTPMask && httpTunnelModeEnabled(cfg.HTTPMaskMode) {
|
||||||
shouldTry := muxMode == "on" || (muxMode == "auto" && httpTunnelModeEnabled(cfg.HTTPMaskMode))
|
stream, muxErr := s.dialMultiplex(ctx, cfg.TargetAddress, muxMode)
|
||||||
if shouldTry {
|
if muxErr == nil {
|
||||||
stream, muxErr := s.dialMultiplex(ctx, cfg.TargetAddress, muxMode)
|
return NewConn(stream, s), nil
|
||||||
if muxErr == nil {
|
|
||||||
return NewConn(stream, s), nil
|
|
||||||
}
|
|
||||||
if muxMode != "auto" {
|
|
||||||
return nil, muxErr
|
|
||||||
}
|
|
||||||
s.noteMuxFailure(muxMode, muxErr)
|
|
||||||
}
|
}
|
||||||
|
return nil, muxErr
|
||||||
}
|
}
|
||||||
|
|
||||||
c, err := s.dialAndHandshake(ctx, cfg)
|
c, err := s.dialAndHandshake(ctx, cfg)
|
||||||
@@ -229,6 +226,7 @@ func NewSudoku(option SudokuOption) (*Sudoku, error) {
|
|||||||
|
|
||||||
func (s *Sudoku) Close() error {
|
func (s *Sudoku) Close() error {
|
||||||
s.resetMuxClient()
|
s.resetMuxClient()
|
||||||
|
s.resetHTTPMaskClient()
|
||||||
return s.Base.Close()
|
return s.Base.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -261,7 +259,17 @@ func (s *Sudoku) dialAndHandshake(ctx context.Context, cfg *sudoku.ProtocolConfi
|
|||||||
|
|
||||||
var c net.Conn
|
var c net.Conn
|
||||||
if !cfg.DisableHTTPMask && httpTunnelModeEnabled(cfg.HTTPMaskMode) {
|
if !cfg.DisableHTTPMask && httpTunnelModeEnabled(cfg.HTTPMaskMode) {
|
||||||
c, err = sudoku.DialHTTPMaskTunnel(ctx, cfg.ServerAddress, cfg, s.dialer.DialContext)
|
muxMode := normalizeHTTPMaskMultiplex(cfg.HTTPMaskMultiplex)
|
||||||
|
switch muxMode {
|
||||||
|
case "auto", "on":
|
||||||
|
client, errX := s.getOrCreateHTTPMaskClient(cfg)
|
||||||
|
if errX != nil {
|
||||||
|
return nil, errX
|
||||||
|
}
|
||||||
|
c, err = client.Dial(ctx)
|
||||||
|
default:
|
||||||
|
c, err = sudoku.DialHTTPMaskTunnel(ctx, cfg.ServerAddress, cfg, s.dialer.DialContext)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if c == nil && err == nil {
|
if c == nil && err == nil {
|
||||||
c, err = s.dialer.DialContext(ctx, "tcp", s.addr)
|
c, err = s.dialer.DialContext(ctx, "tcp", s.addr)
|
||||||
@@ -380,3 +388,35 @@ func (s *Sudoku) resetMuxClient() {
|
|||||||
s.muxClient = nil
|
s.muxClient = nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Sudoku) getOrCreateHTTPMaskClient(cfg *sudoku.ProtocolConfig) (*sudoku.HTTPMaskTunnelClient, error) {
|
||||||
|
if s == nil {
|
||||||
|
return nil, fmt.Errorf("nil adapter")
|
||||||
|
}
|
||||||
|
if cfg == nil {
|
||||||
|
return nil, fmt.Errorf("config is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
s.httpMaskMu.Lock()
|
||||||
|
defer s.httpMaskMu.Unlock()
|
||||||
|
|
||||||
|
if s.httpMaskClient != nil {
|
||||||
|
return s.httpMaskClient, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
c, err := sudoku.NewHTTPMaskTunnelClient(cfg.ServerAddress, cfg, s.dialer.DialContext)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
s.httpMaskClient = c
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Sudoku) resetHTTPMaskClient() {
|
||||||
|
s.httpMaskMu.Lock()
|
||||||
|
defer s.httpMaskMu.Unlock()
|
||||||
|
if s.httpMaskClient != nil {
|
||||||
|
s.httpMaskClient.CloseIdleConnections()
|
||||||
|
s.httpMaskClient = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1068,7 +1068,7 @@ proxies: # socks5
|
|||||||
# http-mask-mode: legacy # 可选:legacy(默认)、stream、poll、auto;stream/poll/auto 支持走 CDN/反代
|
# http-mask-mode: legacy # 可选:legacy(默认)、stream、poll、auto;stream/poll/auto 支持走 CDN/反代
|
||||||
# http-mask-tls: true # 可选:仅在 http-mask-mode 为 stream/poll/auto 时生效;true 强制 https;false 强制 http(不会根据端口自动推断)
|
# http-mask-tls: true # 可选:仅在 http-mask-mode 为 stream/poll/auto 时生效;true 强制 https;false 强制 http(不会根据端口自动推断)
|
||||||
# http-mask-host: "" # 可选:覆盖 Host/SNI(支持 example.com 或 example.com:443);仅在 http-mask-mode 为 stream/poll/auto 时生效
|
# http-mask-host: "" # 可选:覆盖 Host/SNI(支持 example.com 或 example.com:443);仅在 http-mask-mode 为 stream/poll/auto 时生效
|
||||||
# http-mask-multiplex: off # 可选:off(默认)、auto、on;复用单条隧道并在其内多路复用多个目标连接
|
# http-mask-multiplex: off # 可选:off(默认)、auto(复用 h1.1 keep-alive / h2 连接,减少每次建链 RTT)、on(单条隧道内多路复用多个目标连接;仅在 http-mask-mode=stream/poll/auto 生效)
|
||||||
enable-pure-downlink: false # 是否启用混淆下行,false的情况下能在保证数据安全的前提下极大提升下行速度,与服务端端保持相同(如果此处为false,则要求aead不可为none)
|
enable-pure-downlink: false # 是否启用混淆下行,false的情况下能在保证数据安全的前提下极大提升下行速度,与服务端端保持相同(如果此处为false,则要求aead不可为none)
|
||||||
|
|
||||||
# anytls
|
# anytls
|
||||||
|
|||||||
@@ -58,9 +58,10 @@ type ProtocolConfig struct {
|
|||||||
// HTTPMaskHost optionally overrides the HTTP Host header / SNI host for HTTP tunnel modes (client-side).
|
// HTTPMaskHost optionally overrides the HTTP Host header / SNI host for HTTP tunnel modes (client-side).
|
||||||
HTTPMaskHost string
|
HTTPMaskHost string
|
||||||
|
|
||||||
// HTTPMaskMultiplex controls whether the client reuses a single (HTTP-masked) tunnel connection and
|
// HTTPMaskMultiplex controls multiplex behavior when HTTPMask tunnel modes are enabled:
|
||||||
// opens multiple logical target streams inside it (reduces RTT for subsequent connections).
|
// - "off": disable reuse; each Dial establishes its own HTTPMask tunnel
|
||||||
// Values: "off" / "auto" / "on".
|
// - "auto": reuse underlying HTTP connections across multiple tunnel dials (HTTP/1.1 keep-alive / HTTP/2)
|
||||||
|
// - "on": enable "single tunnel, multi-target" mux (Sudoku-level multiplex; Dial behaves like "auto" otherwise)
|
||||||
HTTPMaskMultiplex string
|
HTTPMaskMultiplex string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -83,6 +83,59 @@ func DialHTTPMaskTunnel(ctx context.Context, serverAddress string, cfg *Protocol
|
|||||||
Mode: cfg.HTTPMaskMode,
|
Mode: cfg.HTTPMaskMode,
|
||||||
TLSEnabled: cfg.HTTPMaskTLSEnabled,
|
TLSEnabled: cfg.HTTPMaskTLSEnabled,
|
||||||
HostOverride: cfg.HTTPMaskHost,
|
HostOverride: cfg.HTTPMaskHost,
|
||||||
|
Multiplex: cfg.HTTPMaskMultiplex,
|
||||||
DialContext: dial,
|
DialContext: dial,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type HTTPMaskTunnelClient struct {
|
||||||
|
mode string
|
||||||
|
client *httpmask.TunnelClient
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewHTTPMaskTunnelClient(serverAddress string, cfg *ProtocolConfig, dial TunnelDialer) (*HTTPMaskTunnelClient, error) {
|
||||||
|
if cfg == nil {
|
||||||
|
return nil, fmt.Errorf("config is required")
|
||||||
|
}
|
||||||
|
if cfg.DisableHTTPMask {
|
||||||
|
return nil, fmt.Errorf("http mask is disabled")
|
||||||
|
}
|
||||||
|
switch strings.ToLower(strings.TrimSpace(cfg.HTTPMaskMode)) {
|
||||||
|
case "stream", "poll", "auto":
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("http-mask-mode=%q does not use http tunnel", cfg.HTTPMaskMode)
|
||||||
|
}
|
||||||
|
switch strings.ToLower(strings.TrimSpace(cfg.HTTPMaskMultiplex)) {
|
||||||
|
case "auto", "on":
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("http-mask-multiplex=%q does not enable reuse", cfg.HTTPMaskMultiplex)
|
||||||
|
}
|
||||||
|
|
||||||
|
c, err := httpmask.NewTunnelClient(serverAddress, httpmask.TunnelClientOptions{
|
||||||
|
TLSEnabled: cfg.HTTPMaskTLSEnabled,
|
||||||
|
HostOverride: cfg.HTTPMaskHost,
|
||||||
|
DialContext: dial,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &HTTPMaskTunnelClient{
|
||||||
|
mode: cfg.HTTPMaskMode,
|
||||||
|
client: c,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *HTTPMaskTunnelClient) Dial(ctx context.Context) (net.Conn, error) {
|
||||||
|
if c == nil || c.client == nil {
|
||||||
|
return nil, fmt.Errorf("nil httpmask tunnel client")
|
||||||
|
}
|
||||||
|
return c.client.DialTunnel(ctx, c.mode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *HTTPMaskTunnelClient) CloseIdleConnections() {
|
||||||
|
if c == nil || c.client == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.client.CloseIdleConnections()
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
package sudoku
|
package sudoku
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/metacubex/mihomo/transport/sudoku/multiplex"
|
"github.com/metacubex/mihomo/transport/sudoku/multiplex"
|
||||||
)
|
)
|
||||||
@@ -46,26 +46,19 @@ func (c *MultiplexClient) Dial(ctx context.Context, targetAddress string) (net.C
|
|||||||
return nil, fmt.Errorf("target address cannot be empty")
|
return nil, fmt.Errorf("target address cannot be empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
stream, err := c.sess.OpenStream()
|
addrBuf, err := EncodeAddress(targetAddress)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("encode target address failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ctx != nil && ctx.Err() != nil {
|
||||||
|
return nil, ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
stream, err := c.sess.OpenStream(addrBuf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if deadline, ok := ctx.Deadline(); ok {
|
|
||||||
_ = stream.SetWriteDeadline(deadline)
|
|
||||||
defer stream.SetWriteDeadline(time.Time{})
|
|
||||||
}
|
|
||||||
|
|
||||||
addrBuf, err := EncodeAddress(targetAddress)
|
|
||||||
if err != nil {
|
|
||||||
_ = stream.Close()
|
|
||||||
return nil, fmt.Errorf("encode target address failed: %w", err)
|
|
||||||
}
|
|
||||||
if _, err := stream.Write(addrBuf); err != nil {
|
|
||||||
_ = stream.Close()
|
|
||||||
return nil, fmt.Errorf("send target address failed: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return stream, nil
|
return stream, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -114,18 +107,21 @@ func (s *MultiplexServer) AcceptStream() (net.Conn, error) {
|
|||||||
if s == nil || s.sess == nil {
|
if s == nil || s.sess == nil {
|
||||||
return nil, fmt.Errorf("nil session")
|
return nil, fmt.Errorf("nil session")
|
||||||
}
|
}
|
||||||
return s.sess.AcceptStream()
|
c, _, err := s.sess.AcceptStream()
|
||||||
|
return c, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// AcceptTCP accepts a multiplex stream and reads the target address preface, returning the stream positioned at
|
// AcceptTCP accepts a multiplex stream and returns the target address declared in the open frame.
|
||||||
// application data.
|
|
||||||
func (s *MultiplexServer) AcceptTCP() (net.Conn, string, error) {
|
func (s *MultiplexServer) AcceptTCP() (net.Conn, string, error) {
|
||||||
stream, err := s.AcceptStream()
|
if s == nil || s.sess == nil {
|
||||||
|
return nil, "", fmt.Errorf("nil session")
|
||||||
|
}
|
||||||
|
stream, payload, err := s.sess.AcceptStream()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
target, err := DecodeAddress(stream)
|
target, err := DecodeAddress(bytes.NewReader(payload))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = stream.Close()
|
_ = stream.Close()
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
@@ -147,4 +143,3 @@ func (s *MultiplexServer) IsClosed() bool {
|
|||||||
}
|
}
|
||||||
return s.sess.IsClosed()
|
return s.sess.IsClosed()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
39
transport/sudoku/multiplex/mux.go
Normal file
39
transport/sudoku/multiplex/mux.go
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
package multiplex
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// MagicByte marks a Sudoku tunnel connection that will switch into multiplex mode.
|
||||||
|
// It is sent after the Sudoku handshake + downlink mode byte.
|
||||||
|
//
|
||||||
|
// Keep it distinct from UoTMagicByte and address type bytes.
|
||||||
|
MagicByte byte = 0xED
|
||||||
|
Version byte = 0x01
|
||||||
|
)
|
||||||
|
|
||||||
|
func WritePreface(w io.Writer) error {
|
||||||
|
if w == nil {
|
||||||
|
return fmt.Errorf("nil writer")
|
||||||
|
}
|
||||||
|
_, err := w.Write([]byte{MagicByte, Version})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func ReadVersion(r io.Reader) (byte, error) {
|
||||||
|
var b [1]byte
|
||||||
|
if _, err := io.ReadFull(r, b[:]); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return b[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ValidateVersion(v byte) error {
|
||||||
|
if v != Version {
|
||||||
|
return fmt.Errorf("unsupported multiplex version: %d", v)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
@@ -1,101 +1,504 @@
|
|||||||
package multiplex
|
package multiplex
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/metacubex/smux"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// MagicByte marks a Sudoku tunnel connection that will switch into multiplex mode.
|
frameOpen byte = 0x01
|
||||||
// It is sent after the Sudoku handshake + downlink mode byte.
|
frameData byte = 0x02
|
||||||
MagicByte byte = 0xEF
|
frameClose byte = 0x03
|
||||||
Version = 0x01
|
frameReset byte = 0x04
|
||||||
)
|
)
|
||||||
|
|
||||||
func WritePreface(w io.Writer) error {
|
const (
|
||||||
_, err := w.Write([]byte{MagicByte, Version})
|
headerSize = 1 + 4 + 4
|
||||||
return err
|
maxFrameSize = 256 * 1024
|
||||||
}
|
maxDataPayload = 32 * 1024
|
||||||
|
)
|
||||||
|
|
||||||
func ReadVersion(r io.Reader) (byte, error) {
|
type acceptEvent struct {
|
||||||
var b [1]byte
|
stream *stream
|
||||||
if _, err := io.ReadFull(r, b[:]); err != nil {
|
payload []byte
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return b[0], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func ValidateVersion(v byte) error {
|
|
||||||
if v != Version {
|
|
||||||
return fmt.Errorf("unsupported multiplex version: %d", v)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func defaultSmuxConfig() *smux.Config {
|
|
||||||
cfg := smux.DefaultConfig()
|
|
||||||
cfg.KeepAliveInterval = 15 * time.Second
|
|
||||||
cfg.KeepAliveTimeout = 45 * time.Second
|
|
||||||
return cfg
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type Session struct {
|
type Session struct {
|
||||||
sess *smux.Session
|
conn net.Conn
|
||||||
|
|
||||||
|
writeMu sync.Mutex
|
||||||
|
|
||||||
|
streamsMu sync.Mutex
|
||||||
|
streams map[uint32]*stream
|
||||||
|
nextID uint32
|
||||||
|
|
||||||
|
acceptCh chan acceptEvent
|
||||||
|
|
||||||
|
closed chan struct{}
|
||||||
|
closeOnce sync.Once
|
||||||
|
closeErr error
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewClientSession(conn net.Conn) (*Session, error) {
|
func NewClientSession(conn net.Conn) (*Session, error) {
|
||||||
if conn == nil {
|
if conn == nil {
|
||||||
return nil, fmt.Errorf("nil conn")
|
return nil, fmt.Errorf("nil conn")
|
||||||
}
|
}
|
||||||
s, err := smux.Client(conn, defaultSmuxConfig())
|
s := &Session{
|
||||||
if err != nil {
|
conn: conn,
|
||||||
_ = conn.Close()
|
streams: make(map[uint32]*stream),
|
||||||
return nil, err
|
closed: make(chan struct{}),
|
||||||
}
|
}
|
||||||
return &Session{sess: s}, nil
|
go s.readLoop()
|
||||||
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewServerSession(conn net.Conn) (*Session, error) {
|
func NewServerSession(conn net.Conn) (*Session, error) {
|
||||||
if conn == nil {
|
if conn == nil {
|
||||||
return nil, fmt.Errorf("nil conn")
|
return nil, fmt.Errorf("nil conn")
|
||||||
}
|
}
|
||||||
s, err := smux.Server(conn, defaultSmuxConfig())
|
s := &Session{
|
||||||
if err != nil {
|
conn: conn,
|
||||||
_ = conn.Close()
|
streams: make(map[uint32]*stream),
|
||||||
return nil, err
|
acceptCh: make(chan acceptEvent, 256),
|
||||||
|
closed: make(chan struct{}),
|
||||||
}
|
}
|
||||||
return &Session{sess: s}, nil
|
go s.readLoop()
|
||||||
}
|
return s, nil
|
||||||
|
|
||||||
func (s *Session) OpenStream() (net.Conn, error) {
|
|
||||||
if s == nil || s.sess == nil {
|
|
||||||
return nil, fmt.Errorf("nil session")
|
|
||||||
}
|
|
||||||
return s.sess.OpenStream()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Session) AcceptStream() (net.Conn, error) {
|
|
||||||
if s == nil || s.sess == nil {
|
|
||||||
return nil, fmt.Errorf("nil session")
|
|
||||||
}
|
|
||||||
return s.sess.AcceptStream()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Session) Close() error {
|
|
||||||
if s == nil || s.sess == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return s.sess.Close()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Session) IsClosed() bool {
|
func (s *Session) IsClosed() bool {
|
||||||
if s == nil || s.sess == nil {
|
if s == nil {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
return s.sess.IsClosed()
|
select {
|
||||||
|
case <-s.closed:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Session) closedErr() error {
|
||||||
|
s.streamsMu.Lock()
|
||||||
|
err := s.closeErr
|
||||||
|
s.streamsMu.Unlock()
|
||||||
|
if err == nil {
|
||||||
|
return io.ErrClosedPipe
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) closeWithError(err error) {
|
||||||
|
if err == nil {
|
||||||
|
err = io.ErrClosedPipe
|
||||||
|
}
|
||||||
|
s.closeOnce.Do(func() {
|
||||||
|
s.streamsMu.Lock()
|
||||||
|
if s.closeErr == nil {
|
||||||
|
s.closeErr = err
|
||||||
|
}
|
||||||
|
streams := make([]*stream, 0, len(s.streams))
|
||||||
|
for _, st := range s.streams {
|
||||||
|
streams = append(streams, st)
|
||||||
|
}
|
||||||
|
s.streams = make(map[uint32]*stream)
|
||||||
|
s.streamsMu.Unlock()
|
||||||
|
|
||||||
|
for _, st := range streams {
|
||||||
|
st.closeNoSend(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
close(s.closed)
|
||||||
|
_ = s.conn.Close()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) Close() error {
|
||||||
|
if s == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
s.closeWithError(io.ErrClosedPipe)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) registerStream(st *stream) {
|
||||||
|
s.streamsMu.Lock()
|
||||||
|
s.streams[st.id] = st
|
||||||
|
s.streamsMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) getStream(id uint32) *stream {
|
||||||
|
s.streamsMu.Lock()
|
||||||
|
st := s.streams[id]
|
||||||
|
s.streamsMu.Unlock()
|
||||||
|
return st
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) removeStream(id uint32) {
|
||||||
|
s.streamsMu.Lock()
|
||||||
|
delete(s.streams, id)
|
||||||
|
s.streamsMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) nextStreamID() uint32 {
|
||||||
|
s.streamsMu.Lock()
|
||||||
|
s.nextID++
|
||||||
|
id := s.nextID
|
||||||
|
if id == 0 {
|
||||||
|
s.nextID++
|
||||||
|
id = s.nextID
|
||||||
|
}
|
||||||
|
s.streamsMu.Unlock()
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) sendFrame(frameType byte, streamID uint32, payload []byte) error {
|
||||||
|
if s.IsClosed() {
|
||||||
|
return s.closedErr()
|
||||||
|
}
|
||||||
|
if len(payload) > maxFrameSize {
|
||||||
|
return fmt.Errorf("mux payload too large: %d", len(payload))
|
||||||
|
}
|
||||||
|
|
||||||
|
var header [headerSize]byte
|
||||||
|
header[0] = frameType
|
||||||
|
binary.BigEndian.PutUint32(header[1:5], streamID)
|
||||||
|
binary.BigEndian.PutUint32(header[5:9], uint32(len(payload)))
|
||||||
|
|
||||||
|
s.writeMu.Lock()
|
||||||
|
defer s.writeMu.Unlock()
|
||||||
|
|
||||||
|
if err := writeFull(s.conn, header[:]); err != nil {
|
||||||
|
s.closeWithError(err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if len(payload) > 0 {
|
||||||
|
if err := writeFull(s.conn, payload); err != nil {
|
||||||
|
s.closeWithError(err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) sendReset(streamID uint32, msg string) {
|
||||||
|
if msg == "" {
|
||||||
|
msg = "reset"
|
||||||
|
}
|
||||||
|
_ = s.sendFrame(frameReset, streamID, []byte(msg))
|
||||||
|
_ = s.sendFrame(frameClose, streamID, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) OpenStream(openPayload []byte) (net.Conn, error) {
|
||||||
|
if s == nil {
|
||||||
|
return nil, fmt.Errorf("nil session")
|
||||||
|
}
|
||||||
|
if s.IsClosed() {
|
||||||
|
return nil, s.closedErr()
|
||||||
|
}
|
||||||
|
|
||||||
|
streamID := s.nextStreamID()
|
||||||
|
st := newStream(s, streamID)
|
||||||
|
s.registerStream(st)
|
||||||
|
|
||||||
|
if err := s.sendFrame(frameOpen, streamID, openPayload); err != nil {
|
||||||
|
st.closeNoSend(err)
|
||||||
|
s.removeStream(streamID)
|
||||||
|
return nil, fmt.Errorf("mux open failed: %w", err)
|
||||||
|
}
|
||||||
|
return st, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) AcceptStream() (net.Conn, []byte, error) {
|
||||||
|
if s == nil {
|
||||||
|
return nil, nil, fmt.Errorf("nil session")
|
||||||
|
}
|
||||||
|
if s.acceptCh == nil {
|
||||||
|
return nil, nil, fmt.Errorf("accept is not supported on client sessions")
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case ev := <-s.acceptCh:
|
||||||
|
return ev.stream, ev.payload, nil
|
||||||
|
case <-s.closed:
|
||||||
|
return nil, nil, s.closedErr()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) readLoop() {
|
||||||
|
var header [headerSize]byte
|
||||||
|
for {
|
||||||
|
if _, err := io.ReadFull(s.conn, header[:]); err != nil {
|
||||||
|
s.closeWithError(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
frameType := header[0]
|
||||||
|
streamID := binary.BigEndian.Uint32(header[1:5])
|
||||||
|
n := int(binary.BigEndian.Uint32(header[5:9]))
|
||||||
|
if n < 0 || n > maxFrameSize {
|
||||||
|
s.closeWithError(fmt.Errorf("invalid mux frame length: %d", n))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var payload []byte
|
||||||
|
if n > 0 {
|
||||||
|
payload = make([]byte, n)
|
||||||
|
if _, err := io.ReadFull(s.conn, payload); err != nil {
|
||||||
|
s.closeWithError(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch frameType {
|
||||||
|
case frameOpen:
|
||||||
|
if s.acceptCh == nil {
|
||||||
|
s.sendReset(streamID, "unexpected open")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if streamID == 0 {
|
||||||
|
s.sendReset(streamID, "invalid stream id")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if existing := s.getStream(streamID); existing != nil {
|
||||||
|
s.sendReset(streamID, "stream already exists")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
st := newStream(s, streamID)
|
||||||
|
s.registerStream(st)
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case s.acceptCh <- acceptEvent{stream: st, payload: payload}:
|
||||||
|
case <-s.closed:
|
||||||
|
st.closeNoSend(io.ErrClosedPipe)
|
||||||
|
s.removeStream(streamID)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
case frameData:
|
||||||
|
st := s.getStream(streamID)
|
||||||
|
if st == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if len(payload) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
st.enqueue(payload)
|
||||||
|
|
||||||
|
case frameClose:
|
||||||
|
st := s.getStream(streamID)
|
||||||
|
if st == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
st.closeNoSend(io.EOF)
|
||||||
|
s.removeStream(streamID)
|
||||||
|
|
||||||
|
case frameReset:
|
||||||
|
st := s.getStream(streamID)
|
||||||
|
if st == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
msg := trimASCII(payload)
|
||||||
|
if msg == "" {
|
||||||
|
msg = "reset"
|
||||||
|
}
|
||||||
|
st.closeNoSend(errors.New(msg))
|
||||||
|
s.removeStream(streamID)
|
||||||
|
|
||||||
|
default:
|
||||||
|
s.closeWithError(fmt.Errorf("unknown mux frame type: %d", frameType))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeFull(w io.Writer, b []byte) error {
|
||||||
|
for len(b) > 0 {
|
||||||
|
n, err := w.Write(b)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
b = b[n:]
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func trimASCII(b []byte) string {
|
||||||
|
i := 0
|
||||||
|
j := len(b)
|
||||||
|
for i < j {
|
||||||
|
c := b[i]
|
||||||
|
if c != ' ' && c != '\n' && c != '\r' && c != '\t' {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
for j > i {
|
||||||
|
c := b[j-1]
|
||||||
|
if c != ' ' && c != '\n' && c != '\r' && c != '\t' {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
j--
|
||||||
|
}
|
||||||
|
if i >= j {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
out := make([]byte, j-i)
|
||||||
|
copy(out, b[i:j])
|
||||||
|
return string(out)
|
||||||
|
}
|
||||||
|
|
||||||
|
type stream struct {
|
||||||
|
session *Session
|
||||||
|
id uint32
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
cond *sync.Cond
|
||||||
|
closed bool
|
||||||
|
closeErr error
|
||||||
|
readBuf []byte
|
||||||
|
queue [][]byte
|
||||||
|
|
||||||
|
localAddr net.Addr
|
||||||
|
remoteAddr net.Addr
|
||||||
|
}
|
||||||
|
|
||||||
|
func newStream(session *Session, id uint32) *stream {
|
||||||
|
st := &stream{
|
||||||
|
session: session,
|
||||||
|
id: id,
|
||||||
|
localAddr: &net.TCPAddr{},
|
||||||
|
remoteAddr: &net.TCPAddr{},
|
||||||
|
}
|
||||||
|
st.cond = sync.NewCond(&st.mu)
|
||||||
|
return st
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *stream) enqueue(payload []byte) {
|
||||||
|
c.mu.Lock()
|
||||||
|
if c.closed {
|
||||||
|
c.mu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.queue = append(c.queue, payload)
|
||||||
|
c.cond.Signal()
|
||||||
|
c.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *stream) closeNoSend(err error) {
|
||||||
|
if err == nil {
|
||||||
|
err = io.EOF
|
||||||
|
}
|
||||||
|
c.mu.Lock()
|
||||||
|
if c.closed {
|
||||||
|
c.mu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.closed = true
|
||||||
|
if c.closeErr == nil {
|
||||||
|
c.closeErr = err
|
||||||
|
}
|
||||||
|
c.cond.Broadcast()
|
||||||
|
c.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *stream) closedErr() error {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
if c.closeErr == nil {
|
||||||
|
return io.ErrClosedPipe
|
||||||
|
}
|
||||||
|
return c.closeErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *stream) Read(p []byte) (int, error) {
|
||||||
|
if len(p) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
for len(c.readBuf) == 0 && len(c.queue) == 0 && !c.closed {
|
||||||
|
c.cond.Wait()
|
||||||
|
}
|
||||||
|
if len(c.readBuf) == 0 && len(c.queue) > 0 {
|
||||||
|
c.readBuf = c.queue[0]
|
||||||
|
c.queue = c.queue[1:]
|
||||||
|
}
|
||||||
|
if len(c.readBuf) == 0 && c.closed {
|
||||||
|
if c.closeErr == nil {
|
||||||
|
return 0, io.ErrClosedPipe
|
||||||
|
}
|
||||||
|
return 0, c.closeErr
|
||||||
|
}
|
||||||
|
|
||||||
|
n := copy(p, c.readBuf)
|
||||||
|
c.readBuf = c.readBuf[n:]
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *stream) Write(p []byte) (int, error) {
|
||||||
|
if len(p) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
if c.session == nil || c.session.IsClosed() {
|
||||||
|
if c.session != nil {
|
||||||
|
return 0, c.session.closedErr()
|
||||||
|
}
|
||||||
|
return 0, io.ErrClosedPipe
|
||||||
|
}
|
||||||
|
|
||||||
|
c.mu.Lock()
|
||||||
|
closed := c.closed
|
||||||
|
c.mu.Unlock()
|
||||||
|
if closed {
|
||||||
|
return 0, c.closedErr()
|
||||||
|
}
|
||||||
|
|
||||||
|
written := 0
|
||||||
|
for len(p) > 0 {
|
||||||
|
chunk := p
|
||||||
|
if len(chunk) > maxDataPayload {
|
||||||
|
chunk = p[:maxDataPayload]
|
||||||
|
}
|
||||||
|
if err := c.session.sendFrame(frameData, c.id, chunk); err != nil {
|
||||||
|
return written, err
|
||||||
|
}
|
||||||
|
written += len(chunk)
|
||||||
|
p = p[len(chunk):]
|
||||||
|
}
|
||||||
|
return written, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *stream) Close() error {
|
||||||
|
c.mu.Lock()
|
||||||
|
if c.closed {
|
||||||
|
c.mu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
c.closed = true
|
||||||
|
if c.closeErr == nil {
|
||||||
|
c.closeErr = io.ErrClosedPipe
|
||||||
|
}
|
||||||
|
c.cond.Broadcast()
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
_ = c.session.sendFrame(frameClose, c.id, nil)
|
||||||
|
c.session.removeStream(c.id)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *stream) LocalAddr() net.Addr { return c.localAddr }
|
||||||
|
func (c *stream) RemoteAddr() net.Addr { return c.remoteAddr }
|
||||||
|
|
||||||
|
func (c *stream) SetDeadline(t time.Time) error {
|
||||||
|
_ = c.SetReadDeadline(t)
|
||||||
|
_ = c.SetWriteDeadline(t)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (c *stream) SetReadDeadline(time.Time) error { return nil }
|
||||||
|
func (c *stream) SetWriteDeadline(time.Time) error { return nil }
|
||||||
|
|
||||||
|
|||||||
@@ -62,11 +62,84 @@ type TunnelDialOptions struct {
|
|||||||
Mode string
|
Mode string
|
||||||
TLSEnabled bool // when true, use HTTPS; otherwise, use HTTP (no port-based inference)
|
TLSEnabled bool // when true, use HTTPS; otherwise, use HTTP (no port-based inference)
|
||||||
HostOverride string // optional Host header / SNI host (without scheme); accepts "example.com" or "example.com:443"
|
HostOverride string // optional Host header / SNI host (without scheme); accepts "example.com" or "example.com:443"
|
||||||
|
// Multiplex controls whether the caller should reuse underlying HTTP connections (HTTP/1.1 keep-alive / HTTP/2).
|
||||||
|
// To reuse across multiple dials, create a TunnelClient per proxy and reuse it.
|
||||||
|
// Values: "off" disables reuse; "auto"/"on" enables it.
|
||||||
|
Multiplex string
|
||||||
// DialContext overrides how the HTTP tunnel dials raw TCP/TLS connections.
|
// DialContext overrides how the HTTP tunnel dials raw TCP/TLS connections.
|
||||||
// It must not be nil; passing nil is a programming error.
|
// It must not be nil; passing nil is a programming error.
|
||||||
DialContext func(ctx context.Context, network, addr string) (net.Conn, error)
|
DialContext func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type TunnelClientOptions struct {
|
||||||
|
TLSEnabled bool
|
||||||
|
HostOverride string
|
||||||
|
DialContext func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||||
|
MaxIdleConns int
|
||||||
|
}
|
||||||
|
|
||||||
|
type TunnelClient struct {
|
||||||
|
client *http.Client
|
||||||
|
transport *http.Transport
|
||||||
|
target httpClientTarget
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTunnelClient(serverAddress string, opts TunnelClientOptions) (*TunnelClient, error) {
|
||||||
|
maxIdle := opts.MaxIdleConns
|
||||||
|
if maxIdle <= 0 {
|
||||||
|
maxIdle = 32
|
||||||
|
}
|
||||||
|
|
||||||
|
transport, target, err := buildHTTPTransport(serverAddress, opts.TLSEnabled, opts.HostOverride, opts.DialContext, maxIdle)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &TunnelClient{
|
||||||
|
client: &http.Client{Transport: transport},
|
||||||
|
transport: transport,
|
||||||
|
target: target,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *TunnelClient) CloseIdleConnections() {
|
||||||
|
if c == nil || c.transport == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.transport.CloseIdleConnections()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *TunnelClient) DialTunnel(ctx context.Context, mode string) (net.Conn, error) {
|
||||||
|
if c == nil || c.client == nil {
|
||||||
|
return nil, fmt.Errorf("nil tunnel client")
|
||||||
|
}
|
||||||
|
tm := normalizeTunnelMode(mode)
|
||||||
|
if tm == TunnelModeLegacy {
|
||||||
|
return nil, fmt.Errorf("legacy mode does not use http tunnel")
|
||||||
|
}
|
||||||
|
|
||||||
|
switch tm {
|
||||||
|
case TunnelModeStream:
|
||||||
|
return dialStreamWithClient(ctx, c.client, c.target)
|
||||||
|
case TunnelModePoll:
|
||||||
|
return dialPollWithClient(ctx, c.client, c.target)
|
||||||
|
case TunnelModeAuto:
|
||||||
|
streamCtx, cancelX := context.WithTimeout(ctx, 3*time.Second)
|
||||||
|
c1, errX := dialStreamWithClient(streamCtx, c.client, c.target)
|
||||||
|
cancelX()
|
||||||
|
if errX == nil {
|
||||||
|
return c1, nil
|
||||||
|
}
|
||||||
|
c2, errP := dialPollWithClient(ctx, c.client, c.target)
|
||||||
|
if errP == nil {
|
||||||
|
return c2, nil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("auto tunnel failed: stream: %v; poll: %w", errX, errP)
|
||||||
|
default:
|
||||||
|
return dialStreamWithClient(ctx, c.client, c.target)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// DialTunnel establishes a bidirectional stream over HTTP:
|
// DialTunnel establishes a bidirectional stream over HTTP:
|
||||||
// - stream: a single streaming POST (request body uplink, response body downlink)
|
// - stream: a single streaming POST (request body uplink, response body downlink)
|
||||||
// - poll: authorize + push/pull polling tunnel (base64 framed)
|
// - poll: authorize + push/pull polling tunnel (base64 framed)
|
||||||
@@ -192,43 +265,154 @@ type httpClientTarget struct {
|
|||||||
headerHost string
|
headerHost string
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHTTPClient(serverAddress string, opts TunnelDialOptions, maxIdleConns int) (*http.Client, httpClientTarget, error) {
|
func buildHTTPTransport(serverAddress string, tlsEnabled bool, hostOverride string, dialContext func(ctx context.Context, network, addr string) (net.Conn, error), maxIdleConns int) (*http.Transport, httpClientTarget, error) {
|
||||||
if opts.DialContext == nil {
|
if dialContext == nil {
|
||||||
panic("httpmask: DialContext is nil")
|
panic("httpmask: DialContext is nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
scheme, urlHost, dialAddr, serverName, err := normalizeHTTPDialTarget(serverAddress, opts.TLSEnabled, opts.HostOverride)
|
scheme, urlHost, dialAddr, serverName, err := normalizeHTTPDialTarget(serverAddress, tlsEnabled, hostOverride)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, httpClientTarget{}, err
|
return nil, httpClientTarget{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
transport := &http.Transport{
|
transport := &http.Transport{
|
||||||
ForceAttemptHTTP2: scheme == "https",
|
ForceAttemptHTTP2: scheme == "https",
|
||||||
DisableCompression: true,
|
DisableCompression: true,
|
||||||
MaxIdleConns: maxIdleConns,
|
MaxIdleConns: maxIdleConns,
|
||||||
IdleConnTimeout: 30 * time.Second,
|
MaxIdleConnsPerHost: maxIdleConns,
|
||||||
TLSHandshakeTimeout: 10 * time.Second,
|
IdleConnTimeout: 30 * time.Second,
|
||||||
|
ResponseHeaderTimeout: 20 * time.Second,
|
||||||
|
TLSHandshakeTimeout: 10 * time.Second,
|
||||||
DialContext: func(dialCtx context.Context, network, _ string) (net.Conn, error) {
|
DialContext: func(dialCtx context.Context, network, _ string) (net.Conn, error) {
|
||||||
return opts.DialContext(dialCtx, network, dialAddr)
|
return dialContext(dialCtx, network, dialAddr)
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
if scheme == "https" {
|
if scheme == "https" {
|
||||||
transport.TLSClientConfig, err = ca.GetTLSConfig(ca.Option{TLSConfig: &tls.Config{
|
var tlsConf *tls.Config
|
||||||
|
tlsConf, err = ca.GetTLSConfig(ca.Option{TLSConfig: &tls.Config{
|
||||||
ServerName: serverName,
|
ServerName: serverName,
|
||||||
MinVersion: tls.VersionTLS12,
|
MinVersion: tls.VersionTLS12,
|
||||||
}})
|
}})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, httpClientTarget{}, err
|
return nil, httpClientTarget{}, err
|
||||||
}
|
}
|
||||||
|
transport.TLSClientConfig = tlsConf
|
||||||
}
|
}
|
||||||
|
|
||||||
return &http.Client{Transport: transport}, httpClientTarget{
|
return transport, httpClientTarget{
|
||||||
scheme: scheme,
|
scheme: scheme,
|
||||||
urlHost: urlHost,
|
urlHost: urlHost,
|
||||||
headerHost: canonicalHeaderHost(urlHost, scheme),
|
headerHost: canonicalHeaderHost(urlHost, scheme),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newHTTPClient(serverAddress string, opts TunnelDialOptions, maxIdleConns int) (*http.Client, httpClientTarget, error) {
|
||||||
|
transport, target, err := buildHTTPTransport(serverAddress, opts.TLSEnabled, opts.HostOverride, opts.DialContext, maxIdleConns)
|
||||||
|
if err != nil {
|
||||||
|
return nil, httpClientTarget{}, err
|
||||||
|
}
|
||||||
|
return &http.Client{Transport: transport}, target, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type sessionDialInfo struct {
|
||||||
|
client *http.Client
|
||||||
|
pushURL string
|
||||||
|
pullURL string
|
||||||
|
closeURL string
|
||||||
|
headerHost string
|
||||||
|
}
|
||||||
|
|
||||||
|
func dialSessionWithClient(ctx context.Context, client *http.Client, target httpClientTarget, mode TunnelMode) (*sessionDialInfo, error) {
|
||||||
|
if client == nil {
|
||||||
|
return nil, fmt.Errorf("nil http client")
|
||||||
|
}
|
||||||
|
|
||||||
|
authorizeURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: "/session"}).String()
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, authorizeURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Host = target.headerHost
|
||||||
|
applyTunnelHeaders(req.Header, target.headerHost, mode)
|
||||||
|
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
bodyBytes, err := io.ReadAll(io.LimitReader(resp.Body, 4*1024))
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("%s authorize bad status: %s (%s)", mode, resp.Status, strings.TrimSpace(string(bodyBytes)))
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := parseTunnelToken(bodyBytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("%s authorize failed: %q", mode, strings.TrimSpace(string(bodyBytes)))
|
||||||
|
}
|
||||||
|
if token == "" {
|
||||||
|
return nil, fmt.Errorf("%s authorize empty token", mode)
|
||||||
|
}
|
||||||
|
|
||||||
|
pushURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: "/api/v1/upload", RawQuery: "token=" + url.QueryEscape(token)}).String()
|
||||||
|
pullURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: "/stream", RawQuery: "token=" + url.QueryEscape(token)}).String()
|
||||||
|
closeURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: "/api/v1/upload", RawQuery: "token=" + url.QueryEscape(token) + "&close=1"}).String()
|
||||||
|
|
||||||
|
return &sessionDialInfo{
|
||||||
|
client: client,
|
||||||
|
pushURL: pushURL,
|
||||||
|
pullURL: pullURL,
|
||||||
|
closeURL: closeURL,
|
||||||
|
headerHost: target.headerHost,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func dialSession(ctx context.Context, serverAddress string, opts TunnelDialOptions, mode TunnelMode) (*sessionDialInfo, error) {
|
||||||
|
client, target, err := newHTTPClient(serverAddress, opts, 32)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return dialSessionWithClient(ctx, client, target, mode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func bestEffortCloseSession(client *http.Client, closeURL, headerHost string, mode TunnelMode) {
|
||||||
|
if client == nil || closeURL == "" || headerHost == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
closeCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(closeCtx, http.MethodPost, closeURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
req.Host = headerHost
|
||||||
|
applyTunnelHeaders(req.Header, headerHost, mode)
|
||||||
|
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil || resp == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 4*1024))
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func dialStreamWithClient(ctx context.Context, client *http.Client, target httpClientTarget) (net.Conn, error) {
|
||||||
|
// Prefer split session (Cloudflare-friendly). Fall back to stream-one for older servers / environments.
|
||||||
|
c, errSplit := dialStreamSplitWithClient(ctx, client, target)
|
||||||
|
if errSplit == nil {
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
c2, errOne := dialStreamOneWithClient(ctx, client, target)
|
||||||
|
if errOne == nil {
|
||||||
|
return c2, nil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("dial stream failed: split: %v; stream-one: %w", errSplit, errOne)
|
||||||
|
}
|
||||||
|
|
||||||
func dialStream(ctx context.Context, serverAddress string, opts TunnelDialOptions) (net.Conn, error) {
|
func dialStream(ctx context.Context, serverAddress string, opts TunnelDialOptions) (net.Conn, error) {
|
||||||
// Prefer split session (Cloudflare-friendly). Fall back to stream-one for older servers / environments.
|
// Prefer split session (Cloudflare-friendly). Fall back to stream-one for older servers / environments.
|
||||||
c, errSplit := dialStreamSplit(ctx, serverAddress, opts)
|
c, errSplit := dialStreamSplit(ctx, serverAddress, opts)
|
||||||
@@ -242,10 +426,9 @@ func dialStream(ctx context.Context, serverAddress string, opts TunnelDialOption
|
|||||||
return nil, fmt.Errorf("dial stream failed: split: %v; stream-one: %w", errSplit, errOne)
|
return nil, fmt.Errorf("dial stream failed: split: %v; stream-one: %w", errSplit, errOne)
|
||||||
}
|
}
|
||||||
|
|
||||||
func dialStreamOne(ctx context.Context, serverAddress string, opts TunnelDialOptions) (net.Conn, error) {
|
func dialStreamOneWithClient(ctx context.Context, client *http.Client, target httpClientTarget) (net.Conn, error) {
|
||||||
client, target, err := newHTTPClient(serverAddress, opts, 16)
|
if client == nil {
|
||||||
if err != nil {
|
return nil, fmt.Errorf("nil http client")
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
r := rngPool.Get().(*mrand.Rand)
|
r := rngPool.Get().(*mrand.Rand)
|
||||||
@@ -312,16 +495,15 @@ func dialStreamOne(ctx context.Context, serverAddress string, opts TunnelDialOpt
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type streamSplitConn struct {
|
func dialStreamOne(ctx context.Context, serverAddress string, opts TunnelDialOptions) (net.Conn, error) {
|
||||||
ctx context.Context
|
client, target, err := newHTTPClient(serverAddress, opts, 32)
|
||||||
cancel context.CancelFunc
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
client *http.Client
|
}
|
||||||
pushURL string
|
return dialStreamOneWithClient(ctx, client, target)
|
||||||
pullURL string
|
}
|
||||||
closeURL string
|
|
||||||
headerHost string
|
|
||||||
|
|
||||||
|
type queuedConn struct {
|
||||||
rxc chan []byte
|
rxc chan []byte
|
||||||
closed chan struct{}
|
closed chan struct{}
|
||||||
|
|
||||||
@@ -329,16 +511,46 @@ type streamSplitConn struct {
|
|||||||
|
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
readBuf []byte
|
readBuf []byte
|
||||||
|
closeErr error
|
||||||
localAddr net.Addr
|
localAddr net.Addr
|
||||||
remoteAddr net.Addr
|
remoteAddr net.Addr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *streamSplitConn) Read(b []byte) (n int, err error) {
|
func (c *queuedConn) closeWithError(err error) error {
|
||||||
|
c.mu.Lock()
|
||||||
|
select {
|
||||||
|
case <-c.closed:
|
||||||
|
c.mu.Unlock()
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
if err == nil {
|
||||||
|
err = io.ErrClosedPipe
|
||||||
|
}
|
||||||
|
if c.closeErr == nil {
|
||||||
|
c.closeErr = err
|
||||||
|
}
|
||||||
|
close(c.closed)
|
||||||
|
}
|
||||||
|
c.mu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *queuedConn) closedErr() error {
|
||||||
|
c.mu.Lock()
|
||||||
|
err := c.closeErr
|
||||||
|
c.mu.Unlock()
|
||||||
|
if err == nil {
|
||||||
|
return io.ErrClosedPipe
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *queuedConn) Read(b []byte) (n int, err error) {
|
||||||
if len(c.readBuf) == 0 {
|
if len(c.readBuf) == 0 {
|
||||||
select {
|
select {
|
||||||
case c.readBuf = <-c.rxc:
|
case c.readBuf = <-c.rxc:
|
||||||
case <-c.closed:
|
case <-c.closed:
|
||||||
return 0, io.ErrClosedPipe
|
return 0, c.closedErr()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
n = copy(b, c.readBuf)
|
n = copy(b, c.readBuf)
|
||||||
@@ -346,7 +558,7 @@ func (c *streamSplitConn) Read(b []byte) (n int, err error) {
|
|||||||
return n, nil
|
return n, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *streamSplitConn) Write(b []byte) (n int, err error) {
|
func (c *queuedConn) Write(b []byte) (n int, err error) {
|
||||||
if len(b) == 0 {
|
if len(b) == 0 {
|
||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
@@ -354,7 +566,7 @@ func (c *streamSplitConn) Write(b []byte) (n int, err error) {
|
|||||||
select {
|
select {
|
||||||
case <-c.closed:
|
case <-c.closed:
|
||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
return 0, io.ErrClosedPipe
|
return 0, c.closedErr()
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
@@ -365,111 +577,97 @@ func (c *streamSplitConn) Write(b []byte) (n int, err error) {
|
|||||||
case c.writeCh <- payload:
|
case c.writeCh <- payload:
|
||||||
return len(b), nil
|
return len(b), nil
|
||||||
case <-c.closed:
|
case <-c.closed:
|
||||||
return 0, io.ErrClosedPipe
|
return 0, c.closedErr()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *queuedConn) LocalAddr() net.Addr { return c.localAddr }
|
||||||
|
func (c *queuedConn) RemoteAddr() net.Addr { return c.remoteAddr }
|
||||||
|
|
||||||
|
func (c *queuedConn) SetDeadline(time.Time) error { return nil }
|
||||||
|
func (c *queuedConn) SetReadDeadline(time.Time) error { return nil }
|
||||||
|
func (c *queuedConn) SetWriteDeadline(time.Time) error { return nil }
|
||||||
|
|
||||||
|
type streamSplitConn struct {
|
||||||
|
queuedConn
|
||||||
|
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
|
||||||
|
client *http.Client
|
||||||
|
pushURL string
|
||||||
|
pullURL string
|
||||||
|
closeURL string
|
||||||
|
headerHost string
|
||||||
|
}
|
||||||
|
|
||||||
func (c *streamSplitConn) Close() error {
|
func (c *streamSplitConn) Close() error {
|
||||||
c.mu.Lock()
|
_ = c.closeWithError(io.ErrClosedPipe)
|
||||||
select {
|
|
||||||
case <-c.closed:
|
|
||||||
c.mu.Unlock()
|
|
||||||
return nil
|
|
||||||
default:
|
|
||||||
close(c.closed)
|
|
||||||
}
|
|
||||||
c.mu.Unlock()
|
|
||||||
|
|
||||||
if c.cancel != nil {
|
if c.cancel != nil {
|
||||||
c.cancel()
|
c.cancel()
|
||||||
}
|
}
|
||||||
|
bestEffortCloseSession(c.client, c.closeURL, c.headerHost, TunnelModeStream)
|
||||||
// Best-effort session close signal (avoid leaking server-side sessions).
|
|
||||||
closeCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
req, err := http.NewRequestWithContext(closeCtx, http.MethodPost, c.closeURL, nil)
|
|
||||||
if err == nil {
|
|
||||||
req.Host = c.headerHost
|
|
||||||
applyTunnelHeaders(req.Header, c.headerHost, TunnelModeStream)
|
|
||||||
if resp, doErr := c.client.Do(req); doErr == nil && resp != nil {
|
|
||||||
_, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 4*1024))
|
|
||||||
_ = resp.Body.Close()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *streamSplitConn) LocalAddr() net.Addr { return c.localAddr }
|
func newStreamSplitConnFromInfo(info *sessionDialInfo) *streamSplitConn {
|
||||||
func (c *streamSplitConn) RemoteAddr() net.Addr { return c.remoteAddr }
|
if info == nil {
|
||||||
|
return nil
|
||||||
func (c *streamSplitConn) SetDeadline(time.Time) error { return nil }
|
|
||||||
func (c *streamSplitConn) SetReadDeadline(time.Time) error { return nil }
|
|
||||||
func (c *streamSplitConn) SetWriteDeadline(time.Time) error { return nil }
|
|
||||||
|
|
||||||
func dialStreamSplit(ctx context.Context, serverAddress string, opts TunnelDialOptions) (net.Conn, error) {
|
|
||||||
client, target, err := newHTTPClient(serverAddress, opts, 32)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
authorizeURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: "/session"}).String()
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, authorizeURL, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
req.Host = target.headerHost
|
|
||||||
applyTunnelHeaders(req.Header, target.headerHost, TunnelModeStream)
|
|
||||||
|
|
||||||
resp, err := client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
bodyBytes, err := io.ReadAll(io.LimitReader(resp.Body, 4*1024))
|
|
||||||
_ = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
return nil, fmt.Errorf("stream authorize bad status: %s (%s)", resp.Status, strings.TrimSpace(string(bodyBytes)))
|
|
||||||
}
|
|
||||||
|
|
||||||
token, err := parseTunnelToken(bodyBytes)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("stream authorize failed: %q", strings.TrimSpace(string(bodyBytes)))
|
|
||||||
}
|
|
||||||
if token == "" {
|
|
||||||
return nil, fmt.Errorf("stream authorize empty token")
|
|
||||||
}
|
|
||||||
|
|
||||||
pushURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: "/api/v1/upload", RawQuery: "token=" + url.QueryEscape(token)}).String()
|
|
||||||
pullURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: "/stream", RawQuery: "token=" + url.QueryEscape(token)}).String()
|
|
||||||
closeURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: "/api/v1/upload", RawQuery: "token=" + url.QueryEscape(token) + "&close=1"}).String()
|
|
||||||
|
|
||||||
connCtx, cancel := context.WithCancel(context.Background())
|
connCtx, cancel := context.WithCancel(context.Background())
|
||||||
c := &streamSplitConn{
|
c := &streamSplitConn{
|
||||||
ctx: connCtx,
|
ctx: connCtx,
|
||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
client: client,
|
client: info.client,
|
||||||
pushURL: pushURL,
|
pushURL: info.pushURL,
|
||||||
pullURL: pullURL,
|
pullURL: info.pullURL,
|
||||||
closeURL: closeURL,
|
closeURL: info.closeURL,
|
||||||
headerHost: target.headerHost,
|
headerHost: info.headerHost,
|
||||||
rxc: make(chan []byte, 256),
|
queuedConn: queuedConn{
|
||||||
closed: make(chan struct{}),
|
rxc: make(chan []byte, 256),
|
||||||
writeCh: make(chan []byte, 256),
|
closed: make(chan struct{}),
|
||||||
localAddr: &net.TCPAddr{},
|
writeCh: make(chan []byte, 256),
|
||||||
remoteAddr: &net.TCPAddr{},
|
localAddr: &net.TCPAddr{},
|
||||||
|
remoteAddr: &net.TCPAddr{},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
go c.pullLoop()
|
go c.pullLoop()
|
||||||
go c.pushLoop()
|
go c.pushLoop()
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
func dialStreamSplitWithClient(ctx context.Context, client *http.Client, target httpClientTarget) (net.Conn, error) {
|
||||||
|
info, err := dialSessionWithClient(ctx, client, target, TunnelModeStream)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
c := newStreamSplitConnFromInfo(info)
|
||||||
|
if c == nil {
|
||||||
|
return nil, fmt.Errorf("failed to build stream split conn")
|
||||||
|
}
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func dialStreamSplit(ctx context.Context, serverAddress string, opts TunnelDialOptions) (net.Conn, error) {
|
||||||
|
info, err := dialSession(ctx, serverAddress, opts, TunnelModeStream)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
c := newStreamSplitConnFromInfo(info)
|
||||||
|
if c == nil {
|
||||||
|
return nil, fmt.Errorf("failed to build stream split conn")
|
||||||
|
}
|
||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *streamSplitConn) pullLoop() {
|
func (c *streamSplitConn) pullLoop() {
|
||||||
const (
|
const (
|
||||||
requestTimeout = 30 * time.Second
|
// requestTimeout must be long enough for continuous high-throughput streams (e.g. mux + large downloads).
|
||||||
|
// If it is too short, the client cancels the response mid-body and corrupts the byte stream.
|
||||||
|
requestTimeout = 2 * time.Minute
|
||||||
readChunkSize = 32 * 1024
|
readChunkSize = 32 * 1024
|
||||||
idleBackoff = 25 * time.Millisecond
|
idleBackoff = 25 * time.Millisecond
|
||||||
maxDialRetry = 12
|
maxDialRetry = 12
|
||||||
@@ -688,22 +886,16 @@ func (c *streamSplitConn) pushLoop() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type pollConn struct {
|
type pollConn struct {
|
||||||
|
queuedConn
|
||||||
|
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
|
||||||
client *http.Client
|
client *http.Client
|
||||||
pushURL string
|
pushURL string
|
||||||
pullURL string
|
pullURL string
|
||||||
closeURL string
|
closeURL string
|
||||||
headerHost string
|
headerHost string
|
||||||
|
|
||||||
rxc chan []byte
|
|
||||||
closed chan struct{}
|
|
||||||
|
|
||||||
writeCh chan []byte
|
|
||||||
|
|
||||||
mu sync.Mutex
|
|
||||||
readBuf []byte
|
|
||||||
closeErr error
|
|
||||||
localAddr net.Addr
|
|
||||||
remoteAddr net.Addr
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func isDialError(err error) bool {
|
func isDialError(err error) bool {
|
||||||
@@ -721,147 +913,67 @@ func isDialError(err error) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *pollConn) closeWithError(err error) error {
|
func (c *pollConn) closeWithError(err error) error {
|
||||||
c.mu.Lock()
|
_ = c.queuedConn.closeWithError(err)
|
||||||
select {
|
if c.cancel != nil {
|
||||||
case <-c.closed:
|
c.cancel()
|
||||||
c.mu.Unlock()
|
|
||||||
return nil
|
|
||||||
default:
|
|
||||||
if err == nil {
|
|
||||||
err = io.ErrClosedPipe
|
|
||||||
}
|
|
||||||
if c.closeErr == nil {
|
|
||||||
c.closeErr = err
|
|
||||||
}
|
|
||||||
close(c.closed)
|
|
||||||
}
|
|
||||||
c.mu.Unlock()
|
|
||||||
|
|
||||||
// Best-effort session close signal (avoid leaking server-side sessions).
|
|
||||||
req, reqErr := http.NewRequest(http.MethodPost, c.closeURL, nil)
|
|
||||||
if reqErr == nil {
|
|
||||||
req.Host = c.headerHost
|
|
||||||
req.Header.Set("X-Sudoku-Tunnel", string(TunnelModePoll))
|
|
||||||
req.Header.Set("X-Sudoku-Version", "1")
|
|
||||||
if resp, doErr := c.client.Do(req); doErr == nil && resp != nil {
|
|
||||||
_, _ = io.Copy(io.Discard, io.LimitReader(resp.Body, 4*1024))
|
|
||||||
_ = resp.Body.Close()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
bestEffortCloseSession(c.client, c.closeURL, c.headerHost, TunnelModePoll)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *pollConn) closedErr() error {
|
|
||||||
c.mu.Lock()
|
|
||||||
err := c.closeErr
|
|
||||||
c.mu.Unlock()
|
|
||||||
if err == nil {
|
|
||||||
return io.ErrClosedPipe
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *pollConn) Read(b []byte) (n int, err error) {
|
|
||||||
if len(c.readBuf) == 0 {
|
|
||||||
select {
|
|
||||||
case c.readBuf = <-c.rxc:
|
|
||||||
case <-c.closed:
|
|
||||||
return 0, c.closedErr()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
n = copy(b, c.readBuf)
|
|
||||||
c.readBuf = c.readBuf[n:]
|
|
||||||
return n, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *pollConn) Write(b []byte) (n int, err error) {
|
|
||||||
if len(b) == 0 {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
c.mu.Lock()
|
|
||||||
select {
|
|
||||||
case <-c.closed:
|
|
||||||
c.mu.Unlock()
|
|
||||||
return 0, c.closedErr()
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
c.mu.Unlock()
|
|
||||||
|
|
||||||
payload := make([]byte, len(b))
|
|
||||||
copy(payload, b)
|
|
||||||
select {
|
|
||||||
case c.writeCh <- payload:
|
|
||||||
return len(b), nil
|
|
||||||
case <-c.closed:
|
|
||||||
return 0, c.closedErr()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *pollConn) Close() error {
|
func (c *pollConn) Close() error {
|
||||||
return c.closeWithError(io.ErrClosedPipe)
|
return c.closeWithError(io.ErrClosedPipe)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *pollConn) LocalAddr() net.Addr { return c.localAddr }
|
func newPollConnFromInfo(info *sessionDialInfo) *pollConn {
|
||||||
func (c *pollConn) RemoteAddr() net.Addr { return c.remoteAddr }
|
if info == nil {
|
||||||
|
return nil
|
||||||
func (c *pollConn) SetDeadline(time.Time) error { return nil }
|
|
||||||
func (c *pollConn) SetReadDeadline(time.Time) error { return nil }
|
|
||||||
func (c *pollConn) SetWriteDeadline(time.Time) error { return nil }
|
|
||||||
|
|
||||||
func dialPoll(ctx context.Context, serverAddress string, opts TunnelDialOptions) (net.Conn, error) {
|
|
||||||
client, target, err := newHTTPClient(serverAddress, opts, 32)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
authorizeURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: "/session"}).String()
|
connCtx, cancel := context.WithCancel(context.Background())
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, authorizeURL, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
req.Host = target.headerHost
|
|
||||||
applyTunnelHeaders(req.Header, target.headerHost, TunnelModePoll)
|
|
||||||
|
|
||||||
resp, err := client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
bodyBytes, err := io.ReadAll(io.LimitReader(resp.Body, 4*1024))
|
|
||||||
_ = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
return nil, fmt.Errorf("poll authorize bad status: %s (%s)", resp.Status, strings.TrimSpace(string(bodyBytes)))
|
|
||||||
}
|
|
||||||
|
|
||||||
token, err := parseTunnelToken(bodyBytes)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("poll authorize failed: %q", strings.TrimSpace(string(bodyBytes)))
|
|
||||||
}
|
|
||||||
if token == "" {
|
|
||||||
return nil, fmt.Errorf("poll authorize empty token")
|
|
||||||
}
|
|
||||||
|
|
||||||
pushURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: "/api/v1/upload", RawQuery: "token=" + url.QueryEscape(token)}).String()
|
|
||||||
pullURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: "/stream", RawQuery: "token=" + url.QueryEscape(token)}).String()
|
|
||||||
closeURL := (&url.URL{Scheme: target.scheme, Host: target.urlHost, Path: "/api/v1/upload", RawQuery: "token=" + url.QueryEscape(token) + "&close=1"}).String()
|
|
||||||
|
|
||||||
c := &pollConn{
|
c := &pollConn{
|
||||||
client: client,
|
ctx: connCtx,
|
||||||
pushURL: pushURL,
|
cancel: cancel,
|
||||||
pullURL: pullURL,
|
client: info.client,
|
||||||
closeURL: closeURL,
|
pushURL: info.pushURL,
|
||||||
headerHost: target.headerHost,
|
pullURL: info.pullURL,
|
||||||
rxc: make(chan []byte, 128),
|
closeURL: info.closeURL,
|
||||||
closed: make(chan struct{}),
|
headerHost: info.headerHost,
|
||||||
writeCh: make(chan []byte, 256),
|
queuedConn: queuedConn{
|
||||||
localAddr: &net.TCPAddr{},
|
rxc: make(chan []byte, 128),
|
||||||
remoteAddr: &net.TCPAddr{},
|
closed: make(chan struct{}),
|
||||||
|
writeCh: make(chan []byte, 256),
|
||||||
|
localAddr: &net.TCPAddr{},
|
||||||
|
remoteAddr: &net.TCPAddr{},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
go c.pullLoop()
|
go c.pullLoop()
|
||||||
go c.pushLoop()
|
go c.pushLoop()
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
func dialPollWithClient(ctx context.Context, client *http.Client, target httpClientTarget) (net.Conn, error) {
|
||||||
|
info, err := dialSessionWithClient(ctx, client, target, TunnelModePoll)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
c := newPollConnFromInfo(info)
|
||||||
|
if c == nil {
|
||||||
|
return nil, fmt.Errorf("failed to build poll conn")
|
||||||
|
}
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func dialPoll(ctx context.Context, serverAddress string, opts TunnelDialOptions) (net.Conn, error) {
|
||||||
|
info, err := dialSession(ctx, serverAddress, opts, TunnelModePoll)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
c := newPollConnFromInfo(info)
|
||||||
|
if c == nil {
|
||||||
|
return nil, fmt.Errorf("failed to build poll conn")
|
||||||
|
}
|
||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user