mirror of
synced 2025-03-10 12:51:49 +07:00
558 lines
16 KiB
558 lines
16 KiB
* SPDX-License-Identifier: AGPL-3.0-only
* Copyright (c) 2023, v2rayA Organization <team@v2raya.org>
package control
import (
const (
MaxDnsLookupDepth = 3
var (
SuspectedRushAnswerError = fmt.Errorf("suspected DNS rush-answer")
UnsupportedQuestionTypeError = fmt.Errorf("unsupported question type")
type DnsControllerOption struct {
Log *logrus.Logger
CacheAccessCallback func(cache *DnsCache) (err error)
NewCache func(fqdn string, answers []dnsmessage.Resource, deadline time.Time) (cache *DnsCache, err error)
BestDialerChooser func(req *udpRequest, upstream *dns.Upstream) (*dialArgument, error)
type DnsController struct {
routing *dns.Dns
log *logrus.Logger
cacheAccessCallback func(cache *DnsCache) (err error)
newCache func(fqdn string, answers []dnsmessage.Resource, deadline time.Time) (cache *DnsCache, err error)
bestDialerChooser func(req *udpRequest, upstream *dns.Upstream) (*dialArgument, error)
// mutex protects the dnsCache.
dnsCacheMu sync.Mutex
dnsCache map[string]*DnsCache
func NewDnsController(routing *dns.Dns, option *DnsControllerOption) (c *DnsController, err error) {
return &DnsController{
routing: routing,
log: option.Log,
cacheAccessCallback: option.CacheAccessCallback,
newCache: option.NewCache,
bestDialerChooser: option.BestDialerChooser,
dnsCacheMu: sync.Mutex{},
dnsCache: make(map[string]*DnsCache),
}, nil
func (c *DnsController) LookupDnsRespCache(domain string, t dnsmessage.Type) (cache *DnsCache) {
now := time.Now()
// To fqdn.
if !strings.HasSuffix(domain, ".") {
domain = domain + "."
cache, ok := c.dnsCache[strings.ToLower(domain)+t.String()]
if ok && cache.Deadline.After(now) {
return cache
return nil
// LookupDnsRespCache_ will modify the msg in place.
func (c *DnsController) LookupDnsRespCache_(msg *dnsmessage.Message) (resp []byte) {
if len(msg.Questions) == 0 {
return nil
q := msg.Questions[0]
if msg.Response {
return nil
switch q.Type {
case dnsmessage.TypeA, dnsmessage.TypeAAAA:
return nil
cache := c.LookupDnsRespCache(q.Name.String(), q.Type)
if cache != nil {
b, err := msg.Pack()
if err != nil {
c.log.Warnf("failed to pack: %v", err)
return nil
if err = c.cacheAccessCallback(cache); err != nil {
c.log.Warnf("failed to BatchUpdateDomainRouting: %v", err)
return nil
return b
return nil
// DnsRespHandler handle DNS resp.
func (c *DnsController) DnsRespHandler(data []byte, validateRushAns bool) (newMsg *dnsmessage.Message, err error) {
var msg dnsmessage.Message
if err = msg.Unpack(data); err != nil {
return nil, fmt.Errorf("unpack dns pkt: %w", err)
// Check healthy resp.
if !msg.Response || len(msg.Questions) == 0 {
return &msg, nil
q := msg.Questions[0]
// Check suc resp.
if msg.RCode != dnsmessage.RCodeSuccess {
return &msg, nil
// Check req type.
switch q.Type {
case dnsmessage.TypeA, dnsmessage.TypeAAAA:
return &msg, nil
// Set ttl.
var ttl uint32
for i := range msg.Answers {
if ttl == 0 {
ttl = msg.Answers[i].Header.TTL
// Set TTL = zero. This requests applications must resend every request.
// However, it may be not defined in the standard.
msg.Answers[i].Header.TTL = 0
// Check if there is any A/AAAA record.
var hasIpRecord bool
for i := range msg.Answers {
switch msg.Answers[i].Header.Type {
case dnsmessage.TypeA, dnsmessage.TypeAAAA:
hasIpRecord = true
break loop
if !hasIpRecord {
return &msg, nil
if validateRushAns {
exist, e := EnsureAdditionalOpt(&msg, false)
if e != nil && !errors.Is(e, UnsupportedQuestionTypeError) {
c.log.Warnf("EnsureAdditionalOpt: %v", e)
if e == nil && !exist {
// Additional record OPT in the request was ensured, and in normal case the resp should also set it.
// This DNS packet may be a rush-answer, and we should reject it.
"ques": q,
"addition": FormatDnsRsc(msg.Additionals),
"ans": FormatDnsRsc(msg.Answers),
}).Traceln("DNS rush-answer detected")
return nil, SuspectedRushAnswerError
// Update DnsCache.
if c.log.IsLevelEnabled(logrus.TraceLevel) {
"qname": q.Name,
"rcode": msg.RCode,
"ans": FormatDnsRsc(msg.Answers),
"auth": FormatDnsRsc(msg.Authorities),
"addition": FormatDnsRsc(msg.Additionals),
}).Tracef("Update DNS record cache")
if err = c.UpdateDnsCache(q.Name.String(), q.Type, msg.Answers, time.Now().Add(time.Duration(ttl)*time.Second+DnsNatTimeout)); err != nil {
return nil, err
// Pack to get newData.
return &msg, nil
func (c *DnsController) UpdateDnsCache(host string, typ dnsmessage.Type, answers []dnsmessage.Resource, deadline time.Time) (err error) {
var fqdn string
if strings.HasSuffix(host, ".") {
fqdn = host
host = host[:len(host)-1]
} else {
fqdn = host + "."
// Bypass pure IP.
if _, err = netip.ParseAddr(host); err == nil {
return nil
cacheKey := fqdn + typ.String()
cache, ok := c.dnsCache[cacheKey]
if ok {
cache.Deadline = deadline
cache.Answers = answers
} else {
cache, err = c.newCache(fqdn, answers, deadline)
if err != nil {
return err
c.dnsCache[cacheKey] = cache
if err = c.cacheAccessCallback(cache); err != nil {
return err
return nil
func (c *DnsController) DnsRespHandlerFactory(req *udpRequest, validateRushAnsFunc func(from netip.AddrPort) bool) func(data []byte, from netip.AddrPort) (msg *dnsmessage.Message, err error) {
return func(data []byte, from netip.AddrPort) (msg *dnsmessage.Message, err error) {
// Do not return conn-unrelated err in this func.
validateRushAns := validateRushAnsFunc(from)
msg, err = c.DnsRespHandler(data, validateRushAns)
if err != nil {
if errors.Is(err, SuspectedRushAnswerError) {
if validateRushAns {
// Reject DNS rush-answer.
"from": from,
}).Tracef("DNS rush-answer rejected")
return nil, nil
} else {
if c.log.IsLevelEnabled(logrus.DebugLevel) {
c.log.Debugf("DnsRespHandler: %v", err)
return nil, err
return msg, nil
type udpRequest struct {
lanWanFlag consts.LanWanFlag
realSrc netip.AddrPort
realDst netip.AddrPort
src netip.AddrPort
lConn *net.UDPConn
routingResult *bpfRoutingResult
type dialArgument struct {
l4proto consts.L4ProtoStr
ipversion consts.IpVersionStr
bestDialer *dialer.Dialer
bestOutbound *outbound.DialerGroup
bestTarget netip.AddrPort
mark uint32
func (c *DnsController) Handle_(dnsMessage *dnsmessage.Message, req *udpRequest) (err error) {
if resp := c.LookupDnsRespCache_(dnsMessage); resp != nil {
// Send cache to client directly.
if err = sendPkt(resp, req.realDst, req.realSrc, req.src, req.lConn, req.lanWanFlag); err != nil {
return fmt.Errorf("failed to write cached DNS resp: %w", err)
if c.log.IsLevelEnabled(logrus.DebugLevel) && len(dnsMessage.Questions) > 0 {
q := dnsMessage.Questions[0]
c.log.Tracef("UDP(DNS) %v <-> Cache: %v %v",
RefineSourceToShow(req.realSrc, req.realDst.Addr(), req.lanWanFlag), strings.ToLower(q.Name.String()), q.Type,
return nil
// Make sure there is additional record OPT in the request to filter DNS rush-answer in the response process.
// Because rush-answer has no resp OPT. We can distinguish them from multiple responses.
// Note that additional record OPT may not be supported by home router either.
_, _ = EnsureAdditionalOpt(dnsMessage, true)
// Route request.
upstream, err := c.routing.RequestSelect(dnsMessage)
if err != nil {
return err
if c.log.IsLevelEnabled(logrus.TraceLevel) {
upstreamName := "asis"
if upstream != nil {
upstreamName = upstream.String()
"question": dnsMessage.Questions,
"upstream": upstreamName,
}).Traceln("Request to DNS upstream")
// Re-pack DNS packet.
data, err := dnsMessage.Pack()
if err != nil {
return fmt.Errorf("pack DNS packet: %w", err)
return c.dialSend(req, data, upstream, 0)
func (c *DnsController) dialSend(req *udpRequest, data []byte, upstream *dns.Upstream, invokingDepth int) (err error) {
if invokingDepth >= MaxDnsLookupDepth {
return fmt.Errorf("too deep DNS lookup invoking (depth: %v); there may be infinite loop in your DNS response routing", MaxDnsLookupDepth)
upstreamName := "asis"
if upstream == nil {
// As-is.
// As-is should not be valid in response routing, thus using connection realDest is reasonable.
var ip46 netutils.Ip46
if req.realDst.Addr().Is4() {
ip46.Ip4 = req.realDst.Addr()
} else {
ip46.Ip6 = req.realDst.Addr()
upstream = &dns.Upstream{
Scheme: "udp",
Hostname: req.realDst.Addr().String(),
Port: req.realDst.Port(),
Ip46: &ip46,
} else {
upstreamName = upstream.String()
// Select best dial arguments (outbound, dialer, l4proto, ipversion, etc.)
dialArgument, err := c.bestDialerChooser(req, upstream)
if err != nil {
return err
networkType := &dialer.NetworkType{
L4Proto: dialArgument.l4proto,
IpVersion: dialArgument.ipversion,
IsDns: true, // UDP relies on DNS check result.
// dnsRespHandler caches dns response and check rush answers.
dnsRespHandler := c.DnsRespHandlerFactory(req, func(from netip.AddrPort) bool {
// We only validate rush-ans when outbound is direct and pkt does not send to a home device.
// Because additional record OPT may not be supported by home router.
// So se should trust home devices even if they make rush-answer (or looks like).
return dialArgument.bestDialer.Property().Name == "direct" && !from.Addr().IsPrivate()
// Dial and send.
var respMsg *dnsmessage.Message
// defer in a recursive call will delay Close(), thus we Close() before
// the next recursive call. However, a connection cannot be closed twice.
// We should set a connClosed flag to avoid it.
var connClosed bool
var conn netproxy.Conn
// TODO: Rewritten domain should not use full-cone (such as VMess Packet Addr).
// Maybe we should set up a mapping for UDP: Dialer + Target Domain => Remote Resolved IP.
// However, games may not use QUIC for communication, thus we cannot use domain to dial, which is fine.
switch dialArgument.l4proto {
case consts.L4ProtoStr_UDP:
// Get udp endpoint.
// TODO: connection pool.
conn, err = dialArgument.bestDialer.Dial(
MagicNetwork("udp", dialArgument.mark),
if err != nil {
return fmt.Errorf("failed to dial '%v': %w", dialArgument.bestTarget, err)
defer func() {
if !connClosed {
_ = conn.SetDeadline(time.Now().Add(DnsNatTimeout))
_, err = conn.Write(data)
if err != nil {
if c.log.IsLevelEnabled(logrus.DebugLevel) {
"to": dialArgument.bestTarget.String(),
"pid": req.routingResult.Pid,
"pname": ProcessName2String(req.routingResult.Pname[:]),
"mac": Mac2String(req.routingResult.Mac[:]),
"from": req.realSrc.String(),
"network": networkType.String(),
"err": err.Error(),
}).Debugln("Failed to write UDP(DNS) packet request.")
return fmt.Errorf("failed to write UDP(DNS) packet request: %w", err)
// We can block here because we are in a coroutine.
respBuf := pool.Get(512)
defer pool.Put(respBuf)
for {
// Wait for response.
n, err := conn.Read(respBuf)
if err != nil {
return fmt.Errorf("failed to read from: %v (dialer: %v): %w", dialArgument.bestTarget, dialArgument.bestDialer.Property().Name, err)
respMsg, err = dnsRespHandler(respBuf[:n], dialArgument.bestTarget)
if err != nil {
return err
if respMsg != nil {
case consts.L4ProtoStr_TCP:
// We can block here because we are in a coroutine.
conn, err = dialArgument.bestDialer.Dial(MagicNetwork("tcp", dialArgument.mark), dialArgument.bestTarget.String())
if err != nil {
return fmt.Errorf("failed to dial proxy to tcp: %w", err)
defer func() {
if !connClosed {
_ = conn.SetDeadline(time.Now().Add(DnsNatTimeout))
// We should write two byte length in the front of TCP DNS request.
bReq := pool.Get(2 + len(data))
defer pool.Put(bReq)
binary.BigEndian.PutUint16(bReq, uint16(len(data)))
copy(bReq[2:], data)
_, err = conn.Write(bReq)
if err != nil {
return fmt.Errorf("failed to write DNS req: %w", err)
// Read two byte length.
if _, err = io.ReadFull(conn, bReq[:2]); err != nil {
return fmt.Errorf("failed to read DNS resp payload length: %w", err)
respLen := int(binary.BigEndian.Uint16(bReq))
// Try to reuse the buf.
var buf []byte
if len(bReq) < respLen {
buf = pool.Get(respLen)
defer pool.Put(buf)
} else {
buf = bReq
var n int
if n, err = io.ReadFull(conn, buf[:respLen]); err != nil {
return fmt.Errorf("failed to read DNS resp payload: %w", err)
respMsg, err = dnsRespHandler(buf[:n], dialArgument.bestTarget)
if respMsg == nil && err == nil {
err = fmt.Errorf("bad DNS response")
if err != nil {
return fmt.Errorf("failed to write DNS resp to client: %w", err)
return fmt.Errorf("unexpected l4proto: %v", dialArgument.l4proto)
// Close conn before the recursive call.
connClosed = true
// Route response.
upstreamIndex, nextUpstream, err := c.routing.ResponseSelect(respMsg, upstream)
if err != nil {
return err
switch upstreamIndex {
case consts.DnsResponseOutboundIndex_Accept:
// Accept.
if c.log.IsLevelEnabled(logrus.TraceLevel) {
"question": respMsg.Questions,
"upstream": upstreamName,
case consts.DnsResponseOutboundIndex_Reject:
// Reject the request with empty answer.
respMsg.Answers = nil
if c.log.IsLevelEnabled(logrus.TraceLevel) {
"question": respMsg.Questions,
"upstream": upstreamName,
}).Traceln("Reject with empty answer")
if c.log.IsLevelEnabled(logrus.TraceLevel) {
"question": respMsg.Questions,
"last_upstream": upstreamName,
"next_upstream": nextUpstream.String(),
}).Traceln("Change DNS upstream and resend")
return c.dialSend(req, data, nextUpstream, invokingDepth+1)
if upstreamIndex.IsReserved() && c.log.IsLevelEnabled(logrus.InfoLevel) {
var qname, qtype string
if len(respMsg.Questions) > 0 {
q := respMsg.Questions[0]
qname = strings.ToLower(q.Name.String())
qtype = q.Type.String()
fields := logrus.Fields{
"network": networkType.String(),
"outbound": dialArgument.bestOutbound.Name,
"policy": dialArgument.bestOutbound.GetSelectionPolicy(),
"dialer": dialArgument.bestDialer.Property().Name,
"qname": qname,
"qtype": qtype,
"pid": req.routingResult.Pid,
"pname": ProcessName2String(req.routingResult.Pname[:]),
"mac": Mac2String(req.routingResult.Mac[:]),
switch upstreamIndex {
case consts.DnsResponseOutboundIndex_Accept:
c.log.WithFields(fields).Infof("%v <-> %v", RefineSourceToShow(req.realSrc, req.realDst.Addr(), req.lanWanFlag), RefineAddrPortToShow(dialArgument.bestTarget))
case consts.DnsResponseOutboundIndex_Reject:
c.log.WithFields(fields).Infof("%v -> reject", RefineSourceToShow(req.realSrc, req.realDst.Addr(), req.lanWanFlag))
return fmt.Errorf("unknown upstream: %v", upstreamIndex.String())
data, err = respMsg.Pack()
if err != nil {
return err
if err = sendPkt(data, req.realDst, req.realSrc, req.src, req.lConn, req.lanWanFlag); err != nil {
return err
return nil