diff --git a/ctgan/synthesizers/tvae.py b/ctgan/synthesizers/tvae.py index ecefbb5f..fb37b8ef 100644 --- a/ctgan/synthesizers/tvae.py +++ b/ctgan/synthesizers/tvae.py @@ -128,6 +128,8 @@ def __init__( self.loss_values = pd.DataFrame(columns=['Epoch', 'Batch', 'Loss']) self.verbose = verbose + self.decoder = None + if not cuda or not torch.cuda.is_available(): device = 'cpu' elif isinstance(cuda, str): @@ -243,4 +245,5 @@ def sample(self, samples): def set_device(self, device): """Set the `device` to be used ('GPU' or 'CPU).""" self._device = device - self.decoder.to(self._device) + if self.decoder: + self.decoder.to(self._device)