package mux import ( "encoding/binary" "io" "math/rand" "net" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/rw" "github.com/sagernet/smux" "github.com/hashicorp/yamux" ) var Destination = M.Socksaddr{ Fqdn: "sp.mux.sing-box.arpa", Port: 444, } const ( ProtocolSMux Protocol = iota ProtocolYAMux ProtocolH2Mux ) type Protocol byte func ParseProtocol(name string) (Protocol, error) { switch name { case "", "smux": return ProtocolSMux, nil case "yamux": return ProtocolYAMux, nil case "h2mux": return ProtocolH2Mux, nil default: return ProtocolSMux, E.New("unknown multiplex protocol: ", name) } } func (p Protocol) newServer(conn net.Conn) (abstractSession, error) { switch p { case ProtocolSMux: session, err := smux.Server(conn, smuxConfig()) if err != nil { return nil, err } return &smuxSession{session}, nil case ProtocolYAMux: session, err := yamux.Server(conn, yaMuxConfig()) if err != nil { return nil, err } return &yamuxSession{session}, nil case ProtocolH2Mux: return NewH2MuxServer(conn), nil default: panic("unknown protocol") } } func (p Protocol) newClient(conn net.Conn) (abstractSession, error) { switch p { case ProtocolSMux: session, err := smux.Client(conn, smuxConfig()) if err != nil { return nil, err } return &smuxSession{session}, nil case ProtocolYAMux: session, err := yamux.Client(conn, yaMuxConfig()) if err != nil { return nil, err } return &yamuxSession{session}, nil case ProtocolH2Mux: return NewH2MuxClient(conn) default: panic("unknown protocol") } } func smuxConfig() *smux.Config { config := smux.DefaultConfig() config.KeepAliveDisabled = true return config } func yaMuxConfig() *yamux.Config { config := yamux.DefaultConfig() config.LogOutput = io.Discard config.StreamCloseTimeout = C.TCPTimeout config.StreamOpenTimeout = C.TCPTimeout return config } func (p Protocol) String() string { switch p { case ProtocolSMux: return "smux" case ProtocolYAMux: return "yamux" case ProtocolH2Mux: return "h2mux" default: return "unknown" } } const ( Version0 = iota Version1 ) type Request struct { Version byte Protocol Protocol PaddingEnabled bool } func ReadRequest(reader io.Reader) (*Request, error) { version, err := rw.ReadByte(reader) if err != nil { return nil, err } if version < Version0 || version > Version1 { return nil, E.New("unsupported version: ", version) } protocol, err := rw.ReadByte(reader) if err != nil { return nil, err } var paddingEnabled bool if version == Version1 { err = binary.Read(reader, binary.BigEndian, &paddingEnabled) if err != nil { return nil, err } if paddingEnabled { var paddingLen uint16 err = binary.Read(reader, binary.BigEndian, &paddingLen) if err != nil { return nil, err } err = rw.SkipN(reader, int(paddingLen)) if err != nil { return nil, err } } } return &Request{Version: version, Protocol: Protocol(protocol), PaddingEnabled: paddingEnabled}, nil } func EncodeRequest(request Request, payload []byte) *buf.Buffer { var requestLen int requestLen += 2 var paddingLen uint16 if request.Version == Version1 { requestLen += 1 if request.PaddingEnabled { requestLen += 2 paddingLen = uint16(256 + rand.Intn(512)) requestLen += int(paddingLen) } } buffer := buf.NewSize(requestLen + len(payload)) common.Must( buffer.WriteByte(request.Version), buffer.WriteByte(byte(request.Protocol)), ) if request.Version == Version1 { common.Must(binary.Write(buffer, binary.BigEndian, request.PaddingEnabled)) if request.PaddingEnabled { common.Must(binary.Write(buffer, binary.BigEndian, paddingLen)) buffer.Extend(int(paddingLen)) } } common.Must1(buffer.Write(payload)) return buffer } const ( flagUDP = 1 flagAddr = 2 statusSuccess = 0 statusError = 1 ) type StreamRequest struct { Network string Destination M.Socksaddr PacketAddr bool } func ReadStreamRequest(reader io.Reader) (*StreamRequest, error) { var flags uint16 err := binary.Read(reader, binary.BigEndian, &flags) if err != nil { return nil, err } destination, err := M.SocksaddrSerializer.ReadAddrPort(reader) if err != nil { return nil, err } var network string var udpAddr bool if flags&flagUDP == 0 { network = N.NetworkTCP } else { network = N.NetworkUDP udpAddr = flags&flagAddr != 0 } return &StreamRequest{network, destination, udpAddr}, nil } func streamRequestLen(request StreamRequest) int { var rLen int rLen += 1 // version rLen += 2 // flags rLen += M.SocksaddrSerializer.AddrPortLen(request.Destination) return rLen } func EncodeStreamRequest(request StreamRequest, buffer *buf.Buffer) { destination := request.Destination var flags uint16 if request.Network == N.NetworkUDP { flags |= flagUDP } if request.PacketAddr { flags |= flagAddr if !destination.IsValid() { destination = Destination } } common.Must( binary.Write(buffer, binary.BigEndian, flags), M.SocksaddrSerializer.WriteAddrPort(buffer, destination), ) } type StreamResponse struct { Status uint8 Message string } func ReadStreamResponse(reader io.Reader) (*StreamResponse, error) { var response StreamResponse status, err := rw.ReadByte(reader) if err != nil { return nil, err } response.Status = status if status == statusError { response.Message, err = rw.ReadVString(reader) if err != nil { return nil, err } } return &response, nil } type wrapStream struct { net.Conn } func (w *wrapStream) Read(p []byte) (n int, err error) { n, err = w.Conn.Read(p) err = wrapError(err) return } func (w *wrapStream) Write(p []byte) (n int, err error) { n, err = w.Conn.Write(p) err = wrapError(err) return } func (w *wrapStream) WriteIsThreadUnsafe() { } func (w *wrapStream) Upstream() any { return w.Conn } func wrapError(err error) error { switch err { case yamux.ErrStreamClosed: return io.EOF default: return err } }