clean up EDNS situation and astoundingly weird endianness bug

This commit is contained in:
bert hubert 2018-04-13 15:23:00 +02:00
parent d9b33de683
commit 0398d80a87
5 changed files with 76 additions and 66 deletions

View File

@ -57,10 +57,10 @@ Known broken:
* ~~RCode after one CNAME chase~~
* ~~On output (to screen) we do not escape DNS names correctly~~
* TCP/IP does not follow recommended timeouts
* EDNS is a bit clunky and should move into DNSMessageWriter
* ~~EDNS is a bit clunky and should move into DNSMessageWriter~~
The code is not quite in a teachable state and still contains ugly bits. But
well worth [a
The code is not quite in a teachable state yet and still contains ugly bits.
But well worth [a
read](https://github.com/ahupowerdns/hello-dns/tree/master/tdns).
# Layout
@ -694,9 +694,8 @@ such a larger buffer size, a packet may exceed the available space. In that
case, the standard tells us to truncate the packet, and then still put an
EDNS record in the response.
This is why the code is currently littered with 'if(haveEDNS)' in a number
of places. This will be moved into `DNSMessageWriter` soon.
The DNSMessageWriter, in somewhat of a layering violation, takes care of
this in `serialize()`.
# Internals
`tdns` uses several small pieces of code not core to dns:

View File

@ -1,7 +1,7 @@
#pragma once
struct dnsheader {
unsigned id :16; /* query identification number */
#if BYTE_ORDER == BIG_ENDIAN
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
/* fields in third byte */
unsigned qr: 1; /* response flag */
unsigned opcode: 4; /* purpose of message */
@ -14,7 +14,7 @@ struct dnsheader {
unsigned ad: 1; /* authentic data from named */
unsigned cd: 1; /* checking disabled by resolver */
unsigned rcode :4; /* response code */
#elif BYTE_ORDER == LITTLE_ENDIAN || BYTE_ORDER == PDP_ENDIAN
#elif __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
/* fields in third byte */
unsigned rd :1; /* recursion desired */
unsigned tc :1; /* truncated message */

View File

@ -21,6 +21,7 @@ DNSMessageReader::DNSMessageReader(const char* in, uint16_t size)
d_doBit = flags & 0x80;
payload.getUInt8(); payload.getUInt16(); // ignore rest
cout<<" There was an EDNS section, size supported: "<< d_bufsize<<endl;
d_haveEDNS = true;
}
}
}
@ -96,26 +97,50 @@ void DNSMessageWriter::putEDNS(uint16_t bufsize, bool doBit)
payloadpos = cursize;
throw;
}
dh.nscount = htons(ntohs(dh.nscount)+1);
dh.arcount = htons(ntohs(dh.arcount)+1);
}
void DNSMessageWriter::setQuestion(const DNSName& name, DNSType type)
DNSMessageWriter::DNSMessageWriter(const DNSName& name, DNSType type, int maxsize) : d_qname(name), d_qtype(type)
{
dh.ancount = dh.arcount = dh.nscount = 0;
memset(&dh, 0, sizeof(dh));
payload.resize(maxsize);
clearRRs();
}
void DNSMessageWriter::clearRRs()
{
dh.qdcount = htons(1) ; dh.ancount = dh.arcount = dh.nscount = 0;
payloadpos=0;
putName(name);
putUInt16((uint16_t)type);
putName(d_qname);
putUInt16((uint16_t)d_qtype);
putUInt16(1); // class
}
string DNSMessageReader::serialize() const
{
return string((const char*)this, (const char*)this + sizeof(dnsheader) + payload.payloadpos);
}
string DNSMessageWriter::serialize() const
{
std::string ret((const char*)this, (const char*)this + sizeof(dnsheader));
ret.append((const unsigned char*)&payload[0], (const unsigned char*)&payload[payloadpos]);
DNSMessageWriter act = *this;
try {
if(haveEDNS) {
cout<<"Adding EDNS to DNS Message"<<endl;
act.putEDNS(payload.size() + sizeof(dnsheader), d_doBit);
}
}
catch(std::out_of_range& e) {
cout<<"Got truncated while adding EDNS! Truncating"<<endl;
act.clearRRs();
act.dh.tc = 1; act.dh.aa = 0;
act.putEDNS(payload.size() + sizeof(dnsheader), d_doBit);
}
std::string ret((const char*)&act.dh, ((const char*)&act.dh) + sizeof(dnsheader));
ret.append((const unsigned char*)&act.payload.at(0), (const unsigned char*)&act.payload.at(act.payloadpos));
return ret;
}
void DNSMessageWriter::setEDNS(uint16_t newsize, bool doBit)
{
cout<<"Setting new buffer size "<<newsize<<" for writer"<<endl;
if(newsize > sizeof(dnsheader))
payload.resize(newsize - sizeof(dnsheader));
d_doBit = doBit;
haveEDNS=true;
}

