make DNSMessageWriter variable length

This commit is contained in:
bert hubert 2018-04-11 22:12:57 +02:00
parent 62b880018c
commit 294e72b261
7 changed files with 215 additions and 204 deletions

View File

@ -4,7 +4,7 @@
# teaching DNS
Welcome to tdns, the teaching authoritative server, implementing all of
basic DNS in 1000 lines of code.
basic DNS in ~~1000~~ 1100 lines of code.
The goals of tdns are:
@ -25,9 +25,9 @@ Features are complete:
* Wildcards
* Delegations
* Glue records
* Truncation
Missing:
* Truncation
* Compression (may not fit in the 1000 lines!)
* EDNS (not 'basic' DNS by our definition, but ok)

View File

@ -7,14 +7,15 @@ void loadZones(DNSNode& zones)
auto newzone = zone->zone = new DNSNode(); // XXX ICK
newzone->addRRs(SOAGen::make({"ns1", "powerdns", "org"}, {"admin", "powerdns", "org"}, 1));
newzone->rrsets[DNSType::MX].add(MXGen::make(25, {"server1", "powerdns", "org"}));
newzone->addRRs(MXGen::make(25, {"server1", "powerdns", "org"}));
newzone->rrsets[DNSType::A].add(AGen::make("1.2.3.4"));
newzone->rrsets[DNSType::AAAA].add(AAAAGen::make("::1"));
newzone->addRRs(AGen::make("1.2.3.4"));
newzone->addRRs(AAAAGen::make("::1"));
newzone->rrsets[DNSType::AAAA].ttl= 900;
newzone->rrsets[DNSType::NS].add(NSGen::make({"ns1", "powerdns", "org"}));
newzone->addRRs(TXTGen::make("Proudly served by tdns " __DATE__ " " __TIME__));
newzone->addRRs(NSGen::make({"ns1", "powerdns", "org"}), NSGen::make({"ns2", "powerdns", "org"}));
newzone->addRRs(TXTGen::make("Proudly served by tdns compiled on " __DATE__ " " __TIME__),
TXTGen::make("This is some more filler to make this packet exceed 512 bytes"));
newzone->add({"www"})->rrsets[DNSType::CNAME].add(CNAMEGen::make({"server1","powerdns","org"}));
newzone->add({"www2"})->rrsets[DNSType::CNAME].add(CNAMEGen::make({"nosuchserver1","powerdns","org"}));
@ -25,7 +26,7 @@ void loadZones(DNSNode& zones)
newzone->add({"*", "fr"})->rrsets[DNSType::CNAME].add(CNAMEGen::make({"server2", "powerdns", "org"}));
newzone->add({"fra"})->addRRs(NSGen::make({"ns1","fra","powerdns","org"}), NSGen::make({"ns1","fra","powerdns","org"}));
newzone->add({"ns1"})->addRRs(AGen::make("212.13.14.15"));
newzone->add({"ns1", "fra"})->addRRs(AGen::make("12.13.14.15"));
newzone->add({"NS2", "fra"})->addRRs(AGen::make("12.13.14.16"), AAAAGen::make("::1"));
newzone->add({"something"})->addRRs(AAAAGen::make("::1"), AGen::make("12.13.14.15"));

View File

