Better error handling in goroutines (#122)

* More robust goroutine error handling using channels

* Fix tests and make startup log msg saner

* Clarification to README and config file
This commit is contained in:
Joona Hoikkala
2018-10-31 00:54:51 +02:00
committed by GitHub
parent a09073da12
commit c2c5c5cd70
6 changed files with 55 additions and 19 deletions

View File

@ -1,6 +1,7 @@
language: go language: go
go: go:
- 1.9 - 1.9
- 1.11
env: env:
- "PATH=/home/travis/gopath/bin:$PATH" - "PATH=/home/travis/gopath/bin:$PATH"
before_install: before_install:

View File

@ -212,7 +212,9 @@ $ dig @auth.example.org d420c923-bbd7-4056-ab64-c3ca54c9b3cf.auth.example.org
```bash ```bash
[general] [general]
# dns interface # DNS interface. Note that systemd-resolved may reserve port 53 on 127.0.0.53
# In this case acme-dns will error out and you will need to define the listening interface
# for example: listen = "127.0.0.1:53"
listen = ":53" listen = ":53"
# protocol, "udp", "udp4", "udp6" or "tcp", "tcp4", "tcp6" # protocol, "udp", "udp4", "udp6" or "tcp", "tcp4", "tcp6"
protocol = "udp" protocol = "udp"

View File

@ -1,5 +1,7 @@
[general] [general]
# dns interface # DNS interface. Note that systemd-resolved may reserve port 53 on 127.0.0.53
# In this case acme-dns will error out and you will need to define the listening interface
# for example: listen = "127.0.0.1:53"
listen = ":53" listen = ":53"
# protocol, "udp", "udp4", "udp6" or "tcp", "tcp4", "tcp6" # protocol, "udp", "udp4", "udp6" or "tcp", "tcp4", "tcp6"
protocol = "udp" protocol = "udp"

43
main.go
View File

@ -11,6 +11,7 @@ import (
"syscall" "syscall"
"github.com/julienschmidt/httprouter" "github.com/julienschmidt/httprouter"
"github.com/miekg/dns"
"github.com/rs/cors" "github.com/rs/cors"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/crypto/acme/autocert" "golang.org/x/crypto/acme/autocert"
@ -55,16 +56,41 @@ func main() {
DB = newDB DB = newDB
defer DB.Close() defer DB.Close()
// Error channel for servers
errChan := make(chan error, 1)
// DNS server // DNS server
startDNS(Config.General.Listen, Config.General.Proto) dnsServer := setupDNSServer()
go startDNS(dnsServer, errChan)
// HTTP API // HTTP API
startHTTPAPI() go startHTTPAPI(errChan)
// block waiting for error
select {
case err = <-errChan:
if err != nil {
log.Fatal(err)
}
}
log.Debugf("Shutting down...") log.Debugf("Shutting down...")
} }
func startHTTPAPI() { func startDNS(server *dns.Server, errChan chan error) {
// DNS server part
dns.HandleFunc(".", handleRequest)
log.WithFields(log.Fields{"addr": Config.General.Listen}).Info("Listening DNS")
err := server.ListenAndServe()
if err != nil {
errChan <- err
}
}
func setupDNSServer() *dns.Server {
return &dns.Server{Addr: Config.General.Listen, Net: Config.General.Proto}
}
func startHTTPAPI(errChan chan error) {
// Setup http logger // Setup http logger
logger := log.New() logger := log.New()
logwriter := logger.Writer() logwriter := logger.Writer()
@ -90,7 +116,7 @@ func startHTTPAPI() {
cfg := &tls.Config{ cfg := &tls.Config{
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12,
} }
var err error
switch Config.API.TLS { switch Config.API.TLS {
case "letsencrypt": case "letsencrypt":
m := autocert.Manager{ m := autocert.Manager{
@ -109,7 +135,7 @@ func startHTTPAPI() {
ErrorLog: stdlog.New(logwriter, "", 0), ErrorLog: stdlog.New(logwriter, "", 0),
} }
log.WithFields(log.Fields{"host": host, "domain": Config.API.Domain}).Info("Listening HTTPS, using certificate from autocert") log.WithFields(log.Fields{"host": host, "domain": Config.API.Domain}).Info("Listening HTTPS, using certificate from autocert")
log.Fatal(srv.ListenAndServeTLS("", "")) err = srv.ListenAndServeTLS("", "")
case "cert": case "cert":
srv := &http.Server{ srv := &http.Server{
Addr: host, Addr: host,
@ -118,9 +144,12 @@ func startHTTPAPI() {
ErrorLog: stdlog.New(logwriter, "", 0), ErrorLog: stdlog.New(logwriter, "", 0),
} }
log.WithFields(log.Fields{"host": host}).Info("Listening HTTPS") log.WithFields(log.Fields{"host": host}).Info("Listening HTTPS")
log.Fatal(srv.ListenAndServeTLS(Config.API.TLSCertFullchain, Config.API.TLSCertPrivkey)) err = srv.ListenAndServeTLS(Config.API.TLSCertFullchain, Config.API.TLSCertPrivkey)
default: default:
log.WithFields(log.Fields{"host": host}).Info("Listening HTTP") log.WithFields(log.Fields{"host": host}).Info("Listening HTTP")
log.Fatal(http.ListenAndServe(host, c.Handler(api))) err = http.ListenAndServe(host, c.Handler(api))
}
if err != nil {
errChan <- err
} }
} }

View File

@ -7,6 +7,7 @@ import (
logrustest "github.com/sirupsen/logrus/hooks/test" logrustest "github.com/sirupsen/logrus/hooks/test"
"io/ioutil" "io/ioutil"
"os" "os"
"sync"
"testing" "testing"
) )
@ -42,7 +43,15 @@ func TestMain(m *testing.M) {
_ = newDb.Init("sqlite3", ":memory:") _ = newDb.Init("sqlite3", ":memory:")
} }
DB = newDb DB = newDb
server := startDNS("0.0.0.0:15353", "udp") server := setupDNSServer()
// Make sure that we're not creating a race condition in tests
var wg sync.WaitGroup
wg.Add(1)
server.NotifyStartedFunc = func() {
wg.Done()
}
go startDNS(server, make(chan error, 1))
wg.Wait()
exitval := m.Run() exitval := m.Run()
server.Shutdown() server.Shutdown()
DB.Close() DB.Close()
@ -57,6 +66,8 @@ func setupConfig() {
var generalcfg = general{ var generalcfg = general{
Domain: "auth.example.org", Domain: "auth.example.org",
Listen: "127.0.0.1:15353",
Proto: "udp",
Nsname: "ns1.auth.example.org", Nsname: "ns1.auth.example.org",
Nsadmin: "admin.example.org", Nsadmin: "admin.example.org",
StaticRecords: records, StaticRecords: records,

View File

@ -10,7 +10,6 @@ import (
"strings" "strings"
"github.com/BurntSushi/toml" "github.com/BurntSushi/toml"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@ -108,14 +107,6 @@ func setupLogging(format string, level string) {
// TODO: file logging // TODO: file logging
} }
func startDNS(listen string, proto string) *dns.Server {
// DNS server part
dns.HandleFunc(".", handleRequest)
server := &dns.Server{Addr: listen, Net: proto}
go server.ListenAndServe()
return server
}
func getIPListFromHeader(header string) []string { func getIPListFromHeader(header string) []string {
iplist := []string{} iplist := []string{}
for _, v := range strings.Split(header, ",") { for _, v := range strings.Split(header, ",") {