Made DB an interface

This commit is contained in:
Joona Hoikkala 2016-11-27 23:21:46 +02:00
parent e9f18c99d8
commit 4615826267
No known key found for this signature in database
GPG Key ID: C14AAE0F5ADCB854
7 changed files with 107 additions and 36 deletions

View File

@ -60,9 +60,9 @@ func TestApiRegister(t *testing.T) {
func TestApiRegisterWithMockDB(t *testing.T) { func TestApiRegisterWithMockDB(t *testing.T) {
e := setupIris(t, false, false) e := setupIris(t, false, false)
oldDb := DB.DB oldDb := DB.GetBackend()
db, mock, _ := sqlmock.New() db, mock, _ := sqlmock.New()
DB.DB = db DB.SetBackend(db)
defer db.Close() defer db.Close()
mock.ExpectBegin() mock.ExpectBegin()
mock.ExpectPrepare("INSERT INTO records").WillReturnError(errors.New("error")) mock.ExpectPrepare("INSERT INTO records").WillReturnError(errors.New("error"))
@ -70,7 +70,7 @@ func TestApiRegisterWithMockDB(t *testing.T) {
Status(iris.StatusInternalServerError). Status(iris.StatusInternalServerError).
JSON().Object(). JSON().Object().
ContainsKey("error") ContainsKey("error")
DB.DB = oldDb DB.SetBackend(oldDb)
} }
func TestApiUpdateWithoutCredentials(t *testing.T) { func TestApiUpdateWithoutCredentials(t *testing.T) {
@ -121,9 +121,9 @@ func TestApiUpdateWithCredentialsMockDB(t *testing.T) {
updateJSON["txt"] = validTxtData updateJSON["txt"] = validTxtData
e := setupIris(t, false, true) e := setupIris(t, false, true)
oldDb := DB.DB oldDb := DB.GetBackend()
db, mock, _ := sqlmock.New() db, mock, _ := sqlmock.New()
DB.DB = db DB.SetBackend(db)
defer db.Close() defer db.Close()
mock.ExpectBegin() mock.ExpectBegin()
mock.ExpectPrepare("UPDATE records").WillReturnError(errors.New("error")) mock.ExpectPrepare("UPDATE records").WillReturnError(errors.New("error"))
@ -133,7 +133,7 @@ func TestApiUpdateWithCredentialsMockDB(t *testing.T) {
Status(iris.StatusInternalServerError). Status(iris.StatusInternalServerError).
JSON().Object(). JSON().Object().
ContainsKey("error") ContainsKey("error")
DB.DB = oldDb DB.SetBackend(oldDb)
} }
func TestApiManyUpdateWithCredentials(t *testing.T) { func TestApiManyUpdateWithCredentials(t *testing.T) {

28
db.go
View File

@ -9,15 +9,9 @@ import (
"github.com/satori/go.uuid" "github.com/satori/go.uuid"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"regexp" "regexp"
"sync"
"time" "time"
) )
type database struct {
sync.Mutex
DB *sql.DB
}
var recordsTable = ` var recordsTable = `
CREATE TABLE IF NOT EXISTS records( CREATE TABLE IF NOT EXISTS records(
Username TEXT UNIQUE NOT NULL PRIMARY KEY, Username TEXT UNIQUE NOT NULL PRIMARY KEY,
@ -37,7 +31,7 @@ func getSQLiteStmt(s string) string {
return re.ReplaceAllString(s, "?") return re.ReplaceAllString(s, "?")
} }
func (d *database) Init(engine string, connection string) error { func (d *acmedb) Init(engine string, connection string) error {
d.Lock() d.Lock()
defer d.Unlock() defer d.Unlock()
db, err := sql.Open(engine, connection) db, err := sql.Open(engine, connection)
@ -53,7 +47,7 @@ func (d *database) Init(engine string, connection string) error {
return nil return nil
} }
func (d *database) Register() (ACMETxt, error) { func (d *acmedb) Register() (ACMETxt, error) {
d.Lock() d.Lock()
defer d.Unlock() defer d.Unlock()
a, err := newACMETxt() a, err := newACMETxt()
@ -85,7 +79,7 @@ func (d *database) Register() (ACMETxt, error) {
return a, nil return a, nil
} }
func (d *database) GetByUsername(u uuid.UUID) (ACMETxt, error) { func (d *acmedb) GetByUsername(u uuid.UUID) (ACMETxt, error) {
d.Lock() d.Lock()
defer d.Unlock() defer d.Unlock()
var results []ACMETxt var results []ACMETxt
@ -129,7 +123,7 @@ func (d *database) GetByUsername(u uuid.UUID) (ACMETxt, error) {
return ACMETxt{}, errors.New("no user") return ACMETxt{}, errors.New("no user")
} }
func (d *database) GetByDomain(domain string) ([]ACMETxt, error) { func (d *acmedb) GetByDomain(domain string) ([]ACMETxt, error) {
d.Lock() d.Lock()
defer d.Unlock() defer d.Unlock()
domain = sanitizeString(domain) domain = sanitizeString(domain)
@ -165,7 +159,7 @@ func (d *database) GetByDomain(domain string) ([]ACMETxt, error) {
return a, nil return a, nil
} }
func (d *database) Update(a ACMETxt) error { func (d *acmedb) Update(a ACMETxt) error {
d.Lock() d.Lock()
defer d.Unlock() defer d.Unlock()
// Data in a is already sanitized // Data in a is already sanitized
@ -189,3 +183,15 @@ func (d *database) Update(a ACMETxt) error {
} }
return nil return nil
} }
func (d *acmedb) Close() {
d.DB.Close()
}
func (d *acmedb) GetBackend() *sql.DB {
return d.DB
}
func (d *acmedb) SetBackend(backend *sql.DB) {
d.DB = backend
}

13
main.go
View File

@ -1,7 +1,6 @@
package main package main
import ( import (
"fmt"
log "github.com/Sirupsen/logrus" log "github.com/Sirupsen/logrus"
"os" "os"
) )
@ -17,11 +16,7 @@ var RR Records
func main() { func main() {
// Read global config // Read global config
configTmp, err := readConfig("config.cfg") configTmp := readConfig("config.cfg")
if err != nil {
fmt.Printf("Got error %v\n", err)
os.Exit(1)
}
DNSConf = configTmp DNSConf = configTmp
setupLogging(DNSConf.Logconfig.Format, DNSConf.Logconfig.Level) setupLogging(DNSConf.Logconfig.Format, DNSConf.Logconfig.Level)
@ -30,12 +25,14 @@ func main() {
RR.Parse(DNSConf.General.StaticRecords) RR.Parse(DNSConf.General.StaticRecords)
// Open database // Open database
err = DB.Init(DNSConf.Database.Engine, DNSConf.Database.Connection) newDB := new(acmedb)
err := newDB.Init(DNSConf.Database.Engine, DNSConf.Database.Connection)
if err != nil { if err != nil {
log.Errorf("Could not open database [%v]", err) log.Errorf("Could not open database [%v]", err)
os.Exit(1) os.Exit(1)
} }
defer DB.DB.Close() DB = newDB
defer DB.Close()
// DNS server // DNS server
startDNS(DNSConf.General.Listen) startDNS(DNSConf.General.Listen)

View File

@ -16,22 +16,24 @@ func TestMain(m *testing.M) {
RR.Parse(records) RR.Parse(records)
flag.Parse() flag.Parse()
newDb := new(acmedb)
if *postgres { if *postgres {
DNSConf.Database.Engine = "postgres" DNSConf.Database.Engine = "postgres"
err := DB.Init("postgres", "postgres://acmedns:acmedns@localhost/acmedns") err := newDb.Init("postgres", "postgres://acmedns:acmedns@localhost/acmedns")
if err != nil { if err != nil {
fmt.Println("PostgreSQL integration tests expect database \"acmedns\" running in localhost, with username and password set to \"acmedns\"") fmt.Println("PostgreSQL integration tests expect database \"acmedns\" running in localhost, with username and password set to \"acmedns\"")
os.Exit(1) os.Exit(1)
} }
} else { } else {
DNSConf.Database.Engine = "sqlite3" DNSConf.Database.Engine = "sqlite3"
_ = DB.Init("sqlite3", ":memory:") _ = newDb.Init("sqlite3", ":memory:")
} }
DB = newDb
server := startDNS("0.0.0.0:15353") server := startDNS("0.0.0.0:15353")
exitval := m.Run() exitval := m.Run()
server.Shutdown() server.Shutdown()
DB.DB.Close() DB.Close()
os.Exit(exitval) os.Exit(exitval)
} }

View File

@ -1,8 +1,10 @@
package main package main
import ( import (
"database/sql"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/satori/go.uuid" "github.com/satori/go.uuid"
"sync"
) )
// Records is for static records // Records is for static records
@ -38,7 +40,7 @@ type dbsettings struct {
// API config // API config
type httpapi struct { type httpapi struct {
Domain string Domain string `toml:"api_domain"`
Port string Port string
TLS string TLS string
TLSCertPrivkey string `toml:"tls_cert_privkey"` TLSCertPrivkey string `toml:"tls_cert_privkey"`
@ -67,3 +69,21 @@ type ACMETxtPost struct {
Subdomain string `json:"subdomain"` Subdomain string `json:"subdomain"`
Value string `json:"txt"` Value string `json:"txt"`
} }
type acmedb struct {
sync.Mutex
DB *sql.DB
}
type database interface {
Init(string, string) error
Register() (ACMETxt, error)
GetByUsername(uuid.UUID) (ACMETxt, error)
GetByDomain(string) ([]ACMETxt, error)
Update(ACMETxt) error
GetBackend() *sql.DB
SetBackend(*sql.DB)
Close()
Lock()
Unlock()
}

12
util.go
View File

@ -2,7 +2,6 @@ package main
import ( import (
"crypto/rand" "crypto/rand"
"errors"
"fmt" "fmt"
"github.com/BurntSushi/toml" "github.com/BurntSushi/toml"
log "github.com/Sirupsen/logrus" log "github.com/Sirupsen/logrus"
@ -16,12 +15,11 @@ import (
"strings" "strings"
) )
func readConfig(fname string) (DNSConfig, error) { func readConfig(fname string) DNSConfig {
var conf DNSConfig var conf DNSConfig
if _, err := toml.DecodeFile(fname, &conf); err != nil { // Practically never errors
return DNSConfig{}, errors.New("Malformed configuration file") _, _ = toml.DecodeFile(fname, &conf)
} return conf
return conf, nil
} }
func sanitizeString(s string) string { func sanitizeString(s string) string {
@ -62,7 +60,7 @@ func newACMETxt() (ACMETxt, error) {
} }
func setupLogging(format string, level string) { func setupLogging(format string, level string) {
if DNSConf.Logconfig.Format == "json" { if format == "json" {
log.SetFormatter(&log.JSONFormatter{}) log.SetFormatter(&log.JSONFormatter{})
} }
switch level { switch level {

View File

@ -2,6 +2,8 @@ package main
import ( import (
log "github.com/Sirupsen/logrus" log "github.com/Sirupsen/logrus"
"io/ioutil"
"os"
"testing" "testing"
) )
@ -23,3 +25,49 @@ func TestSetupLogging(t *testing.T) {
} }
} }
} }
func TestReadConfig(t *testing.T) {
for i, test := range []struct {
inFile []byte
output DNSConfig
}{
{
[]byte("[general]\nlisten = \":53\"\ndebug = true\n[api]\napi_domain = \"something.strange\""),
DNSConfig{
General: general{
Listen: ":53",
Debug: true,
},
API: httpapi{
Domain: "something.strange",
},
},
},
{
[]byte("[\x00[[[[[[[[[de\nlisten =]"),
DNSConfig{},
},
} {
tmpfile, err := ioutil.TempFile("", "acmedns")
if err != nil {
t.Error("Could not create temporary file")
}
defer os.Remove(tmpfile.Name())
if _, err := tmpfile.Write(test.inFile); err != nil {
t.Error("Could not write to temporary file")
}
if err := tmpfile.Close(); err != nil {
t.Error("Could not close temporary file")
}
ret := readConfig(tmpfile.Name())
if ret.General.Listen != test.output.General.Listen {
t.Errorf("Test %d: Expected listen value %s, but got %s", i, test.output.General.Listen, ret.General.Listen)
}
if ret.API.Domain != test.output.API.Domain {
t.Errorf("Test %d: Expected HTTP API domain %s, but got %s", i, test.output.API.Domain, ret.API.Domain)
}
}
}