This commit is contained in:
Pierre Chapuis 2023-09-13 16:27:08 +02:00
parent eb88cde7ac
commit 0e0c39b4b5
5 changed files with 27 additions and 54 deletions

View file

@ -213,16 +213,13 @@ class Chain(ContextModule):
return Chain(*self, *other)
@overload
def __getitem__(self, key: int) -> Module:
...
def __getitem__(self, key: int) -> Module: ...
@overload
def __getitem__(self, key: str) -> Module:
...
def __getitem__(self, key: str) -> Module: ...
@overload
def __getitem__(self, key: slice) -> "Chain":
...
def __getitem__(self, key: slice) -> "Chain": ...
def __getitem__(self, key: int | str | slice) -> Module:
if isinstance(key, slice):
@ -270,12 +267,10 @@ class Chain(ContextModule):
@overload
def walk(
self, predicate: Callable[[Module, "Chain"], bool] | None = None, recurse: bool = False
) -> Iterator[tuple[Module, "Chain"]]:
...
) -> Iterator[tuple[Module, "Chain"]]: ...
@overload
def walk(self, predicate: type[T], recurse: bool = False) -> Iterator[tuple[T, "Chain"]]:
...
def walk(self, predicate: type[T], recurse: bool = False) -> Iterator[tuple[T, "Chain"]]: ...
def walk(
self, predicate: type[T] | Callable[[Module, "Chain"], bool] | None = None, recurse: bool = False

View file

@ -66,8 +66,7 @@ class LatentDiffusionModel(fl.Module, ABC):
return self.scheduler.steps
@abstractmethod
def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None:
...
def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None: ...
def forward(
self, x: Tensor, step: int, *, clip_text_embedding: Tensor, condition_scale: float = 7.5, **kwargs: Tensor

View file

@ -42,56 +42,39 @@ T = TypeVar("T")
class Callback(Generic[T]):
def on_train_begin(self, trainer: T) -> None:
...
def on_train_begin(self, trainer: T) -> None: ...
def on_train_end(self, trainer: T) -> None:
...
def on_train_end(self, trainer: T) -> None: ...
def on_epoch_begin(self, trainer: T) -> None:
...
def on_epoch_begin(self, trainer: T) -> None: ...
def on_epoch_end(self, trainer: T) -> None:
...
def on_epoch_end(self, trainer: T) -> None: ...
def on_batch_begin(self, trainer: T) -> None:
...
def on_batch_begin(self, trainer: T) -> None: ...
def on_batch_end(self, trainer: T) -> None:
...
def on_batch_end(self, trainer: T) -> None: ...
def on_backward_begin(self, trainer: T) -> None:
...
def on_backward_begin(self, trainer: T) -> None: ...
def on_backward_end(self, trainer: T) -> None:
...
def on_backward_end(self, trainer: T) -> None: ...
def on_optimizer_step_begin(self, trainer: T) -> None:
...
def on_optimizer_step_begin(self, trainer: T) -> None: ...
def on_optimizer_step_end(self, trainer: T) -> None:
...
def on_optimizer_step_end(self, trainer: T) -> None: ...
def on_compute_loss_begin(self, trainer: T) -> None:
...
def on_compute_loss_begin(self, trainer: T) -> None: ...
def on_compute_loss_end(self, trainer: T) -> None:
...
def on_compute_loss_end(self, trainer: T) -> None: ...
def on_evaluate_begin(self, trainer: T) -> None:
...
def on_evaluate_begin(self, trainer: T) -> None: ...
def on_evaluate_end(self, trainer: T) -> None:
...
def on_evaluate_end(self, trainer: T) -> None: ...
def on_lr_scheduler_step_begin(self, trainer: T) -> None:
...
def on_lr_scheduler_step_begin(self, trainer: T) -> None: ...
def on_lr_scheduler_step_end(self, trainer: T) -> None:
...
def on_lr_scheduler_step_end(self, trainer: T) -> None: ...
def on_checkpoint_save(self, trainer: T) -> None:
...
def on_checkpoint_save(self, trainer: T) -> None: ...
class ClockCallback(Callback["Trainer[BaseConfig, Any]"]):

View file

@ -8,11 +8,9 @@ T = TypeVar("T", covariant=True)
class HuggingfaceDataset(Generic[T], Protocol):
def __getitem__(self, index: int) -> T:
...
def __getitem__(self, index: int) -> T: ...
def __len__(self) -> int:
...
def __len__(self) -> int: ...
def load_hf_dataset(

View file

@ -18,8 +18,7 @@ class DiffusersSDXL(Protocol):
tokenizer_2: fl.Module
vae: fl.Module
def __call__(self, prompt: str, *args: Any, **kwargs: Any) -> Any:
...
def __call__(self, prompt: str, *args: Any, **kwargs: Any) -> Any: ...
def encode_prompt(
self,
@ -27,8 +26,7 @@ class DiffusersSDXL(Protocol):
prompt_2: str | None = None,
negative_prompt: str | None = None,
negative_prompt_2: str | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
...
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ...
@pytest.fixture(scope="module")