From 477ba6eed4d2c6741fc44c4bafad120792efe7dd Mon Sep 17 00:00:00 2001 From: anytls Date: Wed, 19 Feb 2025 15:06:55 +0900 Subject: [PATCH] Implement AnyTLS client & server --- constant/proxy.go | 1 + include/registry.go | 3 + option/anytls.go | 24 + protocol/anytls/inbound.go | 130 +++++ protocol/anytls/outbound.go | 100 ++++ transport/anytls/client.go | 101 ++++ transport/anytls/padding/padding.go | 92 ++++ transport/anytls/pipe/deadline.go | 74 +++ transport/anytls/pipe/io_pipe.go | 232 +++++++++ transport/anytls/service.go | 125 +++++ transport/anytls/session/client.go | 159 ++++++ transport/anytls/session/frame.go | 44 ++ transport/anytls/session/session.go | 383 +++++++++++++++ transport/anytls/session/stream.go | 110 +++++ transport/anytls/skiplist/contianer.go | 46 ++ transport/anytls/skiplist/skiplist.go | 457 ++++++++++++++++++ transport/anytls/skiplist/skiplist_newnode.go | 297 ++++++++++++ transport/anytls/skiplist/types.go | 75 +++ transport/anytls/util/routine.go | 28 ++ transport/anytls/util/string_map.go | 27 ++ 20 files changed, 2508 insertions(+) create mode 100644 option/anytls.go create mode 100644 protocol/anytls/inbound.go create mode 100644 protocol/anytls/outbound.go create mode 100644 transport/anytls/client.go create mode 100644 transport/anytls/padding/padding.go create mode 100644 transport/anytls/pipe/deadline.go create mode 100644 transport/anytls/pipe/io_pipe.go create mode 100644 transport/anytls/service.go create mode 100644 transport/anytls/session/client.go create mode 100644 transport/anytls/session/frame.go create mode 100644 transport/anytls/session/session.go create mode 100644 transport/anytls/session/stream.go create mode 100644 transport/anytls/skiplist/contianer.go create mode 100644 transport/anytls/skiplist/skiplist.go create mode 100644 transport/anytls/skiplist/skiplist_newnode.go create mode 100644 transport/anytls/skiplist/types.go create mode 100644 transport/anytls/util/routine.go create mode 100644 transport/anytls/util/string_map.go diff --git a/constant/proxy.go b/constant/proxy.go index 45e79f84..787a1243 100644 --- a/constant/proxy.go +++ b/constant/proxy.go @@ -19,6 +19,7 @@ const ( TypeTor = "tor" TypeSSH = "ssh" TypeShadowTLS = "shadowtls" + TypeAnyTLS = "anytls" TypeShadowsocksR = "shadowsocksr" TypeVLESS = "vless" TypeTUIC = "tuic" diff --git a/include/registry.go b/include/registry.go index 866c506a..87aea576 100644 --- a/include/registry.go +++ b/include/registry.go @@ -15,6 +15,7 @@ import ( "github.com/sagernet/sing-box/dns/transport/local" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing-box/protocol/anytls" "github.com/sagernet/sing-box/protocol/block" "github.com/sagernet/sing-box/protocol/direct" protocolDNS "github.com/sagernet/sing-box/protocol/dns" @@ -53,6 +54,7 @@ func InboundRegistry() *inbound.Registry { naive.RegisterInbound(registry) shadowtls.RegisterInbound(registry) vless.RegisterInbound(registry) + anytls.RegisterInbound(registry) registerQUICInbounds(registry) registerStubForRemovedInbounds(registry) @@ -80,6 +82,7 @@ func OutboundRegistry() *outbound.Registry { ssh.RegisterOutbound(registry) shadowtls.RegisterOutbound(registry) vless.RegisterOutbound(registry) + anytls.RegisterOutbound(registry) registerQUICOutbounds(registry) registerWireGuardOutbound(registry) diff --git a/option/anytls.go b/option/anytls.go new file mode 100644 index 00000000..0ac19cd1 --- /dev/null +++ b/option/anytls.go @@ -0,0 +1,24 @@ +package option + +import "github.com/sagernet/sing/common/json/badoption" + +type AnyTLSInboundOptions struct { + ListenOptions + InboundTLSOptionsContainer + Users []AnyTLSUser `json:"users,omitempty"` + PaddingScheme badoption.Listable[string] `json:"padding_scheme,omitempty"` +} + +type AnyTLSUser struct { + Name string `json:"name,omitempty"` + Password string `json:"password,omitempty"` +} + +type AnyTLSOutboundOptions struct { + DialerOptions + ServerOptions + OutboundTLSOptionsContainer + Password string `json:"password,omitempty"` + IdleSessionCheckInterval badoption.Duration `json:"idle_session_check_interval,omitempty"` + IdleSessionTimeout badoption.Duration `json:"idle_session_timeout,omitempty"` +} diff --git a/protocol/anytls/inbound.go b/protocol/anytls/inbound.go new file mode 100644 index 00000000..ca43f780 --- /dev/null +++ b/protocol/anytls/inbound.go @@ -0,0 +1,130 @@ +package anytls + +import ( + "context" + "net" + "strings" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/adapter/inbound" + "github.com/sagernet/sing-box/common/listener" + "github.com/sagernet/sing-box/common/tls" + "github.com/sagernet/sing-box/common/uot" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing-box/transport/anytls" + "github.com/sagernet/sing-box/transport/anytls/padding" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/auth" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +func RegisterInbound(registry *inbound.Registry) { + inbound.Register[option.AnyTLSInboundOptions](registry, C.TypeAnyTLS, NewInbound) +} + +type Inbound struct { + inbound.Adapter + tlsConfig tls.ServerConfig + router adapter.ConnectionRouterEx + logger logger.ContextLogger + listener *listener.Listener + service *anytls.Service +} + +func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.AnyTLSInboundOptions) (adapter.Inbound, error) { + inbound := &Inbound{ + Adapter: inbound.NewAdapter(C.TypeAnyTLS, tag), + router: uot.NewRouter(router, logger), + logger: logger, + } + + if options.TLS != nil && options.TLS.Enabled { + tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS)) + if err != nil { + return nil, err + } + inbound.tlsConfig = tlsConfig + } + + paddingScheme := padding.DefaultPaddingScheme + if len(options.PaddingScheme) > 0 { + paddingScheme = []byte(strings.Join(options.PaddingScheme, "\n")) + } + + service, err := anytls.NewService(anytls.ServiceConfig{ + Users: common.Map(options.Users, func(it option.AnyTLSUser) anytls.User { + return (anytls.User)(it) + }), + TLSConfig: inbound.tlsConfig, + PaddingScheme: paddingScheme, + Handler: (*inboundHandler)(inbound), + Logger: logger, + }) + if err != nil { + return nil, err + } + inbound.service = service + inbound.listener = listener.New(listener.Options{ + Context: ctx, + Logger: logger, + Network: []string{N.NetworkTCP}, + Listen: options.ListenOptions, + ConnectionHandler: inbound, + }) + return inbound, nil +} + +func (h *Inbound) Start(stage adapter.StartStage) error { + if stage != adapter.StartStateStart { + return nil + } + if h.tlsConfig != nil { + err := h.tlsConfig.Start() + if err != nil { + return err + } + } + return h.listener.Start() +} + +func (h *Inbound) Close() error { + return common.Close(h.listener, h.tlsConfig) +} + +func (h *Inbound) NewConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) { + err := h.service.NewConnection(adapter.WithContext(log.ContextWithNewID(ctx), &metadata), conn, metadata.Source, metadata.Destination, onClose) + N.CloseOnHandshakeFailure(conn, onClose, err) + if err != nil { + if E.IsClosedOrCanceled(err) { + h.logger.DebugContext(ctx, "connection closed: ", err) + } else { + h.logger.ErrorContext(ctx, E.Cause(err, "process connection from ", metadata.Source)) + } + } +} + +type inboundHandler Inbound + +func (h *inboundHandler) NewConnectionEx(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) { + var metadata adapter.InboundContext + metadata.Inbound = h.Tag() + metadata.InboundType = h.Type() + //nolint:staticcheck + metadata.InboundDetour = h.listener.ListenOptions().Detour + //nolint:staticcheck + metadata.InboundOptions = h.listener.ListenOptions().InboundOptions + metadata.Source = source + metadata.Destination = destination + if userName, _ := auth.UserFromContext[string](ctx); userName != "" { + metadata.User = userName + h.logger.InfoContext(ctx, "[", userName, "] inbound connection to ", metadata.Destination) + } else { + h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination) + } + h.router.RouteConnectionEx(ctx, conn, metadata, onClose) +} diff --git a/protocol/anytls/outbound.go b/protocol/anytls/outbound.go new file mode 100644 index 00000000..2da04000 --- /dev/null +++ b/protocol/anytls/outbound.go @@ -0,0 +1,100 @@ +package anytls + +import ( + "context" + "net" + "os" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/adapter/outbound" + "github.com/sagernet/sing-box/common/dialer" + "github.com/sagernet/sing-box/common/tls" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing-box/transport/anytls" + "github.com/sagernet/sing/common" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/uot" +) + +func RegisterOutbound(registry *outbound.Registry) { + outbound.Register[option.AnyTLSOutboundOptions](registry, C.TypeAnyTLS, NewOutbound) +} + +type Outbound struct { + outbound.Adapter + client *anytls.Client + uotClient *uot.Client + logger log.ContextLogger +} + +func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.AnyTLSOutboundOptions) (adapter.Outbound, error) { + outbound := &Outbound{ + Adapter: outbound.NewAdapterWithDialerOptions(C.TypeAnyTLS, tag, []string{N.NetworkTCP}, options.DialerOptions), + logger: logger, + } + if options.TLS == nil || !options.TLS.Enabled { + return nil, C.ErrTLSRequired + } + + tlsConfig, err := tls.NewClient(ctx, options.Server, common.PtrValueOrDefault(options.TLS)) + if err != nil { + return nil, err + } + + outboundDialer, err := dialer.NewWithOptions(dialer.Options{ + Context: ctx, + Options: options.DialerOptions, + }) + if err != nil { + return nil, err + } + client, err := anytls.NewClient(ctx, anytls.ClientConfig{ + Password: options.Password, + IdleSessionCheckInterval: options.IdleSessionCheckInterval.Build(), + IdleSessionTimeout: options.IdleSessionTimeout.Build(), + Server: options.ServerOptions.Build(), + Dialer: outboundDialer, + TLSConfig: tlsConfig, + Logger: logger, + }) + if err != nil { + return nil, err + } + outbound.client = client + + outbound.uotClient = &uot.Client{ + Dialer: outbound, + Version: uot.Version, + } + return outbound, nil +} + +func (h *Outbound) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { + ctx, metadata := adapter.ExtendContext(ctx) + metadata.Outbound = h.Tag() + metadata.Destination = destination + switch N.NetworkName(network) { + case N.NetworkTCP: + h.logger.InfoContext(ctx, "outbound connection to ", destination) + return h.client.CreateProxy(ctx, destination) + case N.NetworkUDP: + h.logger.InfoContext(ctx, "outbound UoT packet connection to ", destination) + return h.uotClient.DialContext(ctx, network, destination) + } + return nil, os.ErrInvalid +} + +func (h *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { + ctx, metadata := adapter.ExtendContext(ctx) + metadata.Outbound = h.Tag() + metadata.Destination = destination + h.logger.InfoContext(ctx, "outbound UoT packet connection to ", destination) + return h.uotClient.ListenPacket(ctx, destination) +} + +func (h *Outbound) Close() error { + return common.Close(h.client) +} diff --git a/transport/anytls/client.go b/transport/anytls/client.go new file mode 100644 index 00000000..7ad40f0f --- /dev/null +++ b/transport/anytls/client.go @@ -0,0 +1,101 @@ +package anytls + +import ( + "context" + "crypto/sha256" + "encoding/binary" + "net" + "time" + + "github.com/sagernet/sing-box/common/tls" + "github.com/sagernet/sing-box/transport/anytls/padding" + "github.com/sagernet/sing-box/transport/anytls/session" + "github.com/sagernet/sing/common/atomic" + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/logger" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +type ClientConfig struct { + Password string + IdleSessionCheckInterval time.Duration + IdleSessionTimeout time.Duration + Server M.Socksaddr + Dialer N.Dialer + TLSConfig tls.Config + Logger logger.ContextLogger +} + +type Client struct { + passwordSha256 []byte + tlsConfig tls.Config + dialer N.Dialer + server M.Socksaddr + sessionClient *session.Client + padding atomic.TypedValue[*padding.PaddingFactory] +} + +func NewClient(ctx context.Context, config ClientConfig) (*Client, error) { + pw := sha256.Sum256([]byte(config.Password)) + c := &Client{ + passwordSha256: pw[:], + tlsConfig: config.TLSConfig, + dialer: config.Dialer, + server: config.Server, + } + // Initialize the padding state of this client + padding.UpdatePaddingScheme(padding.DefaultPaddingScheme, &c.padding) + c.sessionClient = session.NewClient(ctx, c.CreateOutboundTLSConnection, &c.padding, config.IdleSessionCheckInterval, config.IdleSessionTimeout) + return c, nil +} + +func (c *Client) CreateProxy(ctx context.Context, destination M.Socksaddr) (net.Conn, error) { + conn, err := c.sessionClient.CreateStream(ctx) + if err != nil { + return nil, err + } + err = M.SocksaddrSerializer.WriteAddrPort(conn, destination) + if err != nil { + conn.Close() + return nil, err + } + return conn, nil +} + +func (c *Client) CreateOutboundTLSConnection(ctx context.Context) (net.Conn, error) { + conn, err := c.dialer.DialContext(ctx, N.NetworkTCP, c.server) + if err != nil { + return nil, err + } + + b := buf.NewPacket() + defer b.Release() + + b.Write(c.passwordSha256) + var paddingLen int + if pad := c.padding.Load().GenerateRecordPayloadSizes(0); len(pad) > 0 { + paddingLen = pad[0] + } + binary.BigEndian.PutUint16(b.Extend(2), uint16(paddingLen)) + if paddingLen > 0 { + b.WriteZeroN(paddingLen) + } + + conn, err = tls.ClientHandshake(ctx, conn, c.tlsConfig) + if err != nil { + return nil, err + } + + _, err = b.WriteTo(conn) + if err != nil { + conn.Close() + return nil, err + } + + return conn, nil +} + +func (h *Client) Close() error { + return h.sessionClient.Close() +} diff --git a/transport/anytls/padding/padding.go b/transport/anytls/padding/padding.go new file mode 100644 index 00000000..296d081f --- /dev/null +++ b/transport/anytls/padding/padding.go @@ -0,0 +1,92 @@ +package padding + +import ( + "crypto/md5" + "crypto/rand" + "fmt" + "math/big" + "strconv" + "strings" + + "github.com/sagernet/sing-box/transport/anytls/util" + "github.com/sagernet/sing/common/atomic" +) + +const CheckMark = -1 + +var DefaultPaddingScheme = []byte(`stop=8 +0=34-120 +1=100-400 +2=400-500,c,500-1000,c,400-500,c,500-1000,c,500-1000,c,400-500 +3=500-1000 +4=500-1000 +5=500-1000 +6=500-1000 +7=500-1000`) + +type PaddingFactory struct { + scheme util.StringMap + RawScheme []byte + Stop uint32 + Md5 string +} + +func UpdatePaddingScheme(rawScheme []byte, to *atomic.TypedValue[*PaddingFactory]) bool { + if p := NewPaddingFactory(rawScheme); p != nil { + to.Store(p) + return true + } + return false +} + +func NewPaddingFactory(rawScheme []byte) *PaddingFactory { + p := &PaddingFactory{ + RawScheme: rawScheme, + Md5: fmt.Sprintf("%x", md5.Sum(rawScheme)), + } + scheme := util.StringMapFromBytes(rawScheme) + if len(scheme) == 0 { + return nil + } + if stop, err := strconv.Atoi(scheme["stop"]); err == nil { + p.Stop = uint32(stop) + } else { + return nil + } + p.scheme = scheme + return p +} + +func (p *PaddingFactory) GenerateRecordPayloadSizes(pkt uint32) (pktSizes []int) { + if s, ok := p.scheme[strconv.Itoa(int(pkt))]; ok { + sRanges := strings.Split(s, ",") + for _, sRange := range sRanges { + sRangeMinMax := strings.Split(sRange, "-") + if len(sRangeMinMax) == 2 { + _min, err := strconv.ParseInt(sRangeMinMax[0], 10, 64) + if err != nil { + continue + } + _max, err := strconv.ParseInt(sRangeMinMax[1], 10, 64) + if err != nil { + continue + } + if _min > _max { + _min, _max = _max, _min + } + if _min <= 0 || _max <= 0 { + continue + } + if _min == _max { + pktSizes = append(pktSizes, int(_min)) + } else { + i, _ := rand.Int(rand.Reader, big.NewInt(_max-_min)) + pktSizes = append(pktSizes, int(i.Int64()+_min)) + } + } else if sRange == "c" { + pktSizes = append(pktSizes, CheckMark) + } + } + } + return +} diff --git a/transport/anytls/pipe/deadline.go b/transport/anytls/pipe/deadline.go new file mode 100644 index 00000000..29c4ec0a --- /dev/null +++ b/transport/anytls/pipe/deadline.go @@ -0,0 +1,74 @@ +package pipe + +import ( + "sync" + "time" +) + +// PipeDeadline is an abstraction for handling timeouts. +type PipeDeadline struct { + mu sync.Mutex // Guards timer and cancel + timer *time.Timer + cancel chan struct{} // Must be non-nil +} + +func MakePipeDeadline() PipeDeadline { + return PipeDeadline{cancel: make(chan struct{})} +} + +// Set sets the point in time when the deadline will time out. +// A timeout event is signaled by closing the channel returned by waiter. +// Once a timeout has occurred, the deadline can be refreshed by specifying a +// t value in the future. +// +// A zero value for t prevents timeout. +func (d *PipeDeadline) Set(t time.Time) { + d.mu.Lock() + defer d.mu.Unlock() + + if d.timer != nil && !d.timer.Stop() { + <-d.cancel // Wait for the timer callback to finish and close cancel + } + d.timer = nil + + // Time is zero, then there is no deadline. + closed := isClosedChan(d.cancel) + if t.IsZero() { + if closed { + d.cancel = make(chan struct{}) + } + return + } + + // Time in the future, setup a timer to cancel in the future. + if dur := time.Until(t); dur > 0 { + if closed { + d.cancel = make(chan struct{}) + } + d.timer = time.AfterFunc(dur, func() { + close(d.cancel) + }) + return + } + + // Time in the past, so close immediately. + if !closed { + close(d.cancel) + } +} + +// Wait returns a channel that is closed when the deadline is exceeded. +func (d *PipeDeadline) Wait() chan struct{} { + d.mu.Lock() + defer d.mu.Unlock() + return d.cancel +} + +func isClosedChan(c <-chan struct{}) bool { + select { + case <-c: + return true + default: + return false + } +} diff --git a/transport/anytls/pipe/io_pipe.go b/transport/anytls/pipe/io_pipe.go new file mode 100644 index 00000000..5d0fd252 --- /dev/null +++ b/transport/anytls/pipe/io_pipe.go @@ -0,0 +1,232 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Pipe adapter to connect code expecting an io.Reader +// with code expecting an io.Writer. + +package pipe + +import ( + "io" + "os" + "sync" + "time" +) + +// onceError is an object that will only store an error once. +type onceError struct { + sync.Mutex // guards following + err error +} + +func (a *onceError) Store(err error) { + a.Lock() + defer a.Unlock() + if a.err != nil { + return + } + a.err = err +} +func (a *onceError) Load() error { + a.Lock() + defer a.Unlock() + return a.err +} + +// A pipe is the shared pipe structure underlying PipeReader and PipeWriter. +type pipe struct { + wrMu sync.Mutex // Serializes Write operations + wrCh chan []byte + rdCh chan int + + once sync.Once // Protects closing done + done chan struct{} + rerr onceError + werr onceError + + readDeadline PipeDeadline + writeDeadline PipeDeadline +} + +func (p *pipe) read(b []byte) (n int, err error) { + select { + case <-p.done: + return 0, p.readCloseError() + case <-p.readDeadline.Wait(): + return 0, os.ErrDeadlineExceeded + default: + } + + select { + case bw := <-p.wrCh: + nr := copy(b, bw) + p.rdCh <- nr + return nr, nil + case <-p.done: + return 0, p.readCloseError() + case <-p.readDeadline.Wait(): + return 0, os.ErrDeadlineExceeded + } +} + +func (p *pipe) closeRead(err error) error { + if err == nil { + err = io.ErrClosedPipe + } + p.rerr.Store(err) + p.once.Do(func() { close(p.done) }) + return nil +} + +func (p *pipe) write(b []byte) (n int, err error) { + select { + case <-p.done: + return 0, p.writeCloseError() + case <-p.writeDeadline.Wait(): + return 0, os.ErrDeadlineExceeded + default: + p.wrMu.Lock() + defer p.wrMu.Unlock() + } + + for once := true; once || len(b) > 0; once = false { + select { + case p.wrCh <- b: + nw := <-p.rdCh + b = b[nw:] + n += nw + case <-p.done: + return n, p.writeCloseError() + case <-p.writeDeadline.Wait(): + return n, os.ErrDeadlineExceeded + } + } + return n, nil +} + +func (p *pipe) closeWrite(err error) error { + if err == nil { + err = io.EOF + } + p.werr.Store(err) + p.once.Do(func() { close(p.done) }) + return nil +} + +// readCloseError is considered internal to the pipe type. +func (p *pipe) readCloseError() error { + rerr := p.rerr.Load() + if werr := p.werr.Load(); rerr == nil && werr != nil { + return werr + } + return io.ErrClosedPipe +} + +// writeCloseError is considered internal to the pipe type. +func (p *pipe) writeCloseError() error { + werr := p.werr.Load() + if rerr := p.rerr.Load(); werr == nil && rerr != nil { + return rerr + } + return io.ErrClosedPipe +} + +// A PipeReader is the read half of a pipe. +type PipeReader struct{ pipe } + +// Read implements the standard Read interface: +// it reads data from the pipe, blocking until a writer +// arrives or the write end is closed. +// If the write end is closed with an error, that error is +// returned as err; otherwise err is EOF. +func (r *PipeReader) Read(data []byte) (n int, err error) { + return r.pipe.read(data) +} + +// Close closes the reader; subsequent writes to the +// write half of the pipe will return the error [ErrClosedPipe]. +func (r *PipeReader) Close() error { + return r.CloseWithError(nil) +} + +// CloseWithError closes the reader; subsequent writes +// to the write half of the pipe will return the error err. +// +// CloseWithError never overwrites the previous error if it exists +// and always returns nil. +func (r *PipeReader) CloseWithError(err error) error { + return r.pipe.closeRead(err) +} + +// A PipeWriter is the write half of a pipe. +type PipeWriter struct{ r PipeReader } + +// Write implements the standard Write interface: +// it writes data to the pipe, blocking until one or more readers +// have consumed all the data or the read end is closed. +// If the read end is closed with an error, that err is +// returned as err; otherwise err is [ErrClosedPipe]. +func (w *PipeWriter) Write(data []byte) (n int, err error) { + return w.r.pipe.write(data) +} + +// Close closes the writer; subsequent reads from the +// read half of the pipe will return no bytes and EOF. +func (w *PipeWriter) Close() error { + return w.CloseWithError(nil) +} + +// CloseWithError closes the writer; subsequent reads from the +// read half of the pipe will return no bytes and the error err, +// or EOF if err is nil. +// +// CloseWithError never overwrites the previous error if it exists +// and always returns nil. +func (w *PipeWriter) CloseWithError(err error) error { + return w.r.pipe.closeWrite(err) +} + +// Pipe creates a synchronous in-memory pipe. +// It can be used to connect code expecting an [io.Reader] +// with code expecting an [io.Writer]. +// +// Reads and Writes on the pipe are matched one to one +// except when multiple Reads are needed to consume a single Write. +// That is, each Write to the [PipeWriter] blocks until it has satisfied +// one or more Reads from the [PipeReader] that fully consume +// the written data. +// The data is copied directly from the Write to the corresponding +// Read (or Reads); there is no internal buffering. +// +// It is safe to call Read and Write in parallel with each other or with Close. +// Parallel calls to Read and parallel calls to Write are also safe: +// the individual calls will be gated sequentially. +// +// Added SetReadDeadline and SetWriteDeadline methods based on `io.Pipe`. +func Pipe() (*PipeReader, *PipeWriter) { + pw := &PipeWriter{r: PipeReader{pipe: pipe{ + wrCh: make(chan []byte), + rdCh: make(chan int), + done: make(chan struct{}), + readDeadline: MakePipeDeadline(), + writeDeadline: MakePipeDeadline(), + }}} + return &pw.r, pw +} + +func (p *PipeReader) SetReadDeadline(t time.Time) error { + if isClosedChan(p.done) { + return io.ErrClosedPipe + } + p.readDeadline.Set(t) + return nil +} + +func (p *PipeWriter) SetWriteDeadline(t time.Time) error { + if isClosedChan(p.r.done) { + return io.ErrClosedPipe + } + p.r.writeDeadline.Set(t) + return nil +} diff --git a/transport/anytls/service.go b/transport/anytls/service.go new file mode 100644 index 00000000..aa131150 --- /dev/null +++ b/transport/anytls/service.go @@ -0,0 +1,125 @@ +package anytls + +import ( + "context" + "crypto/sha256" + "encoding/binary" + "errors" + "net" + "os" + + "github.com/sagernet/sing-box/common/tls" + "github.com/sagernet/sing-box/transport/anytls/padding" + "github.com/sagernet/sing-box/transport/anytls/session" + "github.com/sagernet/sing/common/atomic" + "github.com/sagernet/sing/common/auth" + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/bufio" + "github.com/sagernet/sing/common/logger" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +type Service struct { + users map[[32]byte]string + padding atomic.TypedValue[*padding.PaddingFactory] + tlsConfig tls.ServerConfig + handler N.TCPConnectionHandlerEx + logger logger.ContextLogger +} + +type ServiceConfig struct { + PaddingScheme []byte + Users []User + TLSConfig tls.ServerConfig + Handler N.TCPConnectionHandlerEx + Logger logger.ContextLogger +} + +type User struct { + Name string + Password string +} + +func NewService(config ServiceConfig) (*Service, error) { + service := &Service{ + tlsConfig: config.TLSConfig, + handler: config.Handler, + logger: config.Logger, + users: make(map[[32]byte]string), + } + + if service.handler == nil || service.logger == nil { + return nil, os.ErrInvalid + } + + for _, user := range config.Users { + service.users[sha256.Sum256([]byte(user.Password))] = user.Name + } + + if !padding.UpdatePaddingScheme(config.PaddingScheme, &service.padding) { + return nil, errors.New("incorrect padding scheme format") + } + + return service, nil +} + +func (s *Service) NewConnection(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) error { + var err error + + if s.tlsConfig != nil { + conn, err = tls.ServerHandshake(ctx, conn, s.tlsConfig) + if err != nil { + return err + } + } + + b := buf.NewPacket() + defer b.Release() + + _, err = b.ReadOnceFrom(conn) + if err != nil { + return err + } + conn = bufio.NewCachedConn(conn, b) + + by, err := b.ReadBytes(32) + if err != nil { + b.Reset() + return os.ErrInvalid + } + var passwordSha256 [32]byte + copy(passwordSha256[:], by) + if user, ok := s.users[passwordSha256]; ok { + ctx = auth.ContextWithUser(ctx, user) + } else { + b.Reset() + return os.ErrInvalid + } + by, err = b.ReadBytes(2) + if err != nil { + b.Reset() + return os.ErrInvalid + } + paddingLen := binary.BigEndian.Uint16(by) + if paddingLen > 0 { + _, err = b.ReadBytes(int(paddingLen)) + if err != nil { + b.Reset() + return os.ErrInvalid + } + } + + session := session.NewServerSession(conn, func(stream *session.Stream) { + destination, err := M.SocksaddrSerializer.ReadAddrPort(stream) + if err != nil { + s.logger.ErrorContext(ctx, "ReadAddrPort:", err) + return + } + + s.handler.NewConnectionEx(ctx, stream, source, destination, onClose) + }, &s.padding) + session.Run() + session.Close() + return nil +} diff --git a/transport/anytls/session/client.go b/transport/anytls/session/client.go new file mode 100644 index 00000000..0b18eaa4 --- /dev/null +++ b/transport/anytls/session/client.go @@ -0,0 +1,159 @@ +package session + +import ( + "context" + "fmt" + "io" + "math" + "net" + "sync" + "time" + + "github.com/sagernet/sing-box/transport/anytls/padding" + "github.com/sagernet/sing-box/transport/anytls/skiplist" + "github.com/sagernet/sing-box/transport/anytls/util" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/atomic" +) + +type Client struct { + die context.Context + dieCancel context.CancelFunc + + dialOut func(ctx context.Context) (net.Conn, error) + + sessionCounter atomic.Uint64 + idleSession *skiplist.SkipList[uint64, *Session] + idleSessionLock sync.Mutex + + padding *atomic.TypedValue[*padding.PaddingFactory] + + idleSessionTimeout time.Duration +} + +func NewClient(ctx context.Context, dialOut func(ctx context.Context) (net.Conn, error), _padding *atomic.TypedValue[*padding.PaddingFactory], idleSessionCheckInterval, idleSessionTimeout time.Duration) *Client { + c := &Client{ + dialOut: dialOut, + padding: _padding, + idleSessionTimeout: idleSessionTimeout, + } + if idleSessionCheckInterval <= time.Second*5 { + idleSessionCheckInterval = time.Second * 30 + } + if c.idleSessionTimeout <= time.Second*5 { + c.idleSessionTimeout = time.Second * 30 + } + c.die, c.dieCancel = context.WithCancel(ctx) + c.idleSession = skiplist.NewSkipList[uint64, *Session]() + util.StartRoutine(c.die, idleSessionCheckInterval, c.idleCleanup) + return c +} + +func (c *Client) CreateStream(ctx context.Context) (net.Conn, error) { + select { + case <-c.die.Done(): + return nil, io.ErrClosedPipe + default: + } + + var session *Session + var stream *Stream + var err error + + for i := 0; i < 3; i++ { + session, err = c.findSession(ctx) + if session == nil { + return nil, fmt.Errorf("failed to create session: %w", err) + } + stream, err = session.OpenStream() + if err != nil { + common.Close(session, stream) + continue + } + break + } + if session == nil || stream == nil { + return nil, fmt.Errorf("too many closed session: %w", err) + } + + stream.dieHook = func() { + if session.IsClosed() { + if session.dieHook != nil { + session.dieHook() + } + } else { + c.idleSessionLock.Lock() + session.idleSince = time.Now() + c.idleSession.Insert(math.MaxUint64-session.seq, session) + c.idleSessionLock.Unlock() + } + } + + return stream, nil +} + +func (c *Client) findSession(ctx context.Context) (*Session, error) { + var idle *Session + + c.idleSessionLock.Lock() + if !c.idleSession.IsEmpty() { + it := c.idleSession.Iterate() + idle = it.Value() + c.idleSession.Remove(it.Key()) + } + c.idleSessionLock.Unlock() + + if idle == nil { + s, err := c.createSession(ctx) + return s, err + } + return idle, nil +} + +func (c *Client) createSession(ctx context.Context) (*Session, error) { + underlying, err := c.dialOut(ctx) + if err != nil { + return nil, err + } + + session := NewClientSession(underlying, c.padding) + session.seq = c.sessionCounter.Add(1) + session.dieHook = func() { + //logrus.Debugln("session died", session) + c.idleSessionLock.Lock() + c.idleSession.Remove(math.MaxUint64 - session.seq) + c.idleSessionLock.Unlock() + } + session.Run() + return session, nil +} + +func (c *Client) Close() error { + c.dieCancel() + go c.idleCleanupExpTime(time.Now()) + return nil +} + +func (c *Client) idleCleanup() { + c.idleCleanupExpTime(time.Now().Add(-c.idleSessionTimeout)) +} + +func (c *Client) idleCleanupExpTime(expTime time.Time) { + var sessionToRemove = make([]*Session, 0) + + c.idleSessionLock.Lock() + it := c.idleSession.Iterate() + for it.IsNotEnd() { + session := it.Value() + if session.idleSince.Before(expTime) { + sessionToRemove = append(sessionToRemove, session) + c.idleSession.Remove(it.Key()) + } + it.MoveToNext() + } + c.idleSessionLock.Unlock() + + for _, session := range sessionToRemove { + session.Close() + } +} diff --git a/transport/anytls/session/frame.go b/transport/anytls/session/frame.go new file mode 100644 index 00000000..49597c55 --- /dev/null +++ b/transport/anytls/session/frame.go @@ -0,0 +1,44 @@ +package session + +import ( + "encoding/binary" +) + +const ( // cmds + cmdWaste = 0 // Paddings + cmdSYN = 1 // stream open + cmdPSH = 2 // data push + cmdFIN = 3 // stream close, a.k.a EOF mark + cmdSettings = 4 // Settings + cmdAlert = 5 // Alert + cmdUpdatePaddingScheme = 6 // update padding scheme +) + +const ( + headerOverHeadSize = 1 + 4 + 2 +) + +// frame defines a packet from or to be multiplexed into a single connection +type frame struct { + cmd byte // 1 + sid uint32 // 4 + data []byte // 2 + len(data) +} + +func newFrame(cmd byte, sid uint32) frame { + return frame{cmd: cmd, sid: sid} +} + +type rawHeader [headerOverHeadSize]byte + +func (h rawHeader) Cmd() byte { + return h[0] +} + +func (h rawHeader) StreamID() uint32 { + return binary.BigEndian.Uint32(h[1:]) +} + +func (h rawHeader) Length() uint16 { + return binary.BigEndian.Uint16(h[5:]) +} diff --git a/transport/anytls/session/session.go b/transport/anytls/session/session.go new file mode 100644 index 00000000..037c8625 --- /dev/null +++ b/transport/anytls/session/session.go @@ -0,0 +1,383 @@ +package session + +import ( + "crypto/md5" + "encoding/binary" + "fmt" + "io" + "net" + "runtime/debug" + "sync" + "time" + + "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/transport/anytls/padding" + "github.com/sagernet/sing-box/transport/anytls/util" + "github.com/sagernet/sing/common/atomic" + "github.com/sagernet/sing/common/buf" +) + +type Session struct { + conn net.Conn + connLock sync.Mutex + + streams map[uint32]*Stream + streamId atomic.Uint32 + streamLock sync.RWMutex + + dieOnce sync.Once + die chan struct{} + dieHook func() + + // pool + seq uint64 + idleSince time.Time + padding *atomic.TypedValue[*padding.PaddingFactory] + + // client + isClient bool + sendPadding bool + buffering bool + buffer []byte + pktCounter atomic.Uint32 + + // server + onNewStream func(stream *Stream) +} + +func NewClientSession(conn net.Conn, _padding *atomic.TypedValue[*padding.PaddingFactory]) *Session { + s := &Session{ + conn: conn, + isClient: true, + sendPadding: true, + padding: _padding, + } + s.die = make(chan struct{}) + s.streams = make(map[uint32]*Stream) + return s +} + +func NewServerSession(conn net.Conn, onNewStream func(stream *Stream), _padding *atomic.TypedValue[*padding.PaddingFactory]) *Session { + s := &Session{ + conn: conn, + onNewStream: onNewStream, + padding: _padding, + } + s.die = make(chan struct{}) + s.streams = make(map[uint32]*Stream) + return s +} + +func (s *Session) Run() { + if !s.isClient { + s.recvLoop() + return + } + + settings := util.StringMap{ + "v": "1", + "client": "sing-box/" + constant.Version, + "padding-md5": s.padding.Load().Md5, + } + f := newFrame(cmdSettings, 0) + f.data = settings.ToBytes() + s.buffering = true + s.writeFrame(f) + + go s.recvLoop() +} + +// IsClosed does a safe check to see if we have shutdown +func (s *Session) IsClosed() bool { + select { + case <-s.die: + return true + default: + return false + } +} + +// Close is used to close the session and all streams. +func (s *Session) Close() error { + var once bool + s.dieOnce.Do(func() { + close(s.die) + once = true + }) + + if once { + if s.dieHook != nil { + s.dieHook() + } + s.streamLock.Lock() + for k := range s.streams { + s.streams[k].sessionClose() + } + s.streamLock.Unlock() + return s.conn.Close() + } else { + return io.ErrClosedPipe + } +} + +// OpenStream is used to create a new stream for CLIENT +func (s *Session) OpenStream() (*Stream, error) { + if s.IsClosed() { + return nil, io.ErrClosedPipe + } + + sid := s.streamId.Add(1) + stream := newStream(sid, s) + + //logrus.Debugln("stream open", sid, s.streams) + + if _, err := s.writeFrame(newFrame(cmdSYN, sid)); err != nil { + return nil, err + } + + s.buffering = false // proxy Write it's SocksAddr to flush the buffer + + s.streamLock.Lock() + defer s.streamLock.Unlock() + select { + case <-s.die: + return nil, io.ErrClosedPipe + default: + s.streams[sid] = stream + return stream, nil + } +} + +func (s *Session) recvLoop() error { + defer func() { + if r := recover(); r != nil { + log.Error("[BUG]", r, string(debug.Stack())) + } + }() + defer s.Close() + + var receivedSettingsFromClient bool + var hdr rawHeader + + for { + if s.IsClosed() { + return io.ErrClosedPipe + } + // read header first + if _, err := io.ReadFull(s.conn, hdr[:]); err == nil { + sid := hdr.StreamID() + switch hdr.Cmd() { + case cmdPSH: + if hdr.Length() > 0 { + buffer := buf.Get(int(hdr.Length())) + if _, err := io.ReadFull(s.conn, buffer); err == nil { + s.streamLock.RLock() + stream, ok := s.streams[sid] + s.streamLock.RUnlock() + if ok { + stream.pipeW.Write(buffer) + } + buf.Put(buffer) + } else { + buf.Put(buffer) + return err + } + } + case cmdSYN: // should be server only + if !s.isClient && !receivedSettingsFromClient { + f := newFrame(cmdAlert, 0) + f.data = []byte("client did not send its settings") + s.writeFrame(f) + return nil + } + s.streamLock.Lock() + if _, ok := s.streams[sid]; !ok { + stream := newStream(sid, s) + s.streams[sid] = stream + if s.onNewStream != nil { + go s.onNewStream(stream) + } else { + go s.Close() + } + } + s.streamLock.Unlock() + case cmdFIN: + s.streamLock.RLock() + stream, ok := s.streams[sid] + s.streamLock.RUnlock() + if ok { + stream.Close() + } + //logrus.Debugln("stream fin", sid, s.streams) + case cmdWaste: + if hdr.Length() > 0 { + buffer := buf.Get(int(hdr.Length())) + if _, err := io.ReadFull(s.conn, buffer); err != nil { + buf.Put(buffer) + return err + } + buf.Put(buffer) + } + case cmdSettings: + if hdr.Length() > 0 { + buffer := buf.Get(int(hdr.Length())) + if _, err := io.ReadFull(s.conn, buffer); err != nil { + buf.Put(buffer) + return err + } + if !s.isClient { + receivedSettingsFromClient = true + m := util.StringMapFromBytes(buffer) + paddingF := s.padding.Load() + if m["padding-md5"] != paddingF.Md5 { + // logrus.Debugln("remote md5 is", m["padding-md5"]) + f := newFrame(cmdUpdatePaddingScheme, 0) + f.data = paddingF.RawScheme + _, err = s.writeFrame(f) + if err != nil { + buf.Put(buffer) + return err + } + } + } + buf.Put(buffer) + } + case cmdAlert: + if hdr.Length() > 0 { + buffer := buf.Get(int(hdr.Length())) + if _, err := io.ReadFull(s.conn, buffer); err != nil { + buf.Put(buffer) + return err + } + if s.isClient { + log.Error("[Alert from server]", string(buffer)) + } + buf.Put(buffer) + return nil + } + case cmdUpdatePaddingScheme: + if hdr.Length() > 0 { + // `rawScheme` Do not use buffer to prevent subsequent misuse + rawScheme := make([]byte, int(hdr.Length())) + if _, err := io.ReadFull(s.conn, rawScheme); err != nil { + return err + } + if s.isClient { + if padding.UpdatePaddingScheme(rawScheme, s.padding) { + log.Info(fmt.Sprintf("[Update padding succeed] %x\n", md5.Sum(rawScheme))) + } else { + log.Warn(fmt.Sprintf("[Update padding failed] %x\n", md5.Sum(rawScheme))) + } + } + } + default: + // I don't know what command it is (can't have data) + } + } else { + return err + } + } +} + +// notify the session that a stream has closed +func (s *Session) streamClosed(sid uint32) error { + _, err := s.writeFrame(newFrame(cmdFIN, sid)) + s.streamLock.Lock() + delete(s.streams, sid) + s.streamLock.Unlock() + return err +} + +func (s *Session) writeFrame(frame frame) (int, error) { + dataLen := len(frame.data) + + buffer := buf.NewSize(dataLen + headerOverHeadSize) + buffer.WriteByte(frame.cmd) + binary.BigEndian.PutUint32(buffer.Extend(4), frame.sid) + binary.BigEndian.PutUint16(buffer.Extend(2), uint16(dataLen)) + buffer.Write(frame.data) + _, err := s.writeConn(buffer.Bytes()) + buffer.Release() + if err != nil { + return 0, err + } + + return dataLen, nil +} + +func (s *Session) writeConn(b []byte) (n int, err error) { + s.connLock.Lock() + defer s.connLock.Unlock() + + if s.buffering { + s.buffer = append(s.buffer, b...) + return len(b), nil + } else if len(s.buffer) > 0 { + b = append(s.buffer, b...) + s.buffer = nil + } + + // calulate & send padding + if s.sendPadding { + pkt := s.pktCounter.Add(1) + paddingF := s.padding.Load() + if pkt < paddingF.Stop { + pktSizes := paddingF.GenerateRecordPayloadSizes(pkt) + for _, l := range pktSizes { + remainPayloadLen := len(b) + if l == padding.CheckMark { + if remainPayloadLen == 0 { + break + } else { + continue + } + } + if remainPayloadLen > l { // this packet is all payload + _, err = s.conn.Write(b[:l]) + if err != nil { + return 0, err + } + n += l + b = b[l:] + } else if remainPayloadLen > 0 { // this packet contains padding and the last part of payload + paddingLen := l - remainPayloadLen - headerOverHeadSize + if paddingLen > 0 { + padding := make([]byte, headerOverHeadSize+paddingLen) + padding[0] = cmdWaste + binary.BigEndian.PutUint32(padding[1:5], 0) + binary.BigEndian.PutUint16(padding[5:7], uint16(paddingLen)) + b = append(b, padding...) + } + _, err = s.conn.Write(b) + if err != nil { + return 0, err + } + n += remainPayloadLen + b = nil + } else { // this packet is all padding + padding := make([]byte, headerOverHeadSize+l) + padding[0] = cmdWaste + binary.BigEndian.PutUint32(padding[1:5], 0) + binary.BigEndian.PutUint16(padding[5:7], uint16(l)) + _, err = s.conn.Write(padding) + if err != nil { + return 0, err + } + b = nil + } + } + // maybe still remain payload to write + if len(b) == 0 { + return + } else { + n2, err := s.conn.Write(b) + return n + n2, err + } + } else { + s.sendPadding = false + } + } + + return s.conn.Write(b) +} diff --git a/transport/anytls/session/stream.go b/transport/anytls/session/stream.go new file mode 100644 index 00000000..f1e6b8b9 --- /dev/null +++ b/transport/anytls/session/stream.go @@ -0,0 +1,110 @@ +package session + +import ( + "io" + "net" + "os" + "sync" + "time" + + "github.com/sagernet/sing-box/transport/anytls/pipe" +) + +// Stream implements net.Conn +type Stream struct { + id uint32 + + sess *Session + + pipeR *pipe.PipeReader + pipeW *pipe.PipeWriter + writeDeadline pipe.PipeDeadline + + dieOnce sync.Once + dieHook func() +} + +// newStream initiates a Stream struct +func newStream(id uint32, sess *Session) *Stream { + s := new(Stream) + s.id = id + s.sess = sess + s.pipeR, s.pipeW = pipe.Pipe() + s.writeDeadline = pipe.MakePipeDeadline() + return s +} + +// Read implements net.Conn +func (s *Stream) Read(b []byte) (n int, err error) { + return s.pipeR.Read(b) +} + +// Write implements net.Conn +func (s *Stream) Write(b []byte) (n int, err error) { + select { + case <-s.writeDeadline.Wait(): + return 0, os.ErrDeadlineExceeded + default: + } + f := newFrame(cmdPSH, s.id) + f.data = b + n, err = s.sess.writeFrame(f) + return +} + +// Close implements net.Conn +func (s *Stream) Close() error { + if s.sessionClose() { + // notify remote + return s.sess.streamClosed(s.id) + } else { + return io.ErrClosedPipe + } +} + +// sessionClose close stream from session side, do not notify remote +func (s *Stream) sessionClose() (once bool) { + s.dieOnce.Do(func() { + s.pipeR.Close() + once = true + if s.dieHook != nil { + s.dieHook() + s.dieHook = nil + } + }) + return +} + +func (s *Stream) SetReadDeadline(t time.Time) error { + return s.pipeR.SetReadDeadline(t) +} + +func (s *Stream) SetWriteDeadline(t time.Time) error { + s.writeDeadline.Set(t) + return nil +} + +func (s *Stream) SetDeadline(t time.Time) error { + s.SetWriteDeadline(t) + return s.SetReadDeadline(t) +} + +// LocalAddr satisfies net.Conn interface +func (s *Stream) LocalAddr() net.Addr { + if ts, ok := s.sess.conn.(interface { + LocalAddr() net.Addr + }); ok { + return ts.LocalAddr() + } + return nil +} + +// RemoteAddr satisfies net.Conn interface +func (s *Stream) RemoteAddr() net.Addr { + if ts, ok := s.sess.conn.(interface { + RemoteAddr() net.Addr + }); ok { + return ts.RemoteAddr() + } + return nil +} diff --git a/transport/anytls/skiplist/contianer.go b/transport/anytls/skiplist/contianer.go new file mode 100644 index 00000000..ceda0421 --- /dev/null +++ b/transport/anytls/skiplist/contianer.go @@ -0,0 +1,46 @@ +package skiplist + +// Container is a holder object that stores a collection of other objects. +type Container interface { + IsEmpty() bool // IsEmpty checks if the container has no elements. + Len() int // Len returns the number of elements in the container. + Clear() // Clear erases all elements from the container. After this call, Len() returns zero. +} + +// Map is a associative container that contains key-value pairs with unique keys. +type Map[K any, V any] interface { + Container + Has(K) bool // Checks whether the container contains element with specific key. + Find(K) *V // Finds element with specific key. + Insert(K, V) // Inserts a key-value pair in to the container or replace existing value. + Remove(K) bool // Remove element with specific key. + ForEach(func(K, V)) // Iterate the container. + ForEachIf(func(K, V) bool) // Iterate the container, stops when the callback returns false. + ForEachMutable(func(K, *V)) // Iterate the container, *V is mutable. + ForEachMutableIf(func(K, *V) bool) // Iterate the container, *V is mutable, stops when the callback returns false. +} + +// Set is a containers that store unique elements. +type Set[K any] interface { + Container + Has(K) bool // Checks whether the container contains element with specific key. + Insert(K) // Inserts a key-value pair in to the container or replace existing value. + InsertN(...K) // Inserts multiple key-value pairs in to the container or replace existing value. + Remove(K) bool // Remove element with specific key. + RemoveN(...K) // Remove multiple elements with specific keys. + ForEach(func(K)) // Iterate the container. + ForEachIf(func(K) bool) // Iterate the container, stops when the callback returns false. +} + +// Iterator is the interface for container's iterator. +type Iterator[T any] interface { + IsNotEnd() bool // Whether it is point to the end of the range. + MoveToNext() // Let it point to the next element. + Value() T // Return the value of current element. +} + +// MapIterator is the interface for map's iterator. +type MapIterator[K any, V any] interface { + Iterator[V] + Key() K // The key of the element +} diff --git a/transport/anytls/skiplist/skiplist.go b/transport/anytls/skiplist/skiplist.go new file mode 100644 index 00000000..f1ce402a --- /dev/null +++ b/transport/anytls/skiplist/skiplist.go @@ -0,0 +1,457 @@ +package skiplist + +// This implementation is based on https://github.com/liyue201/gostl/tree/master/ds/skiplist +// (many thanks), added many optimizations, such as: +// +// - adaptive level +// - lesser search for prevs when key already exists. +// - reduce memory allocations +// - richer interface. +// +// etc. + +import ( + "math/bits" + "math/rand" + "time" +) + +const ( + skipListMaxLevel = 40 +) + +// SkipList is a probabilistic data structure that seem likely to supplant balanced trees as the +// implementation method of choice for many applications. Skip list algorithms have the same +// asymptotic expected time bounds as balanced trees and are simpler, faster and use less space. +// +// See https://en.wikipedia.org/wiki/Skip_list for more details. +type SkipList[K any, V any] struct { + level int // Current level, may increase dynamically during insertion + len int // Total elements numner in the skiplist. + head skipListNode[K, V] // head.next[level] is the head of each level. + // This cache is used to save the previous nodes when modifying the skip list to avoid + // allocating memory each time it is called. + prevsCache []*skipListNode[K, V] + rander *rand.Rand + impl skipListImpl[K, V] +} + +// NewSkipList creates a new SkipList for Ordered key type. +func NewSkipList[K Ordered, V any]() *SkipList[K, V] { + sl := skipListOrdered[K, V]{} + sl.init() + sl.impl = (skipListImpl[K, V])(&sl) + return &sl.SkipList +} + +// NewSkipListFromMap creates a new SkipList from a map. +func NewSkipListFromMap[K Ordered, V any](m map[K]V) *SkipList[K, V] { + sl := NewSkipList[K, V]() + for k, v := range m { + sl.Insert(k, v) + } + return sl +} + +// NewSkipListFunc creates a new SkipList with specified compare function keyCmp. +func NewSkipListFunc[K any, V any](keyCmp CompareFn[K]) *SkipList[K, V] { + sl := skipListFunc[K, V]{} + sl.init() + sl.keyCmp = keyCmp + sl.impl = skipListImpl[K, V](&sl) + return &sl.SkipList +} + +// IsEmpty implements the Container interface. +func (sl *SkipList[K, V]) IsEmpty() bool { + return sl.len == 0 +} + +// Len implements the Container interface. +func (sl *SkipList[K, V]) Len() int { + return sl.len +} + +// Clear implements the Container interface. +func (sl *SkipList[K, V]) Clear() { + for i := range sl.head.next { + sl.head.next[i] = nil + } + sl.level = 1 + sl.len = 0 +} + +// Iterate return an iterator to the skiplist. +func (sl *SkipList[K, V]) Iterate() MapIterator[K, V] { + return &skipListIterator[K, V]{sl.head.next[0], nil} +} + +// Insert inserts a key-value pair into the skiplist. +// If the key is already in the skip list, it's value will be updated. +func (sl *SkipList[K, V]) Insert(key K, value V) { + node, prevs := sl.impl.findInsertPoint(key) + + if node != nil { + // Already exist, update the value + node.value = value + return + } + + level := sl.randomLevel() + node = newSkipListNode(level, key, value) + + minLevel := level + if sl.level < level { + minLevel = sl.level + } + for i := 0; i < minLevel; i++ { + node.next[i] = prevs[i].next[i] + prevs[i].next[i] = node + } + + if level > sl.level { + for i := sl.level; i < level; i++ { + sl.head.next[i] = node + } + sl.level = level + } + + sl.len++ +} + +// Find returns the value associated with the passed key if the key is in the skiplist, otherwise +// returns nil. +func (sl *SkipList[K, V]) Find(key K) *V { + node := sl.impl.findNode(key) + if node != nil { + return &node.value + } + return nil +} + +// Has implement the Map interface. +func (sl *SkipList[K, V]) Has(key K) bool { + return sl.impl.findNode(key) != nil +} + +// LowerBound returns an iterator to the first element in the skiplist that +// does not satisfy element < value (i.e. greater or equal to), +// or a end itetator if no such element is found. +func (sl *SkipList[K, V]) LowerBound(key K) MapIterator[K, V] { + return &skipListIterator[K, V]{sl.impl.lowerBound(key), nil} +} + +// UpperBound returns an iterator to the first element in the skiplist that +// does not satisfy value < element (i.e. strictly greater), +// or a end itetator if no such element is found. +func (sl *SkipList[K, V]) UpperBound(key K) MapIterator[K, V] { + return &skipListIterator[K, V]{sl.impl.upperBound(key), nil} +} + +// FindRange returns an iterator in range [first, last) (last is not includeed). +func (sl *SkipList[K, V]) FindRange(first, last K) MapIterator[K, V] { + return &skipListIterator[K, V]{sl.impl.lowerBound(first), sl.impl.upperBound(last)} +} + +// Remove removes the key-value pair associated with the passed key and returns true if the key is +// in the skiplist, otherwise returns false. +func (sl *SkipList[K, V]) Remove(key K) bool { + node, prevs := sl.impl.findRemovePoint(key) + if node == nil { + return false + } + for i, v := range node.next { + prevs[i].next[i] = v + } + for sl.level > 1 && sl.head.next[sl.level-1] == nil { + sl.level-- + } + sl.len-- + return true +} + +// ForEach implements the Map interface. +func (sl *SkipList[K, V]) ForEach(op func(K, V)) { + for e := sl.head.next[0]; e != nil; e = e.next[0] { + op(e.key, e.value) + } +} + +// ForEachMutable implements the Map interface. +func (sl *SkipList[K, V]) ForEachMutable(op func(K, *V)) { + for e := sl.head.next[0]; e != nil; e = e.next[0] { + op(e.key, &e.value) + } +} + +// ForEachIf implements the Map interface. +func (sl *SkipList[K, V]) ForEachIf(op func(K, V) bool) { + for e := sl.head.next[0]; e != nil; e = e.next[0] { + if !op(e.key, e.value) { + return + } + } +} + +// ForEachMutableIf implements the Map interface. +func (sl *SkipList[K, V]) ForEachMutableIf(op func(K, *V) bool) { + for e := sl.head.next[0]; e != nil; e = e.next[0] { + if !op(e.key, &e.value) { + return + } + } +} + +/// SkipList implementation part. + +type skipListNode[K any, V any] struct { + key K + value V + next []*skipListNode[K, V] +} + +//go:generate bash ./skiplist_newnode_generate.sh skipListMaxLevel skiplist_newnode.go +// func newSkipListNode[K Ordered, V any](level int, key K, value V) *skipListNode[K, V] + +type skipListIterator[K any, V any] struct { + node, end *skipListNode[K, V] +} + +func (it *skipListIterator[K, V]) IsNotEnd() bool { + return it.node != it.end +} + +func (it *skipListIterator[K, V]) MoveToNext() { + it.node = it.node.next[0] +} + +func (it *skipListIterator[K, V]) Key() K { + return it.node.key +} + +func (it *skipListIterator[K, V]) Value() V { + return it.node.value +} + +// skipListImpl is an interface to provide different implementation for Ordered key or CompareFn. +// +// We can use CompareFn to cumpare Ordered keys, but a separated implementation is much faster. +// We don't make the whole skip list an interface, in order to share the type independented method. +// And because these methods are called directly without going through the interface, they are also +// much faster. +type skipListImpl[K any, V any] interface { + findNode(key K) *skipListNode[K, V] + lowerBound(key K) *skipListNode[K, V] + upperBound(key K) *skipListNode[K, V] + findInsertPoint(key K) (*skipListNode[K, V], []*skipListNode[K, V]) + findRemovePoint(key K) (*skipListNode[K, V], []*skipListNode[K, V]) +} + +func (sl *SkipList[K, V]) init() { + sl.level = 1 + // #nosec G404 -- This is not a security condition + sl.rander = rand.New(rand.NewSource(time.Now().Unix())) + sl.prevsCache = make([]*skipListNode[K, V], skipListMaxLevel) + sl.head.next = make([]*skipListNode[K, V], skipListMaxLevel) +} + +func (sl *SkipList[K, V]) randomLevel() int { + total := uint64(1)< 3 && 1<<(level-3) > sl.len { + level-- + } + + return level +} + +/// skipListOrdered part + +// skipListOrdered is the skip list implementation for Ordered types. +type skipListOrdered[K Ordered, V any] struct { + SkipList[K, V] +} + +func (sl *skipListOrdered[K, V]) findNode(key K) *skipListNode[K, V] { + return sl.doFindNode(key, true) +} + +func (sl *skipListOrdered[K, V]) doFindNode(key K, eq bool) *skipListNode[K, V] { + // This function execute the job of findNode if eq is true, otherwise lowBound. + // Passing the control variable eq is ugly but it's faster than testing node + // again outside the function in findNode. + prev := &sl.head + for i := sl.level - 1; i >= 0; i-- { + for cur := prev.next[i]; cur != nil; cur = cur.next[i] { + if cur.key == key { + return cur + } + if cur.key > key { + // All other node in this level must be greater than the key, + // search the next level. + break + } + prev = cur + } + } + if eq { + return nil + } + return prev.next[0] +} + +func (sl *skipListOrdered[K, V]) lowerBound(key K) *skipListNode[K, V] { + return sl.doFindNode(key, false) +} + +func (sl *skipListOrdered[K, V]) upperBound(key K) *skipListNode[K, V] { + node := sl.lowerBound(key) + if node != nil && node.key == key { + return node.next[0] + } + return node +} + +// findInsertPoint returns (*node, nil) to the existed node if the key exists, +// or (nil, []*node) to the previous nodes if the key doesn't exist +func (sl *skipListOrdered[K, V]) findInsertPoint(key K) (*skipListNode[K, V], []*skipListNode[K, V]) { + prevs := sl.prevsCache[0:sl.level] + prev := &sl.head + for i := sl.level - 1; i >= 0; i-- { + for next := prev.next[i]; next != nil; next = next.next[i] { + if next.key == key { + // The key is already existed, prevs are useless because no new node insertion. + // stop searching. + return next, nil + } + if next.key > key { + // All other node in this level must be greater than the key, + // search the next level. + break + } + prev = next + } + prevs[i] = prev + } + return nil, prevs +} + +// findRemovePoint finds the node which match the key and it's previous nodes. +func (sl *skipListOrdered[K, V]) findRemovePoint(key K) (*skipListNode[K, V], []*skipListNode[K, V]) { + prevs := sl.findPrevNodes(key) + node := prevs[0].next[0] + if node == nil || node.key != key { + return nil, nil + } + return node, prevs +} + +func (sl *skipListOrdered[K, V]) findPrevNodes(key K) []*skipListNode[K, V] { + prevs := sl.prevsCache[0:sl.level] + prev := &sl.head + for i := sl.level - 1; i >= 0; i-- { + for next := prev.next[i]; next != nil; next = next.next[i] { + if next.key >= key { + break + } + prev = next + } + prevs[i] = prev + } + return prevs +} + +/// skipListFunc part + +// skipListFunc is the skip list implementation which compare keys with func. +type skipListFunc[K any, V any] struct { + SkipList[K, V] + keyCmp CompareFn[K] +} + +func (sl *skipListFunc[K, V]) findNode(key K) *skipListNode[K, V] { + node := sl.lowerBound(key) + if node != nil && sl.keyCmp(node.key, key) == 0 { + return node + } + return nil +} + +func (sl *skipListFunc[K, V]) lowerBound(key K) *skipListNode[K, V] { + var prev = &sl.head + for i := sl.level - 1; i >= 0; i-- { + cur := prev.next[i] + for ; cur != nil; cur = cur.next[i] { + cmpRet := sl.keyCmp(cur.key, key) + if cmpRet == 0 { + return cur + } + if cmpRet > 0 { + break + } + prev = cur + } + } + return prev.next[0] +} + +func (sl *skipListFunc[K, V]) upperBound(key K) *skipListNode[K, V] { + node := sl.lowerBound(key) + if node != nil && sl.keyCmp(node.key, key) == 0 { + return node.next[0] + } + return node +} + +// findInsertPoint returns (*node, nil) to the existed node if the key exists, +// or (nil, []*node) to the previous nodes if the key doesn't exist +func (sl *skipListFunc[K, V]) findInsertPoint(key K) (*skipListNode[K, V], []*skipListNode[K, V]) { + prevs := sl.prevsCache[0:sl.level] + prev := &sl.head + for i := sl.level - 1; i >= 0; i-- { + for cur := prev.next[i]; cur != nil; cur = cur.next[i] { + r := sl.keyCmp(cur.key, key) + if r == 0 { + // The key is already existed, prevs are useless because no new node insertion. + // stop searching. + return cur, nil + } + if r > 0 { + // All other node in this level must be greater than the key, + // search the next level. + break + } + prev = cur + } + prevs[i] = prev + } + return nil, prevs +} + +// findRemovePoint finds the node which match the key and it's previous nodes. +func (sl *skipListFunc[K, V]) findRemovePoint(key K) (*skipListNode[K, V], []*skipListNode[K, V]) { + prevs := sl.findPrevNodes(key) + node := prevs[0].next[0] + if node == nil || sl.keyCmp(node.key, key) != 0 { + return nil, nil + } + return node, prevs +} + +func (sl *skipListFunc[K, V]) findPrevNodes(key K) []*skipListNode[K, V] { + prevs := sl.prevsCache[0:sl.level] + prev := &sl.head + for i := sl.level - 1; i >= 0; i-- { + for next := prev.next[i]; next != nil; next = next.next[i] { + if sl.keyCmp(next.key, key) >= 0 { + break + } + prev = next + } + prevs[i] = prev + } + return prevs +} diff --git a/transport/anytls/skiplist/skiplist_newnode.go b/transport/anytls/skiplist/skiplist_newnode.go new file mode 100644 index 00000000..4e8a6d88 --- /dev/null +++ b/transport/anytls/skiplist/skiplist_newnode.go @@ -0,0 +1,297 @@ +// AUTO GENERATED CODE, DON'T EDIT!!! +// EDIT skiplist_newnode_generate.sh accordingly. + +package skiplist + +// newSkipListNode creates a new node initialized with specified key, value and next slice. +func newSkipListNode[K any, V any](level int, key K, value V) *skipListNode[K, V] { + // For nodes with each levels, point their next slice to the nexts array allocated together, + // which can reduce 1 memory allocation and improve performance. + // + // The generics of the golang doesn't support non-type parameters like in C++, + // so we have to generate it manually. + switch level { + case 1: + n := struct { + head skipListNode[K, V] + nexts [1]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 2: + n := struct { + head skipListNode[K, V] + nexts [2]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 3: + n := struct { + head skipListNode[K, V] + nexts [3]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 4: + n := struct { + head skipListNode[K, V] + nexts [4]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 5: + n := struct { + head skipListNode[K, V] + nexts [5]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 6: + n := struct { + head skipListNode[K, V] + nexts [6]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 7: + n := struct { + head skipListNode[K, V] + nexts [7]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 8: + n := struct { + head skipListNode[K, V] + nexts [8]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 9: + n := struct { + head skipListNode[K, V] + nexts [9]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 10: + n := struct { + head skipListNode[K, V] + nexts [10]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 11: + n := struct { + head skipListNode[K, V] + nexts [11]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 12: + n := struct { + head skipListNode[K, V] + nexts [12]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 13: + n := struct { + head skipListNode[K, V] + nexts [13]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 14: + n := struct { + head skipListNode[K, V] + nexts [14]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 15: + n := struct { + head skipListNode[K, V] + nexts [15]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 16: + n := struct { + head skipListNode[K, V] + nexts [16]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 17: + n := struct { + head skipListNode[K, V] + nexts [17]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 18: + n := struct { + head skipListNode[K, V] + nexts [18]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 19: + n := struct { + head skipListNode[K, V] + nexts [19]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 20: + n := struct { + head skipListNode[K, V] + nexts [20]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 21: + n := struct { + head skipListNode[K, V] + nexts [21]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 22: + n := struct { + head skipListNode[K, V] + nexts [22]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 23: + n := struct { + head skipListNode[K, V] + nexts [23]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 24: + n := struct { + head skipListNode[K, V] + nexts [24]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 25: + n := struct { + head skipListNode[K, V] + nexts [25]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 26: + n := struct { + head skipListNode[K, V] + nexts [26]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 27: + n := struct { + head skipListNode[K, V] + nexts [27]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 28: + n := struct { + head skipListNode[K, V] + nexts [28]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 29: + n := struct { + head skipListNode[K, V] + nexts [29]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 30: + n := struct { + head skipListNode[K, V] + nexts [30]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 31: + n := struct { + head skipListNode[K, V] + nexts [31]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 32: + n := struct { + head skipListNode[K, V] + nexts [32]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 33: + n := struct { + head skipListNode[K, V] + nexts [33]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 34: + n := struct { + head skipListNode[K, V] + nexts [34]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 35: + n := struct { + head skipListNode[K, V] + nexts [35]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 36: + n := struct { + head skipListNode[K, V] + nexts [36]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 37: + n := struct { + head skipListNode[K, V] + nexts [37]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 38: + n := struct { + head skipListNode[K, V] + nexts [38]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 39: + n := struct { + head skipListNode[K, V] + nexts [39]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + case 40: + n := struct { + head skipListNode[K, V] + nexts [40]*skipListNode[K, V] + }{head: skipListNode[K, V]{key, value, nil}} + n.head.next = n.nexts[:] + return &n.head + } + + panic("should not reach here") +} diff --git a/transport/anytls/skiplist/types.go b/transport/anytls/skiplist/types.go new file mode 100644 index 00000000..c534f460 --- /dev/null +++ b/transport/anytls/skiplist/types.go @@ -0,0 +1,75 @@ +package skiplist + +// Signed is a constraint that permits any signed integer type. +// If future releases of Go add new predeclared signed integer types, +// this constraint will be modified to include them. +type Signed interface { + ~int | ~int8 | ~int16 | ~int32 | ~int64 +} + +// Unsigned is a constraint that permits any unsigned integer type. +// If future releases of Go add new predeclared unsigned integer types, +// this constraint will be modified to include them. +type Unsigned interface { + ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uintptr +} + +// Integer is a constraint that permits any integer type. +// If future releases of Go add new predeclared integer types, +// this constraint will be modified to include them. +type Integer interface { + Signed | Unsigned +} + +// Float is a constraint that permits any floating-point type. +// If future releases of Go add new predeclared floating-point types, +// this constraint will be modified to include them. +type Float interface { + ~float32 | ~float64 +} + +// Ordered is a constraint that permits any ordered type: any type +// that supports the operators < <= >= >. +// If future releases of Go add new ordered types, +// this constraint will be modified to include them. +type Ordered interface { + Integer | Float | ~string +} + +// Numeric is a constraint that permits any numeric type. +type Numeric interface { + Integer | Float +} + +// LessFn is a function that returns whether 'a' is less than 'b'. +type LessFn[T any] func(a, b T) bool + +// CompareFn is a 3 way compare function that +// returns 1 if a > b, +// returns 0 if a == b, +// returns -1 if a < b. +type CompareFn[T any] func(a, b T) int + +// HashFn is a function that returns the hash of 't'. +type HashFn[T any] func(t T) uint64 + +// Equals wraps the '==' operator for comparable types. +func Equals[T comparable](a, b T) bool { + return a == b +} + +// Less wraps the '<' operator for ordered types. +func Less[T Ordered](a, b T) bool { + return a < b +} + +// OrderedCompare provide default CompareFn for ordered types. +func OrderedCompare[T Ordered](a, b T) int { + if a < b { + return -1 + } + if a > b { + return 1 + } + return 0 +} diff --git a/transport/anytls/util/routine.go b/transport/anytls/util/routine.go new file mode 100644 index 00000000..029dbdca --- /dev/null +++ b/transport/anytls/util/routine.go @@ -0,0 +1,28 @@ +package util + +import ( + "context" + "runtime/debug" + "time" + + "github.com/sagernet/sing-box/log" +) + +func StartRoutine(ctx context.Context, d time.Duration, f func()) { + go func() { + defer func() { + if r := recover(); r != nil { + log.Error("[BUG]", r, string(debug.Stack())) + } + }() + for { + time.Sleep(d) + f() + select { + case <-ctx.Done(): + return + default: + } + } + }() +} diff --git a/transport/anytls/util/string_map.go b/transport/anytls/util/string_map.go new file mode 100644 index 00000000..27fb3581 --- /dev/null +++ b/transport/anytls/util/string_map.go @@ -0,0 +1,27 @@ +package util + +import ( + "strings" +) + +type StringMap map[string]string + +func (s StringMap) ToBytes() []byte { + var lines []string + for k, v := range s { + lines = append(lines, k+"="+v) + } + return []byte(strings.Join(lines, "\n")) +} + +func StringMapFromBytes(b []byte) StringMap { + var m = make(StringMap) + var lines = strings.Split(string(b), "\n") + for _, line := range lines { + v := strings.SplitN(line, "=", 2) + if len(v) == 2 { + m[v[0]] = v[1] + } + } + return m +}