From 244243f2066e42a781995d83b62a095fda7aadbf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 31 Jan 2025 19:16:29 +0800 Subject: [PATCH] Fix domain resolver for DNS server --- box.go | 4 +-- common/dialer/detour.go | 16 +++++----- common/dialer/dialer.go | 54 ++++++++++++++++++++++++++++++++ common/dialer/resolve.go | 67 +++++++++++++++++++++++++++++++++------- dns/transport_adapter.go | 12 +++++-- dns/transport_dialer.go | 47 ++++++++++++---------------- option/dns.go | 44 +++++++++++++++++--------- 7 files changed, 180 insertions(+), 64 deletions(-) diff --git a/box.go b/box.go index 3e53fcd4..08a4fd0b 100644 --- a/box.go +++ b/box.go @@ -202,7 +202,7 @@ func New(options Options) (*Box, error) { transportOptions.Options, ) if err != nil { - return nil, E.Cause(err, "initialize inbound[", i, "]") + return nil, E.Cause(err, "initialize DNS server[", i, "]") } } err = dnsRouter.Initialize(dnsOptions.Rules) @@ -225,7 +225,7 @@ func New(options Options) (*Box, error) { endpointOptions.Options, ) if err != nil { - return nil, E.Cause(err, "initialize inbound[", i, "]") + return nil, E.Cause(err, "initialize endpoint[", i, "]") } } for i, inboundOptions := range options.Inbounds { diff --git a/common/dialer/detour.go b/common/dialer/detour.go index c1d40faa..e4a46049 100644 --- a/common/dialer/detour.go +++ b/common/dialer/detour.go @@ -29,16 +29,18 @@ func (d *DetourDialer) Start() error { } func (d *DetourDialer) Dialer() (N.Dialer, error) { - d.initOnce.Do(func() { - var loaded bool - d.dialer, loaded = d.outboundManager.Outbound(d.detour) - if !loaded { - d.initErr = E.New("outbound detour not found: ", d.detour) - } - }) + d.initOnce.Do(d.init) return d.dialer, d.initErr } +func (d *DetourDialer) init() { + var loaded bool + d.dialer, loaded = d.outboundManager.Outbound(d.detour) + if !loaded { + d.initErr = E.New("outbound detour not found: ", d.detour) + } +} + func (d *DetourDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { dialer, err := d.Dialer() if err != nil { diff --git a/common/dialer/dialer.go b/common/dialer/dialer.go index e860f520..93803fdb 100644 --- a/common/dialer/dialer.go +++ b/common/dialer/dialer.go @@ -84,6 +84,60 @@ func New(ctx context.Context, options option.DialerOptions, remoteIsDomain bool) ctx, dialer, options.Detour == "" && !options.TCPFastOpen, + "", + dnsQueryOptions, + resolveFallbackDelay, + ) + } + return dialer, nil +} + +func NewDNS(ctx context.Context, options option.DialerOptions, remoteIsDomain bool) (N.Dialer, error) { + var ( + dialer N.Dialer + err error + ) + if options.Detour != "" { + outboundManager := service.FromContext[adapter.OutboundManager](ctx) + if outboundManager == nil { + return nil, E.New("missing outbound manager") + } + dialer = NewDetour(outboundManager, options.Detour) + } else { + dialer, err = NewDefault(ctx, options) + if err != nil { + return nil, err + } + } + if remoteIsDomain { + var ( + dnsQueryOptions adapter.DNSQueryOptions + resolveFallbackDelay time.Duration + ) + if options.DomainResolver == nil || options.DomainResolver.Server == "" { + return nil, E.New("missing domain resolver for domain server address") + } + var strategy C.DomainStrategy + if options.DomainResolver.Strategy != option.DomainStrategy(C.DomainStrategyAsIS) { + strategy = C.DomainStrategy(options.DomainResolver.Strategy) + } else if + //nolint:staticcheck + options.DomainStrategy != option.DomainStrategy(C.DomainStrategyAsIS) { + //nolint:staticcheck + strategy = C.DomainStrategy(options.DomainStrategy) + } + dnsQueryOptions = adapter.DNSQueryOptions{ + Strategy: strategy, + DisableCache: options.DomainResolver.DisableCache, + RewriteTTL: options.DomainResolver.RewriteTTL, + ClientSubnet: options.DomainResolver.ClientSubnet.Build(netip.Prefix{}), + } + resolveFallbackDelay = time.Duration(options.FallbackDelay) + dialer = NewResolveDialer( + ctx, + dialer, + options.Detour == "" && !options.TCPFastOpen, + options.DomainResolver.Server, dnsQueryOptions, resolveFallbackDelay, ) diff --git a/common/dialer/resolve.go b/common/dialer/resolve.go index b28b9ebd..66b74e3c 100644 --- a/common/dialer/resolve.go +++ b/common/dialer/resolve.go @@ -3,12 +3,14 @@ package dialer import ( "context" "net" + "sync" "time" "github.com/sagernet/sing-box/adapter" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing/common/bufio" + E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/service" @@ -30,20 +32,26 @@ type ParallelInterfaceResolveDialer interface { } type resolveDialer struct { + transport adapter.DNSTransportManager router adapter.DNSRouter dialer N.Dialer parallel bool + server string + initOnce sync.Once + initErr error queryOptions adapter.DNSQueryOptions fallbackDelay time.Duration } -func NewResolveDialer(ctx context.Context, dialer N.Dialer, parallel bool, queryOptions adapter.DNSQueryOptions, fallbackDelay time.Duration) ResolveDialer { +func NewResolveDialer(ctx context.Context, dialer N.Dialer, parallel bool, server string, queryOptions adapter.DNSQueryOptions, fallbackDelay time.Duration) ResolveDialer { return &resolveDialer{ - service.FromContext[adapter.DNSRouter](ctx), - dialer, - parallel, - queryOptions, - fallbackDelay, + transport: service.FromContext[adapter.DNSTransportManager](ctx), + router: service.FromContext[adapter.DNSRouter](ctx), + dialer: dialer, + parallel: parallel, + server: server, + queryOptions: queryOptions, + fallbackDelay: fallbackDelay, } } @@ -52,20 +60,43 @@ type resolveParallelNetworkDialer struct { dialer ParallelInterfaceDialer } -func NewResolveParallelInterfaceDialer(ctx context.Context, dialer ParallelInterfaceDialer, parallel bool, queryOptions adapter.DNSQueryOptions, fallbackDelay time.Duration) ParallelInterfaceResolveDialer { +func NewResolveParallelInterfaceDialer(ctx context.Context, dialer ParallelInterfaceDialer, parallel bool, server string, queryOptions adapter.DNSQueryOptions, fallbackDelay time.Duration) ParallelInterfaceResolveDialer { return &resolveParallelNetworkDialer{ resolveDialer{ - service.FromContext[adapter.DNSRouter](ctx), - dialer, - parallel, - queryOptions, - fallbackDelay, + transport: service.FromContext[adapter.DNSTransportManager](ctx), + router: service.FromContext[adapter.DNSRouter](ctx), + dialer: dialer, + parallel: parallel, + server: server, + queryOptions: queryOptions, + fallbackDelay: fallbackDelay, }, dialer, } } +func (d *resolveDialer) initialize() error { + d.initOnce.Do(d.initServer) + return d.initErr +} + +func (d *resolveDialer) initServer() { + if d.server == "" { + return + } + transport, loaded := d.transport.Transport(d.server) + if !loaded { + d.initErr = E.New("domain resolver not found: " + d.server) + return + } + d.queryOptions.Transport = transport +} + func (d *resolveDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { + err := d.initialize() + if err != nil { + return nil, err + } if !destination.IsFqdn() { return d.dialer.DialContext(ctx, network, destination) } @@ -82,6 +113,10 @@ func (d *resolveDialer) DialContext(ctx context.Context, network string, destina } func (d *resolveDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { + err := d.initialize() + if err != nil { + return nil, err + } if !destination.IsFqdn() { return d.dialer.ListenPacket(ctx, destination) } @@ -106,6 +141,10 @@ func (d *resolveDialer) Upstream() any { } func (d *resolveParallelNetworkDialer) DialParallelInterface(ctx context.Context, network string, destination M.Socksaddr, strategy *C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.Conn, error) { + err := d.initialize() + if err != nil { + return nil, err + } if !destination.IsFqdn() { return d.dialer.DialContext(ctx, network, destination) } @@ -125,6 +164,10 @@ func (d *resolveParallelNetworkDialer) DialParallelInterface(ctx context.Context } func (d *resolveParallelNetworkDialer) ListenSerialInterfacePacket(ctx context.Context, destination M.Socksaddr, strategy *C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.PacketConn, error) { + err := d.initialize() + if err != nil { + return nil, err + } if !destination.IsFqdn() { return d.dialer.ListenPacket(ctx, destination) } diff --git a/dns/transport_adapter.go b/dns/transport_adapter.go index 02c84621..47345709 100644 --- a/dns/transport_adapter.go +++ b/dns/transport_adapter.go @@ -27,9 +27,14 @@ func NewTransportAdapter(transportType string, transportTag string, dependencies } func NewTransportAdapterWithLocalOptions(transportType string, transportTag string, localOptions option.LocalDNSServerOptions) TransportAdapter { + var dependencies []string + if localOptions.DomainResolver != nil && localOptions.DomainResolver.Server != "" { + dependencies = append(dependencies, localOptions.DomainResolver.Server) + } return TransportAdapter{ transportType: transportType, transportTag: transportTag, + dependencies: dependencies, strategy: C.DomainStrategy(localOptions.LegacyStrategy), clientSubnet: localOptions.LegacyClientSubnet, } @@ -37,8 +42,11 @@ func NewTransportAdapterWithLocalOptions(transportType string, transportTag stri func NewTransportAdapterWithRemoteOptions(transportType string, transportTag string, remoteOptions option.RemoteDNSServerOptions) TransportAdapter { var dependencies []string - if remoteOptions.AddressResolver != "" { - dependencies = []string{remoteOptions.AddressResolver} + if remoteOptions.DomainResolver != nil && remoteOptions.DomainResolver.Server != "" { + dependencies = append(dependencies, remoteOptions.DomainResolver.Server) + } + if remoteOptions.LegacyAddressResolver != "" { + dependencies = append(dependencies, remoteOptions.LegacyAddressResolver) } return TransportAdapter{ transportType: transportType, diff --git a/dns/transport_dialer.go b/dns/transport_dialer.go index d9298b7f..5fe2949d 100644 --- a/dns/transport_dialer.go +++ b/dns/transport_dialer.go @@ -19,37 +19,30 @@ func NewLocalDialer(ctx context.Context, options option.LocalDNSServerOptions) ( if options.LegacyDefaultDialer { return dialer.NewDefaultOutbound(ctx), nil } else { - return dialer.New(ctx, options.DialerOptions, false) + return dialer.NewDNS(ctx, options.DialerOptions, false) } } func NewRemoteDialer(ctx context.Context, options option.RemoteDNSServerOptions) (N.Dialer, error) { - var ( - transportDialer N.Dialer - err error - ) if options.LegacyDefaultDialer { - transportDialer = dialer.NewDefaultOutbound(ctx) - } else { - transportDialer, err = dialer.New(ctx, options.DialerOptions, options.ServerIsDomain()) - } - if err != nil { - return nil, err - } - if options.AddressResolver != "" { - transport := service.FromContext[adapter.DNSTransportManager](ctx) - resolverTransport, loaded := transport.Transport(options.AddressResolver) - if !loaded { - return nil, E.New("address resolver not found: ", options.AddressResolver) + transportDialer := dialer.NewDefaultOutbound(ctx) + if options.LegacyAddressResolver != "" { + transport := service.FromContext[adapter.DNSTransportManager](ctx) + resolverTransport, loaded := transport.Transport(options.LegacyAddressResolver) + if !loaded { + return nil, E.New("address resolver not found: ", options.LegacyAddressResolver) + } + transportDialer = newTransportDialer(transportDialer, service.FromContext[adapter.DNSRouter](ctx), resolverTransport, C.DomainStrategy(options.LegacyAddressStrategy), time.Duration(options.LegacyAddressFallbackDelay)) + } else if options.ServerIsDomain() { + return nil, E.New("missing address resolver for server: ", options.Server) } - transportDialer = NewTransportDialer(transportDialer, service.FromContext[adapter.DNSRouter](ctx), resolverTransport, C.DomainStrategy(options.AddressStrategy), time.Duration(options.AddressFallbackDelay)) - } else if options.ServerIsDomain() { - return nil, E.New("missing address resolver for server: ", options.Server) + return transportDialer, nil + } else { + return dialer.NewDNS(ctx, options.DialerOptions, options.ServerIsDomain()) } - return transportDialer, nil } -type TransportDialer struct { +type legacyTransportDialer struct { dialer N.Dialer dnsRouter adapter.DNSRouter transport adapter.DNSTransport @@ -57,8 +50,8 @@ type TransportDialer struct { fallbackDelay time.Duration } -func NewTransportDialer(dialer N.Dialer, dnsRouter adapter.DNSRouter, transport adapter.DNSTransport, strategy C.DomainStrategy, fallbackDelay time.Duration) *TransportDialer { - return &TransportDialer{ +func newTransportDialer(dialer N.Dialer, dnsRouter adapter.DNSRouter, transport adapter.DNSTransport, strategy C.DomainStrategy, fallbackDelay time.Duration) *legacyTransportDialer { + return &legacyTransportDialer{ dialer, dnsRouter, transport, @@ -67,7 +60,7 @@ func NewTransportDialer(dialer N.Dialer, dnsRouter adapter.DNSRouter, transport } } -func (d *TransportDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { +func (d *legacyTransportDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { if destination.IsIP() { return d.dialer.DialContext(ctx, network, destination) } @@ -81,7 +74,7 @@ func (d *TransportDialer) DialContext(ctx context.Context, network string, desti return N.DialParallel(ctx, d.dialer, network, destination, addresses, d.strategy == C.DomainStrategyPreferIPv6, d.fallbackDelay) } -func (d *TransportDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { +func (d *legacyTransportDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { if destination.IsIP() { return d.dialer.ListenPacket(ctx, destination) } @@ -96,6 +89,6 @@ func (d *TransportDialer) ListenPacket(ctx context.Context, destination M.Socksa return conn, err } -func (d *TransportDialer) Upstream() any { +func (d *legacyTransportDialer) Upstream() any { return d.dialer } diff --git a/option/dns.go b/option/dns.go index 2ed765fc..662bd895 100644 --- a/option/dns.go +++ b/option/dns.go @@ -128,18 +128,34 @@ func (o *NewDNSServerOptions) Upgrade(ctx context.Context) error { } else { serverType = C.DNSTypeUDP } - remoteOptions := RemoteDNSServerOptions{ - LocalDNSServerOptions: LocalDNSServerOptions{ - DialerOptions: DialerOptions{ - Detour: options.Detour, + var remoteOptions RemoteDNSServerOptions + if options.Detour == "" { + remoteOptions = RemoteDNSServerOptions{ + LocalDNSServerOptions: LocalDNSServerOptions{ + LegacyStrategy: options.Strategy, + LegacyDefaultDialer: options.Detour == "", + LegacyClientSubnet: options.ClientSubnet.Build(netip.Prefix{}), }, - LegacyStrategy: options.Strategy, - LegacyDefaultDialer: options.Detour == "", - LegacyClientSubnet: options.ClientSubnet.Build(netip.Prefix{}), - }, - AddressResolver: options.AddressResolver, - AddressStrategy: options.AddressStrategy, - AddressFallbackDelay: options.AddressFallbackDelay, + LegacyAddressResolver: options.AddressResolver, + LegacyAddressStrategy: options.AddressStrategy, + LegacyAddressFallbackDelay: options.AddressFallbackDelay, + } + } else { + remoteOptions = RemoteDNSServerOptions{ + LocalDNSServerOptions: LocalDNSServerOptions{ + DialerOptions: DialerOptions{ + Detour: options.Detour, + DomainResolver: &DomainResolveOptions{ + Server: options.AddressResolver, + Strategy: options.AddressStrategy, + }, + FallbackDelay: options.AddressFallbackDelay, + }, + LegacyStrategy: options.Strategy, + LegacyDefaultDialer: options.Detour == "", + LegacyClientSubnet: options.ClientSubnet.Build(netip.Prefix{}), + }, + } } switch serverType { case C.DNSTypeUDP: @@ -274,9 +290,9 @@ type LocalDNSServerOptions struct { type RemoteDNSServerOptions struct { LocalDNSServerOptions ServerOptions - AddressResolver string `json:"address_resolver,omitempty"` - AddressStrategy DomainStrategy `json:"address_strategy,omitempty"` - AddressFallbackDelay badoption.Duration `json:"address_fallback_delay,omitempty"` + LegacyAddressResolver string `json:"-"` + LegacyAddressStrategy DomainStrategy `json:"-"` + LegacyAddressFallbackDelay badoption.Duration `json:"-"` } type RemoteTLSDNSServerOptions struct {