diff --git a/adapter/router.go b/adapter/router.go index ca4d6547..52eb5c8f 100644 --- a/adapter/router.go +++ b/adapter/router.go @@ -2,12 +2,14 @@ package adapter import ( "context" + "net/http" "net/netip" "github.com/sagernet/sing-box/common/geoip" "github.com/sagernet/sing-dns" "github.com/sagernet/sing-tun" "github.com/sagernet/sing/common/control" + N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/service" mdns "github.com/miekg/dns" @@ -83,8 +85,14 @@ type DNSRule interface { } type RuleSet interface { + StartContext(ctx context.Context, startContext RuleSetStartContext) error + Close() error HeadlessRule - Service +} + +type RuleSetStartContext interface { + HTTPClient(detour string, dialer N.Dialer) *http.Client + Close() } type InterfaceUpdateListener interface { diff --git a/go.mod b/go.mod index 095326d7..69f0946a 100644 --- a/go.mod +++ b/go.mod @@ -26,7 +26,7 @@ require ( github.com/sagernet/gvisor v0.0.0-20231119034329-07cfb6aaf930 github.com/sagernet/quic-go v0.40.0 github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691 - github.com/sagernet/sing v0.2.18-0.20231129075305-eb56a60214be + github.com/sagernet/sing v0.2.18-0.20231130092223-1f82310f0375 github.com/sagernet/sing-dns v0.1.11 github.com/sagernet/sing-mux v0.1.5-0.20231109075101-6b086ed6bb07 github.com/sagernet/sing-quic v0.1.5-0.20231123150216-00957d136203 diff --git a/go.sum b/go.sum index 6d8d994a..a3d6fd73 100644 --- a/go.sum +++ b/go.sum @@ -110,8 +110,8 @@ github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691 h1:5Th31OC6yj8byL github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691/go.mod h1:B8lp4WkQ1PwNnrVMM6KyuFR20pU8jYBD+A4EhJovEXU= github.com/sagernet/sing v0.0.0-20220817130738-ce854cda8522/go.mod h1:QVsS5L/ZA2Q5UhQwLrn0Trw+msNd/NPGEhBKR/ioWiY= github.com/sagernet/sing v0.1.8/go.mod h1:jt1w2u7lJQFFSGLiRrRIs5YWmx4kAPfWuOejuDW9qMk= -github.com/sagernet/sing v0.2.18-0.20231129075305-eb56a60214be h1:FigAM9kq7RRXmHvgn8w2a8tqCY5CMV5GIk0id84dI0o= -github.com/sagernet/sing v0.2.18-0.20231129075305-eb56a60214be/go.mod h1:OL6k2F0vHmEzXz2KW19qQzu172FDgSbUSODylighuVo= +github.com/sagernet/sing v0.2.18-0.20231130092223-1f82310f0375 h1:5Q5K/twBNT1Hrpjd5Ghft0Sv0V+eVfTZX17CiPItSV8= +github.com/sagernet/sing v0.2.18-0.20231130092223-1f82310f0375/go.mod h1:OL6k2F0vHmEzXz2KW19qQzu172FDgSbUSODylighuVo= github.com/sagernet/sing-dns v0.1.11 h1:PPrMCVVrAeR3f5X23I+cmvacXJ+kzuyAsBiWyUKhGSE= github.com/sagernet/sing-dns v0.1.11/go.mod h1:zJ/YjnYB61SYE+ubMcMqVdpaSvsyQ2iShQGO3vuLvvE= github.com/sagernet/sing-mux v0.1.5-0.20231109075101-6b086ed6bb07 h1:ncKb5tVOsCQgCsv6UpsA0jinbNb5OQ5GMPJlyQP3EHM= diff --git a/route/router.go b/route/router.go index 96706195..c7021823 100644 --- a/route/router.go +++ b/route/router.go @@ -39,6 +39,7 @@ import ( M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" serviceNTP "github.com/sagernet/sing/common/ntp" + "github.com/sagernet/sing/common/task" "github.com/sagernet/sing/common/uot" "github.com/sagernet/sing/service" "github.com/sagernet/sing/service/pause" @@ -490,12 +491,24 @@ func (r *Router) Start() error { if r.needWIFIState { r.updateWIFIState() } + ruleSetStartContext := NewRuleSetStartContext() + var ruleSetStartGroup task.Group for i, ruleSet := range r.ruleSets { - err := ruleSet.Start() - if err != nil { - return E.Cause(err, "initialize rule-set[", i, "]") - } + ruleSetStartGroup.Append0(func(ctx context.Context) error { + err := ruleSet.StartContext(ctx, ruleSetStartContext) + if err != nil { + return E.Cause(err, "initialize rule-set[", i, "]") + } + return nil + }) } + ruleSetStartGroup.Concurrency(5) + ruleSetStartGroup.FastFail() + err := ruleSetStartGroup.Run(r.ctx) + if err != nil { + return err + } + ruleSetStartContext.Close() for i, rule := range r.rules { err := rule.Start() if err != nil { diff --git a/route/rule_set.go b/route/rule_set.go index 76c78c62..f644fb40 100644 --- a/route/rule_set.go +++ b/route/rule_set.go @@ -2,12 +2,17 @@ package route import ( "context" + "net" + "net/http" + "sync" "github.com/sagernet/sing-box/adapter" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/option" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/logger" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" ) func NewRuleSet(ctx context.Context, router adapter.Router, logger logger.ContextLogger, options option.RuleSet) (adapter.RuleSet, error) { @@ -20,3 +25,43 @@ func NewRuleSet(ctx context.Context, router adapter.Router, logger logger.Contex return nil, E.New("unknown rule set type: ", options.Type) } } + +var _ adapter.RuleSetStartContext = (*RuleSetStartContext)(nil) + +type RuleSetStartContext struct { + access sync.Mutex + httpClientCache map[string]*http.Client +} + +func NewRuleSetStartContext() *RuleSetStartContext { + return &RuleSetStartContext{ + httpClientCache: make(map[string]*http.Client), + } +} + +func (c *RuleSetStartContext) HTTPClient(detour string, dialer N.Dialer) *http.Client { + c.access.Lock() + defer c.access.Unlock() + if httpClient, loaded := c.httpClientCache[detour]; loaded { + return httpClient + } + httpClient := &http.Client{ + Transport: &http.Transport{ + ForceAttemptHTTP2: true, + TLSHandshakeTimeout: C.TCPTimeout, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return dialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) + }, + }, + } + c.httpClientCache[detour] = httpClient + return httpClient +} + +func (c *RuleSetStartContext) Close() { + c.access.Lock() + defer c.access.Unlock() + for _, client := range c.httpClientCache { + client.CloseIdleConnections() + } +} diff --git a/route/rule_set_local.go b/route/rule_set_local.go index ccdb1704..b466012a 100644 --- a/route/rule_set_local.go +++ b/route/rule_set_local.go @@ -1,6 +1,7 @@ package route import ( + "context" "os" "github.com/sagernet/sing-box/adapter" @@ -60,7 +61,7 @@ func (s *LocalRuleSet) Match(metadata *adapter.InboundContext) bool { return false } -func (s *LocalRuleSet) Start() error { +func (s *LocalRuleSet) StartContext(ctx context.Context, startContext adapter.RuleSetStartContext) error { return nil } diff --git a/route/rule_set_remote.go b/route/rule_set_remote.go index d9102320..05b18bf6 100644 --- a/route/rule_set_remote.go +++ b/route/rule_set_remote.go @@ -63,7 +63,7 @@ func (s *RemoteRuleSet) Match(metadata *adapter.InboundContext) bool { return false } -func (s *RemoteRuleSet) Start() error { +func (s *RemoteRuleSet) StartContext(ctx context.Context, startContext adapter.RuleSetStartContext) error { var dialer N.Dialer if s.options.RemoteOptions.DownloadDetour != "" { outbound, loaded := s.router.Outbound(s.options.RemoteOptions.DownloadDetour) @@ -91,7 +91,7 @@ func (s *RemoteRuleSet) Start() error { } } if s.lastUpdated.IsZero() || time.Since(s.lastUpdated) > s.updateInterval { - err := s.fetchOnce() + err := s.fetchOnce(ctx, startContext) if err != nil { return E.Cause(err, "fetch rule-set ", s.options.Tag) } @@ -141,7 +141,7 @@ func (s *RemoteRuleSet) loopUpdate() { case <-s.ctx.Done(): return case <-s.updateTicker.C: - err := s.fetchOnce() + err := s.fetchOnce(s.ctx, nil) if err != nil { s.logger.Error("fetch rule-set ", s.options.Tag, ": ", err) } @@ -149,18 +149,22 @@ func (s *RemoteRuleSet) loopUpdate() { } } -func (s *RemoteRuleSet) fetchOnce() error { +func (s *RemoteRuleSet) fetchOnce(ctx context.Context, startContext adapter.RuleSetStartContext) error { s.logger.Debug("updating rule-set ", s.options.Tag, " from URL: ", s.options.RemoteOptions.URL) - httpClient := &http.Client{ - Transport: &http.Transport{ - ForceAttemptHTTP2: true, - TLSHandshakeTimeout: C.TCPTimeout, - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - return s.dialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) + var httpClient *http.Client + if startContext != nil { + httpClient = startContext.HTTPClient(s.options.RemoteOptions.DownloadDetour, s.dialer) + } else { + httpClient = &http.Client{ + Transport: &http.Transport{ + ForceAttemptHTTP2: true, + TLSHandshakeTimeout: C.TCPTimeout, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return s.dialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) + }, }, - }, + } } - defer httpClient.CloseIdleConnections() request, err := http.NewRequest("GET", s.options.RemoteOptions.URL, nil) if err != nil { return err @@ -168,7 +172,7 @@ func (s *RemoteRuleSet) fetchOnce() error { if s.lastEtag != "" { request.Header.Set("If-None-Match", s.lastEtag) } - response, err := httpClient.Do(request.WithContext(s.ctx)) + response, err := httpClient.Do(request.WithContext(ctx)) if err != nil { return err }