Skip to content

Commit

Permalink
Fix tensor type compatibility in dynamic soft label assigner and RTMD…
Browse files Browse the repository at this point in the history
…et head (#4140)

* Fix tensor type compatibility in dynamic soft label assigner and RTMDet head

* Update CHANGELOG
  • Loading branch information
eugene123tw authored Dec 4, 2024
1 parent ec610a9 commit cf035f6
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 2 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ All notable changes to this project will be documented in this file.
(<https://github.com/openvinotoolkit/training_extensions/pull/4107>)
- Fix empty annotation in tiling
(<https://github.com/openvinotoolkit/training_extensions/pull/4124>)
- Fix tensor type compatibility in dynamic soft label assigner and RTMDet head
(<https://github.com/openvinotoolkit/training_extensions/pull/4140>)

## \[v2.1.0\]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def assign(
assigned_labels = assigned_gt_inds.new_full((num_bboxes,), -1)
assigned_labels[valid_mask] = gt_labels[matched_gt_inds].long()
max_overlaps = assigned_gt_inds.new_full((num_bboxes,), -INF, dtype=torch.float32)
max_overlaps[valid_mask] = matched_pred_ious
max_overlaps[valid_mask] = matched_pred_ious.to(max_overlaps)
return AssignResult(num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels)

def dynamic_k_matching(
Expand Down
2 changes: 1 addition & 1 deletion src/otx/algo/detection/heads/rtmdet_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ def _get_targets_single( # type: ignore[override]
if len(pos_inds) > 0:
# point-based
pos_bbox_targets = sampling_result.pos_gt_bboxes
bbox_targets[pos_inds, :] = pos_bbox_targets
bbox_targets[pos_inds, :] = pos_bbox_targets.to(bbox_targets)

labels[pos_inds] = sampling_result.pos_gt_labels
if self.train_cfg["pos_weight"] <= 0:
Expand Down

0 comments on commit cf035f6

Please sign in to comment.