Files
mihomo/transport/sudoku/multiplex/session.go
2026-01-30 10:33:22 +08:00

511 lines
9.3 KiB
Go

package multiplex
import (
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"sync"
"time"
)
const (
frameOpen byte = 0x01
frameData byte = 0x02
frameClose byte = 0x03
frameReset byte = 0x04
)
const (
headerSize = 1 + 4 + 4
maxFrameSize = 256 * 1024
maxDataPayload = 32 * 1024
)
type acceptEvent struct {
stream *stream
payload []byte
}
type Session struct {
conn net.Conn
writeMu sync.Mutex
streamsMu sync.Mutex
streams map[uint32]*stream
nextID uint32
acceptCh chan acceptEvent
closed chan struct{}
closeOnce sync.Once
closeErr error
}
func NewClientSession(conn net.Conn) (*Session, error) {
if conn == nil {
return nil, fmt.Errorf("nil conn")
}
s := &Session{
conn: conn,
streams: make(map[uint32]*stream),
closed: make(chan struct{}),
}
go s.readLoop()
return s, nil
}
func NewServerSession(conn net.Conn) (*Session, error) {
if conn == nil {
return nil, fmt.Errorf("nil conn")
}
s := &Session{
conn: conn,
streams: make(map[uint32]*stream),
acceptCh: make(chan acceptEvent, 256),
closed: make(chan struct{}),
}
go s.readLoop()
return s, nil
}
func (s *Session) IsClosed() bool {
if s == nil {
return true
}
select {
case <-s.closed:
return true
default:
return false
}
}
func (s *Session) closedErr() error {
s.streamsMu.Lock()
err := s.closeErr
s.streamsMu.Unlock()
if err == nil {
return io.ErrClosedPipe
}
return err
}
func (s *Session) closeWithError(err error) {
if err == nil {
err = io.ErrClosedPipe
}
s.closeOnce.Do(func() {
s.streamsMu.Lock()
if s.closeErr == nil {
s.closeErr = err
}
streams := make([]*stream, 0, len(s.streams))
for _, st := range s.streams {
streams = append(streams, st)
}
s.streams = make(map[uint32]*stream)
s.streamsMu.Unlock()
for _, st := range streams {
st.closeNoSend(err)
}
close(s.closed)
_ = s.conn.Close()
})
}
func (s *Session) Close() error {
if s == nil {
return nil
}
s.closeWithError(io.ErrClosedPipe)
return nil
}
func (s *Session) registerStream(st *stream) {
s.streamsMu.Lock()
s.streams[st.id] = st
s.streamsMu.Unlock()
}
func (s *Session) getStream(id uint32) *stream {
s.streamsMu.Lock()
st := s.streams[id]
s.streamsMu.Unlock()
return st
}
func (s *Session) removeStream(id uint32) {
s.streamsMu.Lock()
delete(s.streams, id)
s.streamsMu.Unlock()
}
func (s *Session) nextStreamID() uint32 {
s.streamsMu.Lock()
s.nextID++
id := s.nextID
if id == 0 {
s.nextID++
id = s.nextID
}
s.streamsMu.Unlock()
return id
}
func (s *Session) sendFrame(frameType byte, streamID uint32, payload []byte) error {
if s.IsClosed() {
return s.closedErr()
}
if len(payload) > maxFrameSize {
return fmt.Errorf("mux payload too large: %d", len(payload))
}
var header [headerSize]byte
header[0] = frameType
binary.BigEndian.PutUint32(header[1:5], streamID)
binary.BigEndian.PutUint32(header[5:9], uint32(len(payload)))
s.writeMu.Lock()
defer s.writeMu.Unlock()
if err := writeFull(s.conn, header[:]); err != nil {
s.closeWithError(err)
return err
}
if len(payload) > 0 {
if err := writeFull(s.conn, payload); err != nil {
s.closeWithError(err)
return err
}
}
return nil
}
func (s *Session) sendReset(streamID uint32, msg string) {
if msg == "" {
msg = "reset"
}
_ = s.sendFrame(frameReset, streamID, []byte(msg))
_ = s.sendFrame(frameClose, streamID, nil)
}
func (s *Session) OpenStream(openPayload []byte) (net.Conn, error) {
if s == nil {
return nil, fmt.Errorf("nil session")
}
if s.IsClosed() {
return nil, s.closedErr()
}
streamID := s.nextStreamID()
st := newStream(s, streamID)
s.registerStream(st)
if err := s.sendFrame(frameOpen, streamID, openPayload); err != nil {
st.closeNoSend(err)
s.removeStream(streamID)
return nil, fmt.Errorf("mux open failed: %w", err)
}
return st, nil
}
func (s *Session) AcceptStream() (net.Conn, []byte, error) {
if s == nil {
return nil, nil, fmt.Errorf("nil session")
}
if s.acceptCh == nil {
return nil, nil, fmt.Errorf("accept is not supported on client sessions")
}
select {
case ev := <-s.acceptCh:
return ev.stream, ev.payload, nil
case <-s.closed:
return nil, nil, s.closedErr()
}
}
func (s *Session) readLoop() {
var header [headerSize]byte
for {
if _, err := io.ReadFull(s.conn, header[:]); err != nil {
s.closeWithError(err)
return
}
frameType := header[0]
streamID := binary.BigEndian.Uint32(header[1:5])
n := int(binary.BigEndian.Uint32(header[5:9]))
if n < 0 || n > maxFrameSize {
s.closeWithError(fmt.Errorf("invalid mux frame length: %d", n))
return
}
var payload []byte
if n > 0 {
payload = make([]byte, n)
if _, err := io.ReadFull(s.conn, payload); err != nil {
s.closeWithError(err)
return
}
}
switch frameType {
case frameOpen:
if s.acceptCh == nil {
s.sendReset(streamID, "unexpected open")
continue
}
if streamID == 0 {
s.sendReset(streamID, "invalid stream id")
continue
}
if existing := s.getStream(streamID); existing != nil {
s.sendReset(streamID, "stream already exists")
continue
}
st := newStream(s, streamID)
s.registerStream(st)
go func() {
select {
case s.acceptCh <- acceptEvent{stream: st, payload: payload}:
case <-s.closed:
st.closeNoSend(io.ErrClosedPipe)
s.removeStream(streamID)
}
}()
case frameData:
st := s.getStream(streamID)
if st == nil {
continue
}
if len(payload) == 0 {
continue
}
st.enqueue(payload)
case frameClose:
st := s.getStream(streamID)
if st == nil {
continue
}
st.closeNoSend(io.EOF)
s.removeStream(streamID)
case frameReset:
st := s.getStream(streamID)
if st == nil {
continue
}
msg := trimASCII(payload)
if msg == "" {
msg = "reset"
}
st.closeNoSend(errors.New(msg))
s.removeStream(streamID)
default:
s.closeWithError(fmt.Errorf("unknown mux frame type: %d", frameType))
return
}
}
}
func writeFull(w io.Writer, b []byte) error {
for len(b) > 0 {
n, err := w.Write(b)
if err != nil {
return err
}
b = b[n:]
}
return nil
}
func trimASCII(b []byte) string {
i := 0
j := len(b)
for i < j {
c := b[i]
if c != ' ' && c != '\n' && c != '\r' && c != '\t' {
break
}
i++
}
for j > i {
c := b[j-1]
if c != ' ' && c != '\n' && c != '\r' && c != '\t' {
break
}
j--
}
if i >= j {
return ""
}
out := make([]byte, j-i)
copy(out, b[i:j])
return string(out)
}
type stream struct {
session *Session
id uint32
mu sync.Mutex
cond *sync.Cond
closed bool
closeErr error
readBuf []byte
queue [][]byte
localAddr net.Addr
remoteAddr net.Addr
}
func newStream(session *Session, id uint32) *stream {
st := &stream{
session: session,
id: id,
localAddr: &net.TCPAddr{},
remoteAddr: &net.TCPAddr{},
}
st.cond = sync.NewCond(&st.mu)
return st
}
func (c *stream) enqueue(payload []byte) {
c.mu.Lock()
if c.closed {
c.mu.Unlock()
return
}
if len(c.readBuf) == 0 && len(c.queue) == 0 {
c.readBuf = payload
} else {
c.queue = append(c.queue, payload)
}
c.cond.Signal()
c.mu.Unlock()
}
func (c *stream) closeNoSend(err error) {
if err == nil {
err = io.EOF
}
c.mu.Lock()
if c.closed {
c.mu.Unlock()
return
}
c.closed = true
if c.closeErr == nil {
c.closeErr = err
}
c.cond.Broadcast()
c.mu.Unlock()
}
func (c *stream) closedErr() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.closeErr == nil {
return io.ErrClosedPipe
}
return c.closeErr
}
func (c *stream) Read(p []byte) (int, error) {
if len(p) == 0 {
return 0, nil
}
c.mu.Lock()
defer c.mu.Unlock()
for len(c.readBuf) == 0 && len(c.queue) == 0 && !c.closed {
c.cond.Wait()
}
if len(c.readBuf) == 0 && len(c.queue) > 0 {
c.readBuf = c.queue[0]
c.queue = c.queue[1:]
}
if len(c.readBuf) == 0 && c.closed {
if c.closeErr == nil {
return 0, io.ErrClosedPipe
}
return 0, c.closeErr
}
n := copy(p, c.readBuf)
c.readBuf = c.readBuf[n:]
return n, nil
}
func (c *stream) Write(p []byte) (int, error) {
if len(p) == 0 {
return 0, nil
}
if c.session == nil || c.session.IsClosed() {
if c.session != nil {
return 0, c.session.closedErr()
}
return 0, io.ErrClosedPipe
}
c.mu.Lock()
closed := c.closed
c.mu.Unlock()
if closed {
return 0, c.closedErr()
}
written := 0
for len(p) > 0 {
chunk := p
if len(chunk) > maxDataPayload {
chunk = p[:maxDataPayload]
}
if err := c.session.sendFrame(frameData, c.id, chunk); err != nil {
return written, err
}
written += len(chunk)
p = p[len(chunk):]
}
return written, nil
}
func (c *stream) Close() error {
c.mu.Lock()
if c.closed {
c.mu.Unlock()
return nil
}
c.closed = true
if c.closeErr == nil {
c.closeErr = io.ErrClosedPipe
}
c.cond.Broadcast()
c.mu.Unlock()
_ = c.session.sendFrame(frameClose, c.id, nil)
c.session.removeStream(c.id)
return nil
}
func (c *stream) CloseWrite() error { return c.Close() }
func (c *stream) CloseRead() error { return c.Close() }
func (c *stream) LocalAddr() net.Addr { return c.localAddr }
func (c *stream) RemoteAddr() net.Addr { return c.remoteAddr }
func (c *stream) SetDeadline(t time.Time) error {
_ = c.SetReadDeadline(t)
_ = c.SetWriteDeadline(t)
return nil
}
func (c *stream) SetReadDeadline(time.Time) error { return nil }
func (c *stream) SetWriteDeadline(time.Time) error { return nil }