Add hysteria udp client

This commit is contained in:
世界 2022-08-18 16:29:10 +08:00
parent 17f10e0d3a
commit f5bb4cf53f
No known key found for this signature in database
GPG Key ID: CD109927C34A63C4
6 changed files with 422 additions and 34 deletions

View File

@ -51,6 +51,9 @@ type Hysteria struct {
recvBPS uint64 recvBPS uint64
connAccess sync.Mutex connAccess sync.Mutex
conn quic.Connection conn quic.Connection
udpAccess sync.RWMutex
udpSessions map[uint32]chan *hysteria.UDPMessage
udpDefragger hysteria.Defragger
} }
func NewHysteria(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.HysteriaOutboundOptions) (*Hysteria, error) { func NewHysteria(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.HysteriaOutboundOptions) (*Hysteria, error) {
@ -162,6 +165,8 @@ func (h *Hysteria) offer(ctx context.Context) (quic.Connection, error) {
} }
h.connAccess.Lock() h.connAccess.Lock()
defer h.connAccess.Unlock() defer h.connAccess.Unlock()
h.udpAccess.Lock()
defer h.udpAccess.Unlock()
conn = h.conn conn = h.conn
if conn != nil && !common.Done(conn.Context()) { if conn != nil && !common.Done(conn.Context()) {
return conn, nil return conn, nil
@ -171,6 +176,14 @@ func (h *Hysteria) offer(ctx context.Context) (quic.Connection, error) {
return nil, err return nil, err
} }
h.conn = conn h.conn = conn
if common.Contains(h.network, N.NetworkUDP) {
for _, session := range h.udpSessions {
close(session)
}
h.udpSessions = make(map[uint32]chan *hysteria.UDPMessage)
h.udpDefragger = hysteria.Defragger{}
go h.recvLoop(conn)
}
return conn, nil return conn, nil
} }
@ -214,16 +227,53 @@ func (h *Hysteria) offerNew(ctx context.Context) (quic.Connection, error) {
return quicConn, nil return quicConn, nil
} }
func (h *Hysteria) recvLoop(conn quic.Connection) {
for {
packet, err := conn.ReceiveMessage()
if err != nil {
return
}
message, err := hysteria.ParseUDPMessage(packet)
if err != nil {
h.logger.Error("parse udp message: ", err)
continue
}
dfMsg := h.udpDefragger.Feed(message)
if dfMsg == nil {
continue
}
h.udpAccess.RLock()
ch, ok := h.udpSessions[dfMsg.SessionID]
if ok {
select {
case ch <- dfMsg:
// OK
default:
// Silently drop the message when the channel is full
}
}
h.udpAccess.RUnlock()
}
}
func (h *Hysteria) Close() error { func (h *Hysteria) Close() error {
h.connAccess.Lock() h.connAccess.Lock()
defer h.connAccess.Unlock() defer h.connAccess.Unlock()
h.udpAccess.Lock()
defer h.udpAccess.Unlock()
if h.conn != nil { if h.conn != nil {
h.conn.CloseWithError(0, "") h.conn.CloseWithError(0, "")
} }
for _, session := range h.udpSessions {
close(session)
}
h.udpSessions = make(map[uint32]chan *hysteria.UDPMessage)
return nil return nil
} }
func (h *Hysteria) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { func (h *Hysteria) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
switch N.NetworkName(network) {
case N.NetworkTCP:
conn, err := h.offer(ctx) conn, err := h.offer(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
@ -232,16 +282,64 @@ func (h *Hysteria) DialContext(ctx context.Context, network string, destination
if err != nil { if err != nil {
return nil, err return nil, err
} }
switch N.NetworkName(network) {
case N.NetworkTCP:
return hysteria.NewClientConn(stream, destination), nil return hysteria.NewClientConn(stream, destination), nil
case N.NetworkUDP:
conn, err := h.ListenPacket(ctx, destination)
if err != nil {
return nil, err
}
return conn.(*hysteria.ClientPacketConn), nil
default: default:
return nil, E.New("unsupported network: ", network) return nil, E.New("unsupported network: ", network)
} }
} }
func (h *Hysteria) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { func (h *Hysteria) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
return nil, os.ErrInvalid conn, err := h.offer(ctx)
if err != nil {
return nil, err
}
stream, err := conn.OpenStream()
if err != nil {
return nil, err
}
err = hysteria.WriteClientRequest(stream, hysteria.ClientRequest{
UDP: true,
Host: destination.AddrString(),
Port: destination.Port,
}, nil)
if err != nil {
stream.Close()
return nil, err
}
var response *hysteria.ServerResponse
response, err = hysteria.ReadServerResponse(stream)
if err != nil {
stream.Close()
return nil, err
}
if !response.OK {
stream.Close()
return nil, E.New("remote error: ", response.Message)
}
h.udpAccess.Lock()
nCh := make(chan *hysteria.UDPMessage, 1024)
// Store the current session map for CloseFunc below
// to ensures that we are adding and removing sessions on the same map,
// as reconnecting will reassign the map
h.udpSessions[response.UDPSessionID] = nCh
h.udpAccess.Unlock()
packetConn := hysteria.NewClientPacketConn(conn, stream, response.UDPSessionID, destination, nCh, common.Closer(func() error {
h.udpAccess.Lock()
if ch, ok := h.udpSessions[response.UDPSessionID]; ok {
close(ch)
delete(h.udpSessions, response.UDPSessionID)
}
h.udpAccess.Unlock()
return nil
}))
go packetConn.Hold()
return packetConn, nil
} }
func (h *Hysteria) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { func (h *Hysteria) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {

View File

@ -381,7 +381,6 @@ func testLargeDataWithPacketConn(t *testing.T, port uint16, pcc func() (net.Pack
mux.Lock() mux.Lock()
hashMap[i] = hash[:] hashMap[i] = hash[:]
mux.Unlock() mux.Unlock()
println("write ti ", addr.String())
if _, err = pc.WriteTo(buf, addr); err != nil { if _, err = pc.WriteTo(buf, addr); err != nil {
t.Log(err) t.Log(err)
continue continue

View File

@ -56,5 +56,5 @@ func TestHysteriaOutbound(t *testing.T) {
}, },
}, },
}) })
testTCP(t, clientPort, testPort) testSuit(t, clientPort, testPort)
} }

