Skip to content

Commit

Permalink
Update pyproject and fixed yolo DDP error
Browse files Browse the repository at this point in the history
  • Loading branch information
BarzaH committed Sep 12, 2024
1 parent 54353da commit abf50b0
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 6 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
./data
wandb/
lightning_logs/
mmdetection3d/
*.pth
# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
7 changes: 7 additions & 0 deletions innofw/core/integrations/ultralytics/trainer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

from omegaconf import DictConfig

from ..base_adapter import BaseAdapter
Expand Down Expand Up @@ -28,6 +30,11 @@ def get_device(trainer_cfg):
devices = ",".join(map(str, range(n_devices)))

result = f"{devices}"
try:
if int(result) < 2:
return ""
except:
pass

return result

Expand Down
4 changes: 2 additions & 2 deletions pckg_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,13 @@ def check_gpu_and_torch_compatibility():
if "NVIDIA A100" in output:
install_and_import(
"torch",
"1.11.0+cu113",
"1.12.0+cu113",
"-f",
"https://download.pytorch.org/whl/torch_stable.html",
)
install_and_import(
"torchvision",
"0.12.0+cu113",
"0.13.0+cu113",
"-f",
"https://download.pytorch.org/whl/torch_stable.html",
)
Expand Down
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ authors = ["Kazybek Askarbek <[email protected]>"]

[[tool.poetry.source]]
name = "openmmlab"
url = "https://download.openmmlab.com/mmcv/dist/cu113/torch1.11/index.html"
url = "https://download.openmmlab.com/mmcv/dist/cu113/torch1.12/index.html"
priority = "supplemental"

[[tool.poetry.source]]
Expand All @@ -20,9 +20,9 @@ python-dotenv = "^0.20.0"



torch = {version = "1.11.0", source = "cudatorch"}
torchvision = {version = "0.12.0", source = "cudatorch"}
torchaudio = {version = "0.11.0", source = "cudatorch"}
torch = {version = "1.12.0", source = "cudatorch"}
torchvision = {version = "0.13.0", source = "cudatorch"}


tqdm = "^4.64.0"
h5py = "^3.7.0"
Expand Down

0 comments on commit abf50b0

Please sign in to comment.