From 4dada2a5adb115c9266bfb643998efab8481da95 Mon Sep 17 00:00:00 2001 From: Yelin Jeong Date: Fri, 27 Sep 2024 13:37:09 +0900 Subject: [PATCH] App/Language: Add a LLM example about text generation This patch adds a draft of the LLM example. It uses llama2 model to generate text using input prompt. Signed-off-by: Yelin Jeong --- .../ml/inference/offloading/MainService.kt | 45 +++++++++++++++++++ .../inference/offloading/domain/LlamaUtil.kt | 31 +++++++++++++ .../inference/offloading/domain/NewDataCb.kt | 25 +++++++++++ 3 files changed, 101 insertions(+) create mode 100644 ml_inference_offloading/src/main/java/ai/nnstreamer/ml/inference/offloading/domain/LlamaUtil.kt create mode 100644 ml_inference_offloading/src/main/java/ai/nnstreamer/ml/inference/offloading/domain/NewDataCb.kt diff --git a/ml_inference_offloading/src/main/java/ai/nnstreamer/ml/inference/offloading/MainService.kt b/ml_inference_offloading/src/main/java/ai/nnstreamer/ml/inference/offloading/MainService.kt index 3c866b7..ff1b7df 100644 --- a/ml_inference_offloading/src/main/java/ai/nnstreamer/ml/inference/offloading/MainService.kt +++ b/ml_inference_offloading/src/main/java/ai/nnstreamer/ml/inference/offloading/MainService.kt @@ -7,6 +7,8 @@ import ai.nnstreamer.ml.inference.offloading.data.PreferencesDataStoreImpl import ai.nnstreamer.ml.inference.offloading.network.NsdRegistrationListener import ai.nnstreamer.ml.inference.offloading.network.findPort import ai.nnstreamer.ml.inference.offloading.network.getIpAddress +import ai.nnstreamer.ml.inference.offloading.domain.NewDataCb +import ai.nnstreamer.ml.inference.offloading.domain.runLlama2 import android.Manifest import android.app.NotificationChannel import android.app.NotificationManager @@ -34,6 +36,8 @@ import androidx.core.app.ServiceCompat import androidx.core.content.ContextCompat import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.flow.filter +import kotlinx.coroutines.flow.firstOrNull import kotlinx.coroutines.launch import kotlinx.coroutines.runBlocking import org.nnsuite.nnstreamer.NNStreamer @@ -49,6 +53,7 @@ enum class MessageType(val value: Int) { STOP_MODEL(2), DESTROY_MODEL(3), REQ_OBJ_CLASSIFICATION_FILTER(4), + REQ_LLM_FILTER(5) } /** @@ -118,6 +123,11 @@ class MainService : Service() { } } + MessageType.REQ_LLM_FILTER.value -> { + loadModels() + runLLM(msg.data.getString("input") ?: "", msg.replyTo) + } + else -> super.handleMessage(msg) } } @@ -407,4 +417,39 @@ class MainService : Service() { offloadingServiceRepositoryImpl.deleteOffloadingService(id) } } + + private suspend fun findService(name: String): OffloadingService? { + val models = modelsRepository.getAllModelsStream() + // todo: Improve search methods + val model = models.filter { list -> + list.any { + it.name.contains(name) + } + }.firstOrNull()?.get(0) + + val modelId = model?.uid + + val services = offloadingServiceRepositoryImpl.getAllOffloadingService() + val service = services.filter { list -> + list.any { + it.modelId == modelId + } + }.firstOrNull()?.get(0) + + return service + } + + private fun runLLM(input: String, messenger: Messenger?) { + CoroutineScope(Dispatchers.IO).launch { + val service = findService("llama") + + service?.let { + startService(it.serviceId) + // todo: Support other models + runLlama2(input, getIpAddress(isRunningOnEmulator), it.port, NewDataCb(messenger)) + }?.run { + Log.e(TAG, "Not supported LLM") + } + } + } } diff --git a/ml_inference_offloading/src/main/java/ai/nnstreamer/ml/inference/offloading/domain/LlamaUtil.kt b/ml_inference_offloading/src/main/java/ai/nnstreamer/ml/inference/offloading/domain/LlamaUtil.kt new file mode 100644 index 0000000..8e4da9c --- /dev/null +++ b/ml_inference_offloading/src/main/java/ai/nnstreamer/ml/inference/offloading/domain/LlamaUtil.kt @@ -0,0 +1,31 @@ +package ai.nnstreamer.ml.inference.offloading.domain + +import ai.nnstreamer.ml.inference.offloading.network.findPort +import org.nnsuite.nnstreamer.NNStreamer +import org.nnsuite.nnstreamer.Pipeline +import org.nnsuite.nnstreamer.TensorsData +import org.nnsuite.nnstreamer.TensorsInfo +import java.nio.ByteBuffer + +fun runLlama2(input: String, hostAddress: String, servicePort: Int, newDataCb: NewDataCb) { + val port = findPort() + val desc = + "appsrc name=srcx ! application/octet-stream ! tensor_converter ! other/tensors,format=flexible ! tensor_query_client host=${hostAddress} port=${port} dest-host=${hostAddress} dest-port=${servicePort} timeout=1000000 ! tensor_sink name=sinkx" + val pipeline = Pipeline(desc, null) + + pipeline.registerSinkCallback("sinkx", newDataCb) + // todo: Reuse or destroy the client pipeline + pipeline.start() + + val info = TensorsInfo() + info.addTensorInfo(NNStreamer.TensorType.UINT8, intArrayOf(input.length, 1, 1, 1)) + + val size = info.getTensorSize(0) + val data = TensorsData.allocate(info) + val byteBuffer: ByteBuffer = ByteBuffer.wrap(input.toByteArray()) + + val buffer = TensorsData.allocateByteBuffer(size) + buffer.put(byteBuffer) + data.setTensorData(0, buffer) + pipeline.inputData("srcx", data) +} diff --git a/ml_inference_offloading/src/main/java/ai/nnstreamer/ml/inference/offloading/domain/NewDataCb.kt b/ml_inference_offloading/src/main/java/ai/nnstreamer/ml/inference/offloading/domain/NewDataCb.kt new file mode 100644 index 0000000..e5f83ef --- /dev/null +++ b/ml_inference_offloading/src/main/java/ai/nnstreamer/ml/inference/offloading/domain/NewDataCb.kt @@ -0,0 +1,25 @@ +package ai.nnstreamer.ml.inference.offloading.domain + +import android.os.Message +import android.os.Messenger +import org.nnsuite.nnstreamer.Pipeline +import org.nnsuite.nnstreamer.TensorsData + +class NewDataCb(private val messenger: Messenger?) : Pipeline.NewDataCallback { + override fun onNewDataReceived(data: TensorsData?) { + val received = data?.getTensorData(0) + received?.let { + val result = mutableListOf() + + for (byte in received.array()) { + if (byte != 0.toByte()) { + result.add(byte) + } + } + + val response = Message.obtain() + response.data.putString("response", String(result.toByteArray(), Charsets.UTF_8)) + messenger?.send(response) + } + } +}