Refactoring

This commit is contained in:
Joona Hoikkala
2016-11-23 17:11:31 +02:00
parent f32c4940e1
commit ba63bad793
8 changed files with 66 additions and 58 deletions

12
api.go
View File

@ -24,10 +24,10 @@ func (a AuthMiddleware) Serve(ctx *iris.Context) {
password := ctx.RequestHeader("X-Api-Key")
postData := ACMETxt{}
username, err := GetValidUsername(usernameStr)
if err == nil && ValidKey(password) {
username, err := getValidUsername(usernameStr)
if err == nil && validKey(password) {
au, err := DB.GetByUsername(username)
if err == nil && CorrectPassword(password, au.Password) {
if err == nil && correctPassword(password, au.Password) {
// Password ok
if err := ctx.ReadJSON(&postData); err == nil {
// Check that the subdomain belongs to the user
@ -39,7 +39,7 @@ func (a AuthMiddleware) Serve(ctx *iris.Context) {
}
}
// To protect against timed side channel (never gonna give you up)
CorrectPassword(password, "$2a$10$8JEFVNYYhLoBysjAxe2yBuXrkDojBQBkVpXEQgyQyjn43SvJ4vL36")
correctPassword(password, "$2a$10$8JEFVNYYhLoBysjAxe2yBuXrkDojBQBkVpXEQgyQyjn43SvJ4vL36")
}
ctx.JSON(iris.StatusUnauthorized, iris.Map{"error": "unauthorized"})
}
@ -72,7 +72,7 @@ func WebUpdatePost(ctx *iris.Context) {
// User auth done in middleware
a := ACMETxt{}
userStr := ctx.RequestHeader("X-API-User")
username, err := GetValidUsername(userStr)
username, err := getValidUsername(userStr)
if err != nil {
log.Warningf("Error while getting username [%s]. This should never happen because of auth middlware.", userStr)
WebUpdatePostError(ctx, err, iris.StatusUnauthorized)
@ -86,7 +86,7 @@ func WebUpdatePost(ctx *iris.Context) {
}
a.Username = username
// Do update
if ValidSubdomain(a.Subdomain) && ValidTXT(a.Value) {
if validSubdomain(a.Subdomain) && validTXT(a.Value) {
err := DB.Update(a)
if err != nil {
log.Warningf("Error trying to update [%v]", err)

4
db.go
View File

@ -49,7 +49,7 @@ func (d *Database) Init(engine string, connection string) error {
}
func (d *Database) Register() (ACMETxt, error) {
a, err := NewACMETxt()
a, err := newACMETxt()
if err != nil {
return ACMETxt{}, err
}
@ -121,7 +121,7 @@ func (d *Database) GetByUsername(u uuid.UUID) (ACMETxt, error) {
}
func (d *Database) GetByDomain(domain string) ([]ACMETxt, error) {
domain = SanitizeString(domain)
domain = sanitizeString(domain)
log.Debugf("Trying to select domain [%s] from table", domain)
var a []ACMETxt
getSQL := `

View File

@ -66,7 +66,7 @@ func TestGetByUsername(t *testing.T) {
}
// regUser password already is a bcrypt hash
if !CorrectPassword(reg.Password, regUser.Password) {
if !correctPassword(reg.Password, regUser.Password) {
t.Errorf("The password [%s] does not match the hash [%s]", reg.Password, regUser.Password)
}
}
@ -113,7 +113,7 @@ func TestGetByDomain(t *testing.T) {
}
// regDomain password already is a bcrypt hash
if !CorrectPassword(reg.Password, regDomain.Password) {
if !correctPassword(reg.Password, regDomain.Password) {
t.Errorf("The password [%s] does not match the hash [%s]", reg.Password, regDomain.Password)
}

2
dns.go
View File

@ -23,7 +23,7 @@ func answerTXT(q dns.Question) ([]dns.RR, int, error) {
var rcode int = dns.RcodeNameError
var domain = strings.ToLower(q.Name)
atxt, err := DB.GetByDomain(SanitizeDomainQuestion(domain))
atxt, err := DB.GetByDomain(sanitizeDomainQuestion(domain))
if err != nil {
log.Errorf("Error while trying to get record [%v]", err)
return ra, dns.RcodeNameError, err

32
main.go
View File

@ -28,36 +28,8 @@ func main() {
os.Exit(1)
}
DNSConf = configTmp
// Setup logging
var logformat = logging.MustStringFormatter(DNSConf.Logconfig.Format)
var logBackend *logging.LogBackend
switch DNSConf.Logconfig.Logtype {
default:
// Setup logging - stdout
logBackend = logging.NewLogBackend(os.Stdout, "", 0)
case "file":
// Logging to file
logfh, err := os.OpenFile(DNSConf.Logconfig.File, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666)
if err != nil {
fmt.Printf("Could not open log file %s\n", DNSConf.Logconfig.File)
os.Exit(1)
}
defer logfh.Close()
logBackend = logging.NewLogBackend(logfh, "", 0)
}
logFormatter := logging.NewBackendFormatter(logBackend, logformat)
logLevel := logging.AddModuleLevel(logFormatter)
switch DNSConf.Logconfig.Level {
default:
logLevel.SetLevel(logging.DEBUG, "")
case "warning":
logLevel.SetLevel(logging.WARNING, "")
case "error":
logLevel.SetLevel(logging.ERROR, "")
case "info":
logLevel.SetLevel(logging.INFO, "")
}
logging.SetBackend(logFormatter)
setupLogging()
// Read the default records in
RR.Parse(DNSConf.General.StaticRecords)

46
util.go
View File

@ -3,9 +3,12 @@ package main
import (
"crypto/rand"
"errors"
"fmt"
"github.com/BurntSushi/toml"
"github.com/op/go-logging"
"github.com/satori/go.uuid"
"math/big"
"os"
"regexp"
"strings"
)
@ -18,7 +21,7 @@ func readConfig(fname string) (DNSConfig, error) {
return conf, nil
}
func SanitizeString(s string) string {
func sanitizeString(s string) string {
// URL safe base64 alphabet without padding as defined in ACME
re, err := regexp.Compile("[^A-Za-z\\-\\_0-9]+")
if err != nil {
@ -28,7 +31,7 @@ func SanitizeString(s string) string {
return re.ReplaceAllString(s, "")
}
func GeneratePassword(length int) (string, error) {
func generatePassword(length int) (string, error) {
ret := make([]byte, length)
const alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz1234567890-_"
alphalen := big.NewInt(int64(len(alphabet)))
@ -43,7 +46,7 @@ func GeneratePassword(length int) (string, error) {
return string(ret), nil
}
func SanitizeDomainQuestion(d string) string {
func sanitizeDomainQuestion(d string) string {
var dom string
suffix := DNSConf.General.Domain + "."
if strings.HasSuffix(d, suffix) {
@ -54,9 +57,9 @@ func SanitizeDomainQuestion(d string) string {
return dom
}
func NewACMETxt() (ACMETxt, error) {
func newACMETxt() (ACMETxt, error) {
var a = ACMETxt{}
password, err := GeneratePassword(40)
password, err := generatePassword(40)
if err != nil {
return a, err
}
@ -65,3 +68,36 @@ func NewACMETxt() (ACMETxt, error) {
a.Subdomain = uuid.NewV4().String()
return a, nil
}
func setupLogging() {
var logformat = logging.MustStringFormatter(DNSConf.Logconfig.Format)
var logBackend *logging.LogBackend
switch DNSConf.Logconfig.Logtype {
default:
// Setup logging - stdout
logBackend = logging.NewLogBackend(os.Stdout, "", 0)
case "file":
// Logging to file
logfh, err := os.OpenFile(DNSConf.Logconfig.File, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666)
if err != nil {
fmt.Printf("Could not open log file %s\n", DNSConf.Logconfig.File)
os.Exit(1)
}
defer logfh.Close()
logBackend = logging.NewLogBackend(logfh, "", 0)
}
logFormatter := logging.NewBackendFormatter(logBackend, logformat)
logLevel := logging.AddModuleLevel(logFormatter)
switch DNSConf.Logconfig.Level {
default:
logLevel.SetLevel(logging.DEBUG, "")
case "warning":
logLevel.SetLevel(logging.WARNING, "")
case "error":
logLevel.SetLevel(logging.ERROR, "")
case "info":
logLevel.SetLevel(logging.INFO, "")
}
logging.SetBackend(logFormatter)
}

View File

@ -6,7 +6,7 @@ import (
"unicode/utf8"
)
func GetValidUsername(u string) (uuid.UUID, error) {
func getValidUsername(u string) (uuid.UUID, error) {
uname, err := uuid.FromString(u)
if err != nil {
return uuid.UUID{}, err
@ -14,8 +14,8 @@ func GetValidUsername(u string) (uuid.UUID, error) {
return uname, nil
}
func ValidKey(k string) bool {
kn := SanitizeString(k)
func validKey(k string) bool {
kn := sanitizeString(k)
if utf8.RuneCountInString(k) == 40 && utf8.RuneCountInString(kn) == 40 {
// Correct length and all chars valid
return true
@ -23,7 +23,7 @@ func ValidKey(k string) bool {
return false
}
func ValidSubdomain(s string) bool {
func validSubdomain(s string) bool {
_, err := uuid.FromString(s)
if err == nil {
return true
@ -31,8 +31,8 @@ func ValidSubdomain(s string) bool {
return false
}
func ValidTXT(s string) bool {
sn := SanitizeString(s)
func validTXT(s string) bool {
sn := sanitizeString(s)
if utf8.RuneCountInString(s) == 43 && utf8.RuneCountInString(sn) == 43 {
// 43 chars is the current LE auth key size, but not limited / defined by ACME
return true
@ -40,7 +40,7 @@ func ValidTXT(s string) bool {
return false
}
func CorrectPassword(pw string, hash string) bool {
func correctPassword(pw string, hash string) bool {
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(pw)); err == nil {
return true
}

View File

@ -17,7 +17,7 @@ func TestGetValidUsername(t *testing.T) {
{"", uuid.UUID{}, true},
{"&!#!25123!%!'%", uuid.UUID{}, true},
} {
ret, err := GetValidUsername(test.uname)
ret, err := getValidUsername(test.uname)
if test.shouldErr && err == nil {
t.Errorf("Test %d: Expected error, but there was none", i)
}
@ -41,7 +41,7 @@ func TestValidKey(t *testing.T) {
{"aaaaaaaa-aaa-aaaaaa#aaaaaaaa-aaa_aacaaaa", false},
{"aaaaaaaa-aaa-aaaaaa-aaaaaaaa-aaa_aacaaaaa", false},
} {
ret := ValidKey(test.key)
ret := validKey(test.key)
if ret != test.output {
t.Errorf("Test %d: Expected return value %t, but got %t", i, test.output, ret)
}
@ -58,7 +58,7 @@ func TestGetValidSubdomain(t *testing.T) {
{"", false},
{"&!#!25123!%!'%", false},
} {
ret := ValidSubdomain(test.subdomain)
ret := validSubdomain(test.subdomain)
if ret != test.output {
t.Errorf("Test %d: Expected return value %t, but got %t", i, test.output, ret)
}
@ -76,7 +76,7 @@ func TestValidTXT(t *testing.T) {
{"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", false},
{"", false},
} {
ret := ValidTXT(test.txt)
ret := validTXT(test.txt)
if ret != test.output {
t.Errorf("Test %d: Expected return value %t, but got %t", i, test.output, ret)
}
@ -100,7 +100,7 @@ func TestCorrectPassword(t *testing.T) {
false},
{"", "", false},
} {
ret := CorrectPassword(test.pw, test.hash)
ret := correctPassword(test.pw, test.hash)
if ret != test.output {
t.Errorf("Test %d: Expected return value %t, but got %t", i, test.output, ret)
}