From b992d942c4d23929e632b62faa323b4cfb2163dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Thu, 18 Aug 2022 19:54:05 +0800 Subject: [PATCH] Add hysteria tcp server --- inbound/builder.go | 2 + inbound/hysteria.go | 256 +++++++++++++++++++++++++++++++++ inbound/hysteria_stub.go | 16 +++ inbound/naive.go | 4 +- option/hysteria.go | 47 +++--- option/inbound.go | 5 + outbound/hysteria.go | 35 ++--- test/hysteria_test.go | 75 ++++++++++ transport/hysteria/client.go | 20 ++- transport/hysteria/protocol.go | 138 ++++++++++++++++-- transport/hysteria/server.go | 68 +++++++++ 11 files changed, 608 insertions(+), 58 deletions(-) create mode 100644 inbound/hysteria.go create mode 100644 inbound/hysteria_stub.go create mode 100644 transport/hysteria/server.go diff --git a/inbound/builder.go b/inbound/builder.go index ee51d67e..ae5b1a6e 100644 --- a/inbound/builder.go +++ b/inbound/builder.go @@ -37,6 +37,8 @@ func New(ctx context.Context, router adapter.Router, logger log.ContextLogger, o return NewTrojan(ctx, router, logger, options.Tag, options.TrojanOptions) case C.TypeNaive: return NewNaive(ctx, router, logger, options.Tag, options.NaiveOptions) + case C.TypeHysteria: + return NewHysteria(ctx, router, logger, options.Tag, options.HysteriaOptions) default: return nil, E.New("unknown inbound type: ", options.Type) } diff --git a/inbound/hysteria.go b/inbound/hysteria.go new file mode 100644 index 00000000..f3fa337c --- /dev/null +++ b/inbound/hysteria.go @@ -0,0 +1,256 @@ +//go:build with_quic + +package inbound + +import ( + "bytes" + "context" + "net" + "net/netip" + + "github.com/sagernet/sing-box/adapter" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing-box/transport/hysteria" + dns "github.com/sagernet/sing-dns" + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + + "github.com/lucas-clemente/quic-go" +) + +var _ adapter.Inbound = (*Hysteria)(nil) + +type Hysteria struct { + ctx context.Context + router adapter.Router + logger log.ContextLogger + tag string + listenOptions option.ListenOptions + quicConfig *quic.Config + tlsConfig *TLSConfig + authKey []byte + xplusKey []byte + sendBPS uint64 + recvBPS uint64 + listener quic.Listener +} + +func NewHysteria(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.HysteriaInboundOptions) (*Hysteria, error) { + quicConfig := &quic.Config{ + InitialStreamReceiveWindow: options.ReceiveWindowConn, + MaxStreamReceiveWindow: options.ReceiveWindowConn, + InitialConnectionReceiveWindow: options.ReceiveWindowClient, + MaxConnectionReceiveWindow: options.ReceiveWindowClient, + MaxIncomingStreams: int64(options.MaxConnClient), + KeepAlivePeriod: hysteria.KeepAlivePeriod, + DisablePathMTUDiscovery: options.DisableMTUDiscovery, + EnableDatagrams: true, + } + if options.ReceiveWindowConn == 0 { + quicConfig.InitialStreamReceiveWindow = hysteria.DefaultStreamReceiveWindow + quicConfig.MaxStreamReceiveWindow = hysteria.DefaultStreamReceiveWindow + } + if options.ReceiveWindowClient == 0 { + quicConfig.InitialConnectionReceiveWindow = hysteria.DefaultConnectionReceiveWindow + quicConfig.MaxConnectionReceiveWindow = hysteria.DefaultConnectionReceiveWindow + } + if quicConfig.MaxIncomingStreams == 0 { + quicConfig.MaxIncomingStreams = hysteria.DefaultMaxIncomingStreams + } + var auth []byte + if len(options.Auth) > 0 { + auth = options.Auth + } else { + auth = []byte(options.AuthString) + } + var xplus []byte + if options.Obfs != "" { + xplus = []byte(options.Obfs) + } + var up, down uint64 + if len(options.Up) > 0 { + up = hysteria.StringToBps(options.Up) + if up == 0 { + return nil, E.New("invalid up speed format: ", options.Up) + } + } else { + up = uint64(options.UpMbps) * hysteria.MbpsToBps + } + if len(options.Down) > 0 { + down = hysteria.StringToBps(options.Down) + if down == 0 { + return nil, E.New("invalid down speed format: ", options.Down) + } + } else { + down = uint64(options.DownMbps) * hysteria.MbpsToBps + } + if up < hysteria.MinSpeedBPS { + return nil, E.New("invalid up speed") + } + if down < hysteria.MinSpeedBPS { + return nil, E.New("invalid down speed") + } + inbound := &Hysteria{ + ctx: ctx, + router: router, + logger: logger, + tag: tag, + quicConfig: quicConfig, + listenOptions: options.ListenOptions, + authKey: auth, + xplusKey: xplus, + sendBPS: up, + recvBPS: down, + } + if options.TLS == nil || !options.TLS.Enabled { + return nil, ErrTLSRequired + } + if len(options.TLS.ALPN) == 0 { + options.TLS.ALPN = []string{hysteria.DefaultALPN} + } + tlsConfig, err := NewTLSConfig(logger, common.PtrValueOrDefault(options.TLS)) + if err != nil { + return nil, err + } + inbound.tlsConfig = tlsConfig + return inbound, nil +} + +func (h *Hysteria) Type() string { + return C.TypeHysteria +} + +func (h *Hysteria) Tag() string { + return h.tag +} + +func (h *Hysteria) Start() error { + listenAddr := M.SocksaddrFrom(netip.Addr(h.listenOptions.Listen), h.listenOptions.ListenPort) + var packetConn net.PacketConn + var err error + packetConn, err = net.ListenUDP(M.NetworkFromNetAddr("udp", listenAddr.Addr), listenAddr.UDPAddr()) + if err != nil { + return err + } + if len(h.xplusKey) > 0 { + packetConn = hysteria.NewXPlusPacketConn(packetConn, h.xplusKey) + } + err = h.tlsConfig.Start() + if err != nil { + return err + } + listener, err := quic.Listen(packetConn, h.tlsConfig.Config(), h.quicConfig) + if err != nil { + return err + } + h.listener = listener + go h.acceptLoop() + return nil +} + +func (h *Hysteria) acceptLoop() { + for { + ctx := log.ContextWithNewID(h.ctx) + conn, err := h.listener.Accept(ctx) + if err != nil { + return + } + go func() { + hErr := h.accept(ctx, conn) + if hErr != nil { + conn.CloseWithError(0, "") + NewError(h.logger, ctx, E.Cause(hErr, "process connection from ", conn.RemoteAddr())) + } + }() + } +} + +func (h *Hysteria) accept(ctx context.Context, conn quic.Connection) error { + controlStream, err := conn.AcceptStream(ctx) + if err != nil { + return err + } + clientHello, err := hysteria.ReadClientHello(controlStream) + if err != nil { + return err + } + if !bytes.Equal(clientHello.Auth, h.authKey) { + err = hysteria.WriteServerHello(controlStream, hysteria.ServerHello{ + Message: "wrong password", + }) + return E.Errors(E.New("wrong password: ", string(clientHello.Auth)), err) + } + if clientHello.SendBPS == 0 || clientHello.RecvBPS == 0 { + return E.New("invalid rate from client") + } + serverSendBPS, serverRecvBPS := clientHello.RecvBPS, clientHello.SendBPS + if h.sendBPS > 0 && serverSendBPS > h.sendBPS { + serverSendBPS = h.sendBPS + } + if h.recvBPS > 0 && serverRecvBPS > h.recvBPS { + serverRecvBPS = h.recvBPS + } + err = hysteria.WriteServerHello(controlStream, hysteria.ServerHello{ + OK: true, + SendBPS: serverSendBPS, + RecvBPS: serverRecvBPS, + }) + if err != nil { + return err + } + // TODO: set congestion control + go h.udpRecvLoop(conn) + var stream quic.Stream + for { + stream, err = conn.AcceptStream(ctx) + if err != nil { + return err + } + hErr := h.acceptStream(ctx, conn, stream) + if hErr != nil { + stream.Close() + NewError(h.logger, ctx, E.Cause(hErr, "process stream from ", conn.RemoteAddr())) + } + } +} + +func (h *Hysteria) udpRecvLoop(conn quic.Connection) { +} + +func (h *Hysteria) acceptStream(ctx context.Context, conn quic.Connection, stream quic.Stream) error { + request, err := hysteria.ReadClientRequest(stream) + if err != nil { + return err + } + if request.UDP { + err = hysteria.WriteServerResponse(stream, hysteria.ServerResponse{ + Message: "unsupported", + }, nil) + if err != nil { + return err + } + stream.Close() + return nil + } + var metadata adapter.InboundContext + metadata.Inbound = h.tag + metadata.InboundType = C.TypeHysteria + metadata.SniffEnabled = h.listenOptions.SniffEnabled + metadata.SniffOverrideDestination = h.listenOptions.SniffOverrideDestination + metadata.DomainStrategy = dns.DomainStrategy(h.listenOptions.DomainStrategy) + metadata.Network = N.NetworkTCP + metadata.Source = M.SocksaddrFromNet(conn.RemoteAddr()) + metadata.Destination = M.ParseSocksaddrHostPort(request.Host, request.Port) + return h.router.RouteConnection(ctx, hysteria.NewServerConn(stream, metadata.Destination), metadata) +} + +func (h *Hysteria) Close() error { + return common.Close( + h.listener, + common.PtrOrNil(h.tlsConfig), + ) +} diff --git a/inbound/hysteria_stub.go b/inbound/hysteria_stub.go new file mode 100644 index 00000000..5a619115 --- /dev/null +++ b/inbound/hysteria_stub.go @@ -0,0 +1,16 @@ +//go:build !with_quic + +package inbound + +import ( + "context" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" +) + +func NewHysteria(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.HysteriaInboundOptions) (adapter.Inbound, error) { + return nil, E.New(`QUIC is not included in this build, rebuild with -tags with_quic`) +} diff --git a/inbound/naive.go b/inbound/naive.go index 660447f6..424156dd 100644 --- a/inbound/naive.go +++ b/inbound/naive.go @@ -46,7 +46,7 @@ type Naive struct { } var ( - ErrNaiveTLSRequired = E.New("TLS required") + ErrTLSRequired = E.New("TLS required") ErrNaiveMissingUsers = E.New("missing users") ) @@ -61,7 +61,7 @@ func NewNaive(ctx context.Context, router adapter.Router, logger log.ContextLogg authenticator: auth.NewAuthenticator(options.Users), } if options.TLS == nil || !options.TLS.Enabled { - return nil, ErrNaiveTLSRequired + return nil, ErrTLSRequired } if len(options.Users) == 0 { return nil, ErrNaiveMissingUsers diff --git a/option/hysteria.go b/option/hysteria.go index 9d57f82e..f01a7b83 100644 --- a/option/hysteria.go +++ b/option/hysteria.go @@ -1,23 +1,38 @@ package option +type HysteriaInboundOptions struct { + ListenOptions + Up string `json:"up,omitempty"` + UpMbps int `json:"up_mbps,omitempty"` + Down string `json:"down,omitempty"` + DownMbps int `json:"down_mbps,omitempty"` + Obfs string `json:"obfs,omitempty"` + Auth []byte `json:"auth,omitempty"` + AuthString string `json:"auth_str,omitempty"` + ReceiveWindowConn uint64 `json:"recv_window_conn,omitempty"` + ReceiveWindowClient uint64 `json:"recv_window_client,omitempty"` + MaxConnClient int `json:"max_conn_client,omitempty"` + DisableMTUDiscovery bool `json:"disable_mtu_discovery,omitempty"` + TLS *InboundTLSOptions `json:"tls,omitempty"` +} + type HysteriaOutboundOptions struct { OutboundDialerOptions ServerOptions - Protocol string `json:"protocol"` - Up string `json:"up"` - UpMbps int `json:"up_mbps"` - Down string `json:"down"` - DownMbps int `json:"down_mbps"` - Obfs string `json:"obfs"` - Auth []byte `json:"auth"` - AuthString string `json:"auth_str"` - ALPN string `json:"alpn"` - ServerName string `json:"server_name"` - Insecure bool `json:"insecure"` - CustomCA string `json:"ca"` - CustomCAStr string `json:"ca_str"` - ReceiveWindowConn uint64 `json:"recv_window_conn"` - ReceiveWindow uint64 `json:"recv_window"` - DisableMTUDiscovery bool `json:"disable_mtu_discovery"` + Up string `json:"up,omitempty"` + UpMbps int `json:"up_mbps,omitempty"` + Down string `json:"down,omitempty"` + DownMbps int `json:"down_mbps,omitempty"` + Obfs string `json:"obfs,omitempty"` + Auth []byte `json:"auth,omitempty"` + AuthString string `json:"auth_str,omitempty"` + ALPN string `json:"alpn,omitempty"` + ServerName string `json:"server_name,omitempty"` + Insecure bool `json:"insecure,omitempty"` + CustomCA string `json:"ca,omitempty"` + CustomCAStr string `json:"ca_str,omitempty"` + ReceiveWindowConn uint64 `json:"recv_window_conn,omitempty"` + ReceiveWindow uint64 `json:"recv_window,omitempty"` + DisableMTUDiscovery bool `json:"disable_mtu_discovery,omitempty"` Network NetworkList `json:"network,omitempty"` } diff --git a/option/inbound.go b/option/inbound.go index 57f7ac92..2ac440c5 100644 --- a/option/inbound.go +++ b/option/inbound.go @@ -20,6 +20,7 @@ type _Inbound struct { VMessOptions VMessInboundOptions `json:"-"` TrojanOptions TrojanInboundOptions `json:"-"` NaiveOptions NaiveInboundOptions `json:"-"` + HysteriaOptions HysteriaInboundOptions `json:"-"` } type Inbound _Inbound @@ -49,6 +50,8 @@ func (h Inbound) MarshalJSON() ([]byte, error) { v = h.TrojanOptions case C.TypeNaive: v = h.NaiveOptions + case C.TypeHysteria: + v = h.HysteriaOptions default: return nil, E.New("unknown inbound type: ", h.Type) } @@ -84,6 +87,8 @@ func (h *Inbound) UnmarshalJSON(bytes []byte) error { v = &h.TrojanOptions case C.TypeNaive: v = &h.NaiveOptions + case C.TypeHysteria: + v = &h.HysteriaOptions default: return E.New("unknown inbound type: ", h.Type) } diff --git a/outbound/hysteria.go b/outbound/hysteria.go index c539fda3..a1339e4a 100644 --- a/outbound/hysteria.go +++ b/outbound/hysteria.go @@ -9,7 +9,6 @@ import ( "net" "os" "sync" - "time" "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/common/dialer" @@ -26,16 +25,6 @@ import ( "github.com/lucas-clemente/quic-go" ) -const ( - hyMbpsToBps = 125000 - hyMinSpeedBPS = 16384 - hyDefaultStreamReceiveWindow = 15728640 // 15 MB/s - hyDefaultConnectionReceiveWindow = 67108864 // 64 MB/s - hyDefaultMaxIncomingStreams = 1024 - hyDefaultALPN = "hysteria" - hyKeepAlivePeriod = 10 * time.Second -) - var _ adapter.Outbound = (*Hysteria)(nil) type Hysteria struct { @@ -65,7 +54,7 @@ func NewHysteria(ctx context.Context, router adapter.Router, logger log.ContextL if options.ALPN != "" { tlsConfig.NextProtos = []string{options.ALPN} } else { - tlsConfig.NextProtos = []string{hyDefaultALPN} + tlsConfig.NextProtos = []string{hysteria.DefaultALPN} } var ca []byte var err error @@ -90,20 +79,20 @@ func NewHysteria(ctx context.Context, router adapter.Router, logger log.ContextL MaxStreamReceiveWindow: options.ReceiveWindowConn, InitialConnectionReceiveWindow: options.ReceiveWindow, MaxConnectionReceiveWindow: options.ReceiveWindow, - KeepAlivePeriod: hyKeepAlivePeriod, + KeepAlivePeriod: hysteria.KeepAlivePeriod, DisablePathMTUDiscovery: options.DisableMTUDiscovery, EnableDatagrams: true, } if options.ReceiveWindowConn == 0 { - quicConfig.InitialStreamReceiveWindow = hyDefaultStreamReceiveWindow - quicConfig.MaxStreamReceiveWindow = hyDefaultStreamReceiveWindow + quicConfig.InitialStreamReceiveWindow = hysteria.DefaultStreamReceiveWindow + quicConfig.MaxStreamReceiveWindow = hysteria.DefaultStreamReceiveWindow } if options.ReceiveWindow == 0 { - quicConfig.InitialConnectionReceiveWindow = hyDefaultConnectionReceiveWindow - quicConfig.MaxConnectionReceiveWindow = hyDefaultConnectionReceiveWindow + quicConfig.InitialConnectionReceiveWindow = hysteria.DefaultConnectionReceiveWindow + quicConfig.MaxConnectionReceiveWindow = hysteria.DefaultConnectionReceiveWindow } if quicConfig.MaxIncomingStreams == 0 { - quicConfig.MaxIncomingStreams = hyDefaultMaxIncomingStreams + quicConfig.MaxIncomingStreams = hysteria.DefaultMaxIncomingStreams } var auth []byte if len(options.Auth) > 0 { @@ -122,7 +111,7 @@ func NewHysteria(ctx context.Context, router adapter.Router, logger log.ContextL return nil, E.New("invalid up speed format: ", options.Up) } } else { - up = uint64(options.UpMbps) * hyMbpsToBps + up = uint64(options.UpMbps) * hysteria.MbpsToBps } if len(options.Down) > 0 { down = hysteria.StringToBps(options.Down) @@ -130,12 +119,12 @@ func NewHysteria(ctx context.Context, router adapter.Router, logger log.ContextL return nil, E.New("invalid down speed format: ", options.Down) } } else { - down = uint64(options.DownMbps) * hyMbpsToBps + down = uint64(options.DownMbps) * hysteria.MbpsToBps } - if up < hyMinSpeedBPS { + if up < hysteria.MinSpeedBPS { return nil, E.New("invalid up speed") } - if down < hyMinSpeedBPS { + if down < hysteria.MinSpeedBPS { return nil, E.New("invalid down speed") } return &Hysteria{ @@ -214,7 +203,7 @@ func (h *Hysteria) offerNew(ctx context.Context) (quic.Connection, error) { Auth: h.authKey, }) if err != nil { - return nil, E.Cause(err, "write hysteria client hello") + return nil, err } serverHello, err := hysteria.ReadServerHello(controlStream) if err != nil { diff --git a/test/hysteria_test.go b/test/hysteria_test.go index e15b7b37..c9d0568f 100644 --- a/test/hysteria_test.go +++ b/test/hysteria_test.go @@ -8,6 +8,81 @@ import ( "github.com/sagernet/sing-box/option" ) +func TestHysteriaSelf(t *testing.T) { + if !C.QUIC_AVAILABLE { + t.Skip("QUIC not included") + } + caPem, certPem, keyPem := createSelfSignedCertificate(t, "example.org") + startInstance(t, option.Options{ + Log: &option.LogOptions{ + Level: "trace", + }, + Inbounds: []option.Inbound{ + { + Type: C.TypeMixed, + Tag: "mixed-in", + MixedOptions: option.HTTPMixedInboundOptions{ + ListenOptions: option.ListenOptions{ + Listen: option.ListenAddress(netip.IPv4Unspecified()), + ListenPort: clientPort, + }, + }, + }, + { + Type: C.TypeHysteria, + HysteriaOptions: option.HysteriaInboundOptions{ + ListenOptions: option.ListenOptions{ + Listen: option.ListenAddress(netip.IPv4Unspecified()), + ListenPort: serverPort, + }, + UpMbps: 100, + DownMbps: 100, + AuthString: "password", + Obfs: "fuck me till the daylight", + TLS: &option.InboundTLSOptions{ + Enabled: true, + ServerName: "example.org", + CertificatePath: certPem, + KeyPath: keyPem, + }, + }, + }, + }, + Outbounds: []option.Outbound{ + { + Type: C.TypeDirect, + }, + { + Type: C.TypeHysteria, + Tag: "hy-out", + HysteriaOutbound: option.HysteriaOutboundOptions{ + ServerOptions: option.ServerOptions{ + Server: "127.0.0.1", + ServerPort: serverPort, + }, + UpMbps: 100, + DownMbps: 100, + AuthString: "password", + Obfs: "fuck me till the daylight", + CustomCA: caPem, + ServerName: "example.org", + }, + }, + }, + Route: &option.RouteOptions{ + Rules: []option.Rule{ + { + DefaultOptions: option.DefaultRule{ + Inbound: []string{"mixed-in"}, + Outbound: "hy-out", + }, + }, + }, + }, + }) + testTCP(t, clientPort, testPort) +} + func TestHysteriaOutbound(t *testing.T) { if !C.QUIC_AVAILABLE { t.Skip("QUIC not included") diff --git a/transport/hysteria/client.go b/transport/hysteria/client.go index 36cc84b0..60357e50 100644 --- a/transport/hysteria/client.go +++ b/transport/hysteria/client.go @@ -49,7 +49,7 @@ func (c *ClientConn) Write(b []byte) (n int, err error) { if !c.requestWritten { err = WriteClientRequest(c.Stream, ClientRequest{ UDP: false, - Host: c.destination.AddrString(), + Host: c.destination.Unwrap().AddrString(), Port: c.destination.Port, }, b) if err != nil { @@ -73,6 +73,14 @@ func (c *ClientConn) Upstream() any { return c.Stream } +func (c *ClientConn) ReaderReplaceable() bool { + return c.responseRead +} + +func (c *ClientConn) WriterReplaceable() bool { + return c.requestWritten +} + type ClientPacketConn struct { session quic.Connection stream quic.Stream @@ -130,7 +138,7 @@ func (c *ClientPacketConn) ReadPacketThreadSafe() (buffer *buf.Buffer, destinati func (c *ClientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { return WriteUDPMessage(c.session, UDPMessage{ SessionID: c.sessionId, - Host: destination.AddrString(), + Host: destination.Unwrap().AddrString(), Port: destination.Port, FragCount: 1, Data: buffer.Bytes(), @@ -158,19 +166,19 @@ func (c *ClientPacketConn) SetWriteDeadline(t time.Time) error { } func (c *ClientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { - panic("invalid") + return 0, nil, os.ErrInvalid } func (c *ClientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { - panic("invalid") + return 0, os.ErrInvalid } func (c *ClientPacketConn) Read(b []byte) (n int, err error) { - panic("invalid") + return 0, os.ErrInvalid } func (c *ClientPacketConn) Write(b []byte) (n int, err error) { - panic("invalid") + return 0, os.ErrInvalid } func (c *ClientPacketConn) Close() error { diff --git a/transport/hysteria/protocol.go b/transport/hysteria/protocol.go index 091923e5..df62bcb8 100644 --- a/transport/hysteria/protocol.go +++ b/transport/hysteria/protocol.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "io" "math/rand" + "time" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" @@ -13,6 +14,16 @@ import ( "github.com/lucas-clemente/quic-go" ) +const ( + MbpsToBps = 125000 + MinSpeedBPS = 16384 + DefaultStreamReceiveWindow = 15728640 // 15 MB/s + DefaultConnectionReceiveWindow = 67108864 // 64 MB/s + DefaultMaxIncomingStreams = 1024 + DefaultALPN = "hysteria" + KeepAlivePeriod = 10 * time.Second +) + const Version = 3 type ClientHello struct { @@ -21,13 +32,6 @@ type ClientHello struct { Auth []byte } -type ServerHello struct { - OK bool - SendBPS uint64 - RecvBPS uint64 - Message string -} - func WriteClientHello(stream io.Writer, hello ClientHello) error { var requestLen int requestLen += 1 // version @@ -49,6 +53,44 @@ func WriteClientHello(stream io.Writer, hello ClientHello) error { return common.Error(stream.Write(request.Bytes())) } +func ReadClientHello(reader io.Reader) (*ClientHello, error) { + var version uint8 + err := binary.Read(reader, binary.BigEndian, &version) + if err != nil { + return nil, err + } + if version != Version { + return nil, E.New("unsupported client version: ", version) + } + var clientHello ClientHello + err = binary.Read(reader, binary.BigEndian, &clientHello.SendBPS) + if err != nil { + return nil, err + } + err = binary.Read(reader, binary.BigEndian, &clientHello.RecvBPS) + if err != nil { + return nil, err + } + var authLen uint16 + err = binary.Read(reader, binary.BigEndian, &authLen) + if err != nil { + return nil, err + } + clientHello.Auth = make([]byte, authLen) + _, err = io.ReadFull(reader, clientHello.Auth) + if err != nil { + return nil, err + } + return &clientHello, nil +} + +type ServerHello struct { + OK bool + SendBPS uint64 + RecvBPS uint64 + Message string +} + func ReadServerHello(stream io.Reader) (*ServerHello, error) { var responseLen int responseLen += 1 // ok @@ -80,16 +122,59 @@ func ReadServerHello(stream io.Reader) (*ServerHello, error) { return &serverHello, nil } +func WriteServerHello(stream io.Writer, hello ServerHello) error { + var responseLen int + responseLen += 1 // ok + responseLen += 8 // sendBPS + responseLen += 8 // recvBPS + responseLen += 2 // message len + responseLen += len(hello.Message) + _response := buf.StackNewSize(responseLen) + defer common.KeepAlive(_response) + response := common.Dup(_response) + defer response.Release() + if hello.OK { + common.Must(response.WriteByte(1)) + } else { + common.Must(response.WriteByte(0)) + } + common.Must( + binary.Write(response, binary.BigEndian, hello.SendBPS), + binary.Write(response, binary.BigEndian, hello.RecvBPS), + binary.Write(response, binary.BigEndian, uint16(len(hello.Message))), + common.Error(response.WriteString(hello.Message)), + ) + return common.Error(stream.Write(response.Bytes())) +} + type ClientRequest struct { UDP bool Host string Port uint16 } -type ServerResponse struct { - OK bool - UDPSessionID uint32 - Message string +func ReadClientRequest(stream io.Reader) (*ClientRequest, error) { + var clientRequest ClientRequest + err := binary.Read(stream, binary.BigEndian, &clientRequest.UDP) + if err != nil { + return nil, err + } + var hostLen uint16 + err = binary.Read(stream, binary.BigEndian, &hostLen) + if err != nil { + return nil, err + } + host := make([]byte, hostLen) + _, err = io.ReadFull(stream, host) + if err != nil { + return nil, err + } + clientRequest.Host = string(host) + err = binary.Read(stream, binary.BigEndian, &clientRequest.Port) + if err != nil { + return nil, err + } + return &clientRequest, nil } func WriteClientRequest(stream io.Writer, request ClientRequest, payload []byte) error { @@ -117,6 +202,12 @@ func WriteClientRequest(stream io.Writer, request ClientRequest, payload []byte) return common.Error(stream.Write(buffer.Bytes())) } +type ServerResponse struct { + OK bool + UDPSessionID uint32 + Message string +} + func ReadServerResponse(stream io.Reader) (*ServerResponse, error) { var responseLen int responseLen += 1 // ok @@ -146,6 +237,31 @@ func ReadServerResponse(stream io.Reader) (*ServerResponse, error) { return &serverResponse, nil } +func WriteServerResponse(stream io.Writer, response ServerResponse, payload []byte) error { + var responseLen int + responseLen += 1 // ok + responseLen += 4 // udp session id + responseLen += 2 // message len + responseLen += len(response.Message) + responseLen += len(payload) + _buffer := buf.StackNewSize(responseLen) + defer common.KeepAlive(_buffer) + buffer := common.Dup(_buffer) + defer buffer.Release() + if response.OK { + common.Must(buffer.WriteByte(1)) + } else { + common.Must(buffer.WriteByte(0)) + } + common.Must( + binary.Write(buffer, binary.BigEndian, response.UDPSessionID), + binary.Write(buffer, binary.BigEndian, uint16(len(response.Message))), + common.Error(buffer.WriteString(response.Message)), + common.Error(buffer.Write(payload)), + ) + return common.Error(stream.Write(buffer.Bytes())) +} + type UDPMessage struct { SessionID uint32 Host string diff --git a/transport/hysteria/server.go b/transport/hysteria/server.go new file mode 100644 index 00000000..75ed67e0 --- /dev/null +++ b/transport/hysteria/server.go @@ -0,0 +1,68 @@ +package hysteria + +import ( + "net" + + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + + "github.com/lucas-clemente/quic-go" +) + +var ( + _ net.Conn = (*ServerConn)(nil) + _ N.HandshakeConn = (*ServerConn)(nil) +) + +type ServerConn struct { + quic.Stream + destination M.Socksaddr + responseWritten bool +} + +func NewServerConn(stream quic.Stream, destination M.Socksaddr) *ServerConn { + return &ServerConn{ + Stream: stream, + destination: destination, + } +} + +func (c *ServerConn) LocalAddr() net.Addr { + return nil +} + +func (c *ServerConn) RemoteAddr() net.Addr { + return c.destination.TCPAddr() +} + +func (c *ServerConn) Write(b []byte) (n int, err error) { + if !c.responseWritten { + err = WriteServerResponse(c.Stream, ServerResponse{ + OK: true, + }, b) + c.responseWritten = true + return len(b), nil + } + return c.Stream.Write(b) +} + +func (c *ServerConn) ReaderReplaceable() bool { + return true +} + +func (c *ServerConn) WriterReplaceable() bool { + return c.responseWritten +} + +func (c *ServerConn) HandshakeFailure(err error) error { + if c.responseWritten { + return nil + } + return WriteServerResponse(c.Stream, ServerResponse{ + Message: err.Error(), + }, nil) +} + +func (c *ServerConn) Upstream() any { + return c.Stream +}