🐛 (examples/swissroll) use new method names

This commit is contained in:
Laureηt 2023-08-14 16:30:35 +02:00
parent a6fe13ba5e
commit 1e5ceb5140
Signed by: Laurent
SSH key fingerprint: SHA256:kZEpW8cMJ54PDeCvOhzreNr4FSh6R13CMGH/POoO8DI

View file

@ -64,7 +64,7 @@ anim = @animate for t in cat(fill(0, 20), 1:num_timesteps, fill(num_timesteps, 2
xlims!(-3, 3)
ylims!(-3, 3)
else
noisy_data = Diffusers.Schedulers.add_noise(scheduler, data, noise, [t])
noisy_data = Diffusers.Schedulers.forward(scheduler, data, noise, [t])
scatter(noise[1, :], noise[2, :],
alpha=0.3,
aspectratio=:equal,
@ -130,7 +130,7 @@ for epoch = 1:num_epochs
for data in dataloader
noise = randn(Float32, size(data))
timesteps = rand(1:num_timesteps, size(data, ndims(data)))
noisy_data = Diffusers.Schedulers.add_noise(scheduler, data, noise, timesteps)
noisy_data = Diffusers.Schedulers.forward(scheduler, data, noise, timesteps)
grads = Flux.gradient(model) do m
model_output = m(noisy_data, timesteps)
loss(noise, model_output)
@ -147,7 +147,7 @@ sample_old = sample
predictions = []
for timestep in num_timesteps:-1:1
model_output = model(sample, [timestep])
sample, x0_pred = Diffusers.Schedulers.step(scheduler, sample, model_output, [timestep])
sample, x0_pred = Diffusers.Schedulers.reverse(scheduler, sample, model_output, [timestep])
push!(predictions, (sample, x0_pred, timestep))
end