Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: validate the list of allowed file names when extracting files form zip folder #2221

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,19 @@ import com.wire.kalium.logic.CoreFailure
import com.wire.kalium.logic.data.asset.KaliumFileSystem
import com.wire.kalium.logic.functional.Either
import okio.BufferedSource
import okio.Source
import okio.Sink
import okio.Path
import okio.Sink
import okio.Source

actual fun createCompressedFile(files: List<Pair<Source, String>>, outputSink: Sink): Either<CoreFailure, Long> =
TODO("Implement own iOS compression method")

actual fun extractCompressedFile(inputSource: Source, outputRootPath: Path, fileSystem: KaliumFileSystem): Either<CoreFailure, Long> =
actual fun extractCompressedFile(
inputSource: Source,
outputRootPath: Path,
param: ExtractFilesParam,
fileSystem: KaliumFileSystem
): Either<CoreFailure, Long> =
TODO("Implement own iOS compression method")

actual fun checkIfCompressedFileContainsFileTypes(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,22 +68,45 @@ private fun addToCompressedFile(zipOutputStream: ZipOutputStream, fileSource: So
}

@Suppress("TooGenericExceptionCaught", "NestedBlockDepth")
actual fun extractCompressedFile(inputSource: Source, outputRootPath: Path, fileSystem: KaliumFileSystem): Either<CoreFailure, Long> = try {
actual fun extractCompressedFile(
inputSource: Source,
outputRootPath: Path,
param: ExtractFilesParam,
fileSystem: KaliumFileSystem
): Either<CoreFailure, Long> = try {
var totalExtractedFilesSize = 0L
ZipInputStream(inputSource.buffer().inputStream()).use { zipInputStream ->
var entry: ZipEntry? = zipInputStream.nextEntry
while (entry != null) {
readCompressedEntry(zipInputStream, outputRootPath, fileSystem, entry).let {
totalExtractedFilesSize += it.first
entry = it.second
totalExtractedFilesSize += when (param) {
is ExtractFilesParam.All -> readCompressedEntry(zipInputStream, outputRootPath, fileSystem, entry)
is ExtractFilesParam.Only -> readAndExtractIfMatch(zipInputStream, outputRootPath, fileSystem, entry, param.files)
}
zipInputStream.closeEntry()
entry = zipInputStream.nextEntry
}
}
Either.Right(totalExtractedFilesSize)
} catch (e: Exception) {
Either.Left(StorageFailure.Generic(RuntimeException("There was an error trying to extract the provided compressed file", e)))
}

private fun readAndExtractIfMatch(
zipInputStream: ZipInputStream,
outputRootPath: Path,
fileSystem: KaliumFileSystem,
entry: ZipEntry,
fileNames: Set<String>
): Long {
return entry.name.let {
if (fileNames.contains(it)) {
readCompressedEntry(zipInputStream, outputRootPath, fileSystem, entry)
} else {
0L
}
}
}

@Suppress("TooGenericExceptionCaught", "NestedBlockDepth")
actual fun checkIfCompressedFileContainsFileTypes(
compressedFilePath: Path,
Expand Down Expand Up @@ -121,11 +144,7 @@ private fun readCompressedEntry(
outputRootPath: Path,
fileSystem: KaliumFileSystem,
entry: ZipEntry
): Pair<Long, ZipEntry?> {
if (isInvalidEntryPathDestination(entry.name)) {
throw RuntimeException("The provided zip file is invalid or has invalid data")
}

): Long {
var totalExtractedFilesSize = 0L
var byteCount: Int
val entryPathName = "$outputRootPath/${entry.name}"
Expand All @@ -137,8 +156,7 @@ private fun readCompressedEntry(
}
output.write(zipInputStream.readBytes())
}
zipInputStream.closeEntry()
return totalExtractedFilesSize to zipInputStream.nextEntry
return totalExtractedFilesSize
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ package com.wire.kalium.logic.feature.backup
object BackupConstants {
const val BACKUP_FILE_NAME_PREFIX = "WBX"
const val BACKUP_ENCRYPTED_FILE_NAME = "user-backup.cc20"

// BACKUP_METADATA_FILE_NAME and BACKUP_USER_DB_NAME must not be changed
// if there is a need to change them, please create a new file names and add it to the list of acceptedFileNames()
const val BACKUP_USER_DB_NAME = "user-backup-database.db"
const val BACKUP_METADATA_FILE_NAME = "export.json"
const val BACKUP_ENCRYPTED_EXTENSION = "cc20"
Expand All @@ -30,6 +33,15 @@ object BackupConstants {
const val BACKUP_WEB_EVENTS_FILE_NAME = "events.json"
const val BACKUP_WEB_CONVERSATIONS_FILE_NAME = "conversations.json"

/**
* list of accepted file names for the backup file
* this is used when extracting data from the zip file
*/
fun acceptedFileNames() = setOf(
BACKUP_USER_DB_NAME,
BACKUP_METADATA_FILE_NAME
)

fun createBackupFileName(userHandle: String?, timestampIso: String) = // file names cannot have special characters
"$BACKUP_FILE_NAME_PREFIX-$userHandle-${timestampIso.replace(":", "-")}.zip"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@ import com.wire.kalium.cryptography.backup.BackupHeader.HeaderDecodingErrors.INV
import com.wire.kalium.cryptography.backup.Passphrase
import com.wire.kalium.cryptography.utils.ChaCha20Decryptor.decryptBackupFile
import com.wire.kalium.logic.data.asset.KaliumFileSystem
import com.wire.kalium.logic.data.id.CurrentClientIdProvider
import com.wire.kalium.logic.data.id.IdMapper
import com.wire.kalium.logic.data.user.UserId
import com.wire.kalium.logic.data.user.UserRepository
import com.wire.kalium.logic.di.MapperProvider
import com.wire.kalium.logic.data.id.CurrentClientIdProvider
import com.wire.kalium.logic.feature.backup.BackupConstants.BACKUP_ENCRYPTED_EXTENSION
import com.wire.kalium.logic.feature.backup.BackupConstants.acceptedFileNames
import com.wire.kalium.logic.feature.backup.BackupConstants.createBackupFileName
import com.wire.kalium.logic.feature.backup.RestoreBackupResult.BackupRestoreFailure.BackupIOFailure
import com.wire.kalium.logic.feature.backup.RestoreBackupResult.BackupRestoreFailure.DecryptionFailure
Expand All @@ -43,6 +44,7 @@ import com.wire.kalium.logic.functional.flatMap
import com.wire.kalium.logic.functional.fold
import com.wire.kalium.logic.functional.mapLeft
import com.wire.kalium.logic.kaliumLogger
import com.wire.kalium.logic.util.ExtractFilesParam
import com.wire.kalium.logic.util.extractCompressedFile
import com.wire.kalium.logic.wrapStorageRequest
import com.wire.kalium.network.tools.KtxSerializer
Expand Down Expand Up @@ -208,7 +210,11 @@ internal class RestoreBackupUseCaseImpl(
}

private fun extractFiles(inputSource: Source, extractedBackupRootPath: Path) =
extractCompressedFile(inputSource, extractedBackupRootPath, kaliumFileSystem)
extractCompressedFile(
inputSource, extractedBackupRootPath, ExtractFilesParam.Only(
acceptedFileNames()
), kaliumFileSystem
)

private suspend fun getDbPathAndImport(
extractedBackupRootPath: Path,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,24 @@ import okio.Sink
import okio.Source

expect fun createCompressedFile(files: List<Pair<Source, String>>, outputSink: Sink): Either<CoreFailure, Long>
expect fun extractCompressedFile(inputSource: Source, outputRootPath: Path, fileSystem: KaliumFileSystem): Either<CoreFailure, Long>
expect fun extractCompressedFile(
inputSource: Source,
outputRootPath: Path,
param: ExtractFilesParam,
fileSystem: KaliumFileSystem
): Either<CoreFailure, Long>

expect fun checkIfCompressedFileContainsFileTypes(
compressedFilePath: Path,
fileSystem: KaliumFileSystem,
expectedFileExtensions: List<String>
): Either<CoreFailure, Map<String, Boolean>>

sealed interface ExtractFilesParam {
data object All : ExtractFilesParam
data class Only(val files: Set<String>) : ExtractFilesParam {
constructor(vararg files: String) : this(files.toSet())
}
}

expect inline fun <reified T> decodeBufferSequence(bufferedSource: BufferedSource): Sequence<T>
Loading