This commit is contained in:
Pierre Chapuis 2023-09-13 16:27:08 +02:00
parent eb88cde7ac
commit 0e0c39b4b5
5 changed files with 27 additions and 54 deletions

View file

@ -213,16 +213,13 @@ class Chain(ContextModule):
return Chain(*self, *other) return Chain(*self, *other)
@overload @overload
def __getitem__(self, key: int) -> Module: def __getitem__(self, key: int) -> Module: ...
...
@overload @overload
def __getitem__(self, key: str) -> Module: def __getitem__(self, key: str) -> Module: ...
...
@overload @overload
def __getitem__(self, key: slice) -> "Chain": def __getitem__(self, key: slice) -> "Chain": ...
...
def __getitem__(self, key: int | str | slice) -> Module: def __getitem__(self, key: int | str | slice) -> Module:
if isinstance(key, slice): if isinstance(key, slice):
@ -270,12 +267,10 @@ class Chain(ContextModule):
@overload @overload
def walk( def walk(
self, predicate: Callable[[Module, "Chain"], bool] | None = None, recurse: bool = False self, predicate: Callable[[Module, "Chain"], bool] | None = None, recurse: bool = False
) -> Iterator[tuple[Module, "Chain"]]: ) -> Iterator[tuple[Module, "Chain"]]: ...
...
@overload @overload
def walk(self, predicate: type[T], recurse: bool = False) -> Iterator[tuple[T, "Chain"]]: def walk(self, predicate: type[T], recurse: bool = False) -> Iterator[tuple[T, "Chain"]]: ...
...
def walk( def walk(
self, predicate: type[T] | Callable[[Module, "Chain"], bool] | None = None, recurse: bool = False self, predicate: type[T] | Callable[[Module, "Chain"], bool] | None = None, recurse: bool = False

View file

@ -66,8 +66,7 @@ class LatentDiffusionModel(fl.Module, ABC):
return self.scheduler.steps return self.scheduler.steps
@abstractmethod @abstractmethod
def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None: def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None: ...
...
def forward( def forward(
self, x: Tensor, step: int, *, clip_text_embedding: Tensor, condition_scale: float = 7.5, **kwargs: Tensor self, x: Tensor, step: int, *, clip_text_embedding: Tensor, condition_scale: float = 7.5, **kwargs: Tensor

View file

@ -42,56 +42,39 @@ T = TypeVar("T")
class Callback(Generic[T]): class Callback(Generic[T]):
def on_train_begin(self, trainer: T) -> None: def on_train_begin(self, trainer: T) -> None: ...
...
def on_train_end(self, trainer: T) -> None: def on_train_end(self, trainer: T) -> None: ...
...
def on_epoch_begin(self, trainer: T) -> None: def on_epoch_begin(self, trainer: T) -> None: ...
...
def on_epoch_end(self, trainer: T) -> None: def on_epoch_end(self, trainer: T) -> None: ...
...
def on_batch_begin(self, trainer: T) -> None: def on_batch_begin(self, trainer: T) -> None: ...
...
def on_batch_end(self, trainer: T) -> None: def on_batch_end(self, trainer: T) -> None: ...
...
def on_backward_begin(self, trainer: T) -> None: def on_backward_begin(self, trainer: T) -> None: ...
...
def on_backward_end(self, trainer: T) -> None: def on_backward_end(self, trainer: T) -> None: ...
...
def on_optimizer_step_begin(self, trainer: T) -> None: def on_optimizer_step_begin(self, trainer: T) -> None: ...
...
def on_optimizer_step_end(self, trainer: T) -> None: def on_optimizer_step_end(self, trainer: T) -> None: ...
...
def on_compute_loss_begin(self, trainer: T) -> None: def on_compute_loss_begin(self, trainer: T) -> None: ...
...
def on_compute_loss_end(self, trainer: T) -> None: def on_compute_loss_end(self, trainer: T) -> None: ...
...
def on_evaluate_begin(self, trainer: T) -> None: def on_evaluate_begin(self, trainer: T) -> None: ...
...
def on_evaluate_end(self, trainer: T) -> None: def on_evaluate_end(self, trainer: T) -> None: ...
...
def on_lr_scheduler_step_begin(self, trainer: T) -> None: def on_lr_scheduler_step_begin(self, trainer: T) -> None: ...
...
def on_lr_scheduler_step_end(self, trainer: T) -> None: def on_lr_scheduler_step_end(self, trainer: T) -> None: ...
...
def on_checkpoint_save(self, trainer: T) -> None: def on_checkpoint_save(self, trainer: T) -> None: ...
...
class ClockCallback(Callback["Trainer[BaseConfig, Any]"]): class ClockCallback(Callback["Trainer[BaseConfig, Any]"]):

View file

@ -8,11 +8,9 @@ T = TypeVar("T", covariant=True)
class HuggingfaceDataset(Generic[T], Protocol): class HuggingfaceDataset(Generic[T], Protocol):
def __getitem__(self, index: int) -> T: def __getitem__(self, index: int) -> T: ...
...
def __len__(self) -> int: def __len__(self) -> int: ...
...
def load_hf_dataset( def load_hf_dataset(

View file

@ -18,8 +18,7 @@ class DiffusersSDXL(Protocol):
tokenizer_2: fl.Module tokenizer_2: fl.Module
vae: fl.Module vae: fl.Module
def __call__(self, prompt: str, *args: Any, **kwargs: Any) -> Any: def __call__(self, prompt: str, *args: Any, **kwargs: Any) -> Any: ...
...
def encode_prompt( def encode_prompt(
self, self,
@ -27,8 +26,7 @@ class DiffusersSDXL(Protocol):
prompt_2: str | None = None, prompt_2: str | None = None,
negative_prompt: str | None = None, negative_prompt: str | None = None,
negative_prompt_2: str | None = None, negative_prompt_2: str | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ...
...
@pytest.fixture(scope="module") @pytest.fixture(scope="module")