forked from kennymckormick/pyskl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathresampling_smototeen5_oversampling
123 lines (94 loc) · 4.84 KB
/
resampling_smototeen5_oversampling
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import mmcv
import numpy as np
from imblearn.combine import SMOTEENN
import os
from collections import Counter
def resample_pickle_data(input_file, train_output_file, val_output_file):
"""Resample keypoints and split into train/validation datasets."""
if not os.path.exists(input_file):
raise FileNotFoundError(f"Input file {input_file} not found.")
data = mmcv.load(input_file)
annotations = data['annotations']
if not annotations or 'keypoint' not in annotations[0]:
raise ValueError("The input pickle file does not contain keypoint data.")
print(f"Loaded {len(annotations)} annotations from {input_file}.")
# Calculate total frames and target frames per class
total_frames = sum(ann['total_frames'] for ann in annotations)
train_target_frames_per_class = int((total_frames * 0.8) / 6) # Assuming 6 classes
val_target_frames_per_class = int((total_frames * 0.2) / 6)
print(f"Target frames per class for training: {train_target_frames_per_class}")
print(f"Target frames per class for validation: {val_target_frames_per_class}")
def resample_keypoints_within_frame_dirs(annotations, target_frames_per_class):
resampled_annotations = []
class_to_annotations = {}
# Group annotations by class
for ann in annotations:
class_to_annotations.setdefault(ann['label'], []).append(ann)
for label, anns in class_to_annotations.items():
aggregated_keypoints = []
y = []
for ann in anns:
keypoints = ann['keypoint'].reshape(-1, 2) # Flatten keypoints
aggregated_keypoints.extend(keypoints)
y.extend([ann['label']] * len(keypoints))
aggregated_keypoints = np.array(aggregated_keypoints, dtype=np.float32)
y = np.array(y, dtype=np.int32)
if len(set(y)) < 2:
print(f"Not enough diversity in label {label}. Oversampling data.")
X_resampled, _ = aggregated_keypoints, y # No resampling
else:
smoteenn = SMOTEENN()
X_resampled, _ = smoteenn.fit_resample(aggregated_keypoints, y)
# Redistribute resampled keypoints back to annotations
num_resampled = len(X_resampled)
points_per_frame = anns[0]['keypoint'].shape[2]
for ann in anns:
frame_count = min(target_frames_per_class, ann['total_frames'])
total_points_needed = frame_count * points_per_frame
if total_points_needed > num_resampled:
print(f"Not enough resampled points for label {label}, annotation {ann['frame_dir']}.")
break
resampled_keypoints = X_resampled[:total_points_needed]
X_resampled = X_resampled[total_points_needed:] # Remove used points
resampled_keypoints = resampled_keypoints.reshape(
(1, frame_count, points_per_frame, 2))
ann['keypoint'] = resampled_keypoints
ann['total_frames'] = frame_count
resampled_annotations.append(ann)
return resampled_annotations
# Split annotations into training and validation sets
train_annotations = annotations[:int(0.8 * len(annotations))]
val_annotations = annotations[int(0.8 * len(annotations)) :]
# Resample keypoints for training and validation
resampled_train_annotations = resample_keypoints_within_frame_dirs(
train_annotations, train_target_frames_per_class
)
resampled_val_annotations = resample_keypoints_within_frame_dirs(
val_annotations, val_target_frames_per_class
)
# Assign split keys
split_data = {"train": [], "val": []}
for ann in resampled_train_annotations:
split_data["train"].append(ann['frame_dir'])
for ann in resampled_val_annotations:
split_data["val"].append(ann['frame_dir'])
# Save resampled annotations
train_data = {'split': split_data, 'annotations': resampled_train_annotations}
val_data = {'split': split_data, 'annotations': resampled_val_annotations}
mmcv.dump(train_data, train_output_file)
mmcv.dump(val_data, val_output_file)
print(f"Resampled train data saved to {train_output_file}.")
print(f"Resampled validation data saved to {val_output_file}.")
# Example usage
# Example usage
#resample_pickle_data(
# input_file=r"D:\pyskl-main\hand_pose_dataset_combined_2Dec_modified.pkl",
# train_output_file=r"D:\pyskl-main\test_pkl\smotenn_train_resampled.pkl",
# val_output_file=r"D:\pyskl-main\test_pkl\smotenn_val_resampled.pkl"
#)
# Example usage
resample_pickle_data(
input_file=r"/root/pyskl_thesis/hand_pose_dataset_combined_2Dec_modified.pkl",
train_output_file=r"/root/pyskl_thesis/test/smoteenn_2dec_train9.pkl",
val_output_file=r"/root/pyskl_thesis/test/smoteenn_2dec_val9.pkl"
)