Skip to content

Commit

Permalink
goood to go
Browse files Browse the repository at this point in the history
  • Loading branch information
grekiki2 committed Mar 1, 2023
1 parent bc70b22 commit f39fb13
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 26 deletions.
29 changes: 4 additions & 25 deletions LitModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,6 @@ def __init__(

self.save_hyperparameters()

# self.train_custom_metrics = {
# "train_acc": smp.utils.metrics.Accuracy(activation="softmax2d")
# }
# self.validation_custom_metrics = {
# "val_acc": smp.utils.metrics.Accuracy(activation="softmax2d")
# }

self.preprocess_fn = smp.encoders.get_preprocessing_fn(
self.backbone, pretrained="imagenet"
)
Expand All @@ -82,9 +75,7 @@ def __build_model(self):

def forward(self, x):
"""Forward pass. Returns logits."""

x = self.net(x)

return x

def loss(self, logits, labels):
Expand All @@ -96,17 +87,11 @@ def training_step(self, batch, batch_idx):
x, y = batch
y_logits = self.forward(x)

# 2. Compute loss & accuracy:
# 2. Compute loss:
train_loss = self.loss(y_logits, y)

metrics = {}
# for metric_name in self.train_custom_metrics.keys():
# metrics[metric_name] = self.train_custom_metrics[metric_name](y_logits, y)

# 3. Outputs:
output = OrderedDict(
{"loss": train_loss, "log": metrics, "progress_bar": metrics}
)
output = OrderedDict({"loss": train_loss})
self.log("train_loss", train_loss)
self.log("lr", self.trainer.optimizers[0].param_groups[0]["lr"])
return output
Expand All @@ -116,16 +101,10 @@ def validation_step(self, batch, batch_idx):
x, y = batch
y_logits = self.forward(x)

# 2. Compute loss & accuracy:
# 2. Compute loss:
val_loss = self.loss(y_logits, y)

metrics = {"val_loss": val_loss}

# for metric_name in self.validation_custom_metrics.keys():
# metrics[metric_name] = self.validation_custom_metrics[metric_name](
# y_logits, y
# )

return metrics

def validation_epoch_end(self, outputs):
Expand Down Expand Up @@ -252,7 +231,7 @@ def add_model_specific_args(parent_parser):
)
parser.add_argument(
"--data-path",
default="/home/gregor/Desktop/segnet/comma10k", # Change before merge
default="/home/yyousfi1/commaai/comma10k",
type=str,
metavar="dp",
help="data_path",
Expand Down
2 changes: 1 addition & 1 deletion train_lit_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def main(args):
wandb_logger.log_hyperparams(args)

checkpoint_callback = ModelCheckpoint(
dirpath="/home/gregor/logs/segnet/", # TODO change before merge
dirpath="/home/yyousfi1/LogFiles/comma/", # TODO change before merge
filename="{epoch:02d}_{val_loss:.4f}",
save_top_k=10,
monitor="val_loss",
Expand Down

0 comments on commit f39fb13

Please sign in to comment.