From d55d5009c29d594cc5a07ec29f9f44d4a5b20605 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sun, 16 Mar 2025 09:21:54 +0800 Subject: [PATCH] Fix processing multiple sniffs --- adapter/inbound.go | 9 +++++---- common/sniff/sniff.go | 8 ++++++-- route/route.go | 25 ++++++++++++++++++------- 3 files changed, 29 insertions(+), 13 deletions(-) diff --git a/adapter/inbound.go b/adapter/inbound.go index 93d2ec60..173dd0ee 100644 --- a/adapter/inbound.go +++ b/adapter/inbound.go @@ -53,10 +53,11 @@ type InboundContext struct { // sniffer - Protocol string - Domain string - Client string - SniffContext any + Protocol string + Domain string + Client string + SniffContext any + PacketSniffError error // cache diff --git a/common/sniff/sniff.go b/common/sniff/sniff.go index 81fc0a27..ecb0488b 100644 --- a/common/sniff/sniff.go +++ b/common/sniff/sniff.go @@ -9,6 +9,7 @@ import ( "github.com/sagernet/sing-box/adapter" C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" ) @@ -34,7 +35,7 @@ func Skip(metadata *adapter.InboundContext) bool { return false } -func PeekStream(ctx context.Context, metadata *adapter.InboundContext, conn net.Conn, buffer *buf.Buffer, timeout time.Duration, sniffers ...StreamSniffer) error { +func PeekStream(ctx context.Context, metadata *adapter.InboundContext, conn net.Conn, buffers []*buf.Buffer, buffer *buf.Buffer, timeout time.Duration, sniffers ...StreamSniffer) error { if timeout == 0 { timeout = C.ReadPayloadTimeout } @@ -55,7 +56,10 @@ func PeekStream(ctx context.Context, metadata *adapter.InboundContext, conn net. } errors = nil for _, sniffer := range sniffers { - err = sniffer(ctx, metadata, bytes.NewReader(buffer.Bytes())) + reader := io.MultiReader(common.Map(append(buffers, buffer), func(it *buf.Buffer) io.Reader { + return bytes.NewReader(it.Bytes()) + })...) + err = sniffer(ctx, metadata, reader) if err == nil { return nil } diff --git a/route/route.go b/route/route.go index bb484efd..dab750a2 100644 --- a/route/route.go +++ b/route/route.go @@ -358,7 +358,7 @@ func (r *Router) matchRule( newBuffer, newPackerBuffers, newErr := r.actionSniff(ctx, metadata, &rule.RuleActionSniff{ OverrideDestination: metadata.InboundOptions.SniffOverrideDestination, Timeout: time.Duration(metadata.InboundOptions.SniffTimeout), - }, inputConn, inputPacketConn) + }, inputConn, inputPacketConn, nil) if newErr != nil { fatalErr = newErr return @@ -458,7 +458,7 @@ match: switch action := currentRule.Action().(type) { case *rule.RuleActionSniff: if !preMatch { - newBuffer, newPacketBuffers, newErr := r.actionSniff(ctx, metadata, action, inputConn, inputPacketConn) + newBuffer, newPacketBuffers, newErr := r.actionSniff(ctx, metadata, action, inputConn, inputPacketConn, buffers) if newErr != nil { fatalErr = newErr return @@ -490,7 +490,7 @@ match: } } if !preMatch && inputPacketConn != nil && (metadata.InboundType == C.TypeSOCKS || metadata.InboundType == C.TypeMixed) && !metadata.Destination.IsFqdn() && !metadata.Destination.Addr.IsGlobalUnicast() { - newBuffer, newPacketBuffers, newErr := r.actionSniff(ctx, metadata, &rule.RuleActionSniff{Timeout: C.TCPTimeout}, inputConn, inputPacketConn) + newBuffer, newPacketBuffers, newErr := r.actionSniff(ctx, metadata, &rule.RuleActionSniff{Timeout: C.TCPTimeout}, inputConn, inputPacketConn, buffers) if newErr != nil { fatalErr = newErr return @@ -506,11 +506,16 @@ match: func (r *Router) actionSniff( ctx context.Context, metadata *adapter.InboundContext, action *rule.RuleActionSniff, - inputConn net.Conn, inputPacketConn N.PacketConn, + inputConn net.Conn, inputPacketConn N.PacketConn, inputBuffers []*buf.Buffer, ) (buffer *buf.Buffer, packetBuffers []*N.PacketBuffer, fatalErr error) { if sniff.Skip(metadata) { + r.logger.DebugContext(ctx, "sniff skipped due to port considered as server-first") return - } else if inputConn != nil { + } else if metadata.Protocol != "" { + r.logger.DebugContext(ctx, "duplicate sniff skipped") + return + } + if inputConn != nil { sniffBuffer := buf.NewPacket() var streamSniffers []sniff.StreamSniffer if len(action.StreamSniffers) > 0 { @@ -529,6 +534,7 @@ func (r *Router) actionSniff( ctx, metadata, inputConn, + inputBuffers, sniffBuffer, action.Timeout, streamSniffers..., @@ -555,6 +561,10 @@ func (r *Router) actionSniff( sniffBuffer.Release() } } else if inputPacketConn != nil { + if metadata.PacketSniffError != nil && !errors.Is(metadata.PacketSniffError, sniff.ErrClientHelloFragmented) { + r.logger.DebugContext(ctx, "packet sniff skipped due to previous error: ", metadata.PacketSniffError) + return + } for { var ( sniffBuffer = buf.NewPacket() @@ -589,7 +599,7 @@ func (r *Router) actionSniff( if (metadata.InboundType == C.TypeSOCKS || metadata.InboundType == C.TypeMixed) && !metadata.Destination.IsFqdn() && !metadata.Destination.Addr.IsGlobalUnicast() && !metadata.RouteOriginalDestination.IsValid() { metadata.Destination = destination } - if len(packetBuffers) > 0 { + if len(packetBuffers) > 0 || metadata.PacketSniffError != nil { err = sniff.PeekPacket( ctx, metadata, @@ -622,7 +632,8 @@ func (r *Router) actionSniff( Destination: destination, } packetBuffers = append(packetBuffers, packetBuffer) - if E.IsMulti(err, sniff.ErrClientHelloFragmented) { + metadata.PacketSniffError = err + if errors.Is(err, sniff.ErrClientHelloFragmented) { r.logger.DebugContext(ctx, "attempt to sniff fragmented QUIC client hello") continue }