Skip to content

Commit

Permalink
use ::testing::TempDir/SrcDir
Browse files Browse the repository at this point in the history
  • Loading branch information
taku910 committed Feb 26, 2024
1 parent 3b2ea62 commit 9082653
Show file tree
Hide file tree
Showing 12 changed files with 55 additions and 50 deletions.
10 changes: 5 additions & 5 deletions src/bpe_model_trainer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ std::string RunTrainer(
const std::vector<std::string> &input, int size,
const std::vector<std::string> &user_defined_symbols = {}) {
const std::string input_file =
util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "input");
util::JoinPath(::testing::TempDir(), "input");
const std::string model_prefix =
util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "model");
util::JoinPath(::testing::TempDir(), "model");
{
auto output = filesystem::NewWritableFile(input_file);
for (const auto &line : input) {
Expand Down Expand Up @@ -93,13 +93,13 @@ static constexpr char kTestInputData[] = "wagahaiwa_nekodearu.txt";

TEST(BPETrainerTest, EndToEndTest) {
const std::string input =
util::JoinPath(absl::GetFlag(FLAGS_test_srcdir), kTestInputData);
util::JoinPath(::testing::SrcDir(), kTestInputData);

ASSERT_TRUE(
SentencePieceTrainer::Train(
absl::StrCat(
"--model_prefix=",
util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "tmp_model"),
util::JoinPath(::testing::TempDir(), "tmp_model"),
" --input=", input,
" --vocab_size=8000 --normalization_rule_name=identity"
" --model_type=bpe --control_symbols=<ctrl> "
Expand All @@ -108,7 +108,7 @@ TEST(BPETrainerTest, EndToEndTest) {

SentencePieceProcessor sp;
ASSERT_TRUE(sp.Load(std::string(util::JoinPath(
absl::GetFlag(FLAGS_test_tmpdir), "tmp_model.model")))
::testing::TempDir(), "tmp_model.model")))
.ok());
EXPECT_EQ(8000, sp.GetPieceSize());

Expand Down
14 changes: 7 additions & 7 deletions src/builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ TEST(BuilderTest, LoadCharsMapTest) {
Builder::CharsMap chars_map;
ASSERT_TRUE(
Builder::LoadCharsMap(
util::JoinPath(absl::GetFlag(FLAGS_test_srcdir), kTestInputData),
util::JoinPath(::testing::SrcDir(), kTestInputData),
&chars_map)
.ok());

Expand All @@ -158,14 +158,14 @@ TEST(BuilderTest, LoadCharsMapTest) {

ASSERT_TRUE(
Builder::SaveCharsMap(
util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "output.tsv"),
util::JoinPath(::testing::TempDir(), "output.tsv"),
chars_map)
.ok());

Builder::CharsMap saved_chars_map;
ASSERT_TRUE(
Builder::LoadCharsMap(
util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "output.tsv"),
util::JoinPath(::testing::TempDir(), "output.tsv"),
&saved_chars_map)
.ok());
EXPECT_EQ(chars_map, saved_chars_map);
Expand All @@ -180,15 +180,15 @@ TEST(BuilderTest, LoadCharsMapTest) {
TEST(BuilderTest, LoadCharsMapWithEmptyeTest) {
{
auto output = filesystem::NewWritableFile(
util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "test.tsv"));
util::JoinPath(::testing::TempDir(), "test.tsv"));
output->WriteLine("0061\t0041");
output->WriteLine("0062");
output->WriteLine("0063\t\t#foo=>bar");
}

Builder::CharsMap chars_map;
EXPECT_TRUE(Builder::LoadCharsMap(
util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "test.tsv"),
util::JoinPath(::testing::TempDir(), "test.tsv"),
&chars_map)
.ok());

Expand All @@ -199,14 +199,14 @@ TEST(BuilderTest, LoadCharsMapWithEmptyeTest) {

EXPECT_TRUE(
Builder::SaveCharsMap(
util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "test_out.tsv"),
util::JoinPath(::testing::TempDir(), "test_out.tsv"),
chars_map)
.ok());

Builder::CharsMap new_chars_map;
EXPECT_TRUE(
Builder::LoadCharsMap(
util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "test_out.tsv"),
util::JoinPath(::testing::TempDir(), "test_out.tsv"),
&new_chars_map)
.ok());
EXPECT_EQ(chars_map, new_chars_map);
Expand Down
4 changes: 2 additions & 2 deletions src/char_model_trainer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ namespace {

std::string RunTrainer(const std::vector<std::string> &input, int size) {
const std::string input_file =
util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "input");
util::JoinPath(::testing::TempDir(), "input");
const std::string model_prefix =
util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "model");
util::JoinPath(::testing::TempDir(), "model");
{
auto output = filesystem::NewWritableFile(input_file);
for (const auto &line : input) {
Expand Down
4 changes: 2 additions & 2 deletions src/filesystem_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ TEST(UtilTest, FilesystemTest) {

{
auto output = filesystem::NewWritableFile(
util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "test_file"));
util::JoinPath(::testing::TempDir(), "test_file"));
for (size_t i = 0; i < kData.size(); ++i) {
output->WriteLine(kData[i]);
}
}

{
auto input = filesystem::NewReadableFile(
util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "test_file"));
util::JoinPath(::testing::TempDir(), "test_file"));
std::string line;
for (size_t i = 0; i < kData.size(); ++i) {
EXPECT_TRUE(input->ReadLine(&line));
Expand Down
8 changes: 4 additions & 4 deletions src/sentencepiece_processor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -994,13 +994,13 @@ TEST(SentencePieceProcessorTest, EndToEndTest) {

{
auto output = filesystem::NewWritableFile(
util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "model"), true);
util::JoinPath(::testing::TempDir(), "model"), true);
output->Write(model_proto.SerializeAsString());
}

SentencePieceProcessor sp;
EXPECT_TRUE(
sp.Load(util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "model")).ok());
sp.Load(util::JoinPath(::testing::TempDir(), "model")).ok());

EXPECT_EQ(model_proto.SerializeAsString(),
sp.model_proto().SerializeAsString());
Expand Down Expand Up @@ -1467,10 +1467,10 @@ TEST(SentencePieceProcessorTest, VocabularyTest) {
auto GetInlineFilename = [](const std::string content) {
{
auto out = filesystem::NewWritableFile(
util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "vocab.txt"));
util::JoinPath(::testing::TempDir(), "vocab.txt"));
out->Write(content);
}
return util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "vocab.txt");
return util::JoinPath(::testing::TempDir(), "vocab.txt");
};

sp1->set_type(ModelProto::SentencePiece::UNKNOWN);
Expand Down
32 changes: 16 additions & 16 deletions src/sentencepiece_trainer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ void CheckNormalizer(absl::string_view filename, bool expected_has_normalizer,

TEST(SentencePieceTrainerTest, TrainFromArgsTest) {
const std::string input =
util::JoinPath(absl::GetFlag(FLAGS_test_srcdir), kTestData);
util::JoinPath(::testing::SrcDir(), kTestData);
const std::string model =
util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "m");
util::JoinPath(::testing::TempDir(), "m");

ASSERT_TRUE(SentencePieceTrainer::Train(
absl::StrCat("--input=", input, " --model_prefix=", model,
Expand Down Expand Up @@ -118,9 +118,9 @@ TEST(SentencePieceTrainerTest, TrainFromIterator) {
};

const std::string input =
util::JoinPath(absl::GetFlag(FLAGS_test_srcdir), kTestData);
util::JoinPath(::testing::SrcDir(), kTestData);
const std::string model =
util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "m");
util::JoinPath(::testing::TempDir(), "m");

std::vector<std::string> sentences;
{
Expand Down Expand Up @@ -154,11 +154,11 @@ TEST(SentencePieceTrainerTest, TrainFromIterator) {

TEST(SentencePieceTrainerTest, TrainWithCustomNormalizationRule) {
std::string input =
util::JoinPath(absl::GetFlag(FLAGS_test_srcdir), kTestData);
util::JoinPath(::testing::SrcDir(), kTestData);
std::string rule =
util::JoinPath(absl::GetFlag(FLAGS_test_srcdir), kNfkcTestData);
util::JoinPath(::testing::SrcDir(), kNfkcTestData);
const std::string model =
util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "m");
util::JoinPath(::testing::TempDir(), "m");

EXPECT_TRUE(SentencePieceTrainer::Train(
absl::StrCat("--input=", input, " --model_prefix=", model,
Expand All @@ -170,13 +170,13 @@ TEST(SentencePieceTrainerTest, TrainWithCustomNormalizationRule) {

TEST(SentencePieceTrainerTest, TrainWithCustomDenormalizationRule) {
const std::string input_file =
util::JoinPath(absl::GetFlag(FLAGS_test_srcdir), kTestDataJa);
util::JoinPath(::testing::SrcDir(), kTestDataJa);
const std::string model =
util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "m");
util::JoinPath(::testing::TempDir(), "m");
const std::string norm_rule_tsv =
util::JoinPath(absl::GetFlag(FLAGS_test_srcdir), kIdsNormTsv);
util::JoinPath(::testing::SrcDir(), kIdsNormTsv);
const std::string denorm_rule_tsv =
util::JoinPath(absl::GetFlag(FLAGS_test_srcdir), kIdsDenormTsv);
util::JoinPath(::testing::SrcDir(), kIdsDenormTsv);
EXPECT_TRUE(
SentencePieceTrainer::Train(
absl::StrCat("--input=", input_file, " --model_prefix=", model,
Expand All @@ -199,9 +199,9 @@ TEST(SentencePieceTrainerTest, TrainErrorTest) {
TEST(SentencePieceTrainerTest, TrainTest) {
TrainerSpec trainer_spec;
trainer_spec.add_input(
util::JoinPath(absl::GetFlag(FLAGS_test_srcdir), kTestData));
util::JoinPath(::testing::SrcDir(), kTestData));
trainer_spec.set_model_prefix(
util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "m"));
util::JoinPath(::testing::TempDir(), "m"));
trainer_spec.set_vocab_size(1000);
NormalizerSpec normalizer_spec;
ASSERT_TRUE(SentencePieceTrainer::Train(trainer_spec, normalizer_spec).ok());
Expand Down Expand Up @@ -366,12 +366,12 @@ TEST(SentencePieceTrainerTest, PopulateModelTypeFromStringTest) {

TEST(SentencePieceTrainerTest, NormalizationTest) {
const auto model_prefix =
util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "m");
util::JoinPath(::testing::TempDir(), "m");
const auto model_file = absl::StrCat(model_prefix, ".model");

TrainerSpec trainer_spec;
trainer_spec.add_input(
util::JoinPath(absl::GetFlag(FLAGS_test_srcdir), kTestData));
util::JoinPath(::testing::SrcDir(), kTestData));
trainer_spec.set_model_prefix(model_prefix);
trainer_spec.set_vocab_size(1000);
ASSERT_TRUE(SentencePieceTrainer::Train(trainer_spec).ok());
Expand Down Expand Up @@ -424,7 +424,7 @@ TEST(SentencePieceTrainerTest, NormalizationTest) {
{
SentencePieceNormalizer sp;
EXPECT_OK(sp.LoadFromRuleTSV(
util::JoinPath(absl::GetFlag(FLAGS_test_srcdir), "nfkc_cf.tsv")));
util::JoinPath(::testing::SrcDir(), "nfkc_cf.tsv")));
set_normalization_only(&sp);
EXPECT_EQ(sp.Normalize("ABCD"), "abcd");
}
Expand Down
4 changes: 2 additions & 2 deletions src/testharness.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ bool RegisterTest(const char *base, const char *name, void (*func)()) {
int RunAllTests() {
int num = 0;
#ifdef OS_WIN
_mkdir(absl::GetFlag(FLAGS_test_tmpdir).c_str());
_mkdir(::testing::TempDir().c_str());
#else
mkdir(absl::GetFlag(FLAGS_test_tmpdir).c_str(), S_IRUSR | S_IWUSR | S_IXUSR);
mkdir(::testing::TempDir().c_str(), S_IRUSR | S_IWUSR | S_IXUSR);
#endif

if (tests == nullptr) {
Expand Down
5 changes: 5 additions & 0 deletions src/testharness.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@
ABSL_DECLARE_FLAG(std::string, test_tmpdir);
ABSL_DECLARE_FLAG(std::string, test_srcdir);

namespace testing {
inline std::string TempDir() { return absl::GetFlag(FLAGS_test_tmpdir); }
inline std::string SrcDir() { return absl::GetFlag(FLAGS_test_srcdir); }
} // namespace testing

namespace sentencepiece {
namespace test {
// Run some of the tests registered by the TEST() macro.
Expand Down
6 changes: 3 additions & 3 deletions src/trainer_interface_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ TEST(TrainerInterfaceTest, SerializeTest) {

TEST(TrainerInterfaceTest, CharactersTest) {
const std::string input_file =
util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "input");
util::JoinPath(::testing::TempDir(), "input");
{
auto output = filesystem::NewWritableFile(input_file);
// Make a single line with 50 "a", 49 "あ", and 1 "b".
Expand Down Expand Up @@ -559,7 +559,7 @@ TEST(TrainerInterfaceTest, MultiFileSentenceIteratorTest) {
std::vector<std::string> files;
std::vector<std::string> expected;
for (int i = 0; i < 10; ++i) {
const std::string file = util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir),
const std::string file = util::JoinPath(::testing::TempDir(),
absl::StrCat("input", i));
auto output = filesystem::NewWritableFile(file);
int num_line = (rand() % 100) + 1;
Expand All @@ -581,7 +581,7 @@ TEST(TrainerInterfaceTest, MultiFileSentenceIteratorTest) {
TEST(TrainerInterfaceTest, MultiFileSentenceIteratorErrorTest) {
std::vector<std::string> files;
for (int i = 0; i < 10; ++i) {
const std::string file = util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir),
const std::string file = util::JoinPath(::testing::TempDir(),
absl::StrCat("input_not_exist", i));
files.push_back(file);
}
Expand Down
10 changes: 5 additions & 5 deletions src/unigram_model_trainer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ TrainerResult RunTrainer(const std::vector<std::string>& input, int size,
const bool use_dp = false, const float dp_noise = 0.0,
const uint32 dp_clip = 0) {
const std::string input_file =
util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "input");
util::JoinPath(::testing::TempDir(), "input");
const std::string model_prefix =
util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "model");
util::JoinPath(::testing::TempDir(), "model");
{
auto output = filesystem::NewWritableFile(input_file);
for (const auto& line : input) {
Expand Down Expand Up @@ -154,21 +154,21 @@ static constexpr char kTestInputData[] = "wagahaiwa_nekodearu.txt";

TEST(UnigramTrainerTest, EndToEndTest) {
const std::string input =
util::JoinPath(absl::GetFlag(FLAGS_test_srcdir), kTestInputData);
util::JoinPath(::testing::SrcDir(), kTestInputData);

ASSERT_TRUE(
SentencePieceTrainer::Train(
absl::StrCat(
"--model_prefix=",
util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "tmp_model"),
util::JoinPath(::testing::TempDir(), "tmp_model"),
" --input=", input,
" --vocab_size=8000 --normalization_rule_name=identity",
" --model_type=unigram --user_defined_symbols=<user>",
" --control_symbols=<ctrl> --max_sentence_length=2048"))
.ok());

SentencePieceProcessor sp;
EXPECT_TRUE(sp.Load(util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir),
EXPECT_TRUE(sp.Load(util::JoinPath(::testing::TempDir(),
"tmp_model.model"))
.ok());
EXPECT_EQ(8000, sp.GetPieceSize());
Expand Down
4 changes: 2 additions & 2 deletions src/util_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -332,15 +332,15 @@ TEST(UtilTest, InputOutputBufferTest) {

{
auto output = filesystem::NewWritableFile(
util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "test_file"));
util::JoinPath(::testing::TempDir(), "test_file"));
for (size_t i = 0; i < kData.size(); ++i) {
output->WriteLine(kData[i]);
}
}

{
auto input = filesystem::NewReadableFile(
util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "test_file"));
util::JoinPath(::testing::TempDir(), "test_file"));
std::string line;
for (size_t i = 0; i < kData.size(); ++i) {
EXPECT_TRUE(input->ReadLine(&line));
Expand Down
4 changes: 2 additions & 2 deletions src/word_model_trainer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ namespace {

std::string RunTrainer(const std::vector<std::string> &input, int size) {
const std::string input_file =
util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "input");
util::JoinPath(::testing::TempDir(), "input");
const std::string model_prefix =
util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "model");
util::JoinPath(::testing::TempDir(), "model");
{
auto output = filesystem::NewWritableFile(input_file);
for (const auto &line : input) {
Expand Down

0 comments on commit 9082653

Please sign in to comment.