add outbound 'leastload'

This commit is contained in:
jebbs 2022-10-17 15:22:33 +08:00
parent 6591dd58ca
commit e0e3c5153f
13 changed files with 1058 additions and 2 deletions

207
balancer/healthcheck.go Normal file
View File

@ -0,0 +1,207 @@
package balancer
import (
"fmt"
"math/rand"
"strings"
"sync"
"time"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
)
// HealthPingSettings holds settings for health Checker
type HealthPingSettings struct {
Destination string `json:"destination"`
Connectivity string `json:"connectivity"`
Interval time.Duration `json:"interval"`
SamplingCount int `json:"sampling"`
Timeout time.Duration `json:"timeout"`
}
// HealthCheck is the health checker for balancers
type HealthCheck struct {
sync.Mutex
ticker *time.Ticker
nodes []*Node
logger log.Logger
Settings *HealthPingSettings
Results map[string]*HealthCheckRTTS
}
// NewHealthCheck creates a new HealthPing with settings
func NewHealthCheck(outbounds []*Node, logger log.Logger, config *option.HealthCheckSettings) *HealthCheck {
settings := &HealthPingSettings{}
if config != nil {
settings = &HealthPingSettings{
Connectivity: strings.TrimSpace(config.Connectivity),
Destination: strings.TrimSpace(config.Destination),
Interval: time.Duration(config.Interval),
SamplingCount: int(config.SamplingCount),
Timeout: time.Duration(config.Timeout),
}
}
if settings.Destination == "" {
settings.Destination = "http://www.google.com/gen_204"
}
if settings.Interval == 0 {
settings.Interval = time.Duration(1) * time.Minute
} else if settings.Interval < 10 {
logger.Warn("health check interval is too small, 10s is applied")
settings.Interval = time.Duration(10) * time.Second
}
if settings.SamplingCount <= 0 {
settings.SamplingCount = 10
}
if settings.Timeout <= 0 {
// results are saved after all health pings finish,
// a larger timeout could possibly makes checks run longer
settings.Timeout = time.Duration(5) * time.Second
}
return &HealthCheck{
nodes: outbounds,
Settings: settings,
Results: nil,
logger: logger,
}
}
// Start starts the health check service
func (h *HealthCheck) Start() {
if h.ticker != nil {
return
}
interval := h.Settings.Interval * time.Duration(h.Settings.SamplingCount)
ticker := time.NewTicker(interval)
h.ticker = ticker
go func() {
for {
h.doCheck(interval, h.Settings.SamplingCount)
_, ok := <-ticker.C
if !ok {
break
}
}
}()
}
// Stop stops the health check service
func (h *HealthCheck) Stop() {
h.ticker.Stop()
h.ticker = nil
}
// Check does a one time health check
func (h *HealthCheck) Check() error {
if len(h.nodes) == 0 {
return nil
}
h.doCheck(0, 1)
return nil
}
type rtt struct {
handler string
value time.Duration
}
// doCheck performs the 'rounds' amount checks in given 'duration'. You should make
// sure all tags are valid for current balancer
func (h *HealthCheck) doCheck(duration time.Duration, rounds int) {
count := len(h.nodes) * rounds
if count == 0 {
return
}
ch := make(chan *rtt, count)
// rtts := make(map[string][]time.Duration)
for _, node := range h.nodes {
tag, detour := node.Outbound.Tag(), node.Outbound
client := newPingClient(
detour,
h.Settings.Destination,
h.Settings.Timeout,
)
for i := 0; i < rounds; i++ {
delay := time.Duration(0)
if duration > 0 {
delay = time.Duration(rand.Intn(int(duration)))
}
time.AfterFunc(delay, func() {
// h.logger.Debug("checking ", tag)
delay, err := client.MeasureDelay()
if err == nil {
ch <- &rtt{
handler: tag,
value: delay,
}
return
}
if !h.checkConnectivity() {
h.logger.Debug("network is down")
ch <- &rtt{
handler: tag,
value: 0,
}
return
}
h.logger.Debug(
E.Cause(
err,
fmt.Sprintf("ping %s via %s", h.Settings.Destination, tag),
),
)
ch <- &rtt{
handler: tag,
value: rttFailed,
}
})
}
}
for i := 0; i < count; i++ {
rtt := <-ch
if rtt.value > 0 {
// should not put results when network is down
h.PutResult(rtt.handler, rtt.value)
}
}
}
// PutResult put a ping rtt to results
func (h *HealthCheck) PutResult(tag string, rtt time.Duration) {
h.Lock()
defer h.Unlock()
if h.Results == nil {
h.Results = make(map[string]*HealthCheckRTTS)
}
r, ok := h.Results[tag]
if !ok {
// validity is 2 times to sampling period, since the check are
// distributed in the time line randomly, in extreme cases,
// previous checks are distributed on the left, and latters
// on the right
validity := h.Settings.Interval * time.Duration(h.Settings.SamplingCount) * 2
r = NewHealthPingResult(h.Settings.SamplingCount, validity)
h.Results[tag] = r
}
r.Put(rtt)
}
// checkConnectivity checks the network connectivity, it returns
// true if network is good or "connectivity check url" not set
func (h *HealthCheck) checkConnectivity() bool {
if h.Settings.Connectivity == "" {
return true
}
tester := newDirectPingClient(
h.Settings.Connectivity,
h.Settings.Timeout,
)
if _, err := tester.MeasureDelay(); err != nil {
return false
}
return true
}

