Initial commit, PoC quality

This commit is contained in:
Joona Hoikkala 2016-11-11 16:48:00 +02:00
commit 5433444b2f
No known key found for this signature in database
GPG Key ID: C14AAE0F5ADCB854
9 changed files with 593 additions and 0 deletions

27
acmetxt.go Normal file
View File

@ -0,0 +1,27 @@
package main
import (
"github.com/satori/go.uuid"
"time"
)
// The default database object
type ACMETxt struct {
Username string `json:"username"`
Password string `json:"password"`
ACMETxtPost
LastActive time.Time
}
type ACMETxtPost struct {
Subdomain string `json:"subdomain"`
Value string `json:"txt"`
}
func NewACMETxt() ACMETxt {
var a ACMETxt = ACMETxt{}
a.Username = uuid.NewV4().String()
a.Password = uuid.NewV4().String()
a.Subdomain = uuid.NewV4().String()
return a
}

97
api.go Normal file
View File

@ -0,0 +1,97 @@
package main
import (
"errors"
"fmt"
"github.com/kataras/iris"
)
func GetHandlerMap() map[string]func(*iris.Context) {
return map[string]func(*iris.Context){
"/register": WebRegisterGet,
}
}
func PostHandlerMap() map[string]func(*iris.Context) {
return map[string]func(*iris.Context){
"/register": WebRegisterPost,
"/update": WebUpdatePost,
}
}
func WebRegisterPost(ctx *iris.Context) {
// Create new user
nu, err := DB.Register()
var reg_json iris.Map
var reg_status int
if err != nil {
errstr := fmt.Sprintf("%v", err)
reg_json = iris.Map{"username": "", "password": "", "domain": "", "error": errstr}
reg_status = iris.StatusInternalServerError
} else {
reg_json = iris.Map{"username": nu.Username, "password": nu.Password, "fulldomain": nu.Subdomain + "." + DnsConf.General.Domain, "subdomain": nu.Subdomain}
reg_status = iris.StatusCreated
}
ctx.JSON(reg_status, reg_json)
}
func WebRegisterGet(ctx *iris.Context) {
// This is placeholder for now
WebRegisterPost(ctx)
}
func WebUpdatePost(ctx *iris.Context) {
var username, password string
var a ACMETxtPost = ACMETxtPost{}
username = ctx.RequestHeader("X-API-User")
password = ctx.RequestHeader("X-API-Key")
if err := ctx.ReadJSON(&a); err != nil {
// Handle bad post data
WebUpdatePostError(ctx, err, iris.StatusBadRequest)
return
}
// Sanitized by db function
euser, err := DB.GetByUsername(username)
if err != nil {
// DB error
WebUpdatePostError(ctx, err, iris.StatusInternalServerError)
return
}
if len(euser) == 0 {
// User not found
// TODO: do bcrypt to avoid side channel
WebUpdatePostError(ctx, errors.New("invalid user or api key"), iris.StatusUnauthorized)
return
}
// Get first (and the only) user
upduser := euser[0]
// Validate password
if upduser.Password != password {
// Invalid password
WebUpdatePostError(ctx, errors.New("invalid user or api key"), iris.StatusUnauthorized)
return
} else {
// Do update
if len(a.Value) == 0 {
WebUpdatePostError(ctx, errors.New("missing txt value"), iris.StatusBadRequest)
return
} else {
upduser.Value = a.Value
err = DB.Update(upduser)
if err != nil {
WebUpdatePostError(ctx, err, iris.StatusInternalServerError)
return
}
// All ok
ctx.JSON(iris.StatusOK, iris.Map{"txt": upduser.Value})
}
}
}
func WebUpdatePostError(ctx *iris.Context, err error, status int) {
err_str := fmt.Sprintf("%v", err)
upd_json := iris.Map{"error": err_str}
ctx.JSON(status, upd_json)
}

27
config.cfg Normal file
View File

