mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
fix typing problems
This commit is contained in:
parent
d1fc845bc2
commit
98de9d13d3
|
@ -1,5 +1,5 @@
|
||||||
import argparse
|
import argparse
|
||||||
from typing import TYPE_CHECKING, cast
|
from typing import cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
@ -15,8 +15,6 @@ except ImportError:
|
||||||
"Please download the model.py file from https://github.com/carolineec/informative-drawings and add it to your"
|
"Please download the model.py file from https://github.com/carolineec/informative-drawings and add it to your"
|
||||||
" PYTHONPATH"
|
" PYTHONPATH"
|
||||||
)
|
)
|
||||||
if TYPE_CHECKING:
|
|
||||||
Generator = cast(nn.Module, Generator)
|
|
||||||
|
|
||||||
|
|
||||||
class Args(argparse.Namespace):
|
class Args(argparse.Namespace):
|
||||||
|
@ -27,7 +25,7 @@ class Args(argparse.Namespace):
|
||||||
|
|
||||||
|
|
||||||
def setup_converter(args: Args) -> ModelConverter:
|
def setup_converter(args: Args) -> ModelConverter:
|
||||||
source = Generator(3, 1, 3)
|
source = cast(nn.Module, Generator(3, 1, 3))
|
||||||
source.load_state_dict(state_dict=load_tensors(args.source_path))
|
source.load_state_dict(state_dict=load_tensors(args.source_path))
|
||||||
source.eval()
|
source.eval()
|
||||||
target = InformativeDrawings()
|
target = InformativeDrawings()
|
||||||
|
|
|
@ -3,9 +3,9 @@ import dataclasses
|
||||||
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, sqrt, tensor
|
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, sqrt, tensor
|
||||||
|
|
||||||
from refiners.foundationals.latent_diffusion.solvers.solver import (
|
from refiners.foundationals.latent_diffusion.solvers.solver import (
|
||||||
|
BaseSolverParams,
|
||||||
ModelPredictionType,
|
ModelPredictionType,
|
||||||
Solver,
|
Solver,
|
||||||
SolverParams,
|
|
||||||
TimestepSpacing,
|
TimestepSpacing,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -26,7 +26,7 @@ class DDIM(Solver):
|
||||||
self,
|
self,
|
||||||
num_inference_steps: int,
|
num_inference_steps: int,
|
||||||
first_inference_step: int = 0,
|
first_inference_step: int = 0,
|
||||||
params: SolverParams | None = None,
|
params: BaseSolverParams | None = None,
|
||||||
device: Device | str = "cpu",
|
device: Device | str = "cpu",
|
||||||
dtype: Dtype = float32,
|
dtype: Dtype = float32,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
|
@ -3,9 +3,9 @@ import dataclasses
|
||||||
from torch import Generator, Tensor, device as Device
|
from torch import Generator, Tensor, device as Device
|
||||||
|
|
||||||
from refiners.foundationals.latent_diffusion.solvers.solver import (
|
from refiners.foundationals.latent_diffusion.solvers.solver import (
|
||||||
|
BaseSolverParams,
|
||||||
ModelPredictionType,
|
ModelPredictionType,
|
||||||
Solver,
|
Solver,
|
||||||
SolverParams,
|
|
||||||
TimestepSpacing,
|
TimestepSpacing,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -29,7 +29,7 @@ class DDPM(Solver):
|
||||||
self,
|
self,
|
||||||
num_inference_steps: int,
|
num_inference_steps: int,
|
||||||
first_inference_step: int = 0,
|
first_inference_step: int = 0,
|
||||||
params: SolverParams | None = None,
|
params: BaseSolverParams | None = None,
|
||||||
device: Device | str = "cpu",
|
device: Device | str = "cpu",
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initializes a new DDPM solver.
|
"""Initializes a new DDPM solver.
|
||||||
|
|
|
@ -5,9 +5,9 @@ import numpy as np
|
||||||
from torch import Generator, Tensor, device as Device, dtype as Dtype, exp, float32, tensor
|
from torch import Generator, Tensor, device as Device, dtype as Dtype, exp, float32, tensor
|
||||||
|
|
||||||
from refiners.foundationals.latent_diffusion.solvers.solver import (
|
from refiners.foundationals.latent_diffusion.solvers.solver import (
|
||||||
|
BaseSolverParams,
|
||||||
ModelPredictionType,
|
ModelPredictionType,
|
||||||
Solver,
|
Solver,
|
||||||
SolverParams,
|
|
||||||
TimestepSpacing,
|
TimestepSpacing,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -34,7 +34,7 @@ class DPMSolver(Solver):
|
||||||
self,
|
self,
|
||||||
num_inference_steps: int,
|
num_inference_steps: int,
|
||||||
first_inference_step: int = 0,
|
first_inference_step: int = 0,
|
||||||
params: SolverParams | None = None,
|
params: BaseSolverParams | None = None,
|
||||||
last_step_first_order: bool = False,
|
last_step_first_order: bool = False,
|
||||||
device: Device | str = "cpu",
|
device: Device | str = "cpu",
|
||||||
dtype: Dtype = float32,
|
dtype: Dtype = float32,
|
||||||
|
|
|
@ -3,10 +3,10 @@ import torch
|
||||||
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor
|
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor
|
||||||
|
|
||||||
from refiners.foundationals.latent_diffusion.solvers.solver import (
|
from refiners.foundationals.latent_diffusion.solvers.solver import (
|
||||||
|
BaseSolverParams,
|
||||||
ModelPredictionType,
|
ModelPredictionType,
|
||||||
NoiseSchedule,
|
NoiseSchedule,
|
||||||
Solver,
|
Solver,
|
||||||
SolverParams,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -21,7 +21,7 @@ class Euler(Solver):
|
||||||
self,
|
self,
|
||||||
num_inference_steps: int,
|
num_inference_steps: int,
|
||||||
first_inference_step: int = 0,
|
first_inference_step: int = 0,
|
||||||
params: SolverParams | None = None,
|
params: BaseSolverParams | None = None,
|
||||||
device: Device | str = "cpu",
|
device: Device | str = "cpu",
|
||||||
dtype: Dtype = float32,
|
dtype: Dtype = float32,
|
||||||
):
|
):
|
||||||
|
|
|
@ -4,6 +4,7 @@ import torch
|
||||||
|
|
||||||
from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver
|
from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver
|
||||||
from refiners.foundationals.latent_diffusion.solvers.solver import (
|
from refiners.foundationals.latent_diffusion.solvers.solver import (
|
||||||
|
BaseSolverParams,
|
||||||
ModelPredictionType,
|
ModelPredictionType,
|
||||||
Solver,
|
Solver,
|
||||||
SolverParams,
|
SolverParams,
|
||||||
|
@ -29,7 +30,7 @@ class LCMSolver(Solver):
|
||||||
self,
|
self,
|
||||||
num_inference_steps: int,
|
num_inference_steps: int,
|
||||||
first_inference_step: int = 0,
|
first_inference_step: int = 0,
|
||||||
params: SolverParams | None = None,
|
params: BaseSolverParams | None = None,
|
||||||
num_orig_steps: int = 50,
|
num_orig_steps: int = 50,
|
||||||
device: torch.device | str = "cpu",
|
device: torch.device | str = "cpu",
|
||||||
dtype: torch.dtype = torch.float32,
|
dtype: torch.dtype = torch.float32,
|
||||||
|
|
|
@ -71,7 +71,18 @@ class ModelPredictionType(str, Enum):
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(kw_only=True, frozen=True)
|
@dataclasses.dataclass(kw_only=True, frozen=True)
|
||||||
class SolverParams:
|
class BaseSolverParams:
|
||||||
|
num_train_timesteps: int | None
|
||||||
|
timesteps_spacing: TimestepSpacing | None
|
||||||
|
timesteps_offset: int | None
|
||||||
|
initial_diffusion_rate: float | None
|
||||||
|
final_diffusion_rate: float | None
|
||||||
|
noise_schedule: NoiseSchedule | None
|
||||||
|
model_prediction_type: ModelPredictionType | None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass(kw_only=True, frozen=True)
|
||||||
|
class SolverParams(BaseSolverParams):
|
||||||
"""Common parameters for solvers.
|
"""Common parameters for solvers.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -94,7 +105,7 @@ class SolverParams:
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(kw_only=True, frozen=True)
|
@dataclasses.dataclass(kw_only=True, frozen=True)
|
||||||
class ResolvedSolverParams(SolverParams):
|
class ResolvedSolverParams(BaseSolverParams):
|
||||||
num_train_timesteps: int
|
num_train_timesteps: int
|
||||||
timesteps_spacing: TimestepSpacing
|
timesteps_spacing: TimestepSpacing
|
||||||
timesteps_offset: int
|
timesteps_offset: int
|
||||||
|
@ -131,7 +142,7 @@ class Solver(fl.Module, ABC):
|
||||||
self,
|
self,
|
||||||
num_inference_steps: int,
|
num_inference_steps: int,
|
||||||
first_inference_step: int = 0,
|
first_inference_step: int = 0,
|
||||||
params: SolverParams | None = None,
|
params: BaseSolverParams | None = None,
|
||||||
device: Device | str = "cpu",
|
device: Device | str = "cpu",
|
||||||
dtype: DType = float32,
|
dtype: DType = float32,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -158,7 +169,7 @@ class Solver(fl.Module, ABC):
|
||||||
|
|
||||||
self.to(device=device, dtype=dtype)
|
self.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
def resolve_params(self, params: SolverParams | None) -> ResolvedSolverParams:
|
def resolve_params(self, params: BaseSolverParams | None) -> ResolvedSolverParams:
|
||||||
if params is None:
|
if params is None:
|
||||||
return dataclasses.replace(self.default_params)
|
return dataclasses.replace(self.default_params)
|
||||||
return dataclasses.replace(
|
return dataclasses.replace(
|
||||||
|
|
Loading…
Reference in a new issue