From 550b23778f597e1cd2c2c26a5cec72f211511939 Mon Sep 17 00:00:00 2001 From: Joona Hoikkala Date: Mon, 28 Nov 2016 15:39:52 +0200 Subject: [PATCH] Better coverage and refactored static record parsing --- dns.go | 6 +++--- dns_test.go | 49 ++++++++++++++++++++++++++++++++++++++++++------- main.go | 2 +- main_test.go | 18 +++++++++++++----- 4 files changed, 59 insertions(+), 16 deletions(-) diff --git a/dns.go b/dns.go index 7c7656b..6320af3 100644 --- a/dns.go +++ b/dns.go @@ -70,9 +70,9 @@ func handleRequest(w dns.ResponseWriter, r *dns.Msg) { } // Parse config records -func (r *Records) Parse(recs []string) { +func (r *Records) Parse(config general) { rrmap := make(map[uint16]map[string][]dns.RR) - for _, v := range recs { + for _, v := range config.StaticRecords { rr, err := dns.NewRR(strings.ToLower(v)) if err != nil { log.WithFields(log.Fields{"error": err.Error(), "rr": v}).Warning("Could not parse RR from config") @@ -84,7 +84,7 @@ func (r *Records) Parse(recs []string) { // Create serial serial := time.Now().Format("2006010215") // Add SOA - SOAstring := fmt.Sprintf("%s. SOA %s. %s. %s 28800 7200 604800 86400", strings.ToLower(DNSConf.General.Domain), strings.ToLower(DNSConf.General.Nsname), strings.ToLower(DNSConf.General.Nsadmin), serial) + SOAstring := fmt.Sprintf("%s. SOA %s. %s. %s 28800 7200 604800 86400", strings.ToLower(config.Domain), strings.ToLower(config.Nsname), strings.ToLower(config.Nsadmin), serial) soarr, err := dns.NewRR(SOAstring) if err != nil { log.WithFields(log.Fields{"error": err.Error(), "soa": SOAstring}).Warning("Error while adding SOA record") diff --git a/dns_test.go b/dns_test.go index 3349f99..6c24c17 100644 --- a/dns_test.go +++ b/dns_test.go @@ -1,8 +1,11 @@ package main import ( + "database/sql" + "database/sql/driver" "errors" "fmt" + "github.com/erikstmartin/go-testdb" "github.com/miekg/dns" "strings" "testing" @@ -11,13 +14,6 @@ import ( var resolv resolver var server *dns.Server -var records = []string{ - "auth.example.org. A 192.168.1.100", - "ns1.auth.example.org. A 192.168.1.101", - "!''b', unparseable ", - "ns2.auth.example.org. A 192.168.1.102", -} - type resolver struct { server string } @@ -74,6 +70,45 @@ func findRecordFromMemory(rrstr string, host string, qtype uint16) error { return errors.New(errmsg) } +func TestQuestionDBError(t *testing.T) { + testdb.SetQueryWithArgsFunc(func(query string, args []driver.Value) (result driver.Rows, err error) { + columns := []string{"Username", "Password", "Subdomain", "Value", "LastActive"} + return testdb.RowsFromSlice(columns, [][]driver.Value{}), errors.New("Prepared query error") + }) + + defer testdb.Reset() + + tdb, err := sql.Open("testdb", "") + if err != nil { + t.Errorf("Got error: %v", err) + } + oldDb := DB.GetBackend() + + DB.SetBackend(tdb) + defer DB.SetBackend(oldDb) + + q := dns.Question{Name: dns.Fqdn("whatever.tld"), Qtype: dns.TypeTXT, Qclass: dns.ClassINET} + _, rcode, err := answerTXT(q) + if err == nil { + t.Errorf("Expected error but got none") + } + if rcode != dns.RcodeNameError { + t.Errorf("Expected [%s] rcode, but got [%s]", dns.RcodeToString[dns.RcodeNameError], dns.RcodeToString[rcode]) + } +} + +func TestParse(t *testing.T) { + var testcfg = general{ + Domain: ")", + Nsname: "ns1.auth.example.org", + Nsadmin: "admin.example.org", + StaticRecords: []string{}, + Debug: false, + } + var testRR Records + testRR.Parse(testcfg) +} + func TestResolveA(t *testing.T) { resolv := resolver{server: "0.0.0.0:15353"} answer, err := resolv.lookup("auth.example.org", dns.TypeA) diff --git a/main.go b/main.go index b7f4dda..21b9f31 100644 --- a/main.go +++ b/main.go @@ -17,7 +17,7 @@ func main() { setupLogging(DNSConf.Logconfig.Format, DNSConf.Logconfig.Level) // Read the default records in - RR.Parse(DNSConf.General.StaticRecords) + RR.Parse(DNSConf.General) // Open database newDB := new(acmedb) diff --git a/main_test.go b/main_test.go index ed5d827..95987ea 100644 --- a/main_test.go +++ b/main_test.go @@ -11,9 +11,16 @@ var ( postgres = flag.Bool("postgres", false, "run integration tests against PostgreSQL") ) +var records = []string{ + "auth.example.org. A 192.168.1.100", + "ns1.auth.example.org. A 192.168.1.101", + "!''b', unparseable ", + "ns2.auth.example.org. A 192.168.1.102", +} + func TestMain(m *testing.M) { setupConfig() - RR.Parse(records) + RR.Parse(DNSConf.General) flag.Parse() newDb := new(acmedb) @@ -44,10 +51,11 @@ func setupConfig() { } var generalcfg = general{ - Domain: "auth.example.org", - Nsname: "ns1.auth.example.org", - Nsadmin: "admin.example.org", - Debug: false, + Domain: "auth.example.org", + Nsname: "ns1.auth.example.org", + Nsadmin: "admin.example.org", + StaticRecords: records, + Debug: false, } var httpapicfg = httpapi{