mirror of
https://github.com/joohoi/acme-dns.git
synced 2025-07-04 07:17:24 +07:00
DB code for CIDR handling
This commit is contained in:
51
acmetxt.go
Normal file
51
acmetxt.go
Normal file
@ -0,0 +1,51 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net"
|
||||
|
||||
"github.com/satori/go.uuid"
|
||||
)
|
||||
|
||||
// ACMETxt is the default structure for the user controlled record
|
||||
type ACMETxt struct {
|
||||
Username uuid.UUID
|
||||
Password string
|
||||
ACMETxtPost
|
||||
LastActive int64
|
||||
AllowFrom cidrslice
|
||||
}
|
||||
|
||||
// ACMETxtPost holds the DNS part of the ACMETxt struct
|
||||
type ACMETxtPost struct {
|
||||
Subdomain string `json:"subdomain"`
|
||||
Value string `json:"txt"`
|
||||
}
|
||||
|
||||
// cidrslice is a list of allowed cidr ranges
|
||||
type cidrslice []string
|
||||
|
||||
func (c *cidrslice) JSON() string {
|
||||
ret, _ := json.Marshal(c.ValidEntries())
|
||||
return string(ret)
|
||||
}
|
||||
|
||||
func (c *cidrslice) ValidEntries() []string {
|
||||
valid := []string{}
|
||||
for _, v := range *c {
|
||||
_, _, err := net.ParseCIDR(v)
|
||||
if err == nil {
|
||||
valid = append(valid, v)
|
||||
}
|
||||
}
|
||||
return valid
|
||||
}
|
||||
|
||||
func newACMETxt() ACMETxt {
|
||||
var a = ACMETxt{}
|
||||
password := generatePassword(40)
|
||||
a.Username = uuid.NewV4()
|
||||
a.Password = password
|
||||
a.Subdomain = uuid.NewV4().String()
|
||||
return a
|
||||
}
|
30
api.go
30
api.go
@ -16,28 +16,36 @@ func (a authMiddleware) Serve(ctx *iris.Context) {
|
||||
username, err := getValidUsername(usernameStr)
|
||||
if err == nil && validKey(password) {
|
||||
au, err := DB.GetByUsername(username)
|
||||
if err == nil && correctPassword(password, au.Password) {
|
||||
// Password ok
|
||||
if err := ctx.ReadJSON(&postData); err == nil {
|
||||
// Check that the subdomain belongs to the user
|
||||
if au.Subdomain == postData.Subdomain {
|
||||
ctx.Next()
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{"error": err.Error()}).Error("Error while trying to get user")
|
||||
// To protect against timed side channel (never gonna give you up)
|
||||
correctPassword(password, "$2a$10$8JEFVNYYhLoBysjAxe2yBuXrkDojBQBkVpXEQgyQyjn43SvJ4vL36")
|
||||
} else {
|
||||
if correctPassword(password, au.Password) {
|
||||
// Password ok
|
||||
if err := ctx.ReadJSON(&postData); err == nil {
|
||||
// Check that the subdomain belongs to the user
|
||||
if au.Subdomain == postData.Subdomain {
|
||||
ctx.Next()
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// JSON error
|
||||
ctx.JSON(iris.StatusBadRequest, iris.Map{"error": "bad data"})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
ctx.JSON(iris.StatusBadRequest, iris.Map{"error": "bad data"})
|
||||
return
|
||||
// Wrong password
|
||||
log.WithFields(log.Fields{"username": username}).Warning("Failed password check")
|
||||
}
|
||||
}
|
||||
// To protect against timed side channel (never gonna give you up)
|
||||
correctPassword(password, "$2a$10$8JEFVNYYhLoBysjAxe2yBuXrkDojBQBkVpXEQgyQyjn43SvJ4vL36")
|
||||
}
|
||||
ctx.JSON(iris.StatusUnauthorized, iris.Map{"error": "unauthorized"})
|
||||
}
|
||||
|
||||
func webRegisterPost(ctx *iris.Context) {
|
||||
// Create new user
|
||||
nu, err := DB.Register()
|
||||
nu, err := DB.Register(cidrslice{})
|
||||
var regJSON iris.Map
|
||||
var regStatus int
|
||||
if err != nil {
|
||||
|
@ -90,7 +90,7 @@ func TestApiUpdateWithCredentials(t *testing.T) {
|
||||
"txt": ""}
|
||||
|
||||
e := setupIris(t, false, false)
|
||||
newUser, err := DB.Register()
|
||||
newUser, err := DB.Register(cidrslice{})
|
||||
if err != nil {
|
||||
t.Errorf("Could not create new user, got error [%v]", err)
|
||||
}
|
||||
@ -146,7 +146,7 @@ func TestApiManyUpdateWithCredentials(t *testing.T) {
|
||||
"txt": ""}
|
||||
|
||||
e := setupIris(t, false, false)
|
||||
newUser, err := DB.Register()
|
||||
newUser, err := DB.Register(cidrslice{})
|
||||
if err != nil {
|
||||
t.Errorf("Could not create new user, got error [%v]", err)
|
||||
}
|
||||
@ -164,6 +164,7 @@ func TestApiManyUpdateWithCredentials(t *testing.T) {
|
||||
{newUser.Username.String(), newUser.Password, newUser.Subdomain, "tooshortfortxt", 400},
|
||||
{newUser.Username.String(), newUser.Password, newUser.Subdomain, 1234567890, 400},
|
||||
{newUser.Username.String(), newUser.Password, newUser.Subdomain, validTxtData, 200},
|
||||
{newUser.Username.String(), "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", newUser.Subdomain, validTxtData, 401},
|
||||
} {
|
||||
updateJSON = map[string]interface{}{
|
||||
"subdomain": test.subdomain,
|
||||
|
26
db.go
26
db.go
@ -2,13 +2,16 @@ package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"regexp"
|
||||
"time"
|
||||
|
||||
log "github.com/Sirupsen/logrus"
|
||||
_ "github.com/lib/pq"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"github.com/satori/go.uuid"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"regexp"
|
||||
"time"
|
||||
)
|
||||
|
||||
var recordsTable = `
|
||||
@ -43,10 +46,11 @@ func (d *acmedb) Init(engine string, connection string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *acmedb) Register() (ACMETxt, error) {
|
||||
func (d *acmedb) Register(afrom cidrslice) (ACMETxt, error) {
|
||||
d.Lock()
|
||||
defer d.Unlock()
|
||||
a := newACMETxt()
|
||||
a.AllowFrom = cidrslice(afrom.ValidEntries())
|
||||
passwordHash, err := bcrypt.GenerateFromPassword([]byte(a.Password), 10)
|
||||
timenow := time.Now().Unix()
|
||||
regSQL := `
|
||||
@ -63,10 +67,11 @@ func (d *acmedb) Register() (ACMETxt, error) {
|
||||
}
|
||||
sm, err := d.DB.Prepare(regSQL)
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{"error": err.Error()}).Error("Database error in prepare")
|
||||
return a, errors.New("SQL error")
|
||||
}
|
||||
defer sm.Close()
|
||||
_, err = sm.Exec(a.Username.String(), passwordHash, a.Subdomain, timenow, a.AllowFrom)
|
||||
_, err = sm.Exec(a.Username.String(), passwordHash, a.Subdomain, timenow, a.AllowFrom.JSON())
|
||||
if err != nil {
|
||||
return a, err
|
||||
}
|
||||
@ -173,13 +178,24 @@ func (d *acmedb) Update(a ACMETxt) error {
|
||||
|
||||
func getModelFromRow(r *sql.Rows) (ACMETxt, error) {
|
||||
txt := ACMETxt{}
|
||||
afrom := ""
|
||||
err := r.Scan(
|
||||
&txt.Username,
|
||||
&txt.Password,
|
||||
&txt.Subdomain,
|
||||
&txt.Value,
|
||||
&txt.LastActive,
|
||||
&txt.AllowFrom)
|
||||
&afrom)
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{"error": err.Error()}).Error("Row scan error")
|
||||
}
|
||||
|
||||
cslice := cidrslice{}
|
||||
err = json.Unmarshal([]byte(afrom), &cslice)
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{"error": err.Error()}).Error("JSON unmarshall error")
|
||||
}
|
||||
txt.AllowFrom = cslice
|
||||
return txt, err
|
||||
}
|
||||
|
||||
|
47
db_test.go
47
db_test.go
@ -41,17 +41,44 @@ func TestDBInit(t *testing.T) {
|
||||
errorDB.Close()
|
||||
}
|
||||
|
||||
func TestRegister(t *testing.T) {
|
||||
func TestRegisterNoCIDR(t *testing.T) {
|
||||
// Register tests
|
||||
_, err := DB.Register()
|
||||
_, err := DB.Register(cidrslice{})
|
||||
if err != nil {
|
||||
t.Errorf("Registration failed, got error [%v]", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterMany(t *testing.T) {
|
||||
for i, test := range []struct {
|
||||
input cidrslice
|
||||
output cidrslice
|
||||
}{
|
||||
{cidrslice{"127.0.0.1/8", "8.8.8.8/32", "1.0.0.1/1"}, cidrslice{"127.0.0.1/8", "8.8.8.8/32", "1.0.0.1/1"}},
|
||||
{cidrslice{"1.1.1./32", "1922.168.42.42/8", "1.1.1.1/33", "1.2.3.4/"}, cidrslice{}},
|
||||
{cidrslice{"7.6.5.4/32", "invalid", "1.0.0.1/2"}, cidrslice{"7.6.5.4/32", "1.0.0.1/2"}},
|
||||
} {
|
||||
user, err := DB.Register(test.input)
|
||||
if err != nil {
|
||||
t.Errorf("Test %d: Got error from register method: [%v]", i, err)
|
||||
}
|
||||
res, err := DB.GetByUsername(user.Username)
|
||||
if err != nil {
|
||||
t.Errorf("Test %d: Got error when fetching username: [%v]", i, err)
|
||||
}
|
||||
if len(user.AllowFrom) != len(test.output) {
|
||||
t.Errorf("Test %d: Expected to recieve struct with [%d] entries in AllowFrom, but got [%d] records", i, len(test.output), len(user.AllowFrom))
|
||||
}
|
||||
if len(res.AllowFrom) != len(test.output) {
|
||||
t.Errorf("Test %d: Expected to recieve struct with [%d] entries in AllowFrom, but got [%d] records", i, len(test.output), len(res.AllowFrom))
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetByUsername(t *testing.T) {
|
||||
// Create reg to refer to
|
||||
reg, err := DB.Register()
|
||||
reg, err := DB.Register(cidrslice{})
|
||||
if err != nil {
|
||||
t.Errorf("Registration failed, got error [%v]", err)
|
||||
}
|
||||
@ -76,7 +103,7 @@ func TestGetByUsername(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestPrepareErrors(t *testing.T) {
|
||||
reg, _ := DB.Register()
|
||||
reg, _ := DB.Register(cidrslice{})
|
||||
tdb, err := sql.Open("testdb", "")
|
||||
if err != nil {
|
||||
t.Errorf("Got error: %v", err)
|
||||
@ -98,7 +125,7 @@ func TestPrepareErrors(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestQueryExecErrors(t *testing.T) {
|
||||
reg, _ := DB.Register()
|
||||
reg, _ := DB.Register(cidrslice{})
|
||||
testdb.SetExecWithArgsFunc(func(query string, args []driver.Value) (result driver.Result, err error) {
|
||||
return testResult{1, 0}, errors.New("Prepared query error")
|
||||
})
|
||||
@ -129,7 +156,7 @@ func TestQueryExecErrors(t *testing.T) {
|
||||
t.Errorf("Expected error from exec in GetByDomain, but got none")
|
||||
}
|
||||
|
||||
_, err = DB.Register()
|
||||
_, err = DB.Register(cidrslice{})
|
||||
if err == nil {
|
||||
t.Errorf("Expected error from exec in Register, but got none")
|
||||
}
|
||||
@ -142,7 +169,7 @@ func TestQueryExecErrors(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestQueryScanErrors(t *testing.T) {
|
||||
reg, _ := DB.Register()
|
||||
reg, _ := DB.Register(cidrslice{})
|
||||
|
||||
testdb.SetExecWithArgsFunc(func(query string, args []driver.Value) (result driver.Result, err error) {
|
||||
return testResult{1, 0}, errors.New("Prepared query error")
|
||||
@ -176,7 +203,7 @@ func TestQueryScanErrors(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBadDBValues(t *testing.T) {
|
||||
reg, _ := DB.Register()
|
||||
reg, _ := DB.Register(cidrslice{})
|
||||
|
||||
testdb.SetQueryWithArgsFunc(func(query string, args []driver.Value) (result driver.Rows, err error) {
|
||||
columns := []string{"Username", "Password", "Subdomain", "Value", "LastActive"}
|
||||
@ -209,7 +236,7 @@ func TestGetByDomain(t *testing.T) {
|
||||
var regDomain = ACMETxt{}
|
||||
|
||||
// Create reg to refer to
|
||||
reg, err := DB.Register()
|
||||
reg, err := DB.Register(cidrslice{})
|
||||
if err != nil {
|
||||
t.Errorf("Registration failed, got error [%v]", err)
|
||||
}
|
||||
@ -246,7 +273,7 @@ func TestGetByDomain(t *testing.T) {
|
||||
|
||||
func TestUpdate(t *testing.T) {
|
||||
// Create reg to refer to
|
||||
reg, err := DB.Register()
|
||||
reg, err := DB.Register(cidrslice{})
|
||||
if err != nil {
|
||||
t.Errorf("Registration failed, got error [%v]", err)
|
||||
}
|
||||
|
@ -139,7 +139,7 @@ func TestResolveTXT(t *testing.T) {
|
||||
resolv := resolver{server: "0.0.0.0:15353"}
|
||||
validTXT := "______________valid_response_______________"
|
||||
|
||||
atxt, err := DB.Register()
|
||||
atxt, err := DB.Register(cidrslice{})
|
||||
if err != nil {
|
||||
t.Errorf("Could not initiate db record: [%v]", err)
|
||||
return
|
||||
|
17
types.go
17
types.go
@ -66,21 +66,6 @@ type logconfig struct {
|
||||
Format string `toml:"logformat"`
|
||||
}
|
||||
|
||||
// ACMETxt is the default structure for the user controlled record
|
||||
type ACMETxt struct {
|
||||
Username uuid.UUID
|
||||
Password string
|
||||
ACMETxtPost
|
||||
LastActive int64
|
||||
AllowFrom string
|
||||
}
|
||||
|
||||
// ACMETxtPost holds the DNS part of the ACMETxt struct
|
||||
type ACMETxtPost struct {
|
||||
Subdomain string `json:"subdomain"`
|
||||
Value string `json:"txt"`
|
||||
}
|
||||
|
||||
type acmedb struct {
|
||||
sync.Mutex
|
||||
DB *sql.DB
|
||||
@ -88,7 +73,7 @@ type acmedb struct {
|
||||
|
||||
type database interface {
|
||||
Init(string, string) error
|
||||
Register() (ACMETxt, error)
|
||||
Register(cidrslice) (ACMETxt, error)
|
||||
GetByUsername(uuid.UUID) (ACMETxt, error)
|
||||
GetByDomain(string) ([]ACMETxt, error)
|
||||
Update(ACMETxt) error
|
||||
|
10
util.go
10
util.go
@ -5,7 +5,6 @@ import (
|
||||
"github.com/BurntSushi/toml"
|
||||
log "github.com/Sirupsen/logrus"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/satori/go.uuid"
|
||||
"math/big"
|
||||
"regexp"
|
||||
"strings"
|
||||
@ -45,15 +44,6 @@ func sanitizeDomainQuestion(d string) string {
|
||||
return dom
|
||||
}
|
||||
|
||||
func newACMETxt() ACMETxt {
|
||||
var a = ACMETxt{}
|
||||
password := generatePassword(40)
|
||||
a.Username = uuid.NewV4()
|
||||
a.Password = password
|
||||
a.Subdomain = uuid.NewV4().String()
|
||||
return a
|
||||
}
|
||||
|
||||
func setupLogging(format string, level string) {
|
||||
if format == "json" {
|
||||
log.SetFormatter(&log.JSONFormatter{})
|
||||
|
@ -1,9 +1,10 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/satori/go.uuid"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
func getValidUsername(u string) (uuid.UUID, error) {
|
||||
|
@ -106,3 +106,24 @@ func TestCorrectPassword(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetValidCIDRMasks(t *testing.T) {
|
||||
for i, test := range []struct {
|
||||
input cidrslice
|
||||
output cidrslice
|
||||
}{
|
||||
{cidrslice{"10.0.0.1/24"}, cidrslice{"10.0.0.1/24"}},
|
||||
{cidrslice{"invalid", "127.0.0.1/32"}, cidrslice{"127.0.0.1/32"}},
|
||||
} {
|
||||
ret := test.input.ValidEntries()
|
||||
if len(ret) == len(test.output) {
|
||||
for i, v := range ret {
|
||||
if v != test.output[i] {
|
||||
t.Errorf("Test %d: Expected %q but got %q", i, test.output, ret)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
t.Errorf("Test %d: Expected %q but got %q", i, test.output, ret)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user