diff --git a/src/refiners/training_utils/config.py b/src/refiners/training_utils/config.py index ed24453..097e74f 100644 --- a/src/refiners/training_utils/config.py +++ b/src/refiners/training_utils/config.py @@ -211,6 +211,8 @@ class HuggingfaceDatasetConfig(BaseModel): horizontal_flip: bool = False random_crop: bool = True use_verification: bool = False + resize_image_min_size: int = 512 + resize_image_max_size: int = 576 class CheckpointingConfig(BaseModel): diff --git a/src/refiners/training_utils/latent_diffusion.py b/src/refiners/training_utils/latent_diffusion.py index eee6d0f..2195e7c 100644 --- a/src/refiners/training_utils/latent_diffusion.py +++ b/src/refiners/training_utils/latent_diffusion.py @@ -98,7 +98,11 @@ class TextEmbeddingLatentsDataset(Dataset[TextEmbeddingLatentsBatch]): def __getitem__(self, index: int) -> TextEmbeddingLatentsBatch: item = self.dataset[index] caption, image = item["caption"], item["image"] - resized_image = self.resize_image(image=image) + resized_image = self.resize_image( + image=image, + min_size=self.config.dataset.resize_image_min_size, + max_size=self.config.dataset.resize_image_max_size, + ) processed_image = self.process_image(resized_image) latents = self.lda.encode_image(image=processed_image).to(device=self.device) processed_caption = self.process_caption(caption=caption)