diff --git a/common/tlsfragment/index.go b/common/tlsfragment/index.go index 9c26eb3d..0d58c445 100644 --- a/common/tlsfragment/index.go +++ b/common/tlsfragment/index.go @@ -36,47 +36,48 @@ func IndexTLSServerName(payload []byte) *MyServerName { if len(payload) < recordLayerHeaderLen+int(segmentLen) { return nil } - serverName := indexTLSServerNameFromHandshake(payload[recordLayerHeaderLen : recordLayerHeaderLen+int(segmentLen)]) + serverName := indexTLSServerNameFromHandshake(payload[recordLayerHeaderLen:]) if serverName == nil { return nil } - serverName.Length += recordLayerHeaderLen + serverName.Index += recordLayerHeaderLen return serverName } -func indexTLSServerNameFromHandshake(hs []byte) *MyServerName { - if len(hs) < handshakeHeaderLen+randomDataLen+sessionIDHeaderLen { +func indexTLSServerNameFromHandshake(handshake []byte) *MyServerName { + if len(handshake) < handshakeHeaderLen+randomDataLen+sessionIDHeaderLen { return nil } - if hs[0] != handshakeType { + if handshake[0] != handshakeType { return nil } - handshakeLen := uint32(hs[1])<<16 | uint32(hs[2])<<8 | uint32(hs[3]) - if len(hs[4:]) != int(handshakeLen) { + handshakeLen := uint32(handshake[1])<<16 | uint32(handshake[2])<<8 | uint32(handshake[3]) + if len(handshake[4:]) != int(handshakeLen) { return nil } - tlsVersion := uint16(hs[4])<<8 | uint16(hs[5]) + tlsVersion := uint16(handshake[4])<<8 | uint16(handshake[5]) if tlsVersion&tlsVersionBitmask != 0x0300 && tlsVersion != tls13 { return nil } - sessionIDLen := hs[38] - if len(hs) < handshakeHeaderLen+randomDataLen+sessionIDHeaderLen+int(sessionIDLen) { + sessionIDLen := handshake[38] + currentIndex := handshakeHeaderLen + randomDataLen + sessionIDHeaderLen + int(sessionIDLen) + if len(handshake) < currentIndex { return nil } - cs := hs[handshakeHeaderLen+randomDataLen+sessionIDHeaderLen+int(sessionIDLen):] - if len(cs) < cipherSuiteHeaderLen { + cipherSuites := handshake[currentIndex:] + if len(cipherSuites) < cipherSuiteHeaderLen { return nil } - csLen := uint16(cs[0])<<8 | uint16(cs[1]) - if len(cs) < cipherSuiteHeaderLen+int(csLen)+compressMethodHeaderLen { + csLen := uint16(cipherSuites[0])<<8 | uint16(cipherSuites[1]) + if len(cipherSuites) < cipherSuiteHeaderLen+int(csLen)+compressMethodHeaderLen { return nil } - compressMethodLen := uint16(cs[cipherSuiteHeaderLen+int(csLen)]) - if len(cs) < cipherSuiteHeaderLen+int(csLen)+compressMethodHeaderLen+int(compressMethodLen) { + compressMethodLen := uint16(cipherSuites[cipherSuiteHeaderLen+int(csLen)]) + currentIndex += cipherSuiteHeaderLen + int(csLen) + compressMethodHeaderLen + int(compressMethodLen) + if len(handshake) < currentIndex { return nil } - currentIndex := cipherSuiteHeaderLen + int(csLen) + compressMethodHeaderLen + int(compressMethodLen) - serverName := indexTLSServerNameFromExtensions(cs[currentIndex:]) + serverName := indexTLSServerNameFromExtensions(handshake[currentIndex:]) if serverName == nil { return nil } @@ -118,6 +119,7 @@ func indexTLSServerNameFromExtensions(exs []byte) *MyServerName { } sniLen := uint16(sex[3])<<8 | uint16(sex[4]) sex = sex[sniExtensionHeaderLen:] + return &MyServerName{ Index: currentIndex + extensionHeaderLen + sniExtensionHeaderLen, Length: int(sniLen), diff --git a/common/tlsfragment/index_test.go b/common/tlsfragment/index_test.go index a4fb7bcb..5086d6c7 100644 --- a/common/tlsfragment/index_test.go +++ b/common/tlsfragment/index_test.go @@ -15,5 +15,6 @@ func TestIndexTLSServerName(t *testing.T) { require.NoError(t, err) serverName := tf.IndexTLSServerName(payload) require.NotNil(t, serverName) + require.Equal(t, serverName.ServerName, string(payload[serverName.Index:serverName.Index+serverName.Length])) require.Equal(t, "github.com", serverName.ServerName) }