fix typing problems

This commit is contained in:
Pierre Chapuis 2024-06-24 10:58:32 +02:00
parent d1fc845bc2
commit 98de9d13d3
7 changed files with 27 additions and 17 deletions

View file

@ -1,5 +1,5 @@
import argparse
from typing import TYPE_CHECKING, cast
from typing import cast
import torch
from torch import nn
@ -15,8 +15,6 @@ except ImportError:
"Please download the model.py file from https://github.com/carolineec/informative-drawings and add it to your"
" PYTHONPATH"
)
if TYPE_CHECKING:
Generator = cast(nn.Module, Generator)
class Args(argparse.Namespace):
@ -27,7 +25,7 @@ class Args(argparse.Namespace):
def setup_converter(args: Args) -> ModelConverter:
source = Generator(3, 1, 3)
source = cast(nn.Module, Generator(3, 1, 3))
source.load_state_dict(state_dict=load_tensors(args.source_path))
source.eval()
target = InformativeDrawings()

View file

@ -3,9 +3,9 @@ import dataclasses
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, sqrt, tensor
from refiners.foundationals.latent_diffusion.solvers.solver import (
BaseSolverParams,
ModelPredictionType,
Solver,
SolverParams,
TimestepSpacing,
)
@ -26,7 +26,7 @@ class DDIM(Solver):
self,
num_inference_steps: int,
first_inference_step: int = 0,
params: SolverParams | None = None,
params: BaseSolverParams | None = None,
device: Device | str = "cpu",
dtype: Dtype = float32,
) -> None:

View file

@ -3,9 +3,9 @@ import dataclasses
from torch import Generator, Tensor, device as Device
from refiners.foundationals.latent_diffusion.solvers.solver import (
BaseSolverParams,
ModelPredictionType,
Solver,
SolverParams,
TimestepSpacing,
)
@ -29,7 +29,7 @@ class DDPM(Solver):
self,
num_inference_steps: int,
first_inference_step: int = 0,
params: SolverParams | None = None,
params: BaseSolverParams | None = None,
device: Device | str = "cpu",
) -> None:
"""Initializes a new DDPM solver.

View file

@ -5,9 +5,9 @@ import numpy as np
from torch import Generator, Tensor, device as Device, dtype as Dtype, exp, float32, tensor
from refiners.foundationals.latent_diffusion.solvers.solver import (
BaseSolverParams,
ModelPredictionType,
Solver,
SolverParams,
TimestepSpacing,
)
@ -34,7 +34,7 @@ class DPMSolver(Solver):
self,
num_inference_steps: int,
first_inference_step: int = 0,
params: SolverParams | None = None,
params: BaseSolverParams | None = None,
last_step_first_order: bool = False,
device: Device | str = "cpu",
dtype: Dtype = float32,

View file

@ -3,10 +3,10 @@ import torch
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor
from refiners.foundationals.latent_diffusion.solvers.solver import (
BaseSolverParams,
ModelPredictionType,
NoiseSchedule,
Solver,
SolverParams,
)
@ -21,7 +21,7 @@ class Euler(Solver):
self,
num_inference_steps: int,
first_inference_step: int = 0,
params: SolverParams | None = None,
params: BaseSolverParams | None = None,
device: Device | str = "cpu",
dtype: Dtype = float32,
):

View file

@ -4,6 +4,7 @@ import torch
from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver
from refiners.foundationals.latent_diffusion.solvers.solver import (
BaseSolverParams,
ModelPredictionType,
Solver,
SolverParams,
@ -29,7 +30,7 @@ class LCMSolver(Solver):
self,
num_inference_steps: int,
first_inference_step: int = 0,
params: SolverParams | None = None,
params: BaseSolverParams | None = None,
num_orig_steps: int = 50,
device: torch.device | str = "cpu",
dtype: torch.dtype = torch.float32,

View file

@ -71,7 +71,18 @@ class ModelPredictionType(str, Enum):
@dataclasses.dataclass(kw_only=True, frozen=True)
class SolverParams:
class BaseSolverParams:
num_train_timesteps: int | None
timesteps_spacing: TimestepSpacing | None
timesteps_offset: int | None
initial_diffusion_rate: float | None
final_diffusion_rate: float | None
noise_schedule: NoiseSchedule | None
model_prediction_type: ModelPredictionType | None
@dataclasses.dataclass(kw_only=True, frozen=True)
class SolverParams(BaseSolverParams):
"""Common parameters for solvers.
Args:
@ -94,7 +105,7 @@ class SolverParams:
@dataclasses.dataclass(kw_only=True, frozen=True)
class ResolvedSolverParams(SolverParams):
class ResolvedSolverParams(BaseSolverParams):
num_train_timesteps: int
timesteps_spacing: TimestepSpacing
timesteps_offset: int
@ -131,7 +142,7 @@ class Solver(fl.Module, ABC):
self,
num_inference_steps: int,
first_inference_step: int = 0,
params: SolverParams | None = None,
params: BaseSolverParams | None = None,
device: Device | str = "cpu",
dtype: DType = float32,
) -> None:
@ -158,7 +169,7 @@ class Solver(fl.Module, ABC):
self.to(device=device, dtype=dtype)
def resolve_params(self, params: SolverParams | None) -> ResolvedSolverParams:
def resolve_params(self, params: BaseSolverParams | None) -> ResolvedSolverParams:
if params is None:
return dataclasses.replace(self.default_params)
return dataclasses.replace(