@ -7,7 +7,7 @@ std::unique_ptr<RRGen> AGen::make(const ComboAddress& ca)
void AGen::toMessage(DNSMessageWriter& dmw)
{
dmw.payload.putUInt32(d_ip);
dmw.putUInt32(d_ip);
}
std::unique_ptr<RRGen> AAAAGen::make(const ComboAddress& ca)
@ -23,39 +23,39 @@ std::unique_ptr<RRGen> AAAAGen::make(const ComboAddress& ca)
void AAAAGen::toMessage(DNSMessageWriter& dmw)
{
dmw.payload.putBlob(d_ip, 16);
dmw.putBlob(d_ip, 16);
}
void SOAGen::toMessage(DNSMessageWriter& dmw)
{
putName(dmw.payload, d_mname); putName(dmw.payload, d_rname);
dmw.payload.putUInt32(d_serial); dmw.payload.putUInt32(d_refresh);
dmw.payload.putUInt32(d_retry); dmw.payload.putUInt32(d_expire);
dmw.payload.putUInt32(d_minimum);
dmw.putName(d_mname); dmw.putName(d_rname);
dmw.putUInt32(d_serial); dmw.putUInt32(d_refresh);
dmw.putUInt32(d_retry); dmw.putUInt32(d_expire);
dmw.putUInt32(d_minimum);
}
void CNAMEGen::toMessage(DNSMessageWriter& dmw)
{
putName(dmw.payload, d_name);
dmw.putName(d_name);
}
void NSGen::toMessage(DNSMessageWriter& dmw)
{
putName(dmw.payload, d_name);
dmw.putName(d_name);
}
void MXGen::toMessage(DNSMessageWriter& dmw)
{
dmw.payload.putUInt16(d_prio);
putName(dmw.payload, d_name);
dmw.putUInt16(d_prio);
dmw.putName(d_name);
}
void TXTGen::toMessage(DNSMessageWriter& dmw)
{
// XXX should autosplit
dmw.payload.putUInt8(d_txt.length());
dmw.payload.putBlob(d_txt);
dmw.putUInt8(d_txt.length());
dmw.putBlob(d_txt);
}
void ClockTXTGen::toMessage(DNSMessageWriter& dmw)
@ -70,6 +70,6 @@ void ClockTXTGen::toMessage(DNSMessageWriter& dmw)
else
txt="Overflow";
// XXX should autosplit
dmw.payload.putUInt8(txt.length());
dmw.payload.putBlob(txt);
dmw.putUInt8(txt.length());
dmw.putBlob(txt);
}

View File

