mirror of
https://github.com/joohoi/acme-dns.git
synced 2025-02-21 20:18:12 +07:00
Made DB an interface
This commit is contained in:
parent
e9f18c99d8
commit
4615826267
12
api_test.go
12
api_test.go
@ -60,9 +60,9 @@ func TestApiRegister(t *testing.T) {
|
||||
|
||||
func TestApiRegisterWithMockDB(t *testing.T) {
|
||||
e := setupIris(t, false, false)
|
||||
oldDb := DB.DB
|
||||
oldDb := DB.GetBackend()
|
||||
db, mock, _ := sqlmock.New()
|
||||
DB.DB = db
|
||||
DB.SetBackend(db)
|
||||
defer db.Close()
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectPrepare("INSERT INTO records").WillReturnError(errors.New("error"))
|
||||
@ -70,7 +70,7 @@ func TestApiRegisterWithMockDB(t *testing.T) {
|
||||
Status(iris.StatusInternalServerError).
|
||||
JSON().Object().
|
||||
ContainsKey("error")
|
||||
DB.DB = oldDb
|
||||
DB.SetBackend(oldDb)
|
||||
}
|
||||
|
||||
func TestApiUpdateWithoutCredentials(t *testing.T) {
|
||||
@ -121,9 +121,9 @@ func TestApiUpdateWithCredentialsMockDB(t *testing.T) {
|
||||
updateJSON["txt"] = validTxtData
|
||||
|
||||
e := setupIris(t, false, true)
|
||||
oldDb := DB.DB
|
||||
oldDb := DB.GetBackend()
|
||||
db, mock, _ := sqlmock.New()
|
||||
DB.DB = db
|
||||
DB.SetBackend(db)
|
||||
defer db.Close()
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectPrepare("UPDATE records").WillReturnError(errors.New("error"))
|
||||
@ -133,7 +133,7 @@ func TestApiUpdateWithCredentialsMockDB(t *testing.T) {
|
||||
Status(iris.StatusInternalServerError).
|
||||
JSON().Object().
|
||||
ContainsKey("error")
|
||||
DB.DB = oldDb
|
||||
DB.SetBackend(oldDb)
|
||||
}
|
||||
|
||||
func TestApiManyUpdateWithCredentials(t *testing.T) {
|
||||
|
28
db.go
28
db.go
@ -9,15 +9,9 @@ import (
|
||||
"github.com/satori/go.uuid"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"regexp"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type database struct {
|
||||
sync.Mutex
|
||||
DB *sql.DB
|
||||
}
|
||||
|
||||
var recordsTable = `
|
||||
CREATE TABLE IF NOT EXISTS records(
|
||||
Username TEXT UNIQUE NOT NULL PRIMARY KEY,
|
||||
@ -37,7 +31,7 @@ func getSQLiteStmt(s string) string {
|
||||
return re.ReplaceAllString(s, "?")
|
||||
}
|
||||
|
||||
func (d *database) Init(engine string, connection string) error {
|
||||
func (d *acmedb) Init(engine string, connection string) error {
|
||||
d.Lock()
|
||||
defer d.Unlock()
|
||||
db, err := sql.Open(engine, connection)
|
||||
@ -53,7 +47,7 @@ func (d *database) Init(engine string, connection string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *database) Register() (ACMETxt, error) {
|
||||
func (d *acmedb) Register() (ACMETxt, error) {
|
||||
d.Lock()
|
||||
defer d.Unlock()
|
||||
a, err := newACMETxt()
|
||||
@ -85,7 +79,7 @@ func (d *database) Register() (ACMETxt, error) {
|
||||
return a, nil
|
||||
}
|
||||
|
||||
func (d *database) GetByUsername(u uuid.UUID) (ACMETxt, error) {
|
||||
func (d *acmedb) GetByUsername(u uuid.UUID) (ACMETxt, error) {
|
||||
d.Lock()
|
||||
defer d.Unlock()
|
||||
var results []ACMETxt
|
||||
@ -129,7 +123,7 @@ func (d *database) GetByUsername(u uuid.UUID) (ACMETxt, error) {
|
||||
return ACMETxt{}, errors.New("no user")
|
||||
}
|
||||
|
||||
func (d *database) GetByDomain(domain string) ([]ACMETxt, error) {
|
||||
func (d *acmedb) GetByDomain(domain string) ([]ACMETxt, error) {
|
||||
d.Lock()
|
||||
defer d.Unlock()
|
||||
domain = sanitizeString(domain)
|
||||
@ -165,7 +159,7 @@ func (d *database) GetByDomain(domain string) ([]ACMETxt, error) {
|
||||
return a, nil
|
||||
}
|
||||
|
||||
func (d *database) Update(a ACMETxt) error {
|
||||
func (d *acmedb) Update(a ACMETxt) error {
|
||||
d.Lock()
|
||||
defer d.Unlock()
|
||||
// Data in a is already sanitized
|
||||
@ -189,3 +183,15 @@ func (d *database) Update(a ACMETxt) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *acmedb) Close() {
|
||||
d.DB.Close()
|
||||
}
|
||||
|
||||
func (d *acmedb) GetBackend() *sql.DB {
|
||||
return d.DB
|
||||
}
|
||||
|
||||
func (d *acmedb) SetBackend(backend *sql.DB) {
|
||||
d.DB = backend
|
||||
}
|
||||
|
13
main.go
13
main.go
@ -1,7 +1,6 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
log "github.com/Sirupsen/logrus"
|
||||
"os"
|
||||
)
|
||||
@ -17,11 +16,7 @@ var RR Records
|
||||
|
||||
func main() {
|
||||
// Read global config
|
||||
configTmp, err := readConfig("config.cfg")
|
||||
if err != nil {
|
||||
fmt.Printf("Got error %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
configTmp := readConfig("config.cfg")
|
||||
DNSConf = configTmp
|
||||
|
||||
setupLogging(DNSConf.Logconfig.Format, DNSConf.Logconfig.Level)
|
||||
@ -30,12 +25,14 @@ func main() {
|
||||
RR.Parse(DNSConf.General.StaticRecords)
|
||||
|
||||
// Open database
|
||||
err = DB.Init(DNSConf.Database.Engine, DNSConf.Database.Connection)
|
||||
newDB := new(acmedb)
|
||||
err := newDB.Init(DNSConf.Database.Engine, DNSConf.Database.Connection)
|
||||
if err != nil {
|
||||
log.Errorf("Could not open database [%v]", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer DB.DB.Close()
|
||||
DB = newDB
|
||||
defer DB.Close()
|
||||
|
||||
// DNS server
|
||||
startDNS(DNSConf.General.Listen)
|
||||
|
@ -16,22 +16,24 @@ func TestMain(m *testing.M) {
|
||||
RR.Parse(records)
|
||||
flag.Parse()
|
||||
|
||||
newDb := new(acmedb)
|
||||
if *postgres {
|
||||
DNSConf.Database.Engine = "postgres"
|
||||
err := DB.Init("postgres", "postgres://acmedns:acmedns@localhost/acmedns")
|
||||
err := newDb.Init("postgres", "postgres://acmedns:acmedns@localhost/acmedns")
|
||||
if err != nil {
|
||||
fmt.Println("PostgreSQL integration tests expect database \"acmedns\" running in localhost, with username and password set to \"acmedns\"")
|
||||
os.Exit(1)
|
||||
}
|
||||
} else {
|
||||
DNSConf.Database.Engine = "sqlite3"
|
||||
_ = DB.Init("sqlite3", ":memory:")
|
||||
_ = newDb.Init("sqlite3", ":memory:")
|
||||
}
|
||||
DB = newDb
|
||||
|
||||
server := startDNS("0.0.0.0:15353")
|
||||
exitval := m.Run()
|
||||
server.Shutdown()
|
||||
DB.DB.Close()
|
||||
DB.Close()
|
||||
os.Exit(exitval)
|
||||
}
|
||||
|
||||
|
22
types.go
22
types.go
@ -1,8 +1,10 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/satori/go.uuid"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Records is for static records
|
||||
@ -38,7 +40,7 @@ type dbsettings struct {
|
||||
|
||||
// API config
|
||||
type httpapi struct {
|
||||
Domain string
|
||||
Domain string `toml:"api_domain"`
|
||||
Port string
|
||||
TLS string
|
||||
TLSCertPrivkey string `toml:"tls_cert_privkey"`
|
||||
@ -67,3 +69,21 @@ type ACMETxtPost struct {
|
||||
Subdomain string `json:"subdomain"`
|
||||
Value string `json:"txt"`
|
||||
}
|
||||
|
||||
type acmedb struct {
|
||||
sync.Mutex
|
||||
DB *sql.DB
|
||||
}
|
||||
|
||||
type database interface {
|
||||
Init(string, string) error
|
||||
Register() (ACMETxt, error)
|
||||
GetByUsername(uuid.UUID) (ACMETxt, error)
|
||||
GetByDomain(string) ([]ACMETxt, error)
|
||||
Update(ACMETxt) error
|
||||
GetBackend() *sql.DB
|
||||
SetBackend(*sql.DB)
|
||||
Close()
|
||||
Lock()
|
||||
Unlock()
|
||||
}
|
||||
|
12
util.go
12
util.go
@ -2,7 +2,6 @@ package main
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/BurntSushi/toml"
|
||||
log "github.com/Sirupsen/logrus"
|
||||
@ -16,12 +15,11 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
func readConfig(fname string) (DNSConfig, error) {
|
||||
func readConfig(fname string) DNSConfig {
|
||||
var conf DNSConfig
|
||||
if _, err := toml.DecodeFile(fname, &conf); err != nil {
|
||||
return DNSConfig{}, errors.New("Malformed configuration file")
|
||||
}
|
||||
return conf, nil
|
||||
// Practically never errors
|
||||
_, _ = toml.DecodeFile(fname, &conf)
|
||||
return conf
|
||||
}
|
||||
|
||||
func sanitizeString(s string) string {
|
||||
@ -62,7 +60,7 @@ func newACMETxt() (ACMETxt, error) {
|
||||
}
|
||||
|
||||
func setupLogging(format string, level string) {
|
||||
if DNSConf.Logconfig.Format == "json" {
|
||||
if format == "json" {
|
||||
log.SetFormatter(&log.JSONFormatter{})
|
||||
}
|
||||
switch level {
|
||||
|
48
util_test.go
48
util_test.go
@ -2,6 +2,8 @@ package main
|
||||
|
||||
import (
|
||||
log "github.com/Sirupsen/logrus"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@ -23,3 +25,49 @@ func TestSetupLogging(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadConfig(t *testing.T) {
|
||||
for i, test := range []struct {
|
||||
inFile []byte
|
||||
output DNSConfig
|
||||
}{
|
||||
{
|
||||
[]byte("[general]\nlisten = \":53\"\ndebug = true\n[api]\napi_domain = \"something.strange\""),
|
||||
DNSConfig{
|
||||
General: general{
|
||||
Listen: ":53",
|
||||
Debug: true,
|
||||
},
|
||||
API: httpapi{
|
||||
Domain: "something.strange",
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
[]byte("[\x00[[[[[[[[[de\nlisten =]"),
|
||||
DNSConfig{},
|
||||
},
|
||||
} {
|
||||
tmpfile, err := ioutil.TempFile("", "acmedns")
|
||||
if err != nil {
|
||||
t.Error("Could not create temporary file")
|
||||
}
|
||||
defer os.Remove(tmpfile.Name())
|
||||
|
||||
if _, err := tmpfile.Write(test.inFile); err != nil {
|
||||
t.Error("Could not write to temporary file")
|
||||
}
|
||||
|
||||
if err := tmpfile.Close(); err != nil {
|
||||
t.Error("Could not close temporary file")
|
||||
}
|
||||
ret := readConfig(tmpfile.Name())
|
||||
if ret.General.Listen != test.output.General.Listen {
|
||||
t.Errorf("Test %d: Expected listen value %s, but got %s", i, test.output.General.Listen, ret.General.Listen)
|
||||
}
|
||||
if ret.API.Domain != test.output.API.Domain {
|
||||
t.Errorf("Test %d: Expected HTTP API domain %s, but got %s", i, test.output.API.Domain, ret.API.Domain)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user