@ -0,0 +1,27 @@
[general]
# domain name to serve th requests off of
domain = "auth.example.org"
# zone name server
nsname = "ns1.auth.example.org"
# admin email address, with @ substituted with .
nsadmin = "admin.example.org"
# possible values: "letsencrypt", "cert", "false"
tls = "letsencrypt"
# only used if tls = "cert"
tls_cert_privkey = "/etc/tls/example.org/privkey.pem"
tls_cert_fullchain = "/etc/tls/example.org/fullchain.pem"
# predefined records that we're serving in addition to the TXT
records = [
# default A
"auth.example.org. A 192.168.1.100",
# A
"ns1.auth.example.org. A 192.168.1.100",
"ns2.auth.example.org. A 192.168.1.100",
# NS
"auth.example.org. NS ns1.auth.example.org.",
"auth.example.org. NS ns2.auth.example.org.",
]

14
config.go Normal file
View File

@ -0,0 +1,14 @@
package main
import (
"errors"
"github.com/BurntSushi/toml"
)
func ReadConfig(fname string) (DnsConfig, error) {
var conf DnsConfig
if _, err := toml.DecodeFile(fname, &conf); err != nil {
return DnsConfig{}, errors.New("Malformed configuration file")
}
return conf, nil
}

189
db.go Normal file
View File

@ -0,0 +1,189 @@
package main
import (
"database/sql"
//"encoding/json"
//"github.com/boltdb/bolt"
_ "github.com/mattn/go-sqlite3"
//"strings"
)
type Database struct {
DB *sql.DB
}
var records_table string = `
CREATE TABLE IF NOT EXISTS records(
Username TEXT UNIQUE NOT NULL PRIMARY KEY,
Password TEXT UNIQUE NOT NULL,
Subdomain TEXT UNIQUE NOT NULL,
Value TEXT,
LastActive DATETIME
);`
func (d *Database) Init(filename string) error {
db, err := sql.Open("sqlite3", filename)
if err != nil {
return err
}
d.DB = db
_, err = d.DB.Exec(records_table)
if err != nil {
return err
}
return nil
}
func (d *Database) Register() (ACMETxt, error) {
a := NewACMETxt()
reg_sql := `
INSERT INTO records(
Username,
Password,
Subdomain,
Value,
LastActive)
values(?, ?, ?, ?, CURRENT_TIMESTAMP)`
sm, err := d.DB.Prepare(reg_sql)
if err != nil {
return a, err
}
defer sm.Close()
_, err = sm.Exec(a.Username, a.Password, a.Subdomain, a.Value)
if err != nil {
return a, err
}
// Do an insert check
/*
id, err := status.LastInsertId()
if err != nil {
return a, err
}*/
return a, nil
}
func (d *Database) GetByUsername(u string) ([]ACMETxt, error) {
u = NormalizeString(u, 36)
log.Debugf("Trying to select by user [%s] from table", u)
var results []ACMETxt
get_sql := `
SELECT Username, Password, Subdomain, Value
FROM records
WHERE Username=? LIMIT 1
`
sm, err := d.DB.Prepare(get_sql)
if err != nil {
return nil, err
}
defer sm.Close()
rows, err := sm.Query(u)
if err != nil {
return nil, err
}
defer rows.Close()
// It will only be one row though
for rows.Next() {
var a ACMETxt = ACMETxt{}
err = rows.Scan(&a.Username, &a.Password, &a.Subdomain, &a.Value)
if err != nil {
return nil, err
}
results = append(results, a)
}
return results, nil
}
func (d *Database) GetByDomain(domain string) ([]ACMETxt, error) {
domain = NormalizeString(domain, 36)
log.Debugf("Trying to select domain [%s] from table", domain)
var a []ACMETxt
get_sql := `
SELECT Username, Password, Subdomain, Value
FROM records
WHERE Subdomain=? LIMIT 1
`
sm, err := d.DB.Prepare(get_sql)
if err != nil {
return a, err
}
defer sm.Close()
rows, err := sm.Query(domain)
if err != nil {
return a, err
}
defer rows.Close()
for rows.Next() {
txt := ACMETxt{}
err = rows.Scan(&txt.Username, &txt.Password, &txt.Subdomain, &txt.Value)
if err != nil {
return a, err
}
a = append(a, txt)
}
return a, nil
}
func (d *Database) Update(a ACMETxt) error {
// Data in a is already sanitized
log.Debugf("Trying to update domain [%s] with TXT data [%s]", a.Subdomain, a.Value)
upd_sql := `
UPDATE records SET Value=?
WHERE Username=? AND Subdomain=?
`
sm, err := d.DB.Prepare(upd_sql)
if err != nil {
return err
}
defer sm.Close()
_, err = sm.Exec(a.Value, a.Username, a.Subdomain)
if err != nil {
return err
}
return nil
}
/*
func addTXT(txt ACMETxt) error {
err := db.Update(func(tx *bolt.Tx) error {
bucket, err := tx.CreateBucketIfNotExists([]byte("domains"))
if err != nil {
return err
}
jtxt, err := json.Marshal(txt)
if err != nil {
return err
}
// put returns nil if successful, nil return commits db.Update
return bucket.Put([]byte(strings.ToLower(txt.Domain)), jtxt)
})
return err
}
func getTXT(domain string) (ACMETxt, error) {
var atxt ACMETxt
err := db.View(func(tx *bolt.Tx) error {
bucket := tx.Bucket([]byte("domains"))
value := bucket.Get([]byte(strings.ToLower(domain)))
if len(value) == 0 {
// Not found
log.Debugf("Record for [%s] not found", domain)
atxt = ACMETxt{}
} else {
if err := json.Unmarshal(value, &atxt); err != nil {
return err
}
}
return nil
})
if err != nil {
return ACMETxt{}, err
}
return atxt, err
}
*/

