Coding style fixes

This commit is contained in:
Joona Hoikkala
2016-11-16 19:15:36 +02:00
parent e3a2577f7f
commit d30860eeb8
6 changed files with 67 additions and 69 deletions

35
api.go
View File

@ -20,14 +20,14 @@ func PostHandlerMap() map[string]func(*iris.Context) {
} }
func (a AuthMiddleware) Serve(ctx *iris.Context) { func (a AuthMiddleware) Serve(ctx *iris.Context) {
username_str := ctx.RequestHeader("X-Api-User") usernameStr := ctx.RequestHeader("X-Api-User")
password := ctx.RequestHeader("X-Api-Key") password := ctx.RequestHeader("X-Api-Key")
username, err := GetValidUsername(username_str) username, err := GetValidUsername(usernameStr)
if err == nil && ValidKey(password) { if err == nil && ValidKey(password) {
au, err := DB.GetByUsername(username) au, err := DB.GetByUsername(username)
if err == nil && CorrectPassword(password, au.Password) { if err == nil && CorrectPassword(password, au.Password) {
log.Debugf("Accepted authentication from [%s]", username_str) log.Debugf("Accepted authentication from [%s]", usernameStr)
ctx.Next() ctx.Next()
return return
} }
@ -40,19 +40,18 @@ func (a AuthMiddleware) Serve(ctx *iris.Context) {
func WebRegisterPost(ctx *iris.Context) { func WebRegisterPost(ctx *iris.Context) {
// Create new user // Create new user
nu, err := DB.Register() nu, err := DB.Register()
var reg_json iris.Map var regJSON iris.Map
var reg_status int var regStatus int
if err != nil { if err != nil {
errstr := fmt.Sprintf("%v", err) errstr := fmt.Sprintf("%v", err)
regJSON = iris.Map{"username": "", "password": "", "domain": "", "error": errstr}
reg_json = iris.Map{"username": "", "password": "", "domain": "", "error": errstr} regStatus = iris.StatusInternalServerError
reg_status = iris.StatusInternalServerError
} else { } else {
reg_json = iris.Map{"username": nu.Username, "password": nu.Password, "fulldomain": nu.Subdomain + "." + DnsConf.General.Domain, "subdomain": nu.Subdomain} regJSON = iris.Map{"username": nu.Username, "password": nu.Password, "fulldomain": nu.Subdomain + "." + DNSConf.General.Domain, "subdomain": nu.Subdomain}
reg_status = iris.StatusCreated regStatus = iris.StatusCreated
} }
log.Debugf("Successful registration, created user [%s]", nu.Username) log.Debugf("Successful registration, created user [%s]", nu.Username)
ctx.JSON(reg_status, reg_json) ctx.JSON(regStatus, regJSON)
} }
func WebRegisterGet(ctx *iris.Context) { func WebRegisterGet(ctx *iris.Context) {
@ -62,11 +61,11 @@ func WebRegisterGet(ctx *iris.Context) {
func WebUpdatePost(ctx *iris.Context) { func WebUpdatePost(ctx *iris.Context) {
// User auth done in middleware // User auth done in middleware
var a ACMETxt = ACMETxt{} a := ACMETxt{}
user_string := ctx.RequestHeader("X-API-User") userStr := ctx.RequestHeader("X-API-User")
username, err := GetValidUsername(user_string) username, err := GetValidUsername(userStr)
if err != nil { if err != nil {
log.Warningf("Error while getting username [%s]. This should never happen because of auth middlware.", user_string) log.Warningf("Error while getting username [%s]. This should never happen because of auth middlware.", userStr)
WebUpdatePostError(ctx, err, iris.StatusUnauthorized) WebUpdatePostError(ctx, err, iris.StatusUnauthorized)
return return
} }
@ -94,7 +93,7 @@ func WebUpdatePost(ctx *iris.Context) {
} }
func WebUpdatePostError(ctx *iris.Context, err error, status int) { func WebUpdatePostError(ctx *iris.Context, err error, status int) {
err_str := fmt.Sprintf("%v", err) errStr := fmt.Sprintf("%v", err)
upd_json := iris.Map{"error": err_str} updJSON := iris.Map{"error": errStr}
ctx.JSON(status, upd_json) ctx.JSON(status, updJSON)
} }

29
db.go
View File

@ -12,7 +12,7 @@ type Database struct {
DB *sql.DB DB *sql.DB
} }
var records_table string = ` var recordsTable = `
CREATE TABLE IF NOT EXISTS records( CREATE TABLE IF NOT EXISTS records(
Username TEXT UNIQUE NOT NULL PRIMARY KEY, Username TEXT UNIQUE NOT NULL PRIMARY KEY,
Password TEXT UNIQUE NOT NULL, Password TEXT UNIQUE NOT NULL,
@ -27,7 +27,7 @@ func (d *Database) Init(filename string) error {
return err return err
} }
d.DB = db d.DB = db
_, err = d.DB.Exec(records_table) _, err = d.DB.Exec(recordsTable)
if err != nil { if err != nil {
return err return err
} }
@ -39,8 +39,8 @@ func (d *Database) Register() (ACMETxt, error) {
if err != nil { if err != nil {
return ACMETxt{}, err return ACMETxt{}, err
} }
password_hash, err := bcrypt.GenerateFromPassword([]byte(a.Password), 10) passwordHash, err := bcrypt.GenerateFromPassword([]byte(a.Password), 10)
reg_sql := ` regSQL := `
INSERT INTO records( INSERT INTO records(
Username, Username,
Password, Password,
@ -48,12 +48,12 @@ func (d *Database) Register() (ACMETxt, error) {
Value, Value,
LastActive) LastActive)
values(?, ?, ?, ?, CURRENT_TIMESTAMP)` values(?, ?, ?, ?, CURRENT_TIMESTAMP)`
sm, err := d.DB.Prepare(reg_sql) sm, err := d.DB.Prepare(regSQL)
if err != nil { if err != nil {
return a, err return a, err
} }
defer sm.Close() defer sm.Close()
_, err = sm.Exec(a.Username, password_hash, a.Subdomain, a.Value) _, err = sm.Exec(a.Username, passwordHash, a.Subdomain, a.Value)
if err != nil { if err != nil {
return a, err return a, err
} }
@ -62,12 +62,12 @@ func (d *Database) Register() (ACMETxt, error) {
func (d *Database) GetByUsername(u uuid.UUID) (ACMETxt, error) { func (d *Database) GetByUsername(u uuid.UUID) (ACMETxt, error) {
var results []ACMETxt var results []ACMETxt
get_sql := ` getSQL := `
SELECT Username, Password, Subdomain, Value, LastActive SELECT Username, Password, Subdomain, Value, LastActive
FROM records FROM records
WHERE Username=? LIMIT 1 WHERE Username=? LIMIT 1
` `
sm, err := d.DB.Prepare(get_sql) sm, err := d.DB.Prepare(getSQL)
if err != nil { if err != nil {
return ACMETxt{}, err return ACMETxt{}, err
} }
@ -80,7 +80,7 @@ func (d *Database) GetByUsername(u uuid.UUID) (ACMETxt, error) {
// It will only be one row though // It will only be one row though
for rows.Next() { for rows.Next() {
var a ACMETxt = ACMETxt{} a := ACMETxt{}
var uname string var uname string
err = rows.Scan(&uname, &a.Password, &a.Subdomain, &a.Value, &a.LastActive) err = rows.Scan(&uname, &a.Password, &a.Subdomain, &a.Value, &a.LastActive)
if err != nil { if err != nil {
@ -94,21 +94,20 @@ func (d *Database) GetByUsername(u uuid.UUID) (ACMETxt, error) {
} }
if len(results) > 0 { if len(results) > 0 {
return results[0], nil return results[0], nil
} else {
return ACMETxt{}, errors.New("no user")
} }
return ACMETxt{}, errors.New("no user")
} }
func (d *Database) GetByDomain(domain string) ([]ACMETxt, error) { func (d *Database) GetByDomain(domain string) ([]ACMETxt, error) {
domain = SanitizeString(domain) domain = SanitizeString(domain)
log.Debugf("Trying to select domain [%s] from table", domain) log.Debugf("Trying to select domain [%s] from table", domain)
var a []ACMETxt var a []ACMETxt
get_sql := ` getSQL := `
SELECT Username, Password, Subdomain, Value SELECT Username, Password, Subdomain, Value
FROM records FROM records
WHERE Subdomain=? LIMIT 1 WHERE Subdomain=? LIMIT 1
` `
sm, err := d.DB.Prepare(get_sql) sm, err := d.DB.Prepare(getSQL)
if err != nil { if err != nil {
return a, err return a, err
} }
@ -133,11 +132,11 @@ func (d *Database) GetByDomain(domain string) ([]ACMETxt, error) {
func (d *Database) Update(a ACMETxt) error { func (d *Database) Update(a ACMETxt) error {
// Data in a is already sanitized // Data in a is already sanitized
log.Debugf("Trying to update domain [%s] with TXT data [%s]", a.Subdomain, a.Value) log.Debugf("Trying to update domain [%s] with TXT data [%s]", a.Subdomain, a.Value)
upd_sql := ` updSQL := `
UPDATE records SET Value=? UPDATE records SET Value=?
WHERE Username=? AND Subdomain=? WHERE Username=? AND Subdomain=?
` `
sm, err := d.DB.Prepare(upd_sql) sm, err := d.DB.Prepare(updSQL)
if err != nil { if err != nil {
return err return err
} }

6
dns.go
View File

@ -21,7 +21,7 @@ func readQuery(m *dns.Msg) {
func answerTXT(q dns.Question) ([]dns.RR, int, error) { func answerTXT(q dns.Question) ([]dns.RR, int, error) {
var ra []dns.RR var ra []dns.RR
var rcode int = dns.RcodeNameError var rcode int = dns.RcodeNameError
var domain string = strings.ToLower(q.Name) var domain = strings.ToLower(q.Name)
atxt, err := DB.GetByDomain(SanitizeDomainQuestion(domain)) atxt, err := DB.GetByDomain(SanitizeDomainQuestion(domain))
if err != nil { if err != nil {
@ -47,7 +47,7 @@ func answer(q dns.Question) ([]dns.RR, int, error) {
} }
var r []dns.RR var r []dns.RR
var rcode int = dns.RcodeSuccess var rcode int = dns.RcodeSuccess
var domain string = strings.ToLower(q.Name) var domain = strings.ToLower(q.Name)
var rtype uint16 = q.Qtype var rtype uint16 = q.Qtype
r, ok := RR.Records[rtype][domain] r, ok := RR.Records[rtype][domain]
if !ok { if !ok {
@ -83,7 +83,7 @@ func (r *Records) Parse(recs []string) {
// Create serial // Create serial
serial := time.Now().Format("2006010215") serial := time.Now().Format("2006010215")
// Add SOA // Add SOA
SOAstring := fmt.Sprintf("%s. SOA %s. %s. %s 28800 7200 604800 86400", strings.ToLower(DnsConf.General.Domain), strings.ToLower(DnsConf.General.Nsname), strings.ToLower(DnsConf.General.Nsadmin), serial) SOAstring := fmt.Sprintf("%s. SOA %s. %s. %s 28800 7200 604800 86400", strings.ToLower(DNSConf.General.Domain), strings.ToLower(DNSConf.General.Nsname), strings.ToLower(DNSConf.General.Nsadmin), serial)
soarr, err := dns.NewRR(SOAstring) soarr, err := dns.NewRR(SOAstring)
if err != nil { if err != nil {
log.Errorf("Error [%v] while trying to add SOA record: [%s]", err, SOAstring) log.Errorf("Error [%v] while trying to add SOA record: [%s]", err, SOAstring)

36
main.go
View File

@ -13,7 +13,7 @@ import (
var log = logging.MustGetLogger("acme-dns") var log = logging.MustGetLogger("acme-dns")
// Global configuration struct // Global configuration struct
var DnsConf DnsConfig var DNSConf DNSConfig
var DB Database var DB Database
@ -22,24 +22,24 @@ var RR Records
func main() { func main() {
// Read global config // Read global config
config_tmp, err := ReadConfig("config.cfg") configTmp, err := readConfig("config.cfg")
if err != nil { if err != nil {
fmt.Printf("Got error %v\n", DnsConf.Logconfig.File) fmt.Printf("Got error %v\n", DNSConf.Logconfig.File)
os.Exit(1) os.Exit(1)
} }
DnsConf = config_tmp DNSConf = configTmp
// Setup logging // Setup logging
var logformat = logging.MustStringFormatter(DnsConf.Logconfig.Format) var logformat = logging.MustStringFormatter(DNSConf.Logconfig.Format)
var logBackend *logging.LogBackend var logBackend *logging.LogBackend
switch DnsConf.Logconfig.Logtype { switch DNSConf.Logconfig.Logtype {
default: default:
// Setup logging - stdout // Setup logging - stdout
logBackend = logging.NewLogBackend(os.Stdout, "", 0) logBackend = logging.NewLogBackend(os.Stdout, "", 0)
case "file": case "file":
// Logging to file // Logging to file
logfh, err := os.OpenFile(DnsConf.Logconfig.File, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666) logfh, err := os.OpenFile(DNSConf.Logconfig.File, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666)
if err != nil { if err != nil {
fmt.Printf("Could not open log file %s\n", DnsConf.Logconfig.File) fmt.Printf("Could not open log file %s\n", DNSConf.Logconfig.File)
os.Exit(1) os.Exit(1)
} }
defer logfh.Close() defer logfh.Close()
@ -47,7 +47,7 @@ func main() {
} }
logLevel := logging.AddModuleLevel(logBackend) logLevel := logging.AddModuleLevel(logBackend)
switch DnsConf.Logconfig.Level { switch DNSConf.Logconfig.Level {
case "warning": case "warning":
logLevel.SetLevel(logging.WARNING, "") logLevel.SetLevel(logging.WARNING, "")
case "error": case "error":
@ -59,7 +59,7 @@ func main() {
logging.SetBackend(logFormatter) logging.SetBackend(logFormatter)
// Read the default records in // Read the default records in
RR.Parse(DnsConf.General.StaticRecords) RR.Parse(DNSConf.General.StaticRecords)
// Open database // Open database
err = DB.Init("acme-dns.db") err = DB.Init("acme-dns.db")
@ -83,26 +83,26 @@ func main() {
// API server and endpoints // API server and endpoints
api := iris.New() api := iris.New()
crs := cors.New(cors.Options{ crs := cors.New(cors.Options{
AllowedOrigins: DnsConf.Api.CorsOrigins, AllowedOrigins: DNSConf.API.CorsOrigins,
AllowedMethods: []string{"GET", "POST"}, AllowedMethods: []string{"GET", "POST"},
OptionsPassthrough: false, OptionsPassthrough: false,
Debug: DnsConf.General.Debug, Debug: DNSConf.General.Debug,
}) })
api.Use(crs) api.Use(crs)
var ForceAuth AuthMiddleware = AuthMiddleware{} var ForceAuth = AuthMiddleware{}
api.Get("/register", WebRegisterGet) api.Get("/register", WebRegisterGet)
api.Post("/register", WebRegisterPost) api.Post("/register", WebRegisterPost)
api.Post("/update", ForceAuth.Serve, WebUpdatePost) api.Post("/update", ForceAuth.Serve, WebUpdatePost)
// TODO: migrate to api.Serve(iris.LETSENCRYPTPROD("mydomain.com")) // TODO: migrate to api.Serve(iris.LETSENCRYPTPROD("mydomain.com"))
switch DnsConf.Api.Tls { switch DNSConf.API.TLS {
case "letsencrypt": case "letsencrypt":
host := DnsConf.Api.Domain + ":" + DnsConf.Api.Port host := DNSConf.API.Domain + ":" + DNSConf.API.Port
api.Listen(host) api.Listen(host)
case "cert": case "cert":
host := DnsConf.Api.Domain + ":" + DnsConf.Api.Port host := DNSConf.API.Domain + ":" + DNSConf.API.Port
api.ListenTLS(host, DnsConf.Api.Tls_cert_fullchain, DnsConf.Api.Tls_cert_privkey) api.ListenTLS(host, DNSConf.API.TLSCertFullchain, DNSConf.API.TLSCertPrivkey)
default: default:
host := DnsConf.Api.Domain + ":" + DnsConf.Api.Port host := DNSConf.API.Domain + ":" + DNSConf.API.Port
api.Listen(host) api.Listen(host)
} }
if err != nil { if err != nil {

View File

@ -12,9 +12,9 @@ type Records struct {
} }
// Config file main struct // Config file main struct
type DnsConfig struct { type DNSConfig struct {
General general General general
Api httpapi API httpapi
Logconfig logconfig Logconfig logconfig
} }
@ -32,12 +32,12 @@ type general struct {
// API config // API config
type httpapi struct { type httpapi struct {
Domain string Domain string
Port string Port string
Tls string TLS string
Tls_cert_privkey string TLSCertPrivkey string `toml:"tls_cert_privkey"`
Tls_cert_fullchain string TLSCertFullchain string `toml:"tls_cert_fullchain"`
CorsOrigins []string CorsOrigins []string
} }
// Logging config // Logging config

14
util.go
View File

@ -10,10 +10,10 @@ import (
"strings" "strings"
) )
func ReadConfig(fname string) (DnsConfig, error) { func readConfig(fname string) (DNSConfig, error) {
var conf DnsConfig var conf DNSConfig
if _, err := toml.DecodeFile(fname, &conf); err != nil { if _, err := toml.DecodeFile(fname, &conf); err != nil {
return DnsConfig{}, errors.New("Malformed configuration file") return DNSConfig{}, errors.New("Malformed configuration file")
} }
return conf, nil return conf, nil
} }
@ -45,9 +45,9 @@ func GeneratePassword(length int) (string, error) {
func SanitizeDomainQuestion(d string) string { func SanitizeDomainQuestion(d string) string {
var dom string var dom string
dns_suff := DnsConf.General.Domain + "." suffix := DNSConf.General.Domain + "."
if strings.HasSuffix(d, dns_suff) { if strings.HasSuffix(d, suffix) {
dom = d[0 : len(d)-len(dns_suff)] dom = d[0 : len(d)-len(suffix)]
} else { } else {
dom = d dom = d
} }
@ -55,7 +55,7 @@ func SanitizeDomainQuestion(d string) string {
} }
func NewACMETxt() (ACMETxt, error) { func NewACMETxt() (ACMETxt, error) {
var a ACMETxt = ACMETxt{} var a = ACMETxt{}
password, err := GeneratePassword(40) password, err := GeneratePassword(40)
if err != nil { if err != nil {
return a, err return a, err