chore: ensures packets can be sent without blocking the tunnel

This commit is contained in:
wwqgtxx
2024-09-26 11:21:07 +08:00
parent 5812a7bdeb
commit 4fa15c6334
4 changed files with 218 additions and 147 deletions

View File

@@ -1,6 +1,7 @@
package tunnel
import (
"context"
"errors"
"net"
"net/netip"
@@ -11,7 +12,78 @@ import (
"github.com/metacubex/mihomo/log"
)
type packetSender struct {
ctx context.Context
cancel context.CancelFunc
ch chan C.PacketAdapter
}
// newPacketSender return a chan based C.PacketSender
// It ensures that packets can be sent sequentially and without blocking
func newPacketSender() C.PacketSender {
ctx, cancel := context.WithCancel(context.Background())
ch := make(chan C.PacketAdapter, senderCapacity)
return &packetSender{
ctx: ctx,
cancel: cancel,
ch: ch,
}
}
func (s *packetSender) Process(pc C.PacketConn, proxy C.WriteBackProxy) {
for {
select {
case <-s.ctx.Done():
return // sender closed
case packet := <-s.ch:
if proxy != nil {
proxy.UpdateWriteBack(packet)
}
_ = handleUDPToRemote(packet, pc, packet.Metadata())
packet.Drop()
}
}
}
func (s *packetSender) dropAll() {
for {
select {
case data := <-s.ch:
data.Drop() // drop all data still in chan
default:
return // no data, exit goroutine
}
}
}
func (s *packetSender) Send(packet C.PacketAdapter) {
select {
case <-s.ctx.Done():
packet.Drop() // sender closed before Send()
return
default:
}
select {
case s.ch <- packet:
// put ok, so don't drop packet, will process by other side of chan
case <-s.ctx.Done():
packet.Drop() // sender closed when putting data to chan
default:
packet.Drop() // chan is full
}
}
func (s *packetSender) Close() {
s.cancel()
s.dropAll()
}
func handleUDPToRemote(packet C.UDPPacket, pc C.PacketConn, metadata *C.Metadata) error {
if err := resolveUDP(metadata); err != nil {
return err
}
addr := metadata.UDPAddr()
if addr == nil {
return errors.New("udp addr invalid")
@@ -26,8 +98,9 @@ func handleUDPToRemote(packet C.UDPPacket, pc C.PacketConn, metadata *C.Metadata
return nil
}
func handleUDPToLocal(writeBack C.WriteBack, pc N.EnhancePacketConn, key string, oAddrPort netip.AddrPort, fAddr netip.Addr) {
func handleUDPToLocal(writeBack C.WriteBack, pc N.EnhancePacketConn, sender C.PacketSender, key string, oAddrPort netip.AddrPort, fAddr netip.Addr) {
defer func() {
sender.Close()
_ = pc.Close()
closeAllLocalCoon(key)
natTable.Delete(key)

View File

@@ -28,11 +28,14 @@ import (
"github.com/metacubex/mihomo/tunnel/statistic"
)
const queueSize = 200
const (
queueCapacity = 64 // chan capacity tcpQueue and udpQueue
senderCapacity = 128 // chan capacity of PacketSender
)
var (
status = newAtomicStatus(Suspend)
tcpQueue = make(chan C.ConnContext, queueSize)
udpInit sync.Once
udpQueues []chan C.PacketAdapter
natTable = nat.New()
rules []C.Rule
@@ -43,6 +46,12 @@ var (
ruleProviders map[string]provider.RuleProvider
configMux sync.RWMutex
// for compatibility, lazy init
tcpQueue chan C.ConnContext
tcpInOnce sync.Once
udpQueue chan C.PacketAdapter
udpInOnce sync.Once
// Outbound Rule
mode = Rule
@@ -70,15 +79,33 @@ func (t tunnel) HandleTCPConn(conn net.Conn, metadata *C.Metadata) {
handleTCPConn(connCtx)
}
func (t tunnel) HandleUDPPacket(packet C.UDPPacket, metadata *C.Metadata) {
packetAdapter := C.NewPacketAdapter(packet, metadata)
func initUDP() {
numUDPWorkers := 4
if num := runtime.GOMAXPROCS(0); num > numUDPWorkers {
numUDPWorkers = num
}
hash := utils.MapHash(metadata.SourceAddress() + "-" + metadata.RemoteAddress())
udpQueues = make([]chan C.PacketAdapter, numUDPWorkers)
for i := 0; i < numUDPWorkers; i++ {
queue := make(chan C.PacketAdapter, queueCapacity)
udpQueues[i] = queue
go processUDP(queue)
}
}
func (t tunnel) HandleUDPPacket(packet C.UDPPacket, metadata *C.Metadata) {
udpInit.Do(initUDP)
packetAdapter := C.NewPacketAdapter(packet, metadata)
key := packetAdapter.Key()
hash := utils.MapHash(key)
queueNo := uint(hash) % uint(len(udpQueues))
select {
case udpQueues[queueNo] <- packetAdapter:
default:
packet.Drop()
}
}
@@ -134,21 +161,32 @@ func IsSniffing() bool {
return sniffingEnable
}
func init() {
go process()
}
// TCPIn return fan-in queue
// Deprecated: using Tunnel instead
func TCPIn() chan<- C.ConnContext {
tcpInOnce.Do(func() {
tcpQueue = make(chan C.ConnContext, queueCapacity)
go func() {
for connCtx := range tcpQueue {
go handleTCPConn(connCtx)
}
}()
})
return tcpQueue
}
// UDPIn return fan-in udp queue
// Deprecated: using Tunnel instead
func UDPIn() chan<- C.PacketAdapter {
// compatibility: first queue is always available for external callers
return udpQueues[0]
udpInOnce.Do(func() {
udpQueue = make(chan C.PacketAdapter, queueCapacity)
go func() {
for packet := range udpQueue {
Tunnel.HandleUDPPacket(packet, packet.Metadata())
}
}()
})
return udpQueue
}
// NatTable return nat table
@@ -249,32 +287,6 @@ func isHandle(t C.Type) bool {
return status == Running || (status == Inner && t == C.INNER)
}
// processUDP starts a loop to handle udp packet
func processUDP(queue chan C.PacketAdapter) {
for conn := range queue {
handleUDPConn(conn)
}
}
func process() {
numUDPWorkers := 4
if num := runtime.GOMAXPROCS(0); num > numUDPWorkers {
numUDPWorkers = num
}
udpQueues = make([]chan C.PacketAdapter, numUDPWorkers)
for i := 0; i < numUDPWorkers; i++ {
queue := make(chan C.PacketAdapter, queueSize)
udpQueues[i] = queue
go processUDP(queue)
}
queue := tcpQueue
for conn := range queue {
go handleTCPConn(conn)
}
}
func needLookupIP(metadata *C.Metadata) bool {
return resolver.MappingEnabled() && metadata.Host == "" && metadata.DstIP.IsValid()
}
@@ -334,6 +346,25 @@ func resolveMetadata(metadata *C.Metadata) (proxy C.Proxy, rule C.Rule, err erro
return
}
func resolveUDP(metadata *C.Metadata) error {
// local resolve UDP dns
if !metadata.Resolved() {
ip, err := resolver.ResolveIP(context.Background(), metadata.Host)
if err != nil {
return err
}
metadata.DstIP = ip
}
return nil
}
// processUDP starts a loop to handle udp packet
func processUDP(queue chan C.PacketAdapter) {
for conn := range queue {
handleUDPConn(conn)
}
}
func handleUDPConn(packet C.PacketAdapter) {
if !isHandle(packet.Metadata().Type) {
packet.Drop()
@@ -363,85 +394,58 @@ func handleUDPConn(packet C.PacketAdapter) {
snifferDispatcher.UDPSniff(packet)
}
// local resolve UDP dns
if !metadata.Resolved() {
ip, err := resolver.ResolveIP(context.Background(), metadata.Host)
if err != nil {
return
}
metadata.DstIP = ip
}
key := packet.LocalAddr().String()
handle := func() bool {
pc, proxy := natTable.Get(key)
if pc != nil {
if proxy != nil {
proxy.UpdateWriteBack(packet)
key := packet.Key()
sender, loaded := natTable.GetOrCreate(key, newPacketSender)
if !loaded {
dial := func() (C.PacketConn, C.WriteBackProxy, error) {
if err := resolveUDP(metadata); err != nil {
log.Warnln("[UDP] Resolve Ip error: %s", err)
return nil, nil, err
}
_ = handleUDPToRemote(packet, pc, metadata)
return true
}
return false
}
if handle() {
packet.Drop()
return
}
proxy, rule, err := resolveMetadata(metadata)
if err != nil {
log.Warnln("[UDP] Parse metadata failed: %s", err.Error())
return nil, nil, err
}
cond, loaded := natTable.GetOrCreateLock(key)
ctx, cancel := context.WithTimeout(context.Background(), C.DefaultUDPTimeout)
defer cancel()
rawPc, err := retry(ctx, func(ctx context.Context) (C.PacketConn, error) {
return proxy.ListenPacketContext(ctx, metadata.Pure())
}, func(err error) {
logMetadataErr(metadata, rule, proxy, err)
})
if err != nil {
return nil, nil, err
}
logMetadata(metadata, rule, rawPc)
go func() {
defer packet.Drop()
pc := statistic.NewUDPTracker(rawPc, statistic.DefaultManager, metadata, rule, 0, 0, true)
if loaded {
cond.L.Lock()
cond.Wait()
handle()
cond.L.Unlock()
return
if rawPc.Chains().Last() == "REJECT-DROP" {
_ = pc.Close()
return nil, nil, errors.New("rejected drop packet")
}
oAddrPort := metadata.AddrPort()
writeBackProxy := nat.NewWriteBackProxy(packet)
go handleUDPToLocal(writeBackProxy, pc, sender, key, oAddrPort, fAddr)
return pc, writeBackProxy, nil
}
defer func() {
natTable.DeleteLock(key)
cond.Broadcast()
go func() {
pc, proxy, err := dial()
if err != nil {
sender.Close()
natTable.Delete(key)
return
}
sender.Process(pc, proxy)
}()
proxy, rule, err := resolveMetadata(metadata)
if err != nil {
log.Warnln("[UDP] Parse metadata failed: %s", err.Error())
return
}
ctx, cancel := context.WithTimeout(context.Background(), C.DefaultUDPTimeout)
defer cancel()
rawPc, err := retry(ctx, func(ctx context.Context) (C.PacketConn, error) {
return proxy.ListenPacketContext(ctx, metadata.Pure())
}, func(err error) {
logMetadataErr(metadata, rule, proxy, err)
})
if err != nil {
return
}
logMetadata(metadata, rule, rawPc)
pc := statistic.NewUDPTracker(rawPc, statistic.DefaultManager, metadata, rule, 0, 0, true)
if rawPc.Chains().Last() == "REJECT-DROP" {
pc.Close()
return
}
oAddrPort := metadata.AddrPort()
writeBackProxy := nat.NewWriteBackProxy(packet)
natTable.Set(key, pc, writeBackProxy)
go handleUDPToLocal(writeBackProxy, pc, key, oAddrPort, fAddr)
handle()
}()
}
sender.Send(packet) // nonblocking
}
func handleTCPConn(connCtx C.ConnContext) {