diff --git a/src/Schedulers/DDPM.jl b/src/Schedulers/DDPM.jl index 2b3f870..902462e 100644 --- a/src/Schedulers/DDPM.jl +++ b/src/Schedulers/DDPM.jl @@ -179,18 +179,18 @@ function get_variance( if variance_type == FIXED_SMALL # arxiv:2006.11239 Eq. 6 # arxiv:2208.11970 Eq. 70 - σₜ = β̅ₜ₋₁ ./ β̅ₜ .* βₜ + σ²ₜ = β̅ₜ₋₁ ./ β̅ₜ .* βₜ elseif variance_type == FIXED_SMALL_LOG - σₜ = β̅ₜ₋₁ ./ β̅ₜ .* βₜ - σₜ = log.(σₜ) + σ²ₜ = β̅ₜ₋₁ ./ β̅ₜ .* βₜ + σ²ₜ = log.(σ²ₜ) elseif variance_type == FIXED_LARGE - σₜ = βₜ + σ²ₜ = βₜ elseif variance_type == FIXED_LARGE_LOG - σₜ = βₜ - σₜ = log.(σₜ) + σ²ₜ = βₜ + σ²ₜ = log.(σ²ₜ) else throw("unimplemented variance type") end - return σₜ + return σ²ₜ end