diff --git a/route/rule.go b/route/rule.go index 12186e4b..5d0763a5 100644 --- a/route/rule.go +++ b/route/rule.go @@ -41,7 +41,9 @@ var _ adapter.Rule = (*DefaultRule)(nil) type DefaultRule struct { items []RuleItem sourceAddressItems []RuleItem + sourcePortItems []RuleItem destinationAddressItems []RuleItem + destinationPortItems []RuleItem allItems []RuleItem invert bool outbound string @@ -143,7 +145,7 @@ func NewDefaultRule(router adapter.Router, logger log.ContextLogger, options opt } if len(options.SourcePort) > 0 { item := NewPortItem(true, options.SourcePort) - rule.items = append(rule.items, item) + rule.sourcePortItems = append(rule.sourcePortItems, item) rule.allItems = append(rule.allItems, item) } if len(options.SourcePortRange) > 0 { @@ -151,12 +153,12 @@ func NewDefaultRule(router adapter.Router, logger log.ContextLogger, options opt if err != nil { return nil, E.Cause(err, "source_port_range") } - rule.items = append(rule.items, item) + rule.sourcePortItems = append(rule.sourcePortItems, item) rule.allItems = append(rule.allItems, item) } if len(options.Port) > 0 { item := NewPortItem(false, options.Port) - rule.items = append(rule.items, item) + rule.destinationPortItems = append(rule.destinationPortItems, item) rule.allItems = append(rule.allItems, item) } if len(options.PortRange) > 0 { @@ -164,7 +166,7 @@ func NewDefaultRule(router adapter.Router, logger log.ContextLogger, options opt if err != nil { return nil, E.Cause(err, "port_range") } - rule.items = append(rule.items, item) + rule.destinationPortItems = append(rule.destinationPortItems, item) rule.allItems = append(rule.allItems, item) } if len(options.ProcessName) > 0 { @@ -251,6 +253,19 @@ func (r *DefaultRule) Match(metadata *adapter.InboundContext) bool { } } + if len(r.sourcePortItems) > 0 { + var sourcePortMatch bool + for _, item := range r.sourcePortItems { + if item.Match(metadata) { + sourcePortMatch = true + break + } + } + if !sourcePortMatch { + return r.invert + } + } + if len(r.destinationAddressItems) > 0 { var destinationAddressMatch bool for _, item := range r.destinationAddressItems { @@ -264,6 +279,19 @@ func (r *DefaultRule) Match(metadata *adapter.InboundContext) bool { } } + if len(r.destinationPortItems) > 0 { + var destinationPortMatch bool + for _, item := range r.destinationPortItems { + if item.Match(metadata) { + destinationPortMatch = true + break + } + } + if !destinationPortMatch { + return r.invert + } + } + return !r.invert } diff --git a/route/rule_dns.go b/route/rule_dns.go index 6364f548..1fb1aecf 100644 --- a/route/rule_dns.go +++ b/route/rule_dns.go @@ -39,12 +39,15 @@ func NewDNSRule(router adapter.Router, logger log.ContextLogger, options option. var _ adapter.DNSRule = (*DefaultDNSRule)(nil) type DefaultDNSRule struct { - items []RuleItem - addressItems []RuleItem - allItems []RuleItem - invert bool - outbound string - disableCache bool + items []RuleItem + sourceAddressItems []RuleItem + sourcePortItems []RuleItem + destinationAddressItems []RuleItem + destinationPortItems []RuleItem + allItems []RuleItem + invert bool + outbound string + disableCache bool } func NewDefaultDNSRule(router adapter.Router, logger log.ContextLogger, options option.DefaultDNSRule) (*DefaultDNSRule, error) { @@ -90,12 +93,12 @@ func NewDefaultDNSRule(router adapter.Router, logger log.ContextLogger, options } if len(options.Domain) > 0 || len(options.DomainSuffix) > 0 { item := NewDomainItem(options.Domain, options.DomainSuffix) - rule.addressItems = append(rule.addressItems, item) + rule.destinationAddressItems = append(rule.destinationAddressItems, item) rule.allItems = append(rule.allItems, item) } if len(options.DomainKeyword) > 0 { item := NewDomainKeywordItem(options.DomainKeyword) - rule.addressItems = append(rule.addressItems, item) + rule.destinationAddressItems = append(rule.destinationAddressItems, item) rule.allItems = append(rule.allItems, item) } if len(options.DomainRegex) > 0 { @@ -103,17 +106,17 @@ func NewDefaultDNSRule(router adapter.Router, logger log.ContextLogger, options if err != nil { return nil, E.Cause(err, "domain_regex") } - rule.addressItems = append(rule.addressItems, item) + rule.destinationAddressItems = append(rule.destinationAddressItems, item) rule.allItems = append(rule.allItems, item) } if len(options.Geosite) > 0 { item := NewGeositeItem(router, logger, options.Geosite) - rule.addressItems = append(rule.addressItems, item) + rule.destinationAddressItems = append(rule.destinationAddressItems, item) rule.allItems = append(rule.allItems, item) } if len(options.SourceGeoIP) > 0 { item := NewGeoIPItem(router, logger, true, options.SourceGeoIP) - rule.items = append(rule.items, item) + rule.sourceAddressItems = append(rule.sourceAddressItems, item) rule.allItems = append(rule.allItems, item) } if len(options.SourceIPCIDR) > 0 { @@ -121,12 +124,12 @@ func NewDefaultDNSRule(router adapter.Router, logger log.ContextLogger, options if err != nil { return nil, E.Cause(err, "source_ipcidr") } - rule.items = append(rule.items, item) + rule.sourceAddressItems = append(rule.sourceAddressItems, item) rule.allItems = append(rule.allItems, item) } if len(options.SourcePort) > 0 { item := NewPortItem(true, options.SourcePort) - rule.items = append(rule.items, item) + rule.sourcePortItems = append(rule.sourcePortItems, item) rule.allItems = append(rule.allItems, item) } if len(options.SourcePortRange) > 0 { @@ -134,12 +137,12 @@ func NewDefaultDNSRule(router adapter.Router, logger log.ContextLogger, options if err != nil { return nil, E.Cause(err, "source_port_range") } - rule.items = append(rule.items, item) + rule.sourcePortItems = append(rule.sourcePortItems, item) rule.allItems = append(rule.allItems, item) } if len(options.Port) > 0 { item := NewPortItem(false, options.Port) - rule.items = append(rule.items, item) + rule.destinationPortItems = append(rule.destinationPortItems, item) rule.allItems = append(rule.allItems, item) } if len(options.PortRange) > 0 { @@ -147,7 +150,7 @@ func NewDefaultDNSRule(router adapter.Router, logger log.ContextLogger, options if err != nil { return nil, E.Cause(err, "port_range") } - rule.items = append(rule.items, item) + rule.destinationPortItems = append(rule.destinationPortItems, item) rule.allItems = append(rule.allItems, item) } if len(options.ProcessName) > 0 { @@ -225,18 +228,59 @@ func (r *DefaultDNSRule) Match(metadata *adapter.InboundContext) bool { return r.invert } } - if len(r.addressItems) > 0 { - var addressMatch bool - for _, item := range r.addressItems { + + if len(r.sourceAddressItems) > 0 { + var sourceAddressMatch bool + for _, item := range r.sourceAddressItems { if item.Match(metadata) { - addressMatch = true + sourceAddressMatch = true break } } - if !addressMatch { + if !sourceAddressMatch { return r.invert } } + + if len(r.sourcePortItems) > 0 { + var sourcePortMatch bool + for _, item := range r.sourcePortItems { + if item.Match(metadata) { + sourcePortMatch = true + break + } + } + if !sourcePortMatch { + return r.invert + } + } + + if len(r.destinationAddressItems) > 0 { + var destinationAddressMatch bool + for _, item := range r.destinationAddressItems { + if item.Match(metadata) { + destinationAddressMatch = true + break + } + } + if !destinationAddressMatch { + return r.invert + } + } + + if len(r.destinationPortItems) > 0 { + var destinationPortMatch bool + for _, item := range r.destinationPortItems { + if item.Match(metadata) { + destinationPortMatch = true + break + } + } + if !destinationPortMatch { + return r.invert + } + } + return !r.invert }