Fix processing multiple sniffs

This commit is contained in:
世界 2025-03-16 09:21:54 +08:00
parent 4f3ee61104
commit d55d5009c2
No known key found for this signature in database
GPG Key ID: CD109927C34A63C4
3 changed files with 29 additions and 13 deletions

View File

@ -57,6 +57,7 @@ type InboundContext struct {
Domain string Domain string
Client string Client string
SniffContext any SniffContext any
PacketSniffError error
// cache // cache

View File

@ -9,6 +9,7 @@ import (
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
) )
@ -34,7 +35,7 @@ func Skip(metadata *adapter.InboundContext) bool {
return false return false
} }
func PeekStream(ctx context.Context, metadata *adapter.InboundContext, conn net.Conn, buffer *buf.Buffer, timeout time.Duration, sniffers ...StreamSniffer) error { func PeekStream(ctx context.Context, metadata *adapter.InboundContext, conn net.Conn, buffers []*buf.Buffer, buffer *buf.Buffer, timeout time.Duration, sniffers ...StreamSniffer) error {
if timeout == 0 { if timeout == 0 {
timeout = C.ReadPayloadTimeout timeout = C.ReadPayloadTimeout
} }
@ -55,7 +56,10 @@ func PeekStream(ctx context.Context, metadata *adapter.InboundContext, conn net.
} }
errors = nil errors = nil
for _, sniffer := range sniffers { for _, sniffer := range sniffers {
err = sniffer(ctx, metadata, bytes.NewReader(buffer.Bytes())) reader := io.MultiReader(common.Map(append(buffers, buffer), func(it *buf.Buffer) io.Reader {
return bytes.NewReader(it.Bytes())
})...)
err = sniffer(ctx, metadata, reader)
if err == nil { if err == nil {
return nil return nil
} }

View File

@ -358,7 +358,7 @@ func (r *Router) matchRule(
newBuffer, newPackerBuffers, newErr := r.actionSniff(ctx, metadata, &rule.RuleActionSniff{ newBuffer, newPackerBuffers, newErr := r.actionSniff(ctx, metadata, &rule.RuleActionSniff{
OverrideDestination: metadata.InboundOptions.SniffOverrideDestination, OverrideDestination: metadata.InboundOptions.SniffOverrideDestination,
Timeout: time.Duration(metadata.InboundOptions.SniffTimeout), Timeout: time.Duration(metadata.InboundOptions.SniffTimeout),
}, inputConn, inputPacketConn) }, inputConn, inputPacketConn, nil)
if newErr != nil { if newErr != nil {
fatalErr = newErr fatalErr = newErr
return return
@ -458,7 +458,7 @@ match:
switch action := currentRule.Action().(type) { switch action := currentRule.Action().(type) {
case *rule.RuleActionSniff: case *rule.RuleActionSniff:
if !preMatch { if !preMatch {
newBuffer, newPacketBuffers, newErr := r.actionSniff(ctx, metadata, action, inputConn, inputPacketConn) newBuffer, newPacketBuffers, newErr := r.actionSniff(ctx, metadata, action, inputConn, inputPacketConn, buffers)
if newErr != nil { if newErr != nil {
fatalErr = newErr fatalErr = newErr
return return
@ -490,7 +490,7 @@ match:
} }
} }
if !preMatch && inputPacketConn != nil && (metadata.InboundType == C.TypeSOCKS || metadata.InboundType == C.TypeMixed) && !metadata.Destination.IsFqdn() && !metadata.Destination.Addr.IsGlobalUnicast() { if !preMatch && inputPacketConn != nil && (metadata.InboundType == C.TypeSOCKS || metadata.InboundType == C.TypeMixed) && !metadata.Destination.IsFqdn() && !metadata.Destination.Addr.IsGlobalUnicast() {
newBuffer, newPacketBuffers, newErr := r.actionSniff(ctx, metadata, &rule.RuleActionSniff{Timeout: C.TCPTimeout}, inputConn, inputPacketConn) newBuffer, newPacketBuffers, newErr := r.actionSniff(ctx, metadata, &rule.RuleActionSniff{Timeout: C.TCPTimeout}, inputConn, inputPacketConn, buffers)
if newErr != nil { if newErr != nil {
fatalErr = newErr fatalErr = newErr
return return
@ -506,11 +506,16 @@ match:
func (r *Router) actionSniff( func (r *Router) actionSniff(
ctx context.Context, metadata *adapter.InboundContext, action *rule.RuleActionSniff, ctx context.Context, metadata *adapter.InboundContext, action *rule.RuleActionSniff,
inputConn net.Conn, inputPacketConn N.PacketConn, inputConn net.Conn, inputPacketConn N.PacketConn, inputBuffers []*buf.Buffer,
) (buffer *buf.Buffer, packetBuffers []*N.PacketBuffer, fatalErr error) { ) (buffer *buf.Buffer, packetBuffers []*N.PacketBuffer, fatalErr error) {
if sniff.Skip(metadata) { if sniff.Skip(metadata) {
r.logger.DebugContext(ctx, "sniff skipped due to port considered as server-first")
return return
} else if inputConn != nil { } else if metadata.Protocol != "" {
r.logger.DebugContext(ctx, "duplicate sniff skipped")
return
}
if inputConn != nil {
sniffBuffer := buf.NewPacket() sniffBuffer := buf.NewPacket()
var streamSniffers []sniff.StreamSniffer var streamSniffers []sniff.StreamSniffer
if len(action.StreamSniffers) > 0 { if len(action.StreamSniffers) > 0 {
@ -529,6 +534,7 @@ func (r *Router) actionSniff(
ctx, ctx,
metadata, metadata,
inputConn, inputConn,
inputBuffers,
sniffBuffer, sniffBuffer,
action.Timeout, action.Timeout,
streamSniffers..., streamSniffers...,
@ -555,6 +561,10 @@ func (r *Router) actionSniff(
sniffBuffer.Release() sniffBuffer.Release()
} }
} else if inputPacketConn != nil { } else if inputPacketConn != nil {
if metadata.PacketSniffError != nil && !errors.Is(metadata.PacketSniffError, sniff.ErrClientHelloFragmented) {
r.logger.DebugContext(ctx, "packet sniff skipped due to previous error: ", metadata.PacketSniffError)
return
}
for { for {
var ( var (
sniffBuffer = buf.NewPacket() sniffBuffer = buf.NewPacket()
@ -589,7 +599,7 @@ func (r *Router) actionSniff(
if (metadata.InboundType == C.TypeSOCKS || metadata.InboundType == C.TypeMixed) && !metadata.Destination.IsFqdn() && !metadata.Destination.Addr.IsGlobalUnicast() && !metadata.RouteOriginalDestination.IsValid() { if (metadata.InboundType == C.TypeSOCKS || metadata.InboundType == C.TypeMixed) && !metadata.Destination.IsFqdn() && !metadata.Destination.Addr.IsGlobalUnicast() && !metadata.RouteOriginalDestination.IsValid() {
metadata.Destination = destination metadata.Destination = destination
} }
if len(packetBuffers) > 0 { if len(packetBuffers) > 0 || metadata.PacketSniffError != nil {
err = sniff.PeekPacket( err = sniff.PeekPacket(
ctx, ctx,
metadata, metadata,
@ -622,7 +632,8 @@ func (r *Router) actionSniff(
Destination: destination, Destination: destination,
} }
packetBuffers = append(packetBuffers, packetBuffer) packetBuffers = append(packetBuffers, packetBuffer)
if E.IsMulti(err, sniff.ErrClientHelloFragmented) { metadata.PacketSniffError = err
if errors.Is(err, sniff.ErrClientHelloFragmented) {
r.logger.DebugContext(ctx, "attempt to sniff fragmented QUIC client hello") r.logger.DebugContext(ctx, "attempt to sniff fragmented QUIC client hello")
continue continue
} }