From 7a2f9f06b1fe5809e9a9646c9ab3e7a9d405154d Mon Sep 17 00:00:00 2001 From: Joona Hoikkala Date: Sun, 3 Feb 2019 17:23:04 +0200 Subject: [PATCH] Refactoring DNS server part for safer paraller execution (#144) * Refactoring DNS server part for safer paraller execution and better data structures * Fix linter issues --- dns.go | 202 +++++++++++++++++++++++++++++++++------------------ dns_test.go | 91 ++++++++++++----------- main.go | 34 +++------ main_test.go | 13 ++-- types.go | 9 --- 5 files changed, 197 insertions(+), 152 deletions(-) diff --git a/dns.go b/dns.go index 3531820..4e86d84 100644 --- a/dns.go +++ b/dns.go @@ -8,65 +8,80 @@ import ( "time" ) -func readQuery(m *dns.Msg) { - for _, que := range m.Question { - if rr, rc, err := answer(que); err == nil { - m.MsgHdr.Rcode = rc - for _, r := range rr { - m.Answer = append(m.Answer, r) - } - } - } +// Records is a slice of ResourceRecords +type Records struct { + Records []dns.RR } -func answerTXT(q dns.Question) ([]dns.RR, int, error) { - var ra []dns.RR - rcode := dns.RcodeNameError - subdomain := sanitizeDomainQuestion(q.Name) - atxt, err := DB.GetTXTForDomain(subdomain) +// DNSServer is the main struct for acme-dns DNS server +type DNSServer struct { + DB database + Server *dns.Server + Domains map[string]Records +} + +// NewDNSServer parses the DNS records from config and returns a new DNSServer struct +func NewDNSServer(db database, addr string, proto string) *DNSServer { + var server DNSServer + server.Server = &dns.Server{Addr: addr, Net: proto} + server.DB = db + server.Domains = make(map[string]Records) + return &server +} + +// Start starts the DNSServer +func (d *DNSServer) Start(errorChannel chan error) { + // DNS server part + dns.HandleFunc(".", d.handleRequest) + log.WithFields(log.Fields{"addr": d.Server.Addr, "proto": d.Server.Net}).Info("Listening DNS") + err := d.Server.ListenAndServe() if err != nil { - log.WithFields(log.Fields{"error": err.Error()}).Debug("Error while trying to get record") - return ra, dns.RcodeNameError, err + errorChannel <- err } - for _, v := range atxt { - if len(v) > 0 { - r := new(dns.TXT) - r.Hdr = dns.RR_Header{Name: q.Name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 1} - r.Txt = append(r.Txt, v) - ra = append(ra, r) - rcode = dns.RcodeSuccess - } - } - - log.WithFields(log.Fields{"domain": q.Name}).Info("Answering TXT question for domain") - return ra, rcode, nil } -func answer(q dns.Question) ([]dns.RR, int, error) { - if q.Qtype == dns.TypeTXT { - return answerTXT(q) +// ParseRecords parses a slice of DNS record string +func (d *DNSServer) ParseRecords(config DNSConfig) { + for _, v := range config.General.StaticRecords { + rr, err := dns.NewRR(strings.ToLower(v)) + if err != nil { + log.WithFields(log.Fields{"error": err.Error(), "rr": v}).Warning("Could not parse RR from config") + continue + } + // Add parsed RR + d.appendRR(rr) } - var r []dns.RR - var rcode = dns.RcodeSuccess - var domain = strings.ToLower(q.Name) - var rtype = q.Qtype - r, ok := RR.Records[rtype][domain] + // Create serial + serial := time.Now().Format("2006010215") + // Add SOA + SOAstring := fmt.Sprintf("%s. SOA %s. %s. %s 28800 7200 604800 86400", strings.ToLower(config.General.Domain), strings.ToLower(config.General.Nsname), strings.ToLower(config.General.Nsadmin), serial) + soarr, err := dns.NewRR(SOAstring) + if err != nil { + log.WithFields(log.Fields{"error": err.Error(), "soa": SOAstring}).Error("Error while adding SOA record") + } else { + d.appendRR(soarr) + } +} + +func (d *DNSServer) appendRR(rr dns.RR) { + addDomain := rr.Header().Name + _, ok := d.Domains[addDomain] if !ok { - r, ok = RR.Records[dns.TypeCNAME][domain] - if !ok { - rcode = dns.RcodeNameError - } + d.Domains[addDomain] = Records{[]dns.RR{rr}} + } else { + drecs := d.Domains[addDomain] + drecs.Records = append(drecs.Records, rr) + d.Domains[addDomain] = drecs } - log.WithFields(log.Fields{"qtype": dns.TypeToString[rtype], "domain": domain, "rcode": dns.RcodeToString[rcode]}).Debug("Answering question for domain") - return r, rcode, nil + log.WithFields(log.Fields{"recordtype": dns.TypeToString[rr.Header().Rrtype], "domain": addDomain}).Debug("Adding new record to domain") } -func handleRequest(w dns.ResponseWriter, r *dns.Msg) { +func (d *DNSServer) handleRequest(w dns.ResponseWriter, r *dns.Msg) { m := new(dns.Msg) m.SetReply(r) if r.Opcode == dns.OpcodeQuery { - readQuery(m) + d.readQuery(m) } else if r.Opcode == dns.OpcodeUpdate { log.Debug("Refusing DNS Dynamic update request") m.MsgHdr.Rcode = dns.RcodeRefused @@ -75,38 +90,81 @@ func handleRequest(w dns.ResponseWriter, r *dns.Msg) { w.WriteMsg(m) } -// Parse config records -func (r *Records) Parse(config general) { - rrmap := make(map[uint16]map[string][]dns.RR) - for _, v := range config.StaticRecords { - rr, err := dns.NewRR(strings.ToLower(v)) - if err != nil { - log.WithFields(log.Fields{"error": err.Error(), "rr": v}).Warning("Could not parse RR from config") - continue +func (d *DNSServer) readQuery(m *dns.Msg) { + for _, que := range m.Question { + if rr, rc, err := d.answer(que); err == nil { + m.MsgHdr.Rcode = rc + for _, r := range rr { + m.Answer = append(m.Answer, r) + } } - // Add parsed RR to the list - rrmap = appendRR(rrmap, rr) } - // Create serial - serial := time.Now().Format("2006010215") - // Add SOA - SOAstring := fmt.Sprintf("%s. SOA %s. %s. %s 28800 7200 604800 86400", strings.ToLower(config.Domain), strings.ToLower(config.Nsname), strings.ToLower(config.Nsadmin), serial) - soarr, err := dns.NewRR(SOAstring) - if err != nil { - log.WithFields(log.Fields{"error": err.Error(), "soa": SOAstring}).Error("Error while adding SOA record") - } else { - rrmap = appendRR(rrmap, soarr) - } - r.Records = rrmap } -func appendRR(rrmap map[uint16]map[string][]dns.RR, rr dns.RR) map[uint16]map[string][]dns.RR { - _, ok := rrmap[rr.Header().Rrtype] +func (d *DNSServer) getRecord(q dns.Question) ([]dns.RR, error) { + var rr []dns.RR + var cnames []dns.RR + domain, ok := d.Domains[q.Name] if !ok { - newrr := make(map[string][]dns.RR) - rrmap[rr.Header().Rrtype] = newrr + return rr, fmt.Errorf("No records for domain %s", q.Name) } - rrmap[rr.Header().Rrtype][rr.Header().Name] = append(rrmap[rr.Header().Rrtype][rr.Header().Name], rr) - log.WithFields(log.Fields{"recordtype": dns.TypeToString[rr.Header().Rrtype], "domain": rr.Header().Name}).Debug("Adding new record type to domain") - return rrmap + for _, ri := range domain.Records { + if ri.Header().Rrtype == q.Qtype { + rr = append(rr, ri) + } + if ri.Header().Rrtype == dns.TypeCNAME { + cnames = append(cnames, ri) + } + } + if len(rr) == 0 { + return cnames, nil + } + return rr, nil +} + +// answeringForDomain checks if we have any records for a domain +func (d *DNSServer) answeringForDomain(q dns.Question) bool { + _, ok := d.Domains[q.Name] + return ok +} + +func (d *DNSServer) answer(q dns.Question) ([]dns.RR, int, error) { + var rcode int + if !d.answeringForDomain(q) { + rcode = dns.RcodeNameError + } + r, _ := d.getRecord(q) + if q.Qtype == dns.TypeTXT { + txtRRs, err := d.answerTXT(q) + if err == nil { + for _, txtRR := range txtRRs { + r = append(r, txtRR) + } + } + } + if len(r) > 0 { + // Make sure that we return NOERROR if there were dynamic records for the domain + rcode = dns.RcodeSuccess + } + log.WithFields(log.Fields{"qtype": dns.TypeToString[q.Qtype], "domain": q.Name, "rcode": dns.RcodeToString[rcode]}).Debug("Answering question for domain") + return r, rcode, nil +} + +func (d *DNSServer) answerTXT(q dns.Question) ([]dns.RR, error) { + var ra []dns.RR + subdomain := sanitizeDomainQuestion(q.Name) + atxt, err := d.DB.GetTXTForDomain(subdomain) + if err != nil { + log.WithFields(log.Fields{"error": err.Error()}).Debug("Error while trying to get record") + return ra, err + } + for _, v := range atxt { + if len(v) > 0 { + r := new(dns.TXT) + r.Hdr = dns.RR_Header{Name: q.Name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 1} + r.Txt = append(r.Txt, v) + ra = append(ra, r) + } + } + return ra, nil } diff --git a/dns_test.go b/dns_test.go index 63c2701..7bd3da1 100644 --- a/dns_test.go +++ b/dns_test.go @@ -5,10 +5,10 @@ import ( "database/sql/driver" "errors" "fmt" + "testing" + "github.com/erikstmartin/go-testdb" "github.com/miekg/dns" - "strings" - "testing" ) var resolv resolver @@ -51,25 +51,6 @@ func hasExpectedTXTAnswer(answer []dns.RR, cmpTXT string) error { return errors.New("Expected answer not found") } -func findRecordFromMemory(rrstr string, host string, qtype uint16) error { - var errmsg = "No record found" - arr, _ := dns.NewRR(strings.ToLower(rrstr)) - if arrQt, ok := RR.Records[qtype]; ok { - if arrHst, ok := arrQt[host]; ok { - for _, v := range arrHst { - if arr.String() == v.String() { - return nil - } - } - } else { - errmsg = "No records for domain" - } - } else { - errmsg = "No records for this type in DB" - } - return errors.New(errmsg) -} - func TestQuestionDBError(t *testing.T) { testdb.SetQueryWithArgsFunc(func(query string, args []driver.Value) (result driver.Rows, err error) { columns := []string{"Username", "Password", "Subdomain", "Value", "LastActive"} @@ -88,44 +69,36 @@ func TestQuestionDBError(t *testing.T) { defer DB.SetBackend(oldDb) q := dns.Question{Name: dns.Fqdn("whatever.tld"), Qtype: dns.TypeTXT, Qclass: dns.ClassINET} - _, rcode, err := answerTXT(q) + _, err = dnsserver.answerTXT(q) if err == nil { t.Errorf("Expected error but got none") } - if rcode != dns.RcodeNameError { - t.Errorf("Expected [%s] rcode, but got [%s]", dns.RcodeToString[dns.RcodeNameError], dns.RcodeToString[rcode]) - } } func TestParse(t *testing.T) { - var testcfg = general{ - Domain: ")", - Nsname: "ns1.auth.example.org", - Nsadmin: "admin.example.org", - StaticRecords: []string{}, - Debug: false, + var testcfg = DNSConfig{ + General: general{ + Domain: ")", + Nsname: "ns1.auth.example.org", + Nsadmin: "admin.example.org", + StaticRecords: []string{}, + Debug: false, + }, } - var testRR Records - testRR.Parse(testcfg) + dnsserver.ParseRecords(testcfg) if !loggerHasEntryWithMessage("Error while adding SOA record") { t.Errorf("Expected SOA parsing to return error, but did not find one") } } func TestResolveA(t *testing.T) { - resolv := resolver{server: "0.0.0.0:15353"} + resolv := resolver{server: "127.0.0.1:15353"} answer, err := resolv.lookup("auth.example.org", dns.TypeA) if err != nil { t.Errorf("%v", err) } - if len(answer) > 0 { - err = findRecordFromMemory(answer[0].String(), "auth.example.org.", dns.TypeA) - if err != nil { - t.Errorf("Answer [%s] did not match the expected, got error: [%s], debug: [%q]", answer[0].String(), err, RR.Records) - } - - } else { + if len(answer) == 0 { t.Error("No answer for DNS query") } @@ -135,8 +108,42 @@ func TestResolveA(t *testing.T) { } } +func TestOpcodeUpdate(t *testing.T) { + msg := new(dns.Msg) + msg.Id = dns.Id() + msg.Question = make([]dns.Question, 1) + msg.Question[0] = dns.Question{Name: dns.Fqdn("auth.example.org"), Qtype: dns.TypeANY, Qclass: dns.ClassINET} + msg.MsgHdr.Opcode = dns.OpcodeUpdate + in, err := dns.Exchange(msg, "127.0.0.1:15353") + if err != nil || in == nil { + t.Errorf("Encountered an error with UPDATE request") + } else if err == nil { + if in.Rcode != dns.RcodeRefused { + t.Errorf("Expected RCODE Refused from UPDATE request, but got [%s] instead", dns.RcodeToString[in.Rcode]) + } + } +} + +func TestResolveCNAME(t *testing.T) { + resolv := resolver{server: "127.0.0.1:15353"} + expected := "cn.example.org. 3600 IN CNAME something.example.org." + answer, err := resolv.lookup("cn.example.org", dns.TypeCNAME) + if err != nil { + t.Errorf("Got unexpected error: %s", err) + } + if len(answer) != 1 { + t.Errorf("Expected exactly 1 RR in answer, but got %d instead.", len(answer)) + } + if answer[0].Header().Rrtype != dns.TypeCNAME { + t.Errorf("Expected a CNAME answer, but got [%s] instead.", dns.TypeToString[answer[0].Header().Rrtype]) + } + if answer[0].String() != expected { + t.Errorf("Expected CNAME answer [%s] but got [%s] instead.", expected, answer[0].String()) + } +} + func TestResolveTXT(t *testing.T) { - resolv := resolver{server: "0.0.0.0:15353"} + resolv := resolver{server: "127.0.0.1:15353"} validTXT := "______________valid_response_______________" atxt, err := DB.Register(cidrslice{}) diff --git a/main.go b/main.go index 1557c8e..80a75ff 100644 --- a/main.go +++ b/main.go @@ -12,7 +12,6 @@ import ( "syscall" "github.com/julienschmidt/httprouter" - "github.com/miekg/dns" "github.com/rs/cors" log "github.com/sirupsen/logrus" "golang.org/x/crypto/acme/autocert" @@ -42,9 +41,6 @@ func main() { setupLogging(Config.Logconfig.Format, Config.Logconfig.Level) - // Read the default records in - RR.Parse(Config.General) - // Open database newDB := new(acmedb) err = newDB.Init(Config.Database.Engine, Config.Database.Connection) @@ -72,13 +68,17 @@ func main() { udpProto += "6" tcpProto += "6" } - dnsServerUDP := setupDNSServer(udpProto) - dnsServerTCP := setupDNSServer(tcpProto) - go startDNS(dnsServerUDP, errChan) - go startDNS(dnsServerTCP, errChan) + dnsServerUDP := NewDNSServer(DB, Config.General.Listen, udpProto) + dnsServerUDP.ParseRecords(Config) + dnsServerTCP := NewDNSServer(DB, Config.General.Listen, tcpProto) + // No need to parse records from config again + dnsServerTCP.Domains = dnsServerUDP.Domains + go dnsServerUDP.Start(errChan) + go dnsServerTCP.Start(errChan) } else { - dnsServer := setupDNSServer(Config.General.Proto) - go startDNS(dnsServer, errChan) + dnsServer := NewDNSServer(DB, Config.General.Listen, Config.General.Proto) + dnsServer.ParseRecords(Config) + go dnsServer.Start(errChan) } // HTTP API @@ -94,20 +94,6 @@ func main() { log.Debugf("Shutting down...") } -func startDNS(server *dns.Server, errChan chan error) { - // DNS server part - dns.HandleFunc(".", handleRequest) - log.WithFields(log.Fields{"addr": Config.General.Listen, "proto": server.Net}).Info("Listening DNS") - err := server.ListenAndServe() - if err != nil { - errChan <- err - } -} - -func setupDNSServer(proto string) *dns.Server { - return &dns.Server{Addr: Config.General.Listen, Net: proto} -} - func startHTTPAPI(errChan chan error) { // Setup http logger logger := log.New() diff --git a/main_test.go b/main_test.go index 3ec98fa..1faaca6 100644 --- a/main_test.go +++ b/main_test.go @@ -12,6 +12,7 @@ import ( ) var loghook = new(logrustest.Hook) +var dnsserver *DNSServer var ( postgres = flag.Bool("postgres", false, "run integration tests against PostgreSQL") @@ -20,6 +21,7 @@ var ( var records = []string{ "auth.example.org. A 192.168.1.100", "ns1.auth.example.org. A 192.168.1.101", + "cn.example.org CNAME something.example.org.", "!''b', unparseable ", "ns2.auth.example.org. A 192.168.1.102", } @@ -27,7 +29,6 @@ var records = []string{ func TestMain(m *testing.M) { setupTestLogger() setupConfig() - RR.Parse(Config.General) flag.Parse() newDb := new(acmedb) @@ -43,17 +44,19 @@ func TestMain(m *testing.M) { _ = newDb.Init("sqlite3", ":memory:") } DB = newDb - server := setupDNSServer("udp") + dnsserver = NewDNSServer(DB, Config.General.Listen, Config.General.Proto) + dnsserver.ParseRecords(Config) + // Make sure that we're not creating a race condition in tests var wg sync.WaitGroup wg.Add(1) - server.NotifyStartedFunc = func() { + dnsserver.Server.NotifyStartedFunc = func() { wg.Done() } - go startDNS(server, make(chan error, 1)) + go dnsserver.Start(make(chan error, 1)) wg.Wait() exitval := m.Run() - server.Shutdown() + dnsserver.Server.Shutdown() DB.Close() os.Exit(exitval) } diff --git a/types.go b/types.go index a615a59..f1efa9c 100644 --- a/types.go +++ b/types.go @@ -5,7 +5,6 @@ import ( "sync" "github.com/google/uuid" - "github.com/miekg/dns" ) // Config is global configuration struct @@ -14,14 +13,6 @@ var Config DNSConfig // DB is used to access the database functions in acme-dns var DB database -// RR holds the static DNS records -var RR Records - -// Records is for static records -type Records struct { - Records map[uint16]map[string][]dns.RR -} - // DNSConfig holds the config structure type DNSConfig struct { General general