Skip to content

Commit

Permalink
add use_xpu config for det_mv3_db.yml
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangyk0314 committed Feb 23, 2022
1 parent d6ec303 commit 49ecf9c
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 1 deletion.
1 change: 1 addition & 0 deletions configs/det/det_mv3_db.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
Global:
use_gpu: true
use_xpu: false
epoch_num: 1200
log_smooth_window: 20
print_batch_step: 10
Expand Down
31 changes: 30 additions & 1 deletion tools/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,25 @@ def check_gpu(use_gpu):
pass


def check_xpu(use_xpu):
"""
Log error and exit when set use_xpu=true in paddlepaddle
cpu/gpu version.
"""
err = "Config use_xpu cannot be set as true while you are " \
"using paddlepaddle cpu/gpu version ! \nPlease try: \n" \
"\t1. Install paddlepaddle-xpu to run model on XPU \n" \
"\t2. Set use_xpu as false in config file to run " \
"model on CPU/GPU"

try:
if use_xpu and not paddle.is_compiled_with_xpu():
print(err)
sys.exit(1)
except Exception as e:
pass


def train(config,
train_dataloader,
valid_dataloader,
Expand Down Expand Up @@ -512,14 +531,24 @@ def preprocess(is_train=False):
use_gpu = config['Global']['use_gpu']
check_gpu(use_gpu)

# check if set use_xpu=True in paddlepaddle cpu/gpu version
use_xpu = False
if 'use_xpu' in config['Global']:
use_xpu = config['Global']['use_xpu']
check_xpu(use_xpu)

alg = config['Architecture']['algorithm']
assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM'
]

device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
device = 'cpu'
if use_gpu:
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id)
if use_xpu:
device = 'xpu'
device = paddle.set_device(device)

config['Global']['distributed'] = dist.get_world_size() != 1
Expand Down

0 comments on commit 49ecf9c

Please sign in to comment.