Skip to content

Commit

Permalink
clamped log_hazard to prevent torch.Inf
Browse files Browse the repository at this point in the history
  • Loading branch information
tcoroller committed Aug 9, 2024
1 parent 41f5911 commit 015e45d
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions src/torchsurv/loss/weibull.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,9 @@ def log_hazard(
>>> for t in torch.tensor([100.0, 150.0]): log_hazard(log_params, time=t) # Subject-specific log hazard at multiple new times
tensor([ 1.1280, -0.0372, -3.9767, 1.0757])
tensor([ 1.2330, -0.1062, -4.1680, 1.1999])
>>> log_params *= 1e2 # Increase scale
>>> log_hazard(log_params, time, all_times = False) # Check for Torch.Inf values
tensor([-1.0000e+10, -2.3197e+01, -6.8385e+01, -1.0000e+10])
"""

log_scale, log_shape = _check_log_shape(log_params).unbind(1)
Expand All @@ -247,11 +250,13 @@ def log_hazard(
f"Dimension mismatch: 'time' ({len(time)}) does not match the length of 'log_params' ({len(log_params)})."
)

return (
return torch.clamp(
log_shape
- log_scale
+ torch.expm1(log_shape)
* (torch.log(torch.clip(time, 1e-100, torch.inf)) - log_scale)
* (torch.log(torch.clip(time, 1e-100, torch.inf)) - log_scale),
min=-TORCH_CLAMP_VALUE,
max=TORCH_CLAMP_VALUE,
)


Expand Down

0 comments on commit 015e45d

Please sign in to comment.