Skip to content

Commit

Permalink
Merge pull request #109 from ibm-granite/clamp_min_issue
Browse files Browse the repository at this point in the history
address issue with clamp_min on MPS
  • Loading branch information
wgifford authored Aug 13, 2024
2 parents 7ca28ee + f4efdd5 commit b188e80
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion tsfm_public/models/tinytimemixer/modeling_tinytimemixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,7 +982,7 @@ def forward(
`(batch_size, 1, num_input_channels)`)
"""
denominator = observed_indicator.sum(self.dim, keepdim=self.keepdim)
denominator = denominator.clamp_min(1.0)
denominator = denominator.clamp_min(torch.tensor(1, device=denominator.device))
loc = (data * observed_indicator).sum(self.dim, keepdim=self.keepdim) / denominator

variance = (((data - loc) * observed_indicator) ** 2).sum(self.dim, keepdim=self.keepdim) / denominator
Expand Down

0 comments on commit b188e80

Please sign in to comment.