From 515bfca635d18bd269bcc01af93064049d5ca6d1 Mon Sep 17 00:00:00 2001 From: bert hubert Date: Sun, 8 Apr 2018 23:20:11 +0200 Subject: [PATCH] working AXFR, printable enumbs --- tdns/Makefile | 2 +- tdns/dns-storage.cc | 81 ++++++++++++++ tdns/dns-storage.hh | 85 +++++++++++++++ tdns/nenum.hh | 40 +++++++ tdns/tdns.cc | 250 +++++++++++++++++++------------------------- 5 files changed, 317 insertions(+), 141 deletions(-) create mode 100644 tdns/dns-storage.cc create mode 100644 tdns/dns-storage.hh create mode 100644 tdns/nenum.hh diff --git a/tdns/Makefile b/tdns/Makefile index 0591f56..94dcd32 100644 --- a/tdns/Makefile +++ b/tdns/Makefile @@ -12,5 +12,5 @@ clean: -include *.d -tdns: tdns.o ext/simplesocket/comboaddress.o ext/simplesocket/sclasses.o ext/simplesocket/swrappers.o +tdns: tdns.o dns-storage.o ext/simplesocket/comboaddress.o ext/simplesocket/sclasses.o ext/simplesocket/swrappers.o g++ -std=gnu++14 $^ -o $@ -pthread diff --git a/tdns/dns-storage.cc b/tdns/dns-storage.cc new file mode 100644 index 0000000..ccd6dba --- /dev/null +++ b/tdns/dns-storage.cc @@ -0,0 +1,81 @@ +#include "dns-storage.hh" +using namespace std; + +const DNSNode* DNSNode::find(dnsname& name, dnsname& last, bool* passedZonecut) const +{ + cout<<"find for '"<first<<"'"<second.find(name, last, passedZonecut); +} + +DNSNode* DNSNode::add(dnsname name) +{ + cout<<"Add for '"<second.add(name); +} + +dnsname operator+(const dnsname& a, const dnsname& b) +{ + dnsname ret=a; + for(const auto& l : b.d_name) + ret.d_name.push_back(l); + return ret; +} + +void DNSNode::visit(std::function visitor, dnsname name) const +{ + visitor(name, this); + for(const auto& c : children) + c.second.visit(visitor, dnsname{c.first}+name); +} + +// this should perform escaping rules! +std::ostream & operator<<(std::ostream &os, const dnsname& d) +{ + for(const auto& l : d.d_name) { + os< +#include +#include +#include +#include +#include +#include +#include +#include "nenum.hh" + +typedef std::string dnslabel; + +enum class RCode +{ + Noerror = 0, Servfail = 2, Nxdomain = 3, Notimp = 4, Refused = 5 +}; + +SMARTENUMSTART(RCode) +SENUM5(RCode, Noerror, Servfail, Nxdomain, Notimp, Refused) +SMARTENUMEND(RCode) + + +enum class DNSType : uint16_t +{ + A = 1, NS = 2, CNAME = 5, SOA=6, MX=15, AAAA = 28, IXFR = 251, AXFR = 252, ANY = 255 +}; + +SMARTENUMSTART(DNSType) +SENUM11(DNSType, A, NS, CNAME, SOA, MX, AAAA, IXFR, AAAA, IXFR, AXFR, ANY) +SMARTENUMEND(DNSType) + +enum class DNSSection +{ + Question, Answer, Authority, Additional +}; + +SMARTENUMSTART(DNSSection) +SENUM4(DNSSection, Question, Answer, Authority, Additional) +SMARTENUMEND(DNSSection) + +struct dnsname +{ + dnsname() {} + dnsname(std::initializer_list dls) : d_name(dls) {} + void push_back(const dnslabel& l) { d_name.push_back(l); } + auto back() const { return d_name.back(); } + auto begin() const { return d_name.begin(); } + bool empty() const { return d_name.empty(); } + auto end() const { return d_name.end(); } + auto front() const { return d_name.front(); } + void pop_back() { d_name.pop_back(); } + auto push_front(const dnslabel& dn) { return d_name.push_front(dn); } + auto size() { return d_name.size(); } + std::deque d_name; +}; + + +std::ostream & operator<<(std::ostream &os, const dnsname& d); +dnsname operator+(const dnsname& a, const dnsname& b); +struct RRSet +{ + std::vector contents; + uint32_t ttl{3600}; +}; + +struct DNSLabelCompare: public std::binary_function +{ + bool operator()(const dnslabel& a, const dnslabel& b) const + { + return strcasecmp(a.c_str(), b.c_str()) < 0; // XXX locale pain, plus embedded zeros + } +}; + +struct DNSNode +{ + const DNSNode* find(dnsname& name, dnsname& last, bool* passedZonecut=0) const; + DNSNode* add(dnsname name); + std::map children; + std::map rrsets; + void visit(std::function visitor, dnsname name) const; + DNSNode* zone{0}; // if this is set, this node is a zone +}; + + diff --git a/tdns/nenum.hh b/tdns/nenum.hh new file mode 100644 index 0000000..4d87137 --- /dev/null +++ b/tdns/nenum.hh @@ -0,0 +1,40 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#define SMARTENUMSTART(x) static constexpr std::pair enumtypemap##x[]= { +#define SENUM(x,a1) { x::a1, #a1}, +#define SENUM2(x, a1, ...) SENUM(x,a1) SENUM(x, __VA_ARGS__) +#define SENUM3(x, a1, ...) SENUM(x,a1) SENUM2(x, __VA_ARGS__) +#define SENUM4(x, a1, ...) SENUM(x,a1) SENUM3(x, __VA_ARGS__) +#define SENUM5(x, a1, ...) SENUM(x,a1) SENUM4(x, __VA_ARGS__) +#define SENUM6(x, a1, ...) SENUM(x,a1) SENUM5(x, __VA_ARGS__) +#define SENUM7(x, a1, ...) SENUM(x,a1) SENUM6(x, __VA_ARGS__) +#define SENUM8(x, a1, ...) SENUM(x,a1) SENUM7(x, __VA_ARGS__) +#define SENUM9(x, a1, ...) SENUM(x,a1) SENUM8(x, __VA_ARGS__) +#define SENUM10(x, a1, ...) SENUM(x,a1) SENUM9(x, __VA_ARGS__) +#define SENUM11(x, a1, ...) SENUM(x,a1) SENUM10(x, __VA_ARGS__) + +#define SMARTENUMEND(x) }; \ +inline const char* toString(const x& t) \ +{ \ + for(const auto &a : enumtypemap##x) \ + if(a.first == t) \ + return a.second; \ + return "?"; \ +} \ +inline x make##x(const char* from) { \ +for(const auto& a : enumtypemap##x) \ + if(!strcmp(a.second, from)) \ + return a.first; \ + throw std::runtime_error("Unknown value '" + std::string(from) + "' for enum "#x); \ + } \ +inline std::ostream& operator<<(std::ostream &os, const x& s) { \ + os << toString(s); return os; } \ + + diff --git a/tdns/tdns.cc b/tdns/tdns.cc index f4a71f4..1c00720 100644 --- a/tdns/tdns.cc +++ b/tdns/tdns.cc @@ -12,116 +12,10 @@ #include "dns.hh" #include "safearray.hh" #include +#include "dns-storage.hh" using namespace std; -typedef std::string dnslabel; - -enum class RCode -{ - Noerror = 0, Servfail = 2, Nxdomain = 3, Notimp = 4, Refused = 5 -}; - -enum class DNSType : uint16_t -{ - A = 1, NS = 2, CNAME = 5, SOA=6, AAAA = 28, IXFR = 251, AXFR = 252, ANY = 255 -}; - -enum class DNSSection -{ - Question, Answer, Authority, Additional -}; - -typedef deque dnsname; -// this should perform escaping rules! -static std::ostream & operator<<(std::ostream &os, const dnsname& d) -{ - for(const auto& l : d) { - os< contents; - uint32_t ttl{3600}; -}; - -struct CIStringCompare: public std::binary_function -{ - bool operator()(const string& a, const string& b) const - { - return strcasecmp(a.c_str(), b.c_str()) < 0; // XXX locale pain, plus embedded zeros - } -}; - - -struct DNSNode -{ - const DNSNode* find(dnsname& name, dnsname& last, bool* passedZonecut=0) const; - DNSNode* add(dnsname name); - map children; - map rrsets; - - DNSNode* zone{0}; // if this is set, this node is a zone -}; - -const DNSNode* DNSNode::find(dnsname& name, dnsname& last, bool* passedZonecut) const -{ - cout<<"find for '"<first<<"'"<second.find(name, last, passedZonecut); -} - -DNSNode* DNSNode::add(dnsname name) -{ - cout<<"Add for '"<second.add(name); -} - struct DNSMessage { struct dnsheader dh=dnsheader{}; @@ -133,7 +27,7 @@ struct DNSMessage void putRR(DNSSection section, const dnsname& name, DNSType type, uint32_t ttl, const std::string& rr); string serialize() const; -}; // __attribute__((packed)); +}; dnsname DNSMessage::getName() { @@ -169,12 +63,18 @@ void putName(auto& payload, const dnsname& name) void DNSMessage::putRR(DNSSection section, const dnsname& name, DNSType type, uint32_t ttl, const std::string& content) { - putName(payload, name); - payload.putUInt16((int)type); payload.putUInt16(1); - payload.putUInt32(ttl); - payload.putUInt16(content.size()); // check for overflow! - payload.putBlob(content); - + auto cursize = payload.payloadpos; + try { + putName(payload, name); + payload.putUInt16((int)type); payload.putUInt16(1); + payload.putUInt32(ttl); + payload.putUInt16(content.size()); // check for overflow! + payload.putBlob(content); + } + catch(...) { + payload.payloadpos = cursize; + throw; + } switch(section) { case DNSSection::Question: throw runtime_error("Can't add questions to a DNS Message with putRR"); @@ -202,7 +102,6 @@ string DNSMessage::serialize() const return string((const char*)this, (const char*)this + sizeof(dnsheader) + payload.payloadpos); } - static_assert(sizeof(DNSMessage) == 516, "dnsmessage size must be 516"); std::string serializeDNSName(const dnsname& dn) @@ -216,6 +115,15 @@ std::string serializeDNSName(const dnsname& dn) return ret; } +std::string serializeMXRecord(uint16_t prio, const dnsname& mname) +{ + SafeArray<256> sa; + sa.putUInt16(prio); + putName(sa, mname); + return sa.serialize(); +} + + std::string serializeSOARecord(const dnsname& mname, const dnsname& rname, uint32_t serial, uint32_t minimum=3600, uint32_t refresh=10800, uint32_t retry=3600, uint32_t expire=604800) { SafeArray<256> sa; @@ -249,13 +157,6 @@ std::string serializeAAAARecord(const std::string& src) } -dnsname operator+(const dnsname& a, const dnsname& b) -{ - dnsname ret=a; - for(const auto& l : b) - ret.push_back(l); - return ret; -} bool processQuestion(const DNSNode& zones, DNSMessage& dm, const ComboAddress& local, const ComboAddress& remote, DNSMessage& response) try @@ -263,7 +164,7 @@ try dnsname name; DNSType type; dm.getQuestion(name, type); - cout<<"Received a query from "<rrsets) { - cout<<" Have type "<<(int)rr.first<rrsets.find(DNSType::NS); if(iter != node->rrsets.end() && passedZonecut) { @@ -358,6 +259,7 @@ try } } else { + cout<<"No zone matched"<find(name, zone); + if(!fnd || !fnd->zone || !name.empty()) { + cout<<" This was not a zone"<zone; + + // send SOA + response.putRR(DNSSection::Answer, zone, DNSType::SOA, node->rrsets[DNSType::SOA].ttl, node->rrsets[DNSType::SOA].contents[0]); + + writeTCPResponse(sock, response); + response.dh.ancount = response.dh.arcount = response.dh.nscount = 0; + response.payload.rewind(); + response.setQuestion(zone, type); + + + // send all other records + node->visit([&response,&sock,&name,&type,&zone](const dnsname& nname, const DNSNode* n) { + cout<rrsets) { + if(p.first == DNSType::SOA) + continue; + for(const auto& rr : p.second.contents) { + retry: + try { + response.putRR(DNSSection::Answer, nname, p.first, p.second.ttl, rr); + } + catch(...) { + writeTCPResponse(sock, response); + response.dh.ancount = response.dh.arcount = response.dh.nscount = 0; + response.payload.rewind(); + response.setQuestion(zone, type); + goto retry; + } + } + cout<rrsets[DNSType::SOA].ttl, node->rrsets[DNSType::SOA].contents[0]); + + writeTCPResponse(sock, response); + return; } else { dm.payload.rewind(); - DNSMessage response; if(processQuestion(*zones, dm, local, remote, response)) { - string ser="00"+response.serialize(); - cout<<"Should send a message of "<zone = new DNSNode(); // XXX ICK zone->zone->rrsets[DNSType::SOA]={{serializeSOARecord({"ns1", "powerdns", "org"}, {"admin", "powerdns", "org"}, 1)}}; + zone->zone->rrsets[DNSType::MX]={{serializeMXRecord(25, {"server1", "powerdns", "org"})}}; + zone->zone->rrsets[DNSType::A]={{serializeARecord("1.2.3.4")}, 300}; - zone->zone->rrsets[DNSType::AAAA]={{serializeAAAARecord("::1"), serializeAAAARecord("2001::1")}, 900}; + zone->zone->rrsets[DNSType::AAAA]={{serializeAAAARecord("::1"), serializeAAAARecord("2001::1"), serializeAAAARecord("2a02:a440:b085:1:beee:7bff:fe89:f0fb")}, 900}; zone->zone->rrsets[DNSType::NS]={{serializeDNSName({"ns1", "powerdns", "org"})}, 300}; zone->zone->add({"www"})->rrsets[DNSType::CNAME]={{serializeDNSName({"server1","powerdns","org"})}}; zone->zone->add({"server1"})->rrsets[DNSType::A]={{serializeARecord("213.244.168.210")}}; + zone->zone->add({"server1"})->rrsets[DNSType::AAAA]={{serializeAAAARecord("::1")}}; // zone->zone->add({"*"})->rrsets[(dnstype)DNSType::A]={"\x05\x06\x07\x08"}; zone->zone->add({"fra"})->rrsets[DNSType::NS]={{serializeDNSName({"ns1","fra","powerdns","org"})}}; zone->zone->add({"ns1", "fra"})->rrsets[DNSType::A]={{serializeARecord("12.13.14.15")}, 86400}; + zone->zone->add({"NS2", "fra"})->rrsets[DNSType::A]={{serializeARecord("12.13.14.16")}, 86400}; } int main(int argc, char** argv) @@ -491,8 +463,6 @@ int main(int argc, char** argv) loadZones(zones); thread udpServer(udpThread, local, &udplistener, &zones); - - for(;;) { ComboAddress remote;