forked from n-waves/multifit
-
Notifications
You must be signed in to change notification settings - Fork 0
/
prepare_xnli.py
73 lines (54 loc) · 2.17 KB
/
prepare_xnli.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import zipfile
from pathlib import Path
from typing import Optional, Union
import fire
from tqdm import tqdm
from fastai.core import *
from fastai.datasets import *
ROOT = Path("data").resolve()
XNLI_DIR = ROOT / "xnli"
if not ROOT.exists():
ROOT.mkdir()
XNLI_DIR.mkdir(exist_ok=True)
print(f"Saving data in {ROOT}")
MT_FILE = "XNLI-MT-1.0.zip"
XNLI_FILE = "XNLI-1.0.zip"
MT_PATH = XNLI_DIR / MT_FILE
XNLI_PATH = XNLI_DIR / XNLI_FILE
MT_URL = "https://s3.amazonaws.com/xnli/XNLI-MT-1.0.zip"
XNLI_URL = "https://s3.amazonaws.com/xnli/XNLI-1.0.zip"
class TqdmUpTo(tqdm):
def update_to(self, b=1, bsize=1, tsize=None):
if tsize is not None:
self.total = tsize
self.update(b * bsize - self.n)
def download_data(url: str, fname: Union[str, Path], dest: Optional[Union[str, Path]]):
"""
Download data if the filename does not exist already
Uses Tqdm to show download progress
"""
from urllib.request import urlretrieve
filepath = (Path(dest) / fname).resolve()
if not filepath.exists():
dirname = Path(filepath.parents[0])
print(f"Creating directory {dirname} from {filepath}")
dirname.mkdir(exist_ok=True)
with TqdmUpTo(unit="B", unit_scale=True, miniters=1, desc=url.split("/")[-1]) as t:
urlretrieve(url, filepath, reporthook=t.update_to)
return str(filepath.resolve().absolute())
def get_and_unzip_data(url: str, fname: Union[str, Path] = None, dest: Union[str, Path] = None):
"""Download `url` if it doesn't exist to `fname` and un-tgz to folder `dest`"""
if dest is None:
dest = url.split("/")[-1]
dest = Path(dest)
fname = dest / fname
if not fname.exists():
download_data(url=url, fname=fname, dest=dest)
print(f"Extracting {fname.resolve().absolute()} \n to {dest}")
zipfile.ZipFile(fname, "r").extractall(dest)
return dest
def get_xnli_and_MT(dest: Union[str, Path] = XNLI_DIR):
get_and_unzip_data(url=XNLI_URL, fname=XNLI_FILE, dest=dest)
get_and_unzip_data(url=MT_URL, fname=MT_FILE, dest=dest)
if __name__ == "__main__":
fire.Fire(get_xnli_and_MT)