Skip to content

Commit

Permalink
support pytorch on musa & update deprecated deps
Browse files Browse the repository at this point in the history
  • Loading branch information
mt-wangjiangyuan committed Oct 16, 2024
1 parent e6b18cc commit ef3ec42
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 4 deletions.
13 changes: 10 additions & 3 deletions d2l/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from torch.nn import functional as F
from torch.utils import data
from torchvision import transforms
import torch_musa

def use_svg_display():
"""使用svg格式在Jupyter中显示绘图
Expand Down Expand Up @@ -428,15 +429,21 @@ def try_gpu(i=0):
Defined in :numref:`sec_use_gpu`"""
if torch.cuda.device_count() >= i + 1:
return torch.device(f'cuda:{i}')
if torch.musa.device_count() >= i + 1:
return torch.device(f'musa:{i}')
return torch.device('cpu')

def try_all_gpus():
"""返回所有可用的GPU,如果没有GPU,则返回[cpu(),]
Defined in :numref:`sec_use_gpu`"""
devices = [torch.device(f'cuda:{i}')
for i in range(torch.cuda.device_count())]
return devices if devices else [torch.device('cpu')]
num = torch.cuda.device_count()
if num > 0:
return [torch.device(f'cuda:{i}') for i in range(num)]
num = torch.musa.device_count()
if num > 0:
return [torch.device(f'musa:{i}') for i in range(num)]
return [torch.device('cpu')]

def corr2d(X, K):
"""计算二维互相关运算
Expand Down
4 changes: 4 additions & 0 deletions install.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
rm -rf ./build
rm -rf ./dist
python setup.py bdist_wheel
pip install dist/d2l-2.0.0-py3-none-any.whl
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
'numpy==1.21.5',
'matplotlib==3.5.1',
'requests==2.25.1',
'pandas==1.2.4'
'pandas==2.0.3'
]

setup(
Expand Down

0 comments on commit ef3ec42

Please sign in to comment.