Improve ktls rx error handling

This commit is contained in:
世界 2025-09-09 22:20:53 +08:00
parent 160663e1cf
commit 23af702b27
No known key found for this signature in database
GPG Key ID: CD109927C34A63C4
2 changed files with 5 additions and 3 deletions

View File

@ -32,6 +32,7 @@ type Conn struct {
readWaitOptions N.ReadWaitOptions readWaitOptions N.ReadWaitOptions
kernelTx bool kernelTx bool
kernelRx bool kernelRx bool
pendingRxSplice bool
} }
func NewConn(ctx context.Context, logger logger.ContextLogger, conn aTLS.Conn, txOffload, rxOffload bool) (aTLS.Conn, error) { func NewConn(ctx context.Context, logger logger.ContextLogger, conn aTLS.Conn, txOffload, rxOffload bool) (aTLS.Conn, error) {
@ -103,6 +104,7 @@ func (c *Conn) SyscallConnForRead() syscall.RawConn {
func (c *Conn) HandleSyscallReadError(inputErr error) ([]byte, error) { func (c *Conn) HandleSyscallReadError(inputErr error) ([]byte, error) {
if errors.Is(inputErr, unix.EINVAL) { if errors.Is(inputErr, unix.EINVAL) {
c.pendingRxSplice = true
err := c.readRecord() err := c.readRecord()
if err != nil { if err != nil {
return nil, E.Cause(err, "ktls: handle non-application-data record") return nil, E.Cause(err, "ktls: handle non-application-data record")

View File

@ -258,14 +258,14 @@ func (c *Conn) readKernelRecord() (uint8, []byte, error) {
var err error var err error
er := c.rawSyscallConn.Read(func(fd uintptr) bool { er := c.rawSyscallConn.Read(func(fd uintptr) bool {
n, err = recvmsg(int(fd), &msg, 0) n, err = recvmsg(int(fd), &msg, 0)
return err != unix.EAGAIN return err != unix.EAGAIN || c.pendingRxSplice
}) })
if er != nil { if er != nil {
return 0, nil, er return 0, nil, er
} }
switch err { switch err {
case nil: case nil:
case syscall.EINVAL: case syscall.EINVAL, syscall.EAGAIN:
return 0, nil, c.rawConn.In.SetErrorLocked(c.sendAlert(alertProtocolVersion)) return 0, nil, c.rawConn.In.SetErrorLocked(c.sendAlert(alertProtocolVersion))
case syscall.EMSGSIZE: case syscall.EMSGSIZE:
return 0, nil, c.rawConn.In.SetErrorLocked(c.sendAlert(alertRecordOverflow)) return 0, nil, c.rawConn.In.SetErrorLocked(c.sendAlert(alertRecordOverflow))
@ -276,7 +276,7 @@ func (c *Conn) readKernelRecord() (uint8, []byte, error) {
} }
if n <= 0 { if n <= 0 {
return 0, nil, io.EOF return 0, nil, c.rawConn.In.SetErrorLocked(io.EOF)
} }
if cmsg.Level == unix.SOL_TLS && cmsg.Type == TLS_GET_RECORD_TYPE { if cmsg.Level == unix.SOL_TLS && cmsg.Type == TLS_GET_RECORD_TYPE {