From 8c893f518023308fdd7934b5f61a918d47ceb421 Mon Sep 17 00:00:00 2001 From: ashly-right Date: Sun, 10 Mar 2024 04:58:00 +0100 Subject: [PATCH] Implementing select random outbound --- option/group.go | 1 + outbound/urltest.go | 163 +++++++++++++++++++++++++++++++++++--------- 2 files changed, 130 insertions(+), 34 deletions(-) diff --git a/option/group.go b/option/group.go index 72a0f637..8aece458 100644 --- a/option/group.go +++ b/option/group.go @@ -13,4 +13,5 @@ type URLTestOutboundOptions struct { Tolerance uint16 `json:"tolerance,omitempty"` IdleTimeout Duration `json:"idle_timeout,omitempty"` InterruptExistConnections bool `json:"interrupt_exist_connections,omitempty"` + Randomized bool `json:"randomized,omitempty"` } diff --git a/outbound/urltest.go b/outbound/urltest.go index aa7cff6c..411f0eaf 100644 --- a/outbound/urltest.go +++ b/outbound/urltest.go @@ -2,6 +2,7 @@ package outbound import ( "context" + "math/rand" "net" "sync" "time" @@ -38,6 +39,7 @@ type URLTest struct { idleTimeout time.Duration group *URLTestGroup interruptExternalConnections bool + randomized bool } func NewURLTest(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.URLTestOutboundOptions) (*URLTest, error) { @@ -57,6 +59,7 @@ func NewURLTest(ctx context.Context, router adapter.Router, logger log.ContextLo tolerance: options.Tolerance, idleTimeout: time.Duration(options.IdleTimeout), interruptExternalConnections: options.InterruptExistConnections, + randomized: options.Randomized, } if len(outbound.tags) == 0 { return nil, E.New("missing tags") @@ -83,6 +86,7 @@ func (s *URLTest) Start() error { s.tolerance, s.idleTimeout, s.interruptExternalConnections, + s.randomized, ) if err != nil { return err @@ -126,16 +130,20 @@ func (s *URLTest) CheckOutbounds() { func (s *URLTest) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { s.group.Touch() var outbound adapter.Outbound - switch N.NetworkName(network) { - case N.NetworkTCP: - outbound = s.group.selectedOutboundTCP - case N.NetworkUDP: - outbound = s.group.selectedOutboundUDP - default: - return nil, E.Extend(N.ErrUnknownNetwork, network) - } - if outbound == nil { - outbound, _ = s.group.Select(network) + if s.randomized { + outbound = s.group.selectRandomOutbound(network) + } else { + switch N.NetworkName(network) { + case N.NetworkTCP: + outbound = s.group.selectedOutboundTCP + case N.NetworkUDP: + outbound = s.group.selectedOutboundUDP + default: + return nil, E.Extend(N.ErrUnknownNetwork, network) + } + if outbound == nil { + outbound, _ = s.group.Select(network) + } } if outbound == nil { return nil, E.New("missing supported outbound") @@ -151,9 +159,14 @@ func (s *URLTest) DialContext(ctx context.Context, network string, destination M func (s *URLTest) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { s.group.Touch() - outbound := s.group.selectedOutboundUDP - if outbound == nil { - outbound, _ = s.group.Select(N.NetworkUDP) + var outbound adapter.Outbound + if s.randomized { + outbound = s.group.selectRandomOutbound(N.NetworkUDP) // Since ListenPacket is for UDP, we pass "N.NetworkUDP" as the network type + } else { + outbound = s.group.selectedOutboundUDP + if outbound == nil { + outbound, _ = s.group.Select(N.NetworkUDP) + } } if outbound == nil { return nil, E.New("missing supported outbound") @@ -196,9 +209,12 @@ type URLTestGroup struct { pauseManager pause.Manager selectedOutboundTCP adapter.Outbound selectedOutboundUDP adapter.Outbound + randomized bool + bestTCPLatencyOutbounds []adapter.Outbound + bestUDPLatencyOutbounds []adapter.Outbound interruptGroup *interrupt.Group interruptExternalConnections bool - + access sync.Mutex ticker *time.Ticker close chan struct{} @@ -216,6 +232,7 @@ func NewURLTestGroup( tolerance uint16, idleTimeout time.Duration, interruptExternalConnections bool, + randomized bool, ) (*URLTestGroup, error) { if interval == 0 { interval = C.DefaultURLTestInterval @@ -250,6 +267,7 @@ func NewURLTestGroup( pauseManager: service.FromContext[pause.Manager](ctx), interruptGroup: interrupt.NewGroup(), interruptExternalConnections: interruptExternalConnections, + randomized: randomized, }, nil } @@ -330,26 +348,29 @@ func (g *URLTestGroup) Select(network string) (adapter.Outbound, bool) { } func (g *URLTestGroup) loopCheck() { - if time.Now().Sub(g.lastActive.Load()) > g.interval { - g.lastActive.Store(time.Now()) - g.CheckOutbounds(false) - } - for { - select { - case <-g.close: - return - case <-g.ticker.C: + if time.Now().Sub(g.lastActive.Load()) > g.interval { + g.lastActive.Store(time.Now()) + g.CheckOutbounds(false) + } + for { + select { + case <-g.close: + return + case <-g.ticker.C: + } + if time.Now().Sub(g.lastActive.Load()) > g.idleTimeout { + g.access.Lock() + g.ticker.Stop() + g.ticker = nil + g.access.Unlock() + return + } + g.pauseManager.WaitActive() + g.CheckOutbounds(false) + if g.randomized { + g.selectBestLatencyOutbounds() } - if time.Now().Sub(g.lastActive.Load()) > g.idleTimeout { - g.access.Lock() - g.ticker.Stop() - g.ticker = nil - g.access.Unlock() - return - } - g.pauseManager.WaitActive() - g.CheckOutbounds(false) - } + } } func (g *URLTestGroup) CheckOutbounds(force bool) { @@ -357,7 +378,15 @@ func (g *URLTestGroup) CheckOutbounds(force bool) { } func (g *URLTestGroup) URLTest(ctx context.Context) (map[string]uint16, error) { - return g.urlTest(ctx, false) + result, err := g.urlTest(ctx, false) + if err != nil { + return nil, err + } + + if g.randomized { + g.selectBestLatencyOutbounds() + } + return result, nil } func (g *URLTestGroup) urlTest(ctx context.Context, force bool) (map[string]uint16, error) { @@ -423,3 +452,69 @@ func (g *URLTestGroup) performUpdateCheck() { g.interruptGroup.Interrupt(g.interruptExternalConnections) } } + +func (g *URLTestGroup) selectBestLatencyOutbounds() { + var bestTCPLatency uint16 + var bestUDPLatency uint16 + + var bestTCPOutbounds []adapter.Outbound + var bestUDPOutbounds []adapter.Outbound + + for _, detour := range g.outbounds { + history := g.history.LoadURLTestHistory(RealTag(detour)) + if history == nil { + continue + } + + if common.Contains(detour.Network(), N.NetworkTCP) { + if bestTCPLatency == 0 || history.Delay < bestTCPLatency { + bestTCPLatency = history.Delay + } + } else if common.Contains(detour.Network(), N.NetworkUDP) { + if bestUDPLatency == 0 || history.Delay < bestUDPLatency { + bestUDPLatency = history.Delay + } + } + } + + for _, detour := range g.outbounds { + history := g.history.LoadURLTestHistory(RealTag(detour)) + if history == nil { + continue + } + + if common.Contains(detour.Network(), N.NetworkTCP) && history.Delay <= bestTCPLatency+g.tolerance { + bestTCPOutbounds = append(bestTCPOutbounds, detour) + } else if common.Contains(detour.Network(), N.NetworkUDP) && history.Delay <= bestUDPLatency+g.tolerance { + bestUDPOutbounds = append(bestUDPOutbounds, detour) + } + } + + g.bestTCPLatencyOutbounds = bestTCPOutbounds + g.bestUDPLatencyOutbounds = bestUDPOutbounds +} + +// selectRandomOutbound selects an outbound randomly among the outbounds with the best latency +func (g *URLTestGroup) selectRandomOutbound(network string) adapter.Outbound { + var bestOutbounds []adapter.Outbound + + switch network { + case N.NetworkTCP: + bestOutbounds = g.bestTCPLatencyOutbounds + case N.NetworkUDP: + bestOutbounds = g.bestUDPLatencyOutbounds + default: + return nil + } + + if len(bestOutbounds) == 0 { + return nil + } + + randIndex := rand.Intn(len(bestOutbounds)) + g.logger.Debug("Random outbound selection: ", bestOutbounds[randIndex].Tag()) + + return bestOutbounds[randIndex] +} + +