Skip to content

Commit

Permalink
fix attention mask and mem usage
Browse files Browse the repository at this point in the history
  • Loading branch information
plusbang committed Jun 27, 2024
1 parent 7fea5dc commit 0c72348
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 4 deletions.
2 changes: 2 additions & 0 deletions python/llm/example/GPU/Pipeline-Parallel-FastAPI/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ pip install transformers==4.37.0
bash run.sh
```

> Note: INT4 optimization is applied to the model by default. You could specify other low bit optimizations (such as 'fp8' and 'fp6') through `--low-bit`. Besides, you could change `NUM_GPUS` to the number of GPUs you have on your machine.

### 3. Sample Input and Output

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -301,11 +301,11 @@ async def main():
help='The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-chat-hf`, `meta-llama/Llama-2-13b-chat-hf` and `meta-llama/Llama-2-70b-chat-hf`) to be downloaded'
', or the path to the huggingface checkpoint folder')
parser.add_argument('--low-bit', type=str, default='sym_int4',
help='The quantization type the model will convert to.')
help='The quantization type the model will convert to.')
parser.add_argument('--port', type=int, default=8000,
help='The port number on which the server will run.')
help='The port number on which the server will run.')
parser.add_argument('--max-num-seqs', type=int, default=8,
help='Max num sequences in a batch.')
help='Max num sequences in a batch.')

args = parser.parse_args()
model_path = args.repo_id_or_model_path
Expand Down
1 change: 0 additions & 1 deletion python/llm/example/GPU/Pipeline-Parallel-FastAPI/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# limitations under the License.
#

source /opt/intel/oneapi/setvars.sh
export no_proxy=localhost
export FI_PROVIDER=tcp
export OMP_NUM_THREADS=32
Expand Down
13 changes: 13 additions & 0 deletions python/llm/src/ipex_llm/transformers/pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,14 @@ class BatchTask(BaseModel):
stopped: bool


def make_attention_mask(prompt_lengths):
max_length = max(prompt_lengths)
attention_mask = torch.zeros((len(prompt_lengths), max_length), dtype=torch.int64)
for i, length in enumerate(prompt_lengths):
attention_mask[i, max_length - length:] = 1
return attention_mask


class ModelRunner:
"""Implementation for pipeline parallel multi-stage serving."""
def __init__(self, checkpoint, rank, world_size, low_bit, max_num_seqs,
Expand Down Expand Up @@ -386,12 +394,14 @@ def load_model(self, model_path, world_size, low_bit='sym_int4'):
model = model.eval()
return model

@torch.no_grad()
def model_step(self, input, cur_batch):
if cur_batch is None or cur_batch.stopped or input is None:
return None

cur_id = cur_batch.batch_id
_past_key_values = self.past_key_values_dict.get(cur_id, None)
attention_mask = make_attention_mask(cur_batch.prompt_lengths).to(f'xpu:{self.rank}')

if self.rank == 0:
input_ids = input
Expand All @@ -400,9 +410,11 @@ def model_step(self, input, cur_batch):
input_ids = None
inputs_embeds = input

torch.xpu.empty_cache()
output = self.model(input_ids=input_ids,
inputs_embeds=inputs_embeds,
past_key_values=_past_key_values,
attention_mask=attention_mask,
use_cache=True,)

if self.model.config.model_type in ["baichuan", "chatglm"] and self.rank > 0:
Expand All @@ -414,6 +426,7 @@ def model_step(self, input, cur_batch):
else:
_past_key_values = output.past_key_values
self.past_key_values_dict[cur_id] = _past_key_values
torch.xpu.synchronize()
if not self.pp_config.is_tail:
return output[0].to(self.dtype)
else:
Expand Down

0 comments on commit 0c72348

Please sign in to comment.