tidy up balancer package exports

This commit is contained in:
jebbs 2022-10-12 11:54:43 +08:00
parent 92adb851bd
commit 1a2e43a2b6
7 changed files with 64 additions and 67 deletions

View File

@ -32,7 +32,7 @@ type rttBasedBalancer struct {
costs *WeightManager costs *WeightManager
} }
type rttFunc func(node *Node) time.Duration type rttFunc func(n *Node) time.Duration
// newRTTBasedLoad creates a new rtt based load balancer // newRTTBasedLoad creates a new rtt based load balancer
func newRTTBasedBalancer( func newRTTBasedBalancer(
@ -56,12 +56,12 @@ func newRTTBasedBalancer(
// Select selects qualified nodes // Select selects qualified nodes
func (s *rttBasedBalancer) Networks() []string { func (s *rttBasedBalancer) Networks() []string {
hasTCP, hasUDP := false, false hasTCP, hasUDP := false, false
nodes := s.HealthCheck.NodesByCategory("") nodes := s.HealthCheck.Nodes("")
for _, node := range nodes.Qualified { for _, n := range nodes.Qualified {
if !hasTCP && common.Contains(node.Networks, N.NetworkTCP) { if !hasTCP && common.Contains(n.Networks, N.NetworkTCP) {
hasTCP = true hasTCP = true
} }
if !hasUDP && common.Contains(node.Networks, N.NetworkUDP) { if !hasUDP && common.Contains(n.Networks, N.NetworkUDP) {
hasUDP = true hasUDP = true
} }
if hasTCP && hasUDP { if hasTCP && hasUDP {
@ -69,11 +69,11 @@ func (s *rttBasedBalancer) Networks() []string {
} }
} }
if !hasTCP && !hasUDP { if !hasTCP && !hasUDP {
for _, node := range nodes.Untested { for _, n := range nodes.Untested {
if !hasTCP && common.Contains(node.Networks, N.NetworkTCP) { if !hasTCP && common.Contains(n.Networks, N.NetworkTCP) {
hasTCP = true hasTCP = true
} }
if !hasUDP && common.Contains(node.Networks, N.NetworkUDP) { if !hasUDP && common.Contains(n.Networks, N.NetworkUDP) {
hasUDP = true hasUDP = true
} }
if hasTCP && hasUDP { if hasTCP && hasUDP {
@ -93,12 +93,12 @@ func (s *rttBasedBalancer) Networks() []string {
// Select selects qualified nodes // Select selects qualified nodes
func (s *rttBasedBalancer) Pick(network string) string { func (s *rttBasedBalancer) Pick(network string) string {
nodes := s.HealthCheck.NodesByCategory(network) nodes := s.HealthCheck.Nodes(network)
var candidates []*Node var candidates []*Node
if len(nodes.Qualified) > 0 { if len(nodes.Qualified) > 0 {
candidates = nodes.Qualified candidates = nodes.Qualified
for _, node := range candidates { for _, n := range candidates {
node.Weighted = time.Duration(s.costs.Apply(node.Tag, float64(s.rttFunc(node)))) n.Weighted = time.Duration(s.costs.Apply(n.Tag, float64(s.rttFunc(n))))
} }
sortNodes(candidates) sortNodes(candidates)
} else { } else {

View File

@ -14,7 +14,7 @@ import (
// HealthCheck is the health checker for balancers // HealthCheck is the health checker for balancers
type HealthCheck struct { type HealthCheck struct {
sync.Mutex mutex sync.Mutex
ticker *time.Ticker ticker *time.Ticker
router adapter.Router router adapter.Router
@ -64,8 +64,8 @@ func NewHealthCheck(router adapter.Router, tags []string, logger log.Logger, con
// Start starts the health check service // Start starts the health check service
func (h *HealthCheck) Start() error { func (h *HealthCheck) Start() error {
h.Lock() h.mutex.Lock()
defer h.Unlock() defer h.mutex.Unlock()
if h.ticker != nil { if h.ticker != nil {
return nil return nil
} }
@ -88,8 +88,8 @@ func (h *HealthCheck) Start() error {
// Stop stops the health check service // Stop stops the health check service
func (h *HealthCheck) Stop() { func (h *HealthCheck) Stop() {
h.Lock() h.mutex.Lock()
defer h.Unlock() defer h.mutex.Unlock()
if h.ticker != nil { if h.ticker != nil {
h.ticker.Stop() h.ticker.Stop()
h.ticker = nil h.ticker = nil
@ -116,8 +116,8 @@ func (h *HealthCheck) doCheck(duration time.Duration, rounds int) {
} }
ch := make(chan *rtt, count) ch := make(chan *rtt, count)
// rtts := make(map[string][]time.Duration) // rtts := make(map[string][]time.Duration)
for _, node := range nodes { for _, n := range nodes {
tag, detour := node.Tag(), node tag, detour := n.Tag(), n
client := newPingClient( client := newPingClient(
detour, detour,
h.options.Destination, h.options.Destination,
@ -164,15 +164,15 @@ func (h *HealthCheck) doCheck(duration time.Duration, rounds int) {
if rtt.value > 0 { if rtt.value > 0 {
// h.logger.Debug("ping ", rtt.tag, ":", rtt.value) // h.logger.Debug("ping ", rtt.tag, ":", rtt.value)
// should not put results when network is down // should not put results when network is down
h.PutResult(rtt.tag, rtt.value) h.putResult(rtt.tag, rtt.value)
} }
} }
} }
// PutResult put a ping rtt to results // putResult put a ping rtt to results
func (h *HealthCheck) PutResult(tag string, rtt time.Duration) { func (h *HealthCheck) putResult(tag string, rtt time.Duration) {
h.Lock() h.mutex.Lock()
defer h.Unlock() defer h.mutex.Unlock()
r, ok := h.results[tag] r, ok := h.results[tag]
if !ok { if !ok {
// the result may come after the node is removed // the result may come after the node is removed

View File

@ -8,21 +8,28 @@ import (
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
) )
// CategorizedNodes holds the categorized nodes // Nodes holds the categorized nodes
type CategorizedNodes struct { type Nodes struct {
Qualified, Unqualified []*Node Qualified, Unqualified []*Node
Failed, Untested []*Node Failed, Untested []*Node
} }
// NodesByCategory returns the categorized nodes for specific network. // Node is a banalcer Node with health check result
type Node struct {
Tag string
Networks []string
RTTStats
}
// Nodes returns the categorized nodes for specific network.
// If network is empty, all nodes are returned. // If network is empty, all nodes are returned.
func (h *HealthCheck) NodesByCategory(network string) *CategorizedNodes { func (h *HealthCheck) Nodes(network string) *Nodes {
h.Lock() h.mutex.Lock()
defer h.Unlock() defer h.mutex.Unlock()
if h == nil || len(h.results) == 0 { if h == nil || len(h.results) == 0 {
return &CategorizedNodes{} return &Nodes{}
} }
nodes := &CategorizedNodes{ nodes := &Nodes{
Qualified: make([]*Node, 0, len(h.results)), Qualified: make([]*Node, 0, len(h.results)),
Unqualified: make([]*Node, 0, len(h.results)), Unqualified: make([]*Node, 0, len(h.results)),
Failed: make([]*Node, 0, len(h.results)), Failed: make([]*Node, 0, len(h.results)),
@ -32,21 +39,21 @@ func (h *HealthCheck) NodesByCategory(network string) *CategorizedNodes {
if network != "" && !common.Contains(result.networks, network) { if network != "" && !common.Contains(result.networks, network) {
continue continue
} }
node := &Node{ n := &Node{
Tag: tag, Tag: tag,
Networks: result.networks, Networks: result.networks,
RTTStats: result.rttStorage.Get(), RTTStats: result.rttStorage.Get(),
} }
switch { switch {
case node.RTTStats.All == 0: case n.RTTStats.All == 0:
nodes.Untested = append(nodes.Untested, node) nodes.Untested = append(nodes.Untested, n)
case node.RTTStats.All == node.RTTStats.Fail, case n.RTTStats.All == n.RTTStats.Fail,
float64(node.Fail)/float64(node.All) > float64(h.options.Tolerance): float64(n.Fail)/float64(n.All) > float64(h.options.Tolerance):
nodes.Failed = append(nodes.Failed, node) nodes.Failed = append(nodes.Failed, n)
case h.options.MaxRTT > 0 && node.Average > time.Duration(h.options.MaxRTT): case h.options.MaxRTT > 0 && n.Average > time.Duration(h.options.MaxRTT):
nodes.Unqualified = append(nodes.Unqualified, node) nodes.Unqualified = append(nodes.Unqualified, n)
default: default:
nodes.Qualified = append(nodes.Qualified, node) nodes.Qualified = append(nodes.Qualified, n)
} }
} }
return nodes return nodes
@ -69,13 +76,13 @@ func CoveredOutbounds(router adapter.Router, tags []string) []adapter.Outbound {
// refreshNodes matches nodes from router by tag prefix, and refreshes the health check results // refreshNodes matches nodes from router by tag prefix, and refreshes the health check results
func (h *HealthCheck) refreshNodes() []adapter.Outbound { func (h *HealthCheck) refreshNodes() []adapter.Outbound {
h.Lock() h.mutex.Lock()
defer h.Unlock() defer h.mutex.Unlock()
nodes := CoveredOutbounds(h.router, h.tags) nodes := CoveredOutbounds(h.router, h.tags)
tags := make(map[string]struct{}) tags := make(map[string]struct{})
for _, node := range nodes { for _, n := range nodes {
tag := node.Tag() tag := n.Tag()
tags[tag] = struct{}{} tags[tag] = struct{}{}
// make it known to the health check results // make it known to the health check results
_, ok := h.results[tag] _, ok := h.results[tag]
@ -87,7 +94,7 @@ func (h *HealthCheck) refreshNodes() []adapter.Outbound {
validity := time.Duration(h.options.Interval) * time.Duration(h.options.SamplingCount) * 2 validity := time.Duration(h.options.Interval) * time.Duration(h.options.SamplingCount) * 2
h.results[tag] = &result{ h.results[tag] = &result{
// tag: tag, // tag: tag,
networks: node.Network(), networks: n.Network(),
rttStorage: newRTTStorage(h.options.SamplingCount, validity), rttStorage: newRTTStorage(h.options.SamplingCount, validity),
} }
} }

View File

@ -15,8 +15,8 @@ func NewLeastLoad(
) (Balancer, error) { ) (Balancer, error) {
return newRTTBasedBalancer( return newRTTBasedBalancer(
router, logger, options, router, logger, options,
func(node *Node) time.Duration { func(n *Node) time.Duration {
return node.Deviation return n.Deviation
}, },
) )
} }

View File

@ -15,8 +15,8 @@ func NewLeastPing(
) (Balancer, error) { ) (Balancer, error) {
return newRTTBasedBalancer( return newRTTBasedBalancer(
router, logger, options, router, logger, options,
func(node *Node) time.Duration { func(n *Node) time.Duration {
return node.Average return n.Average
}, },
) )
} }

View File

@ -1,17 +0,0 @@
package balancer
var healthPingStatsUntested = RTTStats{
All: 0,
Fail: 0,
Deviation: rttUntested,
Average: rttUntested,
Max: rttUntested,
Min: rttUntested,
}
// Node is a banalcer node with health check result
type Node struct {
Tag string
Networks []string
RTTStats
}

View File

@ -120,7 +120,14 @@ func (h *rttStorage) getStatistics() RTTStats {
} }
switch { switch {
case stats.All == 0: case stats.All == 0:
return healthPingStatsUntested return RTTStats{
All: 0,
Fail: 0,
Deviation: rttUntested,
Average: rttUntested,
Max: rttUntested,
Min: rttUntested,
}
case stats.Fail == stats.All: case stats.Fail == stats.All:
return RTTStats{ return RTTStats{
All: stats.All, All: stats.All,