Skip to content

Commit

Permalink
make text encode
Browse files Browse the repository at this point in the history
  • Loading branch information
ZHEQIUSHUI committed Aug 16, 2023
1 parent c7157c6 commit 041647f
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 85 deletions.
148 changes: 76 additions & 72 deletions src/Runner/CLIP.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,138 +12,142 @@
#include "Tokenizer.hpp"

#define LEN_IMAGE_FEATURE 512
#define LEN_TEXT_FEATURE 77
#define LEN_TEXT_FEATURE 512
#define LEN_TEXT_TOKEN 77

struct CLIP_IMAG_FEATURE_T
{
float feature[LEN_IMAGE_FEATURE];
};
// struct CLIP_IMAG_FEATURE_T
// {
// float feature[LEN_IMAGE_FEATURE];
// };

struct CLIP_TEXT_FEATURE_T
{
int feature[LEN_TEXT_FEATURE];
};
// struct CLIP_TEXT_FEATURE_T
// {
// int feature[LEN_TEXT_FEATURE];
// };

class CLIP
{
protected:
std::string device{"cpu"};
Ort::Env env;
Ort::SessionOptions session_options;
std::shared_ptr<Ort::Session> DecoderSession;
std::shared_ptr<Ort::Session> TextEncoderSession, DecoderSession;
Ort::MemoryInfo memory_info_handler = Ort::MemoryInfo::CreateCpu(
OrtArenaAllocator, OrtMemTypeDefault);

const char *DecoderInputNames[2]{"image_features", "text"},
const char
*TextEncInputNames[1]{"texts"},
*TextEncOutputNames[1]{"text_features"},
*DecoderInputNames[2]{"image_features", "text_features"},
*DecoderOutputNames[2]{"logits_per_image", "logits_per_text"};
float _mean_val[3] = {0.48145466f * 255.f, 0.4578275f * 255.f, 0.40821073f * 255.f};
float _std_val[3] = {1 / (0.26862954f * 255.f), 1 / (0.26130258f * 255.f), 1 / (0.27577711f * 255.f)};
Tokenizer tokenizer;

std::vector<float> image_features_input = std::vector<float>(1024 * LEN_IMAGE_FEATURE);
std::vector<int> text_features_input = std::vector<int>(1024 * LEN_TEXT_FEATURE);
std::vector<float> text_features_input = std::vector<float>(1024 * LEN_TEXT_FEATURE);
std::vector<int> text_tokens_input = std::vector<int>(1024 * LEN_TEXT_TOKEN);

public:
bool load_tokenizer(std::string vocab_path)
{
return tokenizer.load_tokenize(vocab_path);
}

bool load_decoder(std::string decoder_path)
CLIP()
{
env = Ort::Env(ORT_LOGGING_LEVEL_WARNING, "CLIP_DECODER");
session_options = Ort::SessionOptions();
session_options.SetInterOpNumThreads(std::thread::hardware_concurrency());
session_options.SetIntraOpNumThreads(std::thread::hardware_concurrency());
// 设置图像优化级别
session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);
}

bool load_tokenizer(std::string vocab_path)
{
return tokenizer.load_tokenize(vocab_path);
}

bool load_decoder(std::string decoder_path)
{
DecoderSession.reset(new Ort::Session(env, decoder_path.c_str(), session_options));
if (DecoderSession->GetInputCount() != 2 || DecoderSession->GetOutputCount() != 2)
{
ALOGE("Model not loaded (invalid input/output count)");
return false;
}

return true;
}

virtual bool load_encoder(std::string encoder_path) = 0;
virtual void encode(cv::Mat image, std::vector<float> &image_features) = 0;

void encode(std::vector<std::string> &texts, std::vector<std::vector<int>> &feats)
bool load_text_encoder(std::string encoder_path)
{
feats.resize(texts.size());
for (size_t i = 0; i < texts.size(); i++)
TextEncoderSession.reset(new Ort::Session(env, encoder_path.c_str(), session_options));
if (TextEncoderSession->GetInputCount() != 1 || TextEncoderSession->GetOutputCount() != 1)
{
tokenizer.encode_text(texts[i], feats[i]);
ALOGE("Model not loaded (invalid input/output count)");
return false;
}
return true;
}

void decode(std::vector<CLIP_IMAG_FEATURE_T> &image_features, std::vector<CLIP_TEXT_FEATURE_T> &text_features,
std::vector<std::vector<float>> &logits_per_image, std::vector<std::vector<float>> &logits_per_text)
virtual bool load_image_encoder(std::string encoder_path) = 0;
virtual void encode(cv::Mat image, std::vector<float> &image_features) = 0;

