From e625012219a8b03f33205c028b4a76f3e2238af9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 22 Apr 2025 15:08:30 +0800 Subject: [PATCH] Fix fetch ECH configs --- common/tls/ech.go | 42 +++++++++++++++++++++++++++++------------- 1 file changed, 29 insertions(+), 13 deletions(-) diff --git a/common/tls/ech.go b/common/tls/ech.go index ddb9b5dd..de911126 100644 --- a/common/tls/ech.go +++ b/common/tls/ech.go @@ -10,6 +10,8 @@ import ( "net" "os" "strings" + "sync" + "time" "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/dns" @@ -46,7 +48,10 @@ func parseECHClientConfig(ctx context.Context, options option.OutboundTLSOptions tlsConfig.EncryptedClientHelloConfigList = block.Bytes return &STDClientConfig{tlsConfig}, nil } else { - return &STDECHClientConfig{STDClientConfig{tlsConfig}, service.FromContext[adapter.DNSRouter](ctx)}, nil + return &STDECHClientConfig{ + STDClientConfig: STDClientConfig{tlsConfig}, + dnsRouter: service.FromContext[adapter.DNSRouter](ctx), + }, nil } } @@ -99,11 +104,28 @@ func reloadECHKeys(echKeyPath string, tlsConfig *tls.Config) error { type STDECHClientConfig struct { STDClientConfig - dnsRouter adapter.DNSRouter + access sync.Mutex + dnsRouter adapter.DNSRouter + lastTTL time.Duration + lastUpdate time.Time } func (s *STDECHClientConfig) ClientHandshake(ctx context.Context, conn net.Conn) (aTLS.Conn, error) { - if len(s.config.EncryptedClientHelloConfigList) == 0 { + tlsConn, err := s.fetchAndHandshake(ctx, conn) + if err != nil { + return nil, err + } + err = tlsConn.HandshakeContext(ctx) + if err != nil { + return nil, err + } + return tlsConn, nil +} + +func (s *STDECHClientConfig) fetchAndHandshake(ctx context.Context, conn net.Conn) (aTLS.Conn, error) { + s.access.Lock() + defer s.access.Unlock() + if len(s.config.EncryptedClientHelloConfigList) == 0 || s.lastTTL == 0 || time.Now().Sub(s.lastUpdate) > s.lastTTL { message := &mDNS.Msg{ MsgHdr: mDNS.MsgHdr{ RecursionDesired: true, @@ -133,6 +155,8 @@ func (s *STDECHClientConfig) ClientHandshake(ctx context.Context, conn net.Conn) if err != nil { return nil, E.Cause(err, "decode ECH config") } + s.lastTTL = time.Duration(rr.Header().Ttl) * time.Second + s.lastUpdate = time.Now() s.config.EncryptedClientHelloConfigList = echConfigList break match } @@ -143,19 +167,11 @@ func (s *STDECHClientConfig) ClientHandshake(ctx context.Context, conn net.Conn) return nil, E.New("no ECH config found in DNS records") } } - tlsConn, err := s.Client(conn) - if err != nil { - return nil, err - } - err = tlsConn.HandshakeContext(ctx) - if err != nil { - return nil, err - } - return tlsConn, nil + return s.Client(conn) } func (s *STDECHClientConfig) Clone() Config { - return &STDECHClientConfig{STDClientConfig{s.config.Clone()}, s.dnsRouter} + return &STDECHClientConfig{STDClientConfig: STDClientConfig{s.config.Clone()}, dnsRouter: s.dnsRouter, lastUpdate: s.lastUpdate} } func UnmarshalECHKeys(raw []byte) ([]tls.EncryptedClientHelloKey, error) {