diff --git a/README.md b/README.md index f996a9a..306a889 100644 --- a/README.md +++ b/README.md @@ -326,10 +326,12 @@ logformat = "text" ## Changelog - master - - Unreleased - - Added new endpoint to perform health checks + - Added + - New endpoint to perform health checks - Changed - A new protocol selection for DNS server "both", that binds both - UDP and TCP ports. + - Refactored DNS server internals. + - Handle some aspects of DNS spec better. - v0.6 - New - Command line flag `-c` to specify location of config file. diff --git a/dns.go b/dns.go index 4e86d84..b703911 100644 --- a/dns.go +++ b/dns.go @@ -17,6 +17,7 @@ type Records struct { type DNSServer struct { DB database Server *dns.Server + SOA dns.RR Domains map[string]Records } @@ -60,6 +61,7 @@ func (d *DNSServer) ParseRecords(config DNSConfig) { log.WithFields(log.Fields{"error": err.Error(), "soa": SOAstring}).Error("Error while adding SOA record") } else { d.appendRR(soarr) + d.SOA = soarr } } @@ -91,14 +93,25 @@ func (d *DNSServer) handleRequest(w dns.ResponseWriter, r *dns.Msg) { } func (d *DNSServer) readQuery(m *dns.Msg) { + var authoritative = false for _, que := range m.Question { - if rr, rc, err := d.answer(que); err == nil { + if rr, rc, auth, err := d.answer(que); err == nil { + if auth { + authoritative = auth + } m.MsgHdr.Rcode = rc for _, r := range rr { m.Answer = append(m.Answer, r) } } } + m.MsgHdr.Authoritative = authoritative + if authoritative { + if m.MsgHdr.Rcode == dns.RcodeNameError { + m.Answer = append(m.Answer, d.SOA) + } + } + } func (d *DNSServer) getRecord(q dns.Question) ([]dns.RR, error) { @@ -123,14 +136,28 @@ func (d *DNSServer) getRecord(q dns.Question) ([]dns.RR, error) { } // answeringForDomain checks if we have any records for a domain -func (d *DNSServer) answeringForDomain(q dns.Question) bool { - _, ok := d.Domains[q.Name] +func (d *DNSServer) answeringForDomain(name string) bool { + _, ok := d.Domains[name] return ok } -func (d *DNSServer) answer(q dns.Question) ([]dns.RR, int, error) { +func (d *DNSServer) isAuthoritative(q dns.Question) bool { + if d.answeringForDomain(q.Name) { + return true + } + domainParts := strings.Split(q.Name, ".") + for i := range domainParts { + if d.answeringForDomain(strings.Join(domainParts[i:], ".")) { + return true + } + } + return false +} + +func (d *DNSServer) answer(q dns.Question) ([]dns.RR, int, bool, error) { var rcode int - if !d.answeringForDomain(q) { + var authoritative = d.isAuthoritative(q) + if !d.answeringForDomain(q.Name) { rcode = dns.RcodeNameError } r, _ := d.getRecord(q) @@ -146,8 +173,12 @@ func (d *DNSServer) answer(q dns.Question) ([]dns.RR, int, error) { // Make sure that we return NOERROR if there were dynamic records for the domain rcode = dns.RcodeSuccess } + // Handle EDNS (no support at the moment) + if q.Qtype == dns.TypeOPT { + return []dns.RR{}, dns.RcodeFormatError, authoritative, nil + } log.WithFields(log.Fields{"qtype": dns.TypeToString[q.Qtype], "domain": q.Name, "rcode": dns.RcodeToString[rcode]}).Debug("Answering question for domain") - return r, rcode, nil + return r, rcode, authoritative, nil } func (d *DNSServer) answerTXT(q dns.Question) ([]dns.RR, error) { diff --git a/dns_test.go b/dns_test.go index 7bd3da1..9c9138a 100644 --- a/dns_test.go +++ b/dns_test.go @@ -18,20 +18,20 @@ type resolver struct { server string } -func (r *resolver) lookup(host string, qtype uint16) ([]dns.RR, error) { +func (r *resolver) lookup(host string, qtype uint16) (*dns.Msg, error) { msg := new(dns.Msg) msg.Id = dns.Id() msg.Question = make([]dns.Question, 1) msg.Question[0] = dns.Question{Name: dns.Fqdn(host), Qtype: qtype, Qclass: dns.ClassINET} in, err := dns.Exchange(msg, r.server) if err != nil { - return []dns.RR{}, fmt.Errorf("Error querying the server [%v]", err) + return in, fmt.Errorf("Error querying the server [%v]", err) } if in != nil && in.Rcode != dns.RcodeSuccess { - return []dns.RR{}, fmt.Errorf("Received error from the server [%s]", dns.RcodeToString[in.Rcode]) + return in, fmt.Errorf("Received error from the server [%s]", dns.RcodeToString[in.Rcode]) } - return in.Answer, nil + return in, nil } func hasExpectedTXTAnswer(answer []dns.RR, cmpTXT string) error { @@ -98,7 +98,7 @@ func TestResolveA(t *testing.T) { t.Errorf("%v", err) } - if len(answer) == 0 { + if len(answer.Answer) == 0 { t.Error("No answer for DNS query") } @@ -108,6 +108,14 @@ func TestResolveA(t *testing.T) { } } +func TestEDNS(t *testing.T) { + resolv := resolver{server: "127.0.0.1:15353"} + answer, _ := resolv.lookup("auth.example.org", dns.TypeOPT) + if answer.Rcode != dns.RcodeFormatError { + t.Errorf("Was expecing FORMERR rcode for OPT query, but got [%s] instead.", dns.RcodeToString[answer.Rcode]) + } +} + func TestOpcodeUpdate(t *testing.T) { msg := new(dns.Msg) msg.Id = dns.Id() @@ -131,14 +139,38 @@ func TestResolveCNAME(t *testing.T) { if err != nil { t.Errorf("Got unexpected error: %s", err) } - if len(answer) != 1 { - t.Errorf("Expected exactly 1 RR in answer, but got %d instead.", len(answer)) + if len(answer.Answer) != 1 { + t.Errorf("Expected exactly 1 RR in answer, but got %d instead.", len(answer.Answer)) } - if answer[0].Header().Rrtype != dns.TypeCNAME { - t.Errorf("Expected a CNAME answer, but got [%s] instead.", dns.TypeToString[answer[0].Header().Rrtype]) + if answer.Answer[0].Header().Rrtype != dns.TypeCNAME { + t.Errorf("Expected a CNAME answer, but got [%s] instead.", dns.TypeToString[answer.Answer[0].Header().Rrtype]) } - if answer[0].String() != expected { - t.Errorf("Expected CNAME answer [%s] but got [%s] instead.", expected, answer[0].String()) + if answer.Answer[0].String() != expected { + t.Errorf("Expected CNAME answer [%s] but got [%s] instead.", expected, answer.Answer[0].String()) + } +} + +func TestAuthoritative(t *testing.T) { + resolv := resolver{server: "127.0.0.1:15353"} + answer, _ := resolv.lookup("nonexistent.auth.example.org", dns.TypeA) + if answer.Rcode != dns.RcodeNameError { + t.Errorf("Was expecing NXDOMAIN rcode, but got [%s] instead.", dns.RcodeToString[answer.Rcode]) + } + if len(answer.Answer) != 1 { + t.Errorf("Was expecting exactly one answer (SOA) for invalid subdomain, but got %d", len(answer.Answer)) + } + if answer.Answer[0].Header().Rrtype != dns.TypeSOA { + t.Errorf("Was expecting SOA record as answer for NXDOMAIN but got [%s]", dns.TypeToString[answer.Answer[0].Header().Rrtype]) + } + if !answer.MsgHdr.Authoritative { + t.Errorf("Was expecting authoritative bit to be set") + } + nanswer, _ := resolv.lookup("nonexsitent.nonauth.tld", dns.TypeA) + if len(nanswer.Answer) > 0 { + t.Errorf("Didn't expect answers for non authotitative domain query") + } + if nanswer.MsgHdr.Authoritative { + t.Errorf("Authoritative bit should not be set for non-authoritative domain.") } } @@ -179,18 +211,20 @@ func TestResolveTXT(t *testing.T) { } } - if len(answer) > 0 { - if !test.getAnswer { + if len(answer.Answer) > 0 { + if !test.getAnswer && answer.Answer[0].Header().Rrtype != dns.TypeSOA { t.Errorf("Test %d: Expected no answer, but got: [%q]", i, answer) } - err = hasExpectedTXTAnswer(answer, test.expTXT) - if err != nil { - if test.validAnswer { - t.Errorf("Test %d: %v", i, err) - } - } else { - if !test.validAnswer { - t.Errorf("Test %d: Answer was not expected to be valid, answer [%q], compared to [%s]", i, answer, test.expTXT) + if test.getAnswer { + err = hasExpectedTXTAnswer(answer.Answer, test.expTXT) + if err != nil { + if test.validAnswer { + t.Errorf("Test %d: %v", i, err) + } + } else { + if !test.validAnswer { + t.Errorf("Test %d: Answer was not expected to be valid, answer [%q], compared to [%s]", i, answer, test.expTXT) + } } } } else {