@ -27,17 +27,17 @@ void DNSMessageReader::getQuestion(dnsname& name, DNSType& type)
void DNSMessageWriter::putRR(DNSSection section, const dnsname& name, DNSType type, uint32_t ttl, const std::unique_ptr<RRGen>& content)
{
auto cursize = payload.payloadpos;
auto cursize = payloadpos;
try {
putName(payload, name);
payload.putUInt16((int)type); payload.putUInt16(1);
payload.putUInt32(ttl);
auto pos = payload.putUInt16(0); // placeholder
putName(name);
putUInt16((int)type); putUInt16(1);
putUInt32(ttl);
auto pos = putUInt16(0); // placeholder
content->toMessage(*this);
payload.putUInt16At(pos, payload.payloadpos-pos-2);
putUInt16At(pos, payloadpos-pos-2);
}
catch(...) {
payload.payloadpos = cursize;
payloadpos = cursize;
throw;
}
switch(section) {
@ -58,10 +58,10 @@ void DNSMessageWriter::putRR(DNSSection section, const dnsname& name, DNSType ty
void DNSMessageWriter::setQuestion(const dnsname& name, DNSType type)
{
dh.ancount = dh.arcount = dh.nscount = 0;
payload.rewind();
putName(payload, name);
payload.putUInt16((uint16_t)type);
payload.putUInt16(1); // class
payloadpos=0;
putName(name);
putUInt16((uint16_t)type);
putUInt16(1); // class
}
string DNSMessageReader::serialize() const
@ -70,7 +70,9 @@ string DNSMessageReader::serialize() const
}
string DNSMessageWriter::serialize() const
{
return string((const char*)this, (const char*)this + sizeof(dnsheader) + payload.payloadpos);
std::string ret((const char*)this, (const char*)this + sizeof(dnsheader));
ret.append((const unsigned char*)&payload[0], (const unsigned char*)&payload[payloadpos]);
return ret;
}
static_assert(sizeof(DNSMessageReader) == 516, "dnsmessagereader size must be 516");

View File

@ -2,6 +2,7 @@
#include "dns.hh"
#include "safearray.hh"
#include "dns-storage.hh"
#include <vector>
struct DNSMessageReader
{
@ -16,20 +17,63 @@ struct DNSMessageReader
struct DNSMessageWriter
{
explicit DNSMessageWriter(int maxsize=512)
{
payload.resize(maxsize);
}
struct dnsheader dh=dnsheader{};
SafeArray<1500> payload;
std::vector<uint8_t> payload;
void setQuestion(const dnsname& name, DNSType type);
void putRR(DNSSection section, const dnsname& name, DNSType type, uint32_t ttl, const std::unique_ptr<RRGen>& rr);
std::string serialize() const;
uint16_t payloadpos=0;
void putUInt8(uint8_t val)
{
payload.at(payloadpos++)=val;
}
uint16_t putUInt16(uint16_t val)
{
val = htons(val);
memcpy(&payload.at(payloadpos+2)-2, &val, 2);
payloadpos+=2;
return payloadpos - 2;
}
void putUInt16At(uint16_t pos, uint16_t val)
{
val = htons(val);
memcpy(&payload.at(pos+2)-2, &val, 2);
}
void putUInt32(uint32_t val)
{
val = htonl(val);
memcpy(&payload.at(payloadpos+sizeof(val)) - sizeof(val), &val, sizeof(val));
payloadpos += sizeof(val);
}
void putBlob(const std::string& blob)
{
memcpy(&payload.at(payloadpos+blob.size()) - blob.size(), blob.c_str(), blob.size());
payloadpos += blob.size();;
}
void putBlob(const unsigned char* blob, int size)
{
memcpy(&payload.at(payloadpos+size) - size, blob, size);
payloadpos += size;
}
void putName(const dnsname& name)
{
for(const auto& l : name) {
putUInt8(l.size());
putBlob(l.d_s);
}
putUInt8(0);
}
};
inline void putName(SafeArray<1500>& payload, const dnsname& name)
{
for(const auto& l : name) {
if(l.size() > 63)
throw std::runtime_error("Can't emit a label larger than 63 characters");
payload.putUInt8(l.size());
payload.putBlob(l.d_s);
}
payload.putUInt8(0);
}

View File

@ -27,46 +27,8 @@ struct SafeArray
memcpy(&ret, &payload.at(payloadpos+2)-2, 2);
payloadpos+=2;
return htons(ret);
}
void putUInt8(uint8_t val)
{
payload.at(payloadpos++)=val;
}
uint16_t putUInt16(uint16_t val)
{
val = htons(val);
memcpy(&payload.at(payloadpos+2)-2, &val, 2);
payloadpos+=2;
return payloadpos - 2;
}
void putUInt16At(uint16_t pos, uint16_t val)
{
val = htons(val);
memcpy(&payload.at(pos+2)-2, &val, 2);
}
void putUInt32(uint32_t val)
{
val = htonl(val);
memcpy(&payload.at(payloadpos+sizeof(val)) - sizeof(val), &val, sizeof(val));
payloadpos += sizeof(val);
}
void putBlob(const std::string& blob)
{
memcpy(&payload.at(payloadpos+blob.size()) - blob.size(), blob.c_str(), blob.size());
payloadpos += blob.size();;
}
void putBlob(const unsigned char* blob, int size)
{
memcpy(&payload.at(payloadpos+size) - size, blob, size);
payloadpos += size;
}
std::string getBlob(int size)
{

View File

@ -19,13 +19,12 @@ using namespace std;
void addAdditional(const DNSNode* bestzone, const dnsname& zone, const vector<dnsname>& toresolve, DNSMessageWriter& response)
{
for(auto addname : toresolve ) {
cout<<"Doing additional or glue lookup for "<<addname<<endl;
cout<<"Doing additional or glue lookup for "<<addname<<" in "<<zone<<endl;
if(!addname.makeRelative(zone)) {
cout<<addname<<" is not within our zone, not doing glue"<<endl;
continue;
}
dnsname wuh;
cout<<"Looking up glue record "<<addname<<" in zone "<<zone<<endl;
auto addnode = bestzone->find(addname, wuh);
if(!addnode || !addname.empty()) {
cout<<" Found nothing, continuing"<<endl;
@ -44,136 +43,144 @@ void addAdditional(const DNSNode* bestzone, const dnsname& zone, const vector<dn
}
bool processQuestion(const DNSNode& zones, DNSMessageReader& dm, const ComboAddress& local, const ComboAddress& remote, DNSMessageWriter& response)
try
{
dnsname name;
dnsname name, origname;
DNSType type;
dm.getQuestion(name, type);
origname=name; // we munch on this below
cout<<"Received a query from "<<remote.toStringWithPort()<<" for "<<name<<" and type "<<type<<endl;
response.dh = dm.dh;
response.dh.ad = response.dh.ra = response.dh.aa = 0;
response.dh.qr = 1;
response.setQuestion(name, type);
if(type == DNSType::AXFR || type == DNSType::IXFR) {
cout<<"Query was for AXFR or IXFR over UDP, can't do that"<<endl;
response.dh.rcode = (int)RCode::Servfail;
return true;
}
if(dm.dh.opcode != 0) {
cout<<"Query had non-zero opcode "<<dm.dh.opcode<<", sending NOTIMP"<<endl;
response.dh.rcode = (int)RCode::Notimp;
return true;
}
try {
response.dh = dm.dh;
response.dh.ad = response.dh.ra = response.dh.aa = 0;
response.dh.qr = 1;
response.setQuestion(name, type);
dnsname zone;
auto fnd = zones.find(name, zone);
if(fnd && fnd->zone) {
cout<<"---\nBest zone: "<<zone<<", name now "<<name<<", loaded: "<<(void*)fnd->zone<<endl;
response.dh.aa = 1;
auto bestzone = fnd->zone;
dnsname searchname(name), lastnode, zonecutname;
const DNSNode* passedZonecut=0;
int CNAMELoopCount = 0;
loopCNAME:;
lastnode.clear();
zonecutname.clear();
auto node = bestzone->find(searchname, lastnode, &passedZonecut, &zonecutname);
if(!node) {
cout<<"Found nothing in zone '"<<zone<<"' for lhs '"<<name<<"'"<<endl;
if(type == DNSType::AXFR || type == DNSType::IXFR) {
cout<<"Query was for AXFR or IXFR over UDP, can't do that"<<endl;
response.dh.rcode = (int)RCode::Servfail;
return true;
}
else if(passedZonecut) {
response.dh.aa = false;
cout<<"This is a delegation, zonecutname: '"<<zonecutname<<"'"<<endl;
for(const auto& rr: passedZonecut->rrsets) {
cout<<" Have type "<<rr.first<<endl;
if(dm.dh.opcode != 0) {
cout<<"Query had non-zero opcode "<<dm.dh.opcode<<", sending NOTIMP"<<endl;
response.dh.rcode = (int)RCode::Notimp;
return true;
}
dnsname zone;
auto fnd = zones.find(name, zone);
if(fnd && fnd->zone) {
cout<<"---\nBest zone: "<<zone<<", name now "<<name<<", loaded: "<<(void*)fnd->zone<<endl;
response.dh.aa = 1;
auto bestzone = fnd->zone;
dnsname searchname(name), lastnode, zonecutname;
const DNSNode* passedZonecut=0;
int CNAMELoopCount = 0;
loopCNAME:;
auto node = bestzone->find(searchname, lastnode, &passedZonecut, &zonecutname);
if(!node) {
cout<<"Found nothing in zone '"<<zone<<"' for lhs '"<<name<<"'"<<endl;
}
auto iter = passedZonecut->rrsets.find(DNSType::NS);
if(iter != passedZonecut->rrsets.end()) {
const auto& rrset = iter->second;
vector<dnsname> toresolve;
for(const auto& rr : rrset.contents) {
response.putRR(DNSSection::Authority, zonecutname+zone, DNSType::NS, rrset.ttl, rr);
toresolve.push_back(dynamic_cast<NSGen*>(rr.get())->d_name);
else if(passedZonecut) {
response.dh.aa = false;
cout<<"This is a delegation, zonecutname: '"<<zonecutname<<"'"<<endl;
for(const auto& rr: passedZonecut->rrsets) {
cout<<" Have type "<<rr.first<<endl;
}
addAdditional(bestzone, zone, toresolve, response);
}
}
else if(!searchname.empty()) {
cout<<"This is an NXDOMAIN situation"<<endl;
const auto& rrset = fnd->zone->rrsets[DNSType::SOA];
response.dh.rcode = (int)RCode::Nxdomain;
response.putRR(DNSSection::Authority, zone, DNSType::SOA, rrset.ttl, rrset.contents[0]);
}
else {
cout<<"Found something in zone '"<<zone<<"' for lhs '"<<name<<"', searchname now '"<<searchname<<"', lastnode '"<<lastnode<<"', passedZonecut="<<passedZonecut<<endl;
auto iter = node->rrsets.cbegin();
vector<dnsname> additional;
if(type == DNSType::ANY) {
for(const auto& t : node->rrsets) {
const auto& rrset = t.second;
auto iter = passedZonecut->rrsets.find(DNSType::NS);
if(iter != passedZonecut->rrsets.end()) {
const auto& rrset = iter->second;
vector<dnsname> toresolve;
for(const auto& rr : rrset.contents) {
response.putRR(DNSSection::Answer, lastnode+zone, t.first, rrset.ttl, rr);
if(t.first == DNSType::MX)
additional.push_back(dynamic_cast<MXGen*>(rr.get())->d_name);
response.putRR(DNSSection::Authority, zonecutname+zone, DNSType::NS, rrset.ttl, rr);
toresolve.push_back(dynamic_cast<NSGen*>(rr.get())->d_name);
}
addAdditional(bestzone, zone, toresolve, response);
}
}
else if(iter = node->rrsets.find(type), iter != node->rrsets.end()) {
const auto& rrset = iter->second;
for(const auto& rr : rrset.contents) {
response.putRR(DNSSection::Answer, lastnode+zone, type, rrset.ttl, rr);
if(type == DNSType::MX)
additional.push_back(dynamic_cast<MXGen*>(rr.get())->d_name);
}
}
else if(iter = node->rrsets.find(DNSType::CNAME), iter != node->rrsets.end()) {
cout<<"We do have a CNAME!"<<endl;
const auto& rrset = iter->second;
dnsname target;
for(const auto& rr : rrset.contents) {
response.putRR(DNSSection::Answer, lastnode+zone, DNSType::CNAME, rrset.ttl, rr);
target=dynamic_cast<CNAMEGen*>(rr.get())->d_name;
}
if(target.makeRelative(zone)) {
cout<<" Should follow CNAME to "<<target<<" within our zone"<<endl;
// XXX we need to change our behaviour on NXDOMAIN I think depending on if you've followed a CNAME
searchname = target;
if(CNAMELoopCount++ < 10)
goto loopCNAME;
}
else
cout<<" CNAME points to record "<<target<<" in other zone, good luck"<<endl;
else if(!searchname.empty()) {
cout<<"This is an NXDOMAIN situation"<<endl;
const auto& rrset = fnd->zone->rrsets[DNSType::SOA];
response.dh.rcode = (int)RCode::Nxdomain;
response.putRR(DNSSection::Authority, zone, DNSType::SOA, rrset.ttl, rrset.contents[0]);
}
else {
cout<<"Node exists, qtype doesn't, NOERROR situation, inserting SOA"<<endl;
const auto& rrset = fnd->zone->rrsets[DNSType::SOA];
response.putRR(DNSSection::Answer, zone, DNSType::SOA, rrset.ttl, rrset.contents[0]);
}
addAdditional(bestzone, zone, additional, response);
cout<<"Found something in zone '"<<zone<<"' for lhs '"<<name<<"', searchname now '"<<searchname<<"', lastnode '"<<lastnode<<"', passedZonecut="<<passedZonecut<<endl;
auto iter = node->rrsets.cbegin();
vector<dnsname> additional;
if(type == DNSType::ANY) {
for(const auto& t : node->rrsets) {
const auto& rrset = t.second;
for(const auto& rr : rrset.contents) {
response.putRR(DNSSection::Answer, lastnode+zone, t.first, rrset.ttl, rr);
if(t.first == DNSType::MX)
additional.push_back(dynamic_cast<MXGen*>(rr.get())->d_name);
}
}
}
else if(iter = node->rrsets.find(type), iter != node->rrsets.end()) {
const auto& rrset = iter->second;
for(const auto& rr : rrset.contents) {
response.putRR(DNSSection::Answer, lastnode+zone, type, rrset.ttl, rr);
if(type == DNSType::MX)
additional.push_back(dynamic_cast<MXGen*>(rr.get())->d_name);
}
}
else if(iter = node->rrsets.find(DNSType::CNAME), iter != node->rrsets.end()) {
cout<<"We do have a CNAME!"<<endl;
const auto& rrset = iter->second;
dnsname target;
for(const auto& rr : rrset.contents) {
response.putRR(DNSSection::Answer, lastnode+zone, DNSType::CNAME, rrset.ttl, rr);
target=dynamic_cast<CNAMEGen*>(rr.get())->d_name;
}
if(target.makeRelative(zone)) {
cout<<" Should follow CNAME to "<<target<<" within our zone"<<endl;
// XXX we need to change our behaviour on NXDOMAIN I think depending on if you've followed a CNAME
searchname = target;
if(CNAMELoopCount++ < 10) {
lastnode.clear();
zonecutname.clear();
goto loopCNAME;
}
}
else
cout<<" CNAME points to record "<<target<<" in other zone, good luck"<<endl;
}
else {
cout<<"Node exists, qtype doesn't, NOERROR situation, inserting SOA"<<endl;
const auto& rrset = fnd->zone->rrsets[DNSType::SOA];
response.putRR(DNSSection::Answer, zone, DNSType::SOA, rrset.ttl, rrset.contents[0]);
}
addAdditional(bestzone, zone, additional, response);
}
}
else {
cout<<"No zone matched"<<endl;
response.dh.rcode = (uint8_t)RCode::Refused;
}
return true;
}
else {
cout<<"No zone matched"<<endl;
response.dh.rcode = (uint8_t)RCode::Refused;
catch(std::out_of_range& e) { // exceeded packet size
cout<<"Query for '"<<origname<<"'|"<<type<<" got truncated"<<endl;
response.setQuestion(origname, type); // this resets the packet
response.dh.tc=1; response.dh.aa=0;
return true;
}
catch(std::exception& e) {
cout<<"Error processing query: "<<e.what()<<endl;
return false;
}
return true;
}
catch(std::exception& e) {
cout<<"Error processing query: "<<e.what()<<endl;
return false;
}
void udpThread(ComboAddress local, Socket* sock, const DNSNode* zones)
@ -246,7 +253,7 @@ void tcpClientThread(ComboAddress local, ComboAddress remote, int s, const DNSNo
dnsname name;
DNSType type;
dm.getQuestion(name, type);
DNSMessageWriter response;
DNSMessageWriter response(std::numeric_limits<uint16_t>::max()-sizeof(dnsheader));
if(type == DNSType::AXFR) {
cout<<"Should do AXFR for "<<name<<endl;
@ -281,7 +288,7 @@ void tcpClientThread(ComboAddress local, ComboAddress remote, int s, const DNSNo
try {
response.putRR(DNSSection::Answer, nname, p.first, p.second.ttl, rr);
}
catch(...) { // exceeded packet size
catch(std::out_of_range& e) { // exceeded packet size
writeTCPResponse(sock, response);
response.setQuestion(zone, type);
goto retry;
@ -311,14 +318,9 @@ void tcpClientThread(ComboAddress local, ComboAddress remote, int s, const DNSNo
}
}
int main(int argc, char** argv)
try
{
cout<<sizeof(AGen)<<endl;
cout<<sizeof(MXGen)<<endl;
cout<<sizeof(RRGen)<<endl;
if(argc != 2) {
cerr<<"Syntax: tdns ipaddress:port"<<endl;
return(EXIT_FAILURE);