From 7301a2e69ae453e00cc7742dea394f8bdf1b4e91 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 27 Aug 2022 16:40:00 +0800 Subject: [PATCH] Improve grpc lite conn --- transport/v2raygrpclite/conn.go | 97 ++++++++++++++++++------------- transport/v2raygrpclite/server.go | 2 +- 2 files changed, 58 insertions(+), 41 deletions(-) diff --git a/transport/v2raygrpclite/conn.go b/transport/v2raygrpclite/conn.go index db89bef3..97366ef9 100644 --- a/transport/v2raygrpclite/conn.go +++ b/transport/v2raygrpclite/conn.go @@ -4,6 +4,7 @@ package v2raygrpclite import ( + std_bufio "bufio" "bytes" "encoding/binary" "io" @@ -16,6 +17,7 @@ import ( "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/rw" ) var ErrInvalidLength = E.New("invalid length") @@ -23,18 +25,19 @@ var ErrInvalidLength = E.New("invalid length") var _ net.Conn = (*GunConn)(nil) type GunConn struct { - reader io.Reader - writer io.Writer - create chan struct{} - err error - cached []byte - cachedIndex int + reader *std_bufio.Reader + writer io.Writer + flusher http.Flusher + create chan struct{} + err error + readRemaining int } -func newGunConn(reader io.Reader, writer io.Writer) *GunConn { +func newGunConn(reader io.Reader, writer io.Writer, flusher http.Flusher) *GunConn { return &GunConn{ - reader: reader, - writer: writer, + reader: std_bufio.NewReader(reader), + writer: writer, + flusher: flusher, } } @@ -46,7 +49,7 @@ func newLateGunConn(writer io.Writer) *GunConn { } func (c *GunConn) setup(reader io.Reader, err error) { - c.reader = reader + c.reader = std_bufio.NewReader(reader) c.err = err close(c.create) } @@ -59,42 +62,34 @@ func (c *GunConn) Read(b []byte) (n int, err error) { } } - if c.cached != nil { - n = copy(b, c.cached[c.cachedIndex:]) - c.cachedIndex += n - if c.cachedIndex == len(c.cached) { - buf.Put(c.cached) - c.cached = nil + if c.readRemaining > 0 { + if len(b) > c.readRemaining { + b = b[:c.readRemaining] } + n, err = c.reader.Read(b) + c.readRemaining -= n return } - buffer := buf.Get(5) - _, err = io.ReadFull(c.reader, buffer) - if err != nil { - return 0, err - } - grpcPayloadLen := binary.BigEndian.Uint32(buffer[1:]) - buf.Put(buffer) - buffer = buf.Get(int(grpcPayloadLen)) - _, err = io.ReadFull(c.reader, buffer) + _, err = c.reader.Discard(6) if err != nil { - return 0, io.ErrUnexpectedEOF + return } - protobufPayloadLen, protobufLengthLen := binary.Uvarint(buffer[1:]) - if protobufLengthLen == 0 { - return 0, ErrInvalidLength + + dataLen, err := binary.ReadUvarint(c.reader) + if err != nil { + return } - if grpcPayloadLen != uint32(protobufPayloadLen)+uint32(protobufLengthLen)+1 { - return 0, ErrInvalidLength + + readLen := int(dataLen) + c.readRemaining = readLen + if len(b) > readLen { + b = b[:readLen] } - n = copy(b, buffer[1+protobufLengthLen:]) - if n < int(protobufPayloadLen) { - c.cached = buffer - c.cachedIndex = 1 + int(protobufLengthLen) + n - return n, nil - } - return n, nil + + n, err = c.reader.Read(b) + c.readRemaining -= n + return } func (c *GunConn) Write(b []byte) (n int, err error) { @@ -111,11 +106,33 @@ func (c *GunConn) Write(b []byte) (n int, err error) { return len(b), err } -/*func (c *GunConn) ReadBuffer(buffer *buf.Buffer) error { +func uLen(x uint64) int { + i := 0 + for x >= 0x80 { + x >>= 7 + i++ + } + return i + 1 } func (c *GunConn) WriteBuffer(buffer *buf.Buffer) error { -}*/ + defer buffer.Release() + dataLen := buffer.Len() + varLen := uLen(uint64(dataLen)) + header := buffer.ExtendHeader(6 + varLen) + binary.BigEndian.PutUint32(header[1:5], uint32(1+varLen+dataLen)) + header[5] = 0x0A + binary.PutUvarint(header[6:], uint64(dataLen)) + err := rw.WriteBytes(c.writer, buffer.Bytes()) + if c.flusher != nil { + c.flusher.Flush() + } + return err +} + +func (c *GunConn) FrontHeadroom() int { + return 6 + binary.MaxVarintLen64 +} func (c *GunConn) Close() error { return common.Close(c.reader, c.writer) diff --git a/transport/v2raygrpclite/server.go b/transport/v2raygrpclite/server.go index a792ae19..9257770f 100644 --- a/transport/v2raygrpclite/server.go +++ b/transport/v2raygrpclite/server.go @@ -69,7 +69,7 @@ func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) { writer.WriteHeader(http.StatusOK) var metadata M.Metadata metadata.Source = sHttp.SourceAddress(request) - conn := newGunConn(request.Body, writer) + conn := newGunConn(request.Body, writer, writer.(http.Flusher)) s.handler.NewConnection(request.Context(), conn, metadata) }