From dd8eb2ce119c62d55e21c93bc0bf897beaaeb631 Mon Sep 17 00:00:00 2001 From: PuerNya Date: Tue, 13 Aug 2024 10:38:10 +0800 Subject: [PATCH] use ipondemand mode when route rule matching (cherry picked from commit c33f91170b0fa5bc10cbcce5afe3b648f1a5bff1) --- adapter/experimental.go | 1 + adapter/outbound.go | 4 +++ adapter/router.go | 2 ++ option/rule.go | 10 +++--- outbound/default.go | 74 +++++++++++++++++++++++++++++++------- outbound/direct.go | 9 +++-- outbound/selector.go | 12 +++++++ outbound/socks.go | 9 ++++- outbound/urltest.go | 14 ++++++++ outbound/wireguard.go | 5 +++ route/router.go | 41 ++++++++++++++++++--- route/rule_abstract.go | 35 ++++++++++++++---- route/rule_default.go | 13 ++++--- route/rule_dns.go | 1 + route/rule_set_abstract.go | 4 +++ 15 files changed, 199 insertions(+), 35 deletions(-) diff --git a/adapter/experimental.go b/adapter/experimental.go index 0cab5ed5..4a74b92d 100644 --- a/adapter/experimental.go +++ b/adapter/experimental.go @@ -102,6 +102,7 @@ type OutboundGroup interface { Outbound Now() string All() []string + SelectedOutbound(network string) Outbound } type URLTestGroup interface { diff --git a/adapter/outbound.go b/adapter/outbound.go index b6980fb9..bdfe1f19 100644 --- a/adapter/outbound.go +++ b/adapter/outbound.go @@ -18,3 +18,7 @@ type Outbound interface { NewConnection(ctx context.Context, conn net.Conn, metadata InboundContext) error NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata InboundContext) error } + +type OutboundUseIP interface { + UseIP() bool +} diff --git a/adapter/router.go b/adapter/router.go index 8649061b..0d5a1612 100644 --- a/adapter/router.go +++ b/adapter/router.go @@ -80,6 +80,7 @@ type HeadlessRule interface { Match(metadata *InboundContext) bool RuleCount() uint64 String() string + ContainsDestinationIPCIDRRule() bool } type Rule interface { @@ -87,6 +88,7 @@ type Rule interface { Service Type() string UpdateGeosite() error + SkipResolve() bool Outbound() string } diff --git a/option/rule.go b/option/rule.go index 5f15645c..03b55576 100644 --- a/option/rule.go +++ b/option/rule.go @@ -98,6 +98,7 @@ type _DefaultRule struct { RuleSet Listable[string] `json:"rule_set,omitempty"` RuleSetIPCIDRMatchSource bool `json:"rule_set_ip_cidr_match_source,omitempty"` Invert bool `json:"invert,omitempty"` + SkipResolve bool `json:"skip_resolve,omitempty"` Outbound string `json:"outbound,omitempty"` // Deprecated: renamed to rule_set_ip_cidr_match_source @@ -128,10 +129,11 @@ func (r *DefaultRule) IsValid() bool { } type LogicalRule struct { - Mode string `json:"mode"` - Rules []Rule `json:"rules,omitempty"` - Invert bool `json:"invert,omitempty"` - Outbound string `json:"outbound,omitempty"` + Mode string `json:"mode"` + Rules []Rule `json:"rules,omitempty"` + Invert bool `json:"invert,omitempty"` + SkipResolve bool `json:"skip_resolve,omitempty"` + Outbound string `json:"outbound,omitempty"` } func (r LogicalRule) IsValid() bool { diff --git a/outbound/default.go b/outbound/default.go index 972aca94..994c1875 100644 --- a/outbound/default.go +++ b/outbound/default.go @@ -81,15 +81,40 @@ func NewDirectConnection(ctx context.Context, router adapter.Router, this N.Dial ctx = adapter.WithContext(ctx, &metadata) var outConn net.Conn var err error - if len(metadata.DestinationAddresses) > 0 { - outConn, err = N.DialSerial(ctx, this, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses) - } else if metadata.Destination.IsFqdn() { - var destinationAddresses []netip.Addr - destinationAddresses, err = router.Lookup(ctx, metadata.Destination.Fqdn, domainStrategy) + addresses := metadata.DestinationAddresses + if len(addresses) == 0 && metadata.Destination.IsFqdn() { + addresses, err = router.Lookup(ctx, metadata.Destination.Fqdn, domainStrategy) if err != nil { return N.ReportHandshakeFailure(conn, err) } - outConn, err = N.DialSerial(ctx, this, N.NetworkTCP, metadata.Destination, destinationAddresses) + } + if len(addresses) > 0 { + addresses4 := common.Filter(addresses, func(address netip.Addr) bool { + return address.Is4() || address.Is4In6() + }) + addresses6 := common.Filter(addresses, func(address netip.Addr) bool { + return address.Is6() && !address.Is4In6() + }) + connFunc := func(primaries []netip.Addr, fallbacks []netip.Addr) (net.Conn, error) { + if len(primaries) > 0 { + if conn, err := N.DialSerial(ctx, this, N.NetworkTCP, metadata.Destination, primaries); err == nil || len(fallbacks) == 0 { + return conn, err + } + } + return N.DialSerial(ctx, this, N.NetworkTCP, metadata.Destination, fallbacks) + } + switch domainStrategy { + case dns.DomainStrategyAsIS: + outConn, err = connFunc(addresses, nil) + case dns.DomainStrategyUseIPv4: + outConn, err = connFunc(addresses4, nil) + case dns.DomainStrategyUseIPv6: + outConn, err = connFunc(addresses6, nil) + case dns.DomainStrategyPreferIPv4: + outConn, err = connFunc(addresses4, addresses6) + case dns.DomainStrategyPreferIPv6: + outConn, err = connFunc(addresses6, addresses4) + } } else { outConn, err = this.DialContext(ctx, N.NetworkTCP, metadata.Destination) } @@ -150,15 +175,40 @@ func NewDirectPacketConnection(ctx context.Context, router adapter.Router, this var outConn net.PacketConn var destinationAddress netip.Addr var err error - if len(metadata.DestinationAddresses) > 0 { - outConn, destinationAddress, err = N.ListenSerial(ctx, this, metadata.Destination, metadata.DestinationAddresses) - } else if metadata.Destination.IsFqdn() { - var destinationAddresses []netip.Addr - destinationAddresses, err = router.Lookup(ctx, metadata.Destination.Fqdn, domainStrategy) + addresses := metadata.DestinationAddresses + if len(addresses) == 0 && metadata.Destination.IsFqdn() { + addresses, err = router.Lookup(ctx, metadata.Destination.Fqdn, domainStrategy) if err != nil { return N.ReportHandshakeFailure(conn, err) } - outConn, destinationAddress, err = N.ListenSerial(ctx, this, metadata.Destination, destinationAddresses) + } + if len(addresses) > 0 { + addresses4 := common.Filter(addresses, func(address netip.Addr) bool { + return address.Is4() || address.Is4In6() + }) + addresses6 := common.Filter(addresses, func(address netip.Addr) bool { + return address.Is6() && !address.Is4In6() + }) + connFunc := func(primaries []netip.Addr, fallbacks []netip.Addr) (net.PacketConn, netip.Addr, error) { + if len(primaries) > 0 { + if conn, addr, err := N.ListenSerial(ctx, this, metadata.Destination, primaries); err == nil || len(fallbacks) == 0 { + return conn, addr, err + } + } + return N.ListenSerial(ctx, this, metadata.Destination, fallbacks) + } + switch domainStrategy { + case dns.DomainStrategyAsIS: + outConn, destinationAddress, err = connFunc(addresses, nil) + case dns.DomainStrategyUseIPv4: + outConn, destinationAddress, err = connFunc(addresses4, nil) + case dns.DomainStrategyUseIPv6: + outConn, destinationAddress, err = connFunc(addresses6, nil) + case dns.DomainStrategyPreferIPv4: + outConn, destinationAddress, err = connFunc(addresses4, addresses6) + case dns.DomainStrategyPreferIPv6: + outConn, destinationAddress, err = connFunc(addresses6, addresses4) + } } else { outConn, err = this.ListenPacket(ctx, metadata.Destination) } diff --git a/outbound/direct.go b/outbound/direct.go index 11f650e4..2fa2b1bc 100644 --- a/outbound/direct.go +++ b/outbound/direct.go @@ -19,8 +19,9 @@ import ( ) var ( - _ adapter.Outbound = (*Direct)(nil) - _ N.ParallelDialer = (*Direct)(nil) + _ adapter.Outbound = (*Direct)(nil) + _ adapter.OutboundUseIP = (*Direct)(nil) + _ N.ParallelDialer = (*Direct)(nil) ) type Direct struct { @@ -168,3 +169,7 @@ func (h *Direct) NewPacketConnection(ctx context.Context, conn N.PacketConn, met } return NewPacketConnection(ctx, h, conn, metadata) } + +func (h *Direct) UseIP() bool { + return true +} diff --git a/outbound/selector.go b/outbound/selector.go index e801daea..069a6324 100644 --- a/outbound/selector.go +++ b/outbound/selector.go @@ -100,6 +100,10 @@ func (s *Selector) Now() string { return s.selected.Tag() } +func (s *Selector) SelectedOutbound(network string) adapter.Outbound { + return s.selected +} + func (s *Selector) All() []string { return s.tags } @@ -158,3 +162,11 @@ func RealTag(detour adapter.Outbound) string { } return detour.Tag() } + +func RealOutboundTag(detour adapter.Outbound, network string) string { + group, isGroup := detour.(adapter.OutboundGroup) + if !isGroup { + return detour.Tag() + } + return RealOutboundTag(group.SelectedOutbound(network), network) +} diff --git a/outbound/socks.go b/outbound/socks.go index 063f7b95..e18fe457 100644 --- a/outbound/socks.go +++ b/outbound/socks.go @@ -18,7 +18,10 @@ import ( "github.com/sagernet/sing/protocol/socks" ) -var _ adapter.Outbound = (*Socks)(nil) +var ( + _ adapter.Outbound = (*Socks)(nil) + _ adapter.OutboundUseIP = (*Socks)(nil) +) type Socks struct { myOutboundAdapter @@ -128,3 +131,7 @@ func (h *Socks) NewPacketConnection(ctx context.Context, conn N.PacketConn, meta return NewPacketConnection(ctx, h, conn, metadata) } } + +func (h *Socks) UseIP() bool { + return h.resolve +} diff --git a/outbound/urltest.go b/outbound/urltest.go index c6e38ec5..eb762bd7 100644 --- a/outbound/urltest.go +++ b/outbound/urltest.go @@ -111,6 +111,20 @@ func (s *URLTest) Now() string { return "" } +func (s *URLTest) SelectedOutbound(network string) adapter.Outbound { + switch network { + case N.NetworkTCP: + if s.group.selectedOutboundTCP != nil { + return s.group.selectedOutboundTCP + } + case N.NetworkUDP: + if s.group.selectedOutboundUDP != nil { + return s.group.selectedOutboundUDP + } + } + return s.group.outbounds[0] +} + func (s *URLTest) All() []string { return s.tags } diff --git a/outbound/wireguard.go b/outbound/wireguard.go index 3ae3f63b..e2b0323e 100644 --- a/outbound/wireguard.go +++ b/outbound/wireguard.go @@ -32,6 +32,7 @@ import ( var ( _ adapter.Outbound = (*WireGuard)(nil) + _ adapter.OutboundUseIP = (*WireGuard)(nil) _ adapter.InterfaceUpdateListener = (*WireGuard)(nil) ) @@ -241,3 +242,7 @@ func (w *WireGuard) NewConnection(ctx context.Context, conn net.Conn, metadata a func (w *WireGuard) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { return NewDirectPacketConnection(ctx, w.router, w, conn, metadata, dns.DomainStrategyAsIS) } + +func (w *WireGuard) UseIP() bool { + return true +} diff --git a/route/router.go b/route/router.go index 4f7bea40..0df83be7 100644 --- a/route/router.go +++ b/route/router.go @@ -24,7 +24,7 @@ import ( "github.com/sagernet/sing-box/experimental/libbox/platform" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" - "github.com/sagernet/sing-box/outbound" + O "github.com/sagernet/sing-box/outbound" "github.com/sagernet/sing-box/transport/fakeip" "github.com/sagernet/sing-dns" "github.com/sagernet/sing-mux" @@ -1113,12 +1113,12 @@ func (r *Router) RoutePacketConnection(ctx context.Context, conn N.PacketConn, m func (r *Router) match(ctx context.Context, metadata *adapter.InboundContext, defaultOutbound adapter.Outbound) (context.Context, adapter.Rule, adapter.Outbound, error) { matchRule, matchOutbound := r.match0(ctx, metadata, defaultOutbound) - if contextOutbound, loaded := outbound.TagFromContext(ctx); loaded { + if contextOutbound, loaded := O.TagFromContext(ctx); loaded { if contextOutbound == matchOutbound.Tag() { return nil, nil, nil, E.New("connection loopback in outbound/", matchOutbound.Type(), "[", matchOutbound.Tag(), "]") } } - ctx = outbound.ContextWithTag(ctx, matchOutbound.Tag()) + ctx = O.ContextWithTag(ctx, matchOutbound.Tag()) return ctx, matchRule, matchOutbound, nil } @@ -1154,18 +1154,49 @@ func (r *Router) match0(ctx context.Context, metadata *adapter.InboundContext, d metadata.ProcessInfo = processInfo } } + resolveStatus := -1 + if metadata.Destination.IsFqdn() && len(metadata.DestinationAddresses) == 0 { + resolveStatus = 0 + } + var outbound adapter.Outbound + defer func() { + if resolveStatus == 1 && !r.mustUseIP(outbound, metadata.Network) { + metadata.DestinationAddresses = []netip.Addr{} + } + }() for i, rule := range r.rules { metadata.ResetRuleCache() + if !rule.SkipResolve() && resolveStatus == 0 && rule.ContainsDestinationIPCIDRRule() { + addresses, err := r.LookupDefault(adapter.WithContext(ctx, metadata), metadata.Destination.Fqdn) + resolveStatus = 2 + if err == nil { + resolveStatus = 1 + metadata.DestinationAddresses = addresses + } + metadata.ResetRuleCache() + } if rule.Match(metadata) { detour := rule.Outbound() r.logger.DebugContext(ctx, "match[", i, "] ", rule.String(), " => ", detour) - if outbound, loaded := r.Outbound(detour); loaded { + var loaded bool + if outbound, loaded = r.Outbound(detour); loaded { return rule, outbound } r.logger.ErrorContext(ctx, "outbound not found: ", detour) } } - return nil, defaultOutbound + outbound = defaultOutbound + return nil, outbound +} + +func (r *Router) mustUseIP(outbound adapter.Outbound, network string) bool { + tag := O.RealOutboundTag(outbound, network) + detour, _ := r.Outbound(tag) + d, ok := detour.(adapter.OutboundUseIP) + if !ok { + return false + } + return d.UseIP() } func (r *Router) InterfaceFinder() control.InterfaceFinder { diff --git a/route/rule_abstract.go b/route/rule_abstract.go index 06af8f61..135c6e79 100644 --- a/route/rule_abstract.go +++ b/route/rule_abstract.go @@ -18,9 +18,10 @@ type abstractDefaultRule struct { destinationIPCIDRItems []RuleItem destinationPortItems []RuleItem allItems []RuleItem - ruleSetItem RuleItem + ruleSetItems []RuleItem ruleCount uint64 invert bool + skipResolve bool outbound string } @@ -32,6 +33,17 @@ func (r *abstractDefaultRule) RuleCount() uint64 { return r.ruleCount } +func (r *abstractDefaultRule) SkipResolve() bool { + return r.skipResolve +} + +func (r *abstractDefaultRule) ContainsDestinationIPCIDRRule() bool { + return len(r.destinationIPCIDRItems) > 0 || common.Any(r.ruleSetItems, func(it RuleItem) bool { + r, _ := it.(*RuleSetItem) + return r.ContainsDestinationIPCIDRRule() + }) +} + func (r *abstractDefaultRule) Start() error { for _, item := range r.allItems { if starter, isStarter := item.(interface { @@ -168,11 +180,12 @@ func (r *abstractDefaultRule) String() string { } type abstractLogicalRule struct { - rules []adapter.HeadlessRule - mode string - invert bool - outbound string - ruleCount uint64 + rules []adapter.HeadlessRule + mode string + invert bool + skipResolve bool + outbound string + ruleCount uint64 } func (r *abstractLogicalRule) Type() string { @@ -183,6 +196,16 @@ func (r *abstractLogicalRule) RuleCount() uint64 { return r.ruleCount } +func (r *abstractLogicalRule) SkipResolve() bool { + return r.skipResolve +} + +func (r *abstractLogicalRule) ContainsDestinationIPCIDRRule() bool { + return common.Any(r.rules, func(it adapter.HeadlessRule) bool { + return it.ContainsDestinationIPCIDRRule() + }) +} + func (r *abstractLogicalRule) UpdateGeosite() error { for _, rule := range common.FilterIsInstance(r.rules, func(it adapter.HeadlessRule) (adapter.Rule, bool) { rule, loaded := it.(adapter.Rule) diff --git a/route/rule_default.go b/route/rule_default.go index 40b93e5f..4bf24472 100644 --- a/route/rule_default.go +++ b/route/rule_default.go @@ -45,8 +45,9 @@ type RuleItem interface { func NewDefaultRule(router adapter.Router, logger log.ContextLogger, options option.DefaultRule) (*DefaultRule, error) { rule := &DefaultRule{ abstractDefaultRule{ - invert: options.Invert, - outbound: options.Outbound, + invert: options.Invert, + skipResolve: options.SkipResolve, + outbound: options.Outbound, }, } if len(options.Inbound) > 0 { @@ -220,6 +221,7 @@ func NewDefaultRule(router adapter.Router, logger log.ContextLogger, options opt if len(options.RuleSet) > 0 { item := NewRuleSetItem(router, options.RuleSet, options.RuleSetIPCIDRMatchSource, false) rule.items = append(rule.items, item) + rule.ruleSetItems = append(rule.ruleSetItems, item) rule.allItems = append(rule.allItems, item) } return rule, nil @@ -234,9 +236,10 @@ type LogicalRule struct { func NewLogicalRule(router adapter.Router, logger log.ContextLogger, options option.LogicalRule) (*LogicalRule, error) { r := &LogicalRule{ abstractLogicalRule{ - rules: make([]adapter.HeadlessRule, len(options.Rules)), - invert: options.Invert, - outbound: options.Outbound, + rules: make([]adapter.HeadlessRule, len(options.Rules)), + invert: options.Invert, + skipResolve: options.SkipResolve, + outbound: options.Outbound, }, } switch options.Mode { diff --git a/route/rule_dns.go b/route/rule_dns.go index 616f956a..455312da 100644 --- a/route/rule_dns.go +++ b/route/rule_dns.go @@ -229,6 +229,7 @@ func NewDefaultDNSRule(router adapter.Router, logger log.ContextLogger, options if len(options.RuleSet) > 0 { item := NewRuleSetItem(router, options.RuleSet, options.RuleSetIPCIDRMatchSource, options.RuleSetIPCIDRAcceptEmpty) rule.items = append(rule.items, item) + rule.ruleSetItems = append(rule.ruleSetItems, item) rule.allItems = append(rule.allItems, item) } return rule, nil diff --git a/route/rule_set_abstract.go b/route/rule_set_abstract.go index bb65a7eb..970ae999 100644 --- a/route/rule_set_abstract.go +++ b/route/rule_set_abstract.go @@ -52,6 +52,10 @@ func (s *abstractRuleSet) RuleCount() uint64 { return s.ruleCount } +func (s *abstractRuleSet) ContainsDestinationIPCIDRRule() bool { + return s.metadata.ContainsIPCIDRRule +} + func (s *abstractRuleSet) UpdatedTime() time.Time { return s.lastUpdated }