fix typing problems

This commit is contained in:
Pierre Chapuis 2024-06-24 10:58:32 +02:00
parent d1fc845bc2
commit 98de9d13d3
7 changed files with 27 additions and 17 deletions

View file

@ -1,5 +1,5 @@
import argparse import argparse
from typing import TYPE_CHECKING, cast from typing import cast
import torch import torch
from torch import nn from torch import nn
@ -15,8 +15,6 @@ except ImportError:
"Please download the model.py file from https://github.com/carolineec/informative-drawings and add it to your" "Please download the model.py file from https://github.com/carolineec/informative-drawings and add it to your"
" PYTHONPATH" " PYTHONPATH"
) )
if TYPE_CHECKING:
Generator = cast(nn.Module, Generator)
class Args(argparse.Namespace): class Args(argparse.Namespace):
@ -27,7 +25,7 @@ class Args(argparse.Namespace):
def setup_converter(args: Args) -> ModelConverter: def setup_converter(args: Args) -> ModelConverter:
source = Generator(3, 1, 3) source = cast(nn.Module, Generator(3, 1, 3))
source.load_state_dict(state_dict=load_tensors(args.source_path)) source.load_state_dict(state_dict=load_tensors(args.source_path))
source.eval() source.eval()
target = InformativeDrawings() target = InformativeDrawings()

View file

@ -3,9 +3,9 @@ import dataclasses
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, sqrt, tensor from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, sqrt, tensor
from refiners.foundationals.latent_diffusion.solvers.solver import ( from refiners.foundationals.latent_diffusion.solvers.solver import (
BaseSolverParams,
ModelPredictionType, ModelPredictionType,
Solver, Solver,
SolverParams,
TimestepSpacing, TimestepSpacing,
) )
@ -26,7 +26,7 @@ class DDIM(Solver):
self, self,
num_inference_steps: int, num_inference_steps: int,
first_inference_step: int = 0, first_inference_step: int = 0,
params: SolverParams | None = None, params: BaseSolverParams | None = None,
device: Device | str = "cpu", device: Device | str = "cpu",
dtype: Dtype = float32, dtype: Dtype = float32,
) -> None: ) -> None:

View file

@ -3,9 +3,9 @@ import dataclasses
from torch import Generator, Tensor, device as Device from torch import Generator, Tensor, device as Device
from refiners.foundationals.latent_diffusion.solvers.solver import ( from refiners.foundationals.latent_diffusion.solvers.solver import (
BaseSolverParams,
ModelPredictionType, ModelPredictionType,
Solver, Solver,
SolverParams,
TimestepSpacing, TimestepSpacing,
) )
@ -29,7 +29,7 @@ class DDPM(Solver):
self, self,
num_inference_steps: int, num_inference_steps: int,
first_inference_step: int = 0, first_inference_step: int = 0,
params: SolverParams | None = None, params: BaseSolverParams | None = None,
device: Device | str = "cpu", device: Device | str = "cpu",
) -> None: ) -> None:
"""Initializes a new DDPM solver. """Initializes a new DDPM solver.

View file

@ -5,9 +5,9 @@ import numpy as np
from torch import Generator, Tensor, device as Device, dtype as Dtype, exp, float32, tensor from torch import Generator, Tensor, device as Device, dtype as Dtype, exp, float32, tensor
from refiners.foundationals.latent_diffusion.solvers.solver import ( from refiners.foundationals.latent_diffusion.solvers.solver import (
BaseSolverParams,
ModelPredictionType, ModelPredictionType,
Solver, Solver,
SolverParams,
TimestepSpacing, TimestepSpacing,
) )
@ -34,7 +34,7 @@ class DPMSolver(Solver):
self, self,
num_inference_steps: int, num_inference_steps: int,
first_inference_step: int = 0, first_inference_step: int = 0,
params: SolverParams | None = None, params: BaseSolverParams | None = None,
last_step_first_order: bool = False, last_step_first_order: bool = False,
device: Device | str = "cpu", device: Device | str = "cpu",
dtype: Dtype = float32, dtype: Dtype = float32,

