diff --git a/bin/predict.py b/bin/predict.py index 4ff90dc2..3a89660a 100755 --- a/bin/predict.py +++ b/bin/predict.py @@ -41,7 +41,7 @@ def main(predict_config: OmegaConf): if sys.platform != 'win32': register_debug_signal_handlers() # kill -10 will result in traceback dumped into log - device = torch.device("cpu") + device = torch.device(predict_config.get('device', 'cpu')) train_config_path = os.path.join(predict_config.model.path, 'config.yaml') with open(train_config_path, 'r') as f: diff --git a/configs/prediction/default.yaml b/configs/prediction/default.yaml index 80fa69b2..113f4a27 100644 --- a/configs/prediction/default.yaml +++ b/configs/prediction/default.yaml @@ -7,11 +7,12 @@ model: dataset: kind: default - img_suffix: .png + img_suffix: .jpg pad_out_to_modulo: 8 device: cuda out_key: inpainted +out_ext: .jpg refine: False # refiner will only run if this is True refiner: