DB code for CIDR handling

This commit is contained in:
Joona Hoikkala
2016-12-01 00:03:08 +02:00
parent 8ffed7a3d7
commit c3ac7a211c
10 changed files with 156 additions and 56 deletions

51
acmetxt.go Normal file
View File

@ -0,0 +1,51 @@
package main
import (
"encoding/json"
"net"
"github.com/satori/go.uuid"
)
// ACMETxt is the default structure for the user controlled record
type ACMETxt struct {
Username uuid.UUID
Password string
ACMETxtPost
LastActive int64
AllowFrom cidrslice
}
// ACMETxtPost holds the DNS part of the ACMETxt struct
type ACMETxtPost struct {
Subdomain string `json:"subdomain"`
Value string `json:"txt"`
}
// cidrslice is a list of allowed cidr ranges
type cidrslice []string
func (c *cidrslice) JSON() string {
ret, _ := json.Marshal(c.ValidEntries())
return string(ret)
}
func (c *cidrslice) ValidEntries() []string {
valid := []string{}
for _, v := range *c {
_, _, err := net.ParseCIDR(v)
if err == nil {
valid = append(valid, v)
}
}
return valid
}
func newACMETxt() ACMETxt {
var a = ACMETxt{}
password := generatePassword(40)
a.Username = uuid.NewV4()
a.Password = password
a.Subdomain = uuid.NewV4().String()
return a
}

30
api.go
View File

