From 17f10e0d3aa4d9508c9cba4f6795c0b866a13a3a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Thu, 18 Aug 2022 14:39:48 +0800 Subject: [PATCH] Add hysteria tcp client --- constant/proxy.go | 1 + option/hysteria.go | 23 +++ option/outbound.go | 5 + outbound/builder.go | 2 + outbound/hysteria.go | 253 +++++++++++++++++++++++++++++++++ outbound/hysteria_stub.go | 16 +++ test/box_test.go | 43 ++---- test/clash_test.go | 2 + test/config/hysteria.json | 9 ++ test/hysteria_test.go | 60 ++++++++ transport/hysteria/client.go | 69 +++++++++ transport/hysteria/protocol.go | 142 ++++++++++++++++++ transport/hysteria/speed.go | 36 +++++ transport/hysteria/wrap.go | 29 ++++ transport/hysteria/xplus.go | 119 ++++++++++++++++ 15 files changed, 777 insertions(+), 32 deletions(-) create mode 100644 option/hysteria.go create mode 100644 outbound/hysteria.go create mode 100644 outbound/hysteria_stub.go create mode 100644 test/config/hysteria.json create mode 100644 test/hysteria_test.go create mode 100644 transport/hysteria/client.go create mode 100644 transport/hysteria/protocol.go create mode 100644 transport/hysteria/speed.go create mode 100644 transport/hysteria/wrap.go create mode 100644 transport/hysteria/xplus.go diff --git a/constant/proxy.go b/constant/proxy.go index f5024059..fa4d25c6 100644 --- a/constant/proxy.go +++ b/constant/proxy.go @@ -15,6 +15,7 @@ const ( TypeTrojan = "trojan" TypeNaive = "naive" TypeWireGuard = "wireguard" + TypeHysteria = "hysteria" ) const ( diff --git a/option/hysteria.go b/option/hysteria.go new file mode 100644 index 00000000..9d57f82e --- /dev/null +++ b/option/hysteria.go @@ -0,0 +1,23 @@ +package option + +type HysteriaOutboundOptions struct { + OutboundDialerOptions + ServerOptions + Protocol string `json:"protocol"` + Up string `json:"up"` + UpMbps int `json:"up_mbps"` + Down string `json:"down"` + DownMbps int `json:"down_mbps"` + Obfs string `json:"obfs"` + Auth []byte `json:"auth"` + AuthString string `json:"auth_str"` + ALPN string `json:"alpn"` + ServerName string `json:"server_name"` + Insecure bool `json:"insecure"` + CustomCA string `json:"ca"` + CustomCAStr string `json:"ca_str"` + ReceiveWindowConn uint64 `json:"recv_window_conn"` + ReceiveWindow uint64 `json:"recv_window"` + DisableMTUDiscovery bool `json:"disable_mtu_discovery"` + Network NetworkList `json:"network,omitempty"` +} diff --git a/option/outbound.go b/option/outbound.go index e735e130..f1f84409 100644 --- a/option/outbound.go +++ b/option/outbound.go @@ -17,6 +17,7 @@ type _Outbound struct { VMessOptions VMessOutboundOptions `json:"-"` TrojanOptions TrojanOutboundOptions `json:"-"` WireGuardOptions WireGuardOutboundOptions `json:"-"` + HysteriaOutbound HysteriaOutboundOptions `json:"-"` SelectorOptions SelectorOutboundOptions `json:"-"` } @@ -41,6 +42,8 @@ func (h Outbound) MarshalJSON() ([]byte, error) { v = h.TrojanOptions case C.TypeWireGuard: v = h.WireGuardOptions + case C.TypeHysteria: + v = h.HysteriaOutbound case C.TypeSelector: v = h.SelectorOptions default: @@ -72,6 +75,8 @@ func (h *Outbound) UnmarshalJSON(bytes []byte) error { v = &h.TrojanOptions case C.TypeWireGuard: v = &h.WireGuardOptions + case C.TypeHysteria: + v = &h.HysteriaOutbound case C.TypeSelector: v = &h.SelectorOptions default: diff --git a/outbound/builder.go b/outbound/builder.go index 843faa3f..97735dae 100644 --- a/outbound/builder.go +++ b/outbound/builder.go @@ -33,6 +33,8 @@ func New(ctx context.Context, router adapter.Router, logger log.ContextLogger, o return NewTrojan(ctx, router, logger, options.Tag, options.TrojanOptions) case C.TypeWireGuard: return NewWireGuard(ctx, router, logger, options.Tag, options.WireGuardOptions) + case C.TypeHysteria: + return NewHysteria(ctx, router, logger, options.Tag, options.HysteriaOutbound) case C.TypeSelector: return NewSelector(router, logger, options.Tag, options.SelectorOptions) default: diff --git a/outbound/hysteria.go b/outbound/hysteria.go new file mode 100644 index 00000000..40a37327 --- /dev/null +++ b/outbound/hysteria.go @@ -0,0 +1,253 @@ +//go:build with_quic + +package outbound + +import ( + "context" + "crypto/tls" + "crypto/x509" + "net" + "os" + "sync" + "time" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/dialer" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing-box/transport/hysteria" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/bufio" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + + "github.com/lucas-clemente/quic-go" +) + +const ( + hyMbpsToBps = 125000 + hyMinSpeedBPS = 16384 + hyDefaultStreamReceiveWindow = 15728640 // 15 MB/s + hyDefaultConnectionReceiveWindow = 67108864 // 64 MB/s + hyDefaultMaxIncomingStreams = 1024 + hyDefaultALPN = "hysteria" + hyKeepAlivePeriod = 10 * time.Second +) + +var _ adapter.Outbound = (*Hysteria)(nil) + +type Hysteria struct { + myOutboundAdapter + ctx context.Context + dialer N.Dialer + serverAddr M.Socksaddr + tlsConfig *tls.Config + quicConfig *quic.Config + authKey []byte + xplusKey []byte + sendBPS uint64 + recvBPS uint64 + connAccess sync.Mutex + conn quic.Connection +} + +func NewHysteria(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.HysteriaOutboundOptions) (*Hysteria, error) { + tlsConfig := &tls.Config{ + ServerName: options.ServerName, + InsecureSkipVerify: options.Insecure, + MinVersion: tls.VersionTLS13, + } + if options.ALPN != "" { + tlsConfig.NextProtos = []string{options.ALPN} + } else { + tlsConfig.NextProtos = []string{hyDefaultALPN} + } + var ca []byte + var err error + if options.CustomCA != "" { + ca, err = os.ReadFile(options.CustomCA) + if err != nil { + return nil, err + } + } + if options.CustomCAStr != "" { + ca = []byte(options.CustomCAStr) + } + if len(ca) > 0 { + cp := x509.NewCertPool() + if !cp.AppendCertsFromPEM(ca) { + return nil, E.New("parse ca failed") + } + tlsConfig.RootCAs = cp + } + quicConfig := &quic.Config{ + InitialStreamReceiveWindow: options.ReceiveWindowConn, + MaxStreamReceiveWindow: options.ReceiveWindowConn, + InitialConnectionReceiveWindow: options.ReceiveWindow, + MaxConnectionReceiveWindow: options.ReceiveWindow, + KeepAlivePeriod: hyKeepAlivePeriod, + DisablePathMTUDiscovery: options.DisableMTUDiscovery, + EnableDatagrams: true, + } + if options.ReceiveWindowConn == 0 { + quicConfig.InitialStreamReceiveWindow = hyDefaultStreamReceiveWindow + quicConfig.MaxStreamReceiveWindow = hyDefaultStreamReceiveWindow + } + if options.ReceiveWindow == 0 { + quicConfig.InitialConnectionReceiveWindow = hyDefaultConnectionReceiveWindow + quicConfig.MaxConnectionReceiveWindow = hyDefaultConnectionReceiveWindow + } + if quicConfig.MaxIncomingStreams == 0 { + quicConfig.MaxIncomingStreams = hyDefaultMaxIncomingStreams + } + var auth []byte + if len(options.Auth) > 0 { + auth = options.Auth + } else { + auth = []byte(options.AuthString) + } + var xplus []byte + if options.Obfs != "" { + xplus = []byte(options.Obfs) + } + var up, down uint64 + if len(options.Up) > 0 { + up = hysteria.StringToBps(options.Up) + if up == 0 { + return nil, E.New("invalid up speed format: ", options.Up) + } + } else { + up = uint64(options.UpMbps) * hyMbpsToBps + } + if len(options.Down) > 0 { + down = hysteria.StringToBps(options.Down) + if down == 0 { + return nil, E.New("invalid down speed format: ", options.Down) + } + } else { + down = uint64(options.DownMbps) * hyMbpsToBps + } + if up < hyMinSpeedBPS { + return nil, E.New("invalid up speed") + } + if down < hyMinSpeedBPS { + return nil, E.New("invalid down speed") + } + return &Hysteria{ + myOutboundAdapter: myOutboundAdapter{ + protocol: C.TypeHysteria, + network: options.Network.Build(), + router: router, + logger: logger, + tag: tag, + }, + ctx: ctx, + dialer: dialer.NewOutbound(router, options.OutboundDialerOptions), + serverAddr: options.ServerOptions.Build(), + tlsConfig: tlsConfig, + quicConfig: quicConfig, + authKey: auth, + xplusKey: xplus, + sendBPS: up, + recvBPS: down, + }, nil +} + +func (h *Hysteria) offer(ctx context.Context) (quic.Connection, error) { + conn := h.conn + if conn != nil && !common.Done(conn.Context()) { + return conn, nil + } + h.connAccess.Lock() + defer h.connAccess.Unlock() + conn = h.conn + if conn != nil && !common.Done(conn.Context()) { + return conn, nil + } + conn, err := h.offerNew(ctx) + if err != nil { + return nil, err + } + h.conn = conn + return conn, nil +} + +func (h *Hysteria) offerNew(ctx context.Context) (quic.Connection, error) { + udpConn, err := h.dialer.DialContext(h.ctx, "udp", h.serverAddr) + if err != nil { + return nil, err + } + var packetConn net.PacketConn + packetConn = bufio.NewUnbindPacketConn(udpConn) + if h.xplusKey != nil { + packetConn = hysteria.NewXPlusPacketConn(packetConn, h.xplusKey) + } + packetConn = &hysteria.WrapPacketConn{PacketConn: packetConn} + quicConn, err := quic.Dial(packetConn, udpConn.RemoteAddr(), h.serverAddr.AddrString(), h.tlsConfig, h.quicConfig) + if err != nil { + packetConn.Close() + return nil, err + } + controlStream, err := quicConn.OpenStreamSync(ctx) + if err != nil { + packetConn.Close() + return nil, err + } + err = hysteria.WriteClientHello(controlStream, hysteria.ClientHello{ + SendBPS: h.sendBPS, + RecvBPS: h.recvBPS, + Auth: h.authKey, + }) + if err != nil { + return nil, E.Cause(err, "write hysteria client hello") + } + serverHello, err := hysteria.ReadServerHello(controlStream) + if err != nil { + return nil, err + } + if !serverHello.OK { + return nil, E.New("remote error: ", serverHello.Message) + } + // TODO: set congestion control + return quicConn, nil +} + +func (h *Hysteria) Close() error { + h.connAccess.Lock() + defer h.connAccess.Unlock() + if h.conn != nil { + h.conn.CloseWithError(0, "") + } + return nil +} + +func (h *Hysteria) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { + conn, err := h.offer(ctx) + if err != nil { + return nil, err + } + stream, err := conn.OpenStream() + if err != nil { + return nil, err + } + switch N.NetworkName(network) { + case N.NetworkTCP: + return hysteria.NewClientConn(stream, destination), nil + default: + return nil, E.New("unsupported network: ", network) + } +} + +func (h *Hysteria) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { + return nil, os.ErrInvalid +} + +func (h *Hysteria) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { + return NewConnection(ctx, h, conn, metadata) +} + +func (h *Hysteria) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { + return NewPacketConnection(ctx, h, conn, metadata) +} diff --git a/outbound/hysteria_stub.go b/outbound/hysteria_stub.go new file mode 100644 index 00000000..ae2d62b4 --- /dev/null +++ b/outbound/hysteria_stub.go @@ -0,0 +1,16 @@ +//go:build !with_quic + +package outbound + +import ( + "context" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" +) + +func NewHysteria(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.HysteriaOutboundOptions) (adapter.Outbound, error) { + return nil, E.New(`QUIC is not included in this build, rebuild with -tags with_quic`) +} diff --git a/test/box_test.go b/test/box_test.go index 58ab0df1..939a52c6 100644 --- a/test/box_test.go +++ b/test/box_test.go @@ -35,14 +35,6 @@ func startInstance(t *testing.T, options option.Options) { }) } -func testTCP(t *testing.T, clientPort uint16, testPort uint16) { - dialer := socks.NewClient(N.SystemDialer, M.ParseSocksaddrHostPort("127.0.0.1", clientPort), socks.Version5, "", "") - dialTCP := func() (net.Conn, error) { - return dialer.DialContext(context.Background(), "tcp", M.ParseSocksaddrHostPort("127.0.0.1", testPort)) - } - require.NoError(t, testPingPongWithConn(t, testPort, dialTCP)) -} - func testSuit(t *testing.T, clientPort uint16, testPort uint16) { dialer := socks.NewClient(N.SystemDialer, M.ParseSocksaddrHostPort("127.0.0.1", clientPort), socks.Version5, "", "") dialTCP := func() (net.Conn, error) { @@ -51,36 +43,23 @@ func testSuit(t *testing.T, clientPort uint16, testPort uint16) { dialUDP := func() (net.PacketConn, error) { return dialer.ListenPacket(context.Background(), M.ParseSocksaddrHostPort("127.0.0.1", testPort)) } - /*t.Run("tcp", func(t *testing.T) { - t.Parallel() - var err error - for retry := 0; retry < 3; retry++ { - err = testLargeDataWithConn(t, testPort, dialTCP) - if err == nil { - break - } - } - require.NoError(t, err) - }) - t.Run("udp", func(t *testing.T) { - t.Parallel() - var err error - for retry := 0; retry < 3; retry++ { - err = testLargeDataWithPacketConn(t, testPort, dialUDP) - if err == nil { - break - } - } - require.NoError(t, err) - })*/ - //require.NoError(t, testPingPongWithConn(t, testPort, dialTCP)) - //require.NoError(t, testPingPongWithPacketConn(t, testPort, dialUDP)) + // require.NoError(t, testPingPongWithConn(t, testPort, dialTCP)) + // require.NoError(t, testPingPongWithPacketConn(t, testPort, dialUDP)) require.NoError(t, testLargeDataWithConn(t, testPort, dialTCP)) require.NoError(t, testLargeDataWithPacketConn(t, testPort, dialUDP)) // require.NoError(t, testPacketConnTimeout(t, dialUDP)) } +func testTCP(t *testing.T, clientPort uint16, testPort uint16) { + dialer := socks.NewClient(N.SystemDialer, M.ParseSocksaddrHostPort("127.0.0.1", clientPort), socks.Version5, "", "") + dialTCP := func() (net.Conn, error) { + return dialer.DialContext(context.Background(), "tcp", M.ParseSocksaddrHostPort("127.0.0.1", testPort)) + } + require.NoError(t, testPingPongWithConn(t, testPort, dialTCP)) + require.NoError(t, testLargeDataWithConn(t, testPort, dialTCP)) +} + func testSuitWg(t *testing.T, clientPort uint16, testPort uint16) { dialer := socks.NewClient(N.SystemDialer, M.ParseSocksaddrHostPort("127.0.0.1", clientPort), socks.Version5, "", "") dialTCP := func() (net.Conn, error) { diff --git a/test/clash_test.go b/test/clash_test.go index bd76aa19..b490434a 100644 --- a/test/clash_test.go +++ b/test/clash_test.go @@ -32,6 +32,7 @@ const ( ImageTrojan = "trojangfw/trojan:latest" ImageNaive = "pocat/naiveproxy:client" ImageBoringTun = "ghcr.io/ntkme/boringtun:edge" + ImageHysteria = "tobyxdd/hysteria:latest" ) var allImages = []string{ @@ -41,6 +42,7 @@ var allImages = []string{ ImageTrojan, ImageNaive, ImageBoringTun, + ImageHysteria, } var localIP = netip.MustParseAddr("127.0.0.1") diff --git a/test/config/hysteria.json b/test/config/hysteria.json new file mode 100644 index 00000000..e33624a2 --- /dev/null +++ b/test/config/hysteria.json @@ -0,0 +1,9 @@ +{ + "listen": ":10000", + "cert": "/etc/hysteria/cert.pem", + "key": "/etc/hysteria/key.pem", + "auth_str": "password", + "obfs": "fuck me till the daylight", + "up_mbps": 100, + "down_mbps": 100 +} \ No newline at end of file diff --git a/test/hysteria_test.go b/test/hysteria_test.go new file mode 100644 index 00000000..055f76f5 --- /dev/null +++ b/test/hysteria_test.go @@ -0,0 +1,60 @@ +package main + +import ( + "net/netip" + "testing" + + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/option" +) + +func TestHysteriaOutbound(t *testing.T) { + if !C.QUIC_AVAILABLE { + t.Skip("QUIC not included") + } + caPem, certPem, keyPem := createSelfSignedCertificate(t, "example.org") + startDockerContainer(t, DockerOptions{ + Image: ImageHysteria, + Ports: []uint16{serverPort, testPort}, + Cmd: []string{"-c", "/etc/hysteria/config.json", "server"}, + Bind: map[string]string{ + "hysteria.json": "/etc/hysteria/config.json", + certPem: "/etc/hysteria/cert.pem", + keyPem: "/etc/hysteria/key.pem", + }, + }) + startInstance(t, option.Options{ + Log: &option.LogOptions{ + Level: "trace", + }, + Inbounds: []option.Inbound{ + { + Type: C.TypeMixed, + MixedOptions: option.HTTPMixedInboundOptions{ + ListenOptions: option.ListenOptions{ + Listen: option.ListenAddress(netip.IPv4Unspecified()), + ListenPort: clientPort, + }, + }, + }, + }, + Outbounds: []option.Outbound{ + { + Type: C.TypeHysteria, + HysteriaOutbound: option.HysteriaOutboundOptions{ + ServerOptions: option.ServerOptions{ + Server: "127.0.0.1", + ServerPort: serverPort, + }, + UpMbps: 100, + DownMbps: 100, + AuthString: "password", + Obfs: "fuck me till the daylight", + CustomCA: caPem, + ServerName: "example.org", + }, + }, + }, + }) + testTCP(t, clientPort, testPort) +} diff --git a/transport/hysteria/client.go b/transport/hysteria/client.go new file mode 100644 index 00000000..5ae2bfe4 --- /dev/null +++ b/transport/hysteria/client.go @@ -0,0 +1,69 @@ +package hysteria + +import ( + "net" + + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + + "github.com/lucas-clemente/quic-go" +) + +var _ net.Conn = (*ClientConn)(nil) + +type ClientConn struct { + quic.Stream + destination M.Socksaddr + requestWritten bool + responseRead bool +} + +func NewClientConn(stream quic.Stream, destination M.Socksaddr) *ClientConn { + return &ClientConn{ + Stream: stream, + destination: destination, + } +} + +func (c *ClientConn) Read(b []byte) (n int, err error) { + if !c.responseRead { + var response *ServerResponse + response, err = ReadServerResponse(c.Stream) + if err != nil { + return + } + if !response.OK { + return 0, E.New("remote error: " + response.Message) + } + c.responseRead = true + } + return c.Stream.Read(b) +} + +func (c *ClientConn) Write(b []byte) (n int, err error) { + if !c.requestWritten { + err = WriteClientRequest(c.Stream, ClientRequest{ + UDP: false, + Host: c.destination.AddrString(), + Port: c.destination.Port, + }, b) + if err != nil { + return + } + c.requestWritten = true + return len(b), nil + } + return c.Stream.Write(b) +} + +func (c *ClientConn) LocalAddr() net.Addr { + return nil +} + +func (c *ClientConn) RemoteAddr() net.Addr { + return c.destination.TCPAddr() +} + +func (c *ClientConn) Upstream() any { + return c.Stream +} diff --git a/transport/hysteria/protocol.go b/transport/hysteria/protocol.go new file mode 100644 index 00000000..5daedc37 --- /dev/null +++ b/transport/hysteria/protocol.go @@ -0,0 +1,142 @@ +package hysteria + +import ( + "encoding/binary" + "io" + + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" +) + +const Version = 3 + +type ClientHello struct { + SendBPS uint64 + RecvBPS uint64 + Auth []byte +} + +type ServerHello struct { + OK bool + SendBPS uint64 + RecvBPS uint64 + Message string +} + +type ClientRequest struct { + UDP bool + Host string + Port uint16 +} + +type ServerResponse struct { + OK bool + UDPSessionID uint32 + Message string +} + +func WriteClientHello(stream io.Writer, hello ClientHello) error { + var requestLen int + requestLen += 1 // version + requestLen += 8 // sendBPS + requestLen += 8 // recvBPS + requestLen += 2 // auth len + requestLen += len(hello.Auth) + _request := buf.StackNewSize(requestLen) + defer common.KeepAlive(_request) + request := common.Dup(_request) + defer request.Release() + common.Must( + request.WriteByte(Version), + binary.Write(request, binary.BigEndian, hello.SendBPS), + binary.Write(request, binary.BigEndian, hello.RecvBPS), + binary.Write(request, binary.BigEndian, uint16(len(hello.Auth))), + common.Error(request.Write(hello.Auth)), + ) + return common.Error(stream.Write(request.Bytes())) +} + +func ReadServerHello(stream io.Reader) (*ServerHello, error) { + var responseLen int + responseLen += 1 // ok + responseLen += 8 // sendBPS + responseLen += 8 // recvBPS + responseLen += 2 // message len + _response := buf.StackNewSize(responseLen) + defer common.KeepAlive(_response) + response := common.Dup(_response) + defer response.Release() + _, err := response.ReadFullFrom(stream, responseLen) + if err != nil { + return nil, err + } + var serverHello ServerHello + serverHello.OK = response.Byte(0) == 1 + serverHello.SendBPS = binary.BigEndian.Uint64(response.Range(1, 9)) + serverHello.RecvBPS = binary.BigEndian.Uint64(response.Range(9, 17)) + messageLen := binary.BigEndian.Uint16(response.Range(17, 19)) + if messageLen == 0 { + return &serverHello, nil + } + message := make([]byte, messageLen) + _, err = io.ReadFull(stream, message) + if err != nil { + return nil, err + } + serverHello.Message = string(message) + return &serverHello, nil +} + +func WriteClientRequest(stream io.Writer, request ClientRequest, payload []byte) error { + var requestLen int + requestLen += 1 // udp + requestLen += 2 // host len + requestLen += len(request.Host) + requestLen += 2 // port + requestLen += len(payload) + _buffer := buf.StackNewSize(requestLen) + defer common.KeepAlive(_buffer) + buffer := common.Dup(_buffer) + defer buffer.Release() + if request.UDP { + common.Must(buffer.WriteByte(1)) + } else { + common.Must(buffer.WriteByte(0)) + } + common.Must( + binary.Write(buffer, binary.BigEndian, uint16(len(request.Host))), + common.Error(buffer.WriteString(request.Host)), + binary.Write(buffer, binary.BigEndian, request.Port), + common.Error(buffer.Write(payload)), + ) + return common.Error(stream.Write(buffer.Bytes())) +} + +func ReadServerResponse(stream io.Reader) (*ServerResponse, error) { + var responseLen int + responseLen += 1 // ok + responseLen += 4 // udp session id + responseLen += 2 // message len + _response := buf.StackNewSize(responseLen) + defer common.KeepAlive(_response) + response := common.Dup(_response) + defer response.Release() + _, err := response.ReadFullFrom(stream, responseLen) + if err != nil { + return nil, err + } + var serverResponse ServerResponse + serverResponse.OK = response.Byte(0) == 1 + serverResponse.UDPSessionID = binary.BigEndian.Uint32(response.Range(1, 5)) + messageLen := binary.BigEndian.Uint16(response.Range(5, 7)) + if messageLen == 0 { + return &serverResponse, nil + } + message := make([]byte, messageLen) + _, err = io.ReadFull(stream, message) + if err != nil { + return nil, err + } + serverResponse.Message = string(message) + return &serverResponse, nil +} diff --git a/transport/hysteria/speed.go b/transport/hysteria/speed.go new file mode 100644 index 00000000..161e0d58 --- /dev/null +++ b/transport/hysteria/speed.go @@ -0,0 +1,36 @@ +package hysteria + +import ( + "regexp" + "strconv" +) + +func StringToBps(s string) uint64 { + if s == "" { + return 0 + } + m := regexp.MustCompile(`^(\d+)\s*([KMGT]?)([Bb])ps$`).FindStringSubmatch(s) + if m == nil { + return 0 + } + var n uint64 + switch m[2] { + case "K": + n = 1 << 10 + case "M": + n = 1 << 20 + case "G": + n = 1 << 30 + case "T": + n = 1 << 40 + default: + n = 1 + } + v, _ := strconv.ParseUint(m[1], 10, 64) + n = v * n + if m[3] == "b" { + // Bits, need to convert to bytes + n = n >> 3 + } + return n +} diff --git a/transport/hysteria/wrap.go b/transport/hysteria/wrap.go new file mode 100644 index 00000000..03aadb38 --- /dev/null +++ b/transport/hysteria/wrap.go @@ -0,0 +1,29 @@ +package hysteria + +import ( + "net" + "os" + "syscall" + + "github.com/sagernet/sing/common" +) + +type WrapPacketConn struct { + net.PacketConn +} + +func (c *WrapPacketConn) SetReadBuffer(bytes int) error { + return common.MustCast[*net.UDPConn](c.PacketConn).SetReadBuffer(bytes) +} + +func (c *WrapPacketConn) SetWriteBuffer(bytes int) error { + return common.MustCast[*net.UDPConn](c.PacketConn).SetWriteBuffer(bytes) +} + +func (c *WrapPacketConn) SyscallConn() (syscall.RawConn, error) { + return common.MustCast[*net.UDPConn](c.PacketConn).SyscallConn() +} + +func (c *WrapPacketConn) File() (f *os.File, err error) { + return common.MustCast[*net.UDPConn](c.PacketConn).File() +} diff --git a/transport/hysteria/xplus.go b/transport/hysteria/xplus.go new file mode 100644 index 00000000..90a9a325 --- /dev/null +++ b/transport/hysteria/xplus.go @@ -0,0 +1,119 @@ +package hysteria + +import ( + "crypto/sha256" + "math/rand" + "net" + "sync" + "time" + + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/bufio" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +const xplusSaltLen = 16 + +var errInalidPacket = E.New("invalid packet") + +func NewXPlusPacketConn(conn net.PacketConn, key []byte) net.PacketConn { + vectorisedWriter, isVectorised := bufio.CreateVectorisedPacketWriter(conn) + if isVectorised { + return &VectorisedXPlusConn{ + XPlusPacketConn: XPlusPacketConn{ + PacketConn: conn, + key: key, + rand: rand.New(rand.NewSource(time.Now().UnixNano())), + }, + writer: vectorisedWriter, + } + } else { + return &XPlusPacketConn{ + PacketConn: conn, + key: key, + rand: rand.New(rand.NewSource(time.Now().UnixNano())), + } + } +} + +type XPlusPacketConn struct { + net.PacketConn + key []byte + randAccess sync.Mutex + rand *rand.Rand +} + +func (c *XPlusPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + n, addr, err = c.PacketConn.ReadFrom(p) + if err != nil { + return + } else if n < xplusSaltLen { + return 0, nil, errInalidPacket + } + key := sha256.Sum256(append(c.key, p[:xplusSaltLen]...)) + for i := range p[xplusSaltLen:] { + p[i] = p[xplusSaltLen+i] ^ key[i%sha256.Size] + } + n -= xplusSaltLen + return +} + +func (c *XPlusPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + // can't use unsafe buffer on WriteTo + buffer := buf.NewSize(len(p) + xplusSaltLen) + defer buffer.Release() + salt := buffer.Extend(xplusSaltLen) + c.randAccess.Lock() + _, _ = c.rand.Read(salt) + c.randAccess.Unlock() + key := sha256.Sum256(append(c.key, salt...)) + for i := range p { + common.Must(buffer.WriteByte(p[i] ^ key[i%sha256.Size])) + } + return c.PacketConn.WriteTo(buffer.Bytes(), addr) +} + +func (c *XPlusPacketConn) Upstream() any { + return c.PacketConn +} + +type VectorisedXPlusConn struct { + XPlusPacketConn + writer N.VectorisedPacketWriter +} + +func (c *VectorisedXPlusConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + header := buf.NewSize(xplusSaltLen) + defer header.Release() + salt := header.Extend(xplusSaltLen) + c.randAccess.Lock() + _, _ = c.rand.Read(salt) + c.randAccess.Unlock() + key := sha256.Sum256(append(c.key, salt...)) + for i := range p { + p[i] ^= key[i%sha256.Size] + } + return bufio.WriteVectorisedPacket(c.writer, [][]byte{header.Bytes(), p}, M.SocksaddrFromNet(addr)) +} + +func (c *VectorisedXPlusConn) WriteVectorisedPacket(buffers []*buf.Buffer, destination M.Socksaddr) error { + header := buf.NewSize(xplusSaltLen) + salt := header.Extend(xplusSaltLen) + c.randAccess.Lock() + _, _ = c.rand.Read(salt) + c.randAccess.Unlock() + key := sha256.Sum256(append(c.key, salt...)) + var index int + for _, buffer := range buffers { + data := buffer.Bytes() + for i := range data { + data[i] ^= key[index%sha256.Size] + index++ + } + } + buffers = append([]*buf.Buffer{header}, buffers...) + return c.writer.WriteVectorisedPacket(buffers, destination) +}