Compare commits

...

135 Commits

Author SHA1 Message Date
wwqgtxx
97f25250a6 chore: code cleanup 2026-02-08 00:10:07 +08:00
wwqgtxx
8b0bcb6740 chore: better generator 2026-02-08 00:09:46 +08:00
wwqgtxx
022f677385 chore: cleanup hostValue code 2026-02-07 12:07:01 +08:00
wwqgtxx
32799662ad fix: quic data race for crypto/tls 2026-02-07 09:41:08 +08:00
wwqgtxx
86257fc83c chore: remove reflect-based provider override code 2026-02-05 17:20:47 +08:00
wwqgtxx
f2222b5e02 chore: code cleanup 2026-02-05 16:47:24 +08:00
Chenx Dust
3ac62152cb fix: race condition of tcpConcurrent in dialer (#2556) 2026-02-05 15:57:01 +08:00
wwqgtxx
5516ca18fd chore: code cleanup 2026-02-05 10:46:09 +08:00
wwqgtxx
3bca69c745 chore: add some comments for the fingerprint verifier 2026-02-05 10:34:36 +08:00
wwqgtxx
558b3840ea action: update Go 1.26rc3 to test 2026-02-04 23:57:47 +08:00
wwqgtxx
f94da9f2b3 chore: fingerprint verifier handle non-leaf certificate will check the SNI matches the certificate's DNS name 2026-02-04 22:41:33 +08:00
wwqgtxx
2cfc4ba044 fix: CVE-2025-68121 for crypto/tls again and again 2026-02-04 15:19:27 +08:00
wwqgtxx
034f1d1e9b chore: disallow empty proxy-server-nameserver when proxy-server-nameserver-policy is set 2026-02-04 12:06:45 +08:00
wwqgtxx
dede56fe4b feat: add proxy-server-nameserver-policy to dns section 2026-02-03 01:41:00 +08:00
wwqgtxx
5fda87d50e chore: update sing-tun 2026-02-02 14:39:32 +08:00
wwqgtxx
5e1b133e4e chore: more callback support for utls 2026-02-01 01:36:30 +08:00
wwqgtxx
27a3ca6afc chore: converter support fingerprint for vmess/vless/trojan 2026-01-31 21:27:14 +08:00
wwqgtxx
7573affdd4 chore: better logging in masque outbound 2026-01-31 20:25:02 +08:00
wwqgtxx
17fbaf9100 doc: fix typo 2026-01-30 23:16:08 +08:00
sleshep
710772f993 chore: add simple validation for static dialer-proxy config (#2551)
Currently, it can only validate whether a cycle exists in proxies, and cannot determine if it is caused by groups.
2026-01-30 20:39:06 +08:00
saba-futai
d36b024b10 chore: align sudoku with upstream v0.2.0 (#2549) 2026-01-30 10:33:22 +08:00
wwqgtxx
f52c9356c2 fix: CVE-2025-68121 for crypto/tls again 2026-01-29 08:53:48 +08:00
wwqgtxx
e45c896185 feat: support masque outbound 2026-01-28 19:01:18 +08:00
wwqgtxx
d18a14afeb fix: snat key in trojan packet listener 2026-01-28 00:49:27 +08:00
wwqgtxx
6aaabc97ca chore: decrease unneeded string convert in socks5 addr parsing 2026-01-28 00:41:13 +08:00
wwqgtxx
85c024a4a6 fix: snat key in sudoku packet listener 2026-01-28 00:32:44 +08:00
wwqgtxx
c33c90d7af chore: clean up duplicate code in sudoku 2026-01-26 09:28:18 +08:00
wwqgtxx
65c3d3e4e2 chore: remove unreachable code in sudoku 2026-01-26 09:28:18 +08:00
wwqgtxx
98b3060558 chore: optimize timeout control in DoH TLS probe 2026-01-25 22:49:04 +08:00
wwqgtxx
b90100645e chore: using mihomo's global pool in DoQ 2026-01-25 22:48:48 +08:00
wwqgtxx
46ee1649c0 chore: hy2 listener fellow hysteria2's code skip verify in https masquerade 2026-01-25 21:49:02 +08:00
wwqgtxx
e3b8fc2b77 fix: hy2 listener panic with http/https masquerade 2026-01-25 18:12:51 +08:00
wwqgtxx
707fe8b207 chore: remove auto IDNA conversion in domain rules
The original upstream does not support it, and there are many places in the current code that do not support it either. Removing it will help maintain consistency in behavior across different parts.
2026-01-23 09:36:13 +08:00
wwqgtxx
1e1434d1de chore: remove an unnecessary variable 2026-01-20 14:48:59 +08:00
wwqgtxx
26052ba5e5 chore: remove confused varbin using in sing 2026-01-19 23:16:55 +08:00
wwqgtxx
75a0cd5aff fix: file exists when tun start 2026-01-18 16:37:51 +08:00
wwqgtxx
e03ba23f65 chore: update logrus 2026-01-18 11:00:49 +08:00
wwqgtxx
0c995a2479 chore: move proxiesWithProviders to hub/route internal to disallow external usage of this poorly implemented function 2026-01-17 22:26:14 +08:00
wwqgtxx
3c526ae06e feat: add query-server-name for ech-opts 2026-01-17 19:11:57 +08:00
wwqgtxx
0b009b514c doc: add missing params 2026-01-17 18:52:44 +08:00
wwqgtxx
18d139a15d chore: rollback tls to restore the session resumption functionality in quic-go 2026-01-17 18:27:19 +08:00
wwqgtxx
5f250413fe doc: remove deprecated item 2026-01-17 18:27:19 +08:00
wwqgtxx
993595df73 chore: switch to our own common/orderedmap package, remove two unneeded json dependence 2026-01-17 18:27:19 +08:00
H1JK
828fd30dc3 chore: support connection reuse for DoT 2026-01-16 14:20:20 +08:00
wwqgtxx
11000dccd7 chore: add common/deque package 2026-01-16 11:05:15 +08:00
wwqgtxx
0818aa54aa chore: provider a common entrance for YAML package 2026-01-16 11:05:13 +08:00
wwqgtxx
edbfebeacd fix: CVE-2025-68121 for crypto/tls 2026-01-16 08:27:29 +08:00
saba-futai
06f5fbac06 feat: add path-root for sudoku (#2511) 2026-01-14 21:25:05 +08:00
Shaw
f38fc2020f feat: add grpc-user-agent to grpc-opts (#2512) 2026-01-14 21:02:09 +08:00
wwqgtxx
97bce45eba chore: deprecated global-client-fingerprint, please set client-fingerprint directly on the proxy instead 2026-01-14 10:40:26 +08:00
Davoyan
bc28cd486a doc: fix typo in config.yaml (#2459) 2026-01-14 09:01:18 +08:00
wwqgtxx
cdabd1e8b1 chore: update utls 2026-01-14 08:02:37 +08:00
Toby
c5b0f00bb2 fix: logic issues with BBR impl
98872a4f38
2026-01-12 13:34:59 +08:00
wwqgtxx
c128d23dec chore: update quic-go to 0.59.0 2026-01-12 12:48:18 +08:00
wwqgtxx
ee37a353d0 fix: incorrect timestamp conversion in brutal 2026-01-12 12:45:52 +08:00
wwqgtxx
0cf37de1a8 chore: better time storage in rule wrapper 2026-01-12 00:50:55 +08:00
potoo0
ae6069c178 chore: moving rules disabled and hit/miss counts data to extra for restful api (#2503) 2026-01-11 21:11:38 +08:00
wwqgtxx
c8e33a4347 chore: decrease rule wrapper memory usage 2026-01-11 20:57:28 +08:00
potoo0
19a6b5d6f7 feat: support rule disabling and hit/miss count/at tracking in restful api (#2502) 2026-01-11 19:37:08 +08:00
wwqgtxx
efb800866e chore: update quic-go to 0.58.1 2026-01-11 17:19:53 +08:00
wwqgtxx
94c8d60f72 chore: simplified logic rule parsing 2026-01-08 23:42:01 +08:00
saba-futai
0f2baca2de chore: refactored the implementation of suduko mux (#2486) 2026-01-07 00:25:33 +08:00
wwqgtxx
b18a33552c chore: remove unused pointer in rules implements 2026-01-06 09:29:09 +08:00
wwqgtxx
487de9b548 feat: add PROCESS-NAME-WILDCARD and PROCESS-PATH-WILDCARD 2026-01-06 08:52:06 +08:00
enfein
1a6230ec03 chore: update mieru version (#2484)
Fix https://github.com/enfein/mieru/issues/247
2026-01-06 07:48:46 +08:00
wwqgtxx
e6bf56b9af fix: os.(*Process).Wait not working on Windows7 2026-01-05 20:26:19 +08:00
wwqgtxx
0ad9ac325a feat: support aes-128-gcm, ratelimit and framesize for kcptun 2026-01-05 12:25:30 +08:00
saba-futai
d6b1263236 feat: support http-mask-multiplex for suduko (#2482) 2026-01-04 22:24:42 +08:00
wwqgtxx
4d7670339b feat: all dns client support disable-qtype-<int> params 2026-01-02 22:43:58 +08:00
wwqgtxx
0cffc8d76d chore: revert "chore: update quic-go to 0.58.0"
This reverts commit 64015b7634.
2026-01-02 17:09:40 +08:00
wwqgtxx
1f8bee9710 chore: force to disable mptcp for tproxy 2025-12-31 08:43:23 +08:00
wwqgtxx
eb30d3f331 chore: add a code comment for tproxy listener 2025-12-31 02:07:52 +08:00
wwqgtxx
10f4bebdfa fix: only clear dstIP if it is confirmed to be a fake IP 2025-12-30 17:16:09 +08:00
David
06387d5045 feat: support fake-ip-filter-mode: rule mode (#2469) 2025-12-29 08:14:09 +08:00
wwqgtxx
c393e917eb fix: gvisor compatibility on go1.26 2025-12-27 17:57:30 +08:00
wwqgtxx
4f0a6fa117 fix: gvisor panic 2025-12-27 17:16:35 +08:00
wwqgtxx
4f9bfd216f chore: add some comments for the finalizer 2025-12-27 16:38:58 +08:00
joshua
498f81aad3 feat: add header support for rule provider (#2463) 2025-12-24 23:10:38 +08:00
wwqgtxx
9168bee6b7 chore: align internal logic 2025-12-24 18:26:55 +08:00
HolgerHuo
e6c0e3b19c fix: handle geoip:lan when GetRecodeSize() (#2460) 2025-12-24 08:34:19 +08:00
wwqgtxx
287f9e5185 chore: temporarily skip mieru inbound test in go1.26 on windows 2025-12-23 23:49:19 +08:00
wwqgtxx
c456370f4f fix: missing context cancel in pullLoop 2025-12-23 23:26:05 +08:00
wwqgtxx
10ef29f5cd chore: apply global ca in sudoku code 2025-12-23 23:15:52 +08:00
wwqgtxx
85ba7f6a0a chore: change import paths in sudoku code 2025-12-23 23:14:39 +08:00
saba-futai
7daf37bc15 feat: support http-mask-mode, http-mask-tls and http-mask-host for sudoku (#2456) 2025-12-23 23:08:38 +08:00
wwqgtxx
64015b7634 chore: update quic-go to 0.58.0 2025-12-22 17:29:28 +08:00
wwqgtxx
5585304d68 chore: allow custom path for gRPC (grpc-service-name start with /) 2025-12-21 10:28:05 +08:00
wwqgtxx
911211578c action: add Go 1.26rc1 to test 2025-12-21 00:12:03 +08:00
wwqgtxx
abb55199f2 fix: os.RemoveAll not working on Windows7 2025-12-20 23:02:26 +08:00
wwqgtxx
87c3f700e5 chore: add TODO comment to ca.LoadCertificates 2025-12-19 21:43:55 +08:00
wwqgtxx
4a723e8d3f chore: allow automatic reloading when the TLS server's certificate, private-key or ech-key is a local file 2025-12-19 20:23:48 +08:00
wwqgtxx
93cf46e430 chore: remove unused import path 2025-12-19 20:14:02 +08:00
Howard Wu
35a1130c92 chore: use HasPrefix instead of Contains for key checks (#2447) 2025-12-19 18:43:06 +08:00
Howard Wu
1ebcb25e4a fix: typo in sniffer skip-dst-address config parsing (#2446) 2025-12-19 18:16:56 +08:00
wwqgtxx
cbcacdbb8c chore: using tls.Config.GetCertificate/GetClientCertificate to load TLS certificates 2025-12-19 12:24:16 +08:00
wwqgtxx
17966b5418 fix: close sing-tun maybe panic on windows 2025-12-18 10:37:50 +08:00
wwqgtxx
bc8f0dcf77 fix: missing ntp call 2025-12-17 18:50:33 +08:00
wwqgtxx
827cd616e8 chore: cleanup import path 2025-12-17 17:35:58 +08:00
wwqgtxx
e1384e86ab chore: update http2 using in test 2025-12-17 17:19:09 +08:00
wwqgtxx
b92b38701c chore: update ech handling 2025-12-17 17:19:06 +08:00
wwqgtxx
1cab34d257 chore: update quic-go to 0.57.1 2025-12-17 16:13:12 +08:00
saba-futai
a06097c2c4 chore: add xvp rotation andd new header generation strategy for sudoku (#2437) 2025-12-16 18:39:39 +08:00
wwqgtxx
bc9db11cb4 chore: hub/route module handle websocket itself 2025-12-14 19:56:30 +08:00
Ealrang
69e301820c action: fix architecture check for riscv64 in script (#2435) 2025-12-14 17:30:40 +08:00
wwqgtxx
e7a04e0762 chore: don't process msg.Extra in msgToHTTPSRRInfo 2025-12-12 19:42:29 +08:00
Eric Moore
7e8c2876fb chore: improve HTTPS RR logging (#2431) 2025-12-12 17:43:54 +08:00
wwqgtxx
936ebc7718 chore: add echparser package for parse ECHConfigList and ECHConfig 2025-12-12 16:05:11 +08:00
Eric Moore
b753a57e6a fix: ech not work with websocket+clientFingerprint 2025-12-11 23:15:40 +08:00
wwqgtxx
dd99bfc892 doc: fix custom-table doc 2025-12-11 14:53:02 +08:00
wwqgtxx
2a1b3b2aed chore: allow sudoku inbound handle sing-mux request 2025-12-11 14:14:21 +08:00
saba-futai
2211789a7c chore: add customized byte style for sudoku (#2427) 2025-12-10 17:47:59 +08:00
wwqgtxx
e652e277a7 fix: missing ProxyInfo information in wireguard outbound 2025-12-10 17:06:13 +08:00
wwqgtxx
40863d248d chore: add lock in baseProvider for thread-safe 2025-12-10 08:42:40 +08:00
wwqgtxx
17b8eb8772 chore: skip icmp forwarding when destination in tun interface addr range 2025-12-08 09:56:15 +08:00
Vincent Loeng/Leong
6b40072bc5 chore: support find process on freebsd 14 and 15 (#2422) 2025-12-06 14:03:59 +08:00
wwqgtxx
f44aa22d50 chore: add sudoku ed25519key test 2025-12-05 09:06:06 +08:00
wwqgtxx
c33d9ad857 chore: cleanup sudoku internal code 2025-12-05 08:53:18 +08:00
saba-futai
25041b599e chore: sudoku support enable-pure-downlink mode to increase download bandwidth (#2419) 2025-12-05 07:52:49 +08:00
wwqgtxx
6539b509cb chore: restful api contains providerChains for connections 2025-12-04 17:29:01 +08:00
Xi Xu
d2007fdc22 chore: improves thread safety in adapter 2025-12-04 16:02:22 +08:00
wwqgtxx
b5fa3ee99a chore: restful api contains provider-name for proxies 2025-12-04 15:10:13 +08:00
wwqgtxx
91f5593f4e fix: structure ignore tag not working in nest struct 2025-12-04 14:44:34 +08:00
wwqgtxx
90470ac304 chore: cleanup import path for common/net 2025-12-04 13:44:46 +08:00
wwqgtxx
b509affe5b chore: simplify DNSPrefer serialization process 2025-12-04 13:41:44 +08:00
wwqgtxx
32ce513977 chore: discard domain addr input in sudoku uot 2025-12-03 22:54:26 +08:00
wwqgtxx
30891f8781 chore: sharing sudoku internal code 2025-12-03 22:23:37 +08:00
saba-futai
e4cdb9b600 feat: add uot for sudoku (#2415) 2025-12-03 22:11:56 +08:00
wwqgtxx
d33dbbe2f9 fix: QUIC events with session tickets disabled will panic on Go 1.26 2025-12-03 15:40:23 +08:00
wwqgtxx
d8dcaa7500 chore: add upTotal and downTotal data to /traffic restful api 2025-12-03 11:31:13 +08:00
wwqgtxx
9df8392c65 chore: clean up internal interface definitions 2025-12-03 11:08:16 +08:00
wwqgtxx
fdb7cb1f58 chore: allow setting DialerForAPI in adapter.ParseProxy for library user 2025-12-03 00:05:27 +08:00
wwqgtxx
7cd58fbdf6 chore: add DialerForAPI to outbound option for library user 2025-12-02 23:33:07 +08:00
wwqgtxx
bc719eb96d chore: simplify tuic client 2025-12-02 21:07:51 +08:00
wwqgtxx
ac90543548 chore: code cleanup 2025-12-02 17:18:20 +08:00
futai
9a5e506f66 chore: simplify server config and add keygen for sudoku (#2407) 2025-12-01 19:26:41 +08:00
246 changed files with 15433 additions and 2821 deletions

View File

@@ -1,4 +1,6 @@
Subject: [PATCH] Revert "runtime: always use LoadLibraryEx to load system libraries"
Subject: [PATCH] Revert "os: remove 5ms sleep on Windows in (*Process).Wait"
Fix os.RemoveAll not working on Windows7
Revert "runtime: always use LoadLibraryEx to load system libraries"
Revert "syscall: remove Windows 7 console handle workaround"
Revert "net: remove sysSocket fallback for Windows 7"
Revert "crypto/rand,runtime: switch RtlGenRandom for ProcessPrng"
@@ -655,3 +657,228 @@ diff --git a/src/syscall/dll_windows.go b/src/syscall/dll_windows.go
} else {
h, e = loadlibrary(namep)
}
Index: src/os/removeall_at.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/os/removeall_at.go b/src/os/removeall_at.go
--- a/src/os/removeall_at.go (revision f56f1e23507e646c85243a71bde7b9629b2f970c)
+++ b/src/os/removeall_at.go (revision 0a52622d2331ff975fb0442617ec19bc352bb2ed)
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build unix || wasip1 || windows
+//go:build unix || wasip1
package os
@@ -175,3 +175,25 @@
}
return newDirFile(fd, name)
}
+
+func rootRemoveAll(r *Root, name string) error {
+ // Consistency with os.RemoveAll: Strip trailing /s from the name,
+ // so RemoveAll("not_a_directory/") succeeds.
+ for len(name) > 0 && IsPathSeparator(name[len(name)-1]) {
+ name = name[:len(name)-1]
+ }
+ if endsWithDot(name) {
+ // Consistency with os.RemoveAll: Return EINVAL when trying to remove .
+ return &PathError{Op: "RemoveAll", Path: name, Err: syscall.EINVAL}
+ }
+ _, err := doInRoot(r, name, nil, func(parent sysfdType, name string) (struct{}, error) {
+ return struct{}{}, removeAllFrom(parent, name)
+ })
+ if IsNotExist(err) {
+ return nil
+ }
+ if err != nil {
+ return &PathError{Op: "RemoveAll", Path: name, Err: underlyingError(err)}
+ }
+ return err
+}
Index: src/os/removeall_noat.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/os/removeall_noat.go b/src/os/removeall_noat.go
--- a/src/os/removeall_noat.go (revision f56f1e23507e646c85243a71bde7b9629b2f970c)
+++ b/src/os/removeall_noat.go (revision 0a52622d2331ff975fb0442617ec19bc352bb2ed)
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build (js && wasm) || plan9
+//go:build (js && wasm) || plan9 || windows
package os
@@ -140,3 +140,22 @@
}
return err
}
+
+func rootRemoveAll(r *Root, name string) error {
+ if endsWithDot(name) {
+ // Consistency with os.RemoveAll: Return EINVAL when trying to remove .
+ return &PathError{Op: "RemoveAll", Path: name, Err: syscall.EINVAL}
+ }
+ if err := checkPathEscapesLstat(r, name); err != nil {
+ if err == syscall.ENOTDIR {
+ // Some intermediate path component is not a directory.
+ // RemoveAll treats this as success (since the target doesn't exist).
+ return nil
+ }
+ return &PathError{Op: "RemoveAll", Path: name, Err: err}
+ }
+ if err := RemoveAll(joinPath(r.root.name, name)); err != nil {
+ return &PathError{Op: "RemoveAll", Path: name, Err: underlyingError(err)}
+ }
+ return nil
+}
Index: src/os/root_noopenat.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/os/root_noopenat.go b/src/os/root_noopenat.go
--- a/src/os/root_noopenat.go (revision f56f1e23507e646c85243a71bde7b9629b2f970c)
+++ b/src/os/root_noopenat.go (revision 0a52622d2331ff975fb0442617ec19bc352bb2ed)
@@ -11,7 +11,6 @@
"internal/filepathlite"
"internal/stringslite"
"sync/atomic"
- "syscall"
"time"
)
@@ -185,25 +184,6 @@
}
return nil
}
-
-func rootRemoveAll(r *Root, name string) error {
- if endsWithDot(name) {
- // Consistency with os.RemoveAll: Return EINVAL when trying to remove .
- return &PathError{Op: "RemoveAll", Path: name, Err: syscall.EINVAL}
- }
- if err := checkPathEscapesLstat(r, name); err != nil {
- if err == syscall.ENOTDIR {
- // Some intermediate path component is not a directory.
- // RemoveAll treats this as success (since the target doesn't exist).
- return nil
- }
- return &PathError{Op: "RemoveAll", Path: name, Err: err}
- }
- if err := RemoveAll(joinPath(r.root.name, name)); err != nil {
- return &PathError{Op: "RemoveAll", Path: name, Err: underlyingError(err)}
- }
- return nil
-}
func rootReadlink(r *Root, name string) (string, error) {
if err := checkPathEscapesLstat(r, name); err != nil {
Index: src/os/root_openat.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/os/root_openat.go b/src/os/root_openat.go
--- a/src/os/root_openat.go (revision f56f1e23507e646c85243a71bde7b9629b2f970c)
+++ b/src/os/root_openat.go (revision 0a52622d2331ff975fb0442617ec19bc352bb2ed)
@@ -194,28 +194,6 @@
return nil
}
-func rootRemoveAll(r *Root, name string) error {
- // Consistency with os.RemoveAll: Strip trailing /s from the name,
- // so RemoveAll("not_a_directory/") succeeds.
- for len(name) > 0 && IsPathSeparator(name[len(name)-1]) {
- name = name[:len(name)-1]
- }
- if endsWithDot(name) {
- // Consistency with os.RemoveAll: Return EINVAL when trying to remove .
- return &PathError{Op: "RemoveAll", Path: name, Err: syscall.EINVAL}
- }
- _, err := doInRoot(r, name, nil, func(parent sysfdType, name string) (struct{}, error) {
- return struct{}{}, removeAllFrom(parent, name)
- })
- if IsNotExist(err) {
- return nil
- }
- if err != nil {
- return &PathError{Op: "RemoveAll", Path: name, Err: underlyingError(err)}
- }
- return err
-}
-
func rootRename(r *Root, oldname, newname string) error {
_, err := doInRoot(r, oldname, nil, func(oldparent sysfdType, oldname string) (struct{}, error) {
_, err := doInRoot(r, newname, nil, func(newparent sysfdType, newname string) (struct{}, error) {
Index: src/os/root_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/os/root_windows.go b/src/os/root_windows.go
--- a/src/os/root_windows.go (revision f56f1e23507e646c85243a71bde7b9629b2f970c)
+++ b/src/os/root_windows.go (revision 0a52622d2331ff975fb0442617ec19bc352bb2ed)
@@ -402,3 +402,14 @@
}
return fi.Mode(), nil
}
+
+func checkPathEscapes(r *Root, name string) error {
+ if !filepathlite.IsLocal(name) {
+ return errPathEscapes
+ }
+ return nil
+}
+
+func checkPathEscapesLstat(r *Root, name string) error {
+ return checkPathEscapes(r, name)
+}
Index: src/os/exec_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/os/exec_windows.go b/src/os/exec_windows.go
--- a/src/os/exec_windows.go (revision 0a52622d2331ff975fb0442617ec19bc352bb2ed)
+++ b/src/os/exec_windows.go (revision fb3d09a67fe97008ad76fea97ae88170072cbdbb)
@@ -10,6 +10,7 @@
"runtime"
"syscall"
"time"
+ _ "unsafe"
)
// Note that Process.handle is never nil because Windows always requires
@@ -49,9 +50,23 @@
// than statusDone.
p.doRelease(statusReleased)
+ var maj, min, build uint32
+ rtlGetNtVersionNumbers(&maj, &min, &build)
+ if maj < 10 {
+ // NOTE(brainman): It seems that sometimes process is not dead
+ // when WaitForSingleObject returns. But we do not know any
+ // other way to wait for it. Sleeping for a while seems to do
+ // the trick sometimes.
+ // See https://golang.org/issue/25965 for details.
+ time.Sleep(5 * time.Millisecond)
+ }
+
return &ProcessState{p.Pid, syscall.WaitStatus{ExitCode: ec}, &u}, nil
}
+//go:linkname rtlGetNtVersionNumbers syscall.rtlGetNtVersionNumbers
+func rtlGetNtVersionNumbers(majorVersion *uint32, minorVersion *uint32, buildNumber *uint32)
+
func (p *Process) signal(sig Signal) error {
handle, status := p.handleTransientAcquire()
switch status {

883
.github/patch/go1.26.patch vendored Normal file
View File

@@ -0,0 +1,883 @@
Subject: [PATCH] Revert "os: remove 5ms sleep on Windows in (*Process).Wait"
Fix os.RemoveAll not working on Windows7
Revert "runtime: always use LoadLibraryEx to load system libraries"
Revert "syscall: remove Windows 7 console handle workaround"
Revert "net: remove sysSocket fallback for Windows 7"
Revert "crypto/rand,runtime: switch RtlGenRandom for ProcessPrng"
---
Index: src/crypto/internal/sysrand/rand_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/crypto/internal/sysrand/rand_windows.go b/src/crypto/internal/sysrand/rand_windows.go
--- a/src/crypto/internal/sysrand/rand_windows.go (revision c599a8f2385849a225d02843b3c6389dbfc5aa69)
+++ b/src/crypto/internal/sysrand/rand_windows.go (revision b0d48afabb9fd14976c27221cb525c5d2ebbfe79)
@@ -7,5 +7,26 @@
import "internal/syscall/windows"
func read(b []byte) error {
- return windows.ProcessPrng(b)
+ // RtlGenRandom only returns 1<<32-1 bytes at a time. We only read at
+ // most 1<<31-1 bytes at a time so that this works the same on 32-bit
+ // and 64-bit systems.
+ return batched(windows.RtlGenRandom, 1<<31-1)(b)
+}
+
+// batched returns a function that calls f to populate a []byte by chunking it
+// into subslices of, at most, readMax bytes.
+func batched(f func([]byte) error, readMax int) func([]byte) error {
+ return func(out []byte) error {
+ for len(out) > 0 {
+ read := len(out)
+ if read > readMax {
+ read = readMax
+ }
+ if err := f(out[:read]); err != nil {
+ return err
+ }
+ out = out[read:]
+ }
+ return nil
+ }
}
Index: src/crypto/rand/rand.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/crypto/rand/rand.go b/src/crypto/rand/rand.go
--- a/src/crypto/rand/rand.go (revision c599a8f2385849a225d02843b3c6389dbfc5aa69)
+++ b/src/crypto/rand/rand.go (revision b0d48afabb9fd14976c27221cb525c5d2ebbfe79)
@@ -25,7 +25,7 @@
// - On legacy Linux (< 3.17), Reader opens /dev/urandom on first use.
// - On macOS, iOS, and OpenBSD Reader, uses arc4random_buf(3).
// - On NetBSD, Reader uses the kern.arandom sysctl.
-// - On Windows, Reader uses the ProcessPrng API.
+// - On Windows systems, Reader uses the RtlGenRandom API.
// - On js/wasm, Reader uses the Web Crypto API.
// - On wasip1/wasm, Reader uses random_get.
//
Index: src/internal/syscall/windows/syscall_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/internal/syscall/windows/syscall_windows.go b/src/internal/syscall/windows/syscall_windows.go
--- a/src/internal/syscall/windows/syscall_windows.go (revision c599a8f2385849a225d02843b3c6389dbfc5aa69)
+++ b/src/internal/syscall/windows/syscall_windows.go (revision b0d48afabb9fd14976c27221cb525c5d2ebbfe79)
@@ -421,7 +421,7 @@
//sys DestroyEnvironmentBlock(block *uint16) (err error) = userenv.DestroyEnvironmentBlock
//sys CreateEvent(eventAttrs *SecurityAttributes, manualReset uint32, initialState uint32, name *uint16) (handle syscall.Handle, err error) = kernel32.CreateEventW
-//sys ProcessPrng(buf []byte) (err error) = bcryptprimitives.ProcessPrng
+//sys RtlGenRandom(buf []byte) (err error) = advapi32.SystemFunction036
type FILE_ID_BOTH_DIR_INFO struct {
NextEntryOffset uint32
Index: src/internal/syscall/windows/zsyscall_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/internal/syscall/windows/zsyscall_windows.go b/src/internal/syscall/windows/zsyscall_windows.go
--- a/src/internal/syscall/windows/zsyscall_windows.go (revision c599a8f2385849a225d02843b3c6389dbfc5aa69)
+++ b/src/internal/syscall/windows/zsyscall_windows.go (revision b0d48afabb9fd14976c27221cb525c5d2ebbfe79)
@@ -38,7 +38,6 @@
var (
modadvapi32 = syscall.NewLazyDLL(sysdll.Add("advapi32.dll"))
- modbcryptprimitives = syscall.NewLazyDLL(sysdll.Add("bcryptprimitives.dll"))
modiphlpapi = syscall.NewLazyDLL(sysdll.Add("iphlpapi.dll"))
modkernel32 = syscall.NewLazyDLL(sysdll.Add("kernel32.dll"))
modnetapi32 = syscall.NewLazyDLL(sysdll.Add("netapi32.dll"))
@@ -63,7 +62,7 @@
procQueryServiceStatus = modadvapi32.NewProc("QueryServiceStatus")
procRevertToSelf = modadvapi32.NewProc("RevertToSelf")
procSetTokenInformation = modadvapi32.NewProc("SetTokenInformation")
- procProcessPrng = modbcryptprimitives.NewProc("ProcessPrng")
+ procSystemFunction036 = modadvapi32.NewProc("SystemFunction036")
procGetAdaptersAddresses = modiphlpapi.NewProc("GetAdaptersAddresses")
procCreateEventW = modkernel32.NewProc("CreateEventW")
procCreateIoCompletionPort = modkernel32.NewProc("CreateIoCompletionPort")
@@ -244,12 +243,12 @@
return
}
-func ProcessPrng(buf []byte) (err error) {
+func RtlGenRandom(buf []byte) (err error) {
var _p0 *byte
if len(buf) > 0 {
_p0 = &buf[0]
}
- r1, _, e1 := syscall.SyscallN(procProcessPrng.Addr(), uintptr(unsafe.Pointer(_p0)), uintptr(len(buf)))
+ r1, _, e1 := syscall.SyscallN(procSystemFunction036.Addr(), uintptr(unsafe.Pointer(_p0)), uintptr(len(buf)), 0)
if r1 == 0 {
err = errnoErr(e1)
}
Index: src/runtime/os_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/runtime/os_windows.go b/src/runtime/os_windows.go
--- a/src/runtime/os_windows.go (revision c599a8f2385849a225d02843b3c6389dbfc5aa69)
+++ b/src/runtime/os_windows.go (revision ea2726a6fa25fbfa1092e696e522eafca544d24c)
@@ -40,7 +40,8 @@
//go:cgo_import_dynamic runtime._GetSystemInfo GetSystemInfo%1 "kernel32.dll"
//go:cgo_import_dynamic runtime._GetThreadContext GetThreadContext%2 "kernel32.dll"
//go:cgo_import_dynamic runtime._SetThreadContext SetThreadContext%2 "kernel32.dll"
-//go:cgo_import_dynamic runtime._LoadLibraryExW LoadLibraryExW%3 "kernel32.dll"
+//go:cgo_import_dynamic runtime._LoadLibraryW LoadLibraryW%1 "kernel32.dll"
+//go:cgo_import_dynamic runtime._LoadLibraryA LoadLibraryA%1 "kernel32.dll"
//go:cgo_import_dynamic runtime._PostQueuedCompletionStatus PostQueuedCompletionStatus%4 "kernel32.dll"
//go:cgo_import_dynamic runtime._QueryPerformanceCounter QueryPerformanceCounter%1 "kernel32.dll"
//go:cgo_import_dynamic runtime._QueryPerformanceFrequency QueryPerformanceFrequency%1 "kernel32.dll"
@@ -74,7 +75,6 @@
// Following syscalls are available on every Windows PC.
// All these variables are set by the Windows executable
// loader before the Go program starts.
- _AddVectoredContinueHandler,
_AddVectoredExceptionHandler,
_CloseHandle,
_CreateEventA,
@@ -97,7 +97,8 @@
_GetSystemInfo,
_GetThreadContext,
_SetThreadContext,
- _LoadLibraryExW,
+ _LoadLibraryW,
+ _LoadLibraryA,
_PostQueuedCompletionStatus,
_QueryPerformanceCounter,
_QueryPerformanceFrequency,
@@ -126,8 +127,23 @@
_WriteFile,
_ stdFunction
- // Use ProcessPrng to generate cryptographically random data.
- _ProcessPrng stdFunction
+ // Following syscalls are only available on some Windows PCs.
+ // We will load syscalls, if available, before using them.
+ _AddDllDirectory,
+ _AddVectoredContinueHandler,
+ _LoadLibraryExA,
+ _LoadLibraryExW,
+ _ stdFunction
+
+ // Use RtlGenRandom to generate cryptographically random data.
+ // This approach has been recommended by Microsoft (see issue
+ // 15589 for details).
+ // The RtlGenRandom is not listed in advapi32.dll, instead
+ // RtlGenRandom function can be found by searching for SystemFunction036.
+ // Also some versions of Mingw cannot link to SystemFunction036
+ // when building executable as Cgo. So load SystemFunction036
+ // manually during runtime startup.
+ _RtlGenRandom stdFunction
// Load ntdll.dll manually during startup, otherwise Mingw
// links wrong printf function to cgo executable (see issue
@@ -144,13 +160,6 @@
_ stdFunction
)
-var (
- bcryptprimitivesdll = [...]uint16{'b', 'c', 'r', 'y', 'p', 't', 'p', 'r', 'i', 'm', 'i', 't', 'i', 'v', 'e', 's', '.', 'd', 'l', 'l', 0}
- ntdlldll = [...]uint16{'n', 't', 'd', 'l', 'l', '.', 'd', 'l', 'l', 0}
- powrprofdll = [...]uint16{'p', 'o', 'w', 'r', 'p', 'r', 'o', 'f', '.', 'd', 'l', 'l', 0}
- winmmdll = [...]uint16{'w', 'i', 'n', 'm', 'm', '.', 'd', 'l', 'l', 0}
-)
-
// Function to be called by windows CreateThread
// to start new os thread.
func tstart_stdcall(newm *m)
@@ -242,9 +251,40 @@
return unsafe.String(&sysDirectory[0], sysDirectoryLen)
}
-func windowsLoadSystemLib(name []uint16) uintptr {
- const _LOAD_LIBRARY_SEARCH_SYSTEM32 = 0x00000800
- return stdcall(_LoadLibraryExW, uintptr(unsafe.Pointer(&name[0])), 0, _LOAD_LIBRARY_SEARCH_SYSTEM32)
+//go:linkname syscall_getSystemDirectory syscall.getSystemDirectory
+func syscall_getSystemDirectory() string {
+ return unsafe.String(&sysDirectory[0], sysDirectoryLen)
+}
+
+func windowsLoadSystemLib(name []byte) uintptr {
+ if useLoadLibraryEx {
+ return stdcall(_LoadLibraryExA, uintptr(unsafe.Pointer(&name[0])), 0, _LOAD_LIBRARY_SEARCH_SYSTEM32)
+ } else {
+ absName := append(sysDirectory[:sysDirectoryLen], name...)
+ return stdcall(_LoadLibraryA, uintptr(unsafe.Pointer(&absName[0])))
+ }
+}
+
+const _LOAD_LIBRARY_SEARCH_SYSTEM32 = 0x00000800
+
+// When available, this function will use LoadLibraryEx with the filename
+// parameter and the important SEARCH_SYSTEM32 argument. But on systems that
+// do not have that option, absoluteFilepath should contain a fallback
+// to the full path inside of system32 for use with vanilla LoadLibrary.
+//
+//go:linkname syscall_loadsystemlibrary syscall.loadsystemlibrary
+func syscall_loadsystemlibrary(filename *uint16, absoluteFilepath *uint16) (handle, err uintptr) {
+ if useLoadLibraryEx {
+ handle, _, err = syscall_syscalln(uintptr(unsafe.Pointer(_LoadLibraryExW)), 3, uintptr(unsafe.Pointer(filename)), 0, _LOAD_LIBRARY_SEARCH_SYSTEM32)
+ } else {
+ handle, _, err = syscall_syscalln(uintptr(unsafe.Pointer(_LoadLibraryW)), 1, uintptr(unsafe.Pointer(absoluteFilepath)))
+ }
+ KeepAlive(filename)
+ KeepAlive(absoluteFilepath)
+ if handle != 0 {
+ err = 0
+ }
+ return
}
//go:linkname windows_QueryPerformanceCounter internal/syscall/windows.QueryPerformanceCounter
@@ -262,13 +302,28 @@
}
func loadOptionalSyscalls() {
- bcryptPrimitives := windowsLoadSystemLib(bcryptprimitivesdll[:])
- if bcryptPrimitives == 0 {
- throw("bcryptprimitives.dll not found")
+ var kernel32dll = []byte("kernel32.dll\000")
+ k32 := stdcall(_LoadLibraryA, uintptr(unsafe.Pointer(&kernel32dll[0])))
+ if k32 == 0 {
+ throw("kernel32.dll not found")
}
- _ProcessPrng = windowsFindfunc(bcryptPrimitives, []byte("ProcessPrng\000"))
+ _AddDllDirectory = windowsFindfunc(k32, []byte("AddDllDirectory\000"))
+ _AddVectoredContinueHandler = windowsFindfunc(k32, []byte("AddVectoredContinueHandler\000"))
+ _LoadLibraryExA = windowsFindfunc(k32, []byte("LoadLibraryExA\000"))
+ _LoadLibraryExW = windowsFindfunc(k32, []byte("LoadLibraryExW\000"))
+ useLoadLibraryEx = (_LoadLibraryExW != nil && _LoadLibraryExA != nil && _AddDllDirectory != nil)
+
+ initSysDirectory()
- n32 := windowsLoadSystemLib(ntdlldll[:])
+ var advapi32dll = []byte("advapi32.dll\000")
+ a32 := windowsLoadSystemLib(advapi32dll)
+ if a32 == 0 {
+ throw("advapi32.dll not found")
+ }
+ _RtlGenRandom = windowsFindfunc(a32, []byte("SystemFunction036\000"))
+
+ var ntdll = []byte("ntdll.dll\000")
+ n32 := windowsLoadSystemLib(ntdll)
if n32 == 0 {
throw("ntdll.dll not found")
}
@@ -297,7 +352,7 @@
context uintptr
}
- powrprof := windowsLoadSystemLib(powrprofdll[:])
+ powrprof := windowsLoadSystemLib([]byte("powrprof.dll\000"))
if powrprof == 0 {
return // Running on Windows 7, where we don't need it anyway.
}
@@ -351,6 +406,22 @@
// in sys_windows_386.s and sys_windows_amd64.s:
func getlasterror() uint32
+// When loading DLLs, we prefer to use LoadLibraryEx with
+// LOAD_LIBRARY_SEARCH_* flags, if available. LoadLibraryEx is not
+// available on old Windows, though, and the LOAD_LIBRARY_SEARCH_*
+// flags are not available on some versions of Windows without a
+// security patch.
+//
+// https://msdn.microsoft.com/en-us/library/ms684179(v=vs.85).aspx says:
+// "Windows 7, Windows Server 2008 R2, Windows Vista, and Windows
+// Server 2008: The LOAD_LIBRARY_SEARCH_* flags are available on
+// systems that have KB2533623 installed. To determine whether the
+// flags are available, use GetProcAddress to get the address of the
+// AddDllDirectory, RemoveDllDirectory, or SetDefaultDllDirectories
+// function. If GetProcAddress succeeds, the LOAD_LIBRARY_SEARCH_*
+// flags can be used with LoadLibraryEx."
+var useLoadLibraryEx bool
+
var timeBeginPeriodRetValue uint32
// osRelaxMinNS indicates that sysmon shouldn't osRelax if the next
@@ -417,7 +488,8 @@
// Only load winmm.dll if we need it.
// This avoids a dependency on winmm.dll for Go programs
// that run on new Windows versions.
- m32 := windowsLoadSystemLib(winmmdll[:])
+ var winmmdll = []byte("winmm.dll\000")
+ m32 := windowsLoadSystemLib(winmmdll)
if m32 == 0 {
print("runtime: LoadLibraryExW failed; errno=", getlasterror(), "\n")
throw("winmm.dll not found")
@@ -458,6 +530,28 @@
canUseLongPaths = true
}
+var osVersionInfo struct {
+ majorVersion uint32
+ minorVersion uint32
+ buildNumber uint32
+}
+
+func initOsVersionInfo() {
+ info := windows.OSVERSIONINFOW{}
+ info.OSVersionInfoSize = uint32(unsafe.Sizeof(info))
+ stdcall(_RtlGetVersion, uintptr(unsafe.Pointer(&info)))
+ osVersionInfo.majorVersion = info.MajorVersion
+ osVersionInfo.minorVersion = info.MinorVersion
+ osVersionInfo.buildNumber = info.BuildNumber
+}
+
+//go:linkname rtlGetNtVersionNumbers syscall.rtlGetNtVersionNumbers
+func rtlGetNtVersionNumbers(majorVersion *uint32, minorVersion *uint32, buildNumber *uint32) {
+ *majorVersion = osVersionInfo.majorVersion
+ *minorVersion = osVersionInfo.minorVersion
+ *buildNumber = osVersionInfo.buildNumber
+}
+
func osinit() {
asmstdcallAddr = unsafe.Pointer(windows.AsmStdCallAddr())
@@ -470,8 +564,8 @@
initHighResTimer()
timeBeginPeriodRetValue = osRelax(false)
- initSysDirectory()
initLongPathSupport()
+ initOsVersionInfo()
numCPUStartup = getCPUCount()
@@ -487,7 +581,7 @@
//go:nosplit
func readRandom(r []byte) int {
n := 0
- if stdcall(_ProcessPrng, uintptr(unsafe.Pointer(&r[0])), uintptr(len(r)))&0xff != 0 {
+ if stdcall(_RtlGenRandom, uintptr(unsafe.Pointer(&r[0])), uintptr(len(r)))&0xff != 0 {
n = len(r)
}
return n
Index: src/net/hook_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/net/hook_windows.go b/src/net/hook_windows.go
--- a/src/net/hook_windows.go (revision b0d48afabb9fd14976c27221cb525c5d2ebbfe79)
+++ b/src/net/hook_windows.go (revision 44e76f7cf1bc6e04b5da724e0b2e48f393713506)
@@ -13,6 +13,7 @@
hostsFilePath = windows.GetSystemDirectory() + "/Drivers/etc/hosts"
// Placeholders for socket system calls.
+ socketFunc func(int, int, int) (syscall.Handle, error) = syscall.Socket
wsaSocketFunc func(int32, int32, int32, *syscall.WSAProtocolInfo, uint32, uint32) (syscall.Handle, error) = windows.WSASocket
connectFunc func(syscall.Handle, syscall.Sockaddr) error = syscall.Connect
listenFunc func(syscall.Handle, int) error = syscall.Listen
Index: src/net/internal/socktest/main_test.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/net/internal/socktest/main_test.go b/src/net/internal/socktest/main_test.go
--- a/src/net/internal/socktest/main_test.go (revision b0d48afabb9fd14976c27221cb525c5d2ebbfe79)
+++ b/src/net/internal/socktest/main_test.go (revision 44e76f7cf1bc6e04b5da724e0b2e48f393713506)
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build !js && !plan9 && !wasip1 && !windows
+//go:build !js && !plan9 && !wasip1
package socktest_test
Index: src/net/internal/socktest/main_windows_test.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/net/internal/socktest/main_windows_test.go b/src/net/internal/socktest/main_windows_test.go
new file mode 100644
--- /dev/null (revision 44e76f7cf1bc6e04b5da724e0b2e48f393713506)
+++ b/src/net/internal/socktest/main_windows_test.go (revision 44e76f7cf1bc6e04b5da724e0b2e48f393713506)
@@ -0,0 +1,22 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package socktest_test
+
+import "syscall"
+
+var (
+ socketFunc func(int, int, int) (syscall.Handle, error)
+ closeFunc func(syscall.Handle) error
+)
+
+func installTestHooks() {
+ socketFunc = sw.Socket
+ closeFunc = sw.Closesocket
+}
+
+func uninstallTestHooks() {
+ socketFunc = syscall.Socket
+ closeFunc = syscall.Closesocket
+}
Index: src/net/internal/socktest/sys_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/net/internal/socktest/sys_windows.go b/src/net/internal/socktest/sys_windows.go
--- a/src/net/internal/socktest/sys_windows.go (revision b0d48afabb9fd14976c27221cb525c5d2ebbfe79)
+++ b/src/net/internal/socktest/sys_windows.go (revision 44e76f7cf1bc6e04b5da724e0b2e48f393713506)
@@ -9,6 +9,38 @@
"syscall"
)
+// Socket wraps [syscall.Socket].
+func (sw *Switch) Socket(family, sotype, proto int) (s syscall.Handle, err error) {
+ sw.once.Do(sw.init)
+
+ so := &Status{Cookie: cookie(family, sotype, proto)}
+ sw.fmu.RLock()
+ f, _ := sw.fltab[FilterSocket]
+ sw.fmu.RUnlock()
+
+ af, err := f.apply(so)
+ if err != nil {
+ return syscall.InvalidHandle, err
+ }
+ s, so.Err = syscall.Socket(family, sotype, proto)
+ if err = af.apply(so); err != nil {
+ if so.Err == nil {
+ syscall.Closesocket(s)
+ }
+ return syscall.InvalidHandle, err
+ }
+
+ sw.smu.Lock()
+ defer sw.smu.Unlock()
+ if so.Err != nil {
+ sw.stats.getLocked(so.Cookie).OpenFailed++
+ return syscall.InvalidHandle, so.Err
+ }
+ nso := sw.addLocked(s, family, sotype, proto)
+ sw.stats.getLocked(nso.Cookie).Opened++
+ return s, nil
+}
+
// WSASocket wraps [syscall.WSASocket].
func (sw *Switch) WSASocket(family, sotype, proto int32, protinfo *syscall.WSAProtocolInfo, group uint32, flags uint32) (s syscall.Handle, err error) {
sw.once.Do(sw.init)
Index: src/net/main_windows_test.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/net/main_windows_test.go b/src/net/main_windows_test.go
--- a/src/net/main_windows_test.go (revision b0d48afabb9fd14976c27221cb525c5d2ebbfe79)
+++ b/src/net/main_windows_test.go (revision 44e76f7cf1bc6e04b5da724e0b2e48f393713506)
@@ -12,6 +12,7 @@
var (
// Placeholders for saving original socket system calls.
+ origSocket = socketFunc
origWSASocket = wsaSocketFunc
origClosesocket = poll.CloseFunc
origConnect = connectFunc
@@ -21,6 +22,7 @@
)
func installTestHooks() {
+ socketFunc = sw.Socket
wsaSocketFunc = sw.WSASocket
poll.CloseFunc = sw.Closesocket
connectFunc = sw.Connect
@@ -30,6 +32,7 @@
}
func uninstallTestHooks() {
+ socketFunc = origSocket
wsaSocketFunc = origWSASocket
poll.CloseFunc = origClosesocket
connectFunc = origConnect
Index: src/net/sock_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/net/sock_windows.go b/src/net/sock_windows.go
--- a/src/net/sock_windows.go (revision b0d48afabb9fd14976c27221cb525c5d2ebbfe79)
+++ b/src/net/sock_windows.go (revision 44e76f7cf1bc6e04b5da724e0b2e48f393713506)
@@ -20,6 +20,21 @@
func sysSocket(family, sotype, proto int) (syscall.Handle, error) {
s, err := wsaSocketFunc(int32(family), int32(sotype), int32(proto),
nil, 0, windows.WSA_FLAG_OVERLAPPED|windows.WSA_FLAG_NO_HANDLE_INHERIT)
+ if err == nil {
+ return s, nil
+ }
+ // WSA_FLAG_NO_HANDLE_INHERIT flag is not supported on some
+ // old versions of Windows, see
+ // https://msdn.microsoft.com/en-us/library/windows/desktop/ms742212(v=vs.85).aspx
+ // for details. Just use syscall.Socket, if windows.WSASocket failed.
+
+ // See ../syscall/exec_unix.go for description of ForkLock.
+ syscall.ForkLock.RLock()
+ s, err = socketFunc(family, sotype, proto)
+ if err == nil {
+ syscall.CloseOnExec(s)
+ }
+ syscall.ForkLock.RUnlock()
if err != nil {
return syscall.InvalidHandle, os.NewSyscallError("socket", err)
}
Index: src/syscall/exec_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/syscall/exec_windows.go b/src/syscall/exec_windows.go
--- a/src/syscall/exec_windows.go (revision b0d48afabb9fd14976c27221cb525c5d2ebbfe79)
+++ b/src/syscall/exec_windows.go (revision b4aece36e51ecce81c3ee9fe03e31db552e90018)
@@ -15,7 +15,6 @@
"unsafe"
)
-// ForkLock is not used on Windows.
var ForkLock sync.RWMutex
// EscapeArg rewrites command line argument s as prescribed
@@ -304,6 +303,9 @@
var zeroProcAttr ProcAttr
var zeroSysProcAttr SysProcAttr
+//go:linkname rtlGetNtVersionNumbers
+func rtlGetNtVersionNumbers(majorVersion *uint32, minorVersion *uint32, buildNumber *uint32)
+
func StartProcess(argv0 string, argv []string, attr *ProcAttr) (pid int, handle uintptr, err error) {
if len(argv0) == 0 {
return 0, 0, EWINDOWS
@@ -367,6 +369,17 @@
}
}
+ var maj, min, build uint32
+ rtlGetNtVersionNumbers(&maj, &min, &build)
+ isWin7 := maj < 6 || (maj == 6 && min <= 1)
+ // NT kernel handles are divisible by 4, with the bottom 3 bits left as
+ // a tag. The fully set tag correlates with the types of handles we're
+ // concerned about here. Except, the kernel will interpret some
+ // special handle values, like -1, -2, and so forth, so kernelbase.dll
+ // checks to see that those bottom three bits are checked, but that top
+ // bit is not checked.
+ isLegacyWin7ConsoleHandle := func(handle Handle) bool { return isWin7 && handle&0x10000003 == 3 }
+
p, _ := GetCurrentProcess()
parentProcess := p
if sys.ParentProcess != 0 {
@@ -375,7 +388,15 @@
fd := make([]Handle, len(attr.Files))
for i := range attr.Files {
if attr.Files[i] > 0 {
- err := DuplicateHandle(p, Handle(attr.Files[i]), parentProcess, &fd[i], 0, true, DUPLICATE_SAME_ACCESS)
+ destinationProcessHandle := parentProcess
+
+ // On Windows 7, console handles aren't real handles, and can only be duplicated
+ // into the current process, not a parent one, which amounts to the same thing.
+ if parentProcess != p && isLegacyWin7ConsoleHandle(Handle(attr.Files[i])) {
+ destinationProcessHandle = p
+ }
+
+ err := DuplicateHandle(p, Handle(attr.Files[i]), destinationProcessHandle, &fd[i], 0, true, DUPLICATE_SAME_ACCESS)
if err != nil {
return 0, 0, err
}
@@ -406,6 +427,14 @@
fd = append(fd, sys.AdditionalInheritedHandles...)
+ // On Windows 7, console handles aren't real handles, so don't pass them
+ // through to PROC_THREAD_ATTRIBUTE_HANDLE_LIST.
+ for i := range fd {
+ if isLegacyWin7ConsoleHandle(fd[i]) {
+ fd[i] = 0
+ }
+ }
+
// The presence of a NULL handle in the list is enough to cause PROC_THREAD_ATTRIBUTE_HANDLE_LIST
// to treat the entire list as empty, so remove NULL handles.
j := 0
Index: src/syscall/dll_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/syscall/dll_windows.go b/src/syscall/dll_windows.go
--- a/src/syscall/dll_windows.go (revision b4aece36e51ecce81c3ee9fe03e31db552e90018)
+++ b/src/syscall/dll_windows.go (revision ea2726a6fa25fbfa1092e696e522eafca544d24c)
@@ -119,14 +119,7 @@
}
//go:linkname loadsystemlibrary
-func loadsystemlibrary(filename *uint16) (uintptr, Errno) {
- const _LOAD_LIBRARY_SEARCH_SYSTEM32 = 0x00000800
- handle, _, err := SyscallN(uintptr(__LoadLibraryExW), uintptr(unsafe.Pointer(filename)), 0, _LOAD_LIBRARY_SEARCH_SYSTEM32)
- if handle != 0 {
- err = 0
- }
- return handle, err
-}
+func loadsystemlibrary(filename *uint16, absoluteFilepath *uint16) (handle uintptr, err Errno)
//go:linkname getprocaddress
func getprocaddress(handle uintptr, procname *uint8) (uintptr, Errno) {
@@ -143,6 +136,9 @@
Handle Handle
}
+//go:linkname getSystemDirectory
+func getSystemDirectory() string // Implemented in runtime package.
+
// LoadDLL loads the named DLL file into memory.
//
// If name is not an absolute path and is not a known system DLL used by
@@ -159,7 +155,11 @@
var h uintptr
var e Errno
if sysdll.IsSystemDLL[name] {
- h, e = loadsystemlibrary(namep)
+ absoluteFilepathp, err := UTF16PtrFromString(getSystemDirectory() + name)
+ if err != nil {
+ return nil, err
+ }
+ h, e = loadsystemlibrary(namep, absoluteFilepathp)
} else {
h, e = loadlibrary(namep)
}
Index: src/os/removeall_at.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/os/removeall_at.go b/src/os/removeall_at.go
--- a/src/os/removeall_at.go (revision ea2726a6fa25fbfa1092e696e522eafca544d24c)
+++ b/src/os/removeall_at.go (revision d47e0d22130d597dcf9daa6b41fd9501274f0cb2)
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build unix || wasip1 || windows
+//go:build unix || wasip1
package os
@@ -175,3 +175,25 @@
}
return newDirFile(fd, name)
}
+
+func rootRemoveAll(r *Root, name string) error {
+ // Consistency with os.RemoveAll: Strip trailing /s from the name,
+ // so RemoveAll("not_a_directory/") succeeds.
+ for len(name) > 0 && IsPathSeparator(name[len(name)-1]) {
+ name = name[:len(name)-1]
+ }
+ if endsWithDot(name) {
+ // Consistency with os.RemoveAll: Return EINVAL when trying to remove .
+ return &PathError{Op: "RemoveAll", Path: name, Err: syscall.EINVAL}
+ }
+ _, err := doInRoot(r, name, nil, func(parent sysfdType, name string) (struct{}, error) {
+ return struct{}{}, removeAllFrom(parent, name)
+ })
+ if IsNotExist(err) {
+ return nil
+ }
+ if err != nil {
+ return &PathError{Op: "RemoveAll", Path: name, Err: underlyingError(err)}
+ }
+ return err
+}
Index: src/os/removeall_noat.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/os/removeall_noat.go b/src/os/removeall_noat.go
--- a/src/os/removeall_noat.go (revision ea2726a6fa25fbfa1092e696e522eafca544d24c)
+++ b/src/os/removeall_noat.go (revision d47e0d22130d597dcf9daa6b41fd9501274f0cb2)
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build (js && wasm) || plan9
+//go:build (js && wasm) || plan9 || windows
package os
@@ -140,3 +140,22 @@
}
return err
}
+
+func rootRemoveAll(r *Root, name string) error {
+ if endsWithDot(name) {
+ // Consistency with os.RemoveAll: Return EINVAL when trying to remove .
+ return &PathError{Op: "RemoveAll", Path: name, Err: syscall.EINVAL}
+ }
+ if err := checkPathEscapesLstat(r, name); err != nil {
+ if err == syscall.ENOTDIR {
+ // Some intermediate path component is not a directory.
+ // RemoveAll treats this as success (since the target doesn't exist).
+ return nil
+ }
+ return &PathError{Op: "RemoveAll", Path: name, Err: err}
+ }
+ if err := RemoveAll(joinPath(r.root.name, name)); err != nil {
+ return &PathError{Op: "RemoveAll", Path: name, Err: underlyingError(err)}
+ }
+ return nil
+}
Index: src/os/root_noopenat.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/os/root_noopenat.go b/src/os/root_noopenat.go
--- a/src/os/root_noopenat.go (revision ea2726a6fa25fbfa1092e696e522eafca544d24c)
+++ b/src/os/root_noopenat.go (revision d47e0d22130d597dcf9daa6b41fd9501274f0cb2)
@@ -11,7 +11,6 @@
"internal/filepathlite"
"internal/stringslite"
"sync/atomic"
- "syscall"
"time"
)
@@ -185,25 +184,6 @@
}
return nil
}
-
-func rootRemoveAll(r *Root, name string) error {
- if endsWithDot(name) {
- // Consistency with os.RemoveAll: Return EINVAL when trying to remove .
- return &PathError{Op: "RemoveAll", Path: name, Err: syscall.EINVAL}
- }
- if err := checkPathEscapesLstat(r, name); err != nil {
- if err == syscall.ENOTDIR {
- // Some intermediate path component is not a directory.
- // RemoveAll treats this as success (since the target doesn't exist).
- return nil
- }
- return &PathError{Op: "RemoveAll", Path: name, Err: err}
- }
- if err := RemoveAll(joinPath(r.root.name, name)); err != nil {
- return &PathError{Op: "RemoveAll", Path: name, Err: underlyingError(err)}
- }
- return nil
-}
func rootReadlink(r *Root, name string) (string, error) {
if err := checkPathEscapesLstat(r, name); err != nil {
Index: src/os/root_openat.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/os/root_openat.go b/src/os/root_openat.go
--- a/src/os/root_openat.go (revision ea2726a6fa25fbfa1092e696e522eafca544d24c)
+++ b/src/os/root_openat.go (revision d47e0d22130d597dcf9daa6b41fd9501274f0cb2)
@@ -196,28 +196,6 @@
return nil
}
-func rootRemoveAll(r *Root, name string) error {
- // Consistency with os.RemoveAll: Strip trailing /s from the name,
- // so RemoveAll("not_a_directory/") succeeds.
- for len(name) > 0 && IsPathSeparator(name[len(name)-1]) {
- name = name[:len(name)-1]
- }
- if endsWithDot(name) {
- // Consistency with os.RemoveAll: Return EINVAL when trying to remove .
- return &PathError{Op: "RemoveAll", Path: name, Err: syscall.EINVAL}
- }
- _, err := doInRoot(r, name, nil, func(parent sysfdType, name string) (struct{}, error) {
- return struct{}{}, removeAllFrom(parent, name)
- })
- if IsNotExist(err) {
- return nil
- }
- if err != nil {
- return &PathError{Op: "RemoveAll", Path: name, Err: underlyingError(err)}
- }
- return err
-}
-
func rootRename(r *Root, oldname, newname string) error {
_, err := doInRoot(r, oldname, nil, func(oldparent sysfdType, oldname string) (struct{}, error) {
_, err := doInRoot(r, newname, nil, func(newparent sysfdType, newname string) (struct{}, error) {
Index: src/os/root_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/os/root_windows.go b/src/os/root_windows.go
--- a/src/os/root_windows.go (revision ea2726a6fa25fbfa1092e696e522eafca544d24c)
+++ b/src/os/root_windows.go (revision d47e0d22130d597dcf9daa6b41fd9501274f0cb2)
@@ -402,3 +402,14 @@
}
return fi.Mode(), nil
}
+
+func checkPathEscapes(r *Root, name string) error {
+ if !filepathlite.IsLocal(name) {
+ return errPathEscapes
+ }
+ return nil
+}
+
+func checkPathEscapesLstat(r *Root, name string) error {
+ return checkPathEscapes(r, name)
+}
Index: src/os/exec_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/os/exec_windows.go b/src/os/exec_windows.go
--- a/src/os/exec_windows.go (revision d47e0d22130d597dcf9daa6b41fd9501274f0cb2)
+++ b/src/os/exec_windows.go (revision 00e8daec9a4d88f44a8dc55d3bdb71878e525b41)
@@ -10,6 +10,7 @@
"runtime"
"syscall"
"time"
+ _ "unsafe"
)
// Note that Process.handle is never nil because Windows always requires
@@ -49,9 +50,23 @@
// than statusDone.
p.doRelease(statusReleased)
+ var maj, min, build uint32
+ rtlGetNtVersionNumbers(&maj, &min, &build)
+ if maj < 10 {
+ // NOTE(brainman): It seems that sometimes process is not dead
+ // when WaitForSingleObject returns. But we do not know any
+ // other way to wait for it. Sleeping for a while seems to do
+ // the trick sometimes.
+ // See https://golang.org/issue/25965 for details.
+ time.Sleep(5 * time.Millisecond)
+ }
+
return &ProcessState{p.Pid, syscall.WaitStatus{ExitCode: ec}, &u}, nil
}
+//go:linkname rtlGetNtVersionNumbers syscall.rtlGetNtVersionNumbers
+func rtlGetNtVersionNumbers(majorVersion *uint32, minorVersion *uint32, buildNumber *uint32)
+
func (p *Process) signal(sig Signal) error {
handle, status := p.handleTransientAcquire()
switch status {

View File

@@ -59,6 +59,8 @@ jobs:
- { goos: linux, goarch: s390x, output: s390x, debian: s390x, rpm: s390x }
- { goos: linux, goarch: ppc64le, output: ppc64le, debian: ppc64el, rpm: ppc64le }
# Go 1.25 with special patch can work on Windows 7
# https://github.com/MetaCubeX/go/commits/release-branch.go1.25/
- { goos: windows, goarch: '386', output: '386' }
- { goos: windows, goarch: amd64, goamd64: v1, output: amd64-compatible } # old style file name will be removed in next released
- { goos: windows, goarch: amd64, goamd64: v3, output: amd64 }
@@ -153,12 +155,14 @@ jobs:
uses: actions/setup-go@v6
with:
go-version: '1.25'
check-latest: true # Always check for the latest patch release
- name: Set up Go
if: ${{ matrix.jobs.goversion != '' && matrix.jobs.abi != '1' }}
uses: actions/setup-go@v6
with:
go-version: ${{ matrix.jobs.goversion }}
check-latest: true # Always check for the latest patch release
- name: Set up Go1.24 loongarch abi1
if: ${{ matrix.jobs.goarch == 'loong64' && matrix.jobs.abi == '1' }}
@@ -176,6 +180,9 @@ jobs:
# 7c1157f9544922e96945196b47b95664b1e39108: "net: remove sysSocket fallback for Windows 7"
# 48042aa09c2f878c4faa576948b07fe625c4707a: "syscall: remove Windows 7 console handle workaround"
# a17d959debdb04cd550016a3501dd09d50cd62e7: "runtime: always use LoadLibraryEx to load system libraries"
# f0894a00f4b756d4b9b4078af2e686b359493583: "os: remove 5ms sleep on Windows in (*Process).Wait"
# sepical fix:
# - os.RemoveAll not working on Windows7
- name: Revert Golang1.25 commit for Windows7/8
if: ${{ matrix.jobs.goos == 'windows' && matrix.jobs.goversion == '' }}
run: |

View File

@@ -24,6 +24,7 @@ jobs:
- 'ubuntu-24.04-arm' # arm64 linux
- 'macos-15-intel' # amd64 macos
go-version:
- '1.26.0-rc.3'
- '1.25'
- '1.24'
- '1.23'
@@ -47,13 +48,20 @@ jobs:
uses: actions/setup-go@v6
with:
go-version: ${{ matrix.go-version }}
check-latest: true # Always check for the latest patch release
- name: Revert Golang commit for Windows7/8
if: ${{ runner.os == 'Windows' && matrix.go-version != '1.20' }}
if: ${{ runner.os == 'Windows' && matrix.go-version != '1.20' && matrix.go-version != '1.26.0-rc.3' }}
run: |
cd $(go env GOROOT)
patch --verbose -p 1 < $GITHUB_WORKSPACE/.github/patch/go${{matrix.go-version}}.patch
- name: Revert Golang commit for Windows7/8
if: ${{ runner.os == 'Windows' && matrix.go-version == '1.26.0-rc.3' }}
run: |
cd $(go env GOROOT)
patch --verbose -p 1 < $GITHUB_WORKSPACE/.github/patch/go1.26.patch
- name: Remove inbound test for macOS
if: ${{ runner.os == 'macOS' }}
run: |

View File

@@ -5,7 +5,6 @@ import (
"encoding/json"
"fmt"
"net"
"net/http"
"net/url"
"strings"
"time"
@@ -17,6 +16,8 @@ import (
"github.com/metacubex/mihomo/component/ca"
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/log"
"github.com/metacubex/http"
)
var UnifiedDelay = atomic.NewBool(false)
@@ -153,8 +154,9 @@ func (p *Proxy) MarshalJSON() ([]byte, error) {
mapping["mptcp"] = proxyInfo.MPTCP
mapping["smux"] = proxyInfo.SMUX
mapping["interface"] = proxyInfo.Interface
mapping["dialer-proxy"] = proxyInfo.DialerProxy
mapping["routing-mark"] = proxyInfo.RoutingMark
mapping["provider-name"] = proxyInfo.ProviderName
mapping["dialer-proxy"] = proxyInfo.DialerProxy
return json.Marshal(mapping)
}
@@ -177,14 +179,12 @@ func (p *Proxy) URLTest(ctx context.Context, url string, expectedStatus utils.In
p.history.Pop()
}
state, ok := p.extra.Load(url)
if !ok {
state = &internalProxyState{
state, _ := p.extra.LoadOrStoreFn(url, func() *internalProxyState {
return &internalProxyState{
history: queue.New[C.DelayHistory](defaultHistoriesNum),
alive: atomic.NewBool(true),
}
p.extra.Store(url, state)
}
})
if !satisfied {
record.Delay = 0

View File

@@ -2,9 +2,10 @@ package inbound
import (
"net"
"net/http"
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/http"
)
// NewHTTPS receive CONNECT request and return ConnContext

View File

@@ -8,6 +8,7 @@ import (
"sync"
"github.com/metacubex/mihomo/component/keepalive"
"github.com/metacubex/mihomo/component/mptcp"
"github.com/metacubex/tfo-go"
)
@@ -34,13 +35,13 @@ func Tfo() bool {
func SetMPTCP(open bool) {
mutex.Lock()
defer mutex.Unlock()
setMultiPathTCP(&lc.ListenConfig, open)
mptcp.SetNetListenConfig(&lc.ListenConfig, open)
}
func MPTCP() bool {
mutex.RLock()
defer mutex.RUnlock()
return getMultiPathTCP(&lc.ListenConfig)
return mptcp.GetNetListenConfig(&lc.ListenConfig)
}
func preResolve(network, address string) (string, error) {

View File

@@ -1,14 +0,0 @@
//go:build !go1.21
package inbound
import "net"
const multipathTCPAvailable = false
func setMultiPathTCP(listenConfig *net.ListenConfig, open bool) {
}
func getMultiPathTCP(listenConfig *net.ListenConfig) bool {
return false
}

View File

@@ -1,15 +0,0 @@
//go:build go1.21
package inbound
import "net"
const multipathTCPAvailable = true
func setMultiPathTCP(listenConfig *net.ListenConfig, open bool) {
listenConfig.SetMultipathTCP(open)
}
func getMultiPathTCP(listenConfig *net.ListenConfig) bool {
return listenConfig.MultipathTCP()
}

View File

@@ -2,12 +2,13 @@ package inbound
import (
"net"
"net/http"
"net/netip"
"strings"
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/transport/socks5"
"github.com/metacubex/http"
)
func parseSocksAddr(target socks5.Addr) *C.Metadata {

View File

@@ -6,8 +6,7 @@ import (
"strconv"
"time"
CN "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/component/dialer"
N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/component/proxydialer"
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/transport/anytls"
@@ -20,7 +19,6 @@ import (
type AnyTLS struct {
*Base
client *anytls.Client
dialer proxydialer.SingDialer
option *AnyTLSOption
}
@@ -65,7 +63,7 @@ func (t *AnyTLS) ListenPacketContext(ctx context.Context, metadata *C.Metadata)
// create uot on tcp
destination := M.SocksaddrFromNet(metadata.UDPAddr())
return newPacketConn(CN.NewThreadSafePacketConn(uot.NewLazyConn(c, uot.Request{Destination: destination})), t), nil
return newPacketConn(N.NewThreadSafePacketConn(uot.NewLazyConn(c, uot.Request{Destination: destination})), t), nil
}
// SupportUOT implements C.ProxyAdapter
@@ -92,18 +90,18 @@ func NewAnyTLS(option AnyTLSOption) (*AnyTLS, error) {
name: option.Name,
addr: addr,
tp: C.AnyTLS,
pdName: option.ProviderName,
udp: option.UDP,
tfo: option.TFO,
mpTcp: option.MPTCP,
iface: option.Interface,
rmark: option.RoutingMark,
prefer: C.NewDNSPrefer(option.IPVersion),
prefer: option.IPVersion,
},
option: &option,
}
singDialer := proxydialer.NewByNameSingDialer(option.DialerProxy, dialer.NewDialer(outbound.DialOptions()...))
outbound.dialer = singDialer
outbound.dialer = option.NewDialer(outbound.DialOptions())
singDialer := proxydialer.NewSingDialer(outbound.dialer)
tOption := anytls.ClientConfig{
Password: option.Password,

View File

@@ -12,6 +12,7 @@ import (
N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/common/utils"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/proxydialer"
"github.com/metacubex/mihomo/component/resolver"
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/log"
@@ -26,15 +27,17 @@ type ProxyAdapter interface {
type Base struct {
name string
addr string
iface string
tp C.AdapterType
pdName string
udp bool
xudp bool
tfo bool
mpTcp bool
iface string
rmark int
id string
prefer C.DNSPrefer
dialer C.Dialer
id string
}
// Name implements C.ProxyAdapter
@@ -56,35 +59,15 @@ func (b *Base) Type() C.AdapterType {
return b.tp
}
// StreamConnContext implements C.ProxyAdapter
func (b *Base) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (net.Conn, error) {
return c, C.ErrNotSupport
}
func (b *Base) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) {
return nil, C.ErrNotSupport
}
// DialContextWithDialer implements C.ProxyAdapter
func (b *Base) DialContextWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (_ C.Conn, err error) {
return nil, C.ErrNotSupport
}
// ListenPacketContext implements C.ProxyAdapter
func (b *Base) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (C.PacketConn, error) {
return nil, C.ErrNotSupport
}
// ListenPacketWithDialer implements C.ProxyAdapter
func (b *Base) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (_ C.PacketConn, err error) {
return nil, C.ErrNotSupport
}
// SupportWithDialer implements C.ProxyAdapter
func (b *Base) SupportWithDialer() C.NetWork {
return C.InvalidNet
}
// SupportUOT implements C.ProxyAdapter
func (b *Base) SupportUOT() bool {
return false
@@ -103,6 +86,7 @@ func (b *Base) ProxyInfo() (info C.ProxyInfo) {
info.SMUX = false
info.Interface = b.iface
info.RoutingMark = b.rmark
info.ProviderName = b.pdName
return
}
@@ -178,12 +162,30 @@ func (b *Base) Close() error {
}
type BasicOption struct {
TFO bool `proxy:"tfo,omitempty"`
MPTCP bool `proxy:"mptcp,omitempty"`
Interface string `proxy:"interface-name,omitempty"`
RoutingMark int `proxy:"routing-mark,omitempty"`
IPVersion string `proxy:"ip-version,omitempty"`
DialerProxy string `proxy:"dialer-proxy,omitempty"` // don't apply this option into groups, but can set a group name in a proxy
TFO bool `proxy:"tfo,omitempty"`
MPTCP bool `proxy:"mptcp,omitempty"`
Interface string `proxy:"interface-name,omitempty"`
RoutingMark int `proxy:"routing-mark,omitempty"`
IPVersion C.DNSPrefer `proxy:"ip-version,omitempty"`
DialerProxy string `proxy:"dialer-proxy,omitempty"` // don't apply this option into groups, but can set a group name in a proxy
//
// The following parameters are used internally, assign value by the structure decoder are disallowed
//
DialerForAPI C.Dialer `proxy:"-"` // the dialer used for API usage has higher priority than all the above configurations.
ProviderName string `proxy:"-"`
}
func (b *BasicOption) NewDialer(opts []dialer.Option) C.Dialer {
cDialer := b.DialerForAPI
if cDialer == nil {
if b.DialerProxy != "" {
cDialer = proxydialer.NewByName(b.DialerProxy)
} else {
cDialer = dialer.NewDialer(opts...)
}
}
return cDialer
}
type BaseOption struct {
@@ -217,6 +219,7 @@ func NewBase(opt BaseOption) *Base {
type conn struct {
N.ExtendedConn
chain C.Chain
pdChain C.Chain
adapterAddr string
}
@@ -238,9 +241,15 @@ func (c *conn) Chains() C.Chain {
return c.chain
}
// ProviderChains implements C.Connection
func (c *conn) ProviderChains() C.Chain {
return c.pdChain
}
// AppendToChains implements C.Connection
func (c *conn) AppendToChains(a C.ProxyAdapter) {
c.chain = append(c.chain, a.Name())
c.pdChain = append(c.pdChain, a.ProxyInfo().ProviderName)
}
func (c *conn) Upstream() any {
@@ -263,7 +272,7 @@ func NewConn(c net.Conn, a C.ProxyAdapter) C.Conn {
if _, ok := c.(syscall.Conn); !ok { // exclusion system conn like *net.TCPConn
c = N.NewDeadlineConn(c) // most conn from outbound can't handle readDeadline correctly
}
cc := &conn{N.NewExtendedConn(c), nil, a.Addr()}
cc := &conn{N.NewExtendedConn(c), nil, nil, a.Addr()}
cc.AppendToChains(a)
return cc
}
@@ -271,6 +280,7 @@ func NewConn(c net.Conn, a C.ProxyAdapter) C.Conn {
type packetConn struct {
N.EnhancePacketConn
chain C.Chain
pdChain C.Chain
adapterName string
connID string
adapterAddr string
@@ -291,9 +301,15 @@ func (c *packetConn) Chains() C.Chain {
return c.chain
}
// ProviderChains implements C.Connection
func (c *packetConn) ProviderChains() C.Chain {
return c.pdChain
}
// AppendToChains implements C.Connection
func (c *packetConn) AppendToChains(a C.ProxyAdapter) {
c.chain = append(c.chain, a.Name())
c.pdChain = append(c.pdChain, a.ProxyInfo().ProviderName)
}
func (c *packetConn) LocalAddr() net.Addr {
@@ -322,7 +338,7 @@ func newPacketConn(pc net.PacketConn, a ProxyAdapter) C.PacketConn {
if _, ok := pc.(syscall.Conn); !ok { // exclusion system conn like *net.UDPConn
epc = N.NewDeadlineEnhancePacketConn(epc) // most conn from outbound can't handle readDeadline correctly
}
cpc := &packetConn{epc, nil, a.Name(), utils.NewUUIDV4().String(), a.Addr(), a.ResolveUDP}
cpc := &packetConn{epc, nil, nil, a.Name(), utils.NewUUIDV4().String(), a.Addr(), a.ResolveUDP}
cpc.AppendToChains(a)
return cpc
}
@@ -348,17 +364,6 @@ func (p *autoCloseProxyAdapter) DialContext(ctx context.Context, metadata *C.Met
return c, nil
}
func (p *autoCloseProxyAdapter) DialContextWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (_ C.Conn, err error) {
c, err := p.ProxyAdapter.DialContextWithDialer(ctx, dialer, metadata)
if err != nil {
return nil, err
}
if c, ok := c.(AddRef); ok {
c.AddRef(p)
}
return c, nil
}
func (p *autoCloseProxyAdapter) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (_ C.PacketConn, err error) {
pc, err := p.ProxyAdapter.ListenPacketContext(ctx, metadata)
if err != nil {
@@ -370,17 +375,6 @@ func (p *autoCloseProxyAdapter) ListenPacketContext(ctx context.Context, metadat
return pc, nil
}
func (p *autoCloseProxyAdapter) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (_ C.PacketConn, err error) {
pc, err := p.ProxyAdapter.ListenPacketWithDialer(ctx, dialer, metadata)
if err != nil {
return nil, err
}
if pc, ok := pc.(AddRef); ok {
pc.AddRef(p)
}
return pc, nil
}
func (p *autoCloseProxyAdapter) Close() error {
p.closeOnce.Do(func() {
log.Debugln("Closing outdated proxy [%s]", p.Name())

View File

@@ -69,12 +69,13 @@ func NewDirectWithOption(option DirectOption) *Direct {
Base: &Base{
name: option.Name,
tp: C.Direct,
pdName: option.ProviderName,
udp: true,
tfo: option.TFO,
mpTcp: option.MPTCP,
iface: option.Interface,
rmark: option.RoutingMark,
prefer: C.NewDNSPrefer(option.IPVersion),
prefer: option.IPVersion,
},
loopBack: loopback.NewDetector(),
}

View File

@@ -158,12 +158,13 @@ func NewDnsWithOption(option DnsOption) *Dns {
Base: &Base{
name: option.Name,
tp: C.Dns,
pdName: option.ProviderName,
udp: true,
tfo: option.TFO,
mpTcp: option.MPTCP,
iface: option.Interface,
rmark: option.RoutingMark,
prefer: C.NewDNSPrefer(option.IPVersion),
prefer: option.IPVersion,
},
}
}

View File

@@ -12,6 +12,8 @@ import (
type ECHOptions struct {
Enable bool `proxy:"enable,omitempty" obfs:"enable,omitempty"`
Config string `proxy:"config,omitempty" obfs:"config,omitempty"`
QueryServerName string `proxy:"query-server-name,omitempty" obfs:"query-server-name,omitempty"`
}
func (o ECHOptions) Parse() (*ech.Config, error) {
@@ -29,6 +31,9 @@ func (o ECHOptions) Parse() (*ech.Config, error) {
}
} else {
echConfig.GetEncryptedClientHelloConfigList = func(ctx context.Context, serverName string) ([]byte, error) {
if o.QueryServerName != "" { // overrides the domain name used for ECH HTTPS record queries
serverName = o.QueryServerName
}
return resolver.ResolveECHWithResolver(ctx, serverName, resolver.ProxyServerHostResolver)
}
}

View File

@@ -3,19 +3,18 @@ package outbound
import (
"bufio"
"context"
"crypto/tls"
"encoding/base64"
"errors"
"fmt"
"net"
"net/http"
"strconv"
N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/component/ca"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/proxydialer"
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/http"
"github.com/metacubex/tls"
)
type Http struct {
@@ -61,18 +60,7 @@ func (h *Http) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.Me
// DialContext implements C.ProxyAdapter
func (h *Http) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn, err error) {
return h.DialContextWithDialer(ctx, dialer.NewDialer(h.DialOptions()...), metadata)
}
// DialContextWithDialer implements C.ProxyAdapter
func (h *Http) DialContextWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (_ C.Conn, err error) {
if len(h.option.DialerProxy) > 0 {
dialer, err = proxydialer.NewByName(h.option.DialerProxy, dialer)
if err != nil {
return nil, err
}
}
c, err := dialer.DialContext(ctx, "tcp", h.addr)
c, err := h.dialer.DialContext(ctx, "tcp", h.addr)
if err != nil {
return nil, fmt.Errorf("%s connect error: %w", h.addr, err)
}
@@ -89,11 +77,6 @@ func (h *Http) DialContextWithDialer(ctx context.Context, dialer C.Dialer, metad
return NewConn(c, h), nil
}
// SupportWithDialer implements C.ProxyAdapter
func (h *Http) SupportWithDialer() C.NetWork {
return C.TCP
}
// ProxyInfo implements C.ProxyAdapter
func (h *Http) ProxyInfo() C.ProxyInfo {
info := h.Base.ProxyInfo()
@@ -183,20 +166,23 @@ func NewHttp(option HttpOption) (*Http, error) {
}
}
return &Http{
outbound := &Http{
Base: &Base{
name: option.Name,
addr: net.JoinHostPort(option.Server, strconv.Itoa(option.Port)),
tp: C.Http,
pdName: option.ProviderName,
tfo: option.TFO,
mpTcp: option.MPTCP,
iface: option.Interface,
rmark: option.RoutingMark,
prefer: C.NewDNSPrefer(option.IPVersion),
prefer: option.IPVersion,
},
user: option.UserName,
pass: option.Password,
tlsConfig: tlsConfig,
option: &option,
}, nil
}
outbound.dialer = option.NewDialer(outbound.DialOptions())
return outbound, nil
}

View File

@@ -2,7 +2,6 @@ package outbound
import (
"context"
"crypto/tls"
"encoding/base64"
"fmt"
"net"
@@ -13,8 +12,6 @@ import (
"github.com/metacubex/mihomo/component/ca"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/ech"
"github.com/metacubex/mihomo/component/proxydialer"
tlsC "github.com/metacubex/mihomo/component/tls"
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/log"
hyCongestion "github.com/metacubex/mihomo/transport/hysteria/congestion"
@@ -24,6 +21,8 @@ import (
"github.com/metacubex/mihomo/transport/hysteria/transport"
"github.com/metacubex/mihomo/transport/hysteria/utils"
"github.com/metacubex/tls"
"github.com/metacubex/quic-go"
"github.com/metacubex/quic-go/congestion"
M "github.com/metacubex/sing/common/metadata"
@@ -46,7 +45,7 @@ type Hysteria struct {
option *HysteriaOption
client *core.Client
tlsConfig *tlsC.Config
tlsConfig *tls.Config
echConfig *ech.Config
}
@@ -74,16 +73,8 @@ func (h *Hysteria) genHdc(ctx context.Context) utils.PacketDialer {
return &hyDialerWithContext{
ctx: context.Background(),
hyDialer: func(network string, rAddr net.Addr) (net.PacketConn, error) {
var err error
var cDialer C.Dialer = dialer.NewDialer(h.DialOptions()...)
if len(h.option.DialerProxy) > 0 {
cDialer, err = proxydialer.NewByName(h.option.DialerProxy, cDialer)
if err != nil {
return nil, err
}
}
rAddrPort, _ := netip.ParseAddrPort(rAddr.String())
return cDialer.ListenPacket(ctx, network, "", rAddrPort)
return h.dialer.ListenPacket(ctx, network, "", rAddrPort)
},
remoteAddr: func(addr string) (net.Addr, error) {
udpAddr, err := resolveUDPAddr(ctx, "udp", addr, h.prefer)
@@ -184,7 +175,7 @@ func NewHysteria(option HysteriaOption) (*Hysteria, error) {
if err != nil {
return nil, err
}
tlsClientConfig := tlsC.UConfig(tlsConfig)
tlsClientConfig := tlsConfig
quicConfig := &quic.Config{
InitialStreamReceiveWindow: uint64(option.ReceiveWindowConn),
@@ -252,17 +243,19 @@ func NewHysteria(option HysteriaOption) (*Hysteria, error) {
name: option.Name,
addr: addr,
tp: C.Hysteria,
pdName: option.ProviderName,
udp: true,
tfo: option.FastOpen,
iface: option.Interface,
rmark: option.RoutingMark,
prefer: C.NewDNSPrefer(option.IPVersion),
prefer: option.IPVersion,
},
option: &option,
client: client,
tlsConfig: tlsClientConfig,
echConfig: echConfig,
}
outbound.dialer = option.NewDialer(outbound.DialOptions())
return outbound, nil
}

View File

@@ -2,19 +2,16 @@ package outbound
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"strconv"
"time"
CN "github.com/metacubex/mihomo/common/net"
N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/common/utils"
"github.com/metacubex/mihomo/component/ca"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/proxydialer"
tlsC "github.com/metacubex/mihomo/component/tls"
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/log"
tuicCommon "github.com/metacubex/mihomo/transport/tuic/common"
@@ -22,6 +19,7 @@ import (
"github.com/metacubex/quic-go"
"github.com/metacubex/sing-quic/hysteria2"
M "github.com/metacubex/sing/common/metadata"
"github.com/metacubex/tls"
)
func init() {
@@ -36,7 +34,6 @@ type Hysteria2 struct {
option *Hysteria2Option
client *hysteria2.Client
dialer proxydialer.SingDialer
}
type Hysteria2Option struct {
@@ -87,7 +84,7 @@ func (h *Hysteria2) ListenPacketContext(ctx context.Context, metadata *C.Metadat
if pc == nil {
return nil, errors.New("packetConn is nil")
}
return newPacketConn(CN.NewThreadSafePacketConn(pc), h), nil
return newPacketConn(N.NewThreadSafePacketConn(pc), h), nil
}
// Close implements C.ProxyAdapter
@@ -112,16 +109,16 @@ func NewHysteria2(option Hysteria2Option) (*Hysteria2, error) {
name: option.Name,
addr: addr,
tp: C.Hysteria2,
pdName: option.ProviderName,
udp: true,
iface: option.Interface,
rmark: option.RoutingMark,
prefer: C.NewDNSPrefer(option.IPVersion),
prefer: option.IPVersion,
},
option: &option,
}
singDialer := proxydialer.NewByNameSingDialer(option.DialerProxy, dialer.NewDialer(outbound.DialOptions()...))
outbound.dialer = singDialer
outbound.dialer = option.NewDialer(outbound.DialOptions())
singDialer := proxydialer.NewSingDialer(outbound.dialer)
var salamanderPassword string
if len(option.Obfs) > 0 {
@@ -159,7 +156,7 @@ func NewHysteria2(option Hysteria2Option) (*Hysteria2, error) {
tlsConfig.NextProtos = option.ALPN
}
tlsClientConfig := tlsC.UConfig(tlsConfig)
tlsClientConfig := tlsConfig
echConfig, err := option.ECHOpts.Parse()
if err != nil {
return nil, err
@@ -192,7 +189,7 @@ func NewHysteria2(option Hysteria2Option) (*Hysteria2, error) {
CWND: option.CWND,
UdpMTU: option.UdpMTU,
ServerAddress: func(ctx context.Context) (*net.UDPAddr, error) {
udpAddr, err := resolveUDPAddr(ctx, "udp", addr, C.NewDNSPrefer(option.IPVersion))
udpAddr, err := resolveUDPAddr(ctx, "udp", addr, option.IPVersion)
if err != nil {
return nil, err
}

397
adapter/outbound/masque.go Normal file
View File

@@ -0,0 +1,397 @@
package outbound
import (
"context"
"crypto/ecdsa"
"crypto/x509"
"encoding/base64"
"errors"
"fmt"
"net"
"net/netip"
"strconv"
"strings"
"sync"
"time"
"github.com/metacubex/mihomo/common/atomic"
"github.com/metacubex/mihomo/common/contextutils"
"github.com/metacubex/mihomo/common/pool"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/resolver"
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/dns"
"github.com/metacubex/mihomo/log"
"github.com/metacubex/mihomo/transport/masque"
"github.com/metacubex/mihomo/transport/tuic/common"
connectip "github.com/metacubex/connect-ip-go"
"github.com/metacubex/quic-go"
wireguard "github.com/metacubex/sing-wireguard"
M "github.com/metacubex/sing/common/metadata"
"github.com/metacubex/tls"
)
type Masque struct {
*Base
tlsConfig *tls.Config
quicConfig *quic.Config
tunDevice wireguard.Device
resolver resolver.Resolver
uri string
runCtx context.Context
runCancel context.CancelFunc
runMutex sync.Mutex
running atomic.Bool
runDevice atomic.Bool
option MasqueOption
}
type MasqueOption struct {
BasicOption
Name string `proxy:"name"`
Server string `proxy:"server"`
Port int `proxy:"port"`
PrivateKey string `proxy:"private-key"`
PublicKey string `proxy:"public-key"`
Ip string `proxy:"ip,omitempty"`
Ipv6 string `proxy:"ipv6,omitempty"`
URI string `proxy:"uri,omitempty"`
SNI string `proxy:"sni,omitempty"`
MTU int `proxy:"mtu,omitempty"`
UDP bool `proxy:"udp,omitempty"`
CongestionController string `proxy:"congestion-controller,omitempty"`
CWND int `proxy:"cwnd,omitempty"`
RemoteDnsResolve bool `proxy:"remote-dns-resolve,omitempty"`
Dns []string `proxy:"dns,omitempty"`
}
func (option MasqueOption) Prefixes() ([]netip.Prefix, error) {
localPrefixes := make([]netip.Prefix, 0, 2)
if len(option.Ip) > 0 {
if !strings.Contains(option.Ip, "/") {
option.Ip = option.Ip + "/32"
}
if prefix, err := netip.ParsePrefix(option.Ip); err == nil {
localPrefixes = append(localPrefixes, prefix)
} else {
return nil, fmt.Errorf("ip address parse error: %w", err)
}
}
if len(option.Ipv6) > 0 {
if !strings.Contains(option.Ipv6, "/") {
option.Ipv6 = option.Ipv6 + "/128"
}
if prefix, err := netip.ParsePrefix(option.Ipv6); err == nil {
localPrefixes = append(localPrefixes, prefix)
} else {
return nil, fmt.Errorf("ipv6 address parse error: %w", err)
}
}
if len(localPrefixes) == 0 {
return nil, errors.New("missing local address")
}
return localPrefixes, nil
}
func NewMasque(option MasqueOption) (*Masque, error) {
outbound := &Masque{
Base: &Base{
name: option.Name,
addr: net.JoinHostPort(option.Server, strconv.Itoa(option.Port)),
tp: C.Masque,
pdName: option.ProviderName,
udp: option.UDP,
iface: option.Interface,
rmark: option.RoutingMark,
prefer: option.IPVersion,
},
}
outbound.dialer = option.NewDialer(outbound.DialOptions())
ctx, cancel := context.WithCancel(context.Background())
outbound.runCtx = ctx
outbound.runCancel = cancel
privKeyB64, err := base64.StdEncoding.DecodeString(option.PrivateKey)
if err != nil {
return nil, fmt.Errorf("failed to decode private key: %v", err)
}
privKey, err := x509.ParseECPrivateKey(privKeyB64)
if err != nil {
return nil, fmt.Errorf("failed to parse private key: %v", err)
}
endpointPubKeyB64, err := base64.StdEncoding.DecodeString(option.PublicKey)
if err != nil {
return nil, fmt.Errorf("failed to decode public key: %v", err)
}
pubKey, err := x509.ParsePKIXPublicKey(endpointPubKeyB64)
if err != nil {
return nil, fmt.Errorf("failed to parse public key: %v", err)
}
ecPubKey, ok := pubKey.(*ecdsa.PublicKey)
if !ok {
return nil, fmt.Errorf("failed to assert public key as ECDSA")
}
uri := option.URI
if uri == "" {
uri = masque.ConnectURI
}
outbound.uri = uri
sni := option.SNI
if sni == "" {
sni = masque.ConnectSNI
}
tlsConfig, err := masque.PrepareTlsConfig(privKey, ecPubKey, sni)
if err != nil {
return nil, fmt.Errorf("failed to prepare TLS config: %v\n", err)
}
outbound.tlsConfig = tlsConfig
outbound.quicConfig = &quic.Config{
EnableDatagrams: true,
InitialPacketSize: 1242,
KeepAlivePeriod: 30 * time.Second,
}
prefixes, err := option.Prefixes()
if err != nil {
return nil, err
}
outbound.option = option
mtu := option.MTU
if mtu == 0 {
mtu = 1280
}
if len(prefixes) == 0 {
return nil, errors.New("missing local address")
}
outbound.tunDevice, err = wireguard.NewStackDevice(prefixes, uint32(mtu))
if err != nil {
return nil, fmt.Errorf("create device: %w", err)
}
var has6 bool
for _, address := range prefixes {
if !address.Addr().Unmap().Is4() {
has6 = true
break
}
}
if option.RemoteDnsResolve && len(option.Dns) > 0 {
nss, err := dns.ParseNameServer(option.Dns)
if err != nil {
return nil, err
}
for i := range nss {
nss[i].ProxyAdapter = outbound
}
outbound.resolver = dns.NewResolver(dns.Config{
Main: nss,
IPv6: has6,
})
}
return outbound, nil
}
func (w *Masque) run(ctx context.Context) error {
if w.running.Load() {
return nil
}
w.runMutex.Lock()
defer w.runMutex.Unlock()
// double-check like sync.Once
if w.running.Load() {
return nil
}
if w.runCtx.Err() != nil {
return w.runCtx.Err()
}
if !w.runDevice.Load() {
err := w.tunDevice.Start()
if err != nil {
return err
}
w.runDevice.Store(true)
}
udpAddr, err := resolveUDPAddr(ctx, "udp", w.addr, w.prefer)
if err != nil {
return err
}
pc, err := w.dialer.ListenPacket(ctx, "udp", "", udpAddr.AddrPort())
if err != nil {
return err
}
quicConn, err := quic.Dial(ctx, pc, udpAddr, w.tlsConfig, w.quicConfig)
if err != nil {
return err
}
common.SetCongestionController(quicConn, w.option.CongestionController, w.option.CWND)
tr, ipConn, err := masque.ConnectTunnel(ctx, quicConn, w.uri)
if err != nil {
_ = pc.Close()
return err
}
w.running.Store(true)
runCtx, runCancel := context.WithCancel(w.runCtx)
contextutils.AfterFunc(runCtx, func() {
w.running.Store(false)
_ = ipConn.Close()
_ = tr.Close()
_ = pc.Close()
})
go func() {
defer runCancel()
buf := pool.Get(pool.UDPBufferSize)
defer pool.Put(buf)
bufs := [][]byte{buf}
sizes := []int{0}
for runCtx.Err() == nil {
_, err := w.tunDevice.Read(bufs, sizes, 0)
if err != nil {
log.Errorln("[Masque](%s) error reading from TUN device: %v", w.name, err)
return
}
icmp, err := ipConn.WritePacket(buf[:sizes[0]])
if err != nil {
if errors.As(err, new(*connectip.CloseError)) {
log.Errorln("[Masque](%s) connection closed while writing to IP connection: %v", w.name, err)
return
}
log.Warnln("[Masque](%s) error writing to IP connection: %v, continuing...", w.name, err)
continue
}
if len(icmp) > 0 {
if _, err := w.tunDevice.Write([][]byte{icmp}, 0); err != nil {
log.Warnln("[Masque](%s) error writing ICMP to TUN device: %v, continuing...", w.name, err)
}
}
}
}()
go func() {
defer runCancel()
buf := pool.Get(pool.UDPBufferSize)
defer pool.Put(buf)
for runCtx.Err() == nil {
n, err := ipConn.ReadPacket(buf)
if err != nil {
if errors.As(err, new(*connectip.CloseError)) {
log.Errorln("[Masque](%s) connection closed while writing to IP connection: %v", w.name, err)
return
}
log.Warnln("[Masque](%s) error reading from IP connection: %v, continuing...", w.name, err)
continue
}
if _, err := w.tunDevice.Write([][]byte{buf[:n]}, 0); err != nil {
log.Errorln("[Masque](%s) error writing to TUN device: %v", w.name, err)
return
}
}
}()
return nil
}
// Close implements C.ProxyAdapter
func (w *Masque) Close() error {
w.runCancel()
if w.tunDevice != nil {
w.tunDevice.Close()
}
return nil
}
func (w *Masque) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn, err error) {
var conn net.Conn
if err = w.run(ctx); err != nil {
return nil, err
}
if !metadata.Resolved() || w.resolver != nil {
r := resolver.DefaultResolver
if w.resolver != nil {
r = w.resolver
}
options := w.DialOptions()
options = append(options, dialer.WithResolver(r))
options = append(options, dialer.WithNetDialer(wgNetDialer{tunDevice: w.tunDevice}))
conn, err = dialer.NewDialer(options...).DialContext(ctx, "tcp", metadata.RemoteAddress())
} else {
conn, err = w.tunDevice.DialContext(ctx, "tcp", M.SocksaddrFrom(metadata.DstIP, metadata.DstPort).Unwrap())
}
if err != nil {
return nil, err
}
if conn == nil {
return nil, errors.New("conn is nil")
}
return NewConn(conn, w), nil
}
func (w *Masque) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (_ C.PacketConn, err error) {
var pc net.PacketConn
if err = w.run(ctx); err != nil {
return nil, err
}
if err = w.ResolveUDP(ctx, metadata); err != nil {
return nil, err
}
pc, err = w.tunDevice.ListenPacket(ctx, M.SocksaddrFrom(metadata.DstIP, metadata.DstPort).Unwrap())
if err != nil {
return nil, err
}
if pc == nil {
return nil, errors.New("packetConn is nil")
}
return newPacketConn(pc, w), nil
}
func (w *Masque) ResolveUDP(ctx context.Context, metadata *C.Metadata) error {
if (!metadata.Resolved() || w.resolver != nil) && metadata.Host != "" {
r := resolver.DefaultResolver
if w.resolver != nil {
r = w.resolver
}
ip, err := resolver.ResolveIPWithResolver(ctx, metadata.Host, r)
if err != nil {
return fmt.Errorf("can't resolve ip: %w", err)
}
metadata.DstIP = ip
}
return nil
}
// ProxyInfo implements C.ProxyAdapter
func (w *Masque) ProxyInfo() C.ProxyInfo {
info := w.Base.ProxyInfo()
info.DialerProxy = w.option.DialerProxy
return info
}
// IsL3Protocol implements C.ProxyAdapter
func (w *Masque) IsL3Protocol(metadata *C.Metadata) bool {
return true
}

View File

@@ -8,9 +8,7 @@ import (
"strconv"
"sync"
CN "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/proxydialer"
N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/component/resolver"
C "github.com/metacubex/mihomo/constant"
@@ -106,7 +104,7 @@ func (m *Mieru) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (
if err != nil {
return nil, fmt.Errorf("dial to %s failed: %w", metadata.UDPAddr(), err)
}
return newPacketConn(CN.NewThreadSafePacketConn(mierucommon.NewUDPAssociateWrapper(mierucommon.NewPacketOverStreamTunnel(c))), m), nil
return newPacketConn(N.NewThreadSafePacketConn(mierucommon.NewUDPAssociateWrapper(mierucommon.NewPacketOverStreamTunnel(c))), m), nil
}
// SupportUOT implements C.ProxyAdapter
@@ -130,20 +128,12 @@ func (m *Mieru) ensureClientIsRunning() error {
}
// Create a dialer and add it to the client config, before starting the client.
var dialer C.Dialer = dialer.NewDialer(m.DialOptions()...)
var err error
if len(m.option.DialerProxy) > 0 {
dialer, err = proxydialer.NewByName(m.option.DialerProxy, dialer)
if err != nil {
return err
}
}
config, err := m.client.Load()
if err != nil {
return err
}
config.Dialer = dialer
config.PacketDialer = mieruPacketDialer{Dialer: dialer}
config.Dialer = m.dialer
config.PacketDialer = mieruPacketDialer{Dialer: m.dialer}
config.Resolver = mieruDNSResolver{prefer: m.prefer}
if err := m.client.Store(config); err != nil {
return err
@@ -177,16 +167,18 @@ func NewMieru(option MieruOption) (*Mieru, error) {
Base: &Base{
name: option.Name,
addr: addr,
iface: option.Interface,
tp: C.Mieru,
pdName: option.ProviderName,
udp: option.UDP,
xudp: false,
iface: option.Interface,
rmark: option.RoutingMark,
prefer: C.NewDNSPrefer(option.IPVersion),
prefer: option.IPVersion,
},
option: &option,
client: c,
}
outbound.dialer = option.NewDialer(outbound.DialOptions())
return outbound, nil
}

View File

@@ -17,6 +17,7 @@ type Reject struct {
}
type RejectOption struct {
BasicOption
Name string `proxy:"name"`
}

View File

@@ -8,8 +8,6 @@ import (
N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/common/structure"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/proxydialer"
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/ntp"
gost "github.com/metacubex/mihomo/transport/gost-plugin"
@@ -116,6 +114,7 @@ type kcpTunOption struct {
AutoExpire int `obfs:"autoexpire,omitempty"`
ScavengeTTL int `obfs:"scavengettl,omitempty"`
MTU int `obfs:"mtu,omitempty"`
RateLimit int `obfs:"ratelimit,omitempty"`
SndWnd int `obfs:"sndwnd,omitempty"`
RcvWnd int `obfs:"rcvwnd,omitempty"`
DataShard int `obfs:"datashard,omitempty"`
@@ -130,6 +129,7 @@ type kcpTunOption struct {
SockBuf int `obfs:"sockbuf,omitempty"`
SmuxVer int `obfs:"smuxver,omitempty"`
SmuxBuf int `obfs:"smuxbuf,omitempty"`
FrameSize int `obfs:"framesize,omitempty"`
StreamBuf int `obfs:"streambuf,omitempty"`
KeepAlive int `obfs:"keepalive,omitempty"`
}
@@ -191,17 +191,6 @@ func (ss *ShadowSocks) StreamConnContext(ctx context.Context, c net.Conn, metada
// DialContext implements C.ProxyAdapter
func (ss *ShadowSocks) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn, err error) {
return ss.DialContextWithDialer(ctx, dialer.NewDialer(ss.DialOptions()...), metadata)
}
// DialContextWithDialer implements C.ProxyAdapter
func (ss *ShadowSocks) DialContextWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (_ C.Conn, err error) {
if len(ss.option.DialerProxy) > 0 {
dialer, err = proxydialer.NewByName(ss.option.DialerProxy, dialer)
if err != nil {
return nil, err
}
}
var c net.Conn
if ss.kcptunClient != nil {
c, err = ss.kcptunClient.OpenStream(ctx, func(ctx context.Context) (net.PacketConn, net.Addr, error) {
@@ -213,7 +202,7 @@ func (ss *ShadowSocks) DialContextWithDialer(ctx context.Context, dialer C.Diale
return nil, nil, err
}
pc, err := dialer.ListenPacket(ctx, "udp", "", addr.AddrPort())
pc, err := ss.dialer.ListenPacket(ctx, "udp", "", addr.AddrPort())
if err != nil {
return nil, nil, err
}
@@ -221,7 +210,7 @@ func (ss *ShadowSocks) DialContextWithDialer(ctx context.Context, dialer C.Diale
return pc, addr, nil
})
} else {
c, err = dialer.DialContext(ctx, "tcp", ss.addr)
c, err = ss.dialer.DialContext(ctx, "tcp", ss.addr)
}
if err != nil {
return nil, fmt.Errorf("%s connect error: %w", ss.addr, err)
@@ -237,25 +226,14 @@ func (ss *ShadowSocks) DialContextWithDialer(ctx context.Context, dialer C.Diale
// ListenPacketContext implements C.ProxyAdapter
func (ss *ShadowSocks) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (C.PacketConn, error) {
return ss.ListenPacketWithDialer(ctx, dialer.NewDialer(ss.DialOptions()...), metadata)
}
// ListenPacketWithDialer implements C.ProxyAdapter
func (ss *ShadowSocks) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (_ C.PacketConn, err error) {
if ss.option.UDPOverTCP {
tcpConn, err := ss.DialContextWithDialer(ctx, dialer, metadata)
tcpConn, err := ss.DialContext(ctx, metadata)
if err != nil {
return nil, err
}
return ss.ListenPacketOnStreamConn(ctx, tcpConn, metadata)
}
if len(ss.option.DialerProxy) > 0 {
dialer, err = proxydialer.NewByName(ss.option.DialerProxy, dialer)
if err != nil {
return nil, err
}
}
if err = ss.ResolveUDP(ctx, metadata); err != nil {
if err := ss.ResolveUDP(ctx, metadata); err != nil {
return nil, err
}
addr, err := resolveUDPAddr(ctx, "udp", ss.addr, ss.prefer)
@@ -263,7 +241,7 @@ func (ss *ShadowSocks) ListenPacketWithDialer(ctx context.Context, dialer C.Dial
return nil, err
}
pc, err := dialer.ListenPacket(ctx, "udp", "", addr.AddrPort())
pc, err := ss.dialer.ListenPacket(ctx, "udp", "", addr.AddrPort())
if err != nil {
return nil, err
}
@@ -271,11 +249,6 @@ func (ss *ShadowSocks) ListenPacketWithDialer(ctx context.Context, dialer C.Dial
return newPacketConn(pc, ss), nil
}
// SupportWithDialer implements C.ProxyAdapter
func (ss *ShadowSocks) SupportWithDialer() C.NetWork {
return C.ALLNet
}
// ProxyInfo implements C.ProxyAdapter
func (ss *ShadowSocks) ProxyInfo() C.ProxyInfo {
info := ss.Base.ProxyInfo()
@@ -455,6 +428,7 @@ func NewShadowSocks(option ShadowSocksOption) (*ShadowSocks, error) {
AutoExpire: kcptunOpt.AutoExpire,
ScavengeTTL: kcptunOpt.ScavengeTTL,
MTU: kcptunOpt.MTU,
RateLimit: kcptunOpt.RateLimit,
SndWnd: kcptunOpt.SndWnd,
RcvWnd: kcptunOpt.RcvWnd,
DataShard: kcptunOpt.DataShard,
@@ -469,6 +443,7 @@ func NewShadowSocks(option ShadowSocksOption) (*ShadowSocks, error) {
SockBuf: kcptunOpt.SockBuf,
SmuxVer: kcptunOpt.SmuxVer,
SmuxBuf: kcptunOpt.SmuxBuf,
FrameSize: kcptunOpt.FrameSize,
StreamBuf: kcptunOpt.StreamBuf,
KeepAlive: kcptunOpt.KeepAlive,
})
@@ -482,17 +457,18 @@ func NewShadowSocks(option ShadowSocksOption) (*ShadowSocks, error) {
return nil, fmt.Errorf("ss %s unknown udp over tcp protocol version: %d", addr, option.UDPOverTCPVersion)
}
return &ShadowSocks{
outbound := &ShadowSocks{
Base: &Base{
name: option.Name,
addr: addr,
tp: C.Shadowsocks,
pdName: option.ProviderName,
udp: option.UDP,
tfo: option.TFO,
mpTcp: option.MPTCP,
iface: option.Interface,
rmark: option.RoutingMark,
prefer: C.NewDNSPrefer(option.IPVersion),
prefer: option.IPVersion,
},
method: method,
@@ -504,5 +480,7 @@ func NewShadowSocks(option ShadowSocksOption) (*ShadowSocks, error) {
shadowTLSOption: shadowTLSOpt,
restlsConfig: restlsConfig,
kcptunClient: kcptunClient,
}, nil
}
outbound.dialer = option.NewDialer(outbound.DialOptions())
return outbound, nil
}

View File

@@ -8,8 +8,6 @@ import (
"strconv"
N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/proxydialer"
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/transport/shadowsocks/core"
"github.com/metacubex/mihomo/transport/shadowsocks/shadowaead"
@@ -68,18 +66,7 @@ func (ssr *ShadowSocksR) StreamConnContext(ctx context.Context, c net.Conn, meta
// DialContext implements C.ProxyAdapter
func (ssr *ShadowSocksR) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn, err error) {
return ssr.DialContextWithDialer(ctx, dialer.NewDialer(ssr.DialOptions()...), metadata)
}
// DialContextWithDialer implements C.ProxyAdapter
func (ssr *ShadowSocksR) DialContextWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (_ C.Conn, err error) {
if len(ssr.option.DialerProxy) > 0 {
dialer, err = proxydialer.NewByName(ssr.option.DialerProxy, dialer)
if err != nil {
return nil, err
}
}
c, err := dialer.DialContext(ctx, "tcp", ssr.addr)
c, err := ssr.dialer.DialContext(ctx, "tcp", ssr.addr)
if err != nil {
return nil, fmt.Errorf("%s connect error: %w", ssr.addr, err)
}
@@ -94,18 +81,7 @@ func (ssr *ShadowSocksR) DialContextWithDialer(ctx context.Context, dialer C.Dia
// ListenPacketContext implements C.ProxyAdapter
func (ssr *ShadowSocksR) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (C.PacketConn, error) {
return ssr.ListenPacketWithDialer(ctx, dialer.NewDialer(ssr.DialOptions()...), metadata)
}
// ListenPacketWithDialer implements C.ProxyAdapter
func (ssr *ShadowSocksR) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (_ C.PacketConn, err error) {
if len(ssr.option.DialerProxy) > 0 {
dialer, err = proxydialer.NewByName(ssr.option.DialerProxy, dialer)
if err != nil {
return nil, err
}
}
if err = ssr.ResolveUDP(ctx, metadata); err != nil {
if err := ssr.ResolveUDP(ctx, metadata); err != nil {
return nil, err
}
addr, err := resolveUDPAddr(ctx, "udp", ssr.addr, ssr.prefer)
@@ -113,7 +89,7 @@ func (ssr *ShadowSocksR) ListenPacketWithDialer(ctx context.Context, dialer C.Di
return nil, err
}
pc, err := dialer.ListenPacket(ctx, "udp", "", addr.AddrPort())
pc, err := ssr.dialer.ListenPacket(ctx, "udp", "", addr.AddrPort())
if err != nil {
return nil, err
}
@@ -123,11 +99,6 @@ func (ssr *ShadowSocksR) ListenPacketWithDialer(ctx context.Context, dialer C.Di
return newPacketConn(&ssrPacketConn{EnhancePacketConn: epc, rAddr: addr}, ssr), nil
}
// SupportWithDialer implements C.ProxyAdapter
func (ssr *ShadowSocksR) SupportWithDialer() C.NetWork {
return C.ALLNet
}
// ProxyInfo implements C.ProxyAdapter
func (ssr *ShadowSocksR) ProxyInfo() C.ProxyInfo {
info := ssr.Base.ProxyInfo()
@@ -186,23 +157,26 @@ func NewShadowSocksR(option ShadowSocksROption) (*ShadowSocksR, error) {
return nil, fmt.Errorf("ssr %s initialize protocol error: %w", addr, err)
}
return &ShadowSocksR{
outbound := &ShadowSocksR{
Base: &Base{
name: option.Name,
addr: addr,
tp: C.ShadowsocksR,
pdName: option.ProviderName,
udp: option.UDP,
tfo: option.TFO,
mpTcp: option.MPTCP,
iface: option.Interface,
rmark: option.RoutingMark,
prefer: C.NewDNSPrefer(option.IPVersion),
prefer: option.IPVersion,
},
option: &option,
cipher: coreCiph,
obfs: obfs,
protocol: protocol,
}, nil
}
outbound.dialer = option.NewDialer(outbound.DialOptions())
return outbound, nil
}
type ssrPacketConn struct {

View File

@@ -3,8 +3,7 @@ package outbound
import (
"context"
CN "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/component/dialer"
N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/component/proxydialer"
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/log"
@@ -17,7 +16,6 @@ import (
type SingMux struct {
ProxyAdapter
client *mux.Client
dialer proxydialer.SingDialer
onlyTcp bool
}
@@ -61,7 +59,7 @@ func (s *SingMux) ListenPacketContext(ctx context.Context, metadata *C.Metadata)
if pc == nil {
return nil, E.New("packetConn is nil")
}
return newPacketConn(CN.NewThreadSafePacketConn(pc), s), nil
return newPacketConn(N.NewThreadSafePacketConn(pc), s), nil
}
func (s *SingMux) SupportUDP() bool {
@@ -96,7 +94,7 @@ func NewSingMux(option SingMuxOption, proxy ProxyAdapter) (ProxyAdapter, error)
// TODO
// "TCP Brutal is only supported on Linux-based systems"
singDialer := proxydialer.NewSingDialer(proxy, dialer.NewDialer(proxy.DialOptions()...), option.Statistic)
singDialer := proxydialer.NewSingDialer(proxydialer.New(proxy, option.Statistic))
client, err := mux.NewClient(mux.Options{
Dialer: singDialer,
Logger: log.SingLogger,
@@ -117,7 +115,6 @@ func NewSingMux(option SingMuxOption, proxy ProxyAdapter) (ProxyAdapter, error)
outbound := &SingMux{
ProxyAdapter: proxy,
client: client,
dialer: singDialer,
onlyTcp: option.OnlyTcp,
}
return outbound, nil

View File

@@ -8,8 +8,6 @@ import (
N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/common/structure"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/proxydialer"
C "github.com/metacubex/mihomo/constant"
obfs "github.com/metacubex/mihomo/transport/simple-obfs"
"github.com/metacubex/mihomo/transport/snell"
@@ -89,18 +87,7 @@ func (s *Snell) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn
return NewConn(c, s), err
}
return s.DialContextWithDialer(ctx, dialer.NewDialer(s.DialOptions()...), metadata)
}
// DialContextWithDialer implements C.ProxyAdapter
func (s *Snell) DialContextWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (_ C.Conn, err error) {
if len(s.option.DialerProxy) > 0 {
dialer, err = proxydialer.NewByName(s.option.DialerProxy, dialer)
if err != nil {
return nil, err
}
}
c, err := dialer.DialContext(ctx, "tcp", s.addr)
c, err := s.dialer.DialContext(ctx, "tcp", s.addr)
if err != nil {
return nil, fmt.Errorf("%s connect error: %w", s.addr, err)
}
@@ -115,22 +102,11 @@ func (s *Snell) DialContextWithDialer(ctx context.Context, dialer C.Dialer, meta
// ListenPacketContext implements C.ProxyAdapter
func (s *Snell) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (C.PacketConn, error) {
return s.ListenPacketWithDialer(ctx, dialer.NewDialer(s.DialOptions()...), metadata)
}
// ListenPacketWithDialer implements C.ProxyAdapter
func (s *Snell) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (C.PacketConn, error) {
var err error
if len(s.option.DialerProxy) > 0 {
dialer, err = proxydialer.NewByName(s.option.DialerProxy, dialer)
if err != nil {
return nil, err
}
}
if err = s.ResolveUDP(ctx, metadata); err != nil {
return nil, err
}
c, err := dialer.DialContext(ctx, "tcp", s.addr)
c, err := s.dialer.DialContext(ctx, "tcp", s.addr)
if err != nil {
return nil, err
}
@@ -141,11 +117,6 @@ func (s *Snell) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, met
return newPacketConn(pc, s), nil
}
// SupportWithDialer implements C.ProxyAdapter
func (s *Snell) SupportWithDialer() C.NetWork {
return C.ALLNet
}
// SupportUOT implements C.ProxyAdapter
func (s *Snell) SupportUOT() bool {
return true
@@ -194,30 +165,24 @@ func NewSnell(option SnellOption) (*Snell, error) {
name: option.Name,
addr: addr,
tp: C.Snell,
pdName: option.ProviderName,
udp: option.UDP,
tfo: option.TFO,
mpTcp: option.MPTCP,
iface: option.Interface,
rmark: option.RoutingMark,
prefer: C.NewDNSPrefer(option.IPVersion),
prefer: option.IPVersion,
},
option: &option,
psk: psk,
obfsOption: obfsOption,
version: option.Version,
}
s.dialer = option.NewDialer(s.DialOptions())
if option.Version == snell.Version2 {
s.pool = snell.NewPool(func(ctx context.Context) (*snell.Snell, error) {
var err error
var cDialer C.Dialer = dialer.NewDialer(s.DialOptions()...)
if len(s.option.DialerProxy) > 0 {
cDialer, err = proxydialer.NewByName(s.option.DialerProxy, cDialer)
if err != nil {
return nil, err
}
}
c, err := cDialer.DialContext(ctx, "tcp", addr)
c, err := s.dialer.DialContext(ctx, "tcp", addr)
if err != nil {
return nil, err
}

View File

@@ -2,7 +2,6 @@ package outbound
import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
@@ -12,10 +11,10 @@ import (
N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/component/ca"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/proxydialer"
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/transport/socks5"
"github.com/metacubex/tls"
)
type Socks5 struct {
@@ -69,18 +68,7 @@ func (ss *Socks5) StreamConnContext(ctx context.Context, c net.Conn, metadata *C
// DialContext implements C.ProxyAdapter
func (ss *Socks5) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn, err error) {
return ss.DialContextWithDialer(ctx, dialer.NewDialer(ss.DialOptions()...), metadata)
}
// DialContextWithDialer implements C.ProxyAdapter
func (ss *Socks5) DialContextWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (_ C.Conn, err error) {
if len(ss.option.DialerProxy) > 0 {
dialer, err = proxydialer.NewByName(ss.option.DialerProxy, dialer)
if err != nil {
return nil, err
}
}
c, err := dialer.DialContext(ctx, "tcp", ss.addr)
c, err := ss.dialer.DialContext(ctx, "tcp", ss.addr)
if err != nil {
return nil, fmt.Errorf("%s connect error: %w", ss.addr, err)
}
@@ -97,24 +85,12 @@ func (ss *Socks5) DialContextWithDialer(ctx context.Context, dialer C.Dialer, me
return NewConn(c, ss), nil
}
// SupportWithDialer implements C.ProxyAdapter
func (ss *Socks5) SupportWithDialer() C.NetWork {
return C.TCP
}
// ListenPacketContext implements C.ProxyAdapter
func (ss *Socks5) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (_ C.PacketConn, err error) {
var cDialer C.Dialer = dialer.NewDialer(ss.DialOptions()...)
if len(ss.option.DialerProxy) > 0 {
cDialer, err = proxydialer.NewByName(ss.option.DialerProxy, cDialer)
if err != nil {
return nil, err
}
}
if err = ss.ResolveUDP(ctx, metadata); err != nil {
return nil, err
}
c, err := cDialer.DialContext(ctx, "tcp", ss.addr)
c, err := ss.dialer.DialContext(ctx, "tcp", ss.addr)
if err != nil {
err = fmt.Errorf("%s connect error: %w", ss.addr, err)
return
@@ -161,7 +137,7 @@ func (ss *Socks5) ListenPacketContext(ctx context.Context, metadata *C.Metadata)
bindUDPAddr.IP = serverAddr.IP
}
pc, err := cDialer.ListenPacket(ctx, "udp", "", bindUDPAddr.AddrPort())
pc, err := ss.dialer.ListenPacket(ctx, "udp", "", bindUDPAddr.AddrPort())
if err != nil {
return
}
@@ -210,17 +186,18 @@ func NewSocks5(option Socks5Option) (*Socks5, error) {
}
}
return &Socks5{
outbound := &Socks5{
Base: &Base{
name: option.Name,
addr: net.JoinHostPort(option.Server, strconv.Itoa(option.Port)),
tp: C.Socks5,
pdName: option.ProviderName,
udp: option.UDP,
tfo: option.TFO,
mpTcp: option.MPTCP,
iface: option.Interface,
rmark: option.RoutingMark,
prefer: C.NewDNSPrefer(option.IPVersion),
prefer: option.IPVersion,
},
option: &option,
user: option.UserName,
@@ -228,7 +205,9 @@ func NewSocks5(option Socks5Option) (*Socks5, error) {
tls: option.TLS,
skipCertVerify: option.SkipCertVerify,
tlsConfig: tlsConfig,
}, nil
}
outbound.dialer = option.NewDialer(outbound.DialOptions())
return outbound, nil
}
type socksPacketConn struct {

View File

@@ -12,8 +12,6 @@ import (
"sync"
N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/proxydialer"
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/randv2"
@@ -44,14 +42,7 @@ type SshOption struct {
}
func (s *Ssh) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn, err error) {
var cDialer C.Dialer = dialer.NewDialer(s.DialOptions()...)
if len(s.option.DialerProxy) > 0 {
cDialer, err = proxydialer.NewByName(s.option.DialerProxy, cDialer)
if err != nil {
return nil, err
}
}
client, err := s.connect(ctx, cDialer, s.addr)
client, err := s.connect(ctx, s.addr)
if err != nil {
return nil, err
}
@@ -63,13 +54,13 @@ func (s *Ssh) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn,
return NewConn(c, s), nil
}
func (s *Ssh) connect(ctx context.Context, cDialer C.Dialer, addr string) (client *ssh.Client, err error) {
func (s *Ssh) connect(ctx context.Context, addr string) (client *ssh.Client, err error) {
s.cMutex.Lock()
defer s.cMutex.Unlock()
if s.client != nil {
return s.client, nil
}
c, err := cDialer.DialContext(ctx, "tcp", addr)
c, err := s.dialer.DialContext(ctx, "tcp", addr)
if err != nil {
return nil, err
}
@@ -195,14 +186,15 @@ func NewSsh(option SshOption) (*Ssh, error) {
name: option.Name,
addr: addr,
tp: C.Ssh,
pdName: option.ProviderName,
udp: false,
iface: option.Interface,
rmark: option.RoutingMark,
prefer: C.NewDNSPrefer(option.IPVersion),
prefer: option.IPVersion,
},
option: &option,
config: &config,
}
outbound.dialer = option.NewDialer(outbound.DialOptions())
return outbound, nil
}

View File

@@ -2,102 +2,111 @@ package outbound
import (
"context"
"crypto/sha256"
"encoding/binary"
"fmt"
"io"
"net"
"strconv"
"strings"
"time"
"github.com/metacubex/mihomo/log"
"github.com/saba-futai/sudoku/apis"
"github.com/saba-futai/sudoku/pkg/crypto"
"github.com/saba-futai/sudoku/pkg/obfs/httpmask"
"github.com/saba-futai/sudoku/pkg/obfs/sudoku"
"sync"
N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/proxydialer"
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/transport/sudoku"
)
type Sudoku struct {
*Base
option *SudokuOption
table *sudoku.Table
baseConf apis.ProtocolConfig
baseConf sudoku.ProtocolConfig
httpMaskMu sync.Mutex
httpMaskClient *sudoku.HTTPMaskTunnelClient
muxMu sync.Mutex
muxClient *sudoku.MultiplexClient
}
type SudokuOption struct {
BasicOption
Name string `proxy:"name"`
Server string `proxy:"server"`
Port int `proxy:"port"`
Key string `proxy:"key"`
AEADMethod string `proxy:"aead-method,omitempty"`
PaddingMin *int `proxy:"padding-min,omitempty"`
PaddingMax *int `proxy:"padding-max,omitempty"`
TableType string `proxy:"table-type,omitempty"` // "prefer_ascii" or "prefer_entropy"
HTTPMask bool `proxy:"http-mask,omitempty"`
Name string `proxy:"name"`
Server string `proxy:"server"`
Port int `proxy:"port"`
Key string `proxy:"key"`
AEADMethod string `proxy:"aead-method,omitempty"`
PaddingMin *int `proxy:"padding-min,omitempty"`
PaddingMax *int `proxy:"padding-max,omitempty"`
TableType string `proxy:"table-type,omitempty"` // "prefer_ascii" or "prefer_entropy"
EnablePureDownlink *bool `proxy:"enable-pure-downlink,omitempty"`
HTTPMask bool `proxy:"http-mask,omitempty"`
HTTPMaskMode string `proxy:"http-mask-mode,omitempty"` // "legacy" (default), "stream", "poll", "auto"
HTTPMaskTLS bool `proxy:"http-mask-tls,omitempty"` // only for http-mask-mode stream/poll/auto
HTTPMaskHost string `proxy:"http-mask-host,omitempty"` // optional Host/SNI override (domain or domain:port)
PathRoot string `proxy:"path-root,omitempty"` // optional first-level path prefix for HTTP tunnel endpoints
HTTPMaskMultiplex string `proxy:"http-mask-multiplex,omitempty"` // "off" (default), "auto" (reuse h1/h2), "on" (single tunnel, multi-target)
CustomTable string `proxy:"custom-table,omitempty"` // optional custom byte layout, e.g. xpxvvpvv
CustomTables []string `proxy:"custom-tables,omitempty"` // optional table rotation patterns, overrides custom-table when non-empty
}
// DialContext implements C.ProxyAdapter
func (s *Sudoku) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) {
return s.DialContextWithDialer(ctx, dialer.NewDialer(s.DialOptions()...), metadata)
}
// DialContextWithDialer implements C.ProxyAdapter
func (s *Sudoku) DialContextWithDialer(ctx context.Context, d C.Dialer, metadata *C.Metadata) (_ C.Conn, err error) {
if len(s.option.DialerProxy) > 0 {
d, err = proxydialer.NewByName(s.option.DialerProxy, d)
if err != nil {
return nil, err
}
}
func (s *Sudoku) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn, err error) {
cfg, err := s.buildConfig(metadata)
if err != nil {
return nil, err
}
c, err := d.DialContext(ctx, "tcp", s.addr)
if err != nil {
return nil, fmt.Errorf("%s connect error: %w", s.addr, err)
muxMode := normalizeHTTPMaskMultiplex(cfg.HTTPMaskMultiplex)
if muxMode == "on" && !cfg.DisableHTTPMask && httpTunnelModeEnabled(cfg.HTTPMaskMode) {
stream, muxErr := s.dialMultiplex(ctx, cfg.TargetAddress)
if muxErr == nil {
return NewConn(stream, s), nil
}
return nil, muxErr
}
defer func() {
safeConnClose(c, err)
}()
if ctx.Done() != nil {
done := N.SetupContextForConn(ctx, c)
defer done(&err)
}
c, err = s.streamConn(c, cfg)
c, err := s.dialAndHandshake(ctx, cfg)
if err != nil {
return nil, err
}
defer func() { safeConnClose(c, err) }()
addrBuf, err := sudoku.EncodeAddress(cfg.TargetAddress)
if err != nil {
return nil, fmt.Errorf("encode target address failed: %w", err)
}
if _, err = c.Write(addrBuf); err != nil {
return nil, fmt.Errorf("send target address failed: %w", err)
}
return NewConn(c, s), nil
}
// ListenPacketContext implements C.ProxyAdapter
func (s *Sudoku) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (C.PacketConn, error) {
return nil, C.ErrNotSupport
if err := s.ResolveUDP(ctx, metadata); err != nil {
return nil, err
}
cfg, err := s.buildConfig(metadata)
if err != nil {
return nil, err
}
c, err := s.dialAndHandshake(ctx, cfg)
if err != nil {
return nil, err
}
if err = sudoku.WritePreface(c); err != nil {
_ = c.Close()
return nil, fmt.Errorf("send uot preface failed: %w", err)
}
return newPacketConn(N.NewThreadSafePacketConn(sudoku.NewUoTPacketConn(c)), s), nil
}
// SupportUOT implements C.ProxyAdapter
func (s *Sudoku) SupportUOT() bool {
return false // Sudoku protocol only supports TCP
}
// SupportWithDialer implements C.ProxyAdapter
func (s *Sudoku) SupportWithDialer() C.NetWork {
return C.TCP
return true
}
// ProxyInfo implements C.ProxyAdapter
@@ -107,7 +116,7 @@ func (s *Sudoku) ProxyInfo() C.ProxyInfo {
return info
}
func (s *Sudoku) buildConfig(metadata *C.Metadata) (*apis.ProtocolConfig, error) {
func (s *Sudoku) buildConfig(metadata *C.Metadata) (*sudoku.ProtocolConfig, error) {
if metadata == nil || metadata.DstPort == 0 || !metadata.Valid() {
return nil, fmt.Errorf("invalid metadata for sudoku outbound")
}
@@ -121,33 +130,6 @@ func (s *Sudoku) buildConfig(metadata *C.Metadata) (*apis.ProtocolConfig, error)
return &cfg, nil
}
func (s *Sudoku) streamConn(rawConn net.Conn, cfg *apis.ProtocolConfig) (_ net.Conn, err error) {
if !cfg.DisableHTTPMask {
if err = httpmask.WriteRandomRequestHeader(rawConn, cfg.ServerAddress); err != nil {
return nil, fmt.Errorf("write http mask failed: %w", err)
}
}
obfsConn := sudoku.NewConn(rawConn, cfg.Table, cfg.PaddingMin, cfg.PaddingMax, false)
cConn, err := crypto.NewAEADConn(obfsConn, cfg.Key, cfg.AEADMethod)
if err != nil {
return nil, fmt.Errorf("setup crypto failed: %w", err)
}
handshake := buildSudokuHandshakePayload(cfg.Key)
if _, err = cConn.Write(handshake[:]); err != nil {
cConn.Close()
return nil, fmt.Errorf("send handshake failed: %w", err)
}
if err = writeTargetAddress(cConn, cfg.TargetAddress); err != nil {
cConn.Close()
return nil, fmt.Errorf("send target address failed: %w", err)
}
return cConn, nil
}
func NewSudoku(option SudokuOption) (*Sudoku, error) {
if option.Server == "" {
return nil, fmt.Errorf("server is required")
@@ -167,16 +149,7 @@ func NewSudoku(option SudokuOption) (*Sudoku, error) {
return nil, fmt.Errorf("table-type must be prefer_ascii or prefer_entropy")
}
seed := option.Key
if recoveredFromKey, err := crypto.RecoverPublicKey(option.Key); err == nil {
seed = crypto.EncodePoint(recoveredFromKey)
}
start := time.Now()
table := sudoku.NewTable(seed, tableType)
log.Infoln("[Sudoku] Tables initialized (%s) in %v", tableType, time.Since(start))
defaultConf := apis.DefaultConfig()
defaultConf := sudoku.DefaultConfig()
paddingMin := defaultConf.PaddingMin
paddingMax := defaultConf.PaddingMax
if option.PaddingMin != nil {
@@ -191,80 +164,242 @@ func NewSudoku(option SudokuOption) (*Sudoku, error) {
if option.PaddingMax == nil && option.PaddingMin != nil && paddingMax < paddingMin {
paddingMax = paddingMin
}
enablePureDownlink := defaultConf.EnablePureDownlink
if option.EnablePureDownlink != nil {
enablePureDownlink = *option.EnablePureDownlink
}
baseConf := apis.ProtocolConfig{
baseConf := sudoku.ProtocolConfig{
ServerAddress: net.JoinHostPort(option.Server, strconv.Itoa(option.Port)),
Key: option.Key,
AEADMethod: defaultConf.AEADMethod,
Table: table,
PaddingMin: paddingMin,
PaddingMax: paddingMax,
EnablePureDownlink: enablePureDownlink,
HandshakeTimeoutSeconds: defaultConf.HandshakeTimeoutSeconds,
DisableHTTPMask: !option.HTTPMask,
HTTPMaskMode: defaultConf.HTTPMaskMode,
HTTPMaskTLSEnabled: option.HTTPMaskTLS,
HTTPMaskHost: option.HTTPMaskHost,
HTTPMaskPathRoot: strings.TrimSpace(option.PathRoot),
HTTPMaskMultiplex: defaultConf.HTTPMaskMultiplex,
}
if option.HTTPMaskMode != "" {
baseConf.HTTPMaskMode = option.HTTPMaskMode
}
if option.HTTPMaskMultiplex != "" {
baseConf.HTTPMaskMultiplex = option.HTTPMaskMultiplex
}
tables, err := sudoku.NewTablesWithCustomPatterns(sudoku.ClientAEADSeed(option.Key), tableType, option.CustomTable, option.CustomTables)
if err != nil {
return nil, fmt.Errorf("build table(s) failed: %w", err)
}
if len(tables) == 1 {
baseConf.Table = tables[0]
} else {
baseConf.Tables = tables
}
if option.AEADMethod != "" {
baseConf.AEADMethod = option.AEADMethod
}
return &Sudoku{
outbound := &Sudoku{
Base: &Base{
name: option.Name,
addr: baseConf.ServerAddress,
tp: C.Sudoku,
udp: false,
pdName: option.ProviderName,
udp: true,
tfo: option.TFO,
mpTcp: option.MPTCP,
iface: option.Interface,
rmark: option.RoutingMark,
prefer: C.NewDNSPrefer(option.IPVersion),
prefer: option.IPVersion,
},
option: &option,
table: table,
baseConf: baseConf,
}, nil
}
outbound.dialer = option.NewDialer(outbound.DialOptions())
return outbound, nil
}
func buildSudokuHandshakePayload(key string) [16]byte {
var payload [16]byte
binary.BigEndian.PutUint64(payload[:8], uint64(time.Now().Unix()))
hash := sha256.Sum256([]byte(key))
copy(payload[8:], hash[:8])
return payload
func (s *Sudoku) Close() error {
s.resetMuxClient()
s.resetHTTPMaskClient()
return s.Base.Close()
}
func writeTargetAddress(w io.Writer, rawAddr string) error {
host, portStr, err := net.SplitHostPort(rawAddr)
if err != nil {
return err
func normalizeHTTPMaskMultiplex(mode string) string {
switch strings.ToLower(strings.TrimSpace(mode)) {
case "", "off":
return "off"
case "auto":
return "auto"
case "on":
return "on"
default:
return "off"
}
}
func httpTunnelModeEnabled(mode string) bool {
switch strings.ToLower(strings.TrimSpace(mode)) {
case "stream", "poll", "auto":
return true
default:
return false
}
}
func (s *Sudoku) dialAndHandshake(ctx context.Context, cfg *sudoku.ProtocolConfig) (_ net.Conn, err error) {
if cfg == nil {
return nil, fmt.Errorf("config is required")
}
portInt, err := net.LookupPort("tcp", portStr)
if err != nil {
return err
handshakeCfg := *cfg
if !handshakeCfg.DisableHTTPMask && httpTunnelModeEnabled(handshakeCfg.HTTPMaskMode) {
handshakeCfg.DisableHTTPMask = true
}
var buf []byte
if ip := net.ParseIP(host); ip != nil {
if ip4 := ip.To4(); ip4 != nil {
buf = append(buf, 0x01) // IPv4
buf = append(buf, ip4...)
} else {
buf = append(buf, 0x04) // IPv6
buf = append(buf, ip...)
upgrade := func(raw net.Conn) (net.Conn, error) {
return sudoku.ClientHandshake(raw, &handshakeCfg)
}
var (
c net.Conn
handshakeDone bool
)
if !cfg.DisableHTTPMask && httpTunnelModeEnabled(cfg.HTTPMaskMode) {
muxMode := normalizeHTTPMaskMultiplex(cfg.HTTPMaskMultiplex)
switch muxMode {
case "auto", "on":
client, errX := s.getOrCreateHTTPMaskClient(cfg)
if errX != nil {
return nil, errX
}
c, err = client.Dial(ctx, upgrade)
default:
c, err = sudoku.DialHTTPMaskTunnel(ctx, cfg.ServerAddress, cfg, s.dialer.DialContext, upgrade)
}
} else {
if len(host) > 255 {
return fmt.Errorf("domain too long")
if err == nil && c != nil {
handshakeDone = true
}
buf = append(buf, 0x03) // domain
buf = append(buf, byte(len(host)))
buf = append(buf, host...)
}
if c == nil && err == nil {
c, err = s.dialer.DialContext(ctx, "tcp", s.addr)
}
if err != nil {
return nil, fmt.Errorf("%s connect error: %w", s.addr, err)
}
var portBytes [2]byte
binary.BigEndian.PutUint16(portBytes[:], uint16(portInt))
buf = append(buf, portBytes[:]...)
defer func() { safeConnClose(c, err) }()
_, err = w.Write(buf)
return err
if ctx.Done() != nil {
done := N.SetupContextForConn(ctx, c)
defer done(&err)
}
if !handshakeDone {
c, err = sudoku.ClientHandshake(c, &handshakeCfg)
if err != nil {
return nil, err
}
}
return c, nil
}
func (s *Sudoku) dialMultiplex(ctx context.Context, targetAddress string) (net.Conn, error) {
for attempt := 0; attempt < 2; attempt++ {
client, err := s.getOrCreateMuxClient(ctx)
if err != nil {
return nil, err
}
stream, err := client.Dial(ctx, targetAddress)
if err != nil {
s.resetMuxClient()
continue
}
return stream, nil
}
return nil, fmt.Errorf("multiplex open stream failed")
}
func (s *Sudoku) getOrCreateMuxClient(ctx context.Context) (*sudoku.MultiplexClient, error) {
if s == nil {
return nil, fmt.Errorf("nil adapter")
}
s.muxMu.Lock()
if s.muxClient != nil && !s.muxClient.IsClosed() {
client := s.muxClient
s.muxMu.Unlock()
return client, nil
}
s.muxMu.Unlock()
s.muxMu.Lock()
defer s.muxMu.Unlock()
if s.muxClient != nil && !s.muxClient.IsClosed() {
return s.muxClient, nil
}
baseCfg := s.baseConf
baseConn, err := s.dialAndHandshake(ctx, &baseCfg)
if err != nil {
return nil, err
}
client, err := sudoku.StartMultiplexClient(baseConn)
if err != nil {
_ = baseConn.Close()
return nil, err
}
s.muxClient = client
return client, nil
}
func (s *Sudoku) resetMuxClient() {
s.muxMu.Lock()
defer s.muxMu.Unlock()
if s.muxClient != nil {
_ = s.muxClient.Close()
s.muxClient = nil
}
}
func (s *Sudoku) getOrCreateHTTPMaskClient(cfg *sudoku.ProtocolConfig) (*sudoku.HTTPMaskTunnelClient, error) {
if s == nil {
return nil, fmt.Errorf("nil adapter")
}
if cfg == nil {
return nil, fmt.Errorf("config is required")
}
s.httpMaskMu.Lock()
defer s.httpMaskMu.Unlock()
if s.httpMaskClient != nil {
return s.httpMaskClient, nil
}
c, err := sudoku.NewHTTPMaskTunnelClient(cfg.ServerAddress, cfg, s.dialer.DialContext)
if err != nil {
return nil, err
}
s.httpMaskClient = c
return c, nil
}
func (s *Sudoku) resetHTTPMaskClient() {
s.httpMaskMu.Lock()
defer s.httpMaskMu.Unlock()
if s.httpMaskClient != nil {
s.httpMaskClient.CloseIdleConnections()
s.httpMaskClient = nil
}
}

View File

@@ -2,24 +2,23 @@ package outbound
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"strconv"
N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/component/ca"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/ech"
"github.com/metacubex/mihomo/component/proxydialer"
tlsC "github.com/metacubex/mihomo/component/tls"
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/transport/gun"
"github.com/metacubex/mihomo/transport/shadowsocks/core"
"github.com/metacubex/mihomo/transport/trojan"
"github.com/metacubex/mihomo/transport/vmess"
"github.com/metacubex/http"
"github.com/metacubex/tls"
)
type Trojan struct {
@@ -196,18 +195,7 @@ func (t *Trojan) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Con
return NewConn(c, t), nil
}
return t.DialContextWithDialer(ctx, dialer.NewDialer(t.DialOptions()...), metadata)
}
// DialContextWithDialer implements C.ProxyAdapter
func (t *Trojan) DialContextWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (_ C.Conn, err error) {
if len(t.option.DialerProxy) > 0 {
dialer, err = proxydialer.NewByName(t.option.DialerProxy, dialer)
if err != nil {
return nil, err
}
}
c, err := dialer.DialContext(ctx, "tcp", t.addr)
c, err = t.dialer.DialContext(ctx, "tcp", t.addr)
if err != nil {
return nil, fmt.Errorf("%s connect error: %w", t.addr, err)
}
@@ -250,21 +238,10 @@ func (t *Trojan) ListenPacketContext(ctx context.Context, metadata *C.Metadata)
pc := trojan.NewPacketConn(c)
return newPacketConn(pc, t), err
}
return t.ListenPacketWithDialer(ctx, dialer.NewDialer(t.DialOptions()...), metadata)
}
// ListenPacketWithDialer implements C.ProxyAdapter
func (t *Trojan) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (_ C.PacketConn, err error) {
if len(t.option.DialerProxy) > 0 {
dialer, err = proxydialer.NewByName(t.option.DialerProxy, dialer)
if err != nil {
return nil, err
}
}
if err = t.ResolveUDP(ctx, metadata); err != nil {
return nil, err
}
c, err := dialer.DialContext(ctx, "tcp", t.addr)
c, err = t.dialer.DialContext(ctx, "tcp", t.addr)
if err != nil {
return nil, fmt.Errorf("%s connect error: %w", t.addr, err)
}
@@ -280,11 +257,6 @@ func (t *Trojan) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, me
return newPacketConn(pc, t), err
}
// SupportWithDialer implements C.ProxyAdapter
func (t *Trojan) SupportWithDialer() C.NetWork {
return C.ALLNet
}
// SupportUOT implements C.ProxyAdapter
func (t *Trojan) SupportUOT() bool {
return true
@@ -317,16 +289,18 @@ func NewTrojan(option TrojanOption) (*Trojan, error) {
name: option.Name,
addr: addr,
tp: C.Trojan,
pdName: option.ProviderName,
udp: option.UDP,
tfo: option.TFO,
mpTcp: option.MPTCP,
iface: option.Interface,
rmark: option.RoutingMark,
prefer: C.NewDNSPrefer(option.IPVersion),
prefer: option.IPVersion,
},
option: &option,
hexPassword: trojan.Key(option.Password),
}
t.dialer = option.NewDialer(t.DialOptions())
var err error
t.realityConfig, err = option.RealityOpts.Parse()
@@ -355,15 +329,7 @@ func NewTrojan(option TrojanOption) (*Trojan, error) {
if option.Network == "grpc" {
dialFn := func(ctx context.Context, network, addr string) (net.Conn, error) {
var err error
var cDialer C.Dialer = dialer.NewDialer(t.DialOptions()...)
if len(t.option.DialerProxy) > 0 {
cDialer, err = proxydialer.NewByName(t.option.DialerProxy, cDialer)
if err != nil {
return nil, err
}
}
c, err := cDialer.DialContext(ctx, "tcp", t.addr)
c, err := t.dialer.DialContext(ctx, "tcp", t.addr)
if err != nil {
return nil, fmt.Errorf("%s connect error: %s", t.addr, err.Error())
}
@@ -390,6 +356,7 @@ func NewTrojan(option TrojanOption) (*Trojan, error) {
t.gunTLSConfig = tlsConfig
t.gunConfig = &gun.Config{
ServiceName: option.GrpcOpts.GrpcServiceName,
UserAgent: option.GrpcOpts.GrpcUserAgent,
Host: option.SNI,
ClientFingerprint: option.ClientFingerprint,
}

View File

@@ -2,7 +2,6 @@ package outbound
import (
"context"
"crypto/tls"
"fmt"
"math"
"net"
@@ -10,10 +9,7 @@ import (
"time"
"github.com/metacubex/mihomo/component/ca"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/ech"
"github.com/metacubex/mihomo/component/proxydialer"
tlsC "github.com/metacubex/mihomo/component/tls"
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/transport/tuic"
@@ -21,6 +17,7 @@ import (
"github.com/metacubex/quic-go"
M "github.com/metacubex/sing/common/metadata"
"github.com/metacubex/sing/common/uot"
"github.com/metacubex/tls"
)
type Tuic struct {
@@ -28,7 +25,7 @@ type Tuic struct {
option *TuicOption
client *tuic.PoolClient
tlsConfig *tlsC.Config
tlsConfig *tls.Config
echConfig *ech.Config
}
@@ -70,12 +67,7 @@ type TuicOption struct {
// DialContext implements C.ProxyAdapter
func (t *Tuic) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) {
return t.DialContextWithDialer(ctx, dialer.NewDialer(t.DialOptions()...), metadata)
}
// DialContextWithDialer implements C.ProxyAdapter
func (t *Tuic) DialContextWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (C.Conn, error) {
conn, err := t.client.DialContextWithDialer(ctx, metadata, dialer, t.dialWithDialer)
conn, err := t.client.DialContext(ctx, metadata)
if err != nil {
return nil, err
}
@@ -84,11 +76,6 @@ func (t *Tuic) DialContextWithDialer(ctx context.Context, dialer C.Dialer, metad
// ListenPacketContext implements C.ProxyAdapter
func (t *Tuic) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (_ C.PacketConn, err error) {
return t.ListenPacketWithDialer(ctx, dialer.NewDialer(t.DialOptions()...), metadata)
}
// ListenPacketWithDialer implements C.ProxyAdapter
func (t *Tuic) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (_ C.PacketConn, err error) {
if err = t.ResolveUDP(ctx, metadata); err != nil {
return nil, err
}
@@ -98,7 +85,7 @@ func (t *Tuic) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, meta
uotMetadata := *metadata
uotMetadata.Host = uotDestination.Fqdn
uotMetadata.DstPort = uotDestination.Port
c, err := t.DialContextWithDialer(ctx, dialer, &uotMetadata)
c, err := t.DialContext(ctx, &uotMetadata)
if err != nil {
return nil, err
}
@@ -112,25 +99,14 @@ func (t *Tuic) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, meta
return newPacketConn(uot.NewLazyConn(c, uot.Request{Destination: destination}), t), nil
}
}
pc, err := t.client.ListenPacketWithDialer(ctx, metadata, dialer, t.dialWithDialer)
pc, err := t.client.ListenPacket(ctx, metadata)
if err != nil {
return nil, err
}
return newPacketConn(pc, t), nil
}
// SupportWithDialer implements C.ProxyAdapter
func (t *Tuic) SupportWithDialer() C.NetWork {
return C.ALLNet
}
func (t *Tuic) dialWithDialer(ctx context.Context, dialer C.Dialer) (transport *quic.Transport, addr net.Addr, err error) {
if len(t.option.DialerProxy) > 0 {
dialer, err = proxydialer.NewByName(t.option.DialerProxy, dialer)
if err != nil {
return nil, nil, err
}
}
func (t *Tuic) dial(ctx context.Context) (transport *quic.Transport, addr net.Addr, err error) {
udpAddr, err := resolveUDPAddr(ctx, "udp", t.addr, t.prefer)
if err != nil {
return nil, nil, err
@@ -141,7 +117,7 @@ func (t *Tuic) dialWithDialer(ctx context.Context, dialer C.Dialer) (transport *
}
addr = udpAddr
var pc net.PacketConn
pc, err = dialer.ListenPacket(ctx, "udp", "", udpAddr.AddrPort())
pc, err = t.dialer.ListenPacket(ctx, "udp", "", udpAddr.AddrPort())
if err != nil {
return nil, nil, err
}
@@ -256,7 +232,7 @@ func NewTuic(option TuicOption) (*Tuic, error) {
tlsConfig.InsecureSkipVerify = true // tls: either ServerName or InsecureSkipVerify must be specified in the tls.Config
}
tlsClientConfig := tlsC.UConfig(tlsConfig)
tlsClientConfig := tlsConfig
echConfig, err := option.ECHOpts.Parse()
if err != nil {
return nil, err
@@ -275,16 +251,18 @@ func NewTuic(option TuicOption) (*Tuic, error) {
name: option.Name,
addr: addr,
tp: C.Tuic,
pdName: option.ProviderName,
udp: true,
tfo: option.FastOpen,
iface: option.Interface,
rmark: option.RoutingMark,
prefer: C.NewDNSPrefer(option.IPVersion),
prefer: option.IPVersion,
},
option: &option,
tlsConfig: tlsClientConfig,
echConfig: echConfig,
}
t.dialer = option.NewDialer(t.DialOptions())
clientMaxOpenStreams := int64(option.MaxOpenStreams)
@@ -313,7 +291,7 @@ func NewTuic(option TuicOption) (*Tuic, error) {
CWND: option.CWND,
}
t.client = tuic.NewPoolClientV4(clientOption)
t.client = tuic.NewPoolClientV4(clientOption, t.dial)
} else {
maxUdpRelayPacketSize := option.MaxUdpRelayPacketSize
if maxUdpRelayPacketSize > tuic.MaxFragSizeV5 {
@@ -332,7 +310,7 @@ func NewTuic(option TuicOption) (*Tuic, error) {
CWND: option.CWND,
}
t.client = tuic.NewPoolClientV5(clientOption)
t.client = tuic.NewPoolClientV5(clientOption, t.dial)
}
return t, nil

View File

@@ -2,19 +2,15 @@ package outbound
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"strconv"
"github.com/metacubex/mihomo/common/convert"
N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/common/utils"
"github.com/metacubex/mihomo/component/ca"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/ech"
"github.com/metacubex/mihomo/component/proxydialer"
tlsC "github.com/metacubex/mihomo/component/tls"
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/transport/gun"
@@ -22,9 +18,11 @@ import (
"github.com/metacubex/mihomo/transport/vless/encryption"
"github.com/metacubex/mihomo/transport/vmess"
"github.com/metacubex/http"
vmessSing "github.com/metacubex/sing-vmess"
"github.com/metacubex/sing-vmess/packetaddr"
M "github.com/metacubex/sing/common/metadata"
"github.com/metacubex/tls"
)
type Vless struct {
@@ -252,18 +250,7 @@ func (v *Vless) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn
return NewConn(c, v), nil
}
return v.DialContextWithDialer(ctx, dialer.NewDialer(v.DialOptions()...), metadata)
}
// DialContextWithDialer implements C.ProxyAdapter
func (v *Vless) DialContextWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (_ C.Conn, err error) {
if len(v.option.DialerProxy) > 0 {
dialer, err = proxydialer.NewByName(v.option.DialerProxy, dialer)
if err != nil {
return nil, err
}
}
c, err := dialer.DialContext(ctx, "tcp", v.addr)
c, err = v.dialer.DialContext(ctx, "tcp", v.addr)
if err != nil {
return nil, fmt.Errorf("%s connect error: %s", v.addr, err.Error())
}
@@ -301,23 +288,12 @@ func (v *Vless) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (
return v.ListenPacketOnStreamConn(ctx, c, metadata)
}
return v.ListenPacketWithDialer(ctx, dialer.NewDialer(v.DialOptions()...), metadata)
}
// ListenPacketWithDialer implements C.ProxyAdapter
func (v *Vless) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (_ C.PacketConn, err error) {
if len(v.option.DialerProxy) > 0 {
dialer, err = proxydialer.NewByName(v.option.DialerProxy, dialer)
if err != nil {
return nil, err
}
}
if err = v.ResolveUDP(ctx, metadata); err != nil {
return nil, err
}
c, err := dialer.DialContext(ctx, "tcp", v.addr)
c, err = v.dialer.DialContext(ctx, "tcp", v.addr)
if err != nil {
return nil, fmt.Errorf("%s connect error: %s", v.addr, err.Error())
}
@@ -333,11 +309,6 @@ func (v *Vless) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, met
return v.ListenPacketOnStreamConn(ctx, c, metadata)
}
// SupportWithDialer implements C.ProxyAdapter
func (v *Vless) SupportWithDialer() C.NetWork {
return C.ALLNet
}
// ListenPacketOnStreamConn implements C.ProxyAdapter
func (v *Vless) ListenPacketOnStreamConn(ctx context.Context, c net.Conn, metadata *C.Metadata) (_ C.PacketConn, err error) {
if err = v.ResolveUDP(ctx, metadata); err != nil {
@@ -446,17 +417,19 @@ func NewVless(option VlessOption) (*Vless, error) {
name: option.Name,
addr: net.JoinHostPort(option.Server, strconv.Itoa(option.Port)),
tp: C.Vless,
pdName: option.ProviderName,
udp: option.UDP,
xudp: option.XUDP,
tfo: option.TFO,
mpTcp: option.MPTCP,
iface: option.Interface,
rmark: option.RoutingMark,
prefer: C.NewDNSPrefer(option.IPVersion),
prefer: option.IPVersion,
},
client: client,
option: &option,
}
v.dialer = option.NewDialer(v.DialOptions())
v.encryption, err = encryption.NewClient(option.Encryption)
if err != nil {
@@ -480,15 +453,7 @@ func NewVless(option VlessOption) (*Vless, error) {
}
case "grpc":
dialFn := func(ctx context.Context, network, addr string) (net.Conn, error) {
var err error
var cDialer C.Dialer = dialer.NewDialer(v.DialOptions()...)
if len(v.option.DialerProxy) > 0 {
cDialer, err = proxydialer.NewByName(v.option.DialerProxy, cDialer)
if err != nil {
return nil, err
}
}
c, err := cDialer.DialContext(ctx, "tcp", v.addr)
c, err := v.dialer.DialContext(ctx, "tcp", v.addr)
if err != nil {
return nil, fmt.Errorf("%s connect error: %s", v.addr, err.Error())
}
@@ -497,6 +462,7 @@ func NewVless(option VlessOption) (*Vless, error) {
gunConfig := &gun.Config{
ServiceName: v.option.GrpcOpts.GrpcServiceName,
UserAgent: v.option.GrpcOpts.GrpcUserAgent,
Host: v.option.ServerName,
ClientFingerprint: v.option.ClientFingerprint,
}

View File

@@ -2,11 +2,9 @@ package outbound
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"strconv"
"strings"
"sync"
@@ -14,18 +12,18 @@ import (
N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/common/utils"
"github.com/metacubex/mihomo/component/ca"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/ech"
"github.com/metacubex/mihomo/component/proxydialer"
tlsC "github.com/metacubex/mihomo/component/tls"
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/ntp"
"github.com/metacubex/mihomo/transport/gun"
mihomoVMess "github.com/metacubex/mihomo/transport/vmess"
"github.com/metacubex/http"
vmess "github.com/metacubex/sing-vmess"
"github.com/metacubex/sing-vmess/packetaddr"
M "github.com/metacubex/sing/common/metadata"
"github.com/metacubex/tls"
)
var ErrUDPRemoteAddrMismatch = errors.New("udp packet dropped due to mismatched remote address")
@@ -88,6 +86,7 @@ type HTTP2Options struct {
type GrpcOptions struct {
GrpcServiceName string `proxy:"grpc-service-name,omitempty"`
GrpcUserAgent string `proxy:"grpc-user-agent,omitempty"`
}
type WSOptions struct {
@@ -313,18 +312,7 @@ func (v *Vmess) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn
return NewConn(c, v), nil
}
return v.DialContextWithDialer(ctx, dialer.NewDialer(v.DialOptions()...), metadata)
}
// DialContextWithDialer implements C.ProxyAdapter
func (v *Vmess) DialContextWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (_ C.Conn, err error) {
if len(v.option.DialerProxy) > 0 {
dialer, err = proxydialer.NewByName(v.option.DialerProxy, dialer)
if err != nil {
return nil, err
}
}
c, err := dialer.DialContext(ctx, "tcp", v.addr)
c, err = v.dialer.DialContext(ctx, "tcp", v.addr)
if err != nil {
return nil, fmt.Errorf("%s connect error: %s", v.addr, err.Error())
}
@@ -358,23 +346,12 @@ func (v *Vmess) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (
}
return v.ListenPacketOnStreamConn(ctx, c, metadata)
}
return v.ListenPacketWithDialer(ctx, dialer.NewDialer(v.DialOptions()...), metadata)
}
// ListenPacketWithDialer implements C.ProxyAdapter
func (v *Vmess) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (_ C.PacketConn, err error) {
if len(v.option.DialerProxy) > 0 {
dialer, err = proxydialer.NewByName(v.option.DialerProxy, dialer)
if err != nil {
return nil, err
}
}
if err = v.ResolveUDP(ctx, metadata); err != nil {
return nil, err
}
c, err := dialer.DialContext(ctx, "tcp", v.addr)
c, err = v.dialer.DialContext(ctx, "tcp", v.addr)
if err != nil {
return nil, fmt.Errorf("%s connect error: %s", v.addr, err.Error())
}
@@ -389,11 +366,6 @@ func (v *Vmess) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, met
return v.ListenPacketOnStreamConn(ctx, c, metadata)
}
// SupportWithDialer implements C.ProxyAdapter
func (v *Vmess) SupportWithDialer() C.NetWork {
return C.ALLNet
}
// ProxyInfo implements C.ProxyAdapter
func (v *Vmess) ProxyInfo() C.ProxyInfo {
info := v.Base.ProxyInfo()
@@ -456,17 +428,19 @@ func NewVmess(option VmessOption) (*Vmess, error) {
name: option.Name,
addr: net.JoinHostPort(option.Server, strconv.Itoa(option.Port)),
tp: C.Vmess,
pdName: option.ProviderName,
udp: option.UDP,
xudp: option.XUDP,
tfo: option.TFO,
mpTcp: option.MPTCP,
iface: option.Interface,
rmark: option.RoutingMark,
prefer: C.NewDNSPrefer(option.IPVersion),
prefer: option.IPVersion,
},
client: client,
option: &option,
}
v.dialer = option.NewDialer(v.DialOptions())
v.realityConfig, err = v.option.RealityOpts.Parse()
if err != nil {
@@ -485,15 +459,7 @@ func NewVmess(option VmessOption) (*Vmess, error) {
}
case "grpc":
dialFn := func(ctx context.Context, network, addr string) (net.Conn, error) {
var err error
var cDialer C.Dialer = dialer.NewDialer(v.DialOptions()...)
if len(v.option.DialerProxy) > 0 {
cDialer, err = proxydialer.NewByName(v.option.DialerProxy, cDialer)
if err != nil {
return nil, err
}
}
c, err := cDialer.DialContext(ctx, "tcp", v.addr)
c, err := v.dialer.DialContext(ctx, "tcp", v.addr)
if err != nil {
return nil, fmt.Errorf("%s connect error: %s", v.addr, err.Error())
}
@@ -502,6 +468,7 @@ func NewVmess(option VmessOption) (*Vmess, error) {
gunConfig := &gun.Config{
ServiceName: v.option.GrpcOpts.GrpcServiceName,
UserAgent: v.option.GrpcOpts.GrpcUserAgent,
Host: v.option.ServerName,
ClientFingerprint: v.option.ClientFingerprint,
}

View File

@@ -40,7 +40,6 @@ type WireGuard struct {
bind *wireguard.ClientBind
device wireguardGoDevice
tunDevice wireguard.Device
dialer proxydialer.SingDialer
resolver resolver.Resolver
initOk atomic.Bool
@@ -171,14 +170,15 @@ func NewWireGuard(option WireGuardOption) (*WireGuard, error) {
name: option.Name,
addr: net.JoinHostPort(option.Server, strconv.Itoa(option.Port)),
tp: C.WireGuard,
pdName: option.ProviderName,
udp: option.UDP,
iface: option.Interface,
rmark: option.RoutingMark,
prefer: C.NewDNSPrefer(option.IPVersion),
prefer: option.IPVersion,
},
}
singDialer := proxydialer.NewSlowDownSingDialer(proxydialer.NewByNameSingDialer(option.DialerProxy, dialer.NewDialer(outbound.DialOptions()...)), slowdown.New())
outbound.dialer = singDialer
outbound.dialer = option.NewDialer(outbound.DialOptions())
singDialer := proxydialer.NewSingDialer(proxydialer.NewSlowDownDialer(outbound.dialer, slowdown.New()))
var reserved [3]uint8
if len(option.Reserved) > 0 {
@@ -196,7 +196,7 @@ func NewWireGuard(option WireGuardOption) (*WireGuard, error) {
outbound.connectAddr = option.Addr()
}
}
outbound.bind = wireguard.NewClientBind(context.Background(), wgSingErrorHandler{outbound.Name()}, outbound.dialer, isConnect, outbound.connectAddr.AddrPort(), reserved)
outbound.bind = wireguard.NewClientBind(context.Background(), wgSingErrorHandler{outbound.Name()}, singDialer, isConnect, outbound.connectAddr.AddrPort(), reserved)
var err error
outbound.localPrefixes, err = option.Prefixes()
@@ -609,6 +609,13 @@ func (w *WireGuard) ResolveUDP(ctx context.Context, metadata *C.Metadata) error
return nil
}
// ProxyInfo implements C.ProxyAdapter
func (w *WireGuard) ProxyInfo() C.ProxyInfo {
info := w.Base.ProxyInfo()
info.DialerProxy = w.option.DialerProxy
return info
}
// IsL3Protocol implements C.ProxyAdapter
func (w *WireGuard) IsL3Protocol(metadata *C.Metadata) bool {
return true

View File

@@ -150,6 +150,14 @@ func (f *Fallback) ForceSet(name string) {
f.selected = name
}
func (f *Fallback) Providers() []P.ProxyProvider {
return f.providers
}
func (f *Fallback) Proxies() []C.Proxy {
return f.GetProxies(false)
}
func NewFallback(option *GroupCommonOption, providers []P.ProxyProvider) *Fallback {
return &Fallback{
GroupBase: NewGroupBase(GroupBaseOption{

View File

@@ -239,6 +239,18 @@ func (lb *LoadBalance) MarshalJSON() ([]byte, error) {
})
}
func (lb *LoadBalance) Providers() []P.ProxyProvider {
return lb.providers
}
func (lb *LoadBalance) Proxies() []C.Proxy {
return lb.GetProxies(false)
}
func (lb *LoadBalance) Now() string {
return ""
}
func NewLoadBalance(option *GroupCommonOption, providers []P.ProxyProvider, strategy string) (lb *LoadBalance, err error) {
var strategyFn strategyFn
switch strategy {

View File

@@ -1,64 +0,0 @@
//go:build android && cmfa
package outboundgroup
import (
C "github.com/metacubex/mihomo/constant"
P "github.com/metacubex/mihomo/constant/provider"
)
type ProxyGroup interface {
C.ProxyAdapter
Providers() []P.ProxyProvider
Proxies() []C.Proxy
Now() string
}
func (f *Fallback) Providers() []P.ProxyProvider {
return f.providers
}
func (lb *LoadBalance) Providers() []P.ProxyProvider {
return lb.providers
}
func (f *Fallback) Proxies() []C.Proxy {
return f.GetProxies(false)
}
func (lb *LoadBalance) Proxies() []C.Proxy {
return lb.GetProxies(false)
}
func (lb *LoadBalance) Now() string {
return ""
}
func (r *Relay) Providers() []P.ProxyProvider {
return r.providers
}
func (r *Relay) Proxies() []C.Proxy {
return r.GetProxies(false)
}
func (r *Relay) Now() string {
return ""
}
func (s *Selector) Providers() []P.ProxyProvider {
return s.providers
}
func (s *Selector) Proxies() []C.Proxy {
return s.GetProxies(false)
}
func (u *URLTest) Providers() []P.ProxyProvider {
return u.providers
}
func (u *URLTest) Proxies() []C.Proxy {
return u.GetProxies(false)
}

View File

@@ -1,163 +0,0 @@
package outboundgroup
import (
"context"
"encoding/json"
"github.com/metacubex/mihomo/adapter/outbound"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/proxydialer"
C "github.com/metacubex/mihomo/constant"
P "github.com/metacubex/mihomo/constant/provider"
"github.com/metacubex/mihomo/log"
)
type Relay struct {
*GroupBase
Hidden bool
Icon string
}
// DialContext implements C.ProxyAdapter
func (r *Relay) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) {
proxies, chainProxies := r.proxies(metadata, true)
switch len(proxies) {
case 0:
return outbound.NewDirect().DialContext(ctx, metadata)
case 1:
return proxies[0].DialContext(ctx, metadata)
}
var d C.Dialer
d = dialer.NewDialer()
for _, proxy := range proxies[:len(proxies)-1] {
d = proxydialer.New(proxy, d, false)
}
last := proxies[len(proxies)-1]
conn, err := last.DialContextWithDialer(ctx, d, metadata)
if err != nil {
return nil, err
}
for i := len(chainProxies) - 2; i >= 0; i-- {
conn.AppendToChains(chainProxies[i])
}
conn.AppendToChains(r)
return conn, nil
}
// ListenPacketContext implements C.ProxyAdapter
func (r *Relay) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (_ C.PacketConn, err error) {
proxies, chainProxies := r.proxies(metadata, true)
switch len(proxies) {
case 0:
return outbound.NewDirect().ListenPacketContext(ctx, metadata)
case 1:
return proxies[0].ListenPacketContext(ctx, metadata)
}
var d C.Dialer
d = dialer.NewDialer()
for _, proxy := range proxies[:len(proxies)-1] {
d = proxydialer.New(proxy, d, false)
}
last := proxies[len(proxies)-1]
pc, err := last.ListenPacketWithDialer(ctx, d, metadata)
if err != nil {
return nil, err
}
for i := len(chainProxies) - 2; i >= 0; i-- {
pc.AppendToChains(chainProxies[i])
}
pc.AppendToChains(r)
return pc, nil
}
// SupportUDP implements C.ProxyAdapter
func (r *Relay) SupportUDP() bool {
proxies, _ := r.proxies(nil, false)
if len(proxies) == 0 { // C.Direct
return true
}
for i := len(proxies) - 1; i >= 0; i-- {
proxy := proxies[i]
if !proxy.SupportUDP() {
return false
}
if proxy.SupportUOT() {
return true
}
switch proxy.SupportWithDialer() {
case C.ALLNet:
case C.UDP:
default: // C.TCP and C.InvalidNet
return false
}
}
return true
}
// MarshalJSON implements C.ProxyAdapter
func (r *Relay) MarshalJSON() ([]byte, error) {
all := []string{}
for _, proxy := range r.GetProxies(false) {
all = append(all, proxy.Name())
}
return json.Marshal(map[string]any{
"type": r.Type().String(),
"all": all,
"hidden": r.Hidden,
"icon": r.Icon,
})
}
func (r *Relay) proxies(metadata *C.Metadata, touch bool) ([]C.Proxy, []C.Proxy) {
rawProxies := r.GetProxies(touch)
var proxies []C.Proxy
var chainProxies []C.Proxy
var targetProxies []C.Proxy
for n, proxy := range rawProxies {
proxies = append(proxies, proxy)
chainProxies = append(chainProxies, proxy)
subproxy := proxy.Unwrap(metadata, touch)
for subproxy != nil {
chainProxies = append(chainProxies, subproxy)
proxies[n] = subproxy
subproxy = subproxy.Unwrap(metadata, touch)
}
}
for _, proxy := range proxies {
if proxy.Type() != C.Direct && proxy.Type() != C.Compatible {
targetProxies = append(targetProxies, proxy)
}
}
return targetProxies, chainProxies
}
func (r *Relay) Addr() string {
proxies, _ := r.proxies(nil, false)
return proxies[len(proxies)-1].Addr()
}
func NewRelay(option *GroupCommonOption, providers []P.ProxyProvider) *Relay {
log.Warnln("The group [%s] with relay type is deprecated, please using dialer-proxy instead", option.Name)
return &Relay{
GroupBase: NewGroupBase(GroupBaseOption{
Name: option.Name,
Type: C.Relay,
Providers: providers,
}),
Hidden: option.Hidden,
Icon: option.Icon,
}
}

View File

@@ -108,6 +108,14 @@ func (s *Selector) selectedProxy(touch bool) C.Proxy {
return proxies[0]
}
func (s *Selector) Providers() []P.ProxyProvider {
return s.providers
}
func (s *Selector) Proxies() []C.Proxy {
return s.GetProxies(false)
}
func NewSelector(option *GroupCommonOption, providers []P.ProxyProvider) *Selector {
return &Selector{
GroupBase: NewGroupBase(GroupBaseOption{

View File

@@ -185,6 +185,14 @@ func (u *URLTest) MarshalJSON() ([]byte, error) {
})
}
func (u *URLTest) Providers() []P.ProxyProvider {
return u.providers
}
func (u *URLTest) Proxies() []C.Proxy {
return u.GetProxies(false)
}
func (u *URLTest) URLTest(ctx context.Context, url string, expectedStatus utils.IntRanges[uint16]) (map[string]uint16, error) {
return u.GroupBase.URLTest(ctx, u.testUrl, expectedStatus)
}

View File

@@ -1,5 +1,29 @@
package outboundgroup
import (
"context"
"github.com/metacubex/mihomo/common/utils"
C "github.com/metacubex/mihomo/constant"
P "github.com/metacubex/mihomo/constant/provider"
)
type ProxyGroup interface {
C.ProxyAdapter
Providers() []P.ProxyProvider
Proxies() []C.Proxy
Now() string
Touch()
URLTest(ctx context.Context, url string, expectedStatus utils.IntRanges[uint16]) (mp map[string]uint16, err error)
}
var _ ProxyGroup = (*Fallback)(nil)
var _ ProxyGroup = (*LoadBalance)(nil)
var _ ProxyGroup = (*URLTest)(nil)
var _ ProxyGroup = (*Selector)(nil)
type SelectAble interface {
Set(string) error
ForceSet(name string)

View File

@@ -8,151 +8,164 @@ import (
C "github.com/metacubex/mihomo/constant"
)
func ParseProxy(mapping map[string]any) (C.Proxy, error) {
func ParseProxy(mapping map[string]any, options ...ProxyOption) (C.Proxy, error) {
decoder := structure.NewDecoder(structure.Option{TagName: "proxy", WeaklyTypedInput: true, KeyReplacer: structure.DefaultKeyReplacer})
proxyType, existType := mapping["type"].(string)
if !existType {
return nil, fmt.Errorf("missing type")
}
opt := applyProxyOptions(options...)
basicOption := outbound.BasicOption{
DialerForAPI: opt.DialerForAPI,
ProviderName: opt.ProviderName,
}
var (
proxy outbound.ProxyAdapter
err error
)
switch proxyType {
case "ss":
ssOption := &outbound.ShadowSocksOption{}
ssOption := &outbound.ShadowSocksOption{BasicOption: basicOption}
err = decoder.Decode(mapping, ssOption)
if err != nil {
break
}
proxy, err = outbound.NewShadowSocks(*ssOption)
case "ssr":
ssrOption := &outbound.ShadowSocksROption{}
ssrOption := &outbound.ShadowSocksROption{BasicOption: basicOption}
err = decoder.Decode(mapping, ssrOption)
if err != nil {
break
}
proxy, err = outbound.NewShadowSocksR(*ssrOption)
case "socks5":
socksOption := &outbound.Socks5Option{}
socksOption := &outbound.Socks5Option{BasicOption: basicOption}
err = decoder.Decode(mapping, socksOption)
if err != nil {
break
}
proxy, err = outbound.NewSocks5(*socksOption)
case "http":
httpOption := &outbound.HttpOption{}
httpOption := &outbound.HttpOption{BasicOption: basicOption}
err = decoder.Decode(mapping, httpOption)
if err != nil {
break
}
proxy, err = outbound.NewHttp(*httpOption)
case "vmess":
vmessOption := &outbound.VmessOption{}
vmessOption := &outbound.VmessOption{BasicOption: basicOption}
err = decoder.Decode(mapping, vmessOption)
if err != nil {
break
}
proxy, err = outbound.NewVmess(*vmessOption)
case "vless":
vlessOption := &outbound.VlessOption{}
vlessOption := &outbound.VlessOption{BasicOption: basicOption}
err = decoder.Decode(mapping, vlessOption)
if err != nil {
break
}
proxy, err = outbound.NewVless(*vlessOption)
case "snell":
snellOption := &outbound.SnellOption{}
snellOption := &outbound.SnellOption{BasicOption: basicOption}
err = decoder.Decode(mapping, snellOption)
if err != nil {
break
}
proxy, err = outbound.NewSnell(*snellOption)
case "trojan":
trojanOption := &outbound.TrojanOption{}
trojanOption := &outbound.TrojanOption{BasicOption: basicOption}
err = decoder.Decode(mapping, trojanOption)
if err != nil {
break
}
proxy, err = outbound.NewTrojan(*trojanOption)
case "hysteria":
hyOption := &outbound.HysteriaOption{}
hyOption := &outbound.HysteriaOption{BasicOption: basicOption}
err = decoder.Decode(mapping, hyOption)
if err != nil {
break
}
proxy, err = outbound.NewHysteria(*hyOption)
case "hysteria2":
hyOption := &outbound.Hysteria2Option{}
hyOption := &outbound.Hysteria2Option{BasicOption: basicOption}
err = decoder.Decode(mapping, hyOption)
if err != nil {
break
}
proxy, err = outbound.NewHysteria2(*hyOption)
case "wireguard":
wgOption := &outbound.WireGuardOption{}
wgOption := &outbound.WireGuardOption{BasicOption: basicOption}
err = decoder.Decode(mapping, wgOption)
if err != nil {
break
}
proxy, err = outbound.NewWireGuard(*wgOption)
case "tuic":
tuicOption := &outbound.TuicOption{}
tuicOption := &outbound.TuicOption{BasicOption: basicOption}
err = decoder.Decode(mapping, tuicOption)
if err != nil {
break
}
proxy, err = outbound.NewTuic(*tuicOption)
case "direct":
directOption := &outbound.DirectOption{}
directOption := &outbound.DirectOption{BasicOption: basicOption}
err = decoder.Decode(mapping, directOption)
if err != nil {
break
}
proxy = outbound.NewDirectWithOption(*directOption)
case "dns":
dnsOptions := &outbound.DnsOption{}
dnsOptions := &outbound.DnsOption{BasicOption: basicOption}
err = decoder.Decode(mapping, dnsOptions)
if err != nil {
break
}
proxy = outbound.NewDnsWithOption(*dnsOptions)
case "reject":
rejectOption := &outbound.RejectOption{}
rejectOption := &outbound.RejectOption{BasicOption: basicOption}
err = decoder.Decode(mapping, rejectOption)
if err != nil {
break
}
proxy = outbound.NewRejectWithOption(*rejectOption)
case "ssh":
sshOption := &outbound.SshOption{}
sshOption := &outbound.SshOption{BasicOption: basicOption}
err = decoder.Decode(mapping, sshOption)
if err != nil {
break
}
proxy, err = outbound.NewSsh(*sshOption)
case "mieru":
mieruOption := &outbound.MieruOption{}
mieruOption := &outbound.MieruOption{BasicOption: basicOption}
err = decoder.Decode(mapping, mieruOption)
if err != nil {
break
}
proxy, err = outbound.NewMieru(*mieruOption)
case "anytls":
anytlsOption := &outbound.AnyTLSOption{}
anytlsOption := &outbound.AnyTLSOption{BasicOption: basicOption}
err = decoder.Decode(mapping, anytlsOption)
if err != nil {
break
}
proxy, err = outbound.NewAnyTLS(*anytlsOption)
case "sudoku":
sudokuOption := &outbound.SudokuOption{}
sudokuOption := &outbound.SudokuOption{BasicOption: basicOption}
err = decoder.Decode(mapping, sudokuOption)
if err != nil {
break
}
proxy, err = outbound.NewSudoku(*sudokuOption)
case "masque":
masqueOption := &outbound.MasqueOption{BasicOption: basicOption}
err = decoder.Decode(mapping, masqueOption)
if err != nil {
break
}
proxy, err = outbound.NewMasque(*masqueOption)
default:
return nil, fmt.Errorf("unsupport proxy type: %s", proxyType)
}
@@ -178,3 +191,30 @@ func ParseProxy(mapping map[string]any) (C.Proxy, error) {
proxy = outbound.NewAutoCloseProxyAdapter(proxy)
return NewProxy(proxy), nil
}
type proxyOption struct {
DialerForAPI C.Dialer
ProviderName string
}
func applyProxyOptions(options ...ProxyOption) proxyOption {
opt := proxyOption{}
for _, o := range options {
o(&opt)
}
return opt
}
type ProxyOption func(opt *proxyOption)
func WithDialerForAPI(dialer C.Dialer) ProxyOption {
return func(opt *proxyOption) {
opt.DialerForAPI = dialer
}
}
func WithProviderName(name string) ProxyOption {
return func(opt *proxyOption) {
opt.ProviderName = name
}
}

View File

@@ -0,0 +1,88 @@
package provider
import (
"encoding"
"fmt"
"github.com/dlclark/regexp2"
)
type overrideSchema struct {
TFO *bool `provider:"tfo,omitempty"`
MPTcp *bool `provider:"mptcp,omitempty"`
UDP *bool `provider:"udp,omitempty"`
UDPOverTCP *bool `provider:"udp-over-tcp,omitempty"`
Up *string `provider:"up,omitempty"`
Down *string `provider:"down,omitempty"`
DialerProxy *string `provider:"dialer-proxy,omitempty"`
SkipCertVerify *bool `provider:"skip-cert-verify,omitempty"`
Interface *string `provider:"interface-name,omitempty"`
RoutingMark *int `provider:"routing-mark,omitempty"`
IPVersion *string `provider:"ip-version,omitempty"`
AdditionalPrefix *string `provider:"additional-prefix,omitempty"`
AdditionalSuffix *string `provider:"additional-suffix,omitempty"`
ProxyName []overrideProxyNameSchema `provider:"proxy-name,omitempty"`
}
type overrideProxyNameSchema struct {
// matching expression for regex replacement
Pattern *regexp2.Regexp `provider:"pattern"`
// the new content after regex matching
Target string `provider:"target"`
}
var _ encoding.TextUnmarshaler = (*regexp2.Regexp)(nil) // ensure *regexp2.Regexp can decode direct by structure package
func (o *overrideSchema) Apply(mapping map[string]any) error {
if o.TFO != nil {
mapping["tfo"] = *o.TFO
}
if o.MPTcp != nil {
mapping["mptcp"] = *o.MPTcp
}
if o.UDP != nil {
mapping["udp"] = *o.UDP
}
if o.UDPOverTCP != nil {
mapping["udp-over-tcp"] = *o.UDPOverTCP
}
if o.Up != nil {
mapping["up"] = *o.Up
}
if o.Down != nil {
mapping["down"] = *o.Down
}
if o.DialerProxy != nil {
mapping["dialer-proxy"] = *o.DialerProxy
}
if o.SkipCertVerify != nil {
mapping["skip-cert-verify"] = *o.SkipCertVerify
}
if o.Interface != nil {
mapping["interface"] = *o.Interface
}
if o.RoutingMark != nil {
mapping["routing-mark"] = *o.RoutingMark
}
if o.IPVersion != nil {
mapping["ip-version"] = *o.IPVersion
}
for _, expr := range o.ProxyName {
name := mapping["name"].(string)
newName, err := expr.Pattern.Replace(name, expr.Target, 0, -1)
if err != nil {
return fmt.Errorf("proxy name replace error: %w", err)
}
mapping["name"] = newName
}
if o.AdditionalPrefix != nil {
mapping["name"] = fmt.Sprintf("%s%s", *o.AdditionalPrefix, mapping["name"])
}
if o.AdditionalSuffix != nil {
mapping["name"] = fmt.Sprintf("%s%s", mapping["name"], *o.AdditionalSuffix)
}
return nil
}

View File

@@ -1,7 +1,6 @@
package provider
import (
"encoding"
"errors"
"fmt"
"time"
@@ -11,8 +10,6 @@ import (
"github.com/metacubex/mihomo/component/resource"
C "github.com/metacubex/mihomo/constant"
P "github.com/metacubex/mihomo/constant/provider"
"github.com/dlclark/regexp2"
)
var (
@@ -28,33 +25,6 @@ type healthCheckSchema struct {
ExpectedStatus string `provider:"expected-status,omitempty"`
}
type OverrideProxyNameSchema struct {
// matching expression for regex replacement
Pattern *regexp2.Regexp `provider:"pattern"`
// the new content after regex matching
Target string `provider:"target"`
}
var _ encoding.TextUnmarshaler = (*regexp2.Regexp)(nil) // ensure *regexp2.Regexp can decode direct by structure package
type OverrideSchema struct {
TFO *bool `provider:"tfo,omitempty"`
MPTcp *bool `provider:"mptcp,omitempty"`
UDP *bool `provider:"udp,omitempty"`
UDPOverTCP *bool `provider:"udp-over-tcp,omitempty"`
Up *string `provider:"up,omitempty"`
Down *string `provider:"down,omitempty"`
DialerProxy *string `provider:"dialer-proxy,omitempty"`
SkipCertVerify *bool `provider:"skip-cert-verify,omitempty"`
Interface *string `provider:"interface-name,omitempty"`
RoutingMark *int `provider:"routing-mark,omitempty"`
IPVersion *string `provider:"ip-version,omitempty"`
AdditionalPrefix *string `provider:"additional-prefix,omitempty"`
AdditionalSuffix *string `provider:"additional-suffix,omitempty"`
ProxyName []OverrideProxyNameSchema `provider:"proxy-name,omitempty"`
}
type proxyProviderSchema struct {
Type string `provider:"type"`
Path string `provider:"path,omitempty"`
@@ -69,7 +39,7 @@ type proxyProviderSchema struct {
Payload []map[string]any `provider:"payload,omitempty"`
HealthCheck healthCheckSchema `provider:"health-check,omitempty"`
Override OverrideSchema `provider:"override,omitempty"`
Override overrideSchema `provider:"override,omitempty"`
Header map[string][]string `provider:"header,omitempty"`
}
@@ -99,7 +69,7 @@ func ParseProxyProvider(name string, mapping map[string]any) (P.ProxyProvider, e
}
hc := NewHealthCheck([]C.Proxy{}, schema.HealthCheck.URL, uint(schema.HealthCheck.TestTimeout), hcInterval, schema.HealthCheck.Lazy, expectedStatus)
parser, err := NewProxiesParser(schema.Filter, schema.ExcludeFilter, schema.ExcludeType, schema.DialerProxy, schema.Override)
parser, err := NewProxiesParser(name, schema.Filter, schema.ExcludeFilter, schema.ExcludeType, schema.DialerProxy, schema.Override)
if err != nil {
return nil, err
}

View File

@@ -4,15 +4,15 @@ import (
"encoding/json"
"errors"
"fmt"
"net/http"
"reflect"
"runtime"
"strings"
"sync"
"time"
"github.com/metacubex/mihomo/adapter"
"github.com/metacubex/mihomo/common/convert"
"github.com/metacubex/mihomo/common/utils"
"github.com/metacubex/mihomo/common/yaml"
"github.com/metacubex/mihomo/component/profile/cachefile"
"github.com/metacubex/mihomo/component/resource"
C "github.com/metacubex/mihomo/constant"
@@ -20,7 +20,7 @@ import (
"github.com/metacubex/mihomo/tunnel/statistic"
"github.com/dlclark/regexp2"
"gopkg.in/yaml.v3"
"github.com/metacubex/http"
)
const (
@@ -43,6 +43,7 @@ type providerForApi struct {
}
type baseProvider struct {
mutex sync.RWMutex
name string
proxies []C.Proxy
healthCheck *HealthCheck
@@ -54,6 +55,8 @@ func (bp *baseProvider) Name() string {
}
func (bp *baseProvider) Version() uint32 {
bp.mutex.RLock()
defer bp.mutex.RUnlock()
return bp.version
}
@@ -73,10 +76,14 @@ func (bp *baseProvider) Type() P.ProviderType {
}
func (bp *baseProvider) Proxies() []C.Proxy {
bp.mutex.RLock()
defer bp.mutex.RUnlock()
return bp.proxies
}
func (bp *baseProvider) Count() int {
bp.mutex.RLock()
defer bp.mutex.RUnlock()
return len(bp.proxies)
}
@@ -93,6 +100,8 @@ func (bp *baseProvider) RegisterHealthCheckTask(url string, expectedStatus utils
}
func (bp *baseProvider) setProxies(proxies []C.Proxy) {
bp.mutex.Lock()
defer bp.mutex.Unlock()
bp.proxies = proxies
bp.version += 1
bp.healthCheck.setProxies(proxies)
@@ -156,7 +165,7 @@ func (pp *proxySetProvider) Initial() error {
func (pp *proxySetProvider) closeAllConnections() {
statistic.DefaultManager.Range(func(c statistic.Tracker) bool {
for _, chain := range c.Chains() {
for _, chain := range c.ProviderChains() {
if chain == pp.Name() {
_ = c.Close()
break
@@ -330,7 +339,7 @@ func (cp *CompatibleProvider) Close() error {
return cp.compatibleProvider.Close()
}
func NewProxiesParser(filter string, excludeFilter string, excludeType string, dialerProxy string, override OverrideSchema) (resource.Parser[[]C.Proxy], error) {
func NewProxiesParser(pdName string, filter string, excludeFilter string, excludeType string, dialerProxy string, override overrideSchema) (resource.Parser[[]C.Proxy], error) {
var excludeTypeArray []string
if excludeType != "" {
excludeTypeArray = strings.Split(excludeType, "|")
@@ -419,36 +428,12 @@ func NewProxiesParser(filter string, excludeFilter string, excludeType string, d
mapping["dialer-proxy"] = dialerProxy
}
val := reflect.ValueOf(override)
for i := 0; i < val.NumField(); i++ {
field := val.Field(i)
if field.IsNil() {
continue
}
fieldName := strings.Split(val.Type().Field(i).Tag.Get("provider"), ",")[0]
switch fieldName {
case "additional-prefix":
name := mapping["name"].(string)
mapping["name"] = *field.Interface().(*string) + name
case "additional-suffix":
name := mapping["name"].(string)
mapping["name"] = name + *field.Interface().(*string)
case "proxy-name":
// Iterate through all naming replacement rules and perform the replacements.
for _, expr := range override.ProxyName {
name := mapping["name"].(string)
newName, err := expr.Pattern.Replace(name, expr.Target, 0, -1)
if err != nil {
return nil, fmt.Errorf("proxy name replace error: %w", err)
}
mapping["name"] = newName
}
default:
mapping[fieldName] = field.Elem().Interface()
}
err := override.Apply(mapping)
if err != nil {
return nil, fmt.Errorf("proxy %d override error: %w", idx, err)
}
proxy, err := adapter.ParseProxy(mapping)
proxy, err := adapter.ParseProxy(mapping, adapter.WithProviderName(pdName))
if err != nil {
return nil, fmt.Errorf("proxy %d error: %w", idx, err)
}

View File

@@ -1,6 +1,8 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build android && cgo
// +build android,cgo
// kanged from https://github.com/golang/mobile/blob/c713f31d574bb632a93f169b2cc99c9e753fef0e/app/android.go#L89

View File

@@ -201,6 +201,10 @@ func ConvertsV2Ray(buf []byte) ([]map[string]any, error) {
trojan["client-fingerprint"] = fingerprint
}
if pcs := query.Get("pcs"); pcs != "" {
trojan["fingerprint"] = pcs
}
proxies = append(proxies, trojan)
case "vless":

View File

@@ -2,12 +2,12 @@ package convert
import (
"encoding/base64"
"net/http"
"strings"
"time"
"github.com/metacubex/mihomo/common/utils"
"github.com/metacubex/http"
"github.com/metacubex/randv2"
"github.com/metacubex/sing-shadowsocks/shadowimpl"
)

View File

@@ -35,6 +35,9 @@ func handleVShareLink(names map[string]int, url *url.URL, scheme string, proxy m
if alpn := query.Get("alpn"); alpn != "" {
proxy["alpn"] = strings.Split(alpn, ",")
}
if pcs := query.Get("pcs"); pcs != "" {
proxy["fingerprint"] = pcs
}
}
if sni := query.Get("sni"); sni != "" {
proxy["servername"] = sni

674
common/deque/deque.go Normal file
View File

@@ -0,0 +1,674 @@
package deque
// copy and modified from https://github.com/gammazero/deque/blob/v1.2.0/deque.go
// which is licensed under MIT.
import (
"fmt"
)
// minCapacity is the smallest capacity that deque may have. Must be power of 2
// for bitwise modulus: x % n == x & (n - 1).
const minCapacity = 8
// Deque represents a single instance of the deque data structure. A Deque
// instance contains items of the type specified by the type argument.
//
// For example, to create a Deque that contains strings do one of the
// following:
//
// var stringDeque deque.Deque[string]
// stringDeque := new(deque.Deque[string])
// stringDeque := &deque.Deque[string]{}
//
// To create a Deque that will never resize to have space for less than 64
// items, specify a base capacity:
//
// var d deque.Deque[int]
// d.SetBaseCap(64)
//
// To ensure the Deque can store 1000 items without needing to resize while
// items are added:
//
// d.Grow(1000)
//
// Any values supplied to [SetBaseCap] and [Grow] are rounded up to the nearest
// power of 2, since the Deque grows by powers of 2.
type Deque[T any] struct {
buf []T
head int
tail int
count int
minCap int
}
// Cap returns the current capacity of the Deque. If q is nil, q.Cap() is zero.
func (q *Deque[T]) Cap() int {
if q == nil {
return 0
}
return len(q.buf)
}
// Len returns the number of elements currently stored in the queue. If q is
// nil, q.Len() returns zero.
func (q *Deque[T]) Len() int {
if q == nil {
return 0
}
return q.count
}
// PushBack appends an element to the back of the queue. Implements FIFO when
// elements are removed with [PopFront], and LIFO when elements are removed with
// [PopBack].
func (q *Deque[T]) PushBack(elem T) {
q.growIfFull()
q.buf[q.tail] = elem
// Calculate new tail position.
q.tail = q.next(q.tail)
q.count++
}
// PushFront prepends an element to the front of the queue.
func (q *Deque[T]) PushFront(elem T) {
q.growIfFull()
// Calculate new head position.
q.head = q.prev(q.head)
q.buf[q.head] = elem
q.count++
}
// PopFront removes and returns the element from the front of the queue.
// Implements FIFO when used with [PushBack]. If the queue is empty, the call
// panics.
func (q *Deque[T]) PopFront() T {
if q.count <= 0 {
panic("deque: PopFront() called on empty queue")
}
ret := q.buf[q.head]
var zero T
q.buf[q.head] = zero
// Calculate new head position.
q.head = q.next(q.head)
q.count--
q.shrinkIfExcess()
return ret
}
// IterPopFront returns an iterator that iteratively removes items from the
// front of the deque. This is more efficient than removing items one at a time
// because it avoids intermediate resizing. If a resize is necessary, only one
// is done when iteration ends.
func (q *Deque[T]) IterPopFront() func(yield func(T) bool) {
return func(yield func(T) bool) {
if q.Len() == 0 {
return
}
var zero T
for q.count != 0 {
ret := q.buf[q.head]
q.buf[q.head] = zero
q.head = q.next(q.head)
q.count--
if !yield(ret) {
break
}
}
q.shrinkToFit()
}
}
// PopBack removes and returns the element from the back of the queue.
// Implements LIFO when used with [PushBack]. If the queue is empty, the call
// panics.
func (q *Deque[T]) PopBack() T {
if q.count <= 0 {
panic("deque: PopBack() called on empty queue")
}
// Calculate new tail position
q.tail = q.prev(q.tail)
// Remove value at tail.
ret := q.buf[q.tail]
var zero T
q.buf[q.tail] = zero
q.count--
q.shrinkIfExcess()
return ret
}
// IterPopBack returns an iterator that iteratively removes items from the back
// of the deque. This is more efficient than removing items one at a time
// because it avoids intermediate resizing. If a resize is necessary, only one
// is done when iteration ends.
func (q *Deque[T]) IterPopBack() func(yield func(T) bool) {
return func(yield func(T) bool) {
if q.Len() == 0 {
return
}
var zero T
for q.count != 0 {
q.tail = q.prev(q.tail)
ret := q.buf[q.tail]
q.buf[q.tail] = zero
q.count--
if !yield(ret) {
break
}
}
q.shrinkToFit()
}
}
// Front returns the element at the front of the queue. This is the element
// that would be returned by [PopFront]. This call panics if the queue is
// empty.
func (q *Deque[T]) Front() T {
if q.count <= 0 {
panic("deque: Front() called when empty")
}
return q.buf[q.head]
}
// Back returns the element at the back of the queue. This is the element that
// would be returned by [PopBack]. This call panics if the queue is empty.
func (q *Deque[T]) Back() T {
if q.count <= 0 {
panic("deque: Back() called when empty")
}
return q.buf[q.prev(q.tail)]
}
// At returns the element at index i in the queue without removing the element
// from the queue. This method accepts only non-negative index values. At(0)
// refers to the first element and is the same as [Front]. At(Len()-1) refers
// to the last element and is the same as [Back]. If the index is invalid, the
// call panics.
//
// The purpose of At is to allow Deque to serve as a more general purpose
// circular buffer, where items are only added to and removed from the ends of
// the deque, but may be read from any place within the deque. Consider the
// case of a fixed-size circular log buffer: A new entry is pushed onto one end
// and when full the oldest is popped from the other end. All the log entries
// in the buffer must be readable without altering the buffer contents.
func (q *Deque[T]) At(i int) T {
q.checkRange(i)
// bitwise modulus
return q.buf[(q.head+i)&(len(q.buf)-1)]
}
// Set assigns the item to index i in the queue. Set indexes the deque the same
// as [At] but perform the opposite operation. If the index is invalid, the call
// panics.
func (q *Deque[T]) Set(i int, item T) {
q.checkRange(i)
// bitwise modulus
q.buf[(q.head+i)&(len(q.buf)-1)] = item
}
// Iter returns a go iterator to range over all items in the Deque, yielding
// each item from front (index 0) to back (index Len()-1). Modification of
// Deque during iteration panics.
func (q *Deque[T]) Iter() func(yield func(T) bool) {
return func(yield func(T) bool) {
origHead := q.head
origTail := q.tail
head := origHead
for i := -0; i < q.Len(); i++ {
if q.head != origHead || q.tail != origTail {
panic("deque: modified during iteration")
}
if !yield(q.buf[head]) {
return
}
head = q.next(head)
}
}
}
// RIter returns a reverse go iterator to range over all items in the Deque,
// yielding each item from back (index Len()-1) to front (index 0).
// Modification of Deque during iteration panics.
func (q *Deque[T]) RIter() func(yield func(T) bool) {
return func(yield func(T) bool) {
origHead := q.head
origTail := q.tail
tail := origTail
for i := -0; i < q.Len(); i++ {
if q.head != origHead || q.tail != origTail {
panic("deque: modified during iteration")
}
tail = q.prev(tail)
if !yield(q.buf[tail]) {
return
}
}
}
}
// Clear removes all elements from the queue, but retains the current capacity.
// This is useful when repeatedly reusing the queue at high frequency to avoid
// GC during reuse. The queue will not be resized smaller as long as items are
// only added. Only when items are removed is the queue subject to getting
// resized smaller.
func (q *Deque[T]) Clear() {
if q.Len() == 0 {
return
}
head, tail := q.head, q.tail
q.count = 0
q.head = 0
q.tail = 0
if head >= tail {
// [DEF....ABC]
clearSlice(q.buf[head:])
head = 0
}
clearSlice(q.buf[head:tail])
}
func clearSlice[S ~[]E, E any](s S) {
var zero E
for i := range s {
s[i] = zero
}
}
// Grow grows deque's capacity, if necessary, to guarantee space for another n
// items. After Grow(n), at least n items can be written to the deque without
// another allocation. If n is negative, Grow panics.
func (q *Deque[T]) Grow(n int) {
if n < 0 {
panic("deque.Grow: negative count")
}
c := q.Cap()
l := q.Len()
// If already big enough.
if n <= c-l {
return
}
if c == 0 {
c = minCapacity
}
newLen := l + n
for c < newLen {
c <<= 1
}
if l == 0 {
q.buf = make([]T, c)
q.head = 0
q.tail = 0
} else {
q.resize(c)
}
}
// Copy copies the contents of the given src Deque into this Deque.
//
// n := b.Copy(a)
//
// is an efficient shortcut for
//
// b.Clear()
// n := a.Len()
// b.Grow(n)
// for i := 0; i < n; i++ {
// b.PushBack(a.At(i))
// }
func (q *Deque[T]) Copy(src Deque[T]) int {
q.Clear()
q.Grow(src.Len())
n := src.CopyOutSlice(q.buf)
q.count = n
q.tail = n
q.head = 0
return n
}
// AppendToSlice appends from the Deque to the given slice. If the slice has
// insufficient capacity to store all elements in Deque, then allocate a new
// slice. Returns the resulting slice.
//
// out = q.AppendToSlice(out)
//
// is an efficient shortcut for
//
// for i := 0; i < q.Len(); i++ {
// x = append(out, q.At(i))
// }
func (q *Deque[T]) AppendToSlice(out []T) []T {
if q.count == 0 {
return out
}
head, tail := q.head, q.tail
if head >= tail {
// [DEF....ABC]
out = append(out, q.buf[head:]...)
head = 0
}
return append(out, q.buf[head:tail]...)
}
// CopyInSlice replaces the contents of Deque with all the elements from the
// given slice, in. If len(in) is zero, then this is equivalent to calling
// [Clear].
//
// q.CopyInSlice(in)
//
// is an efficient shortcut for
//
// q.Clear()
// for i := range in {
// q.PushBack(in[i])
// }
func (q *Deque[T]) CopyInSlice(in []T) {
// Allocate new buffer if more space needed.
if len(q.buf) < len(in) {
newCap := len(q.buf)
if newCap == 0 {
newCap = minCapacity
q.minCap = minCapacity
}
for newCap < len(in) {
newCap <<= 1
}
q.buf = make([]T, newCap)
} else if len(q.buf) > len(in) {
q.Clear()
}
n := copy(q.buf, in)
q.count = n
q.tail = n
q.head = 0
}
// CopyOutSlice copies elements from the Deque into the given slice, up to the
// size of the buffer. Returns the number of elements copied, which will be the
// minimum of q.Len() and len(out).
//
// n := q.CopyOutSlice(out)
//
// is an efficient shortcut for
//
// n := min(len(out), q.Len())
// for i := 0; i < n; i++ {
// out[i] = q.At(i)
// }
//
// This function is preferable to one that returns a copy of the internal
// buffer because this allows reuse of memory receiving data, for repeated copy
// operations.
func (q *Deque[T]) CopyOutSlice(out []T) int {
if q.count == 0 || len(out) == 0 {
return 0
}
head, tail := q.head, q.tail
var n int
if head >= tail {
// [DEF....ABC]
n = copy(out, q.buf[head:])
out = out[n:]
if len(out) == 0 {
return n
}
head = 0
}
n += copy(out, q.buf[head:tail])
return n
}
// Rotate rotates the deque n steps front-to-back. If n is negative, rotates
// back-to-front. Having Deque provide Rotate avoids resizing that could happen
// if implementing rotation using only Pop and Push methods. If q.Len() is one
// or less, or q is nil, then Rotate does nothing.
func (q *Deque[T]) Rotate(n int) {
if q.Len() <= 1 {
return
}
// Rotating a multiple of q.count is same as no rotation.
n %= q.count
if n == 0 {
return
}
modBits := len(q.buf) - 1
// If no empty space in buffer, only move head and tail indexes.
if q.head == q.tail {
// Calculate new head and tail using bitwise modulus.
q.head = (q.head + n) & modBits
q.tail = q.head
return
}
var zero T
if n < 0 {
// Rotate back to front.
for ; n < 0; n++ {
// Calculate new head and tail using bitwise modulus.
q.head = (q.head - 1) & modBits
q.tail = (q.tail - 1) & modBits
// Put tail value at head and remove value at tail.
q.buf[q.head] = q.buf[q.tail]
q.buf[q.tail] = zero
}
return
}
// Rotate front to back.
for ; n > 0; n-- {
// Put head value at tail and remove value at head.
q.buf[q.tail] = q.buf[q.head]
q.buf[q.head] = zero
// Calculate new head and tail using bitwise modulus.
q.head = (q.head + 1) & modBits
q.tail = (q.tail + 1) & modBits
}
}
// Index returns the index into the Deque of the first item satisfying f(item),
// or -1 if none do. If q is nil, then -1 is always returned. Search is linear
// starting with index 0.
func (q *Deque[T]) Index(f func(T) bool) int {
if q.Len() > 0 {
modBits := len(q.buf) - 1
for i := 0; i < q.count; i++ {
if f(q.buf[(q.head+i)&modBits]) {
return i
}
}
}
return -1
}
// RIndex is the same as Index, but searches from Back to Front. The index
// returned is from Front to Back, where index 0 is the index of the item
// returned by [Front].
func (q *Deque[T]) RIndex(f func(T) bool) int {
if q.Len() > 0 {
modBits := len(q.buf) - 1
for i := q.count - 1; i >= 0; i-- {
if f(q.buf[(q.head+i)&modBits]) {
return i
}
}
}
return -1
}
// Insert is used to insert an element into the middle of the queue, before the
// element at the specified index. Insert(0,e) is the same as PushFront(e) and
// Insert(Len(),e) is the same as PushBack(e). Out of range indexes result in
// pushing the item onto the front of back of the deque.
//
// Important: Deque is optimized for O(1) operations at the ends of the queue,
// not for operations in the the middle. Complexity of this function is
// constant plus linear in the lesser of the distances between the index and
// either of the ends of the queue.
func (q *Deque[T]) Insert(at int, item T) {
if at <= 0 {
q.PushFront(item)
return
}
if at >= q.Len() {
q.PushBack(item)
return
}
if at*2 < q.count {
q.PushFront(item)
front := q.head
for i := 0; i < at; i++ {
next := q.next(front)
q.buf[front], q.buf[next] = q.buf[next], q.buf[front]
front = next
}
return
}
swaps := q.count - at
q.PushBack(item)
back := q.prev(q.tail)
for i := 0; i < swaps; i++ {
prev := q.prev(back)
q.buf[back], q.buf[prev] = q.buf[prev], q.buf[back]
back = prev
}
}
// Remove removes and returns an element from the middle of the queue, at the
// specified index. Remove(0) is the same as [PopFront] and Remove(Len()-1) is
// the same as [PopBack]. Accepts only non-negative index values, and panics if
// index is out of range.
//
// Important: Deque is optimized for O(1) operations at the ends of the queue,
// not for operations in the the middle. Complexity of this function is
// constant plus linear in the lesser of the distances between the index and
// either of the ends of the queue.
func (q *Deque[T]) Remove(at int) T {
q.checkRange(at)
rm := (q.head + at) & (len(q.buf) - 1)
if at*2 < q.count {
for i := 0; i < at; i++ {
prev := q.prev(rm)
q.buf[prev], q.buf[rm] = q.buf[rm], q.buf[prev]
rm = prev
}
return q.PopFront()
}
swaps := q.count - at - 1
for i := 0; i < swaps; i++ {
next := q.next(rm)
q.buf[rm], q.buf[next] = q.buf[next], q.buf[rm]
rm = next
}
return q.PopBack()
}
// SetBaseCap sets a base capacity so that at least the specified number of
// items can always be stored without resizing.
func (q *Deque[T]) SetBaseCap(baseCap int) {
minCap := minCapacity
for minCap < baseCap {
minCap <<= 1
}
q.minCap = minCap
}
// Swap exchanges the two values at idxA and idxB. It panics if either index is
// out of range.
func (q *Deque[T]) Swap(idxA, idxB int) {
q.checkRange(idxA)
q.checkRange(idxB)
if idxA == idxB {
return
}
realA := (q.head + idxA) & (len(q.buf) - 1)
realB := (q.head + idxB) & (len(q.buf) - 1)
q.buf[realA], q.buf[realB] = q.buf[realB], q.buf[realA]
}
func (q *Deque[T]) checkRange(i int) {
if i < 0 || i >= q.count {
panic(fmt.Sprintf("deque: index out of range %d with length %d", i, q.Len()))
}
}
// prev returns the previous buffer position wrapping around buffer.
func (q *Deque[T]) prev(i int) int {
return (i - 1) & (len(q.buf) - 1) // bitwise modulus
}
// next returns the next buffer position wrapping around buffer.
func (q *Deque[T]) next(i int) int {
return (i + 1) & (len(q.buf) - 1) // bitwise modulus
}
// growIfFull resizes up if the buffer is full.
func (q *Deque[T]) growIfFull() {
if q.count != len(q.buf) {
return
}
if len(q.buf) == 0 {
if q.minCap == 0 {
q.minCap = minCapacity
}
q.buf = make([]T, q.minCap)
return
}
q.resize(q.count << 1)
}
// shrinkIfExcess resize down if the buffer 1/4 full.
func (q *Deque[T]) shrinkIfExcess() {
if len(q.buf) > q.minCap && (q.count<<2) == len(q.buf) {
q.resize(q.count << 1)
}
}
func (q *Deque[T]) shrinkToFit() {
if len(q.buf) > q.minCap && (q.count<<2) <= len(q.buf) {
if q.count == 0 {
q.head = 0
q.tail = 0
q.buf = make([]T, q.minCap)
return
}
c := q.minCap
for c < q.count {
c <<= 1
}
q.resize(c)
}
}
// resize resizes the deque to fit exactly twice its current contents. This is
// used to grow the queue when it is full, and also to shrink it when it is
// only a quarter full.
func (q *Deque[T]) resize(newSize int) {
newBuf := make([]T, newSize)
if q.tail > q.head {
copy(newBuf, q.buf[q.head:q.tail])
} else {
n := copy(newBuf, q.buf[q.head:])
copy(newBuf[n:], q.buf[:q.tail])
}
q.head = 0
q.tail = q.count
q.buf = newBuf
}

View File

@@ -1,6 +1,8 @@
package net
import (
"crypto/sha1"
"encoding/base64"
"encoding/binary"
"math/bits"
)
@@ -129,3 +131,13 @@ func MaskWebSocket(key uint32, b []byte) uint32 {
return key
}
func GetWebSocketSecAccept(secKey string) string {
const magic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
const nonceSize = 24 // base64.StdEncoding.EncodedLen(nonceKeySize)
p := make([]byte, nonceSize+len(magic))
copy(p[:nonceSize], secKey)
copy(p[nonceSize:], magic)
sum := sha1.Sum(p)
return base64.StdEncoding.EncodeToString(sum[:])
}

8
common/orderedmap/doc.go Normal file
View File

@@ -0,0 +1,8 @@
package orderedmap
// copy and modified from https://github.com/wk8/go-ordered-map/tree/v2.1.8
// which is licensed under Apache v2.
//
// mihomo modified:
// 1. remove dependence of mailru/easyjson for MarshalJSON
// 2. remove dependence of buger/jsonparser for UnmarshalJSON

139
common/orderedmap/json.go Normal file
View File

@@ -0,0 +1,139 @@
package orderedmap
import (
"bytes"
"encoding"
"encoding/json"
"errors"
"fmt"
"reflect"
)
var (
_ json.Marshaler = &OrderedMap[int, any]{}
_ json.Unmarshaler = &OrderedMap[int, any]{}
)
// MarshalJSON implements the json.Marshaler interface.
func (om *OrderedMap[K, V]) MarshalJSON() ([]byte, error) { //nolint:funlen
if om == nil || om.list == nil {
return []byte("null"), nil
}
var buf bytes.Buffer
buf.WriteByte('{')
enc := json.NewEncoder(&buf)
for pair, firstIteration := om.Oldest(), true; pair != nil; pair = pair.Next() {
if firstIteration {
firstIteration = false
} else {
buf.WriteByte(',')
}
switch key := any(pair.Key).(type) {
case string, encoding.TextMarshaler:
if err := enc.Encode(pair.Key); err != nil {
return nil, err
}
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
buf.WriteByte('"')
buf.WriteString(fmt.Sprint(key))
buf.WriteByte('"')
default:
// this switch takes care of wrapper types around primitive types, such as
// type myType string
switch keyValue := reflect.ValueOf(key); keyValue.Type().Kind() {
case reflect.String:
if err := enc.Encode(pair.Key); err != nil {
return nil, err
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
buf.WriteByte('"')
buf.WriteString(fmt.Sprint(key))
buf.WriteByte('"')
default:
return nil, fmt.Errorf("unsupported key type: %T", key)
}
}
buf.WriteByte(':')
if err := enc.Encode(pair.Value); err != nil {
return nil, err
}
}
buf.WriteByte('}')
return buf.Bytes(), nil
}
// UnmarshalJSON implements the json.Unmarshaler interface.
func (om *OrderedMap[K, V]) UnmarshalJSON(data []byte) error {
if om.list == nil {
om.initialize(0)
}
d := json.NewDecoder(bytes.NewReader(data))
tok, err := d.Token()
if err != nil {
return err
}
if tok != json.Delim('{') {
return errors.New("expect JSON object open with '{'")
}
for d.More() {
// key
tok, err = d.Token()
if err != nil {
return err
}
keyStr, ok := tok.(string)
if !ok {
return fmt.Errorf("key must be a string, got %T\n", tok)
}
var key K
switch typedKey := any(&key).(type) {
case *string:
*typedKey = keyStr
case encoding.TextUnmarshaler:
err = typedKey.UnmarshalText([]byte(keyStr))
case *int, *int8, *int16, *int32, *int64, *uint, *uint8, *uint16, *uint32, *uint64:
err = json.Unmarshal([]byte(keyStr), typedKey)
default:
// this switch takes care of wrapper types around primitive types, such as
// type myType string
switch reflect.TypeOf(key).Kind() {
case reflect.String:
convertedKeyData := reflect.ValueOf(keyStr).Convert(reflect.TypeOf(key))
reflect.ValueOf(&key).Elem().Set(convertedKeyData)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
err = json.Unmarshal([]byte(keyStr), &key)
default:
err = fmt.Errorf("unsupported key type: %T", key)
}
}
if err != nil {
return err
}
// value
value, _ := om.Get(key)
err = d.Decode(&value)
if err != nil {
return err
}
om.Set(key, value)
}
tok, err = d.Token()
if err != nil {
return err
}
if tok != json.Delim('}') {
return errors.New("expect JSON object close with '}'")
}
return nil
}

View File

@@ -0,0 +1,117 @@
package orderedmap
// Adapted from https://github.com/dvyukov/go-fuzz-corpus/blob/c42c1b2/json/json.go
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func FuzzRoundTripJSON(f *testing.F) {
f.Fuzz(func(t *testing.T, data []byte) {
for _, testCase := range []struct {
name string
constructor func() any
// should be a function that asserts that 2 objects of the type returned by constructor are equal
equalityAssertion func(*testing.T, any, any) bool
}{
{
name: "with a string -> string map",
constructor: func() any { return &OrderedMap[string, string]{} },
equalityAssertion: assertOrderedMapsEqual[string, string],
},
{
name: "with a string -> int map",
constructor: func() any { return &OrderedMap[string, int]{} },
equalityAssertion: assertOrderedMapsEqual[string, int],
},
{
name: "with a string -> any map",
constructor: func() any { return &OrderedMap[string, any]{} },
equalityAssertion: assertOrderedMapsEqual[string, any],
},
{
name: "with a struct with map fields",
constructor: func() any { return new(testFuzzStruct) },
equalityAssertion: assertTestFuzzStructEqual,
},
} {
t.Run(testCase.name, func(t *testing.T) {
v1 := testCase.constructor()
if json.Unmarshal(data, v1) != nil {
return
}
jsonData, err := json.Marshal(v1)
require.NoError(t, err)
v2 := testCase.constructor()
require.NoError(t, json.Unmarshal(jsonData, v2))
if !assert.True(t, testCase.equalityAssertion(t, v1, v2), "failed with input data %q", string(data)) {
// look at that what the standard lib does with regular map, to help with debugging
var m1 map[string]any
require.NoError(t, json.Unmarshal(data, &m1))
mapJsonData, err := json.Marshal(m1)
require.NoError(t, err)
var m2 map[string]any
require.NoError(t, json.Unmarshal(mapJsonData, &m2))
t.Logf("initial data = %s", string(data))
t.Logf("unmarshalled map = %v", m1)
t.Logf("re-marshalled from map = %s", string(mapJsonData))
t.Logf("re-marshalled from test obj = %s", string(jsonData))
t.Logf("re-unmarshalled map = %s", m2)
}
})
}
})
}
// only works for fairly basic maps, that's why it's just in this file
func assertOrderedMapsEqual[K comparable, V any](t *testing.T, v1, v2 any) bool {
om1, ok1 := v1.(*OrderedMap[K, V])
om2, ok2 := v2.(*OrderedMap[K, V])
if !assert.True(t, ok1, "v1 not an orderedmap") ||
!assert.True(t, ok2, "v2 not an orderedmap") {
return false
}
success := assert.Equal(t, om1.Len(), om2.Len(), "om1 and om2 have different lengths: %d vs %d", om1.Len(), om2.Len())
for i, pair1, pair2 := 0, om1.Oldest(), om2.Oldest(); pair1 != nil && pair2 != nil; i, pair1, pair2 = i+1, pair1.Next(), pair2.Next() {
success = assert.Equal(t, pair1.Key, pair2.Key, "different keys at position %d: %v vs %v", i, pair1.Key, pair2.Key) && success
success = assert.Equal(t, pair1.Value, pair2.Value, "different values at position %d: %v vs %v", i, pair1.Value, pair2.Value) && success
}
return success
}
type testFuzzStruct struct {
M1 *OrderedMap[int, any]
M2 *OrderedMap[int, string]
M3 *OrderedMap[string, string]
}
func assertTestFuzzStructEqual(t *testing.T, v1, v2 any) bool {
s1, ok := v1.(*testFuzzStruct)
s2, ok := v2.(*testFuzzStruct)
if !assert.True(t, ok, "v1 not an testFuzzStruct") ||
!assert.True(t, ok, "v2 not an testFuzzStruct") {
return false
}
success := assertOrderedMapsEqual[int, any](t, s1.M1, s2.M1)
success = assertOrderedMapsEqual[int, string](t, s1.M2, s2.M2) && success
success = assertOrderedMapsEqual[string, string](t, s1.M3, s2.M3) && success
return success
}

View File

@@ -0,0 +1,338 @@
package orderedmap
import (
"encoding/json"
"errors"
"fmt"
"strconv"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// to test marshalling TextMarshalers and unmarshalling TextUnmarshalers
type marshallable int
func (m marshallable) MarshalText() ([]byte, error) {
return []byte(fmt.Sprintf("#%d#", m)), nil
}
func (m *marshallable) UnmarshalText(text []byte) error {
if len(text) < 3 {
return errors.New("too short")
}
if text[0] != '#' || text[len(text)-1] != '#' {
return errors.New("missing prefix or suffix")
}
value, err := strconv.Atoi(string(text[1 : len(text)-1]))
if err != nil {
return err
}
*m = marshallable(value)
return nil
}
func TestMarshalJSON(t *testing.T) {
t.Run("int key", func(t *testing.T) {
om := New[int, any]()
om.Set(1, "bar")
om.Set(7, "baz")
om.Set(2, 28)
om.Set(3, 100)
om.Set(4, "baz")
om.Set(5, "28")
om.Set(6, "100")
om.Set(8, "baz")
om.Set(8, "baz")
om.Set(9, "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Quisque auctor augue accumsan mi maximus, quis viverra massa pretium. Phasellus imperdiet sapien a interdum sollicitudin. Duis at commodo lectus, a lacinia sem.")
b, err := json.Marshal(om)
assert.NoError(t, err)
assert.Equal(t, `{"1":"bar","7":"baz","2":28,"3":100,"4":"baz","5":"28","6":"100","8":"baz","9":"Lorem ipsum dolor sit amet, consectetur adipiscing elit. Quisque auctor augue accumsan mi maximus, quis viverra massa pretium. Phasellus imperdiet sapien a interdum sollicitudin. Duis at commodo lectus, a lacinia sem."}`, string(b))
})
t.Run("string key", func(t *testing.T) {
om := New[string, any]()
om.Set("test", "bar")
om.Set("abc", true)
b, err := json.Marshal(om)
assert.NoError(t, err)
assert.Equal(t, `{"test":"bar","abc":true}`, string(b))
})
t.Run("typed string key", func(t *testing.T) {
type myString string
om := New[myString, any]()
om.Set("test", "bar")
om.Set("abc", true)
b, err := json.Marshal(om)
assert.NoError(t, err)
assert.Equal(t, `{"test":"bar","abc":true}`, string(b))
})
t.Run("typed int key", func(t *testing.T) {
type myInt uint32
om := New[myInt, any]()
om.Set(1, "bar")
om.Set(7, "baz")
om.Set(2, 28)
om.Set(3, 100)
om.Set(4, "baz")
b, err := json.Marshal(om)
assert.NoError(t, err)
assert.Equal(t, `{"1":"bar","7":"baz","2":28,"3":100,"4":"baz"}`, string(b))
})
t.Run("TextMarshaller key", func(t *testing.T) {
om := New[marshallable, any]()
om.Set(marshallable(1), "bar")
om.Set(marshallable(28), true)
b, err := json.Marshal(om)
assert.NoError(t, err)
assert.Equal(t, `{"#1#":"bar","#28#":true}`, string(b))
})
t.Run("empty map", func(t *testing.T) {
om := New[string, any]()
b, err := json.Marshal(om)
assert.NoError(t, err)
assert.Equal(t, `{}`, string(b))
})
}
func TestUnmarshallJSON(t *testing.T) {
t.Run("int key", func(t *testing.T) {
data := `{"1":"bar","7":"baz","2":28,"3":100,"4":"baz","5":"28","6":"100","8":"baz"}`
om := New[int, any]()
require.NoError(t, json.Unmarshal([]byte(data), &om))
assertOrderedPairsEqual(t, om,
[]int{1, 7, 2, 3, 4, 5, 6, 8},
[]any{"bar", "baz", float64(28), float64(100), "baz", "28", "100", "baz"})
})
t.Run("string key", func(t *testing.T) {
data := `{"test":"bar","abc":true}`
om := New[string, any]()
require.NoError(t, json.Unmarshal([]byte(data), &om))
assertOrderedPairsEqual(t, om,
[]string{"test", "abc"},
[]any{"bar", true})
})
t.Run("typed string key", func(t *testing.T) {
data := `{"test":"bar","abc":true}`
type myString string
om := New[myString, any]()
require.NoError(t, json.Unmarshal([]byte(data), &om))
assertOrderedPairsEqual(t, om,
[]myString{"test", "abc"},
[]any{"bar", true})
})
t.Run("typed int key", func(t *testing.T) {
data := `{"1":"bar","7":"baz","2":28,"3":100,"4":"baz","5":"28","6":"100","8":"baz"}`
type myInt uint32
om := New[myInt, any]()
require.NoError(t, json.Unmarshal([]byte(data), &om))
assertOrderedPairsEqual(t, om,
[]myInt{1, 7, 2, 3, 4, 5, 6, 8},
[]any{"bar", "baz", float64(28), float64(100), "baz", "28", "100", "baz"})
})
t.Run("TextUnmarshaler key", func(t *testing.T) {
data := `{"#1#":"bar","#28#":true}`
om := New[marshallable, any]()
require.NoError(t, json.Unmarshal([]byte(data), &om))
assertOrderedPairsEqual(t, om,
[]marshallable{1, 28},
[]any{"bar", true})
})
t.Run("when fed with an input that's not an object", func(t *testing.T) {
for _, data := range []string{"true", `["foo"]`, "42", `"foo"`} {
om := New[int, any]()
require.Error(t, json.Unmarshal([]byte(data), &om))
}
})
t.Run("empty map", func(t *testing.T) {
data := `{}`
om := New[int, any]()
require.NoError(t, json.Unmarshal([]byte(data), &om))
assertLenEqual(t, om, 0)
})
}
// const specialCharacters = "\\\\/\"\b\f\n\r\t\x00\uffff\ufffd世界\u007f\u00ff\U0010FFFF"
const specialCharacters = "\uffff\ufffd世界\u007f\u00ff\U0010FFFF"
func TestJSONSpecialCharacters(t *testing.T) {
baselineMap := map[string]any{specialCharacters: specialCharacters}
baselineData, err := json.Marshal(baselineMap)
require.NoError(t, err) // baseline proves this key is supported by official json library
t.Logf("specialCharacters: %#v as []rune:%v", specialCharacters, []rune(specialCharacters))
t.Logf("baseline json data: %s", baselineData)
t.Run("marshal special characters", func(t *testing.T) {
om := New[string, any]()
om.Set(specialCharacters, specialCharacters)
b, err := json.Marshal(om)
require.NoError(t, err)
require.Equal(t, baselineData, b)
type myString string
om2 := New[myString, myString]()
om2.Set(specialCharacters, specialCharacters)
b, err = json.Marshal(om2)
require.NoError(t, err)
require.Equal(t, baselineData, b)
})
t.Run("unmarshall special characters", func(t *testing.T) {
om := New[string, any]()
require.NoError(t, json.Unmarshal(baselineData, &om))
assertOrderedPairsEqual(t, om,
[]string{specialCharacters},
[]any{specialCharacters})
type myString string
om2 := New[myString, myString]()
require.NoError(t, json.Unmarshal(baselineData, &om2))
assertOrderedPairsEqual(t, om2,
[]myString{specialCharacters},
[]myString{specialCharacters})
})
}
// to test structs that have nested map fields
type nestedMaps struct {
X int `json:"x" yaml:"x"`
M *OrderedMap[string, []*OrderedMap[int, *OrderedMap[string, any]]] `json:"m" yaml:"m"`
}
func TestJSONRoundTrip(t *testing.T) {
for _, testCase := range []struct {
name string
input string
targetFactory func() any
isPrettyPrinted bool
}{
{
name: "",
input: `{
"x": 28,
"m": {
"foo": [
{
"12": {
"i": 12,
"b": true,
"n": null,
"m": {
"a": "b",
"c": 28
}
},
"28": {
"a": false,
"b": [
1,
2,
3
]
}
},
{
"3": {
"c": null,
"d": 87
},
"4": {
"e": true
},
"5": {
"f": 4,
"g": 5,
"h": 6
}
}
],
"bar": [
{
"5": {
"foo": "bar"
}
}
]
}
}`,
targetFactory: func() any { return &nestedMaps{} },
isPrettyPrinted: true,
},
{
name: "with UTF-8 special chars in key",
input: `{"<22>":0}`,
targetFactory: func() any { return &OrderedMap[string, int]{} },
},
} {
t.Run(testCase.name, func(t *testing.T) {
target := testCase.targetFactory()
require.NoError(t, json.Unmarshal([]byte(testCase.input), target))
var (
out []byte
err error
)
if testCase.isPrettyPrinted {
out, err = json.MarshalIndent(target, "", " ")
} else {
out, err = json.Marshal(target)
}
if assert.NoError(t, err) {
assert.Equal(t, strings.TrimSpace(testCase.input), string(out))
}
})
}
}
func BenchmarkMarshalJSON(b *testing.B) {
om := New[int, any]()
om.Set(1, "bar")
om.Set(7, "baz")
om.Set(2, 28)
om.Set(3, 100)
om.Set(4, "baz")
om.Set(5, "28")
om.Set(6, "100")
om.Set(8, "baz")
om.Set(8, "baz")
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = json.Marshal(om)
}
}

View File

@@ -0,0 +1,295 @@
// Package orderedmap implements an ordered map, i.e. a map that also keeps track of
// the order in which keys were inserted.
//
// All operations are constant-time.
//
// Github repo: https://github.com/wk8/go-ordered-map
package orderedmap
import (
"fmt"
list "github.com/bahlo/generic-list-go"
)
type Pair[K comparable, V any] struct {
Key K
Value V
element *list.Element[*Pair[K, V]]
}
type OrderedMap[K comparable, V any] struct {
pairs map[K]*Pair[K, V]
list *list.List[*Pair[K, V]]
}
type initConfig[K comparable, V any] struct {
capacity int
initialData []Pair[K, V]
}
type InitOption[K comparable, V any] func(config *initConfig[K, V])
// WithCapacity allows giving a capacity hint for the map, akin to the standard make(map[K]V, capacity).
func WithCapacity[K comparable, V any](capacity int) InitOption[K, V] {
return func(c *initConfig[K, V]) {
c.capacity = capacity
}
}
// WithInitialData allows passing in initial data for the map.
func WithInitialData[K comparable, V any](initialData ...Pair[K, V]) InitOption[K, V] {
return func(c *initConfig[K, V]) {
c.initialData = initialData
if c.capacity < len(initialData) {
c.capacity = len(initialData)
}
}
}
// New creates a new OrderedMap.
// options can either be one or several InitOption[K, V], or a single integer,
// which is then interpreted as a capacity hint, à la make(map[K]V, capacity).
func New[K comparable, V any](options ...any) *OrderedMap[K, V] { //nolint:varnamelen
orderedMap := &OrderedMap[K, V]{}
var config initConfig[K, V]
for _, untypedOption := range options {
switch option := untypedOption.(type) {
case int:
if len(options) != 1 {
invalidOption()
}
config.capacity = option
case InitOption[K, V]:
option(&config)
default:
invalidOption()
}
}
orderedMap.initialize(config.capacity)
orderedMap.AddPairs(config.initialData...)
return orderedMap
}
const invalidOptionMessage = `when using orderedmap.New[K,V]() with options, either provide one or several InitOption[K, V]; or a single integer which is then interpreted as a capacity hint, à la make(map[K]V, capacity).` //nolint:lll
func invalidOption() { panic(invalidOptionMessage) }
func (om *OrderedMap[K, V]) initialize(capacity int) {
om.pairs = make(map[K]*Pair[K, V], capacity)
om.list = list.New[*Pair[K, V]]()
}
// Get looks for the given key, and returns the value associated with it,
// or V's nil value if not found. The boolean it returns says whether the key is present in the map.
func (om *OrderedMap[K, V]) Get(key K) (val V, present bool) {
if pair, present := om.pairs[key]; present {
return pair.Value, true
}
return
}
// Load is an alias for Get, mostly to present an API similar to `sync.Map`'s.
func (om *OrderedMap[K, V]) Load(key K) (V, bool) {
return om.Get(key)
}
// Value returns the value associated with the given key or the zero value.
func (om *OrderedMap[K, V]) Value(key K) (val V) {
if pair, present := om.pairs[key]; present {
val = pair.Value
}
return
}
// GetPair looks for the given key, and returns the pair associated with it,
// or nil if not found. The Pair struct can then be used to iterate over the ordered map
// from that point, either forward or backward.
func (om *OrderedMap[K, V]) GetPair(key K) *Pair[K, V] {
return om.pairs[key]
}
// Set sets the key-value pair, and returns what `Get` would have returned
// on that key prior to the call to `Set`.
func (om *OrderedMap[K, V]) Set(key K, value V) (val V, present bool) {
if pair, present := om.pairs[key]; present {
oldValue := pair.Value
pair.Value = value
return oldValue, true
}
pair := &Pair[K, V]{
Key: key,
Value: value,
}
pair.element = om.list.PushBack(pair)
om.pairs[key] = pair
return
}
// AddPairs allows setting multiple pairs at a time. It's equivalent to calling
// Set on each pair sequentially.
func (om *OrderedMap[K, V]) AddPairs(pairs ...Pair[K, V]) {
for _, pair := range pairs {
om.Set(pair.Key, pair.Value)
}
}
// Store is an alias for Set, mostly to present an API similar to `sync.Map`'s.
func (om *OrderedMap[K, V]) Store(key K, value V) (V, bool) {
return om.Set(key, value)
}
// Delete removes the key-value pair, and returns what `Get` would have returned
// on that key prior to the call to `Delete`.
func (om *OrderedMap[K, V]) Delete(key K) (val V, present bool) {
if pair, present := om.pairs[key]; present {
om.list.Remove(pair.element)
delete(om.pairs, key)
return pair.Value, true
}
return
}
// Len returns the length of the ordered map.
func (om *OrderedMap[K, V]) Len() int {
if om == nil || om.pairs == nil {
return 0
}
return len(om.pairs)
}
// Oldest returns a pointer to the oldest pair. It's meant to be used to iterate on the ordered map's
// pairs from the oldest to the newest, e.g.:
// for pair := orderedMap.Oldest(); pair != nil; pair = pair.Next() { fmt.Printf("%v => %v\n", pair.Key, pair.Value) }
func (om *OrderedMap[K, V]) Oldest() *Pair[K, V] {
if om == nil || om.list == nil {
return nil
}
return listElementToPair(om.list.Front())
}
// Newest returns a pointer to the newest pair. It's meant to be used to iterate on the ordered map's
// pairs from the newest to the oldest, e.g.:
// for pair := orderedMap.Oldest(); pair != nil; pair = pair.Next() { fmt.Printf("%v => %v\n", pair.Key, pair.Value) }
func (om *OrderedMap[K, V]) Newest() *Pair[K, V] {
if om == nil || om.list == nil {
return nil
}
return listElementToPair(om.list.Back())
}
// Next returns a pointer to the next pair.
func (p *Pair[K, V]) Next() *Pair[K, V] {
return listElementToPair(p.element.Next())
}
// Prev returns a pointer to the previous pair.
func (p *Pair[K, V]) Prev() *Pair[K, V] {
return listElementToPair(p.element.Prev())
}
func listElementToPair[K comparable, V any](element *list.Element[*Pair[K, V]]) *Pair[K, V] {
if element == nil {
return nil
}
return element.Value
}
// KeyNotFoundError may be returned by functions in this package when they're called with keys that are not present
// in the map.
type KeyNotFoundError[K comparable] struct {
MissingKey K
}
func (e *KeyNotFoundError[K]) Error() string {
return fmt.Sprintf("missing key: %v", e.MissingKey)
}
// MoveAfter moves the value associated with key to its new position after the one associated with markKey.
// Returns an error iff key or markKey are not present in the map. If an error is returned,
// it will be a KeyNotFoundError.
func (om *OrderedMap[K, V]) MoveAfter(key, markKey K) error {
elements, err := om.getElements(key, markKey)
if err != nil {
return err
}
om.list.MoveAfter(elements[0], elements[1])
return nil
}
// MoveBefore moves the value associated with key to its new position before the one associated with markKey.
// Returns an error iff key or markKey are not present in the map. If an error is returned,
// it will be a KeyNotFoundError.
func (om *OrderedMap[K, V]) MoveBefore(key, markKey K) error {
elements, err := om.getElements(key, markKey)
if err != nil {
return err
}
om.list.MoveBefore(elements[0], elements[1])
return nil
}
func (om *OrderedMap[K, V]) getElements(keys ...K) ([]*list.Element[*Pair[K, V]], error) {
elements := make([]*list.Element[*Pair[K, V]], len(keys))
for i, k := range keys {
pair, present := om.pairs[k]
if !present {
return nil, &KeyNotFoundError[K]{k}
}
elements[i] = pair.element
}
return elements, nil
}
// MoveToBack moves the value associated with key to the back of the ordered map,
// i.e. makes it the newest pair in the map.
// Returns an error iff key is not present in the map. If an error is returned,
// it will be a KeyNotFoundError.
func (om *OrderedMap[K, V]) MoveToBack(key K) error {
_, err := om.GetAndMoveToBack(key)
return err
}
// MoveToFront moves the value associated with key to the front of the ordered map,
// i.e. makes it the oldest pair in the map.
// Returns an error iff key is not present in the map. If an error is returned,
// it will be a KeyNotFoundError.
func (om *OrderedMap[K, V]) MoveToFront(key K) error {
_, err := om.GetAndMoveToFront(key)
return err
}
// GetAndMoveToBack combines Get and MoveToBack in the same call. If an error is returned,
// it will be a KeyNotFoundError.
func (om *OrderedMap[K, V]) GetAndMoveToBack(key K) (val V, err error) {
if pair, present := om.pairs[key]; present {
val = pair.Value
om.list.MoveToBack(pair.element)
} else {
err = &KeyNotFoundError[K]{key}
}
return
}
// GetAndMoveToFront combines Get and MoveToFront in the same call. If an error is returned,
// it will be a KeyNotFoundError.
func (om *OrderedMap[K, V]) GetAndMoveToFront(key K) (val V, err error) {
if pair, present := om.pairs[key]; present {
val = pair.Value
om.list.MoveToFront(pair.element)
} else {
err = &KeyNotFoundError[K]{key}
}
return
}

View File

@@ -0,0 +1,384 @@
package orderedmap
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
)
func TestBasicFeatures(t *testing.T) {
n := 100
om := New[int, int]()
// set(i, 2 * i)
for i := 0; i < n; i++ {
assertLenEqual(t, om, i)
oldValue, present := om.Set(i, 2*i)
assertLenEqual(t, om, i+1)
assert.Equal(t, 0, oldValue)
assert.False(t, present)
}
// get what we just set
for i := 0; i < n; i++ {
value, present := om.Get(i)
assert.Equal(t, 2*i, value)
assert.Equal(t, value, om.Value(i))
assert.True(t, present)
}
// get pairs of what we just set
for i := 0; i < n; i++ {
pair := om.GetPair(i)
assert.NotNil(t, pair)
assert.Equal(t, 2*i, pair.Value)
}
// forward iteration
i := 0
for pair := om.Oldest(); pair != nil; pair = pair.Next() {
assert.Equal(t, i, pair.Key)
assert.Equal(t, 2*i, pair.Value)
i++
}
// backward iteration
i = n - 1
for pair := om.Newest(); pair != nil; pair = pair.Prev() {
assert.Equal(t, i, pair.Key)
assert.Equal(t, 2*i, pair.Value)
i--
}
// forward iteration starting from known key
i = 42
for pair := om.GetPair(i); pair != nil; pair = pair.Next() {
assert.Equal(t, i, pair.Key)
assert.Equal(t, 2*i, pair.Value)
i++
}
// double values for pairs with even keys
for j := 0; j < n/2; j++ {
i = 2 * j
oldValue, present := om.Set(i, 4*i)
assert.Equal(t, 2*i, oldValue)
assert.True(t, present)
}
// and delete pairs with odd keys
for j := 0; j < n/2; j++ {
i = 2*j + 1
assertLenEqual(t, om, n-j)
value, present := om.Delete(i)
assertLenEqual(t, om, n-j-1)
assert.Equal(t, 2*i, value)
assert.True(t, present)
// deleting again shouldn't change anything
value, present = om.Delete(i)
assertLenEqual(t, om, n-j-1)
assert.Equal(t, 0, value)
assert.False(t, present)
}
// get the whole range
for j := 0; j < n/2; j++ {
i = 2 * j
value, present := om.Get(i)
assert.Equal(t, 4*i, value)
assert.Equal(t, value, om.Value(i))
assert.True(t, present)
i = 2*j + 1
value, present = om.Get(i)
assert.Equal(t, 0, value)
assert.Equal(t, value, om.Value(i))
assert.False(t, present)
}
// check iterations again
i = 0
for pair := om.Oldest(); pair != nil; pair = pair.Next() {
assert.Equal(t, i, pair.Key)
assert.Equal(t, 4*i, pair.Value)
i += 2
}
i = 2 * ((n - 1) / 2)
for pair := om.Newest(); pair != nil; pair = pair.Prev() {
assert.Equal(t, i, pair.Key)
assert.Equal(t, 4*i, pair.Value)
i -= 2
}
}
func TestUpdatingDoesntChangePairsOrder(t *testing.T) {
om := New[string, any]()
om.Set("foo", "bar")
om.Set("wk", 28)
om.Set("po", 100)
om.Set("bar", "baz")
oldValue, present := om.Set("po", 102)
assert.Equal(t, 100, oldValue)
assert.True(t, present)
assertOrderedPairsEqual(t, om,
[]string{"foo", "wk", "po", "bar"},
[]any{"bar", 28, 102, "baz"})
}
func TestDeletingAndReinsertingChangesPairsOrder(t *testing.T) {
om := New[string, any]()
om.Set("foo", "bar")
om.Set("wk", 28)
om.Set("po", 100)
om.Set("bar", "baz")
// delete a pair
oldValue, present := om.Delete("po")
assert.Equal(t, 100, oldValue)
assert.True(t, present)
// re-insert the same pair
oldValue, present = om.Set("po", 100)
assert.Nil(t, oldValue)
assert.False(t, present)
assertOrderedPairsEqual(t, om,
[]string{"foo", "wk", "bar", "po"},
[]any{"bar", 28, "baz", 100})
}
func TestEmptyMapOperations(t *testing.T) {
om := New[string, any]()
oldValue, present := om.Get("foo")
assert.Nil(t, oldValue)
assert.Nil(t, om.Value("foo"))
assert.False(t, present)
oldValue, present = om.Delete("bar")
assert.Nil(t, oldValue)
assert.False(t, present)
assertLenEqual(t, om, 0)
assert.Nil(t, om.Oldest())
assert.Nil(t, om.Newest())
}
type dummyTestStruct struct {
value string
}
func TestPackUnpackStructs(t *testing.T) {
om := New[string, dummyTestStruct]()
om.Set("foo", dummyTestStruct{"foo!"})
om.Set("bar", dummyTestStruct{"bar!"})
value, present := om.Get("foo")
assert.True(t, present)
assert.Equal(t, value, om.Value("foo"))
if assert.NotNil(t, value) {
assert.Equal(t, "foo!", value.value)
}
value, present = om.Set("bar", dummyTestStruct{"baz!"})
assert.True(t, present)
if assert.NotNil(t, value) {
assert.Equal(t, "bar!", value.value)
}
value, present = om.Get("bar")
assert.Equal(t, value, om.Value("bar"))
assert.True(t, present)
if assert.NotNil(t, value) {
assert.Equal(t, "baz!", value.value)
}
}
// shamelessly stolen from https://github.com/python/cpython/blob/e19a91e45fd54a56e39c2d12e6aaf4757030507f/Lib/test/test_ordered_dict.py#L55-L61
func TestShuffle(t *testing.T) {
ranLen := 100
for _, n := range []int{0, 10, 20, 100, 1000, 10000} {
t.Run(fmt.Sprintf("shuffle test with %d items", n), func(t *testing.T) {
om := New[string, string]()
keys := make([]string, n)
values := make([]string, n)
for i := 0; i < n; i++ {
// we prefix with the number to ensure that we don't get any duplicates
keys[i] = fmt.Sprintf("%d_%s", i, randomHexString(t, ranLen))
values[i] = randomHexString(t, ranLen)
value, present := om.Set(keys[i], values[i])
assert.Equal(t, "", value)
assert.False(t, present)
}
assertOrderedPairsEqual(t, om, keys, values)
})
}
}
func TestMove(t *testing.T) {
om := New[int, any]()
om.Set(1, "bar")
om.Set(2, 28)
om.Set(3, 100)
om.Set(4, "baz")
om.Set(5, "28")
om.Set(6, "100")
om.Set(7, "baz")
om.Set(8, "baz")
err := om.MoveAfter(2, 3)
assert.Nil(t, err)
assertOrderedPairsEqual(t, om,
[]int{1, 3, 2, 4, 5, 6, 7, 8},
[]any{"bar", 100, 28, "baz", "28", "100", "baz", "baz"})
err = om.MoveBefore(6, 4)
assert.Nil(t, err)
assertOrderedPairsEqual(t, om,
[]int{1, 3, 2, 6, 4, 5, 7, 8},
[]any{"bar", 100, 28, "100", "baz", "28", "baz", "baz"})
err = om.MoveToBack(3)
assert.Nil(t, err)
assertOrderedPairsEqual(t, om,
[]int{1, 2, 6, 4, 5, 7, 8, 3},
[]any{"bar", 28, "100", "baz", "28", "baz", "baz", 100})
err = om.MoveToFront(5)
assert.Nil(t, err)
assertOrderedPairsEqual(t, om,
[]int{5, 1, 2, 6, 4, 7, 8, 3},
[]any{"28", "bar", 28, "100", "baz", "baz", "baz", 100})
err = om.MoveToFront(100)
assert.Equal(t, &KeyNotFoundError[int]{100}, err)
}
func TestGetAndMove(t *testing.T) {
om := New[int, any]()
om.Set(1, "bar")
om.Set(2, 28)
om.Set(3, 100)
om.Set(4, "baz")
om.Set(5, "28")
om.Set(6, "100")
om.Set(7, "baz")
om.Set(8, "baz")
value, err := om.GetAndMoveToBack(3)
assert.Nil(t, err)
assert.Equal(t, 100, value)
assertOrderedPairsEqual(t, om,
[]int{1, 2, 4, 5, 6, 7, 8, 3},
[]any{"bar", 28, "baz", "28", "100", "baz", "baz", 100})
value, err = om.GetAndMoveToFront(5)
assert.Nil(t, err)
assert.Equal(t, "28", value)
assertOrderedPairsEqual(t, om,
[]int{5, 1, 2, 4, 6, 7, 8, 3},
[]any{"28", "bar", 28, "baz", "100", "baz", "baz", 100})
value, err = om.GetAndMoveToBack(100)
assert.Equal(t, &KeyNotFoundError[int]{100}, err)
}
func TestAddPairs(t *testing.T) {
om := New[int, any]()
om.AddPairs(
Pair[int, any]{
Key: 28,
Value: "foo",
},
Pair[int, any]{
Key: 12,
Value: "bar",
},
Pair[int, any]{
Key: 28,
Value: "baz",
},
)
assertOrderedPairsEqual(t, om,
[]int{28, 12},
[]any{"baz", "bar"})
}
// sadly, we can't test the "actual" capacity here, see https://github.com/golang/go/issues/52157
func TestNewWithCapacity(t *testing.T) {
zero := New[int, string](0)
assert.Empty(t, zero.Len())
assert.PanicsWithValue(t, invalidOptionMessage, func() {
_ = New[int, string](1, 2)
})
assert.PanicsWithValue(t, invalidOptionMessage, func() {
_ = New[int, string](1, 2, 3)
})
om := New[int, string](-1)
om.Set(1337, "quarante-deux")
assert.Equal(t, 1, om.Len())
}
func TestNewWithOptions(t *testing.T) {
t.Run("wih capacity", func(t *testing.T) {
om := New[string, any](WithCapacity[string, any](98))
assert.Equal(t, 0, om.Len())
})
t.Run("with initial data", func(t *testing.T) {
om := New[string, int](WithInitialData(
Pair[string, int]{
Key: "a",
Value: 1,
},
Pair[string, int]{
Key: "b",
Value: 2,
},
Pair[string, int]{
Key: "c",
Value: 3,
},
))
assertOrderedPairsEqual(t, om,
[]string{"a", "b", "c"},
[]int{1, 2, 3})
})
t.Run("with an invalid option type", func(t *testing.T) {
assert.PanicsWithValue(t, invalidOptionMessage, func() {
_ = New[int, string]("foo")
})
})
}
func TestNilMap(t *testing.T) {
// we want certain behaviors of a nil ordered map to be the same as they are for standard nil maps
var om *OrderedMap[int, any]
t.Run("len", func(t *testing.T) {
assert.Equal(t, 0, om.Len())
})
t.Run("iterating - akin to range", func(t *testing.T) {
assert.Nil(t, om.Oldest())
assert.Nil(t, om.Newest())
})
}

View File

@@ -0,0 +1,76 @@
package orderedmap
import (
"crypto/rand"
"encoding/hex"
"fmt"
"testing"
"github.com/stretchr/testify/assert"
)
// assertOrderedPairsEqual asserts that the map contains the given keys and values
// from oldest to newest.
func assertOrderedPairsEqual[K comparable, V any](
t *testing.T, orderedMap *OrderedMap[K, V], expectedKeys []K, expectedValues []V,
) {
t.Helper()
assertOrderedPairsEqualFromNewest(t, orderedMap, expectedKeys, expectedValues)
assertOrderedPairsEqualFromOldest(t, orderedMap, expectedKeys, expectedValues)
}
func assertOrderedPairsEqualFromNewest[K comparable, V any](
t *testing.T, orderedMap *OrderedMap[K, V], expectedKeys []K, expectedValues []V,
) {
t.Helper()
if assert.Equal(t, len(expectedKeys), len(expectedValues)) && assert.Equal(t, len(expectedKeys), orderedMap.Len()) {
i := orderedMap.Len() - 1
for pair := orderedMap.Newest(); pair != nil; pair = pair.Prev() {
assert.Equal(t, expectedKeys[i], pair.Key, "from newest index=%d on key", i)
assert.Equal(t, expectedValues[i], pair.Value, "from newest index=%d on value", i)
i--
}
}
}
func assertOrderedPairsEqualFromOldest[K comparable, V any](
t *testing.T, orderedMap *OrderedMap[K, V], expectedKeys []K, expectedValues []V,
) {
t.Helper()
if assert.Equal(t, len(expectedKeys), len(expectedValues)) && assert.Equal(t, len(expectedKeys), orderedMap.Len()) {
i := 0
for pair := orderedMap.Oldest(); pair != nil; pair = pair.Next() {
assert.Equal(t, expectedKeys[i], pair.Key, "from oldest index=%d on key", i)
assert.Equal(t, expectedValues[i], pair.Value, "from oldest index=%d on value", i)
i++
}
}
}
func assertLenEqual[K comparable, V any](t *testing.T, orderedMap *OrderedMap[K, V], expectedLen int) {
t.Helper()
assert.Equal(t, expectedLen, orderedMap.Len())
// also check the list length, for good measure
assert.Equal(t, expectedLen, orderedMap.list.Len())
}
func randomHexString(t *testing.T, length int) string {
t.Helper()
b := length / 2 //nolint:gomnd
randBytes := make([]byte, b)
if n, err := rand.Read(randBytes); err != nil || n != b {
if err == nil {
err = fmt.Errorf("only got %v random bytes, expected %v", n, b)
}
t.Fatal(err)
}
return hex.EncodeToString(randBytes)
}

71
common/orderedmap/yaml.go Normal file
View File

@@ -0,0 +1,71 @@
package orderedmap
import (
"fmt"
"gopkg.in/yaml.v3"
)
var (
_ yaml.Marshaler = &OrderedMap[int, any]{}
_ yaml.Unmarshaler = &OrderedMap[int, any]{}
)
// MarshalYAML implements the yaml.Marshaler interface.
func (om *OrderedMap[K, V]) MarshalYAML() (interface{}, error) {
if om == nil {
return []byte("null"), nil
}
node := yaml.Node{
Kind: yaml.MappingNode,
}
for pair := om.Oldest(); pair != nil; pair = pair.Next() {
key, value := pair.Key, pair.Value
keyNode := &yaml.Node{}
// serialize key to yaml, then deserialize it back into the node
// this is a hack to get the correct tag for the key
if err := keyNode.Encode(key); err != nil {
return nil, err
}
valueNode := &yaml.Node{}
if err := valueNode.Encode(value); err != nil {
return nil, err
}
node.Content = append(node.Content, keyNode, valueNode)
}
return &node, nil
}
// UnmarshalYAML implements the yaml.Unmarshaler interface.
func (om *OrderedMap[K, V]) UnmarshalYAML(value *yaml.Node) error {
if value.Kind != yaml.MappingNode {
return fmt.Errorf("pipeline must contain YAML mapping, has %v", value.Kind)
}
if om.list == nil {
om.initialize(0)
}
for index := 0; index < len(value.Content); index += 2 {
var key K
var val V
if err := value.Content[index].Decode(&key); err != nil {
return err
}
if err := value.Content[index+1].Decode(&val); err != nil {
return err
}
om.Set(key, val)
}
return nil
}

View File

@@ -0,0 +1,82 @@
package orderedmap
// Adapted from https://github.com/dvyukov/go-fuzz-corpus/blob/c42c1b2/json/json.go
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gopkg.in/yaml.v3"
)
func FuzzRoundTripYAML(f *testing.F) {
f.Fuzz(func(t *testing.T, data []byte) {
for _, testCase := range []struct {
name string
constructor func() any
// should be a function that asserts that 2 objects of the type returned by constructor are equal
equalityAssertion func(*testing.T, any, any) bool
}{
{
name: "with a string -> string map",
constructor: func() any { return &OrderedMap[string, string]{} },
equalityAssertion: assertOrderedMapsEqual[string, string],
},
{
name: "with a string -> int map",
constructor: func() any { return &OrderedMap[string, int]{} },
equalityAssertion: assertOrderedMapsEqual[string, int],
},
{
name: "with a string -> any map",
constructor: func() any { return &OrderedMap[string, any]{} },
equalityAssertion: assertOrderedMapsEqual[string, any],
},
{
name: "with a struct with map fields",
constructor: func() any { return new(testFuzzStruct) },
equalityAssertion: assertTestFuzzStructEqual,
},
} {
t.Run(testCase.name, func(t *testing.T) {
v1 := testCase.constructor()
if yaml.Unmarshal(data, v1) != nil {
return
}
t.Log(data)
t.Log(v1)
yamlData, err := yaml.Marshal(v1)
require.NoError(t, err)
t.Log(string(yamlData))
v2 := testCase.constructor()
err = yaml.Unmarshal(yamlData, v2)
if err != nil {
t.Log(string(yamlData))
t.Fatal(err)
}
if !assert.True(t, testCase.equalityAssertion(t, v1, v2), "failed with input data %q", string(data)) {
// look at that what the standard lib does with regular map, to help with debugging
var m1 map[string]any
require.NoError(t, yaml.Unmarshal(data, &m1))
mapJsonData, err := yaml.Marshal(m1)
require.NoError(t, err)
var m2 map[string]any
require.NoError(t, yaml.Unmarshal(mapJsonData, &m2))
t.Logf("initial data = %s", string(data))
t.Logf("unmarshalled map = %v", m1)
t.Logf("re-marshalled from map = %s", string(mapJsonData))
t.Logf("re-marshalled from test obj = %s", string(yamlData))
t.Logf("re-unmarshalled map = %s", m2)
}
})
}
})
}

View File

@@ -0,0 +1,334 @@
package orderedmap
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gopkg.in/yaml.v3"
)
func TestMarshalYAML(t *testing.T) {
t.Run("int key", func(t *testing.T) {
om := New[int, any]()
om.Set(1, "bar")
om.Set(7, "baz")
om.Set(2, 28)
om.Set(3, 100)
om.Set(4, "baz")
om.Set(5, "28")
om.Set(6, "100")
om.Set(8, "baz")
om.Set(8, "baz")
om.Set(9, "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Quisque auctor augue accumsan mi maximus, quis viverra massa pretium. Phasellus imperdiet sapien a interdum sollicitudin. Duis at commodo lectus, a lacinia sem.")
b, err := yaml.Marshal(om)
expected := `1: bar
7: baz
2: 28
3: 100
4: baz
5: "28"
6: "100"
8: baz
9: Lorem ipsum dolor sit amet, consectetur adipiscing elit. Quisque auctor augue accumsan mi maximus, quis viverra massa pretium. Phasellus imperdiet sapien a interdum sollicitudin. Duis at commodo lectus, a lacinia sem.
`
assert.NoError(t, err)
assert.Equal(t, expected, string(b))
})
t.Run("string key", func(t *testing.T) {
om := New[string, any]()
om.Set("test", "bar")
om.Set("abc", true)
b, err := yaml.Marshal(om)
assert.NoError(t, err)
expected := `test: bar
abc: true
`
assert.Equal(t, expected, string(b))
})
t.Run("typed string key", func(t *testing.T) {
type myString string
om := New[myString, any]()
om.Set("test", "bar")
om.Set("abc", true)
b, err := yaml.Marshal(om)
assert.NoError(t, err)
assert.Equal(t, `test: bar
abc: true
`, string(b))
})
t.Run("typed int key", func(t *testing.T) {
type myInt uint32
om := New[myInt, any]()
om.Set(1, "bar")
om.Set(7, "baz")
om.Set(2, 28)
om.Set(3, 100)
om.Set(4, "baz")
b, err := yaml.Marshal(om)
assert.NoError(t, err)
assert.Equal(t, `1: bar
7: baz
2: 28
3: 100
4: baz
`, string(b))
})
t.Run("TextMarshaller key", func(t *testing.T) {
om := New[marshallable, any]()
om.Set(marshallable(1), "bar")
om.Set(marshallable(28), true)
b, err := yaml.Marshal(om)
assert.NoError(t, err)
assert.Equal(t, `'#1#': bar
'#28#': true
`, string(b))
})
t.Run("empty map with 0 elements", func(t *testing.T) {
om := New[string, any]()
b, err := yaml.Marshal(om)
assert.NoError(t, err)
assert.Equal(t, "{}\n", string(b))
})
t.Run("empty map with no elements (null)", func(t *testing.T) {
om := &OrderedMap[string, string]{}
b, err := yaml.Marshal(om)
assert.NoError(t, err)
assert.Equal(t, "{}\n", string(b))
})
}
func TestUnmarshallYAML(t *testing.T) {
t.Run("int key", func(t *testing.T) {
data := `
1: bar
7: baz
2: 28
3: 100
4: baz
5: "28"
6: "100"
8: baz
`
om := New[int, any]()
require.NoError(t, yaml.Unmarshal([]byte(data), &om))
assertOrderedPairsEqual(t, om,
[]int{1, 7, 2, 3, 4, 5, 6, 8},
[]any{"bar", "baz", 28, 100, "baz", "28", "100", "baz"})
// serialize back to yaml to make sure things are equal
})
t.Run("string key", func(t *testing.T) {
data := `{"test":"bar","abc":true}`
om := New[string, any]()
require.NoError(t, yaml.Unmarshal([]byte(data), &om))
assertOrderedPairsEqual(t, om,
[]string{"test", "abc"},
[]any{"bar", true})
})
t.Run("typed string key", func(t *testing.T) {
data := `{"test":"bar","abc":true}`
type myString string
om := New[myString, any]()
require.NoError(t, yaml.Unmarshal([]byte(data), &om))
assertOrderedPairsEqual(t, om,
[]myString{"test", "abc"},
[]any{"bar", true})
})
t.Run("typed int key", func(t *testing.T) {
data := `
1: bar
7: baz
2: 28
3: 100
4: baz
5: "28"
6: "100"
8: baz
`
type myInt uint32
om := New[myInt, any]()
require.NoError(t, yaml.Unmarshal([]byte(data), &om))
assertOrderedPairsEqual(t, om,
[]myInt{1, 7, 2, 3, 4, 5, 6, 8},
[]any{"bar", "baz", 28, 100, "baz", "28", "100", "baz"})
})
t.Run("TextUnmarshaler key", func(t *testing.T) {
data := `{"#1#":"bar","#28#":true}`
om := New[marshallable, any]()
require.NoError(t, yaml.Unmarshal([]byte(data), &om))
assertOrderedPairsEqual(t, om,
[]marshallable{1, 28},
[]any{"bar", true})
})
t.Run("when fed with an input that's not an object", func(t *testing.T) {
for _, data := range []string{"true", `["foo"]`, "42", `"foo"`} {
om := New[int, any]()
require.Error(t, yaml.Unmarshal([]byte(data), &om))
}
})
t.Run("empty map", func(t *testing.T) {
data := `{}`
om := New[int, any]()
require.NoError(t, yaml.Unmarshal([]byte(data), &om))
assertLenEqual(t, om, 0)
})
}
func TestYAMLSpecialCharacters(t *testing.T) {
baselineMap := map[string]any{specialCharacters: specialCharacters}
baselineData, err := yaml.Marshal(baselineMap)
require.NoError(t, err) // baseline proves this key is supported by official yaml library
t.Logf("specialCharacters: %#v as []rune:%v", specialCharacters, []rune(specialCharacters))
t.Logf("baseline yaml data: %s", baselineData)
t.Run("marshal special characters", func(t *testing.T) {
om := New[string, any]()
om.Set(specialCharacters, specialCharacters)
b, err := yaml.Marshal(om)
require.NoError(t, err)
require.Equal(t, baselineData, b)
type myString string
om2 := New[myString, myString]()
om2.Set(specialCharacters, specialCharacters)
b, err = yaml.Marshal(om2)
require.NoError(t, err)
require.Equal(t, baselineData, b)
})
t.Run("unmarshall special characters", func(t *testing.T) {
om := New[string, any]()
require.NoError(t, yaml.Unmarshal(baselineData, &om))
assertOrderedPairsEqual(t, om,
[]string{specialCharacters},
[]any{specialCharacters})
type myString string
om2 := New[myString, myString]()
require.NoError(t, yaml.Unmarshal(baselineData, &om2))
assertOrderedPairsEqual(t, om2,
[]myString{specialCharacters},
[]myString{specialCharacters})
})
}
func TestYAMLRoundTrip(t *testing.T) {
for _, testCase := range []struct {
name string
input string
targetFactory func() any
}{
{
name: "empty map",
input: "{}\n",
targetFactory: func() any {
return &OrderedMap[string, any]{}
},
},
{
name: "",
input: `x: 28
m:
bar:
- 5:
foo: bar
foo:
- 12:
b: true
i: 12
m:
a: b
c: 28
"n": null
28:
a: false
b:
- 1
- 2
- 3
- 3:
c: null
d: 87
4:
e: true
5:
f: 4
g: 5
h: 6
`,
targetFactory: func() any { return &nestedMaps{} },
},
{
name: "with UTF-8 special chars in key",
input: "<22>: 0\n",
targetFactory: func() any { return &OrderedMap[string, int]{} },
},
} {
t.Run(testCase.name, func(t *testing.T) {
target := testCase.targetFactory()
require.NoError(t, yaml.Unmarshal([]byte(testCase.input), target))
var (
out []byte
err error
)
out, err = yaml.Marshal(target)
if assert.NoError(t, err) {
assert.Equal(t, testCase.input, string(out))
}
})
}
}
func BenchmarkMarshalYAML(b *testing.B) {
om := New[int, any]()
om.Set(1, "bar")
om.Set(7, "baz")
om.Set(2, 28)
om.Set(3, 100)
om.Set(4, "baz")
om.Set(5, "28")
om.Set(6, "100")
om.Set(8, "baz")
om.Set(8, "baz")
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = yaml.Marshal(om)
}
}

View File

@@ -517,6 +517,10 @@ func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) e
fieldName = tagValue
}
if tagValue == "-" {
continue
}
rawMapKey := reflect.ValueOf(fieldName)
rawMapVal := dataVal.MapIndex(rawMapKey)
if !rawMapVal.IsValid() {

View File

@@ -308,3 +308,27 @@ func TestStructure_Ignore(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, s.MustIgnore, "oldData")
}
func TestStructure_IgnoreInNest(t *testing.T) {
rawMap := map[string]any{
"-": "newData",
}
type TP struct {
MustIgnore string `test:"-"`
}
s := struct {
TP
}{TP{MustIgnore: "oldData"}}
err := decoder.Decode(rawMap, &s)
assert.Nil(t, err)
assert.Equal(t, s.MustIgnore, "oldData")
// test omitempty
delete(rawMap, "-")
err = decoder.Decode(rawMap, &s)
assert.Nil(t, err)
assert.Equal(t, s.MustIgnore, "oldData")
}

14
common/yaml/yaml.go Normal file
View File

@@ -0,0 +1,14 @@
// Package yaml provides a common entrance for YAML marshaling and unmarshalling.
package yaml
import (
"gopkg.in/yaml.v3"
)
func Unmarshal(in []byte, out any) (err error) {
return yaml.Unmarshal(in, out)
}
func Marshal(in any) (out []byte, err error) {
return yaml.Marshal(in)
}

View File

@@ -1,17 +1,17 @@
package tls
package ca
import (
utls "github.com/metacubex/utls"
"github.com/metacubex/tls"
)
type ClientAuthType = utls.ClientAuthType
type ClientAuthType = tls.ClientAuthType
const (
NoClientCert = utls.NoClientCert
RequestClientCert = utls.RequestClientCert
RequireAnyClientCert = utls.RequireAnyClientCert
VerifyClientCertIfGiven = utls.VerifyClientCertIfGiven
RequireAndVerifyClientCert = utls.RequireAndVerifyClientCert
NoClientCert = tls.NoClientCert
RequestClientCert = tls.RequestClientCert
RequireAnyClientCert = tls.RequireAnyClientCert
VerifyClientCertIfGiven = tls.VerifyClientCertIfGiven
RequireAndVerifyClientCert = tls.RequireAndVerifyClientCert
)
func ClientAuthTypeFromString(s string) ClientAuthType {

View File

@@ -1,7 +1,6 @@
package ca
import (
"crypto/tls"
"crypto/x509"
_ "embed"
"errors"
@@ -11,8 +10,9 @@ import (
"sync"
"github.com/metacubex/mihomo/common/once"
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/ntp"
"github.com/metacubex/tls"
)
var globalCertPool *x509.CertPool
@@ -98,20 +98,27 @@ func GetTLSConfig(opt Option) (tlsConfig *tls.Config, err error) {
}
if len(opt.Fingerprint) > 0 {
tlsConfig.VerifyPeerCertificate, err = NewFingerprintVerifier(opt.Fingerprint, tlsConfig.Time)
verifier, err := NewFingerprintVerifier(opt.Fingerprint, tlsConfig.Time)
if err != nil {
return nil, err
}
tlsConfig.VerifyConnection = func(state tls.ConnectionState) error {
// [ConnectionState.ServerName] can return the actual ServerName needed for verification,
// avoiding inconsistencies caused by [tlsConfig.ServerName] being modified after the [NewFingerprintVerifier] call.
// https://github.com/golang/go/issues/36736#issuecomment-587925536
return verifier(state.PeerCertificates, state.ServerName)
}
tlsConfig.InsecureSkipVerify = true
}
if len(opt.Certificate) > 0 || len(opt.PrivateKey) > 0 {
var cert tls.Certificate
cert, err = LoadTLSKeyPair(opt.Certificate, opt.PrivateKey, C.Path)
certLoader, err := NewTLSKeyPairLoader(opt.Certificate, opt.PrivateKey)
if err != nil {
return nil, err
}
tlsConfig.Certificates = []tls.Certificate{cert}
tlsConfig.GetClientCertificate = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
return certLoader()
}
}
return tlsConfig, nil
}

View File

@@ -11,7 +11,7 @@ import (
)
// NewFingerprintVerifier returns a function that verifies whether a certificate's SHA-256 fingerprint matches the given one.
func NewFingerprintVerifier(fingerprint string, time func() time.Time) (func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error, error) {
func NewFingerprintVerifier(fingerprint string, time func() time.Time) (func(certs []*x509.Certificate, serverName string) error, error) {
switch fingerprint {
case "chrome", "firefox", "safari", "ios", "android", "edge", "360", "qq", "random", "randomized": // WTF???
return nil, fmt.Errorf("`fingerprint` is used for TLS certificate pinning. If you need to specify the browser fingerprint, use `client-fingerprint`")
@@ -26,37 +26,24 @@ func NewFingerprintVerifier(fingerprint string, time func() time.Time) (func(raw
return nil, fmt.Errorf("fingerprint string length error,need sha256 fingerprint")
}
return func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
return func(certs []*x509.Certificate, serverName string) error {
// ssl pining
for i, rawCert := range rawCerts {
hash := sha256.Sum256(rawCert)
for i, cert := range certs {
hash := sha256.Sum256(cert.Raw)
if bytes.Equal(fpByte, hash[:]) {
if i > 0 {
// When the fingerprint matches a non-leaf certificate,
// the certificate chain validity is verified using the certificate as the trusted root certificate.
//
// Currently, we do not verify that the SNI matches the certificate's DNS name,
// but we do verify the validity of the child certificate,
// including the issuance time and whether the child certificate was issued by the parent certificate.
certs := make([]*x509.Certificate, i+1) // stop at i
for j := range certs {
cert, err := x509.ParseCertificate(rawCerts[j])
if err != nil {
return err
}
certs[j] = cert
}
opts := x509.VerifyOptions{
Roots: x509.NewCertPool(),
Intermediates: x509.NewCertPool(),
DNSName: serverName,
}
if time != nil {
opts.CurrentTime = time()
}
opts.Roots.AddCert(certs[i])
for _, cert := range certs[1:] {
for _, cert := range certs[1 : i+1] { // stop at i
opts.Intermediates.AddCert(cert)
}
_, err := certs[0].Verify(opts)

View File

@@ -1,6 +1,7 @@
package ca
import (
"crypto/x509"
"encoding/pem"
"testing"
"time"
@@ -10,90 +11,203 @@ import (
)
func TestFingerprintVerifierLeaf(t *testing.T) {
leafFingerprint := CalculateFingerprint(leafPEM.Bytes)
verifier, err := NewFingerprintVerifier(leafFingerprint, func() time.Time {
return time.Unix(1677615892, 0)
})
leafFingerprint := CalculateFingerprint(leafCert.Raw)
verifier, err := NewFingerprintVerifier(leafFingerprint, certTime)
require.NoError(t, err)
err = verifier([][]byte{leafPEM.Bytes, intermediatePEM.Bytes, rootPEM.Bytes}, nil)
err = verifier([]*x509.Certificate{leafCert, intermediateCert, rootCert}, "")
assert.NoError(t, err)
err = verifier([][]byte{leafPEM.Bytes, smimeIntermediatePEM.Bytes, rootPEM.Bytes}, nil)
err = verifier([]*x509.Certificate{leafCert, intermediateCert, rootCert}, leafServerName)
assert.NoError(t, err)
err = verifier([][]byte{leafPEM.Bytes, intermediatePEM.Bytes, smimeRootPEM.Bytes}, nil)
err = verifier([]*x509.Certificate{leafCert, intermediateCert, rootCert}, wrongLeafServerName)
assert.NoError(t, err)
err = verifier([][]byte{leafWithInvalidHashPEM.Bytes, intermediatePEM.Bytes, rootPEM.Bytes}, nil)
err = verifier([]*x509.Certificate{leafCert, smimeIntermediateCert, rootCert}, "")
assert.NoError(t, err)
err = verifier([]*x509.Certificate{leafCert, smimeIntermediateCert, rootCert}, leafServerName)
assert.NoError(t, err)
err = verifier([]*x509.Certificate{leafCert, smimeIntermediateCert, rootCert}, wrongLeafServerName)
assert.NoError(t, err)
err = verifier([]*x509.Certificate{leafCert, intermediateCert, smimeRootCert}, "")
assert.NoError(t, err)
err = verifier([]*x509.Certificate{leafCert, intermediateCert, smimeRootCert}, leafServerName)
assert.NoError(t, err)
err = verifier([]*x509.Certificate{leafCert, intermediateCert, smimeRootCert}, wrongLeafServerName)
assert.NoError(t, err)
err = verifier([]*x509.Certificate{leafWithInvalidHashCert, intermediateCert, rootCert}, "")
assert.Error(t, err)
err = verifier([][]byte{smimeLeafPEM.Bytes, intermediatePEM.Bytes, rootPEM.Bytes}, nil)
err = verifier([]*x509.Certificate{leafWithInvalidHashCert, intermediateCert, rootCert}, leafServerName)
assert.Error(t, err)
err = verifier([][]byte{smimeLeafPEM.Bytes, intermediatePEM.Bytes, smimeRootPEM.Bytes}, nil)
err = verifier([]*x509.Certificate{leafWithInvalidHashCert, intermediateCert, rootCert}, wrongLeafServerName)
assert.Error(t, err)
err = verifier([]*x509.Certificate{smimeLeafCert, intermediateCert, rootCert}, "")
assert.Error(t, err)
err = verifier([]*x509.Certificate{smimeLeafCert, intermediateCert, rootCert}, leafServerName)
assert.Error(t, err)
err = verifier([]*x509.Certificate{smimeLeafCert, intermediateCert, rootCert}, wrongLeafServerName)
assert.Error(t, err)
err = verifier([]*x509.Certificate{smimeLeafCert, intermediateCert, smimeRootCert}, "")
assert.Error(t, err)
err = verifier([]*x509.Certificate{smimeLeafCert, intermediateCert, smimeRootCert}, leafServerName)
assert.Error(t, err)
err = verifier([]*x509.Certificate{smimeLeafCert, intermediateCert, smimeRootCert}, wrongLeafServerName)
assert.Error(t, err)
}
func TestFingerprintVerifierIntermediate(t *testing.T) {
intermediateFingerprint := CalculateFingerprint(intermediatePEM.Bytes)
verifier, err := NewFingerprintVerifier(intermediateFingerprint, func() time.Time {
return time.Unix(1677615892, 0)
})
intermediateFingerprint := CalculateFingerprint(intermediateCert.Raw)
verifier, err := NewFingerprintVerifier(intermediateFingerprint, certTime)
require.NoError(t, err)
err = verifier([][]byte{leafPEM.Bytes, intermediatePEM.Bytes, rootPEM.Bytes}, nil)
err = verifier([]*x509.Certificate{leafCert, intermediateCert, rootCert}, "")
assert.NoError(t, err)
err = verifier([][]byte{leafPEM.Bytes, smimeIntermediatePEM.Bytes, rootPEM.Bytes}, nil)
assert.Error(t, err)
err = verifier([][]byte{leafPEM.Bytes, intermediatePEM.Bytes, smimeRootPEM.Bytes}, nil)
err = verifier([]*x509.Certificate{leafCert, intermediateCert, rootCert}, leafServerName)
assert.NoError(t, err)
err = verifier([][]byte{leafWithInvalidHashPEM.Bytes, intermediatePEM.Bytes, rootPEM.Bytes}, nil)
err = verifier([]*x509.Certificate{leafCert, intermediateCert, rootCert}, wrongLeafServerName)
assert.Error(t, err)
err = verifier([][]byte{smimeLeafPEM.Bytes, intermediatePEM.Bytes, rootPEM.Bytes}, nil)
err = verifier([]*x509.Certificate{leafCert, smimeIntermediateCert, rootCert}, "")
assert.Error(t, err)
err = verifier([][]byte{smimeLeafPEM.Bytes, intermediatePEM.Bytes, smimeRootPEM.Bytes}, nil)
err = verifier([]*x509.Certificate{leafCert, smimeIntermediateCert, rootCert}, leafServerName)
assert.Error(t, err)
err = verifier([]*x509.Certificate{leafCert, smimeIntermediateCert, rootCert}, wrongLeafServerName)
assert.Error(t, err)
err = verifier([]*x509.Certificate{leafCert, intermediateCert, smimeRootCert}, "")
assert.NoError(t, err)
err = verifier([]*x509.Certificate{leafCert, intermediateCert, smimeRootCert}, leafServerName)
assert.NoError(t, err)
err = verifier([]*x509.Certificate{leafCert, intermediateCert, smimeRootCert}, wrongLeafServerName)
assert.Error(t, err)
err = verifier([]*x509.Certificate{leafWithInvalidHashCert, intermediateCert, rootCert}, "")
assert.Error(t, err)
err = verifier([]*x509.Certificate{leafWithInvalidHashCert, intermediateCert, rootCert}, leafServerName)
assert.Error(t, err)
err = verifier([]*x509.Certificate{leafWithInvalidHashCert, intermediateCert, rootCert}, wrongLeafServerName)
assert.Error(t, err)
err = verifier([]*x509.Certificate{smimeLeafCert, intermediateCert, rootCert}, "")
assert.Error(t, err)
err = verifier([]*x509.Certificate{smimeLeafCert, intermediateCert, rootCert}, leafServerName)
assert.Error(t, err)
err = verifier([]*x509.Certificate{smimeLeafCert, intermediateCert, rootCert}, wrongLeafServerName)
assert.Error(t, err)
err = verifier([]*x509.Certificate{smimeLeafCert, intermediateCert, smimeRootCert}, "")
assert.Error(t, err)
err = verifier([]*x509.Certificate{smimeLeafCert, intermediateCert, smimeRootCert}, leafServerName)
assert.Error(t, err)
err = verifier([]*x509.Certificate{smimeLeafCert, intermediateCert, smimeRootCert}, wrongLeafServerName)
assert.Error(t, err)
}
func TestFingerprintVerifierRoot(t *testing.T) {
rootFingerprint := CalculateFingerprint(rootPEM.Bytes)
verifier, err := NewFingerprintVerifier(rootFingerprint, func() time.Time {
return time.Unix(1677615892, 0)
})
rootFingerprint := CalculateFingerprint(rootCert.Raw)
verifier, err := NewFingerprintVerifier(rootFingerprint, certTime)
require.NoError(t, err)
err = verifier([][]byte{leafPEM.Bytes, intermediatePEM.Bytes, rootPEM.Bytes}, nil)
err = verifier([]*x509.Certificate{leafCert, intermediateCert, rootCert}, "")
assert.NoError(t, err)
err = verifier([][]byte{leafPEM.Bytes, smimeIntermediatePEM.Bytes, rootPEM.Bytes}, nil)
err = verifier([]*x509.Certificate{leafCert, intermediateCert, rootCert}, leafServerName)
assert.NoError(t, err)
err = verifier([]*x509.Certificate{leafCert, intermediateCert, rootCert}, wrongLeafServerName)
assert.Error(t, err)
err = verifier([][]byte{leafPEM.Bytes, intermediatePEM.Bytes, smimeRootPEM.Bytes}, nil)
err = verifier([]*x509.Certificate{leafCert, smimeIntermediateCert, rootCert}, "")
assert.Error(t, err)
err = verifier([][]byte{leafWithInvalidHashPEM.Bytes, intermediatePEM.Bytes, rootPEM.Bytes}, nil)
err = verifier([]*x509.Certificate{leafCert, smimeIntermediateCert, rootCert}, leafServerName)
assert.Error(t, err)
err = verifier([][]byte{smimeLeafPEM.Bytes, intermediatePEM.Bytes, rootPEM.Bytes}, nil)
err = verifier([]*x509.Certificate{leafCert, smimeIntermediateCert, rootCert}, wrongLeafServerName)
assert.Error(t, err)
err = verifier([][]byte{smimeLeafPEM.Bytes, intermediatePEM.Bytes, smimeRootPEM.Bytes}, nil)
err = verifier([]*x509.Certificate{leafCert, intermediateCert, smimeRootCert}, "")
assert.Error(t, err)
err = verifier([]*x509.Certificate{leafCert, intermediateCert, smimeRootCert}, leafServerName)
assert.Error(t, err)
err = verifier([]*x509.Certificate{leafCert, intermediateCert, smimeRootCert}, wrongLeafServerName)
assert.Error(t, err)
err = verifier([]*x509.Certificate{leafWithInvalidHashCert, intermediateCert, rootCert}, "")
assert.Error(t, err)
err = verifier([]*x509.Certificate{leafWithInvalidHashCert, intermediateCert, rootCert}, leafServerName)
assert.Error(t, err)
err = verifier([]*x509.Certificate{leafWithInvalidHashCert, intermediateCert, rootCert}, wrongLeafServerName)
assert.Error(t, err)
err = verifier([]*x509.Certificate{smimeLeafCert, intermediateCert, rootCert}, "")
assert.Error(t, err)
err = verifier([]*x509.Certificate{smimeLeafCert, intermediateCert, rootCert}, leafServerName)
assert.Error(t, err)
err = verifier([]*x509.Certificate{smimeLeafCert, intermediateCert, rootCert}, wrongLeafServerName)
assert.Error(t, err)
err = verifier([]*x509.Certificate{smimeLeafCert, intermediateCert, smimeRootCert}, "")
assert.Error(t, err)
err = verifier([]*x509.Certificate{smimeLeafCert, intermediateCert, smimeRootCert}, leafServerName)
assert.Error(t, err)
err = verifier([]*x509.Certificate{smimeLeafCert, intermediateCert, smimeRootCert}, wrongLeafServerName)
assert.Error(t, err)
}
var rootPEM, _ = pem.Decode([]byte(gtsRoot))
var rootCert, _ = x509.ParseCertificate(rootPEM.Bytes)
var intermediatePEM, _ = pem.Decode([]byte(gtsIntermediate))
var intermediateCert, _ = x509.ParseCertificate(intermediatePEM.Bytes)
var leafPEM, _ = pem.Decode([]byte(googleLeaf))
var leafCert, _ = x509.ParseCertificate(leafPEM.Bytes)
var leafWithInvalidHashPEM, _ = pem.Decode([]byte(googleLeafWithInvalidHash))
var leafWithInvalidHashCert, _ = x509.ParseCertificate(leafWithInvalidHashPEM.Bytes)
var smimeRootPEM, _ = pem.Decode([]byte(smimeRoot))
var smimeRootCert, _ = x509.ParseCertificate(smimeRootPEM.Bytes)
var smimeIntermediatePEM, _ = pem.Decode([]byte(smimeIntermediate))
var smimeIntermediateCert, _ = x509.ParseCertificate(smimeIntermediatePEM.Bytes)
var smimeLeafPEM, _ = pem.Decode([]byte(smimeLeaf))
var smimeLeafCert, _ = x509.ParseCertificate(smimeLeafPEM.Bytes)
var certTime = func() time.Time { return time.Unix(1677615892, 0) }
const leafServerName = "www.google.com"
const wrongLeafServerName = "www.google.com.cn"
const gtsIntermediate = `-----BEGIN CERTIFICATE-----
MIIFljCCA36gAwIBAgINAgO8U1lrNMcY9QFQZjANBgkqhkiG9w0BAQsFADBHMQsw

View File

@@ -7,71 +7,85 @@ import (
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"fmt"
"math/big"
"os"
"runtime"
"sync"
"time"
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/fswatch"
"github.com/metacubex/tls"
)
type Path interface {
Resolve(path string) string
IsSafePath(path string) bool
ErrNotSafePath(path string) error
}
// LoadTLSKeyPair loads a TLS key pair from the provided certificate and private key data or file paths, supporting fallback resolution.
// Returns a tls.Certificate and an error, where the error indicates issues during parsing or file loading.
// NewTLSKeyPairLoader creates a loader function for TLS key pairs from the provided certificate and private key data or file paths.
// If both certificate and privateKey are empty, generates a random TLS RSA key pair.
// Accepts a Path interface for resolving file paths when necessary.
func LoadTLSKeyPair(certificate, privateKey string, path Path) (tls.Certificate, error) {
func NewTLSKeyPairLoader(certificate, privateKey string) (func() (*tls.Certificate, error), error) {
if certificate == "" && privateKey == "" {
var err error
certificate, privateKey, _, err = NewRandomTLSKeyPair(KeyPairTypeRSA)
if err != nil {
return tls.Certificate{}, err
return nil, err
}
}
cert, painTextErr := tls.X509KeyPair([]byte(certificate), []byte(privateKey))
if painTextErr == nil {
return cert, nil
}
if path == nil {
return tls.Certificate{}, painTextErr
return func() (*tls.Certificate, error) {
return &cert, nil
}, nil
}
certificate = path.Resolve(certificate)
privateKey = path.Resolve(privateKey)
certificate = C.Path.Resolve(certificate)
privateKey = C.Path.Resolve(privateKey)
var loadErr error
if !path.IsSafePath(certificate) {
loadErr = path.ErrNotSafePath(certificate)
} else if !path.IsSafePath(privateKey) {
loadErr = path.ErrNotSafePath(privateKey)
if !C.Path.IsSafePath(certificate) {
loadErr = C.Path.ErrNotSafePath(certificate)
} else if !C.Path.IsSafePath(privateKey) {
loadErr = C.Path.ErrNotSafePath(privateKey)
} else {
cert, loadErr = tls.LoadX509KeyPair(certificate, privateKey)
}
if loadErr != nil {
return tls.Certificate{}, fmt.Errorf("parse certificate failed, maybe format error:%s, or path error: %s", painTextErr.Error(), loadErr.Error())
return nil, fmt.Errorf("parse certificate failed, maybe format error:%s, or path error: %s", painTextErr.Error(), loadErr.Error())
}
return cert, nil
gcFlag := new(os.File) // tiny (on the order of 16 bytes or less) and pointer-free objects may never run the finalizer, so we choose new an os.File
updateMutex := sync.RWMutex{}
if watcher, err := fswatch.NewWatcher(fswatch.Options{Path: []string{certificate, privateKey}, Callback: func(path string) {
updateMutex.Lock()
defer updateMutex.Unlock()
if newCert, err := tls.LoadX509KeyPair(certificate, privateKey); err == nil {
cert = newCert
}
}}); err == nil {
if err = watcher.Start(); err == nil {
runtime.SetFinalizer(gcFlag, func(f *os.File) {
_ = watcher.Close()
})
}
}
return func() (*tls.Certificate, error) {
defer runtime.KeepAlive(gcFlag)
updateMutex.RLock()
defer updateMutex.RUnlock()
return &cert, nil
}, nil
}
func LoadCertificates(certificate string, path Path) (*x509.CertPool, error) {
func LoadCertificates(certificate string) (*x509.CertPool, error) {
pool := x509.NewCertPool()
if pool.AppendCertsFromPEM([]byte(certificate)) {
return pool, nil
}
painTextErr := fmt.Errorf("invalid certificate: %s", certificate)
if path == nil {
return nil, painTextErr
}
certificate = path.Resolve(certificate)
certificate = C.Path.Resolve(certificate)
var loadErr error
if !path.IsSafePath(certificate) {
loadErr = path.ErrNotSafePath(certificate)
if !C.Path.IsSafePath(certificate) {
loadErr = C.Path.ErrNotSafePath(certificate)
} else {
certPEMBlock, err := os.ReadFile(certificate)
if pool.AppendCertsFromPEM(certPEMBlock) {
@@ -82,6 +96,9 @@ func LoadCertificates(certificate string, path Path) (*x509.CertPool, error) {
if loadErr != nil {
return nil, fmt.Errorf("parse certificate failed, maybe format error:%s, or path error: %s", painTextErr.Error(), loadErr.Error())
}
//TODO: support dynamic update pool too
// blocked by: https://github.com/golang/go/issues/64796
// maybe we can direct add `GetRootCAs` and `GetClientCAs` to ourselves tls fork
return pool, nil
}

View File

@@ -12,25 +12,31 @@ import (
"syscall"
"time"
"github.com/metacubex/mihomo/common/atomic"
"github.com/metacubex/mihomo/component/keepalive"
"github.com/metacubex/mihomo/component/mptcp"
"github.com/metacubex/mihomo/component/resolver"
)
const (
DefaultTCPTimeout = 5 * time.Second
DefaultUDPTimeout = DefaultTCPTimeout
)
type dialFunc func(ctx context.Context, network string, ips []netip.Addr, port string, opt option) (net.Conn, error)
dualStackFallbackTimeout = 300 * time.Millisecond
)
var (
dialMux sync.Mutex
actualSingleStackDialContext = serialSingleStackDialContext
actualDualStackDialContext = serialDualStackDialContext
tcpConcurrent = false
fallbackTimeout = 300 * time.Millisecond
tcpConcurrent = atomic.NewBool(false)
)
func SetTcpConcurrent(concurrent bool) {
tcpConcurrent.Store(concurrent)
}
func GetTcpConcurrent() bool {
return tcpConcurrent.Load()
}
func DialContext(ctx context.Context, network, address string, options ...Option) (net.Conn, error) {
opt := applyOptions(options...)
@@ -49,11 +55,22 @@ func DialContext(ctx context.Context, network, address string, options ...Option
return nil, err
}
tcpConcurrent := GetTcpConcurrent()
switch network {
case "tcp4", "tcp6", "udp4", "udp6":
return actualSingleStackDialContext(ctx, network, ips, port, opt)
if tcpConcurrent {
return parallelDialContext(ctx, network, ips, port, opt)
}
return serialDialContext(ctx, network, ips, port, opt)
case "tcp", "udp":
return actualDualStackDialContext(ctx, network, ips, port, opt)
if tcpConcurrent {
if opt.prefer != 4 && opt.prefer != 6 {
return parallelDialContext(ctx, network, ips, port, opt)
}
return dualStackDialContext(ctx, parallelDialContext, network, ips, port, opt)
}
return dualStackDialContext(ctx, serialDialContext, network, ips, port, opt)
default:
return nil, ErrorInvalidedNetworkStack
}
@@ -103,25 +120,6 @@ func ListenPacket(ctx context.Context, network, address string, rAddrPort netip.
return lc.ListenPacket(ctx, network, address)
}
func SetTcpConcurrent(concurrent bool) {
dialMux.Lock()
defer dialMux.Unlock()
tcpConcurrent = concurrent
if concurrent {
actualSingleStackDialContext = concurrentSingleStackDialContext
actualDualStackDialContext = concurrentDualStackDialContext
} else {
actualSingleStackDialContext = serialSingleStackDialContext
actualDualStackDialContext = serialDualStackDialContext
}
}
func GetTcpConcurrent() bool {
dialMux.Lock()
defer dialMux.Unlock()
return tcpConcurrent
}
func dialContext(ctx context.Context, network string, destination netip.Addr, port string, opt option) (net.Conn, error) {
var address string
destination, port = resolver.LookupIP4P(destination, port)
@@ -140,9 +138,7 @@ func dialContext(ctx context.Context, network string, destination netip.Addr, po
dialer := netDialer.(*net.Dialer)
keepalive.SetNetDialer(dialer)
if opt.mpTcp {
setMultiPathTCP(dialer)
}
mptcp.SetNetDialer(dialer, opt.mpTcp)
if DefaultSocketHook != nil { // ignore interfaceName, routingMark and tfo when DefaultSocketHook not null (in CMFA)
socketHookToToDialer(dialer)
@@ -206,24 +202,7 @@ func ICMPControl(destination netip.Addr) func(network, address string, conn sysc
}
}
func serialSingleStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt option) (net.Conn, error) {
return serialDialContext(ctx, network, ips, port, opt)
}
func serialDualStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt option) (net.Conn, error) {
return dualStackDialContext(ctx, serialDialContext, network, ips, port, opt)
}
func concurrentSingleStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt option) (net.Conn, error) {
return parallelDialContext(ctx, network, ips, port, opt)
}
func concurrentDualStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt option) (net.Conn, error) {
if opt.prefer != 4 && opt.prefer != 6 {
return parallelDialContext(ctx, network, ips, port, opt)
}
return dualStackDialContext(ctx, parallelDialContext, network, ips, port, opt)
}
type dialFunc func(ctx context.Context, network string, ips []netip.Addr, port string, opt option) (net.Conn, error)
func dualStackDialContext(ctx context.Context, dialFn dialFunc, network string, ips []netip.Addr, port string, opt option) (net.Conn, error) {
ipv4s, ipv6s := resolver.SortationAddr(ips)
@@ -232,7 +211,7 @@ func dualStackDialContext(ctx context.Context, dialFn dialFunc, network string,
}
preferIPVersion := opt.prefer
fallbackTicker := time.NewTicker(fallbackTimeout)
fallbackTicker := time.NewTicker(dualStackFallbackTimeout)
defer fallbackTicker.Stop()
results := make(chan dialResult)

View File

@@ -1,12 +0,0 @@
//go:build !go1.21
package dialer
import (
"net"
)
const multipathTCPAvailable = false
func setMultiPathTCP(dialer *net.Dialer) {
}

View File

@@ -1,11 +0,0 @@
//go:build go1.21
package dialer
import "net"
const multipathTCPAvailable = true
func setMultiPathTCP(dialer *net.Dialer) {
dialer.SetMultipathTCP(true)
}

View File

@@ -5,13 +5,33 @@ import (
"fmt"
tlsC "github.com/metacubex/mihomo/component/tls"
"github.com/metacubex/tls"
)
type Config struct {
GetEncryptedClientHelloConfigList func(ctx context.Context, serverName string) ([]byte, error)
}
func (cfg *Config) ClientHandle(ctx context.Context, tlsConfig *tlsC.Config) (err error) {
func (cfg *Config) ClientHandle(ctx context.Context, tlsConfig *tls.Config) (err error) {
if cfg == nil {
return nil
}
echConfigList, err := cfg.GetEncryptedClientHelloConfigList(ctx, tlsConfig.ServerName)
if err != nil {
return fmt.Errorf("resolve ECH config error: %w", err)
}
tlsConfig.EncryptedClientHelloConfigList = echConfigList
if tlsConfig.MinVersion != 0 && tlsConfig.MinVersion < tls.VersionTLS13 {
tlsConfig.MinVersion = tls.VersionTLS13
}
if tlsConfig.MaxVersion != 0 && tlsConfig.MaxVersion < tls.VersionTLS13 {
tlsConfig.MaxVersion = tls.VersionTLS13
}
return nil
}
func (cfg *Config) ClientHandleUTLS(ctx context.Context, tlsConfig *tlsC.Config) (err error) {
if cfg == nil {
return nil
}

View File

@@ -0,0 +1,147 @@
package echparser
import (
"errors"
"fmt"
"golang.org/x/crypto/cryptobyte"
)
// export from std's crypto/tls/ech.go
const extensionEncryptedClientHello = 0xfe0d
type ECHCipher struct {
KDFID uint16
AEADID uint16
}
type ECHExtension struct {
Type uint16
Data []byte
}
type ECHConfig struct {
raw []byte
Version uint16
Length uint16
ConfigID uint8
KemID uint16
PublicKey []byte
SymmetricCipherSuite []ECHCipher
MaxNameLength uint8
PublicName []byte
Extensions []ECHExtension
}
var ErrMalformedECHConfigList = errors.New("tls: malformed ECHConfigList")
type EchConfigErr struct {
field string
}
func (e *EchConfigErr) Error() string {
if e.field == "" {
return "tls: malformed ECHConfig"
}
return fmt.Sprintf("tls: malformed ECHConfig, invalid %s field", e.field)
}
func ParseECHConfig(enc []byte) (skip bool, ec ECHConfig, err error) {
s := cryptobyte.String(enc)
ec.raw = []byte(enc)
if !s.ReadUint16(&ec.Version) {
return false, ECHConfig{}, &EchConfigErr{"version"}
}
if !s.ReadUint16(&ec.Length) {
return false, ECHConfig{}, &EchConfigErr{"length"}
}
if len(ec.raw) < int(ec.Length)+4 {
return false, ECHConfig{}, &EchConfigErr{"length"}
}
ec.raw = ec.raw[:ec.Length+4]
if ec.Version != extensionEncryptedClientHello {
s.Skip(int(ec.Length))
return true, ECHConfig{}, nil
}
if !s.ReadUint8(&ec.ConfigID) {
return false, ECHConfig{}, &EchConfigErr{"config_id"}
}
if !s.ReadUint16(&ec.KemID) {
return false, ECHConfig{}, &EchConfigErr{"kem_id"}
}
if !s.ReadUint16LengthPrefixed((*cryptobyte.String)(&ec.PublicKey)) {
return false, ECHConfig{}, &EchConfigErr{"public_key"}
}
var cipherSuites cryptobyte.String
if !s.ReadUint16LengthPrefixed(&cipherSuites) {
return false, ECHConfig{}, &EchConfigErr{"cipher_suites"}
}
for !cipherSuites.Empty() {
var c ECHCipher
if !cipherSuites.ReadUint16(&c.KDFID) {
return false, ECHConfig{}, &EchConfigErr{"cipher_suites kdf_id"}
}
if !cipherSuites.ReadUint16(&c.AEADID) {
return false, ECHConfig{}, &EchConfigErr{"cipher_suites aead_id"}
}
ec.SymmetricCipherSuite = append(ec.SymmetricCipherSuite, c)
}
if !s.ReadUint8(&ec.MaxNameLength) {
return false, ECHConfig{}, &EchConfigErr{"maximum_name_length"}
}
var publicName cryptobyte.String
if !s.ReadUint8LengthPrefixed(&publicName) {
return false, ECHConfig{}, &EchConfigErr{"public_name"}
}
ec.PublicName = publicName
var extensions cryptobyte.String
if !s.ReadUint16LengthPrefixed(&extensions) {
return false, ECHConfig{}, &EchConfigErr{"extensions"}
}
for !extensions.Empty() {
var e ECHExtension
if !extensions.ReadUint16(&e.Type) {
return false, ECHConfig{}, &EchConfigErr{"extensions type"}
}
if !extensions.ReadUint16LengthPrefixed((*cryptobyte.String)(&e.Data)) {
return false, ECHConfig{}, &EchConfigErr{"extensions data"}
}
ec.Extensions = append(ec.Extensions, e)
}
return false, ec, nil
}
// ParseECHConfigList parses a draft-ietf-tls-esni-18 ECHConfigList, returning a
// slice of parsed ECHConfigs, in the same order they were parsed, or an error
// if the list is malformed.
func ParseECHConfigList(data []byte) ([]ECHConfig, error) {
s := cryptobyte.String(data)
var length uint16
if !s.ReadUint16(&length) {
return nil, ErrMalformedECHConfigList
}
if length != uint16(len(data)-2) {
return nil, ErrMalformedECHConfigList
}
var configs []ECHConfig
for len(s) > 0 {
if len(s) < 4 {
return nil, errors.New("tls: malformed ECHConfig")
}
configLen := uint16(s[2])<<8 | uint16(s[3])
skip, ec, err := ParseECHConfig(s)
if err != nil {
return nil, err
}
s = s[configLen+4:]
if !skip {
configs = append(configs, ec)
}
}
return configs, nil
}

View File

@@ -8,10 +8,13 @@ import (
"errors"
"fmt"
"os"
"runtime"
"sync"
"github.com/metacubex/mihomo/component/ca"
tlsC "github.com/metacubex/mihomo/component/tls"
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/fswatch"
"github.com/metacubex/tls"
"golang.org/x/crypto/cryptobyte"
)
@@ -85,11 +88,11 @@ func GenECHConfig(publicName string) (configBase64 string, keyPem string, err er
return
}
func UnmarshalECHKeys(raw []byte) ([]tlsC.EncryptedClientHelloKey, error) {
var keys []tlsC.EncryptedClientHelloKey
func UnmarshalECHKeys(raw []byte) ([]tls.EncryptedClientHelloKey, error) {
var keys []tls.EncryptedClientHelloKey
rawString := cryptobyte.String(raw)
for !rawString.Empty() {
var key tlsC.EncryptedClientHelloKey
var key tls.EncryptedClientHelloKey
if !rawString.ReadUint16LengthPrefixed((*cryptobyte.String)(&key.PrivateKey)) {
return nil, errors.New("error parsing private key")
}
@@ -104,40 +107,65 @@ func UnmarshalECHKeys(raw []byte) ([]tlsC.EncryptedClientHelloKey, error) {
return keys, nil
}
func LoadECHKey(key string, tlsConfig *tlsC.Config, path ca.Path) error {
func LoadECHKey(key string, tlsConfig *tls.Config) error {
if key == "" {
return nil
}
painTextErr := loadECHKey([]byte(key), tlsConfig)
echKeys, painTextErr := loadECHKey([]byte(key))
if painTextErr == nil {
tlsConfig.GetEncryptedClientHelloKeys = func(info *tls.ClientHelloInfo) ([]tls.EncryptedClientHelloKey, error) {
return echKeys, nil
}
return nil
}
key = path.Resolve(key)
key = C.Path.Resolve(key)
var loadErr error
if !path.IsSafePath(key) {
loadErr = path.ErrNotSafePath(key)
if !C.Path.IsSafePath(key) {
loadErr = C.Path.ErrNotSafePath(key)
} else {
var echKey []byte
echKey, loadErr = os.ReadFile(key)
if loadErr == nil {
loadErr = loadECHKey(echKey, tlsConfig)
echKeys, loadErr = loadECHKey(echKey)
}
}
if loadErr != nil {
return fmt.Errorf("parse ECH keys failed, maybe format error:%s, or path error: %s", painTextErr.Error(), loadErr.Error())
}
gcFlag := new(os.File) // tiny (on the order of 16 bytes or less) and pointer-free objects may never run the finalizer, so we choose new an os.File
updateMutex := sync.RWMutex{}
if watcher, err := fswatch.NewWatcher(fswatch.Options{Path: []string{key}, Callback: func(path string) {
updateMutex.Lock()
defer updateMutex.Unlock()
if echKey, err := os.ReadFile(key); err == nil {
if newEchKeys, err := loadECHKey(echKey); err == nil {
echKeys = newEchKeys
}
}
}}); err == nil {
if err = watcher.Start(); err == nil {
runtime.SetFinalizer(gcFlag, func(f *os.File) {
_ = watcher.Close()
})
}
}
tlsConfig.GetEncryptedClientHelloKeys = func(info *tls.ClientHelloInfo) ([]tls.EncryptedClientHelloKey, error) {
defer runtime.KeepAlive(gcFlag)
updateMutex.RLock()
defer updateMutex.RUnlock()
return echKeys, nil
}
return nil
}
func loadECHKey(echKey []byte, tlsConfig *tlsC.Config) error {
func loadECHKey(echKey []byte) ([]tls.EncryptedClientHelloKey, error) {
block, rest := pem.Decode(echKey)
if block == nil || block.Type != "ECH KEYS" || len(rest) > 0 {
return errors.New("invalid ECH keys pem")
return nil, errors.New("invalid ECH keys pem")
}
echKeys, err := UnmarshalECHKeys(block.Bytes)
if err != nil {
return fmt.Errorf("parse ECH keys: %w", err)
return nil, fmt.Errorf("parse ECH keys: %w", err)
}
tlsConfig.EncryptedClientHelloKeys = echKeys
return nil
return echKeys, err
}

30
component/ech/key_test.go Normal file
View File

@@ -0,0 +1,30 @@
package ech
import (
"encoding/base64"
"testing"
"github.com/metacubex/mihomo/component/ech/echparser"
)
func TestGenECHConfig(t *testing.T) {
domain := "www.example.com"
configBase64, _, err := GenECHConfig(domain)
if err != nil {
t.Error(err)
}
echConfigList, err := base64.StdEncoding.DecodeString(configBase64)
if err != nil {
t.Error(err)
}
echConfigs, err := echparser.ParseECHConfigList(echConfigList)
if err != nil {
t.Error(err)
}
if len(echConfigs) == 0 {
t.Error("no ech config")
}
if publicName := string(echConfigs[0].PublicName); publicName != domain {
t.Error("ech config domain error, expect ", domain, " got", publicName)
}
}

View File

@@ -4,13 +4,29 @@ import (
C "github.com/metacubex/mihomo/constant"
)
const (
UseFakeIP = "fake-ip"
UseRealIP = "real-ip"
)
type Skipper struct {
Host []C.DomainMatcher
Mode C.FilterMode
Rules []C.Rule
Host []C.DomainMatcher
Mode C.FilterMode
}
// ShouldSkipped return if domain should be skipped
func (p *Skipper) ShouldSkipped(domain string) bool {
if len(p.Rules) > 0 {
metadata := &C.Metadata{Host: domain}
for _, rule := range p.Rules {
if matched, action := rule.Match(metadata, C.RuleMatchHelper{}); matched {
return action == UseRealIP
}
}
return false
}
should := p.shouldSkipped(domain)
if p.Mode == C.FilterWhiteList {
return !should

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"github.com/metacubex/mihomo/component/ech"
"github.com/metacubex/mihomo/transport/sudoku"
"github.com/metacubex/mihomo/transport/vless/encryption"
"github.com/gofrs/uuid/v5"
@@ -12,7 +13,7 @@ import (
func Main(args []string) {
if len(args) < 1 {
panic("Using: generate uuid/reality-keypair/wg-keypair/ech-keypair/vless-mlkem768/vless-x25519")
panic("Using: generate uuid/reality-keypair/wg-keypair/ech-keypair/vless-mlkem768/vless-x25519/sudoku-keypair")
}
switch args[0] {
case "uuid":
@@ -57,6 +58,11 @@ func Main(args []string) {
fmt.Println("Seed: " + seedBase64)
fmt.Println("Client: " + clientBase64)
fmt.Println("Hash32: " + hash32Base64)
fmt.Println("-----------------------")
fmt.Println(" Lazy-Config ")
fmt.Println("-----------------------")
fmt.Printf("[Server] decryption: \"mlkem768x25519plus.native.600s.%s\"\n", seedBase64)
fmt.Printf("[Client] encryption: \"mlkem768x25519plus.native.0rtt.%s\"\n", clientBase64)
case "vless-x25519":
var privateKey string
if len(args) > 1 {
@@ -69,5 +75,18 @@ func Main(args []string) {
fmt.Println("PrivateKey: " + privateKeyBase64)
fmt.Println("Password: " + passwordBase64)
fmt.Println("Hash32: " + hash32Base64)
fmt.Println("-----------------------")
fmt.Println(" Lazy-Config ")
fmt.Println("-----------------------")
fmt.Printf("[Server] decryption: \"mlkem768x25519plus.native.600s.%s\"\n", privateKeyBase64)
fmt.Printf("[Client] encryption: \"mlkem768x25519plus.native.0rtt.%s\"\n", passwordBase64)
case "sudoku-keypair":
privateKey, publicKey, err := sudoku.GenKeyPair()
if err != nil {
panic(err)
}
// Output: Available Private Key for client, Master Public Key for server
fmt.Println("PrivateKey: " + privateKey)
fmt.Println("PublicKey: " + publicKey)
}
}

View File

@@ -4,7 +4,6 @@ import (
"context"
"fmt"
"io"
"net/http"
"os"
"sync"
"time"
@@ -14,6 +13,8 @@ import (
"github.com/metacubex/mihomo/component/mmdb"
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/log"
"github.com/metacubex/http"
)
var (

View File

@@ -4,7 +4,6 @@ import (
"context"
"io"
"net"
"net/http"
URL "net/url"
"runtime"
"strings"
@@ -13,6 +12,8 @@ import (
"github.com/metacubex/mihomo/component/ca"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/listener/inner"
"github.com/metacubex/http"
)
var (

View File

@@ -0,0 +1,23 @@
//go:build !go1.21
package mptcp
import (
"net"
)
const MultipathTCPAvailable = false
func SetNetDialer(dialer *net.Dialer, open bool) {
}
func GetNetDialer(dialer *net.Dialer) bool {
return false
}
func SetNetListenConfig(listenConfig *net.ListenConfig, open bool) {
}
func GetNetListenConfig(listenConfig *net.ListenConfig) bool {
return false
}

View File

@@ -0,0 +1,23 @@
//go:build go1.21
package mptcp
import "net"
const MultipathTCPAvailable = true
func SetNetDialer(dialer *net.Dialer, open bool) {
dialer.SetMultipathTCP(open)
}
func GetNetDialer(dialer *net.Dialer) bool {
return dialer.MultipathTCP()
}
func SetNetListenConfig(listenConfig *net.ListenConfig, open bool) {
listenConfig.SetMultipathTCP(open)
}
func GetNetListenConfig(listenConfig *net.ListenConfig) bool {
return listenConfig.MultipathTCP()
}

View File

@@ -197,6 +197,10 @@ func newSearcher(major int) *searcher {
case 12:
fallthrough
case 13:
fallthrough
case 14:
fallthrough
case 15:
s = &searcher{
headSize: 64,
tcpItemSize: 744,

View File

@@ -0,0 +1,37 @@
package proxydialer
import (
"context"
"fmt"
"net"
"net/netip"
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/tunnel"
)
type byNameProxyDialer struct {
proxyName string
}
func (d byNameProxyDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
proxies := tunnel.Proxies()
proxy, ok := proxies[d.proxyName]
if !ok {
return nil, fmt.Errorf("proxyName[%s] not found", d.proxyName)
}
return New(proxy, true).DialContext(ctx, network, address)
}
func (d byNameProxyDialer) ListenPacket(ctx context.Context, network, address string, rAddrPort netip.AddrPort) (net.PacketConn, error) {
proxies := tunnel.Proxies()
proxy, ok := proxies[d.proxyName]
if !ok {
return nil, fmt.Errorf("proxyName[%s] not found", d.proxyName)
}
return New(proxy, true).ListenPacket(ctx, network, address, rAddrPort)
}
func NewByName(proxyName string) C.Dialer {
return byNameProxyDialer{proxyName: proxyName}
}

View File

@@ -2,34 +2,22 @@ package proxydialer
import (
"context"
"fmt"
"net"
"net/netip"
"strings"
N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/component/dialer"
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/tunnel"
"github.com/metacubex/mihomo/tunnel/statistic"
)
type proxyDialer struct {
proxy C.ProxyAdapter
dialer C.Dialer
statistic bool
}
func New(proxy C.ProxyAdapter, dialer C.Dialer, statistic bool) C.Dialer {
return proxyDialer{proxy: proxy, dialer: dialer, statistic: statistic}
}
func NewByName(proxyName string, dialer C.Dialer) (C.Dialer, error) {
proxies := tunnel.Proxies()
if proxy, ok := proxies[proxyName]; ok {
return New(proxy, dialer, true), nil
}
return nil, fmt.Errorf("proxyName[%s] not found", proxyName)
func New(proxy C.ProxyAdapter, statistic bool) C.Dialer {
return proxyDialer{proxy: proxy, statistic: statistic}
}
func (p proxyDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
@@ -50,13 +38,7 @@ func (p proxyDialer) DialContext(ctx context.Context, network, address string) (
}
return N.NewBindPacketConn(pc, currentMeta.UDPAddr()), nil
}
var conn C.Conn
var err error
if _, ok := p.dialer.(dialer.Dialer); ok { // first using old function to let mux work
conn, err = p.proxy.DialContext(ctx, currentMeta)
} else {
conn, err = p.proxy.DialContextWithDialer(ctx, p.dialer, currentMeta)
}
conn, err := p.proxy.DialContext(ctx, currentMeta)
if err != nil {
return nil, err
}
@@ -72,14 +54,8 @@ func (p proxyDialer) ListenPacket(ctx context.Context, network, address string,
}
func (p proxyDialer) listenPacket(ctx context.Context, currentMeta *C.Metadata) (C.PacketConn, error) {
var pc C.PacketConn
var err error
currentMeta.NetWork = C.UDP
if _, ok := p.dialer.(dialer.Dialer); ok { // first using old function to let mux work
pc, err = p.proxy.ListenPacketContext(ctx, currentMeta)
} else {
pc, err = p.proxy.ListenPacketWithDialer(ctx, p.dialer, currentMeta)
}
pc, err := p.proxy.ListenPacketContext(ctx, currentMeta)
if err != nil {
return nil, err
}

View File

@@ -12,71 +12,22 @@ import (
type SingDialer interface {
N.Dialer
SetDialer(dialer C.Dialer)
}
type singDialer proxyDialer
type singDialer struct {
cDialer C.Dialer
}
var _ N.Dialer = (*singDialer)(nil)
func (d *singDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
return (*proxyDialer)(d).DialContext(ctx, network, destination.String())
func (d singDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
return d.cDialer.DialContext(ctx, network, destination.String())
}
func (d *singDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
return (*proxyDialer)(d).ListenPacket(ctx, "udp", "", destination.AddrPort())
func (d singDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
return d.cDialer.ListenPacket(ctx, "udp", "", destination.AddrPort())
}
func (d *singDialer) SetDialer(dialer C.Dialer) {
(*proxyDialer)(d).dialer = dialer
}
func NewSingDialer(proxy C.ProxyAdapter, dialer C.Dialer, statistic bool) SingDialer {
return (*singDialer)(&proxyDialer{
proxy: proxy,
dialer: dialer,
statistic: statistic,
})
}
type byNameSingDialer struct {
dialer C.Dialer
proxyName string
}
var _ N.Dialer = (*byNameSingDialer)(nil)
func (d *byNameSingDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
var cDialer C.Dialer = d.dialer
if len(d.proxyName) > 0 {
pd, err := NewByName(d.proxyName, d.dialer)
if err != nil {
return nil, err
}
cDialer = pd
}
return cDialer.DialContext(ctx, network, destination.String())
}
func (d *byNameSingDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
var cDialer C.Dialer = d.dialer
if len(d.proxyName) > 0 {
pd, err := NewByName(d.proxyName, d.dialer)
if err != nil {
return nil, err
}
cDialer = pd
}
return cDialer.ListenPacket(ctx, "udp", "", destination.AddrPort())
}
func (d *byNameSingDialer) SetDialer(dialer C.Dialer) {
d.dialer = dialer
}
func NewByNameSingDialer(proxyName string, dialer C.Dialer) SingDialer {
return &byNameSingDialer{
dialer: dialer,
proxyName: proxyName,
}
func NewSingDialer(cDialer C.Dialer) SingDialer {
return singDialer{cDialer: cDialer}
}

View File

@@ -8,7 +8,6 @@ import (
"strings"
_ "unsafe"
"github.com/metacubex/mihomo/common/utils"
"github.com/metacubex/mihomo/component/resolver/hosts"
"github.com/metacubex/mihomo/component/trie"
"github.com/metacubex/randv2"
@@ -66,37 +65,35 @@ type HostValue struct {
Domain string
}
func NewHostValue(value any) (HostValue, error) {
func NewHostValue(value []string) (HostValue, error) {
isDomain := true
ips := make([]netip.Addr, 0)
ips := make([]netip.Addr, 0, len(value))
domain := ""
if valueArr, err := utils.ToStringSlice(value); err != nil {
return HostValue{}, err
} else {
if len(valueArr) > 1 {
switch len(value) {
case 0:
return HostValue{}, errors.New("value is empty")
case 1:
host := value[0]
if ip, err := netip.ParseAddr(host); err == nil {
ips = append(ips, ip.Unmap())
isDomain = false
for _, str := range valueArr {
if ip, err := netip.ParseAddr(str); err == nil {
ips = append(ips, ip.Unmap())
} else {
return HostValue{}, err
}
}
} else if len(valueArr) == 1 {
host := valueArr[0]
if ip, err := netip.ParseAddr(host); err == nil {
} else {
domain = host
}
default: // > 1
isDomain = false
for _, str := range value {
if ip, err := netip.ParseAddr(str); err == nil {
ips = append(ips, ip.Unmap())
isDomain = false
} else {
domain = host
return HostValue{}, err
}
}
}
if isDomain {
return NewHostValueByDomain(domain)
} else {
return NewHostValueByIPs(ips)
}
return NewHostValueByIPs(ips)
}
func NewHostValueByIPs(ips []netip.Addr) (HostValue, error) {

View File

@@ -4,7 +4,6 @@ import (
"context"
"errors"
"io"
"net/http"
"os"
"path/filepath"
"time"
@@ -13,6 +12,8 @@ import (
mihomoHttp "github.com/metacubex/mihomo/component/http"
"github.com/metacubex/mihomo/component/profile/cachefile"
P "github.com/metacubex/mihomo/constant/provider"
"github.com/metacubex/http"
)
const (

View File

@@ -3,14 +3,13 @@ package tls
import (
"context"
"net"
"net/http"
"runtime/debug"
"time"
N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/log"
"golang.org/x/net/http2"
"github.com/metacubex/http"
)
func extractTlsHandshakeTimeoutFromServer(s *http.Server) time.Duration {
@@ -35,8 +34,8 @@ func extractTlsHandshakeTimeoutFromServer(s *http.Server) time.Duration {
// only do tls handshake and check NegotiatedProtocol with std's *tls.Conn
// so we do the same logic to let http2 (not h2c) work fine
func NewListenerForHttps(l net.Listener, httpServer *http.Server, tlsConfig *Config) net.Listener {
http2Server := &http2.Server{}
_ = http2.ConfigureServer(httpServer, http2Server)
http2Server := &http.Http2Server{}
_ = http.Http2ConfigureServer(httpServer, http2Server)
return N.NewHandleContextListener(context.Background(), l, func(ctx context.Context, conn net.Conn) (net.Conn, error) {
c := Server(conn, tlsConfig)
@@ -58,8 +57,8 @@ func NewListenerForHttps(l net.Listener, httpServer *http.Server, tlsConfig *Con
_ = conn.SetWriteDeadline(time.Time{})
}
if c.ConnectionState().NegotiatedProtocol == http2.NextProtoTLS {
http2Server.ServeConn(c, &http2.ServeConnOpts{BaseConfig: httpServer})
if c.ConnectionState().NegotiatedProtocol == http.Http2NextProtoTLS {
http2Server.ServeConn(c, &http.Http2ServeConnOpts{BaseConfig: httpServer})
return nil, net.ErrClosed
}
return c, nil

View File

@@ -10,22 +10,21 @@ import (
"crypto/hmac"
"crypto/sha256"
"crypto/sha512"
"crypto/tls"
"crypto/x509"
"encoding/binary"
"errors"
"net"
"net/http"
"strings"
"time"
"github.com/metacubex/mihomo/log"
"github.com/metacubex/mihomo/ntp"
"github.com/metacubex/http"
"github.com/metacubex/randv2"
"github.com/metacubex/tls"
utls "github.com/metacubex/utls"
"golang.org/x/crypto/hkdf"
"golang.org/x/net/http2"
)
const RealityMaxShortIDLen = 8
@@ -37,13 +36,14 @@ type RealityConfig struct {
SupportX25519MLKEM768 bool
}
func GetRealityConn(ctx context.Context, conn net.Conn, fingerprint UClientHelloID, tlsConfig *Config, realityConfig *RealityConfig) (net.Conn, error) {
func GetRealityConn(ctx context.Context, conn net.Conn, fingerprint UClientHelloID, serverName string, realityConfig *RealityConfig) (net.Conn, error) {
for retry := 0; ; retry++ {
verifier := &realityVerifier{
serverName: tlsConfig.ServerName,
serverName: serverName,
}
uConfig := &utls.Config{
ServerName: tlsConfig.ServerName,
Time: ntp.Now,
ServerName: serverName,
InsecureSkipVerify: true,
SessionTicketsDisabled: true,
VerifyPeerCertificate: verifier.VerifyPeerCertificate,
@@ -132,7 +132,7 @@ func GetRealityConn(ctx context.Context, conn net.Conn, fingerprint UClientHello
func realityClientFallback(uConn net.Conn, serverName string, fingerprint utls.ClientHelloID) {
defer uConn.Close()
client := http.Client{
Transport: &http2.Transport{
Transport: &http.Http2Transport{
DialTLSContext: func(ctx context.Context, network, addr string, config *tls.Config) (net.Conn, error) {
return uConn, nil
},

View File

@@ -1,13 +1,16 @@
package tls
import (
"crypto/tls"
"context"
"net"
"reflect"
"unsafe"
"github.com/metacubex/mihomo/common/once"
"github.com/metacubex/mihomo/common/utils"
"github.com/metacubex/mihomo/log"
"github.com/metacubex/tls"
utls "github.com/metacubex/utls"
"github.com/mroth/weightedrand/v2"
)
@@ -124,10 +127,48 @@ func UCertificate(it tls.Certificate) utls.Certificate {
type EncryptedClientHelloKey = utls.EncryptedClientHelloKey
func UEncryptedClientHelloKey(it tls.EncryptedClientHelloKey) utls.EncryptedClientHelloKey {
return utls.EncryptedClientHelloKey{
Config: it.Config,
PrivateKey: it.PrivateKey,
SendAsRetry: it.SendAsRetry,
}
}
type Config = utls.Config
var tlsCertificateRequestInfoCtxOffset = utils.MustOK(reflect.TypeOf((*tls.CertificateRequestInfo)(nil)).Elem().FieldByName("ctx")).Offset
var tlsClientHelloInfoCtxOffset = utils.MustOK(reflect.TypeOf((*tls.ClientHelloInfo)(nil)).Elem().FieldByName("ctx")).Offset
var tlsConnectionStateEkmOffset = utils.MustOK(reflect.TypeOf((*tls.ConnectionState)(nil)).Elem().FieldByName("ekm")).Offset
var utlsConnectionStateEkmOffset = utils.MustOK(reflect.TypeOf((*utls.ConnectionState)(nil)).Elem().FieldByName("ekm")).Offset
func tlsConnectionState(state utls.ConnectionState) (tlsState tls.ConnectionState) {
tlsState = tls.ConnectionState{
Version: state.Version,
HandshakeComplete: state.HandshakeComplete,
DidResume: state.DidResume,
CipherSuite: state.CipherSuite,
//CurveID: state.CurveID,
NegotiatedProtocol: state.NegotiatedProtocol,
NegotiatedProtocolIsMutual: state.NegotiatedProtocolIsMutual,
ServerName: state.ServerName,
PeerCertificates: state.PeerCertificates,
VerifiedChains: state.VerifiedChains,
SignedCertificateTimestamps: state.SignedCertificateTimestamps,
OCSPResponse: state.OCSPResponse,
TLSUnique: state.TLSUnique,
ECHAccepted: state.ECHAccepted,
//HelloRetryRequest: state.HelloRetryRequest,
}
// The layout of map, chan, and func types is equivalent to *T.
// state.ekm is a func(label string, context []byte, length int) ([]byte, error)
*(*unsafe.Pointer)(unsafe.Add(unsafe.Pointer(&tlsState), tlsConnectionStateEkmOffset)) =
*(*unsafe.Pointer)(unsafe.Add(unsafe.Pointer(&state), utlsConnectionStateEkmOffset))
return
}
func UConfig(config *tls.Config) *utls.Config {
return &utls.Config{
cfg := &utls.Config{
Rand: config.Rand,
Time: config.Time,
Certificates: utils.Map(config.Certificates, UCertificate),
@@ -146,7 +187,67 @@ func UConfig(config *tls.Config) *utls.Config {
}),
SessionTicketsDisabled: config.SessionTicketsDisabled,
Renegotiation: utls.RenegotiationSupport(config.Renegotiation),
KeyLogWriter: config.KeyLogWriter,
}
if config.GetClientCertificate != nil {
cfg.GetClientCertificate = func(info *utls.CertificateRequestInfo) (*utls.Certificate, error) {
tlsInfo := &tls.CertificateRequestInfo{
AcceptableCAs: info.AcceptableCAs,
SignatureSchemes: utils.Map(info.SignatureSchemes, func(it utls.SignatureScheme) tls.SignatureScheme {
return tls.SignatureScheme(it)
}),
Version: info.Version,
}
*(*context.Context)(unsafe.Add(unsafe.Pointer(tlsInfo), tlsCertificateRequestInfoCtxOffset)) = info.Context() // for tlsInfo.ctx
cert, err := config.GetClientCertificate(tlsInfo)
if err != nil {
return nil, err
}
uCert := UCertificate(*cert)
return &uCert, err
}
}
if config.GetCertificate != nil {
cfg.GetCertificate = func(info *utls.ClientHelloInfo) (*utls.Certificate, error) {
tlsInfo := &tls.ClientHelloInfo{
CipherSuites: info.CipherSuites,
ServerName: info.ServerName,
SupportedCurves: utils.Map(info.SupportedCurves, func(it utls.CurveID) tls.CurveID {
return tls.CurveID(it)
}),
SupportedPoints: info.SupportedPoints,
SignatureSchemes: utils.Map(info.SignatureSchemes, func(it utls.SignatureScheme) tls.SignatureScheme {
return tls.SignatureScheme(it)
}),
SupportedProtos: info.SupportedProtos,
SupportedVersions: info.SupportedVersions,
Extensions: info.Extensions,
Conn: info.Conn,
//HelloRetryRequest: info.HelloRetryRequest,
}
*(*context.Context)(unsafe.Add(unsafe.Pointer(tlsInfo), tlsClientHelloInfoCtxOffset)) = info.Context() // for tlsInfo.ctx
cert, err := config.GetCertificate(tlsInfo)
if err != nil {
return nil, err
}
uCert := UCertificate(*cert)
return &uCert, err
}
}
if config.VerifyConnection != nil {
cfg.VerifyConnection = func(state utls.ConnectionState) error {
return config.VerifyConnection(tlsConnectionState(state))
}
}
config.EncryptedClientHelloConfigList = cfg.EncryptedClientHelloConfigList
if config.EncryptedClientHelloRejectionVerify != nil {
cfg.EncryptedClientHelloRejectionVerify = func(state utls.ConnectionState) error {
return config.EncryptedClientHelloRejectionVerify(tlsConnectionState(state))
}
}
//cfg.GetEncryptedClientHelloKeys =
cfg.EncryptedClientHelloKeys = utils.Map(config.EncryptedClientHelloKeys, UEncryptedClientHelloKey)
return cfg
}
// BuildWebsocketHandshakeState it will only send http/1.1 in its ALPN.

View File

@@ -6,7 +6,6 @@ import (
"context"
"fmt"
"io"
"net/http"
"os"
"os/exec"
"path/filepath"
@@ -20,6 +19,8 @@ import (
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/constant/features"
"github.com/metacubex/mihomo/log"
"github.com/metacubex/http"
)
const (

View File

@@ -3,11 +3,12 @@ package updater
import (
"context"
"io"
"net/http"
"os"
"time"
mihomoHttp "github.com/metacubex/mihomo/component/http"
"github.com/metacubex/http"
)
const defaultHttpTimeout = time.Second * 90

View File

@@ -15,7 +15,9 @@ import (
"github.com/metacubex/mihomo/adapter/outbound"
"github.com/metacubex/mihomo/adapter/outboundgroup"
"github.com/metacubex/mihomo/adapter/provider"
"github.com/metacubex/mihomo/common/orderedmap"
"github.com/metacubex/mihomo/common/utils"
"github.com/metacubex/mihomo/common/yaml"
"github.com/metacubex/mihomo/component/auth"
"github.com/metacubex/mihomo/component/cidr"
"github.com/metacubex/mihomo/component/fakeip"
@@ -34,11 +36,10 @@ import (
R "github.com/metacubex/mihomo/rules"
RC "github.com/metacubex/mihomo/rules/common"
RP "github.com/metacubex/mihomo/rules/provider"
RW "github.com/metacubex/mihomo/rules/wrapper"
T "github.com/metacubex/mihomo/tunnel"
orderedmap "github.com/wk8/go-ordered-map/v2"
"golang.org/x/exp/slices"
"gopkg.in/yaml.v3"
)
// General config
@@ -165,6 +166,7 @@ type DNS struct {
FakeIPTTL int
NameServerPolicy []dns.Policy
ProxyServerNameserver []dns.NameServer
ProxyServerPolicy []dns.Policy
DirectNameServer []dns.NameServer
DirectFollowPolicy bool
}
@@ -235,6 +237,7 @@ type RawDNS struct {
CacheMaxSize int `yaml:"cache-max-size" json:"cache-max-size"`
NameServerPolicy *orderedmap.OrderedMap[string, any] `yaml:"nameserver-policy" json:"nameserver-policy"`
ProxyServerNameserver []string `yaml:"proxy-server-nameserver" json:"proxy-server-nameserver"`
ProxyServerNameserverPolicy *orderedmap.OrderedMap[string, any] `yaml:"proxy-server-nameserver-policy" json:"proxy-server-nameserver-policy"`
DirectNameServer []string `yaml:"direct-nameserver" json:"direct-nameserver"`
DirectNameServerFollowPolicy bool `yaml:"direct-nameserver-follow-policy" json:"direct-nameserver-follow-policy"`
}
@@ -273,35 +276,36 @@ type RawTun struct {
GSO bool `yaml:"gso" json:"gso,omitempty"`
GSOMaxSize uint32 `yaml:"gso-max-size" json:"gso-max-size,omitempty"`
//Inet4Address []netip.Prefix `yaml:"inet4-address" json:"inet4-address,omitempty"`
Inet6Address []netip.Prefix `yaml:"inet6-address" json:"inet6-address,omitempty"`
IPRoute2TableIndex int `yaml:"iproute2-table-index" json:"iproute2-table-index,omitempty"`
IPRoute2RuleIndex int `yaml:"iproute2-rule-index" json:"iproute2-rule-index,omitempty"`
AutoRedirect bool `yaml:"auto-redirect" json:"auto-redirect,omitempty"`
AutoRedirectInputMark uint32 `yaml:"auto-redirect-input-mark" json:"auto-redirect-input-mark,omitempty"`
AutoRedirectOutputMark uint32 `yaml:"auto-redirect-output-mark" json:"auto-redirect-output-mark,omitempty"`
LoopbackAddress []netip.Addr `yaml:"loopback-address" json:"loopback-address,omitempty"`
StrictRoute bool `yaml:"strict-route" json:"strict-route,omitempty"`
RouteAddress []netip.Prefix `yaml:"route-address" json:"route-address,omitempty"`
RouteAddressSet []string `yaml:"route-address-set" json:"route-address-set,omitempty"`
RouteExcludeAddress []netip.Prefix `yaml:"route-exclude-address" json:"route-exclude-address,omitempty"`
RouteExcludeAddressSet []string `yaml:"route-exclude-address-set" json:"route-exclude-address-set,omitempty"`
IncludeInterface []string `yaml:"include-interface" json:"include-interface,omitempty"`
ExcludeInterface []string `yaml:"exclude-interface" json:"exclude-interface,omitempty"`
IncludeUID []uint32 `yaml:"include-uid" json:"include-uid,omitempty"`
IncludeUIDRange []string `yaml:"include-uid-range" json:"include-uid-range,omitempty"`
ExcludeUID []uint32 `yaml:"exclude-uid" json:"exclude-uid,omitempty"`
ExcludeUIDRange []string `yaml:"exclude-uid-range" json:"exclude-uid-range,omitempty"`
ExcludeSrcPort []uint16 `yaml:"exclude-src-port" json:"exclude-src-port,omitempty"`
ExcludeSrcPortRange []string `yaml:"exclude-src-port-range" json:"exclude-src-port-range,omitempty"`
ExcludeDstPort []uint16 `yaml:"exclude-dst-port" json:"exclude-dst-port,omitempty"`
ExcludeDstPortRange []string `yaml:"exclude-dst-port-range" json:"exclude-dst-port-range,omitempty"`
IncludeAndroidUser []int `yaml:"include-android-user" json:"include-android-user,omitempty"`
IncludePackage []string `yaml:"include-package" json:"include-package,omitempty"`
ExcludePackage []string `yaml:"exclude-package" json:"exclude-package,omitempty"`
EndpointIndependentNat bool `yaml:"endpoint-independent-nat" json:"endpoint-independent-nat,omitempty"`
UDPTimeout int64 `yaml:"udp-timeout" json:"udp-timeout,omitempty"`
DisableICMPForwarding bool `yaml:"disable-icmp-forwarding" json:"disable-icmp-forwarding,omitempty"`
FileDescriptor int `yaml:"file-descriptor" json:"file-descriptor"`
Inet6Address []netip.Prefix `yaml:"inet6-address" json:"inet6-address,omitempty"`
IPRoute2TableIndex int `yaml:"iproute2-table-index" json:"iproute2-table-index,omitempty"`
IPRoute2RuleIndex int `yaml:"iproute2-rule-index" json:"iproute2-rule-index,omitempty"`
AutoRedirect bool `yaml:"auto-redirect" json:"auto-redirect,omitempty"`
AutoRedirectInputMark uint32 `yaml:"auto-redirect-input-mark" json:"auto-redirect-input-mark,omitempty"`
AutoRedirectOutputMark uint32 `yaml:"auto-redirect-output-mark" json:"auto-redirect-output-mark,omitempty"`
AutoRedirectIPRoute2FallbackRuleIndex int `yaml:"auto-redirect-iproute2-fallback-rule-index" json:"auto-redirect-iproute2-fallback-rule-index,omitempty"`
LoopbackAddress []netip.Addr `yaml:"loopback-address" json:"loopback-address,omitempty"`
StrictRoute bool `yaml:"strict-route" json:"strict-route,omitempty"`
RouteAddress []netip.Prefix `yaml:"route-address" json:"route-address,omitempty"`
RouteAddressSet []string `yaml:"route-address-set" json:"route-address-set,omitempty"`
RouteExcludeAddress []netip.Prefix `yaml:"route-exclude-address" json:"route-exclude-address,omitempty"`
RouteExcludeAddressSet []string `yaml:"route-exclude-address-set" json:"route-exclude-address-set,omitempty"`
IncludeInterface []string `yaml:"include-interface" json:"include-interface,omitempty"`
ExcludeInterface []string `yaml:"exclude-interface" json:"exclude-interface,omitempty"`
IncludeUID []uint32 `yaml:"include-uid" json:"include-uid,omitempty"`
IncludeUIDRange []string `yaml:"include-uid-range" json:"include-uid-range,omitempty"`
ExcludeUID []uint32 `yaml:"exclude-uid" json:"exclude-uid,omitempty"`
ExcludeUIDRange []string `yaml:"exclude-uid-range" json:"exclude-uid-range,omitempty"`
ExcludeSrcPort []uint16 `yaml:"exclude-src-port" json:"exclude-src-port,omitempty"`
ExcludeSrcPortRange []string `yaml:"exclude-src-port-range" json:"exclude-src-port-range,omitempty"`
ExcludeDstPort []uint16 `yaml:"exclude-dst-port" json:"exclude-dst-port,omitempty"`
ExcludeDstPortRange []string `yaml:"exclude-dst-port-range" json:"exclude-dst-port-range,omitempty"`
IncludeAndroidUser []int `yaml:"include-android-user" json:"include-android-user,omitempty"`
IncludePackage []string `yaml:"include-package" json:"include-package,omitempty"`
ExcludePackage []string `yaml:"exclude-package" json:"exclude-package,omitempty"`
EndpointIndependentNat bool `yaml:"endpoint-independent-nat" json:"endpoint-independent-nat,omitempty"`
UDPTimeout int64 `yaml:"udp-timeout" json:"udp-timeout,omitempty"`
DisableICMPForwarding bool `yaml:"disable-icmp-forwarding" json:"disable-icmp-forwarding,omitempty"`
FileDescriptor int `yaml:"file-descriptor" json:"file-descriptor"`
Inet4RouteAddress []netip.Prefix `yaml:"inet4-route-address" json:"inet4-route-address,omitempty"`
Inet6RouteAddress []netip.Prefix `yaml:"inet6-route-address" json:"inet6-route-address,omitempty"`
@@ -954,6 +958,12 @@ func parseProxies(cfg *RawConfig) (proxies map[string]C.Proxy, providersMap map[
)
proxies["GLOBAL"] = adapter.NewProxy(global)
}
// validate dialer-proxy references
if err := validateDialerProxies(proxies); err != nil {
return nil, nil, err
}
return proxies, providersMap, nil
}
@@ -1083,6 +1093,10 @@ func parseRules(rulesConfig []string, proxies map[string]C.Proxy, ruleProviders
}
}
if format == "rules" { // only wrap top level rules
parsed = RW.NewRuleWrapper(parsed)
}
rules = append(rules, parsed)
}
@@ -1101,22 +1115,23 @@ func parseHosts(cfg *RawConfig) (*trie.DomainTrie[resolver.HostValue], error) {
if len(cfg.Hosts) != 0 {
for domain, anyValue := range cfg.Hosts {
if str, ok := anyValue.(string); ok && str == "lan" {
hosts, err := utils.ToStringSlice(anyValue)
if err != nil {
return nil, err
}
if len(hosts) == 1 && hosts[0] == "lan" {
if addrs, err := net.InterfaceAddrs(); err != nil {
log.Errorln("insert lan to host error: %s", err)
} else {
ips := make([]netip.Addr, 0)
hosts = make([]string, 0, len(addrs))
for _, addr := range addrs {
if ipnet, ok := addr.(*net.IPNet); ok && !ipnet.IP.IsLoopback() && !ipnet.IP.IsLinkLocalUnicast() {
if ip, err := netip.ParseAddr(ipnet.IP.String()); err == nil {
ips = append(ips, ip)
}
hosts = append(hosts, ipnet.IP.String())
}
}
anyValue = ips
}
}
value, err := resolver.NewHostValue(anyValue)
value, err := resolver.NewHostValue(hosts)
if err != nil {
return nil, fmt.Errorf("%s is not a valid value", anyValue)
}
@@ -1184,7 +1199,7 @@ func parseNameServer(servers []string, respectRules bool, preferH3 bool) ([]dns.
dnsNetType = "tcp" // TCP
case "tls":
addr, err = hostWithDefaultPort(u.Host, "853")
dnsNetType = "tcp-tls" // DNS over TLS
dnsNetType = "tls" // DNS over TLS
case "http", "https":
addr, err = hostWithDefaultPort(u.Host, "443")
dnsNetType = "https" // DNS over HTTPS
@@ -1292,7 +1307,7 @@ func parseNameServerPolicy(nsPolicy *orderedmap.OrderedMap[string, any], rulePro
}
kLower := strings.ToLower(k)
if strings.Contains(kLower, ",") {
if strings.Contains(kLower, "geosite:") {
if strings.HasPrefix(kLower, "geosite:") {
subkeys := strings.Split(k, ":")
subkeys = subkeys[1:]
subkeys = strings.Split(subkeys[0], ",")
@@ -1300,7 +1315,7 @@ func parseNameServerPolicy(nsPolicy *orderedmap.OrderedMap[string, any], rulePro
newKey := "geosite:" + subkey
policy = append(policy, dns.Policy{Domain: newKey, NameServers: nameservers})
}
} else if strings.Contains(kLower, "rule-set:") {
} else if strings.HasPrefix(kLower, "rule-set:") {
subkeys := strings.Split(k, ":")
subkeys = subkeys[1:]
subkeys = strings.Split(subkeys[0], ",")
@@ -1315,9 +1330,9 @@ func parseNameServerPolicy(nsPolicy *orderedmap.OrderedMap[string, any], rulePro
}
}
} else {
if strings.Contains(kLower, "geosite:") {
if strings.HasPrefix(kLower, "geosite:") {
policy = append(policy, dns.Policy{Domain: "geosite:" + k[8:], NameServers: nameservers})
} else if strings.Contains(kLower, "rule-set:") {
} else if strings.HasPrefix(kLower, "rule-set:") {
policy = append(policy, dns.Policy{Domain: "rule-set:" + k[9:], NameServers: nameservers})
} else {
policy = append(policy, dns.Policy{Domain: k, NameServers: nameservers})
@@ -1391,6 +1406,13 @@ func parseDNS(rawCfg *RawConfig, ruleProviders map[string]P.RuleProvider) (*DNS,
return nil, err
}
if dnsCfg.ProxyServerPolicy, err = parseNameServerPolicy(cfg.ProxyServerNameserverPolicy, ruleProviders, false, cfg.PreferH3); err != nil {
return nil, err
}
if len(dnsCfg.ProxyServerPolicy) != 0 && len(dnsCfg.ProxyServerNameserver) == 0 {
return nil, errors.New("disallow empty `proxy-server-nameserver` when `proxy-server-nameserver-policy` is set")
}
if dnsCfg.DirectNameServer, err = parseNameServer(cfg.DirectNameServer, false, cfg.PreferH3); err != nil {
return nil, err
}
@@ -1450,16 +1472,22 @@ func parseDNS(rawCfg *RawConfig, ruleProviders map[string]P.RuleProvider) (*DNS,
}
}
// fake ip skip host filter
host, err := parseDomain(cfg.FakeIPFilter, fakeIPTrie, "dns.fake-ip-filter", ruleProviders)
if err != nil {
return nil, err
skipper := &fakeip.Skipper{Mode: cfg.FakeIPFilterMode}
if cfg.FakeIPFilterMode == C.FilterRule {
rules, err := parseFakeIPRules(cfg.FakeIPFilter, ruleProviders)
if err != nil {
return nil, err
}
skipper.Rules = rules
} else {
host, err := parseDomain(cfg.FakeIPFilter, fakeIPTrie, "dns.fake-ip-filter", ruleProviders)
if err != nil {
return nil, err
}
skipper.Host = host
}
skipper := &fakeip.Skipper{
Host: host,
Mode: cfg.FakeIPFilterMode,
}
dnsCfg.FakeIPSkipper = skipper
dnsCfg.FakeIPTTL = cfg.FakeIPTTL
@@ -1541,6 +1569,55 @@ func parseDNS(rawCfg *RawConfig, ruleProviders map[string]P.RuleProvider) (*DNS,
return dnsCfg, nil
}
func parseFakeIPRules(rawRules []string, ruleProviders map[string]P.RuleProvider) ([]C.Rule, error) {
var rules []C.Rule
for idx, line := range rawRules {
tp, payload, action, params := RC.ParseRulePayload(line, true)
action = strings.ToLower(action)
if action != fakeip.UseFakeIP && action != fakeip.UseRealIP {
return nil, fmt.Errorf("dns.fake-ip-filter[%d] [%s] error: invalid action '%s', must be 'fake-ip' or 'real-ip'", idx, line, action)
}
if tp == "RULE-SET" {
if rp, ok := ruleProviders[payload]; !ok {
return nil, fmt.Errorf("dns.fake-ip-filter[%d] [%s] error: rule-set '%s' not found", idx, line, payload)
} else {
switch rp.Behavior() {
case P.IPCIDR:
return nil, fmt.Errorf("dns.fake-ip-filter[%d] [%s] error: rule-set behavior is %s, must be domain or classical", idx, line, rp.Behavior())
case P.Classical:
log.Warnln("%s provider is %s, only matching domain rules in fake-ip-filter", rp.Name(), rp.Behavior())
default:
}
}
}
parsed, err := R.ParseRule(tp, payload, action, params, nil)
if err != nil {
return nil, fmt.Errorf("dns.fake-ip-filter[%d] [%s] error: %w", idx, line, err)
}
if !isDomainRule(parsed.RuleType()) && parsed.RuleType() != C.MATCH {
return nil, fmt.Errorf("dns.fake-ip-filter[%d] [%s] error: rule type '%s' not supported, only domain-based rules allowed", idx, line, tp)
}
rules = append(rules, parsed)
}
return rules, nil
}
func isDomainRule(rt C.RuleType) bool {
switch rt {
case C.Domain, C.DomainSuffix, C.DomainKeyword, C.DomainRegex, C.DomainWildcard, C.GEOSITE, C.RuleSet:
return true
default:
return false
}
}
func parseAuthentication(rawRecords []string) []auth.AuthUser {
var users []auth.AuthUser
for _, line := range rawRecords {
@@ -1573,39 +1650,40 @@ func parseTun(rawTun RawTun, dns *DNS, general *General) error {
AutoRoute: rawTun.AutoRoute,
AutoDetectInterface: rawTun.AutoDetectInterface,
MTU: rawTun.MTU,
GSO: rawTun.GSO,
GSOMaxSize: rawTun.GSOMaxSize,
Inet4Address: []netip.Prefix{tunAddressPrefix},
Inet6Address: rawTun.Inet6Address,
IPRoute2TableIndex: rawTun.IPRoute2TableIndex,
IPRoute2RuleIndex: rawTun.IPRoute2RuleIndex,
AutoRedirect: rawTun.AutoRedirect,
AutoRedirectInputMark: rawTun.AutoRedirectInputMark,
AutoRedirectOutputMark: rawTun.AutoRedirectOutputMark,
LoopbackAddress: rawTun.LoopbackAddress,
StrictRoute: rawTun.StrictRoute,
RouteAddress: rawTun.RouteAddress,
RouteAddressSet: rawTun.RouteAddressSet,
RouteExcludeAddress: rawTun.RouteExcludeAddress,
RouteExcludeAddressSet: rawTun.RouteExcludeAddressSet,
IncludeInterface: rawTun.IncludeInterface,
ExcludeInterface: rawTun.ExcludeInterface,
IncludeUID: rawTun.IncludeUID,
IncludeUIDRange: rawTun.IncludeUIDRange,
ExcludeUID: rawTun.ExcludeUID,
ExcludeUIDRange: rawTun.ExcludeUIDRange,
ExcludeSrcPort: rawTun.ExcludeSrcPort,
ExcludeSrcPortRange: rawTun.ExcludeSrcPortRange,
ExcludeDstPort: rawTun.ExcludeDstPort,
ExcludeDstPortRange: rawTun.ExcludeDstPortRange,
IncludeAndroidUser: rawTun.IncludeAndroidUser,
IncludePackage: rawTun.IncludePackage,
ExcludePackage: rawTun.ExcludePackage,
EndpointIndependentNat: rawTun.EndpointIndependentNat,
UDPTimeout: rawTun.UDPTimeout,
DisableICMPForwarding: rawTun.DisableICMPForwarding,
FileDescriptor: rawTun.FileDescriptor,
MTU: rawTun.MTU,
GSO: rawTun.GSO,
GSOMaxSize: rawTun.GSOMaxSize,
Inet4Address: []netip.Prefix{tunAddressPrefix},
Inet6Address: rawTun.Inet6Address,
IPRoute2TableIndex: rawTun.IPRoute2TableIndex,
IPRoute2RuleIndex: rawTun.IPRoute2RuleIndex,
AutoRedirect: rawTun.AutoRedirect,
AutoRedirectInputMark: rawTun.AutoRedirectInputMark,
AutoRedirectOutputMark: rawTun.AutoRedirectOutputMark,
AutoRedirectIPRoute2FallbackRuleIndex: rawTun.AutoRedirectIPRoute2FallbackRuleIndex,
LoopbackAddress: rawTun.LoopbackAddress,
StrictRoute: rawTun.StrictRoute,
RouteAddress: rawTun.RouteAddress,
RouteAddressSet: rawTun.RouteAddressSet,
RouteExcludeAddress: rawTun.RouteExcludeAddress,
RouteExcludeAddressSet: rawTun.RouteExcludeAddressSet,
IncludeInterface: rawTun.IncludeInterface,
ExcludeInterface: rawTun.ExcludeInterface,
IncludeUID: rawTun.IncludeUID,
IncludeUIDRange: rawTun.IncludeUIDRange,
ExcludeUID: rawTun.ExcludeUID,
ExcludeUIDRange: rawTun.ExcludeUIDRange,
ExcludeSrcPort: rawTun.ExcludeSrcPort,
ExcludeSrcPortRange: rawTun.ExcludeSrcPortRange,
ExcludeDstPort: rawTun.ExcludeDstPort,
ExcludeDstPortRange: rawTun.ExcludeDstPortRange,
IncludeAndroidUser: rawTun.IncludeAndroidUser,
IncludePackage: rawTun.IncludePackage,
ExcludePackage: rawTun.ExcludePackage,
EndpointIndependentNat: rawTun.EndpointIndependentNat,
UDPTimeout: rawTun.UDPTimeout,
DisableICMPForwarding: rawTun.DisableICMPForwarding,
FileDescriptor: rawTun.FileDescriptor,
Inet4RouteAddress: rawTun.Inet4RouteAddress,
Inet6RouteAddress: rawTun.Inet6RouteAddress,
@@ -1712,7 +1790,7 @@ func parseSniffer(snifferRaw RawSniffer, ruleProviders map[string]P.RuleProvider
}
snifferConfig.SkipSrcAddress = skipSrcAddress
skipDstAddress, err := parseIPCIDR(snifferRaw.SkipDstAddress, nil, "sniffer.skip-src-address", ruleProviders)
skipDstAddress, err := parseIPCIDR(snifferRaw.SkipDstAddress, nil, "sniffer.skip-dst-address", ruleProviders)
if err != nil {
return nil, fmt.Errorf("error in skip-dst-address, error:%w", err)
}
@@ -1731,7 +1809,7 @@ func parseIPCIDR(addresses []string, cidrSet *cidr.IpCidrSet, adapterName string
var matcher C.IpMatcher
for _, ipcidr := range addresses {
ipcidrLower := strings.ToLower(ipcidr)
if strings.Contains(ipcidrLower, "geoip:") {
if strings.HasPrefix(ipcidrLower, "geoip:") {
subkeys := strings.Split(ipcidr, ":")
subkeys = subkeys[1:]
subkeys = strings.Split(subkeys[0], ",")
@@ -1742,7 +1820,7 @@ func parseIPCIDR(addresses []string, cidrSet *cidr.IpCidrSet, adapterName string
}
matchers = append(matchers, matcher)
}
} else if strings.Contains(ipcidrLower, "rule-set:") {
} else if strings.HasPrefix(ipcidrLower, "rule-set:") {
subkeys := strings.Split(ipcidr, ":")
subkeys = subkeys[1:]
subkeys = strings.Split(subkeys[0], ",")
@@ -1778,7 +1856,7 @@ func parseDomain(domains []string, domainTrie *trie.DomainTrie[struct{}], adapte
var matcher C.DomainMatcher
for _, domain := range domains {
domainLower := strings.ToLower(domain)
if strings.Contains(domainLower, "geosite:") {
if strings.HasPrefix(domainLower, "geosite:") {
subkeys := strings.Split(domain, ":")
subkeys = subkeys[1:]
subkeys = strings.Split(subkeys[0], ",")
@@ -1789,7 +1867,7 @@ func parseDomain(domains []string, domainTrie *trie.DomainTrie[struct{}], adapte
}
matchers = append(matchers, matcher)
}
} else if strings.Contains(domainLower, "rule-set:") {
} else if strings.HasPrefix(domainLower, "rule-set:") {
subkeys := strings.Split(domain, ":")
subkeys = subkeys[1:]
subkeys = strings.Split(subkeys[0], ",")

View File

@@ -9,6 +9,7 @@ import (
"github.com/metacubex/mihomo/adapter/outboundgroup"
"github.com/metacubex/mihomo/common/structure"
C "github.com/metacubex/mihomo/constant"
)
// Check if ProxyGroups form DAG(Directed Acyclic Graph), and sort all ProxyGroups by dependency order.
@@ -143,6 +144,64 @@ func proxyGroupsDagSort(groupsConfig []map[string]any) error {
return fmt.Errorf("loop is detected in ProxyGroup, please check following ProxyGroups: %v", loopElements)
}
// validateDialerProxies checks if all dialer-proxy references are valid
func validateDialerProxies(proxies map[string]C.Proxy) error {
graph := make(map[string]string) // proxy name -> dialer-proxy name
// collect all proxies with dialer-proxy configured
for name, proxy := range proxies {
dialerProxy := proxy.ProxyInfo().DialerProxy
if dialerProxy != "" {
// validate each dialer-proxy reference
_, exist := proxies[dialerProxy]
if !exist {
return fmt.Errorf("proxy [%s] dialer-proxy [%s] not found", name, dialerProxy)
}
// build dependency graph
graph[name] = dialerProxy
}
}
// perform depth-first search to detect cycles for each proxy
for name := range graph {
visited := make(map[string]bool, len(graph))
path := make([]string, 0, len(graph))
if validateDialerProxiesHasCycle(name, graph, visited, path) {
return fmt.Errorf("proxy [%s] has circular dialer-proxy dependency", name)
}
}
return nil
}
// validateDialerProxiesHasCycle performs DFS to detect if there's a cycle starting from current proxy
func validateDialerProxiesHasCycle(current string, graph map[string]string, visited map[string]bool, path []string) bool {
// check if current is already in path (cycle detected)
for _, p := range path {
if p == current {
return true
}
}
// already visited and no cycle
if visited[current] {
return false
}
visited[current] = true
path = append(path, current)
// check dialer-proxy of current proxy
if dialerProxy, exists := graph[current]; exists {
if validateDialerProxiesHasCycle(dialerProxy, graph, visited, path) {
return true
}
}
return false
}
func verifyIP6() bool {
if skip, _ := strconv.ParseBool(os.Getenv("SKIP_SYSTEM_IPV6_CHECK")); skip {
return true

79
config/utils_test.go Normal file
View File

@@ -0,0 +1,79 @@
package config
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestValidateDialerProxies(t *testing.T) {
testCases := []struct {
testName string
proxy []map[string]any
errContains string
}{
{
testName: "ValidReference",
proxy: []map[string]any{ // create proxy with valid dialer-proxy reference
{"name": "base-proxy", "type": "socks5", "server": "127.0.0.1", "port": 1080},
{"name": "proxy-with-dialer", "type": "socks5", "server": "127.0.0.1", "port": 1081, "dialer-proxy": "base-proxy"},
},
errContains: "",
},
{
testName: "NotFoundReference",
proxy: []map[string]any{ // create proxy with non-existent dialer-proxy reference
{"name": "proxy-with-dialer", "type": "socks5", "server": "127.0.0.1", "port": 1081, "dialer-proxy": "non-existent-proxy"},
},
errContains: "not found",
},
{
testName: "CircularDependency",
proxy: []map[string]any{
// create proxy A that references B
{"name": "proxy-a", "type": "socks5", "server": "127.0.0.1", "port": 1080, "dialer-proxy": "proxy-c"},
// create proxy B that references C
{"name": "proxy-b", "type": "socks5", "server": "127.0.0.1", "port": 1081, "dialer-proxy": "proxy-a"},
// create proxy C that references A (creates cycle)
{"name": "proxy-c", "type": "socks5", "server": "127.0.0.1", "port": 1082, "dialer-proxy": "proxy-a"},
},
errContains: "circular",
},
{
testName: "ComplexChain",
proxy: []map[string]any{ // create a valid chain: proxy-d -> proxy-c -> proxy-b -> proxy-a
{"name": "proxy-a", "type": "socks5", "server": "127.0.0.1", "port": 1080},
{"name": "proxy-b", "type": "socks5", "server": "127.0.0.1", "port": 1081, "dialer-proxy": "proxy-a"},
{"name": "proxy-c", "type": "socks5", "server": "127.0.0.1", "port": 1082, "dialer-proxy": "proxy-b"},
{"name": "proxy-d", "type": "socks5", "server": "127.0.0.1", "port": 1083, "dialer-proxy": "proxy-c"},
},
errContains: "",
},
{
testName: "EmptyDialerProxy",
proxy: []map[string]any{ // create proxy without dialer-proxy
{"name": "simple-proxy", "type": "socks5", "server": "127.0.0.1", "port": 1080},
},
errContains: "",
},
{
testName: "SelfReference",
proxy: []map[string]any{ // create proxy that references itself
{"name": "self-proxy", "type": "socks5", "server": "127.0.0.1", "port": 1080, "dialer-proxy": "self-proxy"},
},
errContains: "circular",
},
}
for _, testCase := range testCases {
t.Run(testCase.testName, func(t *testing.T) {
config := RawConfig{Proxy: testCase.proxy}
_, _, err := parseProxies(&config)
if testCase.errContains == "" {
assert.NoError(t, err, testCase.testName)
} else {
assert.ErrorContains(t, err, testCase.errContains, testCase.testName)
}
})
}
}

View File

@@ -45,6 +45,7 @@ const (
Mieru
AnyTLS
Sudoku
Masque
)
const (
@@ -59,6 +60,7 @@ var ErrNotSupport = errors.New("no support")
type Connection interface {
Chains() Chain
ProviderChains() Chain
AppendToChains(adapter ProxyAdapter)
RemoteDestination() string
}
@@ -102,13 +104,14 @@ type Dialer interface {
}
type ProxyInfo struct {
XUDP bool
TFO bool
MPTCP bool
SMUX bool
Interface string
RoutingMark int
DialerProxy string
XUDP bool
TFO bool
MPTCP bool
SMUX bool
Interface string
RoutingMark int
ProviderName string
DialerProxy string
}
type ProxyAdapter interface {
@@ -121,17 +124,6 @@ type ProxyAdapter interface {
ProxyInfo() ProxyInfo
MarshalJSON() ([]byte, error)
// Deprecated: use DialContextWithDialer and ListenPacketWithDialer instead.
// StreamConn wraps a protocol around net.Conn with Metadata.
//
// Examples:
// conn, _ := net.DialContext(context.Background(), "tcp", "host:port")
// conn, _ = adapter.StreamConnContext(context.Background(), conn, metadata)
//
// It returns a C.Conn with protocol which start with
// a new session (if any)
StreamConnContext(ctx context.Context, c net.Conn, metadata *Metadata) (net.Conn, error)
// DialContext return a C.Conn with protocol which
// contains multiplexing-related reuse logic (if any)
DialContext(ctx context.Context, metadata *Metadata) (Conn, error)
@@ -140,13 +132,6 @@ type ProxyAdapter interface {
// SupportUOT return UDP over TCP support
SupportUOT() bool
// SupportWithDialer only for deprecated relay group, the new protocol does not need to be implemented.
SupportWithDialer() NetWork
// DialContextWithDialer only for deprecated relay group, the new protocol does not need to be implemented.
DialContextWithDialer(ctx context.Context, dialer Dialer, metadata *Metadata) (Conn, error)
// ListenPacketWithDialer only for deprecated relay group, the new protocol does not need to be implemented.
ListenPacketWithDialer(ctx context.Context, dialer Dialer, metadata *Metadata) (PacketConn, error)
// IsL3Protocol return ProxyAdapter working in L3 (tell dns module not pass the domain to avoid loopback)
IsL3Protocol(metadata *Metadata) bool
@@ -157,11 +142,6 @@ type ProxyAdapter interface {
Close() error
}
type Group interface {
URLTest(ctx context.Context, url string, expectedStatus utils.IntRanges[uint16]) (mp map[string]uint16, err error)
Touch()
}
type DelayHistory struct {
Time time.Time `json:"time"`
Delay uint16 `json:"delay"`
@@ -233,6 +213,8 @@ func (at AdapterType) String() string {
return "AnyTLS"
case Sudoku:
return "Sudoku"
case Masque:
return "Masque"
case Relay:
return "Relay"
case Selector:

View File

@@ -86,18 +86,24 @@ func (d DNSPrefer) String() string {
}
}
func NewDNSPrefer(prefer string) DNSPrefer {
if p, ok := dnsPreferMap[prefer]; ok {
return p
} else {
return DualStack
func (d DNSPrefer) MarshalText() ([]byte, error) {
return []byte(d.String()), nil
}
func (d *DNSPrefer) UnmarshalText(data []byte) error {
p, exist := dnsPreferMap[strings.ToLower(string(data))]
if !exist {
p = DualStack
}
*d = p
return nil
}
// FilterModeMapping is a mapping for FilterMode enum
var FilterModeMapping = map[string]FilterMode{
FilterBlackList.String(): FilterBlackList,
FilterWhiteList.String(): FilterWhiteList,
FilterRule.String(): FilterRule,
}
type FilterMode int
@@ -105,6 +111,7 @@ type FilterMode int
const (
FilterBlackList FilterMode = iota
FilterWhiteList
FilterRule
)
func (e FilterMode) String() string {
@@ -113,6 +120,8 @@ func (e FilterMode) String() string {
return "blacklist"
case FilterWhiteList:
return "whitelist"
case FilterRule:
return "rule"
default:
return "unknown"
}

View File

@@ -1,5 +1,7 @@
package constant
import "time"
// Rule Type
const (
Domain RuleType = iota
@@ -27,6 +29,8 @@ const (
ProcessPath
ProcessNameRegex
ProcessPathRegex
ProcessNameWildcard
ProcessPathWildcard
RuleSet
Network
Uid
@@ -89,6 +93,10 @@ func (rt RuleType) String() string {
return "ProcessNameRegex"
case ProcessPathRegex:
return "ProcessPathRegex"
case ProcessNameWildcard:
return "ProcessNameWildcard"
case ProcessPathWildcard:
return "ProcessPathWildcard"
case MATCH:
return "Match"
case RuleSet:
@@ -120,6 +128,27 @@ type Rule interface {
ProviderNames() []string
}
type RuleWrapper interface {
Rule
// SetDisabled to set enable/disable rule
SetDisabled(v bool)
// IsDisabled return rule is disabled or not
IsDisabled() bool
// HitCount for statistics
HitCount() uint64
// HitAt for statistics
HitAt() time.Time
// MissCount for statistics
MissCount() uint64
// MissAt for statistics
MissAt() time.Time
// Unwrap return Rule
Unwrap() Rule
}
type RuleMatchHelper struct {
ResolveIP func()
FindProcess func()

View File

@@ -2,13 +2,11 @@ package dns
import (
"context"
"crypto/tls"
"fmt"
"net"
"strings"
"time"
"github.com/metacubex/mihomo/component/ca"
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/log"
@@ -16,11 +14,10 @@ import (
)
type client struct {
port string
host string
dialer *dnsDialer
schema string
skipCertVerify bool
port string
host string
dialer *dnsDialer
schema string
}
var _ dnsClient = (*client)(nil)
@@ -43,23 +40,6 @@ func (c *client) ExchangeContext(ctx context.Context, m *D.Msg) (*D.Msg, error)
}
defer conn.Close()
if c.schema == "tls" {
tlsConfig, err := ca.GetTLSConfig(ca.Option{
TLSConfig: &tls.Config{
ServerName: c.host,
InsecureSkipVerify: c.skipCertVerify,
},
})
if err != nil {
return nil, err
}
tlsConn := tls.Client(conn, tlsConfig)
if err := tlsConn.HandshakeContext(ctx); err != nil {
return nil, err
}
conn = tlsConn
}
// miekg/dns ExchangeContext doesn't respond to context cancel.
// this is a workaround
type result struct {
@@ -117,12 +97,6 @@ func newClient(addr string, resolver *Resolver, netType string, params map[strin
}
if strings.HasPrefix(netType, "tcp") {
c.schema = "tcp"
if strings.HasSuffix(netType, "tls") {
c.schema = "tls"
}
}
if params["skip-cert-verify"] == "true" {
c.skipCertVerify = true
}
return c
}

View File

@@ -2,13 +2,11 @@ package dns
import (
"context"
"crypto/tls"
"encoding/base64"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"runtime"
"strconv"
@@ -16,15 +14,15 @@ import (
"time"
"github.com/metacubex/mihomo/component/ca"
tlsC "github.com/metacubex/mihomo/component/tls"
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/log"
"github.com/metacubex/http"
"github.com/metacubex/quic-go"
"github.com/metacubex/quic-go/http3"
"github.com/metacubex/tls"
D "github.com/miekg/dns"
"golang.org/x/exp/slices"
"golang.org/x/net/http2"
)
// Values to configure HTTP and HTTP/2 transport.
@@ -439,8 +437,8 @@ func (doh *dnsOverHTTPS) createTransport(ctx context.Context) (t http.RoundTripp
// Explicitly configure transport to use HTTP/2.
//
// See https://github.com/AdguardTeam/dnsproxy/issues/11.
var transportH2 *http2.Transport
transportH2, err = http2.ConfigureTransports(transport)
var transportH2 *http.Http2Transport
transportH2, err = http.Http2ConfigureTransports(transport)
if err != nil {
return nil, err
}
@@ -530,20 +528,20 @@ func (doh *dnsOverHTTPS) createTransportH3(
// Ignore the address and always connect to the one that we got
// from the bootstrapper.
_ string,
tlsCfg *tlsC.Config,
tlsCfg *tls.Config,
cfg *quic.Config,
) (c *quic.Conn, err error) {
return doh.dialQuic(ctx, addr, tlsCfg, cfg)
},
DisableCompression: true,
TLSClientConfig: tlsC.UConfig(tlsConfig),
TLSClientConfig: tlsConfig,
QUICConfig: doh.getQUICConfig(),
}
return &http3Transport{baseTransport: rt}, nil
}
func (doh *dnsOverHTTPS) dialQuic(ctx context.Context, addr string, tlsCfg *tlsC.Config, cfg *quic.Config) (*quic.Conn, error) {
func (doh *dnsOverHTTPS) dialQuic(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (*quic.Conn, error) {
ip, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
@@ -612,7 +610,7 @@ func (doh *dnsOverHTTPS) probeH3(
// Run probeQUIC and probeTLS in parallel and see which one is faster.
chQuic := make(chan error, 1)
chTLS := make(chan error, 1)
go doh.probeQUIC(ctx, addr, tlsC.UConfig(probeTLSCfg), chQuic)
go doh.probeQUIC(ctx, addr, probeTLSCfg, chQuic)
go doh.probeTLS(ctx, probeTLSCfg, chTLS)
select {
@@ -637,7 +635,7 @@ func (doh *dnsOverHTTPS) probeH3(
// probeQUIC attempts to establish a QUIC connection to the specified address.
// We run probeQUIC and probeTLS in parallel and see which one is faster.
func (doh *dnsOverHTTPS) probeQUIC(ctx context.Context, addr string, tlsConfig *tlsC.Config, ch chan error) {
func (doh *dnsOverHTTPS) probeQUIC(ctx context.Context, addr string, tlsConfig *tls.Config, ch chan error) {
startTime := time.Now()
conn, err := doh.dialQuic(ctx, addr, tlsConfig, doh.getQUICConfig())
if err != nil {
@@ -727,14 +725,10 @@ func (doh *dnsOverHTTPS) tlsDial(ctx context.Context, network string, config *tl
// TLS handshake dialTimeout will be used as connection deadLine.
conn := tls.Client(rawConn, config)
err = conn.SetDeadline(time.Now().Add(dialTimeout))
if err != nil {
// Must not happen in normal circumstances.
log.Errorln("cannot set deadline: %v", err)
return nil, err
}
ctx, cancel := context.WithTimeout(ctx, dialTimeout)
defer cancel()
err = conn.Handshake()
err = conn.HandshakeContext(ctx)
if err != nil {
defer conn.Close()
return nil, err

Some files were not shown because too many files have changed in this diff Show More