Fix sniffer errors override each others

* Fix sniffer errors override each others

* Do not return ErrNeedMoreData if header is not expected
This commit is contained in:
dyhkwong 2025-04-22 14:44:55 +08:00 committed by GitHub
parent c54d50fd36
commit f2bbf6b2aa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 102 additions and 10 deletions

View File

@ -31,13 +31,18 @@ func BitTorrent(_ context.Context, metadata *adapter.InboundContext, reader io.R
return os.ErrInvalid return os.ErrInvalid
} }
const header = "BitTorrent protocol"
var protocol [19]byte var protocol [19]byte
_, err = reader.Read(protocol[:]) var n int
n, err = reader.Read(protocol[:])
if string(protocol[:n]) != header[:n] {
return os.ErrInvalid
}
if err != nil { if err != nil {
return E.Cause1(ErrNeedMoreData, err) return E.Cause1(ErrNeedMoreData, err)
} }
if string(protocol[:]) != "BitTorrent protocol" { if n < 19 {
return os.ErrInvalid return ErrNeedMoreData
} }
metadata.Protocol = C.ProtocolBitTorrent metadata.Protocol = C.ProtocolBitTorrent

View File

@ -32,6 +32,27 @@ func TestSniffBittorrent(t *testing.T) {
} }
} }
func TestSniffIncompleteBittorrent(t *testing.T) {
t.Parallel()
pkt, err := hex.DecodeString("13426974546f7272656e74")
require.NoError(t, err)
var metadata adapter.InboundContext
err = sniff.BitTorrent(context.TODO(), &metadata, bytes.NewReader(pkt))
require.ErrorIs(t, err, sniff.ErrNeedMoreData)
}
func TestSniffNotBittorrent(t *testing.T) {
t.Parallel()
pkt, err := hex.DecodeString("13426974546f7272656e75")
require.NoError(t, err)
var metadata adapter.InboundContext
err = sniff.BitTorrent(context.TODO(), &metadata, bytes.NewReader(pkt))
require.NotEmpty(t, err)
require.NotErrorIs(t, err, sniff.ErrNeedMoreData)
}
func TestSniffUTP(t *testing.T) { func TestSniffUTP(t *testing.T) {
t.Parallel() t.Parallel()

View File

@ -20,22 +20,36 @@ func StreamDomainNameQuery(readCtx context.Context, metadata *adapter.InboundCon
if err != nil { if err != nil {
return E.Cause1(ErrNeedMoreData, err) return E.Cause1(ErrNeedMoreData, err)
} }
if length == 0 { if length < 12 {
return os.ErrInvalid return os.ErrInvalid
} }
buffer := buf.NewSize(int(length)) buffer := buf.NewSize(int(length))
defer buffer.Release() defer buffer.Release()
_, err = buffer.ReadFullFrom(reader, buffer.FreeLen()) var n int
n, err = buffer.ReadFullFrom(reader, buffer.FreeLen())
packet := buffer.Bytes()
if n > 2 && packet[2]&0x80 != 0 { // QR
return os.ErrInvalid
}
if n > 5 && packet[4] == 0 && packet[5] == 0 { // QDCOUNT
return os.ErrInvalid
}
for i := 6; i < 10; i++ {
// ANCOUNT, NSCOUNT
if n > i && packet[i] != 0 {
return os.ErrInvalid
}
}
if err != nil { if err != nil {
return E.Cause1(ErrNeedMoreData, err) return E.Cause1(ErrNeedMoreData, err)
} }
return DomainNameQuery(readCtx, metadata, buffer.Bytes()) return DomainNameQuery(readCtx, metadata, packet)
} }
func DomainNameQuery(ctx context.Context, metadata *adapter.InboundContext, packet []byte) error { func DomainNameQuery(ctx context.Context, metadata *adapter.InboundContext, packet []byte) error {
var msg mDNS.Msg var msg mDNS.Msg
err := msg.Unpack(packet) err := msg.Unpack(packet)
if err != nil { if err != nil || msg.Response || len(msg.Question) == 0 || len(msg.Answer) > 0 || len(msg.Ns) > 0 {
return err return err
} }
metadata.Protocol = C.ProtocolDNS metadata.Protocol = C.ProtocolDNS

View File

@ -1,6 +1,7 @@
package sniff_test package sniff_test
import ( import (
"bytes"
"context" "context"
"encoding/hex" "encoding/hex"
"testing" "testing"
@ -21,3 +22,32 @@ func TestSniffDNS(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, C.ProtocolDNS, metadata.Protocol) require.Equal(t, C.ProtocolDNS, metadata.Protocol)
} }
func TestSniffStreamDNS(t *testing.T) {
t.Parallel()
query, err := hex.DecodeString("001e740701000001000000000000012a06676f6f676c6503636f6d0000010001")
require.NoError(t, err)
var metadata adapter.InboundContext
err = sniff.StreamDomainNameQuery(context.TODO(), &metadata, bytes.NewReader(query))
require.NoError(t, err)
require.Equal(t, C.ProtocolDNS, metadata.Protocol)
}
func TestSniffIncompleteStreamDNS(t *testing.T) {
t.Parallel()
query, err := hex.DecodeString("001e740701000001000000000000")
require.NoError(t, err)
var metadata adapter.InboundContext
err = sniff.StreamDomainNameQuery(context.TODO(), &metadata, bytes.NewReader(query))
require.ErrorIs(t, err, sniff.ErrNeedMoreData)
}
func TestSniffNotStreamDNS(t *testing.T) {
t.Parallel()
query, err := hex.DecodeString("001e740701000000000000000000")
require.NoError(t, err)
var metadata adapter.InboundContext
err = sniff.StreamDomainNameQuery(context.TODO(), &metadata, bytes.NewReader(query))
require.NotEmpty(t, err)
require.NotErrorIs(t, err, sniff.ErrNeedMoreData)
}

View File

@ -68,7 +68,7 @@ func PeekStream(ctx context.Context, metadata *adapter.InboundContext, conn net.
} }
sniffError = E.Errors(sniffError, err) sniffError = E.Errors(sniffError, err)
} }
if !errors.Is(err, ErrNeedMoreData) { if !errors.Is(sniffError, ErrNeedMoreData) {
break break
} }
} }

View File

@ -15,10 +15,11 @@ func SSH(_ context.Context, metadata *adapter.InboundContext, reader io.Reader)
const sshPrefix = "SSH-2.0-" const sshPrefix = "SSH-2.0-"
bReader := bufio.NewReader(reader) bReader := bufio.NewReader(reader)
prefix, err := bReader.Peek(len(sshPrefix)) prefix, err := bReader.Peek(len(sshPrefix))
if string(prefix[:]) != sshPrefix[:len(prefix)] {
return os.ErrInvalid
}
if err != nil { if err != nil {
return E.Cause1(ErrNeedMoreData, err) return E.Cause1(ErrNeedMoreData, err)
} else if string(prefix) != sshPrefix {
return os.ErrInvalid
} }
fistLine, _, err := bReader.ReadLine() fistLine, _, err := bReader.ReadLine()
if err != nil { if err != nil {

View File

@ -24,3 +24,24 @@ func TestSniffSSH(t *testing.T) {
require.Equal(t, C.ProtocolSSH, metadata.Protocol) require.Equal(t, C.ProtocolSSH, metadata.Protocol)
require.Equal(t, "dropbear", metadata.Client) require.Equal(t, "dropbear", metadata.Client)
} }
func TestSniffIncompleteSSH(t *testing.T) {
t.Parallel()
pkt, err := hex.DecodeString("5353482d322e30")
require.NoError(t, err)
var metadata adapter.InboundContext
err = sniff.SSH(context.TODO(), &metadata, bytes.NewReader(pkt))
require.ErrorIs(t, err, sniff.ErrNeedMoreData)
}
func TestSniffNotSSH(t *testing.T) {
t.Parallel()
pkt, err := hex.DecodeString("5353482d322e31")
require.NoError(t, err)
var metadata adapter.InboundContext
err = sniff.SSH(context.TODO(), &metadata, bytes.NewReader(pkt))
require.NotEmpty(t, err)
require.NotErrorIs(t, err, sniff.ErrNeedMoreData)
}