mirror of
https://github.com/Laurent2916/Diffusers.jl.git
synced 2024-11-09 15:02:02 +00:00
🧪 (Schedulers) add some tests to check step
and add_noise
correctness
This commit is contained in:
parent
77961913ba
commit
8b098c4c09
|
@ -13,7 +13,8 @@ julia = "1"
|
|||
[extras]
|
||||
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
|
||||
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
|
||||
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
|
||||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
|
||||
|
||||
[targets]
|
||||
test = ["Test", "Aqua"]
|
||||
test = ["Test", "Aqua", "Statistics"]
|
||||
|
|
58
test/Schedulers.jl
Normal file
58
test/Schedulers.jl
Normal file
|
@ -0,0 +1,58 @@
|
|||
import Diffusers: step, add_noise, DDPM, cosine_beta_schedule, rescale_zero_terminal_snr
|
||||
using Statistics
|
||||
using Test
|
||||
|
||||
@testset "Schedulers tests" begin
|
||||
@testset "check `step` correctness" begin
|
||||
T = 10
|
||||
batch_size = 8
|
||||
size = 128
|
||||
|
||||
# create a DDPM with a cosine beta schedule
|
||||
ddpm = Diffusers.DDPM(
|
||||
Vector{Float32},
|
||||
Diffusers.cosine_beta_schedule(T),
|
||||
)
|
||||
|
||||
# create some dummy data
|
||||
x₀ = ones(Float32, size, size, batch_size)
|
||||
ϵ = randn(Float32, size, size, batch_size)
|
||||
|
||||
for t in 1:T
|
||||
t = ones(UInt32, batch_size) .* t
|
||||
# corrupt x₀ with noise
|
||||
xₜ = Diffusers.add_noise(ddpm, x₀, ϵ, t)
|
||||
# suppose a model predicted ϵ perfectly
|
||||
_, x̂₀ = Diffusers.step(ddpm, xₜ, ϵ, t)
|
||||
# test that we recover x₀
|
||||
@test x̂₀ ≈ x₀
|
||||
end
|
||||
end
|
||||
|
||||
@testset "check `add_noise` terminal SNR" begin
|
||||
T = 10
|
||||
batch_size = 1
|
||||
size = 1000
|
||||
|
||||
# create a DDPM with a terminal SNR cosine beta schedule
|
||||
ddpm = Diffusers.DDPM(
|
||||
Vector{Float32},
|
||||
Diffusers.rescale_zero_terminal_snr(
|
||||
Diffusers.cosine_beta_schedule(T),
|
||||
),
|
||||
)
|
||||
|
||||
# create some dummy data
|
||||
x₀ = ones(Float32, size, size, batch_size)
|
||||
ϵ = randn(Float32, size, size, batch_size)
|
||||
|
||||
t = ones(UInt32, batch_size) .* T
|
||||
# corrupt x₀ with noise
|
||||
xₜ = Diffusers.add_noise(ddpm, x₀, ϵ, t)
|
||||
|
||||
@test std(xₜ) ≈ 1.0 atol=1f-3
|
||||
@test mean(xₜ) ≈ 0.0 atol=1f-3
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -3,4 +3,5 @@ using Aqua
|
|||
|
||||
Aqua.test_all(Diffusers)
|
||||
|
||||
include("Schedulers.jl")
|
||||
include("BetaSchedules.jl")
|
||||
|
|
Loading…
Reference in a new issue