Skip to content

Commit

Permalink
Add highD dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
stepankonev committed Nov 20, 2024
1 parent a2b5a7e commit 5e3f06c
Show file tree
Hide file tree
Showing 6 changed files with 313 additions and 2 deletions.
24 changes: 24 additions & 0 deletions src/trajdata/caching/df_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,30 @@ def is_map_cached(
and raster_map_path.exists()
)

@staticmethod
def cache_raster_map(
env_name: str,
data_idx: str,
cache_path: Path,
raster_map: np.ndarray, # RasterizedMap,
raster_metadata: RasterizedMapMetadata,
map_params: Dict[str, Any],
) -> None:
raster_resolution: float = map_params["px_per_m"]
maps_path: Path = DataFrameCache.get_maps_path(cache_path, env_name)
raster_map_path: Path = (
maps_path / f"{int(data_idx)+1}_{raster_resolution:.2f}px_m.zarr"
)
raster_metadata_path: Path = (
maps_path / f"{int(data_idx)+1}_{raster_resolution:.2f}px_m.dill"
)

maps_path.mkdir(parents=True, exist_ok=True)
zarr.save(raster_map_path, raster_map)

with open(raster_metadata_path, "wb") as f:
dill.dump(raster_metadata, f)

@staticmethod
def finalize_and_cache_map(
cache_path: Path,
Expand Down
1 change: 1 addition & 0 deletions src/trajdata/dataset_specific/highD/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .highd_dataset import HighDDataset
272 changes: 272 additions & 0 deletions src/trajdata/dataset_specific/highD/highd_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,272 @@
import pandas as pd
import numpy as np
from pathlib import Path
from typing import Any, Dict, Final, List, Optional, Tuple, Type
from trajdata.dataset_specific.raw_dataset import RawDataset
from trajdata.data_structures.environment import EnvMetadata
from trajdata.data_structures import (
AgentMetadata,
EnvMetadata,
Scene,
SceneMetadata,
SceneTag,
)
from trajdata.caching import EnvCache, SceneCache
from trajdata.dataset_specific.scene_records import HighDRecord
from trajdata.caching.df_cache import STATE_COLS, EXTENT_COLS
from trajdata.data_structures.agent import (
AgentType,
VariableExtent,
)
from trajdata.maps import RasterizedMapMetadata
from tqdm import tqdm
import cv2
import math


HIGHD_DT: Final[float] = 0.04
HIGHD_NUM_SCENES: Final[int] = 60
HIGHD_ENV_NAME: Final[str] = "highD"
HIGHD_SPLIT_NAME: Final[str] = "all"
# Scailing factor for the HighD raster map
# https://github.com/RobertKrajewski/highD-dataset/blob/master/Python/src/visualization/visualize_frame.py#L151-L152
HIGHD_PX_PER_M: Final[float] = 0.40424


class HighDDataset(RawDataset):
def compute_metadata(self, env_name: str, data_dir: str) -> EnvMetadata:
if env_name != HIGHD_ENV_NAME:
raise ValueError(f"Invalid environment name: {env_name}")
dataset_parts = [(HIGHD_SPLIT_NAME,)]
scene_split_map = {
str(scene_id): HIGHD_SPLIT_NAME
for scene_id in range(1, HIGHD_NUM_SCENES + 1)
}
return EnvMetadata(
name=env_name,
data_dir=data_dir,
dt=HIGHD_DT,
parts=dataset_parts,
scene_split_map=scene_split_map,
)

def load_dataset_obj(self, verbose: bool = False) -> None:
if verbose:
print(f"Loading {self.name} dataset...", flush=True)
self.dataset_obj: Dict[int, Dict[str, Any]] = dict()
for scene_id in tqdm(range(1, HIGHD_NUM_SCENES + 1)):
raw_data_idx = scene_id - 1
scene_id_str = str(scene_id).zfill(2)
tracks_metadata = pd.read_csv(
Path(self.metadata.data_dir) / f"{scene_id_str}_tracksMeta.csv"
)
tracks_metadata["id"] = tracks_metadata["id"].astype(str)
tracks_data = pd.read_csv(
Path(self.metadata.data_dir) / f"{scene_id_str}_tracks.csv"
)
tracks_data["id"] = tracks_data["id"].astype(str)
tracks_data = tracks_data.merge(
tracks_metadata[["id", "numFrames"]], on="id"
)
tracks_metadata.set_index("id", inplace=True)
tracks_data = tracks_data[tracks_data["numFrames"] > 1].reset_index(
drop=True
)
tracks_data["z"] = np.zeros_like(tracks_data["x"])
# Regarding width -> length and height -> width plz see
# https://levelxdata.com/wp-content/uploads/2023/10/highD-Format.pdf
# Track Meta Information
tracks_data.rename(
columns={
"frame": "scene_ts",
"id": "agent_id",
"width": "length",
"height": "width",
"xVelocity": "vx",
"yVelocity": "vy",
"xAcceleration": "ax",
"yAcceleration": "ay",
},
inplace=True,
)
# Originally in the data:
# The x position of the upper left corner of the vehicle's bounding box.
tracks_data["x"] = tracks_data["x"] + tracks_data["length"] / 2
tracks_data["y"] = tracks_data["y"] + tracks_data["width"] / 2
tracks_data["heading"] = np.arctan2(tracks_data["vy"], tracks_data["vx"])
# agent_id -> {scene_id}_{agent_id}
tracks_data["agent_id"] = tracks_data["agent_id"].apply(
lambda x: f"{scene_id_str}_{x}"
)
# "height" is unavailable in the HighD dataset
index_cols = ["agent_id", "scene_ts"]
tracks_data = tracks_data[
["heading"] + STATE_COLS + EXTENT_COLS[:-1] + index_cols
]
tracks_data.set_index(["agent_id", "scene_ts"], inplace=True)
tracks_data.sort_index(inplace=True)
tracks_data.reset_index(level=1, inplace=True)
scene_data = (
pd.read_csv(
Path(self.metadata.data_dir) / f"{scene_id_str}_recordingMeta.csv"
)
.iloc[0]
.to_dict()
)
self.dataset_obj[raw_data_idx] = {
"scene_id": scene_id,
"tracks_data": tracks_data,
"scene_data": scene_data,
"tracks_metadata": tracks_metadata,
}

def _get_location_from_scene_info(self, scene_info: Dict) -> str:
return str(scene_info["scene_id"])

def _get_matching_scenes_from_obj(
self,
scene_tag: SceneTag,
scene_desc_contains: Optional[List[str]],
env_cache: EnvCache,
) -> List[SceneMetadata]:
all_scenes_list: List[HighDRecord] = list()
scenes_list: List[SceneMetadata] = list()
for raw_data_idx, scene_info in self.dataset_obj.items():
scene_id = raw_data_idx + 1
scene_location = self._get_location_from_scene_info(scene_info)
scene_length: int = scene_info["tracks_data"]["scene_ts"].max().item() + 1
all_scenes_list.append(
HighDRecord(raw_data_idx, scene_length, scene_location)
)
scene_metadata = SceneMetadata(
env_name=self.metadata.name,
name=str(scene_id),
dt=self.metadata.dt,
raw_data_idx=raw_data_idx,
)
scenes_list.append(scene_metadata)
self.cache_all_scenes_list(env_cache, all_scenes_list)
return scenes_list

def _get_matching_scenes_from_cache(
self,
scene_tag: SceneTag,
scene_desc_contains: Optional[List[str]],
env_cache: EnvCache,
) -> List[Scene]:
all_scenes_list: List[HighDRecord] = env_cache.load_env_scenes_list(self.name)
scenes_list: List[Scene] = list()
for scene_record in all_scenes_list:
data_idx, scene_length, scene_location = scene_record
scene_id = data_idx + 1
scene_metadata = Scene(
self.metadata,
str(scene_id),
scene_location,
HIGHD_SPLIT_NAME,
scene_length,
data_idx,
None,
)
scenes_list.append(scene_metadata)
return scenes_list

def get_scene(self, scene_info: SceneMetadata) -> Scene:
_, scene_name, _, data_idx = scene_info
scene_data: pd.DataFrame = self.dataset_obj[data_idx]["tracks_data"]
scene_location: str = self._get_location_from_scene_info(
self.dataset_obj[data_idx]
)
scene_split: str = self.metadata.scene_split_map[scene_name]
scene_length: int = scene_data["scene_ts"].max().item() + 1
return Scene(
self.metadata,
scene_name,
scene_location,
scene_split,
scene_length,
data_idx,
None,
)

def get_agent_info(
self, scene: Scene, cache_path: Path, cache_class: Type[SceneCache]
) -> Tuple[List[AgentMetadata], List[List[AgentMetadata]]]:
scene_data: pd.DataFrame = self.dataset_obj[scene.raw_data_idx][
"tracks_data"
].copy()
agent_list: List[AgentMetadata] = list()
agent_presence: List[List[AgentMetadata]] = [
[] for _ in range(scene.length_timesteps)
]
for agent_id, frames in scene_data.groupby("agent_id")["scene_ts"]:
start_frame: int = frames.iat[0].item()
last_frame: int = frames.iat[-1].item()
agent_metadata = self.dataset_obj[scene.raw_data_idx][
"tracks_metadata"
].loc[agent_id.split("_")[1]]
assert start_frame == agent_metadata["initialFrame"]
assert last_frame == agent_metadata["finalFrame"]
agent_info = AgentMetadata(
name=str(agent_id),
agent_type=AgentType.VEHICLE,
first_timestep=start_frame,
last_timestep=last_frame,
extent=VariableExtent(),
)
agent_list.append(agent_info)
for frame in frames:
agent_presence[frame].append(agent_info)
cache_class.save_agent_data(
scene_data,
cache_path,
scene,
)
return agent_list, agent_presence

def cache_map(
self,
data_idx: int,
cache_path: Path,
map_cache_class: Type[SceneCache],
map_params: Dict[str, Any],
) -> None:
env_name = self.metadata.name
resolution = map_params["px_per_m"]
raster_map = (
cv2.imread(
Path(self.metadata.data_dir)
/ f"{str(data_idx + 1).zfill(2)}_highway.png"
).astype(np.float32)
/ 255.0
)
raster_map = cv2.resize(
raster_map,
(
math.ceil(HIGHD_PX_PER_M * resolution * raster_map.shape[1]),
math.ceil(HIGHD_PX_PER_M * resolution * raster_map.shape[0]),
),
interpolation=cv2.INTER_AREA,
).transpose(2, 0, 1)
raster_from_world = np.eye(3)
raster_from_world[:2, :2] *= resolution
raster_metadata = RasterizedMapMetadata(
name=f"{data_idx + 1}_map",
shape=raster_map.shape,
layers=["road", "lane", "shoulder"],
layer_rgb_groups=([0], [1], [2]),
resolution=map_params["px_per_m"],
map_from_world=raster_from_world,
)
map_cache_class.cache_raster_map(
env_name, str(data_idx), cache_path, raster_map, raster_metadata, map_params
)

def cache_maps(
self,
cache_path: Path,
map_cache_class: Type[SceneCache],
map_params: Dict[str, Any],
):
for data_idx in range(HIGHD_NUM_SCENES):
self.cache_map(data_idx, cache_path, map_cache_class, map_params)
6 changes: 6 additions & 0 deletions src/trajdata/dataset_specific/scene_records.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
from typing import NamedTuple


class HighDRecord(NamedTuple):
data_idx: int
length: int
location: str


class Argoverse2Record(NamedTuple):
name: str
data_idx: int
Expand Down
5 changes: 5 additions & 0 deletions src/trajdata/utils/env_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ def get_raw_dataset(dataset_name: str, data_dir: str) -> RawDataset:

return Av2Dataset(dataset_name, data_dir, parallelizable=True, has_maps=True)

if "highD" in dataset_name:
from trajdata.dataset_specific.highD import HighDDataset

return HighDDataset(dataset_name, data_dir, parallelizable=True, has_maps=True)

raise ValueError(f"Dataset with name '{dataset_name}' is not supported")


Expand Down
7 changes: 5 additions & 2 deletions src/trajdata/visualization/vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@ def draw_map(
):
patch_size: int = map.shape[-1]
map_array = RasterizedMap.to_img(map.cpu())
brightened_map_array = map_array * 0.2 + 0.8
if alpha > 1.0 or alpha < 0.0:
raise ValueError("alpha must be between 0 and 1")
brightened_map_array = map_array * alpha + (1 - alpha)

im = ax.imshow(
brightened_map_array,
Expand Down Expand Up @@ -242,6 +244,7 @@ def plot_agent_batch(
legend: bool = True,
show: bool = True,
close: bool = True,
alpha: float = 0.2,
) -> None:
if ax is None:
_, ax = plt.subplots()
Expand All @@ -262,7 +265,7 @@ def plot_agent_batch(

agent_from_raster_tf: Tensor = agent_from_world_tf @ world_from_raster_tf

draw_map(ax, batch.maps[batch_idx], agent_from_raster_tf, alpha=1.0)
draw_map(ax, batch.maps[batch_idx], agent_from_raster_tf, alpha)

agent_hist = batch.agent_hist[batch_idx].cpu()
agent_fut = batch.agent_fut[batch_idx].cpu()
Expand Down

0 comments on commit 5e3f06c

Please sign in to comment.