View file

@ -3,10 +3,10 @@ import torch
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor
from refiners.foundationals.latent_diffusion.solvers.solver import ( from refiners.foundationals.latent_diffusion.solvers.solver import (
BaseSolverParams,
ModelPredictionType, ModelPredictionType,
NoiseSchedule, NoiseSchedule,
Solver, Solver,
SolverParams,
) )
@ -21,7 +21,7 @@ class Euler(Solver):
self, self,
num_inference_steps: int, num_inference_steps: int,
first_inference_step: int = 0, first_inference_step: int = 0,
params: SolverParams | None = None, params: BaseSolverParams | None = None,
device: Device | str = "cpu", device: Device | str = "cpu",
dtype: Dtype = float32, dtype: Dtype = float32,
): ):

View file

@ -4,6 +4,7 @@ import torch
from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver
from refiners.foundationals.latent_diffusion.solvers.solver import ( from refiners.foundationals.latent_diffusion.solvers.solver import (
BaseSolverParams,
ModelPredictionType, ModelPredictionType,
Solver, Solver,
SolverParams, SolverParams,
@ -29,7 +30,7 @@ class LCMSolver(Solver):
self, self,
num_inference_steps: int, num_inference_steps: int,
first_inference_step: int = 0, first_inference_step: int = 0,
params: SolverParams | None = None, params: BaseSolverParams | None = None,
num_orig_steps: int = 50, num_orig_steps: int = 50,
device: torch.device | str = "cpu", device: torch.device | str = "cpu",
dtype: torch.dtype = torch.float32, dtype: torch.dtype = torch.float32,

View file

@ -71,7 +71,18 @@ class ModelPredictionType(str, Enum):
@dataclasses.dataclass(kw_only=True, frozen=True) @dataclasses.dataclass(kw_only=True, frozen=True)
class SolverParams: class BaseSolverParams:
num_train_timesteps: int | None
timesteps_spacing: TimestepSpacing | None
timesteps_offset: int | None
initial_diffusion_rate: float | None
final_diffusion_rate: float | None
noise_schedule: NoiseSchedule | None
model_prediction_type: ModelPredictionType | None
@dataclasses.dataclass(kw_only=True, frozen=True)
class SolverParams(BaseSolverParams):
"""Common parameters for solvers. """Common parameters for solvers.
Args: Args:
@ -94,7 +105,7 @@ class SolverParams:
@dataclasses.dataclass(kw_only=True, frozen=True) @dataclasses.dataclass(kw_only=True, frozen=True)
class ResolvedSolverParams(SolverParams): class ResolvedSolverParams(BaseSolverParams):
num_train_timesteps: int num_train_timesteps: int
timesteps_spacing: TimestepSpacing timesteps_spacing: TimestepSpacing
timesteps_offset: int timesteps_offset: int
@ -131,7 +142,7 @@ class Solver(fl.Module, ABC):
self, self,
num_inference_steps: int, num_inference_steps: int,
first_inference_step: int = 0, first_inference_step: int = 0,
params: SolverParams | None = None, params: BaseSolverParams | None = None,
device: Device | str = "cpu", device: Device | str = "cpu",
dtype: DType = float32, dtype: DType = float32,
) -> None: ) -> None:
@ -158,7 +169,7 @@ class Solver(fl.Module, ABC):
self.to(device=device, dtype=dtype) self.to(device=device, dtype=dtype)
def resolve_params(self, params: SolverParams | None) -> ResolvedSolverParams: def resolve_params(self, params: BaseSolverParams | None) -> ResolvedSolverParams:
if params is None: if params is None:
return dataclasses.replace(self.default_params) return dataclasses.replace(self.default_params)
return dataclasses.replace( return dataclasses.replace(