acme-dns/db.go
2018-08-10 16:51:32 +03:00

344 lines
7.7 KiB
Go

package main
import (
"database/sql"
"encoding/json"
"errors"
"fmt"
"regexp"
"strconv"
"time"
"github.com/google/uuid"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/bcrypt"
)
// DBVersion shows the database version this code uses. This is used for update checks.
var DBVersion = 1
var acmeTable = `
CREATE TABLE IF NOT EXISTS acmedns(
Name TEXT,
Value TEXT
);`
var userTable = `
CREATE TABLE IF NOT EXISTS records(
Username TEXT UNIQUE NOT NULL PRIMARY KEY,
Password TEXT UNIQUE NOT NULL,
Subdomain TEXT UNIQUE NOT NULL,
AllowFrom TEXT
);`
var txtTable = `
CREATE TABLE IF NOT EXISTS txt(
Subdomain TEXT NOT NULL,
Value TEXT NOT NULL DEFAULT '',
LastUpdate INT
);`
var txtTablePG = `
CREATE TABLE IF NOT EXISTS txt(
rowid SERIAL,
Subdomain TEXT NOT NULL,
Value TEXT NOT NULL DEFAULT '',
LastUpdate INT
);`
// getSQLiteStmt replaces all PostgreSQL prepared statement placeholders (eg. $1, $2) with SQLite variant "?"
func getSQLiteStmt(s string) string {
re, _ := regexp.Compile("\\$[0-9]")
return re.ReplaceAllString(s, "?")
}
func (d *acmedb) Init(engine string, connection string) error {
d.Lock()
defer d.Unlock()
db, err := sql.Open(engine, connection)
if err != nil {
return err
}
d.DB = db
// Check version first to try to catch old versions without version string
var versionString string
_ = d.DB.QueryRow("SELECT Value FROM acmedns WHERE Name='db_version'").Scan(&versionString)
if versionString == "" {
versionString = "0"
}
_, err = d.DB.Exec(acmeTable)
_, err = d.DB.Exec(userTable)
if Config.Database.Engine == "sqlite3" {
_, err = d.DB.Exec(txtTable)
} else {
_, err = d.DB.Exec(txtTablePG)
}
// If everything is fine, handle db upgrade tasks
if err == nil {
err = d.checkDBUpgrades(versionString)
}
if err == nil {
if versionString == "0" {
// No errors so we should now be in version 1
insversion := fmt.Sprintf("INSERT INTO acmedns (Name, Value) values('db_version', '%d')", DBVersion)
_, err = db.Exec(insversion)
}
}
return err
}
func (d *acmedb) checkDBUpgrades(versionString string) error {
var err error
version, err := strconv.Atoi(versionString)
if err != nil {
return err
}
if version != DBVersion {
return d.handleDBUpgrades(version)
}
return nil
}
func (d *acmedb) handleDBUpgrades(version int) error {
if version == 0 {
return d.handleDBUpgradeTo1()
}
return nil
}
func (d *acmedb) handleDBUpgradeTo1() error {
var err error
var subdomains []string
rows, err := d.DB.Query("SELECT Subdomain FROM records")
if err != nil {
log.WithFields(log.Fields{"error": err.Error()}).Error("Error in DB upgrade")
return err
}
defer rows.Close()
for rows.Next() {
var subdomain string
err = rows.Scan(&subdomain)
if err != nil {
log.WithFields(log.Fields{"error": err.Error()}).Error("Error in DB upgrade while reading values")
return err
}
subdomains = append(subdomains, subdomain)
}
err = rows.Err()
if err != nil {
log.WithFields(log.Fields{"error": err.Error()}).Error("Error in DB upgrade while inserting values")
return err
}
tx, err := d.DB.Begin()
// Rollback if errored, commit if not
defer func() {
if err != nil {
tx.Rollback()
return
}
tx.Commit()
}()
_, _ = tx.Exec("DELETE FROM txt")
for _, subdomain := range subdomains {
if subdomain != "" {
// Insert two rows for each subdomain to txt table
err = d.NewTXTValuesInTransaction(tx, subdomain)
if err != nil {
log.WithFields(log.Fields{"error": err.Error()}).Error("Error in DB upgrade while inserting values")
return err
}
}
}
// SQLite doesn't support dropping columns
if Config.Database.Engine != "sqlite3" {
_, _ = tx.Exec("ALTER TABLE records DROP COLUMN IF EXISTS Value")
_, _ = tx.Exec("ALTER TABLE records DROP COLUMN IF EXISTS LastActive")
}
_, err = tx.Exec("UPDATE acmedns SET Value='1' WHERE Name='db_version'")
return err
}
// Create two rows for subdomain to the txt table
func (d *acmedb) NewTXTValuesInTransaction(tx *sql.Tx, subdomain string) error {
var err error
instr := fmt.Sprintf("INSERT INTO txt (Subdomain, LastUpdate) values('%s', 0)", subdomain)
_, err = tx.Exec(instr)
_, err = tx.Exec(instr)
return err
}
func (d *acmedb) Register(afrom cidrslice) (ACMETxt, error) {
d.Lock()
defer d.Unlock()
var err error
tx, err := d.DB.Begin()
// Rollback if errored, commit if not
defer func() {
if err != nil {
tx.Rollback()
return
}
tx.Commit()
}()
a := newACMETxt()
a.AllowFrom = cidrslice(afrom.ValidEntries())
passwordHash, err := bcrypt.GenerateFromPassword([]byte(a.Password), 10)
regSQL := `
INSERT INTO records(
Username,
Password,
Subdomain,
AllowFrom)
values($1, $2, $3, $4)`
if Config.Database.Engine == "sqlite3" {
regSQL = getSQLiteStmt(regSQL)
}
sm, err := tx.Prepare(regSQL)
if err != nil {
log.WithFields(log.Fields{"error": err.Error()}).Error("Database error in prepare")
return a, errors.New("SQL error")
}
defer sm.Close()
_, err = sm.Exec(a.Username.String(), passwordHash, a.Subdomain, a.AllowFrom.JSON())
if err == nil {
err = d.NewTXTValuesInTransaction(tx, a.Subdomain)
}
return a, err
}
func (d *acmedb) GetByUsername(u uuid.UUID) (ACMETxt, error) {
d.Lock()
defer d.Unlock()
var results []ACMETxt
getSQL := `
SELECT Username, Password, Subdomain, AllowFrom
FROM records
WHERE Username=$1 LIMIT 1
`
if Config.Database.Engine == "sqlite3" {
getSQL = getSQLiteStmt(getSQL)
}
sm, err := d.DB.Prepare(getSQL)
if err != nil {
return ACMETxt{}, err
}
defer sm.Close()
rows, err := sm.Query(u.String())
if err != nil {
return ACMETxt{}, err
}
defer rows.Close()
// It will only be one row though
for rows.Next() {
txt, err := getModelFromRow(rows)
if err != nil {
return ACMETxt{}, err
}
results = append(results, txt)
}
if len(results) > 0 {
return results[0], nil
}
return ACMETxt{}, errors.New("no user")
}
func (d *acmedb) GetTXTForDomain(domain string) ([]string, error) {
d.Lock()
defer d.Unlock()
domain = sanitizeString(domain)
var txts []string
getSQL := `
SELECT Value FROM txt WHERE Subdomain=$1 LIMIT 2
`
if Config.Database.Engine == "sqlite3" {
getSQL = getSQLiteStmt(getSQL)
}
sm, err := d.DB.Prepare(getSQL)
if err != nil {
return txts, err
}
defer sm.Close()
rows, err := sm.Query(domain)
if err != nil {
return txts, err
}
defer rows.Close()
for rows.Next() {
var rtxt string
err = rows.Scan(&rtxt)
if err != nil {
return txts, err
}
txts = append(txts, rtxt)
}
return txts, nil
}
func (d *acmedb) Update(a ACMETxt) error {
d.Lock()
defer d.Unlock()
var err error
// Data in a is already sanitized
timenow := time.Now().Unix()
updSQL := `
UPDATE txt SET Value=$1, LastUpdate=$2
WHERE rowid=(
SELECT rowid FROM txt WHERE Subdomain=$3 ORDER BY LastUpdate LIMIT 1)
`
if Config.Database.Engine == "sqlite3" {
updSQL = getSQLiteStmt(updSQL)
}
sm, err := d.DB.Prepare(updSQL)
if err != nil {
return err
}
defer sm.Close()
_, err = sm.Exec(a.Value, timenow, a.Subdomain)
if err != nil {
return err
}
return nil
}
func getModelFromRow(r *sql.Rows) (ACMETxt, error) {
txt := ACMETxt{}
afrom := ""
err := r.Scan(
&txt.Username,
&txt.Password,
&txt.Subdomain,
&afrom)
if err != nil {
log.WithFields(log.Fields{"error": err.Error()}).Error("Row scan error")
}
cslice := cidrslice{}
err = json.Unmarshal([]byte(afrom), &cslice)
if err != nil {
log.WithFields(log.Fields{"error": err.Error()}).Error("JSON unmarshall error")
}
txt.AllowFrom = cslice
return txt, err
}
func (d *acmedb) Close() {
d.DB.Close()
}
func (d *acmedb) GetBackend() *sql.DB {
return d.DB
}
func (d *acmedb) SetBackend(backend *sql.DB) {
d.DB = backend
}