Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implemented label in cluster estimation #231

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 115 additions & 17 deletions modules/cluster_estimation/cluster_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,22 @@ class ClusterEstimation:
METHODS
-------
run()
Take in list of landing pad detections and return list of estimated landing pad locations
Take in list of object detections and return list of estimated object locations
if number of detections is sufficient, or if manually forced to run.

cluster_by_label()
Take in list of detections of the same label and return list of estimated object locations
of the same label.

__decide_to_run()
Decide when to run cluster estimation model.

__sort_by_weights()
Sort input model output list by weights in descending order.

__sort_by_labels()
Sort input detection list by labels in descending order.

__convert_detections_to_point()
Convert DetectionInWorld input object to a [x,y] position to store.

Expand Down Expand Up @@ -131,10 +138,12 @@ def __init__(
self.__logger = local_logger

def run(
self, detections: "list[detection_in_world.DetectionInWorld]", run_override: bool
self,
detections: "list[detection_in_world.DetectionInWorld]",
run_override: bool,
) -> "tuple[bool, list[object_in_world.ObjectInWorld] | None]":
"""
Take in list of landing pad detections and return list of estimated landing pad locations
Take in list of detections and return list of estimated object locations
if number of detections is sufficient, or if manually forced to run.

PARAMETERS
Expand All @@ -159,24 +168,95 @@ def run(
# Store new input data
self.__current_bucket += self.__convert_detections_to_point(detections)

# Decide to run
if not self.__decide_to_run(run_override):
return False, None

# sort bucket by label in descending order
self.__all_points = self.__sort_by_labels(self.__current_bucket)
detections_in_world = []

# init search parameters
ptr = 0

# itterates through all points
while ptr <= len(self.__current_bucket):
# reference label
label = self.__current_bucket[ptr][2]

# creates bucket of points with the same label since bucket is sorted by label
bucket_labelled = []
while ptr < len(self.__current_bucket) and self.__all_points[ptr][2] == label:
bucket_labelled.append([self.__all_points[ptr]])
ptr += 1

# skip if no objects have label=label
if len(bucket_labelled) == 0:
continue

result, labelled_detections_in_world = self.cluster_by_label(
bucket_labelled, run_override, label
)

# checks if cluster_by_label ran succssfully
if not result:
self.__logger.warning(
f"did not add objects of label={label} to total object detections"
)
continue

detections_in_world += labelled_detections_in_world

return True, detections_in_world

def cluster_by_label(
self,
points: "list[tuple[float, float, int]]",
run_override: bool,
label: int,
) -> "tuple[bool, list[object_in_world.ObjectInWorld] | None]":
"""
Take in list of detections of the same label and return list of estimated object locations
of the same label.

PARAMETERS
----------
points: list[tuple[float, float, int]]
List containing tuple objects which holds real-world positioning data to run
clustering on and their labels

run_override: bool
Forces ClusterEstimation to predict if data is available, regardless of any other
requirements.

RETURNS
-------
model_ran: bool
True if ClusterEstimation object successfully ran its estimation model, False otherwise.

objects_in_world: list[ObjectInWorld] or None.
List containing ObjectInWorld objects, containing position and covariance value.
None if conditions not met and model not ran or model failed to converge.
"""

# Decide to run
if not self.__decide_to_run(run_override):
return False, None

# Fit points and get cluster data
self.__vgmm = self.__vgmm.fit(self.__all_points) # type: ignore
__vgmm_label = self.__vgmm.fit(points) # type: ignore

# Check convergence
if not self.__vgmm.converged_:
self.__logger.warning("Model failed to converge")
if not __vgmm_label.converged_:
self.__logger.warning(f"Model for label={label} failed to converge")
return False, None

# Get predictions from cluster model
model_output: "list[tuple[np.ndarray, float, float]]" = list(
zip(
self.__vgmm.means_, # type: ignore
self.__vgmm.weights_, # type: ignore
self.__vgmm.covariances_, # type: ignore
__vgmm_label.means_, # type: ignore
__vgmm_label.weights_, # type: ignore
__vgmm_label.covariances_, # type: ignore
)
)

Expand All @@ -203,9 +283,7 @@ def run(
detections_in_world = []
for cluster in model_output:
result, landing_pad = object_in_world.ObjectInWorld.create(
cluster[0][0],
cluster[0][1],
cluster[2],
cluster[0][0], cluster[0][1], cluster[2], label
)

if result:
Expand Down Expand Up @@ -274,12 +352,32 @@ def __sort_by_weights(
"""
return sorted(model_output, key=lambda x: x[1], reverse=True)

@staticmethod
def __sort_by_labels(
points: "list[tuple[float, float, int]]",
) -> "list[tuple[float, float, int]]":
"""
Sort input detection list by labels in descending order.

PARAMETERS
----------
detections: list[tuple[float, float, int]]
List containing detections, with each element having the format
[x_position, y_position, label].

RETURNS
-------
list[tuple[np.ndarray, float, float]]
List containing detection points sorted in descending order by label
"""
return sorted(points, key=lambda x: x.label, reverse=True)

@staticmethod
def __convert_detections_to_point(
detections: "list[detection_in_world.DetectionInWorld]",
) -> "list[tuple[float, float]]":
) -> "list[tuple[float, float, int]]":
"""
Convert DetectionInWorld input object to a list of points- (x,y) positions, to store.
Convert DetectionInWorld input object to a list of points- (x,y) positions with label, to store.

PARAMETERS
----------
Expand All @@ -289,8 +387,8 @@ def __convert_detections_to_point(

RETURNS
-------
points: list[tuple[float, float]]
List of points (x,y).
points: list[tuple[float, float, int]]
List of points (x,y) and their label
-------
"""
points = []
Expand All @@ -302,7 +400,7 @@ def __convert_detections_to_point(
# Convert DetectionInWorld objects
for detection in detections:
# `centre` attribute holds positioning data
points.append(tuple([detection.centre[0], detection.centre[1]]))
points.append(tuple([detection.centre[0], detection.centre[1], detection.label]))

return points

Expand Down
10 changes: 7 additions & 3 deletions modules/object_in_world.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class ObjectInWorld:

@classmethod
def create(
cls, location_x: float, location_y: float, spherical_variance: float
cls, location_x: float, location_y: float, spherical_variance: float, label: int
) -> "tuple[bool, ObjectInWorld | None]":
"""
location_x, location_y: Location of the object.
Expand All @@ -21,14 +21,17 @@ def create(
if spherical_variance < 0.0:
return False, None

return True, ObjectInWorld(cls.__create_key, location_x, location_y, spherical_variance)
return True, ObjectInWorld(
cls.__create_key, location_x, location_y, spherical_variance, label
)

def __init__(
self,
class_private_create_key: object,
location_x: float,
location_y: float,
spherical_variance: float,
label: int,
) -> None:
"""
Private constructor, use create() method.
Expand All @@ -38,12 +41,13 @@ def __init__(
self.location_x = location_x
self.location_y = location_y
self.spherical_variance = spherical_variance
self.label = label

def __str__(self) -> str:
"""
To string.
"""
return f"{self.__class__}, location_x: {self.location_x}, location_y: {self.location_y}, spherical_variance: {self.spherical_variance}"
return f"{self.__class__}, location_x: {self.location_x}, location_y: {self.location_y}, spherical_variance: {self.spherical_variance}, label: {self.label}"

def __repr__(self) -> str:
"""
Expand Down
14 changes: 9 additions & 5 deletions tests/unit/test_decision.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ def best_pad_within_tolerance() -> object_in_world.ObjectInWorld: # type: ignor
location_x = BEST_PAD_LOCATION_X
location_y = BEST_PAD_LOCATION_Y
spherical_variance = 1.0
result, pad = object_in_world.ObjectInWorld.create(location_x, location_y, spherical_variance)
result, pad = object_in_world.ObjectInWorld.create(
location_x, location_y, spherical_variance, 0
)
assert result
assert pad is not None

Expand All @@ -58,7 +60,9 @@ def best_pad_outside_tolerance() -> object_in_world.ObjectInWorld: # type: igno
location_x = 100.0
location_y = 200.0
spherical_variance = 5.0 # variance outside tolerance
result, pad = object_in_world.ObjectInWorld.create(location_x, location_y, spherical_variance)
result, pad = object_in_world.ObjectInWorld.create(
location_x, location_y, spherical_variance, 0
)
assert result
assert pad is not None

Expand All @@ -70,15 +74,15 @@ def pads() -> "list[object_in_world.ObjectInWorld]": # type: ignore
"""
Create a list of ObjectInWorld instances for the landing pads.
"""
result, pad_1 = object_in_world.ObjectInWorld.create(30.0, 40.0, 2.0)
result, pad_1 = object_in_world.ObjectInWorld.create(30.0, 40.0, 2.0, 0)
assert result
assert pad_1 is not None

result, pad_2 = object_in_world.ObjectInWorld.create(50.0, 60.0, 3.0)
result, pad_2 = object_in_world.ObjectInWorld.create(50.0, 60.0, 3.0, 0)
assert result
assert pad_2 is not None

result, pad_3 = object_in_world.ObjectInWorld.create(70.0, 80.0, 4.0)
result, pad_3 = object_in_world.ObjectInWorld.create(70.0, 80.0, 4.0, 0)
assert result
assert pad_3 is not None

Expand Down
Loading
Loading