Compare commits

...

5 Commits

Author SHA1 Message Date
世界
d8191a1ee6
Treat requests with OPT extra but no options as simple requests 2025-09-05 16:28:29 +08:00
世界
c6396b8a59
Fix DNS packet size 2025-09-05 16:17:49 +08:00
世界
d78cdac640
Fix DNS client 2025-09-05 15:49:06 +08:00
世界
531beeeed7
Fix DNS cache 2025-09-05 15:20:49 +08:00
世界
f352f84483
Fix read address 2025-09-05 15:16:14 +08:00
14 changed files with 106 additions and 86 deletions

View File

@ -78,8 +78,8 @@ func (w *myUpstreamHandlerWrapper) NewError(ctx context.Context, err error) {
// Deprecated: removed
func UpstreamMetadata(metadata InboundContext) M.Metadata {
return M.Metadata{
Source: metadata.Source,
Destination: metadata.Destination,
Source: metadata.Source.Unwrap(),
Destination: metadata.Destination.Unwrap(),
}
}

View File

@ -2,12 +2,14 @@ package dns
import (
"context"
"errors"
"net"
"net/netip"
"strings"
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/compatible"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
@ -17,7 +19,7 @@ import (
"github.com/sagernet/sing/contrab/freelru"
"github.com/sagernet/sing/contrab/maphash"
dns "github.com/miekg/dns"
"github.com/miekg/dns"
)
var (
@ -30,16 +32,18 @@ var (
var _ adapter.DNSClient = (*Client)(nil)
type Client struct {
timeout time.Duration
disableCache bool
disableExpire bool
independentCache bool
clientSubnet netip.Prefix
rdrc adapter.RDRCStore
initRDRCFunc func() adapter.RDRCStore
logger logger.ContextLogger
cache freelru.Cache[dns.Question, *dns.Msg]
transportCache freelru.Cache[transportCacheKey, *dns.Msg]
timeout time.Duration
disableCache bool
disableExpire bool
independentCache bool
clientSubnet netip.Prefix
rdrc adapter.RDRCStore
initRDRCFunc func() adapter.RDRCStore
logger logger.ContextLogger
cache freelru.Cache[dns.Question, *dns.Msg]
cacheLock compatible.Map[dns.Question, chan struct{}]
transportCache freelru.Cache[transportCacheKey, *dns.Msg]
transportCacheLock compatible.Map[dns.Question, chan struct{}]
}
type ClientOptions struct {
@ -96,17 +100,15 @@ func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, m
if c.logger != nil {
c.logger.WarnContext(ctx, "bad question size: ", len(message.Question))
}
responseMessage := dns.Msg{
MsgHdr: dns.MsgHdr{
Id: message.Id,
Response: true,
Rcode: dns.RcodeFormatError,
},
Question: message.Question,
}
return &responseMessage, nil
return FixedResponseStatus(message, dns.RcodeFormatError), nil
}
question := message.Question[0]
if question.Qtype == dns.TypeA && options.Strategy == C.DomainStrategyIPv6Only || question.Qtype == dns.TypeAAAA && options.Strategy == C.DomainStrategyIPv4Only {
if c.logger != nil {
c.logger.DebugContext(ctx, "strategy rejected")
}
return FixedResponseStatus(message, dns.RcodeSuccess), nil
}
clientSubnet := options.ClientSubnet
if !clientSubnet.IsValid() {
clientSubnet = c.clientSubnet
@ -114,12 +116,38 @@ func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, m
if clientSubnet.IsValid() {
message = SetClientSubnet(message, clientSubnet)
}
isSimpleRequest := len(message.Question) == 1 &&
len(message.Ns) == 0 &&
len(message.Extra) == 0 &&
(len(message.Extra) == 0 || len(message.Extra) == 1 &&
message.Extra[0].Header().Rrtype == dns.TypeOPT &&
message.Extra[0].Header().Class > 0 &&
message.Extra[0].Header().Ttl == 0 &&
len(message.Extra[0].(*dns.OPT).Option) == 0) &&
!options.ClientSubnet.IsValid()
disableCache := !isSimpleRequest || c.disableCache || options.DisableCache
if !disableCache {
if c.cache != nil {
cond, loaded := c.cacheLock.LoadOrStore(question, make(chan struct{}))
if loaded {
<-cond
} else {
defer func() {
c.cacheLock.Delete(question)
close(cond)
}()
}
} else if c.transportCache != nil {
cond, loaded := c.transportCacheLock.LoadOrStore(question, make(chan struct{}))
if loaded {
<-cond
} else {
defer func() {
c.transportCacheLock.Delete(question)
close(cond)
}()
}
}
response, ttl := c.loadResponse(question, transport)
if response != nil {
logCachedResponse(c.logger, ctx, response, ttl)
@ -127,27 +155,14 @@ func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, m
return response, nil
}
}
if question.Qtype == dns.TypeA && options.Strategy == C.DomainStrategyIPv6Only || question.Qtype == dns.TypeAAAA && options.Strategy == C.DomainStrategyIPv4Only {
responseMessage := dns.Msg{
MsgHdr: dns.MsgHdr{
Id: message.Id,
Response: true,
Rcode: dns.RcodeSuccess,
},
Question: []dns.Question{question},
}
if c.logger != nil {
c.logger.DebugContext(ctx, "strategy rejected")
}
return &responseMessage, nil
}
messageId := message.Id
contextTransport, clientSubnetLoaded := transportTagFromContext(ctx)
if clientSubnetLoaded && transport.Tag() == contextTransport {
return nil, E.New("DNS query loopback in transport[", contextTransport, "]")
}
ctx = contextWithTransportTag(ctx, transport.Tag())
if responseChecker != nil && c.rdrc != nil {
if !disableCache && responseChecker != nil && c.rdrc != nil {
rejected := c.rdrc.LoadRDRC(transport.Tag(), question.Name, question.Qtype)
if rejected {
return nil, ErrResponseRejectedCached
@ -157,7 +172,12 @@ func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, m
response, err := transport.Exchange(ctx, message)
cancel()
if err != nil {
return nil, err
var rcodeError RcodeError
if errors.As(err, &rcodeError) {
response = FixedResponseStatus(message, int(rcodeError))
} else {
return nil, err
}
}
/*if question.Qtype == dns.TypeA || question.Qtype == dns.TypeAAAA {
validResponse := response
@ -196,13 +216,14 @@ func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, m
}*/
if responseChecker != nil {
var rejected bool
if !(response.Rcode == dns.RcodeSuccess || response.Rcode == dns.RcodeNameError) {
// TODO: add accept_any rule and support to check response instead of addresses
if response.Rcode != dns.RcodeSuccess || len(response.Answer) == 0 {
rejected = true
} else {
rejected = !responseChecker(MessageToAddresses(response))
}
if rejected {
if c.rdrc != nil {
if !disableCache && c.rdrc != nil {
c.rdrc.SaveRDRCAsync(transport.Tag(), question.Name, question.Qtype, c.logger)
}
logRejectedResponse(c.logger, ctx, response)
@ -305,8 +326,7 @@ func (c *Client) Lookup(ctx context.Context, transport adapter.DNSTransport, dom
func (c *Client) ClearCache() {
if c.cache != nil {
c.cache.Purge()
}
if c.transportCache != nil {
} else if c.transportCache != nil {
c.transportCache.Purge()
}
}
@ -320,36 +340,36 @@ func (c *Client) LookupCache(domain string, strategy C.DomainStrategy) ([]netip.
}
dnsName := dns.Fqdn(domain)
if strategy == C.DomainStrategyIPv4Only {
response, err := c.questionCache(dns.Question{
addresses, err := c.questionCache(dns.Question{
Name: dnsName,
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}, nil)
if err != ErrNotCached {
return response, true
return addresses, true
}
} else if strategy == C.DomainStrategyIPv6Only {
response, err := c.questionCache(dns.Question{
addresses, err := c.questionCache(dns.Question{
Name: dnsName,
Qtype: dns.TypeAAAA,
Qclass: dns.ClassINET,
}, nil)
if err != ErrNotCached {
return response, true
return addresses, true
}
} else {
response4, _ := c.questionCache(dns.Question{
response4, _ := c.loadResponse(dns.Question{
Name: dnsName,
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}, nil)
response6, _ := c.questionCache(dns.Question{
response6, _ := c.loadResponse(dns.Question{
Name: dnsName,
Qtype: dns.TypeAAAA,
Qclass: dns.ClassINET,
}, nil)
if len(response4) > 0 || len(response6) > 0 {
return sortAddresses(response4, response6, strategy), true
if response4 != nil || response6 != nil {
return sortAddresses(MessageToAddresses(response4), MessageToAddresses(response6), strategy), true
}
}
return nil, false
@ -390,15 +410,15 @@ func (c *Client) storeCache(transport adapter.DNSTransport, question dns.Questio
transportTag: transport.Tag(),
}, message)
}
return
}
if !c.independentCache {
c.cache.AddWithLifetime(question, message, time.Second*time.Duration(timeToLive))
} else {
c.transportCache.AddWithLifetime(transportCacheKey{
Question: question,
transportTag: transport.Tag(),
}, message, time.Second*time.Duration(timeToLive))
if !c.independentCache {
c.cache.AddWithLifetime(question, message, time.Second*time.Duration(timeToLive))
} else {
c.transportCache.AddWithLifetime(transportCacheKey{
Question: question,
transportTag: transport.Tag(),
}, message, time.Second*time.Duration(timeToLive))
}
}
}
@ -517,6 +537,9 @@ func (c *Client) loadResponse(question dns.Question, transport adapter.DNSTransp
}
func MessageToAddresses(response *dns.Msg) []netip.Addr {
if response.Rcode != dns.RcodeSuccess {
return nil
}
addresses := make([]netip.Addr, 0, len(response.Answer))
for _, rawAnswer := range response.Answer {
switch answer := rawAnswer.(type) {
@ -561,9 +584,12 @@ func transportTagFromContext(ctx context.Context) (string, bool) {
func FixedResponseStatus(message *dns.Msg, rcode int) *dns.Msg {
return &dns.Msg{
MsgHdr: dns.MsgHdr{
Id: message.Id,
Rcode: rcode,
Response: true,
Id: message.Id,
Response: true,
Authoritative: true,
RecursionDesired: true,
RecursionAvailable: true,
Rcode: rcode,
},
Question: message.Question,
}

View File

@ -16,11 +16,6 @@ import (
mDNS "github.com/miekg/dns"
)
const (
// net.maxDNSPacketSize
maxDNSPacketSize = 1232
)
func (t *Transport) exchangeSingleRequest(ctx context.Context, servers []M.Socksaddr, message *mDNS.Msg, domain string) (*mDNS.Msg, error) {
var lastErr error
for _, fqdn := range t.nameList(domain) {
@ -118,7 +113,7 @@ func (t *Transport) exchangeOne(ctx context.Context, server M.Socksaddr, questio
Question: []mDNS.Question{question},
Compress: true,
}
request.SetEdns0(maxDNSPacketSize, false)
request.SetEdns0(buf.UDPBufferSize, false)
buffer := buf.Get(buf.UDPBufferSize)
defer buf.Put(buffer)
for _, network := range networks {

View File

@ -2,7 +2,9 @@ package local
import (
"context"
"errors"
"math/rand"
"syscall"
"time"
"github.com/sagernet/sing-box/adapter"
@ -164,7 +166,7 @@ func (t *Transport) exchangeOne(ctx context.Context, server M.Socksaddr, questio
Question: []mDNS.Question{question},
Compress: true,
}
request.SetEdns0(maxDNSPacketSize, false)
request.SetEdns0(buf.UDPBufferSize, false)
buffer := buf.Get(buf.UDPBufferSize)
defer buf.Put(buffer)
for _, network := range networks {
@ -184,6 +186,9 @@ func (t *Transport) exchangeOne(ctx context.Context, server M.Socksaddr, questio
}
_, err = conn.Write(rawMessage)
if err != nil {
if errors.Is(err, syscall.EMSGSIZE) && network == N.NetworkUDP {
continue
}
return nil, E.Cause(err, "write request")
}
n, err := conn.Read(buffer)

View File

@ -10,11 +10,6 @@ import (
"time"
)
const (
// net.maxDNSPacketSize
maxDNSPacketSize = 1232
)
type resolverConfig struct {
initOnce sync.Once
ch chan struct{}

View File

@ -19,7 +19,6 @@ func NewLocalDialer(ctx context.Context, options option.LocalDNSServerOptions) (
if options.LegacyDefaultDialer {
return dialer.NewDefaultOutbound(ctx), nil
} else {
options.UDPFragmentDefault = true
return dialer.NewWithOptions(dialer.Options{
Context: ctx,
Options: options.DialerOptions,

View File

@ -6,8 +6,8 @@ import (
"sync/atomic"
"time"
"github.com/sagernet/sing-box/common/compatible"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/experimental/clashapi/compatible"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/json"
"github.com/sagernet/sing/common/x/list"

4
go.mod
View File

@ -27,9 +27,9 @@ require (
github.com/sagernet/gomobile v0.1.8
github.com/sagernet/gvisor v0.0.0-20250325023245-7a9c0f5725fb
github.com/sagernet/quic-go v0.52.0-beta.1
github.com/sagernet/sing v0.7.6
github.com/sagernet/sing v0.7.7
github.com/sagernet/sing-mux v0.3.3
github.com/sagernet/sing-quic v0.5.0
github.com/sagernet/sing-quic v0.5.1
github.com/sagernet/sing-shadowsocks v0.2.8
github.com/sagernet/sing-shadowsocks2 v0.2.1
github.com/sagernet/sing-shadowtls v0.2.1-0.20250503051639-fcd445d33c11

8
go.sum
View File

@ -167,12 +167,12 @@ github.com/sagernet/nftables v0.3.0-beta.4/go.mod h1:OQXAjvjNGGFxaTgVCSTRIhYB5/l
github.com/sagernet/quic-go v0.52.0-beta.1 h1:hWkojLg64zjV+MJOvJU/kOeWndm3tiEfBLx5foisszs=
github.com/sagernet/quic-go v0.52.0-beta.1/go.mod h1:OV+V5kEBb8kJS7k29MzDu6oj9GyMc7HA07sE1tedxz4=
github.com/sagernet/sing v0.6.9/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
github.com/sagernet/sing v0.7.6 h1:6LBfDH+aI/26J3r9UHlaxTNjJeMhBpU/wrk0JKDZYI4=
github.com/sagernet/sing v0.7.6/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
github.com/sagernet/sing v0.7.7 h1:o46FzVZS+wKbBMEkMEdEHoVZxyM9jvfRpKXc7pEgS/c=
github.com/sagernet/sing v0.7.7/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
github.com/sagernet/sing-mux v0.3.3 h1:YFgt9plMWzH994BMZLmyKL37PdIVaIilwP0Jg+EcLfw=
github.com/sagernet/sing-mux v0.3.3/go.mod h1:pht8iFY4c9Xltj7rhVd208npkNaeCxzyXCgulDPLUDA=
github.com/sagernet/sing-quic v0.5.0 h1:jNLIyVk24lFPvu8A4x+ZNEnZdI+Tg1rp7eCJ6v0Csak=
github.com/sagernet/sing-quic v0.5.0/go.mod h1:SAv/qdeDN+75msGG5U5ZIwG+3Ua50jVIKNrRSY8pkx0=
github.com/sagernet/sing-quic v0.5.1 h1:o+mX/schfy6fbbU2rnb6ouUYOL+iUBjA4jOZqyIvDsU=
github.com/sagernet/sing-quic v0.5.1/go.mod h1:evP1e++ZG8TJHVV5HudXV4vWeYzGfCdF4HwSJZcdqkI=
github.com/sagernet/sing-shadowsocks v0.2.8 h1:PURj5PRoAkqeHh2ZW205RWzN9E9RtKCVCzByXruQWfE=
github.com/sagernet/sing-shadowsocks v0.2.8/go.mod h1:lo7TWEMDcN5/h5B8S0ew+r78ZODn6SwVaFhvB6H+PTI=
github.com/sagernet/sing-shadowsocks2 v0.2.1 h1:dWV9OXCeFPuYGHb6IRqlSptVnSzOelnqqs2gQ2/Qioo=

View File

@ -124,7 +124,7 @@ func (h *inboundHandler) NewConnectionEx(ctx context.Context, conn net.Conn, sou
//nolint:staticcheck
metadata.InboundOptions = h.listener.ListenOptions().InboundOptions
metadata.Source = source
metadata.Destination = destination
metadata.Destination = destination.Unwrap()
if userName, _ := auth.UserFromContext[string](ctx); userName != "" {
metadata.User = userName
h.logger.InfoContext(ctx, "[", userName, "] inbound connection to ", metadata.Destination)

View File

@ -170,7 +170,7 @@ func (n *Inbound) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
hostPort = request.Host
}
source := sHttp.SourceAddress(request)
destination := M.ParseSocksaddr(hostPort)
destination := M.ParseSocksaddr(hostPort).Unwrap()
if hijacker, isHijacker := writer.(http.Hijacker); isHijacker {
conn, _, err := hijacker.Hijack()

View File

@ -292,7 +292,7 @@ func ReadPacket(conn net.Conn, buffer *buf.Buffer) (M.Socksaddr, error) {
}
_, err = buffer.ReadFullFrom(conn, int(length))
return destination.Unwrap(), err
return destination, err
}
func WritePacket(conn net.Conn, buffer *buf.Buffer, destination M.Socksaddr) error {

View File

@ -138,7 +138,7 @@ func (c *ClientBind) receive(packets [][]byte, sizes []int, eps []conn.Endpoint)
b := packets[0]
common.ClearArray(b[1:4])
}
eps[0] = remoteEndpoint(M.AddrPortFromNet(addr))
eps[0] = remoteEndpoint(M.SocksaddrFromNet(addr).Unwrap().AddrPort())
count = 1
return
}