From 597248130f9f541b96f1a60cabe66a2d003c324e Mon Sep 17 00:00:00 2001 From: shadow750d6 <124365938+shadow750d6@users.noreply.github.com> Date: Sun, 11 Jun 2023 22:20:55 +0800 Subject: [PATCH] Reconnect once if hysteria request fails This allows graceful recovery when network isn't good enough. [Original hysteria source code](https://github.com/apernet/hysteria/blob/13d46da99876c2c9feb1083ff5f2da201d9d0a1e/core/cs/client.go#L182) has similar mechanism. --- outbound/hysteria.go | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/outbound/hysteria.go b/outbound/hysteria.go index b773970a..26352a40 100644 --- a/outbound/hysteria.go +++ b/outbound/hysteria.go @@ -150,6 +150,7 @@ func (h *Hysteria) offer(ctx context.Context) (quic.Connection, error) { if conn != nil && !common.Done(conn.Context()) { return conn, nil } + common.Close(h.rawConn) conn, err := h.offerNew(ctx) if err != nil { return nil, err @@ -260,14 +261,18 @@ func (h *Hysteria) Close() error { return nil } -func (h *Hysteria) open(ctx context.Context) (quic.Connection, quic.Stream, error) { +func (h *Hysteria) open(ctx context.Context, reconnect bool) (quic.Connection, quic.Stream, error) { conn, err := h.offer(ctx) if err != nil { - return nil, nil, err + if nErr, ok := err.(net.Error); ok && !nErr.Temporary() && reconnect { + return h.open(ctx, false) + } } stream, err := conn.OpenStream() if err != nil { - return nil, nil, err + if nErr, ok := err.(net.Error); ok && !nErr.Temporary() && reconnect { + return h.open(ctx, false) + } } return conn, &hysteria.StreamWrapper{Stream: stream}, nil } @@ -276,7 +281,7 @@ func (h *Hysteria) DialContext(ctx context.Context, network string, destination switch N.NetworkName(network) { case N.NetworkTCP: h.logger.InfoContext(ctx, "outbound connection to ", destination) - _, stream, err := h.open(ctx) + _, stream, err := h.open(ctx, true) if err != nil { return nil, err } @@ -302,7 +307,7 @@ func (h *Hysteria) DialContext(ctx context.Context, network string, destination func (h *Hysteria) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { h.logger.InfoContext(ctx, "outbound packet connection to ", destination) - conn, stream, err := h.open(ctx) + conn, stream, err := h.open(ctx, true) if err != nil { return nil, err }