Skip to content

Commit

Permalink
fix: validate the list of allowed file names when extracting files fo…
Browse files Browse the repository at this point in the history
…rm zip folder (#2221)

* fix: validate the list of allowed file names when extracting files form  zip folder

* detekt

* fix tests

* detekt

(cherry picked from commit d6c8d60)
  • Loading branch information
MohamadJaara committed Nov 16, 2023
1 parent a6c1110 commit 21cb550
Show file tree
Hide file tree
Showing 8 changed files with 114 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,19 @@ import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.data.asset.KaliumFileSystem
import com.wire.kalium.logic.functional.Either
import okio.BufferedSource
import okio.Source
import okio.Sink
import okio.Path
import okio.Sink
import okio.Source

actual fun createCompressedFile(files: List<Pair<Source, String>>, outputSink: Sink): Either<CoreFailure, Long> =
TODO("Implement own iOS compression method")

actual fun extractCompressedFile(inputSource: Source, outputRootPath: Path, fileSystem: KaliumFileSystem): Either<CoreFailure, Long> =
actual fun extractCompressedFile(
inputSource: Source,
outputRootPath: Path,
param: ExtractFilesParam,
fileSystem: KaliumFileSystem
): Either<CoreFailure, Long> =
TODO("Implement own iOS compression method")

actual fun checkIfCompressedFileContainsFileTypes(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,22 +68,45 @@ private fun addToCompressedFile(zipOutputStream: ZipOutputStream, fileSource: So
}

@Suppress("TooGenericExceptionCaught", "NestedBlockDepth")
actual fun extractCompressedFile(inputSource: Source, outputRootPath: Path, fileSystem: KaliumFileSystem): Either<CoreFailure, Long> = try {
actual fun extractCompressedFile(
inputSource: Source,
outputRootPath: Path,
param: ExtractFilesParam,
fileSystem: KaliumFileSystem
): Either<CoreFailure, Long> = try {
var totalExtractedFilesSize = 0L
ZipInputStream(inputSource.buffer().inputStream()).use { zipInputStream ->
var entry: ZipEntry? = zipInputStream.nextEntry
while (entry != null) {
readCompressedEntry(zipInputStream, outputRootPath, fileSystem, entry).let {
totalExtractedFilesSize += it.first
entry = it.second
totalExtractedFilesSize += when (param) {
is ExtractFilesParam.All -> readCompressedEntry(zipInputStream, outputRootPath, fileSystem, entry)
is ExtractFilesParam.Only -> readAndExtractIfMatch(zipInputStream, outputRootPath, fileSystem, entry, param.files)
}
zipInputStream.closeEntry()
entry = zipInputStream.nextEntry
}
}
Either.Right(totalExtractedFilesSize)
} catch (e: Exception) {
Either.Left(StorageFailure.Generic(RuntimeException("There was an error trying to extract the provided compressed file", e)))
}

private fun readAndExtractIfMatch(
zipInputStream: ZipInputStream,
outputRootPath: Path,
fileSystem: KaliumFileSystem,
entry: ZipEntry,
fileNames: Set<String>
): Long {
return entry.name.let {
if (fileNames.contains(it)) {
readCompressedEntry(zipInputStream, outputRootPath, fileSystem, entry)
} else {
0L
}
}
}

@Suppress("TooGenericExceptionCaught", "NestedBlockDepth")
actual fun checkIfCompressedFileContainsFileTypes(
compressedFilePath: Path,
Expand Down Expand Up @@ -121,11 +144,7 @@ private fun readCompressedEntry(
outputRootPath: Path,
fileSystem: KaliumFileSystem,
entry: ZipEntry
): Pair<Long, ZipEntry?> {
if (isInvalidEntryPathDestination(entry.name)) {
throw RuntimeException("The provided zip file is invalid or has invalid data")
}

): Long {
var totalExtractedFilesSize = 0L
var byteCount: Int
val entryPathName = "$outputRootPath/${entry.name}"
Expand All @@ -137,8 +156,7 @@ private fun readCompressedEntry(
}
output.write(zipInputStream.readBytes())
}
zipInputStream.closeEntry()
return totalExtractedFilesSize to zipInputStream.nextEntry
return totalExtractedFilesSize
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ package com.wire.kalium.logic.feature.backup
object BackupConstants {
const val BACKUP_FILE_NAME_PREFIX = "WBX"
const val BACKUP_ENCRYPTED_FILE_NAME = "user-backup.cc20"

// BACKUP_METADATA_FILE_NAME and BACKUP_USER_DB_NAME must not be changed
// if there is a need to change them, please create a new file names and add it to the list of acceptedFileNames()
const val BACKUP_USER_DB_NAME = "user-backup-database.db"
const val BACKUP_METADATA_FILE_NAME = "export.json"
const val BACKUP_ENCRYPTED_EXTENSION = "cc20"
Expand All @@ -30,6 +33,16 @@ object BackupConstants {
const val BACKUP_WEB_EVENTS_FILE_NAME = "events.json"
const val BACKUP_WEB_CONVERSATIONS_FILE_NAME = "conversations.json"

/**
* list of accepted file names for the backup file
* this is used when extracting data from the zip file
*/
fun acceptedFileNames() = setOf(
BACKUP_USER_DB_NAME,
BACKUP_METADATA_FILE_NAME,
BACKUP_ENCRYPTED_FILE_NAME
)

fun createBackupFileName(userHandle: String?, timestampIso: String) = // file names cannot have special characters
"$BACKUP_FILE_NAME_PREFIX-$userHandle-${timestampIso.replace(":", "-")}.zip"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@ import com.wire.kalium.cryptography.backup.BackupHeader.HeaderDecodingErrors.INV
import com.wire.kalium.cryptography.backup.Passphrase
import com.wire.kalium.cryptography.utils.ChaCha20Decryptor.decryptBackupFile
import com.wire.kalium.logic.data.asset.KaliumFileSystem
import com.wire.kalium.logic.data.id.CurrentClientIdProvider
import com.wire.kalium.logic.data.id.IdMapper
import com.wire.kalium.logic.data.user.UserId
import com.wire.kalium.logic.data.user.UserRepository
import com.wire.kalium.logic.di.MapperProvider
import com.wire.kalium.logic.data.id.CurrentClientIdProvider
import com.wire.kalium.logic.feature.backup.BackupConstants.BACKUP_ENCRYPTED_EXTENSION
import com.wire.kalium.logic.feature.backup.BackupConstants.acceptedFileNames
import com.wire.kalium.logic.feature.backup.BackupConstants.createBackupFileName
import com.wire.kalium.logic.feature.backup.RestoreBackupResult.BackupRestoreFailure.BackupIOFailure
import com.wire.kalium.logic.feature.backup.RestoreBackupResult.BackupRestoreFailure.DecryptionFailure
Expand All @@ -43,6 +44,7 @@ import com.wire.kalium.logic.functional.flatMap
import com.wire.kalium.logic.functional.fold
import com.wire.kalium.logic.functional.mapLeft
import com.wire.kalium.logic.kaliumLogger
import com.wire.kalium.logic.util.ExtractFilesParam
import com.wire.kalium.logic.util.extractCompressedFile
import com.wire.kalium.logic.wrapStorageRequest
import com.wire.kalium.network.tools.KtxSerializer
Expand Down Expand Up @@ -192,23 +194,19 @@ internal class RestoreBackupUseCaseImpl(
when (decodingError) {
INVALID_USER_ID -> InvalidUserId
INVALID_VERSION -> IncompatibleBackup("The provided backup version is lower than the minimum supported version")
INVALID_FORMAT -> IncompatibleBackup("The provided backup format is not supported")
INVALID_FORMAT -> IncompatibleBackup("mappedDecodingError: The provided backup format is not supported")
}

private suspend fun checkIsValidEncryption(extractedBackupPath: Path): Either<Failure, Path> =
with(kaliumFileSystem) {
val encryptedFilePath = listDirectories(extractedBackupPath).firstOrNull {
it.name.substringAfterLast('.', "") == BACKUP_ENCRYPTED_EXTENSION
}
return if (encryptedFilePath == null) {
Either.Left(Failure(DecryptionFailure("No encrypted backup file found")))
} else {
Either.Right(encryptedFilePath)
}
listDirectories(extractedBackupPath)
.firstOrNull { it.name.endsWith(".$BACKUP_ENCRYPTED_EXTENSION") }?.let {
Either.Right(it)
} ?: Either.Left(Failure(DecryptionFailure("No encrypted backup file found")))
}

private fun extractFiles(inputSource: Source, extractedBackupRootPath: Path) =
extractCompressedFile(inputSource, extractedBackupRootPath, kaliumFileSystem)
extractCompressedFile(inputSource, extractedBackupRootPath, ExtractFilesParam.Only(acceptedFileNames()), kaliumFileSystem)

private suspend fun getDbPathAndImport(
extractedBackupRootPath: Path,
Expand All @@ -230,14 +228,15 @@ internal class RestoreBackupUseCaseImpl(
private suspend fun backupMetadata(extractedBackupPath: Path): Either<Failure, BackupMetadata> =
kaliumFileSystem.listDirectories(extractedBackupPath)
.firstOrNull { it.name == BackupConstants.BACKUP_METADATA_FILE_NAME }
?.let { metadataFile ->
.let { it ?: return Either.Left(Failure(IncompatibleBackup("backupMetadata: No metadata file found"))) }
.let { metadataFile ->
try {
kaliumFileSystem.source(metadataFile).buffer()
.use { Either.Right(KtxSerializer.json.decodeFromString(it.readUtf8())) }
} catch (e: SerializationException) {
Either.Left(Failure(IncompatibleBackup(e.toString())))
}
} ?: Either.Left(Failure(IncompatibleBackup("The provided backup format is not supported")))
}

private fun isValidBackupAuthor(metadata: BackupMetadata): Either<Failure, BackupMetadata> =
if (metadata.userId == userId.toString() || metadata.userId == userId.value) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ internal class RestoreWebBackupUseCaseImpl(
if (metadata.version == "19") {
importWebBackup(backupRootPath, this)
} else {
Either.Left(IncompatibleBackup("The provided backup format is not supported"))
Either.Left(IncompatibleBackup("invoke: The provided backup format is not supported"))
}.fold({ RestoreBackupResult.Failure(it) }, { RestoreBackupResult.Success })
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,24 @@ import okio.Sink
import okio.Source

expect fun createCompressedFile(files: List<Pair<Source, String>>, outputSink: Sink): Either<CoreFailure, Long>
expect fun extractCompressedFile(inputSource: Source, outputRootPath: Path, fileSystem: KaliumFileSystem): Either<CoreFailure, Long>
expect fun extractCompressedFile(
inputSource: Source,
outputRootPath: Path,
param: ExtractFilesParam,
fileSystem: KaliumFileSystem
): Either<CoreFailure, Long>

expect fun checkIfCompressedFileContainsFileTypes(
compressedFilePath: Path,
fileSystem: KaliumFileSystem,
expectedFileExtensions: List<String>
): Either<CoreFailure, Map<String, Boolean>>

sealed interface ExtractFilesParam {
data object All : ExtractFilesParam
data class Only(val files: Set<String>) : ExtractFilesParam {
constructor(vararg files: String) : this(files.toSet())
}
}

expect inline fun <reified T> decodeBufferSequence(bufferedSource: BufferedSource): Sequence<T>
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import com.wire.kalium.logic.feature.backup.BackupConstants.BACKUP_METADATA_FILE
import com.wire.kalium.logic.framework.TestUser.SELF
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.logic.test_util.TestKaliumDispatcher
import com.wire.kalium.logic.util.ExtractFilesParam
import com.wire.kalium.logic.util.IgnoreIOS
import com.wire.kalium.logic.util.SecurityHelper
import com.wire.kalium.logic.util.extractCompressedFile
Expand Down Expand Up @@ -107,7 +108,7 @@ class CreateBackupUseCaseTest {
with(fakeFileSystem) {
val extractedFilesPath = tempFilePath()
createDirectory(extractedFilesPath)
extractCompressedFile(source(result.backupFilePath), extractedFilesPath, fakeFileSystem)
extractCompressedFile(source(result.backupFilePath), extractedFilesPath, ExtractFilesParam.All, fakeFileSystem)

assertTrue(listDirectories(extractedFilesPath).firstOrNull { it.name == BACKUP_METADATA_FILE_NAME } != null)
val extractedDB = listDirectories(extractedFilesPath).firstOrNull {
Expand Down Expand Up @@ -174,7 +175,7 @@ class CreateBackupUseCaseTest {
with(fakeFileSystem) {
val extractedFilesPath = tempFilePath()
createDirectory(extractedFilesPath)
extractCompressedFile(source(result.backupFilePath), extractedFilesPath, fakeFileSystem)
extractCompressedFile(source(result.backupFilePath), extractedFilesPath, ExtractFilesParam.All, fakeFileSystem)
val extractedDBPath = listDirectories(extractedFilesPath).firstOrNull { it.name.contains(".cc20") }
assertEquals(BACKUP_ENCRYPTED_FILE_NAME, extractedDBPath?.name)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,24 @@ package com.wire.kalium.logic.feature.backup
import com.wire.kalium.cryptography.backup.BackupCoder
import com.wire.kalium.cryptography.backup.Passphrase
import com.wire.kalium.cryptography.utils.ChaCha20Encryptor.encryptBackupFile
import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.clientPlatform
import com.wire.kalium.logic.data.asset.FakeKaliumFileSystem
import com.wire.kalium.logic.data.asset.KaliumFileSystem
import com.wire.kalium.logic.data.conversation.ClientId
import com.wire.kalium.logic.data.user.SelfUser
import com.wire.kalium.logic.data.user.UserId
import com.wire.kalium.logic.data.user.UserRepository
import com.wire.kalium.logic.di.MapperProvider
import com.wire.kalium.logic.data.id.CurrentClientIdProvider
import com.wire.kalium.logic.feature.backup.BackupConstants.BACKUP_ENCRYPTED_FILE_NAME
import com.wire.kalium.logic.feature.backup.BackupConstants.BACKUP_USER_DB_NAME
import com.wire.kalium.logic.framework.TestUser
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.logic.util.ExtractFilesParam
import com.wire.kalium.logic.util.IgnoreIOS
import com.wire.kalium.logic.util.createCompressedFile
import com.wire.kalium.logic.util.extractCompressedFile
import com.wire.kalium.persistence.backup.DatabaseImporter
import com.wire.kalium.persistence.db.UserDBSecret
import com.wire.kalium.util.DateTimeUtil
Expand All @@ -48,14 +54,15 @@ import kotlinx.coroutines.test.runTest
import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.Json
import okio.Path
import okio.Source
import okio.buffer
import okio.use
import kotlin.test.Ignore
import kotlin.test.Test
import kotlin.test.assertIs
import kotlin.test.assertTrue

@IgnoreIOS // TODO re-enable when BackupUtils is implemented on Darwin
@OptIn(ExperimentalCoroutinesApi::class)
class RestoreBackupUseCaseTest {

private val fakeFileSystem = FakeKaliumFileSystem()
Expand All @@ -76,7 +83,7 @@ class RestoreBackupUseCaseTest {
val result = useCase(backupPath, "")

// then
assertTrue(result is RestoreBackupResult.Success)
assertIs<RestoreBackupResult.Success>(result)
verify(arrangement.databaseImporter)
.suspendFunction(arrangement.databaseImporter::importFromFile)
.with(any(), any())
Expand All @@ -99,8 +106,8 @@ class RestoreBackupUseCaseTest {
val result = useCase(backupPath, "")

// then
assertTrue(result is RestoreBackupResult.Failure)
assertTrue(result.failure is RestoreBackupResult.BackupRestoreFailure.InvalidUserId)
assertIs<RestoreBackupResult.Failure>(result)
assertIs<RestoreBackupResult.BackupRestoreFailure.InvalidUserId>(result.failure)

verify(arrangement.databaseImporter)
.suspendFunction(arrangement.databaseImporter::importFromFile)
Expand All @@ -122,8 +129,8 @@ class RestoreBackupUseCaseTest {
val result = useCase(backupPath, "")

// then
assertTrue(result is RestoreBackupResult.Failure)
assertTrue(result.failure is RestoreBackupResult.BackupRestoreFailure.IncompatibleBackup)
assertIs<RestoreBackupResult.Failure>(result)
assertIs<RestoreBackupResult.BackupRestoreFailure.IncompatibleBackup>(result.failure)

verify(arrangement.databaseImporter)
.suspendFunction(arrangement.databaseImporter::importFromFile)
Expand All @@ -148,7 +155,7 @@ class RestoreBackupUseCaseTest {
val result = useCase(backupPath, password)

// then
assertTrue(result is RestoreBackupResult.Success)
assertIs<RestoreBackupResult.Success>(result)

verify(arrangement.databaseImporter)
.suspendFunction(arrangement.databaseImporter::importFromFile)
Expand Down Expand Up @@ -225,8 +232,8 @@ class RestoreBackupUseCaseTest {
val result = useCase(backupPath, password)

// then
assertTrue(result is RestoreBackupResult.Failure)
assertTrue(result.failure is RestoreBackupResult.BackupRestoreFailure.BackupIOFailure)
assertIs<RestoreBackupResult.Failure>(result)
assertIs<RestoreBackupResult.BackupRestoreFailure.BackupIOFailure>(result.failure)
verify(arrangement.databaseImporter)
.suspendFunction(arrangement.databaseImporter::importFromFile)
.with(any(), any())
Expand All @@ -247,7 +254,7 @@ class RestoreBackupUseCaseTest {
@Mock
val userRepository = mock(classOf<UserRepository>())

val fakeDBFileName = "fakeDBFile.db"
val fakeDBFileName = BACKUP_USER_DB_NAME
private val selfUserId = currentTestUserId
private val fakeDBData = fakeDBFileName.encodeToByteArray()
private val idMapper = MapperProvider.idMapper()
Expand Down Expand Up @@ -317,7 +324,7 @@ class RestoreBackupUseCaseTest {

suspend fun withEncryptedBackup(path: Path, userId: UserId, password: String) = apply {
with(fakeFileSystem) {
val encryptedBackupPath = fakeFileSystem.tempFilePath("backup.cc20")
val encryptedBackupPath = fakeFileSystem.tempFilePath(BACKUP_ENCRYPTED_FILE_NAME)
createEncryptedBackup(encryptedBackupPath, userId, password)
val outputSink = sink(path)
createCompressedFile(listOf(source(encryptedBackupPath) to encryptedBackupPath.name), outputSink)
Expand Down Expand Up @@ -352,6 +359,22 @@ class RestoreBackupUseCaseTest {
.thenReturn(selfUser)
}

lateinit var extractZipFile: (
inputSource: Source,
outputRootPath: Path,
param: ExtractFilesParam,
fileSystem: KaliumFileSystem
) -> Either<CoreFailure, Long>

fun withSuccessfulExtractZipFile() = apply {
extractZipFile = { _, _, _, _ -> Either.Right(10L) }
}

fun withDefaultExtractZipFile() = apply {
extractZipFile = ::extractCompressedFile
}


fun arrange() = this to RestoreBackupUseCaseImpl(
databaseImporter = databaseImporter,
kaliumFileSystem = fakeFileSystem,
Expand Down

0 comments on commit 21cb550

Please sign in to comment.