mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
remove Black preview mode
also fix multiline logs in training
This commit is contained in:
parent
4176868e79
commit
f22f969d65
|
@ -54,7 +54,6 @@ build-backend = "poetry.core.masonry.api"
|
||||||
|
|
||||||
[tool.black]
|
[tool.black]
|
||||||
line-length = 120
|
line-length = 120
|
||||||
preview = true
|
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
ignore = [
|
ignore = [
|
||||||
|
|
|
@ -292,13 +292,16 @@ class Chain(ContextModule):
|
||||||
return Chain(*self, *other)
|
return Chain(*self, *other)
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __getitem__(self, key: int) -> Module: ...
|
def __getitem__(self, key: int) -> Module:
|
||||||
|
...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __getitem__(self, key: str) -> Module: ...
|
def __getitem__(self, key: str) -> Module:
|
||||||
|
...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __getitem__(self, key: slice) -> "Chain": ...
|
def __getitem__(self, key: slice) -> "Chain":
|
||||||
|
...
|
||||||
|
|
||||||
def __getitem__(self, key: int | str | slice) -> Module:
|
def __getitem__(self, key: int | str | slice) -> Module:
|
||||||
if isinstance(key, slice):
|
if isinstance(key, slice):
|
||||||
|
@ -346,10 +349,12 @@ class Chain(ContextModule):
|
||||||
@overload
|
@overload
|
||||||
def walk(
|
def walk(
|
||||||
self, predicate: Callable[[Module, "Chain"], bool] | None = None, recurse: bool = False
|
self, predicate: Callable[[Module, "Chain"], bool] | None = None, recurse: bool = False
|
||||||
) -> Iterator[tuple[Module, "Chain"]]: ...
|
) -> Iterator[tuple[Module, "Chain"]]:
|
||||||
|
...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def walk(self, predicate: type[T], recurse: bool = False) -> Iterator[tuple[T, "Chain"]]: ...
|
def walk(self, predicate: type[T], recurse: bool = False) -> Iterator[tuple[T, "Chain"]]:
|
||||||
|
...
|
||||||
|
|
||||||
def walk(
|
def walk(
|
||||||
self, predicate: type[T] | Callable[[Module, "Chain"], bool] | None = None, recurse: bool = False
|
self, predicate: type[T] | Callable[[Module, "Chain"], bool] | None = None, recurse: bool = False
|
||||||
|
|
|
@ -65,18 +65,22 @@ class LatentDiffusionModel(fl.Module, ABC):
|
||||||
return self.scheduler.steps
|
return self.scheduler.steps
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None: ...
|
def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def set_self_attention_guidance(self, enable: bool, scale: float = 1.0) -> None: ...
|
def set_self_attention_guidance(self, enable: bool, scale: float = 1.0) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def has_self_attention_guidance(self) -> bool: ...
|
def has_self_attention_guidance(self) -> bool:
|
||||||
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def compute_self_attention_guidance(
|
def compute_self_attention_guidance(
|
||||||
self, x: Tensor, noise: Tensor, step: int, *, clip_text_embedding: Tensor, **kwargs: Tensor
|
self, x: Tensor, noise: Tensor, step: int, *, clip_text_embedding: Tensor, **kwargs: Tensor
|
||||||
) -> Tensor: ...
|
) -> Tensor:
|
||||||
|
...
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x: Tensor, step: int, *, clip_text_embedding: Tensor, condition_scale: float = 7.5, **kwargs: Tensor
|
self, x: Tensor, step: int, *, clip_text_embedding: Tensor, condition_scale: float = 7.5, **kwargs: Tensor
|
||||||
|
|
|
@ -69,7 +69,8 @@ class MultiDiffusion(Generic[T, D], ABC):
|
||||||
return torch.where(condition=num_updates > 0, input=cumulative_values / num_updates, other=x)
|
return torch.where(condition=num_updates > 0, input=cumulative_values / num_updates, other=x)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def diffuse_target(self, x: Tensor, step: int, target: D) -> Tensor: ...
|
def diffuse_target(self, x: Tensor, step: int, target: D) -> Tensor:
|
||||||
|
...
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def steps(self) -> list[int]:
|
def steps(self) -> list[int]:
|
||||||
|
|
|
@ -56,7 +56,9 @@ class DDPM(Scheduler):
|
||||||
else tensor(1, device=self.device)
|
else tensor(1, device=self.device)
|
||||||
)
|
)
|
||||||
current_factor = current_cumulative_factor / previous_cumulative_scale_factor
|
current_factor = current_cumulative_factor / previous_cumulative_scale_factor
|
||||||
estimated_denoised_data = (x - (1 - current_cumulative_factor) ** 0.5 * noise) / current_cumulative_factor**0.5
|
estimated_denoised_data = (
|
||||||
|
x - (1 - current_cumulative_factor) ** 0.5 * noise
|
||||||
|
) / current_cumulative_factor**0.5
|
||||||
estimated_denoised_data = estimated_denoised_data.clamp(-1, 1)
|
estimated_denoised_data = estimated_denoised_data.clamp(-1, 1)
|
||||||
original_data_coeff = (previous_cumulative_scale_factor**0.5 * (1 - current_factor)) / (
|
original_data_coeff = (previous_cumulative_scale_factor**0.5 * (1 - current_factor)) / (
|
||||||
1 - current_cumulative_factor
|
1 - current_cumulative_factor
|
||||||
|
|
|
@ -42,59 +42,82 @@ T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
class Callback(Generic[T]):
|
class Callback(Generic[T]):
|
||||||
def on_train_begin(self, trainer: T) -> None: ...
|
def on_train_begin(self, trainer: T) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
def on_train_end(self, trainer: T) -> None: ...
|
def on_train_end(self, trainer: T) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
def on_epoch_begin(self, trainer: T) -> None: ...
|
def on_epoch_begin(self, trainer: T) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
def on_epoch_end(self, trainer: T) -> None: ...
|
def on_epoch_end(self, trainer: T) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
def on_batch_begin(self, trainer: T) -> None: ...
|
def on_batch_begin(self, trainer: T) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
def on_batch_end(self, trainer: T) -> None: ...
|
def on_batch_end(self, trainer: T) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
def on_backward_begin(self, trainer: T) -> None: ...
|
def on_backward_begin(self, trainer: T) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
def on_backward_end(self, trainer: T) -> None: ...
|
def on_backward_end(self, trainer: T) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
def on_optimizer_step_begin(self, trainer: T) -> None: ...
|
def on_optimizer_step_begin(self, trainer: T) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
def on_optimizer_step_end(self, trainer: T) -> None: ...
|
def on_optimizer_step_end(self, trainer: T) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
def on_compute_loss_begin(self, trainer: T) -> None: ...
|
def on_compute_loss_begin(self, trainer: T) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
def on_compute_loss_end(self, trainer: T) -> None: ...
|
def on_compute_loss_end(self, trainer: T) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
def on_evaluate_begin(self, trainer: T) -> None: ...
|
def on_evaluate_begin(self, trainer: T) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
def on_evaluate_end(self, trainer: T) -> None: ...
|
def on_evaluate_end(self, trainer: T) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
def on_lr_scheduler_step_begin(self, trainer: T) -> None: ...
|
def on_lr_scheduler_step_begin(self, trainer: T) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
def on_lr_scheduler_step_end(self, trainer: T) -> None: ...
|
def on_lr_scheduler_step_end(self, trainer: T) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
def on_checkpoint_save(self, trainer: T) -> None: ...
|
def on_checkpoint_save(self, trainer: T) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
class ClockCallback(Callback["Trainer[BaseConfig, Any]"]):
|
class ClockCallback(Callback["Trainer[BaseConfig, Any]"]):
|
||||||
def on_train_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
def on_train_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||||
trainer.clock.reset()
|
trainer.clock.reset()
|
||||||
logger.info(f"""Starting training for a total of:
|
logger.info(
|
||||||
{trainer.clock.num_steps} steps.
|
(
|
||||||
{trainer.clock.num_epochs} epochs.
|
"Starting training for a total of: "
|
||||||
{trainer.clock.num_iterations} iterations.
|
f"{trainer.clock.num_steps} steps, "
|
||||||
""")
|
f"{trainer.clock.num_epochs} epochs, "
|
||||||
|
f"{trainer.clock.num_iterations} iterations."
|
||||||
|
)
|
||||||
|
)
|
||||||
trainer.clock.start_timer()
|
trainer.clock.start_timer()
|
||||||
|
|
||||||
def on_train_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
def on_train_end(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||||
trainer.clock.stop_timer()
|
trainer.clock.stop_timer()
|
||||||
logger.info(f"""Training took:
|
logger.info(
|
||||||
{trainer.clock.time_elapsed} seconds.
|
(
|
||||||
{trainer.clock.iteration} iterations.
|
"Training took: "
|
||||||
{trainer.clock.epoch} epochs.
|
f"{trainer.clock.time_elapsed} seconds, "
|
||||||
{trainer.clock.step} steps.
|
f"{trainer.clock.iteration} iterations, "
|
||||||
""")
|
f"{trainer.clock.epoch} epochs, "
|
||||||
|
f"{trainer.clock.step} steps."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def on_epoch_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
def on_epoch_begin(self, trainer: "Trainer[BaseConfig, Any]") -> None:
|
||||||
logger.info(f"Epoch {trainer.clock.epoch} started.")
|
logger.info(f"Epoch {trainer.clock.epoch} started.")
|
||||||
|
|
|
@ -8,9 +8,11 @@ T = TypeVar("T", covariant=True)
|
||||||
|
|
||||||
|
|
||||||
class HuggingfaceDataset(Generic[T], Protocol):
|
class HuggingfaceDataset(Generic[T], Protocol):
|
||||||
def __getitem__(self, index: int) -> T: ...
|
def __getitem__(self, index: int) -> T:
|
||||||
|
...
|
||||||
|
|
||||||
def __len__(self) -> int: ...
|
def __len__(self) -> int:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
def load_hf_dataset(
|
def load_hf_dataset(
|
||||||
|
|
|
@ -147,13 +147,11 @@ class TrainingClock:
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def unit_to_steps(self) -> dict[TimeUnit, int]:
|
def unit_to_steps(self) -> dict[TimeUnit, int]:
|
||||||
|
iteration_factor = self.num_batches_per_epoch if self.gradient_accumulation["unit"] == TimeUnit.EPOCH else 1
|
||||||
return {
|
return {
|
||||||
TimeUnit.STEP: 1,
|
TimeUnit.STEP: 1,
|
||||||
TimeUnit.EPOCH: self.num_batches_per_epoch,
|
TimeUnit.EPOCH: self.num_batches_per_epoch,
|
||||||
TimeUnit.ITERATION: self.gradient_accumulation["number"] * {
|
TimeUnit.ITERATION: self.gradient_accumulation["number"] * iteration_factor,
|
||||||
TimeUnit.STEP: 1,
|
|
||||||
TimeUnit.EPOCH: self.num_batches_per_epoch,
|
|
||||||
}.get(self.gradient_accumulation["unit"], 1),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def convert_time_unit_to_steps(self, number: int, unit: TimeUnit) -> int:
|
def convert_time_unit_to_steps(self, number: int, unit: TimeUnit) -> int:
|
||||||
|
|
|
@ -18,7 +18,8 @@ class DiffusersSDXL(Protocol):
|
||||||
tokenizer_2: fl.Module
|
tokenizer_2: fl.Module
|
||||||
vae: fl.Module
|
vae: fl.Module
|
||||||
|
|
||||||
def __call__(self, prompt: str, *args: Any, **kwargs: Any) -> Any: ...
|
def __call__(self, prompt: str, *args: Any, **kwargs: Any) -> Any:
|
||||||
|
...
|
||||||
|
|
||||||
def encode_prompt(
|
def encode_prompt(
|
||||||
self,
|
self,
|
||||||
|
@ -26,7 +27,8 @@ class DiffusersSDXL(Protocol):
|
||||||
prompt_2: str | None = None,
|
prompt_2: str | None = None,
|
||||||
negative_prompt: str | None = None,
|
negative_prompt: str | None = None,
|
||||||
negative_prompt_2: str | None = None,
|
negative_prompt_2: str | None = None,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ...
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
|
@ -67,9 +69,12 @@ def test_double_text_encoder(diffusers_sdxl: DiffusersSDXL, double_text_encoder:
|
||||||
manual_seed(seed=0)
|
manual_seed(seed=0)
|
||||||
prompt = "A photo of a pizza."
|
prompt = "A photo of a pizza."
|
||||||
|
|
||||||
prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = (
|
(
|
||||||
diffusers_sdxl.encode_prompt(prompt=prompt, negative_prompt="")
|
prompt_embeds,
|
||||||
)
|
negative_prompt_embeds,
|
||||||
|
pooled_prompt_embeds,
|
||||||
|
negative_pooled_prompt_embeds,
|
||||||
|
) = diffusers_sdxl.encode_prompt(prompt=prompt, negative_prompt="")
|
||||||
|
|
||||||
double_embedding, pooled_embedding = double_text_encoder(prompt)
|
double_embedding, pooled_embedding = double_text_encoder(prompt)
|
||||||
|
|
||||||
|
|
|
@ -32,16 +32,19 @@ class FacebookSAM(nn.Module):
|
||||||
prompt_encoder: nn.Module
|
prompt_encoder: nn.Module
|
||||||
mask_decoder: nn.Module
|
mask_decoder: nn.Module
|
||||||
|
|
||||||
def __call__(self, batched_input: list[SAMInput], multimask_output: bool) -> list[SAMOutput]: ...
|
def __call__(self, batched_input: list[SAMInput], multimask_output: bool) -> list[SAMOutput]:
|
||||||
|
...
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device(self) -> Any: ...
|
def device(self) -> Any:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
class FacebookSAMPredictor:
|
class FacebookSAMPredictor:
|
||||||
model: FacebookSAM
|
model: FacebookSAM
|
||||||
|
|
||||||
def set_image(self, image: NDArrayUInt8, image_format: str = "RGB") -> None: ...
|
def set_image(self, image: NDArrayUInt8, image_format: str = "RGB") -> None:
|
||||||
|
...
|
||||||
|
|
||||||
def predict(
|
def predict(
|
||||||
self,
|
self,
|
||||||
|
@ -51,7 +54,8 @@ class FacebookSAMPredictor:
|
||||||
mask_input: NDArray | None = None,
|
mask_input: NDArray | None = None,
|
||||||
multimask_output: bool = True,
|
multimask_output: bool = True,
|
||||||
return_logits: bool = False,
|
return_logits: bool = False,
|
||||||
) -> tuple[NDArray, NDArray, NDArray]: ...
|
) -> tuple[NDArray, NDArray, NDArray]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
Loading…
Reference in a new issue