diff --git a/common/dialer/default.go b/common/dialer/default.go index 77536c43..bd8d9018 100644 --- a/common/dialer/default.go +++ b/common/dialer/default.go @@ -333,7 +333,17 @@ func (d *DefaultDialer) ListenSerialInterfacePacket(ctx context.Context, destina } func (d *DefaultDialer) ListenPacketCompat(network, address string) (net.PacketConn, error) { - return d.udpListener.ListenPacket(context.Background(), network, address) + udpListener := d.udpListener + udpListener.Control = control.Append(udpListener.Control, func(network, address string, conn syscall.RawConn) error { + for _, wgControlFn := range WgControlFns { + err := wgControlFn(network, address, conn) + if err != nil { + return err + } + } + return nil + }) + return udpListener.ListenPacket(context.Background(), network, address) } func trackConn(conn net.Conn, err error) (net.Conn, error) { diff --git a/transport/wireguard/endpoint.go b/transport/wireguard/endpoint.go index bddd2a12..3801640f 100644 --- a/transport/wireguard/endpoint.go +++ b/transport/wireguard/endpoint.go @@ -141,7 +141,7 @@ func (e *Endpoint) Start(resolve bool) error { return nil } var bind conn.Bind - wgListener, isWgListener := e.options.Dialer.(conn.Listener) + wgListener, isWgListener := common.Cast[conn.Listener](e.options.Dialer) if isWgListener { bind = conn.NewStdNetBind(wgListener) } else {