Skip to content

Commit

Permalink
Support duplicated hash filters
Browse files Browse the repository at this point in the history
  • Loading branch information
CrisBarreiro committed Dec 19, 2024
1 parent 0c8a08a commit ef6e36e
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import timber.log.Timber

interface MaliciousSiteRepository {
suspend fun containsHashPrefix(hashPrefix: String): Boolean
suspend fun getFilter(hash: String): Filter?
suspend fun getFilters(hash: String): List<Filter>?
suspend fun matches(hashPrefix: String): List<Match>
}

Expand Down Expand Up @@ -86,9 +86,11 @@ class RealMaliciousSiteRepository @Inject constructor(
return maliciousSiteDao.getHashPrefix(hashPrefix) != null
}

override suspend fun getFilter(hash: String): Filter? {
override suspend fun getFilters(hash: String): List<Filter>? {
return maliciousSiteDao.getFilter(hash)?.let {
Filter(it.hash, it.regex)
it.map {
Filter(it.hash, it.regex)
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package com.duckduckgo.malicioussiteprotection.impl.data.db

import androidx.room.Dao
import androidx.room.Delete
import androidx.room.Insert
import androidx.room.OnConflictStrategy
import androidx.room.Query
Expand All @@ -33,10 +32,10 @@ interface MaliciousSiteDao {
@Insert(onConflict = OnConflictStrategy.REPLACE)
suspend fun insertHashPrefixes(items: List<HashPrefixEntity>)

@Delete(HashPrefixEntity::class)
@Query("DELETE FROM hash_prefixes")
suspend fun deleteHashPrefixes()

@Delete(FilterEntity::class)
@Query("DELETE FROM filters")
suspend fun deleteFilters()

@Insert(onConflict = OnConflictStrategy.REPLACE)
Expand All @@ -52,7 +51,7 @@ interface MaliciousSiteDao {
suspend fun getHashPrefix(hashPrefix: String): HashPrefixEntity?

@Query("SELECT * FROM filters WHERE hash = :hash")
suspend fun getFilter(hash: String): FilterEntity?
suspend fun getFilter(hash: String): List<FilterEntity>?

@Transaction
suspend fun insertData(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,12 @@ class RealMaliciousSiteProtection @Inject constructor(
Timber.d("\uD83D\uDFE2 Cris: should not block (no hash) $hashPrefix, $url")
return IsMaliciousResult.SAFE
}
maliciousSiteRepository.getFilter(hash)?.let {
if (Pattern.compile(it.regex).matcher(url.toString()).find()) {
Timber.d("\uD83D\uDFE2 Cris: shouldBlock $url")
return IsMaliciousResult.MALICIOUS
maliciousSiteRepository.getFilters(hash)?.let {
for (filter in it) {
if (Pattern.compile(filter.regex).matcher(url.toString()).find()) {
Timber.d("\uD83D\uDFE2 Cris: shouldBlock $url")
return IsMaliciousResult.MALICIOUS
}
}
}
appCoroutineScope.launch(dispatchers.io()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class RealMaliciousSiteProtectionTest {
val filter = Filter(hash, ".*malicious.*")

whenever(maliciousSiteRepository.containsHashPrefix(hashPrefix)).thenReturn(true)
whenever(maliciousSiteRepository.getFilter(hash)).thenReturn(filter)
whenever(maliciousSiteRepository.getFilters(hash)).thenReturn(listOf(filter))

val result = realMaliciousSiteProtection.isMalicious(url) {}

Expand All @@ -97,7 +97,7 @@ class RealMaliciousSiteProtectionTest {
val filter = Filter(hash, ".*malicious.*")

whenever(maliciousSiteRepository.containsHashPrefix(hashPrefix)).thenReturn(true)
whenever(maliciousSiteRepository.getFilter(hash)).thenReturn(filter)
whenever(maliciousSiteRepository.getFilters(hash)).thenReturn(listOf(filter))
whenever(mockMaliciousSiteProtectionRCFeature.isFeatureEnabled()).thenReturn(false)

val result = realMaliciousSiteProtection.isMalicious(url) {}
Expand All @@ -114,7 +114,7 @@ class RealMaliciousSiteProtectionTest {
val filter = Filter(hash, ".*unsafe.*")

whenever(maliciousSiteRepository.containsHashPrefix(hashPrefix)).thenReturn(true)
whenever(maliciousSiteRepository.getFilter(hash)).thenReturn(filter)
whenever(maliciousSiteRepository.getFilters(hash)).thenReturn(listOf(filter))

val result = realMaliciousSiteProtection.isMalicious(url) {}

Expand All @@ -131,7 +131,7 @@ class RealMaliciousSiteProtectionTest {
var onSiteBlockedAsyncCalled = false

whenever(maliciousSiteRepository.containsHashPrefix(hashPrefix)).thenReturn(true)
whenever(maliciousSiteRepository.getFilter(hash)).thenReturn(filter)
whenever(maliciousSiteRepository.getFilters(hash)).thenReturn(listOf(filter))
whenever(maliciousSiteRepository.matches(hashPrefix.substring(0, 4)))
.thenReturn(listOf(Match(hostname, url.toString(), ".*malicious.*", hash)))

Expand Down

0 comments on commit ef6e36e

Please sign in to comment.