void encode(std::vector<std::string> &texts, std::vector<std::vector<float>> &text_features)
{
if (image_features.size() * LEN_IMAGE_FEATURE > image_features_input.size())
std::vector<std::vector<int>> text_token;
text_token.resize(texts.size());
for (size_t i = 0; i < texts.size(); i++)
{
image_features_input.resize(image_features.size() * LEN_IMAGE_FEATURE);
tokenizer.encode_text(texts[i], text_token[i]);
}
if (text_features.size() * LEN_IMAGE_FEATURE > text_features_input.size())

if (text_token.size() * LEN_TEXT_TOKEN > text_tokens_input.size())
{
text_features_input.resize(text_features.size() * LEN_IMAGE_FEATURE);
text_tokens_input.resize(text_token.size() * LEN_TEXT_TOKEN);
}

memset(image_features_input.data(), 0, image_features_input.size() * sizeof(float));
auto image_features_input_ptr = image_features_input.data();
memcpy(image_features_input_ptr, image_features.data(), image_features.size() * sizeof(CLIP_IMAG_FEATURE_T));

memset(text_features_input.data(), 0, text_features_input.size() * sizeof(int));
auto text_features_input_ptr = text_features_input.data();
memcpy(text_features_input_ptr, text_features.data(), text_features.size() * sizeof(CLIP_TEXT_FEATURE_T));

std::vector<Ort::Value> inputTensors;
memset(text_tokens_input.data(), 0, text_token.size() * LEN_TEXT_TOKEN * sizeof(int));
auto text_tokens_input_ptr = text_tokens_input.data();
for (size_t i = 0; i < text_token.size(); i++)
{
if (text_token[i].size() > LEN_TEXT_TOKEN)
{
ALOGW("text_features index %d ,bigger than %d\n", i, LEN_TEXT_TOKEN);
continue;
}
memcpy(text_tokens_input_ptr + i * LEN_TEXT_TOKEN, text_token[i].data(), text_token[i].size() * sizeof(int));
}

std::vector<int64_t> image_features_shape = {(int64_t)image_features.size(), LEN_IMAGE_FEATURE};
std::vector<int64_t> text_features_shape = {(int64_t)text_features.size(), LEN_TEXT_FEATURE};
std::vector<int64_t> text_token_shape = {(int64_t)text_token.size(), LEN_TEXT_TOKEN};

inputTensors.push_back(Ort::Value::CreateTensor<float>(
memory_info_handler, image_features_input.data(), image_features_input.size(), image_features_shape.data(), image_features_shape.size()));
inputTensors.push_back(Ort::Value::CreateTensor<int>(
memory_info_handler, text_features_input.data(), text_features_input.size(), text_features_shape.data(), text_features_shape.size()));
auto inputTensor = (Ort::Value::CreateTensor<int>(
memory_info_handler, text_tokens_input.data(), text_tokens_input.size(), text_token_shape.data(), text_token_shape.size()));

Ort::RunOptions runOptions;
auto DecoderOutputTensors = DecoderSession->Run(runOptions, DecoderInputNames, inputTensors.data(),
inputTensors.size(), DecoderOutputNames, 2);
auto OutputTensors = TextEncoderSession->Run(runOptions, TextEncInputNames, &inputTensor,
1, TextEncOutputNames, 1);

auto &logits_per_image_output = DecoderOutputTensors[0];
auto logits_per_image_ptr = logits_per_image_output.GetTensorMutableData<float>();
auto logits_per_image_shape = logits_per_image_output.GetTensorTypeAndShapeInfo().GetShape();
logits_per_image.resize(logits_per_image_shape[0]);
for (size_t i = 0; i < logits_per_image.size(); i++)
{
logits_per_image[i].resize(logits_per_image_shape[1]);
memcpy(logits_per_image[i].data(), logits_per_image_ptr + i * logits_per_image_shape[1], logits_per_image_shape[1] * sizeof(float));
}
auto &text_features_tensor = OutputTensors[0];
auto text_features_tensor_ptr = text_features_tensor.GetTensorMutableData<float>();
auto output_shape = text_features_tensor.GetTensorTypeAndShapeInfo().GetShape();

auto &logits_per_text_output = DecoderOutputTensors[1];
auto logits_per_text_ptr = logits_per_text_output.GetTensorMutableData<float>();
auto logits_per_text_shape = logits_per_text_output.GetTensorTypeAndShapeInfo().GetShape();
logits_per_text.resize(logits_per_text_shape[0]);
for (size_t i = 0; i < logits_per_text.size(); i++)
text_features.resize(output_shape[0]);

for (size_t i = 0; i < text_features.size(); i++)
{
logits_per_text[i].resize(logits_per_text_shape[1]);
memcpy(logits_per_text[i].data(), logits_per_text_ptr + i * logits_per_text_shape[1], logits_per_text_shape[1] * sizeof(float));
text_features[i].resize(output_shape[1]);
memcpy(text_features[i].data(), text_features_tensor_ptr + i * output_shape[1], output_shape[1] * sizeof(float));
}
}

