mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
black
This commit is contained in:
parent
eb88cde7ac
commit
0e0c39b4b5
|
@ -213,16 +213,13 @@ class Chain(ContextModule):
|
||||||
return Chain(*self, *other)
|
return Chain(*self, *other)
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __getitem__(self, key: int) -> Module:
|
def __getitem__(self, key: int) -> Module: ...
|
||||||
...
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __getitem__(self, key: str) -> Module:
|
def __getitem__(self, key: str) -> Module: ...
|
||||||
...
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __getitem__(self, key: slice) -> "Chain":
|
def __getitem__(self, key: slice) -> "Chain": ...
|
||||||
...
|
|
||||||
|
|
||||||
def __getitem__(self, key: int | str | slice) -> Module:
|
def __getitem__(self, key: int | str | slice) -> Module:
|
||||||
if isinstance(key, slice):
|
if isinstance(key, slice):
|
||||||
|
@ -270,12 +267,10 @@ class Chain(ContextModule):
|
||||||
@overload
|
@overload
|
||||||
def walk(
|
def walk(
|
||||||
self, predicate: Callable[[Module, "Chain"], bool] | None = None, recurse: bool = False
|
self, predicate: Callable[[Module, "Chain"], bool] | None = None, recurse: bool = False
|
||||||
) -> Iterator[tuple[Module, "Chain"]]:
|
) -> Iterator[tuple[Module, "Chain"]]: ...
|
||||||
...
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def walk(self, predicate: type[T], recurse: bool = False) -> Iterator[tuple[T, "Chain"]]:
|
def walk(self, predicate: type[T], recurse: bool = False) -> Iterator[tuple[T, "Chain"]]: ...
|
||||||
...
|
|
||||||
|
|
||||||
def walk(
|
def walk(
|
||||||
self, predicate: type[T] | Callable[[Module, "Chain"], bool] | None = None, recurse: bool = False
|
self, predicate: type[T] | Callable[[Module, "Chain"], bool] | None = None, recurse: bool = False
|
||||||
|
|
|
@ -66,8 +66,7 @@ class LatentDiffusionModel(fl.Module, ABC):
|
||||||
return self.scheduler.steps
|
return self.scheduler.steps
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None:
|
def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None: ...
|
||||||
...
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x: Tensor, step: int, *, clip_text_embedding: Tensor, condition_scale: float = 7.5, **kwargs: Tensor
|
self, x: Tensor, step: int, *, clip_text_embedding: Tensor, condition_scale: float = 7.5, **kwargs: Tensor
|
||||||
|
|
|
@ -42,56 +42,39 @@ T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
class Callback(Generic[T]):
|
class Callback(Generic[T]):
|
||||||
def on_train_begin(self, trainer: T) -> None:
|
def on_train_begin(self, trainer: T) -> None: ...
|
||||||
...
|
|
||||||
|
|
||||||
def on_train_end(self, trainer: T) -> None:
|
def on_train_end(self, trainer: T) -> None: ...
|
||||||
...
|
|
||||||
|
|
||||||
def on_epoch_begin(self, trainer: T) -> None:
|
def on_epoch_begin(self, trainer: T) -> None: ...
|
||||||
...
|
|
||||||
|
|
||||||
def on_epoch_end(self, trainer: T) -> None:
|
def on_epoch_end(self, trainer: T) -> None: ...
|
||||||
...
|
|
||||||
|
|
||||||
def on_batch_begin(self, trainer: T) -> None:
|
def on_batch_begin(self, trainer: T) -> None: ...
|
||||||
...
|
|
||||||
|
|
||||||
def on_batch_end(self, trainer: T) -> None:
|
def on_batch_end(self, trainer: T) -> None: ...
|
||||||
...
|
|
||||||
|
|
||||||
def on_backward_begin(self, trainer: T) -> None:
|
def on_backward_begin(self, trainer: T) -> None: ...
|
||||||
...
|
|
||||||
|
|
||||||
def on_backward_end(self, trainer: T) -> None:
|
def on_backward_end(self, trainer: T) -> None: ...
|
||||||
...
|
|
||||||
|
|
||||||
def on_optimizer_step_begin(self, trainer: T) -> None:
|
def on_optimizer_step_begin(self, trainer: T) -> None: ...
|
||||||
...
|
|
||||||
|
|
||||||
def on_optimizer_step_end(self, trainer: T) -> None:
|
def on_optimizer_step_end(self, trainer: T) -> None: ...
|
||||||
...
|
|
||||||
|
|
||||||
def on_compute_loss_begin(self, trainer: T) -> None:
|
def on_compute_loss_begin(self, trainer: T) -> None: ...
|
||||||
...
|
|
||||||
|
|
||||||
def on_compute_loss_end(self, trainer: T) -> None:
|
def on_compute_loss_end(self, trainer: T) -> None: ...
|
||||||
...
|
|
||||||
|
|
||||||
def on_evaluate_begin(self, trainer: T) -> None:
|
def on_evaluate_begin(self, trainer: T) -> None: ...
|
||||||
...
|
|
||||||
|
|
||||||
def on_evaluate_end(self, trainer: T) -> None:
|
def on_evaluate_end(self, trainer: T) -> None: ...
|
||||||
...
|
|
||||||
|
|
||||||
def on_lr_scheduler_step_begin(self, trainer: T) -> None:
|
def on_lr_scheduler_step_begin(self, trainer: T) -> None: ...
|
||||||
...
|
|
||||||
|
|
||||||
def on_lr_scheduler_step_end(self, trainer: T) -> None:
|
def on_lr_scheduler_step_end(self, trainer: T) -> None: ...
|
||||||
...
|
|
||||||
|
|
||||||
def on_checkpoint_save(self, trainer: T) -> None:
|
def on_checkpoint_save(self, trainer: T) -> None: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class ClockCallback(Callback["Trainer[BaseConfig, Any]"]):
|
class ClockCallback(Callback["Trainer[BaseConfig, Any]"]):
|
||||||
|
|
|
@ -8,11 +8,9 @@ T = TypeVar("T", covariant=True)
|
||||||
|
|
||||||
|
|
||||||
class HuggingfaceDataset(Generic[T], Protocol):
|
class HuggingfaceDataset(Generic[T], Protocol):
|
||||||
def __getitem__(self, index: int) -> T:
|
def __getitem__(self, index: int) -> T: ...
|
||||||
...
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
def load_hf_dataset(
|
def load_hf_dataset(
|
||||||
|
|
|
@ -18,8 +18,7 @@ class DiffusersSDXL(Protocol):
|
||||||
tokenizer_2: fl.Module
|
tokenizer_2: fl.Module
|
||||||
vae: fl.Module
|
vae: fl.Module
|
||||||
|
|
||||||
def __call__(self, prompt: str, *args: Any, **kwargs: Any) -> Any:
|
def __call__(self, prompt: str, *args: Any, **kwargs: Any) -> Any: ...
|
||||||
...
|
|
||||||
|
|
||||||
def encode_prompt(
|
def encode_prompt(
|
||||||
self,
|
self,
|
||||||
|
@ -27,8 +26,7 @@ class DiffusersSDXL(Protocol):
|
||||||
prompt_2: str | None = None,
|
prompt_2: str | None = None,
|
||||||
negative_prompt: str | None = None,
|
negative_prompt: str | None = None,
|
||||||
negative_prompt_2: str | None = None,
|
negative_prompt_2: str | None = None,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
|
|
Loading…
Reference in a new issue