mirror of
https://github.com/joohoi/acme-dns.git
synced 2025-02-22 12:38:10 +07:00
Made DB an interface
This commit is contained in:
parent
e9f18c99d8
commit
4615826267
12
api_test.go
12
api_test.go
@ -60,9 +60,9 @@ func TestApiRegister(t *testing.T) {
|
|||||||
|
|
||||||
func TestApiRegisterWithMockDB(t *testing.T) {
|
func TestApiRegisterWithMockDB(t *testing.T) {
|
||||||
e := setupIris(t, false, false)
|
e := setupIris(t, false, false)
|
||||||
oldDb := DB.DB
|
oldDb := DB.GetBackend()
|
||||||
db, mock, _ := sqlmock.New()
|
db, mock, _ := sqlmock.New()
|
||||||
DB.DB = db
|
DB.SetBackend(db)
|
||||||
defer db.Close()
|
defer db.Close()
|
||||||
mock.ExpectBegin()
|
mock.ExpectBegin()
|
||||||
mock.ExpectPrepare("INSERT INTO records").WillReturnError(errors.New("error"))
|
mock.ExpectPrepare("INSERT INTO records").WillReturnError(errors.New("error"))
|
||||||
@ -70,7 +70,7 @@ func TestApiRegisterWithMockDB(t *testing.T) {
|
|||||||
Status(iris.StatusInternalServerError).
|
Status(iris.StatusInternalServerError).
|
||||||
JSON().Object().
|
JSON().Object().
|
||||||
ContainsKey("error")
|
ContainsKey("error")
|
||||||
DB.DB = oldDb
|
DB.SetBackend(oldDb)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestApiUpdateWithoutCredentials(t *testing.T) {
|
func TestApiUpdateWithoutCredentials(t *testing.T) {
|
||||||
@ -121,9 +121,9 @@ func TestApiUpdateWithCredentialsMockDB(t *testing.T) {
|
|||||||
updateJSON["txt"] = validTxtData
|
updateJSON["txt"] = validTxtData
|
||||||
|
|
||||||
e := setupIris(t, false, true)
|
e := setupIris(t, false, true)
|
||||||
oldDb := DB.DB
|
oldDb := DB.GetBackend()
|
||||||
db, mock, _ := sqlmock.New()
|
db, mock, _ := sqlmock.New()
|
||||||
DB.DB = db
|
DB.SetBackend(db)
|
||||||
defer db.Close()
|
defer db.Close()
|
||||||
mock.ExpectBegin()
|
mock.ExpectBegin()
|
||||||
mock.ExpectPrepare("UPDATE records").WillReturnError(errors.New("error"))
|
mock.ExpectPrepare("UPDATE records").WillReturnError(errors.New("error"))
|
||||||
@ -133,7 +133,7 @@ func TestApiUpdateWithCredentialsMockDB(t *testing.T) {
|
|||||||
Status(iris.StatusInternalServerError).
|
Status(iris.StatusInternalServerError).
|
||||||
JSON().Object().
|
JSON().Object().
|
||||||
ContainsKey("error")
|
ContainsKey("error")
|
||||||
DB.DB = oldDb
|
DB.SetBackend(oldDb)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestApiManyUpdateWithCredentials(t *testing.T) {
|
func TestApiManyUpdateWithCredentials(t *testing.T) {
|
||||||
|
28
db.go
28
db.go
@ -9,15 +9,9 @@ import (
|
|||||||
"github.com/satori/go.uuid"
|
"github.com/satori/go.uuid"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
"regexp"
|
"regexp"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type database struct {
|
|
||||||
sync.Mutex
|
|
||||||
DB *sql.DB
|
|
||||||
}
|
|
||||||
|
|
||||||
var recordsTable = `
|
var recordsTable = `
|
||||||
CREATE TABLE IF NOT EXISTS records(
|
CREATE TABLE IF NOT EXISTS records(
|
||||||
Username TEXT UNIQUE NOT NULL PRIMARY KEY,
|
Username TEXT UNIQUE NOT NULL PRIMARY KEY,
|
||||||
@ -37,7 +31,7 @@ func getSQLiteStmt(s string) string {
|
|||||||
return re.ReplaceAllString(s, "?")
|
return re.ReplaceAllString(s, "?")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *database) Init(engine string, connection string) error {
|
func (d *acmedb) Init(engine string, connection string) error {
|
||||||
d.Lock()
|
d.Lock()
|
||||||
defer d.Unlock()
|
defer d.Unlock()
|
||||||
db, err := sql.Open(engine, connection)
|
db, err := sql.Open(engine, connection)
|
||||||
@ -53,7 +47,7 @@ func (d *database) Init(engine string, connection string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *database) Register() (ACMETxt, error) {
|
func (d *acmedb) Register() (ACMETxt, error) {
|
||||||
d.Lock()
|
d.Lock()
|
||||||
defer d.Unlock()
|
defer d.Unlock()
|
||||||
a, err := newACMETxt()
|
a, err := newACMETxt()
|
||||||
@ -85,7 +79,7 @@ func (d *database) Register() (ACMETxt, error) {
|
|||||||
return a, nil
|
return a, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *database) GetByUsername(u uuid.UUID) (ACMETxt, error) {
|
func (d *acmedb) GetByUsername(u uuid.UUID) (ACMETxt, error) {
|
||||||
d.Lock()
|
d.Lock()
|
||||||
defer d.Unlock()
|
defer d.Unlock()
|
||||||
var results []ACMETxt
|
var results []ACMETxt
|
||||||
@ -129,7 +123,7 @@ func (d *database) GetByUsername(u uuid.UUID) (ACMETxt, error) {
|
|||||||
return ACMETxt{}, errors.New("no user")
|
return ACMETxt{}, errors.New("no user")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *database) GetByDomain(domain string) ([]ACMETxt, error) {
|
func (d *acmedb) GetByDomain(domain string) ([]ACMETxt, error) {
|
||||||
d.Lock()
|
d.Lock()
|
||||||
defer d.Unlock()
|
defer d.Unlock()
|
||||||
domain = sanitizeString(domain)
|
domain = sanitizeString(domain)
|
||||||
@ -165,7 +159,7 @@ func (d *database) GetByDomain(domain string) ([]ACMETxt, error) {
|
|||||||
return a, nil
|
return a, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *database) Update(a ACMETxt) error {
|
func (d *acmedb) Update(a ACMETxt) error {
|
||||||
d.Lock()
|
d.Lock()
|
||||||
defer d.Unlock()
|
defer d.Unlock()
|
||||||
// Data in a is already sanitized
|
// Data in a is already sanitized
|
||||||
@ -189,3 +183,15 @@ func (d *database) Update(a ACMETxt) error {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *acmedb) Close() {
|
||||||
|
d.DB.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *acmedb) GetBackend() *sql.DB {
|
||||||
|
return d.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *acmedb) SetBackend(backend *sql.DB) {
|
||||||
|
d.DB = backend
|
||||||
|
}
|
||||||
|
13
main.go
13
main.go
@ -1,7 +1,6 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
log "github.com/Sirupsen/logrus"
|
log "github.com/Sirupsen/logrus"
|
||||||
"os"
|
"os"
|
||||||
)
|
)
|
||||||
@ -17,11 +16,7 @@ var RR Records
|
|||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
// Read global config
|
// Read global config
|
||||||
configTmp, err := readConfig("config.cfg")
|
configTmp := readConfig("config.cfg")
|
||||||
if err != nil {
|
|
||||||
fmt.Printf("Got error %v\n", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
DNSConf = configTmp
|
DNSConf = configTmp
|
||||||
|
|
||||||
setupLogging(DNSConf.Logconfig.Format, DNSConf.Logconfig.Level)
|
setupLogging(DNSConf.Logconfig.Format, DNSConf.Logconfig.Level)
|
||||||
@ -30,12 +25,14 @@ func main() {
|
|||||||
RR.Parse(DNSConf.General.StaticRecords)
|
RR.Parse(DNSConf.General.StaticRecords)
|
||||||
|
|
||||||
// Open database
|
// Open database
|
||||||
err = DB.Init(DNSConf.Database.Engine, DNSConf.Database.Connection)
|
newDB := new(acmedb)
|
||||||
|
err := newDB.Init(DNSConf.Database.Engine, DNSConf.Database.Connection)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Could not open database [%v]", err)
|
log.Errorf("Could not open database [%v]", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
defer DB.DB.Close()
|
DB = newDB
|
||||||
|
defer DB.Close()
|
||||||
|
|
||||||
// DNS server
|
// DNS server
|
||||||
startDNS(DNSConf.General.Listen)
|
startDNS(DNSConf.General.Listen)
|
||||||
|
@ -16,22 +16,24 @@ func TestMain(m *testing.M) {
|
|||||||
RR.Parse(records)
|
RR.Parse(records)
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
|
newDb := new(acmedb)
|
||||||
if *postgres {
|
if *postgres {
|
||||||
DNSConf.Database.Engine = "postgres"
|
DNSConf.Database.Engine = "postgres"
|
||||||
err := DB.Init("postgres", "postgres://acmedns:acmedns@localhost/acmedns")
|
err := newDb.Init("postgres", "postgres://acmedns:acmedns@localhost/acmedns")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println("PostgreSQL integration tests expect database \"acmedns\" running in localhost, with username and password set to \"acmedns\"")
|
fmt.Println("PostgreSQL integration tests expect database \"acmedns\" running in localhost, with username and password set to \"acmedns\"")
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
DNSConf.Database.Engine = "sqlite3"
|
DNSConf.Database.Engine = "sqlite3"
|
||||||
_ = DB.Init("sqlite3", ":memory:")
|
_ = newDb.Init("sqlite3", ":memory:")
|
||||||
}
|
}
|
||||||
|
DB = newDb
|
||||||
|
|
||||||
server := startDNS("0.0.0.0:15353")
|
server := startDNS("0.0.0.0:15353")
|
||||||
exitval := m.Run()
|
exitval := m.Run()
|
||||||
server.Shutdown()
|
server.Shutdown()
|
||||||
DB.DB.Close()
|
DB.Close()
|
||||||
os.Exit(exitval)
|
os.Exit(exitval)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
22
types.go
22
types.go
@ -1,8 +1,10 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/satori/go.uuid"
|
"github.com/satori/go.uuid"
|
||||||
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Records is for static records
|
// Records is for static records
|
||||||
@ -38,7 +40,7 @@ type dbsettings struct {
|
|||||||
|
|
||||||
// API config
|
// API config
|
||||||
type httpapi struct {
|
type httpapi struct {
|
||||||
Domain string
|
Domain string `toml:"api_domain"`
|
||||||
Port string
|
Port string
|
||||||
TLS string
|
TLS string
|
||||||
TLSCertPrivkey string `toml:"tls_cert_privkey"`
|
TLSCertPrivkey string `toml:"tls_cert_privkey"`
|
||||||
@ -67,3 +69,21 @@ type ACMETxtPost struct {
|
|||||||
Subdomain string `json:"subdomain"`
|
Subdomain string `json:"subdomain"`
|
||||||
Value string `json:"txt"`
|
Value string `json:"txt"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type acmedb struct {
|
||||||
|
sync.Mutex
|
||||||
|
DB *sql.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
type database interface {
|
||||||
|
Init(string, string) error
|
||||||
|
Register() (ACMETxt, error)
|
||||||
|
GetByUsername(uuid.UUID) (ACMETxt, error)
|
||||||
|
GetByDomain(string) ([]ACMETxt, error)
|
||||||
|
Update(ACMETxt) error
|
||||||
|
GetBackend() *sql.DB
|
||||||
|
SetBackend(*sql.DB)
|
||||||
|
Close()
|
||||||
|
Lock()
|
||||||
|
Unlock()
|
||||||
|
}
|
||||||
|
12
util.go
12
util.go
@ -2,7 +2,6 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/BurntSushi/toml"
|
"github.com/BurntSushi/toml"
|
||||||
log "github.com/Sirupsen/logrus"
|
log "github.com/Sirupsen/logrus"
|
||||||
@ -16,12 +15,11 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
func readConfig(fname string) (DNSConfig, error) {
|
func readConfig(fname string) DNSConfig {
|
||||||
var conf DNSConfig
|
var conf DNSConfig
|
||||||
if _, err := toml.DecodeFile(fname, &conf); err != nil {
|
// Practically never errors
|
||||||
return DNSConfig{}, errors.New("Malformed configuration file")
|
_, _ = toml.DecodeFile(fname, &conf)
|
||||||
}
|
return conf
|
||||||
return conf, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func sanitizeString(s string) string {
|
func sanitizeString(s string) string {
|
||||||
@ -62,7 +60,7 @@ func newACMETxt() (ACMETxt, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func setupLogging(format string, level string) {
|
func setupLogging(format string, level string) {
|
||||||
if DNSConf.Logconfig.Format == "json" {
|
if format == "json" {
|
||||||
log.SetFormatter(&log.JSONFormatter{})
|
log.SetFormatter(&log.JSONFormatter{})
|
||||||
}
|
}
|
||||||
switch level {
|
switch level {
|
||||||
|
48
util_test.go
48
util_test.go
@ -2,6 +2,8 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
log "github.com/Sirupsen/logrus"
|
log "github.com/Sirupsen/logrus"
|
||||||
|
"io/ioutil"
|
||||||
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -23,3 +25,49 @@ func TestSetupLogging(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestReadConfig(t *testing.T) {
|
||||||
|
for i, test := range []struct {
|
||||||
|
inFile []byte
|
||||||
|
output DNSConfig
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
[]byte("[general]\nlisten = \":53\"\ndebug = true\n[api]\napi_domain = \"something.strange\""),
|
||||||
|
DNSConfig{
|
||||||
|
General: general{
|
||||||
|
Listen: ":53",
|
||||||
|
Debug: true,
|
||||||
|
},
|
||||||
|
API: httpapi{
|
||||||
|
Domain: "something.strange",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
[]byte("[\x00[[[[[[[[[de\nlisten =]"),
|
||||||
|
DNSConfig{},
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
tmpfile, err := ioutil.TempFile("", "acmedns")
|
||||||
|
if err != nil {
|
||||||
|
t.Error("Could not create temporary file")
|
||||||
|
}
|
||||||
|
defer os.Remove(tmpfile.Name())
|
||||||
|
|
||||||
|
if _, err := tmpfile.Write(test.inFile); err != nil {
|
||||||
|
t.Error("Could not write to temporary file")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tmpfile.Close(); err != nil {
|
||||||
|
t.Error("Could not close temporary file")
|
||||||
|
}
|
||||||
|
ret := readConfig(tmpfile.Name())
|
||||||
|
if ret.General.Listen != test.output.General.Listen {
|
||||||
|
t.Errorf("Test %d: Expected listen value %s, but got %s", i, test.output.General.Listen, ret.General.Listen)
|
||||||
|
}
|
||||||
|
if ret.API.Domain != test.output.API.Domain {
|
||||||
|
t.Errorf("Test %d: Expected HTTP API domain %s, but got %s", i, test.output.API.Domain, ret.API.Domain)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user