Fix domain resolver for DNS server

This commit is contained in:
世界 2025-01-31 19:16:29 +08:00
parent 89855ff548
commit 244243f206
No known key found for this signature in database
GPG Key ID: CD109927C34A63C4
7 changed files with 180 additions and 64 deletions

4
box.go
View File

@ -202,7 +202,7 @@ func New(options Options) (*Box, error) {
transportOptions.Options, transportOptions.Options,
) )
if err != nil { if err != nil {
return nil, E.Cause(err, "initialize inbound[", i, "]") return nil, E.Cause(err, "initialize DNS server[", i, "]")
} }
} }
err = dnsRouter.Initialize(dnsOptions.Rules) err = dnsRouter.Initialize(dnsOptions.Rules)
@ -225,7 +225,7 @@ func New(options Options) (*Box, error) {
endpointOptions.Options, endpointOptions.Options,
) )
if err != nil { if err != nil {
return nil, E.Cause(err, "initialize inbound[", i, "]") return nil, E.Cause(err, "initialize endpoint[", i, "]")
} }
} }
for i, inboundOptions := range options.Inbounds { for i, inboundOptions := range options.Inbounds {

View File

@ -29,14 +29,16 @@ func (d *DetourDialer) Start() error {
} }
func (d *DetourDialer) Dialer() (N.Dialer, error) { func (d *DetourDialer) Dialer() (N.Dialer, error) {
d.initOnce.Do(func() { d.initOnce.Do(d.init)
return d.dialer, d.initErr
}
func (d *DetourDialer) init() {
var loaded bool var loaded bool
d.dialer, loaded = d.outboundManager.Outbound(d.detour) d.dialer, loaded = d.outboundManager.Outbound(d.detour)
if !loaded { if !loaded {
d.initErr = E.New("outbound detour not found: ", d.detour) d.initErr = E.New("outbound detour not found: ", d.detour)
} }
})
return d.dialer, d.initErr
} }
func (d *DetourDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { func (d *DetourDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {

View File

@ -84,6 +84,60 @@ func New(ctx context.Context, options option.DialerOptions, remoteIsDomain bool)
ctx, ctx,
dialer, dialer,
options.Detour == "" && !options.TCPFastOpen, options.Detour == "" && !options.TCPFastOpen,
"",
dnsQueryOptions,
resolveFallbackDelay,
)
}
return dialer, nil
}
func NewDNS(ctx context.Context, options option.DialerOptions, remoteIsDomain bool) (N.Dialer, error) {
var (
dialer N.Dialer
err error
)
if options.Detour != "" {
outboundManager := service.FromContext[adapter.OutboundManager](ctx)
if outboundManager == nil {
return nil, E.New("missing outbound manager")
}
dialer = NewDetour(outboundManager, options.Detour)
} else {
dialer, err = NewDefault(ctx, options)
if err != nil {
return nil, err
}
}
if remoteIsDomain {
var (
dnsQueryOptions adapter.DNSQueryOptions
resolveFallbackDelay time.Duration
)
if options.DomainResolver == nil || options.DomainResolver.Server == "" {
return nil, E.New("missing domain resolver for domain server address")
}
var strategy C.DomainStrategy
if options.DomainResolver.Strategy != option.DomainStrategy(C.DomainStrategyAsIS) {
strategy = C.DomainStrategy(options.DomainResolver.Strategy)
} else if
//nolint:staticcheck
options.DomainStrategy != option.DomainStrategy(C.DomainStrategyAsIS) {
//nolint:staticcheck
strategy = C.DomainStrategy(options.DomainStrategy)
}
dnsQueryOptions = adapter.DNSQueryOptions{
Strategy: strategy,
DisableCache: options.DomainResolver.DisableCache,
RewriteTTL: options.DomainResolver.RewriteTTL,
ClientSubnet: options.DomainResolver.ClientSubnet.Build(netip.Prefix{}),
}
resolveFallbackDelay = time.Duration(options.FallbackDelay)
dialer = NewResolveDialer(
ctx,
dialer,
options.Detour == "" && !options.TCPFastOpen,
options.DomainResolver.Server,
dnsQueryOptions, dnsQueryOptions,
resolveFallbackDelay, resolveFallbackDelay,
) )

View File

@ -3,12 +3,14 @@ package dialer
import ( import (
"context" "context"
"net" "net"
"sync"
"time" "time"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/log"
"github.com/sagernet/sing/common/bufio" "github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/service" "github.com/sagernet/sing/service"
@ -30,20 +32,26 @@ type ParallelInterfaceResolveDialer interface {
} }
type resolveDialer struct { type resolveDialer struct {
transport adapter.DNSTransportManager
router adapter.DNSRouter router adapter.DNSRouter
dialer N.Dialer dialer N.Dialer
parallel bool parallel bool
server string
initOnce sync.Once
initErr error
queryOptions adapter.DNSQueryOptions queryOptions adapter.DNSQueryOptions
fallbackDelay time.Duration fallbackDelay time.Duration
} }
func NewResolveDialer(ctx context.Context, dialer N.Dialer, parallel bool, queryOptions adapter.DNSQueryOptions, fallbackDelay time.Duration) ResolveDialer { func NewResolveDialer(ctx context.Context, dialer N.Dialer, parallel bool, server string, queryOptions adapter.DNSQueryOptions, fallbackDelay time.Duration) ResolveDialer {
return &resolveDialer{ return &resolveDialer{
service.FromContext[adapter.DNSRouter](ctx), transport: service.FromContext[adapter.DNSTransportManager](ctx),
dialer, router: service.FromContext[adapter.DNSRouter](ctx),
parallel, dialer: dialer,
queryOptions, parallel: parallel,
fallbackDelay, server: server,
queryOptions: queryOptions,
fallbackDelay: fallbackDelay,
} }
} }
@ -52,20 +60,43 @@ type resolveParallelNetworkDialer struct {
dialer ParallelInterfaceDialer dialer ParallelInterfaceDialer
} }
func NewResolveParallelInterfaceDialer(ctx context.Context, dialer ParallelInterfaceDialer, parallel bool, queryOptions adapter.DNSQueryOptions, fallbackDelay time.Duration) ParallelInterfaceResolveDialer { func NewResolveParallelInterfaceDialer(ctx context.Context, dialer ParallelInterfaceDialer, parallel bool, server string, queryOptions adapter.DNSQueryOptions, fallbackDelay time.Duration) ParallelInterfaceResolveDialer {
return &resolveParallelNetworkDialer{ return &resolveParallelNetworkDialer{
resolveDialer{ resolveDialer{
service.FromContext[adapter.DNSRouter](ctx), transport: service.FromContext[adapter.DNSTransportManager](ctx),
dialer, router: service.FromContext[adapter.DNSRouter](ctx),
parallel, dialer: dialer,
queryOptions, parallel: parallel,
fallbackDelay, server: server,
queryOptions: queryOptions,
fallbackDelay: fallbackDelay,
}, },
dialer, dialer,
} }
} }
func (d *resolveDialer) initialize() error {
d.initOnce.Do(d.initServer)
return d.initErr
}
func (d *resolveDialer) initServer() {
if d.server == "" {
return
}
transport, loaded := d.transport.Transport(d.server)
if !loaded {
d.initErr = E.New("domain resolver not found: " + d.server)
return
}
d.queryOptions.Transport = transport
}
func (d *resolveDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { func (d *resolveDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
err := d.initialize()
if err != nil {
return nil, err
}
if !destination.IsFqdn() { if !destination.IsFqdn() {
return d.dialer.DialContext(ctx, network, destination) return d.dialer.DialContext(ctx, network, destination)
} }
@ -82,6 +113,10 @@ func (d *resolveDialer) DialContext(ctx context.Context, network string, destina
} }
func (d *resolveDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { func (d *resolveDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
err := d.initialize()
if err != nil {
return nil, err
}
if !destination.IsFqdn() { if !destination.IsFqdn() {
return d.dialer.ListenPacket(ctx, destination) return d.dialer.ListenPacket(ctx, destination)
} }
@ -106,6 +141,10 @@ func (d *resolveDialer) Upstream() any {
} }
func (d *resolveParallelNetworkDialer) DialParallelInterface(ctx context.Context, network string, destination M.Socksaddr, strategy *C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.Conn, error) { func (d *resolveParallelNetworkDialer) DialParallelInterface(ctx context.Context, network string, destination M.Socksaddr, strategy *C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.Conn, error) {
err := d.initialize()
if err != nil {
return nil, err
}
if !destination.IsFqdn() { if !destination.IsFqdn() {
return d.dialer.DialContext(ctx, network, destination) return d.dialer.DialContext(ctx, network, destination)
} }
@ -125,6 +164,10 @@ func (d *resolveParallelNetworkDialer) DialParallelInterface(ctx context.Context
} }
func (d *resolveParallelNetworkDialer) ListenSerialInterfacePacket(ctx context.Context, destination M.Socksaddr, strategy *C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.PacketConn, error) { func (d *resolveParallelNetworkDialer) ListenSerialInterfacePacket(ctx context.Context, destination M.Socksaddr, strategy *C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.PacketConn, error) {
err := d.initialize()
if err != nil {
return nil, err
}
if !destination.IsFqdn() { if !destination.IsFqdn() {
return d.dialer.ListenPacket(ctx, destination) return d.dialer.ListenPacket(ctx, destination)
} }

View File

@ -27,9 +27,14 @@ func NewTransportAdapter(transportType string, transportTag string, dependencies
} }
func NewTransportAdapterWithLocalOptions(transportType string, transportTag string, localOptions option.LocalDNSServerOptions) TransportAdapter { func NewTransportAdapterWithLocalOptions(transportType string, transportTag string, localOptions option.LocalDNSServerOptions) TransportAdapter {
var dependencies []string
if localOptions.DomainResolver != nil && localOptions.DomainResolver.Server != "" {
dependencies = append(dependencies, localOptions.DomainResolver.Server)
}
return TransportAdapter{ return TransportAdapter{
transportType: transportType, transportType: transportType,
transportTag: transportTag, transportTag: transportTag,
dependencies: dependencies,
strategy: C.DomainStrategy(localOptions.LegacyStrategy), strategy: C.DomainStrategy(localOptions.LegacyStrategy),
clientSubnet: localOptions.LegacyClientSubnet, clientSubnet: localOptions.LegacyClientSubnet,
} }
@ -37,8 +42,11 @@ func NewTransportAdapterWithLocalOptions(transportType string, transportTag stri
func NewTransportAdapterWithRemoteOptions(transportType string, transportTag string, remoteOptions option.RemoteDNSServerOptions) TransportAdapter { func NewTransportAdapterWithRemoteOptions(transportType string, transportTag string, remoteOptions option.RemoteDNSServerOptions) TransportAdapter {
var dependencies []string var dependencies []string
if remoteOptions.AddressResolver != "" { if remoteOptions.DomainResolver != nil && remoteOptions.DomainResolver.Server != "" {
dependencies = []string{remoteOptions.AddressResolver} dependencies = append(dependencies, remoteOptions.DomainResolver.Server)
}
if remoteOptions.LegacyAddressResolver != "" {
dependencies = append(dependencies, remoteOptions.LegacyAddressResolver)
} }
return TransportAdapter{ return TransportAdapter{
transportType: transportType, transportType: transportType,

View File

@ -19,37 +19,30 @@ func NewLocalDialer(ctx context.Context, options option.LocalDNSServerOptions) (
if options.LegacyDefaultDialer { if options.LegacyDefaultDialer {
return dialer.NewDefaultOutbound(ctx), nil return dialer.NewDefaultOutbound(ctx), nil
} else { } else {
return dialer.New(ctx, options.DialerOptions, false) return dialer.NewDNS(ctx, options.DialerOptions, false)
} }
} }
func NewRemoteDialer(ctx context.Context, options option.RemoteDNSServerOptions) (N.Dialer, error) { func NewRemoteDialer(ctx context.Context, options option.RemoteDNSServerOptions) (N.Dialer, error) {
var (
transportDialer N.Dialer
err error
)
if options.LegacyDefaultDialer { if options.LegacyDefaultDialer {
transportDialer = dialer.NewDefaultOutbound(ctx) transportDialer := dialer.NewDefaultOutbound(ctx)
} else { if options.LegacyAddressResolver != "" {
transportDialer, err = dialer.New(ctx, options.DialerOptions, options.ServerIsDomain())
}
if err != nil {
return nil, err
}
if options.AddressResolver != "" {
transport := service.FromContext[adapter.DNSTransportManager](ctx) transport := service.FromContext[adapter.DNSTransportManager](ctx)
resolverTransport, loaded := transport.Transport(options.AddressResolver) resolverTransport, loaded := transport.Transport(options.LegacyAddressResolver)
if !loaded { if !loaded {
return nil, E.New("address resolver not found: ", options.AddressResolver) return nil, E.New("address resolver not found: ", options.LegacyAddressResolver)
} }
transportDialer = NewTransportDialer(transportDialer, service.FromContext[adapter.DNSRouter](ctx), resolverTransport, C.DomainStrategy(options.AddressStrategy), time.Duration(options.AddressFallbackDelay)) transportDialer = newTransportDialer(transportDialer, service.FromContext[adapter.DNSRouter](ctx), resolverTransport, C.DomainStrategy(options.LegacyAddressStrategy), time.Duration(options.LegacyAddressFallbackDelay))
} else if options.ServerIsDomain() { } else if options.ServerIsDomain() {
return nil, E.New("missing address resolver for server: ", options.Server) return nil, E.New("missing address resolver for server: ", options.Server)
} }
return transportDialer, nil return transportDialer, nil
} else {
return dialer.NewDNS(ctx, options.DialerOptions, options.ServerIsDomain())
}
} }
type TransportDialer struct { type legacyTransportDialer struct {
dialer N.Dialer dialer N.Dialer
dnsRouter adapter.DNSRouter dnsRouter adapter.DNSRouter
transport adapter.DNSTransport transport adapter.DNSTransport
@ -57,8 +50,8 @@ type TransportDialer struct {
fallbackDelay time.Duration fallbackDelay time.Duration
} }
func NewTransportDialer(dialer N.Dialer, dnsRouter adapter.DNSRouter, transport adapter.DNSTransport, strategy C.DomainStrategy, fallbackDelay time.Duration) *TransportDialer { func newTransportDialer(dialer N.Dialer, dnsRouter adapter.DNSRouter, transport adapter.DNSTransport, strategy C.DomainStrategy, fallbackDelay time.Duration) *legacyTransportDialer {
return &TransportDialer{ return &legacyTransportDialer{
dialer, dialer,
dnsRouter, dnsRouter,
transport, transport,
@ -67,7 +60,7 @@ func NewTransportDialer(dialer N.Dialer, dnsRouter adapter.DNSRouter, transport
} }
} }
func (d *TransportDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { func (d *legacyTransportDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
if destination.IsIP() { if destination.IsIP() {
return d.dialer.DialContext(ctx, network, destination) return d.dialer.DialContext(ctx, network, destination)
} }
@ -81,7 +74,7 @@ func (d *TransportDialer) DialContext(ctx context.Context, network string, desti
return N.DialParallel(ctx, d.dialer, network, destination, addresses, d.strategy == C.DomainStrategyPreferIPv6, d.fallbackDelay) return N.DialParallel(ctx, d.dialer, network, destination, addresses, d.strategy == C.DomainStrategyPreferIPv6, d.fallbackDelay)
} }
func (d *TransportDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { func (d *legacyTransportDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
if destination.IsIP() { if destination.IsIP() {
return d.dialer.ListenPacket(ctx, destination) return d.dialer.ListenPacket(ctx, destination)
} }
@ -96,6 +89,6 @@ func (d *TransportDialer) ListenPacket(ctx context.Context, destination M.Socksa
return conn, err return conn, err
} }
func (d *TransportDialer) Upstream() any { func (d *legacyTransportDialer) Upstream() any {
return d.dialer return d.dialer
} }

View File

@ -128,18 +128,34 @@ func (o *NewDNSServerOptions) Upgrade(ctx context.Context) error {
} else { } else {
serverType = C.DNSTypeUDP serverType = C.DNSTypeUDP
} }
remoteOptions := RemoteDNSServerOptions{ var remoteOptions RemoteDNSServerOptions
if options.Detour == "" {
remoteOptions = RemoteDNSServerOptions{
LocalDNSServerOptions: LocalDNSServerOptions{
LegacyStrategy: options.Strategy,
LegacyDefaultDialer: options.Detour == "",
LegacyClientSubnet: options.ClientSubnet.Build(netip.Prefix{}),
},
LegacyAddressResolver: options.AddressResolver,
LegacyAddressStrategy: options.AddressStrategy,
LegacyAddressFallbackDelay: options.AddressFallbackDelay,
}
} else {
remoteOptions = RemoteDNSServerOptions{
LocalDNSServerOptions: LocalDNSServerOptions{ LocalDNSServerOptions: LocalDNSServerOptions{
DialerOptions: DialerOptions{ DialerOptions: DialerOptions{
Detour: options.Detour, Detour: options.Detour,
DomainResolver: &DomainResolveOptions{
Server: options.AddressResolver,
Strategy: options.AddressStrategy,
},
FallbackDelay: options.AddressFallbackDelay,
}, },
LegacyStrategy: options.Strategy, LegacyStrategy: options.Strategy,
LegacyDefaultDialer: options.Detour == "", LegacyDefaultDialer: options.Detour == "",
LegacyClientSubnet: options.ClientSubnet.Build(netip.Prefix{}), LegacyClientSubnet: options.ClientSubnet.Build(netip.Prefix{}),
}, },
AddressResolver: options.AddressResolver, }
AddressStrategy: options.AddressStrategy,
AddressFallbackDelay: options.AddressFallbackDelay,
} }
switch serverType { switch serverType {
case C.DNSTypeUDP: case C.DNSTypeUDP:
@ -274,9 +290,9 @@ type LocalDNSServerOptions struct {
type RemoteDNSServerOptions struct { type RemoteDNSServerOptions struct {
LocalDNSServerOptions LocalDNSServerOptions
ServerOptions ServerOptions
AddressResolver string `json:"address_resolver,omitempty"` LegacyAddressResolver string `json:"-"`
AddressStrategy DomainStrategy `json:"address_strategy,omitempty"` LegacyAddressStrategy DomainStrategy `json:"-"`
AddressFallbackDelay badoption.Duration `json:"address_fallback_delay,omitempty"` LegacyAddressFallbackDelay badoption.Duration `json:"-"`
} }
type RemoteTLSDNSServerOptions struct { type RemoteTLSDNSServerOptions struct {