dae/common/netutils/dns.go
mzz 218ae3f654
fix: connection leaks (#624)
Co-authored-by: dae-prow[bot] <136105375+dae-prow[bot]@users.noreply.github.com>
2024-09-26 22:40:29 +08:00

279 lines
6.1 KiB
Go

/*
* SPDX-License-Identifier: AGPL-3.0-only
* Copyright (c) 2022-2024, daeuniverse Organization <dae@v2raya.org>
*/
package netutils
import (
"context"
"encoding/binary"
"fmt"
"io"
"math"
"net/netip"
"sync"
"time"
"github.com/daeuniverse/dae/common/consts"
"github.com/daeuniverse/outbound/netproxy"
"github.com/daeuniverse/outbound/pkg/fastrand"
"github.com/daeuniverse/outbound/pool"
dnsmessage "github.com/miekg/dns"
)
var (
systemDnsMu sync.Mutex
systemDns netip.AddrPort
systemDnsNextUpdateAfter time.Time
ErrBadDnsAns = fmt.Errorf("bad dns answer")
BootstrapDns = netip.MustParseAddrPort("208.67.222.222:5353")
)
func TryUpdateSystemDns() (err error) {
systemDnsMu.Lock()
err = tryUpdateSystemDns()
systemDnsMu.Unlock()
return err
}
// TryUpdateSystemDnsElapse will update system DNS if duration has elapsed since the last TryUpdateSystemDns1s call.
func TryUpdateSystemDnsElapse(k time.Duration) (err error) {
systemDnsMu.Lock()
defer systemDnsMu.Unlock()
return tryUpdateSystemDnsElapse(k)
}
func tryUpdateSystemDnsElapse(k time.Duration) (err error) {
if time.Now().Before(systemDnsNextUpdateAfter) {
return fmt.Errorf("update too quickly")
}
err = tryUpdateSystemDns()
if err != nil {
return err
}
systemDnsNextUpdateAfter = time.Now().Add(k)
return nil
}
func tryUpdateSystemDns() (err error) {
dnsConf := dnsReadConfig("/etc/resolv.conf")
systemDns = netip.AddrPort{}
for _, s := range dnsConf.servers {
ipPort := netip.MustParseAddrPort(s)
if !ipPort.Addr().IsLoopback() {
systemDns = ipPort
break
}
}
if !systemDns.IsValid() {
systemDns = BootstrapDns
}
return nil
}
func SystemDns() (dns netip.AddrPort, err error) {
systemDnsMu.Lock()
defer systemDnsMu.Unlock()
if !systemDns.IsValid() {
if err = tryUpdateSystemDns(); err != nil {
return netip.AddrPort{}, err
}
}
// To avoid environment changing.
_ = tryUpdateSystemDnsElapse(5 * time.Second)
return systemDns, nil
}
func ResolveNetip(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host string, typ uint16, network string) (addrs []netip.Addr, err error) {
resources, err := resolve(ctx, d, dns, host, typ, network)
if err != nil {
return nil, err
}
for _, ans := range resources {
if ans.Header().Rrtype != typ {
continue
}
var (
ip netip.Addr
okk bool
)
switch typ {
case dnsmessage.TypeA:
a, ok := ans.(*dnsmessage.A)
if !ok {
return nil, ErrBadDnsAns
}
ip, okk = netip.AddrFromSlice(a.A)
case dnsmessage.TypeAAAA:
a, ok := ans.(*dnsmessage.AAAA)
if !ok {
return nil, ErrBadDnsAns
}
ip, okk = netip.AddrFromSlice(a.AAAA)
}
if !okk {
continue
}
addrs = append(addrs, ip)
}
return addrs, nil
}
func ResolveNS(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host string, network string) (records []string, err error) {
typ := dnsmessage.TypeNS
resources, err := resolve(ctx, d, dns, host, typ, network)
if err != nil {
return nil, err
}
for _, ans := range resources {
if ans.Header().Rrtype != typ {
continue
}
ns, ok := ans.(*dnsmessage.NS)
if !ok {
return nil, ErrBadDnsAns
}
records = append(records, ns.Ns)
}
return records, nil
}
func resolve(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host string, typ uint16, network string) (ans []dnsmessage.RR, err error) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
fqdn := dnsmessage.CanonicalName(host)
switch typ {
case dnsmessage.TypeA, dnsmessage.TypeAAAA:
if addr, err := netip.ParseAddr(host); err == nil {
if (addr.Is4() || addr.Is4In6()) && typ == dnsmessage.TypeA {
return []dnsmessage.RR{
&dnsmessage.A{
Hdr: dnsmessage.RR_Header{
Name: dnsmessage.CanonicalName(fqdn),
Class: dnsmessage.ClassINET,
Ttl: 0,
Rrtype: typ,
},
A: addr.AsSlice(),
},
}, nil
} else if addr.Is6() && typ == dnsmessage.TypeAAAA {
return []dnsmessage.RR{
&dnsmessage.AAAA{
Hdr: dnsmessage.RR_Header{
Name: dnsmessage.CanonicalName(fqdn),
Class: dnsmessage.ClassINET,
Ttl: 0,
Rrtype: typ,
},
AAAA: addr.AsSlice(),
},
}, nil
}
// MUST No record.
return nil, nil
}
default:
}
// Build DNS req.
builder := dnsmessage.Msg{
MsgHdr: dnsmessage.MsgHdr{
Id: uint16(fastrand.Intn(math.MaxUint16 + 1)),
Response: false,
Opcode: 0,
Truncated: false,
RecursionDesired: true,
Authoritative: false,
},
}
builder.SetQuestion(fqdn, typ)
b, err := builder.Pack()
if err != nil {
return nil, err
}
magicNetwork, err := netproxy.ParseMagicNetwork(network)
if err != nil {
return nil, err
}
if magicNetwork.Network == "tcp" {
// Put DNS request length
buf := pool.Get(2 + len(b))
defer pool.Put(buf)
binary.BigEndian.PutUint16(buf, uint16(len(b)))
copy(buf[2:], b)
b = buf
}
// Dial and write.
c, err := d.DialContext(ctx, network, dns.String())
if err != nil {
return nil, err
}
defer c.Close()
_, err = c.Write(b)
if err != nil {
return nil, err
}
ch := make(chan error, 2)
if magicNetwork.Network == "udp" {
go func() {
// Resend every 3 seconds for UDP.
for {
select {
case <-ctx.Done():
return
default:
time.Sleep(3 * time.Second)
}
_, err := c.Write(b)
if err != nil {
ch <- err
return
}
}
}()
}
go func() {
buf := pool.GetFullCap(consts.EthernetMtu)
defer buf.Put()
if magicNetwork.Network == "tcp" {
// Read DNS response length
_, err := io.ReadFull(c, buf[:2])
if err != nil {
ch <- err
return
}
n := binary.BigEndian.Uint16(buf)
if int(n) > cap(buf) {
ch <- fmt.Errorf("too big dns resp")
return
}
buf = buf[:n]
}
n, err := c.Read(buf)
if err != nil {
ch <- err
return
}
// Resolve DNS response and extract A/AAAA record.
var msg dnsmessage.Msg
if err = msg.Unpack(buf[:n]); err != nil {
ch <- err
return
}
ans = msg.Answer
ch <- nil
}()
select {
case <-ctx.Done():
return nil, fmt.Errorf("timeout")
case err = <-ch:
if err != nil {
return nil, err
}
return ans, nil
}
}