diff --git a/adapter/outbound/masque.go b/adapter/outbound/masque.go new file mode 100644 index 00000000..c6ef37dc --- /dev/null +++ b/adapter/outbound/masque.go @@ -0,0 +1,397 @@ +package outbound + +import ( + "context" + "crypto/ecdsa" + "crypto/x509" + "encoding/base64" + "errors" + "fmt" + "net" + "net/netip" + "strconv" + "strings" + "sync" + "time" + + "github.com/metacubex/mihomo/common/atomic" + "github.com/metacubex/mihomo/common/contextutils" + "github.com/metacubex/mihomo/common/pool" + "github.com/metacubex/mihomo/component/dialer" + "github.com/metacubex/mihomo/component/resolver" + C "github.com/metacubex/mihomo/constant" + "github.com/metacubex/mihomo/dns" + "github.com/metacubex/mihomo/log" + "github.com/metacubex/mihomo/transport/masque" + "github.com/metacubex/mihomo/transport/tuic/common" + + connectip "github.com/metacubex/connect-ip-go" + "github.com/metacubex/quic-go" + wireguard "github.com/metacubex/sing-wireguard" + M "github.com/metacubex/sing/common/metadata" + "github.com/metacubex/tls" +) + +type Masque struct { + *Base + tlsConfig *tls.Config + quicConfig *quic.Config + tunDevice wireguard.Device + resolver resolver.Resolver + uri string + + runCtx context.Context + runCancel context.CancelFunc + runMutex sync.Mutex + running atomic.Bool + runDevice atomic.Bool + + option MasqueOption +} + +type MasqueOption struct { + BasicOption + Name string `proxy:"name"` + Server string `proxy:"server"` + Port int `proxy:"port"` + PrivateKey string `proxy:"private-key"` + PublicKey string `proxy:"public-key"` + Ip string `proxy:"ip,omitempty"` + Ipv6 string `proxy:"ipv6,omitempty"` + URI string `proxy:"uri,omitempty"` + SNI string `proxy:"sni,omitempty"` + MTU int `proxy:"mtu,omitempty"` + UDP bool `proxy:"udp,omitempty"` + + CongestionController string `proxy:"congestion-controller,omitempty"` + CWND int `proxy:"cwnd,omitempty"` + + RemoteDnsResolve bool `proxy:"remote-dns-resolve,omitempty"` + Dns []string `proxy:"dns,omitempty"` +} + +func (option MasqueOption) Prefixes() ([]netip.Prefix, error) { + localPrefixes := make([]netip.Prefix, 0, 2) + if len(option.Ip) > 0 { + if !strings.Contains(option.Ip, "/") { + option.Ip = option.Ip + "/32" + } + if prefix, err := netip.ParsePrefix(option.Ip); err == nil { + localPrefixes = append(localPrefixes, prefix) + } else { + return nil, fmt.Errorf("ip address parse error: %w", err) + } + } + if len(option.Ipv6) > 0 { + if !strings.Contains(option.Ipv6, "/") { + option.Ipv6 = option.Ipv6 + "/128" + } + if prefix, err := netip.ParsePrefix(option.Ipv6); err == nil { + localPrefixes = append(localPrefixes, prefix) + } else { + return nil, fmt.Errorf("ipv6 address parse error: %w", err) + } + } + if len(localPrefixes) == 0 { + return nil, errors.New("missing local address") + } + return localPrefixes, nil +} + +func NewMasque(option MasqueOption) (*Masque, error) { + outbound := &Masque{ + Base: &Base{ + name: option.Name, + addr: net.JoinHostPort(option.Server, strconv.Itoa(option.Port)), + tp: C.Masque, + pdName: option.ProviderName, + udp: option.UDP, + iface: option.Interface, + rmark: option.RoutingMark, + prefer: option.IPVersion, + }, + } + outbound.dialer = option.NewDialer(outbound.DialOptions()) + + ctx, cancel := context.WithCancel(context.Background()) + outbound.runCtx = ctx + outbound.runCancel = cancel + + privKeyB64, err := base64.StdEncoding.DecodeString(option.PrivateKey) + if err != nil { + return nil, fmt.Errorf("failed to decode private key: %v", err) + } + privKey, err := x509.ParseECPrivateKey(privKeyB64) + if err != nil { + return nil, fmt.Errorf("failed to parse private key: %v", err) + } + + endpointPubKeyB64, err := base64.StdEncoding.DecodeString(option.PublicKey) + if err != nil { + return nil, fmt.Errorf("failed to decode public key: %v", err) + } + pubKey, err := x509.ParsePKIXPublicKey(endpointPubKeyB64) + if err != nil { + return nil, fmt.Errorf("failed to parse public key: %v", err) + } + ecPubKey, ok := pubKey.(*ecdsa.PublicKey) + if !ok { + return nil, fmt.Errorf("failed to assert public key as ECDSA") + } + + uri := option.URI + if uri == "" { + uri = masque.ConnectURI + } + outbound.uri = uri + + sni := option.SNI + if sni == "" { + sni = masque.ConnectSNI + } + + tlsConfig, err := masque.PrepareTlsConfig(privKey, ecPubKey, sni) + if err != nil { + return nil, fmt.Errorf("failed to prepare TLS config: %v\n", err) + } + outbound.tlsConfig = tlsConfig + + outbound.quicConfig = &quic.Config{ + EnableDatagrams: true, + InitialPacketSize: 1242, + KeepAlivePeriod: 30 * time.Second, + } + + prefixes, err := option.Prefixes() + if err != nil { + return nil, err + } + + outbound.option = option + + mtu := option.MTU + if mtu == 0 { + mtu = 1280 + } + if len(prefixes) == 0 { + return nil, errors.New("missing local address") + } + outbound.tunDevice, err = wireguard.NewStackDevice(prefixes, uint32(mtu)) + if err != nil { + return nil, fmt.Errorf("create device: %w", err) + } + + var has6 bool + for _, address := range prefixes { + if !address.Addr().Unmap().Is4() { + has6 = true + break + } + } + + if option.RemoteDnsResolve && len(option.Dns) > 0 { + nss, err := dns.ParseNameServer(option.Dns) + if err != nil { + return nil, err + } + for i := range nss { + nss[i].ProxyAdapter = outbound + } + outbound.resolver = dns.NewResolver(dns.Config{ + Main: nss, + IPv6: has6, + }) + } + + return outbound, nil +} + +func (w *Masque) run(ctx context.Context) error { + if w.running.Load() { + return nil + } + w.runMutex.Lock() + defer w.runMutex.Unlock() + // double-check like sync.Once + if w.running.Load() { + return nil + } + + if w.runCtx.Err() != nil { + return w.runCtx.Err() + } + + if !w.runDevice.Load() { + err := w.tunDevice.Start() + if err != nil { + return err + } + w.runDevice.Store(true) + } + + udpAddr, err := resolveUDPAddr(ctx, "udp", w.addr, w.prefer) + if err != nil { + return err + } + + pc, err := w.dialer.ListenPacket(ctx, "udp", "", udpAddr.AddrPort()) + if err != nil { + return err + } + + quicConn, err := quic.Dial(ctx, pc, udpAddr, w.tlsConfig, w.quicConfig) + if err != nil { + return err + } + + common.SetCongestionController(quicConn, w.option.CongestionController, w.option.CWND) + + tr, ipConn, err := masque.ConnectTunnel(ctx, quicConn, w.uri) + if err != nil { + _ = pc.Close() + return err + } + + w.running.Store(true) + + runCtx, runCancel := context.WithCancel(w.runCtx) + contextutils.AfterFunc(runCtx, func() { + w.running.Store(false) + _ = ipConn.Close() + _ = tr.Close() + _ = pc.Close() + }) + + go func() { + defer runCancel() + buf := pool.Get(pool.UDPBufferSize) + defer pool.Put(buf) + bufs := [][]byte{buf} + sizes := []int{0} + for runCtx.Err() == nil { + _, err := w.tunDevice.Read(bufs, sizes, 0) + if err != nil { + log.Errorln("Error reading from TUN device: %v", err) + return + } + icmp, err := ipConn.WritePacket(buf[:sizes[0]]) + if err != nil { + if errors.As(err, new(*connectip.CloseError)) { + log.Errorln("connection closed while writing to IP connection: %v", err) + return + } + log.Warnln("Error writing to IP connection: %v, continuing...", err) + continue + } + + if len(icmp) > 0 { + if _, err := w.tunDevice.Write([][]byte{icmp}, 0); err != nil { + log.Warnln("Error writing ICMP to TUN device: %v, continuing...", err) + } + } + } + }() + + go func() { + defer runCancel() + buf := pool.Get(pool.UDPBufferSize) + defer pool.Put(buf) + for runCtx.Err() == nil { + n, err := ipConn.ReadPacket(buf) + if err != nil { + if errors.As(err, new(*connectip.CloseError)) { + log.Errorln("connection closed while writing to IP connection: %v", err) + return + } + log.Warnln("Error reading from IP connection: %v, continuing...", err) + continue + } + if _, err := w.tunDevice.Write([][]byte{buf[:n]}, 0); err != nil { + log.Errorln("Error writing to TUN device: %v", err) + return + } + } + }() + + return nil +} + +// Close implements C.ProxyAdapter +func (w *Masque) Close() error { + w.runCancel() + if w.tunDevice != nil { + w.tunDevice.Close() + } + return nil +} + +func (w *Masque) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn, err error) { + var conn net.Conn + if err = w.run(ctx); err != nil { + return nil, err + } + if !metadata.Resolved() || w.resolver != nil { + r := resolver.DefaultResolver + if w.resolver != nil { + r = w.resolver + } + options := w.DialOptions() + options = append(options, dialer.WithResolver(r)) + options = append(options, dialer.WithNetDialer(wgNetDialer{tunDevice: w.tunDevice})) + conn, err = dialer.NewDialer(options...).DialContext(ctx, "tcp", metadata.RemoteAddress()) + } else { + conn, err = w.tunDevice.DialContext(ctx, "tcp", M.SocksaddrFrom(metadata.DstIP, metadata.DstPort).Unwrap()) + } + if err != nil { + return nil, err + } + if conn == nil { + return nil, errors.New("conn is nil") + } + return NewConn(conn, w), nil +} + +func (w *Masque) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (_ C.PacketConn, err error) { + var pc net.PacketConn + if err = w.run(ctx); err != nil { + return nil, err + } + if err = w.ResolveUDP(ctx, metadata); err != nil { + return nil, err + } + pc, err = w.tunDevice.ListenPacket(ctx, M.SocksaddrFrom(metadata.DstIP, metadata.DstPort).Unwrap()) + if err != nil { + return nil, err + } + if pc == nil { + return nil, errors.New("packetConn is nil") + } + return newPacketConn(pc, w), nil +} + +func (w *Masque) ResolveUDP(ctx context.Context, metadata *C.Metadata) error { + if (!metadata.Resolved() || w.resolver != nil) && metadata.Host != "" { + r := resolver.DefaultResolver + if w.resolver != nil { + r = w.resolver + } + ip, err := resolver.ResolveIPWithResolver(ctx, metadata.Host, r) + if err != nil { + return fmt.Errorf("can't resolve ip: %w", err) + } + metadata.DstIP = ip + } + return nil +} + +// ProxyInfo implements C.ProxyAdapter +func (w *Masque) ProxyInfo() C.ProxyInfo { + info := w.Base.ProxyInfo() + info.DialerProxy = w.option.DialerProxy + return info +} + +// IsL3Protocol implements C.ProxyAdapter +func (w *Masque) IsL3Protocol(metadata *C.Metadata) bool { + return true +} diff --git a/adapter/parser.go b/adapter/parser.go index 08f90afe..290a4a8f 100644 --- a/adapter/parser.go +++ b/adapter/parser.go @@ -159,6 +159,13 @@ func ParseProxy(mapping map[string]any, options ...ProxyOption) (C.Proxy, error) break } proxy, err = outbound.NewSudoku(*sudokuOption) + case "masque": + masqueOption := &outbound.MasqueOption{BasicOption: basicOption} + err = decoder.Decode(mapping, masqueOption) + if err != nil { + break + } + proxy, err = outbound.NewMasque(*masqueOption) default: return nil, fmt.Errorf("unsupport proxy type: %s", proxyType) } diff --git a/constant/adapters.go b/constant/adapters.go index 07ae5de1..e451dc92 100644 --- a/constant/adapters.go +++ b/constant/adapters.go @@ -45,6 +45,7 @@ const ( Mieru AnyTLS Sudoku + Masque ) const ( @@ -212,6 +213,8 @@ func (at AdapterType) String() string { return "AnyTLS" case Sudoku: return "Sudoku" + case Masque: + return "Masque" case Relay: return "Relay" case Selector: diff --git a/docs/config.yaml b/docs/config.yaml index 0170f08f..8746ae20 100644 --- a/docs/config.yaml +++ b/docs/config.yaml @@ -980,6 +980,23 @@ proxies: # socks5 # j3: # AmneziaWG v1.5 only (removed in v2) # itime: 60 # AmneziaWG v1.5 only (removed in v2) + # masque + - name: "masque" + type: masque + server: 162.159.198.1 + port: 443 + private-key: MHcCAQEEILI1eOtnbEIh89Fj4yNDuFR6UjayCKI3NdLl3DhetimWoAoGCCqGSM49AwEHoUQDQgAEgyXrE8v+hHsHy3ewSb3WcRjYgCrM9T9hiE0Uv6k2DZ1+4kefrDT9v1Q/8wdRigTf6t6gGNUV8W+IUMdrfUt+9g== + public-key: MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEIaU7MToJm9NKp8YfGxR6r+/h4mcG\n7SxI8tsW8OR1A5tv/zCzVbCRRh2t87/kxnP6lAy0lkr7qYwu+ox+k3dr6w== + ip: 172.16.0.2 + ipv6: 2606:4700:110:84c0:163a:4914:a0ad:3342 + mtu: 1280 + udp: true + # 一个出站代理的标识。当值不为空时,将使用指定的 proxy 发出连接 + # dialer-proxy: "ss1" + # remote-dns-resolve: true # 强制 dns 远程解析,默认值为 false + # dns: [ 1.1.1.1, 8.8.8.8 ] # 仅在 remote-dns-resolve 为 true 时生效 + # congestion-controller: bbr # 默认不开启 + # tuic - name: tuic server: www.example.com diff --git a/go.mod b/go.mod index 33bb2cd0..17de7f5b 100644 --- a/go.mod +++ b/go.mod @@ -17,13 +17,14 @@ require ( github.com/metacubex/blake3 v0.1.0 github.com/metacubex/chacha v0.1.5 github.com/metacubex/chi v0.1.0 + github.com/metacubex/connect-ip-go v0.0.0-20260128031117-1cad62060727 github.com/metacubex/cpu v0.1.0 github.com/metacubex/fswatch v0.1.1 github.com/metacubex/gopacket v1.1.20-0.20230608035415-7e2f98a3e759 github.com/metacubex/http v0.1.0 github.com/metacubex/kcp-go v0.0.0-20260105040817-550693377604 github.com/metacubex/mlkem v0.1.0 - github.com/metacubex/quic-go v0.59.1-0.20260112033758-aa29579f2001 + github.com/metacubex/quic-go v0.59.1-0.20260128071132-0f3233b973af github.com/metacubex/randv2 v0.2.0 github.com/metacubex/restls-client-go v0.1.7 github.com/metacubex/sing v0.5.7 @@ -47,6 +48,7 @@ require ( github.com/sirupsen/logrus v1.9.4 github.com/stretchr/testify v1.11.1 github.com/vmihailenco/msgpack/v5 v5.4.1 + github.com/yosida95/uritemplate/v3 v3.0.2 gitlab.com/go-extension/aes-ccm v0.0.0-20230221065045-e58665ef23c7 go.uber.org/automaxprocs v1.6.0 go4.org/netipx v0.0.0-20231129151722-fdeea329fbba @@ -74,6 +76,7 @@ require ( github.com/ajg/form v1.5.1 // indirect github.com/andybalholm/brotli v1.0.6 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dunglas/httpsfv v1.0.2 // indirect github.com/ericlagergren/aegis v0.0.0-20250325060835-cd0defd64358 // indirect github.com/ericlagergren/polyval v0.0.0-20220411101811-e25bc10ba391 // indirect github.com/ericlagergren/siv v0.0.0-20220507050439-0b757b3aa5f1 // indirect diff --git a/go.sum b/go.sum index 805085a3..d317f362 100644 --- a/go.sum +++ b/go.sum @@ -20,6 +20,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZQ= github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= +github.com/dunglas/httpsfv v1.0.2 h1:iERDp/YAfnojSDJ7PW3dj1AReJz4MrwbECSSE59JWL0= +github.com/dunglas/httpsfv v1.0.2/go.mod h1:zID2mqw9mFsnt7YC3vYQ9/cjq30q41W+1AnDwH8TiMg= github.com/enfein/mieru/v3 v3.26.2 h1:U/2XJc+3vrJD9r815FoFdwToQFEcqSOzzzWIPPhjfEU= github.com/enfein/mieru/v3 v3.26.2/go.mod h1:zJBUCsi5rxyvHM8fjFf+GLaEl4OEjjBXr1s5F6Qd3hM= github.com/ericlagergren/aegis v0.0.0-20250325060835-cd0defd64358 h1:kXYqH/sL8dS/FdoFjr12ePjnLPorPo2FsnrHNuXSDyo= @@ -87,6 +89,8 @@ github.com/metacubex/chacha v0.1.5 h1:fKWMb/5c7ZrY8Uoqi79PPFxl+qwR7X/q0OrsAubyX2 github.com/metacubex/chacha v0.1.5/go.mod h1:Djn9bPZxLTXbJFSeyo0/qzEzQI+gUSSzttuzZM75GH8= github.com/metacubex/chi v0.1.0 h1:rjNDyDj50nRpicG43CNkIw4ssiCbmDL8d7wJXKlUCsg= github.com/metacubex/chi v0.1.0/go.mod h1:zM5u5oMQt8b2DjvDHvzadKrP6B2ztmasL1YHRMbVV+g= +github.com/metacubex/connect-ip-go v0.0.0-20260128031117-1cad62060727 h1:qbZQ0sO0bDBKPvTd/qNQK6513300WJ5GRsHnw3PO4Ho= +github.com/metacubex/connect-ip-go v0.0.0-20260128031117-1cad62060727/go.mod h1:xYC8Ik7/rN6no+vTRuWMEziGwm3brA0wNM/zZP9qhOQ= github.com/metacubex/cpu v0.1.0 h1:8PeTdV9j6UKbN1K5Jvtbi/Jock7dknvzyYuLb8Conmk= github.com/metacubex/cpu v0.1.0/go.mod h1:09VEt4dSRLR+bOA8l4w4NDuzGZ8n5dkMv7e8axgEeTU= github.com/metacubex/fswatch v0.1.1 h1:jqU7C/v+g0qc2RUFgmAOPoVvfl2BXXUXEumn6oQuxhU= @@ -109,8 +113,8 @@ github.com/metacubex/nftables v0.0.0-20250503052935-30a69ab87793 h1:1Qpuy+sU3Dmy github.com/metacubex/nftables v0.0.0-20250503052935-30a69ab87793/go.mod h1:RjRNb4G52yAgfR+Oe/kp9G4PJJ97Fnj89eY1BFO3YyA= github.com/metacubex/qpack v0.6.0 h1:YqClGIMOpiRYLjV1qOs483Od08MdPgRnHjt90FuaAKw= github.com/metacubex/qpack v0.6.0/go.mod h1:lKGSi7Xk94IMvHGOmxS9eIei3bvIqpOAImEBsaOwTkA= -github.com/metacubex/quic-go v0.59.1-0.20260112033758-aa29579f2001 h1:RlT3bFCIDM/NR9GWaDbFCrweOwpHRfgaT9c0zuRlPhY= -github.com/metacubex/quic-go v0.59.1-0.20260112033758-aa29579f2001/go.mod h1:oNzMrmylS897M3zSMuapIdwSwfq6F2qW01Z3NhVRJhk= +github.com/metacubex/quic-go v0.59.1-0.20260128071132-0f3233b973af h1:do5o1rzn64NEN5oGswo7VruDkbz2055fhVT3rXehA8E= +github.com/metacubex/quic-go v0.59.1-0.20260128071132-0f3233b973af/go.mod h1:oNzMrmylS897M3zSMuapIdwSwfq6F2qW01Z3NhVRJhk= github.com/metacubex/randv2 v0.2.0 h1:uP38uBvV2SxYfLj53kuvAjbND4RUDfFJjwr4UigMiLs= github.com/metacubex/randv2 v0.2.0/go.mod h1:kFi2SzrQ5WuneuoLLCMkABtiBu6VRrMrWFqSPyj2cxY= github.com/metacubex/restls-client-go v0.1.7 h1:eCwiXCTQb5WJu9IlgYvDBA1OgrINv58dEe7hcN5H15k= @@ -194,6 +198,8 @@ github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21 github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= github.com/xtaci/lossyconn v0.0.0-20190602105132-8df528c0c9ae h1:J0GxkO96kL4WF+AIT3M4mfUVinOCPgf2uUWYFUzN0sM= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= gitlab.com/go-extension/aes-ccm v0.0.0-20230221065045-e58665ef23c7 h1:UNrDfkQqiEYzdMlNsVvBYOAJWZjdktqFE9tQh5BT2+4= gitlab.com/go-extension/aes-ccm v0.0.0-20230221065045-e58665ef23c7/go.mod h1:E+rxHvJG9H6PUdzq9NRG6csuLN3XUx98BfGOVWNYnXs= gitlab.com/yawning/bsaes.git v0.0.0-20190805113838-0a714cd429ec h1:FpfFs4EhNehiVfzQttTuxanPIT43FtkkCFypIod8LHo= diff --git a/transport/masque/masque.go b/transport/masque/masque.go new file mode 100644 index 00000000..df11eb19 --- /dev/null +++ b/transport/masque/masque.go @@ -0,0 +1,226 @@ +// Package masque +// copy and modify from https://github.com/Diniboy1123/usque/blob/d0eb96e7e5c56cce6cf34a7f8d75abbedba58fef/api/masque.go +package masque + +import ( + "context" + "crypto/ecdsa" + "crypto/rand" + "crypto/x509" + "errors" + "fmt" + "math/big" + "net/netip" + "net/url" + "time" + + connectip "github.com/metacubex/connect-ip-go" + "github.com/metacubex/http" + "github.com/metacubex/quic-go" + "github.com/metacubex/quic-go/http3" + "github.com/metacubex/tls" + "github.com/yosida95/uritemplate/v3" +) + +const ( + ConnectSNI = "consumer-masque.cloudflareclient.com" + ConnectURI = "https://cloudflareaccess.com" +) + +// PrepareTlsConfig creates a TLS configuration using the provided certificate and SNI (Server Name Indication). +// It also verifies the peer's public key against the provided public key. +func PrepareTlsConfig(privKey *ecdsa.PrivateKey, peerPubKey *ecdsa.PublicKey, sni string) (*tls.Config, error) { + verfiyCert := func(cert *x509.Certificate) error { + if _, ok := cert.PublicKey.(*ecdsa.PublicKey); !ok { + // we only support ECDSA + // TODO: don't hardcode cert type in the future + // as backend can start using different cert types + return x509.ErrUnsupportedAlgorithm + } + + if !cert.PublicKey.(*ecdsa.PublicKey).Equal(peerPubKey) { + // reason is incorrect, but the best I could figure + // detail explains the actual reason + + //10 is NoValidChains, but we support go1.22 where it's not defined + return x509.CertificateInvalidError{Cert: cert, Reason: 10, Detail: "remote endpoint has a different public key than what we trust"} + } + + return nil + } + + cert, err := GenerateCert(privKey) + if err != nil { + return nil, fmt.Errorf("failed to generate cert: %v", err) + } + + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{ + { + Certificate: cert, + PrivateKey: privKey, + }, + }, + ServerName: sni, + NextProtos: []string{http3.NextProtoH3}, + // WARN: SNI is usually not for the endpoint, so we must skip verification + InsecureSkipVerify: true, + // we pin to the endpoint public key + VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { + if len(rawCerts) == 0 { + return nil + } + + var err error + for _, v := range rawCerts { + cert, er := x509.ParseCertificate(v) + if er != nil { + err = errors.Join(err, er) + continue + } + + if er = verfiyCert(cert); er != nil { + err = errors.Join(err, er) + continue + } + } + + return err + }, + } + + return tlsConfig, nil +} + +func GenerateCert(privKey *ecdsa.PrivateKey) ([][]byte, error) { + cert, err := x509.CreateCertificate(rand.Reader, &x509.Certificate{ + SerialNumber: big.NewInt(0), + NotBefore: time.Now(), + NotAfter: time.Now().Add(1 * 24 * time.Hour), + }, &x509.Certificate{}, &privKey.PublicKey, privKey) + if err != nil { + return nil, err + } + + return [][]byte{cert}, nil +} + +// ConnectTunnel establishes a QUIC connection and sets up a Connect-IP tunnel with the provided endpoint. +// Endpoint address is used to check whether the authentication/connection is successful or not. +// Requires modified connect-ip-go for now to support Cloudflare's non RFC compliant implementation. +func ConnectTunnel(ctx context.Context, quicConn *quic.Conn, connectUri string) (*http3.Transport, *connectip.Conn, error) { + tr := &http3.Transport{ + EnableDatagrams: true, + AdditionalSettings: map[uint64]uint64{ + // official client still sends this out as well, even though + // it's deprecated, see https://datatracker.ietf.org/doc/draft-ietf-masque-h3-datagram/00/ + // SETTINGS_H3_DATAGRAM_00 = 0x0000000000000276 + // https://github.com/cloudflare/quiche/blob/7c66757dbc55b8d0c3653d4b345c6785a181f0b7/quiche/src/h3/frame.rs#L46 + 0x276: 1, + }, + DisableCompression: true, + } + + hconn := tr.NewClientConn(quicConn) + + additionalHeaders := http.Header{ + "User-Agent": []string{""}, + } + + template := uritemplate.MustNew(connectUri) + ipConn, rsp, err := dialEx(ctx, hconn, template, "cf-connect-ip", additionalHeaders, true) + if err != nil { + _ = tr.Close() + if err.Error() == "CRYPTO_ERROR 0x131 (remote): tls: access denied" { + return nil, nil, errors.New("login failed! Please double-check if your tls key and cert is enrolled in the Cloudflare Access service") + } + return nil, nil, fmt.Errorf("failed to dial connect-ip: %v", err) + } + + err = ipConn.AdvertiseRoute(ctx, []connectip.IPRoute{ + { + IPProtocol: 0, + StartIP: netip.AddrFrom4([4]byte{}), + EndIP: netip.AddrFrom4([4]byte{255, 255, 255, 255}), + }, + { + IPProtocol: 0, + StartIP: netip.AddrFrom16([16]byte{}), + EndIP: netip.AddrFrom16([16]byte{ + 255, 255, 255, 255, + 255, 255, 255, 255, + 255, 255, 255, 255, + 255, 255, 255, 255, + }), + }, + }) + if err != nil { + _ = ipConn.Close() + _ = tr.Close() + return nil, nil, err + } + + if rsp.StatusCode != http.StatusOK { + _ = ipConn.Close() + _ = tr.Close() + return nil, nil, fmt.Errorf("failed to dial connect-ip: %v", rsp.Status) + } + + return tr, ipConn, nil +} + +// dialEx dials a proxied connection to a target server. +func dialEx(ctx context.Context, conn *http3.ClientConn, template *uritemplate.Template, requestProtocol string, additionalHeaders http.Header, ignoreExtendedConnect bool) (*connectip.Conn, *http.Response, error) { + if len(template.Varnames()) > 0 { + return nil, nil, errors.New("connect-ip: IP flow forwarding not supported") + } + + u, err := url.Parse(template.Raw()) + if err != nil { + return nil, nil, fmt.Errorf("connect-ip: failed to parse URI: %w", err) + } + + select { + case <-ctx.Done(): + return nil, nil, context.Cause(ctx) + case <-conn.Context().Done(): + return nil, nil, context.Cause(conn.Context()) + case <-conn.ReceivedSettings(): + } + settings := conn.Settings() + if !ignoreExtendedConnect && !settings.EnableExtendedConnect { + return nil, nil, errors.New("connect-ip: server didn't enable Extended CONNECT") + } + if !settings.EnableDatagrams { + return nil, nil, errors.New("connect-ip: server didn't enable datagrams") + } + + const capsuleProtocolHeaderValue = "?1" + headers := http.Header{http3.CapsuleProtocolHeader: []string{capsuleProtocolHeaderValue}} + for k, v := range additionalHeaders { + headers[k] = v + } + + rstr, err := conn.OpenRequestStream(ctx) + if err != nil { + return nil, nil, fmt.Errorf("connect-ip: failed to open request stream: %w", err) + } + if err := rstr.SendRequestHeader(&http.Request{ + Method: http.MethodConnect, + Proto: requestProtocol, + Host: u.Host, + Header: headers, + URL: u, + }); err != nil { + return nil, nil, fmt.Errorf("connect-ip: failed to send request: %w", err) + } + // TODO: optimistically return the connection + rsp, err := rstr.ReadResponse() + if err != nil { + return nil, nil, fmt.Errorf("connect-ip: failed to read response: %w", err) + } + if rsp.StatusCode < 200 || rsp.StatusCode > 299 { + return nil, rsp, fmt.Errorf("connect-ip: server responded with %d", rsp.StatusCode) + } + return connectip.NewProxiedConn(rstr), rsp, nil +}