wsc protocol with sing-box transport

This commit is contained in:
Mobin 2025-09-07 20:24:11 +03:30
parent 80bf1f838e
commit 425c455a0d
9 changed files with 718 additions and 161 deletions

View File

@ -2,36 +2,39 @@ package inbound
import (
"context"
"errors"
"math"
"net"
"net/http"
"os"
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/mux"
"github.com/sagernet/sing-box/common/tls"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-box/transport/v2ray"
"github.com/sagernet/sing-box/transport/wsc"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/exceptions"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
"github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/network"
N "github.com/sagernet/sing/common/network"
)
var _ adapter.Inbound = &WSC{}
var _ adapter.InjectableInbound = &WSC{}
var _ adapter.WSCServerTransportHandler = &wscTransportHandler{}
var _ adapter.V2RayServerTransportHandler = &wscTransportHandler{}
var _ wsc.Authenticator = &CustomAuthenticator{}
type WSC struct {
myInboundAdapter
server adapter.WSCServerTransport
service *wsc.Service
tlsConfig tls.ServerConfig
transport adapter.V2RayServerTransport
}
type wscTransportHandler WSC
@ -45,7 +48,7 @@ func NewWSC(ctx context.Context, router adapter.Router, logger log.ContextLogger
inbound := &WSC{
myInboundAdapter: myInboundAdapter{
protocol: C.TypeWSC,
network: []string{network.NetworkTCP},
network: []string{N.NetworkTCP},
ctx: ctx,
router: router,
logger: logger,
@ -53,16 +56,36 @@ func NewWSC(ctx context.Context, router adapter.Router, logger log.ContextLogger
listenOptions: options.ListenOptions,
},
}
server, err := wsc.NewServer(wsc.ServerConfig{
Ctx: ctx,
var err error
if options.TLS != nil {
tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS))
if err != nil {
return nil, err
}
inbound.tlsConfig = tlsConfig
}
if options.Transport != nil {
inbound.transport, err = v2ray.NewServerTransport(ctx, common.PtrValueOrDefault(options.Transport), inbound.tlsConfig, (*wscTransportHandler)(inbound))
if err != nil {
return nil, err
}
}
inbound.router, err = mux.NewRouterWithOptions(inbound.router, logger, common.PtrValueOrDefault(options.Multiplex))
if err != nil {
return nil, err
}
inbound.service, err = wsc.NewService(wsc.ServiceConfig{
Handler: adapter.NewUpstreamContextHandler(inbound.newConnection, inbound.newPacketConnection, inbound),
Logger: logger,
Handler: (*wscTransportHandler)(inbound),
Router: router,
Authenticator: &CustomAuthenticator{
id: 0,
logger: logger,
},
Router: router,
Dialer: network.SystemDialer,
MaxConnectionPerUser: options.MaxConnectionPerUser,
UsageReportTrafficInterval: options.UsageTraffic.Traffic,
UsageReportTimeInterval: time.Duration(options.UsageTraffic.Time),
@ -70,70 +93,87 @@ func NewWSC(ctx context.Context, router adapter.Router, logger log.ContextLogger
if err != nil {
return nil, err
}
if options.TLS != nil {
inbound.tlsConfig, err = tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS))
if err != nil {
return nil, err
}
}
inbound.server = server
inbound.connHandler = inbound
return inbound, nil
}
func (wsc *WSC) Close() error {
return common.Close(&wsc.myInboundAdapter, wsc.tlsConfig, wsc.server)
}
func (wsc *WSC) Start() error {
tcpListener, err := wsc.ListenTCP()
if err != nil {
return err
}
go func() {
sErr := wsc.server.Serve(tcpListener)
if sErr != nil && !exceptions.IsClosedOrCanceled(sErr) && !errors.Is(sErr, http.ErrServerClosed) {
wsc.logger.Error("wsc server serve error: ", sErr)
}
}()
return nil
}
func (wsc *WSC) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
wsc.routeTCP(ctx, conn, metadata)
return nil
var err error
if wsc.tlsConfig != nil && wsc.transport == nil {
conn, err = tls.ServerHandshake(ctx, conn, wsc.tlsConfig)
if err != nil {
return err
}
}
return wsc.service.NewConnection(adapter.WithContext(ctx, &metadata), conn, adapter.UpstreamMetadata(metadata))
}
func (wsc *WSC) NewPacketConnection(ctx context.Context, conn network.PacketConn, metadata adapter.InboundContext) error {
return wsc.myInboundAdapter.newPacketConnection(ctx, conn, metadata)
return os.ErrInvalid
}
func (wsc *WSC) Inject(conn net.Conn, metadata adapter.InboundContext) error {
func (wsc *WSC) Close() error {
return common.Close(&wsc.myInboundAdapter, wsc.tlsConfig, wsc.transport)
}
func (wsc *WSC) Start() error {
if wsc.tlsConfig != nil {
err := wsc.tlsConfig.Start()
if err != nil {
return E.Cause(err, "create TLS config")
}
}
if wsc.transport == nil {
return wsc.myInboundAdapter.Start()
}
if common.Contains(wsc.transport.Network(), N.NetworkTCP) {
tcpListener, err := wsc.myInboundAdapter.ListenTCP()
if err != nil {
return err
}
go func() {
sErr := wsc.transport.Serve(tcpListener)
if sErr != nil && !E.IsClosed(sErr) {
wsc.logger.Error("transport serve error: ", sErr)
}
}()
}
if common.Contains(wsc.transport.Network(), N.NetworkUDP) {
udpConn, err := wsc.myInboundAdapter.ListenUDP()
if err != nil {
return err
}
go func() {
sErr := wsc.transport.ServePacket(udpConn)
if sErr != nil && !E.IsClosed(sErr) {
wsc.logger.Error("transport serve error: ", sErr)
}
}()
}
return nil
}
func (wsc *WSC) newTransportConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
wsc.injectTCP(conn, metadata)
return nil
}
func (wsc *WSC) NewError(ctx context.Context, err error) {
wsc.myInboundAdapter.NewError(ctx, err)
func (wsc *WSC) newPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
return wsc.router.RoutePacketConnection(ctx, conn, metadata)
}
func (handler *wscTransportHandler) NewConnection(ctx context.Context, conn net.Conn, metadata metadata.Metadata) error {
return (*WSC)(handler).NewConnection(ctx, conn, adapter.InboundContext{
return (*WSC)(handler).newTransportConnection(ctx, conn, adapter.InboundContext{
Source: metadata.Source,
Destination: metadata.Destination,
})
}
func (handler *wscTransportHandler) NewPacketConnection(ctx context.Context, conn network.PacketConn, metadata metadata.Metadata) error {
return (*WSC)(handler).NewPacketConnection(ctx, conn, adapter.InboundContext{
Source: metadata.Source,
Destination: metadata.Destination,
})
}
func (handler *wscTransportHandler) NewError(ctx context.Context, err error) {
(*WSC)(handler).NewError(ctx, err)
}
func (auth *CustomAuthenticator) Authenticate(ctx context.Context, params wsc.AuthenticateParams) (wsc.AuthenticateResult, error) {
auth.id++
return wsc.AuthenticateResult{

View File

@ -6,23 +6,27 @@ type WSCUsageReport struct {
}
type WSCRule struct {
Action string `json:"action"`
Args []interface{} `json:"args"`
Action string `json:"action"`
Direction string `json:"direction,omitempty"`
Args []interface{} `json:"args"`
}
type WSCInboundOptions struct {
ListenOptions
InboundTLSOptionsContainer
MaxConnectionPerUser int `json:"max_connections,omitempty"`
UsageTraffic WSCUsageReport `json:"usage_traffic,omitempty"`
Multiplex *InboundMultiplexOptions `json:"multiplex,omitempty"`
Transport *V2RayTransportOptions `json:"transport,omitempty"`
MaxConnectionPerUser int `json:"max_connections,omitempty"`
UsageTraffic WSCUsageReport `json:"usage_traffic,omitempty"`
}
type WSCOutboundOptions struct {
DialerOptions
ServerOptions
OutboundTLSOptionsContainer
Network NetworkList `json:"network,omitempty"`
Auth string `json:"auth"`
Path string `json:"path"`
Rules []WSCRule `json:"rules,omitempty"`
Multiplex *OutboundMultiplexOptions `json:"multiplex,omitempty"`
Transport *V2RayTransportOptions `json:"transport,omitempty"`
Network NetworkList `json:"network,omitempty"`
Auth string `json:"auth"`
Rules []WSCRule `json:"rules,omitempty"`
}

View File

@ -3,37 +3,54 @@ package outbound
import (
"context"
"net"
"time"
"net/url"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/dialer"
"github.com/sagernet/sing-box/common/mux"
"github.com/sagernet/sing-box/common/tls"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-box/transport/v2ray"
"github.com/sagernet/sing-box/transport/wsc"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/network"
N "github.com/sagernet/sing/common/network"
)
var _ adapter.Outbound = &WSC{}
var _ N.Dialer = &wscDialer{}
type WSC struct {
myOutboundAdapter
dialer N.Dialer
tlsConfig tls.Config
client adapter.WSCClientTransport
dialer N.Dialer
serverAddr metadata.Socksaddr
multiplexDialer *mux.Client
tlsConfig tls.Config
transport adapter.V2RayClientTransport
auth string
ruleApplicator *wsc.WSCRuleApplicator
}
type wscDialer WSC
func NewWSC(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.WSCOutboundOptions) (*WSC, error) {
outboundDialer, err := dialer.New(router, options.DialerOptions)
if err != nil {
return nil, err
}
var ruleApplicator *wsc.WSCRuleApplicator = nil
if len(options.Rules) > 0 {
if ruleApplicator, err = wsc.NewRuleApplicator(options.Rules); err != nil {
return nil, err
}
}
outbound := &WSC{
myOutboundAdapter: myOutboundAdapter{
protocol: C.TypeWSC,
@ -43,20 +60,11 @@ func NewWSC(ctx context.Context, router adapter.Router, logger log.ContextLogger
tag: tag,
dependencies: withDialerDependency(options.DialerOptions),
},
dialer: outboundDialer,
dialer: outboundDialer,
auth: options.Auth,
ruleApplicator: ruleApplicator,
}
serverAddr := options.ServerOptions.Build()
if options.Auth == "" {
return nil, exceptions.New("Invalid Auth to use in authentications")
}
if !serverAddr.IsValid() {
return nil, exceptions.New("Invalid server address")
}
if options.Path == "" {
options.Path = "/"
}
if options.TLS != nil {
outbound.tlsConfig, err = tls.NewClient(ctx, options.Server, common.PtrValueOrDefault(options.TLS))
if err != nil {
@ -64,14 +72,16 @@ func NewWSC(ctx context.Context, router adapter.Router, logger log.ContextLogger
}
}
outbound.client, err = wsc.NewClient(wsc.ClientConfig{
Auth: options.Auth,
Host: serverAddr.String(),
Path: options.Path,
TLS: outbound.tlsConfig,
Dialer: outbound.dialer,
Rules: options.Rules,
})
outbound.serverAddr = options.ServerOptions.Build()
if options.Transport != nil {
outbound.transport, err = v2ray.NewClientTransport(ctx, outbound.dialer, outbound.serverAddr, common.PtrValueOrDefault(options.Transport), outbound.tlsConfig)
if err != nil {
return nil, E.Cause(err, "create client transport: ", options.Transport.Type)
}
}
outbound.multiplexDialer, err = mux.NewClientWithOptions((*wscDialer)(outbound), logger, common.PtrValueOrDefault(options.Multiplex))
if err != nil {
return nil, err
}
@ -80,34 +90,105 @@ func NewWSC(ctx context.Context, router adapter.Router, logger log.ContextLogger
}
func (wsc *WSC) DialContext(ctx context.Context, network string, destination metadata.Socksaddr) (net.Conn, error) {
ctx, meta := adapter.ExtendContext(ctx)
meta.Outbound = wsc.tag
meta.Destination = destination
if N.NetworkName(network) != N.NetworkTCP {
return nil, exceptions.Extend(N.ErrUnknownNetwork, network)
if wsc.multiplexDialer == nil {
switch N.NetworkName(network) {
case N.NetworkTCP:
wsc.logger.InfoContext(ctx, "outbound connection to ", destination)
case N.NetworkUDP:
wsc.logger.InfoContext(ctx, "outbound packet connection to ", destination)
}
return (*wscDialer)(wsc).DialContext(ctx, network, destination)
} else {
switch N.NetworkName(network) {
case N.NetworkTCP:
wsc.logger.InfoContext(ctx, "outbound multiplex connection to ", destination)
case N.NetworkUDP:
wsc.logger.InfoContext(ctx, "outbound multiplex packet connection to ", destination)
}
return wsc.multiplexDialer.DialContext(ctx, network, destination)
}
wsc.logger.InfoContext(ctx, "WSC outbound connection to ", destination)
return wsc.client.DialContext(ctx, network, destination.String())
}
func (wsc *WSC) ListenPacket(ctx context.Context, destination metadata.Socksaddr) (net.PacketConn, error) {
ctx, meta := adapter.ExtendContext(ctx)
meta.Outbound = wsc.tag
meta.Destination = destination
wsc.logger.InfoContext(ctx, "WSC outbound packet to ", destination)
return wsc.client.ListenPacket(ctx, N.NetworkUDP, destination.String())
if wsc.multiplexDialer == nil {
wsc.logger.InfoContext(ctx, "outbound packet connection to ", destination)
return (*wscDialer)(wsc).ListenPacket(ctx, destination)
} else {
wsc.logger.InfoContext(ctx, "outbound multiplex packet connection to ", destination)
return wsc.multiplexDialer.ListenPacket(ctx, destination)
}
}
func (wsc *WSC) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
return NewConnection(ctx, wsc, conn, metadata)
}
func (wsc *WSC) NewPacketConnection(ctx context.Context, conn network.PacketConn, metadata adapter.InboundContext) error {
func (wsc *WSC) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
return NewPacketConnection(ctx, wsc, conn, metadata)
}
func (wsc *WSC) Close() error {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
return wsc.client.Close(ctx)
func (wsc *WSC) InterfaceUpdated() {
if wsc.transport != nil {
wsc.transport.Close()
}
if wsc.multiplexDialer != nil {
wsc.multiplexDialer.Reset()
}
}
func (wsc *WSC) Close() error {
return common.Close(common.PtrOrNil(wsc.multiplexDialer))
}
func (dialer *wscDialer) DialContext(ctx context.Context, network string, destination metadata.Socksaddr) (net.Conn, error) {
ctx, metadata := adapter.ExtendContext(ctx)
metadata.Outbound = dialer.tag
metadata.Destination = destination
ep, netw := destination.String(), network
if dialer.ruleApplicator != nil {
ep, netw = dialer.ruleApplicator.ApplyEndpointReplace(ep, netw, wsc.RuleDirectionOutbound)
}
params := url.Values{}
params.Set("auth", dialer.auth)
params.Set("ep", ep)
params.Set("net", netw)
ctx = context.WithValue(ctx, adapter.V2RayExtraOptionsKey, adapter.V2RayExtraOptions{
QueryParams: params,
})
var conn net.Conn
var err error
if dialer.transport != nil {
conn, err = dialer.transport.DialContext(ctx)
} else {
conn, err = dialer.dialer.DialContext(ctx, N.NetworkTCP, dialer.serverAddr)
if err == nil && dialer.tlsConfig != nil {
conn, err = tls.ClientHandshake(ctx, conn, dialer.tlsConfig)
}
}
if err != nil {
common.Close(conn)
return nil, err
}
switch N.NetworkName(network) {
case N.NetworkTCP:
return wsc.NewClientConn(conn, destination)
case N.NetworkUDP:
packetConn, err := wsc.NewClientPacketConn(conn, dialer.ruleApplicator)
return bufio.NewBindPacketConn(packetConn, destination), err
default:
return nil, E.Extend(N.ErrUnknownNetwork, network)
}
}
func (dialer *wscDialer) ListenPacket(ctx context.Context, destination metadata.Socksaddr) (net.PacketConn, error) {
conn, err := dialer.DialContext(ctx, N.NetworkUDP, destination)
if err != nil {
return nil, err
}
return conn.(net.PacketConn), nil
}

View File

@ -4,6 +4,7 @@ import (
"encoding"
"encoding/binary"
"errors"
"io"
"net/netip"
)
@ -62,7 +63,7 @@ func (payload *packetConnPayload) MarshalBinaryUnsafe(data []byte) error {
}
if len(data) < hLen+len(payload.payload) {
return errors.New("invalid data length to write")
return io.ErrShortBuffer
}
addr := payload.addrPort.Addr().As16()

235
transport/wsc/protocol.go Normal file
View File

@ -0,0 +1,235 @@
package wsc
import (
"net"
"os"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
"github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/metadata"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/network"
N "github.com/sagernet/sing/common/network"
)
var _ N.EarlyConn = &ClientConn{}
var _ net.Conn = &ClientPacketConn{}
var _ net.PacketConn = &ClientPacketConn{}
var _ N.NetPacketReader = &ClientPacketConn{}
var _ N.NetPacketWriter = &ClientPacketConn{}
var _ N.NetPacketReader = &servicePacketConn{}
var _ N.NetPacketWriter = &servicePacketConn{}
type ClientConn struct {
N.ExtendedConn
destination M.Socksaddr
}
type ClientPacketConn struct {
net.Conn
ruleApplicator *WSCRuleApplicator
writePayload packetConnPayload
readPayload packetConnPayload
packet [buf.UDPBufferSize]byte
}
type servicePacketConn struct {
net.Conn
writePayload packetConnPayload
readPayload packetConnPayload
packet [buf.UDPBufferSize]byte
}
func NewClientConn(conn net.Conn, destination M.Socksaddr) (*ClientConn, error) {
return &ClientConn{
ExtendedConn: bufio.NewExtendedConn(conn),
destination: destination,
}, nil
}
func NewClientPacketConn(conn net.Conn, ruleApplicator *WSCRuleApplicator) (*ClientPacketConn, error) {
return &ClientPacketConn{
Conn: conn,
ruleApplicator: ruleApplicator,
}, nil
}
func (conn *ClientConn) NeedHandshake() bool {
return false
}
func (conn *ClientConn) FrontHeadroom() int {
return 0
}
func (conn *ClientConn) Upstream() any {
return conn.ExtendedConn
}
func (packetConn *ClientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
if buffer == nil {
return exceptions.New("buffer is nil")
}
if packetConn.ruleApplicator != nil {
ep, _ := packetConn.ruleApplicator.ApplyEndpointReplace(destination.String(), network.NetworkUDP, RuleDirectionOutbound)
packetConn.writePayload.addrPort = metadata.ParseSocksaddr(ep).AddrPort()
} else {
packetConn.writePayload.addrPort = destination.AddrPort()
}
packetConn.writePayload.payload = buffer.Bytes()
payloadBytes, err := packetConn.writePayload.MarshalBinary()
if err != nil {
return err
}
_, err = packetConn.Conn.Write(payloadBytes)
return err
}
func (packetConn *ClientPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
if buffer == nil {
return destination, exceptions.New("buffer is nil")
}
n, err := packetConn.Conn.Read(packetConn.packet[:])
if err != nil {
return destination, err
}
if err := packetConn.readPayload.UnmarshalBinaryUnsafe(packetConn.packet[:n]); err != nil {
return destination, err
}
if _, err := buffer.Write(packetConn.readPayload.payload); err != nil {
return destination, err
}
destination = metadata.SocksaddrFromNetIP(packetConn.readPayload.addrPort)
if packetConn.ruleApplicator != nil {
ep, _ := packetConn.ruleApplicator.ApplyEndpointReplace(destination.String(), N.NetworkUDP, RuleDirectionInbound)
destination = metadata.ParseSocksaddr(ep)
}
return
}
func (packetConn *ClientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
buffer := buf.With(p)
destination, err := packetConn.ReadPacket(buffer)
if err != nil {
return
}
n = buffer.Len()
if destination.IsFqdn() {
addr = destination
} else {
addr = destination.UDPAddr()
}
return
}
func (packetConn *ClientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
return bufio.WritePacket(packetConn, p, addr)
}
func (packetConn *ClientPacketConn) Read(b []byte) (n int, err error) {
n, _, err = packetConn.ReadFrom(b)
return
}
func (packetConn *ClientPacketConn) Write(b []byte) (n int, err error) {
return 0, os.ErrInvalid
}
func (packetConn *ClientPacketConn) NeedHandshake() bool {
return false
}
func (packetConn *ClientPacketConn) FrontHeadroom() int {
return 0
}
func (packetConn *ClientPacketConn) Upstream() any {
return packetConn.Conn
}
func (packetConn *servicePacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
if buffer == nil {
return exceptions.New("buffer is nil")
}
packetConn.writePayload.addrPort = destination.AddrPort()
packetConn.writePayload.payload = buffer.Bytes()
payloadBytes, err := packetConn.writePayload.MarshalBinary()
if err != nil {
return err
}
_, err = packetConn.Conn.Write(payloadBytes)
return err
}
func (packetConn *servicePacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
return bufio.WritePacket(packetConn, p, addr)
}
func (packetConn *servicePacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
buffer := buf.With(p)
destination, err := packetConn.ReadPacket(buffer)
if err != nil {
return
}
n = buffer.Len()
if destination.IsFqdn() {
addr = destination
} else {
addr = destination.UDPAddr()
}
return
}
func (packetConn *servicePacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
if buffer == nil {
return destination, exceptions.New("buffer is nil")
}
n, err := packetConn.Conn.Read(packetConn.packet[:])
if err != nil {
return destination, err
}
if err := packetConn.readPayload.UnmarshalBinaryUnsafe(packetConn.packet[:n]); err != nil {
return destination, err
}
if _, err := buffer.Write(packetConn.readPayload.payload); err != nil {
return destination, err
}
destination = metadata.SocksaddrFromNetIP(packetConn.readPayload.addrPort)
return
}
func (packetConn *servicePacketConn) NeedHandshake() bool {
return false
}
func (packetConn *servicePacketConn) FrontHeadroom() int {
return 0
}
func (packetConn *servicePacketConn) NeedAdditionalReadDeadline() bool {
return false
}
func (packetConn *servicePacketConn) Upstream() any {
return packetConn.Conn
}

View File

@ -1,6 +1,7 @@
package wsc
import (
"bytes"
"errors"
"github.com/sagernet/sing-box/option"
@ -8,15 +9,23 @@ import (
)
type RuleAction int
type RuleDirection int
const (
RuleActionUnknown RuleAction = iota
RuleActionReplace
)
const (
RuleDirectionUnknown RuleDirection = iota
RuleDirectionInbound
RuleDirectionOutbound
)
type WSCRule struct {
Action RuleAction
Args []interface{}
Action RuleAction
Direction RuleDirection
Args []interface{}
}
type WSCRuleApplicator struct {
@ -30,9 +39,17 @@ func NewRuleApplicator(rules []option.WSCRule) (*WSCRuleApplicator, error) {
if err != nil {
return nil, err
}
var direction RuleDirection = RuleDirectionUnknown
if len(rule.Direction) > 0 {
direction, err = RuleDirectionFromString(rule.Direction)
if err != nil {
return nil, err
}
}
wscRules = append(wscRules, WSCRule{
Action: action,
Args: rule.Args,
Action: action,
Direction: direction,
Args: rule.Args,
})
}
return &WSCRuleApplicator{
@ -40,13 +57,14 @@ func NewRuleApplicator(rules []option.WSCRule) (*WSCRuleApplicator, error) {
}, nil
}
func (ruleManager *WSCRuleApplicator) ApplyEndpointReplace(ep string, netw string) (finalEp string, finalNetw string) {
finalEp, finalNetw = ep, netw
func (ruleManager *WSCRuleApplicator) ApplyEndpointReplace(ep string, netw string, direction RuleDirection) (finalEp string, finalNetw string) {
for _, rule := range ruleManager.Rules {
if rule.Action != RuleActionReplace {
continue
}
if rule.Direction != RuleDirectionUnknown && direction != RuleDirectionUnknown && rule.Direction != direction {
continue
}
sType, ok := rule.Args[0].(string)
if !ok {
@ -80,7 +98,18 @@ func (ruleManager *WSCRuleApplicator) ApplyEndpointReplace(ep string, netw strin
epAddr := metadata.ParseSocksaddr(ep)
equal := false
if (whatAddr.IsFqdn() && epAddr.IsFqdn() && whatAddr.Fqdn == epAddr.Fqdn) || whatAddr.Addr.Compare(epAddr.Addr) == 0 {
if whatAddr.IsFqdn() && epAddr.IsFqdn() && whatAddr.Fqdn == epAddr.Fqdn {
equal = true
} else if whatAddr.IsIPv4() {
if epAddr.IsIPv4() || epAddr.Addr.Is4In6() {
whatAddr4 := whatAddr.Addr.As4()
epAddr4 := epAddr.Addr.As4()
equal = bytes.Equal(whatAddr4[:], epAddr4[:])
}
} else if whatAddr.IsIPv6() && epAddr.IsIPv6() {
equal = whatAddr.Addr.Compare(epAddr.Addr) == 0
}
if equal {
if whatAddr.Port == 0 {
equal = true
} else {
@ -117,6 +146,17 @@ func RuleActionFromString(actionStr string) (RuleAction, error) {
case "replace":
return RuleActionReplace, nil
default:
return 0, errors.New("rule action doesn't exist")
return RuleActionUnknown, errors.New("rule action doesn't exist")
}
}
func RuleDirectionFromString(directionStr string) (RuleDirection, error) {
switch directionStr {
case "inbound":
return RuleDirectionInbound, nil
case "outbound":
return RuleDirectionOutbound, nil
default:
return RuleDirectionUnknown, errors.New("rule direction doesn't exist")
}
}

202
transport/wsc/service.go Normal file
View File

@ -0,0 +1,202 @@
package wsc
import (
"context"
"net"
"net/url"
"time"
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/network"
N "github.com/sagernet/sing/common/network"
)
type Handler interface {
N.TCPConnectionHandler
N.UDPConnectionHandler
E.Handler
}
type Service struct {
logger logger.ContextLogger
router adapter.Router
handler Handler
authenticator Authenticator
userManager *wscUserManager
}
type ServiceConfig struct {
Logger logger.ContextLogger
Router adapter.Router
Handler Handler
Authenticator Authenticator
MaxConnectionPerUser int
UsageReportTrafficInterval int64
UsageReportTimeInterval time.Duration
}
type meteredConn struct {
net.Conn
user *wscUser
}
func NewService(config ServiceConfig) (*Service, error) {
if config.Handler == nil {
return nil, E.New("Handler required")
}
if config.Authenticator == nil {
return nil, E.New("Authenticator required")
}
return &Service{
logger: config.Logger,
router: config.Router,
handler: config.Handler,
authenticator: config.Authenticator,
userManager: &wscUserManager{
users: map[int64]*wscUser{},
authenticator: config.Authenticator,
maxConnPerUser: config.MaxConnectionPerUser,
usageReportTrafficInterval: config.UsageReportTrafficInterval,
usageReportTimeInterval: config.UsageReportTimeInterval,
},
}, nil
}
func (service *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
params, err := service.readQueryParams(conn)
if err != nil {
return err
}
auth := params.Get("auth")
if auth == "" {
return E.New("authentication required")
}
account, err := service.authenticator.Authenticate(ctx, AuthenticateParams{
Auth: auth,
MaxConn: service.userManager.maxConnPerUser,
})
if err != nil {
if account.ID != 0 {
if err := service.userManager.cleanupUser(ctx, account.ID, false); err != nil {
return err
}
return E.Cause(err, "authentication failed")
}
}
// user cleanup
//{}
user := service.userManager.findOrCreateUser(ctx, account.ID, account.Rate, account.MaxConn)
netw := params.Get("net")
if netw == "" {
netw = network.NetworkTCP
}
endpoint := params.Get("ep")
addr, err := service.resolveDestination(ctx, M.ParseSocksaddr(endpoint))
if err != nil {
return E.Cause(err, "failed to parse and resolve endpoint")
}
service.log("New request (Client: ", metadata.Source, ", Auth: ", auth, ", User-ID: ", account.ID, ", ", netw+"-Addr: ", addr.String(), ")")
metadata.Protocol = C.TypeWSC
metadata.Destination = addr
if popedConn, err := user.addConn(conn); err != nil {
return err
} else {
if popedConn != nil {
popedConn.Close()
}
}
switch N.NetworkName(netw) {
case N.NetworkTCP:
err = service.handler.NewConnection(ctx, &meteredConn{Conn: conn, user: user}, metadata)
case N.NetworkUDP:
return service.handler.NewPacketConnection(ctx, &servicePacketConn{
Conn: &meteredConn{
Conn: conn,
user: user,
},
}, metadata)
default:
return E.New("not supported protocol ", netw)
}
if cErr := service.userManager.cleanupUserConn(ctx, user, conn); cErr != nil && err == nil {
err = cErr
}
return err
}
func (service *Service) readQueryParams(conn net.Conn) (url.Values, error) {
var queryParamsRaw [500]byte
n, err := conn.Read(queryParamsRaw[:])
if err != nil {
return nil, err
}
pURL := url.URL{
RawQuery: string(queryParamsRaw[:n]),
}
return pURL.Query(), nil
}
func (service *Service) resolveDestination(ctx context.Context, dest M.Socksaddr) (M.Socksaddr, error) {
if dest.IsFqdn() {
addrs, err := service.router.LookupDefault(ctx, dest.Fqdn)
if err != nil {
return M.Socksaddr{}, err
}
if len(addrs) == 0 {
return M.Socksaddr{}, E.New("no address found for endpoint domain: ", dest.Fqdn)
}
return M.Socksaddr{
Addr: addrs[0],
Port: dest.Port,
}, nil
}
return dest, nil
}
func (service *Service) log(args ...any) {
if service.logger != nil {
service.logger.Debug(args...)
}
}
func (conn *meteredConn) Read(p []byte) (int, error) {
reader, err := conn.user.connReader(conn.Conn)
if err != nil {
return 0, err
}
n, err := reader.Read(p)
if err != nil {
return 0, err
}
conn.user.usedTrafficBytes.Add(int64(n))
return n, nil
}
func (conn *meteredConn) Write(p []byte) (int, error) {
writer, err := conn.user.connWriter(conn.Conn)
if err != nil {
return 0, err
}
n, err := writer.Write(p)
if err != nil {
return 0, err
}
conn.user.usedTrafficBytes.Add(int64(n))
return n, nil
}

View File

@ -27,7 +27,6 @@ type wscUser struct {
reportedTrafficBytes atomic.Int64
lastTrafficUpdateTick atomic.Int64
conns map[net.Conn]connData
heap []byte
rateLimit int64
maxConnCount int
usedIds []bool
@ -37,7 +36,6 @@ func (manager *wscUserManager) newUser(id int64, usedTrafficBytes int64, maxConn
user := &wscUser{
id: id,
conns: make(map[net.Conn]connData, maxConnCount),
heap: make([]byte, connReadSize*2*maxConnCount),
rateLimit: rateLimit,
usedIds: make([]bool, maxConnCount),
maxConnCount: maxConnCount,
@ -48,40 +46,6 @@ func (manager *wscUserManager) newUser(id int64, usedTrafficBytes int64, maxConn
return user
}
func (user *wscUser) outBuffer(conn net.Conn) []byte {
user.mu.Lock()
defer user.mu.Unlock()
if user.maxConnCount < 1 {
return make([]byte, connReadSize)
}
if d, found := user.conns[conn]; found {
bufStart := (connReadSize*2)*(d.id+1) - connReadSize
bufEnd := bufStart + connReadSize
return user.heap[bufStart:bufEnd]
}
return make([]byte, connReadSize)
}
func (user *wscUser) inBuffer(conn net.Conn) []byte {
user.mu.Lock()
defer user.mu.Unlock()
if user.maxConnCount < 1 {
return make([]byte, connReadSize)
}
if d, found := user.conns[conn]; found {
bufStart := (connReadSize * 2) * d.id
bufEnd := bufStart + connReadSize
return user.heap[bufStart:bufEnd]
}
return make([]byte, connReadSize)
}
func (user *wscUser) connReader(conn net.Conn) (io.Reader, error) {
user.mu.Lock()
defer user.mu.Unlock()

View File

@ -1,19 +1,9 @@
package wsc
import (
"net"
"time"
"github.com/itsabgr/ge"
)
func nowns() int64 {
return time.Now().UnixNano()
}
func isTimeoutErr(err error) bool {
if nErr, ok := ge.As[net.Error](err); ok && nErr.Timeout() {
return true
}
return false
}