mirror of
https://github.com/joohoi/acme-dns.git
synced 2024-12-22 23:43:44 +07:00
302 lines
7.8 KiB
Go
302 lines
7.8 KiB
Go
package main
|
|
|
|
import (
|
|
"database/sql"
|
|
"database/sql/driver"
|
|
"errors"
|
|
"github.com/erikstmartin/go-testdb"
|
|
"testing"
|
|
)
|
|
|
|
type testResult struct {
|
|
lastID int64
|
|
affectedRows int64
|
|
}
|
|
|
|
func (r testResult) LastInsertId() (int64, error) {
|
|
return r.lastID, nil
|
|
}
|
|
|
|
func (r testResult) RowsAffected() (int64, error) {
|
|
return r.affectedRows, nil
|
|
}
|
|
|
|
func TestDBInit(t *testing.T) {
|
|
fakeDB := new(acmedb)
|
|
err := fakeDB.Init("notarealegine", "connectionstring")
|
|
if err == nil {
|
|
t.Errorf("Was expecting error, didn't get one.")
|
|
}
|
|
|
|
testdb.SetExecWithArgsFunc(func(query string, args []driver.Value) (result driver.Result, err error) {
|
|
return testResult{1, 0}, errors.New("Prepared query error")
|
|
})
|
|
defer testdb.Reset()
|
|
|
|
errorDB := new(acmedb)
|
|
err = errorDB.Init("testdb", "")
|
|
if err == nil {
|
|
t.Errorf("Was expecting DB initiation error but got none")
|
|
}
|
|
errorDB.Close()
|
|
}
|
|
|
|
func TestRegisterNoCIDR(t *testing.T) {
|
|
// Register tests
|
|
_, err := DB.Register(cidrslice{})
|
|
if err != nil {
|
|
t.Errorf("Registration failed, got error [%v]", err)
|
|
}
|
|
}
|
|
|
|
func TestRegisterMany(t *testing.T) {
|
|
for i, test := range []struct {
|
|
input cidrslice
|
|
output cidrslice
|
|
}{
|
|
{cidrslice{"127.0.0.1/8", "8.8.8.8/32", "1.0.0.1/1"}, cidrslice{"127.0.0.1/8", "8.8.8.8/32", "1.0.0.1/1"}},
|
|
{cidrslice{"1.1.1./32", "1922.168.42.42/8", "1.1.1.1/33", "1.2.3.4/"}, cidrslice{}},
|
|
{cidrslice{"7.6.5.4/32", "invalid", "1.0.0.1/2"}, cidrslice{"7.6.5.4/32", "1.0.0.1/2"}},
|
|
} {
|
|
user, err := DB.Register(test.input)
|
|
if err != nil {
|
|
t.Errorf("Test %d: Got error from register method: [%v]", i, err)
|
|
}
|
|
res, err := DB.GetByUsername(user.Username)
|
|
if err != nil {
|
|
t.Errorf("Test %d: Got error when fetching username: [%v]", i, err)
|
|
}
|
|
if len(user.AllowFrom) != len(test.output) {
|
|
t.Errorf("Test %d: Expected to receive struct with [%d] entries in AllowFrom, but got [%d] records", i, len(test.output), len(user.AllowFrom))
|
|
}
|
|
if len(res.AllowFrom) != len(test.output) {
|
|
t.Errorf("Test %d: Expected to receive struct with [%d] entries in AllowFrom, but got [%d] records", i, len(test.output), len(res.AllowFrom))
|
|
}
|
|
|
|
}
|
|
}
|
|
|
|
func TestGetByUsername(t *testing.T) {
|
|
// Create reg to refer to
|
|
reg, err := DB.Register(cidrslice{})
|
|
if err != nil {
|
|
t.Errorf("Registration failed, got error [%v]", err)
|
|
}
|
|
|
|
regUser, err := DB.GetByUsername(reg.Username)
|
|
if err != nil {
|
|
t.Errorf("Could not get test user, got error [%v]", err)
|
|
}
|
|
|
|
if reg.Username != regUser.Username {
|
|
t.Errorf("GetByUsername username [%q] did not match the original [%q]", regUser.Username, reg.Username)
|
|
}
|
|
|
|
if reg.Subdomain != regUser.Subdomain {
|
|
t.Errorf("GetByUsername subdomain [%q] did not match the original [%q]", regUser.Subdomain, reg.Subdomain)
|
|
}
|
|
|
|
// regUser password already is a bcrypt hash
|
|
if !correctPassword(reg.Password, regUser.Password) {
|
|
t.Errorf("The password [%s] does not match the hash [%s]", reg.Password, regUser.Password)
|
|
}
|
|
}
|
|
|
|
func TestPrepareErrors(t *testing.T) {
|
|
reg, _ := DB.Register(cidrslice{})
|
|
tdb, err := sql.Open("testdb", "")
|
|
if err != nil {
|
|
t.Errorf("Got error: %v", err)
|
|
}
|
|
oldDb := DB.GetBackend()
|
|
DB.SetBackend(tdb)
|
|
defer DB.SetBackend(oldDb)
|
|
defer testdb.Reset()
|
|
|
|
_, err = DB.GetByUsername(reg.Username)
|
|
if err == nil {
|
|
t.Errorf("Expected error, but didn't get one")
|
|
}
|
|
|
|
_, err = DB.GetTXTForDomain(reg.Subdomain)
|
|
if err == nil {
|
|
t.Errorf("Expected error, but didn't get one")
|
|
}
|
|
}
|
|
|
|
func TestQueryExecErrors(t *testing.T) {
|
|
reg, _ := DB.Register(cidrslice{})
|
|
testdb.SetExecWithArgsFunc(func(query string, args []driver.Value) (result driver.Result, err error) {
|
|
return testResult{1, 0}, errors.New("Prepared query error")
|
|
})
|
|
|
|
testdb.SetQueryWithArgsFunc(func(query string, args []driver.Value) (result driver.Rows, err error) {
|
|
columns := []string{"Username", "Password", "Subdomain", "Value", "LastActive"}
|
|
return testdb.RowsFromSlice(columns, [][]driver.Value{}), errors.New("Prepared query error")
|
|
})
|
|
|
|
defer testdb.Reset()
|
|
|
|
tdb, err := sql.Open("testdb", "")
|
|
if err != nil {
|
|
t.Errorf("Got error: %v", err)
|
|
}
|
|
oldDb := DB.GetBackend()
|
|
|
|
DB.SetBackend(tdb)
|
|
defer DB.SetBackend(oldDb)
|
|
|
|
_, err = DB.GetByUsername(reg.Username)
|
|
if err == nil {
|
|
t.Errorf("Expected error from exec, but got none")
|
|
}
|
|
|
|
_, err = DB.GetTXTForDomain(reg.Subdomain)
|
|
if err == nil {
|
|
t.Errorf("Expected error from exec in GetByDomain, but got none")
|
|
}
|
|
|
|
_, err = DB.Register(cidrslice{})
|
|
if err == nil {
|
|
t.Errorf("Expected error from exec in Register, but got none")
|
|
}
|
|
reg.Value = "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
|
|
err = DB.Update(reg)
|
|
if err == nil {
|
|
t.Errorf("Expected error from exec in Update, but got none")
|
|
}
|
|
|
|
}
|
|
|
|
func TestQueryScanErrors(t *testing.T) {
|
|
reg, _ := DB.Register(cidrslice{})
|
|
|
|
testdb.SetExecWithArgsFunc(func(query string, args []driver.Value) (result driver.Result, err error) {
|
|
return testResult{1, 0}, errors.New("Prepared query error")
|
|
})
|
|
|
|
testdb.SetQueryWithArgsFunc(func(query string, args []driver.Value) (result driver.Rows, err error) {
|
|
columns := []string{"Only one"}
|
|
resultrows := "this value"
|
|
return testdb.RowsFromCSVString(columns, resultrows), nil
|
|
})
|
|
|
|
defer testdb.Reset()
|
|
tdb, err := sql.Open("testdb", "")
|
|
if err != nil {
|
|
t.Errorf("Got error: %v", err)
|
|
}
|
|
oldDb := DB.GetBackend()
|
|
|
|
DB.SetBackend(tdb)
|
|
defer DB.SetBackend(oldDb)
|
|
|
|
_, err = DB.GetByUsername(reg.Username)
|
|
if err == nil {
|
|
t.Errorf("Expected error from scan in, but got none")
|
|
}
|
|
}
|
|
|
|
func TestBadDBValues(t *testing.T) {
|
|
reg, _ := DB.Register(cidrslice{})
|
|
|
|
testdb.SetQueryWithArgsFunc(func(query string, args []driver.Value) (result driver.Rows, err error) {
|
|
columns := []string{"Username", "Password", "Subdomain", "Value", "LastActive"}
|
|
resultrows := "invalid,invalid,invalid,invalid,"
|
|
return testdb.RowsFromCSVString(columns, resultrows), nil
|
|
})
|
|
|
|
defer testdb.Reset()
|
|
tdb, err := sql.Open("testdb", "")
|
|
if err != nil {
|
|
t.Errorf("Got error: %v", err)
|
|
}
|
|
oldDb := DB.GetBackend()
|
|
|
|
DB.SetBackend(tdb)
|
|
defer DB.SetBackend(oldDb)
|
|
|
|
_, err = DB.GetByUsername(reg.Username)
|
|
if err == nil {
|
|
t.Errorf("Expected error from scan in, but got none")
|
|
}
|
|
|
|
_, err = DB.GetTXTForDomain(reg.Subdomain)
|
|
if err == nil {
|
|
t.Errorf("Expected error from scan in GetByDomain, but got none")
|
|
}
|
|
}
|
|
|
|
func TestGetTXTForDomain(t *testing.T) {
|
|
// Create reg to refer to
|
|
reg, err := DB.Register(cidrslice{})
|
|
if err != nil {
|
|
t.Errorf("Registration failed, got error [%v]", err)
|
|
}
|
|
|
|
txtval1 := "___validation_token_received_from_the_ca___"
|
|
txtval2 := "___validation_token_received_YEAH_the_ca___"
|
|
|
|
reg.Value = txtval1
|
|
_ = DB.Update(reg)
|
|
|
|
reg.Value = txtval2
|
|
_ = DB.Update(reg)
|
|
|
|
regDomainSlice, err := DB.GetTXTForDomain(reg.Subdomain)
|
|
if err != nil {
|
|
t.Errorf("Could not get test user, got error [%v]", err)
|
|
}
|
|
if len(regDomainSlice) == 0 {
|
|
t.Errorf("No rows returned for GetTXTForDomain [%s]", reg.Subdomain)
|
|
}
|
|
|
|
var val1found = false
|
|
var val2found = false
|
|
for _, v := range regDomainSlice {
|
|
if v == txtval1 {
|
|
val1found = true
|
|
}
|
|
if v == txtval2 {
|
|
val2found = true
|
|
}
|
|
}
|
|
if !val1found {
|
|
t.Errorf("No TXT value found for val1")
|
|
}
|
|
if !val2found {
|
|
t.Errorf("No TXT value found for val2")
|
|
}
|
|
|
|
// Not found
|
|
regNotfound, _ := DB.GetTXTForDomain("does-not-exist")
|
|
if len(regNotfound) > 0 {
|
|
t.Errorf("No records should be returned.")
|
|
}
|
|
}
|
|
|
|
func TestUpdate(t *testing.T) {
|
|
// Create reg to refer to
|
|
reg, err := DB.Register(cidrslice{})
|
|
if err != nil {
|
|
t.Errorf("Registration failed, got error [%v]", err)
|
|
}
|
|
|
|
regUser, err := DB.GetByUsername(reg.Username)
|
|
if err != nil {
|
|
t.Errorf("Could not get test user, got error [%v]", err)
|
|
}
|
|
|
|
// Set new values (only TXT should be updated) (matches by username and subdomain)
|
|
|
|
validTXT := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
|
|
|
|
regUser.Password = "nevergonnagiveyouup"
|
|
regUser.Value = validTXT
|
|
|
|
err = DB.Update(regUser)
|
|
if err != nil {
|
|
t.Errorf("DB Update failed, got error: [%v]", err)
|
|
}
|
|
}
|