mirror of
https://github.com/SagerNet/sing-box.git
synced 2025-06-13 05:44:12 +08:00
add outbound 'leastload'
This commit is contained in:
parent
6591dd58ca
commit
e0e3c5153f
207
balancer/healthcheck.go
Normal file
207
balancer/healthcheck.go
Normal file
@ -0,0 +1,207 @@
|
||||
package balancer
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
// HealthPingSettings holds settings for health Checker
|
||||
type HealthPingSettings struct {
|
||||
Destination string `json:"destination"`
|
||||
Connectivity string `json:"connectivity"`
|
||||
Interval time.Duration `json:"interval"`
|
||||
SamplingCount int `json:"sampling"`
|
||||
Timeout time.Duration `json:"timeout"`
|
||||
}
|
||||
|
||||
// HealthCheck is the health checker for balancers
|
||||
type HealthCheck struct {
|
||||
sync.Mutex
|
||||
|
||||
ticker *time.Ticker
|
||||
nodes []*Node
|
||||
logger log.Logger
|
||||
|
||||
Settings *HealthPingSettings
|
||||
Results map[string]*HealthCheckRTTS
|
||||
}
|
||||
|
||||
// NewHealthCheck creates a new HealthPing with settings
|
||||
func NewHealthCheck(outbounds []*Node, logger log.Logger, config *option.HealthCheckSettings) *HealthCheck {
|
||||
settings := &HealthPingSettings{}
|
||||
if config != nil {
|
||||
settings = &HealthPingSettings{
|
||||
Connectivity: strings.TrimSpace(config.Connectivity),
|
||||
Destination: strings.TrimSpace(config.Destination),
|
||||
Interval: time.Duration(config.Interval),
|
||||
SamplingCount: int(config.SamplingCount),
|
||||
Timeout: time.Duration(config.Timeout),
|
||||
}
|
||||
}
|
||||
if settings.Destination == "" {
|
||||
settings.Destination = "http://www.google.com/gen_204"
|
||||
}
|
||||
if settings.Interval == 0 {
|
||||
settings.Interval = time.Duration(1) * time.Minute
|
||||
} else if settings.Interval < 10 {
|
||||
logger.Warn("health check interval is too small, 10s is applied")
|
||||
settings.Interval = time.Duration(10) * time.Second
|
||||
}
|
||||
if settings.SamplingCount <= 0 {
|
||||
settings.SamplingCount = 10
|
||||
}
|
||||
if settings.Timeout <= 0 {
|
||||
// results are saved after all health pings finish,
|
||||
// a larger timeout could possibly makes checks run longer
|
||||
settings.Timeout = time.Duration(5) * time.Second
|
||||
}
|
||||
return &HealthCheck{
|
||||
nodes: outbounds,
|
||||
Settings: settings,
|
||||
Results: nil,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts the health check service
|
||||
func (h *HealthCheck) Start() {
|
||||
if h.ticker != nil {
|
||||
return
|
||||
}
|
||||
interval := h.Settings.Interval * time.Duration(h.Settings.SamplingCount)
|
||||
ticker := time.NewTicker(interval)
|
||||
h.ticker = ticker
|
||||
go func() {
|
||||
for {
|
||||
h.doCheck(interval, h.Settings.SamplingCount)
|
||||
_, ok := <-ticker.C
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Stop stops the health check service
|
||||
func (h *HealthCheck) Stop() {
|
||||
h.ticker.Stop()
|
||||
h.ticker = nil
|
||||
}
|
||||
|
||||
// Check does a one time health check
|
||||
func (h *HealthCheck) Check() error {
|
||||
if len(h.nodes) == 0 {
|
||||
return nil
|
||||
}
|
||||
h.doCheck(0, 1)
|
||||
return nil
|
||||
}
|
||||
|
||||
type rtt struct {
|
||||
handler string
|
||||
value time.Duration
|
||||
}
|
||||
|
||||
// doCheck performs the 'rounds' amount checks in given 'duration'. You should make
|
||||
// sure all tags are valid for current balancer
|
||||
func (h *HealthCheck) doCheck(duration time.Duration, rounds int) {
|
||||
count := len(h.nodes) * rounds
|
||||
if count == 0 {
|
||||
return
|
||||
}
|
||||
ch := make(chan *rtt, count)
|
||||
// rtts := make(map[string][]time.Duration)
|
||||
for _, node := range h.nodes {
|
||||
tag, detour := node.Outbound.Tag(), node.Outbound
|
||||
client := newPingClient(
|
||||
detour,
|
||||
h.Settings.Destination,
|
||||
h.Settings.Timeout,
|
||||
)
|
||||
for i := 0; i < rounds; i++ {
|
||||
delay := time.Duration(0)
|
||||
if duration > 0 {
|
||||
delay = time.Duration(rand.Intn(int(duration)))
|
||||
}
|
||||
time.AfterFunc(delay, func() {
|
||||
// h.logger.Debug("checking ", tag)
|
||||
delay, err := client.MeasureDelay()
|
||||
if err == nil {
|
||||
ch <- &rtt{
|
||||
handler: tag,
|
||||
value: delay,
|
||||
}
|
||||
return
|
||||
}
|
||||
if !h.checkConnectivity() {
|
||||
h.logger.Debug("network is down")
|
||||
ch <- &rtt{
|
||||
handler: tag,
|
||||
value: 0,
|
||||
}
|
||||
return
|
||||
}
|
||||
h.logger.Debug(
|
||||
E.Cause(
|
||||
err,
|
||||
fmt.Sprintf("ping %s via %s", h.Settings.Destination, tag),
|
||||
),
|
||||
)
|
||||
ch <- &rtt{
|
||||
handler: tag,
|
||||
value: rttFailed,
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
for i := 0; i < count; i++ {
|
||||
rtt := <-ch
|
||||
if rtt.value > 0 {
|
||||
// should not put results when network is down
|
||||
h.PutResult(rtt.handler, rtt.value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// PutResult put a ping rtt to results
|
||||
func (h *HealthCheck) PutResult(tag string, rtt time.Duration) {
|
||||
h.Lock()
|
||||
defer h.Unlock()
|
||||
if h.Results == nil {
|
||||
h.Results = make(map[string]*HealthCheckRTTS)
|
||||
}
|
||||
r, ok := h.Results[tag]
|
||||
if !ok {
|
||||
// validity is 2 times to sampling period, since the check are
|
||||
// distributed in the time line randomly, in extreme cases,
|
||||
// previous checks are distributed on the left, and latters
|
||||
// on the right
|
||||
validity := h.Settings.Interval * time.Duration(h.Settings.SamplingCount) * 2
|
||||
r = NewHealthPingResult(h.Settings.SamplingCount, validity)
|
||||
h.Results[tag] = r
|
||||
}
|
||||
r.Put(rtt)
|
||||
}
|
||||
|
||||
// checkConnectivity checks the network connectivity, it returns
|
||||
// true if network is good or "connectivity check url" not set
|
||||
func (h *HealthCheck) checkConnectivity() bool {
|
||||
if h.Settings.Connectivity == "" {
|
||||
return true
|
||||
}
|
||||
tester := newDirectPingClient(
|
||||
h.Settings.Connectivity,
|
||||
h.Settings.Timeout,
|
||||
)
|
||||
if _, err := tester.MeasureDelay(); err != nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
151
balancer/healthcheck_result.go
Normal file
151
balancer/healthcheck_result.go
Normal file
@ -0,0 +1,151 @@
|
||||
package balancer
|
||||
|
||||
import (
|
||||
"math"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
rttFailed = time.Duration(math.MaxInt64 - iota)
|
||||
rttUntested
|
||||
rttUnqualified
|
||||
)
|
||||
|
||||
// HealthCheckStats is the statistics of HealthPingRTTS
|
||||
type HealthCheckStats struct {
|
||||
All int
|
||||
Fail int
|
||||
Deviation time.Duration
|
||||
Average time.Duration
|
||||
Max time.Duration
|
||||
Min time.Duration
|
||||
|
||||
applied time.Duration
|
||||
}
|
||||
|
||||
// HealthCheckRTTS holds ping rtts for health Checker
|
||||
type HealthCheckRTTS struct {
|
||||
idx int
|
||||
cap int
|
||||
validity time.Duration
|
||||
rtts []*pingRTT
|
||||
|
||||
lastUpdateAt time.Time
|
||||
stats *HealthCheckStats
|
||||
}
|
||||
|
||||
type pingRTT struct {
|
||||
time time.Time
|
||||
value time.Duration
|
||||
}
|
||||
|
||||
// NewHealthPingResult returns a *HealthPingResult with specified capacity
|
||||
func NewHealthPingResult(cap int, validity time.Duration) *HealthCheckRTTS {
|
||||
return &HealthCheckRTTS{cap: cap, validity: validity}
|
||||
}
|
||||
|
||||
// Get gets statistics of the HealthPingRTTS
|
||||
func (h *HealthCheckRTTS) Get() *HealthCheckStats {
|
||||
return h.getStatistics()
|
||||
}
|
||||
|
||||
// GetWithCache get statistics and write cache for next call
|
||||
// Make sure use Mutex.Lock() before calling it, RWMutex.RLock()
|
||||
// is not an option since it writes cache
|
||||
func (h *HealthCheckRTTS) GetWithCache() *HealthCheckStats {
|
||||
lastPutAt := h.rtts[h.idx].time
|
||||
now := time.Now()
|
||||
if h.stats == nil || h.lastUpdateAt.Before(lastPutAt) || h.findOutdated(now) >= 0 {
|
||||
h.stats = h.getStatistics()
|
||||
h.lastUpdateAt = now
|
||||
}
|
||||
return h.stats
|
||||
}
|
||||
|
||||
// Put puts a new rtt to the HealthPingResult
|
||||
func (h *HealthCheckRTTS) Put(d time.Duration) {
|
||||
if h.rtts == nil {
|
||||
h.rtts = make([]*pingRTT, h.cap)
|
||||
for i := 0; i < h.cap; i++ {
|
||||
h.rtts[i] = &pingRTT{}
|
||||
}
|
||||
h.idx = -1
|
||||
}
|
||||
h.idx = h.calcIndex(1)
|
||||
now := time.Now()
|
||||
h.rtts[h.idx].time = now
|
||||
h.rtts[h.idx].value = d
|
||||
}
|
||||
|
||||
func (h *HealthCheckRTTS) calcIndex(step int) int {
|
||||
idx := h.idx
|
||||
idx += step
|
||||
if idx >= h.cap {
|
||||
idx %= h.cap
|
||||
}
|
||||
return idx
|
||||
}
|
||||
|
||||
func (h *HealthCheckRTTS) getStatistics() *HealthCheckStats {
|
||||
stats := &HealthCheckStats{}
|
||||
stats.Fail = 0
|
||||
stats.Max = 0
|
||||
stats.Min = rttFailed
|
||||
sum := time.Duration(0)
|
||||
cnt := 0
|
||||
validRTTs := make([]time.Duration, 0)
|
||||
for _, rtt := range h.rtts {
|
||||
switch {
|
||||
case rtt.value == 0 || time.Since(rtt.time) > h.validity:
|
||||
continue
|
||||
case rtt.value == rttFailed:
|
||||
stats.Fail++
|
||||
continue
|
||||
}
|
||||
cnt++
|
||||
sum += rtt.value
|
||||
validRTTs = append(validRTTs, rtt.value)
|
||||
if stats.Max < rtt.value {
|
||||
stats.Max = rtt.value
|
||||
}
|
||||
if stats.Min > rtt.value {
|
||||
stats.Min = rtt.value
|
||||
}
|
||||
}
|
||||
stats.All = cnt + stats.Fail
|
||||
if cnt == 0 {
|
||||
stats.Min = 0
|
||||
return stats
|
||||
}
|
||||
stats.Average = time.Duration(int(sum) / cnt)
|
||||
var std float64
|
||||
if cnt < 2 {
|
||||
// no enough data for standard deviation, we assume it's half of the average rtt
|
||||
// if we don't do this, standard deviation of 1 round tested nodes is 0, will always
|
||||
// selected before 2 or more rounds tested nodes
|
||||
std = float64(stats.Average / 2)
|
||||
} else {
|
||||
variance := float64(0)
|
||||
for _, rtt := range validRTTs {
|
||||
variance += math.Pow(float64(rtt-stats.Average), 2)
|
||||
}
|
||||
std = math.Sqrt(variance / float64(cnt))
|
||||
}
|
||||
stats.Deviation = time.Duration(std)
|
||||
return stats
|
||||
}
|
||||
|
||||
func (h *HealthCheckRTTS) findOutdated(now time.Time) int {
|
||||
for i := h.cap - 1; i < 2*h.cap; i++ {
|
||||
// from oldest to latest
|
||||
idx := h.calcIndex(i)
|
||||
validity := h.rtts[idx].time.Add(h.validity)
|
||||
if h.lastUpdateAt.After(validity) {
|
||||
return idx
|
||||
}
|
||||
if validity.Before(now) {
|
||||
return idx
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
106
balancer/healthcheck_result_test.go
Normal file
106
balancer/healthcheck_result_test.go
Normal file
@ -0,0 +1,106 @@
|
||||
package balancer_test
|
||||
|
||||
import (
|
||||
"math"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/balancer"
|
||||
)
|
||||
|
||||
func TestHealthPingResults(t *testing.T) {
|
||||
rtts := []int64{60, 140, 60, 140, 60, 60, 140, 60, 140}
|
||||
hr := balancer.NewHealthPingResult(4, time.Hour)
|
||||
for _, rtt := range rtts {
|
||||
hr.Put(time.Duration(rtt))
|
||||
}
|
||||
rttFailed := time.Duration(math.MaxInt64)
|
||||
expected := &balancer.HealthCheckStats{
|
||||
All: 4,
|
||||
Fail: 0,
|
||||
Deviation: 40,
|
||||
Average: 100,
|
||||
Max: 140,
|
||||
Min: 60,
|
||||
}
|
||||
actual := hr.Get()
|
||||
if !reflect.DeepEqual(expected, actual) {
|
||||
t.Errorf("expected: %v, actual: %v", expected, actual)
|
||||
}
|
||||
hr.Put(rttFailed)
|
||||
hr.Put(rttFailed)
|
||||
expected.Fail = 2
|
||||
actual = hr.Get()
|
||||
if !reflect.DeepEqual(expected, actual) {
|
||||
t.Errorf("failed half-failures test, expected: %v, actual: %v", expected, actual)
|
||||
}
|
||||
hr.Put(rttFailed)
|
||||
hr.Put(rttFailed)
|
||||
expected = &balancer.HealthCheckStats{
|
||||
All: 4,
|
||||
Fail: 4,
|
||||
Deviation: 0,
|
||||
Average: 0,
|
||||
Max: 0,
|
||||
Min: 0,
|
||||
}
|
||||
actual = hr.Get()
|
||||
if !reflect.DeepEqual(expected, actual) {
|
||||
t.Errorf("failed all-failures test, expected: %v, actual: %v", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealthPingResultsIgnoreOutdated(t *testing.T) {
|
||||
rtts := []int64{60, 140, 60, 140}
|
||||
hr := balancer.NewHealthPingResult(4, time.Duration(10)*time.Millisecond)
|
||||
for i, rtt := range rtts {
|
||||
if i == 2 {
|
||||
// wait for previous 2 outdated
|
||||
time.Sleep(time.Duration(10) * time.Millisecond)
|
||||
}
|
||||
hr.Put(time.Duration(rtt))
|
||||
}
|
||||
hr.Get()
|
||||
expected := &balancer.HealthCheckStats{
|
||||
All: 2,
|
||||
Fail: 0,
|
||||
Deviation: 40,
|
||||
Average: 100,
|
||||
Max: 140,
|
||||
Min: 60,
|
||||
}
|
||||
actual := hr.Get()
|
||||
if !reflect.DeepEqual(expected, actual) {
|
||||
t.Errorf("failed 'half-outdated' test, expected: %v, actual: %v", expected, actual)
|
||||
}
|
||||
// wait for all outdated
|
||||
time.Sleep(time.Duration(10) * time.Millisecond)
|
||||
expected = &balancer.HealthCheckStats{
|
||||
All: 0,
|
||||
Fail: 0,
|
||||
Deviation: 0,
|
||||
Average: 0,
|
||||
Max: 0,
|
||||
Min: 0,
|
||||
}
|
||||
actual = hr.Get()
|
||||
if !reflect.DeepEqual(expected, actual) {
|
||||
t.Errorf("failed 'outdated / not-tested' test, expected: %v, actual: %v", expected, actual)
|
||||
}
|
||||
|
||||
hr.Put(time.Duration(60))
|
||||
expected = &balancer.HealthCheckStats{
|
||||
All: 1,
|
||||
Fail: 0,
|
||||
// 1 sample, std=0.5rtt
|
||||
Deviation: 30,
|
||||
Average: 60,
|
||||
Max: 60,
|
||||
Min: 60,
|
||||
}
|
||||
actual = hr.Get()
|
||||
if !reflect.DeepEqual(expected, actual) {
|
||||
t.Errorf("expected: %v, actual: %v", expected, actual)
|
||||
}
|
||||
}
|
164
balancer/leastload.go
Normal file
164
balancer/leastload.go
Normal file
@ -0,0 +1,164 @@
|
||||
package balancer
|
||||
|
||||
import (
|
||||
"math"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
)
|
||||
|
||||
// LeastLoad is leastload balancer
|
||||
type LeastLoad struct {
|
||||
nodes []*Node
|
||||
options *option.LeastLoadOutboundOptions
|
||||
|
||||
*HealthCheck
|
||||
costs *WeightManager
|
||||
}
|
||||
|
||||
// NewLeastLoad creates a new LeastLoad outbound
|
||||
func NewLeastLoad(
|
||||
nodes []*Node, logger log.ContextLogger,
|
||||
options option.LeastLoadOutboundOptions,
|
||||
) (*LeastLoad, error) {
|
||||
return &LeastLoad{
|
||||
nodes: nodes,
|
||||
options: &options,
|
||||
HealthCheck: NewHealthCheck(nodes, logger, options.HealthCheck),
|
||||
costs: NewWeightManager(
|
||||
logger, options.Costs, 1,
|
||||
func(value, cost float64) float64 {
|
||||
return value * math.Pow(cost, 0.5)
|
||||
},
|
||||
),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Select selects qualified nodes
|
||||
func (s *LeastLoad) Select() []*Node {
|
||||
qualified, _ := s.getNodes()
|
||||
return s.selectLeastLoad(qualified)
|
||||
}
|
||||
|
||||
// selectLeastLoad selects nodes according to Baselines and Expected Count.
|
||||
//
|
||||
// The strategy always improves network response speed, not matter which mode below is configurated.
|
||||
// But they can still have different priorities.
|
||||
//
|
||||
// 1. Bandwidth priority: no Baseline + Expected Count > 0.: selects `Expected Count` of nodes.
|
||||
// (one if Expected Count <= 0)
|
||||
//
|
||||
// 2. Bandwidth priority advanced: Baselines + Expected Count > 0.
|
||||
// Select `Expected Count` amount of nodes, and also those near them according to baselines.
|
||||
// In other words, it selects according to different Baselines, until one of them matches
|
||||
// the Expected Count, if no Baseline matches, Expected Count applied.
|
||||
//
|
||||
// 3. Speed priority: Baselines + `Expected Count <= 0`.
|
||||
// go through all baselines until find selects, if not, select none. Used in combination
|
||||
// with 'balancer.fallbackTag', it means: selects qualified nodes or use the fallback.
|
||||
func (s *LeastLoad) selectLeastLoad(nodes []*Node) []*Node {
|
||||
if len(nodes) == 0 {
|
||||
// s.logger.Debug("LeastLoad: no qualified nodes")
|
||||
return nil
|
||||
}
|
||||
expected := int(s.options.Expected)
|
||||
availableCount := len(nodes)
|
||||
if expected > availableCount {
|
||||
return nodes
|
||||
}
|
||||
|
||||
if expected <= 0 {
|
||||
expected = 1
|
||||
}
|
||||
if len(s.options.Baselines) == 0 {
|
||||
return nodes[:expected]
|
||||
}
|
||||
|
||||
count := 0
|
||||
// go through all base line until find expected selects
|
||||
for _, b := range s.options.Baselines {
|
||||
baseline := time.Duration(b)
|
||||
for i := 0; i < availableCount; i++ {
|
||||
if nodes[i].applied > baseline {
|
||||
break
|
||||
}
|
||||
count = i + 1
|
||||
}
|
||||
// don't continue if find expected selects
|
||||
if count >= expected {
|
||||
s.logger.Debug("applied baseline: ", baseline)
|
||||
break
|
||||
}
|
||||
}
|
||||
if s.options.Expected > 0 && count < expected {
|
||||
count = expected
|
||||
}
|
||||
return nodes[:count]
|
||||
}
|
||||
|
||||
func (s *LeastLoad) getNodes() ([]*Node, []*Node) {
|
||||
s.HealthCheck.Lock()
|
||||
defer s.HealthCheck.Unlock()
|
||||
|
||||
qualified := make([]*Node, 0)
|
||||
unqualified := make([]*Node, 0)
|
||||
failed := make([]*Node, 0)
|
||||
untested := make([]*Node, 0)
|
||||
others := make([]*Node, 0)
|
||||
for _, node := range s.nodes {
|
||||
node.FetchStats(s.HealthCheck)
|
||||
switch {
|
||||
case node.All == 0:
|
||||
node.applied = rttUntested
|
||||
untested = append(untested, node)
|
||||
case s.options.MaxRTT > 0 && node.Average > time.Duration(s.options.MaxRTT):
|
||||
node.applied = rttUnqualified
|
||||
unqualified = append(unqualified, node)
|
||||
case float64(node.Fail)/float64(node.All) > float64(s.options.Tolerance):
|
||||
node.applied = rttFailed
|
||||
if node.All-node.Fail == 0 {
|
||||
// no good, put them after has-good nodes
|
||||
node.applied = rttFailed
|
||||
node.Deviation = rttFailed
|
||||
node.Average = rttFailed
|
||||
}
|
||||
failed = append(failed, node)
|
||||
default:
|
||||
node.applied = time.Duration(s.costs.Apply(node.Outbound.Tag(), float64(node.Deviation)))
|
||||
qualified = append(qualified, node)
|
||||
}
|
||||
}
|
||||
if len(qualified) > 0 {
|
||||
leastloadSort(qualified)
|
||||
others = append(others, unqualified...)
|
||||
others = append(others, untested...)
|
||||
others = append(others, failed...)
|
||||
} else {
|
||||
qualified = untested
|
||||
others = append(others, unqualified...)
|
||||
others = append(others, failed...)
|
||||
}
|
||||
return qualified, others
|
||||
}
|
||||
|
||||
func leastloadSort(nodes []*Node) {
|
||||
sort.Slice(nodes, func(i, j int) bool {
|
||||
left := nodes[i]
|
||||
right := nodes[j]
|
||||
if left.applied != right.applied {
|
||||
return left.applied < right.applied
|
||||
}
|
||||
if left.applied != right.applied {
|
||||
return left.applied < right.applied
|
||||
}
|
||||
if left.Average != right.Average {
|
||||
return left.Average < right.Average
|
||||
}
|
||||
if left.Fail != right.Fail {
|
||||
return left.Fail < right.Fail
|
||||
}
|
||||
return left.All > right.All
|
||||
})
|
||||
}
|
37
balancer/node.go
Normal file
37
balancer/node.go
Normal file
@ -0,0 +1,37 @@
|
||||
package balancer
|
||||
|
||||
import (
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
)
|
||||
|
||||
var healthPingStatsZero = HealthCheckStats{
|
||||
applied: rttUntested,
|
||||
}
|
||||
|
||||
// Node is a banalcer node with health check result
|
||||
type Node struct {
|
||||
Outbound adapter.Outbound
|
||||
HealthCheckStats
|
||||
}
|
||||
|
||||
// NewNode creates a new balancer node from outbound
|
||||
func NewNode(outbound adapter.Outbound) *Node {
|
||||
return &Node{
|
||||
Outbound: outbound,
|
||||
HealthCheckStats: healthPingStatsZero,
|
||||
}
|
||||
}
|
||||
|
||||
// FetchStats fetches statistics from *HealthPing p
|
||||
func (s *Node) FetchStats(p *HealthCheck) {
|
||||
if p == nil || p.Results == nil {
|
||||
s.HealthCheckStats = healthPingStatsZero
|
||||
return
|
||||
}
|
||||
r, ok := p.Results[s.Outbound.Tag()]
|
||||
if !ok {
|
||||
s.HealthCheckStats = healthPingStatsZero
|
||||
return
|
||||
}
|
||||
s.HealthCheckStats = *r.Get()
|
||||
}
|
67
balancer/ping.go
Normal file
67
balancer/ping.go
Normal file
@ -0,0 +1,67 @@
|
||||
package balancer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"net"
|
||||
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
type pingClient struct {
|
||||
destination string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func newPingClient(detour N.Dialer, destination string, timeout time.Duration) *pingClient {
|
||||
return &pingClient{
|
||||
destination: destination,
|
||||
httpClient: newHTTPClient(detour, timeout),
|
||||
}
|
||||
}
|
||||
|
||||
func newDirectPingClient(destination string, timeout time.Duration) *pingClient {
|
||||
return &pingClient{
|
||||
destination: destination,
|
||||
httpClient: &http.Client{Timeout: timeout},
|
||||
}
|
||||
}
|
||||
|
||||
func newHTTPClient(detour N.Dialer, timeout time.Duration) *http.Client {
|
||||
tr := &http.Transport{
|
||||
DisableKeepAlives: true,
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return detour.DialContext(ctx, network, M.ParseSocksaddr(addr))
|
||||
},
|
||||
}
|
||||
return &http.Client{
|
||||
Transport: tr,
|
||||
Timeout: timeout,
|
||||
// don't follow redirect
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// MeasureDelay returns the delay time of the request to dest
|
||||
func (s *pingClient) MeasureDelay() (time.Duration, error) {
|
||||
if s.httpClient == nil {
|
||||
panic("pingClient no initialized")
|
||||
}
|
||||
req, err := http.NewRequest(http.MethodHead, s.destination, nil)
|
||||
if err != nil {
|
||||
return rttFailed, err
|
||||
}
|
||||
start := time.Now()
|
||||
resp, err := s.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return rttFailed, err
|
||||
}
|
||||
// don't wait for body
|
||||
resp.Body.Close()
|
||||
return time.Since(start), nil
|
||||
}
|
91
balancer/weight.go
Normal file
91
balancer/weight.go
Normal file
@ -0,0 +1,91 @@
|
||||
package balancer
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
type weightScaler func(value, weight float64) float64
|
||||
|
||||
var numberFinder = regexp.MustCompile(`\d+(\.\d+)?`)
|
||||
|
||||
// NewWeightManager creates a new WeightManager with settings
|
||||
func NewWeightManager(logger log.Logger, s []*option.StrategyWeight, defaultWeight float64, scaler weightScaler) *WeightManager {
|
||||
return &WeightManager{
|
||||
settings: s,
|
||||
cache: make(map[string]float64),
|
||||
scaler: scaler,
|
||||
defaultWeight: defaultWeight,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// WeightManager manages weights for specific settings
|
||||
type WeightManager struct {
|
||||
settings []*option.StrategyWeight
|
||||
cache map[string]float64
|
||||
scaler weightScaler
|
||||
defaultWeight float64
|
||||
logger log.Logger
|
||||
}
|
||||
|
||||
// Get get the weight of specified tag
|
||||
func (s *WeightManager) Get(tag string) float64 {
|
||||
weight, ok := s.cache[tag]
|
||||
if ok {
|
||||
return weight
|
||||
}
|
||||
weight = s.findValue(tag)
|
||||
s.cache[tag] = weight
|
||||
return weight
|
||||
}
|
||||
|
||||
// Apply applies weight to the value
|
||||
func (s *WeightManager) Apply(tag string, value float64) float64 {
|
||||
return s.scaler(value, s.Get(tag))
|
||||
}
|
||||
|
||||
func (s *WeightManager) findValue(tag string) float64 {
|
||||
for _, w := range s.settings {
|
||||
matched := s.getMatch(tag, w.Match, w.Regexp)
|
||||
if matched == "" {
|
||||
continue
|
||||
}
|
||||
if w.Value > 0 {
|
||||
return float64(w.Value)
|
||||
}
|
||||
// auto weight from matched
|
||||
numStr := numberFinder.FindString(matched)
|
||||
if numStr == "" {
|
||||
return s.defaultWeight
|
||||
}
|
||||
weight, err := strconv.ParseFloat(numStr, 64)
|
||||
if err != nil {
|
||||
s.logger.Warn(E.Cause(err, "parse weight from tag"))
|
||||
return s.defaultWeight
|
||||
}
|
||||
return weight
|
||||
}
|
||||
return s.defaultWeight
|
||||
}
|
||||
|
||||
func (s *WeightManager) getMatch(tag, find string, isRegexp bool) string {
|
||||
if !isRegexp {
|
||||
idx := strings.Index(tag, find)
|
||||
if idx < 0 {
|
||||
return ""
|
||||
}
|
||||
return find
|
||||
}
|
||||
r, err := regexp.Compile(find)
|
||||
if err != nil {
|
||||
s.logger.Warn(E.Cause(err, "weight regexp"))
|
||||
return ""
|
||||
}
|
||||
return r.FindString(tag)
|
||||
}
|
63
balancer/weight_test.go
Normal file
63
balancer/weight_test.go
Normal file
@ -0,0 +1,63 @@
|
||||
package balancer_test
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/sagernet/sing-box/balancer"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
)
|
||||
|
||||
func TestWeight(t *testing.T) {
|
||||
manager := balancer.NewWeightManager(
|
||||
log.NewNOPFactory().Logger(),
|
||||
[]*option.StrategyWeight{
|
||||
{
|
||||
Match: "x5",
|
||||
Value: 100,
|
||||
},
|
||||
{
|
||||
Match: "x8",
|
||||
},
|
||||
{
|
||||
Regexp: true,
|
||||
Match: `\bx0+(\.\d+)?\b`,
|
||||
Value: 1,
|
||||
},
|
||||
{
|
||||
Regexp: true,
|
||||
Match: `\bx\d+(\.\d+)?\b`,
|
||||
},
|
||||
},
|
||||
1, func(v, w float64) float64 {
|
||||
return v * w
|
||||
},
|
||||
)
|
||||
tags := []string{
|
||||
"node name, x5, and more",
|
||||
"node name, x8",
|
||||
"node name, x15",
|
||||
"node name, x0100, and more",
|
||||
"node name, x10.1",
|
||||
"node name, x00.1, and more",
|
||||
}
|
||||
// test weight
|
||||
expected := []float64{100, 8, 15, 100, 10.1, 1}
|
||||
actual := make([]float64, 0)
|
||||
for _, tag := range tags {
|
||||
actual = append(actual, manager.Get(tag))
|
||||
}
|
||||
if !reflect.DeepEqual(expected, actual) {
|
||||
t.Errorf("expected: %v, actual: %v", expected, actual)
|
||||
}
|
||||
// test scale
|
||||
expected2 := []float64{1000, 80, 150, 1000, 101, 10}
|
||||
actual2 := make([]float64, 0)
|
||||
for _, tag := range tags {
|
||||
actual2 = append(actual2, manager.Apply(tag, 10))
|
||||
}
|
||||
if !reflect.DeepEqual(expected2, actual2) {
|
||||
t.Errorf("expected2: %v, actual2: %v", expected2, actual2)
|
||||
}
|
||||
}
|
@ -24,6 +24,7 @@ const (
|
||||
)
|
||||
|
||||
const (
|
||||
TypeSelector = "selector"
|
||||
TypeURLTest = "urltest"
|
||||
TypeSelector = "selector"
|
||||
TypeURLTest = "urltest"
|
||||
TypeLeastLoad = "leastload"
|
||||
)
|
||||
|
35
option/balancer.go
Normal file
35
option/balancer.go
Normal file
@ -0,0 +1,35 @@
|
||||
package option
|
||||
|
||||
// LeastLoadOutboundOptions is the options for leastload outbound
|
||||
type LeastLoadOutboundOptions struct {
|
||||
Outbounds []string `json:"outbounds"`
|
||||
Fallback string `json:"fallback,omitempty"`
|
||||
// health check settings
|
||||
HealthCheck *HealthCheckSettings `json:"healthCheck,omitempty"`
|
||||
// cost settings
|
||||
Costs []*StrategyWeight `json:"costs,omitempty"`
|
||||
// ping rtt baselines (ms)
|
||||
Baselines []Duration `json:"baselines,omitempty"`
|
||||
// expected nodes count to select
|
||||
Expected int32 `json:"expected,omitempty"`
|
||||
// max acceptable rtt (ms), filter away high delay nodes. defalut 0
|
||||
MaxRTT Duration `json:"maxRTT,omitempty"`
|
||||
// acceptable failure rate
|
||||
Tolerance float64 `json:"tolerance,omitempty"`
|
||||
}
|
||||
|
||||
// HealthCheckSettings is the settings for health check
|
||||
type HealthCheckSettings struct {
|
||||
Destination string `json:"destination"`
|
||||
Connectivity string `json:"connectivity"`
|
||||
Interval Duration `json:"interval"`
|
||||
SamplingCount int `json:"sampling"`
|
||||
Timeout Duration `json:"timeout"`
|
||||
}
|
||||
|
||||
// StrategyWeight is the weight for a balancing strategy
|
||||
type StrategyWeight struct {
|
||||
Regexp bool `json:"regexp,omitempty"`
|
||||
Match string `json:"match,omitempty"`
|
||||
Value float32 `json:"value,omitempty"`
|
||||
}
|
@ -25,6 +25,7 @@ type _Outbound struct {
|
||||
VLESSOptions VLESSOutboundOptions `json:"-"`
|
||||
SelectorOptions SelectorOutboundOptions `json:"-"`
|
||||
URLTestOptions URLTestOutboundOptions `json:"-"`
|
||||
LeastLoadOptions LeastLoadOutboundOptions `json:"-"`
|
||||
}
|
||||
|
||||
type Outbound _Outbound
|
||||
@ -64,6 +65,8 @@ func (h Outbound) MarshalJSON() ([]byte, error) {
|
||||
v = h.SelectorOptions
|
||||
case C.TypeURLTest:
|
||||
v = h.URLTestOptions
|
||||
case C.TypeLeastLoad:
|
||||
v = h.LeastLoadOptions
|
||||
default:
|
||||
return nil, E.New("unknown outbound type: ", h.Type)
|
||||
}
|
||||
@ -109,6 +112,8 @@ func (h *Outbound) UnmarshalJSON(bytes []byte) error {
|
||||
v = &h.SelectorOptions
|
||||
case C.TypeURLTest:
|
||||
v = &h.URLTestOptions
|
||||
case C.TypeLeastLoad:
|
||||
v = &h.LeastLoadOptions
|
||||
default:
|
||||
return E.New("unknown outbound type: ", h.Type)
|
||||
}
|
||||
|
@ -49,6 +49,8 @@ func New(ctx context.Context, router adapter.Router, logger log.ContextLogger, o
|
||||
return NewSelector(router, logger, options.Tag, options.SelectorOptions)
|
||||
case C.TypeURLTest:
|
||||
return NewURLTest(router, logger, options.Tag, options.URLTestOptions)
|
||||
case C.TypeLeastLoad:
|
||||
return NewLeastLoad(router, logger, options.Tag, options.LeastLoadOptions)
|
||||
default:
|
||||
return nil, E.New("unknown outbound type: ", options.Type)
|
||||
}
|
||||
|
127
outbound/leastload.go
Normal file
127
outbound/leastload.go
Normal file
@ -0,0 +1,127 @@
|
||||
package outbound
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math/rand"
|
||||
"net"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/balancer"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
var (
|
||||
_ adapter.Outbound = (*LeastLoad)(nil)
|
||||
_ adapter.OutboundGroup = (*LeastLoad)(nil)
|
||||
)
|
||||
|
||||
// LeastLoad is a outbound group that picks outbound with least load
|
||||
type LeastLoad struct {
|
||||
myOutboundAdapter
|
||||
options option.LeastLoadOutboundOptions
|
||||
|
||||
*balancer.LeastLoad
|
||||
nodes []*balancer.Node
|
||||
fallback adapter.Outbound
|
||||
}
|
||||
|
||||
// NewLeastLoad creates a new LeastLoad outbound
|
||||
func NewLeastLoad(router adapter.Router, logger log.ContextLogger, tag string, options option.LeastLoadOutboundOptions) (*LeastLoad, error) {
|
||||
outbound := &LeastLoad{
|
||||
myOutboundAdapter: myOutboundAdapter{
|
||||
protocol: C.TypeLeastLoad,
|
||||
router: router,
|
||||
logger: logger,
|
||||
tag: tag,
|
||||
},
|
||||
options: options,
|
||||
nodes: make([]*balancer.Node, 0, len(options.Outbounds)),
|
||||
}
|
||||
if len(options.Outbounds) == 0 {
|
||||
return nil, E.New("missing tags")
|
||||
}
|
||||
return outbound, nil
|
||||
}
|
||||
|
||||
// Network implements adapter.Outbound
|
||||
func (s *LeastLoad) Network() []string {
|
||||
picked := s.pick()
|
||||
if picked == nil {
|
||||
return []string{N.NetworkTCP, N.NetworkUDP}
|
||||
}
|
||||
return picked.Network()
|
||||
}
|
||||
|
||||
// Start implements common.Starter
|
||||
func (s *LeastLoad) Start() error {
|
||||
for i, tag := range s.options.Outbounds {
|
||||
outbound, loaded := s.router.Outbound(tag)
|
||||
if !loaded {
|
||||
return E.New("outbound ", i, " not found: ", tag)
|
||||
}
|
||||
s.nodes = append(s.nodes, balancer.NewNode(outbound))
|
||||
}
|
||||
if s.options.Fallback != "" {
|
||||
outbound, loaded := s.router.Outbound(s.options.Fallback)
|
||||
if !loaded {
|
||||
return E.New("fallback outbound not found: ", s.options.Fallback)
|
||||
}
|
||||
s.fallback = outbound
|
||||
}
|
||||
var err error
|
||||
s.LeastLoad, err = balancer.NewLeastLoad(s.nodes, s.logger, s.options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.HealthCheck.Start()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Now implements adapter.OutboundGroup
|
||||
func (s *LeastLoad) Now() string {
|
||||
return s.pick().Tag()
|
||||
}
|
||||
|
||||
// All implements adapter.OutboundGroup
|
||||
func (s *LeastLoad) All() []string {
|
||||
return s.options.Outbounds
|
||||
}
|
||||
|
||||
// DialContext implements adapter.Outbound
|
||||
func (s *LeastLoad) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||
return s.pick().DialContext(ctx, network, destination)
|
||||
}
|
||||
|
||||
// ListenPacket implements adapter.Outbound
|
||||
func (s *LeastLoad) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||
return s.pick().ListenPacket(ctx, destination)
|
||||
}
|
||||
|
||||
// NewConnection implements adapter.Outbound
|
||||
func (s *LeastLoad) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
|
||||
return s.pick().NewConnection(ctx, conn, metadata)
|
||||
}
|
||||
|
||||
// NewPacketConnection implements adapter.Outbound
|
||||
func (s *LeastLoad) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
|
||||
return s.pick().NewPacketConnection(ctx, conn, metadata)
|
||||
}
|
||||
|
||||
func (s *LeastLoad) pick() adapter.Outbound {
|
||||
selects := s.nodes
|
||||
if s.LeastLoad != nil {
|
||||
selects = s.LeastLoad.Select()
|
||||
}
|
||||
count := len(selects)
|
||||
if count == 0 {
|
||||
// goes to fallbackTag
|
||||
return s.fallback
|
||||
}
|
||||
picked := selects[rand.Intn(count)]
|
||||
return picked.Outbound
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user