diff --git a/acmetxt.go b/acmetxt.go deleted file mode 100644 index 7bb812c..0000000 --- a/acmetxt.go +++ /dev/null @@ -1,27 +0,0 @@ -package main - -import ( - "github.com/satori/go.uuid" - "time" -) - -// The default database object -type ACMETxt struct { - Username string `json:"username"` - Password string `json:"password"` - ACMETxtPost - LastActive time.Time -} - -type ACMETxtPost struct { - Subdomain string `json:"subdomain"` - Value string `json:"txt"` -} - -func NewACMETxt() ACMETxt { - var a ACMETxt = ACMETxt{} - a.Username = uuid.NewV4().String() - a.Password = uuid.NewV4().String() - a.Subdomain = uuid.NewV4().String() - return a -} diff --git a/api.go b/api.go index fa1df85..618e939 100644 --- a/api.go +++ b/api.go @@ -19,6 +19,24 @@ func PostHandlerMap() map[string]func(*iris.Context) { } } +func (a AuthMiddleware) Serve(ctx *iris.Context) { + username_str := ctx.RequestHeader("X-Api-User") + password := ctx.RequestHeader("X-Api-Key") + + username, err := GetValidUsername(username_str) + if err == nil && ValidKey(password) { + au, err := DB.GetByUsername(username) + if err == nil && CorrectPassword(password, au.Password) { + log.Debugf("Accepted authentication from [%s]", username_str) + ctx.Next() + return + } + // To protect against timed side channel (never gonna give you up) + CorrectPassword(password, "$2a$10$8JEFVNYYhLoBysjAxe2yBuXrkDojBQBkVpXEQgyQyjn43SvJ4vL36") + } + ctx.JSON(iris.StatusUnauthorized, iris.Map{"error": "unauthorized"}) +} + func WebRegisterPost(ctx *iris.Context) { // Create new user nu, err := DB.Register() @@ -33,6 +51,7 @@ func WebRegisterPost(ctx *iris.Context) { reg_json = iris.Map{"username": nu.Username, "password": nu.Password, "fulldomain": nu.Subdomain + "." + DnsConf.General.Domain, "subdomain": nu.Subdomain} reg_status = iris.StatusCreated } + log.Debugf("Successful registration, created user [%s]", nu.Username) ctx.JSON(reg_status, reg_json) } @@ -42,52 +61,36 @@ func WebRegisterGet(ctx *iris.Context) { } func WebUpdatePost(ctx *iris.Context) { - var username, password string - var a ACMETxtPost = ACMETxtPost{} - username = ctx.RequestHeader("X-API-User") - password = ctx.RequestHeader("X-API-Key") + // User auth done in middleware + var a ACMETxt = ACMETxt{} + user_string := ctx.RequestHeader("X-API-User") + username, err := GetValidUsername(user_string) + if err != nil { + log.Warningf("Error while getting username [%s]. This should never happen because of auth middlware.", user_string) + WebUpdatePostError(ctx, err, iris.StatusUnauthorized) + return + } if err := ctx.ReadJSON(&a); err != nil { // Handle bad post data + log.Warningf("Could not unmarshal: [%v]", err) WebUpdatePostError(ctx, err, iris.StatusBadRequest) return } - // Sanitized by db function - euser, err := DB.GetByUsername(username) - if err != nil { - // DB error - WebUpdatePostError(ctx, err, iris.StatusInternalServerError) - return - } - if len(euser) == 0 { - // User not found - // TODO: do bcrypt to avoid side channel - WebUpdatePostError(ctx, errors.New("invalid user or api key"), iris.StatusUnauthorized) - return - } - // Get first (and the only) user - upduser := euser[0] - // Validate password - if upduser.Password != password { - // Invalid password - WebUpdatePostError(ctx, errors.New("invalid user or api key"), iris.StatusUnauthorized) - return - } else { - // Do update - if len(a.Value) == 0 { - WebUpdatePostError(ctx, errors.New("missing txt value"), iris.StatusBadRequest) + a.Username = username + // Do update + if ValidSubdomain(a.Subdomain) && ValidTXT(a.Value) { + err := DB.Update(a) + if err != nil { + log.Warningf("Error trying to update [%v]", err) + WebUpdatePostError(ctx, errors.New("internal error"), iris.StatusInternalServerError) return - } else { - upduser.Value = a.Value - err = DB.Update(upduser) - if err != nil { - WebUpdatePostError(ctx, err, iris.StatusInternalServerError) - return - } - // All ok - ctx.JSON(iris.StatusOK, iris.Map{"txt": upduser.Value}) } + ctx.JSON(iris.StatusOK, iris.Map{"txt": a.Value}) + } else { + log.Warningf("Bad data, subdomain: [%s], txt: [%s]", a.Subdomain, a.Value) + WebUpdatePostError(ctx, errors.New("bad data"), iris.StatusBadRequest) + return } - } func WebUpdatePostError(ctx *iris.Context, err error, status int) { diff --git a/config.cfg b/config.cfg index f7ca620..e8c787a 100644 --- a/config.cfg +++ b/config.cfg @@ -6,15 +6,18 @@ nsname = "ns1.auth.example.org" # admin email address, with @ substituted with . nsadmin = "admin.example.org" +[api] +# domain name to listen requests for, mandatory if using tls = "letsencrypt" +# use "" (empty string) to bind to all interfaces +api_domain = "" +# listen port, eg. 443 for default HTTPS +port = "8080" # possible values: "letsencrypt", "cert", "false" tls = "letsencrypt" - # only used if tls = "cert" tls_cert_privkey = "/etc/tls/example.org/privkey.pem" tls_cert_fullchain = "/etc/tls/example.org/fullchain.pem" - # predefined records that we're serving in addition to the TXT - records = [ # default A "auth.example.org. A 192.168.1.100", @@ -25,3 +28,13 @@ records = [ "auth.example.org. NS ns1.auth.example.org.", "auth.example.org. NS ns2.auth.example.org.", ] + +[logconfig] +# logging level +loglevel = "debug" +# possible values: stdout, file +logtype = "stdout" +# file path for logfile +logfile = "./acme-dns.log" +# format +logformat = "%{time:15:04:05.000} %{shortfunc} - %{level:.4s} %{id:03x} %{message}" diff --git a/db.go b/db.go index 72a65fb..945b7ee 100644 --- a/db.go +++ b/db.go @@ -2,10 +2,10 @@ package main import ( "database/sql" - //"encoding/json" - //"github.com/boltdb/bolt" + "errors" _ "github.com/mattn/go-sqlite3" - //"strings" + "github.com/satori/go.uuid" + "golang.org/x/crypto/bcrypt" ) type Database struct { @@ -35,7 +35,11 @@ func (d *Database) Init(filename string) error { } func (d *Database) Register() (ACMETxt, error) { - a := NewACMETxt() + a, err := NewACMETxt() + if err != nil { + return ACMETxt{}, err + } + password_hash, err := bcrypt.GenerateFromPassword([]byte(a.Password), 10) reg_sql := ` INSERT INTO records( Username, @@ -49,54 +53,54 @@ func (d *Database) Register() (ACMETxt, error) { return a, err } defer sm.Close() - _, err = sm.Exec(a.Username, a.Password, a.Subdomain, a.Value) + _, err = sm.Exec(a.Username, password_hash, a.Subdomain, a.Value) if err != nil { return a, err } - // Do an insert check - /* - id, err := status.LastInsertId() - if err != nil { - return a, err - }*/ - return a, nil } -func (d *Database) GetByUsername(u string) ([]ACMETxt, error) { - u = NormalizeString(u, 36) - log.Debugf("Trying to select by user [%s] from table", u) +func (d *Database) GetByUsername(u uuid.UUID) (ACMETxt, error) { var results []ACMETxt get_sql := ` - SELECT Username, Password, Subdomain, Value + SELECT Username, Password, Subdomain, Value, LastActive FROM records WHERE Username=? LIMIT 1 ` sm, err := d.DB.Prepare(get_sql) if err != nil { - return nil, err + return ACMETxt{}, err } defer sm.Close() - rows, err := sm.Query(u) + rows, err := sm.Query(u.String()) if err != nil { - return nil, err + return ACMETxt{}, err } defer rows.Close() // It will only be one row though for rows.Next() { var a ACMETxt = ACMETxt{} - err = rows.Scan(&a.Username, &a.Password, &a.Subdomain, &a.Value) + var uname string + err = rows.Scan(&uname, &a.Password, &a.Subdomain, &a.Value, &a.LastActive) if err != nil { - return nil, err + return ACMETxt{}, err + } + a.Username, err = uuid.FromString(uname) + if err != nil { + return ACMETxt{}, err } results = append(results, a) } - return results, nil + if len(results) > 0 { + return results[0], nil + } else { + return ACMETxt{}, errors.New("no user") + } } func (d *Database) GetByDomain(domain string) ([]ACMETxt, error) { - domain = NormalizeString(domain, 36) + domain = SanitizeString(domain) log.Debugf("Trying to select domain [%s] from table", domain) var a []ACMETxt get_sql := ` @@ -144,46 +148,3 @@ func (d *Database) Update(a ACMETxt) error { } return nil } - -/* -func addTXT(txt ACMETxt) error { - - err := db.Update(func(tx *bolt.Tx) error { - bucket, err := tx.CreateBucketIfNotExists([]byte("domains")) - if err != nil { - return err - } - jtxt, err := json.Marshal(txt) - if err != nil { - return err - } - - // put returns nil if successful, nil return commits db.Update - return bucket.Put([]byte(strings.ToLower(txt.Domain)), jtxt) - }) - return err - -} - -func getTXT(domain string) (ACMETxt, error) { - var atxt ACMETxt - err := db.View(func(tx *bolt.Tx) error { - bucket := tx.Bucket([]byte("domains")) - value := bucket.Get([]byte(strings.ToLower(domain))) - if len(value) == 0 { - // Not found - log.Debugf("Record for [%s] not found", domain) - atxt = ACMETxt{} - } else { - if err := json.Unmarshal(value, &atxt); err != nil { - return err - } - } - return nil - }) - if err != nil { - return ACMETxt{}, err - } - return atxt, err -} -*/ diff --git a/dns.go b/dns.go index e5be504..dddf969 100644 --- a/dns.go +++ b/dns.go @@ -22,7 +22,7 @@ func answerTXT(q dns.Question) ([]dns.RR, int, error) { var rcode int = dns.RcodeNameError var domain string = q.Name - atxt, err := DB.GetByDomain(domain) + atxt, err := DB.GetByDomain(SanitizeDomainQuestion(domain)) if err != nil { log.Errorf("Error while trying to get record [%v]", err) return ra, dns.RcodeNameError, err diff --git a/main.go b/main.go index 02e7de9..c6839c4 100644 --- a/main.go +++ b/main.go @@ -9,7 +9,6 @@ import ( ) // Logging config -var logfile_path = "acme-dns.log" var log = logging.MustGetLogger("acme-dns") // Global configuration struct @@ -21,41 +20,45 @@ var DB Database var RR Records func main() { - // Setup logging - var stdout_format = logging.MustStringFormatter( - `%{color}%{time:15:04:05.000} %{shortfunc} ▶ %{level:.4s} %{id:03x}%{color:reset} %{message}`, - ) - var file_format = logging.MustStringFormatter( - `%{time:15:04:05.000} %{shortfunc} - %{level:.4s} %{id:03x} %{message}`, - ) - // Setup logging - stdout - logStdout := logging.NewLogBackend(os.Stdout, "", 0) - logStdoutFormatter := logging.NewBackendFormatter(logStdout, stdout_format) - // Setup logging - file - // Logging to file - logfh, err := os.OpenFile(logfile_path, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666) - if err != nil { - fmt.Printf("Could not open log file %s\n", logfile_path) - os.Exit(1) - } - defer logfh.Close() - logFile := logging.NewLogBackend(logfh, "", 0) - logFileFormatter := logging.NewBackendFormatter(logFile, file_format) - /* To limit logging to a level - logFileLeveled := logging.AddModuleLevel(logFile) - logFileLeveled.SetLevel(logging.ERROR, "") - */ - - // Start logging - logging.SetBackend(logStdoutFormatter, logFileFormatter) - log.Debug("Starting up...") - // Read global config - if DnsConf, err = ReadConfig("config.cfg"); err != nil { - log.Errorf("Got error %v", err) + config_tmp, err := ReadConfig("config.cfg") + if err != nil { + fmt.Printf("Got error %v\n", DnsConf.Logconfig.File) os.Exit(1) } - RR.Parse(DnsConf.General.StaticRecords) + DnsConf = config_tmp + // Setup logging + var logformat = logging.MustStringFormatter(DnsConf.Logconfig.Format) + var logBackend *logging.LogBackend + switch DnsConf.Logconfig.Logtype { + default: + // Setup logging - stdout + logBackend = logging.NewLogBackend(os.Stdout, "", 0) + case "file": + // Logging to file + logfh, err := os.OpenFile(DnsConf.Logconfig.File, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666) + if err != nil { + fmt.Printf("Could not open log file %s\n", DnsConf.Logconfig.File) + os.Exit(1) + } + defer logfh.Close() + logBackend = logging.NewLogBackend(logfh, "", 0) + } + + logLevel := logging.AddModuleLevel(logBackend) + switch DnsConf.Logconfig.Level { + case "warning": + logLevel.SetLevel(logging.WARNING, "") + case "error": + logLevel.SetLevel(logging.ERROR, "") + case "info": + logLevel.SetLevel(logging.INFO, "") + } + logFormatter := logging.NewBackendFormatter(logLevel, logformat) + logging.SetBackend(logFormatter) + + // Read the default records in + RR.Parse(DnsConf.Api.StaticRecords) // Open database err = DB.Init("acme-dns.db") @@ -76,14 +79,27 @@ func main() { } }() - // API server + // API server and endpoints api := iris.New() - for path, handlerfunc := range GetHandlerMap() { - api.Get(path, handlerfunc) + var ForceAuth AuthMiddleware = AuthMiddleware{} + api.Get("/register", WebRegisterGet) + api.Post("/register", WebRegisterPost) + api.Post("/update", ForceAuth.Serve, WebUpdatePost) + // TODO: migrate to api.Serve(iris.LETSENCRYPTPROD("mydomain.com")) + switch DnsConf.Api.Tls { + case "letsencrypt": + host := DnsConf.Api.Domain + ":" + DnsConf.Api.Port + api.Listen(host) + case "cert": + host := DnsConf.Api.Domain + ":" + DnsConf.Api.Port + api.ListenTLS(host, DnsConf.Api.Tls_cert_fullchain, DnsConf.Api.Tls_cert_privkey) + + default: + host := DnsConf.Api.Domain + ":" + DnsConf.Api.Port + api.Listen(host) } - for path, handlerfunc := range PostHandlerMap() { - api.Post(path, handlerfunc) + if err != nil { + log.Errorf("Error in HTTP server [%v]", err) } - api.Listen(":8080") log.Debugf("Shutting down...") } diff --git a/types.go b/types.go index 3fd2d35..a014834 100644 --- a/types.go +++ b/types.go @@ -2,6 +2,8 @@ package main import ( "github.com/miekg/dns" + "github.com/satori/go.uuid" + "time" ) // Static records @@ -11,16 +13,48 @@ type Records struct { // Config file main struct type DnsConfig struct { - General general + General general + Api httpapi + Logconfig logconfig } +// Auth middleware +type AuthMiddleware struct{} + // Config file general section type general struct { + Domain string + Nsname string + Nsadmin string +} + +// API config +type httpapi struct { Domain string - Nsname string - Nsadmin string + Port string Tls string Tls_cert_privkey string Tls_cert_fullchain string StaticRecords []string `toml:"records"` } + +// Logging config +type logconfig struct { + Level string `toml:"loglevel"` + Logtype string `toml:"logtype"` + File string `toml:"logfile"` + Format string `toml:"logformat"` +} + +// The default object +type ACMETxt struct { + Username uuid.UUID + Password string + ACMETxtPost + LastActive time.Time +} + +type ACMETxtPost struct { + Subdomain string `json:"subdomain"` + Value string `json:"txt"` +} diff --git a/util.go b/util.go index 234cf56..8fcc4b9 100644 --- a/util.go +++ b/util.go @@ -1,20 +1,57 @@ package main import ( + "crypto/rand" + "github.com/satori/go.uuid" + "math/big" "regexp" - "unicode/utf8" + "strings" ) -func NormalizeString(s string, length int) string { - var ret string - re, err := regexp.Compile("[^A-Za-z\\-0-9]+") +func SanitizeString(s string) string { + // URL safe base64 alphabet without padding as defined in ACME + re, err := regexp.Compile("[^A-Za-z\\-\\_0-9]+") if err != nil { log.Errorf("%v", err) return "" } - ret = re.ReplaceAllString(s, "") - if utf8.RuneCountInString(ret) > length { - ret = ret[0:length] - } - return ret + return re.ReplaceAllString(s, "") +} + +func GeneratePassword(length int) (string, error) { + ret := make([]byte, length) + const alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz1234567890-_" + alphalen := big.NewInt(int64(len(alphabet))) + for i := 0; i < length; i++ { + c, err := rand.Int(rand.Reader, alphalen) + if err != nil { + return "", err + } + r := int(c.Int64()) + ret[i] = alphabet[r] + } + return string(ret), nil +} + +func SanitizeDomainQuestion(d string) string { + var dom string + dns_suff := DnsConf.General.Domain + "." + if strings.HasSuffix(d, dns_suff) { + dom = d[0 : len(d)-len(dns_suff)] + } else { + dom = d + } + return dom +} + +func NewACMETxt() (ACMETxt, error) { + var a ACMETxt = ACMETxt{} + password, err := GeneratePassword(40) + if err != nil { + return a, err + } + a.Username = uuid.NewV4() + a.Password = password + a.Subdomain = uuid.NewV4().String() + return a, nil } diff --git a/validation.go b/validation.go new file mode 100644 index 0000000..589012c --- /dev/null +++ b/validation.go @@ -0,0 +1,48 @@ +package main + +import ( + "github.com/satori/go.uuid" + "golang.org/x/crypto/bcrypt" + "unicode/utf8" +) + +func GetValidUsername(u string) (uuid.UUID, error) { + uname, err := uuid.FromString(u) + if err != nil { + return uuid.UUID{}, err + } + return uname, nil +} + +func ValidKey(k string) bool { + kn := SanitizeString(k) + if utf8.RuneCountInString(k) == 40 && utf8.RuneCountInString(kn) == 40 { + // Correct length and all chars valid + return true + } + return false +} + +func ValidSubdomain(s string) bool { + _, err := uuid.FromString(s) + if err == nil { + return true + } + return false +} + +func ValidTXT(s string) bool { + sn := SanitizeString(s) + if utf8.RuneCountInString(s) == 43 && utf8.RuneCountInString(sn) == 43 { + // 43 chars is the current LE auth key size, but not limited / defined by ACME + return true + } + return false +} + +func CorrectPassword(pw string, hash string) bool { + if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(pw)); err == nil { + return true + } + return false +}