fetch health check nodes in time from the router by prefix ...

* Balancer & HealthCheck not hold adapter.Outbound, but only tags
* rename and unexport some structs and fields
* fix no check in the first rounds
This commit is contained in:
jebbs 2022-10-11 17:54:39 +08:00
parent 6d417949ae
commit a2d428f246
13 changed files with 182 additions and 148 deletions

View File

@ -6,6 +6,7 @@ import (
"sort" "sort"
"time" "time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
) )
@ -14,7 +15,7 @@ var _ Balancer = (*rttBasedBalancer)(nil)
// Balancer is interface for load balancers // Balancer is interface for load balancers
type Balancer interface { type Balancer interface {
Pick() *Node Pick() string
} }
type rttBasedBalancer struct { type rttBasedBalancer struct {
@ -30,15 +31,14 @@ type rttFunc func(node *Node) time.Duration
// newRTTBasedLoad creates a new rtt based load balancer // newRTTBasedLoad creates a new rtt based load balancer
func newRTTBasedBalancer( func newRTTBasedBalancer(
nodes []*Node, logger log.ContextLogger, router adapter.Router, logger log.ContextLogger,
options option.BalancerOutboundOptions, options option.BalancerOutboundOptions,
rttFunc rttFunc, rttFunc rttFunc,
) (Balancer, error) { ) (Balancer, error) {
return &rttBasedBalancer{ return &rttBasedBalancer{
nodes: nodes,
rttFunc: rttFunc, rttFunc: rttFunc,
options: &options, options: &options,
HealthCheck: NewHealthCheck(nodes, logger, &options.Check), HealthCheck: NewHealthCheck(router, options.Outbounds, logger, &options.Check),
costs: NewWeightManager( costs: NewWeightManager(
logger, options.Pick.Costs, 1, logger, options.Pick.Costs, 1,
func(value, cost float64) float64 { func(value, cost float64) float64 {
@ -49,28 +49,33 @@ func newRTTBasedBalancer(
} }
// Select selects qualified nodes // Select selects qualified nodes
func (s *rttBasedBalancer) Pick() *Node { func (s *rttBasedBalancer) Pick() string {
nodes := s.HealthCheck.NodesByCategory() nodes := s.HealthCheck.NodesByCategory()
var candidates []*Node var candidates []*Node
if len(nodes.Qualified) > 0 { if len(nodes.Qualified) > 0 {
candidates = nodes.Qualified candidates = nodes.Qualified
for _, node := range candidates { for _, node := range candidates {
node.Weighted = time.Duration(s.costs.Apply(node.Outbound.Tag(), float64(s.rttFunc(node)))) node.Weighted = time.Duration(s.costs.Apply(node.Tag, float64(s.rttFunc(node))))
} }
sortNodes(candidates) sortNodes(candidates)
} else { } else {
candidates = nodes.Untested candidates = nodes.Untested
shuffleNodes(candidates) shuffleNodes(candidates)
} }
selects := selectNodes( selects := selectNodes(candidates, int(s.options.Pick.Expected), s.options.Pick.Baselines)
candidates, s.logger,
int(s.options.Pick.Expected), s.options.Pick.Baselines,
)
count := len(selects) count := len(selects)
if count == 0 { if count == 0 {
return nil return ""
} }
return selects[rand.Intn(count)] picked := selects[rand.Intn(count)]
s.logger.Debug(
"pick [", picked.Tag, "]",
" +W=", picked.Weighted,
" STD=", picked.Deviation,
" AVG=", picked.Average,
" Fail=", picked.Fail, "/", picked.All,
)
return picked.Tag
} }
func sortNodes(nodes []*Node) { func sortNodes(nodes []*Node) {

View File

@ -3,7 +3,6 @@ package balancer
import ( import (
"time" "time"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
) )
@ -23,7 +22,7 @@ import (
// 3. Speed priority: Baselines + `Expected Count <= 0`. // 3. Speed priority: Baselines + `Expected Count <= 0`.
// go through all baselines until find selects, if not, select none. Used in combination // go through all baselines until find selects, if not, select none. Used in combination
// with 'balancer.fallbackTag', it means: selects qualified nodes or use the fallback. // with 'balancer.fallbackTag', it means: selects qualified nodes or use the fallback.
func selectNodes(nodes []*Node, logger log.Logger, expected int, baselines []option.Duration) []*Node { func selectNodes(nodes []*Node, expected int, baselines []option.Duration) []*Node {
if len(nodes) == 0 { if len(nodes) == 0 {
// s.logger.Debug("no qualified nodes") // s.logger.Debug("no qualified nodes")
return nil return nil
@ -53,9 +52,6 @@ func selectNodes(nodes []*Node, logger log.Logger, expected int, baselines []opt
} }
// don't continue if find expected selects // don't continue if find expected selects
if count >= expected2 { if count >= expected2 {
if logger != nil {
logger.Debug("applied baseline: ", baseline)
}
break break
} }
} }

View File

@ -8,12 +8,12 @@ import (
func TestSelectNodes(t *testing.T) { func TestSelectNodes(t *testing.T) {
nodes := []*Node{ nodes := []*Node{
{HealthCheckStats: HealthCheckStats{Weighted: 50}}, {RTTStats: RTTStats{Weighted: 50}},
{HealthCheckStats: HealthCheckStats{Weighted: 70}}, {RTTStats: RTTStats{Weighted: 70}},
{HealthCheckStats: HealthCheckStats{Weighted: 100}}, {RTTStats: RTTStats{Weighted: 100}},
{HealthCheckStats: HealthCheckStats{Weighted: 110}}, {RTTStats: RTTStats{Weighted: 110}},
{HealthCheckStats: HealthCheckStats{Weighted: 120}}, {RTTStats: RTTStats{Weighted: 120}},
{HealthCheckStats: HealthCheckStats{Weighted: 150}}, {RTTStats: RTTStats{Weighted: 150}},
} }
tests := []struct { tests := []struct {
expected int expected int
@ -31,7 +31,7 @@ func TestSelectNodes(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run("", func(t *testing.T) { t.Run("", func(t *testing.T) {
if got := selectNodes(nodes, nil, tt.expected, tt.baselines); len(got) != tt.want { if got := selectNodes(nodes, tt.expected, tt.baselines); len(got) != tt.want {
t.Errorf("selectNodes() = %v, want %v", len(got), tt.want) t.Errorf("selectNodes() = %v, want %v", len(got), tt.want)
} }
}) })

View File

@ -6,6 +6,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
@ -16,15 +17,16 @@ type HealthCheck struct {
sync.Mutex sync.Mutex
ticker *time.Ticker ticker *time.Ticker
nodes []*Node router adapter.Router
tags []string
logger log.Logger logger log.Logger
options *option.HealthCheckSettings options *option.HealthCheckSettings
Results map[string]*HealthCheckRTTS results map[string]*rttStorage
} }
// NewHealthCheck creates a new HealthPing with settings // NewHealthCheck creates a new HealthPing with settings
func NewHealthCheck(outbounds []*Node, logger log.Logger, config *option.HealthCheckSettings) *HealthCheck { func NewHealthCheck(router adapter.Router, tags []string, logger log.Logger, config *option.HealthCheckSettings) *HealthCheck {
if config == nil { if config == nil {
config = &option.HealthCheckSettings{} config = &option.HealthCheckSettings{}
} }
@ -46,29 +48,33 @@ func NewHealthCheck(outbounds []*Node, logger log.Logger, config *option.HealthC
config.Timeout = option.Duration(5 * time.Second) config.Timeout = option.Duration(5 * time.Second)
} }
return &HealthCheck{ return &HealthCheck{
nodes: outbounds, router: router,
tags: tags,
options: config, options: config,
Results: nil, results: make(map[string]*rttStorage),
logger: logger, logger: logger,
} }
} }
// Start starts the health check service // Start starts the health check service
func (h *HealthCheck) Start() error { func (h *HealthCheck) Start() error {
h.Lock()
defer h.Unlock()
if h.ticker != nil { if h.ticker != nil {
return nil return nil
} }
interval := time.Duration(h.options.Interval) * time.Duration(h.options.SamplingCount) interval := time.Duration(h.options.Interval) * time.Duration(h.options.SamplingCount)
ticker := time.NewTicker(interval) ticker := time.NewTicker(interval)
h.ticker = ticker h.ticker = ticker
// one time instant check
h.Check()
go func() { go func() {
h.doCheck(0, 1)
for { for {
h.doCheck(interval, h.options.SamplingCount)
_, ok := <-ticker.C _, ok := <-ticker.C
if !ok { if !ok {
break break
} }
h.doCheck(interval, h.options.SamplingCount)
} }
}() }()
return nil return nil
@ -76,17 +82,17 @@ func (h *HealthCheck) Start() error {
// Stop stops the health check service // Stop stops the health check service
func (h *HealthCheck) Stop() { func (h *HealthCheck) Stop() {
h.ticker.Stop() h.Lock()
h.ticker = nil defer h.Unlock()
if h.ticker != nil {
h.ticker.Stop()
h.ticker = nil
}
} }
// Check does a one time health check // Check does a one time health check
func (h *HealthCheck) Check() error { func (h *HealthCheck) Check() {
if len(h.nodes) == 0 { go h.doCheck(0, 1)
return nil
}
h.doCheck(0, 1)
return nil
} }
type rtt struct { type rtt struct {
@ -97,14 +103,15 @@ type rtt struct {
// doCheck performs the 'rounds' amount checks in given 'duration'. You should make // doCheck performs the 'rounds' amount checks in given 'duration'. You should make
// sure all tags are valid for current balancer // sure all tags are valid for current balancer
func (h *HealthCheck) doCheck(duration time.Duration, rounds int) { func (h *HealthCheck) doCheck(duration time.Duration, rounds int) {
count := len(h.nodes) * rounds nodes := h.refreshNodes()
count := len(nodes) * rounds
if count == 0 { if count == 0 {
return return
} }
ch := make(chan *rtt, count) ch := make(chan *rtt, count)
// rtts := make(map[string][]time.Duration) // rtts := make(map[string][]time.Duration)
for _, node := range h.nodes { for _, node := range nodes {
tag, detour := node.Outbound.Tag(), node.Outbound tag, detour := node.Tag(), node
client := newPingClient( client := newPingClient(
detour, detour,
h.options.Destination, h.options.Destination,
@ -160,18 +167,10 @@ func (h *HealthCheck) doCheck(duration time.Duration, rounds int) {
func (h *HealthCheck) PutResult(tag string, rtt time.Duration) { func (h *HealthCheck) PutResult(tag string, rtt time.Duration) {
h.Lock() h.Lock()
defer h.Unlock() defer h.Unlock()
if h.Results == nil { r, ok := h.results[tag]
h.Results = make(map[string]*HealthCheckRTTS)
}
r, ok := h.Results[tag]
if !ok { if !ok {
// validity is 2 times to sampling period, since the check are // the result may come after the node is removed
// distributed in the time line randomly, in extreme cases, return
// previous checks are distributed on the left, and latters
// on the right
validity := time.Duration(h.options.Interval) * time.Duration(h.options.SamplingCount) * 2
r = NewHealthPingResult(h.options.SamplingCount, validity)
h.Results[tag] = r
} }
r.Put(rtt) r.Put(rtt)
} }

View File

@ -1,6 +1,11 @@
package balancer package balancer
import "time" import (
"strings"
"time"
"github.com/sagernet/sing-box/adapter"
)
// CategorizedNodes holds the categorized nodes // CategorizedNodes holds the categorized nodes
type CategorizedNodes struct { type CategorizedNodes struct {
@ -12,28 +17,24 @@ type CategorizedNodes struct {
func (h *HealthCheck) NodesByCategory() *CategorizedNodes { func (h *HealthCheck) NodesByCategory() *CategorizedNodes {
h.Lock() h.Lock()
defer h.Unlock() defer h.Unlock()
if h == nil || h.Results == nil { if h == nil || len(h.results) == 0 {
return &CategorizedNodes{ return &CategorizedNodes{}
Untested: h.nodes,
}
} }
nodes := &CategorizedNodes{ nodes := &CategorizedNodes{
Qualified: make([]*Node, 0, len(h.nodes)), Qualified: make([]*Node, 0, len(h.results)),
Unqualified: make([]*Node, 0, len(h.nodes)), Unqualified: make([]*Node, 0, len(h.results)),
Failed: make([]*Node, 0, len(h.nodes)), Failed: make([]*Node, 0, len(h.results)),
Untested: make([]*Node, 0, len(h.nodes)), Untested: make([]*Node, 0, len(h.results)),
} }
for _, node := range h.nodes { for tag, result := range h.results {
r, ok := h.Results[node.Outbound.Tag()] node := &Node{
if !ok { Tag: tag,
node.HealthCheckStats = healthPingStatsUntested RTTStats: result.Get(),
continue
} }
node.HealthCheckStats = r.Get()
switch { switch {
case node.HealthCheckStats.All == 0: case node.RTTStats.All == 0:
nodes.Untested = append(nodes.Untested, node) nodes.Untested = append(nodes.Untested, node)
case node.HealthCheckStats.All == node.HealthCheckStats.Fail, case node.RTTStats.All == node.RTTStats.Fail,
float64(node.Fail)/float64(node.All) > float64(h.options.Tolerance): float64(node.Fail)/float64(node.All) > float64(h.options.Tolerance):
nodes.Failed = append(nodes.Failed, node) nodes.Failed = append(nodes.Failed, node)
case h.options.MaxRTT > 0 && node.Average > time.Duration(h.options.MaxRTT): case h.options.MaxRTT > 0 && node.Average > time.Duration(h.options.MaxRTT):
@ -44,3 +45,49 @@ func (h *HealthCheck) NodesByCategory() *CategorizedNodes {
} }
return nodes return nodes
} }
// CoveredOutbounds returns the outbounds that should covered by health check
func CoveredOutbounds(router adapter.Router, tags []string) []adapter.Outbound {
outbounds := router.Outbounds()
nodes := make([]adapter.Outbound, 0, len(outbounds))
for _, outbound := range outbounds {
for _, prefix := range tags {
tag := outbound.Tag()
if strings.HasPrefix(tag, prefix) {
nodes = append(nodes, outbound)
}
}
}
return nodes
}
// refreshNodes matches nodes from router by tag prefix, and refreshes the health check results
func (h *HealthCheck) refreshNodes() []adapter.Outbound {
h.Lock()
defer h.Unlock()
nodes := CoveredOutbounds(h.router, h.tags)
tags := make(map[string]struct{})
for _, node := range nodes {
tag := node.Tag()
tags[tag] = struct{}{}
// make it known to the health check results
r, ok := h.results[tag]
if !ok {
// validity is 2 times to sampling period, since the check are
// distributed in the time line randomly, in extreme cases,
// previous checks are distributed on the left, and latters
// on the right
validity := time.Duration(h.options.Interval) * time.Duration(h.options.SamplingCount) * 2
r = newRTTStorage(h.options.SamplingCount, validity)
h.results[tag] = r
}
}
// remove unused rttStorage
for tag := range h.results {
if _, ok := tags[tag]; !ok {
delete(h.results, tag)
}
}
return nodes
}

View File

@ -3,17 +3,18 @@ package balancer
import ( import (
"time" "time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
) )
// NewLeastLoad creates a new LeastLoad outbound // NewLeastLoad creates a new LeastLoad outbound
func NewLeastLoad( func NewLeastLoad(
nodes []*Node, logger log.ContextLogger, router adapter.Router, logger log.ContextLogger,
options option.BalancerOutboundOptions, options option.BalancerOutboundOptions,
) (Balancer, error) { ) (Balancer, error) {
return newRTTBasedBalancer( return newRTTBasedBalancer(
nodes, logger, options, router, logger, options,
func(node *Node) time.Duration { func(node *Node) time.Duration {
return node.Deviation return node.Deviation
}, },

View File

@ -3,17 +3,18 @@ package balancer
import ( import (
"time" "time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
) )
// NewLeastPing creates a new LeastPing outbound // NewLeastPing creates a new LeastPing outbound
func NewLeastPing( func NewLeastPing(
nodes []*Node, logger log.ContextLogger, router adapter.Router, logger log.ContextLogger,
options option.BalancerOutboundOptions, options option.BalancerOutboundOptions,
) (Balancer, error) { ) (Balancer, error) {
return newRTTBasedBalancer( return newRTTBasedBalancer(
nodes, logger, options, router, logger, options,
func(node *Node) time.Duration { func(node *Node) time.Duration {
return node.Average return node.Average
}, },

View File

@ -1,10 +1,6 @@
package balancer package balancer
import ( var healthPingStatsUntested = RTTStats{
"github.com/sagernet/sing-box/adapter"
)
var healthPingStatsUntested = HealthCheckStats{
All: 0, All: 0,
Fail: 0, Fail: 0,
Deviation: rttUntested, Deviation: rttUntested,
@ -15,14 +11,6 @@ var healthPingStatsUntested = HealthCheckStats{
// Node is a banalcer node with health check result // Node is a banalcer node with health check result
type Node struct { type Node struct {
Outbound adapter.Outbound Tag string
HealthCheckStats RTTStats
}
// NewNode creates a new balancer node from outbound
func NewNode(outbound adapter.Outbound) *Node {
return &Node{
Outbound: outbound,
HealthCheckStats: healthPingStatsUntested,
}
} }

View File

@ -11,8 +11,8 @@ const (
rttUnqualified rttUnqualified
) )
// HealthCheckStats is the statistics of HealthPingRTTS // RTTStats is the statistics of health check RTTs
type HealthCheckStats struct { type RTTStats struct {
All int All int
Fail int Fail int
Deviation time.Duration Deviation time.Duration
@ -23,15 +23,15 @@ type HealthCheckStats struct {
Weighted time.Duration Weighted time.Duration
} }
// HealthCheckRTTS holds ping rtts for health Checker // rttStorage holds ping rtts for health Checker
type HealthCheckRTTS struct { type rttStorage struct {
idx int idx int
cap int cap int
validity time.Duration validity time.Duration
rtts []*pingRTT rtts []*pingRTT
lastUpdateAt time.Time lastUpdateAt time.Time
stats *HealthCheckStats stats *RTTStats
} }
type pingRTT struct { type pingRTT struct {
@ -39,25 +39,25 @@ type pingRTT struct {
value time.Duration value time.Duration
} }
// NewHealthPingResult returns a *HealthPingResult with specified capacity // newRTTStorage returns a *HealthPingResult with specified capacity
func NewHealthPingResult(cap int, validity time.Duration) *HealthCheckRTTS { func newRTTStorage(cap int, validity time.Duration) *rttStorage {
return &HealthCheckRTTS{cap: cap, validity: validity} return &rttStorage{cap: cap, validity: validity}
} }
// Get gets statistics of the HealthPingRTTS // Get gets statistics of the HealthPingRTTS
func (h *HealthCheckRTTS) Get() HealthCheckStats { func (h *rttStorage) Get() RTTStats {
return h.getStatistics() return h.getStatistics()
} }
// GetWithCache get statistics and write cache for next call // GetWithCache get statistics and write cache for next call
// Make sure use Mutex.Lock() before calling it, RWMutex.RLock() // Make sure use Mutex.Lock() before calling it, RWMutex.RLock()
// is not an option since it writes cache // is not an option since it writes cache
func (h *HealthCheckRTTS) GetWithCache() HealthCheckStats { func (h *rttStorage) GetWithCache() RTTStats {
lastPutAt := h.rtts[h.idx].time lastPutAt := h.rtts[h.idx].time
now := time.Now() now := time.Now()
if h.stats == nil || h.lastUpdateAt.Before(lastPutAt) || h.findOutdated(now) >= 0 { if h.stats == nil || h.lastUpdateAt.Before(lastPutAt) || h.findOutdated(now) >= 0 {
if h.stats == nil { if h.stats == nil {
h.stats = &HealthCheckStats{} h.stats = &RTTStats{}
} }
*h.stats = h.getStatistics() *h.stats = h.getStatistics()
h.lastUpdateAt = now h.lastUpdateAt = now
@ -66,7 +66,7 @@ func (h *HealthCheckRTTS) GetWithCache() HealthCheckStats {
} }
// Put puts a new rtt to the HealthPingResult // Put puts a new rtt to the HealthPingResult
func (h *HealthCheckRTTS) Put(d time.Duration) { func (h *rttStorage) Put(d time.Duration) {
if h.rtts == nil { if h.rtts == nil {
h.rtts = make([]*pingRTT, h.cap) h.rtts = make([]*pingRTT, h.cap)
for i := 0; i < h.cap; i++ { for i := 0; i < h.cap; i++ {
@ -80,7 +80,7 @@ func (h *HealthCheckRTTS) Put(d time.Duration) {
h.rtts[h.idx].value = d h.rtts[h.idx].value = d
} }
func (h *HealthCheckRTTS) calcIndex(step int) int { func (h *rttStorage) calcIndex(step int) int {
idx := h.idx idx := h.idx
idx += step idx += step
if idx >= h.cap { if idx >= h.cap {
@ -89,8 +89,8 @@ func (h *HealthCheckRTTS) calcIndex(step int) int {
return idx return idx
} }
func (h *HealthCheckRTTS) getStatistics() HealthCheckStats { func (h *rttStorage) getStatistics() RTTStats {
stats := HealthCheckStats{} stats := RTTStats{}
stats.Fail = 0 stats.Fail = 0
stats.Max = 0 stats.Max = 0
stats.Min = rttFailed stats.Min = rttFailed
@ -125,7 +125,7 @@ func (h *HealthCheckRTTS) getStatistics() HealthCheckStats {
case stats.All == 0: case stats.All == 0:
return healthPingStatsUntested return healthPingStatsUntested
case stats.Fail == stats.All: case stats.Fail == stats.All:
return HealthCheckStats{ return RTTStats{
All: stats.All, All: stats.All,
Fail: stats.Fail, Fail: stats.Fail,
Deviation: rttFailed, Deviation: rttFailed,
@ -151,7 +151,7 @@ func (h *HealthCheckRTTS) getStatistics() HealthCheckStats {
return stats return stats
} }
func (h *HealthCheckRTTS) findOutdated(now time.Time) int { func (h *rttStorage) findOutdated(now time.Time) int {
for i := h.cap - 1; i < 2*h.cap; i++ { for i := h.cap - 1; i < 2*h.cap; i++ {
// from oldest to latest // from oldest to latest
idx := h.calcIndex(i) idx := h.calcIndex(i)

View File

@ -1,22 +1,20 @@
package balancer_test package balancer
import ( import (
"math" "math"
"reflect" "reflect"
"testing" "testing"
"time" "time"
"github.com/sagernet/sing-box/balancer"
) )
func TestHealthPingResults(t *testing.T) { func TestRTTStorage(t *testing.T) {
rtts := []int64{60, 140, 60, 140, 60, 60, 140, 60, 140} rtts := []int64{60, 140, 60, 140, 60, 60, 140, 60, 140}
hr := balancer.NewHealthPingResult(4, time.Hour) hr := newRTTStorage(4, time.Hour)
for _, rtt := range rtts { for _, rtt := range rtts {
hr.Put(time.Duration(rtt)) hr.Put(time.Duration(rtt))
} }
rttFailed := time.Duration(math.MaxInt64) rttFailed := time.Duration(math.MaxInt64)
expected := &balancer.HealthCheckStats{ expected := &RTTStats{
All: 4, All: 4,
Fail: 0, Fail: 0,
Deviation: 40, Deviation: 40,
@ -37,7 +35,7 @@ func TestHealthPingResults(t *testing.T) {
} }
hr.Put(rttFailed) hr.Put(rttFailed)
hr.Put(rttFailed) hr.Put(rttFailed)
expected = &balancer.HealthCheckStats{ expected = &RTTStats{
All: 4, All: 4,
Fail: 4, Fail: 4,
Deviation: 0, Deviation: 0,
@ -53,7 +51,7 @@ func TestHealthPingResults(t *testing.T) {
func TestHealthPingResultsIgnoreOutdated(t *testing.T) { func TestHealthPingResultsIgnoreOutdated(t *testing.T) {
rtts := []int64{60, 140, 60, 140} rtts := []int64{60, 140, 60, 140}
hr := balancer.NewHealthPingResult(4, time.Duration(10)*time.Millisecond) hr := newRTTStorage(4, time.Duration(10)*time.Millisecond)
for i, rtt := range rtts { for i, rtt := range rtts {
if i == 2 { if i == 2 {
// wait for previous 2 outdated // wait for previous 2 outdated
@ -62,7 +60,7 @@ func TestHealthPingResultsIgnoreOutdated(t *testing.T) {
hr.Put(time.Duration(rtt)) hr.Put(time.Duration(rtt))
} }
hr.Get() hr.Get()
expected := &balancer.HealthCheckStats{ expected := &RTTStats{
All: 2, All: 2,
Fail: 0, Fail: 0,
Deviation: 40, Deviation: 40,
@ -76,7 +74,7 @@ func TestHealthPingResultsIgnoreOutdated(t *testing.T) {
} }
// wait for all outdated // wait for all outdated
time.Sleep(time.Duration(10) * time.Millisecond) time.Sleep(time.Duration(10) * time.Millisecond)
expected = &balancer.HealthCheckStats{ expected = &RTTStats{
All: 0, All: 0,
Fail: 0, Fail: 0,
Deviation: 0, Deviation: 0,
@ -90,7 +88,7 @@ func TestHealthPingResultsIgnoreOutdated(t *testing.T) {
} }
hr.Put(time.Duration(60)) hr.Put(time.Duration(60))
expected = &balancer.HealthCheckStats{ expected = &RTTStats{
All: 1, All: 1,
Fail: 0, Fail: 0,
// 1 sample, std=0.5rtt // 1 sample, std=0.5rtt

View File

@ -27,7 +27,6 @@ type Balancer struct {
fallbackTag string fallbackTag string
balancer.Balancer balancer.Balancer
nodes []*balancer.Node
fallback adapter.Outbound fallback adapter.Outbound
} }
@ -106,13 +105,6 @@ func (s *Balancer) initialize() error {
return E.New("fallback outbound not found: ", s.fallbackTag) return E.New("fallback outbound not found: ", s.fallbackTag)
} }
s.fallback = outbound s.fallback = outbound
for i, tag := range s.tags {
outbound, loaded := s.router.Outbound(tag)
if !loaded {
return E.New("outbound ", i, " not found: ", tag)
}
s.nodes = append(s.nodes, balancer.NewNode(outbound))
}
return nil return nil
} }
@ -128,19 +120,34 @@ func (s *Balancer) setBalancer(b balancer.Balancer) error {
} }
func (s *Balancer) pick() adapter.Outbound { func (s *Balancer) pick() adapter.Outbound {
if s.Balancer != nil { tag := s.pickTag()
selected := s.Balancer.Pick() if tag == "" {
if selected == nil {
return s.fallback
}
return selected.Outbound
}
// not started
count := len(s.nodes)
if count == 0 {
// goes to fallback
return s.fallback return s.fallback
} }
picked := s.nodes[rand.Intn(count)] outbound, ok := s.router.Outbound(tag)
return picked.Outbound if !ok {
return s.fallback
}
return outbound
}
func (s *Balancer) pickTag() string {
if s.Balancer == nil {
// not started yet, pick a random one
return s.randomTag()
}
tag := s.Balancer.Pick()
if tag == "" {
return ""
}
return tag
}
func (s *Balancer) randomTag() string {
nodes := balancer.CoveredOutbounds(s.router, s.tags)
count := len(nodes)
if count == 0 {
return ""
}
return s.tags[rand.Intn(count)]
} }

View File

@ -6,7 +6,6 @@ import (
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
) )
var ( var (
@ -23,9 +22,6 @@ type LeastLoad struct {
// NewLeastLoad creates a new LeastLoad outbound // NewLeastLoad creates a new LeastLoad outbound
func NewLeastLoad(router adapter.Router, logger log.ContextLogger, tag string, options option.BalancerOutboundOptions) (*LeastLoad, error) { func NewLeastLoad(router adapter.Router, logger log.ContextLogger, tag string, options option.BalancerOutboundOptions) (*LeastLoad, error) {
if len(options.Outbounds) == 0 {
return nil, E.New("missing tags")
}
return &LeastLoad{ return &LeastLoad{
Balancer: NewBalancer( Balancer: NewBalancer(
C.TypeLeastLoad, router, logger, tag, C.TypeLeastLoad, router, logger, tag,
@ -41,7 +37,7 @@ func (s *LeastLoad) Start() error {
if err != nil { if err != nil {
return err return err
} }
b, err := balancer.NewLeastLoad(s.nodes, s.logger, s.options) b, err := balancer.NewLeastLoad(s.router, s.logger, s.options)
if err != nil { if err != nil {
return err return err
} }

View File

@ -6,7 +6,6 @@ import (
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
) )
var ( var (
@ -23,9 +22,6 @@ type LeastPing struct {
// NewLeastPing creates a new LeastPing outbound // NewLeastPing creates a new LeastPing outbound
func NewLeastPing(router adapter.Router, logger log.ContextLogger, tag string, options option.BalancerOutboundOptions) (*LeastPing, error) { func NewLeastPing(router adapter.Router, logger log.ContextLogger, tag string, options option.BalancerOutboundOptions) (*LeastPing, error) {
if len(options.Outbounds) == 0 {
return nil, E.New("missing tags")
}
return &LeastPing{ return &LeastPing{
Balancer: NewBalancer( Balancer: NewBalancer(
C.TypeLeastPing, router, logger, tag, C.TypeLeastPing, router, logger, tag,
@ -41,7 +37,7 @@ func (s *LeastPing) Start() error {
if err != nil { if err != nil {
return err return err
} }
b, err := balancer.NewLeastPing(s.nodes, s.logger, s.options) b, err := balancer.NewLeastPing(s.router, s.logger, s.options)
if err != nil { if err != nil {
return err return err
} }