mirror of
https://github.com/SagerNet/sing-box.git
synced 2025-09-09 13:04:06 +08:00
wsc protocol with sing-box transport
This commit is contained in:
parent
80bf1f838e
commit
425c455a0d
148
inbound/wsc.go
148
inbound/wsc.go
@ -2,36 +2,39 @@ package inbound
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/common/mux"
|
||||
"github.com/sagernet/sing-box/common/tls"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing-box/transport/v2ray"
|
||||
"github.com/sagernet/sing-box/transport/wsc"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/exceptions"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
"github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/network"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
var _ adapter.Inbound = &WSC{}
|
||||
var _ adapter.InjectableInbound = &WSC{}
|
||||
|
||||
var _ adapter.WSCServerTransportHandler = &wscTransportHandler{}
|
||||
var _ adapter.V2RayServerTransportHandler = &wscTransportHandler{}
|
||||
|
||||
var _ wsc.Authenticator = &CustomAuthenticator{}
|
||||
|
||||
type WSC struct {
|
||||
myInboundAdapter
|
||||
server adapter.WSCServerTransport
|
||||
service *wsc.Service
|
||||
tlsConfig tls.ServerConfig
|
||||
transport adapter.V2RayServerTransport
|
||||
}
|
||||
|
||||
type wscTransportHandler WSC
|
||||
@ -45,7 +48,7 @@ func NewWSC(ctx context.Context, router adapter.Router, logger log.ContextLogger
|
||||
inbound := &WSC{
|
||||
myInboundAdapter: myInboundAdapter{
|
||||
protocol: C.TypeWSC,
|
||||
network: []string{network.NetworkTCP},
|
||||
network: []string{N.NetworkTCP},
|
||||
ctx: ctx,
|
||||
router: router,
|
||||
logger: logger,
|
||||
@ -53,16 +56,36 @@ func NewWSC(ctx context.Context, router adapter.Router, logger log.ContextLogger
|
||||
listenOptions: options.ListenOptions,
|
||||
},
|
||||
}
|
||||
server, err := wsc.NewServer(wsc.ServerConfig{
|
||||
Ctx: ctx,
|
||||
|
||||
var err error
|
||||
|
||||
if options.TLS != nil {
|
||||
tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
inbound.tlsConfig = tlsConfig
|
||||
}
|
||||
if options.Transport != nil {
|
||||
inbound.transport, err = v2ray.NewServerTransport(ctx, common.PtrValueOrDefault(options.Transport), inbound.tlsConfig, (*wscTransportHandler)(inbound))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
inbound.router, err = mux.NewRouterWithOptions(inbound.router, logger, common.PtrValueOrDefault(options.Multiplex))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
inbound.service, err = wsc.NewService(wsc.ServiceConfig{
|
||||
Handler: adapter.NewUpstreamContextHandler(inbound.newConnection, inbound.newPacketConnection, inbound),
|
||||
Logger: logger,
|
||||
Handler: (*wscTransportHandler)(inbound),
|
||||
Router: router,
|
||||
Authenticator: &CustomAuthenticator{
|
||||
id: 0,
|
||||
logger: logger,
|
||||
},
|
||||
Router: router,
|
||||
Dialer: network.SystemDialer,
|
||||
MaxConnectionPerUser: options.MaxConnectionPerUser,
|
||||
UsageReportTrafficInterval: options.UsageTraffic.Traffic,
|
||||
UsageReportTimeInterval: time.Duration(options.UsageTraffic.Time),
|
||||
@ -70,70 +93,87 @@ func NewWSC(ctx context.Context, router adapter.Router, logger log.ContextLogger
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if options.TLS != nil {
|
||||
inbound.tlsConfig, err = tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
inbound.server = server
|
||||
|
||||
inbound.connHandler = inbound
|
||||
|
||||
return inbound, nil
|
||||
}
|
||||
|
||||
func (wsc *WSC) Close() error {
|
||||
return common.Close(&wsc.myInboundAdapter, wsc.tlsConfig, wsc.server)
|
||||
}
|
||||
|
||||
func (wsc *WSC) Start() error {
|
||||
tcpListener, err := wsc.ListenTCP()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go func() {
|
||||
sErr := wsc.server.Serve(tcpListener)
|
||||
if sErr != nil && !exceptions.IsClosedOrCanceled(sErr) && !errors.Is(sErr, http.ErrServerClosed) {
|
||||
wsc.logger.Error("wsc server serve error: ", sErr)
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (wsc *WSC) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
|
||||
wsc.routeTCP(ctx, conn, metadata)
|
||||
return nil
|
||||
var err error
|
||||
if wsc.tlsConfig != nil && wsc.transport == nil {
|
||||
conn, err = tls.ServerHandshake(ctx, conn, wsc.tlsConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return wsc.service.NewConnection(adapter.WithContext(ctx, &metadata), conn, adapter.UpstreamMetadata(metadata))
|
||||
}
|
||||
|
||||
func (wsc *WSC) NewPacketConnection(ctx context.Context, conn network.PacketConn, metadata adapter.InboundContext) error {
|
||||
return wsc.myInboundAdapter.newPacketConnection(ctx, conn, metadata)
|
||||
return os.ErrInvalid
|
||||
}
|
||||
|
||||
func (wsc *WSC) Inject(conn net.Conn, metadata adapter.InboundContext) error {
|
||||
func (wsc *WSC) Close() error {
|
||||
return common.Close(&wsc.myInboundAdapter, wsc.tlsConfig, wsc.transport)
|
||||
}
|
||||
|
||||
func (wsc *WSC) Start() error {
|
||||
if wsc.tlsConfig != nil {
|
||||
err := wsc.tlsConfig.Start()
|
||||
if err != nil {
|
||||
return E.Cause(err, "create TLS config")
|
||||
}
|
||||
}
|
||||
if wsc.transport == nil {
|
||||
return wsc.myInboundAdapter.Start()
|
||||
}
|
||||
|
||||
if common.Contains(wsc.transport.Network(), N.NetworkTCP) {
|
||||
tcpListener, err := wsc.myInboundAdapter.ListenTCP()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go func() {
|
||||
sErr := wsc.transport.Serve(tcpListener)
|
||||
if sErr != nil && !E.IsClosed(sErr) {
|
||||
wsc.logger.Error("transport serve error: ", sErr)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if common.Contains(wsc.transport.Network(), N.NetworkUDP) {
|
||||
udpConn, err := wsc.myInboundAdapter.ListenUDP()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go func() {
|
||||
sErr := wsc.transport.ServePacket(udpConn)
|
||||
if sErr != nil && !E.IsClosed(sErr) {
|
||||
wsc.logger.Error("transport serve error: ", sErr)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (wsc *WSC) newTransportConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
|
||||
wsc.injectTCP(conn, metadata)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (wsc *WSC) NewError(ctx context.Context, err error) {
|
||||
wsc.myInboundAdapter.NewError(ctx, err)
|
||||
func (wsc *WSC) newPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
|
||||
return wsc.router.RoutePacketConnection(ctx, conn, metadata)
|
||||
}
|
||||
|
||||
func (handler *wscTransportHandler) NewConnection(ctx context.Context, conn net.Conn, metadata metadata.Metadata) error {
|
||||
return (*WSC)(handler).NewConnection(ctx, conn, adapter.InboundContext{
|
||||
return (*WSC)(handler).newTransportConnection(ctx, conn, adapter.InboundContext{
|
||||
Source: metadata.Source,
|
||||
Destination: metadata.Destination,
|
||||
})
|
||||
}
|
||||
|
||||
func (handler *wscTransportHandler) NewPacketConnection(ctx context.Context, conn network.PacketConn, metadata metadata.Metadata) error {
|
||||
return (*WSC)(handler).NewPacketConnection(ctx, conn, adapter.InboundContext{
|
||||
Source: metadata.Source,
|
||||
Destination: metadata.Destination,
|
||||
})
|
||||
}
|
||||
|
||||
func (handler *wscTransportHandler) NewError(ctx context.Context, err error) {
|
||||
(*WSC)(handler).NewError(ctx, err)
|
||||
}
|
||||
|
||||
func (auth *CustomAuthenticator) Authenticate(ctx context.Context, params wsc.AuthenticateParams) (wsc.AuthenticateResult, error) {
|
||||
auth.id++
|
||||
return wsc.AuthenticateResult{
|
||||
|
@ -6,23 +6,27 @@ type WSCUsageReport struct {
|
||||
}
|
||||
|
||||
type WSCRule struct {
|
||||
Action string `json:"action"`
|
||||
Args []interface{} `json:"args"`
|
||||
Action string `json:"action"`
|
||||
Direction string `json:"direction,omitempty"`
|
||||
Args []interface{} `json:"args"`
|
||||
}
|
||||
|
||||
type WSCInboundOptions struct {
|
||||
ListenOptions
|
||||
InboundTLSOptionsContainer
|
||||
MaxConnectionPerUser int `json:"max_connections,omitempty"`
|
||||
UsageTraffic WSCUsageReport `json:"usage_traffic,omitempty"`
|
||||
Multiplex *InboundMultiplexOptions `json:"multiplex,omitempty"`
|
||||
Transport *V2RayTransportOptions `json:"transport,omitempty"`
|
||||
MaxConnectionPerUser int `json:"max_connections,omitempty"`
|
||||
UsageTraffic WSCUsageReport `json:"usage_traffic,omitempty"`
|
||||
}
|
||||
|
||||
type WSCOutboundOptions struct {
|
||||
DialerOptions
|
||||
ServerOptions
|
||||
OutboundTLSOptionsContainer
|
||||
Network NetworkList `json:"network,omitempty"`
|
||||
Auth string `json:"auth"`
|
||||
Path string `json:"path"`
|
||||
Rules []WSCRule `json:"rules,omitempty"`
|
||||
Multiplex *OutboundMultiplexOptions `json:"multiplex,omitempty"`
|
||||
Transport *V2RayTransportOptions `json:"transport,omitempty"`
|
||||
Network NetworkList `json:"network,omitempty"`
|
||||
Auth string `json:"auth"`
|
||||
Rules []WSCRule `json:"rules,omitempty"`
|
||||
}
|
||||
|
167
outbound/wsc.go
167
outbound/wsc.go
@ -3,37 +3,54 @@ package outbound
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
"net/url"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/common/dialer"
|
||||
"github.com/sagernet/sing-box/common/mux"
|
||||
"github.com/sagernet/sing-box/common/tls"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing-box/transport/v2ray"
|
||||
"github.com/sagernet/sing-box/transport/wsc"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/network"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
var _ adapter.Outbound = &WSC{}
|
||||
|
||||
var _ N.Dialer = &wscDialer{}
|
||||
|
||||
type WSC struct {
|
||||
myOutboundAdapter
|
||||
dialer N.Dialer
|
||||
tlsConfig tls.Config
|
||||
client adapter.WSCClientTransport
|
||||
dialer N.Dialer
|
||||
serverAddr metadata.Socksaddr
|
||||
multiplexDialer *mux.Client
|
||||
tlsConfig tls.Config
|
||||
transport adapter.V2RayClientTransport
|
||||
auth string
|
||||
ruleApplicator *wsc.WSCRuleApplicator
|
||||
}
|
||||
|
||||
type wscDialer WSC
|
||||
|
||||
func NewWSC(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.WSCOutboundOptions) (*WSC, error) {
|
||||
outboundDialer, err := dialer.New(router, options.DialerOptions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var ruleApplicator *wsc.WSCRuleApplicator = nil
|
||||
if len(options.Rules) > 0 {
|
||||
if ruleApplicator, err = wsc.NewRuleApplicator(options.Rules); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
outbound := &WSC{
|
||||
myOutboundAdapter: myOutboundAdapter{
|
||||
protocol: C.TypeWSC,
|
||||
@ -43,20 +60,11 @@ func NewWSC(ctx context.Context, router adapter.Router, logger log.ContextLogger
|
||||
tag: tag,
|
||||
dependencies: withDialerDependency(options.DialerOptions),
|
||||
},
|
||||
dialer: outboundDialer,
|
||||
dialer: outboundDialer,
|
||||
auth: options.Auth,
|
||||
ruleApplicator: ruleApplicator,
|
||||
}
|
||||
|
||||
serverAddr := options.ServerOptions.Build()
|
||||
|
||||
if options.Auth == "" {
|
||||
return nil, exceptions.New("Invalid Auth to use in authentications")
|
||||
}
|
||||
if !serverAddr.IsValid() {
|
||||
return nil, exceptions.New("Invalid server address")
|
||||
}
|
||||
if options.Path == "" {
|
||||
options.Path = "/"
|
||||
}
|
||||
if options.TLS != nil {
|
||||
outbound.tlsConfig, err = tls.NewClient(ctx, options.Server, common.PtrValueOrDefault(options.TLS))
|
||||
if err != nil {
|
||||
@ -64,14 +72,16 @@ func NewWSC(ctx context.Context, router adapter.Router, logger log.ContextLogger
|
||||
}
|
||||
}
|
||||
|
||||
outbound.client, err = wsc.NewClient(wsc.ClientConfig{
|
||||
Auth: options.Auth,
|
||||
Host: serverAddr.String(),
|
||||
Path: options.Path,
|
||||
TLS: outbound.tlsConfig,
|
||||
Dialer: outbound.dialer,
|
||||
Rules: options.Rules,
|
||||
})
|
||||
outbound.serverAddr = options.ServerOptions.Build()
|
||||
|
||||
if options.Transport != nil {
|
||||
outbound.transport, err = v2ray.NewClientTransport(ctx, outbound.dialer, outbound.serverAddr, common.PtrValueOrDefault(options.Transport), outbound.tlsConfig)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "create client transport: ", options.Transport.Type)
|
||||
}
|
||||
}
|
||||
|
||||
outbound.multiplexDialer, err = mux.NewClientWithOptions((*wscDialer)(outbound), logger, common.PtrValueOrDefault(options.Multiplex))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -80,34 +90,105 @@ func NewWSC(ctx context.Context, router adapter.Router, logger log.ContextLogger
|
||||
}
|
||||
|
||||
func (wsc *WSC) DialContext(ctx context.Context, network string, destination metadata.Socksaddr) (net.Conn, error) {
|
||||
ctx, meta := adapter.ExtendContext(ctx)
|
||||
meta.Outbound = wsc.tag
|
||||
meta.Destination = destination
|
||||
if N.NetworkName(network) != N.NetworkTCP {
|
||||
return nil, exceptions.Extend(N.ErrUnknownNetwork, network)
|
||||
if wsc.multiplexDialer == nil {
|
||||
switch N.NetworkName(network) {
|
||||
case N.NetworkTCP:
|
||||
wsc.logger.InfoContext(ctx, "outbound connection to ", destination)
|
||||
case N.NetworkUDP:
|
||||
wsc.logger.InfoContext(ctx, "outbound packet connection to ", destination)
|
||||
}
|
||||
return (*wscDialer)(wsc).DialContext(ctx, network, destination)
|
||||
} else {
|
||||
switch N.NetworkName(network) {
|
||||
case N.NetworkTCP:
|
||||
wsc.logger.InfoContext(ctx, "outbound multiplex connection to ", destination)
|
||||
case N.NetworkUDP:
|
||||
wsc.logger.InfoContext(ctx, "outbound multiplex packet connection to ", destination)
|
||||
}
|
||||
return wsc.multiplexDialer.DialContext(ctx, network, destination)
|
||||
}
|
||||
wsc.logger.InfoContext(ctx, "WSC outbound connection to ", destination)
|
||||
return wsc.client.DialContext(ctx, network, destination.String())
|
||||
}
|
||||
|
||||
func (wsc *WSC) ListenPacket(ctx context.Context, destination metadata.Socksaddr) (net.PacketConn, error) {
|
||||
ctx, meta := adapter.ExtendContext(ctx)
|
||||
meta.Outbound = wsc.tag
|
||||
meta.Destination = destination
|
||||
wsc.logger.InfoContext(ctx, "WSC outbound packet to ", destination)
|
||||
return wsc.client.ListenPacket(ctx, N.NetworkUDP, destination.String())
|
||||
if wsc.multiplexDialer == nil {
|
||||
wsc.logger.InfoContext(ctx, "outbound packet connection to ", destination)
|
||||
return (*wscDialer)(wsc).ListenPacket(ctx, destination)
|
||||
} else {
|
||||
wsc.logger.InfoContext(ctx, "outbound multiplex packet connection to ", destination)
|
||||
return wsc.multiplexDialer.ListenPacket(ctx, destination)
|
||||
}
|
||||
}
|
||||
|
||||
func (wsc *WSC) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
|
||||
return NewConnection(ctx, wsc, conn, metadata)
|
||||
}
|
||||
|
||||
func (wsc *WSC) NewPacketConnection(ctx context.Context, conn network.PacketConn, metadata adapter.InboundContext) error {
|
||||
func (wsc *WSC) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
|
||||
return NewPacketConnection(ctx, wsc, conn, metadata)
|
||||
}
|
||||
|
||||
func (wsc *WSC) Close() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
defer cancel()
|
||||
return wsc.client.Close(ctx)
|
||||
func (wsc *WSC) InterfaceUpdated() {
|
||||
if wsc.transport != nil {
|
||||
wsc.transport.Close()
|
||||
}
|
||||
if wsc.multiplexDialer != nil {
|
||||
wsc.multiplexDialer.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
func (wsc *WSC) Close() error {
|
||||
return common.Close(common.PtrOrNil(wsc.multiplexDialer))
|
||||
}
|
||||
|
||||
func (dialer *wscDialer) DialContext(ctx context.Context, network string, destination metadata.Socksaddr) (net.Conn, error) {
|
||||
ctx, metadata := adapter.ExtendContext(ctx)
|
||||
metadata.Outbound = dialer.tag
|
||||
metadata.Destination = destination
|
||||
|
||||
ep, netw := destination.String(), network
|
||||
if dialer.ruleApplicator != nil {
|
||||
ep, netw = dialer.ruleApplicator.ApplyEndpointReplace(ep, netw, wsc.RuleDirectionOutbound)
|
||||
}
|
||||
|
||||
params := url.Values{}
|
||||
params.Set("auth", dialer.auth)
|
||||
params.Set("ep", ep)
|
||||
params.Set("net", netw)
|
||||
ctx = context.WithValue(ctx, adapter.V2RayExtraOptionsKey, adapter.V2RayExtraOptions{
|
||||
QueryParams: params,
|
||||
})
|
||||
|
||||
var conn net.Conn
|
||||
var err error
|
||||
|
||||
if dialer.transport != nil {
|
||||
conn, err = dialer.transport.DialContext(ctx)
|
||||
} else {
|
||||
conn, err = dialer.dialer.DialContext(ctx, N.NetworkTCP, dialer.serverAddr)
|
||||
if err == nil && dialer.tlsConfig != nil {
|
||||
conn, err = tls.ClientHandshake(ctx, conn, dialer.tlsConfig)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
common.Close(conn)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch N.NetworkName(network) {
|
||||
case N.NetworkTCP:
|
||||
return wsc.NewClientConn(conn, destination)
|
||||
case N.NetworkUDP:
|
||||
packetConn, err := wsc.NewClientPacketConn(conn, dialer.ruleApplicator)
|
||||
return bufio.NewBindPacketConn(packetConn, destination), err
|
||||
default:
|
||||
return nil, E.Extend(N.ErrUnknownNetwork, network)
|
||||
}
|
||||
}
|
||||
|
||||
func (dialer *wscDialer) ListenPacket(ctx context.Context, destination metadata.Socksaddr) (net.PacketConn, error) {
|
||||
conn, err := dialer.DialContext(ctx, N.NetworkUDP, destination)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return conn.(net.PacketConn), nil
|
||||
}
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"encoding"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
@ -62,7 +63,7 @@ func (payload *packetConnPayload) MarshalBinaryUnsafe(data []byte) error {
|
||||
}
|
||||
|
||||
if len(data) < hLen+len(payload.payload) {
|
||||
return errors.New("invalid data length to write")
|
||||
return io.ErrShortBuffer
|
||||
}
|
||||
|
||||
addr := payload.addrPort.Addr().As16()
|
||||
|
235
transport/wsc/protocol.go
Normal file
235
transport/wsc/protocol.go
Normal file
@ -0,0 +1,235 @@
|
||||
package wsc
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
"github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/metadata"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/network"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
var _ N.EarlyConn = &ClientConn{}
|
||||
|
||||
var _ net.Conn = &ClientPacketConn{}
|
||||
var _ net.PacketConn = &ClientPacketConn{}
|
||||
var _ N.NetPacketReader = &ClientPacketConn{}
|
||||
var _ N.NetPacketWriter = &ClientPacketConn{}
|
||||
|
||||
var _ N.NetPacketReader = &servicePacketConn{}
|
||||
var _ N.NetPacketWriter = &servicePacketConn{}
|
||||
|
||||
type ClientConn struct {
|
||||
N.ExtendedConn
|
||||
destination M.Socksaddr
|
||||
}
|
||||
|
||||
type ClientPacketConn struct {
|
||||
net.Conn
|
||||
ruleApplicator *WSCRuleApplicator
|
||||
writePayload packetConnPayload
|
||||
readPayload packetConnPayload
|
||||
packet [buf.UDPBufferSize]byte
|
||||
}
|
||||
|
||||
type servicePacketConn struct {
|
||||
net.Conn
|
||||
writePayload packetConnPayload
|
||||
readPayload packetConnPayload
|
||||
packet [buf.UDPBufferSize]byte
|
||||
}
|
||||
|
||||
func NewClientConn(conn net.Conn, destination M.Socksaddr) (*ClientConn, error) {
|
||||
return &ClientConn{
|
||||
ExtendedConn: bufio.NewExtendedConn(conn),
|
||||
destination: destination,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func NewClientPacketConn(conn net.Conn, ruleApplicator *WSCRuleApplicator) (*ClientPacketConn, error) {
|
||||
return &ClientPacketConn{
|
||||
Conn: conn,
|
||||
ruleApplicator: ruleApplicator,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (conn *ClientConn) NeedHandshake() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (conn *ClientConn) FrontHeadroom() int {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (conn *ClientConn) Upstream() any {
|
||||
return conn.ExtendedConn
|
||||
}
|
||||
|
||||
func (packetConn *ClientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
||||
if buffer == nil {
|
||||
return exceptions.New("buffer is nil")
|
||||
}
|
||||
|
||||
if packetConn.ruleApplicator != nil {
|
||||
ep, _ := packetConn.ruleApplicator.ApplyEndpointReplace(destination.String(), network.NetworkUDP, RuleDirectionOutbound)
|
||||
packetConn.writePayload.addrPort = metadata.ParseSocksaddr(ep).AddrPort()
|
||||
} else {
|
||||
packetConn.writePayload.addrPort = destination.AddrPort()
|
||||
}
|
||||
packetConn.writePayload.payload = buffer.Bytes()
|
||||
payloadBytes, err := packetConn.writePayload.MarshalBinary()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = packetConn.Conn.Write(payloadBytes)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (packetConn *ClientPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
|
||||
if buffer == nil {
|
||||
return destination, exceptions.New("buffer is nil")
|
||||
}
|
||||
|
||||
n, err := packetConn.Conn.Read(packetConn.packet[:])
|
||||
if err != nil {
|
||||
return destination, err
|
||||
}
|
||||
|
||||
if err := packetConn.readPayload.UnmarshalBinaryUnsafe(packetConn.packet[:n]); err != nil {
|
||||
return destination, err
|
||||
}
|
||||
|
||||
if _, err := buffer.Write(packetConn.readPayload.payload); err != nil {
|
||||
return destination, err
|
||||
}
|
||||
|
||||
destination = metadata.SocksaddrFromNetIP(packetConn.readPayload.addrPort)
|
||||
|
||||
if packetConn.ruleApplicator != nil {
|
||||
ep, _ := packetConn.ruleApplicator.ApplyEndpointReplace(destination.String(), N.NetworkUDP, RuleDirectionInbound)
|
||||
destination = metadata.ParseSocksaddr(ep)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (packetConn *ClientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
buffer := buf.With(p)
|
||||
destination, err := packetConn.ReadPacket(buffer)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
n = buffer.Len()
|
||||
if destination.IsFqdn() {
|
||||
addr = destination
|
||||
} else {
|
||||
addr = destination.UDPAddr()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (packetConn *ClientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
return bufio.WritePacket(packetConn, p, addr)
|
||||
}
|
||||
|
||||
func (packetConn *ClientPacketConn) Read(b []byte) (n int, err error) {
|
||||
n, _, err = packetConn.ReadFrom(b)
|
||||
return
|
||||
}
|
||||
|
||||
func (packetConn *ClientPacketConn) Write(b []byte) (n int, err error) {
|
||||
return 0, os.ErrInvalid
|
||||
}
|
||||
|
||||
func (packetConn *ClientPacketConn) NeedHandshake() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (packetConn *ClientPacketConn) FrontHeadroom() int {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (packetConn *ClientPacketConn) Upstream() any {
|
||||
return packetConn.Conn
|
||||
}
|
||||
|
||||
func (packetConn *servicePacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
||||
if buffer == nil {
|
||||
return exceptions.New("buffer is nil")
|
||||
}
|
||||
|
||||
packetConn.writePayload.addrPort = destination.AddrPort()
|
||||
packetConn.writePayload.payload = buffer.Bytes()
|
||||
payloadBytes, err := packetConn.writePayload.MarshalBinary()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = packetConn.Conn.Write(payloadBytes)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (packetConn *servicePacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
return bufio.WritePacket(packetConn, p, addr)
|
||||
}
|
||||
|
||||
func (packetConn *servicePacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
buffer := buf.With(p)
|
||||
destination, err := packetConn.ReadPacket(buffer)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
n = buffer.Len()
|
||||
if destination.IsFqdn() {
|
||||
addr = destination
|
||||
} else {
|
||||
addr = destination.UDPAddr()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (packetConn *servicePacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
|
||||
if buffer == nil {
|
||||
return destination, exceptions.New("buffer is nil")
|
||||
}
|
||||
|
||||
n, err := packetConn.Conn.Read(packetConn.packet[:])
|
||||
if err != nil {
|
||||
return destination, err
|
||||
}
|
||||
|
||||
if err := packetConn.readPayload.UnmarshalBinaryUnsafe(packetConn.packet[:n]); err != nil {
|
||||
return destination, err
|
||||
}
|
||||
|
||||
if _, err := buffer.Write(packetConn.readPayload.payload); err != nil {
|
||||
return destination, err
|
||||
}
|
||||
|
||||
destination = metadata.SocksaddrFromNetIP(packetConn.readPayload.addrPort)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (packetConn *servicePacketConn) NeedHandshake() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (packetConn *servicePacketConn) FrontHeadroom() int {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (packetConn *servicePacketConn) NeedAdditionalReadDeadline() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (packetConn *servicePacketConn) Upstream() any {
|
||||
return packetConn.Conn
|
||||
}
|
@ -1,6 +1,7 @@
|
||||
package wsc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
|
||||
"github.com/sagernet/sing-box/option"
|
||||
@ -8,15 +9,23 @@ import (
|
||||
)
|
||||
|
||||
type RuleAction int
|
||||
type RuleDirection int
|
||||
|
||||
const (
|
||||
RuleActionUnknown RuleAction = iota
|
||||
RuleActionReplace
|
||||
)
|
||||
|
||||
const (
|
||||
RuleDirectionUnknown RuleDirection = iota
|
||||
RuleDirectionInbound
|
||||
RuleDirectionOutbound
|
||||
)
|
||||
|
||||
type WSCRule struct {
|
||||
Action RuleAction
|
||||
Args []interface{}
|
||||
Action RuleAction
|
||||
Direction RuleDirection
|
||||
Args []interface{}
|
||||
}
|
||||
|
||||
type WSCRuleApplicator struct {
|
||||
@ -30,9 +39,17 @@ func NewRuleApplicator(rules []option.WSCRule) (*WSCRuleApplicator, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var direction RuleDirection = RuleDirectionUnknown
|
||||
if len(rule.Direction) > 0 {
|
||||
direction, err = RuleDirectionFromString(rule.Direction)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
wscRules = append(wscRules, WSCRule{
|
||||
Action: action,
|
||||
Args: rule.Args,
|
||||
Action: action,
|
||||
Direction: direction,
|
||||
Args: rule.Args,
|
||||
})
|
||||
}
|
||||
return &WSCRuleApplicator{
|
||||
@ -40,13 +57,14 @@ func NewRuleApplicator(rules []option.WSCRule) (*WSCRuleApplicator, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (ruleManager *WSCRuleApplicator) ApplyEndpointReplace(ep string, netw string) (finalEp string, finalNetw string) {
|
||||
finalEp, finalNetw = ep, netw
|
||||
|
||||
func (ruleManager *WSCRuleApplicator) ApplyEndpointReplace(ep string, netw string, direction RuleDirection) (finalEp string, finalNetw string) {
|
||||
for _, rule := range ruleManager.Rules {
|
||||
if rule.Action != RuleActionReplace {
|
||||
continue
|
||||
}
|
||||
if rule.Direction != RuleDirectionUnknown && direction != RuleDirectionUnknown && rule.Direction != direction {
|
||||
continue
|
||||
}
|
||||
|
||||
sType, ok := rule.Args[0].(string)
|
||||
if !ok {
|
||||
@ -80,7 +98,18 @@ func (ruleManager *WSCRuleApplicator) ApplyEndpointReplace(ep string, netw strin
|
||||
epAddr := metadata.ParseSocksaddr(ep)
|
||||
|
||||
equal := false
|
||||
if (whatAddr.IsFqdn() && epAddr.IsFqdn() && whatAddr.Fqdn == epAddr.Fqdn) || whatAddr.Addr.Compare(epAddr.Addr) == 0 {
|
||||
if whatAddr.IsFqdn() && epAddr.IsFqdn() && whatAddr.Fqdn == epAddr.Fqdn {
|
||||
equal = true
|
||||
} else if whatAddr.IsIPv4() {
|
||||
if epAddr.IsIPv4() || epAddr.Addr.Is4In6() {
|
||||
whatAddr4 := whatAddr.Addr.As4()
|
||||
epAddr4 := epAddr.Addr.As4()
|
||||
equal = bytes.Equal(whatAddr4[:], epAddr4[:])
|
||||
}
|
||||
} else if whatAddr.IsIPv6() && epAddr.IsIPv6() {
|
||||
equal = whatAddr.Addr.Compare(epAddr.Addr) == 0
|
||||
}
|
||||
if equal {
|
||||
if whatAddr.Port == 0 {
|
||||
equal = true
|
||||
} else {
|
||||
@ -117,6 +146,17 @@ func RuleActionFromString(actionStr string) (RuleAction, error) {
|
||||
case "replace":
|
||||
return RuleActionReplace, nil
|
||||
default:
|
||||
return 0, errors.New("rule action doesn't exist")
|
||||
return RuleActionUnknown, errors.New("rule action doesn't exist")
|
||||
}
|
||||
}
|
||||
|
||||
func RuleDirectionFromString(directionStr string) (RuleDirection, error) {
|
||||
switch directionStr {
|
||||
case "inbound":
|
||||
return RuleDirectionInbound, nil
|
||||
case "outbound":
|
||||
return RuleDirectionOutbound, nil
|
||||
default:
|
||||
return RuleDirectionUnknown, errors.New("rule direction doesn't exist")
|
||||
}
|
||||
}
|
||||
|
202
transport/wsc/service.go
Normal file
202
transport/wsc/service.go
Normal file
@ -0,0 +1,202 @@
|
||||
package wsc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/network"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
type Handler interface {
|
||||
N.TCPConnectionHandler
|
||||
N.UDPConnectionHandler
|
||||
E.Handler
|
||||
}
|
||||
|
||||
type Service struct {
|
||||
logger logger.ContextLogger
|
||||
router adapter.Router
|
||||
handler Handler
|
||||
authenticator Authenticator
|
||||
userManager *wscUserManager
|
||||
}
|
||||
|
||||
type ServiceConfig struct {
|
||||
Logger logger.ContextLogger
|
||||
Router adapter.Router
|
||||
Handler Handler
|
||||
Authenticator Authenticator
|
||||
MaxConnectionPerUser int
|
||||
UsageReportTrafficInterval int64
|
||||
UsageReportTimeInterval time.Duration
|
||||
}
|
||||
|
||||
type meteredConn struct {
|
||||
net.Conn
|
||||
user *wscUser
|
||||
}
|
||||
|
||||
func NewService(config ServiceConfig) (*Service, error) {
|
||||
if config.Handler == nil {
|
||||
return nil, E.New("Handler required")
|
||||
}
|
||||
if config.Authenticator == nil {
|
||||
return nil, E.New("Authenticator required")
|
||||
}
|
||||
return &Service{
|
||||
logger: config.Logger,
|
||||
router: config.Router,
|
||||
handler: config.Handler,
|
||||
authenticator: config.Authenticator,
|
||||
userManager: &wscUserManager{
|
||||
users: map[int64]*wscUser{},
|
||||
authenticator: config.Authenticator,
|
||||
maxConnPerUser: config.MaxConnectionPerUser,
|
||||
usageReportTrafficInterval: config.UsageReportTrafficInterval,
|
||||
usageReportTimeInterval: config.UsageReportTimeInterval,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (service *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
|
||||
params, err := service.readQueryParams(conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
auth := params.Get("auth")
|
||||
if auth == "" {
|
||||
return E.New("authentication required")
|
||||
}
|
||||
|
||||
account, err := service.authenticator.Authenticate(ctx, AuthenticateParams{
|
||||
Auth: auth,
|
||||
MaxConn: service.userManager.maxConnPerUser,
|
||||
})
|
||||
if err != nil {
|
||||
if account.ID != 0 {
|
||||
if err := service.userManager.cleanupUser(ctx, account.ID, false); err != nil {
|
||||
return err
|
||||
}
|
||||
return E.Cause(err, "authentication failed")
|
||||
}
|
||||
}
|
||||
|
||||
// user cleanup
|
||||
//{}
|
||||
|
||||
user := service.userManager.findOrCreateUser(ctx, account.ID, account.Rate, account.MaxConn)
|
||||
|
||||
netw := params.Get("net")
|
||||
if netw == "" {
|
||||
netw = network.NetworkTCP
|
||||
}
|
||||
|
||||
endpoint := params.Get("ep")
|
||||
addr, err := service.resolveDestination(ctx, M.ParseSocksaddr(endpoint))
|
||||
if err != nil {
|
||||
return E.Cause(err, "failed to parse and resolve endpoint")
|
||||
}
|
||||
|
||||
service.log("New request (Client: ", metadata.Source, ", Auth: ", auth, ", User-ID: ", account.ID, ", ", netw+"-Addr: ", addr.String(), ")")
|
||||
|
||||
metadata.Protocol = C.TypeWSC
|
||||
metadata.Destination = addr
|
||||
|
||||
if popedConn, err := user.addConn(conn); err != nil {
|
||||
return err
|
||||
} else {
|
||||
if popedConn != nil {
|
||||
popedConn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
switch N.NetworkName(netw) {
|
||||
case N.NetworkTCP:
|
||||
err = service.handler.NewConnection(ctx, &meteredConn{Conn: conn, user: user}, metadata)
|
||||
case N.NetworkUDP:
|
||||
return service.handler.NewPacketConnection(ctx, &servicePacketConn{
|
||||
Conn: &meteredConn{
|
||||
Conn: conn,
|
||||
user: user,
|
||||
},
|
||||
}, metadata)
|
||||
default:
|
||||
return E.New("not supported protocol ", netw)
|
||||
}
|
||||
|
||||
if cErr := service.userManager.cleanupUserConn(ctx, user, conn); cErr != nil && err == nil {
|
||||
err = cErr
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (service *Service) readQueryParams(conn net.Conn) (url.Values, error) {
|
||||
var queryParamsRaw [500]byte
|
||||
n, err := conn.Read(queryParamsRaw[:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pURL := url.URL{
|
||||
RawQuery: string(queryParamsRaw[:n]),
|
||||
}
|
||||
return pURL.Query(), nil
|
||||
}
|
||||
|
||||
func (service *Service) resolveDestination(ctx context.Context, dest M.Socksaddr) (M.Socksaddr, error) {
|
||||
if dest.IsFqdn() {
|
||||
addrs, err := service.router.LookupDefault(ctx, dest.Fqdn)
|
||||
if err != nil {
|
||||
return M.Socksaddr{}, err
|
||||
}
|
||||
if len(addrs) == 0 {
|
||||
return M.Socksaddr{}, E.New("no address found for endpoint domain: ", dest.Fqdn)
|
||||
}
|
||||
return M.Socksaddr{
|
||||
Addr: addrs[0],
|
||||
Port: dest.Port,
|
||||
}, nil
|
||||
}
|
||||
return dest, nil
|
||||
}
|
||||
|
||||
func (service *Service) log(args ...any) {
|
||||
if service.logger != nil {
|
||||
service.logger.Debug(args...)
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *meteredConn) Read(p []byte) (int, error) {
|
||||
reader, err := conn.user.connReader(conn.Conn)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
n, err := reader.Read(p)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
conn.user.usedTrafficBytes.Add(int64(n))
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (conn *meteredConn) Write(p []byte) (int, error) {
|
||||
writer, err := conn.user.connWriter(conn.Conn)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
n, err := writer.Write(p)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
conn.user.usedTrafficBytes.Add(int64(n))
|
||||
return n, nil
|
||||
}
|
@ -27,7 +27,6 @@ type wscUser struct {
|
||||
reportedTrafficBytes atomic.Int64
|
||||
lastTrafficUpdateTick atomic.Int64
|
||||
conns map[net.Conn]connData
|
||||
heap []byte
|
||||
rateLimit int64
|
||||
maxConnCount int
|
||||
usedIds []bool
|
||||
@ -37,7 +36,6 @@ func (manager *wscUserManager) newUser(id int64, usedTrafficBytes int64, maxConn
|
||||
user := &wscUser{
|
||||
id: id,
|
||||
conns: make(map[net.Conn]connData, maxConnCount),
|
||||
heap: make([]byte, connReadSize*2*maxConnCount),
|
||||
rateLimit: rateLimit,
|
||||
usedIds: make([]bool, maxConnCount),
|
||||
maxConnCount: maxConnCount,
|
||||
@ -48,40 +46,6 @@ func (manager *wscUserManager) newUser(id int64, usedTrafficBytes int64, maxConn
|
||||
return user
|
||||
}
|
||||
|
||||
func (user *wscUser) outBuffer(conn net.Conn) []byte {
|
||||
user.mu.Lock()
|
||||
defer user.mu.Unlock()
|
||||
|
||||
if user.maxConnCount < 1 {
|
||||
return make([]byte, connReadSize)
|
||||
}
|
||||
|
||||
if d, found := user.conns[conn]; found {
|
||||
bufStart := (connReadSize*2)*(d.id+1) - connReadSize
|
||||
bufEnd := bufStart + connReadSize
|
||||
return user.heap[bufStart:bufEnd]
|
||||
}
|
||||
|
||||
return make([]byte, connReadSize)
|
||||
}
|
||||
|
||||
func (user *wscUser) inBuffer(conn net.Conn) []byte {
|
||||
user.mu.Lock()
|
||||
defer user.mu.Unlock()
|
||||
|
||||
if user.maxConnCount < 1 {
|
||||
return make([]byte, connReadSize)
|
||||
}
|
||||
|
||||
if d, found := user.conns[conn]; found {
|
||||
bufStart := (connReadSize * 2) * d.id
|
||||
bufEnd := bufStart + connReadSize
|
||||
return user.heap[bufStart:bufEnd]
|
||||
}
|
||||
|
||||
return make([]byte, connReadSize)
|
||||
}
|
||||
|
||||
func (user *wscUser) connReader(conn net.Conn) (io.Reader, error) {
|
||||
user.mu.Lock()
|
||||
defer user.mu.Unlock()
|
||||
|
@ -1,19 +1,9 @@
|
||||
package wsc
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/itsabgr/ge"
|
||||
)
|
||||
|
||||
func nowns() int64 {
|
||||
return time.Now().UnixNano()
|
||||
}
|
||||
|
||||
func isTimeoutErr(err error) bool {
|
||||
if nErr, ok := ge.As[net.Error](err); ok && nErr.Timeout() {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user