From 302301bc2175f7e717fb8548516188e89f649753 Mon Sep 17 00:00:00 2001 From: xuruyi Date: Sun, 14 Apr 2024 16:13:28 +0000 Subject: [PATCH] fix bugs --- llava_uhd/train/llava-uhd/adapt_clip.py | 102 +++++++++++++++++------ llava_uhd/train/llava-uhd/adapt_llava.py | 41 +++++++-- llava_uhd/train/llava-uhd/slice_logic.py | 17 +--- llava_uhd/train/llava-uhd/train.py | 40 ++++++++- 4 files changed, 154 insertions(+), 46 deletions(-) diff --git a/llava_uhd/train/llava-uhd/adapt_clip.py b/llava_uhd/train/llava-uhd/adapt_clip.py index 530d893..3effb54 100644 --- a/llava_uhd/train/llava-uhd/adapt_clip.py +++ b/llava_uhd/train/llava-uhd/adapt_clip.py @@ -91,8 +91,8 @@ def __init__(self, config: CLIPVisionConfig): def forward(self, pixel_values: torch.FloatTensor, - origin_image_widths, - origin_image_heights) -> torch.Tensor: + w_patch_num, + h_patch_num) -> torch.Tensor: batch_size = pixel_values.shape[0] patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid] patch_embeds = patch_embeds.flatten(2).transpose(1, 2) @@ -105,10 +105,18 @@ def forward(self, self.position_embedding(self.position_ids), patch_width_num=dim[0], patch_height_num=dim[1] - ).unsqueeze(0) for dim in list(zip(origin_image_widths, origin_image_heights)) + ).unsqueeze(0) for dim in list(zip(w_patch_num, h_patch_num)) ]) - + # print("origin_image_widths",origin_image_widths) + # print("origin_image_heights",origin_image_heights) + # print("pos_embedding_shape",processed_position_embedding.shape) embeddings = embeddings + processed_position_embedding + # for i in range(32): + # if w_patch_num[i]*h_patch_num[i] == 576: + # print(embeddings[i][w_patch_num[i]*h_patch_num[i]][0].item(),0.0,end = "|") + # else: + # print(embeddings[i][w_patch_num[i]*h_patch_num[i]][0].item(),embeddings[i][w_patch_num[i]*h_patch_num[i]+1][0].item(),end = "|") + # print(" ",w_patch_num,h_patch_num) return embeddings class adapt_CLIPVisionTransformer(nn.Module): @@ -128,8 +136,8 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - origin_image_widths = None, - origin_image_heights = None, + w_patch_num = None, + h_patch_num = None, ) -> Union[Tuple, BaseModelOutputWithPooling]: r""" Returns: @@ -145,36 +153,81 @@ def forward( raise ValueError("You have to specify pixel_values") hidden_states = self.embeddings(pixel_values = pixel_values, - origin_image_widths = origin_image_widths, - origin_image_heights = origin_image_heights) + w_patch_num = w_patch_num, + h_patch_num = h_patch_num) + _sums = hidden_states.sum(dim=-1) + _attentionMask = (_sums == 0.00) + _attentionMask = _attentionMask.float() + _attentionMask[_attentionMask == 1] = -float('inf') + # _attentionMask[_attentionMask == 1] = -float('inf') + + # print("image 0 tensor sum",hidden_states[0].sum(dim = -1)) + # print("hidden_states[0][576][0]",hidden_states[0][576][0].item()) + # before layer torch.Size([32, 577, 1024]) + # after layer torch.Size([32, 577, 1024]) hidden_states = self.pre_layrnorm(hidden_states) + # print("after layernorm",hidden_states[0].sum(dim = -1)) + sums = hidden_states.sum(dim=-1) - attentionMask = (sums == 0) + attentionMask = (sums == -1.0000) + # attentionMask = (sums == 0) attentionMask = attentionMask.float() attentionMask[attentionMask == 1] = -float('inf') + + # for i in range(32): + + # print(attentionMask[i][576].item(),end = " ") + # print(" ") + # attentionMask[attentionMask == 1] = -float('inf') + + # print(hidden_states.shape) + # hidden_states torch.Size([32, 577, 1024]) + + # print("hidden_states[0][576][0].item()",hidden_states[0][576][0].item()) + # print(attentionMask.shape) + _true = True + for i in range(577): + if attentionMask[0][i] != _attentionMask[0][i]: + _true = False + # if _true: + # print("This mask is correct") + # else: + # print("This mask is wrong") + # for i in range(577): + # print(attentionMask[0][i],"?",_attentionMask[0][i]) + # attentionMask torch.Size([32, 577]) # 添加一个新维度并复制 - attentionMask = attentionMask.unsqueeze(1).unsqueeze(3).repeat(1, 1, 1, 577).to(torch.bfloat16) + attentionMask = attentionMask.unsqueeze(1).unsqueeze(2).repeat(1, 1, 577, 1).to(torch.bfloat16) encoder_outputs = self.encoder( inputs_embeds=hidden_states, attention_mask = attentionMask, + causal_attention_mask = attentionMask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) - + last_hidden_state = encoder_outputs[0] + # print("last_hidden_state.shape",last_hidden_state.shape) + + _sums = last_hidden_state.sum(dim=-1) + # print("_sum[0][576]",_sums[0][576].item()) pooled_output = last_hidden_state[:, 0, :] + # print("pooled_output.shape before layer",pooled_output.shape) pooled_output = self.post_layernorm(pooled_output) + if not return_dict: + # print("return dict") return (last_hidden_state, pooled_output) + encoder_outputs[1:] + # print(" not return dict ") return BaseModelOutputWithPooling( last_hidden_state=last_hidden_state, pooler_output=pooled_output, @@ -199,8 +252,8 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - origin_image_widths = None, - origin_image_heights = None, + w_patch_num = None, + h_patch_num = None, ) -> Union[Tuple, BaseModelOutputWithPooling]: if pixel_values.shape[0] == 1: @@ -214,8 +267,8 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - origin_image_widths = origin_image_widths, - origin_image_heights = origin_image_heights + w_patch_num = w_patch_num, + h_patch_num = h_patch_num ) @@ -259,10 +312,8 @@ def forward(self, images, origin_image_widths,origin_image_heights): if images.shape[1] == 24: - image_features = [] split_images = torch.chunk(images, chunks=8, dim=1) - slice_w_nums=[] slice_h_nums=[] abstract_w_nums=[] @@ -275,29 +326,30 @@ def forward(self, images, origin_image_widths,origin_image_heights): abstract_w_nums.append(abstract_w_num) abstract_h_nums.append(abstract_h_num) - for i, image in enumerate(split_images): if i == 7: image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True, - origin_image_widths = slice_w_nums, - origin_image_heights = slice_h_nums) + w_patch_num = abstract_w_nums, + h_patch_num = abstract_h_nums) else: image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True, - origin_image_widths = abstract_w_nums, - origin_image_heights = abstract_h_nums) + w_patch_num = slice_w_nums, + h_patch_num = slice_h_nums) image_feature = self.feature_select(image_forward_out).to(image.dtype) - + # print("image_feature.shape",image_feature.shape) + # image_feature.shape torch.Size([4, 576, 1024]) + # print("image_features.shape",image_features.shape) image_features.append(image_feature) else: image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True, - origin_image_widths = origin_image_widths, - origin_image_heights = origin_image_heights) + w_patch_num = origin_image_widths, + h_patch_num = origin_image_heights) image_features = self.feature_select(image_forward_outs).to(images.dtype) diff --git a/llava_uhd/train/llava-uhd/adapt_llava.py b/llava_uhd/train/llava-uhd/adapt_llava.py index 2eb0491..cfe44b9 100644 --- a/llava_uhd/train/llava-uhd/adapt_llava.py +++ b/llava_uhd/train/llava-uhd/adapt_llava.py @@ -86,17 +86,47 @@ def get_vision_tower(self): def encode_images(self, images,origin_image_widths,origin_image_heights): + # print("len(images)",len(images)) + # print("images[0]",images[0].shape) image_features = self.get_model().get_vision_tower()(images,origin_image_widths,origin_image_heights) + + # for i in range(8): + # print(image_features[i][0][0][0].item(),end="|") + # print(" ") + + # print("len(image_features)",len(image_features)) + # print("image_features[0].shape",image_features[0].shape) + # len(image_features) 8 + # image_features[0].shape torch.Size([32, 576, 1024]) + if isinstance(image_features,list): + # print("len(image_features)",len(image_features)) image_features_list = [] for image_feature in image_features: + # print(image_feature) + # 将维度为5120的向量是否全为0的布尔掩码 + # mask = torch.all(image_feature == 0, dim=2) + + # # 打印维度为5120的向量为0的位置 + # indices = torch.nonzero(mask) + + # print("维度为5120的向量为0的位置:") + # print(indices) image_features_list.append(self.get_model().mm_projector(image_feature)) + # print("image_features_list[0].shape",image_features_list[0].shape) image_features = torch.concat( tuple(image_features_list) ,dim = 0) + # print("image_features.shape",image_features.shape) + # image_features.shape torch.Size([32, 64, 5120]) + else: + # print("image_features.shape",image_features.shape) image_features = self.get_model().mm_projector(image_features) - + + # print("image_features.shape",image_features.shape) + # image_features.shape torch.Size([256, 64, 5120]) + return image_features def prepare_inputs_labels_for_multimodal( @@ -115,7 +145,7 @@ def prepare_inputs_labels_for_multimodal( return input_ids, position_ids, attention_mask, past_key_values, None, labels image_features = self.encode_images(images,origin_image_widths,origin_image_heights).to(self.device) - + # print("image_features.shape",image_features.shape) # TODO: image start / end is not implemented here to support pretraining. if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False): raise NotImplementedError @@ -143,6 +173,7 @@ def prepare_inputs_labels_for_multimodal( new_input_embeds = [] new_labels = [] cur_image_idx = 0 + for batch_idx, cur_input_ids in enumerate(input_ids): num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() @@ -167,16 +198,14 @@ def prepare_inputs_labels_for_multimodal( cur_new_labels.append(cur_labels_noim[i]) if i < num_images: - for j in range(5): - cur_image_features = image_features[cur_image_idx+j*16] + for j in range(8): + cur_image_features = image_features[cur_image_idx+j*4] cur_new_input_embeds.append(cur_image_features) cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype)) - cur_new_input_embeds = torch.cat(cur_new_input_embeds) cur_new_labels = torch.cat(cur_new_labels) - new_input_embeds.append(cur_new_input_embeds) new_labels.append(cur_new_labels) diff --git a/llava_uhd/train/llava-uhd/slice_logic.py b/llava_uhd/train/llava-uhd/slice_logic.py index c90752b..3b74609 100644 --- a/llava_uhd/train/llava-uhd/slice_logic.py +++ b/llava_uhd/train/llava-uhd/slice_logic.py @@ -126,13 +126,16 @@ def slice_image(image): best_w, best_h = cal_num_of_slices(origin_image_width=origin_image_width,origin_image_height=origin_image_height) slices = [] + # print(best_w,best_h) for j in range(best_h): for i in range(best_w): box = (i * origin_image_width//best_w, j * origin_image_height//best_h, (i + 1) * origin_image_width//best_w, (j + 1) * origin_image_height//best_h) - + # print(box) + # 切割图片 region = image.crop(box).convert("RGB") + # 添加到列表 slices.append(region) return slices @@ -210,15 +213,3 @@ def process_image(image): resized_patch_widths.append(resized_patch_width) resized_patch_heights.append(resized_patch_height) return images - - -img = Image.open("/home/xuruyi/myLLaVa/883700e3366b775c93315373510e7e7.png") -images = process_image(img) - -for i in range(len(images)): - img = images[i] - to_pil = ToPILImage() - - img = to_pil(img) - - img.save(f"image{i}.png") \ No newline at end of file diff --git a/llava_uhd/train/llava-uhd/train.py b/llava_uhd/train/llava-uhd/train.py index f38e9fe..6d4ad57 100644 --- a/llava_uhd/train/llava-uhd/train.py +++ b/llava_uhd/train/llava-uhd/train.py @@ -679,11 +679,16 @@ def __getitem__(self, i) -> Dict[str, torch.Tensor]: image = Image.open(os.path.join(image_folder, image_file)).convert('RGB') origin_image_width = image.size[0] origin_image_height = image.size[1] + # if local_rank == 0: + # print("path",os.path.join(image_folder, image_file)) + # print("size","image size",image.size) slices_and_image = process_image(image) + # print( slices_and_image[0]) image_tuple = tuple(slices_and_image) + # print(image_tuple) image_tensor = torch.cat(image_tuple,dim = 0) - + # print("image_tensor",image_tensor.shape) sources = preprocess_multimodal( copy.deepcopy([e["conversations"] for e in sources]), @@ -708,6 +713,9 @@ def __getitem__(self, i) -> Dict[str, torch.Tensor]: elif self.data_args.is_multimodal: + # print("theere isnt a photo!!!!!!!!!!!!!!!!!!!!!!!!!!") + # print(2) + # image does not exist in the data, but the model is multimodal crop_size = self.data_args.image_processor.crop_size image = torch.zeros(3, crop_size['height'], crop_size['width']) @@ -754,6 +762,7 @@ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: if 'image' in instances[0]: images = [instance['image'] for instance in instances] + # print("____MY_DEBUG_2____",images) if all(x is not None and x.shape == images[0].shape for x in images): batch['images'] = torch.stack(images) else: @@ -768,7 +777,19 @@ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: batch['images'] = torch.stack(padded_x_tensors) - + # if local_rank == 0: + # print(len(batch["origin_image_heights"])) + # print("batch shape",batch['images'].shape) + # print(batch['origin_image_widths'][0]) + # print(batch["origin_image_heights"][0]) + # for i in range(8): + # print(f"___________________________{i}_________________________________") + # for y in range(5): + # print(batch['images'][0][i*3][0][y].item(),end=" ") + # print("|",end=" ") + # for y in range(5): + # print(batch['images'][0][i*3][335][330+y].item(),end=" ") + # print(" ") return batch @@ -796,6 +817,7 @@ def train(): bnb_model_from_pretrained_args = {} if training_args.bits in [4, 8]: + # print("MY_DEBUG_9________") from transformers import BitsAndBytesConfig bnb_model_from_pretrained_args.update(dict( device_map={"": training_args.device}, @@ -924,6 +946,8 @@ def make_inputs_require_grad(module, input, output): model.config.tokenizer_padding_side = tokenizer.padding_side model.config.tokenizer_model_max_length = tokenizer.model_max_length + # print("model_args",model_args) + # print("config:",model.config) training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter model.config.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter @@ -937,6 +961,12 @@ def make_inputs_require_grad(module, input, output): for p in model.get_model().mm_projector.parameters(): p.requires_grad = False + # print("MY_DEBUG_100_________") + print("freezeinginging") + # model.get_vision_tower().unfreeze_position_embedding() + + # print("MY_DEBUG_111_________") + if training_args.bits in [4, 8]: model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device) @@ -976,6 +1006,12 @@ def make_inputs_require_grad(module, input, output): args=training_args, **data_module) + # def print_model_parameters(model): + # print("Model Parameters:") + # for name, param in model.named_parameters(): + # print(f"{name}: {param.size()}") + + # print_model_parameters(model) #-----------------------------------------------------# # 检查 checkpoints 路径是否有保存的检查点