diff --git a/clip/clip.py b/clip/clip.py index 48a6bd7..5db67ca 100644 --- a/clip/clip.py +++ b/clip/clip.py @@ -10,7 +10,7 @@ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize from tqdm import tqdm -from .model import build_model +from .model import build_model, EMGbuild_model from .simple_tokenizer import SimpleTokenizer as _Tokenizer try: @@ -24,7 +24,7 @@ warnings.warn("PyTorch version 1.7.1 or higher is recommended") -__all__ = ["available_models", "load", "tokenize"] +__all__ = ["available_models", "load", "tokenize", "EMGload"] _tokenizer = _Tokenizer() _MODELS = { @@ -244,3 +244,22 @@ def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: b result[i, :len(tokens)] = torch.tensor(tokens) return result + + +def EMGload(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", + download_root: str = None, classification=True, vis_pretrain=True, model_dim=2): + + if name in _MODELS: + model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) + elif os.path.isfile(name): + model_path = name + else: + raise RuntimeError(f"Model {name} not found; available models = {available_models()}") + + with open(model_path, 'rb') as opened_file: + model = torch.jit.load(opened_file, map_location="cpu").eval() + + model = EMGbuild_model(model.state_dict(), classification=classification, vis_pretrain=vis_pretrain, model_dim=model_dim).to(device) + if str(device) == "cpu": + model.float() + return model diff --git a/clip/model.py b/clip/model.py index 7067bdc..c741112 100644 --- a/clip/model.py +++ b/clip/model.py @@ -527,7 +527,7 @@ def forward(self, x): return x.squeeze(0) -class EMGModifiedResNet(nn.Module): +class EMGModifiedResNet1D(nn.Module): """ A ResNet class that is similar to torchvision's but contains the following changes: - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. @@ -559,8 +559,8 @@ def __init__(self, layers, output_dim, width=64): self.layer3 = self._make_layer(width * 4, layers[2], stride=2) # shape(B,1024,25,1) self.layer4 = self._make_layer(width * 8, layers[3], stride=1) # shape(B,2048,25,1) - embed_dim = width * 32 # the ResNet feature dimension - self.attnpool = EMGAttentionPool2d(25, embed_dim, 16, output_dim) + embed_dim = width * 8 # the ResNet feature dimension + self.attnpool = EMGAttentionPool2d(50, embed_dim, 16, output_dim) def _make_layer(self, planes, blocks, stride=1): layers = [EMGBottleneck(self._inplanes, planes, stride)] @@ -583,8 +583,71 @@ def stem(x): x = stem(x) x = self.layer1(x) # shape(B,256,100,1) x = self.layer2(x) # shape(B,512,50,1) - x = self.layer3(x) # shape(B,1024,25,1) - x = self.layer4(x) # shape(B,2048,25,1) + # x = self.layer3(x) # shape(B,1024,25,1) + # x = self.layer4(x) # shape(B,2048,25,1) + x = self.attnpool(x) # shape(B,1024) + + return x + + +class EMGModifiedResNet2D(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + # input shape(B, 1, 400, 8) + def __init__(self, layers, output_dim, width=64): + super().__init__() + self.output_dim = output_dim + + # the 3-layer stem + self.conv1 = nn.Conv2d(1, width // 2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) # shape(B,32,200,4) + self.bn1 = nn.BatchNorm2d(width // 2) + self.relu1 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=(3, 3), padding=(1, 1), bias=False) # shape(B,32,200,4) + self.bn2 = nn.BatchNorm2d(width // 2) + self.relu2 = nn.ReLU(inplace=True) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) # shape(B,64,100,2) + self.bn3 = nn.BatchNorm2d(width) + self.relu3 = nn.ReLU(inplace=True) + self.avgpool = nn.AvgPool2d((2, 2)) # shape(B,64,50,1) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) # shape(B,256,50,1) + self.layer2 = self._make_layer(width * 2, layers[1], stride=1) # shape(B,512,50,1) + self.layer3 = self._make_layer(width * 4, layers[2], stride=1) # shape(B,1024,25,1) + self.layer4 = self._make_layer(width * 8, layers[3], stride=1) # shape(B,2048,25,1) + + embed_dim = width * 8 # the ResNet feature dimension + self.attnpool = EMGAttentionPool2d(50, embed_dim, 16, output_dim) + + def _make_layer(self, planes, blocks, stride=1): + layers = [EMGBottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * EMGBottleneck.expansion + for _ in range(1, blocks): + layers.append(EMGBottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + def stem(x): + x = self.relu1(self.bn1(self.conv1(x))) # shape(B,32,200,4) + x = self.relu2(self.bn2(self.conv2(x))) # shape(B,32,200,4) + x = self.relu3(self.bn3(self.conv3(x))) # shape(B,64,200,4) + x = self.avgpool(x) # shape(B,64,100,2) + return x + + x = x.type(self.conv1.weight.dtype) + x = stem(x) + x = self.layer1(x) # shape(B,256,100,1) + x = self.layer2(x) # shape(B,512,50,1) + # x = self.layer3(x) # shape(B,1024,25,2) + # x = self.layer4(x) # shape(B,2048,25,1) x = self.attnpool(x) # shape(B,1024) return x @@ -593,11 +656,6 @@ def stem(x): class EMGCLIP(nn.Module): def __init__(self, embed_dim: int, # 512 - # vision - image_resolution: int, # 224 - vision_layers: Union[Tuple[int, int, int, int], int], # 12 - vision_width: int, # 768 - vision_patch_size: int, # 32 # text context_length: int, # 77 vocab_size: int, # 49408 @@ -605,33 +663,39 @@ def __init__(self, transformer_heads: int, # 8 transformer_layers: int, # 12 classification: bool, - vision_model = 'RN50' + vision_model = 'RN50', + model_dim=2 ): super().__init__() self.context_length = context_length self.classification = classification + # signal if vision_model == 'RN50': - vision_heads = vision_width * 32 // 64 # 32 - self.visual = EMGModifiedResNet( - layers=vision_layers, # (3, 4, 6, 4) - output_dim=embed_dim, - heads=vision_heads, - input_resolution=image_resolution, - width=vision_width - ) + if model_dim == 1: + self.visual = EMGModifiedResNet1D( + layers=(3,4,6,4), # (3, 4, 6, 4) + output_dim=embed_dim, + width=64 + ) + else: + self.visual = EMGModifiedResNet2D( + layers=(3,4,6,4), # (3, 4, 6, 4) + output_dim=embed_dim, + width=64 + ) else: - vision_heads = vision_width // 64 # 12 self.visual = VisionTransformer( - input_resolution=image_resolution, - patch_size=vision_patch_size, - width=vision_width, - layers=vision_layers, - heads=vision_heads, + input_resolution=224, + patch_size=32, + width=768, + layers=12, + heads=12, output_dim=embed_dim ) + # text self.transformer = Transformer( width=transformer_width, layers=transformer_layers, @@ -648,7 +712,7 @@ def __init__(self, self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) self.dropout = nn.Dropout() - self.cnn_projection = nn.Linear(embed_dim, 10) + self.class_projection = nn.Linear(embed_dim, 10) self.initialize_parameters() @@ -714,7 +778,7 @@ def encode_text(self, text): # shape(3,77) def forward(self, image, text): image_features = self.encode_image(image) # shape(B,8,400,1) if self.classification: - return self.cnn_projection(self.dropout(image_features)) + return self.class_projection(self.dropout(image_features)) text_features = self.encode_text(text) # normalized features @@ -728,3 +792,33 @@ def forward(self, image, text): # shape = [global_batch_size, global_batch_size] return logits_per_image, logits_per_text + + + +def EMGbuild_model(state_dict: dict, classification, vis_pretrain=True, vision_model='RN50', model_dim=2): + # text + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks"))) + + model = EMGCLIP( + embed_dim, + context_length, vocab_size, transformer_width, transformer_heads, transformer_layers, classification, + vision_model=vision_model, model_dim=model_dim + ) + + for key in ["input_resolution", "context_length", "vocab_size"]: + if key in state_dict: + del state_dict[key] + + if not vis_pretrain: + for key in list(state_dict.keys()): + if 'visual' in key: + del state_dict[key] + + convert_weights(model) + model.load_state_dict(state_dict, strict=False) + return model.eval() diff --git a/dataloader.py b/dataloader.py index ceca461..ab1cc04 100644 --- a/dataloader.py +++ b/dataloader.py @@ -159,14 +159,14 @@ def load_emg_label_from_file(filename, class_type=10): return emg, label -def window_to_h5py(emg, label, filename, window_size=400): +def window_to_h5py(emg, label, filename, window_size=400, window_overlap=0): window_data = [] window_label = [] for i in range(len(label)): emg_type = np.array(emg[i]) window_count = 0 print('{} emg points found in type {} emg signal.'.format(len(emg_type), label[i])) - for j in range(0, len(emg_type) - window_size, window_size): + for j in range(0, len(emg_type) - window_size, window_size - window_overlap): window_data.append(emg_type[j : j + window_size]) window_label.append(label[i]) window_count += 1 @@ -188,9 +188,9 @@ def h5py_to_window(filename): if __name__ == '__main__': filename = 'D:/Download/Datasets/Ninapro/DB2/S1/S1_E1_A1.mat' - h5_filename = 'dataset/window.h5' - # emg, label = load_emg_label_from_file(filename) - # window_to_h5py(emg, label, h5_filename) + h5_filename = 'dataset/window_400_200.h5' + emg, label = load_emg_label_from_file(filename) + window_to_h5py(emg, label, h5_filename, window_overlap=200) emg, label = h5py_to_window(h5_filename) print(emg.shape) print(label.shape) \ No newline at end of file diff --git a/dataset/window_400_200.h5 b/dataset/window_400_200.h5 new file mode 100644 index 0000000..b66231d Binary files /dev/null and b/dataset/window_400_200.h5 differ diff --git a/main_signal_text.py b/main_signal_text.py index e3e0d73..3e44f00 100644 --- a/main_signal_text.py +++ b/main_signal_text.py @@ -15,7 +15,7 @@ def arg_parse(): parser = argparse.ArgumentParser() parser.add_argument("--batch_size", type=int, default=16, help="dataset batch size") - parser.add_argument("--epochs", type=int, default=100, help="training epochs") + parser.add_argument("--epochs", type=int, default=60, help="training epochs") parser.add_argument("--lr", type=float, default=0.0004, help="learning rate") parser.add_argument("--dataset", type=str, default="./dataset/img", help="dataset directory path") @@ -26,17 +26,20 @@ def main(args): epochs = args.epochs device = "cuda" if torch.cuda.is_available() else "cpu" - model, _ = clip.load("RN50", device=device, vis_pretrain=False) + + filename = 'dataset/window_400_200.h5' + weight_path = 'res/best.pt' + best_precision, current_precision = 0, 0 + + model_dim = 1 # 数据维数 1为(B,8,400,1),2为(B,1,400,8) + classification = True # 是否是分类任务 + model = clip.EMGload("RN50", device=device, classification=classification, vis_pretrain=False, model_dim=model_dim) # optimizer = Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.98), weight_decay=0.2) optimizer = Adam(model.parameters(), lr=args.lr, eps=1e-3) scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.99) loss_func = nn.CrossEntropyLoss() - filename = 'dataset/window.h5' - weight_path = 'res/best.pt' - best_precision, current_precision = 0, 0 - emg, label = h5py_to_window(filename) data_len = len(label) index = np.random.permutation(data_len) @@ -74,22 +77,27 @@ def main(args): num_workers=0 ) - + if classification: + print('-------- Classification task --------') + else: + print('-------- Pair task --------') + print("{}D signal input.".format(model_dim)) model.train().half() model.to(device) print("start training...") for epoch in range(epochs): - train_one_epoch_signal_text(model, epoch, epochs, device, train_loader, loss_func, optimizer, scheduler) - current_precision = validate_signal_text(model, device, val_loader, loss_func) + train_one_epoch_signal_text(model, epoch, epochs, device, train_loader, loss_func, + optimizer, scheduler, classification=classification, model_dim=model_dim) + current_precision = validate_signal_text(model, device, val_loader, loss_func, classification=classification, model_dim=model_dim) if current_precision > best_precision: best_precision = current_precision - print('Current best precision in val set is:%.4f' % (best_precision * 100) + '%') + print('Current best precision in val set is: %.4f' % (best_precision * 100) + '%') save_model_weight(model=model, filename=weight_path) model.load_state_dict(torch.load(weight_path)) - evaluate_signal_text(model, device, eval_loader, loss_func) + evaluate_signal_text(model, device, eval_loader, loss_func, classification=classification, model_dim=model_dim) if __name__ == "__main__": args = arg_parse() diff --git a/train.py b/train.py index 268f80d..bbc2421 100644 --- a/train.py +++ b/train.py @@ -206,12 +206,19 @@ def validate_signal(model, device, val_loader, loss_func): # train signal and text jointly -def train_one_epoch_signal_text(model, epoch, epochs, device, train_loader, loss_func, optimizer, scheduler, classification=True): - model.train() +def train_one_epoch_signal_text(model, epoch, epochs, device, train_loader, loss_func, optimizer, scheduler, classification, model_dim=2): + # 最好的分类概率是EMGModifiedResNet2D,使用window_400_200.h5,得到的验证集86%,测试集82% + # 最好的配对概率是EMGModifiedResNet2D,使用window_400_200.h5,得到的验证集63%,测试集63% + model.train() total_loss = 0.0 - loop = tqdm(train_loader, desc='Train') - for _, (window_data, window_labels) in enumerate(loop): # shape(4,400,8) - window_data = window_data.transpose(1, 2).unsqueeze(-1).to(device).type(torch.float32) # shape(4,8,400,1) + loop = tqdm(train_loader, desc='Train', ncols=150) + for _, (window_data, window_labels) in enumerate(loop): # shape(B,400,8) + if model_dim == 1: + window_data = window_data.transpose(1, 2).unsqueeze(-1) # shape(B,8,400,1) + else: + window_data = window_data.unsqueeze(1) # shape(B,1,400,8) + window_data = window_data.to(device).type(torch.float32) + text = clip.tokenize([template + prompts[mov_idx + 12] for mov_idx in window_labels]).to(device) if classification: @@ -231,20 +238,25 @@ def train_one_epoch_signal_text(model, epoch, epochs, device, train_loader, loss loss.backward() optimizer.step() loop.set_description(f'Epoch [{epoch+1}/{epochs}]') - loop.set_postfix(loss = loss.item()) + loop.set_postfix(loss = '%.6f' % loss.item()) scheduler.step() print("[%d/%d] epoch's total loss = %f" % (epoch + 1, epochs, total_loss)) save_results('res/results.csv', '%d, %12.6f\n' % (epoch + 1, total_loss)) -def validate_signal_text(model, device, val_loader, loss_func, classification=True): +def validate_signal_text(model, device, val_loader, loss_func, classification, model_dim=2): model.eval() total_loss, correct_nums, total_nums = 0.0, 0, 0 print("Validating...") - loop = tqdm(val_loader, desc='Train') - for i, (window_data, window_labels) in enumerate(loop): # shape(4,400,8) - window_data = window_data.transpose(1, 2).unsqueeze(-1).to(device).type(torch.float32) # shape(4,8,400,1) + loop = tqdm(val_loader, desc='Validation', ncols=100) + for i, (window_data, window_labels) in enumerate(loop): # shape(B,400,8) + if model_dim == 1: + window_data = window_data.transpose(1, 2).unsqueeze(-1) # shape(B,8,400,1) + else: + window_data = window_data.unsqueeze(1) # shape(B,1,400,8) + window_data = window_data.to(device).type(torch.float32) + text = clip.tokenize([template + prompts[mov_idx + 12] for mov_idx in window_labels]).to(device) if classification: @@ -267,24 +279,29 @@ def validate_signal_text(model, device, val_loader, loss_func, classification=Tr loop.set_postfix(loss = loss.item()) precision = '%.4f' % (100 * correct_nums / total_nums) + '%' - print("Validation:") print("Total loss: {}".format(total_loss)) print("Correct/Total: {}/{}".format(correct_nums, total_nums)) print("Precision:", precision) return correct_nums.item() / total_nums -def evaluate_signal_text(model, device, eval_loader, loss_func, classification=True): +def evaluate_signal_text(model, device, eval_loader, loss_func, classification, model_dim=2): model.eval() # 精度在64% total_loss, correct_nums, total_nums = 0.0, 0, 0 - print("Evaluating...") + print("\nEvaluating...") loop = tqdm(eval_loader, desc='Evaluation') - for i, (window_data, window_labels) in enumerate(loop): # shape(16,400,8) + for i, (window_data, window_labels) in enumerate(loop): # shape(B,400,8) + if model_dim == 1: + window_data = window_data.transpose(1, 2).unsqueeze(-1) # shape(B,8,400,1) + else: + window_data = window_data.unsqueeze(1) # shape(B,1,400,8) + window_data = window_data.to(device).type(torch.float32) + text = clip.tokenize([template + prompts[mov_idx + 12] for mov_idx in window_labels]).to(device) if classification: - logits_per_image = model(window_data, text) + logits_per_image = model(window_data, text) # shape(B,10) labels = window_labels.to(device).type(torch.long) - 1 loss = loss_func(logits_per_image, labels) else: @@ -303,7 +320,6 @@ def evaluate_signal_text(model, device, eval_loader, loss_func, classification=T loop.set_postfix(loss = loss.item()) precision = '%.4f' % (100 * correct_nums / total_nums) + '%' - print("Evalution:") print("Total loss: {}".format(total_loss)) print("Correct/Total: {}/{}".format(correct_nums, total_nums)) print("Precision:", precision)