diff --git a/pygrinder/missing_at_random/mar_logistic.py b/pygrinder/missing_at_random/mar_logistic.py index 3979d99..199047b 100644 --- a/pygrinder/missing_at_random/mar_logistic.py +++ b/pygrinder/missing_at_random/mar_logistic.py @@ -122,7 +122,7 @@ def mar_logistic( X = np.asarray(X) if isinstance(X, np.ndarray) or isinstance(X, torch.Tensor): - corrupted_X = _mar_logistic_torch(X, missing_rate, obs_rate) + corrupted_X = _mar_logistic_torch(X, obs_rate, missing_rate) else: raise TypeError( f"X must be type of list/numpy.ndarray/torch.Tensor, but got {type(X)}"