mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 23:28:45 +00:00
fix typing problems
This commit is contained in:
parent
d1fc845bc2
commit
98de9d13d3
|
@ -1,5 +1,5 @@
|
|||
import argparse
|
||||
from typing import TYPE_CHECKING, cast
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
@ -15,8 +15,6 @@ except ImportError:
|
|||
"Please download the model.py file from https://github.com/carolineec/informative-drawings and add it to your"
|
||||
" PYTHONPATH"
|
||||
)
|
||||
if TYPE_CHECKING:
|
||||
Generator = cast(nn.Module, Generator)
|
||||
|
||||
|
||||
class Args(argparse.Namespace):
|
||||
|
@ -27,7 +25,7 @@ class Args(argparse.Namespace):
|
|||
|
||||
|
||||
def setup_converter(args: Args) -> ModelConverter:
|
||||
source = Generator(3, 1, 3)
|
||||
source = cast(nn.Module, Generator(3, 1, 3))
|
||||
source.load_state_dict(state_dict=load_tensors(args.source_path))
|
||||
source.eval()
|
||||
target = InformativeDrawings()
|
||||
|
|
|
@ -3,9 +3,9 @@ import dataclasses
|
|||
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, sqrt, tensor
|
||||
|
||||
from refiners.foundationals.latent_diffusion.solvers.solver import (
|
||||
BaseSolverParams,
|
||||
ModelPredictionType,
|
||||
Solver,
|
||||
SolverParams,
|
||||
TimestepSpacing,
|
||||
)
|
||||
|
||||
|
@ -26,7 +26,7 @@ class DDIM(Solver):
|
|||
self,
|
||||
num_inference_steps: int,
|
||||
first_inference_step: int = 0,
|
||||
params: SolverParams | None = None,
|
||||
params: BaseSolverParams | None = None,
|
||||
device: Device | str = "cpu",
|
||||
dtype: Dtype = float32,
|
||||
) -> None:
|
||||
|
|
|
@ -3,9 +3,9 @@ import dataclasses
|
|||
from torch import Generator, Tensor, device as Device
|
||||
|
||||
from refiners.foundationals.latent_diffusion.solvers.solver import (
|
||||
BaseSolverParams,
|
||||
ModelPredictionType,
|
||||
Solver,
|
||||
SolverParams,
|
||||
TimestepSpacing,
|
||||
)
|
||||
|
||||
|
@ -29,7 +29,7 @@ class DDPM(Solver):
|
|||
self,
|
||||
num_inference_steps: int,
|
||||
first_inference_step: int = 0,
|
||||
params: SolverParams | None = None,
|
||||
params: BaseSolverParams | None = None,
|
||||
device: Device | str = "cpu",
|
||||
) -> None:
|
||||
"""Initializes a new DDPM solver.
|
||||
|
|
|
@ -5,9 +5,9 @@ import numpy as np
|
|||
from torch import Generator, Tensor, device as Device, dtype as Dtype, exp, float32, tensor
|
||||
|
||||
from refiners.foundationals.latent_diffusion.solvers.solver import (
|
||||
BaseSolverParams,
|
||||
ModelPredictionType,
|
||||
Solver,
|
||||
SolverParams,
|
||||
TimestepSpacing,
|
||||
)
|
||||
|
||||
|
@ -34,7 +34,7 @@ class DPMSolver(Solver):
|
|||
self,
|
||||
num_inference_steps: int,
|
||||
first_inference_step: int = 0,
|
||||
params: SolverParams | None = None,
|
||||
params: BaseSolverParams | None = None,
|
||||
last_step_first_order: bool = False,
|
||||
device: Device | str = "cpu",
|
||||
dtype: Dtype = float32,
|
||||
|
|
|
@ -3,10 +3,10 @@ import torch
|
|||
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor
|
||||
|
||||
from refiners.foundationals.latent_diffusion.solvers.solver import (
|
||||
BaseSolverParams,
|
||||
ModelPredictionType,
|
||||
NoiseSchedule,
|
||||
Solver,
|
||||
SolverParams,
|
||||
)
|
||||
|
||||
|
||||
|
@ -21,7 +21,7 @@ class Euler(Solver):
|
|||
self,
|
||||
num_inference_steps: int,
|
||||
first_inference_step: int = 0,
|
||||
params: SolverParams | None = None,
|
||||
params: BaseSolverParams | None = None,
|
||||
device: Device | str = "cpu",
|
||||
dtype: Dtype = float32,
|
||||
):
|
||||
|
|
|
@ -4,6 +4,7 @@ import torch
|
|||
|
||||
from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver
|
||||
from refiners.foundationals.latent_diffusion.solvers.solver import (
|
||||
BaseSolverParams,
|
||||
ModelPredictionType,
|
||||
Solver,
|
||||
SolverParams,
|
||||
|
@ -29,7 +30,7 @@ class LCMSolver(Solver):
|
|||
self,
|
||||
num_inference_steps: int,
|
||||
first_inference_step: int = 0,
|
||||
params: SolverParams | None = None,
|
||||
params: BaseSolverParams | None = None,
|
||||
num_orig_steps: int = 50,
|
||||
device: torch.device | str = "cpu",
|
||||
dtype: torch.dtype = torch.float32,
|
||||
|
|
|
@ -71,7 +71,18 @@ class ModelPredictionType(str, Enum):
|
|||
|
||||
|
||||
@dataclasses.dataclass(kw_only=True, frozen=True)
|
||||
class SolverParams:
|
||||
class BaseSolverParams:
|
||||
num_train_timesteps: int | None
|
||||
timesteps_spacing: TimestepSpacing | None
|
||||
timesteps_offset: int | None
|
||||
initial_diffusion_rate: float | None
|
||||
final_diffusion_rate: float | None
|
||||
noise_schedule: NoiseSchedule | None
|
||||
model_prediction_type: ModelPredictionType | None
|
||||
|
||||
|
||||
@dataclasses.dataclass(kw_only=True, frozen=True)
|
||||
class SolverParams(BaseSolverParams):
|
||||
"""Common parameters for solvers.
|
||||
|
||||
Args:
|
||||
|
@ -94,7 +105,7 @@ class SolverParams:
|
|||
|
||||
|
||||
@dataclasses.dataclass(kw_only=True, frozen=True)
|
||||
class ResolvedSolverParams(SolverParams):
|
||||
class ResolvedSolverParams(BaseSolverParams):
|
||||
num_train_timesteps: int
|
||||
timesteps_spacing: TimestepSpacing
|
||||
timesteps_offset: int
|
||||
|
@ -131,7 +142,7 @@ class Solver(fl.Module, ABC):
|
|||
self,
|
||||
num_inference_steps: int,
|
||||
first_inference_step: int = 0,
|
||||
params: SolverParams | None = None,
|
||||
params: BaseSolverParams | None = None,
|
||||
device: Device | str = "cpu",
|
||||
dtype: DType = float32,
|
||||
) -> None:
|
||||
|
@ -158,7 +169,7 @@ class Solver(fl.Module, ABC):
|
|||
|
||||
self.to(device=device, dtype=dtype)
|
||||
|
||||
def resolve_params(self, params: SolverParams | None) -> ResolvedSolverParams:
|
||||
def resolve_params(self, params: BaseSolverParams | None) -> ResolvedSolverParams:
|
||||
if params is None:
|
||||
return dataclasses.replace(self.default_params)
|
||||
return dataclasses.replace(
|
||||
|
|
Loading…
Reference in a new issue