Refactoring

This commit is contained in:
Joona Hoikkala
2016-11-27 21:21:38 +02:00
parent bff6310dc7
commit 74b82c87a6
6 changed files with 31 additions and 47 deletions

View File

@ -1,4 +1,6 @@
[general] [general]
# dns interface
listen = ":53"
# domain name to serve th requests off of # domain name to serve th requests off of
domain = "auth.example.org" domain = "auth.example.org"
# zone name server # zone name server

View File

@ -3,7 +3,6 @@ package main
import ( import (
"errors" "errors"
"fmt" "fmt"
log "github.com/Sirupsen/logrus"
"github.com/miekg/dns" "github.com/miekg/dns"
"strings" "strings"
"testing" "testing"
@ -74,22 +73,8 @@ func findRecordFromMemory(rrstr string, host string, qtype uint16) error {
return errors.New(errmsg) return errors.New(errmsg)
} }
func startDNSServer(addr string) (*dns.Server, resolver) {
// DNS server part
dns.HandleFunc(".", handleRequest)
server := &dns.Server{Addr: addr, Net: "udp"}
go func() {
err := server.ListenAndServe()
if err != nil {
log.Errorf("%v", err)
}
}()
return server, resolver{server: addr}
}
func TestResolveA(t *testing.T) { func TestResolveA(t *testing.T) {
setupConfig() resolv := resolver{server: "0.0.0.0:15353"}
answer, err := resolv.lookup("auth.example.org", dns.TypeA) answer, err := resolv.lookup("auth.example.org", dns.TypeA)
if err != nil { if err != nil {
t.Errorf("%v", err) t.Errorf("%v", err)
@ -107,8 +92,7 @@ func TestResolveA(t *testing.T) {
} }
func TestResolveTXT(t *testing.T) { func TestResolveTXT(t *testing.T) {
setupConfig() resolv := resolver{server: "0.0.0.0:15353"}
validTXT := "______________valid_response_______________" validTXT := "______________valid_response_______________"
atxt, err := DB.Register() atxt, err := DB.Register()

15
main.go
View File

@ -5,7 +5,6 @@ import (
log "github.com/Sirupsen/logrus" log "github.com/Sirupsen/logrus"
"github.com/iris-contrib/middleware/cors" "github.com/iris-contrib/middleware/cors"
"github.com/kataras/iris" "github.com/kataras/iris"
"github.com/miekg/dns"
"os" "os"
) )
@ -27,7 +26,7 @@ func main() {
} }
DNSConf = configTmp DNSConf = configTmp
setupLogging() setupLogging(DNSConf.Logconfig.Format, DNSConf.Logconfig.Level)
// Read the default records in // Read the default records in
RR.Parse(DNSConf.General.StaticRecords) RR.Parse(DNSConf.General.StaticRecords)
@ -40,16 +39,8 @@ func main() {
} }
defer DB.DB.Close() defer DB.DB.Close()
// DNS server part // DNS server
dns.HandleFunc(".", handleRequest) startDNS(DNSConf.General.Listen)
server := &dns.Server{Addr: ":53", Net: "udp"}
go func() {
err = server.ListenAndServe()
if err != nil {
log.Errorf("%v", err)
os.Exit(1)
}
}()
// API server and endpoints // API server and endpoints
api := iris.New() api := iris.New()

View File

@ -28,7 +28,7 @@ func TestMain(m *testing.M) {
_ = DB.Init("sqlite3", ":memory:") _ = DB.Init("sqlite3", ":memory:")
} }
server, resolv = startDNSServer("0.0.0.0:15353") server := startDNS("0.0.0.0:15353")
exitval := m.Run() exitval := m.Run()
server.Shutdown() server.Shutdown()
DB.DB.Close() DB.DB.Close()

View File

@ -23,6 +23,7 @@ type authMiddleware struct{}
// Config file general section // Config file general section
type general struct { type general struct {
Listen string
Domain string Domain string
Nsname string Nsname string
Nsadmin string Nsadmin string

38
util.go
View File

@ -6,8 +6,10 @@ import (
"fmt" "fmt"
"github.com/BurntSushi/toml" "github.com/BurntSushi/toml"
log "github.com/Sirupsen/logrus" log "github.com/Sirupsen/logrus"
"github.com/miekg/dns"
"github.com/satori/go.uuid" "github.com/satori/go.uuid"
"math/big" "math/big"
"os"
"regexp" "regexp"
"strings" "strings"
) )
@ -22,27 +24,20 @@ func readConfig(fname string) (DNSConfig, error) {
func sanitizeString(s string) string { func sanitizeString(s string) string {
// URL safe base64 alphabet without padding as defined in ACME // URL safe base64 alphabet without padding as defined in ACME
re, err := regexp.Compile("[^A-Za-z\\-\\_0-9]+") re, _ := regexp.Compile("[^A-Za-z\\-\\_0-9]+")
if err != nil {
log.Errorf("%v", err)
return ""
}
return re.ReplaceAllString(s, "") return re.ReplaceAllString(s, "")
} }
func generatePassword(length int) (string, error) { func generatePassword(length int) string {
ret := make([]byte, length) ret := make([]byte, length)
const alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz1234567890-_" const alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz1234567890-_"
alphalen := big.NewInt(int64(len(alphabet))) alphalen := big.NewInt(int64(len(alphabet)))
for i := 0; i < length; i++ { for i := 0; i < length; i++ {
c, err := rand.Int(rand.Reader, alphalen) c, _ := rand.Int(rand.Reader, alphalen)
if err != nil {
return "", err
}
r := int(c.Int64()) r := int(c.Int64())
ret[i] = alphabet[r] ret[i] = alphabet[r]
} }
return string(ret), nil return string(ret)
} }
func sanitizeDomainQuestion(d string) string { func sanitizeDomainQuestion(d string) string {
@ -57,17 +52,14 @@ func sanitizeDomainQuestion(d string) string {
func newACMETxt() (ACMETxt, error) { func newACMETxt() (ACMETxt, error) {
var a = ACMETxt{} var a = ACMETxt{}
password, err := generatePassword(40) password := generatePassword(40)
if err != nil {
return a, err
}
a.Username = uuid.NewV4() a.Username = uuid.NewV4()
a.Password = password a.Password = password
a.Subdomain = uuid.NewV4().String() a.Subdomain = uuid.NewV4().String()
return a, nil return a, nil
} }
func setupLogging() { func setupLogging(format string, level string) {
if DNSConf.Logconfig.Format == "json" { if DNSConf.Logconfig.Format == "json" {
log.SetFormatter(&log.JSONFormatter{}) log.SetFormatter(&log.JSONFormatter{})
} }
@ -83,3 +75,17 @@ func setupLogging() {
} }
// TODO: file logging // TODO: file logging
} }
func startDNS(listen string) *dns.Server {
// DNS server part
dns.HandleFunc(".", handleRequest)
server := &dns.Server{Addr: listen, Net: "udp"}
go func() {
err := server.ListenAndServe()
if err != nil {
log.Errorf("%v", err)
os.Exit(1)
}
}()
return server
}