mirror of
https://github.com/Laurent2916/Diffusers.jl.git
synced 2024-09-16 17:45:29 +00:00
🧪 (Schedulers) add get_velocity tests
This commit is contained in:
parent
b90b50de27
commit
184b9665ec
|
@ -1,8 +1,9 @@
|
|||
import Diffusers: reverse, forward, DDPM, cosine_beta_schedule
|
||||
import Diffusers: reverse, forward, get_velocity, DDPM, cosine_beta_schedule, VELOCITY
|
||||
import Statistics: mean, std
|
||||
using Test
|
||||
|
||||
@testset "Schedulers tests" begin
|
||||
|
||||
@testset "check `reverse` correctness" begin
|
||||
T = 10
|
||||
batch_size = 8
|
||||
|
@ -15,15 +16,36 @@ using Test
|
|||
x₀ = ones(Float32, size, size, batch_size)
|
||||
ϵ = randn(Float32, size, size, batch_size)
|
||||
|
||||
for t in 1:T
|
||||
t = ones(UInt32, batch_size) .* t
|
||||
# corrupt x₀ with noise
|
||||
xₜ = forward(ddpm, x₀, ϵ, t)
|
||||
# suppose a model predicted ϵ perfectly
|
||||
_, x̂₀ = reverse(ddpm, xₜ, ϵ, t)
|
||||
# test that we recover x₀
|
||||
@test x̂₀ ≈ x₀
|
||||
@testset "PredictionType == EPSILON" begin
|
||||
for t in 1:T
|
||||
t = ones(UInt32, batch_size) .* t
|
||||
# get xₜ from forward diffusion process
|
||||
xₜ = forward(ddpm, x₀, ϵ, t)
|
||||
# suppose a model predicted ϵ perfectly
|
||||
ϵᵧ = ϵ
|
||||
# use reverse diffusion process to retreive x̂₀
|
||||
_, x̂₀ = reverse(ddpm, xₜ, ϵᵧ, t)
|
||||
# test that we recover x₀
|
||||
@test x̂₀ ≈ x₀
|
||||
end
|
||||
end
|
||||
|
||||
@testset "PredictionType == VELOCITY" begin
|
||||
for t in 1:T
|
||||
t = ones(UInt32, batch_size) .* t
|
||||
# get xₜ from forward diffusion process
|
||||
xₜ = forward(ddpm, x₀, ϵ, t)
|
||||
# compute vₜ to train model
|
||||
vₜ = get_velocity(ddpm, x₀, ϵ, t)
|
||||
# suppose a model predicted vₜ perfectly
|
||||
vᵧ = vₜ
|
||||
# use reverse diffusion process to retreive x̂₀
|
||||
_, x̂₀ = Diffusers.reverse(ddpm, xₜ, vₜ, t, VELOCITY)
|
||||
# test that we recover x₀
|
||||
@test x̂₀ ≈ x₀
|
||||
end
|
||||
end
|
||||
|
||||
end
|
||||
|
||||
@testset "check `forward` terminal SNR" begin
|
||||
|
@ -45,4 +67,5 @@ using Test
|
|||
@test std(xₜ) ≈ 1.0 atol = 1.0f-3
|
||||
@test mean(xₜ) ≈ 0.0 atol = 1.0f-2
|
||||
end
|
||||
|
||||
end
|
||||
|
|
Loading…
Reference in a new issue