From 3a7c740365de186a044751c51d6646a6efcaf116 Mon Sep 17 00:00:00 2001 From: Govind Pimpale Date: Wed, 18 Oct 2023 08:59:46 -0700 Subject: [PATCH] Make TopDownMultiChannel support ScenarioEnv (#498) * works! * add preliminary support for top down view when using a scenario * format * remove extraeous change * draw trajectories * format and fix bug --- .../road_network/edge_road_network.py | 2 +- metadrive/manager/scenario_traffic_manager.py | 5 ++ metadrive/obs/top_down_obs_multi_channel.py | 49 ++++++++++++++----- 3 files changed, 44 insertions(+), 12 deletions(-) diff --git a/metadrive/component/road_network/edge_road_network.py b/metadrive/component/road_network/edge_road_network.py index 49c324a1d..dd8e46a4d 100644 --- a/metadrive/component/road_network/edge_road_network.py +++ b/metadrive/component/road_network/edge_road_network.py @@ -8,7 +8,7 @@ from metadrive.utils.math import get_boxes_bounding_box from metadrive.utils.pg.utils import get_lanes_bounding_box -lane_info = namedtuple("neighbor_lanes", "lane entry_lanes exit_lanes left_lanes right_lanes") +lane_info = namedtuple("edge_lane", ["lane", "entry_lanes", "exit_lanes", "left_lanes", "right_lanes"]) class EdgeRoadNetwork(BaseRoadNetwork): diff --git a/metadrive/manager/scenario_traffic_manager.py b/metadrive/manager/scenario_traffic_manager.py index 5977809e0..75fb8a876 100644 --- a/metadrive/manager/scenario_traffic_manager.py +++ b/metadrive/manager/scenario_traffic_manager.py @@ -6,6 +6,7 @@ from metadrive.component.static_object.traffic_object import TrafficCone, TrafficBarrier from metadrive.component.traffic_participants.cyclist import Cyclist from metadrive.component.traffic_participants.pedestrian import Pedestrian +from metadrive.component.vehicle.base_vehicle import BaseVehicle from metadrive.component.vehicle.vehicle_type import get_vehicle_type, reset_vehicle_type_count from metadrive.constants import DEFAULT_AGENT from metadrive.manager.base_manager import BaseManager @@ -162,6 +163,10 @@ def sdc_object_id(self): def current_scenario_length(self): return self.engine.data_manager.current_scenario_length + @property + def vehicles(self): + return list(self.engine.get_objects(filter=lambda o: isinstance(o, BaseVehicle)).values()) + def spawn_vehicle(self, v_id, track): state = parse_object_state(track, self.episode_step) diff --git a/metadrive/obs/top_down_obs_multi_channel.py b/metadrive/obs/top_down_obs_multi_channel.py index 11d7919c0..2836026cf 100644 --- a/metadrive/obs/top_down_obs_multi_channel.py +++ b/metadrive/obs/top_down_obs_multi_channel.py @@ -5,14 +5,23 @@ import numpy as np from metadrive.component.vehicle.base_vehicle import BaseVehicle +from metadrive.component.traffic_participants.base_traffic_participant import BaseTrafficParticipant +from metadrive.scenario.scenario_description import ScenarioDescription +from metadrive.component.lane.point_lane import PointLane from metadrive.constants import Decoration, DEFAULT_AGENT from metadrive.obs.top_down_obs import TopDownObservation from metadrive.obs.top_down_obs_impl import WorldSurface, COLOR_BLACK, ObjectGraphics, LaneGraphics, \ ObservationWindowMultiChannel from metadrive.utils import import_pygame, clip +from metadrive.component.road_network.node_road_network import NodeRoadNetwork +from metadrive.component.vehicle_navigation_module.node_network_navigation import NodeNetworkNavigation +from metadrive.component.vehicle_navigation_module.edge_network_navigation import EdgeNetworkNavigation +from metadrive.component.vehicle_navigation_module.trajectory_navigation import TrajectoryNavigation + pygame, gfxdraw = import_pygame() COLOR_WHITE = pygame.Color("white") +DEFAULT_TRAJECTORY_LANE_WIDTH = 3 class TopDownMultiChannel(TopDownObservation): @@ -106,16 +115,29 @@ def draw_map(self) -> pygame.Surface: self.canvas_background.move_display_window_to(centering_pos) self.canvas_road_network.move_display_window_to(centering_pos) - # self.draw_navigation(self.canvas_navigation) - self.draw_navigation(self.canvas_background, (64, 64, 64)) + if isinstance(self.target_vehicle.navigation, NodeNetworkNavigation): + self.draw_navigation_node(self.canvas_background, (64, 64, 64)) + elif isinstance(self.target_vehicle.navigation, EdgeNetworkNavigation): + # TODO: draw edge network navigation + pass + elif isinstance(self.target_vehicle.navigation, TrajectoryNavigation): + self.draw_navigation_trajectory(self.canvas_background, (64, 64, 64)) + + if isinstance(self.road_network, NodeRoadNetwork): + for _from in self.road_network.graph.keys(): + decoration = True if _from == Decoration.start else False + for _to in self.road_network.graph[_from].keys(): + for l in self.road_network.graph[_from][_to]: + two_side = True if l is self.road_network.graph[_from][_to][-1] or decoration else False + LaneGraphics.LANE_LINE_WIDTH = 0.5 + LaneGraphics.display(l, self.canvas_background, two_side) + elif hasattr(self.engine, "map_manager"): + for data in self.engine.map_manager.current_map.blocks[-1].map_data.values(): + if ScenarioDescription.POLYLINE in data: + LaneGraphics.display_scenario_line( + data[ScenarioDescription.POLYLINE], data[ScenarioDescription.TYPE], self.canvas_background + ) - for _from in self.road_network.graph.keys(): - decoration = True if _from == Decoration.start else False - for _to in self.road_network.graph[_from].keys(): - for l in self.road_network.graph[_from][_to]: - two_side = True if l is self.road_network.graph[_from][_to][-1] or decoration else False - LaneGraphics.LANE_LINE_WIDTH = 0.5 - LaneGraphics.display(l, self.canvas_background, two_side) self.canvas_road_network.blit(self.canvas_background, (0, 0)) self.obs_window.reset(self.canvas_runtime) self._should_draw_map = False @@ -142,7 +164,8 @@ def draw_scene(self): ego_heading = vehicle.heading_theta ego_heading = ego_heading if abs(ego_heading) > 2 * np.pi / 180 else 0 - for v in self.engine.traffic_manager.vehicles: + for v in self.engine.get_objects(lambda o: isinstance(o, BaseVehicle) or isinstance(o, BaseTrafficParticipant) + ).values(): if v is vehicle: continue h = v.heading_theta @@ -256,13 +279,17 @@ def observe(self, vehicle: BaseVehicle): img = np.clip(img, 0, 255) return np.transpose(img, (1, 0, 2)) - def draw_navigation(self, canvas, color=(128, 128, 128)): + def draw_navigation_node(self, canvas, color=(128, 128, 128)): checkpoints = self.target_vehicle.navigation.checkpoints for i, c in enumerate(checkpoints[:-1]): lanes = self.road_network.graph[c][checkpoints[i + 1]] for lane in lanes: LaneGraphics.draw_drivable_area(lane, canvas, color=color) + def draw_navigation_trajectory(self, canvas, color=(128, 128, 128)): + lane = PointLane(self.target_vehicle.navigation.checkpoints, DEFAULT_TRAJECTORY_LANE_WIDTH) + LaneGraphics.draw_drivable_area(lane, canvas, color=color) + def _get_stack_indices(self, length, frame_skip=None): frame_skip = frame_skip or self.frame_skip num = int(math.ceil(length / frame_skip))