diff --git a/taming/modules/losses/lpips.py b/taming/modules/losses/lpips.py index 86a00ba4..9ef3ad44 100644 --- a/taming/modules/losses/lpips.py +++ b/taming/modules/losses/lpips.py @@ -1,5 +1,6 @@ """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" +import os import torch import torch.nn as nn from torchvision import models @@ -25,7 +26,7 @@ def __init__(self, use_dropout=True): param.requires_grad = False def load_from_pretrained(self, name="vgg_lpips"): - ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips") + ckpt = get_ckpt_path(name, os.path.expanduser("~/.cache/taming/modules/autoencoder/lpips")) self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) print("loaded pretrained LPIPS loss from {}".format(ckpt))