diff --git a/balancer/healthcheck_nodes.go b/balancer/healthcheck_nodes.go new file mode 100644 index 00000000..105a9459 --- /dev/null +++ b/balancer/healthcheck_nodes.go @@ -0,0 +1,46 @@ +package balancer + +import "time" + +// CategorizedNodes holds the categorized nodes +type CategorizedNodes struct { + Qualified, Unqualified []*Node + Failed, Untested []*Node +} + +// NodesByCategory returns the categorized nodes +func (h *HealthCheck) NodesByCategory() *CategorizedNodes { + h.Lock() + defer h.Unlock() + if h == nil || h.Results == nil { + return &CategorizedNodes{ + Untested: h.nodes, + } + } + nodes := &CategorizedNodes{ + Qualified: make([]*Node, 0, len(h.nodes)), + Unqualified: make([]*Node, 0, len(h.nodes)), + Failed: make([]*Node, 0, len(h.nodes)), + Untested: make([]*Node, 0, len(h.nodes)), + } + for _, node := range h.nodes { + r, ok := h.Results[node.Outbound.Tag()] + if !ok { + node.HealthCheckStats = healthPingStatsUntested + continue + } + node.HealthCheckStats = r.Get() + switch { + case node.HealthCheckStats.All == 0: + nodes.Untested = append(nodes.Untested, node) + case node.HealthCheckStats.All == node.HealthCheckStats.Fail, + float64(node.Fail)/float64(node.All) > float64(h.options.Tolerance): + nodes.Failed = append(nodes.Failed, node) + case h.options.MaxRTT > 0 && node.Average > time.Duration(h.options.MaxRTT): + nodes.Unqualified = append(nodes.Unqualified, node) + default: + nodes.Qualified = append(nodes.Qualified, node) + } + } + return nodes +} diff --git a/balancer/healthcheck_result.go b/balancer/healthcheck_result.go index 780788b2..f83c5848 100644 --- a/balancer/healthcheck_result.go +++ b/balancer/healthcheck_result.go @@ -20,7 +20,7 @@ type HealthCheckStats struct { Max time.Duration Min time.Duration - applied time.Duration + Weighted time.Duration } // HealthCheckRTTS holds ping rtts for health Checker @@ -45,21 +45,24 @@ func NewHealthPingResult(cap int, validity time.Duration) *HealthCheckRTTS { } // Get gets statistics of the HealthPingRTTS -func (h *HealthCheckRTTS) Get() *HealthCheckStats { +func (h *HealthCheckRTTS) Get() HealthCheckStats { return h.getStatistics() } // GetWithCache get statistics and write cache for next call // Make sure use Mutex.Lock() before calling it, RWMutex.RLock() // is not an option since it writes cache -func (h *HealthCheckRTTS) GetWithCache() *HealthCheckStats { +func (h *HealthCheckRTTS) GetWithCache() HealthCheckStats { lastPutAt := h.rtts[h.idx].time now := time.Now() if h.stats == nil || h.lastUpdateAt.Before(lastPutAt) || h.findOutdated(now) >= 0 { - h.stats = h.getStatistics() + if h.stats == nil { + h.stats = &HealthCheckStats{} + } + *h.stats = h.getStatistics() h.lastUpdateAt = now } - return h.stats + return *h.stats } // Put puts a new rtt to the HealthPingResult @@ -86,14 +89,14 @@ func (h *HealthCheckRTTS) calcIndex(step int) int { return idx } -func (h *HealthCheckRTTS) getStatistics() *HealthCheckStats { - stats := &HealthCheckStats{} +func (h *HealthCheckRTTS) getStatistics() HealthCheckStats { + stats := HealthCheckStats{} stats.Fail = 0 stats.Max = 0 stats.Min = rttFailed sum := time.Duration(0) cnt := 0 - validRTTs := make([]time.Duration, 0) + validRTTs := make([]time.Duration, 0, h.cap) for _, rtt := range h.rtts { switch { case rtt.value == 0 || time.Since(rtt.time) > h.validity: @@ -115,9 +118,22 @@ func (h *HealthCheckRTTS) getStatistics() *HealthCheckStats { stats.All = cnt + stats.Fail if cnt == 0 { stats.Min = 0 - return stats + return healthPingStatsUntested } stats.Average = time.Duration(int(sum) / cnt) + switch { + case stats.All == 0: + return healthPingStatsUntested + case stats.Fail == stats.All: + return HealthCheckStats{ + All: stats.All, + Fail: stats.Fail, + Deviation: rttFailed, + Average: rttFailed, + Max: rttFailed, + Min: rttFailed, + } + } var std float64 if cnt < 2 { // no enough data for standard deviation, we assume it's half of the average rtt diff --git a/balancer/leastload.go b/balancer/leastload.go index 6136fcac..7514eacf 100644 --- a/balancer/leastload.go +++ b/balancer/leastload.go @@ -41,8 +41,17 @@ func NewLeastLoad( // Select selects qualified nodes func (s *LeastLoad) Select() *Node { - qualified, _ := s.getNodes() - selects := s.selectLeastLoad(qualified) + nodes := s.HealthCheck.NodesByCategory() + var candidates []*Node + if len(nodes.Qualified) > 0 { + candidates := nodes.Qualified + appliyCost(candidates, s.costs) + leastPingSort(candidates) + } else { + candidates = nodes.Untested + shuffle(candidates) + } + selects := s.selectLeastLoad(candidates) count := len(selects) if count == 0 { return nil @@ -89,7 +98,7 @@ func (s *LeastLoad) selectLeastLoad(nodes []*Node) []*Node { for _, b := range s.options.Baselines { baseline := time.Duration(b) for i := 0; i < availableCount; i++ { - if nodes[i].applied > baseline { + if nodes[i].Weighted > baseline { break } count = i + 1 @@ -106,57 +115,21 @@ func (s *LeastLoad) selectLeastLoad(nodes []*Node) []*Node { return nodes[:count] } -func (s *LeastLoad) getNodes() ([]*Node, []*Node) { - s.HealthCheck.Lock() - defer s.HealthCheck.Unlock() - - qualified := make([]*Node, 0) - unqualified := make([]*Node, 0) - failed := make([]*Node, 0) - untested := make([]*Node, 0) - others := make([]*Node, 0) - for _, node := range s.nodes { - node.FetchStats(s.HealthCheck) - switch { - case node.All == 0: - node.applied = rttUntested - untested = append(untested, node) - case s.options.HealthCheck.MaxRTT > 0 && node.Average > time.Duration(s.options.HealthCheck.MaxRTT): - node.applied = rttUnqualified - unqualified = append(unqualified, node) - case float64(node.Fail)/float64(node.All) > float64(s.options.HealthCheck.Tolerance): - node.applied = rttFailed - if node.All-node.Fail == 0 { - // no good, put them after has-good nodes - node.applied = rttFailed - node.Deviation = rttFailed - node.Average = rttFailed - } - failed = append(failed, node) - default: - node.applied = time.Duration(s.costs.Apply(node.Outbound.Tag(), float64(node.Deviation))) - qualified = append(qualified, node) - } +func appliyCost(nodes []*Node, costs *WeightManager) { + for _, node := range nodes { + node.Weighted = time.Duration(costs.Apply(node.Outbound.Tag(), float64(node.Deviation))) } - if len(qualified) > 0 { - leastloadSort(qualified) - others = append(others, unqualified...) - others = append(others, untested...) - others = append(others, failed...) - } else { - qualified = untested - others = append(others, unqualified...) - others = append(others, failed...) - } - return qualified, others } func leastloadSort(nodes []*Node) { sort.Slice(nodes, func(i, j int) bool { left := nodes[i] right := nodes[j] - if left.applied != right.applied { - return left.applied < right.applied + if left.Weighted != right.Weighted { + return left.Weighted < right.Weighted + } + if left.Deviation != right.Deviation { + return left.Deviation < right.Deviation } if left.Average != right.Average { return left.Average < right.Average diff --git a/balancer/leastping.go b/balancer/leastping.go index ba7043bc..708eabef 100644 --- a/balancer/leastping.go +++ b/balancer/leastping.go @@ -33,66 +33,27 @@ func NewLeastPing( // Select selects least ping node func (s *LeastPing) Select() *Node { - qualified, _ := s.getNodes() - if len(qualified) == 0 { + nodes := s.HealthCheck.NodesByCategory() + var candidates []*Node + if len(nodes.Qualified) > 0 { + candidates := nodes.Qualified + leastPingSort(candidates) + } else { + candidates = nodes.Untested + shuffle(candidates) + } + if len(candidates) == 0 { return nil } - return qualified[0] -} - -func (s *LeastPing) getNodes() ([]*Node, []*Node) { - s.HealthCheck.Lock() - defer s.HealthCheck.Unlock() - - qualified := make([]*Node, 0) - unqualified := make([]*Node, 0) - failed := make([]*Node, 0) - untested := make([]*Node, 0) - others := make([]*Node, 0) - for _, node := range s.nodes { - node.FetchStats(s.HealthCheck) - switch { - case node.All == 0: - node.applied = rttUntested - untested = append(untested, node) - case s.options.HealthCheck.MaxRTT > 0 && node.Average > time.Duration(s.options.HealthCheck.MaxRTT): - node.applied = rttUnqualified - unqualified = append(unqualified, node) - case float64(node.Fail)/float64(node.All) > float64(s.options.HealthCheck.Tolerance): - node.applied = rttFailed - if node.All-node.Fail == 0 { - // no good, put them after has-good nodes - node.applied = rttFailed - node.Deviation = rttFailed - node.Average = rttFailed - } - failed = append(failed, node) - default: - node.applied = node.Average - qualified = append(qualified, node) - } - } - if len(qualified) > 0 { - leastPingSort(qualified) - others = append(others, unqualified...) - others = append(others, untested...) - others = append(others, failed...) - } else { - // random node if not tested - shuffle(untested) - qualified = untested - others = append(others, unqualified...) - others = append(others, failed...) - } - return qualified, others + return candidates[0] } func leastPingSort(nodes []*Node) { sort.Slice(nodes, func(i, j int) bool { left := nodes[i] right := nodes[j] - if left.applied != right.applied { - return left.applied < right.applied + if left.Average != right.Average { + return left.Average < right.Average } if left.Fail != right.Fail { return left.Fail < right.Fail diff --git a/balancer/node.go b/balancer/node.go index 7f689a6d..aa370ef8 100644 --- a/balancer/node.go +++ b/balancer/node.go @@ -4,8 +4,13 @@ import ( "github.com/sagernet/sing-box/adapter" ) -var healthPingStatsZero = HealthCheckStats{ - applied: rttUntested, +var healthPingStatsUntested = HealthCheckStats{ + All: 0, + Fail: 0, + Deviation: rttUntested, + Average: rttUntested, + Max: rttUntested, + Min: rttUntested, } // Node is a banalcer node with health check result @@ -18,20 +23,6 @@ type Node struct { func NewNode(outbound adapter.Outbound) *Node { return &Node{ Outbound: outbound, - HealthCheckStats: healthPingStatsZero, + HealthCheckStats: healthPingStatsUntested, } } - -// FetchStats fetches statistics from *HealthPing p -func (s *Node) FetchStats(p *HealthCheck) { - if p == nil || p.Results == nil { - s.HealthCheckStats = healthPingStatsZero - return - } - r, ok := p.Results[s.Outbound.Tag()] - if !ok { - s.HealthCheckStats = healthPingStatsZero - return - } - s.HealthCheckStats = *r.Get() -}