diff --git a/transport/shadowtls/conn.go b/transport/shadowtls/conn.go index 0bb3ee8b..d498e021 100644 --- a/transport/shadowtls/conn.go +++ b/transport/shadowtls/conn.go @@ -9,24 +9,22 @@ import ( "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/bufio" + E "github.com/sagernet/sing/common/exceptions" N "github.com/sagernet/sing/common/network" ) -var ( - _ N.ExtendedConn = (*Conn)(nil) - _ N.VectorisedWriter = (*Conn)(nil) -) +var _ N.VectorisedWriter = (*Conn)(nil) type Conn struct { - N.ExtendedConn + net.Conn writer N.VectorisedWriter readRemaining int } func NewConn(conn net.Conn) *Conn { return &Conn{ - ExtendedConn: bufio.NewExtendedConn(conn), - writer: bufio.NewVectorisedWriter(conn), + Conn: conn, + writer: bufio.NewVectorisedWriter(conn), } } @@ -35,21 +33,24 @@ func (c *Conn) Read(p []byte) (n int, err error) { if len(p) > c.readRemaining { p = p[:c.readRemaining] } - n, err = c.ExtendedConn.Read(p) + n, err = c.Conn.Read(p) c.readRemaining -= n return } var tlsHeader [5]byte - _, err = io.ReadFull(c.ExtendedConn, common.Dup(tlsHeader[:])) + _, err = io.ReadFull(c.Conn, common.Dup(tlsHeader[:])) if err != nil { return } length := int(binary.BigEndian.Uint16(tlsHeader[3:5])) + if tlsHeader[0] != 23 { + return 0, E.New("unexpected TLS record type: ", tlsHeader[0]) + } readLen := len(p) if readLen > length { readLen = length } - n, err = c.ExtendedConn.Read(p[:readLen]) + n, err = c.Conn.Read(p[:readLen]) if err != nil { return } @@ -92,5 +93,5 @@ func (c *Conn) WriteVectorised(buffers []*buf.Buffer) error { } func (c *Conn) Upstream() any { - return c.ExtendedConn + return c.Conn }