further split up
This commit is contained in:
parent
97402d1071
commit
8694ac87eb
@ -12,5 +12,5 @@ clean:
|
||||
|
||||
-include *.d
|
||||
|
||||
tdns: tdns.o dns-storage.o ext/simplesocket/comboaddress.o ext/simplesocket/sclasses.o ext/simplesocket/swrappers.o
|
||||
tdns: tdns.o dns-storage.o dnsmessages.o ext/simplesocket/comboaddress.o ext/simplesocket/sclasses.o ext/simplesocket/swrappers.o
|
||||
g++ -std=gnu++14 $^ -o $@ -pthread
|
||||
|
@ -1,17 +1,16 @@
|
||||
#pragma once
|
||||
#include "dns-storage.hh"
|
||||
|
||||
struct DNSPacketWriter;
|
||||
#include "dnsmessages.hh"
|
||||
|
||||
struct RRGenerator
|
||||
{
|
||||
virtual void toPacket(DNSPacketWriter& dpw) = 0;
|
||||
virtual void toPacket(DNSMessageWriter& dpw) = 0;
|
||||
};
|
||||
|
||||
struct AGenerator : RRGenerator
|
||||
{
|
||||
std::unique_ptr<RRGenerator> make(ComboAddress);
|
||||
std::unique_ptr<RRGenerator> make(std::string);
|
||||
void toPacket(DNSPacketWriter& dpw) override;
|
||||
void toPacket(DNSMessageWriter& dpw) override;
|
||||
uint32_t d_ip;
|
||||
};
|
||||
|
73
tdns/dnsmessages.cc
Normal file
73
tdns/dnsmessages.cc
Normal file
@ -0,0 +1,73 @@
|
||||
#include "dnsmessages.hh"
|
||||
|
||||
using namespace std;
|
||||
|
||||
dnsname DNSMessageReader::getName()
|
||||
{
|
||||
dnsname name;
|
||||
for(;;) {
|
||||
uint8_t labellen=payload.getUInt8();
|
||||
if(labellen > 63)
|
||||
throw std::runtime_error("Got a compressed label");
|
||||
if(!labellen) // end of dnsname
|
||||
break;
|
||||
dnslabel label = payload.getBlob(labellen);
|
||||
name.push_back(label);
|
||||
}
|
||||
return name;
|
||||
}
|
||||
|
||||
void DNSMessageReader::getQuestion(dnsname& name, DNSType& type)
|
||||
{
|
||||
name=getName();
|
||||
type=(DNSType)payload.getUInt16();
|
||||
}
|
||||
|
||||
|
||||
void DNSMessageWriter::putRR(DNSSection section, const dnsname& name, DNSType type, uint32_t ttl, const std::string& content)
|
||||
{
|
||||
auto cursize = payload.payloadpos;
|
||||
try {
|
||||
putName(payload, name);
|
||||
payload.putUInt16((int)type); payload.putUInt16(1);
|
||||
payload.putUInt32(ttl);
|
||||
payload.putUInt16(content.size()); // check for overflow!
|
||||
payload.putBlob(content);
|
||||
}
|
||||
catch(...) {
|
||||
payload.payloadpos = cursize;
|
||||
throw;
|
||||
}
|
||||
switch(section) {
|
||||
case DNSSection::Question:
|
||||
throw runtime_error("Can't add questions to a DNS Message with putRR");
|
||||
case DNSSection::Answer:
|
||||
dh.ancount = htons(ntohs(dh.ancount) + 1);
|
||||
break;
|
||||
case DNSSection::Authority:
|
||||
dh.nscount = htons(ntohs(dh.nscount) + 1);
|
||||
break;
|
||||
case DNSSection::Additional:
|
||||
dh.arcount = htons(ntohs(dh.arcount) + 1);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void DNSMessageWriter::setQuestion(const dnsname& name, DNSType type)
|
||||
{
|
||||
payload.rewind();
|
||||
putName(payload, name);
|
||||
payload.putUInt16((uint16_t)type);
|
||||
payload.putUInt16(1); // class
|
||||
}
|
||||
|
||||
string DNSMessageReader::serialize() const
|
||||
{
|
||||
return string((const char*)this, (const char*)this + sizeof(dnsheader) + payload.payloadpos);
|
||||
}
|
||||
string DNSMessageWriter::serialize() const
|
||||
{
|
||||
return string((const char*)this, (const char*)this + sizeof(dnsheader) + payload.payloadpos);
|
||||
}
|
||||
|
||||
static_assert(sizeof(DNSMessageReader) == 516, "dnsmessagereader size must be 516");
|
35
tdns/dnsmessages.hh
Normal file
35
tdns/dnsmessages.hh
Normal file
@ -0,0 +1,35 @@
|
||||
#include "dns.hh"
|
||||
#include "safearray.hh"
|
||||
#include "dns-storage.hh"
|
||||
|
||||
struct DNSMessageReader
|
||||
{
|
||||
struct dnsheader dh=dnsheader{};
|
||||
SafeArray<500> payload;
|
||||
|
||||
dnsname getName();
|
||||
void getQuestion(dnsname& name, DNSType& type);
|
||||
|
||||
std::string serialize() const;
|
||||
};
|
||||
|
||||
struct DNSMessageWriter
|
||||
{
|
||||
struct dnsheader dh=dnsheader{};
|
||||
SafeArray<1500> payload;
|
||||
void setQuestion(const dnsname& name, DNSType type);
|
||||
void putRR(DNSSection section, const dnsname& name, DNSType type, uint32_t ttl, const std::string& rr);
|
||||
std::string serialize() const;
|
||||
};
|
||||
|
||||
void putName(auto& payload, const dnsname& name)
|
||||
{
|
||||
for(const auto& l : name) {
|
||||
if(l.size() > 63)
|
||||
throw std::runtime_error("Can't emit a label larger than 63 characters");
|
||||
payload.putUInt8(l.size());
|
||||
payload.putBlob(l);
|
||||
}
|
||||
payload.putUInt8(0);
|
||||
}
|
||||
|
@ -1,6 +1,9 @@
|
||||
#pragma once
|
||||
#include <string>
|
||||
#include <cstdint>
|
||||
#include <array>
|
||||
#include <arpa/inet.h>
|
||||
#include <string.h>
|
||||
|
||||
template<int N>
|
||||
struct SafeArray
|
||||
|
98
tdns/tdns.cc
98
tdns/tdns.cc
@ -18,93 +18,7 @@
|
||||
|
||||
using namespace std;
|
||||
|
||||
struct DNSMessage
|
||||
{
|
||||
struct dnsheader dh=dnsheader{};
|
||||
SafeArray<500> payload;
|
||||
|
||||
dnsname getName();
|
||||
void getQuestion(dnsname& name, DNSType& type);
|
||||
void setQuestion(const dnsname& name, DNSType type);
|
||||
void putRR(DNSSection section, const dnsname& name, DNSType type, uint32_t ttl, const std::string& rr);
|
||||
|
||||
string serialize() const;
|
||||
};
|
||||
|
||||
dnsname DNSMessage::getName()
|
||||
{
|
||||
dnsname name;
|
||||
for(;;) {
|
||||
uint8_t labellen=payload.getUInt8();
|
||||
if(labellen > 63)
|
||||
throw std::runtime_error("Got a compressed label");
|
||||
if(!labellen) // end of dnsname
|
||||
break;
|
||||
dnslabel label = payload.getBlob(labellen);
|
||||
name.push_back(label);
|
||||
}
|
||||
return name;
|
||||
}
|
||||
|
||||
void DNSMessage::getQuestion(dnsname& name, DNSType& type)
|
||||
{
|
||||
name=getName();
|
||||
type=(DNSType)payload.getUInt16();
|
||||
}
|
||||
|
||||
void putName(auto& payload, const dnsname& name)
|
||||
{
|
||||
for(const auto& l : name) {
|
||||
if(l.size() > 63)
|
||||
throw std::runtime_error("Can't emit a label larger than 63 characters");
|
||||
payload.putUInt8(l.size());
|
||||
payload.putBlob(l);
|
||||
}
|
||||
payload.putUInt8(0);
|
||||
}
|
||||
|
||||
void DNSMessage::putRR(DNSSection section, const dnsname& name, DNSType type, uint32_t ttl, const std::string& content)
|
||||
{
|
||||
auto cursize = payload.payloadpos;
|
||||
try {
|
||||
putName(payload, name);
|
||||
payload.putUInt16((int)type); payload.putUInt16(1);
|
||||
payload.putUInt32(ttl);
|
||||
payload.putUInt16(content.size()); // check for overflow!
|
||||
payload.putBlob(content);
|
||||
}
|
||||
catch(...) {
|
||||
payload.payloadpos = cursize;
|
||||
throw;
|
||||
}
|
||||
switch(section) {
|
||||
case DNSSection::Question:
|
||||
throw runtime_error("Can't add questions to a DNS Message with putRR");
|
||||
case DNSSection::Answer:
|
||||
dh.ancount = htons(ntohs(dh.ancount) + 1);
|
||||
break;
|
||||
case DNSSection::Authority:
|
||||
dh.nscount = htons(ntohs(dh.nscount) + 1);
|
||||
break;
|
||||
case DNSSection::Additional:
|
||||
dh.arcount = htons(ntohs(dh.arcount) + 1);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void DNSMessage::setQuestion(const dnsname& name, DNSType type)
|
||||
{
|
||||
putName(payload, name);
|
||||
payload.putUInt16((uint16_t)type);
|
||||
payload.putUInt16(1); // class
|
||||
}
|
||||
|
||||
string DNSMessage::serialize() const
|
||||
{
|
||||
return string((const char*)this, (const char*)this + sizeof(dnsheader) + payload.payloadpos);
|
||||
}
|
||||
|
||||
static_assert(sizeof(DNSMessage) == 516, "dnsmessage size must be 516");
|
||||
|
||||
std::string serializeDNSName(const dnsname& dn)
|
||||
{
|
||||
@ -155,7 +69,7 @@ std::string serializeAAAARecord(const std::string& src)
|
||||
return std::string(p, p+16);
|
||||
}
|
||||
|
||||
bool processQuestion(const DNSNode& zones, DNSMessage& dm, const ComboAddress& local, const ComboAddress& remote, DNSMessage& response)
|
||||
bool processQuestion(const DNSNode& zones, DNSMessageReader& dm, const ComboAddress& local, const ComboAddress& remote, DNSMessageWriter& response)
|
||||
try
|
||||
{
|
||||
dnsname name;
|
||||
@ -281,7 +195,7 @@ void udpThread(ComboAddress local, Socket* sock, const DNSNode* zones)
|
||||
{
|
||||
for(;;) {
|
||||
ComboAddress remote(local);
|
||||
DNSMessage dm;
|
||||
DNSMessageReader dm;
|
||||
string message = SRecvfrom(*sock, sizeof(dm), remote);
|
||||
if(message.size() < sizeof(dnsheader)) {
|
||||
cerr<<"Dropping query from "<<remote.toStringWithPort()<<", too short"<<endl;
|
||||
@ -294,7 +208,7 @@ void udpThread(ComboAddress local, Socket* sock, const DNSNode* zones)
|
||||
continue;
|
||||
}
|
||||
|
||||
DNSMessage response;
|
||||
DNSMessageWriter response;
|
||||
if(processQuestion(*zones, dm, local, remote, response)) {
|
||||
cout<<"Sending response with rcode "<<(RCode)response.dh.rcode <<endl;
|
||||
SSendto(*sock, response.serialize(), remote);
|
||||
@ -302,7 +216,7 @@ void udpThread(ComboAddress local, Socket* sock, const DNSNode* zones)
|
||||
}
|
||||
}
|
||||
|
||||
void writeTCPResponse(int sock, const DNSMessage& response)
|
||||
void writeTCPResponse(int sock, const DNSMessageWriter& response)
|
||||
{
|
||||
string ser="00"+response.serialize();
|
||||
cout<<"Should send a message of "<<ser.size()<<" bytes in response"<<endl;
|
||||
@ -339,7 +253,7 @@ void tcpClientThread(ComboAddress local, ComboAddress remote, int s, const DNSNo
|
||||
cout<<"Reading "<<len<<" bytes"<<endl;
|
||||
|
||||
message = SRead(sock, len);
|
||||
DNSMessage dm;
|
||||
DNSMessageReader dm;
|
||||
memcpy(&dm, message.c_str(), message.size());
|
||||
|
||||
if(dm.dh.qr) {
|
||||
@ -350,7 +264,7 @@ void tcpClientThread(ComboAddress local, ComboAddress remote, int s, const DNSNo
|
||||
dnsname name;
|
||||
DNSType type;
|
||||
dm.getQuestion(name, type);
|
||||
DNSMessage response;
|
||||
DNSMessageWriter response;
|
||||
|
||||
if(type == DNSType::AXFR) {
|
||||
cout<<"Should do AXFR for "<<name<<endl;
|
||||
|
Loading…
Reference in New Issue
Block a user