diff --git a/handlers/handlers.go b/handlers/handlers.go index 0f2e066..5f2aba1 100644 --- a/handlers/handlers.go +++ b/handlers/handlers.go @@ -66,6 +66,11 @@ func (a *App) CreateShortURL(w http.ResponseWriter, r *http.Request) { http.Error(w, "Invalid URL", http.StatusBadRequest) return } + // Verify alias is valid + if u.Alias != "" && !utils.IsValidAlias(u.Alias) { + http.Error(w, "Invalid Alias", http.StatusBadRequest) + return + } // Check if URL entry already exists existingURL := &URLEntry{} a.DB.Where("url = ?", u.URL).Find(existingURL) diff --git a/handlers/handlers_test.go b/handlers/handlers_test.go index 2ba7e4b..dc5f84b 100644 --- a/handlers/handlers_test.go +++ b/handlers/handlers_test.go @@ -76,10 +76,13 @@ func TestInvalidCreate(t *testing.T) { alias string status int }{ - "bad": {url: "https/agoogle.com", alias: "ggls", status: http.StatusBadRequest}, - "no colon": {url: "http//google.com", alias: "ggl", status: http.StatusBadRequest}, - "empty": {url: "", alias: "", status: http.StatusBadRequest}, - "asdf": {url: "asdf", alias: "", status: http.StatusBadRequest}, + "bad": {url: "https/agoogle.com", alias: "ggls", status: http.StatusBadRequest}, + "no colon": {url: "http//google.com", alias: "ggl", status: http.StatusBadRequest}, + "empty": {url: "", alias: "", status: http.StatusBadRequest}, + "asdf": {url: "asdf", alias: "", status: http.StatusBadRequest}, + "spaces": {url: "https://google.com", alias: "oh no spaces", status: http.StatusBadRequest}, + "question mark": {url: "https://google.com", alias: "huh?", status: http.StatusBadRequest}, + "percent": {url: "https://google.com", alias: "test%20stuff", status: http.StatusBadRequest}, } app := setup() diff --git a/utils/utils.go b/utils/utils.go index 94264fd..b8f0350 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -2,22 +2,23 @@ package utils import ( "math/rand" - "strings" + "regexp" "time" ) -// Alphanumeric charset -const CHARSET = "abcdefghijklmnopqrstuvwxyz" + - "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" +const ( + // Alphanumeric charset + CHARSET = "abcdefghijklmnopqrstuvwxyz" + + "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + + // RFC 3986 Section 2.3 URI Unreserved Characters + URIUnreservedChars = `^([A-Za-z0-9_.~-])+$` +) // Tests whether a string is in the alphanumeric charset -func IsAlphaNum(str string) bool { - for _, r := range []rune(str) { - if !strings.ContainsRune(CHARSET, r) { - return false - } - } - return true +func IsValidAlias(str string) bool { + valid, err := regexp.MatchString(URIUnreservedChars, str) + return valid && err == nil } // Generate a random alphanumeric string of given length diff --git a/utils/utils_test.go b/utils/utils_test.go index 6ab095a..bc8b7f3 100644 --- a/utils/utils_test.go +++ b/utils/utils_test.go @@ -4,23 +4,31 @@ import ( "testing" ) -func TestIsAlphaNum(t *testing.T) { +func TestValidAlias(t *testing.T) { tests := map[string]struct { name string str string val bool }{ - "basic": {str: "hello", val: true}, - "dash": {str: "-", val: false}, - "period": {str: ".", val: false}, + "LOWERCASE": {str: "lowercase", val: true}, + "UPPERCASE": {str: "UPPERCASE", val: true}, + "underscore": {str: "_", val: true}, + "period": {str: ".", val: true}, + "dash": {str: "-", val: true}, + "tilde": {str: "~", val: true}, + "adsf": {str: "a_S-d.f~", val: true}, "question mark": {str: "?", val: false}, - "backslash": {str: "?", val: false}, + "backslash": {str: "\\", val: false}, + "asterix": {str: "*", val: false}, + "ampersand": {str: "&", val: false}, + "space": {str: " ", val: false}, + "percent": {str: "%", val: false}, } for name, tc := range tests { t.Run(name, func(t *testing.T) { - if IsAlphaNum(tc.str) != tc.val { - t.Errorf("For '%s', expected %t. Got %t instead.", tc.str, tc.val, IsAlphaNum(tc.str)) + if IsValidAlias(tc.str) != tc.val { + t.Errorf("For '%s', expected %t. Got %t instead.", tc.str, tc.val, IsValidAlias(tc.str)) } }) } @@ -29,7 +37,7 @@ func TestIsAlphaNum(t *testing.T) { // For the code coverage why not func TestRandString(t *testing.T) { str := RandString(6) - if len(str) == 6 && IsAlphaNum(str) { + if len(str) == 6 && IsValidAlias(str) { return } t.Errorf("Seriously? How?")