Skip to content

Commit

Permalink
Fix bug in add_ncu (#196)
Browse files Browse the repository at this point in the history
moves check for chosen_metrics=None, so we can check for duplicate columns when chosen_metrics=None as well
  • Loading branch information
michaelmckinsey1 authored Oct 28, 2024
1 parent a8387ce commit fba7235
Showing 1 changed file with 29 additions and 19 deletions.
48 changes: 29 additions & 19 deletions thicket/thicket.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ def add_ncu(self, ncu_report_mapping, chosen_metrics=None, overwrite=False):
Arguments:
ncu_report_mapping (dict): mapping from NCU report file to profile
chosen_metrics (list): list of metrics to sub-select from NCU report
chosen_metrics (list): list of metrics to sub-select from NCU report. By default, all metrics are used.
overwrite (bool): whether to overwrite existing columns in the Thicket.DataFrame
"""

Expand All @@ -579,24 +579,25 @@ def _rep_agg_func(col):
else:
return col[0]

# Remove duplicate metrics in chosen_metrics if the user provided duplicates
unique_metrics = list(set(chosen_metrics))
if len(unique_metrics) != len(chosen_metrics):
dupe_mets = [
met
for met, count in collections.Counter(chosen_metrics).items()
if count > 1
]
warnings.warn(f"Removing duplicate metrics in chosen_metrics: {dupe_mets}")
chosen_metrics = unique_metrics

# Check if chosen_metrics are in the dataframe
dupe_cols = [col for col in chosen_metrics if col in self.dataframe.columns]
if overwrite:
self.dataframe = self.dataframe.drop(columns=dupe_cols)
elif not overwrite and len(dupe_cols) > 0:
raise ValueError(
f"Columns {dupe_cols} already exist in the performance data table. Set overwrite=True to overwrite."
# If list, check for duplicate metrics
if isinstance(chosen_metrics, list):
# Remove duplicate metrics in chosen_metrics if the user provided duplicates
unique_metrics = list(set(chosen_metrics))
if len(unique_metrics) != len(chosen_metrics):
dupe_mets = [
met
for met, count in collections.Counter(chosen_metrics).items()
if count > 1
]
warnings.warn(
f"Removing duplicate metrics in chosen_metrics: {dupe_mets}"
)
chosen_metrics = unique_metrics
elif chosen_metrics is None:
pass
else:
raise TypeError(
f"If provided, chosen_metrics ({type(chosen_metrics)}) must be a list"
)

# Initialize reader
Expand All @@ -622,6 +623,15 @@ def _rep_agg_func(col):
if chosen_metrics:
ncu_df = ncu_df[chosen_metrics]

# Overwrite check. We can't check earlier, if chosen_metrics is None, as we haven't read the ncu file yet.
dupe_cols = [col for col in ncu_df.columns if col in self.dataframe.columns]
if overwrite:
self.dataframe = self.dataframe.drop(columns=dupe_cols)
elif not overwrite and len(dupe_cols) > 0:
raise ValueError(
f"Columns {dupe_cols} already exist in the performance data table. Set overwrite=True to overwrite."
)

# Join NCU DataFrame into Thicket
self.dataframe = self.dataframe.join(
ncu_df,
Expand Down

0 comments on commit fba7235

Please sign in to comment.