Skip to content

Commit

Permalink
Enhancements for wild time benchmarks (#305)
Browse files Browse the repository at this point in the history
This PR introduces three enhancements for the WildTime benchmarks. Here
they are:

1. So far, all datasets except fMoW save data using only the year as the
timestamp (finer information is unavailable). Since years can be
pre-1970, these are mapped to days from 1/1/1970 (e.g. in YB
1930->1/1/1970, 1931->2/1/1970). fMoW, on the other hand, has the exact
timestamps available and, before this PR, was used to save files. Now,
to ensure consistency, fMoW also considers only the year. To use the
entire timestamp, you can use the option `--daily`
2. In addition to the training dataset, wildtime provides a validation
and test dataset. Since we do not have these concepts in modyn, all
datasets can be used for training by specifying the `--all` option.
3. In Modyn, training is initiated by a sample that satisfies the
triggering condition. Consequently, samples from the last year are never
used since no sample from the following year initiates its training. By
specifying the `--dummyyear` parameter, it is possible to add a sample
to the year following the last one to enable training on the last year's
data. For example, arXiv has data up to 2022. As a result, since there
is no 2023 point, training on the 2022 data never happens. With this
addition, a 2023 sample is added so that the 2022 data can be used.
  • Loading branch information
francescodeaglio authored Oct 11, 2023
1 parent cc139e0 commit af063be
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 47 deletions.
13 changes: 13 additions & 0 deletions benchmark/wildtime_benchmarks/benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,19 @@ def setup_argparser_wildtime(dataset: str) -> argparse.ArgumentParser:
"--dir", type=pathlib.Path, action="store", help="Path to data directory"
)

parser_.add_argument(
"--all", action="store_true", help="Store all the available data, including the validation and test sets."
)
parser_.add_argument(
"--dummyyear", action="store_true", help="Add a final dummy year to train also on the last trigger in Modyn"
)

if dataset == "fMoW":
parser_.add_argument(
"--daily", action="store_true", help="If specified, data is stored with real timestamps (dd/mm/yy)."
"Otherwise, only the year is considered (as done in the other "
"datasets).")

return parser_


Expand Down
27 changes: 20 additions & 7 deletions benchmark/wildtime_benchmarks/data_generation_arxiv.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def main():
args = parser.parse_args()

logger.info(f"Downloading data to {args.dir}")
ArXivDownloader(args.dir).store_data()
ArXivDownloader(args.dir).store_data(args.all, args.dummyyear)


class ArXivDownloader(Dataset):
Expand Down Expand Up @@ -43,16 +43,19 @@ def __init__(self, data_dir):
self._dataset = datasets
self.path = data_dir

def store_data(self):
def store_data(self, store_all_data: bool, add_final_dummy_year: bool):
for year in tqdm(self._dataset):
# for simplicity, instead of using years we map each day to a year from 1970
year_timestamp = create_timestamp(year=1970, month=1, day=year-2006)
year_rows = []
for i in range(len(self._dataset[year][0]["title"])):
text = self._dataset[year][0]["title"][i].replace("\n", " ")
label = self._dataset[year][0]["category"][i]
csv_row = f"{text}\t{label}"
year_rows.append(csv_row)

splits = [0, 1] if store_all_data else [0]
for split in splits:
for i in range(len(self._dataset[year][split]["title"])):
text = self._dataset[year][split]["title"][i].replace("\n", " ")
label = self._dataset[year][split]["category"][i]
csv_row = f"{text}\t{label}"
year_rows.append(csv_row)

# store the year file
text_file = os.path.join(self.path, f"{year}.csv")
Expand All @@ -62,6 +65,16 @@ def store_data(self):
# set timestamp
os.utime(text_file, (year_timestamp, year_timestamp))

if add_final_dummy_year:
dummy_year = year + 1
year_timestamp = create_timestamp(year=1970, month=1, day= dummy_year - 2006)
text_file = os.path.join(self.path, f"{dummy_year}.csv")
with open(text_file, "w", encoding="utf-8") as f:
f.write("\n".join(["dummy\t0"]))

# set timestamp
os.utime(text_file, (year_timestamp, year_timestamp))

os.remove(os.path.join(self.path, "arxiv.pkl"))


Expand Down
72 changes: 47 additions & 25 deletions benchmark/wildtime_benchmarks/data_generation_fmow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import shutil
from datetime import datetime

