diff --git a/CHANGELOG.md b/CHANGELOG.md index a28551e35..d715056dc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ * Python audio unit test, and support to verify outputs * rocDecode for HW decode * Support for Audio augmentation - PreEmphasis filter +* Support for reading from file lists in file reader ### Optimizations diff --git a/rocAL/include/api/rocal_api_data_loaders.h b/rocAL/include/api/rocal_api_data_loaders.h index 03b01066d..eec0c9a64 100644 --- a/rocAL/include/api/rocal_api_data_loaders.h +++ b/rocAL/include/api/rocal_api_data_loaders.h @@ -882,7 +882,8 @@ extern "C" RocalTensor ROCAL_API_CALL rocalJpegExternalFileSource(RocalContext p /*! Creates Audio file reader and decoder. It allocates the resources and objects required to read and decode audio files stored on the file systems. It has internal sharding capability to load/decode in parallel if user wants. * If the files are not in standard audio compression formats they will be ignored, Currently wav format is supported * \param [in] context Rocal context - * \param [in] source_path A NULL terminated char string pointing to the location of files on the disk + * \param [in] source_path A NULL terminated char string pointing to the location on the disk + * \param [in] source_file_list_path A char string pointing to the file list location on the disk * \param [in] shard_count Defines the parallelism level by internally sharding the input dataset and load/decode using multiple decoder/loader instances. Using shard counts bigger than 1 improves the load/decode performance if compute resources (CPU cores) are available. * \param [in] is_output Boolean variable to enable the audio to be part of the output. * \param [in] shuffle Boolean variable to shuffle the dataset. @@ -892,6 +893,7 @@ extern "C" RocalTensor ROCAL_API_CALL rocalJpegExternalFileSource(RocalContext p */ extern "C" RocalTensor ROCAL_API_CALL rocalAudioFileSource(RocalContext context, const char* source_path, + const char* source_file_list_path, unsigned shard_count, bool is_output, bool shuffle = false, @@ -901,7 +903,8 @@ extern "C" RocalTensor ROCAL_API_CALL rocalAudioFileSource(RocalContext context, /*! Creates Audio file reader and decoder. It allocates the resources and objects required to read and decode audio files stored on the file systems. It has internal sharding capability to load/decode in parallel is user wants. * If the files are not in standard audio compression formats they will be ignored. * \param [in] context Rocal context - * \param [in] source_path A NULL terminated char string pointing to the location of files on the disk + * \param [in] source_path A NULL terminated char string pointing to the location on the disk + * \param [in] source_file_list_path A char string pointing to the file list location on the disk * \param [in] shard_id Shard id for this loader * \param [in] shard_count Defines the parallelism level by internally sharding the input dataset and load/decode using multiple decoder/loader instances. Using shard counts bigger than 1 improves the load/decode performance if compute resources (CPU cores) are available. * \param [in] is_output Boolean variable to enable the audio to be part of the output. @@ -912,6 +915,7 @@ extern "C" RocalTensor ROCAL_API_CALL rocalAudioFileSource(RocalContext context, */ extern "C" RocalTensor ROCAL_API_CALL rocalAudioFileSourceSingleShard(RocalContext p_context, const char* source_path, + const char* source_file_list_path, unsigned shard_id, unsigned shard_count, bool is_output, diff --git a/rocAL/include/api/rocal_api_meta_data.h b/rocAL/include/api/rocal_api_meta_data.h index 50e0f671a..9907427bb 100644 --- a/rocAL/include/api/rocal_api_meta_data.h +++ b/rocAL/include/api/rocal_api_meta_data.h @@ -36,9 +36,10 @@ THE SOFTWARE. * \ingroup group_rocal_meta_data * \param [in] rocal_context rocal context * \param [in] source_path path to the folder that contains the dataset or metadata file + * \param file_list_path is the path to file list that contains the file names and its corresponding labels * \return RocalMetaData object, can be used to inquire about the rocal's output (processed) tensors */ -extern "C" RocalMetaData ROCAL_API_CALL rocalCreateLabelReader(RocalContext rocal_context, const char* source_path); +extern "C" RocalMetaData ROCAL_API_CALL rocalCreateLabelReader(RocalContext rocal_context, const char* source_path, const char* file_list_path = ""); /*! \brief creates video label reader * \ingroup group_rocal_meta_data diff --git a/rocAL/include/readers/file_source_reader.h b/rocAL/include/readers/file_source_reader.h index 1d7043c22..eeefdefd5 100644 --- a/rocAL/include/readers/file_source_reader.h +++ b/rocAL/include/readers/file_source_reader.h @@ -75,6 +75,7 @@ class FileSourceReader : public Reader { Reader::Status open_folder(); Reader::Status subfolder_reading(); std::string _folder_path; + std::string _file_list_path; DIR *_src_dir; DIR *_sub_dir; struct dirent *_entity; diff --git a/rocAL/include/readers/image/image_reader.h b/rocAL/include/readers/image/image_reader.h index 87a8e3e57..7c3cd0eaa 100644 --- a/rocAL/include/readers/image/image_reader.h +++ b/rocAL/include/readers/image/image_reader.h @@ -99,6 +99,8 @@ struct ReaderConfig { std::map feature_key_map() { return _feature_key_map; } void set_file_prefix(const std::string &prefix) { _file_prefix = prefix; } std::string file_prefix() { return _file_prefix; } + void set_file_list_path(const std::string &file_list_path) { _file_list_path = file_list_path; } + std::string file_list_path() { return _file_list_path; } std::shared_ptr meta_data_reader() { return _meta_data_reader; } ExternalSourceFileMode mode() { return _file_mode; } std::pair get_last_batch_policy() { return _last_batch_info; } @@ -117,7 +119,8 @@ struct ReaderConfig { size_t _sequence_frame_stride = 1; bool _shuffle = false; bool _loop = false; - std::string _file_prefix = ""; //!< to read only files with prefix. supported only for cifar10_data_reader and tf_record_reader + std::string _file_prefix; //!< to read only files with prefix. supported only for cifar10_data_reader and tf_record_reader + std::string _file_list_path; //!< to read only files present in the file list std::shared_ptr _meta_data_reader = nullptr; ExternalSourceFileMode _file_mode = ExternalSourceFileMode::NONE; std::pair _last_batch_info = {RocalBatchPolicy::FILL, true}; diff --git a/rocAL/source/api/rocal_api_data_loaders.cpp b/rocAL/source/api/rocal_api_data_loaders.cpp index 99c495d68..344cf6008 100644 --- a/rocAL/source/api/rocal_api_data_loaders.cpp +++ b/rocAL/source/api/rocal_api_data_loaders.cpp @@ -47,7 +47,9 @@ std::tuple evaluate_audio_data_set(StorageType storage_type, DecoderType decoder_type, const std::string& source_path, const std::string& file_list_path) { AudioSourceEvaluator source_evaluator; - if (source_evaluator.Create(ReaderConfig(storage_type, source_path, file_list_path), DecoderConfig(decoder_type)) != AudioSourceEvaluatorStatus::OK) + auto reader_config = ReaderConfig(storage_type, source_path); + reader_config.set_file_list_path(file_list_path); + if (source_evaluator.Create(reader_config, DecoderConfig(decoder_type)) != AudioSourceEvaluatorStatus::OK) THROW("Initializing file source input evaluator failed") auto max_samples = source_evaluator.GetMaxSamples(); auto max_channels = source_evaluator.GetMaxChannels(); @@ -2189,6 +2191,7 @@ RocalTensor ROCAL_API_CALL rocalAudioFileSourceSingleShard( RocalContext p_context, const char* source_path, + const char* source_file_list_path, unsigned shard_id, unsigned shard_count, bool is_output, @@ -2203,7 +2206,7 @@ rocalAudioFileSourceSingleShard( THROW("Shard count should be bigger than 0") if (shard_id >= shard_count) THROW("Shard id should be smaller than shard count") - auto [max_sample_length, max_channels] = evaluate_audio_data_set(StorageType::FILE_SYSTEM, DecoderType::AUDIO_SOFTWARE_DECODE, source_path, ""); + auto [max_sample_length, max_channels] = evaluate_audio_data_set(StorageType::FILE_SYSTEM, DecoderType::AUDIO_SOFTWARE_DECODE, source_path, source_file_list_path); INFO("Internal buffer size for audio samples = " + TOSTR(max_sample_length) + " and channels = " + TOSTR(max_channels)) RocalTensorDataType tensor_data_type = RocalTensorDataType::FP32; std::vector dims = {context->user_batch_size(), max_sample_length, max_channels}; @@ -2234,6 +2237,7 @@ RocalTensor ROCAL_API_CALL rocalAudioFileSource( RocalContext p_context, const char* source_path, + const char* source_file_list_path, unsigned shard_count, bool is_output, bool shuffle, @@ -2243,7 +2247,7 @@ rocalAudioFileSource( auto context = static_cast(p_context); try { #ifdef ROCAL_AUDIO - auto [max_sample_length, max_channels] = evaluate_audio_data_set(StorageType::FILE_SYSTEM, DecoderType::AUDIO_SOFTWARE_DECODE, source_path, ""); + auto [max_sample_length, max_channels] = evaluate_audio_data_set(StorageType::FILE_SYSTEM, DecoderType::AUDIO_SOFTWARE_DECODE, source_path, source_file_list_path); INFO("Internal buffer size for audio samples = " + TOSTR(max_sample_length) + " and channels = " + TOSTR(max_channels)) RocalTensorDataType tensor_data_type = RocalTensorDataType::FP32; std::vector dims = {context->user_batch_size(), max_sample_length, max_channels}; diff --git a/rocAL/source/api/rocal_api_meta_data.cpp b/rocAL/source/api/rocal_api_meta_data.cpp index edea792da..6250d985a 100644 --- a/rocAL/source/api/rocal_api_meta_data.cpp +++ b/rocAL/source/api/rocal_api_meta_data.cpp @@ -51,12 +51,14 @@ void RocalMetaData ROCAL_API_CALL - rocalCreateLabelReader(RocalContext p_context, const char* source_path) { + rocalCreateLabelReader(RocalContext p_context, const char* source_path, const char* file_list_path) { if (!p_context) THROW("Invalid rocal context passed to rocalCreateLabelReader") auto context = static_cast(p_context); - - return context->master_graph->create_label_reader(source_path, MetaDataReaderType::FOLDER_BASED_LABEL_READER); + if (strlen(file_list_path) == 0) + return context->master_graph->create_label_reader(source_path, MetaDataReaderType::FOLDER_BASED_LABEL_READER); + else + return context->master_graph->create_label_reader(file_list_path, MetaDataReaderType::TEXT_FILE_META_DATA_READER); } RocalMetaData diff --git a/rocAL/source/loaders/audio/node_audio_loader.cpp b/rocAL/source/loaders/audio/node_audio_loader.cpp index 5e568a532..c1d3cc511 100644 --- a/rocAL/source/loaders/audio/node_audio_loader.cpp +++ b/rocAL/source/loaders/audio/node_audio_loader.cpp @@ -42,6 +42,7 @@ void AudioLoaderNode::Init(unsigned internal_shard_count, unsigned cpu_num_threa reader_cfg.set_batch_count(load_batch_count); reader_cfg.set_meta_data_reader(meta_data_reader); reader_cfg.set_cpu_num_threads(cpu_num_threads); + reader_cfg.set_file_list_path(file_list_path); _loader_module->initialize(reader_cfg, DecoderConfig(decoder_type), mem_type, _batch_size, false); _loader_module->start_loading(); } diff --git a/rocAL/source/loaders/audio/node_audio_loader_single_shard.cpp b/rocAL/source/loaders/audio/node_audio_loader_single_shard.cpp index 2212969b6..c7238fab8 100644 --- a/rocAL/source/loaders/audio/node_audio_loader_single_shard.cpp +++ b/rocAL/source/loaders/audio/node_audio_loader_single_shard.cpp @@ -46,6 +46,7 @@ void AudioLoaderSingleShardNode::Init(unsigned shard_id, unsigned shard_count, u reader_cfg.set_batch_count(load_batch_count); reader_cfg.set_meta_data_reader(meta_data_reader); reader_cfg.set_cpu_num_threads(cpu_num_threads); + reader_cfg.set_file_list_path(file_list_path); _loader_module->initialize(reader_cfg, DecoderConfig(decoder_type), mem_type, _batch_size); _loader_module->start_loading(); } diff --git a/rocAL/source/meta_data/text_file_meta_data_reader.cpp b/rocAL/source/meta_data/text_file_meta_data_reader.cpp index 760c5aff1..f326f3d9a 100644 --- a/rocAL/source/meta_data/text_file_meta_data_reader.cpp +++ b/rocAL/source/meta_data/text_file_meta_data_reader.cpp @@ -69,15 +69,20 @@ void TextFileMetaDataReader::lookup(const std::vector &image_names) void TextFileMetaDataReader::read_all(const std::string &path) { std::ifstream text_file(path.c_str()); if (text_file.good()) { - //_text_file.open(path.c_str(), std::ifstream::in); std::string line; while (std::getline(text_file, line)) { std::istringstream line_ss(line); int label; - std::string image_name; - if (!(line_ss >> image_name >> label)) + std::string file_name; + if (!(line_ss >> file_name >> label)) continue; - add(image_name, label); + // process pair (file_name, label) + auto last_id = file_name; + auto last_slash_idx = last_id.find_last_of("\\/"); + if (std::string::npos != last_slash_idx) { + last_id.erase(0, last_slash_idx + 1); + } + add(last_id, label); } } else { THROW("Can't open the metadata file at " + path) diff --git a/rocAL/source/pipeline/master_graph.cpp b/rocAL/source/pipeline/master_graph.cpp index 4dbf7ddca..2f9227fc7 100644 --- a/rocAL/source/pipeline/master_graph.cpp +++ b/rocAL/source/pipeline/master_graph.cpp @@ -1136,7 +1136,8 @@ std::vector MasterGraph::create_label_reader(const char *sour THROW("A metadata reader has already been created") if (_augmented_meta_data) THROW("Metadata can only have a single output") - + if (strlen(source_path) == 0) + THROW("Source path needs to be provided") MetaDataConfig config(MetaDataType::Label, reader_type, source_path); _meta_data_reader = create_meta_data_reader(config, _augmented_meta_data); _meta_data_reader->read_all(source_path); diff --git a/rocAL/source/readers/file_source_reader.cpp b/rocAL/source/readers/file_source_reader.cpp index 9856fa92d..564bef03c 100644 --- a/rocAL/source/readers/file_source_reader.cpp +++ b/rocAL/source/readers/file_source_reader.cpp @@ -52,6 +52,7 @@ Reader::Status FileSourceReader::initialize(ReaderConfig desc) { auto ret = Reader::Status::OK; _file_id = 0; _folder_path = desc.path(); + _file_list_path = desc.file_list_path(); _shard_id = desc.get_shard_id(); _shard_count = desc.get_shard_count(); _batch_count = desc.get_batch_size(); @@ -161,25 +162,62 @@ Reader::Status FileSourceReader::generate_file_names() { std::sort(entry_name_list.begin(), entry_name_list.end()); auto ret = Reader::Status::OK; - for (unsigned dir_count = 0; dir_count < entry_name_list.size(); ++dir_count) { - std::string subfolder_path = _full_path + "/" + entry_name_list[dir_count]; - filesys::path pathObj(subfolder_path); - if (filesys::exists(pathObj) && filesys::is_regular_file(pathObj)) { - // ignore files with unsupported extensions - auto file_extension_idx = subfolder_path.find_last_of("."); - if (file_extension_idx != std::string::npos) { - std::string file_extension = subfolder_path.substr(file_extension_idx + 1); - std::transform(file_extension.begin(), file_extension.end(), file_extension.begin(), - [](unsigned char c) { return std::tolower(c); }); - if ((file_extension != "jpg") && (file_extension != "jpeg") && (file_extension != "png") && (file_extension != "ppm") && (file_extension != "bmp") && (file_extension != "pgm") && (file_extension != "tif") && (file_extension != "tiff") && (file_extension != "webp") && (file_extension != "wav")) - continue; + if (!_file_list_path.empty()) { // Reads the file paths from the file list and adds to file_names vector for decoding + std::ifstream fp(_file_list_path); + if (fp.is_open()) { + while (fp) { + std::string file_label_path; + std::getline(fp, file_label_path); + std::istringstream ss(file_label_path); + std::string file_path; + std::getline(ss, file_path, ' '); + if (filesys::path(file_path).is_relative()) { // Only add root path if the file list contains relative file paths + if (!filesys::exists(_folder_path)) + THROW("File list contains relative paths but root path doesn't exists"); + file_path = _folder_path + "/" + file_path; + } + std::string file_name = file_path.substr(file_path.find_last_of("/\\") + 1); + + if (!_meta_data_reader || _meta_data_reader->exists(file_name)) { // Check if the file is present in metadata reader and add to file names list, to avoid issues while lookup + if (filesys::is_regular_file(file_path)) { + if (get_file_shard_id() != _shard_id) { + _file_count_all_shards++; + incremenet_file_id(); + continue; + } + _in_batch_read_count++; + _in_batch_read_count = (_in_batch_read_count % _batch_count == 0) ? 0 : _in_batch_read_count; + _last_file_name = file_path; + _file_names.push_back(file_path); + _file_count_all_shards++; + incremenet_file_id(); + } + } else { + WRN("Skipping file," + std::string(file_path) + " as it is not present in metadata reader") + } + } + } + } else { + for (unsigned dir_count = 0; dir_count < entry_name_list.size(); ++dir_count) { + std::string subfolder_path = _full_path + "/" + entry_name_list[dir_count]; + filesys::path pathObj(subfolder_path); + if (filesys::exists(pathObj) && filesys::is_regular_file(pathObj)) { + // ignore files with unsupported extensions + auto file_extension_idx = subfolder_path.find_last_of("."); + if (file_extension_idx != std::string::npos) { + std::string file_extension = subfolder_path.substr(file_extension_idx + 1); + std::transform(file_extension.begin(), file_extension.end(), file_extension.begin(), + [](unsigned char c) { return std::tolower(c); }); + if ((file_extension != "jpg") && (file_extension != "jpeg") && (file_extension != "png") && (file_extension != "ppm") && (file_extension != "bmp") && (file_extension != "pgm") && (file_extension != "tif") && (file_extension != "tiff") && (file_extension != "webp") && (file_extension != "wav")) + continue; + } + ret = open_folder(); + break; // assume directory has only files. + } else if (filesys::exists(pathObj) && filesys::is_directory(pathObj)) { + _folder_path = subfolder_path; + if (open_folder() != Reader::Status::OK) + WRN("FileReader ShardID [" + TOSTR(_shard_id) + "] File reader cannot access the storage at " + _folder_path); } - ret = open_folder(); - break; // assume directory has only files. - } else if (filesys::exists(pathObj) && filesys::is_directory(pathObj)) { - _folder_path = subfolder_path; - if (open_folder() != Reader::Status::OK) - WRN("FileReader ShardID [" + TOSTR(_shard_id) + "] File reader cannot access the storage at " + _folder_path); } } diff --git a/rocAL_pybind/amd/rocal/decoders.py b/rocAL_pybind/amd/rocal/decoders.py index 42206d763..038034d6a 100644 --- a/rocAL_pybind/amd/rocal/decoders.py +++ b/rocAL_pybind/amd/rocal/decoders.py @@ -454,6 +454,7 @@ def audio(*inputs, file_root='', file_list_path='', shard_id=0, num_shards=1, ra """ kwargs_pybind = { "source_path": file_root, + "source_file_list_path": file_list_path, "shard_id": shard_id, "num_shards": num_shards, "is_output": False, diff --git a/rocAL_pybind/amd/rocal/plugin/pytorch.py b/rocAL_pybind/amd/rocal/plugin/pytorch.py index 367dcdca0..4c20350e7 100644 --- a/rocAL_pybind/amd/rocal/plugin/pytorch.py +++ b/rocAL_pybind/amd/rocal/plugin/pytorch.py @@ -335,6 +335,8 @@ def __next__(self): self.output_tensor_list[i].copy_data(ctypes.c_void_p(output.data_ptr()), self.output_memory_type) self.output_list.append(output) + self.labels = self.loader.get_image_labels() + self.labels_tensor = self.labels_tensor.copy_(torch.from_numpy(self.labels)).long() return self.output_list, self.labels_tensor, torch.tensor(self.roi_array.reshape(self.batch_size,4)[...,2:4]) def reset(self): diff --git a/rocAL_pybind/amd/rocal/readers.py b/rocAL_pybind/amd/rocal/readers.py index 70e5a25f3..65f53b827 100644 --- a/rocAL_pybind/amd/rocal/readers.py +++ b/rocAL_pybind/amd/rocal/readers.py @@ -78,7 +78,7 @@ def file(file_root, file_filters=None, file_list='', stick_to_shard=False, pad_l Pipeline._current_pipeline._reader = "labelReader" # Output labels = [] - kwargs_pybind = {"source_path": file_root} + kwargs_pybind = {"source_path": file_root, "file_list": file_list} label_reader_meta_data = b.labelReader( Pipeline._current_pipeline._handle, *(kwargs_pybind.values())) return (label_reader_meta_data, labels) diff --git a/tests/cpp_api/audio_tests/audio_tests.cpp b/tests/cpp_api/audio_tests/audio_tests.cpp index d3fad18a0..01067b51b 100644 --- a/tests/cpp_api/audio_tests/audio_tests.cpp +++ b/tests/cpp_api/audio_tests/audio_tests.cpp @@ -132,12 +132,17 @@ int test(int test_case, const char *path, int qa_mode, int downmix, int gpu) { return -1; } - std::cout << "Running LABEL READER" << std::endl; - rocalCreateLabelReader(handle, path); + std::string file_list_path; // User can modify this with the file list path if required + if (qa_mode) { // setting the default file list path from ROCAL_DATA_PATH + file_list_path = std::string(std::getenv("ROCAL_DATA_PATH")) + "rocal_data/audio/wav_file_list.txt"; + } + + std::cout << ">>>>>>> Running LABEL READER" << std::endl; + rocalCreateLabelReader(handle, path, file_list_path.c_str()); if (test_case == 0) is_output_audio_decoder = true; - RocalTensor decoded_output = rocalAudioFileSourceSingleShard(handle, path, 0, 1, is_output_audio_decoder, false, false, downmix); + RocalTensor decoded_output = rocalAudioFileSourceSingleShard(handle, path, file_list_path.c_str(), 0, 1, is_output_audio_decoder, false, false, downmix); if (rocalGetStatus(handle) != ROCAL_OK) { std::cout << "Audio source could not initialize : " << rocalGetErrorMessage(handle) << std::endl; return -1; @@ -187,7 +192,10 @@ int test(int test_case, const char *path, int qa_mode, int downmix, int gpu) { char audio_file_name[file_name_size]; std::vector roi(4 * input_batch_size, 0); rocalGetImageName(handle, audio_file_name); + RocalTensorList labels = rocalGetImageLabels(handle); + int *label_id = reinterpret_cast(labels->at(0)->buffer()); // The labels are present contiguously in memory std::cerr << "Audio file : " << audio_file_name << "\n"; + std::cerr << "Label : " << *label_id << "\n"; for (uint idx = 0; idx < output_tensor_list->size(); idx++) { buffer = static_cast(output_tensor_list->at(idx)->buffer()); output_tensor_list->at(idx)->copy_roi(roi.data()); diff --git a/tests/cpp_api/audio_tests/audio_tests.py b/tests/cpp_api/audio_tests/audio_tests.py index 64a0b3b87..bc27a0cc8 100644 --- a/tests/cpp_api/audio_tests/audio_tests.py +++ b/tests/cpp_api/audio_tests/audio_tests.py @@ -72,7 +72,7 @@ def main(): sys.exit() sys.dont_write_bytecode = True - input_file_path = rocal_data_path + "/rocal_data/audio/wav" + input_file_path = rocal_data_path + "/rocal_data/audio" build_folder_path = os.getcwd() args = audio_test_suite_parser_and_validator() diff --git a/tests/python_api/audio_unit_test.py b/tests/python_api/audio_unit_test.py index 56c8efd09..f90e8a7d0 100644 --- a/tests/python_api/audio_unit_test.py +++ b/tests/python_api/audio_unit_test.py @@ -71,22 +71,24 @@ def verify_output(audio_tensor, rocal_data_path, roi_tensor, test_results, case_ test_results[case_name] = "FAILED" @pipeline_def(seed=seed) -def audio_decoder_pipeline(path): - audio, labels = fn.readers.file(file_root=path) +def audio_decoder_pipeline(path, file_list): + audio, labels = fn.readers.file(file_root=path, file_list=file_list) return fn.decoders.audio( audio, file_root=path, + file_list_path=file_list, downmix=False, shard_id=0, num_shards=1, stick_to_shard=False) @pipeline_def(seed=seed) -def pre_emphasis_filter_pipeline(path): - audio, labels = fn.readers.file(file_root=path) +def pre_emphasis_filter_pipeline(path, file_list): + audio, labels = fn.readers.file(file_root=path, file_list=file_list) decoded_audio = fn.decoders.audio( audio, file_root=path, + file_list_path=file_list, downmix=False, shard_id=0, num_shards=1, @@ -97,6 +99,7 @@ def main(): args = parse_args() audio_path = args.audio_path + file_list = args.file_list_path rocal_cpu = False if args.rocal_gpu else True batch_size = args.batch_size test_case = args.test_case @@ -129,8 +132,9 @@ def main(): if not rocal_cpu: print("The GPU support for Audio is not given yet. Running on CPU") rocal_cpu = True - if audio_path == "": - audio_path = f'{rocal_data_path}/rocal_data/audio/wav/' + if audio_path == "" and file_list == "": + audio_path = f'{rocal_data_path}/rocal_data/audio/' + file_list = f'{rocal_data_path}/rocal_data/audio/wav_file_list.txt' else: print("QA mode is disabled for custom audio data") qa_mode = 0 @@ -143,9 +147,9 @@ def main(): for case in case_list: case_name = test_case_augmentation_map.get(case) if case_name == "audio_decoder": - audio_pipeline = audio_decoder_pipeline(batch_size=batch_size, num_threads=num_threads, device_id=device_id, rocal_cpu=rocal_cpu, path=audio_path) + audio_pipeline = audio_decoder_pipeline(batch_size=batch_size, num_threads=num_threads, device_id=device_id, rocal_cpu=rocal_cpu, path=audio_path, file_list=file_list) if case_name == "preemphasis_filter": - audio_pipeline = pre_emphasis_filter_pipeline(batch_size=batch_size, num_threads=num_threads, device_id=device_id, rocal_cpu=rocal_cpu, path=audio_path) + audio_pipeline = pre_emphasis_filter_pipeline(batch_size=batch_size, num_threads=num_threads, device_id=device_id, rocal_cpu=rocal_cpu, path=audio_path, file_list=file_list) audio_pipeline.build() audio_loader = ROCALAudioIterator(audio_pipeline, auto_reset=True) cnt = 0 diff --git a/tests/python_api/parse_config.py b/tests/python_api/parse_config.py index d2b41987b..b2784fbdd 100644 --- a/tests/python_api/parse_config.py +++ b/tests/python_api/parse_config.py @@ -111,6 +111,8 @@ def parse_args(): 'audio-python-unittest', 'audio-python-unittest-related options') audio_unit_test.add_argument('--audio_path', type=str, default="", help='audio files path') + audio_unit_test.add_argument('--file_list_path', type=str, default="", + help='file list path') audio_unit_test.add_argument('--test_case', type=int, default=None, help='test case') audio_unit_test.add_argument('--qa_mode', type=int, default=1,