From d0aaf71770536f2420b43acdbc495c4b3d3ca1f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Mon, 18 Dec 2023 21:32:42 +0800 Subject: [PATCH] Fix direct UDP override --- outbound/direct.go | 26 ++++---------------------- 1 file changed, 4 insertions(+), 22 deletions(-) diff --git a/outbound/direct.go b/outbound/direct.go index 3bf80494..259205e4 100644 --- a/outbound/direct.go +++ b/outbound/direct.go @@ -12,7 +12,6 @@ import ( "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing-dns" - "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" @@ -123,6 +122,7 @@ func (h *Direct) ListenPacket(ctx context.Context, destination M.Socksaddr) (net ctx, metadata := adapter.ExtendContext(ctx) metadata.Outbound = h.tag metadata.Destination = destination + originDestination := destination switch h.overrideOption { case 1: destination = h.overrideDestination @@ -142,11 +142,10 @@ func (h *Direct) ListenPacket(ctx context.Context, destination M.Socksaddr) (net if err != nil { return nil, err } - if h.overrideOption == 0 { - return conn, nil - } else { - return &overridePacketConn{bufio.NewPacketConn(conn), destination}, nil + if originDestination != destination { + conn = bufio.NewNATPacketConn(bufio.NewPacketConn(conn), destination, originDestination) } + return conn, nil } func (h *Direct) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { @@ -156,20 +155,3 @@ func (h *Direct) NewConnection(ctx context.Context, conn net.Conn, metadata adap func (h *Direct) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { return NewPacketConnection(ctx, h, conn, metadata) } - -type overridePacketConn struct { - N.NetPacketConn - overrideDestination M.Socksaddr -} - -func (c *overridePacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { - return c.NetPacketConn.WritePacket(buffer, c.overrideDestination) -} - -func (c *overridePacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { - return c.NetPacketConn.WriteTo(p, c.overrideDestination.UDPAddr()) -} - -func (c *overridePacketConn) Upstream() any { - return c.NetPacketConn -}