working AXFR, printable enumbs

This commit is contained in:
bert hubert 2018-04-08 23:20:11 +02:00
parent 65a39afcfa
commit 515bfca635
5 changed files with 317 additions and 141 deletions

View File

@ -12,5 +12,5 @@ clean:
-include *.d
tdns: tdns.o ext/simplesocket/comboaddress.o ext/simplesocket/sclasses.o ext/simplesocket/swrappers.o
tdns: tdns.o dns-storage.o ext/simplesocket/comboaddress.o ext/simplesocket/sclasses.o ext/simplesocket/swrappers.o
g++ -std=gnu++14 $^ -o $@ -pthread

81
tdns/dns-storage.cc Normal file
View File

@ -0,0 +1,81 @@
#include "dns-storage.hh"
using namespace std;
const DNSNode* DNSNode::find(dnsname& name, dnsname& last, bool* passedZonecut) const
{
cout<<"find for '"<<name<<"', last is now '"<<last<<"'"<<endl;
if(!last.empty() && passedZonecut && rrsets.count(DNSType::NS)) {
*passedZonecut=true;
}
if(name.empty()) {
cout<<"Empty lookup, returning this node or 0"<<endl;
if(!zone && rrsets.empty()) // only root zone can have this
return 0;
else
return this;
}
cout<<"Children at this node: ";
for(const auto& c: children) cout <<"'"<<c.first<<"' ";
cout<<endl;
auto iter = children.find(name.back());
cout<<"Looked for child called '"<<name.back()<<"'"<<endl;
if(iter == children.end()) {
cout<<"Found nothing, trying wildcard"<<endl;
iter = children.find("*");
if(iter == children.end()) {
cout<<"Still nothing, returning leaf"<<endl;
return this;
}
else {
cout<<"Had wildcard match, following"<<endl;
}
}
cout<<"Had match, continuing to child '"<<iter->first<<"'"<<endl;
last.push_front(name.back());
name.pop_back();
return iter->second.find(name, last, passedZonecut);
}
DNSNode* DNSNode::add(dnsname name)
{
cout<<"Add for '"<<name<<"'"<<endl;
if(name.size() == 1) {
cout<<"Last label, adding "<<name.front()<<endl;
return &children[name.front()];
}
auto back = name.back();
name.pop_back();
auto iter = children.find(back);
if(iter == children.end()) {
cout<<"Inserting new child for "<<back<<endl;
return children[back].add(name);
}
return iter->second.add(name);
}
dnsname operator+(const dnsname& a, const dnsname& b)
{
dnsname ret=a;
for(const auto& l : b.d_name)
ret.d_name.push_back(l);
return ret;
}
void DNSNode::visit(std::function<void(const dnsname& name, const DNSNode*)> visitor, dnsname name) const
{
visitor(name, this);
for(const auto& c : children)
c.second.visit(visitor, dnsname{c.first}+name);
}
// this should perform escaping rules!
std::ostream & operator<<(std::ostream &os, const dnsname& d)
{
for(const auto& l : d.d_name) {
os<<l<<".";
}
return os;
}

85
tdns/dns-storage.hh Normal file
View File

