diff --git a/inbound/hysteria.go b/inbound/hysteria.go index 0facc47f..642d101c 100644 --- a/inbound/hysteria.go +++ b/inbound/hysteria.go @@ -273,7 +273,7 @@ func (h *Hysteria) acceptStream(ctx context.Context, conn quic.Connection, strea return err } h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination) - return h.router.RouteConnection(ctx, hysteria.NewConn(stream, metadata.Destination), metadata) + return h.router.RouteConnection(ctx, hysteria.NewConn(stream, metadata.Destination, false), metadata) } else { h.logger.InfoContext(ctx, "inbound packet connection to ", metadata.Destination) var id uint32 diff --git a/outbound/hysteria.go b/outbound/hysteria.go index 176d417d..529083b2 100644 --- a/outbound/hysteria.go +++ b/outbound/hysteria.go @@ -276,16 +276,7 @@ func (h *Hysteria) DialContext(ctx context.Context, network string, destination stream.Close() return nil, err } - response, err := hysteria.ReadServerResponse(stream) - if err != nil { - stream.Close() - return nil, err - } - if !response.OK { - stream.Close() - return nil, E.New("remote error: ", response.Message) - } - return hysteria.NewConn(stream, destination), nil + return hysteria.NewConn(stream, destination, true), nil case N.NetworkUDP: conn, err := h.ListenPacket(ctx, destination) if err != nil { diff --git a/transport/hysteria/protocol.go b/transport/hysteria/protocol.go index d3893b80..3a92d194 100644 --- a/transport/hysteria/protocol.go +++ b/transport/hysteria/protocol.go @@ -374,17 +374,35 @@ var _ net.Conn = (*Conn)(nil) type Conn struct { quic.Stream - destination M.Socksaddr - responseWritten bool + destination M.Socksaddr + needReadResponse bool } -func NewConn(stream quic.Stream, destination M.Socksaddr) *Conn { +func NewConn(stream quic.Stream, destination M.Socksaddr, isClient bool) *Conn { return &Conn{ - Stream: stream, - destination: destination, + Stream: stream, + destination: destination, + needReadResponse: isClient, } } +func (c *Conn) Read(p []byte) (n int, err error) { + if c.needReadResponse { + var response *ServerResponse + response, err = ReadServerResponse(c.Stream) + if err != nil { + c.Close() + return + } + if !response.OK { + c.Close() + return 0, E.New("remote error: ", response.Message) + } + c.needReadResponse = false + } + return c.Stream.Read(p) +} + func (c *Conn) LocalAddr() net.Addr { return nil } @@ -394,7 +412,7 @@ func (c *Conn) RemoteAddr() net.Addr { } func (c *Conn) ReaderReplaceable() bool { - return true + return !c.needReadResponse } func (c *Conn) WriterReplaceable() bool {