Improve read wait interface &

Refactor Authenticator interface to struct
This commit is contained in:
世界 2023-12-07 11:56:57 +08:00
parent 4197805a22
commit c8b4182ea3
No known key found for this signature in database
GPG Key ID: CD109927C34A63C4
22 changed files with 291 additions and 320 deletions

View File

@ -1,233 +0,0 @@
//go:build go1.20 && !go1.21
package badtls
import (
"crypto/cipher"
"crypto/rand"
"crypto/tls"
"encoding/binary"
"io"
"net"
"reflect"
"sync"
"sync/atomic"
"unsafe"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
N "github.com/sagernet/sing/common/network"
aTLS "github.com/sagernet/sing/common/tls"
)
type Conn struct {
*tls.Conn
writer N.ExtendedWriter
isHandshakeComplete *atomic.Bool
activeCall *atomic.Int32
closeNotifySent *bool
version *uint16
rand io.Reader
halfAccess *sync.Mutex
halfError *error
cipher cipher.AEAD
explicitNonceLen int
halfPtr uintptr
halfSeq []byte
halfScratchBuf []byte
}
func TryCreate(conn aTLS.Conn) aTLS.Conn {
tlsConn, ok := conn.(*tls.Conn)
if !ok {
return conn
}
badConn, err := Create(tlsConn)
if err != nil {
log.Warn("initialize badtls: ", err)
return conn
}
return badConn
}
func Create(conn *tls.Conn) (aTLS.Conn, error) {
rawConn := reflect.Indirect(reflect.ValueOf(conn))
rawIsHandshakeComplete := rawConn.FieldByName("isHandshakeComplete")
if !rawIsHandshakeComplete.IsValid() || rawIsHandshakeComplete.Kind() != reflect.Struct {
return nil, E.New("badtls: invalid isHandshakeComplete")
}
isHandshakeComplete := (*atomic.Bool)(unsafe.Pointer(rawIsHandshakeComplete.UnsafeAddr()))
if !isHandshakeComplete.Load() {
return nil, E.New("handshake not finished")
}
rawActiveCall := rawConn.FieldByName("activeCall")
if !rawActiveCall.IsValid() || rawActiveCall.Kind() != reflect.Struct {
return nil, E.New("badtls: invalid active call")
}
activeCall := (*atomic.Int32)(unsafe.Pointer(rawActiveCall.UnsafeAddr()))
rawHalfConn := rawConn.FieldByName("out")
if !rawHalfConn.IsValid() || rawHalfConn.Kind() != reflect.Struct {
return nil, E.New("badtls: invalid half conn")
}
rawVersion := rawConn.FieldByName("vers")
if !rawVersion.IsValid() || rawVersion.Kind() != reflect.Uint16 {
return nil, E.New("badtls: invalid version")
}
version := (*uint16)(unsafe.Pointer(rawVersion.UnsafeAddr()))
rawCloseNotifySent := rawConn.FieldByName("closeNotifySent")
if !rawCloseNotifySent.IsValid() || rawCloseNotifySent.Kind() != reflect.Bool {
return nil, E.New("badtls: invalid notify")
}
closeNotifySent := (*bool)(unsafe.Pointer(rawCloseNotifySent.UnsafeAddr()))
rawConfig := reflect.Indirect(rawConn.FieldByName("config"))
if !rawConfig.IsValid() || rawConfig.Kind() != reflect.Struct {
return nil, E.New("badtls: bad config")
}
config := (*tls.Config)(unsafe.Pointer(rawConfig.UnsafeAddr()))
randReader := config.Rand
if randReader == nil {
randReader = rand.Reader
}
rawHalfMutex := rawHalfConn.FieldByName("Mutex")
if !rawHalfMutex.IsValid() || rawHalfMutex.Kind() != reflect.Struct {
return nil, E.New("badtls: invalid half mutex")
}
halfAccess := (*sync.Mutex)(unsafe.Pointer(rawHalfMutex.UnsafeAddr()))
rawHalfError := rawHalfConn.FieldByName("err")
if !rawHalfError.IsValid() || rawHalfError.Kind() != reflect.Interface {
return nil, E.New("badtls: invalid half error")
}
halfError := (*error)(unsafe.Pointer(rawHalfError.UnsafeAddr()))
rawHalfCipherInterface := rawHalfConn.FieldByName("cipher")
if !rawHalfCipherInterface.IsValid() || rawHalfCipherInterface.Kind() != reflect.Interface {
return nil, E.New("badtls: invalid cipher interface")
}
rawHalfCipher := rawHalfCipherInterface.Elem()
aeadCipher, loaded := valueInterface(rawHalfCipher, false).(cipher.AEAD)
if !loaded {
return nil, E.New("badtls: invalid AEAD cipher")
}
var explicitNonceLen int
switch cipherName := reflect.Indirect(rawHalfCipher).Type().String(); cipherName {
case "tls.prefixNonceAEAD":
explicitNonceLen = aeadCipher.NonceSize()
case "tls.xorNonceAEAD":
default:
return nil, E.New("badtls: unknown cipher type: ", cipherName)
}
rawHalfSeq := rawHalfConn.FieldByName("seq")
if !rawHalfSeq.IsValid() || rawHalfSeq.Kind() != reflect.Array {
return nil, E.New("badtls: invalid seq")
}
halfSeq := rawHalfSeq.Bytes()
rawHalfScratchBuf := rawHalfConn.FieldByName("scratchBuf")
if !rawHalfScratchBuf.IsValid() || rawHalfScratchBuf.Kind() != reflect.Array {
return nil, E.New("badtls: invalid scratchBuf")
}
halfScratchBuf := rawHalfScratchBuf.Bytes()
return &Conn{
Conn: conn,
writer: bufio.NewExtendedWriter(conn.NetConn()),
isHandshakeComplete: isHandshakeComplete,
activeCall: activeCall,
closeNotifySent: closeNotifySent,
version: version,
halfAccess: halfAccess,
halfError: halfError,
cipher: aeadCipher,
explicitNonceLen: explicitNonceLen,
rand: randReader,
halfPtr: rawHalfConn.UnsafeAddr(),
halfSeq: halfSeq,
halfScratchBuf: halfScratchBuf,
}, nil
}
func (c *Conn) WriteBuffer(buffer *buf.Buffer) error {
if buffer.Len() > maxPlaintext {
defer buffer.Release()
return common.Error(c.Write(buffer.Bytes()))
}
for {
x := c.activeCall.Load()
if x&1 != 0 {
return net.ErrClosed
}
if c.activeCall.CompareAndSwap(x, x+2) {
break
}
}
defer c.activeCall.Add(-2)
c.halfAccess.Lock()
defer c.halfAccess.Unlock()
if err := *c.halfError; err != nil {
return err
}
if *c.closeNotifySent {
return errShutdown
}
dataLen := buffer.Len()
dataBytes := buffer.Bytes()
outBuf := buffer.ExtendHeader(recordHeaderLen + c.explicitNonceLen)
outBuf[0] = 23
version := *c.version
if version == 0 {
version = tls.VersionTLS10
} else if version == tls.VersionTLS13 {
version = tls.VersionTLS12
}
binary.BigEndian.PutUint16(outBuf[1:], version)
var nonce []byte
if c.explicitNonceLen > 0 {
nonce = outBuf[5 : 5+c.explicitNonceLen]
if c.explicitNonceLen < 16 {
copy(nonce, c.halfSeq)
} else {
if _, err := io.ReadFull(c.rand, nonce); err != nil {
return err
}
}
}
if len(nonce) == 0 {
nonce = c.halfSeq
}
if *c.version == tls.VersionTLS13 {
buffer.FreeBytes()[0] = 23
binary.BigEndian.PutUint16(outBuf[3:], uint16(dataLen+1+c.cipher.Overhead()))
c.cipher.Seal(outBuf, nonce, outBuf[recordHeaderLen:recordHeaderLen+c.explicitNonceLen+dataLen+1], outBuf[:recordHeaderLen])
buffer.Extend(1 + c.cipher.Overhead())
} else {
binary.BigEndian.PutUint16(outBuf[3:], uint16(dataLen))
additionalData := append(c.halfScratchBuf[:0], c.halfSeq...)
additionalData = append(additionalData, outBuf[:recordHeaderLen]...)
c.cipher.Seal(outBuf, nonce, dataBytes, additionalData)
buffer.Extend(c.cipher.Overhead())
binary.BigEndian.PutUint16(outBuf[3:], uint16(dataLen+c.explicitNonceLen+c.cipher.Overhead()))
}
incSeq(c.halfPtr)
log.Trace("badtls write ", buffer.Len())
return c.writer.WriteBuffer(buffer)
}
func (c *Conn) FrontHeadroom() int {
return recordHeaderLen + c.explicitNonceLen
}
func (c *Conn) RearHeadroom() int {
return 1 + c.cipher.Overhead()
}
func (c *Conn) WriterMTU() int {
return maxPlaintext
}
func (c *Conn) Upstream() any {
return c.Conn
}
func (c *Conn) UpstreamWriter() any {
return c.NetConn()
}

