Skip to content

Commit

Permalink
App/Language: Add a LLM example about text generation
Browse files Browse the repository at this point in the history
This patch adds a draft of the LLM example.
It uses llama2 model to generate text using input prompt.

Signed-off-by: Yelin Jeong <[email protected]>
  • Loading branch information
niley7464 committed Sep 27, 2024
1 parent c68f3f2 commit 4dada2a
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import ai.nnstreamer.ml.inference.offloading.data.PreferencesDataStoreImpl
import ai.nnstreamer.ml.inference.offloading.network.NsdRegistrationListener
import ai.nnstreamer.ml.inference.offloading.network.findPort
import ai.nnstreamer.ml.inference.offloading.network.getIpAddress
import ai.nnstreamer.ml.inference.offloading.domain.NewDataCb
import ai.nnstreamer.ml.inference.offloading.domain.runLlama2
import android.Manifest
import android.app.NotificationChannel
import android.app.NotificationManager
Expand Down Expand Up @@ -34,6 +36,8 @@ import androidx.core.app.ServiceCompat
import androidx.core.content.ContextCompat
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.flow.filter
import kotlinx.coroutines.flow.firstOrNull
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import org.nnsuite.nnstreamer.NNStreamer
Expand All @@ -49,6 +53,7 @@ enum class MessageType(val value: Int) {
STOP_MODEL(2),
DESTROY_MODEL(3),
REQ_OBJ_CLASSIFICATION_FILTER(4),
REQ_LLM_FILTER(5)
}

/**
Expand Down Expand Up @@ -118,6 +123,11 @@ class MainService : Service() {
}
}

MessageType.REQ_LLM_FILTER.value -> {
loadModels()
runLLM(msg.data.getString("input") ?: "", msg.replyTo)
}

else -> super.handleMessage(msg)
}
}
Expand Down Expand Up @@ -407,4 +417,39 @@ class MainService : Service() {
offloadingServiceRepositoryImpl.deleteOffloadingService(id)
}
}

private suspend fun findService(name: String): OffloadingService? {
val models = modelsRepository.getAllModelsStream()
// todo: Improve search methods
val model = models.filter { list ->
list.any {
it.name.contains(name)
}
}.firstOrNull()?.get(0)

val modelId = model?.uid

val services = offloadingServiceRepositoryImpl.getAllOffloadingService()
val service = services.filter { list ->
list.any {
it.modelId == modelId
}
}.firstOrNull()?.get(0)

return service
}

private fun runLLM(input: String, messenger: Messenger?) {
CoroutineScope(Dispatchers.IO).launch {
val service = findService("llama")

service?.let {
startService(it.serviceId)
// todo: Support other models
runLlama2(input, getIpAddress(isRunningOnEmulator), it.port, NewDataCb(messenger))
}?.run {
Log.e(TAG, "Not supported LLM")
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package ai.nnstreamer.ml.inference.offloading.domain

import ai.nnstreamer.ml.inference.offloading.network.findPort
import org.nnsuite.nnstreamer.NNStreamer
import org.nnsuite.nnstreamer.Pipeline
import org.nnsuite.nnstreamer.TensorsData
import org.nnsuite.nnstreamer.TensorsInfo
import java.nio.ByteBuffer

fun runLlama2(input: String, hostAddress: String, servicePort: Int, newDataCb: NewDataCb) {
val port = findPort()
val desc =
"appsrc name=srcx ! application/octet-stream ! tensor_converter ! other/tensors,format=flexible ! tensor_query_client host=${hostAddress} port=${port} dest-host=${hostAddress} dest-port=${servicePort} timeout=1000000 ! tensor_sink name=sinkx"
val pipeline = Pipeline(desc, null)

pipeline.registerSinkCallback("sinkx", newDataCb)
// todo: Reuse or destroy the client pipeline
pipeline.start()

val info = TensorsInfo()
info.addTensorInfo(NNStreamer.TensorType.UINT8, intArrayOf(input.length, 1, 1, 1))

val size = info.getTensorSize(0)
val data = TensorsData.allocate(info)
val byteBuffer: ByteBuffer = ByteBuffer.wrap(input.toByteArray())

val buffer = TensorsData.allocateByteBuffer(size)
buffer.put(byteBuffer)
data.setTensorData(0, buffer)
pipeline.inputData("srcx", data)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package ai.nnstreamer.ml.inference.offloading.domain

import android.os.Message
import android.os.Messenger
import org.nnsuite.nnstreamer.Pipeline
import org.nnsuite.nnstreamer.TensorsData

class NewDataCb(private val messenger: Messenger?) : Pipeline.NewDataCallback {
override fun onNewDataReceived(data: TensorsData?) {
val received = data?.getTensorData(0)
received?.let {
val result = mutableListOf<Byte>()

for (byte in received.array()) {
if (byte != 0.toByte()) {
result.add(byte)
}
}

val response = Message.obtain()
response.data.putString("response", String(result.toByteArray(), Charsets.UTF_8))
messenger?.send(response)
}
}
}

0 comments on commit 4dada2a

Please sign in to comment.