Fix IndexTLSServerName

This commit is contained in:
世界 2025-07-22 22:52:51 +08:00
parent 1e068c78e6
commit 685d7afd9b
No known key found for this signature in database
GPG Key ID: CD109927C34A63C4
2 changed files with 21 additions and 18 deletions

View File

@ -36,47 +36,48 @@ func IndexTLSServerName(payload []byte) *MyServerName {
if len(payload) < recordLayerHeaderLen+int(segmentLen) { if len(payload) < recordLayerHeaderLen+int(segmentLen) {
return nil return nil
} }
serverName := indexTLSServerNameFromHandshake(payload[recordLayerHeaderLen : recordLayerHeaderLen+int(segmentLen)]) serverName := indexTLSServerNameFromHandshake(payload[recordLayerHeaderLen:])
if serverName == nil { if serverName == nil {
return nil return nil
} }
serverName.Length += recordLayerHeaderLen serverName.Index += recordLayerHeaderLen
return serverName return serverName
} }
func indexTLSServerNameFromHandshake(hs []byte) *MyServerName { func indexTLSServerNameFromHandshake(handshake []byte) *MyServerName {
if len(hs) < handshakeHeaderLen+randomDataLen+sessionIDHeaderLen { if len(handshake) < handshakeHeaderLen+randomDataLen+sessionIDHeaderLen {
return nil return nil
} }
if hs[0] != handshakeType { if handshake[0] != handshakeType {
return nil return nil
} }
handshakeLen := uint32(hs[1])<<16 | uint32(hs[2])<<8 | uint32(hs[3]) handshakeLen := uint32(handshake[1])<<16 | uint32(handshake[2])<<8 | uint32(handshake[3])
if len(hs[4:]) != int(handshakeLen) { if len(handshake[4:]) != int(handshakeLen) {
return nil return nil
} }
tlsVersion := uint16(hs[4])<<8 | uint16(hs[5]) tlsVersion := uint16(handshake[4])<<8 | uint16(handshake[5])
if tlsVersion&tlsVersionBitmask != 0x0300 && tlsVersion != tls13 { if tlsVersion&tlsVersionBitmask != 0x0300 && tlsVersion != tls13 {
return nil return nil
} }
sessionIDLen := hs[38] sessionIDLen := handshake[38]
if len(hs) < handshakeHeaderLen+randomDataLen+sessionIDHeaderLen+int(sessionIDLen) { currentIndex := handshakeHeaderLen + randomDataLen + sessionIDHeaderLen + int(sessionIDLen)
if len(handshake) < currentIndex {
return nil return nil
} }
cs := hs[handshakeHeaderLen+randomDataLen+sessionIDHeaderLen+int(sessionIDLen):] cipherSuites := handshake[currentIndex:]
if len(cs) < cipherSuiteHeaderLen { if len(cipherSuites) < cipherSuiteHeaderLen {
return nil return nil
} }
csLen := uint16(cs[0])<<8 | uint16(cs[1]) csLen := uint16(cipherSuites[0])<<8 | uint16(cipherSuites[1])
if len(cs) < cipherSuiteHeaderLen+int(csLen)+compressMethodHeaderLen { if len(cipherSuites) < cipherSuiteHeaderLen+int(csLen)+compressMethodHeaderLen {
return nil return nil
} }
compressMethodLen := uint16(cs[cipherSuiteHeaderLen+int(csLen)]) compressMethodLen := uint16(cipherSuites[cipherSuiteHeaderLen+int(csLen)])
if len(cs) < cipherSuiteHeaderLen+int(csLen)+compressMethodHeaderLen+int(compressMethodLen) { currentIndex += cipherSuiteHeaderLen + int(csLen) + compressMethodHeaderLen + int(compressMethodLen)
if len(handshake) < currentIndex {
return nil return nil
} }
currentIndex := cipherSuiteHeaderLen + int(csLen) + compressMethodHeaderLen + int(compressMethodLen) serverName := indexTLSServerNameFromExtensions(handshake[currentIndex:])
serverName := indexTLSServerNameFromExtensions(cs[currentIndex:])
if serverName == nil { if serverName == nil {
return nil return nil
} }
@ -118,6 +119,7 @@ func indexTLSServerNameFromExtensions(exs []byte) *MyServerName {
} }
sniLen := uint16(sex[3])<<8 | uint16(sex[4]) sniLen := uint16(sex[3])<<8 | uint16(sex[4])
sex = sex[sniExtensionHeaderLen:] sex = sex[sniExtensionHeaderLen:]
return &MyServerName{ return &MyServerName{
Index: currentIndex + extensionHeaderLen + sniExtensionHeaderLen, Index: currentIndex + extensionHeaderLen + sniExtensionHeaderLen,
Length: int(sniLen), Length: int(sniLen),

View File

@ -15,5 +15,6 @@ func TestIndexTLSServerName(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
serverName := tf.IndexTLSServerName(payload) serverName := tf.IndexTLSServerName(payload)
require.NotNil(t, serverName) require.NotNil(t, serverName)
require.Equal(t, serverName.ServerName, string(payload[serverName.Index:serverName.Index+serverName.Length]))
require.Equal(t, "github.com", serverName.ServerName) require.Equal(t, "github.com", serverName.ServerName)
} }