View File

@ -1,14 +0,0 @@
//go:build !go1.19 || go1.21
package badtls
import (
"crypto/tls"
"os"
aTLS "github.com/sagernet/sing/common/tls"
)
func Create(conn *tls.Conn) (aTLS.Conn, error) {
return nil, os.ErrInvalid
}

View File

@ -1,22 +0,0 @@
//go:build go1.20 && !go.1.21
package badtls
import (
"reflect"
_ "unsafe"
)
const (
maxPlaintext = 16384 // maximum plaintext payload length
recordHeaderLen = 5 // record header length
)
//go:linkname errShutdown crypto/tls.errShutdown
var errShutdown error
//go:linkname incSeq crypto/tls.(*halfConn).incSeq
func incSeq(conn uintptr)
//go:linkname valueInterface reflect.valueInterface
func valueInterface(v reflect.Value, safe bool) any

115
common/badtls/read_wait.go Normal file
View File

@ -0,0 +1,115 @@
//go:build go1.21 && !without_badtls
package badtls
import (
"bytes"
"os"
"reflect"
"sync"
"unsafe"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/tls"
)
var _ N.ReadWaiter = (*ReadWaitConn)(nil)
type ReadWaitConn struct {
*tls.STDConn
halfAccess *sync.Mutex
rawInput *bytes.Buffer
input *bytes.Reader
hand *bytes.Buffer
readWaitOptions N.ReadWaitOptions
}
func NewReadWaitConn(conn tls.Conn) (tls.Conn, error) {
stdConn, isSTDConn := conn.(*tls.STDConn)
if !isSTDConn {
return nil, os.ErrInvalid
}
rawConn := reflect.Indirect(reflect.ValueOf(stdConn))
rawHalfConn := rawConn.FieldByName("in")
if !rawHalfConn.IsValid() || rawHalfConn.Kind() != reflect.Struct {
return nil, E.New("badtls: invalid half conn")
}
rawHalfMutex := rawHalfConn.FieldByName("Mutex")
if !rawHalfMutex.IsValid() || rawHalfMutex.Kind() != reflect.Struct {
return nil, E.New("badtls: invalid half mutex")
}
halfAccess := (*sync.Mutex)(unsafe.Pointer(rawHalfMutex.UnsafeAddr()))
rawRawInput := rawConn.FieldByName("rawInput")
if !rawRawInput.IsValid() || rawRawInput.Kind() != reflect.Struct {
return nil, E.New("badtls: invalid raw input")
}
rawInput := (*bytes.Buffer)(unsafe.Pointer(rawRawInput.UnsafeAddr()))
rawInput0 := rawConn.FieldByName("input")
if !rawInput0.IsValid() || rawInput0.Kind() != reflect.Struct {
return nil, E.New("badtls: invalid input")
}
input := (*bytes.Reader)(unsafe.Pointer(rawInput0.UnsafeAddr()))
rawHand := rawConn.FieldByName("hand")
if !rawHand.IsValid() || rawHand.Kind() != reflect.Struct {
return nil, E.New("badtls: invalid hand")
}
hand := (*bytes.Buffer)(unsafe.Pointer(rawHand.UnsafeAddr()))
return &ReadWaitConn{
STDConn: stdConn,
halfAccess: halfAccess,
rawInput: rawInput,
input: input,
hand: hand,
}, nil
}
func (c *ReadWaitConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
c.readWaitOptions = options
return false
}
func (c *ReadWaitConn) WaitReadBuffer() (buffer *buf.Buffer, err error) {
err = c.Handshake()
if err != nil {
return
}
c.halfAccess.Lock()
defer c.halfAccess.Unlock()
for c.input.Len() == 0 {
err = tlsReadRecord(c.STDConn)
if err != nil {
return
}
for c.hand.Len() > 0 {
err = tlsHandlePostHandshakeMessage(c.STDConn)
if err != nil {
return
}
}
}
buffer = c.readWaitOptions.NewBuffer()
n, err := c.input.Read(buffer.FreeBytes())
if err != nil {
buffer.Release()
return
}
buffer.Truncate(n)
if n != 0 && c.input.Len() == 0 && c.rawInput.Len() > 0 &&
// recordType(c.rawInput.Bytes()[0]) == recordTypeAlert {
c.rawInput.Bytes()[0] == 21 {
_ = tlsReadRecord(c.STDConn)
// return n, err // will be io.EOF on closeNotify
}
c.readWaitOptions.PostReturn(buffer)
return
}
//go:linkname tlsReadRecord crypto/tls.(*Conn).readRecord
func tlsReadRecord(c *tls.STDConn) error
//go:linkname tlsHandlePostHandshakeMessage crypto/tls.(*Conn).handlePostHandshakeMessage
func tlsHandlePostHandshakeMessage(c *tls.STDConn) error

