Refactoring DNS server part for safer paraller execution (#144)

* Refactoring DNS server part for safer paraller execution and better data structures

* Fix linter issues
This commit is contained in:
Joona Hoikkala 2019-02-03 17:23:04 +02:00 committed by GitHub
parent d695f72963
commit 7a2f9f06b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 197 additions and 152 deletions

200
dns.go
View File

@ -8,65 +8,80 @@ import (
"time" "time"
) )
func readQuery(m *dns.Msg) { // Records is a slice of ResourceRecords
for _, que := range m.Question { type Records struct {
if rr, rc, err := answer(que); err == nil { Records []dns.RR
m.MsgHdr.Rcode = rc
for _, r := range rr {
m.Answer = append(m.Answer, r)
}
}
}
} }
func answerTXT(q dns.Question) ([]dns.RR, int, error) { // DNSServer is the main struct for acme-dns DNS server
var ra []dns.RR type DNSServer struct {
rcode := dns.RcodeNameError DB database
subdomain := sanitizeDomainQuestion(q.Name) Server *dns.Server
atxt, err := DB.GetTXTForDomain(subdomain) Domains map[string]Records
}
// NewDNSServer parses the DNS records from config and returns a new DNSServer struct
func NewDNSServer(db database, addr string, proto string) *DNSServer {
var server DNSServer
server.Server = &dns.Server{Addr: addr, Net: proto}
server.DB = db
server.Domains = make(map[string]Records)
return &server
}
// Start starts the DNSServer
func (d *DNSServer) Start(errorChannel chan error) {
// DNS server part
dns.HandleFunc(".", d.handleRequest)
log.WithFields(log.Fields{"addr": d.Server.Addr, "proto": d.Server.Net}).Info("Listening DNS")
err := d.Server.ListenAndServe()
if err != nil { if err != nil {
log.WithFields(log.Fields{"error": err.Error()}).Debug("Error while trying to get record") errorChannel <- err
return ra, dns.RcodeNameError, err
} }
for _, v := range atxt {
if len(v) > 0 {
r := new(dns.TXT)
r.Hdr = dns.RR_Header{Name: q.Name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 1}
r.Txt = append(r.Txt, v)
ra = append(ra, r)
rcode = dns.RcodeSuccess
}
}
log.WithFields(log.Fields{"domain": q.Name}).Info("Answering TXT question for domain")
return ra, rcode, nil
} }
func answer(q dns.Question) ([]dns.RR, int, error) { // ParseRecords parses a slice of DNS record string
if q.Qtype == dns.TypeTXT { func (d *DNSServer) ParseRecords(config DNSConfig) {
return answerTXT(q) for _, v := range config.General.StaticRecords {
rr, err := dns.NewRR(strings.ToLower(v))
if err != nil {
log.WithFields(log.Fields{"error": err.Error(), "rr": v}).Warning("Could not parse RR from config")
continue
} }
var r []dns.RR // Add parsed RR
var rcode = dns.RcodeSuccess d.appendRR(rr)
var domain = strings.ToLower(q.Name)
var rtype = q.Qtype
r, ok := RR.Records[rtype][domain]
if !ok {
r, ok = RR.Records[dns.TypeCNAME][domain]
if !ok {
rcode = dns.RcodeNameError
} }
// Create serial
serial := time.Now().Format("2006010215")
// Add SOA
SOAstring := fmt.Sprintf("%s. SOA %s. %s. %s 28800 7200 604800 86400", strings.ToLower(config.General.Domain), strings.ToLower(config.General.Nsname), strings.ToLower(config.General.Nsadmin), serial)
soarr, err := dns.NewRR(SOAstring)
if err != nil {
log.WithFields(log.Fields{"error": err.Error(), "soa": SOAstring}).Error("Error while adding SOA record")
} else {
d.appendRR(soarr)
} }
log.WithFields(log.Fields{"qtype": dns.TypeToString[rtype], "domain": domain, "rcode": dns.RcodeToString[rcode]}).Debug("Answering question for domain")
return r, rcode, nil
} }
func handleRequest(w dns.ResponseWriter, r *dns.Msg) { func (d *DNSServer) appendRR(rr dns.RR) {
addDomain := rr.Header().Name
_, ok := d.Domains[addDomain]
if !ok {
d.Domains[addDomain] = Records{[]dns.RR{rr}}
} else {
drecs := d.Domains[addDomain]
drecs.Records = append(drecs.Records, rr)
d.Domains[addDomain] = drecs
}
log.WithFields(log.Fields{"recordtype": dns.TypeToString[rr.Header().Rrtype], "domain": addDomain}).Debug("Adding new record to domain")
}
func (d *DNSServer) handleRequest(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg) m := new(dns.Msg)
m.SetReply(r) m.SetReply(r)
if r.Opcode == dns.OpcodeQuery { if r.Opcode == dns.OpcodeQuery {
readQuery(m) d.readQuery(m)
} else if r.Opcode == dns.OpcodeUpdate { } else if r.Opcode == dns.OpcodeUpdate {
log.Debug("Refusing DNS Dynamic update request") log.Debug("Refusing DNS Dynamic update request")
m.MsgHdr.Rcode = dns.RcodeRefused m.MsgHdr.Rcode = dns.RcodeRefused
@ -75,38 +90,81 @@ func handleRequest(w dns.ResponseWriter, r *dns.Msg) {
w.WriteMsg(m) w.WriteMsg(m)
} }
// Parse config records func (d *DNSServer) readQuery(m *dns.Msg) {
func (r *Records) Parse(config general) { for _, que := range m.Question {
rrmap := make(map[uint16]map[string][]dns.RR) if rr, rc, err := d.answer(que); err == nil {
for _, v := range config.StaticRecords { m.MsgHdr.Rcode = rc
rr, err := dns.NewRR(strings.ToLower(v)) for _, r := range rr {
if err != nil { m.Answer = append(m.Answer, r)
log.WithFields(log.Fields{"error": err.Error(), "rr": v}).Warning("Could not parse RR from config")
continue
} }
// Add parsed RR to the list
rrmap = appendRR(rrmap, rr)
} }
// Create serial
serial := time.Now().Format("2006010215")
// Add SOA
SOAstring := fmt.Sprintf("%s. SOA %s. %s. %s 28800 7200 604800 86400", strings.ToLower(config.Domain), strings.ToLower(config.Nsname), strings.ToLower(config.Nsadmin), serial)
soarr, err := dns.NewRR(SOAstring)
if err != nil {
log.WithFields(log.Fields{"error": err.Error(), "soa": SOAstring}).Error("Error while adding SOA record")
} else {
rrmap = appendRR(rrmap, soarr)
} }
r.Records = rrmap
} }
func appendRR(rrmap map[uint16]map[string][]dns.RR, rr dns.RR) map[uint16]map[string][]dns.RR { func (d *DNSServer) getRecord(q dns.Question) ([]dns.RR, error) {
_, ok := rrmap[rr.Header().Rrtype] var rr []dns.RR
var cnames []dns.RR
domain, ok := d.Domains[q.Name]
if !ok { if !ok {
newrr := make(map[string][]dns.RR) return rr, fmt.Errorf("No records for domain %s", q.Name)
rrmap[rr.Header().Rrtype] = newrr
} }
rrmap[rr.Header().Rrtype][rr.Header().Name] = append(rrmap[rr.Header().Rrtype][rr.Header().Name], rr) for _, ri := range domain.Records {
log.WithFields(log.Fields{"recordtype": dns.TypeToString[rr.Header().Rrtype], "domain": rr.Header().Name}).Debug("Adding new record type to domain") if ri.Header().Rrtype == q.Qtype {
return rrmap rr = append(rr, ri)
}
if ri.Header().Rrtype == dns.TypeCNAME {
cnames = append(cnames, ri)
}
}
if len(rr) == 0 {
return cnames, nil
}
return rr, nil
}
// answeringForDomain checks if we have any records for a domain
func (d *DNSServer) answeringForDomain(q dns.Question) bool {
_, ok := d.Domains[q.Name]
return ok
}
func (d *DNSServer) answer(q dns.Question) ([]dns.RR, int, error) {
var rcode int
if !d.answeringForDomain(q) {
rcode = dns.RcodeNameError
}
r, _ := d.getRecord(q)
if q.Qtype == dns.TypeTXT {
txtRRs, err := d.answerTXT(q)
if err == nil {
for _, txtRR := range txtRRs {
r = append(r, txtRR)
}
}
}
if len(r) > 0 {
// Make sure that we return NOERROR if there were dynamic records for the domain
rcode = dns.RcodeSuccess
}
log.WithFields(log.Fields{"qtype": dns.TypeToString[q.Qtype], "domain": q.Name, "rcode": dns.RcodeToString[rcode]}).Debug("Answering question for domain")
return r, rcode, nil
}
func (d *DNSServer) answerTXT(q dns.Question) ([]dns.RR, error) {
var ra []dns.RR
subdomain := sanitizeDomainQuestion(q.Name)
atxt, err := d.DB.GetTXTForDomain(subdomain)
if err != nil {
log.WithFields(log.Fields{"error": err.Error()}).Debug("Error while trying to get record")
return ra, err
}
for _, v := range atxt {
if len(v) > 0 {
r := new(dns.TXT)
r.Hdr = dns.RR_Header{Name: q.Name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 1}
r.Txt = append(r.Txt, v)
ra = append(ra, r)
}
}
return ra, nil
} }

View File

@ -5,10 +5,10 @@ import (
"database/sql/driver" "database/sql/driver"
"errors" "errors"
"fmt" "fmt"
"testing"
"github.com/erikstmartin/go-testdb" "github.com/erikstmartin/go-testdb"
"github.com/miekg/dns" "github.com/miekg/dns"
"strings"
"testing"
) )
var resolv resolver var resolv resolver
@ -51,25 +51,6 @@ func hasExpectedTXTAnswer(answer []dns.RR, cmpTXT string) error {
return errors.New("Expected answer not found") return errors.New("Expected answer not found")
} }
func findRecordFromMemory(rrstr string, host string, qtype uint16) error {
var errmsg = "No record found"
arr, _ := dns.NewRR(strings.ToLower(rrstr))
if arrQt, ok := RR.Records[qtype]; ok {
if arrHst, ok := arrQt[host]; ok {
for _, v := range arrHst {
if arr.String() == v.String() {
return nil
}
}
} else {
errmsg = "No records for domain"
}
} else {
errmsg = "No records for this type in DB"
}
return errors.New(errmsg)
}
func TestQuestionDBError(t *testing.T) { func TestQuestionDBError(t *testing.T) {
testdb.SetQueryWithArgsFunc(func(query string, args []driver.Value) (result driver.Rows, err error) { testdb.SetQueryWithArgsFunc(func(query string, args []driver.Value) (result driver.Rows, err error) {
columns := []string{"Username", "Password", "Subdomain", "Value", "LastActive"} columns := []string{"Username", "Password", "Subdomain", "Value", "LastActive"}
@ -88,44 +69,36 @@ func TestQuestionDBError(t *testing.T) {
defer DB.SetBackend(oldDb) defer DB.SetBackend(oldDb)
q := dns.Question{Name: dns.Fqdn("whatever.tld"), Qtype: dns.TypeTXT, Qclass: dns.ClassINET} q := dns.Question{Name: dns.Fqdn("whatever.tld"), Qtype: dns.TypeTXT, Qclass: dns.ClassINET}
_, rcode, err := answerTXT(q) _, err = dnsserver.answerTXT(q)
if err == nil { if err == nil {
t.Errorf("Expected error but got none") t.Errorf("Expected error but got none")
} }
if rcode != dns.RcodeNameError {
t.Errorf("Expected [%s] rcode, but got [%s]", dns.RcodeToString[dns.RcodeNameError], dns.RcodeToString[rcode])
}
} }
func TestParse(t *testing.T) { func TestParse(t *testing.T) {
var testcfg = general{ var testcfg = DNSConfig{
General: general{
Domain: ")", Domain: ")",
Nsname: "ns1.auth.example.org", Nsname: "ns1.auth.example.org",
Nsadmin: "admin.example.org", Nsadmin: "admin.example.org",
StaticRecords: []string{}, StaticRecords: []string{},
Debug: false, Debug: false,
},
} }
var testRR Records dnsserver.ParseRecords(testcfg)
testRR.Parse(testcfg)
if !loggerHasEntryWithMessage("Error while adding SOA record") { if !loggerHasEntryWithMessage("Error while adding SOA record") {
t.Errorf("Expected SOA parsing to return error, but did not find one") t.Errorf("Expected SOA parsing to return error, but did not find one")
} }
} }
func TestResolveA(t *testing.T) { func TestResolveA(t *testing.T) {
resolv := resolver{server: "0.0.0.0:15353"} resolv := resolver{server: "127.0.0.1:15353"}
answer, err := resolv.lookup("auth.example.org", dns.TypeA) answer, err := resolv.lookup("auth.example.org", dns.TypeA)
if err != nil { if err != nil {
t.Errorf("%v", err) t.Errorf("%v", err)
} }
if len(answer) > 0 { if len(answer) == 0 {
err = findRecordFromMemory(answer[0].String(), "auth.example.org.", dns.TypeA)
if err != nil {
t.Errorf("Answer [%s] did not match the expected, got error: [%s], debug: [%q]", answer[0].String(), err, RR.Records)
}
} else {
t.Error("No answer for DNS query") t.Error("No answer for DNS query")
} }
@ -135,8 +108,42 @@ func TestResolveA(t *testing.T) {
} }
} }
func TestOpcodeUpdate(t *testing.T) {
msg := new(dns.Msg)
msg.Id = dns.Id()
msg.Question = make([]dns.Question, 1)
msg.Question[0] = dns.Question{Name: dns.Fqdn("auth.example.org"), Qtype: dns.TypeANY, Qclass: dns.ClassINET}
msg.MsgHdr.Opcode = dns.OpcodeUpdate
in, err := dns.Exchange(msg, "127.0.0.1:15353")
if err != nil || in == nil {
t.Errorf("Encountered an error with UPDATE request")
} else if err == nil {
if in.Rcode != dns.RcodeRefused {
t.Errorf("Expected RCODE Refused from UPDATE request, but got [%s] instead", dns.RcodeToString[in.Rcode])
}
}
}
func TestResolveCNAME(t *testing.T) {
resolv := resolver{server: "127.0.0.1:15353"}
expected := "cn.example.org. 3600 IN CNAME something.example.org."
answer, err := resolv.lookup("cn.example.org", dns.TypeCNAME)
if err != nil {
t.Errorf("Got unexpected error: %s", err)
}
if len(answer) != 1 {
t.Errorf("Expected exactly 1 RR in answer, but got %d instead.", len(answer))
}
if answer[0].Header().Rrtype != dns.TypeCNAME {
t.Errorf("Expected a CNAME answer, but got [%s] instead.", dns.TypeToString[answer[0].Header().Rrtype])
}
if answer[0].String() != expected {
t.Errorf("Expected CNAME answer [%s] but got [%s] instead.", expected, answer[0].String())
}
}
func TestResolveTXT(t *testing.T) { func TestResolveTXT(t *testing.T) {
resolv := resolver{server: "0.0.0.0:15353"} resolv := resolver{server: "127.0.0.1:15353"}
validTXT := "______________valid_response_______________" validTXT := "______________valid_response_______________"
atxt, err := DB.Register(cidrslice{}) atxt, err := DB.Register(cidrslice{})

34
main.go
View File

@ -12,7 +12,6 @@ import (
"syscall" "syscall"
"github.com/julienschmidt/httprouter" "github.com/julienschmidt/httprouter"
"github.com/miekg/dns"
"github.com/rs/cors" "github.com/rs/cors"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/crypto/acme/autocert" "golang.org/x/crypto/acme/autocert"
@ -42,9 +41,6 @@ func main() {
setupLogging(Config.Logconfig.Format, Config.Logconfig.Level) setupLogging(Config.Logconfig.Format, Config.Logconfig.Level)
// Read the default records in
RR.Parse(Config.General)
// Open database // Open database
newDB := new(acmedb) newDB := new(acmedb)
err = newDB.Init(Config.Database.Engine, Config.Database.Connection) err = newDB.Init(Config.Database.Engine, Config.Database.Connection)
@ -72,13 +68,17 @@ func main() {
udpProto += "6" udpProto += "6"
tcpProto += "6" tcpProto += "6"
} }
dnsServerUDP := setupDNSServer(udpProto) dnsServerUDP := NewDNSServer(DB, Config.General.Listen, udpProto)
dnsServerTCP := setupDNSServer(tcpProto) dnsServerUDP.ParseRecords(Config)
go startDNS(dnsServerUDP, errChan) dnsServerTCP := NewDNSServer(DB, Config.General.Listen, tcpProto)
go startDNS(dnsServerTCP, errChan) // No need to parse records from config again
dnsServerTCP.Domains = dnsServerUDP.Domains
go dnsServerUDP.Start(errChan)
go dnsServerTCP.Start(errChan)
} else { } else {
dnsServer := setupDNSServer(Config.General.Proto) dnsServer := NewDNSServer(DB, Config.General.Listen, Config.General.Proto)
go startDNS(dnsServer, errChan) dnsServer.ParseRecords(Config)
go dnsServer.Start(errChan)
} }
// HTTP API // HTTP API
@ -94,20 +94,6 @@ func main() {
log.Debugf("Shutting down...") log.Debugf("Shutting down...")
} }
func startDNS(server *dns.Server, errChan chan error) {
// DNS server part
dns.HandleFunc(".", handleRequest)
log.WithFields(log.Fields{"addr": Config.General.Listen, "proto": server.Net}).Info("Listening DNS")
err := server.ListenAndServe()
if err != nil {
errChan <- err
}
}
func setupDNSServer(proto string) *dns.Server {
return &dns.Server{Addr: Config.General.Listen, Net: proto}
}
func startHTTPAPI(errChan chan error) { func startHTTPAPI(errChan chan error) {
// Setup http logger // Setup http logger
logger := log.New() logger := log.New()

View File

@ -12,6 +12,7 @@ import (
) )
var loghook = new(logrustest.Hook) var loghook = new(logrustest.Hook)
var dnsserver *DNSServer
var ( var (
postgres = flag.Bool("postgres", false, "run integration tests against PostgreSQL") postgres = flag.Bool("postgres", false, "run integration tests against PostgreSQL")
@ -20,6 +21,7 @@ var (
var records = []string{ var records = []string{
"auth.example.org. A 192.168.1.100", "auth.example.org. A 192.168.1.100",
"ns1.auth.example.org. A 192.168.1.101", "ns1.auth.example.org. A 192.168.1.101",
"cn.example.org CNAME something.example.org.",
"!''b', unparseable ", "!''b', unparseable ",
"ns2.auth.example.org. A 192.168.1.102", "ns2.auth.example.org. A 192.168.1.102",
} }
@ -27,7 +29,6 @@ var records = []string{
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
setupTestLogger() setupTestLogger()
setupConfig() setupConfig()
RR.Parse(Config.General)
flag.Parse() flag.Parse()
newDb := new(acmedb) newDb := new(acmedb)
@ -43,17 +44,19 @@ func TestMain(m *testing.M) {
_ = newDb.Init("sqlite3", ":memory:") _ = newDb.Init("sqlite3", ":memory:")
} }
DB = newDb DB = newDb
server := setupDNSServer("udp") dnsserver = NewDNSServer(DB, Config.General.Listen, Config.General.Proto)
dnsserver.ParseRecords(Config)
// Make sure that we're not creating a race condition in tests // Make sure that we're not creating a race condition in tests
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
server.NotifyStartedFunc = func() { dnsserver.Server.NotifyStartedFunc = func() {
wg.Done() wg.Done()
} }
go startDNS(server, make(chan error, 1)) go dnsserver.Start(make(chan error, 1))
wg.Wait() wg.Wait()
exitval := m.Run() exitval := m.Run()
server.Shutdown() dnsserver.Server.Shutdown()
DB.Close() DB.Close()
os.Exit(exitval) os.Exit(exitval)
} }

View File

@ -5,7 +5,6 @@ import (
"sync" "sync"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/miekg/dns"
) )
// Config is global configuration struct // Config is global configuration struct
@ -14,14 +13,6 @@ var Config DNSConfig
// DB is used to access the database functions in acme-dns // DB is used to access the database functions in acme-dns
var DB database var DB database
// RR holds the static DNS records
var RR Records
// Records is for static records
type Records struct {
Records map[uint16]map[string][]dns.RR
}
// DNSConfig holds the config structure // DNSConfig holds the config structure
type DNSConfig struct { type DNSConfig struct {
General general General general