From d511698f3f6f7d1a684e740202be202a5e65c1c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 6 Jun 2025 14:39:40 +0800 Subject: [PATCH] Fix slowOpenConn --- common/dialer/tfo.go | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/common/dialer/tfo.go b/common/dialer/tfo.go index 9f72208d..8ea59ca6 100644 --- a/common/dialer/tfo.go +++ b/common/dialer/tfo.go @@ -10,9 +10,7 @@ import ( "sync" "time" - "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/bufio" - E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" @@ -26,7 +24,9 @@ type slowOpenConn struct { destination M.Socksaddr conn net.Conn create chan struct{} + done chan struct{} access sync.Mutex + closeOnce sync.Once err error } @@ -45,6 +45,7 @@ func DialSlowContext(dialer *tcpDialer, ctx context.Context, network string, des network: network, destination: destination, create: make(chan struct{}), + done: make(chan struct{}), }, nil } @@ -55,8 +56,8 @@ func (c *slowOpenConn) Read(b []byte) (n int, err error) { if c.err != nil { return 0, c.err } - case <-c.ctx.Done(): - return 0, c.ctx.Err() + case <-c.done: + return 0, os.ErrClosed } } return c.conn.Read(b) @@ -74,12 +75,15 @@ func (c *slowOpenConn) Write(b []byte) (n int, err error) { return 0, c.err } return c.conn.Write(b) + case <-c.done: + return 0, os.ErrClosed default: } - c.conn, err = c.dialer.DialContext(c.ctx, c.network, c.destination.String(), b) + conn, err := c.dialer.DialContext(c.ctx, c.network, c.destination.String(), b) if err != nil { - c.conn = nil - c.err = E.Cause(err, "dial tcp fast open") + c.err = err + } else { + c.conn = conn } n = len(b) close(c.create) @@ -87,7 +91,13 @@ func (c *slowOpenConn) Write(b []byte) (n int, err error) { } func (c *slowOpenConn) Close() error { - return common.Close(c.conn) + c.closeOnce.Do(func() { + close(c.done) + if c.conn != nil { + c.conn.Close() + } + }) + return nil } func (c *slowOpenConn) LocalAddr() net.Addr { @@ -152,8 +162,8 @@ func (c *slowOpenConn) WriteTo(w io.Writer) (n int64, err error) { if c.err != nil { return 0, c.err } - case <-c.ctx.Done(): - return 0, c.ctx.Err() + case <-c.done: + return 0, c.err } } return bufio.Copy(w, c.conn)