diff --git a/dns/transport/hosts/hosts.go b/dns/transport/hosts/hosts.go index 773a1d2a..f13e85ae 100644 --- a/dns/transport/hosts/hosts.go +++ b/dns/transport/hosts/hosts.go @@ -2,6 +2,7 @@ package hosts import ( "context" + "net/netip" "os" "github.com/sagernet/sing-box/adapter" @@ -9,6 +10,8 @@ import ( "github.com/sagernet/sing-box/dns" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing/common/json/badjson" + "github.com/sagernet/sing/common/json/badoption" "github.com/sagernet/sing/service/filemanager" mDNS "github.com/miekg/dns" @@ -22,7 +25,8 @@ var _ adapter.DNSTransport = (*Transport)(nil) type Transport struct { dns.TransportAdapter - files []*File + files []*File + predefined badjson.TypedMap[string, badoption.Listable[netip.Addr]] } func NewTransport(ctx context.Context, logger log.ContextLogger, tag string, options option.HostsDNSServerOptions) (adapter.DNSTransport, error) { @@ -37,6 +41,7 @@ func NewTransport(ctx context.Context, logger log.ContextLogger, tag string, opt return &Transport{ TransportAdapter: dns.NewTransportAdapter(C.DNSTypeHosts, tag, nil), files: files, + predefined: options.Predefined, }, nil } @@ -47,6 +52,10 @@ func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, question := message.Question[0] domain := dns.FqdnToDomain(question.Name) if question.Qtype == mDNS.TypeA || question.Qtype == mDNS.TypeAAAA { + if addresses, ok := t.predefined.Get(domain); ok { + return dns.FixedResponse(message.Id, question, addresses, C.DefaultDNSTTL), nil + } + for _, file := range t.files { addresses := file.Lookup(domain) if len(addresses) > 0 { diff --git a/dns/transport/hosts/hosts_file.go b/dns/transport/hosts/hosts_file.go index 7ff34f69..84d7316c 100644 --- a/dns/transport/hosts/hosts_file.go +++ b/dns/transport/hosts/hosts_file.go @@ -9,8 +9,6 @@ import ( "strings" "sync" "time" - - "github.com/miekg/dns" ) const cacheMaxAge = 5 * time.Second @@ -91,8 +89,9 @@ func (f *File) update() { continue } for index := 1; index < len(fields); index++ { - canonicalName := dns.CanonicalName(fields[index]) - byName[canonicalName] = append(byName[canonicalName], addr) + // canonicalName := dns.CanonicalName(fields[index]) + domain := fields[index] + byName[domain] = append(byName[domain], addr) } } f.expire = now.Add(cacheMaxAge) diff --git a/dns/transport/hosts/hosts_test.go b/dns/transport/hosts/hosts_test.go index 944aa437..3ae160b7 100644 --- a/dns/transport/hosts/hosts_test.go +++ b/dns/transport/hosts/hosts_test.go @@ -11,6 +11,6 @@ import ( func TestHosts(t *testing.T) { t.Parallel() - require.Equal(t, []netip.Addr{netip.AddrFrom4([4]byte{127, 0, 0, 1}), netip.IPv6Loopback()}, hosts.NewFile("testdata/hosts").Lookup("localhost.")) - require.NotEmpty(t, hosts.NewFile(hosts.DefaultPath).Lookup("localhost.")) + require.Equal(t, []netip.Addr{netip.AddrFrom4([4]byte{127, 0, 0, 1}), netip.IPv6Loopback()}, hosts.NewFile("testdata/hosts").Lookup("localhost")) + require.NotEmpty(t, hosts.NewFile(hosts.DefaultPath).Lookup("localhost")) }