mirror of
https://github.com/joohoi/acme-dns.git
synced 2025-03-09 20:29:14 +07:00
Initial commit, PoC quality
This commit is contained in:
commit
5433444b2f
27
acmetxt.go
Normal file
27
acmetxt.go
Normal file
@ -0,0 +1,27 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/satori/go.uuid"
|
||||
"time"
|
||||
)
|
||||
|
||||
// The default database object
|
||||
type ACMETxt struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
ACMETxtPost
|
||||
LastActive time.Time
|
||||
}
|
||||
|
||||
type ACMETxtPost struct {
|
||||
Subdomain string `json:"subdomain"`
|
||||
Value string `json:"txt"`
|
||||
}
|
||||
|
||||
func NewACMETxt() ACMETxt {
|
||||
var a ACMETxt = ACMETxt{}
|
||||
a.Username = uuid.NewV4().String()
|
||||
a.Password = uuid.NewV4().String()
|
||||
a.Subdomain = uuid.NewV4().String()
|
||||
return a
|
||||
}
|
97
api.go
Normal file
97
api.go
Normal file
@ -0,0 +1,97 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/kataras/iris"
|
||||
)
|
||||
|
||||
func GetHandlerMap() map[string]func(*iris.Context) {
|
||||
return map[string]func(*iris.Context){
|
||||
"/register": WebRegisterGet,
|
||||
}
|
||||
}
|
||||
|
||||
func PostHandlerMap() map[string]func(*iris.Context) {
|
||||
return map[string]func(*iris.Context){
|
||||
"/register": WebRegisterPost,
|
||||
"/update": WebUpdatePost,
|
||||
}
|
||||
}
|
||||
|
||||
func WebRegisterPost(ctx *iris.Context) {
|
||||
// Create new user
|
||||
nu, err := DB.Register()
|
||||
var reg_json iris.Map
|
||||
var reg_status int
|
||||
if err != nil {
|
||||
errstr := fmt.Sprintf("%v", err)
|
||||
|
||||
reg_json = iris.Map{"username": "", "password": "", "domain": "", "error": errstr}
|
||||
reg_status = iris.StatusInternalServerError
|
||||
} else {
|
||||
reg_json = iris.Map{"username": nu.Username, "password": nu.Password, "fulldomain": nu.Subdomain + "." + DnsConf.General.Domain, "subdomain": nu.Subdomain}
|
||||
reg_status = iris.StatusCreated
|
||||
}
|
||||
ctx.JSON(reg_status, reg_json)
|
||||
}
|
||||
|
||||
func WebRegisterGet(ctx *iris.Context) {
|
||||
// This is placeholder for now
|
||||
WebRegisterPost(ctx)
|
||||
}
|
||||
|
||||
func WebUpdatePost(ctx *iris.Context) {
|
||||
var username, password string
|
||||
var a ACMETxtPost = ACMETxtPost{}
|
||||
username = ctx.RequestHeader("X-API-User")
|
||||
password = ctx.RequestHeader("X-API-Key")
|
||||
if err := ctx.ReadJSON(&a); err != nil {
|
||||
// Handle bad post data
|
||||
WebUpdatePostError(ctx, err, iris.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
// Sanitized by db function
|
||||
euser, err := DB.GetByUsername(username)
|
||||
if err != nil {
|
||||
// DB error
|
||||
WebUpdatePostError(ctx, err, iris.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if len(euser) == 0 {
|
||||
// User not found
|
||||
// TODO: do bcrypt to avoid side channel
|
||||
WebUpdatePostError(ctx, errors.New("invalid user or api key"), iris.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
// Get first (and the only) user
|
||||
upduser := euser[0]
|
||||
// Validate password
|
||||
if upduser.Password != password {
|
||||
// Invalid password
|
||||
WebUpdatePostError(ctx, errors.New("invalid user or api key"), iris.StatusUnauthorized)
|
||||
return
|
||||
} else {
|
||||
// Do update
|
||||
if len(a.Value) == 0 {
|
||||
WebUpdatePostError(ctx, errors.New("missing txt value"), iris.StatusBadRequest)
|
||||
return
|
||||
} else {
|
||||
upduser.Value = a.Value
|
||||
err = DB.Update(upduser)
|
||||
if err != nil {
|
||||
WebUpdatePostError(ctx, err, iris.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
// All ok
|
||||
ctx.JSON(iris.StatusOK, iris.Map{"txt": upduser.Value})
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func WebUpdatePostError(ctx *iris.Context, err error, status int) {
|
||||
err_str := fmt.Sprintf("%v", err)
|
||||
upd_json := iris.Map{"error": err_str}
|
||||
ctx.JSON(status, upd_json)
|
||||
}
|
27
config.cfg
Normal file
27
config.cfg
Normal file
@ -0,0 +1,27 @@
|
||||
[general]
|
||||
# domain name to serve th requests off of
|
||||
domain = "auth.example.org"
|
||||
# zone name server
|
||||
nsname = "ns1.auth.example.org"
|
||||
# admin email address, with @ substituted with .
|
||||
nsadmin = "admin.example.org"
|
||||
|
||||
# possible values: "letsencrypt", "cert", "false"
|
||||
tls = "letsencrypt"
|
||||
|
||||
# only used if tls = "cert"
|
||||
tls_cert_privkey = "/etc/tls/example.org/privkey.pem"
|
||||
tls_cert_fullchain = "/etc/tls/example.org/fullchain.pem"
|
||||
|
||||
# predefined records that we're serving in addition to the TXT
|
||||
|
||||
records = [
|
||||
# default A
|
||||
"auth.example.org. A 192.168.1.100",
|
||||
# A
|
||||
"ns1.auth.example.org. A 192.168.1.100",
|
||||
"ns2.auth.example.org. A 192.168.1.100",
|
||||
# NS
|
||||
"auth.example.org. NS ns1.auth.example.org.",
|
||||
"auth.example.org. NS ns2.auth.example.org.",
|
||||
]
|
14
config.go
Normal file
14
config.go
Normal file
@ -0,0 +1,14 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/BurntSushi/toml"
|
||||
)
|
||||
|
||||
func ReadConfig(fname string) (DnsConfig, error) {
|
||||
var conf DnsConfig
|
||||
if _, err := toml.DecodeFile(fname, &conf); err != nil {
|
||||
return DnsConfig{}, errors.New("Malformed configuration file")
|
||||
}
|
||||
return conf, nil
|
||||
}
|
189
db.go
Normal file
189
db.go
Normal file
@ -0,0 +1,189 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
//"encoding/json"
|
||||
//"github.com/boltdb/bolt"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
//"strings"
|
||||
)
|
||||
|
||||
type Database struct {
|
||||
DB *sql.DB
|
||||
}
|
||||
|
||||
var records_table string = `
|
||||
CREATE TABLE IF NOT EXISTS records(
|
||||
Username TEXT UNIQUE NOT NULL PRIMARY KEY,
|
||||
Password TEXT UNIQUE NOT NULL,
|
||||
Subdomain TEXT UNIQUE NOT NULL,
|
||||
Value TEXT,
|
||||
LastActive DATETIME
|
||||
);`
|
||||
|
||||
func (d *Database) Init(filename string) error {
|
||||
db, err := sql.Open("sqlite3", filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d.DB = db
|
||||
_, err = d.DB.Exec(records_table)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Database) Register() (ACMETxt, error) {
|
||||
a := NewACMETxt()
|
||||
reg_sql := `
|
||||
INSERT INTO records(
|
||||
Username,
|
||||
Password,
|
||||
Subdomain,
|
||||
Value,
|
||||
LastActive)
|
||||
values(?, ?, ?, ?, CURRENT_TIMESTAMP)`
|
||||
sm, err := d.DB.Prepare(reg_sql)
|
||||
if err != nil {
|
||||
return a, err
|
||||
}
|
||||
defer sm.Close()
|
||||
_, err = sm.Exec(a.Username, a.Password, a.Subdomain, a.Value)
|
||||
if err != nil {
|
||||
return a, err
|
||||
}
|
||||
// Do an insert check
|
||||
/*
|
||||
id, err := status.LastInsertId()
|
||||
if err != nil {
|
||||
return a, err
|
||||
}*/
|
||||
|
||||
return a, nil
|
||||
}
|
||||
|
||||
func (d *Database) GetByUsername(u string) ([]ACMETxt, error) {
|
||||
u = NormalizeString(u, 36)
|
||||
log.Debugf("Trying to select by user [%s] from table", u)
|
||||
var results []ACMETxt
|
||||
get_sql := `
|
||||
SELECT Username, Password, Subdomain, Value
|
||||
FROM records
|
||||
WHERE Username=? LIMIT 1
|
||||
`
|
||||
sm, err := d.DB.Prepare(get_sql)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer sm.Close()
|
||||
rows, err := sm.Query(u)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
// It will only be one row though
|
||||
for rows.Next() {
|
||||
var a ACMETxt = ACMETxt{}
|
||||
err = rows.Scan(&a.Username, &a.Password, &a.Subdomain, &a.Value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
results = append(results, a)
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (d *Database) GetByDomain(domain string) ([]ACMETxt, error) {
|
||||
domain = NormalizeString(domain, 36)
|
||||
log.Debugf("Trying to select domain [%s] from table", domain)
|
||||
var a []ACMETxt
|
||||
get_sql := `
|
||||
SELECT Username, Password, Subdomain, Value
|
||||
FROM records
|
||||
WHERE Subdomain=? LIMIT 1
|
||||
`
|
||||
sm, err := d.DB.Prepare(get_sql)
|
||||
if err != nil {
|
||||
return a, err
|
||||
}
|
||||
defer sm.Close()
|
||||
rows, err := sm.Query(domain)
|
||||
if err != nil {
|
||||
return a, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
txt := ACMETxt{}
|
||||
err = rows.Scan(&txt.Username, &txt.Password, &txt.Subdomain, &txt.Value)
|
||||
if err != nil {
|
||||
return a, err
|
||||
}
|
||||
a = append(a, txt)
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
|
||||
func (d *Database) Update(a ACMETxt) error {
|
||||
// Data in a is already sanitized
|
||||
log.Debugf("Trying to update domain [%s] with TXT data [%s]", a.Subdomain, a.Value)
|
||||
upd_sql := `
|
||||
UPDATE records SET Value=?
|
||||
WHERE Username=? AND Subdomain=?
|
||||
`
|
||||
sm, err := d.DB.Prepare(upd_sql)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer sm.Close()
|
||||
_, err = sm.Exec(a.Value, a.Username, a.Subdomain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
/*
|
||||
func addTXT(txt ACMETxt) error {
|
||||
|
||||
err := db.Update(func(tx *bolt.Tx) error {
|
||||
bucket, err := tx.CreateBucketIfNotExists([]byte("domains"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
jtxt, err := json.Marshal(txt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// put returns nil if successful, nil return commits db.Update
|
||||
return bucket.Put([]byte(strings.ToLower(txt.Domain)), jtxt)
|
||||
})
|
||||
return err
|
||||
|
||||
}
|
||||
|
||||
func getTXT(domain string) (ACMETxt, error) {
|
||||
var atxt ACMETxt
|
||||
err := db.View(func(tx *bolt.Tx) error {
|
||||
bucket := tx.Bucket([]byte("domains"))
|
||||
value := bucket.Get([]byte(strings.ToLower(domain)))
|
||||
if len(value) == 0 {
|
||||
// Not found
|
||||
log.Debugf("Record for [%s] not found", domain)
|
||||
atxt = ACMETxt{}
|
||||
} else {
|
||||
if err := json.Unmarshal(value, &atxt); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return ACMETxt{}, err
|
||||
}
|
||||
return atxt, err
|
||||
}
|
||||
*/
|
104
dns.go
Normal file
104
dns.go
Normal file
@ -0,0 +1,104 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/miekg/dns"
|
||||
"time"
|
||||
)
|
||||
|
||||
func readQuery(m *dns.Msg) {
|
||||
for _, que := range m.Question {
|
||||
if rr, rc, err := answer(que); err == nil {
|
||||
m.MsgHdr.Rcode = rc
|
||||
for _, r := range rr {
|
||||
m.Answer = append(m.Answer, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func answerTXT(q dns.Question) ([]dns.RR, int, error) {
|
||||
var ra []dns.RR
|
||||
var rcode int = dns.RcodeNameError
|
||||
var domain string = q.Name
|
||||
|
||||
atxt, err := DB.GetByDomain(domain)
|
||||
if err != nil {
|
||||
log.Errorf("Error while trying to get record [%v]", err)
|
||||
return ra, dns.RcodeNameError, err
|
||||
}
|
||||
for _, v := range atxt {
|
||||
if len(v.Value) > 0 {
|
||||
r := new(dns.TXT)
|
||||
r.Hdr = dns.RR_Header{Name: domain, Rrtype: dns.TypeTXT, Class: dns.ClassINET, Ttl: 1}
|
||||
r.Txt = append(r.Txt, v.Value)
|
||||
ra = append(ra, r)
|
||||
rcode = dns.RcodeSuccess
|
||||
}
|
||||
}
|
||||
log.Debugf("Answering TXT question for domain [%s]", domain)
|
||||
return ra, rcode, nil
|
||||
}
|
||||
|
||||
func answer(q dns.Question) ([]dns.RR, int, error) {
|
||||
if q.Qtype == dns.TypeTXT {
|
||||
return answerTXT(q)
|
||||
}
|
||||
var r []dns.RR
|
||||
var rcode int = dns.RcodeSuccess
|
||||
var domain string = q.Name
|
||||
var rtype uint16 = q.Qtype
|
||||
r, ok := RR.Records[rtype][domain]
|
||||
if !ok {
|
||||
rcode = dns.RcodeNameError
|
||||
}
|
||||
log.Debugf("Answering [%s] question for domain [%s] with rcode [%s]", dns.TypeToString[rtype], domain, dns.RcodeToString[rcode])
|
||||
return r, rcode, nil
|
||||
}
|
||||
|
||||
func handleRequest(w dns.ResponseWriter, r *dns.Msg) {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
|
||||
if r.Opcode == dns.OpcodeQuery {
|
||||
readQuery(m)
|
||||
}
|
||||
|
||||
w.WriteMsg(m)
|
||||
}
|
||||
|
||||
// Parse config records
|
||||
func (r *Records) Parse(recs []string) {
|
||||
rrmap := make(map[uint16]map[string][]dns.RR)
|
||||
for _, v := range recs {
|
||||
rr, err := dns.NewRR(v)
|
||||
if err != nil {
|
||||
log.Errorf("Could not parse RR from config: [%v] for RR: [%s]", err, v)
|
||||
continue
|
||||
}
|
||||
// Add parsed RR to the list
|
||||
rrmap = AppendRR(rrmap, rr)
|
||||
}
|
||||
// Create serial
|
||||
serial := time.Now().Format("2006010215")
|
||||
// Add SOA
|
||||
SOAstring := fmt.Sprintf("%s. SOA %s. %s. %s 28800 7200 604800 86400", DnsConf.General.Domain, DnsConf.General.Nsname, DnsConf.General.Nsadmin, serial)
|
||||
soarr, err := dns.NewRR(SOAstring)
|
||||
if err != nil {
|
||||
log.Errorf("Error [%v] while trying to add SOA record: [%s]", err, SOAstring)
|
||||
} else {
|
||||
rrmap = AppendRR(rrmap, soarr)
|
||||
}
|
||||
r.Records = rrmap
|
||||
}
|
||||
|
||||
func AppendRR(rrmap map[uint16]map[string][]dns.RR, rr dns.RR) map[uint16]map[string][]dns.RR {
|
||||
_, ok := rrmap[rr.Header().Rrtype]
|
||||
if !ok {
|
||||
newrr := make(map[string][]dns.RR)
|
||||
rrmap[rr.Header().Rrtype] = newrr
|
||||
}
|
||||
rrmap[rr.Header().Rrtype][rr.Header().Name] = append(rrmap[rr.Header().Rrtype][rr.Header().Name], rr)
|
||||
log.Debugf("Adding new record of type [%s] for domain [%s]", dns.TypeToString[rr.Header().Rrtype], rr.Header().Name)
|
||||
return rrmap
|
||||
}
|
89
main.go
Normal file
89
main.go
Normal file
@ -0,0 +1,89 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/kataras/iris"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/op/go-logging"
|
||||
"os"
|
||||
)
|
||||
|
||||
// Logging config
|
||||
var logfile_path = "acme-dns.log"
|
||||
var log = logging.MustGetLogger("acme-dns")
|
||||
|
||||
// Global configuration struct
|
||||
var DnsConf DnsConfig
|
||||
|
||||
var DB Database
|
||||
|
||||
// Static records
|
||||
var RR Records
|
||||
|
||||
func main() {
|
||||
// Setup logging
|
||||
var stdout_format = logging.MustStringFormatter(
|
||||
`%{color}%{time:15:04:05.000} %{shortfunc} ▶ %{level:.4s} %{id:03x}%{color:reset} %{message}`,
|
||||
)
|
||||
var file_format = logging.MustStringFormatter(
|
||||
`%{time:15:04:05.000} %{shortfunc} - %{level:.4s} %{id:03x} %{message}`,
|
||||
)
|
||||
// Setup logging - stdout
|
||||
logStdout := logging.NewLogBackend(os.Stdout, "", 0)
|
||||
logStdoutFormatter := logging.NewBackendFormatter(logStdout, stdout_format)
|
||||
// Setup logging - file
|
||||
// Logging to file
|
||||
logfh, err := os.OpenFile(logfile_path, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666)
|
||||
if err != nil {
|
||||
fmt.Printf("Could not open log file %s\n", logfile_path)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer logfh.Close()
|
||||
logFile := logging.NewLogBackend(logfh, "", 0)
|
||||
logFileFormatter := logging.NewBackendFormatter(logFile, file_format)
|
||||
/* To limit logging to a level
|
||||
logFileLeveled := logging.AddModuleLevel(logFile)
|
||||
logFileLeveled.SetLevel(logging.ERROR, "")
|
||||
*/
|
||||
|
||||
// Start logging
|
||||
logging.SetBackend(logStdoutFormatter, logFileFormatter)
|
||||
log.Debug("Starting up...")
|
||||
|
||||
// Read global config
|
||||
if DnsConf, err = ReadConfig("config.cfg"); err != nil {
|
||||
log.Errorf("Got error %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
RR.Parse(DnsConf.General.StaticRecords)
|
||||
|
||||
// Open database
|
||||
err = DB.Init("acme-dns.db")
|
||||
if err != nil {
|
||||
log.Errorf("Could not open database [%v]", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer DB.DB.Close()
|
||||
|
||||
// DNS server part
|
||||
dns.HandleFunc(".", handleRequest)
|
||||
server := &dns.Server{Addr: ":53", Net: "udp"}
|
||||
go func() {
|
||||
err = server.ListenAndServe()
|
||||
if err != nil {
|
||||
log.Errorf("%v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}()
|
||||
|
||||
// API server
|
||||
api := iris.New()
|
||||
for path, handlerfunc := range GetHandlerMap() {
|
||||
api.Get(path, handlerfunc)
|
||||
}
|
||||
for path, handlerfunc := range PostHandlerMap() {
|
||||
api.Post(path, handlerfunc)
|
||||
}
|
||||
api.Listen(":8080")
|
||||
log.Debugf("Shutting down...")
|
||||
}
|
26
types.go
Normal file
26
types.go
Normal file
@ -0,0 +1,26 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// Static records
|
||||
type Records struct {
|
||||
Records map[uint16]map[string][]dns.RR
|
||||
}
|
||||
|
||||
// Config file main struct
|
||||
type DnsConfig struct {
|
||||
General general
|
||||
}
|
||||
|
||||
// Config file general section
|
||||
type general struct {
|
||||
Domain string
|
||||
Nsname string
|
||||
Nsadmin string
|
||||
Tls string
|
||||
Tls_cert_privkey string
|
||||
Tls_cert_fullchain string
|
||||
StaticRecords []string `toml:"records"`
|
||||
}
|
20
util.go
Normal file
20
util.go
Normal file
@ -0,0 +1,20 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
func NormalizeString(s string, len int) string {
|
||||
var ret string
|
||||
re, err := regexp.Compile("[^A-Za-z\\-0-9]+")
|
||||
if err != nil {
|
||||
log.Errorf("%v", err)
|
||||
return ""
|
||||
}
|
||||
ret = re.ReplaceAllString(s, "")
|
||||
if utf8.RuneCountInString(ret) > len {
|
||||
ret = ret[0:len]
|
||||
}
|
||||
return ret
|
||||
}
|
Loading…
Reference in New Issue
Block a user