tls: add reject handshake

This commit is contained in:
arm64v8a 2023-07-20 14:59:12 +09:00
parent e075bb5c8d
commit bb451ab74d
2 changed files with 36 additions and 4 deletions

View File

@ -3,6 +3,7 @@ package tls
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"crypto/x509"
"net" "net"
"os" "os"
@ -25,6 +26,7 @@ type STDServerConfig struct {
key []byte key []byte
certificatePath string certificatePath string
keyPath string keyPath string
rejectHandshake bool
watcher *fsnotify.Watcher watcher *fsnotify.Watcher
} }
@ -141,7 +143,9 @@ func (c *STDServerConfig) reloadKeyPair() error {
if err != nil { if err != nil {
return E.Cause(err, "reload key pair") return E.Cause(err, "reload key pair")
} }
c.config.Certificates = []tls.Certificate{keyPair} setGetCertificateFunc(c.config, func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
return &keyPair, nil
}, c.rejectHandshake)
c.logger.Info("reloaded TLS certificate") c.logger.Info("reloaded TLS certificate")
return nil return nil
} }
@ -230,9 +234,9 @@ func NewSTDServer(ctx context.Context, router adapter.Router, logger log.Logger,
key = content key = content
} }
if certificate == nil && key == nil && options.Insecure { if certificate == nil && key == nil && options.Insecure {
tlsConfig.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { setGetCertificateFunc(tlsConfig, func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
return GenerateKeyPair(router.TimeFunc(), info.ServerName) return GenerateKeyPair(router.TimeFunc(), info.ServerName)
} }, options.RejectHandshake)
} else { } else {
if certificate == nil { if certificate == nil {
return nil, E.New("missing certificate") return nil, E.New("missing certificate")
@ -244,7 +248,9 @@ func NewSTDServer(ctx context.Context, router adapter.Router, logger log.Logger,
if err != nil { if err != nil {
return nil, E.Cause(err, "parse x509 key pair") return nil, E.Cause(err, "parse x509 key pair")
} }
tlsConfig.Certificates = []tls.Certificate{keyPair} setGetCertificateFunc(tlsConfig, func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
return &keyPair, nil
}, options.RejectHandshake)
} }
} }
return &STDServerConfig{ return &STDServerConfig{
@ -255,5 +261,30 @@ func NewSTDServer(ctx context.Context, router adapter.Router, logger log.Logger,
key: key, key: key,
certificatePath: options.CertificatePath, certificatePath: options.CertificatePath,
keyPath: options.KeyPath, keyPath: options.KeyPath,
rejectHandshake: options.RejectHandshake,
}, nil }, nil
} }
func setGetCertificateFunc(tlsConfig *tls.Config, getCertificate func(*tls.ClientHelloInfo) (*tls.Certificate, error), rejectHandshake bool) {
tlsConfig.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
cert, err := getCertificate(info)
if err != nil {
return nil, err
}
if rejectHandshake {
if info.ServerName != "" && info.ServerName == tlsConfig.ServerName {
return cert, nil
}
if cert.Leaf == nil {
cert.Leaf, err = x509.ParseCertificate(cert.Certificate[0])
if err != nil {
return nil, err
}
}
if err = cert.Leaf.VerifyHostname(info.ServerName); err != nil {
return nil, E.Cause(err, "cert is not valid for SNI")
}
}
return cert, nil
}
}

View File

@ -4,6 +4,7 @@ type InboundTLSOptions struct {
Enabled bool `json:"enabled,omitempty"` Enabled bool `json:"enabled,omitempty"`
ServerName string `json:"server_name,omitempty"` ServerName string `json:"server_name,omitempty"`
Insecure bool `json:"insecure,omitempty"` Insecure bool `json:"insecure,omitempty"`
RejectHandshake bool `json:"reject_handshake,omitempty"`
ALPN Listable[string] `json:"alpn,omitempty"` ALPN Listable[string] `json:"alpn,omitempty"`
MinVersion string `json:"min_version,omitempty"` MinVersion string `json:"min_version,omitempty"`
MaxVersion string `json:"max_version,omitempty"` MaxVersion string `json:"max_version,omitempty"`