diff --git a/common/tls/std_server.go b/common/tls/std_server.go index 2c875855..ec7707e6 100644 --- a/common/tls/std_server.go +++ b/common/tls/std_server.go @@ -3,6 +3,7 @@ package tls import ( "context" "crypto/tls" + "crypto/x509" "net" "os" @@ -25,6 +26,7 @@ type STDServerConfig struct { key []byte certificatePath string keyPath string + rejectHandshake bool watcher *fsnotify.Watcher } @@ -141,7 +143,9 @@ func (c *STDServerConfig) reloadKeyPair() error { if err != nil { return E.Cause(err, "reload key pair") } - c.config.Certificates = []tls.Certificate{keyPair} + setGetCertificateFunc(c.config, func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { + return &keyPair, nil + }, c.rejectHandshake) c.logger.Info("reloaded TLS certificate") return nil } @@ -230,9 +234,9 @@ func NewSTDServer(ctx context.Context, router adapter.Router, logger log.Logger, key = content } if certificate == nil && key == nil && options.Insecure { - tlsConfig.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { + setGetCertificateFunc(tlsConfig, func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { return GenerateKeyPair(router.TimeFunc(), info.ServerName) - } + }, options.RejectHandshake) } else { if certificate == nil { return nil, E.New("missing certificate") @@ -244,7 +248,9 @@ func NewSTDServer(ctx context.Context, router adapter.Router, logger log.Logger, if err != nil { return nil, E.Cause(err, "parse x509 key pair") } - tlsConfig.Certificates = []tls.Certificate{keyPair} + setGetCertificateFunc(tlsConfig, func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { + return &keyPair, nil + }, options.RejectHandshake) } } return &STDServerConfig{ @@ -255,5 +261,30 @@ func NewSTDServer(ctx context.Context, router adapter.Router, logger log.Logger, key: key, certificatePath: options.CertificatePath, keyPath: options.KeyPath, + rejectHandshake: options.RejectHandshake, }, nil } + +func setGetCertificateFunc(tlsConfig *tls.Config, getCertificate func(*tls.ClientHelloInfo) (*tls.Certificate, error), rejectHandshake bool) { + tlsConfig.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { + cert, err := getCertificate(info) + if err != nil { + return nil, err + } + if rejectHandshake { + if info.ServerName != "" && info.ServerName == tlsConfig.ServerName { + return cert, nil + } + if cert.Leaf == nil { + cert.Leaf, err = x509.ParseCertificate(cert.Certificate[0]) + if err != nil { + return nil, err + } + } + if err = cert.Leaf.VerifyHostname(info.ServerName); err != nil { + return nil, E.Cause(err, "cert is not valid for SNI") + } + } + return cert, nil + } +} diff --git a/option/tls.go b/option/tls.go index 2ff5f2e4..07d8f1ce 100644 --- a/option/tls.go +++ b/option/tls.go @@ -4,6 +4,7 @@ type InboundTLSOptions struct { Enabled bool `json:"enabled,omitempty"` ServerName string `json:"server_name,omitempty"` Insecure bool `json:"insecure,omitempty"` + RejectHandshake bool `json:"reject_handshake,omitempty"` ALPN Listable[string] `json:"alpn,omitempty"` MinVersion string `json:"min_version,omitempty"` MaxVersion string `json:"max_version,omitempty"`