diff --git a/inbound/hysteria.go b/inbound/hysteria.go index f3fa337c..f3fb51ce 100644 --- a/inbound/hysteria.go +++ b/inbound/hysteria.go @@ -7,19 +7,20 @@ import ( "context" "net" "net/netip" + "sync" + "github.com/sagernet/quic-go" + "github.com/sagernet/quic-go/congestion" "github.com/sagernet/sing-box/adapter" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/transport/hysteria" - dns "github.com/sagernet/sing-dns" + "github.com/sagernet/sing-dns" "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" - - "github.com/lucas-clemente/quic-go" ) var _ adapter.Inbound = (*Hysteria)(nil) @@ -37,6 +38,10 @@ type Hysteria struct { sendBPS uint64 recvBPS uint64 listener quic.Listener + udpAccess sync.RWMutex + udpSessionId uint32 + udpSessions map[uint32]chan *hysteria.UDPMessage + udpDefragger hysteria.Defragger } func NewHysteria(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.HysteriaInboundOptions) (*Hysteria, error) { @@ -105,6 +110,7 @@ func NewHysteria(ctx context.Context, router adapter.Router, logger log.ContextL xplusKey: xplus, sendBPS: up, recvBPS: down, + udpSessions: make(map[uint32]chan *hysteria.UDPMessage), } if options.TLS == nil || !options.TLS.Enabled { return nil, ErrTLSRequired @@ -138,6 +144,7 @@ func (h *Hysteria) Start() error { } if len(h.xplusKey) > 0 { packetConn = hysteria.NewXPlusPacketConn(packetConn, h.xplusKey) + packetConn = &hysteria.PacketConnWrapper{PacketConn: packetConn} } err = h.tlsConfig.Start() if err != nil { @@ -159,6 +166,7 @@ func (h *Hysteria) acceptLoop() { if err != nil { return } + h.logger.InfoContext(ctx, "inbound connection from ", conn.RemoteAddr()) go func() { hErr := h.accept(ctx, conn) if hErr != nil { @@ -202,23 +210,51 @@ func (h *Hysteria) accept(ctx context.Context, conn quic.Connection) error { if err != nil { return err } - // TODO: set congestion control + conn.SetCongestionControl(hysteria.NewBrutalSender(congestion.ByteCount(serverSendBPS))) go h.udpRecvLoop(conn) - var stream quic.Stream for { + var stream quic.Stream stream, err = conn.AcceptStream(ctx) if err != nil { return err } - hErr := h.acceptStream(ctx, conn, stream) - if hErr != nil { - stream.Close() - NewError(h.logger, ctx, E.Cause(hErr, "process stream from ", conn.RemoteAddr())) - } + go func() { + hErr := h.acceptStream(ctx, conn /*&hysteria.StreamWrapper{Stream: stream}*/, stream) + if hErr != nil { + stream.Close() + NewError(h.logger, ctx, E.Cause(hErr, "process stream from ", conn.RemoteAddr())) + } + }() } } func (h *Hysteria) udpRecvLoop(conn quic.Connection) { + for { + packet, err := conn.ReceiveMessage() + if err != nil { + return + } + message, err := hysteria.ParseUDPMessage(packet) + if err != nil { + h.logger.Error("parse udp message: ", err) + continue + } + dfMsg := h.udpDefragger.Feed(message) + if dfMsg == nil { + continue + } + h.udpAccess.RLock() + ch, ok := h.udpSessions[dfMsg.SessionID] + if ok { + select { + case ch <- dfMsg: + // OK + default: + // Silently drop the message when the channel is full + } + } + h.udpAccess.RUnlock() + } } func (h *Hysteria) acceptStream(ctx context.Context, conn quic.Connection, stream quic.Stream) error { @@ -226,15 +262,11 @@ func (h *Hysteria) acceptStream(ctx context.Context, conn quic.Connection, strea if err != nil { return err } - if request.UDP { - err = hysteria.WriteServerResponse(stream, hysteria.ServerResponse{ - Message: "unsupported", - }, nil) - if err != nil { - return err - } - stream.Close() - return nil + err = hysteria.WriteServerResponse(stream, hysteria.ServerResponse{ + OK: true, + }) + if err != nil { + return err } var metadata adapter.InboundContext metadata.Inbound = h.tag @@ -242,13 +274,43 @@ func (h *Hysteria) acceptStream(ctx context.Context, conn quic.Connection, strea metadata.SniffEnabled = h.listenOptions.SniffEnabled metadata.SniffOverrideDestination = h.listenOptions.SniffOverrideDestination metadata.DomainStrategy = dns.DomainStrategy(h.listenOptions.DomainStrategy) - metadata.Network = N.NetworkTCP metadata.Source = M.SocksaddrFromNet(conn.RemoteAddr()) metadata.Destination = M.ParseSocksaddrHostPort(request.Host, request.Port) - return h.router.RouteConnection(ctx, hysteria.NewServerConn(stream, metadata.Destination), metadata) + if !request.UDP { + h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination) + metadata.Network = N.NetworkTCP + return h.router.RouteConnection(ctx, hysteria.NewConn(stream, metadata.Destination), metadata) + } else { + h.logger.InfoContext(ctx, "inbound packet connection to ", metadata.Destination) + var id uint32 + h.udpAccess.Lock() + id = h.udpSessionId + nCh := make(chan *hysteria.UDPMessage, 1024) + h.udpSessions[id] = nCh + h.udpSessionId += 1 + h.udpAccess.Unlock() + metadata.Network = N.NetworkUDP + packetConn := hysteria.NewPacketConn(conn, stream, id, metadata.Destination, nCh, common.Closer(func() error { + h.udpAccess.Lock() + if ch, ok := h.udpSessions[id]; ok { + close(ch) + delete(h.udpSessions, id) + } + h.udpAccess.Unlock() + return nil + })) + go packetConn.Hold() + return h.router.RoutePacketConnection(ctx, packetConn, metadata) + } } func (h *Hysteria) Close() error { + h.udpAccess.Lock() + for _, session := range h.udpSessions { + close(session) + } + h.udpSessions = make(map[uint32]chan *hysteria.UDPMessage) + h.udpAccess.Unlock() return common.Close( h.listener, common.PtrOrNil(h.tlsConfig), diff --git a/inbound/naive.go b/inbound/naive.go index 424156dd..a1537562 100644 --- a/inbound/naive.go +++ b/inbound/naive.go @@ -195,6 +195,8 @@ func (n *Naive) newConnection(ctx context.Context, conn net.Conn, source, destin metadata.Network = N.NetworkTCP metadata.Source = source metadata.Destination = destination + n.logger.InfoContext(ctx, "inbound connection from ", metadata.Source) + n.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination) hErr := n.router.RouteConnection(ctx, conn, metadata) if hErr != nil { conn.Close() diff --git a/outbound/hysteria.go b/outbound/hysteria.go index a1339e4a..bdf51b5c 100644 --- a/outbound/hysteria.go +++ b/outbound/hysteria.go @@ -10,6 +10,8 @@ import ( "os" "sync" + "github.com/sagernet/quic-go" + "github.com/sagernet/quic-go/congestion" "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/common/dialer" C "github.com/sagernet/sing-box/constant" @@ -21,8 +23,6 @@ import ( E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" - - "github.com/lucas-clemente/quic-go" ) var _ adapter.Outbound = (*Hysteria)(nil) @@ -171,7 +171,7 @@ func (h *Hysteria) offer(ctx context.Context) (quic.Connection, error) { } h.udpSessions = make(map[uint32]chan *hysteria.UDPMessage) h.udpDefragger = hysteria.Defragger{} - go h.recvLoop(conn) + go h.udpRecvLoop(conn) } return conn, nil } @@ -186,7 +186,7 @@ func (h *Hysteria) offerNew(ctx context.Context) (quic.Connection, error) { if h.xplusKey != nil { packetConn = hysteria.NewXPlusPacketConn(packetConn, h.xplusKey) } - packetConn = &hysteria.WrapPacketConn{PacketConn: packetConn} + packetConn = &hysteria.PacketConnWrapper{PacketConn: packetConn} quicConn, err := quic.Dial(packetConn, udpConn.RemoteAddr(), h.serverAddr.AddrString(), h.tlsConfig, h.quicConfig) if err != nil { packetConn.Close() @@ -203,20 +203,23 @@ func (h *Hysteria) offerNew(ctx context.Context) (quic.Connection, error) { Auth: h.authKey, }) if err != nil { + packetConn.Close() return nil, err } serverHello, err := hysteria.ReadServerHello(controlStream) if err != nil { + packetConn.Close() return nil, err } if !serverHello.OK { + packetConn.Close() return nil, E.New("remote error: ", serverHello.Message) } - // TODO: set congestion control + quicConn.SetCongestionControl(hysteria.NewBrutalSender(congestion.ByteCount(serverHello.RecvBPS))) return quicConn, nil } -func (h *Hysteria) recvLoop(conn quic.Connection) { +func (h *Hysteria) udpRecvLoop(conn quic.Connection) { for { packet, err := conn.ReceiveMessage() if err != nil { @@ -260,35 +263,58 @@ func (h *Hysteria) Close() error { return nil } +func (h *Hysteria) open(ctx context.Context) (quic.Connection, quic.Stream, error) { + conn, err := h.offer(ctx) + if err != nil { + return nil, nil, err + } + stream, err := conn.OpenStream() + if err != nil { + return nil, nil, err + } + return conn, &hysteria.StreamWrapper{Stream: stream}, nil +} + func (h *Hysteria) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { switch N.NetworkName(network) { case N.NetworkTCP: - conn, err := h.offer(ctx) + h.logger.InfoContext(ctx, "outbound connection to ", destination) + _, stream, err := h.open(ctx) if err != nil { return nil, err } - stream, err := conn.OpenStream() + err = hysteria.WriteClientRequest(stream, hysteria.ClientRequest{ + Host: destination.AddrString(), + Port: destination.Port, + }) if err != nil { + stream.Close() return nil, err } - return hysteria.NewClientConn(stream, destination), nil + response, err := hysteria.ReadServerResponse(stream) + if err != nil { + stream.Close() + return nil, err + } + if !response.OK { + stream.Close() + return nil, E.New("remote error: ", response.Message) + } + return hysteria.NewConn(stream, destination), nil case N.NetworkUDP: conn, err := h.ListenPacket(ctx, destination) if err != nil { return nil, err } - return conn.(*hysteria.ClientPacketConn), nil + return conn.(*hysteria.PacketConn), nil default: return nil, E.New("unsupported network: ", network) } } func (h *Hysteria) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { - conn, err := h.offer(ctx) - if err != nil { - return nil, err - } - stream, err := conn.OpenStream() + h.logger.InfoContext(ctx, "outbound packet connection to ", destination) + conn, stream, err := h.open(ctx) if err != nil { return nil, err } @@ -296,7 +322,7 @@ func (h *Hysteria) ListenPacket(ctx context.Context, destination M.Socksaddr) (n UDP: true, Host: destination.AddrString(), Port: destination.Port, - }, nil) + }) if err != nil { stream.Close() return nil, err @@ -313,12 +339,9 @@ func (h *Hysteria) ListenPacket(ctx context.Context, destination M.Socksaddr) (n } h.udpAccess.Lock() nCh := make(chan *hysteria.UDPMessage, 1024) - // Store the current session map for CloseFunc below - // to ensures that we are adding and removing sessions on the same map, - // as reconnecting will reassign the map h.udpSessions[response.UDPSessionID] = nCh h.udpAccess.Unlock() - packetConn := hysteria.NewClientPacketConn(conn, stream, response.UDPSessionID, destination, nCh, common.Closer(func() error { + packetConn := hysteria.NewPacketConn(conn, stream, response.UDPSessionID, destination, nCh, common.Closer(func() error { h.udpAccess.Lock() if ch, ok := h.udpSessions[response.UDPSessionID]; ok { close(ch) diff --git a/test/box_test.go b/test/box_test.go index 939a52c6..37f85956 100644 --- a/test/box_test.go +++ b/test/box_test.go @@ -56,10 +56,21 @@ func testTCP(t *testing.T, clientPort uint16, testPort uint16) { dialTCP := func() (net.Conn, error) { return dialer.DialContext(context.Background(), "tcp", M.ParseSocksaddrHostPort("127.0.0.1", testPort)) } - require.NoError(t, testPingPongWithConn(t, testPort, dialTCP)) require.NoError(t, testLargeDataWithConn(t, testPort, dialTCP)) } +func testSuitHy(t *testing.T, clientPort uint16, testPort uint16) { + dialer := socks.NewClient(N.SystemDialer, M.ParseSocksaddrHostPort("127.0.0.1", clientPort), socks.Version5, "", "") + dialTCP := func() (net.Conn, error) { + return dialer.DialContext(context.Background(), "tcp", M.ParseSocksaddrHostPort("127.0.0.1", testPort)) + } + dialUDP := func() (net.PacketConn, error) { + return dialer.ListenPacket(context.Background(), M.ParseSocksaddrHostPort("127.0.0.1", testPort)) + } + require.NoError(t, testPingPongWithConn(t, testPort, dialTCP)) + require.NoError(t, testPingPongWithPacketConn(t, testPort, dialUDP)) +} + func testSuitWg(t *testing.T, clientPort uint16, testPort uint16) { dialer := socks.NewClient(N.SystemDialer, M.ParseSocksaddrHostPort("127.0.0.1", clientPort), socks.Version5, "", "") dialTCP := func() (net.Conn, error) { diff --git a/test/clash_test.go b/test/clash_test.go index 908480cc..b27df7ea 100644 --- a/test/clash_test.go +++ b/test/clash_test.go @@ -7,6 +7,8 @@ import ( "errors" "io" "net" + "net/http" + _ "net/http/pprof" "net/netip" "sync" "testing" @@ -91,6 +93,12 @@ func init() { io.Copy(io.Discard, imageStream) } + go func() { + err = http.ListenAndServe("0.0.0.0:8965", nil) + if err != nil { + log.Debug(err) + } + }() } func newPingPongPair() (chan []byte, chan []byte, func(t *testing.T) error) { diff --git a/test/config/hysteria-client.json b/test/config/hysteria-client.json new file mode 100644 index 00000000..3328c510 --- /dev/null +++ b/test/config/hysteria-client.json @@ -0,0 +1,12 @@ +{ + "server": "127.0.0.1:10000", + "auth_str": "password", + "obfs": "fuck me till the daylight", + "up_mbps": 100, + "down_mbps": 100, + "socks5": { + "listen": "127.0.0.1:10001" + }, + "server_name": "example.org", + "ca": "/etc/hysteria/ca.pem" +} \ No newline at end of file diff --git a/test/config/hysteria.json b/test/config/hysteria-server.json similarity index 100% rename from test/config/hysteria.json rename to test/config/hysteria-server.json diff --git a/test/hysteria_test.go b/test/hysteria_test.go index c9d0568f..057d12ca 100644 --- a/test/hysteria_test.go +++ b/test/hysteria_test.go @@ -80,7 +80,50 @@ func TestHysteriaSelf(t *testing.T) { }, }, }) - testTCP(t, clientPort, testPort) + testSuitHy(t, clientPort, testPort) +} + +func TestHysteriaInbound(t *testing.T) { + if !C.QUIC_AVAILABLE { + t.Skip("QUIC not included") + } + caPem, certPem, keyPem := createSelfSignedCertificate(t, "example.org") + startInstance(t, option.Options{ + Log: &option.LogOptions{ + Level: "trace", + }, + Inbounds: []option.Inbound{ + { + Type: C.TypeHysteria, + HysteriaOptions: option.HysteriaInboundOptions{ + ListenOptions: option.ListenOptions{ + Listen: option.ListenAddress(netip.IPv4Unspecified()), + ListenPort: serverPort, + }, + UpMbps: 100, + DownMbps: 100, + AuthString: "password", + Obfs: "fuck me till the daylight", + TLS: &option.InboundTLSOptions{ + Enabled: true, + ServerName: "example.org", + CertificatePath: certPem, + KeyPath: keyPem, + }, + }, + }, + }, + }) + startDockerContainer(t, DockerOptions{ + Image: ImageHysteria, + Ports: []uint16{serverPort, clientPort}, + Cmd: []string{"-c", "/etc/hysteria/config.json", "client"}, + Bind: map[string]string{ + "hysteria-client.json": "/etc/hysteria/config.json", + caPem: "/etc/hysteria/ca.pem", + }, + }) + testSuit(t, clientPort, testPort) } func TestHysteriaOutbound(t *testing.T) { @@ -93,9 +136,9 @@ func TestHysteriaOutbound(t *testing.T) { Ports: []uint16{serverPort, testPort}, Cmd: []string{"-c", "/etc/hysteria/config.json", "server"}, Bind: map[string]string{ - "hysteria.json": "/etc/hysteria/config.json", - certPem: "/etc/hysteria/cert.pem", - keyPem: "/etc/hysteria/key.pem", + "hysteria-server.json": "/etc/hysteria/config.json", + certPem: "/etc/hysteria/cert.pem", + keyPem: "/etc/hysteria/key.pem", }, }) startInstance(t, option.Options{ @@ -131,5 +174,5 @@ func TestHysteriaOutbound(t *testing.T) { }, }, }) - testSuit(t, clientPort, testPort) + testSuitHy(t, clientPort, testPort) } diff --git a/transport/hysteria/brutal.go b/transport/hysteria/brutal.go new file mode 100644 index 00000000..0e6dc794 --- /dev/null +++ b/transport/hysteria/brutal.go @@ -0,0 +1,149 @@ +package hysteria + +import ( + "time" + + "github.com/sagernet/quic-go/congestion" +) + +const ( + initMaxDatagramSize = 1252 + + pktInfoSlotCount = 4 + minSampleCount = 50 + minAckRate = 0.8 +) + +type BrutalSender struct { + rttStats congestion.RTTStatsProvider + bps congestion.ByteCount + maxDatagramSize congestion.ByteCount + pacer *pacer + + pktInfoSlots [pktInfoSlotCount]pktInfo + ackRate float64 +} + +type pktInfo struct { + Timestamp int64 + AckCount uint64 + LossCount uint64 +} + +func NewBrutalSender(bps congestion.ByteCount) *BrutalSender { + bs := &BrutalSender{ + bps: bps, + maxDatagramSize: initMaxDatagramSize, + ackRate: 1, + } + bs.pacer = newPacer(func() congestion.ByteCount { + return congestion.ByteCount(float64(bs.bps) / bs.ackRate) + }) + return bs +} + +func (b *BrutalSender) SetRTTStatsProvider(rttStats congestion.RTTStatsProvider) { + b.rttStats = rttStats +} + +func (b *BrutalSender) TimeUntilSend(bytesInFlight congestion.ByteCount) time.Time { + return b.pacer.TimeUntilSend() +} + +func (b *BrutalSender) HasPacingBudget() bool { + return b.pacer.Budget(time.Now()) >= b.maxDatagramSize +} + +func (b *BrutalSender) CanSend(bytesInFlight congestion.ByteCount) bool { + return bytesInFlight < b.GetCongestionWindow() +} + +func (b *BrutalSender) GetCongestionWindow() congestion.ByteCount { + rtt := maxDuration(b.rttStats.LatestRTT(), b.rttStats.SmoothedRTT()) + if rtt <= 0 { + return 10240 + } + return congestion.ByteCount(float64(b.bps) * rtt.Seconds() * 1.5 / b.ackRate) +} + +func (b *BrutalSender) OnPacketSent(sentTime time.Time, bytesInFlight congestion.ByteCount, + packetNumber congestion.PacketNumber, bytes congestion.ByteCount, isRetransmittable bool, +) { + b.pacer.SentPacket(sentTime, bytes) +} + +func (b *BrutalSender) OnPacketAcked(number congestion.PacketNumber, ackedBytes congestion.ByteCount, + priorInFlight congestion.ByteCount, eventTime time.Time, +) { + currentTimestamp := eventTime.Unix() + slot := currentTimestamp % pktInfoSlotCount + if b.pktInfoSlots[slot].Timestamp == currentTimestamp { + b.pktInfoSlots[slot].AckCount++ + } else { + // uninitialized slot or too old, reset + b.pktInfoSlots[slot].Timestamp = currentTimestamp + b.pktInfoSlots[slot].AckCount = 1 + b.pktInfoSlots[slot].LossCount = 0 + } + b.updateAckRate(currentTimestamp) +} + +func (b *BrutalSender) OnPacketLost(number congestion.PacketNumber, lostBytes congestion.ByteCount, + priorInFlight congestion.ByteCount, +) { + currentTimestamp := time.Now().Unix() + slot := currentTimestamp % pktInfoSlotCount + if b.pktInfoSlots[slot].Timestamp == currentTimestamp { + b.pktInfoSlots[slot].LossCount++ + } else { + // uninitialized slot or too old, reset + b.pktInfoSlots[slot].Timestamp = currentTimestamp + b.pktInfoSlots[slot].AckCount = 0 + b.pktInfoSlots[slot].LossCount = 1 + } + b.updateAckRate(currentTimestamp) +} + +func (b *BrutalSender) SetMaxDatagramSize(size congestion.ByteCount) { + b.maxDatagramSize = size + b.pacer.SetMaxDatagramSize(size) +} + +func (b *BrutalSender) updateAckRate(currentTimestamp int64) { + minTimestamp := currentTimestamp - pktInfoSlotCount + var ackCount, lossCount uint64 + for _, info := range b.pktInfoSlots { + if info.Timestamp < minTimestamp { + continue + } + ackCount += info.AckCount + lossCount += info.LossCount + } + if ackCount+lossCount < minSampleCount { + b.ackRate = 1 + } + rate := float64(ackCount) / float64(ackCount+lossCount) + if rate < minAckRate { + b.ackRate = minAckRate + } + b.ackRate = rate +} + +func (b *BrutalSender) InSlowStart() bool { + return false +} + +func (b *BrutalSender) InRecovery() bool { + return false +} + +func (b *BrutalSender) MaybeExitSlowStart() {} + +func (b *BrutalSender) OnRetransmissionTimeout(packetsRetransmitted bool) {} + +func maxDuration(a, b time.Duration) time.Duration { + if a > b { + return a + } + return b +} diff --git a/transport/hysteria/client.go b/transport/hysteria/client.go deleted file mode 100644 index 60357e50..00000000 --- a/transport/hysteria/client.go +++ /dev/null @@ -1,186 +0,0 @@ -package hysteria - -import ( - "io" - "net" - "os" - "time" - - "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/buf" - E "github.com/sagernet/sing/common/exceptions" - M "github.com/sagernet/sing/common/metadata" - - "github.com/lucas-clemente/quic-go" -) - -var _ net.Conn = (*ClientConn)(nil) - -type ClientConn struct { - quic.Stream - destination M.Socksaddr - requestWritten bool - responseRead bool -} - -func NewClientConn(stream quic.Stream, destination M.Socksaddr) *ClientConn { - return &ClientConn{ - Stream: stream, - destination: destination, - } -} - -func (c *ClientConn) Read(b []byte) (n int, err error) { - if !c.responseRead { - var response *ServerResponse - response, err = ReadServerResponse(c.Stream) - if err != nil { - return - } - if !response.OK { - return 0, E.New("remote error: " + response.Message) - } - c.responseRead = true - } - return c.Stream.Read(b) -} - -func (c *ClientConn) Write(b []byte) (n int, err error) { - if !c.requestWritten { - err = WriteClientRequest(c.Stream, ClientRequest{ - UDP: false, - Host: c.destination.Unwrap().AddrString(), - Port: c.destination.Port, - }, b) - if err != nil { - return - } - c.requestWritten = true - return len(b), nil - } - return c.Stream.Write(b) -} - -func (c *ClientConn) LocalAddr() net.Addr { - return nil -} - -func (c *ClientConn) RemoteAddr() net.Addr { - return c.destination.TCPAddr() -} - -func (c *ClientConn) Upstream() any { - return c.Stream -} - -func (c *ClientConn) ReaderReplaceable() bool { - return c.responseRead -} - -func (c *ClientConn) WriterReplaceable() bool { - return c.requestWritten -} - -type ClientPacketConn struct { - session quic.Connection - stream quic.Stream - sessionId uint32 - destination M.Socksaddr - msgCh <-chan *UDPMessage - closer io.Closer -} - -func NewClientPacketConn(session quic.Connection, stream quic.Stream, sessionId uint32, destination M.Socksaddr, msgCh <-chan *UDPMessage, closer io.Closer) *ClientPacketConn { - return &ClientPacketConn{ - session: session, - stream: stream, - sessionId: sessionId, - destination: destination, - msgCh: msgCh, - closer: closer, - } -} - -func (c *ClientPacketConn) Hold() { - // Hold the stream until it's closed - buf := make([]byte, 1024) - for { - _, err := c.stream.Read(buf) - if err != nil { - break - } - } - _ = c.Close() -} - -func (c *ClientPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { - msg := <-c.msgCh - if msg == nil { - err = net.ErrClosed - return - } - err = common.Error(buffer.Write(msg.Data)) - destination = M.ParseSocksaddrHostPort(msg.Host, msg.Port) - return -} - -func (c *ClientPacketConn) ReadPacketThreadSafe() (buffer *buf.Buffer, destination M.Socksaddr, err error) { - msg := <-c.msgCh - if msg == nil { - err = net.ErrClosed - return - } - buffer = buf.As(msg.Data) - destination = M.ParseSocksaddrHostPort(msg.Host, msg.Port) - return -} - -func (c *ClientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { - return WriteUDPMessage(c.session, UDPMessage{ - SessionID: c.sessionId, - Host: destination.Unwrap().AddrString(), - Port: destination.Port, - FragCount: 1, - Data: buffer.Bytes(), - }) -} - -func (c *ClientPacketConn) LocalAddr() net.Addr { - return nil -} - -func (c *ClientPacketConn) RemoteAddr() net.Addr { - return c.destination.UDPAddr() -} - -func (c *ClientPacketConn) SetDeadline(t time.Time) error { - return os.ErrInvalid -} - -func (c *ClientPacketConn) SetReadDeadline(t time.Time) error { - return os.ErrInvalid -} - -func (c *ClientPacketConn) SetWriteDeadline(t time.Time) error { - return os.ErrInvalid -} - -func (c *ClientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { - return 0, nil, os.ErrInvalid -} - -func (c *ClientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { - return 0, os.ErrInvalid -} - -func (c *ClientPacketConn) Read(b []byte) (n int, err error) { - return 0, os.ErrInvalid -} - -func (c *ClientPacketConn) Write(b []byte) (n int, err error) { - return 0, os.ErrInvalid -} - -func (c *ClientPacketConn) Close() error { - return common.Close(c.stream, c.closer) -} diff --git a/transport/hysteria/pacer.go b/transport/hysteria/pacer.go new file mode 100644 index 00000000..7e67f7f4 --- /dev/null +++ b/transport/hysteria/pacer.go @@ -0,0 +1,86 @@ +package hysteria + +import ( + "math" + "time" + + "github.com/sagernet/quic-go/congestion" +) + +const ( + maxBurstPackets = 10 + minPacingDelay = time.Millisecond +) + +// The pacer implements a token bucket pacing algorithm. +type pacer struct { + budgetAtLastSent congestion.ByteCount + maxDatagramSize congestion.ByteCount + lastSentTime time.Time + getBandwidth func() congestion.ByteCount // in bytes/s +} + +func newPacer(getBandwidth func() congestion.ByteCount) *pacer { + p := &pacer{ + budgetAtLastSent: maxBurstPackets * initMaxDatagramSize, + maxDatagramSize: initMaxDatagramSize, + getBandwidth: getBandwidth, + } + return p +} + +func (p *pacer) SentPacket(sendTime time.Time, size congestion.ByteCount) { + budget := p.Budget(sendTime) + if size > budget { + p.budgetAtLastSent = 0 + } else { + p.budgetAtLastSent = budget - size + } + p.lastSentTime = sendTime +} + +func (p *pacer) Budget(now time.Time) congestion.ByteCount { + if p.lastSentTime.IsZero() { + return p.maxBurstSize() + } + budget := p.budgetAtLastSent + (p.getBandwidth()*congestion.ByteCount(now.Sub(p.lastSentTime).Nanoseconds()))/1e9 + return minByteCount(p.maxBurstSize(), budget) +} + +func (p *pacer) maxBurstSize() congestion.ByteCount { + return maxByteCount( + congestion.ByteCount((minPacingDelay+time.Millisecond).Nanoseconds())*p.getBandwidth()/1e9, + maxBurstPackets*p.maxDatagramSize, + ) +} + +// TimeUntilSend returns when the next packet should be sent. +// It returns the zero value of time.Time if a packet can be sent immediately. +func (p *pacer) TimeUntilSend() time.Time { + if p.budgetAtLastSent >= p.maxDatagramSize { + return time.Time{} + } + return p.lastSentTime.Add(maxDuration( + minPacingDelay, + time.Duration(math.Ceil(float64(p.maxDatagramSize-p.budgetAtLastSent)*1e9/ + float64(p.getBandwidth())))*time.Nanosecond, + )) +} + +func (p *pacer) SetMaxDatagramSize(s congestion.ByteCount) { + p.maxDatagramSize = s +} + +func maxByteCount(a, b congestion.ByteCount) congestion.ByteCount { + if a < b { + return b + } + return a +} + +func minByteCount(a, b congestion.ByteCount) congestion.ByteCount { + if a < b { + return a + } + return b +} diff --git a/transport/hysteria/protocol.go b/transport/hysteria/protocol.go index df62bcb8..d3893b80 100644 --- a/transport/hysteria/protocol.go +++ b/transport/hysteria/protocol.go @@ -5,13 +5,15 @@ import ( "encoding/binary" "io" "math/rand" + "net" + "os" "time" + "github.com/sagernet/quic-go" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" - - "github.com/lucas-clemente/quic-go" + M "github.com/sagernet/sing/common/metadata" ) const ( @@ -177,13 +179,12 @@ func ReadClientRequest(stream io.Reader) (*ClientRequest, error) { return &clientRequest, nil } -func WriteClientRequest(stream io.Writer, request ClientRequest, payload []byte) error { +func WriteClientRequest(stream io.Writer, request ClientRequest) error { var requestLen int requestLen += 1 // udp requestLen += 2 // host len requestLen += len(request.Host) requestLen += 2 // port - requestLen += len(payload) _buffer := buf.StackNewSize(requestLen) defer common.KeepAlive(_buffer) buffer := common.Dup(_buffer) @@ -197,7 +198,6 @@ func WriteClientRequest(stream io.Writer, request ClientRequest, payload []byte) binary.Write(buffer, binary.BigEndian, uint16(len(request.Host))), common.Error(buffer.WriteString(request.Host)), binary.Write(buffer, binary.BigEndian, request.Port), - common.Error(buffer.Write(payload)), ) return common.Error(stream.Write(buffer.Bytes())) } @@ -237,13 +237,12 @@ func ReadServerResponse(stream io.Reader) (*ServerResponse, error) { return &serverResponse, nil } -func WriteServerResponse(stream io.Writer, response ServerResponse, payload []byte) error { +func WriteServerResponse(stream io.Writer, response ServerResponse) error { var responseLen int responseLen += 1 // ok responseLen += 4 // udp session id responseLen += 2 // message len responseLen += len(response.Message) - responseLen += len(payload) _buffer := buf.StackNewSize(responseLen) defer common.KeepAlive(_buffer) buffer := common.Dup(_buffer) @@ -257,7 +256,6 @@ func WriteServerResponse(stream io.Writer, response ServerResponse, payload []by binary.Write(buffer, binary.BigEndian, response.UDPSessionID), binary.Write(buffer, binary.BigEndian, uint16(len(response.Message))), common.Error(buffer.WriteString(response.Message)), - common.Error(buffer.Write(payload)), ) return common.Error(stream.Write(buffer.Bytes())) } @@ -341,9 +339,7 @@ func WriteUDPMessage(conn quic.Connection, message UDPMessage) error { buffer := common.Dup(_buffer) defer buffer.Release() err := writeUDPMessage(conn, message, buffer) - // TODO: wait for change upstream - if /*errSize, ok := err.(quic.ErrMessageToLarge); ok*/ false { - const errSize = 0 + if errSize, ok := err.(quic.ErrMessageToLarge); ok { // need to frag message.MsgID = uint16(rand.Intn(0xFFFF)) + 1 // msgID must be > 0 when fragCount > 1 fragMsgs := FragUDPMessage(message, int(errSize)) @@ -373,3 +369,142 @@ func writeUDPMessage(conn quic.Connection, message UDPMessage, buffer *buf.Buffe ) return conn.SendMessage(buffer.Bytes()) } + +var _ net.Conn = (*Conn)(nil) + +type Conn struct { + quic.Stream + destination M.Socksaddr + responseWritten bool +} + +func NewConn(stream quic.Stream, destination M.Socksaddr) *Conn { + return &Conn{ + Stream: stream, + destination: destination, + } +} + +func (c *Conn) LocalAddr() net.Addr { + return nil +} + +func (c *Conn) RemoteAddr() net.Addr { + return c.destination.TCPAddr() +} + +func (c *Conn) ReaderReplaceable() bool { + return true +} + +func (c *Conn) WriterReplaceable() bool { + return true +} + +func (c *Conn) Upstream() any { + return c.Stream +} + +type PacketConn struct { + session quic.Connection + stream quic.Stream + sessionId uint32 + destination M.Socksaddr + msgCh <-chan *UDPMessage + closer io.Closer +} + +func NewPacketConn(session quic.Connection, stream quic.Stream, sessionId uint32, destination M.Socksaddr, msgCh <-chan *UDPMessage, closer io.Closer) *PacketConn { + return &PacketConn{ + session: session, + stream: stream, + sessionId: sessionId, + destination: destination, + msgCh: msgCh, + closer: closer, + } +} + +func (c *PacketConn) Hold() { + // Hold the stream until it's closed + buf := make([]byte, 1024) + for { + _, err := c.stream.Read(buf) + if err != nil { + break + } + } + _ = c.Close() +} + +func (c *PacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { + msg := <-c.msgCh + if msg == nil { + err = net.ErrClosed + return + } + err = common.Error(buffer.Write(msg.Data)) + destination = M.ParseSocksaddrHostPort(msg.Host, msg.Port) + return +} + +func (c *PacketConn) ReadPacketThreadSafe() (buffer *buf.Buffer, destination M.Socksaddr, err error) { + msg := <-c.msgCh + if msg == nil { + err = net.ErrClosed + return + } + buffer = buf.As(msg.Data) + destination = M.ParseSocksaddrHostPort(msg.Host, msg.Port) + return +} + +func (c *PacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + return WriteUDPMessage(c.session, UDPMessage{ + SessionID: c.sessionId, + Host: destination.Unwrap().AddrString(), + Port: destination.Port, + FragCount: 1, + Data: buffer.Bytes(), + }) +} + +func (c *PacketConn) LocalAddr() net.Addr { + return nil +} + +func (c *PacketConn) RemoteAddr() net.Addr { + return c.destination.UDPAddr() +} + +func (c *PacketConn) SetDeadline(t time.Time) error { + return os.ErrInvalid +} + +func (c *PacketConn) SetReadDeadline(t time.Time) error { + return os.ErrInvalid +} + +func (c *PacketConn) SetWriteDeadline(t time.Time) error { + return os.ErrInvalid +} + +func (c *PacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + return 0, nil, os.ErrInvalid +} + +func (c *PacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + return 0, os.ErrInvalid +} + +func (c *PacketConn) Read(b []byte) (n int, err error) { + return 0, os.ErrInvalid +} + +func (c *PacketConn) Write(b []byte) (n int, err error) { + return 0, os.ErrInvalid +} + +func (c *PacketConn) Close() error { + return common.Close(c.stream, c.closer) +} diff --git a/transport/hysteria/server.go b/transport/hysteria/server.go deleted file mode 100644 index 75ed67e0..00000000 --- a/transport/hysteria/server.go +++ /dev/null @@ -1,68 +0,0 @@ -package hysteria - -import ( - "net" - - M "github.com/sagernet/sing/common/metadata" - N "github.com/sagernet/sing/common/network" - - "github.com/lucas-clemente/quic-go" -) - -var ( - _ net.Conn = (*ServerConn)(nil) - _ N.HandshakeConn = (*ServerConn)(nil) -) - -type ServerConn struct { - quic.Stream - destination M.Socksaddr - responseWritten bool -} - -func NewServerConn(stream quic.Stream, destination M.Socksaddr) *ServerConn { - return &ServerConn{ - Stream: stream, - destination: destination, - } -} - -func (c *ServerConn) LocalAddr() net.Addr { - return nil -} - -func (c *ServerConn) RemoteAddr() net.Addr { - return c.destination.TCPAddr() -} - -func (c *ServerConn) Write(b []byte) (n int, err error) { - if !c.responseWritten { - err = WriteServerResponse(c.Stream, ServerResponse{ - OK: true, - }, b) - c.responseWritten = true - return len(b), nil - } - return c.Stream.Write(b) -} - -func (c *ServerConn) ReaderReplaceable() bool { - return true -} - -func (c *ServerConn) WriterReplaceable() bool { - return c.responseWritten -} - -func (c *ServerConn) HandshakeFailure(err error) error { - if c.responseWritten { - return nil - } - return WriteServerResponse(c.Stream, ServerResponse{ - Message: err.Error(), - }, nil) -} - -func (c *ServerConn) Upstream() any { - return c.Stream -} diff --git a/transport/hysteria/wrap.go b/transport/hysteria/wrap.go index 03aadb38..c280cf1c 100644 --- a/transport/hysteria/wrap.go +++ b/transport/hysteria/wrap.go @@ -5,25 +5,52 @@ import ( "os" "syscall" + "github.com/sagernet/quic-go" "github.com/sagernet/sing/common" ) -type WrapPacketConn struct { +type PacketConnWrapper struct { net.PacketConn } -func (c *WrapPacketConn) SetReadBuffer(bytes int) error { +func (c *PacketConnWrapper) SetReadBuffer(bytes int) error { return common.MustCast[*net.UDPConn](c.PacketConn).SetReadBuffer(bytes) } -func (c *WrapPacketConn) SetWriteBuffer(bytes int) error { +func (c *PacketConnWrapper) SetWriteBuffer(bytes int) error { return common.MustCast[*net.UDPConn](c.PacketConn).SetWriteBuffer(bytes) } -func (c *WrapPacketConn) SyscallConn() (syscall.RawConn, error) { +func (c *PacketConnWrapper) SyscallConn() (syscall.RawConn, error) { return common.MustCast[*net.UDPConn](c.PacketConn).SyscallConn() } -func (c *WrapPacketConn) File() (f *os.File, err error) { +func (c *PacketConnWrapper) File() (f *os.File, err error) { return common.MustCast[*net.UDPConn](c.PacketConn).File() } + +func (c *PacketConnWrapper) Upstream() any { + return c.PacketConn +} + +type StreamWrapper struct { + quic.Stream +} + +func (s *StreamWrapper) Upstream() any { + return s.Stream +} + +func (s *StreamWrapper) ReaderReplaceable() bool { + return true +} + +func (s *StreamWrapper) WriterReplaceable() bool { + return true +} + +func (s *StreamWrapper) Close() error { + s.CancelRead(0) + s.Stream.Close() + return nil +}