Skip to content

Commit

Permalink
add version to frameworks
Browse files Browse the repository at this point in the history
  • Loading branch information
gordinmitya committed Nov 27, 2022
1 parent 27a5348 commit ad75056
Show file tree
Hide file tree
Showing 12 changed files with 38 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class MainActivity : AppCompatActivity() {
delay(sleep)
configurations.forEach { configuration ->
logger.spoiler("Running with ${configuration.inferenceFramework} on ${configuration.inferenceType}")
val result =
val result: InferenceResult =
if (!configuration.inferenceType.isSupported) {
NotSupportedResult(ConfigurationEntity(configuration))
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ sealed class InferenceResult : Parcelable {

override fun toString(): String = configuration.run {
val taskLetter = configuration.task[0].uppercaseChar()
return "$taskLetter $frameworkName $inferenceType"
return "$taskLetter $frameworkName-$frameworkVersion $inferenceType"
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
package ru.gordinmitya.dnnbenchmark.model

import android.os.Parcelable
import kotlinx.android.parcel.Parcelize
import kotlinx.parcelize.Parcelize
import ru.gordinmitya.common.Configuration
import ru.gordinmitya.dnnbenchmark.App
import java.util.*

@Parcelize
class ConfigurationEntity(
val frameworkName: String,
val frameworkVersion: String,
val frameworkClassName: String,
val inferenceType: String,
val isSupported: Boolean,
Expand All @@ -18,10 +19,11 @@ class ConfigurationEntity(

constructor(configuration: Configuration) : this(
frameworkName = configuration.inferenceFramework.name,
frameworkVersion = configuration.inferenceFramework.version.toString(),
frameworkClassName = App.describeFramework(configuration.inferenceFramework.javaClass.kotlin),
inferenceType = configuration.inferenceType.name,
isSupported = configuration.inferenceType.isSupported,
task = configuration.model.task.name.toLowerCase(Locale.ROOT),
task = configuration.model.task.name.lowercase(Locale.ROOT),
model = configuration.model.name
)

Expand All @@ -35,4 +37,4 @@ class ConfigurationEntity(
}
return Configuration(frameworkInstance, type, model)
}
}
}
12 changes: 11 additions & 1 deletion common/src/main/java/ru/gordinmitya/common/InferenceFramework.kt
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
package ru.gordinmitya.common

class Version(
val name: String,
val commitHash: String? = null,
) {
override fun toString(): String {
if (commitHash == null) return name
return "$name ($commitHash)"
}
}

abstract class InferenceFramework(
val name: String,
val description: String
val version: Version
) {
abstract fun getModels(): List<Model>
abstract fun getInferenceTypes(): List<InferenceType>
Expand Down
4 changes: 2 additions & 2 deletions common/src/main/java/ru/gordinmitya/common/NativeHelper.kt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import android.util.Log
import java.io.*

object NativeHelper {
private fun processName(): String? {
private fun processName(): String {
val path = "/proc/" + Process.myPid() + "/cmdline"
BufferedReader(InputStreamReader(FileInputStream(path), "iso-8859-1")).use { reader ->
var c: Int
Expand All @@ -26,4 +26,4 @@ object NativeHelper {
)
System.loadLibrary(libname!!)
}
}
}
11 changes: 2 additions & 9 deletions common/src/test/java/ru/gordinmitya/common/ConfigurationTest.kt
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
package ru.gordinmitya.common

import android.content.Context
import io.mockk.mockk
import org.junit.jupiter.api.Assertions.assertDoesNotThrow
import org.junit.jupiter.api.Assertions.assertThrows
import org.junit.jupiter.api.Test
import ru.gordinmitya.common.classification.Classifier

internal class ConfigurationTest {
private fun createFramework(): InferenceFramework {
Expand All @@ -17,15 +15,10 @@ internal class ConfigurationTest {
mockk(),
mockk()
)
return object : InferenceFramework("", "") {
return object : InferenceFramework("", Version("?")) {
override fun getInferenceTypes(): List<InferenceType> = types

override fun getModels(): List<Model> = models

override fun createClassifier(
context: Context,
configuration: Configuration
): Classifier = mockk()
}
}

Expand Down Expand Up @@ -55,4 +48,4 @@ internal class ConfigurationTest {
Configuration(framework, otherInferenceType, framework.getModels()[0])
}
}
}
}
4 changes: 2 additions & 2 deletions mnn/src/main/java/ru/gordinmitya/mnn/MNNFramework.kt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import ru.gordinmitya.common.segmentation.SegmentationFramework
import ru.gordinmitya.common.segmentation.SegmentationModel
import ru.gordinmitya.common.segmentation.Segmentator

class MNNFramework : InferenceFramework("MNN", "by Alibaba"), ClassificationFramework,
class MNNFramework : InferenceFramework("MNN", Version("?")), ClassificationFramework,
SegmentationFramework {
private val TYPES = arrayListOf(
CPU,
Expand Down Expand Up @@ -42,4 +42,4 @@ class MNNFramework : InferenceFramework("MNN", "by Alibaba"), ClassificationFram
}

override fun getDataOrder(): DataOrder = DataOrder.NCWH
}
}
9 changes: 3 additions & 6 deletions ncnn/src/main/java/ru/gordinmitya/ncnn/NCNNFramework.kt
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
package ru.gordinmitya.ncnn

import android.content.Context
import ru.gordinmitya.common.Configuration
import ru.gordinmitya.common.InferenceFramework
import ru.gordinmitya.common.InferenceType
import ru.gordinmitya.common.Model
import ru.gordinmitya.common.*
import ru.gordinmitya.common.classification.ClassificationFramework
import ru.gordinmitya.common.classification.Classifier

class NCNNFramework : InferenceFramework("NCNN", "by Tencent"), ClassificationFramework {
class NCNNFramework : InferenceFramework("NCNN", Version("?")), ClassificationFramework {
private val types = arrayListOf(
NCNN_CPU,
NCNN_VULKAN
Expand All @@ -27,4 +24,4 @@ class NCNNFramework : InferenceFramework("NCNN", "by Tencent"), ClassificationFr

return NCNNClassifier(context, configuration, convertedModel, inferenceType)
}
}
}
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
package ru.gordinmitya.onnxruntime

import android.content.Context
import ru.gordinmitya.common.Configuration
import ru.gordinmitya.common.InferenceFramework
import ru.gordinmitya.common.InferenceType
import ru.gordinmitya.common.Model
import ru.gordinmitya.common.*
import ru.gordinmitya.common.classification.ClassificationFramework
import ru.gordinmitya.common.classification.Classifier

class ONNXFramework : InferenceFramework("onnxruntime", "by Microsoft"), ClassificationFramework {
class ONNXFramework : InferenceFramework("onnxruntime", Version("?")), ClassificationFramework {
private val types = arrayListOf(
ONNX_CPU,
ONNX_NNAPI
Expand All @@ -27,4 +24,4 @@ class ONNXFramework : InferenceFramework("onnxruntime", "by Microsoft"), Classif

return ONNXClassifier(context, configuration, convertedModel, inferenceType)
}
}
}
5 changes: 3 additions & 2 deletions opencv/src/main/java/ru/gordinmitya/opencv/OpenCVFramework.kt
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
package ru.gordinmitya.opencv

import android.content.Context
import org.opencv.android.OpenCVLoader.OPENCV_VERSION
import ru.gordinmitya.common.*
import ru.gordinmitya.common.classification.ClassificationFramework
import ru.gordinmitya.common.classification.ClassificationModel
import ru.gordinmitya.common.classification.Classifier

class OpenCVFramework : InferenceFramework("OpenCV DNN", "by OpenCV"), ClassificationFramework {
class OpenCVFramework : InferenceFramework("OpenCV DNN", Version(OPENCV_VERSION)), ClassificationFramework {
private val TYPES = listOf(OPENCV_CPU)

override fun getInferenceTypes(): List<InferenceType> = TYPES
Expand All @@ -22,4 +23,4 @@ class OpenCVFramework : InferenceFramework("OpenCV DNN", "by OpenCV"), Classific

return OpenCVClassifier(context, configuration, convertedModel, inferenceType)
}
}
}
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
package ru.gordinmitya.pytorch

import android.content.Context
import ru.gordinmitya.common.Configuration
import ru.gordinmitya.common.InferenceFramework
import ru.gordinmitya.common.InferenceType
import ru.gordinmitya.common.Model
import ru.gordinmitya.common.*
import ru.gordinmitya.common.classification.ClassificationFramework
import ru.gordinmitya.common.classification.ClassificationModel
import ru.gordinmitya.common.classification.Classifier

class PytorchFramework : InferenceFramework("Pytorch", "by Facebook"), ClassificationFramework {
class PytorchFramework : InferenceFramework("Pytorch", Version("1.5.0")), ClassificationFramework {
private val TYPES = listOf(PYTORCH_CPU)

override fun getInferenceTypes(): List<InferenceType> = TYPES
Expand All @@ -25,4 +22,4 @@ class PytorchFramework : InferenceFramework("Pytorch", "by Facebook"), Classific

return PytorchClassifier(context, configuration, convertedModel, inferenceType)
}
}
}
4 changes: 2 additions & 2 deletions tflite/src/main/java/ru/gordinmitya/tflite/TFLiteFramework.kt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import ru.gordinmitya.common.segmentation.SegmentationFramework
import ru.gordinmitya.common.segmentation.SegmentationModel
import ru.gordinmitya.common.segmentation.Segmentator

class TFLiteFramework : InferenceFramework("TFLite", "by Google"),
class TFLiteFramework : InferenceFramework("TFLite", Version("2.3.0")),
ClassificationFramework,
SegmentationFramework {

Expand Down Expand Up @@ -46,4 +46,4 @@ class TFLiteFramework : InferenceFramework("TFLite", "by Google"),
}

override fun getDataOrder(): DataOrder = DataOrder.NHWC
}
}

0 comments on commit ad75056

Please sign in to comment.