mirror of
https://github.com/joohoi/acme-dns.git
synced 2025-07-13 17:27:51 +07:00
Coding style fixes
This commit is contained in:
35
api.go
35
api.go
@ -20,14 +20,14 @@ func PostHandlerMap() map[string]func(*iris.Context) {
|
||||
}
|
||||
|
||||
func (a AuthMiddleware) Serve(ctx *iris.Context) {
|
||||
username_str := ctx.RequestHeader("X-Api-User")
|
||||
usernameStr := ctx.RequestHeader("X-Api-User")
|
||||
password := ctx.RequestHeader("X-Api-Key")
|
||||
|
||||
username, err := GetValidUsername(username_str)
|
||||
username, err := GetValidUsername(usernameStr)
|
||||
if err == nil && ValidKey(password) {
|
||||
au, err := DB.GetByUsername(username)
|
||||
if err == nil && CorrectPassword(password, au.Password) {
|
||||
log.Debugf("Accepted authentication from [%s]", username_str)
|
||||
log.Debugf("Accepted authentication from [%s]", usernameStr)
|
||||
ctx.Next()
|
||||
return
|
||||
}
|
||||
@ -40,19 +40,18 @@ func (a AuthMiddleware) Serve(ctx *iris.Context) {
|
||||
func WebRegisterPost(ctx *iris.Context) {
|
||||
// Create new user
|
||||
nu, err := DB.Register()
|
||||
var reg_json iris.Map
|
||||
var reg_status int
|
||||
var regJSON iris.Map
|
||||
var regStatus int
|
||||
if err != nil {
|
||||
errstr := fmt.Sprintf("%v", err)
|
||||
|
||||
reg_json = iris.Map{"username": "", "password": "", "domain": "", "error": errstr}
|
||||
reg_status = iris.StatusInternalServerError
|
||||
regJSON = iris.Map{"username": "", "password": "", "domain": "", "error": errstr}
|
||||
regStatus = iris.StatusInternalServerError
|
||||
} else {
|
||||
reg_json = iris.Map{"username": nu.Username, "password": nu.Password, "fulldomain": nu.Subdomain + "." + DnsConf.General.Domain, "subdomain": nu.Subdomain}
|
||||
reg_status = iris.StatusCreated
|
||||
regJSON = iris.Map{"username": nu.Username, "password": nu.Password, "fulldomain": nu.Subdomain + "." + DNSConf.General.Domain, "subdomain": nu.Subdomain}
|
||||
regStatus = iris.StatusCreated
|
||||
}
|
||||
log.Debugf("Successful registration, created user [%s]", nu.Username)
|
||||
ctx.JSON(reg_status, reg_json)
|
||||
ctx.JSON(regStatus, regJSON)
|
||||
}
|
||||
|
||||
func WebRegisterGet(ctx *iris.Context) {
|
||||
@ -62,11 +61,11 @@ func WebRegisterGet(ctx *iris.Context) {
|
||||
|
||||
func WebUpdatePost(ctx *iris.Context) {
|
||||
// User auth done in middleware
|
||||
var a ACMETxt = ACMETxt{}
|
||||
user_string := ctx.RequestHeader("X-API-User")
|
||||
username, err := GetValidUsername(user_string)
|
||||
a := ACMETxt{}
|
||||
userStr := ctx.RequestHeader("X-API-User")
|
||||
username, err := GetValidUsername(userStr)
|
||||
if err != nil {
|
||||
log.Warningf("Error while getting username [%s]. This should never happen because of auth middlware.", user_string)
|
||||
log.Warningf("Error while getting username [%s]. This should never happen because of auth middlware.", userStr)
|
||||
WebUpdatePostError(ctx, err, iris.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
@ -94,7 +93,7 @@ func WebUpdatePost(ctx *iris.Context) {
|
||||
}
|
||||
|
||||
func WebUpdatePostError(ctx *iris.Context, err error, status int) {
|
||||
err_str := fmt.Sprintf("%v", err)
|
||||
upd_json := iris.Map{"error": err_str}
|
||||
ctx.JSON(status, upd_json)
|
||||
errStr := fmt.Sprintf("%v", err)
|
||||
updJSON := iris.Map{"error": errStr}
|
||||
ctx.JSON(status, updJSON)
|
||||
}
|
||||
|
29
db.go
29
db.go
@ -12,7 +12,7 @@ type Database struct {
|
||||
DB *sql.DB
|
||||
}
|
||||
|
||||
var records_table string = `
|
||||
var recordsTable = `
|
||||
CREATE TABLE IF NOT EXISTS records(
|
||||
Username TEXT UNIQUE NOT NULL PRIMARY KEY,
|
||||
Password TEXT UNIQUE NOT NULL,
|
||||
@ -27,7 +27,7 @@ func (d *Database) Init(filename string) error {
|
||||
return err
|
||||
}
|
||||
d.DB = db
|
||||
_, err = d.DB.Exec(records_table)
|
||||
_, err = d.DB.Exec(recordsTable)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -39,8 +39,8 @@ func (d *Database) Register() (ACMETxt, error) {
|
||||
if err != nil {
|
||||
return ACMETxt{}, err
|
||||
}
|
||||
password_hash, err := bcrypt.GenerateFromPassword([]byte(a.Password), 10)
|
||||
reg_sql := `
|
||||
passwordHash, err := bcrypt.GenerateFromPassword([]byte(a.Password), 10)
|
||||
regSQL := `
|
||||
INSERT INTO records(
|
||||
Username,
|
||||
Password,
|
||||
@ -48,12 +48,12 @@ func (d *Database) Register() (ACMETxt, error) {
|
||||
Value,
|
||||
LastActive)
|
||||
values(?, ?, ?, ?, CURRENT_TIMESTAMP)`
|
||||
sm, err := d.DB.Prepare(reg_sql)
|
||||
sm, err := d.DB.Prepare(regSQL)
|
||||
if err != nil {
|
||||
return a, err
|
||||
}
|
||||
defer sm.Close()
|
||||
_, err = sm.Exec(a.Username, password_hash, a.Subdomain, a.Value)
|
||||
_, err = sm.Exec(a.Username, passwordHash, a.Subdomain, a.Value)
|
||||
if err != nil {
|
||||
return a, err
|
||||
}
|
||||
@ -62,12 +62,12 @@ func (d *Database) Register() (ACMETxt, error) {
|
||||
|
||||
func (d *Database) GetByUsername(u uuid.UUID) (ACMETxt, error) {
|
||||
var results []ACMETxt
|
||||
get_sql := `
|
||||
getSQL := `
|
||||
SELECT Username, Password, Subdomain, Value, LastActive
|
||||
FROM records
|
||||
WHERE Username=? LIMIT 1
|
||||
`
|
||||
sm, err := d.DB.Prepare(get_sql)
|
||||
sm, err := d.DB.Prepare(getSQL)
|
||||
if err != nil {
|
||||
return ACMETxt{}, err
|
||||
}
|
||||
@ -80,7 +80,7 @@ func (d *Database) GetByUsername(u uuid.UUID) (ACMETxt, error) {
|
||||
|
||||
// It will only be one row though
|
||||
for rows.Next() {
|
||||
var a ACMETxt = ACMETxt{}
|
||||
a := ACMETxt{}
|
||||
var uname string
|
||||
err = rows.Scan(&uname, &a.Password, &a.Subdomain, &a.Value, &a.LastActive)
|
||||
if err != nil {
|
||||
@ -94,21 +94,20 @@ func (d *Database) GetByUsername(u uuid.UUID) (ACMETxt, error) {
|
||||
}
|
||||
if len(results) > 0 {
|
||||
return results[0], nil
|
||||
} else {
|
||||
return ACMETxt{}, errors.New("no user")
|
||||
}
|
||||
return ACMETxt{}, errors.New("no user")
|
||||
}
|
||||
|
||||
func (d *Database) GetByDomain(domain string) ([]ACMETxt, error) {
|
||||
domain = SanitizeString(domain)
|
||||
log.Debugf("Trying to select domain [%s] from table", domain)
|
||||
var a []ACMETxt
|
||||
get_sql := `
|
||||
getSQL := `
|
||||
SELECT Username, Password, Subdomain, Value
|
||||
FROM records
|
||||
WHERE Subdomain=? LIMIT 1
|
||||
`
|
||||
sm, err := d.DB.Prepare(get_sql)
|
||||
sm, err := d.DB.Prepare(getSQL)
|
||||
if err != nil {
|
||||
return a, err
|
||||
}
|
||||
@ -133,11 +132,11 @@ func (d *Database) GetByDomain(domain string) ([]ACMETxt, error) {
|
||||
func (d *Database) Update(a ACMETxt) error {
|
||||
// Data in a is already sanitized
|
||||
log.Debugf("Trying to update domain [%s] with TXT data [%s]", a.Subdomain, a.Value)
|
||||
upd_sql := `
|
||||
updSQL := `
|
||||
UPDATE records SET Value=?
|
||||
WHERE Username=? AND Subdomain=?
|
||||
`
|
||||
sm, err := d.DB.Prepare(upd_sql)
|
||||
sm, err := d.DB.Prepare(updSQL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
6
dns.go
6
dns.go
@ -21,7 +21,7 @@ func readQuery(m *dns.Msg) {
|
||||
func answerTXT(q dns.Question) ([]dns.RR, int, error) {
|
||||
var ra []dns.RR
|
||||
var rcode int = dns.RcodeNameError
|
||||
var domain string = strings.ToLower(q.Name)
|
||||
var domain = strings.ToLower(q.Name)
|
||||
|
||||
atxt, err := DB.GetByDomain(SanitizeDomainQuestion(domain))
|
||||
if err != nil {
|
||||
@ -47,7 +47,7 @@ func answer(q dns.Question) ([]dns.RR, int, error) {
|
||||
}
|
||||
var r []dns.RR
|
||||
var rcode int = dns.RcodeSuccess
|
||||
var domain string = strings.ToLower(q.Name)
|
||||
var domain = strings.ToLower(q.Name)
|
||||
var rtype uint16 = q.Qtype
|
||||
r, ok := RR.Records[rtype][domain]
|
||||
if !ok {
|
||||
@ -83,7 +83,7 @@ func (r *Records) Parse(recs []string) {
|
||||
// Create serial
|
||||
serial := time.Now().Format("2006010215")
|
||||
// Add SOA
|
||||
SOAstring := fmt.Sprintf("%s. SOA %s. %s. %s 28800 7200 604800 86400", strings.ToLower(DnsConf.General.Domain), strings.ToLower(DnsConf.General.Nsname), strings.ToLower(DnsConf.General.Nsadmin), serial)
|
||||
SOAstring := fmt.Sprintf("%s. SOA %s. %s. %s 28800 7200 604800 86400", strings.ToLower(DNSConf.General.Domain), strings.ToLower(DNSConf.General.Nsname), strings.ToLower(DNSConf.General.Nsadmin), serial)
|
||||
soarr, err := dns.NewRR(SOAstring)
|
||||
if err != nil {
|
||||
log.Errorf("Error [%v] while trying to add SOA record: [%s]", err, SOAstring)
|
||||
|
36
main.go
36
main.go
@ -13,7 +13,7 @@ import (
|
||||
var log = logging.MustGetLogger("acme-dns")
|
||||
|
||||
// Global configuration struct
|
||||
var DnsConf DnsConfig
|
||||
var DNSConf DNSConfig
|
||||
|
||||
var DB Database
|
||||
|
||||
@ -22,24 +22,24 @@ var RR Records
|
||||
|
||||
func main() {
|
||||
// Read global config
|
||||
config_tmp, err := ReadConfig("config.cfg")
|
||||
configTmp, err := readConfig("config.cfg")
|
||||
if err != nil {
|
||||
fmt.Printf("Got error %v\n", DnsConf.Logconfig.File)
|
||||
fmt.Printf("Got error %v\n", DNSConf.Logconfig.File)
|
||||
os.Exit(1)
|
||||
}
|
||||
DnsConf = config_tmp
|
||||
DNSConf = configTmp
|
||||
// Setup logging
|
||||
var logformat = logging.MustStringFormatter(DnsConf.Logconfig.Format)
|
||||
var logformat = logging.MustStringFormatter(DNSConf.Logconfig.Format)
|
||||
var logBackend *logging.LogBackend
|
||||
switch DnsConf.Logconfig.Logtype {
|
||||
switch DNSConf.Logconfig.Logtype {
|
||||
default:
|
||||
// Setup logging - stdout
|
||||
logBackend = logging.NewLogBackend(os.Stdout, "", 0)
|
||||
case "file":
|
||||
// Logging to file
|
||||
logfh, err := os.OpenFile(DnsConf.Logconfig.File, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666)
|
||||
logfh, err := os.OpenFile(DNSConf.Logconfig.File, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666)
|
||||
if err != nil {
|
||||
fmt.Printf("Could not open log file %s\n", DnsConf.Logconfig.File)
|
||||
fmt.Printf("Could not open log file %s\n", DNSConf.Logconfig.File)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer logfh.Close()
|
||||
@ -47,7 +47,7 @@ func main() {
|
||||
}
|
||||
|
||||
logLevel := logging.AddModuleLevel(logBackend)
|
||||
switch DnsConf.Logconfig.Level {
|
||||
switch DNSConf.Logconfig.Level {
|
||||
case "warning":
|
||||
logLevel.SetLevel(logging.WARNING, "")
|
||||
case "error":
|
||||
@ -59,7 +59,7 @@ func main() {
|
||||
logging.SetBackend(logFormatter)
|
||||
|
||||
// Read the default records in
|
||||
RR.Parse(DnsConf.General.StaticRecords)
|
||||
RR.Parse(DNSConf.General.StaticRecords)
|
||||
|
||||
// Open database
|
||||
err = DB.Init("acme-dns.db")
|
||||
@ -83,26 +83,26 @@ func main() {
|
||||
// API server and endpoints
|
||||
api := iris.New()
|
||||
crs := cors.New(cors.Options{
|
||||
AllowedOrigins: DnsConf.Api.CorsOrigins,
|
||||
AllowedOrigins: DNSConf.API.CorsOrigins,
|
||||
AllowedMethods: []string{"GET", "POST"},
|
||||
OptionsPassthrough: false,
|
||||
Debug: DnsConf.General.Debug,
|
||||
Debug: DNSConf.General.Debug,
|
||||
})
|
||||
api.Use(crs)
|
||||
var ForceAuth AuthMiddleware = AuthMiddleware{}
|
||||
var ForceAuth = AuthMiddleware{}
|
||||
api.Get("/register", WebRegisterGet)
|
||||
api.Post("/register", WebRegisterPost)
|
||||
api.Post("/update", ForceAuth.Serve, WebUpdatePost)
|
||||
// TODO: migrate to api.Serve(iris.LETSENCRYPTPROD("mydomain.com"))
|
||||
switch DnsConf.Api.Tls {
|
||||
switch DNSConf.API.TLS {
|
||||
case "letsencrypt":
|
||||
host := DnsConf.Api.Domain + ":" + DnsConf.Api.Port
|
||||
host := DNSConf.API.Domain + ":" + DNSConf.API.Port
|
||||
api.Listen(host)
|
||||
case "cert":
|
||||
host := DnsConf.Api.Domain + ":" + DnsConf.Api.Port
|
||||
api.ListenTLS(host, DnsConf.Api.Tls_cert_fullchain, DnsConf.Api.Tls_cert_privkey)
|
||||
host := DNSConf.API.Domain + ":" + DNSConf.API.Port
|
||||
api.ListenTLS(host, DNSConf.API.TLSCertFullchain, DNSConf.API.TLSCertPrivkey)
|
||||
default:
|
||||
host := DnsConf.Api.Domain + ":" + DnsConf.Api.Port
|
||||
host := DNSConf.API.Domain + ":" + DNSConf.API.Port
|
||||
api.Listen(host)
|
||||
}
|
||||
if err != nil {
|
||||
|
10
types.go
10
types.go
@ -12,9 +12,9 @@ type Records struct {
|
||||
}
|
||||
|
||||
// Config file main struct
|
||||
type DnsConfig struct {
|
||||
type DNSConfig struct {
|
||||
General general
|
||||
Api httpapi
|
||||
API httpapi
|
||||
Logconfig logconfig
|
||||
}
|
||||
|
||||
@ -34,9 +34,9 @@ type general struct {
|
||||
type httpapi struct {
|
||||
Domain string
|
||||
Port string
|
||||
Tls string
|
||||
Tls_cert_privkey string
|
||||
Tls_cert_fullchain string
|
||||
TLS string
|
||||
TLSCertPrivkey string `toml:"tls_cert_privkey"`
|
||||
TLSCertFullchain string `toml:"tls_cert_fullchain"`
|
||||
CorsOrigins []string
|
||||
}
|
||||
|
||||
|
14
util.go
14
util.go
@ -10,10 +10,10 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
func ReadConfig(fname string) (DnsConfig, error) {
|
||||
var conf DnsConfig
|
||||
func readConfig(fname string) (DNSConfig, error) {
|
||||
var conf DNSConfig
|
||||
if _, err := toml.DecodeFile(fname, &conf); err != nil {
|
||||
return DnsConfig{}, errors.New("Malformed configuration file")
|
||||
return DNSConfig{}, errors.New("Malformed configuration file")
|
||||
}
|
||||
return conf, nil
|
||||
}
|
||||
@ -45,9 +45,9 @@ func GeneratePassword(length int) (string, error) {
|
||||
|
||||
func SanitizeDomainQuestion(d string) string {
|
||||
var dom string
|
||||
dns_suff := DnsConf.General.Domain + "."
|
||||
if strings.HasSuffix(d, dns_suff) {
|
||||
dom = d[0 : len(d)-len(dns_suff)]
|
||||
suffix := DNSConf.General.Domain + "."
|
||||
if strings.HasSuffix(d, suffix) {
|
||||
dom = d[0 : len(d)-len(suffix)]
|
||||
} else {
|
||||
dom = d
|
||||
}
|
||||
@ -55,7 +55,7 @@ func SanitizeDomainQuestion(d string) string {
|
||||
}
|
||||
|
||||
func NewACMETxt() (ACMETxt, error) {
|
||||
var a ACMETxt = ACMETxt{}
|
||||
var a = ACMETxt{}
|
||||
password, err := GeneratePassword(40)
|
||||
if err != nil {
|
||||
return a, err
|
||||
|
Reference in New Issue
Block a user