104
dns.go Normal file
View File

@ -0,0 +1,104 @@
package main
import (
"fmt"
"github.com/miekg/dns"
"time"
)
func readQuery(m *dns.Msg) {
for _, que := range m.Question {
if rr, rc, err := answer(que); err == nil {
m.MsgHdr.Rcode = rc
for _, r := range rr {
m.Answer = append(m.Answer, r)
}
}
}
}
func answerTXT(q dns.Question) ([]dns.RR, int, error) {
var ra []dns.RR
var rcode int = dns.RcodeNameError
var domain string = q.Name
atxt, err := DB.GetByDomain(domain)
if err != nil {
log.Errorf("Error while trying to get record [%v]", err)
return ra, dns.RcodeNameError, err
}
for _, v := range atxt {
if len(v.Value) > 0 {
r := new(dns.TXT)
r.Hdr = dns.RR_Header{Name: domain, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 1}
r.Txt = append(r.Txt, v.Value)
ra = append(ra, r)
rcode = dns.RcodeSuccess
}
}
log.Debugf("Answering TXT question for domain [%s]", domain)
return ra, rcode, nil
}
func answer(q dns.Question) ([]dns.RR, int, error) {
if q.Qtype == dns.TypeTXT {
return answerTXT(q)
}
var r []dns.RR
var rcode int = dns.RcodeSuccess
var domain string = q.Name
var rtype uint16 = q.Qtype
r, ok := RR.Records[rtype][domain]
if !ok {
rcode = dns.RcodeNameError
}
log.Debugf("Answering [%s] question for domain [%s] with rcode [%s]", dns.TypeToString[rtype], domain, dns.RcodeToString[rcode])
return r, rcode, nil
}
func handleRequest(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
if r.Opcode == dns.OpcodeQuery {
readQuery(m)
}
w.WriteMsg(m)
}
// Parse config records
func (r *Records) Parse(recs []string) {
rrmap := make(map[uint16]map[string][]dns.RR)
for _, v := range recs {
rr, err := dns.NewRR(v)
if err != nil {
log.Errorf("Could not parse RR from config: [%v] for RR: [%s]", err, v)
continue
}
// Add parsed RR to the list
rrmap = AppendRR(rrmap, rr)
}
// Create serial
serial := time.Now().Format("2006010215")
// Add SOA
SOAstring := fmt.Sprintf("%s. SOA %s. %s. %s 28800 7200 604800 86400", DnsConf.General.Domain, DnsConf.General.Nsname, DnsConf.General.Nsadmin, serial)
soarr, err := dns.NewRR(SOAstring)
if err != nil {
log.Errorf("Error [%v] while trying to add SOA record: [%s]", err, SOAstring)
} else {
rrmap = AppendRR(rrmap, soarr)
}
r.Records = rrmap
}
func AppendRR(rrmap map[uint16]map[string][]dns.RR, rr dns.RR) map[uint16]map[string][]dns.RR {
_, ok := rrmap[rr.Header().Rrtype]
if !ok {
newrr := make(map[string][]dns.RR)
rrmap[rr.Header().Rrtype] = newrr
}
rrmap[rr.Header().Rrtype][rr.Header().Name] = append(rrmap[rr.Header().Rrtype][rr.Header().Name], rr)
log.Debugf("Adding new record of type [%s] for domain [%s]", dns.TypeToString[rr.Header().Rrtype], rr.Header().Name)
return rrmap
}

