From 833eb660fd274164568a06d208fabd35e9dde9d5 Mon Sep 17 00:00:00 2001 From: Peter Hellberg Date: Tue, 5 Sep 2023 18:37:43 +0200 Subject: [PATCH] Check that the alphabet does not contain multibyte characters --- alphabet_test.go | 10 ++++++++++ sqids.go | 19 ++++++++++++++++--- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/alphabet_test.go b/alphabet_test.go index f1472ad..6c7c907 100644 --- a/alphabet_test.go +++ b/alphabet_test.go @@ -5,6 +5,16 @@ import ( "testing" ) +func TestMultibyteAlphabet(t *testing.T) { + _, err := New(Options{ + Alphabet: "ë1092", + }) + + if err != errAlphabetMultibyte { + t.Fatalf("unexpected error: %v", err) + } +} + func TestAlphabetSimple(t *testing.T) { numbers := []uint64{1, 2, 3} id := "4d9fd2" diff --git a/sqids.go b/sqids.go index 7658f9a..707a4e6 100644 --- a/sqids.go +++ b/sqids.go @@ -18,6 +18,14 @@ const ( var defaultBlocklist []string = newDefaultBlocklist() +// Alphabet validation errors +var ( + errAlphabetMultibyte = errors.New("alphabet must not contain any multibyte characters") + errAlphabetTooShort = errors.New("alphabet length must be at least 5") + errAlphabetNotUniqueChars = errors.New("alphabet must contain unique characters") + errAlphabetMinLength = errors.New("alphabet minimum length") +) + // Options for a custom instance of Sqids type Options struct { Alphabet string @@ -59,19 +67,24 @@ func validatedOptions(o Options) (Options, error) { o.Alphabet = defaultAlphabet } + // check that the alphabet does not contain multibyte characters + if len(o.Alphabet) != len([]rune(o.Alphabet)) { + return Options{}, errAlphabetMultibyte + } + // check the length of the alphabet if len(o.Alphabet) < minAlphabetLength { - return Options{}, errors.New("alphabet length must be at least 5") + return Options{}, errAlphabetTooShort } // check that the alphabet has only unique characters if !hasUniqueChars(o.Alphabet) { - return Options{}, errors.New("alphabet must contain unique characters") + return Options{}, errAlphabetNotUniqueChars } // test min length (type [might be lang-specific] + min length + max length) if o.MinLength < int(minUint64Value) || o.MinLength > len(o.Alphabet) { - return Options{}, fmt.Errorf("minimum length has to be between %d and %d", minUint64Value, len(o.Alphabet)) + return Options{}, fmt.Errorf("%w has to be between %d and %d", errAlphabetMinLength, minUint64Value, len(o.Alphabet)) } o.Blocklist = filterBlocklist(o.Alphabet, o.Blocklist)