diff --git a/transport/v2raywebsocket/conn.go b/transport/v2raywebsocket/conn.go index f9099764..a796e612 100644 --- a/transport/v2raywebsocket/conn.go +++ b/transport/v2raywebsocket/conn.go @@ -3,6 +3,7 @@ package v2raywebsocket import ( "context" "encoding/base64" + "errors" "io" "net" "os" @@ -61,7 +62,7 @@ func (c *WebsocketConn) Close() error { func (c *WebsocketConn) Read(b []byte) (n int, err error) { var header ws.Header for { - n, err = c.reader.Read(b) + n, err = wrapWsError0(c.reader.Read(b)) if n > 0 { err = nil return @@ -95,7 +96,7 @@ func (c *WebsocketConn) Read(b []byte) (n int, err error) { } func (c *WebsocketConn) Write(p []byte) (n int, err error) { - err = wsutil.WriteMessage(c.Conn, c.state, ws.OpBinary, p) + err = wrapWsError(wsutil.WriteMessage(c.Conn, c.state, ws.OpBinary, p)) if err != nil { return } @@ -146,7 +147,7 @@ func (c *EarlyWebsocketConn) Read(b []byte) (n int, err error) { return 0, c.err } } - return c.conn.Read(b) + return wrapWsError0(c.conn.Read(b)) } func (c *EarlyWebsocketConn) writeRequest(content []byte) error { @@ -177,12 +178,12 @@ func (c *EarlyWebsocketConn) writeRequest(content []byte) error { conn, err = c.dialContext(c.ctx, &c.requestURL, c.headers) } if err != nil { - return err + return wrapWsError(err) } if len(lateData) > 0 { _, err = conn.Write(lateData) if err != nil { - return err + return wrapWsError(err) } } c.conn = conn @@ -191,7 +192,7 @@ func (c *EarlyWebsocketConn) writeRequest(content []byte) error { func (c *EarlyWebsocketConn) Write(b []byte) (n int, err error) { if c.conn != nil { - return c.conn.Write(b) + return wrapWsError0(c.conn.Write(b)) } c.access.Lock() defer c.access.Unlock() @@ -199,9 +200,9 @@ func (c *EarlyWebsocketConn) Write(b []byte) (n int, err error) { return 0, c.err } if c.conn != nil { - return c.conn.Write(b) + return wrapWsError0(c.conn.Write(b)) } - err = c.writeRequest(b) + err = wrapWsError(c.writeRequest(b)) c.err = err close(c.create) if err != nil { @@ -212,17 +213,17 @@ func (c *EarlyWebsocketConn) Write(b []byte) (n int, err error) { func (c *EarlyWebsocketConn) WriteBuffer(buffer *buf.Buffer) error { if c.conn != nil { - return c.conn.WriteBuffer(buffer) + return wrapWsError(c.conn.WriteBuffer(buffer)) } c.access.Lock() defer c.access.Unlock() if c.conn != nil { - return c.conn.WriteBuffer(buffer) + return wrapWsError(c.conn.WriteBuffer(buffer)) } if c.err != nil { return c.err } - err := c.writeRequest(buffer.Bytes()) + err := wrapWsError(c.writeRequest(buffer.Bytes())) c.err = err close(c.create) return err @@ -272,3 +273,23 @@ func (c *EarlyWebsocketConn) Upstream() any { func (c *EarlyWebsocketConn) LazyHeadroom() bool { return c.conn == nil } + +func wrapWsError(err error) error { + if err == nil { + return nil + } + var closedErr *wsutil.ClosedError + if errors.As(err, &closedErr) { + if closedErr.Code == ws.StatusNormalClosure { + err = io.EOF + } + } + return err +} + +func wrapWsError0[T any](value T, err error) (T, error) { + if err == nil { + return value, nil + } + return common.DefaultValue[T](), wrapWsError(err) +}