Add support for client-subnet DNS options

This commit is contained in:
世界 2024-02-09 18:37:25 +08:00
parent 3daa477514
commit 801b8ae8e3
No known key found for this signature in database
GPG Key ID: CD109927C34A63C4
11 changed files with 81 additions and 50 deletions

View File

@ -84,6 +84,7 @@ type DNSRule interface {
Rule Rule
DisableCache() bool DisableCache() bool
RewriteTTL() *uint32 RewriteTTL() *uint32
ClientSubnet() *netip.Addr
WithAddressLimit() bool WithAddressLimit() bool
MatchAddressLimit(metadata *InboundContext) bool MatchAddressLimit(metadata *InboundContext) bool
} }

View File

@ -9,9 +9,7 @@ import (
"github.com/sagernet/sing-dns" "github.com/sagernet/sing-dns"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/task" "github.com/sagernet/sing/common/task"
mDNS "github.com/miekg/dns" mDNS "github.com/miekg/dns"
@ -25,9 +23,11 @@ type LocalDNSTransport interface {
func RegisterLocalDNSTransport(transport LocalDNSTransport) { func RegisterLocalDNSTransport(transport LocalDNSTransport) {
if transport == nil { if transport == nil {
dns.RegisterTransport([]string{"local"}, dns.CreateLocalTransport) dns.RegisterTransport([]string{"local"}, func(options dns.TransportOptions) (dns.Transport, error) {
return dns.NewLocalTransport(options), nil
})
} else { } else {
dns.RegisterTransport([]string{"local"}, func(name string, ctx context.Context, logger logger.ContextLogger, dialer N.Dialer, link string) (dns.Transport, error) { dns.RegisterTransport([]string{"local"}, func(options dns.TransportOptions) (dns.Transport, error) {
return &platformLocalDNSTransport{ return &platformLocalDNSTransport{
iif: transport, iif: transport,
}, nil }, nil

View File

@ -3,16 +3,12 @@
package include package include
import ( import (
"context"
"github.com/sagernet/sing-dns" "github.com/sagernet/sing-dns"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
N "github.com/sagernet/sing/common/network"
) )
func init() { func init() {
dns.RegisterTransport([]string{"dhcp"}, func(name string, ctx context.Context, logger logger.ContextLogger, dialer N.Dialer, link string) (dns.Transport, error) { dns.RegisterTransport([]string{"dhcp"}, func(options dns.TransportOptions) (dns.Transport, error) {
return nil, E.New(`DHCP is not included in this build, rebuild with -tags with_dhcp`) return nil, E.New(`DHCP is not included in this build, rebuild with -tags with_dhcp`)
}) })
} }

View File

@ -11,13 +11,12 @@ import (
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-box/transport/v2ray" "github.com/sagernet/sing-box/transport/v2ray"
"github.com/sagernet/sing-dns" "github.com/sagernet/sing-dns"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
) )
func init() { func init() {
dns.RegisterTransport([]string{"quic", "h3"}, func(name string, ctx context.Context, logger logger.ContextLogger, dialer N.Dialer, link string) (dns.Transport, error) { dns.RegisterTransport([]string{"quic", "h3"}, func(options dns.TransportOptions) (dns.Transport, error) {
return nil, C.ErrQUICNotIncluded return nil, C.ErrQUICNotIncluded
}) })
v2ray.RegisterQUICConstructor( v2ray.RegisterQUICConstructor(

View File

@ -19,6 +19,7 @@ type DNSServerOptions struct {
AddressFallbackDelay Duration `json:"address_fallback_delay,omitempty"` AddressFallbackDelay Duration `json:"address_fallback_delay,omitempty"`
Strategy DomainStrategy `json:"strategy,omitempty"` Strategy DomainStrategy `json:"strategy,omitempty"`
Detour string `json:"detour,omitempty"` Detour string `json:"detour,omitempty"`
ClientSubnet *ListenAddress `json:"client_subnet,omitempty"`
} }
type DNSClientOptions struct { type DNSClientOptions struct {
@ -26,6 +27,7 @@ type DNSClientOptions struct {
DisableCache bool `json:"disable_cache,omitempty"` DisableCache bool `json:"disable_cache,omitempty"`
DisableExpire bool `json:"disable_expire,omitempty"` DisableExpire bool `json:"disable_expire,omitempty"`
IndependentCache bool `json:"independent_cache,omitempty"` IndependentCache bool `json:"independent_cache,omitempty"`
ClientSubnet *ListenAddress `json:"client_subnet,omitempty"`
} }
type DNSFakeIPOptions struct { type DNSFakeIPOptions struct {

View File

@ -100,6 +100,7 @@ type DefaultDNSRule struct {
Server string `json:"server,omitempty"` Server string `json:"server,omitempty"`
DisableCache bool `json:"disable_cache,omitempty"` DisableCache bool `json:"disable_cache,omitempty"`
RewriteTTL *uint32 `json:"rewrite_ttl,omitempty"` RewriteTTL *uint32 `json:"rewrite_ttl,omitempty"`
ClientSubnet *ListenAddress `json:"client_subnet,omitempty"`
} }
func (r DefaultDNSRule) IsValid() bool { func (r DefaultDNSRule) IsValid() bool {
@ -108,16 +109,18 @@ func (r DefaultDNSRule) IsValid() bool {
defaultValue.Server = r.Server defaultValue.Server = r.Server
defaultValue.DisableCache = r.DisableCache defaultValue.DisableCache = r.DisableCache
defaultValue.RewriteTTL = r.RewriteTTL defaultValue.RewriteTTL = r.RewriteTTL
defaultValue.ClientSubnet = r.ClientSubnet
return !reflect.DeepEqual(r, defaultValue) return !reflect.DeepEqual(r, defaultValue)
} }
type LogicalDNSRule struct { type LogicalDNSRule struct {
Mode string `json:"mode"` Mode string `json:"mode"`
Rules []DNSRule `json:"rules,omitempty"` Rules []DNSRule `json:"rules,omitempty"`
Invert bool `json:"invert,omitempty"` Invert bool `json:"invert,omitempty"`
Server string `json:"server,omitempty"` Server string `json:"server,omitempty"`
DisableCache bool `json:"disable_cache,omitempty"` DisableCache bool `json:"disable_cache,omitempty"`
RewriteTTL *uint32 `json:"rewrite_ttl,omitempty"` RewriteTTL *uint32 `json:"rewrite_ttl,omitempty"`
ClientSubnet *ListenAddress `json:"client_subnet,omitempty"`
} }
func (r LogicalDNSRule) IsValid() bool { func (r LogicalDNSRule) IsValid() bool {

View File

@ -222,7 +222,20 @@ func NewRouter(
return nil, E.New("parse dns server[", tag, "]: missing address_resolver") return nil, E.New("parse dns server[", tag, "]: missing address_resolver")
} }
} }
transport, err := dns.CreateTransport(tag, ctx, logFactory.NewLogger(F.ToString("dns/transport[", tag, "]")), detour, server.Address) var clientSubnet netip.Addr
if server.ClientSubnet != nil {
clientSubnet = server.ClientSubnet.Build()
} else if dnsOptions.ClientSubnet != nil {
clientSubnet = dnsOptions.ClientSubnet.Build()
}
transport, err := dns.CreateTransport(dns.TransportOptions{
Context: ctx,
Logger: logFactory.NewLogger(F.ToString("dns/transport[", tag, "]")),
Name: tag,
Dialer: detour,
Address: server.Address,
ClientSubnet: clientSubnet,
})
if err != nil { if err != nil {
return nil, E.Cause(err, "parse dns server[", tag, "]") return nil, E.Cause(err, "parse dns server[", tag, "]")
} }
@ -262,7 +275,11 @@ func NewRouter(
} }
if defaultTransport == nil { if defaultTransport == nil {
if len(transports) == 0 { if len(transports) == 0 {
transports = append(transports, dns.NewLocalTransport("local", N.SystemDialer)) transports = append(transports, common.Must1(dns.CreateTransport(dns.TransportOptions{
Context: ctx,
Name: "local",
Dialer: common.Must1(dialer.NewDefault(router, option.DialerOptions{})),
})))
} }
defaultTransport = transports[0] defaultTransport = transports[0]
} }

View File

@ -70,6 +70,9 @@ func (r *Router) matchDNS(ctx context.Context, allowFakeIP bool, index int) (con
if rewriteTTL := rule.RewriteTTL(); rewriteTTL != nil { if rewriteTTL := rule.RewriteTTL(); rewriteTTL != nil {
ctx = dns.ContextWithRewriteTTL(ctx, *rewriteTTL) ctx = dns.ContextWithRewriteTTL(ctx, *rewriteTTL)
} }
if clientSubnet := rule.ClientSubnet(); clientSubnet != nil {
ctx = dns.ContextWithClientSubnet(ctx, *clientSubnet)
}
if domainStrategy, dsLoaded := r.transportDomainStrategy[transport]; dsLoaded { if domainStrategy, dsLoaded := r.transportDomainStrategy[transport]; dsLoaded {
return ctx, transport, domainStrategy, rule, ruleIndex return ctx, transport, domainStrategy, rule, ruleIndex
} else { } else {

View File

@ -1,6 +1,8 @@
package route package route
import ( import (
"net/netip"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/log"
@ -38,6 +40,7 @@ type DefaultDNSRule struct {
abstractDefaultRule abstractDefaultRule
disableCache bool disableCache bool
rewriteTTL *uint32 rewriteTTL *uint32
clientSubnet *netip.Addr
} }
func NewDefaultDNSRule(router adapter.Router, logger log.ContextLogger, options option.DefaultDNSRule) (*DefaultDNSRule, error) { func NewDefaultDNSRule(router adapter.Router, logger log.ContextLogger, options option.DefaultDNSRule) (*DefaultDNSRule, error) {
@ -48,6 +51,7 @@ func NewDefaultDNSRule(router adapter.Router, logger log.ContextLogger, options
}, },
disableCache: options.DisableCache, disableCache: options.DisableCache,
rewriteTTL: options.RewriteTTL, rewriteTTL: options.RewriteTTL,
clientSubnet: (*netip.Addr)(options.ClientSubnet),
} }
if len(options.Inbound) > 0 { if len(options.Inbound) > 0 {
item := NewInboundRule(options.Inbound) item := NewInboundRule(options.Inbound)
@ -230,6 +234,10 @@ func (r *DefaultDNSRule) RewriteTTL() *uint32 {
return r.rewriteTTL return r.rewriteTTL
} }
func (r *DefaultDNSRule) ClientSubnet() *netip.Addr {
return r.clientSubnet
}
func (r *DefaultDNSRule) WithAddressLimit() bool { func (r *DefaultDNSRule) WithAddressLimit() bool {
if len(r.destinationIPCIDRItems) > 0 { if len(r.destinationIPCIDRItems) > 0 {
return true return true
@ -264,6 +272,7 @@ type LogicalDNSRule struct {
abstractLogicalRule abstractLogicalRule
disableCache bool disableCache bool
rewriteTTL *uint32 rewriteTTL *uint32
clientSubnet *netip.Addr
} }
func NewLogicalDNSRule(router adapter.Router, logger log.ContextLogger, options option.LogicalDNSRule) (*LogicalDNSRule, error) { func NewLogicalDNSRule(router adapter.Router, logger log.ContextLogger, options option.LogicalDNSRule) (*LogicalDNSRule, error) {
@ -302,6 +311,10 @@ func (r *LogicalDNSRule) RewriteTTL() *uint32 {
return r.rewriteTTL return r.rewriteTTL
} }
func (r *LogicalDNSRule) ClientSubnet() *netip.Addr {
return r.clientSubnet
}
func (r *LogicalDNSRule) WithAddressLimit() bool { func (r *LogicalDNSRule) WithAddressLimit() bool {
for _, rawRule := range r.rules { for _, rawRule := range r.rules {
switch rule := rawRule.(type) { switch rule := rawRule.(type) {

View File

@ -21,9 +21,6 @@ import (
"github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/control" "github.com/sagernet/sing/common/control"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/task" "github.com/sagernet/sing/common/task"
"github.com/sagernet/sing/common/x/list" "github.com/sagernet/sing/common/x/list"
@ -32,14 +29,14 @@ import (
) )
func init() { func init() {
dns.RegisterTransport([]string{"dhcp"}, NewTransport) dns.RegisterTransport([]string{"dhcp"}, func(options dns.TransportOptions) (dns.Transport, error) {
return NewTransport(options)
})
} }
type Transport struct { type Transport struct {
name string options dns.TransportOptions
ctx context.Context
router adapter.Router router adapter.Router
logger logger.Logger
interfaceName string interfaceName string
autoInterface bool autoInterface bool
interfaceCallback *list.Element[tun.DefaultInterfaceUpdateCallback] interfaceCallback *list.Element[tun.DefaultInterfaceUpdateCallback]
@ -48,23 +45,20 @@ type Transport struct {
updatedAt time.Time updatedAt time.Time
} }
func NewTransport(name string, ctx context.Context, logger logger.ContextLogger, dialer N.Dialer, link string) (dns.Transport, error) { func NewTransport(options dns.TransportOptions) (*Transport, error) {
linkURL, err := url.Parse(link) linkURL, err := url.Parse(options.Address)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if linkURL.Host == "" { if linkURL.Host == "" {
return nil, E.New("missing interface name for DHCP") return nil, E.New("missing interface name for DHCP")
} }
router := adapter.RouterFromContext(ctx) router := adapter.RouterFromContext(options.Context)
if router == nil { if router == nil {
return nil, E.New("missing router in context") return nil, E.New("missing router in context")
} }
transport := &Transport{ transport := &Transport{
name: name,
ctx: ctx,
router: router, router: router,
logger: logger,
interfaceName: linkURL.Host, interfaceName: linkURL.Host,
autoInterface: linkURL.Host == "auto", autoInterface: linkURL.Host == "auto",
} }
@ -72,7 +66,7 @@ func NewTransport(name string, ctx context.Context, logger logger.ContextLogger,
} }
func (t *Transport) Name() string { func (t *Transport) Name() string {
return t.name return t.options.Name
} }
func (t *Transport) Start() error { func (t *Transport) Start() error {
@ -158,8 +152,8 @@ func (t *Transport) updateServers() error {
return E.Cause(err, "dhcp: prepare interface") return E.Cause(err, "dhcp: prepare interface")
} }
t.logger.Info("dhcp: query DNS servers on ", iface.Name) t.options.Logger.Info("dhcp: query DNS servers on ", iface.Name)
fetchCtx, cancel := context.WithTimeout(t.ctx, C.DHCPTimeout) fetchCtx, cancel := context.WithTimeout(t.options.Context, C.DHCPTimeout)
err = t.fetchServers0(fetchCtx, iface) err = t.fetchServers0(fetchCtx, iface)
cancel() cancel()
if err != nil { if err != nil {
@ -175,7 +169,7 @@ func (t *Transport) updateServers() error {
func (t *Transport) interfaceUpdated(int) { func (t *Transport) interfaceUpdated(int) {
err := t.updateServers() err := t.updateServers()
if err != nil { if err != nil {
t.logger.Error("update servers: ", err) t.options.Logger.Error("update servers: ", err)
} }
} }
@ -187,7 +181,7 @@ func (t *Transport) fetchServers0(ctx context.Context, iface *net.Interface) err
if runtime.GOOS == "linux" || runtime.GOOS == "android" { if runtime.GOOS == "linux" || runtime.GOOS == "android" {
listenAddr = "255.255.255.255:68" listenAddr = "255.255.255.255:68"
} }
packetConn, err := listener.ListenPacket(t.ctx, "udp4", listenAddr) packetConn, err := listener.ListenPacket(t.options.Context, "udp4", listenAddr)
if err != nil { if err != nil {
return err return err
} }
@ -225,17 +219,17 @@ func (t *Transport) fetchServersResponse(iface *net.Interface, packetConn net.Pa
dhcpPacket, err := dhcpv4.FromBytes(buffer.Bytes()) dhcpPacket, err := dhcpv4.FromBytes(buffer.Bytes())
if err != nil { if err != nil {
t.logger.Trace("dhcp: parse DHCP response: ", err) t.options.Logger.Trace("dhcp: parse DHCP response: ", err)
return err return err
} }
if dhcpPacket.MessageType() != dhcpv4.MessageTypeOffer { if dhcpPacket.MessageType() != dhcpv4.MessageTypeOffer {
t.logger.Trace("dhcp: expected OFFER response, but got ", dhcpPacket.MessageType()) t.options.Logger.Trace("dhcp: expected OFFER response, but got ", dhcpPacket.MessageType())
continue continue
} }
if dhcpPacket.TransactionID != transactionID { if dhcpPacket.TransactionID != transactionID {
t.logger.Trace("dhcp: expected transaction ID ", transactionID, ", but got ", dhcpPacket.TransactionID) t.options.Logger.Trace("dhcp: expected transaction ID ", transactionID, ", but got ", dhcpPacket.TransactionID)
continue continue
} }
@ -255,20 +249,22 @@ func (t *Transport) fetchServersResponse(iface *net.Interface, packetConn net.Pa
func (t *Transport) recreateServers(iface *net.Interface, serverAddrs []netip.Addr) error { func (t *Transport) recreateServers(iface *net.Interface, serverAddrs []netip.Addr) error {
if len(serverAddrs) > 0 { if len(serverAddrs) > 0 {
t.logger.Info("dhcp: updated DNS servers from ", iface.Name, ": [", strings.Join(common.Map(serverAddrs, func(it netip.Addr) string { t.options.Logger.Info("dhcp: updated DNS servers from ", iface.Name, ": [", strings.Join(common.Map(serverAddrs, func(it netip.Addr) string {
return it.String() return it.String()
}), ","), "]") }), ","), "]")
} }
serverDialer := common.Must1(dialer.NewDefault(t.router, option.DialerOptions{ serverDialer := common.Must1(dialer.NewDefault(t.router, option.DialerOptions{
BindInterface: iface.Name, BindInterface: iface.Name,
UDPFragmentDefault: true, UDPFragmentDefault: true,
})) }))
var transports []dns.Transport var transports []dns.Transport
for _, serverAddr := range serverAddrs { for _, serverAddr := range serverAddrs {
serverTransport, err := dns.NewUDPTransport(t.name, t.ctx, serverDialer, M.Socksaddr{Addr: serverAddr, Port: 53}) newOptions := t.options
newOptions.Address = serverAddr.String()
newOptions.Dialer = serverDialer
serverTransport, err := dns.NewUDPTransport(newOptions)
if err != nil { if err != nil {
return err return E.Cause(err, "create UDP transport from DHCP result: ", serverAddr)
} }
transports = append(transports, serverTransport) transports = append(transports, serverTransport)
} }

View File

@ -9,7 +9,6 @@ import (
"github.com/sagernet/sing-dns" "github.com/sagernet/sing-dns"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger" "github.com/sagernet/sing/common/logger"
N "github.com/sagernet/sing/common/network"
mDNS "github.com/miekg/dns" mDNS "github.com/miekg/dns"
) )
@ -20,7 +19,9 @@ var (
) )
func init() { func init() {
dns.RegisterTransport([]string{"fakeip"}, NewTransport) dns.RegisterTransport([]string{"fakeip"}, func(options dns.TransportOptions) (dns.Transport, error) {
return NewTransport(options)
})
} }
type Transport struct { type Transport struct {
@ -30,15 +31,15 @@ type Transport struct {
logger logger.ContextLogger logger logger.ContextLogger
} }
func NewTransport(name string, ctx context.Context, logger logger.ContextLogger, dialer N.Dialer, link string) (dns.Transport, error) { func NewTransport(options dns.TransportOptions) (*Transport, error) {
router := adapter.RouterFromContext(ctx) router := adapter.RouterFromContext(options.Context)
if router == nil { if router == nil {
return nil, E.New("missing router in context") return nil, E.New("missing router in context")
} }
return &Transport{ return &Transport{
name: name, name: options.Name,
router: router, router: router,
logger: logger, logger: options.Logger,
}, nil }, nil
} }