From d6fd5f8c03b0d445e318e100984ce52afbcf791f Mon Sep 17 00:00:00 2001 From: arm64v8a <48624112+arm64v8a@users.noreply.github.com> Date: Sat, 8 Jul 2023 10:36:48 +0900 Subject: [PATCH] Add hysteria hop client --- option/hysteria.go | 2 + outbound/hysteria.go | 84 +++++++-- transport/hysteria/hop/hop.go | 346 ++++++++++++++++++++++++++++++++++ 3 files changed, 414 insertions(+), 18 deletions(-) create mode 100644 transport/hysteria/hop/hop.go diff --git a/option/hysteria.go b/option/hysteria.go index d3ce6a10..03e39ead 100644 --- a/option/hysteria.go +++ b/option/hysteria.go @@ -36,4 +36,6 @@ type HysteriaOutboundOptions struct { DisableMTUDiscovery bool `json:"disable_mtu_discovery,omitempty"` Network NetworkList `json:"network,omitempty"` TLS *OutboundTLSOptions `json:"tls,omitempty"` + HopPorts string `json:"hop_ports,omitempty"` + HopInterval int `json:"hop_interval,omitempty"` } diff --git a/outbound/hysteria.go b/outbound/hysteria.go index 207b9c4c..0389d8c5 100644 --- a/outbound/hysteria.go +++ b/outbound/hysteria.go @@ -4,8 +4,11 @@ package outbound import ( "context" + "fmt" + "io" "net" "sync" + "time" "github.com/sagernet/quic-go" "github.com/sagernet/quic-go/congestion" @@ -16,6 +19,7 @@ import ( "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/transport/hysteria" + "github.com/sagernet/sing-box/transport/hysteria/hop" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" @@ -33,6 +37,8 @@ type Hysteria struct { ctx context.Context dialer N.Dialer serverAddr M.Socksaddr + hopPorts string + hopInterval time.Duration tlsConfig *tls.STDConfig quicConfig *quic.Config authKey []byte @@ -41,7 +47,7 @@ type Hysteria struct { recvBPS uint64 connAccess sync.Mutex conn quic.Connection - rawConn net.Conn + rawConn io.Closer udpAccess sync.RWMutex udpSessions map[uint32]chan *hysteria.UDPMessage udpDefragger hysteria.Defragger @@ -117,6 +123,9 @@ func NewHysteria(ctx context.Context, router adapter.Router, logger log.ContextL if down < hysteria.MinSpeedBPS { return nil, E.New("invalid down speed") } + if options.HopInterval < 10 { + options.HopInterval = 10 + } return &Hysteria{ myOutboundAdapter: myOutboundAdapter{ protocol: C.TypeHysteria, @@ -126,15 +135,17 @@ func NewHysteria(ctx context.Context, router adapter.Router, logger log.ContextL tag: tag, dependencies: withDialerDependency(options.DialerOptions), }, - ctx: ctx, - dialer: dialer.New(router, options.DialerOptions), - serverAddr: options.ServerOptions.Build(), - tlsConfig: tlsConfig, - quicConfig: quicConfig, - authKey: auth, - xplusKey: xplus, - sendBPS: up, - recvBPS: down, + ctx: ctx, + dialer: dialer.New(router, options.DialerOptions), + serverAddr: options.ServerOptions.Build(), + hopPorts: options.HopPorts, + hopInterval: time.Second * time.Duration(options.HopInterval), + tlsConfig: tlsConfig, + quicConfig: quicConfig, + authKey: auth, + xplusKey: xplus, + sendBPS: up, + recvBPS: down, }, nil } @@ -168,17 +179,54 @@ func (h *Hysteria) offer(ctx context.Context) (quic.Connection, error) { } func (h *Hysteria) offerNew(ctx context.Context) (quic.Connection, error) { - udpConn, err := h.dialer.DialContext(h.ctx, "udp", h.serverAddr) - if err != nil { - return nil, err - } + var addr net.Addr var packetConn net.PacketConn - packetConn = bufio.NewUnbindPacketConn(udpConn) + var rawConn io.Closer + // + if h.hopPorts != "" { + hyAddrStr := h.serverAddr.AddrString() + if h.serverAddr.IsIPv6() { + hyAddrStr = fmt.Sprintf("[%s]", hyAddrStr) + } + hyAddrStr += ":" + h.hopPorts + host, ports, err := hop.ParseAddr(hyAddrStr) + if err != nil { + return nil, E.Cause(err, "hop.ParseAddr") + } + packetConn, addr, err = hop.NewUDPHopClientPacketConn(hyAddrStr, h.hopInterval, func() (net.PacketConn, error) { + return h.dialer.ListenPacket(ctx, M.ParseSocksaddrHostPort(host, ports[0])) + }, func(host string) (net.IP, error) { + if ip := net.ParseIP(host); ip != nil { + return ip, nil + } + ips, err := h.router.LookupDefault(ctx, host) + if err != nil { + return nil, err + } + return ips[0].AsSlice(), nil + }) + if err != nil { + return nil, E.Cause(err, "hop.NewUDPHopClientPacketConn") + } + rawConn = packetConn + } else { + udpConn, err := h.dialer.DialContext(h.ctx, "udp", h.serverAddr) + if err != nil { + return nil, err + } + addr = udpConn.RemoteAddr() + rawConn = udpConn + packetConn = bufio.NewUnbindPacketConn(udpConn) + } + // if h.xplusKey != nil { packetConn = hysteria.NewXPlusPacketConn(packetConn, h.xplusKey) } - packetConn = &hysteria.PacketConnWrapper{PacketConn: packetConn} - quicConn, err := quic.Dial(h.ctx, packetConn, udpConn.RemoteAddr(), h.tlsConfig, h.quicConfig) + if h.hopPorts == "" { + packetConn = &hysteria.PacketConnWrapper{PacketConn: packetConn} + } + // + quicConn, err := quic.Dial(h.ctx, packetConn, addr, h.tlsConfig, h.quicConfig) if err != nil { packetConn.Close() return nil, err @@ -208,7 +256,7 @@ func (h *Hysteria) offerNew(ctx context.Context) (quic.Connection, error) { } quicConn.SetCongestionControl(hysteria.NewBrutalSender(congestion.ByteCount(serverHello.RecvBPS))) h.conn = quicConn - h.rawConn = udpConn + h.rawConn = rawConn return quicConn, nil } diff --git a/transport/hysteria/hop/hop.go b/transport/hysteria/hop/hop.go new file mode 100644 index 00000000..c09db67e --- /dev/null +++ b/transport/hysteria/hop/hop.go @@ -0,0 +1,346 @@ +package hop + +import ( + "errors" + "math/rand" + "net" + "strconv" + "strings" + "sync" + "syscall" + "time" +) + +const ( + udpBufferSize = 4096 + packetQueueSize = 1024 +) + +type ListenPacketFunc func() (net.PacketConn, error) +type ResolveFunc func(host string) (net.IP, error) + +// UDPHopClientPacketConn is the UDP port-hopping packet connection for client side. +// It hops to a different local & server port every once in a while. +type UDPHopClientPacketConn struct { + listenPacket ListenPacketFunc + serverAddr net.Addr // Combined udpHopAddr + serverAddrs []net.Addr + hopInterval time.Duration + + connMutex sync.RWMutex + prevConn net.PacketConn + currentConn net.PacketConn + addrIndex int + + readBufferSize int + writeBufferSize int + + recvQueue chan *udpPacket + closeChan chan struct{} + closed bool + + bufPool sync.Pool +} + +type udpHopAddr string + +func (a *udpHopAddr) Network() string { + return "udp-hop" +} + +func (a *udpHopAddr) String() string { + return string(*a) +} + +type udpPacket struct { + buf []byte + n int + addr net.Addr +} + +func NewUDPHopClientPacketConn(server string, hopInterval time.Duration, listenPacket ListenPacketFunc, lookupFunc ResolveFunc) (*UDPHopClientPacketConn, net.Addr, error) { + host, ports, err := ParseAddr(server) + if err != nil { + return nil, nil, err + } + // Resolve the server IP address, then attach the ports to UDP addresses + ip, err := lookupFunc(host) + if err != nil { + return nil, nil, err + } + serverAddrs := make([]net.Addr, len(ports)) + for i, port := range ports { + serverAddrs[i] = &net.UDPAddr{ + IP: ip, + Port: int(port), + } + } + hopAddr := udpHopAddr(server) + conn := &UDPHopClientPacketConn{ + listenPacket: listenPacket, + serverAddr: &hopAddr, + serverAddrs: serverAddrs, + hopInterval: hopInterval, + addrIndex: rand.Intn(len(serverAddrs)), + recvQueue: make(chan *udpPacket, packetQueueSize), + closeChan: make(chan struct{}), + bufPool: sync.Pool{ + New: func() interface{} { + return make([]byte, udpBufferSize) + }, + }, + } + curConn, err := listenPacket() + if err != nil { + return nil, nil, err + } + conn.currentConn = curConn + go conn.recvRoutine(conn.currentConn) + go conn.hopRoutine() + return conn, conn.serverAddr, nil +} + +func (c *UDPHopClientPacketConn) recvRoutine(conn net.PacketConn) { + for { + buf := c.bufPool.Get().([]byte) + n, addr, err := conn.ReadFrom(buf) + if err != nil { + return + } + select { + case c.recvQueue <- &udpPacket{buf, n, addr}: + default: + // Drop the packet if the queue is full + c.bufPool.Put(buf) + } + } +} + +func (c *UDPHopClientPacketConn) hopRoutine() { + ticker := time.NewTicker(c.hopInterval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + c.hop() + case <-c.closeChan: + return + } + } +} + +func (c *UDPHopClientPacketConn) hop() { + c.connMutex.Lock() + defer c.connMutex.Unlock() + if c.closed { + return + } + newConn, err := c.listenPacket() + if err != nil { + // Skip this hop if failed to listen + return + } + // Close prevConn, + // prevConn <- currentConn + // currentConn <- newConn + // update addrIndex + // + // We need to keep receiving packets from the previous connection, + // because otherwise there will be packet loss due to the time gap + // between we hop to a new port and the server acknowledges this change. + if c.prevConn != nil { + _ = c.prevConn.Close() // recvRoutine will exit on error + } + c.prevConn = c.currentConn + c.currentConn = newConn + // Set buffer sizes if previously set + if c.readBufferSize > 0 { + _ = trySetPacketConnReadBuffer(c.currentConn, c.readBufferSize) + } + if c.writeBufferSize > 0 { + _ = trySetPacketConnWriteBuffer(c.currentConn, c.writeBufferSize) + } + go c.recvRoutine(c.currentConn) + c.addrIndex = rand.Intn(len(c.serverAddrs)) +} + +func (c *UDPHopClientPacketConn) ReadFrom(b []byte) (int, net.Addr, error) { + for { + select { + case p := <-c.recvQueue: + /* + // Check if the packet is from one of the server addresses + for _, addr := range c.serverAddrs { + if addr.String() == p.addr.String() { + // Copy the packet to the buffer + n := copy(b, p.buf[:p.n]) + c.bufPool.Put(p.buf) + return n, c.serverAddr, nil + } + } + // Drop the packet, continue + c.bufPool.Put(p.buf) + */ + // The above code was causing performance issues when the range is large, + // so we skip the check for now. Should probably still check by using a map + // or something in the future. + n := copy(b, p.buf[:p.n]) + c.bufPool.Put(p.buf) + return n, c.serverAddr, nil + case <-c.closeChan: + return 0, nil, net.ErrClosed + } + // Ignore packets from other addresses + } +} + +func (c *UDPHopClientPacketConn) WriteTo(b []byte, addr net.Addr) (int, error) { + c.connMutex.RLock() + defer c.connMutex.RUnlock() + if c.closed { + return 0, net.ErrClosed + } + /* + // Check if the address is the server address + if addr.String() != c.serverAddr.String() { + return 0, net.ErrWriteToConnected + } + */ + // Skip the check for now, always write to the server + return c.currentConn.WriteTo(b, c.serverAddrs[c.addrIndex]) +} + +func (c *UDPHopClientPacketConn) Close() error { + c.connMutex.Lock() + defer c.connMutex.Unlock() + if c.closed { + return nil + } + // Close prevConn and currentConn + // Close closeChan to unblock ReadFrom & hopRoutine + // Set closed flag to true to prevent double close + if c.prevConn != nil { + _ = c.prevConn.Close() + } + err := c.currentConn.Close() + close(c.closeChan) + c.closed = true + c.serverAddrs = nil // For GC + return err +} + +func (c *UDPHopClientPacketConn) LocalAddr() net.Addr { + c.connMutex.RLock() + defer c.connMutex.RUnlock() + return c.currentConn.LocalAddr() +} + +func (c *UDPHopClientPacketConn) SetReadDeadline(t time.Time) error { + // Not supported + return nil +} + +func (c *UDPHopClientPacketConn) SetWriteDeadline(t time.Time) error { + // Not supported + return nil +} + +func (c *UDPHopClientPacketConn) SetDeadline(t time.Time) error { + err := c.SetReadDeadline(t) + if err != nil { + return err + } + return c.SetWriteDeadline(t) +} + +func (c *UDPHopClientPacketConn) SetReadBuffer(bytes int) error { + c.connMutex.Lock() + defer c.connMutex.Unlock() + c.readBufferSize = bytes + if c.prevConn != nil { + _ = trySetPacketConnReadBuffer(c.prevConn, bytes) + } + return trySetPacketConnReadBuffer(c.currentConn, bytes) +} + +func (c *UDPHopClientPacketConn) SetWriteBuffer(bytes int) error { + c.connMutex.Lock() + defer c.connMutex.Unlock() + c.writeBufferSize = bytes + if c.prevConn != nil { + _ = trySetPacketConnWriteBuffer(c.prevConn, bytes) + } + return trySetPacketConnWriteBuffer(c.currentConn, bytes) +} + +func (c *UDPHopClientPacketConn) SyscallConn() (syscall.RawConn, error) { + c.connMutex.RLock() + defer c.connMutex.RUnlock() + sc, ok := c.currentConn.(syscall.Conn) + if !ok { + return nil, errors.New("not supported") + } + return sc.SyscallConn() +} + +func trySetPacketConnReadBuffer(pc net.PacketConn, bytes int) error { + sc, ok := pc.(interface { + SetReadBuffer(bytes int) error + }) + if ok { + return sc.SetReadBuffer(bytes) + } + return nil +} + +func trySetPacketConnWriteBuffer(pc net.PacketConn, bytes int) error { + sc, ok := pc.(interface { + SetWriteBuffer(bytes int) error + }) + if ok { + return sc.SetWriteBuffer(bytes) + } + return nil +} + +// ParseAddr parses the multi-port server address and returns the host and ports. +// Supports both comma-separated single ports and dash-separated port ranges. +// Format: "host:port1,port2-port3,port4" +func ParseAddr(addr string) (host string, ports []uint16, err error) { + host, portStr, err := net.SplitHostPort(addr) + if err != nil { + return "", nil, err + } + portStrs := strings.Split(portStr, ",") + for _, portStr := range portStrs { + if strings.Contains(portStr, "-") { + // Port range + portRange := strings.Split(portStr, "-") + if len(portRange) != 2 { + return "", nil, net.InvalidAddrError("invalid port range") + } + start, err := strconv.ParseUint(portRange[0], 10, 16) + if err != nil { + return "", nil, net.InvalidAddrError("invalid port range") + } + end, err := strconv.ParseUint(portRange[1], 10, 16) + if err != nil { + return "", nil, net.InvalidAddrError("invalid port range") + } + if start > end { + start, end = end, start + } + for i := start; i <= end; i++ { + ports = append(ports, uint16(i)) + } + } else { + // Single port + port, err := strconv.ParseUint(portStr, 10, 16) + if err != nil { + return "", nil, net.InvalidAddrError("invalid port") + } + ports = append(ports, uint16(port)) + } + } + return host, ports, nil +}