From af19ba6119a4d8cb5f973a789312761de0419bb4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sun, 24 Jul 2022 14:05:06 +0800 Subject: [PATCH] Add disable_cache option to dns rule --- adapter/router.go | 5 +++ option/dns.go | 18 ++++++--- option/route.go | 2 + route/router.go | 24 +++++++----- route/rule.go | 66 ++++++++++++++++++--------------- route/rule_dns.go | 94 ++++++++++++++++++++++++++++------------------- 6 files changed, 127 insertions(+), 82 deletions(-) diff --git a/adapter/router.go b/adapter/router.go index c42d4e0f..f18faf11 100644 --- a/adapter/router.go +++ b/adapter/router.go @@ -50,3 +50,8 @@ type Rule interface { Outbound() string String() string } + +type DNSRule interface { + Rule + DisableCache() bool +} diff --git a/option/dns.go b/option/dns.go index d7b9c6b0..acca67b0 100644 --- a/option/dns.go +++ b/option/dns.go @@ -55,6 +55,7 @@ func (r DNSRule) MarshalJSON() ([]byte, error) { var v any switch r.Type { case C.RuleTypeDefault: + r.Type = "" v = r.DefaultOptions case C.RuleTypeLogical: v = r.LogicalOptions @@ -109,6 +110,7 @@ type DefaultDNSRule struct { Outbound Listable[string] `json:"outbound,omitempty"` Invert bool `json:"invert,omitempty"` Server string `json:"server,omitempty"` + DisableCache bool `json:"disable_cache,omitempty"` } func (r DefaultDNSRule) IsValid() bool { @@ -135,13 +137,17 @@ func (r DefaultDNSRule) Equals(other DefaultDNSRule) bool { common.ComparableSliceEquals(r.UserID, other.UserID) && common.ComparableSliceEquals(r.PackageName, other.PackageName) && common.ComparableSliceEquals(r.Outbound, other.Outbound) && - r.Server == other.Server + r.Invert == other.Invert && + r.Server == other.Server && + r.DisableCache == other.DisableCache } type LogicalDNSRule struct { - Mode string `json:"mode"` - Rules []DefaultDNSRule `json:"rules,omitempty"` - Server string `json:"server,omitempty"` + Mode string `json:"mode"` + Rules []DefaultDNSRule `json:"rules,omitempty"` + Invert bool `json:"invert,omitempty"` + Server string `json:"server,omitempty"` + DisableCache bool `json:"disable_cache,omitempty"` } func (r LogicalDNSRule) IsValid() bool { @@ -151,5 +157,7 @@ func (r LogicalDNSRule) IsValid() bool { func (r LogicalDNSRule) Equals(other LogicalDNSRule) bool { return r.Mode == other.Mode && common.SliceEquals(r.Rules, other.Rules) && - r.Server == other.Server + r.Invert == other.Invert && + r.Server == other.Server && + r.DisableCache == other.DisableCache } diff --git a/option/route.go b/option/route.go index 6a71fdaa..4083fb1f 100644 --- a/option/route.go +++ b/option/route.go @@ -145,6 +145,7 @@ func (r DefaultRule) Equals(other DefaultRule) bool { type LogicalRule struct { Mode string `json:"mode"` Rules []DefaultRule `json:"rules,omitempty"` + Invert bool `json:"invert,omitempty"` Outbound string `json:"outbound,omitempty"` } @@ -155,5 +156,6 @@ func (r LogicalRule) IsValid() bool { func (r LogicalRule) Equals(other LogicalRule) bool { return r.Mode == other.Mode && common.SliceEquals(r.Rules, other.Rules) && + r.Invert == other.Invert && r.Outbound == other.Outbound } diff --git a/route/router.go b/route/router.go index b4ca61b2..634e8cf0 100644 --- a/route/router.go +++ b/route/router.go @@ -59,7 +59,7 @@ type Router struct { geositeCache map[string]adapter.Rule dnsClient *dns.Client defaultDomainStrategy dns.DomainStrategy - dnsRules []adapter.Rule + dnsRules []adapter.DNSRule defaultTransport dns.Transport transports []dns.Transport transportMap map[string]dns.Transport @@ -80,7 +80,7 @@ func NewRouter(ctx context.Context, logger log.ContextLogger, dnsLogger log.Cont dnsLogger: dnsLogger, outboundByTag: make(map[string]adapter.Outbound), rules: make([]adapter.Rule, 0, len(options.Rules)), - dnsRules: make([]adapter.Rule, 0, len(dnsOptions.Rules)), + dnsRules: make([]adapter.DNSRule, 0, len(dnsOptions.Rules)), needGeoIPDatabase: hasRule(options.Rules, isGeoIPRule) || hasDNSRule(dnsOptions.Rules, isGeoIPDNSRule), needGeositeDatabase: hasRule(options.Rules, isGeositeRule) || hasDNSRule(dnsOptions.Rules, isGeositeDNSRule), geoIPOptions: common.PtrValueOrDefault(options.GeoIP), @@ -536,15 +536,18 @@ func (r *Router) RoutePacketConnection(ctx context.Context, conn N.PacketConn, m } func (r *Router) Exchange(ctx context.Context, message *dnsmessage.Message) (*dnsmessage.Message, error) { - return r.dnsClient.Exchange(ctx, r.matchDNS(ctx), message) + ctx, transport := r.matchDNS(ctx) + return r.dnsClient.Exchange(ctx, transport, message) } func (r *Router) Lookup(ctx context.Context, domain string, strategy dns.DomainStrategy) ([]netip.Addr, error) { - return r.dnsClient.Lookup(ctx, r.matchDNS(ctx), domain, strategy) + ctx, transport := r.matchDNS(ctx) + return r.dnsClient.Lookup(ctx, transport, domain, strategy) } func (r *Router) LookupDefault(ctx context.Context, domain string) ([]netip.Addr, error) { - return r.dnsClient.Lookup(ctx, r.matchDNS(ctx), domain, r.defaultDomainStrategy) + ctx, transport := r.matchDNS(ctx) + return r.dnsClient.Lookup(ctx, transport, domain, r.defaultDomainStrategy) } func (r *Router) match(ctx context.Context, metadata *adapter.InboundContext, defaultOutbound adapter.Outbound) (adapter.Rule, adapter.Outbound) { @@ -586,23 +589,26 @@ func (r *Router) match(ctx context.Context, metadata *adapter.InboundContext, de return nil, defaultOutbound } -func (r *Router) matchDNS(ctx context.Context) dns.Transport { +func (r *Router) matchDNS(ctx context.Context) (context.Context, dns.Transport) { metadata := adapter.ContextFrom(ctx) if metadata == nil { r.dnsLogger.WarnContext(ctx, "no context: ", reflect.TypeOf(ctx)) - return r.defaultTransport + return ctx, r.defaultTransport } for i, rule := range r.dnsRules { if rule.Match(metadata) { + if rule.DisableCache() { + ctx = dns.ContextWithDisableCache(ctx, true) + } detour := rule.Outbound() r.dnsLogger.DebugContext(ctx, "match[", i, "] ", rule.String(), " => ", detour) if transport, loaded := r.transportMap[detour]; loaded { - return transport + return ctx, transport } r.dnsLogger.ErrorContext(ctx, "transport not found: ", detour) } } - return r.defaultTransport + return ctx, r.defaultTransport } func (r *Router) InterfaceBindManager() control.BindManager { diff --git a/route/rule.go b/route/rule.go index 96936697..c4148478 100644 --- a/route/rule.go +++ b/route/rule.go @@ -49,10 +49,6 @@ type DefaultRule struct { outbound string } -func (r *DefaultRule) Type() string { - return C.RuleTypeDefault -} - type RuleItem interface { Match(metadata *adapter.InboundContext) bool String() string @@ -180,6 +176,10 @@ func NewDefaultRule(router adapter.Router, logger log.ContextLogger, options opt return rule, nil } +func (r *DefaultRule) Type() string { + return C.RuleTypeDefault +} + func (r *DefaultRule) Start() error { for _, item := range r.allItems { err := common.Start(item) @@ -261,9 +261,34 @@ var _ adapter.Rule = (*LogicalRule)(nil) type LogicalRule struct { mode string rules []*DefaultRule + invert bool outbound string } +func NewLogicalRule(router adapter.Router, logger log.ContextLogger, options option.LogicalRule) (*LogicalRule, error) { + r := &LogicalRule{ + rules: make([]*DefaultRule, len(options.Rules)), + invert: options.Invert, + outbound: options.Outbound, + } + switch options.Mode { + case C.LogicalTypeAnd: + r.mode = C.LogicalTypeAnd + case C.LogicalTypeOr: + r.mode = C.LogicalTypeOr + default: + return nil, E.New("unknown logical mode: ", options.Mode) + } + for i, subRule := range options.Rules { + rule, err := NewDefaultRule(router, logger, subRule) + if err != nil { + return nil, E.Cause(err, "sub rule[", i, "]") + } + r.rules[i] = rule + } + return r, nil +} + func (r *LogicalRule) Type() string { return C.RuleTypeLogical } @@ -298,38 +323,15 @@ func (r *LogicalRule) Close() error { return nil } -func NewLogicalRule(router adapter.Router, logger log.ContextLogger, options option.LogicalRule) (*LogicalRule, error) { - r := &LogicalRule{ - rules: make([]*DefaultRule, len(options.Rules)), - outbound: options.Outbound, - } - switch options.Mode { - case C.LogicalTypeAnd: - r.mode = C.LogicalTypeAnd - case C.LogicalTypeOr: - r.mode = C.LogicalTypeOr - default: - return nil, E.New("unknown logical mode: ", options.Mode) - } - for i, subRule := range options.Rules { - rule, err := NewDefaultRule(router, logger, subRule) - if err != nil { - return nil, E.Cause(err, "sub rule[", i, "]") - } - r.rules[i] = rule - } - return r, nil -} - func (r *LogicalRule) Match(metadata *adapter.InboundContext) bool { if r.mode == C.LogicalTypeAnd { return common.All(r.rules, func(it *DefaultRule) bool { return it.Match(metadata) - }) + }) != r.invert } else { return common.Any(r.rules, func(it *DefaultRule) bool { return it.Match(metadata) - }) + }) != r.invert } } @@ -345,5 +347,9 @@ func (r *LogicalRule) String() string { case C.LogicalTypeOr: op = "||" } - return "logical(" + strings.Join(F.MapToString(r.rules), " "+op+" ") + ")" + if !r.invert { + return strings.Join(F.MapToString(r.rules), " "+op+" ") + } else { + return "!(" + strings.Join(F.MapToString(r.rules), " "+op+" ") + ")" + } } diff --git a/route/rule_dns.go b/route/rule_dns.go index 8814cab1..8fd0e2ac 100644 --- a/route/rule_dns.go +++ b/route/rule_dns.go @@ -12,7 +12,7 @@ import ( F "github.com/sagernet/sing/common/format" ) -func NewDNSRule(router adapter.Router, logger log.ContextLogger, options option.DNSRule) (adapter.Rule, error) { +func NewDNSRule(router adapter.Router, logger log.ContextLogger, options option.DNSRule) (adapter.DNSRule, error) { if common.IsEmptyByEquals(options) { return nil, E.New("empty rule config") } @@ -38,7 +38,7 @@ func NewDNSRule(router adapter.Router, logger log.ContextLogger, options option. } } -var _ adapter.Rule = (*DefaultDNSRule)(nil) +var _ adapter.DNSRule = (*DefaultDNSRule)(nil) type DefaultDNSRule struct { items []RuleItem @@ -46,16 +46,14 @@ type DefaultDNSRule struct { allItems []RuleItem invert bool outbound string -} - -func (r *DefaultDNSRule) Type() string { - return C.RuleTypeDefault + disableCache bool } func NewDefaultDNSRule(router adapter.Router, logger log.ContextLogger, options option.DefaultDNSRule) (*DefaultDNSRule, error) { rule := &DefaultDNSRule{ - invert: true, - outbound: options.Server, + invert: options.Invert, + outbound: options.Server, + disableCache: options.DisableCache, } if len(options.Inbound) > 0 { item := NewInboundRule(options.Inbound) @@ -156,6 +154,10 @@ func NewDefaultDNSRule(router adapter.Router, logger log.ContextLogger, options return rule, nil } +func (r *DefaultDNSRule) Type() string { + return C.RuleTypeDefault +} + func (r *DefaultDNSRule) Start() error { for _, item := range r.allItems { err := common.Start(item) @@ -213,16 +215,47 @@ func (r *DefaultDNSRule) Outbound() string { return r.outbound } +func (r *DefaultDNSRule) DisableCache() bool { + return r.disableCache +} + func (r *DefaultDNSRule) String() string { return strings.Join(F.MapToString(r.allItems), " ") } -var _ adapter.Rule = (*LogicalRule)(nil) +var _ adapter.DNSRule = (*LogicalDNSRule)(nil) type LogicalDNSRule struct { - mode string - rules []*DefaultDNSRule - outbound string + mode string + rules []*DefaultDNSRule + invert bool + outbound string + disableCache bool +} + +func NewLogicalDNSRule(router adapter.Router, logger log.ContextLogger, options option.LogicalDNSRule) (*LogicalDNSRule, error) { + r := &LogicalDNSRule{ + rules: make([]*DefaultDNSRule, len(options.Rules)), + invert: options.Invert, + outbound: options.Server, + disableCache: options.DisableCache, + } + switch options.Mode { + case C.LogicalTypeAnd: + r.mode = C.LogicalTypeAnd + case C.LogicalTypeOr: + r.mode = C.LogicalTypeOr + default: + return nil, E.New("unknown logical mode: ", options.Mode) + } + for i, subRule := range options.Rules { + rule, err := NewDefaultDNSRule(router, logger, subRule) + if err != nil { + return nil, E.Cause(err, "sub rule[", i, "]") + } + r.rules[i] = rule + } + return r, nil } func (r *LogicalDNSRule) Type() string { @@ -259,38 +292,15 @@ func (r *LogicalDNSRule) Close() error { return nil } -func NewLogicalDNSRule(router adapter.Router, logger log.ContextLogger, options option.LogicalDNSRule) (*LogicalDNSRule, error) { - r := &LogicalDNSRule{ - rules: make([]*DefaultDNSRule, len(options.Rules)), - outbound: options.Server, - } - switch options.Mode { - case C.LogicalTypeAnd: - r.mode = C.LogicalTypeAnd - case C.LogicalTypeOr: - r.mode = C.LogicalTypeOr - default: - return nil, E.New("unknown logical mode: ", options.Mode) - } - for i, subRule := range options.Rules { - rule, err := NewDefaultDNSRule(router, logger, subRule) - if err != nil { - return nil, E.Cause(err, "sub rule[", i, "]") - } - r.rules[i] = rule - } - return r, nil -} - func (r *LogicalDNSRule) Match(metadata *adapter.InboundContext) bool { if r.mode == C.LogicalTypeAnd { return common.All(r.rules, func(it *DefaultDNSRule) bool { return it.Match(metadata) - }) + }) != r.invert } else { return common.Any(r.rules, func(it *DefaultDNSRule) bool { return it.Match(metadata) - }) + }) != r.invert } } @@ -298,6 +308,10 @@ func (r *LogicalDNSRule) Outbound() string { return r.outbound } +func (r *LogicalDNSRule) DisableCache() bool { + return r.disableCache +} + func (r *LogicalDNSRule) String() string { var op string switch r.mode { @@ -306,5 +320,9 @@ func (r *LogicalDNSRule) String() string { case C.LogicalTypeOr: op = "||" } - return "logical(" + strings.Join(F.MapToString(r.rules), " "+op+" ") + ")" + if !r.invert { + return strings.Join(F.MapToString(r.rules), " "+op+" ") + } else { + return "!(" + strings.Join(F.MapToString(r.rules), " "+op+" ") + ")" + } }