Refactoring, alpha v0.1

This commit is contained in:
Joona Hoikkala
2016-11-13 14:50:44 +02:00
parent f20ad11188
commit ed56a11e88
9 changed files with 271 additions and 186 deletions

View File

@ -1,27 +0,0 @@
package main
import (
"github.com/satori/go.uuid"
"time"
)
// The default database object
type ACMETxt struct {
Username string `json:"username"`
Password string `json:"password"`
ACMETxtPost
LastActive time.Time
}
type ACMETxtPost struct {
Subdomain string `json:"subdomain"`
Value string `json:"txt"`
}
func NewACMETxt() ACMETxt {
var a ACMETxt = ACMETxt{}
a.Username = uuid.NewV4().String()
a.Password = uuid.NewV4().String()
a.Subdomain = uuid.NewV4().String()
return a
}

79
api.go
View File

@ -19,6 +19,24 @@ func PostHandlerMap() map[string]func(*iris.Context) {
}
}
func (a AuthMiddleware) Serve(ctx *iris.Context) {
username_str := ctx.RequestHeader("X-Api-User")
password := ctx.RequestHeader("X-Api-Key")
username, err := GetValidUsername(username_str)
if err == nil && ValidKey(password) {
au, err := DB.GetByUsername(username)
if err == nil && CorrectPassword(password, au.Password) {
log.Debugf("Accepted authentication from [%s]", username_str)
ctx.Next()
return
}
// To protect against timed side channel (never gonna give you up)
CorrectPassword(password, "$2a$10$8JEFVNYYhLoBysjAxe2yBuXrkDojBQBkVpXEQgyQyjn43SvJ4vL36")
}
ctx.JSON(iris.StatusUnauthorized, iris.Map{"error": "unauthorized"})
}
func WebRegisterPost(ctx *iris.Context) {
// Create new user
nu, err := DB.Register()
@ -33,6 +51,7 @@ func WebRegisterPost(ctx *iris.Context) {
reg_json = iris.Map{"username": nu.Username, "password": nu.Password, "fulldomain": nu.Subdomain + "." + DnsConf.General.Domain, "subdomain": nu.Subdomain}
reg_status = iris.StatusCreated
}
log.Debugf("Successful registration, created user [%s]", nu.Username)
ctx.JSON(reg_status, reg_json)
}
@ -42,52 +61,36 @@ func WebRegisterGet(ctx *iris.Context) {
}
func WebUpdatePost(ctx *iris.Context) {
var username, password string
var a ACMETxtPost = ACMETxtPost{}
username = ctx.RequestHeader("X-API-User")
password = ctx.RequestHeader("X-API-Key")
// User auth done in middleware
var a ACMETxt = ACMETxt{}
user_string := ctx.RequestHeader("X-API-User")
username, err := GetValidUsername(user_string)
if err != nil {
log.Warningf("Error while getting username [%s]. This should never happen because of auth middlware.", user_string)
WebUpdatePostError(ctx, err, iris.StatusUnauthorized)
return
}
if err := ctx.ReadJSON(&a); err != nil {
// Handle bad post data
log.Warningf("Could not unmarshal: [%v]", err)
WebUpdatePostError(ctx, err, iris.StatusBadRequest)
return
}
// Sanitized by db function
euser, err := DB.GetByUsername(username)
if err != nil {
// DB error
WebUpdatePostError(ctx, err, iris.StatusInternalServerError)
return
}
if len(euser) == 0 {
// User not found
// TODO: do bcrypt to avoid side channel
WebUpdatePostError(ctx, errors.New("invalid user or api key"), iris.StatusUnauthorized)
return
}
// Get first (and the only) user
upduser := euser[0]
// Validate password
if upduser.Password != password {
// Invalid password
WebUpdatePostError(ctx, errors.New("invalid user or api key"), iris.StatusUnauthorized)
return
} else {
// Do update
if len(a.Value) == 0 {
WebUpdatePostError(ctx, errors.New("missing txt value"), iris.StatusBadRequest)
a.Username = username
// Do update
if ValidSubdomain(a.Subdomain) && ValidTXT(a.Value) {
err := DB.Update(a)
if err != nil {
log.Warningf("Error trying to update [%v]", err)
WebUpdatePostError(ctx, errors.New("internal error"), iris.StatusInternalServerError)
return
} else {
upduser.Value = a.Value
err = DB.Update(upduser)
if err != nil {
WebUpdatePostError(ctx, err, iris.StatusInternalServerError)
return
}
// All ok
ctx.JSON(iris.StatusOK, iris.Map{"txt": upduser.Value})
}
ctx.JSON(iris.StatusOK, iris.Map{"txt": a.Value})
} else {
log.Warningf("Bad data, subdomain: [%s], txt: [%s]", a.Subdomain, a.Value)
WebUpdatePostError(ctx, errors.New("bad data"), iris.StatusBadRequest)
return
}
}
func WebUpdatePostError(ctx *iris.Context, err error, status int) {

View File

@ -6,15 +6,18 @@ nsname = "ns1.auth.example.org"
# admin email address, with @ substituted with .
nsadmin = "admin.example.org"
[api]
# domain name to listen requests for, mandatory if using tls = "letsencrypt"
# use "" (empty string) to bind to all interfaces
api_domain = ""
# listen port, eg. 443 for default HTTPS
port = "8080"
# possible values: "letsencrypt", "cert", "false"
tls = "letsencrypt"
# only used if tls = "cert"
tls_cert_privkey = "/etc/tls/example.org/privkey.pem"
tls_cert_fullchain = "/etc/tls/example.org/fullchain.pem"
# predefined records that we're serving in addition to the TXT
records = [
# default A
"auth.example.org. A 192.168.1.100",
@ -25,3 +28,13 @@ records = [
"auth.example.org. NS ns1.auth.example.org.",
"auth.example.org. NS ns2.auth.example.org.",
]
[logconfig]
# logging level
loglevel = "debug"
# possible values: stdout, file
logtype = "stdout"
# file path for logfile
logfile = "./acme-dns.log"
# format
logformat = "%{time:15:04:05.000} %{shortfunc} - %{level:.4s} %{id:03x} %{message}"

93
db.go
View File

@ -2,10 +2,10 @@ package main
import (
"database/sql"
//"encoding/json"
//"github.com/boltdb/bolt"
"errors"
_ "github.com/mattn/go-sqlite3"
//"strings"
"github.com/satori/go.uuid"
"golang.org/x/crypto/bcrypt"
)
type Database struct {
@ -35,7 +35,11 @@ func (d *Database) Init(filename string) error {
}
func (d *Database) Register() (ACMETxt, error) {
a := NewACMETxt()
a, err := NewACMETxt()
if err != nil {
return ACMETxt{}, err
}
password_hash, err := bcrypt.GenerateFromPassword([]byte(a.Password), 10)
reg_sql := `
INSERT INTO records(
Username,
@ -49,54 +53,54 @@ func (d *Database) Register() (ACMETxt, error) {
return a, err
}
defer sm.Close()
_, err = sm.Exec(a.Username, a.Password, a.Subdomain, a.Value)
_, err = sm.Exec(a.Username, password_hash, a.Subdomain, a.Value)
if err != nil {
return a, err
}
// Do an insert check
/*
id, err := status.LastInsertId()
if err != nil {
return a, err
}*/
return a, nil
}
func (d *Database) GetByUsername(u string) ([]ACMETxt, error) {
u = NormalizeString(u, 36)
log.Debugf("Trying to select by user [%s] from table", u)
func (d *Database) GetByUsername(u uuid.UUID) (ACMETxt, error) {
var results []ACMETxt
get_sql := `
SELECT Username, Password, Subdomain, Value
SELECT Username, Password, Subdomain, Value, LastActive
FROM records
WHERE Username=? LIMIT 1
`
sm, err := d.DB.Prepare(get_sql)
if err != nil {
return nil, err
return ACMETxt{}, err
}
defer sm.Close()
rows, err := sm.Query(u)
rows, err := sm.Query(u.String())
if err != nil {
return nil, err
return ACMETxt{}, err
}
defer rows.Close()
// It will only be one row though
for rows.Next() {
var a ACMETxt = ACMETxt{}
err = rows.Scan(&a.Username, &a.Password, &a.Subdomain, &a.Value)
var uname string
err = rows.Scan(&uname, &a.Password, &a.Subdomain, &a.Value, &a.LastActive)
if err != nil {
return nil, err
return ACMETxt{}, err
}
a.Username, err = uuid.FromString(uname)
if err != nil {
return ACMETxt{}, err
}
results = append(results, a)
}
return results, nil
if len(results) > 0 {
return results[0], nil
} else {
return ACMETxt{}, errors.New("no user")
}
}
func (d *Database) GetByDomain(domain string) ([]ACMETxt, error) {
domain = NormalizeString(domain, 36)
domain = SanitizeString(domain)
log.Debugf("Trying to select domain [%s] from table", domain)
var a []ACMETxt
get_sql := `
@ -144,46 +148,3 @@ func (d *Database) Update(a ACMETxt) error {
}
return nil
}
/*
func addTXT(txt ACMETxt) error {
err := db.Update(func(tx *bolt.Tx) error {
bucket, err := tx.CreateBucketIfNotExists([]byte("domains"))
if err != nil {
return err
}
jtxt, err := json.Marshal(txt)
if err != nil {
return err
}
// put returns nil if successful, nil return commits db.Update
return bucket.Put([]byte(strings.ToLower(txt.Domain)), jtxt)
})
return err
}
func getTXT(domain string) (ACMETxt, error) {
var atxt ACMETxt
err := db.View(func(tx *bolt.Tx) error {
bucket := tx.Bucket([]byte("domains"))
value := bucket.Get([]byte(strings.ToLower(domain)))
if len(value) == 0 {
// Not found
log.Debugf("Record for [%s] not found", domain)
atxt = ACMETxt{}
} else {
if err := json.Unmarshal(value, &atxt); err != nil {
return err
}
}
return nil
})
if err != nil {
return ACMETxt{}, err
}
return atxt, err
}
*/

2
dns.go
View File

@ -22,7 +22,7 @@ func answerTXT(q dns.Question) ([]dns.RR, int, error) {
var rcode int = dns.RcodeNameError
var domain string = q.Name
atxt, err := DB.GetByDomain(domain)
atxt, err := DB.GetByDomain(SanitizeDomainQuestion(domain))
if err != nil {
log.Errorf("Error while trying to get record [%v]", err)
return ra, dns.RcodeNameError, err

94
main.go
View File

@ -9,7 +9,6 @@ import (
)
// Logging config
var logfile_path = "acme-dns.log"
var log = logging.MustGetLogger("acme-dns")
// Global configuration struct
@ -21,41 +20,45 @@ var DB Database
var RR Records
func main() {
// Setup logging
var stdout_format = logging.MustStringFormatter(
`%{color}%{time:15:04:05.000} %{shortfunc} ▶ %{level:.4s} %{id:03x}%{color:reset} %{message}`,
)
var file_format = logging.MustStringFormatter(
`%{time:15:04:05.000} %{shortfunc} - %{level:.4s} %{id:03x} %{message}`,
)
// Setup logging - stdout
logStdout := logging.NewLogBackend(os.Stdout, "", 0)
logStdoutFormatter := logging.NewBackendFormatter(logStdout, stdout_format)
// Setup logging - file
// Logging to file
logfh, err := os.OpenFile(logfile_path, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666)
if err != nil {
fmt.Printf("Could not open log file %s\n", logfile_path)
os.Exit(1)
}
defer logfh.Close()
logFile := logging.NewLogBackend(logfh, "", 0)
logFileFormatter := logging.NewBackendFormatter(logFile, file_format)
/* To limit logging to a level
logFileLeveled := logging.AddModuleLevel(logFile)
logFileLeveled.SetLevel(logging.ERROR, "")
*/
// Start logging
logging.SetBackend(logStdoutFormatter, logFileFormatter)
log.Debug("Starting up...")
// Read global config
if DnsConf, err = ReadConfig("config.cfg"); err != nil {
log.Errorf("Got error %v", err)
config_tmp, err := ReadConfig("config.cfg")
if err != nil {
fmt.Printf("Got error %v\n", DnsConf.Logconfig.File)
os.Exit(1)
}
RR.Parse(DnsConf.General.StaticRecords)
DnsConf = config_tmp
// Setup logging
var logformat = logging.MustStringFormatter(DnsConf.Logconfig.Format)
var logBackend *logging.LogBackend
switch DnsConf.Logconfig.Logtype {
default:
// Setup logging - stdout
logBackend = logging.NewLogBackend(os.Stdout, "", 0)
case "file":
// Logging to file
logfh, err := os.OpenFile(DnsConf.Logconfig.File, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666)
if err != nil {
fmt.Printf("Could not open log file %s\n", DnsConf.Logconfig.File)
os.Exit(1)
}
defer logfh.Close()
logBackend = logging.NewLogBackend(logfh, "", 0)
}
logLevel := logging.AddModuleLevel(logBackend)
switch DnsConf.Logconfig.Level {
case "warning":
logLevel.SetLevel(logging.WARNING, "")
case "error":
logLevel.SetLevel(logging.ERROR, "")
case "info":
logLevel.SetLevel(logging.INFO, "")
}
logFormatter := logging.NewBackendFormatter(logLevel, logformat)
logging.SetBackend(logFormatter)
// Read the default records in
RR.Parse(DnsConf.Api.StaticRecords)
// Open database
err = DB.Init("acme-dns.db")
@ -76,14 +79,27 @@ func main() {
}
}()
// API server
// API server and endpoints
api := iris.New()
for path, handlerfunc := range GetHandlerMap() {
api.Get(path, handlerfunc)
var ForceAuth AuthMiddleware = AuthMiddleware{}
api.Get("/register", WebRegisterGet)
api.Post("/register", WebRegisterPost)
api.Post("/update", ForceAuth.Serve, WebUpdatePost)
// TODO: migrate to api.Serve(iris.LETSENCRYPTPROD("mydomain.com"))
switch DnsConf.Api.Tls {
case "letsencrypt":
host := DnsConf.Api.Domain + ":" + DnsConf.Api.Port
api.Listen(host)
case "cert":
host := DnsConf.Api.Domain + ":" + DnsConf.Api.Port
api.ListenTLS(host, DnsConf.Api.Tls_cert_fullchain, DnsConf.Api.Tls_cert_privkey)
default:
host := DnsConf.Api.Domain + ":" + DnsConf.Api.Port
api.Listen(host)
}
for path, handlerfunc := range PostHandlerMap() {
api.Post(path, handlerfunc)
if err != nil {
log.Errorf("Error in HTTP server [%v]", err)
}
api.Listen(":8080")
log.Debugf("Shutting down...")
}

View File

@ -2,6 +2,8 @@ package main
import (
"github.com/miekg/dns"
"github.com/satori/go.uuid"
"time"
)
// Static records
@ -11,16 +13,48 @@ type Records struct {
// Config file main struct
type DnsConfig struct {
General general
General general
Api httpapi
Logconfig logconfig
}
// Auth middleware
type AuthMiddleware struct{}
// Config file general section
type general struct {
Domain string
Nsname string
Nsadmin string
}
// API config
type httpapi struct {
Domain string
Nsname string
Nsadmin string
Port string
Tls string
Tls_cert_privkey string
Tls_cert_fullchain string
StaticRecords []string `toml:"records"`
}
// Logging config
type logconfig struct {
Level string `toml:"loglevel"`
Logtype string `toml:"logtype"`
File string `toml:"logfile"`
Format string `toml:"logformat"`
}
// The default object
type ACMETxt struct {
Username uuid.UUID
Password string
ACMETxtPost
LastActive time.Time
}
type ACMETxtPost struct {
Subdomain string `json:"subdomain"`
Value string `json:"txt"`
}

55
util.go
View File

@ -1,20 +1,57 @@
package main
import (
"crypto/rand"
"github.com/satori/go.uuid"
"math/big"
"regexp"
"unicode/utf8"
"strings"
)
func NormalizeString(s string, length int) string {
var ret string
re, err := regexp.Compile("[^A-Za-z\\-0-9]+")
func SanitizeString(s string) string {
// URL safe base64 alphabet without padding as defined in ACME
re, err := regexp.Compile("[^A-Za-z\\-\\_0-9]+")
if err != nil {
log.Errorf("%v", err)
return ""
}
ret = re.ReplaceAllString(s, "")
if utf8.RuneCountInString(ret) > length {
ret = ret[0:length]
}
return ret
return re.ReplaceAllString(s, "")
}
func GeneratePassword(length int) (string, error) {
ret := make([]byte, length)
const alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz1234567890-_"
alphalen := big.NewInt(int64(len(alphabet)))
for i := 0; i < length; i++ {
c, err := rand.Int(rand.Reader, alphalen)
if err != nil {
return "", err
}
r := int(c.Int64())
ret[i] = alphabet[r]
}
return string(ret), nil
}
func SanitizeDomainQuestion(d string) string {
var dom string
dns_suff := DnsConf.General.Domain + "."
if strings.HasSuffix(d, dns_suff) {
dom = d[0 : len(d)-len(dns_suff)]
} else {
dom = d
}
return dom
}
func NewACMETxt() (ACMETxt, error) {
var a ACMETxt = ACMETxt{}
password, err := GeneratePassword(40)
if err != nil {
return a, err
}
a.Username = uuid.NewV4()
a.Password = password
a.Subdomain = uuid.NewV4().String()
return a, nil
}

48
validation.go Normal file
View File

@ -0,0 +1,48 @@
package main
import (
"github.com/satori/go.uuid"
"golang.org/x/crypto/bcrypt"
"unicode/utf8"
)
func GetValidUsername(u string) (uuid.UUID, error) {
uname, err := uuid.FromString(u)
if err != nil {
return uuid.UUID{}, err
}
return uname, nil
}
func ValidKey(k string) bool {
kn := SanitizeString(k)
if utf8.RuneCountInString(k) == 40 && utf8.RuneCountInString(kn) == 40 {
// Correct length and all chars valid
return true
}
return false
}
func ValidSubdomain(s string) bool {
_, err := uuid.FromString(s)
if err == nil {
return true
}
return false
}
func ValidTXT(s string) bool {
sn := SanitizeString(s)
if utf8.RuneCountInString(s) == 43 && utf8.RuneCountInString(sn) == 43 {
// 43 chars is the current LE auth key size, but not limited / defined by ACME
return true
}
return false
}
func CorrectPassword(pw string, hash string) bool {
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(pw)); err == nil {
return true
}
return false
}