/* Goal: a fully standards compliant basic authoritative server. In <500 lines. Non-goals: notifications, slaving zones, name compression, edns, performance */ #include #include #include #include #include #include #include "sclasses.hh" #include "dns.hh" #include "safearray.hh" #include using namespace std; typedef std::string dnslabel; enum class RCode { Noerror = 0, Servfail = 2, Nxdomain = 3, Notimp = 4, Refused = 5 }; enum class DNSType : uint16_t { A = 1, NS = 2, CNAME = 5, SOA=6, AAAA = 28, IXFR = 251, AXFR = 252, ANY = 255 }; enum class DNSSection { Question, Answer, Authority, Additional }; typedef deque dnsname; // this should perform escaping rules! static std::ostream & operator<<(std::ostream &os, const dnsname& d) { for(const auto& l : d) { os< contents; uint32_t ttl{3600}; }; struct CIStringCompare: public std::binary_function { bool operator()(const string& a, const string& b) const { return strcasecmp(a.c_str(), b.c_str()) < 0; // XXX locale pain, plus embedded zeros } }; struct DNSNode { const DNSNode* find(dnsname& name, dnsname& last, bool* passedZonecut=0) const; DNSNode* add(dnsname name); map children; map rrsets; DNSNode* zone{0}; // if this is set, this node is a zone }; const DNSNode* DNSNode::find(dnsname& name, dnsname& last, bool* passedZonecut) const { cout<<"find for '"<first<<"'"<second.find(name, last, passedZonecut); } DNSNode* DNSNode::add(dnsname name) { cout<<"Add for '"<second.add(name); } struct DNSMessage { struct dnsheader dh=dnsheader{}; SafeArray<500> payload; dnsname getName(); void getQuestion(dnsname& name, DNSType& type); void setQuestion(const dnsname& name, DNSType type); void putRR(DNSSection section, const dnsname& name, DNSType type, uint32_t ttl, const std::string& rr); string serialize() const; }; // __attribute__((packed)); dnsname DNSMessage::getName() { dnsname name; for(;;) { uint8_t labellen=payload.getUInt8(); if(labellen > 63) throw std::runtime_error("Got a compressed label"); if(!labellen) // end of dnsname break; dnslabel label = payload.getBlob(labellen); name.push_back(label); } return name; } void DNSMessage::getQuestion(dnsname& name, DNSType& type) { name=getName(); type=(DNSType)payload.getUInt16(); } void putName(auto& payload, const dnsname& name) { for(const auto& l : name) { if(l.size() > 63) throw std::runtime_error("Can't emit a label larger than 63 characters"); payload.putUInt8(l.size()); payload.putBlob(l); } payload.putUInt8(0); } void DNSMessage::putRR(DNSSection section, const dnsname& name, DNSType type, uint32_t ttl, const std::string& content) { putName(payload, name); payload.putUInt16((int)type); payload.putUInt16(1); payload.putUInt32(ttl); payload.putUInt16(content.size()); // check for overflow! payload.putBlob(content); switch(section) { case DNSSection::Question: throw runtime_error("Can't add questions to a DNS Message with putRR"); case DNSSection::Answer: dh.ancount = htons(ntohs(dh.ancount) + 1); break; case DNSSection::Authority: dh.nscount = htons(ntohs(dh.nscount) + 1); break; case DNSSection::Additional: dh.arcount = htons(ntohs(dh.arcount) + 1); break; } } void DNSMessage::setQuestion(const dnsname& name, DNSType type) { putName(payload, name); payload.putUInt16((uint16_t)type); payload.putUInt16(1); // class } string DNSMessage::serialize() const { return string((const char*)this, (const char*)this + sizeof(dnsheader) + payload.payloadpos); } static_assert(sizeof(DNSMessage) == 516, "dnsmessage size must be 516"); std::string serializeDNSName(const dnsname& dn) { std::string ret; for(const auto & l : dn) { ret.append(1, l.size()); ret+=l; } ret.append(1, (char)0); return ret; } std::string serializeSOARecord(const dnsname& mname, const dnsname& rname, uint32_t serial, uint32_t minimum=3600, uint32_t refresh=10800, uint32_t retry=3600, uint32_t expire=604800) { SafeArray<256> sa; putName(sa, mname); putName(sa, rname); sa.putUInt32(serial); sa.putUInt32(refresh); sa.putUInt32(retry); sa.putUInt32(expire); sa.putUInt32(minimum); return sa.serialize(); } std::string serializeARecord(const std::string& src) { ComboAddress ca(src); if(ca.sin4.sin_family != AF_INET) throw std::runtime_error("Could not convert '"+src+"' to an IPv4 address"); auto p = (const char*)&ca.sin4.sin_addr.s_addr; return std::string(p, p+4); } std::string serializeAAAARecord(const std::string& src) { ComboAddress ca(src); if(ca.sin4.sin_family != AF_INET6) throw std::runtime_error("Could not convert '"+src+"' to an IPv6 address"); auto p = (const char*)ca.sin6.sin6_addr.s6_addr; return std::string(p, p+16); } dnsname operator+(const dnsname& a, const dnsname& b) { dnsname ret=a; for(const auto& l : b) ret.push_back(l); return ret; } bool processQuestion(const DNSNode& zones, DNSMessage& dm, const ComboAddress& local, const ComboAddress& remote, DNSMessage& response) try { dnsname name; DNSType type; dm.getQuestion(name, type); cout<<"Received a query from "<zone) { cout<<"---\nBest zone: "<zone<zone; dnsname searchname(name), lastnode; bool passedZonecut=false; auto node = bestzone->find(searchname, lastnode, &passedZonecut); if(passedZonecut) response.dh.aa = false; if(!node) { cout<<"Found nothing in zone '"<rrsets) { cout<<" Have type "<<(int)rr.first<rrsets.find(DNSType::NS); if(iter != node->rrsets.end() && passedZonecut) { cout<<"Have delegation"<second; for(const auto& rr : rrset.contents) { response.putRR(DNSSection::Answer, lastnode+zone, DNSType::NS, rrset.ttl, rr); } // should do additional processing here } else { cout<<"This is an NXDOMAIN situation"<zone->rrsets[DNSType::SOA]; response.dh.rcode = (int)RCode::Nxdomain; response.putRR(DNSSection::Authority, zone, DNSType::SOA, rrset.ttl, rrset.contents[0]); } } else { cout<<"Found something in zone '"<second; for(const auto& rr : rrset.contents) { response.putRR(DNSSection::Answer, lastnode+zone, DNSType::CNAME, rrset.ttl, rr); } cout<<" We should actually follow this, at least within our zone"<zone->rrsets[DNSType::SOA]; response.putRR(DNSSection::Answer, zone, DNSType::SOA, rrset.ttl, rrset.contents[0]); } } } else { response.dh.rcode = (uint8_t)RCode::Refused; } return true; } catch(std::exception& e) { cout<<"Error processing query: "< 512) { cerr<<"Remote "<zone = new DNSNode(); // XXX ICK zone->zone->rrsets[DNSType::SOA]={{serializeSOARecord({"ns1", "powerdns", "org"}, {"admin", "powerdns", "org"}, 1)}}; zone->zone->rrsets[DNSType::A]={{serializeARecord("1.2.3.4")}, 300}; zone->zone->rrsets[DNSType::AAAA]={{serializeAAAARecord("::1"), serializeAAAARecord("2001::1")}, 900}; zone->zone->rrsets[DNSType::NS]={{serializeDNSName({"ns1", "powerdns", "org"})}, 300}; zone->zone->add({"www"})->rrsets[DNSType::CNAME]={{serializeDNSName({"server1","powerdns","org"})}}; zone->zone->add({"server1"})->rrsets[DNSType::A]={{serializeARecord("213.244.168.210")}}; // zone->zone->add({"*"})->rrsets[(dnstype)DNSType::A]={"\x05\x06\x07\x08"}; zone->zone->add({"fra"})->rrsets[DNSType::NS]={{serializeDNSName({"ns1","fra","powerdns","org"})}}; zone->zone->add({"ns1", "fra"})->rrsets[DNSType::A]={{serializeARecord("12.13.14.15")}, 86400}; } int main(int argc, char** argv) { ComboAddress local(argv[1], 53); Socket udplistener(local.sin4.sin_family, SOCK_DGRAM); SBind(udplistener, local); Socket tcplistener(local.sin4.sin_family, SOCK_STREAM); SSetsockopt(tcplistener, SOL_SOCKET, SO_REUSEPORT, 1); SBind(tcplistener, local); SListen(tcplistener, 10); DNSNode zones; loadZones(zones); thread udpServer(udpThread, local, &udplistener, &zones); for(;;) { ComboAddress remote; int client = SAccept(tcplistener, remote); thread t(tcpClientThread, local, remote, client, &zones); t.detach(); } udpServer.join(); }