diff --git a/voicebox_pytorch/trainer.py b/voicebox_pytorch/trainer.py index 4c85790..01a209b 100644 --- a/voicebox_pytorch/trainer.py +++ b/voicebox_pytorch/trainer.py @@ -14,6 +14,7 @@ from voicebox_pytorch.optimizer import get_optimizer from accelerate import Accelerator, DistributedType +from accelerate.utils import DistributedDataParallelKwargs # helpers