from benchmark_utils import download_if_not_exists, setup_argparser_wildtime, setup_logger
from benchmark_utils import create_timestamp, download_if_not_exists, setup_argparser_wildtime, setup_logger
from torch.utils.data import Dataset
from tqdm import tqdm
from wilds import get_dataset
Expand All @@ -13,13 +13,13 @@


def main() -> None:
parser = setup_argparser_wildtime("FMoW")
parser = setup_argparser_wildtime("fMoW")
args = parser.parse_args()

logger.info(f"Downloading data to {args.dir}")

downloader = FMOWDownloader(args.dir)
downloader.store_data()
downloader.store_data(args.daily, args.all, args.dummyyear)
downloader.clean_folder()


Expand Down Expand Up @@ -59,30 +59,52 @@ def move_file_and_rename(self, index: int) -> None:
new_name = os.path.join(self.data_dir, f"{index}.png")
os.rename(dest_file, new_name)

def store_data(self) -> None:
def store_data(self, store_daily: bool, store_all_data: bool, add_final_dummy_year: bool) -> None:

for year in tqdm(self._dataset):
split = 0 # just use training split for now
for i in range(len(self._dataset[year][split]["image_idxs"])):
index = self._dataset[year][split]["image_idxs"][i]
label = self._dataset[year][split]["labels"][i]
raw_timestamp = self.metadata[index]["timestamp"]

if len(raw_timestamp) == 24:
timestamp = datetime.strptime(raw_timestamp, '%Y-%m-%dT%H:%M:%S.%fZ')
else:
timestamp = datetime.strptime(raw_timestamp, '%Y-%m-%dT%H:%M:%SZ')

# save label
label_file = os.path.join(self.data_dir, f"{index}.label")
with open(label_file, "w", encoding="utf-8") as f:
f.write(str(int(label)))
os.utime(label_file, (timestamp.timestamp(), timestamp.timestamp()))

# set image timestamp
self.move_file_and_rename(index)
image_file = os.path.join(self.data_dir, f"{index}.png")
os.utime(image_file, (timestamp.timestamp(), timestamp.timestamp()))
splits = [0, 1] if store_all_data else [0]
for split in splits:
for i in range(len(self._dataset[year][split]["image_idxs"])):
index = self._dataset[year][split]["image_idxs"][i]
label = self._dataset[year][split]["labels"][i]

if store_daily:
raw_timestamp = self.metadata[index]["timestamp"]

if len(raw_timestamp) == 24:
timestamp = datetime.strptime(raw_timestamp, '%Y-%m-%dT%H:%M:%S.%fZ').timestamp()
else:
timestamp = datetime.strptime(raw_timestamp, '%Y-%m-%dT%H:%M:%SZ').timestamp()
else:
timestamp = create_timestamp(year=1970, month=1, day=year+1)

# save label
label_file = os.path.join(self.data_dir, f"{index}.label")
with open(label_file, "w", encoding="utf-8") as f:
f.write(str(int(label)))
os.utime(label_file, (timestamp, timestamp))

# set image timestamp
self.move_file_and_rename(index)
image_file = os.path.join(self.data_dir, f"{index}.png")
os.utime(image_file, (timestamp, timestamp))

if add_final_dummy_year:
dummy_year = year + 1
timestamp = create_timestamp(year=1970, month=1, day=dummy_year+1)
dummy_index = 1000000 #not used by any real sample (last: 99999)

to_copy_image_file = os.path.join(self.data_dir, f"{index}.png")
dummy_image_file = os.path.join(self.data_dir, f"{dummy_index}.png")
shutil.copy(to_copy_image_file, dummy_image_file)
os.utime(dummy_image_file, (timestamp, timestamp))

to_copy_label_file = os.path.join(self.data_dir, f"{index}.label")
dummy_label_file = os.path.join(self.data_dir, f"{dummy_index}.label")
shutil.copy(to_copy_label_file, dummy_label_file)
os.utime(dummy_label_file, (timestamp, timestamp))



@staticmethod
def parse_metadata(data_dir: str) -> list:
Expand Down
27 changes: 20 additions & 7 deletions benchmark/wildtime_benchmarks/data_generation_huffpost.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def main():
args = parser.parse_args()

logger.info(f"Downloading data to {args.dir}")
HuffpostDownloader(args.dir).store_data()
HuffpostDownloader(args.dir).store_data(args.all, args.dummyyear)