View File

@ -0,0 +1,13 @@
//go:build !go1.21 || without_badtls
package badtls
import (
"os"
"github.com/sagernet/sing/common/tls"
)
func NewReadWaitConn(conn tls.Conn) (tls.Conn, error) {
return nil, os.ErrInvalid
}

View File

@ -6,6 +6,7 @@ import (
"os" "os"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/badtls"
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
@ -42,7 +43,17 @@ func NewClient(ctx context.Context, serverAddress string, options option.Outboun
func ClientHandshake(ctx context.Context, conn net.Conn, config Config) (Conn, error) { func ClientHandshake(ctx context.Context, conn net.Conn, config Config) (Conn, error) {
ctx, cancel := context.WithTimeout(ctx, C.TCPTimeout) ctx, cancel := context.WithTimeout(ctx, C.TCPTimeout)
defer cancel() defer cancel()
return aTLS.ClientHandshake(ctx, conn, config) tlsConn, err := aTLS.ClientHandshake(ctx, conn, config)
if err != nil {
return nil, err
}
readWaitConn, err := badtls.NewReadWaitConn(tlsConn)
if err == nil {
return readWaitConn, nil
} else if err != os.ErrInvalid {
return nil, err
}
return tlsConn, nil
} }
type Dialer struct { type Dialer struct {

View File

@ -3,7 +3,9 @@ package tls
import ( import (
"context" "context"
"net" "net"
"os"
"github.com/sagernet/sing-box/common/badtls"
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
@ -26,5 +28,15 @@ func NewServer(ctx context.Context, logger log.Logger, options option.InboundTLS
func ServerHandshake(ctx context.Context, conn net.Conn, config ServerConfig) (Conn, error) { func ServerHandshake(ctx context.Context, conn net.Conn, config ServerConfig) (Conn, error) {
ctx, cancel := context.WithTimeout(ctx, C.TCPTimeout) ctx, cancel := context.WithTimeout(ctx, C.TCPTimeout)
defer cancel() defer cancel()
return aTLS.ServerHandshake(ctx, conn, config) tlsConn, err := aTLS.ServerHandshake(ctx, conn, config)
if err != nil {
return nil, err
}
readWaitConn, err := badtls.NewReadWaitConn(tlsConn)
if err == nil {
return readWaitConn, nil
} else if err != os.ErrInvalid {
return nil, err
}
return tlsConn, nil
} }

12
go.mod
View File

@ -26,14 +26,14 @@ require (
github.com/sagernet/gvisor v0.0.0-20231119034329-07cfb6aaf930 github.com/sagernet/gvisor v0.0.0-20231119034329-07cfb6aaf930
github.com/sagernet/quic-go v0.40.0 github.com/sagernet/quic-go v0.40.0
github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691 github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691
github.com/sagernet/sing v0.2.19-0.20231208065534-70794cb91cc5 github.com/sagernet/sing v0.2.19-0.20231208110306-a3ce328ce759
github.com/sagernet/sing-dns v0.1.11 github.com/sagernet/sing-dns v0.1.11
github.com/sagernet/sing-mux v0.1.5 github.com/sagernet/sing-mux v0.1.6-0.20231207143704-9f6c20fb5266
github.com/sagernet/sing-quic v0.1.5 github.com/sagernet/sing-quic v0.1.6-0.20231207143711-eb3cbf9ed054
github.com/sagernet/sing-shadowsocks v0.2.5 github.com/sagernet/sing-shadowsocks v0.2.6
github.com/sagernet/sing-shadowsocks2 v0.1.5 github.com/sagernet/sing-shadowsocks2 v0.1.6-0.20231207143709-50439739601a
github.com/sagernet/sing-shadowtls v0.1.4 github.com/sagernet/sing-shadowtls v0.1.4
github.com/sagernet/sing-tun v0.1.22 github.com/sagernet/sing-tun v0.1.23-0.20231207143707-82a810316e14
github.com/sagernet/sing-vmess v0.1.8 github.com/sagernet/sing-vmess v0.1.8
github.com/sagernet/smux v0.0.0-20230312102458-337ec2a5af37 github.com/sagernet/smux v0.0.0-20230312102458-337ec2a5af37
github.com/sagernet/tfo-go v0.0.0-20230816093905-5a5c285d44a6 github.com/sagernet/tfo-go v0.0.0-20230816093905-5a5c285d44a6

24
go.sum
View File

@ -110,22 +110,22 @@ github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691 h1:5Th31OC6yj8byL
github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691/go.mod h1:B8lp4WkQ1PwNnrVMM6KyuFR20pU8jYBD+A4EhJovEXU= github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691/go.mod h1:B8lp4WkQ1PwNnrVMM6KyuFR20pU8jYBD+A4EhJovEXU=
github.com/sagernet/sing v0.0.0-20220817130738-ce854cda8522/go.mod h1:QVsS5L/ZA2Q5UhQwLrn0Trw+msNd/NPGEhBKR/ioWiY= github.com/sagernet/sing v0.0.0-20220817130738-ce854cda8522/go.mod h1:QVsS5L/ZA2Q5UhQwLrn0Trw+msNd/NPGEhBKR/ioWiY=
github.com/sagernet/sing v0.1.8/go.mod h1:jt1w2u7lJQFFSGLiRrRIs5YWmx4kAPfWuOejuDW9qMk= github.com/sagernet/sing v0.1.8/go.mod h1:jt1w2u7lJQFFSGLiRrRIs5YWmx4kAPfWuOejuDW9qMk=
github.com/sagernet/sing v0.2.19-0.20231208065534-70794cb91cc5 h1:0VDJK3Y2ZnBHptVyDwtqBMyyIsTcWD+RLzsCYxkixDA= github.com/sagernet/sing v0.2.19-0.20231208110306-a3ce328ce759 h1:BZfmPnZ2n0zD0YZb7UnAAaZ0Ib5riPgKvl5Jasz3LA4=
github.com/sagernet/sing v0.2.19-0.20231208065534-70794cb91cc5/go.mod h1:Ce5LNojQOgOiWhiD8pPD6E9H7e2KgtOe3Zxx4Ou5u80= github.com/sagernet/sing v0.2.19-0.20231208110306-a3ce328ce759/go.mod h1:Ce5LNojQOgOiWhiD8pPD6E9H7e2KgtOe3Zxx4Ou5u80=
github.com/sagernet/sing-dns v0.1.11 h1:PPrMCVVrAeR3f5X23I+cmvacXJ+kzuyAsBiWyUKhGSE= github.com/sagernet/sing-dns v0.1.11 h1:PPrMCVVrAeR3f5X23I+cmvacXJ+kzuyAsBiWyUKhGSE=
github.com/sagernet/sing-dns v0.1.11/go.mod h1:zJ/YjnYB61SYE+ubMcMqVdpaSvsyQ2iShQGO3vuLvvE= github.com/sagernet/sing-dns v0.1.11/go.mod h1:zJ/YjnYB61SYE+ubMcMqVdpaSvsyQ2iShQGO3vuLvvE=
github.com/sagernet/sing-mux v0.1.5 h1:jUbYth9QQd1wsDmU8Ush+fKce7lNo9TMv2dp8PJtSOY= github.com/sagernet/sing-mux v0.1.6-0.20231207143704-9f6c20fb5266 h1:QqwwUyEfmOuoGVTZ2cYvUJEeSWlzunvQLRmv+9B41uk=
github.com/sagernet/sing-mux v0.1.5/go.mod h1:MoH6Soz1R+CYZcCeIXZWx6fkZa6hQc9o3HZu9G6CDTw= github.com/sagernet/sing-mux v0.1.6-0.20231207143704-9f6c20fb5266/go.mod h1:uxpcXa8JqSR+ufC1sGAPsCs027wpE7v1ltnhuJKqyBQ=
github.com/sagernet/sing-quic v0.1.5 h1:PIQzE4cGrry+JkkMEJH/EH3wRkv/QgD48+ScNr/2oig= github.com/sagernet/sing-quic v0.1.6-0.20231207143711-eb3cbf9ed054 h1:Ed7FskwQcep5oQ+QahgVK0F6jPPSV8Nqwjr9MwGatMU=
github.com/sagernet/sing-quic v0.1.5/go.mod h1:n2mXukpubasyV4SlWyyW0+LCdAn7DZ8/brAkUxZujrw= github.com/sagernet/sing-quic v0.1.6-0.20231207143711-eb3cbf9ed054/go.mod h1:u758WWv3G1OITG365CYblL0NfAruFL1PpLD9DUVTv1o=
github.com/sagernet/sing-shadowsocks v0.2.5 h1:qxIttos4xu6ii7MTVJYA8EFQR7Q3KG6xMqmLJIFtBaY= github.com/sagernet/sing-shadowsocks v0.2.6 h1:xr7ylAS/q1cQYS8oxKKajhuQcchd5VJJ4K4UZrrpp0s=
github.com/sagernet/sing-shadowsocks v0.2.5/go.mod h1:MGWGkcU2xW2G2mfArT9/QqpVLOGU+dBaahZCtPHdt7A= github.com/sagernet/sing-shadowsocks v0.2.6/go.mod h1:j2YZBIpWIuElPFL/5sJAj470bcn/3QQ5lxZUNKLDNAM=
github.com/sagernet/sing-shadowsocks2 v0.1.5 h1:JDeAJ4ZWlYZ7F6qEVdDKPhQEangxKw/JtmU+i/YfCYE= github.com/sagernet/sing-shadowsocks2 v0.1.6-0.20231207143709-50439739601a h1:uYIKfpE1/EJpa+1Bja7b006VixeRuVduOpeuesMk2lU=
github.com/sagernet/sing-shadowsocks2 v0.1.5/go.mod h1:KF65y8lI5PGHyMgRZGYXYsH9ilgRc/yr+NYbSNGuBm4= github.com/sagernet/sing-shadowsocks2 v0.1.6-0.20231207143709-50439739601a/go.mod h1:pjeylQ4ApvpEH7B4PUBrdyJf4xmQkg8BaIzT5fI2fR0=
github.com/sagernet/sing-shadowtls v0.1.4 h1:aTgBSJEgnumzFenPvc+kbD9/W0PywzWevnVpEx6Tw3k= github.com/sagernet/sing-shadowtls v0.1.4 h1:aTgBSJEgnumzFenPvc+kbD9/W0PywzWevnVpEx6Tw3k=
github.com/sagernet/sing-shadowtls v0.1.4/go.mod h1:F8NBgsY5YN2beQavdgdm1DPlhaKQlaL6lpDdcBglGK4= github.com/sagernet/sing-shadowtls v0.1.4/go.mod h1:F8NBgsY5YN2beQavdgdm1DPlhaKQlaL6lpDdcBglGK4=
github.com/sagernet/sing-tun v0.1.22 h1:AECJTkiugCK+GCrV41YZ56HB/Z/lDXZvRVas4fNvO30= github.com/sagernet/sing-tun v0.1.23-0.20231207143707-82a810316e14 h1:79d3jw/nlhy3VAIoRvMxRjcOUh7e0D8Mx0cuaBrdIC4=
github.com/sagernet/sing-tun v0.1.22/go.mod h1:fliIEXDRv2u1uT3uCZIoA1daoZcD4f6TeIuzNIzlsN8= github.com/sagernet/sing-tun v0.1.23-0.20231207143707-82a810316e14/go.mod h1:ygdUHhVv4ZEsu0+4rAbAAoHqzqrhvhVNxrbMryapDwI=
github.com/sagernet/sing-vmess v0.1.8 h1:XVWad1RpTy9b5tPxdm5MCU8cGfrTGdR8qCq6HV2aCNc= github.com/sagernet/sing-vmess v0.1.8 h1:XVWad1RpTy9b5tPxdm5MCU8cGfrTGdR8qCq6HV2aCNc=
github.com/sagernet/sing-vmess v0.1.8/go.mod h1:vhx32UNzTDUkNwOyIjcZQohre1CaytquC5mPplId8uA= github.com/sagernet/sing-vmess v0.1.8/go.mod h1:vhx32UNzTDUkNwOyIjcZQohre1CaytquC5mPplId8uA=
github.com/sagernet/smux v0.0.0-20230312102458-337ec2a5af37 h1:HuE6xSwco/Xed8ajZ+coeYLmioq0Qp1/Z2zczFaV8as= github.com/sagernet/smux v0.0.0-20230312102458-337ec2a5af37 h1:HuE6xSwco/Xed8ajZ+coeYLmioq0Qp1/Z2zczFaV8as=

View File

@ -26,7 +26,7 @@ var (
type HTTP struct { type HTTP struct {
myInboundAdapter myInboundAdapter
authenticator auth.Authenticator authenticator *auth.Authenticator
tlsConfig tls.ServerConfig tlsConfig tls.ServerConfig
} }

View File

@ -29,7 +29,7 @@ var (
type Mixed struct { type Mixed struct {
myInboundAdapter myInboundAdapter
authenticator auth.Authenticator authenticator *auth.Authenticator
} }
func NewMixed(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.HTTPMixedInboundOptions) *Mixed { func NewMixed(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.HTTPMixedInboundOptions) *Mixed {

View File

@ -32,7 +32,7 @@ var _ adapter.Inbound = (*Naive)(nil)
type Naive struct { type Naive struct {
myInboundAdapter myInboundAdapter
authenticator auth.Authenticator authenticator *auth.Authenticator
tlsConfig tls.ServerConfig tlsConfig tls.ServerConfig
httpServer *http.Server httpServer *http.Server
h3Server any h3Server any

View File

@ -22,7 +22,7 @@ var (
type Socks struct { type Socks struct {
myInboundAdapter myInboundAdapter
authenticator auth.Authenticator authenticator *auth.Authenticator
} }
func NewSocks(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.SocksInboundOptions) *Socks { func NewSocks(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.SocksInboundOptions) *Socks {

View File

@ -111,6 +111,9 @@ func (d *DNS) NewPacketConnection(ctx context.Context, conn N.PacketConn, metada
} }
} }
if readWaiter, created := bufio.CreatePacketReadWaiter(reader); created { if readWaiter, created := bufio.CreatePacketReadWaiter(reader); created {
readWaiter.InitializeReadWaiter(N.ReadWaitOptions{
MTU: dns.FixedPacketSize,
})
return d.newPacketConnection(ctx, conn, readWaiter, counters, cachedPackets, metadata) return d.newPacketConnection(ctx, conn, readWaiter, counters, cachedPackets, metadata)
} }
break break
@ -193,15 +196,13 @@ func (d *DNS) newPacketConnection(ctx context.Context, conn N.PacketConn, readWa
timeout := canceler.New(fastClose, cancel, C.DNSTimeout) timeout := canceler.New(fastClose, cancel, C.DNSTimeout)
var group task.Group var group task.Group
group.Append0(func(ctx context.Context) error { group.Append0(func(ctx context.Context) error {
var buffer *buf.Buffer
readWaiter.InitializeReadWaiter(func() *buf.Buffer {
return buf.NewSize(dns.FixedPacketSize)
})
defer readWaiter.InitializeReadWaiter(nil)
for { for {
var message mDNS.Msg var (
var destination M.Socksaddr message mDNS.Msg
var err error destination M.Socksaddr
err error
buffer *buf.Buffer
)
if len(cached) > 0 { if len(cached) > 0 {
packet := cached[0] packet := cached[0]
cached = cached[1:] cached = cached[1:]
@ -216,9 +217,8 @@ func (d *DNS) newPacketConnection(ctx context.Context, conn N.PacketConn, readWa
} }
destination = packet.Destination destination = packet.Destination
} else { } else {
destination, err = readWaiter.WaitReadPacket() buffer, destination, err = readWaiter.WaitReadPacket()
if err != nil { if err != nil {
buffer.Release()
cancel(err) cancel(err)
return err return err
} }

View File

@ -30,7 +30,7 @@ type ProxyListener struct {
tcpListener *net.TCPListener tcpListener *net.TCPListener
username string username string
password string password string
authenticator auth.Authenticator authenticator *auth.Authenticator
} }
func NewProxyListener(ctx context.Context, logger log.ContextLogger, dialer N.Dialer) *ProxyListener { func NewProxyListener(ctx context.Context, logger log.ContextLogger, dialer N.Dialer) *ProxyListener {

View File

@ -17,16 +17,16 @@ func (c *NATPacketConn) CreatePacketReadWaiter() (N.PacketReadWaiter, bool) {
type waitNATPacketConn struct { type waitNATPacketConn struct {
*NATPacketConn *NATPacketConn
waiter N.PacketReadWaiter readWaiter N.PacketReadWaiter
} }
func (c *waitNATPacketConn) InitializeReadWaiter(newBuffer func() *buf.Buffer) { func (c *waitNATPacketConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
c.waiter.InitializeReadWaiter(newBuffer) return c.readWaiter.InitializeReadWaiter(options)
} }
func (c *waitNATPacketConn) WaitReadPacket() (destination M.Socksaddr, err error) { func (c *waitNATPacketConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
destination, err = c.waiter.WaitReadPacket() buffer, destination, err = c.readWaiter.WaitReadPacket()
if socksaddrWithoutPort(destination) == c.origin { if err == nil && socksaddrWithoutPort(destination) == c.origin {
destination = M.Socksaddr{ destination = M.Socksaddr{
Addr: c.destination.Addr, Addr: c.destination.Addr,
Fqdn: c.destination.Fqdn, Fqdn: c.destination.Fqdn,

View File

@ -53,7 +53,7 @@ func newMuxConnection0(ctx context.Context, stream net.Conn, metadata M.Metadata
case CommandTCP: case CommandTCP:
return handler.NewConnection(ctx, stream, metadata) return handler.NewConnection(ctx, stream, metadata)
case CommandUDP: case CommandUDP:
return handler.NewPacketConnection(ctx, &PacketConn{stream}, metadata) return handler.NewPacketConnection(ctx, &PacketConn{Conn: stream}, metadata)
default: default:
return E.New("unknown command ", command) return E.New("unknown command ", command)
} }

View File

@ -85,9 +85,10 @@ func (c *ClientConn) Upstream() any {
type ClientPacketConn struct { type ClientPacketConn struct {
net.Conn net.Conn
access sync.Mutex access sync.Mutex
key [KeyLength]byte key [KeyLength]byte
headerWritten bool headerWritten bool
readWaitOptions N.ReadWaitOptions
} }
func NewClientPacketConn(conn net.Conn, key [KeyLength]byte) *ClientPacketConn { func NewClientPacketConn(conn net.Conn, key [KeyLength]byte) *ClientPacketConn {

View File

@ -0,0 +1,45 @@
package trojan
import (
"encoding/binary"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/rw"
)
var _ N.PacketReadWaiter = (*ClientPacketConn)(nil)
func (c *ClientPacketConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
c.readWaitOptions = options
return false
}
func (c *ClientPacketConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
destination, err = M.SocksaddrSerializer.ReadAddrPort(c.Conn)
if err != nil {
return nil, M.Socksaddr{}, E.Cause(err, "read destination")
}
var length uint16
err = binary.Read(c.Conn, binary.BigEndian, &length)
if err != nil {
return nil, M.Socksaddr{}, E.Cause(err, "read chunk length")
}
err = rw.SkipN(c.Conn, 2)
if err != nil {
return nil, M.Socksaddr{}, E.Cause(err, "skip crlf")
}
buffer = c.readWaitOptions.NewPacketBuffer()
_, err = buffer.ReadFullFrom(c.Conn, int(length))
if err != nil {
buffer.Release()
return
}
c.readWaitOptions.PostReturn(buffer)
return
}

View File

@ -105,7 +105,7 @@ func (s *Service[K]) NewConnection(ctx context.Context, conn net.Conn, metadata
case CommandTCP: case CommandTCP:
return s.handler.NewConnection(ctx, conn, metadata) return s.handler.NewConnection(ctx, conn, metadata)
case CommandUDP: case CommandUDP:
return s.handler.NewPacketConnection(ctx, &PacketConn{conn}, metadata) return s.handler.NewPacketConnection(ctx, &PacketConn{Conn: conn}, metadata)
// case CommandMux: // case CommandMux:
default: default:
return HandleMuxConnection(ctx, conn, metadata, s.handler) return HandleMuxConnection(ctx, conn, metadata, s.handler)
@ -122,6 +122,7 @@ func (s *Service[K]) fallback(ctx context.Context, conn net.Conn, metadata M.Met
type PacketConn struct { type PacketConn struct {
net.Conn net.Conn
readWaitOptions N.ReadWaitOptions
} }
func (c *PacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { func (c *PacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {

View File

@ -0,0 +1,45 @@
package trojan
import (
"encoding/binary"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/rw"
)
var _ N.PacketReadWaiter = (*PacketConn)(nil)
func (c *PacketConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
c.readWaitOptions = options
return false
}
func (c *PacketConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
destination, err = M.SocksaddrSerializer.ReadAddrPort(c.Conn)
if err != nil {
return nil, M.Socksaddr{}, E.Cause(err, "read destination")
}
var length uint16
err = binary.Read(c.Conn, binary.BigEndian, &length)
if err != nil {
return nil, M.Socksaddr{}, E.Cause(err, "read chunk length")
}
err = rw.SkipN(c.Conn, 2)
if err != nil {
return nil, M.Socksaddr{}, E.Cause(err, "skip crlf")
}
buffer = c.readWaitOptions.NewPacketBuffer()
_, err = buffer.ReadFullFrom(c.Conn, int(length))
if err != nil {
buffer.Release()
return
}
c.readWaitOptions.PostReturn(buffer)
return
}

View File

@ -76,11 +76,8 @@ func (c *ClientBind) connect() (*wireConn, error) {
return nil, err return nil, err
} }
c.conn = &wireConn{ c.conn = &wireConn{
PacketConn: &bufio.UnbindPacketConn{ PacketConn: bufio.NewUnbindPacketConn(udpConn),
ExtendedConn: bufio.NewExtendedConn(udpConn), done: make(chan struct{}),
Addr: c.connectAddr,
},
done: make(chan struct{}),
} }
} else { } else {
udpConn, err := c.dialer.ListenPacket(c.ctx, M.Socksaddr{Addr: netip.IPv4Unspecified()}) udpConn, err := c.dialer.ListenPacket(c.ctx, M.Socksaddr{Addr: netip.IPv4Unspecified()})