diff --git a/create_data.py b/create_data.py index 17f2b58..f6fcac3 100644 --- a/create_data.py +++ b/create_data.py @@ -310,23 +310,35 @@ def loadData(args): data_url = os.path.join(args.main_dir, "data.json") if args.mwz_ver == '2.1': dataset_url = "https://www.repository.cam.ac.uk/bitstream/handle/1810/294507/MULTIWOZ2.1.zip?sequence=1&isAllowed=y" + dir_name = 'MULTIWOZ2.1' + + elif args.mwz_ver == '2.2': + dataset_url = "https://github.com/JJinIT/som-dst/raw/feature/MultiWOZ_2x/MultiWOZ_data/MULTIWOZ2.2.zip" + dir_name = 'MULTIWOZ2.2' + + elif args.mwz_ver == '2.3': + dataset_url = "https://github.com/JJinIT/som-dst/raw/feature/MultiWOZ_2x/MultiWOZ_data/MULTIWOZ2.3.zip" + dir_name = 'MULTIWOZ2.3' + else: dataset_url = "https://www.repository.cam.ac.uk/bitstream/handle/1810/280608/MULTIWOZ2.zip?sequence=3&isAllowed=y" + dir_name = 'MULTIWOZ2 2' + if not os.path.exists(args.main_dir): os.makedirs(args.main_dir) - if not os.path.exists(data_url): + if not os.path.exists(os.path.join(args.main_dir, dir_name)): print("Downloading and unzipping the MultiWOZ %s dataset" % args.mwz_ver) resp = urllib.request.urlopen(dataset_url) zip_ref = ZipFile(BytesIO(resp.read())) zip_ref.extractall(args.main_dir) zip_ref.close() - dir_name = 'MULTIWOZ2.1' if args.mwz_ver == '2.1' else 'MULTIWOZ2 2' - shutil.copy(os.path.join(args.main_dir, dir_name, 'data.json'), args.main_dir) - shutil.copy(os.path.join(args.main_dir, dir_name, 'ontology.json'), args.main_dir) - shutil.copy(os.path.join(args.main_dir, dir_name, 'valListFile.json'), args.main_dir) - shutil.copy(os.path.join(args.main_dir, dir_name, 'testListFile.json'), args.main_dir) - shutil.copy(os.path.join(args.main_dir, dir_name, 'dialogue_acts.json'), args.main_dir) + + shutil.copy(os.path.join(args.main_dir, dir_name, 'data.json'), args.main_dir) + shutil.copy(os.path.join(args.main_dir, dir_name, 'ontology.json'), args.main_dir) + shutil.copy(os.path.join(args.main_dir, dir_name, 'valListFile.json'), args.main_dir) + shutil.copy(os.path.join(args.main_dir, dir_name, 'testListFile.json'), args.main_dir) + shutil.copy(os.path.join(args.main_dir, dir_name, 'dialogue_acts.json'), args.main_dir) def getDomain(idx, log, domains, last_domain):