diff --git a/common/sniff/bittorrent.go b/common/sniff/bittorrent.go index 39c19598..e4d9f4b8 100644 --- a/common/sniff/bittorrent.go +++ b/common/sniff/bittorrent.go @@ -31,13 +31,18 @@ func BitTorrent(_ context.Context, metadata *adapter.InboundContext, reader io.R return os.ErrInvalid } + const header = "BitTorrent protocol" var protocol [19]byte - _, err = reader.Read(protocol[:]) + var n int + n, err = reader.Read(protocol[:]) + if string(protocol[:n]) != header[:n] { + return os.ErrInvalid + } if err != nil { return E.Cause1(ErrNeedMoreData, err) } - if string(protocol[:]) != "BitTorrent protocol" { - return os.ErrInvalid + if n < 19 { + return ErrNeedMoreData } metadata.Protocol = C.ProtocolBitTorrent diff --git a/common/sniff/bittorrent_test.go b/common/sniff/bittorrent_test.go index f4762e32..fcb5f6fa 100644 --- a/common/sniff/bittorrent_test.go +++ b/common/sniff/bittorrent_test.go @@ -32,6 +32,27 @@ func TestSniffBittorrent(t *testing.T) { } } +func TestSniffIncompleteBittorrent(t *testing.T) { + t.Parallel() + + pkt, err := hex.DecodeString("13426974546f7272656e74") + require.NoError(t, err) + var metadata adapter.InboundContext + err = sniff.BitTorrent(context.TODO(), &metadata, bytes.NewReader(pkt)) + require.ErrorIs(t, err, sniff.ErrNeedMoreData) +} + +func TestSniffNotBittorrent(t *testing.T) { + t.Parallel() + + pkt, err := hex.DecodeString("13426974546f7272656e75") + require.NoError(t, err) + var metadata adapter.InboundContext + err = sniff.BitTorrent(context.TODO(), &metadata, bytes.NewReader(pkt)) + require.NotEmpty(t, err) + require.NotErrorIs(t, err, sniff.ErrNeedMoreData) +} + func TestSniffUTP(t *testing.T) { t.Parallel() diff --git a/common/sniff/dns.go b/common/sniff/dns.go index 2e22f3d7..7125a08e 100644 --- a/common/sniff/dns.go +++ b/common/sniff/dns.go @@ -20,22 +20,36 @@ func StreamDomainNameQuery(readCtx context.Context, metadata *adapter.InboundCon if err != nil { return E.Cause1(ErrNeedMoreData, err) } - if length == 0 { + if length < 12 { return os.ErrInvalid } buffer := buf.NewSize(int(length)) defer buffer.Release() - _, err = buffer.ReadFullFrom(reader, buffer.FreeLen()) + var n int + n, err = buffer.ReadFullFrom(reader, buffer.FreeLen()) + packet := buffer.Bytes() + if n > 2 && packet[2]&0x80 != 0 { // QR + return os.ErrInvalid + } + if n > 5 && packet[4] == 0 && packet[5] == 0 { // QDCOUNT + return os.ErrInvalid + } + for i := 6; i < 10; i++ { + // ANCOUNT, NSCOUNT + if n > i && packet[i] != 0 { + return os.ErrInvalid + } + } if err != nil { return E.Cause1(ErrNeedMoreData, err) } - return DomainNameQuery(readCtx, metadata, buffer.Bytes()) + return DomainNameQuery(readCtx, metadata, packet) } func DomainNameQuery(ctx context.Context, metadata *adapter.InboundContext, packet []byte) error { var msg mDNS.Msg err := msg.Unpack(packet) - if err != nil { + if err != nil || msg.Response || len(msg.Question) == 0 || len(msg.Answer) > 0 || len(msg.Ns) > 0 { return err } metadata.Protocol = C.ProtocolDNS diff --git a/common/sniff/dns_test.go b/common/sniff/dns_test.go index eaf4dd1a..d78b0bf5 100644 --- a/common/sniff/dns_test.go +++ b/common/sniff/dns_test.go @@ -1,6 +1,7 @@ package sniff_test import ( + "bytes" "context" "encoding/hex" "testing" @@ -21,3 +22,32 @@ func TestSniffDNS(t *testing.T) { require.NoError(t, err) require.Equal(t, C.ProtocolDNS, metadata.Protocol) } + +func TestSniffStreamDNS(t *testing.T) { + t.Parallel() + query, err := hex.DecodeString("001e740701000001000000000000012a06676f6f676c6503636f6d0000010001") + require.NoError(t, err) + var metadata adapter.InboundContext + err = sniff.StreamDomainNameQuery(context.TODO(), &metadata, bytes.NewReader(query)) + require.NoError(t, err) + require.Equal(t, C.ProtocolDNS, metadata.Protocol) +} + +func TestSniffIncompleteStreamDNS(t *testing.T) { + t.Parallel() + query, err := hex.DecodeString("001e740701000001000000000000") + require.NoError(t, err) + var metadata adapter.InboundContext + err = sniff.StreamDomainNameQuery(context.TODO(), &metadata, bytes.NewReader(query)) + require.ErrorIs(t, err, sniff.ErrNeedMoreData) +} + +func TestSniffNotStreamDNS(t *testing.T) { + t.Parallel() + query, err := hex.DecodeString("001e740701000000000000000000") + require.NoError(t, err) + var metadata adapter.InboundContext + err = sniff.StreamDomainNameQuery(context.TODO(), &metadata, bytes.NewReader(query)) + require.NotEmpty(t, err) + require.NotErrorIs(t, err, sniff.ErrNeedMoreData) +} diff --git a/common/sniff/sniff.go b/common/sniff/sniff.go index 59e81aaa..b3651e1f 100644 --- a/common/sniff/sniff.go +++ b/common/sniff/sniff.go @@ -68,7 +68,7 @@ func PeekStream(ctx context.Context, metadata *adapter.InboundContext, conn net. } sniffError = E.Errors(sniffError, err) } - if !errors.Is(err, ErrNeedMoreData) { + if !errors.Is(sniffError, ErrNeedMoreData) { break } } diff --git a/common/sniff/ssh.go b/common/sniff/ssh.go index d373d292..dce5d54f 100644 --- a/common/sniff/ssh.go +++ b/common/sniff/ssh.go @@ -15,10 +15,11 @@ func SSH(_ context.Context, metadata *adapter.InboundContext, reader io.Reader) const sshPrefix = "SSH-2.0-" bReader := bufio.NewReader(reader) prefix, err := bReader.Peek(len(sshPrefix)) + if string(prefix[:]) != sshPrefix[:len(prefix)] { + return os.ErrInvalid + } if err != nil { return E.Cause1(ErrNeedMoreData, err) - } else if string(prefix) != sshPrefix { - return os.ErrInvalid } fistLine, _, err := bReader.ReadLine() if err != nil { diff --git a/common/sniff/ssh_test.go b/common/sniff/ssh_test.go index be530980..7cea5aab 100644 --- a/common/sniff/ssh_test.go +++ b/common/sniff/ssh_test.go @@ -24,3 +24,24 @@ func TestSniffSSH(t *testing.T) { require.Equal(t, C.ProtocolSSH, metadata.Protocol) require.Equal(t, "dropbear", metadata.Client) } + +func TestSniffIncompleteSSH(t *testing.T) { + t.Parallel() + + pkt, err := hex.DecodeString("5353482d322e30") + require.NoError(t, err) + var metadata adapter.InboundContext + err = sniff.SSH(context.TODO(), &metadata, bytes.NewReader(pkt)) + require.ErrorIs(t, err, sniff.ErrNeedMoreData) +} + +func TestSniffNotSSH(t *testing.T) { + t.Parallel() + + pkt, err := hex.DecodeString("5353482d322e31") + require.NoError(t, err) + var metadata adapter.InboundContext + err = sniff.SSH(context.TODO(), &metadata, bytes.NewReader(pkt)) + require.NotEmpty(t, err) + require.NotErrorIs(t, err, sniff.ErrNeedMoreData) +}