feat: dns upstream

This commit is contained in:
mzz2017
2023-01-24 17:15:27 +08:00
parent 14b215752f
commit 686d6dedc3
4 changed files with 59 additions and 23 deletions

View File

@ -44,7 +44,9 @@ type ControlPlane struct {
// mutex protects the dnsCache.
mutex sync.Mutex
dnsCache map[string]*dnsCache
epoch uint32
// Deprecated
epoch uint32
dnsUpstream netip.AddrPort
deferFuncs []func() error
}
@ -92,12 +94,12 @@ retry_load:
if err := bpf.ParamMap.Update(consts.DisableL4RxChecksumKey, consts.DisableL4ChecksumPolicy_SetZero, ebpf.UpdateAny); err != nil {
return nil, err
}
var epoch uint32
bpf.ParamMap.Lookup(consts.EpochKey, &epoch)
epoch++
if err := bpf.ParamMap.Update(consts.EpochKey, epoch, ebpf.UpdateAny); err != nil {
return nil, err
}
//var epoch uint32
//bpf.ParamMap.Lookup(consts.EpochKey, &epoch)
//epoch++
//if err := bpf.ParamMap.Update(consts.EpochKey, epoch, ebpf.UpdateAny); err != nil {
// return nil, err
//}
//if err := bpf.ParamMap.Update(consts.InterfaceIpParamOff, binary.LittleEndian.Uint32([]byte{172, 17, 0, 1}), ebpf.UpdateAny); err != nil { // 172.17.0.1
// return nil, err
//}
@ -110,6 +112,16 @@ retry_load:
// return
//}
cfDnsAddr := netip.AddrFrom4([4]byte{1, 1, 1, 1})
cfDnsAddr16 := cfDnsAddr.As16()
cfDnsPort := uint16(53)
if err := bpf.DnsUpstreamMap.Update(consts.ZeroKey, bpfIpPort{
Ip: common.Ipv6ByteSliceToUint32Array(cfDnsAddr16[:]),
Port: swap16(cfDnsPort),
}, ebpf.UpdateAny); err != nil {
return nil, err
}
/**/
// TODO:
d, err := dialer.NewFromLink("socks5://localhost:1080#proxy")
@ -177,8 +189,9 @@ retry_load:
Final: final,
mutex: sync.Mutex{},
dnsCache: make(map[string]*dnsCache),
epoch: epoch,
deferFuncs: []func() error{bpf.Close},
dnsUpstream: netip.AddrPortFrom(cfDnsAddr, cfDnsPort),
//epoch: epoch,
deferFuncs: []func() error{bpf.Close},
}, nil
}

View File

