From d8e1cd0d51e21f4e29e05b573af4c81b6f7da886 Mon Sep 17 00:00:00 2001 From: Restia-Ashbell <107416976+Restia-Ashbell@users.noreply.github.com> Date: Thu, 12 Jun 2025 09:13:23 +0800 Subject: [PATCH] Add ECH support for uTLS --- common/tls/ech.go | 32 +++++----- common/tls/{ech_keygen.go => ech_shared.go} | 6 ++ common/tls/ech_stub.go | 2 +- common/tls/reality_client.go | 2 +- common/tls/std_client.go | 44 ++++++++------ common/tls/utls_client.go | 67 ++++++++++++--------- 6 files changed, 87 insertions(+), 66 deletions(-) rename common/tls/{ech_keygen.go => ech_shared.go} (97%) diff --git a/common/tls/ech.go b/common/tls/ech.go index 61d2a209..f4434604 100644 --- a/common/tls/ech.go +++ b/common/tls/ech.go @@ -25,7 +25,7 @@ import ( "golang.org/x/crypto/cryptobyte" ) -func parseECHClientConfig(ctx context.Context, stdConfig *STDClientConfig, options option.OutboundTLSOptions) (Config, error) { +func parseECHClientConfig(ctx context.Context, clientConfig ECHCapableConfig, options option.OutboundTLSOptions) (Config, error) { var echConfig []byte if len(options.ECH.Config) > 0 { echConfig = []byte(strings.Join(options.ECH.Config, "\n")) @@ -45,12 +45,12 @@ func parseECHClientConfig(ctx context.Context, stdConfig *STDClientConfig, optio if block == nil || block.Type != "ECH CONFIGS" || len(rest) > 0 { return nil, E.New("invalid ECH configs pem") } - stdConfig.config.EncryptedClientHelloConfigList = block.Bytes - return stdConfig, nil + clientConfig.SetECHConfigList(block.Bytes) + return clientConfig, nil } else { - return &STDECHClientConfig{ - STDClientConfig: stdConfig, - dnsRouter: service.FromContext[adapter.DNSRouter](ctx), + return &ECHClientConfig{ + ECHCapableConfig: clientConfig, + dnsRouter: service.FromContext[adapter.DNSRouter](ctx), }, nil } } @@ -102,15 +102,15 @@ func reloadECHKeys(echKeyPath string, tlsConfig *tls.Config) error { return nil } -type STDECHClientConfig struct { - *STDClientConfig +type ECHClientConfig struct { + ECHCapableConfig access sync.Mutex dnsRouter adapter.DNSRouter lastTTL time.Duration lastUpdate time.Time } -func (s *STDECHClientConfig) ClientHandshake(ctx context.Context, conn net.Conn) (aTLS.Conn, error) { +func (s *ECHClientConfig) ClientHandshake(ctx context.Context, conn net.Conn) (aTLS.Conn, error) { tlsConn, err := s.fetchAndHandshake(ctx, conn) if err != nil { return nil, err @@ -122,17 +122,17 @@ func (s *STDECHClientConfig) ClientHandshake(ctx context.Context, conn net.Conn) return tlsConn, nil } -func (s *STDECHClientConfig) fetchAndHandshake(ctx context.Context, conn net.Conn) (aTLS.Conn, error) { +func (s *ECHClientConfig) fetchAndHandshake(ctx context.Context, conn net.Conn) (aTLS.Conn, error) { s.access.Lock() defer s.access.Unlock() - if len(s.config.EncryptedClientHelloConfigList) == 0 || s.lastTTL == 0 || time.Now().Sub(s.lastUpdate) > s.lastTTL { + if len(s.ECHConfigList()) == 0 || s.lastTTL == 0 || time.Now().Sub(s.lastUpdate) > s.lastTTL { message := &mDNS.Msg{ MsgHdr: mDNS.MsgHdr{ RecursionDesired: true, }, Question: []mDNS.Question{ { - Name: mDNS.Fqdn(s.config.ServerName), + Name: mDNS.Fqdn(s.ServerName()), Qtype: mDNS.TypeHTTPS, Qclass: mDNS.ClassINET, }, @@ -157,21 +157,21 @@ func (s *STDECHClientConfig) fetchAndHandshake(ctx context.Context, conn net.Con } s.lastTTL = time.Duration(rr.Header().Ttl) * time.Second s.lastUpdate = time.Now() - s.config.EncryptedClientHelloConfigList = echConfigList + s.SetECHConfigList(echConfigList) break match } } } } - if len(s.config.EncryptedClientHelloConfigList) == 0 { + if len(s.ECHConfigList()) == 0 { return nil, E.New("no ECH config found in DNS records") } } return s.Client(conn) } -func (s *STDECHClientConfig) Clone() Config { - return &STDECHClientConfig{STDClientConfig: s.STDClientConfig.Clone().(*STDClientConfig), dnsRouter: s.dnsRouter, lastUpdate: s.lastUpdate} +func (s *ECHClientConfig) Clone() Config { + return &ECHClientConfig{ECHCapableConfig: s.ECHCapableConfig.Clone().(ECHCapableConfig), dnsRouter: s.dnsRouter, lastUpdate: s.lastUpdate} } func UnmarshalECHKeys(raw []byte) ([]tls.EncryptedClientHelloKey, error) { diff --git a/common/tls/ech_keygen.go b/common/tls/ech_shared.go similarity index 97% rename from common/tls/ech_keygen.go rename to common/tls/ech_shared.go index b6c353ac..f27f0727 100644 --- a/common/tls/ech_keygen.go +++ b/common/tls/ech_shared.go @@ -11,6 +11,12 @@ import ( "github.com/cloudflare/circl/kem" ) +type ECHCapableConfig interface { + Config + ECHConfigList() []byte + SetECHConfigList([]byte) +} + func ECHKeygenDefault(serverName string) (configPem string, keyPem string, err error) { cipherSuites := []echCipherSuite{ { diff --git a/common/tls/ech_stub.go b/common/tls/ech_stub.go index 8c93690c..3b6ffc23 100644 --- a/common/tls/ech_stub.go +++ b/common/tls/ech_stub.go @@ -10,7 +10,7 @@ import ( E "github.com/sagernet/sing/common/exceptions" ) -func parseECHClientConfig(ctx context.Context, options option.OutboundTLSOptions, tlsConfig *tls.Config) (Config, error) { +func parseECHClientConfig(ctx context.Context, clientConfig ECHCapableConfig, options option.OutboundTLSOptions) (Config, error) { return nil, E.New("ECH requires go1.24, please recompile your binary.") } diff --git a/common/tls/reality_client.go b/common/tls/reality_client.go index 823e8285..ca0e1e04 100644 --- a/common/tls/reality_client.go +++ b/common/tls/reality_client.go @@ -74,7 +74,7 @@ func NewRealityClient(ctx context.Context, serverAddress string, options option. if decodedLen > 8 { return nil, E.New("invalid short_id") } - return &RealityClientConfig{ctx, uClient, publicKey, shortID}, nil + return &RealityClientConfig{ctx, uClient.(*UTLSClientConfig), publicKey, shortID}, nil } func (e *RealityClientConfig) ServerName() string { diff --git a/common/tls/std_client.go b/common/tls/std_client.go index 49b239a9..0705d949 100644 --- a/common/tls/std_client.go +++ b/common/tls/std_client.go @@ -24,35 +24,43 @@ type STDClientConfig struct { recordFragment bool } -func (s *STDClientConfig) ServerName() string { - return s.config.ServerName +func (c *STDClientConfig) ServerName() string { + return c.config.ServerName } -func (s *STDClientConfig) SetServerName(serverName string) { - s.config.ServerName = serverName +func (c *STDClientConfig) SetServerName(serverName string) { + c.config.ServerName = serverName } -func (s *STDClientConfig) NextProtos() []string { - return s.config.NextProtos +func (c *STDClientConfig) NextProtos() []string { + return c.config.NextProtos } -func (s *STDClientConfig) SetNextProtos(nextProto []string) { - s.config.NextProtos = nextProto +func (c *STDClientConfig) SetNextProtos(nextProto []string) { + c.config.NextProtos = nextProto } -func (s *STDClientConfig) Config() (*STDConfig, error) { - return s.config, nil +func (c *STDClientConfig) Config() (*STDConfig, error) { + return c.config, nil } -func (s *STDClientConfig) Client(conn net.Conn) (Conn, error) { - if s.recordFragment { - conn = tf.NewConn(conn, s.ctx, s.fragment, s.recordFragment, s.fragmentFallbackDelay) +func (c *STDClientConfig) Client(conn net.Conn) (Conn, error) { + if c.recordFragment { + conn = tf.NewConn(conn, c.ctx, c.fragment, c.recordFragment, c.fragmentFallbackDelay) } - return tls.Client(conn, s.config), nil + return tls.Client(conn, c.config), nil } -func (s *STDClientConfig) Clone() Config { - return &STDClientConfig{s.ctx, s.config.Clone(), s.fragment, s.fragmentFallbackDelay, s.recordFragment} +func (c *STDClientConfig) Clone() Config { + return &STDClientConfig{c.ctx, c.config.Clone(), c.fragment, c.fragmentFallbackDelay, c.recordFragment} +} + +func (c *STDClientConfig) ECHConfigList() []byte { + return c.config.EncryptedClientHelloConfigList +} + +func (c *STDClientConfig) SetECHConfigList(EncryptedClientHelloConfigList []byte) { + c.config.EncryptedClientHelloConfigList = EncryptedClientHelloConfigList } func NewSTDClient(ctx context.Context, serverAddress string, options option.OutboundTLSOptions) (Config, error) { @@ -69,9 +77,7 @@ func NewSTDClient(ctx context.Context, serverAddress string, options option.Outb var tlsConfig tls.Config tlsConfig.Time = ntp.TimeFuncFromContext(ctx) tlsConfig.RootCAs = adapter.RootPoolFromContext(ctx) - if options.DisableSNI { - tlsConfig.ServerName = "127.0.0.1" - } else { + if !options.DisableSNI { tlsConfig.ServerName = serverName } if options.Insecure { diff --git a/common/tls/utls_client.go b/common/tls/utls_client.go index f9a61cd4..6ed81eb4 100644 --- a/common/tls/utls_client.go +++ b/common/tls/utls_client.go @@ -8,7 +8,6 @@ import ( "crypto/x509" "math/rand" "net" - "net/netip" "os" "strings" "time" @@ -32,46 +31,54 @@ type UTLSClientConfig struct { recordFragment bool } -func (e *UTLSClientConfig) ServerName() string { - return e.config.ServerName +func (c *UTLSClientConfig) ServerName() string { + return c.config.ServerName } -func (e *UTLSClientConfig) SetServerName(serverName string) { - e.config.ServerName = serverName +func (c *UTLSClientConfig) SetServerName(serverName string) { + c.config.ServerName = serverName } -func (e *UTLSClientConfig) NextProtos() []string { - return e.config.NextProtos +func (c *UTLSClientConfig) NextProtos() []string { + return c.config.NextProtos } -func (e *UTLSClientConfig) SetNextProtos(nextProto []string) { +func (c *UTLSClientConfig) SetNextProtos(nextProto []string) { if len(nextProto) == 1 && nextProto[0] == http2.NextProtoTLS { nextProto = append(nextProto, "http/1.1") } - e.config.NextProtos = nextProto + c.config.NextProtos = nextProto } -func (e *UTLSClientConfig) Config() (*STDConfig, error) { +func (c *UTLSClientConfig) Config() (*STDConfig, error) { return nil, E.New("unsupported usage for uTLS") } -func (e *UTLSClientConfig) Client(conn net.Conn) (Conn, error) { - if e.recordFragment { - conn = tf.NewConn(conn, e.ctx, e.fragment, e.recordFragment, e.fragmentFallbackDelay) +func (c *UTLSClientConfig) Client(conn net.Conn) (Conn, error) { + if c.recordFragment { + conn = tf.NewConn(conn, c.ctx, c.fragment, c.recordFragment, c.fragmentFallbackDelay) } - return &utlsALPNWrapper{utlsConnWrapper{utls.UClient(conn, e.config.Clone(), e.id)}, e.config.NextProtos}, nil + return &utlsALPNWrapper{utlsConnWrapper{utls.UClient(conn, c.config.Clone(), c.id)}, c.config.NextProtos}, nil } -func (e *UTLSClientConfig) SetSessionIDGenerator(generator func(clientHello []byte, sessionID []byte) error) { - e.config.SessionIDGenerator = generator +func (c *UTLSClientConfig) SetSessionIDGenerator(generator func(clientHello []byte, sessionID []byte) error) { + c.config.SessionIDGenerator = generator } -func (e *UTLSClientConfig) Clone() Config { +func (c *UTLSClientConfig) Clone() Config { return &UTLSClientConfig{ - e.ctx, e.config.Clone(), e.id, e.fragment, e.fragmentFallbackDelay, e.recordFragment, + c.ctx, c.config.Clone(), c.id, c.fragment, c.fragmentFallbackDelay, c.recordFragment, } } +func (c *UTLSClientConfig) ECHConfigList() []byte { + return c.config.EncryptedClientHelloConfigList +} + +func (c *UTLSClientConfig) SetECHConfigList(EncryptedClientHelloConfigList []byte) { + c.config.EncryptedClientHelloConfigList = EncryptedClientHelloConfigList +} + type utlsConnWrapper struct { *utls.UConn } @@ -124,14 +131,12 @@ func (c *utlsALPNWrapper) HandshakeContext(ctx context.Context) error { return c.UConn.HandshakeContext(ctx) } -func NewUTLSClient(ctx context.Context, serverAddress string, options option.OutboundTLSOptions) (*UTLSClientConfig, error) { +func NewUTLSClient(ctx context.Context, serverAddress string, options option.OutboundTLSOptions) (Config, error) { var serverName string if options.ServerName != "" { serverName = options.ServerName } else if serverAddress != "" { - if _, err := netip.ParseAddr(serverName); err != nil { - serverName = serverAddress - } + serverName = serverAddress } if serverName == "" && !options.Insecure { return nil, E.New("missing server_name or insecure=true") @@ -140,11 +145,7 @@ func NewUTLSClient(ctx context.Context, serverAddress string, options option.Out var tlsConfig utls.Config tlsConfig.Time = ntp.TimeFuncFromContext(ctx) tlsConfig.RootCAs = adapter.RootPoolFromContext(ctx) - if options.DisableSNI { - tlsConfig.ServerName = "127.0.0.1" - } else { - tlsConfig.ServerName = serverName - } + tlsConfig.ServerName = serverName if options.Insecure { tlsConfig.InsecureSkipVerify = options.Insecure } else if options.DisableSNI { @@ -200,7 +201,15 @@ func NewUTLSClient(ctx context.Context, serverAddress string, options option.Out if err != nil { return nil, err } - return &UTLSClientConfig{ctx, &tlsConfig, id, options.Fragment, time.Duration(options.FragmentFallbackDelay), options.RecordFragment}, nil + uConfig := &UTLSClientConfig{ctx, &tlsConfig, id, options.Fragment, time.Duration(options.FragmentFallbackDelay), options.RecordFragment} + if options.ECH != nil && options.ECH.Enabled { + if options.Reality != nil && options.Reality.Enabled { + return nil, E.New("Reality is conflict with ECH") + } + return parseECHClientConfig(ctx, uConfig, options) + } else { + return uConfig, nil + } } var ( @@ -228,7 +237,7 @@ func init() { func uTLSClientHelloID(name string) (utls.ClientHelloID, error) { switch name { - case "chrome_psk", "chrome_psk_shuffle", "chrome_padding_psk_shuffle", "chrome_pq": + case "chrome_psk", "chrome_psk_shuffle", "chrome_padding_psk_shuffle", "chrome_pq", "chrome_pq_psk": fallthrough case "chrome", "": return utls.HelloChrome_Auto, nil