Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LLama32 Vision Model Support in Nemo 2.0 #10763

Merged
merged 238 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
238 commits
Select commit Hold shift + click to select a range
2f2f600
add initial code for llama vlm
yaoyu-33 Sep 10, 2024
1c823c5
some restructure
yaoyu-33 Sep 10, 2024
50535cc
add mock data placeholder
yaoyu-33 Sep 10, 2024
27b0240
Fix some importing
yaoyu-33 Sep 10, 2024
3127129
add language component for vlm llama
suiyoubi Sep 10, 2024
af853a7
update code
yaoyu-33 Sep 10, 2024
09b487c
now match num of params
suiyoubi Sep 11, 2024
ec0ce83
update language part and fix vision part
yaoyu-33 Sep 16, 2024
8c9332b
minor fix
yaoyu-33 Sep 16, 2024
97d06ed
model can now init
yaoyu-33 Sep 16, 2024
9d79861
minor update for llama32 text config
yaoyu-33 Sep 16, 2024
26e106d
make checkpoint loading work
cuichenx Sep 17, 2024
8e7a1a9
missing import
cuichenx Sep 17, 2024
ea77296
match vision part tensor shapes with configs
yaoyu-33 Sep 17, 2024
5a76638
Merge remote-tracking branch 'internal/yuya/add_llama_vlm' into yuya/…
yaoyu-33 Sep 17, 2024
28a195c
solve some fwd issues and mismatch issues
yaoyu-33 Sep 17, 2024
3ce7ebb
add vision import
cuichenx Sep 18, 2024
caea32e
fixes
yaoyu-33 Sep 18, 2024
ed4087c
Merge remote-tracking branch 'internal/yuya/add_llama_vlm' into yuya/…
yaoyu-33 Sep 18, 2024
40b33c9
update importer to convert both text and image weights
cuichenx Sep 18, 2024
62a3c19
importer typos and reduce clutter
cuichenx Sep 18, 2024
e232385
fix import qkv
cuichenx Sep 18, 2024
2ece533
some fixes for LLM
yaoyu-33 Sep 18, 2024
7aeb632
Merge remote-tracking branch 'internal/yuya/add_llama_vlm' into yuya/…
yaoyu-33 Sep 18, 2024
8d2dd62
Add embedding
meatybobby Sep 19, 2024
8da75f6
some updates
yaoyu-33 Sep 19, 2024
221eee9
enable loading only text or only vision
cuichenx Sep 19, 2024
4f91a03
add example script
cuichenx Sep 19, 2024
be8ee62
TP fix
yaoyu-33 Sep 19, 2024
7fdbc5c
update
cuichenx Sep 19, 2024
610f9a3
Merge remote-tracking branch 'origin/yuya/add_llama_vlm' into yuya/ad…
cuichenx Sep 19, 2024
43dc2fd
upload examples
yaoyu-33 Sep 19, 2024
8054701
update generate
yaoyu-33 Sep 19, 2024
fbed2b5
update to newer version
yaoyu-33 Sep 19, 2024
a5772e2
upload for sharing
cuichenx Sep 20, 2024
24a40a2
update to new pyt ckpt
cuichenx Sep 20, 2024
931ebb7
xattn_caches matches (except small differences due to TE RMSNorm)
cuichenx Sep 20, 2024
e2dcc00
cleanup
cuichenx Sep 20, 2024
c377844
embeddings match
cuichenx Sep 22, 2024
1a5082c
match precision of weights
cuichenx Sep 23, 2024
342e58c
update sharded state dict
yaoyu-33 Sep 23, 2024
394ad18
Merge remote-tracking branch 'internal/yuya/add_llama_vlm' into yuya/…
yaoyu-33 Sep 23, 2024
d1a38e0
change xattn layer num to 3 7 11 etc
cuichenx Sep 23, 2024
3dcb53b
Merge remote-tracking branch 'origin/yuya/add_llama_vlm' into yuya/ad…
cuichenx Sep 23, 2024
751fe44
upload llama generation
cuichenx Sep 24, 2024
60c9a33
minor fix
cuichenx Sep 24, 2024
b083e12
fix dummy layer input format
cuichenx Sep 24, 2024
f4f5252
fix vision qkv order
cuichenx Sep 24, 2024
f742ecd
fix shareded state dict
yaoyu-33 Sep 24, 2024
aecc8cc
Merge remote-tracking branch 'internal/yuya/add_llama_vlm' into yuya/…
yaoyu-33 Sep 24, 2024
36e2398
fix vision precision
cuichenx Sep 24, 2024
cb2875d
fix rope
cuichenx Sep 24, 2024
8941ccd
match cross attn layer
cuichenx Sep 24, 2024
b931a82
remove nrep
cuichenx Sep 25, 2024
ffd3473
Remove cross attention in ImageTransformerLayer and fix _gate_ffn
meatybobby Sep 25, 2024
dd9eb75
PP draft
yaoyu-33 Sep 25, 2024
198246c
Fix intermediate tensor
meatybobby Sep 25, 2024
78ba2d3
temp save for pp2 is working
yaoyu-33 Sep 25, 2024
86fc928
fix pp issues
yaoyu-33 Sep 25, 2024
816ea88
Merge remote-tracking branch 'internal/yuya/add_llama_vlm' into yuya/…
yaoyu-33 Sep 26, 2024
152e59a
merge
cuichenx Sep 26, 2024
4ff0325
update mcore parallelism initialization
yaoyu-33 Sep 26, 2024
82b9d5a
small update to pretrain script
yaoyu-33 Sep 26, 2024
b45f5ab
Merge remote-tracking branch 'internal/yuya/add_llama_vlm' into yuya/…
yaoyu-33 Sep 26, 2024
232009d
update mcore parallelism initialization
yaoyu-33 Sep 26, 2024
a5e3a3a
Apply isort and black reformatting
yaoyu-33 Sep 26, 2024
8acea06
added energon dataloader for neva training (#10451)
yashaswikarnati Sep 19, 2024
13a3211
llama energon dataloader
yashaswikarnati Sep 23, 2024
16f1c05
have tokenizer for base task encoder class
yashaswikarnati Sep 23, 2024
598f95b
Update megatron_init.py
yaoyu-33 Sep 26, 2024
fd32cfa
Add simple inference
meatybobby Sep 27, 2024
f5549f3
evian3 update
yaoyu-33 Sep 27, 2024
6e9566a
add encoder parallel default config
yaoyu-33 Sep 27, 2024
8bc8823
add encoder parallel default config
yaoyu-33 Sep 27, 2024
d0ff08b
clean up
yaoyu-33 Sep 27, 2024
54923f0
add aspect ratio in model
cuichenx Sep 27, 2024
5843c04
Merge remote-tracking branch 'internal/yuya/add_llama_vlm' into yuya/…
yaoyu-33 Sep 27, 2024
e0e5f00
support energon dataloader
cuichenx Sep 28, 2024
44cabfa
some pp update
yaoyu-33 Sep 29, 2024
5e54a29
Merge remote-tracking branch 'internal/yuya/add_llama_vlm' into yuya/…
yaoyu-33 Sep 29, 2024
3ad96e5
fixes
yaoyu-33 Sep 30, 2024
21e41ca
fix kv merging
yaoyu-33 Sep 30, 2024
ff143ff
fix get_key_value_tensors
yaoyu-33 Sep 30, 2024
25c7781
rename files
yaoyu-33 Sep 30, 2024
4dbe2e3
update to HF style position embedding
yaoyu-33 Sep 30, 2024
ca10c21
fix energon dataloader and support batching
cuichenx Sep 30, 2024
cca1acb
update forward args
yaoyu-33 Sep 30, 2024
f8cc794
Merge remote-tracking branch 'internal/yuya/add_llama_vlm' into yuya/…
yaoyu-33 Sep 30, 2024
f35e2f4
clean up and move to aspect_ratio_ids
yaoyu-33 Sep 30, 2024
428403e
rename back to language.py
yaoyu-33 Sep 30, 2024
ffcb9df
fix loss function
yaoyu-33 Sep 30, 2024
8f65450
update and fix energon
yaoyu-33 Sep 30, 2024
7c33686
Add hf import
meatybobby Oct 1, 2024
11338d9
Fix type
meatybobby Oct 1, 2024
28a2f84
Change config
meatybobby Oct 1, 2024
d5ea385
update energon pretrain
yaoyu-33 Oct 1, 2024
23db738
Merge remote-tracking branch 'internal/yuya/add_llama_vlm_hf' into yu…
yaoyu-33 Oct 1, 2024
412facb
clean up
cuichenx Oct 1, 2024
1881095
Merge remote-tracking branch 'origin/yuya/add_llama_vlm_hf' into yuya…
cuichenx Oct 1, 2024
5441d38
clean up
cuichenx Oct 1, 2024
e67fe30
reformat
yaoyu-33 Oct 1, 2024
a069601
Merge remote-tracking branch 'internal/yuya/add_llama_vlm_hf' into yu…
yaoyu-33 Oct 1, 2024
eb35dba
update inference files for new code
cuichenx Oct 1, 2024
d2ee556
update to instruct
cuichenx Oct 1, 2024
298e3a3
update to instruct
cuichenx Oct 1, 2024
894c709
update few names
yaoyu-33 Oct 1, 2024
c4dad1e
Merge remote-tracking branch 'internal/yuya/add_llama_vlm_hf' into yu…
yaoyu-33 Oct 1, 2024
109bcd6
update generation
yaoyu-33 Oct 1, 2024
ea515d1
fix importer embedding.weight
cuichenx Oct 1, 2024
1867b29
few fixes
yaoyu-33 Oct 1, 2024
6c26440
add hf script
yaoyu-33 Oct 1, 2024
cec153b
fix kv import
cuichenx Oct 2, 2024
caa3b27
remove interleaved
cuichenx Oct 2, 2024
1fb8318
fixes and updates
yaoyu-33 Oct 2, 2024
ecae4db
lora fixes
yaoyu-33 Oct 2, 2024
3933bae
some code clean ups
yaoyu-33 Oct 2, 2024
f4e921f
update training scripts
yaoyu-33 Oct 2, 2024
b0787e9
refactors
yaoyu-33 Oct 2, 2024
9714618
add LoRA finetuning
cuichenx Oct 3, 2024
0ba8272
Merge remote-tracking branch 'origin/yuya/add_llama_vlm_hf' into yuya…
cuichenx Oct 3, 2024
472b85f
Merge branch 'main' into yuya/add_llama_vlm_hf
yaoyu-33 Oct 3, 2024
2ee7608
fixes and nemo update
yaoyu-33 Oct 3, 2024
c4938f0
fix importer registering issue by adding 11B and 90B configs
cuichenx Oct 3, 2024
7f0d003
update `decoder_seq_len`
yaoyu-33 Oct 3, 2024
5df6e11
science vqa script
yaoyu-33 Oct 3, 2024
806a16f
clean up script name
yaoyu-33 Oct 3, 2024
c533329
fix ckpt save serialization issue
cuichenx Oct 3, 2024
bd562ce
fix predefined config classes
cuichenx Oct 3, 2024
ae0f1ad
add num_chunks in input
yaoyu-33 Oct 3, 2024
2045c80
fix format
yaoyu-33 Oct 4, 2024
2156ebc
update finetuning scripts for PEFT
cuichenx Oct 4, 2024
e38f615
add 11b recipe (need #10645 to test)
cuichenx Oct 4, 2024
aaf7746
fix mask generation
yaoyu-33 Oct 4, 2024
2f0d740
minor fix code style
yaoyu-33 Oct 4, 2024
7279efa
Apply isort and black reformatting
yaoyu-33 Oct 4, 2024
3368f29
Support no image inference
meatybobby Oct 4, 2024
eaa7ce7
add llama svqa eval
meatybobby Oct 4, 2024
a4674d0
fix masking
yaoyu-33 Oct 4, 2024
31b64a6
Merge remote-tracking branch 'refs/remotes/origin/yuya/add_llama_vlm_…
yaoyu-33 Oct 4, 2024
d9c8b84
Merge remote-tracking branch 'internal/yuya/add_llama_vlm_hf' into yu…
yaoyu-33 Oct 4, 2024
a1e3bda
Apply isort and black reformatting
yaoyu-33 Oct 4, 2024
9fe9833
fix generation
yaoyu-33 Oct 4, 2024
41506bd
Merge remote-tracking branch 'internal/yuya/add_llama_vlm_hf' into yu…
yaoyu-33 Oct 4, 2024
e66d956
Apply isort and black reformatting
yaoyu-33 Oct 4, 2024
0807220
add 90b recipe and revise 11b recipe
cuichenx Oct 4, 2024
c3bd7ad
Apply isort and black reformatting
cuichenx Oct 4, 2024
4c77909
clean up typing
cuichenx Oct 7, 2024
8004563
add option to disable vision padding
cuichenx Oct 7, 2024
9d77f37
Apply isort and black reformatting
cuichenx Oct 7, 2024
531a97d
Merge remote-tracking branch 'refs/remotes/github/main' into github_y…
cuichenx Oct 9, 2024
045d0d8
base model finetuning (does not work yet)
cuichenx Oct 9, 2024
f74a19d
Apply isort and black reformatting
cuichenx Oct 9, 2024
3c9fa57
fixed default conversation template config for MLLama
Oct 9, 2024
685dc92
Update svqa
meatybobby Oct 10, 2024
b5fc604
Merge remote-tracking branch 'refs/remotes/github/main' into yuya/add…
cuichenx Oct 12, 2024
78314c9
add multinode
cuichenx Oct 12, 2024
279529a
bot happy
cuichenx Oct 12, 2024
d2ae960
Apply isort and black reformatting
cuichenx Oct 12, 2024
e078aa2
Merge branch 'main' into yuya/add_llama_vlm_hf
cuichenx Oct 14, 2024
f689452
Merge remote-tracking branch 'origin/yuya/add_llama_vlm_hf' into yuya…
yaoyu-33 Oct 14, 2024
b49c539
Apply isort and black reformatting
yaoyu-33 Oct 14, 2024
1bf5848
Merge branch 'refs/heads/main' into yuya/add_llama_vlm_hf
yaoyu-33 Oct 15, 2024
da4a4ab
Apply isort and black reformatting
yaoyu-33 Oct 15, 2024
1424f53
Apply isort and black reformatting
artbataev Oct 15, 2024
6c70c7b
Perf improvements. Mainly from XAttn mask calculation (#10901)
parthmannan Oct 16, 2024
59df701
fix existing issues
yaoyu-33 Oct 16, 2024
409f1d8
fix scripts
yaoyu-33 Oct 16, 2024
b6557f9
Apply isort and black reformatting
yaoyu-33 Oct 16, 2024
dbedae0
fix lora
cuichenx Oct 16, 2024
6892f6e
few fixes for non image support
yaoyu-33 Oct 17, 2024
96d912a
update masking gen
yaoyu-33 Oct 17, 2024
06ab440
update lazy dataset
yaoyu-33 Oct 17, 2024
3fd7a80
fix data sampler and loading issue
yaoyu-33 Oct 21, 2024
4711c75
Add vlm generation
meatybobby Oct 21, 2024
9045f4f
Apply isort and black reformatting
meatybobby Oct 21, 2024
668b40d
Merge remote-tracking branch 'refs/remotes/origin/yuya/add_llama_vlm_…
yaoyu-33 Oct 22, 2024
84d3b86
Merge remote-tracking branch 'origin/yuya/add_llama_vlm_hf' into yuya…
yaoyu-33 Oct 22, 2024
32fbb1a
Apply isort and black reformatting
yaoyu-33 Oct 22, 2024
1922d07
generation update
yaoyu-33 Oct 22, 2024
6768696
update lazy dataset
yaoyu-33 Oct 23, 2024
ab58b82
Fix _strategy_lib.py
yaoyu-33 Oct 23, 2024
fd1ddcd
Merge remote-tracking branch 'origin/yuya/add_llama_vlm_hf' into yuya…
yaoyu-33 Oct 23, 2024
4e099ad
Merge branch 'main' into yuya/add_llama_vlm_hf
yaoyu-33 Oct 23, 2024
51530bf
Apply isort and black reformatting
yaoyu-33 Oct 23, 2024
e9db2a4
fix warning
yaoyu-33 Oct 23, 2024
433054d
hide vlm examples
yaoyu-33 Oct 23, 2024
882a6ae
Revert "Add vlm generation"
yaoyu-33 Oct 23, 2024
70521a3
Fix VisionEncoder multi-batch bug
meatybobby Oct 25, 2024
33a83df
update mcore parallelism initialization
yaoyu-33 Sep 26, 2024
7c3e022
Apply isort and black reformatting
yaoyu-33 Sep 26, 2024
5e3b7d7
Update megatron_init.py
yaoyu-33 Sep 26, 2024
340b28d
add encoder parallel default config
yaoyu-33 Sep 27, 2024
a575b61
Fix _strategy_lib.py
yaoyu-33 Oct 23, 2024
8009676
llm.generate fixes (#10983)
HuiyingLi Oct 23, 2024
fc9e855
use __dict__ in check (#11012)
akoumpa Oct 24, 2024
5f2c7fe
LoRA support for HF::AutoModelForCausalLM (#10982)
akoumpa Oct 24, 2024
d2ba2c3
Change default for always_save_context to True (#11014)
athitten Oct 24, 2024
29d37c4
Add a build option to load_context (#10713)
marcromeyn Oct 24, 2024
b58b219
Fix pip install (#11026)
marcromeyn Oct 24, 2024
5044e9d
[WIP] Add docs for NEST SSL (#10804)
stevehuang52 Oct 24, 2024
e550063
Change dist ckpt defaults (#10913)
ShriyaPalsamudram Oct 24, 2024
eae77cb
Akoumparouli/mixtral recipe fix r2.0.0 (#10994)
akoumpa Oct 24, 2024
ed23ca8
Fix _strategy_lib tests (#11033)
maanug-nv Oct 25, 2024
7a4b544
Update `BaseMegatronSampler` for compatibility with PTL's `_BatchProg…
ashors1 Oct 25, 2024
55aa6f9
PTQ example for NeMo 2.0 (#10642)
Laplasjan107 Oct 25, 2024
c37ef00
TDT compute timestamps option and Extra Whitespace handling for SPE (…
monica-sekoyan Oct 25, 2024
f696815
Basic online dynamic FP8 quantization with vLLM (#10904)
janekl Oct 25, 2024
e71b073
ci: Improve VM maintenance (#10758)
ko3n1g Oct 25, 2024
0d62a47
Add comment for vision transpose
meatybobby Oct 25, 2024
8a13481
Merge branch 'refs/heads/main' into yuya/update_megatron_parallel
yaoyu-33 Oct 25, 2024
e7e6798
update megatron_init.py inside lightning
yaoyu-33 Oct 25, 2024
ac5fc29
Merge branch 'refs/heads/yuya/update_megatron_parallel' into yuya/add…
yaoyu-33 Oct 25, 2024
3ff21bf
Merge remote-tracking branch 'origin/yuya/add_llama_vlm_hf' into yuya…
yaoyu-33 Oct 25, 2024
03823d2
Merge branch 'main' into yuya/add_llama_vlm_hf
yaoyu-33 Oct 28, 2024
d9c9f1a
Merge branch 'main' into yuya/add_llama_vlm_hf
yaoyu-33 Oct 31, 2024
f6ef4db
rename llama to mllama folder name
yaoyu-33 Oct 31, 2024
ccf7187
update to attention bias
yaoyu-33 Nov 2, 2024
7142415
Apply isort and black reformatting
yaoyu-33 Nov 2, 2024
4774016
update dropout to 0
yaoyu-33 Nov 4, 2024
5b5d601
fix attention bias
yaoyu-33 Nov 4, 2024
703c10e
remove disable_vision_padding since we now have a fix
yaoyu-33 Nov 5, 2024
485e4fc
Merge remote-tracking branch 'origin/yuya/add_llama_vlm_hf' into yuya…
yaoyu-33 Nov 5, 2024
0669646
Apply isort and black reformatting
yaoyu-33 Nov 5, 2024
477e694
Update init for mllama
yaoyu-33 Nov 5, 2024
a266ce0
Merge remote-tracking branch 'origin/yuya/add_llama_vlm_hf' into yuya…
yaoyu-33 Nov 5, 2024
9c60865
Apply isort and black reformatting
yaoyu-33 Nov 5, 2024
a7d33d5
Address comments
yaoyu-33 Nov 5, 2024
dcf115c
Merge remote-tracking branch 'origin/yuya/add_llama_vlm_hf' into yuya…
yaoyu-33 Nov 5, 2024
03838bf
Apply isort and black reformatting
yaoyu-33 Nov 5, 2024
95244fd
fix copyright title
yaoyu-33 Nov 5, 2024
21f0123
Merge branch 'main' into yuya/add_llama_vlm_hf
yaoyu-33 Nov 6, 2024
ad1fbe9
fix code scan
yaoyu-33 Nov 6, 2024
782b961
update vision code
yaoyu-33 Nov 7, 2024
e49f3a7
revert attention bias changes until latest MLM code got merged
yaoyu-33 Nov 7, 2024
6b272f6
fix warning
yaoyu-33 Nov 7, 2024
f40f4d4
Turn off system message check, as it's "" now
yaoyu-33 Nov 7, 2024
a66fa0e
Merge branch 'main' into yuya/add_llama_vlm_hf
yaoyu-33 Nov 7, 2024
3cfce5f
Rolllback megatron_parallel.py
yaoyu-33 Nov 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions nemo/collections/multimodal/data/energon/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional

from copy import deepcopy
from typing import Any, Dict, Literal, Optional

import fiddle as fdl
import pytorch_lightning as pl
from megatron.core import parallel_state
from megatron.energon import WorkerConfig, get_savable_loader, get_train_dataset
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
from torch.utils.data import DataLoader
from typing_extensions import Self

from nemo.collections.multimodal.data.energon.config import MultiModalSampleConfig
from nemo.collections.multimodal.data.energon.task_encoder import MultiModalTaskEncoder
from nemo.lightning.io.mixin import IOMixin
from nemo.lightning.io.mixin import IOMixin, serialization, track_io
from nemo.lightning.pytorch.plugins import MegatronDataSampler
from nemo.utils import logging

if TYPE_CHECKING:
from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec


class SimpleMultiModalDataModule(pl.LightningDataModule, IOMixin):
"""
Expand Down Expand Up @@ -66,6 +67,7 @@ def __init__(
pin_memory: bool = True,
multimodal_sample_config: Optional[MultiModalSampleConfig] = MultiModalSampleConfig(),
task_encoder: Optional[MultiModalTaskEncoder] = None,
decoder_seq_length: Optional[int] = None,
) -> None:
"""
Initialize the SimpleMultiModalDataModule.
Expand All @@ -87,6 +89,7 @@ def __init__(
self.tokenizer = tokenizer
self.image_processor = image_processor
self.seq_length = seq_length
self.decoder_seq_length = decoder_seq_length
self.micro_batch_size = micro_batch_size
self.global_batch_size = global_batch_size
self.num_workers = num_workers
Expand All @@ -99,11 +102,24 @@ def __init__(
)
self.init_global_step = 0
self.data_sampler = SequentialMegatronSampler(
seq_len=self.seq_length, micro_batch_size=self.micro_batch_size, global_batch_size=self.global_batch_size
seq_len=self.seq_length,
decoder_seq_len=self.decoder_seq_length,
micro_batch_size=self.micro_batch_size,
global_batch_size=self.global_batch_size,
)
self.train_dataloader_object = None
self.val_dataloader_object = None

def io_init(self, **kwargs) -> fdl.Config[Self]:
# (pleasefixme) image_processor and task_encoder are problematic with Fiddle so we skip serializing them for now
cfg_kwargs = {k: deepcopy(v) for k, v in kwargs.items() if k not in ['image_processor', 'task_encoder']}

for val in cfg_kwargs.values():
if not serialization.find_node_traverser(type(val)):
track_io(type(val))
cfg = fdl.Config(type(self), **cfg_kwargs)
return cfg

def datasets_provider(self, worker_config, split: Literal['train', 'val'] = 'val'):
"""
Provide the dataset for training or validation.
Expand Down Expand Up @@ -315,6 +331,7 @@ def __init__(
micro_batch_size: int = 4,
global_batch_size: int = 8,
init_consumed_samples: int = 0,
decoder_seq_len: Optional[int] = None,
init_global_step=0,
):
"""
Expand All @@ -328,6 +345,7 @@ def __init__(
"""
super().__init__(
seq_len=seq_len,
decoder_seq_len=decoder_seq_len,
micro_batch_size=micro_batch_size,
global_batch_size=global_batch_size,
init_consumed_samples=init_consumed_samples,
Expand Down
8 changes: 1 addition & 7 deletions nemo/collections/multimodal/data/energon/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from dataclasses import dataclass, field
from typing import List
import torch
from nemo.collections.multimodal.data.energon.conversation import BaseConversationTemplateConfig
from nemo.collections.multimodal.data.energon.conversation import LLaVATemplateConfig


@dataclass
Expand Down Expand Up @@ -56,12 +56,6 @@ class ImageTextRawBatch:
loss_mask: torch.Tensor = field(default_factory=lambda: torch.empty(0, dtype=torch.float))


class LLaVATemplateConfig(BaseConversationTemplateConfig):
"""LLava specific template configuration which extends the base config"""

pass


@dataclass
class MultiModalSampleConfig:
image_token: ImageToken = field(default_factory=ImageToken)
Expand Down
20 changes: 20 additions & 0 deletions nemo/collections/multimodal/data/energon/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,15 @@
class BaseConversationTemplateConfig:
"""Conversation template config related parameters"""

system: Optional[str] = "".format() # fmt: off
roles: List[str] = field(default_factory=lambda: ['user', 'assistant'])
stop_string: Optional[str] = None
chat_template = None


class LLaVATemplateConfig(BaseConversationTemplateConfig):
"""LLava specific template configuration which extends the base config"""

system: Optional[str] = (
"A chat between a curious user and artificial assistant agent. The assistant gives helpful, detailed and polite answers to user's questions.".format()
) # fmt: off
Expand All @@ -36,3 +45,14 @@ class BaseConversationTemplateConfig:
{%- endif %}
{%- endfor -%}
"""


class MLlamaTemplateConfig(BaseConversationTemplateConfig):
"""LLava specific template configuration which extends the base config"""

system: Optional[str] = None
roles: List[str] = field(default_factory=lambda: ['user', 'assistant'])
stop_string: str = None
chat_template = """
'{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now("%d %b %Y") %}\n {%- else %}\n {%- set date_string = "26 Jul 2024" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0][\'role\'] == \'system\' %}\n {%- set system_message = messages[0][\'content\']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = "" %}\n{%- endif %}\n\n{#- Find out if there are any images #}\n{% set image_ns = namespace(has_images=false) %} \n{%- for message in messages %}\n {%- for content in message[\'content\'] %}\n {%- if content[\'type\'] == \'image\' %}\n {%- set image_ns.has_images = true %}\n {%- endif %}\n {%- endfor %}\n{%- endfor %}\n\n{#- Error out if there are images and system message #}\n{%- if image_ns.has_images and not system_message == "" %}\n {{- raise_exception("Prompting with images is incompatible with system messages.") }}\n{%- endif %}\n\n{#- System message if there are no images #}\n{%- if not image_ns.has_images %}\n {{- "<|start_header_id|>system<|end_header_id|>\\n\\n" }}\n {%- if tools is not none %}\n {{- "Environment: ipython\\n" }}\n {%- endif %}\n {{- "Cutting Knowledge Date: December 2023\\n" }}\n {{- "Today Date: " + date_string + "\\n\\n" }}\n {%- if tools is not none and not tools_in_user_message %}\n {{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }}\n {{- \'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.\' }}\n {{- "Do not use variables.\\n\\n" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- "\\n\\n" }}\n {%- endfor %}\n {%- endif %}\n {{- system_message }}\n {{- "<|eot_id|>" }}\n{%- endif %}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0][\'content\']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception("Cannot put tools in the first user message when there\'s no first user message!") }}\n{%- endif %}\n {{- \'<|start_header_id|>user<|end_header_id|>\\n\\n\' -}}\n {{- "Given the following functions, please respond with a JSON for a function call " }}\n {{- "with its proper arguments that best answers the given prompt.\\n\\n" }}\n {{- \'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.\' }}\n {{- "Do not use variables.\\n\\n" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- "\\n\\n" }}\n {%- endfor %}\n {{- first_user_message + "<|eot_id|>"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == \'ipython\' or message.role == \'tool\' or \'tool_calls\' in message) %}\n {{- \'<|start_header_id|>\' + message[\'role\'] + \'<|end_header_id|>\\n\\n\' }}\n {%- if message[\'content\'] is string %}\n {{- message[\'content\'] }}\n {%- else %}\n {%- for content in message[\'content\'] %}\n {%- if content[\'type\'] == \'image\' %}\n {{- \'<|image|>\' }}\n {%- elif content[\'type\'] == \'text\' %}\n {{- content[\'text\'] }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- \'<|eot_id|>\' }}\n {%- elif \'tool_calls\' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception("This model only supports single tool-calls at once!") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {{- \'<|start_header_id|>assistant<|end_header_id|>\\n\\n\' -}}\n {{- \'{"name": "\' + tool_call.name + \'", \' }}\n {{- \'"parameters": \' }}\n {{- tool_call.arguments | tojson }}\n {{- "}" }}\n {{- "<|eot_id|>" }}\n {%- elif message.role == "tool" or message.role == "ipython" %}\n {{- "<|start_header_id|>ipython<|end_header_id|>\\n\\n" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- "<|eot_id|>" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- \'<|start_header_id|>assistant<|end_header_id|>\\n\\n\' }}\n{%- endif %}\n'
"""
2 changes: 1 addition & 1 deletion nemo/collections/multimodal/data/energon/task_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(self, tokenizer, image_processor, multimodal_sample_config):
image_processor (ImageProcessor): The image processor used for preprocessing images across different sample types.
multimodal_sample_config (MultiModalSampleConfig): Configuration object for multimodal samples, including tokens and placeholders.
"""

self.tokenizer = tokenizer
self.encoders: Dict[str, SampleEncoder] = {
VQASample.__name__: VQASampleEncoder(
tokenizer=tokenizer,
Expand Down
52 changes: 45 additions & 7 deletions nemo/collections/vlm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,56 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from nemo.collections.vlm.mllama.data import MLlamaLazyDataModule, MLlamaMockDataModule
from nemo.collections.vlm.mllama.model.base import (
CrossAttentionTextConfig,
CrossAttentionVisionConfig,
MLlamaModel,
MLlamaModelConfig,
)
from nemo.collections.vlm.mllama.model.mllama import (
MLlamaConfig11B,
MLlamaConfig11BInstruct,
MLlamaConfig90B,
MLlamaConfig90BInstruct,
)
from nemo.collections.vlm.neva.data import (
DataConfig,
ImageDataConfig,
ImageToken,
MockDataModule,
MultiModalToken,
NevaLazyDataModule,
NevaMockDataModule,
VideoDataConfig,
VideoToken,
)
from nemo.collections.vlm.neva.model import (
from nemo.collections.vlm.neva.model.base import (
CLIPViTConfig,
HFCLIPVisionConfig,
Llava1_5Config7B,
Llava1_5Config13B,
LlavaConfig,
LlavaModel,
MultimodalProjectorConfig,
NevaConfig,
NevaModel,
)
from nemo.collections.vlm.neva.model.llava import Llava1_5Config7B, Llava1_5Config13B, LlavaConfig, LlavaModel
from nemo.collections.vlm.peft import LoRA
from nemo.collections.vlm.recipes import *

__all__ = [
"MockDataModule",
"NevaMockDataModule",
"NevaLazyDataModule",
"MLlamaMockDataModule",
"MLlamaLazyDataModule",
"DataConfig",
"ImageDataConfig",
"VideoDataConfig",
Expand All @@ -38,4 +66,14 @@
"Llava1_5Config7B",
"Llava1_5Config13B",
"LlavaModel",
"MLlamaModel",
"MLlamaModelConfig",
"CrossAttentionTextConfig",
"CrossAttentionVisionConfig",
"MLlamaConfig11B",
"MLlamaConfig11BInstruct",
"MLlamaConfig90B",
"MLlamaConfig90BInstruct",
"mllama_11b",
"mllama_90b",
]
17 changes: 17 additions & 0 deletions nemo/collections/vlm/mllama/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers import PreTrainedTokenizerFast
from nemo.lightning.io import track_io

track_io(PreTrainedTokenizerFast)
21 changes: 21 additions & 0 deletions nemo/collections/vlm/mllama/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from nemo.collections.vlm.mllama.data.lazy import MLlamaLazyDataModule
from nemo.collections.vlm.mllama.data.mock import MockDataModule as MLlamaMockDataModule

__all__ = [
"MLlamaMockDataModule",
"MLlamaLazyDataModule",
]
Loading
Loading