Skip to content

Commit

Permalink
Merge pull request #12 from gizatechxyz/feature/cache_management
Browse files Browse the repository at this point in the history
Feature/cache management
  • Loading branch information
alejandromartinezgotor authored Mar 21, 2024
2 parents 2ec1656 + b5ec7a3 commit 1daa45b
Show file tree
Hide file tree
Showing 3 changed files with 182 additions and 31 deletions.
84 changes: 84 additions & 0 deletions giza_datasets/cache_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import polars as pl
import pyarrow.dataset as ds
import os


class CacheManager:
def __init__(self, cache_dir):
"""
Initializes the CacheManager with a specified cache directory.
Args:
cache_dir (str): The directory path where cached datasets will be stored.
"""
self.cache_dir = cache_dir
if not os.path.exists(cache_dir):
os.makedirs(cache_dir)

def set_cache_dir(self, cache_dir):
"""
Sets a new cache directory and creates the directory if it does not exist.
Args:
cache_dir (str): The new directory path for caching datasets.
"""
self.cache_dir = cache_dir
if not os.path.exists(cache_dir):
os.makedirs(cache_dir)

def get_cache_path(self, dataset_name):
"""
Determines the full file path in the cache based on the dataset name.
Args:
dataset_name (str): Name of the dataset to identify the cached file.
Returns:
str: The file path for the cached dataset.
"""
return os.path.join(self.cache_dir, f"{dataset_name}")

def load_from_cache(self, dataset_name, eager):
"""
Attempts to load a Polars DataFrame from a cached Parquet file.
If the file exists, it returns the DataFrame; otherwise, it returns None.
Args:
dataset_name (str): Name of the dataset to identify the cached file.
eager (bool): If True, loads the dataset in eager mode; otherwise, in lazy mode.
Returns:
polars.DataFrame or None: The loaded DataFrame if cached, or None if not cached.
"""
cache_path = self.get_cache_path(dataset_name)
if os.path.exists(cache_path):
print("Dataset read from cache.")
myds = ds.dataset(cache_path)
if eager:
return pl.scan_pyarrow_dataset(myds)
else:
return pl.from_arrow(myds.to_table())
return None

def save_to_cache(self, data, dataset_name):
"""
Saves a Polars DataFrame to a Parquet file in the cache.
Args:
data (polars.DataFrame): The DataFrame to be cached.
dataset_name (str): Name of the dataset for caching.
"""
cache_path = self.get_cache_path(dataset_name)
data.write_parquet(cache_path)

