mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
Free U
This commit is contained in:
parent
92e8166c83
commit
770879a6df
92
src/refiners/foundationals/latent_diffusion/freeu.py
Normal file
92
src/refiners/foundationals/latent_diffusion/freeu.py
Normal file
|
@ -0,0 +1,92 @@
|
|||
import math
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
import refiners.fluxion.layers as fl
|
||||
import torch
|
||||
from refiners.fluxion.adapters.adapter import Adapter
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import ResidualConcatenator, SD1UNet
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
|
||||
from torch import Tensor
|
||||
from torch.fft import fftn, fftshift, ifftn, ifftshift # type: ignore
|
||||
|
||||
T = TypeVar("T", bound="SD1UNet | SDXLUNet")
|
||||
TSDFreeUAdapter = TypeVar("TSDFreeUAdapter", bound="SDFreeUAdapter[Any]") # Self (see PEP 673)
|
||||
|
||||
|
||||
def fourier_filter(x: Tensor, scale: float = 1, threshold: int = 1) -> Tensor:
|
||||
"""Fourier filter as introduced in FreeU (https://arxiv.org/abs/2309.11497).
|
||||
|
||||
This version of the method comes from here:
|
||||
https://github.com/ChenyangSi/FreeU/blob/main/demo/free_lunch_utils.py#L23
|
||||
"""
|
||||
batch, channels, height, width = x.shape
|
||||
dtype = x.dtype
|
||||
device = x.device
|
||||
|
||||
if not (math.log2(height).is_integer() and math.log2(width).is_integer()):
|
||||
x = x.to(dtype=torch.float32)
|
||||
|
||||
x_freq = fftn(x, dim=(-2, -1)) # type: ignore
|
||||
x_freq = fftshift(x_freq, dim=(-2, -1)) # type: ignore
|
||||
mask = torch.ones((batch, channels, height, width), device=device) # type: ignore
|
||||
|
||||
center_row, center_col = height // 2, width // 2 # type: ignore
|
||||
mask[..., center_row - threshold : center_row + threshold, center_col - threshold : center_col + threshold] = scale
|
||||
x_freq = x_freq * mask # type: ignore
|
||||
|
||||
x_freq = ifftshift(x_freq, dim=(-2, -1)) # type: ignore
|
||||
x_filtered = ifftn(x_freq, dim=(-2, -1)).real # type: ignore
|
||||
|
||||
return x_filtered.to(dtype=dtype) # type: ignore
|
||||
|
||||
|
||||
class FreeUBackboneFeatures(fl.Module):
|
||||
def __init__(self, backbone_scale: float) -> None:
|
||||
super().__init__()
|
||||
self.backbone_scale = backbone_scale
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
num_half_channels = x.shape[1] // 2
|
||||
x[:, :num_half_channels] = x[:, :num_half_channels] * self.backbone_scale
|
||||
return x
|
||||
|
||||
|
||||
class FreeUSkipFeatures(fl.Chain):
|
||||
def __init__(self, n: int, skip_scale: float) -> None:
|
||||
super().__init__(
|
||||
fl.UseContext(context="unet", key="residuals").compose(lambda residuals: residuals[n]),
|
||||
fl.Lambda(lambda x: fourier_filter(x, scale=skip_scale)),
|
||||
)
|
||||
|
||||
|
||||
class FreeUResidualConcatenator(fl.Concatenate):
|
||||
def __init__(self, n: int, backbone_scale: float, skip_scale: float) -> None:
|
||||
super().__init__(
|
||||
FreeUBackboneFeatures(backbone_scale),
|
||||
FreeUSkipFeatures(n, skip_scale),
|
||||
dim=1,
|
||||
)
|
||||
|
||||
|
||||
class SDFreeUAdapter(Generic[T], fl.Chain, Adapter[T]):
|
||||
def __init__(self, target: T, backbone_scales: list[float], skip_scales: list[float]) -> None:
|
||||
assert len(backbone_scales) == len(skip_scales)
|
||||
assert len(backbone_scales) <= len(target.UpBlocks)
|
||||
self.backbone_scales = backbone_scales
|
||||
self.skip_scales = skip_scales
|
||||
with self.setup_adapter(target):
|
||||
super().__init__(target)
|
||||
|
||||
def inject(self: TSDFreeUAdapter, parent: fl.Chain | None = None) -> TSDFreeUAdapter:
|
||||
for n, (backbone_scale, skip_scale) in enumerate(zip(self.backbone_scales, self.skip_scales)):
|
||||
block = self.target.UpBlocks[n]
|
||||
concat = block.ensure_find(ResidualConcatenator)
|
||||
block.replace(concat, FreeUResidualConcatenator(-n - 2, backbone_scale, skip_scale))
|
||||
return super().inject(parent)
|
||||
|
||||
def eject(self) -> None:
|
||||
for n in range(len(self.backbone_scales)):
|
||||
block = self.target.UpBlocks[n]
|
||||
concat = block.ensure_find(FreeUResidualConcatenator)
|
||||
block.replace(concat, ResidualConcatenator(-n - 2))
|
||||
super().eject()
|
Loading…
Reference in a new issue