From 1e5ceb5140ac487a619ea293306495fa357783c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Laure=CE=B7t?= Date: Mon, 14 Aug 2023 16:30:35 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20(examples/swissroll)=20use=20new?= =?UTF-8?q?=20method=20names?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- examples/swissroll.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/swissroll.jl b/examples/swissroll.jl index 9c1f28d..4fe775e 100644 --- a/examples/swissroll.jl +++ b/examples/swissroll.jl @@ -64,7 +64,7 @@ anim = @animate for t in cat(fill(0, 20), 1:num_timesteps, fill(num_timesteps, 2 xlims!(-3, 3) ylims!(-3, 3) else - noisy_data = Diffusers.Schedulers.add_noise(scheduler, data, noise, [t]) + noisy_data = Diffusers.Schedulers.forward(scheduler, data, noise, [t]) scatter(noise[1, :], noise[2, :], alpha=0.3, aspectratio=:equal, @@ -130,7 +130,7 @@ for epoch = 1:num_epochs for data in dataloader noise = randn(Float32, size(data)) timesteps = rand(1:num_timesteps, size(data, ndims(data))) - noisy_data = Diffusers.Schedulers.add_noise(scheduler, data, noise, timesteps) + noisy_data = Diffusers.Schedulers.forward(scheduler, data, noise, timesteps) grads = Flux.gradient(model) do m model_output = m(noisy_data, timesteps) loss(noise, model_output) @@ -147,7 +147,7 @@ sample_old = sample predictions = [] for timestep in num_timesteps:-1:1 model_output = model(sample, [timestep]) - sample, x0_pred = Diffusers.Schedulers.step(scheduler, sample, model_output, [timestep]) + sample, x0_pred = Diffusers.Schedulers.reverse(scheduler, sample, model_output, [timestep]) push!(predictions, (sample, x0_pred, timestep)) end