diff --git a/examples/swissroll.jl b/examples/swissroll.jl index 4fe775e..1fdfd43 100644 --- a/examples/swissroll.jl +++ b/examples/swissroll.jl @@ -1,7 +1,7 @@ import Diffusers import Diffusers.Schedulers import Diffusers.Schedulers: DDPM -import Diffusers.BetaSchedules: linear_beta_schedule +import Diffusers.BetaSchedules: cosine_beta_schedule using Flux using Random using Plots @@ -38,7 +38,7 @@ scatter(dataset[1, :], dataset[2, :], num_timesteps = 100 scheduler = DDPM( - linear_beta_schedule(num_timesteps) + cosine_beta_schedule(num_timesteps) ); data = dataset