diff --git a/transport/v2raywebsocket/client.go b/transport/v2raywebsocket/client.go index bc24481a..d3b7a7a6 100644 --- a/transport/v2raywebsocket/client.go +++ b/transport/v2raywebsocket/client.go @@ -70,7 +70,7 @@ func (c *Client) DialContext(ctx context.Context) (net.Conn, error) { } return nil, wrapDialError(response, err) } else { - return &EarlyWebsocketConn{Client: c, create: make(chan struct{})}, nil + return &EarlyWebsocketConn{Client: c, ctx: ctx, create: make(chan struct{})}, nil } } diff --git a/transport/v2raywebsocket/conn.go b/transport/v2raywebsocket/conn.go index 18ef70fc..68fdc8e8 100644 --- a/transport/v2raywebsocket/conn.go +++ b/transport/v2raywebsocket/conn.go @@ -1,6 +1,7 @@ package v2raywebsocket import ( + "context" "encoding/base64" "io" "net" @@ -68,6 +69,7 @@ func (c *WebsocketConn) SetDeadline(t time.Time) error { type EarlyWebsocketConn struct { *Client + ctx context.Context conn *WebsocketConn create chan struct{} } @@ -98,14 +100,14 @@ func (c *EarlyWebsocketConn) Write(b []byte) (n int, err error) { if len(earlyData) > 0 { earlyDataString := base64.RawURLEncoding.EncodeToString(earlyData) if c.earlyDataHeaderName == "" { - conn, response, err = c.dialer.Dial(c.uri+earlyDataString, c.headers) + conn, response, err = c.dialer.DialContext(c.ctx, c.uri+earlyDataString, c.headers) } else { headers := c.headers.Clone() headers.Set(c.earlyDataHeaderName, earlyDataString) - conn, response, err = c.dialer.Dial(c.uri, headers) + conn, response, err = c.dialer.DialContext(c.ctx, c.uri, headers) } } else { - conn, response, err = c.dialer.Dial(c.uri, c.headers) + conn, response, err = c.dialer.DialContext(c.ctx, c.uri, c.headers) } if err != nil { return 0, wrapDialError(response, err)