View File

@ -0,0 +1,151 @@
package balancer
import (
"math"
"time"
)
const (
rttFailed = time.Duration(math.MaxInt64 - iota)
rttUntested
rttUnqualified
)
// HealthCheckStats is the statistics of HealthPingRTTS
type HealthCheckStats struct {
All int
Fail int
Deviation time.Duration
Average time.Duration
Max time.Duration
Min time.Duration
applied time.Duration
}
// HealthCheckRTTS holds ping rtts for health Checker
type HealthCheckRTTS struct {
idx int
cap int
validity time.Duration
rtts []*pingRTT
lastUpdateAt time.Time
stats *HealthCheckStats
}
type pingRTT struct {
time time.Time
value time.Duration
}
// NewHealthPingResult returns a *HealthPingResult with specified capacity
func NewHealthPingResult(cap int, validity time.Duration) *HealthCheckRTTS {
return &HealthCheckRTTS{cap: cap, validity: validity}
}
// Get gets statistics of the HealthPingRTTS
func (h *HealthCheckRTTS) Get() *HealthCheckStats {
return h.getStatistics()
}
// GetWithCache get statistics and write cache for next call
// Make sure use Mutex.Lock() before calling it, RWMutex.RLock()
// is not an option since it writes cache
func (h *HealthCheckRTTS) GetWithCache() *HealthCheckStats {
lastPutAt := h.rtts[h.idx].time
now := time.Now()
if h.stats == nil || h.lastUpdateAt.Before(lastPutAt) || h.findOutdated(now) >= 0 {
h.stats = h.getStatistics()
h.lastUpdateAt = now
}
return h.stats
}
// Put puts a new rtt to the HealthPingResult
func (h *HealthCheckRTTS) Put(d time.Duration) {
if h.rtts == nil {
h.rtts = make([]*pingRTT, h.cap)
for i := 0; i < h.cap; i++ {
h.rtts[i] = &pingRTT{}
}
h.idx = -1
}
h.idx = h.calcIndex(1)
now := time.Now()
h.rtts[h.idx].time = now
h.rtts[h.idx].value = d
}
func (h *HealthCheckRTTS) calcIndex(step int) int {
idx := h.idx
idx += step
if idx >= h.cap {
idx %= h.cap
}
return idx
}
func (h *HealthCheckRTTS) getStatistics() *HealthCheckStats {
stats := &HealthCheckStats{}
stats.Fail = 0
stats.Max = 0
stats.Min = rttFailed
sum := time.Duration(0)
cnt := 0
validRTTs := make([]time.Duration, 0)
for _, rtt := range h.rtts {
switch {
case rtt.value == 0 || time.Since(rtt.time) > h.validity:
continue
case rtt.value == rttFailed:
stats.Fail++
continue
}
cnt++
sum += rtt.value
validRTTs = append(validRTTs, rtt.value)
if stats.Max < rtt.value {
stats.Max = rtt.value
}
if stats.Min > rtt.value {
stats.Min = rtt.value
}
}
stats.All = cnt + stats.Fail
if cnt == 0 {
stats.Min = 0
return stats
}
stats.Average = time.Duration(int(sum) / cnt)
var std float64
if cnt < 2 {
// no enough data for standard deviation, we assume it's half of the average rtt
// if we don't do this, standard deviation of 1 round tested nodes is 0, will always
// selected before 2 or more rounds tested nodes
std = float64(stats.Average / 2)
} else {
variance := float64(0)
for _, rtt := range validRTTs {
variance += math.Pow(float64(rtt-stats.Average), 2)
}
std = math.Sqrt(variance / float64(cnt))
}
stats.Deviation = time.Duration(std)
return stats
}
func (h *HealthCheckRTTS) findOutdated(now time.Time) int {
for i := h.cap - 1; i < 2*h.cap; i++ {
// from oldest to latest
idx := h.calcIndex(i)
validity := h.rtts[idx].time.Add(h.validity)
if h.lastUpdateAt.After(validity) {
return idx
}
if validity.Before(now) {
return idx
}
}
return -1
}

