package main import ( "database/sql" "database/sql/driver" "errors" "github.com/erikstmartin/go-testdb" "testing" ) type testResult struct { lastID int64 affectedRows int64 } func (r testResult) LastInsertId() (int64, error) { return r.lastID, nil } func (r testResult) RowsAffected() (int64, error) { return r.affectedRows, nil } func TestDBInit(t *testing.T) { fakeDB := new(acmedb) err := fakeDB.Init("notarealegine", "connectionstring") if err == nil { t.Errorf("Was expecting error, didn't get one.") } testdb.SetExecWithArgsFunc(func(query string, args []driver.Value) (result driver.Result, err error) { return testResult{1, 0}, errors.New("Prepared query error") }) defer testdb.Reset() errorDB := new(acmedb) err = errorDB.Init("testdb", "") if err == nil { t.Errorf("Was expecting DB initiation error but got none") } errorDB.Close() } func TestRegister(t *testing.T) { // Register tests _, err := DB.Register() if err != nil { t.Errorf("Registration failed, got error [%v]", err) } } func TestGetByUsername(t *testing.T) { // Create reg to refer to reg, err := DB.Register() if err != nil { t.Errorf("Registration failed, got error [%v]", err) } regUser, err := DB.GetByUsername(reg.Username) if err != nil { t.Errorf("Could not get test user, got error [%v]", err) } if reg.Username != regUser.Username { t.Errorf("GetByUsername username [%q] did not match the original [%q]", regUser.Username, reg.Username) } if reg.Subdomain != regUser.Subdomain { t.Errorf("GetByUsername subdomain [%q] did not match the original [%q]", regUser.Subdomain, reg.Subdomain) } // regUser password already is a bcrypt hash if !correctPassword(reg.Password, regUser.Password) { t.Errorf("The password [%s] does not match the hash [%s]", reg.Password, regUser.Password) } } func TestPrepareErrors(t *testing.T) { reg, _ := DB.Register() tdb, err := sql.Open("testdb", "") if err != nil { t.Errorf("Got error: %v", err) } oldDb := DB.GetBackend() DB.SetBackend(tdb) defer DB.SetBackend(oldDb) defer testdb.Reset() _, err = DB.GetByUsername(reg.Username) if err == nil { t.Errorf("Expected error, but didn't get one") } _, err = DB.GetByDomain(reg.Subdomain) if err == nil { t.Errorf("Expected error, but didn't get one") } } func TestQueryExecErrors(t *testing.T) { reg, _ := DB.Register() testdb.SetExecWithArgsFunc(func(query string, args []driver.Value) (result driver.Result, err error) { return testResult{1, 0}, errors.New("Prepared query error") }) testdb.SetQueryWithArgsFunc(func(query string, args []driver.Value) (result driver.Rows, err error) { columns := []string{"Username", "Password", "Subdomain", "Value", "LastActive"} return testdb.RowsFromSlice(columns, [][]driver.Value{}), errors.New("Prepared query error") }) defer testdb.Reset() tdb, err := sql.Open("testdb", "") if err != nil { t.Errorf("Got error: %v", err) } oldDb := DB.GetBackend() DB.SetBackend(tdb) defer DB.SetBackend(oldDb) _, err = DB.GetByUsername(reg.Username) if err == nil { t.Errorf("Expected error from exec, but got none") } _, err = DB.GetByDomain(reg.Subdomain) if err == nil { t.Errorf("Expected error from exec in GetByDomain, but got none") } _, err = DB.Register() if err == nil { t.Errorf("Expected error from exec in Register, but got none") } reg.Value = "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" err = DB.Update(reg) if err == nil { t.Errorf("Expected error from exec in Update, but got none") } } func TestQueryScanErrors(t *testing.T) { reg, _ := DB.Register() testdb.SetExecWithArgsFunc(func(query string, args []driver.Value) (result driver.Result, err error) { return testResult{1, 0}, errors.New("Prepared query error") }) testdb.SetQueryWithArgsFunc(func(query string, args []driver.Value) (result driver.Rows, err error) { columns := []string{"Only one"} resultrows := "this value" return testdb.RowsFromCSVString(columns, resultrows), nil }) defer testdb.Reset() tdb, err := sql.Open("testdb", "") if err != nil { t.Errorf("Got error: %v", err) } oldDb := DB.GetBackend() DB.SetBackend(tdb) defer DB.SetBackend(oldDb) _, err = DB.GetByUsername(reg.Username) if err == nil { t.Errorf("Expected error from scan in, but got none") } _, err = DB.GetByDomain(reg.Subdomain) if err == nil { t.Errorf("Expected error from scan in GetByDomain, but got none") } } func TestBadDBValues(t *testing.T) { reg, _ := DB.Register() testdb.SetQueryWithArgsFunc(func(query string, args []driver.Value) (result driver.Rows, err error) { columns := []string{"Username", "Password", "Subdomain", "Value", "LastActive"} resultrows := "invalid,invalid,invalid,invalid," return testdb.RowsFromCSVString(columns, resultrows), nil }) defer testdb.Reset() tdb, err := sql.Open("testdb", "") if err != nil { t.Errorf("Got error: %v", err) } oldDb := DB.GetBackend() DB.SetBackend(tdb) defer DB.SetBackend(oldDb) _, err = DB.GetByUsername(reg.Username) if err == nil { t.Errorf("Expected error from scan in, but got none") } _, err = DB.GetByDomain(reg.Subdomain) if err == nil { t.Errorf("Expected error from scan in GetByDomain, but got none") } } func TestGetByDomain(t *testing.T) { var regDomain = ACMETxt{} // Create reg to refer to reg, err := DB.Register() if err != nil { t.Errorf("Registration failed, got error [%v]", err) } regDomainSlice, err := DB.GetByDomain(reg.Subdomain) if err != nil { t.Errorf("Could not get test user, got error [%v]", err) } if len(regDomainSlice) == 0 { t.Errorf("No rows returned for GetByDomain [%s]", reg.Subdomain) } else { regDomain = regDomainSlice[0] } if reg.Username != regDomain.Username { t.Errorf("GetByUsername username [%q] did not match the original [%q]", regDomain.Username, reg.Username) } if reg.Subdomain != regDomain.Subdomain { t.Errorf("GetByUsername subdomain [%q] did not match the original [%q]", regDomain.Subdomain, reg.Subdomain) } // regDomain password already is a bcrypt hash if !correctPassword(reg.Password, regDomain.Password) { t.Errorf("The password [%s] does not match the hash [%s]", reg.Password, regDomain.Password) } // Not found regNotfound, _ := DB.GetByDomain("does-not-exist") if len(regNotfound) > 0 { t.Errorf("No records should be returned.") } } func TestUpdate(t *testing.T) { // Create reg to refer to reg, err := DB.Register() if err != nil { t.Errorf("Registration failed, got error [%v]", err) } regUser, err := DB.GetByUsername(reg.Username) if err != nil { t.Errorf("Could not get test user, got error [%v]", err) } // Set new values (only TXT should be updated) (matches by username and subdomain) validTXT := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" regUser.Password = "nevergonnagiveyouup" regUser.Value = validTXT err = DB.Update(regUser) if err != nil { t.Errorf("DB Update failed, got error: [%v]", err) } updUser, err := DB.GetByUsername(regUser.Username) if err != nil { t.Errorf("GetByUsername threw error [%v]", err) } if updUser.Value != validTXT { t.Errorf("Update failed, fetched value [%s] does not match the update value [%s]", updUser.Value, validTXT) } }