diff --git a/outbound/hysteria.go b/outbound/hysteria.go index 40a37327..c539fda3 100644 --- a/outbound/hysteria.go +++ b/outbound/hysteria.go @@ -40,17 +40,20 @@ var _ adapter.Outbound = (*Hysteria)(nil) type Hysteria struct { myOutboundAdapter - ctx context.Context - dialer N.Dialer - serverAddr M.Socksaddr - tlsConfig *tls.Config - quicConfig *quic.Config - authKey []byte - xplusKey []byte - sendBPS uint64 - recvBPS uint64 - connAccess sync.Mutex - conn quic.Connection + ctx context.Context + dialer N.Dialer + serverAddr M.Socksaddr + tlsConfig *tls.Config + quicConfig *quic.Config + authKey []byte + xplusKey []byte + sendBPS uint64 + recvBPS uint64 + connAccess sync.Mutex + conn quic.Connection + udpAccess sync.RWMutex + udpSessions map[uint32]chan *hysteria.UDPMessage + udpDefragger hysteria.Defragger } func NewHysteria(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.HysteriaOutboundOptions) (*Hysteria, error) { @@ -162,6 +165,8 @@ func (h *Hysteria) offer(ctx context.Context) (quic.Connection, error) { } h.connAccess.Lock() defer h.connAccess.Unlock() + h.udpAccess.Lock() + defer h.udpAccess.Unlock() conn = h.conn if conn != nil && !common.Done(conn.Context()) { return conn, nil @@ -171,6 +176,14 @@ func (h *Hysteria) offer(ctx context.Context) (quic.Connection, error) { return nil, err } h.conn = conn + if common.Contains(h.network, N.NetworkUDP) { + for _, session := range h.udpSessions { + close(session) + } + h.udpSessions = make(map[uint32]chan *hysteria.UDPMessage) + h.udpDefragger = hysteria.Defragger{} + go h.recvLoop(conn) + } return conn, nil } @@ -214,16 +227,74 @@ func (h *Hysteria) offerNew(ctx context.Context) (quic.Connection, error) { return quicConn, nil } +func (h *Hysteria) recvLoop(conn quic.Connection) { + for { + packet, err := conn.ReceiveMessage() + if err != nil { + return + } + message, err := hysteria.ParseUDPMessage(packet) + if err != nil { + h.logger.Error("parse udp message: ", err) + continue + } + dfMsg := h.udpDefragger.Feed(message) + if dfMsg == nil { + continue + } + h.udpAccess.RLock() + ch, ok := h.udpSessions[dfMsg.SessionID] + if ok { + select { + case ch <- dfMsg: + // OK + default: + // Silently drop the message when the channel is full + } + } + h.udpAccess.RUnlock() + } +} + func (h *Hysteria) Close() error { h.connAccess.Lock() defer h.connAccess.Unlock() + h.udpAccess.Lock() + defer h.udpAccess.Unlock() if h.conn != nil { h.conn.CloseWithError(0, "") } + for _, session := range h.udpSessions { + close(session) + } + h.udpSessions = make(map[uint32]chan *hysteria.UDPMessage) return nil } func (h *Hysteria) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { + switch N.NetworkName(network) { + case N.NetworkTCP: + conn, err := h.offer(ctx) + if err != nil { + return nil, err + } + stream, err := conn.OpenStream() + if err != nil { + return nil, err + } + return hysteria.NewClientConn(stream, destination), nil + case N.NetworkUDP: + conn, err := h.ListenPacket(ctx, destination) + if err != nil { + return nil, err + } + return conn.(*hysteria.ClientPacketConn), nil + default: + return nil, E.New("unsupported network: ", network) + } +} + +func (h *Hysteria) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { conn, err := h.offer(ctx) if err != nil { return nil, err @@ -232,16 +303,43 @@ func (h *Hysteria) DialContext(ctx context.Context, network string, destination if err != nil { return nil, err } - switch N.NetworkName(network) { - case N.NetworkTCP: - return hysteria.NewClientConn(stream, destination), nil - default: - return nil, E.New("unsupported network: ", network) + err = hysteria.WriteClientRequest(stream, hysteria.ClientRequest{ + UDP: true, + Host: destination.AddrString(), + Port: destination.Port, + }, nil) + if err != nil { + stream.Close() + return nil, err } -} - -func (h *Hysteria) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { - return nil, os.ErrInvalid + var response *hysteria.ServerResponse + response, err = hysteria.ReadServerResponse(stream) + if err != nil { + stream.Close() + return nil, err + } + if !response.OK { + stream.Close() + return nil, E.New("remote error: ", response.Message) + } + h.udpAccess.Lock() + nCh := make(chan *hysteria.UDPMessage, 1024) + // Store the current session map for CloseFunc below + // to ensures that we are adding and removing sessions on the same map, + // as reconnecting will reassign the map + h.udpSessions[response.UDPSessionID] = nCh + h.udpAccess.Unlock() + packetConn := hysteria.NewClientPacketConn(conn, stream, response.UDPSessionID, destination, nCh, common.Closer(func() error { + h.udpAccess.Lock() + if ch, ok := h.udpSessions[response.UDPSessionID]; ok { + close(ch) + delete(h.udpSessions, response.UDPSessionID) + } + h.udpAccess.Unlock() + return nil + })) + go packetConn.Hold() + return packetConn, nil } func (h *Hysteria) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { diff --git a/test/clash_test.go b/test/clash_test.go index b490434a..908480cc 100644 --- a/test/clash_test.go +++ b/test/clash_test.go @@ -381,7 +381,6 @@ func testLargeDataWithPacketConn(t *testing.T, port uint16, pcc func() (net.Pack mux.Lock() hashMap[i] = hash[:] mux.Unlock() - println("write ti ", addr.String()) if _, err = pc.WriteTo(buf, addr); err != nil { t.Log(err) continue diff --git a/test/hysteria_test.go b/test/hysteria_test.go index 055f76f5..e15b7b37 100644 --- a/test/hysteria_test.go +++ b/test/hysteria_test.go @@ -56,5 +56,5 @@ func TestHysteriaOutbound(t *testing.T) { }, }, }) - testTCP(t, clientPort, testPort) + testSuit(t, clientPort, testPort) } diff --git a/transport/hysteria/client.go b/transport/hysteria/client.go index 5ae2bfe4..36cc84b0 100644 --- a/transport/hysteria/client.go +++ b/transport/hysteria/client.go @@ -1,8 +1,13 @@ package hysteria import ( + "io" "net" + "os" + "time" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" @@ -67,3 +72,107 @@ func (c *ClientConn) RemoteAddr() net.Addr { func (c *ClientConn) Upstream() any { return c.Stream } + +type ClientPacketConn struct { + session quic.Connection + stream quic.Stream + sessionId uint32 + destination M.Socksaddr + msgCh <-chan *UDPMessage + closer io.Closer +} + +func NewClientPacketConn(session quic.Connection, stream quic.Stream, sessionId uint32, destination M.Socksaddr, msgCh <-chan *UDPMessage, closer io.Closer) *ClientPacketConn { + return &ClientPacketConn{ + session: session, + stream: stream, + sessionId: sessionId, + destination: destination, + msgCh: msgCh, + closer: closer, + } +} + +func (c *ClientPacketConn) Hold() { + // Hold the stream until it's closed + buf := make([]byte, 1024) + for { + _, err := c.stream.Read(buf) + if err != nil { + break + } + } + _ = c.Close() +} + +func (c *ClientPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { + msg := <-c.msgCh + if msg == nil { + err = net.ErrClosed + return + } + err = common.Error(buffer.Write(msg.Data)) + destination = M.ParseSocksaddrHostPort(msg.Host, msg.Port) + return +} + +func (c *ClientPacketConn) ReadPacketThreadSafe() (buffer *buf.Buffer, destination M.Socksaddr, err error) { + msg := <-c.msgCh + if msg == nil { + err = net.ErrClosed + return + } + buffer = buf.As(msg.Data) + destination = M.ParseSocksaddrHostPort(msg.Host, msg.Port) + return +} + +func (c *ClientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + return WriteUDPMessage(c.session, UDPMessage{ + SessionID: c.sessionId, + Host: destination.AddrString(), + Port: destination.Port, + FragCount: 1, + Data: buffer.Bytes(), + }) +} + +func (c *ClientPacketConn) LocalAddr() net.Addr { + return nil +} + +func (c *ClientPacketConn) RemoteAddr() net.Addr { + return c.destination.UDPAddr() +} + +func (c *ClientPacketConn) SetDeadline(t time.Time) error { + return os.ErrInvalid +} + +func (c *ClientPacketConn) SetReadDeadline(t time.Time) error { + return os.ErrInvalid +} + +func (c *ClientPacketConn) SetWriteDeadline(t time.Time) error { + return os.ErrInvalid +} + +func (c *ClientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + panic("invalid") +} + +func (c *ClientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + panic("invalid") +} + +func (c *ClientPacketConn) Read(b []byte) (n int, err error) { + panic("invalid") +} + +func (c *ClientPacketConn) Write(b []byte) (n int, err error) { + panic("invalid") +} + +func (c *ClientPacketConn) Close() error { + return common.Close(c.stream, c.closer) +} diff --git a/transport/hysteria/frag.go b/transport/hysteria/frag.go new file mode 100644 index 00000000..721341f1 --- /dev/null +++ b/transport/hysteria/frag.go @@ -0,0 +1,65 @@ +package hysteria + +func FragUDPMessage(m UDPMessage, maxSize int) []UDPMessage { + if m.Size() <= maxSize { + return []UDPMessage{m} + } + fullPayload := m.Data + maxPayloadSize := maxSize - m.HeaderSize() + off := 0 + fragID := uint8(0) + fragCount := uint8((len(fullPayload) + maxPayloadSize - 1) / maxPayloadSize) // round up + var frags []UDPMessage + for off < len(fullPayload) { + payloadSize := len(fullPayload) - off + if payloadSize > maxPayloadSize { + payloadSize = maxPayloadSize + } + frag := m + frag.FragID = fragID + frag.FragCount = fragCount + frag.Data = fullPayload[off : off+payloadSize] + frags = append(frags, frag) + off += payloadSize + fragID++ + } + return frags +} + +type Defragger struct { + msgID uint16 + frags []*UDPMessage + count uint8 +} + +func (d *Defragger) Feed(m UDPMessage) *UDPMessage { + if m.FragCount <= 1 { + return &m + } + if m.FragID >= m.FragCount { + // wtf is this? + return nil + } + if m.MsgID != d.msgID { + // new message, clear previous state + d.msgID = m.MsgID + d.frags = make([]*UDPMessage, m.FragCount) + d.count = 1 + d.frags[m.FragID] = &m + } else if d.frags[m.FragID] == nil { + d.frags[m.FragID] = &m + d.count++ + if int(d.count) == len(d.frags) { + // all fragments received, assemble + var data []byte + for _, frag := range d.frags { + data = append(data, frag.Data...) + } + m.Data = data + m.FragID = 0 + m.FragCount = 1 + return &m + } + } + return nil +} diff --git a/transport/hysteria/protocol.go b/transport/hysteria/protocol.go index 5daedc37..091923e5 100644 --- a/transport/hysteria/protocol.go +++ b/transport/hysteria/protocol.go @@ -1,11 +1,16 @@ package hysteria import ( + "bytes" "encoding/binary" "io" + "math/rand" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" + + "github.com/lucas-clemente/quic-go" ) const Version = 3 @@ -23,18 +28,6 @@ type ServerHello struct { Message string } -type ClientRequest struct { - UDP bool - Host string - Port uint16 -} - -type ServerResponse struct { - OK bool - UDPSessionID uint32 - Message string -} - func WriteClientHello(stream io.Writer, hello ClientHello) error { var requestLen int requestLen += 1 // version @@ -87,6 +80,18 @@ func ReadServerHello(stream io.Reader) (*ServerHello, error) { return &serverHello, nil } +type ClientRequest struct { + UDP bool + Host string + Port uint16 +} + +type ServerResponse struct { + OK bool + UDPSessionID uint32 + Message string +} + func WriteClientRequest(stream io.Writer, request ClientRequest, payload []byte) error { var requestLen int requestLen += 1 // udp @@ -140,3 +145,115 @@ func ReadServerResponse(stream io.Reader) (*ServerResponse, error) { serverResponse.Message = string(message) return &serverResponse, nil } + +type UDPMessage struct { + SessionID uint32 + Host string + Port uint16 + MsgID uint16 // doesn't matter when not fragmented, but must not be 0 when fragmented + FragID uint8 // doesn't matter when not fragmented, starts at 0 when fragmented + FragCount uint8 // must be 1 when not fragmented + Data []byte +} + +func (m UDPMessage) HeaderSize() int { + return 4 + 2 + len(m.Host) + 2 + 2 + 1 + 1 + 2 +} + +func (m UDPMessage) Size() int { + return m.HeaderSize() + len(m.Data) +} + +func ParseUDPMessage(packet []byte) (message UDPMessage, err error) { + reader := bytes.NewReader(packet) + err = binary.Read(reader, binary.BigEndian, &message.SessionID) + if err != nil { + return + } + var hostLen uint16 + err = binary.Read(reader, binary.BigEndian, &hostLen) + if err != nil { + return + } + _, err = reader.Seek(int64(hostLen), io.SeekCurrent) + if err != nil { + return + } + message.Host = string(packet[6 : 6+hostLen]) + err = binary.Read(reader, binary.BigEndian, &message.Port) + if err != nil { + return + } + err = binary.Read(reader, binary.BigEndian, &message.MsgID) + if err != nil { + return + } + err = binary.Read(reader, binary.BigEndian, &message.FragID) + if err != nil { + return + } + err = binary.Read(reader, binary.BigEndian, &message.FragCount) + if err != nil { + return + } + var dataLen uint16 + err = binary.Read(reader, binary.BigEndian, &dataLen) + if err != nil { + return + } + if reader.Len() != int(dataLen) { + err = E.New("invalid data length") + } + dataOffset := int(reader.Size()) - reader.Len() + message.Data = packet[dataOffset:] + return +} + +func WriteUDPMessage(conn quic.Connection, message UDPMessage) error { + var messageLen int + messageLen += 4 // session id + messageLen += 2 // host len + messageLen += len(message.Host) + messageLen += 2 // port + messageLen += 2 // msg id + messageLen += 1 // frag id + messageLen += 1 // frag count + messageLen += 2 // data len + messageLen += len(message.Data) + _buffer := buf.StackNewSize(messageLen) + defer common.KeepAlive(_buffer) + buffer := common.Dup(_buffer) + defer buffer.Release() + err := writeUDPMessage(conn, message, buffer) + // TODO: wait for change upstream + if /*errSize, ok := err.(quic.ErrMessageToLarge); ok*/ false { + const errSize = 0 + // need to frag + message.MsgID = uint16(rand.Intn(0xFFFF)) + 1 // msgID must be > 0 when fragCount > 1 + fragMsgs := FragUDPMessage(message, int(errSize)) + for _, fragMsg := range fragMsgs { + buffer.FullReset() + err = writeUDPMessage(conn, fragMsg, buffer) + if err != nil { + return err + } + } + return nil + } + return err +} + +func writeUDPMessage(conn quic.Connection, message UDPMessage, buffer *buf.Buffer) error { + common.Must( + binary.Write(buffer, binary.BigEndian, message.SessionID), + binary.Write(buffer, binary.BigEndian, uint16(len(message.Host))), + common.Error(buffer.WriteString(message.Host)), + binary.Write(buffer, binary.BigEndian, message.Port), + binary.Write(buffer, binary.BigEndian, message.MsgID), + binary.Write(buffer, binary.BigEndian, message.FragID), + binary.Write(buffer, binary.BigEndian, message.FragCount), + binary.Write(buffer, binary.BigEndian, uint16(len(message.Data))), + common.Error(buffer.Write(message.Data)), + ) + return conn.SendMessage(buffer.Bytes()) +}