diff --git a/scripts/conversion/convert_informative_drawings.py b/scripts/conversion/convert_informative_drawings.py index d163472..df350b6 100644 --- a/scripts/conversion/convert_informative_drawings.py +++ b/scripts/conversion/convert_informative_drawings.py @@ -1,5 +1,5 @@ import argparse -from typing import TYPE_CHECKING, cast +from typing import cast import torch from torch import nn @@ -15,8 +15,6 @@ except ImportError: "Please download the model.py file from https://github.com/carolineec/informative-drawings and add it to your" " PYTHONPATH" ) -if TYPE_CHECKING: - Generator = cast(nn.Module, Generator) class Args(argparse.Namespace): @@ -27,7 +25,7 @@ class Args(argparse.Namespace): def setup_converter(args: Args) -> ModelConverter: - source = Generator(3, 1, 3) + source = cast(nn.Module, Generator(3, 1, 3)) source.load_state_dict(state_dict=load_tensors(args.source_path)) source.eval() target = InformativeDrawings() diff --git a/src/refiners/foundationals/latent_diffusion/solvers/ddim.py b/src/refiners/foundationals/latent_diffusion/solvers/ddim.py index d4d5a81..6b21980 100644 --- a/src/refiners/foundationals/latent_diffusion/solvers/ddim.py +++ b/src/refiners/foundationals/latent_diffusion/solvers/ddim.py @@ -3,9 +3,9 @@ import dataclasses from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, sqrt, tensor from refiners.foundationals.latent_diffusion.solvers.solver import ( + BaseSolverParams, ModelPredictionType, Solver, - SolverParams, TimestepSpacing, ) @@ -26,7 +26,7 @@ class DDIM(Solver): self, num_inference_steps: int, first_inference_step: int = 0, - params: SolverParams | None = None, + params: BaseSolverParams | None = None, device: Device | str = "cpu", dtype: Dtype = float32, ) -> None: diff --git a/src/refiners/foundationals/latent_diffusion/solvers/ddpm.py b/src/refiners/foundationals/latent_diffusion/solvers/ddpm.py index 442a63a..efa50fa 100644 --- a/src/refiners/foundationals/latent_diffusion/solvers/ddpm.py +++ b/src/refiners/foundationals/latent_diffusion/solvers/ddpm.py @@ -3,9 +3,9 @@ import dataclasses from torch import Generator, Tensor, device as Device from refiners.foundationals.latent_diffusion.solvers.solver import ( + BaseSolverParams, ModelPredictionType, Solver, - SolverParams, TimestepSpacing, ) @@ -29,7 +29,7 @@ class DDPM(Solver): self, num_inference_steps: int, first_inference_step: int = 0, - params: SolverParams | None = None, + params: BaseSolverParams | None = None, device: Device | str = "cpu", ) -> None: """Initializes a new DDPM solver. diff --git a/src/refiners/foundationals/latent_diffusion/solvers/dpm.py b/src/refiners/foundationals/latent_diffusion/solvers/dpm.py index 15a953e..2074715 100644 --- a/src/refiners/foundationals/latent_diffusion/solvers/dpm.py +++ b/src/refiners/foundationals/latent_diffusion/solvers/dpm.py @@ -5,9 +5,9 @@ import numpy as np from torch import Generator, Tensor, device as Device, dtype as Dtype, exp, float32, tensor from refiners.foundationals.latent_diffusion.solvers.solver import ( + BaseSolverParams, ModelPredictionType, Solver, - SolverParams, TimestepSpacing, ) @@ -34,7 +34,7 @@ class DPMSolver(Solver): self, num_inference_steps: int, first_inference_step: int = 0, - params: SolverParams | None = None, + params: BaseSolverParams | None = None, last_step_first_order: bool = False, device: Device | str = "cpu", dtype: Dtype = float32, diff --git a/src/refiners/foundationals/latent_diffusion/solvers/euler.py b/src/refiners/foundationals/latent_diffusion/solvers/euler.py index e88c48c..09f8007 100644 --- a/src/refiners/foundationals/latent_diffusion/solvers/euler.py +++ b/src/refiners/foundationals/latent_diffusion/solvers/euler.py @@ -3,10 +3,10 @@ import torch from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor from refiners.foundationals.latent_diffusion.solvers.solver import ( + BaseSolverParams, ModelPredictionType, NoiseSchedule, Solver, - SolverParams, ) @@ -21,7 +21,7 @@ class Euler(Solver): self, num_inference_steps: int, first_inference_step: int = 0, - params: SolverParams | None = None, + params: BaseSolverParams | None = None, device: Device | str = "cpu", dtype: Dtype = float32, ): diff --git a/src/refiners/foundationals/latent_diffusion/solvers/lcm.py b/src/refiners/foundationals/latent_diffusion/solvers/lcm.py index c7087d2..06a45c8 100644 --- a/src/refiners/foundationals/latent_diffusion/solvers/lcm.py +++ b/src/refiners/foundationals/latent_diffusion/solvers/lcm.py @@ -4,6 +4,7 @@ import torch from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver from refiners.foundationals.latent_diffusion.solvers.solver import ( + BaseSolverParams, ModelPredictionType, Solver, SolverParams, @@ -29,7 +30,7 @@ class LCMSolver(Solver): self, num_inference_steps: int, first_inference_step: int = 0, - params: SolverParams | None = None, + params: BaseSolverParams | None = None, num_orig_steps: int = 50, device: torch.device | str = "cpu", dtype: torch.dtype = torch.float32, diff --git a/src/refiners/foundationals/latent_diffusion/solvers/solver.py b/src/refiners/foundationals/latent_diffusion/solvers/solver.py index 24c570b..1572091 100644 --- a/src/refiners/foundationals/latent_diffusion/solvers/solver.py +++ b/src/refiners/foundationals/latent_diffusion/solvers/solver.py @@ -71,7 +71,18 @@ class ModelPredictionType(str, Enum): @dataclasses.dataclass(kw_only=True, frozen=True) -class SolverParams: +class BaseSolverParams: + num_train_timesteps: int | None + timesteps_spacing: TimestepSpacing | None + timesteps_offset: int | None + initial_diffusion_rate: float | None + final_diffusion_rate: float | None + noise_schedule: NoiseSchedule | None + model_prediction_type: ModelPredictionType | None + + +@dataclasses.dataclass(kw_only=True, frozen=True) +class SolverParams(BaseSolverParams): """Common parameters for solvers. Args: @@ -94,7 +105,7 @@ class SolverParams: @dataclasses.dataclass(kw_only=True, frozen=True) -class ResolvedSolverParams(SolverParams): +class ResolvedSolverParams(BaseSolverParams): num_train_timesteps: int timesteps_spacing: TimestepSpacing timesteps_offset: int @@ -131,7 +142,7 @@ class Solver(fl.Module, ABC): self, num_inference_steps: int, first_inference_step: int = 0, - params: SolverParams | None = None, + params: BaseSolverParams | None = None, device: Device | str = "cpu", dtype: DType = float32, ) -> None: @@ -158,7 +169,7 @@ class Solver(fl.Module, ABC): self.to(device=device, dtype=dtype) - def resolve_params(self, params: SolverParams | None) -> ResolvedSolverParams: + def resolve_params(self, params: BaseSolverParams | None) -> ResolvedSolverParams: if params is None: return dataclasses.replace(self.default_params) return dataclasses.replace(