acme-dns/db.go

213 lines
4.4 KiB
Go
Raw Normal View History

2016-11-11 21:48:00 +07:00
package main
import (
"database/sql"
2016-12-01 05:03:08 +07:00
"encoding/json"
2016-11-13 19:50:44 +07:00
"errors"
2016-12-01 05:03:08 +07:00
"regexp"
"time"
log "github.com/sirupsen/logrus"
2016-11-17 22:52:55 +07:00
_ "github.com/lib/pq"
2016-11-11 21:48:00 +07:00
_ "github.com/mattn/go-sqlite3"
2016-11-13 19:50:44 +07:00
"github.com/satori/go.uuid"
"golang.org/x/crypto/bcrypt"
2016-11-11 21:48:00 +07:00
)
2016-11-17 00:15:36 +07:00
var recordsTable = `
2016-11-11 21:48:00 +07:00
CREATE TABLE IF NOT EXISTS records(
Username TEXT UNIQUE NOT NULL PRIMARY KEY,
Password TEXT UNIQUE NOT NULL,
Subdomain TEXT UNIQUE NOT NULL,
Value TEXT,
LastActive INT,
AllowFrom TEXT
2016-11-11 21:48:00 +07:00
);`
2016-11-17 22:52:55 +07:00
// getSQLiteStmt replaces all PostgreSQL prepared statement placeholders (eg. $1, $2) with SQLite variant "?"
func getSQLiteStmt(s string) string {
2016-11-28 06:55:57 +07:00
re, _ := regexp.Compile("\\$[0-9]")
2016-11-17 22:52:55 +07:00
return re.ReplaceAllString(s, "?")
}
2016-11-28 04:21:46 +07:00
func (d *acmedb) Init(engine string, connection string) error {
2016-11-28 00:41:54 +07:00
d.Lock()
defer d.Unlock()
2016-11-17 22:52:55 +07:00
db, err := sql.Open(engine, connection)
2016-11-11 21:48:00 +07:00
if err != nil {
return err
}
d.DB = db
2016-11-28 00:41:54 +07:00
//d.DB.SetMaxOpenConns(1)
2016-11-17 00:15:36 +07:00
_, err = d.DB.Exec(recordsTable)
2016-11-11 21:48:00 +07:00
if err != nil {
return err
}
return nil
}
2016-12-01 05:03:08 +07:00
func (d *acmedb) Register(afrom cidrslice) (ACMETxt, error) {
2016-11-28 00:41:54 +07:00
d.Lock()
defer d.Unlock()
2016-11-28 06:55:57 +07:00
a := newACMETxt()
2016-12-01 05:03:08 +07:00
a.AllowFrom = cidrslice(afrom.ValidEntries())
2016-11-17 00:15:36 +07:00
passwordHash, err := bcrypt.GenerateFromPassword([]byte(a.Password), 10)
2016-11-17 22:52:55 +07:00
timenow := time.Now().Unix()
2016-11-17 00:15:36 +07:00
regSQL := `
2016-11-11 21:48:00 +07:00
INSERT INTO records(
Username,
Password,
Subdomain,
2016-11-17 22:52:55 +07:00
Value,
LastActive,
AllowFrom)
values($1, $2, $3, '', $4, $5)`
2016-11-17 22:52:55 +07:00
if DNSConf.Database.Engine == "sqlite3" {
regSQL = getSQLiteStmt(regSQL)
}
2016-11-17 00:15:36 +07:00
sm, err := d.DB.Prepare(regSQL)
2016-11-11 21:48:00 +07:00
if err != nil {
2016-12-01 05:03:08 +07:00
log.WithFields(log.Fields{"error": err.Error()}).Error("Database error in prepare")
return a, errors.New("SQL error")
2016-11-11 21:48:00 +07:00
}
defer sm.Close()
2016-12-01 05:03:08 +07:00
_, err = sm.Exec(a.Username.String(), passwordHash, a.Subdomain, timenow, a.AllowFrom.JSON())
2016-11-11 21:48:00 +07:00
if err != nil {
return a, err
}
return a, nil
}
2016-11-28 04:21:46 +07:00
func (d *acmedb) GetByUsername(u uuid.UUID) (ACMETxt, error) {
2016-11-28 00:41:54 +07:00
d.Lock()
defer d.Unlock()
2016-11-11 21:48:00 +07:00
var results []ACMETxt
2016-11-17 00:15:36 +07:00
getSQL := `
SELECT Username, Password, Subdomain, Value, LastActive, AllowFrom
2016-11-11 21:48:00 +07:00
FROM records
2016-11-17 22:52:55 +07:00
WHERE Username=$1 LIMIT 1
2016-11-11 21:48:00 +07:00
`
2016-11-17 22:52:55 +07:00
if DNSConf.Database.Engine == "sqlite3" {
getSQL = getSQLiteStmt(getSQL)
}
2016-11-17 00:15:36 +07:00
sm, err := d.DB.Prepare(getSQL)
2016-11-11 21:48:00 +07:00
if err != nil {
2016-11-13 19:50:44 +07:00
return ACMETxt{}, err
2016-11-11 21:48:00 +07:00
}
defer sm.Close()
2016-11-13 19:50:44 +07:00
rows, err := sm.Query(u.String())
2016-11-11 21:48:00 +07:00
if err != nil {
2016-11-13 19:50:44 +07:00
return ACMETxt{}, err
2016-11-11 21:48:00 +07:00
}
defer rows.Close()
// It will only be one row though
for rows.Next() {
txt, err := getModelFromRow(rows)
2016-11-11 21:48:00 +07:00
if err != nil {
2016-11-13 19:50:44 +07:00
return ACMETxt{}, err
2016-11-11 21:48:00 +07:00
}
results = append(results, txt)
2016-11-11 21:48:00 +07:00
}
2016-11-13 19:50:44 +07:00
if len(results) > 0 {
return results[0], nil
}
2016-11-17 00:15:36 +07:00
return ACMETxt{}, errors.New("no user")
2016-11-11 21:48:00 +07:00
}
2016-11-28 04:21:46 +07:00
func (d *acmedb) GetByDomain(domain string) ([]ACMETxt, error) {
2016-11-28 00:41:54 +07:00
d.Lock()
defer d.Unlock()
2016-11-23 22:11:31 +07:00
domain = sanitizeString(domain)
2016-11-11 21:48:00 +07:00
var a []ACMETxt
2016-11-17 00:15:36 +07:00
getSQL := `
SELECT Username, Password, Subdomain, Value, LastActive, AllowFrom
2016-11-11 21:48:00 +07:00
FROM records
2016-11-17 22:52:55 +07:00
WHERE Subdomain=$1 LIMIT 1
2016-11-11 21:48:00 +07:00
`
2016-11-17 22:52:55 +07:00
if DNSConf.Database.Engine == "sqlite3" {
getSQL = getSQLiteStmt(getSQL)
}
2016-11-17 00:15:36 +07:00
sm, err := d.DB.Prepare(getSQL)
2016-11-11 21:48:00 +07:00
if err != nil {
return a, err
}
defer sm.Close()
rows, err := sm.Query(domain)
if err != nil {
return a, err
}
defer rows.Close()
for rows.Next() {
txt, err := getModelFromRow(rows)
2016-11-11 21:48:00 +07:00
if err != nil {
return a, err
}
a = append(a, txt)
}
return a, nil
}
2016-11-28 04:21:46 +07:00
func (d *acmedb) Update(a ACMETxt) error {
2016-11-28 00:41:54 +07:00
d.Lock()
defer d.Unlock()
2016-11-11 21:48:00 +07:00
// Data in a is already sanitized
2016-11-17 22:52:55 +07:00
timenow := time.Now().Unix()
2016-11-17 00:15:36 +07:00
updSQL := `
2016-11-17 22:52:55 +07:00
UPDATE records SET Value=$1, LastActive=$2
WHERE Username=$3 AND Subdomain=$4
2016-11-11 21:48:00 +07:00
`
2016-11-17 22:52:55 +07:00
if DNSConf.Database.Engine == "sqlite3" {
updSQL = getSQLiteStmt(updSQL)
}
2016-11-17 00:15:36 +07:00
sm, err := d.DB.Prepare(updSQL)
2016-11-11 21:48:00 +07:00
if err != nil {
return err
}
defer sm.Close()
2016-11-17 22:52:55 +07:00
_, err = sm.Exec(a.Value, timenow, a.Username, a.Subdomain)
2016-11-11 21:48:00 +07:00
if err != nil {
return err
}
return nil
}
2016-11-28 04:21:46 +07:00
func getModelFromRow(r *sql.Rows) (ACMETxt, error) {
txt := ACMETxt{}
2016-12-01 05:03:08 +07:00
afrom := ""
err := r.Scan(
&txt.Username,
&txt.Password,
&txt.Subdomain,
&txt.Value,
&txt.LastActive,
2016-12-01 05:03:08 +07:00
&afrom)
if err != nil {
log.WithFields(log.Fields{"error": err.Error()}).Error("Row scan error")
}
cslice := cidrslice{}
err = json.Unmarshal([]byte(afrom), &cslice)
if err != nil {
log.WithFields(log.Fields{"error": err.Error()}).Error("JSON unmarshall error")
}
txt.AllowFrom = cslice
return txt, err
}
2016-11-28 04:21:46 +07:00
func (d *acmedb) Close() {
d.DB.Close()
}
func (d *acmedb) GetBackend() *sql.DB {
return d.DB
}
func (d *acmedb) SetBackend(backend *sql.DB) {
d.DB = backend
}