From 58218599887eccaeba7b8a9da9f14bea18074dd4 Mon Sep 17 00:00:00 2001 From: Saeid Aghapour Date: Tue, 10 Dec 2024 13:16:19 +0330 Subject: [PATCH] integrate sing-box-plus' commits --- .github/workflows/build.yml | 40 -- .github/workflows/docker.yml | 6 +- common/dialer/default.go | 36 +- common/dialer/default_go1.20.go | 8 +- common/dialer/default_nongo1.20.go | 5 +- common/dialer/extended_tcp.go | 55 +++ common/dialer/extended_tcp_stub.go | 36 ++ common/dialer/fragment.go | 217 ++++++++++ common/dialer/tfo.go | 22 +- common/dialer/wireguard.go | 3 +- go.mod | 1 + go.sum | 2 + ipscanner/LICENSE | 21 + ipscanner/README.md | 38 ++ ipscanner/internal/engine/engine.go | 72 ++++ ipscanner/internal/engine/queue.go | 189 +++++++++ ipscanner/internal/iterator/iterator.go | 247 ++++++++++++ ipscanner/internal/ping/http.go | 128 ++++++ ipscanner/internal/ping/ping.go | 94 +++++ ipscanner/internal/ping/tcp.go | 84 ++++ ipscanner/internal/ping/tls.go | 80 ++++ ipscanner/internal/ping/warp.go | 304 ++++++++++++++ ipscanner/internal/statute/default.go | 173 ++++++++ ipscanner/internal/statute/ping.go | 35 ++ ipscanner/internal/statute/queue.go | 34 ++ ipscanner/internal/statute/statute.go | 66 +++ ipscanner/scanner.go | 275 +++++++++++++ ipscanner/warp_scanner.go | 80 ++++ iputils/iputils.go | 113 ++++++ option/fragment.go | 9 + option/outbound.go | 1 + option/range.go | 56 +++ warp/account.go | 131 ++++++ warp/api.go | 515 ++++++++++++++++++++++++ warp/endpoint.go | 117 ++++++ warp/key.go | 86 ++++ warp/tls.go | 145 +++++++ 37 files changed, 3447 insertions(+), 77 deletions(-) create mode 100644 common/dialer/extended_tcp.go create mode 100644 common/dialer/extended_tcp_stub.go create mode 100644 common/dialer/fragment.go create mode 100644 ipscanner/LICENSE create mode 100644 ipscanner/README.md create mode 100644 ipscanner/internal/engine/engine.go create mode 100644 ipscanner/internal/engine/queue.go create mode 100644 ipscanner/internal/iterator/iterator.go create mode 100644 ipscanner/internal/ping/http.go create mode 100644 ipscanner/internal/ping/ping.go create mode 100644 ipscanner/internal/ping/tcp.go create mode 100644 ipscanner/internal/ping/tls.go create mode 100644 ipscanner/internal/ping/warp.go create mode 100644 ipscanner/internal/statute/default.go create mode 100644 ipscanner/internal/statute/ping.go create mode 100644 ipscanner/internal/statute/queue.go create mode 100644 ipscanner/internal/statute/statute.go create mode 100644 ipscanner/scanner.go create mode 100644 ipscanner/warp_scanner.go create mode 100644 iputils/iputils.go create mode 100644 option/fragment.go create mode 100644 option/range.go create mode 100644 warp/account.go create mode 100644 warp/api.go create mode 100644 warp/endpoint.go create mode 100644 warp/key.go create mode 100644 warp/tls.go diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 8a6924d0..170342c3 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -82,9 +82,6 @@ jobs: strategy: matrix: include: - - name: linux_386 - goos: linux - goarch: 386 - name: linux_amd64 goos: linux goarch: amd64 @@ -99,46 +96,9 @@ jobs: goos: linux goarch: arm goarm: 7 - - name: linux_s390x - goos: linux - goarch: s390x - - name: linux_riscv64 - goos: linux - goarch: riscv64 - - name: linux_mips64le - goos: linux - goarch: mips64le - - name: windows_amd64 - goos: windows - goarch: amd64 - require_legacy_go: true - - name: windows_386 - goos: windows - goarch: 386 - require_legacy_go: true - - name: windows_arm64 - goos: windows - goarch: arm64 - name: darwin_arm64 goos: darwin goarch: arm64 - - name: darwin_amd64 - goos: darwin - goarch: amd64 - require_legacy_go: true - - name: android_arm64 - goos: android - goarch: arm64 - - name: android_arm - goos: android - goarch: arm - goarm: 7 - - name: android_amd64 - goos: android - goarch: amd64 - - name: android_386 - goos: android - goarch: 386 steps: - name: Checkout uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4 diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index bcf210ab..240932be 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -10,7 +10,7 @@ on: description: "The tag version you want to build" env: - REGISTRY_IMAGE: ghcr.io/sagernet/sing-box + REGISTRY_IMAGE: ghcr.io/tools4net/singbox jobs: build: @@ -23,10 +23,6 @@ jobs: - linux/arm/v6 - linux/arm/v7 - linux/arm64 - - linux/386 - - linux/ppc64le - - linux/riscv64 - - linux/s390x steps: - name: Get commit to build id: ref diff --git a/common/dialer/default.go b/common/dialer/default.go index bf553618..17aadd6c 100644 --- a/common/dialer/default.go +++ b/common/dialer/default.go @@ -129,6 +129,34 @@ func NewDefault(networkManager adapter.NetworkManager, options option.DialerOpti // TODO: Add an option to customize the keep alive period dialer.KeepAlive = C.TCPKeepAliveInitial dialer.Control = control.Append(dialer.Control, control.SetKeepAlivePeriod(C.TCPKeepAliveInitial, C.TCPKeepAliveInterval)) + if options.TLSFragment.Enabled && options.TCPFastOpen { + return nil, E.New("TLS Fragmentation is not compatible with TCP Fast Open, set `tcp_fast_open` to `false` in your outbound if you intend to enable TLS fragmentation.") + } + var tlsFragment TLSFragment + if options.TLSFragment.Enabled { + tlsFragment.Enabled = true + + sleep, err := option.ParseIntRange(options.TLSFragment.Sleep) + if err != nil { + return nil, E.Cause(err, "missing or invalid value supplied as TLS fragment `sleep` option") + } + if sleep[1] > 1000 { + return nil, E.New("invalid range supplied as TLS fragment `sleep` option! set to '0' to disable sleeps or set to range [0,1000]") + } + tlsFragment.Sleep.Min = sleep[0] + tlsFragment.Sleep.Max = sleep[1] + + size, err := option.ParseIntRange(options.TLSFragment.Size) + if err != nil { + return nil, E.Cause(err, "missing or invalid value supplied as TLS fragment `size` option") + } + if size[0] <= 0 || size[1] > 256 { + return nil, E.New("invalid range supplied as TLS fragment `size` option! valid range: [1,256]") + } + tlsFragment.Size.Min = size[0] + tlsFragment.Size.Max = size[1] + + } var udpFragment bool if options.UDPFragment != nil { udpFragment = *options.UDPFragment @@ -175,11 +203,11 @@ func NewDefault(networkManager adapter.NetworkManager, options option.DialerOpti if networkStrategy != C.NetworkStrategyDefault && options.TCPFastOpen { return nil, E.New("`tcp_fast_open` is conflict with `network_strategy` or `route.default_network_strategy`") } - tcpDialer4, err := newTCPDialer(dialer4, options.TCPFastOpen) + tcpDialer4, err := newTCPDialer(dialer4, options.TCPFastOpen, tlsFragment) if err != nil { return nil, err } - tcpDialer6, err := newTCPDialer(dialer6, options.TCPFastOpen) + tcpDialer6, err := newTCPDialer(dialer6, options.TCPFastOpen, tlsFragment) if err != nil { return nil, err } @@ -214,9 +242,9 @@ func (d *DefaultDialer) DialContext(ctx context.Context, network string, address } } if !address.IsIPv6() { - return trackConn(DialSlowContext(&d.dialer4, ctx, network, address)) + return trackConn(d.dialer4.DialContext(ctx, network, address)) } else { - return trackConn(DialSlowContext(&d.dialer6, ctx, network, address)) + return trackConn(d.dialer6.DialContext(ctx, network, address)) } } else { return d.DialParallelInterface(ctx, network, address, d.networkStrategy, d.networkType, d.fallbackNetworkType, d.networkFallbackDelay) diff --git a/common/dialer/default_go1.20.go b/common/dialer/default_go1.20.go index 9dde955f..a0cb095e 100644 --- a/common/dialer/default_go1.20.go +++ b/common/dialer/default_go1.20.go @@ -4,14 +4,12 @@ package dialer import ( "net" - - "github.com/metacubex/tfo-go" ) -type tcpDialer = tfo.Dialer +type tcpDialer = ExtendedTCPDialer -func newTCPDialer(dialer net.Dialer, tfoEnabled bool) (tcpDialer, error) { - return tfo.Dialer{Dialer: dialer, DisableTFO: !tfoEnabled}, nil +func newTCPDialer(dialer net.Dialer, tfoEnabled bool, tlsFragment TLSFragment) (tcpDialer, error) { + return tcpDialer{Dialer: dialer, DisableTFO: !tfoEnabled, TLSFragment: tlsFragment}, nil } func dialerFromTCPDialer(dialer tcpDialer) net.Dialer { diff --git a/common/dialer/default_nongo1.20.go b/common/dialer/default_nongo1.20.go index b2e4638d..8796f108 100644 --- a/common/dialer/default_nongo1.20.go +++ b/common/dialer/default_nongo1.20.go @@ -10,10 +10,13 @@ import ( type tcpDialer = net.Dialer -func newTCPDialer(dialer net.Dialer, tfoEnabled bool) (tcpDialer, error) { +func newTCPDialer(dialer net.Dialer, tfoEnabled bool, tlsFragment TLSFragment) (tcpDialer, error) { if tfoEnabled { return dialer, E.New("TCP Fast Open requires go1.20, please recompile your binary.") } + if tlsFragment.Enabled { + return tcpDialer{Dialer: dialer, DisableTFO: true, TLSFragment: tlsFragment}, nil + } return dialer, nil } diff --git a/common/dialer/extended_tcp.go b/common/dialer/extended_tcp.go new file mode 100644 index 00000000..8f827d40 --- /dev/null +++ b/common/dialer/extended_tcp.go @@ -0,0 +1,55 @@ +//go:build go1.20 + +package dialer + +import ( + "context" + "net" + + "github.com/metacubex/tfo-go" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +// Custom TCP dialer with extra features such as "TCP Fast Open" or "TLS Fragmentation" +type ExtendedTCPDialer struct { + net.Dialer + DisableTFO bool + TLSFragment TLSFragment +} + +func (d *ExtendedTCPDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { + if (d.DisableTFO && !d.TLSFragment.Enabled) || N.NetworkName(network) != N.NetworkTCP { + switch N.NetworkName(network) { + case N.NetworkTCP, N.NetworkUDP: + return d.Dialer.DialContext(ctx, network, destination.String()) + default: + return d.Dialer.DialContext(ctx, network, destination.AddrString()) + } + } + // Create a fragment dialer + if d.TLSFragment.Enabled { + fragmentConn := &fragmentConn{ + dialer: d.Dialer, + fragment: d.TLSFragment, + network: network, + destination: destination, + } + conn, err := d.Dialer.DialContext(ctx, network, destination.String()) + if err != nil { + fragmentConn.err = err + return nil, err + } + fragmentConn.conn = conn + return fragmentConn, nil + } + // Create a TFO dialer + return &slowOpenConn{ + dialer: &tfo.Dialer{Dialer: d.Dialer, DisableTFO: d.DisableTFO}, + ctx: ctx, + network: network, + destination: destination, + create: make(chan struct{}), + }, + nil +} diff --git a/common/dialer/extended_tcp_stub.go b/common/dialer/extended_tcp_stub.go new file mode 100644 index 00000000..44770e2d --- /dev/null +++ b/common/dialer/extended_tcp_stub.go @@ -0,0 +1,36 @@ +//go:build !go1.20 + +package dialer + +import ( + "context" + "net" + + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +func (d *ExtendedTCPDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { + if !d.TLSFragment.Enabled || N.NetworkName(network) != N.NetworkTCP { + switch N.NetworkName(network) { + case N.NetworkTCP, N.NetworkUDP: + return d.Dialer.DialContext(ctx, network, destination.String()) + default: + return d.Dialer.DialContext(ctx, network, destination.AddrString()) + } + } + // Create a TLS-Fragmented dialer + fragmentConn := &fragmentConn{ + dialer: d.Dialer, + fragment: d.TLSFragment, + network: network, + destination: destination, + } + conn, err := d.Dialer.DialContext(ctx, network, destination.String()) + if err != nil { + fragmentConn.err = err + return nil, err + } + fragmentConn.conn = conn + return fragmentConn, nil +} diff --git a/common/dialer/fragment.go b/common/dialer/fragment.go new file mode 100644 index 00000000..130ed39e --- /dev/null +++ b/common/dialer/fragment.go @@ -0,0 +1,217 @@ +package dialer + +import ( + "io" + "net" + "os" + "time" + + opts "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/bufio" + M "github.com/sagernet/sing/common/metadata" +) + +type TLSFragment struct { + Enabled bool + Sleep IntRange + Size IntRange +} + +type fragmentConn struct { + conn net.Conn + err error + dialer net.Dialer + destination M.Socksaddr + network string + fragment TLSFragment +} + +type IntRange struct { + Min uint64 + Max uint64 +} + +// isClientHelloPacket checks if data resembles a TLS clientHello packet +func isClientHelloPacket(b []byte) bool { + // Check if the packet is at least 5 bytes long and the content type is 22 (TLS handshake) + if len(b) < 5 || b[0] != 22 { + return false + } + + // Check if the protocol version is TLS 1.0 or higher (0x0301 or greater) + version := uint16(b[1])<<8 | uint16(b[2]) + if version < 0x0301 { + return false + } + + // Check if the handshake message type is ClientHello (1) + if b[5] != 1 { + return false + } + + return true +} + +func (c *fragmentConn) writeFragments(b []byte) (n int, err error) { + recordLen := 5 + ((int(b[3]) << 8) | int(b[4])) + if len(b) < recordLen { // maybe already fragmented somehow + return c.conn.Write(b) + } + + var bytesWritten int + data := b[5:recordLen] + buf := make([]byte, 1024) + queue := make([]byte, 2048) + n_queue := int(opts.GetRandomIntFromRange(1, 4)) + L_queue := 0 + c_queue := 0 + for from := 0; ; { + to := from + int(opts.GetRandomIntFromRange(c.fragment.Size.Min, c.fragment.Size.Max)) + if to > len(data) { + to = len(data) + } + copy(buf[:3], b) + copy(buf[5:], data[from:to]) + l := to - from + from = to + buf[3] = byte(l >> 8) + buf[4] = byte(l) + + if c_queue < n_queue { + if l > 0 { + copy(queue[L_queue:], buf[:5+l]) + L_queue = L_queue + 5 + l + } + c_queue = c_queue + 1 + } else { + if l > 0 { + copy(queue[L_queue:], buf[:5+l]) + L_queue = L_queue + 5 + l + } + + if L_queue > 0 { + n, err := c.conn.Write(queue[:L_queue]) + if err != nil { + return 0, err + } + bytesWritten += n + if c.fragment.Sleep.Max != 0 { + time.Sleep(time.Duration(opts.GetRandomIntFromRange(c.fragment.Sleep.Min, c.fragment.Sleep.Max)) * time.Millisecond) + } + + } + + L_queue = 0 + c_queue = 0 + + } + + if from == len(data) { + if L_queue > 0 { + n, err := c.conn.Write(queue[:L_queue]) + if err != nil { + return 0, err + } + bytesWritten += n + if c.fragment.Sleep.Max != 0 { + time.Sleep(time.Duration(opts.GetRandomIntFromRange(c.fragment.Sleep.Min, c.fragment.Sleep.Max)) * time.Millisecond) + } + + } + if len(b) > recordLen { + n, err := c.conn.Write(b[recordLen:]) + if err != nil { + return recordLen + n, err + } + bytesWritten += n + } + return bytesWritten, nil + } + } +} + +func (c *fragmentConn) Write(b []byte) (n int, err error) { + if c.conn == nil { + return 0, c.err + } + + if isClientHelloPacket(b) { + return c.writeFragments(b) + } + + return c.conn.Write(b) +} + +func (c *fragmentConn) Read(b []byte) (n int, err error) { + if c.conn == nil { + return 0, c.err + } + return c.conn.Read(b) +} + +func (c *fragmentConn) Close() error { + return common.Close(c.conn) +} + +func (c *fragmentConn) LocalAddr() net.Addr { + if c.conn == nil { + return M.Socksaddr{} + } + return c.conn.LocalAddr() +} + +func (c *fragmentConn) RemoteAddr() net.Addr { + if c.conn == nil { + return M.Socksaddr{} + } + return c.conn.RemoteAddr() +} + +func (c *fragmentConn) SetDeadline(t time.Time) error { + if c.conn == nil { + return os.ErrInvalid + } + return c.conn.SetDeadline(t) +} + +func (c *fragmentConn) SetReadDeadline(t time.Time) error { + if c.conn == nil { + return os.ErrInvalid + } + return c.conn.SetReadDeadline(t) +} + +func (c *fragmentConn) SetWriteDeadline(t time.Time) error { + if c.conn == nil { + return os.ErrInvalid + } + return c.conn.SetWriteDeadline(t) +} + +func (c *fragmentConn) Upstream() any { + return c.conn +} + +func (c *fragmentConn) ReaderReplaceable() bool { + return c.conn != nil +} + +func (c *fragmentConn) WriterReplaceable() bool { + return c.conn != nil +} + +func (c *fragmentConn) LazyHeadroom() bool { + return c.conn == nil +} + +func (c *fragmentConn) NeedHandshake() bool { + return c.conn == nil +} + +func (c *fragmentConn) WriteTo(w io.Writer) (n int64, err error) { + if c.conn == nil { + return 0, c.err + } + return bufio.Copy(w, c.conn) +} diff --git a/common/dialer/tfo.go b/common/dialer/tfo.go index 9f72208d..e5f82620 100644 --- a/common/dialer/tfo.go +++ b/common/dialer/tfo.go @@ -10,13 +10,11 @@ import ( "sync" "time" + "github.com/metacubex/tfo-go" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" - N "github.com/sagernet/sing/common/network" - - "github.com/metacubex/tfo-go" ) type slowOpenConn struct { @@ -30,24 +28,6 @@ type slowOpenConn struct { err error } -func DialSlowContext(dialer *tcpDialer, ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { - if dialer.DisableTFO || N.NetworkName(network) != N.NetworkTCP { - switch N.NetworkName(network) { - case N.NetworkTCP, N.NetworkUDP: - return dialer.Dialer.DialContext(ctx, network, destination.String()) - default: - return dialer.Dialer.DialContext(ctx, network, destination.AddrString()) - } - } - return &slowOpenConn{ - dialer: dialer, - ctx: ctx, - network: network, - destination: destination, - create: make(chan struct{}), - }, nil -} - func (c *slowOpenConn) Read(b []byte) (n int, err error) { if c.conn == nil { select { diff --git a/common/dialer/wireguard.go b/common/dialer/wireguard.go index fbd323d8..1f71e080 100644 --- a/common/dialer/wireguard.go +++ b/common/dialer/wireguard.go @@ -1,9 +1,10 @@ package dialer import ( + "github.com/sagernet/sing/common/control" "net" - "github.com/sagernet/sing/common/control" + _ "github.com/redpilllabs/wireguard-go/conn" ) type WireGuardListener interface { diff --git a/go.mod b/go.mod index fbfd60e7..1b113511 100644 --- a/go.mod +++ b/go.mod @@ -17,6 +17,7 @@ require ( github.com/mholt/acmez v1.2.0 github.com/miekg/dns v1.1.62 github.com/oschwald/maxminddb-golang v1.12.0 + github.com/redpilllabs/wireguard-go v0.0.7 github.com/sagernet/bbolt v0.0.0-20231014093535-ea5cb2fe9f0a github.com/sagernet/cloudflare-tls v0.0.0-20231208171750-a4483c1b7cd1 github.com/sagernet/cors v1.2.1 diff --git a/go.sum b/go.sum index 68cc88f1..748fa65b 100644 --- a/go.sum +++ b/go.sum @@ -88,6 +88,8 @@ github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo= github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A= github.com/quic-go/qtls-go1-20 v0.4.1 h1:D33340mCNDAIKBqXuAvexTNMUByrYmFYVfKfDN5nfFs= github.com/quic-go/qtls-go1-20 v0.4.1/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k= +github.com/redpilllabs/wireguard-go v0.0.7 h1:3Z/dSHMVCJl6FAeASSzCxr18fDndFYQ5KDIpXGnpDNU= +github.com/redpilllabs/wireguard-go v0.0.7/go.mod h1:TGR83JtUUguDqglsvDL6Av6DFBal8WfeF02Wb1iU/qM= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/sagernet/bbolt v0.0.0-20231014093535-ea5cb2fe9f0a h1:+NkI2670SQpQWvkkD2QgdTuzQG263YZ+2emfpeyGqW0= github.com/sagernet/bbolt v0.0.0-20231014093535-ea5cb2fe9f0a/go.mod h1:63s7jpZqcDAIpj8oI/1v4Izok+npJOHACFCU6+huCkM= diff --git a/ipscanner/LICENSE b/ipscanner/LICENSE new file mode 100644 index 00000000..8dea11b5 --- /dev/null +++ b/ipscanner/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Bepass + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/ipscanner/README.md b/ipscanner/README.md new file mode 100644 index 00000000..a6583540 --- /dev/null +++ b/ipscanner/README.md @@ -0,0 +1,38 @@ +# IPScanner + +IPScanner is a Go package designed for scanning and analyzing IP addresses. It utilizes various dialers and an internal engine to perform scans efficiently. + +## Features +- IPv4 and IPv6 support. +- Customizable timeout and dialer options. +- Extendable with various ping methods (HTTP, QUIC, TCP, TLS). +- Adjustable IP Queue size for scan optimization. + +## Getting Started +To use IPScanner, simply import the package and initialize a new scanner with your desired options. + +```go +import "github.com/bepass-org/warp-plus/ipscanner" + +func main() { + scanner := ipscanner.NewScanner( + // Configure your options here + ) + scanner.Run() +} +``` + +## Options +You can customize your scanner with several options: +- `WithUseIPv4` and `WithUseIPv6` to specify IP versions. +- `WithDialer` and `WithTLSDialer` to define custom dialing functions. +- `WithTimeout` to set the scan timeout. +- `WithIPQueueSize` to set the IP Queue size. +- `WithPingMethod` to set the ping method, it can be HTTP, QUIC, TCP, TLS at the same time. +- Various other options for detailed scan control. + +## Contributing +Contributions to IPScanner are welcome. Please ensure to follow the project's coding standards and submit detailed pull requests. + +## License +IPScanner is licensed under the MIT license. See [LICENSE](LICENSE) for more information. diff --git a/ipscanner/internal/engine/engine.go b/ipscanner/internal/engine/engine.go new file mode 100644 index 00000000..86615feb --- /dev/null +++ b/ipscanner/internal/engine/engine.go @@ -0,0 +1,72 @@ +package engine + +import ( + "context" + "errors" + "log/slog" + "net/netip" + + "github.com/sagernet/sing-box/ipscanner/internal/iterator" + "github.com/sagernet/sing-box/ipscanner/internal/ping" + "github.com/sagernet/sing-box/ipscanner/internal/statute" +) + +type Engine struct { + generator *iterator.IpGenerator + ipQueue *IPQueue + ping func(context.Context, netip.Addr) (statute.IPInfo, error) + log *slog.Logger +} + +func NewScannerEngine(opts *statute.ScannerOptions) *Engine { + queue := NewIPQueue(opts) + + p := ping.Ping{ + Options: opts, + } + return &Engine{ + ipQueue: queue, + ping: p.DoPing, + generator: iterator.NewIterator(opts), + log: opts.Logger, + } +} + +func (e *Engine) GetAvailableIPs(desc bool) []statute.IPInfo { + if e.ipQueue != nil { + return e.ipQueue.AvailableIPs(desc) + } + return nil +} + +func (e *Engine) Run(ctx context.Context) { + e.ipQueue.Init() + + select { + case <-ctx.Done(): + return + case <-e.ipQueue.available: + e.log.Debug("Started new scanning round") + batch, err := e.generator.NextBatch() + if err != nil { + e.log.Error("Error while generating IP: %v", err) + return + } + for _, ip := range batch { + select { + case <-ctx.Done(): + return + default: + ipInfo, err := e.ping(ctx, ip) + if err != nil { + if !errors.Is(err, context.Canceled) { + e.log.Error("ping error", "addr", ip, "error", err) + } + continue + } + e.log.Debug("ping success", "addr", ipInfo.AddrPort, "rtt", ipInfo.RTT) + e.ipQueue.Enqueue(ipInfo) + } + } + } +} diff --git a/ipscanner/internal/engine/queue.go b/ipscanner/internal/engine/queue.go new file mode 100644 index 00000000..72180e8e --- /dev/null +++ b/ipscanner/internal/engine/queue.go @@ -0,0 +1,189 @@ +package engine + +import ( + "log/slog" + "sort" + "sync" + "time" + + "github.com/sagernet/sing-box/ipscanner/internal/statute" +) + +type IPQueue struct { + queue []statute.IPInfo + maxQueueSize int + mu sync.Mutex + available chan struct{} + maxTTL time.Duration + rttThreshold time.Duration + inIdealMode bool + log *slog.Logger + reserved statute.IPInfQueue +} + +func NewIPQueue(opts *statute.ScannerOptions) *IPQueue { + var reserved statute.IPInfQueue + return &IPQueue{ + queue: make([]statute.IPInfo, 0), + maxQueueSize: opts.IPQueueSize, + maxTTL: opts.IPQueueTTL, + rttThreshold: opts.MaxDesirableRTT, + available: make(chan struct{}, opts.IPQueueSize), + log: opts.Logger, + reserved: reserved, + } +} + +func (q *IPQueue) Enqueue(info statute.IPInfo) bool { + q.mu.Lock() + defer q.mu.Unlock() + + defer func() { + q.log.Debug("queue change", "len", len(q.queue)) + for _, ipInfo := range q.queue { + q.log.Debug( + "queue change", + "created", ipInfo.CreatedAt, + "addr", ipInfo.AddrPort, + "rtt", ipInfo.RTT, + ) + } + }() + + q.log.Debug("Enqueue: Sorting queue by RTT") + sort.Slice(q.queue, func(i, j int) bool { + return q.queue[i].RTT < q.queue[j].RTT + }) + + if len(q.queue) == 0 { + q.log.Debug("Enqueue: empty queue adding first available item") + q.queue = append(q.queue, info) + return false + } + + if info.RTT <= q.rttThreshold { + q.log.Debug("Enqueue: the new item's RTT is less than at least one of the members.") + if len(q.queue) >= q.maxQueueSize && info.RTT < q.queue[len(q.queue)-1].RTT { + q.log.Debug("Enqueue: the queue is full, remove the item with the highest RTT.") + q.queue = q.queue[:len(q.queue)-1] + } else if len(q.queue) < q.maxQueueSize { + q.log.Debug("Enqueue: Insert the new item in a sorted position.") + index := sort.Search(len(q.queue), func(i int) bool { return q.queue[i].RTT > info.RTT }) + q.queue = append(q.queue[:index], append([]statute.IPInfo{info}, q.queue[index:]...)...) + } else { + q.log.Debug("Enqueue: The Queue is full but we keep the new item in the reserved queue.") + q.reserved.Enqueue(info) + } + } + + q.log.Debug("Enqueue: Checking if any member has a higher RTT than the threshold.") + for _, member := range q.queue { + if member.RTT > q.rttThreshold { + return false // If any member has a higher RTT than the threshold, return false. + } + } + + q.log.Debug("Enqueue: All members have an RTT lower than the threshold.") + if len(q.queue) < q.maxQueueSize { + // the queue isn't full dont wait + return false + } + + q.inIdealMode = true + // ok wait for expiration signal + q.log.Debug("Enqueue: All members have an RTT lower than the threshold. Waiting for expiration signal.") + return true +} + +func (q *IPQueue) Dequeue() (statute.IPInfo, bool) { + defer func() { + q.log.Debug("queue change", "len", len(q.queue)) + for _, ipInfo := range q.queue { + q.log.Debug( + "queue change", + "created", ipInfo.CreatedAt, + "addr", ipInfo.AddrPort, + "rtt", ipInfo.RTT, + ) + } + }() + q.mu.Lock() + defer q.mu.Unlock() + + if len(q.queue) == 0 { + return statute.IPInfo{}, false + } + + info := q.queue[len(q.queue)-1] + q.queue = q.queue[0 : len(q.queue)-1] + + q.available <- struct{}{} + + return info, true +} + +func (q *IPQueue) Init() { + q.mu.Lock() + defer q.mu.Unlock() + + if !q.inIdealMode { + q.available <- struct{}{} + return + } +} + +func (q *IPQueue) Expire() { + q.mu.Lock() + defer q.mu.Unlock() + + q.log.Debug("Expire: In ideal mode") + defer func() { + q.log.Debug("queue change", "len", len(q.queue)) + for _, ipInfo := range q.queue { + q.log.Debug( + "queue change", + "created", ipInfo.CreatedAt, + "addr", ipInfo.AddrPort, + "rtt", ipInfo.RTT, + ) + } + }() + + shouldStartNewScan := false + resQ := make([]statute.IPInfo, 0) + for i := 0; i < len(q.queue); i++ { + if time.Since(q.queue[i].CreatedAt) > q.maxTTL { + q.log.Debug("Expire: Removing expired item from queue") + shouldStartNewScan = true + } else { + resQ = append(resQ, q.queue[i]) + } + } + q.queue = resQ + q.log.Debug("Expire: Adding reserved items to queue") + for i := 0; i < q.maxQueueSize && i < q.reserved.Size(); i++ { + q.queue = append(q.queue, q.reserved.Dequeue()) + } + if shouldStartNewScan { + q.available <- struct{}{} + } +} + +func (q *IPQueue) AvailableIPs(desc bool) []statute.IPInfo { + q.mu.Lock() + defer q.mu.Unlock() + + // Create a separate slice for sorting + sortedQueue := make([]statute.IPInfo, len(q.queue)) + copy(sortedQueue, q.queue) + + // Sort by RTT ascending/descending + sort.Slice(sortedQueue, func(i, j int) bool { + if desc { + return sortedQueue[i].RTT > sortedQueue[j].RTT + } + return sortedQueue[i].RTT < sortedQueue[j].RTT + }) + + return sortedQueue +} diff --git a/ipscanner/internal/iterator/iterator.go b/ipscanner/internal/iterator/iterator.go new file mode 100644 index 00000000..e2cf3481 --- /dev/null +++ b/ipscanner/internal/iterator/iterator.go @@ -0,0 +1,247 @@ +package iterator + +import ( + "crypto/rand" + "errors" + "math/big" + "net" + "net/netip" + + "github.com/sagernet/sing-box/ipscanner/internal/statute" +) + +// LCG represents a linear congruential generator with full period. +type LCG struct { + modulus *big.Int + multiplier *big.Int + increment *big.Int + current *big.Int +} + +// NewLCG creates a new LCG instance with a given size. +func NewLCG(size *big.Int) *LCG { + modulus := new(big.Int).Set(size) + + // Generate random multiplier (a) and increment (c) that satisfy Hull-Dobell Theorem + var multiplier, increment *big.Int + for { + var err error + multiplier, err = rand.Int(rand.Reader, modulus) + if err != nil { + continue + } + increment, err = rand.Int(rand.Reader, modulus) + if err != nil { + continue + } + + // Check Hull-Dobell Theorem conditions + if checkHullDobell(modulus, multiplier, increment) { + break + } + } + + return &LCG{ + modulus: modulus, + multiplier: multiplier, + increment: increment, + current: big.NewInt(0), + } +} + +// checkHullDobell checks if the given parameters satisfy the Hull-Dobell Theorem. +func checkHullDobell(modulus, multiplier, increment *big.Int) bool { + // c and m are relatively prime + gcd := new(big.Int).GCD(nil, nil, increment, modulus) + if gcd.Cmp(big.NewInt(1)) != 0 { + return false + } + + // a - 1 is divisible by all prime factors of m + aMinusOne := new(big.Int).Sub(multiplier, big.NewInt(1)) + + // a - 1 is divisible by 4 if m is divisible by 4 + if new(big.Int).And(modulus, big.NewInt(3)).Cmp(big.NewInt(0)) == 0 { + if new(big.Int).And(aMinusOne, big.NewInt(3)).Cmp(big.NewInt(0)) != 0 { + return false + } + } + + return true +} + +// Next generates the next number in the sequence. +func (lcg *LCG) Next() *big.Int { + if lcg.current.Cmp(lcg.modulus) == 0 { + return nil // Sequence complete + } + + next := new(big.Int) + next.Mul(lcg.multiplier, lcg.current) + next.Add(next, lcg.increment) + next.Mod(next, lcg.modulus) + + lcg.current.Set(next) + return next +} + +type ipRange struct { + lcg *LCG + start netip.Addr + stop netip.Addr + size *big.Int + index *big.Int +} + +func newIPRange(cidr netip.Prefix) (ipRange, error) { + startIP := cidr.Addr() + stopIP := lastIP(cidr) + size := ipRangeSize(cidr) + return ipRange{ + start: startIP, + stop: stopIP, + size: size, + index: big.NewInt(0), + lcg: NewLCG(size), + }, nil +} + +func lastIP(prefix netip.Prefix) netip.Addr { + // Calculate the number of bits to fill for the last address based on the address family + fillBits := 128 - prefix.Bits() + if prefix.Addr().Is4() { + fillBits = 32 - prefix.Bits() + } + + // Calculate the numerical representation of the last address by setting the remaining bits to 1 + var lastAddrInt big.Int + lastAddrInt.SetBytes(prefix.Addr().AsSlice()) + for i := 0; i < fillBits; i++ { + lastAddrInt.SetBit(&lastAddrInt, i, 1) + } + + // Convert the big.Int back to netip.Addr + lastAddrBytes := lastAddrInt.Bytes() + var lastAddr netip.Addr + if prefix.Addr().Is4() { + // Ensure the slice is the right length for IPv4 + if len(lastAddrBytes) < net.IPv4len { + leadingZeros := make([]byte, net.IPv4len-len(lastAddrBytes)) + lastAddrBytes = append(leadingZeros, lastAddrBytes...) + } + lastAddr, _ = netip.AddrFromSlice(lastAddrBytes[len(lastAddrBytes)-net.IPv4len:]) + } else { + // Ensure the slice is the right length for IPv6 + if len(lastAddrBytes) < net.IPv6len { + leadingZeros := make([]byte, net.IPv6len-len(lastAddrBytes)) + lastAddrBytes = append(leadingZeros, lastAddrBytes...) + } + lastAddr, _ = netip.AddrFromSlice(lastAddrBytes) + } + + return lastAddr +} + +func addIP(ip netip.Addr, num *big.Int) netip.Addr { + addrAs16 := ip.As16() + ipInt := new(big.Int).SetBytes(addrAs16[:]) + ipInt.Add(ipInt, num) + addr, _ := netip.AddrFromSlice(ipInt.FillBytes(make([]byte, 16))) + return addr.Unmap() +} + +func ipRangeSize(prefix netip.Prefix) *big.Int { + // The number of bits in the address depends on whether it's IPv4 or IPv6. + totalBits := 128 // Assume IPv6 by default + if prefix.Addr().Is4() { + totalBits = 32 // Adjust for IPv4 + } + + // Calculate the size of the range + bits := prefix.Bits() // This is the prefix length + size := big.NewInt(1) + size.Lsh(size, uint(totalBits-bits)) // Left shift to calculate the range size + + return size +} + +type IpGenerator struct { + ipRanges []ipRange +} + +func (g *IpGenerator) NextBatch() ([]netip.Addr, error) { + var results []netip.Addr + for i, r := range g.ipRanges { + if r.index.Cmp(r.size) >= 0 { + continue + } + shuffleIndex := r.lcg.Next() + if shuffleIndex == nil { + continue + } + results = append(results, addIP(r.start, shuffleIndex)) + g.ipRanges[i].index.Add(g.ipRanges[i].index, big.NewInt(1)) + } + if len(results) == 0 { + okFlag := false + for i := range g.ipRanges { + if g.ipRanges[i].index.Cmp(big.NewInt(0)) > 0 { + okFlag = true + } + g.ipRanges[i].index.SetInt64(0) + } + if okFlag { + // Reshuffle and start over + for i := range g.ipRanges { + g.ipRanges[i].lcg = NewLCG(g.ipRanges[i].size) + } + return g.NextBatch() + } else { + return nil, errors.New("no more IP addresses") + } + } + return results, nil +} + +// shuffleSubnetsIpRange shuffles a slice of ipRange using crypto/rand +func shuffleSubnetsIpRange(subnets []ipRange) error { + for i := range subnets { + jBig, err := rand.Int(rand.Reader, big.NewInt(int64(len(subnets)))) + if err != nil { + return err + } + j := jBig.Int64() + + subnets[i], subnets[j] = subnets[j], subnets[i] + } + return nil +} + +func NewIterator(opts *statute.ScannerOptions) *IpGenerator { + var ranges []ipRange + for _, cidr := range opts.CidrList { + if !opts.UseIPv6 && cidr.Addr().Is6() { + continue + } + if !opts.UseIPv4 && cidr.Addr().Is4() { + continue + } + + ipRange, err := newIPRange(cidr) + if err != nil { + // TODO + continue + } + ranges = append(ranges, ipRange) + } + if len(ranges) == 0 { + // TODO + return nil + } + err := shuffleSubnetsIpRange(ranges) + if err != nil { + // TODO + return nil + } + return &IpGenerator{ipRanges: ranges} +} diff --git a/ipscanner/internal/ping/http.go b/ipscanner/internal/ping/http.go new file mode 100644 index 00000000..be1b6ec4 --- /dev/null +++ b/ipscanner/internal/ping/http.go @@ -0,0 +1,128 @@ +package ping + +import ( + "context" + "errors" + "fmt" + "io" + "net/http" + "net/netip" + "net/url" + "time" + + "github.com/sagernet/sing-box/ipscanner/internal/statute" +) + +type HttpPingResult struct { + AddrPort netip.AddrPort + Proto string + Status int + Length int + RTT time.Duration + Err error +} + +func (h *HttpPingResult) Result() statute.IPInfo { + return statute.IPInfo{AddrPort: h.AddrPort, RTT: h.RTT, CreatedAt: time.Now()} +} + +func (h *HttpPingResult) Error() error { + return h.Err +} + +func (h *HttpPingResult) String() string { + if h.Err != nil { + return fmt.Sprintf("%s", h.Err) + } + + return fmt.Sprintf("%s: protocol=%s, status=%d, length=%d, time=%d ms", h.AddrPort, h.Proto, h.Status, h.Length, h.RTT) +} + +type HttpPing struct { + Method string + URL string + IP netip.Addr + + opts statute.ScannerOptions +} + +func (h *HttpPing) Ping() statute.IPingResult { + return h.PingContext(context.Background()) +} + +func (h *HttpPing) PingContext(ctx context.Context) statute.IPingResult { + u, err := url.Parse(h.URL) + if err != nil { + return h.errorResult(err) + } + orighost := u.Host + + if !h.IP.IsValid() { + return h.errorResult(errors.New("no IP specified")) + } + + req, err := http.NewRequestWithContext(ctx, h.Method, h.URL, nil) + if err != nil { + return h.errorResult(err) + } + ua := "httping" + if h.opts.UserAgent != "" { + ua = h.opts.UserAgent + } + req.Header.Set("User-Agent", ua) + if h.opts.Referrer != "" { + req.Header.Set("Referer", h.opts.Referrer) + } + req.Host = orighost + + addr := netip.AddrPortFrom(h.IP, h.opts.Port) + client := h.opts.HttpClientFunc(h.opts.RawDialerFunc, h.opts.TLSDialerFunc, h.opts.QuicDialerFunc, addr.String()) + + client.CheckRedirect = func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + } + + t0 := time.Now() + resp, err := client.Do(req) + if err != nil { + return h.errorResult(err) + } + + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + return h.errorResult(err) + } + + res := HttpPingResult{ + AddrPort: addr, + Proto: resp.Proto, + Status: resp.StatusCode, + Length: len(body), + RTT: time.Since(t0), + Err: nil, + } + + return &res +} + +func (h *HttpPing) errorResult(err error) *HttpPingResult { + r := &HttpPingResult{} + r.Err = err + return r +} + +func NewHttpPing(ip netip.Addr, method, url string, opts *statute.ScannerOptions) *HttpPing { + return &HttpPing{ + IP: ip, + Method: method, + URL: url, + + opts: *opts, + } +} + +var ( + _ statute.IPing = (*HttpPing)(nil) + _ statute.IPingResult = (*HttpPingResult)(nil) +) diff --git a/ipscanner/internal/ping/ping.go b/ipscanner/internal/ping/ping.go new file mode 100644 index 00000000..8b3ab6ac --- /dev/null +++ b/ipscanner/internal/ping/ping.go @@ -0,0 +1,94 @@ +package ping + +import ( + "context" + "errors" + "fmt" + "net/netip" + + "github.com/sagernet/sing-box/ipscanner/internal/statute" +) + +type Ping struct { + Options *statute.ScannerOptions +} + +// DoPing performs a ping on the given IP address. +func (p *Ping) DoPing(ctx context.Context, ip netip.Addr) (statute.IPInfo, error) { + if p.Options.SelectedOps&statute.HTTPPing > 0 { + res, err := p.httpPing(ctx, ip) + if err != nil { + return statute.IPInfo{}, err + } + + return res, nil + } + if p.Options.SelectedOps&statute.TLSPing > 0 { + res, err := p.tlsPing(ctx, ip) + if err != nil { + return statute.IPInfo{}, err + } + + return res, nil + } + if p.Options.SelectedOps&statute.TCPPing > 0 { + res, err := p.tcpPing(ctx, ip) + if err != nil { + return statute.IPInfo{}, err + } + + return res, nil + } + if p.Options.SelectedOps&statute.WARPPing > 0 { + res, err := p.warpPing(ctx, ip) + if err != nil { + return statute.IPInfo{}, err + } + + return res, nil + } + + return statute.IPInfo{}, errors.New("no ping operation selected") +} + +func (p *Ping) httpPing(ctx context.Context, ip netip.Addr) (statute.IPInfo, error) { + return p.calc( + ctx, + NewHttpPing( + ip, + "GET", + fmt.Sprintf( + "https://%s:%d%s", + p.Options.Hostname, + p.Options.Port, + p.Options.HTTPPath, + ), + p.Options, + ), + ) +} + +func (p *Ping) warpPing(ctx context.Context, ip netip.Addr) (statute.IPInfo, error) { + return p.calc(ctx, NewWarpPing(ip, p.Options)) +} + +func (p *Ping) tlsPing(ctx context.Context, ip netip.Addr) (statute.IPInfo, error) { + return p.calc(ctx, + NewTlsPing(ip, p.Options.Hostname, p.Options.Port, p.Options), + ) +} + +func (p *Ping) tcpPing(ctx context.Context, ip netip.Addr) (statute.IPInfo, error) { + return p.calc(ctx, + NewTcpPing(ip, p.Options.Hostname, p.Options.Port, p.Options), + ) +} + +func (p *Ping) calc(ctx context.Context, tp statute.IPing) (statute.IPInfo, error) { + pr := tp.PingContext(ctx) + err := pr.Error() + if err != nil { + return statute.IPInfo{}, err + } + return pr.Result(), nil +} diff --git a/ipscanner/internal/ping/tcp.go b/ipscanner/internal/ping/tcp.go new file mode 100644 index 00000000..d5694d7f --- /dev/null +++ b/ipscanner/internal/ping/tcp.go @@ -0,0 +1,84 @@ +package ping + +import ( + "context" + "errors" + "fmt" + "net/netip" + "time" + + "github.com/sagernet/sing-box/ipscanner/internal/statute" +) + +type TcpPingResult struct { + AddrPort netip.AddrPort + RTT time.Duration + Err error +} + +func (tp *TcpPingResult) Result() statute.IPInfo { + return statute.IPInfo{AddrPort: tp.AddrPort, RTT: tp.RTT, CreatedAt: time.Now()} +} + +func (tp *TcpPingResult) Error() error { + return tp.Err +} + +func (tp *TcpPingResult) String() string { + if tp.Err != nil { + return fmt.Sprintf("%s", tp.Err) + } else { + return fmt.Sprintf("%s: time=%d ms", tp.AddrPort, tp.RTT) + } +} + +type TcpPing struct { + host string + port uint16 + ip netip.Addr + + opts statute.ScannerOptions +} + +func (tp *TcpPing) SetHost(host string) { + tp.host = host + tp.ip, _ = netip.ParseAddr(host) +} + +func (tp *TcpPing) Host() string { + return tp.host +} + +func (tp *TcpPing) Ping() statute.IPingResult { + return tp.PingContext(context.Background()) +} + +func (tp *TcpPing) PingContext(ctx context.Context) statute.IPingResult { + if !tp.ip.IsValid() { + return &TcpPingResult{AddrPort: netip.AddrPort{}, RTT: 0, Err: errors.New("no IP specified")} + } + + addr := netip.AddrPortFrom(tp.ip, tp.port) + t0 := time.Now() + conn, err := tp.opts.RawDialerFunc(ctx, "tcp", addr.String()) + if err != nil { + return &TcpPingResult{AddrPort: addr, RTT: 0, Err: err} + } + defer conn.Close() + + return &TcpPingResult{AddrPort: addr, RTT: time.Since(t0), Err: nil} +} + +func NewTcpPing(ip netip.Addr, host string, port uint16, opts *statute.ScannerOptions) *TcpPing { + return &TcpPing{ + host: host, + port: port, + ip: ip, + opts: *opts, + } +} + +var ( + _ statute.IPing = (*TcpPing)(nil) + _ statute.IPingResult = (*TcpPingResult)(nil) +) diff --git a/ipscanner/internal/ping/tls.go b/ipscanner/internal/ping/tls.go new file mode 100644 index 00000000..93f6d7e6 --- /dev/null +++ b/ipscanner/internal/ping/tls.go @@ -0,0 +1,80 @@ +package ping + +import ( + "context" + "errors" + "fmt" + "net/netip" + "time" + + "github.com/sagernet/sing-box/ipscanner/internal/statute" +) + +type TlsPingResult struct { + AddrPort netip.AddrPort + TLSVersion uint16 + RTT time.Duration + Err error +} + +func (t *TlsPingResult) Result() statute.IPInfo { + return statute.IPInfo{AddrPort: t.AddrPort, RTT: t.RTT, CreatedAt: time.Now()} +} + +func (t *TlsPingResult) Error() error { + return t.Err +} + +func (t *TlsPingResult) String() string { + if t.Err != nil { + return fmt.Sprintf("%s", t.Err) + } + + return fmt.Sprintf("%s: protocol=%s, time=%d ms", t.AddrPort, statute.TlsVersionToString(t.TLSVersion), t.RTT) +} + +type TlsPing struct { + Host string + Port uint16 + IP netip.Addr + + opts *statute.ScannerOptions +} + +func (t *TlsPing) Ping() statute.IPingResult { + return t.PingContext(context.Background()) +} + +func (t *TlsPing) PingContext(ctx context.Context) statute.IPingResult { + if !t.IP.IsValid() { + return t.errorResult(errors.New("no IP specified")) + } + addr := netip.AddrPortFrom(t.IP, t.Port) + t0 := time.Now() + client, err := t.opts.TLSDialerFunc(ctx, "tcp", addr.String()) + if err != nil { + return t.errorResult(err) + } + defer client.Close() + return &TlsPingResult{AddrPort: addr, TLSVersion: t.opts.TlsVersion, RTT: time.Since(t0), Err: nil} +} + +func NewTlsPing(ip netip.Addr, host string, port uint16, opts *statute.ScannerOptions) *TlsPing { + return &TlsPing{ + IP: ip, + Host: host, + Port: port, + opts: opts, + } +} + +func (t *TlsPing) errorResult(err error) *TlsPingResult { + r := &TlsPingResult{} + r.Err = err + return r +} + +var ( + _ statute.IPing = (*TlsPing)(nil) + _ statute.IPingResult = (*TlsPingResult)(nil) +) diff --git a/ipscanner/internal/ping/warp.go b/ipscanner/internal/ping/warp.go new file mode 100644 index 00000000..e9209cdb --- /dev/null +++ b/ipscanner/internal/ping/warp.go @@ -0,0 +1,304 @@ +package ping + +import ( + "bytes" + "context" + "crypto/rand" + "encoding/base64" + "encoding/binary" + "errors" + "fmt" + "math/big" + "net" + "net/netip" + "time" + + "github.com/flynn/noise" + "github.com/sagernet/sing-box/ipscanner/internal/statute" + "github.com/sagernet/sing-box/warp" + "golang.org/x/crypto/blake2s" + "golang.org/x/crypto/curve25519" +) + +type WarpPingResult struct { + AddrPort netip.AddrPort + RTT time.Duration + Err error +} + +func (h *WarpPingResult) Result() statute.IPInfo { + return statute.IPInfo{AddrPort: h.AddrPort, RTT: h.RTT, CreatedAt: time.Now()} +} + +func (h *WarpPingResult) Error() error { + return h.Err +} + +func (h *WarpPingResult) String() string { + if h.Err != nil { + return fmt.Sprintf("%s", h.Err) + } else { + return fmt.Sprintf("%s: protocol=%s, time=%d ms", h.AddrPort, "warp", h.RTT) + } +} + +type WarpPing struct { + PrivateKey string + PeerPublicKey string + PresharedKey string + IP netip.Addr + + opts statute.ScannerOptions +} + +func (h *WarpPing) Ping() statute.IPingResult { + return h.PingContext(context.Background()) +} + +func (h *WarpPing) PingContext(ctx context.Context) statute.IPingResult { + var port uint16 = 0 + if h.opts.Port == 0 || h.opts.Port == 443 { + port = warp.RandomWarpPort() + } else { + port = h.opts.Port + } + addr := netip.AddrPortFrom(h.IP, port) + rtt, err := initiateHandshake( + ctx, + addr, + h.PrivateKey, + h.PeerPublicKey, + h.PresharedKey, + ) + if err != nil { + return h.errorResult(err) + } + + return &WarpPingResult{AddrPort: addr, RTT: rtt, Err: nil} +} + +func (h *WarpPing) errorResult(err error) *WarpPingResult { + r := &WarpPingResult{} + r.Err = err + return r +} + +func uint32ToBytes(n uint32) []byte { + b := make([]byte, 4) + binary.LittleEndian.PutUint32(b, n) + return b +} + +func staticKeypair(privateKeyBase64 string) (noise.DHKey, error) { + privateKey, err := base64.StdEncoding.DecodeString(privateKeyBase64) + if err != nil { + return noise.DHKey{}, err + } + + var pubkey, privkey [32]byte + copy(privkey[:], privateKey) + curve25519.ScalarBaseMult(&pubkey, &privkey) + + return noise.DHKey{ + Private: privateKey, + Public: pubkey[:], + }, nil +} + +func ephemeralKeypair() (noise.DHKey, error) { + // Generate an ephemeral private key + ephemeralPrivateKey := make([]byte, 32) + if _, err := rand.Read(ephemeralPrivateKey); err != nil { + return noise.DHKey{}, err + } + + // Derive the corresponding ephemeral public key + ephemeralPublicKey, err := curve25519.X25519(ephemeralPrivateKey, curve25519.Basepoint) + if err != nil { + return noise.DHKey{}, err + } + + return noise.DHKey{ + Private: ephemeralPrivateKey, + Public: ephemeralPublicKey, + }, nil +} + +func randomInt(min, max uint64) uint64 { + rangee := max - min + if rangee < 1 { + return 0 + } + + n, err := rand.Int(rand.Reader, big.NewInt(int64(rangee))) + if err != nil { + panic(err) + } + + return min + n.Uint64() +} + +func initiateHandshake(ctx context.Context, serverAddr netip.AddrPort, privateKeyBase64, peerPublicKeyBase64, presharedKeyBase64 string) (time.Duration, error) { + staticKeyPair, err := staticKeypair(privateKeyBase64) + if err != nil { + return 0, err + } + + peerPublicKey, err := base64.StdEncoding.DecodeString(peerPublicKeyBase64) + if err != nil { + return 0, err + } + + presharedKey, err := base64.StdEncoding.DecodeString(presharedKeyBase64) + if err != nil { + return 0, err + } + + if presharedKeyBase64 == "" { + presharedKey = make([]byte, 32) + } + + ephemeral, err := ephemeralKeypair() + if err != nil { + return 0, err + } + + cs := noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashBLAKE2s) + hs, err := noise.NewHandshakeState(noise.Config{ + CipherSuite: cs, + Pattern: noise.HandshakeIK, + Initiator: true, + StaticKeypair: staticKeyPair, + PeerStatic: peerPublicKey, + Prologue: []byte("WireGuard v1 zx2c4 Jason@zx2c4.com"), + PresharedKey: presharedKey, + PresharedKeyPlacement: 2, + EphemeralKeypair: ephemeral, + Random: rand.Reader, + }) + if err != nil { + return 0, err + } + + // Prepare handshake initiation packet + + // TAI64N timestamp calculation + now := time.Now().UTC() + epochOffset := int64(4611686018427387914) // TAI offset from Unix epoch + + tai64nTimestampBuf := make([]byte, 0, 16) + tai64nTimestampBuf = binary.BigEndian.AppendUint64(tai64nTimestampBuf, uint64(epochOffset+now.Unix())) + tai64nTimestampBuf = binary.BigEndian.AppendUint32(tai64nTimestampBuf, uint32(now.Nanosecond())) + msg, _, _, err := hs.WriteMessage(nil, tai64nTimestampBuf) + if err != nil { + return 0, err + } + + initiationPacket := new(bytes.Buffer) + binary.Write(initiationPacket, binary.BigEndian, []byte{0x01, 0x00, 0x00, 0x00}) + binary.Write(initiationPacket, binary.BigEndian, uint32ToBytes(28)) + binary.Write(initiationPacket, binary.BigEndian, msg) + + macKey := blake2s.Sum256(append([]byte("mac1----"), peerPublicKey...)) + hasher, err := blake2s.New128(macKey[:]) // using macKey as the key + if err != nil { + return 0, err + } + _, err = hasher.Write(initiationPacket.Bytes()) + if err != nil { + return 0, err + } + initiationPacketMAC := hasher.Sum(nil) + + // Append the MAC and 16 null bytes to the initiation packet + binary.Write(initiationPacket, binary.BigEndian, initiationPacketMAC[:16]) + binary.Write(initiationPacket, binary.BigEndian, [16]byte{}) + + conn, err := net.Dial("udp", serverAddr.String()) + if err != nil { + return 0, err + } + defer conn.Close() + + numPackets := randomInt(8, 15) + randomPacket := make([]byte, 100) + for i := uint64(0); i < numPackets; i++ { + select { + case <-ctx.Done(): + return 0, ctx.Err() + default: + packetSize := randomInt(40, 100) + _, err := rand.Read(randomPacket[:packetSize]) + if err != nil { + return 0, fmt.Errorf("error generating random packet: %w", err) + } + + _, err = conn.Write(randomPacket[:packetSize]) + if err != nil { + return 0, fmt.Errorf("error sending random packet: %w", err) + } + + time.Sleep(time.Duration(randomInt(20, 250)) * time.Millisecond) + } + } + + _, err = initiationPacket.WriteTo(conn) + if err != nil { + return 0, err + } + t0 := time.Now() + + response := make([]byte, 92) + conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + i, err := conn.Read(response) + if err != nil { + return 0, err + } + rtt := time.Since(t0) + + if i < 60 { + return 0, fmt.Errorf("invalid handshake response length %d bytes", i) + } + + // Check the response type + if response[0] != 2 { // 2 is the message type for response + return 0, errors.New("invalid response type") + } + + // Extract sender and receiver index from the response + // peer index + _ = binary.LittleEndian.Uint32(response[4:8]) + // our index(we set it to 28) + ourIndex := binary.LittleEndian.Uint32(response[8:12]) + if ourIndex != 28 { // Check if the response corresponds to our sender index + return 0, errors.New("invalid sender index in response") + } + + payload, _, _, err := hs.ReadMessage(nil, response[12:60]) + if err != nil { + return 0, err + } + + // Check if the payload is empty (as expected in WireGuard handshake) + if len(payload) != 0 { + return 0, errors.New("unexpected payload in response") + } + + return rtt, nil +} + +func NewWarpPing(ip netip.Addr, opts *statute.ScannerOptions) *WarpPing { + return &WarpPing{ + PrivateKey: opts.WarpPrivateKey, + PeerPublicKey: opts.WarpPeerPublicKey, + PresharedKey: opts.WarpPresharedKey, + IP: ip, + + opts: *opts, + } +} + +var ( + _ statute.IPing = (*WarpPing)(nil) + _ statute.IPingResult = (*WarpPingResult)(nil) +) diff --git a/ipscanner/internal/statute/default.go b/ipscanner/internal/statute/default.go new file mode 100644 index 00000000..690e12bb --- /dev/null +++ b/ipscanner/internal/statute/default.go @@ -0,0 +1,173 @@ +package statute + +import ( + "context" + "crypto/tls" + "net" + "net/http" + "net/netip" + "time" + + "github.com/noql-net/certpool" + "github.com/sagernet/quic-go" +) + +var FinalOptions *ScannerOptions + +func DefaultHTTPClientFunc(rawDialer TDialerFunc, tlsDialer TDialerFunc, quicDialer TQuicDialerFunc, targetAddr ...string) *http.Client { + var defaultDialer TDialerFunc + if rawDialer == nil { + defaultDialer = DefaultDialerFunc + } else { + defaultDialer = rawDialer + } + var defaultTLSDialer TDialerFunc + if rawDialer == nil { + defaultTLSDialer = DefaultTLSDialerFunc + } else { + defaultTLSDialer = tlsDialer + } + + transport := &http.Transport{ + DialContext: defaultDialer, + DialTLSContext: defaultTLSDialer, + ForceAttemptHTTP2: FinalOptions.UseHTTP2, + DisableCompression: FinalOptions.DisableCompression, + MaxIdleConnsPerHost: -1, + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: FinalOptions.InsecureSkipVerify, + ServerName: FinalOptions.Hostname, + }, + } + + return &http.Client{ + Transport: transport, + Timeout: FinalOptions.ConnectionTimeout, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } +} + +func DefaultDialerFunc(ctx context.Context, network, addr string) (net.Conn, error) { + d := &net.Dialer{ + Timeout: FinalOptions.ConnectionTimeout, // Connection timeout + // Add other custom settings as needed + } + return d.DialContext(ctx, network, addr) +} + +func getServerName(address string) (string, error) { + host, _, err := net.SplitHostPort(address) + if err != nil { + return "", err // handle the error properly in your real application + } + return host, nil +} + +func defaultTLSConfig(addr string) *tls.Config { + allowInsecure := false + sni, err := getServerName(addr) + if err != nil { + allowInsecure = true + } + + if FinalOptions.Hostname != "" { + sni = FinalOptions.Hostname + } + + alpnProtocols := []string{"http/1.1"} + + // Add protocols based on flags + if FinalOptions.UseHTTP3 { + alpnProtocols = []string{"http/1.1"} // ALPN token for HTTP/3 + } + if FinalOptions.UseHTTP2 { + alpnProtocols = []string{"h2", "http/1.1"} // ALPN token for HTTP/2 + } + + // Initiate a TLS handshake over the connection + return &tls.Config{ + InsecureSkipVerify: allowInsecure || FinalOptions.InsecureSkipVerify, + ServerName: sni, + MinVersion: FinalOptions.TlsVersion, + MaxVersion: FinalOptions.TlsVersion, + NextProtos: alpnProtocols, + RootCAs: certpool.Roots(), + } +} + +// DefaultTLSDialerFunc is a custom TLS dialer function +func DefaultTLSDialerFunc(ctx context.Context, network, addr string) (net.Conn, error) { + // Dial the raw connection using the default dialer + rawConn, err := DefaultDialerFunc(ctx, network, addr) + if err != nil { + return nil, err + } + + // Ensure the raw connection is closed in case of an error after this point + defer func() { + if err != nil { + _ = rawConn.Close() + } + }() + + // Prepare the TLS client connection + tlsClientConn := tls.Client(rawConn, defaultTLSConfig(addr)) + + // Perform the handshake with a timeout + err = tlsClientConn.SetDeadline(time.Now().Add(FinalOptions.HandshakeTimeout)) + if err != nil { + return nil, err + } + + err = tlsClientConn.Handshake() + if err != nil { + return nil, err // rawConn will be closed by the deferred function + } + + // Reset the deadline for future I/O operations + err = tlsClientConn.SetDeadline(time.Time{}) + if err != nil { + return nil, err + } + + // Return the established TLS connection + // Cancel the deferred closure of rawConn since everything succeeded + err = nil + return tlsClientConn, nil +} + +func DefaultQuicDialerFunc(ctx context.Context, addr string, _ *tls.Config, _ *quic.Config) (quic.EarlyConnection, error) { + quicConfig := &quic.Config{ + MaxIdleTimeout: FinalOptions.ConnectionTimeout, + HandshakeIdleTimeout: FinalOptions.HandshakeTimeout, + } + return quic.DialAddrEarly(ctx, addr, defaultTLSConfig(addr), quicConfig) +} + +func DefaultCFRanges() []netip.Prefix { + return []netip.Prefix{ + netip.MustParsePrefix("103.21.244.0/22"), + netip.MustParsePrefix("103.22.200.0/22"), + netip.MustParsePrefix("103.31.4.0/22"), + netip.MustParsePrefix("104.16.0.0/12"), + netip.MustParsePrefix("108.162.192.0/18"), + netip.MustParsePrefix("131.0.72.0/22"), + netip.MustParsePrefix("141.101.64.0/18"), + netip.MustParsePrefix("162.158.0.0/15"), + netip.MustParsePrefix("172.64.0.0/13"), + netip.MustParsePrefix("173.245.48.0/20"), + netip.MustParsePrefix("188.114.96.0/20"), + netip.MustParsePrefix("190.93.240.0/20"), + netip.MustParsePrefix("197.234.240.0/22"), + netip.MustParsePrefix("198.41.128.0/17"), + netip.MustParsePrefix("2400:cb00::/32"), + netip.MustParsePrefix("2405:8100::/32"), + netip.MustParsePrefix("2405:b500::/32"), + netip.MustParsePrefix("2606:4700::/32"), + netip.MustParsePrefix("2803:f800::/32"), + netip.MustParsePrefix("2c0f:f248::/32"), + netip.MustParsePrefix("2a06:98c0::/29"), + } +} diff --git a/ipscanner/internal/statute/ping.go b/ipscanner/internal/statute/ping.go new file mode 100644 index 00000000..816a26fc --- /dev/null +++ b/ipscanner/internal/statute/ping.go @@ -0,0 +1,35 @@ +package statute + +import ( + "context" + "crypto/tls" + "fmt" +) + +type IPingResult interface { + Result() IPInfo + Error() error + fmt.Stringer +} + +type IPing interface { + Ping() IPingResult + PingContext(context.Context) IPingResult +} + +func TlsVersionToString(ver uint16) string { + switch ver { + case tls.VersionSSL30: + return "SSL 3.0" + case tls.VersionTLS10: + return "TLS 1.0" + case tls.VersionTLS11: + return "TLS 1.1" + case tls.VersionTLS12: + return "TLS 1.2" + case tls.VersionTLS13: + return "TLS 1.3" + default: + return "unknown" + } +} diff --git a/ipscanner/internal/statute/queue.go b/ipscanner/internal/statute/queue.go new file mode 100644 index 00000000..16975362 --- /dev/null +++ b/ipscanner/internal/statute/queue.go @@ -0,0 +1,34 @@ +package statute + +import ( + "sort" + "time" +) + +type IPInfQueue struct { + items []IPInfo +} + +// Enqueue adds an item and then sorts the queue. +func (q *IPInfQueue) Enqueue(item IPInfo) { + q.items = append(q.items, item) + sort.Slice(q.items, func(i, j int) bool { + return q.items[i].RTT < q.items[j].RTT + }) +} + +// Dequeue removes and returns the item with the lowest RTT. +func (q *IPInfQueue) Dequeue() IPInfo { + if len(q.items) == 0 { + return IPInfo{} // Returning an empty IPInfo when the queue is empty. + } + item := q.items[0] + q.items = q.items[1:] + item.CreatedAt = time.Now() + return item +} + +// Size returns the number of items in the queue. +func (q *IPInfQueue) Size() int { + return len(q.items) +} diff --git a/ipscanner/internal/statute/statute.go b/ipscanner/internal/statute/statute.go new file mode 100644 index 00000000..1f7c0eab --- /dev/null +++ b/ipscanner/internal/statute/statute.go @@ -0,0 +1,66 @@ +package statute + +import ( + "context" + "crypto/tls" + "log/slog" + "net" + "net/http" + "net/netip" + "time" + + "github.com/sagernet/quic-go" +) + +type TIPQueueChangeCallback func(ips []IPInfo) + +type ( + TDialerFunc func(ctx context.Context, network, addr string) (net.Conn, error) + TQuicDialerFunc func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) + THTTPClientFunc func(rawDialer TDialerFunc, tlsDialer TDialerFunc, quicDialer TQuicDialerFunc, targetAddr ...string) *http.Client +) + +var ( + HTTPPing = 1 << 1 + TLSPing = 1 << 2 + TCPPing = 1 << 3 + QUICPing = 1 << 4 + WARPPing = 1 << 5 +) + +type IPInfo struct { + AddrPort netip.AddrPort + RTT time.Duration + CreatedAt time.Time +} + +type ScannerOptions struct { + UseIPv4 bool + UseIPv6 bool + CidrList []netip.Prefix // CIDR ranges to scan + SelectedOps int + Logger *slog.Logger + InsecureSkipVerify bool + RawDialerFunc TDialerFunc + TLSDialerFunc TDialerFunc + QuicDialerFunc TQuicDialerFunc + HttpClientFunc THTTPClientFunc + UseHTTP3 bool + UseHTTP2 bool + DisableCompression bool + HTTPPath string + Referrer string + UserAgent string + Hostname string + WarpPrivateKey string + WarpPeerPublicKey string + WarpPresharedKey string + Port uint16 + IPQueueSize int + IPQueueTTL time.Duration + MaxDesirableRTT time.Duration + IPQueueChangeCallback TIPQueueChangeCallback + ConnectionTimeout time.Duration + HandshakeTimeout time.Duration + TlsVersion uint16 +} diff --git a/ipscanner/scanner.go b/ipscanner/scanner.go new file mode 100644 index 00000000..1c3216b5 --- /dev/null +++ b/ipscanner/scanner.go @@ -0,0 +1,275 @@ +/* +Copyright and credits to @bepass-org [github.com/sagernet/sing-box] +*/ + +package ipscanner + +import ( + "context" + "crypto/tls" + "log/slog" + "net" + "net/netip" + "time" + + "github.com/sagernet/sing-box/ipscanner/internal/engine" + "github.com/sagernet/sing-box/ipscanner/internal/statute" +) + +type IPInfo = statute.IPInfo + +type IPScanner struct { + log *slog.Logger + engine *engine.Engine + options statute.ScannerOptions +} + +func NewScanner(options ...Option) *IPScanner { + p := &IPScanner{ + options: statute.ScannerOptions{ + UseIPv4: true, + UseIPv6: true, + CidrList: statute.DefaultCFRanges(), + SelectedOps: 0, + Logger: slog.Default(), + InsecureSkipVerify: true, + RawDialerFunc: statute.DefaultDialerFunc, + TLSDialerFunc: statute.DefaultTLSDialerFunc, + HttpClientFunc: statute.DefaultHTTPClientFunc, + UseHTTP2: false, + DisableCompression: false, + HTTPPath: "/", + Referrer: "", + UserAgent: "Chrome/80.0.3987.149", + Hostname: "www.cloudflare.com", + WarpPresharedKey: "", + WarpPeerPublicKey: "", + WarpPrivateKey: "", + Port: 443, + IPQueueSize: 8, + MaxDesirableRTT: 400 * time.Millisecond, + IPQueueTTL: 30 * time.Second, + ConnectionTimeout: 1 * time.Second, + HandshakeTimeout: 1 * time.Second, + TlsVersion: tls.VersionTLS13, + }, + log: slog.Default(), + } + + for _, option := range options { + option(p) + } + + return p +} + +type Option func(*IPScanner) + +func WithUseIPv4(useIPv4 bool) Option { + return func(i *IPScanner) { + i.options.UseIPv4 = useIPv4 + } +} + +func WithUseIPv6(useIPv6 bool) Option { + return func(i *IPScanner) { + i.options.UseIPv6 = useIPv6 + } +} + +func WithDialer(d statute.TDialerFunc) Option { + return func(i *IPScanner) { + i.options.RawDialerFunc = d + } +} + +func WithTLSDialer(t statute.TDialerFunc) Option { + return func(i *IPScanner) { + i.options.TLSDialerFunc = t + } +} + +func WithHttpClientFunc(h statute.THTTPClientFunc) Option { + return func(i *IPScanner) { + i.options.HttpClientFunc = h + } +} + +func WithUseHTTP2(useHTTP2 bool) Option { + return func(i *IPScanner) { + i.options.UseHTTP2 = useHTTP2 + } +} + +func WithDisableCompression(disableCompression bool) Option { + return func(i *IPScanner) { + i.options.DisableCompression = disableCompression + } +} + +func WithHttpPath(path string) Option { + return func(i *IPScanner) { + i.options.HTTPPath = path + } +} + +func WithReferrer(referrer string) Option { + return func(i *IPScanner) { + i.options.Referrer = referrer + } +} + +func WithUserAgent(userAgent string) Option { + return func(i *IPScanner) { + i.options.UserAgent = userAgent + } +} + +func WithLogger(logger *slog.Logger) Option { + return func(i *IPScanner) { + i.log = logger + i.options.Logger = logger + } +} + +func WithInsecureSkipVerify(insecureSkipVerify bool) Option { + return func(i *IPScanner) { + i.options.InsecureSkipVerify = insecureSkipVerify + } +} + +func WithHostname(hostname string) Option { + return func(i *IPScanner) { + i.options.Hostname = hostname + } +} + +func WithPort(port uint16) Option { + return func(i *IPScanner) { + i.options.Port = port + } +} + +func WithCidrList(cidrList []netip.Prefix) Option { + return func(i *IPScanner) { + i.options.CidrList = cidrList + } +} + +func WithHTTPPing() Option { + return func(i *IPScanner) { + i.options.SelectedOps |= statute.HTTPPing + } +} + +func WithWarpPing() Option { + return func(i *IPScanner) { + i.options.SelectedOps |= statute.WARPPing + } +} + +func WithQUICPing() Option { + return func(i *IPScanner) { + i.options.SelectedOps |= statute.QUICPing + } +} + +func WithTCPPing() Option { + return func(i *IPScanner) { + i.options.SelectedOps |= statute.TCPPing + } +} + +func WithTLSPing() Option { + return func(i *IPScanner) { + i.options.SelectedOps |= statute.TLSPing + } +} + +func WithIPQueueSize(size int) Option { + return func(i *IPScanner) { + i.options.IPQueueSize = size + } +} + +func WithMaxDesirableRTT(threshold time.Duration) Option { + return func(i *IPScanner) { + i.options.MaxDesirableRTT = threshold + } +} + +func WithIPQueueTTL(ttl time.Duration) Option { + return func(i *IPScanner) { + i.options.IPQueueTTL = ttl + } +} + +func WithConnectionTimeout(timeout time.Duration) Option { + return func(i *IPScanner) { + i.options.ConnectionTimeout = timeout + } +} + +func WithHandshakeTimeout(timeout time.Duration) Option { + return func(i *IPScanner) { + i.options.HandshakeTimeout = timeout + } +} + +func WithTlsVersion(version uint16) Option { + return func(i *IPScanner) { + i.options.TlsVersion = version + } +} + +func WithWarpPrivateKey(privateKey string) Option { + return func(i *IPScanner) { + i.options.WarpPrivateKey = privateKey + } +} + +func WithWarpPeerPublicKey(peerPublicKey string) Option { + return func(i *IPScanner) { + i.options.WarpPeerPublicKey = peerPublicKey + } +} + +func WithWarpPreSharedKey(presharedKey string) Option { + return func(i *IPScanner) { + i.options.WarpPresharedKey = presharedKey + } +} + +// run engine and in case of new event call onChange callback also if it gets canceled with context +// cancel all operations + +func (i *IPScanner) Run(ctx context.Context) { + statute.FinalOptions = &i.options + if !i.options.UseIPv4 && !i.options.UseIPv6 { + i.log.Error("Fatal: both IPv4 and IPv6 are disabled, nothing to do") + return + } + i.engine = engine.NewScannerEngine(&i.options) + go i.engine.Run(ctx) +} + +func (i *IPScanner) GetAvailableIPs() []statute.IPInfo { + if i.engine != nil { + return i.engine.GetAvailableIPs(false) + } + return nil +} + +func CanConnectIPv6(remoteAddr netip.AddrPort) bool { + dialer := net.Dialer{ + Timeout: 5 * time.Second, + } + + conn, err := dialer.Dial("tcp6", remoteAddr.String()) + if err != nil { + return false + } + defer conn.Close() + + return true +} diff --git a/ipscanner/warp_scanner.go b/ipscanner/warp_scanner.go new file mode 100644 index 00000000..22fd7479 --- /dev/null +++ b/ipscanner/warp_scanner.go @@ -0,0 +1,80 @@ +package ipscanner + +import ( + "context" + "errors" + "log/slog" + "net/netip" + "os" + "time" + + "github.com/sagernet/sing-box/warp" +) + +var googlev6DNSAddr80 = netip.MustParseAddrPort("[2001:4860:4860::8888]:80") + +type WarpScanOptions struct { + PrivateKey string + PublicKey string + MaxRTT time.Duration + V4 bool + V6 bool + Port uint16 +} + +func findMinRTT(ipInfos []IPInfo) (IPInfo, error) { + if len(ipInfos) == 0 { + return IPInfo{}, errors.New("list is empty") + } + + minRTTInfo := ipInfos[0] + for _, ipInfo := range ipInfos[1:] { + if ipInfo.RTT < minRTTInfo.RTT { + minRTTInfo = ipInfo + } + } + + return minRTTInfo, nil +} + +func RunWarpScan(ctx context.Context, opts WarpScanOptions) (result IPInfo, err error) { + ctx, cancel := context.WithTimeout(ctx, 1*time.Minute) + defer cancel() + + scanner := NewScanner( + WithLogger(slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelError}))), + WithWarpPing(), + WithWarpPrivateKey(opts.PrivateKey), + WithWarpPeerPublicKey(opts.PublicKey), + WithUseIPv4(opts.V4), + WithUseIPv6(CanConnectIPv6(googlev6DNSAddr80)), + WithMaxDesirableRTT(opts.MaxRTT), + WithCidrList(warp.WarpPrefixes()), + WithPort(opts.Port), + ) + + scanner.Run(ctx) + + t := time.NewTicker(1 * time.Second) + defer t.Stop() + + for { + ipList := scanner.GetAvailableIPs() + if len(ipList) > 1 { + bestIp, err := findMinRTT(ipList) + if err != nil { + return IPInfo{}, err + } + return bestIp, nil + } + + select { + case <-ctx.Done(): + // Context is done - canceled externally + return IPInfo{}, errors.New("user canceled the operation") + case <-t.C: + // Prevent the loop from spinning too fast + continue + } + } +} diff --git a/iputils/iputils.go b/iputils/iputils.go new file mode 100644 index 00000000..bab501ef --- /dev/null +++ b/iputils/iputils.go @@ -0,0 +1,113 @@ +package iputils + +import ( + "context" + "errors" + "fmt" + "math/big" + "math/rand" + "net" + "net/netip" + "strconv" + "time" +) + +// RandomIPFromPrefix returns a random IP from the provided CIDR prefix. +// Supports IPv4 and IPv6. Does not support mapped inputs. +func RandomIPFromPrefix(cidr netip.Prefix) (netip.Addr, error) { + startingAddress := cidr.Masked().Addr() + if startingAddress.Is4In6() { + return netip.Addr{}, errors.New("mapped v4 addresses not supported") + } + + prefixLen := cidr.Bits() + if prefixLen == -1 { + return netip.Addr{}, fmt.Errorf("invalid cidr: %s", cidr) + } + + // Initialise rand number generator + rng := rand.New(rand.NewSource(time.Now().UnixNano())) + + // Find the bit length of the Host portion of the provided CIDR + // prefix + hostLen := big.NewInt(int64(startingAddress.BitLen() - prefixLen)) + + // Find the max value for our random number + max := new(big.Int).Exp(big.NewInt(2), hostLen, nil) + + // Generate the random number + randInt := new(big.Int).Rand(rng, max) + + // Get the first address in the CIDR prefix in 16-bytes form + startingAddress16 := startingAddress.As16() + + // Convert the first address into a decimal number + startingAddressInt := new(big.Int).SetBytes(startingAddress16[:]) + + // Add the random number to the decimal form of the starting address + // to get a random address in the desired range + randomAddressInt := new(big.Int).Add(startingAddressInt, randInt) + + // Convert the random address from decimal form back into netip.Addr + randomAddress, ok := netip.AddrFromSlice(randomAddressInt.FillBytes(make([]byte, 16))) + if !ok { + return netip.Addr{}, fmt.Errorf("failed to generate random IP from CIDR: %s", cidr) + } + + // Unmap any mapped v4 addresses before return + return randomAddress.Unmap(), nil +} + +func ParseResolveAddressPort(hostname string, includev6 bool, dnsServer string) (netip.AddrPort, error) { + // Attempt to split the hostname into a host and port + host, port, err := net.SplitHostPort(hostname) + if err != nil { + return netip.AddrPort{}, fmt.Errorf("can't parse provided hostname into host and port: %w", err) + } + + // Convert the string port to a uint16 + portInt, err := strconv.Atoi(port) + if err != nil { + return netip.AddrPort{}, fmt.Errorf("error parsing port: %w", err) + } + + if portInt < 1 || portInt > 65535 { + return netip.AddrPort{}, fmt.Errorf("port number %d is out of range", portInt) + } + + // Attempt to parse the host into an IP. Return on success. + addr, err := netip.ParseAddr(host) + if err == nil { + return netip.AddrPortFrom(addr.Unmap(), uint16(portInt)), nil + } + + // Use Go's built-in DNS resolver + resolver := &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + return net.Dial("udp", net.JoinHostPort(dnsServer, "53")) + }, + } + + // If the host wasn't an IP, perform a lookup + ips, err := resolver.LookupIP(context.Background(), "ip", host) + if err != nil { + return netip.AddrPort{}, fmt.Errorf("hostname lookup failed: %w", err) + } + + for _, ip := range ips { + // Take the first IP and then return it + addr, ok := netip.AddrFromSlice(ip) + if !ok { + continue + } + + if addr.Unmap().Is4() { + return netip.AddrPortFrom(addr.Unmap(), uint16(portInt)), nil + } else if includev6 { + return netip.AddrPortFrom(addr.Unmap(), uint16(portInt)), nil + } + } + + return netip.AddrPort{}, errors.New("no valid IP addresses found") +} diff --git a/option/fragment.go b/option/fragment.go new file mode 100644 index 00000000..b4255332 --- /dev/null +++ b/option/fragment.go @@ -0,0 +1,9 @@ +package option + +type TLSFragmentOptions struct { + Enabled bool `json:"enabled,omitempty"` + Method string `json:"method,omitempty"` // Wether to fragment only clientHello or a range of TCP packets. Valid options: ['tlsHello', 'range'] + Size string `json:"size,omitempty"` // Fragment size in Bytes + Sleep string `json:"sleep,omitempty"` // Time to sleep between sending the fragments in milliseconds + Range string `json:"range,omitempty"` // Range of packets to fragment, effective when 'method' is set to 'range' +} diff --git a/option/outbound.go b/option/outbound.go index 833a2d20..f352a4d8 100644 --- a/option/outbound.go +++ b/option/outbound.go @@ -75,6 +75,7 @@ type DialerOptions struct { ConnectTimeout badoption.Duration `json:"connect_timeout,omitempty"` TCPFastOpen bool `json:"tcp_fast_open,omitempty"` TCPMultiPath bool `json:"tcp_multi_path,omitempty"` + TLSFragment TLSFragmentOptions `json:"tls_fragment,omitempty"` UDPFragment *bool `json:"udp_fragment,omitempty"` UDPFragmentDefault bool `json:"-"` DomainStrategy DomainStrategy `json:"domain_strategy,omitempty"` diff --git a/option/range.go b/option/range.go new file mode 100644 index 00000000..047d0599 --- /dev/null +++ b/option/range.go @@ -0,0 +1,56 @@ +package option + +import ( + "crypto/rand" + "fmt" + "math/big" + "strconv" + "strings" + + E "github.com/sagernet/sing/common/exceptions" +) + +type IntRange struct { + Min uint64 + Max uint64 +} + +func ParseIntRange(str string) ([]uint64, error) { + var err error + result := make([]uint64, 2) + + splitString := strings.Split(str, "-") + if len(splitString) == 2 { + result[0], err = strconv.ParseUint(splitString[0], 10, 64) + if err != nil { + return nil, E.Cause(err, "error parsing string to integer") + } + result[1], err = strconv.ParseUint(splitString[1], 10, 64) + if err != nil { + return nil, E.Cause(err, "error parsing string to integer") + } + + if result[1] < result[0] { + return nil, E.Cause(E.New(fmt.Sprintf("upper bound value (%d) must be greater than or equal to lower bound value (%d)", result[1], result[0])), "invalid range") + } + } else { + result[0], err = strconv.ParseUint(splitString[0], 10, 64) + if err != nil { + return nil, E.Cause(err, "error parsing string to integer") + } + result[1] = result[0] + } + return result, err +} + +// GetRandomIntFromRange generate a uniform random number given the range +func GetRandomIntFromRange(min uint64, max uint64) int64 { + if max == 0 { + return 0 + } + if min == max { + return int64(min) + } + randomInt, _ := rand.Int(rand.Reader, big.NewInt(int64(max-min)+1)) + return int64(min) + randomInt.Int64() +} diff --git a/warp/account.go b/warp/account.go new file mode 100644 index 00000000..5c668cad --- /dev/null +++ b/warp/account.go @@ -0,0 +1,131 @@ +package warp + +import ( + "encoding/json" + "errors" + "log/slog" + "os" + "path/filepath" +) + +var identityFile = "wgcf-identity.json" + +func saveIdentity(a Identity, path string) error { + file, err := os.Create(filepath.Join(path, identityFile)) + if err != nil { + return err + } + + encoder := json.NewEncoder(file) + encoder.SetIndent("", " ") + err = encoder.Encode(a) + if err != nil { + return err + } + + return file.Close() +} + +func LoadOrCreateIdentity(l *slog.Logger, path, license string) (*Identity, error) { + l = l.With("subsystem", "warp/account") + + i, err := LoadIdentity(path) + if err != nil { + l.Info("failed to load identity", "path", path, "error", err) + if err := os.RemoveAll(path); err != nil { + return nil, err + } + + if err := os.MkdirAll(path, os.ModePerm); err != nil { + return nil, err + } + + i, err = CreateIdentity(l, license) + if err != nil { + return nil, err + } + + if err = saveIdentity(i, path); err != nil { + return nil, err + } + } + + if license != "" && i.Account.License != license { + l.Info("updating account license key") + _, err := UpdateAccount(i.Token, i.ID, license) + if err != nil { + return nil, err + } + + iAcc, err := GetAccount(i.Token, i.ID) + if err != nil { + return nil, err + } + i.Account = iAcc + + if err = saveIdentity(i, path); err != nil { + return nil, err + } + } + + l.Info("successfully loaded warp identity") + return &i, nil +} + +func LoadIdentity(path string) (Identity, error) { + identityPath := filepath.Join(path, identityFile) + _, err := os.Stat(identityPath) + if err != nil { + return Identity{}, err + } + + fileBytes, err := os.ReadFile(identityPath) + if err != nil { + return Identity{}, err + } + + i := &Identity{} + err = json.Unmarshal(fileBytes, i) + if err != nil { + return Identity{}, err + } + + if len(i.Config.Peers) < 1 { + return Identity{}, errors.New("identity contains 0 peers") + } + + return *i, nil +} + +func CreateIdentity(l *slog.Logger, license string) (Identity, error) { + priv, err := GeneratePrivateKey() + if err != nil { + return Identity{}, err + } + + privateKey, publicKey := priv.String(), priv.PublicKey().String() + + l.Info("creating new identity") + i, err := Register(publicKey) + if err != nil { + return Identity{}, err + } + + if license != "" { + l.Info("updating account license key") + _, err := UpdateAccount(i.Token, i.ID, license) + if err != nil { + return Identity{}, err + } + + ac, err := GetAccount(i.Token, i.ID) + if err != nil { + return Identity{}, err + } + i.Account = ac + } + + i.PrivateKey = privateKey + + return i, nil +} diff --git a/warp/api.go b/warp/api.go new file mode 100644 index 00000000..0c7be239 --- /dev/null +++ b/warp/api.go @@ -0,0 +1,515 @@ +package warp + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "time" +) + +const ( + apiBase string = "https://api.cloudflareclient.com/v0a4005" +) + +var client = makeClient() + +func defaultHeaders() map[string]string { + return map[string]string{ + "Content-Type": "application/json; charset=UTF-8", + "User-Agent": "okhttp/3.12.1", + "CF-Client-Version": "a-6.30-3596", + } +} + +func makeClient() *http.Client { + // Create a custom dialer using the TLS config + plainDialer := &net.Dialer{ + Timeout: 5 * time.Second, + KeepAlive: 5 * time.Second, + } + tlsDialer := Dialer{} + // Create a custom HTTP transport + transport := &http.Transport{ + DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return tlsDialer.TLSDial(plainDialer, network, addr) + }, + } + + // Create a custom HTTP client using the transport + return &http.Client{ + Transport: transport, + // Other client configurations can be added here + } +} + +type IdentityAccount struct { + Created string `json:"created"` + Updated string `json:"updated"` + License string `json:"license"` + PremiumData int64 `json:"premium_data"` + WarpPlus bool `json:"warp_plus"` + AccountType string `json:"account_type"` + ReferralRenewalCountdown int64 `json:"referral_renewal_countdown"` + Role string `json:"role"` + ID string `json:"id"` + Quota int64 `json:"quota"` + Usage int64 `json:"usage"` + ReferralCount int64 `json:"referral_count"` + TTL string `json:"ttl"` +} + +type IdentityConfigPeerEndpoint struct { + V4 string `json:"v4"` + V6 string `json:"v6"` + Host string `json:"host"` + Ports []uint16 `json:"ports"` +} + +type IdentityConfigPeer struct { + PublicKey string `json:"public_key"` + Endpoint IdentityConfigPeerEndpoint `json:"endpoint"` +} + +type IdentityConfigInterfaceAddresses struct { + V4 string `json:"v4"` + V6 string `json:"v6"` +} + +type IdentityConfigInterface struct { + Addresses IdentityConfigInterfaceAddresses `json:"addresses"` +} +type IdentityConfigServices struct { + HTTPProxy string `json:"http_proxy"` +} + +type IdentityConfig struct { + Peers []IdentityConfigPeer `json:"peers"` + Interface IdentityConfigInterface `json:"interface"` + Services IdentityConfigServices `json:"services"` + ClientID string `json:"client_id"` +} + +type Identity struct { + PrivateKey string `json:"private_key"` + Key string `json:"key"` + Account IdentityAccount `json:"account"` + Place int64 `json:"place"` + FCMToken string `json:"fcm_token"` + Name string `json:"name"` + TOS string `json:"tos"` + Locale string `json:"locale"` + InstallID string `json:"install_id"` + WarpEnabled bool `json:"warp_enabled"` + Type string `json:"type"` + Model string `json:"model"` + Config IdentityConfig `json:"config"` + Token string `json:"token"` + Enabled bool `json:"enabled"` + ID string `json:"id"` + Created string `json:"created"` + Updated string `json:"updated"` + WaitlistEnabled bool `json:"waitlist_enabled"` +} + +type IdentityDevice struct { + ID string `json:"id"` + Name string `json:"name"` + Type string `json:"type"` + Model string `json:"model"` + Created string `json:"created"` + Activated string `json:"updated"` + Active bool `json:"active"` + Role string `json:"role"` +} + +type License struct { + License string `json:"license"` +} + +func GetAccount(authToken, deviceID string) (IdentityAccount, error) { + reqUrl := fmt.Sprintf("%s/reg/%s/account", apiBase, deviceID) + method := "GET" + + req, err := http.NewRequest(method, reqUrl, nil) + if err != nil { + return IdentityAccount{}, err + } + + // Set headers + for k, v := range defaultHeaders() { + req.Header.Set(k, v) + } + req.Header.Set("Authorization", "Bearer "+authToken) + + // Create HTTP client and execute request + resp, err := client.Do(req) + if err != nil { + return IdentityAccount{}, err + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return IdentityAccount{}, fmt.Errorf("API request failed with status: %s", resp.Status) + } + + // convert response to byte array + responseData, err := io.ReadAll(resp.Body) + if err != nil { + return IdentityAccount{}, err + } + + var rspData = IdentityAccount{} + if err := json.Unmarshal(responseData, &rspData); err != nil { + return IdentityAccount{}, err + } + + return rspData, nil +} + +func GetBoundDevices(authToken, deviceID string) ([]IdentityDevice, error) { + reqUrl := fmt.Sprintf("%s/reg/%s/account/devices", apiBase, deviceID) + method := "GET" + + req, err := http.NewRequest(method, reqUrl, nil) + if err != nil { + return nil, err + } + + // Set headers + for k, v := range defaultHeaders() { + req.Header.Set(k, v) + } + req.Header.Set("Authorization", "Bearer "+authToken) + + // Create HTTP client and execute request + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, fmt.Errorf("API request failed with status: %s", resp.Status) + } + + // convert response to byte array + responseData, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + var rspData = []IdentityDevice{} + if err := json.Unmarshal(responseData, &rspData); err != nil { + return nil, err + } + + return rspData, nil +} + +func GetSourceDevice(authToken, deviceID string) (Identity, error) { + reqUrl := fmt.Sprintf("%s/reg/%s", apiBase, deviceID) + method := "GET" + + req, err := http.NewRequest(method, reqUrl, nil) + if err != nil { + return Identity{}, err + } + + // Set headers + for k, v := range defaultHeaders() { + req.Header.Set(k, v) + } + req.Header.Set("Authorization", "Bearer "+authToken) + + // Create HTTP client and execute request + resp, err := client.Do(req) + if err != nil { + return Identity{}, err + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return Identity{}, fmt.Errorf("API request failed with status: %s", resp.Status) + } + + // convert response to byte array + responseData, err := io.ReadAll(resp.Body) + if err != nil { + return Identity{}, err + } + + var rspData = Identity{} + if err := json.Unmarshal(responseData, &rspData); err != nil { + return Identity{}, err + } + + return rspData, nil +} + +func Register(publicKey string) (Identity, error) { + reqUrl := fmt.Sprintf("%s/reg", apiBase) + method := "POST" + + data := map[string]interface{}{ + "install_id": "", + "fcm_token": "", + "tos": time.Now().Format(time.RFC3339Nano), + "key": publicKey, + "type": "Android", + "model": "PC", + "locale": "en_US", + "warp_enabled": true, + } + + jsonBody, err := json.Marshal(data) + if err != nil { + return Identity{}, err + } + + req, err := http.NewRequest(method, reqUrl, bytes.NewBuffer(jsonBody)) + if err != nil { + return Identity{}, err + } + + // Set headers + for k, v := range defaultHeaders() { + req.Header.Set(k, v) + } + + // Create HTTP client and execute request + resp, err := client.Do(req) + if err != nil { + return Identity{}, err + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return Identity{}, fmt.Errorf("API request failed with status: %s", resp.Status) + } + + // convert response to byte array + responseData, err := io.ReadAll(resp.Body) + if err != nil { + return Identity{}, err + } + + var rspData = Identity{} + if err := json.Unmarshal(responseData, &rspData); err != nil { + return Identity{}, err + } + + return rspData, nil +} + +func ResetAccountLicense(authToken, deviceID string) (License, error) { + reqUrl := fmt.Sprintf("%s/reg/%s/account/license", apiBase, deviceID) + method := "POST" + + req, err := http.NewRequest(method, reqUrl, nil) + if err != nil { + return License{}, err + } + + // Set headers + for k, v := range defaultHeaders() { + req.Header.Set(k, v) + } + req.Header.Set("Authorization", "Bearer "+authToken) + + // Create HTTP client and execute request + resp, err := client.Do(req) + if err != nil { + return License{}, err + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return License{}, fmt.Errorf("API request failed with response: %s", resp.Status) + } + + // convert response to byte array + responseData, err := io.ReadAll(resp.Body) + if err != nil { + return License{}, err + } + + var rspData = License{} + if err := json.Unmarshal(responseData, &rspData); err != nil { + return License{}, err + } + + return rspData, nil +} + +func UpdateAccount(authToken, deviceID, license string) (IdentityAccount, error) { + reqUrl := fmt.Sprintf("%s/reg/%s/account", apiBase, deviceID) + method := "PUT" + + jsonBody, err := json.Marshal(map[string]interface{}{"license": license}) + if err != nil { + return IdentityAccount{}, err + } + + req, err := http.NewRequest(method, reqUrl, bytes.NewBuffer(jsonBody)) + if err != nil { + return IdentityAccount{}, err + } + + // Set headers + for k, v := range defaultHeaders() { + req.Header.Set(k, v) + } + req.Header.Set("Authorization", "Bearer "+authToken) + + // Create HTTP client and execute request + resp, err := client.Do(req) + if err != nil { + return IdentityAccount{}, err + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return IdentityAccount{}, fmt.Errorf("API request failed with status: %s", resp.Status) + } + + // convert response to byte array + responseData, err := io.ReadAll(resp.Body) + if err != nil { + return IdentityAccount{}, err + } + + var rspData = IdentityAccount{} + if err := json.Unmarshal(responseData, &rspData); err != nil { + return IdentityAccount{}, err + } + + return rspData, nil +} + +func UpdateBoundDevice(authToken, deviceID, otherDeviceID, name string, active bool) (IdentityDevice, error) { + reqUrl := fmt.Sprintf("%s/reg/%s/account/reg/%s", apiBase, deviceID, otherDeviceID) + method := "PATCH" + + data := map[string]interface{}{ + "active": active, + "name": name, + } + + jsonBody, err := json.Marshal(data) + if err != nil { + return IdentityDevice{}, err + } + + req, err := http.NewRequest(method, reqUrl, bytes.NewBuffer(jsonBody)) + if err != nil { + return IdentityDevice{}, err + } + + // Set headers + for k, v := range defaultHeaders() { + req.Header.Set(k, v) + } + req.Header.Set("Authorization", "Bearer "+authToken) + + // Create HTTP client and execute request + resp, err := client.Do(req) + if err != nil { + return IdentityDevice{}, err + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return IdentityDevice{}, fmt.Errorf("API request failed with status: %s", resp.Status) + } + + // convert response to byte array + responseData, err := io.ReadAll(resp.Body) + if err != nil { + return IdentityDevice{}, err + } + + var rspData = IdentityDevice{} + if err := json.Unmarshal(responseData, &rspData); err != nil { + return IdentityDevice{}, err + } + + return rspData, nil +} + +func UpdateSourceDevice(authToken, deviceID, publicKey string) (Identity, error) { + reqUrl := fmt.Sprintf("%s/reg/%s", apiBase, deviceID) + method := "PATCH" + + jsonBody, err := json.Marshal(map[string]interface{}{"key": publicKey}) + if err != nil { + return Identity{}, err + } + + req, err := http.NewRequest(method, reqUrl, bytes.NewBuffer(jsonBody)) + if err != nil { + return Identity{}, err + } + + // Set headers + for k, v := range defaultHeaders() { + req.Header.Set(k, v) + } + req.Header.Set("Authorization", "Bearer "+authToken) + + // Create HTTP client and execute request + resp, err := client.Do(req) + if err != nil { + return Identity{}, err + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return Identity{}, fmt.Errorf("API request failed with status: %s", resp.Status) + } + + // convert response to byte array + responseData, err := io.ReadAll(resp.Body) + if err != nil { + return Identity{}, err + } + + var rspData = Identity{} + if err := json.Unmarshal(responseData, &rspData); err != nil { + return Identity{}, err + } + + return rspData, nil +} + +func DeleteDevice(authToken, deviceID string) error { + reqUrl := fmt.Sprintf("%s/reg/%s", apiBase, deviceID) + method := "DELETE" + + req, err := http.NewRequest(method, reqUrl, nil) + if err != nil { + return err + } + + // Set headers + for k, v := range defaultHeaders() { + req.Header.Set(k, v) + } + req.Header.Set("Authorization", "Bearer "+authToken) + + // Create HTTP client and execute request + resp, err := client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return fmt.Errorf("API request failed with status: %s", resp.Status) + } + + return nil +} diff --git a/warp/endpoint.go b/warp/endpoint.go new file mode 100644 index 00000000..5de77695 --- /dev/null +++ b/warp/endpoint.go @@ -0,0 +1,117 @@ +package warp + +import ( + "math/rand" + "net/netip" + "time" + + "github.com/sagernet/sing-box/iputils" +) + +func WarpPrefixes() []netip.Prefix { + return []netip.Prefix{ + netip.MustParsePrefix("162.159.192.0/24"), + netip.MustParsePrefix("162.159.193.0/24"), + netip.MustParsePrefix("162.159.195.0/24"), + netip.MustParsePrefix("188.114.96.0/24"), + netip.MustParsePrefix("188.114.97.0/24"), + netip.MustParsePrefix("188.114.98.0/24"), + netip.MustParsePrefix("188.114.99.0/24"), + netip.MustParsePrefix("2606:4700:d0::/64"), + netip.MustParsePrefix("2606:4700:d1::/64"), + } +} + +func RandomWarpPrefix(v4, v6 bool) netip.Prefix { + if !v4 && !v6 { + panic("Must choose a IP version for RandomWarpPrefix") + } + + cidrs := WarpPrefixes() + rng := rand.New(rand.NewSource(time.Now().UnixNano())) + for { + cidr := cidrs[rng.Intn(len(cidrs))] + + if v4 && cidr.Addr().Is4() { + return cidr + } + + if v6 && cidr.Addr().Is6() { + return cidr + } + } +} + +func WarpPorts() []uint16 { + return []uint16{ + 500, + 854, + 859, + 864, + 878, + 880, + 890, + 891, + 894, + 903, + 908, + 928, + 934, + 939, + 942, + 943, + 945, + 946, + 955, + 968, + 987, + 988, + 1002, + 1010, + 1014, + 1018, + 1070, + 1074, + 1180, + 1387, + 1701, + 1843, + 2371, + 2408, + 2506, + 3138, + 3476, + 3581, + 3854, + 4177, + 4198, + 4233, + 4500, + 5279, + 5956, + 7103, + 7152, + 7156, + 7281, + 7559, + 8319, + 8742, + 8854, + 8886, + } +} + +func RandomWarpPort() uint16 { + ports := WarpPorts() + rng := rand.New(rand.NewSource(time.Now().UnixNano())) + return ports[rng.Intn(len(ports))] +} + +func RandomWarpEndpoint(v4, v6 bool) (netip.AddrPort, error) { + randomIP, err := iputils.RandomIPFromPrefix(RandomWarpPrefix(v4, v6)) + if err != nil { + return netip.AddrPort{}, err + } + + return netip.AddrPortFrom(randomIP, RandomWarpPort()), nil +} diff --git a/warp/key.go b/warp/key.go new file mode 100644 index 00000000..8b2dfb46 --- /dev/null +++ b/warp/key.go @@ -0,0 +1,86 @@ +package warp + +import ( + "crypto/rand" + "encoding/base64" + "fmt" + + "golang.org/x/crypto/curve25519" +) + +const WarpPublicKey = "bmXOC+F1FxEMF9dyiK2H5/1SUtzH0JuVo51h2wPfgyo=" + +// KeyLen is the expected key length for a WireGuard key. +const KeyLen = 32 // wgh.KeyLen + +// A Key is a public, private, or pre-shared secret key. The Key constructor +// functions in this package can be used to create Keys suitable for each of +// these applications. +type Key [KeyLen]byte + +// GenerateKey generates a Key suitable for use as a pre-shared secret key from +// a cryptographically safe source. +// +// The output Key should not be used as a private key; use GeneratePrivateKey +// instead. +func GenerateKey() (Key, error) { + b := make([]byte, KeyLen) + if _, err := rand.Read(b); err != nil { + return Key{}, fmt.Errorf("wgtypes: failed to read random bytes: %w", err) + } + + return NewKey(b) +} + +// GeneratePrivateKey generates a Key suitable for use as a private key from a +// cryptographically safe source. +func GeneratePrivateKey() (Key, error) { + key, err := GenerateKey() + if err != nil { + return Key{}, err + } + + // Modify random bytes using algorithm described at: + // https://cr.yp.to/ecdh.html. + key[0] &= 248 + key[31] &= 127 + key[31] |= 64 + + return key, nil +} + +// NewKey creates a Key from an existing byte slice. The byte slice must be +// exactly 32 bytes in length. +func NewKey(b []byte) (Key, error) { + if len(b) != KeyLen { + return Key{}, fmt.Errorf("wgtypes: incorrect key size: %d", len(b)) + } + + var k Key + copy(k[:], b) + + return k, nil +} + +// PublicKey computes a public key from the private key k. +// +// PublicKey should only be called when k is a private key. +func (k Key) PublicKey() Key { + var ( + pub [KeyLen]byte + priv = [KeyLen]byte(k) + ) + + // ScalarBaseMult uses the correct base value per https://cr.yp.to/ecdh.html, + // so no need to specify it. + curve25519.ScalarBaseMult(&pub, &priv) + + return Key(pub) +} + +// String returns the base64-encoded string representation of a Key. +// +// ParseKey can be used to produce a new Key from this string. +func (k Key) String() string { + return base64.StdEncoding.EncodeToString(k[:]) +} diff --git a/warp/tls.go b/warp/tls.go new file mode 100644 index 00000000..f1560249 --- /dev/null +++ b/warp/tls.go @@ -0,0 +1,145 @@ +package warp + +import ( + "fmt" + "io" + "net" + "net/netip" + + "github.com/sagernet/sing-box/iputils" + + "github.com/noql-net/certpool" + tls "github.com/sagernet/utls" +) + +// Dialer is a struct that holds various options for custom dialing. +type Dialer struct{} + +const utlsExtensionSNICurve uint16 = 0x15 + +// SNICurveExtension implements SNICurve (0x15) extension +type SNICurveExtension struct { + *tls.GenericExtension + SNICurveLen int + WillPad bool // set false to disable extension +} + +// Len returns the length of the SNICurveExtension. +func (e *SNICurveExtension) Len() int { + if e.WillPad { + return 4 + e.SNICurveLen + } + return 0 +} + +// Read reads the SNICurveExtension. +func (e *SNICurveExtension) Read(b []byte) (n int, err error) { + if !e.WillPad { + return 0, io.EOF + } + if len(b) < e.Len() { + return 0, io.ErrShortBuffer + } + // https://tools.ietf.org/html/rfc7627 + b[0] = byte(utlsExtensionSNICurve >> 8) + b[1] = byte(utlsExtensionSNICurve) + b[2] = byte(e.SNICurveLen >> 8) + b[3] = byte(e.SNICurveLen) + y := make([]byte, 1200) + copy(b[4:], y) + return e.Len(), io.EOF +} + +// makeTLSHelloPacketWithSNICurve creates a TLS hello packet with SNICurve. +func (d *Dialer) makeTLSHelloPacketWithSNICurve(plainConn net.Conn, config *tls.Config, sni string) (*tls.UConn, error) { + SNICurveSize := 1200 + + utlsConn := tls.UClient(plainConn, config, tls.HelloCustom) + spec := tls.ClientHelloSpec{ + TLSVersMax: tls.VersionTLS12, + TLSVersMin: tls.VersionTLS12, + CipherSuites: []uint16{ + tls.GREASE_PLACEHOLDER, + tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, + tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, + tls.TLS_AES_128_GCM_SHA256, // tls 1.3 + tls.FAKE_TLS_DHE_RSA_WITH_AES_256_CBC_SHA, + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_RSA_WITH_AES_256_CBC_SHA, + }, + Extensions: []tls.TLSExtension{ + &SNICurveExtension{ + SNICurveLen: SNICurveSize, + WillPad: true, + }, + &tls.SupportedCurvesExtension{Curves: []tls.CurveID{tls.X25519, tls.CurveP256}}, + &tls.SupportedPointsExtension{SupportedPoints: []byte{0}}, // uncompressed + &tls.SessionTicketExtension{}, + &tls.ALPNExtension{AlpnProtocols: []string{"http/1.1"}}, + &tls.SignatureAlgorithmsExtension{ + SupportedSignatureAlgorithms: []tls.SignatureScheme{ + tls.ECDSAWithP256AndSHA256, + tls.ECDSAWithP384AndSHA384, + tls.ECDSAWithP521AndSHA512, + tls.PSSWithSHA256, + tls.PSSWithSHA384, + tls.PSSWithSHA512, + tls.PKCS1WithSHA256, + tls.PKCS1WithSHA384, + tls.PKCS1WithSHA512, + tls.ECDSAWithSHA1, + tls.PKCS1WithSHA1, + }, + }, + &tls.KeyShareExtension{KeyShares: []tls.KeyShare{ + {Group: tls.CurveID(tls.GREASE_PLACEHOLDER), Data: []byte{0}}, + {Group: tls.X25519}, + }}, + &tls.PSKKeyExchangeModesExtension{Modes: []uint8{1}}, // pskModeDHE + &tls.SNIExtension{ServerName: sni}, + }, + GetSessionID: nil, + } + err := utlsConn.ApplyPreset(&spec) + if err != nil { + return nil, fmt.Errorf("uTlsConn.Handshake() error: %w", err) + } + + err = utlsConn.Handshake() + if err != nil { + return nil, fmt.Errorf("uTlsConn.Handshake() error: %w", err) + } + + return utlsConn, nil +} + +// TLSDial dials a TLS connection. +func (d *Dialer) TLSDial(plainDialer *net.Dialer, network, addr string) (net.Conn, error) { + sni, _, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + ip, err := iputils.RandomIPFromPrefix(netip.MustParsePrefix("141.101.113.0/24")) + if err != nil { + return nil, err + } + plainConn, err := plainDialer.Dial(network, ip.String()+":443") + if err != nil { + return nil, err + } + + config := tls.Config{ + ServerName: sni, + MinVersion: tls.VersionTLS12, + RootCAs: certpool.Roots(), + } + + utlsConn, handshakeErr := d.makeTLSHelloPacketWithSNICurve(plainConn, &config, sni) + if handshakeErr != nil { + _ = plainConn.Close() + return nil, handshakeErr + } + return utlsConn, nil +}