Skip to content

Commit

Permalink
Fix test with resources (pytorch#7071)
Browse files Browse the repository at this point in the history
Fix test failure due to resources not handled correctly by ios tests.

Differential Revision: [D66392647](https://our.internmc.facebook.com/intern/diff/D66392647/)

ghstack-source-id: 255370795
Pull Request resolved: pytorch#7062

Co-authored-by: Mengwei Liu <[email protected]>
  • Loading branch information
pytorchbot and larryliu0820 authored Nov 25, 2024
1 parent a1f668d commit 20c8e8c
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 21 deletions.
15 changes: 13 additions & 2 deletions examples/models/llama/tokenizer/test/test_tiktoken.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,31 @@

#include <gtest/gtest.h>

#ifdef EXECUTORCH_FB_BUCK
#include <TestResourceUtils/TestResourceUtils.h>
#endif

using namespace ::testing;

using ::example::Version;
using ::executorch::extension::llm::Tokenizer;
using ::executorch::runtime::Error;
using ::executorch::runtime::Result;

static std::string get_resource_path(const std::string& name) {
#ifdef EXECUTORCH_FB_BUCK
return facebook::xplat::testing::getPathForTestResource("resources/" + name);
#else
return std::getenv("RESOURCES_PATH") + std::string("/") + name;
#endif
}

class MultimodalTiktokenV5ExtensionTest : public Test {
public:
void SetUp() override {
executorch::runtime::runtime_init();
tokenizer_ = get_tiktoken_for_llama(Version::Multimodal);
modelPath_ = std::getenv("RESOURCES_PATH") +
std::string("/test_tiktoken_tokenizer.model");
modelPath_ = get_resource_path("test_tiktoken_tokenizer.model");
}

std::unique_ptr<Tokenizer> tokenizer_;
Expand Down
8 changes: 8 additions & 0 deletions extension/llm/tokenizer/test/test_bpe_tokenizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
* LICENSE file in the root directory of this source tree.
*/

#ifdef EXECUTORCH_FB_BUCK
#include <TestResourceUtils/TestResourceUtils.h>
#endif
#include <executorch/extension/llm/tokenizer/bpe_tokenizer.h>
#include <executorch/runtime/platform/runtime.h>
#include <gtest/gtest.h>
Expand All @@ -23,8 +26,13 @@ class TokenizerExtensionTest : public Test {
void SetUp() override {
executorch::runtime::runtime_init();
tokenizer_ = std::make_unique<BPETokenizer>();
#ifdef EXECUTORCH_FB_BUCK
modelPath_ = facebook::xplat::testing::getPathForTestResource(
"resources/test_bpe_tokenizer.bin");
#else
modelPath_ =
std::getenv("RESOURCES_PATH") + std::string("/test_bpe_tokenizer.bin");
#endif
}

std::unique_ptr<Tokenizer> tokenizer_;
Expand Down
40 changes: 21 additions & 19 deletions extension/llm/tokenizer/test/test_tiktoken.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
* LICENSE file in the root directory of this source tree.
*/

#ifdef EXECUTORCH_FB_BUCK
#include <TestResourceUtils/TestResourceUtils.h>
#endif
#include <executorch/extension/llm/tokenizer/tiktoken.h>
#include <executorch/runtime/platform/runtime.h>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <sstream>
#include <vector>

using namespace ::testing;
Expand Down Expand Up @@ -47,6 +49,15 @@ static inline std::unique_ptr<std::vector<std::string>> _get_special_tokens() {
}
return special_tokens;
}

static inline std::string _get_resource_path(const std::string& name) {
#ifdef EXECUTORCH_FB_BUCK
return facebook::xplat::testing::getPathForTestResource("resources/" + name);
#else
return std::getenv("RESOURCES_PATH") + std::string("/") + name;
#endif
}

} // namespace

class TiktokenExtensionTest : public Test {
Expand All @@ -55,8 +66,7 @@ class TiktokenExtensionTest : public Test {
executorch::runtime::runtime_init();
tokenizer_ = std::make_unique<Tiktoken>(
_get_special_tokens(), kBOSTokenIndex, kEOSTokenIndex);
modelPath_ = std::getenv("RESOURCES_PATH") +
std::string("/test_tiktoken_tokenizer.model");
modelPath_ = _get_resource_path("test_tiktoken_tokenizer.model");
}

std::unique_ptr<Tokenizer> tokenizer_;
Expand Down Expand Up @@ -144,44 +154,36 @@ TEST_F(TiktokenExtensionTest, ConstructionWithInvalidEOSIndex) {
}

TEST_F(TiktokenExtensionTest, LoadWithInvalidPath) {
auto invalidModelPath =
std::getenv("RESOURCES_PATH") + std::string("/nonexistent.model");

Error res = tokenizer_->load(invalidModelPath.c_str());
auto invalidModelPath = "./nonexistent.model";
Error res = tokenizer_->load(invalidModelPath);
EXPECT_EQ(res, Error::InvalidArgument);
}

TEST_F(TiktokenExtensionTest, LoadTiktokenFileWithInvalidRank) {
auto invalidModelPath = std::getenv("RESOURCES_PATH") +
std::string("/test_tiktoken_invalid_rank.model");

auto invalidModelPath =
_get_resource_path("test_tiktoken_invalid_rank.model");
Error res = tokenizer_->load(invalidModelPath.c_str());

EXPECT_EQ(res, Error::InvalidArgument);
}

TEST_F(TiktokenExtensionTest, LoadTiktokenFileWithInvalidBase64) {
auto invalidModelPath = std::getenv("RESOURCES_PATH") +
std::string("/test_tiktoken_invalid_base64.model");

auto invalidModelPath =
_get_resource_path("test_tiktoken_invalid_base64.model");
Error res = tokenizer_->load(invalidModelPath.c_str());

EXPECT_EQ(res, Error::InvalidArgument);
}

TEST_F(TiktokenExtensionTest, LoadTiktokenFileWithNoSpace) {
auto invalidModelPath = std::getenv("RESOURCES_PATH") +
std::string("/test_tiktoken_no_space.model");

auto invalidModelPath = _get_resource_path("test_tiktoken_no_space.model");
Error res = tokenizer_->load(invalidModelPath.c_str());

EXPECT_EQ(res, Error::InvalidArgument);
}

TEST_F(TiktokenExtensionTest, LoadTiktokenFileWithBPEFile) {
auto invalidModelPath =
std::getenv("RESOURCES_PATH") + std::string("/test_bpe_tokenizer.bin");

auto invalidModelPath = _get_resource_path("test_bpe_tokenizer.bin");
Error res = tokenizer_->load(invalidModelPath.c_str());

EXPECT_EQ(res, Error::InvalidArgument);
Expand Down

0 comments on commit 20c8e8c

Please sign in to comment.