Skip to content

Commit

Permalink
Adding quality of life improvements, map area querying, fixing nuPlan…
Browse files Browse the repository at this point in the history
… lane ID format, and the ability to cache the data index.
  • Loading branch information
BorisIvanovic committed Jan 25, 2024
1 parent 748b8b1 commit c1a9499
Show file tree
Hide file tree
Showing 27 changed files with 801 additions and 159 deletions.
1 change: 1 addition & 0 deletions examples/preprocess_maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# @profile
def main():
dataset = UnifiedDataset(
# TODO([email protected]) Remove lyft from default examples
desired_data=["nusc_mini", "lyft_sample", "nuplan_mini"],
rebuild_maps=True,
data_dirs={ # Remember to change this to match your filesystem!
Expand Down
39 changes: 37 additions & 2 deletions examples/simple_map_api_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import numpy as np

from trajdata import MapAPI, VectorMap
from trajdata.maps.vec_map_elements import MapElementType
from trajdata.utils import map_utils


def main():
Expand All @@ -23,12 +25,16 @@ def main():
}

start = time.perf_counter()
vec_map: VectorMap = map_api.get_map(f"{env_name}:{random_location_dict[env_name]}")
vec_map: VectorMap = map_api.get_map(
f"{env_name}:{random_location_dict[env_name]}", incl_road_areas=True
)
end = time.perf_counter()
print(f"Map loading took {(end - start)*1000:.2f} ms")

start = time.perf_counter()
vec_map: VectorMap = map_api.get_map(f"{env_name}:{random_location_dict[env_name]}")
vec_map: VectorMap = map_api.get_map(
f"{env_name}:{random_location_dict[env_name]}", incl_road_areas=True
)
end = time.perf_counter()
print(f"Repeated (cached in memory) map loading took {(end - start)*1000:.2f} ms")

Expand Down Expand Up @@ -64,6 +70,35 @@ def main():
end = time.perf_counter()
print(f"Lane visualization took {(end - start)*1000:.2f} ms")

point = vec_map.lanes[lane_idx].center.xyz[0, :]
point_raster = map_utils.transform_points(
point[None, :], transf_mat=raster_from_world
)
ax.scatter(point_raster[:, 0], point_raster[:, 1])

print("Getting nearest road area...")
start = time.perf_counter()
area = vec_map.get_closest_area(point, elem_type=MapElementType.ROAD_AREA)
end = time.perf_counter()
print(f"Getting nearest area took {(end-start)*1000:.2f} ms")

raster_pts = map_utils.transform_points(area.exterior_polygon.xy, raster_from_world)
ax.fill(raster_pts[:, 0], raster_pts[:, 1], alpha=1.0, color="C0")

print("Getting road areas within 100m...")
start = time.perf_counter()
areas = vec_map.get_areas_within(
point, elem_type=MapElementType.ROAD_AREA, dist=100.0
)
end = time.perf_counter()
print(f"Getting areas within took {(end-start)*1000:.2f} ms")

for area in areas:
raster_pts = map_utils.transform_points(
area.exterior_polygon.xy, raster_from_world
)
ax.fill(raster_pts[:, 0], raster_pts[:, 1], alpha=0.2, color="C1")

ax.axis("equal")
ax.grid(None)

Expand Down
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ classifiers = [
"License :: OSI Approved :: Apache Software License",
]
name = "trajdata"
version = "1.3.3"
version = "1.4.0"
authors = [{ name = "Boris Ivanovic", email = "[email protected]" }]
description = "A unified interface to many trajectory forecasting datasets."
readme = "README.md"
Expand All @@ -33,7 +33,8 @@ dependencies = [
"geopandas>=0.13.2",
"protobuf==3.19.4",
"scipy>=1.9.0",
"opencv-python>=4.5.0"
"opencv-python>=4.5.0",
"shapely>=2.0.0",
]

[project.optional-dependencies]
Expand Down
63 changes: 60 additions & 3 deletions src/trajdata/caching/df_cache.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import warnings
from decimal import Decimal
from typing import TYPE_CHECKING

Expand All @@ -10,6 +11,7 @@
VectorMap,
)
from trajdata.maps.map_kdtree import MapElementKDTree
from trajdata.maps.map_strtree import MapElementSTRTree

import pickle
from math import ceil, floor
Expand Down Expand Up @@ -654,7 +656,7 @@ def is_traffic_light_data_cached(self, desired_dt: Optional[float] = None) -> bo

