mirror of
https://github.com/daeuniverse/dae.git
synced 2024-12-22 15:54:42 +07:00
feat(dns): support DoH, DoT, DoH3, DoQ (#649)
This commit is contained in:
parent
0e1301b851
commit
bfc17c3e2d
@ -128,7 +128,7 @@ func New(dns *config.Dns, opt *NewOption) (s *Dns, err error) {
|
||||
|
||||
func (s *Dns) CheckUpstreamsFormat() error {
|
||||
for _, upstream := range s.upstream {
|
||||
_, _, _, err := ParseRawUpstream(upstream.Raw)
|
||||
_, _, _, _, err := ParseRawUpstream(upstream.Raw)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -30,6 +30,11 @@ const (
|
||||
UpstreamScheme_UDP UpstreamScheme = "udp"
|
||||
UpstreamScheme_TCP_UDP UpstreamScheme = "tcp+udp"
|
||||
upstreamScheme_TCP_UDP_Alias UpstreamScheme = "udp+tcp"
|
||||
UpstreamScheme_TLS UpstreamScheme = "tls"
|
||||
UpstreamScheme_QUIC UpstreamScheme = "quic"
|
||||
UpstreamScheme_HTTPS UpstreamScheme = "https"
|
||||
upstreamScheme_H3_Alias UpstreamScheme = "http3"
|
||||
UpstreamScheme_H3 UpstreamScheme = "h3"
|
||||
)
|
||||
|
||||
func (s UpstreamScheme) ContainsTcp() bool {
|
||||
@ -42,8 +47,9 @@ func (s UpstreamScheme) ContainsTcp() bool {
|
||||
}
|
||||
}
|
||||
|
||||
func ParseRawUpstream(raw *url.URL) (scheme UpstreamScheme, hostname string, port uint16, err error) {
|
||||
func ParseRawUpstream(raw *url.URL) (scheme UpstreamScheme, hostname string, port uint16, path string, err error) {
|
||||
var __port string
|
||||
var __path string
|
||||
switch scheme = UpstreamScheme(raw.Scheme); scheme {
|
||||
case upstreamScheme_TCP_UDP_Alias:
|
||||
scheme = UpstreamScheme_TCP_UDP
|
||||
@ -53,27 +59,45 @@ func ParseRawUpstream(raw *url.URL) (scheme UpstreamScheme, hostname string, por
|
||||
if __port == "" {
|
||||
__port = "53"
|
||||
}
|
||||
case upstreamScheme_H3_Alias:
|
||||
scheme = UpstreamScheme_H3
|
||||
fallthrough
|
||||
case UpstreamScheme_HTTPS, UpstreamScheme_H3:
|
||||
__port = raw.Port()
|
||||
if __port == "" {
|
||||
__port = "443"
|
||||
}
|
||||
__path = raw.Path
|
||||
if __path == "" {
|
||||
__path = "/dns-query"
|
||||
}
|
||||
case UpstreamScheme_QUIC, UpstreamScheme_TLS:
|
||||
__port = raw.Port()
|
||||
if __port == "" {
|
||||
__port = "853"
|
||||
}
|
||||
default:
|
||||
return "", "", 0, fmt.Errorf("unexpected scheme: %v", raw.Scheme)
|
||||
return "", "", 0, "", fmt.Errorf("unexpected scheme: %v", raw.Scheme)
|
||||
}
|
||||
_port, err := strconv.ParseUint(__port, 10, 16)
|
||||
if err != nil {
|
||||
return "", "", 0, fmt.Errorf("failed to parse dns_upstream port: %v", err)
|
||||
return "", "", 0, "", fmt.Errorf("failed to parse dns_upstream port: %v", err)
|
||||
}
|
||||
port = uint16(_port)
|
||||
hostname = raw.Hostname()
|
||||
return scheme, hostname, port, nil
|
||||
return scheme, hostname, port, __path, nil
|
||||
}
|
||||
|
||||
type Upstream struct {
|
||||
Scheme UpstreamScheme
|
||||
Hostname string
|
||||
Port uint16
|
||||
Path string
|
||||
*netutils.Ip46
|
||||
}
|
||||
|
||||
func NewUpstream(ctx context.Context, upstream *url.URL, resolverNetwork string) (up *Upstream, err error) {
|
||||
scheme, hostname, port, err := ParseRawUpstream(upstream)
|
||||
scheme, hostname, port, path, err := ParseRawUpstream(upstream)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: %v", ErrFormat, err)
|
||||
}
|
||||
@ -100,6 +124,7 @@ func NewUpstream(ctx context.Context, upstream *url.URL, resolverNetwork string)
|
||||
Scheme: scheme,
|
||||
Hostname: hostname,
|
||||
Port: port,
|
||||
Path: path,
|
||||
Ip46: ip46,
|
||||
}, nil
|
||||
}
|
||||
@ -115,9 +140,9 @@ func (u *Upstream) SupportedNetworks() (ipversions []consts.IpVersionStr, l4prot
|
||||
}
|
||||
}
|
||||
switch u.Scheme {
|
||||
case UpstreamScheme_TCP:
|
||||
case UpstreamScheme_TCP, UpstreamScheme_HTTPS, UpstreamScheme_TLS:
|
||||
l4protos = []consts.L4ProtoStr{consts.L4ProtoStr_TCP}
|
||||
case UpstreamScheme_UDP:
|
||||
case UpstreamScheme_UDP, UpstreamScheme_QUIC, UpstreamScheme_H3:
|
||||
l4protos = []consts.L4ProtoStr{consts.L4ProtoStr_UDP}
|
||||
case UpstreamScheme_TCP_UDP:
|
||||
// UDP first.
|
||||
@ -127,7 +152,7 @@ func (u *Upstream) SupportedNetworks() (ipversions []consts.IpVersionStr, l4prot
|
||||
}
|
||||
|
||||
func (u *Upstream) String() string {
|
||||
return string(u.Scheme) + "://" + net.JoinHostPort(u.Hostname, strconv.Itoa(int(u.Port)))
|
||||
return string(u.Scheme) + "://" + net.JoinHostPort(u.Hostname, strconv.Itoa(int(u.Port))) + u.Path
|
||||
}
|
||||
|
||||
type UpstreamResolver struct {
|
||||
|
437
control/dns.go
Normal file
437
control/dns.go
Normal file
@ -0,0 +1,437 @@
|
||||
package control
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/daeuniverse/dae/common"
|
||||
"github.com/daeuniverse/dae/common/consts"
|
||||
"github.com/daeuniverse/dae/component/dns"
|
||||
"github.com/daeuniverse/outbound/netproxy"
|
||||
"github.com/daeuniverse/outbound/pool"
|
||||
tc "github.com/daeuniverse/outbound/protocol/tuic/common"
|
||||
"github.com/daeuniverse/quic-go"
|
||||
"github.com/daeuniverse/quic-go/http3"
|
||||
dnsmessage "github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type DnsForwarder interface {
|
||||
ForwardDNS(ctx context.Context, data []byte) (*dnsmessage.Msg, error)
|
||||
Close() error
|
||||
}
|
||||
|
||||
func newDnsForwarder(upstream *dns.Upstream, dialArgument dialArgument) (DnsForwarder, error) {
|
||||
forwarder, err := func() (DnsForwarder, error) {
|
||||
switch dialArgument.l4proto {
|
||||
case consts.L4ProtoStr_TCP:
|
||||
switch upstream.Scheme {
|
||||
case dns.UpstreamScheme_TCP, dns.UpstreamScheme_TCP_UDP:
|
||||
return &DoTCP{Upstream: *upstream, Dialer: dialArgument.bestDialer, dialArgument: dialArgument}, nil
|
||||
case dns.UpstreamScheme_TLS:
|
||||
return &DoTLS{Upstream: *upstream, Dialer: dialArgument.bestDialer, dialArgument: dialArgument}, nil
|
||||
case dns.UpstreamScheme_HTTPS:
|
||||
return &DoH{Upstream: *upstream, Dialer: dialArgument.bestDialer, dialArgument: dialArgument, http3: false}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected scheme: %v", upstream.Scheme)
|
||||
}
|
||||
case consts.L4ProtoStr_UDP:
|
||||
switch upstream.Scheme {
|
||||
case dns.UpstreamScheme_UDP, dns.UpstreamScheme_TCP_UDP:
|
||||
return &DoUDP{Upstream: *upstream, Dialer: dialArgument.bestDialer, dialArgument: dialArgument}, nil
|
||||
case dns.UpstreamScheme_QUIC:
|
||||
return &DoQ{Upstream: *upstream, Dialer: dialArgument.bestDialer, dialArgument: dialArgument}, nil
|
||||
case dns.UpstreamScheme_H3:
|
||||
return &DoH{Upstream: *upstream, Dialer: dialArgument.bestDialer, dialArgument: dialArgument, http3: true}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected scheme: %v", upstream.Scheme)
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected l4proto: %v", dialArgument.l4proto)
|
||||
}
|
||||
}()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return forwarder, nil
|
||||
}
|
||||
|
||||
type DoH struct {
|
||||
dns.Upstream
|
||||
netproxy.Dialer
|
||||
dialArgument dialArgument
|
||||
http3 bool
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func (d *DoH) ForwardDNS(ctx context.Context, data []byte) (*dnsmessage.Msg, error) {
|
||||
if d.client == nil {
|
||||
d.client = d.getClient()
|
||||
}
|
||||
msg, err := sendHttpDNS(d.client, d.dialArgument.bestTarget.String(), &d.Upstream, data)
|
||||
if err != nil {
|
||||
// If failed to send DNS request, we should try to create a new client.
|
||||
d.client = d.getClient()
|
||||
msg, err = sendHttpDNS(d.client, d.dialArgument.bestTarget.String(), &d.Upstream, data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func (d *DoH) getClient() *http.Client {
|
||||
var roundTripper http.RoundTripper
|
||||
if d.http3 {
|
||||
roundTripper = d.getHttp3RoundTripper()
|
||||
} else {
|
||||
roundTripper = d.getHttpRoundTripper()
|
||||
}
|
||||
|
||||
return &http.Client{
|
||||
Transport: roundTripper,
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DoH) getHttpRoundTripper() *http.Transport {
|
||||
httpTransport := http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
ServerName: d.Upstream.Hostname,
|
||||
InsecureSkipVerify: false,
|
||||
},
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
conn, err := d.dialArgument.bestDialer.DialContext(
|
||||
ctx,
|
||||
common.MagicNetwork("tcp", d.dialArgument.mark, d.dialArgument.mptcp),
|
||||
d.dialArgument.bestTarget.String(),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &netproxy.FakeNetConn{Conn: conn}, nil
|
||||
},
|
||||
}
|
||||
|
||||
return &httpTransport
|
||||
}
|
||||
|
||||
func (d *DoH) getHttp3RoundTripper() *http3.RoundTripper {
|
||||
roundTripper := &http3.RoundTripper{
|
||||
TLSClientConfig: &tls.Config{
|
||||
ServerName: d.Upstream.Hostname,
|
||||
NextProtos: []string{"h3"},
|
||||
InsecureSkipVerify: false,
|
||||
},
|
||||
QuicConfig: &quic.Config{},
|
||||
Dial: func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
|
||||
udpAddr := net.UDPAddrFromAddrPort(d.dialArgument.bestTarget)
|
||||
conn, err := d.dialArgument.bestDialer.DialContext(
|
||||
ctx,
|
||||
common.MagicNetwork("udp", d.dialArgument.mark, d.dialArgument.mptcp),
|
||||
d.dialArgument.bestTarget.String(),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fakePkt := netproxy.NewFakeNetPacketConn(conn.(netproxy.PacketConn), net.UDPAddrFromAddrPort(tc.GetUniqueFakeAddrPort()), udpAddr)
|
||||
c, e := quic.DialEarly(ctx, fakePkt, udpAddr, tlsCfg, cfg)
|
||||
return c, e
|
||||
},
|
||||
}
|
||||
return roundTripper
|
||||
}
|
||||
|
||||
func (d *DoH) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type DoQ struct {
|
||||
dns.Upstream
|
||||
netproxy.Dialer
|
||||
dialArgument dialArgument
|
||||
connection quic.EarlyConnection
|
||||
}
|
||||
|
||||
func (d *DoQ) ForwardDNS(ctx context.Context, data []byte) (*dnsmessage.Msg, error) {
|
||||
if d.connection == nil {
|
||||
qc, err := d.createConnection(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d.connection = qc
|
||||
}
|
||||
|
||||
stream, err := d.connection.OpenStreamSync(ctx)
|
||||
if err != nil {
|
||||
// If failed to open stream, we should try to create a new connection.
|
||||
qc, err := d.createConnection(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d.connection = qc
|
||||
stream, err = d.connection.OpenStreamSync(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
defer func() {
|
||||
_ = stream.Close()
|
||||
}()
|
||||
|
||||
// According https://datatracker.ietf.org/doc/html/rfc9250#section-4.2.1
|
||||
// msg id should set to 0 when transport over QUIC.
|
||||
// thanks https://github.com/natesales/q/blob/1cb2639caf69bd0a9b46494a3c689130df8fb24a/transport/quic.go#L97
|
||||
binary.BigEndian.PutUint16(data[0:2], 0)
|
||||
|
||||
msg, err := sendStreamDNS(stream, data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
func (d *DoQ) createConnection(ctx context.Context) (quic.EarlyConnection, error) {
|
||||
|
||||
udpAddr := net.UDPAddrFromAddrPort(d.dialArgument.bestTarget)
|
||||
conn, err := d.dialArgument.bestDialer.DialContext(
|
||||
ctx,
|
||||
common.MagicNetwork("udp", d.dialArgument.mark, d.dialArgument.mptcp),
|
||||
d.dialArgument.bestTarget.String(),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fakePkt := netproxy.NewFakeNetPacketConn(conn.(netproxy.PacketConn), net.UDPAddrFromAddrPort(tc.GetUniqueFakeAddrPort()), udpAddr)
|
||||
tlsCfg := &tls.Config{
|
||||
NextProtos: []string{"doq"},
|
||||
InsecureSkipVerify: false,
|
||||
ServerName: d.Upstream.Hostname,
|
||||
}
|
||||
addr := net.UDPAddrFromAddrPort(d.dialArgument.bestTarget)
|
||||
qc, err := quic.DialEarly(ctx, fakePkt, addr, tlsCfg, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return qc, nil
|
||||
|
||||
}
|
||||
|
||||
func (d *DoQ) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type DoTLS struct {
|
||||
dns.Upstream
|
||||
netproxy.Dialer
|
||||
dialArgument dialArgument
|
||||
conn netproxy.Conn
|
||||
}
|
||||
|
||||
func (d *DoTLS) ForwardDNS(ctx context.Context, data []byte) (*dnsmessage.Msg, error) {
|
||||
conn, err := d.dialArgument.bestDialer.DialContext(
|
||||
ctx,
|
||||
common.MagicNetwork("tcp", d.dialArgument.mark, d.dialArgument.mptcp),
|
||||
d.dialArgument.bestTarget.String(),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tlsConn := tls.Client(&netproxy.FakeNetConn{Conn: conn}, &tls.Config{
|
||||
InsecureSkipVerify: false,
|
||||
ServerName: d.Upstream.Hostname,
|
||||
})
|
||||
if err = tlsConn.Handshake(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d.conn = tlsConn
|
||||
|
||||
return sendStreamDNS(tlsConn, data)
|
||||
}
|
||||
|
||||
func (d *DoTLS) Close() error {
|
||||
if d.conn != nil {
|
||||
return d.conn.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type DoTCP struct {
|
||||
dns.Upstream
|
||||
netproxy.Dialer
|
||||
dialArgument dialArgument
|
||||
conn netproxy.Conn
|
||||
}
|
||||
|
||||
func (d *DoTCP) ForwardDNS(ctx context.Context, data []byte) (*dnsmessage.Msg, error) {
|
||||
conn, err := d.dialArgument.bestDialer.DialContext(
|
||||
ctx,
|
||||
common.MagicNetwork("tcp", d.dialArgument.mark, d.dialArgument.mptcp),
|
||||
d.dialArgument.bestTarget.String(),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
d.conn = conn
|
||||
return sendStreamDNS(conn, data)
|
||||
}
|
||||
|
||||
func (d *DoTCP) Close() error {
|
||||
if d.conn != nil {
|
||||
return d.conn.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type DoUDP struct {
|
||||
dns.Upstream
|
||||
netproxy.Dialer
|
||||
dialArgument dialArgument
|
||||
conn netproxy.Conn
|
||||
}
|
||||
|
||||
func (d *DoUDP) ForwardDNS(ctx context.Context, data []byte) (*dnsmessage.Msg, error) {
|
||||
conn, err := d.dialArgument.bestDialer.DialContext(
|
||||
ctx,
|
||||
common.MagicNetwork("udp", d.dialArgument.mark, d.dialArgument.mptcp),
|
||||
d.dialArgument.bestTarget.String(),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
timeout := 5 * time.Second
|
||||
_ = conn.SetDeadline(time.Now().Add(timeout))
|
||||
dnsReqCtx, cancelDnsReqCtx := context.WithTimeout(context.TODO(), timeout)
|
||||
defer cancelDnsReqCtx()
|
||||
|
||||
go func() {
|
||||
// Send DNS request every seconds.
|
||||
for {
|
||||
_, err = conn.Write(data)
|
||||
// if err != nil {
|
||||
// if c.log.IsLevelEnabled(logrus.DebugLevel) {
|
||||
// c.log.WithFields(logrus.Fields{
|
||||
// "to": dialArgument.bestTarget.String(),
|
||||
// "pid": req.routingResult.Pid,
|
||||
// "pname": ProcessName2String(req.routingResult.Pname[:]),
|
||||
// "mac": Mac2String(req.routingResult.Mac[:]),
|
||||
// "from": req.realSrc.String(),
|
||||
// "network": networkType.String(),
|
||||
// "err": err.Error(),
|
||||
// }).Debugln("Failed to write UDP(DNS) packet request.")
|
||||
// }
|
||||
// return
|
||||
// }
|
||||
select {
|
||||
case <-dnsReqCtx.Done():
|
||||
return
|
||||
case <-time.After(1 * time.Second):
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// We can block here because we are in a coroutine.
|
||||
respBuf := pool.GetFullCap(consts.EthernetMtu)
|
||||
defer pool.Put(respBuf)
|
||||
// Wait for response.
|
||||
n, err := conn.Read(respBuf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var msg dnsmessage.Msg
|
||||
if err = msg.Unpack(respBuf[:n]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &msg, nil
|
||||
}
|
||||
|
||||
func (d *DoUDP) Close() error {
|
||||
if d.conn != nil {
|
||||
return d.conn.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func sendHttpDNS(client *http.Client, target string, upstream *dns.Upstream, data []byte) (respMsg *dnsmessage.Msg, err error) {
|
||||
// disable redirect https://github.com/daeuniverse/dae/pull/649#issuecomment-2379577896
|
||||
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||
return fmt.Errorf("do not use a server that will redirect, upstream: %v", upstream.String())
|
||||
}
|
||||
serverURL := url.URL{
|
||||
Scheme: "https",
|
||||
Host: target,
|
||||
Path: upstream.Path,
|
||||
}
|
||||
q := serverURL.Query()
|
||||
// According https://datatracker.ietf.org/doc/html/rfc8484#section-4
|
||||
// msg id should set to 0 when transport over HTTPS for cache friendly.
|
||||
binary.BigEndian.PutUint16(data[0:2], 0)
|
||||
q.Set("dns", base64.RawURLEncoding.EncodeToString(data))
|
||||
serverURL.RawQuery = q.Encode()
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, serverURL.String(), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Accept", "application/dns-message")
|
||||
req.Host = upstream.Hostname
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
buf, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var msg dnsmessage.Msg
|
||||
if err = msg.Unpack(buf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &msg, nil
|
||||
}
|
||||
|
||||
func sendStreamDNS(stream io.ReadWriter, data []byte) (respMsg *dnsmessage.Msg, err error) {
|
||||
// We should write two byte length in the front of stream DNS request.
|
||||
bReq := pool.Get(2 + len(data))
|
||||
defer pool.Put(bReq)
|
||||
binary.BigEndian.PutUint16(bReq, uint16(len(data)))
|
||||
copy(bReq[2:], data)
|
||||
_, err = stream.Write(bReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to write DNS req: %w", err)
|
||||
}
|
||||
|
||||
// Read two byte length.
|
||||
if _, err = io.ReadFull(stream, bReq[:2]); err != nil {
|
||||
return nil, fmt.Errorf("failed to read DNS resp payload length: %w", err)
|
||||
}
|
||||
respLen := int(binary.BigEndian.Uint16(bReq))
|
||||
// Try to reuse the buf.
|
||||
var buf []byte
|
||||
if len(bReq) < respLen {
|
||||
buf = pool.Get(respLen)
|
||||
defer pool.Put(buf)
|
||||
} else {
|
||||
buf = bReq
|
||||
}
|
||||
var n int
|
||||
if n, err = io.ReadFull(stream, buf[:respLen]); err != nil {
|
||||
return nil, fmt.Errorf("failed to read DNS resp payload: %w", err)
|
||||
}
|
||||
var msg dnsmessage.Msg
|
||||
if err = msg.Unpack(buf[:n]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &msg, nil
|
||||
}
|
@ -7,9 +7,7 @@ package control
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"net"
|
||||
"net/netip"
|
||||
@ -18,16 +16,12 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/daeuniverse/dae/common"
|
||||
|
||||
"github.com/daeuniverse/dae/common/consts"
|
||||
"github.com/daeuniverse/dae/common/netutils"
|
||||
"github.com/daeuniverse/dae/component/dns"
|
||||
"github.com/daeuniverse/dae/component/outbound"
|
||||
"github.com/daeuniverse/dae/component/outbound/dialer"
|
||||
"github.com/daeuniverse/outbound/netproxy"
|
||||
"github.com/daeuniverse/outbound/pkg/fastrand"
|
||||
"github.com/daeuniverse/outbound/pool"
|
||||
dnsmessage "github.com/miekg/dns"
|
||||
"github.com/mohae/deepcopy"
|
||||
"github.com/sirupsen/logrus"
|
||||
@ -84,6 +78,8 @@ type DnsController struct {
|
||||
// mutex protects the dnsCache.
|
||||
dnsCacheMu sync.Mutex
|
||||
dnsCache map[string]*DnsCache
|
||||
dnsForwarderCacheMu sync.Mutex
|
||||
dnsForwarderCache map[string]DnsForwarder
|
||||
}
|
||||
|
||||
func parseIpVersionPreference(prefer int) (uint16, error) {
|
||||
@ -120,6 +116,8 @@ func NewDnsController(routing *dns.Dns, option *DnsControllerOption) (c *DnsCont
|
||||
fixedDomainTtl: option.FixedDomainTtl,
|
||||
dnsCacheMu: sync.Mutex{},
|
||||
dnsCache: make(map[string]*DnsCache),
|
||||
dnsForwarderCacheMu: sync.Mutex{},
|
||||
dnsForwarderCache: make(map[string]DnsForwarder),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -558,130 +556,40 @@ func (c *DnsController) dialSend(invokingDepth int, req *udpRequest, data []byte
|
||||
// the next recursive call. However, a connection cannot be closed twice.
|
||||
// We should set a connClosed flag to avoid it.
|
||||
var connClosed bool
|
||||
var conn netproxy.Conn
|
||||
|
||||
ctxDial, cancel := context.WithTimeout(context.TODO(), consts.DefaultDialTimeout)
|
||||
defer cancel()
|
||||
|
||||
switch dialArgument.l4proto {
|
||||
case consts.L4ProtoStr_UDP:
|
||||
// Get udp endpoint.
|
||||
|
||||
// TODO: connection pool.
|
||||
conn, err = dialArgument.bestDialer.DialContext(
|
||||
ctxDial,
|
||||
common.MagicNetwork("udp", dialArgument.mark, dialArgument.mptcp),
|
||||
dialArgument.bestTarget.String(),
|
||||
)
|
||||
// get forwarder from cache
|
||||
c.dnsForwarderCacheMu.Lock()
|
||||
forwarder, ok := c.dnsForwarderCache[upstreamName]
|
||||
if !ok {
|
||||
forwarder, err = newDnsForwarder(upstream, *dialArgument)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to dial '%v': %w", dialArgument.bestTarget, err)
|
||||
}
|
||||
defer func() {
|
||||
if !connClosed {
|
||||
conn.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
timeout := 5 * time.Second
|
||||
_ = conn.SetDeadline(time.Now().Add(timeout))
|
||||
dnsReqCtx, cancelDnsReqCtx := context.WithTimeout(context.TODO(), timeout)
|
||||
defer cancelDnsReqCtx()
|
||||
go func() {
|
||||
// Send DNS request every seconds.
|
||||
for {
|
||||
_, err = conn.Write(data)
|
||||
if err != nil {
|
||||
if c.log.IsLevelEnabled(logrus.DebugLevel) {
|
||||
c.log.WithFields(logrus.Fields{
|
||||
"to": dialArgument.bestTarget.String(),
|
||||
"pid": req.routingResult.Pid,
|
||||
"pname": ProcessName2String(req.routingResult.Pname[:]),
|
||||
"mac": Mac2String(req.routingResult.Mac[:]),
|
||||
"from": req.realSrc.String(),
|
||||
"network": networkType.String(),
|
||||
"err": err.Error(),
|
||||
}).Debugln("Failed to write UDP(DNS) packet request.")
|
||||
}
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-dnsReqCtx.Done():
|
||||
return
|
||||
case <-time.After(1 * time.Second):
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// We can block here because we are in a coroutine.
|
||||
respBuf := pool.GetFullCap(consts.EthernetMtu)
|
||||
defer pool.Put(respBuf)
|
||||
// Wait for response.
|
||||
n, err := conn.Read(respBuf)
|
||||
if err != nil {
|
||||
if c.timeoutExceedCallback != nil {
|
||||
c.timeoutExceedCallback(dialArgument, err)
|
||||
}
|
||||
return fmt.Errorf("failed to read from: %v (dialer: %v): %w", dialArgument.bestTarget, dialArgument.bestDialer.Property().Name, err)
|
||||
}
|
||||
var msg dnsmessage.Msg
|
||||
if err = msg.Unpack(respBuf[:n]); err != nil {
|
||||
c.dnsForwarderCacheMu.Unlock()
|
||||
return err
|
||||
}
|
||||
respMsg = &msg
|
||||
cancelDnsReqCtx()
|
||||
|
||||
case consts.L4ProtoStr_TCP:
|
||||
// We can block here because we are in a coroutine.
|
||||
|
||||
conn, err = dialArgument.bestDialer.DialContext(ctxDial, common.MagicNetwork("tcp", dialArgument.mark, dialArgument.mptcp), dialArgument.bestTarget.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to dial proxy to tcp: %w", err)
|
||||
c.dnsForwarderCache[upstreamName] = forwarder
|
||||
}
|
||||
c.dnsForwarderCacheMu.Unlock()
|
||||
|
||||
defer func() {
|
||||
if !connClosed {
|
||||
conn.Close()
|
||||
forwarder.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
_ = conn.SetDeadline(time.Now().Add(4900 * time.Millisecond))
|
||||
// We should write two byte length in the front of TCP DNS request.
|
||||
bReq := pool.Get(2 + len(data))
|
||||
defer pool.Put(bReq)
|
||||
binary.BigEndian.PutUint16(bReq, uint16(len(data)))
|
||||
copy(bReq[2:], data)
|
||||
_, err = conn.Write(bReq)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write DNS req: %w", err)
|
||||
}
|
||||
|
||||
// Read two byte length.
|
||||
if _, err = io.ReadFull(conn, bReq[:2]); err != nil {
|
||||
return fmt.Errorf("failed to read DNS resp payload length: %w", err)
|
||||
}
|
||||
respLen := int(binary.BigEndian.Uint16(bReq))
|
||||
// Try to reuse the buf.
|
||||
var buf []byte
|
||||
if len(bReq) < respLen {
|
||||
buf = pool.Get(respLen)
|
||||
defer pool.Put(buf)
|
||||
} else {
|
||||
buf = bReq
|
||||
}
|
||||
var n int
|
||||
if n, err = io.ReadFull(conn, buf[:respLen]); err != nil {
|
||||
return fmt.Errorf("failed to read DNS resp payload: %w", err)
|
||||
}
|
||||
var msg dnsmessage.Msg
|
||||
if err = msg.Unpack(buf[:n]); err != nil {
|
||||
return err
|
||||
}
|
||||
respMsg = &msg
|
||||
default:
|
||||
return fmt.Errorf("unexpected l4proto: %v", dialArgument.l4proto)
|
||||
|
||||
respMsg, err = forwarder.ForwardDNS(ctxDial, data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Close conn before the recursive call.
|
||||
conn.Close()
|
||||
forwarder.Close()
|
||||
connClosed = true
|
||||
|
||||
// Route response.
|
||||
|
Loading…
Reference in New Issue
Block a user