further split up

This commit is contained in:
bert hubert 2018-04-09 16:43:24 +02:00
parent 97402d1071
commit 8694ac87eb
6 changed files with 121 additions and 97 deletions

View File

@ -12,5 +12,5 @@ clean:
-include *.d
tdns: tdns.o dns-storage.o ext/simplesocket/comboaddress.o ext/simplesocket/sclasses.o ext/simplesocket/swrappers.o
tdns: tdns.o dns-storage.o dnsmessages.o ext/simplesocket/comboaddress.o ext/simplesocket/sclasses.o ext/simplesocket/swrappers.o
g++ -std=gnu++14 $^ -o $@ -pthread

View File

@ -1,17 +1,16 @@
#pragma once
#include "dns-storage.hh"
struct DNSPacketWriter;
#include "dnsmessages.hh"
struct RRGenerator
{
virtual void toPacket(DNSPacketWriter& dpw) = 0;
virtual void toPacket(DNSMessageWriter& dpw) = 0;
};
struct AGenerator : RRGenerator
{
std::unique_ptr<RRGenerator> make(ComboAddress);
std::unique_ptr<RRGenerator> make(std::string);
void toPacket(DNSPacketWriter& dpw) override;
void toPacket(DNSMessageWriter& dpw) override;
uint32_t d_ip;
};

73
tdns/dnsmessages.cc Normal file
View File

@ -0,0 +1,73 @@
#include "dnsmessages.hh"
using namespace std;
dnsname DNSMessageReader::getName()
{
dnsname name;
for(;;) {
uint8_t labellen=payload.getUInt8();
if(labellen > 63)
throw std::runtime_error("Got a compressed label");
if(!labellen) // end of dnsname
break;
dnslabel label = payload.getBlob(labellen);
name.push_back(label);
}
return name;
}
void DNSMessageReader::getQuestion(dnsname& name, DNSType& type)
{
name=getName();
type=(DNSType)payload.getUInt16();
}
void DNSMessageWriter::putRR(DNSSection section, const dnsname& name, DNSType type, uint32_t ttl, const std::string& content)
{
auto cursize = payload.payloadpos;
try {
putName(payload, name);
payload.putUInt16((int)type); payload.putUInt16(1);
payload.putUInt32(ttl);
payload.putUInt16(content.size()); // check for overflow!
payload.putBlob(content);
}
catch(...) {
payload.payloadpos = cursize;
throw;
}
switch(section) {
case DNSSection::Question:
throw runtime_error("Can't add questions to a DNS Message with putRR");
case DNSSection::Answer:
dh.ancount = htons(ntohs(dh.ancount) + 1);
break;
case DNSSection::Authority:
dh.nscount = htons(ntohs(dh.nscount) + 1);
break;
case DNSSection::Additional:
dh.arcount = htons(ntohs(dh.arcount) + 1);
break;
}
}
void DNSMessageWriter::setQuestion(const dnsname& name, DNSType type)
{
payload.rewind();
putName(payload, name);
payload.putUInt16((uint16_t)type);
payload.putUInt16(1); // class
}
string DNSMessageReader::serialize() const
{
return string((const char*)this, (const char*)this + sizeof(dnsheader) + payload.payloadpos);
}
string DNSMessageWriter::serialize() const
{
return string((const char*)this, (const char*)this + sizeof(dnsheader) + payload.payloadpos);
}
static_assert(sizeof(DNSMessageReader) == 516, "dnsmessagereader size must be 516");

35
tdns/dnsmessages.hh Normal file
View File

@ -0,0 +1,35 @@
#include "dns.hh"
#include "safearray.hh"
#include "dns-storage.hh"
struct DNSMessageReader
{
struct dnsheader dh=dnsheader{};
SafeArray<500> payload;
dnsname getName();
void getQuestion(dnsname& name, DNSType& type);
std::string serialize() const;
};
struct DNSMessageWriter
{
struct dnsheader dh=dnsheader{};
SafeArray<1500> payload;
void setQuestion(const dnsname& name, DNSType type);
void putRR(DNSSection section, const dnsname& name, DNSType type, uint32_t ttl, const std::string& rr);
std::string serialize() const;
};
void putName(auto& payload, const dnsname& name)
{
for(const auto& l : name) {
if(l.size() > 63)
throw std::runtime_error("Can't emit a label larger than 63 characters");
payload.putUInt8(l.size());
payload.putBlob(l);
}
payload.putUInt8(0);
}

View File

@ -1,6 +1,9 @@
#pragma once
#include <string>
#include <cstdint>
#include <array>
#include <arpa/inet.h>
#include <string.h>
template<int N>
struct SafeArray

View File

