mirror of
https://github.com/SagerNet/sing-box.git
synced 2025-06-08 11:44:13 +08:00
Add hysteria udp client
This commit is contained in:
parent
17f10e0d3a
commit
f5bb4cf53f
@ -51,6 +51,9 @@ type Hysteria struct {
|
||||
recvBPS uint64
|
||||
connAccess sync.Mutex
|
||||
conn quic.Connection
|
||||
udpAccess sync.RWMutex
|
||||
udpSessions map[uint32]chan *hysteria.UDPMessage
|
||||
udpDefragger hysteria.Defragger
|
||||
}
|
||||
|
||||
func NewHysteria(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.HysteriaOutboundOptions) (*Hysteria, error) {
|
||||
@ -162,6 +165,8 @@ func (h *Hysteria) offer(ctx context.Context) (quic.Connection, error) {
|
||||
}
|
||||
h.connAccess.Lock()
|
||||
defer h.connAccess.Unlock()
|
||||
h.udpAccess.Lock()
|
||||
defer h.udpAccess.Unlock()
|
||||
conn = h.conn
|
||||
if conn != nil && !common.Done(conn.Context()) {
|
||||
return conn, nil
|
||||
@ -171,6 +176,14 @@ func (h *Hysteria) offer(ctx context.Context) (quic.Connection, error) {
|
||||
return nil, err
|
||||
}
|
||||
h.conn = conn
|
||||
if common.Contains(h.network, N.NetworkUDP) {
|
||||
for _, session := range h.udpSessions {
|
||||
close(session)
|
||||
}
|
||||
h.udpSessions = make(map[uint32]chan *hysteria.UDPMessage)
|
||||
h.udpDefragger = hysteria.Defragger{}
|
||||
go h.recvLoop(conn)
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
@ -214,16 +227,53 @@ func (h *Hysteria) offerNew(ctx context.Context) (quic.Connection, error) {
|
||||
return quicConn, nil
|
||||
}
|
||||
|
||||
func (h *Hysteria) recvLoop(conn quic.Connection) {
|
||||
for {
|
||||
packet, err := conn.ReceiveMessage()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
message, err := hysteria.ParseUDPMessage(packet)
|
||||
if err != nil {
|
||||
h.logger.Error("parse udp message: ", err)
|
||||
continue
|
||||
}
|
||||
dfMsg := h.udpDefragger.Feed(message)
|
||||
if dfMsg == nil {
|
||||
continue
|
||||
}
|
||||
h.udpAccess.RLock()
|
||||
ch, ok := h.udpSessions[dfMsg.SessionID]
|
||||
if ok {
|
||||
select {
|
||||
case ch <- dfMsg:
|
||||
// OK
|
||||
default:
|
||||
// Silently drop the message when the channel is full
|
||||
}
|
||||
}
|
||||
h.udpAccess.RUnlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Hysteria) Close() error {
|
||||
h.connAccess.Lock()
|
||||
defer h.connAccess.Unlock()
|
||||
h.udpAccess.Lock()
|
||||
defer h.udpAccess.Unlock()
|
||||
if h.conn != nil {
|
||||
h.conn.CloseWithError(0, "")
|
||||
}
|
||||
for _, session := range h.udpSessions {
|
||||
close(session)
|
||||
}
|
||||
h.udpSessions = make(map[uint32]chan *hysteria.UDPMessage)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Hysteria) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||
switch N.NetworkName(network) {
|
||||
case N.NetworkTCP:
|
||||
conn, err := h.offer(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -232,16 +282,64 @@ func (h *Hysteria) DialContext(ctx context.Context, network string, destination
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch N.NetworkName(network) {
|
||||
case N.NetworkTCP:
|
||||
return hysteria.NewClientConn(stream, destination), nil
|
||||
case N.NetworkUDP:
|
||||
conn, err := h.ListenPacket(ctx, destination)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return conn.(*hysteria.ClientPacketConn), nil
|
||||
default:
|
||||
return nil, E.New("unsupported network: ", network)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Hysteria) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||
return nil, os.ErrInvalid
|
||||
conn, err := h.offer(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stream, err := conn.OpenStream()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = hysteria.WriteClientRequest(stream, hysteria.ClientRequest{
|
||||
UDP: true,
|
||||
Host: destination.AddrString(),
|
||||
Port: destination.Port,
|
||||
}, nil)
|
||||
if err != nil {
|
||||
stream.Close()
|
||||
return nil, err
|
||||
}
|
||||
var response *hysteria.ServerResponse
|
||||
response, err = hysteria.ReadServerResponse(stream)
|
||||
if err != nil {
|
||||
stream.Close()
|
||||
return nil, err
|
||||
}
|
||||
if !response.OK {
|
||||
stream.Close()
|
||||
return nil, E.New("remote error: ", response.Message)
|
||||
}
|
||||
h.udpAccess.Lock()
|
||||
nCh := make(chan *hysteria.UDPMessage, 1024)
|
||||
// Store the current session map for CloseFunc below
|
||||
// to ensures that we are adding and removing sessions on the same map,
|
||||
// as reconnecting will reassign the map
|
||||
h.udpSessions[response.UDPSessionID] = nCh
|
||||
h.udpAccess.Unlock()
|
||||
packetConn := hysteria.NewClientPacketConn(conn, stream, response.UDPSessionID, destination, nCh, common.Closer(func() error {
|
||||
h.udpAccess.Lock()
|
||||
if ch, ok := h.udpSessions[response.UDPSessionID]; ok {
|
||||
close(ch)
|
||||
delete(h.udpSessions, response.UDPSessionID)
|
||||
}
|
||||
h.udpAccess.Unlock()
|
||||
return nil
|
||||
}))
|
||||
go packetConn.Hold()
|
||||
return packetConn, nil
|
||||
}
|
||||
|
||||
func (h *Hysteria) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
|
||||
|
@ -381,7 +381,6 @@ func testLargeDataWithPacketConn(t *testing.T, port uint16, pcc func() (net.Pack
|
||||
mux.Lock()
|
||||
hashMap[i] = hash[:]
|
||||
mux.Unlock()
|
||||
println("write ti ", addr.String())
|
||||
if _, err = pc.WriteTo(buf, addr); err != nil {
|
||||
t.Log(err)
|
||||
continue
|
||||
|
@ -56,5 +56,5 @@ func TestHysteriaOutbound(t *testing.T) {
|
||||
},
|
||||
},
|
||||
})
|
||||
testTCP(t, clientPort, testPort)
|
||||
testSuit(t, clientPort, testPort)
|
||||
}
|
||||
|
@ -1,8 +1,13 @@
|
||||
package hysteria
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
|
||||
@ -67,3 +72,107 @@ func (c *ClientConn) RemoteAddr() net.Addr {
|
||||
func (c *ClientConn) Upstream() any {
|
||||
return c.Stream
|
||||
}
|
||||
|
||||
type ClientPacketConn struct {
|
||||
session quic.Connection
|
||||
stream quic.Stream
|
||||
sessionId uint32
|
||||
destination M.Socksaddr
|
||||
msgCh <-chan *UDPMessage
|
||||
closer io.Closer
|
||||
}
|
||||
|
||||
func NewClientPacketConn(session quic.Connection, stream quic.Stream, sessionId uint32, destination M.Socksaddr, msgCh <-chan *UDPMessage, closer io.Closer) *ClientPacketConn {
|
||||
return &ClientPacketConn{
|
||||
session: session,
|
||||
stream: stream,
|
||||
sessionId: sessionId,
|
||||
destination: destination,
|
||||
msgCh: msgCh,
|
||||
closer: closer,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ClientPacketConn) Hold() {
|
||||
// Hold the stream until it's closed
|
||||
buf := make([]byte, 1024)
|
||||
for {
|
||||
_, err := c.stream.Read(buf)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
_ = c.Close()
|
||||
}
|
||||
|
||||
func (c *ClientPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
|
||||
msg := <-c.msgCh
|
||||
if msg == nil {
|
||||
err = net.ErrClosed
|
||||
return
|
||||
}
|
||||
err = common.Error(buffer.Write(msg.Data))
|
||||
destination = M.ParseSocksaddrHostPort(msg.Host, msg.Port)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *ClientPacketConn) ReadPacketThreadSafe() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
|
||||
msg := <-c.msgCh
|
||||
if msg == nil {
|
||||
err = net.ErrClosed
|
||||
return
|
||||
}
|
||||
buffer = buf.As(msg.Data)
|
||||
destination = M.ParseSocksaddrHostPort(msg.Host, msg.Port)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *ClientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
||||
return WriteUDPMessage(c.session, UDPMessage{
|
||||
SessionID: c.sessionId,
|
||||
Host: destination.AddrString(),
|
||||
Port: destination.Port,
|
||||
FragCount: 1,
|
||||
Data: buffer.Bytes(),
|
||||
})
|
||||
}
|
||||
|
||||
func (c *ClientPacketConn) LocalAddr() net.Addr {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ClientPacketConn) RemoteAddr() net.Addr {
|
||||
return c.destination.UDPAddr()
|
||||
}
|
||||
|
||||
func (c *ClientPacketConn) SetDeadline(t time.Time) error {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
|
||||
func (c *ClientPacketConn) SetReadDeadline(t time.Time) error {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
|
||||
func (c *ClientPacketConn) SetWriteDeadline(t time.Time) error {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
|
||||
func (c *ClientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||
panic("invalid")
|
||||
}
|
||||
|
||||
func (c *ClientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||
panic("invalid")
|
||||
}
|
||||
|
||||
func (c *ClientPacketConn) Read(b []byte) (n int, err error) {
|
||||
panic("invalid")
|
||||
}
|
||||
|
||||
func (c *ClientPacketConn) Write(b []byte) (n int, err error) {
|
||||
panic("invalid")
|
||||
}
|
||||
|
||||
func (c *ClientPacketConn) Close() error {
|
||||
return common.Close(c.stream, c.closer)
|
||||
}
|
||||
|
65
transport/hysteria/frag.go
Normal file
65
transport/hysteria/frag.go
Normal file
@ -0,0 +1,65 @@
|
||||
package hysteria
|
||||
|
||||
func FragUDPMessage(m UDPMessage, maxSize int) []UDPMessage {
|
||||
if m.Size() <= maxSize {
|
||||
return []UDPMessage{m}
|
||||
}
|
||||
fullPayload := m.Data
|
||||
maxPayloadSize := maxSize - m.HeaderSize()
|
||||
off := 0
|
||||
fragID := uint8(0)
|
||||
fragCount := uint8((len(fullPayload) + maxPayloadSize - 1) / maxPayloadSize) // round up
|
||||
var frags []UDPMessage
|
||||
for off < len(fullPayload) {
|
||||
payloadSize := len(fullPayload) - off
|
||||
if payloadSize > maxPayloadSize {
|
||||
payloadSize = maxPayloadSize
|
||||
}
|
||||
frag := m
|
||||
frag.FragID = fragID
|
||||
frag.FragCount = fragCount
|
||||
frag.Data = fullPayload[off : off+payloadSize]
|
||||
frags = append(frags, frag)
|
||||
off += payloadSize
|
||||
fragID++
|
||||
}
|
||||
return frags
|
||||
}
|
||||
|
||||
type Defragger struct {
|
||||
msgID uint16
|
||||
frags []*UDPMessage
|
||||
count uint8
|
||||
}
|
||||
|
||||
func (d *Defragger) Feed(m UDPMessage) *UDPMessage {
|
||||
if m.FragCount <= 1 {
|
||||
return &m
|
||||
}
|
||||
if m.FragID >= m.FragCount {
|
||||
// wtf is this?
|
||||
return nil
|
||||
}
|
||||
if m.MsgID != d.msgID {
|
||||
// new message, clear previous state
|
||||
d.msgID = m.MsgID
|
||||
d.frags = make([]*UDPMessage, m.FragCount)
|
||||
d.count = 1
|
||||
d.frags[m.FragID] = &m
|
||||
} else if d.frags[m.FragID] == nil {
|
||||
d.frags[m.FragID] = &m
|
||||
d.count++
|
||||
if int(d.count) == len(d.frags) {
|
||||
// all fragments received, assemble
|
||||
var data []byte
|
||||
for _, frag := range d.frags {
|
||||
data = append(data, frag.Data...)
|
||||
}
|
||||
m.Data = data
|
||||
m.FragID = 0
|
||||
m.FragCount = 1
|
||||
return &m
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
@ -1,11 +1,16 @@
|
||||
package hysteria
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"math/rand"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
)
|
||||
|
||||
const Version = 3
|
||||
@ -23,18 +28,6 @@ type ServerHello struct {
|
||||
Message string
|
||||
}
|
||||
|
||||
type ClientRequest struct {
|
||||
UDP bool
|
||||
Host string
|
||||
Port uint16
|
||||
}
|
||||
|
||||
type ServerResponse struct {
|
||||
OK bool
|
||||
UDPSessionID uint32
|
||||
Message string
|
||||
}
|
||||
|
||||
func WriteClientHello(stream io.Writer, hello ClientHello) error {
|
||||
var requestLen int
|
||||
requestLen += 1 // version
|
||||
@ -87,6 +80,18 @@ func ReadServerHello(stream io.Reader) (*ServerHello, error) {
|
||||
return &serverHello, nil
|
||||
}
|
||||
|
||||
type ClientRequest struct {
|
||||
UDP bool
|
||||
Host string
|
||||
Port uint16
|
||||
}
|
||||
|
||||
type ServerResponse struct {
|
||||
OK bool
|
||||
UDPSessionID uint32
|
||||
Message string
|
||||
}
|
||||
|
||||
func WriteClientRequest(stream io.Writer, request ClientRequest, payload []byte) error {
|
||||
var requestLen int
|
||||
requestLen += 1 // udp
|
||||
@ -140,3 +145,115 @@ func ReadServerResponse(stream io.Reader) (*ServerResponse, error) {
|
||||
serverResponse.Message = string(message)
|
||||
return &serverResponse, nil
|
||||
}
|
||||
|
||||
type UDPMessage struct {
|
||||
SessionID uint32
|
||||
Host string
|
||||
Port uint16
|
||||
MsgID uint16 // doesn't matter when not fragmented, but must not be 0 when fragmented
|
||||
FragID uint8 // doesn't matter when not fragmented, starts at 0 when fragmented
|
||||
FragCount uint8 // must be 1 when not fragmented
|
||||
Data []byte
|
||||
}
|
||||
|
||||
func (m UDPMessage) HeaderSize() int {
|
||||
return 4 + 2 + len(m.Host) + 2 + 2 + 1 + 1 + 2
|
||||
}
|
||||
|
||||
func (m UDPMessage) Size() int {
|
||||
return m.HeaderSize() + len(m.Data)
|
||||
}
|
||||
|
||||
func ParseUDPMessage(packet []byte) (message UDPMessage, err error) {
|
||||
reader := bytes.NewReader(packet)
|
||||
err = binary.Read(reader, binary.BigEndian, &message.SessionID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var hostLen uint16
|
||||
err = binary.Read(reader, binary.BigEndian, &hostLen)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = reader.Seek(int64(hostLen), io.SeekCurrent)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
message.Host = string(packet[6 : 6+hostLen])
|
||||
err = binary.Read(reader, binary.BigEndian, &message.Port)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = binary.Read(reader, binary.BigEndian, &message.MsgID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = binary.Read(reader, binary.BigEndian, &message.FragID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = binary.Read(reader, binary.BigEndian, &message.FragCount)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var dataLen uint16
|
||||
err = binary.Read(reader, binary.BigEndian, &dataLen)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if reader.Len() != int(dataLen) {
|
||||
err = E.New("invalid data length")
|
||||
}
|
||||
dataOffset := int(reader.Size()) - reader.Len()
|
||||
message.Data = packet[dataOffset:]
|
||||
return
|
||||
}
|
||||
|
||||
func WriteUDPMessage(conn quic.Connection, message UDPMessage) error {
|
||||
var messageLen int
|
||||
messageLen += 4 // session id
|
||||
messageLen += 2 // host len
|
||||
messageLen += len(message.Host)
|
||||
messageLen += 2 // port
|
||||
messageLen += 2 // msg id
|
||||
messageLen += 1 // frag id
|
||||
messageLen += 1 // frag count
|
||||
messageLen += 2 // data len
|
||||
messageLen += len(message.Data)
|
||||
_buffer := buf.StackNewSize(messageLen)
|
||||
defer common.KeepAlive(_buffer)
|
||||
buffer := common.Dup(_buffer)
|
||||
defer buffer.Release()
|
||||
err := writeUDPMessage(conn, message, buffer)
|
||||
// TODO: wait for change upstream
|
||||
if /*errSize, ok := err.(quic.ErrMessageToLarge); ok*/ false {
|
||||
const errSize = 0
|
||||
// need to frag
|
||||
message.MsgID = uint16(rand.Intn(0xFFFF)) + 1 // msgID must be > 0 when fragCount > 1
|
||||
fragMsgs := FragUDPMessage(message, int(errSize))
|
||||
for _, fragMsg := range fragMsgs {
|
||||
buffer.FullReset()
|
||||
err = writeUDPMessage(conn, fragMsg, buffer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func writeUDPMessage(conn quic.Connection, message UDPMessage, buffer *buf.Buffer) error {
|
||||
common.Must(
|
||||
binary.Write(buffer, binary.BigEndian, message.SessionID),
|
||||
binary.Write(buffer, binary.BigEndian, uint16(len(message.Host))),
|
||||
common.Error(buffer.WriteString(message.Host)),
|
||||
binary.Write(buffer, binary.BigEndian, message.Port),
|
||||
binary.Write(buffer, binary.BigEndian, message.MsgID),
|
||||
binary.Write(buffer, binary.BigEndian, message.FragID),
|
||||
binary.Write(buffer, binary.BigEndian, message.FragCount),
|
||||
binary.Write(buffer, binary.BigEndian, uint16(len(message.Data))),
|
||||
common.Error(buffer.Write(message.Data)),
|
||||
)
|
||||
return conn.SendMessage(buffer.Bytes())
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user