diff --git a/common/mux/client.go b/common/mux/client.go index 7c7b9a2b..4be9915d 100644 --- a/common/mux/client.go +++ b/common/mux/client.go @@ -15,40 +15,44 @@ import ( M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/x/list" - - "github.com/hashicorp/yamux" ) var _ N.Dialer = (*Client)(nil) type Client struct { access sync.Mutex - connections list.List[*yamux.Session] + connections list.List[abstractSession] ctx context.Context dialer N.Dialer + protocol Protocol maxConnections int minStreams int maxStreams int } -func NewClient(ctx context.Context, dialer N.Dialer, maxConnections int, minStreams int, maxStreams int) *Client { +func NewClient(ctx context.Context, dialer N.Dialer, protocol Protocol, maxConnections int, minStreams int, maxStreams int) *Client { return &Client{ ctx: ctx, dialer: dialer, + protocol: protocol, maxConnections: maxConnections, minStreams: minStreams, maxStreams: maxStreams, } } -func NewClientWithOptions(ctx context.Context, dialer N.Dialer, options option.MultiplexOptions) N.Dialer { +func NewClientWithOptions(ctx context.Context, dialer N.Dialer, options option.MultiplexOptions) (N.Dialer, error) { if !options.Enabled { - return dialer + return dialer, nil } if options.MaxConnections == 0 && options.MaxStreams == 0 { options.MinStreams = 8 } - return NewClient(ctx, dialer, options.MaxConnections, options.MinStreams, options.MaxStreams) + protocol, err := ParseProtocol(options.Protocol) + if err != nil { + return nil, err + } + return NewClient(ctx, dialer, protocol, options.MaxConnections, options.MinStreams, options.MaxStreams), nil } func (c *Client) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { @@ -80,8 +84,8 @@ func (c *Client) ListenPacket(ctx context.Context, destination M.Socksaddr) (net func (c *Client) openStream() (net.Conn, error) { var ( - session *yamux.Session - stream *yamux.Stream + session abstractSession + stream net.Conn err error ) for attempts := 0; attempts < 2; attempts++ { @@ -89,7 +93,7 @@ func (c *Client) openStream() (net.Conn, error) { if err != nil { continue } - stream, err = session.OpenStream() + stream, err = session.Open() if err != nil { continue } @@ -101,11 +105,11 @@ func (c *Client) openStream() (net.Conn, error) { return &wrapStream{stream}, nil } -func (c *Client) offer() (*yamux.Session, error) { +func (c *Client) offer() (abstractSession, error) { c.access.Lock() defer c.access.Unlock() - sessions := make([]*yamux.Session, 0, c.maxConnections) + sessions := make([]abstractSession, 0, c.maxConnections) for element := c.connections.Front(); element != nil; { if element.Value.IsClosed() { nextElement := element.Next() @@ -120,10 +124,7 @@ func (c *Client) offer() (*yamux.Session, error) { if sLen == 0 { return c.offerNew() } - // session := common.MinBy(sessions, yamux.Session.NumStreams) - session := common.MinBy(sessions, func(it *yamux.Session) int { - return it.NumStreams() - }) + session := common.MinBy(sessions, abstractSession.NumStreams) numStreams := session.NumStreams() if numStreams == 0 { return session, nil @@ -140,12 +141,12 @@ func (c *Client) offer() (*yamux.Session, error) { return c.offerNew() } -func (c *Client) offerNew() (*yamux.Session, error) { +func (c *Client) offerNew() (abstractSession, error) { conn, err := c.dialer.DialContext(c.ctx, N.NetworkTCP, Destination) if err != nil { return nil, err } - session, err := yamux.Client(conn, newMuxConfig()) + session, err := c.protocol.newClient(&protocolConn{Conn: conn, protocol: c.protocol}) if err != nil { return nil, err } @@ -170,7 +171,7 @@ type ClientConn struct { } func (c *ClientConn) readResponse() error { - response, err := ReadResponse(c.Conn) + response, err := ReadStreamResponse(c.Conn) if err != nil { return err } @@ -195,7 +196,7 @@ func (c *ClientConn) Write(b []byte) (n int, err error) { if c.requestWrite { return c.Conn.Write(b) } - request := Request{ + request := StreamRequest{ Network: N.NetworkTCP, Destination: c.destination, } @@ -203,7 +204,7 @@ func (c *ClientConn) Write(b []byte) (n int, err error) { defer common.KeepAlive(_buffer) buffer := common.Dup(_buffer) defer buffer.Release() - EncodeRequest(request, buffer) + EncodeStreamRequest(request, buffer) buffer.Write(b) _, err = c.Conn.Write(buffer.Bytes()) if err != nil { @@ -255,7 +256,7 @@ type ClientPacketConn struct { } func (c *ClientPacketConn) readResponse() error { - response, err := ReadResponse(c.ExtendedConn) + response, err := ReadStreamResponse(c.ExtendedConn) if err != nil { return err } @@ -285,7 +286,7 @@ func (c *ClientPacketConn) Read(b []byte) (n int, err error) { } func (c *ClientPacketConn) writeRequest(payload []byte) (n int, err error) { - request := Request{ + request := StreamRequest{ Network: N.NetworkUDP, Destination: c.destination, } @@ -297,7 +298,7 @@ func (c *ClientPacketConn) writeRequest(payload []byte) (n int, err error) { defer common.KeepAlive(_buffer) buffer := common.Dup(_buffer) defer buffer.Release() - EncodeRequest(request, buffer) + EncodeStreamRequest(request, buffer) if len(payload) > 0 { common.Must( binary.Write(buffer, binary.BigEndian, uint16(len(payload))), @@ -363,7 +364,7 @@ type ClientPacketAddrConn struct { } func (c *ClientPacketAddrConn) readResponse() error { - response, err := ReadResponse(c.ExtendedConn) + response, err := ReadStreamResponse(c.ExtendedConn) if err != nil { return err } @@ -399,7 +400,7 @@ func (c *ClientPacketAddrConn) ReadFrom(p []byte) (n int, addr net.Addr, err err } func (c *ClientPacketAddrConn) writeRequest(payload []byte, destination M.Socksaddr) (n int, err error) { - request := Request{ + request := StreamRequest{ Network: N.NetworkUDP, Destination: c.destination, PacketAddr: true, @@ -412,7 +413,7 @@ func (c *ClientPacketAddrConn) writeRequest(payload []byte, destination M.Socksa defer common.KeepAlive(_buffer) buffer := common.Dup(_buffer) defer buffer.Release() - EncodeRequest(request, buffer) + EncodeStreamRequest(request, buffer) if len(payload) > 0 { common.Must( M.SocksaddrSerializer.WriteAddrPort(buffer, destination), diff --git a/common/mux/protocol.go b/common/mux/protocol.go index 97f6de31..8ffb5a8c 100644 --- a/common/mux/protocol.go +++ b/common/mux/protocol.go @@ -14,6 +14,7 @@ import ( "github.com/sagernet/sing/common/rw" "github.com/hashicorp/yamux" + "github.com/xtaci/smux" ) var Destination = M.Socksaddr{ @@ -21,7 +22,55 @@ var Destination = M.Socksaddr{ Port: 444, } -func newMuxConfig() *yamux.Config { +const ( + ProtocolYAMux Protocol = 0 + ProtocolSMux Protocol = 1 +) + +type Protocol byte + +func ParseProtocol(name string) (Protocol, error) { + switch name { + case "", "yamux": + return ProtocolYAMux, nil + case "smux": + return ProtocolSMux, nil + default: + return ProtocolYAMux, E.New("unknown multiplex protocol: ", name) + } +} + +func (p Protocol) newServer(conn net.Conn) (abstractSession, error) { + switch p { + case ProtocolYAMux: + return yamux.Server(conn, yaMuxConfig()) + case ProtocolSMux: + session, err := smux.Server(conn, nil) + if err != nil { + return nil, err + } + return &smuxSession{session}, nil + default: + panic("unknown protocol") + } +} + +func (p Protocol) newClient(conn net.Conn) (abstractSession, error) { + switch p { + case ProtocolYAMux: + return yamux.Client(conn, yaMuxConfig()) + case ProtocolSMux: + session, err := smux.Client(conn, nil) + if err != nil { + return nil, err + } + return &smuxSession{session}, nil + default: + panic("unknown protocol") + } +} + +func yaMuxConfig() *yamux.Config { config := yamux.DefaultConfig() config.LogOutput = io.Discard config.StreamCloseTimeout = C.TCPTimeout @@ -29,18 +78,23 @@ func newMuxConfig() *yamux.Config { return config } +func (p Protocol) String() string { + switch p { + case ProtocolYAMux: + return "yamux" + case ProtocolSMux: + return "smux" + default: + return "unknown" + } +} + const ( - version0 = 0 - flagUDP = 1 - flagAddr = 2 - statusSuccess = 0 - statusError = 1 + version0 = 0 ) type Request struct { - Network string - Destination M.Socksaddr - PacketAddr bool + Protocol Protocol } func ReadRequest(reader io.Reader) (*Request, error) { @@ -51,8 +105,37 @@ func ReadRequest(reader io.Reader) (*Request, error) { if version != version0 { return nil, E.New("unsupported version: ", version) } + protocol, err := rw.ReadByte(reader) + if err != nil { + return nil, err + } + if protocol > byte(ProtocolSMux) { + return nil, E.New("unsupported protocol: ", protocol) + } + return &Request{Protocol: Protocol(protocol)}, nil +} + +func EncodeRequest(buffer *buf.Buffer, request Request) { + buffer.WriteByte(version0) + buffer.WriteByte(byte(request.Protocol)) +} + +const ( + flagUDP = 1 + flagAddr = 2 + statusSuccess = 0 + statusError = 1 +) + +type StreamRequest struct { + Network string + Destination M.Socksaddr + PacketAddr bool +} + +func ReadStreamRequest(reader io.Reader) (*StreamRequest, error) { var flags uint16 - err = binary.Read(reader, binary.BigEndian, &flags) + err := binary.Read(reader, binary.BigEndian, &flags) if err != nil { return nil, err } @@ -68,10 +151,10 @@ func ReadRequest(reader io.Reader) (*Request, error) { network = N.NetworkUDP udpAddr = flags&flagAddr != 0 } - return &Request{network, destination, udpAddr}, nil + return &StreamRequest{network, destination, udpAddr}, nil } -func requestLen(request Request) int { +func requestLen(request StreamRequest) int { var rLen int rLen += 1 // version rLen += 2 // flags @@ -79,7 +162,7 @@ func requestLen(request Request) int { return rLen } -func EncodeRequest(request Request, buffer *buf.Buffer) { +func EncodeStreamRequest(request StreamRequest, buffer *buf.Buffer) { destination := request.Destination var flags uint16 if request.Network == N.NetworkUDP { @@ -92,19 +175,18 @@ func EncodeRequest(request Request, buffer *buf.Buffer) { } } common.Must( - buffer.WriteByte(version0), binary.Write(buffer, binary.BigEndian, flags), M.SocksaddrSerializer.WriteAddrPort(buffer, destination), ) } -type Response struct { +type StreamResponse struct { Status uint8 Message string } -func ReadResponse(reader io.Reader) (*Response, error) { - var response Response +func ReadStreamResponse(reader io.Reader) (*StreamResponse, error) { + var response StreamResponse status, err := rw.ReadByte(reader) if err != nil { return nil, err diff --git a/common/mux/service.go b/common/mux/service.go index bed5a5ca..acb8ab9f 100644 --- a/common/mux/service.go +++ b/common/mux/service.go @@ -14,12 +14,14 @@ import ( M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/rw" - - "github.com/hashicorp/yamux" ) func NewConnection(ctx context.Context, router adapter.Router, errorHandler E.Handler, logger log.ContextLogger, conn net.Conn, metadata adapter.InboundContext) error { - session, err := yamux.Server(conn, newMuxConfig()) + request, err := ReadRequest(conn) + if err != nil { + return err + } + session, err := request.Protocol.newServer(conn) if err != nil { return err } @@ -34,7 +36,7 @@ func NewConnection(ctx context.Context, router adapter.Router, errorHandler E.Ha func newConnection(ctx context.Context, router adapter.Router, errorHandler E.Handler, logger log.ContextLogger, stream net.Conn, metadata adapter.InboundContext) { stream = &wrapStream{stream} - request, err := ReadRequest(stream) + request, err := ReadStreamRequest(stream) if err != nil { logger.ErrorContext(ctx, err) return diff --git a/common/mux/session.go b/common/mux/session.go new file mode 100644 index 00000000..da1b8b65 --- /dev/null +++ b/common/mux/session.go @@ -0,0 +1,71 @@ +package mux + +import ( + "io" + "net" + + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/bufio" + + "github.com/xtaci/smux" +) + +type abstractSession interface { + Open() (net.Conn, error) + Accept() (net.Conn, error) + NumStreams() int + Close() error + IsClosed() bool +} + +var _ abstractSession = (*smuxSession)(nil) + +type smuxSession struct { + *smux.Session +} + +func (s *smuxSession) Open() (net.Conn, error) { + return s.OpenStream() +} + +func (s *smuxSession) Accept() (net.Conn, error) { + return s.AcceptStream() +} + +type protocolConn struct { + net.Conn + protocol Protocol + protocolWritten bool +} + +func (c *protocolConn) Write(p []byte) (n int, err error) { + if c.protocolWritten { + return c.Conn.Write(p) + } + _buffer := buf.StackNewSize(2 + len(p)) + defer common.KeepAlive(_buffer) + buffer := common.Dup(_buffer) + defer buffer.Release() + EncodeRequest(buffer, Request{ + Protocol: c.protocol, + }) + common.Must(common.Error(buffer.Write(p))) + n, err = c.Conn.Write(buffer.Bytes()) + if err == nil { + n-- + } + c.protocolWritten = true + return n, err +} + +func (c *protocolConn) ReadFrom(r io.Reader) (n int64, err error) { + if !c.protocolWritten { + return bufio.ReadFrom0(c, r) + } + return bufio.Copy(c.Conn, r) +} + +func (c *protocolConn) Upstream() any { + return c.Conn +} diff --git a/docs/configuration/shared/multiplex.md b/docs/configuration/shared/multiplex.md index 534827b0..73b6bdd0 100644 --- a/docs/configuration/shared/multiplex.md +++ b/docs/configuration/shared/multiplex.md @@ -7,6 +7,7 @@ ```json { "enabled": true, + "protocol": "yamux", "max_connections": 4, "min_streams": 4, "max_streams": 0 @@ -19,6 +20,17 @@ Enable multiplex. +#### protocol + +Multiplex protocol. + +| Protocol | Description | +|----------|------------------------------------| +| yamux | https://github.com/hashicorp/yamux | +| smux | https://github.com/xtaci/smux | + +YAMux is used by default. + #### max_connections Maximum connections. diff --git a/go.mod b/go.mod index dd4d29d0..0b683761 100644 --- a/go.mod +++ b/go.mod @@ -20,6 +20,7 @@ require ( github.com/sagernet/sing-vmess v0.0.0-20220802053753-a38d3b22e6b9 github.com/spf13/cobra v1.5.0 github.com/stretchr/testify v1.8.0 + github.com/xtaci/smux v1.5.16 go.uber.org/atomic v1.9.0 golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa golang.org/x/net v0.0.0-20220802222814-0bcc04d9c69b diff --git a/go.sum b/go.sum index 4c494839..7ebd437b 100644 --- a/go.sum +++ b/go.sum @@ -201,6 +201,8 @@ github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49u github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM= github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 h1:gga7acRE695APm9hlsSMoOoE65U4/TcqNj90mc69Rlg= github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= +github.com/xtaci/smux v1.5.16 h1:FBPYOkW8ZTjLKUM4LI4xnnuuDC8CQ/dB04HD519WoEk= +github.com/xtaci/smux v1.5.16/go.mod h1:OMlQbT5vcgl2gb49mFkYo6SMf+zP3rcjcwQz7ZU7IGY= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA= diff --git a/option/outbound.go b/option/outbound.go index 843f4be3..ae88e6d7 100644 --- a/option/outbound.go +++ b/option/outbound.go @@ -100,8 +100,9 @@ func (o ServerOptions) Build() M.Socksaddr { } type MultiplexOptions struct { - Enabled bool `json:"enabled,omitempty"` - MaxConnections int `json:"max_connections,omitempty"` - MinStreams int `json:"min_streams,omitempty"` - MaxStreams int `json:"max_streams,omitempty"` + Enabled bool `json:"enabled,omitempty"` + Protocol string `json:"protocol,omitempty"` + MaxConnections int `json:"max_connections,omitempty"` + MinStreams int `json:"min_streams,omitempty"` + MaxStreams int `json:"max_streams,omitempty"` } diff --git a/outbound/shadowsocks.go b/outbound/shadowsocks.go index 6557e4c9..3fa4fd33 100644 --- a/outbound/shadowsocks.go +++ b/outbound/shadowsocks.go @@ -46,7 +46,10 @@ func NewShadowsocks(ctx context.Context, router adapter.Router, logger log.Conte method: method, serverAddr: options.ServerOptions.Build(), } - outbound.multiplexDialer = mux.NewClientWithOptions(ctx, (*shadowsocksDialer)(outbound), common.PtrValueOrDefault(options.Multiplex)) + outbound.multiplexDialer, err = mux.NewClientWithOptions(ctx, (*shadowsocksDialer)(outbound), common.PtrValueOrDefault(options.Multiplex)) + if err != nil { + return nil, err + } return outbound, nil } diff --git a/route/router_dns.go b/route/router_dns.go index 762d915d..8d8c8b44 100644 --- a/route/router_dns.go +++ b/route/router_dns.go @@ -32,13 +32,13 @@ func (r *Router) Exchange(ctx context.Context, message *dnsmessage.Message) (*dn } func (r *Router) Lookup(ctx context.Context, domain string, strategy dns.DomainStrategy) ([]netip.Addr, error) { - r.dnsLogger.Debug(ctx, "lookup domain ", domain) + r.dnsLogger.DebugContext(ctx, "lookup domain ", domain) ctx, transport := r.matchDNS(ctx) ctx, cancel := context.WithTimeout(ctx, C.DNSTimeout) defer cancel() addrs, err := r.dnsClient.Lookup(ctx, transport, domain, strategy) if len(addrs) > 0 { - r.logger.InfoContext(ctx, "lookup succeed for ", domain, ": ", F.MapToString(addrs)) + r.logger.InfoContext(ctx, "lookup succeed for ", domain, ": ", strings.Join(F.MapToString(addrs), " ")) } else { r.logger.ErrorContext(ctx, E.Cause(err, "lookup failed for ", domain)) } diff --git a/test/go.mod b/test/go.mod index 7a9de1b6..0015e00c 100644 --- a/test/go.mod +++ b/test/go.mod @@ -58,6 +58,7 @@ require ( github.com/sagernet/sing-vmess v0.0.0-20220802053753-a38d3b22e6b9 // indirect github.com/sirupsen/logrus v1.8.1 // indirect github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 // indirect + github.com/xtaci/smux v1.5.16 // indirect go.uber.org/atomic v1.9.0 // indirect golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa // indirect golang.org/x/mod v0.5.1 // indirect diff --git a/test/go.sum b/test/go.sum index 0d8dec8a..85fb2669 100644 --- a/test/go.sum +++ b/test/go.sum @@ -228,6 +228,8 @@ github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49u github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM= github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 h1:gga7acRE695APm9hlsSMoOoE65U4/TcqNj90mc69Rlg= github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= +github.com/xtaci/smux v1.5.16 h1:FBPYOkW8ZTjLKUM4LI4xnnuuDC8CQ/dB04HD519WoEk= +github.com/xtaci/smux v1.5.16/go.mod h1:OMlQbT5vcgl2gb49mFkYo6SMf+zP3rcjcwQz7ZU7IGY= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= diff --git a/test/mux_test.go b/test/mux_test.go index c16c29ba..6d03e15d 100644 --- a/test/mux_test.go +++ b/test/mux_test.go @@ -4,12 +4,24 @@ import ( "net/netip" "testing" + "github.com/sagernet/sing-box/common/mux" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing-shadowsocks/shadowaead_2022" ) func TestShadowsocksMux(t *testing.T) { + for _, protocol := range []mux.Protocol{ + mux.ProtocolYAMux, + mux.ProtocolSMux, + } { + t.Run(protocol.String(), func(t *testing.T) { + testShadowsocksMux(t, protocol.String()) + }) + } +} + +func testShadowsocksMux(t *testing.T, protocol string) { method := shadowaead_2022.List[0] password := mkBase64(t, 16) startInstance(t, option.Options{ @@ -54,7 +66,8 @@ func TestShadowsocksMux(t *testing.T) { Method: method, Password: password, Multiplex: &option.MultiplexOptions{ - Enabled: true, + Enabled: true, + Protocol: protocol, }, }, },