diff --git a/README.md b/README.md
index 1841881..b0fe761 100644
--- a/README.md
+++ b/README.md
@@ -99,6 +99,30 @@ torchrun --standalone --nproc_per_node=1 ml_mdm/clis/generate_sample.py --port
## Codebase
+
+### 1. /configs
+
+| module | description |
+| - | - |
+| `configs.dataset_creation` | Configuration file for dataset splitting into train-eval-val pipeline |
+| `configs.datasets` | Datasets for training and evaluation phases of the model |
+| `configs.models` | Configuration files for different resolution models |
+
+
+### 2. /data
+
+| module | description |
+| - | - |
+| `data` |
- bert.vocab: BERT-trained dictionary containing tokens and their associated vector representations
- c4_wpm.vocab: C4-trained dictionary containing tokens and their associated vector representations
- cifar10.vocab: CIFAR10-trained dictionary containing tokens and their associated vector representations
- imagenet.vocab: Prompts associated with Imagenet dataset
- prompts_cc12m-64x64.tsv: Prompts associated with cc12m dataset for the 64x64 res. model
- prompts_cc12m-256x256.tsv: Prompts associated with cc12m dataset for the 256x256 res. model
- prompts_cifar10-32x32.tsv: Prompts associated with cifar10 dataset for the 32x32 res. model
- prompts_cifar10-64x64.tsv: Prompts associated with cifar10 dataset for the 64x64 res. model
- prompts_demo.tsv: Extra demo prompts
- prompts_imagenet-64px.tsv: Prompts associated with imagenet dataset for the 64x64 res. model
- prompts_WebImage-ALIGN-64px.tsv: Prompts associated with WebImage-ALIGN dataset for the 64x64 res. model
- t5.vocab: t5-trained dictionary containing tokens and their associated vector representations
- tokenizer_spm_32000_50m.vocab: SPM-trained dictionary containing tokens and their associated vector representations
|
+
+### 3. /docs
+
+| module | description |
+| - | - |
+| `docs` | - web_demo.png: Screenshot of the web demo of the model
|
+
+### 4. /ml_mdm
+
| module | description |
| - | - |
| `ml_mdm.models` | The core model implementations |
@@ -107,7 +131,11 @@ torchrun --standalone --nproc_per_node=1 ml_mdm/clis/generate_sample.py --port
| `ml_mdm.clis` | All command line tools in the project, the most relevant being `train_parallel.py` |
| `tests/` | Unit tests and sample training files |
+### 5. /tests
+| module | description |
+| - | - |
+| `tests.test_files` | Sample files for testing |
# Concepts
diff --git a/ml_mdm/language_models/factory.py b/ml_mdm/language_models/factory.py
index 180d406..df8f838 100644
--- a/ml_mdm/language_models/factory.py
+++ b/ml_mdm/language_models/factory.py
@@ -8,7 +8,7 @@
import torch.nn as nn
import torch.nn.functional as F
-from .tokenizer import Tokenizer
+from ml_mdm.language_models.tokenizer import Tokenizer
class T5Encoder(T5ForConditionalGeneration):
diff --git a/ml_mdm/language_models/tokenizer.py b/ml_mdm/language_models/tokenizer.py
index 0fb08dd..b3af8a8 100644
--- a/ml_mdm/language_models/tokenizer.py
+++ b/ml_mdm/language_models/tokenizer.py
@@ -5,11 +5,11 @@
from mlx.data.core import CharTrie
-def read_dictionary_bert(token_file):
+def read_dictionary_bert(vocab_file):
trie_key_scores = []
trie = CharTrie()
- f = open(token_file, "rb")
+ f = open(vocab_file, "rb")
sep = "\u2581".encode()
max_score = 0
@@ -42,11 +42,11 @@ def read_dictionary_bert(token_file):
return trie, trie_key_scores, eos, bos, pad
-def read_dictionary_t5(token_file):
+def read_dictionary_t5(vocab_file):
trie_key_scores = []
trie = CharTrie()
- f = open(token_file, "rb")
+ f = open(vocab_file, "rb")
sep = "\u2581".encode()
max_score = 0
@@ -75,7 +75,7 @@ def read_dictionary_t5(token_file):
return trie, trie_key_scores, eos, bos, pad
-def read_dictionary(token_file):
+def read_dictionary(vocab_file):
trie_key_scores = []
trie = CharTrie()
@@ -85,7 +85,7 @@ def read_dictionary(token_file):
trie.insert(token)
trie_key_scores.append(0.0)
- f = open(token_file, "rb")
+ f = open(vocab_file, "rb")
sep = "\u2581".encode()
max_score = 0
@@ -130,7 +130,7 @@ def read_dictionary(token_file):
class Tokenizer:
- def __init__(self, token_file, mode=None):
+ def __init__(self, vocab_file, mode=None):
if mode == "t5":
(
self._trie,
@@ -138,7 +138,7 @@ def __init__(self, token_file, mode=None):
self.eos,
self.bos,
self.pad,
- ) = read_dictionary_t5(token_file)
+ ) = read_dictionary_t5(vocab_file)
elif mode == "bert":
(
self._trie,
@@ -146,7 +146,7 @@ def __init__(self, token_file, mode=None):
self.eos,
self.bos,
self.pad,
- ) = read_dictionary_bert(token_file)
+ ) = read_dictionary_bert(vocab_file)
else:
(
self._trie,
@@ -154,7 +154,7 @@ def __init__(self, token_file, mode=None):
self.eos,
self.bos,
self.pad,
- ) = read_dictionary(token_file)
+ ) = read_dictionary(vocab_file)
self.vocab_size = self._trie.num_keys()
@property
diff --git a/tests/test_configs.py b/tests/test_configs.py
index ed622ae..a61a8dd 100644
--- a/tests/test_configs.py
+++ b/tests/test_configs.py
@@ -107,4 +107,4 @@ def test_config_cc12m_1024x1024():
mode="demo",
additional_config_paths=[f],
)
- assert args
+ assert args
\ No newline at end of file
diff --git a/tests/test_reader.py b/tests/test_reader.py
index dd6d983..3ad3fe8 100644
--- a/tests/test_reader.py
+++ b/tests/test_reader.py
@@ -56,3 +56,6 @@ def test_process_text():
)
assert len(tokens) > 0
assert len(tokens[0]) > 0
+
+
+test_get_dataset()
\ No newline at end of file
diff --git a/tests/test_tokenizer.py b/tests/test_tokenizer.py
new file mode 100644
index 0000000..49ea485
--- /dev/null
+++ b/tests/test_tokenizer.py
@@ -0,0 +1,23 @@
+# For licensing see accompanying LICENSE file.
+# Copyright (C) 2024 Apple Inc. All rights reserved.
+
+import logging
+
+from pathlib import Path
+from ml_mdm.language_models.tokenizer import Tokenizer # Tokenizer class from tokenizer.py
+
+def test_tokenizer_bert():
+ f = Path(__file__).parent/"data/bert.vocab" # To solve from relative to absolute import
+ assert Tokenizer(f, mode="bert")
+
+def test_tokenizer_t5():
+ f = Path(__file__).parent/"data/t5.vocab"
+ assert Tokenizer(f, mode="tf")
+
+def test_tokenizer():
+ f = Path(__file__).parent/"data/imagenet.vocab"
+ assert Tokenizer(f)
+
+test_tokenizer_bert()
+test_tokenizer_t5()
+test_tokenizer()