-
Notifications
You must be signed in to change notification settings - Fork 5
/
run_train_boundary_classifiers_with_support.m
142 lines (115 loc) · 4.54 KB
/
run_train_boundary_classifiers_with_support.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
% Trains several stages of boundary classifiers.
addpath('common/');
addpath('nn/');
addpath('segmentation/');
addpath('structure_classes/');
addpath('support/');
addpath(genpath('iccv07Final'));
Consts;
addpath(consts.spamsPath);
Params;
params.regionSrc = consts.REGION_SRC_BOTTOM_UP;
params.seg.featureSet = consts.BFT_RGBD_SUP;
% Load the train/test split.
load(consts.splitsPath, 'trainNdxs');
if ~exist(consts.boundaryFeaturesDir, 'dir')
mkdir(consts.boundaryFeaturesDir);
end
OVERWRITE = false;
%%
for stage = 3 : params.seg.numStages
params.stage = stage;
if stage >= 3 && ...
(params.seg.featureSet == consts.BFT_RGBD_SUP || ...
params.seg.featureSet == consts.BFT_RGBD_SUP_SC)
dummyParams = params;
dummyParams.stage = stage - 1;
for ii = 1 : consts.numImages
% Extract structure class features from the regions at the previous
% stage.
outFilename = sprintf(consts.structureFeaturesFilename, ...
dummyParams.regionSrc, dummyParams.seg.featureSet, dummyParams.stage, ii);
if exist(outFilename, 'file') && ~OVERWRITE
continue;
end
regionFeatures = extract_region_to_structure_classes_features(ii, dummyParams);
save(outFilename, 'regionFeatures');
end
%%
create_dataset_structure_class_features(dummyParams);
train_structure_class_classifier(dummyParams);
%%
extract_support_features_and_labels(dummyParams, stage);
%%
create_dataset_support_features_for_seg(dummyParams, stage);
train_support_classifier(dummyParams, stage);
end
%%
extract_boundary_classifier_features_and_labels(stage, params);
%% Create the boundary-classification dataset.
datasetFilename = sprintf(consts.boundaryFeaturesDataset, ...
params.seg.featureSet, stage);
if ~exist(datasetFilename, 'file') || params.overwrite
[trainData, testData, trainLabels, testLabels] = ...
create_boundary_classifier_dataset(stage, trainNdxs, params);
fprintf('Saving dataset...');
save(datasetFilename, 'trainData', 'trainLabels', ...
'testData', 'testLabels', '-v7.3');
fprintf('DONE\n');
else
fprintf('Loading the boundary-classification dataset.\n');
load(datasetFilename, 'trainData', 'trainLabels', ...
'testData', 'testLabels');
end
%% Train the boundary classifier.
boundaryClassifierFilename = ...
sprintf(consts.boundaryClassifierFilename, params.seg.featureSet, stage);
if ~exist(boundaryClassifierFilename, 'file') || params.overwrite
classifier = train_boundary_classifier_dt(stage, trainData, trainLabels, ...
testData, testLabels, params);
save(boundaryClassifierFilename, 'classifier');
else
fprintf('Skipping creation of boundary classifier for stage %d\n', stage);
load(boundaryClassifierFilename, 'classifier');
end
%%
fprintf('\n');
for ii = 1 : consts.numImages
fprintf('Performing merge for image %d/%d (stage %d).\n', ...
ii, consts.numImages, stage);
if ~consts.useImages(ii)
continue;
end
outFilename = sprintf(consts.boundaryInfoPostMerge, ...
params.seg.featureSet, stage, ii);
if exist(outFilename, 'file') && ~params.overwrite
continue;
end
load(sprintf(consts.planeDataFilename, ii), 'planeData');
load(sprintf(consts.watershedFilename, ii), 'pbAll');
if stage == 1
boundaryInfoFilename = sprintf(consts.watershedFilename, ii);
elseif stage <= 3 && ...
(params.seg.featureSet == consts.BFT_RGBD_SUP || ...
params.seg.featureSet == consts.BFT_RGBD_SUP_SC)
boundaryInfoFilename = sprintf(consts.boundaryInfoPostMerge, ...
consts.BFT_RGBD, stage-1, ii);
else
boundaryInfoFilename = sprintf(consts.boundaryInfoPostMerge, ...
params.seg.featureSet, stage-1, ii);
end
load(boundaryInfoFilename, 'boundaryInfo');
load(sprintf(consts.imageRgbFilename, ii), 'imgRgb');
load(sprintf(consts.objectLabelsFilename, ii), 'imgObjectLabels');
load(sprintf(consts.instanceLabelsFilename, ii), 'imgInstanceLabels');
load(sprintf(consts.boundaryFeaturesFilename, ...
params.seg.featureSet, stage, ii), 'boundaryFeatures');
[~, instanceLabels] = get_labels_from_instances(boundaryInfo.imgRegions, ...
imgObjectLabels, imgInstanceLabels);
result = merge_regions(boundaryInfo, boundaryFeatures, ...
classifier, stage, params);
boundaryInfo = update_boundary_info(boundaryInfo, result, imgRgb);
save(outFilename, 'boundaryInfo');
end
fprintf('\n');
end