Improve loopback detector

This commit is contained in:
世界 2024-04-12 09:24:49 +08:00
parent 64a05a27a2
commit d9f2d31147
No known key found for this signature in database
GPG Key ID: CD109927C34A63C4
7 changed files with 48 additions and 92 deletions

View File

@ -9,7 +9,6 @@ import (
"github.com/sagernet/sing-box" "github.com/sagernet/sing-box"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/process" "github.com/sagernet/sing-box/common/process"
"github.com/sagernet/sing-box/experimental/libbox/platform"
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-tun" "github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common/control" "github.com/sagernet/sing/common/control"
@ -75,7 +74,7 @@ func (s *platformInterfaceStub) UsePlatformInterfaceGetter() bool {
return true return true
} }
func (s *platformInterfaceStub) Interfaces() ([]platform.NetworkInterface, error) { func (s *platformInterfaceStub) Interfaces() ([]control.Interface, error) {
return nil, os.ErrInvalid return nil, os.ErrInvalid
} }

View File

@ -2,7 +2,6 @@ package platform
import ( import (
"context" "context"
"net/netip"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/process" "github.com/sagernet/sing-box/common/process"
@ -20,16 +19,9 @@ type Interface interface {
UsePlatformDefaultInterfaceMonitor() bool UsePlatformDefaultInterfaceMonitor() bool
CreateDefaultInterfaceMonitor(logger logger.Logger) tun.DefaultInterfaceMonitor CreateDefaultInterfaceMonitor(logger logger.Logger) tun.DefaultInterfaceMonitor
UsePlatformInterfaceGetter() bool UsePlatformInterfaceGetter() bool
Interfaces() ([]NetworkInterface, error) Interfaces() ([]control.Interface, error)
UnderNetworkExtension() bool UnderNetworkExtension() bool
ClearDNSCache() ClearDNSCache()
ReadWIFIState() adapter.WIFIState ReadWIFIState() adapter.WIFIState
process.Searcher process.Searcher
} }
type NetworkInterface struct {
Index int
MTU int
Name string
Addresses []netip.Prefix
}

View File

@ -192,14 +192,14 @@ func (w *platformInterfaceWrapper) UsePlatformInterfaceGetter() bool {
return w.iif.UsePlatformInterfaceGetter() return w.iif.UsePlatformInterfaceGetter()
} }
func (w *platformInterfaceWrapper) Interfaces() ([]platform.NetworkInterface, error) { func (w *platformInterfaceWrapper) Interfaces() ([]control.Interface, error) {
interfaceIterator, err := w.iif.GetInterfaces() interfaceIterator, err := w.iif.GetInterfaces()
if err != nil { if err != nil {
return nil, err return nil, err
} }
var interfaces []platform.NetworkInterface var interfaces []control.Interface
for _, netInterface := range iteratorToArray[*NetworkInterface](interfaceIterator) { for _, netInterface := range iteratorToArray[*NetworkInterface](interfaceIterator) {
interfaces = append(interfaces, platform.NetworkInterface{ interfaces = append(interfaces, control.Interface{
Index: int(netInterface.Index), Index: int(netInterface.Index),
MTU: int(netInterface.MTU), MTU: int(netInterface.MTU),
Name: netInterface.Name, Name: netInterface.Name,

View File

@ -51,7 +51,7 @@ func NewDirect(router adapter.Router, logger log.ContextLogger, tag string, opti
domainStrategy: dns.DomainStrategy(options.DomainStrategy), domainStrategy: dns.DomainStrategy(options.DomainStrategy),
fallbackDelay: time.Duration(options.FallbackDelay), fallbackDelay: time.Duration(options.FallbackDelay),
dialer: outboundDialer, dialer: outboundDialer,
loopBack: newLoopBackDetector(), loopBack: newLoopBackDetector(router),
} }
if options.ProxyProtocol != 0 { if options.ProxyProtocol != 0 {
return nil, E.New("Proxy Protocol is deprecated and removed in sing-box 1.6.0") return nil, E.New("Proxy Protocol is deprecated and removed in sing-box 1.6.0")

View File

@ -5,21 +5,24 @@ import (
"net/netip" "net/netip"
"sync" "sync"
"github.com/sagernet/sing-box/adapter"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
) )
type loopBackDetector struct { type loopBackDetector struct {
router adapter.Router
connAccess sync.RWMutex connAccess sync.RWMutex
packetConnAccess sync.RWMutex packetConnAccess sync.RWMutex
connMap map[netip.AddrPort]bool connMap map[netip.AddrPort]bool
packetConnMap map[netip.AddrPort]bool packetConnMap map[uint16]bool
} }
func newLoopBackDetector() *loopBackDetector { func newLoopBackDetector(router adapter.Router) *loopBackDetector {
return &loopBackDetector{ return &loopBackDetector{
router: router,
connMap: make(map[netip.AddrPort]bool), connMap: make(map[netip.AddrPort]bool),
packetConnMap: make(map[netip.AddrPort]bool), packetConnMap: make(map[uint16]bool),
} }
} }
@ -29,10 +32,16 @@ func (l *loopBackDetector) NewConn(conn net.Conn) net.Conn {
return conn return conn
} }
if udpConn, isUDPConn := conn.(abstractUDPConn); isUDPConn { if udpConn, isUDPConn := conn.(abstractUDPConn); isUDPConn {
if !connAddr.Addr().IsLoopback() {
_, err := l.router.InterfaceFinder().InterfaceByAddr(connAddr.Addr())
if err != nil {
return conn
}
}
l.packetConnAccess.Lock() l.packetConnAccess.Lock()
l.packetConnMap[connAddr] = true l.packetConnMap[connAddr.Port()] = true
l.packetConnAccess.Unlock() l.packetConnAccess.Unlock()
return &loopBackDetectUDPWrapper{abstractUDPConn: udpConn, detector: l, connAddr: connAddr} return &loopBackDetectUDPWrapper{abstractUDPConn: udpConn, detector: l, connPort: connAddr.Port()}
} else { } else {
l.connAccess.Lock() l.connAccess.Lock()
l.connMap[connAddr] = true l.connMap[connAddr] = true
@ -46,10 +55,16 @@ func (l *loopBackDetector) NewPacketConn(conn N.NetPacketConn) N.NetPacketConn {
if !connAddr.IsValid() { if !connAddr.IsValid() {
return conn return conn
} }
if !connAddr.Addr().IsLoopback() {
_, err := l.router.InterfaceFinder().InterfaceByAddr(connAddr.Addr())
if err != nil {
return conn
}
}
l.packetConnAccess.Lock() l.packetConnAccess.Lock()
l.packetConnMap[connAddr] = true l.packetConnMap[connAddr.Port()] = true
l.packetConnAccess.Unlock() l.packetConnAccess.Unlock()
return &loopBackDetectPacketWrapper{NetPacketConn: conn, detector: l, connAddr: connAddr} return &loopBackDetectPacketWrapper{NetPacketConn: conn, detector: l, connPort: connAddr.Port()}
} }
func (l *loopBackDetector) CheckConn(connAddr netip.AddrPort) bool { func (l *loopBackDetector) CheckConn(connAddr netip.AddrPort) bool {
@ -59,9 +74,18 @@ func (l *loopBackDetector) CheckConn(connAddr netip.AddrPort) bool {
} }
func (l *loopBackDetector) CheckPacketConn(connAddr netip.AddrPort) bool { func (l *loopBackDetector) CheckPacketConn(connAddr netip.AddrPort) bool {
if !connAddr.IsValid() || !connAddr.Addr().IsLoopback() {
return false
}
if !connAddr.Addr().IsLoopback() {
_, err := l.router.InterfaceFinder().InterfaceByAddr(connAddr.Addr())
if err != nil {
return false
}
}
l.packetConnAccess.RLock() l.packetConnAccess.RLock()
defer l.packetConnAccess.RUnlock() defer l.packetConnAccess.RUnlock()
return l.packetConnMap[connAddr] return l.packetConnMap[connAddr.Port()]
} }
type loopBackDetectWrapper struct { type loopBackDetectWrapper struct {
@ -95,14 +119,14 @@ func (w *loopBackDetectWrapper) Upstream() any {
type loopBackDetectPacketWrapper struct { type loopBackDetectPacketWrapper struct {
N.NetPacketConn N.NetPacketConn
detector *loopBackDetector detector *loopBackDetector
connAddr netip.AddrPort connPort uint16
closeOnce sync.Once closeOnce sync.Once
} }
func (w *loopBackDetectPacketWrapper) Close() error { func (w *loopBackDetectPacketWrapper) Close() error {
w.closeOnce.Do(func() { w.closeOnce.Do(func() {
w.detector.packetConnAccess.Lock() w.detector.packetConnAccess.Lock()
delete(w.detector.packetConnMap, w.connAddr) delete(w.detector.packetConnMap, w.connPort)
w.detector.packetConnAccess.Unlock() w.detector.packetConnAccess.Unlock()
}) })
return w.NetPacketConn.Close() return w.NetPacketConn.Close()
@ -128,14 +152,14 @@ type abstractUDPConn interface {
type loopBackDetectUDPWrapper struct { type loopBackDetectUDPWrapper struct {
abstractUDPConn abstractUDPConn
detector *loopBackDetector detector *loopBackDetector
connAddr netip.AddrPort connPort uint16
closeOnce sync.Once closeOnce sync.Once
} }
func (w *loopBackDetectUDPWrapper) Close() error { func (w *loopBackDetectUDPWrapper) Close() error {
w.closeOnce.Do(func() { w.closeOnce.Do(func() {
w.detector.packetConnAccess.Lock() w.detector.packetConnAccess.Lock()
delete(w.detector.packetConnMap, w.connAddr) delete(w.detector.packetConnMap, w.connPort)
w.detector.packetConnAccess.Unlock() w.detector.packetConnAccess.Unlock()
}) })
return w.abstractUDPConn.Close() return w.abstractUDPConn.Close()

View File

@ -1,54 +0,0 @@
package route
import (
"net"
"github.com/sagernet/sing/common/control"
)
var _ control.InterfaceFinder = (*myInterfaceFinder)(nil)
type myInterfaceFinder struct {
interfaces []net.Interface
}
func (f *myInterfaceFinder) update() error {
ifs, err := net.Interfaces()
if err != nil {
return err
}
f.interfaces = ifs
return nil
}
func (f *myInterfaceFinder) updateInterfaces(interfaces []net.Interface) {
f.interfaces = interfaces
}
func (f *myInterfaceFinder) InterfaceIndexByName(name string) (interfaceIndex int, err error) {
for _, netInterface := range f.interfaces {
if netInterface.Name == name {
return netInterface.Index, nil
}
}
netInterface, err := net.InterfaceByName(name)
if err != nil {
return
}
f.update()
return netInterface.Index, nil
}
func (f *myInterfaceFinder) InterfaceNameByIndex(index int) (interfaceName string, err error) {
for _, netInterface := range f.interfaces {
if netInterface.Index == index {
return netInterface.Name, nil
}
}
netInterface, err := net.InterfaceByIndex(index)
if err != nil {
return
}
f.update()
return netInterface.Name, nil
}

View File

@ -79,7 +79,7 @@ type Router struct {
transportDomainStrategy map[dns.Transport]dns.DomainStrategy transportDomainStrategy map[dns.Transport]dns.DomainStrategy
dnsReverseMapping *DNSReverseMapping dnsReverseMapping *DNSReverseMapping
fakeIPStore adapter.FakeIPStore fakeIPStore adapter.FakeIPStore
interfaceFinder myInterfaceFinder interfaceFinder *control.DefaultInterfaceFinder
autoDetectInterface bool autoDetectInterface bool
defaultInterface string defaultInterface string
defaultMark int defaultMark int
@ -124,6 +124,7 @@ func NewRouter(
needFindProcess: hasRule(options.Rules, isProcessRule) || hasDNSRule(dnsOptions.Rules, isProcessDNSRule) || options.FindProcess, needFindProcess: hasRule(options.Rules, isProcessRule) || hasDNSRule(dnsOptions.Rules, isProcessDNSRule) || options.FindProcess,
defaultDetour: options.Final, defaultDetour: options.Final,
defaultDomainStrategy: dns.DomainStrategy(dnsOptions.Strategy), defaultDomainStrategy: dns.DomainStrategy(dnsOptions.Strategy),
interfaceFinder: control.NewDefaultInterfaceFinder(),
autoDetectInterface: options.AutoDetectInterface, autoDetectInterface: options.AutoDetectInterface,
defaultInterface: options.DefaultInterface, defaultInterface: options.DefaultInterface,
defaultMark: options.DefaultMark, defaultMark: options.DefaultMark,
@ -305,7 +306,7 @@ func NewRouter(
} }
router.networkMonitor = networkMonitor router.networkMonitor = networkMonitor
networkMonitor.RegisterCallback(func() { networkMonitor.RegisterCallback(func() {
_ = router.interfaceFinder.update() _ = router.interfaceFinder.Update()
}) })
interfaceMonitor, err := tun.NewDefaultInterfaceMonitor(router.networkMonitor, router.logger, tun.DefaultInterfaceMonitorOptions{ interfaceMonitor, err := tun.NewDefaultInterfaceMonitor(router.networkMonitor, router.logger, tun.DefaultInterfaceMonitorOptions{
OverrideAndroidVPN: options.OverrideAndroidVPN, OverrideAndroidVPN: options.OverrideAndroidVPN,
@ -1063,24 +1064,18 @@ func (r *Router) match0(ctx context.Context, metadata *adapter.InboundContext, d
} }
func (r *Router) InterfaceFinder() control.InterfaceFinder { func (r *Router) InterfaceFinder() control.InterfaceFinder {
return &r.interfaceFinder return r.interfaceFinder
} }
func (r *Router) UpdateInterfaces() error { func (r *Router) UpdateInterfaces() error {
if r.platformInterface == nil || !r.platformInterface.UsePlatformInterfaceGetter() { if r.platformInterface == nil || !r.platformInterface.UsePlatformInterfaceGetter() {
return r.interfaceFinder.update() return r.interfaceFinder.Update()
} else { } else {
interfaces, err := r.platformInterface.Interfaces() interfaces, err := r.platformInterface.Interfaces()
if err != nil { if err != nil {
return err return err
} }
r.interfaceFinder.updateInterfaces(common.Map(interfaces, func(it platform.NetworkInterface) net.Interface { r.interfaceFinder.UpdateInterfaces(interfaces)
return net.Interface{
Name: it.Name,
Index: it.Index,
MTU: it.MTU,
}
}))
return nil return nil
} }
} }