Skip to content

Commit

Permalink
fix: handle the case where asset name can be missing (#2995)
Browse files Browse the repository at this point in the history
* fix: handle the case where asset name can be missing

* detekt

* handle web case correctly

* typo

* detekt

* remove the check for empty mimeType
  • Loading branch information
MohamadJaara authored Sep 10, 2024
1 parent 6b3f390 commit 538bae1
Show file tree
Hide file tree
Showing 8 changed files with 296 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ data class AssetContent(
}

// We should not display Preview Assets (assets w/o valid encryption keys sent by Mac/Web clients) unless they include image metadata
val shouldBeDisplayed = !isPreviewMessage || hasValidImageMetadata
val isAssetDataComplete = !isPreviewMessage || hasValidImageMetadata

sealed class AssetMetadata {
data class Image(val width: Int, val height: Int) : AssetMetadata()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ internal class GetMessageAssetUseCaseImpl(
val assetUploadStatus = content.value.uploadStatus
val wasDownloaded: Boolean = assetDownloadStatus == SAVED_INTERNALLY || assetDownloadStatus == SAVED_EXTERNALLY
// assets uploaded by other clients have upload status NOT_UPLOADED
val alreadyUploaded: Boolean = (assetUploadStatus == NOT_UPLOADED && content.value.shouldBeDisplayed)
val alreadyUploaded: Boolean = (assetUploadStatus == NOT_UPLOADED && content.value.isAssetDataComplete)
|| assetUploadStatus == UPLOADED
val assetMetadata = with(content.value.remoteData) {
DownloadAssetMessageMetadata(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,12 @@ internal class ScheduleNewAssetMessageUseCaseImpl(
FileSharingStatus.Value.EnabledAll -> { /* no-op*/
}

is FileSharingStatus.Value.EnabledSome -> if (!validateAssetFileUseCase(assetName, it.state.allowedType)) {
is FileSharingStatus.Value.EnabledSome -> if (!validateAssetFileUseCase(
fileName = assetName,
mimeType = assetMimeType,
allowedExtension = it.state.allowedType
)
) {
kaliumLogger.e("The asset message trying to be processed has invalid content data")
return ScheduleNewAssetMessageResult.Failure.RestrictedFileType
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,95 @@
*/
package com.wire.kalium.logic.feature.asset

import com.wire.kalium.logic.kaliumLogger

/**
* Returns true if the file extension is present in file name and is allowed and false otherwise.
* @param fileName the file name (with extension) to validate.
* @param allowedExtension the list of allowed extension.
*/
interface ValidateAssetFileTypeUseCase {
operator fun invoke(fileName: String?, allowedExtension: List<String>): Boolean
operator fun invoke(
fileName: String?,
mimeType: String,
allowedExtension: List<String>
): Boolean
}

internal class ValidateAssetFileTypeUseCaseImpl : ValidateAssetFileTypeUseCase {
override operator fun invoke(fileName: String?, allowedExtension: List<String>): Boolean {
if (fileName == null) return false

val split = fileName.split(".")
return if (split.size < 2) {
false
override operator fun invoke(
fileName: String?,
mimeType: String,
allowedExtension: List<String>
): Boolean {
kaliumLogger.d("Validating file type for $fileName with mimeType $mimeType is empty ${mimeType.isBlank()}")
val extension = if (fileName != null) {
extensionFromFileName(fileName)
} else {
val allowedExtensionLowerCase = allowedExtension.map { it.lowercase() }
val extensions = split.subList(1, split.size).map { it.lowercase() }
extensions.all { it.isNotEmpty() && allowedExtensionLowerCase.contains(it) }
extensionFromMimeType(mimeType)
}
return extension?.let { allowedExtension.contains(it) } ?: false
}

private fun extensionFromFileName(fileName: String): String? =
fileName.substringAfterLast('.', "").takeIf { it.isNotEmpty() }

private fun extensionFromMimeType(mimeType: String): String? = fileExtensions[mimeType]

private companion object {
val fileExtensions = mapOf(
"video/3gpp" to "3gpp",
"audio/aac" to "aac",
"audio/amr" to "amr",
"video/x-msvideo" to "avi",
"image/bmp" to "bmp",
"text/css" to "css",
"text/csv" to "csv",
"application/msword" to "doc",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document" to "docx",
"message/rfc822" to "eml",
"audio/flac" to "flac",
"image/gif" to "gif",
"text/html" to "html",
"image/vnd.microsoft.icon" to "ico",
"image/jpeg" to "jpeg",
"image/jpeg" to "jpg",
"image/jpeg" to "jfif",
"application/vnd.apple.keynote" to "key",
"audio/mp4" to "m4a",
"video/x-m4v" to "m4v",
"text/markdown" to "md",
"audio/midi" to "midi",
"video/x-matroska" to "mkv",
"video/quicktime" to "mov",
"audio/mpeg" to "mp3",
"video/mp4" to "mp4",
"video/mpeg" to "mpeg",
"application/vnd.ms-outlook" to "msg",
"application/vnd.oasis.opendocument.spreadsheet" to "ods",
"application/vnd.oasis.opendocument.text" to "odt",
"audio/ogg" to "ogg",
"application/pdf" to "pdf",
"image/jpeg" to "pjp",
"image/pjpeg" to "pjpeg",
"image/png" to "png",
"application/vnd.ms-powerpoint" to "ppt",
"application/vnd.openxmlformats-officedocument.presentationml.presentation" to "pptx",
"image/vnd.adobe.photoshop" to "psd",
"application/rtf" to "rtf",
"application/sql" to "sql",
"image/svg+xml" to "svg",
"application/x-tex" to "tex",
"image/tiff" to "tiff",
"text/plain" to "txt",
"text/x-vcard" to "vcf",
"audio/wav" to "wav",
"video/webm" to "webm",
"image/webp" to "webp",
"video/x-ms-wmv" to "wmv",
"application/vnd.ms-excel" to "xls",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" to "xlsx",
"application/xml" to "xml"
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import com.wire.kalium.logic.data.message.Message
import com.wire.kalium.logic.data.message.MessageContent
import com.wire.kalium.logic.data.message.MessageRepository
import com.wire.kalium.logic.data.message.PersistMessageUseCase
import com.wire.kalium.logic.data.message.getType
import com.wire.kalium.logic.feature.asset.ValidateAssetFileTypeUseCase
import com.wire.kalium.logic.functional.onFailure
import com.wire.kalium.logic.functional.onSuccess
Expand All @@ -46,22 +47,50 @@ internal class AssetMessageHandlerImpl(
kaliumLogger.e("The asset message trying to be processed has invalid content data")
return
}

val messageContent = message.content
userConfigRepository.isFileSharingEnabled().onSuccess {
val isThisAssetAllowed = when (it.state) {
FileSharingStatus.Value.Disabled -> false
FileSharingStatus.Value.EnabledAll -> true
FileSharingStatus.Value.Disabled -> AssetRestrictionContinuationStrategy.Restrict
FileSharingStatus.Value.EnabledAll -> AssetRestrictionContinuationStrategy.Continue

is FileSharingStatus.Value.EnabledSome -> validateAssetMimeTypeUseCase(
messageContent.value.name,
it.state.allowedType
)
is FileSharingStatus.Value.EnabledSome -> {
// If the asset message is missing the name, but it does have full
// asset data then we can not decide now if it is allowed or not
// it is safe to continue and the code later will check the original
// asset message and decide if it is allowed or not
if (
message.content.value.name.isNullOrEmpty() &&
message.content.value.isAssetDataComplete
) {
kaliumLogger.e("The asset message trying to be processed has invalid data looking locally")
AssetRestrictionContinuationStrategy.RestrictIfThereIsNotOldMessageWithTheSameAssetID
} else {
validateAssetMimeTypeUseCase(
fileName = messageContent.value.name,
mimeType = messageContent.value.mimeType,
allowedExtension = it.state.allowedType
).let { validateResult ->
if (validateResult) {
AssetRestrictionContinuationStrategy.Continue
} else {
AssetRestrictionContinuationStrategy.Restrict
}
}
}
}
}

if (isThisAssetAllowed) {
processNonRestrictedAssetMessage(message, messageContent)
} else {
persistRestrictedAssetMessage(message, messageContent)
when (isThisAssetAllowed) {
AssetRestrictionContinuationStrategy.Continue -> processNonRestrictedAssetMessage(message, messageContent, false)
AssetRestrictionContinuationStrategy.RestrictIfThereIsNotOldMessageWithTheSameAssetID -> processNonRestrictedAssetMessage(
message,
messageContent,
true
)

AssetRestrictionContinuationStrategy.Restrict -> persistRestrictedAssetMessage(message, messageContent)

}
}
}
Expand All @@ -77,23 +106,34 @@ internal class AssetMessageHandlerImpl(
persistMessage(newMessage)
}

private suspend fun processNonRestrictedAssetMessage(processedMessage: Message.Regular, assetContent: MessageContent.Asset) {
private suspend fun processNonRestrictedAssetMessage(
processedMessage: Message.Regular,
assetContent: MessageContent.Asset,
restrictIfNotAFollowUpMessage: Boolean
) {
messageRepository.getMessageById(processedMessage.conversationId, processedMessage.id).onFailure {
// No asset message was received previously, so just persist the preview of the asset message
// Web/Mac clients split the asset message delivery into 2. One with the preview metadata (assetName, assetSize...) and
// with empty encryption keys and the second with empty metadata but all the correct encryption keys. We just want to
// hide the preview of generic asset messages with empty encryption keys as a way to avoid user interaction with them.
val initialMessage = processedMessage.copy(
visibility = if (assetContent.value.shouldBeDisplayed) Message.Visibility.VISIBLE else Message.Visibility.HIDDEN
)
persistMessage(initialMessage)

if (restrictIfNotAFollowUpMessage) {
persistRestrictedAssetMessage(processedMessage, assetContent)
} else {
val initialMessage = processedMessage.copy(
visibility = if (assetContent.value.isAssetDataComplete) Message.Visibility.VISIBLE else Message.Visibility.HIDDEN
)
persistMessage(initialMessage)
}
}.onSuccess { persistedMessage ->
val validDecryptionKeys = assetContent.value.remoteData
// Check the second asset message is from the same original sender
if (isSenderVerified(persistedMessage, processedMessage) && persistedMessage is Message.Regular) {
// The second asset message received from Web/Mac clients contains the full asset decryption keys, so we need to update
// the preview message persisted previously with the rest of the data
persistMessage(updateAssetMessageWithDecryptionKeys(persistedMessage, validDecryptionKeys))
updateAssetMessageWithDecryptionKeys(persistedMessage, validDecryptionKeys)?.let {
persistMessage(it)
}
} else {
kaliumLogger.e("The previously persisted message has a different sender id than the one we are trying to process")
}
Expand All @@ -106,8 +146,21 @@ internal class AssetMessageHandlerImpl(
private fun updateAssetMessageWithDecryptionKeys(
persistedMessage: Message.Regular,
remoteData: AssetContent.RemoteData
): Message.Regular {
val assetMessageContent = persistedMessage.content as MessageContent.Asset
): Message.Regular? {
val assetMessageContent = when (persistedMessage.content) {
is MessageContent.Asset -> persistedMessage.content
is MessageContent.RestrictedAsset -> {
// original message was a restricted asset message, ignoring
return null
}

is MessageContent.FailedDecryption,
is MessageContent.Knock,
is MessageContent.Location,
is MessageContent.Composite,
is MessageContent.Text,
is MessageContent.Unknown -> error("Invalid asset message content type ${persistedMessage.content.getType()}")
}
// The message was previously received with just metadata info, so let's update it with the raw data info
return persistedMessage.copy(
content = assetMessageContent.copy(
Expand All @@ -120,3 +173,9 @@ internal class AssetMessageHandlerImpl(
)
}
}

private sealed interface AssetRestrictionContinuationStrategy {
data object Continue : AssetRestrictionContinuationStrategy
data object Restrict : AssetRestrictionContinuationStrategy
data object RestrictIfThereIsNotOldMessageWithTheSameAssetID : AssetRestrictionContinuationStrategy
}
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,7 @@ class ScheduleNewAssetMessageUseCaseTest {

verify(arrangement.validateAssetMimeTypeUseCase)
.function(arrangement.validateAssetMimeTypeUseCase::invoke)
.with(eq("some-asset.txt"), eq(listOf("png")))
.with(eq("some-asset.txt"), eq("text/plain"), eq(listOf("png")))
.wasInvoked(exactly = once)
}

Expand Down Expand Up @@ -669,7 +669,7 @@ class ScheduleNewAssetMessageUseCaseTest {

verify(arrangement.validateAssetMimeTypeUseCase)
.function(arrangement.validateAssetMimeTypeUseCase::invoke)
.with(eq("some-asset.png"), eq(listOf("png")))
.with(eq("some-asset.png"), eq("image/png"), eq(listOf("png")))
.wasInvoked(exactly = once)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class ValidateAssetFileTypeUseCaseTest {
fun givenRegularFileNameWithAllowedExtension_whenInvoke_thenBeApproved() = runTest {
val (_, validate) = arrange {}

val result = validate("name.txt", listOf("txt", "jpg"))
val result = validate(fileName = "name.txt", mimeType = "", allowedExtension = listOf("txt", "jpg"))

assertTrue(result)
}
Expand All @@ -37,7 +37,7 @@ class ValidateAssetFileTypeUseCaseTest {
fun givenRegularFileNameWithNOTAllowedExtension_whenInvoke_thenBeRestricted() = runTest {
val (_, validate) = arrange {}

val result = validate("name.php", listOf("txt", "jpg"))
val result = validate(fileName = "name.php", mimeType = "", allowedExtension = listOf("txt", "jpg"))

assertFalse(result)
}
Expand All @@ -46,7 +46,7 @@ class ValidateAssetFileTypeUseCaseTest {
fun givenRegularFileNameWithoutExtension_whenInvoke_thenBeRestricted() = runTest {
val (_, validate) = arrange {}

val result = validate("name", listOf("txt", "jpg"))
val result = validate(fileName = "name", mimeType = "", allowedExtension = listOf("txt", "jpg"))

assertFalse(result)
}
Expand All @@ -55,24 +55,40 @@ class ValidateAssetFileTypeUseCaseTest {
fun givenNullFileName_whenInvoke_thenBeRestricted() = runTest {
val (_, validate) = arrange {}

val result = validate(null, listOf("txt", "jpg"))
val result = validate(fileName = null, mimeType = "", allowedExtension = listOf("txt", "jpg"))

assertFalse(result)
}

@Test
fun givenRegularFileNameWithFewExtensions_whenInvoke_thenEachExtensionIsChecked() = runTest {
fun givenFileNameIs() = runTest {
val (_, validate) = arrange {}

val result1 = validate("name.php.txt", listOf("txt", "jpg"))
val result2 = validate("name.txt.php", listOf("txt", "jpg"))
val result3 = validate("name..txt.jpg", listOf("txt", "jpg"))
val result4 = validate("name.txt.php.txt.jpg", listOf("txt", "jpg"))
val result = validate(fileName = null, mimeType = "image/jpg", allowedExtension = listOf("txt", "jpg"))

assertFalse(result1)
assertFalse(result2)
assertFalse(result3)
assertFalse(result4)
assertFalse(result)
}

@Test
fun givenNullFileNameAndValidMimeType_whenInvoke_thenMimeTypeIsChecked() = runTest {
val (_, validate) = arrange {}

val result = validate(fileName = null, mimeType = "image/jpg", allowedExtension = listOf("txt", "jpg"))

assertFalse(result)
}

@Test
fun givenNullFileNameAndInvalidMimeType_whenInvoke_thenMimeTypeIsChecked() = runTest {
val (_, validate) = arrange {}

val result = validate(
fileName = null,
mimeType = "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
allowedExtension = listOf("txt", "jpg")
)

assertFalse(result)
}

private fun arrange(block: Arrangement.() -> Unit) = Arrangement(block).arrange()
Expand Down
Loading

0 comments on commit 538bae1

Please sign in to comment.