mirror of
https://github.com/joohoi/acme-dns.git
synced 2025-07-28 13:47:44 +07:00
Refactoring
This commit is contained in:
12
api.go
12
api.go
@ -24,10 +24,10 @@ func (a AuthMiddleware) Serve(ctx *iris.Context) {
|
||||
password := ctx.RequestHeader("X-Api-Key")
|
||||
postData := ACMETxt{}
|
||||
|
||||
username, err := GetValidUsername(usernameStr)
|
||||
if err == nil && ValidKey(password) {
|
||||
username, err := getValidUsername(usernameStr)
|
||||
if err == nil && validKey(password) {
|
||||
au, err := DB.GetByUsername(username)
|
||||
if err == nil && CorrectPassword(password, au.Password) {
|
||||
if err == nil && correctPassword(password, au.Password) {
|
||||
// Password ok
|
||||
if err := ctx.ReadJSON(&postData); err == nil {
|
||||
// Check that the subdomain belongs to the user
|
||||
@ -39,7 +39,7 @@ func (a AuthMiddleware) Serve(ctx *iris.Context) {
|
||||
}
|
||||
}
|
||||
// To protect against timed side channel (never gonna give you up)
|
||||
CorrectPassword(password, "$2a$10$8JEFVNYYhLoBysjAxe2yBuXrkDojBQBkVpXEQgyQyjn43SvJ4vL36")
|
||||
correctPassword(password, "$2a$10$8JEFVNYYhLoBysjAxe2yBuXrkDojBQBkVpXEQgyQyjn43SvJ4vL36")
|
||||
}
|
||||
ctx.JSON(iris.StatusUnauthorized, iris.Map{"error": "unauthorized"})
|
||||
}
|
||||
@ -72,7 +72,7 @@ func WebUpdatePost(ctx *iris.Context) {
|
||||
// User auth done in middleware
|
||||
a := ACMETxt{}
|
||||
userStr := ctx.RequestHeader("X-API-User")
|
||||
username, err := GetValidUsername(userStr)
|
||||
username, err := getValidUsername(userStr)
|
||||
if err != nil {
|
||||
log.Warningf("Error while getting username [%s]. This should never happen because of auth middlware.", userStr)
|
||||
WebUpdatePostError(ctx, err, iris.StatusUnauthorized)
|
||||
@ -86,7 +86,7 @@ func WebUpdatePost(ctx *iris.Context) {
|
||||
}
|
||||
a.Username = username
|
||||
// Do update
|
||||
if ValidSubdomain(a.Subdomain) && ValidTXT(a.Value) {
|
||||
if validSubdomain(a.Subdomain) && validTXT(a.Value) {
|
||||
err := DB.Update(a)
|
||||
if err != nil {
|
||||
log.Warningf("Error trying to update [%v]", err)
|
||||
|
4
db.go
4
db.go
@ -49,7 +49,7 @@ func (d *Database) Init(engine string, connection string) error {
|
||||
}
|
||||
|
||||
func (d *Database) Register() (ACMETxt, error) {
|
||||
a, err := NewACMETxt()
|
||||
a, err := newACMETxt()
|
||||
if err != nil {
|
||||
return ACMETxt{}, err
|
||||
}
|
||||
@ -121,7 +121,7 @@ func (d *Database) GetByUsername(u uuid.UUID) (ACMETxt, error) {
|
||||
}
|
||||
|
||||
func (d *Database) GetByDomain(domain string) ([]ACMETxt, error) {
|
||||
domain = SanitizeString(domain)
|
||||
domain = sanitizeString(domain)
|
||||
log.Debugf("Trying to select domain [%s] from table", domain)
|
||||
var a []ACMETxt
|
||||
getSQL := `
|
||||
|
@ -66,7 +66,7 @@ func TestGetByUsername(t *testing.T) {
|
||||
}
|
||||
|
||||
// regUser password already is a bcrypt hash
|
||||
if !CorrectPassword(reg.Password, regUser.Password) {
|
||||
if !correctPassword(reg.Password, regUser.Password) {
|
||||
t.Errorf("The password [%s] does not match the hash [%s]", reg.Password, regUser.Password)
|
||||
}
|
||||
}
|
||||
@ -113,7 +113,7 @@ func TestGetByDomain(t *testing.T) {
|
||||
}
|
||||
|
||||
// regDomain password already is a bcrypt hash
|
||||
if !CorrectPassword(reg.Password, regDomain.Password) {
|
||||
if !correctPassword(reg.Password, regDomain.Password) {
|
||||
t.Errorf("The password [%s] does not match the hash [%s]", reg.Password, regDomain.Password)
|
||||
}
|
||||
|
||||
|
2
dns.go
2
dns.go
@ -23,7 +23,7 @@ func answerTXT(q dns.Question) ([]dns.RR, int, error) {
|
||||
var rcode int = dns.RcodeNameError
|
||||
var domain = strings.ToLower(q.Name)
|
||||
|
||||
atxt, err := DB.GetByDomain(SanitizeDomainQuestion(domain))
|
||||
atxt, err := DB.GetByDomain(sanitizeDomainQuestion(domain))
|
||||
if err != nil {
|
||||
log.Errorf("Error while trying to get record [%v]", err)
|
||||
return ra, dns.RcodeNameError, err
|
||||
|
32
main.go
32
main.go
@ -28,36 +28,8 @@ func main() {
|
||||
os.Exit(1)
|
||||
}
|
||||
DNSConf = configTmp
|
||||
// Setup logging
|
||||
var logformat = logging.MustStringFormatter(DNSConf.Logconfig.Format)
|
||||
var logBackend *logging.LogBackend
|
||||
switch DNSConf.Logconfig.Logtype {
|
||||
default:
|
||||
// Setup logging - stdout
|
||||
logBackend = logging.NewLogBackend(os.Stdout, "", 0)
|
||||
case "file":
|
||||
// Logging to file
|
||||
logfh, err := os.OpenFile(DNSConf.Logconfig.File, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666)
|
||||
if err != nil {
|
||||
fmt.Printf("Could not open log file %s\n", DNSConf.Logconfig.File)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer logfh.Close()
|
||||
logBackend = logging.NewLogBackend(logfh, "", 0)
|
||||
}
|
||||
logFormatter := logging.NewBackendFormatter(logBackend, logformat)
|
||||
logLevel := logging.AddModuleLevel(logFormatter)
|
||||
switch DNSConf.Logconfig.Level {
|
||||
default:
|
||||
logLevel.SetLevel(logging.DEBUG, "")
|
||||
case "warning":
|
||||
logLevel.SetLevel(logging.WARNING, "")
|
||||
case "error":
|
||||
logLevel.SetLevel(logging.ERROR, "")
|
||||
case "info":
|
||||
logLevel.SetLevel(logging.INFO, "")
|
||||
}
|
||||
logging.SetBackend(logFormatter)
|
||||
|
||||
setupLogging()
|
||||
|
||||
// Read the default records in
|
||||
RR.Parse(DNSConf.General.StaticRecords)
|
||||
|
46
util.go
46
util.go
@ -3,9 +3,12 @@ package main
|
||||
import (
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/BurntSushi/toml"
|
||||
"github.com/op/go-logging"
|
||||
"github.com/satori/go.uuid"
|
||||
"math/big"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
@ -18,7 +21,7 @@ func readConfig(fname string) (DNSConfig, error) {
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
func SanitizeString(s string) string {
|
||||
func sanitizeString(s string) string {
|
||||
// URL safe base64 alphabet without padding as defined in ACME
|
||||
re, err := regexp.Compile("[^A-Za-z\\-\\_0-9]+")
|
||||
if err != nil {
|
||||
@ -28,7 +31,7 @@ func SanitizeString(s string) string {
|
||||
return re.ReplaceAllString(s, "")
|
||||
}
|
||||
|
||||
func GeneratePassword(length int) (string, error) {
|
||||
func generatePassword(length int) (string, error) {
|
||||
ret := make([]byte, length)
|
||||
const alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz1234567890-_"
|
||||
alphalen := big.NewInt(int64(len(alphabet)))
|
||||
@ -43,7 +46,7 @@ func GeneratePassword(length int) (string, error) {
|
||||
return string(ret), nil
|
||||
}
|
||||
|
||||
func SanitizeDomainQuestion(d string) string {
|
||||
func sanitizeDomainQuestion(d string) string {
|
||||
var dom string
|
||||
suffix := DNSConf.General.Domain + "."
|
||||
if strings.HasSuffix(d, suffix) {
|
||||
@ -54,9 +57,9 @@ func SanitizeDomainQuestion(d string) string {
|
||||
return dom
|
||||
}
|
||||
|
||||
func NewACMETxt() (ACMETxt, error) {
|
||||
func newACMETxt() (ACMETxt, error) {
|
||||
var a = ACMETxt{}
|
||||
password, err := GeneratePassword(40)
|
||||
password, err := generatePassword(40)
|
||||
if err != nil {
|
||||
return a, err
|
||||
}
|
||||
@ -65,3 +68,36 @@ func NewACMETxt() (ACMETxt, error) {
|
||||
a.Subdomain = uuid.NewV4().String()
|
||||
return a, nil
|
||||
}
|
||||
|
||||
func setupLogging() {
|
||||
var logformat = logging.MustStringFormatter(DNSConf.Logconfig.Format)
|
||||
var logBackend *logging.LogBackend
|
||||
switch DNSConf.Logconfig.Logtype {
|
||||
default:
|
||||
// Setup logging - stdout
|
||||
logBackend = logging.NewLogBackend(os.Stdout, "", 0)
|
||||
case "file":
|
||||
// Logging to file
|
||||
logfh, err := os.OpenFile(DNSConf.Logconfig.File, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666)
|
||||
if err != nil {
|
||||
fmt.Printf("Could not open log file %s\n", DNSConf.Logconfig.File)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer logfh.Close()
|
||||
logBackend = logging.NewLogBackend(logfh, "", 0)
|
||||
}
|
||||
logFormatter := logging.NewBackendFormatter(logBackend, logformat)
|
||||
logLevel := logging.AddModuleLevel(logFormatter)
|
||||
switch DNSConf.Logconfig.Level {
|
||||
default:
|
||||
logLevel.SetLevel(logging.DEBUG, "")
|
||||
case "warning":
|
||||
logLevel.SetLevel(logging.WARNING, "")
|
||||
case "error":
|
||||
logLevel.SetLevel(logging.ERROR, "")
|
||||
case "info":
|
||||
logLevel.SetLevel(logging.INFO, "")
|
||||
}
|
||||
logging.SetBackend(logFormatter)
|
||||
|
||||
}
|
||||
|
@ -6,7 +6,7 @@ import (
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
func GetValidUsername(u string) (uuid.UUID, error) {
|
||||
func getValidUsername(u string) (uuid.UUID, error) {
|
||||
uname, err := uuid.FromString(u)
|
||||
if err != nil {
|
||||
return uuid.UUID{}, err
|
||||
@ -14,8 +14,8 @@ func GetValidUsername(u string) (uuid.UUID, error) {
|
||||
return uname, nil
|
||||
}
|
||||
|
||||
func ValidKey(k string) bool {
|
||||
kn := SanitizeString(k)
|
||||
func validKey(k string) bool {
|
||||
kn := sanitizeString(k)
|
||||
if utf8.RuneCountInString(k) == 40 && utf8.RuneCountInString(kn) == 40 {
|
||||
// Correct length and all chars valid
|
||||
return true
|
||||
@ -23,7 +23,7 @@ func ValidKey(k string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func ValidSubdomain(s string) bool {
|
||||
func validSubdomain(s string) bool {
|
||||
_, err := uuid.FromString(s)
|
||||
if err == nil {
|
||||
return true
|
||||
@ -31,8 +31,8 @@ func ValidSubdomain(s string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func ValidTXT(s string) bool {
|
||||
sn := SanitizeString(s)
|
||||
func validTXT(s string) bool {
|
||||
sn := sanitizeString(s)
|
||||
if utf8.RuneCountInString(s) == 43 && utf8.RuneCountInString(sn) == 43 {
|
||||
// 43 chars is the current LE auth key size, but not limited / defined by ACME
|
||||
return true
|
||||
@ -40,7 +40,7 @@ func ValidTXT(s string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func CorrectPassword(pw string, hash string) bool {
|
||||
func correctPassword(pw string, hash string) bool {
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(pw)); err == nil {
|
||||
return true
|
||||
}
|
||||
|
@ -17,7 +17,7 @@ func TestGetValidUsername(t *testing.T) {
|
||||
{"", uuid.UUID{}, true},
|
||||
{"&!#!25123!%!'%", uuid.UUID{}, true},
|
||||
} {
|
||||
ret, err := GetValidUsername(test.uname)
|
||||
ret, err := getValidUsername(test.uname)
|
||||
if test.shouldErr && err == nil {
|
||||
t.Errorf("Test %d: Expected error, but there was none", i)
|
||||
}
|
||||
@ -41,7 +41,7 @@ func TestValidKey(t *testing.T) {
|
||||
{"aaaaaaaa-aaa-aaaaaa#aaaaaaaa-aaa_aacaaaa", false},
|
||||
{"aaaaaaaa-aaa-aaaaaa-aaaaaaaa-aaa_aacaaaaa", false},
|
||||
} {
|
||||
ret := ValidKey(test.key)
|
||||
ret := validKey(test.key)
|
||||
if ret != test.output {
|
||||
t.Errorf("Test %d: Expected return value %t, but got %t", i, test.output, ret)
|
||||
}
|
||||
@ -58,7 +58,7 @@ func TestGetValidSubdomain(t *testing.T) {
|
||||
{"", false},
|
||||
{"&!#!25123!%!'%", false},
|
||||
} {
|
||||
ret := ValidSubdomain(test.subdomain)
|
||||
ret := validSubdomain(test.subdomain)
|
||||
if ret != test.output {
|
||||
t.Errorf("Test %d: Expected return value %t, but got %t", i, test.output, ret)
|
||||
}
|
||||
@ -76,7 +76,7 @@ func TestValidTXT(t *testing.T) {
|
||||
{"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", false},
|
||||
{"", false},
|
||||
} {
|
||||
ret := ValidTXT(test.txt)
|
||||
ret := validTXT(test.txt)
|
||||
if ret != test.output {
|
||||
t.Errorf("Test %d: Expected return value %t, but got %t", i, test.output, ret)
|
||||
}
|
||||
@ -100,7 +100,7 @@ func TestCorrectPassword(t *testing.T) {
|
||||
false},
|
||||
{"", "", false},
|
||||
} {
|
||||
ret := CorrectPassword(test.pw, test.hash)
|
||||
ret := correctPassword(test.pw, test.hash)
|
||||
if ret != test.output {
|
||||
t.Errorf("Test %d: Expected return value %t, but got %t", i, test.output, ret)
|
||||
}
|
||||
|
Reference in New Issue
Block a user