/* Goal: a fully standards compliant basic authoritative server. In <500 lines. Non-goals: notifications, slaving zones, name compression, edns, performance */ #include #include #include #include #include #include #include "sclasses.hh" #include "dns.hh" using namespace std; typedef uint16_t dnstype; typedef std::string dnslabel; enum class RCode { Refused=5 }; enum class DNSType { A = 1, NS = 2, CNAME = 5, SOA=6, AAAA = 28 }; typedef deque dnsname; // this should perform escaping rules! static std::ostream & operator<<(std::ostream &os, const dnsname& d) { for(const auto& l : d) { os< children; map > rrsets; DNSNode* zone{0}; // if this is set, this node is a zone }; DNSNode* DNSNode::find(dnsname& name, dnsname& last) { cout<<"Lookup for '"<second.find(name, last); } DNSNode* DNSNode::add(dnsname name) { cout<<"Add for '"<second.add(name); } struct DNSMessage { struct dnsheader dh=dnsheader{}; std::array payload; uint16_t payloadpos{0}, payloadsize{0}; dnsname getName(); uint16_t getUInt16(); uint32_t getUInt32(); void putName(const dnsname& name); void putUInt16(uint16_t val); void putUInt32(uint32_t val); void putBlob(const std::string& blob); void getQuestion(dnsname& name, dnstype& type); void setQuestion(const dnsname& name, dnstype type); void putRR(const dnsname& name, uint16_t type, uint32_t ttl, const std::string& rr); std::string serialize() const; } __attribute__((packed)); dnsname DNSMessage::getName() { dnsname name; for(;;) { uint8_t labellen=payload.at(payloadpos++); if(labellen > 63) throw std::runtime_error("Got a compressed label"); if(!labellen) // end of dnsname break; dnslabel label(&payload.at(payloadpos), &payload.at(payloadpos+labellen)); payloadpos += labellen; name.push_back(label); } return name; } uint16_t DNSMessage::getUInt16() { uint16_t ret; memcpy(&ret, &payload.at(payloadpos+2)-2, 2); payloadpos+=2; return htons(ret); } void DNSMessage::getQuestion(dnsname& name, dnstype& type) { name=getName(); type=getUInt16(); } void DNSMessage::putName(const dnsname& name) { for(const auto& l : name) { payload.at(payloadpos++)=l.size(); for(const auto& a : l) payload.at(payloadpos++)=(uint8_t)a; } payload.at(payloadpos++)=0; } void DNSMessage::putUInt16(uint16_t val) { val = htons(val); memcpy(&payload.at(payloadpos+2)-2, &val, 2); payloadpos+=2; } void DNSMessage::putUInt32(uint32_t val) { val = htonl(val); memcpy(&payload.at(payloadpos+sizeof(val)) - sizeof(val), &val, sizeof(val)); payloadpos += sizeof(val); } void DNSMessage::putBlob(const std::string& blob) { memcpy(&payload.at(payloadpos+blob.size()) - blob.size(), blob.c_str(), blob.size()); payloadpos += blob.size();; } void DNSMessage::putRR(const dnsname& name, uint16_t type, uint32_t ttl, const std::string& payload) { putName(name); putUInt16(type); putUInt16(1); putUInt32(ttl); putUInt16(payload.size()); // check for overflow! putBlob(payload); } void DNSMessage::setQuestion(const dnsname& name, dnstype type) { putName(name); putUInt16(type); putUInt16(1); // class } string DNSMessage::serialize() const { return string((const char*)this, (const char*)this + sizeof(dnsheader) + payloadpos); } static_assert(sizeof(DNSMessage) == 516, "dnsmessage size must be 516"); int main(int argc, char** argv) { ComboAddress local(argv[1], 53); Socket udplistener(local.sin4.sin_family, SOCK_DGRAM); SBind(udplistener, local); DNSNode zones; auto zone = zones.add({"powerdns", "org"}); zone->zone = new DNSNode(); // XXX ICK zone->zone->rrsets[(dnstype)DNSType::SOA]={"hello"}; zone->zone->rrsets[(dnstype)DNSType::A]={"\x01\x02\x03\x04"}; zone->zone->add({"www"})->rrsets[(dnstype)DNSType::CNAME]={"\x03www\x02nl\x00"}; for(;;) { ComboAddress remote(local); DNSMessage dm; string message = SRecvfrom(udplistener, sizeof(dm), remote); if(message.size() < sizeof(dnsheader)) { cerr<<"Dropping query from "<zone) { cout<<"Best zone: "<zone<zone; dnsname searchname(name), lastnode; auto rrsets = bestzone->find(searchname, lastnode); if(!rrsets) { cout<<"Found nothing in zone '"<rrsets.count(type)) { cout<<"Had qtype too!"<rrsets[type]) { response.putRR({"powerdns", "org"}, type, 3600, rr); response.dh.ancount = htons(ntohs(response.dh.ancount)+1); } } else { cout<<"Node exists, but no matching qtype"<rrsets.count((int)DNSType::CNAME)) { cout<<"We do have a CNAME!"<rrsets[(int)DNSType::CNAME]) { response.putRR({"www", "powerdns", "org"}, type, 3600, rr); response.dh.ancount = htons(ntohs(response.dh.ancount)+1); } } } } } else { response.dh.rcode = (uint8_t)RCode::Refused; } SSendto(udplistener, response.serialize(), remote); } }