-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathdonkey_static.lua
43 lines (32 loc) · 1.6 KB
/
donkey_static.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
--[[----------------------------------------------------------------------------
Copyright (c) 2016-present, Facebook, Inc. All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree. An additional grant
of patent rights can be found in the PATENTS file in the same directory.
------------------------------------------------------------------------------]]
local tnt = require 'torchnet'
local utils = paths.dofile'utils.lua'
require 'fbcoco'
function loadDataSet(opt)
local dataset_name = opt.dataset..'_'..opt.train_set..opt.year
local folder_name = opt.dataset == 'imagenet' and opt.dataset
local proposals_path = utils.makeProposalPath(opt.proposal_dir, folder_name, opt.proposals, opt.train_set)
local ds = paths.dofile'DataSetJSON.lua':create(dataset_name, proposals_path, opt.train_nsamples, opt.data_root)
ds.sample_n_per_box = opt.sample_n_per_box
ds.sample_sigma = opt.sample_sigma
ds:setMinProposalArea(opt.train_min_proposal_size)
ds:setMinArea(opt.train_min_gtroi_size)
return ds
end
function createTrainLoader(opt, roidb, scoredb, loader_idx)
local ds = loadDataSet(opt)
ds.roidb, ds.scoredb = roidb, scoredb
local transformer = torch.load(opt.transformer)
local fg_threshold, bg_threshold
bg_threshold = {opt.bg_threshold_min, opt.bg_threshold_max}
fg_threshold = opt.fg_threshold
local bp = fbcoco.BatchProviderROI_StaticImg(ds, transformer, fg_threshold, bg_threshold, opt)
bp.batch_size = opt.batchSize
bp.class_specific = opt.train_class_specific
return bp
end