diff --git a/CMakeLists.txt b/CMakeLists.txt index b52e7fad..a7a424dd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -115,6 +115,7 @@ if(ENABLE_PROSTT5) set(GGML_STATIC ON) set(BUILD_SHARED_LIBS OFF) set(GGML_BLAS OFF) + set(GGML_OPENMP OFF) if (NOT NATIVE_ARCH) set(GGML_NATIVE OFF) if (HAVE_AVX2) diff --git a/src/strucclustutils/ProstT5.cpp b/src/strucclustutils/ProstT5.cpp index ffa97f18..605ce300 100644 --- a/src/strucclustutils/ProstT5.cpp +++ b/src/strucclustutils/ProstT5.cpp @@ -34,24 +34,14 @@ static char number_to_char(unsigned int n) { static int encode(llama_context * ctx, std::vector & enc_input, std::string & result) { const struct llama_model * model = llama_get_model(ctx); - // clear previous kv_cache values (irrelevant for embeddings) - // llama_kv_cache_clear(ctx); - // llama_set_embeddings(ctx, true); - // run model - if (llama_model_has_encoder(model) && !llama_model_has_decoder(model)) { - if (llama_encode(ctx, llama_batch_get_one(enc_input.data(), enc_input.size())) < 0) { - // LOG_ERR("%s : failed to encode\n", __func__); - return 1; - } - } else { - // LOG_ERR("%s : no encoder\n", __func__); + if (llama_encode(ctx, llama_batch_get_one(enc_input.data(), enc_input.size())) < 0) { + // LOG_ERR("%s : failed to encode\n", __func__); return 1; } - // Log the embeddings (assuming n_embd is the embedding size per token) + // LOG_INF("%s: n_tokens = %zu, n_seq = %d\n", __func__, enc_input.size(), 1); float* embeddings = llama_get_embeddings(ctx); if (embeddings == nullptr) { - // LOG_ERR("%s : failed to retrieve embeddings\n", __func__); return 1; } int * arg_max_idx = new int[enc_input.size()]; @@ -69,8 +59,8 @@ static int encode(llama_context * ctx, std::vector & enc_input, std for (int i = 0; i < seq_len - 1; ++i) { result.push_back(number_to_char(arg_max_idx[i])); } - delete [] arg_max_idx; - delete [] arg_max; + delete[] arg_max_idx; + delete[] arg_max; return 0; } @@ -110,182 +100,18 @@ static std::vector parse_device_list(const std::string & val return devices; } -struct lora_adapter_info { - std::string path; - float scale; -}; - -struct lora_adapter_container : lora_adapter_info { - struct llama_lora_adapter * adapter; -}; - -struct init_result { - struct llama_model * model = nullptr; - struct llama_context * context = nullptr; - std::vector lora_adapters; -}; - -struct cpu_params { - int n_threads = -1; - bool cpumask[GGML_MAX_N_THREADS] = {false}; // CPU affinity mask. - bool mask_valid = false; // Default: any CPU - enum ggml_sched_priority priority = GGML_SCHED_PRIO_NORMAL; // Scheduling prio : (0 - normal, 1 - medium, 2 - high, 3 - realtime) - bool strict_cpu = false; // Use strict CPU placement - uint32_t poll = 50; // Polling (busywait) level (0 - no polling, 100 - mostly polling) -}; - -struct common_params { - int32_t n_ctx = 4096; // context size - int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS) - int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS) - int32_t n_parallel = 1; // number of parallel sequences to decode - // float rope_freq_base = 0.0f; // RoPE base frequency - // float rope_freq_scale = 0.0f; // RoPE frequency scaling factor - // float yarn_ext_factor = -1.0f; // YaRN extrapolation mix factor - // float yarn_attn_factor = 1.0f; // YaRN magnitude scaling factor - // float yarn_beta_fast = 32.0f; // YaRN low correction dim - // float yarn_beta_slow = 1.0f; // YaRN high correction dim - // int32_t yarn_orig_ctx = 0; // YaRN original context length - float defrag_thold = 0.1f; // KV cache defragmentation threshold - - // // offload params - std::vector devices; // devices to use for offloading - - int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default) - int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors - float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs - - enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs - - struct cpu_params cpuparams; - struct cpu_params cpuparams_batch; - - ggml_backend_sched_eval_callback cb_eval = nullptr; - void * cb_eval_user_data = nullptr; - - ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED; - - enum llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED; - enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings - enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings - - std::string model = ""; // model path // NOLINT - std::string rpc_servers = ""; // comma separated list of RPC servers // NOLINT - - bool lora_init_without_apply = false; // only load lora to memory, but do not apply it to ctx (user can manually apply lora later using llama_lora_adapter_apply) - std::vector lora_adapters; // lora adapter path with user defined scale - - bool flash_attn = false; // flash attention - bool no_perf = false; // disable performance metrics - bool logits_all = false; // return logits for all tokens in the batch - bool use_mmap = true; // use mmap for faster loads - bool use_mlock = false; // use mlock to keep model in memory - bool no_kv_offload = false; // disable KV offloading - bool warmup = true; // warmup run - bool check_tensors = false; // validate tensor data - - ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K - ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V - - bool embedding = true; // get only sentence embedding -}; - -static struct init_result init_from_params(common_params & params) { - init_result iparams; - auto mparams = llama_model_default_params(); - - if (!params.devices.empty()) { - mparams.devices = params.devices.data(); - } - if (params.n_gpu_layers != -1) { - mparams.n_gpu_layers = params.n_gpu_layers; - } - mparams.rpc_servers = params.rpc_servers.c_str(); - mparams.main_gpu = params.main_gpu; - mparams.split_mode = params.split_mode; - mparams.tensor_split = params.tensor_split; - mparams.use_mmap = params.use_mmap; - mparams.use_mlock = params.use_mlock; - mparams.check_tensors = params.check_tensors; - mparams.n_gpu_layers = params.n_gpu_layers; - mparams.kv_overrides = NULL; - - llama_model * model = nullptr; - - model = llama_load_model_from_file(params.model.c_str(), mparams); - - if (model == NULL) { - // LOG_ERR("%s: failed to load model '%s'\n", __func__, params.model.c_str()); - return iparams; - } - - auto cparams = llama_context_default_params(); - - cparams.n_ctx = params.n_ctx; - cparams.n_seq_max = params.n_parallel; - cparams.n_batch = params.n_batch; - cparams.n_ubatch = params.n_ubatch; - cparams.n_threads = params.cpuparams.n_threads; - cparams.n_threads_batch = params.cpuparams_batch.n_threads == -1 ? - params.cpuparams.n_threads : params.cpuparams_batch.n_threads; - cparams.logits_all = params.logits_all; - cparams.embeddings = params.embedding; - // cparams.rope_scaling_type = params.rope_scaling_type; - // cparams.rope_freq_base = params.rope_freq_base; - // cparams.rope_freq_scale = params.rope_freq_scale; - // cparams.yarn_ext_factor = params.yarn_ext_factor; - // cparams.yarn_attn_factor = params.yarn_attn_factor; - // cparams.yarn_beta_fast = params.yarn_beta_fast; - // cparams.yarn_beta_slow = params.yarn_beta_slow; - // cparams.yarn_orig_ctx = params.yarn_orig_ctx; - cparams.pooling_type = params.pooling_type; - cparams.attention_type = params.attention_type; - cparams.defrag_thold = params.defrag_thold; - cparams.cb_eval = params.cb_eval; - cparams.cb_eval_user_data = params.cb_eval_user_data; - cparams.offload_kqv = !params.no_kv_offload; - cparams.flash_attn = params.flash_attn; - cparams.no_perf = params.no_perf; - - cparams.type_k = params.cache_type_k; - cparams.type_v = params.cache_type_v; - - llama_context * lctx = llama_new_context_with_model(model, cparams); - if (lctx == NULL) { - // LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.c_str()); - llama_free_model(model); - return iparams; - } - - - // load and optionally apply lora adapters - for (auto & la : params.lora_adapters) { - lora_adapter_container loaded_la; - loaded_la.path = la.path; - loaded_la.scale = la.scale; - loaded_la.adapter = llama_lora_adapter_init(model, la.path.c_str()); - if (loaded_la.adapter == nullptr) { - // LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str()); - llama_free(lctx); - llama_free_model(model); - return iparams; - } - iparams.lora_adapters.push_back(loaded_la); // copy to list of loaded adapters - } - if (!params.lora_init_without_apply) { - llama_lora_adapter_clear(lctx); - for (auto & la : iparams.lora_adapters) { - if (la.scale != 0.0f) { - llama_lora_adapter_set(lctx, la.adapter, la.scale); - } - } - } +// struct lora_adapter_info { +// std::string path; +// float scale; +// }; - iparams.model = model; - iparams.context = lctx; +// struct lora_adapter_container : lora_adapter_info { +// struct llama_lora_adapter* adapter; +// }; - return iparams; -} +// struct init_result { +// std::vector lora_adapters; +// }; LlamaInitGuard::LlamaInitGuard(bool verbose) { if (!verbose) { @@ -299,57 +125,88 @@ LlamaInitGuard::~LlamaInitGuard() { llama_backend_free(); } -ProstT5::ProstT5(const std::string& model_file, std::string & device) { - common_params params; - params.n_ubatch = params.n_batch; - params.warmup = false; - params.model = model_file; - params.cpuparams.n_threads = 1; - params.use_mmap = true; - params.devices = parse_device_list(device); +ProstT5Model::ProstT5Model(const std::string& model_file, std::string& device) { + auto mparams = llama_model_default_params(); + std::vector devices = parse_device_list(device); + if (!devices.empty()) { + mparams.devices = devices.data(); + } + int gpus = 0; - for (const auto& dev : params.devices) { + for (const auto& dev : devices) { if (!dev) { continue; } gpus += ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU; } if (gpus > 0) { - params.n_gpu_layers = 24; + mparams.n_gpu_layers = 24; } else { - params.n_gpu_layers = 0; - } + mparams.n_gpu_layers = 0; + } + mparams.use_mmap = true; + model = llama_load_model_from_file(model_file.c_str(), mparams); + + // for (auto & la : params.lora_adapters) { + // lora_adapter_container loaded_la; + // loaded_la.path = la.path; + // loaded_la.scale = la.scale; + // loaded_la.adapter = llama_lora_adapter_init(model, la.path.c_str()); + // if (loaded_la.adapter == nullptr) { + // llama_free_model(model); + // return; + // } + // lora_adapters.push_back(loaded_la); // copy to list of loaded adapters + // } +} - // load the model - init_result llama_init = init_from_params(params); +ProstT5Model::~ProstT5Model() { + llama_free_model(model); +} - model = llama_init.model; - ctx = llama_init.context; +ProstT5::ProstT5(ProstT5Model& model, int threads) : model(model) { + auto cparams = llama_context_default_params(); + cparams.n_threads = threads; + cparams.n_threads_batch = threads; + cparams.n_ubatch = 4096; + cparams.n_batch = 4096; + cparams.n_ctx = 4096; + cparams.embeddings = true; + cparams.attention_type = LLAMA_ATTENTION_TYPE_NON_CAUSAL; + + ctx = llama_new_context_with_model(model.model, cparams); + // batch = llama_batch_init(4096, 0, 1); + // if (!params.lora_init_without_apply) { + // llama_lora_adapter_clear(lctx); + // for (auto & la : iparams.lora_adapters) { + // if (la.scale != 0.0f) { + // llama_lora_adapter_set(lctx, la.adapter, la.scale); + // } + // } + // } }; ProstT5::~ProstT5() { llama_free(ctx); - llama_free_model(model); } std::string ProstT5::predict(const std::string& aa) { std::string result; std::vector embd_inp; embd_inp.reserve(aa.length() + 2); - embd_inp.emplace_back(llama_token_get_token(model, "")); - llama_token unk_aa = llama_token_get_token(model, "▁X"); + embd_inp.emplace_back(llama_token_get_token(model.model, "")); + llama_token unk_aa = llama_token_get_token(model.model, "▁X"); for (size_t i = 0; i < aa.length(); ++i) { std::string current_char("▁"); current_char.append(1, toupper(aa[i])); - llama_token token = llama_token_get_token(model, current_char.c_str()); + llama_token token = llama_token_get_token(model.model, current_char.c_str()); if (token == LLAMA_TOKEN_NULL) { embd_inp.emplace_back(unk_aa); } else { embd_inp.emplace_back(token); } } - embd_inp.emplace_back(llama_token_get_token(model, "")); - + embd_inp.emplace_back(llama_token_get_token(model.model, "")); encode(ctx, embd_inp, result); return result; } diff --git a/src/strucclustutils/ProstT5.h b/src/strucclustutils/ProstT5.h index 16ef9ef6..9ad388dd 100644 --- a/src/strucclustutils/ProstT5.h +++ b/src/strucclustutils/ProstT5.h @@ -16,9 +16,17 @@ class LlamaInitGuard { LlamaInitGuard& operator=(const LlamaInitGuard&) = delete; }; +class ProstT5Model { +public: + ProstT5Model(const std::string& model_file, std::string& device); + ~ProstT5Model(); + + llama_model* model; +}; + class ProstT5 { public: - ProstT5(const std::string& model_file, std::string & device); + ProstT5(ProstT5Model& model, int threads); ~ProstT5(); static std::vector getDevices(); @@ -26,8 +34,8 @@ class ProstT5 { std::string predict(const std::string& aa); void perf(); - llama_model * model; - llama_context * ctx; + ProstT5Model& model; + llama_context* ctx; }; diff --git a/src/strucclustutils/ProstT5ForkRunner.h b/src/strucclustutils/ProstT5ForkRunner.h new file mode 100644 index 00000000..b6eb24e3 --- /dev/null +++ b/src/strucclustutils/ProstT5ForkRunner.h @@ -0,0 +1,160 @@ +#ifndef PROSTT5_FORK_RUNNER_H +#define PROSTT5_FORK_RUNNER_H +#include "DBReader.h" +#include "DBWriter.h" +#include "ProstT5.h" + +#include +#include +#include +#include +#include + +#ifdef OPENMP +#include +#endif + +struct TaskMsg { + long mtype; + long idx; +}; + +void prostt5Forking( + const std::string& modelWeights, + unsigned int split_length, + unsigned int minSplitLength, + DBReader& reader, + const std::string& db, + const std::string& index, + int threads, + int compressed + ) { +#ifdef OPENMP + // forking does not play well with OpenMP threads + omp_set_num_threads(1); +#endif + + int procs = (threads + 3) / 4; + int leftover = threads; + int msgid = msgget(IPC_PRIVATE, IPC_CREAT | 0666); + if (msgid == -1) { + Debug(Debug::ERROR) << "Could not create SysV message queue!\n"; + EXIT(EXIT_FAILURE); + } + + Debug::Progress progress(reader.getSize()); + std::vector children; + int maxSplits = procs; + for (int p = 0; p < procs; ++p) { + int inner = std::min(4, leftover); + leftover -= inner; + switch (pid_t pid = fork()) { + default: + children.push_back(pid); + break; + case -1: + Debug(Debug::ERROR) << "Could not fork worker process!\n"; + EXIT(EXIT_FAILURE); + case 0: { + std::string device = "none"; + ProstT5Model model(modelWeights, device); + ProstT5 context(model, inner); + std::string result; + const char newline = '\n'; + + std::pair outDb = Util::createTmpFileNames(db, index, p); + DBWriter writer(outDb.first.c_str(), outDb.second.c_str(), 1, compressed, reader.getDbtype()); + writer.open(); + while (true) { + TaskMsg msg; + if (msgrcv(msgid, &msg, sizeof(msg.idx), 0, 0) == -1) { + Debug(Debug::ERROR) << "msgrcv failed in child " << p << "\n"; + _Exit(1); + } + if (msg.idx == -1) { + break; + } + + size_t i = static_cast(msg.idx); + unsigned int key = reader.getDbKey(i); + size_t length = reader.getSeqLen(i); + std::string seq = std::string(reader.getData(i, 0), length); + result.clear(); + // splitting input sequences longer than ProstT5 attention (current cutoff 6000 AAs) + // split length of 0 will deactivate splitting + // Debug(Debug::INFO) << "split_length: " << split_length << " minSplitLength: " << minSplitLength << "\n"; + // Debug(Debug::INFO) << "seq: " << seq << "\n"; + if (split_length > 0 && length > split_length) { + unsigned int n_splits, overlap_length; + n_splits = int(length / split_length) + 1; + overlap_length = length % split_length; + + // ensure minimum overlap length; adjustment length was not computed properly with ceil/ceilf now using simple int cast + if (overlap_length < minSplitLength) { + split_length -= int((minSplitLength - overlap_length) / (n_splits - 1)) + 1; + } + + // loop over splits and predict + for (unsigned int i = 0; i < n_splits; i++){ + unsigned int split_start = i * split_length; + result.append(context.predict(seq.substr(split_start, split_length))); + } + } else { + result.append(context.predict(seq)); + } + // Debug(Debug::INFO) << "p: " << p << "pred: " << result << "\n"; + + writer.writeStart(0); + writer.writeAdd(result.c_str(), result.length(), 0); + writer.writeAdd(&newline, 1, 0); + writer.writeEnd(key, 0); + progress.updateProgress(); + } + + std::cout.setstate(std::ios_base::failbit); + writer.close(true); + fflush(NULL); + sync(); + _Exit(0); + } + } + if (leftover <= 0) { + maxSplits = p + 1; + break; + } + } + + for (size_t i = 0; i < reader.getSize(); ++i) { + TaskMsg msg {1, static_cast(i)}; + if (msgsnd(msgid, &msg, sizeof(msg.idx), 0) == -1) { + Debug(Debug::ERROR) << "msgsnd failed for index " << i << "\n"; + EXIT(EXIT_FAILURE); + } + } + + for (int p = 0; p < procs; ++p) { + TaskMsg quitMsg {1, -1}; + msgsnd(msgid, &quitMsg, sizeof(quitMsg.idx), 0); + } + + for (const pid_t& child_pid : children) { + int status = 0; + while (waitpid(child_pid, &status, 0) == -1) { + if (errno == EINTR) { + continue; + } + perror("waitpid"); + break; + } + } + msgctl(msgid, IPC_RMID, nullptr); + fflush(NULL); + sync(); + std::pair outDb = std::make_pair(db, index); + std::vector> splitFiles; + for (int p = 0; p < maxSplits; ++p) { + splitFiles.emplace_back(Util::createTmpFileNames(outDb.first, outDb.second, p)); + } + DBWriter::mergeResults(outDb.first, outDb.second, splitFiles); +} +#endif \ No newline at end of file diff --git a/src/strucclustutils/structcreatedb.cpp b/src/strucclustutils/structcreatedb.cpp index b98c8b3f..2059ab0b 100644 --- a/src/strucclustutils/structcreatedb.cpp +++ b/src/strucclustutils/structcreatedb.cpp @@ -19,6 +19,12 @@ #ifdef HAVE_PROSTT5 #include "ProstT5.h" +#if !defined(__CYGWIN__) && !defined(__EMSCRIPTEN__) && !defined(__APPLE__) +#include "ProstT5ForkRunner.h" +#define FORK_RUNNER 1 +#else +#define FORK_RUNNER 0 +#endif #endif #include @@ -546,6 +552,7 @@ void sortDatafileByIdOrder(DBWriter & dbw, } } } + extern int createdb(int argc, const char **argv, const Command& command); int structcreatedb(int argc, const char **argv, const Command& command) { LocalParameters& par = LocalParameters::getLocalInstance(); @@ -557,20 +564,14 @@ int structcreatedb(int argc, const char **argv, const Command& command) { for (size_t i = 0; i < command.params->size(); ++i) { command.params->at(i)->wasSet = false; } + par.shuffleDatabase = true; + par.PARAM_SHUFFLE.wasSet = true; int status = createdb(argc, argv, command); if (status != EXIT_SUCCESS) { return status; } fflush(stdout); - DBReader reader(outputName.c_str(), (outputName+".index").c_str(), par.threads, DBReader::USE_INDEX|DBReader::USE_DATA); - reader.open(DBReader::LINEAR_ACCCESS); - - std::string ssDb = outputName + "_ss"; - DBWriter writer(ssDb.c_str(), (ssDb + ".index").c_str(), par.threads, par.compressed, reader.getDbtype()); - writer.open(); - Debug::Progress progress(reader.getSize()); - std::vector prefix = { "", "/model" }; std::vector suffix = { "", "/prostt5-f16.gguf" }; // bool quantized = false; @@ -608,6 +609,9 @@ int structcreatedb(int argc, const char **argv, const Command& command) { LlamaInitGuard guard(par.verbosity > 3); std::vector devices = ProstT5::getDevices(); + for (std::vector::iterator it = devices.begin(); it != devices.end(); ++it) { + Debug(Debug::INFO) << *it << "\n"; + } if (par.gpu == 1 && !devices.empty()) { for (std::vector::iterator it = devices.begin(); it != devices.end();) { if (it->find("CUDA") == std::string::npos) { @@ -616,7 +620,7 @@ int structcreatedb(int argc, const char **argv, const Command& command) { ++it; // Move to the next element } } - if(devices.size() == 0) { + if (devices.size() == 0) { Debug(Debug::ERROR) << "No GPU devices found\n"; return EXIT_FAILURE; } @@ -631,62 +635,77 @@ int structcreatedb(int argc, const char **argv, const Command& command) { } } - #ifdef OPENMP - size_t localThreads = (par.gpu != 0) ? devices.size() : par.threads; - localThreads = std::max(std::min(localThreads, reader.getSize()), (size_t)1); - #endif - #pragma omp parallel num_threads(localThreads) - { - int thread_idx = 0; + bool useForkRunner = FORK_RUNNER && par.gpu == 0; + DBReader reader(outputName.c_str(), (outputName+".index").c_str(), par.threads, DBReader::USE_INDEX|DBReader::USE_DATA); + reader.open(useForkRunner ? DBReader::SORT_BY_LENGTH : DBReader::LINEAR_ACCCESS); + + unsigned const int MIN_SPLIT_LENGTH = 2; + std::string ssDb = outputName + "_ss"; + std::string ssIndex = ssDb + ".index"; + if (useForkRunner) { + prostt5Forking(modelWeights, par.prostt5SplitLength, MIN_SPLIT_LENGTH, reader, ssDb, ssIndex, par.threads, par.compressed); + } else { + DBWriter writer(ssDb.c_str(), ssIndex.c_str(), par.threads, par.compressed, reader.getDbtype()); + writer.open(); + + Debug::Progress progress(reader.getSize()); #ifdef OPENMP - thread_idx = omp_get_thread_num(); + size_t localThreads = par.gpu == 1 ? devices.size() : 1; #endif - std::string device = "none"; - if (par.gpu == 1) { - device = devices[thread_idx]; - } - ProstT5 model(modelWeights.c_str(), device); - const char newline = '\n'; - unsigned const int MIN_SPLIT_LENGTH = 2; - std::string result; - #pragma omp for schedule(dynamic, 1) - for (size_t i = 0; i < reader.getSize(); ++i) { - unsigned int key = reader.getDbKey(i); - size_t length = reader.getSeqLen(i); - std::string seq = std::string(reader.getData(i, thread_idx), length); - result.clear(); - // splitting input sequences longer than ProstT5 attention (current cutoff 6000 AAs) - unsigned int split_length = par.prostt5SplitLength; - - // split lenght of 0 will deactivate splitting - if (split_length > 0 && length > split_length) { - unsigned int n_splits, overlap_length; - n_splits = int(length / split_length) + 1; - overlap_length = length % split_length; - - // ensure minimum overlap length; adjustment length was not computed properly with ceil/ceilf now using simple int cast - if (overlap_length < MIN_SPLIT_LENGTH) { - split_length -= int((MIN_SPLIT_LENGTH - overlap_length) / (n_splits - 1)) + 1; - } +#pragma omp parallel num_threads(localThreads) + { + int thread_idx = 0; +#ifdef OPENMP + thread_idx = omp_get_thread_num(); +#endif + std::string device = "none"; + int localThreads = par.threads; + if (par.gpu == 1) { + device = devices[thread_idx]; + localThreads = 1; + } + ProstT5Model model(modelWeights.c_str(), device); + ProstT5 context(model, localThreads); + const char newline = '\n'; + std::string result; +#pragma omp for schedule(dynamic, 1) + for (size_t i = 0; i < reader.getSize(); ++i) { + unsigned int key = reader.getDbKey(i); + size_t length = reader.getSeqLen(i); + std::string seq = std::string(reader.getData(i, thread_idx), length); + result.clear(); + + // splitting input sequences longer than ProstT5 attention (current cutoff 6000 AAs) + unsigned int split_length = par.prostt5SplitLength; + // split lenght of 0 will deactivate splitting + if (split_length > 0 && length > split_length) { + unsigned int n_splits, overlap_length; + n_splits = int(length / split_length) + 1; + overlap_length = length % split_length; + + // ensure minimum overlap length; adjustment length was not computed properly with ceil/ceilf now using simple int cast + if (overlap_length < MIN_SPLIT_LENGTH) { + split_length -= int((MIN_SPLIT_LENGTH - overlap_length) / (n_splits - 1)) + 1; + } - // loop over splits and predict - for (unsigned int i = 0; i < n_splits; i++){ - unsigned int split_start = i * split_length; - std::vector split_input(split_length); - result.append(model.predict(std::string(seq.substr(split_start, split_length)))); + // loop over splits and predict + for (unsigned int i = 0; i < n_splits; i++){ + unsigned int split_start = i * split_length; + result.append(context.predict(seq.substr(split_start, split_length))); + } + } else { + result.append(context.predict(seq)); } - } else { - result.append(model.predict(seq)); - } - writer.writeStart(thread_idx); - writer.writeAdd(result.c_str(), result.length(), thread_idx); - writer.writeAdd(&newline, 1, thread_idx); - writer.writeEnd(key, thread_idx); - progress.updateProgress(); + writer.writeStart(thread_idx); + writer.writeAdd(result.c_str(), result.length(), thread_idx); + writer.writeAdd(&newline, 1, thread_idx); + writer.writeEnd(key, thread_idx); + progress.updateProgress(); + } } + writer.close(true); } - writer.close(true); reader.close(); DBReader resultReader(ssDb.c_str(), (ssDb+".index").c_str(), par.threads, DBReader::USE_INDEX|DBReader::USE_DATA);