mirror of
https://github.com/SagerNet/sing-box.git
synced 2025-09-09 13:04:06 +08:00
fstable
This commit is contained in:
parent
6b783e6674
commit
8815618648
@ -62,6 +62,7 @@ func NewWSC(ctx context.Context, router adapter.Router, logger log.ContextLogger
|
||||
logger: logger,
|
||||
},
|
||||
Router: router,
|
||||
Dialer: network.SystemDialer,
|
||||
MaxConnectionPerUser: options.MaxConnectionPerUser,
|
||||
UsageReportTrafficInterval: options.UsageTraffic.Traffic,
|
||||
UsageReportTimeInterval: time.Duration(options.UsageTraffic.Time),
|
||||
@ -138,7 +139,7 @@ func (auth *CustomAuthenticator) Authenticate(ctx context.Context, params wsc.Au
|
||||
return wsc.AuthenticateResult{
|
||||
ID: int64(auth.id),
|
||||
Rate: math.MaxInt64,
|
||||
MaxConn: 60,
|
||||
MaxConn: params.MaxConn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -5,6 +5,11 @@ type WSCUsageReport struct {
|
||||
Time Duration `json:"time,omitempty"`
|
||||
}
|
||||
|
||||
type WSCRule struct {
|
||||
Action string `json:"action"`
|
||||
Args []interface{} `json:"args"`
|
||||
}
|
||||
|
||||
type WSCInboundOptions struct {
|
||||
ListenOptions
|
||||
InboundTLSOptionsContainer
|
||||
@ -19,4 +24,5 @@ type WSCOutboundOptions struct {
|
||||
Network NetworkList `json:"network,omitempty"`
|
||||
Auth string `json:"auth"`
|
||||
Path string `json:"path"`
|
||||
Rules []WSCRule `json:"rules,omitempty"`
|
||||
}
|
||||
|
@ -70,6 +70,7 @@ func NewWSC(ctx context.Context, router adapter.Router, logger log.ContextLogger
|
||||
Path: options.Path,
|
||||
TLS: outbound.tlsConfig,
|
||||
Dialer: outbound.dialer,
|
||||
Rules: options.Rules,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -3,7 +3,8 @@ package wsc
|
||||
import "context"
|
||||
|
||||
type AuthenticateParams struct {
|
||||
Auth string
|
||||
Auth string
|
||||
MaxConn int
|
||||
}
|
||||
|
||||
type AuthenticateResult struct {
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/common/tls"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/ws"
|
||||
@ -16,29 +17,42 @@ import (
|
||||
var _ adapter.WSCClientTransport = &Client{}
|
||||
|
||||
type Client struct {
|
||||
auth string
|
||||
host string
|
||||
path string
|
||||
tls tls.Config
|
||||
dialer N.Dialer
|
||||
auth string
|
||||
host string
|
||||
path string
|
||||
tls tls.Config
|
||||
dialer N.Dialer
|
||||
endpointReplace map[string]string
|
||||
ruleApplicator *WSCRuleApplicator
|
||||
}
|
||||
|
||||
type ClientConfig struct {
|
||||
Auth string
|
||||
Host string
|
||||
Path string
|
||||
TLS tls.Config
|
||||
Dialer N.Dialer
|
||||
Auth string
|
||||
Host string
|
||||
Path string
|
||||
TLS tls.Config
|
||||
Dialer N.Dialer
|
||||
EndpointReplace map[string]string
|
||||
Rules []option.WSCRule
|
||||
}
|
||||
|
||||
func NewClient(params ClientConfig) (*Client, error) {
|
||||
return &Client{
|
||||
auth: params.Auth,
|
||||
host: params.Host,
|
||||
path: params.Path,
|
||||
tls: params.TLS,
|
||||
dialer: params.Dialer,
|
||||
}, nil
|
||||
ruleApplicator, err := NewRuleApplicator(params.Rules)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cli := &Client{
|
||||
auth: params.Auth,
|
||||
host: params.Host,
|
||||
path: params.Path,
|
||||
tls: params.TLS,
|
||||
dialer: params.Dialer,
|
||||
endpointReplace: params.EndpointReplace,
|
||||
ruleApplicator: ruleApplicator,
|
||||
}
|
||||
|
||||
return cli, nil
|
||||
}
|
||||
|
||||
func (cli *Client) DialContext(ctx context.Context, network string, endpoint string) (net.Conn, error) {
|
||||
@ -46,7 +60,7 @@ func (cli *Client) DialContext(ctx context.Context, network string, endpoint str
|
||||
}
|
||||
|
||||
func (cli *Client) ListenPacket(ctx context.Context, network string, endpoint string) (net.PacketConn, error) {
|
||||
return cli.newPacketConn(ctx, network, endpoint)
|
||||
return cli.newPacketConn(ctx, cli.ruleApplicator, network, endpoint)
|
||||
}
|
||||
|
||||
func (cli *Client) Close(ctx context.Context) error {
|
||||
@ -128,6 +142,12 @@ func (cli *Client) newURL(scheme string, path string, endpoint string, network s
|
||||
path = cli.path
|
||||
}
|
||||
|
||||
if with, exists := cli.endpointReplace[endpoint]; exists {
|
||||
endpoint = with
|
||||
}
|
||||
|
||||
endpoint, network = cli.ruleApplicator.ApplyEndpointReplace(endpoint, network)
|
||||
|
||||
pURL := url.URL{
|
||||
Scheme: scheme,
|
||||
Host: cli.host,
|
||||
|
@ -25,21 +25,23 @@ type readerCache struct {
|
||||
|
||||
type clientPacketConn struct {
|
||||
net.Conn
|
||||
reader *wsutil.Reader
|
||||
cache *readerCache
|
||||
mu sync.Mutex
|
||||
reader *wsutil.Reader
|
||||
cache *readerCache
|
||||
mu sync.Mutex
|
||||
ruleApplicator *WSCRuleApplicator
|
||||
}
|
||||
|
||||
func (cli *Client) newPacketConn(ctx context.Context, network string, endpoint string) (*clientPacketConn, error) {
|
||||
func (cli *Client) newPacketConn(ctx context.Context, ruleApplicator *WSCRuleApplicator, network string, endpoint string) (*clientPacketConn, error) {
|
||||
conn, err := cli.newWSConn(ctx, network, endpoint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
reader := wsutil.NewReader(conn, ws.StateClientSide)
|
||||
return &clientPacketConn{
|
||||
Conn: conn,
|
||||
reader: reader,
|
||||
cache: nil,
|
||||
Conn: conn,
|
||||
reader: reader,
|
||||
cache: nil,
|
||||
ruleApplicator: ruleApplicator,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -63,12 +65,13 @@ func (packetConn *clientPacketConn) ReadPacket(buffer *buf.Buffer) (destination
|
||||
}
|
||||
|
||||
destination = metadata.SocksaddrFromNetIP(payload.addrPort)
|
||||
ep, _ := packetConn.ruleApplicator.ApplyEndpointReplace(destination.String(), network.NetworkUDP)
|
||||
|
||||
if _, err := buffer.Write(payload.payload); err != nil {
|
||||
return metadata.Socksaddr{}, err
|
||||
}
|
||||
|
||||
return destination, nil
|
||||
return metadata.ParseSocksaddr(ep), nil
|
||||
}
|
||||
|
||||
func (packetConn *clientPacketConn) WritePacket(buffer *buf.Buffer, destination metadata.Socksaddr) error {
|
||||
@ -76,8 +79,10 @@ func (packetConn *clientPacketConn) WritePacket(buffer *buf.Buffer, destination
|
||||
return errors.New("buffer is nil")
|
||||
}
|
||||
|
||||
ep, _ := packetConn.ruleApplicator.ApplyEndpointReplace(destination.String(), network.NetworkUDP)
|
||||
|
||||
payload := packetConnPayload{
|
||||
addrPort: destination.AddrPort(),
|
||||
addrPort: metadata.ParseSocksaddr(ep).AddrPort(),
|
||||
payload: buffer.Bytes(),
|
||||
}
|
||||
payloadBytes, err := payload.MarshalBinary()
|
||||
@ -122,9 +127,12 @@ func (packetConn *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, er
|
||||
return 0, nil, err
|
||||
}
|
||||
|
||||
ep, _ := packetConn.ruleApplicator.ApplyEndpointReplace(payload.addrPort.String(), network.NetworkUDP)
|
||||
|
||||
packetConn.cache = &readerCache{
|
||||
reader: bytes.NewReader(payload.payload),
|
||||
addr: metadata.SocksaddrFromNetIP(payload.addrPort),
|
||||
// addr: metadata.SocksaddrFromNetIP(payload.addrPort),
|
||||
addr: metadata.ParseSocksaddr(ep),
|
||||
}
|
||||
|
||||
n, err = packetConn.cache.reader.Read(p)
|
||||
@ -137,8 +145,11 @@ func (packetConn *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, er
|
||||
}
|
||||
|
||||
func (packetConn *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
ep, _ := packetConn.ruleApplicator.ApplyEndpointReplace(addr.String(), network.NetworkUDP)
|
||||
|
||||
payload := packetConnPayload{
|
||||
addrPort: metadata.SocksaddrFromNet(addr).AddrPort(),
|
||||
// addrPort: metadata.SocksaddrFromNet(addr).AddrPort(),
|
||||
addrPort: metadata.ParseSocksaddr(ep).AddrPort(),
|
||||
payload: p,
|
||||
}
|
||||
payloadBytes, err := payload.MarshalBinary()
|
||||
|
122
transport/wsc/rule.go
Normal file
122
transport/wsc/rule.go
Normal file
@ -0,0 +1,122 @@
|
||||
package wsc
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing/common/metadata"
|
||||
)
|
||||
|
||||
type RuleAction int
|
||||
|
||||
const (
|
||||
RuleActionUnknown RuleAction = iota
|
||||
RuleActionReplace
|
||||
)
|
||||
|
||||
type WSCRule struct {
|
||||
Action RuleAction
|
||||
Args []interface{}
|
||||
}
|
||||
|
||||
type WSCRuleApplicator struct {
|
||||
Rules []WSCRule
|
||||
}
|
||||
|
||||
func NewRuleApplicator(rules []option.WSCRule) (*WSCRuleApplicator, error) {
|
||||
wscRules := make([]WSCRule, 0, len(rules))
|
||||
for _, rule := range rules {
|
||||
action, err := RuleActionFromString(rule.Action)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
wscRules = append(wscRules, WSCRule{
|
||||
Action: action,
|
||||
Args: rule.Args,
|
||||
})
|
||||
}
|
||||
return &WSCRuleApplicator{
|
||||
Rules: wscRules,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (ruleManager *WSCRuleApplicator) ApplyEndpointReplace(ep string, netw string) (finalEp string, finalNetw string) {
|
||||
finalEp, finalNetw = ep, netw
|
||||
|
||||
for _, rule := range ruleManager.Rules {
|
||||
if rule.Action != RuleActionReplace {
|
||||
continue
|
||||
}
|
||||
|
||||
sType, ok := rule.Args[0].(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
what, ok := rule.Args[1].(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
with, ok := rule.Args[2].(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
switch sType {
|
||||
case "endpoint":
|
||||
{
|
||||
var proto string
|
||||
var protoWith string
|
||||
var protoOk bool = false
|
||||
var protoWithOk bool = false
|
||||
if len(rule.Args) > 3 {
|
||||
proto, protoOk = rule.Args[3].(string)
|
||||
}
|
||||
if len(rule.Args) > 4 {
|
||||
protoWith, protoWithOk = rule.Args[4].(string)
|
||||
}
|
||||
|
||||
whatAddr := metadata.ParseSocksaddr(what)
|
||||
withAddr := metadata.ParseSocksaddr(with)
|
||||
epAddr := metadata.ParseSocksaddr(ep)
|
||||
|
||||
equal := false
|
||||
if (whatAddr.IsFqdn() && epAddr.IsFqdn() && whatAddr.Fqdn == epAddr.Fqdn) || whatAddr.Addr.Compare(epAddr.Addr) == 0 {
|
||||
if whatAddr.Port == 0 {
|
||||
equal = true
|
||||
} else {
|
||||
equal = whatAddr.Port == epAddr.Port
|
||||
}
|
||||
}
|
||||
if protoOk {
|
||||
equal = equal && netw == proto
|
||||
}
|
||||
|
||||
if equal {
|
||||
port := withAddr.Port
|
||||
if port == 0 {
|
||||
port = epAddr.Port
|
||||
}
|
||||
ep = (metadata.Socksaddr{
|
||||
Addr: withAddr.Addr,
|
||||
Port: port,
|
||||
Fqdn: withAddr.Fqdn,
|
||||
}).String()
|
||||
if protoWithOk {
|
||||
netw = protoWith
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ep, netw
|
||||
}
|
||||
|
||||
func RuleActionFromString(actionStr string) (RuleAction, error) {
|
||||
switch actionStr {
|
||||
case "replace":
|
||||
return RuleActionReplace, nil
|
||||
default:
|
||||
return 0, errors.New("rule action doesn't exist")
|
||||
}
|
||||
}
|
@ -30,6 +30,7 @@ type Server struct {
|
||||
authenticator Authenticator
|
||||
userManager *wscUserManager
|
||||
router adapter.Router
|
||||
dialer network.Dialer
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
@ -38,6 +39,7 @@ type ServerConfig struct {
|
||||
Handler adapter.WSCServerTransportHandler
|
||||
Authenticator Authenticator
|
||||
Router adapter.Router
|
||||
Dialer network.Dialer
|
||||
MaxConnectionPerUser int
|
||||
UsageReportTrafficInterval int64
|
||||
UsageReportTimeInterval time.Duration
|
||||
@ -55,6 +57,7 @@ func NewServer(config ServerConfig) (*Server, error) {
|
||||
logger: config.Logger,
|
||||
authenticator: config.Authenticator,
|
||||
router: config.Router,
|
||||
dialer: config.Dialer,
|
||||
userManager: &wscUserManager{
|
||||
users: map[int64]*wscUser{},
|
||||
authenticator: config.Authenticator,
|
||||
@ -71,6 +74,7 @@ func NewServer(config ServerConfig) (*Server, error) {
|
||||
return config.Ctx
|
||||
},
|
||||
}
|
||||
|
||||
return server, nil
|
||||
}
|
||||
|
||||
@ -79,12 +83,13 @@ func (server *Server) ServeHTTP(res http.ResponseWriter, req *http.Request) {
|
||||
|
||||
auth := req.URL.Query().Get("auth")
|
||||
if auth == "" {
|
||||
server.failRequest(res, req, "Authentication required", http.StatusBadRequest, 0, "", metadata.Socksaddr{})
|
||||
server.failRequest(res, req, "Authentication required", http.StatusBadRequest, 0, "", &metadata.Socksaddr{})
|
||||
return
|
||||
}
|
||||
|
||||
account, err := server.authenticator.Authenticate(ctx, AuthenticateParams{
|
||||
Auth: auth,
|
||||
Auth: auth,
|
||||
MaxConn: server.userManager.maxConnPerUser,
|
||||
})
|
||||
if err != nil {
|
||||
if account.ID != 0 {
|
||||
@ -92,20 +97,20 @@ func (server *Server) ServeHTTP(res http.ResponseWriter, req *http.Request) {
|
||||
server.logger.Debug("Request failed. Couldn't cleanup user: ", err.Error(), " (Client: ", req.RemoteAddr, ", User-ID: ", account.ID, ")")
|
||||
}
|
||||
}
|
||||
server.failRequest(res, req, "Authentication failed: "+err.Error(), http.StatusBadRequest, account.ID, "", metadata.Socksaddr{})
|
||||
server.failRequest(res, req, "Authentication failed: "+err.Error(), http.StatusBadRequest, account.ID, "", &metadata.Socksaddr{})
|
||||
return
|
||||
}
|
||||
|
||||
if req.Method == http.MethodPost && req.URL.Path == "/cleanup" {
|
||||
if err := server.userManager.cleanupUser(ctx, account.ID, true); err != nil {
|
||||
server.failRequest(res, req, "Failed to cleanup user: "+err.Error(), http.StatusInternalServerError, account.ID, "", metadata.Socksaddr{})
|
||||
server.failRequest(res, req, "Failed to cleanup user: "+err.Error(), http.StatusInternalServerError, account.ID, "", &metadata.Socksaddr{})
|
||||
return
|
||||
}
|
||||
res.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
user := server.userManager.findOrCreateUser(ctx, account.ID, account.Rate)
|
||||
user := server.userManager.findOrCreateUser(ctx, account.ID, account.Rate, account.MaxConn)
|
||||
|
||||
netW := req.URL.Query().Get("net")
|
||||
if netW == "" {
|
||||
@ -115,7 +120,7 @@ func (server *Server) ServeHTTP(res http.ResponseWriter, req *http.Request) {
|
||||
endpoint := req.URL.Query().Get("ep")
|
||||
addr, err := server.resolveDestination(ctx, metadata.ParseSocksaddr(endpoint))
|
||||
if err != nil {
|
||||
server.failRequest(res, req, "Failed to parse and resolve endpoint: "+err.Error(), http.StatusBadRequest, account.ID, netW, addr)
|
||||
server.failRequest(res, req, "Failed to parse and resolve endpoint: "+err.Error(), http.StatusBadRequest, account.ID, netW, &addr)
|
||||
return
|
||||
}
|
||||
|
||||
@ -123,7 +128,7 @@ func (server *Server) ServeHTTP(res http.ResponseWriter, req *http.Request) {
|
||||
|
||||
conn, _, _, err := ws.UpgradeHTTP(req, res)
|
||||
if err != nil {
|
||||
server.failRequest(res, req, "Websocket upgrade failed: "+err.Error(), http.StatusBadRequest, account.ID, netW, addr)
|
||||
server.failRequest(res, req, "Websocket upgrade failed: "+err.Error(), http.StatusBadRequest, account.ID, netW, &addr)
|
||||
return
|
||||
}
|
||||
|
||||
@ -136,9 +141,44 @@ func (server *Server) ServeHTTP(res http.ResponseWriter, req *http.Request) {
|
||||
}
|
||||
}()
|
||||
|
||||
server.logger.Info("serve http called: ", req.URL.String(), " | ", req.RemoteAddr, " | ", endpoint, " | ", addr)
|
||||
res.Write([]byte("endpoint is : " + endpoint))
|
||||
res.WriteHeader(http.StatusOK)
|
||||
if err := server.pipeConn(ctx, user, conn, netW, &addr); err != nil {
|
||||
server.failRequest(res, req, "Failed to pipe connection: "+err.Error(), http.StatusInternalServerError, account.ID, netW, &addr)
|
||||
}
|
||||
}
|
||||
|
||||
func (server *Server) pipeConn(ctx context.Context, user *wscUser, conn net.Conn, netW string, addr *metadata.Socksaddr) error {
|
||||
if poppedConn, err := user.addConn(conn); err != nil {
|
||||
return err
|
||||
} else {
|
||||
if poppedConn != nil {
|
||||
poppedConn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
switch netW {
|
||||
case network.NetworkTCP:
|
||||
{
|
||||
piper := serverTCPPiper{
|
||||
conn: conn,
|
||||
user: user,
|
||||
addr: addr,
|
||||
dialer: server.dialer,
|
||||
}
|
||||
return piper.pipe(ctx)
|
||||
}
|
||||
case network.NetworkUDP:
|
||||
{
|
||||
piper := serverUDPPiper{
|
||||
conn: conn,
|
||||
user: user,
|
||||
addr: addr,
|
||||
dialer: server.dialer,
|
||||
}
|
||||
return piper.pipe(ctx)
|
||||
}
|
||||
default:
|
||||
return errors.New("network " + netW + " not supported")
|
||||
}
|
||||
}
|
||||
|
||||
func (server *Server) Close() error {
|
||||
@ -174,7 +214,7 @@ func (server *Server) resolveDestination(ctx context.Context, dest metadata.Sock
|
||||
return dest, nil
|
||||
}
|
||||
|
||||
func (server *Server) failRequest(res http.ResponseWriter, request *http.Request, msg string, code int, uid int64, network string, addr metadata.Socksaddr) {
|
||||
func (server *Server) failRequest(res http.ResponseWriter, request *http.Request, msg string, code int, uid int64, network string, addr *metadata.Socksaddr) {
|
||||
http.Error(res, msg, code)
|
||||
|
||||
info := "(Client: " + request.RemoteAddr
|
||||
|
170
transport/wsc/server_tcp_piper.go
Normal file
170
transport/wsc/server_tcp_piper.go
Normal file
@ -0,0 +1,170 @@
|
||||
package wsc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/ws"
|
||||
"github.com/sagernet/ws/wsutil"
|
||||
)
|
||||
|
||||
type serverTCPPiper struct {
|
||||
conn net.Conn
|
||||
user *wscUser
|
||||
addr *metadata.Socksaddr
|
||||
dialer network.Dialer
|
||||
}
|
||||
|
||||
func (piper *serverTCPPiper) pipe(ctx context.Context) error {
|
||||
remote, err := piper.prepare(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer remote.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
var mu sync.Mutex
|
||||
var wg sync.WaitGroup
|
||||
var gErr error = nil
|
||||
collectErr := func(err error) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
gErr = errors.Join(gErr, err)
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer cancel()
|
||||
defer wg.Done()
|
||||
if err := piper.pipeInbound(ctx, remote); err != nil {
|
||||
collectErr(err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := piper.pipeOutbount(ctx, remote); err != nil {
|
||||
collectErr(err)
|
||||
}
|
||||
cancel()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
return gErr
|
||||
}
|
||||
|
||||
func (piper *serverTCPPiper) pipeInbound(ctx context.Context, remote net.Conn) error {
|
||||
clientInReader, err := piper.user.connReader(piper.conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
clientOut, err := piper.user.connWriter(piper.conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
clientIn := wsutil.NewReader(clientInReader, ws.StateServerSide)
|
||||
buf := piper.user.inBuffer(piper.conn)
|
||||
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
return nil
|
||||
}
|
||||
if err := piper.conn.SetReadDeadline(time.Now().Add(time.Millisecond * 300)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
header, err := clientIn.NextFrame()
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
return nil
|
||||
}
|
||||
if isTimeoutErr(err) {
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
pass := false
|
||||
switch header.OpCode {
|
||||
case ws.OpPing:
|
||||
wsutil.WriteServerMessage(clientOut, ws.OpPong, nil)
|
||||
pass = true
|
||||
case ws.OpPong:
|
||||
pass = true
|
||||
case ws.OpClose:
|
||||
wsutil.WriteServerMessage(clientOut, ws.OpClose, nil)
|
||||
return nil
|
||||
}
|
||||
if pass {
|
||||
continue
|
||||
}
|
||||
|
||||
for {
|
||||
n, err := clientIn.Read(buf)
|
||||
if n > 0 {
|
||||
if _, wErr := remote.Write(buf[:n]); wErr != nil {
|
||||
return wErr
|
||||
} else {
|
||||
piper.user.usedTrafficBytes.Add(int64(n))
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (piper *serverTCPPiper) pipeOutbount(ctx context.Context, remote net.Conn) error {
|
||||
clientOut, err := piper.user.connWriter(piper.conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
buf := piper.user.outBuffer(piper.conn)
|
||||
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := remote.SetReadDeadline(time.Now().Add(time.Millisecond * 300)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
n, err := remote.Read(buf)
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
return nil
|
||||
}
|
||||
if isTimeoutErr(err) {
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
piper.user.usedTrafficBytes.Add(int64(n))
|
||||
|
||||
if err := wsutil.WriteServerBinary(clientOut, buf[:n]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (piper *serverTCPPiper) prepare(ctx context.Context) (net.Conn, error) {
|
||||
remote, err := piper.dialer.DialContext(ctx, network.NetworkTCP, *piper.addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return remote, nil
|
||||
}
|
186
transport/wsc/server_udp_piper.go
Normal file
186
transport/wsc/server_udp_piper.go
Normal file
@ -0,0 +1,186 @@
|
||||
package wsc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/ws"
|
||||
"github.com/sagernet/ws/wsutil"
|
||||
)
|
||||
|
||||
type serverUDPPiper struct {
|
||||
conn net.Conn
|
||||
user *wscUser
|
||||
addr *metadata.Socksaddr
|
||||
dialer network.Dialer
|
||||
}
|
||||
|
||||
func (piper *serverUDPPiper) pipe(ctx context.Context) error {
|
||||
remote, err := piper.prepare(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer remote.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
var mu sync.Mutex
|
||||
var wg sync.WaitGroup
|
||||
var gErr error = nil
|
||||
collectErr := func(err error) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
gErr = errors.Join(gErr, err)
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer cancel()
|
||||
defer wg.Done()
|
||||
if err := piper.pipeInbound(ctx, remote); err != nil {
|
||||
collectErr(err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := piper.pipeOutbount(ctx, remote); err != nil {
|
||||
collectErr(err)
|
||||
}
|
||||
cancel()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
return gErr
|
||||
}
|
||||
|
||||
func (piper *serverUDPPiper) pipeInbound(ctx context.Context, remote net.PacketConn) error {
|
||||
clientInReader, err := piper.user.connReader(piper.conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
clientOut, err := piper.user.connWriter(piper.conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
clientIn := wsutil.NewReader(clientInReader, ws.StateServerSide)
|
||||
buf := piper.user.inBuffer(piper.conn)
|
||||
payload := packetConnPayload{}
|
||||
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
return nil
|
||||
}
|
||||
if err := piper.conn.SetReadDeadline(time.Now().Add(time.Millisecond * 300)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
header, err := clientIn.NextFrame()
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
return nil
|
||||
}
|
||||
if isTimeoutErr(err) {
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
pass := false
|
||||
switch header.OpCode {
|
||||
case ws.OpPing:
|
||||
wsutil.WriteServerMessage(clientOut, ws.OpPong, nil)
|
||||
pass = true
|
||||
case ws.OpPong:
|
||||
pass = true
|
||||
case ws.OpClose:
|
||||
wsutil.WriteServerMessage(clientOut, ws.OpClose, nil)
|
||||
return nil
|
||||
}
|
||||
if pass {
|
||||
continue
|
||||
}
|
||||
|
||||
for {
|
||||
n, err := clientIn.Read(buf)
|
||||
if n > 0 {
|
||||
if err := payload.UnmarshalBinaryUnsafe(buf[:n]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, wErr := remote.WriteTo(payload.payload, net.UDPAddrFromAddrPort(payload.addrPort)); wErr != nil {
|
||||
return wErr
|
||||
} else {
|
||||
piper.user.usedTrafficBytes.Add(int64(n))
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (piper *serverUDPPiper) pipeOutbount(ctx context.Context, remote net.PacketConn) error {
|
||||
clientOut, err := piper.user.connWriter(piper.conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
buf := piper.user.outBuffer(piper.conn)
|
||||
payload := packetConnPayload{}
|
||||
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
return nil
|
||||
}
|
||||
if err := remote.SetReadDeadline(time.Now().Add(time.Millisecond * 300)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
n, netAddr, err := remote.ReadFrom(buf)
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
return nil
|
||||
}
|
||||
if isTimeoutErr(err) {
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if ua, ok := netAddr.(*net.UDPAddr); ok {
|
||||
payload.addrPort = ua.AddrPort()
|
||||
} else {
|
||||
return errors.New("unexpected addr type")
|
||||
}
|
||||
payload.payload = buf[:n]
|
||||
payloadBytes, err := payload.MarshalBinary()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
piper.user.usedTrafficBytes.Add(int64(n))
|
||||
|
||||
if err := wsutil.WriteServerBinary(clientOut, payloadBytes); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (piper *serverUDPPiper) prepare(ctx context.Context) (net.PacketConn, error) {
|
||||
remote, err := piper.dialer.ListenPacket(ctx, *piper.addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return remote, nil
|
||||
}
|
@ -154,7 +154,9 @@ func (user *wscUser) removeConn(conn net.Conn) error {
|
||||
user.mu.Lock()
|
||||
defer user.mu.Unlock()
|
||||
if d, exists := user.conns[conn]; exists {
|
||||
user.usedIds[d.id] = false
|
||||
if user.maxConnCount > 0 {
|
||||
user.usedIds[d.id] = false
|
||||
}
|
||||
delete(user.conns, conn)
|
||||
return nil
|
||||
}
|
||||
|
@ -17,14 +17,14 @@ type wscUserManager struct {
|
||||
authenticator Authenticator
|
||||
}
|
||||
|
||||
func (manager *wscUserManager) findOrCreateUser(ctx context.Context, uid int64, rateLimit int64) *wscUser {
|
||||
func (manager *wscUserManager) findOrCreateUser(ctx context.Context, uid int64, rateLimit int64, maxConn int) *wscUser {
|
||||
manager.mu.Lock()
|
||||
defer manager.mu.Unlock()
|
||||
if user, exists := manager.users[uid]; exists {
|
||||
manager.reportUser(ctx, user, false)
|
||||
return user
|
||||
}
|
||||
user := manager.newUser(uid, 0, manager.maxConnPerUser, rateLimit)
|
||||
user := manager.newUser(uid, 0, min(manager.maxConnPerUser, maxConn), rateLimit)
|
||||
manager.users[uid] = user
|
||||
return user
|
||||
}
|
||||
|
@ -1,7 +1,19 @@
|
||||
package wsc
|
||||
|
||||
import "time"
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/itsabgr/ge"
|
||||
)
|
||||
|
||||
func nowns() int64 {
|
||||
return time.Now().UnixNano()
|
||||
}
|
||||
|
||||
func isTimeoutErr(err error) bool {
|
||||
if nErr, ok := ge.As[net.Error](err); ok && nErr.Timeout() {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user