chore: cleanup dns policy match code

This commit is contained in:
wwqgtxx
2024-08-15 20:04:24 +08:00
parent 4c10d42fbf
commit 92ec5f2236
9 changed files with 241 additions and 420 deletions

View File

@@ -1,102 +0,0 @@
package dns
import (
"net/netip"
"strings"
"github.com/metacubex/mihomo/component/geodata"
"github.com/metacubex/mihomo/component/geodata/router"
"github.com/metacubex/mihomo/component/mmdb"
"github.com/metacubex/mihomo/component/trie"
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/log"
)
type fallbackIPFilter interface {
Match(netip.Addr) bool
}
type geoipFilter struct {
code string
}
var geoIPMatcher *router.GeoIPMatcher
func (gf *geoipFilter) Match(ip netip.Addr) bool {
if !C.GeodataMode {
codes := mmdb.IPInstance().LookupCode(ip.AsSlice())
for _, code := range codes {
if !strings.EqualFold(code, gf.code) && !ip.IsPrivate() {
return true
}
}
return false
}
if geoIPMatcher == nil {
var err error
geoIPMatcher, _, err = geodata.LoadGeoIPMatcher("CN")
if err != nil {
log.Errorln("[GeoIPFilter] LoadGeoIPMatcher error: %s", err.Error())
return false
}
}
return !geoIPMatcher.Match(ip)
}
type ipnetFilter struct {
ipnet netip.Prefix
}
func (inf *ipnetFilter) Match(ip netip.Addr) bool {
return inf.ipnet.Contains(ip)
}
type fallbackDomainFilter interface {
Match(domain string) bool
}
type domainFilter struct {
tree *trie.DomainTrie[struct{}]
}
func NewDomainFilter(domains []string) *domainFilter {
df := domainFilter{tree: trie.New[struct{}]()}
for _, domain := range domains {
_ = df.tree.Insert(domain, struct{}{})
}
df.tree.Optimize()
return &df
}
func (df *domainFilter) Match(domain string) bool {
return df.tree.Search(domain) != nil
}
type geoSiteFilter struct {
matchers []router.DomainMatcher
}
func NewGeoSite(group string) (fallbackDomainFilter, error) {
if err := geodata.InitGeoSite(); err != nil {
log.Errorln("can't initial GeoSite: %s", err)
return nil, err
}
matcher, _, err := geodata.LoadGeoSiteMatcher(group)
if err != nil {
return nil, err
}
filter := &geoSiteFilter{
matchers: []router.DomainMatcher{matcher},
}
return filter, nil
}
func (gsf *geoSiteFilter) Match(domain string) bool {
for _, matcher := range gsf.matchers {
if matcher.ApplyDomain(domain) {
return true
}
}
return false
}

View File

