Made DB an interface

This commit is contained in:
Joona Hoikkala 2016-11-27 23:21:46 +02:00
parent e9f18c99d8
commit 4615826267
No known key found for this signature in database
GPG Key ID: C14AAE0F5ADCB854
7 changed files with 107 additions and 36 deletions

View File

@ -60,9 +60,9 @@ func TestApiRegister(t *testing.T) {
func TestApiRegisterWithMockDB(t *testing.T) {
e := setupIris(t, false, false)
oldDb := DB.DB
oldDb := DB.GetBackend()
db, mock, _ := sqlmock.New()
DB.DB = db
DB.SetBackend(db)
defer db.Close()
mock.ExpectBegin()
mock.ExpectPrepare("INSERT INTO records").WillReturnError(errors.New("error"))
@ -70,7 +70,7 @@ func TestApiRegisterWithMockDB(t *testing.T) {
Status(iris.StatusInternalServerError).
JSON().Object().
ContainsKey("error")
DB.DB = oldDb
DB.SetBackend(oldDb)
}
func TestApiUpdateWithoutCredentials(t *testing.T) {
@ -121,9 +121,9 @@ func TestApiUpdateWithCredentialsMockDB(t *testing.T) {
updateJSON["txt"] = validTxtData
e := setupIris(t, false, true)
oldDb := DB.DB
oldDb := DB.GetBackend()
db, mock, _ := sqlmock.New()
DB.DB = db
DB.SetBackend(db)
defer db.Close()
mock.ExpectBegin()
mock.ExpectPrepare("UPDATE records").WillReturnError(errors.New("error"))
@ -133,7 +133,7 @@ func TestApiUpdateWithCredentialsMockDB(t *testing.T) {
Status(iris.StatusInternalServerError).
JSON().Object().
ContainsKey("error")
DB.DB = oldDb
DB.SetBackend(oldDb)
}
func TestApiManyUpdateWithCredentials(t *testing.T) {

28
db.go
View File

@ -9,15 +9,9 @@ import (
"github.com/satori/go.uuid"
"golang.org/x/crypto/bcrypt"
"regexp"
"sync"
"time"
)
type database struct {
sync.Mutex
DB *sql.DB
}
var recordsTable = `
CREATE TABLE IF NOT EXISTS records(
Username TEXT UNIQUE NOT NULL PRIMARY KEY,
@ -37,7 +31,7 @@ func getSQLiteStmt(s string) string {
return re.ReplaceAllString(s, "?")
}
func (d *database) Init(engine string, connection string) error {
func (d *acmedb) Init(engine string, connection string) error {
d.Lock()
defer d.Unlock()
db, err := sql.Open(engine, connection)
@ -53,7 +47,7 @@ func (d *database) Init(engine string, connection string) error {
return nil
}
func (d *database) Register() (ACMETxt, error) {
func (d *acmedb) Register() (ACMETxt, error) {
d.Lock()
defer d.Unlock()
a, err := newACMETxt()
@ -85,7 +79,7 @@ func (d *database) Register() (ACMETxt, error) {
return a, nil
}
func (d *database) GetByUsername(u uuid.UUID) (ACMETxt, error) {
func (d *acmedb) GetByUsername(u uuid.UUID) (ACMETxt, error) {
d.Lock()
defer d.Unlock()
var results []ACMETxt
@ -129,7 +123,7 @@ func (d *database) GetByUsername(u uuid.UUID) (ACMETxt, error) {
return ACMETxt{}, errors.New("no user")
}
func (d *database) GetByDomain(domain string) ([]ACMETxt, error) {
func (d *acmedb) GetByDomain(domain string) ([]ACMETxt, error) {
d.Lock()
defer d.Unlock()
domain = sanitizeString(domain)
@ -165,7 +159,7 @@ func (d *database) GetByDomain(domain string) ([]ACMETxt, error) {
return a, nil
}
func (d *database) Update(a ACMETxt) error {
func (d *acmedb) Update(a ACMETxt) error {
d.Lock()
defer d.Unlock()
// Data in a is already sanitized
@ -189,3 +183,15 @@ func (d *database) Update(a ACMETxt) error {
}
return nil
}
func (d *acmedb) Close() {
d.DB.Close()
}
func (d *acmedb) GetBackend() *sql.DB {
return d.DB
}
func (d *acmedb) SetBackend(backend *sql.DB) {
d.DB = backend
}

13
main.go
View File

@ -1,7 +1,6 @@
package main
import (
"fmt"
log "github.com/Sirupsen/logrus"
"os"
)
@ -17,11 +16,7 @@ var RR Records
func main() {
// Read global config
configTmp, err := readConfig("config.cfg")
if err != nil {
fmt.Printf("Got error %v\n", err)
os.Exit(1)
}
configTmp := readConfig("config.cfg")
DNSConf = configTmp
setupLogging(DNSConf.Logconfig.Format, DNSConf.Logconfig.Level)
@ -30,12 +25,14 @@ func main() {
RR.Parse(DNSConf.General.StaticRecords)
// Open database
err = DB.Init(DNSConf.Database.Engine, DNSConf.Database.Connection)
newDB := new(acmedb)
err := newDB.Init(DNSConf.Database.Engine, DNSConf.Database.Connection)
if err != nil {
log.Errorf("Could not open database [%v]", err)
os.Exit(1)
}
defer DB.DB.Close()
DB = newDB
defer DB.Close()
// DNS server
startDNS(DNSConf.General.Listen)

View File

@ -16,22 +16,24 @@ func TestMain(m *testing.M) {
RR.Parse(records)
flag.Parse()
newDb := new(acmedb)
if *postgres {
DNSConf.Database.Engine = "postgres"
err := DB.Init("postgres", "postgres://acmedns:acmedns@localhost/acmedns")
err := newDb.Init("postgres", "postgres://acmedns:acmedns@localhost/acmedns")
if err != nil {
fmt.Println("PostgreSQL integration tests expect database \"acmedns\" running in localhost, with username and password set to \"acmedns\"")
os.Exit(1)
}
} else {
DNSConf.Database.Engine = "sqlite3"
_ = DB.Init("sqlite3", ":memory:")
_ = newDb.Init("sqlite3", ":memory:")
}
DB = newDb
server := startDNS("0.0.0.0:15353")
exitval := m.Run()
server.Shutdown()
DB.DB.Close()
DB.Close()
os.Exit(exitval)
}

View File

@ -1,8 +1,10 @@
package main
import (
"database/sql"
"github.com/miekg/dns"
"github.com/satori/go.uuid"
"sync"
)
// Records is for static records
@ -38,7 +40,7 @@ type dbsettings struct {
// API config
type httpapi struct {
Domain string
Domain string `toml:"api_domain"`
Port string
TLS string
TLSCertPrivkey string `toml:"tls_cert_privkey"`
@ -67,3 +69,21 @@ type ACMETxtPost struct {
Subdomain string `json:"subdomain"`
Value string `json:"txt"`
}
type acmedb struct {
sync.Mutex
DB *sql.DB
}
type database interface {
Init(string, string) error
Register() (ACMETxt, error)
GetByUsername(uuid.UUID) (ACMETxt, error)
GetByDomain(string) ([]ACMETxt, error)
Update(ACMETxt) error
GetBackend() *sql.DB
SetBackend(*sql.DB)
Close()
Lock()
Unlock()
}

12
util.go
View File

@ -2,7 +2,6 @@ package main
import (
"crypto/rand"
"errors"
"fmt"
"github.com/BurntSushi/toml"
log "github.com/Sirupsen/logrus"
@ -16,12 +15,11 @@ import (
"strings"
)
func readConfig(fname string) (DNSConfig, error) {
func readConfig(fname string) DNSConfig {
var conf DNSConfig
if _, err := toml.DecodeFile(fname, &conf); err != nil {
return DNSConfig{}, errors.New("Malformed configuration file")
}
return conf, nil
// Practically never errors
_, _ = toml.DecodeFile(fname, &conf)
return conf
}
func sanitizeString(s string) string {
@ -62,7 +60,7 @@ func newACMETxt() (ACMETxt, error) {
}
func setupLogging(format string, level string) {
if DNSConf.Logconfig.Format == "json" {
if format == "json" {
log.SetFormatter(&log.JSONFormatter{})
}
switch level {

View File

@ -2,6 +2,8 @@ package main
import (
log "github.com/Sirupsen/logrus"
"io/ioutil"
"os"
"testing"
)
@ -23,3 +25,49 @@ func TestSetupLogging(t *testing.T) {
}
}
}
func TestReadConfig(t *testing.T) {
for i, test := range []struct {
inFile []byte
output DNSConfig
}{
{
[]byte("[general]\nlisten = \":53\"\ndebug = true\n[api]\napi_domain = \"something.strange\""),
DNSConfig{
General: general{
Listen: ":53",
Debug: true,
},
API: httpapi{
Domain: "something.strange",
},
},
},
{
[]byte("[\x00[[[[[[[[[de\nlisten =]"),
DNSConfig{},
},
} {
tmpfile, err := ioutil.TempFile("", "acmedns")
if err != nil {
t.Error("Could not create temporary file")
}
defer os.Remove(tmpfile.Name())
if _, err := tmpfile.Write(test.inFile); err != nil {
t.Error("Could not write to temporary file")
}
if err := tmpfile.Close(); err != nil {
t.Error("Could not close temporary file")
}
ret := readConfig(tmpfile.Name())
if ret.General.Listen != test.output.General.Listen {
t.Errorf("Test %d: Expected listen value %s, but got %s", i, test.output.General.Listen, ret.General.Listen)
}
if ret.API.Domain != test.output.API.Domain {
t.Errorf("Test %d: Expected HTTP API domain %s, but got %s", i, test.output.API.Domain, ret.API.Domain)
}
}
}