Make image resize configurable in training scripts

This commit is contained in:
Doryan Kaced 2023-08-30 12:53:43 +02:00
parent 437fa24368
commit 08a5341452
2 changed files with 7 additions and 1 deletions

View file

@ -211,6 +211,8 @@ class HuggingfaceDatasetConfig(BaseModel):
horizontal_flip: bool = False horizontal_flip: bool = False
random_crop: bool = True random_crop: bool = True
use_verification: bool = False use_verification: bool = False
resize_image_min_size: int = 512
resize_image_max_size: int = 576
class CheckpointingConfig(BaseModel): class CheckpointingConfig(BaseModel):

View file

@ -98,7 +98,11 @@ class TextEmbeddingLatentsDataset(Dataset[TextEmbeddingLatentsBatch]):
def __getitem__(self, index: int) -> TextEmbeddingLatentsBatch: def __getitem__(self, index: int) -> TextEmbeddingLatentsBatch:
item = self.dataset[index] item = self.dataset[index]
caption, image = item["caption"], item["image"] caption, image = item["caption"], item["image"]
resized_image = self.resize_image(image=image) resized_image = self.resize_image(
image=image,
min_size=self.config.dataset.resize_image_min_size,
max_size=self.config.dataset.resize_image_max_size,
)
processed_image = self.process_image(resized_image) processed_image = self.process_image(resized_image)
latents = self.lda.encode_image(image=processed_image).to(device=self.device) latents = self.lda.encode_image(image=processed_image).to(device=self.device)
processed_caption = self.process_caption(caption=caption) processed_caption = self.process_caption(caption=caption)