Skip to content

Commit

Permalink
2023-12-07 evening
Browse files Browse the repository at this point in the history
  • Loading branch information
Auditor1234 committed Dec 7, 2023
1 parent 380ebb5 commit ffd853e
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 27 deletions.
2 changes: 1 addition & 1 deletion clip/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def EMGload(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.i
with open(model_path, 'rb') as opened_file:
model = torch.jit.load(opened_file, map_location="cpu").eval()

model = EMGbuild_model(model.state_dict(), classification=classification, vis_pretrain=vis_pretrain, model_dim=model_dim).to(device)
model = EMGbuild_model(model.state_dict(), classification=classification, vis_pretrain=vis_pretrain, emg_model=name, model_dim=model_dim).to(device)
if str(device) == "cpu":
model.float()
return model
114 changes: 101 additions & 13 deletions clip/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,84 @@ def stem(x):
return x


class EMGVisionTransformer1D(nn.Module): # 400 8 512 12 12 512
# input shape(B,8,400,1)
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
super().__init__()
self.input_resolution = input_resolution
self.output_dim = output_dim
self.conv1 = nn.Conv2d(in_channels=8, out_channels=width, kernel_size=(patch_size, 1), stride=(patch_size, 1), bias=False)

scale = width ** -0.5
self.class_embedding = nn.Parameter(scale * torch.randn(width))
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) + 1, width))
self.ln_pre = LayerNorm(width)

self.transformer = Transformer(width, layers, heads)

self.ln_post = LayerNorm(width)
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))

def forward(self, x: torch.Tensor):
# shape(B,8,400,1)
x = self.conv1(x) # shape(B,width,50,1)
x = x.reshape(x.shape[0], x.shape[1], -1) # shape(B,width,50)
x = x.permute(0, 2, 1) # shape(B,50,width)
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape(B,50 + 1,width)
x = x + self.positional_embedding.to(x.dtype) # shape(B,51,width)
x = self.ln_pre(x) # shape(B,51,width)

x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x) # shape(51,B,width)
x = x.permute(1, 0, 2) # LND -> NLD

x = self.ln_post(x[:, 0, :]) # shape(B,width)

if self.proj is not None:
x = x @ self.proj

return x # shape(1,512)


class EMGVisionTransformer2D(nn.Module): # 400 8 512 12 12 512
# input shape(B,1,400,8)
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
super().__init__()
self.input_resolution = input_resolution
self.output_dim = output_dim
self.conv1 = nn.Conv2d(in_channels=1, out_channels=width, kernel_size=(patch_size, patch_size), stride=(patch_size, 1), bias=False)

scale = width ** -0.5
self.class_embedding = nn.Parameter(scale * torch.randn(width))
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) + 1, width))
self.ln_pre = LayerNorm(width)

self.transformer = Transformer(width, layers, heads)

self.ln_post = LayerNorm(width)
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))

def forward(self, x: torch.Tensor):
# shape(B,8,400,1)
x = self.conv1(x) # shape(B,width,50,1)
x = x.reshape(x.shape[0], x.shape[1], -1) # shape(B,width,50)
x = x.permute(0, 2, 1) # shape(B,50,width)
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape(B,50 + 1,width)
x = x + self.positional_embedding.to(x.dtype) # shape(B,51,width)
x = self.ln_pre(x) # shape(B,51,width)

x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x) # shape(51,B,width)
x = x.permute(1, 0, 2) # LND -> NLD

x = self.ln_post(x[:, 0, :]) # shape(B,width)

if self.proj is not None:
x = x @ self.proj

return x # shape(1,512)


