Skip to content

Commit

Permalink
2023-12-11 night
Browse files Browse the repository at this point in the history
  • Loading branch information
Auditor1234 committed Dec 11, 2023
1 parent dad6a6f commit 7c24266
Showing 1 changed file with 21 additions and 20 deletions.
41 changes: 21 additions & 20 deletions main_signal_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,15 @@ def main(args):
epochs = args.epochs
device = "cuda" if torch.cuda.is_available() else "cpu"

filename = 'dataset/window_400_300.h5'
filename = 'dataset/window.h5'
weight_path = 'res/best.pt'
best_precision, current_precision = 0, 0

model_dim = 2 # 数据维数 1为(B,8,400,1),2为(B,1,400,8)
classification = True # 是否是分类任务
classification = False # 是否是分类任务
# 'RN50', 'ViT-B/32'
model = clip.EMGload("ViT-B/32", device=device, classification=classification, vis_pretrain=False, model_dim=model_dim)
model_types = ['RN50', 'ViT-B/32']
model = clip.EMGload(model_types[0], device=device, classification=classification, vis_pretrain=False, model_dim=model_dim)

# optimizer = Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.98), weight_decay=0.2)
optimizer = Adam(model.parameters(), lr=args.lr, eps=1e-3)
Expand All @@ -53,23 +54,23 @@ def main(args):
(train_len, val_len, data_len - train_len - val_len))

# 数据按照8:1:1分为训练集、验证集和测试集
# train_emg = emg[: train_len]
# train_label = label[: train_len]
# val_emg = emg[train_len : train_len + val_len]
# val_label = label[train_len : train_len + val_len]
# eval_emg = emg[train_len + val_len :]
# eval_label = label[train_len + val_len :]

data_filename = 'D:/Download/Datasets/Ninapro/DB2/S1/S1_E1_A1.mat'
emg, label = load_emg_label_from_file(data_filename)
# [(4,1,1), 200] [(2,1,1), 300]
train_emg, train_label, val_emg, val_label, eval_emg, eval_label = split_window_ration(emg, label, (2,1,1), window_overlap=300)
train_index = np.random.permutation(len(train_emg))
val_index = np.random.permutation(len(val_emg))
eval_index = np.random.permutation(len(eval_emg))
train_emg, train_label = train_emg[train_index] * 20000, train_label[train_index]
val_emg, val_label = val_emg[val_index] * 20000, val_label[val_index]
eval_emg, eval_label = eval_emg[eval_index] * 20000, eval_label[eval_index]
train_emg = emg[: train_len]
train_label = label[: train_len]
val_emg = emg[train_len : train_len + val_len]
val_label = label[train_len : train_len + val_len]
eval_emg = emg[train_len + val_len :]
eval_label = label[train_len + val_len :]

# data_filename = 'D:/Download/Datasets/Ninapro/DB2/S1/S1_E1_A1.mat'
# emg, label = load_emg_label_from_file(data_filename)
# # [(4,1,1), 200] [(2,1,1), 300]
# train_emg, train_label, val_emg, val_label, eval_emg, eval_label = split_window_ration(emg, label, (2,1,1), window_overlap=300)
# train_index = np.random.permutation(len(train_emg))
# val_index = np.random.permutation(len(val_emg))
# eval_index = np.random.permutation(len(eval_emg))
# train_emg, train_label = train_emg[train_index] * 20000, train_label[train_index]
# val_emg, val_label = val_emg[val_index] * 20000, val_label[val_index]
# eval_emg, eval_label = eval_emg[eval_index] * 20000, eval_label[eval_index]

train_loader = DataLoader(
SignalWindow(train_emg, train_label),
Expand Down

0 comments on commit 7c24266

Please sign in to comment.