@ -0,0 +1,85 @@
#pragma once
#include <strings.h>
#include <string>
#include <map>
#include <vector>
#include <deque>
#include <iostream>
#include <cstdint>
#include <functional>
#include "nenum.hh"
typedef std::string dnslabel;
enum class RCode
{
Noerror = 0, Servfail = 2, Nxdomain = 3, Notimp = 4, Refused = 5
};
SMARTENUMSTART(RCode)
SENUM5(RCode, Noerror, Servfail, Nxdomain, Notimp, Refused)
SMARTENUMEND(RCode)
enum class DNSType : uint16_t
{
A = 1, NS = 2, CNAME = 5, SOA=6, MX=15, AAAA = 28, IXFR = 251, AXFR = 252, ANY = 255
};
SMARTENUMSTART(DNSType)
SENUM11(DNSType, A, NS, CNAME, SOA, MX, AAAA, IXFR, AAAA, IXFR, AXFR, ANY)
SMARTENUMEND(DNSType)
enum class DNSSection
{
Question, Answer, Authority, Additional
};
SMARTENUMSTART(DNSSection)
SENUM4(DNSSection, Question, Answer, Authority, Additional)
SMARTENUMEND(DNSSection)
struct dnsname
{
dnsname() {}
dnsname(std::initializer_list<dnslabel> dls) : d_name(dls) {}
void push_back(const dnslabel& l) { d_name.push_back(l); }
auto back() const { return d_name.back(); }
auto begin() const { return d_name.begin(); }
bool empty() const { return d_name.empty(); }
auto end() const { return d_name.end(); }
auto front() const { return d_name.front(); }
void pop_back() { d_name.pop_back(); }
auto push_front(const dnslabel& dn) { return d_name.push_front(dn); }
auto size() { return d_name.size(); }
std::deque<dnslabel> d_name;
};
std::ostream & operator<<(std::ostream &os, const dnsname& d);
dnsname operator+(const dnsname& a, const dnsname& b);
struct RRSet
{
std::vector<std::string> contents;
uint32_t ttl{3600};
};
struct DNSLabelCompare: public std::binary_function<std::string, std::string, bool>
{
bool operator()(const dnslabel& a, const dnslabel& b) const
{
return strcasecmp(a.c_str(), b.c_str()) < 0; // XXX locale pain, plus embedded zeros
}
};
struct DNSNode
{
const DNSNode* find(dnsname& name, dnsname& last, bool* passedZonecut=0) const;
DNSNode* add(dnsname name);
std::map<dnslabel, DNSNode, DNSLabelCompare> children;
std::map<DNSType, RRSet > rrsets;
void visit(std::function<void(const dnsname& name, const DNSNode*)> visitor, dnsname name) const;
DNSNode* zone{0}; // if this is set, this node is a zone
};

40
tdns/nenum.hh Normal file
View File

