mirror of
https://github.com/SagerNet/sing-box.git
synced 2025-06-08 19:54:12 +08:00
Add hysteria udp server
This commit is contained in:
parent
b992d942c4
commit
32cb511b7c
@ -7,19 +7,20 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/sagernet/quic-go"
|
||||||
|
"github.com/sagernet/quic-go/congestion"
|
||||||
"github.com/sagernet/sing-box/adapter"
|
"github.com/sagernet/sing-box/adapter"
|
||||||
C "github.com/sagernet/sing-box/constant"
|
C "github.com/sagernet/sing-box/constant"
|
||||||
"github.com/sagernet/sing-box/log"
|
"github.com/sagernet/sing-box/log"
|
||||||
"github.com/sagernet/sing-box/option"
|
"github.com/sagernet/sing-box/option"
|
||||||
"github.com/sagernet/sing-box/transport/hysteria"
|
"github.com/sagernet/sing-box/transport/hysteria"
|
||||||
dns "github.com/sagernet/sing-dns"
|
"github.com/sagernet/sing-dns"
|
||||||
"github.com/sagernet/sing/common"
|
"github.com/sagernet/sing/common"
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
N "github.com/sagernet/sing/common/network"
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ adapter.Inbound = (*Hysteria)(nil)
|
var _ adapter.Inbound = (*Hysteria)(nil)
|
||||||
@ -37,6 +38,10 @@ type Hysteria struct {
|
|||||||
sendBPS uint64
|
sendBPS uint64
|
||||||
recvBPS uint64
|
recvBPS uint64
|
||||||
listener quic.Listener
|
listener quic.Listener
|
||||||
|
udpAccess sync.RWMutex
|
||||||
|
udpSessionId uint32
|
||||||
|
udpSessions map[uint32]chan *hysteria.UDPMessage
|
||||||
|
udpDefragger hysteria.Defragger
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHysteria(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.HysteriaInboundOptions) (*Hysteria, error) {
|
func NewHysteria(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.HysteriaInboundOptions) (*Hysteria, error) {
|
||||||
@ -105,6 +110,7 @@ func NewHysteria(ctx context.Context, router adapter.Router, logger log.ContextL
|
|||||||
xplusKey: xplus,
|
xplusKey: xplus,
|
||||||
sendBPS: up,
|
sendBPS: up,
|
||||||
recvBPS: down,
|
recvBPS: down,
|
||||||
|
udpSessions: make(map[uint32]chan *hysteria.UDPMessage),
|
||||||
}
|
}
|
||||||
if options.TLS == nil || !options.TLS.Enabled {
|
if options.TLS == nil || !options.TLS.Enabled {
|
||||||
return nil, ErrTLSRequired
|
return nil, ErrTLSRequired
|
||||||
@ -138,6 +144,7 @@ func (h *Hysteria) Start() error {
|
|||||||
}
|
}
|
||||||
if len(h.xplusKey) > 0 {
|
if len(h.xplusKey) > 0 {
|
||||||
packetConn = hysteria.NewXPlusPacketConn(packetConn, h.xplusKey)
|
packetConn = hysteria.NewXPlusPacketConn(packetConn, h.xplusKey)
|
||||||
|
packetConn = &hysteria.PacketConnWrapper{PacketConn: packetConn}
|
||||||
}
|
}
|
||||||
err = h.tlsConfig.Start()
|
err = h.tlsConfig.Start()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -159,6 +166,7 @@ func (h *Hysteria) acceptLoop() {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
h.logger.InfoContext(ctx, "inbound connection from ", conn.RemoteAddr())
|
||||||
go func() {
|
go func() {
|
||||||
hErr := h.accept(ctx, conn)
|
hErr := h.accept(ctx, conn)
|
||||||
if hErr != nil {
|
if hErr != nil {
|
||||||
@ -202,23 +210,51 @@ func (h *Hysteria) accept(ctx context.Context, conn quic.Connection) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// TODO: set congestion control
|
conn.SetCongestionControl(hysteria.NewBrutalSender(congestion.ByteCount(serverSendBPS)))
|
||||||
go h.udpRecvLoop(conn)
|
go h.udpRecvLoop(conn)
|
||||||
var stream quic.Stream
|
|
||||||
for {
|
for {
|
||||||
|
var stream quic.Stream
|
||||||
stream, err = conn.AcceptStream(ctx)
|
stream, err = conn.AcceptStream(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
hErr := h.acceptStream(ctx, conn, stream)
|
go func() {
|
||||||
|
hErr := h.acceptStream(ctx, conn /*&hysteria.StreamWrapper{Stream: stream}*/, stream)
|
||||||
if hErr != nil {
|
if hErr != nil {
|
||||||
stream.Close()
|
stream.Close()
|
||||||
NewError(h.logger, ctx, E.Cause(hErr, "process stream from ", conn.RemoteAddr()))
|
NewError(h.logger, ctx, E.Cause(hErr, "process stream from ", conn.RemoteAddr()))
|
||||||
}
|
}
|
||||||
|
}()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Hysteria) udpRecvLoop(conn quic.Connection) {
|
func (h *Hysteria) udpRecvLoop(conn quic.Connection) {
|
||||||
|
for {
|
||||||
|
packet, err := conn.ReceiveMessage()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
message, err := hysteria.ParseUDPMessage(packet)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error("parse udp message: ", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
dfMsg := h.udpDefragger.Feed(message)
|
||||||
|
if dfMsg == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
h.udpAccess.RLock()
|
||||||
|
ch, ok := h.udpSessions[dfMsg.SessionID]
|
||||||
|
if ok {
|
||||||
|
select {
|
||||||
|
case ch <- dfMsg:
|
||||||
|
// OK
|
||||||
|
default:
|
||||||
|
// Silently drop the message when the channel is full
|
||||||
|
}
|
||||||
|
}
|
||||||
|
h.udpAccess.RUnlock()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Hysteria) acceptStream(ctx context.Context, conn quic.Connection, stream quic.Stream) error {
|
func (h *Hysteria) acceptStream(ctx context.Context, conn quic.Connection, stream quic.Stream) error {
|
||||||
@ -226,29 +262,55 @@ func (h *Hysteria) acceptStream(ctx context.Context, conn quic.Connection, strea
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if request.UDP {
|
|
||||||
err = hysteria.WriteServerResponse(stream, hysteria.ServerResponse{
|
err = hysteria.WriteServerResponse(stream, hysteria.ServerResponse{
|
||||||
Message: "unsupported",
|
OK: true,
|
||||||
}, nil)
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
stream.Close()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
var metadata adapter.InboundContext
|
var metadata adapter.InboundContext
|
||||||
metadata.Inbound = h.tag
|
metadata.Inbound = h.tag
|
||||||
metadata.InboundType = C.TypeHysteria
|
metadata.InboundType = C.TypeHysteria
|
||||||
metadata.SniffEnabled = h.listenOptions.SniffEnabled
|
metadata.SniffEnabled = h.listenOptions.SniffEnabled
|
||||||
metadata.SniffOverrideDestination = h.listenOptions.SniffOverrideDestination
|
metadata.SniffOverrideDestination = h.listenOptions.SniffOverrideDestination
|
||||||
metadata.DomainStrategy = dns.DomainStrategy(h.listenOptions.DomainStrategy)
|
metadata.DomainStrategy = dns.DomainStrategy(h.listenOptions.DomainStrategy)
|
||||||
metadata.Network = N.NetworkTCP
|
|
||||||
metadata.Source = M.SocksaddrFromNet(conn.RemoteAddr())
|
metadata.Source = M.SocksaddrFromNet(conn.RemoteAddr())
|
||||||
metadata.Destination = M.ParseSocksaddrHostPort(request.Host, request.Port)
|
metadata.Destination = M.ParseSocksaddrHostPort(request.Host, request.Port)
|
||||||
return h.router.RouteConnection(ctx, hysteria.NewServerConn(stream, metadata.Destination), metadata)
|
if !request.UDP {
|
||||||
|
h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination)
|
||||||
|
metadata.Network = N.NetworkTCP
|
||||||
|
return h.router.RouteConnection(ctx, hysteria.NewConn(stream, metadata.Destination), metadata)
|
||||||
|
} else {
|
||||||
|
h.logger.InfoContext(ctx, "inbound packet connection to ", metadata.Destination)
|
||||||
|
var id uint32
|
||||||
|
h.udpAccess.Lock()
|
||||||
|
id = h.udpSessionId
|
||||||
|
nCh := make(chan *hysteria.UDPMessage, 1024)
|
||||||
|
h.udpSessions[id] = nCh
|
||||||
|
h.udpSessionId += 1
|
||||||
|
h.udpAccess.Unlock()
|
||||||
|
metadata.Network = N.NetworkUDP
|
||||||
|
packetConn := hysteria.NewPacketConn(conn, stream, id, metadata.Destination, nCh, common.Closer(func() error {
|
||||||
|
h.udpAccess.Lock()
|
||||||
|
if ch, ok := h.udpSessions[id]; ok {
|
||||||
|
close(ch)
|
||||||
|
delete(h.udpSessions, id)
|
||||||
|
}
|
||||||
|
h.udpAccess.Unlock()
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
go packetConn.Hold()
|
||||||
|
return h.router.RoutePacketConnection(ctx, packetConn, metadata)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Hysteria) Close() error {
|
func (h *Hysteria) Close() error {
|
||||||
|
h.udpAccess.Lock()
|
||||||
|
for _, session := range h.udpSessions {
|
||||||
|
close(session)
|
||||||
|
}
|
||||||
|
h.udpSessions = make(map[uint32]chan *hysteria.UDPMessage)
|
||||||
|
h.udpAccess.Unlock()
|
||||||
return common.Close(
|
return common.Close(
|
||||||
h.listener,
|
h.listener,
|
||||||
common.PtrOrNil(h.tlsConfig),
|
common.PtrOrNil(h.tlsConfig),
|
||||||
|
@ -195,6 +195,8 @@ func (n *Naive) newConnection(ctx context.Context, conn net.Conn, source, destin
|
|||||||
metadata.Network = N.NetworkTCP
|
metadata.Network = N.NetworkTCP
|
||||||
metadata.Source = source
|
metadata.Source = source
|
||||||
metadata.Destination = destination
|
metadata.Destination = destination
|
||||||
|
n.logger.InfoContext(ctx, "inbound connection from ", metadata.Source)
|
||||||
|
n.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination)
|
||||||
hErr := n.router.RouteConnection(ctx, conn, metadata)
|
hErr := n.router.RouteConnection(ctx, conn, metadata)
|
||||||
if hErr != nil {
|
if hErr != nil {
|
||||||
conn.Close()
|
conn.Close()
|
||||||
|
@ -10,6 +10,8 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"github.com/sagernet/quic-go"
|
||||||
|
"github.com/sagernet/quic-go/congestion"
|
||||||
"github.com/sagernet/sing-box/adapter"
|
"github.com/sagernet/sing-box/adapter"
|
||||||
"github.com/sagernet/sing-box/common/dialer"
|
"github.com/sagernet/sing-box/common/dialer"
|
||||||
C "github.com/sagernet/sing-box/constant"
|
C "github.com/sagernet/sing-box/constant"
|
||||||
@ -21,8 +23,6 @@ import (
|
|||||||
E "github.com/sagernet/sing/common/exceptions"
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
N "github.com/sagernet/sing/common/network"
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ adapter.Outbound = (*Hysteria)(nil)
|
var _ adapter.Outbound = (*Hysteria)(nil)
|
||||||
@ -171,7 +171,7 @@ func (h *Hysteria) offer(ctx context.Context) (quic.Connection, error) {
|
|||||||
}
|
}
|
||||||
h.udpSessions = make(map[uint32]chan *hysteria.UDPMessage)
|
h.udpSessions = make(map[uint32]chan *hysteria.UDPMessage)
|
||||||
h.udpDefragger = hysteria.Defragger{}
|
h.udpDefragger = hysteria.Defragger{}
|
||||||
go h.recvLoop(conn)
|
go h.udpRecvLoop(conn)
|
||||||
}
|
}
|
||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
@ -186,7 +186,7 @@ func (h *Hysteria) offerNew(ctx context.Context) (quic.Connection, error) {
|
|||||||
if h.xplusKey != nil {
|
if h.xplusKey != nil {
|
||||||
packetConn = hysteria.NewXPlusPacketConn(packetConn, h.xplusKey)
|
packetConn = hysteria.NewXPlusPacketConn(packetConn, h.xplusKey)
|
||||||
}
|
}
|
||||||
packetConn = &hysteria.WrapPacketConn{PacketConn: packetConn}
|
packetConn = &hysteria.PacketConnWrapper{PacketConn: packetConn}
|
||||||
quicConn, err := quic.Dial(packetConn, udpConn.RemoteAddr(), h.serverAddr.AddrString(), h.tlsConfig, h.quicConfig)
|
quicConn, err := quic.Dial(packetConn, udpConn.RemoteAddr(), h.serverAddr.AddrString(), h.tlsConfig, h.quicConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
packetConn.Close()
|
packetConn.Close()
|
||||||
@ -203,20 +203,23 @@ func (h *Hysteria) offerNew(ctx context.Context) (quic.Connection, error) {
|
|||||||
Auth: h.authKey,
|
Auth: h.authKey,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
packetConn.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
serverHello, err := hysteria.ReadServerHello(controlStream)
|
serverHello, err := hysteria.ReadServerHello(controlStream)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
packetConn.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if !serverHello.OK {
|
if !serverHello.OK {
|
||||||
|
packetConn.Close()
|
||||||
return nil, E.New("remote error: ", serverHello.Message)
|
return nil, E.New("remote error: ", serverHello.Message)
|
||||||
}
|
}
|
||||||
// TODO: set congestion control
|
quicConn.SetCongestionControl(hysteria.NewBrutalSender(congestion.ByteCount(serverHello.RecvBPS)))
|
||||||
return quicConn, nil
|
return quicConn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Hysteria) recvLoop(conn quic.Connection) {
|
func (h *Hysteria) udpRecvLoop(conn quic.Connection) {
|
||||||
for {
|
for {
|
||||||
packet, err := conn.ReceiveMessage()
|
packet, err := conn.ReceiveMessage()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -260,35 +263,58 @@ func (h *Hysteria) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Hysteria) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
func (h *Hysteria) open(ctx context.Context) (quic.Connection, quic.Stream, error) {
|
||||||
switch N.NetworkName(network) {
|
|
||||||
case N.NetworkTCP:
|
|
||||||
conn, err := h.offer(ctx)
|
conn, err := h.offer(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
stream, err := conn.OpenStream()
|
stream, err := conn.OpenStream()
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
return conn, &hysteria.StreamWrapper{Stream: stream}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hysteria) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||||
|
switch N.NetworkName(network) {
|
||||||
|
case N.NetworkTCP:
|
||||||
|
h.logger.InfoContext(ctx, "outbound connection to ", destination)
|
||||||
|
_, stream, err := h.open(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return hysteria.NewClientConn(stream, destination), nil
|
err = hysteria.WriteClientRequest(stream, hysteria.ClientRequest{
|
||||||
|
Host: destination.AddrString(),
|
||||||
|
Port: destination.Port,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
stream.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
response, err := hysteria.ReadServerResponse(stream)
|
||||||
|
if err != nil {
|
||||||
|
stream.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if !response.OK {
|
||||||
|
stream.Close()
|
||||||
|
return nil, E.New("remote error: ", response.Message)
|
||||||
|
}
|
||||||
|
return hysteria.NewConn(stream, destination), nil
|
||||||
case N.NetworkUDP:
|
case N.NetworkUDP:
|
||||||
conn, err := h.ListenPacket(ctx, destination)
|
conn, err := h.ListenPacket(ctx, destination)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return conn.(*hysteria.ClientPacketConn), nil
|
return conn.(*hysteria.PacketConn), nil
|
||||||
default:
|
default:
|
||||||
return nil, E.New("unsupported network: ", network)
|
return nil, E.New("unsupported network: ", network)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Hysteria) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
func (h *Hysteria) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||||
conn, err := h.offer(ctx)
|
h.logger.InfoContext(ctx, "outbound packet connection to ", destination)
|
||||||
if err != nil {
|
conn, stream, err := h.open(ctx)
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
stream, err := conn.OpenStream()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -296,7 +322,7 @@ func (h *Hysteria) ListenPacket(ctx context.Context, destination M.Socksaddr) (n
|
|||||||
UDP: true,
|
UDP: true,
|
||||||
Host: destination.AddrString(),
|
Host: destination.AddrString(),
|
||||||
Port: destination.Port,
|
Port: destination.Port,
|
||||||
}, nil)
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
stream.Close()
|
stream.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -313,12 +339,9 @@ func (h *Hysteria) ListenPacket(ctx context.Context, destination M.Socksaddr) (n
|
|||||||
}
|
}
|
||||||
h.udpAccess.Lock()
|
h.udpAccess.Lock()
|
||||||
nCh := make(chan *hysteria.UDPMessage, 1024)
|
nCh := make(chan *hysteria.UDPMessage, 1024)
|
||||||
// Store the current session map for CloseFunc below
|
|
||||||
// to ensures that we are adding and removing sessions on the same map,
|
|
||||||
// as reconnecting will reassign the map
|
|
||||||
h.udpSessions[response.UDPSessionID] = nCh
|
h.udpSessions[response.UDPSessionID] = nCh
|
||||||
h.udpAccess.Unlock()
|
h.udpAccess.Unlock()
|
||||||
packetConn := hysteria.NewClientPacketConn(conn, stream, response.UDPSessionID, destination, nCh, common.Closer(func() error {
|
packetConn := hysteria.NewPacketConn(conn, stream, response.UDPSessionID, destination, nCh, common.Closer(func() error {
|
||||||
h.udpAccess.Lock()
|
h.udpAccess.Lock()
|
||||||
if ch, ok := h.udpSessions[response.UDPSessionID]; ok {
|
if ch, ok := h.udpSessions[response.UDPSessionID]; ok {
|
||||||
close(ch)
|
close(ch)
|
||||||
|
@ -56,10 +56,21 @@ func testTCP(t *testing.T, clientPort uint16, testPort uint16) {
|
|||||||
dialTCP := func() (net.Conn, error) {
|
dialTCP := func() (net.Conn, error) {
|
||||||
return dialer.DialContext(context.Background(), "tcp", M.ParseSocksaddrHostPort("127.0.0.1", testPort))
|
return dialer.DialContext(context.Background(), "tcp", M.ParseSocksaddrHostPort("127.0.0.1", testPort))
|
||||||
}
|
}
|
||||||
require.NoError(t, testPingPongWithConn(t, testPort, dialTCP))
|
|
||||||
require.NoError(t, testLargeDataWithConn(t, testPort, dialTCP))
|
require.NoError(t, testLargeDataWithConn(t, testPort, dialTCP))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func testSuitHy(t *testing.T, clientPort uint16, testPort uint16) {
|
||||||
|
dialer := socks.NewClient(N.SystemDialer, M.ParseSocksaddrHostPort("127.0.0.1", clientPort), socks.Version5, "", "")
|
||||||
|
dialTCP := func() (net.Conn, error) {
|
||||||
|
return dialer.DialContext(context.Background(), "tcp", M.ParseSocksaddrHostPort("127.0.0.1", testPort))
|
||||||
|
}
|
||||||
|
dialUDP := func() (net.PacketConn, error) {
|
||||||
|
return dialer.ListenPacket(context.Background(), M.ParseSocksaddrHostPort("127.0.0.1", testPort))
|
||||||
|
}
|
||||||
|
require.NoError(t, testPingPongWithConn(t, testPort, dialTCP))
|
||||||
|
require.NoError(t, testPingPongWithPacketConn(t, testPort, dialUDP))
|
||||||
|
}
|
||||||
|
|
||||||
func testSuitWg(t *testing.T, clientPort uint16, testPort uint16) {
|
func testSuitWg(t *testing.T, clientPort uint16, testPort uint16) {
|
||||||
dialer := socks.NewClient(N.SystemDialer, M.ParseSocksaddrHostPort("127.0.0.1", clientPort), socks.Version5, "", "")
|
dialer := socks.NewClient(N.SystemDialer, M.ParseSocksaddrHostPort("127.0.0.1", clientPort), socks.Version5, "", "")
|
||||||
dialTCP := func() (net.Conn, error) {
|
dialTCP := func() (net.Conn, error) {
|
||||||
|
@ -7,6 +7,8 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"net/http"
|
||||||
|
_ "net/http/pprof"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
@ -91,6 +93,12 @@ func init() {
|
|||||||
|
|
||||||
io.Copy(io.Discard, imageStream)
|
io.Copy(io.Discard, imageStream)
|
||||||
}
|
}
|
||||||
|
go func() {
|
||||||
|
err = http.ListenAndServe("0.0.0.0:8965", nil)
|
||||||
|
if err != nil {
|
||||||
|
log.Debug(err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
func newPingPongPair() (chan []byte, chan []byte, func(t *testing.T) error) {
|
func newPingPongPair() (chan []byte, chan []byte, func(t *testing.T) error) {
|
||||||
|
12
test/config/hysteria-client.json
Normal file
12
test/config/hysteria-client.json
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
{
|
||||||
|
"server": "127.0.0.1:10000",
|
||||||
|
"auth_str": "password",
|
||||||
|
"obfs": "fuck me till the daylight",
|
||||||
|
"up_mbps": 100,
|
||||||
|
"down_mbps": 100,
|
||||||
|
"socks5": {
|
||||||
|
"listen": "127.0.0.1:10001"
|
||||||
|
},
|
||||||
|
"server_name": "example.org",
|
||||||
|
"ca": "/etc/hysteria/ca.pem"
|
||||||
|
}
|
@ -80,7 +80,50 @@ func TestHysteriaSelf(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
testTCP(t, clientPort, testPort)
|
testSuitHy(t, clientPort, testPort)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHysteriaInbound(t *testing.T) {
|
||||||
|
if !C.QUIC_AVAILABLE {
|
||||||
|
t.Skip("QUIC not included")
|
||||||
|
}
|
||||||
|
caPem, certPem, keyPem := createSelfSignedCertificate(t, "example.org")
|
||||||
|
startInstance(t, option.Options{
|
||||||
|
Log: &option.LogOptions{
|
||||||
|
Level: "trace",
|
||||||
|
},
|
||||||
|
Inbounds: []option.Inbound{
|
||||||
|
{
|
||||||
|
Type: C.TypeHysteria,
|
||||||
|
HysteriaOptions: option.HysteriaInboundOptions{
|
||||||
|
ListenOptions: option.ListenOptions{
|
||||||
|
Listen: option.ListenAddress(netip.IPv4Unspecified()),
|
||||||
|
ListenPort: serverPort,
|
||||||
|
},
|
||||||
|
UpMbps: 100,
|
||||||
|
DownMbps: 100,
|
||||||
|
AuthString: "password",
|
||||||
|
Obfs: "fuck me till the daylight",
|
||||||
|
TLS: &option.InboundTLSOptions{
|
||||||
|
Enabled: true,
|
||||||
|
ServerName: "example.org",
|
||||||
|
CertificatePath: certPem,
|
||||||
|
KeyPath: keyPem,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
startDockerContainer(t, DockerOptions{
|
||||||
|
Image: ImageHysteria,
|
||||||
|
Ports: []uint16{serverPort, clientPort},
|
||||||
|
Cmd: []string{"-c", "/etc/hysteria/config.json", "client"},
|
||||||
|
Bind: map[string]string{
|
||||||
|
"hysteria-client.json": "/etc/hysteria/config.json",
|
||||||
|
caPem: "/etc/hysteria/ca.pem",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
testSuit(t, clientPort, testPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHysteriaOutbound(t *testing.T) {
|
func TestHysteriaOutbound(t *testing.T) {
|
||||||
@ -93,7 +136,7 @@ func TestHysteriaOutbound(t *testing.T) {
|
|||||||
Ports: []uint16{serverPort, testPort},
|
Ports: []uint16{serverPort, testPort},
|
||||||
Cmd: []string{"-c", "/etc/hysteria/config.json", "server"},
|
Cmd: []string{"-c", "/etc/hysteria/config.json", "server"},
|
||||||
Bind: map[string]string{
|
Bind: map[string]string{
|
||||||
"hysteria.json": "/etc/hysteria/config.json",
|
"hysteria-server.json": "/etc/hysteria/config.json",
|
||||||
certPem: "/etc/hysteria/cert.pem",
|
certPem: "/etc/hysteria/cert.pem",
|
||||||
keyPem: "/etc/hysteria/key.pem",
|
keyPem: "/etc/hysteria/key.pem",
|
||||||
},
|
},
|
||||||
@ -131,5 +174,5 @@ func TestHysteriaOutbound(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
testSuit(t, clientPort, testPort)
|
testSuitHy(t, clientPort, testPort)
|
||||||
}
|
}
|
||||||
|
149
transport/hysteria/brutal.go
Normal file
149
transport/hysteria/brutal.go
Normal file
@ -0,0 +1,149 @@
|
|||||||
|
package hysteria
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sagernet/quic-go/congestion"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
initMaxDatagramSize = 1252
|
||||||
|
|
||||||
|
pktInfoSlotCount = 4
|
||||||
|
minSampleCount = 50
|
||||||
|
minAckRate = 0.8
|
||||||
|
)
|
||||||
|
|
||||||
|
type BrutalSender struct {
|
||||||
|
rttStats congestion.RTTStatsProvider
|
||||||
|
bps congestion.ByteCount
|
||||||
|
maxDatagramSize congestion.ByteCount
|
||||||
|
pacer *pacer
|
||||||
|
|
||||||
|
pktInfoSlots [pktInfoSlotCount]pktInfo
|
||||||
|
ackRate float64
|
||||||
|
}
|
||||||
|
|
||||||
|
type pktInfo struct {
|
||||||
|
Timestamp int64
|
||||||
|
AckCount uint64
|
||||||
|
LossCount uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewBrutalSender(bps congestion.ByteCount) *BrutalSender {
|
||||||
|
bs := &BrutalSender{
|
||||||
|
bps: bps,
|
||||||
|
maxDatagramSize: initMaxDatagramSize,
|
||||||
|
ackRate: 1,
|
||||||
|
}
|
||||||
|
bs.pacer = newPacer(func() congestion.ByteCount {
|
||||||
|
return congestion.ByteCount(float64(bs.bps) / bs.ackRate)
|
||||||
|
})
|
||||||
|
return bs
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *BrutalSender) SetRTTStatsProvider(rttStats congestion.RTTStatsProvider) {
|
||||||
|
b.rttStats = rttStats
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *BrutalSender) TimeUntilSend(bytesInFlight congestion.ByteCount) time.Time {
|
||||||
|
return b.pacer.TimeUntilSend()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *BrutalSender) HasPacingBudget() bool {
|
||||||
|
return b.pacer.Budget(time.Now()) >= b.maxDatagramSize
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *BrutalSender) CanSend(bytesInFlight congestion.ByteCount) bool {
|
||||||
|
return bytesInFlight < b.GetCongestionWindow()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *BrutalSender) GetCongestionWindow() congestion.ByteCount {
|
||||||
|
rtt := maxDuration(b.rttStats.LatestRTT(), b.rttStats.SmoothedRTT())
|
||||||
|
if rtt <= 0 {
|
||||||
|
return 10240
|
||||||
|
}
|
||||||
|
return congestion.ByteCount(float64(b.bps) * rtt.Seconds() * 1.5 / b.ackRate)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *BrutalSender) OnPacketSent(sentTime time.Time, bytesInFlight congestion.ByteCount,
|
||||||
|
packetNumber congestion.PacketNumber, bytes congestion.ByteCount, isRetransmittable bool,
|
||||||
|
) {
|
||||||
|
b.pacer.SentPacket(sentTime, bytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *BrutalSender) OnPacketAcked(number congestion.PacketNumber, ackedBytes congestion.ByteCount,
|
||||||
|
priorInFlight congestion.ByteCount, eventTime time.Time,
|
||||||
|
) {
|
||||||
|
currentTimestamp := eventTime.Unix()
|
||||||
|
slot := currentTimestamp % pktInfoSlotCount
|
||||||
|
if b.pktInfoSlots[slot].Timestamp == currentTimestamp {
|
||||||
|
b.pktInfoSlots[slot].AckCount++
|
||||||
|
} else {
|
||||||
|
// uninitialized slot or too old, reset
|
||||||
|
b.pktInfoSlots[slot].Timestamp = currentTimestamp
|
||||||
|
b.pktInfoSlots[slot].AckCount = 1
|
||||||
|
b.pktInfoSlots[slot].LossCount = 0
|
||||||
|
}
|
||||||
|
b.updateAckRate(currentTimestamp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *BrutalSender) OnPacketLost(number congestion.PacketNumber, lostBytes congestion.ByteCount,
|
||||||
|
priorInFlight congestion.ByteCount,
|
||||||
|
) {
|
||||||
|
currentTimestamp := time.Now().Unix()
|
||||||
|
slot := currentTimestamp % pktInfoSlotCount
|
||||||
|
if b.pktInfoSlots[slot].Timestamp == currentTimestamp {
|
||||||
|
b.pktInfoSlots[slot].LossCount++
|
||||||
|
} else {
|
||||||
|
// uninitialized slot or too old, reset
|
||||||
|
b.pktInfoSlots[slot].Timestamp = currentTimestamp
|
||||||
|
b.pktInfoSlots[slot].AckCount = 0
|
||||||
|
b.pktInfoSlots[slot].LossCount = 1
|
||||||
|
}
|
||||||
|
b.updateAckRate(currentTimestamp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *BrutalSender) SetMaxDatagramSize(size congestion.ByteCount) {
|
||||||
|
b.maxDatagramSize = size
|
||||||
|
b.pacer.SetMaxDatagramSize(size)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *BrutalSender) updateAckRate(currentTimestamp int64) {
|
||||||
|
minTimestamp := currentTimestamp - pktInfoSlotCount
|
||||||
|
var ackCount, lossCount uint64
|
||||||
|
for _, info := range b.pktInfoSlots {
|
||||||
|
if info.Timestamp < minTimestamp {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ackCount += info.AckCount
|
||||||
|
lossCount += info.LossCount
|
||||||
|
}
|
||||||
|
if ackCount+lossCount < minSampleCount {
|
||||||
|
b.ackRate = 1
|
||||||
|
}
|
||||||
|
rate := float64(ackCount) / float64(ackCount+lossCount)
|
||||||
|
if rate < minAckRate {
|
||||||
|
b.ackRate = minAckRate
|
||||||
|
}
|
||||||
|
b.ackRate = rate
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *BrutalSender) InSlowStart() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *BrutalSender) InRecovery() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *BrutalSender) MaybeExitSlowStart() {}
|
||||||
|
|
||||||
|
func (b *BrutalSender) OnRetransmissionTimeout(packetsRetransmitted bool) {}
|
||||||
|
|
||||||
|
func maxDuration(a, b time.Duration) time.Duration {
|
||||||
|
if a > b {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
@ -1,186 +0,0 @@
|
|||||||
package hysteria
|
|
||||||
|
|
||||||
import (
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"os"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/sagernet/sing/common"
|
|
||||||
"github.com/sagernet/sing/common/buf"
|
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ net.Conn = (*ClientConn)(nil)
|
|
||||||
|
|
||||||
type ClientConn struct {
|
|
||||||
quic.Stream
|
|
||||||
destination M.Socksaddr
|
|
||||||
requestWritten bool
|
|
||||||
responseRead bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewClientConn(stream quic.Stream, destination M.Socksaddr) *ClientConn {
|
|
||||||
return &ClientConn{
|
|
||||||
Stream: stream,
|
|
||||||
destination: destination,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *ClientConn) Read(b []byte) (n int, err error) {
|
|
||||||
if !c.responseRead {
|
|
||||||
var response *ServerResponse
|
|
||||||
response, err = ReadServerResponse(c.Stream)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !response.OK {
|
|
||||||
return 0, E.New("remote error: " + response.Message)
|
|
||||||
}
|
|
||||||
c.responseRead = true
|
|
||||||
}
|
|
||||||
return c.Stream.Read(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *ClientConn) Write(b []byte) (n int, err error) {
|
|
||||||
if !c.requestWritten {
|
|
||||||
err = WriteClientRequest(c.Stream, ClientRequest{
|
|
||||||
UDP: false,
|
|
||||||
Host: c.destination.Unwrap().AddrString(),
|
|
||||||
Port: c.destination.Port,
|
|
||||||
}, b)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.requestWritten = true
|
|
||||||
return len(b), nil
|
|
||||||
}
|
|
||||||
return c.Stream.Write(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *ClientConn) LocalAddr() net.Addr {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *ClientConn) RemoteAddr() net.Addr {
|
|
||||||
return c.destination.TCPAddr()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *ClientConn) Upstream() any {
|
|
||||||
return c.Stream
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *ClientConn) ReaderReplaceable() bool {
|
|
||||||
return c.responseRead
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *ClientConn) WriterReplaceable() bool {
|
|
||||||
return c.requestWritten
|
|
||||||
}
|
|
||||||
|
|
||||||
type ClientPacketConn struct {
|
|
||||||
session quic.Connection
|
|
||||||
stream quic.Stream
|
|
||||||
sessionId uint32
|
|
||||||
destination M.Socksaddr
|
|
||||||
msgCh <-chan *UDPMessage
|
|
||||||
closer io.Closer
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewClientPacketConn(session quic.Connection, stream quic.Stream, sessionId uint32, destination M.Socksaddr, msgCh <-chan *UDPMessage, closer io.Closer) *ClientPacketConn {
|
|
||||||
return &ClientPacketConn{
|
|
||||||
session: session,
|
|
||||||
stream: stream,
|
|
||||||
sessionId: sessionId,
|
|
||||||
destination: destination,
|
|
||||||
msgCh: msgCh,
|
|
||||||
closer: closer,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *ClientPacketConn) Hold() {
|
|
||||||
// Hold the stream until it's closed
|
|
||||||
buf := make([]byte, 1024)
|
|
||||||
for {
|
|
||||||
_, err := c.stream.Read(buf)
|
|
||||||
if err != nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ = c.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *ClientPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
|
|
||||||
msg := <-c.msgCh
|
|
||||||
if msg == nil {
|
|
||||||
err = net.ErrClosed
|
|
||||||
return
|
|
||||||
}
|
|
||||||
err = common.Error(buffer.Write(msg.Data))
|
|
||||||
destination = M.ParseSocksaddrHostPort(msg.Host, msg.Port)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *ClientPacketConn) ReadPacketThreadSafe() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
|
|
||||||
msg := <-c.msgCh
|
|
||||||
if msg == nil {
|
|
||||||
err = net.ErrClosed
|
|
||||||
return
|
|
||||||
}
|
|
||||||
buffer = buf.As(msg.Data)
|
|
||||||
destination = M.ParseSocksaddrHostPort(msg.Host, msg.Port)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *ClientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
|
||||||
return WriteUDPMessage(c.session, UDPMessage{
|
|
||||||
SessionID: c.sessionId,
|
|
||||||
Host: destination.Unwrap().AddrString(),
|
|
||||||
Port: destination.Port,
|
|
||||||
FragCount: 1,
|
|
||||||
Data: buffer.Bytes(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *ClientPacketConn) LocalAddr() net.Addr {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *ClientPacketConn) RemoteAddr() net.Addr {
|
|
||||||
return c.destination.UDPAddr()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *ClientPacketConn) SetDeadline(t time.Time) error {
|
|
||||||
return os.ErrInvalid
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *ClientPacketConn) SetReadDeadline(t time.Time) error {
|
|
||||||
return os.ErrInvalid
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *ClientPacketConn) SetWriteDeadline(t time.Time) error {
|
|
||||||
return os.ErrInvalid
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *ClientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
|
||||||
return 0, nil, os.ErrInvalid
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *ClientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
|
||||||
return 0, os.ErrInvalid
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *ClientPacketConn) Read(b []byte) (n int, err error) {
|
|
||||||
return 0, os.ErrInvalid
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *ClientPacketConn) Write(b []byte) (n int, err error) {
|
|
||||||
return 0, os.ErrInvalid
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *ClientPacketConn) Close() error {
|
|
||||||
return common.Close(c.stream, c.closer)
|
|
||||||
}
|
|
86
transport/hysteria/pacer.go
Normal file
86
transport/hysteria/pacer.go
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
package hysteria
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sagernet/quic-go/congestion"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
maxBurstPackets = 10
|
||||||
|
minPacingDelay = time.Millisecond
|
||||||
|
)
|
||||||
|
|
||||||
|
// The pacer implements a token bucket pacing algorithm.
|
||||||
|
type pacer struct {
|
||||||
|
budgetAtLastSent congestion.ByteCount
|
||||||
|
maxDatagramSize congestion.ByteCount
|
||||||
|
lastSentTime time.Time
|
||||||
|
getBandwidth func() congestion.ByteCount // in bytes/s
|
||||||
|
}
|
||||||
|
|
||||||
|
func newPacer(getBandwidth func() congestion.ByteCount) *pacer {
|
||||||
|
p := &pacer{
|
||||||
|
budgetAtLastSent: maxBurstPackets * initMaxDatagramSize,
|
||||||
|
maxDatagramSize: initMaxDatagramSize,
|
||||||
|
getBandwidth: getBandwidth,
|
||||||
|
}
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *pacer) SentPacket(sendTime time.Time, size congestion.ByteCount) {
|
||||||
|
budget := p.Budget(sendTime)
|
||||||
|
if size > budget {
|
||||||
|
p.budgetAtLastSent = 0
|
||||||
|
} else {
|
||||||
|
p.budgetAtLastSent = budget - size
|
||||||
|
}
|
||||||
|
p.lastSentTime = sendTime
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *pacer) Budget(now time.Time) congestion.ByteCount {
|
||||||
|
if p.lastSentTime.IsZero() {
|
||||||
|
return p.maxBurstSize()
|
||||||
|
}
|
||||||
|
budget := p.budgetAtLastSent + (p.getBandwidth()*congestion.ByteCount(now.Sub(p.lastSentTime).Nanoseconds()))/1e9
|
||||||
|
return minByteCount(p.maxBurstSize(), budget)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *pacer) maxBurstSize() congestion.ByteCount {
|
||||||
|
return maxByteCount(
|
||||||
|
congestion.ByteCount((minPacingDelay+time.Millisecond).Nanoseconds())*p.getBandwidth()/1e9,
|
||||||
|
maxBurstPackets*p.maxDatagramSize,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TimeUntilSend returns when the next packet should be sent.
|
||||||
|
// It returns the zero value of time.Time if a packet can be sent immediately.
|
||||||
|
func (p *pacer) TimeUntilSend() time.Time {
|
||||||
|
if p.budgetAtLastSent >= p.maxDatagramSize {
|
||||||
|
return time.Time{}
|
||||||
|
}
|
||||||
|
return p.lastSentTime.Add(maxDuration(
|
||||||
|
minPacingDelay,
|
||||||
|
time.Duration(math.Ceil(float64(p.maxDatagramSize-p.budgetAtLastSent)*1e9/
|
||||||
|
float64(p.getBandwidth())))*time.Nanosecond,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *pacer) SetMaxDatagramSize(s congestion.ByteCount) {
|
||||||
|
p.maxDatagramSize = s
|
||||||
|
}
|
||||||
|
|
||||||
|
func maxByteCount(a, b congestion.ByteCount) congestion.ByteCount {
|
||||||
|
if a < b {
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
|
||||||
|
func minByteCount(a, b congestion.ByteCount) congestion.ByteCount {
|
||||||
|
if a < b {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
@ -5,13 +5,15 @@ import (
|
|||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"io"
|
"io"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/sagernet/quic-go"
|
||||||
"github.com/sagernet/sing/common"
|
"github.com/sagernet/sing/common"
|
||||||
"github.com/sagernet/sing/common/buf"
|
"github.com/sagernet/sing/common/buf"
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
"github.com/lucas-clemente/quic-go"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -177,13 +179,12 @@ func ReadClientRequest(stream io.Reader) (*ClientRequest, error) {
|
|||||||
return &clientRequest, nil
|
return &clientRequest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func WriteClientRequest(stream io.Writer, request ClientRequest, payload []byte) error {
|
func WriteClientRequest(stream io.Writer, request ClientRequest) error {
|
||||||
var requestLen int
|
var requestLen int
|
||||||
requestLen += 1 // udp
|
requestLen += 1 // udp
|
||||||
requestLen += 2 // host len
|
requestLen += 2 // host len
|
||||||
requestLen += len(request.Host)
|
requestLen += len(request.Host)
|
||||||
requestLen += 2 // port
|
requestLen += 2 // port
|
||||||
requestLen += len(payload)
|
|
||||||
_buffer := buf.StackNewSize(requestLen)
|
_buffer := buf.StackNewSize(requestLen)
|
||||||
defer common.KeepAlive(_buffer)
|
defer common.KeepAlive(_buffer)
|
||||||
buffer := common.Dup(_buffer)
|
buffer := common.Dup(_buffer)
|
||||||
@ -197,7 +198,6 @@ func WriteClientRequest(stream io.Writer, request ClientRequest, payload []byte)
|
|||||||
binary.Write(buffer, binary.BigEndian, uint16(len(request.Host))),
|
binary.Write(buffer, binary.BigEndian, uint16(len(request.Host))),
|
||||||
common.Error(buffer.WriteString(request.Host)),
|
common.Error(buffer.WriteString(request.Host)),
|
||||||
binary.Write(buffer, binary.BigEndian, request.Port),
|
binary.Write(buffer, binary.BigEndian, request.Port),
|
||||||
common.Error(buffer.Write(payload)),
|
|
||||||
)
|
)
|
||||||
return common.Error(stream.Write(buffer.Bytes()))
|
return common.Error(stream.Write(buffer.Bytes()))
|
||||||
}
|
}
|
||||||
@ -237,13 +237,12 @@ func ReadServerResponse(stream io.Reader) (*ServerResponse, error) {
|
|||||||
return &serverResponse, nil
|
return &serverResponse, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func WriteServerResponse(stream io.Writer, response ServerResponse, payload []byte) error {
|
func WriteServerResponse(stream io.Writer, response ServerResponse) error {
|
||||||
var responseLen int
|
var responseLen int
|
||||||
responseLen += 1 // ok
|
responseLen += 1 // ok
|
||||||
responseLen += 4 // udp session id
|
responseLen += 4 // udp session id
|
||||||
responseLen += 2 // message len
|
responseLen += 2 // message len
|
||||||
responseLen += len(response.Message)
|
responseLen += len(response.Message)
|
||||||
responseLen += len(payload)
|
|
||||||
_buffer := buf.StackNewSize(responseLen)
|
_buffer := buf.StackNewSize(responseLen)
|
||||||
defer common.KeepAlive(_buffer)
|
defer common.KeepAlive(_buffer)
|
||||||
buffer := common.Dup(_buffer)
|
buffer := common.Dup(_buffer)
|
||||||
@ -257,7 +256,6 @@ func WriteServerResponse(stream io.Writer, response ServerResponse, payload []by
|
|||||||
binary.Write(buffer, binary.BigEndian, response.UDPSessionID),
|
binary.Write(buffer, binary.BigEndian, response.UDPSessionID),
|
||||||
binary.Write(buffer, binary.BigEndian, uint16(len(response.Message))),
|
binary.Write(buffer, binary.BigEndian, uint16(len(response.Message))),
|
||||||
common.Error(buffer.WriteString(response.Message)),
|
common.Error(buffer.WriteString(response.Message)),
|
||||||
common.Error(buffer.Write(payload)),
|
|
||||||
)
|
)
|
||||||
return common.Error(stream.Write(buffer.Bytes()))
|
return common.Error(stream.Write(buffer.Bytes()))
|
||||||
}
|
}
|
||||||
@ -341,9 +339,7 @@ func WriteUDPMessage(conn quic.Connection, message UDPMessage) error {
|
|||||||
buffer := common.Dup(_buffer)
|
buffer := common.Dup(_buffer)
|
||||||
defer buffer.Release()
|
defer buffer.Release()
|
||||||
err := writeUDPMessage(conn, message, buffer)
|
err := writeUDPMessage(conn, message, buffer)
|
||||||
// TODO: wait for change upstream
|
if errSize, ok := err.(quic.ErrMessageToLarge); ok {
|
||||||
if /*errSize, ok := err.(quic.ErrMessageToLarge); ok*/ false {
|
|
||||||
const errSize = 0
|
|
||||||
// need to frag
|
// need to frag
|
||||||
message.MsgID = uint16(rand.Intn(0xFFFF)) + 1 // msgID must be > 0 when fragCount > 1
|
message.MsgID = uint16(rand.Intn(0xFFFF)) + 1 // msgID must be > 0 when fragCount > 1
|
||||||
fragMsgs := FragUDPMessage(message, int(errSize))
|
fragMsgs := FragUDPMessage(message, int(errSize))
|
||||||
@ -373,3 +369,142 @@ func writeUDPMessage(conn quic.Connection, message UDPMessage, buffer *buf.Buffe
|
|||||||
)
|
)
|
||||||
return conn.SendMessage(buffer.Bytes())
|
return conn.SendMessage(buffer.Bytes())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var _ net.Conn = (*Conn)(nil)
|
||||||
|
|
||||||
|
type Conn struct {
|
||||||
|
quic.Stream
|
||||||
|
destination M.Socksaddr
|
||||||
|
responseWritten bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConn(stream quic.Stream, destination M.Socksaddr) *Conn {
|
||||||
|
return &Conn{
|
||||||
|
Stream: stream,
|
||||||
|
destination: destination,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) LocalAddr() net.Addr {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) RemoteAddr() net.Addr {
|
||||||
|
return c.destination.TCPAddr()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) ReaderReplaceable() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) WriterReplaceable() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) Upstream() any {
|
||||||
|
return c.Stream
|
||||||
|
}
|
||||||
|
|
||||||
|
type PacketConn struct {
|
||||||
|
session quic.Connection
|
||||||
|
stream quic.Stream
|
||||||
|
sessionId uint32
|
||||||
|
destination M.Socksaddr
|
||||||
|
msgCh <-chan *UDPMessage
|
||||||
|
closer io.Closer
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPacketConn(session quic.Connection, stream quic.Stream, sessionId uint32, destination M.Socksaddr, msgCh <-chan *UDPMessage, closer io.Closer) *PacketConn {
|
||||||
|
return &PacketConn{
|
||||||
|
session: session,
|
||||||
|
stream: stream,
|
||||||
|
sessionId: sessionId,
|
||||||
|
destination: destination,
|
||||||
|
msgCh: msgCh,
|
||||||
|
closer: closer,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *PacketConn) Hold() {
|
||||||
|
// Hold the stream until it's closed
|
||||||
|
buf := make([]byte, 1024)
|
||||||
|
for {
|
||||||
|
_, err := c.stream.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ = c.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *PacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
|
||||||
|
msg := <-c.msgCh
|
||||||
|
if msg == nil {
|
||||||
|
err = net.ErrClosed
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = common.Error(buffer.Write(msg.Data))
|
||||||
|
destination = M.ParseSocksaddrHostPort(msg.Host, msg.Port)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *PacketConn) ReadPacketThreadSafe() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
|
||||||
|
msg := <-c.msgCh
|
||||||
|
if msg == nil {
|
||||||
|
err = net.ErrClosed
|
||||||
|
return
|
||||||
|
}
|
||||||
|
buffer = buf.As(msg.Data)
|
||||||
|
destination = M.ParseSocksaddrHostPort(msg.Host, msg.Port)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *PacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
||||||
|
return WriteUDPMessage(c.session, UDPMessage{
|
||||||
|
SessionID: c.sessionId,
|
||||||
|
Host: destination.Unwrap().AddrString(),
|
||||||
|
Port: destination.Port,
|
||||||
|
FragCount: 1,
|
||||||
|
Data: buffer.Bytes(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *PacketConn) LocalAddr() net.Addr {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *PacketConn) RemoteAddr() net.Addr {
|
||||||
|
return c.destination.UDPAddr()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *PacketConn) SetDeadline(t time.Time) error {
|
||||||
|
return os.ErrInvalid
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *PacketConn) SetReadDeadline(t time.Time) error {
|
||||||
|
return os.ErrInvalid
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *PacketConn) SetWriteDeadline(t time.Time) error {
|
||||||
|
return os.ErrInvalid
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *PacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||||
|
return 0, nil, os.ErrInvalid
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *PacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||||
|
return 0, os.ErrInvalid
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *PacketConn) Read(b []byte) (n int, err error) {
|
||||||
|
return 0, os.ErrInvalid
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *PacketConn) Write(b []byte) (n int, err error) {
|
||||||
|
return 0, os.ErrInvalid
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *PacketConn) Close() error {
|
||||||
|
return common.Close(c.stream, c.closer)
|
||||||
|
}
|
||||||
|
@ -1,68 +0,0 @@
|
|||||||
package hysteria
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
|
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
|
||||||
N "github.com/sagernet/sing/common/network"
|
|
||||||
|
|
||||||
"github.com/lucas-clemente/quic-go"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
_ net.Conn = (*ServerConn)(nil)
|
|
||||||
_ N.HandshakeConn = (*ServerConn)(nil)
|
|
||||||
)
|
|
||||||
|
|
||||||
type ServerConn struct {
|
|
||||||
quic.Stream
|
|
||||||
destination M.Socksaddr
|
|
||||||
responseWritten bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewServerConn(stream quic.Stream, destination M.Socksaddr) *ServerConn {
|
|
||||||
return &ServerConn{
|
|
||||||
Stream: stream,
|
|
||||||
destination: destination,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *ServerConn) LocalAddr() net.Addr {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *ServerConn) RemoteAddr() net.Addr {
|
|
||||||
return c.destination.TCPAddr()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *ServerConn) Write(b []byte) (n int, err error) {
|
|
||||||
if !c.responseWritten {
|
|
||||||
err = WriteServerResponse(c.Stream, ServerResponse{
|
|
||||||
OK: true,
|
|
||||||
}, b)
|
|
||||||
c.responseWritten = true
|
|
||||||
return len(b), nil
|
|
||||||
}
|
|
||||||
return c.Stream.Write(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *ServerConn) ReaderReplaceable() bool {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *ServerConn) WriterReplaceable() bool {
|
|
||||||
return c.responseWritten
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *ServerConn) HandshakeFailure(err error) error {
|
|
||||||
if c.responseWritten {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return WriteServerResponse(c.Stream, ServerResponse{
|
|
||||||
Message: err.Error(),
|
|
||||||
}, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *ServerConn) Upstream() any {
|
|
||||||
return c.Stream
|
|
||||||
}
|
|
@ -5,25 +5,52 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
|
"github.com/sagernet/quic-go"
|
||||||
"github.com/sagernet/sing/common"
|
"github.com/sagernet/sing/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
type WrapPacketConn struct {
|
type PacketConnWrapper struct {
|
||||||
net.PacketConn
|
net.PacketConn
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *WrapPacketConn) SetReadBuffer(bytes int) error {
|
func (c *PacketConnWrapper) SetReadBuffer(bytes int) error {
|
||||||
return common.MustCast[*net.UDPConn](c.PacketConn).SetReadBuffer(bytes)
|
return common.MustCast[*net.UDPConn](c.PacketConn).SetReadBuffer(bytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *WrapPacketConn) SetWriteBuffer(bytes int) error {
|
func (c *PacketConnWrapper) SetWriteBuffer(bytes int) error {
|
||||||
return common.MustCast[*net.UDPConn](c.PacketConn).SetWriteBuffer(bytes)
|
return common.MustCast[*net.UDPConn](c.PacketConn).SetWriteBuffer(bytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *WrapPacketConn) SyscallConn() (syscall.RawConn, error) {
|
func (c *PacketConnWrapper) SyscallConn() (syscall.RawConn, error) {
|
||||||
return common.MustCast[*net.UDPConn](c.PacketConn).SyscallConn()
|
return common.MustCast[*net.UDPConn](c.PacketConn).SyscallConn()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *WrapPacketConn) File() (f *os.File, err error) {
|
func (c *PacketConnWrapper) File() (f *os.File, err error) {
|
||||||
return common.MustCast[*net.UDPConn](c.PacketConn).File()
|
return common.MustCast[*net.UDPConn](c.PacketConn).File()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *PacketConnWrapper) Upstream() any {
|
||||||
|
return c.PacketConn
|
||||||
|
}
|
||||||
|
|
||||||
|
type StreamWrapper struct {
|
||||||
|
quic.Stream
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StreamWrapper) Upstream() any {
|
||||||
|
return s.Stream
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StreamWrapper) ReaderReplaceable() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StreamWrapper) WriterReplaceable() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StreamWrapper) Close() error {
|
||||||
|
s.CancelRead(0)
|
||||||
|
s.Stream.Close()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user