Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

App/Language: Add a LLM example about text generation #114

Merged
merged 1 commit into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import ai.nnstreamer.ml.inference.offloading.data.PreferencesDataStoreImpl
import ai.nnstreamer.ml.inference.offloading.network.NsdRegistrationListener
import ai.nnstreamer.ml.inference.offloading.network.findPort
import ai.nnstreamer.ml.inference.offloading.network.getIpAddress
import ai.nnstreamer.ml.inference.offloading.domain.NewDataCb
import ai.nnstreamer.ml.inference.offloading.domain.runLlama2
import android.Manifest
import android.app.NotificationChannel
import android.app.NotificationManager
Expand Down Expand Up @@ -34,6 +36,8 @@ import androidx.core.app.ServiceCompat
import androidx.core.content.ContextCompat
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.flow.filter
import kotlinx.coroutines.flow.firstOrNull
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import org.nnsuite.nnstreamer.NNStreamer
Expand All @@ -49,6 +53,7 @@ enum class MessageType(val value: Int) {
STOP_MODEL(2),
DESTROY_MODEL(3),
REQ_OBJ_CLASSIFICATION_FILTER(4),
REQ_LLM_FILTER(5)
}

/**
Expand Down Expand Up @@ -118,6 +123,11 @@ class MainService : Service() {
}
}

MessageType.REQ_LLM_FILTER.value -> {
loadModels()
runLLM(msg.data.getString("input") ?: "", msg.replyTo)
}

else -> super.handleMessage(msg)
}
}
Expand Down Expand Up @@ -407,4 +417,37 @@ class MainService : Service() {
offloadingServiceRepositoryImpl.deleteOffloadingService(id)
}
}

private suspend fun findService(name: String): OffloadingService? {
val models = modelsRepository.getAllModelsStream()
// todo: Improve search methods
val model = models.filter { list ->
list.any {
it.name.contains(name)
}
}.firstOrNull()?.get(0)

val modelId = model?.uid

val services = offloadingServiceRepositoryImpl.getAllOffloadingService()
val service = services.filter { list ->
list.any {
it.modelId == modelId
}
}.firstOrNull()?.get(0)

return service
}

private fun runLLM(input: String, messenger: Messenger?) {
CoroutineScope(Dispatchers.IO).launch {
// todo: Support other models
val service = findService("llama")

if (service != null) {
startService(service.serviceId)
runLlama2(input, getIpAddress(isRunningOnEmulator), service.port, NewDataCb(messenger))
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package ai.nnstreamer.ml.inference.offloading.domain

import ai.nnstreamer.ml.inference.offloading.network.findPort
import org.nnsuite.nnstreamer.NNStreamer
import org.nnsuite.nnstreamer.Pipeline
import org.nnsuite.nnstreamer.TensorsData
import org.nnsuite.nnstreamer.TensorsInfo
import java.nio.ByteBuffer

fun runLlama2(input: String, hostAddress: String, servicePort: Int, newDataCb: NewDataCb) {
val port = findPort()
val desc =
"appsrc name=srcx ! application/octet-stream ! tensor_converter ! other/tensors,format=flexible ! tensor_query_client host=$hostAddress port=$port dest-host=$hostAddress dest-port=$servicePort timeout=1000000 ! tensor_sink name=sinkx"
val pipeline = Pipeline(desc, null)

pipeline.registerSinkCallback("sinkx", newDataCb)
// todo: Reuse or destroy the client pipeline
pipeline.start()

val info = TensorsInfo()
info.addTensorInfo(NNStreamer.TensorType.UINT8, intArrayOf(input.length, 1, 1, 1))

val size = info.getTensorSize(0)
val data = TensorsData.allocate(info)
val byteBuffer: ByteBuffer = ByteBuffer.wrap(input.toByteArray())

val buffer = TensorsData.allocateByteBuffer(size)
buffer.put(byteBuffer)
data.setTensorData(0, buffer)
pipeline.inputData("srcx", data)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package ai.nnstreamer.ml.inference.offloading.domain

import android.os.Message
import android.os.Messenger
import org.nnsuite.nnstreamer.Pipeline
import org.nnsuite.nnstreamer.TensorsData

class NewDataCb(private val messenger: Messenger?) : Pipeline.NewDataCallback {
override fun onNewDataReceived(data: TensorsData?) {
val received = data?.getTensorData(0)
received?.let {
val result = mutableListOf<Byte>()

for (byte in received.array()) {
if (byte != 0.toByte()) {
result.add(byte)
}
}

val response = Message.obtain()
response.data.putString("response", String(result.toByteArray(), Charsets.UTF_8))
messenger?.send(response)
}
}
}