Skip to content

Commit

Permalink
v0.16.2 tracking uri bug
Browse files Browse the repository at this point in the history
  • Loading branch information
AliAl-Gburi committed Oct 31, 2023
1 parent f37d0f8 commit 0231f1f
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 25 deletions.
2 changes: 1 addition & 1 deletion examples/pytorch-garbage-classification-95-accuracy.ipynb

Large diffs are not rendered by default.

45 changes: 22 additions & 23 deletions mlflow_emissions_sdk/experiment_tracking_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,41 +46,40 @@ def read_params(self, experiment_tracking_params: dict) -> None:
self.flavor = experiment_tracking_params["flavor"]
verify_flavor_exists(self.flavor)

try:


# Creates a reference to the mlflow client
client = MlflowClient(
tracking_uri=self.experiment_tracking_params["tracking_uri"]
client = MlflowClient(
tracking_uri=self.tracking_uri
)
# Sets the tracking uri of the mlflow
mlflow.set_tracking_uri(self.experiment_tracking_params["tracking_uri"])
mlflow.set_tracking_uri(self.tracking_uri)
# Looks for the experiment with the given name,
# creates a new one if none are found
exp_id = dict(
exp_id = dict(
mlflow.set_experiment(
self.experiment_tracking_params["experiment_name"]
self.experiment_name
)
)["experiment_id"]
self.exp_id = exp_id
self.exp_id = exp_id
# creates a run with the given name and save the run id
run_id = dict(
run_id = dict(
client.create_run(
exp_id, run_name=self.experiment_tracking_params["run_name"]
exp_id, run_name=self.run_name
)
)
self.run_id = dict(run_id["info"])["run_id"]
self.run_id = dict(run_id["info"])["run_id"]
# starts the mlflow run
mlflow.start_run(self.run_id, exp_id)
mlflow.start_run(self.run_id, exp_id)

# specific for keras, autologs the model params and some metrics
if self.flavor == "keras":
mlflow.keras.autolog()
elif self.flavor == "pytorch":
mlflow.pytorch.autolog()
elif self.flavor == "sklearn":
mlflow.sklearn.autolog()
except Exception:
print("Please refer to a running instance of mlflow ui")
if self.flavor == "keras":
mlflow.keras.autolog()
elif self.flavor == "pytorch":
mlflow.pytorch.autolog()
elif self.flavor == "sklearn":
mlflow.sklearn.autolog()


def start_training_job(self):
verify_emission_tracker_is_instantiated(self.emissions_tracker)
Expand Down Expand Up @@ -112,7 +111,7 @@ def predict_image(self, img, model) -> int:

def end_training_job(self):
client = MlflowClient(
tracking_uri=self.experiment_tracking_params["tracking_uri"]
tracking_uri=self.tracking_uri
)
emissions = self.emissions_tracker.stop()

Expand All @@ -131,7 +130,7 @@ def evaluate_model_accuracy(self, model, *args) -> float:
"""

client = MlflowClient(
tracking_uri=self.experiment_tracking_params["tracking_uri"]
tracking_uri=self.tracking_uri
)
model_acc = 0
#
Expand Down Expand Up @@ -175,7 +174,7 @@ def evaluate_model_accuracy(self, model, *args) -> float:

def accuracy_per_emission(self, model, *args):
client = MlflowClient(
tracking_uri=self.experiment_tracking_params["tracking_uri"]
tracking_uri=self.tracking_uri
)
if len(args) == 2:
x_test = args[0]
Expand All @@ -193,7 +192,7 @@ def accuracy_per_emission(self, model, *args):

def emissions_per_10_inferences(self, model, test_data):
client = MlflowClient(
tracking_uri=self.experiment_tracking_params["tracking_uri"]
tracking_uri=self.tracking_uri
)
# For pytorch
if self.flavor == "pytorch":
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# run the installation of our mlflow_emissions_sdk
setup(
name="mlflow_emissions_sdk",
version="0.16.1",
version="0.16.2",
packages=["mlflow_emissions_sdk"],
description="tracks carbon emissions and logs it to mlfow",
install_requires=requirements,
Expand Down

0 comments on commit 0231f1f

Please sign in to comment.