From b0cd264c714093cc402bfb69cf4bf40feaa9a418 Mon Sep 17 00:00:00 2001
From: Joona Hoikkala <joohoi@users.noreply.github.com>
Date: Wed, 15 Nov 2017 13:52:27 +0200
Subject: [PATCH] Fail on malformed JSON payloads in register endpoint (#24)

---
 api.go      | 14 +++++++++++++-
 api_test.go | 52 +++++++++++++++++++++++++++++++++++++++++++++-------
 2 files changed, 58 insertions(+), 8 deletions(-)

diff --git a/api.go b/api.go
index 8498d65..ef2b02f 100644
--- a/api.go
+++ b/api.go
@@ -3,6 +3,7 @@ package main
 import (
 	"encoding/json"
 	"fmt"
+	"io/ioutil"
 	"net/http"
 
 	"github.com/julienschmidt/httprouter"
@@ -22,7 +23,18 @@ func webRegisterPost(w http.ResponseWriter, r *http.Request, _ httprouter.Params
 	var regStatus int
 	var reg []byte
 	aTXT := ACMETxt{}
-	json.NewDecoder(r.Body).Decode(&aTXT)
+	bdata, _ := ioutil.ReadAll(r.Body)
+	if bdata != nil && len(bdata) > 0 {
+		err := json.Unmarshal(bdata, &aTXT)
+		if err != nil {
+			regStatus = http.StatusBadRequest
+			reg = jsonError("malformed_json_payload")
+			w.Header().Set("Content-Type", "application/json")
+			w.WriteHeader(regStatus)
+			w.Write(reg)
+			return
+		}
+	}
 	// Create new user
 	nu, err := DB.Register(aTXT.AllowFrom)
 	if err != nil {
diff --git a/api_test.go b/api_test.go
index e8f2012..b546966 100644
--- a/api_test.go
+++ b/api_test.go
@@ -35,6 +35,17 @@ func noAuth(update httprouter.Handle) httprouter.Handle {
 	}
 }
 
+func getExpect(t *testing.T, server *httptest.Server) *httpexpect.Expect {
+	return httpexpect.WithConfig(httpexpect.Config{
+		BaseURL:  server.URL,
+		Reporter: httpexpect.NewAssertReporter(t),
+		Printers: []httpexpect.Printer{
+			httpexpect.NewCurlPrinter(t),
+			httpexpect.NewDebugPrinter(t, true),
+		},
+	})
+}
+
 func setupRouter(debug bool, noauth bool) http.Handler {
 	api := httprouter.New()
 	var dbcfg = dbsettings{
@@ -72,7 +83,7 @@ func TestApiRegister(t *testing.T) {
 	router := setupRouter(false, false)
 	server := httptest.NewServer(router)
 	defer server.Close()
-	e := httpexpect.New(t, server.URL)
+	e := getExpect(t, server)
 	e.POST("/register").Expect().
 		Status(http.StatusCreated).
 		JSON().Object().
@@ -103,11 +114,38 @@ func TestApiRegister(t *testing.T) {
 	response.Value("allowfrom").Array().Elements("123.123.123.123/32")
 }
 
+func TestApiRegisterMalformedJSON(t *testing.T) {
+	router := setupRouter(false, false)
+	server := httptest.NewServer(router)
+	defer server.Close()
+	e := getExpect(t, server)
+
+	malPayloads := []string{
+		"{\"allowfrom': '1.1.1.1/32'}",
+		"\"allowfrom\": \"1.1.1.1/32\"",
+		"{\"allowfrom\": \"[1.1.1.1/32]\"",
+		"\"allowfrom\": \"1.1.1.1/32\"}",
+		"{allowfrom: \"1.2.3.4\"}",
+		"{allowfrom: [1.2.3.4]}",
+		"whatever that's not a json payload",
+	}
+	for _, test := range malPayloads {
+		e.POST("/register").
+			WithBytes([]byte(test)).
+			Expect().
+			Status(http.StatusBadRequest).
+			JSON().Object().
+			ContainsKey("error").
+			NotContainsKey("subdomain").
+			NotContainsKey("username")
+	}
+}
+
 func TestApiRegisterWithMockDB(t *testing.T) {
 	router := setupRouter(false, false)
 	server := httptest.NewServer(router)
 	defer server.Close()
-	e := httpexpect.New(t, server.URL)
+	e := getExpect(t, server)
 	oldDb := DB.GetBackend()
 	db, mock, _ := sqlmock.New()
 	DB.SetBackend(db)
@@ -125,7 +163,7 @@ func TestApiUpdateWithoutCredentials(t *testing.T) {
 	router := setupRouter(false, false)
 	server := httptest.NewServer(router)
 	defer server.Close()
-	e := httpexpect.New(t, server.URL)
+	e := getExpect(t, server)
 	e.POST("/update").Expect().
 		Status(http.StatusUnauthorized).
 		JSON().Object().
@@ -143,7 +181,7 @@ func TestApiUpdateWithCredentials(t *testing.T) {
 	router := setupRouter(false, false)
 	server := httptest.NewServer(router)
 	defer server.Close()
-	e := httpexpect.New(t, server.URL)
+	e := getExpect(t, server)
 	newUser, err := DB.Register(cidrslice{})
 	if err != nil {
 		t.Errorf("Could not create new user, got error [%v]", err)
@@ -176,7 +214,7 @@ func TestApiUpdateWithCredentialsMockDB(t *testing.T) {
 	router := setupRouter(false, true)
 	server := httptest.NewServer(router)
 	defer server.Close()
-	e := httpexpect.New(t, server.URL)
+	e := getExpect(t, server)
 	oldDb := DB.GetBackend()
 	db, mock, _ := sqlmock.New()
 	DB.SetBackend(db)
@@ -202,7 +240,7 @@ func TestApiManyUpdateWithCredentials(t *testing.T) {
 	router := setupRouter(true, false)
 	server := httptest.NewServer(router)
 	defer server.Close()
-	e := httpexpect.New(t, server.URL)
+	e := getExpect(t, server)
 	// User without defined CIDR masks
 	newUser, err := DB.Register(cidrslice{})
 	if err != nil {
@@ -262,7 +300,7 @@ func TestApiManyUpdateWithIpCheckHeaders(t *testing.T) {
 	router := setupRouter(false, false)
 	server := httptest.NewServer(router)
 	defer server.Close()
-	e := httpexpect.New(t, server.URL)
+	e := getExpect(t, server)
 	// Use header checks from default header (X-Forwarded-For)
 	Config.API.UseHeader = true
 	// User without defined CIDR masks