Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 41 additions & 7 deletions trident/slide_encoder_models/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
from abc import abstractmethod
from einops import rearrange
from typing import Optional, Tuple

from trident.IO import get_weights_path
from trident.IO import get_weights_path, has_internet_connection

"""
This file contains 10+ pretrained slide encoders, all loadable via the encoder_factory() function.
Expand Down Expand Up @@ -50,13 +49,16 @@ def encoder_factory(model_name: str, pretrained: bool = True, freeze: bool = Tru


class BaseSlideEncoder(torch.nn.Module):

def __init__(self, freeze: bool = True, **build_kwargs: dict) -> None:

_has_internet = has_internet_connection()

def __init__(self, weights_path: Optional[str] = None, freeze: bool = True, **build_kwargs: dict) -> None:
"""
Parent class for all pretrained slide encoders.
"""
super().__init__()
self.enc_name = None
self.weights_path: Optional[str] = weights_path
self.model, self.precision, self.embedding_dim = self._build(**build_kwargs)

# Set all parameters to be non-trainable
Expand All @@ -71,7 +73,33 @@ def forward(self, batch):
"""
z = self.model(batch)
return z



def ensure_valid_weights_path(self, weights_path):
if weights_path and not (os.path.isfile(weights_path) or os.path.isdir(weights_path)):
raise FileNotFoundError(f"Expected checkpoint at '{weights_path}', but the file was not found.")

def _get_weights_path(self):
""" If self.weights_path is provided, use it. If not provided, check the model registry.
If path in model registry is empty, auto-download from huggingface else, use the path from the registry.
"""
if self.weights_path:
self.ensure_valid_weights_path(self.weights_path)
return self.weights_path
else:
weights_path = get_weights_path('slide', self.enc_name)
self.ensure_valid_weights_path(weights_path)
return weights_path


def ensure_has_internet(self, enc_name):
if not BaseSlideEncoder._has_internet:
raise FileNotFoundError(
f"Internet connection does seem not available. Auto checkpoint download is disabled."
f"To proceed, please manually download: {enc_name},\n"
f"and place it in the model registry in:\n`trident/slide_encoder_models/local_ckpts.json`"
)

@abstractmethod
def _build(self, **build_kwargs):
"""
Expand Down Expand Up @@ -410,8 +438,14 @@ def __init__(self, **build_kwargs):
def _build(self, pretrained=True):
self.enc_name = 'titan'
assert pretrained, "TitanSlideEncoder has no non-pretrained models. Please load with pretrained=True."
from transformers import AutoModel
model = AutoModel.from_pretrained('MahmoodLab/TITAN', trust_remote_code=True)
from transformers import AutoModel

weights_path = self._get_weights_path()
if weights_path:
model = AutoModel.from_pretrained(weights_path, trust_remote_code=True)
else:
self.ensure_has_internet(self.enc_name)
model = AutoModel.from_pretrained('MahmoodLab/TITAN', trust_remote_code=True)
precision = torch.float16
embedding_dim = 768
return model, precision, embedding_dim
Expand Down
3 changes: 2 additions & 1 deletion trident/slide_encoder_models/local_ckpts.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
{
"chief": "./CHIEF",
"madeleine": "./MADELEINE"
"madeleine": "./MADELEINE",
"titan": ""
}