void decode(std::vector<std::vector<float>> &image_features, std::vector<std::vector<int>> &text_features,
void decode(std::vector<std::vector<float>> &image_features, std::vector<std::vector<float>> &text_features,
std::vector<std::vector<float>> &logits_per_image, std::vector<std::vector<float>> &logits_per_text)
{
if (image_features.size() * LEN_IMAGE_FEATURE > image_features_input.size())
{
image_features_input.resize(image_features.size() * LEN_IMAGE_FEATURE);
}
if (text_features.size() * LEN_IMAGE_FEATURE > text_features_input.size())
if (text_features.size() * LEN_TEXT_FEATURE > text_features_input.size())
{
text_features_input.resize(text_features.size() * LEN_IMAGE_FEATURE);
text_features_input.resize(text_features.size() * LEN_TEXT_FEATURE);
}

memset(image_features_input.data(), 0, image_features_input.size() * sizeof(float));
Expand All @@ -158,16 +162,16 @@ class CLIP
memcpy(image_features_input_ptr + i * LEN_IMAGE_FEATURE, image_features[i].data(), LEN_IMAGE_FEATURE * sizeof(float));
}

memset(text_features_input.data(), 0, text_features_input.size() * sizeof(int));
memset(text_features_input.data(), 0, text_features_input.size() * sizeof(float));
auto text_features_input_ptr = text_features_input.data();
for (size_t i = 0; i < text_features.size(); i++)
{
if (text_features[i].size() > LEN_TEXT_FEATURE)
if (text_features[i].size() != LEN_TEXT_FEATURE)
{
ALOGW("text_features index %d ,bigger than %d\n", i, LEN_TEXT_FEATURE);
ALOGW("text_features index %d ,not equal %d\n", i, LEN_TEXT_FEATURE);
continue;
}
memcpy(text_features_input_ptr + i * LEN_TEXT_FEATURE, text_features[i].data(), text_features[i].size() * sizeof(int));
memcpy(text_features_input_ptr + i * LEN_TEXT_FEATURE, text_features[i].data(), text_features[i].size() * sizeof(float));
}
std::vector<Ort::Value> inputTensors;

Expand All @@ -176,7 +180,7 @@ class CLIP

inputTensors.push_back(Ort::Value::CreateTensor<float>(
memory_info_handler, image_features_input.data(), image_features_input.size(), image_features_shape.data(), image_features_shape.size()));
inputTensors.push_back(Ort::Value::CreateTensor<int>(
inputTensors.push_back(Ort::Value::CreateTensor<float>(
memory_info_handler, text_features_input.data(), text_features_input.size(), text_features_shape.data(), text_features_shape.size()));

Ort::RunOptions runOptions;
Expand Down
2 changes: 1 addition & 1 deletion src/Runner/CLIPAX650.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class CLIPAX650 : public CLIP
cv::Mat input;

public:
bool load_encoder(std::string encoder_path) override
bool load_image_encoder(std::string encoder_path) override
{
m_encoder.reset(new ax_runner_ax650);
m_encoder->init(encoder_path.c_str());
Expand Down
2 changes: 1 addition & 1 deletion src/Runner/CLIPOnnx.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class CLIPOnnx : public CLIP
cv::Mat input;

public:
bool load_encoder(std::string encoder_path) override
bool load_image_encoder(std::string encoder_path) override
{
m_encoder = CreateRunner(RT_OnnxRunner);
BaseConfig config;
Expand Down
9 changes: 9 additions & 0 deletions src/Runner/OnnxWarpper/OnnxWarpper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ class OnnxRunner : virtual public BaseRunner

printf("%20s: ", input_name.c_str());
auto input_shape = session.GetInputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape();

if (input_shape.size() && input_shape[0] < 0)
{
input_shape[0] = 1;
}
std::vector<size_t> tmp_input_shape(input_shape.size());
for (size_t j = 0; j < input_shape.size(); j++)
{
Expand Down Expand Up @@ -71,6 +76,10 @@ class OnnxRunner : virtual public BaseRunner

printf("%20s: ", output_name.c_str());
auto output_shape = session.GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape();
if (output_shape.size() && output_shape[0] < 0)
{
output_shape[0] = 1;
}
std::vector<size_t> tmp_output_shape(output_shape.size());
for (size_t j = 0; j < output_shape.size(); j++)
{
Expand Down
29 changes: 18 additions & 11 deletions src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,38 +10,42 @@ int main(int argc, char *argv[])
std::string image_src = "./test.jpg";
std::string text_src = "a cat";
std::string vocab_path = "./onnx_models/mobile_sam_decoder.onnx";
std::string encoder_model_path = "./onnx_models/mobile_sam_encoder.onnx";
std::string image_encoder_model_path = "./onnx_models/mobile_sam_encoder.onnx";
std::string text_encoder_model_path = "./onnx_models/mobile_sam_encoder.onnx";
std::string decoder_model_path = "./onnx_models/mobile_sam_decoder.onnx";

cmdline::parser cmd;
cmd.add<std::string>("encoder", 'e', "encoder model(onnx model or axmodel)", true, encoder_model_path);
cmd.add<std::string>("decoder", 'd', "decoder model(onnx)", true, decoder_model_path);
cmd.add<std::string>("ienc", 0, "encoder model(onnx model or axmodel)", true, image_encoder_model_path);
cmd.add<std::string>("tenc", 0, "text encoder model(onnx model or axmodel)", true, text_encoder_model_path);
cmd.add<std::string>("dec", 'd', "decoder model(onnx)", true, decoder_model_path);
cmd.add<std::string>("image", 'i', "image file or folder(jpg png etc....)", true, image_src);
cmd.add<std::string>("text", 't', "text or txt file", true, text_src);
cmd.add<std::string>("vocab", 'v', "vocab path", true, vocab_path);

cmd.parse_check(argc, argv);

vocab_path = cmd.get<std::string>("vocab");
encoder_model_path = cmd.get<std::string>("encoder");
decoder_model_path = cmd.get<std::string>("decoder");
image_encoder_model_path = cmd.get<std::string>("ienc");
text_encoder_model_path = cmd.get<std::string>("tenc");
decoder_model_path = cmd.get<std::string>("dec");

std::shared_ptr<CLIP> mClip;
if (string_utility<std::string>::ends_with(encoder_model_path, ".onnx"))
if (string_utility<std::string>::ends_with(image_encoder_model_path, ".onnx"))
{
mClip.reset(new CLIPOnnx);
}
else if (string_utility<std::string>::ends_with(encoder_model_path, ".axmodel"))
else if (string_utility<std::string>::ends_with(image_encoder_model_path, ".axmodel"))
{
mClip.reset(new CLIPAX650);
}
else
{
fprintf(stderr, "no impl for %s\n", encoder_model_path.c_str());
fprintf(stderr, "no impl for %s\n", image_encoder_model_path.c_str());
return -1;
}

mClip->load_encoder(encoder_model_path);
mClip->load_image_encoder(image_encoder_model_path);
mClip->load_text_encoder(text_encoder_model_path);
mClip->load_decoder(decoder_model_path);
mClip->load_tokenizer(vocab_path);

Expand Down Expand Up @@ -70,7 +74,7 @@ int main(int argc, char *argv[])
{
texts.push_back(text_src);
}
std::vector<std::vector<int>> text_features;
std::vector<std::vector<float>> text_features;
mClip->encode(texts, text_features);

std::vector<std::vector<float>> image_features;
Expand Down Expand Up @@ -107,8 +111,11 @@ int main(int argc, char *argv[])
}

std::vector<std::vector<float>> logits_per_image, logits_per_text;

auto time_start = std::chrono::high_resolution_clock::now();
mClip->decode(image_features, text_features, logits_per_image, logits_per_text);
auto time_end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = time_end - time_start;
std::cout << "decode Inference Cost time : " << diff.count() << "s" << std::endl;

printf("\n");
if (texts.size() > 1)
Expand Down

0 comments on commit 041647f

Please sign in to comment.