class HuffpostDownloader(Dataset):
Expand Down Expand Up @@ -44,15 +44,17 @@ def __init__(self, data_dir: str):
self._dataset = datasets
self.path = data_dir

def store_data(self) -> None:
def store_data(self, store_all_data: bool, add_final_dummy_year: bool) -> None:
for year in tqdm(self._dataset):
year_timestamp = create_timestamp(year=1970, month=1, day=year-2011)
year_rows = []
for i in range(len(self._dataset[year][0]["headline"])):
text = self._dataset[year][0]["headline"][i]
label = self._dataset[year][0]["category"][i]
csv_row = f"{text}\t{label}"
year_rows.append(csv_row)
splits = [0, 1] if store_all_data else [0]
for split in splits:
for i in range(len(self._dataset[year][split]["headline"])):
text = self._dataset[year][split]["headline"][i]
label = self._dataset[year][split]["category"][i]
csv_row = f"{text}\t{label}"
year_rows.append(csv_row)

# store the sentences
text_file = os.path.join(self.path, f"{year}.csv")
Expand All @@ -62,6 +64,17 @@ def store_data(self) -> None:
# set timestamp
os.utime(text_file, (year_timestamp, year_timestamp))

if add_final_dummy_year:
dummy_year = year + 1
year_timestamp = create_timestamp(year=1970, month=1, day= dummy_year - 2011)
text_file = os.path.join(self.path, f"{dummy_year}.csv")
with open(text_file, "w", encoding="utf-8") as f:
f.write("\n".join(["dummy\t0"]))

# set timestamp
os.utime(text_file, (year_timestamp, year_timestamp))


os.remove(os.path.join(self.path, "huffpost.pkl"))


Expand Down
25 changes: 17 additions & 8 deletions benchmark/wildtime_benchmarks/data_generation_yearbook.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def main():
logger.info(f"Downloading data to {args.dir}")

downloader = YearbookDownloader(args.dir)
downloader.store_data()
downloader.store_data(args.all, args.dummyyear)


class YearbookDownloader(Dataset):
Expand All @@ -38,35 +38,44 @@ def __init__(self, data_dir: str):
self._dataset = datasets
self.data_dir = data_dir

def _get_year_data(self, year: int) -> list[Tuple]:
def _get_year_data(self, year: int, store_all_data: bool) -> list[Tuple]:
splits = [0, 1] if store_all_data else [0]
images = torch.FloatTensor(
np.array(
[ # transpose to transform from HWC to CHW (H=height, W=width, C=channels).
# Pytorch requires CHW format
img.transpose(2, 0, 1)[0].reshape(*self.input_dim)
# _dataset has 3 dimensions [years][train=0,valid=1]["images"/"labels"]
for img in self._dataset[year][0]["images"]
img.transpose(2, 0, 1)[split].reshape(*self.input_dim)
# _dataset has 3 dimensions [years][train=0,valid=1,test=2]["images"/"labels"]
for split in splits # just train if --all not specified, else test, train and val
for img in self._dataset[year][split]["images"]
]
)
)
labels = torch.LongTensor(self._dataset[year][0]["labels"])
labels = torch.cat([torch.LongTensor(self._dataset[year][split]["labels"]) for split in splits])
return [(images[i], labels[i]) for i in range(len(images))]

def __len__(self) -> int:
return len(self._dataset["labels"])

def store_data(self) -> None:
def store_data(self, store_all_data: bool, add_final_dummy_year: bool) -> None:
# create directories
if not os.path.exists(self.data_dir):
os.mkdir(self.data_dir)

for year in self.time_steps:
print(f"Saving data for year {year}")
ds = self._get_year_data(year)
ds = self._get_year_data(year, store_all_data)
self.create_binary_file(ds,
os.path.join(self.data_dir, f"{year}.bin"),
create_fake_timestamp(year, base_year=1930))

if add_final_dummy_year:
dummy_year = year + 1
dummy_data = [ ds[0] ] # get one sample from the previous year
self.create_binary_file(dummy_data,
os.path.join(self.data_dir, f"{dummy_year}.bin"),
create_fake_timestamp(dummy_year, base_year=1930))

os.remove(os.path.join(self.data_dir, "yearbook.pkl"))

@staticmethod
Expand Down
1 change: 1 addition & 0 deletions modyn/models/articlenet/articlenet.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# pylint: disable=W0223
from typing import Any

import torch
Expand Down

0 comments on commit af063be

Please sign in to comment.