Skip to content

Commit

Permalink
2023-12-07
Browse files Browse the repository at this point in the history
  • Loading branch information
Auditor1234 committed Dec 7, 2023
1 parent 61b0702 commit 380ebb5
Show file tree
Hide file tree
Showing 6 changed files with 198 additions and 61 deletions.
23 changes: 21 additions & 2 deletions clip/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from tqdm import tqdm

from .model import build_model
from .model import build_model, EMGbuild_model
from .simple_tokenizer import SimpleTokenizer as _Tokenizer

try:
Expand All @@ -24,7 +24,7 @@
warnings.warn("PyTorch version 1.7.1 or higher is recommended")


__all__ = ["available_models", "load", "tokenize"]
__all__ = ["available_models", "load", "tokenize", "EMGload"]
_tokenizer = _Tokenizer()

_MODELS = {
Expand Down Expand Up @@ -244,3 +244,22 @@ def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: b
result[i, :len(tokens)] = torch.tensor(tokens)

return result


def EMGload(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
download_root: str = None, classification=True, vis_pretrain=True, model_dim=2):

if name in _MODELS:
model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
elif os.path.isfile(name):
model_path = name
else:
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")

with open(model_path, 'rb') as opened_file:
model = torch.jit.load(opened_file, map_location="cpu").eval()

model = EMGbuild_model(model.state_dict(), classification=classification, vis_pretrain=vis_pretrain, model_dim=model_dim).to(device)
if str(device) == "cpu":
model.float()
return model
148 changes: 121 additions & 27 deletions clip/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ def forward(self, x):
return x.squeeze(0)


class EMGModifiedResNet(nn.Module):
class EMGModifiedResNet1D(nn.Module):
"""
A ResNet class that is similar to torchvision's but contains the following changes:
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
Expand Down Expand Up @@ -559,8 +559,8 @@ def __init__(self, layers, output_dim, width=64):
self.layer3 = self._make_layer(width * 4, layers[2], stride=2) # shape(B,1024,25,1)
self.layer4 = self._make_layer(width * 8, layers[3], stride=1) # shape(B,2048,25,1)

embed_dim = width * 32 # the ResNet feature dimension
self.attnpool = EMGAttentionPool2d(25, embed_dim, 16, output_dim)
embed_dim = width * 8 # the ResNet feature dimension
self.attnpool = EMGAttentionPool2d(50, embed_dim, 16, output_dim)

def _make_layer(self, planes, blocks, stride=1):
layers = [EMGBottleneck(self._inplanes, planes, stride)]
Expand All @@ -583,8 +583,71 @@ def stem(x):
x = stem(x)
x = self.layer1(x) # shape(B,256,100,1)
x = self.layer2(x) # shape(B,512,50,1)
x = self.layer3(x) # shape(B,1024,25,1)
x = self.layer4(x) # shape(B,2048,25,1)
# x = self.layer3(x) # shape(B,1024,25,1)
# x = self.layer4(x) # shape(B,2048,25,1)
x = self.attnpool(x) # shape(B,1024)

return x


class EMGModifiedResNet2D(nn.Module):
"""
A ResNet class that is similar to torchvision's but contains the following changes:
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
- The final pooling layer is a QKV attention instead of an average pool
"""

# input shape(B, 1, 400, 8)
def __init__(self, layers, output_dim, width=64):
super().__init__()
self.output_dim = output_dim

# the 3-layer stem
self.conv1 = nn.Conv2d(1, width // 2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) # shape(B,32,200,4)
self.bn1 = nn.BatchNorm2d(width // 2)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=(3, 3), padding=(1, 1), bias=False) # shape(B,32,200,4)
self.bn2 = nn.BatchNorm2d(width // 2)
self.relu2 = nn.ReLU(inplace=True)
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) # shape(B,64,100,2)
self.bn3 = nn.BatchNorm2d(width)
self.relu3 = nn.ReLU(inplace=True)
self.avgpool = nn.AvgPool2d((2, 2)) # shape(B,64,50,1)

# residual layers
self._inplanes = width # this is a *mutable* variable used during construction
self.layer1 = self._make_layer(width, layers[0]) # shape(B,256,50,1)
self.layer2 = self._make_layer(width * 2, layers[1], stride=1) # shape(B,512,50,1)
self.layer3 = self._make_layer(width * 4, layers[2], stride=1) # shape(B,1024,25,1)
self.layer4 = self._make_layer(width * 8, layers[3], stride=1) # shape(B,2048,25,1)

embed_dim = width * 8 # the ResNet feature dimension
self.attnpool = EMGAttentionPool2d(50, embed_dim, 16, output_dim)

def _make_layer(self, planes, blocks, stride=1):
layers = [EMGBottleneck(self._inplanes, planes, stride)]

self._inplanes = planes * EMGBottleneck.expansion
for _ in range(1, blocks):
layers.append(EMGBottleneck(self._inplanes, planes))

return nn.Sequential(*layers)

def forward(self, x):
def stem(x):
x = self.relu1(self.bn1(self.conv1(x))) # shape(B,32,200,4)
x = self.relu2(self.bn2(self.conv2(x))) # shape(B,32,200,4)
x = self.relu3(self.bn3(self.conv3(x))) # shape(B,64,200,4)
x = self.avgpool(x) # shape(B,64,100,2)
return x

x = x.type(self.conv1.weight.dtype)
x = stem(x)
x = self.layer1(x) # shape(B,256,100,1)
x = self.layer2(x) # shape(B,512,50,1)
# x = self.layer3(x) # shape(B,1024,25,2)
# x = self.layer4(x) # shape(B,2048,25,1)
x = self.attnpool(x) # shape(B,1024)

return x
Expand All @@ -593,45 +656,46 @@ def stem(x):
class EMGCLIP(nn.Module):
def __init__(self,
embed_dim: int, # 512
# vision
image_resolution: int, # 224
vision_layers: Union[Tuple[int, int, int, int], int], # 12
vision_width: int, # 768
vision_patch_size: int, # 32
# text
context_length: int, # 77
vocab_size: int, # 49408
transformer_width: int, # 512
transformer_heads: int, # 8
transformer_layers: int, # 12
classification: bool,
vision_model = 'RN50'
vision_model = 'RN50',
model_dim=2
):
super().__init__()

self.context_length = context_length
self.classification = classification

# signal
if vision_model == 'RN50':
vision_heads = vision_width * 32 // 64 # 32
self.visual = EMGModifiedResNet(
layers=vision_layers, # (3, 4, 6, 4)
output_dim=embed_dim,
heads=vision_heads,
input_resolution=image_resolution,
width=vision_width
)
if model_dim == 1:
self.visual = EMGModifiedResNet1D(
layers=(3,4,6,4), # (3, 4, 6, 4)
output_dim=embed_dim,
width=64
)
else:
self.visual = EMGModifiedResNet2D(
layers=(3,4,6,4), # (3, 4, 6, 4)
output_dim=embed_dim,
width=64
)
else:
vision_heads = vision_width // 64 # 12
self.visual = VisionTransformer(
input_resolution=image_resolution,
patch_size=vision_patch_size,
width=vision_width,
layers=vision_layers,
heads=vision_heads,
input_resolution=224,
patch_size=32,
width=768,
layers=12,
heads=12,
output_dim=embed_dim
)

# text
self.transformer = Transformer(
width=transformer_width,
layers=transformer_layers,
Expand All @@ -648,7 +712,7 @@ def __init__(self,
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

self.dropout = nn.Dropout()
self.cnn_projection = nn.Linear(embed_dim, 10)
self.class_projection = nn.Linear(embed_dim, 10)

self.initialize_parameters()

Expand Down Expand Up @@ -714,7 +778,7 @@ def encode_text(self, text): # shape(3,77)
def forward(self, image, text):
image_features = self.encode_image(image) # shape(B,8,400,1)
if self.classification:
return self.cnn_projection(self.dropout(image_features))
return self.class_projection(self.dropout(image_features))
text_features = self.encode_text(text)

# normalized features
Expand All @@ -728,3 +792,33 @@ def forward(self, image, text):

# shape = [global_batch_size, global_batch_size]
return logits_per_image, logits_per_text



def EMGbuild_model(state_dict: dict, classification, vis_pretrain=True, vision_model='RN50', model_dim=2):
# text
embed_dim = state_dict["text_projection"].shape[1]
context_length = state_dict["positional_embedding"].shape[0]
vocab_size = state_dict["token_embedding.weight"].shape[0]
transformer_width = state_dict["ln_final.weight"].shape[0]
transformer_heads = transformer_width // 64
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))

model = EMGCLIP(
embed_dim,
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers, classification,
vision_model=vision_model, model_dim=model_dim
)

for key in ["input_resolution", "context_length", "vocab_size"]:
if key in state_dict:
del state_dict[key]

if not vis_pretrain:
for key in list(state_dict.keys()):
if 'visual' in key:
del state_dict[key]

convert_weights(model)
model.load_state_dict(state_dict, strict=False)
return model.eval()
10 changes: 5 additions & 5 deletions dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,14 +159,14 @@ def load_emg_label_from_file(filename, class_type=10):
return emg, label


def window_to_h5py(emg, label, filename, window_size=400):
def window_to_h5py(emg, label, filename, window_size=400, window_overlap=0):
window_data = []
window_label = []
for i in range(len(label)):
emg_type = np.array(emg[i])
window_count = 0
print('{} emg points found in type {} emg signal.'.format(len(emg_type), label[i]))
for j in range(0, len(emg_type) - window_size, window_size):
for j in range(0, len(emg_type) - window_size, window_size - window_overlap):
window_data.append(emg_type[j : j + window_size])
window_label.append(label[i])
window_count += 1
Expand All @@ -188,9 +188,9 @@ def h5py_to_window(filename):

if __name__ == '__main__':
filename = 'D:/Download/Datasets/Ninapro/DB2/S1/S1_E1_A1.mat'
h5_filename = 'dataset/window.h5'
# emg, label = load_emg_label_from_file(filename)
# window_to_h5py(emg, label, h5_filename)
h5_filename = 'dataset/window_400_200.h5'
emg, label = load_emg_label_from_file(filename)
window_to_h5py(emg, label, h5_filename, window_overlap=200)
emg, label = h5py_to_window(h5_filename)
print(emg.shape)
print(label.shape)
Binary file added dataset/window_400_200.h5
Binary file not shown.
30 changes: 19 additions & 11 deletions main_signal_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
def arg_parse():
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", type=int, default=16, help="dataset batch size")
parser.add_argument("--epochs", type=int, default=100, help="training epochs")
parser.add_argument("--epochs", type=int, default=60, help="training epochs")
parser.add_argument("--lr", type=float, default=0.0004, help="learning rate")
parser.add_argument("--dataset", type=str, default="./dataset/img", help="dataset directory path")

Expand All @@ -26,17 +26,20 @@ def main(args):

epochs = args.epochs
device = "cuda" if torch.cuda.is_available() else "cpu"
model, _ = clip.load("RN50", device=device, vis_pretrain=False)

filename = 'dataset/window_400_200.h5'
weight_path = 'res/best.pt'
best_precision, current_precision = 0, 0

model_dim = 1 # 数据维数 1为(B,8,400,1),2为(B,1,400,8)
classification = True # 是否是分类任务
model = clip.EMGload("RN50", device=device, classification=classification, vis_pretrain=False, model_dim=model_dim)

# optimizer = Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.98), weight_decay=0.2)
optimizer = Adam(model.parameters(), lr=args.lr, eps=1e-3)
scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.99)
loss_func = nn.CrossEntropyLoss()

filename = 'dataset/window.h5'
weight_path = 'res/best.pt'
best_precision, current_precision = 0, 0

emg, label = h5py_to_window(filename)
data_len = len(label)
index = np.random.permutation(data_len)
Expand Down Expand Up @@ -74,22 +77,27 @@ def main(args):
num_workers=0
)


if classification:
print('-------- Classification task --------')
else:
print('-------- Pair task --------')
print("{}D signal input.".format(model_dim))
model.train().half()
model.to(device)
print("start training...")
for epoch in range(epochs):
train_one_epoch_signal_text(model, epoch, epochs, device, train_loader, loss_func, optimizer, scheduler)
current_precision = validate_signal_text(model, device, val_loader, loss_func)
train_one_epoch_signal_text(model, epoch, epochs, device, train_loader, loss_func,
optimizer, scheduler, classification=classification, model_dim=model_dim)
current_precision = validate_signal_text(model, device, val_loader, loss_func, classification=classification, model_dim=model_dim)

if current_precision > best_precision:
best_precision = current_precision
print('Current best precision in val set is:%.4f' % (best_precision * 100) + '%')
print('Current best precision in val set is: %.4f' % (best_precision * 100) + '%')
save_model_weight(model=model, filename=weight_path)


model.load_state_dict(torch.load(weight_path))
evaluate_signal_text(model, device, eval_loader, loss_func)
evaluate_signal_text(model, device, eval_loader, loss_func, classification=classification, model_dim=model_dim)

if __name__ == "__main__":
args = arg_parse()
Expand Down
Loading

0 comments on commit 380ebb5

Please sign in to comment.