diff --git a/balancer/balancer.go b/balancer/balancer.go index b0a00b11..6430a58b 100644 --- a/balancer/balancer.go +++ b/balancer/balancer.go @@ -9,13 +9,18 @@ import ( "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing/common" + N "github.com/sagernet/sing/common/network" ) var _ Balancer = (*rttBasedBalancer)(nil) // Balancer is interface for load balancers type Balancer interface { - Pick() string + // Pick picks a qualified nodes + Pick(network string) string + // Networks returns the supported network types + Networks() []string } type rttBasedBalancer struct { @@ -49,8 +54,46 @@ func newRTTBasedBalancer( } // Select selects qualified nodes -func (s *rttBasedBalancer) Pick() string { - nodes := s.HealthCheck.NodesByCategory() +func (s *rttBasedBalancer) Networks() []string { + hasTCP, hasUDP := false, false + nodes := s.HealthCheck.NodesByCategory("") + for _, node := range nodes.Qualified { + if !hasTCP && common.Contains(node.Networks, N.NetworkTCP) { + hasTCP = true + } + if !hasUDP && common.Contains(node.Networks, N.NetworkUDP) { + hasUDP = true + } + if hasTCP && hasUDP { + break + } + } + if !hasTCP && !hasUDP { + for _, node := range nodes.Untested { + if !hasTCP && common.Contains(node.Networks, N.NetworkTCP) { + hasTCP = true + } + if !hasUDP && common.Contains(node.Networks, N.NetworkUDP) { + hasUDP = true + } + if hasTCP && hasUDP { + break + } + } + } + switch { + case hasTCP && hasUDP: + return []string{N.NetworkTCP, N.NetworkUDP} + case hasTCP: + return []string{N.NetworkTCP} + default: + return []string{N.NetworkUDP} + } +} + +// Select selects qualified nodes +func (s *rttBasedBalancer) Pick(network string) string { + nodes := s.HealthCheck.NodesByCategory(network) var candidates []*Node if len(nodes.Qualified) > 0 { candidates = nodes.Qualified diff --git a/balancer/healthcheck.go b/balancer/healthcheck.go index 260dd59c..ece6f0bc 100644 --- a/balancer/healthcheck.go +++ b/balancer/healthcheck.go @@ -22,7 +22,13 @@ type HealthCheck struct { logger log.Logger options *option.HealthCheckSettings - results map[string]*rttStorage + results map[string]*result +} + +type result struct { + // tag string + networks []string + *rttStorage } // NewHealthCheck creates a new HealthPing with settings @@ -51,7 +57,7 @@ func NewHealthCheck(router adapter.Router, tags []string, logger log.Logger, con router: router, tags: tags, options: config, - results: make(map[string]*rttStorage), + results: make(map[string]*result), logger: logger, } } diff --git a/balancer/healthcheck_nodes.go b/balancer/healthcheck_nodes.go index 5245fce4..aac4cf90 100644 --- a/balancer/healthcheck_nodes.go +++ b/balancer/healthcheck_nodes.go @@ -5,6 +5,7 @@ import ( "time" "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing/common" ) // CategorizedNodes holds the categorized nodes @@ -13,8 +14,9 @@ type CategorizedNodes struct { Failed, Untested []*Node } -// NodesByCategory returns the categorized nodes -func (h *HealthCheck) NodesByCategory() *CategorizedNodes { +// NodesByCategory returns the categorized nodes for specific network. +// If network is empty, all nodes are returned. +func (h *HealthCheck) NodesByCategory(network string) *CategorizedNodes { h.Lock() defer h.Unlock() if h == nil || len(h.results) == 0 { @@ -27,9 +29,13 @@ func (h *HealthCheck) NodesByCategory() *CategorizedNodes { Untested: make([]*Node, 0, len(h.results)), } for tag, result := range h.results { + if network != "" && !common.Contains(result.networks, network) { + continue + } node := &Node{ Tag: tag, - RTTStats: result.Get(), + Networks: result.networks, + RTTStats: result.rttStorage.Get(), } switch { case node.RTTStats.All == 0: @@ -72,15 +78,18 @@ func (h *HealthCheck) refreshNodes() []adapter.Outbound { tag := node.Tag() tags[tag] = struct{}{} // make it known to the health check results - r, ok := h.results[tag] + _, ok := h.results[tag] if !ok { // validity is 2 times to sampling period, since the check are // distributed in the time line randomly, in extreme cases, // previous checks are distributed on the left, and latters // on the right validity := time.Duration(h.options.Interval) * time.Duration(h.options.SamplingCount) * 2 - r = newRTTStorage(h.options.SamplingCount, validity) - h.results[tag] = r + h.results[tag] = &result{ + // tag: tag, + networks: node.Network(), + rttStorage: newRTTStorage(h.options.SamplingCount, validity), + } } } // remove unused rttStorage diff --git a/balancer/node.go b/balancer/node.go index a052f60e..e0134c62 100644 --- a/balancer/node.go +++ b/balancer/node.go @@ -11,6 +11,7 @@ var healthPingStatsUntested = RTTStats{ // Node is a banalcer node with health check result type Node struct { - Tag string + Tag string + Networks []string RTTStats } diff --git a/outbound/balancer.go b/outbound/balancer.go index c85fd1ba..b7931192 100644 --- a/outbound/balancer.go +++ b/outbound/balancer.go @@ -49,23 +49,16 @@ func NewBalancer( } // Network implements adapter.Outbound -// -// FIXME: logic issue: -// picked node is very likely to be different between first "Network() assetion" -// then "NewConnection()", maybe we need to keep the picked node in the context? -// that requests to change the Network() signature of the interface of -// adapter.Outbound func (s *Balancer) Network() []string { - picked := s.pick() - if picked == nil { + if s.Balancer == nil { return []string{N.NetworkTCP, N.NetworkUDP} } - return picked.Network() + return s.Balancer.Networks() } // Now implements adapter.OutboundGroup func (s *Balancer) Now() string { - return s.pick().Tag() + return s.pick("").Tag() } // All implements adapter.OutboundGroup @@ -75,22 +68,22 @@ func (s *Balancer) All() []string { // DialContext implements adapter.Outbound func (s *Balancer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { - return s.pick().DialContext(ctx, network, destination) + return s.pick(network).DialContext(ctx, network, destination) } // ListenPacket implements adapter.Outbound func (s *Balancer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { - return s.pick().ListenPacket(ctx, destination) + return s.pick(N.NetworkUDP).ListenPacket(ctx, destination) } // NewConnection implements adapter.Outbound func (s *Balancer) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { - return s.pick().NewConnection(ctx, conn, metadata) + return s.pick(N.NetworkTCP).NewConnection(ctx, conn, metadata) } // NewPacketConnection implements adapter.Outbound func (s *Balancer) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { - return s.pick().NewPacketConnection(ctx, conn, metadata) + return s.pick(N.NetworkUDP).NewPacketConnection(ctx, conn, metadata) } // initialize inits the balancer @@ -119,8 +112,8 @@ func (s *Balancer) setBalancer(b balancer.Balancer) error { return nil } -func (s *Balancer) pick() adapter.Outbound { - tag := s.pickTag() +func (s *Balancer) pick(network string) adapter.Outbound { + tag := s.pickTag(network) if tag == "" { return s.fallback } @@ -131,12 +124,12 @@ func (s *Balancer) pick() adapter.Outbound { return outbound } -func (s *Balancer) pickTag() string { +func (s *Balancer) pickTag(network string) string { if s.Balancer == nil { // not started yet, pick a random one return s.randomTag() } - tag := s.Balancer.Pick() + tag := s.Balancer.Pick(network) if tag == "" { return "" }