From a04add43f62804721544239848cfc582bdcf648f Mon Sep 17 00:00:00 2001 From: amarcelq <88303414+amarcelq@users.noreply.github.com> Date: Fri, 25 Oct 2024 11:32:51 +0200 Subject: [PATCH] Check if decoder is set Without this check things like the sklearn.randomsampler run into an error, since they deepcopy the class before they call the fit method --- ctgan/synthesizers/tvae.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ctgan/synthesizers/tvae.py b/ctgan/synthesizers/tvae.py index ecefbb5f..fb37b8ef 100644 --- a/ctgan/synthesizers/tvae.py +++ b/ctgan/synthesizers/tvae.py @@ -128,6 +128,8 @@ def __init__( self.loss_values = pd.DataFrame(columns=['Epoch', 'Batch', 'Loss']) self.verbose = verbose + self.decoder = None + if not cuda or not torch.cuda.is_available(): device = 'cpu' elif isinstance(cuda, str): @@ -243,4 +245,5 @@ def sample(self, samples): def set_device(self, device): """Set the `device` to be used ('GPU' or 'CPU).""" self._device = device - self.decoder.to(self._device) + if self.decoder: + self.decoder.to(self._device)