Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dekadal for merge changes into croptype branch #121

Draft
wants to merge 10 commits into
base: croptype
Choose a base branch
from
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,6 @@ scrap
output/*
imgs/*
# don't track catboost training info
catboost_info
catboost_info
# ignore slurm files
slurm*
158 changes: 68 additions & 90 deletions presto/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
NDVI_INDEX,
NODATAVALUE,
NORMED_BANDS,
NUM_TIMESTEPS,
S1_S2_ERA5_SRTM,
S2_RGB_INDEX,
DynamicWorld2020_2021,
Expand All @@ -38,7 +39,6 @@

class WorldCerealBase(Dataset):
# _NODATAVALUE = 65535
NUM_TIMESTEPS = 12
BAND_MAPPING = {
"OPTICAL-B02-ts{}-10m": "B2",
"OPTICAL-B03-ts{}-10m": "B3",
Expand All @@ -57,26 +57,27 @@ class WorldCerealBase(Dataset):
}
STATIC_BAND_MAPPING = {"DEM-alt-20m": "elevation", "DEM-slo-20m": "slope"}

def __init__(self, dataframe: pd.DataFrame):
def __init__(self, dataframe: pd.DataFrame, num_timesteps: int = NUM_TIMESTEPS):
self.df = dataframe
self.num_timesteps = num_timesteps

def __len__(self):
return self.df.shape[0]

@classmethod
def get_timestep_positions(
cls, row_d: Dict, augment: bool = False, is_ssl: bool = False
self, row_d: Dict, augment: bool = False, is_ssl: bool = False
) -> List[int]:
available_timesteps = int(row_d["available_timesteps"])

if is_ssl:
if available_timesteps == cls.NUM_TIMESTEPS:
valid_position = int(cls.NUM_TIMESTEPS // 2)
if available_timesteps == self.num_timesteps:
valid_position = int(self.num_timesteps // 2)
else:
valid_position = int(
np.random.choice(
range(
cls.NUM_TIMESTEPS // 2, (available_timesteps - cls.NUM_TIMESTEPS // 2)
self.num_timesteps // 2,
(available_timesteps - self.num_timesteps // 2),
),
1,
)
Expand All @@ -86,11 +87,11 @@ def get_timestep_positions(
valid_position = int(row_d["valid_position"])
if not augment:
# check if the valid position is too close to the start_date and force shifting it
if valid_position < cls.NUM_TIMESTEPS // 2:
center_point = cls.NUM_TIMESTEPS // 2
if valid_position < self.num_timesteps // 2:
center_point = self.num_timesteps // 2
# or too close to the end_date
elif valid_position > (available_timesteps - cls.NUM_TIMESTEPS // 2):
center_point = available_timesteps - cls.NUM_TIMESTEPS // 2
elif valid_position > (available_timesteps - self.num_timesteps // 2):
center_point = available_timesteps - self.num_timesteps // 2
else:
# Center the timesteps around the valid position
center_point = valid_position
Expand All @@ -99,35 +100,34 @@ def get_timestep_positions(
# well includes the valid position

min_center_point = max(
cls.NUM_TIMESTEPS // 2,
valid_position + MIN_EDGE_BUFFER - cls.NUM_TIMESTEPS // 2,
self.num_timesteps // 2,
valid_position + MIN_EDGE_BUFFER - self.num_timesteps // 2,
)
max_center_point = min(
available_timesteps - cls.NUM_TIMESTEPS // 2,
valid_position - MIN_EDGE_BUFFER + cls.NUM_TIMESTEPS // 2,
available_timesteps - self.num_timesteps // 2,
valid_position - MIN_EDGE_BUFFER + self.num_timesteps // 2,
)

center_point = np.random.randint(
min_center_point, max_center_point + 1
) # max_center_point included

last_timestep = min(available_timesteps, center_point + cls.NUM_TIMESTEPS // 2)
first_timestep = max(0, last_timestep - cls.NUM_TIMESTEPS)
last_timestep = min(available_timesteps, center_point + self.num_timesteps // 2)
first_timestep = max(0, last_timestep - self.num_timesteps)
timestep_positions = list(range(first_timestep, last_timestep))

if len(timestep_positions) != cls.NUM_TIMESTEPS:
if len(timestep_positions) != self.num_timesteps:
raise ValueError(
f"Acquired timestep positions do not have correct length: \
required {cls.NUM_TIMESTEPS}, got {len(timestep_positions)}"
required {self.num_timesteps}, got {len(timestep_positions)}"
)
assert (
valid_position in timestep_positions
), f"Valid position {valid_position} not in timestep positions {timestep_positions}"
return timestep_positions

@classmethod
def row_to_arrays(
cls,
self,
row: pd.Series,
task_type: str = "cropland",
croptype_list: List = [],
Expand All @@ -140,11 +140,11 @@ def row_to_arrays(

latlon = np.array([row_d["lat"], row_d["lon"]], dtype=np.float32)

timestep_positions = cls.get_timestep_positions(row_d, augment=augment, is_ssl=is_ssl)
timestep_positions = self.get_timestep_positions(row_d, augment=augment, is_ssl=is_ssl)

if cls.NUM_TIMESTEPS == 12:
if self.num_timesteps == 12:
initial_start_date_position = pd.to_datetime(row_d["start_date"]).month
elif cls.NUM_TIMESTEPS > 12:
elif self.num_timesteps > 12:
# get the correct index of the start_date based on NUM_TIMESTEPS`
# e.g. if NUM_TIMESTEPS is 36 (dekadal setup), we should take the correct
# 10-day interval that the start_date falls into
Expand All @@ -153,18 +153,18 @@ def row_to_arrays(
# should also be changed accordingly
year = pd.to_datetime(row_d["start_date"]).year
year_dates = pd.date_range(start=f"{year}-01-01", end=f"{year}-12-31")
bins = pd.cut(year_dates, bins=cls.NUM_TIMESTEPS, labels=False)
bins = pd.cut(year_dates, bins=self.num_timesteps, labels=False)
initial_start_date_position = bins[
np.where(year_dates == pd.to_datetime(row_d["start_date"]))[0][0]
]
else:
raise ValueError(
f"NUM_TIMESTEPS must be at least 12. Currently it is {cls.NUM_TIMESTEPS}"
f"NUM_TIMESTEPS must be at least 12. Currently it is {self.num_timesteps}"
)

# make sure that month for encoding gets shifted according to
# the selected timestep positions. Also ensure circular indexing
month = (initial_start_date_position - 1 + timestep_positions[0]) % cls.NUM_TIMESTEPS
month = (initial_start_date_position - 1 + timestep_positions[0]) % self.num_timesteps

# adding workaround for compatibility between Phase I and Phase II datasets.
# (in Phase II, the relevant attribute name was changed to valid_time)
Expand All @@ -176,12 +176,12 @@ def row_to_arrays(
else:
logger.error("Dataset does not contain neither valid_date, nor valid_time attribute.")

eo_data = np.zeros((cls.NUM_TIMESTEPS, len(BANDS)))
eo_data = np.zeros((self.num_timesteps, len(BANDS)))
# an assumption we make here is that all timesteps for a token
# have the same masking
mask = np.zeros((cls.NUM_TIMESTEPS, len(BANDS_GROUPS_IDX)))
mask = np.zeros((self.num_timesteps, len(BANDS_GROUPS_IDX)))

for df_val, presto_val in cls.BAND_MAPPING.items():
for df_val, presto_val in self.BAND_MAPPING.items():
values = np.array([float(row_d[df_val.format(t)]) for t in timestep_positions])
# this occurs for the DEM values in one point in Fiji
values = np.nan_to_num(values, nan=NODATAVALUE)
Expand All @@ -199,7 +199,7 @@ def row_to_arrays(
mask[:, IDX_TO_BAND_GROUPS[presto_val]] += ~idx_valid
eo_data[:, BANDS.index(presto_val)] = values * idx_valid

for df_val, presto_val in cls.STATIC_BAND_MAPPING.items():
for df_val, presto_val in self.STATIC_BAND_MAPPING.items():
# this occurs for the DEM values in one point in Fiji
values = np.nan_to_num(row_d[df_val], nan=NODATAVALUE)
idx_valid = values != NODATAVALUE
Expand All @@ -210,7 +210,7 @@ def row_to_arrays(
# or nir mask, and adjust the NDVI mask accordingly
mask[:, NDVI_INDEX] = np.logical_or(mask[:, S2_RGB_INDEX], mask[:, S2_NIR_10m_INDEX])

return (cls.check(eo_data), mask.astype(bool), latlon, month, valid_month)
return (self.check(eo_data), mask.astype(bool), latlon, month, valid_month)

def __getitem__(self, idx):
# Get the sample
Expand All @@ -220,13 +220,28 @@ def __getitem__(self, idx):
mask_per_variable = np.repeat(mask_per_token, BAND_EXPANSION, axis=1)
return (
self.normalize_and_mask(eo),
np.ones(self.NUM_TIMESTEPS) * (DynamicWorld2020_2021.class_amount),
np.ones(self.num_timesteps) * (DynamicWorld2020_2021.class_amount),
latlon,
month,
valid_month,
mask_per_variable,
)

def get_month_array(self, row: pd.Series) -> np.ndarray:
start_date, end_date = datetime.strptime(row.start_date, "%Y-%m-%d"), datetime.strptime(
row.end_date, "%Y-%m-%d"
)

# Calculate the step size for 10-day intervals and create a list of dates
step = int((end_date - start_date).days / (self.num_timesteps - 1))
date_vector = [start_date + timedelta(days=i * step) for i in range(self.num_timesteps)]

# Ensure last date is not beyond the end date
if date_vector[-1] > end_date:
date_vector[-1] = end_date

return np.array([d.month - 1 for d in date_vector])

@classmethod
def normalize_and_mask(cls, eo: np.ndarray):
# TODO: this can be removed
Expand Down Expand Up @@ -348,19 +363,24 @@ def split_df(


class WorldCerealMaskedDataset(WorldCerealBase):

def __init__(
self,
dataframe: pd.DataFrame,
num_timesteps: int,
mask_params: MaskParamsNoDw,
task_type: str = "cropland",
croptype_list: List = [],
is_ssl: bool = True,
is_dekadal: bool = False,
):
super().__init__(dataframe)
super().__init__(dataframe, num_timesteps)
self.num_timesteps = num_timesteps
self.mask_params = mask_params
self.task_type = task_type
self.croptype_list = croptype_list
self.is_ssl = is_ssl
self.is_dekadal = is_dekadal

def __getitem__(self, idx):
# Get the sample
Expand All @@ -370,21 +390,23 @@ def __getitem__(self, idx):
)

mask_eo, x_eo, y_eo, strat = self.mask_params.mask_data(
self.normalize_and_mask(eo), real_mask_per_token
self.normalize_and_mask(eo),
real_mask_per_token,
)
real_mask_per_variable = np.repeat(real_mask_per_token, BAND_EXPANSION, axis=1)

dynamic_world = np.ones(self.NUM_TIMESTEPS) * (DynamicWorld2020_2021.class_amount)
mask_dw = np.full(self.NUM_TIMESTEPS, True)
dynamic_world = np.ones(self.num_timesteps) * (DynamicWorld2020_2021.class_amount)
mask_dw = np.full(self.num_timesteps, True)
y_dw = dynamic_world.copy()

return MaskedExample(
mask_eo,
mask_dw,
x_eo,
y_eo,
dynamic_world,
y_dw,
month,
self.get_month_array(row) if self.is_dekadal else month,
latlon,
strat,
real_mask_per_variable,
Expand Down Expand Up @@ -425,13 +447,15 @@ class WorldCerealLabelledDataset(WorldCerealBase):
def __init__(
self,
dataframe: pd.DataFrame,
num_timesteps: int = NUM_TIMESTEPS,
countries_to_remove: Optional[List[str]] = None,
years_to_remove: Optional[List[int]] = None,
balance: bool = False,
task_type: str = "cropland",
croptype_list: List = [],
return_hierarchical_labels: bool = False,
augment: bool = False,
is_dekadal: bool = False,
mask_ratio: float = 0.0,
):
dataframe = dataframe.loc[~dataframe.LANDCOVER_LABEL.isin(self.FILTER_LABELS)]
Expand All @@ -452,6 +476,7 @@ def __init__(
self.croptype_list = croptype_list
self.return_hierarchical_labels = return_hierarchical_labels
self.augment = augment
self.is_dekadal = is_dekadal
if augment:
logger.info(
"Augmentation is enabled. \
Expand All @@ -468,7 +493,7 @@ def __init__(
mask_ratio,
)

super().__init__(dataframe)
super().__init__(dataframe, num_timesteps)
if balance:
if self.task_type == "cropland":
logger.info("Balancing is enabled. Underrepresented class will be upsampled.")
Expand Down Expand Up @@ -582,9 +607,9 @@ def __getitem__(self, idx):
return (
normed_eo,
target,
np.ones(self.NUM_TIMESTEPS) * (DynamicWorld2020_2021.class_amount),
np.ones(self.num_timesteps) * (DynamicWorld2020_2021.class_amount),
latlon,
month,
self.get_month_array(row) if self.is_dekadal else month,
valid_month,
mask_per_variable,
)
Expand All @@ -608,53 +633,8 @@ def class_weights(self) -> np.ndarray:
return self._class_weights


class WorldCerealLabelled10DDataset(WorldCerealLabelledDataset):

NUM_TIMESTEPS = 36

@classmethod
def get_month_array(cls, row: pd.Series) -> np.ndarray:
start_date, end_date = datetime.strptime(row.start_date, "%Y-%m-%d"), datetime.strptime(
row.end_date, "%Y-%m-%d"
)

# Calculate the step size for 10-day intervals and create a list of dates
step = int((end_date - start_date).days / (cls.NUM_TIMESTEPS - 1))
date_vector = [start_date + timedelta(days=i * step) for i in range(cls.NUM_TIMESTEPS)]

# Ensure last date is not beyond the end date
if date_vector[-1] > end_date:
date_vector[-1] = end_date

return np.array([d.month - 1 for d in date_vector])

def __getitem__(self, idx):
# Get the sample
df_index = self.indices[idx]
row = self.df.iloc[df_index, :]
eo, mask_per_token, latlon, _, valid_month = self.row_to_arrays(
row, self.task_type, self.croptype_list, self.augment
)
target = self.target_crop(
row, self.task_type, self.croptype_list, self.return_hierarchical_labels
)
if self.mask_ratio > 0:
mask_per_token, eo, _, _ = self.mask_params.mask_data(eo, mask_per_token)
mask_per_variable = np.repeat(mask_per_token, BAND_EXPANSION, axis=1)
return (
self.normalize_and_mask(eo),
target,
np.ones(self.NUM_TIMESTEPS) * (DynamicWorld2020_2021.class_amount),
latlon,
self.get_month_array(row),
valid_month,
mask_per_variable,
)


class WorldCerealInferenceDataset(Dataset):
# _NODATAVALUE = 65535
# Y = "worldcereal_cropland"

BAND_MAPPING = {
"B02": "B2",
"B03": "B3",
Expand Down Expand Up @@ -818,9 +798,7 @@ def _subset_array_temporally(inarr: xr.DataArray) -> xr.DataArray:
return inarr

@classmethod
def nc_to_arrays(
cls, filepath: Path
) -> Tuple[
def nc_to_arrays(cls, filepath: Path) -> Tuple[
np.ndarray,
np.ndarray,
np.ndarray,
Expand Down
Loading
Loading