From 79b6bdfda168b6da8bda5940e83bcfff8a9dacde Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Mon, 12 Sep 2022 23:56:18 +0800 Subject: [PATCH] Skip wait for hysteria tcp handshake response Co-authored-by: arm64v8a <48624112+arm64v8a@users.noreply.github.com> --- inbound/hysteria.go | 2 +- outbound/hysteria.go | 11 +---------- transport/hysteria/protocol.go | 30 ++++++++++++++++++++++++------ 3 files changed, 26 insertions(+), 17 deletions(-) diff --git a/inbound/hysteria.go b/inbound/hysteria.go index 0facc47f..642d101c 100644 --- a/inbound/hysteria.go +++ b/inbound/hysteria.go @@ -273,7 +273,7 @@ func (h *Hysteria) acceptStream(ctx context.Context, conn quic.Connection, strea return err } h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination) - return h.router.RouteConnection(ctx, hysteria.NewConn(stream, metadata.Destination), metadata) + return h.router.RouteConnection(ctx, hysteria.NewConn(stream, metadata.Destination, false), metadata) } else { h.logger.InfoContext(ctx, "inbound packet connection to ", metadata.Destination) var id uint32 diff --git a/outbound/hysteria.go b/outbound/hysteria.go index 176d417d..529083b2 100644 --- a/outbound/hysteria.go +++ b/outbound/hysteria.go @@ -276,16 +276,7 @@ func (h *Hysteria) DialContext(ctx context.Context, network string, destination stream.Close() return nil, err } - response, err := hysteria.ReadServerResponse(stream) - if err != nil { - stream.Close() - return nil, err - } - if !response.OK { - stream.Close() - return nil, E.New("remote error: ", response.Message) - } - return hysteria.NewConn(stream, destination), nil + return hysteria.NewConn(stream, destination, true), nil case N.NetworkUDP: conn, err := h.ListenPacket(ctx, destination) if err != nil { diff --git a/transport/hysteria/protocol.go b/transport/hysteria/protocol.go index d3893b80..3a92d194 100644 --- a/transport/hysteria/protocol.go +++ b/transport/hysteria/protocol.go @@ -374,17 +374,35 @@ var _ net.Conn = (*Conn)(nil) type Conn struct { quic.Stream - destination M.Socksaddr - responseWritten bool + destination M.Socksaddr + needReadResponse bool } -func NewConn(stream quic.Stream, destination M.Socksaddr) *Conn { +func NewConn(stream quic.Stream, destination M.Socksaddr, isClient bool) *Conn { return &Conn{ - Stream: stream, - destination: destination, + Stream: stream, + destination: destination, + needReadResponse: isClient, } } +func (c *Conn) Read(p []byte) (n int, err error) { + if c.needReadResponse { + var response *ServerResponse + response, err = ReadServerResponse(c.Stream) + if err != nil { + c.Close() + return + } + if !response.OK { + c.Close() + return 0, E.New("remote error: ", response.Message) + } + c.needReadResponse = false + } + return c.Stream.Read(p) +} + func (c *Conn) LocalAddr() net.Addr { return nil } @@ -394,7 +412,7 @@ func (c *Conn) RemoteAddr() net.Addr { } func (c *Conn) ReaderReplaceable() bool { - return true + return !c.needReadResponse } func (c *Conn) WriterReplaceable() bool {