Fix naive padding

This commit is contained in:
arm64v8a 2022-09-07 08:00:00 +08:00
parent ef013e0639
commit 53b15cca1b

View File

@ -19,8 +19,6 @@ import (
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/auth" "github.com/sagernet/sing/common/auth"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
@ -160,14 +158,14 @@ func (n *Naive) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
destination := M.ParseSocksaddr(hostPort) destination := M.ParseSocksaddr(hostPort)
if hijacker, isHijacker := writer.(http.Hijacker); isHijacker { if hijacker, isHijacker := writer.(http.Hijacker); isHijacker {
conn, _, err := hijacker.Hijack() clientConn, _, err := hijacker.Hijack()
if err != nil { if err != nil {
n.badRequest(ctx, request, E.New("hijack failed")) n.badRequest(ctx, request, E.New("hijack failed"))
return return
} }
n.newConnection(ctx, &naiveH1Conn{Conn: conn}, source, destination) n.newConnection(ctx, &naivePaddingConn{reader: clientConn, writer: clientConn}, source, destination)
} else { } else {
n.newConnection(ctx, &naiveH2Conn{reader: request.Body, writer: writer, flusher: writer.(http.Flusher)}, source, destination) n.newConnection(ctx, &naivePaddingConn{reader: request.Body, writer: writer, flusher: writer.(http.Flusher)}, source, destination)
} }
} }
@ -216,191 +214,24 @@ func generateNaivePaddingHeader() string {
const kFirstPaddings = 8 const kFirstPaddings = 8
type naiveH1Conn struct { type naivePaddingConn struct {
net.Conn reader io.Reader
readPadding int writer io.Writer
writePadding int flusher http.Flusher
rAddr net.Addr
readPadding int
writePadding int
readRemaining int readRemaining int
paddingRemaining int paddingRemaining int
} }
func (c *naiveH1Conn) Read(p []byte) (n int, err error) { func (c *naivePaddingConn) Read(p []byte) (n int, err error) {
n, err = c.read(p) n, err = c.read(p)
return n, wrapHttpError(err) return n, wrapHttpError(err)
} }
func (c *naiveH1Conn) read(p []byte) (n int, err error) { func (c *naivePaddingConn) read(p []byte) (n int, err error) {
if c.readRemaining > 0 {
if len(p) > c.readRemaining {
p = p[:c.readRemaining]
}
n, err = c.Conn.Read(p)
if err != nil {
return
}
c.readRemaining -= n
return
}
if c.paddingRemaining > 0 {
err = rw.SkipN(c.Conn, c.paddingRemaining)
if err != nil {
return
}
c.paddingRemaining = 0
}
if c.readPadding < kFirstPaddings {
paddingHdr := p[:3]
_, err = io.ReadFull(c.Conn, paddingHdr)
if err != nil {
return
}
originalDataSize := int(binary.BigEndian.Uint16(paddingHdr[:2]))
paddingSize := int(paddingHdr[2])
if len(p) > originalDataSize {
p = p[:originalDataSize]
}
n, err = c.Conn.Read(p)
if err != nil {
return
}
c.readPadding++
c.readRemaining = originalDataSize - n
c.paddingRemaining = paddingSize
return
}
return c.Conn.Read(p)
}
func (c *naiveH1Conn) Write(p []byte) (n int, err error) {
for pLen := len(p); pLen > 0; {
var data []byte
if pLen > 65535 {
data = p[:65535]
p = p[65535:]
pLen -= 65535
} else {
data = p
pLen = 0
}
var writeN int
writeN, err = c.write(data)
n += writeN
if err != nil {
break
}
}
return n, wrapHttpError(err)
}
func (c *naiveH1Conn) write(p []byte) (n int, err error) {
if c.writePadding < kFirstPaddings {
paddingSize := rand.Intn(256)
_buffer := buf.StackNewSize(3 + len(p) + paddingSize)
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
defer buffer.Release()
header := buffer.Extend(3)
binary.BigEndian.PutUint16(header, uint16(len(p)))
header[2] = byte(paddingSize)
common.Must1(buffer.Write(p))
_, err = c.Conn.Write(buffer.Bytes())
if err == nil {
n = len(p)
}
c.writePadding++
return
}
return c.Conn.Write(p)
}
func (c *naiveH1Conn) FrontHeadroom() int {
if c.writePadding < kFirstPaddings {
return 3
}
return 0
}
func (c *naiveH1Conn) RearHeadroom() int {
if c.writePadding < kFirstPaddings {
return 255
}
return 0
}
func (c *naiveH1Conn) WriterMTU() int {
if c.writePadding < kFirstPaddings {
return 65535
}
return 0
}
func (c *naiveH1Conn) WriteBuffer(buffer *buf.Buffer) error {
defer buffer.Release()
if c.writePadding < kFirstPaddings {
bufferLen := buffer.Len()
if bufferLen > 65535 {
return common.Error(c.Write(buffer.Bytes()))
}
paddingSize := rand.Intn(256)
header := buffer.ExtendHeader(3)
binary.BigEndian.PutUint16(header, uint16(bufferLen))
header[2] = byte(paddingSize)
buffer.Extend(paddingSize)
c.writePadding++
}
return wrapHttpError(common.Error(c.Conn.Write(buffer.Bytes())))
}
// FIXME
/*func (c *naiveH1Conn) WriteTo(w io.Writer) (n int64, err error) {
if c.readPadding < kFirstPaddings {
n, err = bufio.WriteToN(c, w, kFirstPaddings-c.readPadding)
} else {
n, err = bufio.Copy(w, c.Conn)
}
return n, wrapHttpError(err)
}*/
func (c *naiveH1Conn) ReadFrom(r io.Reader) (n int64, err error) {
if c.writePadding < kFirstPaddings {
n, err = bufio.ReadFromN(c, r, kFirstPaddings-c.writePadding)
} else {
n, err = bufio.Copy(c.Conn, r)
}
return n, wrapHttpError(err)
}
func (c *naiveH1Conn) Upstream() any {
return c.Conn
}
func (c *naiveH1Conn) ReaderReplaceable() bool {
return c.readRemaining == kFirstPaddings
}
func (c *naiveH1Conn) WriterReplaceable() bool {
return c.writePadding == kFirstPaddings
}
type naiveH2Conn struct {
reader io.Reader
writer io.Writer
flusher http.Flusher
rAddr net.Addr
readPadding int
writePadding int
readRemaining int
paddingRemaining int
}
func (c *naiveH2Conn) Read(p []byte) (n int, err error) {
n, err = c.read(p)
return n, wrapHttpError(err)
}
func (c *naiveH2Conn) read(p []byte) (n int, err error) {
if c.readRemaining > 0 { if c.readRemaining > 0 {
if len(p) > c.readRemaining { if len(p) > c.readRemaining {
p = p[:c.readRemaining] p = p[:c.readRemaining]
@ -420,29 +251,34 @@ func (c *naiveH2Conn) read(p []byte) (n int, err error) {
c.paddingRemaining = 0 c.paddingRemaining = 0
} }
if c.readPadding < kFirstPaddings { if c.readPadding < kFirstPaddings {
paddingHdr := p[:3]
_, err = io.ReadFull(c.reader, paddingHdr)
if err != nil {
return
}
originalDataSize := int(binary.BigEndian.Uint16(paddingHdr[:2]))
paddingSize := int(paddingHdr[2])
if len(p) > originalDataSize {
p = p[:originalDataSize]
}
n, err = c.reader.Read(p)
if err != nil {
return
}
c.readPadding++ c.readPadding++
c.readRemaining = originalDataSize - n nr, err := io.ReadFull(c.reader, p[0:3])
c.paddingRemaining = paddingSize if nr > 0 {
return nr = int(p[0])*256 + int(p[1])
paddingSize := int(p[2])
if nr > len(p) {
c.readRemaining = nr - len(p)
c.paddingRemaining = paddingSize
nr = len(p)
paddingSize = 0
}
nr, err = io.ReadFull(c.reader, p[0:nr])
if nr > 0 && paddingSize > 0 {
var junk [256]byte
_, err = io.ReadFull(c.reader, junk[0:paddingSize])
}
}
if err != nil {
return 0, err
}
return nr, nil
} }
return c.reader.Read(p) return c.reader.Read(p)
} }
func (c *naiveH2Conn) Write(p []byte) (n int, err error) { func (c *naivePaddingConn) Write(p []byte) (n int, err error) {
for pLen := len(p); pLen > 0; { for pLen := len(p); pLen > 0; {
var data []byte var data []byte
if pLen > 65535 { if pLen > 65535 {
@ -460,136 +296,96 @@ func (c *naiveH2Conn) Write(p []byte) (n int, err error) {
break break
} }
} }
if err == nil { if err == nil && c.flusher != nil {
c.flusher.Flush() c.flusher.Flush()
} }
return n, wrapHttpError(err) return n, wrapHttpError(err)
} }
func (c *naiveH2Conn) write(p []byte) (n int, err error) { func (c *naivePaddingConn) write(p []byte) (n int, err error) {
if c.writePadding < kFirstPaddings { if c.writePadding < kFirstPaddings {
c.writePadding++
paddingSize := rand.Intn(256) paddingSize := rand.Intn(256)
_buffer := buf.StackNewSize(3 + len(p) + paddingSize) var hdr [3]byte
defer common.KeepAlive(_buffer) binary.BigEndian.PutUint16(hdr[0:2], uint16(len(p)))
buffer := common.Dup(_buffer) hdr[2] = byte(paddingSize)
defer buffer.Release() p = append(hdr[:], p...)
header := buffer.Extend(3)
binary.BigEndian.PutUint16(header, uint16(len(p)))
header[2] = byte(paddingSize)
common.Must1(buffer.Write(p)) junk := make([]byte, paddingSize)
_, err = c.writer.Write(buffer.Bytes()) p = append(p, junk...)
if err == nil {
n = len(p) _, err = c.writer.Write(p)
if err != nil {
return 0, err
} }
c.writePadding++ return len(p), nil
return
} }
return c.writer.Write(p) return c.writer.Write(p)
} }
func (c *naiveH2Conn) FrontHeadroom() int { func (c *naivePaddingConn) FrontHeadroom() int {
if c.writePadding < kFirstPaddings { if c.writePadding < kFirstPaddings {
return 3 return 3
} }
return 0 return 0
} }
func (c *naiveH2Conn) RearHeadroom() int { func (c *naivePaddingConn) RearHeadroom() int {
if c.writePadding < kFirstPaddings { if c.writePadding < kFirstPaddings {
return 255 return 255
} }
return 0 return 0
} }
func (c *naiveH2Conn) WriterMTU() int { func (c *naivePaddingConn) WriterMTU() int {
if c.writePadding < kFirstPaddings { if c.writePadding < kFirstPaddings {
return 65535 return 65535
} }
return 0 return 0
} }
func (c *naiveH2Conn) WriteBuffer(buffer *buf.Buffer) error { func (c *naivePaddingConn) Close() error {
defer buffer.Release()
if c.writePadding < kFirstPaddings {
bufferLen := buffer.Len()
if bufferLen > 65535 {
return common.Error(c.Write(buffer.Bytes()))
}
paddingSize := rand.Intn(256)
header := buffer.ExtendHeader(3)
binary.BigEndian.PutUint16(header, uint16(bufferLen))
header[2] = byte(paddingSize)
buffer.Extend(paddingSize)
c.writePadding++
}
err := common.Error(c.writer.Write(buffer.Bytes()))
if err == nil {
c.flusher.Flush()
}
return wrapHttpError(err)
}
// FIXME
/*func (c *naiveH2Conn) WriteTo(w io.Writer) (n int64, err error) {
if c.readPadding < kFirstPaddings {
n, err = bufio.WriteToN(c, w, kFirstPaddings-c.readPadding)
} else {
n, err = bufio.Copy(w, c.reader)
}
return n, wrapHttpError(err)
}*/
func (c *naiveH2Conn) ReadFrom(r io.Reader) (n int64, err error) {
if c.writePadding < kFirstPaddings {
n, err = bufio.ReadFromN(c, r, kFirstPaddings-c.writePadding)
} else {
n, err = bufio.Copy(c.writer, r)
}
return n, wrapHttpError(err)
}
func (c *naiveH2Conn) Close() error {
return common.Close( return common.Close(
c.reader, c.reader,
c.writer, c.writer,
) )
} }
func (c *naiveH2Conn) LocalAddr() net.Addr { func (c *naivePaddingConn) LocalAddr() net.Addr {
return nil return nil
} }
func (c *naiveH2Conn) RemoteAddr() net.Addr { func (c *naivePaddingConn) RemoteAddr() net.Addr {
return c.rAddr return c.rAddr
} }
func (c *naiveH2Conn) SetDeadline(t time.Time) error { func (c *naivePaddingConn) SetDeadline(t time.Time) error {
return os.ErrInvalid return os.ErrInvalid
} }
func (c *naiveH2Conn) SetReadDeadline(t time.Time) error { func (c *naivePaddingConn) SetReadDeadline(t time.Time) error {
return os.ErrInvalid return os.ErrInvalid
} }
func (c *naiveH2Conn) SetWriteDeadline(t time.Time) error { func (c *naivePaddingConn) SetWriteDeadline(t time.Time) error {
return os.ErrInvalid return os.ErrInvalid
} }
func (c *naiveH2Conn) UpstreamReader() any { func (c *naivePaddingConn) UpstreamReader() any {
return c.reader return c.reader
} }
func (c *naiveH2Conn) UpstreamWriter() any { func (c *naivePaddingConn) UpstreamWriter() any {
return c.writer return c.writer
} }
func (c *naiveH2Conn) ReaderReplaceable() bool { func (c *naivePaddingConn) ReaderReplaceable() bool {
return c.readRemaining == kFirstPaddings return c.readPadding == kFirstPaddings
} }
func (c *naiveH2Conn) WriterReplaceable() bool { func (c *naivePaddingConn) WriterReplaceable() bool {
return c.writePadding == kFirstPaddings return c.writePadding == kFirstPaddings
} }