View File

@ -0,0 +1,106 @@
package balancer_test
import (
"math"
"reflect"
"testing"
"time"
"github.com/sagernet/sing-box/balancer"
)
func TestHealthPingResults(t *testing.T) {
rtts := []int64{60, 140, 60, 140, 60, 60, 140, 60, 140}
hr := balancer.NewHealthPingResult(4, time.Hour)
for _, rtt := range rtts {
hr.Put(time.Duration(rtt))
}
rttFailed := time.Duration(math.MaxInt64)
expected := &balancer.HealthCheckStats{
All: 4,
Fail: 0,
Deviation: 40,
Average: 100,
Max: 140,
Min: 60,
}
actual := hr.Get()
if !reflect.DeepEqual(expected, actual) {
t.Errorf("expected: %v, actual: %v", expected, actual)
}
hr.Put(rttFailed)
hr.Put(rttFailed)
expected.Fail = 2
actual = hr.Get()
if !reflect.DeepEqual(expected, actual) {
t.Errorf("failed half-failures test, expected: %v, actual: %v", expected, actual)
}
hr.Put(rttFailed)
hr.Put(rttFailed)
expected = &balancer.HealthCheckStats{
All: 4,
Fail: 4,
Deviation: 0,
Average: 0,
Max: 0,
Min: 0,
}
actual = hr.Get()
if !reflect.DeepEqual(expected, actual) {
t.Errorf("failed all-failures test, expected: %v, actual: %v", expected, actual)
}
}
func TestHealthPingResultsIgnoreOutdated(t *testing.T) {
rtts := []int64{60, 140, 60, 140}
hr := balancer.NewHealthPingResult(4, time.Duration(10)*time.Millisecond)
for i, rtt := range rtts {
if i == 2 {
// wait for previous 2 outdated
time.Sleep(time.Duration(10) * time.Millisecond)
}
hr.Put(time.Duration(rtt))
}
hr.Get()
expected := &balancer.HealthCheckStats{
All: 2,
Fail: 0,
Deviation: 40,
Average: 100,
Max: 140,
Min: 60,
}
actual := hr.Get()
if !reflect.DeepEqual(expected, actual) {
t.Errorf("failed 'half-outdated' test, expected: %v, actual: %v", expected, actual)
}
// wait for all outdated
time.Sleep(time.Duration(10) * time.Millisecond)
expected = &balancer.HealthCheckStats{
All: 0,
Fail: 0,
Deviation: 0,
Average: 0,
Max: 0,
Min: 0,
}
actual = hr.Get()
if !reflect.DeepEqual(expected, actual) {
t.Errorf("failed 'outdated / not-tested' test, expected: %v, actual: %v", expected, actual)
}
hr.Put(time.Duration(60))
expected = &balancer.HealthCheckStats{
All: 1,
Fail: 0,
// 1 sample, std=0.5rtt
Deviation: 30,
Average: 60,
Max: 60,
Min: 60,
}
actual = hr.Get()
if !reflect.DeepEqual(expected, actual) {
t.Errorf("expected: %v, actual: %v", expected, actual)
}
}

164
balancer/leastload.go Normal file
View File

