diff --git a/option/rule_set.go b/option/rule_set.go index b6ec113e..f7730dfc 100644 --- a/option/rule_set.go +++ b/option/rule_set.go @@ -17,8 +17,8 @@ type _RuleSet struct { Type string `json:"type,omitempty"` Tag string `json:"tag"` Format string `json:"format,omitempty"` + Path string `json:"path,omitempty"` InlineOptions PlainRuleSet `json:"-"` - LocalOptions LocalRuleSet `json:"-"` RemoteOptions RemoteRuleSet `json:"-"` } @@ -31,7 +31,7 @@ func (r RuleSet) MarshalJSON() ([]byte, error) { r.Type = "" v = r.InlineOptions case C.RuleSetTypeLocal: - v = r.LocalOptions + v = nil case C.RuleSetTypeRemote: v = r.RemoteOptions default: @@ -58,6 +58,7 @@ func (r *RuleSet) UnmarshalJSON(bytes []byte) error { } } else { r.Format = "" + r.Path = "" } var v any switch r.Type { @@ -65,7 +66,7 @@ func (r *RuleSet) UnmarshalJSON(bytes []byte) error { r.Type = C.RuleSetTypeInline v = &r.InlineOptions case C.RuleSetTypeLocal: - v = &r.LocalOptions + v = nil case C.RuleSetTypeRemote: v = &r.RemoteOptions default: @@ -78,10 +79,6 @@ func (r *RuleSet) UnmarshalJSON(bytes []byte) error { return nil } -type LocalRuleSet struct { - Path string `json:"path,omitempty"` -} - type RemoteRuleSet struct { URL string `json:"url"` DownloadDetour string `json:"download_detour,omitempty"` diff --git a/route/rule_set_abstract.go b/route/rule_set_abstract.go new file mode 100644 index 00000000..49bae13e --- /dev/null +++ b/route/rule_set_abstract.go @@ -0,0 +1,155 @@ +package route + +import ( + "bytes" + "io" + "os" + "strings" + "time" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/srs" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/atomic" + E "github.com/sagernet/sing/common/exceptions" + F "github.com/sagernet/sing/common/format" + "github.com/sagernet/sing/common/json" + "github.com/sagernet/sing/common/logger" + "github.com/sagernet/sing/common/rw" + + "go4.org/netipx" +) + +type abstractRuleSet struct { + router adapter.Router + logger logger.ContextLogger + tag string + path string + format string + rules []adapter.HeadlessRule + metadata adapter.RuleSetMetadata + lastUpdated time.Time + refs atomic.Int32 +} + +func (s *abstractRuleSet) Name() string { + return s.tag +} + +func (s *abstractRuleSet) String() string { + return strings.Join(F.MapToString(s.rules), " ") +} + +func (s *abstractRuleSet) getPath(path string) (string, error) { + if path == "" { + path = s.tag + switch s.format { + case C.RuleSetFormatSource, "": + path += ".json" + case C.RuleSetFormatBinary: + path += ".srs" + } + } + if rw.IsDir(path) { + return "", E.New("rule_set path is a directory: ", path) + } + return path, nil +} + +func (s *abstractRuleSet) Metadata() adapter.RuleSetMetadata { + return s.metadata +} + +func (s *abstractRuleSet) ExtractIPSet() []*netipx.IPSet { + return common.FlatMap(s.rules, extractIPSetFromRule) +} + +func (s *abstractRuleSet) IncRef() { + s.refs.Add(1) +} + +func (s *abstractRuleSet) DecRef() { + if s.refs.Add(-1) < 0 { + panic("rule-set: negative refs") + } +} + +func (s *abstractRuleSet) Cleanup() { + if s.refs.Load() == 0 { + s.rules = nil + } +} + +func (s *abstractRuleSet) loadFromFile(path string) error { + file, err := os.Open(path) + if err != nil { + return err + } + content, err := io.ReadAll(file) + if err != nil { + return err + } + err = s.loadBytes(content) + if err != nil { + return err + } + fs, _ := file.Stat() + s.lastUpdated = fs.ModTime() + return nil +} + +func (s *abstractRuleSet) loadBytes(content []byte) error { + var ( + plainRuleSet option.PlainRuleSet + err error + ) + switch s.format { + case C.RuleSetFormatSource: + var compat option.PlainRuleSetCompat + compat, err = json.UnmarshalExtended[option.PlainRuleSetCompat](content) + if err != nil { + return err + } + plainRuleSet, err = compat.Upgrade() + if err != nil { + return err + } + case C.RuleSetFormatBinary: + plainRuleSet, err = srs.Read(bytes.NewReader(content), false) + if err != nil { + return err + } + default: + return E.New("unknown rule-set format: ", s.format) + } + return s.reloadRules(plainRuleSet.Rules) +} + +func (s *abstractRuleSet) reloadRules(headlessRules []option.HeadlessRule) error { + rules := make([]adapter.HeadlessRule, len(headlessRules)) + var err error + for i, ruleOptions := range headlessRules { + rules[i], err = NewHeadlessRule(s.router, ruleOptions) + if err != nil { + return E.Cause(err, "parse rule_set.rules.[", i, "]") + } + } + var metadata adapter.RuleSetMetadata + metadata.ContainsProcessRule = hasHeadlessRule(headlessRules, isProcessHeadlessRule) + metadata.ContainsWIFIRule = hasHeadlessRule(headlessRules, isWIFIHeadlessRule) + metadata.ContainsIPCIDRRule = hasHeadlessRule(headlessRules, isIPCIDRHeadlessRule) + s.rules = rules + s.metadata = metadata + return nil +} + +func (s *abstractRuleSet) Match(metadata *adapter.InboundContext) bool { + for _, rule := range s.rules { + if rule.Match(metadata) { + return true + } + } + return false +} diff --git a/route/rule_set_local.go b/route/rule_set_local.go index 893842d5..f7cb5edd 100644 --- a/route/rule_set_local.go +++ b/route/rule_set_local.go @@ -2,46 +2,33 @@ package route import ( "context" - "os" "path/filepath" - "strings" "github.com/sagernet/fswatch" "github.com/sagernet/sing-box/adapter" - "github.com/sagernet/sing-box/common/srs" C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/atomic" E "github.com/sagernet/sing/common/exceptions" - F "github.com/sagernet/sing/common/format" - "github.com/sagernet/sing/common/json" "github.com/sagernet/sing/common/logger" "github.com/sagernet/sing/common/x/list" - "github.com/sagernet/sing/service/filemanager" - - "go4.org/netipx" ) var _ adapter.RuleSet = (*LocalRuleSet)(nil) type LocalRuleSet struct { - router adapter.Router - logger logger.Logger - tag string - rules []adapter.HeadlessRule - metadata adapter.RuleSetMetadata - fileFormat string - watcher *fswatch.Watcher - refs atomic.Int32 + abstractRuleSet + watcher *fswatch.Watcher } -func NewLocalRuleSet(ctx context.Context, router adapter.Router, logger logger.Logger, options option.RuleSet) (*LocalRuleSet, error) { +func NewLocalRuleSet(ctx context.Context, router adapter.Router, logger logger.ContextLogger, options option.RuleSet) (*LocalRuleSet, error) { ruleSet := &LocalRuleSet{ - router: router, - logger: logger, - tag: options.Tag, - fileFormat: options.Format, + abstractRuleSet: abstractRuleSet{ + router: router, + logger: logger, + tag: options.Tag, + }, } if options.Type == C.RuleSetTypeInline { if len(options.InlineOptions.Rules) == 0 { @@ -51,40 +38,36 @@ func NewLocalRuleSet(ctx context.Context, router adapter.Router, logger logger.L if err != nil { return nil, err } - } else { - err := ruleSet.reloadFile(filemanager.BasePath(ctx, options.LocalOptions.Path)) - if err != nil { - return nil, err - } + return ruleSet, nil } - if options.Type == C.RuleSetTypeLocal { - var watcher *fswatch.Watcher - filePath, _ := filepath.Abs(options.LocalOptions.Path) - watcher, err := fswatch.NewWatcher(fswatch.Options{ - Path: []string{filePath}, - Callback: func(path string) { - uErr := ruleSet.reloadFile(path) - if uErr != nil { - logger.Error(E.Cause(uErr, "reload rule-set ", options.Tag)) - } - }, - }) - if err != nil { - return nil, err - } - ruleSet.watcher = watcher + ruleSet.path = options.Path + ruleSet.format = options.Format + path, err := ruleSet.getPath(options.Path) + if err != nil { + return nil, err } + err = ruleSet.loadFromFile(path) + if err != nil { + return nil, err + } + var watcher *fswatch.Watcher + filePath, _ := filepath.Abs(path) + watcher, err = fswatch.NewWatcher(fswatch.Options{ + Path: []string{filePath}, + Callback: func(path string) { + uErr := ruleSet.loadFromFile(path) + if uErr != nil { + logger.ErrorContext(log.ContextWithNewID(context.Background()), E.Cause(uErr, "reload rule-set ", options.Tag)) + } + }, + }) + if err != nil { + return nil, err + } + ruleSet.watcher = watcher return ruleSet, nil } -func (s *LocalRuleSet) Name() string { - return s.tag -} - -func (s *LocalRuleSet) String() string { - return strings.Join(F.MapToString(s.rules), " ") -} - func (s *LocalRuleSet) StartContext(ctx context.Context, startContext adapter.RuleSetStartContext) error { if s.watcher != nil { err := s.watcher.Start() @@ -95,83 +78,10 @@ func (s *LocalRuleSet) StartContext(ctx context.Context, startContext adapter.Ru return nil } -func (s *LocalRuleSet) reloadFile(path string) error { - var plainRuleSet option.PlainRuleSet - switch s.fileFormat { - case C.RuleSetFormatSource, "": - content, err := os.ReadFile(path) - if err != nil { - return err - } - compat, err := json.UnmarshalExtended[option.PlainRuleSetCompat](content) - if err != nil { - return err - } - plainRuleSet, err = compat.Upgrade() - if err != nil { - return err - } - case C.RuleSetFormatBinary: - setFile, err := os.Open(path) - if err != nil { - return err - } - plainRuleSet, err = srs.Read(setFile, false) - if err != nil { - return err - } - default: - return E.New("unknown rule-set format: ", s.fileFormat) - } - return s.reloadRules(plainRuleSet.Rules) -} - -func (s *LocalRuleSet) reloadRules(headlessRules []option.HeadlessRule) error { - rules := make([]adapter.HeadlessRule, len(headlessRules)) - var err error - for i, ruleOptions := range headlessRules { - rules[i], err = NewHeadlessRule(s.router, ruleOptions) - if err != nil { - return E.Cause(err, "parse rule_set.rules.[", i, "]") - } - } - var metadata adapter.RuleSetMetadata - metadata.ContainsProcessRule = hasHeadlessRule(headlessRules, isProcessHeadlessRule) - metadata.ContainsWIFIRule = hasHeadlessRule(headlessRules, isWIFIHeadlessRule) - metadata.ContainsIPCIDRRule = hasHeadlessRule(headlessRules, isIPCIDRHeadlessRule) - s.rules = rules - s.metadata = metadata - return nil -} - func (s *LocalRuleSet) PostStart() error { return nil } -func (s *LocalRuleSet) Metadata() adapter.RuleSetMetadata { - return s.metadata -} - -func (s *LocalRuleSet) ExtractIPSet() []*netipx.IPSet { - return common.FlatMap(s.rules, extractIPSetFromRule) -} - -func (s *LocalRuleSet) IncRef() { - s.refs.Add(1) -} - -func (s *LocalRuleSet) DecRef() { - if s.refs.Add(-1) < 0 { - panic("rule-set: negative refs") - } -} - -func (s *LocalRuleSet) Cleanup() { - if s.refs.Load() == 0 { - s.rules = nil - } -} - func (s *LocalRuleSet) RegisterCallback(callback adapter.RuleSetUpdateCallback) *list.Element[adapter.RuleSetUpdateCallback] { return nil } @@ -183,12 +93,3 @@ func (s *LocalRuleSet) Close() error { s.rules = nil return common.Close(common.PtrOrNil(s.watcher)) } - -func (s *LocalRuleSet) Match(metadata *adapter.InboundContext) bool { - for _, rule := range s.rules { - if rule.Match(metadata) { - return true - } - } - return false -} diff --git a/route/rule_set_remote.go b/route/rule_set_remote.go index 03662ee4..11e475f8 100644 --- a/route/rule_set_remote.go +++ b/route/rule_set_remote.go @@ -1,54 +1,45 @@ package route import ( - "bytes" "context" "io" "net" "net/http" + "os" "runtime" "strings" "sync" "time" "github.com/sagernet/sing-box/adapter" - "github.com/sagernet/sing-box/common/srs" C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" - "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/atomic" E "github.com/sagernet/sing/common/exceptions" F "github.com/sagernet/sing/common/format" - "github.com/sagernet/sing/common/json" "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/x/list" "github.com/sagernet/sing/service" "github.com/sagernet/sing/service/pause" - - "go4.org/netipx" ) var _ adapter.RuleSet = (*RemoteRuleSet)(nil) type RemoteRuleSet struct { + abstractRuleSet ctx context.Context cancel context.CancelFunc - router adapter.Router - logger logger.ContextLogger - options option.RuleSet - metadata adapter.RuleSetMetadata + path string + options option.RemoteRuleSet updateInterval time.Duration dialer N.Dialer - rules []adapter.HeadlessRule - lastUpdated time.Time lastEtag string updateTicker *time.Ticker pauseManager pause.Manager callbackAccess sync.Mutex callbacks list.List[adapter.RuleSetUpdateCallback] - refs atomic.Int32 } func NewRemoteRuleSet(ctx context.Context, router adapter.Router, logger logger.ContextLogger, options option.RuleSet) *RemoteRuleSet { @@ -60,30 +51,31 @@ func NewRemoteRuleSet(ctx context.Context, router adapter.Router, logger logger. updateInterval = 24 * time.Hour } return &RemoteRuleSet{ + abstractRuleSet: abstractRuleSet{ + router: router, + logger: logger, + tag: options.Tag, + format: options.Format, + }, ctx: ctx, cancel: cancel, - router: router, - logger: logger, - options: options, + path: options.Path, + options: options.RemoteOptions, updateInterval: updateInterval, pauseManager: service.FromContext[pause.Manager](ctx), } } -func (s *RemoteRuleSet) Name() string { - return s.options.Tag -} - func (s *RemoteRuleSet) String() string { return strings.Join(F.MapToString(s.rules), " ") } func (s *RemoteRuleSet) StartContext(ctx context.Context, startContext adapter.RuleSetStartContext) error { var dialer N.Dialer - if s.options.RemoteOptions.DownloadDetour != "" { - outbound, loaded := s.router.Outbound(s.options.RemoteOptions.DownloadDetour) + if s.options.DownloadDetour != "" { + outbound, loaded := s.router.Outbound(s.options.DownloadDetour) if !loaded { - return E.New("download_detour not found: ", s.options.RemoteOptions.DownloadDetour) + return E.New("download_detour not found: ", s.options.DownloadDetour) } dialer = outbound } else { @@ -94,21 +86,14 @@ func (s *RemoteRuleSet) StartContext(ctx context.Context, startContext adapter.R dialer = outbound } s.dialer = dialer - cacheFile := service.FromContext[adapter.CacheFile](s.ctx) - if cacheFile != nil { - if savedSet := cacheFile.LoadRuleSet(s.options.Tag); savedSet != nil { - err := s.loadBytes(savedSet.Content) - if err != nil { - return E.Cause(err, "restore cached rule-set") - } - s.lastUpdated = savedSet.LastUpdated - s.lastEtag = savedSet.LastEtag - } + if path, err := s.getPath(s.path); err == nil { + s.path = path + s.loadFromFile(path) } if s.lastUpdated.IsZero() { err := s.fetchOnce(ctx, startContext) if err != nil { - return E.Cause(err, "initial rule-set: ", s.options.Tag) + return E.Cause(err, "initial rule-set: ", s.tag) } } s.updateTicker = time.NewTicker(s.updateInterval) @@ -120,30 +105,6 @@ func (s *RemoteRuleSet) PostStart() error { return nil } -func (s *RemoteRuleSet) Metadata() adapter.RuleSetMetadata { - return s.metadata -} - -func (s *RemoteRuleSet) ExtractIPSet() []*netipx.IPSet { - return common.FlatMap(s.rules, extractIPSetFromRule) -} - -func (s *RemoteRuleSet) IncRef() { - s.refs.Add(1) -} - -func (s *RemoteRuleSet) DecRef() { - if s.refs.Add(-1) < 0 { - panic("rule-set: negative refs") - } -} - -func (s *RemoteRuleSet) Cleanup() { - if s.refs.Load() == 0 { - s.rules = nil - } -} - func (s *RemoteRuleSet) RegisterCallback(callback adapter.RuleSetUpdateCallback) *list.Element[adapter.RuleSetUpdateCallback] { s.callbackAccess.Lock() defer s.callbackAccess.Unlock() @@ -157,40 +118,10 @@ func (s *RemoteRuleSet) UnregisterCallback(element *list.Element[adapter.RuleSet } func (s *RemoteRuleSet) loadBytes(content []byte) error { - var ( - plainRuleSet option.PlainRuleSet - err error - ) - switch s.options.Format { - case C.RuleSetFormatSource: - var compat option.PlainRuleSetCompat - compat, err = json.UnmarshalExtended[option.PlainRuleSetCompat](content) - if err != nil { - return err - } - plainRuleSet, err = compat.Upgrade() - if err != nil { - return err - } - case C.RuleSetFormatBinary: - plainRuleSet, err = srs.Read(bytes.NewReader(content), false) - if err != nil { - return err - } - default: - return E.New("unknown rule-set format: ", s.options.Format) + err := s.abstractRuleSet.loadBytes(content) + if err != nil { + return err } - rules := make([]adapter.HeadlessRule, len(plainRuleSet.Rules)) - for i, ruleOptions := range plainRuleSet.Rules { - rules[i], err = NewHeadlessRule(s.router, ruleOptions) - if err != nil { - return E.Cause(err, "parse rule_set.rules.[", i, "]") - } - } - s.metadata.ContainsProcessRule = hasHeadlessRule(plainRuleSet.Rules, isProcessHeadlessRule) - s.metadata.ContainsWIFIRule = hasHeadlessRule(plainRuleSet.Rules, isWIFIHeadlessRule) - s.metadata.ContainsIPCIDRRule = hasHeadlessRule(plainRuleSet.Rules, isIPCIDRHeadlessRule) - s.rules = rules s.callbackAccess.Lock() callbacks := s.callbacks.Array() s.callbackAccess.Unlock() @@ -202,12 +133,7 @@ func (s *RemoteRuleSet) loadBytes(content []byte) error { func (s *RemoteRuleSet) loopUpdate() { if time.Since(s.lastUpdated) > s.updateInterval { - err := s.fetchOnce(s.ctx, nil) - if err != nil { - s.logger.Error("fetch rule-set ", s.options.Tag, ": ", err) - } else if s.refs.Load() == 0 { - s.rules = nil - } + s.update() } for { runtime.GC() @@ -216,21 +142,26 @@ func (s *RemoteRuleSet) loopUpdate() { return case <-s.updateTicker.C: s.pauseManager.WaitActive() - err := s.fetchOnce(s.ctx, nil) - if err != nil { - s.logger.Error("fetch rule-set ", s.options.Tag, ": ", err) - } else if s.refs.Load() == 0 { - s.rules = nil - } + s.update() } } } +func (s *RemoteRuleSet) update() { + ctx := log.ContextWithNewID(s.ctx) + err := s.fetchOnce(ctx, nil) + if err != nil { + s.logger.ErrorContext(ctx, "fetch rule-set ", s.tag, ": ", err) + } else if s.refs.Load() == 0 { + s.rules = nil + } +} + func (s *RemoteRuleSet) fetchOnce(ctx context.Context, startContext adapter.RuleSetStartContext) error { - s.logger.Debug("updating rule-set ", s.options.Tag, " from URL: ", s.options.RemoteOptions.URL) + s.logger.DebugContext(ctx, "updating rule-set ", s.tag, " from URL: ", s.options.URL) var httpClient *http.Client if startContext != nil { - httpClient = startContext.HTTPClient(s.options.RemoteOptions.DownloadDetour, s.dialer) + httpClient = startContext.HTTPClient(s.options.DownloadDetour, s.dialer) } else { httpClient = &http.Client{ Transport: &http.Transport{ @@ -242,7 +173,7 @@ func (s *RemoteRuleSet) fetchOnce(ctx context.Context, startContext adapter.Rule }, } } - request, err := http.NewRequest("GET", s.options.RemoteOptions.URL, nil) + request, err := http.NewRequest("GET", s.options.URL, nil) if err != nil { return err } @@ -257,19 +188,8 @@ func (s *RemoteRuleSet) fetchOnce(ctx context.Context, startContext adapter.Rule case http.StatusOK: case http.StatusNotModified: s.lastUpdated = time.Now() - cacheFile := service.FromContext[adapter.CacheFile](s.ctx) - if cacheFile != nil { - savedRuleSet := cacheFile.LoadRuleSet(s.options.Tag) - if savedRuleSet != nil { - savedRuleSet.LastUpdated = s.lastUpdated - err = cacheFile.SaveRuleSet(s.options.Tag, savedRuleSet) - if err != nil { - s.logger.Error("save rule-set updated time: ", err) - return nil - } - } - } - s.logger.Info("update rule-set ", s.options.Tag, ": not modified") + os.Chtimes(s.path, s.lastUpdated, s.lastUpdated) + s.logger.InfoContext(ctx, "update rule-set ", s.tag, ": not modified") return nil default: return E.New("unexpected status: ", response.Status) @@ -290,18 +210,8 @@ func (s *RemoteRuleSet) fetchOnce(ctx context.Context, startContext adapter.Rule s.lastEtag = eTagHeader } s.lastUpdated = time.Now() - cacheFile := service.FromContext[adapter.CacheFile](s.ctx) - if cacheFile != nil { - err = cacheFile.SaveRuleSet(s.options.Tag, &adapter.SavedRuleSet{ - LastUpdated: s.lastUpdated, - Content: content, - LastEtag: s.lastEtag, - }) - if err != nil { - s.logger.Error("save rule-set cache: ", err) - } - } - s.logger.Info("updated rule-set ", s.options.Tag) + os.WriteFile(s.path, content, 0o666) + s.logger.InfoContext(ctx, "updated rule-set ", s.tag) return nil } @@ -311,12 +221,3 @@ func (s *RemoteRuleSet) Close() error { s.cancel() return nil } - -func (s *RemoteRuleSet) Match(metadata *adapter.InboundContext) bool { - for _, rule := range s.rules { - if rule.Match(metadata) { - return true - } - } - return false -}