From 9692f03c3cced56c3856217c9dab8c7b778c3e84 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Mon, 8 May 2023 12:20:55 -0700 Subject: [PATCH] ClassyModelWrapper as nn.Module Summary: Fixes https://fb.workplace.com/groups/1148364849411890/permalink/1278947843020256/ Differential Revision: D45665663 fbshipit-source-id: 9c02f011bf8f613f0b3aaa8d8d17b189b239f953 --- classy_vision/models/classy_model.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/classy_vision/models/classy_model.py b/classy_vision/models/classy_model.py index 0f91c6516..11f3c440a 100644 --- a/classy_vision/models/classy_model.py +++ b/classy_vision/models/classy_model.py @@ -59,7 +59,7 @@ def __call__(self, *args, **kwargs): return ret_val -class ClassyModelWrapper: +class ClassyModelWrapper(torch.nn.Module): """Base ClassyModel wrapper class. This class acts as a thin pass through wrapper which lets users modify the behavior @@ -68,9 +68,8 @@ class ClassyModelWrapper: accessed by the `classy_model` attribute. """ - # TODO: Make this torchscriptable by inheriting from nn.Module / ClassyModel - def __init__(self, classy_model): + super().__init__() self.classy_model = classy_model def __getattr__(self, name):