@ -58,7 +58,7 @@ enum {
#define OUTBOUND_LOGICAL_AND 0xFF
// Param keys:
static const __u32 ips_len_key __attribute__((unused, deprecated)) = 0;
static const __u32 zero_key = 0;
static const __u32 tproxy_port_key = 1;
static const __u32 disable_l4_tx_checksum_key = 2;
static const __u32 disable_l4_rx_checksum_key = 3;
@ -110,6 +110,14 @@ struct {
__uint(pinning, LIBBPF_PIN_BY_NAME);
} param_map SEC(".maps");
// Dns upstream:
struct {
__uint(type, BPF_MAP_TYPE_ARRAY);
__type(key, __u32);
__type(value, struct ip_port);
__uint(max_entries, 1);
} dns_upstream_map SEC(".maps");
// Interface Ips:
struct if_ip {
__be32 ip4[4];
@ -753,6 +761,15 @@ static long routing(__u8 flag[2], void *l4_hdr, __be32 saddr[4],
h_dport = bpf_ntohs(((struct udphdr *)l4_hdr)->dest);
h_sport = bpf_ntohs(((struct udphdr *)l4_hdr)->source);
}
// Modify DNS upstream for routing.
if (h_dport == 53 && _network == NETWORK_TYPE_UDP) {
struct ip_port* upstream = bpf_map_lookup_elem(&dns_upstream_map, &zero_key);
if (!upstream) {
return -EFAULT;
}
h_dport = bpf_ntohs(upstream->port);
__builtin_memcpy(daddr, upstream->ip, IPV6_BYTE_LENGTH);
}
struct lpm_key lpm_key_saddr, lpm_key_daddr, lpm_key_mac, *lpm_key;
lpm_key_saddr.trie_key.prefixlen = IPV6_BYTE_LENGTH * 8;
lpm_key_daddr.trie_key.prefixlen = IPV6_BYTE_LENGTH * 8;
@ -798,9 +815,6 @@ static long routing(__u8 flag[2], void *l4_hdr, __be32 saddr[4],
if (!bpf_map_lookup_elem(lpm, lpm_key)) {
// Routing not hit.
bad_rule = true;
bpf_printk("index: %u not hit", routing->index);
} else {
bpf_printk("index: %u hit", routing->index);
}
} else if (routing->type == ROUTING_TYPE_DOMAIN_SET) {
// Bottleneck of insns limit.
@ -854,7 +868,7 @@ static long routing(__u8 flag[2], void *l4_hdr, __be32 saddr[4],
if (!bad_rule) {
if (routing->outbound == OUTBOUND_DIRECT && h_dport == 53 &&
_network == NETWORK_TYPE_UDP) {
// DNS packet should go through control plane.
// DNS packet should go through control plane.
return OUTBOUND_CONTROL_PLANE_DIRECT;
}
return routing->outbound;

View File

@ -76,7 +76,7 @@ func sendPktWithHdr(data []byte, from netip.AddrPort, lConn *net.UDPConn, to net
return err
}
func (c *ControlPlane) RelayToUDP(lConn *net.UDPConn, to netip.AddrPort, isDNS bool) UdpHandler {
func (c *ControlPlane) RelayToUDP(lConn *net.UDPConn, to netip.AddrPort, isDNS bool, dummyFrom *netip.AddrPort) UdpHandler {
return func(data []byte, from netip.AddrPort) (err error) {
if isDNS {
data, err = c.DnsRespHandler(data)
@ -84,6 +84,9 @@ func (c *ControlPlane) RelayToUDP(lConn *net.UDPConn, to netip.AddrPort, isDNS b
c.log.Warnf("DnsRespHandler: %v", err)
}
}
if dummyFrom != nil {
from = *dummyFrom
}
return sendPktWithHdr(data, from, lConn, to)
}
}
@ -104,9 +107,11 @@ func (c *ControlPlane) handlePkt(data []byte, lConn *net.UDPConn, lAddrPort neti
dnsMessage, natTimeout := ChooseNatTimeout(data)
// We should cache DNS records and set record TTL to 0, in order to monitor the dns req and resp in real time.
isDns := dnsMessage != nil
var dummyFrom *netip.AddrPort
dest := addrHdr.Dest
if isDns {
if resp := c.LookupDnsRespCache(dnsMessage); resp != nil {
if err = sendPktWithHdr(resp, addrHdr.Dest, lConn, lAddrPort); err != nil {
if err = sendPktWithHdr(resp, dest, lConn, lAddrPort); err != nil {
return fmt.Errorf("failed to write cached DNS resp: %w", err)
}
q := dnsMessage.Questions[0]
@ -115,28 +120,33 @@ func (c *ControlPlane) handlePkt(data []byte, lConn *net.UDPConn, lAddrPort neti
)
return nil
} else {
c.log.Debugf("Modify dns target %v to upstream: %v", addrHdr.Dest.String(), c.dnsUpstream.String())
// Modify dns target to upstream.
// NOTICE: Routing was calculated in advance by the eBPF program.
dummyFrom = &addrHdr.Dest
dest = c.dnsUpstream
q := dnsMessage.Questions[0]
c.log.Debugf("UDP(DNS) %v <-[%v]-> %v: %v %v",
lAddrPort.String(), outbound.Name, addrHdr.Dest.String(), q.Name, q.Type,
lAddrPort.String(), outbound.Name, dest.String(), q.Name, q.Type,
)
}
} else {
// TODO: Set-up ip to domain mapping and show domain if possible.
c.log.Infof("UDP %v <-[%v]-> %v",
lAddrPort.String(), outbound.Name, addrHdr.Dest.String(),
lAddrPort.String(), outbound.Name, dest.String(),
)
}
ue, err := DefaultUdpEndpointPool.GetOrCreate(lAddrPort, &UdpEndpointOptions{
Handler: c.RelayToUDP(lConn, lAddrPort, isDns),
Handler: c.RelayToUDP(lConn, lAddrPort, isDns, dummyFrom),
NatTimeout: natTimeout,
Dialer: outbound,
Target: addrHdr.Dest,
Target: dest,
})
if err != nil {
return fmt.Errorf("failed to GetOrCreate: %w", err)
}
//log.Printf("WriteToUDPAddrPort->%v", dest)
_, err = ue.WriteToUDPAddrPort(data, addrHdr.Dest)
_, err = ue.WriteToUDPAddrPort(data, dest)
if err != nil {
return fmt.Errorf("failed to write UDP packet req: %w", err)
}