diff --git a/adapter/router.go b/adapter/router.go index 297d15e8..c39b5676 100644 --- a/adapter/router.go +++ b/adapter/router.go @@ -84,6 +84,7 @@ type DNSRule interface { Rule DisableCache() bool RewriteTTL() *uint32 + ClientSubnet() *netip.Addr WithAddressLimit() bool MatchAddressLimit(metadata *InboundContext) bool } diff --git a/experimental/libbox/dns.go b/experimental/libbox/dns.go index fcdaaa92..e1f8bcc3 100644 --- a/experimental/libbox/dns.go +++ b/experimental/libbox/dns.go @@ -9,9 +9,7 @@ import ( "github.com/sagernet/sing-dns" "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" - "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" - N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/task" mDNS "github.com/miekg/dns" @@ -25,9 +23,11 @@ type LocalDNSTransport interface { func RegisterLocalDNSTransport(transport LocalDNSTransport) { if transport == nil { - dns.RegisterTransport([]string{"local"}, dns.CreateLocalTransport) + dns.RegisterTransport([]string{"local"}, func(options dns.TransportOptions) (dns.Transport, error) { + return dns.NewLocalTransport(options), nil + }) } else { - dns.RegisterTransport([]string{"local"}, func(name string, ctx context.Context, logger logger.ContextLogger, dialer N.Dialer, link string) (dns.Transport, error) { + dns.RegisterTransport([]string{"local"}, func(options dns.TransportOptions) (dns.Transport, error) { return &platformLocalDNSTransport{ iif: transport, }, nil diff --git a/go.mod b/go.mod index 4b4868d1..9237f2b3 100644 --- a/go.mod +++ b/go.mod @@ -27,7 +27,7 @@ require ( github.com/sagernet/quic-go v0.40.1 github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691 github.com/sagernet/sing v0.3.1-0.20240105061852-782bc05c5573 - github.com/sagernet/sing-dns v0.1.13-0.20240203102504-27e217be9060 + github.com/sagernet/sing-dns v0.1.13-0.20240209104932-6a377c9272fb github.com/sagernet/sing-mux v0.2.0 github.com/sagernet/sing-quic v0.1.8 github.com/sagernet/sing-shadowsocks v0.2.6 diff --git a/go.sum b/go.sum index 29ef3036..8f80c1e3 100644 --- a/go.sum +++ b/go.sum @@ -111,8 +111,8 @@ github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691/go.mod h1:B8lp4Wk github.com/sagernet/sing v0.2.18/go.mod h1:OL6k2F0vHmEzXz2KW19qQzu172FDgSbUSODylighuVo= github.com/sagernet/sing v0.3.1-0.20240105061852-782bc05c5573 h1:1wGN3eNanp8r+Y3bNBys3ZnAVF5gdtDoDwtosMZEbgA= github.com/sagernet/sing v0.3.1-0.20240105061852-782bc05c5573/go.mod h1:9pfuAH6mZfgnz/YjP6xu5sxx882rfyjpcrTdUpd6w3g= -github.com/sagernet/sing-dns v0.1.13-0.20240203102504-27e217be9060 h1:ah78H3NjlBEov2MGAKC5Wtn71LhFfRatVrJ88PCQPjE= -github.com/sagernet/sing-dns v0.1.13-0.20240203102504-27e217be9060/go.mod h1:IxOqfSb6Zt6UVCy8fJpDxb2XxqzHUytNqeOuJfaiLu8= +github.com/sagernet/sing-dns v0.1.13-0.20240209104932-6a377c9272fb h1:jjlaJ9GMjTB4OoS8taQ1gh0oMBnWke72qSteaMzecPU= +github.com/sagernet/sing-dns v0.1.13-0.20240209104932-6a377c9272fb/go.mod h1:IxOqfSb6Zt6UVCy8fJpDxb2XxqzHUytNqeOuJfaiLu8= github.com/sagernet/sing-mux v0.2.0 h1:4C+vd8HztJCWNYfufvgL49xaOoOHXty2+EAjnzN3IYo= github.com/sagernet/sing-mux v0.2.0/go.mod h1:khzr9AOPocLa+g53dBplwNDz4gdsyx/YM3swtAhlkHQ= github.com/sagernet/sing-quic v0.1.8 h1:G4iBXAKIII+uTzd55oZ/9cAQswGjlvHh/0yKMQioDS0= diff --git a/include/dhcp_stub.go b/include/dhcp_stub.go index c57aa430..47a19d2e 100644 --- a/include/dhcp_stub.go +++ b/include/dhcp_stub.go @@ -3,16 +3,12 @@ package include import ( - "context" - "github.com/sagernet/sing-dns" E "github.com/sagernet/sing/common/exceptions" - "github.com/sagernet/sing/common/logger" - N "github.com/sagernet/sing/common/network" ) func init() { - dns.RegisterTransport([]string{"dhcp"}, func(name string, ctx context.Context, logger logger.ContextLogger, dialer N.Dialer, link string) (dns.Transport, error) { + dns.RegisterTransport([]string{"dhcp"}, func(options dns.TransportOptions) (dns.Transport, error) { return nil, E.New(`DHCP is not included in this build, rebuild with -tags with_dhcp`) }) } diff --git a/option/dns.go b/option/dns.go index e0d237b7..15201343 100644 --- a/option/dns.go +++ b/option/dns.go @@ -19,6 +19,7 @@ type DNSServerOptions struct { AddressFallbackDelay Duration `json:"address_fallback_delay,omitempty"` Strategy DomainStrategy `json:"strategy,omitempty"` Detour string `json:"detour,omitempty"` + ClientSubnet *ListenAddress `json:"client_subnet,omitempty"` } type DNSClientOptions struct { @@ -26,6 +27,7 @@ type DNSClientOptions struct { DisableCache bool `json:"disable_cache,omitempty"` DisableExpire bool `json:"disable_expire,omitempty"` IndependentCache bool `json:"independent_cache,omitempty"` + ClientSubnet *ListenAddress `json:"client_subnet,omitempty"` } type DNSFakeIPOptions struct { diff --git a/option/rule_dns.go b/option/rule_dns.go index d148e264..dc5e5c2b 100644 --- a/option/rule_dns.go +++ b/option/rule_dns.go @@ -100,6 +100,7 @@ type DefaultDNSRule struct { Server string `json:"server,omitempty"` DisableCache bool `json:"disable_cache,omitempty"` RewriteTTL *uint32 `json:"rewrite_ttl,omitempty"` + ClientSubnet *ListenAddress `json:"client_subnet,omitempty"` } func (r DefaultDNSRule) IsValid() bool { @@ -108,16 +109,18 @@ func (r DefaultDNSRule) IsValid() bool { defaultValue.Server = r.Server defaultValue.DisableCache = r.DisableCache defaultValue.RewriteTTL = r.RewriteTTL + defaultValue.ClientSubnet = r.ClientSubnet return !reflect.DeepEqual(r, defaultValue) } type LogicalDNSRule struct { - Mode string `json:"mode"` - Rules []DNSRule `json:"rules,omitempty"` - Invert bool `json:"invert,omitempty"` - Server string `json:"server,omitempty"` - DisableCache bool `json:"disable_cache,omitempty"` - RewriteTTL *uint32 `json:"rewrite_ttl,omitempty"` + Mode string `json:"mode"` + Rules []DNSRule `json:"rules,omitempty"` + Invert bool `json:"invert,omitempty"` + Server string `json:"server,omitempty"` + DisableCache bool `json:"disable_cache,omitempty"` + RewriteTTL *uint32 `json:"rewrite_ttl,omitempty"` + ClientSubnet *ListenAddress `json:"client_subnet,omitempty"` } func (r LogicalDNSRule) IsValid() bool { diff --git a/route/router.go b/route/router.go index 7a7c535a..ed54383f 100644 --- a/route/router.go +++ b/route/router.go @@ -222,7 +222,20 @@ func NewRouter( return nil, E.New("parse dns server[", tag, "]: missing address_resolver") } } - transport, err := dns.CreateTransport(tag, ctx, logFactory.NewLogger(F.ToString("dns/transport[", tag, "]")), detour, server.Address) + var clientSubnet netip.Addr + if server.ClientSubnet != nil { + clientSubnet = server.ClientSubnet.Build() + } else if dnsOptions.ClientSubnet != nil { + clientSubnet = dnsOptions.ClientSubnet.Build() + } + transport, err := dns.CreateTransport(dns.TransportOptions{ + Context: ctx, + Logger: logFactory.NewLogger(F.ToString("dns/transport[", tag, "]")), + Name: tag, + Dialer: detour, + Address: server.Address, + ClientSubnet: clientSubnet, + }) if err != nil { return nil, E.Cause(err, "parse dns server[", tag, "]") } @@ -262,7 +275,11 @@ func NewRouter( } if defaultTransport == nil { if len(transports) == 0 { - transports = append(transports, dns.NewLocalTransport("local", N.SystemDialer)) + transports = append(transports, common.Must1(dns.CreateTransport(dns.TransportOptions{ + Context: ctx, + Name: "local", + Dialer: common.Must1(dialer.NewDefault(router, option.DialerOptions{})), + }))) } defaultTransport = transports[0] } diff --git a/route/router_dns.go b/route/router_dns.go index ee767e9e..7114882b 100644 --- a/route/router_dns.go +++ b/route/router_dns.go @@ -70,6 +70,9 @@ func (r *Router) matchDNS(ctx context.Context, allowFakeIP bool, index int) (con if rewriteTTL := rule.RewriteTTL(); rewriteTTL != nil { ctx = dns.ContextWithRewriteTTL(ctx, *rewriteTTL) } + if clientSubnet := rule.ClientSubnet(); clientSubnet != nil { + ctx = dns.ContextWithClientSubnet(ctx, *clientSubnet) + } if domainStrategy, dsLoaded := r.transportDomainStrategy[transport]; dsLoaded { return ctx, transport, domainStrategy, rule, ruleIndex } else { diff --git a/route/rule_dns.go b/route/rule_dns.go index 3eab61f8..153c9dcb 100644 --- a/route/rule_dns.go +++ b/route/rule_dns.go @@ -1,6 +1,8 @@ package route import ( + "net/netip" + "github.com/sagernet/sing-box/adapter" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" @@ -38,6 +40,7 @@ type DefaultDNSRule struct { abstractDefaultRule disableCache bool rewriteTTL *uint32 + clientSubnet *netip.Addr } func NewDefaultDNSRule(router adapter.Router, logger log.ContextLogger, options option.DefaultDNSRule) (*DefaultDNSRule, error) { @@ -48,6 +51,7 @@ func NewDefaultDNSRule(router adapter.Router, logger log.ContextLogger, options }, disableCache: options.DisableCache, rewriteTTL: options.RewriteTTL, + clientSubnet: (*netip.Addr)(options.ClientSubnet), } if len(options.Inbound) > 0 { item := NewInboundRule(options.Inbound) @@ -230,6 +234,10 @@ func (r *DefaultDNSRule) RewriteTTL() *uint32 { return r.rewriteTTL } +func (r *DefaultDNSRule) ClientSubnet() *netip.Addr { + return r.clientSubnet +} + func (r *DefaultDNSRule) WithAddressLimit() bool { if len(r.destinationIPCIDRItems) > 0 { return true @@ -264,6 +272,7 @@ type LogicalDNSRule struct { abstractLogicalRule disableCache bool rewriteTTL *uint32 + clientSubnet *netip.Addr } func NewLogicalDNSRule(router adapter.Router, logger log.ContextLogger, options option.LogicalDNSRule) (*LogicalDNSRule, error) { @@ -302,6 +311,10 @@ func (r *LogicalDNSRule) RewriteTTL() *uint32 { return r.rewriteTTL } +func (r *LogicalDNSRule) ClientSubnet() *netip.Addr { + return r.clientSubnet +} + func (r *LogicalDNSRule) WithAddressLimit() bool { for _, rawRule := range r.rules { switch rule := rawRule.(type) { diff --git a/transport/dhcp/server.go b/transport/dhcp/server.go index 1a2c2938..2b7346c6 100644 --- a/transport/dhcp/server.go +++ b/transport/dhcp/server.go @@ -21,9 +21,6 @@ import ( "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/control" E "github.com/sagernet/sing/common/exceptions" - "github.com/sagernet/sing/common/logger" - M "github.com/sagernet/sing/common/metadata" - N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/task" "github.com/sagernet/sing/common/x/list" @@ -32,14 +29,14 @@ import ( ) func init() { - dns.RegisterTransport([]string{"dhcp"}, NewTransport) + dns.RegisterTransport([]string{"dhcp"}, func(options dns.TransportOptions) (dns.Transport, error) { + return NewTransport(options) + }) } type Transport struct { - name string - ctx context.Context + options dns.TransportOptions router adapter.Router - logger logger.Logger interfaceName string autoInterface bool interfaceCallback *list.Element[tun.DefaultInterfaceUpdateCallback] @@ -48,23 +45,20 @@ type Transport struct { updatedAt time.Time } -func NewTransport(name string, ctx context.Context, logger logger.ContextLogger, dialer N.Dialer, link string) (dns.Transport, error) { - linkURL, err := url.Parse(link) +func NewTransport(options dns.TransportOptions) (*Transport, error) { + linkURL, err := url.Parse(options.Address) if err != nil { return nil, err } if linkURL.Host == "" { return nil, E.New("missing interface name for DHCP") } - router := adapter.RouterFromContext(ctx) + router := adapter.RouterFromContext(options.Context) if router == nil { return nil, E.New("missing router in context") } transport := &Transport{ - name: name, - ctx: ctx, router: router, - logger: logger, interfaceName: linkURL.Host, autoInterface: linkURL.Host == "auto", } @@ -72,7 +66,7 @@ func NewTransport(name string, ctx context.Context, logger logger.ContextLogger, } func (t *Transport) Name() string { - return t.name + return t.options.Name } func (t *Transport) Start() error { @@ -158,8 +152,8 @@ func (t *Transport) updateServers() error { return E.Cause(err, "dhcp: prepare interface") } - t.logger.Info("dhcp: query DNS servers on ", iface.Name) - fetchCtx, cancel := context.WithTimeout(t.ctx, C.DHCPTimeout) + t.options.Logger.Info("dhcp: query DNS servers on ", iface.Name) + fetchCtx, cancel := context.WithTimeout(t.options.Context, C.DHCPTimeout) err = t.fetchServers0(fetchCtx, iface) cancel() if err != nil { @@ -175,7 +169,7 @@ func (t *Transport) updateServers() error { func (t *Transport) interfaceUpdated(int) { err := t.updateServers() if err != nil { - t.logger.Error("update servers: ", err) + t.options.Logger.Error("update servers: ", err) } } @@ -187,7 +181,7 @@ func (t *Transport) fetchServers0(ctx context.Context, iface *net.Interface) err if runtime.GOOS == "linux" || runtime.GOOS == "android" { listenAddr = "255.255.255.255:68" } - packetConn, err := listener.ListenPacket(t.ctx, "udp4", listenAddr) + packetConn, err := listener.ListenPacket(t.options.Context, "udp4", listenAddr) if err != nil { return err } @@ -225,17 +219,17 @@ func (t *Transport) fetchServersResponse(iface *net.Interface, packetConn net.Pa dhcpPacket, err := dhcpv4.FromBytes(buffer.Bytes()) if err != nil { - t.logger.Trace("dhcp: parse DHCP response: ", err) + t.options.Logger.Trace("dhcp: parse DHCP response: ", err) return err } if dhcpPacket.MessageType() != dhcpv4.MessageTypeOffer { - t.logger.Trace("dhcp: expected OFFER response, but got ", dhcpPacket.MessageType()) + t.options.Logger.Trace("dhcp: expected OFFER response, but got ", dhcpPacket.MessageType()) continue } if dhcpPacket.TransactionID != transactionID { - t.logger.Trace("dhcp: expected transaction ID ", transactionID, ", but got ", dhcpPacket.TransactionID) + t.options.Logger.Trace("dhcp: expected transaction ID ", transactionID, ", but got ", dhcpPacket.TransactionID) continue } @@ -255,20 +249,22 @@ func (t *Transport) fetchServersResponse(iface *net.Interface, packetConn net.Pa func (t *Transport) recreateServers(iface *net.Interface, serverAddrs []netip.Addr) error { if len(serverAddrs) > 0 { - t.logger.Info("dhcp: updated DNS servers from ", iface.Name, ": [", strings.Join(common.Map(serverAddrs, func(it netip.Addr) string { + t.options.Logger.Info("dhcp: updated DNS servers from ", iface.Name, ": [", strings.Join(common.Map(serverAddrs, func(it netip.Addr) string { return it.String() }), ","), "]") } - serverDialer := common.Must1(dialer.NewDefault(t.router, option.DialerOptions{ BindInterface: iface.Name, UDPFragmentDefault: true, })) var transports []dns.Transport for _, serverAddr := range serverAddrs { - serverTransport, err := dns.NewUDPTransport(t.name, t.ctx, serverDialer, M.Socksaddr{Addr: serverAddr, Port: 53}) + newOptions := t.options + newOptions.Address = serverAddr.String() + newOptions.Dialer = serverDialer + serverTransport, err := dns.NewUDPTransport(newOptions) if err != nil { - return err + return E.Cause(err, "create UDP transport from DHCP result: ", serverAddr) } transports = append(transports, serverTransport) } diff --git a/transport/fakeip/server.go b/transport/fakeip/server.go index 40149aa4..5e0c7eef 100644 --- a/transport/fakeip/server.go +++ b/transport/fakeip/server.go @@ -9,7 +9,6 @@ import ( "github.com/sagernet/sing-dns" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/logger" - N "github.com/sagernet/sing/common/network" mDNS "github.com/miekg/dns" ) @@ -20,7 +19,9 @@ var ( ) func init() { - dns.RegisterTransport([]string{"fakeip"}, NewTransport) + dns.RegisterTransport([]string{"fakeip"}, func(options dns.TransportOptions) (dns.Transport, error) { + return NewTransport(options) + }) } type Transport struct { @@ -30,15 +31,15 @@ type Transport struct { logger logger.ContextLogger } -func NewTransport(name string, ctx context.Context, logger logger.ContextLogger, dialer N.Dialer, link string) (dns.Transport, error) { - router := adapter.RouterFromContext(ctx) +func NewTransport(options dns.TransportOptions) (*Transport, error) { + router := adapter.RouterFromContext(options.Context) if router == nil { return nil, E.New("missing router in context") } return &Transport{ - name: name, + name: options.Name, router: router, - logger: logger, + logger: options.Logger, }, nil }