hello-dns/tdns/tdns.cc

479 lines
15 KiB
C++
Raw Normal View History

2018-04-01 23:31:41 +07:00
/* Goal: a fully standards compliant basic authoritative server. In <500 lines.
Non-goals: notifications, slaving zones, name compression, edns,
performance
*/
#include <cstdint>
#include <string>
#include <vector>
#include <deque>
#include <map>
#include <stdexcept>
#include "sclasses.hh"
#include "dns.hh"
2018-04-02 18:25:19 +07:00
#include "safearray.hh"
2018-04-03 18:18:37 +07:00
#include <thread>
2018-04-09 17:54:27 +07:00
#include <signal.h>
#include "dns-types.hh"
2018-04-09 04:20:11 +07:00
#include "dns-storage.hh"
2018-04-01 23:31:41 +07:00
using namespace std;
struct DNSMessage
{
struct dnsheader dh=dnsheader{};
2018-04-02 18:25:19 +07:00
SafeArray<500> payload;
2018-04-01 23:31:41 +07:00
dnsname getName();
2018-04-03 18:18:37 +07:00
void getQuestion(dnsname& name, DNSType& type);
void setQuestion(const dnsname& name, DNSType type);
void putRR(DNSSection section, const dnsname& name, DNSType type, uint32_t ttl, const std::string& rr);
string serialize() const;
2018-04-09 04:20:11 +07:00
};
2018-04-01 23:31:41 +07:00
dnsname DNSMessage::getName()
{
dnsname name;
for(;;) {
2018-04-02 18:25:19 +07:00
uint8_t labellen=payload.getUInt8();
2018-04-01 23:31:41 +07:00
if(labellen > 63)
throw std::runtime_error("Got a compressed label");
if(!labellen) // end of dnsname
break;
2018-04-02 18:25:19 +07:00
dnslabel label = payload.getBlob(labellen);
2018-04-01 23:31:41 +07:00
name.push_back(label);
}
return name;
}
2018-04-03 18:18:37 +07:00
void DNSMessage::getQuestion(dnsname& name, DNSType& type)
2018-04-01 23:31:41 +07:00
{
name=getName();
2018-04-03 18:18:37 +07:00
type=(DNSType)payload.getUInt16();
2018-04-01 23:31:41 +07:00
}
2018-04-03 18:18:37 +07:00
void putName(auto& payload, const dnsname& name)
2018-04-01 23:31:41 +07:00
{
for(const auto& l : name) {
2018-04-03 18:18:37 +07:00
if(l.size() > 63)
throw std::runtime_error("Can't emit a label larger than 63 characters");
2018-04-02 18:25:19 +07:00
payload.putUInt8(l.size());
payload.putBlob(l);
2018-04-01 23:31:41 +07:00
}
2018-04-02 18:25:19 +07:00
payload.putUInt8(0);
2018-04-01 23:31:41 +07:00
}
2018-04-03 18:18:37 +07:00
void DNSMessage::putRR(DNSSection section, const dnsname& name, DNSType type, uint32_t ttl, const std::string& content)
2018-04-01 23:31:41 +07:00
{
2018-04-09 04:20:11 +07:00
auto cursize = payload.payloadpos;
try {
putName(payload, name);
payload.putUInt16((int)type); payload.putUInt16(1);
payload.putUInt32(ttl);
payload.putUInt16(content.size()); // check for overflow!
payload.putBlob(content);
}
catch(...) {
payload.payloadpos = cursize;
throw;
}
2018-04-03 18:18:37 +07:00
switch(section) {
case DNSSection::Question:
throw runtime_error("Can't add questions to a DNS Message with putRR");
case DNSSection::Answer:
dh.ancount = htons(ntohs(dh.ancount) + 1);
break;
case DNSSection::Authority:
dh.nscount = htons(ntohs(dh.nscount) + 1);
break;
case DNSSection::Additional:
dh.arcount = htons(ntohs(dh.arcount) + 1);
break;
}
2018-04-01 23:31:41 +07:00
}
2018-04-03 18:18:37 +07:00
void DNSMessage::setQuestion(const dnsname& name, DNSType type)
2018-04-01 23:31:41 +07:00
{
2018-04-03 18:18:37 +07:00
putName(payload, name);
payload.putUInt16((uint16_t)type);
2018-04-02 18:25:19 +07:00
payload.putUInt16(1); // class
2018-04-01 23:31:41 +07:00
}
string DNSMessage::serialize() const
{
2018-04-02 18:25:19 +07:00
return string((const char*)this, (const char*)this + sizeof(dnsheader) + payload.payloadpos);
2018-04-01 23:31:41 +07:00
}
2018-04-03 18:18:37 +07:00
2018-04-01 23:31:41 +07:00
static_assert(sizeof(DNSMessage) == 516, "dnsmessage size must be 516");
2018-04-02 18:25:19 +07:00
std::string serializeDNSName(const dnsname& dn)
{
std::string ret;
for(const auto & l : dn) {
ret.append(1, l.size());
ret+=l;
}
ret.append(1, (char)0);
return ret;
}
2018-04-09 04:20:11 +07:00
std::string serializeMXRecord(uint16_t prio, const dnsname& mname)
{
SafeArray<256> sa;
sa.putUInt16(prio);
putName(sa, mname);
return sa.serialize();
}
2018-04-03 18:18:37 +07:00
std::string serializeSOARecord(const dnsname& mname, const dnsname& rname, uint32_t serial, uint32_t minimum=3600, uint32_t refresh=10800, uint32_t retry=3600, uint32_t expire=604800)
{
SafeArray<256> sa;
2018-04-09 17:54:27 +07:00
putName(sa, mname); putName(sa, rname);
sa.putUInt32(serial); sa.putUInt32(refresh);
sa.putUInt32(retry); sa.putUInt32(expire);
2018-04-03 18:18:37 +07:00
sa.putUInt32(minimum);
return sa.serialize();
}
std::string serializeARecord(const std::string& src)
{
ComboAddress ca(src);
if(ca.sin4.sin_family != AF_INET)
throw std::runtime_error("Could not convert '"+src+"' to an IPv4 address");
auto p = (const char*)&ca.sin4.sin_addr.s_addr;
return std::string(p, p+4);
}
std::string serializeAAAARecord(const std::string& src)
{
ComboAddress ca(src);
if(ca.sin4.sin_family != AF_INET6)
throw std::runtime_error("Could not convert '"+src+"' to an IPv6 address");
auto p = (const char*)ca.sin6.sin6_addr.s6_addr;
return std::string(p, p+16);
}
2018-04-03 23:03:59 +07:00
bool processQuestion(const DNSNode& zones, DNSMessage& dm, const ComboAddress& local, const ComboAddress& remote, DNSMessage& response)
try
2018-04-01 23:31:41 +07:00
{
2018-04-03 23:03:59 +07:00
dnsname name;
DNSType type;
dm.getQuestion(name, type);
2018-04-09 04:20:11 +07:00
cout<<"Received a query from "<<remote.toStringWithPort()<<" for "<<name<<" and type "<<type<<endl;
2018-04-03 23:03:59 +07:00
response.dh = dm.dh;
response.dh.ad = 0;
response.dh.ra = 0;
response.dh.aa = 0;
response.dh.qr = 1;
response.dh.ancount = response.dh.arcount = response.dh.nscount = 0;
response.setQuestion(name, type);
if(type == DNSType::AXFR) {
2018-04-09 04:20:11 +07:00
cout<<"Query was for AXFR or IXFR over UDP, can't do that"<<endl;
response.dh.rcode = (int)RCode::Servfail;
return true;
2018-04-03 23:03:59 +07:00
}
if(dm.dh.opcode != 0) {
cout<<"Query had non-zero opcode "<<dm.dh.opcode<<", sending NOTIMP"<<endl;
response.dh.rcode = (int)RCode::Notimp;
return true;
}
dnsname zone;
auto fnd = zones.find(name, zone);
if(fnd && fnd->zone) {
cout<<"---\nBest zone: "<<zone<<", name now "<<name<<", loaded: "<<(void*)fnd->zone<<endl;
response.dh.aa = 1;
auto bestzone = fnd->zone;
dnsname searchname(name), lastnode;
bool passedZonecut=false;
auto node = bestzone->find(searchname, lastnode, &passedZonecut);
if(passedZonecut)
response.dh.aa = false;
if(!node) {
cout<<"Found nothing in zone '"<<zone<<"' for lhs '"<<name<<"'"<<endl;
}
else if(!searchname.empty()) {
cout<<"This was a partial match, searchname now "<<searchname<<endl;
for(const auto& rr: node->rrsets) {
2018-04-09 04:20:11 +07:00
cout<<" Have type "<<rr.first<<endl;
2018-04-03 23:03:59 +07:00
}
auto iter = node->rrsets.find(DNSType::NS);
if(iter != node->rrsets.end() && passedZonecut) {
cout<<"Have delegation"<<endl;
const auto& rrset = iter->second;
for(const auto& rr : rrset.contents) {
2018-04-09 17:54:27 +07:00
response.putRR(DNSSection::Authority, lastnode+zone, DNSType::NS, rrset.ttl, rr);
}
dnsname addname{"ns1", "fra"}, wuh;
cout<<"Looking up glue record "<<addname<<endl;
auto addnode = bestzone->find(addname, wuh);
auto iter2 = addnode->rrsets.find(DNSType::A);
if(iter2 != addnode->rrsets.end()) {
cout<<"Lastnode for '"<<addname<<"' glue: "<<wuh<<endl;
const auto& rrset = iter2->second;
for(const auto& rr : rrset.contents) {
response.putRR(DNSSection::Additional, wuh+zone, DNSType::A, rrset.ttl, rr);
}
2018-04-03 23:03:59 +07:00
}
// should do additional processing here
}
else {
cout<<"This is an NXDOMAIN situation"<<endl;
const auto& rrset = fnd->zone->rrsets[DNSType::SOA];
response.dh.rcode = (int)RCode::Nxdomain;
response.putRR(DNSSection::Authority, zone, DNSType::SOA, rrset.ttl, rrset.contents[0]);
}
}
else {
cout<<"Found something in zone '"<<zone<<"' for lhs '"<<name<<"', searchname now '"<<searchname<<"', lastnode '"<<lastnode<<"', passedZonecut="<<passedZonecut<<endl;
auto iter = node->rrsets.cbegin();
if(type == DNSType::ANY) {
for(const auto& t : node->rrsets) {
const auto& rrset = t.second;
for(const auto& rr : rrset.contents) {
response.putRR(DNSSection::Answer, lastnode+zone, t.first, rrset.ttl, rr);
}
}
}
else if(iter = node->rrsets.find(type), iter != node->rrsets.end()) {
const auto& rrset = iter->second;
for(const auto& rr : rrset.contents) {
response.putRR(DNSSection::Answer, lastnode+zone, type, rrset.ttl, rr);
}
}
else if(iter = node->rrsets.find(DNSType::CNAME), iter != node->rrsets.end()) {
cout<<"We do have a CNAME!"<<endl;
const auto& rrset = iter->second;
for(const auto& rr : rrset.contents) {
response.putRR(DNSSection::Answer, lastnode+zone, DNSType::CNAME, rrset.ttl, rr);
}
cout<<" We should actually follow this, at least within our zone"<<endl;
}
else {
cout<<"Node exists, qtype doesn't, NOERROR situation, inserting SOA"<<endl;
const auto& rrset = fnd->zone->rrsets[DNSType::SOA];
response.putRR(DNSSection::Answer, zone, DNSType::SOA, rrset.ttl, rrset.contents[0]);
}
}
}
else {
2018-04-09 04:20:11 +07:00
cout<<"No zone matched"<<endl;
2018-04-03 23:03:59 +07:00
response.dh.rcode = (uint8_t)RCode::Refused;
}
return true;
}
catch(std::exception& e) {
cout<<"Error processing query: "<<e.what()<<endl;
return false;
}
2018-04-01 23:31:41 +07:00
2018-04-03 23:03:59 +07:00
void udpThread(ComboAddress local, Socket* sock, const DNSNode* zones)
{
2018-04-01 23:31:41 +07:00
for(;;) {
ComboAddress remote(local);
DNSMessage dm;
2018-04-03 23:03:59 +07:00
string message = SRecvfrom(*sock, sizeof(dm), remote);
2018-04-01 23:31:41 +07:00
if(message.size() < sizeof(dnsheader)) {
cerr<<"Dropping query from "<<remote.toStringWithPort()<<", too short"<<endl;
continue;
}
memcpy(&dm, message.c_str(), message.size());
2018-04-03 18:18:37 +07:00
if(dm.dh.qr) {
2018-04-01 23:31:41 +07:00
cerr<<"Dropping non-query from "<<remote.toStringWithPort()<<endl;
2018-04-03 18:18:37 +07:00
continue;
2018-04-01 23:31:41 +07:00
}
DNSMessage response;
2018-04-03 23:03:59 +07:00
if(processQuestion(*zones, dm, local, remote, response)) {
2018-04-09 04:20:11 +07:00
cout<<"Sending response with rcode "<<(RCode)response.dh.rcode <<endl;
2018-04-03 23:03:59 +07:00
SSendto(*sock, response.serialize(), remote);
}
}
}
2018-04-01 23:31:41 +07:00
2018-04-09 04:20:11 +07:00
void writeTCPResponse(int sock, const DNSMessage& response)
{
string ser="00"+response.serialize();
cout<<"Should send a message of "<<ser.size()<<" bytes in response"<<endl;
uint16_t len = htons(ser.length()-2);
ser[0] = *((char*)&len);
ser[1] = *(((char*)&len) + 1);
SWriten(sock, ser);
cout<<"Sent!"<<endl;
}
2018-04-03 23:03:59 +07:00
void tcpClientThread(ComboAddress local, ComboAddress remote, int s, const DNSNode* zones)
{
Socket sock(s);
cout<<"TCP Connection from "<<remote.toStringWithPort()<<endl;
for(;;) {
uint16_t len;
string message = SRead(sock, 2);
if(message.size() != 2)
break;
memcpy(&len, &message.at(1)-1, 2);
len=htons(len);
if(len > 512) {
cerr<<"Remote "<<remote.toStringWithPort()<<" sent question that was too big"<<endl;
return;
2018-04-03 18:18:37 +07:00
}
2018-04-03 23:03:59 +07:00
if(len < sizeof(dnsheader)) {
cerr<<"Dropping query from "<<remote.toStringWithPort()<<", too short"<<endl;
return;
2018-04-03 18:18:37 +07:00
}
2018-04-03 23:03:59 +07:00
cout<<"Reading "<<len<<" bytes"<<endl;
2018-04-03 18:18:37 +07:00
2018-04-03 23:03:59 +07:00
message = SRead(sock, len);
DNSMessage dm;
memcpy(&dm, message.c_str(), message.size());
if(dm.dh.qr) {
cerr<<"Dropping non-query from "<<remote.toStringWithPort()<<endl;
return;
}
dnsname name;
DNSType type;
dm.getQuestion(name, type);
2018-04-09 04:20:11 +07:00
DNSMessage response;
2018-04-03 23:03:59 +07:00
if(type == DNSType::AXFR) {
cout<<"Should do AXFR for "<<name<<endl;
2018-04-09 04:20:11 +07:00
dnsname zone;
auto fnd = zones->find(name, zone);
if(!fnd || !fnd->zone || !name.empty()) {
cout<<" This was not a zone"<<endl;
return;
}
cout<<"Have zone, walking it"<<endl;
response.dh = dm.dh;
response.dh.ad = 0;
response.dh.ra = 0;
response.dh.aa = 0;
response.dh.qr = 1;
response.dh.ancount = response.dh.arcount = response.dh.nscount = 0;
response.setQuestion(zone, type);
auto node = fnd->zone;
// send SOA
response.putRR(DNSSection::Answer, zone, DNSType::SOA, node->rrsets[DNSType::SOA].ttl, node->rrsets[DNSType::SOA].contents[0]);
writeTCPResponse(sock, response);
response.dh.ancount = response.dh.arcount = response.dh.nscount = 0;
response.payload.rewind();
response.setQuestion(zone, type);
// send all other records
node->visit([&response,&sock,&name,&type,&zone](const dnsname& nname, const DNSNode* n) {
cout<<nname<<", types: ";
for(const auto& p : n->rrsets) {
if(p.first == DNSType::SOA)
continue;
for(const auto& rr : p.second.contents) {
retry:
try {
response.putRR(DNSSection::Answer, nname, p.first, p.second.ttl, rr);
}
2018-04-09 17:54:27 +07:00
catch(...) { // exceeded packet size
2018-04-09 04:20:11 +07:00
writeTCPResponse(sock, response);
response.dh.ancount = response.dh.arcount = response.dh.nscount = 0;
response.payload.rewind();
response.setQuestion(zone, type);
goto retry;
}
}
cout<<p.first<<" ";
}
cout<<endl;
}, zone);
writeTCPResponse(sock, response);
response.dh.ancount = response.dh.arcount = response.dh.nscount = 0;
response.payload.rewind();
response.setQuestion(zone, type);
// send SOA again
response.putRR(DNSSection::Answer, zone, DNSType::SOA, node->rrsets[DNSType::SOA].ttl, node->rrsets[DNSType::SOA].contents[0]);
writeTCPResponse(sock, response);
2018-04-03 23:03:59 +07:00
return;
2018-04-01 23:31:41 +07:00
}
else {
2018-04-03 23:03:59 +07:00
dm.payload.rewind();
if(processQuestion(*zones, dm, local, remote, response)) {
2018-04-09 04:20:11 +07:00
writeTCPResponse(sock, response);
2018-04-03 23:03:59 +07:00
}
else
return;
2018-04-01 23:31:41 +07:00
}
}
2018-04-03 18:18:37 +07:00
}
void loadZones(DNSNode& zones)
{
auto zone = zones.add({"powerdns", "org"});
zone->zone = new DNSNode(); // XXX ICK
zone->zone->rrsets[DNSType::SOA]={{serializeSOARecord({"ns1", "powerdns", "org"}, {"admin", "powerdns", "org"}, 1)}};
2018-04-09 04:20:11 +07:00
zone->zone->rrsets[DNSType::MX]={{serializeMXRecord(25, {"server1", "powerdns", "org"})}};
2018-04-03 18:18:37 +07:00
zone->zone->rrsets[DNSType::A]={{serializeARecord("1.2.3.4")}, 300};
2018-04-09 04:20:11 +07:00
zone->zone->rrsets[DNSType::AAAA]={{serializeAAAARecord("::1"), serializeAAAARecord("2001::1"), serializeAAAARecord("2a02:a440:b085:1:beee:7bff:fe89:f0fb")}, 900};
2018-04-03 18:18:37 +07:00
zone->zone->rrsets[DNSType::NS]={{serializeDNSName({"ns1", "powerdns", "org"})}, 300};
zone->zone->add({"www"})->rrsets[DNSType::CNAME]={{serializeDNSName({"server1","powerdns","org"})}};
zone->zone->add({"server1"})->rrsets[DNSType::A]={{serializeARecord("213.244.168.210")}};
2018-04-09 04:20:11 +07:00
zone->zone->add({"server1"})->rrsets[DNSType::AAAA]={{serializeAAAARecord("::1")}};
2018-04-03 18:18:37 +07:00
// zone->zone->add({"*"})->rrsets[(dnstype)DNSType::A]={"\x05\x06\x07\x08"};
zone->zone->add({"fra"})->rrsets[DNSType::NS]={{serializeDNSName({"ns1","fra","powerdns","org"})}};
zone->zone->add({"ns1", "fra"})->rrsets[DNSType::A]={{serializeARecord("12.13.14.15")}, 86400};
2018-04-09 04:20:11 +07:00
zone->zone->add({"NS2", "fra"})->rrsets[DNSType::A]={{serializeARecord("12.13.14.16")}, 86400};
2018-04-03 18:18:37 +07:00
}
int main(int argc, char** argv)
{
2018-04-09 17:54:27 +07:00
signal(SIGPIPE, SIG_IGN);
2018-04-03 18:18:37 +07:00
ComboAddress local(argv[1], 53);
2018-04-03 23:03:59 +07:00
Socket udplistener(local.sin4.sin_family, SOCK_DGRAM);
SBind(udplistener, local);
Socket tcplistener(local.sin4.sin_family, SOCK_STREAM);
SSetsockopt(tcplistener, SOL_SOCKET, SO_REUSEPORT, 1);
SBind(tcplistener, local);
SListen(tcplistener, 10);
2018-04-03 18:18:37 +07:00
DNSNode zones;
loadZones(zones);
2018-04-03 23:03:59 +07:00
thread udpServer(udpThread, local, &udplistener, &zones);
2018-04-03 18:18:37 +07:00
2018-04-03 23:03:59 +07:00
for(;;) {
ComboAddress remote;
int client = SAccept(tcplistener, remote);
thread t(tcpClientThread, local, remote, client, &zones);
t.detach();
}
2018-04-01 23:31:41 +07:00
}