diff --git a/inbound/vless.go b/inbound/vless.go index 2de8ce6d..d4c00f76 100644 --- a/inbound/vless.go +++ b/inbound/vless.go @@ -55,6 +55,8 @@ func NewVLESS(ctx context.Context, router adapter.Router, logger log.ContextLogg return index }), common.Map(inbound.users, func(it option.VLESSUser) string { return it.UUID + }), common.Map(inbound.users, func(it option.VLESSUser) string { + return it.Flow })) inbound.service = service var err error diff --git a/option/vless.go b/option/vless.go index 9b48d993..2b6521b1 100644 --- a/option/vless.go +++ b/option/vless.go @@ -10,6 +10,7 @@ type VLESSInboundOptions struct { type VLESSUser struct { Name string `json:"name"` UUID string `json:"uuid"` + Flow string `json:"flow,omitempty"` } type VLESSOutboundOptions struct { diff --git a/transport/vless/service.go b/transport/vless/service.go index b77f9702..1a6cfcfb 100644 --- a/transport/vless/service.go +++ b/transport/vless/service.go @@ -18,10 +18,11 @@ import ( "github.com/gofrs/uuid" ) -type Service[T any] struct { - userMap map[[16]byte]T - logger logger.Logger - handler Handler +type Service[T comparable] struct { + userMap map[[16]byte]T + userFlow map[T]string + logger logger.Logger + handler Handler } type Handler interface { @@ -30,23 +31,26 @@ type Handler interface { E.Handler } -func NewService[T any](logger logger.Logger, handler Handler) *Service[T] { +func NewService[T comparable](logger logger.Logger, handler Handler) *Service[T] { return &Service[T]{ logger: logger, handler: handler, } } -func (s *Service[T]) UpdateUsers(userList []T, userUUIDList []string) { +func (s *Service[T]) UpdateUsers(userList []T, userUUIDList []string, userFlowList []string) { userMap := make(map[[16]byte]T) + userFlowMap := make(map[T]string) for i, userName := range userList { userID := uuid.FromStringOrNil(userUUIDList[i]) if userID == uuid.Nil { userID = uuid.NewV5(uuid.Nil, userUUIDList[i]) } userMap[userID] = userName + userFlowMap[userName] = userFlowList[i] } s.userMap = userMap + s.userFlow = userFlowMap } var _ N.TCPConnectionHandler = (*Service[int])(nil) @@ -63,8 +67,13 @@ func (s *Service[T]) NewConnection(ctx context.Context, conn net.Conn, metadata ctx = auth.ContextWithUser(ctx, user) metadata.Destination = request.Destination + userFlow := s.userFlow[user] + if request.Flow != userFlow { + return E.New("flow mismatch: expected ", userFlow, ", but got ", request.Flow) + } + protocolConn := conn - switch request.Flow { + switch userFlow { case "": case FlowVision: protocolConn, err = NewVisionConn(conn, request.UUID, s.logger)