Skip to content

Commit

Permalink
Allow computing logprobs of SFT model on input strings (#148)
Browse files Browse the repository at this point in the history
Signed-off-by: Olivier Delalleau <[email protected]>
  • Loading branch information
odelalleau authored Apr 9, 2024
1 parent 9fcc534 commit 9db62d6
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
- Fixed crash with float val check interval when checking progress in DPOTrainer
- Fixed potential crash in SPIN when prompts are longer than encoder_seq_len - generation.max_length
- Fixed crash when calling the `generate()` method of an SFT model with pipeline parallelism greater than two
- Fixed crash when calling the `generate()` method of an SFT model with `compute_logprob=True` and string inputs

## [0.2.0] - 2024-02
### New features and optimizations
Expand Down
24 changes: 19 additions & 5 deletions nemo_aligner/models/nlp/gpt/gpt_sft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Tuple
from typing import List, Optional, Tuple, Union

import hydra
import torch
Expand Down Expand Up @@ -144,17 +144,31 @@ def finish_validation_step(self):

def generate(
self,
inputs: Tuple[torch.Tensor, torch.Tensor], # we do not support non-tensor inputs
inputs: Union[List[str], Tuple[torch.Tensor, torch.Tensor]],
length_params: LengthParam,
sampling_params: SamplingParam = None,
*,
strategy: Optional[TextGenerationStrategy] = None,
) -> OutputType:
"""Same as base model generate, except the following
"""
Same as base model generate, except the following:
1. Apply padding to max length.
2. Add a "predictions" key to the output, which is the model output without the prompt.
1. Going to apply padding to max length
2. Going to append "predictions" key which is the model output without the prompt
These two additional steps above are only performed for actual generation from the model:
if `generate()` is called with `compute_logprob=True` then the base model method is used.
"""
if sampling_params is not None and sampling_params.get("compute_logprob", False):
return super().generate(
inputs=inputs, length_params=length_params, sampling_params=sampling_params, strategy=strategy
)

if isinstance(inputs, (list, tuple)) and isinstance(inputs[0], str):
raise NotImplementedError(
"`GPTSFTModel.generate()` does not currently support string inputs, please tokenize prompts first"
)

prompt_tokens, prompt_lengths = inputs
max_prompt_length = prompt_lengths.max().item()
max_response_length = length_params["max_length"]
Expand Down

0 comments on commit 9db62d6

Please sign in to comment.