-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #20 from mohwald/enhance/pytorch
Feature: add YOLOv8 inference and finetunning
- Loading branch information
Showing
15 changed files
with
506 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
from .predict import predict | ||
from .evaluate import evaluate |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
# Patched from ramp-code.scripts.calculate_accuracy.iou created for ramp project by [email protected] | ||
|
||
from pathlib import Path | ||
import geopandas as gpd | ||
|
||
from ramp.utils.eval_utils import get_iou_accuracy_metrics | ||
|
||
|
||
def evaluate(test_path, truth_path, filter_area_m2=None, iou_threshold=0.5, verbose=False): | ||
""" | ||
Calculate precision/recall/F1-score based on intersection-over-union accuracy evaluation protocol defined by RAMP. | ||
The predicted masks will be georeferenced with EPSG:3857 as CRS | ||
Args: | ||
test_path: Path where the weights of the model can be found. | ||
truth_path: Path of the directory where the images are stored. | ||
filter_area_m2: Minimum area of buildings to analyze in m^2. | ||
iou_threshold: (float, 0<threshold<1) above which value of IoU of a detection is considered to be accurate | ||
verbose: Bool, more statistics are printed when turned on. | ||
Example:: | ||
evaluate( | ||
"data/prediction.geojson", | ||
"data/labels.geojson" | ||
) | ||
""" | ||
|
||
test_path, truth_path = Path(test_path), Path(truth_path) | ||
truth_df, test_df = gpd.read_file(str(truth_path)), gpd.read_file(str(test_path)) | ||
metrics = get_iou_accuracy_metrics(test_df, truth_df, filter_area_m2, iou_threshold) | ||
|
||
n_detections = metrics['n_detections'] | ||
n_truth = metrics["n_truth"] | ||
n_truepos = metrics['true_pos'] | ||
n_falsepos = n_detections - n_truepos | ||
n_falseneg = n_truth - n_truepos | ||
agg_precision = n_truepos / n_detections | ||
agg_recall = n_truepos / n_truth | ||
agg_f1 = 2 * n_truepos / (n_truth + n_detections) | ||
|
||
if verbose: | ||
print(f"Detections: {n_detections}") | ||
print(f"Truth buildings: {n_truth}") | ||
print(f"True positives: {n_truepos}") | ||
print(f"False positives: {n_falsepos}") | ||
print(f"False negatives: {n_falseneg}") | ||
print(f"Precision IoU@p: {agg_precision}") | ||
print(f"Recall IoU@p: {agg_recall}") | ||
print(f"F1 IoU@p: {agg_f1}") | ||
|
||
return { | ||
"precision": agg_precision, | ||
"recall": agg_recall, | ||
"f1": agg_f1, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import torch | ||
import torch.nn as nn | ||
import ultralytics | ||
|
||
from ultralytics.utils import RANK | ||
|
||
|
||
# | ||
# Binary cross entropy with p_c | ||
# | ||
|
||
class YOLOSegWithPosWeight(ultralytics.YOLO): | ||
|
||
def train(self, trainer=None, pc=1.0, **kwargs): | ||
return super().train(trainer, **{**kwargs, "pose": pc}) # Hide pc inside pose (pose est loss weight arg) | ||
|
||
@property | ||
def task_map(self): | ||
map = super().task_map | ||
map['segment']['model'] = SegmentationModelWithPosWeight | ||
map['segment']['trainer'] = SegmentationTrainerWithPosWeight | ||
return map | ||
|
||
|
||
class SegmentationTrainerWithPosWeight(ultralytics.models.yolo.segment.train.SegmentationTrainer): | ||
|
||
def get_model(self, cfg=None, weights=None, verbose=True): | ||
"""Return a YOLO segmentation model.""" | ||
model = SegmentationModelWithPosWeight(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1) | ||
if weights: | ||
model.load(weights) | ||
return model | ||
|
||
|
||
class SegmentationModelWithPosWeight(ultralytics.models.yolo.segment.train.SegmentationModel): | ||
|
||
def init_criterion(self): | ||
return v8SegmentationLossWithPosWeight(model=self) | ||
|
||
|
||
class v8SegmentationLossWithPosWeight(ultralytics.utils.loss.v8SegmentationLoss): | ||
|
||
def __init__(self, model): | ||
super().__init__(model) | ||
pc = model.args.pose # hidden in pose arg (used in different task) | ||
pos_weight = torch.full((model.nc,), pc).to(self.device) | ||
self.bce = nn.BCEWithLogitsLoss(reduction="none", pos_weight=pos_weight) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.