use ipondemand mode when route rule matching

(cherry picked from commit c33f91170b0fa5bc10cbcce5afe3b648f1a5bff1)
This commit is contained in:
PuerNya 2024-08-13 10:38:10 +08:00 committed by CHIZI-0618
parent 0ec07e573e
commit dd8eb2ce11
15 changed files with 199 additions and 35 deletions

View File

@ -102,6 +102,7 @@ type OutboundGroup interface {
Outbound
Now() string
All() []string
SelectedOutbound(network string) Outbound
}
type URLTestGroup interface {

View File

@ -18,3 +18,7 @@ type Outbound interface {
NewConnection(ctx context.Context, conn net.Conn, metadata InboundContext) error
NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata InboundContext) error
}
type OutboundUseIP interface {
UseIP() bool
}

View File

@ -80,6 +80,7 @@ type HeadlessRule interface {
Match(metadata *InboundContext) bool
RuleCount() uint64
String() string
ContainsDestinationIPCIDRRule() bool
}
type Rule interface {
@ -87,6 +88,7 @@ type Rule interface {
Service
Type() string
UpdateGeosite() error
SkipResolve() bool
Outbound() string
}

View File

@ -98,6 +98,7 @@ type _DefaultRule struct {
RuleSet Listable[string] `json:"rule_set,omitempty"`
RuleSetIPCIDRMatchSource bool `json:"rule_set_ip_cidr_match_source,omitempty"`
Invert bool `json:"invert,omitempty"`
SkipResolve bool `json:"skip_resolve,omitempty"`
Outbound string `json:"outbound,omitempty"`
// Deprecated: renamed to rule_set_ip_cidr_match_source
@ -131,6 +132,7 @@ type LogicalRule struct {
Mode string `json:"mode"`
Rules []Rule `json:"rules,omitempty"`
Invert bool `json:"invert,omitempty"`
SkipResolve bool `json:"skip_resolve,omitempty"`
Outbound string `json:"outbound,omitempty"`
}

View File

@ -81,15 +81,40 @@ func NewDirectConnection(ctx context.Context, router adapter.Router, this N.Dial
ctx = adapter.WithContext(ctx, &metadata)
var outConn net.Conn
var err error
if len(metadata.DestinationAddresses) > 0 {
outConn, err = N.DialSerial(ctx, this, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses)
} else if metadata.Destination.IsFqdn() {
var destinationAddresses []netip.Addr
destinationAddresses, err = router.Lookup(ctx, metadata.Destination.Fqdn, domainStrategy)
addresses := metadata.DestinationAddresses
if len(addresses) == 0 && metadata.Destination.IsFqdn() {
addresses, err = router.Lookup(ctx, metadata.Destination.Fqdn, domainStrategy)
if err != nil {
return N.ReportHandshakeFailure(conn, err)
}
outConn, err = N.DialSerial(ctx, this, N.NetworkTCP, metadata.Destination, destinationAddresses)
}
if len(addresses) > 0 {
addresses4 := common.Filter(addresses, func(address netip.Addr) bool {
return address.Is4() || address.Is4In6()
})
addresses6 := common.Filter(addresses, func(address netip.Addr) bool {
return address.Is6() && !address.Is4In6()
})
connFunc := func(primaries []netip.Addr, fallbacks []netip.Addr) (net.Conn, error) {
if len(primaries) > 0 {
if conn, err := N.DialSerial(ctx, this, N.NetworkTCP, metadata.Destination, primaries); err == nil || len(fallbacks) == 0 {
return conn, err
}
}
return N.DialSerial(ctx, this, N.NetworkTCP, metadata.Destination, fallbacks)
}
switch domainStrategy {
case dns.DomainStrategyAsIS:
outConn, err = connFunc(addresses, nil)
case dns.DomainStrategyUseIPv4:
outConn, err = connFunc(addresses4, nil)
case dns.DomainStrategyUseIPv6:
outConn, err = connFunc(addresses6, nil)
case dns.DomainStrategyPreferIPv4:
outConn, err = connFunc(addresses4, addresses6)
case dns.DomainStrategyPreferIPv6:
outConn, err = connFunc(addresses6, addresses4)
}
} else {
outConn, err = this.DialContext(ctx, N.NetworkTCP, metadata.Destination)
}
@ -150,15 +175,40 @@ func NewDirectPacketConnection(ctx context.Context, router adapter.Router, this
var outConn net.PacketConn
var destinationAddress netip.Addr
var err error
if len(metadata.DestinationAddresses) > 0 {
outConn, destinationAddress, err = N.ListenSerial(ctx, this, metadata.Destination, metadata.DestinationAddresses)
} else if metadata.Destination.IsFqdn() {
var destinationAddresses []netip.Addr
destinationAddresses, err = router.Lookup(ctx, metadata.Destination.Fqdn, domainStrategy)
addresses := metadata.DestinationAddresses
if len(addresses) == 0 && metadata.Destination.IsFqdn() {
addresses, err = router.Lookup(ctx, metadata.Destination.Fqdn, domainStrategy)
if err != nil {
return N.ReportHandshakeFailure(conn, err)
}
outConn, destinationAddress, err = N.ListenSerial(ctx, this, metadata.Destination, destinationAddresses)
}
if len(addresses) > 0 {
addresses4 := common.Filter(addresses, func(address netip.Addr) bool {
return address.Is4() || address.Is4In6()
})
addresses6 := common.Filter(addresses, func(address netip.Addr) bool {
return address.Is6() && !address.Is4In6()
})
connFunc := func(primaries []netip.Addr, fallbacks []netip.Addr) (net.PacketConn, netip.Addr, error) {
if len(primaries) > 0 {
if conn, addr, err := N.ListenSerial(ctx, this, metadata.Destination, primaries); err == nil || len(fallbacks) == 0 {
return conn, addr, err
}
}
return N.ListenSerial(ctx, this, metadata.Destination, fallbacks)
}
switch domainStrategy {
case dns.DomainStrategyAsIS:
outConn, destinationAddress, err = connFunc(addresses, nil)
case dns.DomainStrategyUseIPv4:
outConn, destinationAddress, err = connFunc(addresses4, nil)
case dns.DomainStrategyUseIPv6:
outConn, destinationAddress, err = connFunc(addresses6, nil)
case dns.DomainStrategyPreferIPv4:
outConn, destinationAddress, err = connFunc(addresses4, addresses6)
case dns.DomainStrategyPreferIPv6:
outConn, destinationAddress, err = connFunc(addresses6, addresses4)
}
} else {
outConn, err = this.ListenPacket(ctx, metadata.Destination)
}

View File

@ -20,6 +20,7 @@ import (
var (
_ adapter.Outbound = (*Direct)(nil)
_ adapter.OutboundUseIP = (*Direct)(nil)
_ N.ParallelDialer = (*Direct)(nil)
)
@ -168,3 +169,7 @@ func (h *Direct) NewPacketConnection(ctx context.Context, conn N.PacketConn, met
}
return NewPacketConnection(ctx, h, conn, metadata)
}
func (h *Direct) UseIP() bool {
return true
}

View File

@ -100,6 +100,10 @@ func (s *Selector) Now() string {
return s.selected.Tag()
}
func (s *Selector) SelectedOutbound(network string) adapter.Outbound {
return s.selected
}
func (s *Selector) All() []string {
return s.tags
}
@ -158,3 +162,11 @@ func RealTag(detour adapter.Outbound) string {
}
return detour.Tag()
}
func RealOutboundTag(detour adapter.Outbound, network string) string {
group, isGroup := detour.(adapter.OutboundGroup)
if !isGroup {
return detour.Tag()
}
return RealOutboundTag(group.SelectedOutbound(network), network)
}

View File

@ -18,7 +18,10 @@ import (
"github.com/sagernet/sing/protocol/socks"
)
var _ adapter.Outbound = (*Socks)(nil)
var (
_ adapter.Outbound = (*Socks)(nil)
_ adapter.OutboundUseIP = (*Socks)(nil)
)
type Socks struct {
myOutboundAdapter
@ -128,3 +131,7 @@ func (h *Socks) NewPacketConnection(ctx context.Context, conn N.PacketConn, meta
return NewPacketConnection(ctx, h, conn, metadata)
}
}
func (h *Socks) UseIP() bool {
return h.resolve
}

View File

@ -111,6 +111,20 @@ func (s *URLTest) Now() string {
return ""
}
func (s *URLTest) SelectedOutbound(network string) adapter.Outbound {
switch network {
case N.NetworkTCP:
if s.group.selectedOutboundTCP != nil {
return s.group.selectedOutboundTCP
}
case N.NetworkUDP:
if s.group.selectedOutboundUDP != nil {
return s.group.selectedOutboundUDP
}
}
return s.group.outbounds[0]
}
func (s *URLTest) All() []string {
return s.tags
}

View File

@ -32,6 +32,7 @@ import (
var (
_ adapter.Outbound = (*WireGuard)(nil)
_ adapter.OutboundUseIP = (*WireGuard)(nil)
_ adapter.InterfaceUpdateListener = (*WireGuard)(nil)
)
@ -241,3 +242,7 @@ func (w *WireGuard) NewConnection(ctx context.Context, conn net.Conn, metadata a
func (w *WireGuard) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
return NewDirectPacketConnection(ctx, w.router, w, conn, metadata, dns.DomainStrategyAsIS)
}
func (w *WireGuard) UseIP() bool {
return true
}

View File

@ -24,7 +24,7 @@ import (
"github.com/sagernet/sing-box/experimental/libbox/platform"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-box/outbound"
O "github.com/sagernet/sing-box/outbound"
"github.com/sagernet/sing-box/transport/fakeip"
"github.com/sagernet/sing-dns"
"github.com/sagernet/sing-mux"
@ -1113,12 +1113,12 @@ func (r *Router) RoutePacketConnection(ctx context.Context, conn N.PacketConn, m
func (r *Router) match(ctx context.Context, metadata *adapter.InboundContext, defaultOutbound adapter.Outbound) (context.Context, adapter.Rule, adapter.Outbound, error) {
matchRule, matchOutbound := r.match0(ctx, metadata, defaultOutbound)
if contextOutbound, loaded := outbound.TagFromContext(ctx); loaded {
if contextOutbound, loaded := O.TagFromContext(ctx); loaded {
if contextOutbound == matchOutbound.Tag() {
return nil, nil, nil, E.New("connection loopback in outbound/", matchOutbound.Type(), "[", matchOutbound.Tag(), "]")
}
}
ctx = outbound.ContextWithTag(ctx, matchOutbound.Tag())
ctx = O.ContextWithTag(ctx, matchOutbound.Tag())
return ctx, matchRule, matchOutbound, nil
}
@ -1154,18 +1154,49 @@ func (r *Router) match0(ctx context.Context, metadata *adapter.InboundContext, d
metadata.ProcessInfo = processInfo
}
}
resolveStatus := -1
if metadata.Destination.IsFqdn() && len(metadata.DestinationAddresses) == 0 {
resolveStatus = 0
}
var outbound adapter.Outbound
defer func() {
if resolveStatus == 1 && !r.mustUseIP(outbound, metadata.Network) {
metadata.DestinationAddresses = []netip.Addr{}
}
}()
for i, rule := range r.rules {
metadata.ResetRuleCache()
if !rule.SkipResolve() && resolveStatus == 0 && rule.ContainsDestinationIPCIDRRule() {
addresses, err := r.LookupDefault(adapter.WithContext(ctx, metadata), metadata.Destination.Fqdn)
resolveStatus = 2
if err == nil {
resolveStatus = 1
metadata.DestinationAddresses = addresses
}
metadata.ResetRuleCache()
}
if rule.Match(metadata) {
detour := rule.Outbound()
r.logger.DebugContext(ctx, "match[", i, "] ", rule.String(), " => ", detour)
if outbound, loaded := r.Outbound(detour); loaded {
var loaded bool
if outbound, loaded = r.Outbound(detour); loaded {
return rule, outbound
}
r.logger.ErrorContext(ctx, "outbound not found: ", detour)
}
}
return nil, defaultOutbound
outbound = defaultOutbound
return nil, outbound
}
func (r *Router) mustUseIP(outbound adapter.Outbound, network string) bool {
tag := O.RealOutboundTag(outbound, network)
detour, _ := r.Outbound(tag)
d, ok := detour.(adapter.OutboundUseIP)
if !ok {
return false
}
return d.UseIP()
}
func (r *Router) InterfaceFinder() control.InterfaceFinder {

View File

@ -18,9 +18,10 @@ type abstractDefaultRule struct {
destinationIPCIDRItems []RuleItem
destinationPortItems []RuleItem
allItems []RuleItem
ruleSetItem RuleItem
ruleSetItems []RuleItem
ruleCount uint64
invert bool
skipResolve bool
outbound string
}
@ -32,6 +33,17 @@ func (r *abstractDefaultRule) RuleCount() uint64 {
return r.ruleCount
}
func (r *abstractDefaultRule) SkipResolve() bool {
return r.skipResolve
}
func (r *abstractDefaultRule) ContainsDestinationIPCIDRRule() bool {
return len(r.destinationIPCIDRItems) > 0 || common.Any(r.ruleSetItems, func(it RuleItem) bool {
r, _ := it.(*RuleSetItem)
return r.ContainsDestinationIPCIDRRule()
})
}
func (r *abstractDefaultRule) Start() error {
for _, item := range r.allItems {
if starter, isStarter := item.(interface {
@ -171,6 +183,7 @@ type abstractLogicalRule struct {
rules []adapter.HeadlessRule
mode string
invert bool
skipResolve bool
outbound string
ruleCount uint64
}
@ -183,6 +196,16 @@ func (r *abstractLogicalRule) RuleCount() uint64 {
return r.ruleCount
}
func (r *abstractLogicalRule) SkipResolve() bool {
return r.skipResolve
}
func (r *abstractLogicalRule) ContainsDestinationIPCIDRRule() bool {
return common.Any(r.rules, func(it adapter.HeadlessRule) bool {
return it.ContainsDestinationIPCIDRRule()
})
}
func (r *abstractLogicalRule) UpdateGeosite() error {
for _, rule := range common.FilterIsInstance(r.rules, func(it adapter.HeadlessRule) (adapter.Rule, bool) {
rule, loaded := it.(adapter.Rule)

View File

@ -46,6 +46,7 @@ func NewDefaultRule(router adapter.Router, logger log.ContextLogger, options opt
rule := &DefaultRule{
abstractDefaultRule{
invert: options.Invert,
skipResolve: options.SkipResolve,
outbound: options.Outbound,
},
}
@ -220,6 +221,7 @@ func NewDefaultRule(router adapter.Router, logger log.ContextLogger, options opt
if len(options.RuleSet) > 0 {
item := NewRuleSetItem(router, options.RuleSet, options.RuleSetIPCIDRMatchSource, false)
rule.items = append(rule.items, item)
rule.ruleSetItems = append(rule.ruleSetItems, item)
rule.allItems = append(rule.allItems, item)
}
return rule, nil
@ -236,6 +238,7 @@ func NewLogicalRule(router adapter.Router, logger log.ContextLogger, options opt
abstractLogicalRule{
rules: make([]adapter.HeadlessRule, len(options.Rules)),
invert: options.Invert,
skipResolve: options.SkipResolve,
outbound: options.Outbound,
},
}

View File

@ -229,6 +229,7 @@ func NewDefaultDNSRule(router adapter.Router, logger log.ContextLogger, options
if len(options.RuleSet) > 0 {
item := NewRuleSetItem(router, options.RuleSet, options.RuleSetIPCIDRMatchSource, options.RuleSetIPCIDRAcceptEmpty)
rule.items = append(rule.items, item)
rule.ruleSetItems = append(rule.ruleSetItems, item)
rule.allItems = append(rule.allItems, item)
}
return rule, nil

View File

@ -52,6 +52,10 @@ func (s *abstractRuleSet) RuleCount() uint64 {
return s.ruleCount
}
func (s *abstractRuleSet) ContainsDestinationIPCIDRRule() bool {
return s.metadata.ContainsIPCIDRRule
}
func (s *abstractRuleSet) UpdatedTime() time.Time {
return s.lastUpdated
}