mirror of
https://github.com/SagerNet/sing-box.git
synced 2025-06-13 21:54:13 +08:00
Implement AnyTLS client & server
This commit is contained in:
parent
63a0b5e2ce
commit
477ba6eed4
@ -19,6 +19,7 @@ const (
|
|||||||
TypeTor = "tor"
|
TypeTor = "tor"
|
||||||
TypeSSH = "ssh"
|
TypeSSH = "ssh"
|
||||||
TypeShadowTLS = "shadowtls"
|
TypeShadowTLS = "shadowtls"
|
||||||
|
TypeAnyTLS = "anytls"
|
||||||
TypeShadowsocksR = "shadowsocksr"
|
TypeShadowsocksR = "shadowsocksr"
|
||||||
TypeVLESS = "vless"
|
TypeVLESS = "vless"
|
||||||
TypeTUIC = "tuic"
|
TypeTUIC = "tuic"
|
||||||
|
@ -15,6 +15,7 @@ import (
|
|||||||
"github.com/sagernet/sing-box/dns/transport/local"
|
"github.com/sagernet/sing-box/dns/transport/local"
|
||||||
"github.com/sagernet/sing-box/log"
|
"github.com/sagernet/sing-box/log"
|
||||||
"github.com/sagernet/sing-box/option"
|
"github.com/sagernet/sing-box/option"
|
||||||
|
"github.com/sagernet/sing-box/protocol/anytls"
|
||||||
"github.com/sagernet/sing-box/protocol/block"
|
"github.com/sagernet/sing-box/protocol/block"
|
||||||
"github.com/sagernet/sing-box/protocol/direct"
|
"github.com/sagernet/sing-box/protocol/direct"
|
||||||
protocolDNS "github.com/sagernet/sing-box/protocol/dns"
|
protocolDNS "github.com/sagernet/sing-box/protocol/dns"
|
||||||
@ -53,6 +54,7 @@ func InboundRegistry() *inbound.Registry {
|
|||||||
naive.RegisterInbound(registry)
|
naive.RegisterInbound(registry)
|
||||||
shadowtls.RegisterInbound(registry)
|
shadowtls.RegisterInbound(registry)
|
||||||
vless.RegisterInbound(registry)
|
vless.RegisterInbound(registry)
|
||||||
|
anytls.RegisterInbound(registry)
|
||||||
|
|
||||||
registerQUICInbounds(registry)
|
registerQUICInbounds(registry)
|
||||||
registerStubForRemovedInbounds(registry)
|
registerStubForRemovedInbounds(registry)
|
||||||
@ -80,6 +82,7 @@ func OutboundRegistry() *outbound.Registry {
|
|||||||
ssh.RegisterOutbound(registry)
|
ssh.RegisterOutbound(registry)
|
||||||
shadowtls.RegisterOutbound(registry)
|
shadowtls.RegisterOutbound(registry)
|
||||||
vless.RegisterOutbound(registry)
|
vless.RegisterOutbound(registry)
|
||||||
|
anytls.RegisterOutbound(registry)
|
||||||
|
|
||||||
registerQUICOutbounds(registry)
|
registerQUICOutbounds(registry)
|
||||||
registerWireGuardOutbound(registry)
|
registerWireGuardOutbound(registry)
|
||||||
|
24
option/anytls.go
Normal file
24
option/anytls.go
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
package option
|
||||||
|
|
||||||
|
import "github.com/sagernet/sing/common/json/badoption"
|
||||||
|
|
||||||
|
type AnyTLSInboundOptions struct {
|
||||||
|
ListenOptions
|
||||||
|
InboundTLSOptionsContainer
|
||||||
|
Users []AnyTLSUser `json:"users,omitempty"`
|
||||||
|
PaddingScheme badoption.Listable[string] `json:"padding_scheme,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AnyTLSUser struct {
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
Password string `json:"password,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AnyTLSOutboundOptions struct {
|
||||||
|
DialerOptions
|
||||||
|
ServerOptions
|
||||||
|
OutboundTLSOptionsContainer
|
||||||
|
Password string `json:"password,omitempty"`
|
||||||
|
IdleSessionCheckInterval badoption.Duration `json:"idle_session_check_interval,omitempty"`
|
||||||
|
IdleSessionTimeout badoption.Duration `json:"idle_session_timeout,omitempty"`
|
||||||
|
}
|
130
protocol/anytls/inbound.go
Normal file
130
protocol/anytls/inbound.go
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
package anytls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/adapter"
|
||||||
|
"github.com/sagernet/sing-box/adapter/inbound"
|
||||||
|
"github.com/sagernet/sing-box/common/listener"
|
||||||
|
"github.com/sagernet/sing-box/common/tls"
|
||||||
|
"github.com/sagernet/sing-box/common/uot"
|
||||||
|
C "github.com/sagernet/sing-box/constant"
|
||||||
|
"github.com/sagernet/sing-box/log"
|
||||||
|
"github.com/sagernet/sing-box/option"
|
||||||
|
"github.com/sagernet/sing-box/transport/anytls"
|
||||||
|
"github.com/sagernet/sing-box/transport/anytls/padding"
|
||||||
|
"github.com/sagernet/sing/common"
|
||||||
|
"github.com/sagernet/sing/common/auth"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
"github.com/sagernet/sing/common/logger"
|
||||||
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
)
|
||||||
|
|
||||||
|
func RegisterInbound(registry *inbound.Registry) {
|
||||||
|
inbound.Register[option.AnyTLSInboundOptions](registry, C.TypeAnyTLS, NewInbound)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Inbound struct {
|
||||||
|
inbound.Adapter
|
||||||
|
tlsConfig tls.ServerConfig
|
||||||
|
router adapter.ConnectionRouterEx
|
||||||
|
logger logger.ContextLogger
|
||||||
|
listener *listener.Listener
|
||||||
|
service *anytls.Service
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.AnyTLSInboundOptions) (adapter.Inbound, error) {
|
||||||
|
inbound := &Inbound{
|
||||||
|
Adapter: inbound.NewAdapter(C.TypeAnyTLS, tag),
|
||||||
|
router: uot.NewRouter(router, logger),
|
||||||
|
logger: logger,
|
||||||
|
}
|
||||||
|
|
||||||
|
if options.TLS != nil && options.TLS.Enabled {
|
||||||
|
tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
inbound.tlsConfig = tlsConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
paddingScheme := padding.DefaultPaddingScheme
|
||||||
|
if len(options.PaddingScheme) > 0 {
|
||||||
|
paddingScheme = []byte(strings.Join(options.PaddingScheme, "\n"))
|
||||||
|
}
|
||||||
|
|
||||||
|
service, err := anytls.NewService(anytls.ServiceConfig{
|
||||||
|
Users: common.Map(options.Users, func(it option.AnyTLSUser) anytls.User {
|
||||||
|
return (anytls.User)(it)
|
||||||
|
}),
|
||||||
|
TLSConfig: inbound.tlsConfig,
|
||||||
|
PaddingScheme: paddingScheme,
|
||||||
|
Handler: (*inboundHandler)(inbound),
|
||||||
|
Logger: logger,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
inbound.service = service
|
||||||
|
inbound.listener = listener.New(listener.Options{
|
||||||
|
Context: ctx,
|
||||||
|
Logger: logger,
|
||||||
|
Network: []string{N.NetworkTCP},
|
||||||
|
Listen: options.ListenOptions,
|
||||||
|
ConnectionHandler: inbound,
|
||||||
|
})
|
||||||
|
return inbound, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Inbound) Start(stage adapter.StartStage) error {
|
||||||
|
if stage != adapter.StartStateStart {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if h.tlsConfig != nil {
|
||||||
|
err := h.tlsConfig.Start()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return h.listener.Start()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Inbound) Close() error {
|
||||||
|
return common.Close(h.listener, h.tlsConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Inbound) NewConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
|
||||||
|
err := h.service.NewConnection(adapter.WithContext(log.ContextWithNewID(ctx), &metadata), conn, metadata.Source, metadata.Destination, onClose)
|
||||||
|
N.CloseOnHandshakeFailure(conn, onClose, err)
|
||||||
|
if err != nil {
|
||||||
|
if E.IsClosedOrCanceled(err) {
|
||||||
|
h.logger.DebugContext(ctx, "connection closed: ", err)
|
||||||
|
} else {
|
||||||
|
h.logger.ErrorContext(ctx, E.Cause(err, "process connection from ", metadata.Source))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type inboundHandler Inbound
|
||||||
|
|
||||||
|
func (h *inboundHandler) NewConnectionEx(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) {
|
||||||
|
var metadata adapter.InboundContext
|
||||||
|
metadata.Inbound = h.Tag()
|
||||||
|
metadata.InboundType = h.Type()
|
||||||
|
//nolint:staticcheck
|
||||||
|
metadata.InboundDetour = h.listener.ListenOptions().Detour
|
||||||
|
//nolint:staticcheck
|
||||||
|
metadata.InboundOptions = h.listener.ListenOptions().InboundOptions
|
||||||
|
metadata.Source = source
|
||||||
|
metadata.Destination = destination
|
||||||
|
if userName, _ := auth.UserFromContext[string](ctx); userName != "" {
|
||||||
|
metadata.User = userName
|
||||||
|
h.logger.InfoContext(ctx, "[", userName, "] inbound connection to ", metadata.Destination)
|
||||||
|
} else {
|
||||||
|
h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination)
|
||||||
|
}
|
||||||
|
h.router.RouteConnectionEx(ctx, conn, metadata, onClose)
|
||||||
|
}
|
100
protocol/anytls/outbound.go
Normal file
100
protocol/anytls/outbound.go
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
package anytls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/adapter"
|
||||||
|
"github.com/sagernet/sing-box/adapter/outbound"
|
||||||
|
"github.com/sagernet/sing-box/common/dialer"
|
||||||
|
"github.com/sagernet/sing-box/common/tls"
|
||||||
|
C "github.com/sagernet/sing-box/constant"
|
||||||
|
"github.com/sagernet/sing-box/log"
|
||||||
|
"github.com/sagernet/sing-box/option"
|
||||||
|
"github.com/sagernet/sing-box/transport/anytls"
|
||||||
|
"github.com/sagernet/sing/common"
|
||||||
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
"github.com/sagernet/sing/common/uot"
|
||||||
|
)
|
||||||
|
|
||||||
|
func RegisterOutbound(registry *outbound.Registry) {
|
||||||
|
outbound.Register[option.AnyTLSOutboundOptions](registry, C.TypeAnyTLS, NewOutbound)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Outbound struct {
|
||||||
|
outbound.Adapter
|
||||||
|
client *anytls.Client
|
||||||
|
uotClient *uot.Client
|
||||||
|
logger log.ContextLogger
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.AnyTLSOutboundOptions) (adapter.Outbound, error) {
|
||||||
|
outbound := &Outbound{
|
||||||
|
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeAnyTLS, tag, []string{N.NetworkTCP}, options.DialerOptions),
|
||||||
|
logger: logger,
|
||||||
|
}
|
||||||
|
if options.TLS == nil || !options.TLS.Enabled {
|
||||||
|
return nil, C.ErrTLSRequired
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConfig, err := tls.NewClient(ctx, options.Server, common.PtrValueOrDefault(options.TLS))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
outboundDialer, err := dialer.NewWithOptions(dialer.Options{
|
||||||
|
Context: ctx,
|
||||||
|
Options: options.DialerOptions,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
client, err := anytls.NewClient(ctx, anytls.ClientConfig{
|
||||||
|
Password: options.Password,
|
||||||
|
IdleSessionCheckInterval: options.IdleSessionCheckInterval.Build(),
|
||||||
|
IdleSessionTimeout: options.IdleSessionTimeout.Build(),
|
||||||
|
Server: options.ServerOptions.Build(),
|
||||||
|
Dialer: outboundDialer,
|
||||||
|
TLSConfig: tlsConfig,
|
||||||
|
Logger: logger,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
outbound.client = client
|
||||||
|
|
||||||
|
outbound.uotClient = &uot.Client{
|
||||||
|
Dialer: outbound,
|
||||||
|
Version: uot.Version,
|
||||||
|
}
|
||||||
|
return outbound, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Outbound) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||||
|
ctx, metadata := adapter.ExtendContext(ctx)
|
||||||
|
metadata.Outbound = h.Tag()
|
||||||
|
metadata.Destination = destination
|
||||||
|
switch N.NetworkName(network) {
|
||||||
|
case N.NetworkTCP:
|
||||||
|
h.logger.InfoContext(ctx, "outbound connection to ", destination)
|
||||||
|
return h.client.CreateProxy(ctx, destination)
|
||||||
|
case N.NetworkUDP:
|
||||||
|
h.logger.InfoContext(ctx, "outbound UoT packet connection to ", destination)
|
||||||
|
return h.uotClient.DialContext(ctx, network, destination)
|
||||||
|
}
|
||||||
|
return nil, os.ErrInvalid
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||||
|
ctx, metadata := adapter.ExtendContext(ctx)
|
||||||
|
metadata.Outbound = h.Tag()
|
||||||
|
metadata.Destination = destination
|
||||||
|
h.logger.InfoContext(ctx, "outbound UoT packet connection to ", destination)
|
||||||
|
return h.uotClient.ListenPacket(ctx, destination)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Outbound) Close() error {
|
||||||
|
return common.Close(h.client)
|
||||||
|
}
|
101
transport/anytls/client.go
Normal file
101
transport/anytls/client.go
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
package anytls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/binary"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/common/tls"
|
||||||
|
"github.com/sagernet/sing-box/transport/anytls/padding"
|
||||||
|
"github.com/sagernet/sing-box/transport/anytls/session"
|
||||||
|
"github.com/sagernet/sing/common/atomic"
|
||||||
|
"github.com/sagernet/sing/common/buf"
|
||||||
|
"github.com/sagernet/sing/common/logger"
|
||||||
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ClientConfig struct {
|
||||||
|
Password string
|
||||||
|
IdleSessionCheckInterval time.Duration
|
||||||
|
IdleSessionTimeout time.Duration
|
||||||
|
Server M.Socksaddr
|
||||||
|
Dialer N.Dialer
|
||||||
|
TLSConfig tls.Config
|
||||||
|
Logger logger.ContextLogger
|
||||||
|
}
|
||||||
|
|
||||||
|
type Client struct {
|
||||||
|
passwordSha256 []byte
|
||||||
|
tlsConfig tls.Config
|
||||||
|
dialer N.Dialer
|
||||||
|
server M.Socksaddr
|
||||||
|
sessionClient *session.Client
|
||||||
|
padding atomic.TypedValue[*padding.PaddingFactory]
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewClient(ctx context.Context, config ClientConfig) (*Client, error) {
|
||||||
|
pw := sha256.Sum256([]byte(config.Password))
|
||||||
|
c := &Client{
|
||||||
|
passwordSha256: pw[:],
|
||||||
|
tlsConfig: config.TLSConfig,
|
||||||
|
dialer: config.Dialer,
|
||||||
|
server: config.Server,
|
||||||
|
}
|
||||||
|
// Initialize the padding state of this client
|
||||||
|
padding.UpdatePaddingScheme(padding.DefaultPaddingScheme, &c.padding)
|
||||||
|
c.sessionClient = session.NewClient(ctx, c.CreateOutboundTLSConnection, &c.padding, config.IdleSessionCheckInterval, config.IdleSessionTimeout)
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) CreateProxy(ctx context.Context, destination M.Socksaddr) (net.Conn, error) {
|
||||||
|
conn, err := c.sessionClient.CreateStream(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
err = M.SocksaddrSerializer.WriteAddrPort(conn, destination)
|
||||||
|
if err != nil {
|
||||||
|
conn.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) CreateOutboundTLSConnection(ctx context.Context) (net.Conn, error) {
|
||||||
|
conn, err := c.dialer.DialContext(ctx, N.NetworkTCP, c.server)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
b := buf.NewPacket()
|
||||||
|
defer b.Release()
|
||||||
|
|
||||||
|
b.Write(c.passwordSha256)
|
||||||
|
var paddingLen int
|
||||||
|
if pad := c.padding.Load().GenerateRecordPayloadSizes(0); len(pad) > 0 {
|
||||||
|
paddingLen = pad[0]
|
||||||
|
}
|
||||||
|
binary.BigEndian.PutUint16(b.Extend(2), uint16(paddingLen))
|
||||||
|
if paddingLen > 0 {
|
||||||
|
b.WriteZeroN(paddingLen)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err = tls.ClientHandshake(ctx, conn, c.tlsConfig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = b.WriteTo(conn)
|
||||||
|
if err != nil {
|
||||||
|
conn.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Client) Close() error {
|
||||||
|
return h.sessionClient.Close()
|
||||||
|
}
|
92
transport/anytls/padding/padding.go
Normal file
92
transport/anytls/padding/padding.go
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
package padding
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/md5"
|
||||||
|
"crypto/rand"
|
||||||
|
"fmt"
|
||||||
|
"math/big"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/transport/anytls/util"
|
||||||
|
"github.com/sagernet/sing/common/atomic"
|
||||||
|
)
|
||||||
|
|
||||||
|
const CheckMark = -1
|
||||||
|
|
||||||
|
var DefaultPaddingScheme = []byte(`stop=8
|
||||||
|
0=34-120
|
||||||
|
1=100-400
|
||||||
|
2=400-500,c,500-1000,c,400-500,c,500-1000,c,500-1000,c,400-500
|
||||||
|
3=500-1000
|
||||||
|
4=500-1000
|
||||||
|
5=500-1000
|
||||||
|
6=500-1000
|
||||||
|
7=500-1000`)
|
||||||
|
|
||||||
|
type PaddingFactory struct {
|
||||||
|
scheme util.StringMap
|
||||||
|
RawScheme []byte
|
||||||
|
Stop uint32
|
||||||
|
Md5 string
|
||||||
|
}
|
||||||
|
|
||||||
|
func UpdatePaddingScheme(rawScheme []byte, to *atomic.TypedValue[*PaddingFactory]) bool {
|
||||||
|
if p := NewPaddingFactory(rawScheme); p != nil {
|
||||||
|
to.Store(p)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPaddingFactory(rawScheme []byte) *PaddingFactory {
|
||||||
|
p := &PaddingFactory{
|
||||||
|
RawScheme: rawScheme,
|
||||||
|
Md5: fmt.Sprintf("%x", md5.Sum(rawScheme)),
|
||||||
|
}
|
||||||
|
scheme := util.StringMapFromBytes(rawScheme)
|
||||||
|
if len(scheme) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if stop, err := strconv.Atoi(scheme["stop"]); err == nil {
|
||||||
|
p.Stop = uint32(stop)
|
||||||
|
} else {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
p.scheme = scheme
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PaddingFactory) GenerateRecordPayloadSizes(pkt uint32) (pktSizes []int) {
|
||||||
|
if s, ok := p.scheme[strconv.Itoa(int(pkt))]; ok {
|
||||||
|
sRanges := strings.Split(s, ",")
|
||||||
|
for _, sRange := range sRanges {
|
||||||
|
sRangeMinMax := strings.Split(sRange, "-")
|
||||||
|
if len(sRangeMinMax) == 2 {
|
||||||
|
_min, err := strconv.ParseInt(sRangeMinMax[0], 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
_max, err := strconv.ParseInt(sRangeMinMax[1], 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _min > _max {
|
||||||
|
_min, _max = _max, _min
|
||||||
|
}
|
||||||
|
if _min <= 0 || _max <= 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _min == _max {
|
||||||
|
pktSizes = append(pktSizes, int(_min))
|
||||||
|
} else {
|
||||||
|
i, _ := rand.Int(rand.Reader, big.NewInt(_max-_min))
|
||||||
|
pktSizes = append(pktSizes, int(i.Int64()+_min))
|
||||||
|
}
|
||||||
|
} else if sRange == "c" {
|
||||||
|
pktSizes = append(pktSizes, CheckMark)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
74
transport/anytls/pipe/deadline.go
Normal file
74
transport/anytls/pipe/deadline.go
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
package pipe
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PipeDeadline is an abstraction for handling timeouts.
|
||||||
|
type PipeDeadline struct {
|
||||||
|
mu sync.Mutex // Guards timer and cancel
|
||||||
|
timer *time.Timer
|
||||||
|
cancel chan struct{} // Must be non-nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func MakePipeDeadline() PipeDeadline {
|
||||||
|
return PipeDeadline{cancel: make(chan struct{})}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set sets the point in time when the deadline will time out.
|
||||||
|
// A timeout event is signaled by closing the channel returned by waiter.
|
||||||
|
// Once a timeout has occurred, the deadline can be refreshed by specifying a
|
||||||
|
// t value in the future.
|
||||||
|
//
|
||||||
|
// A zero value for t prevents timeout.
|
||||||
|
func (d *PipeDeadline) Set(t time.Time) {
|
||||||
|
d.mu.Lock()
|
||||||
|
defer d.mu.Unlock()
|
||||||
|
|
||||||
|
if d.timer != nil && !d.timer.Stop() {
|
||||||
|
<-d.cancel // Wait for the timer callback to finish and close cancel
|
||||||
|
}
|
||||||
|
d.timer = nil
|
||||||
|
|
||||||
|
// Time is zero, then there is no deadline.
|
||||||
|
closed := isClosedChan(d.cancel)
|
||||||
|
if t.IsZero() {
|
||||||
|
if closed {
|
||||||
|
d.cancel = make(chan struct{})
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Time in the future, setup a timer to cancel in the future.
|
||||||
|
if dur := time.Until(t); dur > 0 {
|
||||||
|
if closed {
|
||||||
|
d.cancel = make(chan struct{})
|
||||||
|
}
|
||||||
|
d.timer = time.AfterFunc(dur, func() {
|
||||||
|
close(d.cancel)
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Time in the past, so close immediately.
|
||||||
|
if !closed {
|
||||||
|
close(d.cancel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait returns a channel that is closed when the deadline is exceeded.
|
||||||
|
func (d *PipeDeadline) Wait() chan struct{} {
|
||||||
|
d.mu.Lock()
|
||||||
|
defer d.mu.Unlock()
|
||||||
|
return d.cancel
|
||||||
|
}
|
||||||
|
|
||||||
|
func isClosedChan(c <-chan struct{}) bool {
|
||||||
|
select {
|
||||||
|
case <-c:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
232
transport/anytls/pipe/io_pipe.go
Normal file
232
transport/anytls/pipe/io_pipe.go
Normal file
@ -0,0 +1,232 @@
|
|||||||
|
// Copyright 2009 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
// Pipe adapter to connect code expecting an io.Reader
|
||||||
|
// with code expecting an io.Writer.
|
||||||
|
|
||||||
|
package pipe
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// onceError is an object that will only store an error once.
|
||||||
|
type onceError struct {
|
||||||
|
sync.Mutex // guards following
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *onceError) Store(err error) {
|
||||||
|
a.Lock()
|
||||||
|
defer a.Unlock()
|
||||||
|
if a.err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
a.err = err
|
||||||
|
}
|
||||||
|
func (a *onceError) Load() error {
|
||||||
|
a.Lock()
|
||||||
|
defer a.Unlock()
|
||||||
|
return a.err
|
||||||
|
}
|
||||||
|
|
||||||
|
// A pipe is the shared pipe structure underlying PipeReader and PipeWriter.
|
||||||
|
type pipe struct {
|
||||||
|
wrMu sync.Mutex // Serializes Write operations
|
||||||
|
wrCh chan []byte
|
||||||
|
rdCh chan int
|
||||||
|
|
||||||
|
once sync.Once // Protects closing done
|
||||||
|
done chan struct{}
|
||||||
|
rerr onceError
|
||||||
|
werr onceError
|
||||||
|
|
||||||
|
readDeadline PipeDeadline
|
||||||
|
writeDeadline PipeDeadline
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *pipe) read(b []byte) (n int, err error) {
|
||||||
|
select {
|
||||||
|
case <-p.done:
|
||||||
|
return 0, p.readCloseError()
|
||||||
|
case <-p.readDeadline.Wait():
|
||||||
|
return 0, os.ErrDeadlineExceeded
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case bw := <-p.wrCh:
|
||||||
|
nr := copy(b, bw)
|
||||||
|
p.rdCh <- nr
|
||||||
|
return nr, nil
|
||||||
|
case <-p.done:
|
||||||
|
return 0, p.readCloseError()
|
||||||
|
case <-p.readDeadline.Wait():
|
||||||
|
return 0, os.ErrDeadlineExceeded
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *pipe) closeRead(err error) error {
|
||||||
|
if err == nil {
|
||||||
|
err = io.ErrClosedPipe
|
||||||
|
}
|
||||||
|
p.rerr.Store(err)
|
||||||
|
p.once.Do(func() { close(p.done) })
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *pipe) write(b []byte) (n int, err error) {
|
||||||
|
select {
|
||||||
|
case <-p.done:
|
||||||
|
return 0, p.writeCloseError()
|
||||||
|
case <-p.writeDeadline.Wait():
|
||||||
|
return 0, os.ErrDeadlineExceeded
|
||||||
|
default:
|
||||||
|
p.wrMu.Lock()
|
||||||
|
defer p.wrMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
for once := true; once || len(b) > 0; once = false {
|
||||||
|
select {
|
||||||
|
case p.wrCh <- b:
|
||||||
|
nw := <-p.rdCh
|
||||||
|
b = b[nw:]
|
||||||
|
n += nw
|
||||||
|
case <-p.done:
|
||||||
|
return n, p.writeCloseError()
|
||||||
|
case <-p.writeDeadline.Wait():
|
||||||
|
return n, os.ErrDeadlineExceeded
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *pipe) closeWrite(err error) error {
|
||||||
|
if err == nil {
|
||||||
|
err = io.EOF
|
||||||
|
}
|
||||||
|
p.werr.Store(err)
|
||||||
|
p.once.Do(func() { close(p.done) })
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// readCloseError is considered internal to the pipe type.
|
||||||
|
func (p *pipe) readCloseError() error {
|
||||||
|
rerr := p.rerr.Load()
|
||||||
|
if werr := p.werr.Load(); rerr == nil && werr != nil {
|
||||||
|
return werr
|
||||||
|
}
|
||||||
|
return io.ErrClosedPipe
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeCloseError is considered internal to the pipe type.
|
||||||
|
func (p *pipe) writeCloseError() error {
|
||||||
|
werr := p.werr.Load()
|
||||||
|
if rerr := p.rerr.Load(); werr == nil && rerr != nil {
|
||||||
|
return rerr
|
||||||
|
}
|
||||||
|
return io.ErrClosedPipe
|
||||||
|
}
|
||||||
|
|
||||||
|
// A PipeReader is the read half of a pipe.
|
||||||
|
type PipeReader struct{ pipe }
|
||||||
|
|
||||||
|
// Read implements the standard Read interface:
|
||||||
|
// it reads data from the pipe, blocking until a writer
|
||||||
|
// arrives or the write end is closed.
|
||||||
|
// If the write end is closed with an error, that error is
|
||||||
|
// returned as err; otherwise err is EOF.
|
||||||
|
func (r *PipeReader) Read(data []byte) (n int, err error) {
|
||||||
|
return r.pipe.read(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the reader; subsequent writes to the
|
||||||
|
// write half of the pipe will return the error [ErrClosedPipe].
|
||||||
|
func (r *PipeReader) Close() error {
|
||||||
|
return r.CloseWithError(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseWithError closes the reader; subsequent writes
|
||||||
|
// to the write half of the pipe will return the error err.
|
||||||
|
//
|
||||||
|
// CloseWithError never overwrites the previous error if it exists
|
||||||
|
// and always returns nil.
|
||||||
|
func (r *PipeReader) CloseWithError(err error) error {
|
||||||
|
return r.pipe.closeRead(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// A PipeWriter is the write half of a pipe.
|
||||||
|
type PipeWriter struct{ r PipeReader }
|
||||||
|
|
||||||
|
// Write implements the standard Write interface:
|
||||||
|
// it writes data to the pipe, blocking until one or more readers
|
||||||
|
// have consumed all the data or the read end is closed.
|
||||||
|
// If the read end is closed with an error, that err is
|
||||||
|
// returned as err; otherwise err is [ErrClosedPipe].
|
||||||
|
func (w *PipeWriter) Write(data []byte) (n int, err error) {
|
||||||
|
return w.r.pipe.write(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the writer; subsequent reads from the
|
||||||
|
// read half of the pipe will return no bytes and EOF.
|
||||||
|
func (w *PipeWriter) Close() error {
|
||||||
|
return w.CloseWithError(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseWithError closes the writer; subsequent reads from the
|
||||||
|
// read half of the pipe will return no bytes and the error err,
|
||||||
|
// or EOF if err is nil.
|
||||||
|
//
|
||||||
|
// CloseWithError never overwrites the previous error if it exists
|
||||||
|
// and always returns nil.
|
||||||
|
func (w *PipeWriter) CloseWithError(err error) error {
|
||||||
|
return w.r.pipe.closeWrite(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pipe creates a synchronous in-memory pipe.
|
||||||
|
// It can be used to connect code expecting an [io.Reader]
|
||||||
|
// with code expecting an [io.Writer].
|
||||||
|
//
|
||||||
|
// Reads and Writes on the pipe are matched one to one
|
||||||
|
// except when multiple Reads are needed to consume a single Write.
|
||||||
|
// That is, each Write to the [PipeWriter] blocks until it has satisfied
|
||||||
|
// one or more Reads from the [PipeReader] that fully consume
|
||||||
|
// the written data.
|
||||||
|
// The data is copied directly from the Write to the corresponding
|
||||||
|
// Read (or Reads); there is no internal buffering.
|
||||||
|
//
|
||||||
|
// It is safe to call Read and Write in parallel with each other or with Close.
|
||||||
|
// Parallel calls to Read and parallel calls to Write are also safe:
|
||||||
|
// the individual calls will be gated sequentially.
|
||||||
|
//
|
||||||
|
// Added SetReadDeadline and SetWriteDeadline methods based on `io.Pipe`.
|
||||||
|
func Pipe() (*PipeReader, *PipeWriter) {
|
||||||
|
pw := &PipeWriter{r: PipeReader{pipe: pipe{
|
||||||
|
wrCh: make(chan []byte),
|
||||||
|
rdCh: make(chan int),
|
||||||
|
done: make(chan struct{}),
|
||||||
|
readDeadline: MakePipeDeadline(),
|
||||||
|
writeDeadline: MakePipeDeadline(),
|
||||||
|
}}}
|
||||||
|
return &pw.r, pw
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PipeReader) SetReadDeadline(t time.Time) error {
|
||||||
|
if isClosedChan(p.done) {
|
||||||
|
return io.ErrClosedPipe
|
||||||
|
}
|
||||||
|
p.readDeadline.Set(t)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PipeWriter) SetWriteDeadline(t time.Time) error {
|
||||||
|
if isClosedChan(p.r.done) {
|
||||||
|
return io.ErrClosedPipe
|
||||||
|
}
|
||||||
|
p.r.writeDeadline.Set(t)
|
||||||
|
return nil
|
||||||
|
}
|
125
transport/anytls/service.go
Normal file
125
transport/anytls/service.go
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
package anytls
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/common/tls"
|
||||||
|
"github.com/sagernet/sing-box/transport/anytls/padding"
|
||||||
|
"github.com/sagernet/sing-box/transport/anytls/session"
|
||||||
|
"github.com/sagernet/sing/common/atomic"
|
||||||
|
"github.com/sagernet/sing/common/auth"
|
||||||
|
"github.com/sagernet/sing/common/buf"
|
||||||
|
"github.com/sagernet/sing/common/bufio"
|
||||||
|
"github.com/sagernet/sing/common/logger"
|
||||||
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Service struct {
|
||||||
|
users map[[32]byte]string
|
||||||
|
padding atomic.TypedValue[*padding.PaddingFactory]
|
||||||
|
tlsConfig tls.ServerConfig
|
||||||
|
handler N.TCPConnectionHandlerEx
|
||||||
|
logger logger.ContextLogger
|
||||||
|
}
|
||||||
|
|
||||||
|
type ServiceConfig struct {
|
||||||
|
PaddingScheme []byte
|
||||||
|
Users []User
|
||||||
|
TLSConfig tls.ServerConfig
|
||||||
|
Handler N.TCPConnectionHandlerEx
|
||||||
|
Logger logger.ContextLogger
|
||||||
|
}
|
||||||
|
|
||||||
|
type User struct {
|
||||||
|
Name string
|
||||||
|
Password string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewService(config ServiceConfig) (*Service, error) {
|
||||||
|
service := &Service{
|
||||||
|
tlsConfig: config.TLSConfig,
|
||||||
|
handler: config.Handler,
|
||||||
|
logger: config.Logger,
|
||||||
|
users: make(map[[32]byte]string),
|
||||||
|
}
|
||||||
|
|
||||||
|
if service.handler == nil || service.logger == nil {
|
||||||
|
return nil, os.ErrInvalid
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, user := range config.Users {
|
||||||
|
service.users[sha256.Sum256([]byte(user.Password))] = user.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
if !padding.UpdatePaddingScheme(config.PaddingScheme, &service.padding) {
|
||||||
|
return nil, errors.New("incorrect padding scheme format")
|
||||||
|
}
|
||||||
|
|
||||||
|
return service, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) NewConnection(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) error {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if s.tlsConfig != nil {
|
||||||
|
conn, err = tls.ServerHandshake(ctx, conn, s.tlsConfig)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
b := buf.NewPacket()
|
||||||
|
defer b.Release()
|
||||||
|
|
||||||
|
_, err = b.ReadOnceFrom(conn)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
conn = bufio.NewCachedConn(conn, b)
|
||||||
|
|
||||||
|
by, err := b.ReadBytes(32)
|
||||||
|
if err != nil {
|
||||||
|
b.Reset()
|
||||||
|
return os.ErrInvalid
|
||||||
|
}
|
||||||
|
var passwordSha256 [32]byte
|
||||||
|
copy(passwordSha256[:], by)
|
||||||
|
if user, ok := s.users[passwordSha256]; ok {
|
||||||
|
ctx = auth.ContextWithUser(ctx, user)
|
||||||
|
} else {
|
||||||
|
b.Reset()
|
||||||
|
return os.ErrInvalid
|
||||||
|
}
|
||||||
|
by, err = b.ReadBytes(2)
|
||||||
|
if err != nil {
|
||||||
|
b.Reset()
|
||||||
|
return os.ErrInvalid
|
||||||
|
}
|
||||||
|
paddingLen := binary.BigEndian.Uint16(by)
|
||||||
|
if paddingLen > 0 {
|
||||||
|
_, err = b.ReadBytes(int(paddingLen))
|
||||||
|
if err != nil {
|
||||||
|
b.Reset()
|
||||||
|
return os.ErrInvalid
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
session := session.NewServerSession(conn, func(stream *session.Stream) {
|
||||||
|
destination, err := M.SocksaddrSerializer.ReadAddrPort(stream)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.ErrorContext(ctx, "ReadAddrPort:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.handler.NewConnectionEx(ctx, stream, source, destination, onClose)
|
||||||
|
}, &s.padding)
|
||||||
|
session.Run()
|
||||||
|
session.Close()
|
||||||
|
return nil
|
||||||
|
}
|
159
transport/anytls/session/client.go
Normal file
159
transport/anytls/session/client.go
Normal file
@ -0,0 +1,159 @@
|
|||||||
|
package session
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"math"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/transport/anytls/padding"
|
||||||
|
"github.com/sagernet/sing-box/transport/anytls/skiplist"
|
||||||
|
"github.com/sagernet/sing-box/transport/anytls/util"
|
||||||
|
"github.com/sagernet/sing/common"
|
||||||
|
"github.com/sagernet/sing/common/atomic"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Client struct {
|
||||||
|
die context.Context
|
||||||
|
dieCancel context.CancelFunc
|
||||||
|
|
||||||
|
dialOut func(ctx context.Context) (net.Conn, error)
|
||||||
|
|
||||||
|
sessionCounter atomic.Uint64
|
||||||
|
idleSession *skiplist.SkipList[uint64, *Session]
|
||||||
|
idleSessionLock sync.Mutex
|
||||||
|
|
||||||
|
padding *atomic.TypedValue[*padding.PaddingFactory]
|
||||||
|
|
||||||
|
idleSessionTimeout time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewClient(ctx context.Context, dialOut func(ctx context.Context) (net.Conn, error), _padding *atomic.TypedValue[*padding.PaddingFactory], idleSessionCheckInterval, idleSessionTimeout time.Duration) *Client {
|
||||||
|
c := &Client{
|
||||||
|
dialOut: dialOut,
|
||||||
|
padding: _padding,
|
||||||
|
idleSessionTimeout: idleSessionTimeout,
|
||||||
|
}
|
||||||
|
if idleSessionCheckInterval <= time.Second*5 {
|
||||||
|
idleSessionCheckInterval = time.Second * 30
|
||||||
|
}
|
||||||
|
if c.idleSessionTimeout <= time.Second*5 {
|
||||||
|
c.idleSessionTimeout = time.Second * 30
|
||||||
|
}
|
||||||
|
c.die, c.dieCancel = context.WithCancel(ctx)
|
||||||
|
c.idleSession = skiplist.NewSkipList[uint64, *Session]()
|
||||||
|
util.StartRoutine(c.die, idleSessionCheckInterval, c.idleCleanup)
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) CreateStream(ctx context.Context) (net.Conn, error) {
|
||||||
|
select {
|
||||||
|
case <-c.die.Done():
|
||||||
|
return nil, io.ErrClosedPipe
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
var session *Session
|
||||||
|
var stream *Stream
|
||||||
|
var err error
|
||||||
|
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
session, err = c.findSession(ctx)
|
||||||
|
if session == nil {
|
||||||
|
return nil, fmt.Errorf("failed to create session: %w", err)
|
||||||
|
}
|
||||||
|
stream, err = session.OpenStream()
|
||||||
|
if err != nil {
|
||||||
|
common.Close(session, stream)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if session == nil || stream == nil {
|
||||||
|
return nil, fmt.Errorf("too many closed session: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
stream.dieHook = func() {
|
||||||
|
if session.IsClosed() {
|
||||||
|
if session.dieHook != nil {
|
||||||
|
session.dieHook()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
c.idleSessionLock.Lock()
|
||||||
|
session.idleSince = time.Now()
|
||||||
|
c.idleSession.Insert(math.MaxUint64-session.seq, session)
|
||||||
|
c.idleSessionLock.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return stream, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) findSession(ctx context.Context) (*Session, error) {
|
||||||
|
var idle *Session
|
||||||
|
|
||||||
|
c.idleSessionLock.Lock()
|
||||||
|
if !c.idleSession.IsEmpty() {
|
||||||
|
it := c.idleSession.Iterate()
|
||||||
|
idle = it.Value()
|
||||||
|
c.idleSession.Remove(it.Key())
|
||||||
|
}
|
||||||
|
c.idleSessionLock.Unlock()
|
||||||
|
|
||||||
|
if idle == nil {
|
||||||
|
s, err := c.createSession(ctx)
|
||||||
|
return s, err
|
||||||
|
}
|
||||||
|
return idle, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) createSession(ctx context.Context) (*Session, error) {
|
||||||
|
underlying, err := c.dialOut(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
session := NewClientSession(underlying, c.padding)
|
||||||
|
session.seq = c.sessionCounter.Add(1)
|
||||||
|
session.dieHook = func() {
|
||||||
|
//logrus.Debugln("session died", session)
|
||||||
|
c.idleSessionLock.Lock()
|
||||||
|
c.idleSession.Remove(math.MaxUint64 - session.seq)
|
||||||
|
c.idleSessionLock.Unlock()
|
||||||
|
}
|
||||||
|
session.Run()
|
||||||
|
return session, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) Close() error {
|
||||||
|
c.dieCancel()
|
||||||
|
go c.idleCleanupExpTime(time.Now())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) idleCleanup() {
|
||||||
|
c.idleCleanupExpTime(time.Now().Add(-c.idleSessionTimeout))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) idleCleanupExpTime(expTime time.Time) {
|
||||||
|
var sessionToRemove = make([]*Session, 0)
|
||||||
|
|
||||||
|
c.idleSessionLock.Lock()
|
||||||
|
it := c.idleSession.Iterate()
|
||||||
|
for it.IsNotEnd() {
|
||||||
|
session := it.Value()
|
||||||
|
if session.idleSince.Before(expTime) {
|
||||||
|
sessionToRemove = append(sessionToRemove, session)
|
||||||
|
c.idleSession.Remove(it.Key())
|
||||||
|
}
|
||||||
|
it.MoveToNext()
|
||||||
|
}
|
||||||
|
c.idleSessionLock.Unlock()
|
||||||
|
|
||||||
|
for _, session := range sessionToRemove {
|
||||||
|
session.Close()
|
||||||
|
}
|
||||||
|
}
|
44
transport/anytls/session/frame.go
Normal file
44
transport/anytls/session/frame.go
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
package session
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
)
|
||||||
|
|
||||||
|
const ( // cmds
|
||||||
|
cmdWaste = 0 // Paddings
|
||||||
|
cmdSYN = 1 // stream open
|
||||||
|
cmdPSH = 2 // data push
|
||||||
|
cmdFIN = 3 // stream close, a.k.a EOF mark
|
||||||
|
cmdSettings = 4 // Settings
|
||||||
|
cmdAlert = 5 // Alert
|
||||||
|
cmdUpdatePaddingScheme = 6 // update padding scheme
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
headerOverHeadSize = 1 + 4 + 2
|
||||||
|
)
|
||||||
|
|
||||||
|
// frame defines a packet from or to be multiplexed into a single connection
|
||||||
|
type frame struct {
|
||||||
|
cmd byte // 1
|
||||||
|
sid uint32 // 4
|
||||||
|
data []byte // 2 + len(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFrame(cmd byte, sid uint32) frame {
|
||||||
|
return frame{cmd: cmd, sid: sid}
|
||||||
|
}
|
||||||
|
|
||||||
|
type rawHeader [headerOverHeadSize]byte
|
||||||
|
|
||||||
|
func (h rawHeader) Cmd() byte {
|
||||||
|
return h[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h rawHeader) StreamID() uint32 {
|
||||||
|
return binary.BigEndian.Uint32(h[1:])
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h rawHeader) Length() uint16 {
|
||||||
|
return binary.BigEndian.Uint16(h[5:])
|
||||||
|
}
|
383
transport/anytls/session/session.go
Normal file
383
transport/anytls/session/session.go
Normal file
@ -0,0 +1,383 @@
|
|||||||
|
package session
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/md5"
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"runtime/debug"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/constant"
|
||||||
|
"github.com/sagernet/sing-box/log"
|
||||||
|
"github.com/sagernet/sing-box/transport/anytls/padding"
|
||||||
|
"github.com/sagernet/sing-box/transport/anytls/util"
|
||||||
|
"github.com/sagernet/sing/common/atomic"
|
||||||
|
"github.com/sagernet/sing/common/buf"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Session struct {
|
||||||
|
conn net.Conn
|
||||||
|
connLock sync.Mutex
|
||||||
|
|
||||||
|
streams map[uint32]*Stream
|
||||||
|
streamId atomic.Uint32
|
||||||
|
streamLock sync.RWMutex
|
||||||
|
|
||||||
|
dieOnce sync.Once
|
||||||
|
die chan struct{}
|
||||||
|
dieHook func()
|
||||||
|
|
||||||
|
// pool
|
||||||
|
seq uint64
|
||||||
|
idleSince time.Time
|
||||||
|
padding *atomic.TypedValue[*padding.PaddingFactory]
|
||||||
|
|
||||||
|
// client
|
||||||
|
isClient bool
|
||||||
|
sendPadding bool
|
||||||
|
buffering bool
|
||||||
|
buffer []byte
|
||||||
|
pktCounter atomic.Uint32
|
||||||
|
|
||||||
|
// server
|
||||||
|
onNewStream func(stream *Stream)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewClientSession(conn net.Conn, _padding *atomic.TypedValue[*padding.PaddingFactory]) *Session {
|
||||||
|
s := &Session{
|
||||||
|
conn: conn,
|
||||||
|
isClient: true,
|
||||||
|
sendPadding: true,
|
||||||
|
padding: _padding,
|
||||||
|
}
|
||||||
|
s.die = make(chan struct{})
|
||||||
|
s.streams = make(map[uint32]*Stream)
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewServerSession(conn net.Conn, onNewStream func(stream *Stream), _padding *atomic.TypedValue[*padding.PaddingFactory]) *Session {
|
||||||
|
s := &Session{
|
||||||
|
conn: conn,
|
||||||
|
onNewStream: onNewStream,
|
||||||
|
padding: _padding,
|
||||||
|
}
|
||||||
|
s.die = make(chan struct{})
|
||||||
|
s.streams = make(map[uint32]*Stream)
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) Run() {
|
||||||
|
if !s.isClient {
|
||||||
|
s.recvLoop()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
settings := util.StringMap{
|
||||||
|
"v": "1",
|
||||||
|
"client": "sing-box/" + constant.Version,
|
||||||
|
"padding-md5": s.padding.Load().Md5,
|
||||||
|
}
|
||||||
|
f := newFrame(cmdSettings, 0)
|
||||||
|
f.data = settings.ToBytes()
|
||||||
|
s.buffering = true
|
||||||
|
s.writeFrame(f)
|
||||||
|
|
||||||
|
go s.recvLoop()
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsClosed does a safe check to see if we have shutdown
|
||||||
|
func (s *Session) IsClosed() bool {
|
||||||
|
select {
|
||||||
|
case <-s.die:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close is used to close the session and all streams.
|
||||||
|
func (s *Session) Close() error {
|
||||||
|
var once bool
|
||||||
|
s.dieOnce.Do(func() {
|
||||||
|
close(s.die)
|
||||||
|
once = true
|
||||||
|
})
|
||||||
|
|
||||||
|
if once {
|
||||||
|
if s.dieHook != nil {
|
||||||
|
s.dieHook()
|
||||||
|
}
|
||||||
|
s.streamLock.Lock()
|
||||||
|
for k := range s.streams {
|
||||||
|
s.streams[k].sessionClose()
|
||||||
|
}
|
||||||
|
s.streamLock.Unlock()
|
||||||
|
return s.conn.Close()
|
||||||
|
} else {
|
||||||
|
return io.ErrClosedPipe
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OpenStream is used to create a new stream for CLIENT
|
||||||
|
func (s *Session) OpenStream() (*Stream, error) {
|
||||||
|
if s.IsClosed() {
|
||||||
|
return nil, io.ErrClosedPipe
|
||||||
|
}
|
||||||
|
|
||||||
|
sid := s.streamId.Add(1)
|
||||||
|
stream := newStream(sid, s)
|
||||||
|
|
||||||
|
//logrus.Debugln("stream open", sid, s.streams)
|
||||||
|
|
||||||
|
if _, err := s.writeFrame(newFrame(cmdSYN, sid)); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
s.buffering = false // proxy Write it's SocksAddr to flush the buffer
|
||||||
|
|
||||||
|
s.streamLock.Lock()
|
||||||
|
defer s.streamLock.Unlock()
|
||||||
|
select {
|
||||||
|
case <-s.die:
|
||||||
|
return nil, io.ErrClosedPipe
|
||||||
|
default:
|
||||||
|
s.streams[sid] = stream
|
||||||
|
return stream, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) recvLoop() error {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
log.Error("[BUG]", r, string(debug.Stack()))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
var receivedSettingsFromClient bool
|
||||||
|
var hdr rawHeader
|
||||||
|
|
||||||
|
for {
|
||||||
|
if s.IsClosed() {
|
||||||
|
return io.ErrClosedPipe
|
||||||
|
}
|
||||||
|
// read header first
|
||||||
|
if _, err := io.ReadFull(s.conn, hdr[:]); err == nil {
|
||||||
|
sid := hdr.StreamID()
|
||||||
|
switch hdr.Cmd() {
|
||||||
|
case cmdPSH:
|
||||||
|
if hdr.Length() > 0 {
|
||||||
|
buffer := buf.Get(int(hdr.Length()))
|
||||||
|
if _, err := io.ReadFull(s.conn, buffer); err == nil {
|
||||||
|
s.streamLock.RLock()
|
||||||
|
stream, ok := s.streams[sid]
|
||||||
|
s.streamLock.RUnlock()
|
||||||
|
if ok {
|
||||||
|
stream.pipeW.Write(buffer)
|
||||||
|
}
|
||||||
|
buf.Put(buffer)
|
||||||
|
} else {
|
||||||
|
buf.Put(buffer)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case cmdSYN: // should be server only
|
||||||
|
if !s.isClient && !receivedSettingsFromClient {
|
||||||
|
f := newFrame(cmdAlert, 0)
|
||||||
|
f.data = []byte("client did not send its settings")
|
||||||
|
s.writeFrame(f)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
s.streamLock.Lock()
|
||||||
|
if _, ok := s.streams[sid]; !ok {
|
||||||
|
stream := newStream(sid, s)
|
||||||
|
s.streams[sid] = stream
|
||||||
|
if s.onNewStream != nil {
|
||||||
|
go s.onNewStream(stream)
|
||||||
|
} else {
|
||||||
|
go s.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.streamLock.Unlock()
|
||||||
|
case cmdFIN:
|
||||||
|
s.streamLock.RLock()
|
||||||
|
stream, ok := s.streams[sid]
|
||||||
|
s.streamLock.RUnlock()
|
||||||
|
if ok {
|
||||||
|
stream.Close()
|
||||||
|
}
|
||||||
|
//logrus.Debugln("stream fin", sid, s.streams)
|
||||||
|
case cmdWaste:
|
||||||
|
if hdr.Length() > 0 {
|
||||||
|
buffer := buf.Get(int(hdr.Length()))
|
||||||
|
if _, err := io.ReadFull(s.conn, buffer); err != nil {
|
||||||
|
buf.Put(buffer)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
buf.Put(buffer)
|
||||||
|
}
|
||||||
|
case cmdSettings:
|
||||||
|
if hdr.Length() > 0 {
|
||||||
|
buffer := buf.Get(int(hdr.Length()))
|
||||||
|
if _, err := io.ReadFull(s.conn, buffer); err != nil {
|
||||||
|
buf.Put(buffer)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !s.isClient {
|
||||||
|
receivedSettingsFromClient = true
|
||||||
|
m := util.StringMapFromBytes(buffer)
|
||||||
|
paddingF := s.padding.Load()
|
||||||
|
if m["padding-md5"] != paddingF.Md5 {
|
||||||
|
// logrus.Debugln("remote md5 is", m["padding-md5"])
|
||||||
|
f := newFrame(cmdUpdatePaddingScheme, 0)
|
||||||
|
f.data = paddingF.RawScheme
|
||||||
|
_, err = s.writeFrame(f)
|
||||||
|
if err != nil {
|
||||||
|
buf.Put(buffer)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
buf.Put(buffer)
|
||||||
|
}
|
||||||
|
case cmdAlert:
|
||||||
|
if hdr.Length() > 0 {
|
||||||
|
buffer := buf.Get(int(hdr.Length()))
|
||||||
|
if _, err := io.ReadFull(s.conn, buffer); err != nil {
|
||||||
|
buf.Put(buffer)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if s.isClient {
|
||||||
|
log.Error("[Alert from server]", string(buffer))
|
||||||
|
}
|
||||||
|
buf.Put(buffer)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
case cmdUpdatePaddingScheme:
|
||||||
|
if hdr.Length() > 0 {
|
||||||
|
// `rawScheme` Do not use buffer to prevent subsequent misuse
|
||||||
|
rawScheme := make([]byte, int(hdr.Length()))
|
||||||
|
if _, err := io.ReadFull(s.conn, rawScheme); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if s.isClient {
|
||||||
|
if padding.UpdatePaddingScheme(rawScheme, s.padding) {
|
||||||
|
log.Info(fmt.Sprintf("[Update padding succeed] %x\n", md5.Sum(rawScheme)))
|
||||||
|
} else {
|
||||||
|
log.Warn(fmt.Sprintf("[Update padding failed] %x\n", md5.Sum(rawScheme)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// I don't know what command it is (can't have data)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// notify the session that a stream has closed
|
||||||
|
func (s *Session) streamClosed(sid uint32) error {
|
||||||
|
_, err := s.writeFrame(newFrame(cmdFIN, sid))
|
||||||
|
s.streamLock.Lock()
|
||||||
|
delete(s.streams, sid)
|
||||||
|
s.streamLock.Unlock()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) writeFrame(frame frame) (int, error) {
|
||||||
|
dataLen := len(frame.data)
|
||||||
|
|
||||||
|
buffer := buf.NewSize(dataLen + headerOverHeadSize)
|
||||||
|
buffer.WriteByte(frame.cmd)
|
||||||
|
binary.BigEndian.PutUint32(buffer.Extend(4), frame.sid)
|
||||||
|
binary.BigEndian.PutUint16(buffer.Extend(2), uint16(dataLen))
|
||||||
|
buffer.Write(frame.data)
|
||||||
|
_, err := s.writeConn(buffer.Bytes())
|
||||||
|
buffer.Release()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return dataLen, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) writeConn(b []byte) (n int, err error) {
|
||||||
|
s.connLock.Lock()
|
||||||
|
defer s.connLock.Unlock()
|
||||||
|
|
||||||
|
if s.buffering {
|
||||||
|
s.buffer = append(s.buffer, b...)
|
||||||
|
return len(b), nil
|
||||||
|
} else if len(s.buffer) > 0 {
|
||||||
|
b = append(s.buffer, b...)
|
||||||
|
s.buffer = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// calulate & send padding
|
||||||
|
if s.sendPadding {
|
||||||
|
pkt := s.pktCounter.Add(1)
|
||||||
|
paddingF := s.padding.Load()
|
||||||
|
if pkt < paddingF.Stop {
|
||||||
|
pktSizes := paddingF.GenerateRecordPayloadSizes(pkt)
|
||||||
|
for _, l := range pktSizes {
|
||||||
|
remainPayloadLen := len(b)
|
||||||
|
if l == padding.CheckMark {
|
||||||
|
if remainPayloadLen == 0 {
|
||||||
|
break
|
||||||
|
} else {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if remainPayloadLen > l { // this packet is all payload
|
||||||
|
_, err = s.conn.Write(b[:l])
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
n += l
|
||||||
|
b = b[l:]
|
||||||
|
} else if remainPayloadLen > 0 { // this packet contains padding and the last part of payload
|
||||||
|
paddingLen := l - remainPayloadLen - headerOverHeadSize
|
||||||
|
if paddingLen > 0 {
|
||||||
|
padding := make([]byte, headerOverHeadSize+paddingLen)
|
||||||
|
padding[0] = cmdWaste
|
||||||
|
binary.BigEndian.PutUint32(padding[1:5], 0)
|
||||||
|
binary.BigEndian.PutUint16(padding[5:7], uint16(paddingLen))
|
||||||
|
b = append(b, padding...)
|
||||||
|
}
|
||||||
|
_, err = s.conn.Write(b)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
n += remainPayloadLen
|
||||||
|
b = nil
|
||||||
|
} else { // this packet is all padding
|
||||||
|
padding := make([]byte, headerOverHeadSize+l)
|
||||||
|
padding[0] = cmdWaste
|
||||||
|
binary.BigEndian.PutUint32(padding[1:5], 0)
|
||||||
|
binary.BigEndian.PutUint16(padding[5:7], uint16(l))
|
||||||
|
_, err = s.conn.Write(padding)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
b = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// maybe still remain payload to write
|
||||||
|
if len(b) == 0 {
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
n2, err := s.conn.Write(b)
|
||||||
|
return n + n2, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
s.sendPadding = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.conn.Write(b)
|
||||||
|
}
|
110
transport/anytls/session/stream.go
Normal file
110
transport/anytls/session/stream.go
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
package session
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/transport/anytls/pipe"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Stream implements net.Conn
|
||||||
|
type Stream struct {
|
||||||
|
id uint32
|
||||||
|
|
||||||
|
sess *Session
|
||||||
|
|
||||||
|
pipeR *pipe.PipeReader
|
||||||
|
pipeW *pipe.PipeWriter
|
||||||
|
writeDeadline pipe.PipeDeadline
|
||||||
|
|
||||||
|
dieOnce sync.Once
|
||||||
|
dieHook func()
|
||||||
|
}
|
||||||
|
|
||||||
|
// newStream initiates a Stream struct
|
||||||
|
func newStream(id uint32, sess *Session) *Stream {
|
||||||
|
s := new(Stream)
|
||||||
|
s.id = id
|
||||||
|
s.sess = sess
|
||||||
|
s.pipeR, s.pipeW = pipe.Pipe()
|
||||||
|
s.writeDeadline = pipe.MakePipeDeadline()
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read implements net.Conn
|
||||||
|
func (s *Stream) Read(b []byte) (n int, err error) {
|
||||||
|
return s.pipeR.Read(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write implements net.Conn
|
||||||
|
func (s *Stream) Write(b []byte) (n int, err error) {
|
||||||
|
select {
|
||||||
|
case <-s.writeDeadline.Wait():
|
||||||
|
return 0, os.ErrDeadlineExceeded
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
f := newFrame(cmdPSH, s.id)
|
||||||
|
f.data = b
|
||||||
|
n, err = s.sess.writeFrame(f)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close implements net.Conn
|
||||||
|
func (s *Stream) Close() error {
|
||||||
|
if s.sessionClose() {
|
||||||
|
// notify remote
|
||||||
|
return s.sess.streamClosed(s.id)
|
||||||
|
} else {
|
||||||
|
return io.ErrClosedPipe
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// sessionClose close stream from session side, do not notify remote
|
||||||
|
func (s *Stream) sessionClose() (once bool) {
|
||||||
|
s.dieOnce.Do(func() {
|
||||||
|
s.pipeR.Close()
|
||||||
|
once = true
|
||||||
|
if s.dieHook != nil {
|
||||||
|
s.dieHook()
|
||||||
|
s.dieHook = nil
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Stream) SetReadDeadline(t time.Time) error {
|
||||||
|
return s.pipeR.SetReadDeadline(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Stream) SetWriteDeadline(t time.Time) error {
|
||||||
|
s.writeDeadline.Set(t)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Stream) SetDeadline(t time.Time) error {
|
||||||
|
s.SetWriteDeadline(t)
|
||||||
|
return s.SetReadDeadline(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LocalAddr satisfies net.Conn interface
|
||||||
|
func (s *Stream) LocalAddr() net.Addr {
|
||||||
|
if ts, ok := s.sess.conn.(interface {
|
||||||
|
LocalAddr() net.Addr
|
||||||
|
}); ok {
|
||||||
|
return ts.LocalAddr()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoteAddr satisfies net.Conn interface
|
||||||
|
func (s *Stream) RemoteAddr() net.Addr {
|
||||||
|
if ts, ok := s.sess.conn.(interface {
|
||||||
|
RemoteAddr() net.Addr
|
||||||
|
}); ok {
|
||||||
|
return ts.RemoteAddr()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
46
transport/anytls/skiplist/contianer.go
Normal file
46
transport/anytls/skiplist/contianer.go
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
package skiplist
|
||||||
|
|
||||||
|
// Container is a holder object that stores a collection of other objects.
|
||||||
|
type Container interface {
|
||||||
|
IsEmpty() bool // IsEmpty checks if the container has no elements.
|
||||||
|
Len() int // Len returns the number of elements in the container.
|
||||||
|
Clear() // Clear erases all elements from the container. After this call, Len() returns zero.
|
||||||
|
}
|
||||||
|
|
||||||
|
// Map is a associative container that contains key-value pairs with unique keys.
|
||||||
|
type Map[K any, V any] interface {
|
||||||
|
Container
|
||||||
|
Has(K) bool // Checks whether the container contains element with specific key.
|
||||||
|
Find(K) *V // Finds element with specific key.
|
||||||
|
Insert(K, V) // Inserts a key-value pair in to the container or replace existing value.
|
||||||
|
Remove(K) bool // Remove element with specific key.
|
||||||
|
ForEach(func(K, V)) // Iterate the container.
|
||||||
|
ForEachIf(func(K, V) bool) // Iterate the container, stops when the callback returns false.
|
||||||
|
ForEachMutable(func(K, *V)) // Iterate the container, *V is mutable.
|
||||||
|
ForEachMutableIf(func(K, *V) bool) // Iterate the container, *V is mutable, stops when the callback returns false.
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set is a containers that store unique elements.
|
||||||
|
type Set[K any] interface {
|
||||||
|
Container
|
||||||
|
Has(K) bool // Checks whether the container contains element with specific key.
|
||||||
|
Insert(K) // Inserts a key-value pair in to the container or replace existing value.
|
||||||
|
InsertN(...K) // Inserts multiple key-value pairs in to the container or replace existing value.
|
||||||
|
Remove(K) bool // Remove element with specific key.
|
||||||
|
RemoveN(...K) // Remove multiple elements with specific keys.
|
||||||
|
ForEach(func(K)) // Iterate the container.
|
||||||
|
ForEachIf(func(K) bool) // Iterate the container, stops when the callback returns false.
|
||||||
|
}
|
||||||
|
|
||||||
|
// Iterator is the interface for container's iterator.
|
||||||
|
type Iterator[T any] interface {
|
||||||
|
IsNotEnd() bool // Whether it is point to the end of the range.
|
||||||
|
MoveToNext() // Let it point to the next element.
|
||||||
|
Value() T // Return the value of current element.
|
||||||
|
}
|
||||||
|
|
||||||
|
// MapIterator is the interface for map's iterator.
|
||||||
|
type MapIterator[K any, V any] interface {
|
||||||
|
Iterator[V]
|
||||||
|
Key() K // The key of the element
|
||||||
|
}
|
457
transport/anytls/skiplist/skiplist.go
Normal file
457
transport/anytls/skiplist/skiplist.go
Normal file
@ -0,0 +1,457 @@
|
|||||||
|
package skiplist
|
||||||
|
|
||||||
|
// This implementation is based on https://github.com/liyue201/gostl/tree/master/ds/skiplist
|
||||||
|
// (many thanks), added many optimizations, such as:
|
||||||
|
//
|
||||||
|
// - adaptive level
|
||||||
|
// - lesser search for prevs when key already exists.
|
||||||
|
// - reduce memory allocations
|
||||||
|
// - richer interface.
|
||||||
|
//
|
||||||
|
// etc.
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math/bits"
|
||||||
|
"math/rand"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
skipListMaxLevel = 40
|
||||||
|
)
|
||||||
|
|
||||||
|
// SkipList is a probabilistic data structure that seem likely to supplant balanced trees as the
|
||||||
|
// implementation method of choice for many applications. Skip list algorithms have the same
|
||||||
|
// asymptotic expected time bounds as balanced trees and are simpler, faster and use less space.
|
||||||
|
//
|
||||||
|
// See https://en.wikipedia.org/wiki/Skip_list for more details.
|
||||||
|
type SkipList[K any, V any] struct {
|
||||||
|
level int // Current level, may increase dynamically during insertion
|
||||||
|
len int // Total elements numner in the skiplist.
|
||||||
|
head skipListNode[K, V] // head.next[level] is the head of each level.
|
||||||
|
// This cache is used to save the previous nodes when modifying the skip list to avoid
|
||||||
|
// allocating memory each time it is called.
|
||||||
|
prevsCache []*skipListNode[K, V]
|
||||||
|
rander *rand.Rand
|
||||||
|
impl skipListImpl[K, V]
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSkipList creates a new SkipList for Ordered key type.
|
||||||
|
func NewSkipList[K Ordered, V any]() *SkipList[K, V] {
|
||||||
|
sl := skipListOrdered[K, V]{}
|
||||||
|
sl.init()
|
||||||
|
sl.impl = (skipListImpl[K, V])(&sl)
|
||||||
|
return &sl.SkipList
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSkipListFromMap creates a new SkipList from a map.
|
||||||
|
func NewSkipListFromMap[K Ordered, V any](m map[K]V) *SkipList[K, V] {
|
||||||
|
sl := NewSkipList[K, V]()
|
||||||
|
for k, v := range m {
|
||||||
|
sl.Insert(k, v)
|
||||||
|
}
|
||||||
|
return sl
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSkipListFunc creates a new SkipList with specified compare function keyCmp.
|
||||||
|
func NewSkipListFunc[K any, V any](keyCmp CompareFn[K]) *SkipList[K, V] {
|
||||||
|
sl := skipListFunc[K, V]{}
|
||||||
|
sl.init()
|
||||||
|
sl.keyCmp = keyCmp
|
||||||
|
sl.impl = skipListImpl[K, V](&sl)
|
||||||
|
return &sl.SkipList
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsEmpty implements the Container interface.
|
||||||
|
func (sl *SkipList[K, V]) IsEmpty() bool {
|
||||||
|
return sl.len == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Len implements the Container interface.
|
||||||
|
func (sl *SkipList[K, V]) Len() int {
|
||||||
|
return sl.len
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear implements the Container interface.
|
||||||
|
func (sl *SkipList[K, V]) Clear() {
|
||||||
|
for i := range sl.head.next {
|
||||||
|
sl.head.next[i] = nil
|
||||||
|
}
|
||||||
|
sl.level = 1
|
||||||
|
sl.len = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Iterate return an iterator to the skiplist.
|
||||||
|
func (sl *SkipList[K, V]) Iterate() MapIterator[K, V] {
|
||||||
|
return &skipListIterator[K, V]{sl.head.next[0], nil}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert inserts a key-value pair into the skiplist.
|
||||||
|
// If the key is already in the skip list, it's value will be updated.
|
||||||
|
func (sl *SkipList[K, V]) Insert(key K, value V) {
|
||||||
|
node, prevs := sl.impl.findInsertPoint(key)
|
||||||
|
|
||||||
|
if node != nil {
|
||||||
|
// Already exist, update the value
|
||||||
|
node.value = value
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
level := sl.randomLevel()
|
||||||
|
node = newSkipListNode(level, key, value)
|
||||||
|
|
||||||
|
minLevel := level
|
||||||
|
if sl.level < level {
|
||||||
|
minLevel = sl.level
|
||||||
|
}
|
||||||
|
for i := 0; i < minLevel; i++ {
|
||||||
|
node.next[i] = prevs[i].next[i]
|
||||||
|
prevs[i].next[i] = node
|
||||||
|
}
|
||||||
|
|
||||||
|
if level > sl.level {
|
||||||
|
for i := sl.level; i < level; i++ {
|
||||||
|
sl.head.next[i] = node
|
||||||
|
}
|
||||||
|
sl.level = level
|
||||||
|
}
|
||||||
|
|
||||||
|
sl.len++
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find returns the value associated with the passed key if the key is in the skiplist, otherwise
|
||||||
|
// returns nil.
|
||||||
|
func (sl *SkipList[K, V]) Find(key K) *V {
|
||||||
|
node := sl.impl.findNode(key)
|
||||||
|
if node != nil {
|
||||||
|
return &node.value
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Has implement the Map interface.
|
||||||
|
func (sl *SkipList[K, V]) Has(key K) bool {
|
||||||
|
return sl.impl.findNode(key) != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LowerBound returns an iterator to the first element in the skiplist that
|
||||||
|
// does not satisfy element < value (i.e. greater or equal to),
|
||||||
|
// or a end itetator if no such element is found.
|
||||||
|
func (sl *SkipList[K, V]) LowerBound(key K) MapIterator[K, V] {
|
||||||
|
return &skipListIterator[K, V]{sl.impl.lowerBound(key), nil}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpperBound returns an iterator to the first element in the skiplist that
|
||||||
|
// does not satisfy value < element (i.e. strictly greater),
|
||||||
|
// or a end itetator if no such element is found.
|
||||||
|
func (sl *SkipList[K, V]) UpperBound(key K) MapIterator[K, V] {
|
||||||
|
return &skipListIterator[K, V]{sl.impl.upperBound(key), nil}
|
||||||
|
}
|
||||||
|
|
||||||
|
// FindRange returns an iterator in range [first, last) (last is not includeed).
|
||||||
|
func (sl *SkipList[K, V]) FindRange(first, last K) MapIterator[K, V] {
|
||||||
|
return &skipListIterator[K, V]{sl.impl.lowerBound(first), sl.impl.upperBound(last)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove removes the key-value pair associated with the passed key and returns true if the key is
|
||||||
|
// in the skiplist, otherwise returns false.
|
||||||
|
func (sl *SkipList[K, V]) Remove(key K) bool {
|
||||||
|
node, prevs := sl.impl.findRemovePoint(key)
|
||||||
|
if node == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for i, v := range node.next {
|
||||||
|
prevs[i].next[i] = v
|
||||||
|
}
|
||||||
|
for sl.level > 1 && sl.head.next[sl.level-1] == nil {
|
||||||
|
sl.level--
|
||||||
|
}
|
||||||
|
sl.len--
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// ForEach implements the Map interface.
|
||||||
|
func (sl *SkipList[K, V]) ForEach(op func(K, V)) {
|
||||||
|
for e := sl.head.next[0]; e != nil; e = e.next[0] {
|
||||||
|
op(e.key, e.value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ForEachMutable implements the Map interface.
|
||||||
|
func (sl *SkipList[K, V]) ForEachMutable(op func(K, *V)) {
|
||||||
|
for e := sl.head.next[0]; e != nil; e = e.next[0] {
|
||||||
|
op(e.key, &e.value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ForEachIf implements the Map interface.
|
||||||
|
func (sl *SkipList[K, V]) ForEachIf(op func(K, V) bool) {
|
||||||
|
for e := sl.head.next[0]; e != nil; e = e.next[0] {
|
||||||
|
if !op(e.key, e.value) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ForEachMutableIf implements the Map interface.
|
||||||
|
func (sl *SkipList[K, V]) ForEachMutableIf(op func(K, *V) bool) {
|
||||||
|
for e := sl.head.next[0]; e != nil; e = e.next[0] {
|
||||||
|
if !op(e.key, &e.value) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// SkipList implementation part.
|
||||||
|
|
||||||
|
type skipListNode[K any, V any] struct {
|
||||||
|
key K
|
||||||
|
value V
|
||||||
|
next []*skipListNode[K, V]
|
||||||
|
}
|
||||||
|
|
||||||
|
//go:generate bash ./skiplist_newnode_generate.sh skipListMaxLevel skiplist_newnode.go
|
||||||
|
// func newSkipListNode[K Ordered, V any](level int, key K, value V) *skipListNode[K, V]
|
||||||
|
|
||||||
|
type skipListIterator[K any, V any] struct {
|
||||||
|
node, end *skipListNode[K, V]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (it *skipListIterator[K, V]) IsNotEnd() bool {
|
||||||
|
return it.node != it.end
|
||||||
|
}
|
||||||
|
|
||||||
|
func (it *skipListIterator[K, V]) MoveToNext() {
|
||||||
|
it.node = it.node.next[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (it *skipListIterator[K, V]) Key() K {
|
||||||
|
return it.node.key
|
||||||
|
}
|
||||||
|
|
||||||
|
func (it *skipListIterator[K, V]) Value() V {
|
||||||
|
return it.node.value
|
||||||
|
}
|
||||||
|
|
||||||
|
// skipListImpl is an interface to provide different implementation for Ordered key or CompareFn.
|
||||||
|
//
|
||||||
|
// We can use CompareFn to cumpare Ordered keys, but a separated implementation is much faster.
|
||||||
|
// We don't make the whole skip list an interface, in order to share the type independented method.
|
||||||
|
// And because these methods are called directly without going through the interface, they are also
|
||||||
|
// much faster.
|
||||||
|
type skipListImpl[K any, V any] interface {
|
||||||
|
findNode(key K) *skipListNode[K, V]
|
||||||
|
lowerBound(key K) *skipListNode[K, V]
|
||||||
|
upperBound(key K) *skipListNode[K, V]
|
||||||
|
findInsertPoint(key K) (*skipListNode[K, V], []*skipListNode[K, V])
|
||||||
|
findRemovePoint(key K) (*skipListNode[K, V], []*skipListNode[K, V])
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sl *SkipList[K, V]) init() {
|
||||||
|
sl.level = 1
|
||||||
|
// #nosec G404 -- This is not a security condition
|
||||||
|
sl.rander = rand.New(rand.NewSource(time.Now().Unix()))
|
||||||
|
sl.prevsCache = make([]*skipListNode[K, V], skipListMaxLevel)
|
||||||
|
sl.head.next = make([]*skipListNode[K, V], skipListMaxLevel)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sl *SkipList[K, V]) randomLevel() int {
|
||||||
|
total := uint64(1)<<uint64(skipListMaxLevel) - 1 // 2^n-1
|
||||||
|
k := sl.rander.Uint64() % total
|
||||||
|
level := skipListMaxLevel - bits.Len64(k) + 1
|
||||||
|
// Since levels are randomly generated, most should be less than log2(s.len).
|
||||||
|
// Then make a limit according to sl.len to avoid unexpectedly large value.
|
||||||
|
for level > 3 && 1<<(level-3) > sl.len {
|
||||||
|
level--
|
||||||
|
}
|
||||||
|
|
||||||
|
return level
|
||||||
|
}
|
||||||
|
|
||||||
|
/// skipListOrdered part
|
||||||
|
|
||||||
|
// skipListOrdered is the skip list implementation for Ordered types.
|
||||||
|
type skipListOrdered[K Ordered, V any] struct {
|
||||||
|
SkipList[K, V]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sl *skipListOrdered[K, V]) findNode(key K) *skipListNode[K, V] {
|
||||||
|
return sl.doFindNode(key, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sl *skipListOrdered[K, V]) doFindNode(key K, eq bool) *skipListNode[K, V] {
|
||||||
|
// This function execute the job of findNode if eq is true, otherwise lowBound.
|
||||||
|
// Passing the control variable eq is ugly but it's faster than testing node
|
||||||
|
// again outside the function in findNode.
|
||||||
|
prev := &sl.head
|
||||||
|
for i := sl.level - 1; i >= 0; i-- {
|
||||||
|
for cur := prev.next[i]; cur != nil; cur = cur.next[i] {
|
||||||
|
if cur.key == key {
|
||||||
|
return cur
|
||||||
|
}
|
||||||
|
if cur.key > key {
|
||||||
|
// All other node in this level must be greater than the key,
|
||||||
|
// search the next level.
|
||||||
|
break
|
||||||
|
}
|
||||||
|
prev = cur
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if eq {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return prev.next[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sl *skipListOrdered[K, V]) lowerBound(key K) *skipListNode[K, V] {
|
||||||
|
return sl.doFindNode(key, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sl *skipListOrdered[K, V]) upperBound(key K) *skipListNode[K, V] {
|
||||||
|
node := sl.lowerBound(key)
|
||||||
|
if node != nil && node.key == key {
|
||||||
|
return node.next[0]
|
||||||
|
}
|
||||||
|
return node
|
||||||
|
}
|
||||||
|
|
||||||
|
// findInsertPoint returns (*node, nil) to the existed node if the key exists,
|
||||||
|
// or (nil, []*node) to the previous nodes if the key doesn't exist
|
||||||
|
func (sl *skipListOrdered[K, V]) findInsertPoint(key K) (*skipListNode[K, V], []*skipListNode[K, V]) {
|
||||||
|
prevs := sl.prevsCache[0:sl.level]
|
||||||
|
prev := &sl.head
|
||||||
|
for i := sl.level - 1; i >= 0; i-- {
|
||||||
|
for next := prev.next[i]; next != nil; next = next.next[i] {
|
||||||
|
if next.key == key {
|
||||||
|
// The key is already existed, prevs are useless because no new node insertion.
|
||||||
|
// stop searching.
|
||||||
|
return next, nil
|
||||||
|
}
|
||||||
|
if next.key > key {
|
||||||
|
// All other node in this level must be greater than the key,
|
||||||
|
// search the next level.
|
||||||
|
break
|
||||||
|
}
|
||||||
|
prev = next
|
||||||
|
}
|
||||||
|
prevs[i] = prev
|
||||||
|
}
|
||||||
|
return nil, prevs
|
||||||
|
}
|
||||||
|
|
||||||
|
// findRemovePoint finds the node which match the key and it's previous nodes.
|
||||||
|
func (sl *skipListOrdered[K, V]) findRemovePoint(key K) (*skipListNode[K, V], []*skipListNode[K, V]) {
|
||||||
|
prevs := sl.findPrevNodes(key)
|
||||||
|
node := prevs[0].next[0]
|
||||||
|
if node == nil || node.key != key {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return node, prevs
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sl *skipListOrdered[K, V]) findPrevNodes(key K) []*skipListNode[K, V] {
|
||||||
|
prevs := sl.prevsCache[0:sl.level]
|
||||||
|
prev := &sl.head
|
||||||
|
for i := sl.level - 1; i >= 0; i-- {
|
||||||
|
for next := prev.next[i]; next != nil; next = next.next[i] {
|
||||||
|
if next.key >= key {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
prev = next
|
||||||
|
}
|
||||||
|
prevs[i] = prev
|
||||||
|
}
|
||||||
|
return prevs
|
||||||
|
}
|
||||||
|
|
||||||
|
/// skipListFunc part
|
||||||
|
|
||||||
|
// skipListFunc is the skip list implementation which compare keys with func.
|
||||||
|
type skipListFunc[K any, V any] struct {
|
||||||
|
SkipList[K, V]
|
||||||
|
keyCmp CompareFn[K]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sl *skipListFunc[K, V]) findNode(key K) *skipListNode[K, V] {
|
||||||
|
node := sl.lowerBound(key)
|
||||||
|
if node != nil && sl.keyCmp(node.key, key) == 0 {
|
||||||
|
return node
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sl *skipListFunc[K, V]) lowerBound(key K) *skipListNode[K, V] {
|
||||||
|
var prev = &sl.head
|
||||||
|
for i := sl.level - 1; i >= 0; i-- {
|
||||||
|
cur := prev.next[i]
|
||||||
|
for ; cur != nil; cur = cur.next[i] {
|
||||||
|
cmpRet := sl.keyCmp(cur.key, key)
|
||||||
|
if cmpRet == 0 {
|
||||||
|
return cur
|
||||||
|
}
|
||||||
|
if cmpRet > 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
prev = cur
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return prev.next[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sl *skipListFunc[K, V]) upperBound(key K) *skipListNode[K, V] {
|
||||||
|
node := sl.lowerBound(key)
|
||||||
|
if node != nil && sl.keyCmp(node.key, key) == 0 {
|
||||||
|
return node.next[0]
|
||||||
|
}
|
||||||
|
return node
|
||||||
|
}
|
||||||
|
|
||||||
|
// findInsertPoint returns (*node, nil) to the existed node if the key exists,
|
||||||
|
// or (nil, []*node) to the previous nodes if the key doesn't exist
|
||||||
|
func (sl *skipListFunc[K, V]) findInsertPoint(key K) (*skipListNode[K, V], []*skipListNode[K, V]) {
|
||||||
|
prevs := sl.prevsCache[0:sl.level]
|
||||||
|
prev := &sl.head
|
||||||
|
for i := sl.level - 1; i >= 0; i-- {
|
||||||
|
for cur := prev.next[i]; cur != nil; cur = cur.next[i] {
|
||||||
|
r := sl.keyCmp(cur.key, key)
|
||||||
|
if r == 0 {
|
||||||
|
// The key is already existed, prevs are useless because no new node insertion.
|
||||||
|
// stop searching.
|
||||||
|
return cur, nil
|
||||||
|
}
|
||||||
|
if r > 0 {
|
||||||
|
// All other node in this level must be greater than the key,
|
||||||
|
// search the next level.
|
||||||
|
break
|
||||||
|
}
|
||||||
|
prev = cur
|
||||||
|
}
|
||||||
|
prevs[i] = prev
|
||||||
|
}
|
||||||
|
return nil, prevs
|
||||||
|
}
|
||||||
|
|
||||||
|
// findRemovePoint finds the node which match the key and it's previous nodes.
|
||||||
|
func (sl *skipListFunc[K, V]) findRemovePoint(key K) (*skipListNode[K, V], []*skipListNode[K, V]) {
|
||||||
|
prevs := sl.findPrevNodes(key)
|
||||||
|
node := prevs[0].next[0]
|
||||||
|
if node == nil || sl.keyCmp(node.key, key) != 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return node, prevs
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sl *skipListFunc[K, V]) findPrevNodes(key K) []*skipListNode[K, V] {
|
||||||
|
prevs := sl.prevsCache[0:sl.level]
|
||||||
|
prev := &sl.head
|
||||||
|
for i := sl.level - 1; i >= 0; i-- {
|
||||||
|
for next := prev.next[i]; next != nil; next = next.next[i] {
|
||||||
|
if sl.keyCmp(next.key, key) >= 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
prev = next
|
||||||
|
}
|
||||||
|
prevs[i] = prev
|
||||||
|
}
|
||||||
|
return prevs
|
||||||
|
}
|
297
transport/anytls/skiplist/skiplist_newnode.go
Normal file
297
transport/anytls/skiplist/skiplist_newnode.go
Normal file
@ -0,0 +1,297 @@
|
|||||||
|
// AUTO GENERATED CODE, DON'T EDIT!!!
|
||||||
|
// EDIT skiplist_newnode_generate.sh accordingly.
|
||||||
|
|
||||||
|
package skiplist
|
||||||
|
|
||||||
|
// newSkipListNode creates a new node initialized with specified key, value and next slice.
|
||||||
|
func newSkipListNode[K any, V any](level int, key K, value V) *skipListNode[K, V] {
|
||||||
|
// For nodes with each levels, point their next slice to the nexts array allocated together,
|
||||||
|
// which can reduce 1 memory allocation and improve performance.
|
||||||
|
//
|
||||||
|
// The generics of the golang doesn't support non-type parameters like in C++,
|
||||||
|
// so we have to generate it manually.
|
||||||
|
switch level {
|
||||||
|
case 1:
|
||||||
|
n := struct {
|
||||||
|
head skipListNode[K, V]
|
||||||
|
nexts [1]*skipListNode[K, V]
|
||||||
|
}{head: skipListNode[K, V]{key, value, nil}}
|
||||||
|
n.head.next = n.nexts[:]
|
||||||
|
return &n.head
|
||||||
|
case 2:
|
||||||
|
n := struct {
|
||||||
|
head skipListNode[K, V]
|
||||||
|
nexts [2]*skipListNode[K, V]
|
||||||
|
}{head: skipListNode[K, V]{key, value, nil}}
|
||||||
|
n.head.next = n.nexts[:]
|
||||||
|
return &n.head
|
||||||
|
case 3:
|
||||||
|
n := struct {
|
||||||
|
head skipListNode[K, V]
|
||||||
|
nexts [3]*skipListNode[K, V]
|
||||||
|
}{head: skipListNode[K, V]{key, value, nil}}
|
||||||
|
n.head.next = n.nexts[:]
|
||||||
|
return &n.head
|
||||||
|
case 4:
|
||||||
|
n := struct {
|
||||||
|
head skipListNode[K, V]
|
||||||
|
nexts [4]*skipListNode[K, V]
|
||||||
|
}{head: skipListNode[K, V]{key, value, nil}}
|
||||||
|
n.head.next = n.nexts[:]
|
||||||
|
return &n.head
|
||||||
|
case 5:
|
||||||
|
n := struct {
|
||||||
|
head skipListNode[K, V]
|
||||||
|
nexts [5]*skipListNode[K, V]
|
||||||
|
}{head: skipListNode[K, V]{key, value, nil}}
|
||||||
|
n.head.next = n.nexts[:]
|
||||||
|
return &n.head
|
||||||
|
case 6:
|
||||||
|
n := struct {
|
||||||
|
head skipListNode[K, V]
|
||||||
|
nexts [6]*skipListNode[K, V]
|
||||||
|
}{head: skipListNode[K, V]{key, value, nil}}
|
||||||
|
n.head.next = n.nexts[:]
|
||||||
|
return &n.head
|
||||||
|
case 7:
|
||||||
|
n := struct {
|
||||||
|
head skipListNode[K, V]
|
||||||
|
nexts [7]*skipListNode[K, V]
|
||||||
|
}{head: skipListNode[K, V]{key, value, nil}}
|
||||||
|
n.head.next = n.nexts[:]
|
||||||
|
return &n.head
|
||||||
|
case 8:
|
||||||
|
n := struct {
|
||||||
|
head skipListNode[K, V]
|
||||||
|
nexts [8]*skipListNode[K, V]
|
||||||
|
}{head: skipListNode[K, V]{key, value, nil}}
|
||||||
|
n.head.next = n.nexts[:]
|
||||||
|
return &n.head
|
||||||
|
case 9:
|
||||||
|
n := struct {
|
||||||
|
head skipListNode[K, V]
|
||||||
|
nexts [9]*skipListNode[K, V]
|
||||||
|
}{head: skipListNode[K, V]{key, value, nil}}
|
||||||
|
n.head.next = n.nexts[:]
|
||||||
|
return &n.head
|
||||||
|
case 10:
|
||||||
|
n := struct {
|
||||||
|
head skipListNode[K, V]
|
||||||
|
nexts [10]*skipListNode[K, V]
|
||||||
|
}{head: skipListNode[K, V]{key, value, nil}}
|
||||||
|
n.head.next = n.nexts[:]
|
||||||
|
return &n.head
|
||||||
|
case 11:
|
||||||
|
n := struct {
|
||||||
|
head skipListNode[K, V]
|
||||||
|
nexts [11]*skipListNode[K, V]
|
||||||
|
}{head: skipListNode[K, V]{key, value, nil}}
|
||||||
|
n.head.next = n.nexts[:]
|
||||||
|
return &n.head
|
||||||
|
case 12:
|
||||||
|
n := struct {
|
||||||
|
head skipListNode[K, V]
|
||||||
|
nexts [12]*skipListNode[K, V]
|
||||||
|
}{head: skipListNode[K, V]{key, value, nil}}
|
||||||
|
n.head.next = n.nexts[:]
|
||||||
|
return &n.head
|
||||||
|
case 13:
|
||||||
|
n := struct {
|
||||||
|
head skipListNode[K, V]
|
||||||
|
nexts [13]*skipListNode[K, V]
|
||||||
|
}{head: skipListNode[K, V]{key, value, nil}}
|
||||||
|
n.head.next = n.nexts[:]
|
||||||
|
return &n.head
|
||||||
|
case 14:
|
||||||
|
n := struct {
|
||||||
|
head skipListNode[K, V]
|
||||||
|
nexts [14]*skipListNode[K, V]
|
||||||
|
}{head: skipListNode[K, V]{key, value, nil}}
|
||||||
|
n.head.next = n.nexts[:]
|
||||||
|
return &n.head
|
||||||
|
case 15:
|
||||||
|
n := struct {
|
||||||
|
head skipListNode[K, V]
|
||||||
|
nexts [15]*skipListNode[K, V]
|
||||||
|
}{head: skipListNode[K, V]{key, value, nil}}
|
||||||
|
n.head.next = n.nexts[:]
|
||||||
|
return &n.head
|
||||||
|
case 16:
|
||||||
|
n := struct {
|
||||||
|
head skipListNode[K, V]
|
||||||
|
nexts [16]*skipListNode[K, V]
|
||||||
|
}{head: skipListNode[K, V]{key, value, nil}}
|
||||||
|
n.head.next = n.nexts[:]
|
||||||
|
return &n.head
|
||||||
|
case 17:
|
||||||
|
n := struct {
|
||||||
|
head skipListNode[K, V]
|
||||||
|
nexts [17]*skipListNode[K, V]
|
||||||
|
}{head: skipListNode[K, V]{key, value, nil}}
|
||||||
|
n.head.next = n.nexts[:]
|
||||||
|
return &n.head
|
||||||
|
case 18:
|
||||||
|
n := struct {
|
||||||
|
head skipListNode[K, V]
|
||||||
|
nexts [18]*skipListNode[K, V]
|
||||||
|
}{head: skipListNode[K, V]{key, value, nil}}
|
||||||
|
n.head.next = n.nexts[:]
|
||||||
|
return &n.head
|
||||||
|
case 19:
|
||||||
|
n := struct {
|
||||||
|
head skipListNode[K, V]
|
||||||
|
nexts [19]*skipListNode[K, V]
|
||||||
|
}{head: skipListNode[K, V]{key, value, nil}}
|
||||||
|
n.head.next = n.nexts[:]
|
||||||
|
return &n.head
|
||||||
|
case 20:
|
||||||
|
n := struct {
|
||||||
|
head skipListNode[K, V]
|
||||||
|
nexts [20]*skipListNode[K, V]
|
||||||
|
}{head: skipListNode[K, V]{key, value, nil}}
|
||||||
|
n.head.next = n.nexts[:]
|
||||||
|
return &n.head
|
||||||
|
case 21:
|
||||||
|
n := struct {
|
||||||
|
head skipListNode[K, V]
|
||||||
|
nexts [21]*skipListNode[K, V]
|
||||||
|
}{head: skipListNode[K, V]{key, value, nil}}
|
||||||
|
n.head.next = n.nexts[:]
|
||||||
|
return &n.head
|
||||||
|
case 22:
|
||||||
|
n := struct {
|
||||||
|
head skipListNode[K, V]
|
||||||
|
nexts [22]*skipListNode[K, V]
|
||||||
|
}{head: skipListNode[K, V]{key, value, nil}}
|
||||||
|
n.head.next = n.nexts[:]
|
||||||
|
return &n.head
|
||||||
|
case 23:
|
||||||
|
n := struct {
|
||||||
|
head skipListNode[K, V]
|
||||||
|
nexts [23]*skipListNode[K, V]
|
||||||
|
}{head: skipListNode[K, V]{key, value, nil}}
|
||||||
|
n.head.next = n.nexts[:]
|
||||||
|
return &n.head
|
||||||
|
case 24:
|
||||||
|
n := struct {
|
||||||
|
head skipListNode[K, V]
|
||||||
|
nexts [24]*skipListNode[K, V]
|
||||||
|
}{head: skipListNode[K, V]{key, value, nil}}
|
||||||
|
n.head.next = n.nexts[:]
|
||||||
|
return &n.head
|
||||||
|
case 25:
|
||||||
|
n := struct {
|
||||||
|
head skipListNode[K, V]
|
||||||
|
nexts [25]*skipListNode[K, V]
|
||||||
|
}{head: skipListNode[K, V]{key, value, nil}}
|
||||||
|
n.head.next = n.nexts[:]
|
||||||
|
return &n.head
|
||||||
|
case 26:
|
||||||
|
n := struct {
|
||||||
|
head skipListNode[K, V]
|
||||||
|
nexts [26]*skipListNode[K, V]
|
||||||
|
}{head: skipListNode[K, V]{key, value, nil}}
|
||||||
|
n.head.next = n.nexts[:]
|
||||||
|
return &n.head
|
||||||
|
case 27:
|
||||||
|
n := struct {
|
||||||
|
head skipListNode[K, V]
|
||||||
|
nexts [27]*skipListNode[K, V]
|
||||||
|
}{head: skipListNode[K, V]{key, value, nil}}
|
||||||
|
n.head.next = n.nexts[:]
|
||||||
|
return &n.head
|
||||||
|
case 28:
|
||||||
|
n := struct {
|
||||||
|
head skipListNode[K, V]
|
||||||
|
nexts [28]*skipListNode[K, V]
|
||||||
|
}{head: skipListNode[K, V]{key, value, nil}}
|
||||||
|
n.head.next = n.nexts[:]
|
||||||
|
return &n.head
|
||||||
|
case 29:
|
||||||
|
n := struct {
|
||||||
|
head skipListNode[K, V]
|
||||||
|
nexts [29]*skipListNode[K, V]
|
||||||
|
}{head: skipListNode[K, V]{key, value, nil}}
|
||||||
|
n.head.next = n.nexts[:]
|
||||||
|
return &n.head
|
||||||
|
case 30:
|
||||||
|
n := struct {
|
||||||
|
head skipListNode[K, V]
|
||||||
|
nexts [30]*skipListNode[K, V]
|
||||||
|
}{head: skipListNode[K, V]{key, value, nil}}
|
||||||
|
n.head.next = n.nexts[:]
|
||||||
|
return &n.head
|
||||||
|
case 31:
|
||||||
|
n := struct {
|
||||||
|
head skipListNode[K, V]
|
||||||
|
nexts [31]*skipListNode[K, V]
|
||||||
|
}{head: skipListNode[K, V]{key, value, nil}}
|
||||||
|
n.head.next = n.nexts[:]
|
||||||
|
return &n.head
|
||||||
|
case 32:
|
||||||
|
n := struct {
|
||||||
|
head skipListNode[K, V]
|
||||||
|
nexts [32]*skipListNode[K, V]
|
||||||
|
}{head: skipListNode[K, V]{key, value, nil}}
|
||||||
|
n.head.next = n.nexts[:]
|
||||||
|
return &n.head
|
||||||
|
case 33:
|
||||||
|
n := struct {
|
||||||
|
head skipListNode[K, V]
|
||||||
|
nexts [33]*skipListNode[K, V]
|
||||||
|
}{head: skipListNode[K, V]{key, value, nil}}
|
||||||
|
n.head.next = n.nexts[:]
|
||||||
|
return &n.head
|
||||||
|
case 34:
|
||||||
|
n := struct {
|
||||||
|
head skipListNode[K, V]
|
||||||
|
nexts [34]*skipListNode[K, V]
|
||||||
|
}{head: skipListNode[K, V]{key, value, nil}}
|
||||||
|
n.head.next = n.nexts[:]
|
||||||
|
return &n.head
|
||||||
|
case 35:
|
||||||
|
n := struct {
|
||||||
|
head skipListNode[K, V]
|
||||||
|
nexts [35]*skipListNode[K, V]
|
||||||
|
}{head: skipListNode[K, V]{key, value, nil}}
|
||||||
|
n.head.next = n.nexts[:]
|
||||||
|
return &n.head
|
||||||
|
case 36:
|
||||||
|
n := struct {
|
||||||
|
head skipListNode[K, V]
|
||||||
|
nexts [36]*skipListNode[K, V]
|
||||||
|
}{head: skipListNode[K, V]{key, value, nil}}
|
||||||
|
n.head.next = n.nexts[:]
|
||||||
|
return &n.head
|
||||||
|
case 37:
|
||||||
|
n := struct {
|
||||||
|
head skipListNode[K, V]
|
||||||
|
nexts [37]*skipListNode[K, V]
|
||||||
|
}{head: skipListNode[K, V]{key, value, nil}}
|
||||||
|
n.head.next = n.nexts[:]
|
||||||
|
return &n.head
|
||||||
|
case 38:
|
||||||
|
n := struct {
|
||||||
|
head skipListNode[K, V]
|
||||||
|
nexts [38]*skipListNode[K, V]
|
||||||
|
}{head: skipListNode[K, V]{key, value, nil}}
|
||||||
|
n.head.next = n.nexts[:]
|
||||||
|
return &n.head
|
||||||
|
case 39:
|
||||||
|
n := struct {
|
||||||
|
head skipListNode[K, V]
|
||||||
|
nexts [39]*skipListNode[K, V]
|
||||||
|
}{head: skipListNode[K, V]{key, value, nil}}
|
||||||
|
n.head.next = n.nexts[:]
|
||||||
|
return &n.head
|
||||||
|
case 40:
|
||||||
|
n := struct {
|
||||||
|
head skipListNode[K, V]
|
||||||
|
nexts [40]*skipListNode[K, V]
|
||||||
|
}{head: skipListNode[K, V]{key, value, nil}}
|
||||||
|
n.head.next = n.nexts[:]
|
||||||
|
return &n.head
|
||||||
|
}
|
||||||
|
|
||||||
|
panic("should not reach here")
|
||||||
|
}
|
75
transport/anytls/skiplist/types.go
Normal file
75
transport/anytls/skiplist/types.go
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
package skiplist
|
||||||
|
|
||||||
|
// Signed is a constraint that permits any signed integer type.
|
||||||
|
// If future releases of Go add new predeclared signed integer types,
|
||||||
|
// this constraint will be modified to include them.
|
||||||
|
type Signed interface {
|
||||||
|
~int | ~int8 | ~int16 | ~int32 | ~int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unsigned is a constraint that permits any unsigned integer type.
|
||||||
|
// If future releases of Go add new predeclared unsigned integer types,
|
||||||
|
// this constraint will be modified to include them.
|
||||||
|
type Unsigned interface {
|
||||||
|
~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uintptr
|
||||||
|
}
|
||||||
|
|
||||||
|
// Integer is a constraint that permits any integer type.
|
||||||
|
// If future releases of Go add new predeclared integer types,
|
||||||
|
// this constraint will be modified to include them.
|
||||||
|
type Integer interface {
|
||||||
|
Signed | Unsigned
|
||||||
|
}
|
||||||
|
|
||||||
|
// Float is a constraint that permits any floating-point type.
|
||||||
|
// If future releases of Go add new predeclared floating-point types,
|
||||||
|
// this constraint will be modified to include them.
|
||||||
|
type Float interface {
|
||||||
|
~float32 | ~float64
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ordered is a constraint that permits any ordered type: any type
|
||||||
|
// that supports the operators < <= >= >.
|
||||||
|
// If future releases of Go add new ordered types,
|
||||||
|
// this constraint will be modified to include them.
|
||||||
|
type Ordered interface {
|
||||||
|
Integer | Float | ~string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Numeric is a constraint that permits any numeric type.
|
||||||
|
type Numeric interface {
|
||||||
|
Integer | Float
|
||||||
|
}
|
||||||
|
|
||||||
|
// LessFn is a function that returns whether 'a' is less than 'b'.
|
||||||
|
type LessFn[T any] func(a, b T) bool
|
||||||
|
|
||||||
|
// CompareFn is a 3 way compare function that
|
||||||
|
// returns 1 if a > b,
|
||||||
|
// returns 0 if a == b,
|
||||||
|
// returns -1 if a < b.
|
||||||
|
type CompareFn[T any] func(a, b T) int
|
||||||
|
|
||||||
|
// HashFn is a function that returns the hash of 't'.
|
||||||
|
type HashFn[T any] func(t T) uint64
|
||||||
|
|
||||||
|
// Equals wraps the '==' operator for comparable types.
|
||||||
|
func Equals[T comparable](a, b T) bool {
|
||||||
|
return a == b
|
||||||
|
}
|
||||||
|
|
||||||
|
// Less wraps the '<' operator for ordered types.
|
||||||
|
func Less[T Ordered](a, b T) bool {
|
||||||
|
return a < b
|
||||||
|
}
|
||||||
|
|
||||||
|
// OrderedCompare provide default CompareFn for ordered types.
|
||||||
|
func OrderedCompare[T Ordered](a, b T) int {
|
||||||
|
if a < b {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
if a > b {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
28
transport/anytls/util/routine.go
Normal file
28
transport/anytls/util/routine.go
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"runtime/debug"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sagernet/sing-box/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
func StartRoutine(ctx context.Context, d time.Duration, f func()) {
|
||||||
|
go func() {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
log.Error("[BUG]", r, string(debug.Stack()))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
for {
|
||||||
|
time.Sleep(d)
|
||||||
|
f()
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
27
transport/anytls/util/string_map.go
Normal file
27
transport/anytls/util/string_map.go
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type StringMap map[string]string
|
||||||
|
|
||||||
|
func (s StringMap) ToBytes() []byte {
|
||||||
|
var lines []string
|
||||||
|
for k, v := range s {
|
||||||
|
lines = append(lines, k+"="+v)
|
||||||
|
}
|
||||||
|
return []byte(strings.Join(lines, "\n"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func StringMapFromBytes(b []byte) StringMap {
|
||||||
|
var m = make(StringMap)
|
||||||
|
var lines = strings.Split(string(b), "\n")
|
||||||
|
for _, line := range lines {
|
||||||
|
v := strings.SplitN(line, "=", 2)
|
||||||
|
if len(v) == 2 {
|
||||||
|
m[v[0]] = v[1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user