diff --git a/adapter/router.go b/adapter/router.go index 6bf30589..c96e998f 100644 --- a/adapter/router.go +++ b/adapter/router.go @@ -6,8 +6,8 @@ import ( "net/netip" "github.com/sagernet/sing-box/common/geoip" - "github.com/sagernet/sing-dns" - "github.com/sagernet/sing-tun" + dns "github.com/sagernet/sing-dns" + tun "github.com/sagernet/sing-tun" "github.com/sagernet/sing/common/control" N "github.com/sagernet/sing/common/network" @@ -19,6 +19,8 @@ type Router interface { Outbounds() []Outbound Outbound(tag string) (Outbound, bool) + AddOutbound(Outbound) + RemoveOutbound(string) DefaultOutbound(network string) Outbound RouteConnection(ctx context.Context, conn net.Conn, metadata InboundContext) error diff --git a/route/router.go b/route/router.go index 4cc725a7..a2d18ecc 100644 --- a/route/router.go +++ b/route/router.go @@ -24,9 +24,9 @@ import ( C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" - "github.com/sagernet/sing-dns" - "github.com/sagernet/sing-tun" - "github.com/sagernet/sing-vmess" + dns "github.com/sagernet/sing-dns" + tun "github.com/sagernet/sing-tun" + vmess "github.com/sagernet/sing-vmess" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/bufio" @@ -67,8 +67,7 @@ type Router struct { logger log.ContextLogger dnsLogger log.ContextLogger inboundByTag map[string]adapter.Inbound - outbounds []adapter.Outbound - outboundByTag map[string]adapter.Outbound + outbounds *outboundsManager rules []adapter.Rule defaultDetour string defaultOutboundForConnection adapter.Outbound @@ -114,7 +113,7 @@ func NewRouter(ctx context.Context, logger log.ContextLogger, dnsLogger log.Cont ctx: ctx, logger: logger, dnsLogger: dnsLogger, - outboundByTag: make(map[string]adapter.Outbound), + outbounds: newOutboundsManager(), rules: make([]adapter.Rule, 0, len(options.Rules)), dnsRules: make([]adapter.DNSRule, 0, len(dnsOptions.Rules)), needGeoIPDatabase: hasRule(options.Rules, isGeoIPRule) || hasDNSRule(dnsOptions.Rules, isGeoIPDNSRule), @@ -311,14 +310,13 @@ func (r *Router) Initialize(inbounds []adapter.Inbound, outbounds []adapter.Outb for _, inbound := range inbounds { inboundByTag[inbound.Tag()] = inbound } - outboundByTag := make(map[string]adapter.Outbound) for _, detour := range outbounds { - outboundByTag[detour.Tag()] = detour + r.outbounds.Add(detour) } var defaultOutboundForConnection adapter.Outbound var defaultOutboundForPacketConnection adapter.Outbound if r.defaultDetour != "" { - detour, loaded := outboundByTag[r.defaultDetour] + detour, loaded := r.outbounds.Get(r.defaultDetour) if !loaded { return E.New("default detour not found: ", r.defaultDetour) } @@ -357,7 +355,7 @@ func (r *Router) Initialize(inbounds []adapter.Inbound, outbounds []adapter.Outb defaultOutboundForPacketConnection = detour } outbounds = append(outbounds, detour) - outboundByTag[detour.Tag()] = detour + r.outbounds.Add(detour) } if defaultOutboundForConnection != defaultOutboundForPacketConnection { var description string @@ -376,12 +374,10 @@ func (r *Router) Initialize(inbounds []adapter.Inbound, outbounds []adapter.Outb r.logger.Info("using ", defaultOutboundForPacketConnection.Type(), "[", packetDescription, "] as default outbound for packet connection") } r.inboundByTag = inboundByTag - r.outbounds = outbounds r.defaultOutboundForConnection = defaultOutboundForConnection r.defaultOutboundForPacketConnection = defaultOutboundForPacketConnection - r.outboundByTag = outboundByTag for i, rule := range r.rules { - if _, loaded := outboundByTag[rule.Outbound()]; !loaded { + if _, loaded := r.outbounds.Get(rule.Outbound()); !loaded { return E.New("outbound not found for rule[", i, "]: ", rule.Outbound()) } } @@ -389,7 +385,7 @@ func (r *Router) Initialize(inbounds []adapter.Inbound, outbounds []adapter.Outb } func (r *Router) Outbounds() []adapter.Outbound { - return r.outbounds + return r.outbounds.All() } func (r *Router) Start() error { @@ -499,10 +495,15 @@ func (r *Router) LoadGeosite(code string) (adapter.Rule, error) { r.geositeCache[code] = rule return rule, nil } +func (r *Router) AddOutbound(o adapter.Outbound) { + r.outbounds.Add(o) +} +func (r *Router) RemoveOutbound(tag string) { + r.outbounds.Remove(tag) +} func (r *Router) Outbound(tag string) (adapter.Outbound, bool) { - outbound, loaded := r.outboundByTag[tag] - return outbound, loaded + return r.outbounds.Get(tag) } func (r *Router) DefaultOutbound(network string) adapter.Outbound { diff --git a/route/router_outbounds.go b/route/router_outbounds.go new file mode 100644 index 00000000..54e6a116 --- /dev/null +++ b/route/router_outbounds.go @@ -0,0 +1,81 @@ +package route + +import ( + "sync" + + "github.com/sagernet/sing-box/adapter" +) + +// outboundsManager is the thread-safe outbound manager. +type outboundsManager struct { + sync.RWMutex + + tags []string // tags keeps the order of outbounds + all map[string]adapter.Outbound +} + +func newOutboundsManager() *outboundsManager { + return &outboundsManager{ + all: make(map[string]adapter.Outbound), + } +} + +func (o *outboundsManager) Add(outbound adapter.Outbound) { + o.Lock() + defer o.Unlock() + + tag := outbound.Tag() + if _, ok := o.all[tag]; ok { + // update and return + o.all[tag] = outbound + return + } + + o.all[tag] = outbound + o.tags = append(o.tags, tag) +} + +func (o *outboundsManager) Remove(tag string) { + o.Lock() + defer o.Unlock() + + if _, ok := o.all[tag]; !ok { + return + } + delete(o.all, tag) + o.tags = findDeleteElement(o.tags, tag) +} + +func (o *outboundsManager) Get(tag string) (adapter.Outbound, bool) { + o.RLock() + defer o.RUnlock() + + outbound, ok := o.all[tag] + return outbound, ok +} + +func (o *outboundsManager) All() []adapter.Outbound { + o.RLock() + defer o.RUnlock() + + all := make([]adapter.Outbound, 0, len(o.tags)) + for _, tag := range o.tags { + all = append(all, o.all[tag]) + } + return all +} + +func findDeleteElement[T comparable](slice []T, element T) []T { + idx := -1 + for i := 0; i < len(slice); i++ { + if slice[i] == element { + idx = i + break + } + } + if idx < 0 { + return slice + } + copy(slice[idx:], slice[idx+1:]) + return slice[:len(slice)-1] +}