diff --git a/README.md b/README.md
index 6ab2a5d..6862e0e 100644
--- a/README.md
+++ b/README.md
@@ -102,6 +102,16 @@ Per users' request, we processed two non-anthropogenic datasets
## Quick Start [Back to Top]
+
+### Colab Tutorials
+
+Explore the following tutorials that can be opened directly in Google Colab:
+
+- [![Open Tutorial 1 in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://github.com/ant-research/EasyTemporalPointProcess/blob/main/notebooks/easytpp_1_dataset.ipynb) Tutorial 1: Dataset in EasyTPP.
+
+
+### End-to-end Example
+
We provide an end-to-end example for users to run a standard TPP model with `EasyTPP`.
diff --git a/notebooks/easytpp_1_dataset.ipynb b/notebooks/easytpp_1_dataset.ipynb
index ff13e90..1a37128 100644
--- a/notebooks/easytpp_1_dataset.ipynb
+++ b/notebooks/easytpp_1_dataset.ipynb
@@ -1,263 +1,3153 @@
{
- "cells": [
- {
- "cell_type": "markdown",
- "source": [
- "\n",
- " \n",
- ""
- ],
- "metadata": {
- "collapsed": false,
- "pycharm": {
- "name": "#%% md\n"
- }
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "# Tutorial 1: Dataset"
- ],
- "metadata": {
- "collapsed": false,
- "pycharm": {
- "name": "#%% md\n"
- }
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "In this tutorial, we simply show how the dataset-related functionalities work in **EasyTPP**.\n",
- "\n",
- "\n",
- "Firstly, we install the package."
- ],
- "metadata": {
- "id": "26Wvh9rZbTcg",
- "pycharm": {
- "name": "#%% md\n"
- }
- }
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "U-gIiMZqMPFy",
- "pycharm": {
- "name": "#%%\n"
- }
- },
- "outputs": [],
- "source": [
- "!pip install easy_tpp"
- ]
- },
- {
- "cell_type": "markdown",
- "source": [
- "Currently, there are two options to load the preprocessed dataset:\n",
- "- copy the pickle files from [Google Drive](https://drive.google.com/drive/folders/1f8k82-NL6KFKuNMsUwozmbzDSFycYvz7).\n",
- "- load the json fils from [HuggingFace](https://huggingface.co/easytpp).\n",
- "\n",
- "In the future the first way will be depreciated and the second way is recommended."
- ],
- "metadata": {
- "id": "I5YUvAc7bngQ",
- "pycharm": {
- "name": "#%% md\n"
- }
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "\n",
- "## Load pickle data files\n",
- "\n",
- "If we choose to use the pickle files as the sources, we can download the data files, put it under a data folder, specify the directory in the config file and run the training and prediction pipeline.\n",
- "\n",
- "\n",
- "Take taxi dataset for example, we put it this way:\n",
- "\n",
- "```\n",
- "data:\n",
- " taxi:\n",
- " data_format: pickle\n",
- " train_dir: ./data/taxi/train.pkl\n",
- " valid_dir: ./data/taxi/dev.pkl\n",
- " test_dir: ./data/taxi/test.pkl\n",
- "```\n",
- "\n",
- "See [experiment_config](https://github.com/ant-research/EasyTemporalPointProcess/blob/main/examples/configs/experiment_config.yaml) for the full example.\n",
- "\n"
- ],
- "metadata": {
- "id": "6zfSHKhDmFSS",
- "pycharm": {
- "name": "#%% md\n"
- }
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "## Load json data files\n",
- "\n",
- "\n",
- "The recommended way is to load data from HuggingFace, where all data have been preprocessed in json format and hosted in [EasyTPP Repo](https://huggingface.co/easytpp)."
- ],
- "metadata": {
- "id": "HHPDzqud2wJf",
- "pycharm": {
- "name": "#%% md\n"
- }
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "We use the official APIs to directly download and inspect the dataset."
- ],
- "metadata": {
- "id": "6HJd1lZB33mP",
- "pycharm": {
- "name": "#%% md\n"
- }
- }
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "8sM6riIxQClw",
- "pycharm": {
- "name": "#%%\n"
- }
- },
- "outputs": [],
- "source": [
- "from datasets import load_dataset\n",
- "\n",
- "# we choose taxi dataset as it is relatively small\n",
- "dataset = load_dataset('easytpp/taxi')"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# Tutorial 1: Dataset in EasyTPP"
+ ],
+ "metadata": {
+ "id": "TmnzuOArbQk-"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "In this tutorial, we’ll explore the dataset-related functionalities in **EasyTPP**, an advanced library designed for temporal point process modeling. We will guide you through the installation process, data loading options, and configurations to set up a training pipeline effectively.\n",
+ "\n",
+ "\n",
+ "## Step 1: Install EasyTPP\n",
+ "First, let’s install the EasyTPP package. Run the following command to install the library in your Colab environment:"
+ ],
+ "metadata": {
+ "id": "26Wvh9rZbTcg"
+ }
},
- "id": "BZYUTFsDRHmL",
- "outputId": "478e4afb-6806-4266-83da-2f3c55bf93db",
- "pycharm": {
- "name": "#%%\n"
- }
- },
- "outputs": [
{
- "output_type": "execute_result",
- "data": {
- "text/plain": [
- "DatasetDict({\n",
- " train: Dataset({\n",
- " features: ['time_since_last_event', 'type_event', 'time_since_start', 'dim_process', 'seq_len', 'seq_idx'],\n",
- " num_rows: 1400\n",
- " })\n",
- " validation: Dataset({\n",
- " features: ['time_since_last_event', 'type_event', 'time_since_start', 'dim_process', 'seq_len', 'seq_idx'],\n",
- " num_rows: 200\n",
- " })\n",
- " test: Dataset({\n",
- " features: ['time_since_last_event', 'type_event', 'time_since_start', 'dim_process', 'seq_len', 'seq_idx'],\n",
- " num_rows: 400\n",
- " })\n",
- "})"
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "U-gIiMZqMPFy",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "40a622f2-57c2-4742-f913-9d58302d4e0e"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Collecting git+https://github.com/ant-research/EasyTemporalPointProcess.git\n",
+ " Cloning https://github.com/ant-research/EasyTemporalPointProcess.git to /tmp/pip-req-build-ddqktd57\n",
+ " Running command git clone --filter=blob:none --quiet https://github.com/ant-research/EasyTemporalPointProcess.git /tmp/pip-req-build-ddqktd57\n",
+ " Resolved https://github.com/ant-research/EasyTemporalPointProcess.git to commit de2eef65c9ee66c1dff8dc12d8bef7de270db86f\n",
+ " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
+ "Requirement already satisfied: PyYAML>=5.1 in /usr/local/lib/python3.10/dist-packages (from easy_tpp==0.0.8) (6.0.2)\n",
+ "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from easy_tpp==0.0.8) (1.26.4)\n",
+ "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from easy_tpp==0.0.8) (2.2.2)\n",
+ "Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (from easy_tpp==0.0.8) (2.5.0+cu121)\n",
+ "Requirement already satisfied: tensorboard in /usr/local/lib/python3.10/dist-packages (from easy_tpp==0.0.8) (2.17.0)\n",
+ "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from easy_tpp==0.0.8) (24.1)\n",
+ "Requirement already satisfied: datasets in /usr/local/lib/python3.10/dist-packages (from easy_tpp==0.0.8) (3.0.2)\n",
+ "Requirement already satisfied: omegaconf in /usr/local/lib/python3.10/dist-packages (from easy_tpp==0.0.8) (2.3.0)\n",
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from datasets->easy_tpp==0.0.8) (3.16.1)\n",
+ "Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets->easy_tpp==0.0.8) (16.1.0)\n",
+ "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets->easy_tpp==0.0.8) (0.3.8)\n",
+ "Requirement already satisfied: requests>=2.32.2 in /usr/local/lib/python3.10/dist-packages (from datasets->easy_tpp==0.0.8) (2.32.3)\n",
+ "Requirement already satisfied: tqdm>=4.66.3 in /usr/local/lib/python3.10/dist-packages (from datasets->easy_tpp==0.0.8) (4.66.5)\n",
+ "Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets->easy_tpp==0.0.8) (3.5.0)\n",
+ "Requirement already satisfied: multiprocess<0.70.17 in /usr/local/lib/python3.10/dist-packages (from datasets->easy_tpp==0.0.8) (0.70.16)\n",
+ "Requirement already satisfied: fsspec<=2024.9.0,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets->easy_tpp==0.0.8) (2024.6.1)\n",
+ "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets->easy_tpp==0.0.8) (3.10.10)\n",
+ "Requirement already satisfied: huggingface-hub>=0.23.0 in /usr/local/lib/python3.10/dist-packages (from datasets->easy_tpp==0.0.8) (0.24.7)\n",
+ "Requirement already satisfied: antlr4-python3-runtime==4.9.* in /usr/local/lib/python3.10/dist-packages (from omegaconf->easy_tpp==0.0.8) (4.9.3)\n",
+ "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->easy_tpp==0.0.8) (2.8.2)\n",
+ "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->easy_tpp==0.0.8) (2024.2)\n",
+ "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas->easy_tpp==0.0.8) (2024.2)\n",
+ "Requirement already satisfied: absl-py>=0.4 in /usr/local/lib/python3.10/dist-packages (from tensorboard->easy_tpp==0.0.8) (1.4.0)\n",
+ "Requirement already satisfied: grpcio>=1.48.2 in /usr/local/lib/python3.10/dist-packages (from tensorboard->easy_tpp==0.0.8) (1.64.1)\n",
+ "Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.10/dist-packages (from tensorboard->easy_tpp==0.0.8) (3.7)\n",
+ "Requirement already satisfied: protobuf!=4.24.0,<5.0.0,>=3.19.6 in /usr/local/lib/python3.10/dist-packages (from tensorboard->easy_tpp==0.0.8) (3.20.3)\n",
+ "Requirement already satisfied: setuptools>=41.0.0 in /usr/local/lib/python3.10/dist-packages (from tensorboard->easy_tpp==0.0.8) (75.1.0)\n",
+ "Requirement already satisfied: six>1.9 in /usr/local/lib/python3.10/dist-packages (from tensorboard->easy_tpp==0.0.8) (1.16.0)\n",
+ "Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /usr/local/lib/python3.10/dist-packages (from tensorboard->easy_tpp==0.0.8) (0.7.2)\n",
+ "Requirement already satisfied: werkzeug>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from tensorboard->easy_tpp==0.0.8) (3.0.4)\n",
+ "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch->easy_tpp==0.0.8) (4.12.2)\n",
+ "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch->easy_tpp==0.0.8) (3.4.2)\n",
+ "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch->easy_tpp==0.0.8) (3.1.4)\n",
+ "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.10/dist-packages (from torch->easy_tpp==0.0.8) (1.13.1)\n",
+ "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy==1.13.1->torch->easy_tpp==0.0.8) (1.3.0)\n",
+ "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->easy_tpp==0.0.8) (2.4.3)\n",
+ "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->easy_tpp==0.0.8) (1.3.1)\n",
+ "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->easy_tpp==0.0.8) (24.2.0)\n",
+ "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->easy_tpp==0.0.8) (1.4.1)\n",
+ "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->easy_tpp==0.0.8) (6.1.0)\n",
+ "Requirement already satisfied: yarl<2.0,>=1.12.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->easy_tpp==0.0.8) (1.16.0)\n",
+ "Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->easy_tpp==0.0.8) (4.0.3)\n",
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets->easy_tpp==0.0.8) (3.4.0)\n",
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets->easy_tpp==0.0.8) (3.10)\n",
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets->easy_tpp==0.0.8) (2.2.3)\n",
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets->easy_tpp==0.0.8) (2024.8.30)\n",
+ "Requirement already satisfied: MarkupSafe>=2.1.1 in /usr/local/lib/python3.10/dist-packages (from werkzeug>=1.0.1->tensorboard->easy_tpp==0.0.8) (3.0.2)\n",
+ "Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from yarl<2.0,>=1.12.0->aiohttp->datasets->easy_tpp==0.0.8) (0.2.0)\n"
+ ]
+ }
+ ],
+ "source": [
+ "# ues the latest release\n",
+ "# !pip install easy_tpp\n",
+ "\n",
+ "# or use the git main branch\n",
+ "!pip install git+https://github.com/ant-research/EasyTemporalPointProcess.git"
]
- },
- "metadata": {},
- "execution_count": 3
- }
- ],
- "source": [
- "dataset"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "NJKP0ATnv4_l",
- "pycharm": {
- "name": "#%%\n"
- }
- },
- "outputs": [],
- "source": [
- "dataset['train']['type_event'][0]"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "an__K1qzmRSo",
- "pycharm": {
- "name": "#%% md\n"
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Step 2: Loading Preprocessed Datasets\n",
+ "\n",
+ "EasyTPP provides two methods to load preprocessed datasets:\n",
+ "- [Google Drive](https://drive.google.com/drive/folders/1f8k82-NL6KFKuNMsUwozmbzDSFycYvz7): Download the dataset in pickle format.\n",
+ "- [HuggingFace](https://huggingface.co/easytpp): Load the dataset in JSON format from the HuggingFace repository.\n",
+ "\n",
+ "> Note: The pickle format from Google Drive will be deprecated in future releases, and we recommend using the JSON files from HuggingFace for better compatibility and performance.\n",
+ "\n",
+ "\n",
+ "### Option 1: Load Pickle Data Files (Deprecated Soon)\n",
+ "If you choose to use the pickle files, muanlly download the data files fromt he Google Drive mentioned above, place them under a data directory in your workspace, and specify the directory path in the configuration file.\n",
+ "\n",
+ "Here is an example configuration for loading a Taxi dataset in pickle format:"
+ ],
+ "metadata": {
+ "id": "I5YUvAc7bngQ"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "\n",
+ "\n",
+ "```\n",
+ "data:\n",
+ " taxi:\n",
+ " data_format: pickle\n",
+ " train_dir: ./data/taxi/train.pkl\n",
+ " valid_dir: ./data/taxi/dev.pkl\n",
+ " test_dir: ./data/taxi/test.pkl\n",
+ "```\n",
+ "\n",
+ "Then we can launch the train/evaluation pipeline process. See [experiment_config](https://github.com/ant-research/EasyTemporalPointProcess/blob/main/examples/configs/experiment_config.yaml) for the full example.\n"
+ ],
+ "metadata": {
+ "id": "6zfSHKhDmFSS"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "\n",
+ "### Option 2: Load JSON Data Files (Recommended)\n",
+ "\n",
+ "To use JSON data files from HuggingFace - [EasyTPP Repo](https://huggingface.co/easytpp), simply replace `data_format: pickle` with `data_format: json` in the config file, and update the directory paths accordingly. This setup is recommended for newer versions of EasyTPP and provides better compatibility with various processing functions in the library.\n",
+ "\n",
+ "To activate this loading process in the train/evaluation pipeline, similarly, we put the directory of huggingface repo in the config file, e.g.,\n",
+ "\n",
+ "```\n",
+ "data:\n",
+ " taxi:\n",
+ " data_format: json\n",
+ " train_dir: easytpp/taxi\n",
+ " valid_dir: easytpp/taxi\n",
+ " test_dir: easytpp/taxi\n",
+ "```\n",
+ "\n",
+ "Note that we can also manually put the locally directory of json files in the config:\n",
+ "\n",
+ "```\n",
+ "data:\n",
+ " taxi:\n",
+ " data_format: json\n",
+ " train_dir: ./data/taxi/train.json\n",
+ " valid_dir: ./data/taxi/dev.json\n",
+ " test_dir: ./data/taxi/test.json\n",
+ "```"
+ ],
+ "metadata": {
+ "id": "HHPDzqud2wJf"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Step 3: Exploring Datasets\n",
+ "\n",
+ "The EasyTPP library offers several functions to streamline dataset loading and preprocessing. Let’s go over a few key functionalities:\n",
+ "\n",
+ "### Dataset Properties\n",
+ "\n",
+ "We firstly use the official HuggingFace APIs to directly download and inspect the dataset.\n",
+ "\n",
+ "In this example, the `load_dataset` function is used to load the \"taxi\" dataset, which is relatively small and suited for quick testing. The dataset is automatically split into three parts: train, validation, and test, with each split containing structured information on the events."
+ ],
+ "metadata": {
+ "id": "6HJd1lZB33mP"
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "8sM6riIxQClw",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 368,
+ "referenced_widgets": [
+ "41522af32da04925bc328c4a9d7e3a82",
+ "375ac02a2d6b43f4a70feaf921c5377d",
+ "ae9d5f81adc343edb3f7de9dccb15f47",
+ "03b8f2a927ff406b88d8ff844f620403",
+ "a4ac525ccb134ccbab57aaaafc37045b",
+ "8cfe483c4a924e6992a81fc44274f240",
+ "5d9a76f49322447d9ff1c71ab3f9bee7",
+ "8525620ea0104a05b5ba633671a95873",
+ "65d7f89d96cb4e9188da3db1300076aa",
+ "36ca2862fb10458abb4ba13a9e75ecf2",
+ "722c07901b4c454bb68055971a445778",
+ "5848c90c9f2144b588aa88290b529151",
+ "925eca71c7c64811ab88b672364f440d",
+ "3b29ed8d36864982a47b3510a83f00af",
+ "2b100553e6fe4bc5ad83e17ae523a42e",
+ "28db861ff5f845baac00370fbcd6521b",
+ "e6a5ce28f9974ba3a12d5044aa2e07d4",
+ "9c6f4364de0f4cfe90e189916df4a040",
+ "cd76860acc1b41b997d6df7427596106",
+ "500c3e320fef4e9e906bccc844c4bf4e",
+ "cf4cf5ecb14f45539e14971c4e8a7b07",
+ "5ff53796dc0046f8b077a84836d4fc61",
+ "1d41b7f7b722466f89c74067a3ed783d",
+ "7f82d7dbe42e445198a1f2e0bc512ea6",
+ "e0ad14cb4708466584a0ff46ebf70539",
+ "0b921bafccf148dfaa76e94580800dfb",
+ "cccb52e989424990a2f6fcd881a4137c",
+ "f3ad7c86e9604c27837ea509677cda39",
+ "599d263b0edc42c48b732f22b6f4bb52",
+ "b9cf1a3cdf9544578a1a1ac73dd182b9",
+ "8522fc20649e40aeb1fc9aabbd3501b0",
+ "c986defd720346949d9dfb4a4822bb68",
+ "e2ed3c0ee1bf4beeb1f564f43031648d",
+ "3cdce1231011425e913007b372d1354a",
+ "cfa4572b597345f7a5c69780eb2e19ab",
+ "494232ab57db40e391e34a90f0120dd7",
+ "df95b675f0504e34ab3b2f17d09109c2",
+ "02f0c7450af24e849e4c573ff7e0f9d8",
+ "a8eee5c478d84a4b828970f5f9c55016",
+ "faaf164d7dc142c3bfe55aac535efbf4",
+ "78bfe063ac114e4398007a26053956eb",
+ "0829c53e563b48c0b84dbad6baef16e0",
+ "86b40e77e7a34386b0876859f151b49a",
+ "59ebcd3733c14e8b87a42dc7399cc5e4",
+ "6d1e4b19483343bfa1dbd713c2efbd1f",
+ "ba666b3bd47e4568afeaec47a7acb912",
+ "38c9c5e1bd814f5ea92047d6f5833506",
+ "76579ac830bd404ba5008c5315c323c3",
+ "63da81f729204fb8b2dc5cc12074e09b",
+ "0e2d3e97ceed4b44b55ea7f70047ec7e",
+ "9adf81c38d8d4cce873ebed60c5b208c",
+ "3d24235ee98a4daeaaaf85a7b5763f0b",
+ "c4312014f1a54feab2dc7c2b58a837c5",
+ "75a31d51a0f940aa835b47236b4761b7",
+ "9bc1dbb8d296432f8c79f67a3cd5ef80",
+ "441615ebaf0f4db699d5fd569ee1655a",
+ "21ebc80b5d3c4eca8f8935ffd1a3590c",
+ "f622d1072c63472db30cb6ec6f6e72ae",
+ "879c8911eb3b41febff0d2a33cde55b4",
+ "22e1b3631d324a119e35e53088209296",
+ "93fbb1bbc4304d758998c904955ece3e",
+ "6bb36084b56f4277ad35da260978d9bf",
+ "fb6be1690b3247ee96124fd59c4c001d",
+ "1da7223d069345d88624d9a5485e1752",
+ "dc0b2a2b34fc483abe00409859000d02",
+ "5cdeca479dec41e9ba78fef207360283",
+ "3d4156a02f7a436f9d678531bdcad532",
+ "04ac17403fdd480d985d1f7885196c46",
+ "0b0bc44093404ad8b396dbc7f86a6967",
+ "f9d8cff1f4b24cd4b1aea8693746a81e",
+ "9b8a0445912840c194aed0846228972f",
+ "db91bc99d2964fc69fb4a71abeb0c254",
+ "0d554c589be24338bdeda06958b77140",
+ "33336a9fd5aa42acb54a06486781fb8b",
+ "47d50654d080473b9360a49e246d70df",
+ "f12c68ae19424e8eb8cad1e105674f4c",
+ "1d5ef40cf0004c5caa90c04564c8facb"
+ ]
+ },
+ "outputId": "a5742b18-2ea9-408e-8ef4-dac9c6448cc5"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:89: UserWarning: \n",
+ "The secret `HF_TOKEN` does not exist in your Colab secrets.\n",
+ "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n",
+ "You will be able to reuse this secret in all of your notebooks.\n",
+ "Please note that authentication is recommended but still optional to access public models or datasets.\n",
+ " warnings.warn(\n"
+ ]
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "README.md: 0%| | 0.00/28.0 [00:00, ?B/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "41522af32da04925bc328c4a9d7e3a82"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "train.json: 0%| | 0.00/2.29M [00:00, ?B/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "5848c90c9f2144b588aa88290b529151"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "dev.json: 0%| | 0.00/327k [00:00, ?B/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "1d41b7f7b722466f89c74067a3ed783d"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "test.json: 0%| | 0.00/654k [00:00, ?B/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "3cdce1231011425e913007b372d1354a"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "Generating train split: 0%| | 0/1400 [00:00, ? examples/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "6d1e4b19483343bfa1dbd713c2efbd1f"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "Generating validation split: 0%| | 0/200 [00:00, ? examples/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "441615ebaf0f4db699d5fd569ee1655a"
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "Generating test split: 0%| | 0/400 [00:00, ? examples/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "3d4156a02f7a436f9d678531bdcad532"
+ }
+ },
+ "metadata": {}
+ }
+ ],
+ "source": [
+ "from datasets import load_dataset\n",
+ "\n",
+ "# we choose taxi dataset as it is relatively small\n",
+ "dataset = load_dataset('easytpp/taxi')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Each dataset split is a Dataset object with multiple features such as `seq_len `(sequence length), `time_since_start`, `seq_idx` (sequence index), `time_since_last_event`, `type_event` (event type), and `dim_process` (dimension of the process). This structured format provides essential information about each event's timing and type, which is crucial for modeling temporal point processes. Additionally, the package simplifies data access, allowing users to select specific splits and features for further analysis or model input with minimal setup."
+ ],
+ "metadata": {
+ "id": "DM6yw1u0E6kL"
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "BZYUTFsDRHmL",
+ "outputId": "dfc38373-6079-4ca6-9cbc-ceba5b81e2d3"
+ },
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "DatasetDict({\n",
+ " train: Dataset({\n",
+ " features: ['seq_len', 'time_since_start', 'seq_idx', 'time_since_last_event', 'type_event', 'dim_process'],\n",
+ " num_rows: 1400\n",
+ " })\n",
+ " validation: Dataset({\n",
+ " features: ['seq_len', 'time_since_start', 'seq_idx', 'time_since_last_event', 'type_event', 'dim_process'],\n",
+ " num_rows: 200\n",
+ " })\n",
+ " test: Dataset({\n",
+ " features: ['seq_len', 'time_since_start', 'seq_idx', 'time_since_last_event', 'type_event', 'dim_process'],\n",
+ " num_rows: 400\n",
+ " })\n",
+ "})"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 4
+ }
+ ],
+ "source": [
+ "dataset"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "In the easytpp dataset, the `type_event` feature represents the event type codes within each sequence. In the example shown, `dataset['train']['type_event'][0]` reveals a list of integer codes, such as [8, 3, 8, 3, 8, 3, ...], corresponding to different types of events in the first sequence of the training set. These codes are likely categorical identifiers used to differentiate various types of events in the temporal point process, which can be useful in understanding event dynamics and patterns over time. This feature enables the model to learn and predict not only the timing but also the type of future events within the sequence."
+ ],
+ "metadata": {
+ "id": "-CpGsgnMFa6Z"
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "NJKP0ATnv4_l",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "7bd0df3f-14c1-4fda-bc99-d5ff7c9d9fbe"
+ },
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "[8,\n",
+ " 3,\n",
+ " 8,\n",
+ " 3,\n",
+ " 8,\n",
+ " 3,\n",
+ " 8,\n",
+ " 3,\n",
+ " 8,\n",
+ " 3,\n",
+ " 8,\n",
+ " 3,\n",
+ " 8,\n",
+ " 3,\n",
+ " 8,\n",
+ " 3,\n",
+ " 8,\n",
+ " 3,\n",
+ " 8,\n",
+ " 3,\n",
+ " 8,\n",
+ " 3,\n",
+ " 8,\n",
+ " 3,\n",
+ " 8,\n",
+ " 3,\n",
+ " 8,\n",
+ " 3,\n",
+ " 8,\n",
+ " 3,\n",
+ " 8,\n",
+ " 3,\n",
+ " 8,\n",
+ " 3,\n",
+ " 8,\n",
+ " 3]"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 7
+ }
+ ],
+ "source": [
+ "dataset['train']['type_event'][0]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### Dataset Distributions\n",
+ "\n",
+ "In the following code snippet, several functions from the EasyTPP package are used to configure and analyze a TPP dataset, providing insights into event distribution and timing characteristics.\n",
+ "\n",
+ "#### Dataset Configuration\n",
+ "\n",
+ "The `Config.build_from_yaml_file` function loads configurations from a specified YAML file (`config.yaml`). This file contains settings for data preprocessing, model parameters, and other configurations needed by the `TPPDataLoader` to manage and process the TPP data. By centralizing settings in a configuration file, this function allows for easier parameter management and adjustments without altering the code."
+ ],
+ "metadata": {
+ "id": "int17B0zKRtR"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# As an illustrative example, we write the YAML content to a file\n",
+ "yaml_content = \"\"\"\n",
+ "pipeline_config_id: data_config\n",
+ "\n",
+ "data_format: json\n",
+ "train_dir: easytpp/taxi # ./data/taxi/train.json\n",
+ "valid_dir: easytpp/taxi # ./data/taxi/dev.json\n",
+ "test_dir: easytpp/taxi # ./data/taxi/test.json\n",
+ "data_specs:\n",
+ " num_event_types: 10\n",
+ " pad_token_id: 10\n",
+ " padding_side: right\n",
+ "\"\"\"\n",
+ "\n",
+ "# Save the content to a file named config.yaml\n",
+ "with open(\"config.yaml\", \"w\") as file:\n",
+ " file.write(yaml_content)"
+ ],
+ "metadata": {
+ "id": "G6a74A43MLQN"
+ },
+ "execution_count": 11,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from easy_tpp.config_factory import Config\n",
+ "from easy_tpp.preprocess.data_loader import TPPDataLoader\n",
+ "\n",
+ "\n",
+ "config = Config.build_from_yaml_file('./config.yaml')\n",
+ "tpp_loader = TPPDataLoader(config)"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "rUBkm8JULMmP",
+ "outputId": "c1657049-7fd0-483f-ef24-ad5197a83714"
+ },
+ "execution_count": 12,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "\u001b[31;1m2024-10-29 07:33:39,927 - config.py[pid:253;line:34:build_from_yaml_file] - CRITICAL: Load pipeline config class DataConfig\u001b[0m\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "#### Dataset Statistics\n",
+ "\n",
+ "\n",
+ "The `get_statistics` function retrieves statistical information about the dataset,\n",
+ "such as the distribution of event types, sequence lengths, and timing intervals. By specifying `split='train'`, this function targets only the training subset of the dataset. The resulting stats variable is printed to provide an overview of the dataset's\n",
+ "structure and characteristics, which can be helpful for understanding the data before model training."
+ ],
+ "metadata": {
+ "id": "wFBh_0sBM3EF"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "stats = tpp_loader.get_statistics(split='train')\n",
+ "stats"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "mz95yYH4NL3T",
+ "outputId": "40e28c82-c57f-4dd5-8687-e2d324ee9097"
+ },
+ "execution_count": 13,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "{'num_sequences': 1400,\n",
+ " 'avg_sequence_length': 37.03857142857143,\n",
+ " 'event_type_distribution': {8: 23131,\n",
+ " 3: 22239,\n",
+ " 5: 2161,\n",
+ " 0: 2088,\n",
+ " 1: 1443,\n",
+ " 6: 625,\n",
+ " 4: 107,\n",
+ " 2: 50,\n",
+ " 9: 4,\n",
+ " 7: 6},\n",
+ " 'max_sequence_length': 38,\n",
+ " 'min_sequence_length': 36,\n",
+ " 'mean_time_delta': 0.21851826495759416,\n",
+ " 'min_time_delta': 0.0,\n",
+ " 'max_time_delta': 5.721388888888889}"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 13
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "#### Event Type Distribution Plot\n",
+ "\n",
+ "The following function generates a plot of the distribution of event types within the dataset. This visualization helps identify the frequency of different event types, which can be useful for analyzing class imbalance or the prevalence of certain types of events. Understanding event type distribution is essential for TPP models, as it informs the model about the likelihood and variety of event types."
+ ],
+ "metadata": {
+ "id": "yA61VBnhLLEw"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "tpp_loader.plot_event_type_distribution()"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 564
+ },
+ "id": "M3pheYYPNfYP",
+ "outputId": "3964b087-678f-4a69-ca34-50be977639cf"
+ },
+ "execution_count": 14,
+ "outputs": [
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "