From 04dffbd863f0feba7f92129cc4939bc06c57e93c Mon Sep 17 00:00:00 2001 From: shibuiwilliam Date: Fri, 29 Sep 2023 06:06:36 +0000 Subject: [PATCH 1/2] add directory args to download LPIPS model --- taming/modules/losses/lpips.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/taming/modules/losses/lpips.py b/taming/modules/losses/lpips.py index a7280447..923d493f 100644 --- a/taming/modules/losses/lpips.py +++ b/taming/modules/losses/lpips.py @@ -10,8 +10,9 @@ class LPIPS(nn.Module): # Learned perceptual metric - def __init__(self, use_dropout=True): + def __init__(self, use_dropout=True, download_directory: str = "/tmp/"): super().__init__() + self.download_directory = download_directory self.scaling_layer = ScalingLayer() self.chns = [64, 128, 256, 512, 512] # vg16 features self.net = vgg16(pretrained=True, requires_grad=False) @@ -25,7 +26,9 @@ def __init__(self, use_dropout=True): param.requires_grad = False def load_from_pretrained(self, name="vgg_lpips"): - ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips") + root = os.path.join(self.download_directory, "taming/modules/autoencoder/lpips") + os.makedirs(root, exist_ok=True) + ckpt = get_ckpt_path(name, root) self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) print("loaded pretrained LPIPS loss from {}".format(ckpt)) From 426920c7739771d06f4c62fc793663ca83deebc4 Mon Sep 17 00:00:00 2001 From: shibuiwilliam Date: Fri, 29 Sep 2023 06:08:57 +0000 Subject: [PATCH 2/2] add missing import os --- taming/modules/losses/lpips.py | 1 + 1 file changed, 1 insertion(+) diff --git a/taming/modules/losses/lpips.py b/taming/modules/losses/lpips.py index 923d493f..00a46d27 100644 --- a/taming/modules/losses/lpips.py +++ b/taming/modules/losses/lpips.py @@ -1,4 +1,5 @@ """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" +import os import torch import torch.nn as nn