diff --git a/common/tls/ech.go b/common/tls/ech.go index f375da23..880ca27c 100644 --- a/common/tls/ech.go +++ b/common/tls/ech.go @@ -46,7 +46,7 @@ func parseECHClientConfig(ctx context.Context, options option.OutboundTLSOptions tlsConfig.EncryptedClientHelloConfigList = block.Bytes return &STDClientConfig{tlsConfig}, nil } else { - return &STDECHClientConfig{STDClientConfig{tlsConfig}}, nil + return &STDECHClientConfig{STDClientConfig{tlsConfig}, service.FromContext[adapter.DNSRouter](ctx)}, nil } } @@ -99,9 +99,10 @@ func reloadECHKeys(echKeyPath string, tlsConfig *tls.Config) error { type STDECHClientConfig struct { STDClientConfig + dnsRouter adapter.DNSRouter } -func (s *STDClientConfig) ClientHandshake(ctx context.Context, conn net.Conn) (aTLS.Conn, error) { +func (s *STDECHClientConfig) ClientHandshake(ctx context.Context, conn net.Conn) (aTLS.Conn, error) { if len(s.config.EncryptedClientHelloConfigList) == 0 { message := &mDNS.Msg{ MsgHdr: mDNS.MsgHdr{ @@ -115,8 +116,7 @@ func (s *STDClientConfig) ClientHandshake(ctx context.Context, conn net.Conn) (a }, }, } - dnsRouter := service.FromContext[adapter.DNSRouter](ctx) - response, err := dnsRouter.Exchange(ctx, message, adapter.DNSQueryOptions{}) + response, err := s.dnsRouter.Exchange(ctx, message, adapter.DNSQueryOptions{}) if err != nil { return nil, E.Cause(err, "fetch ECH config list") } @@ -151,7 +151,7 @@ func (s *STDClientConfig) ClientHandshake(ctx context.Context, conn net.Conn) (a } func (s *STDECHClientConfig) Clone() Config { - return &STDECHClientConfig{STDClientConfig{s.config.Clone()}} + return &STDECHClientConfig{STDClientConfig{s.config.Clone()}, s.dnsRouter} } func UnmarshalECHKeys(raw []byte) ([]tls.EncryptedClientHelloKey, error) { diff --git a/common/tls/mkcert.go b/common/tls/mkcert.go index 12680c48..4e0ed102 100644 --- a/common/tls/mkcert.go +++ b/common/tls/mkcert.go @@ -12,6 +12,9 @@ import ( ) func GenerateKeyPair(parent *x509.Certificate, parentKey any, timeFunc func() time.Time, serverName string) (*tls.Certificate, error) { + if timeFunc == nil { + timeFunc = time.Now + } privateKeyPem, publicKeyPem, err := GenerateCertificate(parent, parentKey, timeFunc, serverName, timeFunc().Add(time.Hour)) if err != nil { return nil, err @@ -24,9 +27,6 @@ func GenerateKeyPair(parent *x509.Certificate, parentKey any, timeFunc func() ti } func GenerateCertificate(parent *x509.Certificate, parentKey any, timeFunc func() time.Time, serverName string, expire time.Time) (privateKeyPem []byte, publicKeyPem []byte, err error) { - if timeFunc == nil { - timeFunc = time.Now - } key, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { return