Skip to content

Commit a88e31a

Browse files
Fix bug in handling labels with LabelTensor (#460)
--------- Co-authored-by: Filippo Olivo <[email protected]>
1 parent 12b7787 commit a88e31a

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

pina/label_tensor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,9 @@ def __getitem__(self, index):
448448

449449
# Retrieve selected tensor and labels
450450
selected_tensor = super().__getitem__(index)
451+
if not hasattr(self, "_labels"):
452+
return selected_tensor
453+
451454
original_labels = self._labels
452455
updated_labels = copy(original_labels)
453456

pina/model/block/convolution_2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ def _make_grid_transpose(self, X):
265265
266266
"""
267267
# initialize to all zeros
268-
tmp = torch.zeros_like(X)
268+
tmp = torch.zeros_like(X).as_subclass(torch.Tensor)
269269
tmp[..., :-1] = X[..., :-1]
270270

271271
# save on tmp

0 commit comments

Comments
 (0)