mirror of
https://github.com/SagerNet/sing-box.git
synced 2025-06-08 11:44:13 +08:00
Add hysteria udp client
This commit is contained in:
parent
17f10e0d3a
commit
f5bb4cf53f
@ -40,17 +40,20 @@ var _ adapter.Outbound = (*Hysteria)(nil)
|
|||||||
|
|
||||||
type Hysteria struct {
|
type Hysteria struct {
|
||||||
myOutboundAdapter
|
myOutboundAdapter
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
dialer N.Dialer
|
dialer N.Dialer
|
||||||
serverAddr M.Socksaddr
|
serverAddr M.Socksaddr
|
||||||
tlsConfig *tls.Config
|
tlsConfig *tls.Config
|
||||||
quicConfig *quic.Config
|
quicConfig *quic.Config
|
||||||
authKey []byte
|
authKey []byte
|
||||||
xplusKey []byte
|
xplusKey []byte
|
||||||
sendBPS uint64
|
sendBPS uint64
|
||||||
recvBPS uint64
|
recvBPS uint64
|
||||||
connAccess sync.Mutex
|
connAccess sync.Mutex
|
||||||
conn quic.Connection
|
conn quic.Connection
|
||||||
|
udpAccess sync.RWMutex
|
||||||
|
udpSessions map[uint32]chan *hysteria.UDPMessage
|
||||||
|
udpDefragger hysteria.Defragger
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHysteria(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.HysteriaOutboundOptions) (*Hysteria, error) {
|
func NewHysteria(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.HysteriaOutboundOptions) (*Hysteria, error) {
|
||||||
@ -162,6 +165,8 @@ func (h *Hysteria) offer(ctx context.Context) (quic.Connection, error) {
|
|||||||
}
|
}
|
||||||
h.connAccess.Lock()
|
h.connAccess.Lock()
|
||||||
defer h.connAccess.Unlock()
|
defer h.connAccess.Unlock()
|
||||||
|
h.udpAccess.Lock()
|
||||||
|
defer h.udpAccess.Unlock()
|
||||||
conn = h.conn
|
conn = h.conn
|
||||||
if conn != nil && !common.Done(conn.Context()) {
|
if conn != nil && !common.Done(conn.Context()) {
|
||||||
return conn, nil
|
return conn, nil
|
||||||
@ -171,6 +176,14 @@ func (h *Hysteria) offer(ctx context.Context) (quic.Connection, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
h.conn = conn
|
h.conn = conn
|
||||||
|
if common.Contains(h.network, N.NetworkUDP) {
|
||||||
|
for _, session := range h.udpSessions {
|
||||||
|
close(session)
|
||||||
|
}
|
||||||
|
h.udpSessions = make(map[uint32]chan *hysteria.UDPMessage)
|
||||||
|
h.udpDefragger = hysteria.Defragger{}
|
||||||
|
go h.recvLoop(conn)
|
||||||
|
}
|
||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -214,16 +227,74 @@ func (h *Hysteria) offerNew(ctx context.Context) (quic.Connection, error) {
|
|||||||
return quicConn, nil
|
return quicConn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *Hysteria) recvLoop(conn quic.Connection) {
|
||||||
|
for {
|
||||||
|
packet, err := conn.ReceiveMessage()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
message, err := hysteria.ParseUDPMessage(packet)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error("parse udp message: ", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
dfMsg := h.udpDefragger.Feed(message)
|
||||||
|
if dfMsg == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
h.udpAccess.RLock()
|
||||||
|
ch, ok := h.udpSessions[dfMsg.SessionID]
|
||||||
|
if ok {
|
||||||
|
select {
|
||||||
|
case ch <- dfMsg:
|
||||||
|
// OK
|
||||||
|
default:
|
||||||
|
// Silently drop the message when the channel is full
|
||||||
|
}
|
||||||
|
}
|
||||||
|
h.udpAccess.RUnlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (h *Hysteria) Close() error {
|
func (h *Hysteria) Close() error {
|
||||||
h.connAccess.Lock()
|
h.connAccess.Lock()
|
||||||
defer h.connAccess.Unlock()
|
defer h.connAccess.Unlock()
|
||||||
|
h.udpAccess.Lock()
|
||||||
|
defer h.udpAccess.Unlock()
|
||||||
if h.conn != nil {
|
if h.conn != nil {
|
||||||
h.conn.CloseWithError(0, "")
|
h.conn.CloseWithError(0, "")
|
||||||
}
|
}
|
||||||
|
for _, session := range h.udpSessions {
|
||||||
|
close(session)
|
||||||
|
}
|
||||||
|
h.udpSessions = make(map[uint32]chan *hysteria.UDPMessage)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Hysteria) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
func (h *Hysteria) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||||
|
switch N.NetworkName(network) {
|
||||||
|
case N.NetworkTCP:
|
||||||
|
conn, err := h.offer(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
stream, err := conn.OpenStream()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return hysteria.NewClientConn(stream, destination), nil
|
||||||
|
case N.NetworkUDP:
|
||||||
|
conn, err := h.ListenPacket(ctx, destination)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return conn.(*hysteria.ClientPacketConn), nil
|
||||||
|
default:
|
||||||
|
return nil, E.New("unsupported network: ", network)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Hysteria) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||||
conn, err := h.offer(ctx)
|
conn, err := h.offer(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -232,16 +303,43 @@ func (h *Hysteria) DialContext(ctx context.Context, network string, destination
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
switch N.NetworkName(network) {
|
err = hysteria.WriteClientRequest(stream, hysteria.ClientRequest{
|
||||||
case N.NetworkTCP:
|
UDP: true,
|
||||||
return hysteria.NewClientConn(stream, destination), nil
|
Host: destination.AddrString(),
|
||||||
default:
|
Port: destination.Port,
|
||||||
return nil, E.New("unsupported network: ", network)
|
}, nil)
|
||||||
|
if err != nil {
|
||||||
|
stream.Close()
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
var response *hysteria.ServerResponse
|
||||||
|
response, err = hysteria.ReadServerResponse(stream)
|
||||||
func (h *Hysteria) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
if err != nil {
|
||||||
return nil, os.ErrInvalid
|
stream.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if !response.OK {
|
||||||
|
stream.Close()
|
||||||
|
return nil, E.New("remote error: ", response.Message)
|
||||||
|
}
|
||||||
|
h.udpAccess.Lock()
|
||||||
|
nCh := make(chan *hysteria.UDPMessage, 1024)
|
||||||
|
// Store the current session map for CloseFunc below
|
||||||
|
// to ensures that we are adding and removing sessions on the same map,
|
||||||
|
// as reconnecting will reassign the map
|
||||||
|
h.udpSessions[response.UDPSessionID] = nCh
|
||||||
|
h.udpAccess.Unlock()
|
||||||
|
packetConn := hysteria.NewClientPacketConn(conn, stream, response.UDPSessionID, destination, nCh, common.Closer(func() error {
|
||||||
|
h.udpAccess.Lock()
|
||||||
|
if ch, ok := h.udpSessions[response.UDPSessionID]; ok {
|
||||||
|
close(ch)
|
||||||
|
delete(h.udpSessions, response.UDPSessionID)
|
||||||
|
}
|
||||||
|
h.udpAccess.Unlock()
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
go packetConn.Hold()
|
||||||
|
return packetConn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Hysteria) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
|
func (h *Hysteria) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
|
||||||
|
@ -381,7 +381,6 @@ func testLargeDataWithPacketConn(t *testing.T, port uint16, pcc func() (net.Pack
|
|||||||
mux.Lock()
|
mux.Lock()
|
||||||
hashMap[i] = hash[:]
|
hashMap[i] = hash[:]
|
||||||
mux.Unlock()
|
mux.Unlock()
|
||||||
println("write ti ", addr.String())
|
|
||||||
if _, err = pc.WriteTo(buf, addr); err != nil {
|
if _, err = pc.WriteTo(buf, addr); err != nil {
|
||||||
t.Log(err)
|
t.Log(err)
|
||||||
continue
|
continue
|
||||||
|
@ -56,5 +56,5 @@ func TestHysteriaOutbound(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
testTCP(t, clientPort, testPort)
|
testSuit(t, clientPort, testPort)
|
||||||
}
|
}
|
||||||
|
@ -1,8 +1,13 @@
|
|||||||
package hysteria
|
package hysteria
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing/common"
|
||||||
|
"github.com/sagernet/sing/common/buf"
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
|
|
||||||
@ -67,3 +72,107 @@ func (c *ClientConn) RemoteAddr() net.Addr {
|
|||||||
func (c *ClientConn) Upstream() any {
|
func (c *ClientConn) Upstream() any {
|
||||||
return c.Stream
|
return c.Stream
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ClientPacketConn struct {
|
||||||
|
session quic.Connection
|
||||||
|
stream quic.Stream
|
||||||
|
sessionId uint32
|
||||||
|
destination M.Socksaddr
|
||||||
|
msgCh <-chan *UDPMessage
|
||||||
|
closer io.Closer
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewClientPacketConn(session quic.Connection, stream quic.Stream, sessionId uint32, destination M.Socksaddr, msgCh <-chan *UDPMessage, closer io.Closer) *ClientPacketConn {
|
||||||
|
return &ClientPacketConn{
|
||||||
|
session: session,
|
||||||
|
stream: stream,
|
||||||
|
sessionId: sessionId,
|
||||||
|
destination: destination,
|
||||||
|
msgCh: msgCh,
|
||||||
|
closer: closer,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientPacketConn) Hold() {
|
||||||
|
// Hold the stream until it's closed
|
||||||
|
buf := make([]byte, 1024)
|
||||||
|
for {
|
||||||
|
_, err := c.stream.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ = c.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
|
||||||
|
msg := <-c.msgCh
|
||||||
|
if msg == nil {
|
||||||
|
err = net.ErrClosed
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = common.Error(buffer.Write(msg.Data))
|
||||||
|
destination = M.ParseSocksaddrHostPort(msg.Host, msg.Port)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientPacketConn) ReadPacketThreadSafe() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
|
||||||
|
msg := <-c.msgCh
|
||||||
|
if msg == nil {
|
||||||
|
err = net.ErrClosed
|
||||||
|
return
|
||||||
|
}
|
||||||
|
buffer = buf.As(msg.Data)
|
||||||
|
destination = M.ParseSocksaddrHostPort(msg.Host, msg.Port)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
||||||
|
return WriteUDPMessage(c.session, UDPMessage{
|
||||||
|
SessionID: c.sessionId,
|
||||||
|
Host: destination.AddrString(),
|
||||||
|
Port: destination.Port,
|
||||||
|
FragCount: 1,
|
||||||
|
Data: buffer.Bytes(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientPacketConn) LocalAddr() net.Addr {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientPacketConn) RemoteAddr() net.Addr {
|
||||||
|
return c.destination.UDPAddr()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientPacketConn) SetDeadline(t time.Time) error {
|
||||||
|
return os.ErrInvalid
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientPacketConn) SetReadDeadline(t time.Time) error {
|
||||||
|
return os.ErrInvalid
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientPacketConn) SetWriteDeadline(t time.Time) error {
|
||||||
|
return os.ErrInvalid
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
||||||
|
panic("invalid")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
|
||||||
|
panic("invalid")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientPacketConn) Read(b []byte) (n int, err error) {
|
||||||
|
panic("invalid")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientPacketConn) Write(b []byte) (n int, err error) {
|
||||||
|
panic("invalid")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientPacketConn) Close() error {
|
||||||
|
return common.Close(c.stream, c.closer)
|
||||||
|
}
|
||||||
|
65
transport/hysteria/frag.go
Normal file
65
transport/hysteria/frag.go
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
package hysteria
|
||||||
|
|
||||||
|
func FragUDPMessage(m UDPMessage, maxSize int) []UDPMessage {
|
||||||
|
if m.Size() <= maxSize {
|
||||||
|
return []UDPMessage{m}
|
||||||
|
}
|
||||||
|
fullPayload := m.Data
|
||||||
|
maxPayloadSize := maxSize - m.HeaderSize()
|
||||||
|
off := 0
|
||||||
|
fragID := uint8(0)
|
||||||
|
fragCount := uint8((len(fullPayload) + maxPayloadSize - 1) / maxPayloadSize) // round up
|
||||||
|
var frags []UDPMessage
|
||||||
|
for off < len(fullPayload) {
|
||||||
|
payloadSize := len(fullPayload) - off
|
||||||
|
if payloadSize > maxPayloadSize {
|
||||||
|
payloadSize = maxPayloadSize
|
||||||
|
}
|
||||||
|
frag := m
|
||||||
|
frag.FragID = fragID
|
||||||
|
frag.FragCount = fragCount
|
||||||
|
frag.Data = fullPayload[off : off+payloadSize]
|
||||||
|
frags = append(frags, frag)
|
||||||
|
off += payloadSize
|
||||||
|
fragID++
|
||||||
|
}
|
||||||
|
return frags
|
||||||
|
}
|
||||||
|
|
||||||
|
type Defragger struct {
|
||||||
|
msgID uint16
|
||||||
|
frags []*UDPMessage
|
||||||
|
count uint8
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Defragger) Feed(m UDPMessage) *UDPMessage {
|
||||||
|
if m.FragCount <= 1 {
|
||||||
|
return &m
|
||||||
|
}
|
||||||
|
if m.FragID >= m.FragCount {
|
||||||
|
// wtf is this?
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if m.MsgID != d.msgID {
|
||||||
|
// new message, clear previous state
|
||||||
|
d.msgID = m.MsgID
|
||||||
|
d.frags = make([]*UDPMessage, m.FragCount)
|
||||||
|
d.count = 1
|
||||||
|
d.frags[m.FragID] = &m
|
||||||
|
} else if d.frags[m.FragID] == nil {
|
||||||
|
d.frags[m.FragID] = &m
|
||||||
|
d.count++
|
||||||
|
if int(d.count) == len(d.frags) {
|
||||||
|
// all fragments received, assemble
|
||||||
|
var data []byte
|
||||||
|
for _, frag := range d.frags {
|
||||||
|
data = append(data, frag.Data...)
|
||||||
|
}
|
||||||
|
m.Data = data
|
||||||
|
m.FragID = 0
|
||||||
|
m.FragCount = 1
|
||||||
|
return &m
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
@ -1,11 +1,16 @@
|
|||||||
package hysteria
|
package hysteria
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"io"
|
"io"
|
||||||
|
"math/rand"
|
||||||
|
|
||||||
"github.com/sagernet/sing/common"
|
"github.com/sagernet/sing/common"
|
||||||
"github.com/sagernet/sing/common/buf"
|
"github.com/sagernet/sing/common/buf"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
|
||||||
|
"github.com/lucas-clemente/quic-go"
|
||||||
)
|
)
|
||||||
|
|
||||||
const Version = 3
|
const Version = 3
|
||||||
@ -23,18 +28,6 @@ type ServerHello struct {
|
|||||||
Message string
|
Message string
|
||||||
}
|
}
|
||||||
|
|
||||||
type ClientRequest struct {
|
|
||||||
UDP bool
|
|
||||||
Host string
|
|
||||||
Port uint16
|
|
||||||
}
|
|
||||||
|
|
||||||
type ServerResponse struct {
|
|
||||||
OK bool
|
|
||||||
UDPSessionID uint32
|
|
||||||
Message string
|
|
||||||
}
|
|
||||||
|
|
||||||
func WriteClientHello(stream io.Writer, hello ClientHello) error {
|
func WriteClientHello(stream io.Writer, hello ClientHello) error {
|
||||||
var requestLen int
|
var requestLen int
|
||||||
requestLen += 1 // version
|
requestLen += 1 // version
|
||||||
@ -87,6 +80,18 @@ func ReadServerHello(stream io.Reader) (*ServerHello, error) {
|
|||||||
return &serverHello, nil
|
return &serverHello, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ClientRequest struct {
|
||||||
|
UDP bool
|
||||||
|
Host string
|
||||||
|
Port uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
type ServerResponse struct {
|
||||||
|
OK bool
|
||||||
|
UDPSessionID uint32
|
||||||
|
Message string
|
||||||
|
}
|
||||||
|
|
||||||
func WriteClientRequest(stream io.Writer, request ClientRequest, payload []byte) error {
|
func WriteClientRequest(stream io.Writer, request ClientRequest, payload []byte) error {
|
||||||
var requestLen int
|
var requestLen int
|
||||||
requestLen += 1 // udp
|
requestLen += 1 // udp
|
||||||
@ -140,3 +145,115 @@ func ReadServerResponse(stream io.Reader) (*ServerResponse, error) {
|
|||||||
serverResponse.Message = string(message)
|
serverResponse.Message = string(message)
|
||||||
return &serverResponse, nil
|
return &serverResponse, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type UDPMessage struct {
|
||||||
|
SessionID uint32
|
||||||
|
Host string
|
||||||
|
Port uint16
|
||||||
|
MsgID uint16 // doesn't matter when not fragmented, but must not be 0 when fragmented
|
||||||
|
FragID uint8 // doesn't matter when not fragmented, starts at 0 when fragmented
|
||||||
|
FragCount uint8 // must be 1 when not fragmented
|
||||||
|
Data []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m UDPMessage) HeaderSize() int {
|
||||||
|
return 4 + 2 + len(m.Host) + 2 + 2 + 1 + 1 + 2
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m UDPMessage) Size() int {
|
||||||
|
return m.HeaderSize() + len(m.Data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ParseUDPMessage(packet []byte) (message UDPMessage, err error) {
|
||||||
|
reader := bytes.NewReader(packet)
|
||||||
|
err = binary.Read(reader, binary.BigEndian, &message.SessionID)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var hostLen uint16
|
||||||
|
err = binary.Read(reader, binary.BigEndian, &hostLen)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, err = reader.Seek(int64(hostLen), io.SeekCurrent)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
message.Host = string(packet[6 : 6+hostLen])
|
||||||
|
err = binary.Read(reader, binary.BigEndian, &message.Port)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = binary.Read(reader, binary.BigEndian, &message.MsgID)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = binary.Read(reader, binary.BigEndian, &message.FragID)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = binary.Read(reader, binary.BigEndian, &message.FragCount)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var dataLen uint16
|
||||||
|
err = binary.Read(reader, binary.BigEndian, &dataLen)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if reader.Len() != int(dataLen) {
|
||||||
|
err = E.New("invalid data length")
|
||||||
|
}
|
||||||
|
dataOffset := int(reader.Size()) - reader.Len()
|
||||||
|
message.Data = packet[dataOffset:]
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func WriteUDPMessage(conn quic.Connection, message UDPMessage) error {
|
||||||
|
var messageLen int
|
||||||
|
messageLen += 4 // session id
|
||||||
|
messageLen += 2 // host len
|
||||||
|
messageLen += len(message.Host)
|
||||||
|
messageLen += 2 // port
|
||||||
|
messageLen += 2 // msg id
|
||||||
|
messageLen += 1 // frag id
|
||||||
|
messageLen += 1 // frag count
|
||||||
|
messageLen += 2 // data len
|
||||||
|
messageLen += len(message.Data)
|
||||||
|
_buffer := buf.StackNewSize(messageLen)
|
||||||
|
defer common.KeepAlive(_buffer)
|
||||||
|
buffer := common.Dup(_buffer)
|
||||||
|
defer buffer.Release()
|
||||||
|
err := writeUDPMessage(conn, message, buffer)
|
||||||
|
// TODO: wait for change upstream
|
||||||
|
if /*errSize, ok := err.(quic.ErrMessageToLarge); ok*/ false {
|
||||||
|
const errSize = 0
|
||||||
|
// need to frag
|
||||||
|
message.MsgID = uint16(rand.Intn(0xFFFF)) + 1 // msgID must be > 0 when fragCount > 1
|
||||||
|
fragMsgs := FragUDPMessage(message, int(errSize))
|
||||||
|
for _, fragMsg := range fragMsgs {
|
||||||
|
buffer.FullReset()
|
||||||
|
err = writeUDPMessage(conn, fragMsg, buffer)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeUDPMessage(conn quic.Connection, message UDPMessage, buffer *buf.Buffer) error {
|
||||||
|
common.Must(
|
||||||
|
binary.Write(buffer, binary.BigEndian, message.SessionID),
|
||||||
|
binary.Write(buffer, binary.BigEndian, uint16(len(message.Host))),
|
||||||
|
common.Error(buffer.WriteString(message.Host)),
|
||||||
|
binary.Write(buffer, binary.BigEndian, message.Port),
|
||||||
|
binary.Write(buffer, binary.BigEndian, message.MsgID),
|
||||||
|
binary.Write(buffer, binary.BigEndian, message.FragID),
|
||||||
|
binary.Write(buffer, binary.BigEndian, message.FragCount),
|
||||||
|
binary.Write(buffer, binary.BigEndian, uint16(len(message.Data))),
|
||||||
|
common.Error(buffer.Write(message.Data)),
|
||||||
|
)
|
||||||
|
return conn.SendMessage(buffer.Bytes())
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user