Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: validate the list of allowed file names when extracting files form zip folder #2221

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,19 @@ import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.data.asset.KaliumFileSystem
import com.wire.kalium.logic.functional.Either
import okio.BufferedSource
import okio.Source
import okio.Sink
import okio.Path
import okio.Sink
import okio.Source

actual fun createCompressedFile(files: List<Pair<Source, String>>, outputSink: Sink): Either<CoreFailure, Long> =
TODO("Implement own iOS compression method")

actual fun extractCompressedFile(inputSource: Source, outputRootPath: Path, fileSystem: KaliumFileSystem): Either<CoreFailure, Long> =
actual fun extractCompressedFile(
inputSource: Source,
outputRootPath: Path,
param: ExtractFilesParam,
fileSystem: KaliumFileSystem
): Either<CoreFailure, Long> =
TODO("Implement own iOS compression method")

actual fun checkIfCompressedFileContainsFileTypes(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,22 +68,45 @@ private fun addToCompressedFile(zipOutputStream: ZipOutputStream, fileSource: So
}

@Suppress("TooGenericExceptionCaught", "NestedBlockDepth")
actual fun extractCompressedFile(inputSource: Source, outputRootPath: Path, fileSystem: KaliumFileSystem): Either<CoreFailure, Long> = try {
actual fun extractCompressedFile(
inputSource: Source,
outputRootPath: Path,
param: ExtractFilesParam,
fileSystem: KaliumFileSystem
): Either<CoreFailure, Long> = try {
var totalExtractedFilesSize = 0L
ZipInputStream(inputSource.buffer().inputStream()).use { zipInputStream ->
var entry: ZipEntry? = zipInputStream.nextEntry
while (entry != null) {
readCompressedEntry(zipInputStream, outputRootPath, fileSystem, entry).let {
totalExtractedFilesSize += it.first
entry = it.second
totalExtractedFilesSize += when (param) {
is ExtractFilesParam.All -> readCompressedEntry(zipInputStream, outputRootPath, fileSystem, entry)
is ExtractFilesParam.Only -> readAndExtractIfMatch(zipInputStream, outputRootPath, fileSystem, entry, param.files)
}
zipInputStream.closeEntry()
entry = zipInputStream.nextEntry
}
}
Either.Right(totalExtractedFilesSize)
} catch (e: Exception) {
Either.Left(StorageFailure.Generic(RuntimeException("There was an error trying to extract the provided compressed file", e)))
}

private fun readAndExtractIfMatch(
zipInputStream: ZipInputStream,
outputRootPath: Path,
fileSystem: KaliumFileSystem,
entry: ZipEntry,
fileNames: Set<String>
): Long {
return entry.name.let {
if (fileNames.contains(it)) {
readCompressedEntry(zipInputStream, outputRootPath, fileSystem, entry)
} else {
0L
}
}
}

@Suppress("TooGenericExceptionCaught", "NestedBlockDepth")
actual fun checkIfCompressedFileContainsFileTypes(
compressedFilePath: Path,
Expand Down Expand Up @@ -121,11 +144,7 @@ private fun readCompressedEntry(
outputRootPath: Path,
fileSystem: KaliumFileSystem,
entry: ZipEntry
): Pair<Long, ZipEntry?> {
if (isInvalidEntryPathDestination(entry.name)) {
throw RuntimeException("The provided zip file is invalid or has invalid data")
}

): Long {
var totalExtractedFilesSize = 0L
var byteCount: Int
val entryPathName = "$outputRootPath/${entry.name}"
Expand All @@ -137,8 +156,7 @@ private fun readCompressedEntry(
}
output.write(zipInputStream.readBytes())
}
zipInputStream.closeEntry()
return totalExtractedFilesSize to zipInputStream.nextEntry
return totalExtractedFilesSize
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ package com.wire.kalium.logic.feature.backup
object BackupConstants {
const val BACKUP_FILE_NAME_PREFIX = "WBX"
const val BACKUP_ENCRYPTED_FILE_NAME = "user-backup.cc20"

// BACKUP_METADATA_FILE_NAME and BACKUP_USER_DB_NAME must not be changed
// if there is a need to change them, please create a new file names and add it to the list of acceptedFileNames()
const val BACKUP_USER_DB_NAME = "user-backup-database.db"
const val BACKUP_METADATA_FILE_NAME = "export.json"
const val BACKUP_ENCRYPTED_EXTENSION = "cc20"
Expand All @@ -30,6 +33,16 @@ object BackupConstants {
const val BACKUP_WEB_EVENTS_FILE_NAME = "events.json"
const val BACKUP_WEB_CONVERSATIONS_FILE_NAME = "conversations.json"

/**
* list of accepted file names for the backup file
* this is used when extracting data from the zip file
*/
fun acceptedFileNames() = setOf(
BACKUP_USER_DB_NAME,
BACKUP_METADATA_FILE_NAME,
BACKUP_ENCRYPTED_FILE_NAME
)

fun createBackupFileName(userHandle: String?, timestampIso: String) = // file names cannot have special characters
"$BACKUP_FILE_NAME_PREFIX-$userHandle-${timestampIso.replace(":", "-")}.zip"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@ import com.wire.kalium.cryptography.backup.BackupHeader.HeaderDecodingErrors.INV
import com.wire.kalium.cryptography.backup.Passphrase
import com.wire.kalium.cryptography.utils.ChaCha20Decryptor.decryptBackupFile
import com.wire.kalium.logic.data.asset.KaliumFileSystem
import com.wire.kalium.logic.data.id.CurrentClientIdProvider
import com.wire.kalium.logic.data.id.IdMapper
import com.wire.kalium.logic.data.user.UserId
import com.wire.kalium.logic.data.user.UserRepository
import com.wire.kalium.logic.di.MapperProvider
import com.wire.kalium.logic.data.id.CurrentClientIdProvider
import com.wire.kalium.logic.feature.backup.BackupConstants.BACKUP_ENCRYPTED_EXTENSION
import com.wire.kalium.logic.feature.backup.BackupConstants.acceptedFileNames
import com.wire.kalium.logic.feature.backup.BackupConstants.createBackupFileName
import com.wire.kalium.logic.feature.backup.RestoreBackupResult.BackupRestoreFailure.BackupIOFailure
import com.wire.kalium.logic.feature.backup.RestoreBackupResult.BackupRestoreFailure.DecryptionFailure
Expand All @@ -43,6 +44,7 @@ import com.wire.kalium.logic.functional.flatMap
import com.wire.kalium.logic.functional.fold
import com.wire.kalium.logic.functional.mapLeft
import com.wire.kalium.logic.kaliumLogger
import com.wire.kalium.logic.util.ExtractFilesParam
import com.wire.kalium.logic.util.extractCompressedFile
import com.wire.kalium.logic.wrapStorageRequest
import com.wire.kalium.network.tools.KtxSerializer
Expand Down Expand Up @@ -192,23 +194,19 @@ internal class RestoreBackupUseCaseImpl(
when (decodingError) {
INVALID_USER_ID -> InvalidUserId
INVALID_VERSION -> IncompatibleBackup("The provided backup version is lower than the minimum supported version")
INVALID_FORMAT -> IncompatibleBackup("The provided backup format is not supported")
INVALID_FORMAT -> IncompatibleBackup("mappedDecodingError: The provided backup format is not supported")
}

private suspend fun checkIsValidEncryption(extractedBackupPath: Path): Either<Failure, Path> =
with(kaliumFileSystem) {
val encryptedFilePath = listDirectories(extractedBackupPath).firstOrNull {
it.name.substringAfterLast('.', "") == BACKUP_ENCRYPTED_EXTENSION
}
return if (encryptedFilePath == null) {
Either.Left(Failure(DecryptionFailure("No encrypted backup file found")))
} else {
Either.Right(encryptedFilePath)
}
listDirectories(extractedBackupPath)
.firstOrNull { it.name.endsWith(".$BACKUP_ENCRYPTED_EXTENSION") }?.let {
Either.Right(it)
} ?: Either.Left(Failure(DecryptionFailure("No encrypted backup file found")))
}

private fun extractFiles(inputSource: Source, extractedBackupRootPath: Path) =
extractCompressedFile(inputSource, extractedBackupRootPath, kaliumFileSystem)
extractCompressedFile(inputSource, extractedBackupRootPath, ExtractFilesParam.Only(acceptedFileNames()), kaliumFileSystem)

private suspend fun getDbPathAndImport(
extractedBackupRootPath: Path,
Expand All @@ -230,14 +228,15 @@ internal class RestoreBackupUseCaseImpl(
private suspend fun backupMetadata(extractedBackupPath: Path): Either<Failure, BackupMetadata> =
kaliumFileSystem.listDirectories(extractedBackupPath)
.firstOrNull { it.name == BackupConstants.BACKUP_METADATA_FILE_NAME }
?.let { metadataFile ->
.let { it ?: return Either.Left(Failure(IncompatibleBackup("backupMetadata: No metadata file found"))) }
.let { metadataFile ->
try {
kaliumFileSystem.source(metadataFile).buffer()
.use { Either.Right(KtxSerializer.json.decodeFromString(it.readUtf8())) }
} catch (e: SerializationException) {
Either.Left(Failure(IncompatibleBackup(e.toString())))
}
} ?: Either.Left(Failure(IncompatibleBackup("The provided backup format is not supported")))
}

private fun isValidBackupAuthor(metadata: BackupMetadata): Either<Failure, BackupMetadata> =
if (metadata.userId == userId.toString() || metadata.userId == userId.value) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ internal class RestoreWebBackupUseCaseImpl(
if (metadata.version == "19") {
importWebBackup(backupRootPath, this)
} else {
Either.Left(IncompatibleBackup("The provided backup format is not supported"))
Either.Left(IncompatibleBackup("invoke: The provided backup format is not supported"))
}.fold({ RestoreBackupResult.Failure(it) }, { RestoreBackupResult.Success })
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,24 @@ import okio.Sink
import okio.Source

expect fun createCompressedFile(files: List<Pair<Source, String>>, outputSink: Sink): Either<CoreFailure, Long>
expect fun extractCompressedFile(inputSource: Source, outputRootPath: Path, fileSystem: KaliumFileSystem): Either<CoreFailure, Long>
expect fun extractCompressedFile(
inputSource: Source,
outputRootPath: Path,
param: ExtractFilesParam,
fileSystem: KaliumFileSystem
): Either<CoreFailure, Long>

expect fun checkIfCompressedFileContainsFileTypes(
compressedFilePath: Path,
fileSystem: KaliumFileSystem,
expectedFileExtensions: List<String>
): Either<CoreFailure, Map<String, Boolean>>

sealed interface ExtractFilesParam {
data object All : ExtractFilesParam
data class Only(val files: Set<String>) : ExtractFilesParam {
constructor(vararg files: String) : this(files.toSet())
}
}

expect inline fun <reified T> decodeBufferSequence(bufferedSource: BufferedSource): Sequence<T>
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import com.wire.kalium.logic.feature.backup.BackupConstants.BACKUP_METADATA_FILE
import com.wire.kalium.logic.framework.TestUser.SELF
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.logic.test_util.TestKaliumDispatcher
import com.wire.kalium.logic.util.ExtractFilesParam
import com.wire.kalium.logic.util.IgnoreIOS
import com.wire.kalium.logic.util.SecurityHelper
import com.wire.kalium.logic.util.extractCompressedFile
Expand Down Expand Up @@ -107,7 +108,7 @@ class CreateBackupUseCaseTest {
with(fakeFileSystem) {
val extractedFilesPath = tempFilePath()
createDirectory(extractedFilesPath)
extractCompressedFile(source(result.backupFilePath), extractedFilesPath, fakeFileSystem)
extractCompressedFile(source(result.backupFilePath), extractedFilesPath, ExtractFilesParam.All, fakeFileSystem)

assertTrue(listDirectories(extractedFilesPath).firstOrNull { it.name == BACKUP_METADATA_FILE_NAME } != null)
val extractedDB = listDirectories(extractedFilesPath).firstOrNull {
Expand Down Expand Up @@ -174,7 +175,7 @@ class CreateBackupUseCaseTest {
with(fakeFileSystem) {
val extractedFilesPath = tempFilePath()
createDirectory(extractedFilesPath)
extractCompressedFile(source(result.backupFilePath), extractedFilesPath, fakeFileSystem)
extractCompressedFile(source(result.backupFilePath), extractedFilesPath, ExtractFilesParam.All, fakeFileSystem)
val extractedDBPath = listDirectories(extractedFilesPath).firstOrNull { it.name.contains(".cc20") }
assertEquals(BACKUP_ENCRYPTED_FILE_NAME, extractedDBPath?.name)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,24 @@ package com.wire.kalium.logic.feature.backup
import com.wire.kalium.cryptography.backup.BackupCoder
import com.wire.kalium.cryptography.backup.Passphrase
import com.wire.kalium.cryptography.utils.ChaCha20Encryptor.encryptBackupFile
import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.clientPlatform
import com.wire.kalium.logic.data.asset.FakeKaliumFileSystem
import com.wire.kalium.logic.data.asset.KaliumFileSystem
import com.wire.kalium.logic.data.conversation.ClientId
import com.wire.kalium.logic.data.user.SelfUser
import com.wire.kalium.logic.data.user.UserId
import com.wire.kalium.logic.data.user.UserRepository
import com.wire.kalium.logic.di.MapperProvider
import com.wire.kalium.logic.data.id.CurrentClientIdProvider
import com.wire.kalium.logic.feature.backup.BackupConstants.BACKUP_ENCRYPTED_FILE_NAME
import com.wire.kalium.logic.feature.backup.BackupConstants.BACKUP_USER_DB_NAME
import com.wire.kalium.logic.framework.TestUser
import com.wire.kalium.logic.functional.Either
import com.wire.kalium.logic.util.ExtractFilesParam
import com.wire.kalium.logic.util.IgnoreIOS
import com.wire.kalium.logic.util.createCompressedFile
import com.wire.kalium.logic.util.extractCompressedFile
import com.wire.kalium.persistence.backup.DatabaseImporter
import com.wire.kalium.persistence.db.UserDBSecret
import com.wire.kalium.util.DateTimeUtil
Expand All @@ -48,14 +54,15 @@ import kotlinx.coroutines.test.runTest
import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.Json
import okio.Path
import okio.Source
import okio.buffer
import okio.use
import kotlin.test.Ignore
import kotlin.test.Test
import kotlin.test.assertIs
import kotlin.test.assertTrue

@IgnoreIOS // TODO re-enable when BackupUtils is implemented on Darwin
@OptIn(ExperimentalCoroutinesApi::class)
class RestoreBackupUseCaseTest {

private val fakeFileSystem = FakeKaliumFileSystem()
Expand All @@ -76,7 +83,7 @@ class RestoreBackupUseCaseTest {
val result = useCase(backupPath, "")

// then
assertTrue(result is RestoreBackupResult.Success)
assertIs<RestoreBackupResult.Success>(result)
verify(arrangement.databaseImporter)
.suspendFunction(arrangement.databaseImporter::importFromFile)
.with(any(), any())
Expand All @@ -99,8 +106,8 @@ class RestoreBackupUseCaseTest {
val result = useCase(backupPath, "")

// then
assertTrue(result is RestoreBackupResult.Failure)
assertTrue(result.failure is RestoreBackupResult.BackupRestoreFailure.InvalidUserId)
assertIs<RestoreBackupResult.Failure>(result)
assertIs<RestoreBackupResult.BackupRestoreFailure.InvalidUserId>(result.failure)

verify(arrangement.databaseImporter)
.suspendFunction(arrangement.databaseImporter::importFromFile)
Expand All @@ -122,8 +129,8 @@ class RestoreBackupUseCaseTest {
val result = useCase(backupPath, "")

// then
assertTrue(result is RestoreBackupResult.Failure)
assertTrue(result.failure is RestoreBackupResult.BackupRestoreFailure.IncompatibleBackup)
assertIs<RestoreBackupResult.Failure>(result)
assertIs<RestoreBackupResult.BackupRestoreFailure.IncompatibleBackup>(result.failure)

verify(arrangement.databaseImporter)
.suspendFunction(arrangement.databaseImporter::importFromFile)
Expand All @@ -148,7 +155,7 @@ class RestoreBackupUseCaseTest {
val result = useCase(backupPath, password)

// then
assertTrue(result is RestoreBackupResult.Success)
assertIs<RestoreBackupResult.Success>(result)

verify(arrangement.databaseImporter)
.suspendFunction(arrangement.databaseImporter::importFromFile)
Expand Down Expand Up @@ -225,8 +232,8 @@ class RestoreBackupUseCaseTest {
val result = useCase(backupPath, password)

// then
assertTrue(result is RestoreBackupResult.Failure)
assertTrue(result.failure is RestoreBackupResult.BackupRestoreFailure.BackupIOFailure)
assertIs<RestoreBackupResult.Failure>(result)
assertIs<RestoreBackupResult.BackupRestoreFailure.BackupIOFailure>(result.failure)
verify(arrangement.databaseImporter)
.suspendFunction(arrangement.databaseImporter::importFromFile)
.with(any(), any())
Expand All @@ -247,7 +254,7 @@ class RestoreBackupUseCaseTest {
@Mock
val userRepository = mock(classOf<UserRepository>())

val fakeDBFileName = "fakeDBFile.db"
val fakeDBFileName = BACKUP_USER_DB_NAME
private val selfUserId = currentTestUserId
private val fakeDBData = fakeDBFileName.encodeToByteArray()
private val idMapper = MapperProvider.idMapper()
Expand Down Expand Up @@ -317,7 +324,7 @@ class RestoreBackupUseCaseTest {

suspend fun withEncryptedBackup(path: Path, userId: UserId, password: String) = apply {
with(fakeFileSystem) {
val encryptedBackupPath = fakeFileSystem.tempFilePath("backup.cc20")
val encryptedBackupPath = fakeFileSystem.tempFilePath(BACKUP_ENCRYPTED_FILE_NAME)
createEncryptedBackup(encryptedBackupPath, userId, password)
val outputSink = sink(path)
createCompressedFile(listOf(source(encryptedBackupPath) to encryptedBackupPath.name), outputSink)
Expand Down Expand Up @@ -352,6 +359,22 @@ class RestoreBackupUseCaseTest {
.thenReturn(selfUser)
}

lateinit var extractZipFile: (
inputSource: Source,
outputRootPath: Path,
param: ExtractFilesParam,
fileSystem: KaliumFileSystem
) -> Either<CoreFailure, Long>

fun withSuccessfulExtractZipFile() = apply {
extractZipFile = { _, _, _, _ -> Either.Right(10L) }
}

fun withDefaultExtractZipFile() = apply {
extractZipFile = ::extractCompressedFile
}


fun arrange() = this to RestoreBackupUseCaseImpl(
databaseImporter = databaseImporter,
kaliumFileSystem = fakeFileSystem,
Expand Down
Loading