diff --git a/route/dns.go b/route/dns.go index 2c6efefe..8d57c646 100644 --- a/route/dns.go +++ b/route/dns.go @@ -31,7 +31,7 @@ func (r *Router) hijackDNSStream(ctx context.Context, conn net.Conn, metadata ad } } -func (r *Router) hijackDNSPacket(ctx context.Context, conn N.PacketConn, packetBuffers []*N.PacketBuffer, metadata adapter.InboundContext) { +func (r *Router) hijackDNSPacket(ctx context.Context, conn N.PacketConn, packetBuffers []*N.PacketBuffer, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) { if natConn, isNatConn := conn.(udpnat.Conn); isNatConn { metadata.Destination = M.Socksaddr{} for _, packet := range packetBuffers { @@ -45,10 +45,12 @@ func (r *Router) hijackDNSPacket(ctx context.Context, conn N.PacketConn, packetB conn: conn, ctx: ctx, metadata: metadata, + onClose: onClose, }) return } err := dnsOutbound.NewDNSPacketConnection(ctx, r, conn, packetBuffers, metadata) + N.CloseOnHandshakeFailure(conn, onClose, err) if err != nil && !E.IsClosedOrCanceled(err) { r.dnsLogger.ErrorContext(ctx, E.Cause(err, "process packet connection")) } @@ -85,8 +87,16 @@ type dnsHijacker struct { conn N.PacketConn ctx context.Context metadata adapter.InboundContext + onClose N.CloseHandlerFunc } func (h *dnsHijacker) NewPacketEx(buffer *buf.Buffer, destination M.Socksaddr) { go ExchangeDNSPacket(h.ctx, h.router, h.conn, buffer, h.metadata, destination) } + +func (h *dnsHijacker) Close() error { + if h.onClose != nil { + h.onClose(nil) + } + return nil +} diff --git a/route/route.go b/route/route.go index d0f93e0b..d9bf2638 100644 --- a/route/route.go +++ b/route/route.go @@ -120,7 +120,7 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad for _, buffer := range buffers { conn = bufio.NewCachedConn(conn, buffer) } - r.hijackDNSStream(ctx, conn, metadata) + N.CloseOnHandshakeFailure(conn, onClose, r.hijackDNSStream(ctx, conn, metadata)) return nil } } @@ -233,7 +233,7 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m N.CloseOnHandshakeFailure(conn, onClose, action.Error(ctx)) return nil case *rule.RuleActionHijackDNS: - r.hijackDNSPacket(ctx, conn, packetBuffers, metadata) + r.hijackDNSPacket(ctx, conn, packetBuffers, metadata, onClose) return nil } }