@ -0,0 +1,164 @@
package balancer
import (
"math"
"sort"
"time"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
)
// LeastLoad is leastload balancer
type LeastLoad struct {
nodes []*Node
options *option.LeastLoadOutboundOptions
*HealthCheck
costs *WeightManager
}
// NewLeastLoad creates a new LeastLoad outbound
func NewLeastLoad(
nodes []*Node, logger log.ContextLogger,
options option.LeastLoadOutboundOptions,
) (*LeastLoad, error) {
return &LeastLoad{
nodes: nodes,
options: &options,
HealthCheck: NewHealthCheck(nodes, logger, options.HealthCheck),
costs: NewWeightManager(
logger, options.Costs, 1,
func(value, cost float64) float64 {
return value * math.Pow(cost, 0.5)
},
),
}, nil
}
// Select selects qualified nodes
func (s *LeastLoad) Select() []*Node {
qualified, _ := s.getNodes()
return s.selectLeastLoad(qualified)
}
// selectLeastLoad selects nodes according to Baselines and Expected Count.
//
// The strategy always improves network response speed, not matter which mode below is configurated.
// But they can still have different priorities.
//
// 1. Bandwidth priority: no Baseline + Expected Count > 0.: selects `Expected Count` of nodes.
// (one if Expected Count <= 0)
//
// 2. Bandwidth priority advanced: Baselines + Expected Count > 0.
// Select `Expected Count` amount of nodes, and also those near them according to baselines.
// In other words, it selects according to different Baselines, until one of them matches
// the Expected Count, if no Baseline matches, Expected Count applied.
//
// 3. Speed priority: Baselines + `Expected Count <= 0`.
// go through all baselines until find selects, if not, select none. Used in combination
// with 'balancer.fallbackTag', it means: selects qualified nodes or use the fallback.
func (s *LeastLoad) selectLeastLoad(nodes []*Node) []*Node {
if len(nodes) == 0 {
// s.logger.Debug("LeastLoad: no qualified nodes")
return nil
}
expected := int(s.options.Expected)
availableCount := len(nodes)
if expected > availableCount {
return nodes
}
if expected <= 0 {
expected = 1
}
if len(s.options.Baselines) == 0 {
return nodes[:expected]
}
count := 0
// go through all base line until find expected selects
for _, b := range s.options.Baselines {
baseline := time.Duration(b)
for i := 0; i < availableCount; i++ {
if nodes[i].applied > baseline {
break
}
count = i + 1
}
// don't continue if find expected selects
if count >= expected {
s.logger.Debug("applied baseline: ", baseline)
break
}
}
if s.options.Expected > 0 && count < expected {
count = expected
}
return nodes[:count]
}
func (s *LeastLoad) getNodes() ([]*Node, []*Node) {
s.HealthCheck.Lock()
defer s.HealthCheck.Unlock()
qualified := make([]*Node, 0)
unqualified := make([]*Node, 0)
failed := make([]*Node, 0)
untested := make([]*Node, 0)
others := make([]*Node, 0)
for _, node := range s.nodes {
node.FetchStats(s.HealthCheck)
switch {
case node.All == 0:
node.applied = rttUntested
untested = append(untested, node)
case s.options.MaxRTT > 0 && node.Average > time.Duration(s.options.MaxRTT):
node.applied = rttUnqualified
unqualified = append(unqualified, node)
case float64(node.Fail)/float64(node.All) > float64(s.options.Tolerance):
node.applied = rttFailed
if node.All-node.Fail == 0 {
// no good, put them after has-good nodes
node.applied = rttFailed
node.Deviation = rttFailed
node.Average = rttFailed
}
failed = append(failed, node)
default:
node.applied = time.Duration(s.costs.Apply(node.Outbound.Tag(), float64(node.Deviation)))
qualified = append(qualified, node)
}
}
if len(qualified) > 0 {
leastloadSort(qualified)
others = append(others, unqualified...)
others = append(others, untested...)
others = append(others, failed...)
} else {
qualified = untested
others = append(others, unqualified...)
others = append(others, failed...)
}
return qualified, others
}
func leastloadSort(nodes []*Node) {
sort.Slice(nodes, func(i, j int) bool {
left := nodes[i]
right := nodes[j]
if left.applied != right.applied {
return left.applied < right.applied
}
if left.applied != right.applied {
return left.applied < right.applied
}
if left.Average != right.Average {
return left.Average < right.Average
}
if left.Fail != right.Fail {
return left.Fail < right.Fail
}
return left.All > right.All
})
}

