diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 0000000..729aadc --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,8 @@ +{ + "image": "mcr.microsoft.com/devcontainers/universal:2", + "features": { + "ghcr.io/devcontainers/features/nvidia-cuda:1": { + "installCudnn": true + } + } +} diff --git a/bin/run_patchcore.py b/bin/run_patchcore.py index 8666b2b..7b6113b 100644 --- a/bin/run_patchcore.py +++ b/bin/run_patchcore.py @@ -16,7 +16,10 @@ LOGGER = logging.getLogger(__name__) -_DATASETS = {"mvtec": ["patchcore.datasets.mvtec", "MVTecDataset"]} +_DATASETS = {"mvtec": ["patchcore.datasets.mvtec", "MVTecDataset"], + "inspl": ["patchcore.datasets.inspl", "InsplDataset"]} + + @click.group(chain=True) @@ -50,7 +53,8 @@ def run( list_of_dataloaders = methods["get_dataloaders"](seed) - device = patchcore.utils.set_torch_device(gpu) + #empty list to set on cpu gpu to set on gpu + device = patchcore.utils.set_torch_device([]) # Device context here is specifically set and used later # because there was GPU memory-bleeding which I could only fix with # context managers. diff --git a/sample_training.sh b/sample_training.sh index 61685e0..149274d 100644 --- a/sample_training.sh +++ b/sample_training.sh @@ -1,5 +1,8 @@ -datapath=/path/to/data/from/mvtec -datasets=('bottle' 'cable' 'capsule' 'carpet' 'grid' 'hazelnut' 'leather' 'metal_nut' 'pill' 'screw' 'tile' 'toothbrush' 'transistor' 'wood' 'zipper') +datapath=/workspaces/patchcore-inspection/inspl +datasets=('damper-stockbridge' 'vari-grip' 'yoke' 'spacer' 'lightning-rod-shackle' 'plate' 'damper-preformed' 'polymer-insulator-lower-shackle' 'yoke-suspension' +'polymer-insulator-upper-shackle' 'polymer-insulator' 'glass-insulator' +'glass-insulator-big-shackle' 'glass-insulator-tower-shackle' 'polymer-insulator-tower-shackle' +'lightning-rod-suspension') dataset_flags=($(for dataset in "${datasets[@]}"; do echo '-d '"${dataset}"; done)) ############# Detection diff --git a/src/patchcore/datasets/inspl.py b/src/patchcore/datasets/inspl.py new file mode 100644 index 0000000..f6b14ee --- /dev/null +++ b/src/patchcore/datasets/inspl.py @@ -0,0 +1,152 @@ +import os +from enum import Enum + +import PIL +import torch +from torchvision import transforms + +_CLASSNAMES = ['damper-stockbridge', + 'vari-grip', + 'yoke', + 'spacer', + 'lightning-rod-shackle', + 'plate', 'damper-preformed', + 'polymer-insulator-lower-shackle', + 'yoke-suspension', + 'polymer-insulator-upper-shackle', + 'polymer-insulator', 'glass-insulator', + 'glass-insulator-big-shackle', + 'glass-insulator-tower-shackle', + 'polymer-insulator-tower-shackle', + 'lightning-rod-suspension', + 'glass-insulator-small-shackle' + ] + + +IMAGENET_MEAN = [0.485, 0.456, 0.406] +IMAGENET_STD = [0.229, 0.224, 0.225] + + +class DatasetSplit(Enum): + TRAIN = "train" + VAL = "val" + TEST = "test" + + +class InsplDataset(torch.utils.data.Dataset): + """ + PyTorch Dataset for MVTec. + """ + + def __init__( + self, + source, + classname, + resize=256, + imagesize=224, + split=DatasetSplit.TRAIN, + train_val_split=1.0, + **kwargs, + ): + """ + Args: + source: [str]. Path to the MVTec data folder. + classname: [str or None]. Name of MVTec class that should be + provided in this dataset. If None, the datasets + iterates over all available images. + resize: [int]. (Square) Size the loaded image initially gets + resized to. + imagesize: [int]. (Square) Size the resized loaded image gets + (center-)cropped to. + split: [enum-option]. Indicates if training or test split of the + data should be used. Has to be an option taken from + DatasetSplit, e.g. mvtec.DatasetSplit.TRAIN. Note that + mvtec.DatasetSplit.TEST will also load mask data. + """ + super().__init__() + self.source = source + self.split = split + self.classnames_to_use = [classname] if classname is not None else _CLASSNAMES + self.train_val_split = train_val_split + + self.imgpaths_per_class, self.data_to_iterate = self.get_image_data() + + self.transform_img = [ + transforms.Resize(resize), + transforms.CenterCrop(imagesize), + transforms.ToTensor(), + transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), + ] + self.transform_img = transforms.Compose(self.transform_img) + + self.transform_mask = [ + transforms.Resize(resize), + transforms.CenterCrop(imagesize), + transforms.ToTensor(), + ] + self.transform_mask = transforms.Compose(self.transform_mask) + + self.imagesize = (3, imagesize, imagesize) + + def __getitem__(self, idx): + classname, anomaly, image_path, mask_path = self.data_to_iterate[idx] + image = PIL.Image.open(image_path).convert("RGB") + image = self.transform_img(image) + + if self.split == DatasetSplit.TEST and mask_path is not None: + mask = PIL.Image.open(mask_path) + mask = self.transform_mask(mask) + else: + mask = torch.zeros([1, *image.size()[1:]]) + + return { + "image": image, + "mask": mask, + "classname": classname, + "anomaly": anomaly, + "is_anomaly": int(anomaly != "good"), + "image_name": "/".join(image_path.split("/")[-4:]), + "image_path": image_path, + } + + def __len__(self): + return len(self.data_to_iterate) + + def get_image_data(self): + imgpaths_per_class = {} + + for classname in self.classnames_to_use: + classpath = os.path.join(self.source, classname, self.split.value) + anomaly_types = os.listdir(classpath) + + imgpaths_per_class[classname] = {} + + for anomaly in anomaly_types: + anomaly_path = os.path.join(classpath, anomaly) + anomaly_files = sorted(os.listdir(anomaly_path)) + imgpaths_per_class[classname][anomaly] = [ + os.path.join(anomaly_path, x) for x in anomaly_files + ] + + if self.train_val_split < 1.0: + n_images = len(imgpaths_per_class[classname][anomaly]) + train_val_split_idx = int(n_images * self.train_val_split) + if self.split == DatasetSplit.TRAIN: + imgpaths_per_class[classname][anomaly] = imgpaths_per_class[ + classname + ][anomaly][:train_val_split_idx] + elif self.split == DatasetSplit.VAL: + imgpaths_per_class[classname][anomaly] = imgpaths_per_class[ + classname + ][anomaly][train_val_split_idx:] + + + # Unrolls the data dictionary to an easy-to-iterate list. + data_to_iterate = [] + for classname in sorted(imgpaths_per_class.keys()): + for anomaly in sorted(imgpaths_per_class[classname].keys()): + for i, image_path in enumerate(imgpaths_per_class[classname][anomaly]): + data_tuple = [classname, anomaly, image_path, None] + data_to_iterate.append(data_tuple) + + return imgpaths_per_class, data_to_iterate