dae/control/dns.go

438 lines
12 KiB
Go

package control
import (
"context"
"crypto/tls"
"encoding/base64"
"encoding/binary"
"fmt"
"io"
"net"
"net/http"
"net/url"
"time"
"github.com/daeuniverse/dae/common"
"github.com/daeuniverse/dae/common/consts"
"github.com/daeuniverse/dae/component/dns"
"github.com/daeuniverse/outbound/netproxy"
"github.com/daeuniverse/outbound/pool"
tc "github.com/daeuniverse/outbound/protocol/tuic/common"
"github.com/daeuniverse/quic-go"
"github.com/daeuniverse/quic-go/http3"
dnsmessage "github.com/miekg/dns"
)
type DnsForwarder interface {
ForwardDNS(ctx context.Context, data []byte) (*dnsmessage.Msg, error)
Close() error
}
func newDnsForwarder(upstream *dns.Upstream, dialArgument dialArgument) (DnsForwarder, error) {
forwarder, err := func() (DnsForwarder, error) {
switch dialArgument.l4proto {
case consts.L4ProtoStr_TCP:
switch upstream.Scheme {
case dns.UpstreamScheme_TCP, dns.UpstreamScheme_TCP_UDP:
return &DoTCP{Upstream: *upstream, Dialer: dialArgument.bestDialer, dialArgument: dialArgument}, nil
case dns.UpstreamScheme_TLS:
return &DoTLS{Upstream: *upstream, Dialer: dialArgument.bestDialer, dialArgument: dialArgument}, nil
case dns.UpstreamScheme_HTTPS:
return &DoH{Upstream: *upstream, Dialer: dialArgument.bestDialer, dialArgument: dialArgument, http3: false}, nil
default:
return nil, fmt.Errorf("unexpected scheme: %v", upstream.Scheme)
}
case consts.L4ProtoStr_UDP:
switch upstream.Scheme {
case dns.UpstreamScheme_UDP, dns.UpstreamScheme_TCP_UDP:
return &DoUDP{Upstream: *upstream, Dialer: dialArgument.bestDialer, dialArgument: dialArgument}, nil
case dns.UpstreamScheme_QUIC:
return &DoQ{Upstream: *upstream, Dialer: dialArgument.bestDialer, dialArgument: dialArgument}, nil
case dns.UpstreamScheme_H3:
return &DoH{Upstream: *upstream, Dialer: dialArgument.bestDialer, dialArgument: dialArgument, http3: true}, nil
default:
return nil, fmt.Errorf("unexpected scheme: %v", upstream.Scheme)
}
default:
return nil, fmt.Errorf("unexpected l4proto: %v", dialArgument.l4proto)
}
}()
if err != nil {
return nil, err
}
return forwarder, nil
}
type DoH struct {
dns.Upstream
netproxy.Dialer
dialArgument dialArgument
http3 bool
client *http.Client
}
func (d *DoH) ForwardDNS(ctx context.Context, data []byte) (*dnsmessage.Msg, error) {
if d.client == nil {
d.client = d.getClient()
}
msg, err := sendHttpDNS(d.client, d.dialArgument.bestTarget.String(), &d.Upstream, data)
if err != nil {
// If failed to send DNS request, we should try to create a new client.
d.client = d.getClient()
msg, err = sendHttpDNS(d.client, d.dialArgument.bestTarget.String(), &d.Upstream, data)
if err != nil {
return nil, err
}
return msg, nil
}
return msg, nil
}
func (d *DoH) getClient() *http.Client {
var roundTripper http.RoundTripper
if d.http3 {
roundTripper = d.getHttp3RoundTripper()
} else {
roundTripper = d.getHttpRoundTripper()
}
return &http.Client{
Transport: roundTripper,
}
}
func (d *DoH) getHttpRoundTripper() *http.Transport {
httpTransport := http.Transport{
TLSClientConfig: &tls.Config{
ServerName: d.Upstream.Hostname,
InsecureSkipVerify: false,
},
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
conn, err := d.dialArgument.bestDialer.DialContext(
ctx,
common.MagicNetwork("tcp", d.dialArgument.mark, d.dialArgument.mptcp),
d.dialArgument.bestTarget.String(),
)
if err != nil {
return nil, err
}
return &netproxy.FakeNetConn{Conn: conn}, nil
},
}
return &httpTransport
}
func (d *DoH) getHttp3RoundTripper() *http3.RoundTripper {
roundTripper := &http3.RoundTripper{
TLSClientConfig: &tls.Config{
ServerName: d.Upstream.Hostname,
NextProtos: []string{"h3"},
InsecureSkipVerify: false,
},
QuicConfig: &quic.Config{},
Dial: func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
udpAddr := net.UDPAddrFromAddrPort(d.dialArgument.bestTarget)
conn, err := d.dialArgument.bestDialer.DialContext(
ctx,
common.MagicNetwork("udp", d.dialArgument.mark, d.dialArgument.mptcp),
d.dialArgument.bestTarget.String(),
)
if err != nil {
return nil, err
}
fakePkt := netproxy.NewFakeNetPacketConn(conn.(netproxy.PacketConn), net.UDPAddrFromAddrPort(tc.GetUniqueFakeAddrPort()), udpAddr)
c, e := quic.DialEarly(ctx, fakePkt, udpAddr, tlsCfg, cfg)
return c, e
},
}
return roundTripper
}
func (d *DoH) Close() error {
return nil
}
type DoQ struct {
dns.Upstream
netproxy.Dialer
dialArgument dialArgument
connection quic.EarlyConnection
}
func (d *DoQ) ForwardDNS(ctx context.Context, data []byte) (*dnsmessage.Msg, error) {
if d.connection == nil {
qc, err := d.createConnection(ctx)
if err != nil {
return nil, err
}
d.connection = qc
}
stream, err := d.connection.OpenStreamSync(ctx)
if err != nil {
// If failed to open stream, we should try to create a new connection.
qc, err := d.createConnection(ctx)
if err != nil {
return nil, err
}
d.connection = qc
stream, err = d.connection.OpenStreamSync(ctx)
if err != nil {
return nil, err
}
}
defer func() {
_ = stream.Close()
}()
// According https://datatracker.ietf.org/doc/html/rfc9250#section-4.2.1
// msg id should set to 0 when transport over QUIC.
// thanks https://github.com/natesales/q/blob/1cb2639caf69bd0a9b46494a3c689130df8fb24a/transport/quic.go#L97
binary.BigEndian.PutUint16(data[0:2], 0)
msg, err := sendStreamDNS(stream, data)
if err != nil {
return nil, err
}
return msg, nil
}
func (d *DoQ) createConnection(ctx context.Context) (quic.EarlyConnection, error) {
udpAddr := net.UDPAddrFromAddrPort(d.dialArgument.bestTarget)
conn, err := d.dialArgument.bestDialer.DialContext(
ctx,
common.MagicNetwork("udp", d.dialArgument.mark, d.dialArgument.mptcp),
d.dialArgument.bestTarget.String(),
)
if err != nil {
return nil, err
}
fakePkt := netproxy.NewFakeNetPacketConn(conn.(netproxy.PacketConn), net.UDPAddrFromAddrPort(tc.GetUniqueFakeAddrPort()), udpAddr)
tlsCfg := &tls.Config{
NextProtos: []string{"doq"},
InsecureSkipVerify: false,
ServerName: d.Upstream.Hostname,
}
addr := net.UDPAddrFromAddrPort(d.dialArgument.bestTarget)
qc, err := quic.DialEarly(ctx, fakePkt, addr, tlsCfg, nil)
if err != nil {
return nil, err
}
return qc, nil
}
func (d *DoQ) Close() error {
return nil
}
type DoTLS struct {
dns.Upstream
netproxy.Dialer
dialArgument dialArgument
conn netproxy.Conn
}
func (d *DoTLS) ForwardDNS(ctx context.Context, data []byte) (*dnsmessage.Msg, error) {
conn, err := d.dialArgument.bestDialer.DialContext(
ctx,
common.MagicNetwork("tcp", d.dialArgument.mark, d.dialArgument.mptcp),
d.dialArgument.bestTarget.String(),
)
if err != nil {
return nil, err
}
tlsConn := tls.Client(&netproxy.FakeNetConn{Conn: conn}, &tls.Config{
InsecureSkipVerify: false,
ServerName: d.Upstream.Hostname,
})
if err = tlsConn.Handshake(); err != nil {
return nil, err
}
d.conn = tlsConn
return sendStreamDNS(tlsConn, data)
}
func (d *DoTLS) Close() error {
if d.conn != nil {
return d.conn.Close()
}
return nil
}
type DoTCP struct {
dns.Upstream
netproxy.Dialer
dialArgument dialArgument
conn netproxy.Conn
}
func (d *DoTCP) ForwardDNS(ctx context.Context, data []byte) (*dnsmessage.Msg, error) {
conn, err := d.dialArgument.bestDialer.DialContext(
ctx,
common.MagicNetwork("tcp", d.dialArgument.mark, d.dialArgument.mptcp),
d.dialArgument.bestTarget.String(),
)
if err != nil {
return nil, err
}
d.conn = conn
return sendStreamDNS(conn, data)
}
func (d *DoTCP) Close() error {
if d.conn != nil {
return d.conn.Close()
}
return nil
}
type DoUDP struct {
dns.Upstream
netproxy.Dialer
dialArgument dialArgument
conn netproxy.Conn
}
func (d *DoUDP) ForwardDNS(ctx context.Context, data []byte) (*dnsmessage.Msg, error) {
conn, err := d.dialArgument.bestDialer.DialContext(
ctx,
common.MagicNetwork("udp", d.dialArgument.mark, d.dialArgument.mptcp),
d.dialArgument.bestTarget.String(),
)
if err != nil {
return nil, err
}
timeout := 5 * time.Second
_ = conn.SetDeadline(time.Now().Add(timeout))
dnsReqCtx, cancelDnsReqCtx := context.WithTimeout(context.TODO(), timeout)
defer cancelDnsReqCtx()
go func() {
// Send DNS request every seconds.
for {
_, err = conn.Write(data)
// if err != nil {
// if c.log.IsLevelEnabled(logrus.DebugLevel) {
// c.log.WithFields(logrus.Fields{
// "to": dialArgument.bestTarget.String(),
// "pid": req.routingResult.Pid,
// "pname": ProcessName2String(req.routingResult.Pname[:]),
// "mac": Mac2String(req.routingResult.Mac[:]),
// "from": req.realSrc.String(),
// "network": networkType.String(),
// "err": err.Error(),
// }).Debugln("Failed to write UDP(DNS) packet request.")
// }
// return
// }
select {
case <-dnsReqCtx.Done():
return
case <-time.After(1 * time.Second):
}
}
}()
// We can block here because we are in a coroutine.
respBuf := pool.GetFullCap(consts.EthernetMtu)
defer pool.Put(respBuf)
// Wait for response.
n, err := conn.Read(respBuf)
if err != nil {
return nil, err
}
var msg dnsmessage.Msg
if err = msg.Unpack(respBuf[:n]); err != nil {
return nil, err
}
return &msg, nil
}
func (d *DoUDP) Close() error {
if d.conn != nil {
return d.conn.Close()
}
return nil
}
func sendHttpDNS(client *http.Client, target string, upstream *dns.Upstream, data []byte) (respMsg *dnsmessage.Msg, err error) {
// disable redirect https://github.com/daeuniverse/dae/pull/649#issuecomment-2379577896
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
return fmt.Errorf("do not use a server that will redirect, upstream: %v", upstream.String())
}
serverURL := url.URL{
Scheme: "https",
Host: target,
Path: upstream.Path,
}
q := serverURL.Query()
// According https://datatracker.ietf.org/doc/html/rfc8484#section-4
// msg id should set to 0 when transport over HTTPS for cache friendly.
binary.BigEndian.PutUint16(data[0:2], 0)
q.Set("dns", base64.RawURLEncoding.EncodeToString(data))
serverURL.RawQuery = q.Encode()
req, err := http.NewRequest(http.MethodGet, serverURL.String(), nil)
if err != nil {
return nil, err
}
req.Header.Set("Accept", "application/dns-message")
req.Host = upstream.Hostname
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
buf, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
var msg dnsmessage.Msg
if err = msg.Unpack(buf); err != nil {
return nil, err
}
return &msg, nil
}
func sendStreamDNS(stream io.ReadWriter, data []byte) (respMsg *dnsmessage.Msg, err error) {
// We should write two byte length in the front of stream DNS request.
bReq := pool.Get(2 + len(data))
defer pool.Put(bReq)
binary.BigEndian.PutUint16(bReq, uint16(len(data)))
copy(bReq[2:], data)
_, err = stream.Write(bReq)
if err != nil {
return nil, fmt.Errorf("failed to write DNS req: %w", err)
}
// Read two byte length.
if _, err = io.ReadFull(stream, bReq[:2]); err != nil {
return nil, fmt.Errorf("failed to read DNS resp payload length: %w", err)
}
respLen := int(binary.BigEndian.Uint16(bReq))
// Try to reuse the buf.
var buf []byte
if len(bReq) < respLen {
buf = pool.Get(respLen)
defer pool.Put(buf)
} else {
buf = bReq
}
var n int
if n, err = io.ReadFull(stream, buf[:respLen]); err != nil {
return nil, fmt.Errorf("failed to read DNS resp payload: %w", err)
}
var msg dnsmessage.Msg
if err = msg.Unpack(buf[:n]); err != nil {
return nil, err
}
return &msg, nil
}