/* Goal: a fully standards compliant basic authoritative server. In <500 lines. Non-goals: notifications, slaving zones, name compression, edns, performance */ #include #include #include #include #include #include #include "sclasses.hh" #include "dns.hh" #include "safearray.hh" #include #include "dns-storage.hh" using namespace std; struct DNSMessage { struct dnsheader dh=dnsheader{}; SafeArray<500> payload; dnsname getName(); void getQuestion(dnsname& name, DNSType& type); void setQuestion(const dnsname& name, DNSType type); void putRR(DNSSection section, const dnsname& name, DNSType type, uint32_t ttl, const std::string& rr); string serialize() const; }; dnsname DNSMessage::getName() { dnsname name; for(;;) { uint8_t labellen=payload.getUInt8(); if(labellen > 63) throw std::runtime_error("Got a compressed label"); if(!labellen) // end of dnsname break; dnslabel label = payload.getBlob(labellen); name.push_back(label); } return name; } void DNSMessage::getQuestion(dnsname& name, DNSType& type) { name=getName(); type=(DNSType)payload.getUInt16(); } void putName(auto& payload, const dnsname& name) { for(const auto& l : name) { if(l.size() > 63) throw std::runtime_error("Can't emit a label larger than 63 characters"); payload.putUInt8(l.size()); payload.putBlob(l); } payload.putUInt8(0); } void DNSMessage::putRR(DNSSection section, const dnsname& name, DNSType type, uint32_t ttl, const std::string& content) { auto cursize = payload.payloadpos; try { putName(payload, name); payload.putUInt16((int)type); payload.putUInt16(1); payload.putUInt32(ttl); payload.putUInt16(content.size()); // check for overflow! payload.putBlob(content); } catch(...) { payload.payloadpos = cursize; throw; } switch(section) { case DNSSection::Question: throw runtime_error("Can't add questions to a DNS Message with putRR"); case DNSSection::Answer: dh.ancount = htons(ntohs(dh.ancount) + 1); break; case DNSSection::Authority: dh.nscount = htons(ntohs(dh.nscount) + 1); break; case DNSSection::Additional: dh.arcount = htons(ntohs(dh.arcount) + 1); break; } } void DNSMessage::setQuestion(const dnsname& name, DNSType type) { putName(payload, name); payload.putUInt16((uint16_t)type); payload.putUInt16(1); // class } string DNSMessage::serialize() const { return string((const char*)this, (const char*)this + sizeof(dnsheader) + payload.payloadpos); } static_assert(sizeof(DNSMessage) == 516, "dnsmessage size must be 516"); std::string serializeDNSName(const dnsname& dn) { std::string ret; for(const auto & l : dn) { ret.append(1, l.size()); ret+=l; } ret.append(1, (char)0); return ret; } std::string serializeMXRecord(uint16_t prio, const dnsname& mname) { SafeArray<256> sa; sa.putUInt16(prio); putName(sa, mname); return sa.serialize(); } std::string serializeSOARecord(const dnsname& mname, const dnsname& rname, uint32_t serial, uint32_t minimum=3600, uint32_t refresh=10800, uint32_t retry=3600, uint32_t expire=604800) { SafeArray<256> sa; putName(sa, mname); putName(sa, rname); sa.putUInt32(serial); sa.putUInt32(refresh); sa.putUInt32(retry); sa.putUInt32(expire); sa.putUInt32(minimum); return sa.serialize(); } std::string serializeARecord(const std::string& src) { ComboAddress ca(src); if(ca.sin4.sin_family != AF_INET) throw std::runtime_error("Could not convert '"+src+"' to an IPv4 address"); auto p = (const char*)&ca.sin4.sin_addr.s_addr; return std::string(p, p+4); } std::string serializeAAAARecord(const std::string& src) { ComboAddress ca(src); if(ca.sin4.sin_family != AF_INET6) throw std::runtime_error("Could not convert '"+src+"' to an IPv6 address"); auto p = (const char*)ca.sin6.sin6_addr.s6_addr; return std::string(p, p+16); } bool processQuestion(const DNSNode& zones, DNSMessage& dm, const ComboAddress& local, const ComboAddress& remote, DNSMessage& response) try { dnsname name; DNSType type; dm.getQuestion(name, type); cout<<"Received a query from "<zone) { cout<<"---\nBest zone: "<zone<zone; dnsname searchname(name), lastnode; bool passedZonecut=false; auto node = bestzone->find(searchname, lastnode, &passedZonecut); if(passedZonecut) response.dh.aa = false; if(!node) { cout<<"Found nothing in zone '"<rrsets) { cout<<" Have type "<rrsets.find(DNSType::NS); if(iter != node->rrsets.end() && passedZonecut) { cout<<"Have delegation"<second; for(const auto& rr : rrset.contents) { response.putRR(DNSSection::Answer, lastnode+zone, DNSType::NS, rrset.ttl, rr); } // should do additional processing here } else { cout<<"This is an NXDOMAIN situation"<zone->rrsets[DNSType::SOA]; response.dh.rcode = (int)RCode::Nxdomain; response.putRR(DNSSection::Authority, zone, DNSType::SOA, rrset.ttl, rrset.contents[0]); } } else { cout<<"Found something in zone '"<second; for(const auto& rr : rrset.contents) { response.putRR(DNSSection::Answer, lastnode+zone, DNSType::CNAME, rrset.ttl, rr); } cout<<" We should actually follow this, at least within our zone"<zone->rrsets[DNSType::SOA]; response.putRR(DNSSection::Answer, zone, DNSType::SOA, rrset.ttl, rrset.contents[0]); } } } else { cout<<"No zone matched"< 512) { cerr<<"Remote "<find(name, zone); if(!fnd || !fnd->zone || !name.empty()) { cout<<" This was not a zone"<zone; // send SOA response.putRR(DNSSection::Answer, zone, DNSType::SOA, node->rrsets[DNSType::SOA].ttl, node->rrsets[DNSType::SOA].contents[0]); writeTCPResponse(sock, response); response.dh.ancount = response.dh.arcount = response.dh.nscount = 0; response.payload.rewind(); response.setQuestion(zone, type); // send all other records node->visit([&response,&sock,&name,&type,&zone](const dnsname& nname, const DNSNode* n) { cout<rrsets) { if(p.first == DNSType::SOA) continue; for(const auto& rr : p.second.contents) { retry: try { response.putRR(DNSSection::Answer, nname, p.first, p.second.ttl, rr); } catch(...) { writeTCPResponse(sock, response); response.dh.ancount = response.dh.arcount = response.dh.nscount = 0; response.payload.rewind(); response.setQuestion(zone, type); goto retry; } } cout<rrsets[DNSType::SOA].ttl, node->rrsets[DNSType::SOA].contents[0]); writeTCPResponse(sock, response); return; } else { dm.payload.rewind(); if(processQuestion(*zones, dm, local, remote, response)) { writeTCPResponse(sock, response); } else return; } } } void loadZones(DNSNode& zones) { auto zone = zones.add({"powerdns", "org"}); zone->zone = new DNSNode(); // XXX ICK zone->zone->rrsets[DNSType::SOA]={{serializeSOARecord({"ns1", "powerdns", "org"}, {"admin", "powerdns", "org"}, 1)}}; zone->zone->rrsets[DNSType::MX]={{serializeMXRecord(25, {"server1", "powerdns", "org"})}}; zone->zone->rrsets[DNSType::A]={{serializeARecord("1.2.3.4")}, 300}; zone->zone->rrsets[DNSType::AAAA]={{serializeAAAARecord("::1"), serializeAAAARecord("2001::1"), serializeAAAARecord("2a02:a440:b085:1:beee:7bff:fe89:f0fb")}, 900}; zone->zone->rrsets[DNSType::NS]={{serializeDNSName({"ns1", "powerdns", "org"})}, 300}; zone->zone->add({"www"})->rrsets[DNSType::CNAME]={{serializeDNSName({"server1","powerdns","org"})}}; zone->zone->add({"server1"})->rrsets[DNSType::A]={{serializeARecord("213.244.168.210")}}; zone->zone->add({"server1"})->rrsets[DNSType::AAAA]={{serializeAAAARecord("::1")}}; // zone->zone->add({"*"})->rrsets[(dnstype)DNSType::A]={"\x05\x06\x07\x08"}; zone->zone->add({"fra"})->rrsets[DNSType::NS]={{serializeDNSName({"ns1","fra","powerdns","org"})}}; zone->zone->add({"ns1", "fra"})->rrsets[DNSType::A]={{serializeARecord("12.13.14.15")}, 86400}; zone->zone->add({"NS2", "fra"})->rrsets[DNSType::A]={{serializeARecord("12.13.14.16")}, 86400}; } int main(int argc, char** argv) { ComboAddress local(argv[1], 53); Socket udplistener(local.sin4.sin_family, SOCK_DGRAM); SBind(udplistener, local); Socket tcplistener(local.sin4.sin_family, SOCK_STREAM); SSetsockopt(tcplistener, SOL_SOCKET, SO_REUSEPORT, 1); SBind(tcplistener, local); SListen(tcplistener, 10); DNSNode zones; loadZones(zones); thread udpServer(udpThread, local, &udplistener, &zones); for(;;) { ComboAddress remote; int client = SAccept(tcplistener, remote); thread t(tcpClientThread, local, remote, client, &zones); t.detach(); } udpServer.join(); }