Skip to content

Commit

Permalink
Add Av2 Motion Forecasting dataset (#49)
Browse files Browse the repository at this point in the history
* Add av2 motion forecasting to trajdata

* Update README

* Update README

* Set robot's name to "ego"
  • Loading branch information
EPronovost authored Nov 16, 2024
1 parent c1a9499 commit 32d071d
Show file tree
Hide file tree
Showing 8 changed files with 486 additions and 0 deletions.
27 changes: 27 additions & 0 deletions DATASETS.md
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,30 @@ It should look like this after downloading:
```

**Note**: Only the annotations need to be downloaded (not the videos).


## Argoverse 2 Motion Forecasting
The dataset can be downloaded from [here](https://www.argoverse.org/av2.html#download-link).

It should look like this after downloading:
```
/path/to/av2mf/
├── train/
| ├── 0000b0f9-99f9-4a1f-a231-5be9e4c523f7/
| | ├── log_map_archive_0000b0f9-99f9-4a1f-a231-5be9e4c523f7.json
| | └── scenario_0000b0f9-99f9-4a1f-a231-5be9e4c523f7.parquet
| ├── 0000b6ab-e100-4f6b-aee8-b520b57c0530/
| | ├── log_map_archive_0000b6ab-e100-4f6b-aee8-b520b57c0530.json
| | └── scenario_0000b6ab-e100-4f6b-aee8-b520b57c0530.parquet
| └── ...
├── val/
| ├── 00010486-9a07-48ae-b493-cf4545855937/
| | ├── log_map_archive_00010486-9a07-48ae-b493-cf4545855937.json
| | └── scenario_00010486-9a07-48ae-b493-cf4545855937.parquet
| └── ...
└── test/
├── 0000b329-f890-4c2b-93f2-7e2413d4ca5b/
| ├── log_map_archive_0000b329-f890-4c2b-93f2-7e2413d4ca5b.json
| └── scenario_0000b329-f890-4c2b-93f2-7e2413d4ca5b.parquet
└── ...
```
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ Currently, the dataloader supports interfacing with the following datasets:
| Lyft Level 5 Train Full | `lyft_train_full` | `train` | `palo_alto` | Lyft Level 5 training data - part 2/2 (70 GB) | 0.1s (10Hz) | :white_check_mark: |
| Lyft Level 5 Validation | `lyft_val` | `val` | `palo_alto` | Lyft Level 5 validation data (8.2 GB) | 0.1s (10Hz) | :white_check_mark: |
| Lyft Level 5 Sample | `lyft_sample` | `mini_train`, `mini_val` | `palo_alto` | Lyft Level 5 sample data (100 scenes, randomly split 80/20 for training/validation) | 0.1s (10Hz) | :white_check_mark: |
| Argoverse 2 Motion Forecasting | `av2_motion_forecasting` | `train`, `val`, `test` | N/A | 250,000 motion forecasting scenarios of 11s each | 0.1s (10Hz) | :white_check_mark: |
| INTERACTION Dataset Single-Agent | `interaction_single` | `train`, `val`, `test`, `test_conditional` | `usa`, `china`, `germany`, `bulgaria` | Single-agent split of the INTERACTION Dataset (where the goal is to predict one target agents' future motion) | 0.1s (10Hz) | :white_check_mark: |
| INTERACTION Dataset Multi-Agent | `interaction_multi` | `train`, `val`, `test`, `test_conditional` | `usa`, `china`, `germany`, `bulgaria` | Multi-agent split of the INTERACTION Dataset (where the goal is to jointly predict multiple agents' future motion) | 0.1s (10Hz) | :white_check_mark: |
| ETH - Univ | `eupeds_eth` | `train`, `val`, `train_loo`, `val_loo`, `test_loo` | `zurich` | The ETH (University) scene from the ETH BIWI Walking Pedestrians dataset | 0.4s (2.5Hz) | |
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ dependencies = [
]

[project.optional-dependencies]
av2 = ["av2==0.2.1"]
dev = ["black", "isort", "pytest", "pytest-xdist", "twine", "build"]
interaction = ["lanelet2==1.2.1"]
lyft = ["l5kit==1.5.0"]
Expand Down
1 change: 1 addition & 0 deletions src/trajdata/dataset_specific/argoverse2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .av2_dataset import Av2Dataset
191 changes: 191 additions & 0 deletions src/trajdata/dataset_specific/argoverse2/av2_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
from pathlib import Path
from typing import Any, Dict, List, Tuple, Type

import pandas as pd
import tqdm
from av2.datasets.motion_forecasting.constants import (
AV2_SCENARIO_OBS_TIMESTEPS,
AV2_SCENARIO_STEP_HZ,
AV2_SCENARIO_TOTAL_TIMESTEPS,
)

from trajdata.caching.env_cache import EnvCache
from trajdata.caching.scene_cache import SceneCache
from trajdata.data_structures import AgentMetadata, EnvMetadata, Scene, SceneMetadata
from trajdata.data_structures.scene_tag import SceneTag
from trajdata.dataset_specific.argoverse2.av2_utils import (
AV2_SPLITS,
Av2Object,
Av2ScenarioIds,
av2_map_to_vector_map,
get_track_metadata,
scenario_name_to_split,
)
from trajdata.dataset_specific.raw_dataset import RawDataset
from trajdata.dataset_specific.scene_records import Argoverse2Record
from trajdata.utils import arr_utils

AV2_MOTION_FORECASTING = "av2_motion_forecasting"
AV2_DT = 1 / AV2_SCENARIO_STEP_HZ


class Av2Dataset(RawDataset):

def compute_metadata(self, env_name: str, data_dir: str) -> EnvMetadata:
if env_name != AV2_MOTION_FORECASTING:
raise ValueError(f"Unknown Argoverse 2 env name: {env_name}")

scenario_ids = Av2ScenarioIds.create(Path(data_dir))

return EnvMetadata(
name=env_name,
data_dir=data_dir,
dt=AV2_DT,
parts=[AV2_SPLITS],
scene_split_map=scenario_ids.scene_split_map,
map_locations=None,
)

def load_dataset_obj(self, verbose: bool = False) -> None:
if verbose:
print(f"Loading {self.name} dataset...", flush=True)
self.dataset_obj = Av2Object(self.metadata.data_dir)

def _get_matching_scenes_from_obj(
self,
scene_tag: SceneTag,
scene_desc_contains: List[str] | None,
env_cache: EnvCache,
) -> List[SceneMetadata]:
"""Compute SceneMetadata for all samples from self.dataset_obj.
Also saves records to env_cache for later reuse.
"""
if scene_desc_contains:
raise ValueError("Argoverse dataset does not support scene descriptions.")

record_list = []
metadata_list = []

for idx, scenario_name in enumerate(self.dataset_obj.scenario_names):
record_list.append(Argoverse2Record(scenario_name, idx))
metadata_list.append(
SceneMetadata(
env_name=self.metadata.name,
name=scenario_name,
dt=AV2_DT,
raw_data_idx=idx,
)
)

self.cache_all_scenes_list(env_cache, record_list)
return metadata_list

def _get_matching_scenes_from_cache(
self,
scene_tag: SceneTag,
scene_desc_contains: List[str] | None,
env_cache: EnvCache,
) -> List[Scene]:
"""Computes Scene data for all samples by reading data from env_cache."""
if scene_desc_contains:
raise ValueError("Argoverse dataset does not support scene descriptions.")

record_list: List[Argoverse2Record] = env_cache.load_env_scenes_list(self.name)
return [
self._create_scene(record.name, record.data_idx) for record in record_list
]

def get_scene(self, scene_info: SceneMetadata) -> Scene:
return self._create_scene(scene_info.name, scene_info.raw_data_idx)

def _create_scene(self, scenario_name: str, data_idx: int) -> Scene:
data_split = scenario_name_to_split(scenario_name)
return Scene(
env_metadata=self.metadata,
name=scenario_name,
location=scenario_name,
data_split=data_split,
length_timesteps=(
AV2_SCENARIO_OBS_TIMESTEPS
if data_split == "test"
else AV2_SCENARIO_TOTAL_TIMESTEPS
),
raw_data_idx=data_idx,
data_access_info=None,
)

def get_agent_info(
self, scene: Scene, cache_path: Path, cache_class: Type[SceneCache]
) -> Tuple[List[AgentMetadata], List[List[AgentMetadata]]]:
"""
Get frame-level information from source dataset, caching it
to cache_path.
Always called after cache_maps, can load map if needed
to associate map information to positions.
"""
scenario = self.dataset_obj.load_scenario(scene.name)

agent_list: List[AgentMetadata] = []
agent_presence: List[List[AgentMetadata]] = [[] for _ in scenario.timestamps_ns]

df_records = []

for track in scenario.tracks:
track_metadata = get_track_metadata(track)
if track_metadata is None:
continue

agent_list.append(track_metadata)

for object_state in track.object_states:
agent_presence[int(object_state.timestep)].append(track_metadata)

df_records.append(
{
"agent_id": track_metadata.name,
"scene_ts": object_state.timestep,
"x": object_state.position[0],
"y": object_state.position[1],
"z": 0.0,
"vx": object_state.velocity[0],
"vy": object_state.velocity[1],
"heading": object_state.heading,
}
)

df = pd.DataFrame.from_records(df_records)
df.set_index(["agent_id", "scene_ts"], inplace=True)
df.sort_index(inplace=True)

df[["ax", "ay"]] = (
arr_utils.agent_aware_diff(
df[["vx", "vy"]].to_numpy(), df.index.get_level_values(0)
)
/ AV2_DT
)
cache_class.save_agent_data(df, cache_path, scene)

return agent_list, agent_presence

def cache_maps(
self,
cache_path: Path,
map_cache_class: Type[SceneCache],
map_params: Dict[str, Any],
) -> None:
"""
Get static, scene-level info from the source dataset, caching it
to cache_path. (Primarily this is info needed to construct VectorMap)
Resolution is in pixels per meter.
"""
for scenario_name in tqdm.tqdm(
self.dataset_obj.scenario_names,
desc=f"{self.name} cache maps",
dynamic_ncols=True,
):
av2_map = self.dataset_obj.load_map(scenario_name)
vector_map = av2_map_to_vector_map(f"{self.name}:{scenario_name}", av2_map)
map_cache_class.finalize_and_cache_map(cache_path, vector_map, map_params)
Loading

0 comments on commit 32d071d

Please sign in to comment.