View File

@ -14,7 +14,6 @@ public:
void getQuestion(DNSName& name, DNSType& type) const;
bool getEDNS(uint16_t* newsize, bool* doBit) const;
std::string serialize() const;
private:
DNSName getName();
@ -26,20 +25,23 @@ private:
bool d_haveEDNS;
};
struct DNSMessageWriter
class DNSMessageWriter
{
public:
struct dnsheader dh=dnsheader{};
std::vector<uint8_t> payload;
uint16_t payloadpos=0;
DNSName d_qname;
DNSType d_qtype;
DNSClass d_qclass;
bool haveEDNS{false};
bool d_doBit;
explicit DNSMessageWriter(int maxsize=500)
{
payload.resize(maxsize);
}
void setQuestion(const DNSName& name, DNSType type);
DNSMessageWriter(const DNSName& name, DNSType type, int maxsize=500);
void clearRRs();
void putRR(DNSSection section, const DNSName& name, DNSType type, uint32_t ttl, const std::unique_ptr<RRGen>& rr);
void putEDNS(uint16_t bufsize, bool doBit);
void setEDNS(uint16_t bufsize, bool doBit);
std::string serialize() const;
void putUInt8(uint8_t val)
@ -87,5 +89,8 @@ struct DNSMessageWriter
}
putUInt8(0);
}
private:
void putEDNS(uint16_t bufsize, bool doBit);
};

View File

