Add address filter support for DNS rules

This commit is contained in:
世界 2024-02-03 17:45:27 +08:00
parent 748ff7d915
commit 98050e3b34
No known key found for this signature in database
GPG Key ID: CD109927C34A63C4
14 changed files with 283 additions and 70 deletions

View File

@ -56,6 +56,8 @@ type InboundContext struct {
SourcePortMatch bool SourcePortMatch bool
DestinationAddressMatch bool DestinationAddressMatch bool
DestinationPortMatch bool DestinationPortMatch bool
DidMatch bool
IgnoreDestinationIPCIDRMatch bool
} }
func (c *InboundContext) ResetRuleCache() { func (c *InboundContext) ResetRuleCache() {

View File

@ -84,6 +84,8 @@ type DNSRule interface {
Rule Rule
DisableCache() bool DisableCache() bool
RewriteTTL() *uint32 RewriteTTL() *uint32
WithAddressLimit() bool
MatchAddressLimit(metadata *InboundContext) bool
} }
type RuleSet interface { type RuleSet interface {
@ -97,6 +99,7 @@ type RuleSet interface {
type RuleSetMetadata struct { type RuleSetMetadata struct {
ContainsProcessRule bool ContainsProcessRule bool
ContainsWIFIRule bool ContainsWIFIRule bool
ContainsIPCIDRRule bool
} }
type RuleSetStartContext interface { type RuleSetStartContext interface {

2
go.mod
View File

@ -27,7 +27,7 @@ require (
github.com/sagernet/quic-go v0.40.1 github.com/sagernet/quic-go v0.40.1
github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691 github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691
github.com/sagernet/sing v0.3.1-beta.1 github.com/sagernet/sing v0.3.1-beta.1
github.com/sagernet/sing-dns v0.1.12 github.com/sagernet/sing-dns v0.2.0-beta.1
github.com/sagernet/sing-mux v0.2.0 github.com/sagernet/sing-mux v0.2.0
github.com/sagernet/sing-quic v0.1.8 github.com/sagernet/sing-quic v0.1.8
github.com/sagernet/sing-shadowsocks v0.2.6 github.com/sagernet/sing-shadowsocks v0.2.6

4
go.sum
View File

@ -111,8 +111,8 @@ github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691/go.mod h1:B8lp4Wk
github.com/sagernet/sing v0.2.18/go.mod h1:OL6k2F0vHmEzXz2KW19qQzu172FDgSbUSODylighuVo= github.com/sagernet/sing v0.2.18/go.mod h1:OL6k2F0vHmEzXz2KW19qQzu172FDgSbUSODylighuVo=
github.com/sagernet/sing v0.3.1-beta.1 h1:5s4hUokrd58xgmiT2baXQJmZDUvjf7qzXXIqVbPTa3o= github.com/sagernet/sing v0.3.1-beta.1 h1:5s4hUokrd58xgmiT2baXQJmZDUvjf7qzXXIqVbPTa3o=
github.com/sagernet/sing v0.3.1-beta.1/go.mod h1:9pfuAH6mZfgnz/YjP6xu5sxx882rfyjpcrTdUpd6w3g= github.com/sagernet/sing v0.3.1-beta.1/go.mod h1:9pfuAH6mZfgnz/YjP6xu5sxx882rfyjpcrTdUpd6w3g=
github.com/sagernet/sing-dns v0.1.12 h1:1HqZ+ln+Rezx/aJMStaS0d7oPeX2EobSV1NT537kyj4= github.com/sagernet/sing-dns v0.2.0-beta.1 h1:Lj+zQ3Y38GJpP2vr0/5ELQbYjhPc92Lq7/JaB/GuPTw=
github.com/sagernet/sing-dns v0.1.12/go.mod h1:rx/DTOisneQpCgNQ4jbFU/JNEtnz0lYcHXenlVzpjEU= github.com/sagernet/sing-dns v0.2.0-beta.1/go.mod h1:IxOqfSb6Zt6UVCy8fJpDxb2XxqzHUytNqeOuJfaiLu8=
github.com/sagernet/sing-mux v0.2.0 h1:4C+vd8HztJCWNYfufvgL49xaOoOHXty2+EAjnzN3IYo= github.com/sagernet/sing-mux v0.2.0 h1:4C+vd8HztJCWNYfufvgL49xaOoOHXty2+EAjnzN3IYo=
github.com/sagernet/sing-mux v0.2.0/go.mod h1:khzr9AOPocLa+g53dBplwNDz4gdsyx/YM3swtAhlkHQ= github.com/sagernet/sing-mux v0.2.0/go.mod h1:khzr9AOPocLa+g53dBplwNDz4gdsyx/YM3swtAhlkHQ=
github.com/sagernet/sing-quic v0.1.8 h1:G4iBXAKIII+uTzd55oZ/9cAQswGjlvHh/0yKMQioDS0= github.com/sagernet/sing-quic v0.1.8 h1:G4iBXAKIII+uTzd55oZ/9cAQswGjlvHh/0yKMQioDS0=

View File

@ -77,6 +77,9 @@ type DefaultDNSRule struct {
DomainRegex Listable[string] `json:"domain_regex,omitempty"` DomainRegex Listable[string] `json:"domain_regex,omitempty"`
Geosite Listable[string] `json:"geosite,omitempty"` Geosite Listable[string] `json:"geosite,omitempty"`
SourceGeoIP Listable[string] `json:"source_geoip,omitempty"` SourceGeoIP Listable[string] `json:"source_geoip,omitempty"`
GeoIP Listable[string] `json:"geoip,omitempty"`
IPCIDR Listable[string] `json:"ip_cidr,omitempty"`
IPIsPrivate bool `json:"ip_is_private,omitempty"`
SourceIPCIDR Listable[string] `json:"source_ip_cidr,omitempty"` SourceIPCIDR Listable[string] `json:"source_ip_cidr,omitempty"`
SourceIPIsPrivate bool `json:"source_ip_is_private,omitempty"` SourceIPIsPrivate bool `json:"source_ip_is_private,omitempty"`
SourcePort Listable[uint16] `json:"source_port,omitempty"` SourcePort Listable[uint16] `json:"source_port,omitempty"`

View File

@ -2,13 +2,13 @@ package route
import ( import (
"context" "context"
"errors"
"net/netip" "net/netip"
"strings" "strings"
"time" "time"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-dns" "github.com/sagernet/sing-dns"
"github.com/sagernet/sing/common/cache" "github.com/sagernet/sing/common/cache"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
@ -37,12 +37,17 @@ func (m *DNSReverseMapping) Query(address netip.Addr) (string, bool) {
return domain, loaded return domain, loaded
} }
func (r *Router) matchDNS(ctx context.Context, allowFakeIP bool) (context.Context, dns.Transport, dns.DomainStrategy) { func (r *Router) matchDNS(ctx context.Context, allowFakeIP bool, index int) (context.Context, dns.Transport, dns.DomainStrategy, adapter.DNSRule, int) {
metadata := adapter.ContextFrom(ctx) metadata := adapter.ContextFrom(ctx)
if metadata == nil { if metadata == nil {
panic("no context") panic("no context")
} }
for i, rule := range r.dnsRules { if index < len(r.dnsRules) {
dnsRules := r.dnsRules
if index != -1 {
dnsRules = dnsRules[index+1:]
}
for ruleIndex, rule := range dnsRules {
metadata.ResetRuleCache() metadata.ResetRuleCache()
if rule.Match(metadata) { if rule.Match(metadata) {
detour := rule.Outbound() detour := rule.Outbound()
@ -54,7 +59,11 @@ func (r *Router) matchDNS(ctx context.Context, allowFakeIP bool) (context.Contex
if _, isFakeIP := transport.(adapter.FakeIPTransport); isFakeIP && !allowFakeIP { if _, isFakeIP := transport.(adapter.FakeIPTransport); isFakeIP && !allowFakeIP {
continue continue
} }
r.dnsLogger.DebugContext(ctx, "match[", i, "] ", rule.String(), " => ", detour) displayRuleIndex := ruleIndex
if index != -1 {
displayRuleIndex += index + 1
}
r.dnsLogger.DebugContext(ctx, "match[", displayRuleIndex, "] ", rule.String(), " => ", detour)
if rule.DisableCache() { if rule.DisableCache() {
ctx = dns.ContextWithDisableCache(ctx, true) ctx = dns.ContextWithDisableCache(ctx, true)
} }
@ -62,16 +71,17 @@ func (r *Router) matchDNS(ctx context.Context, allowFakeIP bool) (context.Contex
ctx = dns.ContextWithRewriteTTL(ctx, *rewriteTTL) ctx = dns.ContextWithRewriteTTL(ctx, *rewriteTTL)
} }
if domainStrategy, dsLoaded := r.transportDomainStrategy[transport]; dsLoaded { if domainStrategy, dsLoaded := r.transportDomainStrategy[transport]; dsLoaded {
return ctx, transport, domainStrategy return ctx, transport, domainStrategy, rule, ruleIndex
} else { } else {
return ctx, transport, r.defaultDomainStrategy return ctx, transport, r.defaultDomainStrategy, rule, ruleIndex
}
} }
} }
} }
if domainStrategy, dsLoaded := r.transportDomainStrategy[r.defaultTransport]; dsLoaded { if domainStrategy, dsLoaded := r.transportDomainStrategy[r.defaultTransport]; dsLoaded {
return ctx, r.defaultTransport, domainStrategy return ctx, r.defaultTransport, domainStrategy, nil, -1
} else { } else {
return ctx, r.defaultTransport, r.defaultDomainStrategy return ctx, r.defaultTransport, r.defaultDomainStrategy, nil, -1
} }
} }
@ -86,7 +96,8 @@ func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, er
) )
response, cached = r.dnsClient.ExchangeCache(ctx, message) response, cached = r.dnsClient.ExchangeCache(ctx, message)
if !cached { if !cached {
ctx, metadata := adapter.AppendContext(ctx) var metadata *adapter.InboundContext
ctx, metadata = adapter.AppendContext(ctx)
if len(message.Question) > 0 { if len(message.Question) > 0 {
metadata.QueryType = message.Question[0].Qtype metadata.QueryType = message.Question[0].Qtype
switch metadata.QueryType { switch metadata.QueryType {
@ -97,16 +108,46 @@ func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, er
} }
metadata.Domain = fqdnToDomain(message.Question[0].Name) metadata.Domain = fqdnToDomain(message.Question[0].Name)
} }
ctx, transport, strategy := r.matchDNS(ctx, true) var (
ctx, cancel := context.WithTimeout(ctx, C.DNSTimeout) transport dns.Transport
defer cancel() strategy dns.DomainStrategy
response, err = r.dnsClient.Exchange(ctx, transport, message, strategy) rule adapter.DNSRule
if err != nil && len(message.Question) > 0 { ruleIndex int
)
ruleIndex = -1
for {
var (
dnsCtx context.Context
cancel context.CancelFunc
addressLimit bool
)
dnsCtx, transport, strategy, rule, ruleIndex = r.matchDNS(ctx, true, ruleIndex)
dnsCtx, cancel = context.WithTimeout(dnsCtx, C.DNSTimeout)
if rule != nil && rule.WithAddressLimit() && isAddressQuery(message) {
addressLimit = true
response, err = r.dnsClient.ExchangeWithResponseCheck(dnsCtx, transport, message, strategy, func(response *mDNS.Msg) bool {
metadata.DestinationAddresses, _ = dns.MessageToAddresses(response)
return rule.MatchAddressLimit(metadata)
})
} else {
addressLimit = false
response, err = r.dnsClient.Exchange(dnsCtx, transport, message, strategy)
}
cancel()
if err != nil {
if errors.Is(err, dns.ErrResponseRejected) {
r.dnsLogger.DebugContext(ctx, E.Cause(err, "response rejected for ", formatQuestion(message.Question[0].String())))
} else if len(message.Question) > 0 {
r.dnsLogger.ErrorContext(ctx, E.Cause(err, "exchange failed for ", formatQuestion(message.Question[0].String()))) r.dnsLogger.ErrorContext(ctx, E.Cause(err, "exchange failed for ", formatQuestion(message.Question[0].String())))
} else {
r.dnsLogger.ErrorContext(ctx, E.Cause(err, "exchange failed for <empty query>"))
}
}
if !addressLimit || err == nil {
break
} }
} }
if len(message.Question) > 0 && response != nil {
LogDNSAnswers(r.dnsLogger, ctx, message.Question[0].Name, response.Answer)
} }
if r.dnsReverseMapping != nil && len(message.Question) > 0 && response != nil && len(response.Answer) > 0 { if r.dnsReverseMapping != nil && len(message.Question) > 0 && response != nil && len(response.Answer) > 0 {
for _, answer := range response.Answer { for _, answer := range response.Answer {
@ -125,22 +166,57 @@ func (r *Router) Lookup(ctx context.Context, domain string, strategy dns.DomainS
r.dnsLogger.DebugContext(ctx, "lookup domain ", domain) r.dnsLogger.DebugContext(ctx, "lookup domain ", domain)
ctx, metadata := adapter.AppendContext(ctx) ctx, metadata := adapter.AppendContext(ctx)
metadata.Domain = domain metadata.Domain = domain
ctx, transport, transportStrategy := r.matchDNS(ctx, false) var (
transport dns.Transport
transportStrategy dns.DomainStrategy
rule adapter.DNSRule
ruleIndex int
resultAddrs []netip.Addr
err error
)
ruleIndex = -1
for {
var (
dnsCtx context.Context
cancel context.CancelFunc
addressLimit bool
)
metadata.ResetRuleCache()
metadata.DestinationAddresses = nil
dnsCtx, transport, transportStrategy, rule, ruleIndex = r.matchDNS(ctx, false, ruleIndex)
if strategy == dns.DomainStrategyAsIS { if strategy == dns.DomainStrategyAsIS {
strategy = transportStrategy strategy = transportStrategy
} }
ctx, cancel := context.WithTimeout(ctx, C.DNSTimeout) dnsCtx, cancel = context.WithTimeout(dnsCtx, C.DNSTimeout)
defer cancel() if rule != nil && rule.WithAddressLimit() {
addrs, err := r.dnsClient.Lookup(ctx, transport, domain, strategy) addressLimit = true
if len(addrs) > 0 { resultAddrs, err = r.dnsClient.LookupWithResponseCheck(dnsCtx, transport, domain, strategy, func(responseAddrs []netip.Addr) bool {
r.dnsLogger.InfoContext(ctx, "lookup succeed for ", domain, ": ", strings.Join(F.MapToString(addrs), " ")) metadata.DestinationAddresses = responseAddrs
} else if err != nil { return rule.MatchAddressLimit(metadata)
r.dnsLogger.ErrorContext(ctx, E.Cause(err, "lookup failed for ", domain)) })
} else { } else {
addressLimit = false
resultAddrs, err = r.dnsClient.Lookup(dnsCtx, transport, domain, strategy)
}
cancel()
if err != nil {
if errors.Is(err, dns.ErrResponseRejected) {
r.dnsLogger.DebugContext(ctx, "response rejected for ", domain)
} else {
r.dnsLogger.ErrorContext(ctx, E.Cause(err, "lookup failed for ", domain))
}
} else if len(resultAddrs) == 0 {
r.dnsLogger.ErrorContext(ctx, "lookup failed for ", domain, ": empty result") r.dnsLogger.ErrorContext(ctx, "lookup failed for ", domain, ": empty result")
err = dns.RCodeNameError err = dns.RCodeNameError
} }
return addrs, err if !addressLimit || err == nil {
break
}
}
if len(resultAddrs) > 0 {
r.dnsLogger.InfoContext(ctx, "lookup succeed for ", domain, ": ", strings.Join(F.MapToString(resultAddrs), " "))
}
return resultAddrs, err
} }
func (r *Router) LookupDefault(ctx context.Context, domain string) ([]netip.Addr, error) { func (r *Router) LookupDefault(ctx context.Context, domain string) ([]netip.Addr, error) {
@ -154,10 +230,13 @@ func (r *Router) ClearDNSCache() {
} }
} }
func LogDNSAnswers(logger log.ContextLogger, ctx context.Context, domain string, answers []mDNS.RR) { func isAddressQuery(message *mDNS.Msg) bool {
for _, answer := range answers { for _, question := range message.Question {
logger.InfoContext(ctx, "exchanged ", domain, " ", mDNS.Type(answer.Header().Rrtype).String(), " ", formatQuestion(answer.String())) if question.Qtype == mDNS.TypeA || question.Qtype == mDNS.TypeAAAA {
return true
} }
}
return false
} }
func fqdnToDomain(fqdn string) string { func fqdnToDomain(fqdn string) string {

View File

@ -59,7 +59,7 @@ func isGeoIPRule(rule option.DefaultRule) bool {
} }
func isGeoIPDNSRule(rule option.DefaultDNSRule) bool { func isGeoIPDNSRule(rule option.DefaultDNSRule) bool {
return len(rule.SourceGeoIP) > 0 && common.Any(rule.SourceGeoIP, notPrivateNode) return len(rule.SourceGeoIP) > 0 && common.Any(rule.SourceGeoIP, notPrivateNode) || len(rule.GeoIP) > 0 && common.Any(rule.GeoIP, notPrivateNode)
} }
func isGeositeRule(rule option.DefaultRule) bool { func isGeositeRule(rule option.DefaultRule) bool {
@ -97,3 +97,7 @@ func isWIFIDNSRule(rule option.DefaultDNSRule) bool {
func isWIFIHeadlessRule(rule option.DefaultHeadlessRule) bool { func isWIFIHeadlessRule(rule option.DefaultHeadlessRule) bool {
return len(rule.WIFISSID) > 0 || len(rule.WIFIBSSID) > 0 return len(rule.WIFISSID) > 0 || len(rule.WIFIBSSID) > 0
} }
func isIPCIDRHeadlessRule(rule option.DefaultHeadlessRule) bool {
return len(rule.IPCIDR) > 0 || rule.IPSet != nil
}

View File

@ -15,6 +15,7 @@ type abstractDefaultRule struct {
sourceAddressItems []RuleItem sourceAddressItems []RuleItem
sourcePortItems []RuleItem sourcePortItems []RuleItem
destinationAddressItems []RuleItem destinationAddressItems []RuleItem
destinationIPCIDRItems []RuleItem
destinationPortItems []RuleItem destinationPortItems []RuleItem
allItems []RuleItem allItems []RuleItem
ruleSetItem RuleItem ruleSetItem RuleItem
@ -64,6 +65,7 @@ func (r *abstractDefaultRule) Match(metadata *adapter.InboundContext) bool {
} }
if len(r.sourceAddressItems) > 0 && !metadata.SourceAddressMatch { if len(r.sourceAddressItems) > 0 && !metadata.SourceAddressMatch {
metadata.DidMatch = true
for _, item := range r.sourceAddressItems { for _, item := range r.sourceAddressItems {
if item.Match(metadata) { if item.Match(metadata) {
metadata.SourceAddressMatch = true metadata.SourceAddressMatch = true
@ -73,6 +75,7 @@ func (r *abstractDefaultRule) Match(metadata *adapter.InboundContext) bool {
} }
if len(r.sourcePortItems) > 0 && !metadata.SourcePortMatch { if len(r.sourcePortItems) > 0 && !metadata.SourcePortMatch {
metadata.DidMatch = true
for _, item := range r.sourcePortItems { for _, item := range r.sourcePortItems {
if item.Match(metadata) { if item.Match(metadata) {
metadata.SourcePortMatch = true metadata.SourcePortMatch = true
@ -82,6 +85,7 @@ func (r *abstractDefaultRule) Match(metadata *adapter.InboundContext) bool {
} }
if len(r.destinationAddressItems) > 0 && !metadata.DestinationAddressMatch { if len(r.destinationAddressItems) > 0 && !metadata.DestinationAddressMatch {
metadata.DidMatch = true
for _, item := range r.destinationAddressItems { for _, item := range r.destinationAddressItems {
if item.Match(metadata) { if item.Match(metadata) {
metadata.DestinationAddressMatch = true metadata.DestinationAddressMatch = true
@ -90,7 +94,18 @@ func (r *abstractDefaultRule) Match(metadata *adapter.InboundContext) bool {
} }
} }
if !metadata.IgnoreDestinationIPCIDRMatch && len(r.destinationIPCIDRItems) > 0 && !metadata.DestinationAddressMatch {
metadata.DidMatch = true
for _, item := range r.destinationIPCIDRItems {
if item.Match(metadata) {
metadata.DestinationAddressMatch = true
break
}
}
}
if len(r.destinationPortItems) > 0 && !metadata.DestinationPortMatch { if len(r.destinationPortItems) > 0 && !metadata.DestinationPortMatch {
metadata.DidMatch = true
for _, item := range r.destinationPortItems { for _, item := range r.destinationPortItems {
if item.Match(metadata) { if item.Match(metadata) {
metadata.DestinationPortMatch = true metadata.DestinationPortMatch = true
@ -100,6 +115,9 @@ func (r *abstractDefaultRule) Match(metadata *adapter.InboundContext) bool {
} }
for _, item := range r.items { for _, item := range r.items {
if _, isRuleSet := item.(*RuleSetItem); !isRuleSet {
metadata.DidMatch = true
}
if !item.Match(metadata) { if !item.Match(metadata) {
return r.invert return r.invert
} }
@ -113,7 +131,7 @@ func (r *abstractDefaultRule) Match(metadata *adapter.InboundContext) bool {
return r.invert return r.invert
} }
if len(r.destinationAddressItems) > 0 && !metadata.DestinationAddressMatch { if ((!metadata.IgnoreDestinationIPCIDRMatch && len(r.destinationIPCIDRItems) > 0) || len(r.destinationAddressItems) > 0) && !metadata.DestinationAddressMatch {
return r.invert return r.invert
} }
@ -121,6 +139,10 @@ func (r *abstractDefaultRule) Match(metadata *adapter.InboundContext) bool {
return r.invert return r.invert
} }
if !metadata.DidMatch {
return false
}
return !r.invert return !r.invert
} }

View File

@ -109,7 +109,7 @@ func NewDefaultRule(router adapter.Router, logger log.ContextLogger, options opt
} }
if len(options.GeoIP) > 0 { if len(options.GeoIP) > 0 {
item := NewGeoIPItem(router, logger, false, options.GeoIP) item := NewGeoIPItem(router, logger, false, options.GeoIP)
rule.destinationAddressItems = append(rule.destinationAddressItems, item) rule.destinationIPCIDRItems = append(rule.destinationIPCIDRItems, item)
rule.allItems = append(rule.allItems, item) rule.allItems = append(rule.allItems, item)
} }
if len(options.SourceIPCIDR) > 0 { if len(options.SourceIPCIDR) > 0 {
@ -130,12 +130,12 @@ func NewDefaultRule(router adapter.Router, logger log.ContextLogger, options opt
if err != nil { if err != nil {
return nil, E.Cause(err, "ipcidr") return nil, E.Cause(err, "ipcidr")
} }
rule.destinationAddressItems = append(rule.destinationAddressItems, item) rule.destinationIPCIDRItems = append(rule.destinationIPCIDRItems, item)
rule.allItems = append(rule.allItems, item) rule.allItems = append(rule.allItems, item)
} }
if options.IPIsPrivate { if options.IPIsPrivate {
item := NewIPIsPrivateItem(false) item := NewIPIsPrivateItem(false)
rule.destinationAddressItems = append(rule.destinationAddressItems, item) rule.destinationIPCIDRItems = append(rule.destinationIPCIDRItems, item)
rule.allItems = append(rule.allItems, item) rule.allItems = append(rule.allItems, item)
} }
if len(options.SourcePort) > 0 { if len(options.SourcePort) > 0 {

View File

@ -5,6 +5,7 @@ import (
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
) )
@ -111,6 +112,11 @@ func NewDefaultDNSRule(router adapter.Router, logger log.ContextLogger, options
rule.sourceAddressItems = append(rule.sourceAddressItems, item) rule.sourceAddressItems = append(rule.sourceAddressItems, item)
rule.allItems = append(rule.allItems, item) rule.allItems = append(rule.allItems, item)
} }
if len(options.GeoIP) > 0 {
item := NewGeoIPItem(router, logger, false, options.GeoIP)
rule.destinationIPCIDRItems = append(rule.destinationIPCIDRItems, item)
rule.allItems = append(rule.allItems, item)
}
if len(options.SourceIPCIDR) > 0 { if len(options.SourceIPCIDR) > 0 {
item, err := NewIPCIDRItem(true, options.SourceIPCIDR) item, err := NewIPCIDRItem(true, options.SourceIPCIDR)
if err != nil { if err != nil {
@ -119,11 +125,24 @@ func NewDefaultDNSRule(router adapter.Router, logger log.ContextLogger, options
rule.sourceAddressItems = append(rule.sourceAddressItems, item) rule.sourceAddressItems = append(rule.sourceAddressItems, item)
rule.allItems = append(rule.allItems, item) rule.allItems = append(rule.allItems, item)
} }
if len(options.IPCIDR) > 0 {
item, err := NewIPCIDRItem(false, options.IPCIDR)
if err != nil {
return nil, E.Cause(err, "ip_cidr")
}
rule.destinationIPCIDRItems = append(rule.destinationIPCIDRItems, item)
rule.allItems = append(rule.allItems, item)
}
if options.SourceIPIsPrivate { if options.SourceIPIsPrivate {
item := NewIPIsPrivateItem(true) item := NewIPIsPrivateItem(true)
rule.sourceAddressItems = append(rule.sourceAddressItems, item) rule.sourceAddressItems = append(rule.sourceAddressItems, item)
rule.allItems = append(rule.allItems, item) rule.allItems = append(rule.allItems, item)
} }
if options.IPIsPrivate {
item := NewIPIsPrivateItem(false)
rule.destinationIPCIDRItems = append(rule.destinationIPCIDRItems, item)
rule.allItems = append(rule.allItems, item)
}
if len(options.SourcePort) > 0 { if len(options.SourcePort) > 0 {
item := NewPortItem(true, options.SourcePort) item := NewPortItem(true, options.SourcePort)
rule.sourcePortItems = append(rule.sourcePortItems, item) rule.sourcePortItems = append(rule.sourcePortItems, item)
@ -211,6 +230,34 @@ func (r *DefaultDNSRule) RewriteTTL() *uint32 {
return r.rewriteTTL return r.rewriteTTL
} }
func (r *DefaultDNSRule) WithAddressLimit() bool {
if len(r.destinationIPCIDRItems) > 0 {
return true
}
for _, rawRule := range r.items {
ruleSet, isRuleSet := rawRule.(*RuleSetItem)
if !isRuleSet {
continue
}
if ruleSet.ContainsIPCIDRRule() {
return true
}
}
return false
}
func (r *DefaultDNSRule) Match(metadata *adapter.InboundContext) bool {
metadata.IgnoreDestinationIPCIDRMatch = true
defer func() {
metadata.IgnoreDestinationIPCIDRMatch = false
}()
return r.abstractDefaultRule.Match(metadata)
}
func (r *DefaultDNSRule) MatchAddressLimit(metadata *adapter.InboundContext) bool {
return r.abstractDefaultRule.Match(metadata)
}
var _ adapter.DNSRule = (*LogicalDNSRule)(nil) var _ adapter.DNSRule = (*LogicalDNSRule)(nil)
type LogicalDNSRule struct { type LogicalDNSRule struct {
@ -254,3 +301,47 @@ func (r *LogicalDNSRule) DisableCache() bool {
func (r *LogicalDNSRule) RewriteTTL() *uint32 { func (r *LogicalDNSRule) RewriteTTL() *uint32 {
return r.rewriteTTL return r.rewriteTTL
} }
func (r *LogicalDNSRule) WithAddressLimit() bool {
for _, rawRule := range r.rules {
switch rule := rawRule.(type) {
case *DefaultDNSRule:
if rule.WithAddressLimit() {
return true
}
case *LogicalDNSRule:
if rule.WithAddressLimit() {
return true
}
}
}
return false
}
func (r *LogicalDNSRule) Match(metadata *adapter.InboundContext) bool {
if r.mode == C.LogicalTypeAnd {
return common.All(r.rules, func(it adapter.HeadlessRule) bool {
metadata.ResetRuleCache()
return it.(adapter.DNSRule).Match(metadata)
}) != r.invert
} else {
return common.Any(r.rules, func(it adapter.HeadlessRule) bool {
metadata.ResetRuleCache()
return it.(adapter.DNSRule).Match(metadata)
}) != r.invert
}
}
func (r *LogicalDNSRule) MatchAddressLimit(metadata *adapter.InboundContext) bool {
if r.mode == C.LogicalTypeAnd {
return common.All(r.rules, func(it adapter.HeadlessRule) bool {
metadata.ResetRuleCache()
return it.(adapter.DNSRule).MatchAddressLimit(metadata)
}) != r.invert
} else {
return common.Any(r.rules, func(it adapter.HeadlessRule) bool {
metadata.ResetRuleCache()
return it.(adapter.DNSRule).MatchAddressLimit(metadata)
}) != r.invert
}
}

View File

@ -80,11 +80,11 @@ func NewDefaultHeadlessRule(router adapter.Router, options option.DefaultHeadles
if err != nil { if err != nil {
return nil, E.Cause(err, "ipcidr") return nil, E.Cause(err, "ipcidr")
} }
rule.destinationAddressItems = append(rule.destinationAddressItems, item) rule.destinationIPCIDRItems = append(rule.destinationIPCIDRItems, item)
rule.allItems = append(rule.allItems, item) rule.allItems = append(rule.allItems, item)
} else if options.IPSet != nil { } else if options.IPSet != nil {
item := NewRawIPCIDRItem(false, options.IPSet) item := NewRawIPCIDRItem(false, options.IPSet)
rule.destinationAddressItems = append(rule.destinationAddressItems, item) rule.destinationIPCIDRItems = append(rule.destinationIPCIDRItems, item)
rule.allItems = append(rule.allItems, item) rule.allItems = append(rule.allItems, item)
} }
if len(options.SourcePort) > 0 { if len(options.SourcePort) > 0 {

View File

@ -4,6 +4,7 @@ import (
"strings" "strings"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format" F "github.com/sagernet/sing/common/format"
) )
@ -13,7 +14,7 @@ var _ RuleItem = (*RuleSetItem)(nil)
type RuleSetItem struct { type RuleSetItem struct {
router adapter.Router router adapter.Router
tagList []string tagList []string
setList []adapter.HeadlessRule setList []adapter.RuleSet
ipcidrMatchSource bool ipcidrMatchSource bool
} }
@ -46,6 +47,12 @@ func (r *RuleSetItem) Match(metadata *adapter.InboundContext) bool {
return false return false
} }
func (r *RuleSetItem) ContainsIPCIDRRule() bool {
return common.Any(r.setList, func(ruleSet adapter.RuleSet) bool {
return ruleSet.Metadata().ContainsIPCIDRRule
})
}
func (r *RuleSetItem) String() string { func (r *RuleSetItem) String() string {
if len(r.tagList) == 1 { if len(r.tagList) == 1 {
return F.ToString("rule_set=", r.tagList[0]) return F.ToString("rule_set=", r.tagList[0])

View File

@ -55,6 +55,7 @@ func NewLocalRuleSet(router adapter.Router, options option.RuleSet) (*LocalRuleS
var metadata adapter.RuleSetMetadata var metadata adapter.RuleSetMetadata
metadata.ContainsProcessRule = hasHeadlessRule(plainRuleSet.Rules, isProcessHeadlessRule) metadata.ContainsProcessRule = hasHeadlessRule(plainRuleSet.Rules, isProcessHeadlessRule)
metadata.ContainsWIFIRule = hasHeadlessRule(plainRuleSet.Rules, isWIFIHeadlessRule) metadata.ContainsWIFIRule = hasHeadlessRule(plainRuleSet.Rules, isWIFIHeadlessRule)
metadata.ContainsIPCIDRRule = hasHeadlessRule(plainRuleSet.Rules, isIPCIDRHeadlessRule)
return &LocalRuleSet{rules, metadata}, nil return &LocalRuleSet{rules, metadata}, nil
} }

View File

@ -150,6 +150,7 @@ func (s *RemoteRuleSet) loadBytes(content []byte) error {
} }
s.metadata.ContainsProcessRule = hasHeadlessRule(plainRuleSet.Rules, isProcessHeadlessRule) s.metadata.ContainsProcessRule = hasHeadlessRule(plainRuleSet.Rules, isProcessHeadlessRule)
s.metadata.ContainsWIFIRule = hasHeadlessRule(plainRuleSet.Rules, isWIFIHeadlessRule) s.metadata.ContainsWIFIRule = hasHeadlessRule(plainRuleSet.Rules, isWIFIHeadlessRule)
s.metadata.ContainsIPCIDRRule = hasHeadlessRule(plainRuleSet.Rules, isIPCIDRHeadlessRule)
s.rules = rules s.rules = rules
return nil return nil
} }