From 53b15cca1b133bde04ca49e217365fd09d2ba4f7 Mon Sep 17 00:00:00 2001 From: arm64v8a <48624112+arm64v8a@users.noreply.github.com> Date: Wed, 7 Sep 2022 08:00:00 +0800 Subject: [PATCH] Fix naive padding --- inbound/naive.go | 334 +++++++++-------------------------------------- 1 file changed, 65 insertions(+), 269 deletions(-) diff --git a/inbound/naive.go b/inbound/naive.go index b082d5a9..1ba8ec44 100644 --- a/inbound/naive.go +++ b/inbound/naive.go @@ -19,8 +19,6 @@ import ( "github.com/sagernet/sing-box/option" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/auth" - "github.com/sagernet/sing/common/buf" - "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" @@ -160,14 +158,14 @@ func (n *Naive) ServeHTTP(writer http.ResponseWriter, request *http.Request) { destination := M.ParseSocksaddr(hostPort) if hijacker, isHijacker := writer.(http.Hijacker); isHijacker { - conn, _, err := hijacker.Hijack() + clientConn, _, err := hijacker.Hijack() if err != nil { n.badRequest(ctx, request, E.New("hijack failed")) return } - n.newConnection(ctx, &naiveH1Conn{Conn: conn}, source, destination) + n.newConnection(ctx, &naivePaddingConn{reader: clientConn, writer: clientConn}, source, destination) } else { - n.newConnection(ctx, &naiveH2Conn{reader: request.Body, writer: writer, flusher: writer.(http.Flusher)}, source, destination) + n.newConnection(ctx, &naivePaddingConn{reader: request.Body, writer: writer, flusher: writer.(http.Flusher)}, source, destination) } } @@ -216,191 +214,24 @@ func generateNaivePaddingHeader() string { const kFirstPaddings = 8 -type naiveH1Conn struct { - net.Conn - readPadding int - writePadding int +type naivePaddingConn struct { + reader io.Reader + writer io.Writer + flusher http.Flusher + rAddr net.Addr + readPadding int + writePadding int + readRemaining int paddingRemaining int } -func (c *naiveH1Conn) Read(p []byte) (n int, err error) { +func (c *naivePaddingConn) Read(p []byte) (n int, err error) { n, err = c.read(p) return n, wrapHttpError(err) } -func (c *naiveH1Conn) read(p []byte) (n int, err error) { - if c.readRemaining > 0 { - if len(p) > c.readRemaining { - p = p[:c.readRemaining] - } - n, err = c.Conn.Read(p) - if err != nil { - return - } - c.readRemaining -= n - return - } - if c.paddingRemaining > 0 { - err = rw.SkipN(c.Conn, c.paddingRemaining) - if err != nil { - return - } - c.paddingRemaining = 0 - } - if c.readPadding < kFirstPaddings { - paddingHdr := p[:3] - _, err = io.ReadFull(c.Conn, paddingHdr) - if err != nil { - return - } - originalDataSize := int(binary.BigEndian.Uint16(paddingHdr[:2])) - paddingSize := int(paddingHdr[2]) - if len(p) > originalDataSize { - p = p[:originalDataSize] - } - n, err = c.Conn.Read(p) - if err != nil { - return - } - c.readPadding++ - c.readRemaining = originalDataSize - n - c.paddingRemaining = paddingSize - return - } - return c.Conn.Read(p) -} - -func (c *naiveH1Conn) Write(p []byte) (n int, err error) { - for pLen := len(p); pLen > 0; { - var data []byte - if pLen > 65535 { - data = p[:65535] - p = p[65535:] - pLen -= 65535 - } else { - data = p - pLen = 0 - } - var writeN int - writeN, err = c.write(data) - n += writeN - if err != nil { - break - } - } - return n, wrapHttpError(err) -} - -func (c *naiveH1Conn) write(p []byte) (n int, err error) { - if c.writePadding < kFirstPaddings { - paddingSize := rand.Intn(256) - - _buffer := buf.StackNewSize(3 + len(p) + paddingSize) - defer common.KeepAlive(_buffer) - buffer := common.Dup(_buffer) - defer buffer.Release() - header := buffer.Extend(3) - binary.BigEndian.PutUint16(header, uint16(len(p))) - header[2] = byte(paddingSize) - - common.Must1(buffer.Write(p)) - _, err = c.Conn.Write(buffer.Bytes()) - if err == nil { - n = len(p) - } - c.writePadding++ - return - } - return c.Conn.Write(p) -} - -func (c *naiveH1Conn) FrontHeadroom() int { - if c.writePadding < kFirstPaddings { - return 3 - } - return 0 -} - -func (c *naiveH1Conn) RearHeadroom() int { - if c.writePadding < kFirstPaddings { - return 255 - } - return 0 -} - -func (c *naiveH1Conn) WriterMTU() int { - if c.writePadding < kFirstPaddings { - return 65535 - } - return 0 -} - -func (c *naiveH1Conn) WriteBuffer(buffer *buf.Buffer) error { - defer buffer.Release() - if c.writePadding < kFirstPaddings { - bufferLen := buffer.Len() - if bufferLen > 65535 { - return common.Error(c.Write(buffer.Bytes())) - } - paddingSize := rand.Intn(256) - header := buffer.ExtendHeader(3) - binary.BigEndian.PutUint16(header, uint16(bufferLen)) - header[2] = byte(paddingSize) - buffer.Extend(paddingSize) - c.writePadding++ - } - return wrapHttpError(common.Error(c.Conn.Write(buffer.Bytes()))) -} - -// FIXME -/*func (c *naiveH1Conn) WriteTo(w io.Writer) (n int64, err error) { - if c.readPadding < kFirstPaddings { - n, err = bufio.WriteToN(c, w, kFirstPaddings-c.readPadding) - } else { - n, err = bufio.Copy(w, c.Conn) - } - return n, wrapHttpError(err) -}*/ - -func (c *naiveH1Conn) ReadFrom(r io.Reader) (n int64, err error) { - if c.writePadding < kFirstPaddings { - n, err = bufio.ReadFromN(c, r, kFirstPaddings-c.writePadding) - } else { - n, err = bufio.Copy(c.Conn, r) - } - return n, wrapHttpError(err) -} - -func (c *naiveH1Conn) Upstream() any { - return c.Conn -} - -func (c *naiveH1Conn) ReaderReplaceable() bool { - return c.readRemaining == kFirstPaddings -} - -func (c *naiveH1Conn) WriterReplaceable() bool { - return c.writePadding == kFirstPaddings -} - -type naiveH2Conn struct { - reader io.Reader - writer io.Writer - flusher http.Flusher - rAddr net.Addr - readPadding int - writePadding int - readRemaining int - paddingRemaining int -} - -func (c *naiveH2Conn) Read(p []byte) (n int, err error) { - n, err = c.read(p) - return n, wrapHttpError(err) -} - -func (c *naiveH2Conn) read(p []byte) (n int, err error) { +func (c *naivePaddingConn) read(p []byte) (n int, err error) { if c.readRemaining > 0 { if len(p) > c.readRemaining { p = p[:c.readRemaining] @@ -420,29 +251,34 @@ func (c *naiveH2Conn) read(p []byte) (n int, err error) { c.paddingRemaining = 0 } if c.readPadding < kFirstPaddings { - paddingHdr := p[:3] - _, err = io.ReadFull(c.reader, paddingHdr) - if err != nil { - return - } - originalDataSize := int(binary.BigEndian.Uint16(paddingHdr[:2])) - paddingSize := int(paddingHdr[2]) - if len(p) > originalDataSize { - p = p[:originalDataSize] - } - n, err = c.reader.Read(p) - if err != nil { - return - } c.readPadding++ - c.readRemaining = originalDataSize - n - c.paddingRemaining = paddingSize - return + nr, err := io.ReadFull(c.reader, p[0:3]) + if nr > 0 { + nr = int(p[0])*256 + int(p[1]) + paddingSize := int(p[2]) + + if nr > len(p) { + c.readRemaining = nr - len(p) + c.paddingRemaining = paddingSize + nr = len(p) + paddingSize = 0 + } + + nr, err = io.ReadFull(c.reader, p[0:nr]) + if nr > 0 && paddingSize > 0 { + var junk [256]byte + _, err = io.ReadFull(c.reader, junk[0:paddingSize]) + } + } + if err != nil { + return 0, err + } + return nr, nil } return c.reader.Read(p) } -func (c *naiveH2Conn) Write(p []byte) (n int, err error) { +func (c *naivePaddingConn) Write(p []byte) (n int, err error) { for pLen := len(p); pLen > 0; { var data []byte if pLen > 65535 { @@ -460,136 +296,96 @@ func (c *naiveH2Conn) Write(p []byte) (n int, err error) { break } } - if err == nil { + if err == nil && c.flusher != nil { c.flusher.Flush() } return n, wrapHttpError(err) } -func (c *naiveH2Conn) write(p []byte) (n int, err error) { +func (c *naivePaddingConn) write(p []byte) (n int, err error) { if c.writePadding < kFirstPaddings { + c.writePadding++ paddingSize := rand.Intn(256) - _buffer := buf.StackNewSize(3 + len(p) + paddingSize) - defer common.KeepAlive(_buffer) - buffer := common.Dup(_buffer) - defer buffer.Release() - header := buffer.Extend(3) - binary.BigEndian.PutUint16(header, uint16(len(p))) - header[2] = byte(paddingSize) + var hdr [3]byte + binary.BigEndian.PutUint16(hdr[0:2], uint16(len(p))) + hdr[2] = byte(paddingSize) + p = append(hdr[:], p...) - common.Must1(buffer.Write(p)) - _, err = c.writer.Write(buffer.Bytes()) - if err == nil { - n = len(p) + junk := make([]byte, paddingSize) + p = append(p, junk...) + + _, err = c.writer.Write(p) + + if err != nil { + return 0, err } - c.writePadding++ - return + return len(p), nil } return c.writer.Write(p) } -func (c *naiveH2Conn) FrontHeadroom() int { +func (c *naivePaddingConn) FrontHeadroom() int { if c.writePadding < kFirstPaddings { return 3 } return 0 } -func (c *naiveH2Conn) RearHeadroom() int { +func (c *naivePaddingConn) RearHeadroom() int { if c.writePadding < kFirstPaddings { return 255 } return 0 } -func (c *naiveH2Conn) WriterMTU() int { +func (c *naivePaddingConn) WriterMTU() int { if c.writePadding < kFirstPaddings { return 65535 } return 0 } -func (c *naiveH2Conn) WriteBuffer(buffer *buf.Buffer) error { - defer buffer.Release() - if c.writePadding < kFirstPaddings { - bufferLen := buffer.Len() - if bufferLen > 65535 { - return common.Error(c.Write(buffer.Bytes())) - } - paddingSize := rand.Intn(256) - header := buffer.ExtendHeader(3) - binary.BigEndian.PutUint16(header, uint16(bufferLen)) - header[2] = byte(paddingSize) - buffer.Extend(paddingSize) - c.writePadding++ - } - err := common.Error(c.writer.Write(buffer.Bytes())) - if err == nil { - c.flusher.Flush() - } - return wrapHttpError(err) -} - -// FIXME -/*func (c *naiveH2Conn) WriteTo(w io.Writer) (n int64, err error) { - if c.readPadding < kFirstPaddings { - n, err = bufio.WriteToN(c, w, kFirstPaddings-c.readPadding) - } else { - n, err = bufio.Copy(w, c.reader) - } - return n, wrapHttpError(err) -}*/ - -func (c *naiveH2Conn) ReadFrom(r io.Reader) (n int64, err error) { - if c.writePadding < kFirstPaddings { - n, err = bufio.ReadFromN(c, r, kFirstPaddings-c.writePadding) - } else { - n, err = bufio.Copy(c.writer, r) - } - return n, wrapHttpError(err) -} - -func (c *naiveH2Conn) Close() error { +func (c *naivePaddingConn) Close() error { return common.Close( c.reader, c.writer, ) } -func (c *naiveH2Conn) LocalAddr() net.Addr { +func (c *naivePaddingConn) LocalAddr() net.Addr { return nil } -func (c *naiveH2Conn) RemoteAddr() net.Addr { +func (c *naivePaddingConn) RemoteAddr() net.Addr { return c.rAddr } -func (c *naiveH2Conn) SetDeadline(t time.Time) error { +func (c *naivePaddingConn) SetDeadline(t time.Time) error { return os.ErrInvalid } -func (c *naiveH2Conn) SetReadDeadline(t time.Time) error { +func (c *naivePaddingConn) SetReadDeadline(t time.Time) error { return os.ErrInvalid } -func (c *naiveH2Conn) SetWriteDeadline(t time.Time) error { +func (c *naivePaddingConn) SetWriteDeadline(t time.Time) error { return os.ErrInvalid } -func (c *naiveH2Conn) UpstreamReader() any { +func (c *naivePaddingConn) UpstreamReader() any { return c.reader } -func (c *naiveH2Conn) UpstreamWriter() any { +func (c *naivePaddingConn) UpstreamWriter() any { return c.writer } -func (c *naiveH2Conn) ReaderReplaceable() bool { - return c.readRemaining == kFirstPaddings +func (c *naivePaddingConn) ReaderReplaceable() bool { + return c.readPadding == kFirstPaddings } -func (c *naiveH2Conn) WriterReplaceable() bool { +func (c *naivePaddingConn) WriterReplaceable() bool { return c.writePadding == kFirstPaddings }