From 577e39404b1d141a8a83a53bdb39c561b55f9fb4 Mon Sep 17 00:00:00 2001 From: ZeroCool Date: Tue, 16 Aug 2022 23:42:13 -0700 Subject: [PATCH] Allow folders for the train and validation paths. This changes in the `taming/data/custom.py` file allow the use of folders for the training and validation paths while been compatible with how it was used before. This should make it easier to train on custom data as you no longer need to create a `train.txt` and `val.txt` file manually which sometimes can be a time-consuming and tedious task, now you can just create two folders, one for the train and one for the validation containing the images inside and the script will discover the files and create the list of paths by itself, this can also allow you to add more images to those folders later without having to manually recreate the train and validation files, not sure about the use for that but the option is there. --- taming/data/custom.py | 38 ++++++++++++++++++++++++++++++++++---- 1 file changed, 34 insertions(+), 4 deletions(-) diff --git a/taming/data/custom.py b/taming/data/custom.py index 33f302a4..543966ab 100644 --- a/taming/data/custom.py +++ b/taming/data/custom.py @@ -23,16 +23,46 @@ def __getitem__(self, i): class CustomTrain(CustomBase): def __init__(self, size, training_images_list_file): super().__init__() - with open(training_images_list_file, "r") as f: - paths = f.read().splitlines() + + isFile = os.path.isfile(training_images_list_file) + isDirectory = os.path.isdir(training_images_list_file) + + if isFile: + with open(training_images_list_file, "r") as f: + paths = f.read().splitlines() + + if isDirectory: + paths = [] + for images in os.listdir(training_images_list_file): + + # check if the image ends with png or jpg or jpeg + if (images.endswith(".png") or images.endswith(".jpg")\ + or images.endswith(".jpeg")): + paths.append(os.path.join(training_images_list_file, images)) + self.data = ImagePaths(paths=paths, size=size, random_crop=False) class CustomTest(CustomBase): def __init__(self, size, test_images_list_file): super().__init__() - with open(test_images_list_file, "r") as f: - paths = f.read().splitlines() + + isFile = os.path.isfile(test_images_list_file) + isDirectory = os.path.isdir(test_images_list_file) + + if isFile: + with open(test_images_list_file, "r") as f: + paths = f.read().splitlines() + + if isDirectory: + paths = [] + for images in os.listdir(test_images_list_file): + + # check if the image ends with png or jpg or jpeg + if (images.endswith(".png") or images.endswith(".jpg")\ + or images.endswith(".jpeg")): + paths.append(os.path.join(test_images_list_file, images)) + self.data = ImagePaths(paths=paths, size=size, random_crop=False)