diff --git a/balancer/balancer.go b/balancer/balancer.go new file mode 100644 index 00000000..ac21dd4e --- /dev/null +++ b/balancer/balancer.go @@ -0,0 +1,6 @@ +package balancer + +// Balancer is interface for load balancers +type Balancer interface { + Select() *Node +} diff --git a/balancer/healthcheck.go b/balancer/healthcheck.go index 15697206..f1552679 100644 --- a/balancer/healthcheck.go +++ b/balancer/healthcheck.go @@ -71,22 +71,24 @@ func NewHealthCheck(outbounds []*Node, logger log.Logger, config *option.HealthC } // Start starts the health check service -func (h *HealthCheck) Start() { +func (h *HealthCheck) Start() error { if h.ticker != nil { - return + return nil } interval := h.Settings.Interval * time.Duration(h.Settings.SamplingCount) ticker := time.NewTicker(interval) h.ticker = ticker go func() { + h.doCheck(0, 1) for { - h.doCheck(interval, h.Settings.SamplingCount) _, ok := <-ticker.C if !ok { break } + h.doCheck(interval, h.Settings.SamplingCount) } }() + return nil } // Stop stops the health check service @@ -105,8 +107,8 @@ func (h *HealthCheck) Check() error { } type rtt struct { - handler string - value time.Duration + tag string + value time.Duration } // doCheck performs the 'rounds' amount checks in given 'duration'. You should make @@ -135,16 +137,16 @@ func (h *HealthCheck) doCheck(duration time.Duration, rounds int) { delay, err := client.MeasureDelay() if err == nil { ch <- &rtt{ - handler: tag, - value: delay, + tag: tag, + value: delay, } return } if !h.checkConnectivity() { h.logger.Debug("network is down") ch <- &rtt{ - handler: tag, - value: 0, + tag: tag, + value: 0, } return } @@ -155,8 +157,8 @@ func (h *HealthCheck) doCheck(duration time.Duration, rounds int) { ), ) ch <- &rtt{ - handler: tag, - value: rttFailed, + tag: tag, + value: rttFailed, } }) } @@ -164,8 +166,9 @@ func (h *HealthCheck) doCheck(duration time.Duration, rounds int) { for i := 0; i < count; i++ { rtt := <-ch if rtt.value > 0 { + // h.logger.Debug("ping ", rtt.tag, ":", rtt.value) // should not put results when network is down - h.PutResult(rtt.handler, rtt.value) + h.PutResult(rtt.tag, rtt.value) } } } diff --git a/balancer/leastload.go b/balancer/leastload.go index 5c9d9cf7..6e0c91ca 100644 --- a/balancer/leastload.go +++ b/balancer/leastload.go @@ -2,6 +2,7 @@ package balancer import ( "math" + "math/rand" "sort" "time" @@ -9,6 +10,8 @@ import ( "github.com/sagernet/sing-box/option" ) +var _ Balancer = (*LeastLoad)(nil) + // LeastLoad is leastload balancer type LeastLoad struct { nodes []*Node @@ -22,11 +25,11 @@ type LeastLoad struct { func NewLeastLoad( nodes []*Node, logger log.ContextLogger, options option.LeastLoadOutboundOptions, -) (*LeastLoad, error) { +) (Balancer, error) { return &LeastLoad{ nodes: nodes, options: &options, - HealthCheck: NewHealthCheck(nodes, logger, options.HealthCheck), + HealthCheck: NewHealthCheck(nodes, logger, &options.HealthCheck), costs: NewWeightManager( logger, options.Costs, 1, func(value, cost float64) float64 { @@ -37,9 +40,14 @@ func NewLeastLoad( } // Select selects qualified nodes -func (s *LeastLoad) Select() []*Node { +func (s *LeastLoad) Select() *Node { qualified, _ := s.getNodes() - return s.selectLeastLoad(qualified) + selects := s.selectLeastLoad(qualified) + count := len(selects) + if count == 0 { + return nil + } + return selects[rand.Intn(count)] } // selectLeastLoad selects nodes according to Baselines and Expected Count. @@ -150,9 +158,6 @@ func leastloadSort(nodes []*Node) { if left.applied != right.applied { return left.applied < right.applied } - if left.applied != right.applied { - return left.applied < right.applied - } if left.Average != right.Average { return left.Average < right.Average } diff --git a/balancer/leastping.go b/balancer/leastping.go new file mode 100644 index 00000000..ceedcb01 --- /dev/null +++ b/balancer/leastping.go @@ -0,0 +1,109 @@ +package balancer + +import ( + "math/rand" + "sort" + "time" + + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" +) + +var _ Balancer = (*LeastPing)(nil) + +// LeastPing is least ping balancer +type LeastPing struct { + nodes []*Node + options *option.LeastPingOutboundOptions + + *HealthCheck +} + +// NewLeastPing creates a new LeastPing outbound +func NewLeastPing( + nodes []*Node, logger log.ContextLogger, + options option.LeastPingOutboundOptions, +) (Balancer, error) { + return &LeastPing{ + nodes: nodes, + options: &options, + HealthCheck: NewHealthCheck(nodes, logger, &options.HealthCheck), + }, nil +} + +// Select selects least ping node +func (s *LeastPing) Select() *Node { + qualified, _ := s.getNodes() + if len(qualified) == 0 { + return nil + } + return qualified[0] +} + +func (s *LeastPing) getNodes() ([]*Node, []*Node) { + s.HealthCheck.Lock() + defer s.HealthCheck.Unlock() + + qualified := make([]*Node, 0) + unqualified := make([]*Node, 0) + failed := make([]*Node, 0) + untested := make([]*Node, 0) + others := make([]*Node, 0) + for _, node := range s.nodes { + node.FetchStats(s.HealthCheck) + switch { + case node.All == 0: + node.applied = rttUntested + untested = append(untested, node) + case s.options.MaxRTT > 0 && node.Average > time.Duration(s.options.MaxRTT): + node.applied = rttUnqualified + unqualified = append(unqualified, node) + case float64(node.Fail)/float64(node.All) > float64(s.options.Tolerance): + node.applied = rttFailed + if node.All-node.Fail == 0 { + // no good, put them after has-good nodes + node.applied = rttFailed + node.Deviation = rttFailed + node.Average = rttFailed + } + failed = append(failed, node) + default: + node.applied = node.Average + qualified = append(qualified, node) + } + } + if len(qualified) > 0 { + leastPingSort(qualified) + others = append(others, unqualified...) + others = append(others, untested...) + others = append(others, failed...) + } else { + // random node if not tested + shuffle(untested) + qualified = untested + others = append(others, unqualified...) + others = append(others, failed...) + } + return qualified, others +} + +func leastPingSort(nodes []*Node) { + sort.Slice(nodes, func(i, j int) bool { + left := nodes[i] + right := nodes[j] + if left.applied != right.applied { + return left.applied < right.applied + } + if left.Fail != right.Fail { + return left.Fail < right.Fail + } + return left.All > right.All + }) +} + +func shuffle(nodes []*Node) { + rand.Seed(time.Now().Unix()) + rand.Shuffle(len(nodes), func(i, j int) { + nodes[i], nodes[j] = nodes[j], nodes[i] + }) +} diff --git a/constant/proxy.go b/constant/proxy.go index af7edcab..867b41de 100644 --- a/constant/proxy.go +++ b/constant/proxy.go @@ -24,4 +24,5 @@ const ( const ( TypeSelector = "selector" TypeLeastLoad = "leastload" + TypeLeastPing = "leastping" ) diff --git a/option/balancer.go b/option/balancer.go index 4b0305ec..3f52a889 100644 --- a/option/balancer.go +++ b/option/balancer.go @@ -1,23 +1,39 @@ package option -// LeastLoadOutboundOptions is the options for leastload outbound -type LeastLoadOutboundOptions struct { +// BalancerOutboundOptions is the options for balancer outbound +type BalancerOutboundOptions struct { Outbounds []string `json:"outbounds"` Fallback string `json:"fallback,omitempty"` +} + +// HealthCheckOptions is the options for health check +type HealthCheckOptions struct { // health check settings - HealthCheck *HealthCheckSettings `json:"healthCheck,omitempty"` - // cost settings - Costs []*StrategyWeight `json:"costs,omitempty"` - // ping rtt baselines (ms) - Baselines []Duration `json:"baselines,omitempty"` - // expected nodes count to select - Expected int32 `json:"expected,omitempty"` + HealthCheck HealthCheckSettings `json:"healthCheck,omitempty"` // max acceptable rtt (ms), filter away high delay nodes. defalut 0 MaxRTT Duration `json:"maxRTT,omitempty"` // acceptable failure rate Tolerance float64 `json:"tolerance,omitempty"` } +// LeastPingOutboundOptions is the options for leastping outbound +type LeastPingOutboundOptions struct { + BalancerOutboundOptions + HealthCheckOptions +} + +// LeastLoadOutboundOptions is the options for leastload outbound +type LeastLoadOutboundOptions struct { + BalancerOutboundOptions + HealthCheckOptions + // expected nodes count to select + Expected int32 `json:"expected,omitempty"` + // ping rtt baselines (ms) + Baselines []Duration `json:"baselines,omitempty"` + // cost settings + Costs []*StrategyWeight `json:"costs,omitempty"` +} + // HealthCheckSettings is the settings for health check type HealthCheckSettings struct { Destination string `json:"destination"` diff --git a/option/outbound.go b/option/outbound.go index 52d70de0..4b5003e6 100644 --- a/option/outbound.go +++ b/option/outbound.go @@ -23,6 +23,7 @@ type _Outbound struct { ShadowTLSOptions ShadowTLSOutboundOptions `json:"-"` SelectorOptions SelectorOutboundOptions `json:"-"` LeastLoadOptions LeastLoadOutboundOptions `json:"-"` + LeastPingOptions LeastPingOutboundOptions `json:"-"` } type Outbound _Outbound @@ -58,6 +59,8 @@ func (h Outbound) MarshalJSON() ([]byte, error) { v = h.SelectorOptions case C.TypeLeastLoad: v = h.LeastLoadOptions + case C.TypeLeastPing: + v = h.LeastPingOptions default: return nil, E.New("unknown outbound type: ", h.Type) } @@ -99,6 +102,8 @@ func (h *Outbound) UnmarshalJSON(bytes []byte) error { v = &h.SelectorOptions case C.TypeLeastLoad: v = &h.LeastLoadOptions + case C.TypeLeastPing: + v = &h.LeastPingOptions default: return E.New("unknown outbound type: ", h.Type) } diff --git a/outbound/balancer.go b/outbound/balancer.go new file mode 100644 index 00000000..0b6e1d9b --- /dev/null +++ b/outbound/balancer.go @@ -0,0 +1,137 @@ +package outbound + +import ( + "context" + "math/rand" + "net" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/balancer" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +var ( + _ adapter.Outbound = (*Balancer)(nil) + _ adapter.OutboundGroup = (*Balancer)(nil) +) + +// Balancer is a outbound group that picks outbound with least load +type Balancer struct { + myOutboundAdapter + + tags []string + fallbackTag string + + balancer.Balancer + nodes []*balancer.Node + fallback adapter.Outbound +} + +// NewBalancer creates a new Balancer outbound +func NewBalancer( + protocol string, router adapter.Router, logger log.ContextLogger, tag string, + outbounds []string, fallbackTag string, +) *Balancer { + b := &Balancer{ + myOutboundAdapter: myOutboundAdapter{ + protocol: protocol, + router: router, + logger: logger, + tag: tag, + }, + tags: outbounds, + fallbackTag: fallbackTag, + } + return b +} + +// Network implements adapter.Outbound +func (s *Balancer) Network() []string { + picked := s.pick() + if picked == nil { + return []string{N.NetworkTCP, N.NetworkUDP} + } + return picked.Network() +} + +// Now implements adapter.OutboundGroup +func (s *Balancer) Now() string { + return s.pick().Tag() +} + +// All implements adapter.OutboundGroup +func (s *Balancer) All() []string { + return s.tags +} + +// DialContext implements adapter.Outbound +func (s *Balancer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { + return s.pick().DialContext(ctx, network, destination) +} + +// ListenPacket implements adapter.Outbound +func (s *Balancer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { + return s.pick().ListenPacket(ctx, destination) +} + +// NewConnection implements adapter.Outbound +func (s *Balancer) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { + return s.pick().NewConnection(ctx, conn, metadata) +} + +// NewPacketConnection implements adapter.Outbound +func (s *Balancer) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { + return s.pick().NewPacketConnection(ctx, conn, metadata) +} + +// initialize inits the balancer +func (s *Balancer) initialize() error { + for i, tag := range s.tags { + outbound, loaded := s.router.Outbound(tag) + if !loaded { + return E.New("outbound ", i, " not found: ", tag) + } + s.nodes = append(s.nodes, balancer.NewNode(outbound)) + } + if s.fallbackTag != "" { + outbound, loaded := s.router.Outbound(s.fallbackTag) + if !loaded { + return E.New("fallback outbound not found: ", s.fallbackTag) + } + s.fallback = outbound + } + return nil +} + +func (s *Balancer) setBalancer(b balancer.Balancer) error { + s.Balancer = b + if starter, isStarter := b.(common.Starter); isStarter { + err := starter.Start() + if err != nil { + return err + } + } + return nil +} + +func (s *Balancer) pick() adapter.Outbound { + if s.Balancer != nil { + selected := s.Balancer.Select() + if selected == nil { + return s.fallback + } + return selected.Outbound + } + // not started + count := len(s.nodes) + if count == 0 { + // goes to fallbackTag + return s.fallback + } + picked := s.nodes[rand.Intn(count)] + return picked.Outbound +} diff --git a/outbound/builder.go b/outbound/builder.go index 60306942..ab748ed2 100644 --- a/outbound/builder.go +++ b/outbound/builder.go @@ -45,6 +45,8 @@ func New(ctx context.Context, router adapter.Router, logger log.ContextLogger, o return NewSelector(router, logger, options.Tag, options.SelectorOptions) case C.TypeLeastLoad: return NewLeastLoad(router, logger, options.Tag, options.LeastLoadOptions) + case C.TypeLeastPing: + return NewLeastPing(router, logger, options.Tag, options.LeastPingOptions) default: return nil, E.New("unknown outbound type: ", options.Type) } diff --git a/outbound/leastload.go b/outbound/leastload.go index 844cbc14..ce1e0478 100644 --- a/outbound/leastload.go +++ b/outbound/leastload.go @@ -1,18 +1,12 @@ package outbound import ( - "context" - "math/rand" - "net" - "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/balancer" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" E "github.com/sagernet/sing/common/exceptions" - M "github.com/sagernet/sing/common/metadata" - N "github.com/sagernet/sing/common/network" ) var ( @@ -22,106 +16,34 @@ var ( // LeastLoad is a outbound group that picks outbound with least load type LeastLoad struct { - myOutboundAdapter - options option.LeastLoadOutboundOptions + *Balancer - *balancer.LeastLoad - nodes []*balancer.Node - fallback adapter.Outbound + options option.LeastLoadOutboundOptions } // NewLeastLoad creates a new LeastLoad outbound func NewLeastLoad(router adapter.Router, logger log.ContextLogger, tag string, options option.LeastLoadOutboundOptions) (*LeastLoad, error) { - outbound := &LeastLoad{ - myOutboundAdapter: myOutboundAdapter{ - protocol: C.TypeLeastLoad, - router: router, - logger: logger, - tag: tag, - }, - options: options, - nodes: make([]*balancer.Node, 0, len(options.Outbounds)), - } if len(options.Outbounds) == 0 { return nil, E.New("missing tags") } - return outbound, nil -} - -// Network implements adapter.Outbound -func (s *LeastLoad) Network() []string { - picked := s.pick() - if picked == nil { - return []string{N.NetworkTCP, N.NetworkUDP} - } - return picked.Network() + return &LeastLoad{ + Balancer: NewBalancer( + C.TypeLeastLoad, router, logger, tag, + options.Outbounds, options.Fallback, + ), + options: options, + }, nil } // Start implements common.Starter func (s *LeastLoad) Start() error { - for i, tag := range s.options.Outbounds { - outbound, loaded := s.router.Outbound(tag) - if !loaded { - return E.New("outbound ", i, " not found: ", tag) - } - s.nodes = append(s.nodes, balancer.NewNode(outbound)) - } - if s.options.Fallback != "" { - outbound, loaded := s.router.Outbound(s.options.Fallback) - if !loaded { - return E.New("fallback outbound not found: ", s.options.Fallback) - } - s.fallback = outbound - } - var err error - s.LeastLoad, err = balancer.NewLeastLoad(s.nodes, s.logger, s.options) + err := s.Balancer.initialize() if err != nil { return err } - s.HealthCheck.Start() - return nil -} - -// Now implements adapter.OutboundGroup -func (s *LeastLoad) Now() string { - return s.pick().Tag() -} - -// All implements adapter.OutboundGroup -func (s *LeastLoad) All() []string { - return s.options.Outbounds -} - -// DialContext implements adapter.Outbound -func (s *LeastLoad) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { - return s.pick().DialContext(ctx, network, destination) -} - -// ListenPacket implements adapter.Outbound -func (s *LeastLoad) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { - return s.pick().ListenPacket(ctx, destination) -} - -// NewConnection implements adapter.Outbound -func (s *LeastLoad) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { - return s.pick().NewConnection(ctx, conn, metadata) -} - -// NewPacketConnection implements adapter.Outbound -func (s *LeastLoad) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { - return s.pick().NewPacketConnection(ctx, conn, metadata) -} - -func (s *LeastLoad) pick() adapter.Outbound { - selects := s.nodes - if s.LeastLoad != nil { - selects = s.LeastLoad.Select() + b, err := balancer.NewLeastLoad(s.nodes, s.logger, s.options) + if err != nil { + return err } - count := len(selects) - if count == 0 { - // goes to fallbackTag - return s.fallback - } - picked := selects[rand.Intn(count)] - return picked.Outbound + return s.setBalancer(b) } diff --git a/outbound/leastping.go b/outbound/leastping.go new file mode 100644 index 00000000..f3172a68 --- /dev/null +++ b/outbound/leastping.go @@ -0,0 +1,49 @@ +package outbound + +import ( + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/balancer" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" +) + +var ( + _ adapter.Outbound = (*LeastPing)(nil) + _ adapter.OutboundGroup = (*LeastPing)(nil) +) + +// LeastPing is a outbound group that picks outbound with least load +type LeastPing struct { + *Balancer + + options option.LeastPingOutboundOptions +} + +// NewLeastPing creates a new LeastPing outbound +func NewLeastPing(router adapter.Router, logger log.ContextLogger, tag string, options option.LeastPingOutboundOptions) (*LeastPing, error) { + if len(options.Outbounds) == 0 { + return nil, E.New("missing tags") + } + return &LeastPing{ + Balancer: NewBalancer( + C.TypeLeastPing, router, logger, tag, + options.Outbounds, options.Fallback, + ), + options: options, + }, nil +} + +// Start implements common.Starter +func (s *LeastPing) Start() error { + err := s.Balancer.initialize() + if err != nil { + return err + } + b, err := balancer.NewLeastPing(s.nodes, s.logger, s.options) + if err != nil { + return err + } + return s.setBalancer(b) +}