diff --git a/d2l/torch.py b/d2l/torch.py index 77e5f3577..3c901a743 100644 --- a/d2l/torch.py +++ b/d2l/torch.py @@ -45,6 +45,7 @@ from torch.nn import functional as F from torch.utils import data from torchvision import transforms +import torch_musa def use_svg_display(): """使用svg格式在Jupyter中显示绘图 @@ -428,15 +429,21 @@ def try_gpu(i=0): Defined in :numref:`sec_use_gpu`""" if torch.cuda.device_count() >= i + 1: return torch.device(f'cuda:{i}') + if torch.musa.device_count() >= i + 1: + return torch.device(f'musa:{i}') return torch.device('cpu') def try_all_gpus(): """返回所有可用的GPU,如果没有GPU,则返回[cpu(),] Defined in :numref:`sec_use_gpu`""" - devices = [torch.device(f'cuda:{i}') - for i in range(torch.cuda.device_count())] - return devices if devices else [torch.device('cpu')] + num = torch.cuda.device_count() + if num > 0: + return [torch.device(f'cuda:{i}') for i in range(num)] + num = torch.musa.device_count() + if num > 0: + return [torch.device(f'musa:{i}') for i in range(num)] + return [torch.device('cpu')] def corr2d(X, K): """计算二维互相关运算 diff --git a/install.sh b/install.sh new file mode 100755 index 000000000..8203fd96d --- /dev/null +++ b/install.sh @@ -0,0 +1,4 @@ +rm -rf ./build +rm -rf ./dist +python setup.py bdist_wheel +pip install dist/d2l-2.0.0-py3-none-any.whl \ No newline at end of file diff --git a/setup.py b/setup.py index 1a38f1c75..c9062073c 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ 'numpy==1.21.5', 'matplotlib==3.5.1', 'requests==2.25.1', - 'pandas==1.2.4' + 'pandas==2.0.3' ] setup(