@ -16,28 +16,36 @@ func (a authMiddleware) Serve(ctx *iris.Context) {
username, err := getValidUsername(usernameStr)
if err == nil && validKey(password) {
au, err := DB.GetByUsername(username)
if err == nil && correctPassword(password, au.Password) {
// Password ok
if err := ctx.ReadJSON(&postData); err == nil {
// Check that the subdomain belongs to the user
if au.Subdomain == postData.Subdomain {
ctx.Next()
if err != nil {
log.WithFields(log.Fields{"error": err.Error()}).Error("Error while trying to get user")
// To protect against timed side channel (never gonna give you up)
correctPassword(password, "$2a$10$8JEFVNYYhLoBysjAxe2yBuXrkDojBQBkVpXEQgyQyjn43SvJ4vL36")
} else {
if correctPassword(password, au.Password) {
// Password ok
if err := ctx.ReadJSON(&postData); err == nil {
// Check that the subdomain belongs to the user
if au.Subdomain == postData.Subdomain {
ctx.Next()
return
}
} else {
// JSON error
ctx.JSON(iris.StatusBadRequest, iris.Map{"error": "bad data"})
return
}
} else {
ctx.JSON(iris.StatusBadRequest, iris.Map{"error": "bad data"})
return
// Wrong password
log.WithFields(log.Fields{"username": username}).Warning("Failed password check")
}
}
// To protect against timed side channel (never gonna give you up)
correctPassword(password, "$2a$10$8JEFVNYYhLoBysjAxe2yBuXrkDojBQBkVpXEQgyQyjn43SvJ4vL36")
}
ctx.JSON(iris.StatusUnauthorized, iris.Map{"error": "unauthorized"})
}
func webRegisterPost(ctx *iris.Context) {
// Create new user
nu, err := DB.Register()
nu, err := DB.Register(cidrslice{})
var regJSON iris.Map
var regStatus int
if err != nil {

View File

@ -90,7 +90,7 @@ func TestApiUpdateWithCredentials(t *testing.T) {
"txt": ""}
e := setupIris(t, false, false)
newUser, err := DB.Register()
newUser, err := DB.Register(cidrslice{})
if err != nil {
t.Errorf("Could not create new user, got error [%v]", err)
}
@ -146,7 +146,7 @@ func TestApiManyUpdateWithCredentials(t *testing.T) {
"txt": ""}
e := setupIris(t, false, false)
newUser, err := DB.Register()
newUser, err := DB.Register(cidrslice{})
if err != nil {
t.Errorf("Could not create new user, got error [%v]", err)
}
@ -164,6 +164,7 @@ func TestApiManyUpdateWithCredentials(t *testing.T) {
{newUser.Username.String(), newUser.Password, newUser.Subdomain, "tooshortfortxt", 400},
{newUser.Username.String(), newUser.Password, newUser.Subdomain, 1234567890, 400},
{newUser.Username.String(), newUser.Password, newUser.Subdomain, validTxtData, 200},
{newUser.Username.String(), "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", newUser.Subdomain, validTxtData, 401},
} {
updateJSON = map[string]interface{}{
"subdomain": test.subdomain,

26
db.go
View File

@ -2,13 +2,16 @@ package main
import (
"database/sql"
"encoding/json"
"errors"
"regexp"
"time"
log "github.com/Sirupsen/logrus"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
"github.com/satori/go.uuid"
"golang.org/x/crypto/bcrypt"
"regexp"
"time"
)
var recordsTable = `
@ -43,10 +46,11 @@ func (d *acmedb) Init(engine string, connection string) error {
return nil
}
func (d *acmedb) Register() (ACMETxt, error) {
func (d *acmedb) Register(afrom cidrslice) (ACMETxt, error) {
d.Lock()
defer d.Unlock()
a := newACMETxt()
a.AllowFrom = cidrslice(afrom.ValidEntries())
passwordHash, err := bcrypt.GenerateFromPassword([]byte(a.Password), 10)
timenow := time.Now().Unix()
regSQL := `
@ -63,10 +67,11 @@ func (d *acmedb) Register() (ACMETxt, error) {
}
sm, err := d.DB.Prepare(regSQL)
if err != nil {
log.WithFields(log.Fields{"error": err.Error()}).Error("Database error in prepare")
return a, errors.New("SQL error")
}
defer sm.Close()
_, err = sm.Exec(a.Username.String(), passwordHash, a.Subdomain, timenow, a.AllowFrom)
_, err = sm.Exec(a.Username.String(), passwordHash, a.Subdomain, timenow, a.AllowFrom.JSON())
if err != nil {
return a, err
}
@ -173,13 +178,24 @@ func (d *acmedb) Update(a ACMETxt) error {
func getModelFromRow(r *sql.Rows) (ACMETxt, error) {
txt := ACMETxt{}
afrom := ""
err := r.Scan(
&txt.Username,
&txt.Password,
&txt.Subdomain,
&txt.Value,
&txt.LastActive,
&txt.AllowFrom)
&afrom)
if err != nil {
log.WithFields(log.Fields{"error": err.Error()}).Error("Row scan error")
}
cslice := cidrslice{}
err = json.Unmarshal([]byte(afrom), &cslice)
if err != nil {
log.WithFields(log.Fields{"error": err.Error()}).Error("JSON unmarshall error")
}
txt.AllowFrom = cslice
return txt, err
}

View File

@ -41,17 +41,44 @@ func TestDBInit(t *testing.T) {
errorDB.Close()
}
func TestRegister(t *testing.T) {
func TestRegisterNoCIDR(t *testing.T) {
// Register tests
_, err := DB.Register()
_, err := DB.Register(cidrslice{})
if err != nil {
t.Errorf("Registration failed, got error [%v]", err)
}
}
func TestRegisterMany(t *testing.T) {
for i, test := range []struct {
input cidrslice
output cidrslice
}{
{cidrslice{"127.0.0.1/8", "8.8.8.8/32", "1.0.0.1/1"}, cidrslice{"127.0.0.1/8", "8.8.8.8/32", "1.0.0.1/1"}},
{cidrslice{"1.1.1./32", "1922.168.42.42/8", "1.1.1.1/33", "1.2.3.4/"}, cidrslice{}},
{cidrslice{"7.6.5.4/32", "invalid", "1.0.0.1/2"}, cidrslice{"7.6.5.4/32", "1.0.0.1/2"}},
} {
user, err := DB.Register(test.input)
if err != nil {
t.Errorf("Test %d: Got error from register method: [%v]", i, err)
}
res, err := DB.GetByUsername(user.Username)
if err != nil {
t.Errorf("Test %d: Got error when fetching username: [%v]", i, err)
}
if len(user.AllowFrom) != len(test.output) {
t.Errorf("Test %d: Expected to recieve struct with [%d] entries in AllowFrom, but got [%d] records", i, len(test.output), len(user.AllowFrom))
}
if len(res.AllowFrom) != len(test.output) {
t.Errorf("Test %d: Expected to recieve struct with [%d] entries in AllowFrom, but got [%d] records", i, len(test.output), len(res.AllowFrom))
}
}
}
func TestGetByUsername(t *testing.T) {
// Create reg to refer to
reg, err := DB.Register()
reg, err := DB.Register(cidrslice{})
if err != nil {
t.Errorf("Registration failed, got error [%v]", err)
}
@ -76,7 +103,7 @@ func TestGetByUsername(t *testing.T) {
}
func TestPrepareErrors(t *testing.T) {
reg, _ := DB.Register()
reg, _ := DB.Register(cidrslice{})
tdb, err := sql.Open("testdb", "")
if err != nil {
t.Errorf("Got error: %v", err)
@ -98,7 +125,7 @@ func TestPrepareErrors(t *testing.T) {
}
func TestQueryExecErrors(t *testing.T) {
reg, _ := DB.Register()
reg, _ := DB.Register(cidrslice{})
testdb.SetExecWithArgsFunc(func(query string, args []driver.Value) (result driver.Result, err error) {
return testResult{1, 0}, errors.New("Prepared query error")
})
@ -129,7 +156,7 @@ func TestQueryExecErrors(t *testing.T) {
t.Errorf("Expected error from exec in GetByDomain, but got none")
}
_, err = DB.Register()
_, err = DB.Register(cidrslice{})
if err == nil {
t.Errorf("Expected error from exec in Register, but got none")
}
@ -142,7 +169,7 @@ func TestQueryExecErrors(t *testing.T) {
}
func TestQueryScanErrors(t *testing.T) {
reg, _ := DB.Register()
reg, _ := DB.Register(cidrslice{})
testdb.SetExecWithArgsFunc(func(query string, args []driver.Value) (result driver.Result, err error) {
return testResult{1, 0}, errors.New("Prepared query error")
@ -176,7 +203,7 @@ func TestQueryScanErrors(t *testing.T) {
}
func TestBadDBValues(t *testing.T) {
reg, _ := DB.Register()
reg, _ := DB.Register(cidrslice{})
testdb.SetQueryWithArgsFunc(func(query string, args []driver.Value) (result driver.Rows, err error) {
columns := []string{"Username", "Password", "Subdomain", "Value", "LastActive"}
@ -209,7 +236,7 @@ func TestGetByDomain(t *testing.T) {
var regDomain = ACMETxt{}
// Create reg to refer to
reg, err := DB.Register()
reg, err := DB.Register(cidrslice{})
if err != nil {
t.Errorf("Registration failed, got error [%v]", err)
}
@ -246,7 +273,7 @@ func TestGetByDomain(t *testing.T) {
func TestUpdate(t *testing.T) {
// Create reg to refer to
reg, err := DB.Register()
reg, err := DB.Register(cidrslice{})
if err != nil {
t.Errorf("Registration failed, got error [%v]", err)
}

View File

@ -139,7 +139,7 @@ func TestResolveTXT(t *testing.T) {
resolv := resolver{server: "0.0.0.0:15353"}
validTXT := "______________valid_response_______________"
atxt, err := DB.Register()
atxt, err := DB.Register(cidrslice{})
if err != nil {
t.Errorf("Could not initiate db record: [%v]", err)
return

View File

@ -66,21 +66,6 @@ type logconfig struct {
Format string `toml:"logformat"`
}
// ACMETxt is the default structure for the user controlled record
type ACMETxt struct {
Username uuid.UUID
Password string
ACMETxtPost
LastActive int64
AllowFrom string
}
// ACMETxtPost holds the DNS part of the ACMETxt struct
type ACMETxtPost struct {
Subdomain string `json:"subdomain"`
Value string `json:"txt"`
}
type acmedb struct {
sync.Mutex
DB *sql.DB
@ -88,7 +73,7 @@ type acmedb struct {
type database interface {
Init(string, string) error
Register() (ACMETxt, error)
Register(cidrslice) (ACMETxt, error)
GetByUsername(uuid.UUID) (ACMETxt, error)
GetByDomain(string) ([]ACMETxt, error)
Update(ACMETxt) error

10
util.go
View File

@ -5,7 +5,6 @@ import (
"github.com/BurntSushi/toml"
log "github.com/Sirupsen/logrus"
"github.com/miekg/dns"
"github.com/satori/go.uuid"
"math/big"
"regexp"
"strings"
@ -45,15 +44,6 @@ func sanitizeDomainQuestion(d string) string {
return dom
}
func newACMETxt() ACMETxt {
var a = ACMETxt{}
password := generatePassword(40)
a.Username = uuid.NewV4()
a.Password = password
a.Subdomain = uuid.NewV4().String()
return a
}
func setupLogging(format string, level string) {
if format == "json" {
log.SetFormatter(&log.JSONFormatter{})

View File

@ -1,9 +1,10 @@
package main
import (
"unicode/utf8"
"github.com/satori/go.uuid"
"golang.org/x/crypto/bcrypt"
"unicode/utf8"
)
func getValidUsername(u string) (uuid.UUID, error) {

View File

@ -106,3 +106,24 @@ func TestCorrectPassword(t *testing.T) {
}
}
}
func TestGetValidCIDRMasks(t *testing.T) {
for i, test := range []struct {
input cidrslice
output cidrslice
}{
{cidrslice{"10.0.0.1/24"}, cidrslice{"10.0.0.1/24"}},
{cidrslice{"invalid", "127.0.0.1/32"}, cidrslice{"127.0.0.1/32"}},
} {
ret := test.input.ValidEntries()
if len(ret) == len(test.output) {
for i, v := range ret {
if v != test.output[i] {
t.Errorf("Test %d: Expected %q but got %q", i, test.output, ret)
}
}
} else {
t.Errorf("Test %d: Expected %q but got %q", i, test.output, ret)
}
}
}