Fix naive padding

This commit is contained in:
arm64v8a 2022-09-07 08:00:00 +08:00
parent ef013e0639
commit 53b15cca1b

View File

@ -19,8 +19,6 @@ import (
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/auth" "github.com/sagernet/sing/common/auth"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
@ -160,14 +158,14 @@ func (n *Naive) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
destination := M.ParseSocksaddr(hostPort) destination := M.ParseSocksaddr(hostPort)
if hijacker, isHijacker := writer.(http.Hijacker); isHijacker { if hijacker, isHijacker := writer.(http.Hijacker); isHijacker {
conn, _, err := hijacker.Hijack() clientConn, _, err := hijacker.Hijack()
if err != nil { if err != nil {
n.badRequest(ctx, request, E.New("hijack failed")) n.badRequest(ctx, request, E.New("hijack failed"))
return return
} }
n.newConnection(ctx, &naiveH1Conn{Conn: conn}, source, destination) n.newConnection(ctx, &naivePaddingConn{reader: clientConn, writer: clientConn}, source, destination)
} else { } else {
n.newConnection(ctx, &naiveH2Conn{reader: request.Body, writer: writer, flusher: writer.(http.Flusher)}, source, destination) n.newConnection(ctx, &naivePaddingConn{reader: request.Body, writer: writer, flusher: writer.(http.Flusher)}, source, destination)
} }
} }
@ -216,191 +214,24 @@ func generateNaivePaddingHeader() string {
const kFirstPaddings = 8 const kFirstPaddings = 8
type naiveH1Conn struct { type naivePaddingConn struct {
net.Conn
readPadding int
writePadding int
readRemaining int
paddingRemaining int
}
func (c *naiveH1Conn) Read(p []byte) (n int, err error) {
n, err = c.read(p)
return n, wrapHttpError(err)
}
func (c *naiveH1Conn) read(p []byte) (n int, err error) {
if c.readRemaining > 0 {
if len(p) > c.readRemaining {
p = p[:c.readRemaining]
}
n, err = c.Conn.Read(p)
if err != nil {
return
}
c.readRemaining -= n
return
}
if c.paddingRemaining > 0 {
err = rw.SkipN(c.Conn, c.paddingRemaining)
if err != nil {
return
}
c.paddingRemaining = 0
}
if c.readPadding < kFirstPaddings {
paddingHdr := p[:3]
_, err = io.ReadFull(c.Conn, paddingHdr)
if err != nil {
return
}
originalDataSize := int(binary.BigEndian.Uint16(paddingHdr[:2]))
paddingSize := int(paddingHdr[2])
if len(p) > originalDataSize {
p = p[:originalDataSize]
}
n, err = c.Conn.Read(p)
if err != nil {
return
}
c.readPadding++
c.readRemaining = originalDataSize - n
c.paddingRemaining = paddingSize
return
}
return c.Conn.Read(p)
}
func (c *naiveH1Conn) Write(p []byte) (n int, err error) {
for pLen := len(p); pLen > 0; {
var data []byte
if pLen > 65535 {
data = p[:65535]
p = p[65535:]
pLen -= 65535
} else {
data = p
pLen = 0
}
var writeN int
writeN, err = c.write(data)
n += writeN
if err != nil {
break
}
}
return n, wrapHttpError(err)
}
func (c *naiveH1Conn) write(p []byte) (n int, err error) {
if c.writePadding < kFirstPaddings {
paddingSize := rand.Intn(256)
_buffer := buf.StackNewSize(3 + len(p) + paddingSize)
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
defer buffer.Release()
header := buffer.Extend(3)
binary.BigEndian.PutUint16(header, uint16(len(p)))
header[2] = byte(paddingSize)
common.Must1(buffer.Write(p))
_, err = c.Conn.Write(buffer.Bytes())
if err == nil {
n = len(p)
}
c.writePadding++
return
}
return c.Conn.Write(p)
}
func (c *naiveH1Conn) FrontHeadroom() int {
if c.writePadding < kFirstPaddings {
return 3
}
return 0
}
func (c *naiveH1Conn) RearHeadroom() int {
if c.writePadding < kFirstPaddings {
return 255
}
return 0
}
func (c *naiveH1Conn) WriterMTU() int {
if c.writePadding < kFirstPaddings {
return 65535
}
return 0
}
func (c *naiveH1Conn) WriteBuffer(buffer *buf.Buffer) error {
defer buffer.Release()
if c.writePadding < kFirstPaddings {
bufferLen := buffer.Len()
if bufferLen > 65535 {
return common.Error(c.Write(buffer.Bytes()))
}
paddingSize := rand.Intn(256)
header := buffer.ExtendHeader(3)
binary.BigEndian.PutUint16(header, uint16(bufferLen))
header[2] = byte(paddingSize)
buffer.Extend(paddingSize)
c.writePadding++
}
return wrapHttpError(common.Error(c.Conn.Write(buffer.Bytes())))
}
// FIXME
/*func (c *naiveH1Conn) WriteTo(w io.Writer) (n int64, err error) {
if c.readPadding < kFirstPaddings {
n, err = bufio.WriteToN(c, w, kFirstPaddings-c.readPadding)
} else {
n, err = bufio.Copy(w, c.Conn)
}
return n, wrapHttpError(err)
}*/
func (c *naiveH1Conn) ReadFrom(r io.Reader) (n int64, err error) {
if c.writePadding < kFirstPaddings {
n, err = bufio.ReadFromN(c, r, kFirstPaddings-c.writePadding)
} else {
n, err = bufio.Copy(c.Conn, r)
}
return n, wrapHttpError(err)
}
func (c *naiveH1Conn) Upstream() any {
return c.Conn
}
func (c *naiveH1Conn) ReaderReplaceable() bool {
return c.readRemaining == kFirstPaddings
}
func (c *naiveH1Conn) WriterReplaceable() bool {
return c.writePadding == kFirstPaddings
}
type naiveH2Conn struct {
reader io.Reader reader io.Reader
writer io.Writer writer io.Writer
flusher http.Flusher flusher http.Flusher
rAddr net.Addr rAddr net.Addr
readPadding int readPadding int
writePadding int writePadding int
readRemaining int readRemaining int
paddingRemaining int paddingRemaining int
} }
func (c *naiveH2Conn) Read(p []byte) (n int, err error) { func (c *naivePaddingConn) Read(p []byte) (n int, err error) {
n, err = c.read(p) n, err = c.read(p)
return n, wrapHttpError(err) return n, wrapHttpError(err)
} }
func (c *naiveH2Conn) read(p []byte) (n int, err error) { func (c *naivePaddingConn) read(p []byte) (n int, err error) {
if c.readRemaining > 0 { if c.readRemaining > 0 {
if len(p) > c.readRemaining { if len(p) > c.readRemaining {
p = p[:c.readRemaining] p = p[:c.readRemaining]
@ -420,29 +251,34 @@ func (c *naiveH2Conn) read(p []byte) (n int, err error) {
c.paddingRemaining = 0 c.paddingRemaining = 0
} }
if c.readPadding < kFirstPaddings { if c.readPadding < kFirstPaddings {
paddingHdr := p[:3]
_, err = io.ReadFull(c.reader, paddingHdr)
if err != nil {
return
}
originalDataSize := int(binary.BigEndian.Uint16(paddingHdr[:2]))
paddingSize := int(paddingHdr[2])
if len(p) > originalDataSize {
p = p[:originalDataSize]
}
n, err = c.reader.Read(p)
if err != nil {
return
}
c.readPadding++ c.readPadding++
c.readRemaining = originalDataSize - n nr, err := io.ReadFull(c.reader, p[0:3])
if nr > 0 {
nr = int(p[0])*256 + int(p[1])
paddingSize := int(p[2])
if nr > len(p) {
c.readRemaining = nr - len(p)
c.paddingRemaining = paddingSize c.paddingRemaining = paddingSize
return nr = len(p)
paddingSize = 0
}
nr, err = io.ReadFull(c.reader, p[0:nr])
if nr > 0 && paddingSize > 0 {
var junk [256]byte
_, err = io.ReadFull(c.reader, junk[0:paddingSize])
}
}
if err != nil {
return 0, err
}
return nr, nil
} }
return c.reader.Read(p) return c.reader.Read(p)
} }
func (c *naiveH2Conn) Write(p []byte) (n int, err error) { func (c *naivePaddingConn) Write(p []byte) (n int, err error) {
for pLen := len(p); pLen > 0; { for pLen := len(p); pLen > 0; {
var data []byte var data []byte
if pLen > 65535 { if pLen > 65535 {
@ -460,136 +296,96 @@ func (c *naiveH2Conn) Write(p []byte) (n int, err error) {
break break
} }
} }
if err == nil { if err == nil && c.flusher != nil {
c.flusher.Flush() c.flusher.Flush()
} }
return n, wrapHttpError(err) return n, wrapHttpError(err)
} }
func (c *naiveH2Conn) write(p []byte) (n int, err error) { func (c *naivePaddingConn) write(p []byte) (n int, err error) {
if c.writePadding < kFirstPaddings { if c.writePadding < kFirstPaddings {
c.writePadding++
paddingSize := rand.Intn(256) paddingSize := rand.Intn(256)
_buffer := buf.StackNewSize(3 + len(p) + paddingSize) var hdr [3]byte
defer common.KeepAlive(_buffer) binary.BigEndian.PutUint16(hdr[0:2], uint16(len(p)))
buffer := common.Dup(_buffer) hdr[2] = byte(paddingSize)
defer buffer.Release() p = append(hdr[:], p...)
header := buffer.Extend(3)
binary.BigEndian.PutUint16(header, uint16(len(p)))
header[2] = byte(paddingSize)
common.Must1(buffer.Write(p)) junk := make([]byte, paddingSize)
_, err = c.writer.Write(buffer.Bytes()) p = append(p, junk...)
if err == nil {
n = len(p) _, err = c.writer.Write(p)
if err != nil {
return 0, err
} }
c.writePadding++ return len(p), nil
return
} }
return c.writer.Write(p) return c.writer.Write(p)
} }
func (c *naiveH2Conn) FrontHeadroom() int { func (c *naivePaddingConn) FrontHeadroom() int {
if c.writePadding < kFirstPaddings { if c.writePadding < kFirstPaddings {
return 3 return 3
} }
return 0 return 0
} }
func (c *naiveH2Conn) RearHeadroom() int { func (c *naivePaddingConn) RearHeadroom() int {
if c.writePadding < kFirstPaddings { if c.writePadding < kFirstPaddings {
return 255 return 255
} }
return 0 return 0
} }
func (c *naiveH2Conn) WriterMTU() int { func (c *naivePaddingConn) WriterMTU() int {
if c.writePadding < kFirstPaddings { if c.writePadding < kFirstPaddings {
return 65535 return 65535
} }
return 0 return 0
} }
func (c *naiveH2Conn) WriteBuffer(buffer *buf.Buffer) error { func (c *naivePaddingConn) Close() error {
defer buffer.Release()
if c.writePadding < kFirstPaddings {
bufferLen := buffer.Len()
if bufferLen > 65535 {
return common.Error(c.Write(buffer.Bytes()))
}
paddingSize := rand.Intn(256)
header := buffer.ExtendHeader(3)
binary.BigEndian.PutUint16(header, uint16(bufferLen))
header[2] = byte(paddingSize)
buffer.Extend(paddingSize)
c.writePadding++
}
err := common.Error(c.writer.Write(buffer.Bytes()))
if err == nil {
c.flusher.Flush()
}
return wrapHttpError(err)
}
// FIXME
/*func (c *naiveH2Conn) WriteTo(w io.Writer) (n int64, err error) {
if c.readPadding < kFirstPaddings {
n, err = bufio.WriteToN(c, w, kFirstPaddings-c.readPadding)
} else {
n, err = bufio.Copy(w, c.reader)
}
return n, wrapHttpError(err)
}*/
func (c *naiveH2Conn) ReadFrom(r io.Reader) (n int64, err error) {
if c.writePadding < kFirstPaddings {
n, err = bufio.ReadFromN(c, r, kFirstPaddings-c.writePadding)
} else {
n, err = bufio.Copy(c.writer, r)
}
return n, wrapHttpError(err)
}
func (c *naiveH2Conn) Close() error {
return common.Close( return common.Close(
c.reader, c.reader,
c.writer, c.writer,
) )
} }
func (c *naiveH2Conn) LocalAddr() net.Addr { func (c *naivePaddingConn) LocalAddr() net.Addr {
return nil return nil
} }
func (c *naiveH2Conn) RemoteAddr() net.Addr { func (c *naivePaddingConn) RemoteAddr() net.Addr {
return c.rAddr return c.rAddr
} }
func (c *naiveH2Conn) SetDeadline(t time.Time) error { func (c *naivePaddingConn) SetDeadline(t time.Time) error {
return os.ErrInvalid return os.ErrInvalid
} }
func (c *naiveH2Conn) SetReadDeadline(t time.Time) error { func (c *naivePaddingConn) SetReadDeadline(t time.Time) error {
return os.ErrInvalid return os.ErrInvalid
} }
func (c *naiveH2Conn) SetWriteDeadline(t time.Time) error { func (c *naivePaddingConn) SetWriteDeadline(t time.Time) error {
return os.ErrInvalid return os.ErrInvalid
} }
func (c *naiveH2Conn) UpstreamReader() any { func (c *naivePaddingConn) UpstreamReader() any {
return c.reader return c.reader
} }
func (c *naiveH2Conn) UpstreamWriter() any { func (c *naivePaddingConn) UpstreamWriter() any {
return c.writer return c.writer
} }
func (c *naiveH2Conn) ReaderReplaceable() bool { func (c *naivePaddingConn) ReaderReplaceable() bool {
return c.readRemaining == kFirstPaddings return c.readPadding == kFirstPaddings
} }
func (c *naiveH2Conn) WriterReplaceable() bool { func (c *naivePaddingConn) WriterReplaceable() bool {
return c.writePadding == kFirstPaddings return c.writePadding == kFirstPaddings
} }