From ba0e2a850e0fc597d63cc34efff093e747579a6f Mon Sep 17 00:00:00 2001 From: arm64v8a <48624112+arm64v8a@users.noreply.github.com> Date: Wed, 7 Sep 2022 11:22:53 +0800 Subject: [PATCH] Fix naive padding (naiveH1Conn & WriteBuffer) --- inbound/naive.go | 125 +++++++++++++++++++++++++++++++---------------- 1 file changed, 84 insertions(+), 41 deletions(-) diff --git a/inbound/naive.go b/inbound/naive.go index 0b2cb2c7..91b0206a 100644 --- a/inbound/naive.go +++ b/inbound/naive.go @@ -19,6 +19,7 @@ import ( "github.com/sagernet/sing-box/option" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/auth" + "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" @@ -163,9 +164,9 @@ func (n *Naive) ServeHTTP(writer http.ResponseWriter, request *http.Request) { n.badRequest(ctx, request, E.New("hijack failed")) return } - n.newConnection(ctx, &naivePaddingConn{reader: clientConn, writer: clientConn}, source, destination) + n.newConnection(ctx, &naiveH1Conn{clientConn, &naivePaddingConn{reader: clientConn, writer: clientConn}}, source, destination) } else { - n.newConnection(ctx, &naivePaddingConn{reader: request.Body, writer: writer, flusher: writer.(http.Flusher)}, source, destination) + n.newConnection(ctx, &naiveH2Conn{&naivePaddingConn{reader: request.Body, writer: writer, flusher: writer.(http.Flusher)}}, source, destination) } } @@ -214,11 +215,20 @@ func generateNaivePaddingHeader() string { const kFirstPaddings = 8 +type naiveH1Conn struct { + net.Conn + *naivePaddingConn +} + +type naiveH2Conn struct { + *naivePaddingConn +} + type naivePaddingConn struct { - reader io.Reader - writer io.Writer - flusher http.Flusher - rAddr net.Addr + reader io.Reader + writer io.Writer + flusher http.Flusher + readPadding int writePadding int @@ -323,6 +333,27 @@ func (c *naivePaddingConn) write(p []byte) (n int, err error) { return c.writer.Write(p) } +func (c *naivePaddingConn) WriteBuffer(buffer *buf.Buffer) error { + defer buffer.Release() + if c.writePadding < kFirstPaddings { + bufferLen := buffer.Len() + if bufferLen > 65535 { + return common.Error(c.Write(buffer.Bytes())) + } + paddingSize := rand.Intn(256) + header := buffer.ExtendHeader(3) + binary.BigEndian.PutUint16(header, uint16(bufferLen)) + header[2] = byte(paddingSize) + buffer.Extend(paddingSize) + c.writePadding++ + } + err := common.Error(c.writer.Write(buffer.Bytes())) + if err == nil && c.flusher != nil { + c.flusher.Flush() + } + return wrapHttpError(err) +} + func (c *naivePaddingConn) FrontHeadroom() int { if c.writePadding < kFirstPaddings { return 3 @@ -344,41 +375,6 @@ func (c *naivePaddingConn) WriterMTU() int { return 0 } -func (c *naivePaddingConn) Close() error { - return common.Close( - c.reader, - c.writer, - ) -} - -func (c *naivePaddingConn) LocalAddr() net.Addr { - return nil -} - -func (c *naivePaddingConn) RemoteAddr() net.Addr { - return c.rAddr -} - -func (c *naivePaddingConn) SetDeadline(t time.Time) error { - return os.ErrInvalid -} - -func (c *naivePaddingConn) SetReadDeadline(t time.Time) error { - return os.ErrInvalid -} - -func (c *naivePaddingConn) SetWriteDeadline(t time.Time) error { - return os.ErrInvalid -} - -func (c *naivePaddingConn) UpstreamReader() any { - return c.reader -} - -func (c *naivePaddingConn) UpstreamWriter() any { - return c.writer -} - func (c *naivePaddingConn) ReaderReplaceable() bool { return c.readPadding == kFirstPaddings } @@ -387,6 +383,53 @@ func (c *naivePaddingConn) WriterReplaceable() bool { return c.writePadding == kFirstPaddings } +func (c *naiveH1Conn) Read(p []byte) (n int, err error) { + return c.naivePaddingConn.read(p) +} + +func (c *naiveH1Conn) Write(p []byte) (n int, err error) { + return c.naivePaddingConn.Write(p) +} + +func (c *naiveH1Conn) Upstream() any { + return c.Conn +} + +func (c *naiveH2Conn) Close() error { + return common.Close( + c.reader, + c.writer, + ) +} + +func (c *naiveH2Conn) LocalAddr() net.Addr { + return nil +} + +func (c *naiveH2Conn) RemoteAddr() net.Addr { + return nil +} + +func (c *naiveH2Conn) SetDeadline(t time.Time) error { + return os.ErrInvalid +} + +func (c *naiveH2Conn) SetReadDeadline(t time.Time) error { + return os.ErrInvalid +} + +func (c *naiveH2Conn) SetWriteDeadline(t time.Time) error { + return os.ErrInvalid +} + +func (c *naiveH2Conn) UpstreamReader() any { + return c.reader +} + +func (c *naiveH2Conn) UpstreamWriter() any { + return c.writer +} + func wrapHttpError(err error) error { if err == nil { return err