diff --git a/src/main/java/redis/clients/jedis/search/RediSearchUtil.java b/src/main/java/redis/clients/jedis/search/RediSearchUtil.java index a6a82486b7..e08e0d2dd8 100644 --- a/src/main/java/redis/clients/jedis/search/RediSearchUtil.java +++ b/src/main/java/redis/clients/jedis/search/RediSearchUtil.java @@ -2,9 +2,12 @@ import java.nio.ByteBuffer; import java.nio.ByteOrder; +import java.util.Arrays; import java.util.HashMap; +import java.util.HashSet; import java.util.Map; +import java.util.Set; import redis.clients.jedis.util.SafeEncoder; @@ -54,6 +57,39 @@ public static byte[] ToByteArray(float[] input) { return bytes; } + public static final Set TAG_ESCAPE_CHARS = new HashSet<>(Arrays.asList(// + ',', '.', '<', '>', '{', '}', '[', // + ']', '"', '\'', ':', ';', '!', '@', // + '#', '$', '%', '^', '&', '*', '(', // + ')', '-', '+', '=', '~', '|' // + )); + + public static String escape(String text) { + return escape(text, false); + } + + public static String escapeQuery(String query) { + return escape(query, true); + } + + public static String escape(String text, boolean querying) { + char[] chars = text.toCharArray(); + + StringBuilder sb = new StringBuilder(); + for (char ch : chars) { + if (TAG_ESCAPE_CHARS.contains(ch) + || (querying && ch == ' ')) { + sb.append("\\"); + } + sb.append(ch); + } + return sb.toString(); + } + + public static String unescape(String text) { + return text.replace("\\", ""); + } + private RediSearchUtil() { throw new InstantiationError("Must not instantiate this class"); } diff --git a/src/test/java/redis/clients/jedis/modules/search/SearchWithParamsTest.java b/src/test/java/redis/clients/jedis/modules/search/SearchWithParamsTest.java index 98bc5968e8..046fccfb54 100644 --- a/src/test/java/redis/clients/jedis/modules/search/SearchWithParamsTest.java +++ b/src/test/java/redis/clients/jedis/modules/search/SearchWithParamsTest.java @@ -1210,4 +1210,19 @@ public void searchIterationCollect() { "pupil:4444", "student:5555", "teacher:6666").stream().collect(Collectors.toSet()), collect.stream().map(Document::getId).collect(Collectors.toSet())); } + + @Test + public void escapeUtil() { + assertOK(client.ftCreate(index, TextField.of("txt"))); + + client.hset("doc1", "txt", RediSearchUtil.escape("hello-world")); + assertNotEquals("hello-world", client.hget("doc1", "txt")); + assertEquals("hello-world", RediSearchUtil.unescape(client.hget("doc1", "txt"))); + + SearchResult resultNoEscape = client.ftSearch(index, "hello-world"); + assertEquals(0, resultNoEscape.getTotalResults()); + + SearchResult resultEscaped = client.ftSearch(index, RediSearchUtil.escapeQuery("hello-world")); + assertEquals(1, resultEscaped.getTotalResults()); + } }