def clear_cache(self):
"""
Removes all files in the cache directory and returns the count of deleted files.
"""
deleted_files_count = 0
for filename in os.listdir(self.cache_dir):
file_path = os.path.join(self.cache_dir, filename)
if os.path.isfile(file_path):
os.remove(file_path)
deleted_files_count += 1
return deleted_files_count
4 changes: 2 additions & 2 deletions giza_datasets/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,14 +197,14 @@
tags=["DeFi", "Lending", "Aave-v3", "Ethereum", "Swap Fees"],
documentation="https://datasets.gizatech.xyz/hub/aave/daily-exchange-rates-and-indexes-v3",
),
Dataset(
Dataset(
name="aave-liquidationsV2",
path="gs://datasets-giza/Aave/Aave_LiquidationsV2.parquet",
description="Individual liquidations in Aave v2, including colleteral and lending values",
tags=["DeFi", "Lending", "Aave-v2", "Ethereum", "Liquiditations"],
documentation="https://datasets.gizatech.xyz/hub/aave/liquidations-v2",
),
Dataset(
Dataset(
name="aave-liquidationsV3",
path="gs://datasets-giza/Aave/Aave_LiquidationsV3.parquet",
description="Individual liquidations in Aave v3, including colleteral and lending values",
Expand Down
125 changes: 96 additions & 29 deletions giza_datasets/loaders.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import gcsfs
import polars as pl

import os
from giza_datasets.constants import DATASET_HUB
from giza_datasets.cache_manager import CacheManager


class DatasetsLoader:
Expand All @@ -10,23 +11,26 @@ class DatasetsLoader:
It uses the GCSFileSystem for accessing files and Polars for handling data.
"""

def __init__(self):
"""
Initializes the DatasetsLoader with a GCS filesystem and the dataset configuration.
Verification is turned off for the GCS filesystem.
"""
def __init__(self, use_cache=True, cache_dir=None):
self.fs = gcsfs.GCSFileSystem(verify=False)
self.dataset_hub = DATASET_HUB
self.use_cache = use_cache
self.cache_dir = (
cache_dir
if cache_dir is not None
else os.path.join(os.path.expanduser("~"), ".cache/giza_datasets")
)
self.cache_manager = CacheManager(self.cache_dir) if use_cache else None

def _get_all_parquet_files(self, directory):
"""
Recursively retrieves all the parquet file paths in the given directory.
Args:
directory: The GCS directory to search for parquet files.
directory (str): The GCS directory to search for parquet files.
Returns:
A list of full paths to all the parquet files found.
List[str]: A list of full paths to all the parquet files found.
"""
all_files = self.fs.ls(directory, detail=True)
parquet_files = []
Expand All @@ -44,49 +48,112 @@ def _load_multiple_parquet_files(self, file_paths):
Loads multiple parquet files into a single Polars DataFrame.
Args:
file_paths: A list of file paths to load.
file_paths (List[str]): A list of file paths to load.
Returns:
A concatenated Polars DataFrame containing data from all files.
polars.DataFrame: A concatenated Polars DataFrame containing data from all files.
"""
dfs = []
for file_path in file_paths:
with self.fs.open(file_path) as f:
df = pl.read_parquet(f, use_pyarrow=True)
dfs.append(df)
concatenated_df = pl.concat(dfs, how="diagonal_relaxed")

return concatenated_df

def load(self, dataset_name):
def _dataset_exists_in_gcs(self, dataset_name):
"""
Loads a dataset by name, either as a single file or multiple files.
Checks if a dataset exists in GCS by looking for its path in the dataset hub.
Args:
dataset_name: The name of the dataset to load.
dataset_name (str): The name of the dataset to check.
Returns:
A Polars DataFrame containing the loaded dataset.
Raises:
ValueError: If the dataset name is not found or if no parquet files are found.
str or None: The GCS path of the dataset if found, otherwise None.
"""
gcs_path = None
for dataset in self.dataset_hub:
if dataset.name == dataset_name:
gcs_path = dataset.path
break
return gcs_path

if not gcs_path:
raise ValueError(f"Dataset name '{dataset_name}' not found in Giza.")
elif gcs_path.endswith(".parquet"):
with self.fs.open(gcs_path) as f:
df = pl.read_parquet(f, use_pyarrow=True)
else:
parquet_files = self._get_all_parquet_files(gcs_path)
if not parquet_files:
raise ValueError(
"No .parquet files were found in the directory or subdirectories."
def load(self, dataset_name, eager=False):
"""
Loads a dataset by name, either from cache or GCS. Supports eager and lazy loading.
Args:
dataset_name (str): The name of the dataset to load.
eager (bool): If True, loads the dataset eagerly; otherwise, loads lazily.
Returns:
polars.DataFrame: The loaded dataset as a Polars DataFrame.
Raises:
ValueError: If the dataset cannot be found in cache or GCS, or if eager loading is requested for a non-cached dataset.
"""
if self.use_cache:
# Check if the dataset is already cached
cached_data = self.cache_manager.load_from_cache(dataset_name, eager)
if cached_data is not None:
print(f"Loading dataset {dataset_name} from cache.")
return cached_data
else:
print(
f"Dataset {dataset_name} not found in cache. Downloading from GCS."
)
df = self._load_multiple_parquet_files(parquet_files)

gcs_path = self._dataset_exists_in_gcs(dataset_name)

if not gcs_path:
raise ValueError(f"Dataset name '{dataset_name}' not found.")

local_file_path = os.path.join(self.cache_dir, f"{dataset_name}")

if gcs_path.endswith(".parquet"):
self.fs.get(gcs_path, local_file_path)
else:
self.fs.get(gcs_path, local_file_path, recursive=True)

df = self.cache_manager.load_from_cache(dataset_name, eager)
else:
if eager:
raise ValueError("Eager mode is only available for cached datasets")
gcs_path = self._dataset_exists_in_gcs(dataset_name)
if not gcs_path:
raise ValueError(f"Dataset name '{dataset_name}' not found in Giza.")
elif gcs_path.endswith(".parquet"):
with self.fs.open(gcs_path) as f:
df = pl.read_parquet(f, use_pyarrow=True)
else:
parquet_files = self._get_all_parquet_files(gcs_path)
if not parquet_files:
raise ValueError(
"No .parquet files were found in the directory or subdirectories."
)
df = self._load_multiple_parquet_files(parquet_files)

return df

def set_cache_dir(self, new_cache_dir):
"""
Sets a new cache directory for the CacheManager and updates the instance cache directory.
Args:
new_cache_dir (str): The new directory path for caching datasets.
"""
self.cache_dir = new_cache_dir
if self.use_cache:
self.cache_manager = CacheManager(self.cache_dir)

def clear_cache(self):
"""
Removes all files in the cache directory through the CacheManager and prints the count.
"""
if self.use_cache and self.cache_manager:
deleted_files_count = self.cache_manager.clear_cache()
print(
f"{deleted_files_count} datasets have been cleared from the cache directory."
)
else:
print("Cache management is not enabled or CacheManager is not initialized.")

0 comments on commit 1daa45b

Please sign in to comment.