mirror of
https://github.com/Laurent2916/Diffusers.jl.git
synced 2024-09-16 17:45:29 +00:00
🐛 (examples/swissroll) use new method names
This commit is contained in:
parent
a6fe13ba5e
commit
1e5ceb5140
|
@ -64,7 +64,7 @@ anim = @animate for t in cat(fill(0, 20), 1:num_timesteps, fill(num_timesteps, 2
|
|||
xlims!(-3, 3)
|
||||
ylims!(-3, 3)
|
||||
else
|
||||
noisy_data = Diffusers.Schedulers.add_noise(scheduler, data, noise, [t])
|
||||
noisy_data = Diffusers.Schedulers.forward(scheduler, data, noise, [t])
|
||||
scatter(noise[1, :], noise[2, :],
|
||||
alpha=0.3,
|
||||
aspectratio=:equal,
|
||||
|
@ -130,7 +130,7 @@ for epoch = 1:num_epochs
|
|||
for data in dataloader
|
||||
noise = randn(Float32, size(data))
|
||||
timesteps = rand(1:num_timesteps, size(data, ndims(data)))
|
||||
noisy_data = Diffusers.Schedulers.add_noise(scheduler, data, noise, timesteps)
|
||||
noisy_data = Diffusers.Schedulers.forward(scheduler, data, noise, timesteps)
|
||||
grads = Flux.gradient(model) do m
|
||||
model_output = m(noisy_data, timesteps)
|
||||
loss(noise, model_output)
|
||||
|
@ -147,7 +147,7 @@ sample_old = sample
|
|||
predictions = []
|
||||
for timestep in num_timesteps:-1:1
|
||||
model_output = model(sample, [timestep])
|
||||
sample, x0_pred = Diffusers.Schedulers.step(scheduler, sample, model_output, [timestep])
|
||||
sample, x0_pred = Diffusers.Schedulers.reverse(scheduler, sample, model_output, [timestep])
|
||||
push!(predictions, (sample, x0_pred, timestep))
|
||||
end
|
||||
|
||||
|
|
Loading…
Reference in a new issue