mirror of
https://github.com/joohoi/acme-dns.git
synced 2025-07-13 09:17:47 +07:00
Refactoring DNS server part for safer paraller execution (#144)
* Refactoring DNS server part for safer paraller execution and better data structures * Fix linter issues
This commit is contained in:
202
dns.go
202
dns.go
@ -8,65 +8,80 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func readQuery(m *dns.Msg) {
|
||||
for _, que := range m.Question {
|
||||
if rr, rc, err := answer(que); err == nil {
|
||||
m.MsgHdr.Rcode = rc
|
||||
for _, r := range rr {
|
||||
m.Answer = append(m.Answer, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Records is a slice of ResourceRecords
|
||||
type Records struct {
|
||||
Records []dns.RR
|
||||
}
|
||||
|
||||
func answerTXT(q dns.Question) ([]dns.RR, int, error) {
|
||||
var ra []dns.RR
|
||||
rcode := dns.RcodeNameError
|
||||
subdomain := sanitizeDomainQuestion(q.Name)
|
||||
atxt, err := DB.GetTXTForDomain(subdomain)
|
||||
// DNSServer is the main struct for acme-dns DNS server
|
||||
type DNSServer struct {
|
||||
DB database
|
||||
Server *dns.Server
|
||||
Domains map[string]Records
|
||||
}
|
||||
|
||||
// NewDNSServer parses the DNS records from config and returns a new DNSServer struct
|
||||
func NewDNSServer(db database, addr string, proto string) *DNSServer {
|
||||
var server DNSServer
|
||||
server.Server = &dns.Server{Addr: addr, Net: proto}
|
||||
server.DB = db
|
||||
server.Domains = make(map[string]Records)
|
||||
return &server
|
||||
}
|
||||
|
||||
// Start starts the DNSServer
|
||||
func (d *DNSServer) Start(errorChannel chan error) {
|
||||
// DNS server part
|
||||
dns.HandleFunc(".", d.handleRequest)
|
||||
log.WithFields(log.Fields{"addr": d.Server.Addr, "proto": d.Server.Net}).Info("Listening DNS")
|
||||
err := d.Server.ListenAndServe()
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{"error": err.Error()}).Debug("Error while trying to get record")
|
||||
return ra, dns.RcodeNameError, err
|
||||
errorChannel <- err
|
||||
}
|
||||
for _, v := range atxt {
|
||||
if len(v) > 0 {
|
||||
r := new(dns.TXT)
|
||||
r.Hdr = dns.RR_Header{Name: q.Name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 1}
|
||||
r.Txt = append(r.Txt, v)
|
||||
ra = append(ra, r)
|
||||
rcode = dns.RcodeSuccess
|
||||
}
|
||||
}
|
||||
|
||||
log.WithFields(log.Fields{"domain": q.Name}).Info("Answering TXT question for domain")
|
||||
return ra, rcode, nil
|
||||
}
|
||||
|
||||
func answer(q dns.Question) ([]dns.RR, int, error) {
|
||||
if q.Qtype == dns.TypeTXT {
|
||||
return answerTXT(q)
|
||||
// ParseRecords parses a slice of DNS record string
|
||||
func (d *DNSServer) ParseRecords(config DNSConfig) {
|
||||
for _, v := range config.General.StaticRecords {
|
||||
rr, err := dns.NewRR(strings.ToLower(v))
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{"error": err.Error(), "rr": v}).Warning("Could not parse RR from config")
|
||||
continue
|
||||
}
|
||||
// Add parsed RR
|
||||
d.appendRR(rr)
|
||||
}
|
||||
var r []dns.RR
|
||||
var rcode = dns.RcodeSuccess
|
||||
var domain = strings.ToLower(q.Name)
|
||||
var rtype = q.Qtype
|
||||
r, ok := RR.Records[rtype][domain]
|
||||
// Create serial
|
||||
serial := time.Now().Format("2006010215")
|
||||
// Add SOA
|
||||
SOAstring := fmt.Sprintf("%s. SOA %s. %s. %s 28800 7200 604800 86400", strings.ToLower(config.General.Domain), strings.ToLower(config.General.Nsname), strings.ToLower(config.General.Nsadmin), serial)
|
||||
soarr, err := dns.NewRR(SOAstring)
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{"error": err.Error(), "soa": SOAstring}).Error("Error while adding SOA record")
|
||||
} else {
|
||||
d.appendRR(soarr)
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DNSServer) appendRR(rr dns.RR) {
|
||||
addDomain := rr.Header().Name
|
||||
_, ok := d.Domains[addDomain]
|
||||
if !ok {
|
||||
r, ok = RR.Records[dns.TypeCNAME][domain]
|
||||
if !ok {
|
||||
rcode = dns.RcodeNameError
|
||||
}
|
||||
d.Domains[addDomain] = Records{[]dns.RR{rr}}
|
||||
} else {
|
||||
drecs := d.Domains[addDomain]
|
||||
drecs.Records = append(drecs.Records, rr)
|
||||
d.Domains[addDomain] = drecs
|
||||
}
|
||||
log.WithFields(log.Fields{"qtype": dns.TypeToString[rtype], "domain": domain, "rcode": dns.RcodeToString[rcode]}).Debug("Answering question for domain")
|
||||
return r, rcode, nil
|
||||
log.WithFields(log.Fields{"recordtype": dns.TypeToString[rr.Header().Rrtype], "domain": addDomain}).Debug("Adding new record to domain")
|
||||
}
|
||||
|
||||
func handleRequest(w dns.ResponseWriter, r *dns.Msg) {
|
||||
func (d *DNSServer) handleRequest(w dns.ResponseWriter, r *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
|
||||
if r.Opcode == dns.OpcodeQuery {
|
||||
readQuery(m)
|
||||
d.readQuery(m)
|
||||
} else if r.Opcode == dns.OpcodeUpdate {
|
||||
log.Debug("Refusing DNS Dynamic update request")
|
||||
m.MsgHdr.Rcode = dns.RcodeRefused
|
||||
@ -75,38 +90,81 @@ func handleRequest(w dns.ResponseWriter, r *dns.Msg) {
|
||||
w.WriteMsg(m)
|
||||
}
|
||||
|
||||
// Parse config records
|
||||
func (r *Records) Parse(config general) {
|
||||
rrmap := make(map[uint16]map[string][]dns.RR)
|
||||
for _, v := range config.StaticRecords {
|
||||
rr, err := dns.NewRR(strings.ToLower(v))
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{"error": err.Error(), "rr": v}).Warning("Could not parse RR from config")
|
||||
continue
|
||||
func (d *DNSServer) readQuery(m *dns.Msg) {
|
||||
for _, que := range m.Question {
|
||||
if rr, rc, err := d.answer(que); err == nil {
|
||||
m.MsgHdr.Rcode = rc
|
||||
for _, r := range rr {
|
||||
m.Answer = append(m.Answer, r)
|
||||
}
|
||||
}
|
||||
// Add parsed RR to the list
|
||||
rrmap = appendRR(rrmap, rr)
|
||||
}
|
||||
// Create serial
|
||||
serial := time.Now().Format("2006010215")
|
||||
// Add SOA
|
||||
SOAstring := fmt.Sprintf("%s. SOA %s. %s. %s 28800 7200 604800 86400", strings.ToLower(config.Domain), strings.ToLower(config.Nsname), strings.ToLower(config.Nsadmin), serial)
|
||||
soarr, err := dns.NewRR(SOAstring)
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{"error": err.Error(), "soa": SOAstring}).Error("Error while adding SOA record")
|
||||
} else {
|
||||
rrmap = appendRR(rrmap, soarr)
|
||||
}
|
||||
r.Records = rrmap
|
||||
}
|
||||
|
||||
func appendRR(rrmap map[uint16]map[string][]dns.RR, rr dns.RR) map[uint16]map[string][]dns.RR {
|
||||
_, ok := rrmap[rr.Header().Rrtype]
|
||||
func (d *DNSServer) getRecord(q dns.Question) ([]dns.RR, error) {
|
||||
var rr []dns.RR
|
||||
var cnames []dns.RR
|
||||
domain, ok := d.Domains[q.Name]
|
||||
if !ok {
|
||||
newrr := make(map[string][]dns.RR)
|
||||
rrmap[rr.Header().Rrtype] = newrr
|
||||
return rr, fmt.Errorf("No records for domain %s", q.Name)
|
||||
}
|
||||
rrmap[rr.Header().Rrtype][rr.Header().Name] = append(rrmap[rr.Header().Rrtype][rr.Header().Name], rr)
|
||||
log.WithFields(log.Fields{"recordtype": dns.TypeToString[rr.Header().Rrtype], "domain": rr.Header().Name}).Debug("Adding new record type to domain")
|
||||
return rrmap
|
||||
for _, ri := range domain.Records {
|
||||
if ri.Header().Rrtype == q.Qtype {
|
||||
rr = append(rr, ri)
|
||||
}
|
||||
if ri.Header().Rrtype == dns.TypeCNAME {
|
||||
cnames = append(cnames, ri)
|
||||
}
|
||||
}
|
||||
if len(rr) == 0 {
|
||||
return cnames, nil
|
||||
}
|
||||
return rr, nil
|
||||
}
|
||||
|
||||
// answeringForDomain checks if we have any records for a domain
|
||||
func (d *DNSServer) answeringForDomain(q dns.Question) bool {
|
||||
_, ok := d.Domains[q.Name]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (d *DNSServer) answer(q dns.Question) ([]dns.RR, int, error) {
|
||||
var rcode int
|
||||
if !d.answeringForDomain(q) {
|
||||
rcode = dns.RcodeNameError
|
||||
}
|
||||
r, _ := d.getRecord(q)
|
||||
if q.Qtype == dns.TypeTXT {
|
||||
txtRRs, err := d.answerTXT(q)
|
||||
if err == nil {
|
||||
for _, txtRR := range txtRRs {
|
||||
r = append(r, txtRR)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(r) > 0 {
|
||||
// Make sure that we return NOERROR if there were dynamic records for the domain
|
||||
rcode = dns.RcodeSuccess
|
||||
}
|
||||
log.WithFields(log.Fields{"qtype": dns.TypeToString[q.Qtype], "domain": q.Name, "rcode": dns.RcodeToString[rcode]}).Debug("Answering question for domain")
|
||||
return r, rcode, nil
|
||||
}
|
||||
|
||||
func (d *DNSServer) answerTXT(q dns.Question) ([]dns.RR, error) {
|
||||
var ra []dns.RR
|
||||
subdomain := sanitizeDomainQuestion(q.Name)
|
||||
atxt, err := d.DB.GetTXTForDomain(subdomain)
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{"error": err.Error()}).Debug("Error while trying to get record")
|
||||
return ra, err
|
||||
}
|
||||
for _, v := range atxt {
|
||||
if len(v) > 0 {
|
||||
r := new(dns.TXT)
|
||||
r.Hdr = dns.RR_Header{Name: q.Name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 1}
|
||||
r.Txt = append(r.Txt, v)
|
||||
ra = append(ra, r)
|
||||
}
|
||||
}
|
||||
return ra, nil
|
||||
}
|
||||
|
Reference in New Issue
Block a user