class EMGCLIP(nn.Module):
def __init__(self,
embed_dim: int, # 512
Expand All @@ -663,7 +741,7 @@ def __init__(self,
transformer_heads: int, # 8
transformer_layers: int, # 12
classification: bool,
vision_model = 'RN50',
emg_model = 'RN50',
model_dim=2
):
super().__init__()
Expand All @@ -672,7 +750,7 @@ def __init__(self,
self.classification = classification

# signal
if vision_model == 'RN50':
if emg_model == 'RN50':
if model_dim == 1:
self.visual = EMGModifiedResNet1D(
layers=(3,4,6,4), # (3, 4, 6, 4)
Expand All @@ -686,14 +764,24 @@ def __init__(self,
width=64
)
else:
self.visual = VisionTransformer(
input_resolution=224,
patch_size=32,
width=768,
layers=12,
heads=12,
output_dim=embed_dim
)
if model_dim == 1:
self.visual = EMGVisionTransformer1D(
input_resolution=400,
patch_size=8,
width=512,
layers=12,
heads=8,
output_dim=embed_dim
)
else:
self.visual = EMGVisionTransformer2D(
input_resolution=400,
patch_size=8,
width=512,
layers=12,
heads=8,
output_dim=embed_dim
)

# text
self.transformer = Transformer(
Expand Down Expand Up @@ -795,7 +883,7 @@ def forward(self, image, text):



def EMGbuild_model(state_dict: dict, classification, vis_pretrain=True, vision_model='RN50', model_dim=2):
def EMGbuild_model(state_dict: dict, classification, vis_pretrain=True, emg_model='RN50', model_dim=2):
# text
embed_dim = state_dict["text_projection"].shape[1]
context_length = state_dict["positional_embedding"].shape[0]
Expand All @@ -807,7 +895,7 @@ def EMGbuild_model(state_dict: dict, classification, vis_pretrain=True, vision_m
model = EMGCLIP(
embed_dim,
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers, classification,
vision_model=vision_model, model_dim=model_dim
emg_model=emg_model, model_dim=model_dim
)

for key in ["input_resolution", "context_length", "vocab_size"]:
Expand All @@ -820,5 +908,5 @@ def EMGbuild_model(state_dict: dict, classification, vis_pretrain=True, vision_m
del state_dict[key]

convert_weights(model)
model.load_state_dict(state_dict, strict=False)
# model.load_state_dict(state_dict, strict=False)
return model.eval()
4 changes: 2 additions & 2 deletions dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,9 @@ def h5py_to_window(filename):

if __name__ == '__main__':
filename = 'D:/Download/Datasets/Ninapro/DB2/S1/S1_E1_A1.mat'
h5_filename = 'dataset/window_400_200.h5'
h5_filename = 'dataset/window_400_300.h5'
emg, label = load_emg_label_from_file(filename)
window_to_h5py(emg, label, h5_filename, window_overlap=200)
window_to_h5py(emg, label, h5_filename, window_overlap=300)
emg, label = h5py_to_window(h5_filename)
print(emg.shape)
print(label.shape)
7 changes: 5 additions & 2 deletions learn_grammer/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
from PIL import Image
import numpy as np

from torchvision.transforms import ToTensor

# print(torch.ones([])) # 每个tensor值为1
Expand Down Expand Up @@ -44,7 +45,7 @@

# model = nn.LayerNorm(4) # 总体标准差
# x = torch.tensor(range(12), dtype=torch.float32).reshape(3, 4)
# print(model(x))
# print(model(x)) # 对最后一维进行标准化,shape(3,4)


# image = Image.open('CLIP.png')
Expand Down Expand Up @@ -156,4 +157,6 @@


x = torch.rand(4, 5)
print(x[:1].shape)
print(x[:1].shape)
print(x[0,:].shape)
print(x.dim())
11 changes: 5 additions & 6 deletions main_signal_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ def main(args):
epochs = args.epochs
device = "cuda" if torch.cuda.is_available() else "cpu"

filename = 'dataset/window_400_200.h5'
filename = 'dataset/window_400_300.h5'
weight_path = 'res/best.pt'
best_precision, current_precision = 0, 0

model_dim = 1 # 数据维数 1为(B,8,400,1),2为(B,1,400,8)
classification = True # 是否是分类任务
model = clip.EMGload("RN50", device=device, classification=classification, vis_pretrain=False, model_dim=model_dim)
model_dim = 2 # 数据维数 1为(B,8,400,1),2为(B,1,400,8)
classification = False # 是否是分类任务
model = clip.EMGload("ViT-B/32", device=device, classification=classification, vis_pretrain=False, model_dim=model_dim)

# optimizer = Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.98), weight_decay=0.2)
optimizer = Adam(model.parameters(), lr=args.lr, eps=1e-3)
Expand Down Expand Up @@ -92,10 +92,9 @@ def main(args):

if current_precision > best_precision:
best_precision = current_precision
print('Current best precision in val set is: %.4f' % (best_precision * 100) + '%')
save_model_weight(model=model, filename=weight_path)


print('\nCurrent best precision in val set is: %.4f' % (best_precision * 100) + '%')
model.load_state_dict(torch.load(weight_path))
evaluate_signal_text(model, device, eval_loader, loss_func, classification=classification, model_dim=model_dim)

Expand Down
6 changes: 3 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def validate_signal_text(model, device, val_loader, loss_func, classification, m
total_nums += len(predict_idx)

loop.set_description(f'Validating [{i + 1}/{len(val_loader)}]')
loop.set_postfix(loss = loss.item())
loop.set_postfix(loss = '%.6f' % loss.item())

precision = '%.4f' % (100 * correct_nums / total_nums) + '%'
print("Total loss: {}".format(total_loss))
Expand All @@ -289,7 +289,7 @@ def evaluate_signal_text(model, device, eval_loader, loss_func, classification,
model.eval() # 精度在64%
total_loss, correct_nums, total_nums = 0.0, 0, 0

print("\nEvaluating...")
print("Evaluating...")
loop = tqdm(eval_loader, desc='Evaluation')
for i, (window_data, window_labels) in enumerate(loop): # shape(B,400,8)
if model_dim == 1:
Expand Down Expand Up @@ -317,7 +317,7 @@ def evaluate_signal_text(model, device, eval_loader, loss_func, classification,
total_nums += len(predict_idx)

loop.set_description(f'Evaluating [{i + 1}/{len(eval_loader)}]')
loop.set_postfix(loss = loss.item())
loop.set_postfix(loss = '%.6f' % loss.item())

precision = '%.4f' % (100 * correct_nums / total_nums) + '%'
print("Total loss: {}".format(total_loss))
Expand Down

0 comments on commit ffd853e

Please sign in to comment.