DNS tests continued

This commit is contained in:
Joona Hoikkala
2016-11-26 10:02:32 +02:00
parent 8f8262acdd
commit f71b1772c6

View File

@ -2,6 +2,7 @@ package main
import ( import (
"errors" "errors"
"flag"
"fmt" "fmt"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/op/go-logging" "github.com/op/go-logging"
@ -22,31 +23,45 @@ type resolver struct {
server string server string
} }
func (r *resolver) lookup(host string, qtype uint16) (string, error) { func (r *resolver) lookup(host string, qtype uint16) ([]dns.RR, error) {
msg := new(dns.Msg) msg := new(dns.Msg)
msg.Id = dns.Id() msg.Id = dns.Id()
msg.Question = make([]dns.Question, 1) msg.Question = make([]dns.Question, 1)
msg.Question[0] = dns.Question{dns.Fqdn(host), qtype, dns.ClassINET} msg.Question[0] = dns.Question{Name: dns.Fqdn(host), Qtype: qtype, Qclass: dns.ClassINET}
in, err := dns.Exchange(msg, r.server) in, err := dns.Exchange(msg, r.server)
if err != nil { if err != nil {
return "", errors.New(fmt.Sprintf("Error querying the server [%v]", err)) return []dns.RR{}, fmt.Errorf("Error querying the server [%v]", err)
} }
if in != nil && in.Rcode != dns.RcodeSuccess { if in != nil && in.Rcode != dns.RcodeSuccess {
return "", errors.New(fmt.Sprintf("Recieved error from the server [%s]", dns.RcodeToString[in.Rcode])) return []dns.RR{}, fmt.Errorf("Recieved error from the server [%s]", dns.RcodeToString[in.Rcode])
} }
if len(in.Answer) > 0 { return in.Answer, nil
return in.Answer[0].String(), nil
}
return "", errors.New("No answer")
} }
func findRecord(rrstr string, host string, qtype uint16) error { func hasExpectedTXTAnswer(answer []dns.RR, cmpTXT string) error {
for _, record := range answer {
// We expect only one answer, so no need to loop through the answer slice
if rec, ok := record.(*dns.TXT); ok {
for _, txtValue := range rec.Txt {
if txtValue == cmpTXT {
return nil
}
}
} else {
errmsg := fmt.Sprintf("Got answer of unexpected type [%q]", answer[0])
return errors.New(errmsg)
}
}
return errors.New("Expected answer not found")
}
func findRecordFromMemory(rrstr string, host string, qtype uint16) error {
var errmsg = "No record found" var errmsg = "No record found"
arr, _ := dns.NewRR(strings.ToLower(rrstr)) arr, _ := dns.NewRR(strings.ToLower(rrstr))
if arr_qt, ok := RR.Records[qtype]; ok { if arrQt, ok := RR.Records[qtype]; ok {
if arr_hst, ok := arr_qt[host]; ok { if arrHst, ok := arrQt[host]; ok {
for _, v := range arr_hst { for _, v := range arrHst {
if arr.String() == v.String() { if arr.String() == v.String() {
return nil return nil
} }
@ -61,6 +76,26 @@ func findRecord(rrstr string, host string, qtype uint16) error {
} }
func startDNSServer(addr string) (*dns.Server, resolver) { func startDNSServer(addr string) (*dns.Server, resolver) {
var dbcfg = dbsettings{
Engine: "sqlite3",
Connection: ":memory:",
}
var generalcfg = general{
Domain: "auth.example.org",
Nsname: "ns1.auth.example.org",
Nsadmin: "admin.example.org",
Debug: false,
}
var dnscfg = DNSConfig{
Database: dbcfg,
General: generalcfg,
}
DNSConf = dnscfg
logging.InitForTesting(logging.DEBUG) logging.InitForTesting(logging.DEBUG)
// DNS server part // DNS server part
dns.HandleFunc(".", handleRequest) dns.HandleFunc(".", handleRequest)
@ -77,14 +112,68 @@ func startDNSServer(addr string) (*dns.Server, resolver) {
func TestResolveA(t *testing.T) { func TestResolveA(t *testing.T) {
server, resolver := startDNSServer(testAddr) server, resolver := startDNSServer(testAddr)
defer server.Shutdown()
RR.Parse(records) RR.Parse(records)
a, err := resolver.lookup("auth.example.org", dns.TypeA) answer, err := resolver.lookup("auth.example.org", dns.TypeA)
if err != nil { if err != nil {
t.Errorf("%v", err) t.Errorf("%v", err)
} }
err = findRecord(a, "auth.example.org.", dns.TypeA)
if len(answer) > 0 {
err = findRecordFromMemory(answer[0].String(), "auth.example.org.", dns.TypeA)
if err != nil { if err != nil {
t.Errorf("Answer [%s] did not match the expected, got error: [%s], debug: [%q]", a, err, RR.Records) t.Errorf("Answer [%s] did not match the expected, got error: [%s], debug: [%q]", answer[0].String(), err, RR.Records)
}
} else {
t.Error("No answer for DNS query")
}
}
func TestResolveTXT(t *testing.T) {
flag.Parse()
if *postgres {
DNSConf.Database.Engine = "postgres"
err := DB.Init("postgres", "postgres://acmedns:acmedns@localhost/acmedns")
if err != nil {
t.Errorf("PostgreSQL integration tests expect database \"acmedns\" running in localhost, with username and password set to \"acmedns\"")
return
}
} else {
DNSConf.Database.Engine = "sqlite3"
_ = DB.Init("sqlite3", ":memory:")
}
defer DB.DB.Close()
server, resolver := startDNSServer(testAddr)
defer server.Shutdown()
RR.Parse(records)
validTXT := "______________valid_response_______________"
atxt, err := DB.Register()
if err != nil {
t.Errorf("Could not initiate db record: [%v]", err)
return
}
atxt.Value = validTXT
err = DB.Update(atxt)
if err != nil {
t.Errorf("Could not update db record: [%v]", err)
return
}
answer, err := resolver.lookup(atxt.Subdomain+".auth.example.org", dns.TypeTXT)
if err != nil {
t.Errorf("%v", err)
return
}
if len(answer) > 0 {
err = hasExpectedTXTAnswer(answer, validTXT)
if err != nil {
t.Errorf("%v", err)
}
} else {
t.Error("No answer for DNS query")
} }
server.Shutdown()
} }