mirror of
https://github.com/SagerNet/sing-box.git
synced 2025-06-08 03:34:13 +08:00
Add hysteria udp server
This commit is contained in:
parent
b992d942c4
commit
32cb511b7c
@ -7,19 +7,20 @@ import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/sagernet/quic-go"
|
||||
"github.com/sagernet/quic-go/congestion"
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing-box/transport/hysteria"
|
||||
dns "github.com/sagernet/sing-dns"
|
||||
"github.com/sagernet/sing-dns"
|
||||
"github.com/sagernet/sing/common"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
)
|
||||
|
||||
var _ adapter.Inbound = (*Hysteria)(nil)
|
||||
@ -37,6 +38,10 @@ type Hysteria struct {
|
||||
sendBPS uint64
|
||||
recvBPS uint64
|
||||
listener quic.Listener
|
||||
udpAccess sync.RWMutex
|
||||
udpSessionId uint32
|
||||
udpSessions map[uint32]chan *hysteria.UDPMessage
|
||||
udpDefragger hysteria.Defragger
|
||||
}
|
||||
|
||||
func NewHysteria(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.HysteriaInboundOptions) (*Hysteria, error) {
|
||||
@ -105,6 +110,7 @@ func NewHysteria(ctx context.Context, router adapter.Router, logger log.ContextL
|
||||
xplusKey: xplus,
|
||||
sendBPS: up,
|
||||
recvBPS: down,
|
||||
udpSessions: make(map[uint32]chan *hysteria.UDPMessage),
|
||||
}
|
||||
if options.TLS == nil || !options.TLS.Enabled {
|
||||
return nil, ErrTLSRequired
|
||||
@ -138,6 +144,7 @@ func (h *Hysteria) Start() error {
|
||||
}
|
||||
if len(h.xplusKey) > 0 {
|
||||
packetConn = hysteria.NewXPlusPacketConn(packetConn, h.xplusKey)
|
||||
packetConn = &hysteria.PacketConnWrapper{PacketConn: packetConn}
|
||||
}
|
||||
err = h.tlsConfig.Start()
|
||||
if err != nil {
|
||||
@ -159,6 +166,7 @@ func (h *Hysteria) acceptLoop() {
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
h.logger.InfoContext(ctx, "inbound connection from ", conn.RemoteAddr())
|
||||
go func() {
|
||||
hErr := h.accept(ctx, conn)
|
||||
if hErr != nil {
|
||||
@ -202,23 +210,51 @@ func (h *Hysteria) accept(ctx context.Context, conn quic.Connection) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// TODO: set congestion control
|
||||
conn.SetCongestionControl(hysteria.NewBrutalSender(congestion.ByteCount(serverSendBPS)))
|
||||
go h.udpRecvLoop(conn)
|
||||
var stream quic.Stream
|
||||
for {
|
||||
var stream quic.Stream
|
||||
stream, err = conn.AcceptStream(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
hErr := h.acceptStream(ctx, conn, stream)
|
||||
if hErr != nil {
|
||||
stream.Close()
|
||||
NewError(h.logger, ctx, E.Cause(hErr, "process stream from ", conn.RemoteAddr()))
|
||||
}
|
||||
go func() {
|
||||
hErr := h.acceptStream(ctx, conn /*&hysteria.StreamWrapper{Stream: stream}*/, stream)
|
||||
if hErr != nil {
|
||||
stream.Close()
|
||||
NewError(h.logger, ctx, E.Cause(hErr, "process stream from ", conn.RemoteAddr()))
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Hysteria) udpRecvLoop(conn quic.Connection) {
|
||||
for {
|
||||
packet, err := conn.ReceiveMessage()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
message, err := hysteria.ParseUDPMessage(packet)
|
||||
if err != nil {
|
||||
h.logger.Error("parse udp message: ", err)
|
||||
continue
|
||||
}
|
||||
dfMsg := h.udpDefragger.Feed(message)
|
||||
if dfMsg == nil {
|
||||
continue
|
||||
}
|
||||
h.udpAccess.RLock()
|
||||
ch, ok := h.udpSessions[dfMsg.SessionID]
|
||||
if ok {
|
||||
select {
|
||||
case ch <- dfMsg:
|
||||
// OK
|
||||
default:
|
||||
// Silently drop the message when the channel is full
|
||||
}
|
||||
}
|
||||
h.udpAccess.RUnlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Hysteria) acceptStream(ctx context.Context, conn quic.Connection, stream quic.Stream) error {
|
||||
@ -226,15 +262,11 @@ func (h *Hysteria) acceptStream(ctx context.Context, conn quic.Connection, strea
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if request.UDP {
|
||||
err = hysteria.WriteServerResponse(stream, hysteria.ServerResponse{
|
||||
Message: "unsupported",
|
||||
}, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stream.Close()
|
||||
return nil
|
||||
err = hysteria.WriteServerResponse(stream, hysteria.ServerResponse{
|
||||
OK: true,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var metadata adapter.InboundContext
|
||||
metadata.Inbound = h.tag
|
||||
@ -242,13 +274,43 @@ func (h *Hysteria) acceptStream(ctx context.Context, conn quic.Connection, strea
|
||||
metadata.SniffEnabled = h.listenOptions.SniffEnabled
|
||||
metadata.SniffOverrideDestination = h.listenOptions.SniffOverrideDestination
|
||||
metadata.DomainStrategy = dns.DomainStrategy(h.listenOptions.DomainStrategy)
|
||||
metadata.Network = N.NetworkTCP
|
||||
metadata.Source = M.SocksaddrFromNet(conn.RemoteAddr())
|
||||
metadata.Destination = M.ParseSocksaddrHostPort(request.Host, request.Port)
|
||||
return h.router.RouteConnection(ctx, hysteria.NewServerConn(stream, metadata.Destination), metadata)
|
||||
if !request.UDP {
|
||||
h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination)
|
||||
metadata.Network = N.NetworkTCP
|
||||
return h.router.RouteConnection(ctx, hysteria.NewConn(stream, metadata.Destination), metadata)
|
||||
} else {
|
||||
h.logger.InfoContext(ctx, "inbound packet connection to ", metadata.Destination)
|
||||
var id uint32
|
||||
h.udpAccess.Lock()
|
||||
id = h.udpSessionId
|
||||
nCh := make(chan *hysteria.UDPMessage, 1024)
|
||||
h.udpSessions[id] = nCh
|
||||
h.udpSessionId += 1
|
||||
h.udpAccess.Unlock()
|
||||
metadata.Network = N.NetworkUDP
|
||||
packetConn := hysteria.NewPacketConn(conn, stream, id, metadata.Destination, nCh, common.Closer(func() error {
|
||||
h.udpAccess.Lock()
|
||||
if ch, ok := h.udpSessions[id]; ok {
|
||||
close(ch)
|
||||
delete(h.udpSessions, id)
|
||||
}
|
||||
h.udpAccess.Unlock()
|
||||
return nil
|
||||
}))
|
||||
go packetConn.Hold()
|
||||
return h.router.RoutePacketConnection(ctx, packetConn, metadata)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Hysteria) Close() error {
|
||||
h.udpAccess.Lock()
|
||||
for _, session := range h.udpSessions {
|
||||
close(session)
|
||||
}
|
||||
h.udpSessions = make(map[uint32]chan *hysteria.UDPMessage)
|
||||
h.udpAccess.Unlock()
|
||||
return common.Close(
|
||||
h.listener,
|
||||
common.PtrOrNil(h.tlsConfig),
|
||||
|
@ -195,6 +195,8 @@ func (n *Naive) newConnection(ctx context.Context, conn net.Conn, source, destin
|
||||
metadata.Network = N.NetworkTCP
|
||||
metadata.Source = source
|
||||
metadata.Destination = destination
|
||||
n.logger.InfoContext(ctx, "inbound connection from ", metadata.Source)
|
||||
n.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination)
|
||||
hErr := n.router.RouteConnection(ctx, conn, metadata)
|
||||
if hErr != nil {
|
||||
conn.Close()
|
||||
|
@ -10,6 +10,8 @@ import (
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/sagernet/quic-go"
|
||||
"github.com/sagernet/quic-go/congestion"
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/common/dialer"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
@ -21,8 +23,6 @@ import (
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
)
|
||||
|
||||
var _ adapter.Outbound = (*Hysteria)(nil)
|
||||
@ -171,7 +171,7 @@ func (h *Hysteria) offer(ctx context.Context) (quic.Connection, error) {
|
||||
}
|
||||
h.udpSessions = make(map[uint32]chan *hysteria.UDPMessage)
|
||||
h.udpDefragger = hysteria.Defragger{}
|
||||
go h.recvLoop(conn)
|
||||
go h.udpRecvLoop(conn)
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
@ -186,7 +186,7 @@ func (h *Hysteria) offerNew(ctx context.Context) (quic.Connection, error) {
|
||||
if h.xplusKey != nil {
|
||||
packetConn = hysteria.NewXPlusPacketConn(packetConn, h.xplusKey)
|
||||
}
|
||||
packetConn = &hysteria.WrapPacketConn{PacketConn: packetConn}
|
||||
packetConn = &hysteria.PacketConnWrapper{PacketConn: packetConn}
|
||||
quicConn, err := quic.Dial(packetConn, udpConn.RemoteAddr(), h.serverAddr.AddrString(), h.tlsConfig, h.quicConfig)
|
||||
if err != nil {
|
||||
packetConn.Close()
|
||||
@ -203,20 +203,23 @@ func (h *Hysteria) offerNew(ctx context.Context) (quic.Connection, error) {
|
||||
Auth: h.authKey,
|
||||
})
|
||||
if err != nil {
|
||||
packetConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
serverHello, err := hysteria.ReadServerHello(controlStream)
|
||||
if err != nil {
|
||||
packetConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
if !serverHello.OK {
|
||||
packetConn.Close()
|
||||
return nil, E.New("remote error: ", serverHello.Message)
|
||||
}
|
||||
// TODO: set congestion control
|
||||
quicConn.SetCongestionControl(hysteria.NewBrutalSender(congestion.ByteCount(serverHello.RecvBPS)))
|
||||
return quicConn, nil
|
||||
}
|
||||
|
||||
func (h *Hysteria) recvLoop(conn quic.Connection) {
|
||||
func (h *Hysteria) udpRecvLoop(conn quic.Connection) {
|
||||
for {
|
||||
packet, err := conn.ReceiveMessage()
|
||||
if err != nil {
|
||||
@ -260,35 +263,58 @@ func (h *Hysteria) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Hysteria) open(ctx context.Context) (quic.Connection, quic.Stream, error) {
|
||||
conn, err := h.offer(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
stream, err := conn.OpenStream()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return conn, &hysteria.StreamWrapper{Stream: stream}, nil
|
||||
}
|
||||
|
||||
func (h *Hysteria) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||
switch N.NetworkName(network) {
|
||||
case N.NetworkTCP:
|
||||
conn, err := h.offer(ctx)
|
||||
h.logger.InfoContext(ctx, "outbound connection to ", destination)
|
||||
_, stream, err := h.open(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stream, err := conn.OpenStream()
|
||||
err = hysteria.WriteClientRequest(stream, hysteria.ClientRequest{
|
||||
Host: destination.AddrString(),
|
||||
Port: destination.Port,
|
||||
})
|
||||
if err != nil {
|
||||
stream.Close()
|
||||
return nil, err
|
||||
}
|
||||
return hysteria.NewClientConn(stream, destination), nil
|
||||
response, err := hysteria.ReadServerResponse(stream)
|
||||
if err != nil {
|
||||
stream.Close()
|
||||
return nil, err
|
||||
}
|
||||
if !response.OK {
|
||||
stream.Close()
|
||||
return nil, E.New("remote error: ", response.Message)
|
||||
}
|
||||
return hysteria.NewConn(stream, destination), nil
|
||||
case N.NetworkUDP:
|
||||
conn, err := h.ListenPacket(ctx, destination)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return conn.(*hysteria.ClientPacketConn), nil
|
||||
return conn.(*hysteria.PacketConn), nil
|
||||
default:
|
||||
return nil, E.New("unsupported network: ", network)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Hysteria) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||
conn, err := h.offer(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stream, err := conn.OpenStream()
|
||||
h.logger.InfoContext(ctx, "outbound packet connection to ", destination)
|
||||
conn, stream, err := h.open(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -296,7 +322,7 @@ func (h *Hysteria) ListenPacket(ctx context.Context, destination M.Socksaddr) (n
|
||||
UDP: true,
|
||||
Host: destination.AddrString(),
|
||||
Port: destination.Port,
|
||||
}, nil)
|
||||
})
|
||||
if err != nil {
|
||||
stream.Close()
|
||||
return nil, err
|
||||
@ -313,12 +339,9 @@ func (h *Hysteria) ListenPacket(ctx context.Context, destination M.Socksaddr) (n
|
||||
}
|
||||
h.udpAccess.Lock()
|
||||
nCh := make(chan *hysteria.UDPMessage, 1024)
|
||||
// Store the current session map for CloseFunc below
|
||||
// to ensures that we are adding and removing sessions on the same map,
|
||||
// as reconnecting will reassign the map
|
||||
h.udpSessions[response.UDPSessionID] = nCh
|
||||
h.udpAccess.Unlock()
|
||||
packetConn := hysteria.NewClientPacketConn(conn, stream, response.UDPSessionID, destination, nCh, common.Closer(func() error {
|
||||
packetConn := hysteria.NewPacketConn(conn, stream, response.UDPSessionID, destination, nCh, common.Closer(func() error {
|
||||
h.udpAccess.Lock()
|
||||
if ch, ok := h.udpSessions[response.UDPSessionID]; ok {
|
||||
close(ch)
|
||||
|
@ -56,10 +56,21 @@ func testTCP(t *testing.T, clientPort uint16, testPort uint16) {
|
||||
dialTCP := func() (net.Conn, error) {
|
||||
return dialer.DialContext(context.Background(), "tcp", M.ParseSocksaddrHostPort("127.0.0.1", testPort))
|
||||
}
|
||||
require.NoError(t, testPingPongWithConn(t, testPort, dialTCP))
|
||||
require.NoError(t, testLargeDataWithConn(t, testPort, dialTCP))
|
||||
}
|
||||
|
||||
func testSuitHy(t *testing.T, clientPort uint16, testPort uint16) {
|
||||
dialer := socks.NewClient(N.SystemDialer, M.ParseSocksaddrHostPort("127.0.0.1", clientPort), socks.Version5, "", "")
|
||||
dialTCP := func() (net.Conn, error) {
|
||||
return dialer.DialContext(context.Background(), "tcp", M.ParseSocksaddrHostPort("127.0.0.1", testPort))
|
||||
}
|
||||
dialUDP := func() (net.PacketConn, error) {
|
||||
return dialer.ListenPacket(context.Background(), M.ParseSocksaddrHostPort("127.0.0.1", testPort))
|
||||
}
|
||||
require.NoError(t, testPingPongWithConn(t, testPort, dialTCP))
|
||||
require.NoError(t, testPingPongWithPacketConn(t, testPort, dialUDP))
|
||||
}
|
||||
|
||||
func testSuitWg(t *testing.T, clientPort uint16, testPort uint16) {
|
||||
dialer := socks.NewClient(N.SystemDialer, M.ParseSocksaddrHostPort("127.0.0.1", clientPort), socks.Version5, "", "")
|
||||
dialTCP := func() (net.Conn, error) {
|
||||
|
@ -7,6 +7,8 @@ import (
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"testing"
|
||||
@ -91,6 +93,12 @@ func init() {
|
||||
|
||||
io.Copy(io.Discard, imageStream)
|
||||
}
|
||||
go func() {
|
||||
err = http.ListenAndServe("0.0.0.0:8965", nil)
|
||||
if err != nil {
|
||||
log.Debug(err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func newPingPongPair() (chan []byte, chan []byte, func(t *testing.T) error) {
|
||||
|
12
test/config/hysteria-client.json
Normal file
12
test/config/hysteria-client.json
Normal file
@ -0,0 +1,12 @@
|
||||
{
|
||||
"server": "127.0.0.1:10000",
|
||||
"auth_str": "password",
|
||||
"obfs": "fuck me till the daylight",
|
||||
"up_mbps": 100,
|
||||
"down_mbps": 100,
|
||||
"socks5": {
|
||||
"listen": "127.0.0.1:10001"
|
||||
},
|
||||
"server_name": "example.org",
|
||||
"ca": "/etc/hysteria/ca.pem"
|
||||
}
|
@ -80,7 +80,50 @@ func TestHysteriaSelf(t *testing.T) {
|
||||
},
|
||||
},
|
||||
})
|
||||
testTCP(t, clientPort, testPort)
|
||||
testSuitHy(t, clientPort, testPort)
|
||||
}
|
||||
|
||||
func TestHysteriaInbound(t *testing.T) {
|
||||
if !C.QUIC_AVAILABLE {
|
||||
t.Skip("QUIC not included")
|
||||
}
|
||||
caPem, certPem, keyPem := createSelfSignedCertificate(t, "example.org")
|
||||
startInstance(t, option.Options{
|
||||
Log: &option.LogOptions{
|
||||
Level: "trace",
|
||||
},
|
||||
Inbounds: []option.Inbound{
|
||||
{
|
||||
Type: C.TypeHysteria,
|
||||
HysteriaOptions: option.HysteriaInboundOptions{
|
||||
ListenOptions: option.ListenOptions{
|
||||
Listen: option.ListenAddress(netip.IPv4Unspecified()),
|
||||
ListenPort: serverPort,
|
||||
},
|
||||
UpMbps: 100,
|
||||
DownMbps: 100,
|
||||
AuthString: "password",
|
||||
Obfs: "fuck me till the daylight",
|
||||
TLS: &option.InboundTLSOptions{
|
||||
Enabled: true,
|
||||
ServerName: "example.org",
|
||||
CertificatePath: certPem,
|
||||
KeyPath: keyPem,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
startDockerContainer(t, DockerOptions{
|
||||
Image: ImageHysteria,
|
||||
Ports: []uint16{serverPort, clientPort},
|
||||
Cmd: []string{"-c", "/etc/hysteria/config.json", "client"},
|
||||
Bind: map[string]string{
|
||||
"hysteria-client.json": "/etc/hysteria/config.json",
|
||||
caPem: "/etc/hysteria/ca.pem",
|
||||
},
|
||||
})
|
||||
testSuit(t, clientPort, testPort)
|
||||
}
|
||||
|
||||
func TestHysteriaOutbound(t *testing.T) {
|
||||
@ -93,9 +136,9 @@ func TestHysteriaOutbound(t *testing.T) {
|
||||
Ports: []uint16{serverPort, testPort},
|
||||
Cmd: []string{"-c", "/etc/hysteria/config.json", "server"},
|
||||
Bind: map[string]string{
|
||||
"hysteria.json": "/etc/hysteria/config.json",
|
||||
certPem: "/etc/hysteria/cert.pem",
|
||||
keyPem: "/etc/hysteria/key.pem",
|
||||
"hysteria-server.json": "/etc/hysteria/config.json",
|
||||
certPem: "/etc/hysteria/cert.pem",
|
||||
keyPem: "/etc/hysteria/key.pem",
|
||||
},
|
||||
})
|
||||
startInstance(t, option.Options{
|
||||
@ -131,5 +174,5 @@ func TestHysteriaOutbound(t *testing.T) {
|
||||
},
|
||||
},
|
||||
})
|
||||
testSuit(t, clientPort, testPort)
|
||||
testSuitHy(t, clientPort, testPort)
|
||||
}
|
||||
|
149
transport/hysteria/brutal.go
Normal file
149
transport/hysteria/brutal.go
Normal file
@ -0,0 +1,149 @@
|
||||
package hysteria
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/quic-go/congestion"
|
||||
)
|
||||
|
||||
const (
|
||||
initMaxDatagramSize = 1252
|
||||
|
||||
pktInfoSlotCount = 4
|
||||
minSampleCount = 50
|
||||
minAckRate = 0.8
|
||||
)
|
||||
|
||||
type BrutalSender struct {
|
||||
rttStats congestion.RTTStatsProvider
|
||||
bps congestion.ByteCount
|
||||
maxDatagramSize congestion.ByteCount
|
||||
pacer *pacer
|
||||
|
||||
pktInfoSlots [pktInfoSlotCount]pktInfo
|
||||
ackRate float64
|
||||
}
|
||||
|
||||
type pktInfo struct {
|
||||
Timestamp int64
|
||||
AckCount uint64
|
||||
LossCount uint64
|
||||
}
|
||||
|
||||
func NewBrutalSender(bps congestion.ByteCount) *BrutalSender {
|
||||
bs := &BrutalSender{
|
||||
bps: bps,
|
||||
maxDatagramSize: initMaxDatagramSize,
|
||||
ackRate: 1,
|
||||
}
|
||||
bs.pacer = newPacer(func() congestion.ByteCount {
|
||||
return congestion.ByteCount(float64(bs.bps) / bs.ackRate)
|
||||
})
|
||||
return bs
|
||||
}
|
||||
|
||||
func (b *BrutalSender) SetRTTStatsProvider(rttStats congestion.RTTStatsProvider) {
|
||||
b.rttStats = rttStats
|
||||
}
|
||||
|
||||
func (b *BrutalSender) TimeUntilSend(bytesInFlight congestion.ByteCount) time.Time {
|
||||
return b.pacer.TimeUntilSend()
|
||||
}
|
||||
|
||||
func (b *BrutalSender) HasPacingBudget() bool {
|
||||
return b.pacer.Budget(time.Now()) >= b.maxDatagramSize
|
||||
}
|
||||
|
||||
func (b *BrutalSender) CanSend(bytesInFlight congestion.ByteCount) bool {
|
||||
return bytesInFlight < b.GetCongestionWindow()
|
||||
}
|
||||
|
||||
func (b *BrutalSender) GetCongestionWindow() congestion.ByteCount {
|
||||
rtt := maxDuration(b.rttStats.LatestRTT(), b.rttStats.SmoothedRTT())
|
||||
if rtt <= 0 {
|
||||
return 10240
|
||||
}
|
||||
return congestion.ByteCount(float64(b.bps) * rtt.Seconds() * 1.5 / b.ackRate)
|
||||
}
|
||||
|
||||
func (b *BrutalSender) OnPacketSent(sentTime time.Time, bytesInFlight congestion.ByteCount,
|
||||
packetNumber congestion.PacketNumber, bytes congestion.ByteCount, isRetransmittable bool,
|
||||
) {
|
||||
b.pacer.SentPacket(sentTime, bytes)
|
||||
}
|
||||
|
||||
func (b *BrutalSender) OnPacketAcked(number congestion.PacketNumber, ackedBytes congestion.ByteCount,
|
||||
priorInFlight congestion.ByteCount, eventTime time.Time,
|
||||
) {
|
||||
currentTimestamp := eventTime.Unix()
|
||||
slot := currentTimestamp % pktInfoSlotCount
|
||||
if b.pktInfoSlots[slot].Timestamp == currentTimestamp {
|
||||
b.pktInfoSlots[slot].AckCount++
|
||||
} else {
|
||||
// uninitialized slot or too old, reset
|
||||
b.pktInfoSlots[slot].Timestamp = currentTimestamp
|
||||
b.pktInfoSlots[slot].AckCount = 1
|
||||
b.pktInfoSlots[slot].LossCount = 0
|
||||
}
|
||||
b.updateAckRate(currentTimestamp)
|
||||
}
|
||||
|
||||
func (b *BrutalSender) OnPacketLost(number congestion.PacketNumber, lostBytes congestion.ByteCount,
|
||||
priorInFlight congestion.ByteCount,
|
||||
) {
|
||||
currentTimestamp := time.Now().Unix()
|
||||
slot := currentTimestamp % pktInfoSlotCount
|
||||
if b.pktInfoSlots[slot].Timestamp == currentTimestamp {
|
||||
b.pktInfoSlots[slot].LossCount++
|
||||
} else {
|
||||
// uninitialized slot or too old, reset
|
||||
b.pktInfoSlots[slot].Timestamp = currentTimestamp
|
||||
b.pktInfoSlots[slot].AckCount = 0
|
||||
b.pktInfoSlots[slot].LossCount = 1
|
||||
}
|
||||
b.updateAckRate(currentTimestamp)
|
||||
}
|
||||
|
||||
func (b *BrutalSender) SetMaxDatagramSize(size congestion.ByteCount) {
|
||||
b.maxDatagramSize = size
|
||||
b.pacer.SetMaxDatagramSize(size)
|
||||
}
|
||||
|
||||
func (b *BrutalSender) updateAckRate(currentTimestamp int64) {
|
||||
minTimestamp := currentTimestamp - pktInfoSlotCount
|
||||
var ackCount, lossCount uint64
|
||||
for _, info := range b.pktInfoSlots {
|
||||
if info.Timestamp < minTimestamp {
|
||||
continue
|
||||
}
|
||||
ackCount += info.AckCount
|
||||
lossCount += info.LossCount
|
||||
}
|
||||
if ackCount+lossCount < minSampleCount {
|
||||
b.ackRate = 1
|
||||
}
|
||||
rate := float64(ackCount) / float64(ackCount+lossCount)
|
||||
if rate < minAckRate {
|
||||
b.ackRate = minAckRate
|
||||
}
|
||||
b.ackRate = rate
|
||||
}
|
||||
|
||||
func (b *BrutalSender) InSlowStart() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (b *BrutalSender) InRecovery() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (b *BrutalSender) MaybeExitSlowStart() {}
|
||||
|
||||
func (b *BrutalSender) OnRetransmissionTimeout(packetsRetransmitted bool) {}
|
||||
|
||||
func maxDuration(a, b time.Duration) time.Duration {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
@ -1,186 +0,0 @@
|
||||
package hysteria
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
)
|
||||
|
||||
var _ net.Conn = (*ClientConn)(nil)
|
||||
|
||||
type ClientConn struct {
|
||||
quic.Stream
|
||||
destination M.Socksaddr
|
||||
requestWritten bool
|
||||
responseRead bool
|
||||
}
|
||||
|
||||
func NewClientConn(stream quic.Stream, destination M.Socksaddr) *ClientConn {
|
||||
return &ClientConn{
|
||||
Stream: stream,
|
||||
destination: destination,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ClientConn) Read(b []byte) (n int, err error) {
|
||||
if !c.responseRead {
|
||||
var response *ServerResponse
|
||||
response, err = ReadServerResponse(c.Stream)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if !response.OK {
|
||||
return 0, E.New("remote error: " + response.Message)
|
||||
}
|
||||
c.responseRead = true
|
||||
}
|
||||
return c.Stream.Read(b)
|
||||
}
|
||||
|
||||
func (c *ClientConn) Write(b []byte) (n int, err error) {
|
||||
if !c.requestWritten {
|
||||
err = WriteClientRequest(c.Stream, ClientRequest{
|
||||
UDP: false,
|
||||
Host: c.destination.Unwrap().AddrString(),
|
||||
Port: c.destination.Port,
|
||||
}, b)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c.requestWritten = true
|
||||
return len(b), nil
|
||||
}
|
||||
return c.Stream.Write(b)
|
||||
}
|
||||
|
||||
func (c *ClientConn) LocalAddr() net.Addr {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ClientConn) RemoteAddr() net.Addr {
|
||||
return c.destination.TCPAddr()
|
||||
}
|
||||
|
||||
func (c *ClientConn) Upstream() any {
|
||||
return c.Stream
|
||||
}
|
||||
|
||||
func (c *ClientConn) ReaderReplaceable() bool {
|
||||
return c.responseRead
|
||||
}
|
||||
|
||||
func (c *ClientConn) WriterReplaceable() bool {
|
||||
return c.requestWritten
|
||||
}
|
||||
|
||||
type ClientPacketConn struct {
|
||||
session quic.Connection
|
||||
stream quic.Stream
|
||||
sessionId uint32
|
||||
destination M.Socksaddr
|
||||
msgCh <-chan *UDPMessage
|
||||
closer io.Closer
|
||||
}
|
||||
|
||||
func NewClientPacketConn(session quic.Connection, stream quic.Stream, sessionId uint32, destination M.Socksaddr, msgCh <-chan *UDPMessage, closer io.Closer) *ClientPacketConn {
|
||||
return &ClientPacketConn{
|
||||
session: session,
|
||||
stream: stream,
|
||||
sessionId: sessionId,
|
||||
destination: destination,
|
||||
msgCh: msgCh,
|
||||
closer: closer,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ClientPacketConn) Hold() {
|
||||
// Hold the stream until it's closed
|
||||
buf := make([]byte, 1024)
|
||||
for {
|
||||
_, err := c.stream.Read(buf)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
_ = c.Close()
|
||||
}
|
||||
|
||||
func (c *ClientPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
|
||||
msg := <-c.msgCh
|
||||
if msg == nil {
|
||||
err = net.ErrClosed
|
||||
return
|
||||
}
|
||||
err = common.Error(buffer.Write(msg.Data))
|
||||
destination = M.ParseSocksaddrHostPort(msg.Host, msg.Port)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *ClientPacketConn) ReadPacketThreadSafe() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
|
||||
msg := <-c.msgCh
|
||||
if msg == nil {
|
||||
err = net.ErrClosed
|
||||
return
|
||||
}
|
||||
buffer = buf.As(msg.Data)
|
||||
destination = M.ParseSocksaddrHostPort(msg.Host, msg.Port)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *ClientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
||||
return WriteUDPMessage(c.session, UDPMessage{
|
||||
SessionID: c.sessionId,
|
||||
Host: destination.Unwrap().AddrString(),
|
||||
Port: destination.Port,
|
||||
FragCount: 1,
|
||||
Data: buffer.Bytes(),
|
||||
})
|
||||
}
|
||||
|
||||
func (c *ClientPacketConn) LocalAddr() net.Addr {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ClientPacketConn) RemoteAddr() net.Addr {
|
||||
return c.destination.UDPAddr()
|
||||
}
|
||||
|
||||
func (c *ClientPacketConn) SetDeadline(t time.Time) error {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
|
||||
func (c *ClientPacketConn) SetReadDeadline(t time.Time) error {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
|
||||
func (c *ClientPacketConn) SetWriteDeadline(t time.Time) error {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
|
||||
func (c *ClientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
return 0, nil, os.ErrInvalid
|
||||
}
|
||||
|
||||
func (c *ClientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
return 0, os.ErrInvalid
|
||||
}
|
||||
|
||||
func (c *ClientPacketConn) Read(b []byte) (n int, err error) {
|
||||
return 0, os.ErrInvalid
|
||||
}
|
||||
|
||||
func (c *ClientPacketConn) Write(b []byte) (n int, err error) {
|
||||
return 0, os.ErrInvalid
|
||||
}
|
||||
|
||||
func (c *ClientPacketConn) Close() error {
|
||||
return common.Close(c.stream, c.closer)
|
||||
}
|
86
transport/hysteria/pacer.go
Normal file
86
transport/hysteria/pacer.go
Normal file
@ -0,0 +1,86 @@
|
||||
package hysteria
|
||||
|
||||
import (
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/quic-go/congestion"
|
||||
)
|
||||
|
||||
const (
|
||||
maxBurstPackets = 10
|
||||
minPacingDelay = time.Millisecond
|
||||
)
|
||||
|
||||
// The pacer implements a token bucket pacing algorithm.
|
||||
type pacer struct {
|
||||
budgetAtLastSent congestion.ByteCount
|
||||
maxDatagramSize congestion.ByteCount
|
||||
lastSentTime time.Time
|
||||
getBandwidth func() congestion.ByteCount // in bytes/s
|
||||
}
|
||||
|
||||
func newPacer(getBandwidth func() congestion.ByteCount) *pacer {
|
||||
p := &pacer{
|
||||
budgetAtLastSent: maxBurstPackets * initMaxDatagramSize,
|
||||
maxDatagramSize: initMaxDatagramSize,
|
||||
getBandwidth: getBandwidth,
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *pacer) SentPacket(sendTime time.Time, size congestion.ByteCount) {
|
||||
budget := p.Budget(sendTime)
|
||||
if size > budget {
|
||||
p.budgetAtLastSent = 0
|
||||
} else {
|
||||
p.budgetAtLastSent = budget - size
|
||||
}
|
||||
p.lastSentTime = sendTime
|
||||
}
|
||||
|
||||
func (p *pacer) Budget(now time.Time) congestion.ByteCount {
|
||||
if p.lastSentTime.IsZero() {
|
||||
return p.maxBurstSize()
|
||||
}
|
||||
budget := p.budgetAtLastSent + (p.getBandwidth()*congestion.ByteCount(now.Sub(p.lastSentTime).Nanoseconds()))/1e9
|
||||
return minByteCount(p.maxBurstSize(), budget)
|
||||
}
|
||||
|
||||
func (p *pacer) maxBurstSize() congestion.ByteCount {
|
||||
return maxByteCount(
|
||||
congestion.ByteCount((minPacingDelay+time.Millisecond).Nanoseconds())*p.getBandwidth()/1e9,
|
||||
maxBurstPackets*p.maxDatagramSize,
|
||||
)
|
||||
}
|
||||
|
||||
// TimeUntilSend returns when the next packet should be sent.
|
||||
// It returns the zero value of time.Time if a packet can be sent immediately.
|
||||
func (p *pacer) TimeUntilSend() time.Time {
|
||||
if p.budgetAtLastSent >= p.maxDatagramSize {
|
||||
return time.Time{}
|
||||
}
|
||||
return p.lastSentTime.Add(maxDuration(
|
||||
minPacingDelay,
|
||||
time.Duration(math.Ceil(float64(p.maxDatagramSize-p.budgetAtLastSent)*1e9/
|
||||
float64(p.getBandwidth())))*time.Nanosecond,
|
||||
))
|
||||
}
|
||||
|
||||
func (p *pacer) SetMaxDatagramSize(s congestion.ByteCount) {
|
||||
p.maxDatagramSize = s
|
||||
}
|
||||
|
||||
func maxByteCount(a, b congestion.ByteCount) congestion.ByteCount {
|
||||
if a < b {
|
||||
return b
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
func minByteCount(a, b congestion.ByteCount) congestion.ByteCount {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
@ -5,13 +5,15 @@ import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/quic-go"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -177,13 +179,12 @@ func ReadClientRequest(stream io.Reader) (*ClientRequest, error) {
|
||||
return &clientRequest, nil
|
||||
}
|
||||
|
||||
func WriteClientRequest(stream io.Writer, request ClientRequest, payload []byte) error {
|
||||
func WriteClientRequest(stream io.Writer, request ClientRequest) error {
|
||||
var requestLen int
|
||||
requestLen += 1 // udp
|
||||
requestLen += 2 // host len
|
||||
requestLen += len(request.Host)
|
||||
requestLen += 2 // port
|
||||
requestLen += len(payload)
|
||||
_buffer := buf.StackNewSize(requestLen)
|
||||
defer common.KeepAlive(_buffer)
|
||||
buffer := common.Dup(_buffer)
|
||||
@ -197,7 +198,6 @@ func WriteClientRequest(stream io.Writer, request ClientRequest, payload []byte)
|
||||
binary.Write(buffer, binary.BigEndian, uint16(len(request.Host))),
|
||||
common.Error(buffer.WriteString(request.Host)),
|
||||
binary.Write(buffer, binary.BigEndian, request.Port),
|
||||
common.Error(buffer.Write(payload)),
|
||||
)
|
||||
return common.Error(stream.Write(buffer.Bytes()))
|
||||
}
|
||||
@ -237,13 +237,12 @@ func ReadServerResponse(stream io.Reader) (*ServerResponse, error) {
|
||||
return &serverResponse, nil
|
||||
}
|
||||
|
||||
func WriteServerResponse(stream io.Writer, response ServerResponse, payload []byte) error {
|
||||
func WriteServerResponse(stream io.Writer, response ServerResponse) error {
|
||||
var responseLen int
|
||||
responseLen += 1 // ok
|
||||
responseLen += 4 // udp session id
|
||||
responseLen += 2 // message len
|
||||
responseLen += len(response.Message)
|
||||
responseLen += len(payload)
|
||||
_buffer := buf.StackNewSize(responseLen)
|
||||
defer common.KeepAlive(_buffer)
|
||||
buffer := common.Dup(_buffer)
|
||||
@ -257,7 +256,6 @@ func WriteServerResponse(stream io.Writer, response ServerResponse, payload []by
|
||||
binary.Write(buffer, binary.BigEndian, response.UDPSessionID),
|
||||
binary.Write(buffer, binary.BigEndian, uint16(len(response.Message))),
|
||||
common.Error(buffer.WriteString(response.Message)),
|
||||
common.Error(buffer.Write(payload)),
|
||||
)
|
||||
return common.Error(stream.Write(buffer.Bytes()))
|
||||
}
|
||||
@ -341,9 +339,7 @@ func WriteUDPMessage(conn quic.Connection, message UDPMessage) error {
|
||||
buffer := common.Dup(_buffer)
|
||||
defer buffer.Release()
|
||||
err := writeUDPMessage(conn, message, buffer)
|
||||
// TODO: wait for change upstream
|
||||
if /*errSize, ok := err.(quic.ErrMessageToLarge); ok*/ false {
|
||||
const errSize = 0
|
||||
if errSize, ok := err.(quic.ErrMessageToLarge); ok {
|
||||
// need to frag
|
||||
message.MsgID = uint16(rand.Intn(0xFFFF)) + 1 // msgID must be > 0 when fragCount > 1
|
||||
fragMsgs := FragUDPMessage(message, int(errSize))
|
||||
@ -373,3 +369,142 @@ func writeUDPMessage(conn quic.Connection, message UDPMessage, buffer *buf.Buffe
|
||||
)
|
||||
return conn.SendMessage(buffer.Bytes())
|
||||
}
|
||||
|
||||
var _ net.Conn = (*Conn)(nil)
|
||||
|
||||
type Conn struct {
|
||||
quic.Stream
|
||||
destination M.Socksaddr
|
||||
responseWritten bool
|
||||
}
|
||||
|
||||
func NewConn(stream quic.Stream, destination M.Socksaddr) *Conn {
|
||||
return &Conn{
|
||||
Stream: stream,
|
||||
destination: destination,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) LocalAddr() net.Addr {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) RemoteAddr() net.Addr {
|
||||
return c.destination.TCPAddr()
|
||||
}
|
||||
|
||||
func (c *Conn) ReaderReplaceable() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *Conn) WriterReplaceable() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *Conn) Upstream() any {
|
||||
return c.Stream
|
||||
}
|
||||
|
||||
type PacketConn struct {
|
||||
session quic.Connection
|
||||
stream quic.Stream
|
||||
sessionId uint32
|
||||
destination M.Socksaddr
|
||||
msgCh <-chan *UDPMessage
|
||||
closer io.Closer
|
||||
}
|
||||
|
||||
func NewPacketConn(session quic.Connection, stream quic.Stream, sessionId uint32, destination M.Socksaddr, msgCh <-chan *UDPMessage, closer io.Closer) *PacketConn {
|
||||
return &PacketConn{
|
||||
session: session,
|
||||
stream: stream,
|
||||
sessionId: sessionId,
|
||||
destination: destination,
|
||||
msgCh: msgCh,
|
||||
closer: closer,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *PacketConn) Hold() {
|
||||
// Hold the stream until it's closed
|
||||
buf := make([]byte, 1024)
|
||||
for {
|
||||
_, err := c.stream.Read(buf)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
_ = c.Close()
|
||||
}
|
||||
|
||||
func (c *PacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
|
||||
msg := <-c.msgCh
|
||||
if msg == nil {
|
||||
err = net.ErrClosed
|
||||
return
|
||||
}
|
||||
err = common.Error(buffer.Write(msg.Data))
|
||||
destination = M.ParseSocksaddrHostPort(msg.Host, msg.Port)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *PacketConn) ReadPacketThreadSafe() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
|
||||
msg := <-c.msgCh
|
||||
if msg == nil {
|
||||
err = net.ErrClosed
|
||||
return
|
||||
}
|
||||
buffer = buf.As(msg.Data)
|
||||
destination = M.ParseSocksaddrHostPort(msg.Host, msg.Port)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *PacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
||||
return WriteUDPMessage(c.session, UDPMessage{
|
||||
SessionID: c.sessionId,
|
||||
Host: destination.Unwrap().AddrString(),
|
||||
Port: destination.Port,
|
||||
FragCount: 1,
|
||||
Data: buffer.Bytes(),
|
||||
})
|
||||
}
|
||||
|
||||
func (c *PacketConn) LocalAddr() net.Addr {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *PacketConn) RemoteAddr() net.Addr {
|
||||
return c.destination.UDPAddr()
|
||||
}
|
||||
|
||||
func (c *PacketConn) SetDeadline(t time.Time) error {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
|
||||
func (c *PacketConn) SetReadDeadline(t time.Time) error {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
|
||||
func (c *PacketConn) SetWriteDeadline(t time.Time) error {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
|
||||
func (c *PacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
return 0, nil, os.ErrInvalid
|
||||
}
|
||||
|
||||
func (c *PacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
return 0, os.ErrInvalid
|
||||
}
|
||||
|
||||
func (c *PacketConn) Read(b []byte) (n int, err error) {
|
||||
return 0, os.ErrInvalid
|
||||
}
|
||||
|
||||
func (c *PacketConn) Write(b []byte) (n int, err error) {
|
||||
return 0, os.ErrInvalid
|
||||
}
|
||||
|
||||
func (c *PacketConn) Close() error {
|
||||
return common.Close(c.stream, c.closer)
|
||||
}
|
||||
|
@ -1,68 +0,0 @@
|
||||
package hysteria
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
)
|
||||
|
||||
var (
|
||||
_ net.Conn = (*ServerConn)(nil)
|
||||
_ N.HandshakeConn = (*ServerConn)(nil)
|
||||
)
|
||||
|
||||
type ServerConn struct {
|
||||
quic.Stream
|
||||
destination M.Socksaddr
|
||||
responseWritten bool
|
||||
}
|
||||
|
||||
func NewServerConn(stream quic.Stream, destination M.Socksaddr) *ServerConn {
|
||||
return &ServerConn{
|
||||
Stream: stream,
|
||||
destination: destination,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ServerConn) LocalAddr() net.Addr {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ServerConn) RemoteAddr() net.Addr {
|
||||
return c.destination.TCPAddr()
|
||||
}
|
||||
|
||||
func (c *ServerConn) Write(b []byte) (n int, err error) {
|
||||
if !c.responseWritten {
|
||||
err = WriteServerResponse(c.Stream, ServerResponse{
|
||||
OK: true,
|
||||
}, b)
|
||||
c.responseWritten = true
|
||||
return len(b), nil
|
||||
}
|
||||
return c.Stream.Write(b)
|
||||
}
|
||||
|
||||
func (c *ServerConn) ReaderReplaceable() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *ServerConn) WriterReplaceable() bool {
|
||||
return c.responseWritten
|
||||
}
|
||||
|
||||
func (c *ServerConn) HandshakeFailure(err error) error {
|
||||
if c.responseWritten {
|
||||
return nil
|
||||
}
|
||||
return WriteServerResponse(c.Stream, ServerResponse{
|
||||
Message: err.Error(),
|
||||
}, nil)
|
||||
}
|
||||
|
||||
func (c *ServerConn) Upstream() any {
|
||||
return c.Stream
|
||||
}
|
@ -5,25 +5,52 @@ import (
|
||||
"os"
|
||||
"syscall"
|
||||
|
||||
"github.com/sagernet/quic-go"
|
||||
"github.com/sagernet/sing/common"
|
||||
)
|
||||
|
||||
type WrapPacketConn struct {
|
||||
type PacketConnWrapper struct {
|
||||
net.PacketConn
|
||||
}
|
||||
|
||||
func (c *WrapPacketConn) SetReadBuffer(bytes int) error {
|
||||
func (c *PacketConnWrapper) SetReadBuffer(bytes int) error {
|
||||
return common.MustCast[*net.UDPConn](c.PacketConn).SetReadBuffer(bytes)
|
||||
}
|
||||
|
||||
func (c *WrapPacketConn) SetWriteBuffer(bytes int) error {
|
||||
func (c *PacketConnWrapper) SetWriteBuffer(bytes int) error {
|
||||
return common.MustCast[*net.UDPConn](c.PacketConn).SetWriteBuffer(bytes)
|
||||
}
|
||||
|
||||
func (c *WrapPacketConn) SyscallConn() (syscall.RawConn, error) {
|
||||
func (c *PacketConnWrapper) SyscallConn() (syscall.RawConn, error) {
|
||||
return common.MustCast[*net.UDPConn](c.PacketConn).SyscallConn()
|
||||
}
|
||||
|
||||
func (c *WrapPacketConn) File() (f *os.File, err error) {
|
||||
func (c *PacketConnWrapper) File() (f *os.File, err error) {
|
||||
return common.MustCast[*net.UDPConn](c.PacketConn).File()
|
||||
}
|
||||
|
||||
func (c *PacketConnWrapper) Upstream() any {
|
||||
return c.PacketConn
|
||||
}
|
||||
|
||||
type StreamWrapper struct {
|
||||
quic.Stream
|
||||
}
|
||||
|
||||
func (s *StreamWrapper) Upstream() any {
|
||||
return s.Stream
|
||||
}
|
||||
|
||||
func (s *StreamWrapper) ReaderReplaceable() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *StreamWrapper) WriterReplaceable() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *StreamWrapper) Close() error {
|
||||
s.CancelRead(0)
|
||||
s.Stream.Close()
|
||||
return nil
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user