Skip to content

Commit

Permalink
replace cuda to device in folder libai (#545)
Browse files Browse the repository at this point in the history
* libai/engine/default.py cuda->to(device)

* eval util

* beam search

* tokenization base

* graph base

* npu mlu xpu

* replace cuda to device

* fix format

* have to reformat

* have to reformat
  • Loading branch information
ShawnXuan authored Aug 21, 2024
1 parent a3a467c commit 4a57a74
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 17 deletions.
4 changes: 2 additions & 2 deletions libai/engine/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,8 +508,8 @@ def get_batch(

if mixup_func is not None:
images, labels = mixup_func(
data.get("images").tensor.cuda(),
data.get("labels").tensor.cuda(),
data.get("images").tensor.to(input_placement_device),
data.get("labels").tensor.to(input_placement_device),
)
data.get("images").tensor = images
data.get("labels").tensor = labels
Expand Down
6 changes: 3 additions & 3 deletions libai/evaluation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from libai.utils import distributed as dist


def pad_batch(x_dict, batch_size, last_batch_lack, is_last_batch):
def pad_batch(x_dict, batch_size, last_batch_lack, is_last_batch, device="cuda"):
x = list(x_dict.values())[0]
tensor_batch = x.shape[0]
assert tensor_batch <= batch_size
Expand All @@ -37,9 +37,9 @@ def pad_batch(x_dict, batch_size, last_batch_lack, is_last_batch):
for key, xi in x_dict.items():
pad_shape = (batch_size, *xi.shape[1:])
local_xi = xi.to_global(
sbp=flow.sbp.broadcast, placement=flow.env.all_device_placement("cuda")
sbp=flow.sbp.broadcast, placement=flow.env.all_device_placement(device)
).to_local()
padded_xi = flow.zeros(pad_shape, dtype=xi.dtype, device="cuda")
padded_xi = flow.zeros(pad_shape, dtype=xi.dtype, device=device)
padded_xi[:tensor_batch, ...] = padded_xi[:tensor_batch, ...] + local_xi
for i in range(last_batch_lack - 1):
start_idx = tensor_micro_batch_size * (data_parallel_size - i - 1) - 1
Expand Down
13 changes: 8 additions & 5 deletions libai/inference/generator/generation_beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def __init__(
do_early_stopping: Optional[bool] = False,
num_beam_hyps_to_keep: Optional[int] = 1,
num_beam_groups: Optional[int] = 1,
device: Optional[str] = "cuda",
**kwargs,
):
self.num_beams = num_beams
Expand All @@ -119,7 +120,7 @@ def __init__(
[False for _ in range(batch_size)],
dtype=flow.bool,
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
placement=flow.placement("cuda", list(range(dist.get_world_size()))),
placement=flow.placement(device, list(range(dist.get_world_size()))),
)

if not isinstance(num_beams, int) or num_beams <= 1:
Expand Down Expand Up @@ -159,6 +160,7 @@ def process(
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
beam_indices: Optional[flow.Tensor] = None,
device: Optional[str] = "cuda",
) -> Tuple[flow.Tensor]:
cur_len = input_ids.shape[-1]
batch_size = len(self._beam_hyps)
Expand All @@ -177,19 +179,19 @@ def process(
(batch_size, self.group_size),
dtype=next_scores.dtype,
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
placement=flow.placement("cuda", list(range(dist.get_world_size()))),
placement=flow.placement(device, list(range(dist.get_world_size()))),
)
next_beam_tokens = flow.zeros(
(batch_size, self.group_size),
dtype=next_tokens.dtype,
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
placement=flow.placement("cuda", list(range(dist.get_world_size()))),
placement=flow.placement(device, list(range(dist.get_world_size()))),
)
next_beam_indices = flow.zeros(
(batch_size, self.group_size),
dtype=next_indices.dtype,
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
placement=flow.placement("cuda", list(range(dist.get_world_size()))),
placement=flow.placement(device, list(range(dist.get_world_size()))),
)

for batch_idx, beam_hyp in enumerate(self._beam_hyps):
Expand Down Expand Up @@ -274,6 +276,7 @@ def finalize(
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
beam_indices: Optional[flow.Tensor] = None,
device: Optional[str] = "cuda",
):
batch_size = len(self._beam_hyps)
# finalize all open beam hypotheses and add to generated hypotheses
Expand Down Expand Up @@ -303,7 +306,7 @@ def finalize(
batch_size * self.num_beam_hyps_to_keep,
dtype=flow.float32,
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
placement=flow.placement("cuda", list(range(dist.get_world_size()))),
placement=flow.placement(device, list(range(dist.get_world_size()))),
)

# retrieve best hypotheses
Expand Down
4 changes: 3 additions & 1 deletion libai/models/utils/graph_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,14 @@ def __init__(
is_train=True,
auto_parallel_conf=None,
global_mode=None,
device="cuda",
):
super().__init__()

self.model = model
self.is_train = is_train
self.global_mode = global_mode
self.device = device

if is_train:
self.add_optimizer(optimizer, lr_sch=lr_scheduler)
Expand Down Expand Up @@ -103,7 +105,7 @@ def build(self, **kwargs):
if self.is_train:
placement_sbp_dict = (
dict(
placement=flow.env.all_device_placement("cuda"),
placement=flow.env.all_device_placement(self.device),
sbp=flow.sbp.split(0),
)
if self.global_mode.enabled
Expand Down
6 changes: 4 additions & 2 deletions libai/tokenizer/tokenization_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,9 @@ def convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, Lis
ids.append(self._convert_token_to_id_with_added_voc(token))
return ids

def convert_to_tensors(self, token_ids, return_tensors=None, is_global=False, **kwargs):
def convert_to_tensors(
self, token_ids, return_tensors=None, is_global=False, device="cuda", **kwargs
):
if return_tensors is None:
return_token_ids = token_ids
elif return_tensors == "of":
Expand All @@ -783,7 +785,7 @@ def convert_to_tensors(self, token_ids, return_tensors=None, is_global=False, **
elif is_global:
sbp = kwargs.get("sbp", dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]))
placement = kwargs.get(
"placement", flow.placement("cuda", list(range(dist.get_world_size())))
"placement", flow.placement(device, list(range(dist.get_world_size())))
)
return_token_ids = flow.tensor(
token_ids, sbp=sbp, placement=placement, dtype=flow.long
Expand Down
8 changes: 4 additions & 4 deletions libai/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def device_type(self):
return self._device_type

def set_device_type(self, device_type):
assert device_type in ["cpu", "cuda"], f"not supported for {device_type}"
# assert device in ["cpu", "cuda"], f"not supported for device:{device}"
self._device_type = device_type

def get_layer_ranks(self, layer_idx):
Expand Down Expand Up @@ -435,10 +435,10 @@ def convert_to_distributed_default_setting(t):
return t.to_global(placement=flow.placement(device_type, ranks=t.placement.ranks))


def ttol(tensor, pure_local=False, ranks=None):
def ttol(tensor, pure_local=False, device="cuda", ranks=None):
"""Global tensor to local tensor."""
if tensor.is_global:
placement = tensor.placement if not ranks else flow.placement("cuda", ranks)
placement = tensor.placement if not ranks else flow.placement(device, ranks)
if pure_local:
tensor = tensor.to_global(placement=placement).to_local()
else:
Expand All @@ -459,7 +459,7 @@ def tton(tensor, local_only=False, ranks=None):

def tensor_to_rank0(tensor, device="cuda", to_local=False):
"""Global tensor to rank0."""
assert device in ["cpu", "cuda"], f"not supported for device:{device}"
# assert device in ["cpu", "cuda"], f"not supported for device:{device}"
if tensor.is_global:
# Consider if it's 2d mesh, ranks should be [[0]] instead of [0]
placement = flow.placement(device, ranks=[0] if tensor.placement.ranks.ndim == 1 else [[0]])
Expand Down

0 comments on commit 4a57a74

Please sign in to comment.