-
Notifications
You must be signed in to change notification settings - Fork 31
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[shortfin] Add C++ tokenizer wrapper library.
* This is gated by SHORTFIN_ENABLE_TOKENIZERS (presently off). * I'd like to either take over the wrapper or get mlc-ai/tokenizers-cpp#50 before putting much weight on this. * There is no great C++ option for this component, so we go to the trouble of integrating a Rust component. We will need to do a bit of prep on our CI systems to enable this by default. * Python API will be added in a subsequent commit. This should be more efficient than the tokenizers Python API since we will allow direct access to the tokens vs doing a lot of conversions. * Obligatory language flame bait: Use Rust, they said. It's super efficient. Prior to this patch, libshortfin was 1.8MB, which gave us an entire GPU and CPU runtime stack. After this patch (stripped) it is 8.4MB. Given how important the use case is, I'm willing to tolerate this for the moment. It seems like there is room for something better here, which is why I did not expose the underlying vendor'd API directly.
- Loading branch information
1 parent
ddc3091
commit 2016aae
Showing
9 changed files
with
330 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# Copyright 2024 Advanced Micro Devices, Inc. | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
# Downloads some test data file as part of configure. | ||
# This does a download->rename in an attempt to be robust to partial downloads. | ||
# It should not be used to manage large test data files or anything sensitive | ||
# enough to require a hash check. | ||
# The output file is added as an additional clean file on the global | ||
# shortfin_testdata_deps target, meaning the "ninja clean" will remove it. | ||
# It is also added to the current directories list of configure depends, which | ||
# means that if ninja is run and it is not present, cmake will be re-invoked. | ||
function(shortfin_download_test_data) | ||
cmake_parse_arguments( | ||
_RULE | ||
"" | ||
"URL;OUTPUT_FILE" | ||
"" | ||
${ARGN} | ||
) | ||
if(NOT EXISTS "${_RULE_OUTPUT_FILE}") | ||
set(_stage_file "${_RULE_OUTPUT_FILE}.stage") | ||
message(STATUS "Downloading test data ${_RULE_URL} -> ${_RULE_OUTPUT_FILE}") | ||
file(DOWNLOAD "${_RULE_URL}" "${_stage_file}" STATUS _status) | ||
list(POP_FRONT _status _status_code) | ||
if(_status_code EQUAL "0") | ||
file(RENAME "${_stage_file}" "${_RULE_OUTPUT_FILE}") | ||
else() | ||
message(SEND_ERROR "Error downloading file ${_RULE_URL} -> ${_RULE_OUTPUT_FILE}") | ||
endif() | ||
endif() | ||
|
||
# Make clean remove it. | ||
set_property( | ||
TARGET shortfin_testdata_deps | ||
APPEND PROPERTY ADDITIONAL_CLEAN_FILES | ||
"${CMAKE_CURRENT_BINARY_DIR}/tokenizer.json" | ||
) | ||
|
||
# And make us reconfigure if it isn't there. | ||
set_property( | ||
DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}" | ||
APPEND PROPERTY | ||
CMAKE_CONFIGURE_DEPENDS "${_RULE_OUTPUT_FILE}") | ||
endfunction() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
39 changes: 39 additions & 0 deletions
39
shortfin/src/shortfin/components/tokenizers/CMakeLists.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# Copyright 2024 Advanced Micro Devices, Inc. | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
if(SHORTFIN_ENABLE_TOKENIZERS) | ||
shortfin_cc_component( | ||
NAME | ||
shortfin_tokenizers | ||
HDRS | ||
tokenizers.h | ||
SRCS | ||
tokenizers.cc | ||
DEFINES | ||
SHORTFIN_HAVE_TOKENIZERS | ||
COMPONENTS | ||
shortfin_support | ||
DEPS | ||
tokenizers_cpp | ||
) | ||
set_property(GLOBAL APPEND | ||
PROPERTY SHORTFIN_LIB_OPTIONAL_COMPONENTS | ||
shortfin_tokenizers) | ||
target_compile_definitions(shortfin_public_defs INTERFACE SHORTFIN_HAVE_TOKENIZERS) | ||
|
||
# Download test data. | ||
shortfin_download_test_data( | ||
URL "https://huggingface.co/google-bert/bert-base-cased/resolve/main/tokenizer.json" | ||
OUTPUT_FILE "${CMAKE_CURRENT_BINARY_DIR}/tokenizer.json" | ||
) | ||
|
||
# Note that tests run from the binary dir of the project. | ||
shortfin_gtest_test( | ||
NAME shortfin_tokenizers_test | ||
SRCS | ||
tokenizers_test.cc | ||
) | ||
endif() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
// Copyright 2024 Advanced Micro Devices, Inc. | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
#include "shortfin/components/tokenizers/tokenizers.h" | ||
|
||
#include <exception> | ||
|
||
#include "shortfin/support/logging.h" | ||
#include "tokenizers_cpp.h" | ||
|
||
namespace shortfin::tokenizers { | ||
|
||
namespace { | ||
|
||
class AccessibleTokenizer : public Tokenizer { | ||
public: | ||
using Tokenizer::vendor_tokenizer_; | ||
}; | ||
|
||
::tokenizers::Tokenizer *Get(Tokenizer *self) { | ||
void *ptr = static_cast<AccessibleTokenizer *>(self)->vendor_tokenizer_; | ||
if (!ptr) { | ||
throw std::logic_error("Tokenizer is null"); | ||
} | ||
return static_cast<::tokenizers::Tokenizer *>(ptr); | ||
} | ||
|
||
} // namespace | ||
|
||
Tokenizer::~Tokenizer() { delete Get(this); } | ||
|
||
Tokenizer Tokenizer::FromBlobJSON(const std::string &json_blob) { | ||
SHORTFIN_TRACE_SCOPE_NAMED("Tokenizer::FromBlobJSON"); | ||
return Tokenizer(::tokenizers::Tokenizer::FromBlobJSON(json_blob).release()); | ||
} | ||
|
||
std::vector<int32_t> Tokenizer::Encode(const std::string &text) { | ||
SHORTFIN_TRACE_SCOPE_NAMED("Tokenizer::Encode"); | ||
return Get(this)->Encode(text); | ||
} | ||
|
||
std::vector<std::vector<int32_t>> Tokenizer::EncodeBatch( | ||
const std::vector<std::string> &texts) { | ||
SHORTFIN_TRACE_SCOPE_NAMED("Tokenizer::EncodeBatch"); | ||
return Get(this)->EncodeBatch(texts); | ||
} | ||
|
||
std::string Tokenizer::Decode(const std::vector<int32_t> &ids) { | ||
SHORTFIN_TRACE_SCOPE_NAMED("Tokenizer::Decode"); | ||
return Get(this)->Decode(ids); | ||
} | ||
size_t Tokenizer::GetVocabSize() { return Get(this)->GetVocabSize(); } | ||
std::string Tokenizer::IdToToken(int32_t token_id) { | ||
return Get(this)->IdToToken(token_id); | ||
} | ||
int32_t Tokenizer::TokenToId(const std::string &token) { | ||
return Get(this)->TokenToId(token); | ||
} | ||
|
||
} // namespace shortfin::tokenizers |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
// Copyright 2024 Advanced Micro Devices, Inc. | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
#ifndef SHORTFIN_COMPONENTS_TOKENIZERS_TOKENIZERS_H | ||
#define SHORTFIN_COMPONENTS_TOKENIZERS_TOKENIZERS_H | ||
|
||
#include <string> | ||
#include <vector> | ||
|
||
#include "shortfin/support/api.h" | ||
|
||
namespace shortfin::tokenizers { | ||
|
||
// A vendored Tokenizer class that does not export the details of the backing | ||
// implementation. While a little bit gross, this keeps us from needing to | ||
// re-export a vendor'ed API as part of our public API. | ||
// The current vendor tokenizer is based on mlc-ai/tokenizers-cpp. The API | ||
// is fairly close to that implementation. | ||
// See: https://github.com/mlc-ai/tokenizers-cpp | ||
class SHORTFIN_API Tokenizer { | ||
public: | ||
Tokenizer(const Tokenizer &) = delete; | ||
Tokenizer &operator=(const Tokenizer &) = delete; | ||
Tokenizer(Tokenizer &&other) : vendor_tokenizer_(other.vendor_tokenizer_) { | ||
vendor_tokenizer_ = nullptr; | ||
} | ||
~Tokenizer(); | ||
|
||
// Factory functions. | ||
static Tokenizer FromBlobJSON(const std::string &json_blob); | ||
|
||
std::vector<int32_t> Encode(const std::string &text); | ||
std::vector<std::vector<int32_t>> EncodeBatch( | ||
const std::vector<std::string> &texts); | ||
std::string Decode(const std::vector<int32_t> &ids); | ||
size_t GetVocabSize(); | ||
std::string IdToToken(int32_t token_id); | ||
int32_t TokenToId(const std::string &token); | ||
|
||
private: | ||
Tokenizer(void *vendor_tokenizer) : vendor_tokenizer_(vendor_tokenizer) {} | ||
|
||
protected: | ||
void *vendor_tokenizer_; | ||
}; | ||
|
||
} // namespace shortfin::tokenizers | ||
|
||
#endif // SHORTFIN_COMPONENTS_TOKENIZERS_TOKENIZERS_H |
Oops, something went wrong.