Compare commits

...

135 Commits

Author SHA1 Message Date
wwqgtxx
97f25250a6 chore: code cleanup 2026-02-08 00:10:07 +08:00
wwqgtxx
8b0bcb6740 chore: better generator 2026-02-08 00:09:46 +08:00
wwqgtxx
022f677385 chore: cleanup hostValue code 2026-02-07 12:07:01 +08:00
wwqgtxx
32799662ad fix: quic data race for crypto/tls 2026-02-07 09:41:08 +08:00
wwqgtxx
86257fc83c chore: remove reflect-based provider override code 2026-02-05 17:20:47 +08:00
wwqgtxx
f2222b5e02 chore: code cleanup 2026-02-05 16:47:24 +08:00
Chenx Dust
3ac62152cb fix: race condition of tcpConcurrent in dialer (#2556) 2026-02-05 15:57:01 +08:00
wwqgtxx
5516ca18fd chore: code cleanup 2026-02-05 10:46:09 +08:00
wwqgtxx
3bca69c745 chore: add some comments for the fingerprint verifier 2026-02-05 10:34:36 +08:00
wwqgtxx
558b3840ea action: update Go 1.26rc3 to test 2026-02-04 23:57:47 +08:00
wwqgtxx
f94da9f2b3 chore: fingerprint verifier handle non-leaf certificate will check the SNI matches the certificate's DNS name 2026-02-04 22:41:33 +08:00
wwqgtxx
2cfc4ba044 fix: CVE-2025-68121 for crypto/tls again and again 2026-02-04 15:19:27 +08:00
wwqgtxx
034f1d1e9b chore: disallow empty proxy-server-nameserver when proxy-server-nameserver-policy is set 2026-02-04 12:06:45 +08:00
wwqgtxx
dede56fe4b feat: add proxy-server-nameserver-policy to dns section 2026-02-03 01:41:00 +08:00
wwqgtxx
5fda87d50e chore: update sing-tun 2026-02-02 14:39:32 +08:00
wwqgtxx
5e1b133e4e chore: more callback support for utls 2026-02-01 01:36:30 +08:00
wwqgtxx
27a3ca6afc chore: converter support fingerprint for vmess/vless/trojan 2026-01-31 21:27:14 +08:00
wwqgtxx
7573affdd4 chore: better logging in masque outbound 2026-01-31 20:25:02 +08:00
wwqgtxx
17fbaf9100 doc: fix typo 2026-01-30 23:16:08 +08:00
sleshep
710772f993 chore: add simple validation for static dialer-proxy config (#2551)
Currently, it can only validate whether a cycle exists in proxies, and cannot determine if it is caused by groups.
2026-01-30 20:39:06 +08:00
saba-futai
d36b024b10 chore: align sudoku with upstream v0.2.0 (#2549) 2026-01-30 10:33:22 +08:00
wwqgtxx
f52c9356c2 fix: CVE-2025-68121 for crypto/tls again 2026-01-29 08:53:48 +08:00
wwqgtxx
e45c896185 feat: support masque outbound 2026-01-28 19:01:18 +08:00
wwqgtxx
d18a14afeb fix: snat key in trojan packet listener 2026-01-28 00:49:27 +08:00
wwqgtxx
6aaabc97ca chore: decrease unneeded string convert in socks5 addr parsing 2026-01-28 00:41:13 +08:00
wwqgtxx
85c024a4a6 fix: snat key in sudoku packet listener 2026-01-28 00:32:44 +08:00
wwqgtxx
c33c90d7af chore: clean up duplicate code in sudoku 2026-01-26 09:28:18 +08:00
wwqgtxx
65c3d3e4e2 chore: remove unreachable code in sudoku 2026-01-26 09:28:18 +08:00
wwqgtxx
98b3060558 chore: optimize timeout control in DoH TLS probe 2026-01-25 22:49:04 +08:00
wwqgtxx
b90100645e chore: using mihomo's global pool in DoQ 2026-01-25 22:48:48 +08:00
wwqgtxx
46ee1649c0 chore: hy2 listener fellow hysteria2's code skip verify in https masquerade 2026-01-25 21:49:02 +08:00
wwqgtxx
e3b8fc2b77 fix: hy2 listener panic with http/https masquerade 2026-01-25 18:12:51 +08:00
wwqgtxx
707fe8b207 chore: remove auto IDNA conversion in domain rules
The original upstream does not support it, and there are many places in the current code that do not support it either. Removing it will help maintain consistency in behavior across different parts.
2026-01-23 09:36:13 +08:00
wwqgtxx
1e1434d1de chore: remove an unnecessary variable 2026-01-20 14:48:59 +08:00
wwqgtxx
26052ba5e5 chore: remove confused varbin using in sing 2026-01-19 23:16:55 +08:00
wwqgtxx
75a0cd5aff fix: file exists when tun start 2026-01-18 16:37:51 +08:00
wwqgtxx
e03ba23f65 chore: update logrus 2026-01-18 11:00:49 +08:00
wwqgtxx
0c995a2479 chore: move proxiesWithProviders to hub/route internal to disallow external usage of this poorly implemented function 2026-01-17 22:26:14 +08:00
wwqgtxx
3c526ae06e feat: add query-server-name for ech-opts 2026-01-17 19:11:57 +08:00
wwqgtxx
0b009b514c doc: add missing params 2026-01-17 18:52:44 +08:00
wwqgtxx
18d139a15d chore: rollback tls to restore the session resumption functionality in quic-go 2026-01-17 18:27:19 +08:00
wwqgtxx
5f250413fe doc: remove deprecated item 2026-01-17 18:27:19 +08:00
wwqgtxx
993595df73 chore: switch to our own common/orderedmap package, remove two unneeded json dependence 2026-01-17 18:27:19 +08:00
H1JK
828fd30dc3 chore: support connection reuse for DoT 2026-01-16 14:20:20 +08:00
wwqgtxx
11000dccd7 chore: add common/deque package 2026-01-16 11:05:15 +08:00
wwqgtxx
0818aa54aa chore: provider a common entrance for YAML package 2026-01-16 11:05:13 +08:00
wwqgtxx
edbfebeacd fix: CVE-2025-68121 for crypto/tls 2026-01-16 08:27:29 +08:00
saba-futai
06f5fbac06 feat: add path-root for sudoku (#2511) 2026-01-14 21:25:05 +08:00
Shaw
f38fc2020f feat: add grpc-user-agent to grpc-opts (#2512) 2026-01-14 21:02:09 +08:00
wwqgtxx
97bce45eba chore: deprecated global-client-fingerprint, please set client-fingerprint directly on the proxy instead 2026-01-14 10:40:26 +08:00
Davoyan
bc28cd486a doc: fix typo in config.yaml (#2459) 2026-01-14 09:01:18 +08:00
wwqgtxx
cdabd1e8b1 chore: update utls 2026-01-14 08:02:37 +08:00
Toby
c5b0f00bb2 fix: logic issues with BBR impl
98872a4f38
2026-01-12 13:34:59 +08:00
wwqgtxx
c128d23dec chore: update quic-go to 0.59.0 2026-01-12 12:48:18 +08:00
wwqgtxx
ee37a353d0 fix: incorrect timestamp conversion in brutal 2026-01-12 12:45:52 +08:00
wwqgtxx
0cf37de1a8 chore: better time storage in rule wrapper 2026-01-12 00:50:55 +08:00
potoo0
ae6069c178 chore: moving rules disabled and hit/miss counts data to extra for restful api (#2503) 2026-01-11 21:11:38 +08:00
wwqgtxx
c8e33a4347 chore: decrease rule wrapper memory usage 2026-01-11 20:57:28 +08:00
potoo0
19a6b5d6f7 feat: support rule disabling and hit/miss count/at tracking in restful api (#2502) 2026-01-11 19:37:08 +08:00
wwqgtxx
efb800866e chore: update quic-go to 0.58.1 2026-01-11 17:19:53 +08:00
wwqgtxx
94c8d60f72 chore: simplified logic rule parsing 2026-01-08 23:42:01 +08:00
saba-futai
0f2baca2de chore: refactored the implementation of suduko mux (#2486) 2026-01-07 00:25:33 +08:00
wwqgtxx
b18a33552c chore: remove unused pointer in rules implements 2026-01-06 09:29:09 +08:00
wwqgtxx
487de9b548 feat: add PROCESS-NAME-WILDCARD and PROCESS-PATH-WILDCARD 2026-01-06 08:52:06 +08:00
enfein
1a6230ec03 chore: update mieru version (#2484)
Fix https://github.com/enfein/mieru/issues/247
2026-01-06 07:48:46 +08:00
wwqgtxx
e6bf56b9af fix: os.(*Process).Wait not working on Windows7 2026-01-05 20:26:19 +08:00
wwqgtxx
0ad9ac325a feat: support aes-128-gcm, ratelimit and framesize for kcptun 2026-01-05 12:25:30 +08:00
saba-futai
d6b1263236 feat: support http-mask-multiplex for suduko (#2482) 2026-01-04 22:24:42 +08:00
wwqgtxx
4d7670339b feat: all dns client support disable-qtype-<int> params 2026-01-02 22:43:58 +08:00
wwqgtxx
0cffc8d76d chore: revert "chore: update quic-go to 0.58.0"
This reverts commit 64015b7634.
2026-01-02 17:09:40 +08:00
wwqgtxx
1f8bee9710 chore: force to disable mptcp for tproxy 2025-12-31 08:43:23 +08:00
wwqgtxx
eb30d3f331 chore: add a code comment for tproxy listener 2025-12-31 02:07:52 +08:00
wwqgtxx
10f4bebdfa fix: only clear dstIP if it is confirmed to be a fake IP 2025-12-30 17:16:09 +08:00
David
06387d5045 feat: support fake-ip-filter-mode: rule mode (#2469) 2025-12-29 08:14:09 +08:00
wwqgtxx
c393e917eb fix: gvisor compatibility on go1.26 2025-12-27 17:57:30 +08:00
wwqgtxx
4f0a6fa117 fix: gvisor panic 2025-12-27 17:16:35 +08:00
wwqgtxx
4f9bfd216f chore: add some comments for the finalizer 2025-12-27 16:38:58 +08:00
joshua
498f81aad3 feat: add header support for rule provider (#2463) 2025-12-24 23:10:38 +08:00
wwqgtxx
9168bee6b7 chore: align internal logic 2025-12-24 18:26:55 +08:00
HolgerHuo
e6c0e3b19c fix: handle geoip:lan when GetRecodeSize() (#2460) 2025-12-24 08:34:19 +08:00
wwqgtxx
287f9e5185 chore: temporarily skip mieru inbound test in go1.26 on windows 2025-12-23 23:49:19 +08:00
wwqgtxx
c456370f4f fix: missing context cancel in pullLoop 2025-12-23 23:26:05 +08:00
wwqgtxx
10ef29f5cd chore: apply global ca in sudoku code 2025-12-23 23:15:52 +08:00
wwqgtxx
85ba7f6a0a chore: change import paths in sudoku code 2025-12-23 23:14:39 +08:00
saba-futai
7daf37bc15 feat: support http-mask-mode, http-mask-tls and http-mask-host for sudoku (#2456) 2025-12-23 23:08:38 +08:00
wwqgtxx
64015b7634 chore: update quic-go to 0.58.0 2025-12-22 17:29:28 +08:00
wwqgtxx
5585304d68 chore: allow custom path for gRPC (grpc-service-name start with /) 2025-12-21 10:28:05 +08:00
wwqgtxx
911211578c action: add Go 1.26rc1 to test 2025-12-21 00:12:03 +08:00
wwqgtxx
abb55199f2 fix: os.RemoveAll not working on Windows7 2025-12-20 23:02:26 +08:00
wwqgtxx
87c3f700e5 chore: add TODO comment to ca.LoadCertificates 2025-12-19 21:43:55 +08:00
wwqgtxx
4a723e8d3f chore: allow automatic reloading when the TLS server's certificate, private-key or ech-key is a local file 2025-12-19 20:23:48 +08:00
wwqgtxx
93cf46e430 chore: remove unused import path 2025-12-19 20:14:02 +08:00
Howard Wu
35a1130c92 chore: use HasPrefix instead of Contains for key checks (#2447) 2025-12-19 18:43:06 +08:00
Howard Wu
1ebcb25e4a fix: typo in sniffer skip-dst-address config parsing (#2446) 2025-12-19 18:16:56 +08:00
wwqgtxx
cbcacdbb8c chore: using tls.Config.GetCertificate/GetClientCertificate to load TLS certificates 2025-12-19 12:24:16 +08:00
wwqgtxx
17966b5418 fix: close sing-tun maybe panic on windows 2025-12-18 10:37:50 +08:00
wwqgtxx
bc8f0dcf77 fix: missing ntp call 2025-12-17 18:50:33 +08:00
wwqgtxx
827cd616e8 chore: cleanup import path 2025-12-17 17:35:58 +08:00
wwqgtxx
e1384e86ab chore: update http2 using in test 2025-12-17 17:19:09 +08:00
wwqgtxx
b92b38701c chore: update ech handling 2025-12-17 17:19:06 +08:00
wwqgtxx
1cab34d257 chore: update quic-go to 0.57.1 2025-12-17 16:13:12 +08:00
saba-futai
a06097c2c4 chore: add xvp rotation andd new header generation strategy for sudoku (#2437) 2025-12-16 18:39:39 +08:00
wwqgtxx
bc9db11cb4 chore: hub/route module handle websocket itself 2025-12-14 19:56:30 +08:00
Ealrang
69e301820c action: fix architecture check for riscv64 in script (#2435) 2025-12-14 17:30:40 +08:00
wwqgtxx
e7a04e0762 chore: don't process msg.Extra in msgToHTTPSRRInfo 2025-12-12 19:42:29 +08:00
Eric Moore
7e8c2876fb chore: improve HTTPS RR logging (#2431) 2025-12-12 17:43:54 +08:00
wwqgtxx
936ebc7718 chore: add echparser package for parse ECHConfigList and ECHConfig 2025-12-12 16:05:11 +08:00
Eric Moore
b753a57e6a fix: ech not work with websocket+clientFingerprint 2025-12-11 23:15:40 +08:00
wwqgtxx
dd99bfc892 doc: fix custom-table doc 2025-12-11 14:53:02 +08:00
wwqgtxx
2a1b3b2aed chore: allow sudoku inbound handle sing-mux request 2025-12-11 14:14:21 +08:00
saba-futai
2211789a7c chore: add customized byte style for sudoku (#2427) 2025-12-10 17:47:59 +08:00
wwqgtxx
e652e277a7 fix: missing ProxyInfo information in wireguard outbound 2025-12-10 17:06:13 +08:00
wwqgtxx
40863d248d chore: add lock in baseProvider for thread-safe 2025-12-10 08:42:40 +08:00
wwqgtxx
17b8eb8772 chore: skip icmp forwarding when destination in tun interface addr range 2025-12-08 09:56:15 +08:00
Vincent Loeng/Leong
6b40072bc5 chore: support find process on freebsd 14 and 15 (#2422) 2025-12-06 14:03:59 +08:00
wwqgtxx
f44aa22d50 chore: add sudoku ed25519key test 2025-12-05 09:06:06 +08:00
wwqgtxx
c33d9ad857 chore: cleanup sudoku internal code 2025-12-05 08:53:18 +08:00
saba-futai
25041b599e chore: sudoku support enable-pure-downlink mode to increase download bandwidth (#2419) 2025-12-05 07:52:49 +08:00
wwqgtxx
6539b509cb chore: restful api contains providerChains for connections 2025-12-04 17:29:01 +08:00
Xi Xu
d2007fdc22 chore: improves thread safety in adapter 2025-12-04 16:02:22 +08:00
wwqgtxx
b5fa3ee99a chore: restful api contains provider-name for proxies 2025-12-04 15:10:13 +08:00
wwqgtxx
91f5593f4e fix: structure ignore tag not working in nest struct 2025-12-04 14:44:34 +08:00
wwqgtxx
90470ac304 chore: cleanup import path for common/net 2025-12-04 13:44:46 +08:00
wwqgtxx
b509affe5b chore: simplify DNSPrefer serialization process 2025-12-04 13:41:44 +08:00
wwqgtxx
32ce513977 chore: discard domain addr input in sudoku uot 2025-12-03 22:54:26 +08:00
wwqgtxx
30891f8781 chore: sharing sudoku internal code 2025-12-03 22:23:37 +08:00
saba-futai
e4cdb9b600 feat: add uot for sudoku (#2415) 2025-12-03 22:11:56 +08:00
wwqgtxx
d33dbbe2f9 fix: QUIC events with session tickets disabled will panic on Go 1.26 2025-12-03 15:40:23 +08:00
wwqgtxx
d8dcaa7500 chore: add upTotal and downTotal data to /traffic restful api 2025-12-03 11:31:13 +08:00
wwqgtxx
9df8392c65 chore: clean up internal interface definitions 2025-12-03 11:08:16 +08:00
wwqgtxx
fdb7cb1f58 chore: allow setting DialerForAPI in adapter.ParseProxy for library user 2025-12-03 00:05:27 +08:00
wwqgtxx
7cd58fbdf6 chore: add DialerForAPI to outbound option for library user 2025-12-02 23:33:07 +08:00
wwqgtxx
bc719eb96d chore: simplify tuic client 2025-12-02 21:07:51 +08:00
wwqgtxx
ac90543548 chore: code cleanup 2025-12-02 17:18:20 +08:00
futai
9a5e506f66 chore: simplify server config and add keygen for sudoku (#2407) 2025-12-01 19:26:41 +08:00
246 changed files with 15433 additions and 2821 deletions

View File

@@ -1,4 +1,6 @@
Subject: [PATCH] Revert "runtime: always use LoadLibraryEx to load system libraries" Subject: [PATCH] Revert "os: remove 5ms sleep on Windows in (*Process).Wait"
Fix os.RemoveAll not working on Windows7
Revert "runtime: always use LoadLibraryEx to load system libraries"
Revert "syscall: remove Windows 7 console handle workaround" Revert "syscall: remove Windows 7 console handle workaround"
Revert "net: remove sysSocket fallback for Windows 7" Revert "net: remove sysSocket fallback for Windows 7"
Revert "crypto/rand,runtime: switch RtlGenRandom for ProcessPrng" Revert "crypto/rand,runtime: switch RtlGenRandom for ProcessPrng"
@@ -655,3 +657,228 @@ diff --git a/src/syscall/dll_windows.go b/src/syscall/dll_windows.go
} else { } else {
h, e = loadlibrary(namep) h, e = loadlibrary(namep)
} }
Index: src/os/removeall_at.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/os/removeall_at.go b/src/os/removeall_at.go
--- a/src/os/removeall_at.go (revision f56f1e23507e646c85243a71bde7b9629b2f970c)
+++ b/src/os/removeall_at.go (revision 0a52622d2331ff975fb0442617ec19bc352bb2ed)
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build unix || wasip1 || windows
+//go:build unix || wasip1
package os
@@ -175,3 +175,25 @@
}
return newDirFile(fd, name)
}
+
+func rootRemoveAll(r *Root, name string) error {
+ // Consistency with os.RemoveAll: Strip trailing /s from the name,
+ // so RemoveAll("not_a_directory/") succeeds.
+ for len(name) > 0 && IsPathSeparator(name[len(name)-1]) {
+ name = name[:len(name)-1]
+ }
+ if endsWithDot(name) {
+ // Consistency with os.RemoveAll: Return EINVAL when trying to remove .
+ return &PathError{Op: "RemoveAll", Path: name, Err: syscall.EINVAL}
+ }
+ _, err := doInRoot(r, name, nil, func(parent sysfdType, name string) (struct{}, error) {
+ return struct{}{}, removeAllFrom(parent, name)
+ })
+ if IsNotExist(err) {
+ return nil
+ }
+ if err != nil {
+ return &PathError{Op: "RemoveAll", Path: name, Err: underlyingError(err)}
+ }
+ return err
+}
Index: src/os/removeall_noat.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/os/removeall_noat.go b/src/os/removeall_noat.go
--- a/src/os/removeall_noat.go (revision f56f1e23507e646c85243a71bde7b9629b2f970c)
+++ b/src/os/removeall_noat.go (revision 0a52622d2331ff975fb0442617ec19bc352bb2ed)
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build (js && wasm) || plan9
+//go:build (js && wasm) || plan9 || windows
package os
@@ -140,3 +140,22 @@
}
return err
}
+
+func rootRemoveAll(r *Root, name string) error {
+ if endsWithDot(name) {
+ // Consistency with os.RemoveAll: Return EINVAL when trying to remove .
+ return &PathError{Op: "RemoveAll", Path: name, Err: syscall.EINVAL}
+ }
+ if err := checkPathEscapesLstat(r, name); err != nil {
+ if err == syscall.ENOTDIR {
+ // Some intermediate path component is not a directory.
+ // RemoveAll treats this as success (since the target doesn't exist).
+ return nil
+ }
+ return &PathError{Op: "RemoveAll", Path: name, Err: err}
+ }
+ if err := RemoveAll(joinPath(r.root.name, name)); err != nil {
+ return &PathError{Op: "RemoveAll", Path: name, Err: underlyingError(err)}
+ }
+ return nil
+}
Index: src/os/root_noopenat.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/os/root_noopenat.go b/src/os/root_noopenat.go
--- a/src/os/root_noopenat.go (revision f56f1e23507e646c85243a71bde7b9629b2f970c)
+++ b/src/os/root_noopenat.go (revision 0a52622d2331ff975fb0442617ec19bc352bb2ed)
@@ -11,7 +11,6 @@
"internal/filepathlite"
"internal/stringslite"
"sync/atomic"
- "syscall"
"time"
)
@@ -185,25 +184,6 @@
}
return nil
}
-
-func rootRemoveAll(r *Root, name string) error {
- if endsWithDot(name) {
- // Consistency with os.RemoveAll: Return EINVAL when trying to remove .
- return &PathError{Op: "RemoveAll", Path: name, Err: syscall.EINVAL}
- }
- if err := checkPathEscapesLstat(r, name); err != nil {
- if err == syscall.ENOTDIR {
- // Some intermediate path component is not a directory.
- // RemoveAll treats this as success (since the target doesn't exist).
- return nil
- }
- return &PathError{Op: "RemoveAll", Path: name, Err: err}
- }
- if err := RemoveAll(joinPath(r.root.name, name)); err != nil {
- return &PathError{Op: "RemoveAll", Path: name, Err: underlyingError(err)}
- }
- return nil
-}
func rootReadlink(r *Root, name string) (string, error) {
if err := checkPathEscapesLstat(r, name); err != nil {
Index: src/os/root_openat.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/os/root_openat.go b/src/os/root_openat.go
--- a/src/os/root_openat.go (revision f56f1e23507e646c85243a71bde7b9629b2f970c)
+++ b/src/os/root_openat.go (revision 0a52622d2331ff975fb0442617ec19bc352bb2ed)
@@ -194,28 +194,6 @@
return nil
}
-func rootRemoveAll(r *Root, name string) error {
- // Consistency with os.RemoveAll: Strip trailing /s from the name,
- // so RemoveAll("not_a_directory/") succeeds.
- for len(name) > 0 && IsPathSeparator(name[len(name)-1]) {
- name = name[:len(name)-1]
- }
- if endsWithDot(name) {
- // Consistency with os.RemoveAll: Return EINVAL when trying to remove .
- return &PathError{Op: "RemoveAll", Path: name, Err: syscall.EINVAL}
- }
- _, err := doInRoot(r, name, nil, func(parent sysfdType, name string) (struct{}, error) {
- return struct{}{}, removeAllFrom(parent, name)
- })
- if IsNotExist(err) {
- return nil
- }
- if err != nil {
- return &PathError{Op: "RemoveAll", Path: name, Err: underlyingError(err)}
- }
- return err
-}
-
func rootRename(r *Root, oldname, newname string) error {
_, err := doInRoot(r, oldname, nil, func(oldparent sysfdType, oldname string) (struct{}, error) {
_, err := doInRoot(r, newname, nil, func(newparent sysfdType, newname string) (struct{}, error) {
Index: src/os/root_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/os/root_windows.go b/src/os/root_windows.go
--- a/src/os/root_windows.go (revision f56f1e23507e646c85243a71bde7b9629b2f970c)
+++ b/src/os/root_windows.go (revision 0a52622d2331ff975fb0442617ec19bc352bb2ed)
@@ -402,3 +402,14 @@
}
return fi.Mode(), nil
}
+
+func checkPathEscapes(r *Root, name string) error {
+ if !filepathlite.IsLocal(name) {
+ return errPathEscapes
+ }
+ return nil
+}
+
+func checkPathEscapesLstat(r *Root, name string) error {
+ return checkPathEscapes(r, name)
+}
Index: src/os/exec_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/os/exec_windows.go b/src/os/exec_windows.go
--- a/src/os/exec_windows.go (revision 0a52622d2331ff975fb0442617ec19bc352bb2ed)
+++ b/src/os/exec_windows.go (revision fb3d09a67fe97008ad76fea97ae88170072cbdbb)
@@ -10,6 +10,7 @@
"runtime"
"syscall"
"time"
+ _ "unsafe"
)
// Note that Process.handle is never nil because Windows always requires
@@ -49,9 +50,23 @@
// than statusDone.
p.doRelease(statusReleased)
+ var maj, min, build uint32
+ rtlGetNtVersionNumbers(&maj, &min, &build)
+ if maj < 10 {
+ // NOTE(brainman): It seems that sometimes process is not dead
+ // when WaitForSingleObject returns. But we do not know any
+ // other way to wait for it. Sleeping for a while seems to do
+ // the trick sometimes.
+ // See https://golang.org/issue/25965 for details.
+ time.Sleep(5 * time.Millisecond)
+ }
+
return &ProcessState{p.Pid, syscall.WaitStatus{ExitCode: ec}, &u}, nil
}
+//go:linkname rtlGetNtVersionNumbers syscall.rtlGetNtVersionNumbers
+func rtlGetNtVersionNumbers(majorVersion *uint32, minorVersion *uint32, buildNumber *uint32)
+
func (p *Process) signal(sig Signal) error {
handle, status := p.handleTransientAcquire()
switch status {

883
.github/patch/go1.26.patch vendored Normal file
View File

@@ -0,0 +1,883 @@
Subject: [PATCH] Revert "os: remove 5ms sleep on Windows in (*Process).Wait"
Fix os.RemoveAll not working on Windows7
Revert "runtime: always use LoadLibraryEx to load system libraries"
Revert "syscall: remove Windows 7 console handle workaround"
Revert "net: remove sysSocket fallback for Windows 7"
Revert "crypto/rand,runtime: switch RtlGenRandom for ProcessPrng"
---
Index: src/crypto/internal/sysrand/rand_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/crypto/internal/sysrand/rand_windows.go b/src/crypto/internal/sysrand/rand_windows.go
--- a/src/crypto/internal/sysrand/rand_windows.go (revision c599a8f2385849a225d02843b3c6389dbfc5aa69)
+++ b/src/crypto/internal/sysrand/rand_windows.go (revision b0d48afabb9fd14976c27221cb525c5d2ebbfe79)
@@ -7,5 +7,26 @@
import "internal/syscall/windows"
func read(b []byte) error {
- return windows.ProcessPrng(b)
+ // RtlGenRandom only returns 1<<32-1 bytes at a time. We only read at
+ // most 1<<31-1 bytes at a time so that this works the same on 32-bit
+ // and 64-bit systems.
+ return batched(windows.RtlGenRandom, 1<<31-1)(b)
+}
+
+// batched returns a function that calls f to populate a []byte by chunking it
+// into subslices of, at most, readMax bytes.
+func batched(f func([]byte) error, readMax int) func([]byte) error {
+ return func(out []byte) error {
+ for len(out) > 0 {
+ read := len(out)
+ if read > readMax {
+ read = readMax
+ }
+ if err := f(out[:read]); err != nil {
+ return err
+ }
+ out = out[read:]
+ }
+ return nil
+ }
}
Index: src/crypto/rand/rand.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/crypto/rand/rand.go b/src/crypto/rand/rand.go
--- a/src/crypto/rand/rand.go (revision c599a8f2385849a225d02843b3c6389dbfc5aa69)
+++ b/src/crypto/rand/rand.go (revision b0d48afabb9fd14976c27221cb525c5d2ebbfe79)
@@ -25,7 +25,7 @@
// - On legacy Linux (< 3.17), Reader opens /dev/urandom on first use.
// - On macOS, iOS, and OpenBSD Reader, uses arc4random_buf(3).
// - On NetBSD, Reader uses the kern.arandom sysctl.
-// - On Windows, Reader uses the ProcessPrng API.
+// - On Windows systems, Reader uses the RtlGenRandom API.
// - On js/wasm, Reader uses the Web Crypto API.
// - On wasip1/wasm, Reader uses random_get.
//
Index: src/internal/syscall/windows/syscall_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/internal/syscall/windows/syscall_windows.go b/src/internal/syscall/windows/syscall_windows.go
--- a/src/internal/syscall/windows/syscall_windows.go (revision c599a8f2385849a225d02843b3c6389dbfc5aa69)
+++ b/src/internal/syscall/windows/syscall_windows.go (revision b0d48afabb9fd14976c27221cb525c5d2ebbfe79)
@@ -421,7 +421,7 @@
//sys DestroyEnvironmentBlock(block *uint16) (err error) = userenv.DestroyEnvironmentBlock
//sys CreateEvent(eventAttrs *SecurityAttributes, manualReset uint32, initialState uint32, name *uint16) (handle syscall.Handle, err error) = kernel32.CreateEventW
-//sys ProcessPrng(buf []byte) (err error) = bcryptprimitives.ProcessPrng
+//sys RtlGenRandom(buf []byte) (err error) = advapi32.SystemFunction036
type FILE_ID_BOTH_DIR_INFO struct {
NextEntryOffset uint32
Index: src/internal/syscall/windows/zsyscall_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/internal/syscall/windows/zsyscall_windows.go b/src/internal/syscall/windows/zsyscall_windows.go
--- a/src/internal/syscall/windows/zsyscall_windows.go (revision c599a8f2385849a225d02843b3c6389dbfc5aa69)
+++ b/src/internal/syscall/windows/zsyscall_windows.go (revision b0d48afabb9fd14976c27221cb525c5d2ebbfe79)
@@ -38,7 +38,6 @@
var (
modadvapi32 = syscall.NewLazyDLL(sysdll.Add("advapi32.dll"))
- modbcryptprimitives = syscall.NewLazyDLL(sysdll.Add("bcryptprimitives.dll"))
modiphlpapi = syscall.NewLazyDLL(sysdll.Add("iphlpapi.dll"))
modkernel32 = syscall.NewLazyDLL(sysdll.Add("kernel32.dll"))
modnetapi32 = syscall.NewLazyDLL(sysdll.Add("netapi32.dll"))
@@ -63,7 +62,7 @@
procQueryServiceStatus = modadvapi32.NewProc("QueryServiceStatus")
procRevertToSelf = modadvapi32.NewProc("RevertToSelf")
procSetTokenInformation = modadvapi32.NewProc("SetTokenInformation")
- procProcessPrng = modbcryptprimitives.NewProc("ProcessPrng")
+ procSystemFunction036 = modadvapi32.NewProc("SystemFunction036")
procGetAdaptersAddresses = modiphlpapi.NewProc("GetAdaptersAddresses")
procCreateEventW = modkernel32.NewProc("CreateEventW")
procCreateIoCompletionPort = modkernel32.NewProc("CreateIoCompletionPort")
@@ -244,12 +243,12 @@
return
}
-func ProcessPrng(buf []byte) (err error) {
+func RtlGenRandom(buf []byte) (err error) {
var _p0 *byte
if len(buf) > 0 {
_p0 = &buf[0]
}
- r1, _, e1 := syscall.SyscallN(procProcessPrng.Addr(), uintptr(unsafe.Pointer(_p0)), uintptr(len(buf)))
+ r1, _, e1 := syscall.SyscallN(procSystemFunction036.Addr(), uintptr(unsafe.Pointer(_p0)), uintptr(len(buf)), 0)
if r1 == 0 {
err = errnoErr(e1)
}
Index: src/runtime/os_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/runtime/os_windows.go b/src/runtime/os_windows.go
--- a/src/runtime/os_windows.go (revision c599a8f2385849a225d02843b3c6389dbfc5aa69)
+++ b/src/runtime/os_windows.go (revision ea2726a6fa25fbfa1092e696e522eafca544d24c)
@@ -40,7 +40,8 @@
//go:cgo_import_dynamic runtime._GetSystemInfo GetSystemInfo%1 "kernel32.dll"
//go:cgo_import_dynamic runtime._GetThreadContext GetThreadContext%2 "kernel32.dll"
//go:cgo_import_dynamic runtime._SetThreadContext SetThreadContext%2 "kernel32.dll"
-//go:cgo_import_dynamic runtime._LoadLibraryExW LoadLibraryExW%3 "kernel32.dll"
+//go:cgo_import_dynamic runtime._LoadLibraryW LoadLibraryW%1 "kernel32.dll"
+//go:cgo_import_dynamic runtime._LoadLibraryA LoadLibraryA%1 "kernel32.dll"
//go:cgo_import_dynamic runtime._PostQueuedCompletionStatus PostQueuedCompletionStatus%4 "kernel32.dll"
//go:cgo_import_dynamic runtime._QueryPerformanceCounter QueryPerformanceCounter%1 "kernel32.dll"
//go:cgo_import_dynamic runtime._QueryPerformanceFrequency QueryPerformanceFrequency%1 "kernel32.dll"
@@ -74,7 +75,6 @@
// Following syscalls are available on every Windows PC.
// All these variables are set by the Windows executable
// loader before the Go program starts.
- _AddVectoredContinueHandler,
_AddVectoredExceptionHandler,
_CloseHandle,
_CreateEventA,
@@ -97,7 +97,8 @@
_GetSystemInfo,
_GetThreadContext,
_SetThreadContext,
- _LoadLibraryExW,
+ _LoadLibraryW,
+ _LoadLibraryA,
_PostQueuedCompletionStatus,
_QueryPerformanceCounter,
_QueryPerformanceFrequency,
@@ -126,8 +127,23 @@
_WriteFile,
_ stdFunction
- // Use ProcessPrng to generate cryptographically random data.
- _ProcessPrng stdFunction
+ // Following syscalls are only available on some Windows PCs.
+ // We will load syscalls, if available, before using them.
+ _AddDllDirectory,
+ _AddVectoredContinueHandler,
+ _LoadLibraryExA,
+ _LoadLibraryExW,
+ _ stdFunction
+
+ // Use RtlGenRandom to generate cryptographically random data.
+ // This approach has been recommended by Microsoft (see issue
+ // 15589 for details).
+ // The RtlGenRandom is not listed in advapi32.dll, instead
+ // RtlGenRandom function can be found by searching for SystemFunction036.
+ // Also some versions of Mingw cannot link to SystemFunction036
+ // when building executable as Cgo. So load SystemFunction036
+ // manually during runtime startup.
+ _RtlGenRandom stdFunction
// Load ntdll.dll manually during startup, otherwise Mingw
// links wrong printf function to cgo executable (see issue
@@ -144,13 +160,6 @@
_ stdFunction
)
-var (
- bcryptprimitivesdll = [...]uint16{'b', 'c', 'r', 'y', 'p', 't', 'p', 'r', 'i', 'm', 'i', 't', 'i', 'v', 'e', 's', '.', 'd', 'l', 'l', 0}
- ntdlldll = [...]uint16{'n', 't', 'd', 'l', 'l', '.', 'd', 'l', 'l', 0}
- powrprofdll = [...]uint16{'p', 'o', 'w', 'r', 'p', 'r', 'o', 'f', '.', 'd', 'l', 'l', 0}
- winmmdll = [...]uint16{'w', 'i', 'n', 'm', 'm', '.', 'd', 'l', 'l', 0}
-)
-
// Function to be called by windows CreateThread
// to start new os thread.
func tstart_stdcall(newm *m)
@@ -242,9 +251,40 @@
return unsafe.String(&sysDirectory[0], sysDirectoryLen)
}
-func windowsLoadSystemLib(name []uint16) uintptr {
- const _LOAD_LIBRARY_SEARCH_SYSTEM32 = 0x00000800
- return stdcall(_LoadLibraryExW, uintptr(unsafe.Pointer(&name[0])), 0, _LOAD_LIBRARY_SEARCH_SYSTEM32)
+//go:linkname syscall_getSystemDirectory syscall.getSystemDirectory
+func syscall_getSystemDirectory() string {
+ return unsafe.String(&sysDirectory[0], sysDirectoryLen)
+}
+
+func windowsLoadSystemLib(name []byte) uintptr {
+ if useLoadLibraryEx {
+ return stdcall(_LoadLibraryExA, uintptr(unsafe.Pointer(&name[0])), 0, _LOAD_LIBRARY_SEARCH_SYSTEM32)
+ } else {
+ absName := append(sysDirectory[:sysDirectoryLen], name...)
+ return stdcall(_LoadLibraryA, uintptr(unsafe.Pointer(&absName[0])))
+ }
+}
+
+const _LOAD_LIBRARY_SEARCH_SYSTEM32 = 0x00000800
+
+// When available, this function will use LoadLibraryEx with the filename
+// parameter and the important SEARCH_SYSTEM32 argument. But on systems that
+// do not have that option, absoluteFilepath should contain a fallback
+// to the full path inside of system32 for use with vanilla LoadLibrary.
+//
+//go:linkname syscall_loadsystemlibrary syscall.loadsystemlibrary
+func syscall_loadsystemlibrary(filename *uint16, absoluteFilepath *uint16) (handle, err uintptr) {
+ if useLoadLibraryEx {
+ handle, _, err = syscall_syscalln(uintptr(unsafe.Pointer(_LoadLibraryExW)), 3, uintptr(unsafe.Pointer(filename)), 0, _LOAD_LIBRARY_SEARCH_SYSTEM32)
+ } else {
+ handle, _, err = syscall_syscalln(uintptr(unsafe.Pointer(_LoadLibraryW)), 1, uintptr(unsafe.Pointer(absoluteFilepath)))
+ }
+ KeepAlive(filename)
+ KeepAlive(absoluteFilepath)
+ if handle != 0 {
+ err = 0
+ }
+ return
}
//go:linkname windows_QueryPerformanceCounter internal/syscall/windows.QueryPerformanceCounter
@@ -262,13 +302,28 @@
}
func loadOptionalSyscalls() {
- bcryptPrimitives := windowsLoadSystemLib(bcryptprimitivesdll[:])
- if bcryptPrimitives == 0 {
- throw("bcryptprimitives.dll not found")
+ var kernel32dll = []byte("kernel32.dll\000")
+ k32 := stdcall(_LoadLibraryA, uintptr(unsafe.Pointer(&kernel32dll[0])))
+ if k32 == 0 {
+ throw("kernel32.dll not found")
}
- _ProcessPrng = windowsFindfunc(bcryptPrimitives, []byte("ProcessPrng\000"))
+ _AddDllDirectory = windowsFindfunc(k32, []byte("AddDllDirectory\000"))
+ _AddVectoredContinueHandler = windowsFindfunc(k32, []byte("AddVectoredContinueHandler\000"))
+ _LoadLibraryExA = windowsFindfunc(k32, []byte("LoadLibraryExA\000"))
+ _LoadLibraryExW = windowsFindfunc(k32, []byte("LoadLibraryExW\000"))
+ useLoadLibraryEx = (_LoadLibraryExW != nil && _LoadLibraryExA != nil && _AddDllDirectory != nil)
+
+ initSysDirectory()
- n32 := windowsLoadSystemLib(ntdlldll[:])
+ var advapi32dll = []byte("advapi32.dll\000")
+ a32 := windowsLoadSystemLib(advapi32dll)
+ if a32 == 0 {
+ throw("advapi32.dll not found")
+ }
+ _RtlGenRandom = windowsFindfunc(a32, []byte("SystemFunction036\000"))
+
+ var ntdll = []byte("ntdll.dll\000")
+ n32 := windowsLoadSystemLib(ntdll)
if n32 == 0 {
throw("ntdll.dll not found")
}
@@ -297,7 +352,7 @@
context uintptr
}
- powrprof := windowsLoadSystemLib(powrprofdll[:])
+ powrprof := windowsLoadSystemLib([]byte("powrprof.dll\000"))
if powrprof == 0 {
return // Running on Windows 7, where we don't need it anyway.
}
@@ -351,6 +406,22 @@
// in sys_windows_386.s and sys_windows_amd64.s:
func getlasterror() uint32
+// When loading DLLs, we prefer to use LoadLibraryEx with
+// LOAD_LIBRARY_SEARCH_* flags, if available. LoadLibraryEx is not
+// available on old Windows, though, and the LOAD_LIBRARY_SEARCH_*
+// flags are not available on some versions of Windows without a
+// security patch.
+//
+// https://msdn.microsoft.com/en-us/library/ms684179(v=vs.85).aspx says:
+// "Windows 7, Windows Server 2008 R2, Windows Vista, and Windows
+// Server 2008: The LOAD_LIBRARY_SEARCH_* flags are available on
+// systems that have KB2533623 installed. To determine whether the
+// flags are available, use GetProcAddress to get the address of the
+// AddDllDirectory, RemoveDllDirectory, or SetDefaultDllDirectories
+// function. If GetProcAddress succeeds, the LOAD_LIBRARY_SEARCH_*
+// flags can be used with LoadLibraryEx."
+var useLoadLibraryEx bool
+
var timeBeginPeriodRetValue uint32
// osRelaxMinNS indicates that sysmon shouldn't osRelax if the next
@@ -417,7 +488,8 @@
// Only load winmm.dll if we need it.
// This avoids a dependency on winmm.dll for Go programs
// that run on new Windows versions.
- m32 := windowsLoadSystemLib(winmmdll[:])
+ var winmmdll = []byte("winmm.dll\000")
+ m32 := windowsLoadSystemLib(winmmdll)
if m32 == 0 {
print("runtime: LoadLibraryExW failed; errno=", getlasterror(), "\n")
throw("winmm.dll not found")
@@ -458,6 +530,28 @@
canUseLongPaths = true
}
+var osVersionInfo struct {
+ majorVersion uint32
+ minorVersion uint32
+ buildNumber uint32
+}
+
+func initOsVersionInfo() {
+ info := windows.OSVERSIONINFOW{}
+ info.OSVersionInfoSize = uint32(unsafe.Sizeof(info))
+ stdcall(_RtlGetVersion, uintptr(unsafe.Pointer(&info)))
+ osVersionInfo.majorVersion = info.MajorVersion
+ osVersionInfo.minorVersion = info.MinorVersion
+ osVersionInfo.buildNumber = info.BuildNumber
+}
+
+//go:linkname rtlGetNtVersionNumbers syscall.rtlGetNtVersionNumbers
+func rtlGetNtVersionNumbers(majorVersion *uint32, minorVersion *uint32, buildNumber *uint32) {
+ *majorVersion = osVersionInfo.majorVersion
+ *minorVersion = osVersionInfo.minorVersion
+ *buildNumber = osVersionInfo.buildNumber
+}
+
func osinit() {
asmstdcallAddr = unsafe.Pointer(windows.AsmStdCallAddr())
@@ -470,8 +564,8 @@
initHighResTimer()
timeBeginPeriodRetValue = osRelax(false)
- initSysDirectory()
initLongPathSupport()
+ initOsVersionInfo()
numCPUStartup = getCPUCount()
@@ -487,7 +581,7 @@
//go:nosplit
func readRandom(r []byte) int {
n := 0
- if stdcall(_ProcessPrng, uintptr(unsafe.Pointer(&r[0])), uintptr(len(r)))&0xff != 0 {
+ if stdcall(_RtlGenRandom, uintptr(unsafe.Pointer(&r[0])), uintptr(len(r)))&0xff != 0 {
n = len(r)
}
return n
Index: src/net/hook_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/net/hook_windows.go b/src/net/hook_windows.go
--- a/src/net/hook_windows.go (revision b0d48afabb9fd14976c27221cb525c5d2ebbfe79)
+++ b/src/net/hook_windows.go (revision 44e76f7cf1bc6e04b5da724e0b2e48f393713506)
@@ -13,6 +13,7 @@
hostsFilePath = windows.GetSystemDirectory() + "/Drivers/etc/hosts"
// Placeholders for socket system calls.
+ socketFunc func(int, int, int) (syscall.Handle, error) = syscall.Socket
wsaSocketFunc func(int32, int32, int32, *syscall.WSAProtocolInfo, uint32, uint32) (syscall.Handle, error) = windows.WSASocket
connectFunc func(syscall.Handle, syscall.Sockaddr) error = syscall.Connect
listenFunc func(syscall.Handle, int) error = syscall.Listen
Index: src/net/internal/socktest/main_test.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/net/internal/socktest/main_test.go b/src/net/internal/socktest/main_test.go
--- a/src/net/internal/socktest/main_test.go (revision b0d48afabb9fd14976c27221cb525c5d2ebbfe79)
+++ b/src/net/internal/socktest/main_test.go (revision 44e76f7cf1bc6e04b5da724e0b2e48f393713506)
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build !js && !plan9 && !wasip1 && !windows
+//go:build !js && !plan9 && !wasip1
package socktest_test
Index: src/net/internal/socktest/main_windows_test.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/net/internal/socktest/main_windows_test.go b/src/net/internal/socktest/main_windows_test.go
new file mode 100644
--- /dev/null (revision 44e76f7cf1bc6e04b5da724e0b2e48f393713506)
+++ b/src/net/internal/socktest/main_windows_test.go (revision 44e76f7cf1bc6e04b5da724e0b2e48f393713506)
@@ -0,0 +1,22 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package socktest_test
+
+import "syscall"
+
+var (
+ socketFunc func(int, int, int) (syscall.Handle, error)
+ closeFunc func(syscall.Handle) error
+)
+
+func installTestHooks() {
+ socketFunc = sw.Socket
+ closeFunc = sw.Closesocket
+}
+
+func uninstallTestHooks() {
+ socketFunc = syscall.Socket
+ closeFunc = syscall.Closesocket
+}
Index: src/net/internal/socktest/sys_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/net/internal/socktest/sys_windows.go b/src/net/internal/socktest/sys_windows.go
--- a/src/net/internal/socktest/sys_windows.go (revision b0d48afabb9fd14976c27221cb525c5d2ebbfe79)
+++ b/src/net/internal/socktest/sys_windows.go (revision 44e76f7cf1bc6e04b5da724e0b2e48f393713506)
@@ -9,6 +9,38 @@
"syscall"
)
+// Socket wraps [syscall.Socket].
+func (sw *Switch) Socket(family, sotype, proto int) (s syscall.Handle, err error) {
+ sw.once.Do(sw.init)
+
+ so := &Status{Cookie: cookie(family, sotype, proto)}
+ sw.fmu.RLock()
+ f, _ := sw.fltab[FilterSocket]
+ sw.fmu.RUnlock()
+
+ af, err := f.apply(so)
+ if err != nil {
+ return syscall.InvalidHandle, err
+ }
+ s, so.Err = syscall.Socket(family, sotype, proto)
+ if err = af.apply(so); err != nil {
+ if so.Err == nil {
+ syscall.Closesocket(s)
+ }
+ return syscall.InvalidHandle, err
+ }
+
+ sw.smu.Lock()
+ defer sw.smu.Unlock()
+ if so.Err != nil {
+ sw.stats.getLocked(so.Cookie).OpenFailed++
+ return syscall.InvalidHandle, so.Err
+ }
+ nso := sw.addLocked(s, family, sotype, proto)
+ sw.stats.getLocked(nso.Cookie).Opened++
+ return s, nil
+}
+
// WSASocket wraps [syscall.WSASocket].
func (sw *Switch) WSASocket(family, sotype, proto int32, protinfo *syscall.WSAProtocolInfo, group uint32, flags uint32) (s syscall.Handle, err error) {
sw.once.Do(sw.init)
Index: src/net/main_windows_test.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/net/main_windows_test.go b/src/net/main_windows_test.go
--- a/src/net/main_windows_test.go (revision b0d48afabb9fd14976c27221cb525c5d2ebbfe79)
+++ b/src/net/main_windows_test.go (revision 44e76f7cf1bc6e04b5da724e0b2e48f393713506)
@@ -12,6 +12,7 @@
var (
// Placeholders for saving original socket system calls.
+ origSocket = socketFunc
origWSASocket = wsaSocketFunc
origClosesocket = poll.CloseFunc
origConnect = connectFunc
@@ -21,6 +22,7 @@
)
func installTestHooks() {
+ socketFunc = sw.Socket
wsaSocketFunc = sw.WSASocket
poll.CloseFunc = sw.Closesocket
connectFunc = sw.Connect
@@ -30,6 +32,7 @@
}
func uninstallTestHooks() {
+ socketFunc = origSocket
wsaSocketFunc = origWSASocket
poll.CloseFunc = origClosesocket
connectFunc = origConnect
Index: src/net/sock_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/net/sock_windows.go b/src/net/sock_windows.go
--- a/src/net/sock_windows.go (revision b0d48afabb9fd14976c27221cb525c5d2ebbfe79)
+++ b/src/net/sock_windows.go (revision 44e76f7cf1bc6e04b5da724e0b2e48f393713506)
@@ -20,6 +20,21 @@
func sysSocket(family, sotype, proto int) (syscall.Handle, error) {
s, err := wsaSocketFunc(int32(family), int32(sotype), int32(proto),
nil, 0, windows.WSA_FLAG_OVERLAPPED|windows.WSA_FLAG_NO_HANDLE_INHERIT)
+ if err == nil {
+ return s, nil
+ }
+ // WSA_FLAG_NO_HANDLE_INHERIT flag is not supported on some
+ // old versions of Windows, see
+ // https://msdn.microsoft.com/en-us/library/windows/desktop/ms742212(v=vs.85).aspx
+ // for details. Just use syscall.Socket, if windows.WSASocket failed.
+
+ // See ../syscall/exec_unix.go for description of ForkLock.
+ syscall.ForkLock.RLock()
+ s, err = socketFunc(family, sotype, proto)
+ if err == nil {
+ syscall.CloseOnExec(s)
+ }
+ syscall.ForkLock.RUnlock()
if err != nil {
return syscall.InvalidHandle, os.NewSyscallError("socket", err)
}
Index: src/syscall/exec_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/syscall/exec_windows.go b/src/syscall/exec_windows.go
--- a/src/syscall/exec_windows.go (revision b0d48afabb9fd14976c27221cb525c5d2ebbfe79)
+++ b/src/syscall/exec_windows.go (revision b4aece36e51ecce81c3ee9fe03e31db552e90018)
@@ -15,7 +15,6 @@
"unsafe"
)
-// ForkLock is not used on Windows.
var ForkLock sync.RWMutex
// EscapeArg rewrites command line argument s as prescribed
@@ -304,6 +303,9 @@
var zeroProcAttr ProcAttr
var zeroSysProcAttr SysProcAttr
+//go:linkname rtlGetNtVersionNumbers
+func rtlGetNtVersionNumbers(majorVersion *uint32, minorVersion *uint32, buildNumber *uint32)
+
func StartProcess(argv0 string, argv []string, attr *ProcAttr) (pid int, handle uintptr, err error) {
if len(argv0) == 0 {
return 0, 0, EWINDOWS
@@ -367,6 +369,17 @@
}
}
+ var maj, min, build uint32
+ rtlGetNtVersionNumbers(&maj, &min, &build)
+ isWin7 := maj < 6 || (maj == 6 && min <= 1)
+ // NT kernel handles are divisible by 4, with the bottom 3 bits left as
+ // a tag. The fully set tag correlates with the types of handles we're
+ // concerned about here. Except, the kernel will interpret some
+ // special handle values, like -1, -2, and so forth, so kernelbase.dll
+ // checks to see that those bottom three bits are checked, but that top
+ // bit is not checked.
+ isLegacyWin7ConsoleHandle := func(handle Handle) bool { return isWin7 && handle&0x10000003 == 3 }
+
p, _ := GetCurrentProcess()
parentProcess := p
if sys.ParentProcess != 0 {
@@ -375,7 +388,15 @@
fd := make([]Handle, len(attr.Files))
for i := range attr.Files {
if attr.Files[i] > 0 {
- err := DuplicateHandle(p, Handle(attr.Files[i]), parentProcess, &fd[i], 0, true, DUPLICATE_SAME_ACCESS)
+ destinationProcessHandle := parentProcess
+
+ // On Windows 7, console handles aren't real handles, and can only be duplicated
+ // into the current process, not a parent one, which amounts to the same thing.
+ if parentProcess != p && isLegacyWin7ConsoleHandle(Handle(attr.Files[i])) {
+ destinationProcessHandle = p
+ }
+
+ err := DuplicateHandle(p, Handle(attr.Files[i]), destinationProcessHandle, &fd[i], 0, true, DUPLICATE_SAME_ACCESS)
if err != nil {
return 0, 0, err
}
@@ -406,6 +427,14 @@
fd = append(fd, sys.AdditionalInheritedHandles...)
+ // On Windows 7, console handles aren't real handles, so don't pass them
+ // through to PROC_THREAD_ATTRIBUTE_HANDLE_LIST.
+ for i := range fd {
+ if isLegacyWin7ConsoleHandle(fd[i]) {
+ fd[i] = 0
+ }
+ }
+
// The presence of a NULL handle in the list is enough to cause PROC_THREAD_ATTRIBUTE_HANDLE_LIST
// to treat the entire list as empty, so remove NULL handles.
j := 0
Index: src/syscall/dll_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/syscall/dll_windows.go b/src/syscall/dll_windows.go
--- a/src/syscall/dll_windows.go (revision b4aece36e51ecce81c3ee9fe03e31db552e90018)
+++ b/src/syscall/dll_windows.go (revision ea2726a6fa25fbfa1092e696e522eafca544d24c)
@@ -119,14 +119,7 @@
}
//go:linkname loadsystemlibrary
-func loadsystemlibrary(filename *uint16) (uintptr, Errno) {
- const _LOAD_LIBRARY_SEARCH_SYSTEM32 = 0x00000800
- handle, _, err := SyscallN(uintptr(__LoadLibraryExW), uintptr(unsafe.Pointer(filename)), 0, _LOAD_LIBRARY_SEARCH_SYSTEM32)
- if handle != 0 {
- err = 0
- }
- return handle, err
-}
+func loadsystemlibrary(filename *uint16, absoluteFilepath *uint16) (handle uintptr, err Errno)
//go:linkname getprocaddress
func getprocaddress(handle uintptr, procname *uint8) (uintptr, Errno) {
@@ -143,6 +136,9 @@
Handle Handle
}
+//go:linkname getSystemDirectory
+func getSystemDirectory() string // Implemented in runtime package.
+
// LoadDLL loads the named DLL file into memory.
//
// If name is not an absolute path and is not a known system DLL used by
@@ -159,7 +155,11 @@
var h uintptr
var e Errno
if sysdll.IsSystemDLL[name] {
- h, e = loadsystemlibrary(namep)
+ absoluteFilepathp, err := UTF16PtrFromString(getSystemDirectory() + name)
+ if err != nil {
+ return nil, err
+ }
+ h, e = loadsystemlibrary(namep, absoluteFilepathp)
} else {
h, e = loadlibrary(namep)
}
Index: src/os/removeall_at.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/os/removeall_at.go b/src/os/removeall_at.go
--- a/src/os/removeall_at.go (revision ea2726a6fa25fbfa1092e696e522eafca544d24c)
+++ b/src/os/removeall_at.go (revision d47e0d22130d597dcf9daa6b41fd9501274f0cb2)
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build unix || wasip1 || windows
+//go:build unix || wasip1
package os
@@ -175,3 +175,25 @@
}
return newDirFile(fd, name)
}
+
+func rootRemoveAll(r *Root, name string) error {
+ // Consistency with os.RemoveAll: Strip trailing /s from the name,
+ // so RemoveAll("not_a_directory/") succeeds.
+ for len(name) > 0 && IsPathSeparator(name[len(name)-1]) {
+ name = name[:len(name)-1]
+ }
+ if endsWithDot(name) {
+ // Consistency with os.RemoveAll: Return EINVAL when trying to remove .
+ return &PathError{Op: "RemoveAll", Path: name, Err: syscall.EINVAL}
+ }
+ _, err := doInRoot(r, name, nil, func(parent sysfdType, name string) (struct{}, error) {
+ return struct{}{}, removeAllFrom(parent, name)
+ })
+ if IsNotExist(err) {
+ return nil
+ }
+ if err != nil {
+ return &PathError{Op: "RemoveAll", Path: name, Err: underlyingError(err)}
+ }
+ return err
+}
Index: src/os/removeall_noat.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/os/removeall_noat.go b/src/os/removeall_noat.go
--- a/src/os/removeall_noat.go (revision ea2726a6fa25fbfa1092e696e522eafca544d24c)
+++ b/src/os/removeall_noat.go (revision d47e0d22130d597dcf9daa6b41fd9501274f0cb2)
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build (js && wasm) || plan9
+//go:build (js && wasm) || plan9 || windows
package os
@@ -140,3 +140,22 @@
}
return err
}
+
+func rootRemoveAll(r *Root, name string) error {
+ if endsWithDot(name) {
+ // Consistency with os.RemoveAll: Return EINVAL when trying to remove .
+ return &PathError{Op: "RemoveAll", Path: name, Err: syscall.EINVAL}
+ }
+ if err := checkPathEscapesLstat(r, name); err != nil {
+ if err == syscall.ENOTDIR {
+ // Some intermediate path component is not a directory.
+ // RemoveAll treats this as success (since the target doesn't exist).
+ return nil
+ }
+ return &PathError{Op: "RemoveAll", Path: name, Err: err}
+ }
+ if err := RemoveAll(joinPath(r.root.name, name)); err != nil {
+ return &PathError{Op: "RemoveAll", Path: name, Err: underlyingError(err)}
+ }
+ return nil
+}
Index: src/os/root_noopenat.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/os/root_noopenat.go b/src/os/root_noopenat.go
--- a/src/os/root_noopenat.go (revision ea2726a6fa25fbfa1092e696e522eafca544d24c)
+++ b/src/os/root_noopenat.go (revision d47e0d22130d597dcf9daa6b41fd9501274f0cb2)
@@ -11,7 +11,6 @@
"internal/filepathlite"
"internal/stringslite"
"sync/atomic"
- "syscall"
"time"
)
@@ -185,25 +184,6 @@
}
return nil
}
-
-func rootRemoveAll(r *Root, name string) error {
- if endsWithDot(name) {
- // Consistency with os.RemoveAll: Return EINVAL when trying to remove .
- return &PathError{Op: "RemoveAll", Path: name, Err: syscall.EINVAL}
- }
- if err := checkPathEscapesLstat(r, name); err != nil {
- if err == syscall.ENOTDIR {
- // Some intermediate path component is not a directory.
- // RemoveAll treats this as success (since the target doesn't exist).
- return nil
- }
- return &PathError{Op: "RemoveAll", Path: name, Err: err}
- }
- if err := RemoveAll(joinPath(r.root.name, name)); err != nil {
- return &PathError{Op: "RemoveAll", Path: name, Err: underlyingError(err)}
- }
- return nil
-}
func rootReadlink(r *Root, name string) (string, error) {
if err := checkPathEscapesLstat(r, name); err != nil {
Index: src/os/root_openat.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/os/root_openat.go b/src/os/root_openat.go
--- a/src/os/root_openat.go (revision ea2726a6fa25fbfa1092e696e522eafca544d24c)
+++ b/src/os/root_openat.go (revision d47e0d22130d597dcf9daa6b41fd9501274f0cb2)
@@ -196,28 +196,6 @@
return nil
}
-func rootRemoveAll(r *Root, name string) error {
- // Consistency with os.RemoveAll: Strip trailing /s from the name,
- // so RemoveAll("not_a_directory/") succeeds.
- for len(name) > 0 && IsPathSeparator(name[len(name)-1]) {
- name = name[:len(name)-1]
- }
- if endsWithDot(name) {
- // Consistency with os.RemoveAll: Return EINVAL when trying to remove .
- return &PathError{Op: "RemoveAll", Path: name, Err: syscall.EINVAL}
- }
- _, err := doInRoot(r, name, nil, func(parent sysfdType, name string) (struct{}, error) {
- return struct{}{}, removeAllFrom(parent, name)
- })
- if IsNotExist(err) {
- return nil
- }
- if err != nil {
- return &PathError{Op: "RemoveAll", Path: name, Err: underlyingError(err)}
- }
- return err
-}
-
func rootRename(r *Root, oldname, newname string) error {
_, err := doInRoot(r, oldname, nil, func(oldparent sysfdType, oldname string) (struct{}, error) {
_, err := doInRoot(r, newname, nil, func(newparent sysfdType, newname string) (struct{}, error) {
Index: src/os/root_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/os/root_windows.go b/src/os/root_windows.go
--- a/src/os/root_windows.go (revision ea2726a6fa25fbfa1092e696e522eafca544d24c)
+++ b/src/os/root_windows.go (revision d47e0d22130d597dcf9daa6b41fd9501274f0cb2)
@@ -402,3 +402,14 @@
}
return fi.Mode(), nil
}
+
+func checkPathEscapes(r *Root, name string) error {
+ if !filepathlite.IsLocal(name) {
+ return errPathEscapes
+ }
+ return nil
+}
+
+func checkPathEscapesLstat(r *Root, name string) error {
+ return checkPathEscapes(r, name)
+}
Index: src/os/exec_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/os/exec_windows.go b/src/os/exec_windows.go
--- a/src/os/exec_windows.go (revision d47e0d22130d597dcf9daa6b41fd9501274f0cb2)
+++ b/src/os/exec_windows.go (revision 00e8daec9a4d88f44a8dc55d3bdb71878e525b41)
@@ -10,6 +10,7 @@
"runtime"
"syscall"
"time"
+ _ "unsafe"
)
// Note that Process.handle is never nil because Windows always requires
@@ -49,9 +50,23 @@
// than statusDone.
p.doRelease(statusReleased)
+ var maj, min, build uint32
+ rtlGetNtVersionNumbers(&maj, &min, &build)
+ if maj < 10 {
+ // NOTE(brainman): It seems that sometimes process is not dead
+ // when WaitForSingleObject returns. But we do not know any
+ // other way to wait for it. Sleeping for a while seems to do
+ // the trick sometimes.
+ // See https://golang.org/issue/25965 for details.
+ time.Sleep(5 * time.Millisecond)
+ }
+
return &ProcessState{p.Pid, syscall.WaitStatus{ExitCode: ec}, &u}, nil
}
+//go:linkname rtlGetNtVersionNumbers syscall.rtlGetNtVersionNumbers
+func rtlGetNtVersionNumbers(majorVersion *uint32, minorVersion *uint32, buildNumber *uint32)
+
func (p *Process) signal(sig Signal) error {
handle, status := p.handleTransientAcquire()
switch status {

View File

@@ -59,6 +59,8 @@ jobs:
- { goos: linux, goarch: s390x, output: s390x, debian: s390x, rpm: s390x } - { goos: linux, goarch: s390x, output: s390x, debian: s390x, rpm: s390x }
- { goos: linux, goarch: ppc64le, output: ppc64le, debian: ppc64el, rpm: ppc64le } - { goos: linux, goarch: ppc64le, output: ppc64le, debian: ppc64el, rpm: ppc64le }
# Go 1.25 with special patch can work on Windows 7
# https://github.com/MetaCubeX/go/commits/release-branch.go1.25/
- { goos: windows, goarch: '386', output: '386' } - { goos: windows, goarch: '386', output: '386' }
- { goos: windows, goarch: amd64, goamd64: v1, output: amd64-compatible } # old style file name will be removed in next released - { goos: windows, goarch: amd64, goamd64: v1, output: amd64-compatible } # old style file name will be removed in next released
- { goos: windows, goarch: amd64, goamd64: v3, output: amd64 } - { goos: windows, goarch: amd64, goamd64: v3, output: amd64 }
@@ -153,12 +155,14 @@ jobs:
uses: actions/setup-go@v6 uses: actions/setup-go@v6
with: with:
go-version: '1.25' go-version: '1.25'
check-latest: true # Always check for the latest patch release
- name: Set up Go - name: Set up Go
if: ${{ matrix.jobs.goversion != '' && matrix.jobs.abi != '1' }} if: ${{ matrix.jobs.goversion != '' && matrix.jobs.abi != '1' }}
uses: actions/setup-go@v6 uses: actions/setup-go@v6
with: with:
go-version: ${{ matrix.jobs.goversion }} go-version: ${{ matrix.jobs.goversion }}
check-latest: true # Always check for the latest patch release
- name: Set up Go1.24 loongarch abi1 - name: Set up Go1.24 loongarch abi1
if: ${{ matrix.jobs.goarch == 'loong64' && matrix.jobs.abi == '1' }} if: ${{ matrix.jobs.goarch == 'loong64' && matrix.jobs.abi == '1' }}
@@ -176,6 +180,9 @@ jobs:
# 7c1157f9544922e96945196b47b95664b1e39108: "net: remove sysSocket fallback for Windows 7" # 7c1157f9544922e96945196b47b95664b1e39108: "net: remove sysSocket fallback for Windows 7"
# 48042aa09c2f878c4faa576948b07fe625c4707a: "syscall: remove Windows 7 console handle workaround" # 48042aa09c2f878c4faa576948b07fe625c4707a: "syscall: remove Windows 7 console handle workaround"
# a17d959debdb04cd550016a3501dd09d50cd62e7: "runtime: always use LoadLibraryEx to load system libraries" # a17d959debdb04cd550016a3501dd09d50cd62e7: "runtime: always use LoadLibraryEx to load system libraries"
# f0894a00f4b756d4b9b4078af2e686b359493583: "os: remove 5ms sleep on Windows in (*Process).Wait"
# sepical fix:
# - os.RemoveAll not working on Windows7
- name: Revert Golang1.25 commit for Windows7/8 - name: Revert Golang1.25 commit for Windows7/8
if: ${{ matrix.jobs.goos == 'windows' && matrix.jobs.goversion == '' }} if: ${{ matrix.jobs.goos == 'windows' && matrix.jobs.goversion == '' }}
run: | run: |

View File

@@ -24,6 +24,7 @@ jobs:
- 'ubuntu-24.04-arm' # arm64 linux - 'ubuntu-24.04-arm' # arm64 linux
- 'macos-15-intel' # amd64 macos - 'macos-15-intel' # amd64 macos
go-version: go-version:
- '1.26.0-rc.3'
- '1.25' - '1.25'
- '1.24' - '1.24'
- '1.23' - '1.23'
@@ -47,13 +48,20 @@ jobs:
uses: actions/setup-go@v6 uses: actions/setup-go@v6
with: with:
go-version: ${{ matrix.go-version }} go-version: ${{ matrix.go-version }}
check-latest: true # Always check for the latest patch release
- name: Revert Golang commit for Windows7/8 - name: Revert Golang commit for Windows7/8
if: ${{ runner.os == 'Windows' && matrix.go-version != '1.20' }} if: ${{ runner.os == 'Windows' && matrix.go-version != '1.20' && matrix.go-version != '1.26.0-rc.3' }}
run: | run: |
cd $(go env GOROOT) cd $(go env GOROOT)
patch --verbose -p 1 < $GITHUB_WORKSPACE/.github/patch/go${{matrix.go-version}}.patch patch --verbose -p 1 < $GITHUB_WORKSPACE/.github/patch/go${{matrix.go-version}}.patch
- name: Revert Golang commit for Windows7/8
if: ${{ runner.os == 'Windows' && matrix.go-version == '1.26.0-rc.3' }}
run: |
cd $(go env GOROOT)
patch --verbose -p 1 < $GITHUB_WORKSPACE/.github/patch/go1.26.patch
- name: Remove inbound test for macOS - name: Remove inbound test for macOS
if: ${{ runner.os == 'macOS' }} if: ${{ runner.os == 'macOS' }}
run: | run: |

View File

@@ -5,7 +5,6 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net" "net"
"net/http"
"net/url" "net/url"
"strings" "strings"
"time" "time"
@@ -17,6 +16,8 @@ import (
"github.com/metacubex/mihomo/component/ca" "github.com/metacubex/mihomo/component/ca"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/log" "github.com/metacubex/mihomo/log"
"github.com/metacubex/http"
) )
var UnifiedDelay = atomic.NewBool(false) var UnifiedDelay = atomic.NewBool(false)
@@ -153,8 +154,9 @@ func (p *Proxy) MarshalJSON() ([]byte, error) {
mapping["mptcp"] = proxyInfo.MPTCP mapping["mptcp"] = proxyInfo.MPTCP
mapping["smux"] = proxyInfo.SMUX mapping["smux"] = proxyInfo.SMUX
mapping["interface"] = proxyInfo.Interface mapping["interface"] = proxyInfo.Interface
mapping["dialer-proxy"] = proxyInfo.DialerProxy
mapping["routing-mark"] = proxyInfo.RoutingMark mapping["routing-mark"] = proxyInfo.RoutingMark
mapping["provider-name"] = proxyInfo.ProviderName
mapping["dialer-proxy"] = proxyInfo.DialerProxy
return json.Marshal(mapping) return json.Marshal(mapping)
} }
@@ -177,14 +179,12 @@ func (p *Proxy) URLTest(ctx context.Context, url string, expectedStatus utils.In
p.history.Pop() p.history.Pop()
} }
state, ok := p.extra.Load(url) state, _ := p.extra.LoadOrStoreFn(url, func() *internalProxyState {
if !ok { return &internalProxyState{
state = &internalProxyState{
history: queue.New[C.DelayHistory](defaultHistoriesNum), history: queue.New[C.DelayHistory](defaultHistoriesNum),
alive: atomic.NewBool(true), alive: atomic.NewBool(true),
} }
p.extra.Store(url, state) })
}
if !satisfied { if !satisfied {
record.Delay = 0 record.Delay = 0

View File

@@ -2,9 +2,10 @@ package inbound
import ( import (
"net" "net"
"net/http"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/http"
) )
// NewHTTPS receive CONNECT request and return ConnContext // NewHTTPS receive CONNECT request and return ConnContext

View File

@@ -8,6 +8,7 @@ import (
"sync" "sync"
"github.com/metacubex/mihomo/component/keepalive" "github.com/metacubex/mihomo/component/keepalive"
"github.com/metacubex/mihomo/component/mptcp"
"github.com/metacubex/tfo-go" "github.com/metacubex/tfo-go"
) )
@@ -34,13 +35,13 @@ func Tfo() bool {
func SetMPTCP(open bool) { func SetMPTCP(open bool) {
mutex.Lock() mutex.Lock()
defer mutex.Unlock() defer mutex.Unlock()
setMultiPathTCP(&lc.ListenConfig, open) mptcp.SetNetListenConfig(&lc.ListenConfig, open)
} }
func MPTCP() bool { func MPTCP() bool {
mutex.RLock() mutex.RLock()
defer mutex.RUnlock() defer mutex.RUnlock()
return getMultiPathTCP(&lc.ListenConfig) return mptcp.GetNetListenConfig(&lc.ListenConfig)
} }
func preResolve(network, address string) (string, error) { func preResolve(network, address string) (string, error) {

View File

@@ -1,14 +0,0 @@
//go:build !go1.21
package inbound
import "net"
const multipathTCPAvailable = false
func setMultiPathTCP(listenConfig *net.ListenConfig, open bool) {
}
func getMultiPathTCP(listenConfig *net.ListenConfig) bool {
return false
}

View File

@@ -1,15 +0,0 @@
//go:build go1.21
package inbound
import "net"
const multipathTCPAvailable = true
func setMultiPathTCP(listenConfig *net.ListenConfig, open bool) {
listenConfig.SetMultipathTCP(open)
}
func getMultiPathTCP(listenConfig *net.ListenConfig) bool {
return listenConfig.MultipathTCP()
}

View File

@@ -2,12 +2,13 @@ package inbound
import ( import (
"net" "net"
"net/http"
"net/netip" "net/netip"
"strings" "strings"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/transport/socks5" "github.com/metacubex/mihomo/transport/socks5"
"github.com/metacubex/http"
) )
func parseSocksAddr(target socks5.Addr) *C.Metadata { func parseSocksAddr(target socks5.Addr) *C.Metadata {

View File

@@ -6,8 +6,7 @@ import (
"strconv" "strconv"
"time" "time"
CN "github.com/metacubex/mihomo/common/net" N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/proxydialer" "github.com/metacubex/mihomo/component/proxydialer"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/transport/anytls" "github.com/metacubex/mihomo/transport/anytls"
@@ -20,7 +19,6 @@ import (
type AnyTLS struct { type AnyTLS struct {
*Base *Base
client *anytls.Client client *anytls.Client
dialer proxydialer.SingDialer
option *AnyTLSOption option *AnyTLSOption
} }
@@ -65,7 +63,7 @@ func (t *AnyTLS) ListenPacketContext(ctx context.Context, metadata *C.Metadata)
// create uot on tcp // create uot on tcp
destination := M.SocksaddrFromNet(metadata.UDPAddr()) destination := M.SocksaddrFromNet(metadata.UDPAddr())
return newPacketConn(CN.NewThreadSafePacketConn(uot.NewLazyConn(c, uot.Request{Destination: destination})), t), nil return newPacketConn(N.NewThreadSafePacketConn(uot.NewLazyConn(c, uot.Request{Destination: destination})), t), nil
} }
// SupportUOT implements C.ProxyAdapter // SupportUOT implements C.ProxyAdapter
@@ -92,18 +90,18 @@ func NewAnyTLS(option AnyTLSOption) (*AnyTLS, error) {
name: option.Name, name: option.Name,
addr: addr, addr: addr,
tp: C.AnyTLS, tp: C.AnyTLS,
pdName: option.ProviderName,
udp: option.UDP, udp: option.UDP,
tfo: option.TFO, tfo: option.TFO,
mpTcp: option.MPTCP, mpTcp: option.MPTCP,
iface: option.Interface, iface: option.Interface,
rmark: option.RoutingMark, rmark: option.RoutingMark,
prefer: C.NewDNSPrefer(option.IPVersion), prefer: option.IPVersion,
}, },
option: &option, option: &option,
} }
outbound.dialer = option.NewDialer(outbound.DialOptions())
singDialer := proxydialer.NewByNameSingDialer(option.DialerProxy, dialer.NewDialer(outbound.DialOptions()...)) singDialer := proxydialer.NewSingDialer(outbound.dialer)
outbound.dialer = singDialer
tOption := anytls.ClientConfig{ tOption := anytls.ClientConfig{
Password: option.Password, Password: option.Password,

View File

@@ -12,6 +12,7 @@ import (
N "github.com/metacubex/mihomo/common/net" N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/common/utils" "github.com/metacubex/mihomo/common/utils"
"github.com/metacubex/mihomo/component/dialer" "github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/proxydialer"
"github.com/metacubex/mihomo/component/resolver" "github.com/metacubex/mihomo/component/resolver"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/log" "github.com/metacubex/mihomo/log"
@@ -26,15 +27,17 @@ type ProxyAdapter interface {
type Base struct { type Base struct {
name string name string
addr string addr string
iface string
tp C.AdapterType tp C.AdapterType
pdName string
udp bool udp bool
xudp bool xudp bool
tfo bool tfo bool
mpTcp bool mpTcp bool
iface string
rmark int rmark int
id string
prefer C.DNSPrefer prefer C.DNSPrefer
dialer C.Dialer
id string
} }
// Name implements C.ProxyAdapter // Name implements C.ProxyAdapter
@@ -56,35 +59,15 @@ func (b *Base) Type() C.AdapterType {
return b.tp return b.tp
} }
// StreamConnContext implements C.ProxyAdapter
func (b *Base) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (net.Conn, error) {
return c, C.ErrNotSupport
}
func (b *Base) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) { func (b *Base) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) {
return nil, C.ErrNotSupport return nil, C.ErrNotSupport
} }
// DialContextWithDialer implements C.ProxyAdapter
func (b *Base) DialContextWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (_ C.Conn, err error) {
return nil, C.ErrNotSupport
}
// ListenPacketContext implements C.ProxyAdapter // ListenPacketContext implements C.ProxyAdapter
func (b *Base) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (C.PacketConn, error) { func (b *Base) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (C.PacketConn, error) {
return nil, C.ErrNotSupport return nil, C.ErrNotSupport
} }
// ListenPacketWithDialer implements C.ProxyAdapter
func (b *Base) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (_ C.PacketConn, err error) {
return nil, C.ErrNotSupport
}
// SupportWithDialer implements C.ProxyAdapter
func (b *Base) SupportWithDialer() C.NetWork {
return C.InvalidNet
}
// SupportUOT implements C.ProxyAdapter // SupportUOT implements C.ProxyAdapter
func (b *Base) SupportUOT() bool { func (b *Base) SupportUOT() bool {
return false return false
@@ -103,6 +86,7 @@ func (b *Base) ProxyInfo() (info C.ProxyInfo) {
info.SMUX = false info.SMUX = false
info.Interface = b.iface info.Interface = b.iface
info.RoutingMark = b.rmark info.RoutingMark = b.rmark
info.ProviderName = b.pdName
return return
} }
@@ -178,12 +162,30 @@ func (b *Base) Close() error {
} }
type BasicOption struct { type BasicOption struct {
TFO bool `proxy:"tfo,omitempty"` TFO bool `proxy:"tfo,omitempty"`
MPTCP bool `proxy:"mptcp,omitempty"` MPTCP bool `proxy:"mptcp,omitempty"`
Interface string `proxy:"interface-name,omitempty"` Interface string `proxy:"interface-name,omitempty"`
RoutingMark int `proxy:"routing-mark,omitempty"` RoutingMark int `proxy:"routing-mark,omitempty"`
IPVersion string `proxy:"ip-version,omitempty"` IPVersion C.DNSPrefer `proxy:"ip-version,omitempty"`
DialerProxy string `proxy:"dialer-proxy,omitempty"` // don't apply this option into groups, but can set a group name in a proxy DialerProxy string `proxy:"dialer-proxy,omitempty"` // don't apply this option into groups, but can set a group name in a proxy
//
// The following parameters are used internally, assign value by the structure decoder are disallowed
//
DialerForAPI C.Dialer `proxy:"-"` // the dialer used for API usage has higher priority than all the above configurations.
ProviderName string `proxy:"-"`
}
func (b *BasicOption) NewDialer(opts []dialer.Option) C.Dialer {
cDialer := b.DialerForAPI
if cDialer == nil {
if b.DialerProxy != "" {
cDialer = proxydialer.NewByName(b.DialerProxy)
} else {
cDialer = dialer.NewDialer(opts...)
}
}
return cDialer
} }
type BaseOption struct { type BaseOption struct {
@@ -217,6 +219,7 @@ func NewBase(opt BaseOption) *Base {
type conn struct { type conn struct {
N.ExtendedConn N.ExtendedConn
chain C.Chain chain C.Chain
pdChain C.Chain
adapterAddr string adapterAddr string
} }
@@ -238,9 +241,15 @@ func (c *conn) Chains() C.Chain {
return c.chain return c.chain
} }
// ProviderChains implements C.Connection
func (c *conn) ProviderChains() C.Chain {
return c.pdChain
}
// AppendToChains implements C.Connection // AppendToChains implements C.Connection
func (c *conn) AppendToChains(a C.ProxyAdapter) { func (c *conn) AppendToChains(a C.ProxyAdapter) {
c.chain = append(c.chain, a.Name()) c.chain = append(c.chain, a.Name())
c.pdChain = append(c.pdChain, a.ProxyInfo().ProviderName)
} }
func (c *conn) Upstream() any { func (c *conn) Upstream() any {
@@ -263,7 +272,7 @@ func NewConn(c net.Conn, a C.ProxyAdapter) C.Conn {
if _, ok := c.(syscall.Conn); !ok { // exclusion system conn like *net.TCPConn if _, ok := c.(syscall.Conn); !ok { // exclusion system conn like *net.TCPConn
c = N.NewDeadlineConn(c) // most conn from outbound can't handle readDeadline correctly c = N.NewDeadlineConn(c) // most conn from outbound can't handle readDeadline correctly
} }
cc := &conn{N.NewExtendedConn(c), nil, a.Addr()} cc := &conn{N.NewExtendedConn(c), nil, nil, a.Addr()}
cc.AppendToChains(a) cc.AppendToChains(a)
return cc return cc
} }
@@ -271,6 +280,7 @@ func NewConn(c net.Conn, a C.ProxyAdapter) C.Conn {
type packetConn struct { type packetConn struct {
N.EnhancePacketConn N.EnhancePacketConn
chain C.Chain chain C.Chain
pdChain C.Chain
adapterName string adapterName string
connID string connID string
adapterAddr string adapterAddr string
@@ -291,9 +301,15 @@ func (c *packetConn) Chains() C.Chain {
return c.chain return c.chain
} }
// ProviderChains implements C.Connection
func (c *packetConn) ProviderChains() C.Chain {
return c.pdChain
}
// AppendToChains implements C.Connection // AppendToChains implements C.Connection
func (c *packetConn) AppendToChains(a C.ProxyAdapter) { func (c *packetConn) AppendToChains(a C.ProxyAdapter) {
c.chain = append(c.chain, a.Name()) c.chain = append(c.chain, a.Name())
c.pdChain = append(c.pdChain, a.ProxyInfo().ProviderName)
} }
func (c *packetConn) LocalAddr() net.Addr { func (c *packetConn) LocalAddr() net.Addr {
@@ -322,7 +338,7 @@ func newPacketConn(pc net.PacketConn, a ProxyAdapter) C.PacketConn {
if _, ok := pc.(syscall.Conn); !ok { // exclusion system conn like *net.UDPConn if _, ok := pc.(syscall.Conn); !ok { // exclusion system conn like *net.UDPConn
epc = N.NewDeadlineEnhancePacketConn(epc) // most conn from outbound can't handle readDeadline correctly epc = N.NewDeadlineEnhancePacketConn(epc) // most conn from outbound can't handle readDeadline correctly
} }
cpc := &packetConn{epc, nil, a.Name(), utils.NewUUIDV4().String(), a.Addr(), a.ResolveUDP} cpc := &packetConn{epc, nil, nil, a.Name(), utils.NewUUIDV4().String(), a.Addr(), a.ResolveUDP}
cpc.AppendToChains(a) cpc.AppendToChains(a)
return cpc return cpc
} }
@@ -348,17 +364,6 @@ func (p *autoCloseProxyAdapter) DialContext(ctx context.Context, metadata *C.Met
return c, nil return c, nil
} }
func (p *autoCloseProxyAdapter) DialContextWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (_ C.Conn, err error) {
c, err := p.ProxyAdapter.DialContextWithDialer(ctx, dialer, metadata)
if err != nil {
return nil, err
}
if c, ok := c.(AddRef); ok {
c.AddRef(p)
}
return c, nil
}
func (p *autoCloseProxyAdapter) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (_ C.PacketConn, err error) { func (p *autoCloseProxyAdapter) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (_ C.PacketConn, err error) {
pc, err := p.ProxyAdapter.ListenPacketContext(ctx, metadata) pc, err := p.ProxyAdapter.ListenPacketContext(ctx, metadata)
if err != nil { if err != nil {
@@ -370,17 +375,6 @@ func (p *autoCloseProxyAdapter) ListenPacketContext(ctx context.Context, metadat
return pc, nil return pc, nil
} }
func (p *autoCloseProxyAdapter) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (_ C.PacketConn, err error) {
pc, err := p.ProxyAdapter.ListenPacketWithDialer(ctx, dialer, metadata)
if err != nil {
return nil, err
}
if pc, ok := pc.(AddRef); ok {
pc.AddRef(p)
}
return pc, nil
}
func (p *autoCloseProxyAdapter) Close() error { func (p *autoCloseProxyAdapter) Close() error {
p.closeOnce.Do(func() { p.closeOnce.Do(func() {
log.Debugln("Closing outdated proxy [%s]", p.Name()) log.Debugln("Closing outdated proxy [%s]", p.Name())

View File

@@ -69,12 +69,13 @@ func NewDirectWithOption(option DirectOption) *Direct {
Base: &Base{ Base: &Base{
name: option.Name, name: option.Name,
tp: C.Direct, tp: C.Direct,
pdName: option.ProviderName,
udp: true, udp: true,
tfo: option.TFO, tfo: option.TFO,
mpTcp: option.MPTCP, mpTcp: option.MPTCP,
iface: option.Interface, iface: option.Interface,
rmark: option.RoutingMark, rmark: option.RoutingMark,
prefer: C.NewDNSPrefer(option.IPVersion), prefer: option.IPVersion,
}, },
loopBack: loopback.NewDetector(), loopBack: loopback.NewDetector(),
} }

View File

@@ -158,12 +158,13 @@ func NewDnsWithOption(option DnsOption) *Dns {
Base: &Base{ Base: &Base{
name: option.Name, name: option.Name,
tp: C.Dns, tp: C.Dns,
pdName: option.ProviderName,
udp: true, udp: true,
tfo: option.TFO, tfo: option.TFO,
mpTcp: option.MPTCP, mpTcp: option.MPTCP,
iface: option.Interface, iface: option.Interface,
rmark: option.RoutingMark, rmark: option.RoutingMark,
prefer: C.NewDNSPrefer(option.IPVersion), prefer: option.IPVersion,
}, },
} }
} }

View File

@@ -12,6 +12,8 @@ import (
type ECHOptions struct { type ECHOptions struct {
Enable bool `proxy:"enable,omitempty" obfs:"enable,omitempty"` Enable bool `proxy:"enable,omitempty" obfs:"enable,omitempty"`
Config string `proxy:"config,omitempty" obfs:"config,omitempty"` Config string `proxy:"config,omitempty" obfs:"config,omitempty"`
QueryServerName string `proxy:"query-server-name,omitempty" obfs:"query-server-name,omitempty"`
} }
func (o ECHOptions) Parse() (*ech.Config, error) { func (o ECHOptions) Parse() (*ech.Config, error) {
@@ -29,6 +31,9 @@ func (o ECHOptions) Parse() (*ech.Config, error) {
} }
} else { } else {
echConfig.GetEncryptedClientHelloConfigList = func(ctx context.Context, serverName string) ([]byte, error) { echConfig.GetEncryptedClientHelloConfigList = func(ctx context.Context, serverName string) ([]byte, error) {
if o.QueryServerName != "" { // overrides the domain name used for ECH HTTPS record queries
serverName = o.QueryServerName
}
return resolver.ResolveECHWithResolver(ctx, serverName, resolver.ProxyServerHostResolver) return resolver.ResolveECHWithResolver(ctx, serverName, resolver.ProxyServerHostResolver)
} }
} }

View File

@@ -3,19 +3,18 @@ package outbound
import ( import (
"bufio" "bufio"
"context" "context"
"crypto/tls"
"encoding/base64" "encoding/base64"
"errors" "errors"
"fmt" "fmt"
"net" "net"
"net/http"
"strconv" "strconv"
N "github.com/metacubex/mihomo/common/net" N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/component/ca" "github.com/metacubex/mihomo/component/ca"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/proxydialer"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/http"
"github.com/metacubex/tls"
) )
type Http struct { type Http struct {
@@ -61,18 +60,7 @@ func (h *Http) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.Me
// DialContext implements C.ProxyAdapter // DialContext implements C.ProxyAdapter
func (h *Http) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn, err error) { func (h *Http) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn, err error) {
return h.DialContextWithDialer(ctx, dialer.NewDialer(h.DialOptions()...), metadata) c, err := h.dialer.DialContext(ctx, "tcp", h.addr)
}
// DialContextWithDialer implements C.ProxyAdapter
func (h *Http) DialContextWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (_ C.Conn, err error) {
if len(h.option.DialerProxy) > 0 {
dialer, err = proxydialer.NewByName(h.option.DialerProxy, dialer)
if err != nil {
return nil, err
}
}
c, err := dialer.DialContext(ctx, "tcp", h.addr)
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error: %w", h.addr, err) return nil, fmt.Errorf("%s connect error: %w", h.addr, err)
} }
@@ -89,11 +77,6 @@ func (h *Http) DialContextWithDialer(ctx context.Context, dialer C.Dialer, metad
return NewConn(c, h), nil return NewConn(c, h), nil
} }
// SupportWithDialer implements C.ProxyAdapter
func (h *Http) SupportWithDialer() C.NetWork {
return C.TCP
}
// ProxyInfo implements C.ProxyAdapter // ProxyInfo implements C.ProxyAdapter
func (h *Http) ProxyInfo() C.ProxyInfo { func (h *Http) ProxyInfo() C.ProxyInfo {
info := h.Base.ProxyInfo() info := h.Base.ProxyInfo()
@@ -183,20 +166,23 @@ func NewHttp(option HttpOption) (*Http, error) {
} }
} }
return &Http{ outbound := &Http{
Base: &Base{ Base: &Base{
name: option.Name, name: option.Name,
addr: net.JoinHostPort(option.Server, strconv.Itoa(option.Port)), addr: net.JoinHostPort(option.Server, strconv.Itoa(option.Port)),
tp: C.Http, tp: C.Http,
pdName: option.ProviderName,
tfo: option.TFO, tfo: option.TFO,
mpTcp: option.MPTCP, mpTcp: option.MPTCP,
iface: option.Interface, iface: option.Interface,
rmark: option.RoutingMark, rmark: option.RoutingMark,
prefer: C.NewDNSPrefer(option.IPVersion), prefer: option.IPVersion,
}, },
user: option.UserName, user: option.UserName,
pass: option.Password, pass: option.Password,
tlsConfig: tlsConfig, tlsConfig: tlsConfig,
option: &option, option: &option,
}, nil }
outbound.dialer = option.NewDialer(outbound.DialOptions())
return outbound, nil
} }

View File

@@ -2,7 +2,6 @@ package outbound
import ( import (
"context" "context"
"crypto/tls"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"net" "net"
@@ -13,8 +12,6 @@ import (
"github.com/metacubex/mihomo/component/ca" "github.com/metacubex/mihomo/component/ca"
"github.com/metacubex/mihomo/component/dialer" "github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/ech" "github.com/metacubex/mihomo/component/ech"
"github.com/metacubex/mihomo/component/proxydialer"
tlsC "github.com/metacubex/mihomo/component/tls"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/log" "github.com/metacubex/mihomo/log"
hyCongestion "github.com/metacubex/mihomo/transport/hysteria/congestion" hyCongestion "github.com/metacubex/mihomo/transport/hysteria/congestion"
@@ -24,6 +21,8 @@ import (
"github.com/metacubex/mihomo/transport/hysteria/transport" "github.com/metacubex/mihomo/transport/hysteria/transport"
"github.com/metacubex/mihomo/transport/hysteria/utils" "github.com/metacubex/mihomo/transport/hysteria/utils"
"github.com/metacubex/tls"
"github.com/metacubex/quic-go" "github.com/metacubex/quic-go"
"github.com/metacubex/quic-go/congestion" "github.com/metacubex/quic-go/congestion"
M "github.com/metacubex/sing/common/metadata" M "github.com/metacubex/sing/common/metadata"
@@ -46,7 +45,7 @@ type Hysteria struct {
option *HysteriaOption option *HysteriaOption
client *core.Client client *core.Client
tlsConfig *tlsC.Config tlsConfig *tls.Config
echConfig *ech.Config echConfig *ech.Config
} }
@@ -74,16 +73,8 @@ func (h *Hysteria) genHdc(ctx context.Context) utils.PacketDialer {
return &hyDialerWithContext{ return &hyDialerWithContext{
ctx: context.Background(), ctx: context.Background(),
hyDialer: func(network string, rAddr net.Addr) (net.PacketConn, error) { hyDialer: func(network string, rAddr net.Addr) (net.PacketConn, error) {
var err error
var cDialer C.Dialer = dialer.NewDialer(h.DialOptions()...)
if len(h.option.DialerProxy) > 0 {
cDialer, err = proxydialer.NewByName(h.option.DialerProxy, cDialer)
if err != nil {
return nil, err
}
}
rAddrPort, _ := netip.ParseAddrPort(rAddr.String()) rAddrPort, _ := netip.ParseAddrPort(rAddr.String())
return cDialer.ListenPacket(ctx, network, "", rAddrPort) return h.dialer.ListenPacket(ctx, network, "", rAddrPort)
}, },
remoteAddr: func(addr string) (net.Addr, error) { remoteAddr: func(addr string) (net.Addr, error) {
udpAddr, err := resolveUDPAddr(ctx, "udp", addr, h.prefer) udpAddr, err := resolveUDPAddr(ctx, "udp", addr, h.prefer)
@@ -184,7 +175,7 @@ func NewHysteria(option HysteriaOption) (*Hysteria, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
tlsClientConfig := tlsC.UConfig(tlsConfig) tlsClientConfig := tlsConfig
quicConfig := &quic.Config{ quicConfig := &quic.Config{
InitialStreamReceiveWindow: uint64(option.ReceiveWindowConn), InitialStreamReceiveWindow: uint64(option.ReceiveWindowConn),
@@ -252,17 +243,19 @@ func NewHysteria(option HysteriaOption) (*Hysteria, error) {
name: option.Name, name: option.Name,
addr: addr, addr: addr,
tp: C.Hysteria, tp: C.Hysteria,
pdName: option.ProviderName,
udp: true, udp: true,
tfo: option.FastOpen, tfo: option.FastOpen,
iface: option.Interface, iface: option.Interface,
rmark: option.RoutingMark, rmark: option.RoutingMark,
prefer: C.NewDNSPrefer(option.IPVersion), prefer: option.IPVersion,
}, },
option: &option, option: &option,
client: client, client: client,
tlsConfig: tlsClientConfig, tlsConfig: tlsClientConfig,
echConfig: echConfig, echConfig: echConfig,
} }
outbound.dialer = option.NewDialer(outbound.DialOptions())
return outbound, nil return outbound, nil
} }

View File

@@ -2,19 +2,16 @@ package outbound
import ( import (
"context" "context"
"crypto/tls"
"errors" "errors"
"fmt" "fmt"
"net" "net"
"strconv" "strconv"
"time" "time"
CN "github.com/metacubex/mihomo/common/net" N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/common/utils" "github.com/metacubex/mihomo/common/utils"
"github.com/metacubex/mihomo/component/ca" "github.com/metacubex/mihomo/component/ca"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/proxydialer" "github.com/metacubex/mihomo/component/proxydialer"
tlsC "github.com/metacubex/mihomo/component/tls"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/log" "github.com/metacubex/mihomo/log"
tuicCommon "github.com/metacubex/mihomo/transport/tuic/common" tuicCommon "github.com/metacubex/mihomo/transport/tuic/common"
@@ -22,6 +19,7 @@ import (
"github.com/metacubex/quic-go" "github.com/metacubex/quic-go"
"github.com/metacubex/sing-quic/hysteria2" "github.com/metacubex/sing-quic/hysteria2"
M "github.com/metacubex/sing/common/metadata" M "github.com/metacubex/sing/common/metadata"
"github.com/metacubex/tls"
) )
func init() { func init() {
@@ -36,7 +34,6 @@ type Hysteria2 struct {
option *Hysteria2Option option *Hysteria2Option
client *hysteria2.Client client *hysteria2.Client
dialer proxydialer.SingDialer
} }
type Hysteria2Option struct { type Hysteria2Option struct {
@@ -87,7 +84,7 @@ func (h *Hysteria2) ListenPacketContext(ctx context.Context, metadata *C.Metadat
if pc == nil { if pc == nil {
return nil, errors.New("packetConn is nil") return nil, errors.New("packetConn is nil")
} }
return newPacketConn(CN.NewThreadSafePacketConn(pc), h), nil return newPacketConn(N.NewThreadSafePacketConn(pc), h), nil
} }
// Close implements C.ProxyAdapter // Close implements C.ProxyAdapter
@@ -112,16 +109,16 @@ func NewHysteria2(option Hysteria2Option) (*Hysteria2, error) {
name: option.Name, name: option.Name,
addr: addr, addr: addr,
tp: C.Hysteria2, tp: C.Hysteria2,
pdName: option.ProviderName,
udp: true, udp: true,
iface: option.Interface, iface: option.Interface,
rmark: option.RoutingMark, rmark: option.RoutingMark,
prefer: C.NewDNSPrefer(option.IPVersion), prefer: option.IPVersion,
}, },
option: &option, option: &option,
} }
outbound.dialer = option.NewDialer(outbound.DialOptions())
singDialer := proxydialer.NewByNameSingDialer(option.DialerProxy, dialer.NewDialer(outbound.DialOptions()...)) singDialer := proxydialer.NewSingDialer(outbound.dialer)
outbound.dialer = singDialer
var salamanderPassword string var salamanderPassword string
if len(option.Obfs) > 0 { if len(option.Obfs) > 0 {
@@ -159,7 +156,7 @@ func NewHysteria2(option Hysteria2Option) (*Hysteria2, error) {
tlsConfig.NextProtos = option.ALPN tlsConfig.NextProtos = option.ALPN
} }
tlsClientConfig := tlsC.UConfig(tlsConfig) tlsClientConfig := tlsConfig
echConfig, err := option.ECHOpts.Parse() echConfig, err := option.ECHOpts.Parse()
if err != nil { if err != nil {
return nil, err return nil, err
@@ -192,7 +189,7 @@ func NewHysteria2(option Hysteria2Option) (*Hysteria2, error) {
CWND: option.CWND, CWND: option.CWND,
UdpMTU: option.UdpMTU, UdpMTU: option.UdpMTU,
ServerAddress: func(ctx context.Context) (*net.UDPAddr, error) { ServerAddress: func(ctx context.Context) (*net.UDPAddr, error) {
udpAddr, err := resolveUDPAddr(ctx, "udp", addr, C.NewDNSPrefer(option.IPVersion)) udpAddr, err := resolveUDPAddr(ctx, "udp", addr, option.IPVersion)
if err != nil { if err != nil {
return nil, err return nil, err
} }

397
adapter/outbound/masque.go Normal file
View File

@@ -0,0 +1,397 @@
package outbound
import (
"context"
"crypto/ecdsa"
"crypto/x509"
"encoding/base64"
"errors"
"fmt"
"net"
"net/netip"
"strconv"
"strings"
"sync"
"time"
"github.com/metacubex/mihomo/common/atomic"
"github.com/metacubex/mihomo/common/contextutils"
"github.com/metacubex/mihomo/common/pool"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/resolver"
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/dns"
"github.com/metacubex/mihomo/log"
"github.com/metacubex/mihomo/transport/masque"
"github.com/metacubex/mihomo/transport/tuic/common"
connectip "github.com/metacubex/connect-ip-go"
"github.com/metacubex/quic-go"
wireguard "github.com/metacubex/sing-wireguard"
M "github.com/metacubex/sing/common/metadata"
"github.com/metacubex/tls"
)
type Masque struct {
*Base
tlsConfig *tls.Config
quicConfig *quic.Config
tunDevice wireguard.Device
resolver resolver.Resolver
uri string
runCtx context.Context
runCancel context.CancelFunc
runMutex sync.Mutex
running atomic.Bool
runDevice atomic.Bool
option MasqueOption
}
type MasqueOption struct {
BasicOption
Name string `proxy:"name"`
Server string `proxy:"server"`
Port int `proxy:"port"`
PrivateKey string `proxy:"private-key"`
PublicKey string `proxy:"public-key"`
Ip string `proxy:"ip,omitempty"`
Ipv6 string `proxy:"ipv6,omitempty"`
URI string `proxy:"uri,omitempty"`
SNI string `proxy:"sni,omitempty"`
MTU int `proxy:"mtu,omitempty"`
UDP bool `proxy:"udp,omitempty"`
CongestionController string `proxy:"congestion-controller,omitempty"`
CWND int `proxy:"cwnd,omitempty"`
RemoteDnsResolve bool `proxy:"remote-dns-resolve,omitempty"`
Dns []string `proxy:"dns,omitempty"`
}
func (option MasqueOption) Prefixes() ([]netip.Prefix, error) {
localPrefixes := make([]netip.Prefix, 0, 2)
if len(option.Ip) > 0 {
if !strings.Contains(option.Ip, "/") {
option.Ip = option.Ip + "/32"
}
if prefix, err := netip.ParsePrefix(option.Ip); err == nil {
localPrefixes = append(localPrefixes, prefix)
} else {
return nil, fmt.Errorf("ip address parse error: %w", err)
}
}
if len(option.Ipv6) > 0 {
if !strings.Contains(option.Ipv6, "/") {
option.Ipv6 = option.Ipv6 + "/128"
}
if prefix, err := netip.ParsePrefix(option.Ipv6); err == nil {
localPrefixes = append(localPrefixes, prefix)
} else {
return nil, fmt.Errorf("ipv6 address parse error: %w", err)
}
}
if len(localPrefixes) == 0 {
return nil, errors.New("missing local address")
}
return localPrefixes, nil
}
func NewMasque(option MasqueOption) (*Masque, error) {
outbound := &Masque{
Base: &Base{
name: option.Name,
addr: net.JoinHostPort(option.Server, strconv.Itoa(option.Port)),
tp: C.Masque,
pdName: option.ProviderName,
udp: option.UDP,
iface: option.Interface,
rmark: option.RoutingMark,
prefer: option.IPVersion,
},
}
outbound.dialer = option.NewDialer(outbound.DialOptions())
ctx, cancel := context.WithCancel(context.Background())
outbound.runCtx = ctx
outbound.runCancel = cancel
privKeyB64, err := base64.StdEncoding.DecodeString(option.PrivateKey)
if err != nil {
return nil, fmt.Errorf("failed to decode private key: %v", err)
}
privKey, err := x509.ParseECPrivateKey(privKeyB64)
if err != nil {
return nil, fmt.Errorf("failed to parse private key: %v", err)
}
endpointPubKeyB64, err := base64.StdEncoding.DecodeString(option.PublicKey)
if err != nil {
return nil, fmt.Errorf("failed to decode public key: %v", err)
}
pubKey, err := x509.ParsePKIXPublicKey(endpointPubKeyB64)
if err != nil {
return nil, fmt.Errorf("failed to parse public key: %v", err)
}
ecPubKey, ok := pubKey.(*ecdsa.PublicKey)
if !ok {
return nil, fmt.Errorf("failed to assert public key as ECDSA")
}
uri := option.URI
if uri == "" {
uri = masque.ConnectURI
}
outbound.uri = uri
sni := option.SNI
if sni == "" {
sni = masque.ConnectSNI
}
tlsConfig, err := masque.PrepareTlsConfig(privKey, ecPubKey, sni)
if err != nil {
return nil, fmt.Errorf("failed to prepare TLS config: %v\n", err)
}
outbound.tlsConfig = tlsConfig
outbound.quicConfig = &quic.Config{
EnableDatagrams: true,
InitialPacketSize: 1242,
KeepAlivePeriod: 30 * time.Second,
}
prefixes, err := option.Prefixes()
if err != nil {
return nil, err
}
outbound.option = option
mtu := option.MTU
if mtu == 0 {
mtu = 1280
}
if len(prefixes) == 0 {
return nil, errors.New("missing local address")
}
outbound.tunDevice, err = wireguard.NewStackDevice(prefixes, uint32(mtu))
if err != nil {
return nil, fmt.Errorf("create device: %w", err)
}
var has6 bool
for _, address := range prefixes {
if !address.Addr().Unmap().Is4() {
has6 = true
break
}
}
if option.RemoteDnsResolve && len(option.Dns) > 0 {
nss, err := dns.ParseNameServer(option.Dns)
if err != nil {
return nil, err
}
for i := range nss {
nss[i].ProxyAdapter = outbound
}
outbound.resolver = dns.NewResolver(dns.Config{
Main: nss,
IPv6: has6,
})
}
return outbound, nil
}
func (w *Masque) run(ctx context.Context) error {
if w.running.Load() {
return nil
}
w.runMutex.Lock()
defer w.runMutex.Unlock()
// double-check like sync.Once
if w.running.Load() {
return nil
}
if w.runCtx.Err() != nil {
return w.runCtx.Err()
}
if !w.runDevice.Load() {
err := w.tunDevice.Start()
if err != nil {
return err
}
w.runDevice.Store(true)
}
udpAddr, err := resolveUDPAddr(ctx, "udp", w.addr, w.prefer)
if err != nil {
return err
}
pc, err := w.dialer.ListenPacket(ctx, "udp", "", udpAddr.AddrPort())
if err != nil {
return err
}
quicConn, err := quic.Dial(ctx, pc, udpAddr, w.tlsConfig, w.quicConfig)
if err != nil {
return err
}
common.SetCongestionController(quicConn, w.option.CongestionController, w.option.CWND)
tr, ipConn, err := masque.ConnectTunnel(ctx, quicConn, w.uri)
if err != nil {
_ = pc.Close()
return err
}
w.running.Store(true)
runCtx, runCancel := context.WithCancel(w.runCtx)
contextutils.AfterFunc(runCtx, func() {
w.running.Store(false)
_ = ipConn.Close()
_ = tr.Close()
_ = pc.Close()
})
go func() {
defer runCancel()
buf := pool.Get(pool.UDPBufferSize)
defer pool.Put(buf)
bufs := [][]byte{buf}
sizes := []int{0}
for runCtx.Err() == nil {
_, err := w.tunDevice.Read(bufs, sizes, 0)
if err != nil {
log.Errorln("[Masque](%s) error reading from TUN device: %v", w.name, err)
return
}
icmp, err := ipConn.WritePacket(buf[:sizes[0]])
if err != nil {
if errors.As(err, new(*connectip.CloseError)) {
log.Errorln("[Masque](%s) connection closed while writing to IP connection: %v", w.name, err)
return
}
log.Warnln("[Masque](%s) error writing to IP connection: %v, continuing...", w.name, err)
continue
}
if len(icmp) > 0 {
if _, err := w.tunDevice.Write([][]byte{icmp}, 0); err != nil {
log.Warnln("[Masque](%s) error writing ICMP to TUN device: %v, continuing...", w.name, err)
}
}
}
}()
go func() {
defer runCancel()
buf := pool.Get(pool.UDPBufferSize)
defer pool.Put(buf)
for runCtx.Err() == nil {
n, err := ipConn.ReadPacket(buf)
if err != nil {
if errors.As(err, new(*connectip.CloseError)) {
log.Errorln("[Masque](%s) connection closed while writing to IP connection: %v", w.name, err)
return
}
log.Warnln("[Masque](%s) error reading from IP connection: %v, continuing...", w.name, err)
continue
}
if _, err := w.tunDevice.Write([][]byte{buf[:n]}, 0); err != nil {
log.Errorln("[Masque](%s) error writing to TUN device: %v", w.name, err)
return
}
}
}()
return nil
}
// Close implements C.ProxyAdapter
func (w *Masque) Close() error {
w.runCancel()
if w.tunDevice != nil {
w.tunDevice.Close()
}
return nil
}
func (w *Masque) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn, err error) {
var conn net.Conn
if err = w.run(ctx); err != nil {
return nil, err
}
if !metadata.Resolved() || w.resolver != nil {
r := resolver.DefaultResolver
if w.resolver != nil {
r = w.resolver
}
options := w.DialOptions()
options = append(options, dialer.WithResolver(r))
options = append(options, dialer.WithNetDialer(wgNetDialer{tunDevice: w.tunDevice}))
conn, err = dialer.NewDialer(options...).DialContext(ctx, "tcp", metadata.RemoteAddress())
} else {
conn, err = w.tunDevice.DialContext(ctx, "tcp", M.SocksaddrFrom(metadata.DstIP, metadata.DstPort).Unwrap())
}
if err != nil {
return nil, err
}
if conn == nil {
return nil, errors.New("conn is nil")
}
return NewConn(conn, w), nil
}
func (w *Masque) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (_ C.PacketConn, err error) {
var pc net.PacketConn
if err = w.run(ctx); err != nil {
return nil, err
}
if err = w.ResolveUDP(ctx, metadata); err != nil {
return nil, err
}
pc, err = w.tunDevice.ListenPacket(ctx, M.SocksaddrFrom(metadata.DstIP, metadata.DstPort).Unwrap())
if err != nil {
return nil, err
}
if pc == nil {
return nil, errors.New("packetConn is nil")
}
return newPacketConn(pc, w), nil
}
func (w *Masque) ResolveUDP(ctx context.Context, metadata *C.Metadata) error {
if (!metadata.Resolved() || w.resolver != nil) && metadata.Host != "" {
r := resolver.DefaultResolver
if w.resolver != nil {
r = w.resolver
}
ip, err := resolver.ResolveIPWithResolver(ctx, metadata.Host, r)
if err != nil {
return fmt.Errorf("can't resolve ip: %w", err)
}
metadata.DstIP = ip
}
return nil
}
// ProxyInfo implements C.ProxyAdapter
func (w *Masque) ProxyInfo() C.ProxyInfo {
info := w.Base.ProxyInfo()
info.DialerProxy = w.option.DialerProxy
return info
}
// IsL3Protocol implements C.ProxyAdapter
func (w *Masque) IsL3Protocol(metadata *C.Metadata) bool {
return true
}

View File

@@ -8,9 +8,7 @@ import (
"strconv" "strconv"
"sync" "sync"
CN "github.com/metacubex/mihomo/common/net" N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/proxydialer"
"github.com/metacubex/mihomo/component/resolver" "github.com/metacubex/mihomo/component/resolver"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
@@ -106,7 +104,7 @@ func (m *Mieru) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (
if err != nil { if err != nil {
return nil, fmt.Errorf("dial to %s failed: %w", metadata.UDPAddr(), err) return nil, fmt.Errorf("dial to %s failed: %w", metadata.UDPAddr(), err)
} }
return newPacketConn(CN.NewThreadSafePacketConn(mierucommon.NewUDPAssociateWrapper(mierucommon.NewPacketOverStreamTunnel(c))), m), nil return newPacketConn(N.NewThreadSafePacketConn(mierucommon.NewUDPAssociateWrapper(mierucommon.NewPacketOverStreamTunnel(c))), m), nil
} }
// SupportUOT implements C.ProxyAdapter // SupportUOT implements C.ProxyAdapter
@@ -130,20 +128,12 @@ func (m *Mieru) ensureClientIsRunning() error {
} }
// Create a dialer and add it to the client config, before starting the client. // Create a dialer and add it to the client config, before starting the client.
var dialer C.Dialer = dialer.NewDialer(m.DialOptions()...)
var err error
if len(m.option.DialerProxy) > 0 {
dialer, err = proxydialer.NewByName(m.option.DialerProxy, dialer)
if err != nil {
return err
}
}
config, err := m.client.Load() config, err := m.client.Load()
if err != nil { if err != nil {
return err return err
} }
config.Dialer = dialer config.Dialer = m.dialer
config.PacketDialer = mieruPacketDialer{Dialer: dialer} config.PacketDialer = mieruPacketDialer{Dialer: m.dialer}
config.Resolver = mieruDNSResolver{prefer: m.prefer} config.Resolver = mieruDNSResolver{prefer: m.prefer}
if err := m.client.Store(config); err != nil { if err := m.client.Store(config); err != nil {
return err return err
@@ -177,16 +167,18 @@ func NewMieru(option MieruOption) (*Mieru, error) {
Base: &Base{ Base: &Base{
name: option.Name, name: option.Name,
addr: addr, addr: addr,
iface: option.Interface,
tp: C.Mieru, tp: C.Mieru,
pdName: option.ProviderName,
udp: option.UDP, udp: option.UDP,
xudp: false, xudp: false,
iface: option.Interface,
rmark: option.RoutingMark, rmark: option.RoutingMark,
prefer: C.NewDNSPrefer(option.IPVersion), prefer: option.IPVersion,
}, },
option: &option, option: &option,
client: c, client: c,
} }
outbound.dialer = option.NewDialer(outbound.DialOptions())
return outbound, nil return outbound, nil
} }

View File

@@ -17,6 +17,7 @@ type Reject struct {
} }
type RejectOption struct { type RejectOption struct {
BasicOption
Name string `proxy:"name"` Name string `proxy:"name"`
} }

View File

@@ -8,8 +8,6 @@ import (
N "github.com/metacubex/mihomo/common/net" N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/common/structure" "github.com/metacubex/mihomo/common/structure"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/proxydialer"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/ntp" "github.com/metacubex/mihomo/ntp"
gost "github.com/metacubex/mihomo/transport/gost-plugin" gost "github.com/metacubex/mihomo/transport/gost-plugin"
@@ -116,6 +114,7 @@ type kcpTunOption struct {
AutoExpire int `obfs:"autoexpire,omitempty"` AutoExpire int `obfs:"autoexpire,omitempty"`
ScavengeTTL int `obfs:"scavengettl,omitempty"` ScavengeTTL int `obfs:"scavengettl,omitempty"`
MTU int `obfs:"mtu,omitempty"` MTU int `obfs:"mtu,omitempty"`
RateLimit int `obfs:"ratelimit,omitempty"`
SndWnd int `obfs:"sndwnd,omitempty"` SndWnd int `obfs:"sndwnd,omitempty"`
RcvWnd int `obfs:"rcvwnd,omitempty"` RcvWnd int `obfs:"rcvwnd,omitempty"`
DataShard int `obfs:"datashard,omitempty"` DataShard int `obfs:"datashard,omitempty"`
@@ -130,6 +129,7 @@ type kcpTunOption struct {
SockBuf int `obfs:"sockbuf,omitempty"` SockBuf int `obfs:"sockbuf,omitempty"`
SmuxVer int `obfs:"smuxver,omitempty"` SmuxVer int `obfs:"smuxver,omitempty"`
SmuxBuf int `obfs:"smuxbuf,omitempty"` SmuxBuf int `obfs:"smuxbuf,omitempty"`
FrameSize int `obfs:"framesize,omitempty"`
StreamBuf int `obfs:"streambuf,omitempty"` StreamBuf int `obfs:"streambuf,omitempty"`
KeepAlive int `obfs:"keepalive,omitempty"` KeepAlive int `obfs:"keepalive,omitempty"`
} }
@@ -191,17 +191,6 @@ func (ss *ShadowSocks) StreamConnContext(ctx context.Context, c net.Conn, metada
// DialContext implements C.ProxyAdapter // DialContext implements C.ProxyAdapter
func (ss *ShadowSocks) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn, err error) { func (ss *ShadowSocks) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn, err error) {
return ss.DialContextWithDialer(ctx, dialer.NewDialer(ss.DialOptions()...), metadata)
}
// DialContextWithDialer implements C.ProxyAdapter
func (ss *ShadowSocks) DialContextWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (_ C.Conn, err error) {
if len(ss.option.DialerProxy) > 0 {
dialer, err = proxydialer.NewByName(ss.option.DialerProxy, dialer)
if err != nil {
return nil, err
}
}
var c net.Conn var c net.Conn
if ss.kcptunClient != nil { if ss.kcptunClient != nil {
c, err = ss.kcptunClient.OpenStream(ctx, func(ctx context.Context) (net.PacketConn, net.Addr, error) { c, err = ss.kcptunClient.OpenStream(ctx, func(ctx context.Context) (net.PacketConn, net.Addr, error) {
@@ -213,7 +202,7 @@ func (ss *ShadowSocks) DialContextWithDialer(ctx context.Context, dialer C.Diale
return nil, nil, err return nil, nil, err
} }
pc, err := dialer.ListenPacket(ctx, "udp", "", addr.AddrPort()) pc, err := ss.dialer.ListenPacket(ctx, "udp", "", addr.AddrPort())
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@@ -221,7 +210,7 @@ func (ss *ShadowSocks) DialContextWithDialer(ctx context.Context, dialer C.Diale
return pc, addr, nil return pc, addr, nil
}) })
} else { } else {
c, err = dialer.DialContext(ctx, "tcp", ss.addr) c, err = ss.dialer.DialContext(ctx, "tcp", ss.addr)
} }
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error: %w", ss.addr, err) return nil, fmt.Errorf("%s connect error: %w", ss.addr, err)
@@ -237,25 +226,14 @@ func (ss *ShadowSocks) DialContextWithDialer(ctx context.Context, dialer C.Diale
// ListenPacketContext implements C.ProxyAdapter // ListenPacketContext implements C.ProxyAdapter
func (ss *ShadowSocks) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (C.PacketConn, error) { func (ss *ShadowSocks) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (C.PacketConn, error) {
return ss.ListenPacketWithDialer(ctx, dialer.NewDialer(ss.DialOptions()...), metadata)
}
// ListenPacketWithDialer implements C.ProxyAdapter
func (ss *ShadowSocks) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (_ C.PacketConn, err error) {
if ss.option.UDPOverTCP { if ss.option.UDPOverTCP {
tcpConn, err := ss.DialContextWithDialer(ctx, dialer, metadata) tcpConn, err := ss.DialContext(ctx, metadata)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return ss.ListenPacketOnStreamConn(ctx, tcpConn, metadata) return ss.ListenPacketOnStreamConn(ctx, tcpConn, metadata)
} }
if len(ss.option.DialerProxy) > 0 { if err := ss.ResolveUDP(ctx, metadata); err != nil {
dialer, err = proxydialer.NewByName(ss.option.DialerProxy, dialer)
if err != nil {
return nil, err
}
}
if err = ss.ResolveUDP(ctx, metadata); err != nil {
return nil, err return nil, err
} }
addr, err := resolveUDPAddr(ctx, "udp", ss.addr, ss.prefer) addr, err := resolveUDPAddr(ctx, "udp", ss.addr, ss.prefer)
@@ -263,7 +241,7 @@ func (ss *ShadowSocks) ListenPacketWithDialer(ctx context.Context, dialer C.Dial
return nil, err return nil, err
} }
pc, err := dialer.ListenPacket(ctx, "udp", "", addr.AddrPort()) pc, err := ss.dialer.ListenPacket(ctx, "udp", "", addr.AddrPort())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -271,11 +249,6 @@ func (ss *ShadowSocks) ListenPacketWithDialer(ctx context.Context, dialer C.Dial
return newPacketConn(pc, ss), nil return newPacketConn(pc, ss), nil
} }
// SupportWithDialer implements C.ProxyAdapter
func (ss *ShadowSocks) SupportWithDialer() C.NetWork {
return C.ALLNet
}
// ProxyInfo implements C.ProxyAdapter // ProxyInfo implements C.ProxyAdapter
func (ss *ShadowSocks) ProxyInfo() C.ProxyInfo { func (ss *ShadowSocks) ProxyInfo() C.ProxyInfo {
info := ss.Base.ProxyInfo() info := ss.Base.ProxyInfo()
@@ -455,6 +428,7 @@ func NewShadowSocks(option ShadowSocksOption) (*ShadowSocks, error) {
AutoExpire: kcptunOpt.AutoExpire, AutoExpire: kcptunOpt.AutoExpire,
ScavengeTTL: kcptunOpt.ScavengeTTL, ScavengeTTL: kcptunOpt.ScavengeTTL,
MTU: kcptunOpt.MTU, MTU: kcptunOpt.MTU,
RateLimit: kcptunOpt.RateLimit,
SndWnd: kcptunOpt.SndWnd, SndWnd: kcptunOpt.SndWnd,
RcvWnd: kcptunOpt.RcvWnd, RcvWnd: kcptunOpt.RcvWnd,
DataShard: kcptunOpt.DataShard, DataShard: kcptunOpt.DataShard,
@@ -469,6 +443,7 @@ func NewShadowSocks(option ShadowSocksOption) (*ShadowSocks, error) {
SockBuf: kcptunOpt.SockBuf, SockBuf: kcptunOpt.SockBuf,
SmuxVer: kcptunOpt.SmuxVer, SmuxVer: kcptunOpt.SmuxVer,
SmuxBuf: kcptunOpt.SmuxBuf, SmuxBuf: kcptunOpt.SmuxBuf,
FrameSize: kcptunOpt.FrameSize,
StreamBuf: kcptunOpt.StreamBuf, StreamBuf: kcptunOpt.StreamBuf,
KeepAlive: kcptunOpt.KeepAlive, KeepAlive: kcptunOpt.KeepAlive,
}) })
@@ -482,17 +457,18 @@ func NewShadowSocks(option ShadowSocksOption) (*ShadowSocks, error) {
return nil, fmt.Errorf("ss %s unknown udp over tcp protocol version: %d", addr, option.UDPOverTCPVersion) return nil, fmt.Errorf("ss %s unknown udp over tcp protocol version: %d", addr, option.UDPOverTCPVersion)
} }
return &ShadowSocks{ outbound := &ShadowSocks{
Base: &Base{ Base: &Base{
name: option.Name, name: option.Name,
addr: addr, addr: addr,
tp: C.Shadowsocks, tp: C.Shadowsocks,
pdName: option.ProviderName,
udp: option.UDP, udp: option.UDP,
tfo: option.TFO, tfo: option.TFO,
mpTcp: option.MPTCP, mpTcp: option.MPTCP,
iface: option.Interface, iface: option.Interface,
rmark: option.RoutingMark, rmark: option.RoutingMark,
prefer: C.NewDNSPrefer(option.IPVersion), prefer: option.IPVersion,
}, },
method: method, method: method,
@@ -504,5 +480,7 @@ func NewShadowSocks(option ShadowSocksOption) (*ShadowSocks, error) {
shadowTLSOption: shadowTLSOpt, shadowTLSOption: shadowTLSOpt,
restlsConfig: restlsConfig, restlsConfig: restlsConfig,
kcptunClient: kcptunClient, kcptunClient: kcptunClient,
}, nil }
outbound.dialer = option.NewDialer(outbound.DialOptions())
return outbound, nil
} }

View File

@@ -8,8 +8,6 @@ import (
"strconv" "strconv"
N "github.com/metacubex/mihomo/common/net" N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/proxydialer"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/transport/shadowsocks/core" "github.com/metacubex/mihomo/transport/shadowsocks/core"
"github.com/metacubex/mihomo/transport/shadowsocks/shadowaead" "github.com/metacubex/mihomo/transport/shadowsocks/shadowaead"
@@ -68,18 +66,7 @@ func (ssr *ShadowSocksR) StreamConnContext(ctx context.Context, c net.Conn, meta
// DialContext implements C.ProxyAdapter // DialContext implements C.ProxyAdapter
func (ssr *ShadowSocksR) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn, err error) { func (ssr *ShadowSocksR) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn, err error) {
return ssr.DialContextWithDialer(ctx, dialer.NewDialer(ssr.DialOptions()...), metadata) c, err := ssr.dialer.DialContext(ctx, "tcp", ssr.addr)
}
// DialContextWithDialer implements C.ProxyAdapter
func (ssr *ShadowSocksR) DialContextWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (_ C.Conn, err error) {
if len(ssr.option.DialerProxy) > 0 {
dialer, err = proxydialer.NewByName(ssr.option.DialerProxy, dialer)
if err != nil {
return nil, err
}
}
c, err := dialer.DialContext(ctx, "tcp", ssr.addr)
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error: %w", ssr.addr, err) return nil, fmt.Errorf("%s connect error: %w", ssr.addr, err)
} }
@@ -94,18 +81,7 @@ func (ssr *ShadowSocksR) DialContextWithDialer(ctx context.Context, dialer C.Dia
// ListenPacketContext implements C.ProxyAdapter // ListenPacketContext implements C.ProxyAdapter
func (ssr *ShadowSocksR) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (C.PacketConn, error) { func (ssr *ShadowSocksR) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (C.PacketConn, error) {
return ssr.ListenPacketWithDialer(ctx, dialer.NewDialer(ssr.DialOptions()...), metadata) if err := ssr.ResolveUDP(ctx, metadata); err != nil {
}
// ListenPacketWithDialer implements C.ProxyAdapter
func (ssr *ShadowSocksR) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (_ C.PacketConn, err error) {
if len(ssr.option.DialerProxy) > 0 {
dialer, err = proxydialer.NewByName(ssr.option.DialerProxy, dialer)
if err != nil {
return nil, err
}
}
if err = ssr.ResolveUDP(ctx, metadata); err != nil {
return nil, err return nil, err
} }
addr, err := resolveUDPAddr(ctx, "udp", ssr.addr, ssr.prefer) addr, err := resolveUDPAddr(ctx, "udp", ssr.addr, ssr.prefer)
@@ -113,7 +89,7 @@ func (ssr *ShadowSocksR) ListenPacketWithDialer(ctx context.Context, dialer C.Di
return nil, err return nil, err
} }
pc, err := dialer.ListenPacket(ctx, "udp", "", addr.AddrPort()) pc, err := ssr.dialer.ListenPacket(ctx, "udp", "", addr.AddrPort())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -123,11 +99,6 @@ func (ssr *ShadowSocksR) ListenPacketWithDialer(ctx context.Context, dialer C.Di
return newPacketConn(&ssrPacketConn{EnhancePacketConn: epc, rAddr: addr}, ssr), nil return newPacketConn(&ssrPacketConn{EnhancePacketConn: epc, rAddr: addr}, ssr), nil
} }
// SupportWithDialer implements C.ProxyAdapter
func (ssr *ShadowSocksR) SupportWithDialer() C.NetWork {
return C.ALLNet
}
// ProxyInfo implements C.ProxyAdapter // ProxyInfo implements C.ProxyAdapter
func (ssr *ShadowSocksR) ProxyInfo() C.ProxyInfo { func (ssr *ShadowSocksR) ProxyInfo() C.ProxyInfo {
info := ssr.Base.ProxyInfo() info := ssr.Base.ProxyInfo()
@@ -186,23 +157,26 @@ func NewShadowSocksR(option ShadowSocksROption) (*ShadowSocksR, error) {
return nil, fmt.Errorf("ssr %s initialize protocol error: %w", addr, err) return nil, fmt.Errorf("ssr %s initialize protocol error: %w", addr, err)
} }
return &ShadowSocksR{ outbound := &ShadowSocksR{
Base: &Base{ Base: &Base{
name: option.Name, name: option.Name,
addr: addr, addr: addr,
tp: C.ShadowsocksR, tp: C.ShadowsocksR,
pdName: option.ProviderName,
udp: option.UDP, udp: option.UDP,
tfo: option.TFO, tfo: option.TFO,
mpTcp: option.MPTCP, mpTcp: option.MPTCP,
iface: option.Interface, iface: option.Interface,
rmark: option.RoutingMark, rmark: option.RoutingMark,
prefer: C.NewDNSPrefer(option.IPVersion), prefer: option.IPVersion,
}, },
option: &option, option: &option,
cipher: coreCiph, cipher: coreCiph,
obfs: obfs, obfs: obfs,
protocol: protocol, protocol: protocol,
}, nil }
outbound.dialer = option.NewDialer(outbound.DialOptions())
return outbound, nil
} }
type ssrPacketConn struct { type ssrPacketConn struct {

View File

@@ -3,8 +3,7 @@ package outbound
import ( import (
"context" "context"
CN "github.com/metacubex/mihomo/common/net" N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/proxydialer" "github.com/metacubex/mihomo/component/proxydialer"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/log" "github.com/metacubex/mihomo/log"
@@ -17,7 +16,6 @@ import (
type SingMux struct { type SingMux struct {
ProxyAdapter ProxyAdapter
client *mux.Client client *mux.Client
dialer proxydialer.SingDialer
onlyTcp bool onlyTcp bool
} }
@@ -61,7 +59,7 @@ func (s *SingMux) ListenPacketContext(ctx context.Context, metadata *C.Metadata)
if pc == nil { if pc == nil {
return nil, E.New("packetConn is nil") return nil, E.New("packetConn is nil")
} }
return newPacketConn(CN.NewThreadSafePacketConn(pc), s), nil return newPacketConn(N.NewThreadSafePacketConn(pc), s), nil
} }
func (s *SingMux) SupportUDP() bool { func (s *SingMux) SupportUDP() bool {
@@ -96,7 +94,7 @@ func NewSingMux(option SingMuxOption, proxy ProxyAdapter) (ProxyAdapter, error)
// TODO // TODO
// "TCP Brutal is only supported on Linux-based systems" // "TCP Brutal is only supported on Linux-based systems"
singDialer := proxydialer.NewSingDialer(proxy, dialer.NewDialer(proxy.DialOptions()...), option.Statistic) singDialer := proxydialer.NewSingDialer(proxydialer.New(proxy, option.Statistic))
client, err := mux.NewClient(mux.Options{ client, err := mux.NewClient(mux.Options{
Dialer: singDialer, Dialer: singDialer,
Logger: log.SingLogger, Logger: log.SingLogger,
@@ -117,7 +115,6 @@ func NewSingMux(option SingMuxOption, proxy ProxyAdapter) (ProxyAdapter, error)
outbound := &SingMux{ outbound := &SingMux{
ProxyAdapter: proxy, ProxyAdapter: proxy,
client: client, client: client,
dialer: singDialer,
onlyTcp: option.OnlyTcp, onlyTcp: option.OnlyTcp,
} }
return outbound, nil return outbound, nil

View File

@@ -8,8 +8,6 @@ import (
N "github.com/metacubex/mihomo/common/net" N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/common/structure" "github.com/metacubex/mihomo/common/structure"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/proxydialer"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
obfs "github.com/metacubex/mihomo/transport/simple-obfs" obfs "github.com/metacubex/mihomo/transport/simple-obfs"
"github.com/metacubex/mihomo/transport/snell" "github.com/metacubex/mihomo/transport/snell"
@@ -89,18 +87,7 @@ func (s *Snell) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn
return NewConn(c, s), err return NewConn(c, s), err
} }
return s.DialContextWithDialer(ctx, dialer.NewDialer(s.DialOptions()...), metadata) c, err := s.dialer.DialContext(ctx, "tcp", s.addr)
}
// DialContextWithDialer implements C.ProxyAdapter
func (s *Snell) DialContextWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (_ C.Conn, err error) {
if len(s.option.DialerProxy) > 0 {
dialer, err = proxydialer.NewByName(s.option.DialerProxy, dialer)
if err != nil {
return nil, err
}
}
c, err := dialer.DialContext(ctx, "tcp", s.addr)
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error: %w", s.addr, err) return nil, fmt.Errorf("%s connect error: %w", s.addr, err)
} }
@@ -115,22 +102,11 @@ func (s *Snell) DialContextWithDialer(ctx context.Context, dialer C.Dialer, meta
// ListenPacketContext implements C.ProxyAdapter // ListenPacketContext implements C.ProxyAdapter
func (s *Snell) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (C.PacketConn, error) { func (s *Snell) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (C.PacketConn, error) {
return s.ListenPacketWithDialer(ctx, dialer.NewDialer(s.DialOptions()...), metadata)
}
// ListenPacketWithDialer implements C.ProxyAdapter
func (s *Snell) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (C.PacketConn, error) {
var err error var err error
if len(s.option.DialerProxy) > 0 {
dialer, err = proxydialer.NewByName(s.option.DialerProxy, dialer)
if err != nil {
return nil, err
}
}
if err = s.ResolveUDP(ctx, metadata); err != nil { if err = s.ResolveUDP(ctx, metadata); err != nil {
return nil, err return nil, err
} }
c, err := dialer.DialContext(ctx, "tcp", s.addr) c, err := s.dialer.DialContext(ctx, "tcp", s.addr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -141,11 +117,6 @@ func (s *Snell) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, met
return newPacketConn(pc, s), nil return newPacketConn(pc, s), nil
} }
// SupportWithDialer implements C.ProxyAdapter
func (s *Snell) SupportWithDialer() C.NetWork {
return C.ALLNet
}
// SupportUOT implements C.ProxyAdapter // SupportUOT implements C.ProxyAdapter
func (s *Snell) SupportUOT() bool { func (s *Snell) SupportUOT() bool {
return true return true
@@ -194,30 +165,24 @@ func NewSnell(option SnellOption) (*Snell, error) {
name: option.Name, name: option.Name,
addr: addr, addr: addr,
tp: C.Snell, tp: C.Snell,
pdName: option.ProviderName,
udp: option.UDP, udp: option.UDP,
tfo: option.TFO, tfo: option.TFO,
mpTcp: option.MPTCP, mpTcp: option.MPTCP,
iface: option.Interface, iface: option.Interface,
rmark: option.RoutingMark, rmark: option.RoutingMark,
prefer: C.NewDNSPrefer(option.IPVersion), prefer: option.IPVersion,
}, },
option: &option, option: &option,
psk: psk, psk: psk,
obfsOption: obfsOption, obfsOption: obfsOption,
version: option.Version, version: option.Version,
} }
s.dialer = option.NewDialer(s.DialOptions())
if option.Version == snell.Version2 { if option.Version == snell.Version2 {
s.pool = snell.NewPool(func(ctx context.Context) (*snell.Snell, error) { s.pool = snell.NewPool(func(ctx context.Context) (*snell.Snell, error) {
var err error c, err := s.dialer.DialContext(ctx, "tcp", addr)
var cDialer C.Dialer = dialer.NewDialer(s.DialOptions()...)
if len(s.option.DialerProxy) > 0 {
cDialer, err = proxydialer.NewByName(s.option.DialerProxy, cDialer)
if err != nil {
return nil, err
}
}
c, err := cDialer.DialContext(ctx, "tcp", addr)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -2,7 +2,6 @@ package outbound
import ( import (
"context" "context"
"crypto/tls"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@@ -12,10 +11,10 @@ import (
N "github.com/metacubex/mihomo/common/net" N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/component/ca" "github.com/metacubex/mihomo/component/ca"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/proxydialer"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/transport/socks5" "github.com/metacubex/mihomo/transport/socks5"
"github.com/metacubex/tls"
) )
type Socks5 struct { type Socks5 struct {
@@ -69,18 +68,7 @@ func (ss *Socks5) StreamConnContext(ctx context.Context, c net.Conn, metadata *C
// DialContext implements C.ProxyAdapter // DialContext implements C.ProxyAdapter
func (ss *Socks5) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn, err error) { func (ss *Socks5) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn, err error) {
return ss.DialContextWithDialer(ctx, dialer.NewDialer(ss.DialOptions()...), metadata) c, err := ss.dialer.DialContext(ctx, "tcp", ss.addr)
}
// DialContextWithDialer implements C.ProxyAdapter
func (ss *Socks5) DialContextWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (_ C.Conn, err error) {
if len(ss.option.DialerProxy) > 0 {
dialer, err = proxydialer.NewByName(ss.option.DialerProxy, dialer)
if err != nil {
return nil, err
}
}
c, err := dialer.DialContext(ctx, "tcp", ss.addr)
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error: %w", ss.addr, err) return nil, fmt.Errorf("%s connect error: %w", ss.addr, err)
} }
@@ -97,24 +85,12 @@ func (ss *Socks5) DialContextWithDialer(ctx context.Context, dialer C.Dialer, me
return NewConn(c, ss), nil return NewConn(c, ss), nil
} }
// SupportWithDialer implements C.ProxyAdapter
func (ss *Socks5) SupportWithDialer() C.NetWork {
return C.TCP
}
// ListenPacketContext implements C.ProxyAdapter // ListenPacketContext implements C.ProxyAdapter
func (ss *Socks5) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (_ C.PacketConn, err error) { func (ss *Socks5) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (_ C.PacketConn, err error) {
var cDialer C.Dialer = dialer.NewDialer(ss.DialOptions()...)
if len(ss.option.DialerProxy) > 0 {
cDialer, err = proxydialer.NewByName(ss.option.DialerProxy, cDialer)
if err != nil {
return nil, err
}
}
if err = ss.ResolveUDP(ctx, metadata); err != nil { if err = ss.ResolveUDP(ctx, metadata); err != nil {
return nil, err return nil, err
} }
c, err := cDialer.DialContext(ctx, "tcp", ss.addr) c, err := ss.dialer.DialContext(ctx, "tcp", ss.addr)
if err != nil { if err != nil {
err = fmt.Errorf("%s connect error: %w", ss.addr, err) err = fmt.Errorf("%s connect error: %w", ss.addr, err)
return return
@@ -161,7 +137,7 @@ func (ss *Socks5) ListenPacketContext(ctx context.Context, metadata *C.Metadata)
bindUDPAddr.IP = serverAddr.IP bindUDPAddr.IP = serverAddr.IP
} }
pc, err := cDialer.ListenPacket(ctx, "udp", "", bindUDPAddr.AddrPort()) pc, err := ss.dialer.ListenPacket(ctx, "udp", "", bindUDPAddr.AddrPort())
if err != nil { if err != nil {
return return
} }
@@ -210,17 +186,18 @@ func NewSocks5(option Socks5Option) (*Socks5, error) {
} }
} }
return &Socks5{ outbound := &Socks5{
Base: &Base{ Base: &Base{
name: option.Name, name: option.Name,
addr: net.JoinHostPort(option.Server, strconv.Itoa(option.Port)), addr: net.JoinHostPort(option.Server, strconv.Itoa(option.Port)),
tp: C.Socks5, tp: C.Socks5,
pdName: option.ProviderName,
udp: option.UDP, udp: option.UDP,
tfo: option.TFO, tfo: option.TFO,
mpTcp: option.MPTCP, mpTcp: option.MPTCP,
iface: option.Interface, iface: option.Interface,
rmark: option.RoutingMark, rmark: option.RoutingMark,
prefer: C.NewDNSPrefer(option.IPVersion), prefer: option.IPVersion,
}, },
option: &option, option: &option,
user: option.UserName, user: option.UserName,
@@ -228,7 +205,9 @@ func NewSocks5(option Socks5Option) (*Socks5, error) {
tls: option.TLS, tls: option.TLS,
skipCertVerify: option.SkipCertVerify, skipCertVerify: option.SkipCertVerify,
tlsConfig: tlsConfig, tlsConfig: tlsConfig,
}, nil }
outbound.dialer = option.NewDialer(outbound.DialOptions())
return outbound, nil
} }
type socksPacketConn struct { type socksPacketConn struct {

View File

@@ -12,8 +12,6 @@ import (
"sync" "sync"
N "github.com/metacubex/mihomo/common/net" N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/proxydialer"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/randv2" "github.com/metacubex/randv2"
@@ -44,14 +42,7 @@ type SshOption struct {
} }
func (s *Ssh) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn, err error) { func (s *Ssh) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn, err error) {
var cDialer C.Dialer = dialer.NewDialer(s.DialOptions()...) client, err := s.connect(ctx, s.addr)
if len(s.option.DialerProxy) > 0 {
cDialer, err = proxydialer.NewByName(s.option.DialerProxy, cDialer)
if err != nil {
return nil, err
}
}
client, err := s.connect(ctx, cDialer, s.addr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -63,13 +54,13 @@ func (s *Ssh) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn,
return NewConn(c, s), nil return NewConn(c, s), nil
} }
func (s *Ssh) connect(ctx context.Context, cDialer C.Dialer, addr string) (client *ssh.Client, err error) { func (s *Ssh) connect(ctx context.Context, addr string) (client *ssh.Client, err error) {
s.cMutex.Lock() s.cMutex.Lock()
defer s.cMutex.Unlock() defer s.cMutex.Unlock()
if s.client != nil { if s.client != nil {
return s.client, nil return s.client, nil
} }
c, err := cDialer.DialContext(ctx, "tcp", addr) c, err := s.dialer.DialContext(ctx, "tcp", addr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -195,14 +186,15 @@ func NewSsh(option SshOption) (*Ssh, error) {
name: option.Name, name: option.Name,
addr: addr, addr: addr,
tp: C.Ssh, tp: C.Ssh,
pdName: option.ProviderName,
udp: false, udp: false,
iface: option.Interface, iface: option.Interface,
rmark: option.RoutingMark, rmark: option.RoutingMark,
prefer: C.NewDNSPrefer(option.IPVersion), prefer: option.IPVersion,
}, },
option: &option, option: &option,
config: &config, config: &config,
} }
outbound.dialer = option.NewDialer(outbound.DialOptions())
return outbound, nil return outbound, nil
} }

View File

@@ -2,102 +2,111 @@ package outbound
import ( import (
"context" "context"
"crypto/sha256"
"encoding/binary"
"fmt" "fmt"
"io"
"net" "net"
"strconv" "strconv"
"strings" "strings"
"time" "sync"
"github.com/metacubex/mihomo/log"
"github.com/saba-futai/sudoku/apis"
"github.com/saba-futai/sudoku/pkg/crypto"
"github.com/saba-futai/sudoku/pkg/obfs/httpmask"
"github.com/saba-futai/sudoku/pkg/obfs/sudoku"
N "github.com/metacubex/mihomo/common/net" N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/proxydialer"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/transport/sudoku"
) )
type Sudoku struct { type Sudoku struct {
*Base *Base
option *SudokuOption option *SudokuOption
table *sudoku.Table baseConf sudoku.ProtocolConfig
baseConf apis.ProtocolConfig
httpMaskMu sync.Mutex
httpMaskClient *sudoku.HTTPMaskTunnelClient
muxMu sync.Mutex
muxClient *sudoku.MultiplexClient
} }
type SudokuOption struct { type SudokuOption struct {
BasicOption BasicOption
Name string `proxy:"name"` Name string `proxy:"name"`
Server string `proxy:"server"` Server string `proxy:"server"`
Port int `proxy:"port"` Port int `proxy:"port"`
Key string `proxy:"key"` Key string `proxy:"key"`
AEADMethod string `proxy:"aead-method,omitempty"` AEADMethod string `proxy:"aead-method,omitempty"`
PaddingMin *int `proxy:"padding-min,omitempty"` PaddingMin *int `proxy:"padding-min,omitempty"`
PaddingMax *int `proxy:"padding-max,omitempty"` PaddingMax *int `proxy:"padding-max,omitempty"`
TableType string `proxy:"table-type,omitempty"` // "prefer_ascii" or "prefer_entropy" TableType string `proxy:"table-type,omitempty"` // "prefer_ascii" or "prefer_entropy"
HTTPMask bool `proxy:"http-mask,omitempty"` EnablePureDownlink *bool `proxy:"enable-pure-downlink,omitempty"`
HTTPMask bool `proxy:"http-mask,omitempty"`
HTTPMaskMode string `proxy:"http-mask-mode,omitempty"` // "legacy" (default), "stream", "poll", "auto"
HTTPMaskTLS bool `proxy:"http-mask-tls,omitempty"` // only for http-mask-mode stream/poll/auto
HTTPMaskHost string `proxy:"http-mask-host,omitempty"` // optional Host/SNI override (domain or domain:port)
PathRoot string `proxy:"path-root,omitempty"` // optional first-level path prefix for HTTP tunnel endpoints
HTTPMaskMultiplex string `proxy:"http-mask-multiplex,omitempty"` // "off" (default), "auto" (reuse h1/h2), "on" (single tunnel, multi-target)
CustomTable string `proxy:"custom-table,omitempty"` // optional custom byte layout, e.g. xpxvvpvv
CustomTables []string `proxy:"custom-tables,omitempty"` // optional table rotation patterns, overrides custom-table when non-empty
} }
// DialContext implements C.ProxyAdapter // DialContext implements C.ProxyAdapter
func (s *Sudoku) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) { func (s *Sudoku) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn, err error) {
return s.DialContextWithDialer(ctx, dialer.NewDialer(s.DialOptions()...), metadata)
}
// DialContextWithDialer implements C.ProxyAdapter
func (s *Sudoku) DialContextWithDialer(ctx context.Context, d C.Dialer, metadata *C.Metadata) (_ C.Conn, err error) {
if len(s.option.DialerProxy) > 0 {
d, err = proxydialer.NewByName(s.option.DialerProxy, d)
if err != nil {
return nil, err
}
}
cfg, err := s.buildConfig(metadata) cfg, err := s.buildConfig(metadata)
if err != nil { if err != nil {
return nil, err return nil, err
} }
c, err := d.DialContext(ctx, "tcp", s.addr) muxMode := normalizeHTTPMaskMultiplex(cfg.HTTPMaskMultiplex)
if err != nil { if muxMode == "on" && !cfg.DisableHTTPMask && httpTunnelModeEnabled(cfg.HTTPMaskMode) {
return nil, fmt.Errorf("%s connect error: %w", s.addr, err) stream, muxErr := s.dialMultiplex(ctx, cfg.TargetAddress)
if muxErr == nil {
return NewConn(stream, s), nil
}
return nil, muxErr
} }
defer func() { c, err := s.dialAndHandshake(ctx, cfg)
safeConnClose(c, err)
}()
if ctx.Done() != nil {
done := N.SetupContextForConn(ctx, c)
defer done(&err)
}
c, err = s.streamConn(c, cfg)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer func() { safeConnClose(c, err) }()
addrBuf, err := sudoku.EncodeAddress(cfg.TargetAddress)
if err != nil {
return nil, fmt.Errorf("encode target address failed: %w", err)
}
if _, err = c.Write(addrBuf); err != nil {
return nil, fmt.Errorf("send target address failed: %w", err)
}
return NewConn(c, s), nil return NewConn(c, s), nil
} }
// ListenPacketContext implements C.ProxyAdapter // ListenPacketContext implements C.ProxyAdapter
func (s *Sudoku) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (C.PacketConn, error) { func (s *Sudoku) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (C.PacketConn, error) {
return nil, C.ErrNotSupport if err := s.ResolveUDP(ctx, metadata); err != nil {
return nil, err
}
cfg, err := s.buildConfig(metadata)
if err != nil {
return nil, err
}
c, err := s.dialAndHandshake(ctx, cfg)
if err != nil {
return nil, err
}
if err = sudoku.WritePreface(c); err != nil {
_ = c.Close()
return nil, fmt.Errorf("send uot preface failed: %w", err)
}
return newPacketConn(N.NewThreadSafePacketConn(sudoku.NewUoTPacketConn(c)), s), nil
} }
// SupportUOT implements C.ProxyAdapter // SupportUOT implements C.ProxyAdapter
func (s *Sudoku) SupportUOT() bool { func (s *Sudoku) SupportUOT() bool {
return false // Sudoku protocol only supports TCP return true
}
// SupportWithDialer implements C.ProxyAdapter
func (s *Sudoku) SupportWithDialer() C.NetWork {
return C.TCP
} }
// ProxyInfo implements C.ProxyAdapter // ProxyInfo implements C.ProxyAdapter
@@ -107,7 +116,7 @@ func (s *Sudoku) ProxyInfo() C.ProxyInfo {
return info return info
} }
func (s *Sudoku) buildConfig(metadata *C.Metadata) (*apis.ProtocolConfig, error) { func (s *Sudoku) buildConfig(metadata *C.Metadata) (*sudoku.ProtocolConfig, error) {
if metadata == nil || metadata.DstPort == 0 || !metadata.Valid() { if metadata == nil || metadata.DstPort == 0 || !metadata.Valid() {
return nil, fmt.Errorf("invalid metadata for sudoku outbound") return nil, fmt.Errorf("invalid metadata for sudoku outbound")
} }
@@ -121,33 +130,6 @@ func (s *Sudoku) buildConfig(metadata *C.Metadata) (*apis.ProtocolConfig, error)
return &cfg, nil return &cfg, nil
} }
func (s *Sudoku) streamConn(rawConn net.Conn, cfg *apis.ProtocolConfig) (_ net.Conn, err error) {
if !cfg.DisableHTTPMask {
if err = httpmask.WriteRandomRequestHeader(rawConn, cfg.ServerAddress); err != nil {
return nil, fmt.Errorf("write http mask failed: %w", err)
}
}
obfsConn := sudoku.NewConn(rawConn, cfg.Table, cfg.PaddingMin, cfg.PaddingMax, false)
cConn, err := crypto.NewAEADConn(obfsConn, cfg.Key, cfg.AEADMethod)
if err != nil {
return nil, fmt.Errorf("setup crypto failed: %w", err)
}
handshake := buildSudokuHandshakePayload(cfg.Key)
if _, err = cConn.Write(handshake[:]); err != nil {
cConn.Close()
return nil, fmt.Errorf("send handshake failed: %w", err)
}
if err = writeTargetAddress(cConn, cfg.TargetAddress); err != nil {
cConn.Close()
return nil, fmt.Errorf("send target address failed: %w", err)
}
return cConn, nil
}
func NewSudoku(option SudokuOption) (*Sudoku, error) { func NewSudoku(option SudokuOption) (*Sudoku, error) {
if option.Server == "" { if option.Server == "" {
return nil, fmt.Errorf("server is required") return nil, fmt.Errorf("server is required")
@@ -167,16 +149,7 @@ func NewSudoku(option SudokuOption) (*Sudoku, error) {
return nil, fmt.Errorf("table-type must be prefer_ascii or prefer_entropy") return nil, fmt.Errorf("table-type must be prefer_ascii or prefer_entropy")
} }
seed := option.Key defaultConf := sudoku.DefaultConfig()
if recoveredFromKey, err := crypto.RecoverPublicKey(option.Key); err == nil {
seed = crypto.EncodePoint(recoveredFromKey)
}
start := time.Now()
table := sudoku.NewTable(seed, tableType)
log.Infoln("[Sudoku] Tables initialized (%s) in %v", tableType, time.Since(start))
defaultConf := apis.DefaultConfig()
paddingMin := defaultConf.PaddingMin paddingMin := defaultConf.PaddingMin
paddingMax := defaultConf.PaddingMax paddingMax := defaultConf.PaddingMax
if option.PaddingMin != nil { if option.PaddingMin != nil {
@@ -191,80 +164,242 @@ func NewSudoku(option SudokuOption) (*Sudoku, error) {
if option.PaddingMax == nil && option.PaddingMin != nil && paddingMax < paddingMin { if option.PaddingMax == nil && option.PaddingMin != nil && paddingMax < paddingMin {
paddingMax = paddingMin paddingMax = paddingMin
} }
enablePureDownlink := defaultConf.EnablePureDownlink
if option.EnablePureDownlink != nil {
enablePureDownlink = *option.EnablePureDownlink
}
baseConf := apis.ProtocolConfig{ baseConf := sudoku.ProtocolConfig{
ServerAddress: net.JoinHostPort(option.Server, strconv.Itoa(option.Port)), ServerAddress: net.JoinHostPort(option.Server, strconv.Itoa(option.Port)),
Key: option.Key, Key: option.Key,
AEADMethod: defaultConf.AEADMethod, AEADMethod: defaultConf.AEADMethod,
Table: table,
PaddingMin: paddingMin, PaddingMin: paddingMin,
PaddingMax: paddingMax, PaddingMax: paddingMax,
EnablePureDownlink: enablePureDownlink,
HandshakeTimeoutSeconds: defaultConf.HandshakeTimeoutSeconds, HandshakeTimeoutSeconds: defaultConf.HandshakeTimeoutSeconds,
DisableHTTPMask: !option.HTTPMask, DisableHTTPMask: !option.HTTPMask,
HTTPMaskMode: defaultConf.HTTPMaskMode,
HTTPMaskTLSEnabled: option.HTTPMaskTLS,
HTTPMaskHost: option.HTTPMaskHost,
HTTPMaskPathRoot: strings.TrimSpace(option.PathRoot),
HTTPMaskMultiplex: defaultConf.HTTPMaskMultiplex,
}
if option.HTTPMaskMode != "" {
baseConf.HTTPMaskMode = option.HTTPMaskMode
}
if option.HTTPMaskMultiplex != "" {
baseConf.HTTPMaskMultiplex = option.HTTPMaskMultiplex
}
tables, err := sudoku.NewTablesWithCustomPatterns(sudoku.ClientAEADSeed(option.Key), tableType, option.CustomTable, option.CustomTables)
if err != nil {
return nil, fmt.Errorf("build table(s) failed: %w", err)
}
if len(tables) == 1 {
baseConf.Table = tables[0]
} else {
baseConf.Tables = tables
} }
if option.AEADMethod != "" { if option.AEADMethod != "" {
baseConf.AEADMethod = option.AEADMethod baseConf.AEADMethod = option.AEADMethod
} }
return &Sudoku{ outbound := &Sudoku{
Base: &Base{ Base: &Base{
name: option.Name, name: option.Name,
addr: baseConf.ServerAddress, addr: baseConf.ServerAddress,
tp: C.Sudoku, tp: C.Sudoku,
udp: false, pdName: option.ProviderName,
udp: true,
tfo: option.TFO, tfo: option.TFO,
mpTcp: option.MPTCP, mpTcp: option.MPTCP,
iface: option.Interface, iface: option.Interface,
rmark: option.RoutingMark, rmark: option.RoutingMark,
prefer: C.NewDNSPrefer(option.IPVersion), prefer: option.IPVersion,
}, },
option: &option, option: &option,
table: table,
baseConf: baseConf, baseConf: baseConf,
}, nil }
outbound.dialer = option.NewDialer(outbound.DialOptions())
return outbound, nil
} }
func buildSudokuHandshakePayload(key string) [16]byte { func (s *Sudoku) Close() error {
var payload [16]byte s.resetMuxClient()
binary.BigEndian.PutUint64(payload[:8], uint64(time.Now().Unix())) s.resetHTTPMaskClient()
hash := sha256.Sum256([]byte(key)) return s.Base.Close()
copy(payload[8:], hash[:8])
return payload
} }
func writeTargetAddress(w io.Writer, rawAddr string) error { func normalizeHTTPMaskMultiplex(mode string) string {
host, portStr, err := net.SplitHostPort(rawAddr) switch strings.ToLower(strings.TrimSpace(mode)) {
if err != nil { case "", "off":
return err return "off"
case "auto":
return "auto"
case "on":
return "on"
default:
return "off"
}
}
func httpTunnelModeEnabled(mode string) bool {
switch strings.ToLower(strings.TrimSpace(mode)) {
case "stream", "poll", "auto":
return true
default:
return false
}
}
func (s *Sudoku) dialAndHandshake(ctx context.Context, cfg *sudoku.ProtocolConfig) (_ net.Conn, err error) {
if cfg == nil {
return nil, fmt.Errorf("config is required")
} }
portInt, err := net.LookupPort("tcp", portStr) handshakeCfg := *cfg
if err != nil { if !handshakeCfg.DisableHTTPMask && httpTunnelModeEnabled(handshakeCfg.HTTPMaskMode) {
return err handshakeCfg.DisableHTTPMask = true
} }
var buf []byte upgrade := func(raw net.Conn) (net.Conn, error) {
if ip := net.ParseIP(host); ip != nil { return sudoku.ClientHandshake(raw, &handshakeCfg)
if ip4 := ip.To4(); ip4 != nil { }
buf = append(buf, 0x01) // IPv4
buf = append(buf, ip4...) var (
} else { c net.Conn
buf = append(buf, 0x04) // IPv6 handshakeDone bool
buf = append(buf, ip...) )
if !cfg.DisableHTTPMask && httpTunnelModeEnabled(cfg.HTTPMaskMode) {
muxMode := normalizeHTTPMaskMultiplex(cfg.HTTPMaskMultiplex)
switch muxMode {
case "auto", "on":
client, errX := s.getOrCreateHTTPMaskClient(cfg)
if errX != nil {
return nil, errX
}
c, err = client.Dial(ctx, upgrade)
default:
c, err = sudoku.DialHTTPMaskTunnel(ctx, cfg.ServerAddress, cfg, s.dialer.DialContext, upgrade)
} }
} else { if err == nil && c != nil {
if len(host) > 255 { handshakeDone = true
return fmt.Errorf("domain too long")
} }
buf = append(buf, 0x03) // domain }
buf = append(buf, byte(len(host))) if c == nil && err == nil {
buf = append(buf, host...) c, err = s.dialer.DialContext(ctx, "tcp", s.addr)
}
if err != nil {
return nil, fmt.Errorf("%s connect error: %w", s.addr, err)
} }
var portBytes [2]byte defer func() { safeConnClose(c, err) }()
binary.BigEndian.PutUint16(portBytes[:], uint16(portInt))
buf = append(buf, portBytes[:]...)
_, err = w.Write(buf) if ctx.Done() != nil {
return err done := N.SetupContextForConn(ctx, c)
defer done(&err)
}
if !handshakeDone {
c, err = sudoku.ClientHandshake(c, &handshakeCfg)
if err != nil {
return nil, err
}
}
return c, nil
}
func (s *Sudoku) dialMultiplex(ctx context.Context, targetAddress string) (net.Conn, error) {
for attempt := 0; attempt < 2; attempt++ {
client, err := s.getOrCreateMuxClient(ctx)
if err != nil {
return nil, err
}
stream, err := client.Dial(ctx, targetAddress)
if err != nil {
s.resetMuxClient()
continue
}
return stream, nil
}
return nil, fmt.Errorf("multiplex open stream failed")
}
func (s *Sudoku) getOrCreateMuxClient(ctx context.Context) (*sudoku.MultiplexClient, error) {
if s == nil {
return nil, fmt.Errorf("nil adapter")
}
s.muxMu.Lock()
if s.muxClient != nil && !s.muxClient.IsClosed() {
client := s.muxClient
s.muxMu.Unlock()
return client, nil
}
s.muxMu.Unlock()
s.muxMu.Lock()
defer s.muxMu.Unlock()
if s.muxClient != nil && !s.muxClient.IsClosed() {
return s.muxClient, nil
}
baseCfg := s.baseConf
baseConn, err := s.dialAndHandshake(ctx, &baseCfg)
if err != nil {
return nil, err
}
client, err := sudoku.StartMultiplexClient(baseConn)
if err != nil {
_ = baseConn.Close()
return nil, err
}
s.muxClient = client
return client, nil
}
func (s *Sudoku) resetMuxClient() {
s.muxMu.Lock()
defer s.muxMu.Unlock()
if s.muxClient != nil {
_ = s.muxClient.Close()
s.muxClient = nil
}
}
func (s *Sudoku) getOrCreateHTTPMaskClient(cfg *sudoku.ProtocolConfig) (*sudoku.HTTPMaskTunnelClient, error) {
if s == nil {
return nil, fmt.Errorf("nil adapter")
}
if cfg == nil {
return nil, fmt.Errorf("config is required")
}
s.httpMaskMu.Lock()
defer s.httpMaskMu.Unlock()
if s.httpMaskClient != nil {
return s.httpMaskClient, nil
}
c, err := sudoku.NewHTTPMaskTunnelClient(cfg.ServerAddress, cfg, s.dialer.DialContext)
if err != nil {
return nil, err
}
s.httpMaskClient = c
return c, nil
}
func (s *Sudoku) resetHTTPMaskClient() {
s.httpMaskMu.Lock()
defer s.httpMaskMu.Unlock()
if s.httpMaskClient != nil {
s.httpMaskClient.CloseIdleConnections()
s.httpMaskClient = nil
}
} }

View File

@@ -2,24 +2,23 @@ package outbound
import ( import (
"context" "context"
"crypto/tls"
"errors" "errors"
"fmt" "fmt"
"net" "net"
"net/http"
"strconv" "strconv"
N "github.com/metacubex/mihomo/common/net" N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/component/ca" "github.com/metacubex/mihomo/component/ca"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/ech" "github.com/metacubex/mihomo/component/ech"
"github.com/metacubex/mihomo/component/proxydialer"
tlsC "github.com/metacubex/mihomo/component/tls" tlsC "github.com/metacubex/mihomo/component/tls"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/transport/gun" "github.com/metacubex/mihomo/transport/gun"
"github.com/metacubex/mihomo/transport/shadowsocks/core" "github.com/metacubex/mihomo/transport/shadowsocks/core"
"github.com/metacubex/mihomo/transport/trojan" "github.com/metacubex/mihomo/transport/trojan"
"github.com/metacubex/mihomo/transport/vmess" "github.com/metacubex/mihomo/transport/vmess"
"github.com/metacubex/http"
"github.com/metacubex/tls"
) )
type Trojan struct { type Trojan struct {
@@ -196,18 +195,7 @@ func (t *Trojan) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Con
return NewConn(c, t), nil return NewConn(c, t), nil
} }
return t.DialContextWithDialer(ctx, dialer.NewDialer(t.DialOptions()...), metadata) c, err = t.dialer.DialContext(ctx, "tcp", t.addr)
}
// DialContextWithDialer implements C.ProxyAdapter
func (t *Trojan) DialContextWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (_ C.Conn, err error) {
if len(t.option.DialerProxy) > 0 {
dialer, err = proxydialer.NewByName(t.option.DialerProxy, dialer)
if err != nil {
return nil, err
}
}
c, err := dialer.DialContext(ctx, "tcp", t.addr)
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error: %w", t.addr, err) return nil, fmt.Errorf("%s connect error: %w", t.addr, err)
} }
@@ -250,21 +238,10 @@ func (t *Trojan) ListenPacketContext(ctx context.Context, metadata *C.Metadata)
pc := trojan.NewPacketConn(c) pc := trojan.NewPacketConn(c)
return newPacketConn(pc, t), err return newPacketConn(pc, t), err
} }
return t.ListenPacketWithDialer(ctx, dialer.NewDialer(t.DialOptions()...), metadata)
}
// ListenPacketWithDialer implements C.ProxyAdapter
func (t *Trojan) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (_ C.PacketConn, err error) {
if len(t.option.DialerProxy) > 0 {
dialer, err = proxydialer.NewByName(t.option.DialerProxy, dialer)
if err != nil {
return nil, err
}
}
if err = t.ResolveUDP(ctx, metadata); err != nil { if err = t.ResolveUDP(ctx, metadata); err != nil {
return nil, err return nil, err
} }
c, err := dialer.DialContext(ctx, "tcp", t.addr) c, err = t.dialer.DialContext(ctx, "tcp", t.addr)
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error: %w", t.addr, err) return nil, fmt.Errorf("%s connect error: %w", t.addr, err)
} }
@@ -280,11 +257,6 @@ func (t *Trojan) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, me
return newPacketConn(pc, t), err return newPacketConn(pc, t), err
} }
// SupportWithDialer implements C.ProxyAdapter
func (t *Trojan) SupportWithDialer() C.NetWork {
return C.ALLNet
}
// SupportUOT implements C.ProxyAdapter // SupportUOT implements C.ProxyAdapter
func (t *Trojan) SupportUOT() bool { func (t *Trojan) SupportUOT() bool {
return true return true
@@ -317,16 +289,18 @@ func NewTrojan(option TrojanOption) (*Trojan, error) {
name: option.Name, name: option.Name,
addr: addr, addr: addr,
tp: C.Trojan, tp: C.Trojan,
pdName: option.ProviderName,
udp: option.UDP, udp: option.UDP,
tfo: option.TFO, tfo: option.TFO,
mpTcp: option.MPTCP, mpTcp: option.MPTCP,
iface: option.Interface, iface: option.Interface,
rmark: option.RoutingMark, rmark: option.RoutingMark,
prefer: C.NewDNSPrefer(option.IPVersion), prefer: option.IPVersion,
}, },
option: &option, option: &option,
hexPassword: trojan.Key(option.Password), hexPassword: trojan.Key(option.Password),
} }
t.dialer = option.NewDialer(t.DialOptions())
var err error var err error
t.realityConfig, err = option.RealityOpts.Parse() t.realityConfig, err = option.RealityOpts.Parse()
@@ -355,15 +329,7 @@ func NewTrojan(option TrojanOption) (*Trojan, error) {
if option.Network == "grpc" { if option.Network == "grpc" {
dialFn := func(ctx context.Context, network, addr string) (net.Conn, error) { dialFn := func(ctx context.Context, network, addr string) (net.Conn, error) {
var err error c, err := t.dialer.DialContext(ctx, "tcp", t.addr)
var cDialer C.Dialer = dialer.NewDialer(t.DialOptions()...)
if len(t.option.DialerProxy) > 0 {
cDialer, err = proxydialer.NewByName(t.option.DialerProxy, cDialer)
if err != nil {
return nil, err
}
}
c, err := cDialer.DialContext(ctx, "tcp", t.addr)
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error: %s", t.addr, err.Error()) return nil, fmt.Errorf("%s connect error: %s", t.addr, err.Error())
} }
@@ -390,6 +356,7 @@ func NewTrojan(option TrojanOption) (*Trojan, error) {
t.gunTLSConfig = tlsConfig t.gunTLSConfig = tlsConfig
t.gunConfig = &gun.Config{ t.gunConfig = &gun.Config{
ServiceName: option.GrpcOpts.GrpcServiceName, ServiceName: option.GrpcOpts.GrpcServiceName,
UserAgent: option.GrpcOpts.GrpcUserAgent,
Host: option.SNI, Host: option.SNI,
ClientFingerprint: option.ClientFingerprint, ClientFingerprint: option.ClientFingerprint,
} }

View File

@@ -2,7 +2,6 @@ package outbound
import ( import (
"context" "context"
"crypto/tls"
"fmt" "fmt"
"math" "math"
"net" "net"
@@ -10,10 +9,7 @@ import (
"time" "time"
"github.com/metacubex/mihomo/component/ca" "github.com/metacubex/mihomo/component/ca"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/ech" "github.com/metacubex/mihomo/component/ech"
"github.com/metacubex/mihomo/component/proxydialer"
tlsC "github.com/metacubex/mihomo/component/tls"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/transport/tuic" "github.com/metacubex/mihomo/transport/tuic"
@@ -21,6 +17,7 @@ import (
"github.com/metacubex/quic-go" "github.com/metacubex/quic-go"
M "github.com/metacubex/sing/common/metadata" M "github.com/metacubex/sing/common/metadata"
"github.com/metacubex/sing/common/uot" "github.com/metacubex/sing/common/uot"
"github.com/metacubex/tls"
) )
type Tuic struct { type Tuic struct {
@@ -28,7 +25,7 @@ type Tuic struct {
option *TuicOption option *TuicOption
client *tuic.PoolClient client *tuic.PoolClient
tlsConfig *tlsC.Config tlsConfig *tls.Config
echConfig *ech.Config echConfig *ech.Config
} }
@@ -70,12 +67,7 @@ type TuicOption struct {
// DialContext implements C.ProxyAdapter // DialContext implements C.ProxyAdapter
func (t *Tuic) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) { func (t *Tuic) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) {
return t.DialContextWithDialer(ctx, dialer.NewDialer(t.DialOptions()...), metadata) conn, err := t.client.DialContext(ctx, metadata)
}
// DialContextWithDialer implements C.ProxyAdapter
func (t *Tuic) DialContextWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (C.Conn, error) {
conn, err := t.client.DialContextWithDialer(ctx, metadata, dialer, t.dialWithDialer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -84,11 +76,6 @@ func (t *Tuic) DialContextWithDialer(ctx context.Context, dialer C.Dialer, metad
// ListenPacketContext implements C.ProxyAdapter // ListenPacketContext implements C.ProxyAdapter
func (t *Tuic) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (_ C.PacketConn, err error) { func (t *Tuic) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (_ C.PacketConn, err error) {
return t.ListenPacketWithDialer(ctx, dialer.NewDialer(t.DialOptions()...), metadata)
}
// ListenPacketWithDialer implements C.ProxyAdapter
func (t *Tuic) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (_ C.PacketConn, err error) {
if err = t.ResolveUDP(ctx, metadata); err != nil { if err = t.ResolveUDP(ctx, metadata); err != nil {
return nil, err return nil, err
} }
@@ -98,7 +85,7 @@ func (t *Tuic) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, meta
uotMetadata := *metadata uotMetadata := *metadata
uotMetadata.Host = uotDestination.Fqdn uotMetadata.Host = uotDestination.Fqdn
uotMetadata.DstPort = uotDestination.Port uotMetadata.DstPort = uotDestination.Port
c, err := t.DialContextWithDialer(ctx, dialer, &uotMetadata) c, err := t.DialContext(ctx, &uotMetadata)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -112,25 +99,14 @@ func (t *Tuic) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, meta
return newPacketConn(uot.NewLazyConn(c, uot.Request{Destination: destination}), t), nil return newPacketConn(uot.NewLazyConn(c, uot.Request{Destination: destination}), t), nil
} }
} }
pc, err := t.client.ListenPacketWithDialer(ctx, metadata, dialer, t.dialWithDialer) pc, err := t.client.ListenPacket(ctx, metadata)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return newPacketConn(pc, t), nil return newPacketConn(pc, t), nil
} }
// SupportWithDialer implements C.ProxyAdapter func (t *Tuic) dial(ctx context.Context) (transport *quic.Transport, addr net.Addr, err error) {
func (t *Tuic) SupportWithDialer() C.NetWork {
return C.ALLNet
}
func (t *Tuic) dialWithDialer(ctx context.Context, dialer C.Dialer) (transport *quic.Transport, addr net.Addr, err error) {
if len(t.option.DialerProxy) > 0 {
dialer, err = proxydialer.NewByName(t.option.DialerProxy, dialer)
if err != nil {
return nil, nil, err
}
}
udpAddr, err := resolveUDPAddr(ctx, "udp", t.addr, t.prefer) udpAddr, err := resolveUDPAddr(ctx, "udp", t.addr, t.prefer)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
@@ -141,7 +117,7 @@ func (t *Tuic) dialWithDialer(ctx context.Context, dialer C.Dialer) (transport *
} }
addr = udpAddr addr = udpAddr
var pc net.PacketConn var pc net.PacketConn
pc, err = dialer.ListenPacket(ctx, "udp", "", udpAddr.AddrPort()) pc, err = t.dialer.ListenPacket(ctx, "udp", "", udpAddr.AddrPort())
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@@ -256,7 +232,7 @@ func NewTuic(option TuicOption) (*Tuic, error) {
tlsConfig.InsecureSkipVerify = true // tls: either ServerName or InsecureSkipVerify must be specified in the tls.Config tlsConfig.InsecureSkipVerify = true // tls: either ServerName or InsecureSkipVerify must be specified in the tls.Config
} }
tlsClientConfig := tlsC.UConfig(tlsConfig) tlsClientConfig := tlsConfig
echConfig, err := option.ECHOpts.Parse() echConfig, err := option.ECHOpts.Parse()
if err != nil { if err != nil {
return nil, err return nil, err
@@ -275,16 +251,18 @@ func NewTuic(option TuicOption) (*Tuic, error) {
name: option.Name, name: option.Name,
addr: addr, addr: addr,
tp: C.Tuic, tp: C.Tuic,
pdName: option.ProviderName,
udp: true, udp: true,
tfo: option.FastOpen, tfo: option.FastOpen,
iface: option.Interface, iface: option.Interface,
rmark: option.RoutingMark, rmark: option.RoutingMark,
prefer: C.NewDNSPrefer(option.IPVersion), prefer: option.IPVersion,
}, },
option: &option, option: &option,
tlsConfig: tlsClientConfig, tlsConfig: tlsClientConfig,
echConfig: echConfig, echConfig: echConfig,
} }
t.dialer = option.NewDialer(t.DialOptions())
clientMaxOpenStreams := int64(option.MaxOpenStreams) clientMaxOpenStreams := int64(option.MaxOpenStreams)
@@ -313,7 +291,7 @@ func NewTuic(option TuicOption) (*Tuic, error) {
CWND: option.CWND, CWND: option.CWND,
} }
t.client = tuic.NewPoolClientV4(clientOption) t.client = tuic.NewPoolClientV4(clientOption, t.dial)
} else { } else {
maxUdpRelayPacketSize := option.MaxUdpRelayPacketSize maxUdpRelayPacketSize := option.MaxUdpRelayPacketSize
if maxUdpRelayPacketSize > tuic.MaxFragSizeV5 { if maxUdpRelayPacketSize > tuic.MaxFragSizeV5 {
@@ -332,7 +310,7 @@ func NewTuic(option TuicOption) (*Tuic, error) {
CWND: option.CWND, CWND: option.CWND,
} }
t.client = tuic.NewPoolClientV5(clientOption) t.client = tuic.NewPoolClientV5(clientOption, t.dial)
} }
return t, nil return t, nil

View File

@@ -2,19 +2,15 @@ package outbound
import ( import (
"context" "context"
"crypto/tls"
"fmt" "fmt"
"net" "net"
"net/http"
"strconv" "strconv"
"github.com/metacubex/mihomo/common/convert" "github.com/metacubex/mihomo/common/convert"
N "github.com/metacubex/mihomo/common/net" N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/common/utils" "github.com/metacubex/mihomo/common/utils"
"github.com/metacubex/mihomo/component/ca" "github.com/metacubex/mihomo/component/ca"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/ech" "github.com/metacubex/mihomo/component/ech"
"github.com/metacubex/mihomo/component/proxydialer"
tlsC "github.com/metacubex/mihomo/component/tls" tlsC "github.com/metacubex/mihomo/component/tls"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/transport/gun" "github.com/metacubex/mihomo/transport/gun"
@@ -22,9 +18,11 @@ import (
"github.com/metacubex/mihomo/transport/vless/encryption" "github.com/metacubex/mihomo/transport/vless/encryption"
"github.com/metacubex/mihomo/transport/vmess" "github.com/metacubex/mihomo/transport/vmess"
"github.com/metacubex/http"
vmessSing "github.com/metacubex/sing-vmess" vmessSing "github.com/metacubex/sing-vmess"
"github.com/metacubex/sing-vmess/packetaddr" "github.com/metacubex/sing-vmess/packetaddr"
M "github.com/metacubex/sing/common/metadata" M "github.com/metacubex/sing/common/metadata"
"github.com/metacubex/tls"
) )
type Vless struct { type Vless struct {
@@ -252,18 +250,7 @@ func (v *Vless) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn
return NewConn(c, v), nil return NewConn(c, v), nil
} }
return v.DialContextWithDialer(ctx, dialer.NewDialer(v.DialOptions()...), metadata) c, err = v.dialer.DialContext(ctx, "tcp", v.addr)
}
// DialContextWithDialer implements C.ProxyAdapter
func (v *Vless) DialContextWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (_ C.Conn, err error) {
if len(v.option.DialerProxy) > 0 {
dialer, err = proxydialer.NewByName(v.option.DialerProxy, dialer)
if err != nil {
return nil, err
}
}
c, err := dialer.DialContext(ctx, "tcp", v.addr)
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error: %s", v.addr, err.Error()) return nil, fmt.Errorf("%s connect error: %s", v.addr, err.Error())
} }
@@ -301,23 +288,12 @@ func (v *Vless) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (
return v.ListenPacketOnStreamConn(ctx, c, metadata) return v.ListenPacketOnStreamConn(ctx, c, metadata)
} }
return v.ListenPacketWithDialer(ctx, dialer.NewDialer(v.DialOptions()...), metadata)
}
// ListenPacketWithDialer implements C.ProxyAdapter
func (v *Vless) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (_ C.PacketConn, err error) {
if len(v.option.DialerProxy) > 0 {
dialer, err = proxydialer.NewByName(v.option.DialerProxy, dialer)
if err != nil {
return nil, err
}
}
if err = v.ResolveUDP(ctx, metadata); err != nil { if err = v.ResolveUDP(ctx, metadata); err != nil {
return nil, err return nil, err
} }
c, err := dialer.DialContext(ctx, "tcp", v.addr) c, err = v.dialer.DialContext(ctx, "tcp", v.addr)
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error: %s", v.addr, err.Error()) return nil, fmt.Errorf("%s connect error: %s", v.addr, err.Error())
} }
@@ -333,11 +309,6 @@ func (v *Vless) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, met
return v.ListenPacketOnStreamConn(ctx, c, metadata) return v.ListenPacketOnStreamConn(ctx, c, metadata)
} }
// SupportWithDialer implements C.ProxyAdapter
func (v *Vless) SupportWithDialer() C.NetWork {
return C.ALLNet
}
// ListenPacketOnStreamConn implements C.ProxyAdapter // ListenPacketOnStreamConn implements C.ProxyAdapter
func (v *Vless) ListenPacketOnStreamConn(ctx context.Context, c net.Conn, metadata *C.Metadata) (_ C.PacketConn, err error) { func (v *Vless) ListenPacketOnStreamConn(ctx context.Context, c net.Conn, metadata *C.Metadata) (_ C.PacketConn, err error) {
if err = v.ResolveUDP(ctx, metadata); err != nil { if err = v.ResolveUDP(ctx, metadata); err != nil {
@@ -446,17 +417,19 @@ func NewVless(option VlessOption) (*Vless, error) {
name: option.Name, name: option.Name,
addr: net.JoinHostPort(option.Server, strconv.Itoa(option.Port)), addr: net.JoinHostPort(option.Server, strconv.Itoa(option.Port)),
tp: C.Vless, tp: C.Vless,
pdName: option.ProviderName,
udp: option.UDP, udp: option.UDP,
xudp: option.XUDP, xudp: option.XUDP,
tfo: option.TFO, tfo: option.TFO,
mpTcp: option.MPTCP, mpTcp: option.MPTCP,
iface: option.Interface, iface: option.Interface,
rmark: option.RoutingMark, rmark: option.RoutingMark,
prefer: C.NewDNSPrefer(option.IPVersion), prefer: option.IPVersion,
}, },
client: client, client: client,
option: &option, option: &option,
} }
v.dialer = option.NewDialer(v.DialOptions())
v.encryption, err = encryption.NewClient(option.Encryption) v.encryption, err = encryption.NewClient(option.Encryption)
if err != nil { if err != nil {
@@ -480,15 +453,7 @@ func NewVless(option VlessOption) (*Vless, error) {
} }
case "grpc": case "grpc":
dialFn := func(ctx context.Context, network, addr string) (net.Conn, error) { dialFn := func(ctx context.Context, network, addr string) (net.Conn, error) {
var err error c, err := v.dialer.DialContext(ctx, "tcp", v.addr)
var cDialer C.Dialer = dialer.NewDialer(v.DialOptions()...)
if len(v.option.DialerProxy) > 0 {
cDialer, err = proxydialer.NewByName(v.option.DialerProxy, cDialer)
if err != nil {
return nil, err
}
}
c, err := cDialer.DialContext(ctx, "tcp", v.addr)
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error: %s", v.addr, err.Error()) return nil, fmt.Errorf("%s connect error: %s", v.addr, err.Error())
} }
@@ -497,6 +462,7 @@ func NewVless(option VlessOption) (*Vless, error) {
gunConfig := &gun.Config{ gunConfig := &gun.Config{
ServiceName: v.option.GrpcOpts.GrpcServiceName, ServiceName: v.option.GrpcOpts.GrpcServiceName,
UserAgent: v.option.GrpcOpts.GrpcUserAgent,
Host: v.option.ServerName, Host: v.option.ServerName,
ClientFingerprint: v.option.ClientFingerprint, ClientFingerprint: v.option.ClientFingerprint,
} }

View File

@@ -2,11 +2,9 @@ package outbound
import ( import (
"context" "context"
"crypto/tls"
"errors" "errors"
"fmt" "fmt"
"net" "net"
"net/http"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@@ -14,18 +12,18 @@ import (
N "github.com/metacubex/mihomo/common/net" N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/common/utils" "github.com/metacubex/mihomo/common/utils"
"github.com/metacubex/mihomo/component/ca" "github.com/metacubex/mihomo/component/ca"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/ech" "github.com/metacubex/mihomo/component/ech"
"github.com/metacubex/mihomo/component/proxydialer"
tlsC "github.com/metacubex/mihomo/component/tls" tlsC "github.com/metacubex/mihomo/component/tls"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/ntp" "github.com/metacubex/mihomo/ntp"
"github.com/metacubex/mihomo/transport/gun" "github.com/metacubex/mihomo/transport/gun"
mihomoVMess "github.com/metacubex/mihomo/transport/vmess" mihomoVMess "github.com/metacubex/mihomo/transport/vmess"
"github.com/metacubex/http"
vmess "github.com/metacubex/sing-vmess" vmess "github.com/metacubex/sing-vmess"
"github.com/metacubex/sing-vmess/packetaddr" "github.com/metacubex/sing-vmess/packetaddr"
M "github.com/metacubex/sing/common/metadata" M "github.com/metacubex/sing/common/metadata"
"github.com/metacubex/tls"
) )
var ErrUDPRemoteAddrMismatch = errors.New("udp packet dropped due to mismatched remote address") var ErrUDPRemoteAddrMismatch = errors.New("udp packet dropped due to mismatched remote address")
@@ -88,6 +86,7 @@ type HTTP2Options struct {
type GrpcOptions struct { type GrpcOptions struct {
GrpcServiceName string `proxy:"grpc-service-name,omitempty"` GrpcServiceName string `proxy:"grpc-service-name,omitempty"`
GrpcUserAgent string `proxy:"grpc-user-agent,omitempty"`
} }
type WSOptions struct { type WSOptions struct {
@@ -313,18 +312,7 @@ func (v *Vmess) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn
return NewConn(c, v), nil return NewConn(c, v), nil
} }
return v.DialContextWithDialer(ctx, dialer.NewDialer(v.DialOptions()...), metadata) c, err = v.dialer.DialContext(ctx, "tcp", v.addr)
}
// DialContextWithDialer implements C.ProxyAdapter
func (v *Vmess) DialContextWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (_ C.Conn, err error) {
if len(v.option.DialerProxy) > 0 {
dialer, err = proxydialer.NewByName(v.option.DialerProxy, dialer)
if err != nil {
return nil, err
}
}
c, err := dialer.DialContext(ctx, "tcp", v.addr)
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error: %s", v.addr, err.Error()) return nil, fmt.Errorf("%s connect error: %s", v.addr, err.Error())
} }
@@ -358,23 +346,12 @@ func (v *Vmess) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (
} }
return v.ListenPacketOnStreamConn(ctx, c, metadata) return v.ListenPacketOnStreamConn(ctx, c, metadata)
} }
return v.ListenPacketWithDialer(ctx, dialer.NewDialer(v.DialOptions()...), metadata)
}
// ListenPacketWithDialer implements C.ProxyAdapter
func (v *Vmess) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (_ C.PacketConn, err error) {
if len(v.option.DialerProxy) > 0 {
dialer, err = proxydialer.NewByName(v.option.DialerProxy, dialer)
if err != nil {
return nil, err
}
}
if err = v.ResolveUDP(ctx, metadata); err != nil { if err = v.ResolveUDP(ctx, metadata); err != nil {
return nil, err return nil, err
} }
c, err := dialer.DialContext(ctx, "tcp", v.addr) c, err = v.dialer.DialContext(ctx, "tcp", v.addr)
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error: %s", v.addr, err.Error()) return nil, fmt.Errorf("%s connect error: %s", v.addr, err.Error())
} }
@@ -389,11 +366,6 @@ func (v *Vmess) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, met
return v.ListenPacketOnStreamConn(ctx, c, metadata) return v.ListenPacketOnStreamConn(ctx, c, metadata)
} }
// SupportWithDialer implements C.ProxyAdapter
func (v *Vmess) SupportWithDialer() C.NetWork {
return C.ALLNet
}
// ProxyInfo implements C.ProxyAdapter // ProxyInfo implements C.ProxyAdapter
func (v *Vmess) ProxyInfo() C.ProxyInfo { func (v *Vmess) ProxyInfo() C.ProxyInfo {
info := v.Base.ProxyInfo() info := v.Base.ProxyInfo()
@@ -456,17 +428,19 @@ func NewVmess(option VmessOption) (*Vmess, error) {
name: option.Name, name: option.Name,
addr: net.JoinHostPort(option.Server, strconv.Itoa(option.Port)), addr: net.JoinHostPort(option.Server, strconv.Itoa(option.Port)),
tp: C.Vmess, tp: C.Vmess,
pdName: option.ProviderName,
udp: option.UDP, udp: option.UDP,
xudp: option.XUDP, xudp: option.XUDP,
tfo: option.TFO, tfo: option.TFO,
mpTcp: option.MPTCP, mpTcp: option.MPTCP,
iface: option.Interface, iface: option.Interface,
rmark: option.RoutingMark, rmark: option.RoutingMark,
prefer: C.NewDNSPrefer(option.IPVersion), prefer: option.IPVersion,
}, },
client: client, client: client,
option: &option, option: &option,
} }
v.dialer = option.NewDialer(v.DialOptions())
v.realityConfig, err = v.option.RealityOpts.Parse() v.realityConfig, err = v.option.RealityOpts.Parse()
if err != nil { if err != nil {
@@ -485,15 +459,7 @@ func NewVmess(option VmessOption) (*Vmess, error) {
} }
case "grpc": case "grpc":
dialFn := func(ctx context.Context, network, addr string) (net.Conn, error) { dialFn := func(ctx context.Context, network, addr string) (net.Conn, error) {
var err error c, err := v.dialer.DialContext(ctx, "tcp", v.addr)
var cDialer C.Dialer = dialer.NewDialer(v.DialOptions()...)
if len(v.option.DialerProxy) > 0 {
cDialer, err = proxydialer.NewByName(v.option.DialerProxy, cDialer)
if err != nil {
return nil, err
}
}
c, err := cDialer.DialContext(ctx, "tcp", v.addr)
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error: %s", v.addr, err.Error()) return nil, fmt.Errorf("%s connect error: %s", v.addr, err.Error())
} }
@@ -502,6 +468,7 @@ func NewVmess(option VmessOption) (*Vmess, error) {
gunConfig := &gun.Config{ gunConfig := &gun.Config{
ServiceName: v.option.GrpcOpts.GrpcServiceName, ServiceName: v.option.GrpcOpts.GrpcServiceName,
UserAgent: v.option.GrpcOpts.GrpcUserAgent,
Host: v.option.ServerName, Host: v.option.ServerName,
ClientFingerprint: v.option.ClientFingerprint, ClientFingerprint: v.option.ClientFingerprint,
} }

View File

@@ -40,7 +40,6 @@ type WireGuard struct {
bind *wireguard.ClientBind bind *wireguard.ClientBind
device wireguardGoDevice device wireguardGoDevice
tunDevice wireguard.Device tunDevice wireguard.Device
dialer proxydialer.SingDialer
resolver resolver.Resolver resolver resolver.Resolver
initOk atomic.Bool initOk atomic.Bool
@@ -171,14 +170,15 @@ func NewWireGuard(option WireGuardOption) (*WireGuard, error) {
name: option.Name, name: option.Name,
addr: net.JoinHostPort(option.Server, strconv.Itoa(option.Port)), addr: net.JoinHostPort(option.Server, strconv.Itoa(option.Port)),
tp: C.WireGuard, tp: C.WireGuard,
pdName: option.ProviderName,
udp: option.UDP, udp: option.UDP,
iface: option.Interface, iface: option.Interface,
rmark: option.RoutingMark, rmark: option.RoutingMark,
prefer: C.NewDNSPrefer(option.IPVersion), prefer: option.IPVersion,
}, },
} }
singDialer := proxydialer.NewSlowDownSingDialer(proxydialer.NewByNameSingDialer(option.DialerProxy, dialer.NewDialer(outbound.DialOptions()...)), slowdown.New()) outbound.dialer = option.NewDialer(outbound.DialOptions())
outbound.dialer = singDialer singDialer := proxydialer.NewSingDialer(proxydialer.NewSlowDownDialer(outbound.dialer, slowdown.New()))
var reserved [3]uint8 var reserved [3]uint8
if len(option.Reserved) > 0 { if len(option.Reserved) > 0 {
@@ -196,7 +196,7 @@ func NewWireGuard(option WireGuardOption) (*WireGuard, error) {
outbound.connectAddr = option.Addr() outbound.connectAddr = option.Addr()
} }
} }
outbound.bind = wireguard.NewClientBind(context.Background(), wgSingErrorHandler{outbound.Name()}, outbound.dialer, isConnect, outbound.connectAddr.AddrPort(), reserved) outbound.bind = wireguard.NewClientBind(context.Background(), wgSingErrorHandler{outbound.Name()}, singDialer, isConnect, outbound.connectAddr.AddrPort(), reserved)
var err error var err error
outbound.localPrefixes, err = option.Prefixes() outbound.localPrefixes, err = option.Prefixes()
@@ -609,6 +609,13 @@ func (w *WireGuard) ResolveUDP(ctx context.Context, metadata *C.Metadata) error
return nil return nil
} }
// ProxyInfo implements C.ProxyAdapter
func (w *WireGuard) ProxyInfo() C.ProxyInfo {
info := w.Base.ProxyInfo()
info.DialerProxy = w.option.DialerProxy
return info
}
// IsL3Protocol implements C.ProxyAdapter // IsL3Protocol implements C.ProxyAdapter
func (w *WireGuard) IsL3Protocol(metadata *C.Metadata) bool { func (w *WireGuard) IsL3Protocol(metadata *C.Metadata) bool {
return true return true

View File

@@ -150,6 +150,14 @@ func (f *Fallback) ForceSet(name string) {
f.selected = name f.selected = name
} }
func (f *Fallback) Providers() []P.ProxyProvider {
return f.providers
}
func (f *Fallback) Proxies() []C.Proxy {
return f.GetProxies(false)
}
func NewFallback(option *GroupCommonOption, providers []P.ProxyProvider) *Fallback { func NewFallback(option *GroupCommonOption, providers []P.ProxyProvider) *Fallback {
return &Fallback{ return &Fallback{
GroupBase: NewGroupBase(GroupBaseOption{ GroupBase: NewGroupBase(GroupBaseOption{

View File

@@ -239,6 +239,18 @@ func (lb *LoadBalance) MarshalJSON() ([]byte, error) {
}) })
} }
func (lb *LoadBalance) Providers() []P.ProxyProvider {
return lb.providers
}
func (lb *LoadBalance) Proxies() []C.Proxy {
return lb.GetProxies(false)
}
func (lb *LoadBalance) Now() string {
return ""
}
func NewLoadBalance(option *GroupCommonOption, providers []P.ProxyProvider, strategy string) (lb *LoadBalance, err error) { func NewLoadBalance(option *GroupCommonOption, providers []P.ProxyProvider, strategy string) (lb *LoadBalance, err error) {
var strategyFn strategyFn var strategyFn strategyFn
switch strategy { switch strategy {

View File

@@ -1,64 +0,0 @@
//go:build android && cmfa
package outboundgroup
import (
C "github.com/metacubex/mihomo/constant"
P "github.com/metacubex/mihomo/constant/provider"
)
type ProxyGroup interface {
C.ProxyAdapter
Providers() []P.ProxyProvider
Proxies() []C.Proxy
Now() string
}
func (f *Fallback) Providers() []P.ProxyProvider {
return f.providers
}
func (lb *LoadBalance) Providers() []P.ProxyProvider {
return lb.providers
}
func (f *Fallback) Proxies() []C.Proxy {
return f.GetProxies(false)
}
func (lb *LoadBalance) Proxies() []C.Proxy {
return lb.GetProxies(false)
}
func (lb *LoadBalance) Now() string {
return ""
}
func (r *Relay) Providers() []P.ProxyProvider {
return r.providers
}
func (r *Relay) Proxies() []C.Proxy {
return r.GetProxies(false)
}
func (r *Relay) Now() string {
return ""
}
func (s *Selector) Providers() []P.ProxyProvider {
return s.providers
}
func (s *Selector) Proxies() []C.Proxy {
return s.GetProxies(false)
}
func (u *URLTest) Providers() []P.ProxyProvider {
return u.providers
}
func (u *URLTest) Proxies() []C.Proxy {
return u.GetProxies(false)
}

View File

@@ -1,163 +0,0 @@
package outboundgroup
import (
"context"
"encoding/json"
"github.com/metacubex/mihomo/adapter/outbound"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/proxydialer"
C "github.com/metacubex/mihomo/constant"
P "github.com/metacubex/mihomo/constant/provider"
"github.com/metacubex/mihomo/log"
)
type Relay struct {
*GroupBase
Hidden bool
Icon string
}
// DialContext implements C.ProxyAdapter
func (r *Relay) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) {
proxies, chainProxies := r.proxies(metadata, true)
switch len(proxies) {
case 0:
return outbound.NewDirect().DialContext(ctx, metadata)
case 1:
return proxies[0].DialContext(ctx, metadata)
}
var d C.Dialer
d = dialer.NewDialer()
for _, proxy := range proxies[:len(proxies)-1] {
d = proxydialer.New(proxy, d, false)
}
last := proxies[len(proxies)-1]
conn, err := last.DialContextWithDialer(ctx, d, metadata)
if err != nil {
return nil, err
}
for i := len(chainProxies) - 2; i >= 0; i-- {
conn.AppendToChains(chainProxies[i])
}
conn.AppendToChains(r)
return conn, nil
}
// ListenPacketContext implements C.ProxyAdapter
func (r *Relay) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (_ C.PacketConn, err error) {
proxies, chainProxies := r.proxies(metadata, true)
switch len(proxies) {
case 0:
return outbound.NewDirect().ListenPacketContext(ctx, metadata)
case 1:
return proxies[0].ListenPacketContext(ctx, metadata)
}
var d C.Dialer
d = dialer.NewDialer()
for _, proxy := range proxies[:len(proxies)-1] {
d = proxydialer.New(proxy, d, false)
}
last := proxies[len(proxies)-1]
pc, err := last.ListenPacketWithDialer(ctx, d, metadata)
if err != nil {
return nil, err
}
for i := len(chainProxies) - 2; i >= 0; i-- {
pc.AppendToChains(chainProxies[i])
}
pc.AppendToChains(r)
return pc, nil
}
// SupportUDP implements C.ProxyAdapter
func (r *Relay) SupportUDP() bool {
proxies, _ := r.proxies(nil, false)
if len(proxies) == 0 { // C.Direct
return true
}
for i := len(proxies) - 1; i >= 0; i-- {
proxy := proxies[i]
if !proxy.SupportUDP() {
return false
}
if proxy.SupportUOT() {
return true
}
switch proxy.SupportWithDialer() {
case C.ALLNet:
case C.UDP:
default: // C.TCP and C.InvalidNet
return false
}
}
return true
}
// MarshalJSON implements C.ProxyAdapter
func (r *Relay) MarshalJSON() ([]byte, error) {
all := []string{}
for _, proxy := range r.GetProxies(false) {
all = append(all, proxy.Name())
}
return json.Marshal(map[string]any{
"type": r.Type().String(),
"all": all,
"hidden": r.Hidden,
"icon": r.Icon,
})
}
func (r *Relay) proxies(metadata *C.Metadata, touch bool) ([]C.Proxy, []C.Proxy) {
rawProxies := r.GetProxies(touch)
var proxies []C.Proxy
var chainProxies []C.Proxy
var targetProxies []C.Proxy
for n, proxy := range rawProxies {
proxies = append(proxies, proxy)
chainProxies = append(chainProxies, proxy)
subproxy := proxy.Unwrap(metadata, touch)
for subproxy != nil {
chainProxies = append(chainProxies, subproxy)
proxies[n] = subproxy
subproxy = subproxy.Unwrap(metadata, touch)
}
}
for _, proxy := range proxies {
if proxy.Type() != C.Direct && proxy.Type() != C.Compatible {
targetProxies = append(targetProxies, proxy)
}
}
return targetProxies, chainProxies
}
func (r *Relay) Addr() string {
proxies, _ := r.proxies(nil, false)
return proxies[len(proxies)-1].Addr()
}
func NewRelay(option *GroupCommonOption, providers []P.ProxyProvider) *Relay {
log.Warnln("The group [%s] with relay type is deprecated, please using dialer-proxy instead", option.Name)
return &Relay{
GroupBase: NewGroupBase(GroupBaseOption{
Name: option.Name,
Type: C.Relay,
Providers: providers,
}),
Hidden: option.Hidden,
Icon: option.Icon,
}
}

View File

@@ -108,6 +108,14 @@ func (s *Selector) selectedProxy(touch bool) C.Proxy {
return proxies[0] return proxies[0]
} }
func (s *Selector) Providers() []P.ProxyProvider {
return s.providers
}
func (s *Selector) Proxies() []C.Proxy {
return s.GetProxies(false)
}
func NewSelector(option *GroupCommonOption, providers []P.ProxyProvider) *Selector { func NewSelector(option *GroupCommonOption, providers []P.ProxyProvider) *Selector {
return &Selector{ return &Selector{
GroupBase: NewGroupBase(GroupBaseOption{ GroupBase: NewGroupBase(GroupBaseOption{

View File

@@ -185,6 +185,14 @@ func (u *URLTest) MarshalJSON() ([]byte, error) {
}) })
} }
func (u *URLTest) Providers() []P.ProxyProvider {
return u.providers
}
func (u *URLTest) Proxies() []C.Proxy {
return u.GetProxies(false)
}
func (u *URLTest) URLTest(ctx context.Context, url string, expectedStatus utils.IntRanges[uint16]) (map[string]uint16, error) { func (u *URLTest) URLTest(ctx context.Context, url string, expectedStatus utils.IntRanges[uint16]) (map[string]uint16, error) {
return u.GroupBase.URLTest(ctx, u.testUrl, expectedStatus) return u.GroupBase.URLTest(ctx, u.testUrl, expectedStatus)
} }

View File

@@ -1,5 +1,29 @@
package outboundgroup package outboundgroup
import (
"context"
"github.com/metacubex/mihomo/common/utils"
C "github.com/metacubex/mihomo/constant"
P "github.com/metacubex/mihomo/constant/provider"
)
type ProxyGroup interface {
C.ProxyAdapter
Providers() []P.ProxyProvider
Proxies() []C.Proxy
Now() string
Touch()
URLTest(ctx context.Context, url string, expectedStatus utils.IntRanges[uint16]) (mp map[string]uint16, err error)
}
var _ ProxyGroup = (*Fallback)(nil)
var _ ProxyGroup = (*LoadBalance)(nil)
var _ ProxyGroup = (*URLTest)(nil)
var _ ProxyGroup = (*Selector)(nil)
type SelectAble interface { type SelectAble interface {
Set(string) error Set(string) error
ForceSet(name string) ForceSet(name string)

View File

@@ -8,151 +8,164 @@ import (
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
) )
func ParseProxy(mapping map[string]any) (C.Proxy, error) { func ParseProxy(mapping map[string]any, options ...ProxyOption) (C.Proxy, error) {
decoder := structure.NewDecoder(structure.Option{TagName: "proxy", WeaklyTypedInput: true, KeyReplacer: structure.DefaultKeyReplacer}) decoder := structure.NewDecoder(structure.Option{TagName: "proxy", WeaklyTypedInput: true, KeyReplacer: structure.DefaultKeyReplacer})
proxyType, existType := mapping["type"].(string) proxyType, existType := mapping["type"].(string)
if !existType { if !existType {
return nil, fmt.Errorf("missing type") return nil, fmt.Errorf("missing type")
} }
opt := applyProxyOptions(options...)
basicOption := outbound.BasicOption{
DialerForAPI: opt.DialerForAPI,
ProviderName: opt.ProviderName,
}
var ( var (
proxy outbound.ProxyAdapter proxy outbound.ProxyAdapter
err error err error
) )
switch proxyType { switch proxyType {
case "ss": case "ss":
ssOption := &outbound.ShadowSocksOption{} ssOption := &outbound.ShadowSocksOption{BasicOption: basicOption}
err = decoder.Decode(mapping, ssOption) err = decoder.Decode(mapping, ssOption)
if err != nil { if err != nil {
break break
} }
proxy, err = outbound.NewShadowSocks(*ssOption) proxy, err = outbound.NewShadowSocks(*ssOption)
case "ssr": case "ssr":
ssrOption := &outbound.ShadowSocksROption{} ssrOption := &outbound.ShadowSocksROption{BasicOption: basicOption}
err = decoder.Decode(mapping, ssrOption) err = decoder.Decode(mapping, ssrOption)
if err != nil { if err != nil {
break break
} }
proxy, err = outbound.NewShadowSocksR(*ssrOption) proxy, err = outbound.NewShadowSocksR(*ssrOption)
case "socks5": case "socks5":
socksOption := &outbound.Socks5Option{} socksOption := &outbound.Socks5Option{BasicOption: basicOption}
err = decoder.Decode(mapping, socksOption) err = decoder.Decode(mapping, socksOption)
if err != nil { if err != nil {
break break
} }
proxy, err = outbound.NewSocks5(*socksOption) proxy, err = outbound.NewSocks5(*socksOption)
case "http": case "http":
httpOption := &outbound.HttpOption{} httpOption := &outbound.HttpOption{BasicOption: basicOption}
err = decoder.Decode(mapping, httpOption) err = decoder.Decode(mapping, httpOption)
if err != nil { if err != nil {
break break
} }
proxy, err = outbound.NewHttp(*httpOption) proxy, err = outbound.NewHttp(*httpOption)
case "vmess": case "vmess":
vmessOption := &outbound.VmessOption{} vmessOption := &outbound.VmessOption{BasicOption: basicOption}
err = decoder.Decode(mapping, vmessOption) err = decoder.Decode(mapping, vmessOption)
if err != nil { if err != nil {
break break
} }
proxy, err = outbound.NewVmess(*vmessOption) proxy, err = outbound.NewVmess(*vmessOption)
case "vless": case "vless":
vlessOption := &outbound.VlessOption{} vlessOption := &outbound.VlessOption{BasicOption: basicOption}
err = decoder.Decode(mapping, vlessOption) err = decoder.Decode(mapping, vlessOption)
if err != nil { if err != nil {
break break
} }
proxy, err = outbound.NewVless(*vlessOption) proxy, err = outbound.NewVless(*vlessOption)
case "snell": case "snell":
snellOption := &outbound.SnellOption{} snellOption := &outbound.SnellOption{BasicOption: basicOption}
err = decoder.Decode(mapping, snellOption) err = decoder.Decode(mapping, snellOption)
if err != nil { if err != nil {
break break
} }
proxy, err = outbound.NewSnell(*snellOption) proxy, err = outbound.NewSnell(*snellOption)
case "trojan": case "trojan":
trojanOption := &outbound.TrojanOption{} trojanOption := &outbound.TrojanOption{BasicOption: basicOption}
err = decoder.Decode(mapping, trojanOption) err = decoder.Decode(mapping, trojanOption)
if err != nil { if err != nil {
break break
} }
proxy, err = outbound.NewTrojan(*trojanOption) proxy, err = outbound.NewTrojan(*trojanOption)
case "hysteria": case "hysteria":
hyOption := &outbound.HysteriaOption{} hyOption := &outbound.HysteriaOption{BasicOption: basicOption}
err = decoder.Decode(mapping, hyOption) err = decoder.Decode(mapping, hyOption)
if err != nil { if err != nil {
break break
} }
proxy, err = outbound.NewHysteria(*hyOption) proxy, err = outbound.NewHysteria(*hyOption)
case "hysteria2": case "hysteria2":
hyOption := &outbound.Hysteria2Option{} hyOption := &outbound.Hysteria2Option{BasicOption: basicOption}
err = decoder.Decode(mapping, hyOption) err = decoder.Decode(mapping, hyOption)
if err != nil { if err != nil {
break break
} }
proxy, err = outbound.NewHysteria2(*hyOption) proxy, err = outbound.NewHysteria2(*hyOption)
case "wireguard": case "wireguard":
wgOption := &outbound.WireGuardOption{} wgOption := &outbound.WireGuardOption{BasicOption: basicOption}
err = decoder.Decode(mapping, wgOption) err = decoder.Decode(mapping, wgOption)
if err != nil { if err != nil {
break break
} }
proxy, err = outbound.NewWireGuard(*wgOption) proxy, err = outbound.NewWireGuard(*wgOption)
case "tuic": case "tuic":
tuicOption := &outbound.TuicOption{} tuicOption := &outbound.TuicOption{BasicOption: basicOption}
err = decoder.Decode(mapping, tuicOption) err = decoder.Decode(mapping, tuicOption)
if err != nil { if err != nil {
break break
} }
proxy, err = outbound.NewTuic(*tuicOption) proxy, err = outbound.NewTuic(*tuicOption)
case "direct": case "direct":
directOption := &outbound.DirectOption{} directOption := &outbound.DirectOption{BasicOption: basicOption}
err = decoder.Decode(mapping, directOption) err = decoder.Decode(mapping, directOption)
if err != nil { if err != nil {
break break
} }
proxy = outbound.NewDirectWithOption(*directOption) proxy = outbound.NewDirectWithOption(*directOption)
case "dns": case "dns":
dnsOptions := &outbound.DnsOption{} dnsOptions := &outbound.DnsOption{BasicOption: basicOption}
err = decoder.Decode(mapping, dnsOptions) err = decoder.Decode(mapping, dnsOptions)
if err != nil { if err != nil {
break break
} }
proxy = outbound.NewDnsWithOption(*dnsOptions) proxy = outbound.NewDnsWithOption(*dnsOptions)
case "reject": case "reject":
rejectOption := &outbound.RejectOption{} rejectOption := &outbound.RejectOption{BasicOption: basicOption}
err = decoder.Decode(mapping, rejectOption) err = decoder.Decode(mapping, rejectOption)
if err != nil { if err != nil {
break break
} }
proxy = outbound.NewRejectWithOption(*rejectOption) proxy = outbound.NewRejectWithOption(*rejectOption)
case "ssh": case "ssh":
sshOption := &outbound.SshOption{} sshOption := &outbound.SshOption{BasicOption: basicOption}
err = decoder.Decode(mapping, sshOption) err = decoder.Decode(mapping, sshOption)
if err != nil { if err != nil {
break break
} }
proxy, err = outbound.NewSsh(*sshOption) proxy, err = outbound.NewSsh(*sshOption)
case "mieru": case "mieru":
mieruOption := &outbound.MieruOption{} mieruOption := &outbound.MieruOption{BasicOption: basicOption}
err = decoder.Decode(mapping, mieruOption) err = decoder.Decode(mapping, mieruOption)
if err != nil { if err != nil {
break break
} }
proxy, err = outbound.NewMieru(*mieruOption) proxy, err = outbound.NewMieru(*mieruOption)
case "anytls": case "anytls":
anytlsOption := &outbound.AnyTLSOption{} anytlsOption := &outbound.AnyTLSOption{BasicOption: basicOption}
err = decoder.Decode(mapping, anytlsOption) err = decoder.Decode(mapping, anytlsOption)
if err != nil { if err != nil {
break break
} }
proxy, err = outbound.NewAnyTLS(*anytlsOption) proxy, err = outbound.NewAnyTLS(*anytlsOption)
case "sudoku": case "sudoku":
sudokuOption := &outbound.SudokuOption{} sudokuOption := &outbound.SudokuOption{BasicOption: basicOption}
err = decoder.Decode(mapping, sudokuOption) err = decoder.Decode(mapping, sudokuOption)
if err != nil { if err != nil {
break break
} }
proxy, err = outbound.NewSudoku(*sudokuOption) proxy, err = outbound.NewSudoku(*sudokuOption)
case "masque":
masqueOption := &outbound.MasqueOption{BasicOption: basicOption}
err = decoder.Decode(mapping, masqueOption)
if err != nil {
break
}
proxy, err = outbound.NewMasque(*masqueOption)
default: default:
return nil, fmt.Errorf("unsupport proxy type: %s", proxyType) return nil, fmt.Errorf("unsupport proxy type: %s", proxyType)
} }
@@ -178,3 +191,30 @@ func ParseProxy(mapping map[string]any) (C.Proxy, error) {
proxy = outbound.NewAutoCloseProxyAdapter(proxy) proxy = outbound.NewAutoCloseProxyAdapter(proxy)
return NewProxy(proxy), nil return NewProxy(proxy), nil
} }
type proxyOption struct {
DialerForAPI C.Dialer
ProviderName string
}
func applyProxyOptions(options ...ProxyOption) proxyOption {
opt := proxyOption{}
for _, o := range options {
o(&opt)
}
return opt
}
type ProxyOption func(opt *proxyOption)
func WithDialerForAPI(dialer C.Dialer) ProxyOption {
return func(opt *proxyOption) {
opt.DialerForAPI = dialer
}
}
func WithProviderName(name string) ProxyOption {
return func(opt *proxyOption) {
opt.ProviderName = name
}
}

View File

@@ -0,0 +1,88 @@
package provider
import (
"encoding"
"fmt"
"github.com/dlclark/regexp2"
)
type overrideSchema struct {
TFO *bool `provider:"tfo,omitempty"`
MPTcp *bool `provider:"mptcp,omitempty"`
UDP *bool `provider:"udp,omitempty"`
UDPOverTCP *bool `provider:"udp-over-tcp,omitempty"`
Up *string `provider:"up,omitempty"`
Down *string `provider:"down,omitempty"`
DialerProxy *string `provider:"dialer-proxy,omitempty"`
SkipCertVerify *bool `provider:"skip-cert-verify,omitempty"`
Interface *string `provider:"interface-name,omitempty"`
RoutingMark *int `provider:"routing-mark,omitempty"`
IPVersion *string `provider:"ip-version,omitempty"`
AdditionalPrefix *string `provider:"additional-prefix,omitempty"`
AdditionalSuffix *string `provider:"additional-suffix,omitempty"`
ProxyName []overrideProxyNameSchema `provider:"proxy-name,omitempty"`
}
type overrideProxyNameSchema struct {
// matching expression for regex replacement
Pattern *regexp2.Regexp `provider:"pattern"`
// the new content after regex matching
Target string `provider:"target"`
}
var _ encoding.TextUnmarshaler = (*regexp2.Regexp)(nil) // ensure *regexp2.Regexp can decode direct by structure package
func (o *overrideSchema) Apply(mapping map[string]any) error {
if o.TFO != nil {
mapping["tfo"] = *o.TFO
}
if o.MPTcp != nil {
mapping["mptcp"] = *o.MPTcp
}
if o.UDP != nil {
mapping["udp"] = *o.UDP
}
if o.UDPOverTCP != nil {
mapping["udp-over-tcp"] = *o.UDPOverTCP
}
if o.Up != nil {
mapping["up"] = *o.Up
}
if o.Down != nil {
mapping["down"] = *o.Down
}
if o.DialerProxy != nil {
mapping["dialer-proxy"] = *o.DialerProxy
}
if o.SkipCertVerify != nil {
mapping["skip-cert-verify"] = *o.SkipCertVerify
}
if o.Interface != nil {
mapping["interface"] = *o.Interface
}
if o.RoutingMark != nil {
mapping["routing-mark"] = *o.RoutingMark
}
if o.IPVersion != nil {
mapping["ip-version"] = *o.IPVersion
}
for _, expr := range o.ProxyName {
name := mapping["name"].(string)
newName, err := expr.Pattern.Replace(name, expr.Target, 0, -1)
if err != nil {
return fmt.Errorf("proxy name replace error: %w", err)
}
mapping["name"] = newName
}
if o.AdditionalPrefix != nil {
mapping["name"] = fmt.Sprintf("%s%s", *o.AdditionalPrefix, mapping["name"])
}
if o.AdditionalSuffix != nil {
mapping["name"] = fmt.Sprintf("%s%s", mapping["name"], *o.AdditionalSuffix)
}
return nil
}

View File

@@ -1,7 +1,6 @@
package provider package provider
import ( import (
"encoding"
"errors" "errors"
"fmt" "fmt"
"time" "time"
@@ -11,8 +10,6 @@ import (
"github.com/metacubex/mihomo/component/resource" "github.com/metacubex/mihomo/component/resource"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
P "github.com/metacubex/mihomo/constant/provider" P "github.com/metacubex/mihomo/constant/provider"
"github.com/dlclark/regexp2"
) )
var ( var (
@@ -28,33 +25,6 @@ type healthCheckSchema struct {
ExpectedStatus string `provider:"expected-status,omitempty"` ExpectedStatus string `provider:"expected-status,omitempty"`
} }
type OverrideProxyNameSchema struct {
// matching expression for regex replacement
Pattern *regexp2.Regexp `provider:"pattern"`
// the new content after regex matching
Target string `provider:"target"`
}
var _ encoding.TextUnmarshaler = (*regexp2.Regexp)(nil) // ensure *regexp2.Regexp can decode direct by structure package
type OverrideSchema struct {
TFO *bool `provider:"tfo,omitempty"`
MPTcp *bool `provider:"mptcp,omitempty"`
UDP *bool `provider:"udp,omitempty"`
UDPOverTCP *bool `provider:"udp-over-tcp,omitempty"`
Up *string `provider:"up,omitempty"`
Down *string `provider:"down,omitempty"`
DialerProxy *string `provider:"dialer-proxy,omitempty"`
SkipCertVerify *bool `provider:"skip-cert-verify,omitempty"`
Interface *string `provider:"interface-name,omitempty"`
RoutingMark *int `provider:"routing-mark,omitempty"`
IPVersion *string `provider:"ip-version,omitempty"`
AdditionalPrefix *string `provider:"additional-prefix,omitempty"`
AdditionalSuffix *string `provider:"additional-suffix,omitempty"`
ProxyName []OverrideProxyNameSchema `provider:"proxy-name,omitempty"`
}
type proxyProviderSchema struct { type proxyProviderSchema struct {
Type string `provider:"type"` Type string `provider:"type"`
Path string `provider:"path,omitempty"` Path string `provider:"path,omitempty"`
@@ -69,7 +39,7 @@ type proxyProviderSchema struct {
Payload []map[string]any `provider:"payload,omitempty"` Payload []map[string]any `provider:"payload,omitempty"`
HealthCheck healthCheckSchema `provider:"health-check,omitempty"` HealthCheck healthCheckSchema `provider:"health-check,omitempty"`
Override OverrideSchema `provider:"override,omitempty"` Override overrideSchema `provider:"override,omitempty"`
Header map[string][]string `provider:"header,omitempty"` Header map[string][]string `provider:"header,omitempty"`
} }
@@ -99,7 +69,7 @@ func ParseProxyProvider(name string, mapping map[string]any) (P.ProxyProvider, e
} }
hc := NewHealthCheck([]C.Proxy{}, schema.HealthCheck.URL, uint(schema.HealthCheck.TestTimeout), hcInterval, schema.HealthCheck.Lazy, expectedStatus) hc := NewHealthCheck([]C.Proxy{}, schema.HealthCheck.URL, uint(schema.HealthCheck.TestTimeout), hcInterval, schema.HealthCheck.Lazy, expectedStatus)
parser, err := NewProxiesParser(schema.Filter, schema.ExcludeFilter, schema.ExcludeType, schema.DialerProxy, schema.Override) parser, err := NewProxiesParser(name, schema.Filter, schema.ExcludeFilter, schema.ExcludeType, schema.DialerProxy, schema.Override)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -4,15 +4,15 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"net/http"
"reflect"
"runtime" "runtime"
"strings" "strings"
"sync"
"time" "time"
"github.com/metacubex/mihomo/adapter" "github.com/metacubex/mihomo/adapter"
"github.com/metacubex/mihomo/common/convert" "github.com/metacubex/mihomo/common/convert"
"github.com/metacubex/mihomo/common/utils" "github.com/metacubex/mihomo/common/utils"
"github.com/metacubex/mihomo/common/yaml"
"github.com/metacubex/mihomo/component/profile/cachefile" "github.com/metacubex/mihomo/component/profile/cachefile"
"github.com/metacubex/mihomo/component/resource" "github.com/metacubex/mihomo/component/resource"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
@@ -20,7 +20,7 @@ import (
"github.com/metacubex/mihomo/tunnel/statistic" "github.com/metacubex/mihomo/tunnel/statistic"
"github.com/dlclark/regexp2" "github.com/dlclark/regexp2"
"gopkg.in/yaml.v3" "github.com/metacubex/http"
) )
const ( const (
@@ -43,6 +43,7 @@ type providerForApi struct {
} }
type baseProvider struct { type baseProvider struct {
mutex sync.RWMutex
name string name string
proxies []C.Proxy proxies []C.Proxy
healthCheck *HealthCheck healthCheck *HealthCheck
@@ -54,6 +55,8 @@ func (bp *baseProvider) Name() string {
} }
func (bp *baseProvider) Version() uint32 { func (bp *baseProvider) Version() uint32 {
bp.mutex.RLock()
defer bp.mutex.RUnlock()
return bp.version return bp.version
} }
@@ -73,10 +76,14 @@ func (bp *baseProvider) Type() P.ProviderType {
} }
func (bp *baseProvider) Proxies() []C.Proxy { func (bp *baseProvider) Proxies() []C.Proxy {
bp.mutex.RLock()
defer bp.mutex.RUnlock()
return bp.proxies return bp.proxies
} }
func (bp *baseProvider) Count() int { func (bp *baseProvider) Count() int {
bp.mutex.RLock()
defer bp.mutex.RUnlock()
return len(bp.proxies) return len(bp.proxies)
} }
@@ -93,6 +100,8 @@ func (bp *baseProvider) RegisterHealthCheckTask(url string, expectedStatus utils
} }
func (bp *baseProvider) setProxies(proxies []C.Proxy) { func (bp *baseProvider) setProxies(proxies []C.Proxy) {
bp.mutex.Lock()
defer bp.mutex.Unlock()
bp.proxies = proxies bp.proxies = proxies
bp.version += 1 bp.version += 1
bp.healthCheck.setProxies(proxies) bp.healthCheck.setProxies(proxies)
@@ -156,7 +165,7 @@ func (pp *proxySetProvider) Initial() error {
func (pp *proxySetProvider) closeAllConnections() { func (pp *proxySetProvider) closeAllConnections() {
statistic.DefaultManager.Range(func(c statistic.Tracker) bool { statistic.DefaultManager.Range(func(c statistic.Tracker) bool {
for _, chain := range c.Chains() { for _, chain := range c.ProviderChains() {
if chain == pp.Name() { if chain == pp.Name() {
_ = c.Close() _ = c.Close()
break break
@@ -330,7 +339,7 @@ func (cp *CompatibleProvider) Close() error {
return cp.compatibleProvider.Close() return cp.compatibleProvider.Close()
} }
func NewProxiesParser(filter string, excludeFilter string, excludeType string, dialerProxy string, override OverrideSchema) (resource.Parser[[]C.Proxy], error) { func NewProxiesParser(pdName string, filter string, excludeFilter string, excludeType string, dialerProxy string, override overrideSchema) (resource.Parser[[]C.Proxy], error) {
var excludeTypeArray []string var excludeTypeArray []string
if excludeType != "" { if excludeType != "" {
excludeTypeArray = strings.Split(excludeType, "|") excludeTypeArray = strings.Split(excludeType, "|")
@@ -419,36 +428,12 @@ func NewProxiesParser(filter string, excludeFilter string, excludeType string, d
mapping["dialer-proxy"] = dialerProxy mapping["dialer-proxy"] = dialerProxy
} }
val := reflect.ValueOf(override) err := override.Apply(mapping)
for i := 0; i < val.NumField(); i++ { if err != nil {
field := val.Field(i) return nil, fmt.Errorf("proxy %d override error: %w", idx, err)
if field.IsNil() {
continue
}
fieldName := strings.Split(val.Type().Field(i).Tag.Get("provider"), ",")[0]
switch fieldName {
case "additional-prefix":
name := mapping["name"].(string)
mapping["name"] = *field.Interface().(*string) + name
case "additional-suffix":
name := mapping["name"].(string)
mapping["name"] = name + *field.Interface().(*string)
case "proxy-name":
// Iterate through all naming replacement rules and perform the replacements.
for _, expr := range override.ProxyName {
name := mapping["name"].(string)
newName, err := expr.Pattern.Replace(name, expr.Target, 0, -1)
if err != nil {
return nil, fmt.Errorf("proxy name replace error: %w", err)
}
mapping["name"] = newName
}
default:
mapping[fieldName] = field.Elem().Interface()
}
} }
proxy, err := adapter.ParseProxy(mapping) proxy, err := adapter.ParseProxy(mapping, adapter.WithProviderName(pdName))
if err != nil { if err != nil {
return nil, fmt.Errorf("proxy %d error: %w", idx, err) return nil, fmt.Errorf("proxy %d error: %w", idx, err)
} }

View File

@@ -1,6 +1,8 @@
// Copyright 2014 The Go Authors. All rights reserved. // Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build android && cgo
// +build android,cgo
// kanged from https://github.com/golang/mobile/blob/c713f31d574bb632a93f169b2cc99c9e753fef0e/app/android.go#L89 // kanged from https://github.com/golang/mobile/blob/c713f31d574bb632a93f169b2cc99c9e753fef0e/app/android.go#L89

View File

@@ -201,6 +201,10 @@ func ConvertsV2Ray(buf []byte) ([]map[string]any, error) {
trojan["client-fingerprint"] = fingerprint trojan["client-fingerprint"] = fingerprint
} }
if pcs := query.Get("pcs"); pcs != "" {
trojan["fingerprint"] = pcs
}
proxies = append(proxies, trojan) proxies = append(proxies, trojan)
case "vless": case "vless":

View File

@@ -2,12 +2,12 @@ package convert
import ( import (
"encoding/base64" "encoding/base64"
"net/http"
"strings" "strings"
"time" "time"
"github.com/metacubex/mihomo/common/utils" "github.com/metacubex/mihomo/common/utils"
"github.com/metacubex/http"
"github.com/metacubex/randv2" "github.com/metacubex/randv2"
"github.com/metacubex/sing-shadowsocks/shadowimpl" "github.com/metacubex/sing-shadowsocks/shadowimpl"
) )

View File

@@ -35,6 +35,9 @@ func handleVShareLink(names map[string]int, url *url.URL, scheme string, proxy m
if alpn := query.Get("alpn"); alpn != "" { if alpn := query.Get("alpn"); alpn != "" {
proxy["alpn"] = strings.Split(alpn, ",") proxy["alpn"] = strings.Split(alpn, ",")
} }
if pcs := query.Get("pcs"); pcs != "" {
proxy["fingerprint"] = pcs
}
} }
if sni := query.Get("sni"); sni != "" { if sni := query.Get("sni"); sni != "" {
proxy["servername"] = sni proxy["servername"] = sni

674
common/deque/deque.go Normal file
View File

@@ -0,0 +1,674 @@
package deque
// copy and modified from https://github.com/gammazero/deque/blob/v1.2.0/deque.go
// which is licensed under MIT.
import (
"fmt"
)
// minCapacity is the smallest capacity that deque may have. Must be power of 2
// for bitwise modulus: x % n == x & (n - 1).
const minCapacity = 8
// Deque represents a single instance of the deque data structure. A Deque
// instance contains items of the type specified by the type argument.
//
// For example, to create a Deque that contains strings do one of the
// following:
//
// var stringDeque deque.Deque[string]
// stringDeque := new(deque.Deque[string])
// stringDeque := &deque.Deque[string]{}
//
// To create a Deque that will never resize to have space for less than 64
// items, specify a base capacity:
//
// var d deque.Deque[int]
// d.SetBaseCap(64)
//
// To ensure the Deque can store 1000 items without needing to resize while
// items are added:
//
// d.Grow(1000)
//
// Any values supplied to [SetBaseCap] and [Grow] are rounded up to the nearest
// power of 2, since the Deque grows by powers of 2.
type Deque[T any] struct {
buf []T
head int
tail int
count int
minCap int
}
// Cap returns the current capacity of the Deque. If q is nil, q.Cap() is zero.
func (q *Deque[T]) Cap() int {
if q == nil {
return 0
}
return len(q.buf)
}
// Len returns the number of elements currently stored in the queue. If q is
// nil, q.Len() returns zero.
func (q *Deque[T]) Len() int {
if q == nil {
return 0
}
return q.count
}
// PushBack appends an element to the back of the queue. Implements FIFO when
// elements are removed with [PopFront], and LIFO when elements are removed with
// [PopBack].
func (q *Deque[T]) PushBack(elem T) {
q.growIfFull()
q.buf[q.tail] = elem
// Calculate new tail position.
q.tail = q.next(q.tail)
q.count++
}
// PushFront prepends an element to the front of the queue.
func (q *Deque[T]) PushFront(elem T) {
q.growIfFull()
// Calculate new head position.
q.head = q.prev(q.head)
q.buf[q.head] = elem
q.count++
}
// PopFront removes and returns the element from the front of the queue.
// Implements FIFO when used with [PushBack]. If the queue is empty, the call
// panics.
func (q *Deque[T]) PopFront() T {
if q.count <= 0 {
panic("deque: PopFront() called on empty queue")
}
ret := q.buf[q.head]
var zero T
q.buf[q.head] = zero
// Calculate new head position.
q.head = q.next(q.head)
q.count--
q.shrinkIfExcess()
return ret
}
// IterPopFront returns an iterator that iteratively removes items from the
// front of the deque. This is more efficient than removing items one at a time
// because it avoids intermediate resizing. If a resize is necessary, only one
// is done when iteration ends.
func (q *Deque[T]) IterPopFront() func(yield func(T) bool) {
return func(yield func(T) bool) {
if q.Len() == 0 {
return
}
var zero T
for q.count != 0 {
ret := q.buf[q.head]
q.buf[q.head] = zero
q.head = q.next(q.head)
q.count--
if !yield(ret) {
break
}
}
q.shrinkToFit()
}
}
// PopBack removes and returns the element from the back of the queue.
// Implements LIFO when used with [PushBack]. If the queue is empty, the call
// panics.
func (q *Deque[T]) PopBack() T {
if q.count <= 0 {
panic("deque: PopBack() called on empty queue")
}
// Calculate new tail position
q.tail = q.prev(q.tail)
// Remove value at tail.
ret := q.buf[q.tail]
var zero T
q.buf[q.tail] = zero
q.count--
q.shrinkIfExcess()
return ret
}
// IterPopBack returns an iterator that iteratively removes items from the back
// of the deque. This is more efficient than removing items one at a time
// because it avoids intermediate resizing. If a resize is necessary, only one
// is done when iteration ends.
func (q *Deque[T]) IterPopBack() func(yield func(T) bool) {
return func(yield func(T) bool) {
if q.Len() == 0 {
return
}
var zero T
for q.count != 0 {
q.tail = q.prev(q.tail)
ret := q.buf[q.tail]
q.buf[q.tail] = zero
q.count--
if !yield(ret) {
break
}
}
q.shrinkToFit()
}
}
// Front returns the element at the front of the queue. This is the element
// that would be returned by [PopFront]. This call panics if the queue is
// empty.
func (q *Deque[T]) Front() T {
if q.count <= 0 {
panic("deque: Front() called when empty")
}
return q.buf[q.head]
}
// Back returns the element at the back of the queue. This is the element that
// would be returned by [PopBack]. This call panics if the queue is empty.
func (q *Deque[T]) Back() T {
if q.count <= 0 {
panic("deque: Back() called when empty")
}
return q.buf[q.prev(q.tail)]
}
// At returns the element at index i in the queue without removing the element
// from the queue. This method accepts only non-negative index values. At(0)
// refers to the first element and is the same as [Front]. At(Len()-1) refers
// to the last element and is the same as [Back]. If the index is invalid, the
// call panics.
//
// The purpose of At is to allow Deque to serve as a more general purpose
// circular buffer, where items are only added to and removed from the ends of
// the deque, but may be read from any place within the deque. Consider the
// case of a fixed-size circular log buffer: A new entry is pushed onto one end
// and when full the oldest is popped from the other end. All the log entries
// in the buffer must be readable without altering the buffer contents.
func (q *Deque[T]) At(i int) T {
q.checkRange(i)
// bitwise modulus
return q.buf[(q.head+i)&(len(q.buf)-1)]
}
// Set assigns the item to index i in the queue. Set indexes the deque the same
// as [At] but perform the opposite operation. If the index is invalid, the call
// panics.
func (q *Deque[T]) Set(i int, item T) {
q.checkRange(i)
// bitwise modulus
q.buf[(q.head+i)&(len(q.buf)-1)] = item
}
// Iter returns a go iterator to range over all items in the Deque, yielding
// each item from front (index 0) to back (index Len()-1). Modification of
// Deque during iteration panics.
func (q *Deque[T]) Iter() func(yield func(T) bool) {
return func(yield func(T) bool) {
origHead := q.head
origTail := q.tail
head := origHead
for i := -0; i < q.Len(); i++ {
if q.head != origHead || q.tail != origTail {
panic("deque: modified during iteration")
}
if !yield(q.buf[head]) {
return
}
head = q.next(head)
}
}
}
// RIter returns a reverse go iterator to range over all items in the Deque,
// yielding each item from back (index Len()-1) to front (index 0).
// Modification of Deque during iteration panics.
func (q *Deque[T]) RIter() func(yield func(T) bool) {
return func(yield func(T) bool) {
origHead := q.head
origTail := q.tail
tail := origTail
for i := -0; i < q.Len(); i++ {
if q.head != origHead || q.tail != origTail {
panic("deque: modified during iteration")
}
tail = q.prev(tail)
if !yield(q.buf[tail]) {
return
}
}
}
}
// Clear removes all elements from the queue, but retains the current capacity.
// This is useful when repeatedly reusing the queue at high frequency to avoid
// GC during reuse. The queue will not be resized smaller as long as items are
// only added. Only when items are removed is the queue subject to getting
// resized smaller.
func (q *Deque[T]) Clear() {
if q.Len() == 0 {
return
}
head, tail := q.head, q.tail
q.count = 0
q.head = 0
q.tail = 0
if head >= tail {
// [DEF....ABC]
clearSlice(q.buf[head:])
head = 0
}
clearSlice(q.buf[head:tail])
}
func clearSlice[S ~[]E, E any](s S) {
var zero E
for i := range s {
s[i] = zero
}
}
// Grow grows deque's capacity, if necessary, to guarantee space for another n
// items. After Grow(n), at least n items can be written to the deque without
// another allocation. If n is negative, Grow panics.
func (q *Deque[T]) Grow(n int) {
if n < 0 {
panic("deque.Grow: negative count")
}
c := q.Cap()
l := q.Len()
// If already big enough.
if n <= c-l {
return
}
if c == 0 {
c = minCapacity
}
newLen := l + n
for c < newLen {
c <<= 1
}
if l == 0 {
q.buf = make([]T, c)
q.head = 0
q.tail = 0
} else {
q.resize(c)
}
}
// Copy copies the contents of the given src Deque into this Deque.
//
// n := b.Copy(a)
//
// is an efficient shortcut for
//
// b.Clear()
// n := a.Len()
// b.Grow(n)
// for i := 0; i < n; i++ {
// b.PushBack(a.At(i))
// }
func (q *Deque[T]) Copy(src Deque[T]) int {
q.Clear()
q.Grow(src.Len())
n := src.CopyOutSlice(q.buf)
q.count = n
q.tail = n
q.head = 0
return n
}
// AppendToSlice appends from the Deque to the given slice. If the slice has
// insufficient capacity to store all elements in Deque, then allocate a new
// slice. Returns the resulting slice.
//
// out = q.AppendToSlice(out)
//
// is an efficient shortcut for
//
// for i := 0; i < q.Len(); i++ {
// x = append(out, q.At(i))
// }
func (q *Deque[T]) AppendToSlice(out []T) []T {
if q.count == 0 {
return out
}
head, tail := q.head, q.tail
if head >= tail {
// [DEF....ABC]
out = append(out, q.buf[head:]...)
head = 0
}
return append(out, q.buf[head:tail]...)
}
// CopyInSlice replaces the contents of Deque with all the elements from the
// given slice, in. If len(in) is zero, then this is equivalent to calling
// [Clear].
//
// q.CopyInSlice(in)
//
// is an efficient shortcut for
//
// q.Clear()
// for i := range in {
// q.PushBack(in[i])
// }
func (q *Deque[T]) CopyInSlice(in []T) {
// Allocate new buffer if more space needed.
if len(q.buf) < len(in) {
newCap := len(q.buf)
if newCap == 0 {
newCap = minCapacity
q.minCap = minCapacity
}
for newCap < len(in) {
newCap <<= 1
}
q.buf = make([]T, newCap)
} else if len(q.buf) > len(in) {
q.Clear()
}
n := copy(q.buf, in)
q.count = n
q.tail = n
q.head = 0
}
// CopyOutSlice copies elements from the Deque into the given slice, up to the
// size of the buffer. Returns the number of elements copied, which will be the
// minimum of q.Len() and len(out).
//
// n := q.CopyOutSlice(out)
//
// is an efficient shortcut for
//
// n := min(len(out), q.Len())
// for i := 0; i < n; i++ {
// out[i] = q.At(i)
// }
//
// This function is preferable to one that returns a copy of the internal
// buffer because this allows reuse of memory receiving data, for repeated copy
// operations.
func (q *Deque[T]) CopyOutSlice(out []T) int {
if q.count == 0 || len(out) == 0 {
return 0
}
head, tail := q.head, q.tail
var n int
if head >= tail {
// [DEF....ABC]
n = copy(out, q.buf[head:])
out = out[n:]
if len(out) == 0 {
return n
}
head = 0
}
n += copy(out, q.buf[head:tail])
return n
}
// Rotate rotates the deque n steps front-to-back. If n is negative, rotates
// back-to-front. Having Deque provide Rotate avoids resizing that could happen
// if implementing rotation using only Pop and Push methods. If q.Len() is one
// or less, or q is nil, then Rotate does nothing.
func (q *Deque[T]) Rotate(n int) {
if q.Len() <= 1 {
return
}
// Rotating a multiple of q.count is same as no rotation.
n %= q.count
if n == 0 {
return
}
modBits := len(q.buf) - 1
// If no empty space in buffer, only move head and tail indexes.
if q.head == q.tail {
// Calculate new head and tail using bitwise modulus.
q.head = (q.head + n) & modBits
q.tail = q.head
return
}
var zero T
if n < 0 {
// Rotate back to front.
for ; n < 0; n++ {
// Calculate new head and tail using bitwise modulus.
q.head = (q.head - 1) & modBits
q.tail = (q.tail - 1) & modBits
// Put tail value at head and remove value at tail.
q.buf[q.head] = q.buf[q.tail]
q.buf[q.tail] = zero
}
return
}
// Rotate front to back.
for ; n > 0; n-- {
// Put head value at tail and remove value at head.
q.buf[q.tail] = q.buf[q.head]
q.buf[q.head] = zero
// Calculate new head and tail using bitwise modulus.
q.head = (q.head + 1) & modBits
q.tail = (q.tail + 1) & modBits
}
}
// Index returns the index into the Deque of the first item satisfying f(item),
// or -1 if none do. If q is nil, then -1 is always returned. Search is linear
// starting with index 0.
func (q *Deque[T]) Index(f func(T) bool) int {
if q.Len() > 0 {
modBits := len(q.buf) - 1
for i := 0; i < q.count; i++ {
if f(q.buf[(q.head+i)&modBits]) {
return i
}
}
}
return -1
}
// RIndex is the same as Index, but searches from Back to Front. The index
// returned is from Front to Back, where index 0 is the index of the item
// returned by [Front].
func (q *Deque[T]) RIndex(f func(T) bool) int {
if q.Len() > 0 {
modBits := len(q.buf) - 1
for i := q.count - 1; i >= 0; i-- {
if f(q.buf[(q.head+i)&modBits]) {
return i
}
}
}
return -1
}
// Insert is used to insert an element into the middle of the queue, before the
// element at the specified index. Insert(0,e) is the same as PushFront(e) and
// Insert(Len(),e) is the same as PushBack(e). Out of range indexes result in
// pushing the item onto the front of back of the deque.
//
// Important: Deque is optimized for O(1) operations at the ends of the queue,
// not for operations in the the middle. Complexity of this function is
// constant plus linear in the lesser of the distances between the index and
// either of the ends of the queue.
func (q *Deque[T]) Insert(at int, item T) {
if at <= 0 {
q.PushFront(item)
return
}
if at >= q.Len() {
q.PushBack(item)
return
}
if at*2 < q.count {
q.PushFront(item)
front := q.head
for i := 0; i < at; i++ {
next := q.next(front)
q.buf[front], q.buf[next] = q.buf[next], q.buf[front]
front = next
}
return
}
swaps := q.count - at
q.PushBack(item)
back := q.prev(q.tail)
for i := 0; i < swaps; i++ {
prev := q.prev(back)
q.buf[back], q.buf[prev] = q.buf[prev], q.buf[back]
back = prev
}
}
// Remove removes and returns an element from the middle of the queue, at the
// specified index. Remove(0) is the same as [PopFront] and Remove(Len()-1) is
// the same as [PopBack]. Accepts only non-negative index values, and panics if
// index is out of range.
//
// Important: Deque is optimized for O(1) operations at the ends of the queue,
// not for operations in the the middle. Complexity of this function is
// constant plus linear in the lesser of the distances between the index and
// either of the ends of the queue.
func (q *Deque[T]) Remove(at int) T {
q.checkRange(at)
rm := (q.head + at) & (len(q.buf) - 1)
if at*2 < q.count {
for i := 0; i < at; i++ {
prev := q.prev(rm)
q.buf[prev], q.buf[rm] = q.buf[rm], q.buf[prev]
rm = prev
}
return q.PopFront()
}
swaps := q.count - at - 1
for i := 0; i < swaps; i++ {
next := q.next(rm)
q.buf[rm], q.buf[next] = q.buf[next], q.buf[rm]
rm = next
}
return q.PopBack()
}
// SetBaseCap sets a base capacity so that at least the specified number of
// items can always be stored without resizing.
func (q *Deque[T]) SetBaseCap(baseCap int) {
minCap := minCapacity
for minCap < baseCap {
minCap <<= 1
}
q.minCap = minCap
}
// Swap exchanges the two values at idxA and idxB. It panics if either index is
// out of range.
func (q *Deque[T]) Swap(idxA, idxB int) {
q.checkRange(idxA)
q.checkRange(idxB)
if idxA == idxB {
return
}
realA := (q.head + idxA) & (len(q.buf) - 1)
realB := (q.head + idxB) & (len(q.buf) - 1)
q.buf[realA], q.buf[realB] = q.buf[realB], q.buf[realA]
}
func (q *Deque[T]) checkRange(i int) {
if i < 0 || i >= q.count {
panic(fmt.Sprintf("deque: index out of range %d with length %d", i, q.Len()))
}
}
// prev returns the previous buffer position wrapping around buffer.
func (q *Deque[T]) prev(i int) int {
return (i - 1) & (len(q.buf) - 1) // bitwise modulus
}
// next returns the next buffer position wrapping around buffer.
func (q *Deque[T]) next(i int) int {
return (i + 1) & (len(q.buf) - 1) // bitwise modulus
}
// growIfFull resizes up if the buffer is full.
func (q *Deque[T]) growIfFull() {
if q.count != len(q.buf) {
return
}
if len(q.buf) == 0 {
if q.minCap == 0 {
q.minCap = minCapacity
}
q.buf = make([]T, q.minCap)
return
}
q.resize(q.count << 1)
}
// shrinkIfExcess resize down if the buffer 1/4 full.
func (q *Deque[T]) shrinkIfExcess() {
if len(q.buf) > q.minCap && (q.count<<2) == len(q.buf) {
q.resize(q.count << 1)
}
}
func (q *Deque[T]) shrinkToFit() {
if len(q.buf) > q.minCap && (q.count<<2) <= len(q.buf) {
if q.count == 0 {
q.head = 0
q.tail = 0
q.buf = make([]T, q.minCap)
return
}
c := q.minCap
for c < q.count {
c <<= 1
}
q.resize(c)
}
}
// resize resizes the deque to fit exactly twice its current contents. This is
// used to grow the queue when it is full, and also to shrink it when it is
// only a quarter full.
func (q *Deque[T]) resize(newSize int) {
newBuf := make([]T, newSize)
if q.tail > q.head {
copy(newBuf, q.buf[q.head:q.tail])
} else {
n := copy(newBuf, q.buf[q.head:])
copy(newBuf[n:], q.buf[:q.tail])
}
q.head = 0
q.tail = q.count
q.buf = newBuf
}

View File

@@ -1,6 +1,8 @@
package net package net
import ( import (
"crypto/sha1"
"encoding/base64"
"encoding/binary" "encoding/binary"
"math/bits" "math/bits"
) )
@@ -129,3 +131,13 @@ func MaskWebSocket(key uint32, b []byte) uint32 {
return key return key
} }
func GetWebSocketSecAccept(secKey string) string {
const magic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
const nonceSize = 24 // base64.StdEncoding.EncodedLen(nonceKeySize)
p := make([]byte, nonceSize+len(magic))
copy(p[:nonceSize], secKey)
copy(p[nonceSize:], magic)
sum := sha1.Sum(p)
return base64.StdEncoding.EncodeToString(sum[:])
}

8
common/orderedmap/doc.go Normal file
View File

@@ -0,0 +1,8 @@
package orderedmap
// copy and modified from https://github.com/wk8/go-ordered-map/tree/v2.1.8
// which is licensed under Apache v2.
//
// mihomo modified:
// 1. remove dependence of mailru/easyjson for MarshalJSON
// 2. remove dependence of buger/jsonparser for UnmarshalJSON

139
common/orderedmap/json.go Normal file
View File

@@ -0,0 +1,139 @@
package orderedmap
import (
"bytes"
"encoding"
"encoding/json"
"errors"
"fmt"
"reflect"
)
var (
_ json.Marshaler = &OrderedMap[int, any]{}
_ json.Unmarshaler = &OrderedMap[int, any]{}
)
// MarshalJSON implements the json.Marshaler interface.
func (om *OrderedMap[K, V]) MarshalJSON() ([]byte, error) { //nolint:funlen
if om == nil || om.list == nil {
return []byte("null"), nil
}
var buf bytes.Buffer
buf.WriteByte('{')
enc := json.NewEncoder(&buf)
for pair, firstIteration := om.Oldest(), true; pair != nil; pair = pair.Next() {
if firstIteration {
firstIteration = false
} else {
buf.WriteByte(',')
}
switch key := any(pair.Key).(type) {
case string, encoding.TextMarshaler:
if err := enc.Encode(pair.Key); err != nil {
return nil, err
}
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
buf.WriteByte('"')
buf.WriteString(fmt.Sprint(key))
buf.WriteByte('"')
default:
// this switch takes care of wrapper types around primitive types, such as
// type myType string
switch keyValue := reflect.ValueOf(key); keyValue.Type().Kind() {
case reflect.String:
if err := enc.Encode(pair.Key); err != nil {
return nil, err
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
buf.WriteByte('"')
buf.WriteString(fmt.Sprint(key))
buf.WriteByte('"')
default:
return nil, fmt.Errorf("unsupported key type: %T", key)
}
}
buf.WriteByte(':')
if err := enc.Encode(pair.Value); err != nil {
return nil, err
}
}
buf.WriteByte('}')
return buf.Bytes(), nil
}
// UnmarshalJSON implements the json.Unmarshaler interface.
func (om *OrderedMap[K, V]) UnmarshalJSON(data []byte) error {
if om.list == nil {
om.initialize(0)
}
d := json.NewDecoder(bytes.NewReader(data))
tok, err := d.Token()
if err != nil {
return err
}
if tok != json.Delim('{') {
return errors.New("expect JSON object open with '{'")
}
for d.More() {
// key
tok, err = d.Token()
if err != nil {
return err
}
keyStr, ok := tok.(string)
if !ok {
return fmt.Errorf("key must be a string, got %T\n", tok)
}
var key K
switch typedKey := any(&key).(type) {
case *string:
*typedKey = keyStr
case encoding.TextUnmarshaler:
err = typedKey.UnmarshalText([]byte(keyStr))
case *int, *int8, *int16, *int32, *int64, *uint, *uint8, *uint16, *uint32, *uint64:
err = json.Unmarshal([]byte(keyStr), typedKey)
default:
// this switch takes care of wrapper types around primitive types, such as
// type myType string
switch reflect.TypeOf(key).Kind() {
case reflect.String:
convertedKeyData := reflect.ValueOf(keyStr).Convert(reflect.TypeOf(key))
reflect.ValueOf(&key).Elem().Set(convertedKeyData)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
err = json.Unmarshal([]byte(keyStr), &key)
default:
err = fmt.Errorf("unsupported key type: %T", key)
}
}
if err != nil {
return err
}
// value
value, _ := om.Get(key)
err = d.Decode(&value)
if err != nil {
return err
}
om.Set(key, value)
}
tok, err = d.Token()
if err != nil {
return err
}
if tok != json.Delim('}') {
return errors.New("expect JSON object close with '}'")
}
return nil
}

View File

@@ -0,0 +1,117 @@
package orderedmap
// Adapted from https://github.com/dvyukov/go-fuzz-corpus/blob/c42c1b2/json/json.go
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func FuzzRoundTripJSON(f *testing.F) {
f.Fuzz(func(t *testing.T, data []byte) {
for _, testCase := range []struct {
name string
constructor func() any
// should be a function that asserts that 2 objects of the type returned by constructor are equal
equalityAssertion func(*testing.T, any, any) bool
}{
{
name: "with a string -> string map",
constructor: func() any { return &OrderedMap[string, string]{} },
equalityAssertion: assertOrderedMapsEqual[string, string],
},
{
name: "with a string -> int map",
constructor: func() any { return &OrderedMap[string, int]{} },
equalityAssertion: assertOrderedMapsEqual[string, int],
},
{
name: "with a string -> any map",
constructor: func() any { return &OrderedMap[string, any]{} },
equalityAssertion: assertOrderedMapsEqual[string, any],
},
{
name: "with a struct with map fields",
constructor: func() any { return new(testFuzzStruct) },
equalityAssertion: assertTestFuzzStructEqual,
},
} {
t.Run(testCase.name, func(t *testing.T) {
v1 := testCase.constructor()
if json.Unmarshal(data, v1) != nil {
return
}
jsonData, err := json.Marshal(v1)
require.NoError(t, err)
v2 := testCase.constructor()
require.NoError(t, json.Unmarshal(jsonData, v2))
if !assert.True(t, testCase.equalityAssertion(t, v1, v2), "failed with input data %q", string(data)) {
// look at that what the standard lib does with regular map, to help with debugging
var m1 map[string]any
require.NoError(t, json.Unmarshal(data, &m1))
mapJsonData, err := json.Marshal(m1)
require.NoError(t, err)
var m2 map[string]any
require.NoError(t, json.Unmarshal(mapJsonData, &m2))
t.Logf("initial data = %s", string(data))
t.Logf("unmarshalled map = %v", m1)
t.Logf("re-marshalled from map = %s", string(mapJsonData))
t.Logf("re-marshalled from test obj = %s", string(jsonData))
t.Logf("re-unmarshalled map = %s", m2)
}
})
}
})
}
// only works for fairly basic maps, that's why it's just in this file
func assertOrderedMapsEqual[K comparable, V any](t *testing.T, v1, v2 any) bool {
om1, ok1 := v1.(*OrderedMap[K, V])
om2, ok2 := v2.(*OrderedMap[K, V])
if !assert.True(t, ok1, "v1 not an orderedmap") ||
!assert.True(t, ok2, "v2 not an orderedmap") {
return false
}
success := assert.Equal(t, om1.Len(), om2.Len(), "om1 and om2 have different lengths: %d vs %d", om1.Len(), om2.Len())
for i, pair1, pair2 := 0, om1.Oldest(), om2.Oldest(); pair1 != nil && pair2 != nil; i, pair1, pair2 = i+1, pair1.Next(), pair2.Next() {
success = assert.Equal(t, pair1.Key, pair2.Key, "different keys at position %d: %v vs %v", i, pair1.Key, pair2.Key) && success
success = assert.Equal(t, pair1.Value, pair2.Value, "different values at position %d: %v vs %v", i, pair1.Value, pair2.Value) && success
}
return success
}
type testFuzzStruct struct {
M1 *OrderedMap[int, any]
M2 *OrderedMap[int, string]
M3 *OrderedMap[string, string]
}
func assertTestFuzzStructEqual(t *testing.T, v1, v2 any) bool {
s1, ok := v1.(*testFuzzStruct)
s2, ok := v2.(*testFuzzStruct)
if !assert.True(t, ok, "v1 not an testFuzzStruct") ||
!assert.True(t, ok, "v2 not an testFuzzStruct") {
return false
}
success := assertOrderedMapsEqual[int, any](t, s1.M1, s2.M1)
success = assertOrderedMapsEqual[int, string](t, s1.M2, s2.M2) && success
success = assertOrderedMapsEqual[string, string](t, s1.M3, s2.M3) && success
return success
}

View File

@@ -0,0 +1,338 @@
package orderedmap
import (
"encoding/json"
"errors"
"fmt"
"strconv"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// to test marshalling TextMarshalers and unmarshalling TextUnmarshalers
type marshallable int
func (m marshallable) MarshalText() ([]byte, error) {
return []byte(fmt.Sprintf("#%d#", m)), nil
}
func (m *marshallable) UnmarshalText(text []byte) error {
if len(text) < 3 {
return errors.New("too short")
}
if text[0] != '#' || text[len(text)-1] != '#' {
return errors.New("missing prefix or suffix")
}
value, err := strconv.Atoi(string(text[1 : len(text)-1]))
if err != nil {
return err
}
*m = marshallable(value)
return nil
}
func TestMarshalJSON(t *testing.T) {
t.Run("int key", func(t *testing.T) {
om := New[int, any]()
om.Set(1, "bar")
om.Set(7, "baz")
om.Set(2, 28)
om.Set(3, 100)
om.Set(4, "baz")
om.Set(5, "28")
om.Set(6, "100")
om.Set(8, "baz")
om.Set(8, "baz")
om.Set(9, "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Quisque auctor augue accumsan mi maximus, quis viverra massa pretium. Phasellus imperdiet sapien a interdum sollicitudin. Duis at commodo lectus, a lacinia sem.")
b, err := json.Marshal(om)
assert.NoError(t, err)
assert.Equal(t, `{"1":"bar","7":"baz","2":28,"3":100,"4":"baz","5":"28","6":"100","8":"baz","9":"Lorem ipsum dolor sit amet, consectetur adipiscing elit. Quisque auctor augue accumsan mi maximus, quis viverra massa pretium. Phasellus imperdiet sapien a interdum sollicitudin. Duis at commodo lectus, a lacinia sem."}`, string(b))
})
t.Run("string key", func(t *testing.T) {
om := New[string, any]()
om.Set("test", "bar")
om.Set("abc", true)
b, err := json.Marshal(om)
assert.NoError(t, err)
assert.Equal(t, `{"test":"bar","abc":true}`, string(b))
})
t.Run("typed string key", func(t *testing.T) {
type myString string
om := New[myString, any]()
om.Set("test", "bar")
om.Set("abc", true)
b, err := json.Marshal(om)
assert.NoError(t, err)
assert.Equal(t, `{"test":"bar","abc":true}`, string(b))
})
t.Run("typed int key", func(t *testing.T) {
type myInt uint32
om := New[myInt, any]()
om.Set(1, "bar")
om.Set(7, "baz")
om.Set(2, 28)
om.Set(3, 100)
om.Set(4, "baz")
b, err := json.Marshal(om)
assert.NoError(t, err)
assert.Equal(t, `{"1":"bar","7":"baz","2":28,"3":100,"4":"baz"}`, string(b))
})
t.Run("TextMarshaller key", func(t *testing.T) {
om := New[marshallable, any]()
om.Set(marshallable(1), "bar")
om.Set(marshallable(28), true)
b, err := json.Marshal(om)
assert.NoError(t, err)
assert.Equal(t, `{"#1#":"bar","#28#":true}`, string(b))
})
t.Run("empty map", func(t *testing.T) {
om := New[string, any]()
b, err := json.Marshal(om)
assert.NoError(t, err)
assert.Equal(t, `{}`, string(b))
})
}
func TestUnmarshallJSON(t *testing.T) {
t.Run("int key", func(t *testing.T) {
data := `{"1":"bar","7":"baz","2":28,"3":100,"4":"baz","5":"28","6":"100","8":"baz"}`
om := New[int, any]()
require.NoError(t, json.Unmarshal([]byte(data), &om))
assertOrderedPairsEqual(t, om,
[]int{1, 7, 2, 3, 4, 5, 6, 8},
[]any{"bar", "baz", float64(28), float64(100), "baz", "28", "100", "baz"})
})
t.Run("string key", func(t *testing.T) {
data := `{"test":"bar","abc":true}`
om := New[string, any]()
require.NoError(t, json.Unmarshal([]byte(data), &om))
assertOrderedPairsEqual(t, om,
[]string{"test", "abc"},
[]any{"bar", true})
})
t.Run("typed string key", func(t *testing.T) {
data := `{"test":"bar","abc":true}`
type myString string
om := New[myString, any]()
require.NoError(t, json.Unmarshal([]byte(data), &om))
assertOrderedPairsEqual(t, om,
[]myString{"test", "abc"},
[]any{"bar", true})
})
t.Run("typed int key", func(t *testing.T) {
data := `{"1":"bar","7":"baz","2":28,"3":100,"4":"baz","5":"28","6":"100","8":"baz"}`
type myInt uint32
om := New[myInt, any]()
require.NoError(t, json.Unmarshal([]byte(data), &om))
assertOrderedPairsEqual(t, om,
[]myInt{1, 7, 2, 3, 4, 5, 6, 8},
[]any{"bar", "baz", float64(28), float64(100), "baz", "28", "100", "baz"})
})
t.Run("TextUnmarshaler key", func(t *testing.T) {
data := `{"#1#":"bar","#28#":true}`
om := New[marshallable, any]()
require.NoError(t, json.Unmarshal([]byte(data), &om))
assertOrderedPairsEqual(t, om,
[]marshallable{1, 28},
[]any{"bar", true})
})
t.Run("when fed with an input that's not an object", func(t *testing.T) {
for _, data := range []string{"true", `["foo"]`, "42", `"foo"`} {
om := New[int, any]()
require.Error(t, json.Unmarshal([]byte(data), &om))
}
})
t.Run("empty map", func(t *testing.T) {
data := `{}`
om := New[int, any]()
require.NoError(t, json.Unmarshal([]byte(data), &om))
assertLenEqual(t, om, 0)
})
}
// const specialCharacters = "\\\\/\"\b\f\n\r\t\x00\uffff\ufffd世界\u007f\u00ff\U0010FFFF"
const specialCharacters = "\uffff\ufffd世界\u007f\u00ff\U0010FFFF"
func TestJSONSpecialCharacters(t *testing.T) {
baselineMap := map[string]any{specialCharacters: specialCharacters}
baselineData, err := json.Marshal(baselineMap)
require.NoError(t, err) // baseline proves this key is supported by official json library
t.Logf("specialCharacters: %#v as []rune:%v", specialCharacters, []rune(specialCharacters))
t.Logf("baseline json data: %s", baselineData)
t.Run("marshal special characters", func(t *testing.T) {
om := New[string, any]()
om.Set(specialCharacters, specialCharacters)
b, err := json.Marshal(om)
require.NoError(t, err)
require.Equal(t, baselineData, b)
type myString string
om2 := New[myString, myString]()
om2.Set(specialCharacters, specialCharacters)
b, err = json.Marshal(om2)
require.NoError(t, err)
require.Equal(t, baselineData, b)
})
t.Run("unmarshall special characters", func(t *testing.T) {
om := New[string, any]()
require.NoError(t, json.Unmarshal(baselineData, &om))
assertOrderedPairsEqual(t, om,
[]string{specialCharacters},
[]any{specialCharacters})
type myString string
om2 := New[myString, myString]()
require.NoError(t, json.Unmarshal(baselineData, &om2))
assertOrderedPairsEqual(t, om2,
[]myString{specialCharacters},
[]myString{specialCharacters})
})
}
// to test structs that have nested map fields
type nestedMaps struct {
X int `json:"x" yaml:"x"`
M *OrderedMap[string, []*OrderedMap[int, *OrderedMap[string, any]]] `json:"m" yaml:"m"`
}
func TestJSONRoundTrip(t *testing.T) {
for _, testCase := range []struct {
name string
input string
targetFactory func() any
isPrettyPrinted bool
}{
{
name: "",
input: `{
"x": 28,
"m": {
"foo": [
{
"12": {
"i": 12,
"b": true,
"n": null,
"m": {
"a": "b",
"c": 28
}
},
"28": {
"a": false,
"b": [
1,
2,
3
]
}
},
{
"3": {
"c": null,
"d": 87
},
"4": {
"e": true
},
"5": {
"f": 4,
"g": 5,
"h": 6
}
}
],
"bar": [
{
"5": {
"foo": "bar"
}
}
]
}
}`,
targetFactory: func() any { return &nestedMaps{} },
isPrettyPrinted: true,
},
{
name: "with UTF-8 special chars in key",
input: `{"<22>":0}`,
targetFactory: func() any { return &OrderedMap[string, int]{} },
},
} {
t.Run(testCase.name, func(t *testing.T) {
target := testCase.targetFactory()
require.NoError(t, json.Unmarshal([]byte(testCase.input), target))
var (
out []byte
err error
)
if testCase.isPrettyPrinted {
out, err = json.MarshalIndent(target, "", " ")
} else {
out, err = json.Marshal(target)
}
if assert.NoError(t, err) {
assert.Equal(t, strings.TrimSpace(testCase.input), string(out))
}
})
}
}
func BenchmarkMarshalJSON(b *testing.B) {
om := New[int, any]()
om.Set(1, "bar")
om.Set(7, "baz")
om.Set(2, 28)
om.Set(3, 100)
om.Set(4, "baz")
om.Set(5, "28")
om.Set(6, "100")
om.Set(8, "baz")
om.Set(8, "baz")
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = json.Marshal(om)
}
}

View File

@@ -0,0 +1,295 @@
// Package orderedmap implements an ordered map, i.e. a map that also keeps track of
// the order in which keys were inserted.
//
// All operations are constant-time.
//
// Github repo: https://github.com/wk8/go-ordered-map
package orderedmap
import (
"fmt"
list "github.com/bahlo/generic-list-go"
)
type Pair[K comparable, V any] struct {
Key K
Value V
element *list.Element[*Pair[K, V]]
}
type OrderedMap[K comparable, V any] struct {
pairs map[K]*Pair[K, V]
list *list.List[*Pair[K, V]]
}
type initConfig[K comparable, V any] struct {
capacity int
initialData []Pair[K, V]
}
type InitOption[K comparable, V any] func(config *initConfig[K, V])
// WithCapacity allows giving a capacity hint for the map, akin to the standard make(map[K]V, capacity).
func WithCapacity[K comparable, V any](capacity int) InitOption[K, V] {
return func(c *initConfig[K, V]) {
c.capacity = capacity
}
}
// WithInitialData allows passing in initial data for the map.
func WithInitialData[K comparable, V any](initialData ...Pair[K, V]) InitOption[K, V] {
return func(c *initConfig[K, V]) {
c.initialData = initialData
if c.capacity < len(initialData) {
c.capacity = len(initialData)
}
}
}
// New creates a new OrderedMap.
// options can either be one or several InitOption[K, V], or a single integer,
// which is then interpreted as a capacity hint, à la make(map[K]V, capacity).
func New[K comparable, V any](options ...any) *OrderedMap[K, V] { //nolint:varnamelen
orderedMap := &OrderedMap[K, V]{}
var config initConfig[K, V]
for _, untypedOption := range options {
switch option := untypedOption.(type) {
case int:
if len(options) != 1 {
invalidOption()
}
config.capacity = option
case InitOption[K, V]:
option(&config)
default:
invalidOption()
}
}
orderedMap.initialize(config.capacity)
orderedMap.AddPairs(config.initialData...)
return orderedMap
}
const invalidOptionMessage = `when using orderedmap.New[K,V]() with options, either provide one or several InitOption[K, V]; or a single integer which is then interpreted as a capacity hint, à la make(map[K]V, capacity).` //nolint:lll
func invalidOption() { panic(invalidOptionMessage) }
func (om *OrderedMap[K, V]) initialize(capacity int) {
om.pairs = make(map[K]*Pair[K, V], capacity)
om.list = list.New[*Pair[K, V]]()
}
// Get looks for the given key, and returns the value associated with it,
// or V's nil value if not found. The boolean it returns says whether the key is present in the map.
func (om *OrderedMap[K, V]) Get(key K) (val V, present bool) {
if pair, present := om.pairs[key]; present {
return pair.Value, true
}
return
}
// Load is an alias for Get, mostly to present an API similar to `sync.Map`'s.
func (om *OrderedMap[K, V]) Load(key K) (V, bool) {
return om.Get(key)
}
// Value returns the value associated with the given key or the zero value.
func (om *OrderedMap[K, V]) Value(key K) (val V) {
if pair, present := om.pairs[key]; present {
val = pair.Value
}
return
}
// GetPair looks for the given key, and returns the pair associated with it,
// or nil if not found. The Pair struct can then be used to iterate over the ordered map
// from that point, either forward or backward.
func (om *OrderedMap[K, V]) GetPair(key K) *Pair[K, V] {
return om.pairs[key]
}
// Set sets the key-value pair, and returns what `Get` would have returned
// on that key prior to the call to `Set`.
func (om *OrderedMap[K, V]) Set(key K, value V) (val V, present bool) {
if pair, present := om.pairs[key]; present {
oldValue := pair.Value
pair.Value = value
return oldValue, true
}
pair := &Pair[K, V]{
Key: key,
Value: value,
}
pair.element = om.list.PushBack(pair)
om.pairs[key] = pair
return
}
// AddPairs allows setting multiple pairs at a time. It's equivalent to calling
// Set on each pair sequentially.
func (om *OrderedMap[K, V]) AddPairs(pairs ...Pair[K, V]) {
for _, pair := range pairs {
om.Set(pair.Key, pair.Value)
}
}
// Store is an alias for Set, mostly to present an API similar to `sync.Map`'s.
func (om *OrderedMap[K, V]) Store(key K, value V) (V, bool) {
return om.Set(key, value)
}
// Delete removes the key-value pair, and returns what `Get` would have returned
// on that key prior to the call to `Delete`.
func (om *OrderedMap[K, V]) Delete(key K) (val V, present bool) {
if pair, present := om.pairs[key]; present {
om.list.Remove(pair.element)
delete(om.pairs, key)
return pair.Value, true
}
return
}
// Len returns the length of the ordered map.
func (om *OrderedMap[K, V]) Len() int {
if om == nil || om.pairs == nil {
return 0
}
return len(om.pairs)
}
// Oldest returns a pointer to the oldest pair. It's meant to be used to iterate on the ordered map's
// pairs from the oldest to the newest, e.g.:
// for pair := orderedMap.Oldest(); pair != nil; pair = pair.Next() { fmt.Printf("%v => %v\n", pair.Key, pair.Value) }
func (om *OrderedMap[K, V]) Oldest() *Pair[K, V] {
if om == nil || om.list == nil {
return nil
}
return listElementToPair(om.list.Front())
}
// Newest returns a pointer to the newest pair. It's meant to be used to iterate on the ordered map's
// pairs from the newest to the oldest, e.g.:
// for pair := orderedMap.Oldest(); pair != nil; pair = pair.Next() { fmt.Printf("%v => %v\n", pair.Key, pair.Value) }
func (om *OrderedMap[K, V]) Newest() *Pair[K, V] {
if om == nil || om.list == nil {
return nil
}
return listElementToPair(om.list.Back())
}
// Next returns a pointer to the next pair.
func (p *Pair[K, V]) Next() *Pair[K, V] {
return listElementToPair(p.element.Next())
}
// Prev returns a pointer to the previous pair.
func (p *Pair[K, V]) Prev() *Pair[K, V] {
return listElementToPair(p.element.Prev())
}
func listElementToPair[K comparable, V any](element *list.Element[*Pair[K, V]]) *Pair[K, V] {
if element == nil {
return nil
}
return element.Value
}
// KeyNotFoundError may be returned by functions in this package when they're called with keys that are not present
// in the map.
type KeyNotFoundError[K comparable] struct {
MissingKey K
}
func (e *KeyNotFoundError[K]) Error() string {
return fmt.Sprintf("missing key: %v", e.MissingKey)
}
// MoveAfter moves the value associated with key to its new position after the one associated with markKey.
// Returns an error iff key or markKey are not present in the map. If an error is returned,
// it will be a KeyNotFoundError.
func (om *OrderedMap[K, V]) MoveAfter(key, markKey K) error {
elements, err := om.getElements(key, markKey)
if err != nil {
return err
}
om.list.MoveAfter(elements[0], elements[1])
return nil
}
// MoveBefore moves the value associated with key to its new position before the one associated with markKey.
// Returns an error iff key or markKey are not present in the map. If an error is returned,
// it will be a KeyNotFoundError.
func (om *OrderedMap[K, V]) MoveBefore(key, markKey K) error {
elements, err := om.getElements(key, markKey)
if err != nil {
return err
}
om.list.MoveBefore(elements[0], elements[1])
return nil
}
func (om *OrderedMap[K, V]) getElements(keys ...K) ([]*list.Element[*Pair[K, V]], error) {
elements := make([]*list.Element[*Pair[K, V]], len(keys))
for i, k := range keys {
pair, present := om.pairs[k]
if !present {
return nil, &KeyNotFoundError[K]{k}
}
elements[i] = pair.element
}
return elements, nil
}
// MoveToBack moves the value associated with key to the back of the ordered map,
// i.e. makes it the newest pair in the map.
// Returns an error iff key is not present in the map. If an error is returned,
// it will be a KeyNotFoundError.
func (om *OrderedMap[K, V]) MoveToBack(key K) error {
_, err := om.GetAndMoveToBack(key)
return err
}
// MoveToFront moves the value associated with key to the front of the ordered map,
// i.e. makes it the oldest pair in the map.
// Returns an error iff key is not present in the map. If an error is returned,
// it will be a KeyNotFoundError.
func (om *OrderedMap[K, V]) MoveToFront(key K) error {
_, err := om.GetAndMoveToFront(key)
return err
}
// GetAndMoveToBack combines Get and MoveToBack in the same call. If an error is returned,
// it will be a KeyNotFoundError.
func (om *OrderedMap[K, V]) GetAndMoveToBack(key K) (val V, err error) {
if pair, present := om.pairs[key]; present {
val = pair.Value
om.list.MoveToBack(pair.element)
} else {
err = &KeyNotFoundError[K]{key}
}
return
}
// GetAndMoveToFront combines Get and MoveToFront in the same call. If an error is returned,
// it will be a KeyNotFoundError.
func (om *OrderedMap[K, V]) GetAndMoveToFront(key K) (val V, err error) {
if pair, present := om.pairs[key]; present {
val = pair.Value
om.list.MoveToFront(pair.element)
} else {
err = &KeyNotFoundError[K]{key}
}
return
}

View File

@@ -0,0 +1,384 @@
package orderedmap
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
)
func TestBasicFeatures(t *testing.T) {
n := 100
om := New[int, int]()
// set(i, 2 * i)
for i := 0; i < n; i++ {
assertLenEqual(t, om, i)
oldValue, present := om.Set(i, 2*i)
assertLenEqual(t, om, i+1)
assert.Equal(t, 0, oldValue)
assert.False(t, present)
}
// get what we just set
for i := 0; i < n; i++ {
value, present := om.Get(i)
assert.Equal(t, 2*i, value)
assert.Equal(t, value, om.Value(i))
assert.True(t, present)
}
// get pairs of what we just set
for i := 0; i < n; i++ {
pair := om.GetPair(i)
assert.NotNil(t, pair)
assert.Equal(t, 2*i, pair.Value)
}
// forward iteration
i := 0
for pair := om.Oldest(); pair != nil; pair = pair.Next() {
assert.Equal(t, i, pair.Key)
assert.Equal(t, 2*i, pair.Value)
i++
}
// backward iteration
i = n - 1
for pair := om.Newest(); pair != nil; pair = pair.Prev() {
assert.Equal(t, i, pair.Key)
assert.Equal(t, 2*i, pair.Value)
i--
}
// forward iteration starting from known key
i = 42
for pair := om.GetPair(i); pair != nil; pair = pair.Next() {
assert.Equal(t, i, pair.Key)
assert.Equal(t, 2*i, pair.Value)
i++
}
// double values for pairs with even keys
for j := 0; j < n/2; j++ {
i = 2 * j
oldValue, present := om.Set(i, 4*i)
assert.Equal(t, 2*i, oldValue)
assert.True(t, present)
}
// and delete pairs with odd keys
for j := 0; j < n/2; j++ {
i = 2*j + 1
assertLenEqual(t, om, n-j)
value, present := om.Delete(i)
assertLenEqual(t, om, n-j-1)
assert.Equal(t, 2*i, value)
assert.True(t, present)
// deleting again shouldn't change anything
value, present = om.Delete(i)
assertLenEqual(t, om, n-j-1)
assert.Equal(t, 0, value)
assert.False(t, present)
}
// get the whole range
for j := 0; j < n/2; j++ {
i = 2 * j
value, present := om.Get(i)
assert.Equal(t, 4*i, value)
assert.Equal(t, value, om.Value(i))
assert.True(t, present)
i = 2*j + 1
value, present = om.Get(i)
assert.Equal(t, 0, value)
assert.Equal(t, value, om.Value(i))
assert.False(t, present)
}
// check iterations again
i = 0
for pair := om.Oldest(); pair != nil; pair = pair.Next() {
assert.Equal(t, i, pair.Key)
assert.Equal(t, 4*i, pair.Value)
i += 2
}
i = 2 * ((n - 1) / 2)
for pair := om.Newest(); pair != nil; pair = pair.Prev() {
assert.Equal(t, i, pair.Key)
assert.Equal(t, 4*i, pair.Value)
i -= 2
}
}
func TestUpdatingDoesntChangePairsOrder(t *testing.T) {
om := New[string, any]()
om.Set("foo", "bar")
om.Set("wk", 28)
om.Set("po", 100)
om.Set("bar", "baz")
oldValue, present := om.Set("po", 102)
assert.Equal(t, 100, oldValue)
assert.True(t, present)
assertOrderedPairsEqual(t, om,
[]string{"foo", "wk", "po", "bar"},
[]any{"bar", 28, 102, "baz"})
}
func TestDeletingAndReinsertingChangesPairsOrder(t *testing.T) {
om := New[string, any]()
om.Set("foo", "bar")
om.Set("wk", 28)
om.Set("po", 100)
om.Set("bar", "baz")
// delete a pair
oldValue, present := om.Delete("po")
assert.Equal(t, 100, oldValue)
assert.True(t, present)
// re-insert the same pair
oldValue, present = om.Set("po", 100)
assert.Nil(t, oldValue)
assert.False(t, present)
assertOrderedPairsEqual(t, om,
[]string{"foo", "wk", "bar", "po"},
[]any{"bar", 28, "baz", 100})
}
func TestEmptyMapOperations(t *testing.T) {
om := New[string, any]()
oldValue, present := om.Get("foo")
assert.Nil(t, oldValue)
assert.Nil(t, om.Value("foo"))
assert.False(t, present)
oldValue, present = om.Delete("bar")
assert.Nil(t, oldValue)
assert.False(t, present)
assertLenEqual(t, om, 0)
assert.Nil(t, om.Oldest())
assert.Nil(t, om.Newest())
}
type dummyTestStruct struct {
value string
}
func TestPackUnpackStructs(t *testing.T) {
om := New[string, dummyTestStruct]()
om.Set("foo", dummyTestStruct{"foo!"})
om.Set("bar", dummyTestStruct{"bar!"})
value, present := om.Get("foo")
assert.True(t, present)
assert.Equal(t, value, om.Value("foo"))
if assert.NotNil(t, value) {
assert.Equal(t, "foo!", value.value)
}
value, present = om.Set("bar", dummyTestStruct{"baz!"})
assert.True(t, present)
if assert.NotNil(t, value) {
assert.Equal(t, "bar!", value.value)
}
value, present = om.Get("bar")
assert.Equal(t, value, om.Value("bar"))
assert.True(t, present)
if assert.NotNil(t, value) {
assert.Equal(t, "baz!", value.value)
}
}
// shamelessly stolen from https://github.com/python/cpython/blob/e19a91e45fd54a56e39c2d12e6aaf4757030507f/Lib/test/test_ordered_dict.py#L55-L61
func TestShuffle(t *testing.T) {
ranLen := 100
for _, n := range []int{0, 10, 20, 100, 1000, 10000} {
t.Run(fmt.Sprintf("shuffle test with %d items", n), func(t *testing.T) {
om := New[string, string]()
keys := make([]string, n)
values := make([]string, n)
for i := 0; i < n; i++ {
// we prefix with the number to ensure that we don't get any duplicates
keys[i] = fmt.Sprintf("%d_%s", i, randomHexString(t, ranLen))
values[i] = randomHexString(t, ranLen)
value, present := om.Set(keys[i], values[i])
assert.Equal(t, "", value)
assert.False(t, present)
}
assertOrderedPairsEqual(t, om, keys, values)
})
}
}
func TestMove(t *testing.T) {
om := New[int, any]()
om.Set(1, "bar")
om.Set(2, 28)
om.Set(3, 100)
om.Set(4, "baz")
om.Set(5, "28")
om.Set(6, "100")
om.Set(7, "baz")
om.Set(8, "baz")
err := om.MoveAfter(2, 3)
assert.Nil(t, err)
assertOrderedPairsEqual(t, om,
[]int{1, 3, 2, 4, 5, 6, 7, 8},
[]any{"bar", 100, 28, "baz", "28", "100", "baz", "baz"})
err = om.MoveBefore(6, 4)
assert.Nil(t, err)
assertOrderedPairsEqual(t, om,
[]int{1, 3, 2, 6, 4, 5, 7, 8},
[]any{"bar", 100, 28, "100", "baz", "28", "baz", "baz"})
err = om.MoveToBack(3)
assert.Nil(t, err)
assertOrderedPairsEqual(t, om,
[]int{1, 2, 6, 4, 5, 7, 8, 3},
[]any{"bar", 28, "100", "baz", "28", "baz", "baz", 100})
err = om.MoveToFront(5)
assert.Nil(t, err)
assertOrderedPairsEqual(t, om,
[]int{5, 1, 2, 6, 4, 7, 8, 3},
[]any{"28", "bar", 28, "100", "baz", "baz", "baz", 100})
err = om.MoveToFront(100)
assert.Equal(t, &KeyNotFoundError[int]{100}, err)
}
func TestGetAndMove(t *testing.T) {
om := New[int, any]()
om.Set(1, "bar")
om.Set(2, 28)
om.Set(3, 100)
om.Set(4, "baz")
om.Set(5, "28")
om.Set(6, "100")
om.Set(7, "baz")
om.Set(8, "baz")
value, err := om.GetAndMoveToBack(3)
assert.Nil(t, err)
assert.Equal(t, 100, value)
assertOrderedPairsEqual(t, om,
[]int{1, 2, 4, 5, 6, 7, 8, 3},
[]any{"bar", 28, "baz", "28", "100", "baz", "baz", 100})
value, err = om.GetAndMoveToFront(5)
assert.Nil(t, err)
assert.Equal(t, "28", value)
assertOrderedPairsEqual(t, om,
[]int{5, 1, 2, 4, 6, 7, 8, 3},
[]any{"28", "bar", 28, "baz", "100", "baz", "baz", 100})
value, err = om.GetAndMoveToBack(100)
assert.Equal(t, &KeyNotFoundError[int]{100}, err)
}
func TestAddPairs(t *testing.T) {
om := New[int, any]()
om.AddPairs(
Pair[int, any]{
Key: 28,
Value: "foo",
},
Pair[int, any]{
Key: 12,
Value: "bar",
},
Pair[int, any]{
Key: 28,
Value: "baz",
},
)
assertOrderedPairsEqual(t, om,
[]int{28, 12},
[]any{"baz", "bar"})
}
// sadly, we can't test the "actual" capacity here, see https://github.com/golang/go/issues/52157
func TestNewWithCapacity(t *testing.T) {
zero := New[int, string](0)
assert.Empty(t, zero.Len())
assert.PanicsWithValue(t, invalidOptionMessage, func() {
_ = New[int, string](1, 2)
})
assert.PanicsWithValue(t, invalidOptionMessage, func() {
_ = New[int, string](1, 2, 3)
})
om := New[int, string](-1)
om.Set(1337, "quarante-deux")
assert.Equal(t, 1, om.Len())
}
func TestNewWithOptions(t *testing.T) {
t.Run("wih capacity", func(t *testing.T) {
om := New[string, any](WithCapacity[string, any](98))
assert.Equal(t, 0, om.Len())
})
t.Run("with initial data", func(t *testing.T) {
om := New[string, int](WithInitialData(
Pair[string, int]{
Key: "a",
Value: 1,
},
Pair[string, int]{
Key: "b",
Value: 2,
},
Pair[string, int]{
Key: "c",
Value: 3,
},
))
assertOrderedPairsEqual(t, om,
[]string{"a", "b", "c"},
[]int{1, 2, 3})
})
t.Run("with an invalid option type", func(t *testing.T) {
assert.PanicsWithValue(t, invalidOptionMessage, func() {
_ = New[int, string]("foo")
})
})
}
func TestNilMap(t *testing.T) {
// we want certain behaviors of a nil ordered map to be the same as they are for standard nil maps
var om *OrderedMap[int, any]
t.Run("len", func(t *testing.T) {
assert.Equal(t, 0, om.Len())
})
t.Run("iterating - akin to range", func(t *testing.T) {
assert.Nil(t, om.Oldest())
assert.Nil(t, om.Newest())
})
}

View File

@@ -0,0 +1,76 @@
package orderedmap
import (
"crypto/rand"
"encoding/hex"
"fmt"
"testing"
"github.com/stretchr/testify/assert"
)
// assertOrderedPairsEqual asserts that the map contains the given keys and values
// from oldest to newest.
func assertOrderedPairsEqual[K comparable, V any](
t *testing.T, orderedMap *OrderedMap[K, V], expectedKeys []K, expectedValues []V,
) {
t.Helper()
assertOrderedPairsEqualFromNewest(t, orderedMap, expectedKeys, expectedValues)
assertOrderedPairsEqualFromOldest(t, orderedMap, expectedKeys, expectedValues)
}
func assertOrderedPairsEqualFromNewest[K comparable, V any](
t *testing.T, orderedMap *OrderedMap[K, V], expectedKeys []K, expectedValues []V,
) {
t.Helper()
if assert.Equal(t, len(expectedKeys), len(expectedValues)) && assert.Equal(t, len(expectedKeys), orderedMap.Len()) {
i := orderedMap.Len() - 1
for pair := orderedMap.Newest(); pair != nil; pair = pair.Prev() {
assert.Equal(t, expectedKeys[i], pair.Key, "from newest index=%d on key", i)
assert.Equal(t, expectedValues[i], pair.Value, "from newest index=%d on value", i)
i--
}
}
}
func assertOrderedPairsEqualFromOldest[K comparable, V any](
t *testing.T, orderedMap *OrderedMap[K, V], expectedKeys []K, expectedValues []V,
) {
t.Helper()
if assert.Equal(t, len(expectedKeys), len(expectedValues)) && assert.Equal(t, len(expectedKeys), orderedMap.Len()) {
i := 0
for pair := orderedMap.Oldest(); pair != nil; pair = pair.Next() {
assert.Equal(t, expectedKeys[i], pair.Key, "from oldest index=%d on key", i)
assert.Equal(t, expectedValues[i], pair.Value, "from oldest index=%d on value", i)
i++
}
}
}
func assertLenEqual[K comparable, V any](t *testing.T, orderedMap *OrderedMap[K, V], expectedLen int) {
t.Helper()
assert.Equal(t, expectedLen, orderedMap.Len())
// also check the list length, for good measure
assert.Equal(t, expectedLen, orderedMap.list.Len())
}
func randomHexString(t *testing.T, length int) string {
t.Helper()
b := length / 2 //nolint:gomnd
randBytes := make([]byte, b)
if n, err := rand.Read(randBytes); err != nil || n != b {
if err == nil {
err = fmt.Errorf("only got %v random bytes, expected %v", n, b)
}
t.Fatal(err)
}
return hex.EncodeToString(randBytes)
}

71
common/orderedmap/yaml.go Normal file
View File

@@ -0,0 +1,71 @@
package orderedmap
import (
"fmt"
"gopkg.in/yaml.v3"
)
var (
_ yaml.Marshaler = &OrderedMap[int, any]{}
_ yaml.Unmarshaler = &OrderedMap[int, any]{}
)
// MarshalYAML implements the yaml.Marshaler interface.
func (om *OrderedMap[K, V]) MarshalYAML() (interface{}, error) {
if om == nil {
return []byte("null"), nil
}
node := yaml.Node{
Kind: yaml.MappingNode,
}
for pair := om.Oldest(); pair != nil; pair = pair.Next() {
key, value := pair.Key, pair.Value
keyNode := &yaml.Node{}
// serialize key to yaml, then deserialize it back into the node
// this is a hack to get the correct tag for the key
if err := keyNode.Encode(key); err != nil {
return nil, err
}
valueNode := &yaml.Node{}
if err := valueNode.Encode(value); err != nil {
return nil, err
}
node.Content = append(node.Content, keyNode, valueNode)
}
return &node, nil
}
// UnmarshalYAML implements the yaml.Unmarshaler interface.
func (om *OrderedMap[K, V]) UnmarshalYAML(value *yaml.Node) error {
if value.Kind != yaml.MappingNode {
return fmt.Errorf("pipeline must contain YAML mapping, has %v", value.Kind)
}
if om.list == nil {
om.initialize(0)
}
for index := 0; index < len(value.Content); index += 2 {
var key K
var val V
if err := value.Content[index].Decode(&key); err != nil {
return err
}
if err := value.Content[index+1].Decode(&val); err != nil {
return err
}
om.Set(key, val)
}
return nil
}

View File

@@ -0,0 +1,82 @@
package orderedmap
// Adapted from https://github.com/dvyukov/go-fuzz-corpus/blob/c42c1b2/json/json.go
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gopkg.in/yaml.v3"
)
func FuzzRoundTripYAML(f *testing.F) {
f.Fuzz(func(t *testing.T, data []byte) {
for _, testCase := range []struct {
name string
constructor func() any
// should be a function that asserts that 2 objects of the type returned by constructor are equal
equalityAssertion func(*testing.T, any, any) bool
}{
{
name: "with a string -> string map",
constructor: func() any { return &OrderedMap[string, string]{} },
equalityAssertion: assertOrderedMapsEqual[string, string],
},
{
name: "with a string -> int map",
constructor: func() any { return &OrderedMap[string, int]{} },
equalityAssertion: assertOrderedMapsEqual[string, int],
},
{
name: "with a string -> any map",
constructor: func() any { return &OrderedMap[string, any]{} },
equalityAssertion: assertOrderedMapsEqual[string, any],
},
{
name: "with a struct with map fields",
constructor: func() any { return new(testFuzzStruct) },
equalityAssertion: assertTestFuzzStructEqual,
},
} {
t.Run(testCase.name, func(t *testing.T) {
v1 := testCase.constructor()
if yaml.Unmarshal(data, v1) != nil {
return
}
t.Log(data)
t.Log(v1)
yamlData, err := yaml.Marshal(v1)
require.NoError(t, err)
t.Log(string(yamlData))
v2 := testCase.constructor()
err = yaml.Unmarshal(yamlData, v2)
if err != nil {
t.Log(string(yamlData))
t.Fatal(err)
}
if !assert.True(t, testCase.equalityAssertion(t, v1, v2), "failed with input data %q", string(data)) {
// look at that what the standard lib does with regular map, to help with debugging
var m1 map[string]any
require.NoError(t, yaml.Unmarshal(data, &m1))
mapJsonData, err := yaml.Marshal(m1)
require.NoError(t, err)
var m2 map[string]any
require.NoError(t, yaml.Unmarshal(mapJsonData, &m2))
t.Logf("initial data = %s", string(data))
t.Logf("unmarshalled map = %v", m1)
t.Logf("re-marshalled from map = %s", string(mapJsonData))
t.Logf("re-marshalled from test obj = %s", string(yamlData))
t.Logf("re-unmarshalled map = %s", m2)
}
})
}
})
}

View File

@@ -0,0 +1,334 @@
package orderedmap
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gopkg.in/yaml.v3"
)
func TestMarshalYAML(t *testing.T) {
t.Run("int key", func(t *testing.T) {
om := New[int, any]()
om.Set(1, "bar")
om.Set(7, "baz")
om.Set(2, 28)
om.Set(3, 100)
om.Set(4, "baz")
om.Set(5, "28")
om.Set(6, "100")
om.Set(8, "baz")
om.Set(8, "baz")
om.Set(9, "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Quisque auctor augue accumsan mi maximus, quis viverra massa pretium. Phasellus imperdiet sapien a interdum sollicitudin. Duis at commodo lectus, a lacinia sem.")
b, err := yaml.Marshal(om)
expected := `1: bar
7: baz
2: 28
3: 100
4: baz
5: "28"
6: "100"
8: baz
9: Lorem ipsum dolor sit amet, consectetur adipiscing elit. Quisque auctor augue accumsan mi maximus, quis viverra massa pretium. Phasellus imperdiet sapien a interdum sollicitudin. Duis at commodo lectus, a lacinia sem.
`
assert.NoError(t, err)
assert.Equal(t, expected, string(b))
})
t.Run("string key", func(t *testing.T) {
om := New[string, any]()
om.Set("test", "bar")
om.Set("abc", true)
b, err := yaml.Marshal(om)
assert.NoError(t, err)
expected := `test: bar
abc: true
`
assert.Equal(t, expected, string(b))
})
t.Run("typed string key", func(t *testing.T) {
type myString string
om := New[myString, any]()
om.Set("test", "bar")
om.Set("abc", true)
b, err := yaml.Marshal(om)
assert.NoError(t, err)
assert.Equal(t, `test: bar
abc: true
`, string(b))
})
t.Run("typed int key", func(t *testing.T) {
type myInt uint32
om := New[myInt, any]()
om.Set(1, "bar")
om.Set(7, "baz")
om.Set(2, 28)
om.Set(3, 100)
om.Set(4, "baz")
b, err := yaml.Marshal(om)
assert.NoError(t, err)
assert.Equal(t, `1: bar
7: baz
2: 28
3: 100
4: baz
`, string(b))
})
t.Run("TextMarshaller key", func(t *testing.T) {
om := New[marshallable, any]()
om.Set(marshallable(1), "bar")
om.Set(marshallable(28), true)
b, err := yaml.Marshal(om)
assert.NoError(t, err)
assert.Equal(t, `'#1#': bar
'#28#': true
`, string(b))
})
t.Run("empty map with 0 elements", func(t *testing.T) {
om := New[string, any]()
b, err := yaml.Marshal(om)
assert.NoError(t, err)
assert.Equal(t, "{}\n", string(b))
})
t.Run("empty map with no elements (null)", func(t *testing.T) {
om := &OrderedMap[string, string]{}
b, err := yaml.Marshal(om)
assert.NoError(t, err)
assert.Equal(t, "{}\n", string(b))
})
}
func TestUnmarshallYAML(t *testing.T) {
t.Run("int key", func(t *testing.T) {
data := `
1: bar
7: baz
2: 28
3: 100
4: baz
5: "28"
6: "100"
8: baz
`
om := New[int, any]()
require.NoError(t, yaml.Unmarshal([]byte(data), &om))
assertOrderedPairsEqual(t, om,
[]int{1, 7, 2, 3, 4, 5, 6, 8},
[]any{"bar", "baz", 28, 100, "baz", "28", "100", "baz"})
// serialize back to yaml to make sure things are equal
})
t.Run("string key", func(t *testing.T) {
data := `{"test":"bar","abc":true}`
om := New[string, any]()
require.NoError(t, yaml.Unmarshal([]byte(data), &om))
assertOrderedPairsEqual(t, om,
[]string{"test", "abc"},
[]any{"bar", true})
})
t.Run("typed string key", func(t *testing.T) {
data := `{"test":"bar","abc":true}`
type myString string
om := New[myString, any]()
require.NoError(t, yaml.Unmarshal([]byte(data), &om))
assertOrderedPairsEqual(t, om,
[]myString{"test", "abc"},
[]any{"bar", true})
})
t.Run("typed int key", func(t *testing.T) {
data := `
1: bar
7: baz
2: 28
3: 100
4: baz
5: "28"
6: "100"
8: baz
`
type myInt uint32
om := New[myInt, any]()
require.NoError(t, yaml.Unmarshal([]byte(data), &om))
assertOrderedPairsEqual(t, om,
[]myInt{1, 7, 2, 3, 4, 5, 6, 8},
[]any{"bar", "baz", 28, 100, "baz", "28", "100", "baz"})
})
t.Run("TextUnmarshaler key", func(t *testing.T) {
data := `{"#1#":"bar","#28#":true}`
om := New[marshallable, any]()
require.NoError(t, yaml.Unmarshal([]byte(data), &om))
assertOrderedPairsEqual(t, om,
[]marshallable{1, 28},
[]any{"bar", true})
})
t.Run("when fed with an input that's not an object", func(t *testing.T) {
for _, data := range []string{"true", `["foo"]`, "42", `"foo"`} {
om := New[int, any]()
require.Error(t, yaml.Unmarshal([]byte(data), &om))
}
})
t.Run("empty map", func(t *testing.T) {
data := `{}`
om := New[int, any]()
require.NoError(t, yaml.Unmarshal([]byte(data), &om))
assertLenEqual(t, om, 0)
})
}
func TestYAMLSpecialCharacters(t *testing.T) {
baselineMap := map[string]any{specialCharacters: specialCharacters}
baselineData, err := yaml.Marshal(baselineMap)
require.NoError(t, err) // baseline proves this key is supported by official yaml library
t.Logf("specialCharacters: %#v as []rune:%v", specialCharacters, []rune(specialCharacters))
t.Logf("baseline yaml data: %s", baselineData)
t.Run("marshal special characters", func(t *testing.T) {
om := New[string, any]()
om.Set(specialCharacters, specialCharacters)
b, err := yaml.Marshal(om)
require.NoError(t, err)
require.Equal(t, baselineData, b)
type myString string
om2 := New[myString, myString]()
om2.Set(specialCharacters, specialCharacters)
b, err = yaml.Marshal(om2)
require.NoError(t, err)
require.Equal(t, baselineData, b)
})
t.Run("unmarshall special characters", func(t *testing.T) {
om := New[string, any]()
require.NoError(t, yaml.Unmarshal(baselineData, &om))
assertOrderedPairsEqual(t, om,
[]string{specialCharacters},
[]any{specialCharacters})
type myString string
om2 := New[myString, myString]()
require.NoError(t, yaml.Unmarshal(baselineData, &om2))
assertOrderedPairsEqual(t, om2,
[]myString{specialCharacters},
[]myString{specialCharacters})
})
}
func TestYAMLRoundTrip(t *testing.T) {
for _, testCase := range []struct {
name string
input string
targetFactory func() any
}{
{
name: "empty map",
input: "{}\n",
targetFactory: func() any {
return &OrderedMap[string, any]{}
},
},
{
name: "",
input: `x: 28
m:
bar:
- 5:
foo: bar
foo:
- 12:
b: true
i: 12
m:
a: b
c: 28
"n": null
28:
a: false
b:
- 1
- 2
- 3
- 3:
c: null
d: 87
4:
e: true
5:
f: 4
g: 5
h: 6
`,
targetFactory: func() any { return &nestedMaps{} },
},
{
name: "with UTF-8 special chars in key",
input: "<22>: 0\n",
targetFactory: func() any { return &OrderedMap[string, int]{} },
},
} {
t.Run(testCase.name, func(t *testing.T) {
target := testCase.targetFactory()
require.NoError(t, yaml.Unmarshal([]byte(testCase.input), target))
var (
out []byte
err error
)
out, err = yaml.Marshal(target)
if assert.NoError(t, err) {
assert.Equal(t, testCase.input, string(out))
}
})
}
}
func BenchmarkMarshalYAML(b *testing.B) {
om := New[int, any]()
om.Set(1, "bar")
om.Set(7, "baz")
om.Set(2, 28)
om.Set(3, 100)
om.Set(4, "baz")
om.Set(5, "28")
om.Set(6, "100")
om.Set(8, "baz")
om.Set(8, "baz")
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = yaml.Marshal(om)
}
}

View File

@@ -517,6 +517,10 @@ func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) e
fieldName = tagValue fieldName = tagValue
} }
if tagValue == "-" {
continue
}
rawMapKey := reflect.ValueOf(fieldName) rawMapKey := reflect.ValueOf(fieldName)
rawMapVal := dataVal.MapIndex(rawMapKey) rawMapVal := dataVal.MapIndex(rawMapKey)
if !rawMapVal.IsValid() { if !rawMapVal.IsValid() {

View File

@@ -308,3 +308,27 @@ func TestStructure_Ignore(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, s.MustIgnore, "oldData") assert.Equal(t, s.MustIgnore, "oldData")
} }
func TestStructure_IgnoreInNest(t *testing.T) {
rawMap := map[string]any{
"-": "newData",
}
type TP struct {
MustIgnore string `test:"-"`
}
s := struct {
TP
}{TP{MustIgnore: "oldData"}}
err := decoder.Decode(rawMap, &s)
assert.Nil(t, err)
assert.Equal(t, s.MustIgnore, "oldData")
// test omitempty
delete(rawMap, "-")
err = decoder.Decode(rawMap, &s)
assert.Nil(t, err)
assert.Equal(t, s.MustIgnore, "oldData")
}

14
common/yaml/yaml.go Normal file
View File

@@ -0,0 +1,14 @@
// Package yaml provides a common entrance for YAML marshaling and unmarshalling.
package yaml
import (
"gopkg.in/yaml.v3"
)
func Unmarshal(in []byte, out any) (err error) {
return yaml.Unmarshal(in, out)
}
func Marshal(in any) (out []byte, err error) {
return yaml.Marshal(in)
}

View File

@@ -1,17 +1,17 @@
package tls package ca
import ( import (
utls "github.com/metacubex/utls" "github.com/metacubex/tls"
) )
type ClientAuthType = utls.ClientAuthType type ClientAuthType = tls.ClientAuthType
const ( const (
NoClientCert = utls.NoClientCert NoClientCert = tls.NoClientCert
RequestClientCert = utls.RequestClientCert RequestClientCert = tls.RequestClientCert
RequireAnyClientCert = utls.RequireAnyClientCert RequireAnyClientCert = tls.RequireAnyClientCert
VerifyClientCertIfGiven = utls.VerifyClientCertIfGiven VerifyClientCertIfGiven = tls.VerifyClientCertIfGiven
RequireAndVerifyClientCert = utls.RequireAndVerifyClientCert RequireAndVerifyClientCert = tls.RequireAndVerifyClientCert
) )
func ClientAuthTypeFromString(s string) ClientAuthType { func ClientAuthTypeFromString(s string) ClientAuthType {

View File

@@ -1,7 +1,6 @@
package ca package ca
import ( import (
"crypto/tls"
"crypto/x509" "crypto/x509"
_ "embed" _ "embed"
"errors" "errors"
@@ -11,8 +10,9 @@ import (
"sync" "sync"
"github.com/metacubex/mihomo/common/once" "github.com/metacubex/mihomo/common/once"
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/ntp" "github.com/metacubex/mihomo/ntp"
"github.com/metacubex/tls"
) )
var globalCertPool *x509.CertPool var globalCertPool *x509.CertPool
@@ -98,20 +98,27 @@ func GetTLSConfig(opt Option) (tlsConfig *tls.Config, err error) {
} }
if len(opt.Fingerprint) > 0 { if len(opt.Fingerprint) > 0 {
tlsConfig.VerifyPeerCertificate, err = NewFingerprintVerifier(opt.Fingerprint, tlsConfig.Time) verifier, err := NewFingerprintVerifier(opt.Fingerprint, tlsConfig.Time)
if err != nil { if err != nil {
return nil, err return nil, err
} }
tlsConfig.VerifyConnection = func(state tls.ConnectionState) error {
// [ConnectionState.ServerName] can return the actual ServerName needed for verification,
// avoiding inconsistencies caused by [tlsConfig.ServerName] being modified after the [NewFingerprintVerifier] call.
// https://github.com/golang/go/issues/36736#issuecomment-587925536
return verifier(state.PeerCertificates, state.ServerName)
}
tlsConfig.InsecureSkipVerify = true tlsConfig.InsecureSkipVerify = true
} }
if len(opt.Certificate) > 0 || len(opt.PrivateKey) > 0 { if len(opt.Certificate) > 0 || len(opt.PrivateKey) > 0 {
var cert tls.Certificate certLoader, err := NewTLSKeyPairLoader(opt.Certificate, opt.PrivateKey)
cert, err = LoadTLSKeyPair(opt.Certificate, opt.PrivateKey, C.Path)
if err != nil { if err != nil {
return nil, err return nil, err
} }
tlsConfig.Certificates = []tls.Certificate{cert} tlsConfig.GetClientCertificate = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
return certLoader()
}
} }
return tlsConfig, nil return tlsConfig, nil
} }

View File

@@ -11,7 +11,7 @@ import (
) )
// NewFingerprintVerifier returns a function that verifies whether a certificate's SHA-256 fingerprint matches the given one. // NewFingerprintVerifier returns a function that verifies whether a certificate's SHA-256 fingerprint matches the given one.
func NewFingerprintVerifier(fingerprint string, time func() time.Time) (func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error, error) { func NewFingerprintVerifier(fingerprint string, time func() time.Time) (func(certs []*x509.Certificate, serverName string) error, error) {
switch fingerprint { switch fingerprint {
case "chrome", "firefox", "safari", "ios", "android", "edge", "360", "qq", "random", "randomized": // WTF??? case "chrome", "firefox", "safari", "ios", "android", "edge", "360", "qq", "random", "randomized": // WTF???
return nil, fmt.Errorf("`fingerprint` is used for TLS certificate pinning. If you need to specify the browser fingerprint, use `client-fingerprint`") return nil, fmt.Errorf("`fingerprint` is used for TLS certificate pinning. If you need to specify the browser fingerprint, use `client-fingerprint`")
@@ -26,37 +26,24 @@ func NewFingerprintVerifier(fingerprint string, time func() time.Time) (func(raw
return nil, fmt.Errorf("fingerprint string length error,need sha256 fingerprint") return nil, fmt.Errorf("fingerprint string length error,need sha256 fingerprint")
} }
return func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { return func(certs []*x509.Certificate, serverName string) error {
// ssl pining // ssl pining
for i, rawCert := range rawCerts { for i, cert := range certs {
hash := sha256.Sum256(rawCert) hash := sha256.Sum256(cert.Raw)
if bytes.Equal(fpByte, hash[:]) { if bytes.Equal(fpByte, hash[:]) {
if i > 0 { if i > 0 {
// When the fingerprint matches a non-leaf certificate, // When the fingerprint matches a non-leaf certificate,
// the certificate chain validity is verified using the certificate as the trusted root certificate. // the certificate chain validity is verified using the certificate as the trusted root certificate.
//
// Currently, we do not verify that the SNI matches the certificate's DNS name,
// but we do verify the validity of the child certificate,
// including the issuance time and whether the child certificate was issued by the parent certificate.
certs := make([]*x509.Certificate, i+1) // stop at i
for j := range certs {
cert, err := x509.ParseCertificate(rawCerts[j])
if err != nil {
return err
}
certs[j] = cert
}
opts := x509.VerifyOptions{ opts := x509.VerifyOptions{
Roots: x509.NewCertPool(), Roots: x509.NewCertPool(),
Intermediates: x509.NewCertPool(), Intermediates: x509.NewCertPool(),
DNSName: serverName,
} }
if time != nil { if time != nil {
opts.CurrentTime = time() opts.CurrentTime = time()
} }
opts.Roots.AddCert(certs[i]) opts.Roots.AddCert(certs[i])
for _, cert := range certs[1:] { for _, cert := range certs[1 : i+1] { // stop at i
opts.Intermediates.AddCert(cert) opts.Intermediates.AddCert(cert)
} }
_, err := certs[0].Verify(opts) _, err := certs[0].Verify(opts)

View File

@@ -1,6 +1,7 @@
package ca package ca
import ( import (
"crypto/x509"
"encoding/pem" "encoding/pem"
"testing" "testing"
"time" "time"
@@ -10,90 +11,203 @@ import (
) )
func TestFingerprintVerifierLeaf(t *testing.T) { func TestFingerprintVerifierLeaf(t *testing.T) {
leafFingerprint := CalculateFingerprint(leafPEM.Bytes) leafFingerprint := CalculateFingerprint(leafCert.Raw)
verifier, err := NewFingerprintVerifier(leafFingerprint, func() time.Time { verifier, err := NewFingerprintVerifier(leafFingerprint, certTime)
return time.Unix(1677615892, 0)
})
require.NoError(t, err) require.NoError(t, err)
err = verifier([][]byte{leafPEM.Bytes, intermediatePEM.Bytes, rootPEM.Bytes}, nil) err = verifier([]*x509.Certificate{leafCert, intermediateCert, rootCert}, "")
assert.NoError(t, err) assert.NoError(t, err)
err = verifier([][]byte{leafPEM.Bytes, smimeIntermediatePEM.Bytes, rootPEM.Bytes}, nil) err = verifier([]*x509.Certificate{leafCert, intermediateCert, rootCert}, leafServerName)
assert.NoError(t, err) assert.NoError(t, err)
err = verifier([][]byte{leafPEM.Bytes, intermediatePEM.Bytes, smimeRootPEM.Bytes}, nil) err = verifier([]*x509.Certificate{leafCert, intermediateCert, rootCert}, wrongLeafServerName)
assert.NoError(t, err) assert.NoError(t, err)
err = verifier([][]byte{leafWithInvalidHashPEM.Bytes, intermediatePEM.Bytes, rootPEM.Bytes}, nil) err = verifier([]*x509.Certificate{leafCert, smimeIntermediateCert, rootCert}, "")
assert.NoError(t, err)
err = verifier([]*x509.Certificate{leafCert, smimeIntermediateCert, rootCert}, leafServerName)
assert.NoError(t, err)
err = verifier([]*x509.Certificate{leafCert, smimeIntermediateCert, rootCert}, wrongLeafServerName)
assert.NoError(t, err)
err = verifier([]*x509.Certificate{leafCert, intermediateCert, smimeRootCert}, "")
assert.NoError(t, err)
err = verifier([]*x509.Certificate{leafCert, intermediateCert, smimeRootCert}, leafServerName)
assert.NoError(t, err)
err = verifier([]*x509.Certificate{leafCert, intermediateCert, smimeRootCert}, wrongLeafServerName)
assert.NoError(t, err)
err = verifier([]*x509.Certificate{leafWithInvalidHashCert, intermediateCert, rootCert}, "")
assert.Error(t, err) assert.Error(t, err)
err = verifier([][]byte{smimeLeafPEM.Bytes, intermediatePEM.Bytes, rootPEM.Bytes}, nil) err = verifier([]*x509.Certificate{leafWithInvalidHashCert, intermediateCert, rootCert}, leafServerName)
assert.Error(t, err) assert.Error(t, err)
err = verifier([][]byte{smimeLeafPEM.Bytes, intermediatePEM.Bytes, smimeRootPEM.Bytes}, nil) err = verifier([]*x509.Certificate{leafWithInvalidHashCert, intermediateCert, rootCert}, wrongLeafServerName)
assert.Error(t, err)
err = verifier([]*x509.Certificate{smimeLeafCert, intermediateCert, rootCert}, "")
assert.Error(t, err)
err = verifier([]*x509.Certificate{smimeLeafCert, intermediateCert, rootCert}, leafServerName)
assert.Error(t, err)
err = verifier([]*x509.Certificate{smimeLeafCert, intermediateCert, rootCert}, wrongLeafServerName)
assert.Error(t, err)
err = verifier([]*x509.Certificate{smimeLeafCert, intermediateCert, smimeRootCert}, "")
assert.Error(t, err)
err = verifier([]*x509.Certificate{smimeLeafCert, intermediateCert, smimeRootCert}, leafServerName)
assert.Error(t, err)
err = verifier([]*x509.Certificate{smimeLeafCert, intermediateCert, smimeRootCert}, wrongLeafServerName)
assert.Error(t, err) assert.Error(t, err)
} }
func TestFingerprintVerifierIntermediate(t *testing.T) { func TestFingerprintVerifierIntermediate(t *testing.T) {
intermediateFingerprint := CalculateFingerprint(intermediatePEM.Bytes) intermediateFingerprint := CalculateFingerprint(intermediateCert.Raw)
verifier, err := NewFingerprintVerifier(intermediateFingerprint, func() time.Time { verifier, err := NewFingerprintVerifier(intermediateFingerprint, certTime)
return time.Unix(1677615892, 0)
})
require.NoError(t, err) require.NoError(t, err)
err = verifier([][]byte{leafPEM.Bytes, intermediatePEM.Bytes, rootPEM.Bytes}, nil) err = verifier([]*x509.Certificate{leafCert, intermediateCert, rootCert}, "")
assert.NoError(t, err) assert.NoError(t, err)
err = verifier([][]byte{leafPEM.Bytes, smimeIntermediatePEM.Bytes, rootPEM.Bytes}, nil) err = verifier([]*x509.Certificate{leafCert, intermediateCert, rootCert}, leafServerName)
assert.Error(t, err)
err = verifier([][]byte{leafPEM.Bytes, intermediatePEM.Bytes, smimeRootPEM.Bytes}, nil)
assert.NoError(t, err) assert.NoError(t, err)
err = verifier([][]byte{leafWithInvalidHashPEM.Bytes, intermediatePEM.Bytes, rootPEM.Bytes}, nil) err = verifier([]*x509.Certificate{leafCert, intermediateCert, rootCert}, wrongLeafServerName)
assert.Error(t, err) assert.Error(t, err)
err = verifier([][]byte{smimeLeafPEM.Bytes, intermediatePEM.Bytes, rootPEM.Bytes}, nil) err = verifier([]*x509.Certificate{leafCert, smimeIntermediateCert, rootCert}, "")
assert.Error(t, err) assert.Error(t, err)
err = verifier([][]byte{smimeLeafPEM.Bytes, intermediatePEM.Bytes, smimeRootPEM.Bytes}, nil) err = verifier([]*x509.Certificate{leafCert, smimeIntermediateCert, rootCert}, leafServerName)
assert.Error(t, err)
err = verifier([]*x509.Certificate{leafCert, smimeIntermediateCert, rootCert}, wrongLeafServerName)
assert.Error(t, err)
err = verifier([]*x509.Certificate{leafCert, intermediateCert, smimeRootCert}, "")
assert.NoError(t, err)
err = verifier([]*x509.Certificate{leafCert, intermediateCert, smimeRootCert}, leafServerName)
assert.NoError(t, err)
err = verifier([]*x509.Certificate{leafCert, intermediateCert, smimeRootCert}, wrongLeafServerName)
assert.Error(t, err)
err = verifier([]*x509.Certificate{leafWithInvalidHashCert, intermediateCert, rootCert}, "")
assert.Error(t, err)
err = verifier([]*x509.Certificate{leafWithInvalidHashCert, intermediateCert, rootCert}, leafServerName)
assert.Error(t, err)
err = verifier([]*x509.Certificate{leafWithInvalidHashCert, intermediateCert, rootCert}, wrongLeafServerName)
assert.Error(t, err)
err = verifier([]*x509.Certificate{smimeLeafCert, intermediateCert, rootCert}, "")
assert.Error(t, err)
err = verifier([]*x509.Certificate{smimeLeafCert, intermediateCert, rootCert}, leafServerName)
assert.Error(t, err)
err = verifier([]*x509.Certificate{smimeLeafCert, intermediateCert, rootCert}, wrongLeafServerName)
assert.Error(t, err)
err = verifier([]*x509.Certificate{smimeLeafCert, intermediateCert, smimeRootCert}, "")
assert.Error(t, err)
err = verifier([]*x509.Certificate{smimeLeafCert, intermediateCert, smimeRootCert}, leafServerName)
assert.Error(t, err)
err = verifier([]*x509.Certificate{smimeLeafCert, intermediateCert, smimeRootCert}, wrongLeafServerName)
assert.Error(t, err) assert.Error(t, err)
} }
func TestFingerprintVerifierRoot(t *testing.T) { func TestFingerprintVerifierRoot(t *testing.T) {
rootFingerprint := CalculateFingerprint(rootPEM.Bytes) rootFingerprint := CalculateFingerprint(rootCert.Raw)
verifier, err := NewFingerprintVerifier(rootFingerprint, func() time.Time { verifier, err := NewFingerprintVerifier(rootFingerprint, certTime)
return time.Unix(1677615892, 0)
})
require.NoError(t, err) require.NoError(t, err)
err = verifier([][]byte{leafPEM.Bytes, intermediatePEM.Bytes, rootPEM.Bytes}, nil) err = verifier([]*x509.Certificate{leafCert, intermediateCert, rootCert}, "")
assert.NoError(t, err) assert.NoError(t, err)
err = verifier([][]byte{leafPEM.Bytes, smimeIntermediatePEM.Bytes, rootPEM.Bytes}, nil) err = verifier([]*x509.Certificate{leafCert, intermediateCert, rootCert}, leafServerName)
assert.NoError(t, err)
err = verifier([]*x509.Certificate{leafCert, intermediateCert, rootCert}, wrongLeafServerName)
assert.Error(t, err) assert.Error(t, err)
err = verifier([][]byte{leafPEM.Bytes, intermediatePEM.Bytes, smimeRootPEM.Bytes}, nil) err = verifier([]*x509.Certificate{leafCert, smimeIntermediateCert, rootCert}, "")
assert.Error(t, err) assert.Error(t, err)
err = verifier([][]byte{leafWithInvalidHashPEM.Bytes, intermediatePEM.Bytes, rootPEM.Bytes}, nil) err = verifier([]*x509.Certificate{leafCert, smimeIntermediateCert, rootCert}, leafServerName)
assert.Error(t, err) assert.Error(t, err)
err = verifier([][]byte{smimeLeafPEM.Bytes, intermediatePEM.Bytes, rootPEM.Bytes}, nil) err = verifier([]*x509.Certificate{leafCert, smimeIntermediateCert, rootCert}, wrongLeafServerName)
assert.Error(t, err) assert.Error(t, err)
err = verifier([][]byte{smimeLeafPEM.Bytes, intermediatePEM.Bytes, smimeRootPEM.Bytes}, nil) err = verifier([]*x509.Certificate{leafCert, intermediateCert, smimeRootCert}, "")
assert.Error(t, err)
err = verifier([]*x509.Certificate{leafCert, intermediateCert, smimeRootCert}, leafServerName)
assert.Error(t, err)
err = verifier([]*x509.Certificate{leafCert, intermediateCert, smimeRootCert}, wrongLeafServerName)
assert.Error(t, err)
err = verifier([]*x509.Certificate{leafWithInvalidHashCert, intermediateCert, rootCert}, "")
assert.Error(t, err)
err = verifier([]*x509.Certificate{leafWithInvalidHashCert, intermediateCert, rootCert}, leafServerName)
assert.Error(t, err)
err = verifier([]*x509.Certificate{leafWithInvalidHashCert, intermediateCert, rootCert}, wrongLeafServerName)
assert.Error(t, err)
err = verifier([]*x509.Certificate{smimeLeafCert, intermediateCert, rootCert}, "")
assert.Error(t, err)
err = verifier([]*x509.Certificate{smimeLeafCert, intermediateCert, rootCert}, leafServerName)
assert.Error(t, err)
err = verifier([]*x509.Certificate{smimeLeafCert, intermediateCert, rootCert}, wrongLeafServerName)
assert.Error(t, err)
err = verifier([]*x509.Certificate{smimeLeafCert, intermediateCert, smimeRootCert}, "")
assert.Error(t, err)
err = verifier([]*x509.Certificate{smimeLeafCert, intermediateCert, smimeRootCert}, leafServerName)
assert.Error(t, err)
err = verifier([]*x509.Certificate{smimeLeafCert, intermediateCert, smimeRootCert}, wrongLeafServerName)
assert.Error(t, err) assert.Error(t, err)
} }
var rootPEM, _ = pem.Decode([]byte(gtsRoot)) var rootPEM, _ = pem.Decode([]byte(gtsRoot))
var rootCert, _ = x509.ParseCertificate(rootPEM.Bytes)
var intermediatePEM, _ = pem.Decode([]byte(gtsIntermediate)) var intermediatePEM, _ = pem.Decode([]byte(gtsIntermediate))
var intermediateCert, _ = x509.ParseCertificate(intermediatePEM.Bytes)
var leafPEM, _ = pem.Decode([]byte(googleLeaf)) var leafPEM, _ = pem.Decode([]byte(googleLeaf))
var leafCert, _ = x509.ParseCertificate(leafPEM.Bytes)
var leafWithInvalidHashPEM, _ = pem.Decode([]byte(googleLeafWithInvalidHash)) var leafWithInvalidHashPEM, _ = pem.Decode([]byte(googleLeafWithInvalidHash))
var leafWithInvalidHashCert, _ = x509.ParseCertificate(leafWithInvalidHashPEM.Bytes)
var smimeRootPEM, _ = pem.Decode([]byte(smimeRoot)) var smimeRootPEM, _ = pem.Decode([]byte(smimeRoot))
var smimeRootCert, _ = x509.ParseCertificate(smimeRootPEM.Bytes)
var smimeIntermediatePEM, _ = pem.Decode([]byte(smimeIntermediate)) var smimeIntermediatePEM, _ = pem.Decode([]byte(smimeIntermediate))
var smimeIntermediateCert, _ = x509.ParseCertificate(smimeIntermediatePEM.Bytes)
var smimeLeafPEM, _ = pem.Decode([]byte(smimeLeaf)) var smimeLeafPEM, _ = pem.Decode([]byte(smimeLeaf))
var smimeLeafCert, _ = x509.ParseCertificate(smimeLeafPEM.Bytes)
var certTime = func() time.Time { return time.Unix(1677615892, 0) }
const leafServerName = "www.google.com"
const wrongLeafServerName = "www.google.com.cn"
const gtsIntermediate = `-----BEGIN CERTIFICATE----- const gtsIntermediate = `-----BEGIN CERTIFICATE-----
MIIFljCCA36gAwIBAgINAgO8U1lrNMcY9QFQZjANBgkqhkiG9w0BAQsFADBHMQsw MIIFljCCA36gAwIBAgINAgO8U1lrNMcY9QFQZjANBgkqhkiG9w0BAQsFADBHMQsw

View File

@@ -7,71 +7,85 @@ import (
"crypto/elliptic" "crypto/elliptic"
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
"crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"math/big" "math/big"
"os" "os"
"runtime"
"sync"
"time" "time"
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/fswatch"
"github.com/metacubex/tls"
) )
type Path interface { // NewTLSKeyPairLoader creates a loader function for TLS key pairs from the provided certificate and private key data or file paths.
Resolve(path string) string
IsSafePath(path string) bool
ErrNotSafePath(path string) error
}
// LoadTLSKeyPair loads a TLS key pair from the provided certificate and private key data or file paths, supporting fallback resolution.
// Returns a tls.Certificate and an error, where the error indicates issues during parsing or file loading.
// If both certificate and privateKey are empty, generates a random TLS RSA key pair. // If both certificate and privateKey are empty, generates a random TLS RSA key pair.
// Accepts a Path interface for resolving file paths when necessary. func NewTLSKeyPairLoader(certificate, privateKey string) (func() (*tls.Certificate, error), error) {
func LoadTLSKeyPair(certificate, privateKey string, path Path) (tls.Certificate, error) {
if certificate == "" && privateKey == "" { if certificate == "" && privateKey == "" {
var err error var err error
certificate, privateKey, _, err = NewRandomTLSKeyPair(KeyPairTypeRSA) certificate, privateKey, _, err = NewRandomTLSKeyPair(KeyPairTypeRSA)
if err != nil { if err != nil {
return tls.Certificate{}, err return nil, err
} }
} }
cert, painTextErr := tls.X509KeyPair([]byte(certificate), []byte(privateKey)) cert, painTextErr := tls.X509KeyPair([]byte(certificate), []byte(privateKey))
if painTextErr == nil { if painTextErr == nil {
return cert, nil return func() (*tls.Certificate, error) {
} return &cert, nil
if path == nil { }, nil
return tls.Certificate{}, painTextErr
} }
certificate = path.Resolve(certificate) certificate = C.Path.Resolve(certificate)
privateKey = path.Resolve(privateKey) privateKey = C.Path.Resolve(privateKey)
var loadErr error var loadErr error
if !path.IsSafePath(certificate) { if !C.Path.IsSafePath(certificate) {
loadErr = path.ErrNotSafePath(certificate) loadErr = C.Path.ErrNotSafePath(certificate)
} else if !path.IsSafePath(privateKey) { } else if !C.Path.IsSafePath(privateKey) {
loadErr = path.ErrNotSafePath(privateKey) loadErr = C.Path.ErrNotSafePath(privateKey)
} else { } else {
cert, loadErr = tls.LoadX509KeyPair(certificate, privateKey) cert, loadErr = tls.LoadX509KeyPair(certificate, privateKey)
} }
if loadErr != nil { if loadErr != nil {
return tls.Certificate{}, fmt.Errorf("parse certificate failed, maybe format error:%s, or path error: %s", painTextErr.Error(), loadErr.Error()) return nil, fmt.Errorf("parse certificate failed, maybe format error:%s, or path error: %s", painTextErr.Error(), loadErr.Error())
} }
return cert, nil gcFlag := new(os.File) // tiny (on the order of 16 bytes or less) and pointer-free objects may never run the finalizer, so we choose new an os.File
updateMutex := sync.RWMutex{}
if watcher, err := fswatch.NewWatcher(fswatch.Options{Path: []string{certificate, privateKey}, Callback: func(path string) {
updateMutex.Lock()
defer updateMutex.Unlock()
if newCert, err := tls.LoadX509KeyPair(certificate, privateKey); err == nil {
cert = newCert
}
}}); err == nil {
if err = watcher.Start(); err == nil {
runtime.SetFinalizer(gcFlag, func(f *os.File) {
_ = watcher.Close()
})
}
}
return func() (*tls.Certificate, error) {
defer runtime.KeepAlive(gcFlag)
updateMutex.RLock()
defer updateMutex.RUnlock()
return &cert, nil
}, nil
} }
func LoadCertificates(certificate string, path Path) (*x509.CertPool, error) { func LoadCertificates(certificate string) (*x509.CertPool, error) {
pool := x509.NewCertPool() pool := x509.NewCertPool()
if pool.AppendCertsFromPEM([]byte(certificate)) { if pool.AppendCertsFromPEM([]byte(certificate)) {
return pool, nil return pool, nil
} }
painTextErr := fmt.Errorf("invalid certificate: %s", certificate) painTextErr := fmt.Errorf("invalid certificate: %s", certificate)
if path == nil {
return nil, painTextErr
}
certificate = path.Resolve(certificate) certificate = C.Path.Resolve(certificate)
var loadErr error var loadErr error
if !path.IsSafePath(certificate) { if !C.Path.IsSafePath(certificate) {
loadErr = path.ErrNotSafePath(certificate) loadErr = C.Path.ErrNotSafePath(certificate)
} else { } else {
certPEMBlock, err := os.ReadFile(certificate) certPEMBlock, err := os.ReadFile(certificate)
if pool.AppendCertsFromPEM(certPEMBlock) { if pool.AppendCertsFromPEM(certPEMBlock) {
@@ -82,6 +96,9 @@ func LoadCertificates(certificate string, path Path) (*x509.CertPool, error) {
if loadErr != nil { if loadErr != nil {
return nil, fmt.Errorf("parse certificate failed, maybe format error:%s, or path error: %s", painTextErr.Error(), loadErr.Error()) return nil, fmt.Errorf("parse certificate failed, maybe format error:%s, or path error: %s", painTextErr.Error(), loadErr.Error())
} }
//TODO: support dynamic update pool too
// blocked by: https://github.com/golang/go/issues/64796
// maybe we can direct add `GetRootCAs` and `GetClientCAs` to ourselves tls fork
return pool, nil return pool, nil
} }

View File

@@ -12,25 +12,31 @@ import (
"syscall" "syscall"
"time" "time"
"github.com/metacubex/mihomo/common/atomic"
"github.com/metacubex/mihomo/component/keepalive" "github.com/metacubex/mihomo/component/keepalive"
"github.com/metacubex/mihomo/component/mptcp"
"github.com/metacubex/mihomo/component/resolver" "github.com/metacubex/mihomo/component/resolver"
) )
const ( const (
DefaultTCPTimeout = 5 * time.Second DefaultTCPTimeout = 5 * time.Second
DefaultUDPTimeout = DefaultTCPTimeout DefaultUDPTimeout = DefaultTCPTimeout
)
type dialFunc func(ctx context.Context, network string, ips []netip.Addr, port string, opt option) (net.Conn, error) dualStackFallbackTimeout = 300 * time.Millisecond
)
var ( var (
dialMux sync.Mutex tcpConcurrent = atomic.NewBool(false)
actualSingleStackDialContext = serialSingleStackDialContext
actualDualStackDialContext = serialDualStackDialContext
tcpConcurrent = false
fallbackTimeout = 300 * time.Millisecond
) )
func SetTcpConcurrent(concurrent bool) {
tcpConcurrent.Store(concurrent)
}
func GetTcpConcurrent() bool {
return tcpConcurrent.Load()
}
func DialContext(ctx context.Context, network, address string, options ...Option) (net.Conn, error) { func DialContext(ctx context.Context, network, address string, options ...Option) (net.Conn, error) {
opt := applyOptions(options...) opt := applyOptions(options...)
@@ -49,11 +55,22 @@ func DialContext(ctx context.Context, network, address string, options ...Option
return nil, err return nil, err
} }
tcpConcurrent := GetTcpConcurrent()
switch network { switch network {
case "tcp4", "tcp6", "udp4", "udp6": case "tcp4", "tcp6", "udp4", "udp6":
return actualSingleStackDialContext(ctx, network, ips, port, opt) if tcpConcurrent {
return parallelDialContext(ctx, network, ips, port, opt)
}
return serialDialContext(ctx, network, ips, port, opt)
case "tcp", "udp": case "tcp", "udp":
return actualDualStackDialContext(ctx, network, ips, port, opt) if tcpConcurrent {
if opt.prefer != 4 && opt.prefer != 6 {
return parallelDialContext(ctx, network, ips, port, opt)
}
return dualStackDialContext(ctx, parallelDialContext, network, ips, port, opt)
}
return dualStackDialContext(ctx, serialDialContext, network, ips, port, opt)
default: default:
return nil, ErrorInvalidedNetworkStack return nil, ErrorInvalidedNetworkStack
} }
@@ -103,25 +120,6 @@ func ListenPacket(ctx context.Context, network, address string, rAddrPort netip.
return lc.ListenPacket(ctx, network, address) return lc.ListenPacket(ctx, network, address)
} }
func SetTcpConcurrent(concurrent bool) {
dialMux.Lock()
defer dialMux.Unlock()
tcpConcurrent = concurrent
if concurrent {
actualSingleStackDialContext = concurrentSingleStackDialContext
actualDualStackDialContext = concurrentDualStackDialContext
} else {
actualSingleStackDialContext = serialSingleStackDialContext
actualDualStackDialContext = serialDualStackDialContext
}
}
func GetTcpConcurrent() bool {
dialMux.Lock()
defer dialMux.Unlock()
return tcpConcurrent
}
func dialContext(ctx context.Context, network string, destination netip.Addr, port string, opt option) (net.Conn, error) { func dialContext(ctx context.Context, network string, destination netip.Addr, port string, opt option) (net.Conn, error) {
var address string var address string
destination, port = resolver.LookupIP4P(destination, port) destination, port = resolver.LookupIP4P(destination, port)
@@ -140,9 +138,7 @@ func dialContext(ctx context.Context, network string, destination netip.Addr, po
dialer := netDialer.(*net.Dialer) dialer := netDialer.(*net.Dialer)
keepalive.SetNetDialer(dialer) keepalive.SetNetDialer(dialer)
if opt.mpTcp { mptcp.SetNetDialer(dialer, opt.mpTcp)
setMultiPathTCP(dialer)
}
if DefaultSocketHook != nil { // ignore interfaceName, routingMark and tfo when DefaultSocketHook not null (in CMFA) if DefaultSocketHook != nil { // ignore interfaceName, routingMark and tfo when DefaultSocketHook not null (in CMFA)
socketHookToToDialer(dialer) socketHookToToDialer(dialer)
@@ -206,24 +202,7 @@ func ICMPControl(destination netip.Addr) func(network, address string, conn sysc
} }
} }
func serialSingleStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt option) (net.Conn, error) { type dialFunc func(ctx context.Context, network string, ips []netip.Addr, port string, opt option) (net.Conn, error)
return serialDialContext(ctx, network, ips, port, opt)
}
func serialDualStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt option) (net.Conn, error) {
return dualStackDialContext(ctx, serialDialContext, network, ips, port, opt)
}
func concurrentSingleStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt option) (net.Conn, error) {
return parallelDialContext(ctx, network, ips, port, opt)
}
func concurrentDualStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt option) (net.Conn, error) {
if opt.prefer != 4 && opt.prefer != 6 {
return parallelDialContext(ctx, network, ips, port, opt)
}
return dualStackDialContext(ctx, parallelDialContext, network, ips, port, opt)
}
func dualStackDialContext(ctx context.Context, dialFn dialFunc, network string, ips []netip.Addr, port string, opt option) (net.Conn, error) { func dualStackDialContext(ctx context.Context, dialFn dialFunc, network string, ips []netip.Addr, port string, opt option) (net.Conn, error) {
ipv4s, ipv6s := resolver.SortationAddr(ips) ipv4s, ipv6s := resolver.SortationAddr(ips)
@@ -232,7 +211,7 @@ func dualStackDialContext(ctx context.Context, dialFn dialFunc, network string,
} }
preferIPVersion := opt.prefer preferIPVersion := opt.prefer
fallbackTicker := time.NewTicker(fallbackTimeout) fallbackTicker := time.NewTicker(dualStackFallbackTimeout)
defer fallbackTicker.Stop() defer fallbackTicker.Stop()
results := make(chan dialResult) results := make(chan dialResult)

View File

@@ -1,12 +0,0 @@
//go:build !go1.21
package dialer
import (
"net"
)
const multipathTCPAvailable = false
func setMultiPathTCP(dialer *net.Dialer) {
}

View File

@@ -1,11 +0,0 @@
//go:build go1.21
package dialer
import "net"
const multipathTCPAvailable = true
func setMultiPathTCP(dialer *net.Dialer) {
dialer.SetMultipathTCP(true)
}

View File

@@ -5,13 +5,33 @@ import (
"fmt" "fmt"
tlsC "github.com/metacubex/mihomo/component/tls" tlsC "github.com/metacubex/mihomo/component/tls"
"github.com/metacubex/tls"
) )
type Config struct { type Config struct {
GetEncryptedClientHelloConfigList func(ctx context.Context, serverName string) ([]byte, error) GetEncryptedClientHelloConfigList func(ctx context.Context, serverName string) ([]byte, error)
} }
func (cfg *Config) ClientHandle(ctx context.Context, tlsConfig *tlsC.Config) (err error) { func (cfg *Config) ClientHandle(ctx context.Context, tlsConfig *tls.Config) (err error) {
if cfg == nil {
return nil
}
echConfigList, err := cfg.GetEncryptedClientHelloConfigList(ctx, tlsConfig.ServerName)
if err != nil {
return fmt.Errorf("resolve ECH config error: %w", err)
}
tlsConfig.EncryptedClientHelloConfigList = echConfigList
if tlsConfig.MinVersion != 0 && tlsConfig.MinVersion < tls.VersionTLS13 {
tlsConfig.MinVersion = tls.VersionTLS13
}
if tlsConfig.MaxVersion != 0 && tlsConfig.MaxVersion < tls.VersionTLS13 {
tlsConfig.MaxVersion = tls.VersionTLS13
}
return nil
}
func (cfg *Config) ClientHandleUTLS(ctx context.Context, tlsConfig *tlsC.Config) (err error) {
if cfg == nil { if cfg == nil {
return nil return nil
} }

View File

@@ -0,0 +1,147 @@
package echparser
import (
"errors"
"fmt"
"golang.org/x/crypto/cryptobyte"
)
// export from std's crypto/tls/ech.go
const extensionEncryptedClientHello = 0xfe0d
type ECHCipher struct {
KDFID uint16
AEADID uint16
}
type ECHExtension struct {
Type uint16
Data []byte
}
type ECHConfig struct {
raw []byte
Version uint16
Length uint16
ConfigID uint8
KemID uint16
PublicKey []byte
SymmetricCipherSuite []ECHCipher
MaxNameLength uint8
PublicName []byte
Extensions []ECHExtension
}
var ErrMalformedECHConfigList = errors.New("tls: malformed ECHConfigList")
type EchConfigErr struct {
field string
}
func (e *EchConfigErr) Error() string {
if e.field == "" {
return "tls: malformed ECHConfig"
}
return fmt.Sprintf("tls: malformed ECHConfig, invalid %s field", e.field)
}
func ParseECHConfig(enc []byte) (skip bool, ec ECHConfig, err error) {
s := cryptobyte.String(enc)
ec.raw = []byte(enc)
if !s.ReadUint16(&ec.Version) {
return false, ECHConfig{}, &EchConfigErr{"version"}
}
if !s.ReadUint16(&ec.Length) {
return false, ECHConfig{}, &EchConfigErr{"length"}
}
if len(ec.raw) < int(ec.Length)+4 {
return false, ECHConfig{}, &EchConfigErr{"length"}
}
ec.raw = ec.raw[:ec.Length+4]
if ec.Version != extensionEncryptedClientHello {
s.Skip(int(ec.Length))
return true, ECHConfig{}, nil
}
if !s.ReadUint8(&ec.ConfigID) {
return false, ECHConfig{}, &EchConfigErr{"config_id"}
}
if !s.ReadUint16(&ec.KemID) {
return false, ECHConfig{}, &EchConfigErr{"kem_id"}
}
if !s.ReadUint16LengthPrefixed((*cryptobyte.String)(&ec.PublicKey)) {
return false, ECHConfig{}, &EchConfigErr{"public_key"}
}
var cipherSuites cryptobyte.String
if !s.ReadUint16LengthPrefixed(&cipherSuites) {
return false, ECHConfig{}, &EchConfigErr{"cipher_suites"}
}
for !cipherSuites.Empty() {
var c ECHCipher
if !cipherSuites.ReadUint16(&c.KDFID) {
return false, ECHConfig{}, &EchConfigErr{"cipher_suites kdf_id"}
}
if !cipherSuites.ReadUint16(&c.AEADID) {
return false, ECHConfig{}, &EchConfigErr{"cipher_suites aead_id"}
}
ec.SymmetricCipherSuite = append(ec.SymmetricCipherSuite, c)
}
if !s.ReadUint8(&ec.MaxNameLength) {
return false, ECHConfig{}, &EchConfigErr{"maximum_name_length"}
}
var publicName cryptobyte.String
if !s.ReadUint8LengthPrefixed(&publicName) {
return false, ECHConfig{}, &EchConfigErr{"public_name"}
}
ec.PublicName = publicName
var extensions cryptobyte.String
if !s.ReadUint16LengthPrefixed(&extensions) {
return false, ECHConfig{}, &EchConfigErr{"extensions"}
}
for !extensions.Empty() {
var e ECHExtension
if !extensions.ReadUint16(&e.Type) {
return false, ECHConfig{}, &EchConfigErr{"extensions type"}
}
if !extensions.ReadUint16LengthPrefixed((*cryptobyte.String)(&e.Data)) {
return false, ECHConfig{}, &EchConfigErr{"extensions data"}
}
ec.Extensions = append(ec.Extensions, e)
}
return false, ec, nil
}
// ParseECHConfigList parses a draft-ietf-tls-esni-18 ECHConfigList, returning a
// slice of parsed ECHConfigs, in the same order they were parsed, or an error
// if the list is malformed.
func ParseECHConfigList(data []byte) ([]ECHConfig, error) {
s := cryptobyte.String(data)
var length uint16
if !s.ReadUint16(&length) {
return nil, ErrMalformedECHConfigList
}
if length != uint16(len(data)-2) {
return nil, ErrMalformedECHConfigList
}
var configs []ECHConfig
for len(s) > 0 {
if len(s) < 4 {
return nil, errors.New("tls: malformed ECHConfig")
}
configLen := uint16(s[2])<<8 | uint16(s[3])
skip, ec, err := ParseECHConfig(s)
if err != nil {
return nil, err
}
s = s[configLen+4:]
if !skip {
configs = append(configs, ec)
}
}
return configs, nil
}

View File

@@ -8,10 +8,13 @@ import (
"errors" "errors"
"fmt" "fmt"
"os" "os"
"runtime"
"sync"
"github.com/metacubex/mihomo/component/ca" C "github.com/metacubex/mihomo/constant"
tlsC "github.com/metacubex/mihomo/component/tls"
"github.com/metacubex/fswatch"
"github.com/metacubex/tls"
"golang.org/x/crypto/cryptobyte" "golang.org/x/crypto/cryptobyte"
) )
@@ -85,11 +88,11 @@ func GenECHConfig(publicName string) (configBase64 string, keyPem string, err er
return return
} }
func UnmarshalECHKeys(raw []byte) ([]tlsC.EncryptedClientHelloKey, error) { func UnmarshalECHKeys(raw []byte) ([]tls.EncryptedClientHelloKey, error) {
var keys []tlsC.EncryptedClientHelloKey var keys []tls.EncryptedClientHelloKey
rawString := cryptobyte.String(raw) rawString := cryptobyte.String(raw)
for !rawString.Empty() { for !rawString.Empty() {
var key tlsC.EncryptedClientHelloKey var key tls.EncryptedClientHelloKey
if !rawString.ReadUint16LengthPrefixed((*cryptobyte.String)(&key.PrivateKey)) { if !rawString.ReadUint16LengthPrefixed((*cryptobyte.String)(&key.PrivateKey)) {
return nil, errors.New("error parsing private key") return nil, errors.New("error parsing private key")
} }
@@ -104,40 +107,65 @@ func UnmarshalECHKeys(raw []byte) ([]tlsC.EncryptedClientHelloKey, error) {
return keys, nil return keys, nil
} }
func LoadECHKey(key string, tlsConfig *tlsC.Config, path ca.Path) error { func LoadECHKey(key string, tlsConfig *tls.Config) error {
if key == "" { if key == "" {
return nil return nil
} }
painTextErr := loadECHKey([]byte(key), tlsConfig) echKeys, painTextErr := loadECHKey([]byte(key))
if painTextErr == nil { if painTextErr == nil {
tlsConfig.GetEncryptedClientHelloKeys = func(info *tls.ClientHelloInfo) ([]tls.EncryptedClientHelloKey, error) {
return echKeys, nil
}
return nil return nil
} }
key = path.Resolve(key) key = C.Path.Resolve(key)
var loadErr error var loadErr error
if !path.IsSafePath(key) { if !C.Path.IsSafePath(key) {
loadErr = path.ErrNotSafePath(key) loadErr = C.Path.ErrNotSafePath(key)
} else { } else {
var echKey []byte var echKey []byte
echKey, loadErr = os.ReadFile(key) echKey, loadErr = os.ReadFile(key)
if loadErr == nil { if loadErr == nil {
loadErr = loadECHKey(echKey, tlsConfig) echKeys, loadErr = loadECHKey(echKey)
} }
} }
if loadErr != nil { if loadErr != nil {
return fmt.Errorf("parse ECH keys failed, maybe format error:%s, or path error: %s", painTextErr.Error(), loadErr.Error()) return fmt.Errorf("parse ECH keys failed, maybe format error:%s, or path error: %s", painTextErr.Error(), loadErr.Error())
} }
gcFlag := new(os.File) // tiny (on the order of 16 bytes or less) and pointer-free objects may never run the finalizer, so we choose new an os.File
updateMutex := sync.RWMutex{}
if watcher, err := fswatch.NewWatcher(fswatch.Options{Path: []string{key}, Callback: func(path string) {
updateMutex.Lock()
defer updateMutex.Unlock()
if echKey, err := os.ReadFile(key); err == nil {
if newEchKeys, err := loadECHKey(echKey); err == nil {
echKeys = newEchKeys
}
}
}}); err == nil {
if err = watcher.Start(); err == nil {
runtime.SetFinalizer(gcFlag, func(f *os.File) {
_ = watcher.Close()
})
}
}
tlsConfig.GetEncryptedClientHelloKeys = func(info *tls.ClientHelloInfo) ([]tls.EncryptedClientHelloKey, error) {
defer runtime.KeepAlive(gcFlag)
updateMutex.RLock()
defer updateMutex.RUnlock()
return echKeys, nil
}
return nil return nil
} }
func loadECHKey(echKey []byte, tlsConfig *tlsC.Config) error { func loadECHKey(echKey []byte) ([]tls.EncryptedClientHelloKey, error) {
block, rest := pem.Decode(echKey) block, rest := pem.Decode(echKey)
if block == nil || block.Type != "ECH KEYS" || len(rest) > 0 { if block == nil || block.Type != "ECH KEYS" || len(rest) > 0 {
return errors.New("invalid ECH keys pem") return nil, errors.New("invalid ECH keys pem")
} }
echKeys, err := UnmarshalECHKeys(block.Bytes) echKeys, err := UnmarshalECHKeys(block.Bytes)
if err != nil { if err != nil {
return fmt.Errorf("parse ECH keys: %w", err) return nil, fmt.Errorf("parse ECH keys: %w", err)
} }
tlsConfig.EncryptedClientHelloKeys = echKeys return echKeys, err
return nil
} }

30
component/ech/key_test.go Normal file
View File

@@ -0,0 +1,30 @@
package ech
import (
"encoding/base64"
"testing"
"github.com/metacubex/mihomo/component/ech/echparser"
)
func TestGenECHConfig(t *testing.T) {
domain := "www.example.com"
configBase64, _, err := GenECHConfig(domain)
if err != nil {
t.Error(err)
}
echConfigList, err := base64.StdEncoding.DecodeString(configBase64)
if err != nil {
t.Error(err)
}
echConfigs, err := echparser.ParseECHConfigList(echConfigList)
if err != nil {
t.Error(err)
}
if len(echConfigs) == 0 {
t.Error("no ech config")
}
if publicName := string(echConfigs[0].PublicName); publicName != domain {
t.Error("ech config domain error, expect ", domain, " got", publicName)
}
}

View File

@@ -4,13 +4,29 @@ import (
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
) )
const (
UseFakeIP = "fake-ip"
UseRealIP = "real-ip"
)
type Skipper struct { type Skipper struct {
Host []C.DomainMatcher Rules []C.Rule
Mode C.FilterMode Host []C.DomainMatcher
Mode C.FilterMode
} }
// ShouldSkipped return if domain should be skipped // ShouldSkipped return if domain should be skipped
func (p *Skipper) ShouldSkipped(domain string) bool { func (p *Skipper) ShouldSkipped(domain string) bool {
if len(p.Rules) > 0 {
metadata := &C.Metadata{Host: domain}
for _, rule := range p.Rules {
if matched, action := rule.Match(metadata, C.RuleMatchHelper{}); matched {
return action == UseRealIP
}
}
return false
}
should := p.shouldSkipped(domain) should := p.shouldSkipped(domain)
if p.Mode == C.FilterWhiteList { if p.Mode == C.FilterWhiteList {
return !should return !should

View File

@@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"github.com/metacubex/mihomo/component/ech" "github.com/metacubex/mihomo/component/ech"
"github.com/metacubex/mihomo/transport/sudoku"
"github.com/metacubex/mihomo/transport/vless/encryption" "github.com/metacubex/mihomo/transport/vless/encryption"
"github.com/gofrs/uuid/v5" "github.com/gofrs/uuid/v5"
@@ -12,7 +13,7 @@ import (
func Main(args []string) { func Main(args []string) {
if len(args) < 1 { if len(args) < 1 {
panic("Using: generate uuid/reality-keypair/wg-keypair/ech-keypair/vless-mlkem768/vless-x25519") panic("Using: generate uuid/reality-keypair/wg-keypair/ech-keypair/vless-mlkem768/vless-x25519/sudoku-keypair")
} }
switch args[0] { switch args[0] {
case "uuid": case "uuid":
@@ -57,6 +58,11 @@ func Main(args []string) {
fmt.Println("Seed: " + seedBase64) fmt.Println("Seed: " + seedBase64)
fmt.Println("Client: " + clientBase64) fmt.Println("Client: " + clientBase64)
fmt.Println("Hash32: " + hash32Base64) fmt.Println("Hash32: " + hash32Base64)
fmt.Println("-----------------------")
fmt.Println(" Lazy-Config ")
fmt.Println("-----------------------")
fmt.Printf("[Server] decryption: \"mlkem768x25519plus.native.600s.%s\"\n", seedBase64)
fmt.Printf("[Client] encryption: \"mlkem768x25519plus.native.0rtt.%s\"\n", clientBase64)
case "vless-x25519": case "vless-x25519":
var privateKey string var privateKey string
if len(args) > 1 { if len(args) > 1 {
@@ -69,5 +75,18 @@ func Main(args []string) {
fmt.Println("PrivateKey: " + privateKeyBase64) fmt.Println("PrivateKey: " + privateKeyBase64)
fmt.Println("Password: " + passwordBase64) fmt.Println("Password: " + passwordBase64)
fmt.Println("Hash32: " + hash32Base64) fmt.Println("Hash32: " + hash32Base64)
fmt.Println("-----------------------")
fmt.Println(" Lazy-Config ")
fmt.Println("-----------------------")
fmt.Printf("[Server] decryption: \"mlkem768x25519plus.native.600s.%s\"\n", privateKeyBase64)
fmt.Printf("[Client] encryption: \"mlkem768x25519plus.native.0rtt.%s\"\n", passwordBase64)
case "sudoku-keypair":
privateKey, publicKey, err := sudoku.GenKeyPair()
if err != nil {
panic(err)
}
// Output: Available Private Key for client, Master Public Key for server
fmt.Println("PrivateKey: " + privateKey)
fmt.Println("PublicKey: " + publicKey)
} }
} }

View File

@@ -4,7 +4,6 @@ import (
"context" "context"
"fmt" "fmt"
"io" "io"
"net/http"
"os" "os"
"sync" "sync"
"time" "time"
@@ -14,6 +13,8 @@ import (
"github.com/metacubex/mihomo/component/mmdb" "github.com/metacubex/mihomo/component/mmdb"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/log" "github.com/metacubex/mihomo/log"
"github.com/metacubex/http"
) )
var ( var (

View File

@@ -4,7 +4,6 @@ import (
"context" "context"
"io" "io"
"net" "net"
"net/http"
URL "net/url" URL "net/url"
"runtime" "runtime"
"strings" "strings"
@@ -13,6 +12,8 @@ import (
"github.com/metacubex/mihomo/component/ca" "github.com/metacubex/mihomo/component/ca"
"github.com/metacubex/mihomo/component/dialer" "github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/listener/inner" "github.com/metacubex/mihomo/listener/inner"
"github.com/metacubex/http"
) )
var ( var (

View File

@@ -0,0 +1,23 @@
//go:build !go1.21
package mptcp
import (
"net"
)
const MultipathTCPAvailable = false
func SetNetDialer(dialer *net.Dialer, open bool) {
}
func GetNetDialer(dialer *net.Dialer) bool {
return false
}
func SetNetListenConfig(listenConfig *net.ListenConfig, open bool) {
}
func GetNetListenConfig(listenConfig *net.ListenConfig) bool {
return false
}

View File

@@ -0,0 +1,23 @@
//go:build go1.21
package mptcp
import "net"
const MultipathTCPAvailable = true
func SetNetDialer(dialer *net.Dialer, open bool) {
dialer.SetMultipathTCP(open)
}
func GetNetDialer(dialer *net.Dialer) bool {
return dialer.MultipathTCP()
}
func SetNetListenConfig(listenConfig *net.ListenConfig, open bool) {
listenConfig.SetMultipathTCP(open)
}
func GetNetListenConfig(listenConfig *net.ListenConfig) bool {
return listenConfig.MultipathTCP()
}

View File

@@ -197,6 +197,10 @@ func newSearcher(major int) *searcher {
case 12: case 12:
fallthrough fallthrough
case 13: case 13:
fallthrough
case 14:
fallthrough
case 15:
s = &searcher{ s = &searcher{
headSize: 64, headSize: 64,
tcpItemSize: 744, tcpItemSize: 744,

View File

@@ -0,0 +1,37 @@
package proxydialer
import (
"context"
"fmt"
"net"
"net/netip"
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/tunnel"
)
type byNameProxyDialer struct {
proxyName string
}
func (d byNameProxyDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
proxies := tunnel.Proxies()
proxy, ok := proxies[d.proxyName]
if !ok {
return nil, fmt.Errorf("proxyName[%s] not found", d.proxyName)
}
return New(proxy, true).DialContext(ctx, network, address)
}
func (d byNameProxyDialer) ListenPacket(ctx context.Context, network, address string, rAddrPort netip.AddrPort) (net.PacketConn, error) {
proxies := tunnel.Proxies()
proxy, ok := proxies[d.proxyName]
if !ok {
return nil, fmt.Errorf("proxyName[%s] not found", d.proxyName)
}
return New(proxy, true).ListenPacket(ctx, network, address, rAddrPort)
}
func NewByName(proxyName string) C.Dialer {
return byNameProxyDialer{proxyName: proxyName}
}

View File

@@ -2,34 +2,22 @@ package proxydialer
import ( import (
"context" "context"
"fmt"
"net" "net"
"net/netip" "net/netip"
"strings" "strings"
N "github.com/metacubex/mihomo/common/net" N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/component/dialer"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/tunnel"
"github.com/metacubex/mihomo/tunnel/statistic" "github.com/metacubex/mihomo/tunnel/statistic"
) )
type proxyDialer struct { type proxyDialer struct {
proxy C.ProxyAdapter proxy C.ProxyAdapter
dialer C.Dialer
statistic bool statistic bool
} }
func New(proxy C.ProxyAdapter, dialer C.Dialer, statistic bool) C.Dialer { func New(proxy C.ProxyAdapter, statistic bool) C.Dialer {
return proxyDialer{proxy: proxy, dialer: dialer, statistic: statistic} return proxyDialer{proxy: proxy, statistic: statistic}
}
func NewByName(proxyName string, dialer C.Dialer) (C.Dialer, error) {
proxies := tunnel.Proxies()
if proxy, ok := proxies[proxyName]; ok {
return New(proxy, dialer, true), nil
}
return nil, fmt.Errorf("proxyName[%s] not found", proxyName)
} }
func (p proxyDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { func (p proxyDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
@@ -50,13 +38,7 @@ func (p proxyDialer) DialContext(ctx context.Context, network, address string) (
} }
return N.NewBindPacketConn(pc, currentMeta.UDPAddr()), nil return N.NewBindPacketConn(pc, currentMeta.UDPAddr()), nil
} }
var conn C.Conn conn, err := p.proxy.DialContext(ctx, currentMeta)
var err error
if _, ok := p.dialer.(dialer.Dialer); ok { // first using old function to let mux work
conn, err = p.proxy.DialContext(ctx, currentMeta)
} else {
conn, err = p.proxy.DialContextWithDialer(ctx, p.dialer, currentMeta)
}
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -72,14 +54,8 @@ func (p proxyDialer) ListenPacket(ctx context.Context, network, address string,
} }
func (p proxyDialer) listenPacket(ctx context.Context, currentMeta *C.Metadata) (C.PacketConn, error) { func (p proxyDialer) listenPacket(ctx context.Context, currentMeta *C.Metadata) (C.PacketConn, error) {
var pc C.PacketConn
var err error
currentMeta.NetWork = C.UDP currentMeta.NetWork = C.UDP
if _, ok := p.dialer.(dialer.Dialer); ok { // first using old function to let mux work pc, err := p.proxy.ListenPacketContext(ctx, currentMeta)
pc, err = p.proxy.ListenPacketContext(ctx, currentMeta)
} else {
pc, err = p.proxy.ListenPacketWithDialer(ctx, p.dialer, currentMeta)
}
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -12,71 +12,22 @@ import (
type SingDialer interface { type SingDialer interface {
N.Dialer N.Dialer
SetDialer(dialer C.Dialer)
} }
type singDialer proxyDialer type singDialer struct {
cDialer C.Dialer
}
var _ N.Dialer = (*singDialer)(nil) var _ N.Dialer = (*singDialer)(nil)
func (d *singDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { func (d singDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
return (*proxyDialer)(d).DialContext(ctx, network, destination.String()) return d.cDialer.DialContext(ctx, network, destination.String())
} }
func (d *singDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { func (d singDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
return (*proxyDialer)(d).ListenPacket(ctx, "udp", "", destination.AddrPort()) return d.cDialer.ListenPacket(ctx, "udp", "", destination.AddrPort())
} }
func (d *singDialer) SetDialer(dialer C.Dialer) { func NewSingDialer(cDialer C.Dialer) SingDialer {
(*proxyDialer)(d).dialer = dialer return singDialer{cDialer: cDialer}
}
func NewSingDialer(proxy C.ProxyAdapter, dialer C.Dialer, statistic bool) SingDialer {
return (*singDialer)(&proxyDialer{
proxy: proxy,
dialer: dialer,
statistic: statistic,
})
}
type byNameSingDialer struct {
dialer C.Dialer
proxyName string
}
var _ N.Dialer = (*byNameSingDialer)(nil)
func (d *byNameSingDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
var cDialer C.Dialer = d.dialer
if len(d.proxyName) > 0 {
pd, err := NewByName(d.proxyName, d.dialer)
if err != nil {
return nil, err
}
cDialer = pd
}
return cDialer.DialContext(ctx, network, destination.String())
}
func (d *byNameSingDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
var cDialer C.Dialer = d.dialer
if len(d.proxyName) > 0 {
pd, err := NewByName(d.proxyName, d.dialer)
if err != nil {
return nil, err
}
cDialer = pd
}
return cDialer.ListenPacket(ctx, "udp", "", destination.AddrPort())
}
func (d *byNameSingDialer) SetDialer(dialer C.Dialer) {
d.dialer = dialer
}
func NewByNameSingDialer(proxyName string, dialer C.Dialer) SingDialer {
return &byNameSingDialer{
dialer: dialer,
proxyName: proxyName,
}
} }

View File

@@ -8,7 +8,6 @@ import (
"strings" "strings"
_ "unsafe" _ "unsafe"
"github.com/metacubex/mihomo/common/utils"
"github.com/metacubex/mihomo/component/resolver/hosts" "github.com/metacubex/mihomo/component/resolver/hosts"
"github.com/metacubex/mihomo/component/trie" "github.com/metacubex/mihomo/component/trie"
"github.com/metacubex/randv2" "github.com/metacubex/randv2"
@@ -66,37 +65,35 @@ type HostValue struct {
Domain string Domain string
} }
func NewHostValue(value any) (HostValue, error) { func NewHostValue(value []string) (HostValue, error) {
isDomain := true isDomain := true
ips := make([]netip.Addr, 0) ips := make([]netip.Addr, 0, len(value))
domain := "" domain := ""
if valueArr, err := utils.ToStringSlice(value); err != nil { switch len(value) {
return HostValue{}, err case 0:
} else { return HostValue{}, errors.New("value is empty")
if len(valueArr) > 1 { case 1:
host := value[0]
if ip, err := netip.ParseAddr(host); err == nil {
ips = append(ips, ip.Unmap())
isDomain = false isDomain = false
for _, str := range valueArr { } else {
if ip, err := netip.ParseAddr(str); err == nil { domain = host
ips = append(ips, ip.Unmap()) }
} else { default: // > 1
return HostValue{}, err isDomain = false
} for _, str := range value {
} if ip, err := netip.ParseAddr(str); err == nil {
} else if len(valueArr) == 1 {
host := valueArr[0]
if ip, err := netip.ParseAddr(host); err == nil {
ips = append(ips, ip.Unmap()) ips = append(ips, ip.Unmap())
isDomain = false
} else { } else {
domain = host return HostValue{}, err
} }
} }
} }
if isDomain { if isDomain {
return NewHostValueByDomain(domain) return NewHostValueByDomain(domain)
} else {
return NewHostValueByIPs(ips)
} }
return NewHostValueByIPs(ips)
} }
func NewHostValueByIPs(ips []netip.Addr) (HostValue, error) { func NewHostValueByIPs(ips []netip.Addr) (HostValue, error) {

View File

@@ -4,7 +4,6 @@ import (
"context" "context"
"errors" "errors"
"io" "io"
"net/http"
"os" "os"
"path/filepath" "path/filepath"
"time" "time"
@@ -13,6 +12,8 @@ import (
mihomoHttp "github.com/metacubex/mihomo/component/http" mihomoHttp "github.com/metacubex/mihomo/component/http"
"github.com/metacubex/mihomo/component/profile/cachefile" "github.com/metacubex/mihomo/component/profile/cachefile"
P "github.com/metacubex/mihomo/constant/provider" P "github.com/metacubex/mihomo/constant/provider"
"github.com/metacubex/http"
) )
const ( const (

View File

@@ -3,14 +3,13 @@ package tls
import ( import (
"context" "context"
"net" "net"
"net/http"
"runtime/debug" "runtime/debug"
"time" "time"
N "github.com/metacubex/mihomo/common/net" N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/log" "github.com/metacubex/mihomo/log"
"golang.org/x/net/http2" "github.com/metacubex/http"
) )
func extractTlsHandshakeTimeoutFromServer(s *http.Server) time.Duration { func extractTlsHandshakeTimeoutFromServer(s *http.Server) time.Duration {
@@ -35,8 +34,8 @@ func extractTlsHandshakeTimeoutFromServer(s *http.Server) time.Duration {
// only do tls handshake and check NegotiatedProtocol with std's *tls.Conn // only do tls handshake and check NegotiatedProtocol with std's *tls.Conn
// so we do the same logic to let http2 (not h2c) work fine // so we do the same logic to let http2 (not h2c) work fine
func NewListenerForHttps(l net.Listener, httpServer *http.Server, tlsConfig *Config) net.Listener { func NewListenerForHttps(l net.Listener, httpServer *http.Server, tlsConfig *Config) net.Listener {
http2Server := &http2.Server{} http2Server := &http.Http2Server{}
_ = http2.ConfigureServer(httpServer, http2Server) _ = http.Http2ConfigureServer(httpServer, http2Server)
return N.NewHandleContextListener(context.Background(), l, func(ctx context.Context, conn net.Conn) (net.Conn, error) { return N.NewHandleContextListener(context.Background(), l, func(ctx context.Context, conn net.Conn) (net.Conn, error) {
c := Server(conn, tlsConfig) c := Server(conn, tlsConfig)
@@ -58,8 +57,8 @@ func NewListenerForHttps(l net.Listener, httpServer *http.Server, tlsConfig *Con
_ = conn.SetWriteDeadline(time.Time{}) _ = conn.SetWriteDeadline(time.Time{})
} }
if c.ConnectionState().NegotiatedProtocol == http2.NextProtoTLS { if c.ConnectionState().NegotiatedProtocol == http.Http2NextProtoTLS {
http2Server.ServeConn(c, &http2.ServeConnOpts{BaseConfig: httpServer}) http2Server.ServeConn(c, &http.Http2ServeConnOpts{BaseConfig: httpServer})
return nil, net.ErrClosed return nil, net.ErrClosed
} }
return c, nil return c, nil

View File

@@ -10,22 +10,21 @@ import (
"crypto/hmac" "crypto/hmac"
"crypto/sha256" "crypto/sha256"
"crypto/sha512" "crypto/sha512"
"crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/binary" "encoding/binary"
"errors" "errors"
"net" "net"
"net/http"
"strings" "strings"
"time" "time"
"github.com/metacubex/mihomo/log" "github.com/metacubex/mihomo/log"
"github.com/metacubex/mihomo/ntp" "github.com/metacubex/mihomo/ntp"
"github.com/metacubex/http"
"github.com/metacubex/randv2" "github.com/metacubex/randv2"
"github.com/metacubex/tls"
utls "github.com/metacubex/utls" utls "github.com/metacubex/utls"
"golang.org/x/crypto/hkdf" "golang.org/x/crypto/hkdf"
"golang.org/x/net/http2"
) )
const RealityMaxShortIDLen = 8 const RealityMaxShortIDLen = 8
@@ -37,13 +36,14 @@ type RealityConfig struct {
SupportX25519MLKEM768 bool SupportX25519MLKEM768 bool
} }
func GetRealityConn(ctx context.Context, conn net.Conn, fingerprint UClientHelloID, tlsConfig *Config, realityConfig *RealityConfig) (net.Conn, error) { func GetRealityConn(ctx context.Context, conn net.Conn, fingerprint UClientHelloID, serverName string, realityConfig *RealityConfig) (net.Conn, error) {
for retry := 0; ; retry++ { for retry := 0; ; retry++ {
verifier := &realityVerifier{ verifier := &realityVerifier{
serverName: tlsConfig.ServerName, serverName: serverName,
} }
uConfig := &utls.Config{ uConfig := &utls.Config{
ServerName: tlsConfig.ServerName, Time: ntp.Now,
ServerName: serverName,
InsecureSkipVerify: true, InsecureSkipVerify: true,
SessionTicketsDisabled: true, SessionTicketsDisabled: true,
VerifyPeerCertificate: verifier.VerifyPeerCertificate, VerifyPeerCertificate: verifier.VerifyPeerCertificate,
@@ -132,7 +132,7 @@ func GetRealityConn(ctx context.Context, conn net.Conn, fingerprint UClientHello
func realityClientFallback(uConn net.Conn, serverName string, fingerprint utls.ClientHelloID) { func realityClientFallback(uConn net.Conn, serverName string, fingerprint utls.ClientHelloID) {
defer uConn.Close() defer uConn.Close()
client := http.Client{ client := http.Client{
Transport: &http2.Transport{ Transport: &http.Http2Transport{
DialTLSContext: func(ctx context.Context, network, addr string, config *tls.Config) (net.Conn, error) { DialTLSContext: func(ctx context.Context, network, addr string, config *tls.Config) (net.Conn, error) {
return uConn, nil return uConn, nil
}, },

View File

@@ -1,13 +1,16 @@
package tls package tls
import ( import (
"crypto/tls" "context"
"net" "net"
"reflect"
"unsafe"
"github.com/metacubex/mihomo/common/once" "github.com/metacubex/mihomo/common/once"
"github.com/metacubex/mihomo/common/utils" "github.com/metacubex/mihomo/common/utils"
"github.com/metacubex/mihomo/log" "github.com/metacubex/mihomo/log"
"github.com/metacubex/tls"
utls "github.com/metacubex/utls" utls "github.com/metacubex/utls"
"github.com/mroth/weightedrand/v2" "github.com/mroth/weightedrand/v2"
) )
@@ -124,10 +127,48 @@ func UCertificate(it tls.Certificate) utls.Certificate {
type EncryptedClientHelloKey = utls.EncryptedClientHelloKey type EncryptedClientHelloKey = utls.EncryptedClientHelloKey
func UEncryptedClientHelloKey(it tls.EncryptedClientHelloKey) utls.EncryptedClientHelloKey {
return utls.EncryptedClientHelloKey{
Config: it.Config,
PrivateKey: it.PrivateKey,
SendAsRetry: it.SendAsRetry,
}
}
type Config = utls.Config type Config = utls.Config
var tlsCertificateRequestInfoCtxOffset = utils.MustOK(reflect.TypeOf((*tls.CertificateRequestInfo)(nil)).Elem().FieldByName("ctx")).Offset
var tlsClientHelloInfoCtxOffset = utils.MustOK(reflect.TypeOf((*tls.ClientHelloInfo)(nil)).Elem().FieldByName("ctx")).Offset
var tlsConnectionStateEkmOffset = utils.MustOK(reflect.TypeOf((*tls.ConnectionState)(nil)).Elem().FieldByName("ekm")).Offset
var utlsConnectionStateEkmOffset = utils.MustOK(reflect.TypeOf((*utls.ConnectionState)(nil)).Elem().FieldByName("ekm")).Offset
func tlsConnectionState(state utls.ConnectionState) (tlsState tls.ConnectionState) {
tlsState = tls.ConnectionState{
Version: state.Version,
HandshakeComplete: state.HandshakeComplete,
DidResume: state.DidResume,
CipherSuite: state.CipherSuite,
//CurveID: state.CurveID,
NegotiatedProtocol: state.NegotiatedProtocol,
NegotiatedProtocolIsMutual: state.NegotiatedProtocolIsMutual,
ServerName: state.ServerName,
PeerCertificates: state.PeerCertificates,
VerifiedChains: state.VerifiedChains,
SignedCertificateTimestamps: state.SignedCertificateTimestamps,
OCSPResponse: state.OCSPResponse,
TLSUnique: state.TLSUnique,
ECHAccepted: state.ECHAccepted,
//HelloRetryRequest: state.HelloRetryRequest,
}
// The layout of map, chan, and func types is equivalent to *T.
// state.ekm is a func(label string, context []byte, length int) ([]byte, error)
*(*unsafe.Pointer)(unsafe.Add(unsafe.Pointer(&tlsState), tlsConnectionStateEkmOffset)) =
*(*unsafe.Pointer)(unsafe.Add(unsafe.Pointer(&state), utlsConnectionStateEkmOffset))
return
}
func UConfig(config *tls.Config) *utls.Config { func UConfig(config *tls.Config) *utls.Config {
return &utls.Config{ cfg := &utls.Config{
Rand: config.Rand, Rand: config.Rand,
Time: config.Time, Time: config.Time,
Certificates: utils.Map(config.Certificates, UCertificate), Certificates: utils.Map(config.Certificates, UCertificate),
@@ -146,7 +187,67 @@ func UConfig(config *tls.Config) *utls.Config {
}), }),
SessionTicketsDisabled: config.SessionTicketsDisabled, SessionTicketsDisabled: config.SessionTicketsDisabled,
Renegotiation: utls.RenegotiationSupport(config.Renegotiation), Renegotiation: utls.RenegotiationSupport(config.Renegotiation),
KeyLogWriter: config.KeyLogWriter,
} }
if config.GetClientCertificate != nil {
cfg.GetClientCertificate = func(info *utls.CertificateRequestInfo) (*utls.Certificate, error) {
tlsInfo := &tls.CertificateRequestInfo{
AcceptableCAs: info.AcceptableCAs,
SignatureSchemes: utils.Map(info.SignatureSchemes, func(it utls.SignatureScheme) tls.SignatureScheme {
return tls.SignatureScheme(it)
}),
Version: info.Version,
}
*(*context.Context)(unsafe.Add(unsafe.Pointer(tlsInfo), tlsCertificateRequestInfoCtxOffset)) = info.Context() // for tlsInfo.ctx
cert, err := config.GetClientCertificate(tlsInfo)
if err != nil {
return nil, err
}
uCert := UCertificate(*cert)
return &uCert, err
}
}
if config.GetCertificate != nil {
cfg.GetCertificate = func(info *utls.ClientHelloInfo) (*utls.Certificate, error) {
tlsInfo := &tls.ClientHelloInfo{
CipherSuites: info.CipherSuites,
ServerName: info.ServerName,
SupportedCurves: utils.Map(info.SupportedCurves, func(it utls.CurveID) tls.CurveID {
return tls.CurveID(it)
}),
SupportedPoints: info.SupportedPoints,
SignatureSchemes: utils.Map(info.SignatureSchemes, func(it utls.SignatureScheme) tls.SignatureScheme {
return tls.SignatureScheme(it)
}),
SupportedProtos: info.SupportedProtos,
SupportedVersions: info.SupportedVersions,
Extensions: info.Extensions,
Conn: info.Conn,
//HelloRetryRequest: info.HelloRetryRequest,
}
*(*context.Context)(unsafe.Add(unsafe.Pointer(tlsInfo), tlsClientHelloInfoCtxOffset)) = info.Context() // for tlsInfo.ctx
cert, err := config.GetCertificate(tlsInfo)
if err != nil {
return nil, err
}
uCert := UCertificate(*cert)
return &uCert, err
}
}
if config.VerifyConnection != nil {
cfg.VerifyConnection = func(state utls.ConnectionState) error {
return config.VerifyConnection(tlsConnectionState(state))
}
}
config.EncryptedClientHelloConfigList = cfg.EncryptedClientHelloConfigList
if config.EncryptedClientHelloRejectionVerify != nil {
cfg.EncryptedClientHelloRejectionVerify = func(state utls.ConnectionState) error {
return config.EncryptedClientHelloRejectionVerify(tlsConnectionState(state))
}
}
//cfg.GetEncryptedClientHelloKeys =
cfg.EncryptedClientHelloKeys = utils.Map(config.EncryptedClientHelloKeys, UEncryptedClientHelloKey)
return cfg
} }
// BuildWebsocketHandshakeState it will only send http/1.1 in its ALPN. // BuildWebsocketHandshakeState it will only send http/1.1 in its ALPN.

View File

@@ -6,7 +6,6 @@ import (
"context" "context"
"fmt" "fmt"
"io" "io"
"net/http"
"os" "os"
"os/exec" "os/exec"
"path/filepath" "path/filepath"
@@ -20,6 +19,8 @@ import (
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/constant/features" "github.com/metacubex/mihomo/constant/features"
"github.com/metacubex/mihomo/log" "github.com/metacubex/mihomo/log"
"github.com/metacubex/http"
) )
const ( const (

View File

@@ -3,11 +3,12 @@ package updater
import ( import (
"context" "context"
"io" "io"
"net/http"
"os" "os"
"time" "time"
mihomoHttp "github.com/metacubex/mihomo/component/http" mihomoHttp "github.com/metacubex/mihomo/component/http"
"github.com/metacubex/http"
) )
const defaultHttpTimeout = time.Second * 90 const defaultHttpTimeout = time.Second * 90

View File

@@ -15,7 +15,9 @@ import (
"github.com/metacubex/mihomo/adapter/outbound" "github.com/metacubex/mihomo/adapter/outbound"
"github.com/metacubex/mihomo/adapter/outboundgroup" "github.com/metacubex/mihomo/adapter/outboundgroup"
"github.com/metacubex/mihomo/adapter/provider" "github.com/metacubex/mihomo/adapter/provider"
"github.com/metacubex/mihomo/common/orderedmap"
"github.com/metacubex/mihomo/common/utils" "github.com/metacubex/mihomo/common/utils"
"github.com/metacubex/mihomo/common/yaml"
"github.com/metacubex/mihomo/component/auth" "github.com/metacubex/mihomo/component/auth"
"github.com/metacubex/mihomo/component/cidr" "github.com/metacubex/mihomo/component/cidr"
"github.com/metacubex/mihomo/component/fakeip" "github.com/metacubex/mihomo/component/fakeip"
@@ -34,11 +36,10 @@ import (
R "github.com/metacubex/mihomo/rules" R "github.com/metacubex/mihomo/rules"
RC "github.com/metacubex/mihomo/rules/common" RC "github.com/metacubex/mihomo/rules/common"
RP "github.com/metacubex/mihomo/rules/provider" RP "github.com/metacubex/mihomo/rules/provider"
RW "github.com/metacubex/mihomo/rules/wrapper"
T "github.com/metacubex/mihomo/tunnel" T "github.com/metacubex/mihomo/tunnel"
orderedmap "github.com/wk8/go-ordered-map/v2"
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
"gopkg.in/yaml.v3"
) )
// General config // General config
@@ -165,6 +166,7 @@ type DNS struct {
FakeIPTTL int FakeIPTTL int
NameServerPolicy []dns.Policy NameServerPolicy []dns.Policy
ProxyServerNameserver []dns.NameServer ProxyServerNameserver []dns.NameServer
ProxyServerPolicy []dns.Policy
DirectNameServer []dns.NameServer DirectNameServer []dns.NameServer
DirectFollowPolicy bool DirectFollowPolicy bool
} }
@@ -235,6 +237,7 @@ type RawDNS struct {
CacheMaxSize int `yaml:"cache-max-size" json:"cache-max-size"` CacheMaxSize int `yaml:"cache-max-size" json:"cache-max-size"`
NameServerPolicy *orderedmap.OrderedMap[string, any] `yaml:"nameserver-policy" json:"nameserver-policy"` NameServerPolicy *orderedmap.OrderedMap[string, any] `yaml:"nameserver-policy" json:"nameserver-policy"`
ProxyServerNameserver []string `yaml:"proxy-server-nameserver" json:"proxy-server-nameserver"` ProxyServerNameserver []string `yaml:"proxy-server-nameserver" json:"proxy-server-nameserver"`
ProxyServerNameserverPolicy *orderedmap.OrderedMap[string, any] `yaml:"proxy-server-nameserver-policy" json:"proxy-server-nameserver-policy"`
DirectNameServer []string `yaml:"direct-nameserver" json:"direct-nameserver"` DirectNameServer []string `yaml:"direct-nameserver" json:"direct-nameserver"`
DirectNameServerFollowPolicy bool `yaml:"direct-nameserver-follow-policy" json:"direct-nameserver-follow-policy"` DirectNameServerFollowPolicy bool `yaml:"direct-nameserver-follow-policy" json:"direct-nameserver-follow-policy"`
} }
@@ -273,35 +276,36 @@ type RawTun struct {
GSO bool `yaml:"gso" json:"gso,omitempty"` GSO bool `yaml:"gso" json:"gso,omitempty"`
GSOMaxSize uint32 `yaml:"gso-max-size" json:"gso-max-size,omitempty"` GSOMaxSize uint32 `yaml:"gso-max-size" json:"gso-max-size,omitempty"`
//Inet4Address []netip.Prefix `yaml:"inet4-address" json:"inet4-address,omitempty"` //Inet4Address []netip.Prefix `yaml:"inet4-address" json:"inet4-address,omitempty"`
Inet6Address []netip.Prefix `yaml:"inet6-address" json:"inet6-address,omitempty"` Inet6Address []netip.Prefix `yaml:"inet6-address" json:"inet6-address,omitempty"`
IPRoute2TableIndex int `yaml:"iproute2-table-index" json:"iproute2-table-index,omitempty"` IPRoute2TableIndex int `yaml:"iproute2-table-index" json:"iproute2-table-index,omitempty"`
IPRoute2RuleIndex int `yaml:"iproute2-rule-index" json:"iproute2-rule-index,omitempty"` IPRoute2RuleIndex int `yaml:"iproute2-rule-index" json:"iproute2-rule-index,omitempty"`
AutoRedirect bool `yaml:"auto-redirect" json:"auto-redirect,omitempty"` AutoRedirect bool `yaml:"auto-redirect" json:"auto-redirect,omitempty"`
AutoRedirectInputMark uint32 `yaml:"auto-redirect-input-mark" json:"auto-redirect-input-mark,omitempty"` AutoRedirectInputMark uint32 `yaml:"auto-redirect-input-mark" json:"auto-redirect-input-mark,omitempty"`
AutoRedirectOutputMark uint32 `yaml:"auto-redirect-output-mark" json:"auto-redirect-output-mark,omitempty"` AutoRedirectOutputMark uint32 `yaml:"auto-redirect-output-mark" json:"auto-redirect-output-mark,omitempty"`
LoopbackAddress []netip.Addr `yaml:"loopback-address" json:"loopback-address,omitempty"` AutoRedirectIPRoute2FallbackRuleIndex int `yaml:"auto-redirect-iproute2-fallback-rule-index" json:"auto-redirect-iproute2-fallback-rule-index,omitempty"`
StrictRoute bool `yaml:"strict-route" json:"strict-route,omitempty"` LoopbackAddress []netip.Addr `yaml:"loopback-address" json:"loopback-address,omitempty"`
RouteAddress []netip.Prefix `yaml:"route-address" json:"route-address,omitempty"` StrictRoute bool `yaml:"strict-route" json:"strict-route,omitempty"`
RouteAddressSet []string `yaml:"route-address-set" json:"route-address-set,omitempty"` RouteAddress []netip.Prefix `yaml:"route-address" json:"route-address,omitempty"`
RouteExcludeAddress []netip.Prefix `yaml:"route-exclude-address" json:"route-exclude-address,omitempty"` RouteAddressSet []string `yaml:"route-address-set" json:"route-address-set,omitempty"`
RouteExcludeAddressSet []string `yaml:"route-exclude-address-set" json:"route-exclude-address-set,omitempty"` RouteExcludeAddress []netip.Prefix `yaml:"route-exclude-address" json:"route-exclude-address,omitempty"`
IncludeInterface []string `yaml:"include-interface" json:"include-interface,omitempty"` RouteExcludeAddressSet []string `yaml:"route-exclude-address-set" json:"route-exclude-address-set,omitempty"`
ExcludeInterface []string `yaml:"exclude-interface" json:"exclude-interface,omitempty"` IncludeInterface []string `yaml:"include-interface" json:"include-interface,omitempty"`
IncludeUID []uint32 `yaml:"include-uid" json:"include-uid,omitempty"` ExcludeInterface []string `yaml:"exclude-interface" json:"exclude-interface,omitempty"`
IncludeUIDRange []string `yaml:"include-uid-range" json:"include-uid-range,omitempty"` IncludeUID []uint32 `yaml:"include-uid" json:"include-uid,omitempty"`
ExcludeUID []uint32 `yaml:"exclude-uid" json:"exclude-uid,omitempty"` IncludeUIDRange []string `yaml:"include-uid-range" json:"include-uid-range,omitempty"`
ExcludeUIDRange []string `yaml:"exclude-uid-range" json:"exclude-uid-range,omitempty"` ExcludeUID []uint32 `yaml:"exclude-uid" json:"exclude-uid,omitempty"`
ExcludeSrcPort []uint16 `yaml:"exclude-src-port" json:"exclude-src-port,omitempty"` ExcludeUIDRange []string `yaml:"exclude-uid-range" json:"exclude-uid-range,omitempty"`
ExcludeSrcPortRange []string `yaml:"exclude-src-port-range" json:"exclude-src-port-range,omitempty"` ExcludeSrcPort []uint16 `yaml:"exclude-src-port" json:"exclude-src-port,omitempty"`
ExcludeDstPort []uint16 `yaml:"exclude-dst-port" json:"exclude-dst-port,omitempty"` ExcludeSrcPortRange []string `yaml:"exclude-src-port-range" json:"exclude-src-port-range,omitempty"`
ExcludeDstPortRange []string `yaml:"exclude-dst-port-range" json:"exclude-dst-port-range,omitempty"` ExcludeDstPort []uint16 `yaml:"exclude-dst-port" json:"exclude-dst-port,omitempty"`
IncludeAndroidUser []int `yaml:"include-android-user" json:"include-android-user,omitempty"` ExcludeDstPortRange []string `yaml:"exclude-dst-port-range" json:"exclude-dst-port-range,omitempty"`
IncludePackage []string `yaml:"include-package" json:"include-package,omitempty"` IncludeAndroidUser []int `yaml:"include-android-user" json:"include-android-user,omitempty"`
ExcludePackage []string `yaml:"exclude-package" json:"exclude-package,omitempty"` IncludePackage []string `yaml:"include-package" json:"include-package,omitempty"`
EndpointIndependentNat bool `yaml:"endpoint-independent-nat" json:"endpoint-independent-nat,omitempty"` ExcludePackage []string `yaml:"exclude-package" json:"exclude-package,omitempty"`
UDPTimeout int64 `yaml:"udp-timeout" json:"udp-timeout,omitempty"` EndpointIndependentNat bool `yaml:"endpoint-independent-nat" json:"endpoint-independent-nat,omitempty"`
DisableICMPForwarding bool `yaml:"disable-icmp-forwarding" json:"disable-icmp-forwarding,omitempty"` UDPTimeout int64 `yaml:"udp-timeout" json:"udp-timeout,omitempty"`
FileDescriptor int `yaml:"file-descriptor" json:"file-descriptor"` DisableICMPForwarding bool `yaml:"disable-icmp-forwarding" json:"disable-icmp-forwarding,omitempty"`
FileDescriptor int `yaml:"file-descriptor" json:"file-descriptor"`
Inet4RouteAddress []netip.Prefix `yaml:"inet4-route-address" json:"inet4-route-address,omitempty"` Inet4RouteAddress []netip.Prefix `yaml:"inet4-route-address" json:"inet4-route-address,omitempty"`
Inet6RouteAddress []netip.Prefix `yaml:"inet6-route-address" json:"inet6-route-address,omitempty"` Inet6RouteAddress []netip.Prefix `yaml:"inet6-route-address" json:"inet6-route-address,omitempty"`
@@ -954,6 +958,12 @@ func parseProxies(cfg *RawConfig) (proxies map[string]C.Proxy, providersMap map[
) )
proxies["GLOBAL"] = adapter.NewProxy(global) proxies["GLOBAL"] = adapter.NewProxy(global)
} }
// validate dialer-proxy references
if err := validateDialerProxies(proxies); err != nil {
return nil, nil, err
}
return proxies, providersMap, nil return proxies, providersMap, nil
} }
@@ -1083,6 +1093,10 @@ func parseRules(rulesConfig []string, proxies map[string]C.Proxy, ruleProviders
} }
} }
if format == "rules" { // only wrap top level rules
parsed = RW.NewRuleWrapper(parsed)
}
rules = append(rules, parsed) rules = append(rules, parsed)
} }
@@ -1101,22 +1115,23 @@ func parseHosts(cfg *RawConfig) (*trie.DomainTrie[resolver.HostValue], error) {
if len(cfg.Hosts) != 0 { if len(cfg.Hosts) != 0 {
for domain, anyValue := range cfg.Hosts { for domain, anyValue := range cfg.Hosts {
if str, ok := anyValue.(string); ok && str == "lan" { hosts, err := utils.ToStringSlice(anyValue)
if err != nil {
return nil, err
}
if len(hosts) == 1 && hosts[0] == "lan" {
if addrs, err := net.InterfaceAddrs(); err != nil { if addrs, err := net.InterfaceAddrs(); err != nil {
log.Errorln("insert lan to host error: %s", err) log.Errorln("insert lan to host error: %s", err)
} else { } else {
ips := make([]netip.Addr, 0) hosts = make([]string, 0, len(addrs))
for _, addr := range addrs { for _, addr := range addrs {
if ipnet, ok := addr.(*net.IPNet); ok && !ipnet.IP.IsLoopback() && !ipnet.IP.IsLinkLocalUnicast() { if ipnet, ok := addr.(*net.IPNet); ok && !ipnet.IP.IsLoopback() && !ipnet.IP.IsLinkLocalUnicast() {
if ip, err := netip.ParseAddr(ipnet.IP.String()); err == nil { hosts = append(hosts, ipnet.IP.String())
ips = append(ips, ip)
}
} }
} }
anyValue = ips
} }
} }
value, err := resolver.NewHostValue(anyValue) value, err := resolver.NewHostValue(hosts)
if err != nil { if err != nil {
return nil, fmt.Errorf("%s is not a valid value", anyValue) return nil, fmt.Errorf("%s is not a valid value", anyValue)
} }
@@ -1184,7 +1199,7 @@ func parseNameServer(servers []string, respectRules bool, preferH3 bool) ([]dns.
dnsNetType = "tcp" // TCP dnsNetType = "tcp" // TCP
case "tls": case "tls":
addr, err = hostWithDefaultPort(u.Host, "853") addr, err = hostWithDefaultPort(u.Host, "853")
dnsNetType = "tcp-tls" // DNS over TLS dnsNetType = "tls" // DNS over TLS
case "http", "https": case "http", "https":
addr, err = hostWithDefaultPort(u.Host, "443") addr, err = hostWithDefaultPort(u.Host, "443")
dnsNetType = "https" // DNS over HTTPS dnsNetType = "https" // DNS over HTTPS
@@ -1292,7 +1307,7 @@ func parseNameServerPolicy(nsPolicy *orderedmap.OrderedMap[string, any], rulePro
} }
kLower := strings.ToLower(k) kLower := strings.ToLower(k)
if strings.Contains(kLower, ",") { if strings.Contains(kLower, ",") {
if strings.Contains(kLower, "geosite:") { if strings.HasPrefix(kLower, "geosite:") {
subkeys := strings.Split(k, ":") subkeys := strings.Split(k, ":")
subkeys = subkeys[1:] subkeys = subkeys[1:]
subkeys = strings.Split(subkeys[0], ",") subkeys = strings.Split(subkeys[0], ",")
@@ -1300,7 +1315,7 @@ func parseNameServerPolicy(nsPolicy *orderedmap.OrderedMap[string, any], rulePro
newKey := "geosite:" + subkey newKey := "geosite:" + subkey
policy = append(policy, dns.Policy{Domain: newKey, NameServers: nameservers}) policy = append(policy, dns.Policy{Domain: newKey, NameServers: nameservers})
} }
} else if strings.Contains(kLower, "rule-set:") { } else if strings.HasPrefix(kLower, "rule-set:") {
subkeys := strings.Split(k, ":") subkeys := strings.Split(k, ":")
subkeys = subkeys[1:] subkeys = subkeys[1:]
subkeys = strings.Split(subkeys[0], ",") subkeys = strings.Split(subkeys[0], ",")
@@ -1315,9 +1330,9 @@ func parseNameServerPolicy(nsPolicy *orderedmap.OrderedMap[string, any], rulePro
} }
} }
} else { } else {
if strings.Contains(kLower, "geosite:") { if strings.HasPrefix(kLower, "geosite:") {
policy = append(policy, dns.Policy{Domain: "geosite:" + k[8:], NameServers: nameservers}) policy = append(policy, dns.Policy{Domain: "geosite:" + k[8:], NameServers: nameservers})
} else if strings.Contains(kLower, "rule-set:") { } else if strings.HasPrefix(kLower, "rule-set:") {
policy = append(policy, dns.Policy{Domain: "rule-set:" + k[9:], NameServers: nameservers}) policy = append(policy, dns.Policy{Domain: "rule-set:" + k[9:], NameServers: nameservers})
} else { } else {
policy = append(policy, dns.Policy{Domain: k, NameServers: nameservers}) policy = append(policy, dns.Policy{Domain: k, NameServers: nameservers})
@@ -1391,6 +1406,13 @@ func parseDNS(rawCfg *RawConfig, ruleProviders map[string]P.RuleProvider) (*DNS,
return nil, err return nil, err
} }
if dnsCfg.ProxyServerPolicy, err = parseNameServerPolicy(cfg.ProxyServerNameserverPolicy, ruleProviders, false, cfg.PreferH3); err != nil {
return nil, err
}
if len(dnsCfg.ProxyServerPolicy) != 0 && len(dnsCfg.ProxyServerNameserver) == 0 {
return nil, errors.New("disallow empty `proxy-server-nameserver` when `proxy-server-nameserver-policy` is set")
}
if dnsCfg.DirectNameServer, err = parseNameServer(cfg.DirectNameServer, false, cfg.PreferH3); err != nil { if dnsCfg.DirectNameServer, err = parseNameServer(cfg.DirectNameServer, false, cfg.PreferH3); err != nil {
return nil, err return nil, err
} }
@@ -1450,16 +1472,22 @@ func parseDNS(rawCfg *RawConfig, ruleProviders map[string]P.RuleProvider) (*DNS,
} }
} }
// fake ip skip host filter skipper := &fakeip.Skipper{Mode: cfg.FakeIPFilterMode}
host, err := parseDomain(cfg.FakeIPFilter, fakeIPTrie, "dns.fake-ip-filter", ruleProviders)
if err != nil { if cfg.FakeIPFilterMode == C.FilterRule {
return nil, err rules, err := parseFakeIPRules(cfg.FakeIPFilter, ruleProviders)
if err != nil {
return nil, err
}
skipper.Rules = rules
} else {
host, err := parseDomain(cfg.FakeIPFilter, fakeIPTrie, "dns.fake-ip-filter", ruleProviders)
if err != nil {
return nil, err
}
skipper.Host = host
} }
skipper := &fakeip.Skipper{
Host: host,
Mode: cfg.FakeIPFilterMode,
}
dnsCfg.FakeIPSkipper = skipper dnsCfg.FakeIPSkipper = skipper
dnsCfg.FakeIPTTL = cfg.FakeIPTTL dnsCfg.FakeIPTTL = cfg.FakeIPTTL
@@ -1541,6 +1569,55 @@ func parseDNS(rawCfg *RawConfig, ruleProviders map[string]P.RuleProvider) (*DNS,
return dnsCfg, nil return dnsCfg, nil
} }
func parseFakeIPRules(rawRules []string, ruleProviders map[string]P.RuleProvider) ([]C.Rule, error) {
var rules []C.Rule
for idx, line := range rawRules {
tp, payload, action, params := RC.ParseRulePayload(line, true)
action = strings.ToLower(action)
if action != fakeip.UseFakeIP && action != fakeip.UseRealIP {
return nil, fmt.Errorf("dns.fake-ip-filter[%d] [%s] error: invalid action '%s', must be 'fake-ip' or 'real-ip'", idx, line, action)
}
if tp == "RULE-SET" {
if rp, ok := ruleProviders[payload]; !ok {
return nil, fmt.Errorf("dns.fake-ip-filter[%d] [%s] error: rule-set '%s' not found", idx, line, payload)
} else {
switch rp.Behavior() {
case P.IPCIDR:
return nil, fmt.Errorf("dns.fake-ip-filter[%d] [%s] error: rule-set behavior is %s, must be domain or classical", idx, line, rp.Behavior())
case P.Classical:
log.Warnln("%s provider is %s, only matching domain rules in fake-ip-filter", rp.Name(), rp.Behavior())
default:
}
}
}
parsed, err := R.ParseRule(tp, payload, action, params, nil)
if err != nil {
return nil, fmt.Errorf("dns.fake-ip-filter[%d] [%s] error: %w", idx, line, err)
}
if !isDomainRule(parsed.RuleType()) && parsed.RuleType() != C.MATCH {
return nil, fmt.Errorf("dns.fake-ip-filter[%d] [%s] error: rule type '%s' not supported, only domain-based rules allowed", idx, line, tp)
}
rules = append(rules, parsed)
}
return rules, nil
}
func isDomainRule(rt C.RuleType) bool {
switch rt {
case C.Domain, C.DomainSuffix, C.DomainKeyword, C.DomainRegex, C.DomainWildcard, C.GEOSITE, C.RuleSet:
return true
default:
return false
}
}
func parseAuthentication(rawRecords []string) []auth.AuthUser { func parseAuthentication(rawRecords []string) []auth.AuthUser {
var users []auth.AuthUser var users []auth.AuthUser
for _, line := range rawRecords { for _, line := range rawRecords {
@@ -1573,39 +1650,40 @@ func parseTun(rawTun RawTun, dns *DNS, general *General) error {
AutoRoute: rawTun.AutoRoute, AutoRoute: rawTun.AutoRoute,
AutoDetectInterface: rawTun.AutoDetectInterface, AutoDetectInterface: rawTun.AutoDetectInterface,
MTU: rawTun.MTU, MTU: rawTun.MTU,
GSO: rawTun.GSO, GSO: rawTun.GSO,
GSOMaxSize: rawTun.GSOMaxSize, GSOMaxSize: rawTun.GSOMaxSize,
Inet4Address: []netip.Prefix{tunAddressPrefix}, Inet4Address: []netip.Prefix{tunAddressPrefix},
Inet6Address: rawTun.Inet6Address, Inet6Address: rawTun.Inet6Address,
IPRoute2TableIndex: rawTun.IPRoute2TableIndex, IPRoute2TableIndex: rawTun.IPRoute2TableIndex,
IPRoute2RuleIndex: rawTun.IPRoute2RuleIndex, IPRoute2RuleIndex: rawTun.IPRoute2RuleIndex,
AutoRedirect: rawTun.AutoRedirect, AutoRedirect: rawTun.AutoRedirect,
AutoRedirectInputMark: rawTun.AutoRedirectInputMark, AutoRedirectInputMark: rawTun.AutoRedirectInputMark,
AutoRedirectOutputMark: rawTun.AutoRedirectOutputMark, AutoRedirectOutputMark: rawTun.AutoRedirectOutputMark,
LoopbackAddress: rawTun.LoopbackAddress, AutoRedirectIPRoute2FallbackRuleIndex: rawTun.AutoRedirectIPRoute2FallbackRuleIndex,
StrictRoute: rawTun.StrictRoute, LoopbackAddress: rawTun.LoopbackAddress,
RouteAddress: rawTun.RouteAddress, StrictRoute: rawTun.StrictRoute,
RouteAddressSet: rawTun.RouteAddressSet, RouteAddress: rawTun.RouteAddress,
RouteExcludeAddress: rawTun.RouteExcludeAddress, RouteAddressSet: rawTun.RouteAddressSet,
RouteExcludeAddressSet: rawTun.RouteExcludeAddressSet, RouteExcludeAddress: rawTun.RouteExcludeAddress,
IncludeInterface: rawTun.IncludeInterface, RouteExcludeAddressSet: rawTun.RouteExcludeAddressSet,
ExcludeInterface: rawTun.ExcludeInterface, IncludeInterface: rawTun.IncludeInterface,
IncludeUID: rawTun.IncludeUID, ExcludeInterface: rawTun.ExcludeInterface,
IncludeUIDRange: rawTun.IncludeUIDRange, IncludeUID: rawTun.IncludeUID,
ExcludeUID: rawTun.ExcludeUID, IncludeUIDRange: rawTun.IncludeUIDRange,
ExcludeUIDRange: rawTun.ExcludeUIDRange, ExcludeUID: rawTun.ExcludeUID,
ExcludeSrcPort: rawTun.ExcludeSrcPort, ExcludeUIDRange: rawTun.ExcludeUIDRange,
ExcludeSrcPortRange: rawTun.ExcludeSrcPortRange, ExcludeSrcPort: rawTun.ExcludeSrcPort,
ExcludeDstPort: rawTun.ExcludeDstPort, ExcludeSrcPortRange: rawTun.ExcludeSrcPortRange,
ExcludeDstPortRange: rawTun.ExcludeDstPortRange, ExcludeDstPort: rawTun.ExcludeDstPort,
IncludeAndroidUser: rawTun.IncludeAndroidUser, ExcludeDstPortRange: rawTun.ExcludeDstPortRange,
IncludePackage: rawTun.IncludePackage, IncludeAndroidUser: rawTun.IncludeAndroidUser,
ExcludePackage: rawTun.ExcludePackage, IncludePackage: rawTun.IncludePackage,
EndpointIndependentNat: rawTun.EndpointIndependentNat, ExcludePackage: rawTun.ExcludePackage,
UDPTimeout: rawTun.UDPTimeout, EndpointIndependentNat: rawTun.EndpointIndependentNat,
DisableICMPForwarding: rawTun.DisableICMPForwarding, UDPTimeout: rawTun.UDPTimeout,
FileDescriptor: rawTun.FileDescriptor, DisableICMPForwarding: rawTun.DisableICMPForwarding,
FileDescriptor: rawTun.FileDescriptor,
Inet4RouteAddress: rawTun.Inet4RouteAddress, Inet4RouteAddress: rawTun.Inet4RouteAddress,
Inet6RouteAddress: rawTun.Inet6RouteAddress, Inet6RouteAddress: rawTun.Inet6RouteAddress,
@@ -1712,7 +1790,7 @@ func parseSniffer(snifferRaw RawSniffer, ruleProviders map[string]P.RuleProvider
} }
snifferConfig.SkipSrcAddress = skipSrcAddress snifferConfig.SkipSrcAddress = skipSrcAddress
skipDstAddress, err := parseIPCIDR(snifferRaw.SkipDstAddress, nil, "sniffer.skip-src-address", ruleProviders) skipDstAddress, err := parseIPCIDR(snifferRaw.SkipDstAddress, nil, "sniffer.skip-dst-address", ruleProviders)
if err != nil { if err != nil {
return nil, fmt.Errorf("error in skip-dst-address, error:%w", err) return nil, fmt.Errorf("error in skip-dst-address, error:%w", err)
} }
@@ -1731,7 +1809,7 @@ func parseIPCIDR(addresses []string, cidrSet *cidr.IpCidrSet, adapterName string
var matcher C.IpMatcher var matcher C.IpMatcher
for _, ipcidr := range addresses { for _, ipcidr := range addresses {
ipcidrLower := strings.ToLower(ipcidr) ipcidrLower := strings.ToLower(ipcidr)
if strings.Contains(ipcidrLower, "geoip:") { if strings.HasPrefix(ipcidrLower, "geoip:") {
subkeys := strings.Split(ipcidr, ":") subkeys := strings.Split(ipcidr, ":")
subkeys = subkeys[1:] subkeys = subkeys[1:]
subkeys = strings.Split(subkeys[0], ",") subkeys = strings.Split(subkeys[0], ",")
@@ -1742,7 +1820,7 @@ func parseIPCIDR(addresses []string, cidrSet *cidr.IpCidrSet, adapterName string
} }
matchers = append(matchers, matcher) matchers = append(matchers, matcher)
} }
} else if strings.Contains(ipcidrLower, "rule-set:") { } else if strings.HasPrefix(ipcidrLower, "rule-set:") {
subkeys := strings.Split(ipcidr, ":") subkeys := strings.Split(ipcidr, ":")
subkeys = subkeys[1:] subkeys = subkeys[1:]
subkeys = strings.Split(subkeys[0], ",") subkeys = strings.Split(subkeys[0], ",")
@@ -1778,7 +1856,7 @@ func parseDomain(domains []string, domainTrie *trie.DomainTrie[struct{}], adapte
var matcher C.DomainMatcher var matcher C.DomainMatcher
for _, domain := range domains { for _, domain := range domains {
domainLower := strings.ToLower(domain) domainLower := strings.ToLower(domain)
if strings.Contains(domainLower, "geosite:") { if strings.HasPrefix(domainLower, "geosite:") {
subkeys := strings.Split(domain, ":") subkeys := strings.Split(domain, ":")
subkeys = subkeys[1:] subkeys = subkeys[1:]
subkeys = strings.Split(subkeys[0], ",") subkeys = strings.Split(subkeys[0], ",")
@@ -1789,7 +1867,7 @@ func parseDomain(domains []string, domainTrie *trie.DomainTrie[struct{}], adapte
} }
matchers = append(matchers, matcher) matchers = append(matchers, matcher)
} }
} else if strings.Contains(domainLower, "rule-set:") { } else if strings.HasPrefix(domainLower, "rule-set:") {
subkeys := strings.Split(domain, ":") subkeys := strings.Split(domain, ":")
subkeys = subkeys[1:] subkeys = subkeys[1:]
subkeys = strings.Split(subkeys[0], ",") subkeys = strings.Split(subkeys[0], ",")

View File

@@ -9,6 +9,7 @@ import (
"github.com/metacubex/mihomo/adapter/outboundgroup" "github.com/metacubex/mihomo/adapter/outboundgroup"
"github.com/metacubex/mihomo/common/structure" "github.com/metacubex/mihomo/common/structure"
C "github.com/metacubex/mihomo/constant"
) )
// Check if ProxyGroups form DAG(Directed Acyclic Graph), and sort all ProxyGroups by dependency order. // Check if ProxyGroups form DAG(Directed Acyclic Graph), and sort all ProxyGroups by dependency order.
@@ -143,6 +144,64 @@ func proxyGroupsDagSort(groupsConfig []map[string]any) error {
return fmt.Errorf("loop is detected in ProxyGroup, please check following ProxyGroups: %v", loopElements) return fmt.Errorf("loop is detected in ProxyGroup, please check following ProxyGroups: %v", loopElements)
} }
// validateDialerProxies checks if all dialer-proxy references are valid
func validateDialerProxies(proxies map[string]C.Proxy) error {
graph := make(map[string]string) // proxy name -> dialer-proxy name
// collect all proxies with dialer-proxy configured
for name, proxy := range proxies {
dialerProxy := proxy.ProxyInfo().DialerProxy
if dialerProxy != "" {
// validate each dialer-proxy reference
_, exist := proxies[dialerProxy]
if !exist {
return fmt.Errorf("proxy [%s] dialer-proxy [%s] not found", name, dialerProxy)
}
// build dependency graph
graph[name] = dialerProxy
}
}
// perform depth-first search to detect cycles for each proxy
for name := range graph {
visited := make(map[string]bool, len(graph))
path := make([]string, 0, len(graph))
if validateDialerProxiesHasCycle(name, graph, visited, path) {
return fmt.Errorf("proxy [%s] has circular dialer-proxy dependency", name)
}
}
return nil
}
// validateDialerProxiesHasCycle performs DFS to detect if there's a cycle starting from current proxy
func validateDialerProxiesHasCycle(current string, graph map[string]string, visited map[string]bool, path []string) bool {
// check if current is already in path (cycle detected)
for _, p := range path {
if p == current {
return true
}
}
// already visited and no cycle
if visited[current] {
return false
}
visited[current] = true
path = append(path, current)
// check dialer-proxy of current proxy
if dialerProxy, exists := graph[current]; exists {
if validateDialerProxiesHasCycle(dialerProxy, graph, visited, path) {
return true
}
}
return false
}
func verifyIP6() bool { func verifyIP6() bool {
if skip, _ := strconv.ParseBool(os.Getenv("SKIP_SYSTEM_IPV6_CHECK")); skip { if skip, _ := strconv.ParseBool(os.Getenv("SKIP_SYSTEM_IPV6_CHECK")); skip {
return true return true

79
config/utils_test.go Normal file
View File

@@ -0,0 +1,79 @@
package config
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestValidateDialerProxies(t *testing.T) {
testCases := []struct {
testName string
proxy []map[string]any
errContains string
}{
{
testName: "ValidReference",
proxy: []map[string]any{ // create proxy with valid dialer-proxy reference
{"name": "base-proxy", "type": "socks5", "server": "127.0.0.1", "port": 1080},
{"name": "proxy-with-dialer", "type": "socks5", "server": "127.0.0.1", "port": 1081, "dialer-proxy": "base-proxy"},
},
errContains: "",
},
{
testName: "NotFoundReference",
proxy: []map[string]any{ // create proxy with non-existent dialer-proxy reference
{"name": "proxy-with-dialer", "type": "socks5", "server": "127.0.0.1", "port": 1081, "dialer-proxy": "non-existent-proxy"},
},
errContains: "not found",
},
{
testName: "CircularDependency",
proxy: []map[string]any{
// create proxy A that references B
{"name": "proxy-a", "type": "socks5", "server": "127.0.0.1", "port": 1080, "dialer-proxy": "proxy-c"},
// create proxy B that references C
{"name": "proxy-b", "type": "socks5", "server": "127.0.0.1", "port": 1081, "dialer-proxy": "proxy-a"},
// create proxy C that references A (creates cycle)
{"name": "proxy-c", "type": "socks5", "server": "127.0.0.1", "port": 1082, "dialer-proxy": "proxy-a"},
},
errContains: "circular",
},
{
testName: "ComplexChain",
proxy: []map[string]any{ // create a valid chain: proxy-d -> proxy-c -> proxy-b -> proxy-a
{"name": "proxy-a", "type": "socks5", "server": "127.0.0.1", "port": 1080},
{"name": "proxy-b", "type": "socks5", "server": "127.0.0.1", "port": 1081, "dialer-proxy": "proxy-a"},
{"name": "proxy-c", "type": "socks5", "server": "127.0.0.1", "port": 1082, "dialer-proxy": "proxy-b"},
{"name": "proxy-d", "type": "socks5", "server": "127.0.0.1", "port": 1083, "dialer-proxy": "proxy-c"},
},
errContains: "",
},
{
testName: "EmptyDialerProxy",
proxy: []map[string]any{ // create proxy without dialer-proxy
{"name": "simple-proxy", "type": "socks5", "server": "127.0.0.1", "port": 1080},
},
errContains: "",
},
{
testName: "SelfReference",
proxy: []map[string]any{ // create proxy that references itself
{"name": "self-proxy", "type": "socks5", "server": "127.0.0.1", "port": 1080, "dialer-proxy": "self-proxy"},
},
errContains: "circular",
},
}
for _, testCase := range testCases {
t.Run(testCase.testName, func(t *testing.T) {
config := RawConfig{Proxy: testCase.proxy}
_, _, err := parseProxies(&config)
if testCase.errContains == "" {
assert.NoError(t, err, testCase.testName)
} else {
assert.ErrorContains(t, err, testCase.errContains, testCase.testName)
}
})
}
}

View File

@@ -45,6 +45,7 @@ const (
Mieru Mieru
AnyTLS AnyTLS
Sudoku Sudoku
Masque
) )
const ( const (
@@ -59,6 +60,7 @@ var ErrNotSupport = errors.New("no support")
type Connection interface { type Connection interface {
Chains() Chain Chains() Chain
ProviderChains() Chain
AppendToChains(adapter ProxyAdapter) AppendToChains(adapter ProxyAdapter)
RemoteDestination() string RemoteDestination() string
} }
@@ -102,13 +104,14 @@ type Dialer interface {
} }
type ProxyInfo struct { type ProxyInfo struct {
XUDP bool XUDP bool
TFO bool TFO bool
MPTCP bool MPTCP bool
SMUX bool SMUX bool
Interface string Interface string
RoutingMark int RoutingMark int
DialerProxy string ProviderName string
DialerProxy string
} }
type ProxyAdapter interface { type ProxyAdapter interface {
@@ -121,17 +124,6 @@ type ProxyAdapter interface {
ProxyInfo() ProxyInfo ProxyInfo() ProxyInfo
MarshalJSON() ([]byte, error) MarshalJSON() ([]byte, error)
// Deprecated: use DialContextWithDialer and ListenPacketWithDialer instead.
// StreamConn wraps a protocol around net.Conn with Metadata.
//
// Examples:
// conn, _ := net.DialContext(context.Background(), "tcp", "host:port")
// conn, _ = adapter.StreamConnContext(context.Background(), conn, metadata)
//
// It returns a C.Conn with protocol which start with
// a new session (if any)
StreamConnContext(ctx context.Context, c net.Conn, metadata *Metadata) (net.Conn, error)
// DialContext return a C.Conn with protocol which // DialContext return a C.Conn with protocol which
// contains multiplexing-related reuse logic (if any) // contains multiplexing-related reuse logic (if any)
DialContext(ctx context.Context, metadata *Metadata) (Conn, error) DialContext(ctx context.Context, metadata *Metadata) (Conn, error)
@@ -140,13 +132,6 @@ type ProxyAdapter interface {
// SupportUOT return UDP over TCP support // SupportUOT return UDP over TCP support
SupportUOT() bool SupportUOT() bool
// SupportWithDialer only for deprecated relay group, the new protocol does not need to be implemented.
SupportWithDialer() NetWork
// DialContextWithDialer only for deprecated relay group, the new protocol does not need to be implemented.
DialContextWithDialer(ctx context.Context, dialer Dialer, metadata *Metadata) (Conn, error)
// ListenPacketWithDialer only for deprecated relay group, the new protocol does not need to be implemented.
ListenPacketWithDialer(ctx context.Context, dialer Dialer, metadata *Metadata) (PacketConn, error)
// IsL3Protocol return ProxyAdapter working in L3 (tell dns module not pass the domain to avoid loopback) // IsL3Protocol return ProxyAdapter working in L3 (tell dns module not pass the domain to avoid loopback)
IsL3Protocol(metadata *Metadata) bool IsL3Protocol(metadata *Metadata) bool
@@ -157,11 +142,6 @@ type ProxyAdapter interface {
Close() error Close() error
} }
type Group interface {
URLTest(ctx context.Context, url string, expectedStatus utils.IntRanges[uint16]) (mp map[string]uint16, err error)
Touch()
}
type DelayHistory struct { type DelayHistory struct {
Time time.Time `json:"time"` Time time.Time `json:"time"`
Delay uint16 `json:"delay"` Delay uint16 `json:"delay"`
@@ -233,6 +213,8 @@ func (at AdapterType) String() string {
return "AnyTLS" return "AnyTLS"
case Sudoku: case Sudoku:
return "Sudoku" return "Sudoku"
case Masque:
return "Masque"
case Relay: case Relay:
return "Relay" return "Relay"
case Selector: case Selector:

View File

@@ -86,18 +86,24 @@ func (d DNSPrefer) String() string {
} }
} }
func NewDNSPrefer(prefer string) DNSPrefer { func (d DNSPrefer) MarshalText() ([]byte, error) {
if p, ok := dnsPreferMap[prefer]; ok { return []byte(d.String()), nil
return p }
} else {
return DualStack func (d *DNSPrefer) UnmarshalText(data []byte) error {
p, exist := dnsPreferMap[strings.ToLower(string(data))]
if !exist {
p = DualStack
} }
*d = p
return nil
} }
// FilterModeMapping is a mapping for FilterMode enum // FilterModeMapping is a mapping for FilterMode enum
var FilterModeMapping = map[string]FilterMode{ var FilterModeMapping = map[string]FilterMode{
FilterBlackList.String(): FilterBlackList, FilterBlackList.String(): FilterBlackList,
FilterWhiteList.String(): FilterWhiteList, FilterWhiteList.String(): FilterWhiteList,
FilterRule.String(): FilterRule,
} }
type FilterMode int type FilterMode int
@@ -105,6 +111,7 @@ type FilterMode int
const ( const (
FilterBlackList FilterMode = iota FilterBlackList FilterMode = iota
FilterWhiteList FilterWhiteList
FilterRule
) )
func (e FilterMode) String() string { func (e FilterMode) String() string {
@@ -113,6 +120,8 @@ func (e FilterMode) String() string {
return "blacklist" return "blacklist"
case FilterWhiteList: case FilterWhiteList:
return "whitelist" return "whitelist"
case FilterRule:
return "rule"
default: default:
return "unknown" return "unknown"
} }

View File

@@ -1,5 +1,7 @@
package constant package constant
import "time"
// Rule Type // Rule Type
const ( const (
Domain RuleType = iota Domain RuleType = iota
@@ -27,6 +29,8 @@ const (
ProcessPath ProcessPath
ProcessNameRegex ProcessNameRegex
ProcessPathRegex ProcessPathRegex
ProcessNameWildcard
ProcessPathWildcard
RuleSet RuleSet
Network Network
Uid Uid
@@ -89,6 +93,10 @@ func (rt RuleType) String() string {
return "ProcessNameRegex" return "ProcessNameRegex"
case ProcessPathRegex: case ProcessPathRegex:
return "ProcessPathRegex" return "ProcessPathRegex"
case ProcessNameWildcard:
return "ProcessNameWildcard"
case ProcessPathWildcard:
return "ProcessPathWildcard"
case MATCH: case MATCH:
return "Match" return "Match"
case RuleSet: case RuleSet:
@@ -120,6 +128,27 @@ type Rule interface {
ProviderNames() []string ProviderNames() []string
} }
type RuleWrapper interface {
Rule
// SetDisabled to set enable/disable rule
SetDisabled(v bool)
// IsDisabled return rule is disabled or not
IsDisabled() bool
// HitCount for statistics
HitCount() uint64
// HitAt for statistics
HitAt() time.Time
// MissCount for statistics
MissCount() uint64
// MissAt for statistics
MissAt() time.Time
// Unwrap return Rule
Unwrap() Rule
}
type RuleMatchHelper struct { type RuleMatchHelper struct {
ResolveIP func() ResolveIP func()
FindProcess func() FindProcess func()

View File

@@ -2,13 +2,11 @@ package dns
import ( import (
"context" "context"
"crypto/tls"
"fmt" "fmt"
"net" "net"
"strings" "strings"
"time" "time"
"github.com/metacubex/mihomo/component/ca"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/log" "github.com/metacubex/mihomo/log"
@@ -16,11 +14,10 @@ import (
) )
type client struct { type client struct {
port string port string
host string host string
dialer *dnsDialer dialer *dnsDialer
schema string schema string
skipCertVerify bool
} }
var _ dnsClient = (*client)(nil) var _ dnsClient = (*client)(nil)
@@ -43,23 +40,6 @@ func (c *client) ExchangeContext(ctx context.Context, m *D.Msg) (*D.Msg, error)
} }
defer conn.Close() defer conn.Close()
if c.schema == "tls" {
tlsConfig, err := ca.GetTLSConfig(ca.Option{
TLSConfig: &tls.Config{
ServerName: c.host,
InsecureSkipVerify: c.skipCertVerify,
},
})
if err != nil {
return nil, err
}
tlsConn := tls.Client(conn, tlsConfig)
if err := tlsConn.HandshakeContext(ctx); err != nil {
return nil, err
}
conn = tlsConn
}
// miekg/dns ExchangeContext doesn't respond to context cancel. // miekg/dns ExchangeContext doesn't respond to context cancel.
// this is a workaround // this is a workaround
type result struct { type result struct {
@@ -117,12 +97,6 @@ func newClient(addr string, resolver *Resolver, netType string, params map[strin
} }
if strings.HasPrefix(netType, "tcp") { if strings.HasPrefix(netType, "tcp") {
c.schema = "tcp" c.schema = "tcp"
if strings.HasSuffix(netType, "tls") {
c.schema = "tls"
}
}
if params["skip-cert-verify"] == "true" {
c.skipCertVerify = true
} }
return c return c
} }

View File

@@ -2,13 +2,11 @@ package dns
import ( import (
"context" "context"
"crypto/tls"
"encoding/base64" "encoding/base64"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"net" "net"
"net/http"
"net/url" "net/url"
"runtime" "runtime"
"strconv" "strconv"
@@ -16,15 +14,15 @@ import (
"time" "time"
"github.com/metacubex/mihomo/component/ca" "github.com/metacubex/mihomo/component/ca"
tlsC "github.com/metacubex/mihomo/component/tls"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/log" "github.com/metacubex/mihomo/log"
"github.com/metacubex/http"
"github.com/metacubex/quic-go" "github.com/metacubex/quic-go"
"github.com/metacubex/quic-go/http3" "github.com/metacubex/quic-go/http3"
"github.com/metacubex/tls"
D "github.com/miekg/dns" D "github.com/miekg/dns"
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
"golang.org/x/net/http2"
) )
// Values to configure HTTP and HTTP/2 transport. // Values to configure HTTP and HTTP/2 transport.
@@ -439,8 +437,8 @@ func (doh *dnsOverHTTPS) createTransport(ctx context.Context) (t http.RoundTripp
// Explicitly configure transport to use HTTP/2. // Explicitly configure transport to use HTTP/2.
// //
// See https://github.com/AdguardTeam/dnsproxy/issues/11. // See https://github.com/AdguardTeam/dnsproxy/issues/11.
var transportH2 *http2.Transport var transportH2 *http.Http2Transport
transportH2, err = http2.ConfigureTransports(transport) transportH2, err = http.Http2ConfigureTransports(transport)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -530,20 +528,20 @@ func (doh *dnsOverHTTPS) createTransportH3(
// Ignore the address and always connect to the one that we got // Ignore the address and always connect to the one that we got
// from the bootstrapper. // from the bootstrapper.
_ string, _ string,
tlsCfg *tlsC.Config, tlsCfg *tls.Config,
cfg *quic.Config, cfg *quic.Config,
) (c *quic.Conn, err error) { ) (c *quic.Conn, err error) {
return doh.dialQuic(ctx, addr, tlsCfg, cfg) return doh.dialQuic(ctx, addr, tlsCfg, cfg)
}, },
DisableCompression: true, DisableCompression: true,
TLSClientConfig: tlsC.UConfig(tlsConfig), TLSClientConfig: tlsConfig,
QUICConfig: doh.getQUICConfig(), QUICConfig: doh.getQUICConfig(),
} }
return &http3Transport{baseTransport: rt}, nil return &http3Transport{baseTransport: rt}, nil
} }
func (doh *dnsOverHTTPS) dialQuic(ctx context.Context, addr string, tlsCfg *tlsC.Config, cfg *quic.Config) (*quic.Conn, error) { func (doh *dnsOverHTTPS) dialQuic(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (*quic.Conn, error) {
ip, port, err := net.SplitHostPort(addr) ip, port, err := net.SplitHostPort(addr)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -612,7 +610,7 @@ func (doh *dnsOverHTTPS) probeH3(
// Run probeQUIC and probeTLS in parallel and see which one is faster. // Run probeQUIC and probeTLS in parallel and see which one is faster.
chQuic := make(chan error, 1) chQuic := make(chan error, 1)
chTLS := make(chan error, 1) chTLS := make(chan error, 1)
go doh.probeQUIC(ctx, addr, tlsC.UConfig(probeTLSCfg), chQuic) go doh.probeQUIC(ctx, addr, probeTLSCfg, chQuic)
go doh.probeTLS(ctx, probeTLSCfg, chTLS) go doh.probeTLS(ctx, probeTLSCfg, chTLS)
select { select {
@@ -637,7 +635,7 @@ func (doh *dnsOverHTTPS) probeH3(
// probeQUIC attempts to establish a QUIC connection to the specified address. // probeQUIC attempts to establish a QUIC connection to the specified address.
// We run probeQUIC and probeTLS in parallel and see which one is faster. // We run probeQUIC and probeTLS in parallel and see which one is faster.
func (doh *dnsOverHTTPS) probeQUIC(ctx context.Context, addr string, tlsConfig *tlsC.Config, ch chan error) { func (doh *dnsOverHTTPS) probeQUIC(ctx context.Context, addr string, tlsConfig *tls.Config, ch chan error) {
startTime := time.Now() startTime := time.Now()
conn, err := doh.dialQuic(ctx, addr, tlsConfig, doh.getQUICConfig()) conn, err := doh.dialQuic(ctx, addr, tlsConfig, doh.getQUICConfig())
if err != nil { if err != nil {
@@ -727,14 +725,10 @@ func (doh *dnsOverHTTPS) tlsDial(ctx context.Context, network string, config *tl
// TLS handshake dialTimeout will be used as connection deadLine. // TLS handshake dialTimeout will be used as connection deadLine.
conn := tls.Client(rawConn, config) conn := tls.Client(rawConn, config)
err = conn.SetDeadline(time.Now().Add(dialTimeout)) ctx, cancel := context.WithTimeout(ctx, dialTimeout)
if err != nil { defer cancel()
// Must not happen in normal circumstances.
log.Errorln("cannot set deadline: %v", err)
return nil, err
}
err = conn.Handshake() err = conn.HandshakeContext(ctx)
if err != nil { if err != nil {
defer conn.Close() defer conn.Close()
return nil, err return nil, err

Some files were not shown because too many files have changed in this diff Show More