Compare commits

..

35 Commits

Author SHA1 Message Date
Larvan2
e27d7c010f Merge branch 'Alpha' into Meta 2024-03-29 19:50:16 +08:00
Larvan2
582e94ff4d Merge branch 'Alpha' into Meta 2024-03-27 18:43:06 +08:00
Larvan2
a034421a58 Merge branch 'Alpha' into Meta 2024-02-02 20:15:27 +08:00
Larvan2
dc2108c174 Merge branch 'Alpha' into Meta 2024-01-02 15:21:37 +08:00
wwqgtxx
bda71dbfa1 Merge branch 'Alpha' into Meta 2023-12-03 08:44:30 +08:00
wwqgtxx
253b023442 Merge branch 'Alpha' into Meta 2023-11-03 21:59:44 +08:00
wwqgtxx
8293b7fdae Merge branch 'Beta' into Meta
# Conflicts:
#	.github/workflows/docker.yaml
#	.github/workflows/prerelease.yml
2023-02-19 01:25:34 +08:00
wwqgtxx
0ba415866e Merge branch 'Alpha' into Beta 2023-02-19 01:25:01 +08:00
Larvan2
53b41ca166 Chore: Add action for deleting old workflow 2023-01-30 18:17:22 +08:00
metacubex
8a75f78e63 chore: adjust Dockerfile 2023-01-12 02:24:12 +08:00
metacubex
d9692c6366 Merge branch 'Beta' into Meta 2023-01-12 02:15:14 +08:00
metacubex
f4b0062dfc Merge branch 'Alpha' into Beta 2023-01-12 02:14:49 +08:00
metacubex
b9ffc82e53 Merge branch 'Beta' into Meta 2023-01-12 01:33:56 +08:00
metacubex
78aaea6a45 Merge branch 'Alpha' into Beta 2023-01-12 01:33:16 +08:00
cubemaze
3645fbf161 Merge pull request #327 from Rasphino/Meta
Update flake.nix hash
2023-01-08 00:29:29 +08:00
Rasphino
a1d0f22132 fix: update flake.nix hash 2023-01-07 23:38:32 +08:00
metacubex
fa73b0f4bf Merge remote-tracking branch 'origin/Beta' into Meta 2023-01-01 19:41:36 +08:00
metacubex
3b76a8b839 Merge remote-tracking branch 'origin/Alpha' into Beta 2023-01-01 19:40:36 +08:00
cubemaze
667f42dcdc Merge pull request #282 from tdjnodj/Meta
Update README.md
2022-12-03 17:23:51 +08:00
tdjnodj
dfbe09860f Update README.md 2022-12-03 17:17:08 +08:00
metacubex
9e20f9c26a chore: update dependencies 2022-11-28 20:33:10 +08:00
Skimmle
f968d0cb82 chore: update github action 2022-11-26 20:16:12 +08:00
metacubex
2ad84f4379 Merge branch 'Beta' into Meta 2022-11-02 18:08:22 +08:00
metacubex
c7aa16426f Merge branch 'Alpha' into Beta 2022-11-02 18:07:29 +08:00
Skyxim
5987f8e3b5 Merge branch 'Beta' into Meta 2022-08-29 13:08:29 +08:00
Skyxim
3a8eb72de2 Merge branch 'Alpha' into Beta 2022-08-29 13:08:22 +08:00
zhudan
33abbdfd24 Merge pull request #174 from MetaCubeX/Alpha
Alpha
2022-08-29 11:24:07 +08:00
zhudan
0703d6cbff Merge pull request #173 from MetaCubeX/Alpha
Alpha
2022-08-29 11:22:14 +08:00
Skyxim
10d2d14938 Merge branch 'Beta' into Meta
# Conflicts:
#	rules/provider/classical_strategy.go
2022-07-02 10:41:41 +08:00
wwqgtxx
691cf1d8d6 Merge pull request #94 from bash99/Meta
Update README.md
2022-06-15 19:15:51 +08:00
bash99
d1decb8e58 Update README.md
add permissions for systemctl services
clash-dashboard change to updated one
2022-06-15 14:00:05 +08:00
Skyxim
7d04904109 fix: leak dns when domain in hosts list 2022-06-11 18:51:26 +08:00
Skyxim
a5acd3aa97 refactor: clear linkname,reduce cycle dependencies,transport init geosite function 2022-06-11 18:51:22 +08:00
Skyxim
eea9a12560 fix: 规则匹配默认策略组返回错误 2022-06-09 14:18:35 +08:00
adlyq
0a4570b55c fix: group filter touch provider 2022-06-09 14:18:29 +08:00
660 changed files with 18213 additions and 52660 deletions

17
.github/mihomo.service vendored Normal file
View File

@@ -0,0 +1,17 @@
[Unit]
Description=mihomo Daemon, Another Clash Kernel.
After=network.target NetworkManager.service systemd-networkd.service iwd.service
[Service]
Type=simple
LimitNPROC=500
LimitNOFILE=1000000
CapabilityBoundingSet=CAP_NET_ADMIN CAP_NET_RAW CAP_NET_BIND_SERVICE CAP_SYS_TIME CAP_SYS_PTRACE CAP_DAC_READ_SEARCH
AmbientCapabilities=CAP_NET_ADMIN CAP_NET_RAW CAP_NET_BIND_SERVICE CAP_SYS_TIME CAP_SYS_PTRACE CAP_DAC_READ_SEARCH
Restart=always
ExecStartPre=/usr/bin/sleep 2s
ExecStart=/usr/bin/mihomo -d /etc/mihomo
ExecReload=/bin/kill -HUP $MAINPID
[Install]
WantedBy=multi-user.target

View File

@@ -1,182 +0,0 @@
Subject: [PATCH] Revert "[release-branch.go1.21] crypto/rand,runtime: switch RtlGenRandom for ProcessPrng"
---
Index: src/crypto/rand/rand.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/crypto/rand/rand.go b/src/crypto/rand/rand.go
--- a/src/crypto/rand/rand.go (revision 8bba868de983dd7bf55fcd121495ba8d6e2734e7)
+++ b/src/crypto/rand/rand.go (revision 7e6c963d81e14ee394402671d4044b2940c8d2c1)
@@ -15,7 +15,7 @@
// available, /dev/urandom otherwise.
// On OpenBSD and macOS, Reader uses getentropy(2).
// On other Unix-like systems, Reader reads from /dev/urandom.
-// On Windows systems, Reader uses the ProcessPrng API.
+// On Windows systems, Reader uses the RtlGenRandom API.
// On JS/Wasm, Reader uses the Web Crypto API.
// On WASIP1/Wasm, Reader uses random_get from wasi_snapshot_preview1.
var Reader io.Reader
Index: src/crypto/rand/rand_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/crypto/rand/rand_windows.go b/src/crypto/rand/rand_windows.go
--- a/src/crypto/rand/rand_windows.go (revision 8bba868de983dd7bf55fcd121495ba8d6e2734e7)
+++ b/src/crypto/rand/rand_windows.go (revision 7e6c963d81e14ee394402671d4044b2940c8d2c1)
@@ -15,8 +15,11 @@
type rngReader struct{}
-func (r *rngReader) Read(b []byte) (int, error) {
- if err := windows.ProcessPrng(b); err != nil {
+func (r *rngReader) Read(b []byte) (n int, err error) {
+ // RtlGenRandom only returns 1<<32-1 bytes at a time. We only read at
+ // most 1<<31-1 bytes at a time so that this works the same on 32-bit
+ // and 64-bit systems.
+ if err := batched(windows.RtlGenRandom, 1<<31-1)(b); err != nil {
return 0, err
}
return len(b), nil
Index: src/internal/syscall/windows/syscall_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/internal/syscall/windows/syscall_windows.go b/src/internal/syscall/windows/syscall_windows.go
--- a/src/internal/syscall/windows/syscall_windows.go (revision 8bba868de983dd7bf55fcd121495ba8d6e2734e7)
+++ b/src/internal/syscall/windows/syscall_windows.go (revision 7e6c963d81e14ee394402671d4044b2940c8d2c1)
@@ -384,7 +384,7 @@
//sys DestroyEnvironmentBlock(block *uint16) (err error) = userenv.DestroyEnvironmentBlock
//sys CreateEvent(eventAttrs *SecurityAttributes, manualReset uint32, initialState uint32, name *uint16) (handle syscall.Handle, err error) = kernel32.CreateEventW
-//sys ProcessPrng(buf []byte) (err error) = bcryptprimitives.ProcessPrng
+//sys RtlGenRandom(buf []byte) (err error) = advapi32.SystemFunction036
//sys RtlLookupFunctionEntry(pc uintptr, baseAddress *uintptr, table *byte) (ret uintptr) = kernel32.RtlLookupFunctionEntry
//sys RtlVirtualUnwind(handlerType uint32, baseAddress uintptr, pc uintptr, entry uintptr, ctxt uintptr, data *uintptr, frame *uintptr, ctxptrs *byte) (ret uintptr) = kernel32.RtlVirtualUnwind
Index: src/internal/syscall/windows/zsyscall_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/internal/syscall/windows/zsyscall_windows.go b/src/internal/syscall/windows/zsyscall_windows.go
--- a/src/internal/syscall/windows/zsyscall_windows.go (revision 8bba868de983dd7bf55fcd121495ba8d6e2734e7)
+++ b/src/internal/syscall/windows/zsyscall_windows.go (revision 7e6c963d81e14ee394402671d4044b2940c8d2c1)
@@ -37,14 +37,13 @@
}
var (
- modadvapi32 = syscall.NewLazyDLL(sysdll.Add("advapi32.dll"))
- modbcryptprimitives = syscall.NewLazyDLL(sysdll.Add("bcryptprimitives.dll"))
- modiphlpapi = syscall.NewLazyDLL(sysdll.Add("iphlpapi.dll"))
- modkernel32 = syscall.NewLazyDLL(sysdll.Add("kernel32.dll"))
- modnetapi32 = syscall.NewLazyDLL(sysdll.Add("netapi32.dll"))
- modpsapi = syscall.NewLazyDLL(sysdll.Add("psapi.dll"))
- moduserenv = syscall.NewLazyDLL(sysdll.Add("userenv.dll"))
- modws2_32 = syscall.NewLazyDLL(sysdll.Add("ws2_32.dll"))
+ modadvapi32 = syscall.NewLazyDLL(sysdll.Add("advapi32.dll"))
+ modiphlpapi = syscall.NewLazyDLL(sysdll.Add("iphlpapi.dll"))
+ modkernel32 = syscall.NewLazyDLL(sysdll.Add("kernel32.dll"))
+ modnetapi32 = syscall.NewLazyDLL(sysdll.Add("netapi32.dll"))
+ modpsapi = syscall.NewLazyDLL(sysdll.Add("psapi.dll"))
+ moduserenv = syscall.NewLazyDLL(sysdll.Add("userenv.dll"))
+ modws2_32 = syscall.NewLazyDLL(sysdll.Add("ws2_32.dll"))
procAdjustTokenPrivileges = modadvapi32.NewProc("AdjustTokenPrivileges")
procDuplicateTokenEx = modadvapi32.NewProc("DuplicateTokenEx")
@@ -53,7 +52,7 @@
procOpenThreadToken = modadvapi32.NewProc("OpenThreadToken")
procRevertToSelf = modadvapi32.NewProc("RevertToSelf")
procSetTokenInformation = modadvapi32.NewProc("SetTokenInformation")
- procProcessPrng = modbcryptprimitives.NewProc("ProcessPrng")
+ procSystemFunction036 = modadvapi32.NewProc("SystemFunction036")
procGetAdaptersAddresses = modiphlpapi.NewProc("GetAdaptersAddresses")
procCreateEventW = modkernel32.NewProc("CreateEventW")
procGetACP = modkernel32.NewProc("GetACP")
@@ -149,12 +148,12 @@
return
}
-func ProcessPrng(buf []byte) (err error) {
+func RtlGenRandom(buf []byte) (err error) {
var _p0 *byte
if len(buf) > 0 {
_p0 = &buf[0]
}
- r1, _, e1 := syscall.Syscall(procProcessPrng.Addr(), 2, uintptr(unsafe.Pointer(_p0)), uintptr(len(buf)), 0)
+ r1, _, e1 := syscall.Syscall(procSystemFunction036.Addr(), 2, uintptr(unsafe.Pointer(_p0)), uintptr(len(buf)), 0)
if r1 == 0 {
err = errnoErr(e1)
}
Index: src/runtime/os_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/runtime/os_windows.go b/src/runtime/os_windows.go
--- a/src/runtime/os_windows.go (revision 8bba868de983dd7bf55fcd121495ba8d6e2734e7)
+++ b/src/runtime/os_windows.go (revision 7e6c963d81e14ee394402671d4044b2940c8d2c1)
@@ -127,8 +127,15 @@
_AddVectoredContinueHandler,
_ stdFunction
- // Use ProcessPrng to generate cryptographically random data.
- _ProcessPrng stdFunction
+ // Use RtlGenRandom to generate cryptographically random data.
+ // This approach has been recommended by Microsoft (see issue
+ // 15589 for details).
+ // The RtlGenRandom is not listed in advapi32.dll, instead
+ // RtlGenRandom function can be found by searching for SystemFunction036.
+ // Also some versions of Mingw cannot link to SystemFunction036
+ // when building executable as Cgo. So load SystemFunction036
+ // manually during runtime startup.
+ _RtlGenRandom stdFunction
// Load ntdll.dll manually during startup, otherwise Mingw
// links wrong printf function to cgo executable (see issue
@@ -145,12 +152,12 @@
)
var (
- bcryptprimitivesdll = [...]uint16{'b', 'c', 'r', 'y', 'p', 't', 'p', 'r', 'i', 'm', 'i', 't', 'i', 'v', 'e', 's', '.', 'd', 'l', 'l', 0}
- kernel32dll = [...]uint16{'k', 'e', 'r', 'n', 'e', 'l', '3', '2', '.', 'd', 'l', 'l', 0}
- ntdlldll = [...]uint16{'n', 't', 'd', 'l', 'l', '.', 'd', 'l', 'l', 0}
- powrprofdll = [...]uint16{'p', 'o', 'w', 'r', 'p', 'r', 'o', 'f', '.', 'd', 'l', 'l', 0}
- winmmdll = [...]uint16{'w', 'i', 'n', 'm', 'm', '.', 'd', 'l', 'l', 0}
- ws2_32dll = [...]uint16{'w', 's', '2', '_', '3', '2', '.', 'd', 'l', 'l', 0}
+ advapi32dll = [...]uint16{'a', 'd', 'v', 'a', 'p', 'i', '3', '2', '.', 'd', 'l', 'l', 0}
+ kernel32dll = [...]uint16{'k', 'e', 'r', 'n', 'e', 'l', '3', '2', '.', 'd', 'l', 'l', 0}
+ ntdlldll = [...]uint16{'n', 't', 'd', 'l', 'l', '.', 'd', 'l', 'l', 0}
+ powrprofdll = [...]uint16{'p', 'o', 'w', 'r', 'p', 'r', 'o', 'f', '.', 'd', 'l', 'l', 0}
+ winmmdll = [...]uint16{'w', 'i', 'n', 'm', 'm', '.', 'd', 'l', 'l', 0}
+ ws2_32dll = [...]uint16{'w', 's', '2', '_', '3', '2', '.', 'd', 'l', 'l', 0}
)
// Function to be called by windows CreateThread
@@ -249,11 +256,11 @@
}
_AddVectoredContinueHandler = windowsFindfunc(k32, []byte("AddVectoredContinueHandler\000"))
- bcryptPrimitives := windowsLoadSystemLib(bcryptprimitivesdll[:])
- if bcryptPrimitives == 0 {
- throw("bcryptprimitives.dll not found")
+ a32 := windowsLoadSystemLib(advapi32dll[:])
+ if a32 == 0 {
+ throw("advapi32.dll not found")
}
- _ProcessPrng = windowsFindfunc(bcryptPrimitives, []byte("ProcessPrng\000"))
+ _RtlGenRandom = windowsFindfunc(a32, []byte("SystemFunction036\000"))
n32 := windowsLoadSystemLib(ntdlldll[:])
if n32 == 0 {
@@ -610,7 +617,7 @@
//go:nosplit
func getRandomData(r []byte) {
n := 0
- if stdcall2(_ProcessPrng, uintptr(unsafe.Pointer(&r[0])), uintptr(len(r)))&0xff != 0 {
+ if stdcall2(_RtlGenRandom, uintptr(unsafe.Pointer(&r[0])), uintptr(len(r)))&0xff != 0 {
n = len(r)
}
extendRandom(r, n)

View File

@@ -1,645 +0,0 @@
Subject: [PATCH] Revert "runtime: always use LoadLibraryEx to load system libraries"
Revert "syscall: remove Windows 7 console handle workaround"
Revert "net: remove sysSocket fallback for Windows 7"
Revert "crypto/rand,runtime: switch RtlGenRandom for ProcessPrng"
---
Index: src/crypto/rand/rand.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/crypto/rand/rand.go b/src/crypto/rand/rand.go
--- a/src/crypto/rand/rand.go (revision cb4eee693c382bea4222f20837e26501d40ed892)
+++ b/src/crypto/rand/rand.go (revision 9779155f18b6556a034f7bb79fb7fb2aad1e26a9)
@@ -15,7 +15,7 @@
// available, /dev/urandom otherwise.
// On OpenBSD and macOS, Reader uses getentropy(2).
// On other Unix-like systems, Reader reads from /dev/urandom.
-// On Windows systems, Reader uses the ProcessPrng API.
+// On Windows systems, Reader uses the RtlGenRandom API.
// On JS/Wasm, Reader uses the Web Crypto API.
// On WASIP1/Wasm, Reader uses random_get from wasi_snapshot_preview1.
var Reader io.Reader
Index: src/crypto/rand/rand_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/crypto/rand/rand_windows.go b/src/crypto/rand/rand_windows.go
--- a/src/crypto/rand/rand_windows.go (revision cb4eee693c382bea4222f20837e26501d40ed892)
+++ b/src/crypto/rand/rand_windows.go (revision 9779155f18b6556a034f7bb79fb7fb2aad1e26a9)
@@ -15,8 +15,11 @@
type rngReader struct{}
-func (r *rngReader) Read(b []byte) (int, error) {
- if err := windows.ProcessPrng(b); err != nil {
+func (r *rngReader) Read(b []byte) (n int, err error) {
+ // RtlGenRandom only returns 1<<32-1 bytes at a time. We only read at
+ // most 1<<31-1 bytes at a time so that this works the same on 32-bit
+ // and 64-bit systems.
+ if err := batched(windows.RtlGenRandom, 1<<31-1)(b); err != nil {
return 0, err
}
return len(b), nil
Index: src/internal/syscall/windows/syscall_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/internal/syscall/windows/syscall_windows.go b/src/internal/syscall/windows/syscall_windows.go
--- a/src/internal/syscall/windows/syscall_windows.go (revision cb4eee693c382bea4222f20837e26501d40ed892)
+++ b/src/internal/syscall/windows/syscall_windows.go (revision 9779155f18b6556a034f7bb79fb7fb2aad1e26a9)
@@ -384,7 +384,7 @@
//sys DestroyEnvironmentBlock(block *uint16) (err error) = userenv.DestroyEnvironmentBlock
//sys CreateEvent(eventAttrs *SecurityAttributes, manualReset uint32, initialState uint32, name *uint16) (handle syscall.Handle, err error) = kernel32.CreateEventW
-//sys ProcessPrng(buf []byte) (err error) = bcryptprimitives.ProcessPrng
+//sys RtlGenRandom(buf []byte) (err error) = advapi32.SystemFunction036
type FILE_ID_BOTH_DIR_INFO struct {
NextEntryOffset uint32
Index: src/internal/syscall/windows/zsyscall_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/internal/syscall/windows/zsyscall_windows.go b/src/internal/syscall/windows/zsyscall_windows.go
--- a/src/internal/syscall/windows/zsyscall_windows.go (revision cb4eee693c382bea4222f20837e26501d40ed892)
+++ b/src/internal/syscall/windows/zsyscall_windows.go (revision 9779155f18b6556a034f7bb79fb7fb2aad1e26a9)
@@ -37,14 +37,13 @@
}
var (
- modadvapi32 = syscall.NewLazyDLL(sysdll.Add("advapi32.dll"))
- modbcryptprimitives = syscall.NewLazyDLL(sysdll.Add("bcryptprimitives.dll"))
- modiphlpapi = syscall.NewLazyDLL(sysdll.Add("iphlpapi.dll"))
- modkernel32 = syscall.NewLazyDLL(sysdll.Add("kernel32.dll"))
- modnetapi32 = syscall.NewLazyDLL(sysdll.Add("netapi32.dll"))
- modpsapi = syscall.NewLazyDLL(sysdll.Add("psapi.dll"))
- moduserenv = syscall.NewLazyDLL(sysdll.Add("userenv.dll"))
- modws2_32 = syscall.NewLazyDLL(sysdll.Add("ws2_32.dll"))
+ modadvapi32 = syscall.NewLazyDLL(sysdll.Add("advapi32.dll"))
+ modiphlpapi = syscall.NewLazyDLL(sysdll.Add("iphlpapi.dll"))
+ modkernel32 = syscall.NewLazyDLL(sysdll.Add("kernel32.dll"))
+ modnetapi32 = syscall.NewLazyDLL(sysdll.Add("netapi32.dll"))
+ modpsapi = syscall.NewLazyDLL(sysdll.Add("psapi.dll"))
+ moduserenv = syscall.NewLazyDLL(sysdll.Add("userenv.dll"))
+ modws2_32 = syscall.NewLazyDLL(sysdll.Add("ws2_32.dll"))
procAdjustTokenPrivileges = modadvapi32.NewProc("AdjustTokenPrivileges")
procDuplicateTokenEx = modadvapi32.NewProc("DuplicateTokenEx")
@@ -56,7 +55,7 @@
procQueryServiceStatus = modadvapi32.NewProc("QueryServiceStatus")
procRevertToSelf = modadvapi32.NewProc("RevertToSelf")
procSetTokenInformation = modadvapi32.NewProc("SetTokenInformation")
- procProcessPrng = modbcryptprimitives.NewProc("ProcessPrng")
+ procSystemFunction036 = modadvapi32.NewProc("SystemFunction036")
procGetAdaptersAddresses = modiphlpapi.NewProc("GetAdaptersAddresses")
procCreateEventW = modkernel32.NewProc("CreateEventW")
procGetACP = modkernel32.NewProc("GetACP")
@@ -180,12 +179,12 @@
return
}
-func ProcessPrng(buf []byte) (err error) {
+func RtlGenRandom(buf []byte) (err error) {
var _p0 *byte
if len(buf) > 0 {
_p0 = &buf[0]
}
- r1, _, e1 := syscall.Syscall(procProcessPrng.Addr(), 2, uintptr(unsafe.Pointer(_p0)), uintptr(len(buf)), 0)
+ r1, _, e1 := syscall.Syscall(procSystemFunction036.Addr(), 2, uintptr(unsafe.Pointer(_p0)), uintptr(len(buf)), 0)
if r1 == 0 {
err = errnoErr(e1)
}
Index: src/runtime/os_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/runtime/os_windows.go b/src/runtime/os_windows.go
--- a/src/runtime/os_windows.go (revision cb4eee693c382bea4222f20837e26501d40ed892)
+++ b/src/runtime/os_windows.go (revision 83ff9782e024cb328b690cbf0da4e7848a327f4f)
@@ -40,8 +40,8 @@
//go:cgo_import_dynamic runtime._GetSystemInfo GetSystemInfo%1 "kernel32.dll"
//go:cgo_import_dynamic runtime._GetThreadContext GetThreadContext%2 "kernel32.dll"
//go:cgo_import_dynamic runtime._SetThreadContext SetThreadContext%2 "kernel32.dll"
-//go:cgo_import_dynamic runtime._LoadLibraryExW LoadLibraryExW%3 "kernel32.dll"
//go:cgo_import_dynamic runtime._LoadLibraryW LoadLibraryW%1 "kernel32.dll"
+//go:cgo_import_dynamic runtime._LoadLibraryA LoadLibraryA%1 "kernel32.dll"
//go:cgo_import_dynamic runtime._PostQueuedCompletionStatus PostQueuedCompletionStatus%4 "kernel32.dll"
//go:cgo_import_dynamic runtime._QueryPerformanceCounter QueryPerformanceCounter%1 "kernel32.dll"
//go:cgo_import_dynamic runtime._RaiseFailFastException RaiseFailFastException%3 "kernel32.dll"
@@ -74,7 +74,6 @@
// Following syscalls are available on every Windows PC.
// All these variables are set by the Windows executable
// loader before the Go program starts.
- _AddVectoredContinueHandler,
_AddVectoredExceptionHandler,
_CloseHandle,
_CreateEventA,
@@ -98,8 +97,8 @@
_GetSystemInfo,
_GetThreadContext,
_SetThreadContext,
- _LoadLibraryExW,
_LoadLibraryW,
+ _LoadLibraryA,
_PostQueuedCompletionStatus,
_QueryPerformanceCounter,
_RaiseFailFastException,
@@ -127,8 +126,23 @@
_WriteFile,
_ stdFunction
- // Use ProcessPrng to generate cryptographically random data.
- _ProcessPrng stdFunction
+ // Following syscalls are only available on some Windows PCs.
+ // We will load syscalls, if available, before using them.
+ _AddDllDirectory,
+ _AddVectoredContinueHandler,
+ _LoadLibraryExA,
+ _LoadLibraryExW,
+ _ stdFunction
+
+ // Use RtlGenRandom to generate cryptographically random data.
+ // This approach has been recommended by Microsoft (see issue
+ // 15589 for details).
+ // The RtlGenRandom is not listed in advapi32.dll, instead
+ // RtlGenRandom function can be found by searching for SystemFunction036.
+ // Also some versions of Mingw cannot link to SystemFunction036
+ // when building executable as Cgo. So load SystemFunction036
+ // manually during runtime startup.
+ _RtlGenRandom stdFunction
// Load ntdll.dll manually during startup, otherwise Mingw
// links wrong printf function to cgo executable (see issue
@@ -143,14 +157,6 @@
_ stdFunction
)
-var (
- bcryptprimitivesdll = [...]uint16{'b', 'c', 'r', 'y', 'p', 't', 'p', 'r', 'i', 'm', 'i', 't', 'i', 'v', 'e', 's', '.', 'd', 'l', 'l', 0}
- ntdlldll = [...]uint16{'n', 't', 'd', 'l', 'l', '.', 'd', 'l', 'l', 0}
- powrprofdll = [...]uint16{'p', 'o', 'w', 'r', 'p', 'r', 'o', 'f', '.', 'd', 'l', 'l', 0}
- winmmdll = [...]uint16{'w', 'i', 'n', 'm', 'm', '.', 'd', 'l', 'l', 0}
- ws2_32dll = [...]uint16{'w', 's', '2', '_', '3', '2', '.', 'd', 'l', 'l', 0}
-)
-
// Function to be called by windows CreateThread
// to start new os thread.
func tstart_stdcall(newm *m)
@@ -239,25 +245,51 @@
return unsafe.String(&sysDirectory[0], sysDirectoryLen)
}
-func windowsLoadSystemLib(name []uint16) uintptr {
- return stdcall3(_LoadLibraryExW, uintptr(unsafe.Pointer(&name[0])), 0, _LOAD_LIBRARY_SEARCH_SYSTEM32)
+//go:linkname syscall_getSystemDirectory syscall.getSystemDirectory
+func syscall_getSystemDirectory() string {
+ return unsafe.String(&sysDirectory[0], sysDirectoryLen)
+}
+
+func windowsLoadSystemLib(name []byte) uintptr {
+ if useLoadLibraryEx {
+ return stdcall3(_LoadLibraryExA, uintptr(unsafe.Pointer(&name[0])), 0, _LOAD_LIBRARY_SEARCH_SYSTEM32)
+ } else {
+ absName := append(sysDirectory[:sysDirectoryLen], name...)
+ return stdcall1(_LoadLibraryA, uintptr(unsafe.Pointer(&absName[0])))
+ }
}
func loadOptionalSyscalls() {
- bcryptPrimitives := windowsLoadSystemLib(bcryptprimitivesdll[:])
- if bcryptPrimitives == 0 {
- throw("bcryptprimitives.dll not found")
+ var kernel32dll = []byte("kernel32.dll\000")
+ k32 := stdcall1(_LoadLibraryA, uintptr(unsafe.Pointer(&kernel32dll[0])))
+ if k32 == 0 {
+ throw("kernel32.dll not found")
}
- _ProcessPrng = windowsFindfunc(bcryptPrimitives, []byte("ProcessPrng\000"))
+ _AddDllDirectory = windowsFindfunc(k32, []byte("AddDllDirectory\000"))
+ _AddVectoredContinueHandler = windowsFindfunc(k32, []byte("AddVectoredContinueHandler\000"))
+ _LoadLibraryExA = windowsFindfunc(k32, []byte("LoadLibraryExA\000"))
+ _LoadLibraryExW = windowsFindfunc(k32, []byte("LoadLibraryExW\000"))
+ useLoadLibraryEx = (_LoadLibraryExW != nil && _LoadLibraryExA != nil && _AddDllDirectory != nil)
+
+ initSysDirectory()
- n32 := windowsLoadSystemLib(ntdlldll[:])
+ var advapi32dll = []byte("advapi32.dll\000")
+ a32 := windowsLoadSystemLib(advapi32dll)
+ if a32 == 0 {
+ throw("advapi32.dll not found")
+ }
+ _RtlGenRandom = windowsFindfunc(a32, []byte("SystemFunction036\000"))
+
+ var ntdll = []byte("ntdll.dll\000")
+ n32 := windowsLoadSystemLib(ntdll)
if n32 == 0 {
throw("ntdll.dll not found")
}
_RtlGetCurrentPeb = windowsFindfunc(n32, []byte("RtlGetCurrentPeb\000"))
_RtlGetNtVersionNumbers = windowsFindfunc(n32, []byte("RtlGetNtVersionNumbers\000"))
- m32 := windowsLoadSystemLib(winmmdll[:])
+ var winmmdll = []byte("winmm.dll\000")
+ m32 := windowsLoadSystemLib(winmmdll)
if m32 == 0 {
throw("winmm.dll not found")
}
@@ -267,7 +299,8 @@
throw("timeBegin/EndPeriod not found")
}
- ws232 := windowsLoadSystemLib(ws2_32dll[:])
+ var ws232dll = []byte("ws2_32.dll\000")
+ ws232 := windowsLoadSystemLib(ws232dll)
if ws232 == 0 {
throw("ws2_32.dll not found")
}
@@ -286,7 +319,7 @@
context uintptr
}
- powrprof := windowsLoadSystemLib(powrprofdll[:])
+ powrprof := windowsLoadSystemLib([]byte("powrprof.dll\000"))
if powrprof == 0 {
return // Running on Windows 7, where we don't need it anyway.
}
@@ -360,6 +393,22 @@
// in sys_windows_386.s and sys_windows_amd64.s:
func getlasterror() uint32
+// When loading DLLs, we prefer to use LoadLibraryEx with
+// LOAD_LIBRARY_SEARCH_* flags, if available. LoadLibraryEx is not
+// available on old Windows, though, and the LOAD_LIBRARY_SEARCH_*
+// flags are not available on some versions of Windows without a
+// security patch.
+//
+// https://msdn.microsoft.com/en-us/library/ms684179(v=vs.85).aspx says:
+// "Windows 7, Windows Server 2008 R2, Windows Vista, and Windows
+// Server 2008: The LOAD_LIBRARY_SEARCH_* flags are available on
+// systems that have KB2533623 installed. To determine whether the
+// flags are available, use GetProcAddress to get the address of the
+// AddDllDirectory, RemoveDllDirectory, or SetDefaultDllDirectories
+// function. If GetProcAddress succeeds, the LOAD_LIBRARY_SEARCH_*
+// flags can be used with LoadLibraryEx."
+var useLoadLibraryEx bool
+
var timeBeginPeriodRetValue uint32
// osRelaxMinNS indicates that sysmon shouldn't osRelax if the next
@@ -507,7 +556,6 @@
initHighResTimer()
timeBeginPeriodRetValue = osRelax(false)
- initSysDirectory()
initLongPathSupport()
ncpu = getproccount()
@@ -524,7 +572,7 @@
//go:nosplit
func readRandom(r []byte) int {
n := 0
- if stdcall2(_ProcessPrng, uintptr(unsafe.Pointer(&r[0])), uintptr(len(r)))&0xff != 0 {
+ if stdcall2(_RtlGenRandom, uintptr(unsafe.Pointer(&r[0])), uintptr(len(r)))&0xff != 0 {
n = len(r)
}
return n
Index: src/net/hook_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/net/hook_windows.go b/src/net/hook_windows.go
--- a/src/net/hook_windows.go (revision 9779155f18b6556a034f7bb79fb7fb2aad1e26a9)
+++ b/src/net/hook_windows.go (revision ef0606261340e608017860b423ffae5c1ce78239)
@@ -13,6 +13,7 @@
hostsFilePath = windows.GetSystemDirectory() + "/Drivers/etc/hosts"
// Placeholders for socket system calls.
+ socketFunc func(int, int, int) (syscall.Handle, error) = syscall.Socket
wsaSocketFunc func(int32, int32, int32, *syscall.WSAProtocolInfo, uint32, uint32) (syscall.Handle, error) = windows.WSASocket
connectFunc func(syscall.Handle, syscall.Sockaddr) error = syscall.Connect
listenFunc func(syscall.Handle, int) error = syscall.Listen
Index: src/net/internal/socktest/main_test.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/net/internal/socktest/main_test.go b/src/net/internal/socktest/main_test.go
--- a/src/net/internal/socktest/main_test.go (revision 9779155f18b6556a034f7bb79fb7fb2aad1e26a9)
+++ b/src/net/internal/socktest/main_test.go (revision ef0606261340e608017860b423ffae5c1ce78239)
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build !js && !plan9 && !wasip1 && !windows
+//go:build !js && !plan9 && !wasip1
package socktest_test
Index: src/net/internal/socktest/main_windows_test.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/net/internal/socktest/main_windows_test.go b/src/net/internal/socktest/main_windows_test.go
new file mode 100644
--- /dev/null (revision ef0606261340e608017860b423ffae5c1ce78239)
+++ b/src/net/internal/socktest/main_windows_test.go (revision ef0606261340e608017860b423ffae5c1ce78239)
@@ -0,0 +1,22 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package socktest_test
+
+import "syscall"
+
+var (
+ socketFunc func(int, int, int) (syscall.Handle, error)
+ closeFunc func(syscall.Handle) error
+)
+
+func installTestHooks() {
+ socketFunc = sw.Socket
+ closeFunc = sw.Closesocket
+}
+
+func uninstallTestHooks() {
+ socketFunc = syscall.Socket
+ closeFunc = syscall.Closesocket
+}
Index: src/net/internal/socktest/sys_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/net/internal/socktest/sys_windows.go b/src/net/internal/socktest/sys_windows.go
--- a/src/net/internal/socktest/sys_windows.go (revision 9779155f18b6556a034f7bb79fb7fb2aad1e26a9)
+++ b/src/net/internal/socktest/sys_windows.go (revision ef0606261340e608017860b423ffae5c1ce78239)
@@ -9,6 +9,38 @@
"syscall"
)
+// Socket wraps [syscall.Socket].
+func (sw *Switch) Socket(family, sotype, proto int) (s syscall.Handle, err error) {
+ sw.once.Do(sw.init)
+
+ so := &Status{Cookie: cookie(family, sotype, proto)}
+ sw.fmu.RLock()
+ f, _ := sw.fltab[FilterSocket]
+ sw.fmu.RUnlock()
+
+ af, err := f.apply(so)
+ if err != nil {
+ return syscall.InvalidHandle, err
+ }
+ s, so.Err = syscall.Socket(family, sotype, proto)
+ if err = af.apply(so); err != nil {
+ if so.Err == nil {
+ syscall.Closesocket(s)
+ }
+ return syscall.InvalidHandle, err
+ }
+
+ sw.smu.Lock()
+ defer sw.smu.Unlock()
+ if so.Err != nil {
+ sw.stats.getLocked(so.Cookie).OpenFailed++
+ return syscall.InvalidHandle, so.Err
+ }
+ nso := sw.addLocked(s, family, sotype, proto)
+ sw.stats.getLocked(nso.Cookie).Opened++
+ return s, nil
+}
+
// WSASocket wraps [syscall.WSASocket].
func (sw *Switch) WSASocket(family, sotype, proto int32, protinfo *syscall.WSAProtocolInfo, group uint32, flags uint32) (s syscall.Handle, err error) {
sw.once.Do(sw.init)
Index: src/net/main_windows_test.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/net/main_windows_test.go b/src/net/main_windows_test.go
--- a/src/net/main_windows_test.go (revision 9779155f18b6556a034f7bb79fb7fb2aad1e26a9)
+++ b/src/net/main_windows_test.go (revision ef0606261340e608017860b423ffae5c1ce78239)
@@ -8,6 +8,7 @@
var (
// Placeholders for saving original socket system calls.
+ origSocket = socketFunc
origWSASocket = wsaSocketFunc
origClosesocket = poll.CloseFunc
origConnect = connectFunc
@@ -17,6 +18,7 @@
)
func installTestHooks() {
+ socketFunc = sw.Socket
wsaSocketFunc = sw.WSASocket
poll.CloseFunc = sw.Closesocket
connectFunc = sw.Connect
@@ -26,6 +28,7 @@
}
func uninstallTestHooks() {
+ socketFunc = origSocket
wsaSocketFunc = origWSASocket
poll.CloseFunc = origClosesocket
connectFunc = origConnect
Index: src/net/sock_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/net/sock_windows.go b/src/net/sock_windows.go
--- a/src/net/sock_windows.go (revision 9779155f18b6556a034f7bb79fb7fb2aad1e26a9)
+++ b/src/net/sock_windows.go (revision ef0606261340e608017860b423ffae5c1ce78239)
@@ -20,6 +20,21 @@
func sysSocket(family, sotype, proto int) (syscall.Handle, error) {
s, err := wsaSocketFunc(int32(family), int32(sotype), int32(proto),
nil, 0, windows.WSA_FLAG_OVERLAPPED|windows.WSA_FLAG_NO_HANDLE_INHERIT)
+ if err == nil {
+ return s, nil
+ }
+ // WSA_FLAG_NO_HANDLE_INHERIT flag is not supported on some
+ // old versions of Windows, see
+ // https://msdn.microsoft.com/en-us/library/windows/desktop/ms742212(v=vs.85).aspx
+ // for details. Just use syscall.Socket, if windows.WSASocket failed.
+
+ // See ../syscall/exec_unix.go for description of ForkLock.
+ syscall.ForkLock.RLock()
+ s, err = socketFunc(family, sotype, proto)
+ if err == nil {
+ syscall.CloseOnExec(s)
+ }
+ syscall.ForkLock.RUnlock()
if err != nil {
return syscall.InvalidHandle, os.NewSyscallError("socket", err)
}
Index: src/syscall/exec_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/syscall/exec_windows.go b/src/syscall/exec_windows.go
--- a/src/syscall/exec_windows.go (revision 9779155f18b6556a034f7bb79fb7fb2aad1e26a9)
+++ b/src/syscall/exec_windows.go (revision 7f83badcb925a7e743188041cb6e561fc9b5b642)
@@ -14,7 +14,6 @@
"unsafe"
)
-// ForkLock is not used on Windows.
var ForkLock sync.RWMutex
// EscapeArg rewrites command line argument s as prescribed
@@ -317,6 +316,17 @@
}
}
+ var maj, min, build uint32
+ rtlGetNtVersionNumbers(&maj, &min, &build)
+ isWin7 := maj < 6 || (maj == 6 && min <= 1)
+ // NT kernel handles are divisible by 4, with the bottom 3 bits left as
+ // a tag. The fully set tag correlates with the types of handles we're
+ // concerned about here. Except, the kernel will interpret some
+ // special handle values, like -1, -2, and so forth, so kernelbase.dll
+ // checks to see that those bottom three bits are checked, but that top
+ // bit is not checked.
+ isLegacyWin7ConsoleHandle := func(handle Handle) bool { return isWin7 && handle&0x10000003 == 3 }
+
p, _ := GetCurrentProcess()
parentProcess := p
if sys.ParentProcess != 0 {
@@ -325,7 +335,15 @@
fd := make([]Handle, len(attr.Files))
for i := range attr.Files {
if attr.Files[i] > 0 {
- err := DuplicateHandle(p, Handle(attr.Files[i]), parentProcess, &fd[i], 0, true, DUPLICATE_SAME_ACCESS)
+ destinationProcessHandle := parentProcess
+
+ // On Windows 7, console handles aren't real handles, and can only be duplicated
+ // into the current process, not a parent one, which amounts to the same thing.
+ if parentProcess != p && isLegacyWin7ConsoleHandle(Handle(attr.Files[i])) {
+ destinationProcessHandle = p
+ }
+
+ err := DuplicateHandle(p, Handle(attr.Files[i]), destinationProcessHandle, &fd[i], 0, true, DUPLICATE_SAME_ACCESS)
if err != nil {
return 0, 0, err
}
@@ -356,6 +374,14 @@
fd = append(fd, sys.AdditionalInheritedHandles...)
+ // On Windows 7, console handles aren't real handles, so don't pass them
+ // through to PROC_THREAD_ATTRIBUTE_HANDLE_LIST.
+ for i := range fd {
+ if isLegacyWin7ConsoleHandle(fd[i]) {
+ fd[i] = 0
+ }
+ }
+
// The presence of a NULL handle in the list is enough to cause PROC_THREAD_ATTRIBUTE_HANDLE_LIST
// to treat the entire list as empty, so remove NULL handles.
j := 0
Index: src/runtime/syscall_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/runtime/syscall_windows.go b/src/runtime/syscall_windows.go
--- a/src/runtime/syscall_windows.go (revision 7f83badcb925a7e743188041cb6e561fc9b5b642)
+++ b/src/runtime/syscall_windows.go (revision 83ff9782e024cb328b690cbf0da4e7848a327f4f)
@@ -413,23 +413,36 @@
const _LOAD_LIBRARY_SEARCH_SYSTEM32 = 0x00000800
+// When available, this function will use LoadLibraryEx with the filename
+// parameter and the important SEARCH_SYSTEM32 argument. But on systems that
+// do not have that option, absoluteFilepath should contain a fallback
+// to the full path inside of system32 for use with vanilla LoadLibrary.
+//
//go:linkname syscall_loadsystemlibrary syscall.loadsystemlibrary
//go:nosplit
//go:cgo_unsafe_args
-func syscall_loadsystemlibrary(filename *uint16) (handle, err uintptr) {
+func syscall_loadsystemlibrary(filename *uint16, absoluteFilepath *uint16) (handle, err uintptr) {
lockOSThread()
c := &getg().m.syscall
- c.fn = getLoadLibraryEx()
- c.n = 3
- args := struct {
- lpFileName *uint16
- hFile uintptr // always 0
- flags uint32
- }{filename, 0, _LOAD_LIBRARY_SEARCH_SYSTEM32}
- c.args = uintptr(noescape(unsafe.Pointer(&args)))
+
+ if useLoadLibraryEx {
+ c.fn = getLoadLibraryEx()
+ c.n = 3
+ args := struct {
+ lpFileName *uint16
+ hFile uintptr // always 0
+ flags uint32
+ }{filename, 0, _LOAD_LIBRARY_SEARCH_SYSTEM32}
+ c.args = uintptr(noescape(unsafe.Pointer(&args)))
+ } else {
+ c.fn = getLoadLibrary()
+ c.n = 1
+ c.args = uintptr(noescape(unsafe.Pointer(&absoluteFilepath)))
+ }
cgocall(asmstdcallAddr, unsafe.Pointer(c))
KeepAlive(filename)
+ KeepAlive(absoluteFilepath)
handle = c.r1
if handle == 0 {
err = c.err
Index: src/syscall/dll_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/syscall/dll_windows.go b/src/syscall/dll_windows.go
--- a/src/syscall/dll_windows.go (revision 7f83badcb925a7e743188041cb6e561fc9b5b642)
+++ b/src/syscall/dll_windows.go (revision 83ff9782e024cb328b690cbf0da4e7848a327f4f)
@@ -44,7 +44,7 @@
func SyscallN(trap uintptr, args ...uintptr) (r1, r2 uintptr, err Errno)
func loadlibrary(filename *uint16) (handle uintptr, err Errno)
-func loadsystemlibrary(filename *uint16) (handle uintptr, err Errno)
+func loadsystemlibrary(filename *uint16, absoluteFilepath *uint16) (handle uintptr, err Errno)
func getprocaddress(handle uintptr, procname *uint8) (proc uintptr, err Errno)
// A DLL implements access to a single DLL.
@@ -53,6 +53,9 @@
Handle Handle
}
+//go:linkname getSystemDirectory
+func getSystemDirectory() string // Implemented in runtime package.
+
// LoadDLL loads the named DLL file into memory.
//
// If name is not an absolute path and is not a known system DLL used by
@@ -69,7 +72,11 @@
var h uintptr
var e Errno
if sysdll.IsSystemDLL[name] {
- h, e = loadsystemlibrary(namep)
+ absoluteFilepathp, err := UTF16PtrFromString(getSystemDirectory() + name)
+ if err != nil {
+ return nil, err
+ }
+ h, e = loadsystemlibrary(namep, absoluteFilepathp)
} else {
h, e = loadlibrary(namep)
}

View File

@@ -1,643 +0,0 @@
Subject: [PATCH] Revert "runtime: always use LoadLibraryEx to load system libraries"
Revert "syscall: remove Windows 7 console handle workaround"
Revert "net: remove sysSocket fallback for Windows 7"
Revert "crypto/rand,runtime: switch RtlGenRandom for ProcessPrng"
---
Index: src/crypto/rand/rand.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/crypto/rand/rand.go b/src/crypto/rand/rand.go
--- a/src/crypto/rand/rand.go (revision 6885bad7dd86880be6929c02085e5c7a67ff2887)
+++ b/src/crypto/rand/rand.go (revision 9ac42137ef6730e8b7daca016ece831297a1d75b)
@@ -16,7 +16,7 @@
// - On macOS and iOS, Reader uses arc4random_buf(3).
// - On OpenBSD and NetBSD, Reader uses getentropy(2).
// - On other Unix-like systems, Reader reads from /dev/urandom.
-// - On Windows, Reader uses the ProcessPrng API.
+// - On Windows systems, Reader uses the RtlGenRandom API.
// - On js/wasm, Reader uses the Web Crypto API.
// - On wasip1/wasm, Reader uses random_get from wasi_snapshot_preview1.
var Reader io.Reader
Index: src/crypto/rand/rand_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/crypto/rand/rand_windows.go b/src/crypto/rand/rand_windows.go
--- a/src/crypto/rand/rand_windows.go (revision 6885bad7dd86880be6929c02085e5c7a67ff2887)
+++ b/src/crypto/rand/rand_windows.go (revision 9ac42137ef6730e8b7daca016ece831297a1d75b)
@@ -15,8 +15,11 @@
type rngReader struct{}
-func (r *rngReader) Read(b []byte) (int, error) {
- if err := windows.ProcessPrng(b); err != nil {
+func (r *rngReader) Read(b []byte) (n int, err error) {
+ // RtlGenRandom only returns 1<<32-1 bytes at a time. We only read at
+ // most 1<<31-1 bytes at a time so that this works the same on 32-bit
+ // and 64-bit systems.
+ if err := batched(windows.RtlGenRandom, 1<<31-1)(b); err != nil {
return 0, err
}
return len(b), nil
Index: src/internal/syscall/windows/syscall_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/internal/syscall/windows/syscall_windows.go b/src/internal/syscall/windows/syscall_windows.go
--- a/src/internal/syscall/windows/syscall_windows.go (revision 6885bad7dd86880be6929c02085e5c7a67ff2887)
+++ b/src/internal/syscall/windows/syscall_windows.go (revision 9ac42137ef6730e8b7daca016ece831297a1d75b)
@@ -414,7 +414,7 @@
//sys DestroyEnvironmentBlock(block *uint16) (err error) = userenv.DestroyEnvironmentBlock
//sys CreateEvent(eventAttrs *SecurityAttributes, manualReset uint32, initialState uint32, name *uint16) (handle syscall.Handle, err error) = kernel32.CreateEventW
-//sys ProcessPrng(buf []byte) (err error) = bcryptprimitives.ProcessPrng
+//sys RtlGenRandom(buf []byte) (err error) = advapi32.SystemFunction036
type FILE_ID_BOTH_DIR_INFO struct {
NextEntryOffset uint32
Index: src/internal/syscall/windows/zsyscall_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/internal/syscall/windows/zsyscall_windows.go b/src/internal/syscall/windows/zsyscall_windows.go
--- a/src/internal/syscall/windows/zsyscall_windows.go (revision 6885bad7dd86880be6929c02085e5c7a67ff2887)
+++ b/src/internal/syscall/windows/zsyscall_windows.go (revision 9ac42137ef6730e8b7daca016ece831297a1d75b)
@@ -38,7 +38,6 @@
var (
modadvapi32 = syscall.NewLazyDLL(sysdll.Add("advapi32.dll"))
- modbcryptprimitives = syscall.NewLazyDLL(sysdll.Add("bcryptprimitives.dll"))
modiphlpapi = syscall.NewLazyDLL(sysdll.Add("iphlpapi.dll"))
modkernel32 = syscall.NewLazyDLL(sysdll.Add("kernel32.dll"))
modnetapi32 = syscall.NewLazyDLL(sysdll.Add("netapi32.dll"))
@@ -57,7 +56,7 @@
procQueryServiceStatus = modadvapi32.NewProc("QueryServiceStatus")
procRevertToSelf = modadvapi32.NewProc("RevertToSelf")
procSetTokenInformation = modadvapi32.NewProc("SetTokenInformation")
- procProcessPrng = modbcryptprimitives.NewProc("ProcessPrng")
+ procSystemFunction036 = modadvapi32.NewProc("SystemFunction036")
procGetAdaptersAddresses = modiphlpapi.NewProc("GetAdaptersAddresses")
procCreateEventW = modkernel32.NewProc("CreateEventW")
procGetACP = modkernel32.NewProc("GetACP")
@@ -183,12 +182,12 @@
return
}
-func ProcessPrng(buf []byte) (err error) {
+func RtlGenRandom(buf []byte) (err error) {
var _p0 *byte
if len(buf) > 0 {
_p0 = &buf[0]
}
- r1, _, e1 := syscall.Syscall(procProcessPrng.Addr(), 2, uintptr(unsafe.Pointer(_p0)), uintptr(len(buf)), 0)
+ r1, _, e1 := syscall.Syscall(procSystemFunction036.Addr(), 2, uintptr(unsafe.Pointer(_p0)), uintptr(len(buf)), 0)
if r1 == 0 {
err = errnoErr(e1)
}
Index: src/runtime/os_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/runtime/os_windows.go b/src/runtime/os_windows.go
--- a/src/runtime/os_windows.go (revision 6885bad7dd86880be6929c02085e5c7a67ff2887)
+++ b/src/runtime/os_windows.go (revision 69e2eed6dd0f6d815ebf15797761c13f31213dd6)
@@ -39,8 +39,8 @@
//go:cgo_import_dynamic runtime._GetSystemInfo GetSystemInfo%1 "kernel32.dll"
//go:cgo_import_dynamic runtime._GetThreadContext GetThreadContext%2 "kernel32.dll"
//go:cgo_import_dynamic runtime._SetThreadContext SetThreadContext%2 "kernel32.dll"
-//go:cgo_import_dynamic runtime._LoadLibraryExW LoadLibraryExW%3 "kernel32.dll"
//go:cgo_import_dynamic runtime._LoadLibraryW LoadLibraryW%1 "kernel32.dll"
+//go:cgo_import_dynamic runtime._LoadLibraryA LoadLibraryA%1 "kernel32.dll"
//go:cgo_import_dynamic runtime._PostQueuedCompletionStatus PostQueuedCompletionStatus%4 "kernel32.dll"
//go:cgo_import_dynamic runtime._QueryPerformanceCounter QueryPerformanceCounter%1 "kernel32.dll"
//go:cgo_import_dynamic runtime._QueryPerformanceFrequency QueryPerformanceFrequency%1 "kernel32.dll"
@@ -74,7 +74,6 @@
// Following syscalls are available on every Windows PC.
// All these variables are set by the Windows executable
// loader before the Go program starts.
- _AddVectoredContinueHandler,
_AddVectoredExceptionHandler,
_CloseHandle,
_CreateEventA,
@@ -97,8 +96,8 @@
_GetSystemInfo,
_GetThreadContext,
_SetThreadContext,
- _LoadLibraryExW,
_LoadLibraryW,
+ _LoadLibraryA,
_PostQueuedCompletionStatus,
_QueryPerformanceCounter,
_QueryPerformanceFrequency,
@@ -127,8 +126,23 @@
_WriteFile,
_ stdFunction
- // Use ProcessPrng to generate cryptographically random data.
- _ProcessPrng stdFunction
+ // Following syscalls are only available on some Windows PCs.
+ // We will load syscalls, if available, before using them.
+ _AddDllDirectory,
+ _AddVectoredContinueHandler,
+ _LoadLibraryExA,
+ _LoadLibraryExW,
+ _ stdFunction
+
+ // Use RtlGenRandom to generate cryptographically random data.
+ // This approach has been recommended by Microsoft (see issue
+ // 15589 for details).
+ // The RtlGenRandom is not listed in advapi32.dll, instead
+ // RtlGenRandom function can be found by searching for SystemFunction036.
+ // Also some versions of Mingw cannot link to SystemFunction036
+ // when building executable as Cgo. So load SystemFunction036
+ // manually during runtime startup.
+ _RtlGenRandom stdFunction
// Load ntdll.dll manually during startup, otherwise Mingw
// links wrong printf function to cgo executable (see issue
@@ -145,13 +159,6 @@
_ stdFunction
)
-var (
- bcryptprimitivesdll = [...]uint16{'b', 'c', 'r', 'y', 'p', 't', 'p', 'r', 'i', 'm', 'i', 't', 'i', 'v', 'e', 's', '.', 'd', 'l', 'l', 0}
- ntdlldll = [...]uint16{'n', 't', 'd', 'l', 'l', '.', 'd', 'l', 'l', 0}
- powrprofdll = [...]uint16{'p', 'o', 'w', 'r', 'p', 'r', 'o', 'f', '.', 'd', 'l', 'l', 0}
- winmmdll = [...]uint16{'w', 'i', 'n', 'm', 'm', '.', 'd', 'l', 'l', 0}
-)
-
// Function to be called by windows CreateThread
// to start new os thread.
func tstart_stdcall(newm *m)
@@ -244,8 +251,18 @@
return unsafe.String(&sysDirectory[0], sysDirectoryLen)
}
-func windowsLoadSystemLib(name []uint16) uintptr {
- return stdcall3(_LoadLibraryExW, uintptr(unsafe.Pointer(&name[0])), 0, _LOAD_LIBRARY_SEARCH_SYSTEM32)
+//go:linkname syscall_getSystemDirectory syscall.getSystemDirectory
+func syscall_getSystemDirectory() string {
+ return unsafe.String(&sysDirectory[0], sysDirectoryLen)
+}
+
+func windowsLoadSystemLib(name []byte) uintptr {
+ if useLoadLibraryEx {
+ return stdcall3(_LoadLibraryExA, uintptr(unsafe.Pointer(&name[0])), 0, _LOAD_LIBRARY_SEARCH_SYSTEM32)
+ } else {
+ absName := append(sysDirectory[:sysDirectoryLen], name...)
+ return stdcall1(_LoadLibraryA, uintptr(unsafe.Pointer(&absName[0])))
+ }
}
//go:linkname windows_QueryPerformanceCounter internal/syscall/windows.QueryPerformanceCounter
@@ -263,13 +280,28 @@
}
func loadOptionalSyscalls() {
- bcryptPrimitives := windowsLoadSystemLib(bcryptprimitivesdll[:])
- if bcryptPrimitives == 0 {
- throw("bcryptprimitives.dll not found")
+ var kernel32dll = []byte("kernel32.dll\000")
+ k32 := stdcall1(_LoadLibraryA, uintptr(unsafe.Pointer(&kernel32dll[0])))
+ if k32 == 0 {
+ throw("kernel32.dll not found")
}
- _ProcessPrng = windowsFindfunc(bcryptPrimitives, []byte("ProcessPrng\000"))
+ _AddDllDirectory = windowsFindfunc(k32, []byte("AddDllDirectory\000"))
+ _AddVectoredContinueHandler = windowsFindfunc(k32, []byte("AddVectoredContinueHandler\000"))
+ _LoadLibraryExA = windowsFindfunc(k32, []byte("LoadLibraryExA\000"))
+ _LoadLibraryExW = windowsFindfunc(k32, []byte("LoadLibraryExW\000"))
+ useLoadLibraryEx = (_LoadLibraryExW != nil && _LoadLibraryExA != nil && _AddDllDirectory != nil)
+
+ initSysDirectory()
- n32 := windowsLoadSystemLib(ntdlldll[:])
+ var advapi32dll = []byte("advapi32.dll\000")
+ a32 := windowsLoadSystemLib(advapi32dll)
+ if a32 == 0 {
+ throw("advapi32.dll not found")
+ }
+ _RtlGenRandom = windowsFindfunc(a32, []byte("SystemFunction036\000"))
+
+ var ntdll = []byte("ntdll.dll\000")
+ n32 := windowsLoadSystemLib(ntdll)
if n32 == 0 {
throw("ntdll.dll not found")
}
@@ -298,7 +330,7 @@
context uintptr
}
- powrprof := windowsLoadSystemLib(powrprofdll[:])
+ powrprof := windowsLoadSystemLib([]byte("powrprof.dll\000"))
if powrprof == 0 {
return // Running on Windows 7, where we don't need it anyway.
}
@@ -357,6 +389,22 @@
// in sys_windows_386.s and sys_windows_amd64.s:
func getlasterror() uint32
+// When loading DLLs, we prefer to use LoadLibraryEx with
+// LOAD_LIBRARY_SEARCH_* flags, if available. LoadLibraryEx is not
+// available on old Windows, though, and the LOAD_LIBRARY_SEARCH_*
+// flags are not available on some versions of Windows without a
+// security patch.
+//
+// https://msdn.microsoft.com/en-us/library/ms684179(v=vs.85).aspx says:
+// "Windows 7, Windows Server 2008 R2, Windows Vista, and Windows
+// Server 2008: The LOAD_LIBRARY_SEARCH_* flags are available on
+// systems that have KB2533623 installed. To determine whether the
+// flags are available, use GetProcAddress to get the address of the
+// AddDllDirectory, RemoveDllDirectory, or SetDefaultDllDirectories
+// function. If GetProcAddress succeeds, the LOAD_LIBRARY_SEARCH_*
+// flags can be used with LoadLibraryEx."
+var useLoadLibraryEx bool
+
var timeBeginPeriodRetValue uint32
// osRelaxMinNS indicates that sysmon shouldn't osRelax if the next
@@ -430,7 +478,8 @@
// Only load winmm.dll if we need it.
// This avoids a dependency on winmm.dll for Go programs
// that run on new Windows versions.
- m32 := windowsLoadSystemLib(winmmdll[:])
+ var winmmdll = []byte("winmm.dll\000")
+ m32 := windowsLoadSystemLib(winmmdll)
if m32 == 0 {
print("runtime: LoadLibraryExW failed; errno=", getlasterror(), "\n")
throw("winmm.dll not found")
@@ -471,6 +520,28 @@
canUseLongPaths = true
}
+var osVersionInfo struct {
+ majorVersion uint32
+ minorVersion uint32
+ buildNumber uint32
+}
+
+func initOsVersionInfo() {
+ info := _OSVERSIONINFOW{}
+ info.osVersionInfoSize = uint32(unsafe.Sizeof(info))
+ stdcall1(_RtlGetVersion, uintptr(unsafe.Pointer(&info)))
+ osVersionInfo.majorVersion = info.majorVersion
+ osVersionInfo.minorVersion = info.minorVersion
+ osVersionInfo.buildNumber = info.buildNumber
+}
+
+//go:linkname rtlGetNtVersionNumbers syscall.rtlGetNtVersionNumbers
+func rtlGetNtVersionNumbers(majorVersion *uint32, minorVersion *uint32, buildNumber *uint32) {
+ *majorVersion = osVersionInfo.majorVersion
+ *minorVersion = osVersionInfo.minorVersion
+ *buildNumber = osVersionInfo.buildNumber
+}
+
func osinit() {
asmstdcallAddr = unsafe.Pointer(abi.FuncPCABI0(asmstdcall))
@@ -483,8 +554,8 @@
initHighResTimer()
timeBeginPeriodRetValue = osRelax(false)
- initSysDirectory()
initLongPathSupport()
+ initOsVersionInfo()
ncpu = getproccount()
@@ -500,7 +571,7 @@
//go:nosplit
func readRandom(r []byte) int {
n := 0
- if stdcall2(_ProcessPrng, uintptr(unsafe.Pointer(&r[0])), uintptr(len(r)))&0xff != 0 {
+ if stdcall2(_RtlGenRandom, uintptr(unsafe.Pointer(&r[0])), uintptr(len(r)))&0xff != 0 {
n = len(r)
}
return n
Index: src/net/hook_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/net/hook_windows.go b/src/net/hook_windows.go
--- a/src/net/hook_windows.go (revision 9ac42137ef6730e8b7daca016ece831297a1d75b)
+++ b/src/net/hook_windows.go (revision 21290de8a4c91408de7c2b5b68757b1e90af49dd)
@@ -13,6 +13,7 @@
hostsFilePath = windows.GetSystemDirectory() + "/Drivers/etc/hosts"
// Placeholders for socket system calls.
+ socketFunc func(int, int, int) (syscall.Handle, error) = syscall.Socket
wsaSocketFunc func(int32, int32, int32, *syscall.WSAProtocolInfo, uint32, uint32) (syscall.Handle, error) = windows.WSASocket
connectFunc func(syscall.Handle, syscall.Sockaddr) error = syscall.Connect
listenFunc func(syscall.Handle, int) error = syscall.Listen
Index: src/net/internal/socktest/main_test.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/net/internal/socktest/main_test.go b/src/net/internal/socktest/main_test.go
--- a/src/net/internal/socktest/main_test.go (revision 9ac42137ef6730e8b7daca016ece831297a1d75b)
+++ b/src/net/internal/socktest/main_test.go (revision 21290de8a4c91408de7c2b5b68757b1e90af49dd)
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build !js && !plan9 && !wasip1 && !windows
+//go:build !js && !plan9 && !wasip1
package socktest_test
Index: src/net/internal/socktest/main_windows_test.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/net/internal/socktest/main_windows_test.go b/src/net/internal/socktest/main_windows_test.go
new file mode 100644
--- /dev/null (revision 21290de8a4c91408de7c2b5b68757b1e90af49dd)
+++ b/src/net/internal/socktest/main_windows_test.go (revision 21290de8a4c91408de7c2b5b68757b1e90af49dd)
@@ -0,0 +1,22 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package socktest_test
+
+import "syscall"
+
+var (
+ socketFunc func(int, int, int) (syscall.Handle, error)
+ closeFunc func(syscall.Handle) error
+)
+
+func installTestHooks() {
+ socketFunc = sw.Socket
+ closeFunc = sw.Closesocket
+}
+
+func uninstallTestHooks() {
+ socketFunc = syscall.Socket
+ closeFunc = syscall.Closesocket
+}
Index: src/net/internal/socktest/sys_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/net/internal/socktest/sys_windows.go b/src/net/internal/socktest/sys_windows.go
--- a/src/net/internal/socktest/sys_windows.go (revision 9ac42137ef6730e8b7daca016ece831297a1d75b)
+++ b/src/net/internal/socktest/sys_windows.go (revision 21290de8a4c91408de7c2b5b68757b1e90af49dd)
@@ -9,6 +9,38 @@
"syscall"
)
+// Socket wraps [syscall.Socket].
+func (sw *Switch) Socket(family, sotype, proto int) (s syscall.Handle, err error) {
+ sw.once.Do(sw.init)
+
+ so := &Status{Cookie: cookie(family, sotype, proto)}
+ sw.fmu.RLock()
+ f, _ := sw.fltab[FilterSocket]
+ sw.fmu.RUnlock()
+
+ af, err := f.apply(so)
+ if err != nil {
+ return syscall.InvalidHandle, err
+ }
+ s, so.Err = syscall.Socket(family, sotype, proto)
+ if err = af.apply(so); err != nil {
+ if so.Err == nil {
+ syscall.Closesocket(s)
+ }
+ return syscall.InvalidHandle, err
+ }
+
+ sw.smu.Lock()
+ defer sw.smu.Unlock()
+ if so.Err != nil {
+ sw.stats.getLocked(so.Cookie).OpenFailed++
+ return syscall.InvalidHandle, so.Err
+ }
+ nso := sw.addLocked(s, family, sotype, proto)
+ sw.stats.getLocked(nso.Cookie).Opened++
+ return s, nil
+}
+
// WSASocket wraps [syscall.WSASocket].
func (sw *Switch) WSASocket(family, sotype, proto int32, protinfo *syscall.WSAProtocolInfo, group uint32, flags uint32) (s syscall.Handle, err error) {
sw.once.Do(sw.init)
Index: src/net/main_windows_test.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/net/main_windows_test.go b/src/net/main_windows_test.go
--- a/src/net/main_windows_test.go (revision 9ac42137ef6730e8b7daca016ece831297a1d75b)
+++ b/src/net/main_windows_test.go (revision 21290de8a4c91408de7c2b5b68757b1e90af49dd)
@@ -8,6 +8,7 @@
var (
// Placeholders for saving original socket system calls.
+ origSocket = socketFunc
origWSASocket = wsaSocketFunc
origClosesocket = poll.CloseFunc
origConnect = connectFunc
@@ -17,6 +18,7 @@
)
func installTestHooks() {
+ socketFunc = sw.Socket
wsaSocketFunc = sw.WSASocket
poll.CloseFunc = sw.Closesocket
connectFunc = sw.Connect
@@ -26,6 +28,7 @@
}
func uninstallTestHooks() {
+ socketFunc = origSocket
wsaSocketFunc = origWSASocket
poll.CloseFunc = origClosesocket
connectFunc = origConnect
Index: src/net/sock_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/net/sock_windows.go b/src/net/sock_windows.go
--- a/src/net/sock_windows.go (revision 9ac42137ef6730e8b7daca016ece831297a1d75b)
+++ b/src/net/sock_windows.go (revision 21290de8a4c91408de7c2b5b68757b1e90af49dd)
@@ -20,6 +20,21 @@
func sysSocket(family, sotype, proto int) (syscall.Handle, error) {
s, err := wsaSocketFunc(int32(family), int32(sotype), int32(proto),
nil, 0, windows.WSA_FLAG_OVERLAPPED|windows.WSA_FLAG_NO_HANDLE_INHERIT)
+ if err == nil {
+ return s, nil
+ }
+ // WSA_FLAG_NO_HANDLE_INHERIT flag is not supported on some
+ // old versions of Windows, see
+ // https://msdn.microsoft.com/en-us/library/windows/desktop/ms742212(v=vs.85).aspx
+ // for details. Just use syscall.Socket, if windows.WSASocket failed.
+
+ // See ../syscall/exec_unix.go for description of ForkLock.
+ syscall.ForkLock.RLock()
+ s, err = socketFunc(family, sotype, proto)
+ if err == nil {
+ syscall.CloseOnExec(s)
+ }
+ syscall.ForkLock.RUnlock()
if err != nil {
return syscall.InvalidHandle, os.NewSyscallError("socket", err)
}
Index: src/syscall/exec_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/syscall/exec_windows.go b/src/syscall/exec_windows.go
--- a/src/syscall/exec_windows.go (revision 9ac42137ef6730e8b7daca016ece831297a1d75b)
+++ b/src/syscall/exec_windows.go (revision 6a31d3fa8e47ddabc10bd97bff10d9a85f4cfb76)
@@ -14,7 +14,6 @@
"unsafe"
)
-// ForkLock is not used on Windows.
var ForkLock sync.RWMutex
// EscapeArg rewrites command line argument s as prescribed
@@ -254,6 +253,9 @@
var zeroProcAttr ProcAttr
var zeroSysProcAttr SysProcAttr
+//go:linkname rtlGetNtVersionNumbers
+func rtlGetNtVersionNumbers(majorVersion *uint32, minorVersion *uint32, buildNumber *uint32)
+
func StartProcess(argv0 string, argv []string, attr *ProcAttr) (pid int, handle uintptr, err error) {
if len(argv0) == 0 {
return 0, 0, EWINDOWS
@@ -317,6 +319,17 @@
}
}
+ var maj, min, build uint32
+ rtlGetNtVersionNumbers(&maj, &min, &build)
+ isWin7 := maj < 6 || (maj == 6 && min <= 1)
+ // NT kernel handles are divisible by 4, with the bottom 3 bits left as
+ // a tag. The fully set tag correlates with the types of handles we're
+ // concerned about here. Except, the kernel will interpret some
+ // special handle values, like -1, -2, and so forth, so kernelbase.dll
+ // checks to see that those bottom three bits are checked, but that top
+ // bit is not checked.
+ isLegacyWin7ConsoleHandle := func(handle Handle) bool { return isWin7 && handle&0x10000003 == 3 }
+
p, _ := GetCurrentProcess()
parentProcess := p
if sys.ParentProcess != 0 {
@@ -325,7 +338,15 @@
fd := make([]Handle, len(attr.Files))
for i := range attr.Files {
if attr.Files[i] > 0 {
- err := DuplicateHandle(p, Handle(attr.Files[i]), parentProcess, &fd[i], 0, true, DUPLICATE_SAME_ACCESS)
+ destinationProcessHandle := parentProcess
+
+ // On Windows 7, console handles aren't real handles, and can only be duplicated
+ // into the current process, not a parent one, which amounts to the same thing.
+ if parentProcess != p && isLegacyWin7ConsoleHandle(Handle(attr.Files[i])) {
+ destinationProcessHandle = p
+ }
+
+ err := DuplicateHandle(p, Handle(attr.Files[i]), destinationProcessHandle, &fd[i], 0, true, DUPLICATE_SAME_ACCESS)
if err != nil {
return 0, 0, err
}
@@ -356,6 +377,14 @@
fd = append(fd, sys.AdditionalInheritedHandles...)
+ // On Windows 7, console handles aren't real handles, so don't pass them
+ // through to PROC_THREAD_ATTRIBUTE_HANDLE_LIST.
+ for i := range fd {
+ if isLegacyWin7ConsoleHandle(fd[i]) {
+ fd[i] = 0
+ }
+ }
+
// The presence of a NULL handle in the list is enough to cause PROC_THREAD_ATTRIBUTE_HANDLE_LIST
// to treat the entire list as empty, so remove NULL handles.
j := 0
Index: src/runtime/syscall_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/runtime/syscall_windows.go b/src/runtime/syscall_windows.go
--- a/src/runtime/syscall_windows.go (revision 6a31d3fa8e47ddabc10bd97bff10d9a85f4cfb76)
+++ b/src/runtime/syscall_windows.go (revision 69e2eed6dd0f6d815ebf15797761c13f31213dd6)
@@ -413,10 +413,20 @@
const _LOAD_LIBRARY_SEARCH_SYSTEM32 = 0x00000800
+// When available, this function will use LoadLibraryEx with the filename
+// parameter and the important SEARCH_SYSTEM32 argument. But on systems that
+// do not have that option, absoluteFilepath should contain a fallback
+// to the full path inside of system32 for use with vanilla LoadLibrary.
+//
//go:linkname syscall_loadsystemlibrary syscall.loadsystemlibrary
-func syscall_loadsystemlibrary(filename *uint16) (handle, err uintptr) {
- handle, _, err = syscall_SyscallN(uintptr(unsafe.Pointer(_LoadLibraryExW)), uintptr(unsafe.Pointer(filename)), 0, _LOAD_LIBRARY_SEARCH_SYSTEM32)
+func syscall_loadsystemlibrary(filename *uint16, absoluteFilepath *uint16) (handle, err uintptr) {
+ if useLoadLibraryEx {
+ handle, _, err = syscall_SyscallN(uintptr(unsafe.Pointer(_LoadLibraryExW)), uintptr(unsafe.Pointer(filename)), 0, _LOAD_LIBRARY_SEARCH_SYSTEM32)
+ } else {
+ handle, _, err = syscall_SyscallN(uintptr(unsafe.Pointer(_LoadLibraryW)), uintptr(unsafe.Pointer(absoluteFilepath)))
+ }
KeepAlive(filename)
+ KeepAlive(absoluteFilepath)
if handle != 0 {
err = 0
}
Index: src/syscall/dll_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/syscall/dll_windows.go b/src/syscall/dll_windows.go
--- a/src/syscall/dll_windows.go (revision 6a31d3fa8e47ddabc10bd97bff10d9a85f4cfb76)
+++ b/src/syscall/dll_windows.go (revision 69e2eed6dd0f6d815ebf15797761c13f31213dd6)
@@ -44,7 +44,7 @@
func SyscallN(trap uintptr, args ...uintptr) (r1, r2 uintptr, err Errno)
func loadlibrary(filename *uint16) (handle uintptr, err Errno)
-func loadsystemlibrary(filename *uint16) (handle uintptr, err Errno)
+func loadsystemlibrary(filename *uint16, absoluteFilepath *uint16) (handle uintptr, err Errno)
func getprocaddress(handle uintptr, procname *uint8) (proc uintptr, err Errno)
// A DLL implements access to a single DLL.
@@ -53,6 +53,9 @@
Handle Handle
}
+//go:linkname getSystemDirectory
+func getSystemDirectory() string // Implemented in runtime package.
+
// LoadDLL loads the named DLL file into memory.
//
// If name is not an absolute path and is not a known system DLL used by
@@ -69,7 +72,11 @@
var h uintptr
var e Errno
if sysdll.IsSystemDLL[name] {
- h, e = loadsystemlibrary(namep)
+ absoluteFilepathp, err := UTF16PtrFromString(getSystemDirectory() + name)
+ if err != nil {
+ return nil, err
+ }
+ h, e = loadsystemlibrary(namep, absoluteFilepathp)
} else {
h, e = loadlibrary(namep)
}

View File

@@ -1,657 +0,0 @@
Subject: [PATCH] Revert "runtime: always use LoadLibraryEx to load system libraries"
Revert "syscall: remove Windows 7 console handle workaround"
Revert "net: remove sysSocket fallback for Windows 7"
Revert "crypto/rand,runtime: switch RtlGenRandom for ProcessPrng"
---
Index: src/crypto/internal/sysrand/rand_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/crypto/internal/sysrand/rand_windows.go b/src/crypto/internal/sysrand/rand_windows.go
--- a/src/crypto/internal/sysrand/rand_windows.go (revision 3901409b5d0fb7c85a3e6730a59943cc93b2835c)
+++ b/src/crypto/internal/sysrand/rand_windows.go (revision 2a406dc9f1ea7323d6ca9fccb2fe9ddebb6b1cc8)
@@ -7,5 +7,26 @@
import "internal/syscall/windows"
func read(b []byte) error {
- return windows.ProcessPrng(b)
+ // RtlGenRandom only returns 1<<32-1 bytes at a time. We only read at
+ // most 1<<31-1 bytes at a time so that this works the same on 32-bit
+ // and 64-bit systems.
+ return batched(windows.RtlGenRandom, 1<<31-1)(b)
+}
+
+// batched returns a function that calls f to populate a []byte by chunking it
+// into subslices of, at most, readMax bytes.
+func batched(f func([]byte) error, readMax int) func([]byte) error {
+ return func(out []byte) error {
+ for len(out) > 0 {
+ read := len(out)
+ if read > readMax {
+ read = readMax
+ }
+ if err := f(out[:read]); err != nil {
+ return err
+ }
+ out = out[read:]
+ }
+ return nil
+ }
}
Index: src/crypto/rand/rand.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/crypto/rand/rand.go b/src/crypto/rand/rand.go
--- a/src/crypto/rand/rand.go (revision 3901409b5d0fb7c85a3e6730a59943cc93b2835c)
+++ b/src/crypto/rand/rand.go (revision 2a406dc9f1ea7323d6ca9fccb2fe9ddebb6b1cc8)
@@ -22,7 +22,7 @@
// - On legacy Linux (< 3.17), Reader opens /dev/urandom on first use.
// - On macOS, iOS, and OpenBSD Reader, uses arc4random_buf(3).
// - On NetBSD, Reader uses the kern.arandom sysctl.
-// - On Windows, Reader uses the ProcessPrng API.
+// - On Windows systems, Reader uses the RtlGenRandom API.
// - On js/wasm, Reader uses the Web Crypto API.
// - On wasip1/wasm, Reader uses random_get.
//
Index: src/internal/syscall/windows/syscall_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/internal/syscall/windows/syscall_windows.go b/src/internal/syscall/windows/syscall_windows.go
--- a/src/internal/syscall/windows/syscall_windows.go (revision 3901409b5d0fb7c85a3e6730a59943cc93b2835c)
+++ b/src/internal/syscall/windows/syscall_windows.go (revision 2a406dc9f1ea7323d6ca9fccb2fe9ddebb6b1cc8)
@@ -416,7 +416,7 @@
//sys DestroyEnvironmentBlock(block *uint16) (err error) = userenv.DestroyEnvironmentBlock
//sys CreateEvent(eventAttrs *SecurityAttributes, manualReset uint32, initialState uint32, name *uint16) (handle syscall.Handle, err error) = kernel32.CreateEventW
-//sys ProcessPrng(buf []byte) (err error) = bcryptprimitives.ProcessPrng
+//sys RtlGenRandom(buf []byte) (err error) = advapi32.SystemFunction036
type FILE_ID_BOTH_DIR_INFO struct {
NextEntryOffset uint32
Index: src/internal/syscall/windows/zsyscall_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/internal/syscall/windows/zsyscall_windows.go b/src/internal/syscall/windows/zsyscall_windows.go
--- a/src/internal/syscall/windows/zsyscall_windows.go (revision 3901409b5d0fb7c85a3e6730a59943cc93b2835c)
+++ b/src/internal/syscall/windows/zsyscall_windows.go (revision 2a406dc9f1ea7323d6ca9fccb2fe9ddebb6b1cc8)
@@ -38,7 +38,6 @@
var (
modadvapi32 = syscall.NewLazyDLL(sysdll.Add("advapi32.dll"))
- modbcryptprimitives = syscall.NewLazyDLL(sysdll.Add("bcryptprimitives.dll"))
modiphlpapi = syscall.NewLazyDLL(sysdll.Add("iphlpapi.dll"))
modkernel32 = syscall.NewLazyDLL(sysdll.Add("kernel32.dll"))
modnetapi32 = syscall.NewLazyDLL(sysdll.Add("netapi32.dll"))
@@ -63,7 +62,7 @@
procQueryServiceStatus = modadvapi32.NewProc("QueryServiceStatus")
procRevertToSelf = modadvapi32.NewProc("RevertToSelf")
procSetTokenInformation = modadvapi32.NewProc("SetTokenInformation")
- procProcessPrng = modbcryptprimitives.NewProc("ProcessPrng")
+ procSystemFunction036 = modadvapi32.NewProc("SystemFunction036")
procGetAdaptersAddresses = modiphlpapi.NewProc("GetAdaptersAddresses")
procCreateEventW = modkernel32.NewProc("CreateEventW")
procGetACP = modkernel32.NewProc("GetACP")
@@ -236,12 +235,12 @@
return
}
-func ProcessPrng(buf []byte) (err error) {
+func RtlGenRandom(buf []byte) (err error) {
var _p0 *byte
if len(buf) > 0 {
_p0 = &buf[0]
}
- r1, _, e1 := syscall.Syscall(procProcessPrng.Addr(), 2, uintptr(unsafe.Pointer(_p0)), uintptr(len(buf)), 0)
+ r1, _, e1 := syscall.Syscall(procSystemFunction036.Addr(), 2, uintptr(unsafe.Pointer(_p0)), uintptr(len(buf)), 0)
if r1 == 0 {
err = errnoErr(e1)
}
Index: src/runtime/os_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/runtime/os_windows.go b/src/runtime/os_windows.go
--- a/src/runtime/os_windows.go (revision 3901409b5d0fb7c85a3e6730a59943cc93b2835c)
+++ b/src/runtime/os_windows.go (revision ac3e93c061779dfefc0dd13a5b6e6f764a25621e)
@@ -40,8 +40,8 @@
//go:cgo_import_dynamic runtime._GetSystemInfo GetSystemInfo%1 "kernel32.dll"
//go:cgo_import_dynamic runtime._GetThreadContext GetThreadContext%2 "kernel32.dll"
//go:cgo_import_dynamic runtime._SetThreadContext SetThreadContext%2 "kernel32.dll"
-//go:cgo_import_dynamic runtime._LoadLibraryExW LoadLibraryExW%3 "kernel32.dll"
//go:cgo_import_dynamic runtime._LoadLibraryW LoadLibraryW%1 "kernel32.dll"
+//go:cgo_import_dynamic runtime._LoadLibraryA LoadLibraryA%1 "kernel32.dll"
//go:cgo_import_dynamic runtime._PostQueuedCompletionStatus PostQueuedCompletionStatus%4 "kernel32.dll"
//go:cgo_import_dynamic runtime._QueryPerformanceCounter QueryPerformanceCounter%1 "kernel32.dll"
//go:cgo_import_dynamic runtime._QueryPerformanceFrequency QueryPerformanceFrequency%1 "kernel32.dll"
@@ -75,7 +75,6 @@
// Following syscalls are available on every Windows PC.
// All these variables are set by the Windows executable
// loader before the Go program starts.
- _AddVectoredContinueHandler,
_AddVectoredExceptionHandler,
_CloseHandle,
_CreateEventA,
@@ -98,8 +97,8 @@
_GetSystemInfo,
_GetThreadContext,
_SetThreadContext,
- _LoadLibraryExW,
_LoadLibraryW,
+ _LoadLibraryA,
_PostQueuedCompletionStatus,
_QueryPerformanceCounter,
_QueryPerformanceFrequency,
@@ -128,8 +127,23 @@
_WriteFile,
_ stdFunction
- // Use ProcessPrng to generate cryptographically random data.
- _ProcessPrng stdFunction
+ // Following syscalls are only available on some Windows PCs.
+ // We will load syscalls, if available, before using them.
+ _AddDllDirectory,
+ _AddVectoredContinueHandler,
+ _LoadLibraryExA,
+ _LoadLibraryExW,
+ _ stdFunction
+
+ // Use RtlGenRandom to generate cryptographically random data.
+ // This approach has been recommended by Microsoft (see issue
+ // 15589 for details).
+ // The RtlGenRandom is not listed in advapi32.dll, instead
+ // RtlGenRandom function can be found by searching for SystemFunction036.
+ // Also some versions of Mingw cannot link to SystemFunction036
+ // when building executable as Cgo. So load SystemFunction036
+ // manually during runtime startup.
+ _RtlGenRandom stdFunction
// Load ntdll.dll manually during startup, otherwise Mingw
// links wrong printf function to cgo executable (see issue
@@ -146,13 +160,6 @@
_ stdFunction
)
-var (
- bcryptprimitivesdll = [...]uint16{'b', 'c', 'r', 'y', 'p', 't', 'p', 'r', 'i', 'm', 'i', 't', 'i', 'v', 'e', 's', '.', 'd', 'l', 'l', 0}
- ntdlldll = [...]uint16{'n', 't', 'd', 'l', 'l', '.', 'd', 'l', 'l', 0}
- powrprofdll = [...]uint16{'p', 'o', 'w', 'r', 'p', 'r', 'o', 'f', '.', 'd', 'l', 'l', 0}
- winmmdll = [...]uint16{'w', 'i', 'n', 'm', 'm', '.', 'd', 'l', 'l', 0}
-)
-
// Function to be called by windows CreateThread
// to start new os thread.
func tstart_stdcall(newm *m)
@@ -245,8 +252,18 @@
return unsafe.String(&sysDirectory[0], sysDirectoryLen)
}
-func windowsLoadSystemLib(name []uint16) uintptr {
- return stdcall3(_LoadLibraryExW, uintptr(unsafe.Pointer(&name[0])), 0, _LOAD_LIBRARY_SEARCH_SYSTEM32)
+//go:linkname syscall_getSystemDirectory syscall.getSystemDirectory
+func syscall_getSystemDirectory() string {
+ return unsafe.String(&sysDirectory[0], sysDirectoryLen)
+}
+
+func windowsLoadSystemLib(name []byte) uintptr {
+ if useLoadLibraryEx {
+ return stdcall3(_LoadLibraryExA, uintptr(unsafe.Pointer(&name[0])), 0, _LOAD_LIBRARY_SEARCH_SYSTEM32)
+ } else {
+ absName := append(sysDirectory[:sysDirectoryLen], name...)
+ return stdcall1(_LoadLibraryA, uintptr(unsafe.Pointer(&absName[0])))
+ }
}
//go:linkname windows_QueryPerformanceCounter internal/syscall/windows.QueryPerformanceCounter
@@ -264,13 +281,28 @@
}
func loadOptionalSyscalls() {
- bcryptPrimitives := windowsLoadSystemLib(bcryptprimitivesdll[:])
- if bcryptPrimitives == 0 {
- throw("bcryptprimitives.dll not found")
+ var kernel32dll = []byte("kernel32.dll\000")
+ k32 := stdcall1(_LoadLibraryA, uintptr(unsafe.Pointer(&kernel32dll[0])))
+ if k32 == 0 {
+ throw("kernel32.dll not found")
}
- _ProcessPrng = windowsFindfunc(bcryptPrimitives, []byte("ProcessPrng\000"))
+ _AddDllDirectory = windowsFindfunc(k32, []byte("AddDllDirectory\000"))
+ _AddVectoredContinueHandler = windowsFindfunc(k32, []byte("AddVectoredContinueHandler\000"))
+ _LoadLibraryExA = windowsFindfunc(k32, []byte("LoadLibraryExA\000"))
+ _LoadLibraryExW = windowsFindfunc(k32, []byte("LoadLibraryExW\000"))
+ useLoadLibraryEx = (_LoadLibraryExW != nil && _LoadLibraryExA != nil && _AddDllDirectory != nil)
+
+ initSysDirectory()
- n32 := windowsLoadSystemLib(ntdlldll[:])
+ var advapi32dll = []byte("advapi32.dll\000")
+ a32 := windowsLoadSystemLib(advapi32dll)
+ if a32 == 0 {
+ throw("advapi32.dll not found")
+ }
+ _RtlGenRandom = windowsFindfunc(a32, []byte("SystemFunction036\000"))
+
+ var ntdll = []byte("ntdll.dll\000")
+ n32 := windowsLoadSystemLib(ntdll)
if n32 == 0 {
throw("ntdll.dll not found")
}
@@ -299,7 +331,7 @@
context uintptr
}
- powrprof := windowsLoadSystemLib(powrprofdll[:])
+ powrprof := windowsLoadSystemLib([]byte("powrprof.dll\000"))
if powrprof == 0 {
return // Running on Windows 7, where we don't need it anyway.
}
@@ -358,6 +390,22 @@
// in sys_windows_386.s and sys_windows_amd64.s:
func getlasterror() uint32
+// When loading DLLs, we prefer to use LoadLibraryEx with
+// LOAD_LIBRARY_SEARCH_* flags, if available. LoadLibraryEx is not
+// available on old Windows, though, and the LOAD_LIBRARY_SEARCH_*
+// flags are not available on some versions of Windows without a
+// security patch.
+//
+// https://msdn.microsoft.com/en-us/library/ms684179(v=vs.85).aspx says:
+// "Windows 7, Windows Server 2008 R2, Windows Vista, and Windows
+// Server 2008: The LOAD_LIBRARY_SEARCH_* flags are available on
+// systems that have KB2533623 installed. To determine whether the
+// flags are available, use GetProcAddress to get the address of the
+// AddDllDirectory, RemoveDllDirectory, or SetDefaultDllDirectories
+// function. If GetProcAddress succeeds, the LOAD_LIBRARY_SEARCH_*
+// flags can be used with LoadLibraryEx."
+var useLoadLibraryEx bool
+
var timeBeginPeriodRetValue uint32
// osRelaxMinNS indicates that sysmon shouldn't osRelax if the next
@@ -431,7 +479,8 @@
// Only load winmm.dll if we need it.
// This avoids a dependency on winmm.dll for Go programs
// that run on new Windows versions.
- m32 := windowsLoadSystemLib(winmmdll[:])
+ var winmmdll = []byte("winmm.dll\000")
+ m32 := windowsLoadSystemLib(winmmdll)
if m32 == 0 {
print("runtime: LoadLibraryExW failed; errno=", getlasterror(), "\n")
throw("winmm.dll not found")
@@ -472,6 +521,28 @@
canUseLongPaths = true
}
+var osVersionInfo struct {
+ majorVersion uint32
+ minorVersion uint32
+ buildNumber uint32
+}
+
+func initOsVersionInfo() {
+ info := _OSVERSIONINFOW{}
+ info.osVersionInfoSize = uint32(unsafe.Sizeof(info))
+ stdcall1(_RtlGetVersion, uintptr(unsafe.Pointer(&info)))
+ osVersionInfo.majorVersion = info.majorVersion
+ osVersionInfo.minorVersion = info.minorVersion
+ osVersionInfo.buildNumber = info.buildNumber
+}
+
+//go:linkname rtlGetNtVersionNumbers syscall.rtlGetNtVersionNumbers
+func rtlGetNtVersionNumbers(majorVersion *uint32, minorVersion *uint32, buildNumber *uint32) {
+ *majorVersion = osVersionInfo.majorVersion
+ *minorVersion = osVersionInfo.minorVersion
+ *buildNumber = osVersionInfo.buildNumber
+}
+
func osinit() {
asmstdcallAddr = unsafe.Pointer(abi.FuncPCABI0(asmstdcall))
@@ -484,8 +555,8 @@
initHighResTimer()
timeBeginPeriodRetValue = osRelax(false)
- initSysDirectory()
initLongPathSupport()
+ initOsVersionInfo()
ncpu = getproccount()
@@ -501,7 +572,7 @@
//go:nosplit
func readRandom(r []byte) int {
n := 0
- if stdcall2(_ProcessPrng, uintptr(unsafe.Pointer(&r[0])), uintptr(len(r)))&0xff != 0 {
+ if stdcall2(_RtlGenRandom, uintptr(unsafe.Pointer(&r[0])), uintptr(len(r)))&0xff != 0 {
n = len(r)
}
return n
Index: src/net/hook_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/net/hook_windows.go b/src/net/hook_windows.go
--- a/src/net/hook_windows.go (revision 2a406dc9f1ea7323d6ca9fccb2fe9ddebb6b1cc8)
+++ b/src/net/hook_windows.go (revision 7b1fd7d39c6be0185fbe1d929578ab372ac5c632)
@@ -13,6 +13,7 @@
hostsFilePath = windows.GetSystemDirectory() + "/Drivers/etc/hosts"
// Placeholders for socket system calls.
+ socketFunc func(int, int, int) (syscall.Handle, error) = syscall.Socket
wsaSocketFunc func(int32, int32, int32, *syscall.WSAProtocolInfo, uint32, uint32) (syscall.Handle, error) = windows.WSASocket
connectFunc func(syscall.Handle, syscall.Sockaddr) error = syscall.Connect
listenFunc func(syscall.Handle, int) error = syscall.Listen
Index: src/net/internal/socktest/main_test.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/net/internal/socktest/main_test.go b/src/net/internal/socktest/main_test.go
--- a/src/net/internal/socktest/main_test.go (revision 2a406dc9f1ea7323d6ca9fccb2fe9ddebb6b1cc8)
+++ b/src/net/internal/socktest/main_test.go (revision 7b1fd7d39c6be0185fbe1d929578ab372ac5c632)
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build !js && !plan9 && !wasip1 && !windows
+//go:build !js && !plan9 && !wasip1
package socktest_test
Index: src/net/internal/socktest/main_windows_test.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/net/internal/socktest/main_windows_test.go b/src/net/internal/socktest/main_windows_test.go
new file mode 100644
--- /dev/null (revision 7b1fd7d39c6be0185fbe1d929578ab372ac5c632)
+++ b/src/net/internal/socktest/main_windows_test.go (revision 7b1fd7d39c6be0185fbe1d929578ab372ac5c632)
@@ -0,0 +1,22 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package socktest_test
+
+import "syscall"
+
+var (
+ socketFunc func(int, int, int) (syscall.Handle, error)
+ closeFunc func(syscall.Handle) error
+)
+
+func installTestHooks() {
+ socketFunc = sw.Socket
+ closeFunc = sw.Closesocket
+}
+
+func uninstallTestHooks() {
+ socketFunc = syscall.Socket
+ closeFunc = syscall.Closesocket
+}
Index: src/net/internal/socktest/sys_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/net/internal/socktest/sys_windows.go b/src/net/internal/socktest/sys_windows.go
--- a/src/net/internal/socktest/sys_windows.go (revision 2a406dc9f1ea7323d6ca9fccb2fe9ddebb6b1cc8)
+++ b/src/net/internal/socktest/sys_windows.go (revision 7b1fd7d39c6be0185fbe1d929578ab372ac5c632)
@@ -9,6 +9,38 @@
"syscall"
)
+// Socket wraps [syscall.Socket].
+func (sw *Switch) Socket(family, sotype, proto int) (s syscall.Handle, err error) {
+ sw.once.Do(sw.init)
+
+ so := &Status{Cookie: cookie(family, sotype, proto)}
+ sw.fmu.RLock()
+ f, _ := sw.fltab[FilterSocket]
+ sw.fmu.RUnlock()
+
+ af, err := f.apply(so)
+ if err != nil {
+ return syscall.InvalidHandle, err
+ }
+ s, so.Err = syscall.Socket(family, sotype, proto)
+ if err = af.apply(so); err != nil {
+ if so.Err == nil {
+ syscall.Closesocket(s)
+ }
+ return syscall.InvalidHandle, err
+ }
+
+ sw.smu.Lock()
+ defer sw.smu.Unlock()
+ if so.Err != nil {
+ sw.stats.getLocked(so.Cookie).OpenFailed++
+ return syscall.InvalidHandle, so.Err
+ }
+ nso := sw.addLocked(s, family, sotype, proto)
+ sw.stats.getLocked(nso.Cookie).Opened++
+ return s, nil
+}
+
// WSASocket wraps [syscall.WSASocket].
func (sw *Switch) WSASocket(family, sotype, proto int32, protinfo *syscall.WSAProtocolInfo, group uint32, flags uint32) (s syscall.Handle, err error) {
sw.once.Do(sw.init)
Index: src/net/main_windows_test.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/net/main_windows_test.go b/src/net/main_windows_test.go
--- a/src/net/main_windows_test.go (revision 2a406dc9f1ea7323d6ca9fccb2fe9ddebb6b1cc8)
+++ b/src/net/main_windows_test.go (revision 7b1fd7d39c6be0185fbe1d929578ab372ac5c632)
@@ -8,6 +8,7 @@
var (
// Placeholders for saving original socket system calls.
+ origSocket = socketFunc
origWSASocket = wsaSocketFunc
origClosesocket = poll.CloseFunc
origConnect = connectFunc
@@ -17,6 +18,7 @@
)
func installTestHooks() {
+ socketFunc = sw.Socket
wsaSocketFunc = sw.WSASocket
poll.CloseFunc = sw.Closesocket
connectFunc = sw.Connect
@@ -26,6 +28,7 @@
}
func uninstallTestHooks() {
+ socketFunc = origSocket
wsaSocketFunc = origWSASocket
poll.CloseFunc = origClosesocket
connectFunc = origConnect
Index: src/net/sock_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/net/sock_windows.go b/src/net/sock_windows.go
--- a/src/net/sock_windows.go (revision 2a406dc9f1ea7323d6ca9fccb2fe9ddebb6b1cc8)
+++ b/src/net/sock_windows.go (revision 7b1fd7d39c6be0185fbe1d929578ab372ac5c632)
@@ -20,6 +20,21 @@
func sysSocket(family, sotype, proto int) (syscall.Handle, error) {
s, err := wsaSocketFunc(int32(family), int32(sotype), int32(proto),
nil, 0, windows.WSA_FLAG_OVERLAPPED|windows.WSA_FLAG_NO_HANDLE_INHERIT)
+ if err == nil {
+ return s, nil
+ }
+ // WSA_FLAG_NO_HANDLE_INHERIT flag is not supported on some
+ // old versions of Windows, see
+ // https://msdn.microsoft.com/en-us/library/windows/desktop/ms742212(v=vs.85).aspx
+ // for details. Just use syscall.Socket, if windows.WSASocket failed.
+
+ // See ../syscall/exec_unix.go for description of ForkLock.
+ syscall.ForkLock.RLock()
+ s, err = socketFunc(family, sotype, proto)
+ if err == nil {
+ syscall.CloseOnExec(s)
+ }
+ syscall.ForkLock.RUnlock()
if err != nil {
return syscall.InvalidHandle, os.NewSyscallError("socket", err)
}
Index: src/syscall/exec_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/syscall/exec_windows.go b/src/syscall/exec_windows.go
--- a/src/syscall/exec_windows.go (revision 2a406dc9f1ea7323d6ca9fccb2fe9ddebb6b1cc8)
+++ b/src/syscall/exec_windows.go (revision 979d6d8bab3823ff572ace26767fd2ce3cf351ae)
@@ -14,7 +14,6 @@
"unsafe"
)
-// ForkLock is not used on Windows.
var ForkLock sync.RWMutex
// EscapeArg rewrites command line argument s as prescribed
@@ -254,6 +253,9 @@
var zeroProcAttr ProcAttr
var zeroSysProcAttr SysProcAttr
+//go:linkname rtlGetNtVersionNumbers
+func rtlGetNtVersionNumbers(majorVersion *uint32, minorVersion *uint32, buildNumber *uint32)
+
func StartProcess(argv0 string, argv []string, attr *ProcAttr) (pid int, handle uintptr, err error) {
if len(argv0) == 0 {
return 0, 0, EWINDOWS
@@ -317,6 +319,17 @@
}
}
+ var maj, min, build uint32
+ rtlGetNtVersionNumbers(&maj, &min, &build)
+ isWin7 := maj < 6 || (maj == 6 && min <= 1)
+ // NT kernel handles are divisible by 4, with the bottom 3 bits left as
+ // a tag. The fully set tag correlates with the types of handles we're
+ // concerned about here. Except, the kernel will interpret some
+ // special handle values, like -1, -2, and so forth, so kernelbase.dll
+ // checks to see that those bottom three bits are checked, but that top
+ // bit is not checked.
+ isLegacyWin7ConsoleHandle := func(handle Handle) bool { return isWin7 && handle&0x10000003 == 3 }
+
p, _ := GetCurrentProcess()
parentProcess := p
if sys.ParentProcess != 0 {
@@ -325,7 +338,15 @@
fd := make([]Handle, len(attr.Files))
for i := range attr.Files {
if attr.Files[i] > 0 {
- err := DuplicateHandle(p, Handle(attr.Files[i]), parentProcess, &fd[i], 0, true, DUPLICATE_SAME_ACCESS)
+ destinationProcessHandle := parentProcess
+
+ // On Windows 7, console handles aren't real handles, and can only be duplicated
+ // into the current process, not a parent one, which amounts to the same thing.
+ if parentProcess != p && isLegacyWin7ConsoleHandle(Handle(attr.Files[i])) {
+ destinationProcessHandle = p
+ }
+
+ err := DuplicateHandle(p, Handle(attr.Files[i]), destinationProcessHandle, &fd[i], 0, true, DUPLICATE_SAME_ACCESS)
if err != nil {
return 0, 0, err
}
@@ -356,6 +377,14 @@
fd = append(fd, sys.AdditionalInheritedHandles...)
+ // On Windows 7, console handles aren't real handles, so don't pass them
+ // through to PROC_THREAD_ATTRIBUTE_HANDLE_LIST.
+ for i := range fd {
+ if isLegacyWin7ConsoleHandle(fd[i]) {
+ fd[i] = 0
+ }
+ }
+
// The presence of a NULL handle in the list is enough to cause PROC_THREAD_ATTRIBUTE_HANDLE_LIST
// to treat the entire list as empty, so remove NULL handles.
j := 0
Index: src/runtime/syscall_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/runtime/syscall_windows.go b/src/runtime/syscall_windows.go
--- a/src/runtime/syscall_windows.go (revision 979d6d8bab3823ff572ace26767fd2ce3cf351ae)
+++ b/src/runtime/syscall_windows.go (revision ac3e93c061779dfefc0dd13a5b6e6f764a25621e)
@@ -413,10 +413,20 @@
const _LOAD_LIBRARY_SEARCH_SYSTEM32 = 0x00000800
+// When available, this function will use LoadLibraryEx with the filename
+// parameter and the important SEARCH_SYSTEM32 argument. But on systems that
+// do not have that option, absoluteFilepath should contain a fallback
+// to the full path inside of system32 for use with vanilla LoadLibrary.
+//
//go:linkname syscall_loadsystemlibrary syscall.loadsystemlibrary
-func syscall_loadsystemlibrary(filename *uint16) (handle, err uintptr) {
- handle, _, err = syscall_SyscallN(uintptr(unsafe.Pointer(_LoadLibraryExW)), uintptr(unsafe.Pointer(filename)), 0, _LOAD_LIBRARY_SEARCH_SYSTEM32)
+func syscall_loadsystemlibrary(filename *uint16, absoluteFilepath *uint16) (handle, err uintptr) {
+ if useLoadLibraryEx {
+ handle, _, err = syscall_SyscallN(uintptr(unsafe.Pointer(_LoadLibraryExW)), uintptr(unsafe.Pointer(filename)), 0, _LOAD_LIBRARY_SEARCH_SYSTEM32)
+ } else {
+ handle, _, err = syscall_SyscallN(uintptr(unsafe.Pointer(_LoadLibraryW)), uintptr(unsafe.Pointer(absoluteFilepath)))
+ }
KeepAlive(filename)
+ KeepAlive(absoluteFilepath)
if handle != 0 {
err = 0
}
Index: src/syscall/dll_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/syscall/dll_windows.go b/src/syscall/dll_windows.go
--- a/src/syscall/dll_windows.go (revision 979d6d8bab3823ff572ace26767fd2ce3cf351ae)
+++ b/src/syscall/dll_windows.go (revision ac3e93c061779dfefc0dd13a5b6e6f764a25621e)
@@ -45,7 +45,7 @@
//go:noescape
func SyscallN(trap uintptr, args ...uintptr) (r1, r2 uintptr, err Errno)
func loadlibrary(filename *uint16) (handle uintptr, err Errno)
-func loadsystemlibrary(filename *uint16) (handle uintptr, err Errno)
+func loadsystemlibrary(filename *uint16, absoluteFilepath *uint16) (handle uintptr, err Errno)
func getprocaddress(handle uintptr, procname *uint8) (proc uintptr, err Errno)
// A DLL implements access to a single DLL.
@@ -54,6 +54,9 @@
Handle Handle
}
+//go:linkname getSystemDirectory
+func getSystemDirectory() string // Implemented in runtime package.
+
// LoadDLL loads the named DLL file into memory.
//
// If name is not an absolute path and is not a known system DLL used by
@@ -70,7 +73,11 @@
var h uintptr
var e Errno
if sysdll.IsSystemDLL[name] {
- h, e = loadsystemlibrary(namep)
+ absoluteFilepathp, err := UTF16PtrFromString(getSystemDirectory() + name)
+ if err != nil {
+ return nil, err
+ }
+ h, e = loadsystemlibrary(namep, absoluteFilepathp)
} else {
h, e = loadlibrary(namep)
}

View File

@@ -1,884 +0,0 @@
Subject: [PATCH] Revert "os: remove 5ms sleep on Windows in (*Process).Wait"
Fix os.RemoveAll not working on Windows7
Revert "runtime: always use LoadLibraryEx to load system libraries"
Revert "syscall: remove Windows 7 console handle workaround"
Revert "net: remove sysSocket fallback for Windows 7"
Revert "crypto/rand,runtime: switch RtlGenRandom for ProcessPrng"
---
Index: src/crypto/internal/sysrand/rand_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/crypto/internal/sysrand/rand_windows.go b/src/crypto/internal/sysrand/rand_windows.go
--- a/src/crypto/internal/sysrand/rand_windows.go (revision 6e676ab2b809d46623acb5988248d95d1eb7939c)
+++ b/src/crypto/internal/sysrand/rand_windows.go (revision 8cb5472d94c34b88733a81091bd328e70ee565a4)
@@ -7,5 +7,26 @@
import "internal/syscall/windows"
func read(b []byte) error {
- return windows.ProcessPrng(b)
+ // RtlGenRandom only returns 1<<32-1 bytes at a time. We only read at
+ // most 1<<31-1 bytes at a time so that this works the same on 32-bit
+ // and 64-bit systems.
+ return batched(windows.RtlGenRandom, 1<<31-1)(b)
+}
+
+// batched returns a function that calls f to populate a []byte by chunking it
+// into subslices of, at most, readMax bytes.
+func batched(f func([]byte) error, readMax int) func([]byte) error {
+ return func(out []byte) error {
+ for len(out) > 0 {
+ read := len(out)
+ if read > readMax {
+ read = readMax
+ }
+ if err := f(out[:read]); err != nil {
+ return err
+ }
+ out = out[read:]
+ }
+ return nil
+ }
}
Index: src/crypto/rand/rand.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/crypto/rand/rand.go b/src/crypto/rand/rand.go
--- a/src/crypto/rand/rand.go (revision 6e676ab2b809d46623acb5988248d95d1eb7939c)
+++ b/src/crypto/rand/rand.go (revision 8cb5472d94c34b88733a81091bd328e70ee565a4)
@@ -22,7 +22,7 @@
// - On legacy Linux (< 3.17), Reader opens /dev/urandom on first use.
// - On macOS, iOS, and OpenBSD Reader, uses arc4random_buf(3).
// - On NetBSD, Reader uses the kern.arandom sysctl.
-// - On Windows, Reader uses the ProcessPrng API.
+// - On Windows systems, Reader uses the RtlGenRandom API.
// - On js/wasm, Reader uses the Web Crypto API.
// - On wasip1/wasm, Reader uses random_get.
//
Index: src/internal/syscall/windows/syscall_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/internal/syscall/windows/syscall_windows.go b/src/internal/syscall/windows/syscall_windows.go
--- a/src/internal/syscall/windows/syscall_windows.go (revision 6e676ab2b809d46623acb5988248d95d1eb7939c)
+++ b/src/internal/syscall/windows/syscall_windows.go (revision 8cb5472d94c34b88733a81091bd328e70ee565a4)
@@ -419,7 +419,7 @@
//sys DestroyEnvironmentBlock(block *uint16) (err error) = userenv.DestroyEnvironmentBlock
//sys CreateEvent(eventAttrs *SecurityAttributes, manualReset uint32, initialState uint32, name *uint16) (handle syscall.Handle, err error) = kernel32.CreateEventW
-//sys ProcessPrng(buf []byte) (err error) = bcryptprimitives.ProcessPrng
+//sys RtlGenRandom(buf []byte) (err error) = advapi32.SystemFunction036
type FILE_ID_BOTH_DIR_INFO struct {
NextEntryOffset uint32
Index: src/internal/syscall/windows/zsyscall_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/internal/syscall/windows/zsyscall_windows.go b/src/internal/syscall/windows/zsyscall_windows.go
--- a/src/internal/syscall/windows/zsyscall_windows.go (revision 6e676ab2b809d46623acb5988248d95d1eb7939c)
+++ b/src/internal/syscall/windows/zsyscall_windows.go (revision 8cb5472d94c34b88733a81091bd328e70ee565a4)
@@ -38,7 +38,6 @@
var (
modadvapi32 = syscall.NewLazyDLL(sysdll.Add("advapi32.dll"))
- modbcryptprimitives = syscall.NewLazyDLL(sysdll.Add("bcryptprimitives.dll"))
modiphlpapi = syscall.NewLazyDLL(sysdll.Add("iphlpapi.dll"))
modkernel32 = syscall.NewLazyDLL(sysdll.Add("kernel32.dll"))
modnetapi32 = syscall.NewLazyDLL(sysdll.Add("netapi32.dll"))
@@ -63,7 +62,7 @@
procQueryServiceStatus = modadvapi32.NewProc("QueryServiceStatus")
procRevertToSelf = modadvapi32.NewProc("RevertToSelf")
procSetTokenInformation = modadvapi32.NewProc("SetTokenInformation")
- procProcessPrng = modbcryptprimitives.NewProc("ProcessPrng")
+ procSystemFunction036 = modadvapi32.NewProc("SystemFunction036")
procGetAdaptersAddresses = modiphlpapi.NewProc("GetAdaptersAddresses")
procCreateEventW = modkernel32.NewProc("CreateEventW")
procCreateIoCompletionPort = modkernel32.NewProc("CreateIoCompletionPort")
@@ -242,12 +241,12 @@
return
}
-func ProcessPrng(buf []byte) (err error) {
+func RtlGenRandom(buf []byte) (err error) {
var _p0 *byte
if len(buf) > 0 {
_p0 = &buf[0]
}
- r1, _, e1 := syscall.Syscall(procProcessPrng.Addr(), 2, uintptr(unsafe.Pointer(_p0)), uintptr(len(buf)), 0)
+ r1, _, e1 := syscall.Syscall(procSystemFunction036.Addr(), 2, uintptr(unsafe.Pointer(_p0)), uintptr(len(buf)), 0)
if r1 == 0 {
err = errnoErr(e1)
}
Index: src/runtime/os_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/runtime/os_windows.go b/src/runtime/os_windows.go
--- a/src/runtime/os_windows.go (revision 6e676ab2b809d46623acb5988248d95d1eb7939c)
+++ b/src/runtime/os_windows.go (revision f56f1e23507e646c85243a71bde7b9629b2f970c)
@@ -39,8 +39,8 @@
//go:cgo_import_dynamic runtime._GetSystemInfo GetSystemInfo%1 "kernel32.dll"
//go:cgo_import_dynamic runtime._GetThreadContext GetThreadContext%2 "kernel32.dll"
//go:cgo_import_dynamic runtime._SetThreadContext SetThreadContext%2 "kernel32.dll"
-//go:cgo_import_dynamic runtime._LoadLibraryExW LoadLibraryExW%3 "kernel32.dll"
//go:cgo_import_dynamic runtime._LoadLibraryW LoadLibraryW%1 "kernel32.dll"
+//go:cgo_import_dynamic runtime._LoadLibraryA LoadLibraryA%1 "kernel32.dll"
//go:cgo_import_dynamic runtime._PostQueuedCompletionStatus PostQueuedCompletionStatus%4 "kernel32.dll"
//go:cgo_import_dynamic runtime._QueryPerformanceCounter QueryPerformanceCounter%1 "kernel32.dll"
//go:cgo_import_dynamic runtime._QueryPerformanceFrequency QueryPerformanceFrequency%1 "kernel32.dll"
@@ -74,7 +74,6 @@
// Following syscalls are available on every Windows PC.
// All these variables are set by the Windows executable
// loader before the Go program starts.
- _AddVectoredContinueHandler,
_AddVectoredExceptionHandler,
_CloseHandle,
_CreateEventA,
@@ -97,8 +96,8 @@
_GetSystemInfo,
_GetThreadContext,
_SetThreadContext,
- _LoadLibraryExW,
_LoadLibraryW,
+ _LoadLibraryA,
_PostQueuedCompletionStatus,
_QueryPerformanceCounter,
_QueryPerformanceFrequency,
@@ -127,8 +126,23 @@
_WriteFile,
_ stdFunction
- // Use ProcessPrng to generate cryptographically random data.
- _ProcessPrng stdFunction
+ // Following syscalls are only available on some Windows PCs.
+ // We will load syscalls, if available, before using them.
+ _AddDllDirectory,
+ _AddVectoredContinueHandler,
+ _LoadLibraryExA,
+ _LoadLibraryExW,
+ _ stdFunction
+
+ // Use RtlGenRandom to generate cryptographically random data.
+ // This approach has been recommended by Microsoft (see issue
+ // 15589 for details).
+ // The RtlGenRandom is not listed in advapi32.dll, instead
+ // RtlGenRandom function can be found by searching for SystemFunction036.
+ // Also some versions of Mingw cannot link to SystemFunction036
+ // when building executable as Cgo. So load SystemFunction036
+ // manually during runtime startup.
+ _RtlGenRandom stdFunction
// Load ntdll.dll manually during startup, otherwise Mingw
// links wrong printf function to cgo executable (see issue
@@ -145,13 +159,6 @@
_ stdFunction
)
-var (
- bcryptprimitivesdll = [...]uint16{'b', 'c', 'r', 'y', 'p', 't', 'p', 'r', 'i', 'm', 'i', 't', 'i', 'v', 'e', 's', '.', 'd', 'l', 'l', 0}
- ntdlldll = [...]uint16{'n', 't', 'd', 'l', 'l', '.', 'd', 'l', 'l', 0}
- powrprofdll = [...]uint16{'p', 'o', 'w', 'r', 'p', 'r', 'o', 'f', '.', 'd', 'l', 'l', 0}
- winmmdll = [...]uint16{'w', 'i', 'n', 'm', 'm', '.', 'd', 'l', 'l', 0}
-)
-
// Function to be called by windows CreateThread
// to start new os thread.
func tstart_stdcall(newm *m)
@@ -244,8 +251,18 @@
return unsafe.String(&sysDirectory[0], sysDirectoryLen)
}
-func windowsLoadSystemLib(name []uint16) uintptr {
- return stdcall3(_LoadLibraryExW, uintptr(unsafe.Pointer(&name[0])), 0, _LOAD_LIBRARY_SEARCH_SYSTEM32)
+//go:linkname syscall_getSystemDirectory syscall.getSystemDirectory
+func syscall_getSystemDirectory() string {
+ return unsafe.String(&sysDirectory[0], sysDirectoryLen)
+}
+
+func windowsLoadSystemLib(name []byte) uintptr {
+ if useLoadLibraryEx {
+ return stdcall3(_LoadLibraryExA, uintptr(unsafe.Pointer(&name[0])), 0, _LOAD_LIBRARY_SEARCH_SYSTEM32)
+ } else {
+ absName := append(sysDirectory[:sysDirectoryLen], name...)
+ return stdcall1(_LoadLibraryA, uintptr(unsafe.Pointer(&absName[0])))
+ }
}
//go:linkname windows_QueryPerformanceCounter internal/syscall/windows.QueryPerformanceCounter
@@ -263,13 +280,28 @@
}
func loadOptionalSyscalls() {
- bcryptPrimitives := windowsLoadSystemLib(bcryptprimitivesdll[:])
- if bcryptPrimitives == 0 {
- throw("bcryptprimitives.dll not found")
+ var kernel32dll = []byte("kernel32.dll\000")
+ k32 := stdcall1(_LoadLibraryA, uintptr(unsafe.Pointer(&kernel32dll[0])))
+ if k32 == 0 {
+ throw("kernel32.dll not found")
}
- _ProcessPrng = windowsFindfunc(bcryptPrimitives, []byte("ProcessPrng\000"))
+ _AddDllDirectory = windowsFindfunc(k32, []byte("AddDllDirectory\000"))
+ _AddVectoredContinueHandler = windowsFindfunc(k32, []byte("AddVectoredContinueHandler\000"))
+ _LoadLibraryExA = windowsFindfunc(k32, []byte("LoadLibraryExA\000"))
+ _LoadLibraryExW = windowsFindfunc(k32, []byte("LoadLibraryExW\000"))
+ useLoadLibraryEx = (_LoadLibraryExW != nil && _LoadLibraryExA != nil && _AddDllDirectory != nil)
+
+ initSysDirectory()
- n32 := windowsLoadSystemLib(ntdlldll[:])
+ var advapi32dll = []byte("advapi32.dll\000")
+ a32 := windowsLoadSystemLib(advapi32dll)
+ if a32 == 0 {
+ throw("advapi32.dll not found")
+ }
+ _RtlGenRandom = windowsFindfunc(a32, []byte("SystemFunction036\000"))
+
+ var ntdll = []byte("ntdll.dll\000")
+ n32 := windowsLoadSystemLib(ntdll)
if n32 == 0 {
throw("ntdll.dll not found")
}
@@ -298,7 +330,7 @@
context uintptr
}
- powrprof := windowsLoadSystemLib(powrprofdll[:])
+ powrprof := windowsLoadSystemLib([]byte("powrprof.dll\000"))
if powrprof == 0 {
return // Running on Windows 7, where we don't need it anyway.
}
@@ -357,6 +389,22 @@
// in sys_windows_386.s and sys_windows_amd64.s:
func getlasterror() uint32
+// When loading DLLs, we prefer to use LoadLibraryEx with
+// LOAD_LIBRARY_SEARCH_* flags, if available. LoadLibraryEx is not
+// available on old Windows, though, and the LOAD_LIBRARY_SEARCH_*
+// flags are not available on some versions of Windows without a
+// security patch.
+//
+// https://msdn.microsoft.com/en-us/library/ms684179(v=vs.85).aspx says:
+// "Windows 7, Windows Server 2008 R2, Windows Vista, and Windows
+// Server 2008: The LOAD_LIBRARY_SEARCH_* flags are available on
+// systems that have KB2533623 installed. To determine whether the
+// flags are available, use GetProcAddress to get the address of the
+// AddDllDirectory, RemoveDllDirectory, or SetDefaultDllDirectories
+// function. If GetProcAddress succeeds, the LOAD_LIBRARY_SEARCH_*
+// flags can be used with LoadLibraryEx."
+var useLoadLibraryEx bool
+
var timeBeginPeriodRetValue uint32
// osRelaxMinNS indicates that sysmon shouldn't osRelax if the next
@@ -430,7 +478,8 @@
// Only load winmm.dll if we need it.
// This avoids a dependency on winmm.dll for Go programs
// that run on new Windows versions.
- m32 := windowsLoadSystemLib(winmmdll[:])
+ var winmmdll = []byte("winmm.dll\000")
+ m32 := windowsLoadSystemLib(winmmdll)
if m32 == 0 {
print("runtime: LoadLibraryExW failed; errno=", getlasterror(), "\n")
throw("winmm.dll not found")
@@ -471,6 +520,28 @@
canUseLongPaths = true
}
+var osVersionInfo struct {
+ majorVersion uint32
+ minorVersion uint32
+ buildNumber uint32
+}
+
+func initOsVersionInfo() {
+ info := _OSVERSIONINFOW{}
+ info.osVersionInfoSize = uint32(unsafe.Sizeof(info))
+ stdcall1(_RtlGetVersion, uintptr(unsafe.Pointer(&info)))
+ osVersionInfo.majorVersion = info.majorVersion
+ osVersionInfo.minorVersion = info.minorVersion
+ osVersionInfo.buildNumber = info.buildNumber
+}
+
+//go:linkname rtlGetNtVersionNumbers syscall.rtlGetNtVersionNumbers
+func rtlGetNtVersionNumbers(majorVersion *uint32, minorVersion *uint32, buildNumber *uint32) {
+ *majorVersion = osVersionInfo.majorVersion
+ *minorVersion = osVersionInfo.minorVersion
+ *buildNumber = osVersionInfo.buildNumber
+}
+
func osinit() {
asmstdcallAddr = unsafe.Pointer(abi.FuncPCABI0(asmstdcall))
@@ -483,8 +554,8 @@
initHighResTimer()
timeBeginPeriodRetValue = osRelax(false)
- initSysDirectory()
initLongPathSupport()
+ initOsVersionInfo()
numCPUStartup = getCPUCount()
@@ -500,7 +571,7 @@
//go:nosplit
func readRandom(r []byte) int {
n := 0
- if stdcall2(_ProcessPrng, uintptr(unsafe.Pointer(&r[0])), uintptr(len(r)))&0xff != 0 {
+ if stdcall2(_RtlGenRandom, uintptr(unsafe.Pointer(&r[0])), uintptr(len(r)))&0xff != 0 {
n = len(r)
}
return n
Index: src/net/hook_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/net/hook_windows.go b/src/net/hook_windows.go
--- a/src/net/hook_windows.go (revision 8cb5472d94c34b88733a81091bd328e70ee565a4)
+++ b/src/net/hook_windows.go (revision 6788c4c6f9fafb56729bad6b660f7ee2272d699f)
@@ -13,6 +13,7 @@
hostsFilePath = windows.GetSystemDirectory() + "/Drivers/etc/hosts"
// Placeholders for socket system calls.
+ socketFunc func(int, int, int) (syscall.Handle, error) = syscall.Socket
wsaSocketFunc func(int32, int32, int32, *syscall.WSAProtocolInfo, uint32, uint32) (syscall.Handle, error) = windows.WSASocket
connectFunc func(syscall.Handle, syscall.Sockaddr) error = syscall.Connect
listenFunc func(syscall.Handle, int) error = syscall.Listen
Index: src/net/internal/socktest/main_test.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/net/internal/socktest/main_test.go b/src/net/internal/socktest/main_test.go
--- a/src/net/internal/socktest/main_test.go (revision 8cb5472d94c34b88733a81091bd328e70ee565a4)
+++ b/src/net/internal/socktest/main_test.go (revision 6788c4c6f9fafb56729bad6b660f7ee2272d699f)
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build !js && !plan9 && !wasip1 && !windows
+//go:build !js && !plan9 && !wasip1
package socktest_test
Index: src/net/internal/socktest/main_windows_test.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/net/internal/socktest/main_windows_test.go b/src/net/internal/socktest/main_windows_test.go
new file mode 100644
--- /dev/null (revision 6788c4c6f9fafb56729bad6b660f7ee2272d699f)
+++ b/src/net/internal/socktest/main_windows_test.go (revision 6788c4c6f9fafb56729bad6b660f7ee2272d699f)
@@ -0,0 +1,22 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package socktest_test
+
+import "syscall"
+
+var (
+ socketFunc func(int, int, int) (syscall.Handle, error)
+ closeFunc func(syscall.Handle) error
+)
+
+func installTestHooks() {
+ socketFunc = sw.Socket
+ closeFunc = sw.Closesocket
+}
+
+func uninstallTestHooks() {
+ socketFunc = syscall.Socket
+ closeFunc = syscall.Closesocket
+}
Index: src/net/internal/socktest/sys_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/net/internal/socktest/sys_windows.go b/src/net/internal/socktest/sys_windows.go
--- a/src/net/internal/socktest/sys_windows.go (revision 8cb5472d94c34b88733a81091bd328e70ee565a4)
+++ b/src/net/internal/socktest/sys_windows.go (revision 6788c4c6f9fafb56729bad6b660f7ee2272d699f)
@@ -9,6 +9,38 @@
"syscall"
)
+// Socket wraps [syscall.Socket].
+func (sw *Switch) Socket(family, sotype, proto int) (s syscall.Handle, err error) {
+ sw.once.Do(sw.init)
+
+ so := &Status{Cookie: cookie(family, sotype, proto)}
+ sw.fmu.RLock()
+ f, _ := sw.fltab[FilterSocket]
+ sw.fmu.RUnlock()
+
+ af, err := f.apply(so)
+ if err != nil {
+ return syscall.InvalidHandle, err
+ }
+ s, so.Err = syscall.Socket(family, sotype, proto)
+ if err = af.apply(so); err != nil {
+ if so.Err == nil {
+ syscall.Closesocket(s)
+ }
+ return syscall.InvalidHandle, err
+ }
+
+ sw.smu.Lock()
+ defer sw.smu.Unlock()
+ if so.Err != nil {
+ sw.stats.getLocked(so.Cookie).OpenFailed++
+ return syscall.InvalidHandle, so.Err
+ }
+ nso := sw.addLocked(s, family, sotype, proto)
+ sw.stats.getLocked(nso.Cookie).Opened++
+ return s, nil
+}
+
// WSASocket wraps [syscall.WSASocket].
func (sw *Switch) WSASocket(family, sotype, proto int32, protinfo *syscall.WSAProtocolInfo, group uint32, flags uint32) (s syscall.Handle, err error) {
sw.once.Do(sw.init)
Index: src/net/main_windows_test.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/net/main_windows_test.go b/src/net/main_windows_test.go
--- a/src/net/main_windows_test.go (revision 8cb5472d94c34b88733a81091bd328e70ee565a4)
+++ b/src/net/main_windows_test.go (revision 6788c4c6f9fafb56729bad6b660f7ee2272d699f)
@@ -12,6 +12,7 @@
var (
// Placeholders for saving original socket system calls.
+ origSocket = socketFunc
origWSASocket = wsaSocketFunc
origClosesocket = poll.CloseFunc
origConnect = connectFunc
@@ -21,6 +22,7 @@
)
func installTestHooks() {
+ socketFunc = sw.Socket
wsaSocketFunc = sw.WSASocket
poll.CloseFunc = sw.Closesocket
connectFunc = sw.Connect
@@ -30,6 +32,7 @@
}
func uninstallTestHooks() {
+ socketFunc = origSocket
wsaSocketFunc = origWSASocket
poll.CloseFunc = origClosesocket
connectFunc = origConnect
Index: src/net/sock_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/net/sock_windows.go b/src/net/sock_windows.go
--- a/src/net/sock_windows.go (revision 8cb5472d94c34b88733a81091bd328e70ee565a4)
+++ b/src/net/sock_windows.go (revision 6788c4c6f9fafb56729bad6b660f7ee2272d699f)
@@ -20,6 +20,21 @@
func sysSocket(family, sotype, proto int) (syscall.Handle, error) {
s, err := wsaSocketFunc(int32(family), int32(sotype), int32(proto),
nil, 0, windows.WSA_FLAG_OVERLAPPED|windows.WSA_FLAG_NO_HANDLE_INHERIT)
+ if err == nil {
+ return s, nil
+ }
+ // WSA_FLAG_NO_HANDLE_INHERIT flag is not supported on some
+ // old versions of Windows, see
+ // https://msdn.microsoft.com/en-us/library/windows/desktop/ms742212(v=vs.85).aspx
+ // for details. Just use syscall.Socket, if windows.WSASocket failed.
+
+ // See ../syscall/exec_unix.go for description of ForkLock.
+ syscall.ForkLock.RLock()
+ s, err = socketFunc(family, sotype, proto)
+ if err == nil {
+ syscall.CloseOnExec(s)
+ }
+ syscall.ForkLock.RUnlock()
if err != nil {
return syscall.InvalidHandle, os.NewSyscallError("socket", err)
}
Index: src/syscall/exec_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/syscall/exec_windows.go b/src/syscall/exec_windows.go
--- a/src/syscall/exec_windows.go (revision 8cb5472d94c34b88733a81091bd328e70ee565a4)
+++ b/src/syscall/exec_windows.go (revision a5b2168bb836ed9d6601c626f95e56c07923f906)
@@ -14,7 +14,6 @@
"unsafe"
)
-// ForkLock is not used on Windows.
var ForkLock sync.RWMutex
// EscapeArg rewrites command line argument s as prescribed
@@ -254,6 +253,9 @@
var zeroProcAttr ProcAttr
var zeroSysProcAttr SysProcAttr
+//go:linkname rtlGetNtVersionNumbers
+func rtlGetNtVersionNumbers(majorVersion *uint32, minorVersion *uint32, buildNumber *uint32)
+
func StartProcess(argv0 string, argv []string, attr *ProcAttr) (pid int, handle uintptr, err error) {
if len(argv0) == 0 {
return 0, 0, EWINDOWS
@@ -317,6 +319,17 @@
}
}
+ var maj, min, build uint32
+ rtlGetNtVersionNumbers(&maj, &min, &build)
+ isWin7 := maj < 6 || (maj == 6 && min <= 1)
+ // NT kernel handles are divisible by 4, with the bottom 3 bits left as
+ // a tag. The fully set tag correlates with the types of handles we're
+ // concerned about here. Except, the kernel will interpret some
+ // special handle values, like -1, -2, and so forth, so kernelbase.dll
+ // checks to see that those bottom three bits are checked, but that top
+ // bit is not checked.
+ isLegacyWin7ConsoleHandle := func(handle Handle) bool { return isWin7 && handle&0x10000003 == 3 }
+
p, _ := GetCurrentProcess()
parentProcess := p
if sys.ParentProcess != 0 {
@@ -325,7 +338,15 @@
fd := make([]Handle, len(attr.Files))
for i := range attr.Files {
if attr.Files[i] > 0 {
- err := DuplicateHandle(p, Handle(attr.Files[i]), parentProcess, &fd[i], 0, true, DUPLICATE_SAME_ACCESS)
+ destinationProcessHandle := parentProcess
+
+ // On Windows 7, console handles aren't real handles, and can only be duplicated
+ // into the current process, not a parent one, which amounts to the same thing.
+ if parentProcess != p && isLegacyWin7ConsoleHandle(Handle(attr.Files[i])) {
+ destinationProcessHandle = p
+ }
+
+ err := DuplicateHandle(p, Handle(attr.Files[i]), destinationProcessHandle, &fd[i], 0, true, DUPLICATE_SAME_ACCESS)
if err != nil {
return 0, 0, err
}
@@ -356,6 +377,14 @@
fd = append(fd, sys.AdditionalInheritedHandles...)
+ // On Windows 7, console handles aren't real handles, so don't pass them
+ // through to PROC_THREAD_ATTRIBUTE_HANDLE_LIST.
+ for i := range fd {
+ if isLegacyWin7ConsoleHandle(fd[i]) {
+ fd[i] = 0
+ }
+ }
+
// The presence of a NULL handle in the list is enough to cause PROC_THREAD_ATTRIBUTE_HANDLE_LIST
// to treat the entire list as empty, so remove NULL handles.
j := 0
Index: src/runtime/syscall_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/runtime/syscall_windows.go b/src/runtime/syscall_windows.go
--- a/src/runtime/syscall_windows.go (revision a5b2168bb836ed9d6601c626f95e56c07923f906)
+++ b/src/runtime/syscall_windows.go (revision f56f1e23507e646c85243a71bde7b9629b2f970c)
@@ -413,10 +413,20 @@
const _LOAD_LIBRARY_SEARCH_SYSTEM32 = 0x00000800
+// When available, this function will use LoadLibraryEx with the filename
+// parameter and the important SEARCH_SYSTEM32 argument. But on systems that
+// do not have that option, absoluteFilepath should contain a fallback
+// to the full path inside of system32 for use with vanilla LoadLibrary.
+//
//go:linkname syscall_loadsystemlibrary syscall.loadsystemlibrary
-func syscall_loadsystemlibrary(filename *uint16) (handle, err uintptr) {
- handle, _, err = syscall_SyscallN(uintptr(unsafe.Pointer(_LoadLibraryExW)), uintptr(unsafe.Pointer(filename)), 0, _LOAD_LIBRARY_SEARCH_SYSTEM32)
+func syscall_loadsystemlibrary(filename *uint16, absoluteFilepath *uint16) (handle, err uintptr) {
+ if useLoadLibraryEx {
+ handle, _, err = syscall_SyscallN(uintptr(unsafe.Pointer(_LoadLibraryExW)), uintptr(unsafe.Pointer(filename)), 0, _LOAD_LIBRARY_SEARCH_SYSTEM32)
+ } else {
+ handle, _, err = syscall_SyscallN(uintptr(unsafe.Pointer(_LoadLibraryW)), uintptr(unsafe.Pointer(absoluteFilepath)))
+ }
KeepAlive(filename)
+ KeepAlive(absoluteFilepath)
if handle != 0 {
err = 0
}
Index: src/syscall/dll_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/syscall/dll_windows.go b/src/syscall/dll_windows.go
--- a/src/syscall/dll_windows.go (revision a5b2168bb836ed9d6601c626f95e56c07923f906)
+++ b/src/syscall/dll_windows.go (revision f56f1e23507e646c85243a71bde7b9629b2f970c)
@@ -45,7 +45,7 @@
//go:noescape
func SyscallN(trap uintptr, args ...uintptr) (r1, r2 uintptr, err Errno)
func loadlibrary(filename *uint16) (handle uintptr, err Errno)
-func loadsystemlibrary(filename *uint16) (handle uintptr, err Errno)
+func loadsystemlibrary(filename *uint16, absoluteFilepath *uint16) (handle uintptr, err Errno)
func getprocaddress(handle uintptr, procname *uint8) (proc uintptr, err Errno)
// A DLL implements access to a single DLL.
@@ -54,6 +54,9 @@
Handle Handle
}
+//go:linkname getSystemDirectory
+func getSystemDirectory() string // Implemented in runtime package.
+
// LoadDLL loads the named DLL file into memory.
//
// If name is not an absolute path and is not a known system DLL used by
@@ -70,7 +73,11 @@
var h uintptr
var e Errno
if sysdll.IsSystemDLL[name] {
- h, e = loadsystemlibrary(namep)
+ absoluteFilepathp, err := UTF16PtrFromString(getSystemDirectory() + name)
+ if err != nil {
+ return nil, err
+ }
+ h, e = loadsystemlibrary(namep, absoluteFilepathp)
} else {
h, e = loadlibrary(namep)
}
Index: src/os/removeall_at.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/os/removeall_at.go b/src/os/removeall_at.go
--- a/src/os/removeall_at.go (revision f56f1e23507e646c85243a71bde7b9629b2f970c)
+++ b/src/os/removeall_at.go (revision 0a52622d2331ff975fb0442617ec19bc352bb2ed)
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build unix || wasip1 || windows
+//go:build unix || wasip1
package os
@@ -175,3 +175,25 @@
}
return newDirFile(fd, name)
}
+
+func rootRemoveAll(r *Root, name string) error {
+ // Consistency with os.RemoveAll: Strip trailing /s from the name,
+ // so RemoveAll("not_a_directory/") succeeds.
+ for len(name) > 0 && IsPathSeparator(name[len(name)-1]) {
+ name = name[:len(name)-1]
+ }
+ if endsWithDot(name) {
+ // Consistency with os.RemoveAll: Return EINVAL when trying to remove .
+ return &PathError{Op: "RemoveAll", Path: name, Err: syscall.EINVAL}
+ }
+ _, err := doInRoot(r, name, nil, func(parent sysfdType, name string) (struct{}, error) {
+ return struct{}{}, removeAllFrom(parent, name)
+ })
+ if IsNotExist(err) {
+ return nil
+ }
+ if err != nil {
+ return &PathError{Op: "RemoveAll", Path: name, Err: underlyingError(err)}
+ }
+ return err
+}
Index: src/os/removeall_noat.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/os/removeall_noat.go b/src/os/removeall_noat.go
--- a/src/os/removeall_noat.go (revision f56f1e23507e646c85243a71bde7b9629b2f970c)
+++ b/src/os/removeall_noat.go (revision 0a52622d2331ff975fb0442617ec19bc352bb2ed)
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build (js && wasm) || plan9
+//go:build (js && wasm) || plan9 || windows
package os
@@ -140,3 +140,22 @@
}
return err
}
+
+func rootRemoveAll(r *Root, name string) error {
+ if endsWithDot(name) {
+ // Consistency with os.RemoveAll: Return EINVAL when trying to remove .
+ return &PathError{Op: "RemoveAll", Path: name, Err: syscall.EINVAL}
+ }
+ if err := checkPathEscapesLstat(r, name); err != nil {
+ if err == syscall.ENOTDIR {
+ // Some intermediate path component is not a directory.
+ // RemoveAll treats this as success (since the target doesn't exist).
+ return nil
+ }
+ return &PathError{Op: "RemoveAll", Path: name, Err: err}
+ }
+ if err := RemoveAll(joinPath(r.root.name, name)); err != nil {
+ return &PathError{Op: "RemoveAll", Path: name, Err: underlyingError(err)}
+ }
+ return nil
+}
Index: src/os/root_noopenat.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/os/root_noopenat.go b/src/os/root_noopenat.go
--- a/src/os/root_noopenat.go (revision f56f1e23507e646c85243a71bde7b9629b2f970c)
+++ b/src/os/root_noopenat.go (revision 0a52622d2331ff975fb0442617ec19bc352bb2ed)
@@ -11,7 +11,6 @@
"internal/filepathlite"
"internal/stringslite"
"sync/atomic"
- "syscall"
"time"
)
@@ -185,25 +184,6 @@
}
return nil
}
-
-func rootRemoveAll(r *Root, name string) error {
- if endsWithDot(name) {
- // Consistency with os.RemoveAll: Return EINVAL when trying to remove .
- return &PathError{Op: "RemoveAll", Path: name, Err: syscall.EINVAL}
- }
- if err := checkPathEscapesLstat(r, name); err != nil {
- if err == syscall.ENOTDIR {
- // Some intermediate path component is not a directory.
- // RemoveAll treats this as success (since the target doesn't exist).
- return nil
- }
- return &PathError{Op: "RemoveAll", Path: name, Err: err}
- }
- if err := RemoveAll(joinPath(r.root.name, name)); err != nil {
- return &PathError{Op: "RemoveAll", Path: name, Err: underlyingError(err)}
- }
- return nil
-}
func rootReadlink(r *Root, name string) (string, error) {
if err := checkPathEscapesLstat(r, name); err != nil {
Index: src/os/root_openat.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/os/root_openat.go b/src/os/root_openat.go
--- a/src/os/root_openat.go (revision f56f1e23507e646c85243a71bde7b9629b2f970c)
+++ b/src/os/root_openat.go (revision 0a52622d2331ff975fb0442617ec19bc352bb2ed)
@@ -194,28 +194,6 @@
return nil
}
-func rootRemoveAll(r *Root, name string) error {
- // Consistency with os.RemoveAll: Strip trailing /s from the name,
- // so RemoveAll("not_a_directory/") succeeds.
- for len(name) > 0 && IsPathSeparator(name[len(name)-1]) {
- name = name[:len(name)-1]
- }
- if endsWithDot(name) {
- // Consistency with os.RemoveAll: Return EINVAL when trying to remove .
- return &PathError{Op: "RemoveAll", Path: name, Err: syscall.EINVAL}
- }
- _, err := doInRoot(r, name, nil, func(parent sysfdType, name string) (struct{}, error) {
- return struct{}{}, removeAllFrom(parent, name)
- })
- if IsNotExist(err) {
- return nil
- }
- if err != nil {
- return &PathError{Op: "RemoveAll", Path: name, Err: underlyingError(err)}
- }
- return err
-}
-
func rootRename(r *Root, oldname, newname string) error {
_, err := doInRoot(r, oldname, nil, func(oldparent sysfdType, oldname string) (struct{}, error) {
_, err := doInRoot(r, newname, nil, func(newparent sysfdType, newname string) (struct{}, error) {
Index: src/os/root_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/os/root_windows.go b/src/os/root_windows.go
--- a/src/os/root_windows.go (revision f56f1e23507e646c85243a71bde7b9629b2f970c)
+++ b/src/os/root_windows.go (revision 0a52622d2331ff975fb0442617ec19bc352bb2ed)
@@ -402,3 +402,14 @@
}
return fi.Mode(), nil
}
+
+func checkPathEscapes(r *Root, name string) error {
+ if !filepathlite.IsLocal(name) {
+ return errPathEscapes
+ }
+ return nil
+}
+
+func checkPathEscapesLstat(r *Root, name string) error {
+ return checkPathEscapes(r, name)
+}
Index: src/os/exec_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/os/exec_windows.go b/src/os/exec_windows.go
--- a/src/os/exec_windows.go (revision 0a52622d2331ff975fb0442617ec19bc352bb2ed)
+++ b/src/os/exec_windows.go (revision fb3d09a67fe97008ad76fea97ae88170072cbdbb)
@@ -10,6 +10,7 @@
"runtime"
"syscall"
"time"
+ _ "unsafe"
)
// Note that Process.handle is never nil because Windows always requires
@@ -49,9 +50,23 @@
// than statusDone.
p.doRelease(statusReleased)
+ var maj, min, build uint32
+ rtlGetNtVersionNumbers(&maj, &min, &build)
+ if maj < 10 {
+ // NOTE(brainman): It seems that sometimes process is not dead
+ // when WaitForSingleObject returns. But we do not know any
+ // other way to wait for it. Sleeping for a while seems to do
+ // the trick sometimes.
+ // See https://golang.org/issue/25965 for details.
+ time.Sleep(5 * time.Millisecond)
+ }
+
return &ProcessState{p.Pid, syscall.WaitStatus{ExitCode: ec}, &u}, nil
}
+//go:linkname rtlGetNtVersionNumbers syscall.rtlGetNtVersionNumbers
+func rtlGetNtVersionNumbers(majorVersion *uint32, minorVersion *uint32, buildNumber *uint32)
+
func (p *Process) signal(sig Signal) error {
handle, status := p.handleTransientAcquire()
switch status {

View File

@@ -1,883 +0,0 @@
Subject: [PATCH] Revert "os: remove 5ms sleep on Windows in (*Process).Wait"
Fix os.RemoveAll not working on Windows7
Revert "runtime: always use LoadLibraryEx to load system libraries"
Revert "syscall: remove Windows 7 console handle workaround"
Revert "net: remove sysSocket fallback for Windows 7"
Revert "crypto/rand,runtime: switch RtlGenRandom for ProcessPrng"
---
Index: src/crypto/internal/sysrand/rand_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/crypto/internal/sysrand/rand_windows.go b/src/crypto/internal/sysrand/rand_windows.go
--- a/src/crypto/internal/sysrand/rand_windows.go (revision c599a8f2385849a225d02843b3c6389dbfc5aa69)
+++ b/src/crypto/internal/sysrand/rand_windows.go (revision b0d48afabb9fd14976c27221cb525c5d2ebbfe79)
@@ -7,5 +7,26 @@
import "internal/syscall/windows"
func read(b []byte) error {
- return windows.ProcessPrng(b)
+ // RtlGenRandom only returns 1<<32-1 bytes at a time. We only read at
+ // most 1<<31-1 bytes at a time so that this works the same on 32-bit
+ // and 64-bit systems.
+ return batched(windows.RtlGenRandom, 1<<31-1)(b)
+}
+
+// batched returns a function that calls f to populate a []byte by chunking it
+// into subslices of, at most, readMax bytes.
+func batched(f func([]byte) error, readMax int) func([]byte) error {
+ return func(out []byte) error {
+ for len(out) > 0 {
+ read := len(out)
+ if read > readMax {
+ read = readMax
+ }
+ if err := f(out[:read]); err != nil {
+ return err
+ }
+ out = out[read:]
+ }
+ return nil
+ }
}
Index: src/crypto/rand/rand.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/crypto/rand/rand.go b/src/crypto/rand/rand.go
--- a/src/crypto/rand/rand.go (revision c599a8f2385849a225d02843b3c6389dbfc5aa69)
+++ b/src/crypto/rand/rand.go (revision b0d48afabb9fd14976c27221cb525c5d2ebbfe79)
@@ -25,7 +25,7 @@
// - On legacy Linux (< 3.17), Reader opens /dev/urandom on first use.
// - On macOS, iOS, and OpenBSD Reader, uses arc4random_buf(3).
// - On NetBSD, Reader uses the kern.arandom sysctl.
-// - On Windows, Reader uses the ProcessPrng API.
+// - On Windows systems, Reader uses the RtlGenRandom API.
// - On js/wasm, Reader uses the Web Crypto API.
// - On wasip1/wasm, Reader uses random_get.
//
Index: src/internal/syscall/windows/syscall_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/internal/syscall/windows/syscall_windows.go b/src/internal/syscall/windows/syscall_windows.go
--- a/src/internal/syscall/windows/syscall_windows.go (revision c599a8f2385849a225d02843b3c6389dbfc5aa69)
+++ b/src/internal/syscall/windows/syscall_windows.go (revision b0d48afabb9fd14976c27221cb525c5d2ebbfe79)
@@ -421,7 +421,7 @@
//sys DestroyEnvironmentBlock(block *uint16) (err error) = userenv.DestroyEnvironmentBlock
//sys CreateEvent(eventAttrs *SecurityAttributes, manualReset uint32, initialState uint32, name *uint16) (handle syscall.Handle, err error) = kernel32.CreateEventW
-//sys ProcessPrng(buf []byte) (err error) = bcryptprimitives.ProcessPrng
+//sys RtlGenRandom(buf []byte) (err error) = advapi32.SystemFunction036
type FILE_ID_BOTH_DIR_INFO struct {
NextEntryOffset uint32
Index: src/internal/syscall/windows/zsyscall_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/internal/syscall/windows/zsyscall_windows.go b/src/internal/syscall/windows/zsyscall_windows.go
--- a/src/internal/syscall/windows/zsyscall_windows.go (revision c599a8f2385849a225d02843b3c6389dbfc5aa69)
+++ b/src/internal/syscall/windows/zsyscall_windows.go (revision b0d48afabb9fd14976c27221cb525c5d2ebbfe79)
@@ -38,7 +38,6 @@
var (
modadvapi32 = syscall.NewLazyDLL(sysdll.Add("advapi32.dll"))
- modbcryptprimitives = syscall.NewLazyDLL(sysdll.Add("bcryptprimitives.dll"))
modiphlpapi = syscall.NewLazyDLL(sysdll.Add("iphlpapi.dll"))
modkernel32 = syscall.NewLazyDLL(sysdll.Add("kernel32.dll"))
modnetapi32 = syscall.NewLazyDLL(sysdll.Add("netapi32.dll"))
@@ -63,7 +62,7 @@
procQueryServiceStatus = modadvapi32.NewProc("QueryServiceStatus")
procRevertToSelf = modadvapi32.NewProc("RevertToSelf")
procSetTokenInformation = modadvapi32.NewProc("SetTokenInformation")
- procProcessPrng = modbcryptprimitives.NewProc("ProcessPrng")
+ procSystemFunction036 = modadvapi32.NewProc("SystemFunction036")
procGetAdaptersAddresses = modiphlpapi.NewProc("GetAdaptersAddresses")
procCreateEventW = modkernel32.NewProc("CreateEventW")
procCreateIoCompletionPort = modkernel32.NewProc("CreateIoCompletionPort")
@@ -244,12 +243,12 @@
return
}
-func ProcessPrng(buf []byte) (err error) {
+func RtlGenRandom(buf []byte) (err error) {
var _p0 *byte
if len(buf) > 0 {
_p0 = &buf[0]
}
- r1, _, e1 := syscall.SyscallN(procProcessPrng.Addr(), uintptr(unsafe.Pointer(_p0)), uintptr(len(buf)))
+ r1, _, e1 := syscall.SyscallN(procSystemFunction036.Addr(), uintptr(unsafe.Pointer(_p0)), uintptr(len(buf)), 0)
if r1 == 0 {
err = errnoErr(e1)
}
Index: src/runtime/os_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/runtime/os_windows.go b/src/runtime/os_windows.go
--- a/src/runtime/os_windows.go (revision c599a8f2385849a225d02843b3c6389dbfc5aa69)
+++ b/src/runtime/os_windows.go (revision ea2726a6fa25fbfa1092e696e522eafca544d24c)
@@ -40,7 +40,8 @@
//go:cgo_import_dynamic runtime._GetSystemInfo GetSystemInfo%1 "kernel32.dll"
//go:cgo_import_dynamic runtime._GetThreadContext GetThreadContext%2 "kernel32.dll"
//go:cgo_import_dynamic runtime._SetThreadContext SetThreadContext%2 "kernel32.dll"
-//go:cgo_import_dynamic runtime._LoadLibraryExW LoadLibraryExW%3 "kernel32.dll"
+//go:cgo_import_dynamic runtime._LoadLibraryW LoadLibraryW%1 "kernel32.dll"
+//go:cgo_import_dynamic runtime._LoadLibraryA LoadLibraryA%1 "kernel32.dll"
//go:cgo_import_dynamic runtime._PostQueuedCompletionStatus PostQueuedCompletionStatus%4 "kernel32.dll"
//go:cgo_import_dynamic runtime._QueryPerformanceCounter QueryPerformanceCounter%1 "kernel32.dll"
//go:cgo_import_dynamic runtime._QueryPerformanceFrequency QueryPerformanceFrequency%1 "kernel32.dll"
@@ -74,7 +75,6 @@
// Following syscalls are available on every Windows PC.
// All these variables are set by the Windows executable
// loader before the Go program starts.
- _AddVectoredContinueHandler,
_AddVectoredExceptionHandler,
_CloseHandle,
_CreateEventA,
@@ -97,7 +97,8 @@
_GetSystemInfo,
_GetThreadContext,
_SetThreadContext,
- _LoadLibraryExW,
+ _LoadLibraryW,
+ _LoadLibraryA,
_PostQueuedCompletionStatus,
_QueryPerformanceCounter,
_QueryPerformanceFrequency,
@@ -126,8 +127,23 @@
_WriteFile,
_ stdFunction
- // Use ProcessPrng to generate cryptographically random data.
- _ProcessPrng stdFunction
+ // Following syscalls are only available on some Windows PCs.
+ // We will load syscalls, if available, before using them.
+ _AddDllDirectory,
+ _AddVectoredContinueHandler,
+ _LoadLibraryExA,
+ _LoadLibraryExW,
+ _ stdFunction
+
+ // Use RtlGenRandom to generate cryptographically random data.
+ // This approach has been recommended by Microsoft (see issue
+ // 15589 for details).
+ // The RtlGenRandom is not listed in advapi32.dll, instead
+ // RtlGenRandom function can be found by searching for SystemFunction036.
+ // Also some versions of Mingw cannot link to SystemFunction036
+ // when building executable as Cgo. So load SystemFunction036
+ // manually during runtime startup.
+ _RtlGenRandom stdFunction
// Load ntdll.dll manually during startup, otherwise Mingw
// links wrong printf function to cgo executable (see issue
@@ -144,13 +160,6 @@
_ stdFunction
)
-var (
- bcryptprimitivesdll = [...]uint16{'b', 'c', 'r', 'y', 'p', 't', 'p', 'r', 'i', 'm', 'i', 't', 'i', 'v', 'e', 's', '.', 'd', 'l', 'l', 0}
- ntdlldll = [...]uint16{'n', 't', 'd', 'l', 'l', '.', 'd', 'l', 'l', 0}
- powrprofdll = [...]uint16{'p', 'o', 'w', 'r', 'p', 'r', 'o', 'f', '.', 'd', 'l', 'l', 0}
- winmmdll = [...]uint16{'w', 'i', 'n', 'm', 'm', '.', 'd', 'l', 'l', 0}
-)
-
// Function to be called by windows CreateThread
// to start new os thread.
func tstart_stdcall(newm *m)
@@ -242,9 +251,40 @@
return unsafe.String(&sysDirectory[0], sysDirectoryLen)
}
-func windowsLoadSystemLib(name []uint16) uintptr {
- const _LOAD_LIBRARY_SEARCH_SYSTEM32 = 0x00000800
- return stdcall(_LoadLibraryExW, uintptr(unsafe.Pointer(&name[0])), 0, _LOAD_LIBRARY_SEARCH_SYSTEM32)
+//go:linkname syscall_getSystemDirectory syscall.getSystemDirectory
+func syscall_getSystemDirectory() string {
+ return unsafe.String(&sysDirectory[0], sysDirectoryLen)
+}
+
+func windowsLoadSystemLib(name []byte) uintptr {
+ if useLoadLibraryEx {
+ return stdcall(_LoadLibraryExA, uintptr(unsafe.Pointer(&name[0])), 0, _LOAD_LIBRARY_SEARCH_SYSTEM32)
+ } else {
+ absName := append(sysDirectory[:sysDirectoryLen], name...)
+ return stdcall(_LoadLibraryA, uintptr(unsafe.Pointer(&absName[0])))
+ }
+}
+
+const _LOAD_LIBRARY_SEARCH_SYSTEM32 = 0x00000800
+
+// When available, this function will use LoadLibraryEx with the filename
+// parameter and the important SEARCH_SYSTEM32 argument. But on systems that
+// do not have that option, absoluteFilepath should contain a fallback
+// to the full path inside of system32 for use with vanilla LoadLibrary.
+//
+//go:linkname syscall_loadsystemlibrary syscall.loadsystemlibrary
+func syscall_loadsystemlibrary(filename *uint16, absoluteFilepath *uint16) (handle, err uintptr) {
+ if useLoadLibraryEx {
+ handle, _, err = syscall_syscalln(uintptr(unsafe.Pointer(_LoadLibraryExW)), 3, uintptr(unsafe.Pointer(filename)), 0, _LOAD_LIBRARY_SEARCH_SYSTEM32)
+ } else {
+ handle, _, err = syscall_syscalln(uintptr(unsafe.Pointer(_LoadLibraryW)), 1, uintptr(unsafe.Pointer(absoluteFilepath)))
+ }
+ KeepAlive(filename)
+ KeepAlive(absoluteFilepath)
+ if handle != 0 {
+ err = 0
+ }
+ return
}
//go:linkname windows_QueryPerformanceCounter internal/syscall/windows.QueryPerformanceCounter
@@ -262,13 +302,28 @@
}
func loadOptionalSyscalls() {
- bcryptPrimitives := windowsLoadSystemLib(bcryptprimitivesdll[:])
- if bcryptPrimitives == 0 {
- throw("bcryptprimitives.dll not found")
+ var kernel32dll = []byte("kernel32.dll\000")
+ k32 := stdcall(_LoadLibraryA, uintptr(unsafe.Pointer(&kernel32dll[0])))
+ if k32 == 0 {
+ throw("kernel32.dll not found")
}
- _ProcessPrng = windowsFindfunc(bcryptPrimitives, []byte("ProcessPrng\000"))
+ _AddDllDirectory = windowsFindfunc(k32, []byte("AddDllDirectory\000"))
+ _AddVectoredContinueHandler = windowsFindfunc(k32, []byte("AddVectoredContinueHandler\000"))
+ _LoadLibraryExA = windowsFindfunc(k32, []byte("LoadLibraryExA\000"))
+ _LoadLibraryExW = windowsFindfunc(k32, []byte("LoadLibraryExW\000"))
+ useLoadLibraryEx = (_LoadLibraryExW != nil && _LoadLibraryExA != nil && _AddDllDirectory != nil)
+
+ initSysDirectory()
- n32 := windowsLoadSystemLib(ntdlldll[:])
+ var advapi32dll = []byte("advapi32.dll\000")
+ a32 := windowsLoadSystemLib(advapi32dll)
+ if a32 == 0 {
+ throw("advapi32.dll not found")
+ }
+ _RtlGenRandom = windowsFindfunc(a32, []byte("SystemFunction036\000"))
+
+ var ntdll = []byte("ntdll.dll\000")
+ n32 := windowsLoadSystemLib(ntdll)
if n32 == 0 {
throw("ntdll.dll not found")
}
@@ -297,7 +352,7 @@
context uintptr
}
- powrprof := windowsLoadSystemLib(powrprofdll[:])
+ powrprof := windowsLoadSystemLib([]byte("powrprof.dll\000"))
if powrprof == 0 {
return // Running on Windows 7, where we don't need it anyway.
}
@@ -351,6 +406,22 @@
// in sys_windows_386.s and sys_windows_amd64.s:
func getlasterror() uint32
+// When loading DLLs, we prefer to use LoadLibraryEx with
+// LOAD_LIBRARY_SEARCH_* flags, if available. LoadLibraryEx is not
+// available on old Windows, though, and the LOAD_LIBRARY_SEARCH_*
+// flags are not available on some versions of Windows without a
+// security patch.
+//
+// https://msdn.microsoft.com/en-us/library/ms684179(v=vs.85).aspx says:
+// "Windows 7, Windows Server 2008 R2, Windows Vista, and Windows
+// Server 2008: The LOAD_LIBRARY_SEARCH_* flags are available on
+// systems that have KB2533623 installed. To determine whether the
+// flags are available, use GetProcAddress to get the address of the
+// AddDllDirectory, RemoveDllDirectory, or SetDefaultDllDirectories
+// function. If GetProcAddress succeeds, the LOAD_LIBRARY_SEARCH_*
+// flags can be used with LoadLibraryEx."
+var useLoadLibraryEx bool
+
var timeBeginPeriodRetValue uint32
// osRelaxMinNS indicates that sysmon shouldn't osRelax if the next
@@ -417,7 +488,8 @@
// Only load winmm.dll if we need it.
// This avoids a dependency on winmm.dll for Go programs
// that run on new Windows versions.
- m32 := windowsLoadSystemLib(winmmdll[:])
+ var winmmdll = []byte("winmm.dll\000")
+ m32 := windowsLoadSystemLib(winmmdll)
if m32 == 0 {
print("runtime: LoadLibraryExW failed; errno=", getlasterror(), "\n")
throw("winmm.dll not found")
@@ -458,6 +530,28 @@
canUseLongPaths = true
}
+var osVersionInfo struct {
+ majorVersion uint32
+ minorVersion uint32
+ buildNumber uint32
+}
+
+func initOsVersionInfo() {
+ info := windows.OSVERSIONINFOW{}
+ info.OSVersionInfoSize = uint32(unsafe.Sizeof(info))
+ stdcall(_RtlGetVersion, uintptr(unsafe.Pointer(&info)))
+ osVersionInfo.majorVersion = info.MajorVersion
+ osVersionInfo.minorVersion = info.MinorVersion
+ osVersionInfo.buildNumber = info.BuildNumber
+}
+
+//go:linkname rtlGetNtVersionNumbers syscall.rtlGetNtVersionNumbers
+func rtlGetNtVersionNumbers(majorVersion *uint32, minorVersion *uint32, buildNumber *uint32) {
+ *majorVersion = osVersionInfo.majorVersion
+ *minorVersion = osVersionInfo.minorVersion
+ *buildNumber = osVersionInfo.buildNumber
+}
+
func osinit() {
asmstdcallAddr = unsafe.Pointer(windows.AsmStdCallAddr())
@@ -470,8 +564,8 @@
initHighResTimer()
timeBeginPeriodRetValue = osRelax(false)
- initSysDirectory()
initLongPathSupport()
+ initOsVersionInfo()
numCPUStartup = getCPUCount()
@@ -487,7 +581,7 @@
//go:nosplit
func readRandom(r []byte) int {
n := 0
- if stdcall(_ProcessPrng, uintptr(unsafe.Pointer(&r[0])), uintptr(len(r)))&0xff != 0 {
+ if stdcall(_RtlGenRandom, uintptr(unsafe.Pointer(&r[0])), uintptr(len(r)))&0xff != 0 {
n = len(r)
}
return n
Index: src/net/hook_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/net/hook_windows.go b/src/net/hook_windows.go
--- a/src/net/hook_windows.go (revision b0d48afabb9fd14976c27221cb525c5d2ebbfe79)
+++ b/src/net/hook_windows.go (revision 44e76f7cf1bc6e04b5da724e0b2e48f393713506)
@@ -13,6 +13,7 @@
hostsFilePath = windows.GetSystemDirectory() + "/Drivers/etc/hosts"
// Placeholders for socket system calls.
+ socketFunc func(int, int, int) (syscall.Handle, error) = syscall.Socket
wsaSocketFunc func(int32, int32, int32, *syscall.WSAProtocolInfo, uint32, uint32) (syscall.Handle, error) = windows.WSASocket
connectFunc func(syscall.Handle, syscall.Sockaddr) error = syscall.Connect
listenFunc func(syscall.Handle, int) error = syscall.Listen
Index: src/net/internal/socktest/main_test.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/net/internal/socktest/main_test.go b/src/net/internal/socktest/main_test.go
--- a/src/net/internal/socktest/main_test.go (revision b0d48afabb9fd14976c27221cb525c5d2ebbfe79)
+++ b/src/net/internal/socktest/main_test.go (revision 44e76f7cf1bc6e04b5da724e0b2e48f393713506)
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build !js && !plan9 && !wasip1 && !windows
+//go:build !js && !plan9 && !wasip1
package socktest_test
Index: src/net/internal/socktest/main_windows_test.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/net/internal/socktest/main_windows_test.go b/src/net/internal/socktest/main_windows_test.go
new file mode 100644
--- /dev/null (revision 44e76f7cf1bc6e04b5da724e0b2e48f393713506)
+++ b/src/net/internal/socktest/main_windows_test.go (revision 44e76f7cf1bc6e04b5da724e0b2e48f393713506)
@@ -0,0 +1,22 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package socktest_test
+
+import "syscall"
+
+var (
+ socketFunc func(int, int, int) (syscall.Handle, error)
+ closeFunc func(syscall.Handle) error
+)
+
+func installTestHooks() {
+ socketFunc = sw.Socket
+ closeFunc = sw.Closesocket
+}
+
+func uninstallTestHooks() {
+ socketFunc = syscall.Socket
+ closeFunc = syscall.Closesocket
+}
Index: src/net/internal/socktest/sys_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/net/internal/socktest/sys_windows.go b/src/net/internal/socktest/sys_windows.go
--- a/src/net/internal/socktest/sys_windows.go (revision b0d48afabb9fd14976c27221cb525c5d2ebbfe79)
+++ b/src/net/internal/socktest/sys_windows.go (revision 44e76f7cf1bc6e04b5da724e0b2e48f393713506)
@@ -9,6 +9,38 @@
"syscall"
)
+// Socket wraps [syscall.Socket].
+func (sw *Switch) Socket(family, sotype, proto int) (s syscall.Handle, err error) {
+ sw.once.Do(sw.init)
+
+ so := &Status{Cookie: cookie(family, sotype, proto)}
+ sw.fmu.RLock()
+ f, _ := sw.fltab[FilterSocket]
+ sw.fmu.RUnlock()
+
+ af, err := f.apply(so)
+ if err != nil {
+ return syscall.InvalidHandle, err
+ }
+ s, so.Err = syscall.Socket(family, sotype, proto)
+ if err = af.apply(so); err != nil {
+ if so.Err == nil {
+ syscall.Closesocket(s)
+ }
+ return syscall.InvalidHandle, err
+ }
+
+ sw.smu.Lock()
+ defer sw.smu.Unlock()
+ if so.Err != nil {
+ sw.stats.getLocked(so.Cookie).OpenFailed++
+ return syscall.InvalidHandle, so.Err
+ }
+ nso := sw.addLocked(s, family, sotype, proto)
+ sw.stats.getLocked(nso.Cookie).Opened++
+ return s, nil
+}
+
// WSASocket wraps [syscall.WSASocket].
func (sw *Switch) WSASocket(family, sotype, proto int32, protinfo *syscall.WSAProtocolInfo, group uint32, flags uint32) (s syscall.Handle, err error) {
sw.once.Do(sw.init)
Index: src/net/main_windows_test.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/net/main_windows_test.go b/src/net/main_windows_test.go
--- a/src/net/main_windows_test.go (revision b0d48afabb9fd14976c27221cb525c5d2ebbfe79)
+++ b/src/net/main_windows_test.go (revision 44e76f7cf1bc6e04b5da724e0b2e48f393713506)
@@ -12,6 +12,7 @@
var (
// Placeholders for saving original socket system calls.
+ origSocket = socketFunc
origWSASocket = wsaSocketFunc
origClosesocket = poll.CloseFunc
origConnect = connectFunc
@@ -21,6 +22,7 @@
)
func installTestHooks() {
+ socketFunc = sw.Socket
wsaSocketFunc = sw.WSASocket
poll.CloseFunc = sw.Closesocket
connectFunc = sw.Connect
@@ -30,6 +32,7 @@
}
func uninstallTestHooks() {
+ socketFunc = origSocket
wsaSocketFunc = origWSASocket
poll.CloseFunc = origClosesocket
connectFunc = origConnect
Index: src/net/sock_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/net/sock_windows.go b/src/net/sock_windows.go
--- a/src/net/sock_windows.go (revision b0d48afabb9fd14976c27221cb525c5d2ebbfe79)
+++ b/src/net/sock_windows.go (revision 44e76f7cf1bc6e04b5da724e0b2e48f393713506)
@@ -20,6 +20,21 @@
func sysSocket(family, sotype, proto int) (syscall.Handle, error) {
s, err := wsaSocketFunc(int32(family), int32(sotype), int32(proto),
nil, 0, windows.WSA_FLAG_OVERLAPPED|windows.WSA_FLAG_NO_HANDLE_INHERIT)
+ if err == nil {
+ return s, nil
+ }
+ // WSA_FLAG_NO_HANDLE_INHERIT flag is not supported on some
+ // old versions of Windows, see
+ // https://msdn.microsoft.com/en-us/library/windows/desktop/ms742212(v=vs.85).aspx
+ // for details. Just use syscall.Socket, if windows.WSASocket failed.
+
+ // See ../syscall/exec_unix.go for description of ForkLock.
+ syscall.ForkLock.RLock()
+ s, err = socketFunc(family, sotype, proto)
+ if err == nil {
+ syscall.CloseOnExec(s)
+ }
+ syscall.ForkLock.RUnlock()
if err != nil {
return syscall.InvalidHandle, os.NewSyscallError("socket", err)
}
Index: src/syscall/exec_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/syscall/exec_windows.go b/src/syscall/exec_windows.go
--- a/src/syscall/exec_windows.go (revision b0d48afabb9fd14976c27221cb525c5d2ebbfe79)
+++ b/src/syscall/exec_windows.go (revision b4aece36e51ecce81c3ee9fe03e31db552e90018)
@@ -15,7 +15,6 @@
"unsafe"
)
-// ForkLock is not used on Windows.
var ForkLock sync.RWMutex
// EscapeArg rewrites command line argument s as prescribed
@@ -304,6 +303,9 @@
var zeroProcAttr ProcAttr
var zeroSysProcAttr SysProcAttr
+//go:linkname rtlGetNtVersionNumbers
+func rtlGetNtVersionNumbers(majorVersion *uint32, minorVersion *uint32, buildNumber *uint32)
+
func StartProcess(argv0 string, argv []string, attr *ProcAttr) (pid int, handle uintptr, err error) {
if len(argv0) == 0 {
return 0, 0, EWINDOWS
@@ -367,6 +369,17 @@
}
}
+ var maj, min, build uint32
+ rtlGetNtVersionNumbers(&maj, &min, &build)
+ isWin7 := maj < 6 || (maj == 6 && min <= 1)
+ // NT kernel handles are divisible by 4, with the bottom 3 bits left as
+ // a tag. The fully set tag correlates with the types of handles we're
+ // concerned about here. Except, the kernel will interpret some
+ // special handle values, like -1, -2, and so forth, so kernelbase.dll
+ // checks to see that those bottom three bits are checked, but that top
+ // bit is not checked.
+ isLegacyWin7ConsoleHandle := func(handle Handle) bool { return isWin7 && handle&0x10000003 == 3 }
+
p, _ := GetCurrentProcess()
parentProcess := p
if sys.ParentProcess != 0 {
@@ -375,7 +388,15 @@
fd := make([]Handle, len(attr.Files))
for i := range attr.Files {
if attr.Files[i] > 0 {
- err := DuplicateHandle(p, Handle(attr.Files[i]), parentProcess, &fd[i], 0, true, DUPLICATE_SAME_ACCESS)
+ destinationProcessHandle := parentProcess
+
+ // On Windows 7, console handles aren't real handles, and can only be duplicated
+ // into the current process, not a parent one, which amounts to the same thing.
+ if parentProcess != p && isLegacyWin7ConsoleHandle(Handle(attr.Files[i])) {
+ destinationProcessHandle = p
+ }
+
+ err := DuplicateHandle(p, Handle(attr.Files[i]), destinationProcessHandle, &fd[i], 0, true, DUPLICATE_SAME_ACCESS)
if err != nil {
return 0, 0, err
}
@@ -406,6 +427,14 @@
fd = append(fd, sys.AdditionalInheritedHandles...)
+ // On Windows 7, console handles aren't real handles, so don't pass them
+ // through to PROC_THREAD_ATTRIBUTE_HANDLE_LIST.
+ for i := range fd {
+ if isLegacyWin7ConsoleHandle(fd[i]) {
+ fd[i] = 0
+ }
+ }
+
// The presence of a NULL handle in the list is enough to cause PROC_THREAD_ATTRIBUTE_HANDLE_LIST
// to treat the entire list as empty, so remove NULL handles.
j := 0
Index: src/syscall/dll_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/syscall/dll_windows.go b/src/syscall/dll_windows.go
--- a/src/syscall/dll_windows.go (revision b4aece36e51ecce81c3ee9fe03e31db552e90018)
+++ b/src/syscall/dll_windows.go (revision ea2726a6fa25fbfa1092e696e522eafca544d24c)
@@ -119,14 +119,7 @@
}
//go:linkname loadsystemlibrary
-func loadsystemlibrary(filename *uint16) (uintptr, Errno) {
- const _LOAD_LIBRARY_SEARCH_SYSTEM32 = 0x00000800
- handle, _, err := SyscallN(uintptr(__LoadLibraryExW), uintptr(unsafe.Pointer(filename)), 0, _LOAD_LIBRARY_SEARCH_SYSTEM32)
- if handle != 0 {
- err = 0
- }
- return handle, err
-}
+func loadsystemlibrary(filename *uint16, absoluteFilepath *uint16) (handle uintptr, err Errno)
//go:linkname getprocaddress
func getprocaddress(handle uintptr, procname *uint8) (uintptr, Errno) {
@@ -143,6 +136,9 @@
Handle Handle
}
+//go:linkname getSystemDirectory
+func getSystemDirectory() string // Implemented in runtime package.
+
// LoadDLL loads the named DLL file into memory.
//
// If name is not an absolute path and is not a known system DLL used by
@@ -159,7 +155,11 @@
var h uintptr
var e Errno
if sysdll.IsSystemDLL[name] {
- h, e = loadsystemlibrary(namep)
+ absoluteFilepathp, err := UTF16PtrFromString(getSystemDirectory() + name)
+ if err != nil {
+ return nil, err
+ }
+ h, e = loadsystemlibrary(namep, absoluteFilepathp)
} else {
h, e = loadlibrary(namep)
}
Index: src/os/removeall_at.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/os/removeall_at.go b/src/os/removeall_at.go
--- a/src/os/removeall_at.go (revision ea2726a6fa25fbfa1092e696e522eafca544d24c)
+++ b/src/os/removeall_at.go (revision d47e0d22130d597dcf9daa6b41fd9501274f0cb2)
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build unix || wasip1 || windows
+//go:build unix || wasip1
package os
@@ -175,3 +175,25 @@
}
return newDirFile(fd, name)
}
+
+func rootRemoveAll(r *Root, name string) error {
+ // Consistency with os.RemoveAll: Strip trailing /s from the name,
+ // so RemoveAll("not_a_directory/") succeeds.
+ for len(name) > 0 && IsPathSeparator(name[len(name)-1]) {
+ name = name[:len(name)-1]
+ }
+ if endsWithDot(name) {
+ // Consistency with os.RemoveAll: Return EINVAL when trying to remove .
+ return &PathError{Op: "RemoveAll", Path: name, Err: syscall.EINVAL}
+ }
+ _, err := doInRoot(r, name, nil, func(parent sysfdType, name string) (struct{}, error) {
+ return struct{}{}, removeAllFrom(parent, name)
+ })
+ if IsNotExist(err) {
+ return nil
+ }
+ if err != nil {
+ return &PathError{Op: "RemoveAll", Path: name, Err: underlyingError(err)}
+ }
+ return err
+}
Index: src/os/removeall_noat.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/os/removeall_noat.go b/src/os/removeall_noat.go
--- a/src/os/removeall_noat.go (revision ea2726a6fa25fbfa1092e696e522eafca544d24c)
+++ b/src/os/removeall_noat.go (revision d47e0d22130d597dcf9daa6b41fd9501274f0cb2)
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build (js && wasm) || plan9
+//go:build (js && wasm) || plan9 || windows
package os
@@ -140,3 +140,22 @@
}
return err
}
+
+func rootRemoveAll(r *Root, name string) error {
+ if endsWithDot(name) {
+ // Consistency with os.RemoveAll: Return EINVAL when trying to remove .
+ return &PathError{Op: "RemoveAll", Path: name, Err: syscall.EINVAL}
+ }
+ if err := checkPathEscapesLstat(r, name); err != nil {
+ if err == syscall.ENOTDIR {
+ // Some intermediate path component is not a directory.
+ // RemoveAll treats this as success (since the target doesn't exist).
+ return nil
+ }
+ return &PathError{Op: "RemoveAll", Path: name, Err: err}
+ }
+ if err := RemoveAll(joinPath(r.root.name, name)); err != nil {
+ return &PathError{Op: "RemoveAll", Path: name, Err: underlyingError(err)}
+ }
+ return nil
+}
Index: src/os/root_noopenat.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/os/root_noopenat.go b/src/os/root_noopenat.go
--- a/src/os/root_noopenat.go (revision ea2726a6fa25fbfa1092e696e522eafca544d24c)
+++ b/src/os/root_noopenat.go (revision d47e0d22130d597dcf9daa6b41fd9501274f0cb2)
@@ -11,7 +11,6 @@
"internal/filepathlite"
"internal/stringslite"
"sync/atomic"
- "syscall"
"time"
)
@@ -185,25 +184,6 @@
}
return nil
}
-
-func rootRemoveAll(r *Root, name string) error {
- if endsWithDot(name) {
- // Consistency with os.RemoveAll: Return EINVAL when trying to remove .
- return &PathError{Op: "RemoveAll", Path: name, Err: syscall.EINVAL}
- }
- if err := checkPathEscapesLstat(r, name); err != nil {
- if err == syscall.ENOTDIR {
- // Some intermediate path component is not a directory.
- // RemoveAll treats this as success (since the target doesn't exist).
- return nil
- }
- return &PathError{Op: "RemoveAll", Path: name, Err: err}
- }
- if err := RemoveAll(joinPath(r.root.name, name)); err != nil {
- return &PathError{Op: "RemoveAll", Path: name, Err: underlyingError(err)}
- }
- return nil
-}
func rootReadlink(r *Root, name string) (string, error) {
if err := checkPathEscapesLstat(r, name); err != nil {
Index: src/os/root_openat.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/os/root_openat.go b/src/os/root_openat.go
--- a/src/os/root_openat.go (revision ea2726a6fa25fbfa1092e696e522eafca544d24c)
+++ b/src/os/root_openat.go (revision d47e0d22130d597dcf9daa6b41fd9501274f0cb2)
@@ -196,28 +196,6 @@
return nil
}
-func rootRemoveAll(r *Root, name string) error {
- // Consistency with os.RemoveAll: Strip trailing /s from the name,
- // so RemoveAll("not_a_directory/") succeeds.
- for len(name) > 0 && IsPathSeparator(name[len(name)-1]) {
- name = name[:len(name)-1]
- }
- if endsWithDot(name) {
- // Consistency with os.RemoveAll: Return EINVAL when trying to remove .
- return &PathError{Op: "RemoveAll", Path: name, Err: syscall.EINVAL}
- }
- _, err := doInRoot(r, name, nil, func(parent sysfdType, name string) (struct{}, error) {
- return struct{}{}, removeAllFrom(parent, name)
- })
- if IsNotExist(err) {
- return nil
- }
- if err != nil {
- return &PathError{Op: "RemoveAll", Path: name, Err: underlyingError(err)}
- }
- return err
-}
-
func rootRename(r *Root, oldname, newname string) error {
_, err := doInRoot(r, oldname, nil, func(oldparent sysfdType, oldname string) (struct{}, error) {
_, err := doInRoot(r, newname, nil, func(newparent sysfdType, newname string) (struct{}, error) {
Index: src/os/root_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/os/root_windows.go b/src/os/root_windows.go
--- a/src/os/root_windows.go (revision ea2726a6fa25fbfa1092e696e522eafca544d24c)
+++ b/src/os/root_windows.go (revision d47e0d22130d597dcf9daa6b41fd9501274f0cb2)
@@ -402,3 +402,14 @@
}
return fi.Mode(), nil
}
+
+func checkPathEscapes(r *Root, name string) error {
+ if !filepathlite.IsLocal(name) {
+ return errPathEscapes
+ }
+ return nil
+}
+
+func checkPathEscapesLstat(r *Root, name string) error {
+ return checkPathEscapes(r, name)
+}
Index: src/os/exec_windows.go
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/src/os/exec_windows.go b/src/os/exec_windows.go
--- a/src/os/exec_windows.go (revision d47e0d22130d597dcf9daa6b41fd9501274f0cb2)
+++ b/src/os/exec_windows.go (revision 00e8daec9a4d88f44a8dc55d3bdb71878e525b41)
@@ -10,6 +10,7 @@
"runtime"
"syscall"
"time"
+ _ "unsafe"
)
// Note that Process.handle is never nil because Windows always requires
@@ -49,9 +50,23 @@
// than statusDone.
p.doRelease(statusReleased)
+ var maj, min, build uint32
+ rtlGetNtVersionNumbers(&maj, &min, &build)
+ if maj < 10 {
+ // NOTE(brainman): It seems that sometimes process is not dead
+ // when WaitForSingleObject returns. But we do not know any
+ // other way to wait for it. Sleeping for a while seems to do
+ // the trick sometimes.
+ // See https://golang.org/issue/25965 for details.
+ time.Sleep(5 * time.Millisecond)
+ }
+
return &ProcessState{p.Pid, syscall.WaitStatus{ExitCode: ec}, &u}, nil
}
+//go:linkname rtlGetNtVersionNumbers syscall.rtlGetNtVersionNumbers
+func rtlGetNtVersionNumbers(majorVersion *uint32, minorVersion *uint32, buildNumber *uint32)
+
func (p *Process) signal(sig Signal) error {
handle, status := p.handleTransientAcquire()
switch status {

View File

@@ -1,18 +0,0 @@
-s dir
--name mihomo
--category net
--license GPL-3.0-or-later
--description "The universal proxy platform."
--url "https://wiki.metacubex.one/"
--maintainer "MetaCubeX <none@example.com>"
--deb-field "Bug: https://github.com/MetaCubeX/mihomo/issues"
--no-deb-generate-changes
--config-files /etc/mihomo/config.yaml
.github/release/config.yaml=/etc/mihomo/config.yaml
.github/release/mihomo.service=/usr/lib/systemd/system/mihomo.service
.github/release/mihomo@.service=/usr/lib/systemd/system/mihomo@.service
LICENSE=/usr/share/licenses/mihomo/LICENSE

View File

@@ -1,15 +0,0 @@
mixed-port: 7890
dns:
enable: true
ipv6: true
enhanced-mode: fake-ip
fake-ip-filter:
- "*"
- "+.lan"
- "+.local"
nameserver:
- system
rules:
- MATCH,DIRECT

View File

@@ -1,17 +0,0 @@
[Unit]
Description=mihomo Daemon, Another Clash Kernel.
Documentation=https://wiki.metacubex.one
After=network.target nss-lookup.target network-online.target
[Service]
Type=simple
CapabilityBoundingSet=CAP_NET_ADMIN CAP_NET_RAW CAP_NET_BIND_SERVICE CAP_SYS_TIME CAP_SYS_PTRACE CAP_DAC_READ_SEARCH CAP_DAC_OVERRIDE
AmbientCapabilities=CAP_NET_ADMIN CAP_NET_RAW CAP_NET_BIND_SERVICE CAP_SYS_TIME CAP_SYS_PTRACE CAP_DAC_READ_SEARCH CAP_DAC_OVERRIDE
ExecStart=/usr/bin/mihomo -d /etc/mihomo
ExecReload=/bin/kill -HUP $MAINPID
Restart=on-failure
RestartSec=10
LimitNOFILE=infinity
[Install]
WantedBy=multi-user.target

View File

@@ -1,17 +0,0 @@
[Unit]
Description=mihomo Daemon, Another Clash Kernel.
Documentation=https://wiki.metacubex.one
After=network.target nss-lookup.target network-online.target
[Service]
Type=simple
CapabilityBoundingSet=CAP_NET_ADMIN CAP_NET_RAW CAP_NET_BIND_SERVICE CAP_SYS_TIME CAP_SYS_PTRACE CAP_DAC_READ_SEARCH CAP_DAC_OVERRIDE
AmbientCapabilities=CAP_NET_ADMIN CAP_NET_RAW CAP_NET_BIND_SERVICE CAP_SYS_TIME CAP_SYS_PTRACE CAP_DAC_READ_SEARCH CAP_DAC_OVERRIDE
ExecStart=/usr/bin/mihomo -d /etc/mihomo
ExecReload=/bin/kill -HUP $MAINPID
Restart=on-failure
RestartSec=10
LimitNOFILE=infinity
[Install]
WantedBy=multi-user.target

16
.github/workflows/Delete.yml vendored Normal file
View File

@@ -0,0 +1,16 @@
name: Delete old workflow runs
on:
schedule:
- cron: '0 0 1 * *'
# Run monthly, at 00:00 on the 1st day of month.
jobs:
del_runs:
runs-on: ubuntu-latest
steps:
- name: Delete workflow runs
uses: GitRML/delete-workflow-runs@main
with:
token: ${{ secrets.AUTH_PAT }}
repository: ${{ github.repository }}
retain_days: 30

View File

@@ -1,10 +1,6 @@
name: Build name: Build
on: on:
workflow_dispatch: workflow_dispatch:
inputs:
version:
description: "Tag version to release"
required: true
push: push:
paths-ignore: paths-ignore:
- "docs/**" - "docs/**"
@@ -14,11 +10,12 @@ on:
- Alpha - Alpha
tags: tags:
- "v*" - "v*"
pull_request: pull_request_target:
branches: branches:
- Alpha - Alpha
concurrency: concurrency:
group: "${{ github.workflow }}-${{ github.ref }}" group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true cancel-in-progress: true
env: env:
@@ -29,52 +26,35 @@ jobs:
strategy: strategy:
matrix: matrix:
jobs: jobs:
- { goos: darwin, goarch: amd64, goamd64: v1, output: amd64-compatible } # old style file name will be removed in next released
- { goos: darwin, goarch: amd64, goamd64: v3, output: amd64 }
- { goos: darwin, goarch: amd64, goamd64: v1, output: amd64-v1 }
- { goos: darwin, goarch: amd64, goamd64: v2, output: amd64-v2 }
- { goos: darwin, goarch: amd64, goamd64: v3, output: amd64-v3 }
- { goos: darwin, goarch: arm64, output: arm64 } - { goos: darwin, goarch: arm64, output: arm64 }
- { goos: darwin, goarch: amd64, goamd64: v1, output: amd64-compatible }
- { goos: darwin, goarch: amd64, goamd64: v3, output: amd64 }
- { goos: linux, goarch: '386', go386: sse2, output: '386', debian: i386, rpm: i386} - { goos: linux, goarch: '386', output: '386' }
- { goos: linux, goarch: '386', go386: softfloat, output: '386-softfloat' } - { goos: linux, goarch: amd64, goamd64: v1, output: amd64-compatible, test: test }
- { goos: linux, goarch: amd64, goamd64: v1, output: amd64-compatible} # old style file name will be removed in next released - { goos: linux, goarch: amd64, goamd64: v3, output: amd64 }
- { goos: linux, goarch: amd64, goamd64: v3, output: amd64, debian: amd64, rpm: x86_64, pacman: x86_64} - { goos: linux, goarch: arm64, output: arm64 }
- { goos: linux, goarch: amd64, goamd64: v1, output: amd64-v1, debian: amd64, rpm: x86_64, pacman: x86_64, test: test } - { goos: linux, goarch: arm, goarm: '7', output: armv7 }
- { goos: linux, goarch: amd64, goamd64: v2, output: amd64-v2, debian: amd64, rpm: x86_64, pacman: x86_64} - { goos: linux, goarch: mips, mips: hardfloat, output: mips-hardfloat }
- { goos: linux, goarch: amd64, goamd64: v3, output: amd64-v3, debian: amd64, rpm: x86_64, pacman: x86_64} - { goos: linux, goarch: mips, mips: softfloat, output: mips-softfloat }
- { goos: linux, goarch: arm64, output: arm64, debian: arm64, rpm: aarch64, pacman: aarch64} - { goos: linux, goarch: mipsle, mips: hardfloat, output: mipsle-hardfloat }
- { goos: linux, goarch: arm, goarm: '5', output: armv5 } - { goos: linux, goarch: mipsle, mips: softfloat, output: mipsle-softfloat }
- { goos: linux, goarch: arm, goarm: '6', output: armv6, debian: armel, rpm: armv6hl}
- { goos: linux, goarch: arm, goarm: '7', output: armv7, debian: armhf, rpm: armv7hl, pacman: armv7hl}
- { goos: linux, goarch: mips, gomips: hardfloat, output: mips-hardfloat }
- { goos: linux, goarch: mips, gomips: softfloat, output: mips-softfloat }
- { goos: linux, goarch: mipsle, gomips: hardfloat, output: mipsle-hardfloat }
- { goos: linux, goarch: mipsle, gomips: softfloat, output: mipsle-softfloat }
- { goos: linux, goarch: mips64, output: mips64 } - { goos: linux, goarch: mips64, output: mips64 }
- { goos: linux, goarch: mips64le, output: mips64le, debian: mips64el, rpm: mips64el } - { goos: linux, goarch: mips64le, output: mips64le }
- { goos: linux, goarch: loong64, output: loong64-abi1, abi: '1', debian: loongarch64, rpm: loongarch64 } - { goos: linux, goarch: loong64, output: loong64-abi1, abi: '1' }
- { goos: linux, goarch: loong64, output: loong64-abi2, abi: '2', debian: loong64, rpm: loong64 } - { goos: linux, goarch: loong64, output: loong64-abi2, abi: '2' }
- { goos: linux, goarch: riscv64, output: riscv64, debian: riscv64, rpm: riscv64 } - { goos: linux, goarch: riscv64, output: riscv64 }
- { goos: linux, goarch: s390x, output: s390x, debian: s390x, rpm: s390x } - { goos: linux, goarch: s390x, output: s390x }
- { goos: linux, goarch: ppc64le, output: ppc64le, debian: ppc64el, rpm: ppc64le }
# Go 1.25 with special patch can work on Windows 7
# https://github.com/MetaCubeX/go/commits/release-branch.go1.25/
- { goos: windows, goarch: '386', output: '386' } - { goos: windows, goarch: '386', output: '386' }
- { goos: windows, goarch: amd64, goamd64: v1, output: amd64-compatible } # old style file name will be removed in next released - { goos: windows, goarch: amd64, goamd64: v1, output: amd64-compatible }
- { goos: windows, goarch: amd64, goamd64: v3, output: amd64 } - { goos: windows, goarch: amd64, goamd64: v3, output: amd64 }
- { goos: windows, goarch: amd64, goamd64: v1, output: amd64-v1 } - { goos: windows, goarch: arm, goarm: '7', output: armv7 }
- { goos: windows, goarch: amd64, goamd64: v2, output: amd64-v2 }
- { goos: windows, goarch: amd64, goamd64: v3, output: amd64-v3 }
- { goos: windows, goarch: arm64, output: arm64 } - { goos: windows, goarch: arm64, output: arm64 }
- { goos: freebsd, goarch: '386', output: '386' } - { goos: freebsd, goarch: '386', output: '386' }
- { goos: freebsd, goarch: amd64, goamd64: v1, output: amd64-compatible } # old style file name will be removed in next released - { goos: freebsd, goarch: amd64, goamd64: v1, output: amd64-compatible }
- { goos: freebsd, goarch: amd64, goamd64: v3, output: amd64 } - { goos: freebsd, goarch: amd64, goamd64: v3, output: amd64 }
- { goos: freebsd, goarch: amd64, goamd64: v1, output: amd64-v1 }
- { goos: freebsd, goarch: amd64, goamd64: v2, output: amd64-v2 }
- { goos: freebsd, goarch: amd64, goamd64: v3, output: amd64-v3 }
- { goos: freebsd, goarch: arm64, output: arm64 } - { goos: freebsd, goarch: arm64, output: arm64 }
- { goos: android, goarch: '386', ndk: i686-linux-android34, output: '386' } - { goos: android, goarch: '386', ndk: i686-linux-android34, output: '386' }
@@ -82,188 +62,72 @@ jobs:
- { goos: android, goarch: arm, ndk: armv7a-linux-androideabi34, output: armv7 } - { goos: android, goarch: arm, ndk: armv7a-linux-androideabi34, output: armv7 }
- { goos: android, goarch: arm64, ndk: aarch64-linux-android34, output: arm64-v8 } - { goos: android, goarch: arm64, ndk: aarch64-linux-android34, output: arm64-v8 }
# Go 1.24 with special patch can work on Windows 7
# https://github.com/MetaCubeX/go/commits/release-branch.go1.24/
- { goos: windows, goarch: '386', output: '386-go124', goversion: '1.24' }
- { goos: windows, goarch: amd64, goamd64: v1, output: amd64-v1-go124, goversion: '1.24' }
- { goos: windows, goarch: amd64, goamd64: v2, output: amd64-v2-go124, goversion: '1.24' }
- { goos: windows, goarch: amd64, goamd64: v3, output: amd64-v3-go124, goversion: '1.24' }
# Go 1.23 with special patch can work on Windows 7
# https://github.com/MetaCubeX/go/commits/release-branch.go1.23/
- { goos: windows, goarch: '386', output: '386-go123', goversion: '1.23' }
- { goos: windows, goarch: amd64, goamd64: v1, output: amd64-v1-go123, goversion: '1.23' }
- { goos: windows, goarch: amd64, goamd64: v2, output: amd64-v2-go123, goversion: '1.23' }
- { goos: windows, goarch: amd64, goamd64: v3, output: amd64-v3-go123, goversion: '1.23' }
# Go 1.22 with special patch can work on Windows 7
# https://github.com/MetaCubeX/go/commits/release-branch.go1.22/
- { goos: windows, goarch: '386', output: '386-go122', goversion: '1.22' }
- { goos: windows, goarch: amd64, goamd64: v1, output: amd64-v1-go122, goversion: '1.22' }
- { goos: windows, goarch: amd64, goamd64: v2, output: amd64-v2-go122, goversion: '1.22' }
- { goos: windows, goarch: amd64, goamd64: v3, output: amd64-v3-go122, goversion: '1.22' }
# Go 1.21 can revert commit `9e4385` to work on Windows 7
# https://github.com/golang/go/issues/64622#issuecomment-1847475161
# (OR we can just use golang1.21.4 which unneeded any patch)
- { goos: windows, goarch: '386', output: '386-go121', goversion: '1.21' }
- { goos: windows, goarch: amd64, goamd64: v1, output: amd64-v1-go121, goversion: '1.21' }
- { goos: windows, goarch: amd64, goamd64: v2, output: amd64-v2-go121, goversion: '1.21' }
- { goos: windows, goarch: amd64, goamd64: v3, output: amd64-v3-go121, goversion: '1.21' }
# Go 1.20 is the last release that will run on any release of Windows 7, 8, Server 2008 and Server 2012. Go 1.21 will require at least Windows 10 or Server 2016. # Go 1.20 is the last release that will run on any release of Windows 7, 8, Server 2008 and Server 2012. Go 1.21 will require at least Windows 10 or Server 2016.
- { goos: windows, goarch: '386', output: '386-go120', goversion: '1.20' } - { goos: windows, goarch: '386', output: '386-go120', goversion: '1.20' }
- { goos: windows, goarch: amd64, goamd64: v1, output: amd64-v1-go120, goversion: '1.20' } - { goos: windows, goarch: amd64, goamd64: v1, output: amd64-compatible-go120, goversion: '1.20' }
- { goos: windows, goarch: amd64, goamd64: v2, output: amd64-v2-go120, goversion: '1.20' } - { goos: windows, goarch: amd64, goamd64: v3, output: amd64-go120, goversion: '1.20' }
- { goos: windows, goarch: amd64, goamd64: v3, output: amd64-v3-go120, goversion: '1.20' }
# Go 1.24 is the last release that will run on macOS 11 Big Sur. Go 1.25 will require macOS 12 Monterey or later.
- { goos: darwin, goarch: arm64, output: arm64-go124, goversion: '1.24' }
- { goos: darwin, goarch: amd64, goamd64: v1, output: amd64-v1-go124, goversion: '1.24' }
- { goos: darwin, goarch: amd64, goamd64: v2, output: amd64-v2-go124, goversion: '1.24' }
- { goos: darwin, goarch: amd64, goamd64: v3, output: amd64-v3-go124, goversion: '1.24' }
# Go 1.22 is the last release that will run on macOS 10.15 Catalina. Go 1.23 will require macOS 11 Big Sur or later.
- { goos: darwin, goarch: arm64, output: arm64-go122, goversion: '1.22' }
- { goos: darwin, goarch: amd64, goamd64: v1, output: amd64-v1-go122, goversion: '1.22' }
- { goos: darwin, goarch: amd64, goamd64: v2, output: amd64-v2-go122, goversion: '1.22' }
- { goos: darwin, goarch: amd64, goamd64: v3, output: amd64-v3-go122, goversion: '1.22' }
# Go 1.20 is the last release that will run on macOS 10.13 High Sierra or 10.14 Mojave. Go 1.21 will require macOS 10.15 Catalina or later. # Go 1.20 is the last release that will run on macOS 10.13 High Sierra or 10.14 Mojave. Go 1.21 will require macOS 10.15 Catalina or later.
- { goos: darwin, goarch: arm64, output: arm64-go120, goversion: '1.20' } - { goos: darwin, goarch: arm64, output: arm64-go120, goversion: '1.20' }
- { goos: darwin, goarch: amd64, goamd64: v1, output: amd64-v1-go120, goversion: '1.20' } - { goos: darwin, goarch: amd64, goamd64: v1, output: amd64-compatible-go120, goversion: '1.20' }
- { goos: darwin, goarch: amd64, goamd64: v2, output: amd64-v2-go120, goversion: '1.20' } - { goos: darwin, goarch: amd64, goamd64: v3, output: amd64-go120, goversion: '1.20' }
- { goos: darwin, goarch: amd64, goamd64: v3, output: amd64-v3-go120, goversion: '1.20' }
# Go 1.23 is the last release that requires Linux kernel version 2.6.32 or later. Go 1.24 will require Linux kernel version 3.2 or later.
- { goos: linux, goarch: '386', output: '386-go123', goversion: '1.23' }
- { goos: linux, goarch: amd64, goamd64: v1, output: amd64-v1-go123, goversion: '1.23', test: test }
- { goos: linux, goarch: amd64, goamd64: v2, output: amd64-v2-go123, goversion: '1.23' }
- { goos: linux, goarch: amd64, goamd64: v3, output: amd64-v3-go123, goversion: '1.23' }
# only for test # only for test
- { goos: linux, goarch: '386', output: '386-go120', goversion: '1.20' } - { goos: linux, goarch: '386', output: '386-go120', goversion: '1.20' }
- { goos: linux, goarch: amd64, goamd64: v1, output: amd64-v1-go120, goversion: '1.20', test: test } - { goos: linux, goarch: amd64, goamd64: v1, output: amd64-compatible-go120, goversion: '1.20', test: test }
- { goos: linux, goarch: amd64, goamd64: v2, output: amd64-v2-go120, goversion: '1.20' } - { goos: linux, goarch: amd64, goamd64: v3, output: amd64-go120, goversion: '1.20' }
- { goos: linux, goarch: amd64, goamd64: v3, output: amd64-v3-go120, goversion: '1.20' }
steps: steps:
- uses: actions/checkout@v5 - uses: actions/checkout@v4
- name: Set up Go - name: Set up Go
if: ${{ matrix.jobs.goversion == '' && matrix.jobs.abi != '1' }} if: ${{ matrix.jobs.goversion == '' && matrix.jobs.goarch != 'loong64' }}
uses: actions/setup-go@v6 uses: actions/setup-go@v5
with: with:
go-version: '1.25' go-version: '1.22'
check-latest: true # Always check for the latest patch release
- name: Set up Go - name: Set up Go
if: ${{ matrix.jobs.goversion != '' && matrix.jobs.abi != '1' }} if: ${{ matrix.jobs.goversion != '' && matrix.jobs.goarch != 'loong64' }}
uses: actions/setup-go@v6 uses: actions/setup-go@v5
with: with:
go-version: ${{ matrix.jobs.goversion }} go-version: ${{ matrix.jobs.goversion }}
check-latest: true # Always check for the latest patch release
- name: Set up Go1.24 loongarch abi1 - name: Set up Go1.21 loongarch abi1
if: ${{ matrix.jobs.goarch == 'loong64' && matrix.jobs.abi == '1' }} if: ${{ matrix.jobs.goarch == 'loong64' && matrix.jobs.abi == '1' }}
run: | run: |
wget -q https://github.com/MetaCubeX/loongarch64-golang/releases/download/1.24.0/go1.24.0.linux-amd64-abi1.tar.gz wget -q https://github.com/xishang0128/loongarch64-golang/releases/download/1.21.5/go1.21.5.linux-amd64-abi1.tar.gz
sudo tar zxf go1.24.0.linux-amd64-abi1.tar.gz -C /usr/local sudo tar zxf go1.21.5.linux-amd64-abi1.tar.gz -C /usr/local
echo "/usr/local/go/bin" >> $GITHUB_PATH echo "/usr/local/go/bin" >> $GITHUB_PATH
# modify from https://github.com/restic/restic/issues/4636#issuecomment-1896455557 - name: Set up Go1.21 loongarch abi2
# this patch file only works on golang1.25.x if: ${{ matrix.jobs.goarch == 'loong64' && matrix.jobs.abi == '2' }}
# that means after golang1.26 release it must be changed
# see: https://github.com/MetaCubeX/go/commits/release-branch.go1.25/
# revert:
# 693def151adff1af707d82d28f55dba81ceb08e1: "crypto/rand,runtime: switch RtlGenRandom for ProcessPrng"
# 7c1157f9544922e96945196b47b95664b1e39108: "net: remove sysSocket fallback for Windows 7"
# 48042aa09c2f878c4faa576948b07fe625c4707a: "syscall: remove Windows 7 console handle workaround"
# a17d959debdb04cd550016a3501dd09d50cd62e7: "runtime: always use LoadLibraryEx to load system libraries"
# f0894a00f4b756d4b9b4078af2e686b359493583: "os: remove 5ms sleep on Windows in (*Process).Wait"
# sepical fix:
# - os.RemoveAll not working on Windows7
- name: Revert Golang1.25 commit for Windows7/8
if: ${{ matrix.jobs.goos == 'windows' && matrix.jobs.goversion == '' }}
run: | run: |
cd $(go env GOROOT) wget -q https://github.com/xishang0128/loongarch64-golang/releases/download/1.21.5/go1.21.5.linux-amd64-abi2.tar.gz
patch --verbose -p 1 < $GITHUB_WORKSPACE/.github/patch/go1.25.patch sudo tar zxf go1.21.5.linux-amd64-abi2.tar.gz -C /usr/local
echo "/usr/local/go/bin" >> $GITHUB_PATH
# modify from https://github.com/restic/restic/issues/4636#issuecomment-1896455557
# this patch file only works on golang1.24.x
# that means after golang1.25 release it must be changed
# see: https://github.com/MetaCubeX/go/commits/release-branch.go1.24/
# revert:
# 693def151adff1af707d82d28f55dba81ceb08e1: "crypto/rand,runtime: switch RtlGenRandom for ProcessPrng"
# 7c1157f9544922e96945196b47b95664b1e39108: "net: remove sysSocket fallback for Windows 7"
# 48042aa09c2f878c4faa576948b07fe625c4707a: "syscall: remove Windows 7 console handle workaround"
# a17d959debdb04cd550016a3501dd09d50cd62e7: "runtime: always use LoadLibraryEx to load system libraries"
- name: Revert Golang1.24 commit for Windows7/8
if: ${{ matrix.jobs.goos == 'windows' && matrix.jobs.goversion == '1.24' }}
run: |
cd $(go env GOROOT)
patch --verbose -p 1 < $GITHUB_WORKSPACE/.github/patch/go1.24.patch
# modify from https://github.com/restic/restic/issues/4636#issuecomment-1896455557
# this patch file only works on golang1.23.x
# that means after golang1.24 release it must be changed
# see: https://github.com/MetaCubeX/go/commits/release-branch.go1.23/
# revert:
# 693def151adff1af707d82d28f55dba81ceb08e1: "crypto/rand,runtime: switch RtlGenRandom for ProcessPrng"
# 7c1157f9544922e96945196b47b95664b1e39108: "net: remove sysSocket fallback for Windows 7"
# 48042aa09c2f878c4faa576948b07fe625c4707a: "syscall: remove Windows 7 console handle workaround"
# a17d959debdb04cd550016a3501dd09d50cd62e7: "runtime: always use LoadLibraryEx to load system libraries"
- name: Revert Golang1.23 commit for Windows7/8
if: ${{ matrix.jobs.goos == 'windows' && matrix.jobs.goversion == '1.23' }}
run: |
cd $(go env GOROOT)
patch --verbose -p 1 < $GITHUB_WORKSPACE/.github/patch/go1.23.patch
# modify from https://github.com/restic/restic/issues/4636#issuecomment-1896455557
# this patch file only works on golang1.22.x
# that means after golang1.23 release it must be changed
# see: https://github.com/MetaCubeX/go/commits/release-branch.go1.22/
# revert:
# 693def151adff1af707d82d28f55dba81ceb08e1: "crypto/rand,runtime: switch RtlGenRandom for ProcessPrng"
# 7c1157f9544922e96945196b47b95664b1e39108: "net: remove sysSocket fallback for Windows 7"
# 48042aa09c2f878c4faa576948b07fe625c4707a: "syscall: remove Windows 7 console handle workaround"
# a17d959debdb04cd550016a3501dd09d50cd62e7: "runtime: always use LoadLibraryEx to load system libraries"
- name: Revert Golang1.22 commit for Windows7/8
if: ${{ matrix.jobs.goos == 'windows' && matrix.jobs.goversion == '1.22' }}
run: |
cd $(go env GOROOT)
patch --verbose -p 1 < $GITHUB_WORKSPACE/.github/patch/go1.22.patch
# modify from https://github.com/restic/restic/issues/4636#issuecomment-1896455557
- name: Revert Golang1.21 commit for Windows7/8
if: ${{ matrix.jobs.goos == 'windows' && matrix.jobs.goversion == '1.21' }}
run: |
cd $(go env GOROOT)
patch --verbose -p 1 < $GITHUB_WORKSPACE/.github/patch/go1.21.patch
- name: Set variables - name: Set variables
run: | if: ${{github.ref_name=='Alpha'}}
VERSION="${GITHUB_REF_NAME,,}-$(git rev-parse --short HEAD)" run: echo "VERSION=alpha-$(git rev-parse --short HEAD)" >> $GITHUB_ENV
VERSION="${VERSION//\//-}" shell: bash
PackageVersion="$(curl -s "https://api.github.com/repos/MetaCubeX/mihomo/releases/latest" | jq -r '.tag_name' | sed 's/v//g' | awk -F '.' '{$NF = $NF + 1; print}' OFS='.').${VERSION/-/.}"
if [ -n "${{ github.event.inputs.version }}" ]; then
VERSION=${{ github.event.inputs.version }}
PackageVersion="${VERSION#v}"
fi
echo "VERSION=${VERSION}" >> $GITHUB_ENV
echo "PackageVersion=${PackageVersion}" >> $GITHUB_ENV
- name: Set variables
if: ${{github.ref_name=='' || github.ref_type=='tag'}}
run: echo "VERSION=$(git describe --tags)" >> $GITHUB_ENV
shell: bash
- name: Set Time Variable
run: |
echo "BUILDTIME=$(date)" >> $GITHUB_ENV echo "BUILDTIME=$(date)" >> $GITHUB_ENV
echo "CGO_ENABLED=0" >> $GITHUB_ENV echo "CGO_ENABLED=0" >> $GITHUB_ENV
echo "BUILDTAG=-extldflags --static" >> $GITHUB_ENV echo "BUILDTAG=-extldflags --static" >> $GITHUB_ENV
echo "GOTOOLCHAIN=local" >> $GITHUB_ENV
- name: Setup NDK - name: Setup NDK
if: ${{ matrix.jobs.goos == 'android' }} if: ${{ matrix.jobs.goos == 'android' }}
uses: nttld/setup-ndk@v1 uses: nttld/setup-ndk@v1
id: setup-ndk id: setup-ndk
with: with:
ndk-version: r29-beta1 ndk-version: r26c
- name: Set NDK path - name: Set NDK path
if: ${{ matrix.jobs.goos == 'android' }} if: ${{ matrix.jobs.goos == 'android' }}
@@ -276,12 +140,10 @@ jobs:
if: ${{ matrix.jobs.test == 'test' }} if: ${{ matrix.jobs.test == 'test' }}
run: | run: |
go test ./... go test ./...
echo "---test with_gvisor---"
go test ./... -tags "with_gvisor" -count=1
- name: Update CA - name: Update UA
run: | run: |
sudo apt-get update && sudo apt-get install ca-certificates sudo apt-get install ca-certificates
sudo update-ca-certificates sudo update-ca-certificates
cp -f /etc/ssl/certs/ca-certificates.crt component/ca/ca-certificates.crt cp -f /etc/ssl/certs/ca-certificates.crt component/ca/ca-certificates.crt
@@ -290,11 +152,10 @@ jobs:
GOOS: ${{matrix.jobs.goos}} GOOS: ${{matrix.jobs.goos}}
GOARCH: ${{matrix.jobs.goarch}} GOARCH: ${{matrix.jobs.goarch}}
GOAMD64: ${{matrix.jobs.goamd64}} GOAMD64: ${{matrix.jobs.goamd64}}
GO386: ${{matrix.jobs.go386}} GOARM: ${{matrix.jobs.arm}}
GOARM: ${{matrix.jobs.goarm}} GOMIPS: ${{matrix.jobs.mips}}
GOMIPS: ${{matrix.jobs.gomips}}
run: | run: |
go env echo $CGO_ENABLED
go build -v -tags "with_gvisor" -trimpath -ldflags "${BUILDTAG} -X 'github.com/metacubex/mihomo/constant.Version=${VERSION}' -X 'github.com/metacubex/mihomo/constant.BuildTime=${BUILDTIME}' -w -s -buildid=" go build -v -tags "with_gvisor" -trimpath -ldflags "${BUILDTAG} -X 'github.com/metacubex/mihomo/constant.Version=${VERSION}' -X 'github.com/metacubex/mihomo/constant.BuildTime=${BUILDTIME}' -w -s -buildid="
if [ "${{matrix.jobs.goos}}" = "windows" ]; then if [ "${{matrix.jobs.goos}}" = "windows" ]; then
cp mihomo.exe mihomo-${{matrix.jobs.goos}}-${{matrix.jobs.output}}.exe cp mihomo.exe mihomo-${{matrix.jobs.goos}}-${{matrix.jobs.output}}.exe
@@ -305,45 +166,62 @@ jobs:
rm mihomo-${{matrix.jobs.goos}}-${{matrix.jobs.output}} rm mihomo-${{matrix.jobs.goos}}-${{matrix.jobs.output}}
fi fi
- name: Package DEB - name: Create DEB package
if: matrix.jobs.debian != '' if: ${{ matrix.jobs.goos == 'linux' && !contains(matrix.jobs.goarch, 'mips') }}
run: | run: |
set -xeuo pipefail sudo apt-get install dpkg
sudo gem install fpm if [ "${{matrix.jobs.abi}}" = "1" ]; then
cp .github/release/.fpm_systemd .fpm ARCH=loongarch64
else
ARCH=${{matrix.jobs.goarch}}
fi
mkdir -p mihomo-${{matrix.jobs.goos}}-${{matrix.jobs.output}}-${VERSION}/DEBIAN
mkdir -p mihomo-${{matrix.jobs.goos}}-${{matrix.jobs.output}}-${VERSION}/usr/bin
mkdir -p mihomo-${{matrix.jobs.goos}}-${{matrix.jobs.output}}-${VERSION}/etc/mihomo
mkdir -p mihomo-${{matrix.jobs.goos}}-${{matrix.jobs.output}}-${VERSION}/etc/systemd/system/
mkdir -p mihomo-${{matrix.jobs.goos}}-${{matrix.jobs.output}}-${VERSION}/usr/share/licenses/mihomo
fpm -t deb \ cp mihomo mihomo-${{matrix.jobs.goos}}-${{matrix.jobs.output}}-${VERSION}/usr/bin/mihomo
-v "${PackageVersion}" \ cp LICENSE mihomo-${{matrix.jobs.goos}}-${{matrix.jobs.output}}-${VERSION}/usr/share/licenses/mihomo/
-p "mihomo-${{matrix.jobs.goos}}-${{matrix.jobs.output}}-${VERSION}.deb" \ cp .github/mihomo.service mihomo-${{matrix.jobs.goos}}-${{matrix.jobs.output}}-${VERSION}/etc/systemd/system/
--architecture ${{ matrix.jobs.debian }} \
mihomo=/usr/bin/mihomo
- name: Package RPM cat > mihomo-${{matrix.jobs.goos}}-${{matrix.jobs.output}}-${VERSION}/etc/mihomo/config.yaml <<EOF
if: matrix.jobs.rpm != '' mixed-port: 7890
external-controller: 127.0.0.1:9090
EOF
cat > mihomo-${{matrix.jobs.goos}}-${{matrix.jobs.output}}-${VERSION}/DEBIAN/control <<EOF
Package: mihomo
Version: 1.18.2-${VERSION}
Section:
Priority: extra
Architecture: ${ARCH}
Maintainer: MetaCubeX <none@example.com>
Homepage: https://wiki.metacubex.one/
Description: The universal proxy platform.
EOF
dpkg-deb -Z gzip --build mihomo-${{matrix.jobs.goos}}-${{matrix.jobs.output}}-${VERSION}
- name: Convert DEB to RPM
if: ${{ matrix.jobs.goos == 'linux' && !contains(matrix.jobs.goarch, 'mips') }}
run: | run: |
set -xeuo pipefail sudo apt-get install -y alien
sudo gem install fpm alien --to-rpm --scripts mihomo-${{matrix.jobs.goos}}-${{matrix.jobs.output}}-${VERSION}.deb
cp .github/release/.fpm_systemd .fpm mv mihomo*.rpm mihomo-${{matrix.jobs.goos}}-${{matrix.jobs.output}}-${VERSION}.rpm
fpm -t rpm \ # - name: Convert DEB to PKG
-v "${PackageVersion}" \ # if: ${{ matrix.jobs.goos == 'linux' && !contains(matrix.jobs.goarch, 'mips') && !contains(matrix.jobs.goarch, 'loong64') }}
-p "mihomo-${{matrix.jobs.goos}}-${{matrix.jobs.output}}-${VERSION}.rpm" \ # run: |
--architecture ${{ matrix.jobs.rpm }} \ # docker pull archlinux
mihomo=/usr/bin/mihomo # docker run --rm -v ./:/mnt archlinux bash -c "
# pacman -Syu pkgfile base-devel --noconfirm
- name: Package Pacman # curl -L https://github.com/helixarch/debtap/raw/master/debtap > /usr/bin/debtap
if: matrix.jobs.pacman != '' # chmod 755 /usr/bin/debtap
run: | # debtap -u
set -xeuo pipefail # debtap -Q /mnt/mihomo-${{matrix.jobs.goos}}-${{matrix.jobs.output}}-${VERSION}.deb
sudo gem install fpm # "
sudo apt-get update && sudo apt-get install -y libarchive-tools # mv mihomo*.pkg.tar.zst mihomo-${{matrix.jobs.goos}}-${{matrix.jobs.output}}-${VERSION}.pkg.tar.zst
cp .github/release/.fpm_systemd .fpm
fpm -t pacman \
-v "${PackageVersion}" \
-p "mihomo-${{matrix.jobs.goos}}-${{matrix.jobs.output}}-${VERSION}.pkg.tar.zst" \
--architecture ${{ matrix.jobs.pacman }} \
mihomo=/usr/bin/mihomo
- name: Save version - name: Save version
run: | run: |
@@ -353,19 +231,17 @@ jobs:
- name: Archive production artifacts - name: Archive production artifacts
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:
name: "${{ matrix.jobs.goos }}-${{ matrix.jobs.output }}" name: ${{ matrix.jobs.goos }}-${{ matrix.jobs.output }}
path: | path: |
mihomo*.gz mihomo*.gz
mihomo*.deb mihomo*.deb
mihomo*.rpm mihomo*.rpm
mihomo*.pkg.tar.zst
mihomo*.zip mihomo*.zip
version.txt version.txt
checksums.txt
Upload-Prerelease: Upload-Prerelease:
permissions: write-all permissions: write-all
if: ${{ github.event_name != 'workflow_dispatch' && github.ref_type == 'branch' && !startsWith(github.event_name, 'pull_request') }} if: ${{ github.ref_type == 'branch' && !startsWith(github.event_name, 'pull_request') }}
needs: [build] needs: [build]
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
@@ -375,13 +251,6 @@ jobs:
path: bin/ path: bin/
merge-multiple: true merge-multiple: true
- name: Calculate checksums
run: |
cd bin/
find . -type f -not -name "checksums.*" -not -name "version.txt" | sort | xargs sha256sum > checksums.txt
cat checksums.txt
shell: bash
- name: Delete current release assets - name: Delete current release assets
uses: 8Mi-Tech/delete-release-assets-action@main uses: 8Mi-Tech/delete-release-assets-action@main
with: with:
@@ -423,62 +292,44 @@ jobs:
Upload-Release: Upload-Release:
permissions: write-all permissions: write-all
if: ${{ github.event_name == 'workflow_dispatch' && github.event.inputs.version != '' }} if: ${{ github.ref_type=='tag' }}
needs: [build] needs: [build]
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v5 uses: actions/checkout@v4
with: with:
ref: Meta fetch-depth: 0
fetch-depth: '0'
fetch-tags: 'true'
- name: Get tags - name: Get tags
run: | run: |
echo "CURRENTVERSION=${{ github.event.inputs.version }}" >> $GITHUB_ENV echo "CURRENTVERSION=${GITHUB_REF#refs/tags/}" >> $GITHUB_ENV
git fetch --tags git fetch --tags
echo "PREVERSION=$(git describe --tags --abbrev=0 HEAD)" >> $GITHUB_ENV echo "PREVERSION=$(git describe --tags --abbrev=0 HEAD^)" >> $GITHUB_ENV
- name: Force push Alpha branch to Meta - name: Generate release notes
run: | run: |
git config --global user.email "github-actions[bot]@users.noreply.github.com" cp ./.github/genReleaseNote.sh ./
git config --global user.name "github-actions[bot]" bash ./genReleaseNote.sh -v ${PREVERSION}...${CURRENTVERSION}
git fetch origin Alpha:Alpha rm ./genReleaseNote.sh
git push origin Alpha:Meta --force
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Tag the commit on Alpha - uses: actions/download-artifact@v4
run: | with:
git checkout Alpha path: bin/
git tag ${{ github.event.inputs.version }} merge-multiple: true
git push origin ${{ github.event.inputs.version }}
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Generate release notes - name: Display structure of downloaded files
run: | run: ls -R
cp ./.github/genReleaseNote.sh ./ working-directory: bin
bash ./genReleaseNote.sh -v ${PREVERSION}...${CURRENTVERSION}
rm ./genReleaseNote.sh
- uses: actions/download-artifact@v4 - name: Upload Release
with: uses: softprops/action-gh-release@v1
path: bin/ if: ${{ success() }}
merge-multiple: true with:
tag_name: ${{ github.ref_name }}
- name: Display structure of downloaded files files: bin/*
run: ls -R generate_release_notes: true
working-directory: bin body_path: release.md
- name: Upload Release
uses: softprops/action-gh-release@v2
if: ${{ success() }}
with:
tag_name: ${{ github.event.inputs.version }}
files: bin/*
body_path: release.md
Docker: Docker:
if: ${{ !startsWith(github.event_name, 'pull_request') }} if: ${{ !startsWith(github.event_name, 'pull_request') }}
@@ -487,7 +338,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v5 uses: actions/checkout@v4
with: with:
fetch-depth: 0 fetch-depth: 0
@@ -511,25 +362,10 @@ jobs:
# Extract metadata (tags, labels) for Docker # Extract metadata (tags, labels) for Docker
# https://github.com/docker/metadata-action # https://github.com/docker/metadata-action
- name: Extract Docker metadata - name: Extract Docker metadata
if: ${{ github.event_name != 'workflow_dispatch' }} id: meta
id: meta_alpha
uses: docker/metadata-action@v5 uses: docker/metadata-action@v5
with: with:
images: '${{ env.REGISTRY }}/${{ github.repository }}' images: ${{ env.REGISTRY }}/${{ github.repository }}
# Extract metadata (tags, labels) for Docker
# https://github.com/docker/metadata-action
- name: Extract Docker metadata
if: ${{ github.event_name == 'workflow_dispatch' && github.event.inputs.version != '' }}
id: meta_release
uses: docker/metadata-action@v5
with:
images: '${{ env.REGISTRY }}/${{ github.repository }}'
tags: |
${{ github.event.inputs.version }}
flavor: |
latest=true
labels: org.opencontainers.image.version=${{ github.event.inputs.version }}
- name: Show files - name: Show files
run: | run: |
@@ -546,7 +382,7 @@ jobs:
# Build and push Docker image with Buildx (don't push on PR) # Build and push Docker image with Buildx (don't push on PR)
# https://github.com/docker/build-push-action # https://github.com/docker/build-push-action
- name: Build and push Docker image - name: Build and push Docker image
if: ${{ github.event_name != 'workflow_dispatch' }} id: build-and-push
uses: docker/build-push-action@v5 uses: docker/build-push-action@v5
with: with:
context: . context: .
@@ -557,20 +393,5 @@ jobs:
linux/amd64 linux/amd64
linux/arm64 linux/arm64
linux/arm/v7 linux/arm/v7
tags: ${{ steps.meta_alpha.outputs.tags }} tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta_alpha.outputs.labels }} labels: ${{ steps.meta.outputs.labels }}
- name: Build and push Docker image
if: ${{ github.event_name == 'workflow_dispatch' && github.event.inputs.version != '' }}
uses: docker/build-push-action@v5
with:
context: .
file: ./Dockerfile
push: ${{ github.event_name != 'pull_request' }}
platforms: |
linux/386
linux/amd64
linux/arm64
linux/arm/v7
tags: ${{ steps.meta_release.outputs.tags }}
labels: ${{ steps.meta_release.outputs.labels }}

View File

@@ -1,74 +0,0 @@
name: Test
on:
push:
paths-ignore:
- "docs/**"
- "README.md"
- ".github/ISSUE_TEMPLATE/**"
branches:
- Alpha
tags:
- "v*"
pull_request:
branches:
- Alpha
jobs:
test:
strategy:
matrix:
os:
- 'ubuntu-latest' # amd64 linux
- 'windows-latest' # amd64 windows
- 'macos-latest' # arm64 macos
- 'ubuntu-24.04-arm' # arm64 linux
- 'macos-15-intel' # amd64 macos
go-version:
- '1.26.0-rc.3'
- '1.25'
- '1.24'
- '1.23'
- '1.22'
- '1.21'
- '1.20'
fail-fast: false
runs-on: ${{ matrix.os }}
defaults:
run:
shell: bash
env:
CGO_ENABLED: 0
GOTOOLCHAIN: local
# Fix mingw trying to be smart and converting paths https://github.com/moby/moby/issues/24029#issuecomment-250412919
MSYS_NO_PATHCONV: true
steps:
- uses: actions/checkout@v5
- name: Setup Go
uses: actions/setup-go@v6
with:
go-version: ${{ matrix.go-version }}
check-latest: true # Always check for the latest patch release
- name: Revert Golang commit for Windows7/8
if: ${{ runner.os == 'Windows' && matrix.go-version != '1.20' && matrix.go-version != '1.26.0-rc.3' }}
run: |
cd $(go env GOROOT)
patch --verbose -p 1 < $GITHUB_WORKSPACE/.github/patch/go${{matrix.go-version}}.patch
- name: Revert Golang commit for Windows7/8
if: ${{ runner.os == 'Windows' && matrix.go-version == '1.26.0-rc.3' }}
run: |
cd $(go env GOROOT)
patch --verbose -p 1 < $GITHUB_WORKSPACE/.github/patch/go1.26.patch
- name: Remove inbound test for macOS
if: ${{ runner.os == 'macOS' }}
run: |
rm -rf listener/inbound/*_test.go
- name: Test
run: go test ./... -v -count=1
- name: Test with tag with_gvisor
run: go test ./... -v -count=1 -tags "with_gvisor"

View File

@@ -10,6 +10,9 @@ on:
- Alpha - Alpha
tags: tags:
- "v*" - "v*"
pull_request_target:
branches:
- Alpha
jobs: jobs:
# Send "core-updated" to MetaCubeX/ClashMetaForAndroid to trigger update-dependencies # Send "core-updated" to MetaCubeX/ClashMetaForAndroid to trigger update-dependencies

View File

@@ -4,16 +4,16 @@ RUN echo "I'm building for $TARGETPLATFORM"
RUN apk add --no-cache gzip && \ RUN apk add --no-cache gzip && \
mkdir /mihomo-config && \ mkdir /mihomo-config && \
wget -O /mihomo-config/geoip.metadb https://github.com/MetaCubeX/meta-rules-dat/releases/download/latest/geoip.metadb && \ wget -O /mihomo-config/geoip.metadb https://fastly.jsdelivr.net/gh/MetaCubeX/meta-rules-dat@release/geoip.metadb && \
wget -O /mihomo-config/geosite.dat https://github.com/MetaCubeX/meta-rules-dat/releases/download/latest/geosite.dat && \ wget -O /mihomo-config/geosite.dat https://fastly.jsdelivr.net/gh/MetaCubeX/meta-rules-dat@release/geosite.dat && \
wget -O /mihomo-config/geoip.dat https://github.com/MetaCubeX/meta-rules-dat/releases/download/latest/geoip.dat wget -O /mihomo-config/geoip.dat https://fastly.jsdelivr.net/gh/MetaCubeX/meta-rules-dat@release/geoip.dat
COPY docker/file-name.sh /mihomo/file-name.sh COPY docker/file-name.sh /mihomo/file-name.sh
WORKDIR /mihomo WORKDIR /mihomo
COPY bin/ bin/ COPY bin/ bin/
RUN FILE_NAME=`sh file-name.sh` && echo $FILE_NAME && \ RUN FILE_NAME=`sh file-name.sh` && echo $FILE_NAME && \
FILE_NAME=`ls bin/ | egrep "$FILE_NAME.gz"|awk NR==1` && echo $FILE_NAME && \ FILE_NAME=`ls bin/ | egrep "$FILE_NAME.gz"|awk NR==1` && echo $FILE_NAME && \
mv bin/$FILE_NAME mihomo.gz && gzip -d mihomo.gz && chmod +x mihomo && echo "$FILE_NAME" > /mihomo-config/test mv bin/$FILE_NAME mihomo.gz && gzip -d mihomo.gz && echo "$FILE_NAME" > /mihomo-config/test
FROM alpine:latest FROM alpine:latest
LABEL org.opencontainers.image.source="https://github.com/MetaCubeX/mihomo" LABEL org.opencontainers.image.source="https://github.com/MetaCubeX/mihomo"
@@ -23,4 +23,5 @@ VOLUME ["/root/.config/mihomo/"]
COPY --from=builder /mihomo-config/ /root/.config/mihomo/ COPY --from=builder /mihomo-config/ /root/.config/mihomo/
COPY --from=builder /mihomo/mihomo /mihomo COPY --from=builder /mihomo/mihomo /mihomo
RUN chmod +x /mihomo
ENTRYPOINT [ "/mihomo" ] ENTRYPOINT [ "/mihomo" ]

View File

@@ -17,19 +17,11 @@ GOBUILD=CGO_ENABLED=0 go build -tags with_gvisor -trimpath -ldflags '-X "github.
-w -s -buildid=' -w -s -buildid='
PLATFORM_LIST = \ PLATFORM_LIST = \
darwin-386 \
darwin-amd64-compatible \ darwin-amd64-compatible \
darwin-amd64 \ darwin-amd64 \
darwin-amd64-v1 \
darwin-amd64-v2 \
darwin-amd64-v3 \
darwin-arm64 \ darwin-arm64 \
linux-386 \
linux-amd64-compatible \ linux-amd64-compatible \
linux-amd64 \ linux-amd64 \
linux-amd64-v1 \
linux-amd64-v2 \
linux-amd64-v3 \
linux-armv5 \ linux-armv5 \
linux-armv6 \ linux-armv6 \
linux-armv7 \ linux-armv7 \
@@ -51,61 +43,37 @@ WINDOWS_ARCH_LIST = \
windows-386 \ windows-386 \
windows-amd64-compatible \ windows-amd64-compatible \
windows-amd64 \ windows-amd64 \
windows-amd64-v1 \
windows-amd64-v2 \
windows-amd64-v3 \
windows-arm64 \ windows-arm64 \
windows-arm32v7 windows-arm32v7
all:linux-amd64-v3 linux-arm64\ all:linux-amd64 linux-arm64\
darwin-amd64-v3 darwin-arm64\ darwin-amd64 darwin-arm64\
windows-amd64-v3 windows-arm64\ windows-amd64 windows-arm64\
darwin-all: darwin-amd64-v3 darwin-arm64 darwin-all: darwin-amd64 darwin-arm64
docker: docker:
GOAMD64=v1 $(GOBUILD) -o $(BINDIR)/$(NAME)-$@ GOAMD64=v1 $(GOBUILD) -o $(BINDIR)/$(NAME)-$@
darwin-386:
GOARCH=386 GOOS=darwin $(GOBUILD) -o $(BINDIR)/$(NAME)-$@
darwin-amd64-compatible:
GOARCH=amd64 GOOS=darwin GOAMD64=v1 $(GOBUILD) -o $(BINDIR)/$(NAME)-$@
darwin-amd64: darwin-amd64:
GOARCH=amd64 GOOS=darwin GOAMD64=v3 $(GOBUILD) -o $(BINDIR)/$(NAME)-$@ GOARCH=amd64 GOOS=darwin GOAMD64=v3 $(GOBUILD) -o $(BINDIR)/$(NAME)-$@
darwin-amd64-v1: darwin-amd64-compatible:
GOARCH=amd64 GOOS=darwin GOAMD64=v1 $(GOBUILD) -o $(BINDIR)/$(NAME)-$@ GOARCH=amd64 GOOS=darwin GOAMD64=v1 $(GOBUILD) -o $(BINDIR)/$(NAME)-$@
darwin-amd64-v2:
GOARCH=amd64 GOOS=darwin GOAMD64=v2 $(GOBUILD) -o $(BINDIR)/$(NAME)-$@
darwin-amd64-v3:
GOARCH=amd64 GOOS=darwin GOAMD64=v3 $(GOBUILD) -o $(BINDIR)/$(NAME)-$@
darwin-arm64: darwin-arm64:
GOARCH=arm64 GOOS=darwin $(GOBUILD) -o $(BINDIR)/$(NAME)-$@ GOARCH=arm64 GOOS=darwin $(GOBUILD) -o $(BINDIR)/$(NAME)-$@
linux-386: linux-386:
GOARCH=386 GOOS=linux $(GOBUILD) -o $(BINDIR)/$(NAME)-$@ GOARCH=386 GOOS=linux $(GOBUILD) -o $(BINDIR)/$(NAME)-$@
linux-amd64-compatible:
GOARCH=amd64 GOOS=linux GOAMD64=v1 $(GOBUILD) -o $(BINDIR)/$(NAME)-$@
linux-amd64: linux-amd64:
GOARCH=amd64 GOOS=linux GOAMD64=v3 $(GOBUILD) -o $(BINDIR)/$(NAME)-$@ GOARCH=amd64 GOOS=linux GOAMD64=v3 $(GOBUILD) -o $(BINDIR)/$(NAME)-$@
linux-amd64-v1: linux-amd64-compatible:
GOARCH=amd64 GOOS=linux GOAMD64=v1 $(GOBUILD) -o $(BINDIR)/$(NAME)-$@ GOARCH=amd64 GOOS=linux GOAMD64=v1 $(GOBUILD) -o $(BINDIR)/$(NAME)-$@
linux-amd64-v2:
GOARCH=amd64 GOOS=linux GOAMD64=v2 $(GOBUILD) -o $(BINDIR)/$(NAME)-$@
linux-amd64-v3:
GOARCH=amd64 GOOS=linux GOAMD64=v3 $(GOBUILD) -o $(BINDIR)/$(NAME)-$@
linux-arm64: linux-arm64:
GOARCH=arm64 GOOS=linux $(GOBUILD) -o $(BINDIR)/$(NAME)-$@ GOARCH=arm64 GOOS=linux $(GOBUILD) -o $(BINDIR)/$(NAME)-$@
@@ -157,21 +125,12 @@ freebsd-arm64:
windows-386: windows-386:
GOARCH=386 GOOS=windows $(GOBUILD) -o $(BINDIR)/$(NAME)-$@.exe GOARCH=386 GOOS=windows $(GOBUILD) -o $(BINDIR)/$(NAME)-$@.exe
windows-amd64-compatible:
GOARCH=amd64 GOOS=windows GOAMD64=v1 $(GOBUILD) -o $(BINDIR)/$(NAME)-$@.exe
windows-amd64: windows-amd64:
GOARCH=amd64 GOOS=windows GOAMD64=v3 $(GOBUILD) -o $(BINDIR)/$(NAME)-$@.exe GOARCH=amd64 GOOS=windows GOAMD64=v3 $(GOBUILD) -o $(BINDIR)/$(NAME)-$@.exe
windows-amd64-v1: windows-amd64-compatible:
GOARCH=amd64 GOOS=windows GOAMD64=v1 $(GOBUILD) -o $(BINDIR)/$(NAME)-$@.exe GOARCH=amd64 GOOS=windows GOAMD64=v1 $(GOBUILD) -o $(BINDIR)/$(NAME)-$@.exe
windows-amd64-v2:
GOARCH=amd64 GOOS=windows GOAMD64=v2 $(GOBUILD) -o $(BINDIR)/$(NAME)-$@.exe
windows-amd64-v3:
GOARCH=amd64 GOOS=windows GOAMD64=v3 $(GOBUILD) -o $(BINDIR)/$(NAME)-$@.exe
windows-arm64: windows-arm64:
GOARCH=arm64 GOOS=windows $(GOBUILD) -o $(BINDIR)/$(NAME)-$@.exe GOARCH=arm64 GOOS=windows $(GOBUILD) -o $(BINDIR)/$(NAME)-$@.exe
@@ -204,3 +163,7 @@ clean:
CLANG ?= clang-14 CLANG ?= clang-14
CFLAGS := -O2 -g -Wall -Werror $(CFLAGS) CFLAGS := -O2 -g -Wall -Werror $(CFLAGS)
ebpf: export BPF_CLANG := $(CLANG)
ebpf: export BPF_CFLAGS := $(CFLAGS)
ebpf:
cd component/ebpf/ && go generate ./...

View File

@@ -98,4 +98,4 @@ API.
This software is released under the GPL-3.0 license. This software is released under the GPL-3.0 license.
**In addition, any downstream projects not affiliated with `MetaCubeX` shall not contain the word `mihomo` in their names.** [![FOSSA Status](https://app.fossa.io/api/projects/git%2Bgithub.com%2FMetaCubeX%2Fmihomo.svg?type=large)](https://app.fossa.io/projects/git%2Bgithub.com%2FMetaCubeX%2Fmihomo?ref=badge_large)

View File

@@ -5,19 +5,18 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net" "net"
"net/http"
"net/netip"
"net/url" "net/url"
"strings" "strconv"
"time" "time"
"github.com/metacubex/mihomo/common/atomic" "github.com/metacubex/mihomo/common/atomic"
"github.com/metacubex/mihomo/common/queue" "github.com/metacubex/mihomo/common/queue"
"github.com/metacubex/mihomo/common/utils" "github.com/metacubex/mihomo/common/utils"
"github.com/metacubex/mihomo/common/xsync" "github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/ca"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/log" "github.com/puzpuzpuz/xsync/v3"
"github.com/metacubex/http"
) )
var UnifiedDelay = atomic.NewBool(false) var UnifiedDelay = atomic.NewBool(false)
@@ -35,12 +34,7 @@ type Proxy struct {
C.ProxyAdapter C.ProxyAdapter
alive atomic.Bool alive atomic.Bool
history *queue.Queue[C.DelayHistory] history *queue.Queue[C.DelayHistory]
extra xsync.Map[string, *internalProxyState] extra *xsync.MapOf[string, *internalProxyState]
}
// Adapter implements C.Proxy
func (p *Proxy) Adapter() C.ProxyAdapter {
return p.ProxyAdapter
} }
// AliveForTestUrl implements C.Proxy // AliveForTestUrl implements C.Proxy
@@ -52,15 +46,29 @@ func (p *Proxy) AliveForTestUrl(url string) bool {
return p.alive.Load() return p.alive.Load()
} }
// Dial implements C.Proxy
func (p *Proxy) Dial(metadata *C.Metadata) (C.Conn, error) {
ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTCPTimeout)
defer cancel()
return p.DialContext(ctx, metadata)
}
// DialContext implements C.ProxyAdapter // DialContext implements C.ProxyAdapter
func (p *Proxy) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) { func (p *Proxy) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.Conn, error) {
conn, err := p.ProxyAdapter.DialContext(ctx, metadata) conn, err := p.ProxyAdapter.DialContext(ctx, metadata, opts...)
return conn, err return conn, err
} }
// DialUDP implements C.ProxyAdapter
func (p *Proxy) DialUDP(metadata *C.Metadata) (C.PacketConn, error) {
ctx, cancel := context.WithTimeout(context.Background(), C.DefaultUDPTimeout)
defer cancel()
return p.ListenPacketContext(ctx, metadata)
}
// ListenPacketContext implements C.ProxyAdapter // ListenPacketContext implements C.ProxyAdapter
func (p *Proxy) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (C.PacketConn, error) { func (p *Proxy) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.PacketConn, error) {
pc, err := p.ProxyAdapter.ListenPacketContext(ctx, metadata) pc, err := p.ProxyAdapter.ListenPacketContext(ctx, metadata, opts...)
return pc, err return pc, err
} }
@@ -146,18 +154,8 @@ func (p *Proxy) MarshalJSON() ([]byte, error) {
mapping["alive"] = p.alive.Load() mapping["alive"] = p.alive.Load()
mapping["name"] = p.Name() mapping["name"] = p.Name()
mapping["udp"] = p.SupportUDP() mapping["udp"] = p.SupportUDP()
mapping["uot"] = p.SupportUOT() mapping["xudp"] = p.SupportXUDP()
mapping["tfo"] = p.SupportTFO()
proxyInfo := p.ProxyInfo()
mapping["xudp"] = proxyInfo.XUDP
mapping["tfo"] = proxyInfo.TFO
mapping["mptcp"] = proxyInfo.MPTCP
mapping["smux"] = proxyInfo.SMUX
mapping["interface"] = proxyInfo.Interface
mapping["routing-mark"] = proxyInfo.RoutingMark
mapping["provider-name"] = proxyInfo.ProviderName
mapping["dialer-proxy"] = proxyInfo.DialerProxy
return json.Marshal(mapping) return json.Marshal(mapping)
} }
@@ -179,12 +177,14 @@ func (p *Proxy) URLTest(ctx context.Context, url string, expectedStatus utils.In
p.history.Pop() p.history.Pop()
} }
state, _ := p.extra.LoadOrStoreFn(url, func() *internalProxyState { state, ok := p.extra.Load(url)
return &internalProxyState{ if !ok {
state = &internalProxyState{
history: queue.New[C.DelayHistory](defaultHistoriesNum), history: queue.New[C.DelayHistory](defaultHistoriesNum),
alive: atomic.NewBool(true), alive: atomic.NewBool(true),
} }
}) p.extra.Store(url, state)
}
if !satisfied { if !satisfied {
record.Delay = 0 record.Delay = 0
@@ -221,11 +221,6 @@ func (p *Proxy) URLTest(ctx context.Context, url string, expectedStatus utils.In
} }
req = req.WithContext(ctx) req = req.WithContext(ctx)
tlsConfig, err := ca.GetTLSConfig(ca.Option{})
if err != nil {
return
}
transport := &http.Transport{ transport := &http.Transport{
DialContext: func(context.Context, string, string) (net.Conn, error) { DialContext: func(context.Context, string, string) (net.Conn, error) {
return instance, nil return instance, nil
@@ -235,7 +230,6 @@ func (p *Proxy) URLTest(ctx context.Context, url string, expectedStatus utils.In
IdleConnTimeout: 90 * time.Second, IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second, TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second, ExpectContinueTimeout: 1 * time.Second,
TLSClientConfig: tlsConfig,
} }
client := http.Client{ client := http.Client{
@@ -258,18 +252,10 @@ func (p *Proxy) URLTest(ctx context.Context, url string, expectedStatus utils.In
if unifiedDelay { if unifiedDelay {
second := time.Now() second := time.Now()
var ignoredErr error resp, err = client.Do(req)
var secondResp *http.Response if err == nil {
secondResp, ignoredErr = client.Do(req)
if ignoredErr == nil {
resp = secondResp
_ = resp.Body.Close() _ = resp.Body.Close()
start = second start = second
} else {
if strings.HasPrefix(url, "http://") {
log.Errorln("%s failed to get the second response from %s: %v", p.Name(), url, ignoredErr)
log.Warnln("It is recommended to use HTTPS for provider.health-check.url and group.url to ensure better reliability. Due to some proxy providers hijacking test addresses and not being compatible with repeated HEAD requests, using HTTP may result in failed tests.")
}
} }
} }
@@ -277,13 +263,12 @@ func (p *Proxy) URLTest(ctx context.Context, url string, expectedStatus utils.In
t = uint16(time.Since(start) / time.Millisecond) t = uint16(time.Since(start) / time.Millisecond)
return return
} }
func NewProxy(adapter C.ProxyAdapter) *Proxy { func NewProxy(adapter C.ProxyAdapter) *Proxy {
return &Proxy{ return &Proxy{
ProxyAdapter: adapter, ProxyAdapter: adapter,
history: queue.New[C.DelayHistory](defaultHistoriesNum), history: queue.New[C.DelayHistory](defaultHistoriesNum),
alive: atomic.NewBool(true), alive: atomic.NewBool(true),
} extra: xsync.NewMapOf[string, *internalProxyState]()}
} }
func urlToMetadata(rawURL string) (addr C.Metadata, err error) { func urlToMetadata(rawURL string) (addr C.Metadata, err error) {
@@ -304,7 +289,15 @@ func urlToMetadata(rawURL string) (addr C.Metadata, err error) {
return return
} }
} }
uintPort, err := strconv.ParseUint(port, 10, 16)
if err != nil {
return
}
err = addr.SetRemoteAddress(net.JoinHostPort(u.Hostname(), port)) addr = C.Metadata{
Host: u.Hostname(),
DstIP: netip.Addr{},
DstPort: uint16(uintPort),
}
return return
} }

View File

@@ -47,7 +47,7 @@ func WithDstAddr(addr net.Addr) Addition {
func WithSrcAddr(addr net.Addr) Addition { func WithSrcAddr(addr net.Addr) Addition {
return func(metadata *C.Metadata) { return func(metadata *C.Metadata) {
m := C.Metadata{} m := C.Metadata{}
if err := m.SetRemoteAddr(addr); err == nil { if err := m.SetRemoteAddr(addr);err ==nil{
metadata.SrcIP = m.DstIP metadata.SrcIP = m.DstIP
metadata.SrcPort = m.DstPort metadata.SrcPort = m.DstPort
} }
@@ -57,7 +57,7 @@ func WithSrcAddr(addr net.Addr) Addition {
func WithInAddr(addr net.Addr) Addition { func WithInAddr(addr net.Addr) Addition {
return func(metadata *C.Metadata) { return func(metadata *C.Metadata) {
m := C.Metadata{} m := C.Metadata{}
if err := m.SetRemoteAddr(addr); err == nil { if err := m.SetRemoteAddr(addr);err ==nil{
metadata.InIP = m.DstIP metadata.InIP = m.DstIP
metadata.InPort = m.DstPort metadata.InPort = m.DstPort
} }
@@ -69,5 +69,3 @@ func WithDSCP(dscp uint8) Addition {
metadata.DSCP = dscp metadata.DSCP = dscp
} }
} }
func Placeholder(metadata *C.Metadata) {}

View File

@@ -34,5 +34,12 @@ func SkipAuthRemoteAddress(addr string) bool {
} }
func skipAuth(addr netip.Addr) bool { func skipAuth(addr netip.Addr) bool {
return prefixesContains(skipAuthPrefixes, addr) if addr.IsValid() {
for _, prefix := range skipAuthPrefixes {
if prefix.Contains(addr.Unmap()) {
return true
}
}
}
return false
} }

View File

@@ -14,7 +14,7 @@ func NewHTTP(target socks5.Addr, srcConn net.Conn, conn net.Conn, additions ...A
metadata.Type = C.HTTP metadata.Type = C.HTTP
metadata.RawSrcAddr = srcConn.RemoteAddr() metadata.RawSrcAddr = srcConn.RemoteAddr()
metadata.RawDstAddr = srcConn.LocalAddr() metadata.RawDstAddr = srcConn.LocalAddr()
ApplyAdditions(metadata, WithSrcAddr(srcConn.RemoteAddr()), WithInAddr(srcConn.LocalAddr())) ApplyAdditions(metadata, WithSrcAddr(srcConn.RemoteAddr()), WithInAddr(conn.LocalAddr()))
ApplyAdditions(metadata, additions...) ApplyAdditions(metadata, additions...)
return conn, metadata return conn, metadata
} }

View File

@@ -2,18 +2,15 @@ package inbound
import ( import (
"net" "net"
"net/http"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/http"
) )
// NewHTTPS receive CONNECT request and return ConnContext // NewHTTPS receive CONNECT request and return ConnContext
func NewHTTPS(request *http.Request, conn net.Conn, additions ...Addition) (net.Conn, *C.Metadata) { func NewHTTPS(request *http.Request, conn net.Conn, additions ...Addition) (net.Conn, *C.Metadata) {
metadata := parseHTTPAddr(request) metadata := parseHTTPAddr(request)
metadata.Type = C.HTTPS metadata.Type = C.HTTPS
metadata.RawSrcAddr = conn.RemoteAddr()
metadata.RawDstAddr = conn.LocalAddr()
ApplyAdditions(metadata, WithSrcAddr(conn.RemoteAddr()), WithInAddr(conn.LocalAddr())) ApplyAdditions(metadata, WithSrcAddr(conn.RemoteAddr()), WithInAddr(conn.LocalAddr()))
ApplyAdditions(metadata, additions...) ApplyAdditions(metadata, additions...)
return conn, metadata return conn, metadata

View File

@@ -31,17 +31,27 @@ func IsRemoteAddrDisAllowed(addr net.Addr) bool {
if err := m.SetRemoteAddr(addr); err != nil { if err := m.SetRemoteAddr(addr); err != nil {
return false return false
} }
ipAddr := m.AddrPort().Addr() return isAllowed(m.AddrPort().Addr().Unmap()) && !isDisAllowed(m.AddrPort().Addr().Unmap())
if ipAddr.IsValid() { }
return isAllowed(ipAddr) && !isDisAllowed(ipAddr)
func isAllowed(addr netip.Addr) bool {
if addr.IsValid() {
for _, prefix := range lanAllowedIPs {
if prefix.Contains(addr) {
return true
}
}
} }
return false return false
} }
func isAllowed(addr netip.Addr) bool {
return prefixesContains(lanAllowedIPs, addr)
}
func isDisAllowed(addr netip.Addr) bool { func isDisAllowed(addr netip.Addr) bool {
return prefixesContains(lanDisAllowedIPs, addr) if addr.IsValid() {
for _, prefix := range lanDisAllowedIPs {
if prefix.Contains(addr) {
return true
}
}
}
return false
} }

View File

@@ -2,13 +2,7 @@ package inbound
import ( import (
"context" "context"
"fmt"
"net" "net"
"net/netip"
"sync"
"github.com/metacubex/mihomo/component/keepalive"
"github.com/metacubex/mihomo/component/mptcp"
"github.com/metacubex/tfo-go" "github.com/metacubex/tfo-go"
) )
@@ -17,91 +11,20 @@ var (
lc = tfo.ListenConfig{ lc = tfo.ListenConfig{
DisableTFO: true, DisableTFO: true,
} }
mutex sync.RWMutex
) )
func SetTfo(open bool) { func SetTfo(open bool) {
mutex.Lock()
defer mutex.Unlock()
lc.DisableTFO = !open lc.DisableTFO = !open
} }
func Tfo() bool {
mutex.RLock()
defer mutex.RUnlock()
return !lc.DisableTFO
}
func SetMPTCP(open bool) { func SetMPTCP(open bool) {
mutex.Lock() setMultiPathTCP(&lc.ListenConfig, open)
defer mutex.Unlock()
mptcp.SetNetListenConfig(&lc.ListenConfig, open)
}
func MPTCP() bool {
mutex.RLock()
defer mutex.RUnlock()
return mptcp.GetNetListenConfig(&lc.ListenConfig)
}
func preResolve(network, address string) (string, error) {
switch network { // like net.Resolver.internetAddrList but filter domain to avoid call net.Resolver.lookupIPAddr
case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6", "ip", "ip4", "ip6":
if host, port, err := net.SplitHostPort(address); err == nil {
switch host {
case "localhost":
switch network {
case "tcp6", "udp6", "ip6":
address = net.JoinHostPort("::1", port)
default:
address = net.JoinHostPort("127.0.0.1", port)
}
case "": // internetAddrList can handle this special case
break
default:
if _, err := netip.ParseAddr(host); err != nil { // not ip
return "", fmt.Errorf("invalid network address: %s", address)
}
}
}
}
return address, nil
} }
func ListenContext(ctx context.Context, network, address string) (net.Listener, error) { func ListenContext(ctx context.Context, network, address string) (net.Listener, error) {
address, err := preResolve(network, address)
if err != nil {
return nil, err
}
mutex.RLock()
defer mutex.RUnlock()
return lc.Listen(ctx, network, address) return lc.Listen(ctx, network, address)
} }
func Listen(network, address string) (net.Listener, error) { func Listen(network, address string) (net.Listener, error) {
return ListenContext(context.Background(), network, address) return ListenContext(context.Background(), network, address)
} }
func ListenPacketContext(ctx context.Context, network, address string) (net.PacketConn, error) {
address, err := preResolve(network, address)
if err != nil {
return nil, err
}
mutex.RLock()
defer mutex.RUnlock()
return lc.ListenPacket(ctx, network, address)
}
func ListenPacket(network, address string) (net.PacketConn, error) {
return ListenPacketContext(context.Background(), network, address)
}
func init() {
keepalive.SetDisableKeepAliveCallback.Register(func(b bool) {
mutex.Lock()
defer mutex.Unlock()
keepalive.SetNetListenConfig(&lc.ListenConfig)
})
}

View File

@@ -1,14 +0,0 @@
//go:build !windows
package inbound
import (
"net"
"os"
)
const SupportNamedPipe = false
func ListenNamedPipe(path string) (net.Listener, error) {
return nil, os.ErrInvalid
}

View File

@@ -1,32 +0,0 @@
package inbound
import (
"net"
"os"
"github.com/metacubex/wireguard-go/ipc/namedpipe"
"golang.org/x/sys/windows"
)
const SupportNamedPipe = true
// windowsSDDL is the Security Descriptor set on the namedpipe.
// It provides read/write access to all users and the local system.
const windowsSDDL = "D:PAI(A;OICI;GWGR;;;BU)(A;OICI;GWGR;;;SY)"
func ListenNamedPipe(path string) (net.Listener, error) {
sddl := os.Getenv("LISTEN_NAMEDPIPE_SDDL")
if sddl == "" {
sddl = windowsSDDL
}
securityDescriptor, err := windows.SecurityDescriptorFromString(sddl)
if err != nil {
return nil, err
}
namedpipeLC := namedpipe.ListenConfig{
SecurityDescriptor: securityDescriptor,
InputBufferSize: 256 * 1024,
OutputBufferSize: 256 * 1024,
}
return namedpipeLC.Listen(path)
}

View File

@@ -0,0 +1,10 @@
//go:build !go1.21
package inbound
import "net"
const multipathTCPAvailable = false
func setMultiPathTCP(listenConfig *net.ListenConfig, open bool) {
}

View File

@@ -0,0 +1,11 @@
//go:build go1.21
package inbound
import "net"
const multipathTCPAvailable = true
func setMultiPathTCP(listenConfig *net.ListenConfig, open bool) {
listenConfig.SetMultipathTCP(open)
}

View File

@@ -2,13 +2,14 @@ package inbound
import ( import (
"net" "net"
"net/http"
"net/netip" "net/netip"
"strconv"
"strings" "strings"
"github.com/metacubex/mihomo/common/nnip"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/transport/socks5" "github.com/metacubex/mihomo/transport/socks5"
"github.com/metacubex/http"
) )
func parseSocksAddr(target socks5.Addr) *C.Metadata { func parseSocksAddr(target socks5.Addr) *C.Metadata {
@@ -20,13 +21,13 @@ func parseSocksAddr(target socks5.Addr) *C.Metadata {
metadata.Host = strings.TrimRight(string(target[2:2+target[1]]), ".") metadata.Host = strings.TrimRight(string(target[2:2+target[1]]), ".")
metadata.DstPort = uint16((int(target[2+target[1]]) << 8) | int(target[2+target[1]+1])) metadata.DstPort = uint16((int(target[2+target[1]]) << 8) | int(target[2+target[1]+1]))
case socks5.AtypIPv4: case socks5.AtypIPv4:
metadata.DstIP, _ = netip.AddrFromSlice(target[1 : 1+net.IPv4len]) metadata.DstIP = nnip.IpToAddr(net.IP(target[1 : 1+net.IPv4len]))
metadata.DstPort = uint16((int(target[1+net.IPv4len]) << 8) | int(target[1+net.IPv4len+1])) metadata.DstPort = uint16((int(target[1+net.IPv4len]) << 8) | int(target[1+net.IPv4len+1]))
case socks5.AtypIPv6: case socks5.AtypIPv6:
metadata.DstIP, _ = netip.AddrFromSlice(target[1 : 1+net.IPv6len]) ip6, _ := netip.AddrFromSlice(target[1 : 1+net.IPv6len])
metadata.DstIP = ip6.Unmap()
metadata.DstPort = uint16((int(target[1+net.IPv6len]) << 8) | int(target[1+net.IPv6len+1])) metadata.DstPort = uint16((int(target[1+net.IPv6len]) << 8) | int(target[1+net.IPv6len+1]))
} }
metadata.DstIP = metadata.DstIP.Unmap()
return metadata return metadata
} }
@@ -41,23 +42,22 @@ func parseHTTPAddr(request *http.Request) *C.Metadata {
// trim FQDN (#737) // trim FQDN (#737)
host = strings.TrimRight(host, ".") host = strings.TrimRight(host, ".")
metadata := &C.Metadata{} var uint16Port uint16
_ = metadata.SetRemoteAddress(net.JoinHostPort(host, port)) if port, err := strconv.ParseUint(port, 10, 16); err == nil {
uint16Port = uint16(port)
}
metadata := &C.Metadata{
NetWork: C.TCP,
Host: host,
DstIP: netip.Addr{},
DstPort: uint16Port,
}
ip, err := netip.ParseAddr(host)
if err == nil {
metadata.DstIP = ip
}
return metadata return metadata
} }
func prefixesContains(prefixes []netip.Prefix, addr netip.Addr) bool {
if len(prefixes) == 0 {
return false
}
if !addr.IsValid() {
return false
}
addr = addr.Unmap().WithZone("") // netip.Prefix.Contains returns false if ip has an IPv6 zone
for _, prefix := range prefixes {
if prefix.Contains(addr) {
return true
}
}
return false
}

View File

@@ -1,137 +0,0 @@
package outbound
import (
"context"
"net"
"strconv"
"time"
N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/component/proxydialer"
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/transport/anytls"
"github.com/metacubex/mihomo/transport/vmess"
M "github.com/metacubex/sing/common/metadata"
"github.com/metacubex/sing/common/uot"
)
type AnyTLS struct {
*Base
client *anytls.Client
option *AnyTLSOption
}
type AnyTLSOption struct {
BasicOption
Name string `proxy:"name"`
Server string `proxy:"server"`
Port int `proxy:"port"`
Password string `proxy:"password"`
ALPN []string `proxy:"alpn,omitempty"`
SNI string `proxy:"sni,omitempty"`
ECHOpts ECHOptions `proxy:"ech-opts,omitempty"`
ClientFingerprint string `proxy:"client-fingerprint,omitempty"`
SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"`
Fingerprint string `proxy:"fingerprint,omitempty"`
Certificate string `proxy:"certificate,omitempty"`
PrivateKey string `proxy:"private-key,omitempty"`
UDP bool `proxy:"udp,omitempty"`
IdleSessionCheckInterval int `proxy:"idle-session-check-interval,omitempty"`
IdleSessionTimeout int `proxy:"idle-session-timeout,omitempty"`
MinIdleSession int `proxy:"min-idle-session,omitempty"`
}
func (t *AnyTLS) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn, err error) {
c, err := t.client.CreateProxy(ctx, M.ParseSocksaddrHostPort(metadata.String(), metadata.DstPort))
if err != nil {
return nil, err
}
return NewConn(c, t), nil
}
func (t *AnyTLS) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (_ C.PacketConn, err error) {
if err = t.ResolveUDP(ctx, metadata); err != nil {
return nil, err
}
// create tcp
c, err := t.client.CreateProxy(ctx, uot.RequestDestination(2))
if err != nil {
return nil, err
}
// create uot on tcp
destination := M.SocksaddrFromNet(metadata.UDPAddr())
return newPacketConn(N.NewThreadSafePacketConn(uot.NewLazyConn(c, uot.Request{Destination: destination})), t), nil
}
// SupportUOT implements C.ProxyAdapter
func (t *AnyTLS) SupportUOT() bool {
return true
}
// ProxyInfo implements C.ProxyAdapter
func (t *AnyTLS) ProxyInfo() C.ProxyInfo {
info := t.Base.ProxyInfo()
info.DialerProxy = t.option.DialerProxy
return info
}
// Close implements C.ProxyAdapter
func (t *AnyTLS) Close() error {
return t.client.Close()
}
func NewAnyTLS(option AnyTLSOption) (*AnyTLS, error) {
addr := net.JoinHostPort(option.Server, strconv.Itoa(option.Port))
outbound := &AnyTLS{
Base: &Base{
name: option.Name,
addr: addr,
tp: C.AnyTLS,
pdName: option.ProviderName,
udp: option.UDP,
tfo: option.TFO,
mpTcp: option.MPTCP,
iface: option.Interface,
rmark: option.RoutingMark,
prefer: option.IPVersion,
},
option: &option,
}
outbound.dialer = option.NewDialer(outbound.DialOptions())
singDialer := proxydialer.NewSingDialer(outbound.dialer)
tOption := anytls.ClientConfig{
Password: option.Password,
Server: M.ParseSocksaddrHostPort(option.Server, uint16(option.Port)),
Dialer: singDialer,
IdleSessionCheckInterval: time.Duration(option.IdleSessionCheckInterval) * time.Second,
IdleSessionTimeout: time.Duration(option.IdleSessionTimeout) * time.Second,
MinIdleSession: option.MinIdleSession,
}
echConfig, err := option.ECHOpts.Parse()
if err != nil {
return nil, err
}
tlsConfig := &vmess.TLSConfig{
Host: option.SNI,
SkipCertVerify: option.SkipCertVerify,
NextProtos: option.ALPN,
FingerPrint: option.Fingerprint,
Certificate: option.Certificate,
PrivateKey: option.PrivateKey,
ClientFingerprint: option.ClientFingerprint,
ECH: echConfig,
}
if tlsConfig.Host == "" {
tlsConfig.Host = option.Server
}
tOption.TLSConfig = tlsConfig
client := anytls.NewClient(context.TODO(), tOption)
outbound.client = client
return outbound, nil
}

View File

@@ -3,41 +3,28 @@ package outbound
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt"
"net" "net"
"runtime" "strings"
"sync"
"syscall" "syscall"
N "github.com/metacubex/mihomo/common/net" N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/common/utils" "github.com/metacubex/mihomo/common/utils"
"github.com/metacubex/mihomo/component/dialer" "github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/proxydialer"
"github.com/metacubex/mihomo/component/resolver"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/log"
) )
type ProxyAdapter interface {
C.ProxyAdapter
DialOptions() []dialer.Option
ResolveUDP(ctx context.Context, metadata *C.Metadata) error
}
type Base struct { type Base struct {
name string name string
addr string addr string
iface string
tp C.AdapterType tp C.AdapterType
pdName string
udp bool udp bool
xudp bool xudp bool
tfo bool tfo bool
mpTcp bool mpTcp bool
iface string
rmark int rmark int
prefer C.DNSPrefer
dialer C.Dialer
id string id string
prefer C.DNSPrefer
} }
// Name implements C.ProxyAdapter // Name implements C.ProxyAdapter
@@ -59,15 +46,35 @@ func (b *Base) Type() C.AdapterType {
return b.tp return b.tp
} }
func (b *Base) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) { // StreamConnContext implements C.ProxyAdapter
func (b *Base) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (net.Conn, error) {
return c, C.ErrNotSupport
}
func (b *Base) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.Conn, error) {
return nil, C.ErrNotSupport
}
// DialContextWithDialer implements C.ProxyAdapter
func (b *Base) DialContextWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (_ C.Conn, err error) {
return nil, C.ErrNotSupport return nil, C.ErrNotSupport
} }
// ListenPacketContext implements C.ProxyAdapter // ListenPacketContext implements C.ProxyAdapter
func (b *Base) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (C.PacketConn, error) { func (b *Base) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.PacketConn, error) {
return nil, C.ErrNotSupport return nil, C.ErrNotSupport
} }
// ListenPacketWithDialer implements C.ProxyAdapter
func (b *Base) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (_ C.PacketConn, err error) {
return nil, C.ErrNotSupport
}
// SupportWithDialer implements C.ProxyAdapter
func (b *Base) SupportWithDialer() C.NetWork {
return C.InvalidNet
}
// SupportUOT implements C.ProxyAdapter // SupportUOT implements C.ProxyAdapter
func (b *Base) SupportUOT() bool { func (b *Base) SupportUOT() bool {
return false return false
@@ -78,16 +85,14 @@ func (b *Base) SupportUDP() bool {
return b.udp return b.udp
} }
// ProxyInfo implements C.ProxyAdapter // SupportXUDP implements C.ProxyAdapter
func (b *Base) ProxyInfo() (info C.ProxyInfo) { func (b *Base) SupportXUDP() bool {
info.XUDP = b.xudp return b.xudp
info.TFO = b.tfo }
info.MPTCP = b.mpTcp
info.SMUX = false // SupportTFO implements C.ProxyAdapter
info.Interface = b.iface func (b *Base) SupportTFO() bool {
info.RoutingMark = b.rmark return b.tfo
info.ProviderName = b.pdName
return
} }
// IsL3Protocol implements C.ProxyAdapter // IsL3Protocol implements C.ProxyAdapter
@@ -114,7 +119,7 @@ func (b *Base) Unwrap(metadata *C.Metadata, touch bool) C.Proxy {
} }
// DialOptions return []dialer.Option from struct // DialOptions return []dialer.Option from struct
func (b *Base) DialOptions() (opts []dialer.Option) { func (b *Base) DialOptions(opts ...dialer.Option) []dialer.Option {
if b.iface != "" { if b.iface != "" {
opts = append(opts, dialer.WithInterface(b.iface)) opts = append(opts, dialer.WithInterface(b.iface))
} }
@@ -146,46 +151,13 @@ func (b *Base) DialOptions() (opts []dialer.Option) {
return opts return opts
} }
func (b *Base) ResolveUDP(ctx context.Context, metadata *C.Metadata) error {
if !metadata.Resolved() {
ip, err := resolver.ResolveIP(ctx, metadata.Host)
if err != nil {
return fmt.Errorf("can't resolve ip: %w", err)
}
metadata.DstIP = ip
}
return nil
}
func (b *Base) Close() error {
return nil
}
type BasicOption struct { type BasicOption struct {
TFO bool `proxy:"tfo,omitempty"` TFO bool `proxy:"tfo,omitempty" group:"tfo,omitempty"`
MPTCP bool `proxy:"mptcp,omitempty"` MPTCP bool `proxy:"mptcp,omitempty" group:"mptcp,omitempty"`
Interface string `proxy:"interface-name,omitempty"` Interface string `proxy:"interface-name,omitempty" group:"interface-name,omitempty"`
RoutingMark int `proxy:"routing-mark,omitempty"` RoutingMark int `proxy:"routing-mark,omitempty" group:"routing-mark,omitempty"`
IPVersion C.DNSPrefer `proxy:"ip-version,omitempty"` IPVersion string `proxy:"ip-version,omitempty" group:"ip-version,omitempty"`
DialerProxy string `proxy:"dialer-proxy,omitempty"` // don't apply this option into groups, but can set a group name in a proxy DialerProxy string `proxy:"dialer-proxy,omitempty"` // don't apply this option into groups, but can set a group name in a proxy
//
// The following parameters are used internally, assign value by the structure decoder are disallowed
//
DialerForAPI C.Dialer `proxy:"-"` // the dialer used for API usage has higher priority than all the above configurations.
ProviderName string `proxy:"-"`
}
func (b *BasicOption) NewDialer(opts []dialer.Option) C.Dialer {
cDialer := b.DialerForAPI
if cDialer == nil {
if b.DialerProxy != "" {
cDialer = proxydialer.NewByName(b.DialerProxy)
} else {
cDialer = dialer.NewDialer(opts...)
}
}
return cDialer
} }
type BaseOption struct { type BaseOption struct {
@@ -218,22 +190,12 @@ func NewBase(opt BaseOption) *Base {
type conn struct { type conn struct {
N.ExtendedConn N.ExtendedConn
chain C.Chain chain C.Chain
pdChain C.Chain actualRemoteDestination string
adapterAddr string
} }
func (c *conn) RemoteDestination() string { func (c *conn) RemoteDestination() string {
if remoteAddr := c.RemoteAddr(); remoteAddr != nil { return c.actualRemoteDestination
m := C.Metadata{}
if err := m.SetRemoteAddr(remoteAddr); err == nil {
if m.Valid() {
return m.String()
}
}
}
host, _, _ := net.SplitHostPort(c.adapterAddr)
return host
} }
// Chains implements C.Connection // Chains implements C.Connection
@@ -241,15 +203,9 @@ func (c *conn) Chains() C.Chain {
return c.chain return c.chain
} }
// ProviderChains implements C.Connection
func (c *conn) ProviderChains() C.Chain {
return c.pdChain
}
// AppendToChains implements C.Connection // AppendToChains implements C.Connection
func (c *conn) AppendToChains(a C.ProxyAdapter) { func (c *conn) AppendToChains(a C.ProxyAdapter) {
c.chain = append(c.chain, a.Name()) c.chain = append(c.chain, a.Name())
c.pdChain = append(c.pdChain, a.ProxyInfo().ProviderName)
} }
func (c *conn) Upstream() any { func (c *conn) Upstream() any {
@@ -264,36 +220,23 @@ func (c *conn) ReaderReplaceable() bool {
return true return true
} }
func (c *conn) AddRef(ref any) {
c.ExtendedConn = N.NewRefConn(c.ExtendedConn, ref) // add ref for autoCloseProxyAdapter
}
func NewConn(c net.Conn, a C.ProxyAdapter) C.Conn { func NewConn(c net.Conn, a C.ProxyAdapter) C.Conn {
if _, ok := c.(syscall.Conn); !ok { // exclusion system conn like *net.TCPConn if _, ok := c.(syscall.Conn); !ok { // exclusion system conn like *net.TCPConn
c = N.NewDeadlineConn(c) // most conn from outbound can't handle readDeadline correctly c = N.NewDeadlineConn(c) // most conn from outbound can't handle readDeadline correctly
} }
cc := &conn{N.NewExtendedConn(c), nil, nil, a.Addr()} return &conn{N.NewExtendedConn(c), []string{a.Name()}, parseRemoteDestination(a.Addr())}
cc.AppendToChains(a)
return cc
} }
type packetConn struct { type packetConn struct {
N.EnhancePacketConn N.EnhancePacketConn
chain C.Chain chain C.Chain
pdChain C.Chain adapterName string
adapterName string connID string
connID string actualRemoteDestination string
adapterAddr string
resolveUDP func(ctx context.Context, metadata *C.Metadata) error
}
func (c *packetConn) ResolveUDP(ctx context.Context, metadata *C.Metadata) error {
return c.resolveUDP(ctx, metadata)
} }
func (c *packetConn) RemoteDestination() string { func (c *packetConn) RemoteDestination() string {
host, _, _ := net.SplitHostPort(c.adapterAddr) return c.actualRemoteDestination
return host
} }
// Chains implements C.Connection // Chains implements C.Connection
@@ -301,15 +244,9 @@ func (c *packetConn) Chains() C.Chain {
return c.chain return c.chain
} }
// ProviderChains implements C.Connection
func (c *packetConn) ProviderChains() C.Chain {
return c.pdChain
}
// AppendToChains implements C.Connection // AppendToChains implements C.Connection
func (c *packetConn) AppendToChains(a C.ProxyAdapter) { func (c *packetConn) AppendToChains(a C.ProxyAdapter) {
c.chain = append(c.chain, a.Name()) c.chain = append(c.chain, a.Name())
c.pdChain = append(c.pdChain, a.ProxyInfo().ProviderName)
} }
func (c *packetConn) LocalAddr() net.Addr { func (c *packetConn) LocalAddr() net.Addr {
@@ -329,66 +266,22 @@ func (c *packetConn) ReaderReplaceable() bool {
return true return true
} }
func (c *packetConn) AddRef(ref any) { func newPacketConn(pc net.PacketConn, a C.ProxyAdapter) C.PacketConn {
c.EnhancePacketConn = N.NewRefPacketConn(c.EnhancePacketConn, ref) // add ref for autoCloseProxyAdapter
}
func newPacketConn(pc net.PacketConn, a ProxyAdapter) C.PacketConn {
epc := N.NewEnhancePacketConn(pc) epc := N.NewEnhancePacketConn(pc)
if _, ok := pc.(syscall.Conn); !ok { // exclusion system conn like *net.UDPConn if _, ok := pc.(syscall.Conn); !ok { // exclusion system conn like *net.UDPConn
epc = N.NewDeadlineEnhancePacketConn(epc) // most conn from outbound can't handle readDeadline correctly epc = N.NewDeadlineEnhancePacketConn(epc) // most conn from outbound can't handle readDeadline correctly
} }
cpc := &packetConn{epc, nil, nil, a.Name(), utils.NewUUIDV4().String(), a.Addr(), a.ResolveUDP} return &packetConn{epc, []string{a.Name()}, a.Name(), utils.NewUUIDV4().String(), parseRemoteDestination(a.Addr())}
cpc.AppendToChains(a)
return cpc
} }
type AddRef interface { func parseRemoteDestination(addr string) string {
AddRef(ref any) if dst, _, err := net.SplitHostPort(addr); err == nil {
} return dst
} else {
type autoCloseProxyAdapter struct { if addrError, ok := err.(*net.AddrError); ok && strings.Contains(addrError.Err, "missing port") {
ProxyAdapter return dst
closeOnce sync.Once } else {
closeErr error return ""
} }
func (p *autoCloseProxyAdapter) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn, err error) {
c, err := p.ProxyAdapter.DialContext(ctx, metadata)
if err != nil {
return nil, err
} }
if c, ok := c.(AddRef); ok {
c.AddRef(p)
}
return c, nil
}
func (p *autoCloseProxyAdapter) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (_ C.PacketConn, err error) {
pc, err := p.ProxyAdapter.ListenPacketContext(ctx, metadata)
if err != nil {
return nil, err
}
if pc, ok := pc.(AddRef); ok {
pc.AddRef(p)
}
return pc, nil
}
func (p *autoCloseProxyAdapter) Close() error {
p.closeOnce.Do(func() {
log.Debugln("Closing outdated proxy [%s]", p.Name())
runtime.SetFinalizer(p, nil)
p.closeErr = p.ProxyAdapter.Close()
})
return p.closeErr
}
func NewAutoCloseProxyAdapter(adapter ProxyAdapter) ProxyAdapter {
proxy := &autoCloseProxyAdapter{
ProxyAdapter: adapter,
}
// auto close ProxyAdapter
runtime.SetFinalizer(proxy, (*autoCloseProxyAdapter).Close)
return proxy
} }

View File

@@ -2,17 +2,19 @@ package outbound
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"net/netip"
N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/component/dialer" "github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/loopback"
"github.com/metacubex/mihomo/component/resolver" "github.com/metacubex/mihomo/component/resolver"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
) )
type Direct struct { type Direct struct {
*Base *Base
loopBack *loopback.Detector loopBack *loopBackDetector
} }
type DirectOption struct { type DirectOption struct {
@@ -21,63 +23,52 @@ type DirectOption struct {
} }
// DialContext implements C.ProxyAdapter // DialContext implements C.ProxyAdapter
func (d *Direct) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) { func (d *Direct) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.Conn, error) {
if err := d.loopBack.CheckConn(metadata); err != nil { if d.loopBack.CheckConn(metadata.SourceAddrPort()) {
return nil, err return nil, fmt.Errorf("reject loopback connection to: %s", metadata.RemoteAddress())
} }
opts := d.DialOptions() opts = append(opts, dialer.WithResolver(resolver.DefaultResolver))
opts = append(opts, dialer.WithResolver(resolver.DirectHostResolver)) c, err := dialer.DialContext(ctx, "tcp", metadata.RemoteAddress(), d.Base.DialOptions(opts...)...)
c, err := dialer.DialContext(ctx, "tcp", metadata.RemoteAddress(), opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
N.TCPKeepAlive(c)
return d.loopBack.NewConn(NewConn(c, d)), nil return d.loopBack.NewConn(NewConn(c, d)), nil
} }
// ListenPacketContext implements C.ProxyAdapter // ListenPacketContext implements C.ProxyAdapter
func (d *Direct) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (C.PacketConn, error) { func (d *Direct) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.PacketConn, error) {
if err := d.loopBack.CheckPacketConn(metadata); err != nil { if d.loopBack.CheckPacketConn(metadata.SourceAddrPort()) {
return nil, err return nil, fmt.Errorf("reject loopback connection to: %s", metadata.RemoteAddress())
} }
if err := d.ResolveUDP(ctx, metadata); err != nil { // net.UDPConn.WriteTo only working with *net.UDPAddr, so we need a net.UDPAddr
return nil, err if !metadata.Resolved() {
ip, err := resolver.ResolveIPWithResolver(ctx, metadata.Host, resolver.DefaultResolver)
if err != nil {
return nil, errors.New("can't resolve ip")
}
metadata.DstIP = ip
} }
pc, err := dialer.NewDialer(d.DialOptions()...).ListenPacket(ctx, "udp", "", metadata.AddrPort()) pc, err := dialer.NewDialer(d.Base.DialOptions(opts...)...).ListenPacket(ctx, "udp", "", netip.AddrPortFrom(metadata.DstIP, metadata.DstPort))
if err != nil { if err != nil {
return nil, err return nil, err
} }
return d.loopBack.NewPacketConn(newPacketConn(pc, d)), nil return d.loopBack.NewPacketConn(newPacketConn(pc, d)), nil
} }
func (d *Direct) ResolveUDP(ctx context.Context, metadata *C.Metadata) error {
if (!metadata.Resolved() || resolver.DirectHostResolver != resolver.DefaultResolver) && metadata.Host != "" {
ip, err := resolver.ResolveIPWithResolver(ctx, metadata.Host, resolver.DirectHostResolver)
if err != nil {
return fmt.Errorf("can't resolve ip: %w", err)
}
metadata.DstIP = ip
}
return nil
}
func (d *Direct) IsL3Protocol(metadata *C.Metadata) bool {
return true // tell DNSDialer don't send domain to DialContext, avoid lookback to DefaultResolver
}
func NewDirectWithOption(option DirectOption) *Direct { func NewDirectWithOption(option DirectOption) *Direct {
return &Direct{ return &Direct{
Base: &Base{ Base: &Base{
name: option.Name, name: option.Name,
tp: C.Direct, tp: C.Direct,
pdName: option.ProviderName,
udp: true, udp: true,
tfo: option.TFO, tfo: option.TFO,
mpTcp: option.MPTCP, mpTcp: option.MPTCP,
iface: option.Interface, iface: option.Interface,
rmark: option.RoutingMark, rmark: option.RoutingMark,
prefer: option.IPVersion, prefer: C.NewDNSPrefer(option.IPVersion),
}, },
loopBack: loopback.NewDetector(), loopBack: newLoopBackDetector(),
} }
} }
@@ -89,7 +80,7 @@ func NewDirect() *Direct {
udp: true, udp: true,
prefer: C.DualStack, prefer: C.DualStack,
}, },
loopBack: loopback.NewDetector(), loopBack: newLoopBackDetector(),
} }
} }
@@ -101,6 +92,6 @@ func NewCompatible() *Direct {
udp: true, udp: true,
prefer: C.DualStack, prefer: C.DualStack,
}, },
loopBack: loopback.NewDetector(), loopBack: newLoopBackDetector(),
} }
} }

View File

@@ -0,0 +1,68 @@
package outbound
import (
"net/netip"
"github.com/metacubex/mihomo/common/callback"
C "github.com/metacubex/mihomo/constant"
"github.com/puzpuzpuz/xsync/v3"
)
type loopBackDetector struct {
connMap *xsync.MapOf[netip.AddrPort, struct{}]
packetConnMap *xsync.MapOf[netip.AddrPort, struct{}]
}
func newLoopBackDetector() *loopBackDetector {
return &loopBackDetector{
connMap: xsync.NewMapOf[netip.AddrPort, struct{}](),
packetConnMap: xsync.NewMapOf[netip.AddrPort, struct{}](),
}
}
func (l *loopBackDetector) NewConn(conn C.Conn) C.Conn {
metadata := C.Metadata{}
if metadata.SetRemoteAddr(conn.LocalAddr()) != nil {
return conn
}
connAddr := metadata.AddrPort()
if !connAddr.IsValid() {
return conn
}
l.connMap.Store(connAddr, struct{}{})
return callback.NewCloseCallbackConn(conn, func() {
l.connMap.Delete(connAddr)
})
}
func (l *loopBackDetector) NewPacketConn(conn C.PacketConn) C.PacketConn {
metadata := C.Metadata{}
if metadata.SetRemoteAddr(conn.LocalAddr()) != nil {
return conn
}
connAddr := metadata.AddrPort()
if !connAddr.IsValid() {
return conn
}
l.packetConnMap.Store(connAddr, struct{}{})
return callback.NewCloseCallbackPacketConn(conn, func() {
l.packetConnMap.Delete(connAddr)
})
}
func (l *loopBackDetector) CheckConn(connAddr netip.AddrPort) bool {
if !connAddr.IsValid() {
return false
}
_, ok := l.connMap.Load(connAddr)
return ok
}
func (l *loopBackDetector) CheckPacketConn(connAddr netip.AddrPort) bool {
if !connAddr.IsValid() {
return false
}
_, ok := l.packetConnMap.Load(connAddr)
return ok
}

View File

@@ -3,11 +3,11 @@ package outbound
import ( import (
"context" "context"
"net" "net"
"net/netip"
"time" "time"
N "github.com/metacubex/mihomo/common/net" N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/common/pool" "github.com/metacubex/mihomo/common/pool"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/resolver" "github.com/metacubex/mihomo/component/resolver"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/log" "github.com/metacubex/mihomo/log"
@@ -23,18 +23,15 @@ type DnsOption struct {
} }
// DialContext implements C.ProxyAdapter // DialContext implements C.ProxyAdapter
func (d *Dns) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) { func (d *Dns) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.Conn, error) {
left, right := N.Pipe() left, right := N.Pipe()
go resolver.RelayDnsConn(context.Background(), right, 0) go resolver.RelayDnsConn(context.Background(), right, 0)
return NewConn(left, d), nil return NewConn(left, d), nil
} }
// ListenPacketContext implements C.ProxyAdapter // ListenPacketContext implements C.ProxyAdapter
func (d *Dns) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (C.PacketConn, error) { func (d *Dns) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.PacketConn, error) {
log.Debugln("[DNS] hijack udp:%s from %s", metadata.RemoteAddress(), metadata.SourceAddrPort()) log.Debugln("[DNS] hijack udp:%s from %s", metadata.RemoteAddress(), metadata.SourceAddrPort())
if err := d.ResolveUDP(ctx, metadata); err != nil {
return nil, err
}
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
@@ -45,13 +42,6 @@ func (d *Dns) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (C.
}, d), nil }, d), nil
} }
func (d *Dns) ResolveUDP(ctx context.Context, metadata *C.Metadata) error {
if !metadata.Resolved() {
metadata.DstIP = netip.AddrFrom4([4]byte{127, 0, 0, 2})
}
return nil
}
type dnsPacket struct { type dnsPacket struct {
data []byte data []byte
put func() put func()
@@ -99,14 +89,14 @@ func (d *dnsPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
return len(p), nil return len(p), nil
} }
ctx, cancel := context.WithTimeout(d.ctx, resolver.DefaultDnsRelayTimeout)
defer cancel()
buf := pool.Get(resolver.SafeDnsPacketSize) buf := pool.Get(resolver.SafeDnsPacketSize)
put := func() { _ = pool.Put(buf) } put := func() { _ = pool.Put(buf) }
copy(buf, p) // avoid p be changed after WriteTo returned copy(buf, p) // avoid p be changed after WriteTo returned
go func() { // don't block the WriteTo function go func() { // don't block the WriteTo function
ctx, cancel := context.WithTimeout(d.ctx, resolver.DefaultDnsRelayTimeout)
defer cancel()
buf, err = resolver.RelayDnsPacket(ctx, buf[:len(p)], buf) buf, err = resolver.RelayDnsPacket(ctx, buf[:len(p)], buf)
if err != nil { if err != nil {
put() put()
@@ -158,13 +148,12 @@ func NewDnsWithOption(option DnsOption) *Dns {
Base: &Base{ Base: &Base{
name: option.Name, name: option.Name,
tp: C.Dns, tp: C.Dns,
pdName: option.ProviderName,
udp: true, udp: true,
tfo: option.TFO, tfo: option.TFO,
mpTcp: option.MPTCP, mpTcp: option.MPTCP,
iface: option.Interface, iface: option.Interface,
rmark: option.RoutingMark, rmark: option.RoutingMark,
prefer: option.IPVersion, prefer: C.NewDNSPrefer(option.IPVersion),
}, },
} }
} }

View File

@@ -1,41 +0,0 @@
package outbound
import (
"context"
"encoding/base64"
"fmt"
"github.com/metacubex/mihomo/component/ech"
"github.com/metacubex/mihomo/component/resolver"
)
type ECHOptions struct {
Enable bool `proxy:"enable,omitempty" obfs:"enable,omitempty"`
Config string `proxy:"config,omitempty" obfs:"config,omitempty"`
QueryServerName string `proxy:"query-server-name,omitempty" obfs:"query-server-name,omitempty"`
}
func (o ECHOptions) Parse() (*ech.Config, error) {
if !o.Enable {
return nil, nil
}
echConfig := &ech.Config{}
if o.Config != "" {
list, err := base64.StdEncoding.DecodeString(o.Config)
if err != nil {
return nil, fmt.Errorf("base64 decode ech config string failed: %v", err)
}
echConfig.GetEncryptedClientHelloConfigList = func(ctx context.Context, serverName string) ([]byte, error) {
return list, nil
}
} else {
echConfig.GetEncryptedClientHelloConfigList = func(ctx context.Context, serverName string) ([]byte, error) {
if o.QueryServerName != "" { // overrides the domain name used for ECH HTTPS record queries
serverName = o.QueryServerName
}
return resolver.ResolveECHWithResolver(ctx, serverName, resolver.ProxyServerHostResolver)
}
}
return echConfig, nil
}

View File

@@ -3,18 +3,21 @@ package outbound
import ( import (
"bufio" "bufio"
"context" "context"
"crypto/tls"
"encoding/base64" "encoding/base64"
"errors" "errors"
"fmt" "fmt"
"io"
"net" "net"
"net/http"
"strconv" "strconv"
N "github.com/metacubex/mihomo/common/net" N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/component/ca" "github.com/metacubex/mihomo/component/ca"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/proxydialer"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/http"
"github.com/metacubex/tls"
) )
type Http struct { type Http struct {
@@ -36,8 +39,6 @@ type HttpOption struct {
SNI string `proxy:"sni,omitempty"` SNI string `proxy:"sni,omitempty"`
SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"` SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"`
Fingerprint string `proxy:"fingerprint,omitempty"` Fingerprint string `proxy:"fingerprint,omitempty"`
Certificate string `proxy:"certificate,omitempty"`
PrivateKey string `proxy:"private-key,omitempty"`
Headers map[string]string `proxy:"headers,omitempty"` Headers map[string]string `proxy:"headers,omitempty"`
} }
@@ -52,18 +53,30 @@ func (h *Http) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.Me
} }
} }
if err := h.shakeHandContext(ctx, c, metadata); err != nil { if err := h.shakeHand(metadata, c); err != nil {
return nil, err return nil, err
} }
return c, nil return c, nil
} }
// DialContext implements C.ProxyAdapter // DialContext implements C.ProxyAdapter
func (h *Http) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn, err error) { func (h *Http) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.Conn, err error) {
c, err := h.dialer.DialContext(ctx, "tcp", h.addr) return h.DialContextWithDialer(ctx, dialer.NewDialer(h.Base.DialOptions(opts...)...), metadata)
}
// DialContextWithDialer implements C.ProxyAdapter
func (h *Http) DialContextWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (_ C.Conn, err error) {
if len(h.option.DialerProxy) > 0 {
dialer, err = proxydialer.NewByName(h.option.DialerProxy, dialer)
if err != nil {
return nil, err
}
}
c, err := dialer.DialContext(ctx, "tcp", h.addr)
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error: %w", h.addr, err) return nil, fmt.Errorf("%s connect error: %w", h.addr, err)
} }
N.TCPKeepAlive(c)
defer func(c net.Conn) { defer func(c net.Conn) {
safeConnClose(c, err) safeConnClose(c, err)
@@ -77,19 +90,12 @@ func (h *Http) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn,
return NewConn(c, h), nil return NewConn(c, h), nil
} }
// ProxyInfo implements C.ProxyAdapter // SupportWithDialer implements C.ProxyAdapter
func (h *Http) ProxyInfo() C.ProxyInfo { func (h *Http) SupportWithDialer() C.NetWork {
info := h.Base.ProxyInfo() return C.TCP
info.DialerProxy = h.option.DialerProxy
return info
} }
func (h *Http) shakeHandContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (err error) { func (h *Http) shakeHand(metadata *C.Metadata, rw io.ReadWriter) error {
if ctx.Done() != nil {
done := N.SetupContextForConn(ctx, c)
defer done(&err)
}
addr := metadata.RemoteAddress() addr := metadata.RemoteAddress()
HeaderString := "CONNECT " + addr + " HTTP/1.1\r\n" HeaderString := "CONNECT " + addr + " HTTP/1.1\r\n"
tempHeaders := map[string]string{ tempHeaders := map[string]string{
@@ -113,13 +119,13 @@ func (h *Http) shakeHandContext(ctx context.Context, c net.Conn, metadata *C.Met
HeaderString += "\r\n" HeaderString += "\r\n"
_, err = c.Write([]byte(HeaderString)) _, err := rw.Write([]byte(HeaderString))
if err != nil { if err != nil {
return err return err
} }
resp, err := http.ReadResponse(bufio.NewReader(c), nil) resp, err := http.ReadResponse(bufio.NewReader(rw), nil)
if err != nil { if err != nil {
return err return err
@@ -152,37 +158,29 @@ func NewHttp(option HttpOption) (*Http, error) {
sni = option.SNI sni = option.SNI
} }
var err error var err error
tlsConfig, err = ca.GetTLSConfig(ca.Option{ tlsConfig, err = ca.GetSpecifiedFingerprintTLSConfig(&tls.Config{
TLSConfig: &tls.Config{ InsecureSkipVerify: option.SkipCertVerify,
InsecureSkipVerify: option.SkipCertVerify, ServerName: sni,
ServerName: sni, }, option.Fingerprint)
},
Fingerprint: option.Fingerprint,
Certificate: option.Certificate,
PrivateKey: option.PrivateKey,
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
outbound := &Http{ return &Http{
Base: &Base{ Base: &Base{
name: option.Name, name: option.Name,
addr: net.JoinHostPort(option.Server, strconv.Itoa(option.Port)), addr: net.JoinHostPort(option.Server, strconv.Itoa(option.Port)),
tp: C.Http, tp: C.Http,
pdName: option.ProviderName,
tfo: option.TFO, tfo: option.TFO,
mpTcp: option.MPTCP, mpTcp: option.MPTCP,
iface: option.Interface, iface: option.Interface,
rmark: option.RoutingMark, rmark: option.RoutingMark,
prefer: option.IPVersion, prefer: C.NewDNSPrefer(option.IPVersion),
}, },
user: option.UserName, user: option.UserName,
pass: option.Password, pass: option.Password,
tlsConfig: tlsConfig, tlsConfig: tlsConfig,
option: &option, option: &option,
} }, nil
outbound.dialer = option.NewDialer(outbound.DialOptions())
return outbound, nil
} }

View File

@@ -2,6 +2,7 @@ package outbound
import ( import (
"context" "context"
"crypto/tls"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"net" "net"
@@ -9,9 +10,13 @@ import (
"strconv" "strconv"
"time" "time"
"github.com/metacubex/quic-go"
"github.com/metacubex/quic-go/congestion"
M "github.com/sagernet/sing/common/metadata"
"github.com/metacubex/mihomo/component/ca" "github.com/metacubex/mihomo/component/ca"
"github.com/metacubex/mihomo/component/dialer" "github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/ech" "github.com/metacubex/mihomo/component/proxydialer"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/log" "github.com/metacubex/mihomo/log"
hyCongestion "github.com/metacubex/mihomo/transport/hysteria/congestion" hyCongestion "github.com/metacubex/mihomo/transport/hysteria/congestion"
@@ -20,12 +25,6 @@ import (
"github.com/metacubex/mihomo/transport/hysteria/pmtud_fix" "github.com/metacubex/mihomo/transport/hysteria/pmtud_fix"
"github.com/metacubex/mihomo/transport/hysteria/transport" "github.com/metacubex/mihomo/transport/hysteria/transport"
"github.com/metacubex/mihomo/transport/hysteria/utils" "github.com/metacubex/mihomo/transport/hysteria/utils"
"github.com/metacubex/tls"
"github.com/metacubex/quic-go"
"github.com/metacubex/quic-go/congestion"
M "github.com/metacubex/sing/common/metadata"
) )
const ( const (
@@ -44,13 +43,10 @@ type Hysteria struct {
option *HysteriaOption option *HysteriaOption
client *core.Client client *core.Client
tlsConfig *tls.Config
echConfig *ech.Config
} }
func (h *Hysteria) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) { func (h *Hysteria) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.Conn, error) {
tcpConn, err := h.client.DialTCP(metadata.String(), metadata.DstPort, h.genHdc(ctx)) tcpConn, err := h.client.DialTCP(metadata.String(), metadata.DstPort, h.genHdc(ctx, opts...))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -58,72 +54,61 @@ func (h *Hysteria) DialContext(ctx context.Context, metadata *C.Metadata) (C.Con
return NewConn(tcpConn, h), nil return NewConn(tcpConn, h), nil
} }
func (h *Hysteria) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (C.PacketConn, error) { func (h *Hysteria) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.PacketConn, error) {
if err := h.ResolveUDP(ctx, metadata); err != nil { udpConn, err := h.client.DialUDP(h.genHdc(ctx, opts...))
return nil, err
}
udpConn, err := h.client.DialUDP(h.genHdc(ctx))
if err != nil { if err != nil {
return nil, err return nil, err
} }
return newPacketConn(&hyPacketConn{udpConn}, h), nil return newPacketConn(&hyPacketConn{udpConn}, h), nil
} }
func (h *Hysteria) genHdc(ctx context.Context) utils.PacketDialer { func (h *Hysteria) genHdc(ctx context.Context, opts ...dialer.Option) utils.PacketDialer {
return &hyDialerWithContext{ return &hyDialerWithContext{
ctx: context.Background(), ctx: context.Background(),
hyDialer: func(network string, rAddr net.Addr) (net.PacketConn, error) { hyDialer: func(network string) (net.PacketConn, error) {
rAddrPort, _ := netip.ParseAddrPort(rAddr.String()) var err error
return h.dialer.ListenPacket(ctx, network, "", rAddrPort) var cDialer C.Dialer = dialer.NewDialer(h.Base.DialOptions(opts...)...)
if len(h.option.DialerProxy) > 0 {
cDialer, err = proxydialer.NewByName(h.option.DialerProxy, cDialer)
if err != nil {
return nil, err
}
}
rAddrPort, _ := netip.ParseAddrPort(h.Addr())
return cDialer.ListenPacket(ctx, network, "", rAddrPort)
}, },
remoteAddr: func(addr string) (net.Addr, error) { remoteAddr: func(addr string) (net.Addr, error) {
udpAddr, err := resolveUDPAddr(ctx, "udp", addr, h.prefer) return resolveUDPAddrWithPrefer(ctx, "udp", addr, h.prefer)
if err != nil {
return nil, err
}
err = h.echConfig.ClientHandle(ctx, h.tlsConfig)
if err != nil {
return nil, err
}
return udpAddr, nil
}, },
} }
} }
// ProxyInfo implements C.ProxyAdapter
func (h *Hysteria) ProxyInfo() C.ProxyInfo {
info := h.Base.ProxyInfo()
info.DialerProxy = h.option.DialerProxy
return info
}
type HysteriaOption struct { type HysteriaOption struct {
BasicOption BasicOption
Name string `proxy:"name"` Name string `proxy:"name"`
Server string `proxy:"server"` Server string `proxy:"server"`
Port int `proxy:"port,omitempty"` Port int `proxy:"port,omitempty"`
Ports string `proxy:"ports,omitempty"` Ports string `proxy:"ports,omitempty"`
Protocol string `proxy:"protocol,omitempty"` Protocol string `proxy:"protocol,omitempty"`
ObfsProtocol string `proxy:"obfs-protocol,omitempty"` // compatible with Stash ObfsProtocol string `proxy:"obfs-protocol,omitempty"` // compatible with Stash
Up string `proxy:"up"` Up string `proxy:"up"`
UpSpeed int `proxy:"up-speed,omitempty"` // compatible with Stash UpSpeed int `proxy:"up-speed,omitempty"` // compatible with Stash
Down string `proxy:"down"` Down string `proxy:"down"`
DownSpeed int `proxy:"down-speed,omitempty"` // compatible with Stash DownSpeed int `proxy:"down-speed,omitempty"` // compatible with Stash
Auth string `proxy:"auth,omitempty"` Auth string `proxy:"auth,omitempty"`
AuthString string `proxy:"auth-str,omitempty"` AuthString string `proxy:"auth-str,omitempty"`
Obfs string `proxy:"obfs,omitempty"` Obfs string `proxy:"obfs,omitempty"`
SNI string `proxy:"sni,omitempty"` SNI string `proxy:"sni,omitempty"`
ECHOpts ECHOptions `proxy:"ech-opts,omitempty"` SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"`
SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"` Fingerprint string `proxy:"fingerprint,omitempty"`
Fingerprint string `proxy:"fingerprint,omitempty"` ALPN []string `proxy:"alpn,omitempty"`
Certificate string `proxy:"certificate,omitempty"` CustomCA string `proxy:"ca,omitempty"`
PrivateKey string `proxy:"private-key,omitempty"` CustomCAString string `proxy:"ca-str,omitempty"`
ALPN []string `proxy:"alpn,omitempty"` ReceiveWindowConn int `proxy:"recv-window-conn,omitempty"`
ReceiveWindowConn int `proxy:"recv-window-conn,omitempty"` ReceiveWindow int `proxy:"recv-window,omitempty"`
ReceiveWindow int `proxy:"recv-window,omitempty"` DisableMTUDiscovery bool `proxy:"disable-mtu-discovery,omitempty"`
DisableMTUDiscovery bool `proxy:"disable-mtu-discovery,omitempty"` FastOpen bool `proxy:"fast-open,omitempty"`
FastOpen bool `proxy:"fast-open,omitempty"` HopInterval int `proxy:"hop-interval,omitempty"`
HopInterval int `proxy:"hop-interval,omitempty"`
} }
func (c *HysteriaOption) Speed() (uint64, uint64, error) { func (c *HysteriaOption) Speed() (uint64, uint64, error) {
@@ -142,7 +127,11 @@ func (c *HysteriaOption) Speed() (uint64, uint64, error) {
} }
func NewHysteria(option HysteriaOption) (*Hysteria, error) { func NewHysteria(option HysteriaOption) (*Hysteria, error) {
clientTransport := &transport.ClientTransport{} clientTransport := &transport.ClientTransport{
Dialer: &net.Dialer{
Timeout: 8 * time.Second,
},
}
addr := net.JoinHostPort(option.Server, strconv.Itoa(option.Port)) addr := net.JoinHostPort(option.Server, strconv.Itoa(option.Port))
ports := option.Ports ports := option.Ports
@@ -151,32 +140,23 @@ func NewHysteria(option HysteriaOption) (*Hysteria, error) {
serverName = option.SNI serverName = option.SNI
} }
tlsConfig, err := ca.GetTLSConfig(ca.Option{ tlsConfig := &tls.Config{
TLSConfig: &tls.Config{ ServerName: serverName,
ServerName: serverName, InsecureSkipVerify: option.SkipCertVerify,
InsecureSkipVerify: option.SkipCertVerify, MinVersion: tls.VersionTLS13,
MinVersion: tls.VersionTLS13, }
},
Fingerprint: option.Fingerprint, var err error
Certificate: option.Certificate, tlsConfig, err = ca.GetTLSConfig(tlsConfig, option.Fingerprint, option.CustomCA, option.CustomCAString)
PrivateKey: option.PrivateKey,
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
if option.ALPN != nil { // structure's Decode will ensure value not nil when input has value even it was set an empty array if len(option.ALPN) > 0 {
tlsConfig.NextProtos = option.ALPN tlsConfig.NextProtos = option.ALPN
} else { } else {
tlsConfig.NextProtos = []string{DefaultALPN} tlsConfig.NextProtos = []string{DefaultALPN}
} }
echConfig, err := option.ECHOpts.Parse()
if err != nil {
return nil, err
}
tlsClientConfig := tlsConfig
quicConfig := &quic.Config{ quicConfig := &quic.Config{
InitialStreamReceiveWindow: uint64(option.ReceiveWindowConn), InitialStreamReceiveWindow: uint64(option.ReceiveWindowConn),
MaxStreamReceiveWindow: uint64(option.ReceiveWindowConn), MaxStreamReceiveWindow: uint64(option.ReceiveWindowConn),
@@ -231,41 +211,27 @@ func NewHysteria(option HysteriaOption) (*Hysteria, error) {
down = uint64(option.DownSpeed * mbpsToBps) down = uint64(option.DownSpeed * mbpsToBps)
} }
client, err := core.NewClient( client, err := core.NewClient(
addr, ports, option.Protocol, auth, tlsClientConfig, quicConfig, clientTransport, up, down, func(refBPS uint64) congestion.CongestionControl { addr, ports, option.Protocol, auth, tlsConfig, quicConfig, clientTransport, up, down, func(refBPS uint64) congestion.CongestionControl {
return hyCongestion.NewBrutalSender(congestion.ByteCount(refBPS)) return hyCongestion.NewBrutalSender(congestion.ByteCount(refBPS))
}, obfuscator, hopInterval, option.FastOpen, }, obfuscator, hopInterval, option.FastOpen,
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("hysteria %s create error: %w", addr, err) return nil, fmt.Errorf("hysteria %s create error: %w", addr, err)
} }
outbound := &Hysteria{ return &Hysteria{
Base: &Base{ Base: &Base{
name: option.Name, name: option.Name,
addr: addr, addr: addr,
tp: C.Hysteria, tp: C.Hysteria,
pdName: option.ProviderName,
udp: true, udp: true,
tfo: option.FastOpen, tfo: option.FastOpen,
iface: option.Interface, iface: option.Interface,
rmark: option.RoutingMark, rmark: option.RoutingMark,
prefer: option.IPVersion, prefer: C.NewDNSPrefer(option.IPVersion),
}, },
option: &option, option: &option,
client: client, client: client,
tlsConfig: tlsClientConfig, }, nil
echConfig: echConfig,
}
outbound.dialer = option.NewDialer(outbound.DialOptions())
return outbound, nil
}
// Close implements C.ProxyAdapter
func (h *Hysteria) Close() error {
if h.client != nil {
return h.client.Close()
}
return nil
} }
type hyPacketConn struct { type hyPacketConn struct {
@@ -302,7 +268,7 @@ func (c *hyPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
} }
type hyDialerWithContext struct { type hyDialerWithContext struct {
hyDialer func(network string, rAddr net.Addr) (net.PacketConn, error) hyDialer func(network string) (net.PacketConn, error)
ctx context.Context ctx context.Context
remoteAddr func(host string) (net.Addr, error) remoteAddr func(host string) (net.Addr, error)
} }
@@ -312,7 +278,7 @@ func (h *hyDialerWithContext) ListenPacket(rAddr net.Addr) (net.PacketConn, erro
if addrPort, err := netip.ParseAddrPort(rAddr.String()); err == nil { if addrPort, err := netip.ParseAddrPort(rAddr.String()); err == nil {
network = dialer.ParseNetwork(network, addrPort.Addr()) network = dialer.ParseNetwork(network, addrPort.Addr())
} }
return h.hyDialer(network, rAddr) return h.hyDialer(network)
} }
func (h *hyDialerWithContext) Context() context.Context { func (h *hyDialerWithContext) Context() context.Context {

View File

@@ -2,24 +2,27 @@ package outbound
import ( import (
"context" "context"
"crypto/tls"
"errors" "errors"
"fmt" "fmt"
"net" "net"
"runtime"
"strconv" "strconv"
"time" "time"
N "github.com/metacubex/mihomo/common/net" CN "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/common/utils" "github.com/metacubex/mihomo/common/utils"
"github.com/metacubex/mihomo/component/ca" "github.com/metacubex/mihomo/component/ca"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/proxydialer" "github.com/metacubex/mihomo/component/proxydialer"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/log" "github.com/metacubex/mihomo/log"
tuicCommon "github.com/metacubex/mihomo/transport/tuic/common" tuicCommon "github.com/metacubex/mihomo/transport/tuic/common"
"github.com/metacubex/quic-go"
"github.com/metacubex/sing-quic/hysteria2" "github.com/metacubex/sing-quic/hysteria2"
M "github.com/metacubex/sing/common/metadata"
"github.com/metacubex/tls" M "github.com/sagernet/sing/common/metadata"
"github.com/zhangyunhao116/fastrand"
) )
func init() { func init() {
@@ -34,49 +37,44 @@ type Hysteria2 struct {
option *Hysteria2Option option *Hysteria2Option
client *hysteria2.Client client *hysteria2.Client
dialer proxydialer.SingDialer
} }
type Hysteria2Option struct { type Hysteria2Option struct {
BasicOption BasicOption
Name string `proxy:"name"` Name string `proxy:"name"`
Server string `proxy:"server"` Server string `proxy:"server"`
Port int `proxy:"port,omitempty"` Port int `proxy:"port,omitempty"`
Ports string `proxy:"ports,omitempty"` Ports string `proxy:"ports,omitempty"`
HopInterval int `proxy:"hop-interval,omitempty"` HopInterval int `proxy:"hop-interval,omitempty"`
Up string `proxy:"up,omitempty"` Up string `proxy:"up,omitempty"`
Down string `proxy:"down,omitempty"` Down string `proxy:"down,omitempty"`
Password string `proxy:"password,omitempty"` Password string `proxy:"password,omitempty"`
Obfs string `proxy:"obfs,omitempty"` Obfs string `proxy:"obfs,omitempty"`
ObfsPassword string `proxy:"obfs-password,omitempty"` ObfsPassword string `proxy:"obfs-password,omitempty"`
SNI string `proxy:"sni,omitempty"` SNI string `proxy:"sni,omitempty"`
ECHOpts ECHOptions `proxy:"ech-opts,omitempty"` SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"`
SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"` Fingerprint string `proxy:"fingerprint,omitempty"`
Fingerprint string `proxy:"fingerprint,omitempty"` ALPN []string `proxy:"alpn,omitempty"`
Certificate string `proxy:"certificate,omitempty"` CustomCA string `proxy:"ca,omitempty"`
PrivateKey string `proxy:"private-key,omitempty"` CustomCAString string `proxy:"ca-str,omitempty"`
ALPN []string `proxy:"alpn,omitempty"` CWND int `proxy:"cwnd,omitempty"`
CWND int `proxy:"cwnd,omitempty"` UdpMTU int `proxy:"udp-mtu,omitempty"`
UdpMTU int `proxy:"udp-mtu,omitempty"`
// quic-go special config
InitialStreamReceiveWindow uint64 `proxy:"initial-stream-receive-window,omitempty"`
MaxStreamReceiveWindow uint64 `proxy:"max-stream-receive-window,omitempty"`
InitialConnectionReceiveWindow uint64 `proxy:"initial-connection-receive-window,omitempty"`
MaxConnectionReceiveWindow uint64 `proxy:"max-connection-receive-window,omitempty"`
} }
func (h *Hysteria2) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn, err error) { func (h *Hysteria2) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.Conn, err error) {
options := h.Base.DialOptions(opts...)
h.dialer.SetDialer(dialer.NewDialer(options...))
c, err := h.client.DialConn(ctx, M.ParseSocksaddrHostPort(metadata.String(), metadata.DstPort)) c, err := h.client.DialConn(ctx, M.ParseSocksaddrHostPort(metadata.String(), metadata.DstPort))
if err != nil { if err != nil {
return nil, err return nil, err
} }
return NewConn(c, h), nil return NewConn(CN.NewRefConn(c, h), h), nil
} }
func (h *Hysteria2) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (_ C.PacketConn, err error) { func (h *Hysteria2) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.PacketConn, err error) {
if err = h.ResolveUDP(ctx, metadata); err != nil { options := h.Base.DialOptions(opts...)
return nil, err h.dialer.SetDialer(dialer.NewDialer(options...))
}
pc, err := h.client.ListenPacket(ctx) pc, err := h.client.ListenPacket(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -84,42 +82,17 @@ func (h *Hysteria2) ListenPacketContext(ctx context.Context, metadata *C.Metadat
if pc == nil { if pc == nil {
return nil, errors.New("packetConn is nil") return nil, errors.New("packetConn is nil")
} }
return newPacketConn(N.NewThreadSafePacketConn(pc), h), nil return newPacketConn(CN.NewRefPacketConn(CN.NewThreadSafePacketConn(pc), h), h), nil
} }
// Close implements C.ProxyAdapter func closeHysteria2(h *Hysteria2) {
func (h *Hysteria2) Close() error {
if h.client != nil { if h.client != nil {
return h.client.CloseWithError(errors.New("proxy removed")) _ = h.client.CloseWithError(errors.New("proxy removed"))
} }
return nil
}
// ProxyInfo implements C.ProxyAdapter
func (h *Hysteria2) ProxyInfo() C.ProxyInfo {
info := h.Base.ProxyInfo()
info.DialerProxy = h.option.DialerProxy
return info
} }
func NewHysteria2(option Hysteria2Option) (*Hysteria2, error) { func NewHysteria2(option Hysteria2Option) (*Hysteria2, error) {
addr := net.JoinHostPort(option.Server, strconv.Itoa(option.Port)) addr := net.JoinHostPort(option.Server, strconv.Itoa(option.Port))
outbound := &Hysteria2{
Base: &Base{
name: option.Name,
addr: addr,
tp: C.Hysteria2,
pdName: option.ProviderName,
udp: true,
iface: option.Interface,
rmark: option.RoutingMark,
prefer: option.IPVersion,
},
option: &option,
}
outbound.dialer = option.NewDialer(outbound.DialOptions())
singDialer := proxydialer.NewSingDialer(outbound.dialer)
var salamanderPassword string var salamanderPassword string
if len(option.Obfs) > 0 { if len(option.Obfs) > 0 {
if option.ObfsPassword == "" { if option.ObfsPassword == "" {
@@ -138,42 +111,29 @@ func NewHysteria2(option Hysteria2Option) (*Hysteria2, error) {
serverName = option.SNI serverName = option.SNI
} }
tlsConfig, err := ca.GetTLSConfig(ca.Option{ tlsConfig := &tls.Config{
TLSConfig: &tls.Config{ ServerName: serverName,
ServerName: serverName, InsecureSkipVerify: option.SkipCertVerify,
InsecureSkipVerify: option.SkipCertVerify, MinVersion: tls.VersionTLS13,
MinVersion: tls.VersionTLS13, }
},
Fingerprint: option.Fingerprint, var err error
Certificate: option.Certificate, tlsConfig, err = ca.GetTLSConfig(tlsConfig, option.Fingerprint, option.CustomCA, option.CustomCAString)
PrivateKey: option.PrivateKey,
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
if option.ALPN != nil { // structure's Decode will ensure value not nil when input has value even it was set an empty array if len(option.ALPN) > 0 {
tlsConfig.NextProtos = option.ALPN tlsConfig.NextProtos = option.ALPN
} }
tlsClientConfig := tlsConfig
echConfig, err := option.ECHOpts.Parse()
if err != nil {
return nil, err
}
if option.UdpMTU == 0 { if option.UdpMTU == 0 {
// "1200" from quic-go's MaxDatagramSize // "1200" from quic-go's MaxDatagramSize
// "-3" from quic-go's DatagramFrame.MaxDataLen // "-3" from quic-go's DatagramFrame.MaxDataLen
option.UdpMTU = 1200 - 3 option.UdpMTU = 1200 - 3
} }
quicConfig := &quic.Config{ singDialer := proxydialer.NewByNameSingDialer(option.DialerProxy, dialer.NewDialer())
InitialStreamReceiveWindow: option.InitialStreamReceiveWindow,
MaxStreamReceiveWindow: option.MaxStreamReceiveWindow,
InitialConnectionReceiveWindow: option.InitialConnectionReceiveWindow,
MaxConnectionReceiveWindow: option.MaxConnectionReceiveWindow,
}
clientOptions := hysteria2.ClientOptions{ clientOptions := hysteria2.ClientOptions{
Context: context.TODO(), Context: context.TODO(),
@@ -183,46 +143,40 @@ func NewHysteria2(option Hysteria2Option) (*Hysteria2, error) {
ReceiveBPS: StringToBps(option.Down), ReceiveBPS: StringToBps(option.Down),
SalamanderPassword: salamanderPassword, SalamanderPassword: salamanderPassword,
Password: option.Password, Password: option.Password,
TLSConfig: tlsClientConfig, TLSConfig: tlsConfig,
QUICConfig: quicConfig,
UDPDisabled: false, UDPDisabled: false,
CWND: option.CWND, CWND: option.CWND,
UdpMTU: option.UdpMTU, UdpMTU: option.UdpMTU,
ServerAddress: func(ctx context.Context) (*net.UDPAddr, error) { ServerAddress: func(ctx context.Context) (*net.UDPAddr, error) {
udpAddr, err := resolveUDPAddr(ctx, "udp", addr, option.IPVersion) return resolveUDPAddrWithPrefer(ctx, "udp", addr, C.NewDNSPrefer(option.IPVersion))
if err != nil {
return nil, err
}
err = echConfig.ClientHandle(ctx, tlsClientConfig)
if err != nil {
return nil, err
}
return udpAddr, nil
}, },
} }
var ranges utils.IntRanges[uint16] var ranges utils.IntRanges[uint16]
var serverPorts []uint16 var serverAddress []string
if option.Ports != "" { if option.Ports != "" {
ranges, err = utils.NewUnsignedRanges[uint16](option.Ports) ranges, err = utils.NewUnsignedRanges[uint16](option.Ports)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ranges.Range(func(port uint16) bool { ranges.Range(func(port uint16) bool {
serverPorts = append(serverPorts, port) serverAddress = append(serverAddress, net.JoinHostPort(option.Server, strconv.Itoa(int(port))))
return true return true
}) })
if len(serverPorts) > 0 { if len(serverAddress) > 0 {
clientOptions.ServerAddress = func(ctx context.Context) (*net.UDPAddr, error) {
return resolveUDPAddrWithPrefer(ctx, "udp", serverAddress[fastrand.Intn(len(serverAddress))], C.NewDNSPrefer(option.IPVersion))
}
if option.HopInterval == 0 { if option.HopInterval == 0 {
option.HopInterval = defaultHopInterval option.HopInterval = defaultHopInterval
} else if option.HopInterval < minHopInterval { } else if option.HopInterval < minHopInterval {
option.HopInterval = minHopInterval option.HopInterval = minHopInterval
} }
clientOptions.HopInterval = time.Duration(option.HopInterval) * time.Second clientOptions.HopInterval = time.Duration(option.HopInterval) * time.Second
clientOptions.ServerPorts = serverPorts
} }
} }
if option.Port == 0 && len(serverPorts) == 0 { if option.Port == 0 && len(serverAddress) == 0 {
return nil, errors.New("invalid port") return nil, errors.New("invalid port")
} }
@@ -230,7 +184,22 @@ func NewHysteria2(option Hysteria2Option) (*Hysteria2, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
outbound.client = client
outbound := &Hysteria2{
Base: &Base{
name: option.Name,
addr: addr,
tp: C.Hysteria2,
udp: true,
iface: option.Interface,
rmark: option.RoutingMark,
prefer: C.NewDNSPrefer(option.IPVersion),
},
option: &option,
client: client,
dialer: singDialer,
}
runtime.SetFinalizer(outbound, closeHysteria2)
return outbound, nil return outbound, nil
} }

View File

@@ -1,397 +0,0 @@
package outbound
import (
"context"
"crypto/ecdsa"
"crypto/x509"
"encoding/base64"
"errors"
"fmt"
"net"
"net/netip"
"strconv"
"strings"
"sync"
"time"
"github.com/metacubex/mihomo/common/atomic"
"github.com/metacubex/mihomo/common/contextutils"
"github.com/metacubex/mihomo/common/pool"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/resolver"
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/dns"
"github.com/metacubex/mihomo/log"
"github.com/metacubex/mihomo/transport/masque"
"github.com/metacubex/mihomo/transport/tuic/common"
connectip "github.com/metacubex/connect-ip-go"
"github.com/metacubex/quic-go"
wireguard "github.com/metacubex/sing-wireguard"
M "github.com/metacubex/sing/common/metadata"
"github.com/metacubex/tls"
)
type Masque struct {
*Base
tlsConfig *tls.Config
quicConfig *quic.Config
tunDevice wireguard.Device
resolver resolver.Resolver
uri string
runCtx context.Context
runCancel context.CancelFunc
runMutex sync.Mutex
running atomic.Bool
runDevice atomic.Bool
option MasqueOption
}
type MasqueOption struct {
BasicOption
Name string `proxy:"name"`
Server string `proxy:"server"`
Port int `proxy:"port"`
PrivateKey string `proxy:"private-key"`
PublicKey string `proxy:"public-key"`
Ip string `proxy:"ip,omitempty"`
Ipv6 string `proxy:"ipv6,omitempty"`
URI string `proxy:"uri,omitempty"`
SNI string `proxy:"sni,omitempty"`
MTU int `proxy:"mtu,omitempty"`
UDP bool `proxy:"udp,omitempty"`
CongestionController string `proxy:"congestion-controller,omitempty"`
CWND int `proxy:"cwnd,omitempty"`
RemoteDnsResolve bool `proxy:"remote-dns-resolve,omitempty"`
Dns []string `proxy:"dns,omitempty"`
}
func (option MasqueOption) Prefixes() ([]netip.Prefix, error) {
localPrefixes := make([]netip.Prefix, 0, 2)
if len(option.Ip) > 0 {
if !strings.Contains(option.Ip, "/") {
option.Ip = option.Ip + "/32"
}
if prefix, err := netip.ParsePrefix(option.Ip); err == nil {
localPrefixes = append(localPrefixes, prefix)
} else {
return nil, fmt.Errorf("ip address parse error: %w", err)
}
}
if len(option.Ipv6) > 0 {
if !strings.Contains(option.Ipv6, "/") {
option.Ipv6 = option.Ipv6 + "/128"
}
if prefix, err := netip.ParsePrefix(option.Ipv6); err == nil {
localPrefixes = append(localPrefixes, prefix)
} else {
return nil, fmt.Errorf("ipv6 address parse error: %w", err)
}
}
if len(localPrefixes) == 0 {
return nil, errors.New("missing local address")
}
return localPrefixes, nil
}
func NewMasque(option MasqueOption) (*Masque, error) {
outbound := &Masque{
Base: &Base{
name: option.Name,
addr: net.JoinHostPort(option.Server, strconv.Itoa(option.Port)),
tp: C.Masque,
pdName: option.ProviderName,
udp: option.UDP,
iface: option.Interface,
rmark: option.RoutingMark,
prefer: option.IPVersion,
},
}
outbound.dialer = option.NewDialer(outbound.DialOptions())
ctx, cancel := context.WithCancel(context.Background())
outbound.runCtx = ctx
outbound.runCancel = cancel
privKeyB64, err := base64.StdEncoding.DecodeString(option.PrivateKey)
if err != nil {
return nil, fmt.Errorf("failed to decode private key: %v", err)
}
privKey, err := x509.ParseECPrivateKey(privKeyB64)
if err != nil {
return nil, fmt.Errorf("failed to parse private key: %v", err)
}
endpointPubKeyB64, err := base64.StdEncoding.DecodeString(option.PublicKey)
if err != nil {
return nil, fmt.Errorf("failed to decode public key: %v", err)
}
pubKey, err := x509.ParsePKIXPublicKey(endpointPubKeyB64)
if err != nil {
return nil, fmt.Errorf("failed to parse public key: %v", err)
}
ecPubKey, ok := pubKey.(*ecdsa.PublicKey)
if !ok {
return nil, fmt.Errorf("failed to assert public key as ECDSA")
}
uri := option.URI
if uri == "" {
uri = masque.ConnectURI
}
outbound.uri = uri
sni := option.SNI
if sni == "" {
sni = masque.ConnectSNI
}
tlsConfig, err := masque.PrepareTlsConfig(privKey, ecPubKey, sni)
if err != nil {
return nil, fmt.Errorf("failed to prepare TLS config: %v\n", err)
}
outbound.tlsConfig = tlsConfig
outbound.quicConfig = &quic.Config{
EnableDatagrams: true,
InitialPacketSize: 1242,
KeepAlivePeriod: 30 * time.Second,
}
prefixes, err := option.Prefixes()
if err != nil {
return nil, err
}
outbound.option = option
mtu := option.MTU
if mtu == 0 {
mtu = 1280
}
if len(prefixes) == 0 {
return nil, errors.New("missing local address")
}
outbound.tunDevice, err = wireguard.NewStackDevice(prefixes, uint32(mtu))
if err != nil {
return nil, fmt.Errorf("create device: %w", err)
}
var has6 bool
for _, address := range prefixes {
if !address.Addr().Unmap().Is4() {
has6 = true
break
}
}
if option.RemoteDnsResolve && len(option.Dns) > 0 {
nss, err := dns.ParseNameServer(option.Dns)
if err != nil {
return nil, err
}
for i := range nss {
nss[i].ProxyAdapter = outbound
}
outbound.resolver = dns.NewResolver(dns.Config{
Main: nss,
IPv6: has6,
})
}
return outbound, nil
}
func (w *Masque) run(ctx context.Context) error {
if w.running.Load() {
return nil
}
w.runMutex.Lock()
defer w.runMutex.Unlock()
// double-check like sync.Once
if w.running.Load() {
return nil
}
if w.runCtx.Err() != nil {
return w.runCtx.Err()
}
if !w.runDevice.Load() {
err := w.tunDevice.Start()
if err != nil {
return err
}
w.runDevice.Store(true)
}
udpAddr, err := resolveUDPAddr(ctx, "udp", w.addr, w.prefer)
if err != nil {
return err
}
pc, err := w.dialer.ListenPacket(ctx, "udp", "", udpAddr.AddrPort())
if err != nil {
return err
}
quicConn, err := quic.Dial(ctx, pc, udpAddr, w.tlsConfig, w.quicConfig)
if err != nil {
return err
}
common.SetCongestionController(quicConn, w.option.CongestionController, w.option.CWND)
tr, ipConn, err := masque.ConnectTunnel(ctx, quicConn, w.uri)
if err != nil {
_ = pc.Close()
return err
}
w.running.Store(true)
runCtx, runCancel := context.WithCancel(w.runCtx)
contextutils.AfterFunc(runCtx, func() {
w.running.Store(false)
_ = ipConn.Close()
_ = tr.Close()
_ = pc.Close()
})
go func() {
defer runCancel()
buf := pool.Get(pool.UDPBufferSize)
defer pool.Put(buf)
bufs := [][]byte{buf}
sizes := []int{0}
for runCtx.Err() == nil {
_, err := w.tunDevice.Read(bufs, sizes, 0)
if err != nil {
log.Errorln("[Masque](%s) error reading from TUN device: %v", w.name, err)
return
}
icmp, err := ipConn.WritePacket(buf[:sizes[0]])
if err != nil {
if errors.As(err, new(*connectip.CloseError)) {
log.Errorln("[Masque](%s) connection closed while writing to IP connection: %v", w.name, err)
return
}
log.Warnln("[Masque](%s) error writing to IP connection: %v, continuing...", w.name, err)
continue
}
if len(icmp) > 0 {
if _, err := w.tunDevice.Write([][]byte{icmp}, 0); err != nil {
log.Warnln("[Masque](%s) error writing ICMP to TUN device: %v, continuing...", w.name, err)
}
}
}
}()
go func() {
defer runCancel()
buf := pool.Get(pool.UDPBufferSize)
defer pool.Put(buf)
for runCtx.Err() == nil {
n, err := ipConn.ReadPacket(buf)
if err != nil {
if errors.As(err, new(*connectip.CloseError)) {
log.Errorln("[Masque](%s) connection closed while writing to IP connection: %v", w.name, err)
return
}
log.Warnln("[Masque](%s) error reading from IP connection: %v, continuing...", w.name, err)
continue
}
if _, err := w.tunDevice.Write([][]byte{buf[:n]}, 0); err != nil {
log.Errorln("[Masque](%s) error writing to TUN device: %v", w.name, err)
return
}
}
}()
return nil
}
// Close implements C.ProxyAdapter
func (w *Masque) Close() error {
w.runCancel()
if w.tunDevice != nil {
w.tunDevice.Close()
}
return nil
}
func (w *Masque) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn, err error) {
var conn net.Conn
if err = w.run(ctx); err != nil {
return nil, err
}
if !metadata.Resolved() || w.resolver != nil {
r := resolver.DefaultResolver
if w.resolver != nil {
r = w.resolver
}
options := w.DialOptions()
options = append(options, dialer.WithResolver(r))
options = append(options, dialer.WithNetDialer(wgNetDialer{tunDevice: w.tunDevice}))
conn, err = dialer.NewDialer(options...).DialContext(ctx, "tcp", metadata.RemoteAddress())
} else {
conn, err = w.tunDevice.DialContext(ctx, "tcp", M.SocksaddrFrom(metadata.DstIP, metadata.DstPort).Unwrap())
}
if err != nil {
return nil, err
}
if conn == nil {
return nil, errors.New("conn is nil")
}
return NewConn(conn, w), nil
}
func (w *Masque) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (_ C.PacketConn, err error) {
var pc net.PacketConn
if err = w.run(ctx); err != nil {
return nil, err
}
if err = w.ResolveUDP(ctx, metadata); err != nil {
return nil, err
}
pc, err = w.tunDevice.ListenPacket(ctx, M.SocksaddrFrom(metadata.DstIP, metadata.DstPort).Unwrap())
if err != nil {
return nil, err
}
if pc == nil {
return nil, errors.New("packetConn is nil")
}
return newPacketConn(pc, w), nil
}
func (w *Masque) ResolveUDP(ctx context.Context, metadata *C.Metadata) error {
if (!metadata.Resolved() || w.resolver != nil) && metadata.Host != "" {
r := resolver.DefaultResolver
if w.resolver != nil {
r = w.resolver
}
ip, err := resolver.ResolveIPWithResolver(ctx, metadata.Host, r)
if err != nil {
return fmt.Errorf("can't resolve ip: %w", err)
}
metadata.DstIP = ip
}
return nil
}
// ProxyInfo implements C.ProxyAdapter
func (w *Masque) ProxyInfo() C.ProxyInfo {
info := w.Base.ProxyInfo()
info.DialerProxy = w.option.DialerProxy
return info
}
// IsL3Protocol implements C.ProxyAdapter
func (w *Masque) IsL3Protocol(metadata *C.Metadata) bool {
return true
}

View File

@@ -1,355 +0,0 @@
package outbound
import (
"context"
"fmt"
"net"
"net/netip"
"strconv"
"sync"
N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/component/resolver"
C "github.com/metacubex/mihomo/constant"
mieruclient "github.com/enfein/mieru/v3/apis/client"
mierucommon "github.com/enfein/mieru/v3/apis/common"
mierumodel "github.com/enfein/mieru/v3/apis/model"
mierupb "github.com/enfein/mieru/v3/pkg/appctl/appctlpb"
"google.golang.org/protobuf/proto"
)
type Mieru struct {
*Base
option *MieruOption
client mieruclient.Client
mu sync.Mutex
}
type MieruOption struct {
BasicOption
Name string `proxy:"name"`
Server string `proxy:"server"`
Port int `proxy:"port,omitempty"`
PortRange string `proxy:"port-range,omitempty"`
Transport string `proxy:"transport"`
UDP bool `proxy:"udp,omitempty"`
UserName string `proxy:"username"`
Password string `proxy:"password"`
Multiplexing string `proxy:"multiplexing,omitempty"`
HandshakeMode string `proxy:"handshake-mode,omitempty"`
}
type mieruPacketDialer struct {
C.Dialer
}
var _ mierucommon.PacketDialer = (*mieruPacketDialer)(nil)
func (pd mieruPacketDialer) ListenPacket(ctx context.Context, network, laddr, raddr string) (net.PacketConn, error) {
rAddrPort, err := netip.ParseAddrPort(raddr)
if err != nil {
return nil, fmt.Errorf("invalid address %s: %w", raddr, err)
}
return pd.Dialer.ListenPacket(ctx, network, laddr, rAddrPort)
}
type mieruDNSResolver struct {
prefer C.DNSPrefer
}
var _ mierucommon.DNSResolver = (*mieruDNSResolver)(nil)
func (dr mieruDNSResolver) LookupIP(ctx context.Context, network, host string) (_ []net.IP, err error) {
var ip netip.Addr
switch dr.prefer {
case C.IPv4Only:
ip, err = resolver.ResolveIPv4WithResolver(ctx, host, resolver.ProxyServerHostResolver)
case C.IPv6Only:
ip, err = resolver.ResolveIPv6WithResolver(ctx, host, resolver.ProxyServerHostResolver)
case C.IPv6Prefer:
ip, err = resolver.ResolveIPPrefer6WithResolver(ctx, host, resolver.ProxyServerHostResolver)
default:
ip, err = resolver.ResolveIPWithResolver(ctx, host, resolver.ProxyServerHostResolver)
}
if err != nil {
return nil, fmt.Errorf("can't resolve ip: %w", err)
}
// TODO: handle IP4P (due to interface limitations, it's currently impossible to modify the port here)
return []net.IP{ip.AsSlice()}, nil
}
// DialContext implements C.ProxyAdapter
func (m *Mieru) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) {
if err := m.ensureClientIsRunning(); err != nil {
return nil, err
}
addr := metadataToMieruNetAddrSpec(metadata)
c, err := m.client.DialContext(ctx, addr)
if err != nil {
return nil, fmt.Errorf("dial to %s failed: %w", addr, err)
}
return NewConn(c, m), nil
}
// ListenPacketContext implements C.ProxyAdapter
func (m *Mieru) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (_ C.PacketConn, err error) {
if err = m.ResolveUDP(ctx, metadata); err != nil {
return nil, err
}
if err := m.ensureClientIsRunning(); err != nil {
return nil, err
}
c, err := m.client.DialContext(ctx, metadata.UDPAddr())
if err != nil {
return nil, fmt.Errorf("dial to %s failed: %w", metadata.UDPAddr(), err)
}
return newPacketConn(N.NewThreadSafePacketConn(mierucommon.NewUDPAssociateWrapper(mierucommon.NewPacketOverStreamTunnel(c))), m), nil
}
// SupportUOT implements C.ProxyAdapter
func (m *Mieru) SupportUOT() bool {
return true
}
// ProxyInfo implements C.ProxyAdapter
func (m *Mieru) ProxyInfo() C.ProxyInfo {
info := m.Base.ProxyInfo()
info.DialerProxy = m.option.DialerProxy
return info
}
func (m *Mieru) ensureClientIsRunning() error {
m.mu.Lock()
defer m.mu.Unlock()
if m.client.IsRunning() {
return nil
}
// Create a dialer and add it to the client config, before starting the client.
config, err := m.client.Load()
if err != nil {
return err
}
config.Dialer = m.dialer
config.PacketDialer = mieruPacketDialer{Dialer: m.dialer}
config.Resolver = mieruDNSResolver{prefer: m.prefer}
if err := m.client.Store(config); err != nil {
return err
}
if err := m.client.Start(); err != nil {
return fmt.Errorf("failed to start mieru client: %w", err)
}
return nil
}
func NewMieru(option MieruOption) (*Mieru, error) {
config, err := buildMieruClientConfig(option)
if err != nil {
return nil, fmt.Errorf("failed to build mieru client config: %w", err)
}
c := mieruclient.NewClient()
if err := c.Store(config); err != nil {
return nil, fmt.Errorf("failed to store mieru client config: %w", err)
}
// Client is started lazily on the first use.
var addr string
if option.Port != 0 {
addr = net.JoinHostPort(option.Server, strconv.Itoa(option.Port))
} else {
beginPort, _, _ := beginAndEndPortFromPortRange(option.PortRange)
addr = net.JoinHostPort(option.Server, strconv.Itoa(beginPort))
}
outbound := &Mieru{
Base: &Base{
name: option.Name,
addr: addr,
tp: C.Mieru,
pdName: option.ProviderName,
udp: option.UDP,
xudp: false,
iface: option.Interface,
rmark: option.RoutingMark,
prefer: option.IPVersion,
},
option: &option,
client: c,
}
outbound.dialer = option.NewDialer(outbound.DialOptions())
return outbound, nil
}
// Close implements C.ProxyAdapter
func (m *Mieru) Close() error {
m.mu.Lock()
defer m.mu.Unlock()
if m.client != nil && m.client.IsRunning() {
return m.client.Stop()
}
return nil
}
func metadataToMieruNetAddrSpec(metadata *C.Metadata) mierumodel.NetAddrSpec {
spec := mierumodel.NetAddrSpec{
Net: metadata.NetWork.String(),
}
if metadata.Host != "" {
spec.AddrSpec = mierumodel.AddrSpec{
FQDN: metadata.Host,
Port: int(metadata.DstPort),
}
} else {
spec.AddrSpec = mierumodel.AddrSpec{
IP: metadata.DstIP.AsSlice(),
Port: int(metadata.DstPort),
}
}
return spec
}
func buildMieruClientConfig(option MieruOption) (*mieruclient.ClientConfig, error) {
if err := validateMieruOption(option); err != nil {
return nil, fmt.Errorf("failed to validate mieru option: %w", err)
}
var transportProtocol = mierupb.TransportProtocol_UNKNOWN_TRANSPORT_PROTOCOL.Enum()
switch option.Transport {
case "TCP":
transportProtocol = mierupb.TransportProtocol_TCP.Enum()
case "UDP":
transportProtocol = mierupb.TransportProtocol_UDP.Enum()
}
var server *mierupb.ServerEndpoint
if net.ParseIP(option.Server) != nil {
// server is an IP address
if option.PortRange != "" {
server = &mierupb.ServerEndpoint{
IpAddress: proto.String(option.Server),
PortBindings: []*mierupb.PortBinding{
{
PortRange: proto.String(option.PortRange),
Protocol: transportProtocol,
},
},
}
} else {
server = &mierupb.ServerEndpoint{
IpAddress: proto.String(option.Server),
PortBindings: []*mierupb.PortBinding{
{
Port: proto.Int32(int32(option.Port)),
Protocol: transportProtocol,
},
},
}
}
} else {
// server is a domain name
if option.PortRange != "" {
server = &mierupb.ServerEndpoint{
DomainName: proto.String(option.Server),
PortBindings: []*mierupb.PortBinding{
{
PortRange: proto.String(option.PortRange),
Protocol: transportProtocol,
},
},
}
} else {
server = &mierupb.ServerEndpoint{
DomainName: proto.String(option.Server),
PortBindings: []*mierupb.PortBinding{
{
Port: proto.Int32(int32(option.Port)),
Protocol: transportProtocol,
},
},
}
}
}
config := &mieruclient.ClientConfig{
Profile: &mierupb.ClientProfile{
ProfileName: proto.String(option.Name),
User: &mierupb.User{
Name: proto.String(option.UserName),
Password: proto.String(option.Password),
},
Servers: []*mierupb.ServerEndpoint{server},
},
DNSConfig: &mierucommon.ClientDNSConfig{
BypassDialerDNS: true,
},
}
if multiplexing, ok := mierupb.MultiplexingLevel_value[option.Multiplexing]; ok {
config.Profile.Multiplexing = &mierupb.MultiplexingConfig{
Level: mierupb.MultiplexingLevel(multiplexing).Enum(),
}
}
if handshakeMode, ok := mierupb.HandshakeMode_value[option.HandshakeMode]; ok {
config.Profile.HandshakeMode = (*mierupb.HandshakeMode)(&handshakeMode)
}
return config, nil
}
func validateMieruOption(option MieruOption) error {
if option.Name == "" {
return fmt.Errorf("name is empty")
}
if option.Server == "" {
return fmt.Errorf("server is empty")
}
if option.Port == 0 && option.PortRange == "" {
return fmt.Errorf("either port or port-range must be set")
}
if option.Port != 0 && option.PortRange != "" {
return fmt.Errorf("port and port-range cannot be set at the same time")
}
if option.Port != 0 && (option.Port < 1 || option.Port > 65535) {
return fmt.Errorf("port must be between 1 and 65535")
}
if option.PortRange != "" {
begin, end, err := beginAndEndPortFromPortRange(option.PortRange)
if err != nil {
return fmt.Errorf("invalid port-range format")
}
if begin < 1 || begin > 65535 {
return fmt.Errorf("begin port must be between 1 and 65535")
}
if end < 1 || end > 65535 {
return fmt.Errorf("end port must be between 1 and 65535")
}
if begin > end {
return fmt.Errorf("begin port must be less than or equal to end port")
}
}
if option.Transport != "TCP" && option.Transport != "UDP" {
return fmt.Errorf("transport must be TCP or UDP")
}
if option.UserName == "" {
return fmt.Errorf("username is empty")
}
if option.Password == "" {
return fmt.Errorf("password is empty")
}
if option.Multiplexing != "" {
if _, ok := mierupb.MultiplexingLevel_value[option.Multiplexing]; !ok {
return fmt.Errorf("invalid multiplexing level: %s", option.Multiplexing)
}
}
if option.HandshakeMode != "" {
if _, ok := mierupb.HandshakeMode_value[option.HandshakeMode]; !ok {
return fmt.Errorf("invalid handshake mode: %s", option.HandshakeMode)
}
}
return nil
}
func beginAndEndPortFromPortRange(portRange string) (int, int, error) {
var begin, end int
_, err := fmt.Sscanf(portRange, "%d-%d", &begin, &end)
return begin, end, err
}

View File

@@ -1,92 +0,0 @@
package outbound
import "testing"
func TestNewMieru(t *testing.T) {
testCases := []struct {
option MieruOption
wantBaseAddr string
}{
{
option: MieruOption{
Name: "test",
Server: "1.2.3.4",
Port: 10000,
Transport: "TCP",
UserName: "test",
Password: "test",
},
wantBaseAddr: "1.2.3.4:10000",
},
{
option: MieruOption{
Name: "test",
Server: "2001:db8::1",
PortRange: "10001-10002",
Transport: "TCP",
UserName: "test",
Password: "test",
},
wantBaseAddr: "[2001:db8::1]:10001",
},
{
option: MieruOption{
Name: "test",
Server: "example.com",
Port: 10003,
Transport: "UDP",
UserName: "test",
Password: "test",
},
wantBaseAddr: "example.com:10003",
},
}
for _, testCase := range testCases {
mieru, err := NewMieru(testCase.option)
if err != nil {
t.Error(err)
}
if mieru.addr != testCase.wantBaseAddr {
t.Errorf("got addr %q, want %q", mieru.addr, testCase.wantBaseAddr)
}
}
}
func TestBeginAndEndPortFromPortRange(t *testing.T) {
testCases := []struct {
input string
begin int
end int
hasErr bool
}{
{"1-10", 1, 10, false},
{"1000-2000", 1000, 2000, false},
{"65535-65535", 65535, 65535, false},
{"1", 0, 0, true},
{"1-", 0, 0, true},
{"-10", 0, 0, true},
{"a-b", 0, 0, true},
{"1-b", 0, 0, true},
{"a-10", 0, 0, true},
}
for _, testCase := range testCases {
begin, end, err := beginAndEndPortFromPortRange(testCase.input)
if testCase.hasErr {
if err == nil {
t.Errorf("beginAndEndPortFromPortRange(%s) should return an error", testCase.input)
}
} else {
if err != nil {
t.Errorf("beginAndEndPortFromPortRange(%s) should not return an error, but got %v", testCase.input, err)
}
if begin != testCase.begin {
t.Errorf("beginAndEndPortFromPortRange(%s) begin port mismatch, got %d, want %d", testCase.input, begin, testCase.begin)
}
if end != testCase.end {
t.Errorf("beginAndEndPortFromPortRange(%s) end port mismatch, got %d, want %d", testCase.input, end, testCase.end)
}
}
}
}

View File

@@ -13,29 +13,23 @@ import (
type RealityOptions struct { type RealityOptions struct {
PublicKey string `proxy:"public-key"` PublicKey string `proxy:"public-key"`
ShortID string `proxy:"short-id"` ShortID string `proxy:"short-id"`
SupportX25519MLKEM768 bool `proxy:"support-x25519mlkem768"`
} }
func (o RealityOptions) Parse() (*tlsC.RealityConfig, error) { func (o RealityOptions) Parse() (*tlsC.RealityConfig, error) {
if o.PublicKey != "" { if o.PublicKey != "" {
config := new(tlsC.RealityConfig) config := new(tlsC.RealityConfig)
config.SupportX25519MLKEM768 = o.SupportX25519MLKEM768
const x25519ScalarSize = 32 const x25519ScalarSize = 32
publicKey, err := base64.RawURLEncoding.DecodeString(o.PublicKey) var publicKey [x25519ScalarSize]byte
if err != nil || len(publicKey) != x25519ScalarSize { n, err := base64.RawURLEncoding.Decode(publicKey[:], []byte(o.PublicKey))
if err != nil || n != x25519ScalarSize {
return nil, errors.New("invalid REALITY public key") return nil, errors.New("invalid REALITY public key")
} }
config.PublicKey, err = ecdh.X25519().NewPublicKey(publicKey) config.PublicKey, err = ecdh.X25519().NewPublicKey(publicKey[:])
if err != nil { if err != nil {
return nil, fmt.Errorf("fail to create REALITY public key: %w", err) return nil, fmt.Errorf("fail to create REALITY public key: %w", err)
} }
n := hex.DecodedLen(len(o.ShortID))
if n > tlsC.RealityMaxShortIDLen {
return nil, errors.New("invalid REALITY short id")
}
n, err = hex.Decode(config.ShortID[:], []byte(o.ShortID)) n, err = hex.Decode(config.ShortID[:], []byte(o.ShortID))
if err != nil || n > tlsC.RealityMaxShortIDLen { if err != nil || n > tlsC.RealityMaxShortIDLen {
return nil, errors.New("invalid REALITY short ID") return nil, errors.New("invalid REALITY short ID")

View File

@@ -4,10 +4,10 @@ import (
"context" "context"
"io" "io"
"net" "net"
"net/netip"
"time" "time"
"github.com/metacubex/mihomo/common/buf" "github.com/metacubex/mihomo/common/buf"
"github.com/metacubex/mihomo/component/dialer"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
) )
@@ -17,12 +17,11 @@ type Reject struct {
} }
type RejectOption struct { type RejectOption struct {
BasicOption
Name string `proxy:"name"` Name string `proxy:"name"`
} }
// DialContext implements C.ProxyAdapter // DialContext implements C.ProxyAdapter
func (r *Reject) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) { func (r *Reject) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.Conn, error) {
if r.drop { if r.drop {
return NewConn(dropConn{}, r), nil return NewConn(dropConn{}, r), nil
} }
@@ -30,25 +29,15 @@ func (r *Reject) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn,
} }
// ListenPacketContext implements C.ProxyAdapter // ListenPacketContext implements C.ProxyAdapter
func (r *Reject) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (C.PacketConn, error) { func (r *Reject) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.PacketConn, error) {
if err := r.ResolveUDP(ctx, metadata); err != nil {
return nil, err
}
return newPacketConn(&nopPacketConn{}, r), nil return newPacketConn(&nopPacketConn{}, r), nil
} }
func (r *Reject) ResolveUDP(ctx context.Context, metadata *C.Metadata) error {
if !metadata.Resolved() {
metadata.DstIP = netip.IPv4Unspecified()
}
return nil
}
func NewRejectWithOption(option RejectOption) *Reject { func NewRejectWithOption(option RejectOption) *Reject {
return &Reject{ return &Reject{
Base: &Base{ Base: &Base{
name: option.Name, name: option.Name,
tp: C.Reject, tp: C.Direct,
udp: true, udp: true,
}, },
} }

View File

@@ -2,25 +2,27 @@ package outbound
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"net" "net"
"strconv" "strconv"
N "github.com/metacubex/mihomo/common/net" N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/common/structure" "github.com/metacubex/mihomo/common/structure"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/proxydialer"
"github.com/metacubex/mihomo/component/resolver"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/ntp"
gost "github.com/metacubex/mihomo/transport/gost-plugin"
"github.com/metacubex/mihomo/transport/kcptun"
"github.com/metacubex/mihomo/transport/restls" "github.com/metacubex/mihomo/transport/restls"
obfs "github.com/metacubex/mihomo/transport/simple-obfs" obfs "github.com/metacubex/mihomo/transport/simple-obfs"
shadowtls "github.com/metacubex/mihomo/transport/sing-shadowtls" shadowtls "github.com/metacubex/mihomo/transport/sing-shadowtls"
v2rayObfs "github.com/metacubex/mihomo/transport/v2ray-plugin" v2rayObfs "github.com/metacubex/mihomo/transport/v2ray-plugin"
restlsC "github.com/3andne/restls-client-go"
shadowsocks "github.com/metacubex/sing-shadowsocks2" shadowsocks "github.com/metacubex/sing-shadowsocks2"
"github.com/metacubex/sing/common/bufio" "github.com/sagernet/sing/common/bufio"
M "github.com/metacubex/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
"github.com/metacubex/sing/common/uot" "github.com/sagernet/sing/common/uot"
) )
type ShadowSocks struct { type ShadowSocks struct {
@@ -32,10 +34,8 @@ type ShadowSocks struct {
obfsMode string obfsMode string
obfsOption *simpleObfsOption obfsOption *simpleObfsOption
v2rayOption *v2rayObfs.Option v2rayOption *v2rayObfs.Option
gostOption *gost.Option
shadowTLSOption *shadowtls.ShadowTLSOption shadowTLSOption *shadowtls.ShadowTLSOption
restlsConfig *restls.Config restlsConfig *restlsC.Config
kcptunClient *kcptun.Client
} }
type ShadowSocksOption struct { type ShadowSocksOption struct {
@@ -63,10 +63,7 @@ type v2rayObfsOption struct {
Host string `obfs:"host,omitempty"` Host string `obfs:"host,omitempty"`
Path string `obfs:"path,omitempty"` Path string `obfs:"path,omitempty"`
TLS bool `obfs:"tls,omitempty"` TLS bool `obfs:"tls,omitempty"`
ECHOpts ECHOptions `obfs:"ech-opts,omitempty"`
Fingerprint string `obfs:"fingerprint,omitempty"` Fingerprint string `obfs:"fingerprint,omitempty"`
Certificate string `obfs:"certificate,omitempty"`
PrivateKey string `obfs:"private-key,omitempty"`
Headers map[string]string `obfs:"headers,omitempty"` Headers map[string]string `obfs:"headers,omitempty"`
SkipCertVerify bool `obfs:"skip-cert-verify,omitempty"` SkipCertVerify bool `obfs:"skip-cert-verify,omitempty"`
Mux bool `obfs:"mux,omitempty"` Mux bool `obfs:"mux,omitempty"`
@@ -74,29 +71,12 @@ type v2rayObfsOption struct {
V2rayHttpUpgradeFastOpen bool `obfs:"v2ray-http-upgrade-fast-open,omitempty"` V2rayHttpUpgradeFastOpen bool `obfs:"v2ray-http-upgrade-fast-open,omitempty"`
} }
type gostObfsOption struct {
Mode string `obfs:"mode"`
Host string `obfs:"host,omitempty"`
Path string `obfs:"path,omitempty"`
TLS bool `obfs:"tls,omitempty"`
ECHOpts ECHOptions `obfs:"ech-opts,omitempty"`
Fingerprint string `obfs:"fingerprint,omitempty"`
Certificate string `obfs:"certificate,omitempty"`
PrivateKey string `obfs:"private-key,omitempty"`
Headers map[string]string `obfs:"headers,omitempty"`
SkipCertVerify bool `obfs:"skip-cert-verify,omitempty"`
Mux bool `obfs:"mux,omitempty"`
}
type shadowTLSOption struct { type shadowTLSOption struct {
Password string `obfs:"password,omitempty"` Password string `obfs:"password"`
Host string `obfs:"host"` Host string `obfs:"host"`
Fingerprint string `obfs:"fingerprint,omitempty"` Fingerprint string `obfs:"fingerprint,omitempty"`
Certificate string `obfs:"certificate,omitempty"` SkipCertVerify bool `obfs:"skip-cert-verify,omitempty"`
PrivateKey string `obfs:"private-key,omitempty"` Version int `obfs:"version,omitempty"`
SkipCertVerify bool `obfs:"skip-cert-verify,omitempty"`
Version int `obfs:"version,omitempty"`
ALPN []string `obfs:"alpn,omitempty"`
} }
type restlsOption struct { type restlsOption struct {
@@ -106,36 +86,8 @@ type restlsOption struct {
RestlsScript string `obfs:"restls-script,omitempty"` RestlsScript string `obfs:"restls-script,omitempty"`
} }
type kcpTunOption struct {
Key string `obfs:"key,omitempty"`
Crypt string `obfs:"crypt,omitempty"`
Mode string `obfs:"mode,omitempty"`
Conn int `obfs:"conn,omitempty"`
AutoExpire int `obfs:"autoexpire,omitempty"`
ScavengeTTL int `obfs:"scavengettl,omitempty"`
MTU int `obfs:"mtu,omitempty"`
RateLimit int `obfs:"ratelimit,omitempty"`
SndWnd int `obfs:"sndwnd,omitempty"`
RcvWnd int `obfs:"rcvwnd,omitempty"`
DataShard int `obfs:"datashard,omitempty"`
ParityShard int `obfs:"parityshard,omitempty"`
DSCP int `obfs:"dscp,omitempty"`
NoComp bool `obfs:"nocomp,omitempty"`
AckNodelay bool `obfs:"acknodelay,omitempty"`
NoDelay int `obfs:"nodelay,omitempty"`
Interval int `obfs:"interval,omitempty"`
Resend int `obfs:"resend,omitempty"`
NoCongestion int `obfs:"nc,omitempty"`
SockBuf int `obfs:"sockbuf,omitempty"`
SmuxVer int `obfs:"smuxver,omitempty"`
SmuxBuf int `obfs:"smuxbuf,omitempty"`
FrameSize int `obfs:"framesize,omitempty"`
StreamBuf int `obfs:"streambuf,omitempty"`
KeepAlive int `obfs:"keepalive,omitempty"`
}
// StreamConnContext implements C.ProxyAdapter // StreamConnContext implements C.ProxyAdapter
func (ss *ShadowSocks) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (_ net.Conn, err error) { func (ss *ShadowSocks) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (net.Conn, error) {
useEarly := false useEarly := false
switch ss.obfsMode { switch ss.obfsMode {
case "tls": case "tls":
@@ -144,23 +96,20 @@ func (ss *ShadowSocks) StreamConnContext(ctx context.Context, c net.Conn, metada
_, port, _ := net.SplitHostPort(ss.addr) _, port, _ := net.SplitHostPort(ss.addr)
c = obfs.NewHTTPObfs(c, ss.obfsOption.Host, port) c = obfs.NewHTTPObfs(c, ss.obfsOption.Host, port)
case "websocket": case "websocket":
if ss.v2rayOption != nil { var err error
c, err = v2rayObfs.NewV2rayObfs(ctx, c, ss.v2rayOption) c, err = v2rayObfs.NewV2rayObfs(ctx, c, ss.v2rayOption)
} else if ss.gostOption != nil {
c, err = gost.NewGostWebsocket(ctx, c, ss.gostOption)
} else {
return nil, fmt.Errorf("plugin options is required")
}
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error: %w", ss.addr, err) return nil, fmt.Errorf("%s connect error: %w", ss.addr, err)
} }
case shadowtls.Mode: case shadowtls.Mode:
var err error
c, err = shadowtls.NewShadowTLS(ctx, c, ss.shadowTLSOption) c, err = shadowtls.NewShadowTLS(ctx, c, ss.shadowTLSOption)
if err != nil { if err != nil {
return nil, err return nil, err
} }
useEarly = true useEarly = true
case restls.Mode: case restls.Mode:
var err error
c, err = restls.NewRestls(ctx, c, ss.restlsConfig) c, err = restls.NewRestls(ctx, c, ss.restlsConfig)
if err != nil { if err != nil {
return nil, fmt.Errorf("%s (restls) connect error: %w", ss.addr, err) return nil, fmt.Errorf("%s (restls) connect error: %w", ss.addr, err)
@@ -168,12 +117,6 @@ func (ss *ShadowSocks) StreamConnContext(ctx context.Context, c net.Conn, metada
useEarly = true useEarly = true
} }
useEarly = useEarly || N.NeedHandshake(c) useEarly = useEarly || N.NeedHandshake(c)
if !useEarly {
if ctx.Done() != nil {
done := N.SetupContextForConn(ctx, c)
defer done(&err)
}
}
if metadata.NetWork == C.UDP && ss.option.UDPOverTCP { if metadata.NetWork == C.UDP && ss.option.UDPOverTCP {
uotDestination := uot.RequestDestination(uint8(ss.option.UDPOverTCPVersion)) uotDestination := uot.RequestDestination(uint8(ss.option.UDPOverTCPVersion))
if useEarly { if useEarly {
@@ -190,31 +133,23 @@ func (ss *ShadowSocks) StreamConnContext(ctx context.Context, c net.Conn, metada
} }
// DialContext implements C.ProxyAdapter // DialContext implements C.ProxyAdapter
func (ss *ShadowSocks) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn, err error) { func (ss *ShadowSocks) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.Conn, err error) {
var c net.Conn return ss.DialContextWithDialer(ctx, dialer.NewDialer(ss.Base.DialOptions(opts...)...), metadata)
if ss.kcptunClient != nil { }
c, err = ss.kcptunClient.OpenStream(ctx, func(ctx context.Context) (net.PacketConn, net.Addr, error) {
if err = ss.ResolveUDP(ctx, metadata); err != nil {
return nil, nil, err
}
addr, err := resolveUDPAddr(ctx, "udp", ss.addr, ss.prefer)
if err != nil {
return nil, nil, err
}
pc, err := ss.dialer.ListenPacket(ctx, "udp", "", addr.AddrPort()) // DialContextWithDialer implements C.ProxyAdapter
if err != nil { func (ss *ShadowSocks) DialContextWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (_ C.Conn, err error) {
return nil, nil, err if len(ss.option.DialerProxy) > 0 {
} dialer, err = proxydialer.NewByName(ss.option.DialerProxy, dialer)
if err != nil {
return pc, addr, nil return nil, err
}) }
} else {
c, err = ss.dialer.DialContext(ctx, "tcp", ss.addr)
} }
c, err := dialer.DialContext(ctx, "tcp", ss.addr)
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error: %w", ss.addr, err) return nil, fmt.Errorf("%s connect error: %w", ss.addr, err)
} }
N.TCPKeepAlive(c)
defer func(c net.Conn) { defer func(c net.Conn) {
safeConnClose(c, err) safeConnClose(c, err)
@@ -225,23 +160,31 @@ func (ss *ShadowSocks) DialContext(ctx context.Context, metadata *C.Metadata) (_
} }
// ListenPacketContext implements C.ProxyAdapter // ListenPacketContext implements C.ProxyAdapter
func (ss *ShadowSocks) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (C.PacketConn, error) { func (ss *ShadowSocks) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.PacketConn, error) {
return ss.ListenPacketWithDialer(ctx, dialer.NewDialer(ss.Base.DialOptions(opts...)...), metadata)
}
// ListenPacketWithDialer implements C.ProxyAdapter
func (ss *ShadowSocks) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (_ C.PacketConn, err error) {
if len(ss.option.DialerProxy) > 0 {
dialer, err = proxydialer.NewByName(ss.option.DialerProxy, dialer)
if err != nil {
return nil, err
}
}
if ss.option.UDPOverTCP { if ss.option.UDPOverTCP {
tcpConn, err := ss.DialContext(ctx, metadata) tcpConn, err := ss.DialContextWithDialer(ctx, dialer, metadata)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return ss.ListenPacketOnStreamConn(ctx, tcpConn, metadata) return ss.ListenPacketOnStreamConn(ctx, tcpConn, metadata)
} }
if err := ss.ResolveUDP(ctx, metadata); err != nil { addr, err := resolveUDPAddrWithPrefer(ctx, "udp", ss.addr, ss.prefer)
return nil, err
}
addr, err := resolveUDPAddr(ctx, "udp", ss.addr, ss.prefer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
pc, err := ss.dialer.ListenPacket(ctx, "udp", "", addr.AddrPort()) pc, err := dialer.ListenPacket(ctx, "udp", "", addr.AddrPort())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -249,19 +192,23 @@ func (ss *ShadowSocks) ListenPacketContext(ctx context.Context, metadata *C.Meta
return newPacketConn(pc, ss), nil return newPacketConn(pc, ss), nil
} }
// ProxyInfo implements C.ProxyAdapter // SupportWithDialer implements C.ProxyAdapter
func (ss *ShadowSocks) ProxyInfo() C.ProxyInfo { func (ss *ShadowSocks) SupportWithDialer() C.NetWork {
info := ss.Base.ProxyInfo() return C.ALLNet
info.DialerProxy = ss.option.DialerProxy
return info
} }
// ListenPacketOnStreamConn implements C.ProxyAdapter // ListenPacketOnStreamConn implements C.ProxyAdapter
func (ss *ShadowSocks) ListenPacketOnStreamConn(ctx context.Context, c net.Conn, metadata *C.Metadata) (_ C.PacketConn, err error) { func (ss *ShadowSocks) ListenPacketOnStreamConn(ctx context.Context, c net.Conn, metadata *C.Metadata) (_ C.PacketConn, err error) {
if ss.option.UDPOverTCP { if ss.option.UDPOverTCP {
if err = ss.ResolveUDP(ctx, metadata); err != nil { // ss uot use stream-oriented udp with a special address, so we need a net.UDPAddr
return nil, err if !metadata.Resolved() {
ip, err := resolver.ResolveIP(ctx, metadata.Host)
if err != nil {
return nil, errors.New("can't resolve ip")
}
metadata.DstIP = ip
} }
destination := M.SocksaddrFromNet(metadata.UDPAddr()) destination := M.SocksaddrFromNet(metadata.UDPAddr())
if ss.option.UDPOverTCPVersion == uot.LegacyVersion { if ss.option.UDPOverTCPVersion == uot.LegacyVersion {
return newPacketConn(N.NewThreadSafePacketConn(uot.NewConn(c, uot.Request{Destination: destination})), ss), nil return newPacketConn(N.NewThreadSafePacketConn(uot.NewConn(c, uot.Request{Destination: destination})), ss), nil
@@ -277,29 +224,19 @@ func (ss *ShadowSocks) SupportUOT() bool {
return ss.option.UDPOverTCP return ss.option.UDPOverTCP
} }
func (ss *ShadowSocks) Close() error {
if ss.kcptunClient != nil {
return ss.kcptunClient.Close()
}
return nil
}
func NewShadowSocks(option ShadowSocksOption) (*ShadowSocks, error) { func NewShadowSocks(option ShadowSocksOption) (*ShadowSocks, error) {
addr := net.JoinHostPort(option.Server, strconv.Itoa(option.Port)) addr := net.JoinHostPort(option.Server, strconv.Itoa(option.Port))
method, err := shadowsocks.CreateMethod(option.Cipher, shadowsocks.MethodOptions{ method, err := shadowsocks.CreateMethod(context.Background(), option.Cipher, shadowsocks.MethodOptions{
Password: option.Password, Password: option.Password,
TimeFunc: ntp.Now,
}) })
if err != nil { if err != nil {
return nil, fmt.Errorf("ss %s cipher: %s initialize error: %w", addr, option.Cipher, err) return nil, fmt.Errorf("ss %s initialize error: %w", addr, err)
} }
var v2rayOption *v2rayObfs.Option var v2rayOption *v2rayObfs.Option
var gostOption *gost.Option
var obfsOption *simpleObfsOption var obfsOption *simpleObfsOption
var shadowTLSOpt *shadowtls.ShadowTLSOption var shadowTLSOpt *shadowtls.ShadowTLSOption
var restlsConfig *restls.Config var restlsConfig *restlsC.Config
var kcptunClient *kcptun.Client
obfsMode := "" obfsMode := ""
decoder := structure.NewDecoder(structure.Option{TagName: "obfs", WeaklyTypedInput: true}) decoder := structure.NewDecoder(structure.Option{TagName: "obfs", WeaklyTypedInput: true})
@@ -336,45 +273,6 @@ func NewShadowSocks(option ShadowSocksOption) (*ShadowSocks, error) {
if opts.TLS { if opts.TLS {
v2rayOption.TLS = true v2rayOption.TLS = true
v2rayOption.SkipCertVerify = opts.SkipCertVerify v2rayOption.SkipCertVerify = opts.SkipCertVerify
v2rayOption.Fingerprint = opts.Fingerprint
v2rayOption.Certificate = opts.Certificate
v2rayOption.PrivateKey = opts.PrivateKey
echConfig, err := opts.ECHOpts.Parse()
if err != nil {
return nil, fmt.Errorf("ss %s initialize v2ray-plugin error: %w", addr, err)
}
v2rayOption.ECHConfig = echConfig
}
} else if option.Plugin == "gost-plugin" {
opts := gostObfsOption{Host: "bing.com", Mux: true}
if err := decoder.Decode(option.PluginOpts, &opts); err != nil {
return nil, fmt.Errorf("ss %s initialize gost-plugin error: %w", addr, err)
}
if opts.Mode != "websocket" {
return nil, fmt.Errorf("ss %s obfs mode error: %s", addr, opts.Mode)
}
obfsMode = opts.Mode
gostOption = &gost.Option{
Host: opts.Host,
Path: opts.Path,
Headers: opts.Headers,
Mux: opts.Mux,
}
if opts.TLS {
gostOption.TLS = true
gostOption.SkipCertVerify = opts.SkipCertVerify
gostOption.Fingerprint = opts.Fingerprint
gostOption.Certificate = opts.Certificate
gostOption.PrivateKey = opts.PrivateKey
echConfig, err := opts.ECHOpts.Parse()
if err != nil {
return nil, fmt.Errorf("ss %s initialize gost-plugin error: %w", addr, err)
}
gostOption.ECHConfig = echConfig
} }
} else if option.Plugin == shadowtls.Mode { } else if option.Plugin == shadowtls.Mode {
obfsMode = shadowtls.Mode obfsMode = shadowtls.Mode
@@ -389,18 +287,10 @@ func NewShadowSocks(option ShadowSocksOption) (*ShadowSocks, error) {
Password: opt.Password, Password: opt.Password,
Host: opt.Host, Host: opt.Host,
Fingerprint: opt.Fingerprint, Fingerprint: opt.Fingerprint,
Certificate: opt.Certificate,
PrivateKey: opt.PrivateKey,
ClientFingerprint: option.ClientFingerprint, ClientFingerprint: option.ClientFingerprint,
SkipCertVerify: opt.SkipCertVerify, SkipCertVerify: opt.SkipCertVerify,
Version: opt.Version, Version: opt.Version,
} }
if opt.ALPN != nil { // structure's Decode will ensure value not nil when input has value even it was set an empty array
shadowTLSOpt.ALPN = opt.ALPN
} else {
shadowTLSOpt.ALPN = shadowtls.DefaultALPN
}
} else if option.Plugin == restls.Mode { } else if option.Plugin == restls.Mode {
obfsMode = restls.Mode obfsMode = restls.Mode
restlsOpt := &restlsOption{} restlsOpt := &restlsOption{}
@@ -408,46 +298,11 @@ func NewShadowSocks(option ShadowSocksOption) (*ShadowSocks, error) {
return nil, fmt.Errorf("ss %s initialize restls-plugin error: %w", addr, err) return nil, fmt.Errorf("ss %s initialize restls-plugin error: %w", addr, err)
} }
restlsConfig, err = restls.NewRestlsConfig(restlsOpt.Host, restlsOpt.Password, restlsOpt.VersionHint, restlsOpt.RestlsScript, option.ClientFingerprint) restlsConfig, err = restlsC.NewRestlsConfig(restlsOpt.Host, restlsOpt.Password, restlsOpt.VersionHint, restlsOpt.RestlsScript, option.ClientFingerprint)
if err != nil { if err != nil {
return nil, fmt.Errorf("ss %s initialize restls-plugin error: %w", addr, err) return nil, fmt.Errorf("ss %s initialize restls-plugin error: %w", addr, err)
} }
} else if option.Plugin == kcptun.Mode {
obfsMode = kcptun.Mode
kcptunOpt := &kcpTunOption{}
if err := decoder.Decode(option.PluginOpts, kcptunOpt); err != nil {
return nil, fmt.Errorf("ss %s initialize kcptun-plugin error: %w", addr, err)
}
kcptunClient = kcptun.NewClient(kcptun.Config{
Key: kcptunOpt.Key,
Crypt: kcptunOpt.Crypt,
Mode: kcptunOpt.Mode,
Conn: kcptunOpt.Conn,
AutoExpire: kcptunOpt.AutoExpire,
ScavengeTTL: kcptunOpt.ScavengeTTL,
MTU: kcptunOpt.MTU,
RateLimit: kcptunOpt.RateLimit,
SndWnd: kcptunOpt.SndWnd,
RcvWnd: kcptunOpt.RcvWnd,
DataShard: kcptunOpt.DataShard,
ParityShard: kcptunOpt.ParityShard,
DSCP: kcptunOpt.DSCP,
NoComp: kcptunOpt.NoComp,
AckNodelay: kcptunOpt.AckNodelay,
NoDelay: kcptunOpt.NoDelay,
Interval: kcptunOpt.Interval,
Resend: kcptunOpt.Resend,
NoCongestion: kcptunOpt.NoCongestion,
SockBuf: kcptunOpt.SockBuf,
SmuxVer: kcptunOpt.SmuxVer,
SmuxBuf: kcptunOpt.SmuxBuf,
FrameSize: kcptunOpt.FrameSize,
StreamBuf: kcptunOpt.StreamBuf,
KeepAlive: kcptunOpt.KeepAlive,
})
option.UDPOverTCP = true // must open uot
} }
switch option.UDPOverTCPVersion { switch option.UDPOverTCPVersion {
case uot.Version, uot.LegacyVersion: case uot.Version, uot.LegacyVersion:
@@ -457,30 +312,25 @@ func NewShadowSocks(option ShadowSocksOption) (*ShadowSocks, error) {
return nil, fmt.Errorf("ss %s unknown udp over tcp protocol version: %d", addr, option.UDPOverTCPVersion) return nil, fmt.Errorf("ss %s unknown udp over tcp protocol version: %d", addr, option.UDPOverTCPVersion)
} }
outbound := &ShadowSocks{ return &ShadowSocks{
Base: &Base{ Base: &Base{
name: option.Name, name: option.Name,
addr: addr, addr: addr,
tp: C.Shadowsocks, tp: C.Shadowsocks,
pdName: option.ProviderName,
udp: option.UDP, udp: option.UDP,
tfo: option.TFO, tfo: option.TFO,
mpTcp: option.MPTCP, mpTcp: option.MPTCP,
iface: option.Interface, iface: option.Interface,
rmark: option.RoutingMark, rmark: option.RoutingMark,
prefer: option.IPVersion, prefer: C.NewDNSPrefer(option.IPVersion),
}, },
method: method, method: method,
option: &option, option: &option,
obfsMode: obfsMode, obfsMode: obfsMode,
v2rayOption: v2rayOption, v2rayOption: v2rayOption,
gostOption: gostOption,
obfsOption: obfsOption, obfsOption: obfsOption,
shadowTLSOption: shadowTLSOpt, shadowTLSOption: shadowTLSOpt,
restlsConfig: restlsConfig, restlsConfig: restlsConfig,
kcptunClient: kcptunClient, }, nil
}
outbound.dialer = option.NewDialer(outbound.DialOptions())
return outbound, nil
} }

View File

@@ -8,6 +8,8 @@ import (
"strconv" "strconv"
N "github.com/metacubex/mihomo/common/net" N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/proxydialer"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/transport/shadowsocks/core" "github.com/metacubex/mihomo/transport/shadowsocks/core"
"github.com/metacubex/mihomo/transport/shadowsocks/shadowaead" "github.com/metacubex/mihomo/transport/shadowsocks/shadowaead"
@@ -40,15 +42,12 @@ type ShadowSocksROption struct {
} }
// StreamConnContext implements C.ProxyAdapter // StreamConnContext implements C.ProxyAdapter
func (ssr *ShadowSocksR) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (_ net.Conn, err error) { func (ssr *ShadowSocksR) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (net.Conn, error) {
if ctx.Done() != nil {
done := N.SetupContextForConn(ctx, c)
defer done(&err)
}
c = ssr.obfs.StreamConn(c) c = ssr.obfs.StreamConn(c)
c = ssr.cipher.StreamConn(c) c = ssr.cipher.StreamConn(c)
var ( var (
iv []byte iv []byte
err error
) )
switch conn := c.(type) { switch conn := c.(type) {
case *shadowstream.Conn: case *shadowstream.Conn:
@@ -65,11 +64,23 @@ func (ssr *ShadowSocksR) StreamConnContext(ctx context.Context, c net.Conn, meta
} }
// DialContext implements C.ProxyAdapter // DialContext implements C.ProxyAdapter
func (ssr *ShadowSocksR) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn, err error) { func (ssr *ShadowSocksR) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.Conn, err error) {
c, err := ssr.dialer.DialContext(ctx, "tcp", ssr.addr) return ssr.DialContextWithDialer(ctx, dialer.NewDialer(ssr.Base.DialOptions(opts...)...), metadata)
}
// DialContextWithDialer implements C.ProxyAdapter
func (ssr *ShadowSocksR) DialContextWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (_ C.Conn, err error) {
if len(ssr.option.DialerProxy) > 0 {
dialer, err = proxydialer.NewByName(ssr.option.DialerProxy, dialer)
if err != nil {
return nil, err
}
}
c, err := dialer.DialContext(ctx, "tcp", ssr.addr)
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error: %w", ssr.addr, err) return nil, fmt.Errorf("%s connect error: %w", ssr.addr, err)
} }
N.TCPKeepAlive(c)
defer func(c net.Conn) { defer func(c net.Conn) {
safeConnClose(c, err) safeConnClose(c, err)
@@ -80,16 +91,24 @@ func (ssr *ShadowSocksR) DialContext(ctx context.Context, metadata *C.Metadata)
} }
// ListenPacketContext implements C.ProxyAdapter // ListenPacketContext implements C.ProxyAdapter
func (ssr *ShadowSocksR) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (C.PacketConn, error) { func (ssr *ShadowSocksR) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.PacketConn, error) {
if err := ssr.ResolveUDP(ctx, metadata); err != nil { return ssr.ListenPacketWithDialer(ctx, dialer.NewDialer(ssr.Base.DialOptions(opts...)...), metadata)
return nil, err }
// ListenPacketWithDialer implements C.ProxyAdapter
func (ssr *ShadowSocksR) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (_ C.PacketConn, err error) {
if len(ssr.option.DialerProxy) > 0 {
dialer, err = proxydialer.NewByName(ssr.option.DialerProxy, dialer)
if err != nil {
return nil, err
}
} }
addr, err := resolveUDPAddr(ctx, "udp", ssr.addr, ssr.prefer) addr, err := resolveUDPAddrWithPrefer(ctx, "udp", ssr.addr, ssr.prefer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
pc, err := ssr.dialer.ListenPacket(ctx, "udp", "", addr.AddrPort()) pc, err := dialer.ListenPacket(ctx, "udp", "", addr.AddrPort())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -99,11 +118,9 @@ func (ssr *ShadowSocksR) ListenPacketContext(ctx context.Context, metadata *C.Me
return newPacketConn(&ssrPacketConn{EnhancePacketConn: epc, rAddr: addr}, ssr), nil return newPacketConn(&ssrPacketConn{EnhancePacketConn: epc, rAddr: addr}, ssr), nil
} }
// ProxyInfo implements C.ProxyAdapter // SupportWithDialer implements C.ProxyAdapter
func (ssr *ShadowSocksR) ProxyInfo() C.ProxyInfo { func (ssr *ShadowSocksR) SupportWithDialer() C.NetWork {
info := ssr.Base.ProxyInfo() return C.ALLNet
info.DialerProxy = ssr.option.DialerProxy
return info
} }
func NewShadowSocksR(option ShadowSocksROption) (*ShadowSocksR, error) { func NewShadowSocksR(option ShadowSocksROption) (*ShadowSocksR, error) {
@@ -118,7 +135,7 @@ func NewShadowSocksR(option ShadowSocksROption) (*ShadowSocksR, error) {
password := option.Password password := option.Password
coreCiph, err := core.PickCipher(cipher, nil, password) coreCiph, err := core.PickCipher(cipher, nil, password)
if err != nil { if err != nil {
return nil, fmt.Errorf("ssr %s cipher: %s initialize error: %w", addr, cipher, err) return nil, fmt.Errorf("ssr %s initialize error: %w", addr, err)
} }
var ( var (
ivSize int ivSize int
@@ -157,26 +174,23 @@ func NewShadowSocksR(option ShadowSocksROption) (*ShadowSocksR, error) {
return nil, fmt.Errorf("ssr %s initialize protocol error: %w", addr, err) return nil, fmt.Errorf("ssr %s initialize protocol error: %w", addr, err)
} }
outbound := &ShadowSocksR{ return &ShadowSocksR{
Base: &Base{ Base: &Base{
name: option.Name, name: option.Name,
addr: addr, addr: addr,
tp: C.ShadowsocksR, tp: C.ShadowsocksR,
pdName: option.ProviderName,
udp: option.UDP, udp: option.UDP,
tfo: option.TFO, tfo: option.TFO,
mpTcp: option.MPTCP, mpTcp: option.MPTCP,
iface: option.Interface, iface: option.Interface,
rmark: option.RoutingMark, rmark: option.RoutingMark,
prefer: option.IPVersion, prefer: C.NewDNSPrefer(option.IPVersion),
}, },
option: &option, option: &option,
cipher: coreCiph, cipher: coreCiph,
obfs: obfs, obfs: obfs,
protocol: protocol, protocol: protocol,
} }, nil
outbound.dialer = option.NewDialer(outbound.DialOptions())
return outbound, nil
} }
type ssrPacketConn struct { type ssrPacketConn struct {
@@ -226,14 +240,13 @@ func (spc *ssrPacketConn) WaitReadFrom() (data []byte, put func(), addr net.Addr
return nil, nil, nil, errors.New("parse addr error") return nil, nil, nil, errors.New("parse addr error")
} }
udpAddr := _addr.UDPAddr() addr = _addr.UDPAddr()
if udpAddr == nil { if addr == nil {
if put != nil { if put != nil {
put() put()
} }
return nil, nil, nil, errors.New("parse addr error") return nil, nil, nil, errors.New("parse addr error")
} }
addr = udpAddr
data = data[len(_addr):] data = data[len(_addr):]
return return

View File

@@ -2,20 +2,26 @@ package outbound
import ( import (
"context" "context"
"errors"
"runtime"
N "github.com/metacubex/mihomo/common/net" CN "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/proxydialer" "github.com/metacubex/mihomo/component/proxydialer"
"github.com/metacubex/mihomo/component/resolver"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/log" "github.com/metacubex/mihomo/log"
mux "github.com/metacubex/sing-mux" mux "github.com/sagernet/sing-mux"
E "github.com/metacubex/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
M "github.com/metacubex/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
) )
type SingMux struct { type SingMux struct {
ProxyAdapter C.ProxyAdapter
base ProxyBase
client *mux.Client client *mux.Client
dialer proxydialer.SingDialer
onlyTcp bool onlyTcp bool
} }
@@ -37,21 +43,36 @@ type BrutalOption struct {
Down string `proxy:"down,omitempty"` Down string `proxy:"down,omitempty"`
} }
func (s *SingMux) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn, err error) { type ProxyBase interface {
DialOptions(opts ...dialer.Option) []dialer.Option
}
func (s *SingMux) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.Conn, err error) {
options := s.base.DialOptions(opts...)
s.dialer.SetDialer(dialer.NewDialer(options...))
c, err := s.client.DialContext(ctx, "tcp", M.ParseSocksaddrHostPort(metadata.String(), metadata.DstPort)) c, err := s.client.DialContext(ctx, "tcp", M.ParseSocksaddrHostPort(metadata.String(), metadata.DstPort))
if err != nil { if err != nil {
return nil, err return nil, err
} }
return NewConn(c, s), err return NewConn(CN.NewRefConn(c, s), s.ProxyAdapter), err
} }
func (s *SingMux) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (_ C.PacketConn, err error) { func (s *SingMux) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.PacketConn, err error) {
if s.onlyTcp { if s.onlyTcp {
return s.ProxyAdapter.ListenPacketContext(ctx, metadata) return s.ProxyAdapter.ListenPacketContext(ctx, metadata, opts...)
} }
if err = s.ProxyAdapter.ResolveUDP(ctx, metadata); err != nil { options := s.base.DialOptions(opts...)
return nil, err s.dialer.SetDialer(dialer.NewDialer(options...))
// sing-mux use stream-oriented udp with a special address, so we need a net.UDPAddr
if !metadata.Resolved() {
ip, err := resolver.ResolveIP(ctx, metadata.Host)
if err != nil {
return nil, errors.New("can't resolve ip")
}
metadata.DstIP = ip
} }
pc, err := s.client.ListenPacket(ctx, M.SocksaddrFromNet(metadata.UDPAddr())) pc, err := s.client.ListenPacket(ctx, M.SocksaddrFromNet(metadata.UDPAddr()))
if err != nil { if err != nil {
return nil, err return nil, err
@@ -59,7 +80,7 @@ func (s *SingMux) ListenPacketContext(ctx context.Context, metadata *C.Metadata)
if pc == nil { if pc == nil {
return nil, E.New("packetConn is nil") return nil, E.New("packetConn is nil")
} }
return newPacketConn(N.NewThreadSafePacketConn(pc), s), nil return newPacketConn(CN.NewRefPacketConn(CN.NewThreadSafePacketConn(pc), s), s.ProxyAdapter), nil
} }
func (s *SingMux) SupportUDP() bool { func (s *SingMux) SupportUDP() bool {
@@ -76,25 +97,15 @@ func (s *SingMux) SupportUOT() bool {
return true return true
} }
func (s *SingMux) ProxyInfo() C.ProxyInfo { func closeSingMux(s *SingMux) {
info := s.ProxyAdapter.ProxyInfo() _ = s.client.Close()
info.SMUX = true
return info
} }
// Close implements C.ProxyAdapter func NewSingMux(option SingMuxOption, proxy C.ProxyAdapter, base ProxyBase) (C.ProxyAdapter, error) {
func (s *SingMux) Close() error {
if s.client != nil {
_ = s.client.Close()
}
return s.ProxyAdapter.Close()
}
func NewSingMux(option SingMuxOption, proxy ProxyAdapter) (ProxyAdapter, error) {
// TODO // TODO
// "TCP Brutal is only supported on Linux-based systems" // "TCP Brutal is only supported on Linux-based systems"
singDialer := proxydialer.NewSingDialer(proxydialer.New(proxy, option.Statistic)) singDialer := proxydialer.NewSingDialer(proxy, dialer.NewDialer(), option.Statistic)
client, err := mux.NewClient(mux.Options{ client, err := mux.NewClient(mux.Options{
Dialer: singDialer, Dialer: singDialer,
Logger: log.SingLogger, Logger: log.SingLogger,
@@ -114,8 +125,11 @@ func NewSingMux(option SingMuxOption, proxy ProxyAdapter) (ProxyAdapter, error)
} }
outbound := &SingMux{ outbound := &SingMux{
ProxyAdapter: proxy, ProxyAdapter: proxy,
base: base,
client: client, client: client,
dialer: singDialer,
onlyTcp: option.OnlyTcp, onlyTcp: option.OnlyTcp,
} }
runtime.SetFinalizer(outbound, closeSingMux)
return outbound, nil return outbound, nil
} }

View File

@@ -8,6 +8,8 @@ import (
N "github.com/metacubex/mihomo/common/net" N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/common/structure" "github.com/metacubex/mihomo/common/structure"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/proxydialer"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
obfs "github.com/metacubex/mihomo/transport/simple-obfs" obfs "github.com/metacubex/mihomo/transport/simple-obfs"
"github.com/metacubex/mihomo/transport/snell" "github.com/metacubex/mihomo/transport/snell"
@@ -40,7 +42,7 @@ type streamOption struct {
obfsOption *simpleObfsOption obfsOption *simpleObfsOption
} }
func snellStreamConn(c net.Conn, option streamOption) *snell.Snell { func streamConn(c net.Conn, option streamOption) *snell.Snell {
switch option.obfsOption.Mode { switch option.obfsOption.Mode {
case "tls": case "tls":
c = obfs.NewTLSObfs(c, option.obfsOption.Host) c = obfs.NewTLSObfs(c, option.obfsOption.Host)
@@ -53,44 +55,46 @@ func snellStreamConn(c net.Conn, option streamOption) *snell.Snell {
// StreamConnContext implements C.ProxyAdapter // StreamConnContext implements C.ProxyAdapter
func (s *Snell) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (net.Conn, error) { func (s *Snell) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (net.Conn, error) {
c = snellStreamConn(c, streamOption{s.psk, s.version, s.addr, s.obfsOption}) c = streamConn(c, streamOption{s.psk, s.version, s.addr, s.obfsOption})
err := s.writeHeaderContext(ctx, c, metadata) if metadata.NetWork == C.UDP {
err := snell.WriteUDPHeader(c, s.version)
return c, err
}
err := snell.WriteHeader(c, metadata.String(), uint(metadata.DstPort), s.version)
return c, err return c, err
} }
func (s *Snell) writeHeaderContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (err error) {
if ctx.Done() != nil {
done := N.SetupContextForConn(ctx, c)
defer done(&err)
}
if metadata.NetWork == C.UDP {
err = snell.WriteUDPHeader(c, s.version)
return
}
err = snell.WriteHeader(c, metadata.String(), uint(metadata.DstPort), s.version)
return
}
// DialContext implements C.ProxyAdapter // DialContext implements C.ProxyAdapter
func (s *Snell) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn, err error) { func (s *Snell) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.Conn, err error) {
if s.version == snell.Version2 { if s.version == snell.Version2 && len(opts) == 0 {
c, err := s.pool.Get() c, err := s.pool.Get()
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err = s.writeHeaderContext(ctx, c, metadata); err != nil { if err = snell.WriteHeader(c, metadata.String(), uint(metadata.DstPort), s.version); err != nil {
_ = c.Close() c.Close()
return nil, err return nil, err
} }
return NewConn(c, s), err return NewConn(c, s), err
} }
c, err := s.dialer.DialContext(ctx, "tcp", s.addr) return s.DialContextWithDialer(ctx, dialer.NewDialer(s.Base.DialOptions(opts...)...), metadata)
}
// DialContextWithDialer implements C.ProxyAdapter
func (s *Snell) DialContextWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (_ C.Conn, err error) {
if len(s.option.DialerProxy) > 0 {
dialer, err = proxydialer.NewByName(s.option.DialerProxy, dialer)
if err != nil {
return nil, err
}
}
c, err := dialer.DialContext(ctx, "tcp", s.addr)
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error: %w", s.addr, err) return nil, fmt.Errorf("%s connect error: %w", s.addr, err)
} }
N.TCPKeepAlive(c)
defer func(c net.Conn) { defer func(c net.Conn) {
safeConnClose(c, err) safeConnClose(c, err)
@@ -101,34 +105,45 @@ func (s *Snell) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn
} }
// ListenPacketContext implements C.ProxyAdapter // ListenPacketContext implements C.ProxyAdapter
func (s *Snell) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (C.PacketConn, error) { func (s *Snell) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.PacketConn, error) {
return s.ListenPacketWithDialer(ctx, dialer.NewDialer(s.Base.DialOptions(opts...)...), metadata)
}
// ListenPacketWithDialer implements C.ProxyAdapter
func (s *Snell) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (C.PacketConn, error) {
var err error var err error
if err = s.ResolveUDP(ctx, metadata); err != nil { if len(s.option.DialerProxy) > 0 {
return nil, err dialer, err = proxydialer.NewByName(s.option.DialerProxy, dialer)
if err != nil {
return nil, err
}
} }
c, err := s.dialer.DialContext(ctx, "tcp", s.addr) c, err := dialer.DialContext(ctx, "tcp", s.addr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
N.TCPKeepAlive(c)
c = streamConn(c, streamOption{s.psk, s.version, s.addr, s.obfsOption})
c, err = s.StreamConnContext(ctx, c, metadata) err = snell.WriteUDPHeader(c, s.version)
if err != nil {
return nil, err
}
pc := snell.PacketConn(c) pc := snell.PacketConn(c)
return newPacketConn(pc, s), nil return newPacketConn(pc, s), nil
} }
// SupportWithDialer implements C.ProxyAdapter
func (s *Snell) SupportWithDialer() C.NetWork {
return C.ALLNet
}
// SupportUOT implements C.ProxyAdapter // SupportUOT implements C.ProxyAdapter
func (s *Snell) SupportUOT() bool { func (s *Snell) SupportUOT() bool {
return true return true
} }
// ProxyInfo implements C.ProxyAdapter
func (s *Snell) ProxyInfo() C.ProxyInfo {
info := s.Base.ProxyInfo()
info.DialerProxy = s.option.DialerProxy
return info
}
func NewSnell(option SnellOption) (*Snell, error) { func NewSnell(option SnellOption) (*Snell, error) {
addr := net.JoinHostPort(option.Server, strconv.Itoa(option.Port)) addr := net.JoinHostPort(option.Server, strconv.Itoa(option.Port))
psk := []byte(option.Psk) psk := []byte(option.Psk)
@@ -165,29 +180,36 @@ func NewSnell(option SnellOption) (*Snell, error) {
name: option.Name, name: option.Name,
addr: addr, addr: addr,
tp: C.Snell, tp: C.Snell,
pdName: option.ProviderName,
udp: option.UDP, udp: option.UDP,
tfo: option.TFO, tfo: option.TFO,
mpTcp: option.MPTCP, mpTcp: option.MPTCP,
iface: option.Interface, iface: option.Interface,
rmark: option.RoutingMark, rmark: option.RoutingMark,
prefer: option.IPVersion, prefer: C.NewDNSPrefer(option.IPVersion),
}, },
option: &option, option: &option,
psk: psk, psk: psk,
obfsOption: obfsOption, obfsOption: obfsOption,
version: option.Version, version: option.Version,
} }
s.dialer = option.NewDialer(s.DialOptions())
if option.Version == snell.Version2 { if option.Version == snell.Version2 {
s.pool = snell.NewPool(func(ctx context.Context) (*snell.Snell, error) { s.pool = snell.NewPool(func(ctx context.Context) (*snell.Snell, error) {
c, err := s.dialer.DialContext(ctx, "tcp", addr) var err error
var cDialer C.Dialer = dialer.NewDialer(s.Base.DialOptions()...)
if len(s.option.DialerProxy) > 0 {
cDialer, err = proxydialer.NewByName(s.option.DialerProxy, cDialer)
if err != nil {
return nil, err
}
}
c, err := cDialer.DialContext(ctx, "tcp", addr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return snellStreamConn(c, streamOption{psk, option.Version, addr, obfsOption}), nil N.TCPKeepAlive(c)
return streamConn(c, streamOption{psk, option.Version, addr, obfsOption}), nil
}) })
} }
return s, nil return s, nil

View File

@@ -2,6 +2,7 @@ package outbound
import ( import (
"context" "context"
"crypto/tls"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@@ -11,10 +12,10 @@ import (
N "github.com/metacubex/mihomo/common/net" N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/component/ca" "github.com/metacubex/mihomo/component/ca"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/proxydialer"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/transport/socks5" "github.com/metacubex/mihomo/transport/socks5"
"github.com/metacubex/tls"
) )
type Socks5 struct { type Socks5 struct {
@@ -38,8 +39,6 @@ type Socks5Option struct {
UDP bool `proxy:"udp,omitempty"` UDP bool `proxy:"udp,omitempty"`
SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"` SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"`
Fingerprint string `proxy:"fingerprint,omitempty"` Fingerprint string `proxy:"fingerprint,omitempty"`
Certificate string `proxy:"certificate,omitempty"`
PrivateKey string `proxy:"private-key,omitempty"`
} }
// StreamConnContext implements C.ProxyAdapter // StreamConnContext implements C.ProxyAdapter
@@ -60,18 +59,30 @@ func (ss *Socks5) StreamConnContext(ctx context.Context, c net.Conn, metadata *C
Password: ss.pass, Password: ss.pass,
} }
} }
if _, err := ss.clientHandshakeContext(ctx, c, serializesSocksAddr(metadata), socks5.CmdConnect, user); err != nil { if _, err := socks5.ClientHandshake(c, serializesSocksAddr(metadata), socks5.CmdConnect, user); err != nil {
return nil, err return nil, err
} }
return c, nil return c, nil
} }
// DialContext implements C.ProxyAdapter // DialContext implements C.ProxyAdapter
func (ss *Socks5) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn, err error) { func (ss *Socks5) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.Conn, err error) {
c, err := ss.dialer.DialContext(ctx, "tcp", ss.addr) return ss.DialContextWithDialer(ctx, dialer.NewDialer(ss.Base.DialOptions(opts...)...), metadata)
}
// DialContextWithDialer implements C.ProxyAdapter
func (ss *Socks5) DialContextWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (_ C.Conn, err error) {
if len(ss.option.DialerProxy) > 0 {
dialer, err = proxydialer.NewByName(ss.option.DialerProxy, dialer)
if err != nil {
return nil, err
}
}
c, err := dialer.DialContext(ctx, "tcp", ss.addr)
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error: %w", ss.addr, err) return nil, fmt.Errorf("%s connect error: %w", ss.addr, err)
} }
N.TCPKeepAlive(c)
defer func(c net.Conn) { defer func(c net.Conn) {
safeConnClose(c, err) safeConnClose(c, err)
@@ -85,12 +96,21 @@ func (ss *Socks5) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Co
return NewConn(c, ss), nil return NewConn(c, ss), nil
} }
// SupportWithDialer implements C.ProxyAdapter
func (ss *Socks5) SupportWithDialer() C.NetWork {
return C.TCP
}
// ListenPacketContext implements C.ProxyAdapter // ListenPacketContext implements C.ProxyAdapter
func (ss *Socks5) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (_ C.PacketConn, err error) { func (ss *Socks5) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.PacketConn, err error) {
if err = ss.ResolveUDP(ctx, metadata); err != nil { var cDialer C.Dialer = dialer.NewDialer(ss.Base.DialOptions(opts...)...)
return nil, err if len(ss.option.DialerProxy) > 0 {
cDialer, err = proxydialer.NewByName(ss.option.DialerProxy, cDialer)
if err != nil {
return nil, err
}
} }
c, err := ss.dialer.DialContext(ctx, "tcp", ss.addr) c, err := cDialer.DialContext(ctx, "tcp", ss.addr)
if err != nil { if err != nil {
err = fmt.Errorf("%s connect error: %w", ss.addr, err) err = fmt.Errorf("%s connect error: %w", ss.addr, err)
return return
@@ -108,6 +128,7 @@ func (ss *Socks5) ListenPacketContext(ctx context.Context, metadata *C.Metadata)
safeConnClose(c, err) safeConnClose(c, err)
}(c) }(c)
N.TCPKeepAlive(c)
var user *socks5.User var user *socks5.User
if ss.user != "" { if ss.user != "" {
user = &socks5.User{ user = &socks5.User{
@@ -117,7 +138,7 @@ func (ss *Socks5) ListenPacketContext(ctx context.Context, metadata *C.Metadata)
} }
udpAssocateAddr := socks5.AddrFromStdAddrPort(netip.AddrPortFrom(netip.IPv4Unspecified(), 0)) udpAssocateAddr := socks5.AddrFromStdAddrPort(netip.AddrPortFrom(netip.IPv4Unspecified(), 0))
bindAddr, err := ss.clientHandshakeContext(ctx, c, udpAssocateAddr, socks5.CmdUDPAssociate, user) bindAddr, err := socks5.ClientHandshake(c, udpAssocateAddr, socks5.CmdUDPAssociate, user)
if err != nil { if err != nil {
err = fmt.Errorf("client hanshake error: %w", err) err = fmt.Errorf("client hanshake error: %w", err)
return return
@@ -129,7 +150,7 @@ func (ss *Socks5) ListenPacketContext(ctx context.Context, metadata *C.Metadata)
err = errors.New("invalid UDP bind address") err = errors.New("invalid UDP bind address")
return return
} else if bindUDPAddr.IP.IsUnspecified() { } else if bindUDPAddr.IP.IsUnspecified() {
serverAddr, err := resolveUDPAddr(ctx, "udp", ss.Addr(), C.IPv4Prefer) serverAddr, err := resolveUDPAddr(ctx, "udp", ss.Addr())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -137,7 +158,7 @@ func (ss *Socks5) ListenPacketContext(ctx context.Context, metadata *C.Metadata)
bindUDPAddr.IP = serverAddr.IP bindUDPAddr.IP = serverAddr.IP
} }
pc, err := ss.dialer.ListenPacket(ctx, "udp", "", bindUDPAddr.AddrPort()) pc, err := cDialer.ListenPacket(ctx, "udp", "", bindUDPAddr.AddrPort())
if err != nil { if err != nil {
return return
} }
@@ -153,51 +174,32 @@ func (ss *Socks5) ListenPacketContext(ctx context.Context, metadata *C.Metadata)
return newPacketConn(&socksPacketConn{PacketConn: pc, rAddr: bindUDPAddr, tcpConn: c}, ss), nil return newPacketConn(&socksPacketConn{PacketConn: pc, rAddr: bindUDPAddr, tcpConn: c}, ss), nil
} }
// ProxyInfo implements C.ProxyAdapter
func (ss *Socks5) ProxyInfo() C.ProxyInfo {
info := ss.Base.ProxyInfo()
info.DialerProxy = ss.option.DialerProxy
return info
}
func (ss *Socks5) clientHandshakeContext(ctx context.Context, c net.Conn, addr socks5.Addr, command socks5.Command, user *socks5.User) (_ socks5.Addr, err error) {
if ctx.Done() != nil {
done := N.SetupContextForConn(ctx, c)
defer done(&err)
}
return socks5.ClientHandshake(c, addr, command, user)
}
func NewSocks5(option Socks5Option) (*Socks5, error) { func NewSocks5(option Socks5Option) (*Socks5, error) {
var tlsConfig *tls.Config var tlsConfig *tls.Config
if option.TLS { if option.TLS {
tlsConfig = &tls.Config{
InsecureSkipVerify: option.SkipCertVerify,
ServerName: option.Server,
}
var err error var err error
tlsConfig, err = ca.GetTLSConfig(ca.Option{ tlsConfig, err = ca.GetSpecifiedFingerprintTLSConfig(tlsConfig, option.Fingerprint)
TLSConfig: &tls.Config{
InsecureSkipVerify: option.SkipCertVerify,
ServerName: option.Server,
},
Fingerprint: option.Fingerprint,
Certificate: option.Certificate,
PrivateKey: option.PrivateKey,
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
outbound := &Socks5{ return &Socks5{
Base: &Base{ Base: &Base{
name: option.Name, name: option.Name,
addr: net.JoinHostPort(option.Server, strconv.Itoa(option.Port)), addr: net.JoinHostPort(option.Server, strconv.Itoa(option.Port)),
tp: C.Socks5, tp: C.Socks5,
pdName: option.ProviderName,
udp: option.UDP, udp: option.UDP,
tfo: option.TFO, tfo: option.TFO,
mpTcp: option.MPTCP, mpTcp: option.MPTCP,
iface: option.Interface, iface: option.Interface,
rmark: option.RoutingMark, rmark: option.RoutingMark,
prefer: option.IPVersion, prefer: C.NewDNSPrefer(option.IPVersion),
}, },
option: &option, option: &option,
user: option.UserName, user: option.UserName,
@@ -205,9 +207,7 @@ func NewSocks5(option Socks5Option) (*Socks5, error) {
tls: option.TLS, tls: option.TLS,
skipCertVerify: option.SkipCertVerify, skipCertVerify: option.SkipCertVerify,
tlsConfig: tlsConfig, tlsConfig: tlsConfig,
} }, nil
outbound.dialer = option.NewDialer(outbound.DialOptions())
return outbound, nil
} }
type socksPacketConn struct { type socksPacketConn struct {

View File

@@ -7,14 +7,17 @@ import (
"fmt" "fmt"
"net" "net"
"os" "os"
"runtime"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
N "github.com/metacubex/mihomo/common/net" N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/proxydialer"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/randv2" "github.com/zhangyunhao116/fastrand"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
@@ -22,10 +25,7 @@ type Ssh struct {
*Base *Base
option *SshOption option *SshOption
client *sshClient // using a standalone struct to avoid its inner loop invalidate the Finalizer
config *ssh.ClientConfig
client *ssh.Client
cMutex sync.Mutex
} }
type SshOption struct { type SshOption struct {
@@ -41,8 +41,15 @@ type SshOption struct {
HostKeyAlgorithms []string `proxy:"host-key-algorithms,omitempty"` HostKeyAlgorithms []string `proxy:"host-key-algorithms,omitempty"`
} }
func (s *Ssh) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn, err error) { func (s *Ssh) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.Conn, err error) {
client, err := s.connect(ctx, s.addr) var cDialer C.Dialer = dialer.NewDialer(s.Base.DialOptions(opts...)...)
if len(s.option.DialerProxy) > 0 {
cDialer, err = proxydialer.NewByName(s.option.DialerProxy, cDialer)
if err != nil {
return nil, err
}
}
client, err := s.client.connect(ctx, cDialer, s.addr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -51,19 +58,26 @@ func (s *Ssh) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn,
return nil, err return nil, err
} }
return NewConn(c, s), nil return NewConn(N.NewRefConn(c, s), s), nil
} }
func (s *Ssh) connect(ctx context.Context, addr string) (client *ssh.Client, err error) { type sshClient struct {
config *ssh.ClientConfig
client *ssh.Client
cMutex sync.Mutex
}
func (s *sshClient) connect(ctx context.Context, cDialer C.Dialer, addr string) (client *ssh.Client, err error) {
s.cMutex.Lock() s.cMutex.Lock()
defer s.cMutex.Unlock() defer s.cMutex.Unlock()
if s.client != nil { if s.client != nil {
return s.client, nil return s.client, nil
} }
c, err := s.dialer.DialContext(ctx, "tcp", addr) c, err := cDialer.DialContext(ctx, "tcp", addr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
N.TCPKeepAlive(c)
defer func(c net.Conn) { defer func(c net.Conn) {
safeConnClose(c, err) safeConnClose(c, err)
@@ -95,15 +109,7 @@ func (s *Ssh) connect(ctx context.Context, addr string) (client *ssh.Client, err
return client, nil return client, nil
} }
// ProxyInfo implements C.ProxyAdapter func (s *sshClient) Close() error {
func (s *Ssh) ProxyInfo() C.ProxyInfo {
info := s.Base.ProxyInfo()
info.DialerProxy = s.option.DialerProxy
return info
}
// Close implements C.ProxyAdapter
func (s *Ssh) Close() error {
s.cMutex.Lock() s.cMutex.Lock()
defer s.cMutex.Unlock() defer s.cMutex.Unlock()
if s.client != nil { if s.client != nil {
@@ -112,6 +118,10 @@ func (s *Ssh) Close() error {
return nil return nil
} }
func closeSsh(s *Ssh) {
_ = s.client.Close()
}
func NewSsh(option SshOption) (*Ssh, error) { func NewSsh(option SshOption) (*Ssh, error) {
addr := net.JoinHostPort(option.Server, strconv.Itoa(option.Port)) addr := net.JoinHostPort(option.Server, strconv.Itoa(option.Port))
@@ -127,11 +137,7 @@ func NewSsh(option SshOption) (*Ssh, error) {
if strings.Contains(option.PrivateKey, "PRIVATE KEY") { if strings.Contains(option.PrivateKey, "PRIVATE KEY") {
b = []byte(option.PrivateKey) b = []byte(option.PrivateKey)
} else { } else {
path := C.Path.Resolve(option.PrivateKey) b, err = os.ReadFile(C.Path.Resolve(option.PrivateKey))
if !C.Path.IsSafePath(path) {
return nil, C.Path.ErrNotSafePath(path)
}
b, err = os.ReadFile(path)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -174,10 +180,10 @@ func NewSsh(option SshOption) (*Ssh, error) {
} }
version := "SSH-2.0-OpenSSH_" version := "SSH-2.0-OpenSSH_"
if randv2.IntN(2) == 0 { if fastrand.Intn(2) == 0 {
version += "7." + strconv.Itoa(randv2.IntN(10)) version += "7." + strconv.Itoa(fastrand.Intn(10))
} else { } else {
version += "8." + strconv.Itoa(randv2.IntN(9)) version += "8." + strconv.Itoa(fastrand.Intn(9))
} }
config.ClientVersion = version config.ClientVersion = version
@@ -186,15 +192,17 @@ func NewSsh(option SshOption) (*Ssh, error) {
name: option.Name, name: option.Name,
addr: addr, addr: addr,
tp: C.Ssh, tp: C.Ssh,
pdName: option.ProviderName,
udp: false, udp: false,
iface: option.Interface, iface: option.Interface,
rmark: option.RoutingMark, rmark: option.RoutingMark,
prefer: option.IPVersion, prefer: C.NewDNSPrefer(option.IPVersion),
}, },
option: &option, option: &option,
config: &config, client: &sshClient{
config: &config,
},
} }
outbound.dialer = option.NewDialer(outbound.DialOptions()) runtime.SetFinalizer(outbound, closeSsh)
return outbound, nil return outbound, nil
} }

View File

@@ -1,405 +0,0 @@
package outbound
import (
"context"
"fmt"
"net"
"strconv"
"strings"
"sync"
N "github.com/metacubex/mihomo/common/net"
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/transport/sudoku"
)
type Sudoku struct {
*Base
option *SudokuOption
baseConf sudoku.ProtocolConfig
httpMaskMu sync.Mutex
httpMaskClient *sudoku.HTTPMaskTunnelClient
muxMu sync.Mutex
muxClient *sudoku.MultiplexClient
}
type SudokuOption struct {
BasicOption
Name string `proxy:"name"`
Server string `proxy:"server"`
Port int `proxy:"port"`
Key string `proxy:"key"`
AEADMethod string `proxy:"aead-method,omitempty"`
PaddingMin *int `proxy:"padding-min,omitempty"`
PaddingMax *int `proxy:"padding-max,omitempty"`
TableType string `proxy:"table-type,omitempty"` // "prefer_ascii" or "prefer_entropy"
EnablePureDownlink *bool `proxy:"enable-pure-downlink,omitempty"`
HTTPMask bool `proxy:"http-mask,omitempty"`
HTTPMaskMode string `proxy:"http-mask-mode,omitempty"` // "legacy" (default), "stream", "poll", "auto"
HTTPMaskTLS bool `proxy:"http-mask-tls,omitempty"` // only for http-mask-mode stream/poll/auto
HTTPMaskHost string `proxy:"http-mask-host,omitempty"` // optional Host/SNI override (domain or domain:port)
PathRoot string `proxy:"path-root,omitempty"` // optional first-level path prefix for HTTP tunnel endpoints
HTTPMaskMultiplex string `proxy:"http-mask-multiplex,omitempty"` // "off" (default), "auto" (reuse h1/h2), "on" (single tunnel, multi-target)
CustomTable string `proxy:"custom-table,omitempty"` // optional custom byte layout, e.g. xpxvvpvv
CustomTables []string `proxy:"custom-tables,omitempty"` // optional table rotation patterns, overrides custom-table when non-empty
}
// DialContext implements C.ProxyAdapter
func (s *Sudoku) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn, err error) {
cfg, err := s.buildConfig(metadata)
if err != nil {
return nil, err
}
muxMode := normalizeHTTPMaskMultiplex(cfg.HTTPMaskMultiplex)
if muxMode == "on" && !cfg.DisableHTTPMask && httpTunnelModeEnabled(cfg.HTTPMaskMode) {
stream, muxErr := s.dialMultiplex(ctx, cfg.TargetAddress)
if muxErr == nil {
return NewConn(stream, s), nil
}
return nil, muxErr
}
c, err := s.dialAndHandshake(ctx, cfg)
if err != nil {
return nil, err
}
defer func() { safeConnClose(c, err) }()
addrBuf, err := sudoku.EncodeAddress(cfg.TargetAddress)
if err != nil {
return nil, fmt.Errorf("encode target address failed: %w", err)
}
if _, err = c.Write(addrBuf); err != nil {
return nil, fmt.Errorf("send target address failed: %w", err)
}
return NewConn(c, s), nil
}
// ListenPacketContext implements C.ProxyAdapter
func (s *Sudoku) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (C.PacketConn, error) {
if err := s.ResolveUDP(ctx, metadata); err != nil {
return nil, err
}
cfg, err := s.buildConfig(metadata)
if err != nil {
return nil, err
}
c, err := s.dialAndHandshake(ctx, cfg)
if err != nil {
return nil, err
}
if err = sudoku.WritePreface(c); err != nil {
_ = c.Close()
return nil, fmt.Errorf("send uot preface failed: %w", err)
}
return newPacketConn(N.NewThreadSafePacketConn(sudoku.NewUoTPacketConn(c)), s), nil
}
// SupportUOT implements C.ProxyAdapter
func (s *Sudoku) SupportUOT() bool {
return true
}
// ProxyInfo implements C.ProxyAdapter
func (s *Sudoku) ProxyInfo() C.ProxyInfo {
info := s.Base.ProxyInfo()
info.DialerProxy = s.option.DialerProxy
return info
}
func (s *Sudoku) buildConfig(metadata *C.Metadata) (*sudoku.ProtocolConfig, error) {
if metadata == nil || metadata.DstPort == 0 || !metadata.Valid() {
return nil, fmt.Errorf("invalid metadata for sudoku outbound")
}
cfg := s.baseConf
cfg.TargetAddress = metadata.RemoteAddress()
if err := cfg.ValidateClient(); err != nil {
return nil, err
}
return &cfg, nil
}
func NewSudoku(option SudokuOption) (*Sudoku, error) {
if option.Server == "" {
return nil, fmt.Errorf("server is required")
}
if option.Port <= 0 || option.Port > 65535 {
return nil, fmt.Errorf("invalid port: %d", option.Port)
}
if option.Key == "" {
return nil, fmt.Errorf("key is required")
}
tableType := strings.ToLower(option.TableType)
if tableType == "" {
tableType = "prefer_ascii"
}
if tableType != "prefer_ascii" && tableType != "prefer_entropy" {
return nil, fmt.Errorf("table-type must be prefer_ascii or prefer_entropy")
}
defaultConf := sudoku.DefaultConfig()
paddingMin := defaultConf.PaddingMin
paddingMax := defaultConf.PaddingMax
if option.PaddingMin != nil {
paddingMin = *option.PaddingMin
}
if option.PaddingMax != nil {
paddingMax = *option.PaddingMax
}
if option.PaddingMin == nil && option.PaddingMax != nil && paddingMax < paddingMin {
paddingMin = paddingMax
}
if option.PaddingMax == nil && option.PaddingMin != nil && paddingMax < paddingMin {
paddingMax = paddingMin
}
enablePureDownlink := defaultConf.EnablePureDownlink
if option.EnablePureDownlink != nil {
enablePureDownlink = *option.EnablePureDownlink
}
baseConf := sudoku.ProtocolConfig{
ServerAddress: net.JoinHostPort(option.Server, strconv.Itoa(option.Port)),
Key: option.Key,
AEADMethod: defaultConf.AEADMethod,
PaddingMin: paddingMin,
PaddingMax: paddingMax,
EnablePureDownlink: enablePureDownlink,
HandshakeTimeoutSeconds: defaultConf.HandshakeTimeoutSeconds,
DisableHTTPMask: !option.HTTPMask,
HTTPMaskMode: defaultConf.HTTPMaskMode,
HTTPMaskTLSEnabled: option.HTTPMaskTLS,
HTTPMaskHost: option.HTTPMaskHost,
HTTPMaskPathRoot: strings.TrimSpace(option.PathRoot),
HTTPMaskMultiplex: defaultConf.HTTPMaskMultiplex,
}
if option.HTTPMaskMode != "" {
baseConf.HTTPMaskMode = option.HTTPMaskMode
}
if option.HTTPMaskMultiplex != "" {
baseConf.HTTPMaskMultiplex = option.HTTPMaskMultiplex
}
tables, err := sudoku.NewTablesWithCustomPatterns(sudoku.ClientAEADSeed(option.Key), tableType, option.CustomTable, option.CustomTables)
if err != nil {
return nil, fmt.Errorf("build table(s) failed: %w", err)
}
if len(tables) == 1 {
baseConf.Table = tables[0]
} else {
baseConf.Tables = tables
}
if option.AEADMethod != "" {
baseConf.AEADMethod = option.AEADMethod
}
outbound := &Sudoku{
Base: &Base{
name: option.Name,
addr: baseConf.ServerAddress,
tp: C.Sudoku,
pdName: option.ProviderName,
udp: true,
tfo: option.TFO,
mpTcp: option.MPTCP,
iface: option.Interface,
rmark: option.RoutingMark,
prefer: option.IPVersion,
},
option: &option,
baseConf: baseConf,
}
outbound.dialer = option.NewDialer(outbound.DialOptions())
return outbound, nil
}
func (s *Sudoku) Close() error {
s.resetMuxClient()
s.resetHTTPMaskClient()
return s.Base.Close()
}
func normalizeHTTPMaskMultiplex(mode string) string {
switch strings.ToLower(strings.TrimSpace(mode)) {
case "", "off":
return "off"
case "auto":
return "auto"
case "on":
return "on"
default:
return "off"
}
}
func httpTunnelModeEnabled(mode string) bool {
switch strings.ToLower(strings.TrimSpace(mode)) {
case "stream", "poll", "auto":
return true
default:
return false
}
}
func (s *Sudoku) dialAndHandshake(ctx context.Context, cfg *sudoku.ProtocolConfig) (_ net.Conn, err error) {
if cfg == nil {
return nil, fmt.Errorf("config is required")
}
handshakeCfg := *cfg
if !handshakeCfg.DisableHTTPMask && httpTunnelModeEnabled(handshakeCfg.HTTPMaskMode) {
handshakeCfg.DisableHTTPMask = true
}
upgrade := func(raw net.Conn) (net.Conn, error) {
return sudoku.ClientHandshake(raw, &handshakeCfg)
}
var (
c net.Conn
handshakeDone bool
)
if !cfg.DisableHTTPMask && httpTunnelModeEnabled(cfg.HTTPMaskMode) {
muxMode := normalizeHTTPMaskMultiplex(cfg.HTTPMaskMultiplex)
switch muxMode {
case "auto", "on":
client, errX := s.getOrCreateHTTPMaskClient(cfg)
if errX != nil {
return nil, errX
}
c, err = client.Dial(ctx, upgrade)
default:
c, err = sudoku.DialHTTPMaskTunnel(ctx, cfg.ServerAddress, cfg, s.dialer.DialContext, upgrade)
}
if err == nil && c != nil {
handshakeDone = true
}
}
if c == nil && err == nil {
c, err = s.dialer.DialContext(ctx, "tcp", s.addr)
}
if err != nil {
return nil, fmt.Errorf("%s connect error: %w", s.addr, err)
}
defer func() { safeConnClose(c, err) }()
if ctx.Done() != nil {
done := N.SetupContextForConn(ctx, c)
defer done(&err)
}
if !handshakeDone {
c, err = sudoku.ClientHandshake(c, &handshakeCfg)
if err != nil {
return nil, err
}
}
return c, nil
}
func (s *Sudoku) dialMultiplex(ctx context.Context, targetAddress string) (net.Conn, error) {
for attempt := 0; attempt < 2; attempt++ {
client, err := s.getOrCreateMuxClient(ctx)
if err != nil {
return nil, err
}
stream, err := client.Dial(ctx, targetAddress)
if err != nil {
s.resetMuxClient()
continue
}
return stream, nil
}
return nil, fmt.Errorf("multiplex open stream failed")
}
func (s *Sudoku) getOrCreateMuxClient(ctx context.Context) (*sudoku.MultiplexClient, error) {
if s == nil {
return nil, fmt.Errorf("nil adapter")
}
s.muxMu.Lock()
if s.muxClient != nil && !s.muxClient.IsClosed() {
client := s.muxClient
s.muxMu.Unlock()
return client, nil
}
s.muxMu.Unlock()
s.muxMu.Lock()
defer s.muxMu.Unlock()
if s.muxClient != nil && !s.muxClient.IsClosed() {
return s.muxClient, nil
}
baseCfg := s.baseConf
baseConn, err := s.dialAndHandshake(ctx, &baseCfg)
if err != nil {
return nil, err
}
client, err := sudoku.StartMultiplexClient(baseConn)
if err != nil {
_ = baseConn.Close()
return nil, err
}
s.muxClient = client
return client, nil
}
func (s *Sudoku) resetMuxClient() {
s.muxMu.Lock()
defer s.muxMu.Unlock()
if s.muxClient != nil {
_ = s.muxClient.Close()
s.muxClient = nil
}
}
func (s *Sudoku) getOrCreateHTTPMaskClient(cfg *sudoku.ProtocolConfig) (*sudoku.HTTPMaskTunnelClient, error) {
if s == nil {
return nil, fmt.Errorf("nil adapter")
}
if cfg == nil {
return nil, fmt.Errorf("config is required")
}
s.httpMaskMu.Lock()
defer s.httpMaskMu.Unlock()
if s.httpMaskClient != nil {
return s.httpMaskClient, nil
}
c, err := sudoku.NewHTTPMaskTunnelClient(cfg.ServerAddress, cfg, s.dialer.DialContext)
if err != nil {
return nil, err
}
s.httpMaskClient = c
return c, nil
}
func (s *Sudoku) resetHTTPMaskClient() {
s.httpMaskMu.Lock()
defer s.httpMaskMu.Unlock()
if s.httpMaskClient != nil {
s.httpMaskClient.CloseIdleConnections()
s.httpMaskClient = nil
}
}

View File

@@ -2,29 +2,26 @@ package outbound
import ( import (
"context" "context"
"errors" "crypto/tls"
"fmt" "fmt"
"net" "net"
"net/http"
"strconv" "strconv"
N "github.com/metacubex/mihomo/common/net" N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/component/ca" "github.com/metacubex/mihomo/component/ca"
"github.com/metacubex/mihomo/component/ech" "github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/proxydialer"
tlsC "github.com/metacubex/mihomo/component/tls" tlsC "github.com/metacubex/mihomo/component/tls"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/transport/gun" "github.com/metacubex/mihomo/transport/gun"
"github.com/metacubex/mihomo/transport/shadowsocks/core"
"github.com/metacubex/mihomo/transport/trojan" "github.com/metacubex/mihomo/transport/trojan"
"github.com/metacubex/mihomo/transport/vmess"
"github.com/metacubex/http"
"github.com/metacubex/tls"
) )
type Trojan struct { type Trojan struct {
*Base *Base
option *TrojanOption instance *trojan.Trojan
hexPassword [trojan.KeyLength]byte option *TrojanOption
// for gun mux // for gun mux
gunTLSConfig *tls.Config gunTLSConfig *tls.Config
@@ -32,9 +29,6 @@ type Trojan struct {
transport *gun.TransportWrap transport *gun.TransportWrap
realityConfig *tlsC.RealityConfig realityConfig *tlsC.RealityConfig
echConfig *ech.Config
ssCipher core.Cipher
} }
type TrojanOption struct { type TrojanOption struct {
@@ -47,41 +41,23 @@ type TrojanOption struct {
SNI string `proxy:"sni,omitempty"` SNI string `proxy:"sni,omitempty"`
SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"` SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"`
Fingerprint string `proxy:"fingerprint,omitempty"` Fingerprint string `proxy:"fingerprint,omitempty"`
Certificate string `proxy:"certificate,omitempty"`
PrivateKey string `proxy:"private-key,omitempty"`
UDP bool `proxy:"udp,omitempty"` UDP bool `proxy:"udp,omitempty"`
Network string `proxy:"network,omitempty"` Network string `proxy:"network,omitempty"`
ECHOpts ECHOptions `proxy:"ech-opts,omitempty"`
RealityOpts RealityOptions `proxy:"reality-opts,omitempty"` RealityOpts RealityOptions `proxy:"reality-opts,omitempty"`
GrpcOpts GrpcOptions `proxy:"grpc-opts,omitempty"` GrpcOpts GrpcOptions `proxy:"grpc-opts,omitempty"`
WSOpts WSOptions `proxy:"ws-opts,omitempty"` WSOpts WSOptions `proxy:"ws-opts,omitempty"`
SSOpts TrojanSSOption `proxy:"ss-opts,omitempty"`
ClientFingerprint string `proxy:"client-fingerprint,omitempty"` ClientFingerprint string `proxy:"client-fingerprint,omitempty"`
} }
// TrojanSSOption from https://github.com/p4gefau1t/trojan-go/blob/v0.10.6/tunnel/shadowsocks/config.go#L5 func (t *Trojan) plainStream(ctx context.Context, c net.Conn) (net.Conn, error) {
type TrojanSSOption struct { if t.option.Network == "ws" {
Enabled bool `proxy:"enabled,omitempty"`
Method string `proxy:"method,omitempty"`
Password string `proxy:"password,omitempty"`
}
// StreamConnContext implements C.ProxyAdapter
func (t *Trojan) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (_ net.Conn, err error) {
switch t.option.Network {
case "ws":
host, port, _ := net.SplitHostPort(t.addr) host, port, _ := net.SplitHostPort(t.addr)
wsOpts := &trojan.WebsocketOption{
wsOpts := &vmess.WebsocketConfig{
Host: host, Host: host,
Port: port, Port: port,
Path: t.option.WSOpts.Path, Path: t.option.WSOpts.Path,
MaxEarlyData: t.option.WSOpts.MaxEarlyData,
EarlyDataHeaderName: t.option.WSOpts.EarlyDataHeaderName,
V2rayHttpUpgrade: t.option.WSOpts.V2rayHttpUpgrade, V2rayHttpUpgrade: t.option.WSOpts.V2rayHttpUpgrade,
V2rayHttpUpgradeFastOpen: t.option.WSOpts.V2rayHttpUpgradeFastOpen, V2rayHttpUpgradeFastOpen: t.option.WSOpts.V2rayHttpUpgradeFastOpen,
ClientFingerprint: t.option.ClientFingerprint,
ECHConfig: t.echConfig,
Headers: http.Header{}, Headers: http.Header{},
} }
@@ -95,110 +71,70 @@ func (t *Trojan) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.
} }
} }
alpn := trojan.DefaultWebsocketALPN return t.instance.StreamWebsocketConn(ctx, c, wsOpts)
if t.option.ALPN != nil { // structure's Decode will ensure value not nil when input has value even it was set an empty array
alpn = t.option.ALPN
}
wsOpts.TLS = true
wsOpts.TLSConfig, err = ca.GetTLSConfig(ca.Option{
TLSConfig: &tls.Config{
NextProtos: alpn,
MinVersion: tls.VersionTLS12,
InsecureSkipVerify: t.option.SkipCertVerify,
ServerName: t.option.SNI,
},
Fingerprint: t.option.Fingerprint,
Certificate: t.option.Certificate,
PrivateKey: t.option.PrivateKey,
})
if err != nil {
return nil, err
}
c, err = vmess.StreamWebsocketConn(ctx, c, wsOpts)
case "grpc":
c, err = gun.StreamGunWithConn(c, t.gunTLSConfig, t.gunConfig, t.echConfig, t.realityConfig)
default:
// default tcp network
// handle TLS
alpn := trojan.DefaultALPN
if t.option.ALPN != nil { // structure's Decode will ensure value not nil when input has value even it was set an empty array
alpn = t.option.ALPN
}
c, err = vmess.StreamTLSConn(ctx, c, &vmess.TLSConfig{
Host: t.option.SNI,
SkipCertVerify: t.option.SkipCertVerify,
FingerPrint: t.option.Fingerprint,
Certificate: t.option.Certificate,
PrivateKey: t.option.PrivateKey,
ClientFingerprint: t.option.ClientFingerprint,
NextProtos: alpn,
ECH: t.echConfig,
Reality: t.realityConfig,
})
} }
return t.instance.StreamConn(ctx, c)
}
// StreamConnContext implements C.ProxyAdapter
func (t *Trojan) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (net.Conn, error) {
var err error
if tlsC.HaveGlobalFingerprint() && len(t.option.ClientFingerprint) == 0 {
t.option.ClientFingerprint = tlsC.GetGlobalFingerprint()
}
if t.transport != nil {
c, err = gun.StreamGunWithConn(c, t.gunTLSConfig, t.gunConfig, t.realityConfig)
} else {
c, err = t.plainStream(ctx, c)
}
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error: %w", t.addr, err) return nil, fmt.Errorf("%s connect error: %w", t.addr, err)
} }
return t.streamConnContext(ctx, c, metadata)
}
func (t *Trojan) streamConnContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (_ net.Conn, err error) {
if t.ssCipher != nil {
c = t.ssCipher.StreamConn(c)
}
if ctx.Done() != nil {
done := N.SetupContextForConn(ctx, c)
defer done(&err)
}
command := trojan.CommandTCP
if metadata.NetWork == C.UDP { if metadata.NetWork == C.UDP {
command = trojan.CommandUDP err = t.instance.WriteHeader(c, trojan.CommandUDP, serializesSocksAddr(metadata))
return c, err
} }
err = trojan.WriteHeader(c, t.hexPassword, command, serializesSocksAddr(metadata)) err = t.instance.WriteHeader(c, trojan.CommandTCP, serializesSocksAddr(metadata))
return c, err return c, err
} }
func (t *Trojan) writeHeaderContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (err error) {
if ctx.Done() != nil {
done := N.SetupContextForConn(ctx, c)
defer done(&err)
}
command := trojan.CommandTCP
if metadata.NetWork == C.UDP {
command = trojan.CommandUDP
}
err = trojan.WriteHeader(c, t.hexPassword, command, serializesSocksAddr(metadata))
return err
}
// DialContext implements C.ProxyAdapter // DialContext implements C.ProxyAdapter
func (t *Trojan) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn, err error) { func (t *Trojan) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.Conn, err error) {
var c net.Conn
// gun transport // gun transport
if t.transport != nil { if t.transport != nil && len(opts) == 0 {
c, err = gun.StreamGunWithTransport(t.transport, t.gunConfig) c, err := gun.StreamGunWithTransport(t.transport, t.gunConfig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer func(c net.Conn) {
safeConnClose(c, err)
}(c)
c, err = t.streamConnContext(ctx, c, metadata) if err = t.instance.WriteHeader(c, trojan.CommandTCP, serializesSocksAddr(metadata)); err != nil {
if err != nil { c.Close()
return nil, err return nil, err
} }
return NewConn(c, t), nil return NewConn(c, t), nil
} }
c, err = t.dialer.DialContext(ctx, "tcp", t.addr) return t.DialContextWithDialer(ctx, dialer.NewDialer(t.Base.DialOptions(opts...)...), metadata)
}
// DialContextWithDialer implements C.ProxyAdapter
func (t *Trojan) DialContextWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (_ C.Conn, err error) {
if len(t.option.DialerProxy) > 0 {
dialer, err = proxydialer.NewByName(t.option.DialerProxy, dialer)
if err != nil {
return nil, err
}
}
c, err := dialer.DialContext(ctx, "tcp", t.addr)
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error: %w", t.addr, err) return nil, fmt.Errorf("%s connect error: %w", t.addr, err)
} }
N.TCPKeepAlive(c)
defer func(c net.Conn) { defer func(c net.Conn) {
safeConnClose(c, err) safeConnClose(c, err)
@@ -213,15 +149,11 @@ func (t *Trojan) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Con
} }
// ListenPacketContext implements C.ProxyAdapter // ListenPacketContext implements C.ProxyAdapter
func (t *Trojan) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (_ C.PacketConn, err error) { func (t *Trojan) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.PacketConn, err error) {
if err = t.ResolveUDP(ctx, metadata); err != nil {
return nil, err
}
var c net.Conn var c net.Conn
// grpc transport // grpc transport
if t.transport != nil { if t.transport != nil && len(opts) == 0 {
c, err = gun.StreamGunWithTransport(t.transport, t.gunConfig) c, err = gun.StreamGunWithTransport(t.transport, t.gunConfig)
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error: %w", t.addr, err) return nil, fmt.Errorf("%s connect error: %w", t.addr, err)
@@ -229,31 +161,55 @@ func (t *Trojan) ListenPacketContext(ctx context.Context, metadata *C.Metadata)
defer func(c net.Conn) { defer func(c net.Conn) {
safeConnClose(c, err) safeConnClose(c, err)
}(c) }(c)
err = t.instance.WriteHeader(c, trojan.CommandUDP, serializesSocksAddr(metadata))
c, err = t.streamConnContext(ctx, c, metadata)
if err != nil { if err != nil {
return nil, err return nil, err
} }
pc := trojan.NewPacketConn(c) pc := t.instance.PacketConn(c)
return newPacketConn(pc, t), err return newPacketConn(pc, t), err
} }
if err = t.ResolveUDP(ctx, metadata); err != nil { return t.ListenPacketWithDialer(ctx, dialer.NewDialer(t.Base.DialOptions(opts...)...), metadata)
return nil, err }
// ListenPacketWithDialer implements C.ProxyAdapter
func (t *Trojan) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (_ C.PacketConn, err error) {
if len(t.option.DialerProxy) > 0 {
dialer, err = proxydialer.NewByName(t.option.DialerProxy, dialer)
if err != nil {
return nil, err
}
} }
c, err = t.dialer.DialContext(ctx, "tcp", t.addr) c, err := dialer.DialContext(ctx, "tcp", t.addr)
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error: %w", t.addr, err) return nil, fmt.Errorf("%s connect error: %w", t.addr, err)
} }
defer func(c net.Conn) { defer func(c net.Conn) {
safeConnClose(c, err) safeConnClose(c, err)
}(c) }(c)
c, err = t.StreamConnContext(ctx, c, metadata) N.TCPKeepAlive(c)
c, err = t.plainStream(ctx, c)
if err != nil {
return nil, fmt.Errorf("%s connect error: %w", t.addr, err)
}
err = t.instance.WriteHeader(c, trojan.CommandUDP, serializesSocksAddr(metadata))
if err != nil { if err != nil {
return nil, err return nil, err
} }
pc := trojan.NewPacketConn(c) pc := t.instance.PacketConn(c)
return newPacketConn(pc, t), err
}
// SupportWithDialer implements C.ProxyAdapter
func (t *Trojan) SupportWithDialer() C.NetWork {
return C.ALLNet
}
// ListenPacketOnStreamConn implements C.ProxyAdapter
func (t *Trojan) ListenPacketOnStreamConn(c net.Conn, metadata *C.Metadata) (_ C.PacketConn, err error) {
pc := t.instance.PacketConn(c)
return newPacketConn(pc, t), err return newPacketConn(pc, t), err
} }
@@ -262,26 +218,20 @@ func (t *Trojan) SupportUOT() bool {
return true return true
} }
// ProxyInfo implements C.ProxyAdapter
func (t *Trojan) ProxyInfo() C.ProxyInfo {
info := t.Base.ProxyInfo()
info.DialerProxy = t.option.DialerProxy
return info
}
// Close implements C.ProxyAdapter
func (t *Trojan) Close() error {
if t.transport != nil {
return t.transport.Close()
}
return nil
}
func NewTrojan(option TrojanOption) (*Trojan, error) { func NewTrojan(option TrojanOption) (*Trojan, error) {
addr := net.JoinHostPort(option.Server, strconv.Itoa(option.Port)) addr := net.JoinHostPort(option.Server, strconv.Itoa(option.Port))
if option.SNI == "" { tOption := &trojan.Option{
option.SNI = option.Server Password: option.Password,
ALPN: option.ALPN,
ServerName: option.Server,
SkipCertVerify: option.SkipCertVerify,
Fingerprint: option.Fingerprint,
ClientFingerprint: option.ClientFingerprint,
}
if option.SNI != "" {
tOption.ServerName = option.SNI
} }
t := &Trojan{ t := &Trojan{
@@ -289,76 +239,61 @@ func NewTrojan(option TrojanOption) (*Trojan, error) {
name: option.Name, name: option.Name,
addr: addr, addr: addr,
tp: C.Trojan, tp: C.Trojan,
pdName: option.ProviderName,
udp: option.UDP, udp: option.UDP,
tfo: option.TFO, tfo: option.TFO,
mpTcp: option.MPTCP, mpTcp: option.MPTCP,
iface: option.Interface, iface: option.Interface,
rmark: option.RoutingMark, rmark: option.RoutingMark,
prefer: option.IPVersion, prefer: C.NewDNSPrefer(option.IPVersion),
}, },
option: &option, instance: trojan.New(tOption),
hexPassword: trojan.Key(option.Password), option: &option,
} }
t.dialer = option.NewDialer(t.DialOptions())
var err error var err error
t.realityConfig, err = option.RealityOpts.Parse() t.realityConfig, err = option.RealityOpts.Parse()
if err != nil { if err != nil {
return nil, err return nil, err
} }
tOption.Reality = t.realityConfig
t.echConfig, err = option.ECHOpts.Parse()
if err != nil {
return nil, err
}
if option.SSOpts.Enabled {
if option.SSOpts.Password == "" {
return nil, errors.New("empty password")
}
if option.SSOpts.Method == "" {
option.SSOpts.Method = "AES-128-GCM"
}
ciph, err := core.PickCipher(option.SSOpts.Method, nil, option.SSOpts.Password)
if err != nil {
return nil, err
}
t.ssCipher = ciph
}
if option.Network == "grpc" { if option.Network == "grpc" {
dialFn := func(ctx context.Context, network, addr string) (net.Conn, error) { dialFn := func(network, addr string) (net.Conn, error) {
c, err := t.dialer.DialContext(ctx, "tcp", t.addr) var err error
var cDialer C.Dialer = dialer.NewDialer(t.Base.DialOptions()...)
if len(t.option.DialerProxy) > 0 {
cDialer, err = proxydialer.NewByName(t.option.DialerProxy, cDialer)
if err != nil {
return nil, err
}
}
c, err := cDialer.DialContext(context.Background(), "tcp", t.addr)
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error: %s", t.addr, err.Error()) return nil, fmt.Errorf("%s connect error: %s", t.addr, err.Error())
} }
N.TCPKeepAlive(c)
return c, nil return c, nil
} }
tlsConfig, err := ca.GetTLSConfig(ca.Option{ tlsConfig := &tls.Config{
TLSConfig: &tls.Config{ NextProtos: option.ALPN,
NextProtos: option.ALPN, MinVersion: tls.VersionTLS12,
MinVersion: tls.VersionTLS12, InsecureSkipVerify: tOption.SkipCertVerify,
InsecureSkipVerify: option.SkipCertVerify, ServerName: tOption.ServerName,
ServerName: option.SNI, }
},
Fingerprint: option.Fingerprint, var err error
Certificate: option.Certificate, tlsConfig, err = ca.GetSpecifiedFingerprintTLSConfig(tlsConfig, option.Fingerprint)
PrivateKey: option.PrivateKey,
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
t.transport = gun.NewHTTP2Client(dialFn, tlsConfig, option.ClientFingerprint, t.echConfig, t.realityConfig) t.transport = gun.NewHTTP2Client(dialFn, tlsConfig, tOption.ClientFingerprint, t.realityConfig)
t.gunTLSConfig = tlsConfig t.gunTLSConfig = tlsConfig
t.gunConfig = &gun.Config{ t.gunConfig = &gun.Config{
ServiceName: option.GrpcOpts.GrpcServiceName, ServiceName: option.GrpcOpts.GrpcServiceName,
UserAgent: option.GrpcOpts.GrpcUserAgent, Host: tOption.ServerName,
Host: option.SNI,
ClientFingerprint: option.ClientFingerprint,
} }
} }

View File

@@ -2,6 +2,8 @@ package outbound
import ( import (
"context" "context"
"crypto/tls"
"errors"
"fmt" "fmt"
"math" "math"
"net" "net"
@@ -9,24 +11,22 @@ import (
"time" "time"
"github.com/metacubex/mihomo/component/ca" "github.com/metacubex/mihomo/component/ca"
"github.com/metacubex/mihomo/component/ech" "github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/proxydialer"
"github.com/metacubex/mihomo/component/resolver"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/transport/tuic" "github.com/metacubex/mihomo/transport/tuic"
"github.com/gofrs/uuid/v5" "github.com/gofrs/uuid/v5"
"github.com/metacubex/quic-go" "github.com/metacubex/quic-go"
M "github.com/metacubex/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
"github.com/metacubex/sing/common/uot" "github.com/sagernet/sing/common/uot"
"github.com/metacubex/tls"
) )
type Tuic struct { type Tuic struct {
*Base *Base
option *TuicOption option *TuicOption
client *tuic.PoolClient client *tuic.PoolClient
tlsConfig *tls.Config
echConfig *ech.Config
} }
type TuicOption struct { type TuicOption struct {
@@ -47,27 +47,31 @@ type TuicOption struct {
DisableSni bool `proxy:"disable-sni,omitempty"` DisableSni bool `proxy:"disable-sni,omitempty"`
MaxUdpRelayPacketSize int `proxy:"max-udp-relay-packet-size,omitempty"` MaxUdpRelayPacketSize int `proxy:"max-udp-relay-packet-size,omitempty"`
FastOpen bool `proxy:"fast-open,omitempty"` FastOpen bool `proxy:"fast-open,omitempty"`
MaxOpenStreams int `proxy:"max-open-streams,omitempty"` MaxOpenStreams int `proxy:"max-open-streams,omitempty"`
CWND int `proxy:"cwnd,omitempty"` CWND int `proxy:"cwnd,omitempty"`
SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"` SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"`
Fingerprint string `proxy:"fingerprint,omitempty"` Fingerprint string `proxy:"fingerprint,omitempty"`
Certificate string `proxy:"certificate,omitempty"` CustomCA string `proxy:"ca,omitempty"`
PrivateKey string `proxy:"private-key,omitempty"` CustomCAString string `proxy:"ca-str,omitempty"`
ReceiveWindowConn int `proxy:"recv-window-conn,omitempty"` ReceiveWindowConn int `proxy:"recv-window-conn,omitempty"`
ReceiveWindow int `proxy:"recv-window,omitempty"` ReceiveWindow int `proxy:"recv-window,omitempty"`
DisableMTUDiscovery bool `proxy:"disable-mtu-discovery,omitempty"` DisableMTUDiscovery bool `proxy:"disable-mtu-discovery,omitempty"`
MaxDatagramFrameSize int `proxy:"max-datagram-frame-size,omitempty"` MaxDatagramFrameSize int `proxy:"max-datagram-frame-size,omitempty"`
SNI string `proxy:"sni,omitempty"` SNI string `proxy:"sni,omitempty"`
ECHOpts ECHOptions `proxy:"ech-opts,omitempty"`
UDPOverStream bool `proxy:"udp-over-stream,omitempty"` UDPOverStream bool `proxy:"udp-over-stream,omitempty"`
UDPOverStreamVersion int `proxy:"udp-over-stream-version,omitempty"` UDPOverStreamVersion int `proxy:"udp-over-stream-version,omitempty"`
} }
// DialContext implements C.ProxyAdapter // DialContext implements C.ProxyAdapter
func (t *Tuic) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) { func (t *Tuic) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.Conn, error) {
conn, err := t.client.DialContext(ctx, metadata) return t.DialContextWithDialer(ctx, dialer.NewDialer(t.Base.DialOptions(opts...)...), metadata)
}
// DialContextWithDialer implements C.ProxyAdapter
func (t *Tuic) DialContextWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (C.Conn, error) {
conn, err := t.client.DialContextWithDialer(ctx, metadata, dialer, t.dialWithDialer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -75,22 +79,30 @@ func (t *Tuic) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, e
} }
// ListenPacketContext implements C.ProxyAdapter // ListenPacketContext implements C.ProxyAdapter
func (t *Tuic) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (_ C.PacketConn, err error) { func (t *Tuic) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.PacketConn, err error) {
if err = t.ResolveUDP(ctx, metadata); err != nil { return t.ListenPacketWithDialer(ctx, dialer.NewDialer(t.Base.DialOptions(opts...)...), metadata)
return nil, err }
}
// ListenPacketWithDialer implements C.ProxyAdapter
func (t *Tuic) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (_ C.PacketConn, err error) {
if t.option.UDPOverStream { if t.option.UDPOverStream {
uotDestination := uot.RequestDestination(uint8(t.option.UDPOverStreamVersion)) uotDestination := uot.RequestDestination(uint8(t.option.UDPOverStreamVersion))
uotMetadata := *metadata uotMetadata := *metadata
uotMetadata.Host = uotDestination.Fqdn uotMetadata.Host = uotDestination.Fqdn
uotMetadata.DstPort = uotDestination.Port uotMetadata.DstPort = uotDestination.Port
c, err := t.DialContext(ctx, &uotMetadata) c, err := t.DialContextWithDialer(ctx, dialer, &uotMetadata)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// tuic uos use stream-oriented udp with a special address, so we need a net.UDPAddr // tuic uos use stream-oriented udp with a special address, so we need a net.UDPAddr
if !metadata.Resolved() {
ip, err := resolver.ResolveIP(ctx, metadata.Host)
if err != nil {
return nil, errors.New("can't resolve ip")
}
metadata.DstIP = ip
}
destination := M.SocksaddrFromNet(metadata.UDPAddr()) destination := M.SocksaddrFromNet(metadata.UDPAddr())
if t.option.UDPOverStreamVersion == uot.LegacyVersion { if t.option.UDPOverStreamVersion == uot.LegacyVersion {
@@ -99,25 +111,32 @@ func (t *Tuic) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (_
return newPacketConn(uot.NewLazyConn(c, uot.Request{Destination: destination}), t), nil return newPacketConn(uot.NewLazyConn(c, uot.Request{Destination: destination}), t), nil
} }
} }
pc, err := t.client.ListenPacket(ctx, metadata) pc, err := t.client.ListenPacketWithDialer(ctx, metadata, dialer, t.dialWithDialer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return newPacketConn(pc, t), nil return newPacketConn(pc, t), nil
} }
func (t *Tuic) dial(ctx context.Context) (transport *quic.Transport, addr net.Addr, err error) { // SupportWithDialer implements C.ProxyAdapter
udpAddr, err := resolveUDPAddr(ctx, "udp", t.addr, t.prefer) func (t *Tuic) SupportWithDialer() C.NetWork {
if err != nil { return C.ALLNet
return nil, nil, err }
func (t *Tuic) dialWithDialer(ctx context.Context, dialer C.Dialer) (transport *quic.Transport, addr net.Addr, err error) {
if len(t.option.DialerProxy) > 0 {
dialer, err = proxydialer.NewByName(t.option.DialerProxy, dialer)
if err != nil {
return nil, nil, err
}
} }
err = t.echConfig.ClientHandle(ctx, t.tlsConfig) udpAddr, err := resolveUDPAddrWithPrefer(ctx, "udp", t.addr, t.prefer)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
addr = udpAddr addr = udpAddr
var pc net.PacketConn var pc net.PacketConn
pc, err = t.dialer.ListenPacket(ctx, "udp", "", udpAddr.AddrPort()) pc, err = dialer.ListenPacket(ctx, "udp", "", udpAddr.AddrPort())
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@@ -127,30 +146,20 @@ func (t *Tuic) dial(ctx context.Context) (transport *quic.Transport, addr net.Ad
return return
} }
// ProxyInfo implements C.ProxyAdapter
func (t *Tuic) ProxyInfo() C.ProxyInfo {
info := t.Base.ProxyInfo()
info.DialerProxy = t.option.DialerProxy
return info
}
func NewTuic(option TuicOption) (*Tuic, error) { func NewTuic(option TuicOption) (*Tuic, error) {
addr := net.JoinHostPort(option.Server, strconv.Itoa(option.Port)) addr := net.JoinHostPort(option.Server, strconv.Itoa(option.Port))
serverName := option.Server serverName := option.Server
tlsConfig := &tls.Config{
ServerName: serverName,
InsecureSkipVerify: option.SkipCertVerify,
MinVersion: tls.VersionTLS13,
}
if option.SNI != "" { if option.SNI != "" {
serverName = option.SNI tlsConfig.ServerName = option.SNI
} }
tlsConfig, err := ca.GetTLSConfig(ca.Option{ var err error
TLSConfig: &tls.Config{ tlsConfig, err = ca.GetTLSConfig(tlsConfig, option.Fingerprint, option.CustomCA, option.CustomCAString)
ServerName: serverName,
InsecureSkipVerify: option.SkipCertVerify,
MinVersion: tls.VersionTLS13,
},
Fingerprint: option.Fingerprint,
Certificate: option.Certificate,
PrivateKey: option.PrivateKey,
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -232,12 +241,6 @@ func NewTuic(option TuicOption) (*Tuic, error) {
tlsConfig.InsecureSkipVerify = true // tls: either ServerName or InsecureSkipVerify must be specified in the tls.Config tlsConfig.InsecureSkipVerify = true // tls: either ServerName or InsecureSkipVerify must be specified in the tls.Config
} }
tlsClientConfig := tlsConfig
echConfig, err := option.ECHOpts.Parse()
if err != nil {
return nil, err
}
switch option.UDPOverStreamVersion { switch option.UDPOverStreamVersion {
case uot.Version, uot.LegacyVersion: case uot.Version, uot.LegacyVersion:
case 0: case 0:
@@ -251,18 +254,14 @@ func NewTuic(option TuicOption) (*Tuic, error) {
name: option.Name, name: option.Name,
addr: addr, addr: addr,
tp: C.Tuic, tp: C.Tuic,
pdName: option.ProviderName,
udp: true, udp: true,
tfo: option.FastOpen, tfo: option.FastOpen,
iface: option.Interface, iface: option.Interface,
rmark: option.RoutingMark, rmark: option.RoutingMark,
prefer: option.IPVersion, prefer: C.NewDNSPrefer(option.IPVersion),
}, },
option: &option, option: &option,
tlsConfig: tlsClientConfig,
echConfig: echConfig,
} }
t.dialer = option.NewDialer(t.DialOptions())
clientMaxOpenStreams := int64(option.MaxOpenStreams) clientMaxOpenStreams := int64(option.MaxOpenStreams)
@@ -278,7 +277,7 @@ func NewTuic(option TuicOption) (*Tuic, error) {
if len(option.Token) > 0 { if len(option.Token) > 0 {
tkn := tuic.GenTKN(option.Token) tkn := tuic.GenTKN(option.Token)
clientOption := &tuic.ClientOptionV4{ clientOption := &tuic.ClientOptionV4{
TlsConfig: tlsClientConfig, TlsConfig: tlsConfig,
QuicConfig: quicConfig, QuicConfig: quicConfig,
Token: tkn, Token: tkn,
UdpRelayMode: udpRelayMode, UdpRelayMode: udpRelayMode,
@@ -291,14 +290,14 @@ func NewTuic(option TuicOption) (*Tuic, error) {
CWND: option.CWND, CWND: option.CWND,
} }
t.client = tuic.NewPoolClientV4(clientOption, t.dial) t.client = tuic.NewPoolClientV4(clientOption)
} else { } else {
maxUdpRelayPacketSize := option.MaxUdpRelayPacketSize maxUdpRelayPacketSize := option.MaxUdpRelayPacketSize
if maxUdpRelayPacketSize > tuic.MaxFragSizeV5 { if maxUdpRelayPacketSize > tuic.MaxFragSizeV5 {
maxUdpRelayPacketSize = tuic.MaxFragSizeV5 maxUdpRelayPacketSize = tuic.MaxFragSizeV5
} }
clientOption := &tuic.ClientOptionV5{ clientOption := &tuic.ClientOptionV5{
TlsConfig: tlsClientConfig, TlsConfig: tlsConfig,
QuicConfig: quicConfig, QuicConfig: quicConfig,
Uuid: uuid.FromStringOrNil(option.UUID), Uuid: uuid.FromStringOrNil(option.UUID),
Password: option.Password, Password: option.Password,
@@ -310,7 +309,7 @@ func NewTuic(option TuicOption) (*Tuic, error) {
CWND: option.CWND, CWND: option.CWND,
} }
t.client = tuic.NewPoolClientV5(clientOption, t.dial) t.client = tuic.NewPoolClientV5(clientOption)
} }
return t, nil return t, nil

View File

@@ -3,71 +3,119 @@ package outbound
import ( import (
"bytes" "bytes"
"context" "context"
"crypto/tls"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"regexp" "regexp"
"strconv" "strconv"
"sync"
"github.com/metacubex/mihomo/component/resolver" "github.com/metacubex/mihomo/component/resolver"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/transport/socks5" "github.com/metacubex/mihomo/transport/socks5"
) )
var (
globalClientSessionCache tls.ClientSessionCache
once sync.Once
)
func getClientSessionCache() tls.ClientSessionCache {
once.Do(func() {
globalClientSessionCache = tls.NewLRUClientSessionCache(128)
})
return globalClientSessionCache
}
func serializesSocksAddr(metadata *C.Metadata) []byte { func serializesSocksAddr(metadata *C.Metadata) []byte {
var buf [][]byte var buf [][]byte
addrType := metadata.AddrType() addrType := metadata.AddrType()
aType := uint8(addrType)
p := uint(metadata.DstPort) p := uint(metadata.DstPort)
port := []byte{uint8(p >> 8), uint8(p & 0xff)} port := []byte{uint8(p >> 8), uint8(p & 0xff)}
switch addrType { switch addrType {
case C.AtypDomainName: case socks5.AtypDomainName:
lenM := uint8(len(metadata.Host)) lenM := uint8(len(metadata.Host))
host := []byte(metadata.Host) host := []byte(metadata.Host)
buf = [][]byte{{socks5.AtypDomainName, lenM}, host, port} buf = [][]byte{{aType, lenM}, host, port}
case C.AtypIPv4: case socks5.AtypIPv4:
host := metadata.DstIP.AsSlice() host := metadata.DstIP.AsSlice()
buf = [][]byte{{socks5.AtypIPv4}, host, port} buf = [][]byte{{aType}, host, port}
case C.AtypIPv6: case socks5.AtypIPv6:
host := metadata.DstIP.AsSlice() host := metadata.DstIP.AsSlice()
buf = [][]byte{{socks5.AtypIPv6}, host, port} buf = [][]byte{{aType}, host, port}
} }
return bytes.Join(buf, nil) return bytes.Join(buf, nil)
} }
func resolveUDPAddr(ctx context.Context, network, address string, prefer C.DNSPrefer) (*net.UDPAddr, error) { func resolveUDPAddr(ctx context.Context, network, address string) (*net.UDPAddr, error) {
host, port, err := net.SplitHostPort(address) host, port, err := net.SplitHostPort(address)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ip, err := resolver.ResolveProxyServerHost(ctx, host)
if err != nil {
return nil, err
}
return net.ResolveUDPAddr(network, net.JoinHostPort(ip.String(), port))
}
func resolveUDPAddrWithPrefer(ctx context.Context, network, address string, prefer C.DNSPrefer) (*net.UDPAddr, error) {
host, port, err := net.SplitHostPort(address)
if err != nil {
return nil, err
}
var ip netip.Addr var ip netip.Addr
var fallback netip.Addr
switch prefer { switch prefer {
case C.IPv4Only: case C.IPv4Only:
ip, err = resolver.ResolveIPv4WithResolver(ctx, host, resolver.ProxyServerHostResolver) ip, err = resolver.ResolveIPv4ProxyServerHost(ctx, host)
case C.IPv6Only: case C.IPv6Only:
ip, err = resolver.ResolveIPv6WithResolver(ctx, host, resolver.ProxyServerHostResolver) ip, err = resolver.ResolveIPv6ProxyServerHost(ctx, host)
case C.IPv6Prefer: case C.IPv6Prefer:
ip, err = resolver.ResolveIPPrefer6WithResolver(ctx, host, resolver.ProxyServerHostResolver) var ips []netip.Addr
ips, err = resolver.LookupIPProxyServerHost(ctx, host)
if err == nil {
for _, addr := range ips {
if addr.Is6() {
ip = addr
break
} else {
if !fallback.IsValid() {
fallback = addr
}
}
}
}
default: default:
ip, err = resolver.ResolveIPWithResolver(ctx, host, resolver.ProxyServerHostResolver) // C.IPv4Prefer, C.DualStack and other
var ips []netip.Addr
ips, err = resolver.LookupIPProxyServerHost(ctx, host)
if err == nil {
for _, addr := range ips {
if addr.Is4() {
ip = addr
break
} else {
if !fallback.IsValid() {
fallback = addr
}
}
}
}
}
if !ip.IsValid() && fallback.IsValid() {
ip = fallback
} }
if err != nil { if err != nil {
return nil, err return nil, err
} }
return net.ResolveUDPAddr(network, net.JoinHostPort(ip.String(), port))
ip, port = resolver.LookupIP4P(ip, port)
var uint16Port uint16
if port, err := strconv.ParseUint(port, 10, 16); err == nil {
uint16Port = uint16(port)
} else {
return nil, err
}
// our resolver always unmap before return, so unneeded unmap at here
// which is different with net.ResolveUDPAddr maybe return 4in6 address
// 4in6 addresses can cause some strange effects on sing-based code
return net.UDPAddrFromAddrPort(netip.AddrPortFrom(ip, uint16Port)), nil
} }
func safeConnClose(c net.Conn, err error) { func safeConnClose(c net.Conn, err error) {

View File

@@ -2,27 +2,39 @@ package outbound
import ( import (
"context" "context"
"crypto/tls"
"encoding/binary"
"errors"
"fmt" "fmt"
"io"
"net" "net"
"net/http"
"strconv" "strconv"
"sync"
"github.com/metacubex/mihomo/common/convert" "github.com/metacubex/mihomo/common/convert"
N "github.com/metacubex/mihomo/common/net" N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/common/utils" "github.com/metacubex/mihomo/common/utils"
"github.com/metacubex/mihomo/component/ca" "github.com/metacubex/mihomo/component/ca"
"github.com/metacubex/mihomo/component/ech" "github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/proxydialer"
"github.com/metacubex/mihomo/component/resolver"
tlsC "github.com/metacubex/mihomo/component/tls" tlsC "github.com/metacubex/mihomo/component/tls"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/log"
"github.com/metacubex/mihomo/transport/gun" "github.com/metacubex/mihomo/transport/gun"
"github.com/metacubex/mihomo/transport/socks5"
"github.com/metacubex/mihomo/transport/vless" "github.com/metacubex/mihomo/transport/vless"
"github.com/metacubex/mihomo/transport/vless/encryption"
"github.com/metacubex/mihomo/transport/vmess" "github.com/metacubex/mihomo/transport/vmess"
"github.com/metacubex/http"
vmessSing "github.com/metacubex/sing-vmess" vmessSing "github.com/metacubex/sing-vmess"
"github.com/metacubex/sing-vmess/packetaddr" "github.com/metacubex/sing-vmess/packetaddr"
M "github.com/metacubex/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
"github.com/metacubex/tls" )
const (
// max packet length
maxLength = 1024 << 3
) )
type Vless struct { type Vless struct {
@@ -30,15 +42,12 @@ type Vless struct {
client *vless.Client client *vless.Client
option *VlessOption option *VlessOption
encryption *encryption.ClientInstance
// for gun mux // for gun mux
gunTLSConfig *tls.Config gunTLSConfig *tls.Config
gunConfig *gun.Config gunConfig *gun.Config
transport *gun.TransportWrap transport *gun.TransportWrap
realityConfig *tlsC.RealityConfig realityConfig *tlsC.RealityConfig
echConfig *ech.Config
} }
type VlessOption struct { type VlessOption struct {
@@ -54,24 +63,27 @@ type VlessOption struct {
PacketAddr bool `proxy:"packet-addr,omitempty"` PacketAddr bool `proxy:"packet-addr,omitempty"`
XUDP bool `proxy:"xudp,omitempty"` XUDP bool `proxy:"xudp,omitempty"`
PacketEncoding string `proxy:"packet-encoding,omitempty"` PacketEncoding string `proxy:"packet-encoding,omitempty"`
Encryption string `proxy:"encryption,omitempty"`
Network string `proxy:"network,omitempty"` Network string `proxy:"network,omitempty"`
ECHOpts ECHOptions `proxy:"ech-opts,omitempty"`
RealityOpts RealityOptions `proxy:"reality-opts,omitempty"` RealityOpts RealityOptions `proxy:"reality-opts,omitempty"`
HTTPOpts HTTPOptions `proxy:"http-opts,omitempty"` HTTPOpts HTTPOptions `proxy:"http-opts,omitempty"`
HTTP2Opts HTTP2Options `proxy:"h2-opts,omitempty"` HTTP2Opts HTTP2Options `proxy:"h2-opts,omitempty"`
GrpcOpts GrpcOptions `proxy:"grpc-opts,omitempty"` GrpcOpts GrpcOptions `proxy:"grpc-opts,omitempty"`
WSOpts WSOptions `proxy:"ws-opts,omitempty"` WSOpts WSOptions `proxy:"ws-opts,omitempty"`
WSPath string `proxy:"ws-path,omitempty"`
WSHeaders map[string]string `proxy:"ws-headers,omitempty"` WSHeaders map[string]string `proxy:"ws-headers,omitempty"`
SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"` SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"`
Fingerprint string `proxy:"fingerprint,omitempty"` Fingerprint string `proxy:"fingerprint,omitempty"`
Certificate string `proxy:"certificate,omitempty"`
PrivateKey string `proxy:"private-key,omitempty"`
ServerName string `proxy:"servername,omitempty"` ServerName string `proxy:"servername,omitempty"`
ClientFingerprint string `proxy:"client-fingerprint,omitempty"` ClientFingerprint string `proxy:"client-fingerprint,omitempty"`
} }
func (v *Vless) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (_ net.Conn, err error) { func (v *Vless) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (net.Conn, error) {
var err error
if tlsC.HaveGlobalFingerprint() && len(v.option.ClientFingerprint) == 0 {
v.option.ClientFingerprint = tlsC.GetGlobalFingerprint()
}
switch v.option.Network { switch v.option.Network {
case "ws": case "ws":
host, port, _ := net.SplitHostPort(v.addr) host, port, _ := net.SplitHostPort(v.addr)
@@ -84,7 +96,6 @@ func (v *Vless) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.M
V2rayHttpUpgrade: v.option.WSOpts.V2rayHttpUpgrade, V2rayHttpUpgrade: v.option.WSOpts.V2rayHttpUpgrade,
V2rayHttpUpgradeFastOpen: v.option.WSOpts.V2rayHttpUpgradeFastOpen, V2rayHttpUpgradeFastOpen: v.option.WSOpts.V2rayHttpUpgradeFastOpen,
ClientFingerprint: v.option.ClientFingerprint, ClientFingerprint: v.option.ClientFingerprint,
ECHConfig: v.echConfig,
Headers: http.Header{}, Headers: http.Header{},
} }
@@ -95,17 +106,14 @@ func (v *Vless) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.M
} }
if v.option.TLS { if v.option.TLS {
wsOpts.TLS = true wsOpts.TLS = true
wsOpts.TLSConfig, err = ca.GetTLSConfig(ca.Option{ tlsConfig := &tls.Config{
TLSConfig: &tls.Config{ MinVersion: tls.VersionTLS12,
MinVersion: tls.VersionTLS12, ServerName: host,
ServerName: host, InsecureSkipVerify: v.option.SkipCertVerify,
InsecureSkipVerify: v.option.SkipCertVerify, NextProtos: []string{"http/1.1"},
NextProtos: []string{"http/1.1"}, }
},
Fingerprint: v.option.Fingerprint, wsOpts.TLSConfig, err = ca.GetSpecifiedFingerprintTLSConfig(tlsConfig, v.option.Fingerprint)
Certificate: v.option.Certificate,
PrivateKey: v.option.PrivateKey,
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -149,9 +157,9 @@ func (v *Vless) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.M
Path: v.option.HTTP2Opts.Path, Path: v.option.HTTP2Opts.Path,
} }
c, err = vmess.StreamH2Conn(ctx, c, h2Opts) c, err = vmess.StreamH2Conn(c, h2Opts)
case "grpc": case "grpc":
c, err = gun.StreamGunWithConn(c, v.gunTLSConfig, v.gunConfig, v.echConfig, v.realityConfig) c, err = gun.StreamGunWithConn(c, v.gunTLSConfig, v.gunConfig, v.realityConfig)
default: default:
// default tcp network // default tcp network
// handle TLS // handle TLS
@@ -162,20 +170,10 @@ func (v *Vless) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.M
return nil, err return nil, err
} }
return v.streamConnContext(ctx, c, metadata) return v.streamConn(c, metadata)
} }
func (v *Vless) streamConnContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (conn net.Conn, err error) { func (v *Vless) streamConn(c net.Conn, metadata *C.Metadata) (conn net.Conn, err error) {
if ctx.Done() != nil {
done := N.SetupContextForConn(ctx, c)
defer done(&err)
}
if v.encryption != nil {
c, err = v.encryption.Handshake(c)
if err != nil {
return
}
}
if metadata.NetWork == C.UDP { if metadata.NetWork == C.UDP {
if v.option.PacketAddr { if v.option.PacketAddr {
metadata = &C.Metadata{ metadata = &C.Metadata{
@@ -191,6 +189,9 @@ func (v *Vless) streamConnContext(ctx context.Context, c net.Conn, metadata *C.M
} }
} }
conn, err = v.client.StreamConn(c, parseVlessAddr(metadata, v.option.XUDP)) conn, err = v.client.StreamConn(c, parseVlessAddr(metadata, v.option.XUDP))
if v.option.PacketAddr {
conn = packetaddr.NewBindConn(conn)
}
} else { } else {
conn, err = v.client.StreamConn(c, parseVlessAddr(metadata, false)) conn, err = v.client.StreamConn(c, parseVlessAddr(metadata, false))
} }
@@ -208,10 +209,7 @@ func (v *Vless) streamTLSConn(ctx context.Context, conn net.Conn, isH2 bool) (ne
Host: host, Host: host,
SkipCertVerify: v.option.SkipCertVerify, SkipCertVerify: v.option.SkipCertVerify,
FingerPrint: v.option.Fingerprint, FingerPrint: v.option.Fingerprint,
Certificate: v.option.Certificate,
PrivateKey: v.option.PrivateKey,
ClientFingerprint: v.option.ClientFingerprint, ClientFingerprint: v.option.ClientFingerprint,
ECH: v.echConfig,
Reality: v.realityConfig, Reality: v.realityConfig,
NextProtos: v.option.ALPN, NextProtos: v.option.ALPN,
} }
@@ -231,11 +229,10 @@ func (v *Vless) streamTLSConn(ctx context.Context, conn net.Conn, isH2 bool) (ne
} }
// DialContext implements C.ProxyAdapter // DialContext implements C.ProxyAdapter
func (v *Vless) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn, err error) { func (v *Vless) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.Conn, err error) {
var c net.Conn
// gun transport // gun transport
if v.transport != nil { if v.transport != nil && len(opts) == 0 {
c, err = gun.StreamGunWithTransport(v.transport, v.gunConfig) c, err := gun.StreamGunWithTransport(v.transport, v.gunConfig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -243,17 +240,29 @@ func (v *Vless) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn
safeConnClose(c, err) safeConnClose(c, err)
}(c) }(c)
c, err = v.streamConnContext(ctx, c, metadata) c, err = v.client.StreamConn(c, parseVlessAddr(metadata, v.option.XUDP))
if err != nil { if err != nil {
return nil, err return nil, err
} }
return NewConn(c, v), nil return NewConn(c, v), nil
} }
c, err = v.dialer.DialContext(ctx, "tcp", v.addr) return v.DialContextWithDialer(ctx, dialer.NewDialer(v.Base.DialOptions(opts...)...), metadata)
}
// DialContextWithDialer implements C.ProxyAdapter
func (v *Vless) DialContextWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (_ C.Conn, err error) {
if len(v.option.DialerProxy) > 0 {
dialer, err = proxydialer.NewByName(v.option.DialerProxy, dialer)
if err != nil {
return nil, err
}
}
c, err := dialer.DialContext(ctx, "tcp", v.addr)
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error: %s", v.addr, err.Error()) return nil, fmt.Errorf("%s connect error: %s", v.addr, err.Error())
} }
N.TCPKeepAlive(c)
defer func(c net.Conn) { defer func(c net.Conn) {
safeConnClose(c, err) safeConnClose(c, err)
}(c) }(c)
@@ -266,13 +275,18 @@ func (v *Vless) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn
} }
// ListenPacketContext implements C.ProxyAdapter // ListenPacketContext implements C.ProxyAdapter
func (v *Vless) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (_ C.PacketConn, err error) { func (v *Vless) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.PacketConn, err error) {
if err = v.ResolveUDP(ctx, metadata); err != nil { // vless use stream-oriented udp with a special address, so we need a net.UDPAddr
return nil, err if !metadata.Resolved() {
ip, err := resolver.ResolveIP(ctx, metadata.Host)
if err != nil {
return nil, errors.New("can't resolve ip")
}
metadata.DstIP = ip
} }
var c net.Conn var c net.Conn
// gun transport // gun transport
if v.transport != nil { if v.transport != nil && len(opts) == 0 {
c, err = gun.StreamGunWithTransport(v.transport, v.gunConfig) c, err = gun.StreamGunWithTransport(v.transport, v.gunConfig)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -281,22 +295,39 @@ func (v *Vless) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (
safeConnClose(c, err) safeConnClose(c, err)
}(c) }(c)
c, err = v.streamConnContext(ctx, c, metadata) c, err = v.streamConn(c, metadata)
if err != nil { if err != nil {
return nil, fmt.Errorf("new vless client error: %v", err) return nil, fmt.Errorf("new vless client error: %v", err)
} }
return v.ListenPacketOnStreamConn(ctx, c, metadata) return v.ListenPacketOnStreamConn(ctx, c, metadata)
} }
return v.ListenPacketWithDialer(ctx, dialer.NewDialer(v.Base.DialOptions(opts...)...), metadata)
}
if err = v.ResolveUDP(ctx, metadata); err != nil { // ListenPacketWithDialer implements C.ProxyAdapter
return nil, err func (v *Vless) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (_ C.PacketConn, err error) {
if len(v.option.DialerProxy) > 0 {
dialer, err = proxydialer.NewByName(v.option.DialerProxy, dialer)
if err != nil {
return nil, err
}
} }
c, err = v.dialer.DialContext(ctx, "tcp", v.addr) // vless use stream-oriented udp with a special address, so we need a net.UDPAddr
if !metadata.Resolved() {
ip, err := resolver.ResolveIP(ctx, metadata.Host)
if err != nil {
return nil, errors.New("can't resolve ip")
}
metadata.DstIP = ip
}
c, err := dialer.DialContext(ctx, "tcp", v.addr)
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error: %s", v.addr, err.Error()) return nil, fmt.Errorf("%s connect error: %s", v.addr, err.Error())
} }
N.TCPKeepAlive(c)
defer func(c net.Conn) { defer func(c net.Conn) {
safeConnClose(c, err) safeConnClose(c, err)
}(c) }(c)
@@ -309,10 +340,20 @@ func (v *Vless) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (
return v.ListenPacketOnStreamConn(ctx, c, metadata) return v.ListenPacketOnStreamConn(ctx, c, metadata)
} }
// SupportWithDialer implements C.ProxyAdapter
func (v *Vless) SupportWithDialer() C.NetWork {
return C.ALLNet
}
// ListenPacketOnStreamConn implements C.ProxyAdapter // ListenPacketOnStreamConn implements C.ProxyAdapter
func (v *Vless) ListenPacketOnStreamConn(ctx context.Context, c net.Conn, metadata *C.Metadata) (_ C.PacketConn, err error) { func (v *Vless) ListenPacketOnStreamConn(ctx context.Context, c net.Conn, metadata *C.Metadata) (_ C.PacketConn, err error) {
if err = v.ResolveUDP(ctx, metadata); err != nil { // vless use stream-oriented udp with a special address, so we need a net.UDPAddr
return nil, err if !metadata.Resolved() {
ip, err := resolver.ResolveIP(ctx, metadata.Host)
if err != nil {
return nil, errors.New("can't resolve ip")
}
metadata.DstIP = ip
} }
if v.option.XUDP { if v.option.XUDP {
@@ -327,11 +368,12 @@ func (v *Vless) ListenPacketOnStreamConn(ctx context.Context, c net.Conn, metada
), v), nil ), v), nil
} else if v.option.PacketAddr { } else if v.option.PacketAddr {
return newPacketConn(N.NewThreadSafePacketConn( return newPacketConn(N.NewThreadSafePacketConn(
packetaddr.NewConn(v.client.PacketConn(c, metadata.UDPAddr()), packetaddr.NewConn(&vlessPacketConn{
M.SocksaddrFromNet(metadata.UDPAddr())), Conn: c, rAddr: metadata.UDPAddr(),
}, M.SocksaddrFromNet(metadata.UDPAddr())),
), v), nil ), v), nil
} }
return newPacketConn(N.NewThreadSafePacketConn(v.client.PacketConn(c, metadata.UDPAddr())), v), nil return newPacketConn(N.NewThreadSafePacketConn(&vlessPacketConn{Conn: c, rAddr: metadata.UDPAddr()}), v), nil
} }
// SupportUOT implements C.ProxyAdapter // SupportUOT implements C.ProxyAdapter
@@ -339,34 +381,19 @@ func (v *Vless) SupportUOT() bool {
return true return true
} }
// ProxyInfo implements C.ProxyAdapter
func (v *Vless) ProxyInfo() C.ProxyInfo {
info := v.Base.ProxyInfo()
info.DialerProxy = v.option.DialerProxy
return info
}
// Close implements C.ProxyAdapter
func (v *Vless) Close() error {
if v.transport != nil {
return v.transport.Close()
}
return nil
}
func parseVlessAddr(metadata *C.Metadata, xudp bool) *vless.DstAddr { func parseVlessAddr(metadata *C.Metadata, xudp bool) *vless.DstAddr {
var addrType byte var addrType byte
var addr []byte var addr []byte
switch metadata.AddrType() { switch metadata.AddrType() {
case C.AtypIPv4: case socks5.AtypIPv4:
addrType = vless.AtypIPv4 addrType = vless.AtypIPv4
addr = make([]byte, net.IPv4len) addr = make([]byte, net.IPv4len)
copy(addr[:], metadata.DstIP.AsSlice()) copy(addr[:], metadata.DstIP.AsSlice())
case C.AtypIPv6: case socks5.AtypIPv6:
addrType = vless.AtypIPv6 addrType = vless.AtypIPv6
addr = make([]byte, net.IPv6len) addr = make([]byte, net.IPv6len)
copy(addr[:], metadata.DstIP.AsSlice()) copy(addr[:], metadata.DstIP.AsSlice())
case C.AtypDomainName: case socks5.AtypDomainName:
addrType = vless.AtypDomainName addrType = vless.AtypDomainName
addr = make([]byte, len(metadata.Host)+1) addr = make([]byte, len(metadata.Host)+1)
addr[0] = byte(len(metadata.Host)) addr[0] = byte(len(metadata.Host))
@@ -382,16 +409,113 @@ func parseVlessAddr(metadata *C.Metadata, xudp bool) *vless.DstAddr {
} }
} }
type vlessPacketConn struct {
net.Conn
rAddr net.Addr
remain int
mux sync.Mutex
cache [2]byte
}
func (c *vlessPacketConn) writePacket(payload []byte) (int, error) {
binary.BigEndian.PutUint16(c.cache[:], uint16(len(payload)))
if _, err := c.Conn.Write(c.cache[:]); err != nil {
return 0, err
}
return c.Conn.Write(payload)
}
func (c *vlessPacketConn) WriteTo(b []byte, addr net.Addr) (int, error) {
total := len(b)
if total == 0 {
return 0, nil
}
if total <= maxLength {
return c.writePacket(b)
}
offset := 0
for offset < total {
cursor := offset + maxLength
if cursor > total {
cursor = total
}
n, err := c.writePacket(b[offset:cursor])
if err != nil {
return offset + n, err
}
offset = cursor
if offset == total {
break
}
}
return total, nil
}
func (c *vlessPacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
c.mux.Lock()
defer c.mux.Unlock()
if c.remain > 0 {
length := len(b)
if c.remain < length {
length = c.remain
}
n, err := c.Conn.Read(b[:length])
if err != nil {
return 0, c.rAddr, err
}
c.remain -= n
return n, c.rAddr, nil
}
if _, err := c.Conn.Read(b[:2]); err != nil {
return 0, c.rAddr, err
}
total := int(binary.BigEndian.Uint16(b[:2]))
if total == 0 {
return 0, c.rAddr, nil
}
length := len(b)
if length > total {
length = total
}
if _, err := io.ReadFull(c.Conn, b[:length]); err != nil {
return 0, c.rAddr, errors.New("read packet error")
}
c.remain = total - length
return length, c.rAddr, nil
}
func NewVless(option VlessOption) (*Vless, error) { func NewVless(option VlessOption) (*Vless, error) {
var addons *vless.Addons var addons *vless.Addons
if len(option.Flow) >= 16 { if option.Network != "ws" && len(option.Flow) >= 16 {
option.Flow = option.Flow[:16] option.Flow = option.Flow[:16]
if option.Flow != vless.XRV { switch option.Flow {
case vless.XRV:
log.Warnln("To use %s, ensure your server is upgrade to Xray-core v1.8.0+", vless.XRV)
addons = &vless.Addons{
Flow: option.Flow,
}
case vless.XRO, vless.XRD, vless.XRS:
log.Fatalln("Legacy XTLS protocol %s is deprecated and no longer supported", option.Flow)
default:
return nil, fmt.Errorf("unsupported xtls flow type: %s", option.Flow) return nil, fmt.Errorf("unsupported xtls flow type: %s", option.Flow)
} }
addons = &vless.Addons{
Flow: option.Flow,
}
} }
switch option.PacketEncoding { switch option.PacketEncoding {
@@ -417,52 +541,48 @@ func NewVless(option VlessOption) (*Vless, error) {
name: option.Name, name: option.Name,
addr: net.JoinHostPort(option.Server, strconv.Itoa(option.Port)), addr: net.JoinHostPort(option.Server, strconv.Itoa(option.Port)),
tp: C.Vless, tp: C.Vless,
pdName: option.ProviderName,
udp: option.UDP, udp: option.UDP,
xudp: option.XUDP, xudp: option.XUDP,
tfo: option.TFO, tfo: option.TFO,
mpTcp: option.MPTCP, mpTcp: option.MPTCP,
iface: option.Interface, iface: option.Interface,
rmark: option.RoutingMark, rmark: option.RoutingMark,
prefer: option.IPVersion, prefer: C.NewDNSPrefer(option.IPVersion),
}, },
client: client, client: client,
option: &option, option: &option,
} }
v.dialer = option.NewDialer(v.DialOptions())
v.encryption, err = encryption.NewClient(option.Encryption)
if err != nil {
return nil, err
}
v.realityConfig, err = v.option.RealityOpts.Parse() v.realityConfig, err = v.option.RealityOpts.Parse()
if err != nil { if err != nil {
return nil, err return nil, err
} }
v.echConfig, err = v.option.ECHOpts.Parse()
if err != nil {
return nil, err
}
switch option.Network { switch option.Network {
case "h2": case "h2":
if len(option.HTTP2Opts.Host) == 0 { if len(option.HTTP2Opts.Host) == 0 {
option.HTTP2Opts.Host = append(option.HTTP2Opts.Host, "www.example.com") option.HTTP2Opts.Host = append(option.HTTP2Opts.Host, "www.example.com")
} }
case "grpc": case "grpc":
dialFn := func(ctx context.Context, network, addr string) (net.Conn, error) { dialFn := func(network, addr string) (net.Conn, error) {
c, err := v.dialer.DialContext(ctx, "tcp", v.addr) var err error
var cDialer C.Dialer = dialer.NewDialer(v.Base.DialOptions()...)
if len(v.option.DialerProxy) > 0 {
cDialer, err = proxydialer.NewByName(v.option.DialerProxy, cDialer)
if err != nil {
return nil, err
}
}
c, err := cDialer.DialContext(context.Background(), "tcp", v.addr)
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error: %s", v.addr, err.Error()) return nil, fmt.Errorf("%s connect error: %s", v.addr, err.Error())
} }
N.TCPKeepAlive(c)
return c, nil return c, nil
} }
gunConfig := &gun.Config{ gunConfig := &gun.Config{
ServiceName: v.option.GrpcOpts.GrpcServiceName, ServiceName: v.option.GrpcOpts.GrpcServiceName,
UserAgent: v.option.GrpcOpts.GrpcUserAgent,
Host: v.option.ServerName, Host: v.option.ServerName,
ClientFingerprint: v.option.ClientFingerprint, ClientFingerprint: v.option.ClientFingerprint,
} }
@@ -471,18 +591,10 @@ func NewVless(option VlessOption) (*Vless, error) {
} }
var tlsConfig *tls.Config var tlsConfig *tls.Config
if option.TLS { if option.TLS {
tlsConfig, err = ca.GetTLSConfig(ca.Option{ tlsConfig = ca.GetGlobalTLSConfig(&tls.Config{
TLSConfig: &tls.Config{ InsecureSkipVerify: v.option.SkipCertVerify,
InsecureSkipVerify: v.option.SkipCertVerify, ServerName: v.option.ServerName,
ServerName: v.option.ServerName,
},
Fingerprint: v.option.Fingerprint,
Certificate: v.option.Certificate,
PrivateKey: v.option.PrivateKey,
}) })
if err != nil {
return nil, err
}
if option.ServerName == "" { if option.ServerName == "" {
host, _, _ := net.SplitHostPort(v.addr) host, _, _ := net.SplitHostPort(v.addr)
tlsConfig.ServerName = host tlsConfig.ServerName = host
@@ -492,7 +604,7 @@ func NewVless(option VlessOption) (*Vless, error) {
v.gunTLSConfig = tlsConfig v.gunTLSConfig = tlsConfig
v.gunConfig = gunConfig v.gunConfig = gunConfig
v.transport = gun.NewHTTP2Client(dialFn, tlsConfig, v.option.ClientFingerprint, v.echConfig, v.realityConfig) v.transport = gun.NewHTTP2Client(dialFn, tlsConfig, v.option.ClientFingerprint, v.realityConfig)
} }
return v, nil return v, nil

View File

@@ -2,9 +2,11 @@ package outbound
import ( import (
"context" "context"
"crypto/tls"
"errors" "errors"
"fmt" "fmt"
"net" "net"
"net/http"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@@ -12,18 +14,18 @@ import (
N "github.com/metacubex/mihomo/common/net" N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/common/utils" "github.com/metacubex/mihomo/common/utils"
"github.com/metacubex/mihomo/component/ca" "github.com/metacubex/mihomo/component/ca"
"github.com/metacubex/mihomo/component/ech" "github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/proxydialer"
"github.com/metacubex/mihomo/component/resolver"
tlsC "github.com/metacubex/mihomo/component/tls" tlsC "github.com/metacubex/mihomo/component/tls"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/ntp" "github.com/metacubex/mihomo/ntp"
"github.com/metacubex/mihomo/transport/gun" "github.com/metacubex/mihomo/transport/gun"
mihomoVMess "github.com/metacubex/mihomo/transport/vmess" mihomoVMess "github.com/metacubex/mihomo/transport/vmess"
"github.com/metacubex/http"
vmess "github.com/metacubex/sing-vmess" vmess "github.com/metacubex/sing-vmess"
"github.com/metacubex/sing-vmess/packetaddr" "github.com/metacubex/sing-vmess/packetaddr"
M "github.com/metacubex/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
"github.com/metacubex/tls"
) )
var ErrUDPRemoteAddrMismatch = errors.New("udp packet dropped due to mismatched remote address") var ErrUDPRemoteAddrMismatch = errors.New("udp packet dropped due to mismatched remote address")
@@ -39,7 +41,6 @@ type Vmess struct {
transport *gun.TransportWrap transport *gun.TransportWrap
realityConfig *tlsC.RealityConfig realityConfig *tlsC.RealityConfig
echConfig *ech.Config
} }
type VmessOption struct { type VmessOption struct {
@@ -56,10 +57,7 @@ type VmessOption struct {
ALPN []string `proxy:"alpn,omitempty"` ALPN []string `proxy:"alpn,omitempty"`
SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"` SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"`
Fingerprint string `proxy:"fingerprint,omitempty"` Fingerprint string `proxy:"fingerprint,omitempty"`
Certificate string `proxy:"certificate,omitempty"`
PrivateKey string `proxy:"private-key,omitempty"`
ServerName string `proxy:"servername,omitempty"` ServerName string `proxy:"servername,omitempty"`
ECHOpts ECHOptions `proxy:"ech-opts,omitempty"`
RealityOpts RealityOptions `proxy:"reality-opts,omitempty"` RealityOpts RealityOptions `proxy:"reality-opts,omitempty"`
HTTPOpts HTTPOptions `proxy:"http-opts,omitempty"` HTTPOpts HTTPOptions `proxy:"http-opts,omitempty"`
HTTP2Opts HTTP2Options `proxy:"h2-opts,omitempty"` HTTP2Opts HTTP2Options `proxy:"h2-opts,omitempty"`
@@ -86,7 +84,6 @@ type HTTP2Options struct {
type GrpcOptions struct { type GrpcOptions struct {
GrpcServiceName string `proxy:"grpc-service-name,omitempty"` GrpcServiceName string `proxy:"grpc-service-name,omitempty"`
GrpcUserAgent string `proxy:"grpc-user-agent,omitempty"`
} }
type WSOptions struct { type WSOptions struct {
@@ -99,7 +96,13 @@ type WSOptions struct {
} }
// StreamConnContext implements C.ProxyAdapter // StreamConnContext implements C.ProxyAdapter
func (v *Vmess) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (_ net.Conn, err error) { func (v *Vmess) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (net.Conn, error) {
var err error
if tlsC.HaveGlobalFingerprint() && (len(v.option.ClientFingerprint) == 0) {
v.option.ClientFingerprint = tlsC.GetGlobalFingerprint()
}
switch v.option.Network { switch v.option.Network {
case "ws": case "ws":
host, port, _ := net.SplitHostPort(v.addr) host, port, _ := net.SplitHostPort(v.addr)
@@ -112,7 +115,6 @@ func (v *Vmess) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.M
V2rayHttpUpgrade: v.option.WSOpts.V2rayHttpUpgrade, V2rayHttpUpgrade: v.option.WSOpts.V2rayHttpUpgrade,
V2rayHttpUpgradeFastOpen: v.option.WSOpts.V2rayHttpUpgradeFastOpen, V2rayHttpUpgradeFastOpen: v.option.WSOpts.V2rayHttpUpgradeFastOpen,
ClientFingerprint: v.option.ClientFingerprint, ClientFingerprint: v.option.ClientFingerprint,
ECHConfig: v.echConfig,
Headers: http.Header{}, Headers: http.Header{},
} }
@@ -124,16 +126,13 @@ func (v *Vmess) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.M
if v.option.TLS { if v.option.TLS {
wsOpts.TLS = true wsOpts.TLS = true
wsOpts.TLSConfig, err = ca.GetTLSConfig(ca.Option{ tlsConfig := &tls.Config{
TLSConfig: &tls.Config{ ServerName: host,
ServerName: host, InsecureSkipVerify: v.option.SkipCertVerify,
InsecureSkipVerify: v.option.SkipCertVerify, NextProtos: []string{"http/1.1"},
NextProtos: []string{"http/1.1"}, }
},
Fingerprint: v.option.Fingerprint, wsOpts.TLSConfig, err = ca.GetSpecifiedFingerprintTLSConfig(tlsConfig, v.option.Fingerprint)
Certificate: v.option.Certificate,
PrivateKey: v.option.PrivateKey,
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -153,7 +152,6 @@ func (v *Vmess) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.M
Host: host, Host: host,
SkipCertVerify: v.option.SkipCertVerify, SkipCertVerify: v.option.SkipCertVerify,
ClientFingerprint: v.option.ClientFingerprint, ClientFingerprint: v.option.ClientFingerprint,
ECH: v.echConfig,
Reality: v.realityConfig, Reality: v.realityConfig,
NextProtos: v.option.ALPN, NextProtos: v.option.ALPN,
} }
@@ -181,9 +179,6 @@ func (v *Vmess) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.M
tlsOpts := mihomoVMess.TLSConfig{ tlsOpts := mihomoVMess.TLSConfig{
Host: host, Host: host,
SkipCertVerify: v.option.SkipCertVerify, SkipCertVerify: v.option.SkipCertVerify,
FingerPrint: v.option.Fingerprint,
Certificate: v.option.Certificate,
PrivateKey: v.option.PrivateKey,
NextProtos: []string{"h2"}, NextProtos: []string{"h2"},
ClientFingerprint: v.option.ClientFingerprint, ClientFingerprint: v.option.ClientFingerprint,
Reality: v.realityConfig, Reality: v.realityConfig,
@@ -203,9 +198,9 @@ func (v *Vmess) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.M
Path: v.option.HTTP2Opts.Path, Path: v.option.HTTP2Opts.Path,
} }
c, err = mihomoVMess.StreamH2Conn(ctx, c, h2Opts) c, err = mihomoVMess.StreamH2Conn(c, h2Opts)
case "grpc": case "grpc":
c, err = gun.StreamGunWithConn(c, v.gunTLSConfig, v.gunConfig, v.echConfig, v.realityConfig) c, err = gun.StreamGunWithConn(c, v.gunTLSConfig, v.gunConfig, v.realityConfig)
default: default:
// handle TLS // handle TLS
if v.option.TLS { if v.option.TLS {
@@ -213,11 +208,7 @@ func (v *Vmess) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.M
tlsOpts := &mihomoVMess.TLSConfig{ tlsOpts := &mihomoVMess.TLSConfig{
Host: host, Host: host,
SkipCertVerify: v.option.SkipCertVerify, SkipCertVerify: v.option.SkipCertVerify,
FingerPrint: v.option.Fingerprint,
Certificate: v.option.Certificate,
PrivateKey: v.option.PrivateKey,
ClientFingerprint: v.option.ClientFingerprint, ClientFingerprint: v.option.ClientFingerprint,
ECH: v.echConfig,
Reality: v.realityConfig, Reality: v.realityConfig,
NextProtos: v.option.ALPN, NextProtos: v.option.ALPN,
} }
@@ -233,24 +224,17 @@ func (v *Vmess) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.M
if err != nil { if err != nil {
return nil, err return nil, err
} }
return v.streamConnContext(ctx, c, metadata) return v.streamConn(c, metadata)
} }
func (v *Vmess) streamConnContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (conn net.Conn, err error) { func (v *Vmess) streamConn(c net.Conn, metadata *C.Metadata) (conn net.Conn, err error) {
useEarly := N.NeedHandshake(c)
if !useEarly {
if ctx.Done() != nil {
done := N.SetupContextForConn(ctx, c)
defer done(&err)
}
}
if metadata.NetWork == C.UDP { if metadata.NetWork == C.UDP {
if v.option.XUDP { if v.option.XUDP {
var globalID [8]byte var globalID [8]byte
if metadata.SourceValid() { if metadata.SourceValid() {
globalID = utils.GlobalID(metadata.SourceAddress()) globalID = utils.GlobalID(metadata.SourceAddress())
} }
if useEarly { if N.NeedHandshake(c) {
conn = v.client.DialEarlyXUDPPacketConn(c, conn = v.client.DialEarlyXUDPPacketConn(c,
globalID, globalID,
M.SocksaddrFromNet(metadata.UDPAddr())) M.SocksaddrFromNet(metadata.UDPAddr()))
@@ -260,7 +244,7 @@ func (v *Vmess) streamConnContext(ctx context.Context, c net.Conn, metadata *C.M
M.SocksaddrFromNet(metadata.UDPAddr())) M.SocksaddrFromNet(metadata.UDPAddr()))
} }
} else if v.option.PacketAddr { } else if v.option.PacketAddr {
if useEarly { if N.NeedHandshake(c) {
conn = v.client.DialEarlyPacketConn(c, conn = v.client.DialEarlyPacketConn(c,
M.ParseSocksaddrHostPort(packetaddr.SeqPacketMagicAddress, 443)) M.ParseSocksaddrHostPort(packetaddr.SeqPacketMagicAddress, 443))
} else { } else {
@@ -269,7 +253,7 @@ func (v *Vmess) streamConnContext(ctx context.Context, c net.Conn, metadata *C.M
} }
conn = packetaddr.NewBindConn(conn) conn = packetaddr.NewBindConn(conn)
} else { } else {
if useEarly { if N.NeedHandshake(c) {
conn = v.client.DialEarlyPacketConn(c, conn = v.client.DialEarlyPacketConn(c,
M.SocksaddrFromNet(metadata.UDPAddr())) M.SocksaddrFromNet(metadata.UDPAddr()))
} else { } else {
@@ -278,7 +262,7 @@ func (v *Vmess) streamConnContext(ctx context.Context, c net.Conn, metadata *C.M
} }
} }
} else { } else {
if useEarly { if N.NeedHandshake(c) {
conn = v.client.DialEarlyConn(c, conn = v.client.DialEarlyConn(c,
M.ParseSocksaddrHostPort(metadata.String(), metadata.DstPort)) M.ParseSocksaddrHostPort(metadata.String(), metadata.DstPort))
} else { } else {
@@ -293,11 +277,10 @@ func (v *Vmess) streamConnContext(ctx context.Context, c net.Conn, metadata *C.M
} }
// DialContext implements C.ProxyAdapter // DialContext implements C.ProxyAdapter
func (v *Vmess) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn, err error) { func (v *Vmess) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.Conn, err error) {
var c net.Conn
// gun transport // gun transport
if v.transport != nil { if v.transport != nil && len(opts) == 0 {
c, err = gun.StreamGunWithTransport(v.transport, v.gunConfig) c, err := gun.StreamGunWithTransport(v.transport, v.gunConfig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -305,17 +288,29 @@ func (v *Vmess) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn
safeConnClose(c, err) safeConnClose(c, err)
}(c) }(c)
c, err = v.streamConnContext(ctx, c, metadata) c, err = v.client.DialConn(c, M.ParseSocksaddrHostPort(metadata.String(), metadata.DstPort))
if err != nil { if err != nil {
return nil, err return nil, err
} }
return NewConn(c, v), nil return NewConn(c, v), nil
} }
c, err = v.dialer.DialContext(ctx, "tcp", v.addr) return v.DialContextWithDialer(ctx, dialer.NewDialer(v.Base.DialOptions(opts...)...), metadata)
}
// DialContextWithDialer implements C.ProxyAdapter
func (v *Vmess) DialContextWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (_ C.Conn, err error) {
if len(v.option.DialerProxy) > 0 {
dialer, err = proxydialer.NewByName(v.option.DialerProxy, dialer)
if err != nil {
return nil, err
}
}
c, err := dialer.DialContext(ctx, "tcp", v.addr)
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error: %s", v.addr, err.Error()) return nil, fmt.Errorf("%s connect error: %s", v.addr, err.Error())
} }
N.TCPKeepAlive(c)
defer func(c net.Conn) { defer func(c net.Conn) {
safeConnClose(c, err) safeConnClose(c, err)
}(c) }(c)
@@ -325,13 +320,18 @@ func (v *Vmess) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn
} }
// ListenPacketContext implements C.ProxyAdapter // ListenPacketContext implements C.ProxyAdapter
func (v *Vmess) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (_ C.PacketConn, err error) { func (v *Vmess) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.PacketConn, err error) {
if err = v.ResolveUDP(ctx, metadata); err != nil { // vmess use stream-oriented udp with a special address, so we need a net.UDPAddr
return nil, err if !metadata.Resolved() {
ip, err := resolver.ResolveIP(ctx, metadata.Host)
if err != nil {
return nil, errors.New("can't resolve ip")
}
metadata.DstIP = ip
} }
var c net.Conn var c net.Conn
// gun transport // gun transport
if v.transport != nil { if v.transport != nil && len(opts) == 0 {
c, err = gun.StreamGunWithTransport(v.transport, v.gunConfig) c, err = gun.StreamGunWithTransport(v.transport, v.gunConfig)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -340,21 +340,38 @@ func (v *Vmess) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (
safeConnClose(c, err) safeConnClose(c, err)
}(c) }(c)
c, err = v.streamConnContext(ctx, c, metadata) c, err = v.streamConn(c, metadata)
if err != nil { if err != nil {
return nil, fmt.Errorf("new vmess client error: %v", err) return nil, fmt.Errorf("new vmess client error: %v", err)
} }
return v.ListenPacketOnStreamConn(ctx, c, metadata) return v.ListenPacketOnStreamConn(ctx, c, metadata)
} }
return v.ListenPacketWithDialer(ctx, dialer.NewDialer(v.Base.DialOptions(opts...)...), metadata)
}
if err = v.ResolveUDP(ctx, metadata); err != nil { // ListenPacketWithDialer implements C.ProxyAdapter
return nil, err func (v *Vmess) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (_ C.PacketConn, err error) {
if len(v.option.DialerProxy) > 0 {
dialer, err = proxydialer.NewByName(v.option.DialerProxy, dialer)
if err != nil {
return nil, err
}
} }
c, err = v.dialer.DialContext(ctx, "tcp", v.addr) // vmess use stream-oriented udp with a special address, so we need a net.UDPAddr
if !metadata.Resolved() {
ip, err := resolver.ResolveIP(ctx, metadata.Host)
if err != nil {
return nil, errors.New("can't resolve ip")
}
metadata.DstIP = ip
}
c, err := dialer.DialContext(ctx, "tcp", v.addr)
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error: %s", v.addr, err.Error()) return nil, fmt.Errorf("%s connect error: %s", v.addr, err.Error())
} }
N.TCPKeepAlive(c)
defer func(c net.Conn) { defer func(c net.Conn) {
safeConnClose(c, err) safeConnClose(c, err)
}(c) }(c)
@@ -366,25 +383,20 @@ func (v *Vmess) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (
return v.ListenPacketOnStreamConn(ctx, c, metadata) return v.ListenPacketOnStreamConn(ctx, c, metadata)
} }
// ProxyInfo implements C.ProxyAdapter // SupportWithDialer implements C.ProxyAdapter
func (v *Vmess) ProxyInfo() C.ProxyInfo { func (v *Vmess) SupportWithDialer() C.NetWork {
info := v.Base.ProxyInfo() return C.ALLNet
info.DialerProxy = v.option.DialerProxy
return info
}
// Close implements C.ProxyAdapter
func (v *Vmess) Close() error {
if v.transport != nil {
return v.transport.Close()
}
return nil
} }
// ListenPacketOnStreamConn implements C.ProxyAdapter // ListenPacketOnStreamConn implements C.ProxyAdapter
func (v *Vmess) ListenPacketOnStreamConn(ctx context.Context, c net.Conn, metadata *C.Metadata) (_ C.PacketConn, err error) { func (v *Vmess) ListenPacketOnStreamConn(ctx context.Context, c net.Conn, metadata *C.Metadata) (_ C.PacketConn, err error) {
if err = v.ResolveUDP(ctx, metadata); err != nil { // vmess use stream-oriented udp with a special address, so we need a net.UDPAddr
return nil, err if !metadata.Resolved() {
ip, err := resolver.ResolveIP(ctx, metadata.Host)
if err != nil {
return nil, errors.New("can't resolve ip")
}
metadata.DstIP = ip
} }
if pc, ok := c.(net.PacketConn); ok { if pc, ok := c.(net.PacketConn); ok {
@@ -428,29 +440,17 @@ func NewVmess(option VmessOption) (*Vmess, error) {
name: option.Name, name: option.Name,
addr: net.JoinHostPort(option.Server, strconv.Itoa(option.Port)), addr: net.JoinHostPort(option.Server, strconv.Itoa(option.Port)),
tp: C.Vmess, tp: C.Vmess,
pdName: option.ProviderName,
udp: option.UDP, udp: option.UDP,
xudp: option.XUDP, xudp: option.XUDP,
tfo: option.TFO, tfo: option.TFO,
mpTcp: option.MPTCP, mpTcp: option.MPTCP,
iface: option.Interface, iface: option.Interface,
rmark: option.RoutingMark, rmark: option.RoutingMark,
prefer: option.IPVersion, prefer: C.NewDNSPrefer(option.IPVersion),
}, },
client: client, client: client,
option: &option, option: &option,
} }
v.dialer = option.NewDialer(v.DialOptions())
v.realityConfig, err = v.option.RealityOpts.Parse()
if err != nil {
return nil, err
}
v.echConfig, err = v.option.ECHOpts.Parse()
if err != nil {
return nil, err
}
switch option.Network { switch option.Network {
case "h2": case "h2":
@@ -458,17 +458,25 @@ func NewVmess(option VmessOption) (*Vmess, error) {
option.HTTP2Opts.Host = append(option.HTTP2Opts.Host, "www.example.com") option.HTTP2Opts.Host = append(option.HTTP2Opts.Host, "www.example.com")
} }
case "grpc": case "grpc":
dialFn := func(ctx context.Context, network, addr string) (net.Conn, error) { dialFn := func(network, addr string) (net.Conn, error) {
c, err := v.dialer.DialContext(ctx, "tcp", v.addr) var err error
var cDialer C.Dialer = dialer.NewDialer(v.Base.DialOptions()...)
if len(v.option.DialerProxy) > 0 {
cDialer, err = proxydialer.NewByName(v.option.DialerProxy, cDialer)
if err != nil {
return nil, err
}
}
c, err := cDialer.DialContext(context.Background(), "tcp", v.addr)
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error: %s", v.addr, err.Error()) return nil, fmt.Errorf("%s connect error: %s", v.addr, err.Error())
} }
N.TCPKeepAlive(c)
return c, nil return c, nil
} }
gunConfig := &gun.Config{ gunConfig := &gun.Config{
ServiceName: v.option.GrpcOpts.GrpcServiceName, ServiceName: v.option.GrpcOpts.GrpcServiceName,
UserAgent: v.option.GrpcOpts.GrpcUserAgent,
Host: v.option.ServerName, Host: v.option.ServerName,
ClientFingerprint: v.option.ClientFingerprint, ClientFingerprint: v.option.ClientFingerprint,
} }
@@ -477,18 +485,10 @@ func NewVmess(option VmessOption) (*Vmess, error) {
} }
var tlsConfig *tls.Config var tlsConfig *tls.Config
if option.TLS { if option.TLS {
tlsConfig, err = ca.GetTLSConfig(ca.Option{ tlsConfig = ca.GetGlobalTLSConfig(&tls.Config{
TLSConfig: &tls.Config{ InsecureSkipVerify: v.option.SkipCertVerify,
InsecureSkipVerify: v.option.SkipCertVerify, ServerName: v.option.ServerName,
ServerName: v.option.ServerName,
},
Fingerprint: v.option.Fingerprint,
Certificate: v.option.Certificate,
PrivateKey: v.option.PrivateKey,
}) })
if err != nil {
return nil, err
}
if option.ServerName == "" { if option.ServerName == "" {
host, _, _ := net.SplitHostPort(v.addr) host, _, _ := net.SplitHostPort(v.addr)
tlsConfig.ServerName = host tlsConfig.ServerName = host
@@ -498,7 +498,12 @@ func NewVmess(option VmessOption) (*Vmess, error) {
v.gunTLSConfig = tlsConfig v.gunTLSConfig = tlsConfig
v.gunConfig = gunConfig v.gunConfig = gunConfig
v.transport = gun.NewHTTP2Client(dialFn, tlsConfig, v.option.ClientFingerprint, v.echConfig, v.realityConfig) v.transport = gun.NewHTTP2Client(dialFn, tlsConfig, v.option.ClientFingerprint, v.realityConfig)
}
v.realityConfig, err = v.option.RealityOpts.Parse()
if err != nil {
return nil, err
} }
return v, nil return v, nil

View File

@@ -4,15 +4,17 @@ import (
"context" "context"
"encoding/base64" "encoding/base64"
"encoding/hex" "encoding/hex"
"errors"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"runtime"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"time"
"github.com/metacubex/mihomo/common/atomic" "github.com/metacubex/mihomo/common/atomic"
CN "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/component/dialer" "github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/proxydialer" "github.com/metacubex/mihomo/component/proxydialer"
"github.com/metacubex/mihomo/component/resolver" "github.com/metacubex/mihomo/component/resolver"
@@ -21,37 +23,24 @@ import (
"github.com/metacubex/mihomo/dns" "github.com/metacubex/mihomo/dns"
"github.com/metacubex/mihomo/log" "github.com/metacubex/mihomo/log"
amnezia "github.com/metacubex/amneziawg-go/device"
wireguard "github.com/metacubex/sing-wireguard" wireguard "github.com/metacubex/sing-wireguard"
"github.com/metacubex/wireguard-go/device"
"github.com/metacubex/sing/common/debug" "github.com/sagernet/sing/common"
E "github.com/metacubex/sing/common/exceptions" "github.com/sagernet/sing/common/debug"
M "github.com/metacubex/sing/common/metadata" E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/wireguard-go/device"
) )
type wireguardGoDevice interface {
Close()
IpcSet(uapiConf string) error
}
type WireGuard struct { type WireGuard struct {
*Base *Base
bind *wireguard.ClientBind bind *wireguard.ClientBind
device wireguardGoDevice device *device.Device
tunDevice wireguard.Device tunDevice wireguard.Device
resolver resolver.Resolver dialer proxydialer.SingDialer
init func(ctx context.Context) error
initOk atomic.Bool resolver *dns.Resolver
initMutex sync.Mutex refP *refProxyAdapter
initErr error
option WireGuardOption
connectAddr M.Socksaddr
localPrefixes []netip.Prefix
serverAddrMap map[M.Socksaddr]netip.AddrPort
serverAddrTime atomic.TypedValue[time.Time]
serverAddrMutex sync.Mutex
} }
type WireGuardOption struct { type WireGuardOption struct {
@@ -66,14 +55,10 @@ type WireGuardOption struct {
UDP bool `proxy:"udp,omitempty"` UDP bool `proxy:"udp,omitempty"`
PersistentKeepalive int `proxy:"persistent-keepalive,omitempty"` PersistentKeepalive int `proxy:"persistent-keepalive,omitempty"`
AmneziaWGOption *AmneziaWGOption `proxy:"amnezia-wg-option,omitempty"`
Peers []WireGuardPeerOption `proxy:"peers,omitempty"` Peers []WireGuardPeerOption `proxy:"peers,omitempty"`
RemoteDnsResolve bool `proxy:"remote-dns-resolve,omitempty"` RemoteDnsResolve bool `proxy:"remote-dns-resolve,omitempty"`
Dns []string `proxy:"dns,omitempty"` Dns []string `proxy:"dns,omitempty"`
RefreshServerIPInterval int `proxy:"refresh-server-ip-interval,omitempty"`
} }
type WireGuardPeerOption struct { type WireGuardPeerOption struct {
@@ -85,29 +70,6 @@ type WireGuardPeerOption struct {
AllowedIPs []string `proxy:"allowed-ips,omitempty"` AllowedIPs []string `proxy:"allowed-ips,omitempty"`
} }
type AmneziaWGOption struct {
JC int `proxy:"jc,omitempty"`
JMin int `proxy:"jmin,omitempty"`
JMax int `proxy:"jmax,omitempty"`
S1 int `proxy:"s1,omitempty"`
S2 int `proxy:"s2,omitempty"`
S3 int `proxy:"s3,omitempty"` // AmneziaWG v1.5 and v2
S4 int `proxy:"s4,omitempty"` // AmneziaWG v1.5 and v2
H1 string `proxy:"h1,omitempty"` // In AmneziaWG v1.x, it was uint32, but our WeaklyTypedInput can handle this situation
H2 string `proxy:"h2,omitempty"` // In AmneziaWG v1.x, it was uint32, but our WeaklyTypedInput can handle this situation
H3 string `proxy:"h3,omitempty"` // In AmneziaWG v1.x, it was uint32, but our WeaklyTypedInput can handle this situation
H4 string `proxy:"h4,omitempty"` // In AmneziaWG v1.x, it was uint32, but our WeaklyTypedInput can handle this situation
I1 string `proxy:"i1,omitempty"` // AmneziaWG v1.5 and v2
I2 string `proxy:"i2,omitempty"` // AmneziaWG v1.5 and v2
I3 string `proxy:"i3,omitempty"` // AmneziaWG v1.5 and v2
I4 string `proxy:"i4,omitempty"` // AmneziaWG v1.5 and v2
I5 string `proxy:"i5,omitempty"` // AmneziaWG v1.5 and v2
J1 string `proxy:"j1,omitempty"` // AmneziaWG v1.5 only (removed in v2)
J2 string `proxy:"j2,omitempty"` // AmneziaWG v1.5 only (removed in v2)
J3 string `proxy:"j3,omitempty"` // AmneziaWG v1.5 only (removed in v2)
Itime int64 `proxy:"itime,omitempty"` // AmneziaWG v1.5 only (removed in v2)
}
type wgSingErrorHandler struct { type wgSingErrorHandler struct {
name string name string
} }
@@ -170,15 +132,27 @@ func NewWireGuard(option WireGuardOption) (*WireGuard, error) {
name: option.Name, name: option.Name,
addr: net.JoinHostPort(option.Server, strconv.Itoa(option.Port)), addr: net.JoinHostPort(option.Server, strconv.Itoa(option.Port)),
tp: C.WireGuard, tp: C.WireGuard,
pdName: option.ProviderName,
udp: option.UDP, udp: option.UDP,
iface: option.Interface, iface: option.Interface,
rmark: option.RoutingMark, rmark: option.RoutingMark,
prefer: option.IPVersion, prefer: C.NewDNSPrefer(option.IPVersion),
}, },
dialer: proxydialer.NewSlowDownSingDialer(proxydialer.NewByNameSingDialer(option.DialerProxy, dialer.NewDialer()), slowdown.New()),
}
runtime.SetFinalizer(outbound, closeWireGuard)
resolv := func(ctx context.Context, address M.Socksaddr) (netip.AddrPort, error) {
if address.Addr.IsValid() {
return address.AddrPort(), nil
}
udpAddr, err := resolveUDPAddrWithPrefer(ctx, "udp", address.String(), outbound.prefer)
if err != nil {
return netip.AddrPort{}, err
}
// net.ResolveUDPAddr maybe return 4in6 address, so unmap at here
addrPort := udpAddr.AddrPort()
return netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port()), nil
} }
outbound.dialer = option.NewDialer(outbound.DialOptions())
singDialer := proxydialer.NewSingDialer(proxydialer.NewSlowDownDialer(outbound.dialer, slowdown.New()))
var reserved [3]uint8 var reserved [3]uint8
if len(option.Reserved) > 0 { if len(option.Reserved) > 0 {
@@ -188,28 +162,29 @@ func NewWireGuard(option WireGuardOption) (*WireGuard, error) {
copy(reserved[:], option.Reserved) copy(reserved[:], option.Reserved)
} }
var isConnect bool var isConnect bool
var connectAddr M.Socksaddr
if len(option.Peers) < 2 { if len(option.Peers) < 2 {
isConnect = true isConnect = true
if len(option.Peers) == 1 { if len(option.Peers) == 1 {
outbound.connectAddr = option.Peers[0].Addr() connectAddr = option.Peers[0].Addr()
} else { } else {
outbound.connectAddr = option.Addr() connectAddr = option.Addr()
} }
} }
outbound.bind = wireguard.NewClientBind(context.Background(), wgSingErrorHandler{outbound.Name()}, singDialer, isConnect, outbound.connectAddr.AddrPort(), reserved) outbound.bind = wireguard.NewClientBind(context.Background(), wgSingErrorHandler{outbound.Name()}, outbound.dialer, isConnect, connectAddr.AddrPort(), reserved)
var err error localPrefixes, err := option.Prefixes()
outbound.localPrefixes, err = option.Prefixes()
if err != nil { if err != nil {
return nil, err return nil, err
} }
var privateKey string
{ {
bytes, err := base64.StdEncoding.DecodeString(option.PrivateKey) bytes, err := base64.StdEncoding.DecodeString(option.PrivateKey)
if err != nil { if err != nil {
return nil, E.Cause(err, "decode private key") return nil, E.Cause(err, "decode private key")
} }
option.PrivateKey = hex.EncodeToString(bytes) privateKey = hex.EncodeToString(bytes)
} }
if len(option.Peers) > 0 { if len(option.Peers) > 0 {
@@ -255,49 +230,139 @@ func NewWireGuard(option WireGuardOption) (*WireGuard, error) {
option.PreSharedKey = hex.EncodeToString(bytes) option.PreSharedKey = hex.EncodeToString(bytes)
} }
} }
outbound.option = option
var (
initOk atomic.Bool
initMutex sync.Mutex
initErr error
)
outbound.init = func(ctx context.Context) error {
if initOk.Load() {
return nil
}
initMutex.Lock()
defer initMutex.Unlock()
// double check like sync.Once
if initOk.Load() {
return nil
}
if initErr != nil {
return initErr
}
outbound.bind.ResetReservedForEndpoint()
ipcConf := "private_key=" + privateKey
if len(option.Peers) > 0 {
for i, peer := range option.Peers {
destination, err := resolv(ctx, peer.Addr())
if err != nil {
// !!! do not set initErr here !!!
// let us can retry domain resolve in next time
return E.Cause(err, "resolve endpoint domain for peer ", i)
}
ipcConf += "\npublic_key=" + peer.PublicKey
ipcConf += "\nendpoint=" + destination.String()
if peer.PreSharedKey != "" {
ipcConf += "\npreshared_key=" + peer.PreSharedKey
}
for _, allowedIP := range peer.AllowedIPs {
ipcConf += "\nallowed_ip=" + allowedIP
}
if len(peer.Reserved) > 0 {
copy(reserved[:], option.Reserved)
outbound.bind.SetReservedForEndpoint(destination, reserved)
}
}
} else {
ipcConf += "\npublic_key=" + option.PublicKey
destination, err := resolv(ctx, connectAddr)
if err != nil {
// !!! do not set initErr here !!!
// let us can retry domain resolve in next time
return E.Cause(err, "resolve endpoint domain")
}
outbound.bind.SetConnectAddr(destination)
ipcConf += "\nendpoint=" + destination.String()
if option.PreSharedKey != "" {
ipcConf += "\npreshared_key=" + option.PreSharedKey
}
var has4, has6 bool
for _, address := range localPrefixes {
if address.Addr().Is4() {
has4 = true
} else {
has6 = true
}
}
if has4 {
ipcConf += "\nallowed_ip=0.0.0.0/0"
}
if has6 {
ipcConf += "\nallowed_ip=::/0"
}
}
if option.PersistentKeepalive != 0 {
ipcConf += fmt.Sprintf("\npersistent_keepalive_interval=%d", option.PersistentKeepalive)
}
if debug.Enabled {
log.SingLogger.Trace(fmt.Sprintf("[WG](%s) created wireguard ipc conf: \n %s", option.Name, ipcConf))
}
err = outbound.device.IpcSet(ipcConf)
if err != nil {
initErr = E.Cause(err, "setup wireguard")
return initErr
}
err = outbound.tunDevice.Start()
if err != nil {
initErr = err
return initErr
}
initOk.Store(true)
return nil
}
mtu := option.MTU mtu := option.MTU
if mtu == 0 { if mtu == 0 {
mtu = 1408 mtu = 1408
} }
if len(outbound.localPrefixes) == 0 { if len(localPrefixes) == 0 {
return nil, E.New("missing local address") return nil, E.New("missing local address")
} }
outbound.tunDevice, err = wireguard.NewStackDevice(outbound.localPrefixes, uint32(mtu)) outbound.tunDevice, err = wireguard.NewStackDevice(localPrefixes, uint32(mtu))
if err != nil { if err != nil {
return nil, E.Cause(err, "create WireGuard device") return nil, E.Cause(err, "create WireGuard device")
} }
logger := &device.Logger{ outbound.device = device.NewDevice(context.Background(), outbound.tunDevice, outbound.bind, &device.Logger{
Verbosef: func(format string, args ...interface{}) { Verbosef: func(format string, args ...interface{}) {
log.SingLogger.Debug(fmt.Sprintf("[WG](%s) %s", option.Name, fmt.Sprintf(format, args...))) log.SingLogger.Debug(fmt.Sprintf("[WG](%s) %s", option.Name, fmt.Sprintf(format, args...)))
}, },
Errorf: func(format string, args ...interface{}) { Errorf: func(format string, args ...interface{}) {
log.SingLogger.Error(fmt.Sprintf("[WG](%s) %s", option.Name, fmt.Sprintf(format, args...))) log.SingLogger.Error(fmt.Sprintf("[WG](%s) %s", option.Name, fmt.Sprintf(format, args...)))
}, },
} }, option.Workers)
if option.AmneziaWGOption != nil {
outbound.bind.SetParseReserved(false) // AmneziaWG don't need parse reserved
outbound.device = amnezia.NewDevice(outbound.tunDevice, outbound.bind, logger, option.Workers)
} else {
outbound.device = device.NewDevice(outbound.tunDevice, outbound.bind, logger, option.Workers)
}
var has6 bool var has6 bool
for _, address := range outbound.localPrefixes { for _, address := range localPrefixes {
if !address.Addr().Unmap().Is4() { if !address.Addr().Unmap().Is4() {
has6 = true has6 = true
break break
} }
} }
refP := &refProxyAdapter{}
outbound.refP = refP
if option.RemoteDnsResolve && len(option.Dns) > 0 { if option.RemoteDnsResolve && len(option.Dns) > 0 {
nss, err := dns.ParseNameServer(option.Dns) nss, err := dns.ParseNameServer(option.Dns)
if err != nil { if err != nil {
return nil, err return nil, err
} }
for i := range nss { for i := range nss {
nss[i].ProxyAdapter = outbound nss[i].ProxyAdapter = refP
} }
outbound.resolver = dns.NewResolver(dns.Config{ outbound.resolver = dns.NewResolver(dns.Config{
Main: nss, Main: nss,
@@ -308,249 +373,16 @@ func NewWireGuard(option WireGuardOption) (*WireGuard, error) {
return outbound, nil return outbound, nil
} }
func (w *WireGuard) resolve(ctx context.Context, address M.Socksaddr) (netip.AddrPort, error) { func closeWireGuard(w *WireGuard) {
if address.Addr.IsValid() {
return address.AddrPort(), nil
}
udpAddr, err := resolveUDPAddr(ctx, "udp", address.String(), w.prefer)
if err != nil {
return netip.AddrPort{}, err
}
// net.ResolveUDPAddr maybe return 4in6 address, so unmap at here
addrPort := udpAddr.AddrPort()
return netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port()), nil
}
func (w *WireGuard) init(ctx context.Context) error {
err := w.init0(ctx)
if err != nil {
return err
}
w.updateServerAddr(ctx)
return nil
}
func (w *WireGuard) init0(ctx context.Context) error {
if w.initOk.Load() {
return nil
}
w.initMutex.Lock()
defer w.initMutex.Unlock()
// double check like sync.Once
if w.initOk.Load() {
return nil
}
if w.initErr != nil {
return w.initErr
}
w.bind.ResetReservedForEndpoint()
w.serverAddrMap = make(map[M.Socksaddr]netip.AddrPort)
ipcConf, err := w.genIpcConf(ctx, false)
if err != nil {
// !!! do not set initErr here !!!
// let us can retry domain resolve in next time
return err
}
if debug.Enabled {
log.SingLogger.Trace(fmt.Sprintf("[WG](%s) created wireguard ipc conf: \n %s", w.option.Name, ipcConf))
}
err = w.device.IpcSet(ipcConf)
if err != nil {
w.initErr = E.Cause(err, "setup wireguard")
return w.initErr
}
w.serverAddrTime.Store(time.Now())
err = w.tunDevice.Start()
if err != nil {
w.initErr = err
return w.initErr
}
w.initOk.Store(true)
return nil
}
func (w *WireGuard) updateServerAddr(ctx context.Context) {
if w.option.RefreshServerIPInterval != 0 && time.Since(w.serverAddrTime.Load()) > time.Second*time.Duration(w.option.RefreshServerIPInterval) {
if w.serverAddrMutex.TryLock() {
defer w.serverAddrMutex.Unlock()
ipcConf, err := w.genIpcConf(ctx, true)
if err != nil {
log.Warnln("[WG](%s)UpdateServerAddr failed to generate wireguard ipc conf: %s", w.option.Name, err)
return
}
err = w.device.IpcSet(ipcConf)
if err != nil {
log.Warnln("[WG](%s)UpdateServerAddr failed to update wireguard ipc conf: %s", w.option.Name, err)
return
}
w.serverAddrTime.Store(time.Now())
}
}
}
func (w *WireGuard) genIpcConf(ctx context.Context, updateOnly bool) (string, error) {
ipcConf := ""
if !updateOnly {
ipcConf += "private_key=" + w.option.PrivateKey + "\n"
if w.option.AmneziaWGOption != nil {
if w.option.AmneziaWGOption.JC != 0 {
ipcConf += "jc=" + strconv.Itoa(w.option.AmneziaWGOption.JC) + "\n"
}
if w.option.AmneziaWGOption.JMin != 0 {
ipcConf += "jmin=" + strconv.Itoa(w.option.AmneziaWGOption.JMin) + "\n"
}
if w.option.AmneziaWGOption.JMax != 0 {
ipcConf += "jmax=" + strconv.Itoa(w.option.AmneziaWGOption.JMax) + "\n"
}
if w.option.AmneziaWGOption.S1 != 0 {
ipcConf += "s1=" + strconv.Itoa(w.option.AmneziaWGOption.S1) + "\n"
}
if w.option.AmneziaWGOption.S2 != 0 {
ipcConf += "s2=" + strconv.Itoa(w.option.AmneziaWGOption.S2) + "\n"
}
if w.option.AmneziaWGOption.S3 != 0 {
ipcConf += "s3=" + strconv.Itoa(w.option.AmneziaWGOption.S3) + "\n"
}
if w.option.AmneziaWGOption.S4 != 0 {
ipcConf += "s4=" + strconv.Itoa(w.option.AmneziaWGOption.S4) + "\n"
}
if w.option.AmneziaWGOption.H1 != "" {
ipcConf += "h1=" + w.option.AmneziaWGOption.H1 + "\n"
}
if w.option.AmneziaWGOption.H2 != "" {
ipcConf += "h2=" + w.option.AmneziaWGOption.H2 + "\n"
}
if w.option.AmneziaWGOption.H3 != "" {
ipcConf += "h3=" + w.option.AmneziaWGOption.H3 + "\n"
}
if w.option.AmneziaWGOption.H4 != "" {
ipcConf += "h4=" + w.option.AmneziaWGOption.H4 + "\n"
}
if w.option.AmneziaWGOption.I1 != "" {
ipcConf += "i1=" + w.option.AmneziaWGOption.I1 + "\n"
}
if w.option.AmneziaWGOption.I2 != "" {
ipcConf += "i2=" + w.option.AmneziaWGOption.I2 + "\n"
}
if w.option.AmneziaWGOption.I3 != "" {
ipcConf += "i3=" + w.option.AmneziaWGOption.I3 + "\n"
}
if w.option.AmneziaWGOption.I4 != "" {
ipcConf += "i4=" + w.option.AmneziaWGOption.I4 + "\n"
}
if w.option.AmneziaWGOption.I5 != "" {
ipcConf += "i5=" + w.option.AmneziaWGOption.I5 + "\n"
}
if w.option.AmneziaWGOption.J1 != "" {
ipcConf += "j1=" + w.option.AmneziaWGOption.J1 + "\n"
}
if w.option.AmneziaWGOption.J2 != "" {
ipcConf += "j2=" + w.option.AmneziaWGOption.J2 + "\n"
}
if w.option.AmneziaWGOption.J3 != "" {
ipcConf += "j3=" + w.option.AmneziaWGOption.J3 + "\n"
}
if w.option.AmneziaWGOption.Itime != 0 {
ipcConf += "itime=" + strconv.FormatInt(int64(w.option.AmneziaWGOption.Itime), 10) + "\n"
}
}
}
if len(w.option.Peers) > 0 {
for i, peer := range w.option.Peers {
peerAddr := peer.Addr()
destination, err := w.resolve(ctx, peerAddr)
if err != nil {
return "", E.Cause(err, "resolve endpoint domain for peer ", i)
}
if w.serverAddrMap[peerAddr] != destination {
w.serverAddrMap[peerAddr] = destination
} else if updateOnly {
continue
}
if len(w.option.Peers) == 1 { // must call SetConnectAddr if isConnect == true
w.bind.SetConnectAddr(destination)
}
ipcConf += "public_key=" + peer.PublicKey + "\n"
if updateOnly {
ipcConf += "update_only=true\n"
}
ipcConf += "endpoint=" + destination.String() + "\n"
if len(peer.Reserved) > 0 {
var reserved [3]uint8
copy(reserved[:], w.option.Reserved)
w.bind.SetReservedForEndpoint(destination, reserved)
}
if updateOnly {
continue
}
if peer.PreSharedKey != "" {
ipcConf += "preshared_key=" + peer.PreSharedKey + "\n"
}
for _, allowedIP := range peer.AllowedIPs {
ipcConf += "allowed_ip=" + allowedIP + "\n"
}
if w.option.PersistentKeepalive != 0 {
ipcConf += fmt.Sprintf("persistent_keepalive_interval=%d\n", w.option.PersistentKeepalive)
}
}
} else {
destination, err := w.resolve(ctx, w.connectAddr)
if err != nil {
return "", E.Cause(err, "resolve endpoint domain")
}
if w.serverAddrMap[w.connectAddr] != destination {
w.serverAddrMap[w.connectAddr] = destination
} else if updateOnly {
return "", nil
}
w.bind.SetConnectAddr(destination) // must call SetConnectAddr if isConnect == true
ipcConf += "public_key=" + w.option.PublicKey + "\n"
if updateOnly {
ipcConf += "update_only=true\n"
}
ipcConf += "endpoint=" + destination.String() + "\n"
if updateOnly {
return ipcConf, nil
}
if w.option.PreSharedKey != "" {
ipcConf += "preshared_key=" + w.option.PreSharedKey + "\n"
}
var has4, has6 bool
for _, address := range w.localPrefixes {
if address.Addr().Is4() {
has4 = true
} else {
has6 = true
}
}
if has4 {
ipcConf += "allowed_ip=0.0.0.0/0\n"
}
if has6 {
ipcConf += "allowed_ip=::/0\n"
}
if w.option.PersistentKeepalive != 0 {
ipcConf += fmt.Sprintf("persistent_keepalive_interval=%d\n", w.option.PersistentKeepalive)
}
}
return ipcConf, nil
}
// Close implements C.ProxyAdapter
func (w *WireGuard) Close() error {
if w.device != nil { if w.device != nil {
w.device.Close() w.device.Close()
} }
return nil _ = common.Close(w.tunDevice)
} }
func (w *WireGuard) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.Conn, err error) { func (w *WireGuard) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.Conn, err error) {
options := w.Base.DialOptions(opts...)
w.dialer.SetDialer(dialer.NewDialer(options...))
var conn net.Conn var conn net.Conn
if err = w.init(ctx); err != nil { if err = w.init(ctx); err != nil {
return nil, err return nil, err
@@ -558,9 +390,10 @@ func (w *WireGuard) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.
if !metadata.Resolved() || w.resolver != nil { if !metadata.Resolved() || w.resolver != nil {
r := resolver.DefaultResolver r := resolver.DefaultResolver
if w.resolver != nil { if w.resolver != nil {
w.refP.SetProxyAdapter(w)
defer w.refP.ClearProxyAdapter()
r = w.resolver r = w.resolver
} }
options := w.DialOptions()
options = append(options, dialer.WithResolver(r)) options = append(options, dialer.WithResolver(r))
options = append(options, dialer.WithNetDialer(wgNetDialer{tunDevice: w.tunDevice})) options = append(options, dialer.WithNetDialer(wgNetDialer{tunDevice: w.tunDevice}))
conn, err = dialer.NewDialer(options...).DialContext(ctx, "tcp", metadata.RemoteAddress()) conn, err = dialer.NewDialer(options...).DialContext(ctx, "tcp", metadata.RemoteAddress())
@@ -573,17 +406,32 @@ func (w *WireGuard) DialContext(ctx context.Context, metadata *C.Metadata) (_ C.
if conn == nil { if conn == nil {
return nil, E.New("conn is nil") return nil, E.New("conn is nil")
} }
return NewConn(conn, w), nil return NewConn(CN.NewRefConn(conn, w), w), nil
} }
func (w *WireGuard) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (_ C.PacketConn, err error) { func (w *WireGuard) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.PacketConn, err error) {
options := w.Base.DialOptions(opts...)
w.dialer.SetDialer(dialer.NewDialer(options...))
var pc net.PacketConn var pc net.PacketConn
if err = w.init(ctx); err != nil { if err = w.init(ctx); err != nil {
return nil, err return nil, err
} }
if err = w.ResolveUDP(ctx, metadata); err != nil { if err != nil {
return nil, err return nil, err
} }
if (!metadata.Resolved() || w.resolver != nil) && metadata.Host != "" {
r := resolver.DefaultResolver
if w.resolver != nil {
w.refP.SetProxyAdapter(w)
defer w.refP.ClearProxyAdapter()
r = w.resolver
}
ip, err := resolver.ResolveIPWithResolver(ctx, metadata.Host, r)
if err != nil {
return nil, errors.New("can't resolve ip")
}
metadata.DstIP = ip
}
pc, err = w.tunDevice.ListenPacket(ctx, M.SocksaddrFrom(metadata.DstIP, metadata.DstPort).Unwrap()) pc, err = w.tunDevice.ListenPacket(ctx, M.SocksaddrFrom(metadata.DstIP, metadata.DstPort).Unwrap())
if err != nil { if err != nil {
return nil, err return nil, err
@@ -591,32 +439,146 @@ func (w *WireGuard) ListenPacketContext(ctx context.Context, metadata *C.Metadat
if pc == nil { if pc == nil {
return nil, E.New("packetConn is nil") return nil, E.New("packetConn is nil")
} }
return newPacketConn(pc, w), nil return newPacketConn(CN.NewRefPacketConn(pc, w), w), nil
}
func (w *WireGuard) ResolveUDP(ctx context.Context, metadata *C.Metadata) error {
if (!metadata.Resolved() || w.resolver != nil) && metadata.Host != "" {
r := resolver.DefaultResolver
if w.resolver != nil {
r = w.resolver
}
ip, err := resolver.ResolveIPWithResolver(ctx, metadata.Host, r)
if err != nil {
return fmt.Errorf("can't resolve ip: %w", err)
}
metadata.DstIP = ip
}
return nil
}
// ProxyInfo implements C.ProxyAdapter
func (w *WireGuard) ProxyInfo() C.ProxyInfo {
info := w.Base.ProxyInfo()
info.DialerProxy = w.option.DialerProxy
return info
} }
// IsL3Protocol implements C.ProxyAdapter // IsL3Protocol implements C.ProxyAdapter
func (w *WireGuard) IsL3Protocol(metadata *C.Metadata) bool { func (w *WireGuard) IsL3Protocol(metadata *C.Metadata) bool {
return true return true
} }
type refProxyAdapter struct {
proxyAdapter C.ProxyAdapter
count int
mutex sync.Mutex
}
func (r *refProxyAdapter) SetProxyAdapter(proxyAdapter C.ProxyAdapter) {
r.mutex.Lock()
defer r.mutex.Unlock()
r.proxyAdapter = proxyAdapter
r.count++
}
func (r *refProxyAdapter) ClearProxyAdapter() {
r.mutex.Lock()
defer r.mutex.Unlock()
r.count--
if r.count == 0 {
r.proxyAdapter = nil
}
}
func (r *refProxyAdapter) Name() string {
if r.proxyAdapter != nil {
return r.proxyAdapter.Name()
}
return ""
}
func (r *refProxyAdapter) Type() C.AdapterType {
if r.proxyAdapter != nil {
return r.proxyAdapter.Type()
}
return C.AdapterType(0)
}
func (r *refProxyAdapter) Addr() string {
if r.proxyAdapter != nil {
return r.proxyAdapter.Addr()
}
return ""
}
func (r *refProxyAdapter) SupportUDP() bool {
if r.proxyAdapter != nil {
return r.proxyAdapter.SupportUDP()
}
return false
}
func (r *refProxyAdapter) SupportXUDP() bool {
if r.proxyAdapter != nil {
return r.proxyAdapter.SupportXUDP()
}
return false
}
func (r *refProxyAdapter) SupportTFO() bool {
if r.proxyAdapter != nil {
return r.proxyAdapter.SupportTFO()
}
return false
}
func (r *refProxyAdapter) MarshalJSON() ([]byte, error) {
if r.proxyAdapter != nil {
return r.proxyAdapter.MarshalJSON()
}
return nil, C.ErrNotSupport
}
func (r *refProxyAdapter) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (net.Conn, error) {
if r.proxyAdapter != nil {
return r.proxyAdapter.StreamConnContext(ctx, c, metadata)
}
return nil, C.ErrNotSupport
}
func (r *refProxyAdapter) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.Conn, error) {
if r.proxyAdapter != nil {
return r.proxyAdapter.DialContext(ctx, metadata, opts...)
}
return nil, C.ErrNotSupport
}
func (r *refProxyAdapter) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.PacketConn, error) {
if r.proxyAdapter != nil {
return r.proxyAdapter.ListenPacketContext(ctx, metadata, opts...)
}
return nil, C.ErrNotSupport
}
func (r *refProxyAdapter) SupportUOT() bool {
if r.proxyAdapter != nil {
return r.proxyAdapter.SupportUOT()
}
return false
}
func (r *refProxyAdapter) SupportWithDialer() C.NetWork {
if r.proxyAdapter != nil {
return r.proxyAdapter.SupportWithDialer()
}
return C.InvalidNet
}
func (r *refProxyAdapter) DialContextWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (C.Conn, error) {
if r.proxyAdapter != nil {
return r.proxyAdapter.DialContextWithDialer(ctx, dialer, metadata)
}
return nil, C.ErrNotSupport
}
func (r *refProxyAdapter) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (C.PacketConn, error) {
if r.proxyAdapter != nil {
return r.proxyAdapter.ListenPacketWithDialer(ctx, dialer, metadata)
}
return nil, C.ErrNotSupport
}
func (r *refProxyAdapter) IsL3Protocol(metadata *C.Metadata) bool {
if r.proxyAdapter != nil {
return r.proxyAdapter.IsL3Protocol(metadata)
}
return false
}
func (r *refProxyAdapter) Unwrap(metadata *C.Metadata, touch bool) C.Proxy {
if r.proxyAdapter != nil {
return r.proxyAdapter.Unwrap(metadata, touch)
}
return nil
}
var _ C.ProxyAdapter = (*refProxyAdapter)(nil)

View File

@@ -6,11 +6,13 @@ import (
"errors" "errors"
"time" "time"
"github.com/metacubex/mihomo/adapter/outbound"
"github.com/metacubex/mihomo/common/callback" "github.com/metacubex/mihomo/common/callback"
N "github.com/metacubex/mihomo/common/net" N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/common/utils" "github.com/metacubex/mihomo/common/utils"
"github.com/metacubex/mihomo/component/dialer"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
P "github.com/metacubex/mihomo/constant/provider" "github.com/metacubex/mihomo/constant/provider"
) )
type Fallback struct { type Fallback struct {
@@ -29,13 +31,13 @@ func (f *Fallback) Now() string {
} }
// DialContext implements C.ProxyAdapter // DialContext implements C.ProxyAdapter
func (f *Fallback) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) { func (f *Fallback) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.Conn, error) {
proxy := f.findAliveProxy(true) proxy := f.findAliveProxy(true)
c, err := proxy.DialContext(ctx, metadata) c, err := proxy.DialContext(ctx, metadata, f.Base.DialOptions(opts...)...)
if err == nil { if err == nil {
c.AppendToChains(f) c.AppendToChains(f)
} else { } else {
f.onDialFailed(proxy.Type(), err, f.healthCheck) f.onDialFailed(proxy.Type(), err)
} }
if N.NeedHandshake(c) { if N.NeedHandshake(c) {
@@ -43,7 +45,7 @@ func (f *Fallback) DialContext(ctx context.Context, metadata *C.Metadata) (C.Con
if err == nil { if err == nil {
f.onDialSuccess() f.onDialSuccess()
} else { } else {
f.onDialFailed(proxy.Type(), err, f.healthCheck) f.onDialFailed(proxy.Type(), err)
} }
}) })
} }
@@ -52,9 +54,9 @@ func (f *Fallback) DialContext(ctx context.Context, metadata *C.Metadata) (C.Con
} }
// ListenPacketContext implements C.ProxyAdapter // ListenPacketContext implements C.ProxyAdapter
func (f *Fallback) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (C.PacketConn, error) { func (f *Fallback) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.PacketConn, error) {
proxy := f.findAliveProxy(true) proxy := f.findAliveProxy(true)
pc, err := proxy.ListenPacketContext(ctx, metadata) pc, err := proxy.ListenPacketContext(ctx, metadata, f.Base.DialOptions(opts...)...)
if err == nil { if err == nil {
pc.AppendToChains(f) pc.AppendToChains(f)
} }
@@ -150,25 +152,21 @@ func (f *Fallback) ForceSet(name string) {
f.selected = name f.selected = name
} }
func (f *Fallback) Providers() []P.ProxyProvider { func NewFallback(option *GroupCommonOption, providers []provider.ProxyProvider) *Fallback {
return f.providers
}
func (f *Fallback) Proxies() []C.Proxy {
return f.GetProxies(false)
}
func NewFallback(option *GroupCommonOption, providers []P.ProxyProvider) *Fallback {
return &Fallback{ return &Fallback{
GroupBase: NewGroupBase(GroupBaseOption{ GroupBase: NewGroupBase(GroupBaseOption{
Name: option.Name, outbound.BaseOption{
Type: C.Fallback, Name: option.Name,
Filter: option.Filter, Type: C.Fallback,
ExcludeFilter: option.ExcludeFilter, Interface: option.Interface,
ExcludeType: option.ExcludeType, RoutingMark: option.RoutingMark,
TestTimeout: option.TestTimeout, },
MaxFailedTimes: option.MaxFailedTimes, option.Filter,
Providers: providers, option.ExcludeFilter,
option.ExcludeType,
option.TestTimeout,
option.MaxFailedTimes,
providers,
}), }),
disableUDP: option.DisableUDP, disableUDP: option.DisableUDP,
testUrl: option.URL, testUrl: option.URL,

View File

@@ -2,7 +2,6 @@ package outboundgroup
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"strings" "strings"
"sync" "sync"
@@ -12,75 +11,67 @@ import (
"github.com/metacubex/mihomo/common/atomic" "github.com/metacubex/mihomo/common/atomic"
"github.com/metacubex/mihomo/common/utils" "github.com/metacubex/mihomo/common/utils"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
P "github.com/metacubex/mihomo/constant/provider" "github.com/metacubex/mihomo/constant/provider"
types "github.com/metacubex/mihomo/constant/provider"
"github.com/metacubex/mihomo/log" "github.com/metacubex/mihomo/log"
"github.com/metacubex/mihomo/tunnel" "github.com/metacubex/mihomo/tunnel"
"github.com/dlclark/regexp2" "github.com/dlclark/regexp2"
"golang.org/x/exp/slices"
) )
type GroupBase struct { type GroupBase struct {
*outbound.Base *outbound.Base
filterRegs []*regexp2.Regexp filterRegs []*regexp2.Regexp
excludeFilterRegs []*regexp2.Regexp excludeFilterReg *regexp2.Regexp
excludeTypeArray []string excludeTypeArray []string
providers []P.ProxyProvider providers []provider.ProxyProvider
failedTestMux sync.Mutex failedTestMux sync.Mutex
failedTimes int failedTimes int
failedTime time.Time failedTime time.Time
failedTesting atomic.Bool failedTesting atomic.Bool
TestTimeout int proxies [][]C.Proxy
maxFailedTimes int versions []atomic.Uint32
TestTimeout int
// for GetProxies maxFailedTimes int
getProxiesMutex sync.Mutex
providerVersions []uint32
providerProxies []C.Proxy
} }
type GroupBaseOption struct { type GroupBaseOption struct {
Name string outbound.BaseOption
Type C.AdapterType filter string
Filter string excludeFilter string
ExcludeFilter string excludeType string
ExcludeType string
TestTimeout int TestTimeout int
MaxFailedTimes int maxFailedTimes int
Providers []P.ProxyProvider providers []provider.ProxyProvider
} }
func NewGroupBase(opt GroupBaseOption) *GroupBase { func NewGroupBase(opt GroupBaseOption) *GroupBase {
var excludeTypeArray []string var excludeFilterReg *regexp2.Regexp
if opt.ExcludeType != "" { if opt.excludeFilter != "" {
excludeTypeArray = strings.Split(opt.ExcludeType, "|") excludeFilterReg = regexp2.MustCompile(opt.excludeFilter, 0)
} }
var excludeTypeArray []string
var excludeFilterRegs []*regexp2.Regexp if opt.excludeType != "" {
if opt.ExcludeFilter != "" { excludeTypeArray = strings.Split(opt.excludeType, "|")
for _, excludeFilter := range strings.Split(opt.ExcludeFilter, "`") {
excludeFilterReg := regexp2.MustCompile(excludeFilter, regexp2.None)
excludeFilterRegs = append(excludeFilterRegs, excludeFilterReg)
}
} }
var filterRegs []*regexp2.Regexp var filterRegs []*regexp2.Regexp
if opt.Filter != "" { if opt.filter != "" {
for _, filter := range strings.Split(opt.Filter, "`") { for _, filter := range strings.Split(opt.filter, "`") {
filterReg := regexp2.MustCompile(filter, regexp2.None) filterReg := regexp2.MustCompile(filter, 0)
filterRegs = append(filterRegs, filterReg) filterRegs = append(filterRegs, filterReg)
} }
} }
gb := &GroupBase{ gb := &GroupBase{
Base: outbound.NewBase(outbound.BaseOption{Name: opt.Name, Type: opt.Type}), Base: outbound.NewBase(opt.BaseOption),
filterRegs: filterRegs, filterRegs: filterRegs,
excludeFilterRegs: excludeFilterRegs, excludeFilterReg: excludeFilterReg,
excludeTypeArray: excludeTypeArray, excludeTypeArray: excludeTypeArray,
providers: opt.Providers, providers: opt.providers,
failedTesting: atomic.NewBool(false), failedTesting: atomic.NewBool(false),
TestTimeout: opt.TestTimeout, TestTimeout: opt.TestTimeout,
maxFailedTimes: opt.MaxFailedTimes, maxFailedTimes: opt.maxFailedTimes,
} }
if gb.TestTimeout == 0 { if gb.TestTimeout == 0 {
@@ -90,6 +81,9 @@ func NewGroupBase(opt GroupBaseOption) *GroupBase {
gb.maxFailedTimes = 5 gb.maxFailedTimes = 5
} }
gb.proxies = make([][]C.Proxy, len(opt.providers))
gb.versions = make([]atomic.Uint32, len(opt.providers))
return gb return gb
} }
@@ -100,62 +94,63 @@ func (gb *GroupBase) Touch() {
} }
func (gb *GroupBase) GetProxies(touch bool) []C.Proxy { func (gb *GroupBase) GetProxies(touch bool) []C.Proxy {
providerVersions := make([]uint32, len(gb.providers))
for i, pd := range gb.providers {
if touch { // touch first
pd.Touch()
}
providerVersions[i] = pd.Version()
}
// thread safe
gb.getProxiesMutex.Lock()
defer gb.getProxiesMutex.Unlock()
// return the cached proxies if version not changed
if slices.Equal(providerVersions, gb.providerVersions) {
return gb.providerProxies
}
var proxies []C.Proxy var proxies []C.Proxy
if len(gb.filterRegs) == 0 { if len(gb.filterRegs) == 0 {
for _, pd := range gb.providers { for _, pd := range gb.providers {
if touch {
pd.Touch()
}
proxies = append(proxies, pd.Proxies()...) proxies = append(proxies, pd.Proxies()...)
} }
} else { } else {
for _, pd := range gb.providers { for i, pd := range gb.providers {
if pd.VehicleType() == P.Compatible { // compatible provider unneeded filter if touch {
proxies = append(proxies, pd.Proxies()...) pd.Touch()
}
if pd.VehicleType() == types.Compatible {
gb.versions[i].Store(pd.Version())
gb.proxies[i] = pd.Proxies()
continue continue
} }
var newProxies []C.Proxy version := gb.versions[i].Load()
proxiesSet := map[string]struct{}{} if version != pd.Version() && gb.versions[i].CompareAndSwap(version, pd.Version()) {
for _, filterReg := range gb.filterRegs { var (
for _, p := range pd.Proxies() { proxies []C.Proxy
name := p.Name() newProxies []C.Proxy
if mat, _ := filterReg.MatchString(name); mat { )
if _, ok := proxiesSet[name]; !ok {
proxiesSet[name] = struct{}{} proxies = pd.Proxies()
newProxies = append(newProxies, p) proxiesSet := map[string]struct{}{}
for _, filterReg := range gb.filterRegs {
for _, p := range proxies {
name := p.Name()
if mat, _ := filterReg.FindStringMatch(name); mat != nil {
if _, ok := proxiesSet[name]; !ok {
proxiesSet[name] = struct{}{}
newProxies = append(newProxies, p)
}
} }
} }
} }
gb.proxies[i] = newProxies
} }
proxies = append(proxies, newProxies...) }
for _, p := range gb.proxies {
proxies = append(proxies, p...)
} }
} }
// Multiple filers means that proxies are sorted in the order in which the filers appear.
// Although the filter has been performed once in the previous process,
// when there are multiple providers, the array needs to be reordered as a whole.
if len(gb.providers) > 1 && len(gb.filterRegs) > 1 { if len(gb.providers) > 1 && len(gb.filterRegs) > 1 {
var newProxies []C.Proxy var newProxies []C.Proxy
proxiesSet := map[string]struct{}{} proxiesSet := map[string]struct{}{}
for _, filterReg := range gb.filterRegs { for _, filterReg := range gb.filterRegs {
for _, p := range proxies { for _, p := range proxies {
name := p.Name() name := p.Name()
if mat, _ := filterReg.MatchString(name); mat { if mat, _ := filterReg.FindStringMatch(name); mat != nil {
if _, ok := proxiesSet[name]; !ok { if _, ok := proxiesSet[name]; !ok {
proxiesSet[name] = struct{}{} proxiesSet[name] = struct{}{}
newProxies = append(newProxies, p) newProxies = append(newProxies, p)
@@ -172,31 +167,32 @@ func (gb *GroupBase) GetProxies(touch bool) []C.Proxy {
} }
proxies = newProxies proxies = newProxies
} }
if gb.excludeTypeArray != nil {
if len(gb.excludeFilterRegs) > 0 {
var newProxies []C.Proxy var newProxies []C.Proxy
LOOP1:
for _, p := range proxies { for _, p := range proxies {
name := p.Name() mType := p.Type().String()
for _, excludeFilterReg := range gb.excludeFilterRegs { flag := false
if mat, _ := excludeFilterReg.MatchString(name); mat { for i := range gb.excludeTypeArray {
continue LOOP1 if strings.EqualFold(mType, gb.excludeTypeArray[i]) {
flag = true
break
} }
}
if flag {
continue
} }
newProxies = append(newProxies, p) newProxies = append(newProxies, p)
} }
proxies = newProxies proxies = newProxies
} }
if gb.excludeTypeArray != nil { if gb.excludeFilterReg != nil {
var newProxies []C.Proxy var newProxies []C.Proxy
LOOP2:
for _, p := range proxies { for _, p := range proxies {
mType := p.Type().String() name := p.Name()
for _, excludeType := range gb.excludeTypeArray { if mat, _ := gb.excludeFilterReg.FindStringMatch(name); mat != nil {
if strings.EqualFold(mType, excludeType) { continue
continue LOOP2
}
} }
newProxies = append(newProxies, p) newProxies = append(newProxies, p)
} }
@@ -204,13 +200,9 @@ func (gb *GroupBase) GetProxies(touch bool) []C.Proxy {
} }
if len(proxies) == 0 { if len(proxies) == 0 {
return []C.Proxy{tunnel.Proxies()["COMPATIBLE"]} return append(proxies, tunnel.Proxies()["COMPATIBLE"])
} }
// only cache when proxies not empty
gb.providerVersions = providerVersions
gb.providerProxies = proxies
return proxies return proxies
} }
@@ -242,21 +234,17 @@ func (gb *GroupBase) URLTest(ctx context.Context, url string, expectedStatus uti
} }
} }
func (gb *GroupBase) onDialFailed(adapterType C.AdapterType, err error, fn func()) { func (gb *GroupBase) onDialFailed(adapterType C.AdapterType, err error) {
if adapterType == C.Direct || adapterType == C.Compatible || adapterType == C.Reject || adapterType == C.Pass || adapterType == C.RejectDrop { if adapterType == C.Direct || adapterType == C.Compatible || adapterType == C.Reject || adapterType == C.Pass || adapterType == C.RejectDrop {
return return
} }
if errors.Is(err, C.ErrNotSupport) { if strings.Contains(err.Error(), "connection refused") {
go gb.healthCheck()
return return
} }
go func() { go func() {
if strings.Contains(err.Error(), "connection refused") {
fn()
return
}
gb.failedTestMux.Lock() gb.failedTestMux.Lock()
defer gb.failedTestMux.Unlock() defer gb.failedTestMux.Unlock()
@@ -273,7 +261,7 @@ func (gb *GroupBase) onDialFailed(adapterType C.AdapterType, err error, fn func(
log.Debugln("ProxyGroup: %s failed count: %d", gb.Name(), gb.failedTimes) log.Debugln("ProxyGroup: %s failed count: %d", gb.Name(), gb.failedTimes)
if gb.failedTimes >= gb.maxFailedTimes { if gb.failedTimes >= gb.maxFailedTimes {
log.Warnln("because %s failed multiple times, active health check", gb.Name()) log.Warnln("because %s failed multiple times, active health check", gb.Name())
fn() gb.healthCheck()
} }
} }
}() }()

View File

@@ -9,12 +9,14 @@ import (
"sync" "sync"
"time" "time"
"github.com/metacubex/mihomo/adapter/outbound"
"github.com/metacubex/mihomo/common/callback" "github.com/metacubex/mihomo/common/callback"
"github.com/metacubex/mihomo/common/lru" "github.com/metacubex/mihomo/common/lru"
N "github.com/metacubex/mihomo/common/net" N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/common/utils" "github.com/metacubex/mihomo/common/utils"
"github.com/metacubex/mihomo/component/dialer"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
P "github.com/metacubex/mihomo/constant/provider" "github.com/metacubex/mihomo/constant/provider"
"golang.org/x/net/publicsuffix" "golang.org/x/net/publicsuffix"
) )
@@ -86,14 +88,14 @@ func jumpHash(key uint64, buckets int32) int32 {
} }
// DialContext implements C.ProxyAdapter // DialContext implements C.ProxyAdapter
func (lb *LoadBalance) DialContext(ctx context.Context, metadata *C.Metadata) (c C.Conn, err error) { func (lb *LoadBalance) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (c C.Conn, err error) {
proxy := lb.Unwrap(metadata, true) proxy := lb.Unwrap(metadata, true)
c, err = proxy.DialContext(ctx, metadata) c, err = proxy.DialContext(ctx, metadata, lb.Base.DialOptions(opts...)...)
if err == nil { if err == nil {
c.AppendToChains(lb) c.AppendToChains(lb)
} else { } else {
lb.onDialFailed(proxy.Type(), err, lb.healthCheck) lb.onDialFailed(proxy.Type(), err)
} }
if N.NeedHandshake(c) { if N.NeedHandshake(c) {
@@ -101,7 +103,7 @@ func (lb *LoadBalance) DialContext(ctx context.Context, metadata *C.Metadata) (c
if err == nil { if err == nil {
lb.onDialSuccess() lb.onDialSuccess()
} else { } else {
lb.onDialFailed(proxy.Type(), err, lb.healthCheck) lb.onDialFailed(proxy.Type(), err)
} }
}) })
} }
@@ -110,7 +112,7 @@ func (lb *LoadBalance) DialContext(ctx context.Context, metadata *C.Metadata) (c
} }
// ListenPacketContext implements C.ProxyAdapter // ListenPacketContext implements C.ProxyAdapter
func (lb *LoadBalance) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (pc C.PacketConn, err error) { func (lb *LoadBalance) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (pc C.PacketConn, err error) {
defer func() { defer func() {
if err == nil { if err == nil {
pc.AppendToChains(lb) pc.AppendToChains(lb)
@@ -118,7 +120,7 @@ func (lb *LoadBalance) ListenPacketContext(ctx context.Context, metadata *C.Meta
}() }()
proxy := lb.Unwrap(metadata, true) proxy := lb.Unwrap(metadata, true)
return proxy.ListenPacketContext(ctx, metadata) return proxy.ListenPacketContext(ctx, metadata, lb.Base.DialOptions(opts...)...)
} }
// SupportUDP implements C.ProxyAdapter // SupportUDP implements C.ProxyAdapter
@@ -194,7 +196,7 @@ func strategyStickySessions(url string) strategyFn {
key := utils.MapHash(getKeyWithSrcAndDst(metadata)) key := utils.MapHash(getKeyWithSrcAndDst(metadata))
length := len(proxies) length := len(proxies)
idx, has := lruCache.Get(key) idx, has := lruCache.Get(key)
if !has || idx >= length { if !has {
idx = int(jumpHash(key+uint64(time.Now().UnixNano()), int32(length))) idx = int(jumpHash(key+uint64(time.Now().UnixNano()), int32(length)))
} }
@@ -202,7 +204,8 @@ func strategyStickySessions(url string) strategyFn {
for i := 1; i < maxRetry; i++ { for i := 1; i < maxRetry; i++ {
proxy := proxies[nowIdx] proxy := proxies[nowIdx]
if proxy.AliveForTestUrl(url) { if proxy.AliveForTestUrl(url) {
if !has || nowIdx != idx { if nowIdx != idx {
lruCache.Delete(key)
lruCache.Set(key, nowIdx) lruCache.Set(key, nowIdx)
} }
@@ -212,6 +215,7 @@ func strategyStickySessions(url string) strategyFn {
} }
} }
lruCache.Delete(key)
lruCache.Set(key, 0) lruCache.Set(key, 0)
return proxies[0] return proxies[0]
} }
@@ -239,19 +243,7 @@ func (lb *LoadBalance) MarshalJSON() ([]byte, error) {
}) })
} }
func (lb *LoadBalance) Providers() []P.ProxyProvider { func NewLoadBalance(option *GroupCommonOption, providers []provider.ProxyProvider, strategy string) (lb *LoadBalance, err error) {
return lb.providers
}
func (lb *LoadBalance) Proxies() []C.Proxy {
return lb.GetProxies(false)
}
func (lb *LoadBalance) Now() string {
return ""
}
func NewLoadBalance(option *GroupCommonOption, providers []P.ProxyProvider, strategy string) (lb *LoadBalance, err error) {
var strategyFn strategyFn var strategyFn strategyFn
switch strategy { switch strategy {
case "consistent-hashing": case "consistent-hashing":
@@ -265,14 +257,18 @@ func NewLoadBalance(option *GroupCommonOption, providers []P.ProxyProvider, stra
} }
return &LoadBalance{ return &LoadBalance{
GroupBase: NewGroupBase(GroupBaseOption{ GroupBase: NewGroupBase(GroupBaseOption{
Name: option.Name, outbound.BaseOption{
Type: C.LoadBalance, Name: option.Name,
Filter: option.Filter, Type: C.LoadBalance,
ExcludeFilter: option.ExcludeFilter, Interface: option.Interface,
ExcludeType: option.ExcludeType, RoutingMark: option.RoutingMark,
TestTimeout: option.TestTimeout, },
MaxFailedTimes: option.MaxFailedTimes, option.Filter,
Providers: providers, option.ExcludeFilter,
option.ExcludeType,
option.TestTimeout,
option.MaxFailedTimes,
providers,
}), }),
strategyFn: strategyFn, strategyFn: strategyFn,
disableUDP: option.DisableUDP, disableUDP: option.DisableUDP,

View File

@@ -5,14 +5,12 @@ import (
"fmt" "fmt"
"strings" "strings"
"github.com/dlclark/regexp2" "github.com/metacubex/mihomo/adapter/outbound"
"github.com/metacubex/mihomo/adapter/provider" "github.com/metacubex/mihomo/adapter/provider"
"github.com/metacubex/mihomo/common/structure" "github.com/metacubex/mihomo/common/structure"
"github.com/metacubex/mihomo/common/utils" "github.com/metacubex/mihomo/common/utils"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
P "github.com/metacubex/mihomo/constant/provider" types "github.com/metacubex/mihomo/constant/provider"
"github.com/metacubex/mihomo/log"
) )
var ( var (
@@ -23,6 +21,7 @@ var (
) )
type GroupCommonOption struct { type GroupCommonOption struct {
outbound.BasicOption
Name string `group:"name"` Name string `group:"name"`
Type string `group:"type"` Type string `group:"type"`
Proxies []string `group:"proxies,omitempty"` Proxies []string `group:"proxies,omitempty"`
@@ -42,13 +41,9 @@ type GroupCommonOption struct {
IncludeAllProviders bool `group:"include-all-providers,omitempty"` IncludeAllProviders bool `group:"include-all-providers,omitempty"`
Hidden bool `group:"hidden,omitempty"` Hidden bool `group:"hidden,omitempty"`
Icon string `group:"icon,omitempty"` Icon string `group:"icon,omitempty"`
// removed configs, only for error logging
Interface string `group:"interface-name,omitempty"`
RoutingMark int `group:"routing-mark,omitempty"`
} }
func ParseProxyGroup(config map[string]any, proxyMap map[string]C.Proxy, providersMap map[string]P.ProxyProvider, AllProxies []string, AllProviders []string) (C.ProxyAdapter, error) { func ParseProxyGroup(config map[string]any, proxyMap map[string]C.Proxy, providersMap map[string]types.ProxyProvider, AllProxies []string, AllProviders []string) (C.ProxyAdapter, error) {
decoder := structure.NewDecoder(structure.Option{TagName: "group", WeaklyTypedInput: true}) decoder := structure.NewDecoder(structure.Option{TagName: "group", WeaklyTypedInput: true})
groupOption := &GroupCommonOption{ groupOption := &GroupCommonOption{
@@ -62,16 +57,9 @@ func ParseProxyGroup(config map[string]any, proxyMap map[string]C.Proxy, provide
return nil, errFormat return nil, errFormat
} }
if groupOption.RoutingMark != 0 {
log.Errorln("The group [%s] with routing-mark configuration was removed, please set it directly on the proxy instead", groupOption.Name)
}
if groupOption.Interface != "" {
log.Errorln("The group [%s] with interface-name configuration was removed, please set it directly on the proxy instead", groupOption.Name)
}
groupName := groupOption.Name groupName := groupOption.Name
providers := []P.ProxyProvider{} providers := []types.ProxyProvider{}
if groupOption.IncludeAll { if groupOption.IncludeAll {
groupOption.IncludeAllProviders = true groupOption.IncludeAllProviders = true
@@ -79,28 +67,10 @@ func ParseProxyGroup(config map[string]any, proxyMap map[string]C.Proxy, provide
} }
if groupOption.IncludeAllProviders { if groupOption.IncludeAllProviders {
groupOption.Use = AllProviders groupOption.Use = append(groupOption.Use, AllProviders...)
} }
if groupOption.IncludeAllProxies { if groupOption.IncludeAllProxies {
if groupOption.Filter != "" { groupOption.Proxies = append(groupOption.Proxies, AllProxies...)
var filterRegs []*regexp2.Regexp
for _, filter := range strings.Split(groupOption.Filter, "`") {
filterReg := regexp2.MustCompile(filter, regexp2.None)
filterRegs = append(filterRegs, filterReg)
}
for _, p := range AllProxies {
for _, filterReg := range filterRegs {
if mat, _ := filterReg.MatchString(p); mat {
groupOption.Proxies = append(groupOption.Proxies, p)
}
}
}
} else {
groupOption.Proxies = append(groupOption.Proxies, AllProxies...)
}
if len(groupOption.Proxies) == 0 && len(groupOption.Use) == 0 {
groupOption.Proxies = []string{"COMPATIBLE"}
}
} }
if len(groupOption.Proxies) == 0 && len(groupOption.Use) == 0 { if len(groupOption.Proxies) == 0 && len(groupOption.Use) == 0 {
@@ -118,29 +88,6 @@ func ParseProxyGroup(config map[string]any, proxyMap map[string]C.Proxy, provide
} }
groupOption.ExpectedStatus = status groupOption.ExpectedStatus = status
if len(groupOption.Use) != 0 {
PDs, err := getProviders(providersMap, groupOption.Use)
if err != nil {
return nil, fmt.Errorf("%s: %w", groupName, err)
}
// if test URL is empty, use the first health check URL of providers
if groupOption.URL == "" {
for _, pd := range PDs {
if pd.HealthCheckURL() != "" {
groupOption.URL = pd.HealthCheckURL()
break
}
}
if groupOption.URL == "" {
groupOption.URL = C.DefaultTestURL
}
} else {
addTestUrlToProviders(PDs, groupOption.URL, expectedStatus, groupOption.Filter, uint(groupOption.Interval))
}
providers = append(providers, PDs...)
}
if len(groupOption.Proxies) != 0 { if len(groupOption.Proxies) != 0 {
ps, err := getProxies(proxyMap, groupOption.Proxies) ps, err := getProxies(proxyMap, groupOption.Proxies)
if err != nil { if err != nil {
@@ -151,15 +98,14 @@ func ParseProxyGroup(config map[string]any, proxyMap map[string]C.Proxy, provide
return nil, fmt.Errorf("%s: %w", groupName, errDuplicateProvider) return nil, fmt.Errorf("%s: %w", groupName, errDuplicateProvider)
} }
if groupOption.URL == "" { // select don't need health check
groupOption.URL = C.DefaultTestURL
}
// select don't need auto health check
if groupOption.Type != "select" && groupOption.Type != "relay" { if groupOption.Type != "select" && groupOption.Type != "relay" {
if groupOption.Interval == 0 { if groupOption.Interval == 0 {
groupOption.Interval = 300 groupOption.Interval = 300
} }
if groupOption.URL == "" {
groupOption.URL = C.DefaultTestURL
}
} }
hc := provider.NewHealthCheck(ps, groupOption.URL, uint(groupOption.TestTimeout), uint(groupOption.Interval), groupOption.Lazy, expectedStatus) hc := provider.NewHealthCheck(ps, groupOption.URL, uint(groupOption.TestTimeout), uint(groupOption.Interval), groupOption.Lazy, expectedStatus)
@@ -169,10 +115,34 @@ func ParseProxyGroup(config map[string]any, proxyMap map[string]C.Proxy, provide
return nil, fmt.Errorf("%s: %w", groupName, err) return nil, fmt.Errorf("%s: %w", groupName, err)
} }
providers = append([]P.ProxyProvider{pd}, providers...) providers = append(providers, pd)
providersMap[groupName] = pd providersMap[groupName] = pd
} }
if len(groupOption.Use) != 0 {
list, err := getProviders(providersMap, groupOption.Use)
if err != nil {
return nil, fmt.Errorf("%s: %w", groupName, err)
}
if groupOption.URL == "" {
for _, p := range list {
if p.HealthCheckURL() != "" {
groupOption.URL = p.HealthCheckURL()
}
break
}
if groupOption.URL == "" {
groupOption.URL = C.DefaultTestURL
}
}
// different proxy groups use different test URL
addTestUrlToProviders(list, groupOption.URL, expectedStatus, groupOption.Filter, uint(groupOption.Interval))
providers = append(providers, list...)
}
var group C.ProxyAdapter var group C.ProxyAdapter
switch groupOption.Type { switch groupOption.Type {
case "url-test": case "url-test":
@@ -186,7 +156,7 @@ func ParseProxyGroup(config map[string]any, proxyMap map[string]C.Proxy, provide
strategy := parseStrategy(config) strategy := parseStrategy(config)
return NewLoadBalance(groupOption, providers, strategy) return NewLoadBalance(groupOption, providers, strategy)
case "relay": case "relay":
return nil, fmt.Errorf("%w: The group [%s] with relay type was removed, please using dialer-proxy instead", errType, groupName) group = NewRelay(groupOption, providers)
default: default:
return nil, fmt.Errorf("%w: %s", errType, groupOption.Type) return nil, fmt.Errorf("%w: %s", errType, groupOption.Type)
} }
@@ -206,15 +176,15 @@ func getProxies(mapping map[string]C.Proxy, list []string) ([]C.Proxy, error) {
return ps, nil return ps, nil
} }
func getProviders(mapping map[string]P.ProxyProvider, list []string) ([]P.ProxyProvider, error) { func getProviders(mapping map[string]types.ProxyProvider, list []string) ([]types.ProxyProvider, error) {
var ps []P.ProxyProvider var ps []types.ProxyProvider
for _, name := range list { for _, name := range list {
p, ok := mapping[name] p, ok := mapping[name]
if !ok { if !ok {
return nil, fmt.Errorf("'%s' not found", name) return nil, fmt.Errorf("'%s' not found", name)
} }
if p.VehicleType() == P.Compatible { if p.VehicleType() == types.Compatible {
return nil, fmt.Errorf("proxy group %s can't contains in `use`", name) return nil, fmt.Errorf("proxy group %s can't contains in `use`", name)
} }
ps = append(ps, p) ps = append(ps, p)
@@ -222,7 +192,7 @@ func getProviders(mapping map[string]P.ProxyProvider, list []string) ([]P.ProxyP
return ps, nil return ps, nil
} }
func addTestUrlToProviders(providers []P.ProxyProvider, url string, expectedStatus utils.IntRanges[uint16], filter string, interval uint) { func addTestUrlToProviders(providers []types.ProxyProvider, url string, expectedStatus utils.IntRanges[uint16], filter string, interval uint) {
if len(providers) == 0 || len(url) == 0 { if len(providers) == 0 || len(url) == 0 {
return return
} }

View File

@@ -0,0 +1,64 @@
//go:build android && cmfa
package outboundgroup
import (
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/constant/provider"
)
type ProxyGroup interface {
C.ProxyAdapter
Providers() []provider.ProxyProvider
Proxies() []C.Proxy
Now() string
}
func (f *Fallback) Providers() []provider.ProxyProvider {
return f.providers
}
func (lb *LoadBalance) Providers() []provider.ProxyProvider {
return lb.providers
}
func (f *Fallback) Proxies() []C.Proxy {
return f.GetProxies(false)
}
func (lb *LoadBalance) Proxies() []C.Proxy {
return lb.GetProxies(false)
}
func (lb *LoadBalance) Now() string {
return ""
}
func (r *Relay) Providers() []provider.ProxyProvider {
return r.providers
}
func (r *Relay) Proxies() []C.Proxy {
return r.GetProxies(false)
}
func (r *Relay) Now() string {
return ""
}
func (s *Selector) Providers() []provider.ProxyProvider {
return s.providers
}
func (s *Selector) Proxies() []C.Proxy {
return s.GetProxies(false)
}
func (u *URLTest) Providers() []provider.ProxyProvider {
return u.providers
}
func (u *URLTest) Proxies() []C.Proxy {
return u.GetProxies(false)
}

View File

@@ -0,0 +1,170 @@
package outboundgroup
import (
"context"
"encoding/json"
"github.com/metacubex/mihomo/adapter/outbound"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/proxydialer"
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/constant/provider"
)
type Relay struct {
*GroupBase
Hidden bool
Icon string
}
// DialContext implements C.ProxyAdapter
func (r *Relay) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.Conn, error) {
proxies, chainProxies := r.proxies(metadata, true)
switch len(proxies) {
case 0:
return outbound.NewDirect().DialContext(ctx, metadata, r.Base.DialOptions(opts...)...)
case 1:
return proxies[0].DialContext(ctx, metadata, r.Base.DialOptions(opts...)...)
}
var d C.Dialer
d = dialer.NewDialer(r.Base.DialOptions(opts...)...)
for _, proxy := range proxies[:len(proxies)-1] {
d = proxydialer.New(proxy, d, false)
}
last := proxies[len(proxies)-1]
conn, err := last.DialContextWithDialer(ctx, d, metadata)
if err != nil {
return nil, err
}
for i := len(chainProxies) - 2; i >= 0; i-- {
conn.AppendToChains(chainProxies[i])
}
conn.AppendToChains(r)
return conn, nil
}
// ListenPacketContext implements C.ProxyAdapter
func (r *Relay) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.PacketConn, err error) {
proxies, chainProxies := r.proxies(metadata, true)
switch len(proxies) {
case 0:
return outbound.NewDirect().ListenPacketContext(ctx, metadata, r.Base.DialOptions(opts...)...)
case 1:
return proxies[0].ListenPacketContext(ctx, metadata, r.Base.DialOptions(opts...)...)
}
var d C.Dialer
d = dialer.NewDialer(r.Base.DialOptions(opts...)...)
for _, proxy := range proxies[:len(proxies)-1] {
d = proxydialer.New(proxy, d, false)
}
last := proxies[len(proxies)-1]
pc, err := last.ListenPacketWithDialer(ctx, d, metadata)
if err != nil {
return nil, err
}
for i := len(chainProxies) - 2; i >= 0; i-- {
pc.AppendToChains(chainProxies[i])
}
pc.AppendToChains(r)
return pc, nil
}
// SupportUDP implements C.ProxyAdapter
func (r *Relay) SupportUDP() bool {
proxies, _ := r.proxies(nil, false)
if len(proxies) == 0 { // C.Direct
return true
}
for i := len(proxies) - 1; i >= 0; i-- {
proxy := proxies[i]
if !proxy.SupportUDP() {
return false
}
if proxy.SupportUOT() {
return true
}
switch proxy.SupportWithDialer() {
case C.ALLNet:
case C.UDP:
default: // C.TCP and C.InvalidNet
return false
}
}
return true
}
// MarshalJSON implements C.ProxyAdapter
func (r *Relay) MarshalJSON() ([]byte, error) {
all := []string{}
for _, proxy := range r.GetProxies(false) {
all = append(all, proxy.Name())
}
return json.Marshal(map[string]any{
"type": r.Type().String(),
"all": all,
"hidden": r.Hidden,
"icon": r.Icon,
})
}
func (r *Relay) proxies(metadata *C.Metadata, touch bool) ([]C.Proxy, []C.Proxy) {
rawProxies := r.GetProxies(touch)
var proxies []C.Proxy
var chainProxies []C.Proxy
var targetProxies []C.Proxy
for n, proxy := range rawProxies {
proxies = append(proxies, proxy)
chainProxies = append(chainProxies, proxy)
subproxy := proxy.Unwrap(metadata, touch)
for subproxy != nil {
chainProxies = append(chainProxies, subproxy)
proxies[n] = subproxy
subproxy = subproxy.Unwrap(metadata, touch)
}
}
for _, proxy := range proxies {
if proxy.Type() != C.Direct && proxy.Type() != C.Compatible {
targetProxies = append(targetProxies, proxy)
}
}
return targetProxies, chainProxies
}
func (r *Relay) Addr() string {
proxies, _ := r.proxies(nil, false)
return proxies[len(proxies)-1].Addr()
}
func NewRelay(option *GroupCommonOption, providers []provider.ProxyProvider) *Relay {
return &Relay{
GroupBase: NewGroupBase(GroupBaseOption{
outbound.BaseOption{
Name: option.Name,
Type: C.Relay,
Interface: option.Interface,
RoutingMark: option.RoutingMark,
},
"",
"",
"",
5000,
5,
providers,
}),
Hidden: option.Hidden,
Icon: option.Icon,
}
}

View File

@@ -5,22 +5,23 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"github.com/metacubex/mihomo/adapter/outbound"
"github.com/metacubex/mihomo/component/dialer"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
P "github.com/metacubex/mihomo/constant/provider" "github.com/metacubex/mihomo/constant/provider"
) )
type Selector struct { type Selector struct {
*GroupBase *GroupBase
disableUDP bool disableUDP bool
selected string selected string
testUrl string
Hidden bool Hidden bool
Icon string Icon string
} }
// DialContext implements C.ProxyAdapter // DialContext implements C.ProxyAdapter
func (s *Selector) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) { func (s *Selector) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.Conn, error) {
c, err := s.selectedProxy(true).DialContext(ctx, metadata) c, err := s.selectedProxy(true).DialContext(ctx, metadata, s.Base.DialOptions(opts...)...)
if err == nil { if err == nil {
c.AppendToChains(s) c.AppendToChains(s)
} }
@@ -28,8 +29,8 @@ func (s *Selector) DialContext(ctx context.Context, metadata *C.Metadata) (C.Con
} }
// ListenPacketContext implements C.ProxyAdapter // ListenPacketContext implements C.ProxyAdapter
func (s *Selector) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (C.PacketConn, error) { func (s *Selector) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.PacketConn, error) {
pc, err := s.selectedProxy(true).ListenPacketContext(ctx, metadata) pc, err := s.selectedProxy(true).ListenPacketContext(ctx, metadata, s.Base.DialOptions(opts...)...)
if err == nil { if err == nil {
pc.AppendToChains(s) pc.AppendToChains(s)
} }
@@ -56,20 +57,13 @@ func (s *Selector) MarshalJSON() ([]byte, error) {
for _, proxy := range s.GetProxies(false) { for _, proxy := range s.GetProxies(false) {
all = append(all, proxy.Name()) all = append(all, proxy.Name())
} }
// When testurl is the default value
// do not append a value to ensure that the web dashboard follows the settings of the dashboard
var url string
if s.testUrl != C.DefaultTestURL {
url = s.testUrl
}
return json.Marshal(map[string]any{ return json.Marshal(map[string]any{
"type": s.Type().String(), "type": s.Type().String(),
"now": s.Now(), "now": s.Now(),
"all": all, "all": all,
"testUrl": url, "hidden": s.Hidden,
"hidden": s.Hidden, "icon": s.Icon,
"icon": s.Icon,
}) })
} }
@@ -108,29 +102,24 @@ func (s *Selector) selectedProxy(touch bool) C.Proxy {
return proxies[0] return proxies[0]
} }
func (s *Selector) Providers() []P.ProxyProvider { func NewSelector(option *GroupCommonOption, providers []provider.ProxyProvider) *Selector {
return s.providers
}
func (s *Selector) Proxies() []C.Proxy {
return s.GetProxies(false)
}
func NewSelector(option *GroupCommonOption, providers []P.ProxyProvider) *Selector {
return &Selector{ return &Selector{
GroupBase: NewGroupBase(GroupBaseOption{ GroupBase: NewGroupBase(GroupBaseOption{
Name: option.Name, outbound.BaseOption{
Type: C.Selector, Name: option.Name,
Filter: option.Filter, Type: C.Selector,
ExcludeFilter: option.ExcludeFilter, Interface: option.Interface,
ExcludeType: option.ExcludeType, RoutingMark: option.RoutingMark,
TestTimeout: option.TestTimeout, },
MaxFailedTimes: option.MaxFailedTimes, option.Filter,
Providers: providers, option.ExcludeFilter,
option.ExcludeType,
option.TestTimeout,
option.MaxFailedTimes,
providers,
}), }),
selected: "COMPATIBLE", selected: "COMPATIBLE",
disableUDP: option.DisableUDP, disableUDP: option.DisableUDP,
testUrl: option.URL,
Hidden: option.Hidden, Hidden: option.Hidden,
Icon: option.Icon, Icon: option.Icon,
} }

View File

@@ -4,14 +4,18 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt"
"sync"
"time" "time"
"github.com/metacubex/mihomo/adapter/outbound"
"github.com/metacubex/mihomo/common/callback" "github.com/metacubex/mihomo/common/callback"
N "github.com/metacubex/mihomo/common/net" N "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/common/singledo" "github.com/metacubex/mihomo/common/singledo"
"github.com/metacubex/mihomo/common/utils" "github.com/metacubex/mihomo/common/utils"
"github.com/metacubex/mihomo/component/dialer"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
P "github.com/metacubex/mihomo/constant/provider" "github.com/metacubex/mihomo/constant/provider"
) )
type urlTestOption func(*URLTest) type urlTestOption func(*URLTest)
@@ -50,23 +54,23 @@ func (u *URLTest) Set(name string) error {
if p == nil { if p == nil {
return errors.New("proxy not exist") return errors.New("proxy not exist")
} }
u.ForceSet(name) u.selected = name
u.fast(false)
return nil return nil
} }
func (u *URLTest) ForceSet(name string) { func (u *URLTest) ForceSet(name string) {
u.selected = name u.selected = name
u.fastSingle.Reset()
} }
// DialContext implements C.ProxyAdapter // DialContext implements C.ProxyAdapter
func (u *URLTest) DialContext(ctx context.Context, metadata *C.Metadata) (c C.Conn, err error) { func (u *URLTest) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (c C.Conn, err error) {
proxy := u.fast(true) proxy := u.fast(true)
c, err = proxy.DialContext(ctx, metadata) c, err = proxy.DialContext(ctx, metadata, u.Base.DialOptions(opts...)...)
if err == nil { if err == nil {
c.AppendToChains(u) c.AppendToChains(u)
} else { } else {
u.onDialFailed(proxy.Type(), err, u.healthCheck) u.onDialFailed(proxy.Type(), err)
} }
if N.NeedHandshake(c) { if N.NeedHandshake(c) {
@@ -74,7 +78,7 @@ func (u *URLTest) DialContext(ctx context.Context, metadata *C.Metadata) (c C.Co
if err == nil { if err == nil {
u.onDialSuccess() u.onDialSuccess()
} else { } else {
u.onDialFailed(proxy.Type(), err, u.healthCheck) u.onDialFailed(proxy.Type(), err)
} }
}) })
} }
@@ -83,13 +87,10 @@ func (u *URLTest) DialContext(ctx context.Context, metadata *C.Metadata) (c C.Co
} }
// ListenPacketContext implements C.ProxyAdapter // ListenPacketContext implements C.ProxyAdapter
func (u *URLTest) ListenPacketContext(ctx context.Context, metadata *C.Metadata) (C.PacketConn, error) { func (u *URLTest) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.PacketConn, error) {
proxy := u.fast(true) pc, err := u.fast(true).ListenPacketContext(ctx, metadata, u.Base.DialOptions(opts...)...)
pc, err := proxy.ListenPacketContext(ctx, metadata)
if err == nil { if err == nil {
pc.AppendToChains(u) pc.AppendToChains(u)
} else {
u.onDialFailed(proxy.Type(), err, u.healthCheck)
} }
return pc, err return pc, err
@@ -100,27 +101,22 @@ func (u *URLTest) Unwrap(metadata *C.Metadata, touch bool) C.Proxy {
return u.fast(touch) return u.fast(touch)
} }
func (u *URLTest) healthCheck() {
u.fastSingle.Reset()
u.GroupBase.healthCheck()
u.fastSingle.Reset()
}
func (u *URLTest) fast(touch bool) C.Proxy { func (u *URLTest) fast(touch bool) C.Proxy {
elm, _, shared := u.fastSingle.Do(func() (C.Proxy, error) {
proxies := u.GetProxies(touch) proxies := u.GetProxies(touch)
if u.selected != "" { if u.selected != "" {
for _, proxy := range proxies { for _, proxy := range proxies {
if !proxy.AliveForTestUrl(u.testUrl) { if !proxy.AliveForTestUrl(u.testUrl) {
continue continue
} }
if proxy.Name() == u.selected { if proxy.Name() == u.selected {
u.fastNode = proxy u.fastNode = proxy
return proxy, nil return proxy
}
} }
} }
}
elm, _, shared := u.fastSingle.Do(func() (C.Proxy, error) {
fast := proxies[0] fast := proxies[0]
minDelay := fast.LastDelayForTestUrl(u.testUrl) minDelay := fast.LastDelayForTestUrl(u.testUrl)
fastNotExist := true fastNotExist := true
@@ -185,16 +181,32 @@ func (u *URLTest) MarshalJSON() ([]byte, error) {
}) })
} }
func (u *URLTest) Providers() []P.ProxyProvider {
return u.providers
}
func (u *URLTest) Proxies() []C.Proxy {
return u.GetProxies(false)
}
func (u *URLTest) URLTest(ctx context.Context, url string, expectedStatus utils.IntRanges[uint16]) (map[string]uint16, error) { func (u *URLTest) URLTest(ctx context.Context, url string, expectedStatus utils.IntRanges[uint16]) (map[string]uint16, error) {
return u.GroupBase.URLTest(ctx, u.testUrl, expectedStatus) var wg sync.WaitGroup
var lock sync.Mutex
mp := map[string]uint16{}
proxies := u.GetProxies(false)
for _, proxy := range proxies {
proxy := proxy
wg.Add(1)
go func() {
delay, err := proxy.URLTest(ctx, u.testUrl, expectedStatus)
if err == nil {
lock.Lock()
mp[proxy.Name()] = delay
lock.Unlock()
}
wg.Done()
}()
}
wg.Wait()
if len(mp) == 0 {
return mp, fmt.Errorf("get delay: all proxies timeout")
} else {
return mp, nil
}
} }
func parseURLTestOption(config map[string]any) []urlTestOption { func parseURLTestOption(config map[string]any) []urlTestOption {
@@ -210,17 +222,22 @@ func parseURLTestOption(config map[string]any) []urlTestOption {
return opts return opts
} }
func NewURLTest(option *GroupCommonOption, providers []P.ProxyProvider, options ...urlTestOption) *URLTest { func NewURLTest(option *GroupCommonOption, providers []provider.ProxyProvider, options ...urlTestOption) *URLTest {
urlTest := &URLTest{ urlTest := &URLTest{
GroupBase: NewGroupBase(GroupBaseOption{ GroupBase: NewGroupBase(GroupBaseOption{
Name: option.Name, outbound.BaseOption{
Type: C.URLTest, Name: option.Name,
Filter: option.Filter, Type: C.URLTest,
ExcludeFilter: option.ExcludeFilter, Interface: option.Interface,
ExcludeType: option.ExcludeType, RoutingMark: option.RoutingMark,
TestTimeout: option.TestTimeout, },
MaxFailedTimes: option.MaxFailedTimes,
Providers: providers, option.Filter,
option.ExcludeFilter,
option.ExcludeType,
option.TestTimeout,
option.MaxFailedTimes,
providers,
}), }),
fastSingle: singledo.NewSingle[C.Proxy](time.Second * 10), fastSingle: singledo.NewSingle[C.Proxy](time.Second * 10),
disableUDP: option.DisableUDP, disableUDP: option.DisableUDP,

View File

@@ -1,34 +1,6 @@
package outboundgroup package outboundgroup
import (
"context"
"github.com/metacubex/mihomo/common/utils"
C "github.com/metacubex/mihomo/constant"
P "github.com/metacubex/mihomo/constant/provider"
)
type ProxyGroup interface {
C.ProxyAdapter
Providers() []P.ProxyProvider
Proxies() []C.Proxy
Now() string
Touch()
URLTest(ctx context.Context, url string, expectedStatus utils.IntRanges[uint16]) (mp map[string]uint16, err error)
}
var _ ProxyGroup = (*Fallback)(nil)
var _ ProxyGroup = (*LoadBalance)(nil)
var _ ProxyGroup = (*URLTest)(nil)
var _ ProxyGroup = (*Selector)(nil)
type SelectAble interface { type SelectAble interface {
Set(string) error Set(string) error
ForceSet(name string) ForceSet(name string)
} }
var _ SelectAble = (*Fallback)(nil)
var _ SelectAble = (*URLTest)(nil)
var _ SelectAble = (*Selector)(nil)

View File

@@ -3,169 +3,144 @@ package adapter
import ( import (
"fmt" "fmt"
tlsC "github.com/metacubex/mihomo/component/tls"
"github.com/metacubex/mihomo/adapter/outbound" "github.com/metacubex/mihomo/adapter/outbound"
"github.com/metacubex/mihomo/common/structure" "github.com/metacubex/mihomo/common/structure"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
) )
func ParseProxy(mapping map[string]any, options ...ProxyOption) (C.Proxy, error) { func ParseProxy(mapping map[string]any) (C.Proxy, error) {
decoder := structure.NewDecoder(structure.Option{TagName: "proxy", WeaklyTypedInput: true, KeyReplacer: structure.DefaultKeyReplacer}) decoder := structure.NewDecoder(structure.Option{TagName: "proxy", WeaklyTypedInput: true, KeyReplacer: structure.DefaultKeyReplacer})
proxyType, existType := mapping["type"].(string) proxyType, existType := mapping["type"].(string)
if !existType { if !existType {
return nil, fmt.Errorf("missing type") return nil, fmt.Errorf("missing type")
} }
opt := applyProxyOptions(options...)
basicOption := outbound.BasicOption{
DialerForAPI: opt.DialerForAPI,
ProviderName: opt.ProviderName,
}
var ( var (
proxy outbound.ProxyAdapter proxy C.ProxyAdapter
err error err error
) )
switch proxyType { switch proxyType {
case "ss": case "ss":
ssOption := &outbound.ShadowSocksOption{BasicOption: basicOption} ssOption := &outbound.ShadowSocksOption{ClientFingerprint: tlsC.GetGlobalFingerprint()}
err = decoder.Decode(mapping, ssOption) err = decoder.Decode(mapping, ssOption)
if err != nil { if err != nil {
break break
} }
proxy, err = outbound.NewShadowSocks(*ssOption) proxy, err = outbound.NewShadowSocks(*ssOption)
case "ssr": case "ssr":
ssrOption := &outbound.ShadowSocksROption{BasicOption: basicOption} ssrOption := &outbound.ShadowSocksROption{}
err = decoder.Decode(mapping, ssrOption) err = decoder.Decode(mapping, ssrOption)
if err != nil { if err != nil {
break break
} }
proxy, err = outbound.NewShadowSocksR(*ssrOption) proxy, err = outbound.NewShadowSocksR(*ssrOption)
case "socks5": case "socks5":
socksOption := &outbound.Socks5Option{BasicOption: basicOption} socksOption := &outbound.Socks5Option{}
err = decoder.Decode(mapping, socksOption) err = decoder.Decode(mapping, socksOption)
if err != nil { if err != nil {
break break
} }
proxy, err = outbound.NewSocks5(*socksOption) proxy, err = outbound.NewSocks5(*socksOption)
case "http": case "http":
httpOption := &outbound.HttpOption{BasicOption: basicOption} httpOption := &outbound.HttpOption{}
err = decoder.Decode(mapping, httpOption) err = decoder.Decode(mapping, httpOption)
if err != nil { if err != nil {
break break
} }
proxy, err = outbound.NewHttp(*httpOption) proxy, err = outbound.NewHttp(*httpOption)
case "vmess": case "vmess":
vmessOption := &outbound.VmessOption{BasicOption: basicOption} vmessOption := &outbound.VmessOption{
HTTPOpts: outbound.HTTPOptions{
Method: "GET",
Path: []string{"/"},
},
ClientFingerprint: tlsC.GetGlobalFingerprint(),
}
err = decoder.Decode(mapping, vmessOption) err = decoder.Decode(mapping, vmessOption)
if err != nil { if err != nil {
break break
} }
proxy, err = outbound.NewVmess(*vmessOption) proxy, err = outbound.NewVmess(*vmessOption)
case "vless": case "vless":
vlessOption := &outbound.VlessOption{BasicOption: basicOption} vlessOption := &outbound.VlessOption{ClientFingerprint: tlsC.GetGlobalFingerprint()}
err = decoder.Decode(mapping, vlessOption) err = decoder.Decode(mapping, vlessOption)
if err != nil { if err != nil {
break break
} }
proxy, err = outbound.NewVless(*vlessOption) proxy, err = outbound.NewVless(*vlessOption)
case "snell": case "snell":
snellOption := &outbound.SnellOption{BasicOption: basicOption} snellOption := &outbound.SnellOption{}
err = decoder.Decode(mapping, snellOption) err = decoder.Decode(mapping, snellOption)
if err != nil { if err != nil {
break break
} }
proxy, err = outbound.NewSnell(*snellOption) proxy, err = outbound.NewSnell(*snellOption)
case "trojan": case "trojan":
trojanOption := &outbound.TrojanOption{BasicOption: basicOption} trojanOption := &outbound.TrojanOption{ClientFingerprint: tlsC.GetGlobalFingerprint()}
err = decoder.Decode(mapping, trojanOption) err = decoder.Decode(mapping, trojanOption)
if err != nil { if err != nil {
break break
} }
proxy, err = outbound.NewTrojan(*trojanOption) proxy, err = outbound.NewTrojan(*trojanOption)
case "hysteria": case "hysteria":
hyOption := &outbound.HysteriaOption{BasicOption: basicOption} hyOption := &outbound.HysteriaOption{}
err = decoder.Decode(mapping, hyOption) err = decoder.Decode(mapping, hyOption)
if err != nil { if err != nil {
break break
} }
proxy, err = outbound.NewHysteria(*hyOption) proxy, err = outbound.NewHysteria(*hyOption)
case "hysteria2": case "hysteria2":
hyOption := &outbound.Hysteria2Option{BasicOption: basicOption} hyOption := &outbound.Hysteria2Option{}
err = decoder.Decode(mapping, hyOption) err = decoder.Decode(mapping, hyOption)
if err != nil { if err != nil {
break break
} }
proxy, err = outbound.NewHysteria2(*hyOption) proxy, err = outbound.NewHysteria2(*hyOption)
case "wireguard": case "wireguard":
wgOption := &outbound.WireGuardOption{BasicOption: basicOption} wgOption := &outbound.WireGuardOption{}
err = decoder.Decode(mapping, wgOption) err = decoder.Decode(mapping, wgOption)
if err != nil { if err != nil {
break break
} }
proxy, err = outbound.NewWireGuard(*wgOption) proxy, err = outbound.NewWireGuard(*wgOption)
case "tuic": case "tuic":
tuicOption := &outbound.TuicOption{BasicOption: basicOption} tuicOption := &outbound.TuicOption{}
err = decoder.Decode(mapping, tuicOption) err = decoder.Decode(mapping, tuicOption)
if err != nil { if err != nil {
break break
} }
proxy, err = outbound.NewTuic(*tuicOption) proxy, err = outbound.NewTuic(*tuicOption)
case "direct": case "direct":
directOption := &outbound.DirectOption{BasicOption: basicOption} directOption := &outbound.DirectOption{}
err = decoder.Decode(mapping, directOption) err = decoder.Decode(mapping, directOption)
if err != nil { if err != nil {
break break
} }
proxy = outbound.NewDirectWithOption(*directOption) proxy = outbound.NewDirectWithOption(*directOption)
case "dns": case "dns":
dnsOptions := &outbound.DnsOption{BasicOption: basicOption} dnsOptions := &outbound.DnsOption{}
err = decoder.Decode(mapping, dnsOptions) err = decoder.Decode(mapping, dnsOptions)
if err != nil { if err != nil {
break break
} }
proxy = outbound.NewDnsWithOption(*dnsOptions) proxy = outbound.NewDnsWithOption(*dnsOptions)
case "reject": case "reject":
rejectOption := &outbound.RejectOption{BasicOption: basicOption} rejectOption := &outbound.RejectOption{}
err = decoder.Decode(mapping, rejectOption) err = decoder.Decode(mapping, rejectOption)
if err != nil { if err != nil {
break break
} }
proxy = outbound.NewRejectWithOption(*rejectOption) proxy = outbound.NewRejectWithOption(*rejectOption)
case "ssh": case "ssh":
sshOption := &outbound.SshOption{BasicOption: basicOption} sshOption := &outbound.SshOption{}
err = decoder.Decode(mapping, sshOption) err = decoder.Decode(mapping, sshOption)
if err != nil { if err != nil {
break break
} }
proxy, err = outbound.NewSsh(*sshOption) proxy, err = outbound.NewSsh(*sshOption)
case "mieru":
mieruOption := &outbound.MieruOption{BasicOption: basicOption}
err = decoder.Decode(mapping, mieruOption)
if err != nil {
break
}
proxy, err = outbound.NewMieru(*mieruOption)
case "anytls":
anytlsOption := &outbound.AnyTLSOption{BasicOption: basicOption}
err = decoder.Decode(mapping, anytlsOption)
if err != nil {
break
}
proxy, err = outbound.NewAnyTLS(*anytlsOption)
case "sudoku":
sudokuOption := &outbound.SudokuOption{BasicOption: basicOption}
err = decoder.Decode(mapping, sudokuOption)
if err != nil {
break
}
proxy, err = outbound.NewSudoku(*sudokuOption)
case "masque":
masqueOption := &outbound.MasqueOption{BasicOption: basicOption}
err = decoder.Decode(mapping, masqueOption)
if err != nil {
break
}
proxy, err = outbound.NewMasque(*masqueOption)
default: default:
return nil, fmt.Errorf("unsupport proxy type: %s", proxyType) return nil, fmt.Errorf("unsupport proxy type: %s", proxyType)
} }
@@ -181,40 +156,12 @@ func ParseProxy(mapping map[string]any, options ...ProxyOption) (C.Proxy, error)
return nil, err return nil, err
} }
if muxOption.Enabled { if muxOption.Enabled {
proxy, err = outbound.NewSingMux(*muxOption, proxy) proxy, err = outbound.NewSingMux(*muxOption, proxy, proxy.(outbound.ProxyBase))
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
} }
proxy = outbound.NewAutoCloseProxyAdapter(proxy)
return NewProxy(proxy), nil return NewProxy(proxy), nil
} }
type proxyOption struct {
DialerForAPI C.Dialer
ProviderName string
}
func applyProxyOptions(options ...ProxyOption) proxyOption {
opt := proxyOption{}
for _, o := range options {
o(&opt)
}
return opt
}
type ProxyOption func(opt *proxyOption)
func WithDialerForAPI(dialer C.Dialer) ProxyOption {
return func(opt *proxyOption) {
opt.DialerForAPI = dialer
}
}
func WithProviderName(name string) ProxyOption {
return func(opt *proxyOption) {
opt.ProviderName = name
}
}

View File

@@ -7,13 +7,13 @@ import (
"time" "time"
"github.com/metacubex/mihomo/common/atomic" "github.com/metacubex/mihomo/common/atomic"
"github.com/metacubex/mihomo/common/batch"
"github.com/metacubex/mihomo/common/singledo" "github.com/metacubex/mihomo/common/singledo"
"github.com/metacubex/mihomo/common/utils" "github.com/metacubex/mihomo/common/utils"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/log" "github.com/metacubex/mihomo/log"
"github.com/dlclark/regexp2" "github.com/dlclark/regexp2"
"golang.org/x/sync/errgroup"
) )
type HealthCheckOption struct { type HealthCheckOption struct {
@@ -27,23 +27,28 @@ type extraOption struct {
} }
type HealthCheck struct { type HealthCheck struct {
ctx context.Context
ctxCancel context.CancelFunc
url string url string
extra map[string]*extraOption extra map[string]*extraOption
mu sync.Mutex mu sync.Mutex
started atomic.Bool
proxies []C.Proxy proxies []C.Proxy
interval time.Duration interval time.Duration
lazy bool lazy bool
expectedStatus utils.IntRanges[uint16] expectedStatus utils.IntRanges[uint16]
lastTouch atomic.TypedValue[time.Time] lastTouch atomic.TypedValue[time.Time]
done chan struct{}
singleDo *singledo.Single[struct{}] singleDo *singledo.Single[struct{}]
timeout time.Duration timeout time.Duration
} }
func (hc *HealthCheck) process() { func (hc *HealthCheck) process() {
if hc.started.Load() {
log.Warnln("Skip start health check timer due to it's started")
return
}
ticker := time.NewTicker(hc.interval) ticker := time.NewTicker(hc.interval)
go hc.check() hc.start()
for { for {
select { select {
case <-ticker.C: case <-ticker.C:
@@ -54,14 +59,15 @@ func (hc *HealthCheck) process() {
} else { } else {
log.Debugln("Skip once health check because we are lazy") log.Debugln("Skip once health check because we are lazy")
} }
case <-hc.ctx.Done(): case <-hc.done:
ticker.Stop() ticker.Stop()
hc.stop()
return return
} }
} }
} }
func (hc *HealthCheck) setProxies(proxies []C.Proxy) { func (hc *HealthCheck) setProxy(proxies []C.Proxy) {
hc.proxies = proxies hc.proxies = proxies
} }
@@ -98,6 +104,10 @@ func (hc *HealthCheck) registerHealthCheckTask(url string, expectedStatus utils.
option := &extraOption{filters: map[string]struct{}{}, expectedStatus: expectedStatus} option := &extraOption{filters: map[string]struct{}{}, expectedStatus: expectedStatus}
splitAndAddFiltersToExtra(filter, option) splitAndAddFiltersToExtra(filter, option)
hc.extra[url] = option hc.extra[url] = option
if hc.auto() && !hc.started.Load() {
go hc.process()
}
} }
func splitAndAddFiltersToExtra(filter string, option *extraOption) { func splitAndAddFiltersToExtra(filter string, option *extraOption) {
@@ -120,6 +130,14 @@ func (hc *HealthCheck) touch() {
hc.lastTouch.Store(time.Now()) hc.lastTouch.Store(time.Now())
} }
func (hc *HealthCheck) start() {
hc.started.Store(true)
}
func (hc *HealthCheck) stop() {
hc.started.Store(false)
}
func (hc *HealthCheck) check() { func (hc *HealthCheck) check() {
if len(hc.proxies) == 0 { if len(hc.proxies) == 0 {
return return
@@ -128,8 +146,7 @@ func (hc *HealthCheck) check() {
_, _, _ = hc.singleDo.Do(func() (struct{}, error) { _, _, _ = hc.singleDo.Do(func() (struct{}, error) {
id := utils.NewUUIDV4().String() id := utils.NewUUIDV4().String()
log.Debugln("Start New Health Checking {%s}", id) log.Debugln("Start New Health Checking {%s}", id)
b := new(errgroup.Group) b, _ := batch.New[bool](context.Background(), batch.WithConcurrencyNum[bool](10))
b.SetLimit(10)
// execute default health check // execute default health check
option := &extraOption{filters: nil, expectedStatus: hc.expectedStatus} option := &extraOption{filters: nil, expectedStatus: hc.expectedStatus}
@@ -141,13 +158,13 @@ func (hc *HealthCheck) check() {
hc.execute(b, url, id, option) hc.execute(b, url, id, option)
} }
} }
_ = b.Wait() b.Wait()
log.Debugln("Finish A Health Checking {%s}", id) log.Debugln("Finish A Health Checking {%s}", id)
return struct{}{}, nil return struct{}{}, nil
}) })
} }
func (hc *HealthCheck) execute(b *errgroup.Group, url, uid string, option *extraOption) { func (hc *HealthCheck) execute(b *batch.Batch[bool], url, uid string, option *extraOption) {
url = strings.TrimSpace(url) url = strings.TrimSpace(url)
if len(url) == 0 { if len(url) == 0 {
log.Debugln("Health Check has been skipped due to testUrl is empty, {%s}", uid) log.Debugln("Health Check has been skipped due to testUrl is empty, {%s}", uid)
@@ -164,32 +181,32 @@ func (hc *HealthCheck) execute(b *errgroup.Group, url, uid string, option *extra
filters = append(filters, filter) filters = append(filters, filter)
} }
filterReg = regexp2.MustCompile(strings.Join(filters, "|"), regexp2.None) filterReg = regexp2.MustCompile(strings.Join(filters, "|"), 0)
} }
} }
for _, proxy := range hc.proxies { for _, proxy := range hc.proxies {
// skip proxies that do not require health check // skip proxies that do not require health check
if filterReg != nil { if filterReg != nil {
if match, _ := filterReg.MatchString(proxy.Name()); !match { if match, _ := filterReg.FindStringMatch(proxy.Name()); match == nil {
continue continue
} }
} }
p := proxy p := proxy
b.Go(func() error { b.Go(p.Name(), func() (bool, error) {
ctx, cancel := context.WithTimeout(hc.ctx, hc.timeout) ctx, cancel := context.WithTimeout(context.Background(), hc.timeout)
defer cancel() defer cancel()
log.Debugln("Health Checking, proxy: %s, url: %s, id: {%s}", p.Name(), url, uid) log.Debugln("Health Checking, proxy: %s, url: %s, id: {%s}", p.Name(), url, uid)
_, _ = p.URLTest(ctx, url, expectedStatus) _, _ = p.URLTest(ctx, url, expectedStatus)
log.Debugln("Health Checked, proxy: %s, url: %s, alive: %t, delay: %d ms uid: {%s}", p.Name(), url, p.AliveForTestUrl(url), p.LastDelayForTestUrl(url), uid) log.Debugln("Health Checked, proxy: %s, url: %s, alive: %t, delay: %d ms uid: {%s}", p.Name(), url, p.AliveForTestUrl(url), p.LastDelayForTestUrl(url), uid)
return nil return false, nil
}) })
} }
} }
func (hc *HealthCheck) close() { func (hc *HealthCheck) close() {
hc.ctxCancel() hc.done <- struct{}{}
} }
func NewHealthCheck(proxies []C.Proxy, url string, timeout uint, interval uint, lazy bool, expectedStatus utils.IntRanges[uint16]) *HealthCheck { func NewHealthCheck(proxies []C.Proxy, url string, timeout uint, interval uint, lazy bool, expectedStatus utils.IntRanges[uint16]) *HealthCheck {
@@ -200,11 +217,8 @@ func NewHealthCheck(proxies []C.Proxy, url string, timeout uint, interval uint,
if timeout == 0 { if timeout == 0 {
timeout = 5000 timeout = 5000
} }
ctx, cancel := context.WithCancel(context.Background())
return &HealthCheck{ return &HealthCheck{
ctx: ctx,
ctxCancel: cancel,
proxies: proxies, proxies: proxies,
url: url, url: url,
timeout: time.Duration(timeout) * time.Millisecond, timeout: time.Duration(timeout) * time.Millisecond,
@@ -212,6 +226,7 @@ func NewHealthCheck(proxies []C.Proxy, url string, timeout uint, interval uint,
interval: time.Duration(interval) * time.Second, interval: time.Duration(interval) * time.Second,
lazy: lazy, lazy: lazy,
expectedStatus: expectedStatus, expectedStatus: expectedStatus,
done: make(chan struct{}, 1),
singleDo: singledo.NewSingle[struct{}](time.Second), singleDo: singledo.NewSingle[struct{}](time.Second),
} }
} }

View File

@@ -1,88 +0,0 @@
package provider
import (
"encoding"
"fmt"
"github.com/dlclark/regexp2"
)
type overrideSchema struct {
TFO *bool `provider:"tfo,omitempty"`
MPTcp *bool `provider:"mptcp,omitempty"`
UDP *bool `provider:"udp,omitempty"`
UDPOverTCP *bool `provider:"udp-over-tcp,omitempty"`
Up *string `provider:"up,omitempty"`
Down *string `provider:"down,omitempty"`
DialerProxy *string `provider:"dialer-proxy,omitempty"`
SkipCertVerify *bool `provider:"skip-cert-verify,omitempty"`
Interface *string `provider:"interface-name,omitempty"`
RoutingMark *int `provider:"routing-mark,omitempty"`
IPVersion *string `provider:"ip-version,omitempty"`
AdditionalPrefix *string `provider:"additional-prefix,omitempty"`
AdditionalSuffix *string `provider:"additional-suffix,omitempty"`
ProxyName []overrideProxyNameSchema `provider:"proxy-name,omitempty"`
}
type overrideProxyNameSchema struct {
// matching expression for regex replacement
Pattern *regexp2.Regexp `provider:"pattern"`
// the new content after regex matching
Target string `provider:"target"`
}
var _ encoding.TextUnmarshaler = (*regexp2.Regexp)(nil) // ensure *regexp2.Regexp can decode direct by structure package
func (o *overrideSchema) Apply(mapping map[string]any) error {
if o.TFO != nil {
mapping["tfo"] = *o.TFO
}
if o.MPTcp != nil {
mapping["mptcp"] = *o.MPTcp
}
if o.UDP != nil {
mapping["udp"] = *o.UDP
}
if o.UDPOverTCP != nil {
mapping["udp-over-tcp"] = *o.UDPOverTCP
}
if o.Up != nil {
mapping["up"] = *o.Up
}
if o.Down != nil {
mapping["down"] = *o.Down
}
if o.DialerProxy != nil {
mapping["dialer-proxy"] = *o.DialerProxy
}
if o.SkipCertVerify != nil {
mapping["skip-cert-verify"] = *o.SkipCertVerify
}
if o.Interface != nil {
mapping["interface"] = *o.Interface
}
if o.RoutingMark != nil {
mapping["routing-mark"] = *o.RoutingMark
}
if o.IPVersion != nil {
mapping["ip-version"] = *o.IPVersion
}
for _, expr := range o.ProxyName {
name := mapping["name"].(string)
newName, err := expr.Pattern.Replace(name, expr.Target, 0, -1)
if err != nil {
return fmt.Errorf("proxy name replace error: %w", err)
}
mapping["name"] = newName
}
if o.AdditionalPrefix != nil {
mapping["name"] = fmt.Sprintf("%s%s", *o.AdditionalPrefix, mapping["name"])
}
if o.AdditionalSuffix != nil {
mapping["name"] = fmt.Sprintf("%s%s", mapping["name"], *o.AdditionalSuffix)
}
return nil
}

View File

@@ -9,11 +9,13 @@ import (
"github.com/metacubex/mihomo/common/utils" "github.com/metacubex/mihomo/common/utils"
"github.com/metacubex/mihomo/component/resource" "github.com/metacubex/mihomo/component/resource"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
P "github.com/metacubex/mihomo/constant/provider" "github.com/metacubex/mihomo/constant/features"
types "github.com/metacubex/mihomo/constant/provider"
) )
var ( var (
errVehicleType = errors.New("unsupport vehicle type") errVehicleType = errors.New("unsupport vehicle type")
errSubPath = errors.New("path is not subpath of home directory")
) )
type healthCheckSchema struct { type healthCheckSchema struct {
@@ -25,25 +27,34 @@ type healthCheckSchema struct {
ExpectedStatus string `provider:"expected-status,omitempty"` ExpectedStatus string `provider:"expected-status,omitempty"`
} }
type proxyProviderSchema struct { type OverrideSchema struct {
Type string `provider:"type"` UDP *bool `provider:"udp,omitempty"`
Path string `provider:"path,omitempty"` Up *string `provider:"up,omitempty"`
URL string `provider:"url,omitempty"` Down *string `provider:"down,omitempty"`
Proxy string `provider:"proxy,omitempty"` DialerProxy *string `provider:"dialer-proxy,omitempty"`
Interval int `provider:"interval,omitempty"` SkipCertVerify *bool `provider:"skip-cert-verify,omitempty"`
Filter string `provider:"filter,omitempty"` Interface *string `provider:"interface-name,omitempty"`
ExcludeFilter string `provider:"exclude-filter,omitempty"` RoutingMark *int `provider:"routing-mark,omitempty"`
ExcludeType string `provider:"exclude-type,omitempty"` IPVersion *string `provider:"ip-version,omitempty"`
DialerProxy string `provider:"dialer-proxy,omitempty"` AdditionalPrefix *string `provider:"additional-prefix,omitempty"`
SizeLimit int64 `provider:"size-limit,omitempty"` AdditionalSuffix *string `provider:"additional-suffix,omitempty"`
Payload []map[string]any `provider:"payload,omitempty"`
HealthCheck healthCheckSchema `provider:"health-check,omitempty"`
Override overrideSchema `provider:"override,omitempty"`
Header map[string][]string `provider:"header,omitempty"`
} }
func ParseProxyProvider(name string, mapping map[string]any) (P.ProxyProvider, error) { type proxyProviderSchema struct {
Type string `provider:"type"`
Path string `provider:"path,omitempty"`
URL string `provider:"url,omitempty"`
Interval int `provider:"interval,omitempty"`
Filter string `provider:"filter,omitempty"`
ExcludeFilter string `provider:"exclude-filter,omitempty"`
ExcludeType string `provider:"exclude-type,omitempty"`
DialerProxy string `provider:"dialer-proxy,omitempty"`
HealthCheck healthCheckSchema `provider:"health-check,omitempty"`
Override OverrideSchema `provider:"override,omitempty"`
}
func ParseProxyProvider(name string, mapping map[string]any) (types.ProxyProvider, error) {
decoder := structure.NewDecoder(structure.Option{TagName: "provider", WeaklyTypedInput: true}) decoder := structure.NewDecoder(structure.Option{TagName: "provider", WeaklyTypedInput: true})
schema := &proxyProviderSchema{ schema := &proxyProviderSchema{
@@ -69,35 +80,32 @@ func ParseProxyProvider(name string, mapping map[string]any) (P.ProxyProvider, e
} }
hc := NewHealthCheck([]C.Proxy{}, schema.HealthCheck.URL, uint(schema.HealthCheck.TestTimeout), hcInterval, schema.HealthCheck.Lazy, expectedStatus) hc := NewHealthCheck([]C.Proxy{}, schema.HealthCheck.URL, uint(schema.HealthCheck.TestTimeout), hcInterval, schema.HealthCheck.Lazy, expectedStatus)
parser, err := NewProxiesParser(name, schema.Filter, schema.ExcludeFilter, schema.ExcludeType, schema.DialerProxy, schema.Override) var vehicle types.Vehicle
if err != nil {
return nil, err
}
var vehicle P.Vehicle
switch schema.Type { switch schema.Type {
case "file": case "file":
path := C.Path.Resolve(schema.Path) path := C.Path.Resolve(schema.Path)
if !C.Path.IsSafePath(path) {
return nil, C.Path.ErrNotSafePath(path)
}
vehicle = resource.NewFileVehicle(path) vehicle = resource.NewFileVehicle(path)
case "http": case "http":
path := C.Path.GetPathByHash("proxies", schema.URL)
if schema.Path != "" { if schema.Path != "" {
path = C.Path.Resolve(schema.Path) path := C.Path.Resolve(schema.Path)
if !C.Path.IsSafePath(path) { if !features.CMFA && !C.Path.IsSafePath(path) {
return nil, C.Path.ErrNotSafePath(path) return nil, fmt.Errorf("%w: %s", errSubPath, path)
} }
vehicle = resource.NewHTTPVehicle(schema.URL, path)
} else {
path := C.Path.GetPathByHash("proxies", schema.URL)
vehicle = resource.NewHTTPVehicle(schema.URL, path)
} }
vehicle = resource.NewHTTPVehicle(schema.URL, path, schema.Proxy, schema.Header, resource.DefaultHttpTimeout, schema.SizeLimit)
case "inline":
return NewInlineProvider(name, schema.Payload, parser, hc)
default: default:
return nil, fmt.Errorf("%w: %s", errVehicleType, schema.Type) return nil, fmt.Errorf("%w: %s", errVehicleType, schema.Type)
} }
interval := time.Duration(uint(schema.Interval)) * time.Second interval := time.Duration(uint(schema.Interval)) * time.Second
filter := schema.Filter
excludeFilter := schema.ExcludeFilter
excludeType := schema.ExcludeType
dialerProxy := schema.DialerProxy
override := schema.Override
return NewProxySetProvider(name, interval, schema.Payload, parser, vehicle, hc) return NewProxySetProvider(name, interval, filter, excludeFilter, excludeType, dialerProxy, override, vehicle, hc)
} }

View File

@@ -0,0 +1,36 @@
//go:build android && cmfa
package provider
import (
"time"
)
var (
suspended bool
)
type UpdatableProvider interface {
UpdatedAt() time.Time
}
func (pp *proxySetProvider) UpdatedAt() time.Time {
return pp.Fetcher.UpdatedAt
}
func (pp *proxySetProvider) Close() error {
pp.healthCheck.close()
pp.Fetcher.Destroy()
return nil
}
func (cp *compatibleProvider) Close() error {
cp.healthCheck.close()
return nil
}
func Suspend(s bool) {
suspended = s
}

View File

@@ -1,26 +1,27 @@
package provider package provider
import ( import (
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"net/http"
"runtime" "runtime"
"strings" "strings"
"sync"
"time" "time"
"github.com/metacubex/mihomo/adapter" "github.com/metacubex/mihomo/adapter"
"github.com/metacubex/mihomo/common/convert" "github.com/metacubex/mihomo/common/convert"
"github.com/metacubex/mihomo/common/utils" "github.com/metacubex/mihomo/common/utils"
"github.com/metacubex/mihomo/common/yaml" mihomoHttp "github.com/metacubex/mihomo/component/http"
"github.com/metacubex/mihomo/component/profile/cachefile"
"github.com/metacubex/mihomo/component/resource" "github.com/metacubex/mihomo/component/resource"
C "github.com/metacubex/mihomo/constant" C "github.com/metacubex/mihomo/constant"
P "github.com/metacubex/mihomo/constant/provider" types "github.com/metacubex/mihomo/constant/provider"
"github.com/metacubex/mihomo/log"
"github.com/metacubex/mihomo/tunnel/statistic" "github.com/metacubex/mihomo/tunnel/statistic"
"github.com/dlclark/regexp2" "github.com/dlclark/regexp2"
"github.com/metacubex/http" "gopkg.in/yaml.v3"
) )
const ( const (
@@ -31,141 +32,128 @@ type ProxySchema struct {
Proxies []map[string]any `yaml:"proxies"` Proxies []map[string]any `yaml:"proxies"`
} }
type providerForApi struct {
Name string `json:"name"`
Type string `json:"type"`
VehicleType string `json:"vehicleType"`
Proxies []C.Proxy `json:"proxies"`
TestUrl string `json:"testUrl"`
ExpectedStatus string `json:"expectedStatus"`
UpdatedAt time.Time `json:"updatedAt,omitempty"`
SubscriptionInfo *SubscriptionInfo `json:"subscriptionInfo,omitempty"`
}
type baseProvider struct {
mutex sync.RWMutex
name string
proxies []C.Proxy
healthCheck *HealthCheck
version uint32
}
func (bp *baseProvider) Name() string {
return bp.name
}
func (bp *baseProvider) Version() uint32 {
bp.mutex.RLock()
defer bp.mutex.RUnlock()
return bp.version
}
func (bp *baseProvider) Initial() error {
if bp.healthCheck.auto() {
go bp.healthCheck.process()
}
return nil
}
func (bp *baseProvider) HealthCheck() {
bp.healthCheck.check()
}
func (bp *baseProvider) Type() P.ProviderType {
return P.Proxy
}
func (bp *baseProvider) Proxies() []C.Proxy {
bp.mutex.RLock()
defer bp.mutex.RUnlock()
return bp.proxies
}
func (bp *baseProvider) Count() int {
bp.mutex.RLock()
defer bp.mutex.RUnlock()
return len(bp.proxies)
}
func (bp *baseProvider) Touch() {
bp.healthCheck.touch()
}
func (bp *baseProvider) HealthCheckURL() string {
return bp.healthCheck.url
}
func (bp *baseProvider) RegisterHealthCheckTask(url string, expectedStatus utils.IntRanges[uint16], filter string, interval uint) {
bp.healthCheck.registerHealthCheckTask(url, expectedStatus, filter, interval)
}
func (bp *baseProvider) setProxies(proxies []C.Proxy) {
bp.mutex.Lock()
defer bp.mutex.Unlock()
bp.proxies = proxies
bp.version += 1
bp.healthCheck.setProxies(proxies)
if bp.healthCheck.auto() {
go bp.healthCheck.check()
}
}
func (bp *baseProvider) Close() error {
bp.healthCheck.close()
return nil
}
// ProxySetProvider for auto gc // ProxySetProvider for auto gc
type ProxySetProvider struct { type ProxySetProvider struct {
*proxySetProvider *proxySetProvider
} }
type proxySetProvider struct { type proxySetProvider struct {
baseProvider
*resource.Fetcher[[]C.Proxy] *resource.Fetcher[[]C.Proxy]
proxies []C.Proxy
healthCheck *HealthCheck
version uint32
subscriptionInfo *SubscriptionInfo subscriptionInfo *SubscriptionInfo
} }
func (pp *proxySetProvider) MarshalJSON() ([]byte, error) { func (pp *proxySetProvider) MarshalJSON() ([]byte, error) {
return json.Marshal(providerForApi{ return json.Marshal(map[string]any{
Name: pp.Name(), "name": pp.Name(),
Type: pp.Type().String(), "type": pp.Type().String(),
VehicleType: pp.VehicleType().String(), "vehicleType": pp.VehicleType().String(),
Proxies: pp.Proxies(), "proxies": pp.Proxies(),
TestUrl: pp.healthCheck.url, "testUrl": pp.healthCheck.url,
ExpectedStatus: pp.healthCheck.expectedStatus.String(), "expectedStatus": pp.healthCheck.expectedStatus.String(),
UpdatedAt: pp.UpdatedAt(), "updatedAt": pp.UpdatedAt,
SubscriptionInfo: pp.subscriptionInfo, "subscriptionInfo": pp.subscriptionInfo,
}) })
} }
func (pp *proxySetProvider) Version() uint32 {
return pp.version
}
func (pp *proxySetProvider) Name() string { func (pp *proxySetProvider) Name() string {
return pp.Fetcher.Name() return pp.Fetcher.Name()
} }
func (pp *proxySetProvider) HealthCheck() {
pp.healthCheck.check()
}
func (pp *proxySetProvider) Update() error { func (pp *proxySetProvider) Update() error {
_, _, err := pp.Fetcher.Update() elm, same, err := pp.Fetcher.Update()
if err == nil && !same {
pp.OnUpdate(elm)
}
return err return err
} }
func (pp *proxySetProvider) Initial() error { func (pp *proxySetProvider) Initial() error {
if err := pp.baseProvider.Initial(); err != nil { elm, err := pp.Fetcher.Initial()
return err
}
_, err := pp.Fetcher.Initial()
if err != nil { if err != nil {
return err return err
} }
if subscriptionInfo := cachefile.Cache().GetSubscriptionInfo(pp.Name()); subscriptionInfo != "" { pp.OnUpdate(elm)
pp.subscriptionInfo = NewSubscriptionInfo(subscriptionInfo) pp.getSubscriptionInfo()
}
pp.closeAllConnections() pp.closeAllConnections()
return nil return nil
} }
func (pp *proxySetProvider) Type() types.ProviderType {
return types.Proxy
}
func (pp *proxySetProvider) Proxies() []C.Proxy {
return pp.proxies
}
func (pp *proxySetProvider) Touch() {
pp.healthCheck.touch()
}
func (pp *proxySetProvider) HealthCheckURL() string {
return pp.healthCheck.url
}
func (pp *proxySetProvider) RegisterHealthCheckTask(url string, expectedStatus utils.IntRanges[uint16], filter string, interval uint) {
pp.healthCheck.registerHealthCheckTask(url, expectedStatus, filter, interval)
}
func (pp *proxySetProvider) setProxies(proxies []C.Proxy) {
pp.proxies = proxies
pp.healthCheck.setProxy(proxies)
if pp.healthCheck.auto() {
go pp.healthCheck.check()
}
}
func (pp *proxySetProvider) getSubscriptionInfo() {
if pp.VehicleType() != types.HTTP {
return
}
go func() {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*90)
defer cancel()
resp, err := mihomoHttp.HttpRequest(ctx, pp.Vehicle().(*resource.HTTPVehicle).Url(),
http.MethodGet, http.Header{"User-Agent": {C.UA}}, nil)
if err != nil {
return
}
defer resp.Body.Close()
userInfoStr := strings.TrimSpace(resp.Header.Get("subscription-userinfo"))
if userInfoStr == "" {
resp2, err := mihomoHttp.HttpRequest(ctx, pp.Vehicle().(*resource.HTTPVehicle).Url(),
http.MethodGet, http.Header{"User-Agent": {"Quantumultx"}}, nil)
if err != nil {
return
}
defer resp2.Body.Close()
userInfoStr = strings.TrimSpace(resp2.Header.Get("subscription-userinfo"))
if userInfoStr == "" {
return
}
}
pp.subscriptionInfo, err = NewSubscriptionInfo(userInfoStr)
if err != nil {
log.Warnln("[Provider] get subscription-userinfo: %e", err)
}
}()
}
func (pp *proxySetProvider) closeAllConnections() { func (pp *proxySetProvider) closeAllConnections() {
statistic.DefaultManager.Range(func(c statistic.Tracker) bool { statistic.DefaultManager.Range(func(c statistic.Tracker) bool {
for _, chain := range c.ProviderChains() { for _, chain := range c.Chains() {
if chain == pp.Name() { if chain == pp.Name() {
_ = c.Close() _ = c.Close()
break break
@@ -175,145 +163,118 @@ func (pp *proxySetProvider) closeAllConnections() {
}) })
} }
func (pp *proxySetProvider) Close() error { func stopProxyProvider(pd *ProxySetProvider) {
_ = pp.baseProvider.Close() pd.healthCheck.close()
return pp.Fetcher.Close() _ = pd.Fetcher.Destroy()
} }
func NewProxySetProvider(name string, interval time.Duration, payload []map[string]any, parser resource.Parser[[]C.Proxy], vehicle P.Vehicle, hc *HealthCheck) (*ProxySetProvider, error) { func NewProxySetProvider(name string, interval time.Duration, filter string, excludeFilter string, excludeType string, dialerProxy string, override OverrideSchema, vehicle types.Vehicle, hc *HealthCheck) (*ProxySetProvider, error) {
excludeFilterReg, err := regexp2.Compile(excludeFilter, 0)
if err != nil {
return nil, fmt.Errorf("invalid excludeFilter regex: %w", err)
}
var excludeTypeArray []string
if excludeType != "" {
excludeTypeArray = strings.Split(excludeType, "|")
}
var filterRegs []*regexp2.Regexp
for _, filter := range strings.Split(filter, "`") {
filterReg, err := regexp2.Compile(filter, 0)
if err != nil {
return nil, fmt.Errorf("invalid filter regex: %w", err)
}
filterRegs = append(filterRegs, filterReg)
}
if hc.auto() {
go hc.process()
}
pd := &proxySetProvider{ pd := &proxySetProvider{
baseProvider: baseProvider{ proxies: []C.Proxy{},
name: name, healthCheck: hc,
proxies: []C.Proxy{},
healthCheck: hc,
},
} }
if len(payload) > 0 { // using as fallback proxies fetcher := resource.NewFetcher[[]C.Proxy](name, interval, vehicle, proxiesParseAndFilter(filter, excludeFilter, excludeTypeArray, filterRegs, excludeFilterReg, dialerProxy, override), proxiesOnUpdate(pd))
ps := ProxySchema{Proxies: payload}
buf, err := yaml.Marshal(ps)
if err != nil {
return nil, err
}
proxies, err := parser(buf)
if err != nil {
return nil, err
}
pd.proxies = proxies
// direct call setProxies on hc to avoid starting a health check process immediately, it should be done by Initial()
hc.setProxies(proxies)
}
fetcher := resource.NewFetcher[[]C.Proxy](name, interval, vehicle, parser, pd.setProxies)
pd.Fetcher = fetcher pd.Fetcher = fetcher
if httpVehicle, ok := vehicle.(*resource.HTTPVehicle); ok {
httpVehicle.SetInRead(func(resp *http.Response) {
if subscriptionInfo := resp.Header.Get("subscription-userinfo"); subscriptionInfo != "" {
cachefile.Cache().SetSubscriptionInfo(name, subscriptionInfo)
pd.subscriptionInfo = NewSubscriptionInfo(subscriptionInfo)
}
})
}
wrapper := &ProxySetProvider{pd} wrapper := &ProxySetProvider{pd}
runtime.SetFinalizer(wrapper, (*ProxySetProvider).Close) runtime.SetFinalizer(wrapper, stopProxyProvider)
return wrapper, nil return wrapper, nil
} }
func (pp *ProxySetProvider) Close() error {
runtime.SetFinalizer(pp, nil)
return pp.proxySetProvider.Close()
}
// InlineProvider for auto gc
type InlineProvider struct {
*inlineProvider
}
type inlineProvider struct {
baseProvider
updateAt time.Time
}
func (ip *inlineProvider) MarshalJSON() ([]byte, error) {
return json.Marshal(providerForApi{
Name: ip.Name(),
Type: ip.Type().String(),
VehicleType: ip.VehicleType().String(),
Proxies: ip.Proxies(),
TestUrl: ip.healthCheck.url,
ExpectedStatus: ip.healthCheck.expectedStatus.String(),
UpdatedAt: ip.updateAt,
})
}
func (ip *inlineProvider) VehicleType() P.VehicleType {
return P.Inline
}
func (ip *inlineProvider) Update() error {
// make api update happy
ip.updateAt = time.Now()
return nil
}
func NewInlineProvider(name string, payload []map[string]any, parser resource.Parser[[]C.Proxy], hc *HealthCheck) (*InlineProvider, error) {
ps := ProxySchema{Proxies: payload}
buf, err := yaml.Marshal(ps)
if err != nil {
return nil, err
}
proxies, err := parser(buf)
if err != nil {
return nil, err
}
// direct call setProxies on hc to avoid starting a health check process immediately, it should be done by Initial()
hc.setProxies(proxies)
ip := &inlineProvider{
baseProvider: baseProvider{
name: name,
proxies: proxies,
healthCheck: hc,
},
updateAt: time.Now(),
}
wrapper := &InlineProvider{ip}
runtime.SetFinalizer(wrapper, (*InlineProvider).Close)
return wrapper, nil
}
func (ip *InlineProvider) Close() error {
runtime.SetFinalizer(ip, nil)
return ip.baseProvider.Close()
}
// CompatibleProvider for auto gc // CompatibleProvider for auto gc
type CompatibleProvider struct { type CompatibleProvider struct {
*compatibleProvider *compatibleProvider
} }
type compatibleProvider struct { type compatibleProvider struct {
baseProvider name string
healthCheck *HealthCheck
proxies []C.Proxy
version uint32
} }
func (cp *compatibleProvider) MarshalJSON() ([]byte, error) { func (cp *compatibleProvider) MarshalJSON() ([]byte, error) {
return json.Marshal(providerForApi{ return json.Marshal(map[string]any{
Name: cp.Name(), "name": cp.Name(),
Type: cp.Type().String(), "type": cp.Type().String(),
VehicleType: cp.VehicleType().String(), "vehicleType": cp.VehicleType().String(),
Proxies: cp.Proxies(), "proxies": cp.Proxies(),
TestUrl: cp.healthCheck.url, "testUrl": cp.healthCheck.url,
ExpectedStatus: cp.healthCheck.expectedStatus.String(), "expectedStatus": cp.healthCheck.expectedStatus.String(),
}) })
} }
func (cp *compatibleProvider) Version() uint32 {
return cp.version
}
func (cp *compatibleProvider) Name() string {
return cp.name
}
func (cp *compatibleProvider) HealthCheck() {
cp.healthCheck.check()
}
func (cp *compatibleProvider) Update() error { func (cp *compatibleProvider) Update() error {
return nil return nil
} }
func (cp *compatibleProvider) VehicleType() P.VehicleType { func (cp *compatibleProvider) Initial() error {
return P.Compatible if cp.healthCheck.interval != 0 && cp.healthCheck.url != "" {
cp.HealthCheck()
}
return nil
}
func (cp *compatibleProvider) VehicleType() types.VehicleType {
return types.Compatible
}
func (cp *compatibleProvider) Type() types.ProviderType {
return types.Proxy
}
func (cp *compatibleProvider) Proxies() []C.Proxy {
return cp.proxies
}
func (cp *compatibleProvider) Touch() {
cp.healthCheck.touch()
}
func (cp *compatibleProvider) HealthCheckURL() string {
return cp.healthCheck.url
}
func (cp *compatibleProvider) RegisterHealthCheckTask(url string, expectedStatus utils.IntRanges[uint16], filter string, interval uint) {
cp.healthCheck.registerHealthCheckTask(url, expectedStatus, filter, interval)
}
func stopCompatibleProvider(pd *CompatibleProvider) {
pd.healthCheck.close()
} }
func NewCompatibleProvider(name string, proxies []C.Proxy, hc *HealthCheck) (*CompatibleProvider, error) { func NewCompatibleProvider(name string, proxies []C.Proxy, hc *HealthCheck) (*CompatibleProvider, error) {
@@ -321,50 +282,30 @@ func NewCompatibleProvider(name string, proxies []C.Proxy, hc *HealthCheck) (*Co
return nil, errors.New("provider need one proxy at least") return nil, errors.New("provider need one proxy at least")
} }
if hc.auto() {
go hc.process()
}
pd := &compatibleProvider{ pd := &compatibleProvider{
baseProvider: baseProvider{ name: name,
name: name, proxies: proxies,
proxies: proxies, healthCheck: hc,
healthCheck: hc,
},
} }
wrapper := &CompatibleProvider{pd} wrapper := &CompatibleProvider{pd}
runtime.SetFinalizer(wrapper, (*CompatibleProvider).Close) runtime.SetFinalizer(wrapper, stopCompatibleProvider)
return wrapper, nil return wrapper, nil
} }
func (cp *CompatibleProvider) Close() error { func proxiesOnUpdate(pd *proxySetProvider) func([]C.Proxy) {
runtime.SetFinalizer(cp, nil) return func(elm []C.Proxy) {
return cp.compatibleProvider.Close() pd.setProxies(elm)
pd.version += 1
pd.getSubscriptionInfo()
}
} }
func NewProxiesParser(pdName string, filter string, excludeFilter string, excludeType string, dialerProxy string, override overrideSchema) (resource.Parser[[]C.Proxy], error) { func proxiesParseAndFilter(filter string, excludeFilter string, excludeTypeArray []string, filterRegs []*regexp2.Regexp, excludeFilterReg *regexp2.Regexp, dialerProxy string, override OverrideSchema) resource.Parser[[]C.Proxy] {
var excludeTypeArray []string
if excludeType != "" {
excludeTypeArray = strings.Split(excludeType, "|")
}
var excludeFilterRegs []*regexp2.Regexp
if excludeFilter != "" {
for _, excludeFilter := range strings.Split(excludeFilter, "`") {
excludeFilterReg, err := regexp2.Compile(excludeFilter, regexp2.None)
if err != nil {
return nil, fmt.Errorf("invalid excludeFilter regex: %w", err)
}
excludeFilterRegs = append(excludeFilterRegs, excludeFilterReg)
}
}
var filterRegs []*regexp2.Regexp
for _, filter := range strings.Split(filter, "`") {
filterReg, err := regexp2.Compile(filter, regexp2.None)
if err != nil {
return nil, fmt.Errorf("invalid filter regex: %w", err)
}
filterRegs = append(filterRegs, filterReg)
}
return func(buf []byte) ([]C.Proxy, error) { return func(buf []byte) ([]C.Proxy, error) {
schema := &ProxySchema{} schema := &ProxySchema{}
@@ -383,9 +324,8 @@ func NewProxiesParser(pdName string, filter string, excludeFilter string, exclud
proxies := []C.Proxy{} proxies := []C.Proxy{}
proxiesSet := map[string]struct{}{} proxiesSet := map[string]struct{}{}
for _, filterReg := range filterRegs { for _, filterReg := range filterRegs {
LOOP1:
for idx, mapping := range schema.Proxies { for idx, mapping := range schema.Proxies {
if len(excludeTypeArray) > 0 { if nil != excludeTypeArray && len(excludeTypeArray) > 0 {
mType, ok := mapping["type"] mType, ok := mapping["type"]
if !ok { if !ok {
continue continue
@@ -394,11 +334,18 @@ func NewProxiesParser(pdName string, filter string, excludeFilter string, exclud
if !ok { if !ok {
continue continue
} }
for _, excludeType := range excludeTypeArray { flag := false
if strings.EqualFold(pType, excludeType) { for i := range excludeTypeArray {
continue LOOP1 if strings.EqualFold(pType, excludeTypeArray[i]) {
flag = true
break
} }
} }
if flag {
continue
}
} }
mName, ok := mapping["name"] mName, ok := mapping["name"]
if !ok { if !ok {
@@ -408,15 +355,13 @@ func NewProxiesParser(pdName string, filter string, excludeFilter string, exclud
if !ok { if !ok {
continue continue
} }
if len(excludeFilterRegs) > 0 { if len(excludeFilter) > 0 {
for _, excludeFilterReg := range excludeFilterRegs { if mat, _ := excludeFilterReg.FindStringMatch(name); mat != nil {
if mat, _ := excludeFilterReg.MatchString(name); mat { continue
continue LOOP1
}
} }
} }
if len(filter) > 0 { if len(filter) > 0 {
if mat, _ := filterReg.MatchString(name); !mat { if mat, _ := filterReg.FindStringMatch(name); mat == nil {
continue continue
} }
} }
@@ -428,12 +373,40 @@ func NewProxiesParser(pdName string, filter string, excludeFilter string, exclud
mapping["dialer-proxy"] = dialerProxy mapping["dialer-proxy"] = dialerProxy
} }
err := override.Apply(mapping) if override.UDP != nil {
if err != nil { mapping["udp"] = *override.UDP
return nil, fmt.Errorf("proxy %d override error: %w", idx, err) }
if override.Up != nil {
mapping["up"] = *override.Up
}
if override.Down != nil {
mapping["down"] = *override.Down
}
if override.DialerProxy != nil {
mapping["dialer-proxy"] = *override.DialerProxy
}
if override.SkipCertVerify != nil {
mapping["skip-cert-verify"] = *override.SkipCertVerify
}
if override.Interface != nil {
mapping["interface-name"] = *override.Interface
}
if override.RoutingMark != nil {
mapping["routing-mark"] = *override.RoutingMark
}
if override.IPVersion != nil {
mapping["ip-version"] = *override.IPVersion
}
if override.AdditionalPrefix != nil {
name := mapping["name"].(string)
mapping["name"] = *override.AdditionalPrefix + name
}
if override.AdditionalSuffix != nil {
name := mapping["name"].(string)
mapping["name"] = name + *override.AdditionalSuffix
} }
proxy, err := adapter.ParseProxy(mapping, adapter.WithProviderName(pdName)) proxy, err := adapter.ParseProxy(mapping)
if err != nil { if err != nil {
return nil, fmt.Errorf("proxy %d error: %w", idx, err) return nil, fmt.Errorf("proxy %d error: %w", idx, err)
} }
@@ -451,5 +424,5 @@ func NewProxiesParser(pdName string, filter string, excludeFilter string, exclud
} }
return proxies, nil return proxies, nil
}, nil }
} }

View File

@@ -1,11 +1,8 @@
package provider package provider
import ( import (
"fmt"
"strconv" "strconv"
"strings" "strings"
"github.com/metacubex/mihomo/log"
) )
type SubscriptionInfo struct { type SubscriptionInfo struct {
@@ -15,44 +12,28 @@ type SubscriptionInfo struct {
Expire int64 Expire int64
} }
func NewSubscriptionInfo(userinfo string) (si *SubscriptionInfo) { func NewSubscriptionInfo(userinfo string) (si *SubscriptionInfo, err error) {
userinfo = strings.ReplaceAll(strings.ToLower(userinfo), " ", "") userinfo = strings.ToLower(userinfo)
userinfo = strings.ReplaceAll(userinfo, " ", "")
si = new(SubscriptionInfo) si = new(SubscriptionInfo)
for _, field := range strings.Split(userinfo, ";") { for _, field := range strings.Split(userinfo, ";") {
name, value, ok := strings.Cut(field, "=") switch name, value, _ := strings.Cut(field, "="); name {
if !ok {
continue
}
intValue, err := parseValue(value)
if err != nil {
log.Warnln("[Provider] get subscription-userinfo: %e", err)
continue
}
switch name {
case "upload": case "upload":
si.Upload = intValue si.Upload, err = strconv.ParseInt(value, 10, 64)
case "download": case "download":
si.Download = intValue si.Download, err = strconv.ParseInt(value, 10, 64)
case "total": case "total":
si.Total = intValue si.Total, err = strconv.ParseInt(value, 10, 64)
case "expire": case "expire":
si.Expire = intValue if value == "" {
si.Expire = 0
} else {
si.Expire, err = strconv.ParseInt(value, 10, 64)
}
}
if err != nil {
return
} }
} }
return si return
}
func parseValue(value string) (int64, error) {
if intValue, err := strconv.ParseInt(value, 10, 64); err == nil {
return intValue, nil
}
if floatValue, err := strconv.ParseFloat(value, 64); err == nil {
return int64(floatValue), nil
}
return 0, fmt.Errorf("failed to parse value '%s'", value)
} }

View File

@@ -1,8 +1,6 @@
// Copyright 2014 The Go Authors. All rights reserved. // Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build android && cgo
// +build android,cgo
// kanged from https://github.com/golang/mobile/blob/c713f31d574bb632a93f169b2cc99c9e753fef0e/app/android.go#L89 // kanged from https://github.com/golang/mobile/blob/c713f31d574bb632a93f169b2cc99c9e753fef0e/app/android.go#L89

View File

@@ -33,8 +33,15 @@ type ARC[K comparable, V any] struct {
// New returns a new Adaptive Replacement Cache (ARC). // New returns a new Adaptive Replacement Cache (ARC).
func New[K comparable, V any](options ...Option[K, V]) *ARC[K, V] { func New[K comparable, V any](options ...Option[K, V]) *ARC[K, V] {
arc := &ARC[K, V]{} arc := &ARC[K, V]{
arc.Clear() p: 0,
t1: list.New[*entry[K, V]](),
b1: list.New[*entry[K, V]](),
t2: list.New[*entry[K, V]](),
b2: list.New[*entry[K, V]](),
len: 0,
cache: make(map[K]*entry[K, V]),
}
for _, option := range options { for _, option := range options {
option(arc) option(arc)
@@ -42,19 +49,6 @@ func New[K comparable, V any](options ...Option[K, V]) *ARC[K, V] {
return arc return arc
} }
func (a *ARC[K, V]) Clear() {
a.mutex.Lock()
defer a.mutex.Unlock()
a.p = 0
a.t1 = list.New[*entry[K, V]]()
a.b1 = list.New[*entry[K, V]]()
a.t2 = list.New[*entry[K, V]]()
a.b2 = list.New[*entry[K, V]]()
a.len = 0
a.cache = make(map[K]*entry[K, V])
}
// Set inserts a new key-value pair into the cache. // Set inserts a new key-value pair into the cache.
// This optimizes future access to this entry (side effect). // This optimizes future access to this entry (side effect).
func (a *ARC[K, V]) Set(key K, value V) { func (a *ARC[K, V]) Set(key K, value V) {

View File

@@ -1,63 +0,0 @@
package atomic
import (
"encoding/json"
"fmt"
"sync/atomic"
)
type Int32Enum[T ~int32] struct {
value atomic.Int32
}
func (i *Int32Enum[T]) MarshalJSON() ([]byte, error) {
return json.Marshal(i.Load())
}
func (i *Int32Enum[T]) UnmarshalJSON(b []byte) error {
var v T
if err := json.Unmarshal(b, &v); err != nil {
return err
}
i.Store(v)
return nil
}
func (i *Int32Enum[T]) MarshalYAML() (any, error) {
return i.Load(), nil
}
func (i *Int32Enum[T]) UnmarshalYAML(unmarshal func(any) error) error {
var v T
if err := unmarshal(&v); err != nil {
return err
}
i.Store(v)
return nil
}
func (i *Int32Enum[T]) String() string {
return fmt.Sprint(i.Load())
}
func (i *Int32Enum[T]) Store(v T) {
i.value.Store(int32(v))
}
func (i *Int32Enum[T]) Load() T {
return T(i.value.Load())
}
func (i *Int32Enum[T]) Swap(new T) T {
return T(i.value.Swap(int32(new)))
}
func (i *Int32Enum[T]) CompareAndSwap(old, new T) bool {
return i.value.CompareAndSwap(int32(old), int32(new))
}
func NewInt32Enum[T ~int32](v T) *Int32Enum[T] {
a := &Int32Enum[T]{}
a.Store(v)
return a
}

View File

@@ -29,19 +29,6 @@ func (i *Bool) UnmarshalJSON(b []byte) error {
return nil return nil
} }
func (i *Bool) MarshalYAML() (any, error) {
return i.Load(), nil
}
func (i *Bool) UnmarshalYAML(unmarshal func(any) error) error {
var v bool
if err := unmarshal(&v); err != nil {
return err
}
i.Store(v)
return nil
}
func (i *Bool) String() string { func (i *Bool) String() string {
v := i.Load() v := i.Load()
return strconv.FormatBool(v) return strconv.FormatBool(v)
@@ -71,19 +58,6 @@ func (p *Pointer[T]) UnmarshalJSON(b []byte) error {
return nil return nil
} }
func (p *Pointer[T]) MarshalYAML() (any, error) {
return p.Load(), nil
}
func (p *Pointer[T]) UnmarshalYAML(unmarshal func(any) error) error {
var v *T
if err := unmarshal(&v); err != nil {
return err
}
p.Store(v)
return nil
}
func (p *Pointer[T]) String() string { func (p *Pointer[T]) String() string {
return fmt.Sprint(p.Load()) return fmt.Sprint(p.Load())
} }
@@ -110,19 +84,6 @@ func (i *Int32) UnmarshalJSON(b []byte) error {
return nil return nil
} }
func (i *Int32) MarshalYAML() (any, error) {
return i.Load(), nil
}
func (i *Int32) UnmarshalYAML(unmarshal func(any) error) error {
var v int32
if err := unmarshal(&v); err != nil {
return err
}
i.Store(v)
return nil
}
func (i *Int32) String() string { func (i *Int32) String() string {
v := i.Load() v := i.Load()
return strconv.FormatInt(int64(v), 10) return strconv.FormatInt(int64(v), 10)
@@ -150,19 +111,6 @@ func (i *Int64) UnmarshalJSON(b []byte) error {
return nil return nil
} }
func (i *Int64) MarshalYAML() (any, error) {
return i.Load(), nil
}
func (i *Int64) UnmarshalYAML(unmarshal func(any) error) error {
var v int64
if err := unmarshal(&v); err != nil {
return err
}
i.Store(v)
return nil
}
func (i *Int64) String() string { func (i *Int64) String() string {
v := i.Load() v := i.Load()
return strconv.FormatInt(int64(v), 10) return strconv.FormatInt(int64(v), 10)
@@ -190,19 +138,6 @@ func (i *Uint32) UnmarshalJSON(b []byte) error {
return nil return nil
} }
func (i *Uint32) MarshalYAML() (any, error) {
return i.Load(), nil
}
func (i *Uint32) UnmarshalYAML(unmarshal func(any) error) error {
var v uint32
if err := unmarshal(&v); err != nil {
return err
}
i.Store(v)
return nil
}
func (i *Uint32) String() string { func (i *Uint32) String() string {
v := i.Load() v := i.Load()
return strconv.FormatUint(uint64(v), 10) return strconv.FormatUint(uint64(v), 10)
@@ -230,19 +165,6 @@ func (i *Uint64) UnmarshalJSON(b []byte) error {
return nil return nil
} }
func (i *Uint64) MarshalYAML() (any, error) {
return i.Load(), nil
}
func (i *Uint64) UnmarshalYAML(unmarshal func(any) error) error {
var v uint64
if err := unmarshal(&v); err != nil {
return err
}
i.Store(v)
return nil
}
func (i *Uint64) String() string { func (i *Uint64) String() string {
v := i.Load() v := i.Load()
return strconv.FormatUint(uint64(v), 10) return strconv.FormatUint(uint64(v), 10)
@@ -270,19 +192,6 @@ func (i *Uintptr) UnmarshalJSON(b []byte) error {
return nil return nil
} }
func (i *Uintptr) MarshalYAML() (any, error) {
return i.Load(), nil
}
func (i *Uintptr) UnmarshalYAML(unmarshal func(any) error) error {
var v uintptr
if err := unmarshal(&v); err != nil {
return err
}
i.Store(v)
return nil
}
func (i *Uintptr) String() string { func (i *Uintptr) String() string {
v := i.Load() v := i.Load()
return strconv.FormatUint(uint64(v), 10) return strconv.FormatUint(uint64(v), 10)

View File

@@ -5,50 +5,49 @@ import (
"sync/atomic" "sync/atomic"
) )
func DefaultValue[T any]() T {
var defaultValue T
return defaultValue
}
type TypedValue[T any] struct { type TypedValue[T any] struct {
value atomic.Pointer[T] _ noCopy
value atomic.Value
} }
func (t *TypedValue[T]) Load() (v T) { // tValue is a struct with determined type to resolve atomic.Value usages with interface types
v, _ = t.LoadOk() // https://github.com/golang/go/issues/22550
return //
// The intention to have an atomic value store for errors. However, running this code panics:
// panic: sync/atomic: store of inconsistently typed value into Value
// This is because atomic.Value requires that the underlying concrete type be the same (which is a reasonable expectation for its implementation).
// When going through the atomic.Value.Store method call, the fact that both these are of the error interface is lost.
type tValue[T any] struct {
value T
} }
func (t *TypedValue[T]) LoadOk() (v T, ok bool) { func (t *TypedValue[T]) Load() T {
value := t.value.Load() value := t.value.Load()
if value == nil { if value == nil {
return return DefaultValue[T]()
} }
return *value, true return value.(tValue[T]).value
} }
func (t *TypedValue[T]) Store(value T) { func (t *TypedValue[T]) Store(value T) {
t.value.Store(&value) t.value.Store(tValue[T]{value})
} }
func (t *TypedValue[T]) Swap(new T) (v T) { func (t *TypedValue[T]) Swap(new T) T {
old := t.value.Swap(&new) old := t.value.Swap(tValue[T]{new})
if old == nil { if old == nil {
return return DefaultValue[T]()
} }
return *old return old.(tValue[T]).value
} }
func (t *TypedValue[T]) CompareAndSwap(old, new T) bool { func (t *TypedValue[T]) CompareAndSwap(old, new T) bool {
for { return t.value.CompareAndSwap(tValue[T]{old}, tValue[T]{new})
currentP := t.value.Load()
var currentValue T
if currentP != nil {
currentValue = *currentP
}
// Compare old and current via runtime equality check.
if any(currentValue) != any(old) {
return false
}
if t.value.CompareAndSwap(currentP, &new) {
return true
}
}
} }
func (t *TypedValue[T]) MarshalJSON() ([]byte, error) { func (t *TypedValue[T]) MarshalJSON() ([]byte, error) {
@@ -64,20 +63,13 @@ func (t *TypedValue[T]) UnmarshalJSON(b []byte) error {
return nil return nil
} }
func (t *TypedValue[T]) MarshalYAML() (any, error) {
return t.Load(), nil
}
func (t *TypedValue[T]) UnmarshalYAML(unmarshal func(any) error) error {
var v T
if err := unmarshal(&v); err != nil {
return err
}
t.Store(v)
return nil
}
func NewTypedValue[T any](t T) (v TypedValue[T]) { func NewTypedValue[T any](t T) (v TypedValue[T]) {
v.Store(t) v.Store(t)
return return
} }
type noCopy struct{}
// Lock is a no-op used by -copylocks checker from `go vet`.
func (*noCopy) Lock() {}
func (*noCopy) Unlock() {}

View File

@@ -1,169 +0,0 @@
package atomic
import (
"io"
"os"
"testing"
)
func TestTypedValue(t *testing.T) {
{
var v TypedValue[int]
got, gotOk := v.LoadOk()
if got != 0 || gotOk {
t.Fatalf("LoadOk = (%v, %v), want (0, false)", got, gotOk)
}
v.Store(1)
got, gotOk = v.LoadOk()
if got != 1 || !gotOk {
t.Fatalf("LoadOk = (%v, %v), want (1, true)", got, gotOk)
}
}
{
var v TypedValue[error]
got, gotOk := v.LoadOk()
if got != nil || gotOk {
t.Fatalf("LoadOk = (%v, %v), want (nil, false)", got, gotOk)
}
v.Store(io.EOF)
got, gotOk = v.LoadOk()
if got != io.EOF || !gotOk {
t.Fatalf("LoadOk = (%v, %v), want (EOF, true)", got, gotOk)
}
err := &os.PathError{}
v.Store(err)
got, gotOk = v.LoadOk()
if got != err || !gotOk {
t.Fatalf("LoadOk = (%v, %v), want (%v, true)", got, gotOk, err)
}
v.Store(nil)
got, gotOk = v.LoadOk()
if got != nil || !gotOk {
t.Fatalf("LoadOk = (%v, %v), want (nil, true)", got, gotOk)
}
}
{
e1, e2, e3 := io.EOF, &os.PathError{}, &os.PathError{}
var v TypedValue[error]
if v.CompareAndSwap(e1, e2) != false {
t.Fatalf("CompareAndSwap = true, want false")
}
if value := v.Load(); value != nil {
t.Fatalf("Load = (%v), want (%v)", value, nil)
}
if v.CompareAndSwap(nil, e1) != true {
t.Fatalf("CompareAndSwap = false, want true")
}
if value := v.Load(); value != e1 {
t.Fatalf("Load = (%v), want (%v)", value, e1)
}
if v.CompareAndSwap(e2, e3) != false {
t.Fatalf("CompareAndSwap = true, want false")
}
if value := v.Load(); value != e1 {
t.Fatalf("Load = (%v), want (%v)", value, e1)
}
if v.CompareAndSwap(e1, e2) != true {
t.Fatalf("CompareAndSwap = false, want true")
}
if value := v.Load(); value != e2 {
t.Fatalf("Load = (%v), want (%v)", value, e2)
}
if v.CompareAndSwap(e3, e2) != false {
t.Fatalf("CompareAndSwap = true, want false")
}
if value := v.Load(); value != e2 {
t.Fatalf("Load = (%v), want (%v)", value, e2)
}
if v.CompareAndSwap(nil, e3) != false {
t.Fatalf("CompareAndSwap = true, want false")
}
if value := v.Load(); value != e2 {
t.Fatalf("Load = (%v), want (%v)", value, e2)
}
}
{
c1, c2, c3 := make(chan struct{}), make(chan struct{}), make(chan struct{})
var v TypedValue[chan struct{}]
if v.CompareAndSwap(c1, c2) != false {
t.Fatalf("CompareAndSwap = true, want false")
}
if value := v.Load(); value != nil {
t.Fatalf("Load = (%v), want (%v)", value, nil)
}
if v.CompareAndSwap(nil, c1) != true {
t.Fatalf("CompareAndSwap = false, want true")
}
if value := v.Load(); value != c1 {
t.Fatalf("Load = (%v), want (%v)", value, c1)
}
if v.CompareAndSwap(c2, c3) != false {
t.Fatalf("CompareAndSwap = true, want false")
}
if value := v.Load(); value != c1 {
t.Fatalf("Load = (%v), want (%v)", value, c1)
}
if v.CompareAndSwap(c1, c2) != true {
t.Fatalf("CompareAndSwap = false, want true")
}
if value := v.Load(); value != c2 {
t.Fatalf("Load = (%v), want (%v)", value, c2)
}
if v.CompareAndSwap(c3, c2) != false {
t.Fatalf("CompareAndSwap = true, want false")
}
if value := v.Load(); value != c2 {
t.Fatalf("Load = (%v), want (%v)", value, c2)
}
if v.CompareAndSwap(nil, c3) != false {
t.Fatalf("CompareAndSwap = true, want false")
}
if value := v.Load(); value != c2 {
t.Fatalf("Load = (%v), want (%v)", value, c2)
}
}
{
c1, c2, c3 := &io.LimitedReader{}, &io.SectionReader{}, &io.SectionReader{}
var v TypedValue[io.Reader]
if v.CompareAndSwap(c1, c2) != false {
t.Fatalf("CompareAndSwap = true, want false")
}
if value := v.Load(); value != nil {
t.Fatalf("Load = (%v), want (%v)", value, nil)
}
if v.CompareAndSwap(nil, c1) != true {
t.Fatalf("CompareAndSwap = false, want true")
}
if value := v.Load(); value != c1 {
t.Fatalf("Load = (%v), want (%v)", value, c1)
}
if v.CompareAndSwap(c2, c3) != false {
t.Fatalf("CompareAndSwap = true, want false")
}
if value := v.Load(); value != c1 {
t.Fatalf("Load = (%v), want (%v)", value, c1)
}
if v.CompareAndSwap(c1, c2) != true {
t.Fatalf("CompareAndSwap = false, want true")
}
if value := v.Load(); value != c2 {
t.Fatalf("Load = (%v), want (%v)", value, c2)
}
if v.CompareAndSwap(c3, c2) != false {
t.Fatalf("CompareAndSwap = true, want false")
}
if value := v.Load(); value != c2 {
t.Fatalf("Load = (%v), want (%v)", value, c2)
}
if v.CompareAndSwap(nil, c3) != false {
t.Fatalf("CompareAndSwap = true, want false")
}
if value := v.Load(); value != c2 {
t.Fatalf("Load = (%v), want (%v)", value, c2)
}
}
}

View File

@@ -1,8 +1,8 @@
package buf package buf
import ( import (
"github.com/metacubex/sing/common" "github.com/sagernet/sing/common"
"github.com/metacubex/sing/common/buf" "github.com/sagernet/sing/common/buf"
) )
const BufferSize = buf.BufferSize const BufferSize = buf.BufferSize
@@ -14,7 +14,6 @@ var NewPacket = buf.NewPacket
var NewSize = buf.NewSize var NewSize = buf.NewSize
var With = buf.With var With = buf.With
var As = buf.As var As = buf.As
var ReleaseMulti = buf.ReleaseMulti
var ( var (
Must = common.Must Must = common.Must

View File

@@ -1,31 +0,0 @@
package contextutils
import (
"context"
"sync"
)
func afterFunc(ctx context.Context, f func()) (stop func() bool) {
stopc := make(chan struct{})
once := sync.Once{} // either starts running f or stops f from running
if ctx.Done() != nil {
go func() {
select {
case <-ctx.Done():
once.Do(func() {
go f()
})
case <-stopc:
}
}()
}
return func() bool {
stopped := false
once.Do(func() {
stopped = true
close(stopc)
})
return stopped
}
}

View File

@@ -1,11 +0,0 @@
//go:build !go1.21
package contextutils
import (
"context"
)
func AfterFunc(ctx context.Context, f func()) (stop func() bool) {
return afterFunc(ctx, f)
}

View File

@@ -1,9 +0,0 @@
//go:build go1.21
package contextutils
import "context"
func AfterFunc(ctx context.Context, f func()) (stop func() bool) {
return context.AfterFunc(ctx, f)
}

View File

@@ -1,100 +0,0 @@
package contextutils
import (
"context"
"testing"
"time"
)
const (
shortDuration = 1 * time.Millisecond // a reasonable duration to block in a test
veryLongDuration = 1000 * time.Hour // an arbitrary upper bound on the test's running time
)
func TestAfterFuncCalledAfterCancel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
donec := make(chan struct{})
stop := afterFunc(ctx, func() {
close(donec)
})
select {
case <-donec:
t.Fatalf("AfterFunc called before context is done")
case <-time.After(shortDuration):
}
cancel()
select {
case <-donec:
case <-time.After(veryLongDuration):
t.Fatalf("AfterFunc not called after context is canceled")
}
if stop() {
t.Fatalf("stop() = true, want false")
}
}
func TestAfterFuncCalledAfterTimeout(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), shortDuration)
defer cancel()
donec := make(chan struct{})
afterFunc(ctx, func() {
close(donec)
})
select {
case <-donec:
case <-time.After(veryLongDuration):
t.Fatalf("AfterFunc not called after context is canceled")
}
}
func TestAfterFuncCalledImmediately(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
donec := make(chan struct{})
afterFunc(ctx, func() {
close(donec)
})
select {
case <-donec:
case <-time.After(veryLongDuration):
t.Fatalf("AfterFunc not called for already-canceled context")
}
}
func TestAfterFuncNotCalledAfterStop(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
donec := make(chan struct{})
stop := afterFunc(ctx, func() {
close(donec)
})
if !stop() {
t.Fatalf("stop() = false, want true")
}
cancel()
select {
case <-donec:
t.Fatalf("AfterFunc called for already-canceled context")
case <-time.After(shortDuration):
}
if stop() {
t.Fatalf("stop() = true, want false")
}
}
// This test verifies that canceling a context does not block waiting for AfterFuncs to finish.
func TestAfterFuncCalledAsynchronously(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
donec := make(chan struct{})
stop := afterFunc(ctx, func() {
// The channel send blocks until donec is read from.
donec <- struct{}{}
})
defer stop()
cancel()
// After cancel returns, read from donec and unblock the AfterFunc.
select {
case <-donec:
case <-time.After(veryLongDuration):
t.Fatalf("AfterFunc not called after context is canceled")
}
}

View File

@@ -2,7 +2,6 @@ package convert
import ( import (
"encoding/base64" "encoding/base64"
"fmt"
"strings" "strings"
) )
@@ -44,22 +43,3 @@ func decodeUrlSafe(data string) string {
} }
return string(dcBuf) return string(dcBuf)
} }
func TryDecodeBase64(s string) (decoded []byte, err error) {
if len(s)%4 == 0 {
if decoded, err = base64.StdEncoding.DecodeString(s); err == nil {
return
}
if decoded, err = base64.URLEncoding.DecodeString(s); err == nil {
return
}
} else {
if decoded, err = base64.RawStdEncoding.DecodeString(s); err == nil {
return
}
if decoded, err = base64.RawURLEncoding.DecodeString(s); err == nil {
return
}
}
return nil, fmt.Errorf("invalid base64-encoded string")
}

View File

@@ -201,10 +201,6 @@ func ConvertsV2Ray(buf []byte) ([]map[string]any, error) {
trojan["client-fingerprint"] = fingerprint trojan["client-fingerprint"] = fingerprint
} }
if pcs := query.Get("pcs"); pcs != "" {
trojan["fingerprint"] = pcs
}
proxies = append(proxies, trojan) proxies = append(proxies, trojan)
case "vless": case "vless":
@@ -212,9 +208,6 @@ func ConvertsV2Ray(buf []byte) ([]map[string]any, error) {
if err != nil { if err != nil {
continue continue
} }
if decodedHost, err := tryDecodeBase64([]byte(urlVLess.Host)); err == nil {
urlVLess.Host = string(decodedHost)
}
query := urlVLess.Query() query := urlVLess.Query()
vless := make(map[string]any, 20) vless := make(map[string]any, 20)
err = handleVShareLink(names, urlVLess, scheme, vless) err = handleVShareLink(names, urlVLess, scheme, vless)
@@ -225,9 +218,6 @@ func ConvertsV2Ray(buf []byte) ([]map[string]any, error) {
if flow := query.Get("flow"); flow != "" { if flow := query.Get("flow"); flow != "" {
vless["flow"] = strings.ToLower(flow) vless["flow"] = strings.ToLower(flow)
} }
if encryption := query.Get("encryption"); encryption != "" {
vless["encryption"] = encryption
}
proxies = append(proxies, vless) proxies = append(proxies, vless)
case "vmess": case "vmess":
@@ -285,24 +275,22 @@ func ConvertsV2Ray(buf []byte) ([]map[string]any, error) {
vmess["skip-cert-verify"] = false vmess["skip-cert-verify"] = false
vmess["cipher"] = "auto" vmess["cipher"] = "auto"
if cipher, ok := values["scy"].(string); ok && cipher != "" { if cipher, ok := values["scy"]; ok && cipher != "" {
vmess["cipher"] = cipher vmess["cipher"] = cipher
} }
if sni, ok := values["sni"].(string); ok && sni != "" { if sni, ok := values["sni"]; ok && sni != "" {
vmess["servername"] = sni vmess["servername"] = sni
} }
network, ok := values["net"].(string) network, _ := values["net"].(string)
if ok { network = strings.ToLower(network)
network = strings.ToLower(network) if values["type"] == "http" {
if values["type"] == "http" { network = "http"
network = "http" } else if network == "http" {
} else if network == "http" { network = "h2"
network = "h2"
}
vmess["network"] = network
} }
vmess["network"] = network
tls, ok := values["tls"].(string) tls, ok := values["tls"].(string)
if ok { if ok {
@@ -319,12 +307,12 @@ func ConvertsV2Ray(buf []byte) ([]map[string]any, error) {
case "http": case "http":
headers := make(map[string]any) headers := make(map[string]any)
httpOpts := make(map[string]any) httpOpts := make(map[string]any)
if host, ok := values["host"].(string); ok && host != "" { if host, ok := values["host"]; ok && host != "" {
headers["Host"] = []string{host} headers["Host"] = []string{host.(string)}
} }
httpOpts["path"] = []string{"/"} httpOpts["path"] = []string{"/"}
if path, ok := values["path"].(string); ok && path != "" { if path, ok := values["path"]; ok && path != "" {
httpOpts["path"] = []string{path} httpOpts["path"] = []string{path.(string)}
} }
httpOpts["headers"] = headers httpOpts["headers"] = headers
@@ -333,8 +321,8 @@ func ConvertsV2Ray(buf []byte) ([]map[string]any, error) {
case "h2": case "h2":
headers := make(map[string]any) headers := make(map[string]any)
h2Opts := make(map[string]any) h2Opts := make(map[string]any)
if host, ok := values["host"].(string); ok && host != "" { if host, ok := values["host"]; ok && host != "" {
headers["Host"] = []string{host} headers["Host"] = []string{host.(string)}
} }
h2Opts["path"] = values["path"] h2Opts["path"] = values["path"]
@@ -342,38 +330,15 @@ func ConvertsV2Ray(buf []byte) ([]map[string]any, error) {
vmess["h2-opts"] = h2Opts vmess["h2-opts"] = h2Opts
case "ws", "httpupgrade": case "ws":
headers := make(map[string]any) headers := make(map[string]any)
wsOpts := make(map[string]any) wsOpts := make(map[string]any)
wsOpts["path"] = "/" wsOpts["path"] = []string{"/"}
if host, ok := values["host"].(string); ok && host != "" { if host, ok := values["host"]; ok && host != "" {
headers["Host"] = host headers["Host"] = host.(string)
} }
if path, ok := values["path"].(string); ok && path != "" { if path, ok := values["path"]; ok && path != "" {
path := path wsOpts["path"] = path.(string)
pathURL, err := url.Parse(path)
if err == nil {
query := pathURL.Query()
if earlyData := query.Get("ed"); earlyData != "" {
med, err := strconv.Atoi(earlyData)
if err == nil {
switch network {
case "ws":
wsOpts["max-early-data"] = med
wsOpts["early-data-header-name"] = "Sec-WebSocket-Protocol"
case "httpupgrade":
wsOpts["v2ray-http-upgrade-fast-open"] = true
}
query.Del("ed")
pathURL.RawQuery = query.Encode()
path = pathURL.String()
}
}
if earlyDataHeader := query.Get("eh"); earlyDataHeader != "" {
wsOpts["early-data-header-name"] = earlyDataHeader
}
}
wsOpts["path"] = path
} }
wsOpts["headers"] = headers wsOpts["headers"] = headers
vmess["ws-opts"] = wsOpts vmess["ws-opts"] = wsOpts
@@ -466,12 +431,12 @@ func ConvertsV2Ray(buf []byte) ([]map[string]any, error) {
proxies = append(proxies, ss) proxies = append(proxies, ss)
case "ssr": case "ssr":
dcBuf, err := TryDecodeBase64(body) dcBuf, err := encRaw.DecodeString(body)
if err != nil { if err != nil {
continue continue
} }
// ssr://host:port:protocol:method:obfs:urlsafebase64pass/?obfsparam=urlsafebase64param&protoparam=urlsafebase64param&remarks=urlsafebase64remarks&group=urlsafebase64group&udpport=0&uot=1 // ssr://host:port:protocol:method:obfs:urlsafebase64pass/?obfsparam=urlsafebase64&protoparam=&remarks=urlsafebase64&group=urlsafebase64&udpport=0&uot=1
before, after, ok := strings.Cut(string(dcBuf), "/?") before, after, ok := strings.Cut(string(dcBuf), "/?")
if !ok { if !ok {
@@ -500,7 +465,7 @@ func ConvertsV2Ray(buf []byte) ([]map[string]any, error) {
name := uniqueName(names, remarks) name := uniqueName(names, remarks)
obfsParam := decodeUrlSafe(query.Get("obfsparam")) obfsParam := decodeUrlSafe(query.Get("obfsparam"))
protocolParam := decodeUrlSafe(query.Get("protoparam")) protocolParam := query.Get("protoparam")
ssr := make(map[string]any, 20) ssr := make(map[string]any, 20)
@@ -523,101 +488,6 @@ func ConvertsV2Ray(buf []byte) ([]map[string]any, error) {
} }
proxies = append(proxies, ssr) proxies = append(proxies, ssr)
case "socks", "socks5", "socks5h", "http", "https":
link, err := url.Parse(line)
if err != nil {
continue
}
server := link.Hostname()
if server == "" {
continue
}
portStr := link.Port()
if portStr == "" {
continue
}
remarks := link.Fragment
if remarks == "" {
remarks = fmt.Sprintf("%s:%s", server, portStr)
}
name := uniqueName(names, remarks)
encodeStr := link.User.String()
var username, password string
if encodeStr != "" {
decodeStr := string(DecodeBase64([]byte(encodeStr)))
splitStr := strings.Split(decodeStr, ":")
// todo: should use url.QueryUnescape ?
username = splitStr[0]
if len(splitStr) == 2 {
password = splitStr[1]
}
}
socks := make(map[string]any, 10)
socks["name"] = name
socks["type"] = func() string {
switch scheme {
case "socks", "socks5", "socks5h":
return "socks5"
case "http", "https":
return "http"
}
return scheme
}()
socks["server"] = server
socks["port"] = portStr
socks["username"] = username
socks["password"] = password
socks["skip-cert-verify"] = true
if scheme == "https" {
socks["tls"] = true
}
proxies = append(proxies, socks)
case "anytls":
// https://github.com/anytls/anytls-go/blob/main/docs/uri_scheme.md
link, err := url.Parse(line)
if err != nil {
continue
}
username := link.User.Username()
password, exist := link.User.Password()
if !exist {
password = username
}
query := link.Query()
server := link.Hostname()
if server == "" {
continue
}
portStr := link.Port()
if portStr == "" {
continue
}
insecure, sni := query.Get("insecure"), query.Get("sni")
insecureBool := insecure == "1"
fingerprint := query.Get("hpkp")
remarks := link.Fragment
if remarks == "" {
remarks = fmt.Sprintf("%s:%s", server, portStr)
}
name := uniqueName(names, remarks)
anytls := make(map[string]any, 10)
anytls["name"] = name
anytls["type"] = "anytls"
anytls["server"] = server
anytls["port"] = portStr
anytls["username"] = username
anytls["password"] = password
anytls["sni"] = sni
anytls["fingerprint"] = fingerprint
anytls["skip-cert-verify"] = insecureBool
anytls["udp"] = true
proxies = append(proxies, anytls)
} }
} }

View File

@@ -2,14 +2,14 @@ package convert
import ( import (
"encoding/base64" "encoding/base64"
"net/http"
"strings" "strings"
"time" "time"
"github.com/metacubex/mihomo/common/utils" "github.com/metacubex/mihomo/common/utils"
"github.com/metacubex/http"
"github.com/metacubex/randv2"
"github.com/metacubex/sing-shadowsocks/shadowimpl" "github.com/metacubex/sing-shadowsocks/shadowimpl"
"github.com/zhangyunhao116/fastrand"
) )
var hostsSuffix = []string{ var hostsSuffix = []string{
@@ -302,11 +302,11 @@ func RandHost() string {
prefix += string(buf[6:8]) + "-" prefix += string(buf[6:8]) + "-"
prefix += string(buf[len(buf)-8:]) prefix += string(buf[len(buf)-8:])
return prefix + hostsSuffix[randv2.IntN(hostsLen)] return prefix + hostsSuffix[fastrand.Intn(hostsLen)]
} }
func RandUserAgent() string { func RandUserAgent() string {
return userAgents[randv2.IntN(uaLen)] return userAgents[fastrand.Intn(uaLen)]
} }
func SetUserAgent(header http.Header) { func SetUserAgent(header http.Header) {

View File

@@ -35,9 +35,6 @@ func handleVShareLink(names map[string]int, url *url.URL, scheme string, proxy m
if alpn := query.Get("alpn"); alpn != "" { if alpn := query.Get("alpn"); alpn != "" {
proxy["alpn"] = strings.Split(alpn, ",") proxy["alpn"] = strings.Split(alpn, ",")
} }
if pcs := query.Get("pcs"); pcs != "" {
proxy["fingerprint"] = pcs
}
} }
if sni := query.Get("sni"); sni != "" { if sni := query.Get("sni"); sni != "" {
proxy["servername"] = sni proxy["servername"] = sni
@@ -103,7 +100,7 @@ func handleVShareLink(names map[string]int, url *url.URL, scheme string, proxy m
h2Opts["headers"] = headers h2Opts["headers"] = headers
proxy["h2-opts"] = h2Opts proxy["h2-opts"] = h2Opts
case "ws", "httpupgrade": case "ws":
headers := make(map[string]any) headers := make(map[string]any)
wsOpts := make(map[string]any) wsOpts := make(map[string]any)
headers["User-Agent"] = RandUserAgent() headers["User-Agent"] = RandUserAgent()
@@ -116,13 +113,7 @@ func handleVShareLink(names map[string]int, url *url.URL, scheme string, proxy m
if err != nil { if err != nil {
return fmt.Errorf("bad WebSocket max early data size: %v", err) return fmt.Errorf("bad WebSocket max early data size: %v", err)
} }
switch network { wsOpts["max-early-data"] = med
case "ws":
wsOpts["max-early-data"] = med
wsOpts["early-data-header-name"] = "Sec-WebSocket-Protocol"
case "httpupgrade":
wsOpts["v2ray-http-upgrade-fast-open"] = true
}
} }
if earlyDataHeader := query.Get("eh"); earlyDataHeader != "" { if earlyDataHeader := query.Get("eh"); earlyDataHeader != "" {
wsOpts["early-data-header-name"] = earlyDataHeader wsOpts["early-data-header-name"] = earlyDataHeader

View File

@@ -1,674 +0,0 @@
package deque
// copy and modified from https://github.com/gammazero/deque/blob/v1.2.0/deque.go
// which is licensed under MIT.
import (
"fmt"
)
// minCapacity is the smallest capacity that deque may have. Must be power of 2
// for bitwise modulus: x % n == x & (n - 1).
const minCapacity = 8
// Deque represents a single instance of the deque data structure. A Deque
// instance contains items of the type specified by the type argument.
//
// For example, to create a Deque that contains strings do one of the
// following:
//
// var stringDeque deque.Deque[string]
// stringDeque := new(deque.Deque[string])
// stringDeque := &deque.Deque[string]{}
//
// To create a Deque that will never resize to have space for less than 64
// items, specify a base capacity:
//
// var d deque.Deque[int]
// d.SetBaseCap(64)
//
// To ensure the Deque can store 1000 items without needing to resize while
// items are added:
//
// d.Grow(1000)
//
// Any values supplied to [SetBaseCap] and [Grow] are rounded up to the nearest
// power of 2, since the Deque grows by powers of 2.
type Deque[T any] struct {
buf []T
head int
tail int
count int
minCap int
}
// Cap returns the current capacity of the Deque. If q is nil, q.Cap() is zero.
func (q *Deque[T]) Cap() int {
if q == nil {
return 0
}
return len(q.buf)
}
// Len returns the number of elements currently stored in the queue. If q is
// nil, q.Len() returns zero.
func (q *Deque[T]) Len() int {
if q == nil {
return 0
}
return q.count
}
// PushBack appends an element to the back of the queue. Implements FIFO when
// elements are removed with [PopFront], and LIFO when elements are removed with
// [PopBack].
func (q *Deque[T]) PushBack(elem T) {
q.growIfFull()
q.buf[q.tail] = elem
// Calculate new tail position.
q.tail = q.next(q.tail)
q.count++
}
// PushFront prepends an element to the front of the queue.
func (q *Deque[T]) PushFront(elem T) {
q.growIfFull()
// Calculate new head position.
q.head = q.prev(q.head)
q.buf[q.head] = elem
q.count++
}
// PopFront removes and returns the element from the front of the queue.
// Implements FIFO when used with [PushBack]. If the queue is empty, the call
// panics.
func (q *Deque[T]) PopFront() T {
if q.count <= 0 {
panic("deque: PopFront() called on empty queue")
}
ret := q.buf[q.head]
var zero T
q.buf[q.head] = zero
// Calculate new head position.
q.head = q.next(q.head)
q.count--
q.shrinkIfExcess()
return ret
}
// IterPopFront returns an iterator that iteratively removes items from the
// front of the deque. This is more efficient than removing items one at a time
// because it avoids intermediate resizing. If a resize is necessary, only one
// is done when iteration ends.
func (q *Deque[T]) IterPopFront() func(yield func(T) bool) {
return func(yield func(T) bool) {
if q.Len() == 0 {
return
}
var zero T
for q.count != 0 {
ret := q.buf[q.head]
q.buf[q.head] = zero
q.head = q.next(q.head)
q.count--
if !yield(ret) {
break
}
}
q.shrinkToFit()
}
}
// PopBack removes and returns the element from the back of the queue.
// Implements LIFO when used with [PushBack]. If the queue is empty, the call
// panics.
func (q *Deque[T]) PopBack() T {
if q.count <= 0 {
panic("deque: PopBack() called on empty queue")
}
// Calculate new tail position
q.tail = q.prev(q.tail)
// Remove value at tail.
ret := q.buf[q.tail]
var zero T
q.buf[q.tail] = zero
q.count--
q.shrinkIfExcess()
return ret
}
// IterPopBack returns an iterator that iteratively removes items from the back
// of the deque. This is more efficient than removing items one at a time
// because it avoids intermediate resizing. If a resize is necessary, only one
// is done when iteration ends.
func (q *Deque[T]) IterPopBack() func(yield func(T) bool) {
return func(yield func(T) bool) {
if q.Len() == 0 {
return
}
var zero T
for q.count != 0 {
q.tail = q.prev(q.tail)
ret := q.buf[q.tail]
q.buf[q.tail] = zero
q.count--
if !yield(ret) {
break
}
}
q.shrinkToFit()
}
}
// Front returns the element at the front of the queue. This is the element
// that would be returned by [PopFront]. This call panics if the queue is
// empty.
func (q *Deque[T]) Front() T {
if q.count <= 0 {
panic("deque: Front() called when empty")
}
return q.buf[q.head]
}
// Back returns the element at the back of the queue. This is the element that
// would be returned by [PopBack]. This call panics if the queue is empty.
func (q *Deque[T]) Back() T {
if q.count <= 0 {
panic("deque: Back() called when empty")
}
return q.buf[q.prev(q.tail)]
}
// At returns the element at index i in the queue without removing the element
// from the queue. This method accepts only non-negative index values. At(0)
// refers to the first element and is the same as [Front]. At(Len()-1) refers
// to the last element and is the same as [Back]. If the index is invalid, the
// call panics.
//
// The purpose of At is to allow Deque to serve as a more general purpose
// circular buffer, where items are only added to and removed from the ends of
// the deque, but may be read from any place within the deque. Consider the
// case of a fixed-size circular log buffer: A new entry is pushed onto one end
// and when full the oldest is popped from the other end. All the log entries
// in the buffer must be readable without altering the buffer contents.
func (q *Deque[T]) At(i int) T {
q.checkRange(i)
// bitwise modulus
return q.buf[(q.head+i)&(len(q.buf)-1)]
}
// Set assigns the item to index i in the queue. Set indexes the deque the same
// as [At] but perform the opposite operation. If the index is invalid, the call
// panics.
func (q *Deque[T]) Set(i int, item T) {
q.checkRange(i)
// bitwise modulus
q.buf[(q.head+i)&(len(q.buf)-1)] = item
}
// Iter returns a go iterator to range over all items in the Deque, yielding
// each item from front (index 0) to back (index Len()-1). Modification of
// Deque during iteration panics.
func (q *Deque[T]) Iter() func(yield func(T) bool) {
return func(yield func(T) bool) {
origHead := q.head
origTail := q.tail
head := origHead
for i := -0; i < q.Len(); i++ {
if q.head != origHead || q.tail != origTail {
panic("deque: modified during iteration")
}
if !yield(q.buf[head]) {
return
}
head = q.next(head)
}
}
}
// RIter returns a reverse go iterator to range over all items in the Deque,
// yielding each item from back (index Len()-1) to front (index 0).
// Modification of Deque during iteration panics.
func (q *Deque[T]) RIter() func(yield func(T) bool) {
return func(yield func(T) bool) {
origHead := q.head
origTail := q.tail
tail := origTail
for i := -0; i < q.Len(); i++ {
if q.head != origHead || q.tail != origTail {
panic("deque: modified during iteration")
}
tail = q.prev(tail)
if !yield(q.buf[tail]) {
return
}
}
}
}
// Clear removes all elements from the queue, but retains the current capacity.
// This is useful when repeatedly reusing the queue at high frequency to avoid
// GC during reuse. The queue will not be resized smaller as long as items are
// only added. Only when items are removed is the queue subject to getting
// resized smaller.
func (q *Deque[T]) Clear() {
if q.Len() == 0 {
return
}
head, tail := q.head, q.tail
q.count = 0
q.head = 0
q.tail = 0
if head >= tail {
// [DEF....ABC]
clearSlice(q.buf[head:])
head = 0
}
clearSlice(q.buf[head:tail])
}
func clearSlice[S ~[]E, E any](s S) {
var zero E
for i := range s {
s[i] = zero
}
}
// Grow grows deque's capacity, if necessary, to guarantee space for another n
// items. After Grow(n), at least n items can be written to the deque without
// another allocation. If n is negative, Grow panics.
func (q *Deque[T]) Grow(n int) {
if n < 0 {
panic("deque.Grow: negative count")
}
c := q.Cap()
l := q.Len()
// If already big enough.
if n <= c-l {
return
}
if c == 0 {
c = minCapacity
}
newLen := l + n
for c < newLen {
c <<= 1
}
if l == 0 {
q.buf = make([]T, c)
q.head = 0
q.tail = 0
} else {
q.resize(c)
}
}
// Copy copies the contents of the given src Deque into this Deque.
//
// n := b.Copy(a)
//
// is an efficient shortcut for
//
// b.Clear()
// n := a.Len()
// b.Grow(n)
// for i := 0; i < n; i++ {
// b.PushBack(a.At(i))
// }
func (q *Deque[T]) Copy(src Deque[T]) int {
q.Clear()
q.Grow(src.Len())
n := src.CopyOutSlice(q.buf)
q.count = n
q.tail = n
q.head = 0
return n
}
// AppendToSlice appends from the Deque to the given slice. If the slice has
// insufficient capacity to store all elements in Deque, then allocate a new
// slice. Returns the resulting slice.
//
// out = q.AppendToSlice(out)
//
// is an efficient shortcut for
//
// for i := 0; i < q.Len(); i++ {
// x = append(out, q.At(i))
// }
func (q *Deque[T]) AppendToSlice(out []T) []T {
if q.count == 0 {
return out
}
head, tail := q.head, q.tail
if head >= tail {
// [DEF....ABC]
out = append(out, q.buf[head:]...)
head = 0
}
return append(out, q.buf[head:tail]...)
}
// CopyInSlice replaces the contents of Deque with all the elements from the
// given slice, in. If len(in) is zero, then this is equivalent to calling
// [Clear].
//
// q.CopyInSlice(in)
//
// is an efficient shortcut for
//
// q.Clear()
// for i := range in {
// q.PushBack(in[i])
// }
func (q *Deque[T]) CopyInSlice(in []T) {
// Allocate new buffer if more space needed.
if len(q.buf) < len(in) {
newCap := len(q.buf)
if newCap == 0 {
newCap = minCapacity
q.minCap = minCapacity
}
for newCap < len(in) {
newCap <<= 1
}
q.buf = make([]T, newCap)
} else if len(q.buf) > len(in) {
q.Clear()
}
n := copy(q.buf, in)
q.count = n
q.tail = n
q.head = 0
}
// CopyOutSlice copies elements from the Deque into the given slice, up to the
// size of the buffer. Returns the number of elements copied, which will be the
// minimum of q.Len() and len(out).
//
// n := q.CopyOutSlice(out)
//
// is an efficient shortcut for
//
// n := min(len(out), q.Len())
// for i := 0; i < n; i++ {
// out[i] = q.At(i)
// }
//
// This function is preferable to one that returns a copy of the internal
// buffer because this allows reuse of memory receiving data, for repeated copy
// operations.
func (q *Deque[T]) CopyOutSlice(out []T) int {
if q.count == 0 || len(out) == 0 {
return 0
}
head, tail := q.head, q.tail
var n int
if head >= tail {
// [DEF....ABC]
n = copy(out, q.buf[head:])
out = out[n:]
if len(out) == 0 {
return n
}
head = 0
}
n += copy(out, q.buf[head:tail])
return n
}
// Rotate rotates the deque n steps front-to-back. If n is negative, rotates
// back-to-front. Having Deque provide Rotate avoids resizing that could happen
// if implementing rotation using only Pop and Push methods. If q.Len() is one
// or less, or q is nil, then Rotate does nothing.
func (q *Deque[T]) Rotate(n int) {
if q.Len() <= 1 {
return
}
// Rotating a multiple of q.count is same as no rotation.
n %= q.count
if n == 0 {
return
}
modBits := len(q.buf) - 1
// If no empty space in buffer, only move head and tail indexes.
if q.head == q.tail {
// Calculate new head and tail using bitwise modulus.
q.head = (q.head + n) & modBits
q.tail = q.head
return
}
var zero T
if n < 0 {
// Rotate back to front.
for ; n < 0; n++ {
// Calculate new head and tail using bitwise modulus.
q.head = (q.head - 1) & modBits
q.tail = (q.tail - 1) & modBits
// Put tail value at head and remove value at tail.
q.buf[q.head] = q.buf[q.tail]
q.buf[q.tail] = zero
}
return
}
// Rotate front to back.
for ; n > 0; n-- {
// Put head value at tail and remove value at head.
q.buf[q.tail] = q.buf[q.head]
q.buf[q.head] = zero
// Calculate new head and tail using bitwise modulus.
q.head = (q.head + 1) & modBits
q.tail = (q.tail + 1) & modBits
}
}
// Index returns the index into the Deque of the first item satisfying f(item),
// or -1 if none do. If q is nil, then -1 is always returned. Search is linear
// starting with index 0.
func (q *Deque[T]) Index(f func(T) bool) int {
if q.Len() > 0 {
modBits := len(q.buf) - 1
for i := 0; i < q.count; i++ {
if f(q.buf[(q.head+i)&modBits]) {
return i
}
}
}
return -1
}
// RIndex is the same as Index, but searches from Back to Front. The index
// returned is from Front to Back, where index 0 is the index of the item
// returned by [Front].
func (q *Deque[T]) RIndex(f func(T) bool) int {
if q.Len() > 0 {
modBits := len(q.buf) - 1
for i := q.count - 1; i >= 0; i-- {
if f(q.buf[(q.head+i)&modBits]) {
return i
}
}
}
return -1
}
// Insert is used to insert an element into the middle of the queue, before the
// element at the specified index. Insert(0,e) is the same as PushFront(e) and
// Insert(Len(),e) is the same as PushBack(e). Out of range indexes result in
// pushing the item onto the front of back of the deque.
//
// Important: Deque is optimized for O(1) operations at the ends of the queue,
// not for operations in the the middle. Complexity of this function is
// constant plus linear in the lesser of the distances between the index and
// either of the ends of the queue.
func (q *Deque[T]) Insert(at int, item T) {
if at <= 0 {
q.PushFront(item)
return
}
if at >= q.Len() {
q.PushBack(item)
return
}
if at*2 < q.count {
q.PushFront(item)
front := q.head
for i := 0; i < at; i++ {
next := q.next(front)
q.buf[front], q.buf[next] = q.buf[next], q.buf[front]
front = next
}
return
}
swaps := q.count - at
q.PushBack(item)
back := q.prev(q.tail)
for i := 0; i < swaps; i++ {
prev := q.prev(back)
q.buf[back], q.buf[prev] = q.buf[prev], q.buf[back]
back = prev
}
}
// Remove removes and returns an element from the middle of the queue, at the
// specified index. Remove(0) is the same as [PopFront] and Remove(Len()-1) is
// the same as [PopBack]. Accepts only non-negative index values, and panics if
// index is out of range.
//
// Important: Deque is optimized for O(1) operations at the ends of the queue,
// not for operations in the the middle. Complexity of this function is
// constant plus linear in the lesser of the distances between the index and
// either of the ends of the queue.
func (q *Deque[T]) Remove(at int) T {
q.checkRange(at)
rm := (q.head + at) & (len(q.buf) - 1)
if at*2 < q.count {
for i := 0; i < at; i++ {
prev := q.prev(rm)
q.buf[prev], q.buf[rm] = q.buf[rm], q.buf[prev]
rm = prev
}
return q.PopFront()
}
swaps := q.count - at - 1
for i := 0; i < swaps; i++ {
next := q.next(rm)
q.buf[rm], q.buf[next] = q.buf[next], q.buf[rm]
rm = next
}
return q.PopBack()
}
// SetBaseCap sets a base capacity so that at least the specified number of
// items can always be stored without resizing.
func (q *Deque[T]) SetBaseCap(baseCap int) {
minCap := minCapacity
for minCap < baseCap {
minCap <<= 1
}
q.minCap = minCap
}
// Swap exchanges the two values at idxA and idxB. It panics if either index is
// out of range.
func (q *Deque[T]) Swap(idxA, idxB int) {
q.checkRange(idxA)
q.checkRange(idxB)
if idxA == idxB {
return
}
realA := (q.head + idxA) & (len(q.buf) - 1)
realB := (q.head + idxB) & (len(q.buf) - 1)
q.buf[realA], q.buf[realB] = q.buf[realB], q.buf[realA]
}
func (q *Deque[T]) checkRange(i int) {
if i < 0 || i >= q.count {
panic(fmt.Sprintf("deque: index out of range %d with length %d", i, q.Len()))
}
}
// prev returns the previous buffer position wrapping around buffer.
func (q *Deque[T]) prev(i int) int {
return (i - 1) & (len(q.buf) - 1) // bitwise modulus
}
// next returns the next buffer position wrapping around buffer.
func (q *Deque[T]) next(i int) int {
return (i + 1) & (len(q.buf) - 1) // bitwise modulus
}
// growIfFull resizes up if the buffer is full.
func (q *Deque[T]) growIfFull() {
if q.count != len(q.buf) {
return
}
if len(q.buf) == 0 {
if q.minCap == 0 {
q.minCap = minCapacity
}
q.buf = make([]T, q.minCap)
return
}
q.resize(q.count << 1)
}
// shrinkIfExcess resize down if the buffer 1/4 full.
func (q *Deque[T]) shrinkIfExcess() {
if len(q.buf) > q.minCap && (q.count<<2) == len(q.buf) {
q.resize(q.count << 1)
}
}
func (q *Deque[T]) shrinkToFit() {
if len(q.buf) > q.minCap && (q.count<<2) <= len(q.buf) {
if q.count == 0 {
q.head = 0
q.tail = 0
q.buf = make([]T, q.minCap)
return
}
c := q.minCap
for c < q.count {
c <<= 1
}
q.resize(c)
}
}
// resize resizes the deque to fit exactly twice its current contents. This is
// used to grow the queue when it is full, and also to shrink it when it is
// only a quarter full.
func (q *Deque[T]) resize(newSize int) {
newBuf := make([]T, newSize)
if q.tail > q.head {
copy(newBuf, q.buf[q.head:q.tail])
} else {
n := copy(newBuf, q.buf[q.head:])
copy(newBuf[n:], q.buf[:q.tail])
}
q.head = 0
q.tail = q.count
q.buf = newBuf
}

View File

@@ -68,8 +68,10 @@ type LruCache[K comparable, V any] struct {
// New creates an LruCache // New creates an LruCache
func New[K comparable, V any](options ...Option[K, V]) *LruCache[K, V] { func New[K comparable, V any](options ...Option[K, V]) *LruCache[K, V] {
lc := &LruCache[K, V]{} lc := &LruCache[K, V]{
lc.Clear() lru: list.New[*entry[K, V]](),
cache: make(map[K]*list.Element[*entry[K, V]]),
}
for _, option := range options { for _, option := range options {
option(lc) option(lc)
@@ -78,14 +80,6 @@ func New[K comparable, V any](options ...Option[K, V]) *LruCache[K, V] {
return lc return lc
} }
func (c *LruCache[K, V]) Clear() {
c.mu.Lock()
defer c.mu.Unlock()
c.lru = list.New[*entry[K, V]]()
c.cache = make(map[K]*list.Element[*entry[K, V]])
}
// Get returns any representation of a cached response and a bool // Get returns any representation of a cached response and a bool
// set to true if the key was found. // set to true if the key was found.
func (c *LruCache[K, V]) Get(key K) (V, bool) { func (c *LruCache[K, V]) Get(key K) (V, bool) {
@@ -229,10 +223,6 @@ func (c *LruCache[K, V]) Delete(key K) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
c.delete(key)
}
func (c *LruCache[K, V]) delete(key K) {
if le, ok := c.cache[key]; ok { if le, ok := c.cache[key]; ok {
c.deleteElement(le) c.deleteElement(le)
} }
@@ -256,32 +246,13 @@ func (c *LruCache[K, V]) deleteElement(le *list.Element[*entry[K, V]]) {
} }
} }
// Compute either sets the computed new value for the key or deletes func (c *LruCache[K, V]) Clear() error {
// the value for the key. When the delete result of the valueFn function
// is set to true, the value will be deleted, if it exists. When delete
// is set to false, the value is updated to the newValue.
// The ok result indicates whether value was computed and stored, thus, is
// present in the map. The actual result contains the new value in cases where
// the value was computed and stored.
func (c *LruCache[K, V]) Compute(
key K,
valueFn func(oldValue V, loaded bool) (newValue V, delete bool),
) (actual V, ok bool) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
if el := c.get(key); el != nil { c.cache = make(map[K]*list.Element[*entry[K, V]])
actual, ok = el.value, true
} return nil
if newValue, del := valueFn(actual, ok); del {
if ok { // data not in cache, so needn't delete
c.delete(key)
}
return lo.Empty[V](), false
} else {
c.set(key, newValue)
return newValue, true
}
} }
type entry[K comparable, V any] struct { type entry[K comparable, V any] struct {

View File

@@ -1,19 +0,0 @@
package maphash
import "hash/maphash"
type Seed = maphash.Seed
func MakeSeed() Seed {
return maphash.MakeSeed()
}
type Hash = maphash.Hash
func Bytes(seed Seed, b []byte) uint64 {
return maphash.Bytes(seed, b)
}
func String(seed Seed, s string) uint64 {
return maphash.String(seed, s)
}

View File

@@ -1,142 +0,0 @@
//go:build !go1.24
package maphash
import "unsafe"
func Comparable[T comparable](s Seed, v T) uint64 {
return comparableHash(*(*seedTyp)(unsafe.Pointer(&s)), v)
}
func comparableHash[T comparable](seed seedTyp, v T) uint64 {
s := seed.s
var m map[T]struct{}
mTyp := iTypeOf(m)
var hasher func(unsafe.Pointer, uintptr) uintptr
hasher = (*iMapType)(unsafe.Pointer(mTyp)).Hasher
p := escape(unsafe.Pointer(&v))
if ptrSize == 8 {
return uint64(hasher(p, uintptr(s)))
}
lo := hasher(p, uintptr(s))
hi := hasher(p, uintptr(s>>32))
return uint64(hi)<<32 | uint64(lo)
}
// WriteComparable adds x to the data hashed by h.
func WriteComparable[T comparable](h *Hash, x T) {
// writeComparable (not in purego mode) directly operates on h.state
// without using h.buf. Mix in the buffer length so it won't
// commute with a buffered write, which either changes h.n or changes
// h.state.
hash := (*hashTyp)(unsafe.Pointer(h))
if hash.n != 0 {
hash.state.s = comparableHash(hash.state, hash.n)
}
hash.state.s = comparableHash(hash.state, x)
}
// go/src/hash/maphash/maphash.go
type hashTyp struct {
_ [0]func() // not comparable
seed seedTyp // initial seed used for this hash
state seedTyp // current hash of all flushed bytes
buf [128]byte // unflushed byte buffer
n int // number of unflushed bytes
}
type seedTyp struct {
s uint64
}
type iTFlag uint8
type iKind uint8
type iNameOff int32
// TypeOff is the offset to a type from moduledata.types. See resolveTypeOff in runtime.
type iTypeOff int32
type iType struct {
Size_ uintptr
PtrBytes uintptr // number of (prefix) bytes in the type that can contain pointers
Hash uint32 // hash of type; avoids computation in hash tables
TFlag iTFlag // extra type information flags
Align_ uint8 // alignment of variable with this type
FieldAlign_ uint8 // alignment of struct field with this type
Kind_ iKind // enumeration for C
// function for comparing objects of this type
// (ptr to object A, ptr to object B) -> ==?
Equal func(unsafe.Pointer, unsafe.Pointer) bool
// GCData stores the GC type data for the garbage collector.
// Normally, GCData points to a bitmask that describes the
// ptr/nonptr fields of the type. The bitmask will have at
// least PtrBytes/ptrSize bits.
// If the TFlagGCMaskOnDemand bit is set, GCData is instead a
// **byte and the pointer to the bitmask is one dereference away.
// The runtime will build the bitmask if needed.
// (See runtime/type.go:getGCMask.)
// Note: multiple types may have the same value of GCData,
// including when TFlagGCMaskOnDemand is set. The types will, of course,
// have the same pointer layout (but not necessarily the same size).
GCData *byte
Str iNameOff // string form
PtrToThis iTypeOff // type for pointer to this type, may be zero
}
type iMapType struct {
iType
Key *iType
Elem *iType
Group *iType // internal type representing a slot group
// function for hashing keys (ptr to key, seed) -> hash
Hasher func(unsafe.Pointer, uintptr) uintptr
}
func iTypeOf(a any) *iType {
eface := *(*iEmptyInterface)(unsafe.Pointer(&a))
// Types are either static (for compiler-created types) or
// heap-allocated but always reachable (for reflection-created
// types, held in the central map). So there is no need to
// escape types. noescape here help avoid unnecessary escape
// of v.
return (*iType)(noescape(unsafe.Pointer(eface.Type)))
}
type iEmptyInterface struct {
Type *iType
Data unsafe.Pointer
}
// noescape hides a pointer from escape analysis. noescape is
// the identity function but escape analysis doesn't think the
// output depends on the input. noescape is inlined and currently
// compiles down to zero instructions.
// USE CAREFULLY!
//
// nolint:all
//
//go:nosplit
//goland:noinspection ALL
func noescape(p unsafe.Pointer) unsafe.Pointer {
x := uintptr(p)
return unsafe.Pointer(x ^ 0)
}
var alwaysFalse bool
var escapeSink any
// escape forces any pointers in x to escape to the heap.
func escape[T any](x T) T {
if alwaysFalse {
escapeSink = x
}
return x
}
// ptrSize is the size of a pointer in bytes - unsafe.Sizeof(uintptr(0)) but as an ideal constant.
// It is also the size of the machine's native word size (that is, 4 on 32-bit systems, 8 on 64-bit).
const ptrSize = 4 << (^uintptr(0) >> 63)
const testComparableAllocations = false

View File

@@ -1,15 +0,0 @@
//go:build go1.24
package maphash
import "hash/maphash"
func Comparable[T comparable](seed Seed, v T) uint64 {
return maphash.Comparable(seed, v)
}
func WriteComparable[T comparable](h *Hash, x T) {
maphash.WriteComparable(h, x)
}
const testComparableAllocations = true

View File

@@ -1,534 +0,0 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package maphash
import (
"bytes"
"fmt"
"hash"
"math"
"reflect"
"strings"
"testing"
"unsafe"
rand "github.com/metacubex/randv2"
)
func TestUnseededHash(t *testing.T) {
m := map[uint64]struct{}{}
for i := 0; i < 1000; i++ {
h := new(Hash)
m[h.Sum64()] = struct{}{}
}
if len(m) < 900 {
t.Errorf("empty hash not sufficiently random: got %d, want 1000", len(m))
}
}
func TestSeededHash(t *testing.T) {
s := MakeSeed()
m := map[uint64]struct{}{}
for i := 0; i < 1000; i++ {
h := new(Hash)
h.SetSeed(s)
m[h.Sum64()] = struct{}{}
}
if len(m) != 1 {
t.Errorf("seeded hash is random: got %d, want 1", len(m))
}
}
func TestHashGrouping(t *testing.T) {
b := bytes.Repeat([]byte("foo"), 100)
hh := make([]*Hash, 7)
for i := range hh {
hh[i] = new(Hash)
}
for _, h := range hh[1:] {
h.SetSeed(hh[0].Seed())
}
hh[0].Write(b)
hh[1].WriteString(string(b))
writeByte := func(h *Hash, b byte) {
err := h.WriteByte(b)
if err != nil {
t.Fatalf("WriteByte: %v", err)
}
}
writeSingleByte := func(h *Hash, b byte) {
_, err := h.Write([]byte{b})
if err != nil {
t.Fatalf("Write single byte: %v", err)
}
}
writeStringSingleByte := func(h *Hash, b byte) {
_, err := h.WriteString(string([]byte{b}))
if err != nil {
t.Fatalf("WriteString single byte: %v", err)
}
}
for i, x := range b {
writeByte(hh[2], x)
writeSingleByte(hh[3], x)
if i == 0 {
writeByte(hh[4], x)
} else {
writeSingleByte(hh[4], x)
}
writeStringSingleByte(hh[5], x)
if i == 0 {
writeByte(hh[6], x)
} else {
writeStringSingleByte(hh[6], x)
}
}
sum := hh[0].Sum64()
for i, h := range hh {
if sum != h.Sum64() {
t.Errorf("hash %d not identical to a single Write", i)
}
}
if sum1 := Bytes(hh[0].Seed(), b); sum1 != hh[0].Sum64() {
t.Errorf("hash using Bytes not identical to a single Write")
}
if sum1 := String(hh[0].Seed(), string(b)); sum1 != hh[0].Sum64() {
t.Errorf("hash using String not identical to a single Write")
}
}
func TestHashBytesVsString(t *testing.T) {
s := "foo"
b := []byte(s)
h1 := new(Hash)
h2 := new(Hash)
h2.SetSeed(h1.Seed())
n1, err1 := h1.WriteString(s)
if n1 != len(s) || err1 != nil {
t.Fatalf("WriteString(s) = %d, %v, want %d, nil", n1, err1, len(s))
}
n2, err2 := h2.Write(b)
if n2 != len(b) || err2 != nil {
t.Fatalf("Write(b) = %d, %v, want %d, nil", n2, err2, len(b))
}
if h1.Sum64() != h2.Sum64() {
t.Errorf("hash of string and bytes not identical")
}
}
func TestHashHighBytes(t *testing.T) {
// See issue 34925.
const N = 10
m := map[uint64]struct{}{}
for i := 0; i < N; i++ {
h := new(Hash)
h.WriteString("foo")
m[h.Sum64()>>32] = struct{}{}
}
if len(m) < N/2 {
t.Errorf("from %d seeds, wanted at least %d different hashes; got %d", N, N/2, len(m))
}
}
func TestRepeat(t *testing.T) {
h1 := new(Hash)
h1.WriteString("testing")
sum1 := h1.Sum64()
h1.Reset()
h1.WriteString("testing")
sum2 := h1.Sum64()
if sum1 != sum2 {
t.Errorf("different sum after resetting: %#x != %#x", sum1, sum2)
}
h2 := new(Hash)
h2.SetSeed(h1.Seed())
h2.WriteString("testing")
sum3 := h2.Sum64()
if sum1 != sum3 {
t.Errorf("different sum on the same seed: %#x != %#x", sum1, sum3)
}
}
func TestSeedFromSum64(t *testing.T) {
h1 := new(Hash)
h1.WriteString("foo")
x := h1.Sum64() // seed generated here
h2 := new(Hash)
h2.SetSeed(h1.Seed())
h2.WriteString("foo")
y := h2.Sum64()
if x != y {
t.Errorf("hashes don't match: want %x, got %x", x, y)
}
}
func TestSeedFromSeed(t *testing.T) {
h1 := new(Hash)
h1.WriteString("foo")
_ = h1.Seed() // seed generated here
x := h1.Sum64()
h2 := new(Hash)
h2.SetSeed(h1.Seed())
h2.WriteString("foo")
y := h2.Sum64()
if x != y {
t.Errorf("hashes don't match: want %x, got %x", x, y)
}
}
func TestSeedFromFlush(t *testing.T) {
b := make([]byte, 65)
h1 := new(Hash)
h1.Write(b) // seed generated here
x := h1.Sum64()
h2 := new(Hash)
h2.SetSeed(h1.Seed())
h2.Write(b)
y := h2.Sum64()
if x != y {
t.Errorf("hashes don't match: want %x, got %x", x, y)
}
}
func TestSeedFromReset(t *testing.T) {
h1 := new(Hash)
h1.WriteString("foo")
h1.Reset() // seed generated here
h1.WriteString("foo")
x := h1.Sum64()
h2 := new(Hash)
h2.SetSeed(h1.Seed())
h2.WriteString("foo")
y := h2.Sum64()
if x != y {
t.Errorf("hashes don't match: want %x, got %x", x, y)
}
}
func negativeZero[T float32 | float64]() T {
var f T
f = -f
return f
}
func TestComparable(t *testing.T) {
testComparable(t, int64(2))
testComparable(t, uint64(8))
testComparable(t, uintptr(12))
testComparable(t, any("s"))
testComparable(t, "s")
testComparable(t, true)
testComparable(t, new(float64))
testComparable(t, float64(9))
testComparable(t, complex128(9i+1))
testComparable(t, struct{}{})
testComparable(t, struct {
i int
u uint
b bool
f float64
p *int
a any
}{i: 9, u: 1, b: true, f: 9.9, p: new(int), a: 1})
type S struct {
s string
}
s1 := S{s: heapStr(t)}
s2 := S{s: heapStr(t)}
if unsafe.StringData(s1.s) == unsafe.StringData(s2.s) {
t.Fatalf("unexpected two heapStr ptr equal")
}
if s1.s != s2.s {
t.Fatalf("unexpected two heapStr value not equal")
}
testComparable(t, s1, s2)
testComparable(t, s1.s, s2.s)
testComparable(t, float32(0), negativeZero[float32]())
testComparable(t, float64(0), negativeZero[float64]())
testComparableNoEqual(t, math.NaN(), math.NaN())
testComparableNoEqual(t, [2]string{"a", ""}, [2]string{"", "a"})
testComparableNoEqual(t, struct{ a, b string }{"foo", ""}, struct{ a, b string }{"", "foo"})
testComparableNoEqual(t, struct{ a, b any }{int(0), struct{}{}}, struct{ a, b any }{struct{}{}, int(0)})
}
func testComparableNoEqual[T comparable](t *testing.T, v1, v2 T) {
seed := MakeSeed()
if Comparable(seed, v1) == Comparable(seed, v2) {
t.Fatalf("Comparable(seed, %v) == Comparable(seed, %v)", v1, v2)
}
}
var heapStrValue = []byte("aTestString")
func heapStr(t *testing.T) string {
return string(heapStrValue)
}
func testComparable[T comparable](t *testing.T, v T, v2 ...T) {
t.Run(TypeFor[T]().String(), func(t *testing.T) {
var a, b T = v, v
if len(v2) != 0 {
b = v2[0]
}
var pa *T = &a
seed := MakeSeed()
if Comparable(seed, a) != Comparable(seed, b) {
t.Fatalf("Comparable(seed, %v) != Comparable(seed, %v)", a, b)
}
old := Comparable(seed, pa)
stackGrow(8192)
new := Comparable(seed, pa)
if old != new {
t.Fatal("Comparable(seed, ptr) != Comparable(seed, ptr)")
}
})
}
var use byte
//go:noinline
func stackGrow(dep int) {
if dep == 0 {
return
}
var local [1024]byte
// make sure local is allocated on the stack.
local[rand.Uint64()%1024] = byte(rand.Uint64())
use = local[rand.Uint64()%1024]
stackGrow(dep - 1)
}
func TestWriteComparable(t *testing.T) {
testWriteComparable(t, int64(2))
testWriteComparable(t, uint64(8))
testWriteComparable(t, uintptr(12))
testWriteComparable(t, any("s"))
testWriteComparable(t, "s")
testComparable(t, true)
testWriteComparable(t, new(float64))
testWriteComparable(t, float64(9))
testWriteComparable(t, complex128(9i+1))
testWriteComparable(t, struct{}{})
testWriteComparable(t, struct {
i int
u uint
b bool
f float64
p *int
a any
}{i: 9, u: 1, b: true, f: 9.9, p: new(int), a: 1})
type S struct {
s string
}
s1 := S{s: heapStr(t)}
s2 := S{s: heapStr(t)}
if unsafe.StringData(s1.s) == unsafe.StringData(s2.s) {
t.Fatalf("unexpected two heapStr ptr equal")
}
if s1.s != s2.s {
t.Fatalf("unexpected two heapStr value not equal")
}
testWriteComparable(t, s1, s2)
testWriteComparable(t, s1.s, s2.s)
testWriteComparable(t, float32(0), negativeZero[float32]())
testWriteComparable(t, float64(0), negativeZero[float64]())
testWriteComparableNoEqual(t, math.NaN(), math.NaN())
testWriteComparableNoEqual(t, [2]string{"a", ""}, [2]string{"", "a"})
testWriteComparableNoEqual(t, struct{ a, b string }{"foo", ""}, struct{ a, b string }{"", "foo"})
testWriteComparableNoEqual(t, struct{ a, b any }{int(0), struct{}{}}, struct{ a, b any }{struct{}{}, int(0)})
}
func testWriteComparableNoEqual[T comparable](t *testing.T, v1, v2 T) {
seed := MakeSeed()
h1 := Hash{}
h2 := Hash{}
*(*Seed)(unsafe.Pointer(&h1)), *(*Seed)(unsafe.Pointer(&h2)) = seed, seed
WriteComparable(&h1, v1)
WriteComparable(&h2, v2)
if h1.Sum64() == h2.Sum64() {
t.Fatalf("WriteComparable(seed, %v) == WriteComparable(seed, %v)", v1, v2)
}
}
func testWriteComparable[T comparable](t *testing.T, v T, v2 ...T) {
t.Run(TypeFor[T]().String(), func(t *testing.T) {
var a, b T = v, v
if len(v2) != 0 {
b = v2[0]
}
var pa *T = &a
h1 := Hash{}
h2 := Hash{}
*(*Seed)(unsafe.Pointer(&h1)) = MakeSeed()
h2 = h1
WriteComparable(&h1, a)
WriteComparable(&h2, b)
if h1.Sum64() != h2.Sum64() {
t.Fatalf("WriteComparable(h, %v) != WriteComparable(h, %v)", a, b)
}
WriteComparable(&h1, pa)
old := h1.Sum64()
stackGrow(8192)
WriteComparable(&h2, pa)
new := h2.Sum64()
if old != new {
t.Fatal("WriteComparable(seed, ptr) != WriteComparable(seed, ptr)")
}
})
}
func TestComparableShouldPanic(t *testing.T) {
s := []byte("s")
a := any(s)
defer func() {
e := recover()
err, ok := e.(error)
if !ok {
t.Fatalf("Comaparable(any([]byte)) should panic")
}
want := "hash of unhashable type []uint8"
if s := err.Error(); !strings.Contains(s, want) {
t.Fatalf("want %s, got %s", want, s)
}
}()
Comparable(MakeSeed(), a)
}
func TestWriteComparableNoncommute(t *testing.T) {
seed := MakeSeed()
var h1, h2 Hash
h1.SetSeed(seed)
h2.SetSeed(seed)
h1.WriteString("abc")
WriteComparable(&h1, 123)
WriteComparable(&h2, 123)
h2.WriteString("abc")
if h1.Sum64() == h2.Sum64() {
t.Errorf("WriteComparable and WriteString unexpectedly commute")
}
}
func TestComparableAllocations(t *testing.T) {
if !testComparableAllocations {
t.Skip("test broken in old golang version")
}
seed := MakeSeed()
x := heapStr(t)
allocs := testing.AllocsPerRun(10, func() {
s := "s" + x
Comparable(seed, s)
})
if allocs > 0 {
t.Errorf("got %v allocs, want 0", allocs)
}
type S struct {
a int
b string
}
allocs = testing.AllocsPerRun(10, func() {
s := S{123, "s" + x}
Comparable(seed, s)
})
if allocs > 0 {
t.Errorf("got %v allocs, want 0", allocs)
}
}
// Make sure a Hash implements the hash.Hash and hash.Hash64 interfaces.
var _ hash.Hash = &Hash{}
var _ hash.Hash64 = &Hash{}
func benchmarkSize(b *testing.B, size int) {
h := &Hash{}
buf := make([]byte, size)
s := string(buf)
b.Run("Write", func(b *testing.B) {
b.SetBytes(int64(size))
for i := 0; i < b.N; i++ {
h.Reset()
h.Write(buf)
h.Sum64()
}
})
b.Run("Bytes", func(b *testing.B) {
b.SetBytes(int64(size))
seed := h.Seed()
for i := 0; i < b.N; i++ {
Bytes(seed, buf)
}
})
b.Run("String", func(b *testing.B) {
b.SetBytes(int64(size))
seed := h.Seed()
for i := 0; i < b.N; i++ {
String(seed, s)
}
})
}
func BenchmarkHash(b *testing.B) {
sizes := []int{4, 8, 16, 32, 64, 256, 320, 1024, 4096, 16384}
for _, size := range sizes {
b.Run(fmt.Sprint("n=", size), func(b *testing.B) {
benchmarkSize(b, size)
})
}
}
func benchmarkComparable[T comparable](b *testing.B, v T) {
b.Run(TypeFor[T]().String(), func(b *testing.B) {
seed := MakeSeed()
for i := 0; i < b.N; i++ {
Comparable(seed, v)
}
})
}
func BenchmarkComparable(b *testing.B) {
type testStruct struct {
i int
u uint
b bool
f float64
p *int
a any
}
benchmarkComparable(b, int64(2))
benchmarkComparable(b, uint64(8))
benchmarkComparable(b, uintptr(12))
benchmarkComparable(b, any("s"))
benchmarkComparable(b, "s")
benchmarkComparable(b, true)
benchmarkComparable(b, new(float64))
benchmarkComparable(b, float64(9))
benchmarkComparable(b, complex128(9i+1))
benchmarkComparable(b, struct{}{})
benchmarkComparable(b, testStruct{i: 9, u: 1, b: true, f: 9.9, p: new(int), a: 1})
}
// TypeFor returns the [Type] that represents the type argument T.
func TypeFor[T any]() reflect.Type {
var v T
if t := reflect.TypeOf(v); t != nil {
return t // optimize for T being a non-interface kind
}
return reflect.TypeOf((*T)(nil)).Elem() // only for an interface kind
}

View File

@@ -3,37 +3,29 @@ package net
import ( import (
"context" "context"
"net" "net"
"github.com/metacubex/mihomo/common/contextutils"
) )
// SetupContextForConn is a helper function that starts connection I/O interrupter. // SetupContextForConn is a helper function that starts connection I/O interrupter goroutine.
// if ctx be canceled before done called, it will close the connection.
// should use like this:
//
// func streamConn(ctx context.Context, conn net.Conn) (_ net.Conn, err error) {
// if ctx.Done() != nil {
// done := N.SetupContextForConn(ctx, conn)
// defer done(&err)
// }
// conn, err := xxx
// return conn, err
// }
func SetupContextForConn(ctx context.Context, conn net.Conn) (done func(*error)) { func SetupContextForConn(ctx context.Context, conn net.Conn) (done func(*error)) {
stopc := make(chan struct{}) var (
stop := contextutils.AfterFunc(ctx, func() { quit = make(chan struct{})
// Close the connection, discarding the error interrupt = make(chan error, 1)
_ = conn.Close() )
close(stopc) go func() {
}) select {
case <-quit:
interrupt <- nil
case <-ctx.Done():
// Close the connection, discarding the error
_ = conn.Close()
interrupt <- ctx.Err()
}
}()
return func(inputErr *error) { return func(inputErr *error) {
if !stop() { close(quit)
// The AfterFunc was started, wait for it to complete. if ctxErr := <-interrupt; ctxErr != nil && inputErr != nil {
<-stopc // Return context error to user.
if ctxErr := ctx.Err(); ctxErr != nil && inputErr != nil { inputErr = &ctxErr
// Return context error to user.
*inputErr = ctxErr
}
} }
} }
} }

View File

@@ -1,103 +0,0 @@
package net_test
import (
"context"
"errors"
"net"
"testing"
"time"
N "github.com/metacubex/mihomo/common/net"
"github.com/stretchr/testify/assert"
)
func testRead(ctx context.Context, conn net.Conn) (err error) {
if ctx.Done() != nil {
done := N.SetupContextForConn(ctx, conn)
defer done(&err)
}
_, err = conn.Read(make([]byte, 1))
return err
}
func TestSetupContextForConnWithCancel(t *testing.T) {
t.Parallel()
c1, c2 := N.Pipe()
defer c1.Close()
defer c2.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
errc := make(chan error)
go func() {
errc <- testRead(ctx, c1)
}()
select {
case <-errc:
t.Fatal("conn closed before cancel")
case <-time.After(100 * time.Millisecond):
cancel()
}
select {
case err := <-errc:
assert.ErrorIs(t, err, context.Canceled)
case <-time.After(100 * time.Millisecond):
t.Fatal("conn not be canceled")
}
}
func TestSetupContextForConnWithTimeout1(t *testing.T) {
t.Parallel()
c1, c2 := N.Pipe()
defer c1.Close()
defer c2.Close()
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
errc := make(chan error)
go func() {
errc <- testRead(ctx, c1)
}()
select {
case err := <-errc:
if !errors.Is(ctx.Err(), context.DeadlineExceeded) {
t.Fatal("conn closed before timeout")
}
assert.ErrorIs(t, err, context.DeadlineExceeded)
case <-time.After(200 * time.Millisecond):
t.Fatal("conn not be canceled")
}
}
func TestSetupContextForConnWithTimeout2(t *testing.T) {
t.Parallel()
c1, c2 := N.Pipe()
defer c1.Close()
defer c2.Close()
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
errc := make(chan error)
go func() {
errc <- testRead(ctx, c1)
}()
select {
case <-errc:
t.Fatal("conn closed before cancel")
case <-time.After(100 * time.Millisecond):
c2.Write(make([]byte, 1))
}
select {
case err := <-errc:
assert.Nil(t, ctx.Err())
assert.Nil(t, err)
case <-time.After(200 * time.Millisecond):
t.Fatal("conn not be canceled")
}
}

View File

@@ -7,9 +7,9 @@ import (
"github.com/metacubex/mihomo/common/atomic" "github.com/metacubex/mihomo/common/atomic"
"github.com/metacubex/sing/common/buf" "github.com/sagernet/sing/common/buf"
"github.com/metacubex/sing/common/bufio" "github.com/sagernet/sing/common/bufio"
"github.com/metacubex/sing/common/network" "github.com/sagernet/sing/common/network"
) )
type connReadResult struct { type connReadResult struct {
@@ -20,7 +20,7 @@ type connReadResult struct {
type Conn struct { type Conn struct {
network.ExtendedConn network.ExtendedConn
deadline atomic.TypedValue[time.Time] deadline atomic.TypedValue[time.Time]
pipeDeadline PipeDeadline pipeDeadline pipeDeadline
disablePipe atomic.Bool disablePipe atomic.Bool
inRead atomic.Bool inRead atomic.Bool
resultCh chan *connReadResult resultCh chan *connReadResult
@@ -34,7 +34,7 @@ func IsConn(conn any) bool {
func NewConn(conn net.Conn) *Conn { func NewConn(conn net.Conn) *Conn {
c := &Conn{ c := &Conn{
ExtendedConn: bufio.NewExtendedConn(conn), ExtendedConn: bufio.NewExtendedConn(conn),
pipeDeadline: MakePipeDeadline(), pipeDeadline: makePipeDeadline(),
resultCh: make(chan *connReadResult, 1), resultCh: make(chan *connReadResult, 1),
} }
c.resultCh <- nil c.resultCh <- nil
@@ -58,7 +58,7 @@ func (c *Conn) Read(p []byte) (n int, err error) {
c.resultCh <- nil c.resultCh <- nil
break break
} }
case <-c.pipeDeadline.Wait(): case <-c.pipeDeadline.wait():
return 0, os.ErrDeadlineExceeded return 0, os.ErrDeadlineExceeded
} }
@@ -104,7 +104,7 @@ func (c *Conn) ReadBuffer(buffer *buf.Buffer) (err error) {
c.resultCh <- nil c.resultCh <- nil
break break
} }
case <-c.pipeDeadline.Wait(): case <-c.pipeDeadline.wait():
return os.ErrDeadlineExceeded return os.ErrDeadlineExceeded
} }
@@ -130,7 +130,7 @@ func (c *Conn) SetReadDeadline(t time.Time) error {
return c.ExtendedConn.SetReadDeadline(t) return c.ExtendedConn.SetReadDeadline(t)
} }
c.deadline.Store(t) c.deadline.Store(t)
c.pipeDeadline.Set(t) c.pipeDeadline.set(t)
return nil return nil
} }
@@ -149,10 +149,6 @@ func (c *Conn) ReaderReplaceable() bool {
return c.disablePipe.Load() || c.deadline.Load().IsZero() return c.disablePipe.Load() || c.deadline.Load().IsZero()
} }
func (c *Conn) WriterReplaceable() bool {
return true
}
func (c *Conn) Upstream() any { func (c *Conn) Upstream() any {
return c.ExtendedConn return c.ExtendedConn
} }

View File

@@ -19,7 +19,7 @@ type readResult struct {
type NetPacketConn struct { type NetPacketConn struct {
net.PacketConn net.PacketConn
deadline atomic.TypedValue[time.Time] deadline atomic.TypedValue[time.Time]
pipeDeadline PipeDeadline pipeDeadline pipeDeadline
disablePipe atomic.Bool disablePipe atomic.Bool
inRead atomic.Bool inRead atomic.Bool
resultCh chan any resultCh chan any
@@ -28,7 +28,7 @@ type NetPacketConn struct {
func NewNetPacketConn(pc net.PacketConn) net.PacketConn { func NewNetPacketConn(pc net.PacketConn) net.PacketConn {
npc := &NetPacketConn{ npc := &NetPacketConn{
PacketConn: pc, PacketConn: pc,
pipeDeadline: MakePipeDeadline(), pipeDeadline: makePipeDeadline(),
resultCh: make(chan any, 1), resultCh: make(chan any, 1),
} }
npc.resultCh <- nil npc.resultCh <- nil
@@ -83,7 +83,7 @@ FOR:
c.resultCh <- nil c.resultCh <- nil
break FOR break FOR
} }
case <-c.pipeDeadline.Wait(): case <-c.pipeDeadline.wait():
return 0, nil, os.ErrDeadlineExceeded return 0, nil, os.ErrDeadlineExceeded
} }
} }
@@ -122,7 +122,7 @@ func (c *NetPacketConn) SetReadDeadline(t time.Time) error {
return c.PacketConn.SetReadDeadline(t) return c.PacketConn.SetReadDeadline(t)
} }
c.deadline.Store(t) c.deadline.Store(t)
c.pipeDeadline.Set(t) c.pipeDeadline.set(t)
return nil return nil
} }

View File

@@ -52,7 +52,7 @@ FOR:
c.netPacketConn.resultCh <- nil c.netPacketConn.resultCh <- nil
break FOR break FOR
} }
case <-c.netPacketConn.pipeDeadline.Wait(): case <-c.netPacketConn.pipeDeadline.wait():
return nil, nil, nil, os.ErrDeadlineExceeded return nil, nil, nil, os.ErrDeadlineExceeded
} }
} }

View File

@@ -6,10 +6,10 @@ import (
"github.com/metacubex/mihomo/common/net/packet" "github.com/metacubex/mihomo/common/net/packet"
"github.com/metacubex/sing/common/buf" "github.com/sagernet/sing/common/buf"
"github.com/metacubex/sing/common/bufio" "github.com/sagernet/sing/common/bufio"
M "github.com/metacubex/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/metacubex/sing/common/network" N "github.com/sagernet/sing/common/network"
) )
type SingPacketConn struct { type SingPacketConn struct {
@@ -69,7 +69,7 @@ FOR:
c.netPacketConn.resultCh <- nil c.netPacketConn.resultCh <- nil
break FOR break FOR
} }
case <-c.netPacketConn.pipeDeadline.Wait(): case <-c.netPacketConn.pipeDeadline.wait():
return M.Socksaddr{}, os.ErrDeadlineExceeded return M.Socksaddr{}, os.ErrDeadlineExceeded
} }
} }
@@ -146,7 +146,7 @@ FOR:
c.netPacketConn.resultCh <- nil c.netPacketConn.resultCh <- nil
break FOR break FOR
} }
case <-c.netPacketConn.pipeDeadline.Wait(): case <-c.netPacketConn.pipeDeadline.wait():
return nil, M.Socksaddr{}, os.ErrDeadlineExceeded return nil, M.Socksaddr{}, os.ErrDeadlineExceeded
} }
} }

Some files were not shown because too many files have changed in this diff Show More