@ -0,0 +1,40 @@
#pragma once
#include <cstdint>
#include <iostream>
#include <stdexcept>
#include <algorithm>
#include <array>
#include <string.h>
#define SMARTENUMSTART(x) static constexpr std::pair<x, const char*> enumtypemap##x[]= {
#define SENUM(x,a1) { x::a1, #a1},
#define SENUM2(x, a1, ...) SENUM(x,a1) SENUM(x, __VA_ARGS__)
#define SENUM3(x, a1, ...) SENUM(x,a1) SENUM2(x, __VA_ARGS__)
#define SENUM4(x, a1, ...) SENUM(x,a1) SENUM3(x, __VA_ARGS__)
#define SENUM5(x, a1, ...) SENUM(x,a1) SENUM4(x, __VA_ARGS__)
#define SENUM6(x, a1, ...) SENUM(x,a1) SENUM5(x, __VA_ARGS__)
#define SENUM7(x, a1, ...) SENUM(x,a1) SENUM6(x, __VA_ARGS__)
#define SENUM8(x, a1, ...) SENUM(x,a1) SENUM7(x, __VA_ARGS__)
#define SENUM9(x, a1, ...) SENUM(x,a1) SENUM8(x, __VA_ARGS__)
#define SENUM10(x, a1, ...) SENUM(x,a1) SENUM9(x, __VA_ARGS__)
#define SENUM11(x, a1, ...) SENUM(x,a1) SENUM10(x, __VA_ARGS__)
#define SMARTENUMEND(x) }; \
inline const char* toString(const x& t) \
{ \
for(const auto &a : enumtypemap##x) \
if(a.first == t) \
return a.second; \
return "?"; \
} \
inline x make##x(const char* from) { \
for(const auto& a : enumtypemap##x) \
if(!strcmp(a.second, from)) \
return a.first; \
throw std::runtime_error("Unknown value '" + std::string(from) + "' for enum "#x); \
} \
inline std::ostream& operator<<(std::ostream &os, const x& s) { \
os << toString(s); return os; } \

View File

@ -12,116 +12,10 @@
#include "dns.hh"
#include "safearray.hh"
#include <thread>
#include "dns-storage.hh"
using namespace std;
typedef std::string dnslabel;
enum class RCode
{
Noerror = 0, Servfail = 2, Nxdomain = 3, Notimp = 4, Refused = 5
};
enum class DNSType : uint16_t
{
A = 1, NS = 2, CNAME = 5, SOA=6, AAAA = 28, IXFR = 251, AXFR = 252, ANY = 255
};
enum class DNSSection
{
Question, Answer, Authority, Additional
};
typedef deque<dnslabel> dnsname;
// this should perform escaping rules!
static std::ostream & operator<<(std::ostream &os, const dnsname& d)
{
for(const auto& l : d) {
os<<l<<".";
}
return os;
}
struct RRSet
{
vector<string> contents;
uint32_t ttl{3600};
};
struct CIStringCompare: public std::binary_function<string, string, bool>
{
bool operator()(const string& a, const string& b) const
{
return strcasecmp(a.c_str(), b.c_str()) < 0; // XXX locale pain, plus embedded zeros
}
};
struct DNSNode
{
const DNSNode* find(dnsname& name, dnsname& last, bool* passedZonecut=0) const;
DNSNode* add(dnsname name);
map<dnslabel, DNSNode, CIStringCompare> children;
map<DNSType, RRSet > rrsets;
DNSNode* zone{0}; // if this is set, this node is a zone
};
const DNSNode* DNSNode::find(dnsname& name, dnsname& last, bool* passedZonecut) const
{
cout<<"find for '"<<name<<"', last is now '"<<last<<"'"<<endl;
if(!last.empty() && passedZonecut && rrsets.count(DNSType::NS)) {
*passedZonecut=true;
}
if(name.empty()) {
cout<<"Empty lookup, returning this node or 0"<<endl;
if(!zone && rrsets.empty()) // only root zone can have this
return 0;
else
return this;
}
cout<<"Children at this node: ";
for(const auto& c: children) cout <<"'"<<c.first<<"' ";
cout<<endl;
auto iter = children.find(name.back());
cout<<"Looked for child called '"<<name.back()<<"'"<<endl;
if(iter == children.end()) {
cout<<"Found nothing, trying wildcard"<<endl;
iter = children.find("*");
if(iter == children.end()) {
cout<<"Still nothing, returning leaf"<<endl;
return this;
}
else {
cout<<"Had wildcard match, following"<<endl;
}
}
cout<<"Had match, continuing to child '"<<iter->first<<"'"<<endl;
last.push_front(name.back());
name.pop_back();
return iter->second.find(name, last, passedZonecut);
}
DNSNode* DNSNode::add(dnsname name)
{
cout<<"Add for '"<<name<<"'"<<endl;
if(name.size() == 1) {
cout<<"Last label, adding "<<name.front()<<endl;
return &children[name.front()];
}
auto back = name.back();
name.pop_back();
auto iter = children.find(back);
if(iter == children.end()) {
cout<<"Inserting new child for "<<back<<endl;
return children[back].add(name);
}
return iter->second.add(name);
}
struct DNSMessage
{
struct dnsheader dh=dnsheader{};
@ -133,7 +27,7 @@ struct DNSMessage
void putRR(DNSSection section, const dnsname& name, DNSType type, uint32_t ttl, const std::string& rr);
string serialize() const;
}; // __attribute__((packed));
};
dnsname DNSMessage::getName()
{
@ -169,12 +63,18 @@ void putName(auto& payload, const dnsname& name)
void DNSMessage::putRR(DNSSection section, const dnsname& name, DNSType type, uint32_t ttl, const std::string& content)
{
putName(payload, name);
payload.putUInt16((int)type); payload.putUInt16(1);
payload.putUInt32(ttl);
payload.putUInt16(content.size()); // check for overflow!
payload.putBlob(content);
auto cursize = payload.payloadpos;
try {
putName(payload, name);
payload.putUInt16((int)type); payload.putUInt16(1);
payload.putUInt32(ttl);
payload.putUInt16(content.size()); // check for overflow!
payload.putBlob(content);
}
catch(...) {
payload.payloadpos = cursize;
throw;
}
switch(section) {
case DNSSection::Question:
throw runtime_error("Can't add questions to a DNS Message with putRR");
@ -202,7 +102,6 @@ string DNSMessage::serialize() const
return string((const char*)this, (const char*)this + sizeof(dnsheader) + payload.payloadpos);
}
static_assert(sizeof(DNSMessage) == 516, "dnsmessage size must be 516");
std::string serializeDNSName(const dnsname& dn)
@ -216,6 +115,15 @@ std::string serializeDNSName(const dnsname& dn)
return ret;
}
std::string serializeMXRecord(uint16_t prio, const dnsname& mname)
{
SafeArray<256> sa;
sa.putUInt16(prio);
putName(sa, mname);
return sa.serialize();
}
std::string serializeSOARecord(const dnsname& mname, const dnsname& rname, uint32_t serial, uint32_t minimum=3600, uint32_t refresh=10800, uint32_t retry=3600, uint32_t expire=604800)
{
SafeArray<256> sa;
@ -249,13 +157,6 @@ std::string serializeAAAARecord(const std::string& src)
}
dnsname operator+(const dnsname& a, const dnsname& b)
{
dnsname ret=a;
for(const auto& l : b)
ret.push_back(l);
return ret;
}
bool processQuestion(const DNSNode& zones, DNSMessage& dm, const ComboAddress& local, const ComboAddress& remote, DNSMessage& response)
try
@ -263,7 +164,7 @@ try
dnsname name;
DNSType type;
dm.getQuestion(name, type);
cout<<"Received a query from "<<remote.toStringWithPort()<<" for "<<name<<" and type "<<(int)type<<endl;
cout<<"Received a query from "<<remote.toStringWithPort()<<" for "<<name<<" and type "<<type<<endl;
response.dh = dm.dh;
response.dh.ad = 0;
@ -274,9 +175,9 @@ try
response.setQuestion(name, type);
if(type == DNSType::AXFR) {
cout<<"Query was for AXFR or IXFR over UDP, can't do that"<<endl;
response.dh.rcode = (int)RCode::Servfail;
return true;
cout<<"Query was for AXFR or IXFR over UDP, can't do that"<<endl;
response.dh.rcode = (int)RCode::Servfail;
return true;
}
if(dm.dh.opcode != 0) {
@ -306,7 +207,7 @@ try
cout<<"This was a partial match, searchname now "<<searchname<<endl;
for(const auto& rr: node->rrsets) {
cout<<" Have type "<<(int)rr.first<<endl;
cout<<" Have type "<<rr.first<<endl;
}
auto iter = node->rrsets.find(DNSType::NS);
if(iter != node->rrsets.end() && passedZonecut) {
@ -358,6 +259,7 @@ try
}
}
else {
cout<<"No zone matched"<<endl;
response.dh.rcode = (uint8_t)RCode::Refused;
}
return true;
@ -386,17 +288,28 @@ void udpThread(ComboAddress local, Socket* sock, const DNSNode* zones)
DNSMessage response;
if(processQuestion(*zones, dm, local, remote, response)) {
cout<<"Sending response with rcode "<<(RCode)response.dh.rcode <<endl;
SSendto(*sock, response.serialize(), remote);
}
}
}
void writeTCPResponse(int sock, const DNSMessage& response)
{
string ser="00"+response.serialize();
cout<<"Should send a message of "<<ser.size()<<" bytes in response"<<endl;
uint16_t len = htons(ser.length()-2);
ser[0] = *((char*)&len);
ser[1] = *(((char*)&len) + 1);
SWriten(sock, ser);
cout<<"Sent!"<<endl;
}
void tcpClientThread(ComboAddress local, ComboAddress remote, int s, const DNSNode* zones)
{
Socket sock(s);
cout<<"TCP Connection from "<<remote.toStringWithPort()<<endl;
for(;;) {
uint16_t len;
string message = SRead(sock, 2);
@ -429,23 +342,78 @@ void tcpClientThread(ComboAddress local, ComboAddress remote, int s, const DNSNo
dnsname name;
DNSType type;
dm.getQuestion(name, type);
DNSMessage response;
if(type == DNSType::AXFR) {
cout<<"Should do AXFR for "<<name<<endl;
dnsname zone;
auto fnd = zones->find(name, zone);
if(!fnd || !fnd->zone || !name.empty()) {
cout<<" This was not a zone"<<endl;
return;
}
cout<<"Have zone, walking it"<<endl;
response.dh = dm.dh;
response.dh.ad = 0;
response.dh.ra = 0;
response.dh.aa = 0;
response.dh.qr = 1;
response.dh.ancount = response.dh.arcount = response.dh.nscount = 0;
response.setQuestion(zone, type);
auto node = fnd->zone;
// send SOA
response.putRR(DNSSection::Answer, zone, DNSType::SOA, node->rrsets[DNSType::SOA].ttl, node->rrsets[DNSType::SOA].contents[0]);
writeTCPResponse(sock, response);
response.dh.ancount = response.dh.arcount = response.dh.nscount = 0;
response.payload.rewind();
response.setQuestion(zone, type);
// send all other records
node->visit([&response,&sock,&name,&type,&zone](const dnsname& nname, const DNSNode* n) {
cout<<nname<<", types: ";
for(const auto& p : n->rrsets) {
if(p.first == DNSType::SOA)
continue;
for(const auto& rr : p.second.contents) {
retry:
try {
response.putRR(DNSSection::Answer, nname, p.first, p.second.ttl, rr);
}
catch(...) {
writeTCPResponse(sock, response);
response.dh.ancount = response.dh.arcount = response.dh.nscount = 0;
response.payload.rewind();
response.setQuestion(zone, type);
goto retry;
}
}
cout<<p.first<<" ";
}
cout<<endl;
}, zone);
writeTCPResponse(sock, response);
response.dh.ancount = response.dh.arcount = response.dh.nscount = 0;
response.payload.rewind();
response.setQuestion(zone, type);
// send SOA again
response.putRR(DNSSection::Answer, zone, DNSType::SOA, node->rrsets[DNSType::SOA].ttl, node->rrsets[DNSType::SOA].contents[0]);
writeTCPResponse(sock, response);
return;
}
else {
dm.payload.rewind();
DNSMessage response;
if(processQuestion(*zones, dm, local, remote, response)) {
string ser="00"+response.serialize();
cout<<"Should send a message of "<<ser.size()<<" bytes in response"<<endl;
len = htons(ser.length()-2);
ser[0] = *((char*)&len);
ser[1] = *(((char*)&len) + 1);
SWriten(sock, ser);
cout<<"Sent!"<<endl;
writeTCPResponse(sock, response);
}
else
return;
@ -459,19 +427,23 @@ void loadZones(DNSNode& zones)
auto zone = zones.add({"powerdns", "org"});
zone->zone = new DNSNode(); // XXX ICK
zone->zone->rrsets[DNSType::SOA]={{serializeSOARecord({"ns1", "powerdns", "org"}, {"admin", "powerdns", "org"}, 1)}};
zone->zone->rrsets[DNSType::MX]={{serializeMXRecord(25, {"server1", "powerdns", "org"})}};
zone->zone->rrsets[DNSType::A]={{serializeARecord("1.2.3.4")}, 300};
zone->zone->rrsets[DNSType::AAAA]={{serializeAAAARecord("::1"), serializeAAAARecord("2001::1")}, 900};
zone->zone->rrsets[DNSType::AAAA]={{serializeAAAARecord("::1"), serializeAAAARecord("2001::1"), serializeAAAARecord("2a02:a440:b085:1:beee:7bff:fe89:f0fb")}, 900};
zone->zone->rrsets[DNSType::NS]={{serializeDNSName({"ns1", "powerdns", "org"})}, 300};
zone->zone->add({"www"})->rrsets[DNSType::CNAME]={{serializeDNSName({"server1","powerdns","org"})}};
zone->zone->add({"server1"})->rrsets[DNSType::A]={{serializeARecord("213.244.168.210")}};
zone->zone->add({"server1"})->rrsets[DNSType::AAAA]={{serializeAAAARecord("::1")}};
// zone->zone->add({"*"})->rrsets[(dnstype)DNSType::A]={"\x05\x06\x07\x08"};
zone->zone->add({"fra"})->rrsets[DNSType::NS]={{serializeDNSName({"ns1","fra","powerdns","org"})}};
zone->zone->add({"ns1", "fra"})->rrsets[DNSType::A]={{serializeARecord("12.13.14.15")}, 86400};
zone->zone->add({"NS2", "fra"})->rrsets[DNSType::A]={{serializeARecord("12.13.14.16")}, 86400};
}
int main(int argc, char** argv)
@ -491,8 +463,6 @@ int main(int argc, char** argv)
loadZones(zones);
thread udpServer(udpThread, local, &udplistener, &zones);
for(;;) {
ComboAddress remote;