hello-dns/tdns/tdns.cc

302 lines
8.0 KiB
C++
Raw Normal View History

2018-04-01 23:31:41 +07:00
/* Goal: a fully standards compliant basic authoritative server. In <500 lines.
Non-goals: notifications, slaving zones, name compression, edns,
performance
*/
#include <cstdint>
#include <string>
#include <vector>
#include <deque>
#include <map>
#include <stdexcept>
#include "sclasses.hh"
#include "dns.hh"
2018-04-02 18:25:19 +07:00
#include "safearray.hh"
2018-04-01 23:31:41 +07:00
using namespace std;
typedef uint16_t dnstype;
typedef std::string dnslabel;
enum class RCode
{
Refused=5
};
enum class DNSType
{
A = 1,
NS = 2,
CNAME = 5,
SOA=6,
AAAA = 28
};
typedef deque<dnslabel> dnsname;
// this should perform escaping rules!
static std::ostream & operator<<(std::ostream &os, const dnsname& d)
{
for(const auto& l : d) {
os<<l<<".";
}
return os;
}
struct DNSNode
{
2018-04-02 18:25:19 +07:00
DNSNode* find(dnsname& name, dnsname& last, bool* passedZonecut=0);
2018-04-01 23:31:41 +07:00
DNSNode* add(dnsname name);
map<dnslabel, DNSNode> children;
map<dnstype, vector<string> > rrsets;
DNSNode* zone{0}; // if this is set, this node is a zone
};
2018-04-02 18:25:19 +07:00
DNSNode* DNSNode::find(dnsname& name, dnsname& last, bool* passedZonecut)
2018-04-01 23:31:41 +07:00
{
2018-04-02 18:25:19 +07:00
cout<<"find for '"<<name<<"', last is now '"<<last<<"'"<<endl;
if(passedZonecut && rrsets.count((int)DNSType::NS)) {
*passedZonecut=true;
}
2018-04-01 23:31:41 +07:00
if(name.empty()) {
2018-04-02 18:25:19 +07:00
cout<<"Empty lookup, returning this node or 0"<<endl;
2018-04-01 23:31:41 +07:00
if(!zone && rrsets.empty()) // only root zone can have this
return 0;
else
return this;
}
2018-04-02 18:25:19 +07:00
cout<<"Children at this node: ";
for(const auto& c: children) cout <<"'"<<c.first<<"' ";
cout<<endl;
2018-04-01 23:31:41 +07:00
auto iter = children.find(name.back());
cout<<"Looked for child called '"<<name.back()<<"'"<<endl;
if(iter == children.end()) {
2018-04-02 18:25:19 +07:00
cout<<"Found nothing, trying wildcard"<<endl;
iter = children.find("*");
if(iter == children.end()) {
cout<<"Still nothing, returning leaf"<<endl;
return this;
}
else {
cout<<"Had wildcard match, following"<<endl;
}
2018-04-01 23:31:41 +07:00
}
2018-04-02 18:25:19 +07:00
cout<<"Had match, continuing to child '"<<iter->first<<"'"<<endl;
2018-04-01 23:31:41 +07:00
last.push_front(name.back());
name.pop_back();
2018-04-02 18:25:19 +07:00
return iter->second.find(name, last, passedZonecut);
2018-04-01 23:31:41 +07:00
}
DNSNode* DNSNode::add(dnsname name)
{
cout<<"Add for '"<<name<<"'"<<endl;
if(name.size() == 1) {
cout<<"Last label, adding "<<name.front()<<endl;
return &children[name.front()];
}
auto back = name.back();
name.pop_back();
auto iter = children.find(back);
if(iter == children.end()) {
cout<<"Inserting new child for "<<back<<endl;
return children[back].add(name);
}
return iter->second.add(name);
}
2018-04-02 18:25:19 +07:00
2018-04-01 23:31:41 +07:00
struct DNSMessage
{
struct dnsheader dh=dnsheader{};
2018-04-02 18:25:19 +07:00
SafeArray<500> payload;
2018-04-01 23:31:41 +07:00
dnsname getName();
void putName(const dnsname& name);
void getQuestion(dnsname& name, dnstype& type);
void setQuestion(const dnsname& name, dnstype type);
void putRR(const dnsname& name, uint16_t type, uint32_t ttl, const std::string& rr);
std::string serialize() const;
} __attribute__((packed));
dnsname DNSMessage::getName()
{
dnsname name;
for(;;) {
2018-04-02 18:25:19 +07:00
uint8_t labellen=payload.getUInt8();
2018-04-01 23:31:41 +07:00
if(labellen > 63)
throw std::runtime_error("Got a compressed label");
if(!labellen) // end of dnsname
break;
2018-04-02 18:25:19 +07:00
dnslabel label = payload.getBlob(labellen);
2018-04-01 23:31:41 +07:00
name.push_back(label);
}
return name;
}
void DNSMessage::getQuestion(dnsname& name, dnstype& type)
{
name=getName();
2018-04-02 18:25:19 +07:00
type=payload.getUInt16();
2018-04-01 23:31:41 +07:00
}
void DNSMessage::putName(const dnsname& name)
{
for(const auto& l : name) {
2018-04-02 18:25:19 +07:00
payload.putUInt8(l.size());
payload.putBlob(l);
2018-04-01 23:31:41 +07:00
}
2018-04-02 18:25:19 +07:00
payload.putUInt8(0);
2018-04-01 23:31:41 +07:00
}
2018-04-02 18:25:19 +07:00
void DNSMessage::putRR(const dnsname& name, uint16_t type, uint32_t ttl, const std::string& content)
2018-04-01 23:31:41 +07:00
{
putName(name);
2018-04-02 18:25:19 +07:00
payload.putUInt16(type); payload.putUInt16(1);
payload.putUInt32(ttl);
payload.putUInt16(content.size()); // check for overflow!
payload.putBlob(content);
2018-04-01 23:31:41 +07:00
}
void DNSMessage::setQuestion(const dnsname& name, dnstype type)
{
putName(name);
2018-04-02 18:25:19 +07:00
payload.putUInt16(type);
payload.putUInt16(1); // class
2018-04-01 23:31:41 +07:00
}
string DNSMessage::serialize() const
{
2018-04-02 18:25:19 +07:00
return string((const char*)this, (const char*)this + sizeof(dnsheader) + payload.payloadpos);
2018-04-01 23:31:41 +07:00
}
static_assert(sizeof(DNSMessage) == 516, "dnsmessage size must be 516");
2018-04-02 18:25:19 +07:00
std::string serializeDNSName(const dnsname& dn)
{
std::string ret;
for(const auto & l : dn) {
ret.append(1, l.size());
ret+=l;
}
ret.append(1, (char)0);
return ret;
}
dnsname operator+(const dnsname& a, const dnsname& b)
{
dnsname ret=a;
for(const auto& l : b)
ret.push_back(l);
return ret;
}
2018-04-01 23:31:41 +07:00
int main(int argc, char** argv)
{
ComboAddress local(argv[1], 53);
Socket udplistener(local.sin4.sin_family, SOCK_DGRAM);
SBind(udplistener, local);
DNSNode zones;
auto zone = zones.add({"powerdns", "org"});
zone->zone = new DNSNode(); // XXX ICK
zone->zone->rrsets[(dnstype)DNSType::SOA]={"hello"};
zone->zone->rrsets[(dnstype)DNSType::A]={"\x01\x02\x03\x04"};
2018-04-02 18:25:19 +07:00
zone->zone->add({"www"})->rrsets[(dnstype)DNSType::CNAME]={serializeDNSName({"server1","powerdns","com"})};
// zone->zone->add({"*"})->rrsets[(dnstype)DNSType::A]={"\x05\x06\x07\x08"};
zone->zone->add({"fra"})->rrsets[(dnstype)DNSType::NS]={serializeDNSName({"ns1","fra","powerdns","org"})};
zone->zone->add({"ns1", "fra"})->rrsets[(dnstype)DNSType::A]={"\x05\x06\x07\x08"};
2018-04-01 23:31:41 +07:00
for(;;) {
ComboAddress remote(local);
DNSMessage dm;
string message = SRecvfrom(udplistener, sizeof(dm), remote);
if(message.size() < sizeof(dnsheader)) {
cerr<<"Dropping query from "<<remote.toStringWithPort()<<", too short"<<endl;
continue;
}
memcpy(&dm, message.c_str(), message.size());
if(dm.dh.qr || dm.dh.opcode) {
cerr<<"Dropping non-query from "<<remote.toStringWithPort()<<endl;
}
dnsname name;
dnstype type;
dm.getQuestion(name, type);
cout<<"Received a query from "<<remote.toStringWithPort()<<" for "<<name<<" and type "<<type<<endl;
DNSMessage response;
response.dh = dm.dh;
response.dh.ad = 0;
response.dh.ra = 0;
response.dh.aa = 0;
response.dh.qr = 1;
response.dh.ancount = response.dh.arcount = response.dh.nscount = 0;
response.setQuestion(name, type);
dnsname zone;
auto fnd = zones.find(name, zone);
if(fnd && fnd->zone) {
2018-04-02 18:25:19 +07:00
cout<<"---\nBest zone: "<<zone<<", name now "<<name<<", loaded: "<<(void*)fnd->zone<<endl;
2018-04-01 23:31:41 +07:00
response.dh.aa = 1;
auto bestzone = fnd->zone;
dnsname searchname(name), lastnode;
2018-04-02 18:25:19 +07:00
bool passedZonecut=false;
auto node = bestzone->find(searchname, lastnode, &passedZonecut);
if(!node) {
2018-04-01 23:31:41 +07:00
cout<<"Found nothing in zone '"<<zone<<"' for lhs '"<<name<<"'"<<endl;
}
2018-04-02 18:25:19 +07:00
else if(!searchname.empty()) {
cout<<"This was a partial match, searchname now "<<searchname<<endl;
for(const auto& rr: node->rrsets) {
cout<<" Have type "<<rr.first<<endl;
}
if(node->rrsets.count((int)DNSType::NS)) {
for(const auto& rr : node->rrsets[(int)DNSType::NS]) {
response.putRR(lastnode+zone, (int)DNSType::NS, 3600, rr);
response.dh.nscount = htons(ntohs(response.dh.ancount)+1);
}
// should do additional processing here
}
}
2018-04-01 23:31:41 +07:00
else {
2018-04-02 18:25:19 +07:00
cout<<"Found something in zone '"<<zone<<"' for lhs '"<<name<<"', searchname now '"<<searchname<<"', lastnode '"<<lastnode<<"', passedZonecut="<<passedZonecut<<endl;
if(passedZonecut)
response.dh.aa = false;
if(node->rrsets.count(type)) {
2018-04-01 23:31:41 +07:00
cout<<"Had qtype too!"<<endl;
2018-04-02 18:25:19 +07:00
for(const auto& rr : node->rrsets[type]) {
response.putRR(lastnode+zone, type, 3600, rr);
2018-04-01 23:31:41 +07:00
response.dh.ancount = htons(ntohs(response.dh.ancount)+1);
}
}
2018-04-02 18:25:19 +07:00
else if(node->rrsets.count((int)DNSType::CNAME)) {
cout<<"We do have a CNAME!"<<endl;
for(const auto& rr : node->rrsets[(int)DNSType::CNAME]) {
response.putRR(lastnode+zone, (int)DNSType::CNAME, 3600, rr);
response.dh.ancount = htons(ntohs(response.dh.ancount)+1);
2018-04-01 23:31:41 +07:00
}
}
2018-04-02 18:25:19 +07:00
2018-04-01 23:31:41 +07:00
}
}
else {
response.dh.rcode = (uint8_t)RCode::Refused;
}
SSendto(udplistener, response.serialize(), remote);
}
}