mirror of
https://github.com/joohoi/acme-dns.git
synced 2025-01-13 00:05:33 +07:00
Refactoring DNS server part for safer paraller execution (#144)
* Refactoring DNS server part for safer paraller execution and better data structures * Fix linter issues
This commit is contained in:
parent
d695f72963
commit
7a2f9f06b1
202
dns.go
202
dns.go
@ -8,65 +8,80 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func readQuery(m *dns.Msg) {
|
// Records is a slice of ResourceRecords
|
||||||
for _, que := range m.Question {
|
type Records struct {
|
||||||
if rr, rc, err := answer(que); err == nil {
|
Records []dns.RR
|
||||||
m.MsgHdr.Rcode = rc
|
|
||||||
for _, r := range rr {
|
|
||||||
m.Answer = append(m.Answer, r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func answerTXT(q dns.Question) ([]dns.RR, int, error) {
|
// DNSServer is the main struct for acme-dns DNS server
|
||||||
var ra []dns.RR
|
type DNSServer struct {
|
||||||
rcode := dns.RcodeNameError
|
DB database
|
||||||
subdomain := sanitizeDomainQuestion(q.Name)
|
Server *dns.Server
|
||||||
atxt, err := DB.GetTXTForDomain(subdomain)
|
Domains map[string]Records
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDNSServer parses the DNS records from config and returns a new DNSServer struct
|
||||||
|
func NewDNSServer(db database, addr string, proto string) *DNSServer {
|
||||||
|
var server DNSServer
|
||||||
|
server.Server = &dns.Server{Addr: addr, Net: proto}
|
||||||
|
server.DB = db
|
||||||
|
server.Domains = make(map[string]Records)
|
||||||
|
return &server
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start starts the DNSServer
|
||||||
|
func (d *DNSServer) Start(errorChannel chan error) {
|
||||||
|
// DNS server part
|
||||||
|
dns.HandleFunc(".", d.handleRequest)
|
||||||
|
log.WithFields(log.Fields{"addr": d.Server.Addr, "proto": d.Server.Net}).Info("Listening DNS")
|
||||||
|
err := d.Server.ListenAndServe()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithFields(log.Fields{"error": err.Error()}).Debug("Error while trying to get record")
|
errorChannel <- err
|
||||||
return ra, dns.RcodeNameError, err
|
|
||||||
}
|
}
|
||||||
for _, v := range atxt {
|
|
||||||
if len(v) > 0 {
|
|
||||||
r := new(dns.TXT)
|
|
||||||
r.Hdr = dns.RR_Header{Name: q.Name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 1}
|
|
||||||
r.Txt = append(r.Txt, v)
|
|
||||||
ra = append(ra, r)
|
|
||||||
rcode = dns.RcodeSuccess
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
log.WithFields(log.Fields{"domain": q.Name}).Info("Answering TXT question for domain")
|
|
||||||
return ra, rcode, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func answer(q dns.Question) ([]dns.RR, int, error) {
|
// ParseRecords parses a slice of DNS record string
|
||||||
if q.Qtype == dns.TypeTXT {
|
func (d *DNSServer) ParseRecords(config DNSConfig) {
|
||||||
return answerTXT(q)
|
for _, v := range config.General.StaticRecords {
|
||||||
|
rr, err := dns.NewRR(strings.ToLower(v))
|
||||||
|
if err != nil {
|
||||||
|
log.WithFields(log.Fields{"error": err.Error(), "rr": v}).Warning("Could not parse RR from config")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Add parsed RR
|
||||||
|
d.appendRR(rr)
|
||||||
}
|
}
|
||||||
var r []dns.RR
|
// Create serial
|
||||||
var rcode = dns.RcodeSuccess
|
serial := time.Now().Format("2006010215")
|
||||||
var domain = strings.ToLower(q.Name)
|
// Add SOA
|
||||||
var rtype = q.Qtype
|
SOAstring := fmt.Sprintf("%s. SOA %s. %s. %s 28800 7200 604800 86400", strings.ToLower(config.General.Domain), strings.ToLower(config.General.Nsname), strings.ToLower(config.General.Nsadmin), serial)
|
||||||
r, ok := RR.Records[rtype][domain]
|
soarr, err := dns.NewRR(SOAstring)
|
||||||
|
if err != nil {
|
||||||
|
log.WithFields(log.Fields{"error": err.Error(), "soa": SOAstring}).Error("Error while adding SOA record")
|
||||||
|
} else {
|
||||||
|
d.appendRR(soarr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DNSServer) appendRR(rr dns.RR) {
|
||||||
|
addDomain := rr.Header().Name
|
||||||
|
_, ok := d.Domains[addDomain]
|
||||||
if !ok {
|
if !ok {
|
||||||
r, ok = RR.Records[dns.TypeCNAME][domain]
|
d.Domains[addDomain] = Records{[]dns.RR{rr}}
|
||||||
if !ok {
|
} else {
|
||||||
rcode = dns.RcodeNameError
|
drecs := d.Domains[addDomain]
|
||||||
}
|
drecs.Records = append(drecs.Records, rr)
|
||||||
|
d.Domains[addDomain] = drecs
|
||||||
}
|
}
|
||||||
log.WithFields(log.Fields{"qtype": dns.TypeToString[rtype], "domain": domain, "rcode": dns.RcodeToString[rcode]}).Debug("Answering question for domain")
|
log.WithFields(log.Fields{"recordtype": dns.TypeToString[rr.Header().Rrtype], "domain": addDomain}).Debug("Adding new record to domain")
|
||||||
return r, rcode, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleRequest(w dns.ResponseWriter, r *dns.Msg) {
|
func (d *DNSServer) handleRequest(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
m := new(dns.Msg)
|
m := new(dns.Msg)
|
||||||
m.SetReply(r)
|
m.SetReply(r)
|
||||||
|
|
||||||
if r.Opcode == dns.OpcodeQuery {
|
if r.Opcode == dns.OpcodeQuery {
|
||||||
readQuery(m)
|
d.readQuery(m)
|
||||||
} else if r.Opcode == dns.OpcodeUpdate {
|
} else if r.Opcode == dns.OpcodeUpdate {
|
||||||
log.Debug("Refusing DNS Dynamic update request")
|
log.Debug("Refusing DNS Dynamic update request")
|
||||||
m.MsgHdr.Rcode = dns.RcodeRefused
|
m.MsgHdr.Rcode = dns.RcodeRefused
|
||||||
@ -75,38 +90,81 @@ func handleRequest(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
w.WriteMsg(m)
|
w.WriteMsg(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse config records
|
func (d *DNSServer) readQuery(m *dns.Msg) {
|
||||||
func (r *Records) Parse(config general) {
|
for _, que := range m.Question {
|
||||||
rrmap := make(map[uint16]map[string][]dns.RR)
|
if rr, rc, err := d.answer(que); err == nil {
|
||||||
for _, v := range config.StaticRecords {
|
m.MsgHdr.Rcode = rc
|
||||||
rr, err := dns.NewRR(strings.ToLower(v))
|
for _, r := range rr {
|
||||||
if err != nil {
|
m.Answer = append(m.Answer, r)
|
||||||
log.WithFields(log.Fields{"error": err.Error(), "rr": v}).Warning("Could not parse RR from config")
|
}
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
// Add parsed RR to the list
|
|
||||||
rrmap = appendRR(rrmap, rr)
|
|
||||||
}
|
}
|
||||||
// Create serial
|
|
||||||
serial := time.Now().Format("2006010215")
|
|
||||||
// Add SOA
|
|
||||||
SOAstring := fmt.Sprintf("%s. SOA %s. %s. %s 28800 7200 604800 86400", strings.ToLower(config.Domain), strings.ToLower(config.Nsname), strings.ToLower(config.Nsadmin), serial)
|
|
||||||
soarr, err := dns.NewRR(SOAstring)
|
|
||||||
if err != nil {
|
|
||||||
log.WithFields(log.Fields{"error": err.Error(), "soa": SOAstring}).Error("Error while adding SOA record")
|
|
||||||
} else {
|
|
||||||
rrmap = appendRR(rrmap, soarr)
|
|
||||||
}
|
|
||||||
r.Records = rrmap
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func appendRR(rrmap map[uint16]map[string][]dns.RR, rr dns.RR) map[uint16]map[string][]dns.RR {
|
func (d *DNSServer) getRecord(q dns.Question) ([]dns.RR, error) {
|
||||||
_, ok := rrmap[rr.Header().Rrtype]
|
var rr []dns.RR
|
||||||
|
var cnames []dns.RR
|
||||||
|
domain, ok := d.Domains[q.Name]
|
||||||
if !ok {
|
if !ok {
|
||||||
newrr := make(map[string][]dns.RR)
|
return rr, fmt.Errorf("No records for domain %s", q.Name)
|
||||||
rrmap[rr.Header().Rrtype] = newrr
|
|
||||||
}
|
}
|
||||||
rrmap[rr.Header().Rrtype][rr.Header().Name] = append(rrmap[rr.Header().Rrtype][rr.Header().Name], rr)
|
for _, ri := range domain.Records {
|
||||||
log.WithFields(log.Fields{"recordtype": dns.TypeToString[rr.Header().Rrtype], "domain": rr.Header().Name}).Debug("Adding new record type to domain")
|
if ri.Header().Rrtype == q.Qtype {
|
||||||
return rrmap
|
rr = append(rr, ri)
|
||||||
|
}
|
||||||
|
if ri.Header().Rrtype == dns.TypeCNAME {
|
||||||
|
cnames = append(cnames, ri)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(rr) == 0 {
|
||||||
|
return cnames, nil
|
||||||
|
}
|
||||||
|
return rr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// answeringForDomain checks if we have any records for a domain
|
||||||
|
func (d *DNSServer) answeringForDomain(q dns.Question) bool {
|
||||||
|
_, ok := d.Domains[q.Name]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DNSServer) answer(q dns.Question) ([]dns.RR, int, error) {
|
||||||
|
var rcode int
|
||||||
|
if !d.answeringForDomain(q) {
|
||||||
|
rcode = dns.RcodeNameError
|
||||||
|
}
|
||||||
|
r, _ := d.getRecord(q)
|
||||||
|
if q.Qtype == dns.TypeTXT {
|
||||||
|
txtRRs, err := d.answerTXT(q)
|
||||||
|
if err == nil {
|
||||||
|
for _, txtRR := range txtRRs {
|
||||||
|
r = append(r, txtRR)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(r) > 0 {
|
||||||
|
// Make sure that we return NOERROR if there were dynamic records for the domain
|
||||||
|
rcode = dns.RcodeSuccess
|
||||||
|
}
|
||||||
|
log.WithFields(log.Fields{"qtype": dns.TypeToString[q.Qtype], "domain": q.Name, "rcode": dns.RcodeToString[rcode]}).Debug("Answering question for domain")
|
||||||
|
return r, rcode, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DNSServer) answerTXT(q dns.Question) ([]dns.RR, error) {
|
||||||
|
var ra []dns.RR
|
||||||
|
subdomain := sanitizeDomainQuestion(q.Name)
|
||||||
|
atxt, err := d.DB.GetTXTForDomain(subdomain)
|
||||||
|
if err != nil {
|
||||||
|
log.WithFields(log.Fields{"error": err.Error()}).Debug("Error while trying to get record")
|
||||||
|
return ra, err
|
||||||
|
}
|
||||||
|
for _, v := range atxt {
|
||||||
|
if len(v) > 0 {
|
||||||
|
r := new(dns.TXT)
|
||||||
|
r.Hdr = dns.RR_Header{Name: q.Name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 1}
|
||||||
|
r.Txt = append(r.Txt, v)
|
||||||
|
ra = append(ra, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ra, nil
|
||||||
}
|
}
|
||||||
|
91
dns_test.go
91
dns_test.go
@ -5,10 +5,10 @@ import (
|
|||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
"github.com/erikstmartin/go-testdb"
|
"github.com/erikstmartin/go-testdb"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var resolv resolver
|
var resolv resolver
|
||||||
@ -51,25 +51,6 @@ func hasExpectedTXTAnswer(answer []dns.RR, cmpTXT string) error {
|
|||||||
return errors.New("Expected answer not found")
|
return errors.New("Expected answer not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
func findRecordFromMemory(rrstr string, host string, qtype uint16) error {
|
|
||||||
var errmsg = "No record found"
|
|
||||||
arr, _ := dns.NewRR(strings.ToLower(rrstr))
|
|
||||||
if arrQt, ok := RR.Records[qtype]; ok {
|
|
||||||
if arrHst, ok := arrQt[host]; ok {
|
|
||||||
for _, v := range arrHst {
|
|
||||||
if arr.String() == v.String() {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
errmsg = "No records for domain"
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
errmsg = "No records for this type in DB"
|
|
||||||
}
|
|
||||||
return errors.New(errmsg)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQuestionDBError(t *testing.T) {
|
func TestQuestionDBError(t *testing.T) {
|
||||||
testdb.SetQueryWithArgsFunc(func(query string, args []driver.Value) (result driver.Rows, err error) {
|
testdb.SetQueryWithArgsFunc(func(query string, args []driver.Value) (result driver.Rows, err error) {
|
||||||
columns := []string{"Username", "Password", "Subdomain", "Value", "LastActive"}
|
columns := []string{"Username", "Password", "Subdomain", "Value", "LastActive"}
|
||||||
@ -88,44 +69,36 @@ func TestQuestionDBError(t *testing.T) {
|
|||||||
defer DB.SetBackend(oldDb)
|
defer DB.SetBackend(oldDb)
|
||||||
|
|
||||||
q := dns.Question{Name: dns.Fqdn("whatever.tld"), Qtype: dns.TypeTXT, Qclass: dns.ClassINET}
|
q := dns.Question{Name: dns.Fqdn("whatever.tld"), Qtype: dns.TypeTXT, Qclass: dns.ClassINET}
|
||||||
_, rcode, err := answerTXT(q)
|
_, err = dnsserver.answerTXT(q)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("Expected error but got none")
|
t.Errorf("Expected error but got none")
|
||||||
}
|
}
|
||||||
if rcode != dns.RcodeNameError {
|
|
||||||
t.Errorf("Expected [%s] rcode, but got [%s]", dns.RcodeToString[dns.RcodeNameError], dns.RcodeToString[rcode])
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParse(t *testing.T) {
|
func TestParse(t *testing.T) {
|
||||||
var testcfg = general{
|
var testcfg = DNSConfig{
|
||||||
Domain: ")",
|
General: general{
|
||||||
Nsname: "ns1.auth.example.org",
|
Domain: ")",
|
||||||
Nsadmin: "admin.example.org",
|
Nsname: "ns1.auth.example.org",
|
||||||
StaticRecords: []string{},
|
Nsadmin: "admin.example.org",
|
||||||
Debug: false,
|
StaticRecords: []string{},
|
||||||
|
Debug: false,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
var testRR Records
|
dnsserver.ParseRecords(testcfg)
|
||||||
testRR.Parse(testcfg)
|
|
||||||
if !loggerHasEntryWithMessage("Error while adding SOA record") {
|
if !loggerHasEntryWithMessage("Error while adding SOA record") {
|
||||||
t.Errorf("Expected SOA parsing to return error, but did not find one")
|
t.Errorf("Expected SOA parsing to return error, but did not find one")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestResolveA(t *testing.T) {
|
func TestResolveA(t *testing.T) {
|
||||||
resolv := resolver{server: "0.0.0.0:15353"}
|
resolv := resolver{server: "127.0.0.1:15353"}
|
||||||
answer, err := resolv.lookup("auth.example.org", dns.TypeA)
|
answer, err := resolv.lookup("auth.example.org", dns.TypeA)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("%v", err)
|
t.Errorf("%v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(answer) > 0 {
|
if len(answer) == 0 {
|
||||||
err = findRecordFromMemory(answer[0].String(), "auth.example.org.", dns.TypeA)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Answer [%s] did not match the expected, got error: [%s], debug: [%q]", answer[0].String(), err, RR.Records)
|
|
||||||
}
|
|
||||||
|
|
||||||
} else {
|
|
||||||
t.Error("No answer for DNS query")
|
t.Error("No answer for DNS query")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -135,8 +108,42 @@ func TestResolveA(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpcodeUpdate(t *testing.T) {
|
||||||
|
msg := new(dns.Msg)
|
||||||
|
msg.Id = dns.Id()
|
||||||
|
msg.Question = make([]dns.Question, 1)
|
||||||
|
msg.Question[0] = dns.Question{Name: dns.Fqdn("auth.example.org"), Qtype: dns.TypeANY, Qclass: dns.ClassINET}
|
||||||
|
msg.MsgHdr.Opcode = dns.OpcodeUpdate
|
||||||
|
in, err := dns.Exchange(msg, "127.0.0.1:15353")
|
||||||
|
if err != nil || in == nil {
|
||||||
|
t.Errorf("Encountered an error with UPDATE request")
|
||||||
|
} else if err == nil {
|
||||||
|
if in.Rcode != dns.RcodeRefused {
|
||||||
|
t.Errorf("Expected RCODE Refused from UPDATE request, but got [%s] instead", dns.RcodeToString[in.Rcode])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveCNAME(t *testing.T) {
|
||||||
|
resolv := resolver{server: "127.0.0.1:15353"}
|
||||||
|
expected := "cn.example.org. 3600 IN CNAME something.example.org."
|
||||||
|
answer, err := resolv.lookup("cn.example.org", dns.TypeCNAME)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Got unexpected error: %s", err)
|
||||||
|
}
|
||||||
|
if len(answer) != 1 {
|
||||||
|
t.Errorf("Expected exactly 1 RR in answer, but got %d instead.", len(answer))
|
||||||
|
}
|
||||||
|
if answer[0].Header().Rrtype != dns.TypeCNAME {
|
||||||
|
t.Errorf("Expected a CNAME answer, but got [%s] instead.", dns.TypeToString[answer[0].Header().Rrtype])
|
||||||
|
}
|
||||||
|
if answer[0].String() != expected {
|
||||||
|
t.Errorf("Expected CNAME answer [%s] but got [%s] instead.", expected, answer[0].String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestResolveTXT(t *testing.T) {
|
func TestResolveTXT(t *testing.T) {
|
||||||
resolv := resolver{server: "0.0.0.0:15353"}
|
resolv := resolver{server: "127.0.0.1:15353"}
|
||||||
validTXT := "______________valid_response_______________"
|
validTXT := "______________valid_response_______________"
|
||||||
|
|
||||||
atxt, err := DB.Register(cidrslice{})
|
atxt, err := DB.Register(cidrslice{})
|
||||||
|
34
main.go
34
main.go
@ -12,7 +12,6 @@ import (
|
|||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
"github.com/julienschmidt/httprouter"
|
"github.com/julienschmidt/httprouter"
|
||||||
"github.com/miekg/dns"
|
|
||||||
"github.com/rs/cors"
|
"github.com/rs/cors"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/crypto/acme/autocert"
|
"golang.org/x/crypto/acme/autocert"
|
||||||
@ -42,9 +41,6 @@ func main() {
|
|||||||
|
|
||||||
setupLogging(Config.Logconfig.Format, Config.Logconfig.Level)
|
setupLogging(Config.Logconfig.Format, Config.Logconfig.Level)
|
||||||
|
|
||||||
// Read the default records in
|
|
||||||
RR.Parse(Config.General)
|
|
||||||
|
|
||||||
// Open database
|
// Open database
|
||||||
newDB := new(acmedb)
|
newDB := new(acmedb)
|
||||||
err = newDB.Init(Config.Database.Engine, Config.Database.Connection)
|
err = newDB.Init(Config.Database.Engine, Config.Database.Connection)
|
||||||
@ -72,13 +68,17 @@ func main() {
|
|||||||
udpProto += "6"
|
udpProto += "6"
|
||||||
tcpProto += "6"
|
tcpProto += "6"
|
||||||
}
|
}
|
||||||
dnsServerUDP := setupDNSServer(udpProto)
|
dnsServerUDP := NewDNSServer(DB, Config.General.Listen, udpProto)
|
||||||
dnsServerTCP := setupDNSServer(tcpProto)
|
dnsServerUDP.ParseRecords(Config)
|
||||||
go startDNS(dnsServerUDP, errChan)
|
dnsServerTCP := NewDNSServer(DB, Config.General.Listen, tcpProto)
|
||||||
go startDNS(dnsServerTCP, errChan)
|
// No need to parse records from config again
|
||||||
|
dnsServerTCP.Domains = dnsServerUDP.Domains
|
||||||
|
go dnsServerUDP.Start(errChan)
|
||||||
|
go dnsServerTCP.Start(errChan)
|
||||||
} else {
|
} else {
|
||||||
dnsServer := setupDNSServer(Config.General.Proto)
|
dnsServer := NewDNSServer(DB, Config.General.Listen, Config.General.Proto)
|
||||||
go startDNS(dnsServer, errChan)
|
dnsServer.ParseRecords(Config)
|
||||||
|
go dnsServer.Start(errChan)
|
||||||
}
|
}
|
||||||
|
|
||||||
// HTTP API
|
// HTTP API
|
||||||
@ -94,20 +94,6 @@ func main() {
|
|||||||
log.Debugf("Shutting down...")
|
log.Debugf("Shutting down...")
|
||||||
}
|
}
|
||||||
|
|
||||||
func startDNS(server *dns.Server, errChan chan error) {
|
|
||||||
// DNS server part
|
|
||||||
dns.HandleFunc(".", handleRequest)
|
|
||||||
log.WithFields(log.Fields{"addr": Config.General.Listen, "proto": server.Net}).Info("Listening DNS")
|
|
||||||
err := server.ListenAndServe()
|
|
||||||
if err != nil {
|
|
||||||
errChan <- err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func setupDNSServer(proto string) *dns.Server {
|
|
||||||
return &dns.Server{Addr: Config.General.Listen, Net: proto}
|
|
||||||
}
|
|
||||||
|
|
||||||
func startHTTPAPI(errChan chan error) {
|
func startHTTPAPI(errChan chan error) {
|
||||||
// Setup http logger
|
// Setup http logger
|
||||||
logger := log.New()
|
logger := log.New()
|
||||||
|
13
main_test.go
13
main_test.go
@ -12,6 +12,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var loghook = new(logrustest.Hook)
|
var loghook = new(logrustest.Hook)
|
||||||
|
var dnsserver *DNSServer
|
||||||
|
|
||||||
var (
|
var (
|
||||||
postgres = flag.Bool("postgres", false, "run integration tests against PostgreSQL")
|
postgres = flag.Bool("postgres", false, "run integration tests against PostgreSQL")
|
||||||
@ -20,6 +21,7 @@ var (
|
|||||||
var records = []string{
|
var records = []string{
|
||||||
"auth.example.org. A 192.168.1.100",
|
"auth.example.org. A 192.168.1.100",
|
||||||
"ns1.auth.example.org. A 192.168.1.101",
|
"ns1.auth.example.org. A 192.168.1.101",
|
||||||
|
"cn.example.org CNAME something.example.org.",
|
||||||
"!''b', unparseable ",
|
"!''b', unparseable ",
|
||||||
"ns2.auth.example.org. A 192.168.1.102",
|
"ns2.auth.example.org. A 192.168.1.102",
|
||||||
}
|
}
|
||||||
@ -27,7 +29,6 @@ var records = []string{
|
|||||||
func TestMain(m *testing.M) {
|
func TestMain(m *testing.M) {
|
||||||
setupTestLogger()
|
setupTestLogger()
|
||||||
setupConfig()
|
setupConfig()
|
||||||
RR.Parse(Config.General)
|
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
newDb := new(acmedb)
|
newDb := new(acmedb)
|
||||||
@ -43,17 +44,19 @@ func TestMain(m *testing.M) {
|
|||||||
_ = newDb.Init("sqlite3", ":memory:")
|
_ = newDb.Init("sqlite3", ":memory:")
|
||||||
}
|
}
|
||||||
DB = newDb
|
DB = newDb
|
||||||
server := setupDNSServer("udp")
|
dnsserver = NewDNSServer(DB, Config.General.Listen, Config.General.Proto)
|
||||||
|
dnsserver.ParseRecords(Config)
|
||||||
|
|
||||||
// Make sure that we're not creating a race condition in tests
|
// Make sure that we're not creating a race condition in tests
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
server.NotifyStartedFunc = func() {
|
dnsserver.Server.NotifyStartedFunc = func() {
|
||||||
wg.Done()
|
wg.Done()
|
||||||
}
|
}
|
||||||
go startDNS(server, make(chan error, 1))
|
go dnsserver.Start(make(chan error, 1))
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
exitval := m.Run()
|
exitval := m.Run()
|
||||||
server.Shutdown()
|
dnsserver.Server.Shutdown()
|
||||||
DB.Close()
|
DB.Close()
|
||||||
os.Exit(exitval)
|
os.Exit(exitval)
|
||||||
}
|
}
|
||||||
|
9
types.go
9
types.go
@ -5,7 +5,6 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/miekg/dns"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Config is global configuration struct
|
// Config is global configuration struct
|
||||||
@ -14,14 +13,6 @@ var Config DNSConfig
|
|||||||
// DB is used to access the database functions in acme-dns
|
// DB is used to access the database functions in acme-dns
|
||||||
var DB database
|
var DB database
|
||||||
|
|
||||||
// RR holds the static DNS records
|
|
||||||
var RR Records
|
|
||||||
|
|
||||||
// Records is for static records
|
|
||||||
type Records struct {
|
|
||||||
Records map[uint16]map[string][]dns.RR
|
|
||||||
}
|
|
||||||
|
|
||||||
// DNSConfig holds the config structure
|
// DNSConfig holds the config structure
|
||||||
type DNSConfig struct {
|
type DNSConfig struct {
|
||||||
General general
|
General general
|
||||||
|
Loading…
Reference in New Issue
Block a user