37
balancer/node.go Normal file
View File

@ -0,0 +1,37 @@
package balancer
import (
"github.com/sagernet/sing-box/adapter"
)
var healthPingStatsZero = HealthCheckStats{
applied: rttUntested,
}
// Node is a banalcer node with health check result
type Node struct {
Outbound adapter.Outbound
HealthCheckStats
}
// NewNode creates a new balancer node from outbound
func NewNode(outbound adapter.Outbound) *Node {
return &Node{
Outbound: outbound,
HealthCheckStats: healthPingStatsZero,
}
}
// FetchStats fetches statistics from *HealthPing p
func (s *Node) FetchStats(p *HealthCheck) {
if p == nil || p.Results == nil {
s.HealthCheckStats = healthPingStatsZero
return
}
r, ok := p.Results[s.Outbound.Tag()]
if !ok {
s.HealthCheckStats = healthPingStatsZero
return
}
s.HealthCheckStats = *r.Get()
}

67
balancer/ping.go Normal file
View File

@ -0,0 +1,67 @@
package balancer
import (
"context"
"net/http"
"time"
"net"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
type pingClient struct {
destination string
httpClient *http.Client
}
func newPingClient(detour N.Dialer, destination string, timeout time.Duration) *pingClient {
return &pingClient{
destination: destination,
httpClient: newHTTPClient(detour, timeout),
}
}
func newDirectPingClient(destination string, timeout time.Duration) *pingClient {
return &pingClient{
destination: destination,
httpClient: &http.Client{Timeout: timeout},
}
}
func newHTTPClient(detour N.Dialer, timeout time.Duration) *http.Client {
tr := &http.Transport{
DisableKeepAlives: true,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return detour.DialContext(ctx, network, M.ParseSocksaddr(addr))
},
}
return &http.Client{
Transport: tr,
Timeout: timeout,
// don't follow redirect
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
}
// MeasureDelay returns the delay time of the request to dest
func (s *pingClient) MeasureDelay() (time.Duration, error) {
if s.httpClient == nil {
panic("pingClient no initialized")
}
req, err := http.NewRequest(http.MethodHead, s.destination, nil)
if err != nil {
return rttFailed, err
}
start := time.Now()
resp, err := s.httpClient.Do(req)
if err != nil {
return rttFailed, err
}
// don't wait for body
resp.Body.Close()
return time.Since(start), nil
}

91
balancer/weight.go Normal file
View File

@ -0,0 +1,91 @@
package balancer
import (
"regexp"
"strconv"
"strings"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
)
type weightScaler func(value, weight float64) float64
var numberFinder = regexp.MustCompile(`\d+(\.\d+)?`)
// NewWeightManager creates a new WeightManager with settings
func NewWeightManager(logger log.Logger, s []*option.StrategyWeight, defaultWeight float64, scaler weightScaler) *WeightManager {
return &WeightManager{
settings: s,
cache: make(map[string]float64),
scaler: scaler,
defaultWeight: defaultWeight,
logger: logger,
}
}
// WeightManager manages weights for specific settings
type WeightManager struct {
settings []*option.StrategyWeight
cache map[string]float64
scaler weightScaler
defaultWeight float64
logger log.Logger
}
// Get get the weight of specified tag
func (s *WeightManager) Get(tag string) float64 {
weight, ok := s.cache[tag]
if ok {
return weight
}
weight = s.findValue(tag)
s.cache[tag] = weight
return weight
}
// Apply applies weight to the value
func (s *WeightManager) Apply(tag string, value float64) float64 {
return s.scaler(value, s.Get(tag))
}
func (s *WeightManager) findValue(tag string) float64 {
for _, w := range s.settings {
matched := s.getMatch(tag, w.Match, w.Regexp)
if matched == "" {
continue
}
if w.Value > 0 {
return float64(w.Value)
}
// auto weight from matched
numStr := numberFinder.FindString(matched)
if numStr == "" {
return s.defaultWeight
}
weight, err := strconv.ParseFloat(numStr, 64)
if err != nil {
s.logger.Warn(E.Cause(err, "parse weight from tag"))
return s.defaultWeight
}
return weight
}
return s.defaultWeight
}
func (s *WeightManager) getMatch(tag, find string, isRegexp bool) string {
if !isRegexp {
idx := strings.Index(tag, find)
if idx < 0 {
return ""
}
return find
}
r, err := regexp.Compile(find)
if err != nil {
s.logger.Warn(E.Cause(err, "weight regexp"))
return ""
}
return r.FindString(tag)
}

63
balancer/weight_test.go Normal file
View File

@ -0,0 +1,63 @@
package balancer_test
import (
"reflect"
"testing"
"github.com/sagernet/sing-box/balancer"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
)
func TestWeight(t *testing.T) {
manager := balancer.NewWeightManager(
log.NewNOPFactory().Logger(),
[]*option.StrategyWeight{
{
Match: "x5",
Value: 100,
},
{
Match: "x8",
},
{
Regexp: true,
Match: `\bx0+(\.\d+)?\b`,
Value: 1,
},
{
Regexp: true,
Match: `\bx\d+(\.\d+)?\b`,
},
},
1, func(v, w float64) float64 {
return v * w
},
)
tags := []string{
"node name, x5, and more",
"node name, x8",
"node name, x15",
"node name, x0100, and more",
"node name, x10.1",
"node name, x00.1, and more",
}
// test weight
expected := []float64{100, 8, 15, 100, 10.1, 1}
actual := make([]float64, 0)
for _, tag := range tags {
actual = append(actual, manager.Get(tag))
}
if !reflect.DeepEqual(expected, actual) {
t.Errorf("expected: %v, actual: %v", expected, actual)
}
// test scale
expected2 := []float64{1000, 80, 150, 1000, 101, 10}
actual2 := make([]float64, 0)
for _, tag := range tags {
actual2 = append(actual2, manager.Apply(tag, 10))
}
if !reflect.DeepEqual(expected2, actual2) {
t.Errorf("expected2: %v, actual2: %v", expected2, actual2)
}
}

View File

@ -24,6 +24,7 @@ const (
)
const (
TypeSelector = "selector"
TypeURLTest = "urltest"
TypeSelector = "selector"
TypeURLTest = "urltest"
TypeLeastLoad = "leastload"
)

35
option/balancer.go Normal file
View File

@ -0,0 +1,35 @@
package option
// LeastLoadOutboundOptions is the options for leastload outbound
type LeastLoadOutboundOptions struct {
Outbounds []string `json:"outbounds"`
Fallback string `json:"fallback,omitempty"`
// health check settings
HealthCheck *HealthCheckSettings `json:"healthCheck,omitempty"`
// cost settings
Costs []*StrategyWeight `json:"costs,omitempty"`
// ping rtt baselines (ms)
Baselines []Duration `json:"baselines,omitempty"`
// expected nodes count to select
Expected int32 `json:"expected,omitempty"`
// max acceptable rtt (ms), filter away high delay nodes. defalut 0
MaxRTT Duration `json:"maxRTT,omitempty"`
// acceptable failure rate
Tolerance float64 `json:"tolerance,omitempty"`
}
// HealthCheckSettings is the settings for health check
type HealthCheckSettings struct {
Destination string `json:"destination"`
Connectivity string `json:"connectivity"`
Interval Duration `json:"interval"`
SamplingCount int `json:"sampling"`
Timeout Duration `json:"timeout"`
}
// StrategyWeight is the weight for a balancing strategy
type StrategyWeight struct {
Regexp bool `json:"regexp,omitempty"`
Match string `json:"match,omitempty"`
Value float32 `json:"value,omitempty"`
}

View File

@ -25,6 +25,7 @@ type _Outbound struct {
VLESSOptions VLESSOutboundOptions `json:"-"`
SelectorOptions SelectorOutboundOptions `json:"-"`
URLTestOptions URLTestOutboundOptions `json:"-"`
LeastLoadOptions LeastLoadOutboundOptions `json:"-"`
}
type Outbound _Outbound
@ -64,6 +65,8 @@ func (h Outbound) MarshalJSON() ([]byte, error) {
v = h.SelectorOptions
case C.TypeURLTest:
v = h.URLTestOptions
case C.TypeLeastLoad:
v = h.LeastLoadOptions
default:
return nil, E.New("unknown outbound type: ", h.Type)
}
@ -109,6 +112,8 @@ func (h *Outbound) UnmarshalJSON(bytes []byte) error {
v = &h.SelectorOptions
case C.TypeURLTest:
v = &h.URLTestOptions
case C.TypeLeastLoad:
v = &h.LeastLoadOptions
default:
return E.New("unknown outbound type: ", h.Type)
}

View File

@ -49,6 +49,8 @@ func New(ctx context.Context, router adapter.Router, logger log.ContextLogger, o
return NewSelector(router, logger, options.Tag, options.SelectorOptions)
case C.TypeURLTest:
return NewURLTest(router, logger, options.Tag, options.URLTestOptions)
case C.TypeLeastLoad:
return NewLeastLoad(router, logger, options.Tag, options.LeastLoadOptions)
default:
return nil, E.New("unknown outbound type: ", options.Type)
}

127
outbound/leastload.go Normal file
View File

@ -0,0 +1,127 @@
package outbound
import (
"context"
"math/rand"
"net"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/balancer"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
var (
_ adapter.Outbound = (*LeastLoad)(nil)
_ adapter.OutboundGroup = (*LeastLoad)(nil)
)
// LeastLoad is a outbound group that picks outbound with least load
type LeastLoad struct {
myOutboundAdapter
options option.LeastLoadOutboundOptions
*balancer.LeastLoad
nodes []*balancer.Node
fallback adapter.Outbound
}
// NewLeastLoad creates a new LeastLoad outbound
func NewLeastLoad(router adapter.Router, logger log.ContextLogger, tag string, options option.LeastLoadOutboundOptions) (*LeastLoad, error) {
outbound := &LeastLoad{
myOutboundAdapter: myOutboundAdapter{
protocol: C.TypeLeastLoad,
router: router,
logger: logger,
tag: tag,
},
options: options,
nodes: make([]*balancer.Node, 0, len(options.Outbounds)),
}
if len(options.Outbounds) == 0 {
return nil, E.New("missing tags")
}
return outbound, nil
}
// Network implements adapter.Outbound
func (s *LeastLoad) Network() []string {
picked := s.pick()
if picked == nil {
return []string{N.NetworkTCP, N.NetworkUDP}
}
return picked.Network()
}
// Start implements common.Starter
func (s *LeastLoad) Start() error {
for i, tag := range s.options.Outbounds {
outbound, loaded := s.router.Outbound(tag)
if !loaded {
return E.New("outbound ", i, " not found: ", tag)
}
s.nodes = append(s.nodes, balancer.NewNode(outbound))
}
if s.options.Fallback != "" {
outbound, loaded := s.router.Outbound(s.options.Fallback)
if !loaded {
return E.New("fallback outbound not found: ", s.options.Fallback)
}
s.fallback = outbound
}
var err error
s.LeastLoad, err = balancer.NewLeastLoad(s.nodes, s.logger, s.options)
if err != nil {
return err
}
s.HealthCheck.Start()
return nil
}
// Now implements adapter.OutboundGroup
func (s *LeastLoad) Now() string {
return s.pick().Tag()
}
// All implements adapter.OutboundGroup
func (s *LeastLoad) All() []string {
return s.options.Outbounds
}
// DialContext implements adapter.Outbound
func (s *LeastLoad) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
return s.pick().DialContext(ctx, network, destination)
}
// ListenPacket implements adapter.Outbound
func (s *LeastLoad) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
return s.pick().ListenPacket(ctx, destination)
}
// NewConnection implements adapter.Outbound
func (s *LeastLoad) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
return s.pick().NewConnection(ctx, conn, metadata)
}
// NewPacketConnection implements adapter.Outbound
func (s *LeastLoad) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
return s.pick().NewPacketConnection(ctx, conn, metadata)
}
func (s *LeastLoad) pick() adapter.Outbound {
selects := s.nodes
if s.LeastLoad != nil {
selects = s.LeastLoad.Select()
}
count := len(selects)
if count == 0 {
// goes to fallbackTag
return s.fallback
}
picked := selects[rand.Intn(count)]
return picked.Outbound
}