Skip to content

Commit

Permalink
Load model in background thread (pytorch#2470)
Browse files Browse the repository at this point in the history
Summary:
If we load it on main thread, it could cause ANR.

Pull Request resolved: pytorch#2470

Reviewed By: shoumikhin

Differential Revision: D54970842

Pulled By: kirklandsign

fbshipit-source-id: ffd9a8fddbfefbb9e57d94d0b66190993d023646
  • Loading branch information
kirklandsign authored and facebook-github-bot committed Mar 19, 2024
1 parent 07f52ff commit a5add5f
Showing 1 changed file with 28 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,24 @@ private static String[] listLocalFile(String path, String suffix) {
}

private void setLocalModel(String modelPath, String tokenizerPath) {
Message modelLoadingMessage = new Message("Loading model...", false);
runOnUiThread(
() -> {
mSendButton.setEnabled(false);
mMessageAdapter.add(modelLoadingMessage);
mMessageAdapter.notifyDataSetChanged();
});
long runStartTime = System.currentTimeMillis();
mModule = new LlamaModule(modelPath, tokenizerPath, 0.8f);
int loadResult = mModule.load();
if (loadResult != 0) {
AlertDialog.Builder builder = new AlertDialog.Builder(this);
builder.setTitle("Load failed: " + loadResult);
AlertDialog alert = builder.create();
alert.show();
runOnUiThread(
() -> {
alert.show();
});
}

long runDuration = System.currentTimeMillis() - runStartTime;
Expand All @@ -79,8 +89,13 @@ private void setLocalModel(String modelPath, String tokenizerPath) {
+ runDuration
+ " ms";
Message modelLoadedMessage = new Message(modelInfo, false);
mMessageAdapter.add(modelLoadedMessage);
mMessageAdapter.notifyDataSetChanged();
runOnUiThread(
() -> {
mSendButton.setEnabled(true);
mMessageAdapter.remove(modelLoadingMessage);
mMessageAdapter.add(modelLoadedMessage);
mMessageAdapter.notifyDataSetChanged();
});
}

private String memoryInfo() {
Expand Down Expand Up @@ -116,7 +131,14 @@ private void modelDialog() {
-1,
(dialog, item) -> {
mModelFilePath = pteFiles[item];
setLocalModel(mModelFilePath, mTokenizerFilePath);
Runnable runnable =
new Runnable() {
@Override
public void run() {
setLocalModel(mModelFilePath, mTokenizerFilePath);
}
};
new Thread(runnable).start();
dialog.dismiss();
});

Expand All @@ -130,6 +152,7 @@ protected void onCreate(Bundle savedInstanceState) {

mEditTextMessage = findViewById(R.id.editTextMessage);
mSendButton = findViewById(R.id.sendButton);
mSendButton.setEnabled(false);
mModelButton = findViewById(R.id.modelButton);
mMessagesView = findViewById(R.id.messages_view);
mMessageAdapter = new MessageAdapter(this, R.layout.sent_message);
Expand All @@ -142,8 +165,8 @@ protected void onCreate(Bundle savedInstanceState) {
modelDialog();
});

setLocalModel("/data/local/tmp/llama/stories110M.pte", "/data/local/tmp/llama/tokenizer.bin");
onModelRunStopped();
modelDialog();
}

private void onModelRunStarted() {
Expand Down

0 comments on commit a5add5f

Please sign in to comment.