@@ -3,7 +3,6 @@ package dns
import (
"github.com/metacubex/mihomo/component/trie"
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/constant/provider"
)
type dnsPolicy interface {
@@ -22,32 +21,14 @@ func (p domainTriePolicy) Match(domain string) []dnsClient {
return nil
}
type geositePolicy struct {
matcher fallbackDomainFilter
inverse bool
type domainRulePolicy struct {
rule C.Rule
dnsClients []dnsClient
}
func (p geositePolicy) Match(domain string) []dnsClient {
matched := p.matcher.Match(domain)
if matched != p.inverse {
func (p domainRulePolicy) Match(domain string) []dnsClient {
if ok, _ := p.rule.Match(&C.Metadata{Host: domain}); ok {
return p.dnsClients
}
return nil
}
type domainSetPolicy struct {
tunnel provider.Tunnel
name string
dnsClients []dnsClient
}
func (p domainSetPolicy) Match(domain string) []dnsClient {
if ruleProvider, ok := p.tunnel.RuleProviders()[p.name]; ok {
metadata := &C.Metadata{Host: domain}
if ok := ruleProvider.Match(metadata); ok {
return p.dnsClients
}
}
return nil
}

View File

@@ -4,13 +4,11 @@ import (
"context"
"errors"
"net/netip"
"strings"
"time"
"github.com/metacubex/mihomo/common/arc"
"github.com/metacubex/mihomo/common/lru"
"github.com/metacubex/mihomo/component/fakeip"
"github.com/metacubex/mihomo/component/geodata/router"
"github.com/metacubex/mihomo/component/resolver"
"github.com/metacubex/mihomo/component/trie"
C "github.com/metacubex/mihomo/constant"
@@ -19,7 +17,6 @@ import (
D "github.com/miekg/dns"
"github.com/samber/lo"
orderedmap "github.com/wk8/go-ordered-map/v2"
"golang.org/x/exp/maps"
"golang.org/x/sync/singleflight"
)
@@ -45,8 +42,8 @@ type Resolver struct {
hosts *trie.DomainTrie[resolver.HostValue]
main []dnsClient
fallback []dnsClient
fallbackDomainFilters []fallbackDomainFilter
fallbackIPFilters []fallbackIPFilter
fallbackDomainFilters []C.Rule
fallbackIPFilters []C.Rule
group singleflight.Group
cache dnsCache
policy []dnsPolicy
@@ -122,7 +119,7 @@ func (r *Resolver) LookupIPv6(ctx context.Context, host string) ([]netip.Addr, e
func (r *Resolver) shouldIPFallback(ip netip.Addr) bool {
for _, filter := range r.fallbackIPFilters {
if filter.Match(ip) {
if ok, _ := filter.Match(&C.Metadata{DstIP: ip}); ok {
return true
}
}
@@ -277,7 +274,7 @@ func (r *Resolver) shouldOnlyQueryFallback(m *D.Msg) bool {
}
for _, df := range r.fallbackDomainFilters {
if df.Match(domain) {
if ok, _ := df.Match(&C.Metadata{Host: domain}); ok {
return true
}
}
@@ -398,27 +395,26 @@ func (ns NameServer) Equal(ns2 NameServer) bool {
return false
}
type FallbackFilter struct {
GeoIP bool
GeoIPCode string
IPCIDR []netip.Prefix
Domain []string
GeoSite []router.DomainMatcher
type Policy struct {
Domain string
Rule C.Rule
NameServers []NameServer
}
type Config struct {
Main, Fallback []NameServer
Default []NameServer
ProxyServer []NameServer
IPv6 bool
IPv6Timeout uint
EnhancedMode C.DNSMode
FallbackFilter FallbackFilter
Pool *fakeip.Pool
Hosts *trie.DomainTrie[resolver.HostValue]
Policy *orderedmap.OrderedMap[string, []NameServer]
Tunnel provider.Tunnel
CacheAlgorithm string
Main, Fallback []NameServer
Default []NameServer
ProxyServer []NameServer
IPv6 bool
IPv6Timeout uint
EnhancedMode C.DNSMode
FallbackIPFilter []C.Rule
FallbackDomainFilter []C.Rule
Pool *fakeip.Pool
Hosts *trie.DomainTrie[resolver.HostValue]
Policy []Policy
Tunnel provider.Tunnel
CacheAlgorithm string
}
func NewResolver(config Config) *Resolver {
@@ -482,7 +478,7 @@ func NewResolver(config Config) *Resolver {
r.proxyServer = cacheTransform(config.ProxyServer)
}
if config.Policy.Len() != 0 {
if len(config.Policy) != 0 {
r.policy = make([]dnsPolicy, 0)
var triePolicy *trie.DomainTrie[[]dnsClient]
@@ -497,75 +493,20 @@ func NewResolver(config Config) *Resolver {
}
}
for pair := config.Policy.Oldest(); pair != nil; pair = pair.Next() {
domain, nameserver := pair.Key, pair.Value
if temp := strings.Split(domain, ":"); len(temp) == 2 {
prefix := temp[0]
key := temp[1]
switch prefix {
case "rule-set":
if _, ok := config.Tunnel.RuleProviders()[key]; ok {
log.Debugln("Adding rule-set policy: %s ", key)
insertPolicy(domainSetPolicy{
tunnel: config.Tunnel,
name: key,
dnsClients: cacheTransform(nameserver),
})
continue
} else {
log.Warnln("Can't found ruleset policy: %s", key)
}
case "geosite":
inverse := false
if strings.HasPrefix(key, "!") {
inverse = true
key = key[1:]
}
log.Debugln("Adding geosite policy: %s inversed %t", key, inverse)
matcher, err := NewGeoSite(key)
if err != nil {
log.Warnln("adding geosite policy %s error: %s", key, err)
continue
}
insertPolicy(geositePolicy{
matcher: matcher,
inverse: inverse,
dnsClients: cacheTransform(nameserver),
})
continue // skip triePolicy new
for _, policy := range config.Policy {
if policy.Rule != nil {
insertPolicy(domainRulePolicy{rule: policy.Rule, dnsClients: cacheTransform(policy.NameServers)})
} else {
if triePolicy == nil {
triePolicy = trie.New[[]dnsClient]()
}
_ = triePolicy.Insert(policy.Domain, cacheTransform(policy.NameServers))
}
if triePolicy == nil {
triePolicy = trie.New[[]dnsClient]()
}
_ = triePolicy.Insert(domain, cacheTransform(nameserver))
}
insertPolicy(nil)
}
fallbackIPFilters := []fallbackIPFilter{}
if config.FallbackFilter.GeoIP {
fallbackIPFilters = append(fallbackIPFilters, &geoipFilter{
code: config.FallbackFilter.GeoIPCode,
})
}
for _, ipnet := range config.FallbackFilter.IPCIDR {
fallbackIPFilters = append(fallbackIPFilters, &ipnetFilter{ipnet: ipnet})
}
r.fallbackIPFilters = fallbackIPFilters
fallbackDomainFilters := []fallbackDomainFilter{}
if len(config.FallbackFilter.Domain) != 0 {
fallbackDomainFilters = append(fallbackDomainFilters, NewDomainFilter(config.FallbackFilter.Domain))
}
if len(config.FallbackFilter.GeoSite) != 0 {
fallbackDomainFilters = append(fallbackDomainFilters, &geoSiteFilter{
matchers: config.FallbackFilter.GeoSite,
})
}
r.fallbackDomainFilters = fallbackDomainFilters
r.fallbackIPFilters = config.FallbackIPFilter
r.fallbackDomainFilters = config.FallbackDomainFilter
return r
}