fetch health check nodes in time from the router by prefix ...

* Balancer & HealthCheck not hold adapter.Outbound, but only tags
* rename and unexport some structs and fields
* fix no check in the first rounds
This commit is contained in:
jebbs 2022-10-11 17:54:39 +08:00
parent 6d417949ae
commit a2d428f246
13 changed files with 182 additions and 148 deletions

View File

@ -6,6 +6,7 @@ import (
"sort"
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
)
@ -14,7 +15,7 @@ var _ Balancer = (*rttBasedBalancer)(nil)
// Balancer is interface for load balancers
type Balancer interface {
Pick() *Node
Pick() string
}
type rttBasedBalancer struct {
@ -30,15 +31,14 @@ type rttFunc func(node *Node) time.Duration
// newRTTBasedLoad creates a new rtt based load balancer
func newRTTBasedBalancer(
nodes []*Node, logger log.ContextLogger,
router adapter.Router, logger log.ContextLogger,
options option.BalancerOutboundOptions,
rttFunc rttFunc,
) (Balancer, error) {
return &rttBasedBalancer{
nodes: nodes,
rttFunc: rttFunc,
options: &options,
HealthCheck: NewHealthCheck(nodes, logger, &options.Check),
HealthCheck: NewHealthCheck(router, options.Outbounds, logger, &options.Check),
costs: NewWeightManager(
logger, options.Pick.Costs, 1,
func(value, cost float64) float64 {
@ -49,28 +49,33 @@ func newRTTBasedBalancer(
}
// Select selects qualified nodes
func (s *rttBasedBalancer) Pick() *Node {
func (s *rttBasedBalancer) Pick() string {
nodes := s.HealthCheck.NodesByCategory()
var candidates []*Node
if len(nodes.Qualified) > 0 {
candidates = nodes.Qualified
for _, node := range candidates {
node.Weighted = time.Duration(s.costs.Apply(node.Outbound.Tag(), float64(s.rttFunc(node))))
node.Weighted = time.Duration(s.costs.Apply(node.Tag, float64(s.rttFunc(node))))
}
sortNodes(candidates)
} else {
candidates = nodes.Untested
shuffleNodes(candidates)
}
selects := selectNodes(
candidates, s.logger,
int(s.options.Pick.Expected), s.options.Pick.Baselines,
)
selects := selectNodes(candidates, int(s.options.Pick.Expected), s.options.Pick.Baselines)
count := len(selects)
if count == 0 {
return nil
return ""
}
return selects[rand.Intn(count)]
picked := selects[rand.Intn(count)]
s.logger.Debug(
"pick [", picked.Tag, "]",
" +W=", picked.Weighted,
" STD=", picked.Deviation,
" AVG=", picked.Average,
" Fail=", picked.Fail, "/", picked.All,
)
return picked.Tag
}
func sortNodes(nodes []*Node) {

View File

@ -3,7 +3,6 @@ package balancer
import (
"time"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
)
@ -23,7 +22,7 @@ import (
// 3. Speed priority: Baselines + `Expected Count <= 0`.
// go through all baselines until find selects, if not, select none. Used in combination
// with 'balancer.fallbackTag', it means: selects qualified nodes or use the fallback.
func selectNodes(nodes []*Node, logger log.Logger, expected int, baselines []option.Duration) []*Node {
func selectNodes(nodes []*Node, expected int, baselines []option.Duration) []*Node {
if len(nodes) == 0 {
// s.logger.Debug("no qualified nodes")
return nil
@ -53,9 +52,6 @@ func selectNodes(nodes []*Node, logger log.Logger, expected int, baselines []opt
}
// don't continue if find expected selects
if count >= expected2 {
if logger != nil {
logger.Debug("applied baseline: ", baseline)
}
break
}
}

View File

@ -8,12 +8,12 @@ import (
func TestSelectNodes(t *testing.T) {
nodes := []*Node{
{HealthCheckStats: HealthCheckStats{Weighted: 50}},
{HealthCheckStats: HealthCheckStats{Weighted: 70}},
{HealthCheckStats: HealthCheckStats{Weighted: 100}},
{HealthCheckStats: HealthCheckStats{Weighted: 110}},
{HealthCheckStats: HealthCheckStats{Weighted: 120}},
{HealthCheckStats: HealthCheckStats{Weighted: 150}},
{RTTStats: RTTStats{Weighted: 50}},
{RTTStats: RTTStats{Weighted: 70}},
{RTTStats: RTTStats{Weighted: 100}},
{RTTStats: RTTStats{Weighted: 110}},
{RTTStats: RTTStats{Weighted: 120}},
{RTTStats: RTTStats{Weighted: 150}},
}
tests := []struct {
expected int
@ -31,7 +31,7 @@ func TestSelectNodes(t *testing.T) {
}
for _, tt := range tests {
t.Run("", func(t *testing.T) {
if got := selectNodes(nodes, nil, tt.expected, tt.baselines); len(got) != tt.want {
if got := selectNodes(nodes, tt.expected, tt.baselines); len(got) != tt.want {
t.Errorf("selectNodes() = %v, want %v", len(got), tt.want)
}
})

View File

@ -6,6 +6,7 @@ import (
"sync"
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
@ -16,15 +17,16 @@ type HealthCheck struct {
sync.Mutex
ticker *time.Ticker
nodes []*Node
router adapter.Router
tags []string
logger log.Logger
options *option.HealthCheckSettings
Results map[string]*HealthCheckRTTS
results map[string]*rttStorage
}
// NewHealthCheck creates a new HealthPing with settings
func NewHealthCheck(outbounds []*Node, logger log.Logger, config *option.HealthCheckSettings) *HealthCheck {
func NewHealthCheck(router adapter.Router, tags []string, logger log.Logger, config *option.HealthCheckSettings) *HealthCheck {
if config == nil {
config = &option.HealthCheckSettings{}
}
@ -46,29 +48,33 @@ func NewHealthCheck(outbounds []*Node, logger log.Logger, config *option.HealthC
config.Timeout = option.Duration(5 * time.Second)
}
return &HealthCheck{
nodes: outbounds,
router: router,
tags: tags,
options: config,
Results: nil,
results: make(map[string]*rttStorage),
logger: logger,
}
}
// Start starts the health check service
func (h *HealthCheck) Start() error {
h.Lock()
defer h.Unlock()
if h.ticker != nil {
return nil
}
interval := time.Duration(h.options.Interval) * time.Duration(h.options.SamplingCount)
ticker := time.NewTicker(interval)
h.ticker = ticker
// one time instant check
h.Check()
go func() {
h.doCheck(0, 1)
for {
h.doCheck(interval, h.options.SamplingCount)
_, ok := <-ticker.C
if !ok {
break
}
h.doCheck(interval, h.options.SamplingCount)
}
}()
return nil
@ -76,17 +82,17 @@ func (h *HealthCheck) Start() error {
// Stop stops the health check service
func (h *HealthCheck) Stop() {
h.Lock()
defer h.Unlock()
if h.ticker != nil {
h.ticker.Stop()
h.ticker = nil
}
}
// Check does a one time health check
func (h *HealthCheck) Check() error {
if len(h.nodes) == 0 {
return nil
}
h.doCheck(0, 1)
return nil
func (h *HealthCheck) Check() {
go h.doCheck(0, 1)
}
type rtt struct {
@ -97,14 +103,15 @@ type rtt struct {
// doCheck performs the 'rounds' amount checks in given 'duration'. You should make
// sure all tags are valid for current balancer
func (h *HealthCheck) doCheck(duration time.Duration, rounds int) {
count := len(h.nodes) * rounds
nodes := h.refreshNodes()
count := len(nodes) * rounds
if count == 0 {
return
}
ch := make(chan *rtt, count)
// rtts := make(map[string][]time.Duration)
for _, node := range h.nodes {
tag, detour := node.Outbound.Tag(), node.Outbound
for _, node := range nodes {
tag, detour := node.Tag(), node
client := newPingClient(
detour,
h.options.Destination,
@ -160,18 +167,10 @@ func (h *HealthCheck) doCheck(duration time.Duration, rounds int) {
func (h *HealthCheck) PutResult(tag string, rtt time.Duration) {
h.Lock()
defer h.Unlock()
if h.Results == nil {
h.Results = make(map[string]*HealthCheckRTTS)
}
r, ok := h.Results[tag]
r, ok := h.results[tag]
if !ok {
// validity is 2 times to sampling period, since the check are
// distributed in the time line randomly, in extreme cases,
// previous checks are distributed on the left, and latters
// on the right
validity := time.Duration(h.options.Interval) * time.Duration(h.options.SamplingCount) * 2
r = NewHealthPingResult(h.options.SamplingCount, validity)
h.Results[tag] = r
// the result may come after the node is removed
return
}
r.Put(rtt)
}

View File

@ -1,6 +1,11 @@
package balancer
import "time"
import (
"strings"
"time"
"github.com/sagernet/sing-box/adapter"
)
// CategorizedNodes holds the categorized nodes
type CategorizedNodes struct {
@ -12,28 +17,24 @@ type CategorizedNodes struct {
func (h *HealthCheck) NodesByCategory() *CategorizedNodes {
h.Lock()
defer h.Unlock()
if h == nil || h.Results == nil {
return &CategorizedNodes{
Untested: h.nodes,
}
if h == nil || len(h.results) == 0 {
return &CategorizedNodes{}
}
nodes := &CategorizedNodes{
Qualified: make([]*Node, 0, len(h.nodes)),
Unqualified: make([]*Node, 0, len(h.nodes)),
Failed: make([]*Node, 0, len(h.nodes)),
Untested: make([]*Node, 0, len(h.nodes)),
Qualified: make([]*Node, 0, len(h.results)),
Unqualified: make([]*Node, 0, len(h.results)),
Failed: make([]*Node, 0, len(h.results)),
Untested: make([]*Node, 0, len(h.results)),
}
for _, node := range h.nodes {
r, ok := h.Results[node.Outbound.Tag()]
if !ok {
node.HealthCheckStats = healthPingStatsUntested
continue
for tag, result := range h.results {
node := &Node{
Tag: tag,
RTTStats: result.Get(),
}
node.HealthCheckStats = r.Get()
switch {
case node.HealthCheckStats.All == 0:
case node.RTTStats.All == 0:
nodes.Untested = append(nodes.Untested, node)
case node.HealthCheckStats.All == node.HealthCheckStats.Fail,
case node.RTTStats.All == node.RTTStats.Fail,
float64(node.Fail)/float64(node.All) > float64(h.options.Tolerance):
nodes.Failed = append(nodes.Failed, node)
case h.options.MaxRTT > 0 && node.Average > time.Duration(h.options.MaxRTT):
@ -44,3 +45,49 @@ func (h *HealthCheck) NodesByCategory() *CategorizedNodes {
}
return nodes
}
// CoveredOutbounds returns the outbounds that should covered by health check
func CoveredOutbounds(router adapter.Router, tags []string) []adapter.Outbound {
outbounds := router.Outbounds()
nodes := make([]adapter.Outbound, 0, len(outbounds))
for _, outbound := range outbounds {
for _, prefix := range tags {
tag := outbound.Tag()
if strings.HasPrefix(tag, prefix) {
nodes = append(nodes, outbound)
}
}
}
return nodes
}
// refreshNodes matches nodes from router by tag prefix, and refreshes the health check results
func (h *HealthCheck) refreshNodes() []adapter.Outbound {
h.Lock()
defer h.Unlock()
nodes := CoveredOutbounds(h.router, h.tags)
tags := make(map[string]struct{})
for _, node := range nodes {
tag := node.Tag()
tags[tag] = struct{}{}
// make it known to the health check results
r, ok := h.results[tag]
if !ok {
// validity is 2 times to sampling period, since the check are
// distributed in the time line randomly, in extreme cases,
// previous checks are distributed on the left, and latters
// on the right
validity := time.Duration(h.options.Interval) * time.Duration(h.options.SamplingCount) * 2
r = newRTTStorage(h.options.SamplingCount, validity)
h.results[tag] = r
}
}
// remove unused rttStorage
for tag := range h.results {
if _, ok := tags[tag]; !ok {
delete(h.results, tag)
}
}
return nodes
}

View File

@ -3,17 +3,18 @@ package balancer
import (
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
)
// NewLeastLoad creates a new LeastLoad outbound
func NewLeastLoad(
nodes []*Node, logger log.ContextLogger,
router adapter.Router, logger log.ContextLogger,
options option.BalancerOutboundOptions,
) (Balancer, error) {
return newRTTBasedBalancer(
nodes, logger, options,
router, logger, options,
func(node *Node) time.Duration {
return node.Deviation
},

View File

@ -3,17 +3,18 @@ package balancer
import (
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
)
// NewLeastPing creates a new LeastPing outbound
func NewLeastPing(
nodes []*Node, logger log.ContextLogger,
router adapter.Router, logger log.ContextLogger,
options option.BalancerOutboundOptions,
) (Balancer, error) {
return newRTTBasedBalancer(
nodes, logger, options,
router, logger, options,
func(node *Node) time.Duration {
return node.Average
},

View File

@ -1,10 +1,6 @@
package balancer
import (
"github.com/sagernet/sing-box/adapter"
)
var healthPingStatsUntested = HealthCheckStats{
var healthPingStatsUntested = RTTStats{
All: 0,
Fail: 0,
Deviation: rttUntested,
@ -15,14 +11,6 @@ var healthPingStatsUntested = HealthCheckStats{
// Node is a banalcer node with health check result
type Node struct {
Outbound adapter.Outbound
HealthCheckStats
}
// NewNode creates a new balancer node from outbound
func NewNode(outbound adapter.Outbound) *Node {
return &Node{
Outbound: outbound,
HealthCheckStats: healthPingStatsUntested,
}
Tag string
RTTStats
}

View File

@ -11,8 +11,8 @@ const (
rttUnqualified
)
// HealthCheckStats is the statistics of HealthPingRTTS
type HealthCheckStats struct {
// RTTStats is the statistics of health check RTTs
type RTTStats struct {
All int
Fail int
Deviation time.Duration
@ -23,15 +23,15 @@ type HealthCheckStats struct {
Weighted time.Duration
}
// HealthCheckRTTS holds ping rtts for health Checker
type HealthCheckRTTS struct {
// rttStorage holds ping rtts for health Checker
type rttStorage struct {
idx int
cap int
validity time.Duration
rtts []*pingRTT
lastUpdateAt time.Time
stats *HealthCheckStats
stats *RTTStats
}
type pingRTT struct {
@ -39,25 +39,25 @@ type pingRTT struct {
value time.Duration
}
// NewHealthPingResult returns a *HealthPingResult with specified capacity
func NewHealthPingResult(cap int, validity time.Duration) *HealthCheckRTTS {
return &HealthCheckRTTS{cap: cap, validity: validity}
// newRTTStorage returns a *HealthPingResult with specified capacity
func newRTTStorage(cap int, validity time.Duration) *rttStorage {
return &rttStorage{cap: cap, validity: validity}
}
// Get gets statistics of the HealthPingRTTS
func (h *HealthCheckRTTS) Get() HealthCheckStats {
func (h *rttStorage) Get() RTTStats {
return h.getStatistics()
}
// GetWithCache get statistics and write cache for next call
// Make sure use Mutex.Lock() before calling it, RWMutex.RLock()
// is not an option since it writes cache
func (h *HealthCheckRTTS) GetWithCache() HealthCheckStats {
func (h *rttStorage) GetWithCache() RTTStats {
lastPutAt := h.rtts[h.idx].time
now := time.Now()
if h.stats == nil || h.lastUpdateAt.Before(lastPutAt) || h.findOutdated(now) >= 0 {
if h.stats == nil {
h.stats = &HealthCheckStats{}
h.stats = &RTTStats{}
}
*h.stats = h.getStatistics()
h.lastUpdateAt = now
@ -66,7 +66,7 @@ func (h *HealthCheckRTTS) GetWithCache() HealthCheckStats {
}
// Put puts a new rtt to the HealthPingResult
func (h *HealthCheckRTTS) Put(d time.Duration) {
func (h *rttStorage) Put(d time.Duration) {
if h.rtts == nil {
h.rtts = make([]*pingRTT, h.cap)
for i := 0; i < h.cap; i++ {
@ -80,7 +80,7 @@ func (h *HealthCheckRTTS) Put(d time.Duration) {
h.rtts[h.idx].value = d
}
func (h *HealthCheckRTTS) calcIndex(step int) int {
func (h *rttStorage) calcIndex(step int) int {
idx := h.idx
idx += step
if idx >= h.cap {
@ -89,8 +89,8 @@ func (h *HealthCheckRTTS) calcIndex(step int) int {
return idx
}
func (h *HealthCheckRTTS) getStatistics() HealthCheckStats {
stats := HealthCheckStats{}
func (h *rttStorage) getStatistics() RTTStats {
stats := RTTStats{}
stats.Fail = 0
stats.Max = 0
stats.Min = rttFailed
@ -125,7 +125,7 @@ func (h *HealthCheckRTTS) getStatistics() HealthCheckStats {
case stats.All == 0:
return healthPingStatsUntested
case stats.Fail == stats.All:
return HealthCheckStats{
return RTTStats{
All: stats.All,
Fail: stats.Fail,
Deviation: rttFailed,
@ -151,7 +151,7 @@ func (h *HealthCheckRTTS) getStatistics() HealthCheckStats {
return stats
}
func (h *HealthCheckRTTS) findOutdated(now time.Time) int {
func (h *rttStorage) findOutdated(now time.Time) int {
for i := h.cap - 1; i < 2*h.cap; i++ {
// from oldest to latest
idx := h.calcIndex(i)

View File

@ -1,22 +1,20 @@
package balancer_test
package balancer
import (
"math"
"reflect"
"testing"
"time"
"github.com/sagernet/sing-box/balancer"
)
func TestHealthPingResults(t *testing.T) {
func TestRTTStorage(t *testing.T) {
rtts := []int64{60, 140, 60, 140, 60, 60, 140, 60, 140}
hr := balancer.NewHealthPingResult(4, time.Hour)
hr := newRTTStorage(4, time.Hour)
for _, rtt := range rtts {
hr.Put(time.Duration(rtt))
}
rttFailed := time.Duration(math.MaxInt64)
expected := &balancer.HealthCheckStats{
expected := &RTTStats{
All: 4,
Fail: 0,
Deviation: 40,
@ -37,7 +35,7 @@ func TestHealthPingResults(t *testing.T) {
}
hr.Put(rttFailed)
hr.Put(rttFailed)
expected = &balancer.HealthCheckStats{
expected = &RTTStats{
All: 4,
Fail: 4,
Deviation: 0,
@ -53,7 +51,7 @@ func TestHealthPingResults(t *testing.T) {
func TestHealthPingResultsIgnoreOutdated(t *testing.T) {
rtts := []int64{60, 140, 60, 140}
hr := balancer.NewHealthPingResult(4, time.Duration(10)*time.Millisecond)
hr := newRTTStorage(4, time.Duration(10)*time.Millisecond)
for i, rtt := range rtts {
if i == 2 {
// wait for previous 2 outdated
@ -62,7 +60,7 @@ func TestHealthPingResultsIgnoreOutdated(t *testing.T) {
hr.Put(time.Duration(rtt))
}
hr.Get()
expected := &balancer.HealthCheckStats{
expected := &RTTStats{
All: 2,
Fail: 0,
Deviation: 40,
@ -76,7 +74,7 @@ func TestHealthPingResultsIgnoreOutdated(t *testing.T) {
}
// wait for all outdated
time.Sleep(time.Duration(10) * time.Millisecond)
expected = &balancer.HealthCheckStats{
expected = &RTTStats{
All: 0,
Fail: 0,
Deviation: 0,
@ -90,7 +88,7 @@ func TestHealthPingResultsIgnoreOutdated(t *testing.T) {
}
hr.Put(time.Duration(60))
expected = &balancer.HealthCheckStats{
expected = &RTTStats{
All: 1,
Fail: 0,
// 1 sample, std=0.5rtt

View File

@ -27,7 +27,6 @@ type Balancer struct {
fallbackTag string
balancer.Balancer
nodes []*balancer.Node
fallback adapter.Outbound
}
@ -106,13 +105,6 @@ func (s *Balancer) initialize() error {
return E.New("fallback outbound not found: ", s.fallbackTag)
}
s.fallback = outbound
for i, tag := range s.tags {
outbound, loaded := s.router.Outbound(tag)
if !loaded {
return E.New("outbound ", i, " not found: ", tag)
}
s.nodes = append(s.nodes, balancer.NewNode(outbound))
}
return nil
}
@ -128,19 +120,34 @@ func (s *Balancer) setBalancer(b balancer.Balancer) error {
}
func (s *Balancer) pick() adapter.Outbound {
if s.Balancer != nil {
selected := s.Balancer.Pick()
if selected == nil {
tag := s.pickTag()
if tag == "" {
return s.fallback
}
return selected.Outbound
}
// not started
count := len(s.nodes)
if count == 0 {
// goes to fallback
outbound, ok := s.router.Outbound(tag)
if !ok {
return s.fallback
}
picked := s.nodes[rand.Intn(count)]
return picked.Outbound
return outbound
}
func (s *Balancer) pickTag() string {
if s.Balancer == nil {
// not started yet, pick a random one
return s.randomTag()
}
tag := s.Balancer.Pick()
if tag == "" {
return ""
}
return tag
}
func (s *Balancer) randomTag() string {
nodes := balancer.CoveredOutbounds(s.router, s.tags)
count := len(nodes)
if count == 0 {
return ""
}
return s.tags[rand.Intn(count)]
}

View File

@ -6,7 +6,6 @@ import (
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
)
var (
@ -23,9 +22,6 @@ type LeastLoad struct {
// NewLeastLoad creates a new LeastLoad outbound
func NewLeastLoad(router adapter.Router, logger log.ContextLogger, tag string, options option.BalancerOutboundOptions) (*LeastLoad, error) {
if len(options.Outbounds) == 0 {
return nil, E.New("missing tags")
}
return &LeastLoad{
Balancer: NewBalancer(
C.TypeLeastLoad, router, logger, tag,
@ -41,7 +37,7 @@ func (s *LeastLoad) Start() error {
if err != nil {
return err
}
b, err := balancer.NewLeastLoad(s.nodes, s.logger, s.options)
b, err := balancer.NewLeastLoad(s.router, s.logger, s.options)
if err != nil {
return err
}

View File

@ -6,7 +6,6 @@ import (
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
)
var (
@ -23,9 +22,6 @@ type LeastPing struct {
// NewLeastPing creates a new LeastPing outbound
func NewLeastPing(router adapter.Router, logger log.ContextLogger, tag string, options option.BalancerOutboundOptions) (*LeastPing, error) {
if len(options.Outbounds) == 0 {
return nil, E.New("missing tags")
}
return &LeastPing{
Balancer: NewBalancer(
C.TypeLeastPing, router, logger, tag,
@ -41,7 +37,7 @@ func (s *LeastPing) Start() error {
if err != nil {
return err
}
b, err := balancer.NewLeastPing(s.nodes, s.logger, s.options)
b, err := balancer.NewLeastPing(s.router, s.logger, s.options)
if err != nil {
return err
}