Refactoring and comments

This commit is contained in:
Joona Hoikkala 2016-11-23 18:07:38 +02:00
parent ba63bad793
commit 670c20f904
5 changed files with 36 additions and 46 deletions

34
api.go
View File

@ -6,20 +6,8 @@ import (
"github.com/kataras/iris"
)
func GetHandlerMap() map[string]func(*iris.Context) {
return map[string]func(*iris.Context){
"/register": WebRegisterGet,
}
}
func PostHandlerMap() map[string]func(*iris.Context) {
return map[string]func(*iris.Context){
"/register": WebRegisterPost,
"/update": WebUpdatePost,
}
}
func (a AuthMiddleware) Serve(ctx *iris.Context) {
// Serve is an authentication middlware function used to authenticate update requests
func (a authMiddleware) Serve(ctx *iris.Context) {
usernameStr := ctx.RequestHeader("X-Api-User")
password := ctx.RequestHeader("X-Api-Key")
postData := ACMETxt{}
@ -44,7 +32,7 @@ func (a AuthMiddleware) Serve(ctx *iris.Context) {
ctx.JSON(iris.StatusUnauthorized, iris.Map{"error": "unauthorized"})
}
func WebRegisterPost(ctx *iris.Context) {
func webRegisterPost(ctx *iris.Context) {
// Create new user
nu, err := DB.Register()
var regJSON iris.Map
@ -63,25 +51,25 @@ func WebRegisterPost(ctx *iris.Context) {
ctx.JSON(regStatus, regJSON)
}
func WebRegisterGet(ctx *iris.Context) {
func webRegisterGet(ctx *iris.Context) {
// This is placeholder for now
WebRegisterPost(ctx)
webRegisterPost(ctx)
}
func WebUpdatePost(ctx *iris.Context) {
func webUpdatePost(ctx *iris.Context) {
// User auth done in middleware
a := ACMETxt{}
userStr := ctx.RequestHeader("X-API-User")
username, err := getValidUsername(userStr)
if err != nil {
log.Warningf("Error while getting username [%s]. This should never happen because of auth middlware.", userStr)
WebUpdatePostError(ctx, err, iris.StatusUnauthorized)
webUpdatePostError(ctx, err, iris.StatusUnauthorized)
return
}
if err := ctx.ReadJSON(&a); err != nil {
// Handle bad post data
log.Warningf("Could not unmarshal: [%v]", err)
WebUpdatePostError(ctx, err, iris.StatusBadRequest)
webUpdatePostError(ctx, err, iris.StatusBadRequest)
return
}
a.Username = username
@ -90,18 +78,18 @@ func WebUpdatePost(ctx *iris.Context) {
err := DB.Update(a)
if err != nil {
log.Warningf("Error trying to update [%v]", err)
WebUpdatePostError(ctx, errors.New("internal error"), iris.StatusInternalServerError)
webUpdatePostError(ctx, errors.New("internal error"), iris.StatusInternalServerError)
return
}
ctx.JSON(iris.StatusOK, iris.Map{"txt": a.Value})
} else {
log.Warningf("Bad data, subdomain: [%s], txt: [%s]", a.Subdomain, a.Value)
WebUpdatePostError(ctx, errors.New("bad data"), iris.StatusBadRequest)
webUpdatePostError(ctx, errors.New("bad data"), iris.StatusBadRequest)
return
}
}
func WebUpdatePostError(ctx *iris.Context, err error, status int) {
func webUpdatePostError(ctx *iris.Context, err error, status int) {
errStr := fmt.Sprintf("%v", err)
updJSON := iris.Map{"error": errStr}
ctx.JSON(status, updJSON)

12
db.go
View File

@ -11,7 +11,7 @@ import (
"time"
)
type Database struct {
type database struct {
DB *sql.DB
}
@ -34,7 +34,7 @@ func getSQLiteStmt(s string) string {
return re.ReplaceAllString(s, "?")
}
func (d *Database) Init(engine string, connection string) error {
func (d *database) Init(engine string, connection string) error {
db, err := sql.Open(engine, connection)
if err != nil {
return err
@ -48,7 +48,7 @@ func (d *Database) Init(engine string, connection string) error {
return nil
}
func (d *Database) Register() (ACMETxt, error) {
func (d *database) Register() (ACMETxt, error) {
a, err := newACMETxt()
if err != nil {
return ACMETxt{}, err
@ -78,7 +78,7 @@ func (d *Database) Register() (ACMETxt, error) {
return a, nil
}
func (d *Database) GetByUsername(u uuid.UUID) (ACMETxt, error) {
func (d *database) GetByUsername(u uuid.UUID) (ACMETxt, error) {
var results []ACMETxt
getSQL := `
SELECT Username, Password, Subdomain, Value, LastActive
@ -120,7 +120,7 @@ func (d *Database) GetByUsername(u uuid.UUID) (ACMETxt, error) {
return ACMETxt{}, errors.New("no user")
}
func (d *Database) GetByDomain(domain string) ([]ACMETxt, error) {
func (d *database) GetByDomain(domain string) ([]ACMETxt, error) {
domain = sanitizeString(domain)
log.Debugf("Trying to select domain [%s] from table", domain)
var a []ACMETxt
@ -155,7 +155,7 @@ func (d *Database) GetByDomain(domain string) ([]ACMETxt, error) {
return a, nil
}
func (d *Database) Update(a ACMETxt) error {
func (d *database) Update(a ACMETxt) error {
// Data in a is already sanitized
log.Debugf("Trying to update domain [%s] with TXT data [%s]", a.Subdomain, a.Value)
timenow := time.Now().Unix()

12
dns.go
View File

@ -20,7 +20,7 @@ func readQuery(m *dns.Msg) {
func answerTXT(q dns.Question) ([]dns.RR, int, error) {
var ra []dns.RR
var rcode int = dns.RcodeNameError
var rcode = dns.RcodeNameError
var domain = strings.ToLower(q.Name)
atxt, err := DB.GetByDomain(sanitizeDomainQuestion(domain))
@ -46,9 +46,9 @@ func answer(q dns.Question) ([]dns.RR, int, error) {
return answerTXT(q)
}
var r []dns.RR
var rcode int = dns.RcodeSuccess
var rcode = dns.RcodeSuccess
var domain = strings.ToLower(q.Name)
var rtype uint16 = q.Qtype
var rtype = q.Qtype
r, ok := RR.Records[rtype][domain]
if !ok {
rcode = dns.RcodeNameError
@ -78,7 +78,7 @@ func (r *Records) Parse(recs []string) {
continue
}
// Add parsed RR to the list
rrmap = AppendRR(rrmap, rr)
rrmap = appendRR(rrmap, rr)
}
// Create serial
serial := time.Now().Format("2006010215")
@ -88,12 +88,12 @@ func (r *Records) Parse(recs []string) {
if err != nil {
log.Errorf("Error [%v] while trying to add SOA record: [%s]", err, SOAstring)
} else {
rrmap = AppendRR(rrmap, soarr)
rrmap = appendRR(rrmap, soarr)
}
r.Records = rrmap
}
func AppendRR(rrmap map[uint16]map[string][]dns.RR, rr dns.RR) map[uint16]map[string][]dns.RR {
func appendRR(rrmap map[uint16]map[string][]dns.RR, rr dns.RR) map[uint16]map[string][]dns.RR {
_, ok := rrmap[rr.Header().Rrtype]
if !ok {
newrr := make(map[string][]dns.RR)

15
main.go
View File

@ -12,12 +12,13 @@ import (
// Logging config
var log = logging.MustGetLogger("acme-dns")
// Global configuration struct
// DNSConf is global configuration struct
var DNSConf DNSConfig
var DB Database
// DB is used to access the database functions in acme-dns
var DB database
// Static records
// RR holds the static DNS records
var RR Records
func main() {
@ -63,10 +64,10 @@ func main() {
Debug: DNSConf.General.Debug,
})
api.Use(crs)
var ForceAuth = AuthMiddleware{}
api.Get("/register", WebRegisterGet)
api.Post("/register", WebRegisterPost)
api.Post("/update", ForceAuth.Serve, WebUpdatePost)
var ForceAuth = authMiddleware{}
api.Get("/register", webRegisterGet)
api.Post("/register", webRegisterPost)
api.Post("/update", ForceAuth.Serve, webUpdatePost)
// TODO: migrate to api.Serve(iris.LETSENCRYPTPROD("mydomain.com"))
switch DNSConf.API.TLS {
case "letsencrypt":

View File

@ -5,12 +5,12 @@ import (
"github.com/satori/go.uuid"
)
// Static records
// Records is for static records
type Records struct {
Records map[uint16]map[string][]dns.RR
}
// Config file main struct
// DNSConfig holds the config structure
type DNSConfig struct {
General general
Database dbsettings
@ -19,7 +19,7 @@ type DNSConfig struct {
}
// Auth middleware
type AuthMiddleware struct{}
type authMiddleware struct{}
// Config file general section
type general struct {
@ -53,7 +53,7 @@ type logconfig struct {
Format string `toml:"logformat"`
}
// The default object
// ACMETxt is the default structure for the user controlled record
type ACMETxt struct {
Username uuid.UUID
Password string
@ -61,6 +61,7 @@ type ACMETxt struct {
LastActive int64
}
// ACMETxtPost holds the DNS part of the ACMETxt struct
type ACMETxtPost struct {
Subdomain string `json:"subdomain"`
Value string `json:"txt"`