mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
ruff 3 formatting (Rye 0.28)
This commit is contained in:
parent
a0be5458b9
commit
be2368cf20
|
@ -4,6 +4,7 @@ Download and convert weights for testing
|
||||||
To see what weights will be downloaded and converted, run:
|
To see what weights will be downloaded and converted, run:
|
||||||
DRY_RUN=1 python scripts/prepare_test_weights.py
|
DRY_RUN=1 python scripts/prepare_test_weights.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
|
|
|
@ -129,8 +129,7 @@ class Lora(Generic[T], fl.Chain, ABC):
|
||||||
return loras
|
return loras
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def is_compatible(self, layer: fl.WeightedModule, /) -> bool:
|
def is_compatible(self, layer: fl.WeightedModule, /) -> bool: ...
|
||||||
...
|
|
||||||
|
|
||||||
def auto_attach(
|
def auto_attach(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -256,16 +256,13 @@ class Chain(ContextModule):
|
||||||
self._modules = generate_unique_names(tuple(modules)) # type: ignore
|
self._modules = generate_unique_names(tuple(modules)) # type: ignore
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __getitem__(self, key: int) -> Module:
|
def __getitem__(self, key: int) -> Module: ...
|
||||||
...
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __getitem__(self, key: str) -> Module:
|
def __getitem__(self, key: str) -> Module: ...
|
||||||
...
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __getitem__(self, key: slice) -> "Chain":
|
def __getitem__(self, key: slice) -> "Chain": ...
|
||||||
...
|
|
||||||
|
|
||||||
def __getitem__(self, key: int | str | slice) -> Module:
|
def __getitem__(self, key: int | str | slice) -> Module:
|
||||||
if isinstance(key, slice):
|
if isinstance(key, slice):
|
||||||
|
@ -324,16 +321,14 @@ class Chain(ContextModule):
|
||||||
self,
|
self,
|
||||||
predicate: Callable[[Module, "Chain"], bool] | None = None,
|
predicate: Callable[[Module, "Chain"], bool] | None = None,
|
||||||
recurse: bool = False,
|
recurse: bool = False,
|
||||||
) -> Iterator[tuple[Module, "Chain"]]:
|
) -> Iterator[tuple[Module, "Chain"]]: ...
|
||||||
...
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def walk(
|
def walk(
|
||||||
self,
|
self,
|
||||||
predicate: type[T],
|
predicate: type[T],
|
||||||
recurse: bool = False,
|
recurse: bool = False,
|
||||||
) -> Iterator[tuple[T, "Chain"]]:
|
) -> Iterator[tuple[T, "Chain"]]: ...
|
||||||
...
|
|
||||||
|
|
||||||
def walk(
|
def walk(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -440,18 +440,15 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
||||||
self.set_context("ip_adapter", {"clip_image_embedding": image_embedding})
|
self.set_context("ip_adapter", {"clip_image_embedding": image_embedding})
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def compute_clip_image_embedding(self, image_prompt: Tensor, weights: list[float] | None = None) -> Tensor:
|
def compute_clip_image_embedding(self, image_prompt: Tensor, weights: list[float] | None = None) -> Tensor: ...
|
||||||
...
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def compute_clip_image_embedding(self, image_prompt: Image.Image) -> Tensor:
|
def compute_clip_image_embedding(self, image_prompt: Image.Image) -> Tensor: ...
|
||||||
...
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def compute_clip_image_embedding(
|
def compute_clip_image_embedding(
|
||||||
self, image_prompt: list[Image.Image], weights: list[float] | None = None
|
self, image_prompt: list[Image.Image], weights: list[float] | None = None
|
||||||
) -> Tensor:
|
) -> Tensor: ...
|
||||||
...
|
|
||||||
|
|
||||||
def compute_clip_image_embedding(
|
def compute_clip_image_embedding(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -66,22 +66,18 @@ class LatentDiffusionModel(fl.Module, ABC):
|
||||||
return self.solver.inference_steps
|
return self.solver.inference_steps
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None:
|
def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None: ...
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def set_self_attention_guidance(self, enable: bool, scale: float = 1.0) -> None:
|
def set_self_attention_guidance(self, enable: bool, scale: float = 1.0) -> None: ...
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def has_self_attention_guidance(self) -> bool:
|
def has_self_attention_guidance(self) -> bool: ...
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def compute_self_attention_guidance(
|
def compute_self_attention_guidance(
|
||||||
self, x: Tensor, noise: Tensor, step: int, *, clip_text_embedding: Tensor, **kwargs: Tensor
|
self, x: Tensor, noise: Tensor, step: int, *, clip_text_embedding: Tensor, **kwargs: Tensor
|
||||||
) -> Tensor:
|
) -> Tensor: ...
|
||||||
...
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x: Tensor, step: int, *, clip_text_embedding: Tensor, condition_scale: float = 7.5, **kwargs: Tensor
|
self, x: Tensor, step: int, *, clip_text_embedding: Tensor, condition_scale: float = 7.5, **kwargs: Tensor
|
||||||
|
|
|
@ -68,8 +68,7 @@ class MultiDiffusion(Generic[T, D], ABC):
|
||||||
return torch.where(condition=num_updates > 0, input=cumulative_values / num_updates, other=x)
|
return torch.where(condition=num_updates > 0, input=cumulative_values / num_updates, other=x)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def diffuse_target(self, x: Tensor, step: int, target: D) -> Tensor:
|
def diffuse_target(self, x: Tensor, step: int, target: D) -> Tensor: ...
|
||||||
...
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def steps(self) -> list[int]:
|
def steps(self) -> list[int]:
|
||||||
|
|
|
@ -20,56 +20,38 @@ class CallbackConfig(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class Callback(Generic[T]):
|
class Callback(Generic[T]):
|
||||||
def on_init_begin(self, trainer: T) -> None:
|
def on_init_begin(self, trainer: T) -> None: ...
|
||||||
...
|
|
||||||
|
|
||||||
def on_init_end(self, trainer: T) -> None:
|
def on_init_end(self, trainer: T) -> None: ...
|
||||||
...
|
|
||||||
|
|
||||||
def on_train_begin(self, trainer: T) -> None:
|
def on_train_begin(self, trainer: T) -> None: ...
|
||||||
...
|
|
||||||
|
|
||||||
def on_train_end(self, trainer: T) -> None:
|
def on_train_end(self, trainer: T) -> None: ...
|
||||||
...
|
|
||||||
|
|
||||||
def on_epoch_begin(self, trainer: T) -> None:
|
def on_epoch_begin(self, trainer: T) -> None: ...
|
||||||
...
|
|
||||||
|
|
||||||
def on_epoch_end(self, trainer: T) -> None:
|
def on_epoch_end(self, trainer: T) -> None: ...
|
||||||
...
|
|
||||||
|
|
||||||
def on_batch_begin(self, trainer: T) -> None:
|
def on_batch_begin(self, trainer: T) -> None: ...
|
||||||
...
|
|
||||||
|
|
||||||
def on_batch_end(self, trainer: T) -> None:
|
def on_batch_end(self, trainer: T) -> None: ...
|
||||||
...
|
|
||||||
|
|
||||||
def on_backward_begin(self, trainer: T) -> None:
|
def on_backward_begin(self, trainer: T) -> None: ...
|
||||||
...
|
|
||||||
|
|
||||||
def on_backward_end(self, trainer: T) -> None:
|
def on_backward_end(self, trainer: T) -> None: ...
|
||||||
...
|
|
||||||
|
|
||||||
def on_optimizer_step_begin(self, trainer: T) -> None:
|
def on_optimizer_step_begin(self, trainer: T) -> None: ...
|
||||||
...
|
|
||||||
|
|
||||||
def on_optimizer_step_end(self, trainer: T) -> None:
|
def on_optimizer_step_end(self, trainer: T) -> None: ...
|
||||||
...
|
|
||||||
|
|
||||||
def on_compute_loss_begin(self, trainer: T) -> None:
|
def on_compute_loss_begin(self, trainer: T) -> None: ...
|
||||||
...
|
|
||||||
|
|
||||||
def on_compute_loss_end(self, trainer: T) -> None:
|
def on_compute_loss_end(self, trainer: T) -> None: ...
|
||||||
...
|
|
||||||
|
|
||||||
def on_evaluate_begin(self, trainer: T) -> None:
|
def on_evaluate_begin(self, trainer: T) -> None: ...
|
||||||
...
|
|
||||||
|
|
||||||
def on_evaluate_end(self, trainer: T) -> None:
|
def on_evaluate_end(self, trainer: T) -> None: ...
|
||||||
...
|
|
||||||
|
|
||||||
def on_lr_scheduler_step_begin(self, trainer: T) -> None:
|
def on_lr_scheduler_step_begin(self, trainer: T) -> None: ...
|
||||||
...
|
|
||||||
|
|
||||||
def on_lr_scheduler_step_end(self, trainer: T) -> None:
|
def on_lr_scheduler_step_end(self, trainer: T) -> None: ...
|
||||||
...
|
|
||||||
|
|
|
@ -10,11 +10,9 @@ T = TypeVar("T", covariant=True)
|
||||||
|
|
||||||
|
|
||||||
class HuggingfaceDataset(Generic[T], Protocol):
|
class HuggingfaceDataset(Generic[T], Protocol):
|
||||||
def __getitem__(self, index: int) -> T:
|
def __getitem__(self, index: int) -> T: ...
|
||||||
...
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
def load_hf_dataset(
|
def load_hf_dataset(
|
||||||
|
|
|
@ -335,8 +335,7 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
||||||
)
|
)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def compute_loss(self, batch: Batch) -> Tensor:
|
def compute_loss(self, batch: Batch) -> Tensor: ...
|
||||||
...
|
|
||||||
|
|
||||||
def compute_evaluation(self) -> None:
|
def compute_evaluation(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -10,13 +10,11 @@ from refiners.foundationals.latent_diffusion.image_prompt import ImageCrossAtten
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def new_adapter(target: SD1UNet) -> SD1IPAdapter:
|
def new_adapter(target: SD1UNet) -> SD1IPAdapter: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def new_adapter(target: SDXLUNet) -> SDXLIPAdapter:
|
def new_adapter(target: SDXLUNet) -> SDXLIPAdapter: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
def new_adapter(target: SD1UNet | SDXLUNet) -> SD1IPAdapter | SDXLIPAdapter:
|
def new_adapter(target: SD1UNet | SDXLUNet) -> SD1IPAdapter | SDXLIPAdapter:
|
||||||
|
|
|
@ -9,13 +9,11 @@ from refiners.foundationals.latent_diffusion.t2i_adapter import T2IFeatures
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def new_adapter(target: SD1UNet, name: str) -> SD1T2IAdapter:
|
def new_adapter(target: SD1UNet, name: str) -> SD1T2IAdapter: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def new_adapter(target: SDXLUNet, name: str) -> SDXLT2IAdapter:
|
def new_adapter(target: SDXLUNet, name: str) -> SDXLT2IAdapter: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
def new_adapter(target: SD1UNet | SDXLUNet, name: str) -> SD1T2IAdapter | SDXLT2IAdapter:
|
def new_adapter(target: SD1UNet | SDXLUNet, name: str) -> SD1T2IAdapter | SDXLT2IAdapter:
|
||||||
|
|
|
@ -19,8 +19,7 @@ class DiffusersSDXL(Protocol):
|
||||||
tokenizer_2: fl.Module
|
tokenizer_2: fl.Module
|
||||||
vae: fl.Module
|
vae: fl.Module
|
||||||
|
|
||||||
def __call__(self, prompt: str, *args: Any, **kwargs: Any) -> Any:
|
def __call__(self, prompt: str, *args: Any, **kwargs: Any) -> Any: ...
|
||||||
...
|
|
||||||
|
|
||||||
def encode_prompt(
|
def encode_prompt(
|
||||||
self,
|
self,
|
||||||
|
@ -28,8 +27,7 @@ class DiffusersSDXL(Protocol):
|
||||||
prompt_2: str | None = None,
|
prompt_2: str | None = None,
|
||||||
negative_prompt: str | None = None,
|
negative_prompt: str | None = None,
|
||||||
negative_prompt_2: str | None = None,
|
negative_prompt_2: str | None = None,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
|
|
|
@ -32,19 +32,16 @@ class FacebookSAM(nn.Module):
|
||||||
prompt_encoder: nn.Module
|
prompt_encoder: nn.Module
|
||||||
mask_decoder: nn.Module
|
mask_decoder: nn.Module
|
||||||
|
|
||||||
def __call__(self, batched_input: list[SAMInput], multimask_output: bool) -> list[SAMOutput]:
|
def __call__(self, batched_input: list[SAMInput], multimask_output: bool) -> list[SAMOutput]: ...
|
||||||
...
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device(self) -> Any:
|
def device(self) -> Any: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class FacebookSAMPredictor:
|
class FacebookSAMPredictor:
|
||||||
model: FacebookSAM
|
model: FacebookSAM
|
||||||
|
|
||||||
def set_image(self, image: NDArrayUInt8, image_format: str = "RGB") -> None:
|
def set_image(self, image: NDArrayUInt8, image_format: str = "RGB") -> None: ...
|
||||||
...
|
|
||||||
|
|
||||||
def predict(
|
def predict(
|
||||||
self,
|
self,
|
||||||
|
@ -54,8 +51,7 @@ class FacebookSAMPredictor:
|
||||||
mask_input: NDArray | None = None,
|
mask_input: NDArray | None = None,
|
||||||
multimask_output: bool = True,
|
multimask_output: bool = True,
|
||||||
return_logits: bool = False,
|
return_logits: bool = False,
|
||||||
) -> tuple[NDArray, NDArray, NDArray]:
|
) -> tuple[NDArray, NDArray, NDArray]: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
Loading…
Reference in a new issue