-
Notifications
You must be signed in to change notification settings - Fork 525
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
1. Support dynamic sequence length during training 2. Update README.md 3. Update evaluation code
- Loading branch information
Showing
12 changed files
with
167 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import numpy as np | ||
import torch | ||
|
||
IGNORE_INDEX = -100 | ||
|
||
|
||
def pad_data_collator(features, pad_id=0): | ||
|
||
first = features[0] | ||
batch = {} | ||
|
||
batch_lens = [feat['input_ids'].shape for feat in features] | ||
max_item_length = max(batch_lens)[0] | ||
for idx in range(len(features)): | ||
feat = features[idx] | ||
temp_input_ids = torch.LongTensor([pad_id] * max_item_length) | ||
temp_input_ids[:feat['input_ids'].shape[0]] = feat['input_ids'] | ||
feat['input_ids'] = temp_input_ids | ||
temp_labels = torch.LongTensor([IGNORE_INDEX] * max_item_length) | ||
temp_labels[:feat['labels'].shape[0]] = feat['labels'] | ||
feat['labels'] = temp_labels | ||
feat['attention_mask'] = feat['input_ids'].ne(pad_id) | ||
|
||
# Special handling for labels. | ||
# Ensure that tensor is created with the correct type | ||
# (it should be automatically the case, but let's make sure of it.) | ||
if 'label' in first and first['label'] is not None: | ||
label = first['label'].item() if isinstance(first['label'], torch.Tensor) else first['label'] | ||
dtype = torch.long if isinstance(label, int) else torch.float | ||
batch['labels'] = torch.tensor([f['label'] for f in features], dtype=dtype) | ||
elif 'label_ids' in first and first['label_ids'] is not None: | ||
if isinstance(first['label_ids'], torch.Tensor): | ||
batch['labels'] = torch.stack([f['label_ids'] for f in features]) | ||
else: | ||
dtype = torch.long if isinstance(first['label_ids'][0], int) else torch.float | ||
batch['labels'] = torch.tensor([f['label_ids'] for f in features], dtype=dtype) | ||
|
||
# Handling of all other possible keys. | ||
# Again, we will use the first element to figure out which key/values are not None for this model. | ||
for k, v in first.items(): | ||
if k not in ('label', 'label_ids') and v is not None and not isinstance(v, str): | ||
if isinstance(v, torch.Tensor): | ||
batch[k] = torch.stack([f[k] for f in features]) | ||
elif isinstance(v, np.ndarray): | ||
batch[k] = torch.tensor(np.stack([f[k] for f in features])) | ||
else: | ||
batch[k] = torch.tensor([f[k] for f in features]) | ||
|
||
return batch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from typing import Optional | ||
|
||
import torch | ||
import transformers | ||
from transformers.trainer import (LengthGroupedSampler, RandomSampler, | ||
has_length) | ||
|
||
|
||
# patch trainer | ||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: | ||
if self.train_dataset is None or not has_length(self.train_dataset): | ||
return None | ||
# Build the sampler. | ||
if self.args.group_by_length: | ||
lengths = [] | ||
for dataset in self.train_dataset.datasets: | ||
lengths = lengths + dataset.length | ||
model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None | ||
return LengthGroupedSampler( | ||
self.args.train_batch_size * self.args.gradient_accumulation_steps, | ||
dataset=self.train_dataset, | ||
lengths=lengths, | ||
model_input_name=model_input_name, | ||
) | ||
else: | ||
return RandomSampler(self.train_dataset) | ||
|
||
|
||
def replace_train_sampler(): | ||
transformers.Trainer._get_train_sampler = _get_train_sampler | ||
print('Replace train sampler!!') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import argparse | ||
|
||
import torch | ||
from internvl.model.internvl_chat import InternVLChatModel | ||
from transformers import AutoTokenizer | ||
|
||
argparse = argparse.ArgumentParser() | ||
argparse.add_argument('model_path', type=str, default='') | ||
argparse.add_argument('output_path', type=str, default='') | ||
argparse.add_argument('force_image_size', type=int, default=448) | ||
|
||
args = argparse.parse_args() | ||
|
||
model = InternVLChatModel.from_pretrained(args.model_path, torch_dtype=torch.bfloat16) | ||
model.vision_model.resize_pos_embeddings(old_size=model.config.vision_config.image_size, | ||
new_size=args.force_image_size, | ||
patch_size=14) | ||
model.config.vision_config.image_size = args.force_image_size | ||
model.config.force_image_size = args.force_image_size | ||
|
||
model.save_pretrained(args.output_path) | ||
|
||
tokenizer = AutoTokenizer.from_pretrained(args.model_path) | ||
tokenizer.save_pretrained(args.output_path) | ||
print('finished') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters