optimize code

This commit is contained in:
arm64v8a 2023-07-23 18:47:37 +09:00
parent 043e473000
commit e63a68ab04

View File

@ -143,9 +143,12 @@ func (c *STDServerConfig) reloadKeyPair() error {
if err != nil { if err != nil {
return E.Cause(err, "reload key pair") return E.Cause(err, "reload key pair")
} }
setGetCertificateFunc(c.config, func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { c.config.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
return &keyPair, nil return &keyPair, nil
}, c.rejectUnknownSNI) }
if c.rejectUnknownSNI {
setRejectUnknownSNI(c.config)
}
c.logger.Info("reloaded TLS certificate") c.logger.Info("reloaded TLS certificate")
return nil return nil
} }
@ -234,9 +237,12 @@ func NewSTDServer(ctx context.Context, router adapter.Router, logger log.Logger,
key = content key = content
} }
if certificate == nil && key == nil && options.Insecure { if certificate == nil && key == nil && options.Insecure {
setGetCertificateFunc(tlsConfig, func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { tlsConfig.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
return GenerateKeyPair(router.TimeFunc(), info.ServerName) return GenerateKeyPair(router.TimeFunc(), info.ServerName)
}, options.RejectUnknownSNI) }
if options.RejectUnknownSNI {
return nil, E.New("insecure conflict with reject_unknown_sni")
}
} else { } else {
if certificate == nil { if certificate == nil {
return nil, E.New("missing certificate") return nil, E.New("missing certificate")
@ -248,9 +254,12 @@ func NewSTDServer(ctx context.Context, router adapter.Router, logger log.Logger,
if err != nil { if err != nil {
return nil, E.Cause(err, "parse x509 key pair") return nil, E.Cause(err, "parse x509 key pair")
} }
setGetCertificateFunc(tlsConfig, func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { tlsConfig.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
return &keyPair, nil return &keyPair, nil
}, options.RejectUnknownSNI) }
if options.RejectUnknownSNI {
setRejectUnknownSNI(tlsConfig)
}
} }
} }
return &STDServerConfig{ return &STDServerConfig{
@ -265,26 +274,25 @@ func NewSTDServer(ctx context.Context, router adapter.Router, logger log.Logger,
}, nil }, nil
} }
func setGetCertificateFunc(tlsConfig *tls.Config, getCertificate func(*tls.ClientHelloInfo) (*tls.Certificate, error), rejectUnknownSNI bool) { func setRejectUnknownSNI(tlsConfig *tls.Config) {
getCertificate := tlsConfig.GetCertificate
tlsConfig.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { tlsConfig.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
cert, err := getCertificate(info) cert, err := getCertificate(info)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if rejectUnknownSNI { if info.ServerName != "" && info.ServerName == tlsConfig.ServerName {
if info.ServerName != "" && info.ServerName == tlsConfig.ServerName { return cert, nil
return cert, nil }
} if cert.Leaf == nil {
if cert.Leaf == nil { cert.Leaf, err = x509.ParseCertificate(cert.Certificate[0])
cert.Leaf, err = x509.ParseCertificate(cert.Certificate[0]) if err != nil {
if err != nil { return nil, err
return nil, err
}
}
if err = cert.Leaf.VerifyHostname(info.ServerName); err != nil {
return nil, E.Cause(err, "cert is not valid for SNI")
} }
} }
if err = cert.Leaf.VerifyHostname(info.ServerName); err != nil {
return nil, E.Cause(err, "cert is not valid for SNI")
}
return cert, nil return cert, nil
} }
} }