Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions captum/_utils/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
TupleOrTensorOrBoolGeneric = TypeVar("TupleOrTensorOrBoolGeneric", Tuple, Tensor, bool)
ModuleOrModuleList = TypeVar("ModuleOrModuleList", Module, List[Module])
ModuleOrModuleListOrModuleTuple = TypeVar(
"ModuleOrModuleListOrModuleTuple", Module, List[Module], Tuple[Module]
)
TargetType = Union[None, int, Tuple[int, ...], Tensor, List[Tuple[int, ...]], List[int]]
BaselineTupleType = Union[None, Tuple[Union[Tensor, int, float], ...]]
BaselineType = Union[None, Tensor, int, float, BaselineTupleType]
Expand Down
5 changes: 3 additions & 2 deletions captum/attr/_core/layer/layer_lrp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)
from captum._utils.typing import (
ModuleOrModuleList,
ModuleOrModuleListOrModuleTuple,
TargetType,
TensorOrTupleOfTensorsGeneric,
)
Expand Down Expand Up @@ -253,7 +254,7 @@ def attribute(

if return_convergence_delta:
delta: Union[Tensor, List[Tensor]]
if isinstance(self.layer, list):
if isinstance(self.layer, list) or isinstance(self.layer, tuple):
delta = []
for relevance_layer in relevances:
delta.append(
Expand Down Expand Up @@ -294,7 +295,7 @@ def _get_single_output_relevance(self, layer, output):
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def _get_output_relevance(self, output):
if isinstance(self.layer, list):
if isinstance(self.layer, list) or isinstance(self.layer, tuple):
relevances = []
for layer in self.layer:
relevances.append(self._get_single_output_relevance(layer, output))
Expand Down