@ -18,93 +18,7 @@
using namespace std;
struct DNSMessage
{
struct dnsheader dh=dnsheader{};
SafeArray<500> payload;
dnsname getName();
void getQuestion(dnsname& name, DNSType& type);
void setQuestion(const dnsname& name, DNSType type);
void putRR(DNSSection section, const dnsname& name, DNSType type, uint32_t ttl, const std::string& rr);
string serialize() const;
};
dnsname DNSMessage::getName()
{
dnsname name;
for(;;) {
uint8_t labellen=payload.getUInt8();
if(labellen > 63)
throw std::runtime_error("Got a compressed label");
if(!labellen) // end of dnsname
break;
dnslabel label = payload.getBlob(labellen);
name.push_back(label);
}
return name;
}
void DNSMessage::getQuestion(dnsname& name, DNSType& type)
{
name=getName();
type=(DNSType)payload.getUInt16();
}
void putName(auto& payload, const dnsname& name)
{
for(const auto& l : name) {
if(l.size() > 63)
throw std::runtime_error("Can't emit a label larger than 63 characters");
payload.putUInt8(l.size());
payload.putBlob(l);
}
payload.putUInt8(0);
}
void DNSMessage::putRR(DNSSection section, const dnsname& name, DNSType type, uint32_t ttl, const std::string& content)
{
auto cursize = payload.payloadpos;
try {
putName(payload, name);
payload.putUInt16((int)type); payload.putUInt16(1);
payload.putUInt32(ttl);
payload.putUInt16(content.size()); // check for overflow!
payload.putBlob(content);
}
catch(...) {
payload.payloadpos = cursize;
throw;
}
switch(section) {
case DNSSection::Question:
throw runtime_error("Can't add questions to a DNS Message with putRR");
case DNSSection::Answer:
dh.ancount = htons(ntohs(dh.ancount) + 1);
break;
case DNSSection::Authority:
dh.nscount = htons(ntohs(dh.nscount) + 1);
break;
case DNSSection::Additional:
dh.arcount = htons(ntohs(dh.arcount) + 1);
break;
}
}
void DNSMessage::setQuestion(const dnsname& name, DNSType type)
{
putName(payload, name);
payload.putUInt16((uint16_t)type);
payload.putUInt16(1); // class
}
string DNSMessage::serialize() const
{
return string((const char*)this, (const char*)this + sizeof(dnsheader) + payload.payloadpos);
}
static_assert(sizeof(DNSMessage) == 516, "dnsmessage size must be 516");
std::string serializeDNSName(const dnsname& dn)
{
@ -155,7 +69,7 @@ std::string serializeAAAARecord(const std::string& src)
return std::string(p, p+16);
}
bool processQuestion(const DNSNode& zones, DNSMessage& dm, const ComboAddress& local, const ComboAddress& remote, DNSMessage& response)
bool processQuestion(const DNSNode& zones, DNSMessageReader& dm, const ComboAddress& local, const ComboAddress& remote, DNSMessageWriter& response)
try
{
dnsname name;
@ -281,7 +195,7 @@ void udpThread(ComboAddress local, Socket* sock, const DNSNode* zones)
{
for(;;) {
ComboAddress remote(local);
DNSMessage dm;
DNSMessageReader dm;
string message = SRecvfrom(*sock, sizeof(dm), remote);
if(message.size() < sizeof(dnsheader)) {
cerr<<"Dropping query from "<<remote.toStringWithPort()<<", too short"<<endl;
@ -294,7 +208,7 @@ void udpThread(ComboAddress local, Socket* sock, const DNSNode* zones)
continue;
}
DNSMessage response;
DNSMessageWriter response;
if(processQuestion(*zones, dm, local, remote, response)) {
cout<<"Sending response with rcode "<<(RCode)response.dh.rcode <<endl;
SSendto(*sock, response.serialize(), remote);
@ -302,7 +216,7 @@ void udpThread(ComboAddress local, Socket* sock, const DNSNode* zones)
}
}
void writeTCPResponse(int sock, const DNSMessage& response)
void writeTCPResponse(int sock, const DNSMessageWriter& response)
{
string ser="00"+response.serialize();
cout<<"Should send a message of "<<ser.size()<<" bytes in response"<<endl;
@ -339,7 +253,7 @@ void tcpClientThread(ComboAddress local, ComboAddress remote, int s, const DNSNo
cout<<"Reading "<<len<<" bytes"<<endl;
message = SRead(sock, len);
DNSMessage dm;
DNSMessageReader dm;
memcpy(&dm, message.c_str(), message.size());
if(dm.dh.qr) {
@ -350,7 +264,7 @@ void tcpClientThread(ComboAddress local, ComboAddress remote, int s, const DNSNo
dnsname name;
DNSType type;
dm.getQuestion(name, type);
DNSMessage response;
DNSMessageWriter response;
if(type == DNSType::AXFR) {
cout<<"Should do AXFR for "<<name<<endl;