Add SSM API service

This commit is contained in:
世界 2025-04-14 15:41:20 +08:00
parent bdc008c8ce
commit 4f425e13aa
No known key found for this signature in database
GPG Key ID: CD109927C34A63C4
13 changed files with 726 additions and 9 deletions

18
adapter/ssm.go Normal file
View File

@ -0,0 +1,18 @@
package adapter
import (
"net"
N "github.com/sagernet/sing/common/network"
)
type ManagedSSMServer interface {
Inbound
SetTracker(tracker SSMTracker)
UpdateUsers(users []string, uPSKs []string) error
}
type SSMTracker interface {
TrackConnection(conn net.Conn, metadata InboundContext) net.Conn
TrackPacketConnection(conn N.PacketConn, metadata InboundContext) N.PacketConn
}

View File

@ -27,6 +27,7 @@ const (
TypeTailscale = "tailscale"
TypeDERP = "derp"
TypeResolved = "resolved"
TypeSSMAPI = "ssm-api"
)
const (

View File

@ -0,0 +1,52 @@
---
icon: material/new-box
---
!!! question "Since sing-box 1.12.0"
# SSM API
SSM API service is a RESTful API server for managing Shadowsocks servers.
See https://github.com/Shadowsocks-NET/shadowsocks-specs/blob/main/2023-1-shadowsocks-server-management-api-v1.md
### Structure
```json
{
"type": "ssm-api",
... // Listen Fields
"servers": {},
"tls": {}
}
```
### Listen Fields
See [Listen Fields](/configuration/shared/listen/) for details.
### Fields
#### servers
==Required==
A mapping Object from HTTP endpoints to [Shadowsocks Inbound](/configuration/inbound/shadowsocks) tags.
Selected Shadowsocks inbounds must be configured with [managed](/configuration/inbound/shadowsocks#managed) enabled.
Example:
```json
{
"servers": {
"/": "ss-in"
}
}
```
#### tls
TLS configuration, see [TLS](/configuration/shared/tls/#inbound).

View File

@ -35,6 +35,7 @@ import (
"github.com/sagernet/sing-box/protocol/vless"
"github.com/sagernet/sing-box/protocol/vmess"
"github.com/sagernet/sing-box/service/resolved"
"github.com/sagernet/sing-box/service/ssmapi"
E "github.com/sagernet/sing/common/exceptions"
)
@ -125,6 +126,7 @@ func ServiceRegistry() *service.Registry {
registry := service.NewRegistry()
resolved.RegisterService(registry)
ssmapi.RegisterService(registry)
registerDERPService(registry)

View File

@ -174,6 +174,7 @@ nav:
- configuration/service/index.md
- DERP: configuration/service/derp.md
- Resolved: configuration/service/resolved.md
- SSM API: configuration/service/ssm-api.md
markdown_extensions:
- pymdownx.inlinehilite
- pymdownx.snippets

View File

@ -8,6 +8,7 @@ type ShadowsocksInboundOptions struct {
Users []ShadowsocksUser `json:"users,omitempty"`
Destinations []ShadowsocksDestination `json:"destinations,omitempty"`
Multiplex *InboundMultiplexOptions `json:"multiplex,omitempty"`
Managed bool `json:"managed,omitempty"`
}
type ShadowsocksUser struct {

11
option/ssmapi.go Normal file
View File

@ -0,0 +1,11 @@
package option
import (
"github.com/sagernet/sing/common/json/badjson"
)
type SSMAPIServiceOptions struct {
ListenOptions
Servers *badjson.TypedMap[string, string] `json:"servers"`
InboundTLSOptionsContainer
}

View File

@ -32,8 +32,10 @@ func RegisterInbound(registry *inbound.Registry) {
func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.ShadowsocksInboundOptions) (adapter.Inbound, error) {
if len(options.Users) > 0 && len(options.Destinations) > 0 {
return nil, E.New("users and destinations options must not be combined")
} else if options.Managed && (len(options.Users) > 0 || len(options.Destinations) > 0) {
return nil, E.New("users and destinations options are not supported in managed servers")
}
if len(options.Users) > 0 {
if len(options.Users) > 0 || options.Managed {
return newMultiInbound(ctx, router, logger, tag, options)
} else if len(options.Destinations) > 0 {
return newRelayInbound(ctx, router, logger, tag, options)

View File

@ -28,7 +28,10 @@ import (
"github.com/sagernet/sing/common/ntp"
)
var _ adapter.TCPInjectableInbound = (*MultiInbound)(nil)
var (
_ adapter.TCPInjectableInbound = (*MultiInbound)(nil)
_ adapter.ManagedSSMServer = (*MultiInbound)(nil)
)
type MultiInbound struct {
inbound.Adapter
@ -38,6 +41,7 @@ type MultiInbound struct {
listener *listener.Listener
service shadowsocks.MultiService[int]
users []option.ShadowsocksUser
tracker adapter.SSMTracker
}
func newMultiInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.ShadowsocksInboundOptions) (*MultiInbound, error) {
@ -79,13 +83,15 @@ func newMultiInbound(ctx context.Context, router adapter.Router, logger log.Cont
if err != nil {
return nil, err
}
err = service.UpdateUsersWithPasswords(common.MapIndexed(options.Users, func(index int, user option.ShadowsocksUser) int {
return index
}), common.Map(options.Users, func(user option.ShadowsocksUser) string {
return user.Password
}))
if err != nil {
return nil, err
if len(options.Users) > 0 {
err = service.UpdateUsersWithPasswords(common.MapIndexed(options.Users, func(index int, user option.ShadowsocksUser) int {
return index
}), common.Map(options.Users, func(user option.ShadowsocksUser) string {
return user.Password
}))
if err != nil {
return nil, err
}
}
inbound.service = service
inbound.users = options.Users
@ -112,6 +118,25 @@ func (h *MultiInbound) Close() error {
return h.listener.Close()
}
func (h *MultiInbound) SetTracker(tracker adapter.SSMTracker) {
h.tracker = tracker
}
func (h *MultiInbound) UpdateUsers(users []string, uPSKs []string) error {
err := h.service.UpdateUsersWithPasswords(common.MapIndexed(users, func(index int, user string) int {
return index
}), uPSKs)
if err != nil {
return err
}
h.users = common.Map(users, func(user string) option.ShadowsocksUser {
return option.ShadowsocksUser{
Name: user,
}
})
return nil
}
//nolint:staticcheck
func (h *MultiInbound) NewConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
err := h.service.NewConnection(ctx, conn, adapter.UpstreamMetadata(metadata))
@ -151,6 +176,9 @@ func (h *MultiInbound) newConnection(ctx context.Context, conn net.Conn, metadat
metadata.InboundDetour = h.listener.ListenOptions().Detour
//nolint:staticcheck
metadata.InboundOptions = h.listener.ListenOptions().InboundOptions
if h.tracker != nil {
conn = h.tracker.TrackConnection(conn, metadata)
}
return h.router.RouteConnection(ctx, conn, metadata)
}
@ -174,6 +202,9 @@ func (h *MultiInbound) newPacketConnection(ctx context.Context, conn N.PacketCon
metadata.InboundDetour = h.listener.ListenOptions().Detour
//nolint:staticcheck
metadata.InboundOptions = h.listener.ListenOptions().InboundOptions
if h.tracker != nil {
conn = h.tracker.TrackPacketConnection(conn, metadata)
}
return h.router.RoutePacketConnection(ctx, conn, metadata)
}

181
service/ssmapi/api.go Normal file
View File

@ -0,0 +1,181 @@
package ssmapi
import (
"net/http"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing/common/logger"
sHTTP "github.com/sagernet/sing/protocol/http"
"github.com/go-chi/chi/v5"
"github.com/go-chi/render"
)
type APIServer struct {
logger logger.Logger
traffic *TrafficManager
user *UserManager
}
func NewAPIServer(logger logger.Logger, traffic *TrafficManager, user *UserManager) *APIServer {
return &APIServer{
logger: logger,
traffic: traffic,
user: user,
}
}
func (s *APIServer) Route(r chi.Router) {
r.Route("/server/v1", func(r chi.Router) {
r.Use(func(handler http.Handler) http.Handler {
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
s.logger.Debug(request.Method, " ", request.RequestURI, " ", sHTTP.SourceAddress(request))
handler.ServeHTTP(writer, request)
})
})
r.Get("/", s.getServerInfo)
r.Get("/users", s.listUser)
r.Post("/users", s.addUser)
r.Get("/users/{username}", s.getUser)
r.Put("/users/{username}", s.updateUser)
r.Delete("/users/{username}", s.deleteUser)
r.Get("/stats", s.getStats)
})
}
func (s *APIServer) getServerInfo(writer http.ResponseWriter, request *http.Request) {
render.JSON(writer, request, render.M{
"server": "sing-box " + C.Version,
"apiVersion": "v1",
})
}
type UserObject struct {
UserName string `json:"username"`
Password string `json:"uPSK,omitempty"`
DownlinkBytes int64 `json:"downlinkBytes"`
UplinkBytes int64 `json:"uplinkBytes"`
DownlinkPackets int64 `json:"downlinkPackets"`
UplinkPackets int64 `json:"uplinkPackets"`
TCPSessions int64 `json:"tcpSessions"`
UDPSessions int64 `json:"udpSessions"`
}
func (s *APIServer) listUser(writer http.ResponseWriter, request *http.Request) {
render.JSON(writer, request, render.M{
"users": s.user.List(),
})
}
func (s *APIServer) addUser(writer http.ResponseWriter, request *http.Request) {
var addRequest struct {
UserName string `json:"username"`
Password string `json:"uPSK"`
}
err := render.DecodeJSON(request.Body, &addRequest)
if err != nil {
render.Status(request, http.StatusBadRequest)
render.PlainText(writer, request, err.Error())
return
}
err = s.user.Add(addRequest.UserName, addRequest.Password)
if err != nil {
render.Status(request, http.StatusBadRequest)
render.PlainText(writer, request, err.Error())
return
}
writer.WriteHeader(http.StatusCreated)
}
func (s *APIServer) getUser(writer http.ResponseWriter, request *http.Request) {
userName := chi.URLParam(request, "username")
if userName == "" {
writer.WriteHeader(http.StatusBadRequest)
return
}
uPSK, loaded := s.user.Get(userName)
if !loaded {
writer.WriteHeader(http.StatusNotFound)
return
}
user := UserObject{
UserName: userName,
Password: uPSK,
}
s.traffic.ReadUser(&user)
render.JSON(writer, request, user)
}
func (s *APIServer) updateUser(writer http.ResponseWriter, request *http.Request) {
userName := chi.URLParam(request, "username")
if userName == "" {
writer.WriteHeader(http.StatusBadRequest)
return
}
var updateRequest struct {
Password string `json:"uPSK"`
}
err := render.DecodeJSON(request.Body, &updateRequest)
if err != nil {
render.Status(request, http.StatusBadRequest)
render.PlainText(writer, request, err.Error())
return
}
_, loaded := s.user.Get(userName)
if !loaded {
writer.WriteHeader(http.StatusNotFound)
return
}
err = s.user.Update(userName, updateRequest.Password)
if err != nil {
render.Status(request, http.StatusBadRequest)
render.PlainText(writer, request, err.Error())
return
}
writer.WriteHeader(http.StatusNoContent)
}
func (s *APIServer) deleteUser(writer http.ResponseWriter, request *http.Request) {
userName := chi.URLParam(request, "username")
if userName == "" {
writer.WriteHeader(http.StatusBadRequest)
return
}
_, loaded := s.user.Get(userName)
if !loaded {
writer.WriteHeader(http.StatusNotFound)
return
}
err := s.user.Delete(userName)
if err != nil {
render.Status(request, http.StatusBadRequest)
render.PlainText(writer, request, err.Error())
return
}
writer.WriteHeader(http.StatusNoContent)
}
func (s *APIServer) getStats(writer http.ResponseWriter, request *http.Request) {
requireClear := chi.URLParam(request, "clear") == "true"
users := s.user.List()
s.traffic.ReadUsers(users)
for i := range users {
users[i].Password = ""
}
uplinkBytes, downlinkBytes, uplinkPackets, downlinkPackets, tcpSessions, udpSessions := s.traffic.ReadGlobal()
if requireClear {
s.traffic.Clear()
}
render.JSON(writer, request, render.M{
"uplinkBytes": uplinkBytes,
"downlinkBytes": downlinkBytes,
"uplinkPackets": uplinkPackets,
"downlinkPackets": downlinkPackets,
"tcpSessions": tcpSessions,
"udpSessions": udpSessions,
"users": users,
})
}

117
service/ssmapi/server.go Normal file
View File

@ -0,0 +1,117 @@
package ssmapi
import (
"context"
"errors"
"net/http"
"github.com/sagernet/sing-box/adapter"
boxService "github.com/sagernet/sing-box/adapter/service"
"github.com/sagernet/sing-box/common/listener"
"github.com/sagernet/sing-box/common/tls"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
N "github.com/sagernet/sing/common/network"
aTLS "github.com/sagernet/sing/common/tls"
"github.com/sagernet/sing/service"
"github.com/go-chi/chi/v5"
"golang.org/x/net/http2"
)
func RegisterService(registry *boxService.Registry) {
boxService.Register[option.SSMAPIServiceOptions](registry, C.TypeSSMAPI, NewService)
}
type Service struct {
boxService.Adapter
ctx context.Context
logger log.ContextLogger
listener *listener.Listener
tlsConfig tls.ServerConfig
httpServer *http.Server
}
func NewService(ctx context.Context, logger log.ContextLogger, tag string, options option.SSMAPIServiceOptions) (adapter.Service, error) {
chiRouter := chi.NewRouter()
s := &Service{
Adapter: boxService.NewAdapter(C.TypeSSMAPI, tag),
ctx: ctx,
logger: logger,
listener: listener.New(listener.Options{
Context: ctx,
Logger: logger,
Network: []string{N.NetworkTCP},
Listen: options.ListenOptions,
}),
httpServer: &http.Server{
Handler: chiRouter,
},
}
inboundManager := service.FromContext[adapter.InboundManager](ctx)
if options.Servers.Size() == 0 {
return nil, E.New("missing servers")
}
for i, entry := range options.Servers.Entries() {
inbound, loaded := inboundManager.Get(entry.Value)
if !loaded {
return nil, E.New("parse SSM server[", i, "]: inbound ", entry.Value, "not found")
}
managedServer, isManaged := inbound.(adapter.ManagedSSMServer)
if !isManaged {
return nil, E.New("parse SSM server[", i, "]: inbound/", inbound.Type(), "[", inbound.Tag(), "] is not a SSM server")
}
traffic := NewTrafficManager()
managedServer.SetTracker(traffic)
user := NewUserManager(managedServer, traffic)
chiRouter.Route(entry.Key, NewAPIServer(logger, traffic, user).Route)
}
if options.TLS != nil {
tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS))
if err != nil {
return nil, err
}
s.tlsConfig = tlsConfig
}
return s, nil
}
func (s *Service) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
if s.tlsConfig != nil {
err := s.tlsConfig.Start()
if err != nil {
return E.Cause(err, "create TLS config")
}
}
tcpListener, err := s.listener.ListenTCP()
if err != nil {
return err
}
if s.tlsConfig != nil {
if !common.Contains(s.tlsConfig.NextProtos(), http2.NextProtoTLS) {
s.tlsConfig.SetNextProtos(append([]string{"h2"}, s.tlsConfig.NextProtos()...))
}
tcpListener = aTLS.NewListener(tcpListener, s.tlsConfig)
}
go func() {
err = s.httpServer.Serve(tcpListener)
if err != nil && !errors.Is(err, http.ErrServerClosed) {
s.logger.Error("serve error: ", err)
}
}()
return nil
}
func (s *Service) Close() error {
return common.Close(
common.PtrOrNil(s.httpServer),
common.PtrOrNil(s.listener),
s.tlsConfig,
)
}

215
service/ssmapi/traffic.go Normal file
View File

@ -0,0 +1,215 @@
package ssmapi
import (
"net"
"sync"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing/common/atomic"
"github.com/sagernet/sing/common/bufio"
N "github.com/sagernet/sing/common/network"
)
var _ adapter.SSMTracker = (*TrafficManager)(nil)
type TrafficManager struct {
globalUplink atomic.Int64
globalDownlink atomic.Int64
globalUplinkPackets atomic.Int64
globalDownlinkPackets atomic.Int64
globalTCPSessions atomic.Int64
globalUDPSessions atomic.Int64
userAccess sync.Mutex
userUplink map[string]*atomic.Int64
userDownlink map[string]*atomic.Int64
userUplinkPackets map[string]*atomic.Int64
userDownlinkPackets map[string]*atomic.Int64
userTCPSessions map[string]*atomic.Int64
userUDPSessions map[string]*atomic.Int64
}
func NewTrafficManager() *TrafficManager {
manager := &TrafficManager{
userUplink: make(map[string]*atomic.Int64),
userDownlink: make(map[string]*atomic.Int64),
userUplinkPackets: make(map[string]*atomic.Int64),
userDownlinkPackets: make(map[string]*atomic.Int64),
userTCPSessions: make(map[string]*atomic.Int64),
userUDPSessions: make(map[string]*atomic.Int64),
}
return manager
}
func (s *TrafficManager) UpdateUsers(users []string) {
s.userAccess.Lock()
defer s.userAccess.Unlock()
newUserUplink := make(map[string]*atomic.Int64)
newUserDownlink := make(map[string]*atomic.Int64)
newUserUplinkPackets := make(map[string]*atomic.Int64)
newUserDownlinkPackets := make(map[string]*atomic.Int64)
newUserTCPSessions := make(map[string]*atomic.Int64)
newUserUDPSessions := make(map[string]*atomic.Int64)
for _, user := range users {
newUserUplink[user] = s.userUplinkPackets[user]
newUserDownlink[user] = s.userDownlinkPackets[user]
newUserUplinkPackets[user] = s.userUplinkPackets[user]
newUserDownlinkPackets[user] = s.userDownlinkPackets[user]
newUserTCPSessions[user] = s.userTCPSessions[user]
newUserUDPSessions[user] = s.userUDPSessions[user]
}
s.userUplink = newUserUplink
s.userDownlink = newUserDownlink
s.userUplinkPackets = newUserUplinkPackets
s.userDownlinkPackets = newUserDownlinkPackets
s.userTCPSessions = newUserTCPSessions
s.userUDPSessions = newUserUDPSessions
}
func (s *TrafficManager) userCounter(user string) (*atomic.Int64, *atomic.Int64, *atomic.Int64, *atomic.Int64, *atomic.Int64, *atomic.Int64) {
s.userAccess.Lock()
defer s.userAccess.Unlock()
upCounter, loaded := s.userUplink[user]
if !loaded {
upCounter = new(atomic.Int64)
s.userUplink[user] = upCounter
}
downCounter, loaded := s.userDownlink[user]
if !loaded {
downCounter = new(atomic.Int64)
s.userDownlink[user] = downCounter
}
upPacketsCounter, loaded := s.userUplinkPackets[user]
if !loaded {
upPacketsCounter = new(atomic.Int64)
s.userUplinkPackets[user] = upPacketsCounter
}
downPacketsCounter, loaded := s.userDownlinkPackets[user]
if !loaded {
downPacketsCounter = new(atomic.Int64)
s.userDownlinkPackets[user] = downPacketsCounter
}
tcpSessionsCounter, loaded := s.userTCPSessions[user]
if !loaded {
tcpSessionsCounter = new(atomic.Int64)
s.userTCPSessions[user] = tcpSessionsCounter
}
udpSessionsCounter, loaded := s.userUDPSessions[user]
if !loaded {
udpSessionsCounter = new(atomic.Int64)
s.userUDPSessions[user] = udpSessionsCounter
}
return upCounter, downCounter, upPacketsCounter, downPacketsCounter, tcpSessionsCounter, udpSessionsCounter
}
func (s *TrafficManager) TrackConnection(conn net.Conn, metadata adapter.InboundContext) net.Conn {
s.globalTCPSessions.Add(1)
var readCounter []*atomic.Int64
var writeCounter []*atomic.Int64
readCounter = append(readCounter, &s.globalUplink)
writeCounter = append(writeCounter, &s.globalDownlink)
upCounter, downCounter, _, _, tcpSessionCounter, _ := s.userCounter(metadata.User)
readCounter = append(readCounter, upCounter)
writeCounter = append(writeCounter, downCounter)
tcpSessionCounter.Add(1)
return bufio.NewInt64CounterConn(conn, readCounter, writeCounter)
}
func (s *TrafficManager) TrackPacketConnection(conn N.PacketConn, metadata adapter.InboundContext) N.PacketConn {
s.globalUDPSessions.Add(1)
var readCounter []*atomic.Int64
var readPacketCounter []*atomic.Int64
var writeCounter []*atomic.Int64
var writePacketCounter []*atomic.Int64
readCounter = append(readCounter, &s.globalUplink)
writeCounter = append(writeCounter, &s.globalDownlink)
readPacketCounter = append(readPacketCounter, &s.globalUplinkPackets)
writePacketCounter = append(writePacketCounter, &s.globalDownlinkPackets)
upCounter, downCounter, upPacketsCounter, downPacketsCounter, _, udpSessionCounter := s.userCounter(metadata.User)
readCounter = append(readCounter, upCounter)
writeCounter = append(writeCounter, downCounter)
readPacketCounter = append(readPacketCounter, upPacketsCounter)
writePacketCounter = append(writePacketCounter, downPacketsCounter)
udpSessionCounter.Add(1)
return bufio.NewInt64CounterPacketConn(conn, append(readCounter, readPacketCounter...), append(writeCounter, writePacketCounter...))
}
func (s *TrafficManager) ReadUser(user *UserObject) {
s.userAccess.Lock()
defer s.userAccess.Unlock()
s.readUser(user)
}
func (s *TrafficManager) readUser(user *UserObject) {
if counter, loaded := s.userUplink[user.UserName]; loaded {
user.UplinkBytes = counter.Load()
}
if counter, loaded := s.userDownlink[user.UserName]; loaded {
user.DownlinkBytes = counter.Load()
}
if counter, loaded := s.userUplinkPackets[user.UserName]; loaded {
user.UplinkPackets = counter.Load()
}
if counter, loaded := s.userDownlinkPackets[user.UserName]; loaded {
user.DownlinkPackets = counter.Load()
}
if counter, loaded := s.userTCPSessions[user.UserName]; loaded {
user.TCPSessions = counter.Load()
}
if counter, loaded := s.userUDPSessions[user.UserName]; loaded {
user.UDPSessions = counter.Load()
}
}
func (s *TrafficManager) ReadUsers(users []*UserObject) {
s.userAccess.Lock()
defer s.userAccess.Unlock()
for _, user := range users {
s.readUser(user)
}
return
}
func (s *TrafficManager) ReadGlobal() (
uplinkBytes int64,
downlinkBytes int64,
uplinkPackets int64,
downlinkPackets int64,
tcpSessions int64,
udpSessions int64,
) {
return s.globalUplink.Load(),
s.globalDownlink.Load(),
s.globalUplinkPackets.Load(),
s.globalDownlinkPackets.Load(),
s.globalTCPSessions.Load(),
s.globalUDPSessions.Load()
}
func (s *TrafficManager) Clear() {
s.globalUplink.Store(0)
s.globalDownlink.Store(0)
s.globalUplinkPackets.Store(0)
s.globalDownlinkPackets.Store(0)
s.globalTCPSessions.Store(0)
s.globalUDPSessions.Store(0)
s.userAccess.Lock()
defer s.userAccess.Unlock()
for _, counter := range s.userUplink {
counter.Store(0)
}
for _, counter := range s.userDownlink {
counter.Store(0)
}
for _, counter := range s.userUplinkPackets {
counter.Store(0)
}
for _, counter := range s.userDownlinkPackets {
counter.Store(0)
}
for _, counter := range s.userTCPSessions {
counter.Store(0)
}
for _, counter := range s.userUDPSessions {
counter.Store(0)
}
}

85
service/ssmapi/user.go Normal file
View File

@ -0,0 +1,85 @@
package ssmapi
import (
"sync"
"github.com/sagernet/sing-box/adapter"
E "github.com/sagernet/sing/common/exceptions"
)
type UserManager struct {
access sync.Mutex
usersMap map[string]string
server adapter.ManagedSSMServer
trafficManager *TrafficManager
}
func NewUserManager(inbound adapter.ManagedSSMServer, trafficManager *TrafficManager) *UserManager {
return &UserManager{
usersMap: make(map[string]string),
server: inbound,
trafficManager: trafficManager,
}
}
func (m *UserManager) postUpdate() error {
users := make([]string, 0, len(m.usersMap))
uPSKs := make([]string, 0, len(m.usersMap))
for username, password := range m.usersMap {
users = append(users, username)
uPSKs = append(uPSKs, password)
}
err := m.server.UpdateUsers(users, uPSKs)
if err != nil {
return err
}
m.trafficManager.UpdateUsers(users)
return nil
}
func (m *UserManager) List() []*UserObject {
m.access.Lock()
defer m.access.Unlock()
users := make([]*UserObject, 0, len(m.usersMap))
for username, password := range m.usersMap {
users = append(users, &UserObject{
UserName: username,
Password: password,
})
}
return users
}
func (m *UserManager) Add(username string, password string) error {
m.access.Lock()
defer m.access.Unlock()
if _, found := m.usersMap[username]; found {
return E.New("user", username, "already exists")
}
m.usersMap[username] = password
return m.postUpdate()
}
func (m *UserManager) Get(username string) (string, bool) {
m.access.Lock()
defer m.access.Unlock()
if password, found := m.usersMap[username]; found {
return password, true
}
return "", false
}
func (m *UserManager) Update(username string, password string) error {
m.access.Lock()
defer m.access.Unlock()
m.usersMap[username] = password
return m.postUpdate()
}
func (m *UserManager) Delete(username string) error {
m.access.Lock()
defer m.access.Unlock()
delete(m.usersMap, username)
return m.postUpdate()
}