diff --git a/inbound/shadowtls.go b/inbound/shadowtls.go index d3da651b..268dace5 100644 --- a/inbound/shadowtls.go +++ b/inbound/shadowtls.go @@ -3,7 +3,10 @@ package inbound import ( "bytes" "context" + "crypto/hmac" + "crypto/sha1" "encoding/binary" + "encoding/hex" "io" "net" "os" @@ -27,7 +30,7 @@ type ShadowTLS struct { myInboundAdapter handshakeDialer N.Dialer handshakeAddr M.Socksaddr - v2 bool + version int password string fallbackAfter int } @@ -47,17 +50,18 @@ func NewShadowTLS(ctx context.Context, router adapter.Router, logger log.Context handshakeAddr: options.Handshake.ServerOptions.Build(), password: options.Password, } + inbound.version = options.Version switch options.Version { case 0: fallthrough case 1: case 2: - inbound.v2 = true if options.FallbackAfter == nil { inbound.fallbackAfter = 2 } else { inbound.fallbackAfter = *options.FallbackAfter } + case 3: default: return nil, E.New("unknown shadowtls protocol version: ", options.Version) } @@ -70,7 +74,8 @@ func (s *ShadowTLS) NewConnection(ctx context.Context, conn net.Conn, metadata a if err != nil { return err } - if !s.v2 { + switch s.version { + case 1: var handshake task.Group handshake.Append("client handshake", func(ctx context.Context) error { return s.copyUntilHandshakeFinished(handshakeConn, conn) @@ -87,7 +92,7 @@ func (s *ShadowTLS) NewConnection(ctx context.Context, conn net.Conn, metadata a return err } return s.newConnection(ctx, conn, metadata) - } else { + case 2: hashConn := shadowtls.NewHashWriteConn(conn, s.password) go bufio.Copy(hashConn, handshakeConn) var request *buf.Buffer @@ -102,6 +107,97 @@ func (s *ShadowTLS) NewConnection(ctx context.Context, conn net.Conn, metadata a } else { return err } + default: + fallthrough + case 3: + var clientHelloFrame *buf.Buffer + clientHelloFrame, err = shadowtls.ExtractFrame(conn) + if err != nil { + return E.Cause(err, "read client handshake") + } + _, err = handshakeConn.Write(clientHelloFrame.Bytes()) + if err != nil { + clientHelloFrame.Release() + return E.Cause(err, "write client handshake") + } + err = shadowtls.VerifyClientHello(clientHelloFrame.Bytes(), s.password) + if err != nil { + s.logger.WarnContext(ctx, E.Cause(err, "client hello verify failed")) + return bufio.CopyConn(ctx, conn, handshakeConn) + } + s.logger.TraceContext(ctx, "client hello verify success") + clientHelloFrame.Release() + + var serverHelloFrame *buf.Buffer + serverHelloFrame, err = shadowtls.ExtractFrame(handshakeConn) + if err != nil { + return E.Cause(err, "read server handshake") + } + + _, err = conn.Write(serverHelloFrame.Bytes()) + if err != nil { + serverHelloFrame.Release() + return E.Cause(err, "write server handshake") + } + + serverRandom := shadowtls.ExtractServerRandom(serverHelloFrame.Bytes()) + + if serverRandom == nil { + s.logger.WarnContext(ctx, "server random extract failed, will copy bidirectional") + return bufio.CopyConn(ctx, conn, handshakeConn) + } + + if !shadowtls.IsServerHelloSupportTLS13(serverHelloFrame.Bytes()) { + s.logger.WarnContext(ctx, "TLS 1.3 is not supported, will copy bidirectional") + return bufio.CopyConn(ctx, conn, handshakeConn) + } + + serverHelloFrame.Release() + s.logger.TraceContext(ctx, "client authenticated. server random extracted: ", hex.EncodeToString(serverRandom)) + + hmacWrite := hmac.New(sha1.New, []byte(s.password)) + hmacWrite.Write(serverRandom) + + hmacAdd := hmac.New(sha1.New, []byte(s.password)) + hmacAdd.Write(serverRandom) + hmacAdd.Write([]byte("S")) + + hmacVerify := hmac.New(sha1.New, []byte(s.password)) + hmacVerifyReset := func() { + hmacVerify.Reset() + hmacVerify.Write(serverRandom) + hmacVerify.Write([]byte("C")) + } + + var clientFirstFrame *buf.Buffer + var group task.Group + var handshakeFinished bool + group.Append("client handshake relay", func(ctx context.Context) error { + clientFrame, cErr := shadowtls.CopyByFrameUntilHMACMatches(conn, handshakeConn, hmacVerify, hmacVerifyReset) + if cErr == nil { + clientFirstFrame = clientFrame + handshakeFinished = true + handshakeConn.Close() + } + return cErr + }) + group.Append("server handshake relay", func(ctx context.Context) error { + cErr := shadowtls.CopyByFrameWithModification(handshakeConn, conn, s.password, serverRandom, hmacWrite) + if E.IsClosedOrCanceled(cErr) && handshakeFinished { + return nil + } + return cErr + }) + group.Cleanup(func() { + handshakeConn.Close() + }) + err = group.Run(ctx) + if err != nil { + return E.Cause(err, "handshake relay") + } + + s.logger.TraceContext(ctx, "handshake relay finished") + return s.newConnection(ctx, bufio.NewCachedConn(shadowtls.NewVerifiedConn(conn, hmacAdd, hmacVerify, nil), clientFirstFrame), metadata) } } diff --git a/test/shadowtls_test.go b/test/shadowtls_test.go index 6720702c..361e5f45 100644 --- a/test/shadowtls_test.go +++ b/test/shadowtls_test.go @@ -17,22 +17,20 @@ import ( func TestShadowTLS(t *testing.T) { t.Run("v1", func(t *testing.T) { - testShadowTLS(t, "") + testShadowTLS(t, 1, "") }) t.Run("v2", func(t *testing.T) { - testShadowTLS(t, "hello") + testShadowTLS(t, 2, "hello") }) } -func testShadowTLS(t *testing.T, password string) { +func TestShadowTLSv3(t *testing.T) { + testShadowTLS(t, 3, "hello") +} + +func testShadowTLS(t *testing.T, version int, password string) { method := shadowaead_2022.List[0] ssPassword := mkBase64(t, 16) - var version int - if password != "" { - version = 2 - } else { - version = 1 - } startInstance(t, option.Options{ Inbounds: []option.Inbound{ { diff --git a/transport/shadowtls/client_v3.go b/transport/shadowtls/client_v3.go index 18c0e784..9d809590 100644 --- a/transport/shadowtls/client_v3.go +++ b/transport/shadowtls/client_v3.go @@ -105,10 +105,7 @@ func (w *StreamWrapper) Read(p []byte) (n int, err error) { copy(w.serverRandom, buffer[serverRandomIndex:serverRandomIndex+tlsRandomSize]) w.readHMAC = hmac.New(sha1.New, []byte(w.password)) w.readHMAC.Write(w.serverRandom) - hasher := sha256.New() - hasher.Write([]byte(w.password)) - hasher.Write(w.serverRandom) - w.readHMACKey = hasher.Sum(nil) + w.readHMACKey = kdf(w.password, w.serverRandom) } case applicationData: w.authorized = false @@ -126,6 +123,13 @@ func (w *StreamWrapper) Read(p []byte) (n int, err error) { return w.buffer.Read(p) } +func kdf(password string, serverRandom []byte) []byte { + hasher := sha256.New() + hasher.Write([]byte(password)) + hasher.Write(serverRandom) + return hasher.Sum(nil) +} + func xorSlice(data []byte, key []byte) { for i := range data { data[i] ^= key[i%len(key)] diff --git a/transport/shadowtls/server_v3.go b/transport/shadowtls/server_v3.go new file mode 100644 index 00000000..44bd6bbb --- /dev/null +++ b/transport/shadowtls/server_v3.go @@ -0,0 +1,181 @@ +package shadowtls + +import ( + "bytes" + "crypto/hmac" + "crypto/sha1" + "encoding/binary" + "hash" + "io" + "net" + + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/bufio" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/rw" +) + +func ExtractFrame(conn net.Conn) (*buf.Buffer, error) { + var tlsHeader [tlsHeaderSize]byte + _, err := io.ReadFull(conn, tlsHeader[:]) + if err != nil { + return nil, err + } + length := int(binary.BigEndian.Uint16(tlsHeader[3:])) + buffer := buf.NewSize(tlsHeaderSize + length) + common.Must1(buffer.Write(tlsHeader[:])) + _, err = buffer.ReadFullFrom(conn, length) + if err != nil { + buffer.Release() + } + return buffer, err +} + +func VerifyClientHello(frame []byte, password string) error { + const minLen = tlsHeaderSize + 1 + 3 + 2 + tlsRandomSize + 1 + tlsSessionIDSize + const hmacIndex = sessionIDLengthIndex + 1 + tlsSessionIDSize - hmacSize + if len(frame) < minLen { + return io.ErrUnexpectedEOF + } else if frame[0] != handshake { + return E.New("unexpected record type") + } else if frame[5] != clientHello { + return E.New("unexpected handshake type") + } else if frame[sessionIDLengthIndex] != tlsSessionIDSize { + return E.New("unexpected session id length") + } + hmacSHA1Hash := hmac.New(sha1.New, []byte(password)) + hmacSHA1Hash.Write(frame[tlsHeaderSize:hmacIndex]) + hmacSHA1Hash.Write(rw.ZeroBytes[:4]) + hmacSHA1Hash.Write(frame[hmacIndex+hmacSize:]) + if !hmac.Equal(frame[hmacIndex:hmacIndex+hmacSize], hmacSHA1Hash.Sum(nil)[:hmacSize]) { + return E.New("hmac mismatch") + } + return nil +} + +func ExtractServerRandom(frame []byte) []byte { + const minLen = tlsHeaderSize + 1 + 3 + 2 + tlsRandomSize + + if len(frame) < minLen || frame[0] != handshake || frame[5] != serverHello { + return nil + } + + serverRandom := make([]byte, tlsRandomSize) + copy(serverRandom, frame[serverRandomIndex:serverRandomIndex+tlsRandomSize]) + return serverRandom +} + +func IsServerHelloSupportTLS13(frame []byte) bool { + if len(frame) < sessionIDLengthIndex { + return false + } + + reader := bytes.NewReader(frame[sessionIDLengthIndex:]) + + var sessionIdLength uint8 + err := binary.Read(reader, binary.BigEndian, &sessionIdLength) + if err != nil { + return false + } + _, err = io.CopyN(io.Discard, reader, int64(sessionIdLength)) + if err != nil { + return false + } + + _, err = io.CopyN(io.Discard, reader, 3) + if err != nil { + return false + } + + var extensionListLength uint16 + err = binary.Read(reader, binary.BigEndian, &extensionListLength) + if err != nil { + return false + } + for i := uint16(0); i < extensionListLength; i++ { + var extensionType uint16 + err = binary.Read(reader, binary.BigEndian, &extensionType) + if err != nil { + return false + } + var extensionLength uint16 + err = binary.Read(reader, binary.BigEndian, &extensionLength) + if err != nil { + return false + } + if extensionType != 43 { + _, err = io.CopyN(io.Discard, reader, int64(extensionLength)) + if err != nil { + return false + } + continue + } + if extensionLength != 2 { + return false + } + var extensionValue uint16 + err = binary.Read(reader, binary.BigEndian, &extensionValue) + if err != nil { + return false + } + return extensionValue == 0x0304 + } + return false +} + +func CopyByFrameUntilHMACMatches(conn net.Conn, handshakeConn net.Conn, hmacVerify hash.Hash, hmacReset func()) (*buf.Buffer, error) { + for { + frameBuffer, err := ExtractFrame(conn) + if err != nil { + return nil, E.Cause(err, "read client record") + } + frame := frameBuffer.Bytes() + if len(frame) > tlsHmacHeaderSize && frame[0] == applicationData { + hmacReset() + hmacVerify.Write(frame[tlsHmacHeaderSize:]) + hmacHash := hmacVerify.Sum(nil)[:4] + if bytes.Equal(hmacHash, frame[tlsHeaderSize:tlsHmacHeaderSize]) { + hmacReset() + hmacVerify.Write(frame[tlsHmacHeaderSize:]) + hmacVerify.Write(frame[tlsHeaderSize:tlsHmacHeaderSize]) + frameBuffer.Advance(tlsHmacHeaderSize) + return frameBuffer, nil + } + } + _, err = handshakeConn.Write(frame) + frameBuffer.Release() + if err != nil { + return nil, E.Cause(err, "write clint frame") + } + } +} + +func CopyByFrameWithModification(conn net.Conn, handshakeConn net.Conn, password string, serverRandom []byte, hmacWrite hash.Hash) error { + writeKey := kdf(password, serverRandom) + writer := bufio.NewVectorisedWriter(handshakeConn) + for { + frameBuffer, err := ExtractFrame(conn) + if err != nil { + return E.Cause(err, "read server record") + } + frame := frameBuffer.Bytes() + if frame[0] == applicationData { + xorSlice(frame[tlsHeaderSize:], writeKey) + hmacWrite.Write(frame[tlsHeaderSize:]) + binary.BigEndian.PutUint16(frame[3:], uint16(len(frame)-tlsHeaderSize+hmacSize)) + hmacHash := hmacWrite.Sum(nil)[:4] + _, err = bufio.WriteVectorised(writer, [][]byte{frame[:tlsHeaderSize], hmacHash, frame[tlsHeaderSize:]}) + frameBuffer.Release() + if err != nil { + return E.Cause(err, "write modified server frame") + } + } else { + _, err = handshakeConn.Write(frame) + frameBuffer.Release() + if err != nil { + return E.Cause(err, "write server frame") + } + } + } +}