View File

@ -1,8 +1,13 @@
package hysteria package hysteria
import ( import (
"io"
"net" "net"
"os"
"time"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
@ -67,3 +72,107 @@ func (c *ClientConn) RemoteAddr() net.Addr {
func (c *ClientConn) Upstream() any { func (c *ClientConn) Upstream() any {
return c.Stream return c.Stream
} }
type ClientPacketConn struct {
session quic.Connection
stream quic.Stream
sessionId uint32
destination M.Socksaddr
msgCh <-chan *UDPMessage
closer io.Closer
}
func NewClientPacketConn(session quic.Connection, stream quic.Stream, sessionId uint32, destination M.Socksaddr, msgCh <-chan *UDPMessage, closer io.Closer) *ClientPacketConn {
return &ClientPacketConn{
session: session,
stream: stream,
sessionId: sessionId,
destination: destination,
msgCh: msgCh,
closer: closer,
}
}
func (c *ClientPacketConn) Hold() {
// Hold the stream until it's closed
buf := make([]byte, 1024)
for {
_, err := c.stream.Read(buf)
if err != nil {
break
}
}
_ = c.Close()
}
func (c *ClientPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
msg := <-c.msgCh
if msg == nil {
err = net.ErrClosed
return
}
err = common.Error(buffer.Write(msg.Data))
destination = M.ParseSocksaddrHostPort(msg.Host, msg.Port)
return
}
func (c *ClientPacketConn) ReadPacketThreadSafe() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
msg := <-c.msgCh
if msg == nil {
err = net.ErrClosed
return
}
buffer = buf.As(msg.Data)
destination = M.ParseSocksaddrHostPort(msg.Host, msg.Port)
return
}
func (c *ClientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
return WriteUDPMessage(c.session, UDPMessage{
SessionID: c.sessionId,
Host: destination.AddrString(),
Port: destination.Port,
FragCount: 1,
Data: buffer.Bytes(),
})
}
func (c *ClientPacketConn) LocalAddr() net.Addr {
return nil
}
func (c *ClientPacketConn) RemoteAddr() net.Addr {
return c.destination.UDPAddr()
}
func (c *ClientPacketConn) SetDeadline(t time.Time) error {
return os.ErrInvalid
}
func (c *ClientPacketConn) SetReadDeadline(t time.Time) error {
return os.ErrInvalid
}
func (c *ClientPacketConn) SetWriteDeadline(t time.Time) error {
return os.ErrInvalid
}
func (c *ClientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
panic("invalid")
}
func (c *ClientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
panic("invalid")
}
func (c *ClientPacketConn) Read(b []byte) (n int, err error) {
panic("invalid")
}
func (c *ClientPacketConn) Write(b []byte) (n int, err error) {
panic("invalid")
}
func (c *ClientPacketConn) Close() error {
return common.Close(c.stream, c.closer)
}

View File

@ -0,0 +1,65 @@
package hysteria
func FragUDPMessage(m UDPMessage, maxSize int) []UDPMessage {
if m.Size() <= maxSize {
return []UDPMessage{m}
}
fullPayload := m.Data
maxPayloadSize := maxSize - m.HeaderSize()
off := 0
fragID := uint8(0)
fragCount := uint8((len(fullPayload) + maxPayloadSize - 1) / maxPayloadSize) // round up
var frags []UDPMessage
for off < len(fullPayload) {
payloadSize := len(fullPayload) - off
if payloadSize > maxPayloadSize {
payloadSize = maxPayloadSize
}
frag := m
frag.FragID = fragID
frag.FragCount = fragCount
frag.Data = fullPayload[off : off+payloadSize]
frags = append(frags, frag)
off += payloadSize
fragID++
}
return frags
}
type Defragger struct {
msgID uint16
frags []*UDPMessage
count uint8
}
func (d *Defragger) Feed(m UDPMessage) *UDPMessage {
if m.FragCount <= 1 {
return &m
}
if m.FragID >= m.FragCount {
// wtf is this?
return nil
}
if m.MsgID != d.msgID {
// new message, clear previous state
d.msgID = m.MsgID
d.frags = make([]*UDPMessage, m.FragCount)
d.count = 1
d.frags[m.FragID] = &m
} else if d.frags[m.FragID] == nil {
d.frags[m.FragID] = &m
d.count++
if int(d.count) == len(d.frags) {
// all fragments received, assemble
var data []byte
for _, frag := range d.frags {
data = append(data, frag.Data...)
}
m.Data = data
m.FragID = 0
m.FragCount = 1
return &m
}
}
return nil
}

View File

@ -1,11 +1,16 @@
package hysteria package hysteria
import ( import (
"bytes"
"encoding/binary" "encoding/binary"
"io" "io"
"math/rand"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
"github.com/lucas-clemente/quic-go"
) )
const Version = 3 const Version = 3
@ -23,18 +28,6 @@ type ServerHello struct {
Message string Message string
} }
type ClientRequest struct {
UDP bool
Host string
Port uint16
}
type ServerResponse struct {
OK bool
UDPSessionID uint32
Message string
}
func WriteClientHello(stream io.Writer, hello ClientHello) error { func WriteClientHello(stream io.Writer, hello ClientHello) error {
var requestLen int var requestLen int
requestLen += 1 // version requestLen += 1 // version
@ -87,6 +80,18 @@ func ReadServerHello(stream io.Reader) (*ServerHello, error) {
return &serverHello, nil return &serverHello, nil
} }
type ClientRequest struct {
UDP bool
Host string
Port uint16
}
type ServerResponse struct {
OK bool
UDPSessionID uint32
Message string
}
func WriteClientRequest(stream io.Writer, request ClientRequest, payload []byte) error { func WriteClientRequest(stream io.Writer, request ClientRequest, payload []byte) error {
var requestLen int var requestLen int
requestLen += 1 // udp requestLen += 1 // udp
@ -140,3 +145,115 @@ func ReadServerResponse(stream io.Reader) (*ServerResponse, error) {
serverResponse.Message = string(message) serverResponse.Message = string(message)
return &serverResponse, nil return &serverResponse, nil
} }
type UDPMessage struct {
SessionID uint32
Host string
Port uint16
MsgID uint16 // doesn't matter when not fragmented, but must not be 0 when fragmented
FragID uint8 // doesn't matter when not fragmented, starts at 0 when fragmented
FragCount uint8 // must be 1 when not fragmented
Data []byte
}
func (m UDPMessage) HeaderSize() int {
return 4 + 2 + len(m.Host) + 2 + 2 + 1 + 1 + 2
}
func (m UDPMessage) Size() int {
return m.HeaderSize() + len(m.Data)
}
func ParseUDPMessage(packet []byte) (message UDPMessage, err error) {
reader := bytes.NewReader(packet)
err = binary.Read(reader, binary.BigEndian, &message.SessionID)
if err != nil {
return
}
var hostLen uint16
err = binary.Read(reader, binary.BigEndian, &hostLen)
if err != nil {
return
}
_, err = reader.Seek(int64(hostLen), io.SeekCurrent)
if err != nil {
return
}
message.Host = string(packet[6 : 6+hostLen])
err = binary.Read(reader, binary.BigEndian, &message.Port)
if err != nil {
return
}
err = binary.Read(reader, binary.BigEndian, &message.MsgID)
if err != nil {
return
}
err = binary.Read(reader, binary.BigEndian, &message.FragID)
if err != nil {
return
}
err = binary.Read(reader, binary.BigEndian, &message.FragCount)
if err != nil {
return
}
var dataLen uint16
err = binary.Read(reader, binary.BigEndian, &dataLen)
if err != nil {
return
}
if reader.Len() != int(dataLen) {
err = E.New("invalid data length")
}
dataOffset := int(reader.Size()) - reader.Len()
message.Data = packet[dataOffset:]
return
}
func WriteUDPMessage(conn quic.Connection, message UDPMessage) error {
var messageLen int
messageLen += 4 // session id
messageLen += 2 // host len
messageLen += len(message.Host)
messageLen += 2 // port
messageLen += 2 // msg id
messageLen += 1 // frag id
messageLen += 1 // frag count
messageLen += 2 // data len
messageLen += len(message.Data)
_buffer := buf.StackNewSize(messageLen)
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
defer buffer.Release()
err := writeUDPMessage(conn, message, buffer)
// TODO: wait for change upstream
if /*errSize, ok := err.(quic.ErrMessageToLarge); ok*/ false {
const errSize = 0
// need to frag
message.MsgID = uint16(rand.Intn(0xFFFF)) + 1 // msgID must be > 0 when fragCount > 1
fragMsgs := FragUDPMessage(message, int(errSize))
for _, fragMsg := range fragMsgs {
buffer.FullReset()
err = writeUDPMessage(conn, fragMsg, buffer)
if err != nil {
return err
}
}
return nil
}
return err
}
func writeUDPMessage(conn quic.Connection, message UDPMessage, buffer *buf.Buffer) error {
common.Must(
binary.Write(buffer, binary.BigEndian, message.SessionID),
binary.Write(buffer, binary.BigEndian, uint16(len(message.Host))),
common.Error(buffer.WriteString(message.Host)),
binary.Write(buffer, binary.BigEndian, message.Port),
binary.Write(buffer, binary.BigEndian, message.MsgID),
binary.Write(buffer, binary.BigEndian, message.FragID),
binary.Write(buffer, binary.BigEndian, message.FragCount),
binary.Write(buffer, binary.BigEndian, uint16(len(message.Data))),
common.Error(buffer.Write(message.Data)),
)
return conn.SendMessage(buffer.Bytes())
}