From 11bb6c032f40eea683702e518ffef6795c20217c Mon Sep 17 00:00:00 2001 From: Joe Marks Date: Sat, 19 Aug 2023 22:53:01 +0100 Subject: [PATCH] Add verbose option to TVAE --- ctgan/synthesizers/tvae.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/ctgan/synthesizers/tvae.py b/ctgan/synthesizers/tvae.py index 8dafd4e7..4765483d 100644 --- a/ctgan/synthesizers/tvae.py +++ b/ctgan/synthesizers/tvae.py @@ -110,6 +110,7 @@ def __init__( decompress_dims=(128, 128), l2scale=1e-5, batch_size=500, + verbose=False, epochs=300, loss_factor=2, cuda=True @@ -122,6 +123,7 @@ def __init__( self.l2scale = l2scale self.batch_size = batch_size self.loss_factor = loss_factor + self.verbose = verbose self.epochs = epochs if not cuda or not torch.cuda.is_available(): @@ -176,6 +178,12 @@ def fit(self, train_data, discrete_columns=()): optimizerAE.step() self.decoder.sigma.data.clamp_(0.01, 1.0) + if self.verbose: + print(f'Epoch {i+1}, Loss: {loss.detach().cpu(): .4f},', # noqa: T001 + f' Rec loss: {loss_1.detach().cpu(): .4f},', + f' KL loss: {loss_2.detach().cpu(): .4f}', + flush=True) + @random_state def sample(self, samples): """Sample data similar to the training data.