Skip to content

Commit

Permalink
Fix fp32->fp16 cast for scalars (fixes llama2 fp16 for MPS) (pytorch#…
Browse files Browse the repository at this point in the history
…1752)

Summary:
Summary of changes:
- Fix FP32 -> FP16 scalar conversion
- Disable MPS partitioner by default in mps_example script

```bash
# AOT:
python3 -m examples.apple.mps.scripts.mps_example --model_name="llama2" --bundled

# Runtime:
./cmake-out/examples/apple/mps/mps_executor_runner --model_path llama2_mps_bundled_fp16.pte --bundled_program
```
cc cccclai

Pull Request resolved: pytorch#1752

Reviewed By: cccclai

Differential Revision: D53236633

Pulled By: shoumikhin

fbshipit-source-id: ab04587bfd1493b81fa03a3145c0df6911749fec
  • Loading branch information
DenisVieriu97 authored and facebook-github-bot committed Jan 30, 2024
1 parent 7b5419a commit c99e5a5
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 25 deletions.
42 changes: 20 additions & 22 deletions backends/apple/mps/operators/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,22 +161,10 @@ def define_scalar(
return self.tensor_to_id[val]

id = self.get_serialized_id(val, mps_graph)

if (
self.convert_model_to_fp16
and mps_data_type == MPSDataType.mps_data_type_float32
):
mps_data_type = MPSDataType.mps_data_type_float16

if isinstance(val, int):
array = bytes(ctypes.c_int32(val))
elif isinstance(val, float):
array = bytes(ctypes.c_float(val))
else:
raise RuntimeError("Unknown data type!")

constant_buffer = Buffer(storage=array)
constant_buffer_size = len(array)
tensor = torch.tensor(val)
constant_buffer_size, constant_buffer, mps_data_type = self.get_serialized_data(
tensor, mps_graph, mps_data_type, id
)

mps_tensor = MPSTensor(
datatype=mps_data_type,
Expand All @@ -186,9 +174,6 @@ def define_scalar(
constant_buffer=constant_buffer,
)

if id not in mps_graph.constant_ids:
mps_graph.constant_ids.append(id)

mps_graph.mps_values.append(mps_tensor)
return id

Expand Down Expand Up @@ -218,12 +203,25 @@ def get_serialized_buffer(
tensor = get_param_tensor(self.exported_program, node)
assert tensor is not None and isinstance(tensor, torch.Tensor)
tensor = tensor.contiguous()
if self.convert_model_to_fp16 and tensor.dtype == torch.float32:

return self.get_serialized_data(tensor, mps_graph, mps_data_type, node_id)

def get_serialized_data(
self,
tensor: torch.tensor,
mps_graph: MPSGraph,
mps_data_type: MPSDataType,
id: int,
) -> Tuple[int, Buffer, MPSDataType]:
if (
self.convert_model_to_fp16
and mps_data_type == MPSDataType.mps_data_type_float32
):
tensor = tensor.half()
mps_data_type = MPSDataType.mps_data_type_float16

if node_id not in mps_graph.constant_ids:
mps_graph.constant_ids.append(node_id)
if id not in mps_graph.constant_ids:
mps_graph.constant_ids.append(id)

array_type = ctypes.c_char * tensor.untyped_storage().nbytes()
array = ctypes.cast(
Expand Down
5 changes: 2 additions & 3 deletions examples/apple/mps/scripts/mps_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,8 @@

parser.add_argument(
"--use_partitioner",
action="store_true",
required=False,
default=True,
default=False,
action=argparse.BooleanOptionalAction,
help="Use MPS partitioner to run the model instead of using whole graph lowering.",
)

Expand Down

0 comments on commit c99e5a5

Please sign in to comment.