Skip to content

Commit

Permalink
Merge pull request #955 from MukuFlash03/log-statements-model-load
Browse files Browse the repository at this point in the history
Cleanup fixes including log statements for PR #944
  • Loading branch information
shankari authored Jan 30, 2024
2 parents 54659fb + 9e87dbb commit 0a8c61d
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
11 changes: 8 additions & 3 deletions emission/analysis/classification/inference/labels/inferrers.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,14 +171,19 @@ def predict_cluster_confidence_discounting(trip_list, max_confidence=None, first
user_id_list = []
for trip in trip_list:
user_id_list.append(trip['user_id'])
assert user_id_list.count(user_id_list[0]) == len(user_id_list), "Multiple user_ids found for trip_list, expected unique user_id for all trips"
error_message = f"""
Multiple user_ids found for trip_list, expected unique user_id for all trips.
Unique user_ids count = {len(set(user_id_list))}
{set(user_id_list)}
"""
assert user_id_list.count(user_id_list[0]) == len(user_id_list), error_message
# Assertion successful, use unique user_id
user_id = user_id_list[0]

# load model
start_model_load_time = time.process_time()
model = eamur._load_stored_trip_model(user_id, model_type, model_storage)
print(f"{arrow.now()} Inside predict_labels_n: Model load time = {time.process_time() - start_model_load_time}")
logging.debug(f"{arrow.now()} Inside predict_cluster_confidence_discounting: Model load time = {time.process_time() - start_model_load_time}")

labels_n_list = eamur.predict_labels_with_n(trip_list, model)
predictions_list = []
Expand All @@ -192,4 +197,4 @@ def predict_cluster_confidence_discounting(trip_list, max_confidence=None, first
labels = copy.deepcopy(labels)
for l in labels: l["p"] *= confidence_coeff
predictions_list.append(labels)
return predictions_list
return predictions_list
5 changes: 3 additions & 2 deletions emission/analysis/modelling/trip_model/run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def predict_labels_with_n(
"""

predictions_list = []
print(f"{arrow.now()} Inside predict_labels_n: Predicting...")
logging.debug(f"{arrow.now()} Inside predict_labels_n: Predicting...")
start_predict_time = time.process_time()
for trip in trip_list:
if model is None:
Expand All @@ -118,7 +118,8 @@ def predict_labels_with_n(
else:
predictions, n = model.predict(trip)
predictions_list.append((predictions, n))
print(f"{arrow.now()} Inside predict_labels_n: Predictions complete for trip_list in time = {time.process_time() - start_predict_time}")
logging.debug(f"{arrow.now()} Inside predict_labels_n: Predictions complete for trip_list in time = {time.process_time() - start_predict_time}")
logging.debug(f"{arrow.now()} No. of trips = {len(trip_list)}; No. of predictions = {len(predictions_list)}")
return predictions_list


Expand Down

0 comments on commit 0a8c61d

Please sign in to comment.