diff --git a/malicious-site-protection/malicious-site-protection-impl/src/main/kotlin/com/duckduckgo/malicioussiteprotection/impl/data/MaliciousSiteRepository.kt b/malicious-site-protection/malicious-site-protection-impl/src/main/kotlin/com/duckduckgo/malicioussiteprotection/impl/data/MaliciousSiteRepository.kt index 1be3aed90e55..8e0468fde422 100644 --- a/malicious-site-protection/malicious-site-protection-impl/src/main/kotlin/com/duckduckgo/malicioussiteprotection/impl/data/MaliciousSiteRepository.kt +++ b/malicious-site-protection/malicious-site-protection-impl/src/main/kotlin/com/duckduckgo/malicioussiteprotection/impl/data/MaliciousSiteRepository.kt @@ -33,7 +33,7 @@ import timber.log.Timber interface MaliciousSiteRepository { suspend fun containsHashPrefix(hashPrefix: String): Boolean - suspend fun getFilter(hash: String): Filter? + suspend fun getFilters(hash: String): List? suspend fun matches(hashPrefix: String): List } @@ -86,9 +86,11 @@ class RealMaliciousSiteRepository @Inject constructor( return maliciousSiteDao.getHashPrefix(hashPrefix) != null } - override suspend fun getFilter(hash: String): Filter? { + override suspend fun getFilters(hash: String): List? { return maliciousSiteDao.getFilter(hash)?.let { - Filter(it.hash, it.regex) + it.map { + Filter(it.hash, it.regex) + } } } diff --git a/malicious-site-protection/malicious-site-protection-impl/src/main/kotlin/com/duckduckgo/malicioussiteprotection/impl/data/db/MaliciousSiteDao.kt b/malicious-site-protection/malicious-site-protection-impl/src/main/kotlin/com/duckduckgo/malicioussiteprotection/impl/data/db/MaliciousSiteDao.kt index df9a5690f589..db62ba442d6b 100644 --- a/malicious-site-protection/malicious-site-protection-impl/src/main/kotlin/com/duckduckgo/malicioussiteprotection/impl/data/db/MaliciousSiteDao.kt +++ b/malicious-site-protection/malicious-site-protection-impl/src/main/kotlin/com/duckduckgo/malicioussiteprotection/impl/data/db/MaliciousSiteDao.kt @@ -17,7 +17,6 @@ package com.duckduckgo.malicioussiteprotection.impl.data.db import androidx.room.Dao -import androidx.room.Delete import androidx.room.Insert import androidx.room.OnConflictStrategy import androidx.room.Query @@ -33,10 +32,10 @@ interface MaliciousSiteDao { @Insert(onConflict = OnConflictStrategy.REPLACE) suspend fun insertHashPrefixes(items: List) - @Delete(HashPrefixEntity::class) + @Query("DELETE FROM hash_prefixes") suspend fun deleteHashPrefixes() - @Delete(FilterEntity::class) + @Query("DELETE FROM filters") suspend fun deleteFilters() @Insert(onConflict = OnConflictStrategy.REPLACE) @@ -52,7 +51,7 @@ interface MaliciousSiteDao { suspend fun getHashPrefix(hashPrefix: String): HashPrefixEntity? @Query("SELECT * FROM filters WHERE hash = :hash") - suspend fun getFilter(hash: String): FilterEntity? + suspend fun getFilter(hash: String): List? @Transaction suspend fun insertData( diff --git a/malicious-site-protection/malicious-site-protection-impl/src/main/kotlin/com/duckduckgo/malicioussiteprotection/impl/domain/RealMaliciousSiteProtection.kt b/malicious-site-protection/malicious-site-protection-impl/src/main/kotlin/com/duckduckgo/malicioussiteprotection/impl/domain/RealMaliciousSiteProtection.kt index 2909bdf22f1f..87384dc0ffca 100644 --- a/malicious-site-protection/malicious-site-protection-impl/src/main/kotlin/com/duckduckgo/malicioussiteprotection/impl/domain/RealMaliciousSiteProtection.kt +++ b/malicious-site-protection/malicious-site-protection-impl/src/main/kotlin/com/duckduckgo/malicioussiteprotection/impl/domain/RealMaliciousSiteProtection.kt @@ -59,10 +59,12 @@ class RealMaliciousSiteProtection @Inject constructor( Timber.d("\uD83D\uDFE2 Cris: should not block (no hash) $hashPrefix, $url") return IsMaliciousResult.SAFE } - maliciousSiteRepository.getFilter(hash)?.let { - if (Pattern.compile(it.regex).matcher(url.toString()).find()) { - Timber.d("\uD83D\uDFE2 Cris: shouldBlock $url") - return IsMaliciousResult.MALICIOUS + maliciousSiteRepository.getFilters(hash)?.let { + for (filter in it) { + if (Pattern.compile(filter.regex).matcher(url.toString()).find()) { + Timber.d("\uD83D\uDFE2 Cris: shouldBlock $url") + return IsMaliciousResult.MALICIOUS + } } } appCoroutineScope.launch(dispatchers.io()) { diff --git a/malicious-site-protection/malicious-site-protection-impl/src/test/kotlin/com/duckduckgo/malicioussiteprotection/impl/domain/RealMaliciousSiteProtectionTest.kt b/malicious-site-protection/malicious-site-protection-impl/src/test/kotlin/com/duckduckgo/malicioussiteprotection/impl/domain/RealMaliciousSiteProtectionTest.kt index d1484cabf63b..270e4a1f0550 100644 --- a/malicious-site-protection/malicious-site-protection-impl/src/test/kotlin/com/duckduckgo/malicioussiteprotection/impl/domain/RealMaliciousSiteProtectionTest.kt +++ b/malicious-site-protection/malicious-site-protection-impl/src/test/kotlin/com/duckduckgo/malicioussiteprotection/impl/domain/RealMaliciousSiteProtectionTest.kt @@ -81,7 +81,7 @@ class RealMaliciousSiteProtectionTest { val filter = Filter(hash, ".*malicious.*") whenever(maliciousSiteRepository.containsHashPrefix(hashPrefix)).thenReturn(true) - whenever(maliciousSiteRepository.getFilter(hash)).thenReturn(filter) + whenever(maliciousSiteRepository.getFilters(hash)).thenReturn(listOf(filter)) val result = realMaliciousSiteProtection.isMalicious(url) {} @@ -97,7 +97,7 @@ class RealMaliciousSiteProtectionTest { val filter = Filter(hash, ".*malicious.*") whenever(maliciousSiteRepository.containsHashPrefix(hashPrefix)).thenReturn(true) - whenever(maliciousSiteRepository.getFilter(hash)).thenReturn(filter) + whenever(maliciousSiteRepository.getFilters(hash)).thenReturn(listOf(filter)) whenever(mockMaliciousSiteProtectionRCFeature.isFeatureEnabled()).thenReturn(false) val result = realMaliciousSiteProtection.isMalicious(url) {} @@ -114,7 +114,7 @@ class RealMaliciousSiteProtectionTest { val filter = Filter(hash, ".*unsafe.*") whenever(maliciousSiteRepository.containsHashPrefix(hashPrefix)).thenReturn(true) - whenever(maliciousSiteRepository.getFilter(hash)).thenReturn(filter) + whenever(maliciousSiteRepository.getFilters(hash)).thenReturn(listOf(filter)) val result = realMaliciousSiteProtection.isMalicious(url) {} @@ -131,7 +131,7 @@ class RealMaliciousSiteProtectionTest { var onSiteBlockedAsyncCalled = false whenever(maliciousSiteRepository.containsHashPrefix(hashPrefix)).thenReturn(true) - whenever(maliciousSiteRepository.getFilter(hash)).thenReturn(filter) + whenever(maliciousSiteRepository.getFilters(hash)).thenReturn(listOf(filter)) whenever(maliciousSiteRepository.matches(hashPrefix.substring(0, 4))) .thenReturn(listOf(Match(hostname, url.toString(), ".*malicious.*", hash)))