def get_traffic_light_status_dict(
self, desired_dt: Optional[float] = None
) -> Dict[Tuple[int, int], TrafficLightStatus]:
) -> Dict[Tuple[str, int], TrafficLightStatus]:
"""
Returns dict mapping Lane Id, scene_ts to traffic light status for the
particular scene. If data doesn't exist for the current dt, interpolates and
Expand Down Expand Up @@ -704,18 +706,20 @@ def are_maps_cached(cache_path: Path, env_name: str) -> bool:
@staticmethod
def get_map_paths(
cache_path: Path, env_name: str, map_name: str, resolution: float
) -> Tuple[Path, Path, Path, Path, Path]:
) -> Tuple[Path, Path, Path, Path, Path, Path]:
maps_path: Path = DataFrameCache.get_maps_path(cache_path, env_name)

vector_map_path: Path = maps_path / f"{map_name}.pb"
kdtrees_path: Path = maps_path / f"{map_name}_kdtrees.dill"
rtrees_path: Path = maps_path / f"{map_name}_rtrees.dill"
raster_map_path: Path = maps_path / f"{map_name}_{resolution:.2f}px_m.zarr"
raster_metadata_path: Path = maps_path / f"{map_name}_{resolution:.2f}px_m.dill"

return (
maps_path,
vector_map_path,
kdtrees_path,
rtrees_path,
raster_map_path,
raster_metadata_path,
)
Expand All @@ -728,13 +732,19 @@ def is_map_cached(
maps_path,
vector_map_path,
kdtrees_path,
rtrees_path,
raster_map_path,
raster_metadata_path,
) = DataFrameCache.get_map_paths(cache_path, env_name, map_name, resolution)

# TODO(bivanovic): For now, rtrees are optional to have in the cache.
# In the future, they may be required (likely after we develop an
# incremental caching scheme or similar to handle additions like these).
return (
maps_path.exists()
and vector_map_path.exists()
and kdtrees_path.exists()
# and rtrees_path.exists()
and raster_metadata_path.exists()
and raster_map_path.exists()
)
Expand All @@ -751,6 +761,7 @@ def finalize_and_cache_map(
maps_path,
vector_map_path,
kdtrees_path,
rtrees_path,
raster_map_path,
raster_metadata_path,
) = DataFrameCache.get_map_paths(
Expand All @@ -775,6 +786,10 @@ def finalize_and_cache_map(
with open(kdtrees_path, "wb") as f:
dill.dump(vector_map.search_kdtrees, f)

# Saving precomputed map element rtrees.
with open(rtrees_path, "wb") as f:
dill.dump(vector_map.search_rtrees, f)

# Saving the rasterized map data.
zarr.save(raster_map_path, rasterized_map.data)

Expand Down Expand Up @@ -814,7 +829,7 @@ def pad_map_patch(
return np.pad(patch, [(0, 0), (pad_top, pad_bot), (pad_left, pad_right)])

def load_kdtrees(self) -> Dict[str, MapElementKDTree]:
_, _, kdtrees_path, _, _ = DataFrameCache.get_map_paths(
_, _, kdtrees_path, _, _, _ = DataFrameCache.get_map_paths(
self.path, self.scene.env_name, self.scene.location, 0.0
)

Expand All @@ -840,6 +855,47 @@ def get_kdtrees(self, load_only_once: bool = True):
else:
return self._kdtrees

def load_rtrees(self) -> MapElementSTRTree:
_, _, _, rtrees_path, _, _ = DataFrameCache.get_map_paths(
self.path, self.scene.env_name, self.scene.location, 0.0
)

if not rtrees_path.exists():
warnings.warn(
(
"Trying to load cached RTree encoding 2D Map elements, "
f"but {rtrees_path} does not exist. Earlier versions of "
"trajdata did not build and cache this RTree. If area queries "
"are needed, please rebuild the map cache (see "
"examples/preprocess_maps.py for an example of how to do this). "
"Otherwise, please ignore this warning."
),
UserWarning,
)
return None

with open(rtrees_path, "rb") as f:
rtrees: MapElementSTRTree = dill.load(f)

return rtrees

def get_rtrees(self, load_only_once: bool = True):
"""Loads and returns the rtrees object from the cache file.
Args:
load_only_once (bool): store the kdtree dictionary in self so that we
dont have to load it from the cache file more than once.
"""
if self._rtrees is None:
rtrees = self.load_rtrees()
if load_only_once:
self._rtrees = rtrees

return rtrees

else:
return self._rtrees

def load_map_patch(
self,
world_x: float,
Expand All @@ -856,6 +912,7 @@ def load_map_patch(
maps_path,
_,
_,
_,
raster_map_path,
raster_metadata_path,
) = DataFrameCache.get_map_paths(
Expand Down
2 changes: 1 addition & 1 deletion src/trajdata/caching/scene_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def is_traffic_light_data_cached(self, desired_dt: Optional[float] = None) -> bo

def get_traffic_light_status_dict(
self, desired_dt: Optional[float] = None
) -> Dict[Tuple[int, int], TrafficLightStatus]:
) -> Dict[Tuple[str, int], TrafficLightStatus]:
"""Returns lookup table for traffic light status in the current scene
lane_id, scene_ts -> TrafficLightStatus"""
raise NotImplementedError()
Expand Down
3 changes: 0 additions & 3 deletions src/trajdata/data_structures/batch_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,8 +441,6 @@ def distance_limit(agent_types: np.ndarray, target_type: int) -> np.ndarray:
**vector_map_params if vector_map_params is not None else None,
)

self.scene_id = scene_time.scene.name

### ROBOT DATA ###
self.robot_future_np: Optional[StateArray] = None

Expand Down Expand Up @@ -506,7 +504,6 @@ def get_agents_future(
future_sec: Tuple[Optional[float], Optional[float]],
nearby_agents: List[AgentMetadata],
) -> Tuple[List[StateArray], List[np.ndarray], np.ndarray]:

(
agent_futures,
agent_future_extents,
Expand Down
1 change: 0 additions & 1 deletion src/trajdata/data_structures/collation.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,6 @@ def raster_map_collate_fn_scene(
max_agent_num: Optional[int] = None,
pad_value: Any = np.nan,
) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:

if batch_elems[0].map_patches is None:
return None, None, None, None

Expand Down
4 changes: 4 additions & 0 deletions src/trajdata/data_structures/scene_tag.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from typing import Set, Tuple


Expand All @@ -8,6 +9,9 @@ def __init__(self, tag_tuple: Tuple[str, ...]) -> None:
def contains(self, query: Set[str]) -> bool:
return query.issubset(self._tag_tuple)

def matches_any(self, regex: re.Pattern) -> bool:
return any(regex.search(x) is not None for x in self._tag_tuple)

def __contains__(self, item) -> bool:
return item in self._tag_tuple

Expand Down
Loading

0 comments on commit c1a9499

Please sign in to comment.