From e7ae3ddf3170821b03dbd579a469503ee009a629 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 10 Jun 2025 21:14:56 +0800 Subject: [PATCH] Add cache support for ssm-api --- docs/configuration/service/ssm-api.md | 8 +- option/ssmapi.go | 3 +- service/derp/service.go | 2 +- service/ssmapi/cache.go | 222 ++++++++++++++++++++++++++ service/ssmapi/server.go | 18 ++- service/ssmapi/user.go | 12 +- 6 files changed, 256 insertions(+), 9 deletions(-) create mode 100644 service/ssmapi/cache.go diff --git a/docs/configuration/service/ssm-api.md b/docs/configuration/service/ssm-api.md index 854ec687..1ef9f373 100644 --- a/docs/configuration/service/ssm-api.md +++ b/docs/configuration/service/ssm-api.md @@ -19,6 +19,7 @@ See https://github.com/Shadowsocks-NET/shadowsocks-specs/blob/main/2023-1-shadow ... // Listen Fields "servers": {}, + "cache_path": "", "tls": {} } ``` @@ -37,7 +38,7 @@ A mapping Object from HTTP endpoints to [Shadowsocks Inbound](/configuration/inb Selected Shadowsocks inbounds must be configured with [managed](/configuration/inbound/shadowsocks#managed) enabled. -Example: +Example: ```json { @@ -47,6 +48,11 @@ Example: } ``` +#### cache_path + +If set, when the server is about to stop, traffic and user state will be saved to the specified JSON file +to be restored on the next startup. + #### tls TLS configuration, see [TLS](/configuration/shared/tls/#inbound). diff --git a/option/ssmapi.go b/option/ssmapi.go index 2fbdc1bc..8d25f400 100644 --- a/option/ssmapi.go +++ b/option/ssmapi.go @@ -6,6 +6,7 @@ import ( type SSMAPIServiceOptions struct { ListenOptions - Servers *badjson.TypedMap[string, string] `json:"servers"` + Servers *badjson.TypedMap[string, string] `json:"servers"` + CachePath string `json:"cache_path,omitempty"` InboundTLSOptionsContainer } diff --git a/service/derp/service.go b/service/derp/service.go index f42e7e29..861bb235 100644 --- a/service/derp/service.go +++ b/service/derp/service.go @@ -134,7 +134,7 @@ func NewService(ctx context.Context, logger log.ContextLogger, tag string, optio func (d *Service) Start(stage adapter.StartStage) error { switch stage { case adapter.StartStateStart: - config, err := readDERPConfig(d.configPath) + config, err := readDERPConfig(filemanager.BasePath(d.ctx, d.configPath)) if err != nil { return err } diff --git a/service/ssmapi/cache.go b/service/ssmapi/cache.go new file mode 100644 index 00000000..4c82c9d0 --- /dev/null +++ b/service/ssmapi/cache.go @@ -0,0 +1,222 @@ +package ssmapi + +import ( + "bytes" + "os" + "path/filepath" + "sort" + + "github.com/sagernet/sing/common/atomic" + "github.com/sagernet/sing/common/json" + "github.com/sagernet/sing/common/json/badjson" + "github.com/sagernet/sing/service/filemanager" +) + +type Cache struct { + Endpoints *badjson.TypedMap[string, *EndpointCache] `json:"endpoints"` +} + +type EndpointCache struct { + GlobalUplink int64 `json:"global_uplink"` + GlobalDownlink int64 `json:"global_downlink"` + GlobalUplinkPackets int64 `json:"global_uplink_packets"` + GlobalDownlinkPackets int64 `json:"global_downlink_packets"` + GlobalTCPSessions int64 `json:"global_tcp_sessions"` + GlobalUDPSessions int64 `json:"global_udp_sessions"` + UserUplink *badjson.TypedMap[string, int64] `json:"user_uplink"` + UserDownlink *badjson.TypedMap[string, int64] `json:"user_downlink"` + UserUplinkPackets *badjson.TypedMap[string, int64] `json:"user_uplink_packets"` + UserDownlinkPackets *badjson.TypedMap[string, int64] `json:"user_downlink_packets"` + UserTCPSessions *badjson.TypedMap[string, int64] `json:"user_tcp_sessions"` + UserUDPSessions *badjson.TypedMap[string, int64] `json:"user_udp_sessions"` + Users *badjson.TypedMap[string, string] `json:"users"` +} + +func (s *Service) loadCache() error { + if s.cachePath == "" { + return nil + } + basePath := filemanager.BasePath(s.ctx, s.cachePath) + cacheBinary, err := os.ReadFile(basePath) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return err + } + err = s.decodeCache(cacheBinary) + if err != nil { + os.RemoveAll(basePath) + return err + } + return nil +} + +func (s *Service) saveCache() error { + if s.cachePath == "" { + return nil + } + basePath := filemanager.BasePath(s.ctx, s.cachePath) + err := os.MkdirAll(filepath.Dir(basePath), 0o777) + if err != nil { + return err + } + cacheBinary, err := s.encodeCache() + if err != nil { + return err + } + return os.WriteFile(s.cachePath, cacheBinary, 0o644) +} + +func (s *Service) decodeCache(cacheBinary []byte) error { + if len(cacheBinary) == 0 { + return nil + } + cache, err := json.UnmarshalExtended[*Cache](cacheBinary) + if err != nil { + return err + } + if cache.Endpoints == nil || cache.Endpoints.Size() == 0 { + return nil + } + for _, entry := range cache.Endpoints.Entries() { + trafficManager, loaded := s.traffics[entry.Key] + if !loaded { + continue + } + trafficManager.globalUplink.Store(entry.Value.GlobalUplink) + trafficManager.globalDownlink.Store(entry.Value.GlobalDownlink) + trafficManager.globalUplinkPackets.Store(entry.Value.GlobalUplinkPackets) + trafficManager.globalDownlinkPackets.Store(entry.Value.GlobalDownlinkPackets) + trafficManager.globalTCPSessions.Store(entry.Value.GlobalTCPSessions) + trafficManager.globalUDPSessions.Store(entry.Value.GlobalUDPSessions) + trafficManager.userUplink = typedAtomicInt64Map(entry.Value.UserUplink) + trafficManager.userDownlink = typedAtomicInt64Map(entry.Value.UserDownlink) + trafficManager.userUplinkPackets = typedAtomicInt64Map(entry.Value.UserUplinkPackets) + trafficManager.userDownlinkPackets = typedAtomicInt64Map(entry.Value.UserDownlinkPackets) + trafficManager.userTCPSessions = typedAtomicInt64Map(entry.Value.UserTCPSessions) + trafficManager.userUDPSessions = typedAtomicInt64Map(entry.Value.UserUDPSessions) + userManager, loaded := s.users[entry.Key] + if !loaded { + continue + } + userManager.usersMap = typedMap(entry.Value.Users) + _ = userManager.postUpdate(false) + } + return nil +} + +func (s *Service) encodeCache() ([]byte, error) { + endpoints := new(badjson.TypedMap[string, *EndpointCache]) + for tag, traffic := range s.traffics { + var ( + userUplink = new(badjson.TypedMap[string, int64]) + userDownlink = new(badjson.TypedMap[string, int64]) + userUplinkPackets = new(badjson.TypedMap[string, int64]) + userDownlinkPackets = new(badjson.TypedMap[string, int64]) + userTCPSessions = new(badjson.TypedMap[string, int64]) + userUDPSessions = new(badjson.TypedMap[string, int64]) + userMap = new(badjson.TypedMap[string, string]) + ) + for user, uplink := range traffic.userUplink { + if uplink.Load() > 0 { + userUplink.Put(user, uplink.Load()) + } + } + for user, downlink := range traffic.userDownlink { + if downlink.Load() > 0 { + userDownlink.Put(user, downlink.Load()) + } + } + for user, uplinkPackets := range traffic.userUplinkPackets { + if uplinkPackets.Load() > 0 { + userUplinkPackets.Put(user, uplinkPackets.Load()) + } + } + for user, downlinkPackets := range traffic.userDownlinkPackets { + if downlinkPackets.Load() > 0 { + userDownlinkPackets.Put(user, downlinkPackets.Load()) + } + } + for user, tcpSessions := range traffic.userTCPSessions { + if tcpSessions.Load() > 0 { + userTCPSessions.Put(user, tcpSessions.Load()) + } + } + for user, udpSessions := range traffic.userUDPSessions { + if udpSessions.Load() > 0 { + userUDPSessions.Put(user, udpSessions.Load()) + } + } + userManager := s.users[tag] + if userManager != nil && len(userManager.usersMap) > 0 { + userMap = new(badjson.TypedMap[string, string]) + for username, password := range userManager.usersMap { + if username != "" && password != "" { + userMap.Put(username, password) + } + } + } + endpoints.Put(tag, &EndpointCache{ + GlobalUplink: traffic.globalUplink.Load(), + GlobalDownlink: traffic.globalDownlink.Load(), + GlobalUplinkPackets: traffic.globalUplinkPackets.Load(), + GlobalDownlinkPackets: traffic.globalDownlinkPackets.Load(), + GlobalTCPSessions: traffic.globalTCPSessions.Load(), + GlobalUDPSessions: traffic.globalUDPSessions.Load(), + UserUplink: sortTypedMap(userUplink), + UserDownlink: sortTypedMap(userDownlink), + UserUplinkPackets: sortTypedMap(userUplinkPackets), + UserDownlinkPackets: sortTypedMap(userDownlinkPackets), + UserTCPSessions: sortTypedMap(userTCPSessions), + UserUDPSessions: sortTypedMap(userUDPSessions), + Users: sortTypedMap(userMap), + }) + } + var buffer bytes.Buffer + encoder := json.NewEncoder(&buffer) + encoder.SetIndent("", " ") + err := encoder.Encode(&Cache{ + Endpoints: sortTypedMap(endpoints), + }) + if err != nil { + return nil, err + } + return buffer.Bytes(), nil +} + +func sortTypedMap[T comparable](trafficMap *badjson.TypedMap[string, T]) *badjson.TypedMap[string, T] { + if trafficMap == nil { + return nil + } + keys := trafficMap.Keys() + sort.Strings(keys) + sortedMap := new(badjson.TypedMap[string, T]) + for _, key := range keys { + value, _ := trafficMap.Get(key) + sortedMap.Put(key, value) + } + return sortedMap +} + +func typedAtomicInt64Map(trafficMap *badjson.TypedMap[string, int64]) map[string]*atomic.Int64 { + result := make(map[string]*atomic.Int64) + if trafficMap != nil { + for _, entry := range trafficMap.Entries() { + counter := new(atomic.Int64) + counter.Store(entry.Value) + result[entry.Key] = counter + } + } + return result +} + +func typedMap[T comparable](trafficMap *badjson.TypedMap[string, T]) map[string]T { + result := make(map[string]T) + if trafficMap != nil { + for _, entry := range trafficMap.Entries() { + result[entry.Key] = entry.Value + } + } + return result +} diff --git a/service/ssmapi/server.go b/service/ssmapi/server.go index 92d7354f..f9b382af 100644 --- a/service/ssmapi/server.go +++ b/service/ssmapi/server.go @@ -33,6 +33,9 @@ type Service struct { listener *listener.Listener tlsConfig tls.ServerConfig httpServer *http.Server + traffics map[string]*TrafficManager + users map[string]*UserManager + cachePath string } func NewService(ctx context.Context, logger log.ContextLogger, tag string, options option.SSMAPIServiceOptions) (adapter.Service, error) { @@ -50,6 +53,9 @@ func NewService(ctx context.Context, logger log.ContextLogger, tag string, optio httpServer: &http.Server{ Handler: chiRouter, }, + traffics: make(map[string]*TrafficManager), + users: make(map[string]*UserManager), + cachePath: options.CachePath, } inboundManager := service.FromContext[adapter.InboundManager](ctx) if options.Servers.Size() == 0 { @@ -68,6 +74,8 @@ func NewService(ctx context.Context, logger log.ContextLogger, tag string, optio managedServer.SetTracker(traffic) user := NewUserManager(managedServer, traffic) chiRouter.Route(entry.Key, NewAPIServer(logger, traffic, user).Route) + s.traffics[entry.Key] = traffic + s.users[entry.Key] = user } if options.TLS != nil { tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS)) @@ -83,8 +91,12 @@ func (s *Service) Start(stage adapter.StartStage) error { if stage != adapter.StartStateStart { return nil } + err := s.loadCache() + if err != nil { + s.logger.Error(E.Cause(err, "load cache")) + } if s.tlsConfig != nil { - err := s.tlsConfig.Start() + err = s.tlsConfig.Start() if err != nil { return E.Cause(err, "create TLS config") } @@ -109,6 +121,10 @@ func (s *Service) Start(stage adapter.StartStage) error { } func (s *Service) Close() error { + err := s.saveCache() + if err != nil { + s.logger.Error(E.Cause(err, "save cache")) + } return common.Close( common.PtrOrNil(s.httpServer), common.PtrOrNil(s.listener), diff --git a/service/ssmapi/user.go b/service/ssmapi/user.go index a8eb27fb..26bc621a 100644 --- a/service/ssmapi/user.go +++ b/service/ssmapi/user.go @@ -22,7 +22,7 @@ func NewUserManager(inbound adapter.ManagedSSMServer, trafficManager *TrafficMan } } -func (m *UserManager) postUpdate() error { +func (m *UserManager) postUpdate(updated bool) error { users := make([]string, 0, len(m.usersMap)) uPSKs := make([]string, 0, len(m.usersMap)) for username, password := range m.usersMap { @@ -33,7 +33,9 @@ func (m *UserManager) postUpdate() error { if err != nil { return err } - m.trafficManager.UpdateUsers(users) + if updated { + m.trafficManager.UpdateUsers(users) + } return nil } @@ -58,7 +60,7 @@ func (m *UserManager) Add(username string, password string) error { return E.New("user ", username, " already exists") } m.usersMap[username] = password - return m.postUpdate() + return m.postUpdate(true) } func (m *UserManager) Get(username string) (string, bool) { @@ -74,12 +76,12 @@ func (m *UserManager) Update(username string, password string) error { m.access.Lock() defer m.access.Unlock() m.usersMap[username] = password - return m.postUpdate() + return m.postUpdate(true) } func (m *UserManager) Delete(username string) error { m.access.Lock() defer m.access.Unlock() delete(m.usersMap, username) - return m.postUpdate() + return m.postUpdate(true) }