Add hysteria udp server

This commit is contained in:
世界 2022-08-18 23:02:36 +08:00
parent b992d942c4
commit 32cb511b7c
No known key found for this signature in database
GPG Key ID: CD109927C34A63C4
14 changed files with 621 additions and 317 deletions

View File

@ -7,19 +7,20 @@ import (
"context"
"net"
"net/netip"
"sync"
"github.com/sagernet/quic-go"
"github.com/sagernet/quic-go/congestion"
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-box/transport/hysteria"
dns "github.com/sagernet/sing-dns"
"github.com/sagernet/sing-dns"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/lucas-clemente/quic-go"
)
var _ adapter.Inbound = (*Hysteria)(nil)
@ -37,6 +38,10 @@ type Hysteria struct {
sendBPS uint64
recvBPS uint64
listener quic.Listener
udpAccess sync.RWMutex
udpSessionId uint32
udpSessions map[uint32]chan *hysteria.UDPMessage
udpDefragger hysteria.Defragger
}
func NewHysteria(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.HysteriaInboundOptions) (*Hysteria, error) {
@ -105,6 +110,7 @@ func NewHysteria(ctx context.Context, router adapter.Router, logger log.ContextL
xplusKey: xplus,
sendBPS: up,
recvBPS: down,
udpSessions: make(map[uint32]chan *hysteria.UDPMessage),
}
if options.TLS == nil || !options.TLS.Enabled {
return nil, ErrTLSRequired
@ -138,6 +144,7 @@ func (h *Hysteria) Start() error {
}
if len(h.xplusKey) > 0 {
packetConn = hysteria.NewXPlusPacketConn(packetConn, h.xplusKey)
packetConn = &hysteria.PacketConnWrapper{PacketConn: packetConn}
}
err = h.tlsConfig.Start()
if err != nil {
@ -159,6 +166,7 @@ func (h *Hysteria) acceptLoop() {
if err != nil {
return
}
h.logger.InfoContext(ctx, "inbound connection from ", conn.RemoteAddr())
go func() {
hErr := h.accept(ctx, conn)
if hErr != nil {
@ -202,23 +210,51 @@ func (h *Hysteria) accept(ctx context.Context, conn quic.Connection) error {
if err != nil {
return err
}
// TODO: set congestion control
conn.SetCongestionControl(hysteria.NewBrutalSender(congestion.ByteCount(serverSendBPS)))
go h.udpRecvLoop(conn)
var stream quic.Stream
for {
var stream quic.Stream
stream, err = conn.AcceptStream(ctx)
if err != nil {
return err
}
hErr := h.acceptStream(ctx, conn, stream)
go func() {
hErr := h.acceptStream(ctx, conn /*&hysteria.StreamWrapper{Stream: stream}*/, stream)
if hErr != nil {
stream.Close()
NewError(h.logger, ctx, E.Cause(hErr, "process stream from ", conn.RemoteAddr()))
}
}()
}
}
func (h *Hysteria) udpRecvLoop(conn quic.Connection) {
for {
packet, err := conn.ReceiveMessage()
if err != nil {
return
}
message, err := hysteria.ParseUDPMessage(packet)
if err != nil {
h.logger.Error("parse udp message: ", err)
continue
}
dfMsg := h.udpDefragger.Feed(message)
if dfMsg == nil {
continue
}
h.udpAccess.RLock()
ch, ok := h.udpSessions[dfMsg.SessionID]
if ok {
select {
case ch <- dfMsg:
// OK
default:
// Silently drop the message when the channel is full
}
}
h.udpAccess.RUnlock()
}
}
func (h *Hysteria) acceptStream(ctx context.Context, conn quic.Connection, stream quic.Stream) error {
@ -226,29 +262,55 @@ func (h *Hysteria) acceptStream(ctx context.Context, conn quic.Connection, strea
if err != nil {
return err
}
if request.UDP {
err = hysteria.WriteServerResponse(stream, hysteria.ServerResponse{
Message: "unsupported",
}, nil)
OK: true,
})
if err != nil {
return err
}
stream.Close()
return nil
}
var metadata adapter.InboundContext
metadata.Inbound = h.tag
metadata.InboundType = C.TypeHysteria
metadata.SniffEnabled = h.listenOptions.SniffEnabled
metadata.SniffOverrideDestination = h.listenOptions.SniffOverrideDestination
metadata.DomainStrategy = dns.DomainStrategy(h.listenOptions.DomainStrategy)
metadata.Network = N.NetworkTCP
metadata.Source = M.SocksaddrFromNet(conn.RemoteAddr())
metadata.Destination = M.ParseSocksaddrHostPort(request.Host, request.Port)
return h.router.RouteConnection(ctx, hysteria.NewServerConn(stream, metadata.Destination), metadata)
if !request.UDP {
h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination)
metadata.Network = N.NetworkTCP
return h.router.RouteConnection(ctx, hysteria.NewConn(stream, metadata.Destination), metadata)
} else {
h.logger.InfoContext(ctx, "inbound packet connection to ", metadata.Destination)
var id uint32
h.udpAccess.Lock()
id = h.udpSessionId
nCh := make(chan *hysteria.UDPMessage, 1024)
h.udpSessions[id] = nCh
h.udpSessionId += 1
h.udpAccess.Unlock()
metadata.Network = N.NetworkUDP
packetConn := hysteria.NewPacketConn(conn, stream, id, metadata.Destination, nCh, common.Closer(func() error {
h.udpAccess.Lock()
if ch, ok := h.udpSessions[id]; ok {
close(ch)
delete(h.udpSessions, id)
}
h.udpAccess.Unlock()
return nil
}))
go packetConn.Hold()
return h.router.RoutePacketConnection(ctx, packetConn, metadata)
}
}
func (h *Hysteria) Close() error {
h.udpAccess.Lock()
for _, session := range h.udpSessions {
close(session)
}
h.udpSessions = make(map[uint32]chan *hysteria.UDPMessage)
h.udpAccess.Unlock()
return common.Close(
h.listener,
common.PtrOrNil(h.tlsConfig),

View File

@ -195,6 +195,8 @@ func (n *Naive) newConnection(ctx context.Context, conn net.Conn, source, destin
metadata.Network = N.NetworkTCP
metadata.Source = source
metadata.Destination = destination
n.logger.InfoContext(ctx, "inbound connection from ", metadata.Source)
n.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination)
hErr := n.router.RouteConnection(ctx, conn, metadata)
if hErr != nil {
conn.Close()

View File

@ -10,6 +10,8 @@ import (
"os"
"sync"
"github.com/sagernet/quic-go"
"github.com/sagernet/quic-go/congestion"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/dialer"
C "github.com/sagernet/sing-box/constant"
@ -21,8 +23,6 @@ import (
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/lucas-clemente/quic-go"
)
var _ adapter.Outbound = (*Hysteria)(nil)
@ -171,7 +171,7 @@ func (h *Hysteria) offer(ctx context.Context) (quic.Connection, error) {
}
h.udpSessions = make(map[uint32]chan *hysteria.UDPMessage)
h.udpDefragger = hysteria.Defragger{}
go h.recvLoop(conn)
go h.udpRecvLoop(conn)
}
return conn, nil
}
@ -186,7 +186,7 @@ func (h *Hysteria) offerNew(ctx context.Context) (quic.Connection, error) {
if h.xplusKey != nil {
packetConn = hysteria.NewXPlusPacketConn(packetConn, h.xplusKey)
}
packetConn = &hysteria.WrapPacketConn{PacketConn: packetConn}
packetConn = &hysteria.PacketConnWrapper{PacketConn: packetConn}
quicConn, err := quic.Dial(packetConn, udpConn.RemoteAddr(), h.serverAddr.AddrString(), h.tlsConfig, h.quicConfig)
if err != nil {
packetConn.Close()
@ -203,20 +203,23 @@ func (h *Hysteria) offerNew(ctx context.Context) (quic.Connection, error) {
Auth: h.authKey,
})
if err != nil {
packetConn.Close()
return nil, err
}
serverHello, err := hysteria.ReadServerHello(controlStream)
if err != nil {
packetConn.Close()
return nil, err
}
if !serverHello.OK {
packetConn.Close()
return nil, E.New("remote error: ", serverHello.Message)
}
// TODO: set congestion control
quicConn.SetCongestionControl(hysteria.NewBrutalSender(congestion.ByteCount(serverHello.RecvBPS)))
return quicConn, nil
}
func (h *Hysteria) recvLoop(conn quic.Connection) {
func (h *Hysteria) udpRecvLoop(conn quic.Connection) {
for {
packet, err := conn.ReceiveMessage()
if err != nil {
@ -260,35 +263,58 @@ func (h *Hysteria) Close() error {
return nil
}
func (h *Hysteria) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
switch N.NetworkName(network) {
case N.NetworkTCP:
func (h *Hysteria) open(ctx context.Context) (quic.Connection, quic.Stream, error) {
conn, err := h.offer(ctx)
if err != nil {
return nil, err
return nil, nil, err
}
stream, err := conn.OpenStream()
if err != nil {
return nil, nil, err
}
return conn, &hysteria.StreamWrapper{Stream: stream}, nil
}
func (h *Hysteria) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
switch N.NetworkName(network) {
case N.NetworkTCP:
h.logger.InfoContext(ctx, "outbound connection to ", destination)
_, stream, err := h.open(ctx)
if err != nil {
return nil, err
}
return hysteria.NewClientConn(stream, destination), nil
err = hysteria.WriteClientRequest(stream, hysteria.ClientRequest{
Host: destination.AddrString(),
Port: destination.Port,
})
if err != nil {
stream.Close()
return nil, err
}
response, err := hysteria.ReadServerResponse(stream)
if err != nil {
stream.Close()
return nil, err
}
if !response.OK {
stream.Close()
return nil, E.New("remote error: ", response.Message)
}
return hysteria.NewConn(stream, destination), nil
case N.NetworkUDP:
conn, err := h.ListenPacket(ctx, destination)
if err != nil {
return nil, err
}
return conn.(*hysteria.ClientPacketConn), nil
return conn.(*hysteria.PacketConn), nil
default:
return nil, E.New("unsupported network: ", network)
}
}
func (h *Hysteria) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
conn, err := h.offer(ctx)
if err != nil {
return nil, err
}
stream, err := conn.OpenStream()
h.logger.InfoContext(ctx, "outbound packet connection to ", destination)
conn, stream, err := h.open(ctx)
if err != nil {
return nil, err
}
@ -296,7 +322,7 @@ func (h *Hysteria) ListenPacket(ctx context.Context, destination M.Socksaddr) (n
UDP: true,
Host: destination.AddrString(),
Port: destination.Port,
}, nil)
})
if err != nil {
stream.Close()
return nil, err
@ -313,12 +339,9 @@ func (h *Hysteria) ListenPacket(ctx context.Context, destination M.Socksaddr) (n
}
h.udpAccess.Lock()
nCh := make(chan *hysteria.UDPMessage, 1024)
// Store the current session map for CloseFunc below
// to ensures that we are adding and removing sessions on the same map,
// as reconnecting will reassign the map
h.udpSessions[response.UDPSessionID] = nCh
h.udpAccess.Unlock()
packetConn := hysteria.NewClientPacketConn(conn, stream, response.UDPSessionID, destination, nCh, common.Closer(func() error {
packetConn := hysteria.NewPacketConn(conn, stream, response.UDPSessionID, destination, nCh, common.Closer(func() error {
h.udpAccess.Lock()
if ch, ok := h.udpSessions[response.UDPSessionID]; ok {
close(ch)

View File

@ -56,10 +56,21 @@ func testTCP(t *testing.T, clientPort uint16, testPort uint16) {
dialTCP := func() (net.Conn, error) {
return dialer.DialContext(context.Background(), "tcp", M.ParseSocksaddrHostPort("127.0.0.1", testPort))
}
require.NoError(t, testPingPongWithConn(t, testPort, dialTCP))
require.NoError(t, testLargeDataWithConn(t, testPort, dialTCP))
}
func testSuitHy(t *testing.T, clientPort uint16, testPort uint16) {
dialer := socks.NewClient(N.SystemDialer, M.ParseSocksaddrHostPort("127.0.0.1", clientPort), socks.Version5, "", "")
dialTCP := func() (net.Conn, error) {
return dialer.DialContext(context.Background(), "tcp", M.ParseSocksaddrHostPort("127.0.0.1", testPort))
}
dialUDP := func() (net.PacketConn, error) {
return dialer.ListenPacket(context.Background(), M.ParseSocksaddrHostPort("127.0.0.1", testPort))
}
require.NoError(t, testPingPongWithConn(t, testPort, dialTCP))
require.NoError(t, testPingPongWithPacketConn(t, testPort, dialUDP))
}
func testSuitWg(t *testing.T, clientPort uint16, testPort uint16) {
dialer := socks.NewClient(N.SystemDialer, M.ParseSocksaddrHostPort("127.0.0.1", clientPort), socks.Version5, "", "")
dialTCP := func() (net.Conn, error) {

View File

@ -7,6 +7,8 @@ import (
"errors"
"io"
"net"
"net/http"
_ "net/http/pprof"
"net/netip"
"sync"
"testing"
@ -91,6 +93,12 @@ func init() {
io.Copy(io.Discard, imageStream)
}
go func() {
err = http.ListenAndServe("0.0.0.0:8965", nil)
if err != nil {
log.Debug(err)
}
}()
}
func newPingPongPair() (chan []byte, chan []byte, func(t *testing.T) error) {

View File

@ -0,0 +1,12 @@
{
"server": "127.0.0.1:10000",
"auth_str": "password",
"obfs": "fuck me till the daylight",
"up_mbps": 100,
"down_mbps": 100,
"socks5": {
"listen": "127.0.0.1:10001"
},
"server_name": "example.org",
"ca": "/etc/hysteria/ca.pem"
}

View File

@ -80,7 +80,50 @@ func TestHysteriaSelf(t *testing.T) {
},
},
})
testTCP(t, clientPort, testPort)
testSuitHy(t, clientPort, testPort)
}
func TestHysteriaInbound(t *testing.T) {
if !C.QUIC_AVAILABLE {
t.Skip("QUIC not included")
}
caPem, certPem, keyPem := createSelfSignedCertificate(t, "example.org")
startInstance(t, option.Options{
Log: &option.LogOptions{
Level: "trace",
},
Inbounds: []option.Inbound{
{
Type: C.TypeHysteria,
HysteriaOptions: option.HysteriaInboundOptions{
ListenOptions: option.ListenOptions{
Listen: option.ListenAddress(netip.IPv4Unspecified()),
ListenPort: serverPort,
},
UpMbps: 100,
DownMbps: 100,
AuthString: "password",
Obfs: "fuck me till the daylight",
TLS: &option.InboundTLSOptions{
Enabled: true,
ServerName: "example.org",
CertificatePath: certPem,
KeyPath: keyPem,
},
},
},
},
})
startDockerContainer(t, DockerOptions{
Image: ImageHysteria,
Ports: []uint16{serverPort, clientPort},
Cmd: []string{"-c", "/etc/hysteria/config.json", "client"},
Bind: map[string]string{
"hysteria-client.json": "/etc/hysteria/config.json",
caPem: "/etc/hysteria/ca.pem",
},
})
testSuit(t, clientPort, testPort)
}
func TestHysteriaOutbound(t *testing.T) {
@ -93,7 +136,7 @@ func TestHysteriaOutbound(t *testing.T) {
Ports: []uint16{serverPort, testPort},
Cmd: []string{"-c", "/etc/hysteria/config.json", "server"},
Bind: map[string]string{
"hysteria.json": "/etc/hysteria/config.json",
"hysteria-server.json": "/etc/hysteria/config.json",
certPem: "/etc/hysteria/cert.pem",
keyPem: "/etc/hysteria/key.pem",
},
@ -131,5 +174,5 @@ func TestHysteriaOutbound(t *testing.T) {
},
},
})
testSuit(t, clientPort, testPort)
testSuitHy(t, clientPort, testPort)
}

View File

@ -0,0 +1,149 @@
package hysteria
import (
"time"
"github.com/sagernet/quic-go/congestion"
)
const (
initMaxDatagramSize = 1252
pktInfoSlotCount = 4
minSampleCount = 50
minAckRate = 0.8
)
type BrutalSender struct {
rttStats congestion.RTTStatsProvider
bps congestion.ByteCount
maxDatagramSize congestion.ByteCount
pacer *pacer
pktInfoSlots [pktInfoSlotCount]pktInfo
ackRate float64
}
type pktInfo struct {
Timestamp int64
AckCount uint64
LossCount uint64
}
func NewBrutalSender(bps congestion.ByteCount) *BrutalSender {
bs := &BrutalSender{
bps: bps,
maxDatagramSize: initMaxDatagramSize,
ackRate: 1,
}
bs.pacer = newPacer(func() congestion.ByteCount {
return congestion.ByteCount(float64(bs.bps) / bs.ackRate)
})
return bs
}
func (b *BrutalSender) SetRTTStatsProvider(rttStats congestion.RTTStatsProvider) {
b.rttStats = rttStats
}
func (b *BrutalSender) TimeUntilSend(bytesInFlight congestion.ByteCount) time.Time {
return b.pacer.TimeUntilSend()
}
func (b *BrutalSender) HasPacingBudget() bool {
return b.pacer.Budget(time.Now()) >= b.maxDatagramSize
}
func (b *BrutalSender) CanSend(bytesInFlight congestion.ByteCount) bool {
return bytesInFlight < b.GetCongestionWindow()
}
func (b *BrutalSender) GetCongestionWindow() congestion.ByteCount {
rtt := maxDuration(b.rttStats.LatestRTT(), b.rttStats.SmoothedRTT())
if rtt <= 0 {
return 10240
}
return congestion.ByteCount(float64(b.bps) * rtt.Seconds() * 1.5 / b.ackRate)
}
func (b *BrutalSender) OnPacketSent(sentTime time.Time, bytesInFlight congestion.ByteCount,
packetNumber congestion.PacketNumber, bytes congestion.ByteCount, isRetransmittable bool,
) {
b.pacer.SentPacket(sentTime, bytes)
}
func (b *BrutalSender) OnPacketAcked(number congestion.PacketNumber, ackedBytes congestion.ByteCount,
priorInFlight congestion.ByteCount, eventTime time.Time,
) {
currentTimestamp := eventTime.Unix()
slot := currentTimestamp % pktInfoSlotCount
if b.pktInfoSlots[slot].Timestamp == currentTimestamp {
b.pktInfoSlots[slot].AckCount++
} else {
// uninitialized slot or too old, reset
b.pktInfoSlots[slot].Timestamp = currentTimestamp
b.pktInfoSlots[slot].AckCount = 1
b.pktInfoSlots[slot].LossCount = 0
}
b.updateAckRate(currentTimestamp)
}
func (b *BrutalSender) OnPacketLost(number congestion.PacketNumber, lostBytes congestion.ByteCount,
priorInFlight congestion.ByteCount,
) {
currentTimestamp := time.Now().Unix()
slot := currentTimestamp % pktInfoSlotCount
if b.pktInfoSlots[slot].Timestamp == currentTimestamp {
b.pktInfoSlots[slot].LossCount++
} else {
// uninitialized slot or too old, reset
b.pktInfoSlots[slot].Timestamp = currentTimestamp
b.pktInfoSlots[slot].AckCount = 0
b.pktInfoSlots[slot].LossCount = 1
}
b.updateAckRate(currentTimestamp)
}
func (b *BrutalSender) SetMaxDatagramSize(size congestion.ByteCount) {
b.maxDatagramSize = size
b.pacer.SetMaxDatagramSize(size)
}
func (b *BrutalSender) updateAckRate(currentTimestamp int64) {
minTimestamp := currentTimestamp - pktInfoSlotCount
var ackCount, lossCount uint64
for _, info := range b.pktInfoSlots {
if info.Timestamp < minTimestamp {
continue
}
ackCount += info.AckCount
lossCount += info.LossCount
}
if ackCount+lossCount < minSampleCount {
b.ackRate = 1
}
rate := float64(ackCount) / float64(ackCount+lossCount)
if rate < minAckRate {
b.ackRate = minAckRate
}
b.ackRate = rate
}
func (b *BrutalSender) InSlowStart() bool {
return false
}
func (b *BrutalSender) InRecovery() bool {
return false
}
func (b *BrutalSender) MaybeExitSlowStart() {}
func (b *BrutalSender) OnRetransmissionTimeout(packetsRetransmitted bool) {}
func maxDuration(a, b time.Duration) time.Duration {
if a > b {
return a
}
return b
}

View File

@ -1,186 +0,0 @@
package hysteria
import (
"io"
"net"
"os"
"time"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
"github.com/lucas-clemente/quic-go"
)
var _ net.Conn = (*ClientConn)(nil)
type ClientConn struct {
quic.Stream
destination M.Socksaddr
requestWritten bool
responseRead bool
}
func NewClientConn(stream quic.Stream, destination M.Socksaddr) *ClientConn {
return &ClientConn{
Stream: stream,
destination: destination,
}
}
func (c *ClientConn) Read(b []byte) (n int, err error) {
if !c.responseRead {
var response *ServerResponse
response, err = ReadServerResponse(c.Stream)
if err != nil {
return
}
if !response.OK {
return 0, E.New("remote error: " + response.Message)
}
c.responseRead = true
}
return c.Stream.Read(b)
}
func (c *ClientConn) Write(b []byte) (n int, err error) {
if !c.requestWritten {
err = WriteClientRequest(c.Stream, ClientRequest{
UDP: false,
Host: c.destination.Unwrap().AddrString(),
Port: c.destination.Port,
}, b)
if err != nil {
return
}
c.requestWritten = true
return len(b), nil
}
return c.Stream.Write(b)
}
func (c *ClientConn) LocalAddr() net.Addr {
return nil
}
func (c *ClientConn) RemoteAddr() net.Addr {
return c.destination.TCPAddr()
}
func (c *ClientConn) Upstream() any {
return c.Stream
}
func (c *ClientConn) ReaderReplaceable() bool {
return c.responseRead
}
func (c *ClientConn) WriterReplaceable() bool {
return c.requestWritten
}
type ClientPacketConn struct {
session quic.Connection
stream quic.Stream
sessionId uint32
destination M.Socksaddr
msgCh <-chan *UDPMessage
closer io.Closer
}
func NewClientPacketConn(session quic.Connection, stream quic.Stream, sessionId uint32, destination M.Socksaddr, msgCh <-chan *UDPMessage, closer io.Closer) *ClientPacketConn {
return &ClientPacketConn{
session: session,
stream: stream,
sessionId: sessionId,
destination: destination,
msgCh: msgCh,
closer: closer,
}
}
func (c *ClientPacketConn) Hold() {
// Hold the stream until it's closed
buf := make([]byte, 1024)
for {
_, err := c.stream.Read(buf)
if err != nil {
break
}
}
_ = c.Close()
}
func (c *ClientPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
msg := <-c.msgCh
if msg == nil {
err = net.ErrClosed
return
}
err = common.Error(buffer.Write(msg.Data))
destination = M.ParseSocksaddrHostPort(msg.Host, msg.Port)
return
}
func (c *ClientPacketConn) ReadPacketThreadSafe() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
msg := <-c.msgCh
if msg == nil {
err = net.ErrClosed
return
}
buffer = buf.As(msg.Data)
destination = M.ParseSocksaddrHostPort(msg.Host, msg.Port)
return
}
func (c *ClientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
return WriteUDPMessage(c.session, UDPMessage{
SessionID: c.sessionId,
Host: destination.Unwrap().AddrString(),
Port: destination.Port,
FragCount: 1,
Data: buffer.Bytes(),
})
}
func (c *ClientPacketConn) LocalAddr() net.Addr {
return nil
}
func (c *ClientPacketConn) RemoteAddr() net.Addr {
return c.destination.UDPAddr()
}
func (c *ClientPacketConn) SetDeadline(t time.Time) error {
return os.ErrInvalid
}
func (c *ClientPacketConn) SetReadDeadline(t time.Time) error {
return os.ErrInvalid
}
func (c *ClientPacketConn) SetWriteDeadline(t time.Time) error {
return os.ErrInvalid
}
func (c *ClientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
return 0, nil, os.ErrInvalid
}
func (c *ClientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
return 0, os.ErrInvalid
}
func (c *ClientPacketConn) Read(b []byte) (n int, err error) {
return 0, os.ErrInvalid
}
func (c *ClientPacketConn) Write(b []byte) (n int, err error) {
return 0, os.ErrInvalid
}
func (c *ClientPacketConn) Close() error {
return common.Close(c.stream, c.closer)
}

View File

@ -0,0 +1,86 @@
package hysteria
import (
"math"
"time"
"github.com/sagernet/quic-go/congestion"
)
const (
maxBurstPackets = 10
minPacingDelay = time.Millisecond
)
// The pacer implements a token bucket pacing algorithm.
type pacer struct {
budgetAtLastSent congestion.ByteCount
maxDatagramSize congestion.ByteCount
lastSentTime time.Time
getBandwidth func() congestion.ByteCount // in bytes/s
}
func newPacer(getBandwidth func() congestion.ByteCount) *pacer {
p := &pacer{
budgetAtLastSent: maxBurstPackets * initMaxDatagramSize,
maxDatagramSize: initMaxDatagramSize,
getBandwidth: getBandwidth,
}
return p
}
func (p *pacer) SentPacket(sendTime time.Time, size congestion.ByteCount) {
budget := p.Budget(sendTime)
if size > budget {
p.budgetAtLastSent = 0
} else {
p.budgetAtLastSent = budget - size
}
p.lastSentTime = sendTime
}
func (p *pacer) Budget(now time.Time) congestion.ByteCount {
if p.lastSentTime.IsZero() {
return p.maxBurstSize()
}
budget := p.budgetAtLastSent + (p.getBandwidth()*congestion.ByteCount(now.Sub(p.lastSentTime).Nanoseconds()))/1e9
return minByteCount(p.maxBurstSize(), budget)
}
func (p *pacer) maxBurstSize() congestion.ByteCount {
return maxByteCount(
congestion.ByteCount((minPacingDelay+time.Millisecond).Nanoseconds())*p.getBandwidth()/1e9,
maxBurstPackets*p.maxDatagramSize,
)
}
// TimeUntilSend returns when the next packet should be sent.
// It returns the zero value of time.Time if a packet can be sent immediately.
func (p *pacer) TimeUntilSend() time.Time {
if p.budgetAtLastSent >= p.maxDatagramSize {
return time.Time{}
}
return p.lastSentTime.Add(maxDuration(
minPacingDelay,
time.Duration(math.Ceil(float64(p.maxDatagramSize-p.budgetAtLastSent)*1e9/
float64(p.getBandwidth())))*time.Nanosecond,
))
}
func (p *pacer) SetMaxDatagramSize(s congestion.ByteCount) {
p.maxDatagramSize = s
}
func maxByteCount(a, b congestion.ByteCount) congestion.ByteCount {
if a < b {
return b
}
return a
}
func minByteCount(a, b congestion.ByteCount) congestion.ByteCount {
if a < b {
return a
}
return b
}

View File

@ -5,13 +5,15 @@ import (
"encoding/binary"
"io"
"math/rand"
"net"
"os"
"time"
"github.com/sagernet/quic-go"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
"github.com/lucas-clemente/quic-go"
M "github.com/sagernet/sing/common/metadata"
)
const (
@ -177,13 +179,12 @@ func ReadClientRequest(stream io.Reader) (*ClientRequest, error) {
return &clientRequest, nil
}
func WriteClientRequest(stream io.Writer, request ClientRequest, payload []byte) error {
func WriteClientRequest(stream io.Writer, request ClientRequest) error {
var requestLen int
requestLen += 1 // udp
requestLen += 2 // host len
requestLen += len(request.Host)
requestLen += 2 // port
requestLen += len(payload)
_buffer := buf.StackNewSize(requestLen)
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
@ -197,7 +198,6 @@ func WriteClientRequest(stream io.Writer, request ClientRequest, payload []byte)
binary.Write(buffer, binary.BigEndian, uint16(len(request.Host))),
common.Error(buffer.WriteString(request.Host)),
binary.Write(buffer, binary.BigEndian, request.Port),
common.Error(buffer.Write(payload)),
)
return common.Error(stream.Write(buffer.Bytes()))
}
@ -237,13 +237,12 @@ func ReadServerResponse(stream io.Reader) (*ServerResponse, error) {
return &serverResponse, nil
}
func WriteServerResponse(stream io.Writer, response ServerResponse, payload []byte) error {
func WriteServerResponse(stream io.Writer, response ServerResponse) error {
var responseLen int
responseLen += 1 // ok
responseLen += 4 // udp session id
responseLen += 2 // message len
responseLen += len(response.Message)
responseLen += len(payload)
_buffer := buf.StackNewSize(responseLen)
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
@ -257,7 +256,6 @@ func WriteServerResponse(stream io.Writer, response ServerResponse, payload []by
binary.Write(buffer, binary.BigEndian, response.UDPSessionID),
binary.Write(buffer, binary.BigEndian, uint16(len(response.Message))),
common.Error(buffer.WriteString(response.Message)),
common.Error(buffer.Write(payload)),
)
return common.Error(stream.Write(buffer.Bytes()))
}
@ -341,9 +339,7 @@ func WriteUDPMessage(conn quic.Connection, message UDPMessage) error {
buffer := common.Dup(_buffer)
defer buffer.Release()
err := writeUDPMessage(conn, message, buffer)
// TODO: wait for change upstream
if /*errSize, ok := err.(quic.ErrMessageToLarge); ok*/ false {
const errSize = 0
if errSize, ok := err.(quic.ErrMessageToLarge); ok {
// need to frag
message.MsgID = uint16(rand.Intn(0xFFFF)) + 1 // msgID must be > 0 when fragCount > 1
fragMsgs := FragUDPMessage(message, int(errSize))
@ -373,3 +369,142 @@ func writeUDPMessage(conn quic.Connection, message UDPMessage, buffer *buf.Buffe
)
return conn.SendMessage(buffer.Bytes())
}
var _ net.Conn = (*Conn)(nil)
type Conn struct {
quic.Stream
destination M.Socksaddr
responseWritten bool
}
func NewConn(stream quic.Stream, destination M.Socksaddr) *Conn {
return &Conn{
Stream: stream,
destination: destination,
}
}
func (c *Conn) LocalAddr() net.Addr {
return nil
}
func (c *Conn) RemoteAddr() net.Addr {
return c.destination.TCPAddr()
}
func (c *Conn) ReaderReplaceable() bool {
return true
}
func (c *Conn) WriterReplaceable() bool {
return true
}
func (c *Conn) Upstream() any {
return c.Stream
}
type PacketConn struct {
session quic.Connection
stream quic.Stream
sessionId uint32
destination M.Socksaddr
msgCh <-chan *UDPMessage
closer io.Closer
}
func NewPacketConn(session quic.Connection, stream quic.Stream, sessionId uint32, destination M.Socksaddr, msgCh <-chan *UDPMessage, closer io.Closer) *PacketConn {
return &PacketConn{
session: session,
stream: stream,
sessionId: sessionId,
destination: destination,
msgCh: msgCh,
closer: closer,
}
}
func (c *PacketConn) Hold() {
// Hold the stream until it's closed
buf := make([]byte, 1024)
for {
_, err := c.stream.Read(buf)
if err != nil {
break
}
}
_ = c.Close()
}
func (c *PacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
msg := <-c.msgCh
if msg == nil {
err = net.ErrClosed
return
}
err = common.Error(buffer.Write(msg.Data))
destination = M.ParseSocksaddrHostPort(msg.Host, msg.Port)
return
}
func (c *PacketConn) ReadPacketThreadSafe() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
msg := <-c.msgCh
if msg == nil {
err = net.ErrClosed
return
}
buffer = buf.As(msg.Data)
destination = M.ParseSocksaddrHostPort(msg.Host, msg.Port)
return
}
func (c *PacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
return WriteUDPMessage(c.session, UDPMessage{
SessionID: c.sessionId,
Host: destination.Unwrap().AddrString(),
Port: destination.Port,
FragCount: 1,
Data: buffer.Bytes(),
})
}
func (c *PacketConn) LocalAddr() net.Addr {
return nil
}
func (c *PacketConn) RemoteAddr() net.Addr {
return c.destination.UDPAddr()
}
func (c *PacketConn) SetDeadline(t time.Time) error {
return os.ErrInvalid
}
func (c *PacketConn) SetReadDeadline(t time.Time) error {
return os.ErrInvalid
}
func (c *PacketConn) SetWriteDeadline(t time.Time) error {
return os.ErrInvalid
}
func (c *PacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
return 0, nil, os.ErrInvalid
}
func (c *PacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
return 0, os.ErrInvalid
}
func (c *PacketConn) Read(b []byte) (n int, err error) {
return 0, os.ErrInvalid
}
func (c *PacketConn) Write(b []byte) (n int, err error) {
return 0, os.ErrInvalid
}
func (c *PacketConn) Close() error {
return common.Close(c.stream, c.closer)
}

View File

@ -1,68 +0,0 @@
package hysteria
import (
"net"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/lucas-clemente/quic-go"
)
var (
_ net.Conn = (*ServerConn)(nil)
_ N.HandshakeConn = (*ServerConn)(nil)
)
type ServerConn struct {
quic.Stream
destination M.Socksaddr
responseWritten bool
}
func NewServerConn(stream quic.Stream, destination M.Socksaddr) *ServerConn {
return &ServerConn{
Stream: stream,
destination: destination,
}
}
func (c *ServerConn) LocalAddr() net.Addr {
return nil
}
func (c *ServerConn) RemoteAddr() net.Addr {
return c.destination.TCPAddr()
}
func (c *ServerConn) Write(b []byte) (n int, err error) {
if !c.responseWritten {
err = WriteServerResponse(c.Stream, ServerResponse{
OK: true,
}, b)
c.responseWritten = true
return len(b), nil
}
return c.Stream.Write(b)
}
func (c *ServerConn) ReaderReplaceable() bool {
return true
}
func (c *ServerConn) WriterReplaceable() bool {
return c.responseWritten
}
func (c *ServerConn) HandshakeFailure(err error) error {
if c.responseWritten {
return nil
}
return WriteServerResponse(c.Stream, ServerResponse{
Message: err.Error(),
}, nil)
}
func (c *ServerConn) Upstream() any {
return c.Stream
}

View File

@ -5,25 +5,52 @@ import (
"os"
"syscall"
"github.com/sagernet/quic-go"
"github.com/sagernet/sing/common"
)
type WrapPacketConn struct {
type PacketConnWrapper struct {
net.PacketConn
}
func (c *WrapPacketConn) SetReadBuffer(bytes int) error {
func (c *PacketConnWrapper) SetReadBuffer(bytes int) error {
return common.MustCast[*net.UDPConn](c.PacketConn).SetReadBuffer(bytes)
}
func (c *WrapPacketConn) SetWriteBuffer(bytes int) error {
func (c *PacketConnWrapper) SetWriteBuffer(bytes int) error {
return common.MustCast[*net.UDPConn](c.PacketConn).SetWriteBuffer(bytes)
}
func (c *WrapPacketConn) SyscallConn() (syscall.RawConn, error) {
func (c *PacketConnWrapper) SyscallConn() (syscall.RawConn, error) {
return common.MustCast[*net.UDPConn](c.PacketConn).SyscallConn()
}
func (c *WrapPacketConn) File() (f *os.File, err error) {
func (c *PacketConnWrapper) File() (f *os.File, err error) {
return common.MustCast[*net.UDPConn](c.PacketConn).File()
}
func (c *PacketConnWrapper) Upstream() any {
return c.PacketConn
}
type StreamWrapper struct {
quic.Stream
}
func (s *StreamWrapper) Upstream() any {
return s.Stream
}
func (s *StreamWrapper) ReaderReplaceable() bool {
return true
}
func (s *StreamWrapper) WriterReplaceable() bool {
return true
}
func (s *StreamWrapper) Close() error {
s.CancelRead(0)
s.Stream.Close()
return nil
}