89
main.go Normal file
View File

@ -0,0 +1,89 @@
package main
import (
"fmt"
"github.com/kataras/iris"
"github.com/miekg/dns"
"github.com/op/go-logging"
"os"
)
// Logging config
var logfile_path = "acme-dns.log"
var log = logging.MustGetLogger("acme-dns")
// Global configuration struct
var DnsConf DnsConfig
var DB Database
// Static records
var RR Records
func main() {
// Setup logging
var stdout_format = logging.MustStringFormatter(
`%{color}%{time:15:04:05.000} %{shortfunc} ▶ %{level:.4s} %{id:03x}%{color:reset} %{message}`,
)
var file_format = logging.MustStringFormatter(
`%{time:15:04:05.000} %{shortfunc} - %{level:.4s} %{id:03x} %{message}`,
)
// Setup logging - stdout
logStdout := logging.NewLogBackend(os.Stdout, "", 0)
logStdoutFormatter := logging.NewBackendFormatter(logStdout, stdout_format)
// Setup logging - file
// Logging to file
logfh, err := os.OpenFile(logfile_path, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666)
if err != nil {
fmt.Printf("Could not open log file %s\n", logfile_path)
os.Exit(1)
}
defer logfh.Close()
logFile := logging.NewLogBackend(logfh, "", 0)
logFileFormatter := logging.NewBackendFormatter(logFile, file_format)
/* To limit logging to a level
logFileLeveled := logging.AddModuleLevel(logFile)
logFileLeveled.SetLevel(logging.ERROR, "")
*/
// Start logging
logging.SetBackend(logStdoutFormatter, logFileFormatter)
log.Debug("Starting up...")
// Read global config
if DnsConf, err = ReadConfig("config.cfg"); err != nil {
log.Errorf("Got error %v", err)
os.Exit(1)
}
RR.Parse(DnsConf.General.StaticRecords)
// Open database
err = DB.Init("acme-dns.db")
if err != nil {
log.Errorf("Could not open database [%v]", err)
os.Exit(1)
}
defer DB.DB.Close()
// DNS server part
dns.HandleFunc(".", handleRequest)
server := &dns.Server{Addr: ":53", Net: "udp"}
go func() {
err = server.ListenAndServe()
if err != nil {
log.Errorf("%v", err)
os.Exit(1)
}
}()
// API server
api := iris.New()
for path, handlerfunc := range GetHandlerMap() {
api.Get(path, handlerfunc)
}
for path, handlerfunc := range PostHandlerMap() {
api.Post(path, handlerfunc)
}
api.Listen(":8080")
log.Debugf("Shutting down...")
}

26
types.go Normal file
View File

@ -0,0 +1,26 @@
package main
import (
"github.com/miekg/dns"
)
// Static records
type Records struct {
Records map[uint16]map[string][]dns.RR
}
// Config file main struct
type DnsConfig struct {
General general
}
// Config file general section
type general struct {
Domain string
Nsname string
Nsadmin string
Tls string
Tls_cert_privkey string
Tls_cert_fullchain string
StaticRecords []string `toml:"records"`
}

20
util.go Normal file
View File

@ -0,0 +1,20 @@
package main
import (
"regexp"
"unicode/utf8"
)
func NormalizeString(s string, len int) string {
var ret string
re, err := regexp.Compile("[^A-Za-z\\-0-9]+")
if err != nil {
log.Errorf("%v", err)
return ""
}
ret = re.ReplaceAllString(s, "")
if utf8.RuneCountInString(ret) > len {
ret = ret[0:len]
}
return ret
}