From 4615826267f62ce5bb777ef3aa965f2102864b7d Mon Sep 17 00:00:00 2001 From: Joona Hoikkala Date: Sun, 27 Nov 2016 23:21:46 +0200 Subject: [PATCH] Made DB an interface --- api_test.go | 12 ++++++------ db.go | 28 +++++++++++++++++----------- main.go | 13 +++++-------- main_test.go | 8 +++++--- types.go | 22 +++++++++++++++++++++- util.go | 12 +++++------- util_test.go | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ 7 files changed, 107 insertions(+), 36 deletions(-) diff --git a/api_test.go b/api_test.go index ea7e9b0..b067f14 100644 --- a/api_test.go +++ b/api_test.go @@ -60,9 +60,9 @@ func TestApiRegister(t *testing.T) { func TestApiRegisterWithMockDB(t *testing.T) { e := setupIris(t, false, false) - oldDb := DB.DB + oldDb := DB.GetBackend() db, mock, _ := sqlmock.New() - DB.DB = db + DB.SetBackend(db) defer db.Close() mock.ExpectBegin() mock.ExpectPrepare("INSERT INTO records").WillReturnError(errors.New("error")) @@ -70,7 +70,7 @@ func TestApiRegisterWithMockDB(t *testing.T) { Status(iris.StatusInternalServerError). JSON().Object(). ContainsKey("error") - DB.DB = oldDb + DB.SetBackend(oldDb) } func TestApiUpdateWithoutCredentials(t *testing.T) { @@ -121,9 +121,9 @@ func TestApiUpdateWithCredentialsMockDB(t *testing.T) { updateJSON["txt"] = validTxtData e := setupIris(t, false, true) - oldDb := DB.DB + oldDb := DB.GetBackend() db, mock, _ := sqlmock.New() - DB.DB = db + DB.SetBackend(db) defer db.Close() mock.ExpectBegin() mock.ExpectPrepare("UPDATE records").WillReturnError(errors.New("error")) @@ -133,7 +133,7 @@ func TestApiUpdateWithCredentialsMockDB(t *testing.T) { Status(iris.StatusInternalServerError). JSON().Object(). ContainsKey("error") - DB.DB = oldDb + DB.SetBackend(oldDb) } func TestApiManyUpdateWithCredentials(t *testing.T) { diff --git a/db.go b/db.go index 99092df..85348d8 100644 --- a/db.go +++ b/db.go @@ -9,15 +9,9 @@ import ( "github.com/satori/go.uuid" "golang.org/x/crypto/bcrypt" "regexp" - "sync" "time" ) -type database struct { - sync.Mutex - DB *sql.DB -} - var recordsTable = ` CREATE TABLE IF NOT EXISTS records( Username TEXT UNIQUE NOT NULL PRIMARY KEY, @@ -37,7 +31,7 @@ func getSQLiteStmt(s string) string { return re.ReplaceAllString(s, "?") } -func (d *database) Init(engine string, connection string) error { +func (d *acmedb) Init(engine string, connection string) error { d.Lock() defer d.Unlock() db, err := sql.Open(engine, connection) @@ -53,7 +47,7 @@ func (d *database) Init(engine string, connection string) error { return nil } -func (d *database) Register() (ACMETxt, error) { +func (d *acmedb) Register() (ACMETxt, error) { d.Lock() defer d.Unlock() a, err := newACMETxt() @@ -85,7 +79,7 @@ func (d *database) Register() (ACMETxt, error) { return a, nil } -func (d *database) GetByUsername(u uuid.UUID) (ACMETxt, error) { +func (d *acmedb) GetByUsername(u uuid.UUID) (ACMETxt, error) { d.Lock() defer d.Unlock() var results []ACMETxt @@ -129,7 +123,7 @@ func (d *database) GetByUsername(u uuid.UUID) (ACMETxt, error) { return ACMETxt{}, errors.New("no user") } -func (d *database) GetByDomain(domain string) ([]ACMETxt, error) { +func (d *acmedb) GetByDomain(domain string) ([]ACMETxt, error) { d.Lock() defer d.Unlock() domain = sanitizeString(domain) @@ -165,7 +159,7 @@ func (d *database) GetByDomain(domain string) ([]ACMETxt, error) { return a, nil } -func (d *database) Update(a ACMETxt) error { +func (d *acmedb) Update(a ACMETxt) error { d.Lock() defer d.Unlock() // Data in a is already sanitized @@ -189,3 +183,15 @@ func (d *database) Update(a ACMETxt) error { } return nil } + +func (d *acmedb) Close() { + d.DB.Close() +} + +func (d *acmedb) GetBackend() *sql.DB { + return d.DB +} + +func (d *acmedb) SetBackend(backend *sql.DB) { + d.DB = backend +} diff --git a/main.go b/main.go index 288703b..ce895f8 100644 --- a/main.go +++ b/main.go @@ -1,7 +1,6 @@ package main import ( - "fmt" log "github.com/Sirupsen/logrus" "os" ) @@ -17,11 +16,7 @@ var RR Records func main() { // Read global config - configTmp, err := readConfig("config.cfg") - if err != nil { - fmt.Printf("Got error %v\n", err) - os.Exit(1) - } + configTmp := readConfig("config.cfg") DNSConf = configTmp setupLogging(DNSConf.Logconfig.Format, DNSConf.Logconfig.Level) @@ -30,12 +25,14 @@ func main() { RR.Parse(DNSConf.General.StaticRecords) // Open database - err = DB.Init(DNSConf.Database.Engine, DNSConf.Database.Connection) + newDB := new(acmedb) + err := newDB.Init(DNSConf.Database.Engine, DNSConf.Database.Connection) if err != nil { log.Errorf("Could not open database [%v]", err) os.Exit(1) } - defer DB.DB.Close() + DB = newDB + defer DB.Close() // DNS server startDNS(DNSConf.General.Listen) diff --git a/main_test.go b/main_test.go index 0f98dab..ed5d827 100644 --- a/main_test.go +++ b/main_test.go @@ -16,22 +16,24 @@ func TestMain(m *testing.M) { RR.Parse(records) flag.Parse() + newDb := new(acmedb) if *postgres { DNSConf.Database.Engine = "postgres" - err := DB.Init("postgres", "postgres://acmedns:acmedns@localhost/acmedns") + err := newDb.Init("postgres", "postgres://acmedns:acmedns@localhost/acmedns") if err != nil { fmt.Println("PostgreSQL integration tests expect database \"acmedns\" running in localhost, with username and password set to \"acmedns\"") os.Exit(1) } } else { DNSConf.Database.Engine = "sqlite3" - _ = DB.Init("sqlite3", ":memory:") + _ = newDb.Init("sqlite3", ":memory:") } + DB = newDb server := startDNS("0.0.0.0:15353") exitval := m.Run() server.Shutdown() - DB.DB.Close() + DB.Close() os.Exit(exitval) } diff --git a/types.go b/types.go index e6cfb90..3ea3d38 100644 --- a/types.go +++ b/types.go @@ -1,8 +1,10 @@ package main import ( + "database/sql" "github.com/miekg/dns" "github.com/satori/go.uuid" + "sync" ) // Records is for static records @@ -38,7 +40,7 @@ type dbsettings struct { // API config type httpapi struct { - Domain string + Domain string `toml:"api_domain"` Port string TLS string TLSCertPrivkey string `toml:"tls_cert_privkey"` @@ -67,3 +69,21 @@ type ACMETxtPost struct { Subdomain string `json:"subdomain"` Value string `json:"txt"` } + +type acmedb struct { + sync.Mutex + DB *sql.DB +} + +type database interface { + Init(string, string) error + Register() (ACMETxt, error) + GetByUsername(uuid.UUID) (ACMETxt, error) + GetByDomain(string) ([]ACMETxt, error) + Update(ACMETxt) error + GetBackend() *sql.DB + SetBackend(*sql.DB) + Close() + Lock() + Unlock() +} diff --git a/util.go b/util.go index 69adca6..200719d 100644 --- a/util.go +++ b/util.go @@ -2,7 +2,6 @@ package main import ( "crypto/rand" - "errors" "fmt" "github.com/BurntSushi/toml" log "github.com/Sirupsen/logrus" @@ -16,12 +15,11 @@ import ( "strings" ) -func readConfig(fname string) (DNSConfig, error) { +func readConfig(fname string) DNSConfig { var conf DNSConfig - if _, err := toml.DecodeFile(fname, &conf); err != nil { - return DNSConfig{}, errors.New("Malformed configuration file") - } - return conf, nil + // Practically never errors + _, _ = toml.DecodeFile(fname, &conf) + return conf } func sanitizeString(s string) string { @@ -62,7 +60,7 @@ func newACMETxt() (ACMETxt, error) { } func setupLogging(format string, level string) { - if DNSConf.Logconfig.Format == "json" { + if format == "json" { log.SetFormatter(&log.JSONFormatter{}) } switch level { diff --git a/util_test.go b/util_test.go index 381bb86..3d8d9fd 100644 --- a/util_test.go +++ b/util_test.go @@ -2,6 +2,8 @@ package main import ( log "github.com/Sirupsen/logrus" + "io/ioutil" + "os" "testing" ) @@ -23,3 +25,49 @@ func TestSetupLogging(t *testing.T) { } } } + +func TestReadConfig(t *testing.T) { + for i, test := range []struct { + inFile []byte + output DNSConfig + }{ + { + []byte("[general]\nlisten = \":53\"\ndebug = true\n[api]\napi_domain = \"something.strange\""), + DNSConfig{ + General: general{ + Listen: ":53", + Debug: true, + }, + API: httpapi{ + Domain: "something.strange", + }, + }, + }, + + { + []byte("[\x00[[[[[[[[[de\nlisten =]"), + DNSConfig{}, + }, + } { + tmpfile, err := ioutil.TempFile("", "acmedns") + if err != nil { + t.Error("Could not create temporary file") + } + defer os.Remove(tmpfile.Name()) + + if _, err := tmpfile.Write(test.inFile); err != nil { + t.Error("Could not write to temporary file") + } + + if err := tmpfile.Close(); err != nil { + t.Error("Could not close temporary file") + } + ret := readConfig(tmpfile.Name()) + if ret.General.Listen != test.output.General.Listen { + t.Errorf("Test %d: Expected listen value %s, but got %s", i, test.output.General.Listen, ret.General.Listen) + } + if ret.API.Domain != test.output.API.Domain { + t.Errorf("Test %d: Expected HTTP API domain %s, but got %s", i, test.output.API.Domain, ret.API.Domain) + } + } +}