From 5a0af953e0e5e477ee0ab6e9d9c926bd904c6ef8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Thu, 7 Dec 2023 11:56:57 +0800 Subject: [PATCH] Improve read waiter interface --- common/badtls/badtls.go | 233 ----------------------------- common/badtls/badtls_stub.go | 14 -- common/badtls/link.go | 22 --- common/badtls/read_wait.go | 115 ++++++++++++++ common/badtls/read_wait_stub.go | 13 ++ common/tls/client.go | 13 +- common/tls/server.go | 14 +- go.mod | 10 +- go.sum | 20 +-- outbound/dns.go | 20 +-- transport/fakeip/packet_wait.go | 12 +- transport/trojan/mux.go | 2 +- transport/trojan/protocol.go | 7 +- transport/trojan/protocol_wait.go | 45 ++++++ transport/trojan/service.go | 3 +- transport/trojan/service_wait.go | 45 ++++++ transport/wireguard/client_bind.go | 7 +- 17 files changed, 283 insertions(+), 312 deletions(-) delete mode 100644 common/badtls/badtls.go delete mode 100644 common/badtls/badtls_stub.go delete mode 100644 common/badtls/link.go create mode 100644 common/badtls/read_wait.go create mode 100644 common/badtls/read_wait_stub.go create mode 100644 transport/trojan/protocol_wait.go create mode 100644 transport/trojan/service_wait.go diff --git a/common/badtls/badtls.go b/common/badtls/badtls.go deleted file mode 100644 index c5c55e3c..00000000 --- a/common/badtls/badtls.go +++ /dev/null @@ -1,233 +0,0 @@ -//go:build go1.20 && !go1.21 - -package badtls - -import ( - "crypto/cipher" - "crypto/rand" - "crypto/tls" - "encoding/binary" - "io" - "net" - "reflect" - "sync" - "sync/atomic" - "unsafe" - - "github.com/sagernet/sing-box/log" - "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/buf" - "github.com/sagernet/sing/common/bufio" - E "github.com/sagernet/sing/common/exceptions" - N "github.com/sagernet/sing/common/network" - aTLS "github.com/sagernet/sing/common/tls" -) - -type Conn struct { - *tls.Conn - writer N.ExtendedWriter - isHandshakeComplete *atomic.Bool - activeCall *atomic.Int32 - closeNotifySent *bool - version *uint16 - rand io.Reader - halfAccess *sync.Mutex - halfError *error - cipher cipher.AEAD - explicitNonceLen int - halfPtr uintptr - halfSeq []byte - halfScratchBuf []byte -} - -func TryCreate(conn aTLS.Conn) aTLS.Conn { - tlsConn, ok := conn.(*tls.Conn) - if !ok { - return conn - } - badConn, err := Create(tlsConn) - if err != nil { - log.Warn("initialize badtls: ", err) - return conn - } - return badConn -} - -func Create(conn *tls.Conn) (aTLS.Conn, error) { - rawConn := reflect.Indirect(reflect.ValueOf(conn)) - rawIsHandshakeComplete := rawConn.FieldByName("isHandshakeComplete") - if !rawIsHandshakeComplete.IsValid() || rawIsHandshakeComplete.Kind() != reflect.Struct { - return nil, E.New("badtls: invalid isHandshakeComplete") - } - isHandshakeComplete := (*atomic.Bool)(unsafe.Pointer(rawIsHandshakeComplete.UnsafeAddr())) - if !isHandshakeComplete.Load() { - return nil, E.New("handshake not finished") - } - rawActiveCall := rawConn.FieldByName("activeCall") - if !rawActiveCall.IsValid() || rawActiveCall.Kind() != reflect.Struct { - return nil, E.New("badtls: invalid active call") - } - activeCall := (*atomic.Int32)(unsafe.Pointer(rawActiveCall.UnsafeAddr())) - rawHalfConn := rawConn.FieldByName("out") - if !rawHalfConn.IsValid() || rawHalfConn.Kind() != reflect.Struct { - return nil, E.New("badtls: invalid half conn") - } - rawVersion := rawConn.FieldByName("vers") - if !rawVersion.IsValid() || rawVersion.Kind() != reflect.Uint16 { - return nil, E.New("badtls: invalid version") - } - version := (*uint16)(unsafe.Pointer(rawVersion.UnsafeAddr())) - rawCloseNotifySent := rawConn.FieldByName("closeNotifySent") - if !rawCloseNotifySent.IsValid() || rawCloseNotifySent.Kind() != reflect.Bool { - return nil, E.New("badtls: invalid notify") - } - closeNotifySent := (*bool)(unsafe.Pointer(rawCloseNotifySent.UnsafeAddr())) - rawConfig := reflect.Indirect(rawConn.FieldByName("config")) - if !rawConfig.IsValid() || rawConfig.Kind() != reflect.Struct { - return nil, E.New("badtls: bad config") - } - config := (*tls.Config)(unsafe.Pointer(rawConfig.UnsafeAddr())) - randReader := config.Rand - if randReader == nil { - randReader = rand.Reader - } - rawHalfMutex := rawHalfConn.FieldByName("Mutex") - if !rawHalfMutex.IsValid() || rawHalfMutex.Kind() != reflect.Struct { - return nil, E.New("badtls: invalid half mutex") - } - halfAccess := (*sync.Mutex)(unsafe.Pointer(rawHalfMutex.UnsafeAddr())) - rawHalfError := rawHalfConn.FieldByName("err") - if !rawHalfError.IsValid() || rawHalfError.Kind() != reflect.Interface { - return nil, E.New("badtls: invalid half error") - } - halfError := (*error)(unsafe.Pointer(rawHalfError.UnsafeAddr())) - rawHalfCipherInterface := rawHalfConn.FieldByName("cipher") - if !rawHalfCipherInterface.IsValid() || rawHalfCipherInterface.Kind() != reflect.Interface { - return nil, E.New("badtls: invalid cipher interface") - } - rawHalfCipher := rawHalfCipherInterface.Elem() - aeadCipher, loaded := valueInterface(rawHalfCipher, false).(cipher.AEAD) - if !loaded { - return nil, E.New("badtls: invalid AEAD cipher") - } - var explicitNonceLen int - switch cipherName := reflect.Indirect(rawHalfCipher).Type().String(); cipherName { - case "tls.prefixNonceAEAD": - explicitNonceLen = aeadCipher.NonceSize() - case "tls.xorNonceAEAD": - default: - return nil, E.New("badtls: unknown cipher type: ", cipherName) - } - rawHalfSeq := rawHalfConn.FieldByName("seq") - if !rawHalfSeq.IsValid() || rawHalfSeq.Kind() != reflect.Array { - return nil, E.New("badtls: invalid seq") - } - halfSeq := rawHalfSeq.Bytes() - rawHalfScratchBuf := rawHalfConn.FieldByName("scratchBuf") - if !rawHalfScratchBuf.IsValid() || rawHalfScratchBuf.Kind() != reflect.Array { - return nil, E.New("badtls: invalid scratchBuf") - } - halfScratchBuf := rawHalfScratchBuf.Bytes() - return &Conn{ - Conn: conn, - writer: bufio.NewExtendedWriter(conn.NetConn()), - isHandshakeComplete: isHandshakeComplete, - activeCall: activeCall, - closeNotifySent: closeNotifySent, - version: version, - halfAccess: halfAccess, - halfError: halfError, - cipher: aeadCipher, - explicitNonceLen: explicitNonceLen, - rand: randReader, - halfPtr: rawHalfConn.UnsafeAddr(), - halfSeq: halfSeq, - halfScratchBuf: halfScratchBuf, - }, nil -} - -func (c *Conn) WriteBuffer(buffer *buf.Buffer) error { - if buffer.Len() > maxPlaintext { - defer buffer.Release() - return common.Error(c.Write(buffer.Bytes())) - } - for { - x := c.activeCall.Load() - if x&1 != 0 { - return net.ErrClosed - } - if c.activeCall.CompareAndSwap(x, x+2) { - break - } - } - defer c.activeCall.Add(-2) - c.halfAccess.Lock() - defer c.halfAccess.Unlock() - if err := *c.halfError; err != nil { - return err - } - if *c.closeNotifySent { - return errShutdown - } - dataLen := buffer.Len() - dataBytes := buffer.Bytes() - outBuf := buffer.ExtendHeader(recordHeaderLen + c.explicitNonceLen) - outBuf[0] = 23 - version := *c.version - if version == 0 { - version = tls.VersionTLS10 - } else if version == tls.VersionTLS13 { - version = tls.VersionTLS12 - } - binary.BigEndian.PutUint16(outBuf[1:], version) - var nonce []byte - if c.explicitNonceLen > 0 { - nonce = outBuf[5 : 5+c.explicitNonceLen] - if c.explicitNonceLen < 16 { - copy(nonce, c.halfSeq) - } else { - if _, err := io.ReadFull(c.rand, nonce); err != nil { - return err - } - } - } - if len(nonce) == 0 { - nonce = c.halfSeq - } - if *c.version == tls.VersionTLS13 { - buffer.FreeBytes()[0] = 23 - binary.BigEndian.PutUint16(outBuf[3:], uint16(dataLen+1+c.cipher.Overhead())) - c.cipher.Seal(outBuf, nonce, outBuf[recordHeaderLen:recordHeaderLen+c.explicitNonceLen+dataLen+1], outBuf[:recordHeaderLen]) - buffer.Extend(1 + c.cipher.Overhead()) - } else { - binary.BigEndian.PutUint16(outBuf[3:], uint16(dataLen)) - additionalData := append(c.halfScratchBuf[:0], c.halfSeq...) - additionalData = append(additionalData, outBuf[:recordHeaderLen]...) - c.cipher.Seal(outBuf, nonce, dataBytes, additionalData) - buffer.Extend(c.cipher.Overhead()) - binary.BigEndian.PutUint16(outBuf[3:], uint16(dataLen+c.explicitNonceLen+c.cipher.Overhead())) - } - incSeq(c.halfPtr) - log.Trace("badtls write ", buffer.Len()) - return c.writer.WriteBuffer(buffer) -} - -func (c *Conn) FrontHeadroom() int { - return recordHeaderLen + c.explicitNonceLen -} - -func (c *Conn) RearHeadroom() int { - return 1 + c.cipher.Overhead() -} - -func (c *Conn) WriterMTU() int { - return maxPlaintext -} - -func (c *Conn) Upstream() any { - return c.Conn -} - -func (c *Conn) UpstreamWriter() any { - return c.NetConn() -} diff --git a/common/badtls/badtls_stub.go b/common/badtls/badtls_stub.go deleted file mode 100644 index 2f0028f6..00000000 --- a/common/badtls/badtls_stub.go +++ /dev/null @@ -1,14 +0,0 @@ -//go:build !go1.19 || go1.21 - -package badtls - -import ( - "crypto/tls" - "os" - - aTLS "github.com/sagernet/sing/common/tls" -) - -func Create(conn *tls.Conn) (aTLS.Conn, error) { - return nil, os.ErrInvalid -} diff --git a/common/badtls/link.go b/common/badtls/link.go deleted file mode 100644 index b8d5f4bd..00000000 --- a/common/badtls/link.go +++ /dev/null @@ -1,22 +0,0 @@ -//go:build go1.20 && !go.1.21 - -package badtls - -import ( - "reflect" - _ "unsafe" -) - -const ( - maxPlaintext = 16384 // maximum plaintext payload length - recordHeaderLen = 5 // record header length -) - -//go:linkname errShutdown crypto/tls.errShutdown -var errShutdown error - -//go:linkname incSeq crypto/tls.(*halfConn).incSeq -func incSeq(conn uintptr) - -//go:linkname valueInterface reflect.valueInterface -func valueInterface(v reflect.Value, safe bool) any diff --git a/common/badtls/read_wait.go b/common/badtls/read_wait.go new file mode 100644 index 00000000..4657bc5b --- /dev/null +++ b/common/badtls/read_wait.go @@ -0,0 +1,115 @@ +//go:build go1.21 && !without_badtls + +package badtls + +import ( + "bytes" + "os" + "reflect" + "sync" + "unsafe" + + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/tls" +) + +var _ N.ReadWaiter = (*ReadWaitConn)(nil) + +type ReadWaitConn struct { + *tls.STDConn + halfAccess *sync.Mutex + rawInput *bytes.Buffer + input *bytes.Reader + hand *bytes.Buffer + readWaitOptions N.ReadWaitOptions +} + +func NewReadWaitConn(conn tls.Conn) (tls.Conn, error) { + stdConn, isSTDConn := conn.(*tls.STDConn) + if !isSTDConn { + return nil, os.ErrInvalid + } + rawConn := reflect.Indirect(reflect.ValueOf(stdConn)) + rawHalfConn := rawConn.FieldByName("in") + if !rawHalfConn.IsValid() || rawHalfConn.Kind() != reflect.Struct { + return nil, E.New("badtls: invalid half conn") + } + rawHalfMutex := rawHalfConn.FieldByName("Mutex") + if !rawHalfMutex.IsValid() || rawHalfMutex.Kind() != reflect.Struct { + return nil, E.New("badtls: invalid half mutex") + } + halfAccess := (*sync.Mutex)(unsafe.Pointer(rawHalfMutex.UnsafeAddr())) + rawRawInput := rawConn.FieldByName("rawInput") + if !rawRawInput.IsValid() || rawRawInput.Kind() != reflect.Struct { + return nil, E.New("badtls: invalid raw input") + } + rawInput := (*bytes.Buffer)(unsafe.Pointer(rawRawInput.UnsafeAddr())) + rawInput0 := rawConn.FieldByName("input") + if !rawInput0.IsValid() || rawInput0.Kind() != reflect.Struct { + return nil, E.New("badtls: invalid input") + } + input := (*bytes.Reader)(unsafe.Pointer(rawInput0.UnsafeAddr())) + rawHand := rawConn.FieldByName("hand") + if !rawHand.IsValid() || rawHand.Kind() != reflect.Struct { + return nil, E.New("badtls: invalid hand") + } + hand := (*bytes.Buffer)(unsafe.Pointer(rawHand.UnsafeAddr())) + return &ReadWaitConn{ + STDConn: stdConn, + halfAccess: halfAccess, + rawInput: rawInput, + input: input, + hand: hand, + }, nil +} + +func (c *ReadWaitConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { + c.readWaitOptions = options + return false +} + +func (c *ReadWaitConn) WaitReadBuffer() (buffer *buf.Buffer, err error) { + err = c.Handshake() + if err != nil { + return + } + c.halfAccess.Lock() + defer c.halfAccess.Unlock() + for c.input.Len() == 0 { + err = tlsReadRecord(c.STDConn) + if err != nil { + return + } + for c.hand.Len() > 0 { + err = tlsHandlePostHandshakeMessage(c.STDConn) + if err != nil { + return + } + } + } + buffer = c.readWaitOptions.NewBuffer() + n, err := c.input.Read(buffer.FreeBytes()) + if err != nil { + buffer.Release() + return + } + buffer.Truncate(n) + + if n != 0 && c.input.Len() == 0 && c.rawInput.Len() > 0 && + // recordType(c.rawInput.Bytes()[0]) == recordTypeAlert { + c.rawInput.Bytes()[0] == 21 { + _ = tlsReadRecord(c.STDConn) + // return n, err // will be io.EOF on closeNotify + } + + c.readWaitOptions.PostReturn(buffer) + return +} + +//go:linkname tlsReadRecord crypto/tls.(*Conn).readRecord +func tlsReadRecord(c *tls.STDConn) error + +//go:linkname tlsHandlePostHandshakeMessage crypto/tls.(*Conn).handlePostHandshakeMessage +func tlsHandlePostHandshakeMessage(c *tls.STDConn) error diff --git a/common/badtls/read_wait_stub.go b/common/badtls/read_wait_stub.go new file mode 100644 index 00000000..c5c9946f --- /dev/null +++ b/common/badtls/read_wait_stub.go @@ -0,0 +1,13 @@ +//go:build !go1.21 || without_badtls + +package badtls + +import ( + "os" + + "github.com/sagernet/sing/common/tls" +) + +func NewReadWaitConn(conn tls.Conn) (tls.Conn, error) { + return nil, os.ErrInvalid +} diff --git a/common/tls/client.go b/common/tls/client.go index d1c9475a..4d6b0c54 100644 --- a/common/tls/client.go +++ b/common/tls/client.go @@ -6,6 +6,7 @@ import ( "os" "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/badtls" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/option" M "github.com/sagernet/sing/common/metadata" @@ -42,7 +43,17 @@ func NewClient(ctx context.Context, serverAddress string, options option.Outboun func ClientHandshake(ctx context.Context, conn net.Conn, config Config) (Conn, error) { ctx, cancel := context.WithTimeout(ctx, C.TCPTimeout) defer cancel() - return aTLS.ClientHandshake(ctx, conn, config) + tlsConn, err := aTLS.ClientHandshake(ctx, conn, config) + if err != nil { + return nil, err + } + readWaitConn, err := badtls.NewReadWaitConn(tlsConn) + if err == nil { + return readWaitConn, nil + } else if err != os.ErrInvalid { + return nil, err + } + return tlsConn, nil } type Dialer struct { diff --git a/common/tls/server.go b/common/tls/server.go index ac6d0a2e..6afd89d6 100644 --- a/common/tls/server.go +++ b/common/tls/server.go @@ -3,7 +3,9 @@ package tls import ( "context" "net" + "os" + "github.com/sagernet/sing-box/common/badtls" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" @@ -26,5 +28,15 @@ func NewServer(ctx context.Context, logger log.Logger, options option.InboundTLS func ServerHandshake(ctx context.Context, conn net.Conn, config ServerConfig) (Conn, error) { ctx, cancel := context.WithTimeout(ctx, C.TCPTimeout) defer cancel() - return aTLS.ServerHandshake(ctx, conn, config) + tlsConn, err := aTLS.ServerHandshake(ctx, conn, config) + if err != nil { + return nil, err + } + readWaitConn, err := badtls.NewReadWaitConn(tlsConn) + if err == nil { + return readWaitConn, nil + } else if err != os.ErrInvalid { + return nil, err + } + return tlsConn, nil } diff --git a/go.mod b/go.mod index 90d9d92c..0d83725b 100644 --- a/go.mod +++ b/go.mod @@ -26,14 +26,14 @@ require ( github.com/sagernet/gvisor v0.0.0-20231119034329-07cfb6aaf930 github.com/sagernet/quic-go v0.40.0 github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691 - github.com/sagernet/sing v0.2.19-0.20231208031707-0830d1517da8 + github.com/sagernet/sing v0.2.19-0.20231208031707-03f22280298b github.com/sagernet/sing-dns v0.1.11 - github.com/sagernet/sing-mux v0.1.5 - github.com/sagernet/sing-quic v0.1.5 + github.com/sagernet/sing-mux v0.1.6-0.20231207143704-9f6c20fb5266 + github.com/sagernet/sing-quic v0.1.6-0.20231207143711-eb3cbf9ed054 github.com/sagernet/sing-shadowsocks v0.2.5 - github.com/sagernet/sing-shadowsocks2 v0.1.5 + github.com/sagernet/sing-shadowsocks2 v0.1.6-0.20231207143709-50439739601a github.com/sagernet/sing-shadowtls v0.1.4 - github.com/sagernet/sing-tun v0.1.22 + github.com/sagernet/sing-tun v0.1.23-0.20231207143707-82a810316e14 github.com/sagernet/sing-vmess v0.1.8 github.com/sagernet/smux v0.0.0-20230312102458-337ec2a5af37 github.com/sagernet/tfo-go v0.0.0-20230816093905-5a5c285d44a6 diff --git a/go.sum b/go.sum index 93524512..a00f2d0a 100644 --- a/go.sum +++ b/go.sum @@ -110,22 +110,22 @@ github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691 h1:5Th31OC6yj8byL github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691/go.mod h1:B8lp4WkQ1PwNnrVMM6KyuFR20pU8jYBD+A4EhJovEXU= github.com/sagernet/sing v0.0.0-20220817130738-ce854cda8522/go.mod h1:QVsS5L/ZA2Q5UhQwLrn0Trw+msNd/NPGEhBKR/ioWiY= github.com/sagernet/sing v0.1.8/go.mod h1:jt1w2u7lJQFFSGLiRrRIs5YWmx4kAPfWuOejuDW9qMk= -github.com/sagernet/sing v0.2.19-0.20231208031707-0830d1517da8 h1:w9gxEZISgkLf1VCySAu4ao7Ptgbkjl3t5JosaDAhqRE= -github.com/sagernet/sing v0.2.19-0.20231208031707-0830d1517da8/go.mod h1:Ce5LNojQOgOiWhiD8pPD6E9H7e2KgtOe3Zxx4Ou5u80= +github.com/sagernet/sing v0.2.19-0.20231208031707-03f22280298b h1:UlcBAGEJ2MgtmAyFLRm1q6cXazMvIXBa5kKAAMwVWDo= +github.com/sagernet/sing v0.2.19-0.20231208031707-03f22280298b/go.mod h1:Ce5LNojQOgOiWhiD8pPD6E9H7e2KgtOe3Zxx4Ou5u80= github.com/sagernet/sing-dns v0.1.11 h1:PPrMCVVrAeR3f5X23I+cmvacXJ+kzuyAsBiWyUKhGSE= github.com/sagernet/sing-dns v0.1.11/go.mod h1:zJ/YjnYB61SYE+ubMcMqVdpaSvsyQ2iShQGO3vuLvvE= -github.com/sagernet/sing-mux v0.1.5 h1:jUbYth9QQd1wsDmU8Ush+fKce7lNo9TMv2dp8PJtSOY= -github.com/sagernet/sing-mux v0.1.5/go.mod h1:MoH6Soz1R+CYZcCeIXZWx6fkZa6hQc9o3HZu9G6CDTw= -github.com/sagernet/sing-quic v0.1.5 h1:PIQzE4cGrry+JkkMEJH/EH3wRkv/QgD48+ScNr/2oig= -github.com/sagernet/sing-quic v0.1.5/go.mod h1:n2mXukpubasyV4SlWyyW0+LCdAn7DZ8/brAkUxZujrw= +github.com/sagernet/sing-mux v0.1.6-0.20231207143704-9f6c20fb5266 h1:QqwwUyEfmOuoGVTZ2cYvUJEeSWlzunvQLRmv+9B41uk= +github.com/sagernet/sing-mux v0.1.6-0.20231207143704-9f6c20fb5266/go.mod h1:uxpcXa8JqSR+ufC1sGAPsCs027wpE7v1ltnhuJKqyBQ= +github.com/sagernet/sing-quic v0.1.6-0.20231207143711-eb3cbf9ed054 h1:Ed7FskwQcep5oQ+QahgVK0F6jPPSV8Nqwjr9MwGatMU= +github.com/sagernet/sing-quic v0.1.6-0.20231207143711-eb3cbf9ed054/go.mod h1:u758WWv3G1OITG365CYblL0NfAruFL1PpLD9DUVTv1o= github.com/sagernet/sing-shadowsocks v0.2.5 h1:qxIttos4xu6ii7MTVJYA8EFQR7Q3KG6xMqmLJIFtBaY= github.com/sagernet/sing-shadowsocks v0.2.5/go.mod h1:MGWGkcU2xW2G2mfArT9/QqpVLOGU+dBaahZCtPHdt7A= -github.com/sagernet/sing-shadowsocks2 v0.1.5 h1:JDeAJ4ZWlYZ7F6qEVdDKPhQEangxKw/JtmU+i/YfCYE= -github.com/sagernet/sing-shadowsocks2 v0.1.5/go.mod h1:KF65y8lI5PGHyMgRZGYXYsH9ilgRc/yr+NYbSNGuBm4= +github.com/sagernet/sing-shadowsocks2 v0.1.6-0.20231207143709-50439739601a h1:uYIKfpE1/EJpa+1Bja7b006VixeRuVduOpeuesMk2lU= +github.com/sagernet/sing-shadowsocks2 v0.1.6-0.20231207143709-50439739601a/go.mod h1:pjeylQ4ApvpEH7B4PUBrdyJf4xmQkg8BaIzT5fI2fR0= github.com/sagernet/sing-shadowtls v0.1.4 h1:aTgBSJEgnumzFenPvc+kbD9/W0PywzWevnVpEx6Tw3k= github.com/sagernet/sing-shadowtls v0.1.4/go.mod h1:F8NBgsY5YN2beQavdgdm1DPlhaKQlaL6lpDdcBglGK4= -github.com/sagernet/sing-tun v0.1.22 h1:AECJTkiugCK+GCrV41YZ56HB/Z/lDXZvRVas4fNvO30= -github.com/sagernet/sing-tun v0.1.22/go.mod h1:fliIEXDRv2u1uT3uCZIoA1daoZcD4f6TeIuzNIzlsN8= +github.com/sagernet/sing-tun v0.1.23-0.20231207143707-82a810316e14 h1:79d3jw/nlhy3VAIoRvMxRjcOUh7e0D8Mx0cuaBrdIC4= +github.com/sagernet/sing-tun v0.1.23-0.20231207143707-82a810316e14/go.mod h1:ygdUHhVv4ZEsu0+4rAbAAoHqzqrhvhVNxrbMryapDwI= github.com/sagernet/sing-vmess v0.1.8 h1:XVWad1RpTy9b5tPxdm5MCU8cGfrTGdR8qCq6HV2aCNc= github.com/sagernet/sing-vmess v0.1.8/go.mod h1:vhx32UNzTDUkNwOyIjcZQohre1CaytquC5mPplId8uA= github.com/sagernet/smux v0.0.0-20230312102458-337ec2a5af37 h1:HuE6xSwco/Xed8ajZ+coeYLmioq0Qp1/Z2zczFaV8as= diff --git a/outbound/dns.go b/outbound/dns.go index 74adb3ae..fcb67d45 100644 --- a/outbound/dns.go +++ b/outbound/dns.go @@ -111,6 +111,9 @@ func (d *DNS) NewPacketConnection(ctx context.Context, conn N.PacketConn, metada } } if readWaiter, created := bufio.CreatePacketReadWaiter(reader); created { + readWaiter.InitializeReadWaiter(N.ReadWaitOptions{ + MTU: dns.FixedPacketSize, + }) return d.newPacketConnection(ctx, conn, readWaiter, counters, cachedPackets, metadata) } break @@ -193,15 +196,13 @@ func (d *DNS) newPacketConnection(ctx context.Context, conn N.PacketConn, readWa timeout := canceler.New(fastClose, cancel, C.DNSTimeout) var group task.Group group.Append0(func(ctx context.Context) error { - var buffer *buf.Buffer - readWaiter.InitializeReadWaiter(func() *buf.Buffer { - return buf.NewSize(dns.FixedPacketSize) - }) - defer readWaiter.InitializeReadWaiter(nil) for { - var message mDNS.Msg - var destination M.Socksaddr - var err error + var ( + message mDNS.Msg + destination M.Socksaddr + err error + buffer *buf.Buffer + ) if len(cached) > 0 { packet := cached[0] cached = cached[1:] @@ -216,9 +217,8 @@ func (d *DNS) newPacketConnection(ctx context.Context, conn N.PacketConn, readWa } destination = packet.Destination } else { - destination, err = readWaiter.WaitReadPacket() + buffer, destination, err = readWaiter.WaitReadPacket() if err != nil { - buffer.Release() cancel(err) return err } diff --git a/transport/fakeip/packet_wait.go b/transport/fakeip/packet_wait.go index 3e3fd89f..9fa4a5bd 100644 --- a/transport/fakeip/packet_wait.go +++ b/transport/fakeip/packet_wait.go @@ -17,16 +17,16 @@ func (c *NATPacketConn) CreatePacketReadWaiter() (N.PacketReadWaiter, bool) { type waitNATPacketConn struct { *NATPacketConn - waiter N.PacketReadWaiter + readWaiter N.PacketReadWaiter } -func (c *waitNATPacketConn) InitializeReadWaiter(newBuffer func() *buf.Buffer) { - c.waiter.InitializeReadWaiter(newBuffer) +func (c *waitNATPacketConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { + return c.readWaiter.InitializeReadWaiter(options) } -func (c *waitNATPacketConn) WaitReadPacket() (destination M.Socksaddr, err error) { - destination, err = c.waiter.WaitReadPacket() - if socksaddrWithoutPort(destination) == c.origin { +func (c *waitNATPacketConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { + buffer, destination, err = c.readWaiter.WaitReadPacket() + if err == nil && socksaddrWithoutPort(destination) == c.origin { destination = M.Socksaddr{ Addr: c.destination.Addr, Fqdn: c.destination.Fqdn, diff --git a/transport/trojan/mux.go b/transport/trojan/mux.go index 77324000..13ac1e83 100644 --- a/transport/trojan/mux.go +++ b/transport/trojan/mux.go @@ -53,7 +53,7 @@ func newMuxConnection0(ctx context.Context, stream net.Conn, metadata M.Metadata case CommandTCP: return handler.NewConnection(ctx, stream, metadata) case CommandUDP: - return handler.NewPacketConnection(ctx, &PacketConn{stream}, metadata) + return handler.NewPacketConnection(ctx, &PacketConn{Conn: stream}, metadata) default: return E.New("unknown command ", command) } diff --git a/transport/trojan/protocol.go b/transport/trojan/protocol.go index 09e18782..394ba291 100644 --- a/transport/trojan/protocol.go +++ b/transport/trojan/protocol.go @@ -85,9 +85,10 @@ func (c *ClientConn) Upstream() any { type ClientPacketConn struct { net.Conn - access sync.Mutex - key [KeyLength]byte - headerWritten bool + access sync.Mutex + key [KeyLength]byte + headerWritten bool + readWaitOptions N.ReadWaitOptions } func NewClientPacketConn(conn net.Conn, key [KeyLength]byte) *ClientPacketConn { diff --git a/transport/trojan/protocol_wait.go b/transport/trojan/protocol_wait.go new file mode 100644 index 00000000..c6b4ec06 --- /dev/null +++ b/transport/trojan/protocol_wait.go @@ -0,0 +1,45 @@ +package trojan + +import ( + "encoding/binary" + + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/rw" +) + +var _ N.PacketReadWaiter = (*ClientPacketConn)(nil) + +func (c *ClientPacketConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { + c.readWaitOptions = options + return false +} + +func (c *ClientPacketConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { + destination, err = M.SocksaddrSerializer.ReadAddrPort(c.Conn) + if err != nil { + return nil, M.Socksaddr{}, E.Cause(err, "read destination") + } + + var length uint16 + err = binary.Read(c.Conn, binary.BigEndian, &length) + if err != nil { + return nil, M.Socksaddr{}, E.Cause(err, "read chunk length") + } + + err = rw.SkipN(c.Conn, 2) + if err != nil { + return nil, M.Socksaddr{}, E.Cause(err, "skip crlf") + } + + buffer = c.readWaitOptions.NewPacketBuffer() + _, err = buffer.ReadFullFrom(c.Conn, int(length)) + if err != nil { + buffer.Release() + return + } + c.readWaitOptions.PostReturn(buffer) + return +} diff --git a/transport/trojan/service.go b/transport/trojan/service.go index de6bd7e8..9078276c 100644 --- a/transport/trojan/service.go +++ b/transport/trojan/service.go @@ -105,7 +105,7 @@ func (s *Service[K]) NewConnection(ctx context.Context, conn net.Conn, metadata case CommandTCP: return s.handler.NewConnection(ctx, conn, metadata) case CommandUDP: - return s.handler.NewPacketConnection(ctx, &PacketConn{conn}, metadata) + return s.handler.NewPacketConnection(ctx, &PacketConn{Conn: conn}, metadata) // case CommandMux: default: return HandleMuxConnection(ctx, conn, metadata, s.handler) @@ -122,6 +122,7 @@ func (s *Service[K]) fallback(ctx context.Context, conn net.Conn, metadata M.Met type PacketConn struct { net.Conn + readWaitOptions N.ReadWaitOptions } func (c *PacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { diff --git a/transport/trojan/service_wait.go b/transport/trojan/service_wait.go new file mode 100644 index 00000000..5ec082fe --- /dev/null +++ b/transport/trojan/service_wait.go @@ -0,0 +1,45 @@ +package trojan + +import ( + "encoding/binary" + + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/rw" +) + +var _ N.PacketReadWaiter = (*PacketConn)(nil) + +func (c *PacketConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { + c.readWaitOptions = options + return false +} + +func (c *PacketConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { + destination, err = M.SocksaddrSerializer.ReadAddrPort(c.Conn) + if err != nil { + return nil, M.Socksaddr{}, E.Cause(err, "read destination") + } + + var length uint16 + err = binary.Read(c.Conn, binary.BigEndian, &length) + if err != nil { + return nil, M.Socksaddr{}, E.Cause(err, "read chunk length") + } + + err = rw.SkipN(c.Conn, 2) + if err != nil { + return nil, M.Socksaddr{}, E.Cause(err, "skip crlf") + } + + buffer = c.readWaitOptions.NewPacketBuffer() + _, err = buffer.ReadFullFrom(c.Conn, int(length)) + if err != nil { + buffer.Release() + return + } + c.readWaitOptions.PostReturn(buffer) + return +} diff --git a/transport/wireguard/client_bind.go b/transport/wireguard/client_bind.go index 2b56f73a..a72432d3 100644 --- a/transport/wireguard/client_bind.go +++ b/transport/wireguard/client_bind.go @@ -76,11 +76,8 @@ func (c *ClientBind) connect() (*wireConn, error) { return nil, err } c.conn = &wireConn{ - PacketConn: &bufio.UnbindPacketConn{ - ExtendedConn: bufio.NewExtendedConn(udpConn), - Addr: c.connectAddr, - }, - done: make(chan struct{}), + PacketConn: bufio.NewUnbindPacketConn(udpConn), + done: make(chan struct{}), } } else { udpConn, err := c.dialer.ListenPacket(c.ctx, M.Socksaddr{Addr: netip.IPv4Unspecified()})