@ -47,39 +47,25 @@ bool processQuestion(const DNSNode& zones, DNSMessageReader& dm, const ComboAddr
DNSName qname;
DNSType qtype;
dm.getQuestion(qname, qtype);
DNSName origname=qname; // we need this for error reporting, we munch the original name
bool haveEDNS=false;
cout<<"Received a query from "<<remote.toStringWithPort()<<" for "<<qname<<" and type "<<qtype<<endl;
uint16_t newsize=0;
bool doBit=false;
haveEDNS = dm.getEDNS(&newsize, &doBit);
if(haveEDNS && newsize > sizeof(dnsheader))
response.payload.resize(newsize - sizeof(dnsheader));
try {
response.dh = dm.dh;
response.dh.id = dm.dh.id;
response.dh.ad = response.dh.ra = response.dh.aa = 0;
response.dh.qr = 1;
response.setQuestion(qname, qtype);
if(qtype == DNSType::AXFR || qtype == DNSType::IXFR) {
cout<<"Query was for AXFR or IXFR over UDP, can't do that"<<endl;
response.dh.rcode = (int)RCode::Servfail;
if(haveEDNS) {
response.putEDNS(newsize, doBit);
}
return true;
}
if(dm.dh.opcode != 0) {
cout<<"Query had non-zero opcode "<<dm.dh.opcode<<", sending NOTIMP"<<endl;
response.dh.rcode = (int)RCode::Notimp;
if(haveEDNS) {
response.putEDNS(newsize, doBit);
}
return true;
}
@ -88,9 +74,6 @@ bool processQuestion(const DNSNode& zones, DNSMessageReader& dm, const ComboAddr
if(!fnd || !fnd->zone) {
cout<<"No zone matched"<<endl;
response.dh.rcode = (uint8_t)RCode::Refused;
if(haveEDNS) {
response.putEDNS(newsize, doBit);
}
return true;
}
@ -175,19 +158,12 @@ bool processQuestion(const DNSNode& zones, DNSMessageReader& dm, const ComboAddr
}
addAdditional(bestzone, zonename, additional, response);
}
if(haveEDNS) {
response.putEDNS(newsize, doBit);
}
return true;
}
catch(std::out_of_range& e) { // exceeded packet size
cout<<"Query for '"<<origname<<"'|"<<qtype<<" got truncated"<<endl;
response.setQuestion(origname, qtype); // this resets the packet
response.dh.tc=1; response.dh.aa=0;
if(haveEDNS) {
response.putEDNS(newsize, doBit);
}
response.clearRRs();
response.dh.aa = 0; response.dh.tc = 1;
return true;
}
catch(std::exception& e) {
@ -208,11 +184,19 @@ void udpThread(ComboAddress local, Socket* sock, const DNSNode* zones)
cerr<<"Dropping non-query from "<<remote.toStringWithPort()<<endl;
continue;
}
DNSName qname;
DNSType qtype;
dm.getQuestion(qname, qtype);
DNSMessageWriter response(qname, qtype);
uint16_t newsize; bool doBit;
if(dm.getEDNS(&newsize, &doBit))
response.setEDNS(newsize, doBit);
DNSMessageWriter response;
if(processQuestion(*zones, dm, local, remote, response)) {
cout<<"Sending response with rcode "<<(RCode)response.dh.rcode <<endl;
SSendto(*sock, response.serialize(), remote);
string ret = response.serialize();
SSendto(*sock, ret, remote);
}
}
}
@ -261,15 +245,14 @@ void tcpClientThread(ComboAddress local, ComboAddress remote, int s, const DNSNo
DNSName name;
DNSType type;
dm.getQuestion(name, type);
DNSMessageWriter response(std::numeric_limits<uint16_t>::max()-sizeof(dnsheader));
DNSMessageWriter response(name, type, std::numeric_limits<uint16_t>::max());
if(type == DNSType::AXFR) {
cout<<"Should do AXFR for "<<name<<endl;
response.dh = dm.dh;
response.dh.id = dm.dh.id;
response.dh.ad = response.dh.ra = response.dh.aa = 0;
response.dh.qr = 1;
response.setQuestion(name, type);
DNSName zone;
auto fnd = zones->find(name, zone);
@ -287,7 +270,7 @@ void tcpClientThread(ComboAddress local, ComboAddress remote, int s, const DNSNo
response.putRR(DNSSection::Answer, zone, DNSType::SOA, node->rrsets[DNSType::SOA].ttl, node->rrsets[DNSType::SOA].contents[0]);
writeTCPResponse(sock, response);
response.setQuestion(zone, type);
response.clearRRs();
// send all other records
node->visit([&response,&sock,&name,&type,&zone](const DNSName& nname, const DNSNode* n) {
@ -301,7 +284,7 @@ void tcpClientThread(ComboAddress local, ComboAddress remote, int s, const DNSNo
}
catch(std::out_of_range& e) { // exceeded packet size
writeTCPResponse(sock, response);
response.setQuestion(zone, type);
response.clearRRs();
goto retry;
}
}
@ -309,7 +292,7 @@ void tcpClientThread(ComboAddress local, ComboAddress remote, int s, const DNSNo
}, zone);
writeTCPResponse(sock, response);
response.setQuestion(zone, type);
response.clearRRs();
// send SOA again
response.putRR(DNSSection::Answer, zone, DNSType::SOA, node->rrsets[DNSType::SOA].ttl, node->rrsets[DNSType::SOA].contents[0]);
@ -318,8 +301,6 @@ void tcpClientThread(ComboAddress local, ComboAddress remote, int s, const DNSNo
return;
}
else {
dm.payload.rewind();
if(processQuestion(*zones, dm, local, remote, response)) {
writeTCPResponse(sock, response);
}