diff --git a/acmetxt.go b/acmetxt.go index d4dbab0..8584017 100644 --- a/acmetxt.go +++ b/acmetxt.go @@ -12,8 +12,7 @@ type ACMETxt struct { Username uuid.UUID Password string ACMETxtPost - LastActive int64 - AllowFrom cidrslice + AllowFrom cidrslice } // ACMETxtPost holds the DNS part of the ACMETxt struct diff --git a/db.go b/db.go index 9b6b0a0..d245808 100644 --- a/db.go +++ b/db.go @@ -4,7 +4,9 @@ import ( "database/sql" "encoding/json" "errors" + "fmt" "regexp" + "strconv" "time" _ "github.com/lib/pq" @@ -14,16 +16,38 @@ import ( "golang.org/x/crypto/bcrypt" ) -var recordsTable = ` +// DBVersion shows the database version this code uses. This is used for update checks. +var DBVersion = 1 + +var acmeTable = ` + CREATE TABLE IF NOT EXISTS acmedns( + Name TEXT, + Value TEXT + );` + +var userTable = ` CREATE TABLE IF NOT EXISTS records( Username TEXT UNIQUE NOT NULL PRIMARY KEY, Password TEXT UNIQUE NOT NULL, Subdomain TEXT UNIQUE NOT NULL, - Value TEXT, - LastActive INT, AllowFrom TEXT );` +var txtTable = ` + CREATE TABLE IF NOT EXISTS txt( + Subdomain TEXT NOT NULL, + Value TEXT NOT NULL DEFAULT '', + LastUpdate INT + );` + +var txtTablePG = ` + CREATE TABLE IF NOT EXISTS txt( + rowid SERIAL, + Subdomain TEXT NOT NULL, + Value TEXT NOT NULL DEFAULT '', + LastUpdate INT + );` + // getSQLiteStmt replaces all PostgreSQL prepared statement placeholders (eg. $1, $2) with SQLite variant "?" func getSQLiteStmt(s string) string { re, _ := regexp.Compile("\\$[0-9]") @@ -38,44 +62,151 @@ func (d *acmedb) Init(engine string, connection string) error { return err } d.DB = db - //d.DB.SetMaxOpenConns(1) - _, err = d.DB.Exec(recordsTable) + // Check version first to try to catch old versions without version string + var versionString string + _ = d.DB.QueryRow("SELECT Value FROM acmedns WHERE Name='db_version'").Scan(&versionString) + if versionString == "" { + versionString = "0" + } + _, err = d.DB.Exec(acmeTable) + _, err = d.DB.Exec(userTable) + if Config.Database.Engine == "sqlite3" { + _, err = d.DB.Exec(txtTable) + } else { + _, err = d.DB.Exec(txtTablePG) + } + // If everything is fine, handle db upgrade tasks + if err == nil { + err = d.checkDBUpgrades(versionString) + } + if err == nil { + if versionString == "0" { + // No errors so we should now be in version 1 + insversion := fmt.Sprintf("INSERT INTO acmedns (Name, Value) values('db_version', '%d')", DBVersion) + _, err = db.Exec(insversion) + } + } + return err +} + +func (d *acmedb) checkDBUpgrades(versionString string) error { + var err error + version, err := strconv.Atoi(versionString) if err != nil { return err } + if version != DBVersion { + return d.handleDBUpgrades(version) + } return nil + +} + +func (d *acmedb) handleDBUpgrades(version int) error { + if version == 0 { + return d.handleDBUpgradeTo1() + } + return nil +} + +func (d *acmedb) handleDBUpgradeTo1() error { + var err error + var subdomains []string + rows, err := d.DB.Query("SELECT Subdomain FROM records") + if err != nil { + log.WithFields(log.Fields{"error": err.Error()}).Error("Error in DB upgrade") + return err + } + defer rows.Close() + for rows.Next() { + var subdomain string + err = rows.Scan(&subdomain) + if err != nil { + log.WithFields(log.Fields{"error": err.Error()}).Error("Error in DB upgrade while reading values") + return err + } + subdomains = append(subdomains, subdomain) + } + err = rows.Err() + if err != nil { + log.WithFields(log.Fields{"error": err.Error()}).Error("Error in DB upgrade while inserting values") + return err + } + tx, err := d.DB.Begin() + // Rollback if errored, commit if not + defer func() { + if err != nil { + tx.Rollback() + return + } + tx.Commit() + }() + _, _ = tx.Exec("DELETE FROM txt") + for _, subdomain := range subdomains { + if subdomain != "" { + // Insert two rows for each subdomain to txt table + err = d.NewTXTValuesInTransaction(tx, subdomain) + if err != nil { + log.WithFields(log.Fields{"error": err.Error()}).Error("Error in DB upgrade while inserting values") + return err + } + } + } + // SQLite doesn't support dropping columns + if Config.Database.Engine != "sqlite3" { + _, _ = tx.Exec("ALTER TABLE records DROP COLUMN IF EXISTS Value") + _, _ = tx.Exec("ALTER TABLE records DROP COLUMN IF EXISTS LastActive") + } + _, err = tx.Exec("UPDATE acmedns SET Value='1' WHERE Name='db_version'") + return err +} + +// Create two rows for subdomain to the txt table +func (d *acmedb) NewTXTValuesInTransaction(tx *sql.Tx, subdomain string) error { + var err error + instr := fmt.Sprintf("INSERT INTO txt (Subdomain, LastUpdate) values('%s', 0)", subdomain) + _, err = tx.Exec(instr) + _, err = tx.Exec(instr) + return err } func (d *acmedb) Register(afrom cidrslice) (ACMETxt, error) { d.Lock() defer d.Unlock() + var err error + tx, err := d.DB.Begin() + // Rollback if errored, commit if not + defer func() { + if err != nil { + tx.Rollback() + return + } + tx.Commit() + }() a := newACMETxt() a.AllowFrom = cidrslice(afrom.ValidEntries()) passwordHash, err := bcrypt.GenerateFromPassword([]byte(a.Password), 10) - timenow := time.Now().Unix() regSQL := ` INSERT INTO records( Username, Password, Subdomain, - Value, - LastActive, AllowFrom) - values($1, $2, $3, '', $4, $5)` + values($1, $2, $3, $4)` if Config.Database.Engine == "sqlite3" { regSQL = getSQLiteStmt(regSQL) } - sm, err := d.DB.Prepare(regSQL) + sm, err := tx.Prepare(regSQL) if err != nil { log.WithFields(log.Fields{"error": err.Error()}).Error("Database error in prepare") return a, errors.New("SQL error") } defer sm.Close() - _, err = sm.Exec(a.Username.String(), passwordHash, a.Subdomain, timenow, a.AllowFrom.JSON()) - if err != nil { - return a, err + _, err = sm.Exec(a.Username.String(), passwordHash, a.Subdomain, a.AllowFrom.JSON()) + if err == nil { + err = d.NewTXTValuesInTransaction(tx, a.Subdomain) } - return a, nil + return a, err } func (d *acmedb) GetByUsername(u uuid.UUID) (ACMETxt, error) { @@ -83,7 +214,7 @@ func (d *acmedb) GetByUsername(u uuid.UUID) (ACMETxt, error) { defer d.Unlock() var results []ACMETxt getSQL := ` - SELECT Username, Password, Subdomain, Value, LastActive, AllowFrom + SELECT Username, Password, Subdomain, AllowFrom FROM records WHERE Username=$1 LIMIT 1 ` @@ -116,15 +247,13 @@ func (d *acmedb) GetByUsername(u uuid.UUID) (ACMETxt, error) { return ACMETxt{}, errors.New("no user") } -func (d *acmedb) GetByDomain(domain string) ([]ACMETxt, error) { +func (d *acmedb) GetTXTForDomain(domain string) ([]string, error) { d.Lock() defer d.Unlock() domain = sanitizeString(domain) - var a []ACMETxt + var txts []string getSQL := ` - SELECT Username, Password, Subdomain, Value, LastActive, AllowFrom - FROM records - WHERE Subdomain=$1 LIMIT 1 + SELECT Value FROM txt WHERE Subdomain=$1 LIMIT 2 ` if Config.Database.Engine == "sqlite3" { getSQL = getSQLiteStmt(getSQL) @@ -132,33 +261,37 @@ func (d *acmedb) GetByDomain(domain string) ([]ACMETxt, error) { sm, err := d.DB.Prepare(getSQL) if err != nil { - return a, err + return txts, err } defer sm.Close() rows, err := sm.Query(domain) if err != nil { - return a, err + return txts, err } defer rows.Close() for rows.Next() { - txt, err := getModelFromRow(rows) + var rtxt string + err = rows.Scan(&rtxt) if err != nil { - return a, err + return txts, err } - a = append(a, txt) + txts = append(txts, rtxt) } - return a, nil + return txts, nil } func (d *acmedb) Update(a ACMETxt) error { d.Lock() defer d.Unlock() + var err error // Data in a is already sanitized timenow := time.Now().Unix() + updSQL := ` - UPDATE records SET Value=$1, LastActive=$2 - WHERE Username=$3 AND Subdomain=$4 + UPDATE txt SET Value=$1, LastUpdate=$2 + WHERE rowid=( + SELECT rowid FROM txt WHERE Subdomain=$3 ORDER BY LastUpdate LIMIT 1) ` if Config.Database.Engine == "sqlite3" { updSQL = getSQLiteStmt(updSQL) @@ -169,7 +302,7 @@ func (d *acmedb) Update(a ACMETxt) error { return err } defer sm.Close() - _, err = sm.Exec(a.Value, timenow, a.Username, a.Subdomain) + _, err = sm.Exec(a.Value, timenow, a.Subdomain) if err != nil { return err } @@ -183,8 +316,6 @@ func getModelFromRow(r *sql.Rows) (ACMETxt, error) { &txt.Username, &txt.Password, &txt.Subdomain, - &txt.Value, - &txt.LastActive, &afrom) if err != nil { log.WithFields(log.Fields{"error": err.Error()}).Error("Row scan error") diff --git a/db_test.go b/db_test.go index 4580a34..3cb265c 100644 --- a/db_test.go +++ b/db_test.go @@ -118,7 +118,7 @@ func TestPrepareErrors(t *testing.T) { t.Errorf("Expected error, but didn't get one") } - _, err = DB.GetByDomain(reg.Subdomain) + _, err = DB.GetTXTForDomain(reg.Subdomain) if err == nil { t.Errorf("Expected error, but didn't get one") } @@ -151,7 +151,7 @@ func TestQueryExecErrors(t *testing.T) { t.Errorf("Expected error from exec, but got none") } - _, err = DB.GetByDomain(reg.Subdomain) + _, err = DB.GetTXTForDomain(reg.Subdomain) if err == nil { t.Errorf("Expected error from exec in GetByDomain, but got none") } @@ -195,11 +195,6 @@ func TestQueryScanErrors(t *testing.T) { if err == nil { t.Errorf("Expected error from scan in, but got none") } - - _, err = DB.GetByDomain(reg.Subdomain) - if err == nil { - t.Errorf("Expected error from scan in GetByDomain, but got none") - } } func TestBadDBValues(t *testing.T) { @@ -226,46 +221,55 @@ func TestBadDBValues(t *testing.T) { t.Errorf("Expected error from scan in, but got none") } - _, err = DB.GetByDomain(reg.Subdomain) + _, err = DB.GetTXTForDomain(reg.Subdomain) if err == nil { t.Errorf("Expected error from scan in GetByDomain, but got none") } } -func TestGetByDomain(t *testing.T) { - var regDomain = ACMETxt{} - +func TestGetTXTForDomain(t *testing.T) { // Create reg to refer to reg, err := DB.Register(cidrslice{}) if err != nil { t.Errorf("Registration failed, got error [%v]", err) } - regDomainSlice, err := DB.GetByDomain(reg.Subdomain) + txtval1 := "___validation_token_received_from_the_ca___" + txtval2 := "___validation_token_received_YEAH_the_ca___" + + reg.Value = txtval1 + _ = DB.Update(reg) + + reg.Value = txtval2 + _ = DB.Update(reg) + + regDomainSlice, err := DB.GetTXTForDomain(reg.Subdomain) if err != nil { t.Errorf("Could not get test user, got error [%v]", err) } if len(regDomainSlice) == 0 { - t.Errorf("No rows returned for GetByDomain [%s]", reg.Subdomain) - } else { - regDomain = regDomainSlice[0] + t.Errorf("No rows returned for GetTXTForDomain [%s]", reg.Subdomain) } - if reg.Username != regDomain.Username { - t.Errorf("GetByUsername username [%q] did not match the original [%q]", regDomain.Username, reg.Username) + var val1found = false + var val2found = false + for _, v := range regDomainSlice { + if v == txtval1 { + val1found = true + } + if v == txtval2 { + val2found = true + } } - - if reg.Subdomain != regDomain.Subdomain { - t.Errorf("GetByUsername subdomain [%q] did not match the original [%q]", regDomain.Subdomain, reg.Subdomain) + if !val1found { + t.Errorf("No TXT value found for val1") } - - // regDomain password already is a bcrypt hash - if !correctPassword(reg.Password, regDomain.Password) { - t.Errorf("The password [%s] does not match the hash [%s]", reg.Password, regDomain.Password) + if !val2found { + t.Errorf("No TXT value found for val2") } // Not found - regNotfound, _ := DB.GetByDomain("does-not-exist") + regNotfound, _ := DB.GetTXTForDomain("does-not-exist") if len(regNotfound) > 0 { t.Errorf("No records should be returned.") } @@ -294,12 +298,4 @@ func TestUpdate(t *testing.T) { if err != nil { t.Errorf("DB Update failed, got error: [%v]", err) } - - updUser, err := DB.GetByUsername(regUser.Username) - if err != nil { - t.Errorf("GetByUsername threw error [%v]", err) - } - if updUser.Value != validTXT { - t.Errorf("Update failed, fetched value [%s] does not match the update value [%s]", updUser.Value, validTXT) - } } diff --git a/dns.go b/dns.go index 122c5d7..3e687a4 100644 --- a/dns.go +++ b/dns.go @@ -2,8 +2,8 @@ package main import ( "fmt" - log "github.com/sirupsen/logrus" "github.com/miekg/dns" + log "github.com/sirupsen/logrus" "strings" "time" ) @@ -23,16 +23,16 @@ func answerTXT(q dns.Question) ([]dns.RR, int, error) { var ra []dns.RR rcode := dns.RcodeNameError subdomain := sanitizeDomainQuestion(q.Name) - atxt, err := DB.GetByDomain(subdomain) + atxt, err := DB.GetTXTForDomain(subdomain) if err != nil { log.WithFields(log.Fields{"error": err.Error()}).Debug("Error while trying to get record") return ra, dns.RcodeNameError, err } for _, v := range atxt { - if len(v.Value) > 0 { + if len(v) > 0 { r := new(dns.TXT) r.Hdr = dns.RR_Header{Name: q.Name, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 1} - r.Txt = append(r.Txt, v.Value) + r.Txt = append(r.Txt, v) ra = append(ra, r) rcode = dns.RcodeSuccess } diff --git a/main.go b/main.go index c521d56..036818b 100644 --- a/main.go +++ b/main.go @@ -36,6 +36,8 @@ func main() { if err != nil { log.Errorf("Could not open database [%v]", err) os.Exit(1) + } else { + log.Info("Connected to database") } DB = newDB defer DB.Close() diff --git a/types.go b/types.go index 173ab52..d6b6054 100644 --- a/types.go +++ b/types.go @@ -79,7 +79,7 @@ type database interface { Init(string, string) error Register(cidrslice) (ACMETxt, error) GetByUsername(uuid.UUID) (ACMETxt, error) - GetByDomain(string) ([]ACMETxt, error) + GetTXTForDomain(string) ([]string, error) Update(ACMETxt) error GetBackend() *sql.DB SetBackend(*sql.DB)