very basic working example of forward diffusion

This commit is contained in:
Laureηt 2023-07-03 21:12:22 +02:00
parent e7756d4304
commit 03514a68d0
Signed by: Laurent
SSH key fingerprint: SHA256:kZEpW8cMJ54PDeCvOhzreNr4FSh6R13CMGH/POoO8DI
8 changed files with 274 additions and 61 deletions

View file

@ -1,4 +1,5 @@
{
"julia.useRevise": true,
"julia.persistentSession.enabled": true,
"editor.unicodeHighlight.ambiguousCharacters": false
}

View file

@ -2,7 +2,7 @@
julia_version = "1.9.1"
manifest_format = "2.0"
project_hash = "4277c3d2d0e9ba6c27261646589d4ccf517491c4"
project_hash = "d24809eea0828bfe960e6b7544e6415082911917"
[[deps.AbstractFFTs]]
deps = ["LinearAlgebra"]
@ -14,6 +14,24 @@ weakdeps = ["ChainRulesCore"]
[deps.AbstractFFTs.extensions]
AbstractFFTsChainRulesCoreExt = "ChainRulesCore"
[[deps.Accessors]]
deps = ["CompositionsBase", "ConstructionBase", "Dates", "InverseFunctions", "LinearAlgebra", "MacroTools", "Requires", "Test"]
git-tree-sha1 = "954634616d5846d8e216df1298be2298d55280b2"
uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
version = "0.1.32"
[deps.Accessors.extensions]
AccessorsAxisKeysExt = "AxisKeys"
AccessorsIntervalSetsExt = "IntervalSets"
AccessorsStaticArraysExt = "StaticArrays"
AccessorsStructArraysExt = "StructArrays"
[deps.Accessors.weakdeps]
AxisKeys = "94b1ba4f-4ee9-5380-92f1-94cde586c3c5"
IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
[[deps.Adapt]]
deps = ["LinearAlgebra", "Requires"]
git-tree-sha1 = "76289dc51920fdc6e0013c872ba9551d54961c24"
@ -49,24 +67,10 @@ uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
version = "0.4.2"
[[deps.BangBang]]
deps = ["Compat", "ConstructionBase", "InitialValues", "LinearAlgebra", "Requires", "Setfield", "Tables"]
git-tree-sha1 = "e28912ce94077686443433c2800104b061a827ed"
deps = ["Compat", "ConstructionBase", "Future", "InitialValues", "LinearAlgebra", "Requires", "Setfield", "Tables", "ZygoteRules"]
git-tree-sha1 = "7fe6d92c4f281cf4ca6f2fba0ce7b299742da7ca"
uuid = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
version = "0.3.39"
[deps.BangBang.extensions]
BangBangChainRulesCoreExt = "ChainRulesCore"
BangBangDataFramesExt = "DataFrames"
BangBangStaticArraysExt = "StaticArrays"
BangBangStructArraysExt = "StructArrays"
BangBangTypedTablesExt = "TypedTables"
[deps.BangBang.weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9"
version = "0.3.37"
[[deps.Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
@ -93,10 +97,10 @@ uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82"
version = "0.4.2"
[[deps.CUDA]]
deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CUDA_Driver_jll", "CUDA_Runtime_Discovery", "CUDA_Runtime_jll", "CompilerSupportLibraries_jll", "ExprTools", "GPUArrays", "GPUCompiler", "KernelAbstractions", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Preferences", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "UnsafeAtomicsLLVM"]
git-tree-sha1 = "442d989978ed3ff4e174c928ee879dc09d1ef693"
deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CUDA_Driver_jll", "CUDA_Runtime_Discovery", "CUDA_Runtime_jll", "ExprTools", "GPUArrays", "GPUCompiler", "KernelAbstractions", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Preferences", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "UnsafeAtomicsLLVM"]
git-tree-sha1 = "35160ef0f03b14768abfd68b830f8e3940e8e0dc"
uuid = "052768ef-5323-5732-b1bb-66c8b64840ba"
version = "4.3.2"
version = "4.4.0"
[[deps.CUDA_Driver_jll]]
deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"]
@ -118,9 +122,9 @@ version = "0.6.0+0"
[[deps.CUDNN_jll]]
deps = ["Artifacts", "CUDA_Runtime_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"]
git-tree-sha1 = "2918fbffb50e3b7a0b9127617587afa76d4276e8"
git-tree-sha1 = "c30b29597102341a1ea4c2175c4acae9ae522c9d"
uuid = "62b44479-cb7b-5706-934f-f13b2eb2e645"
version = "8.8.1+0"
version = "8.9.2+0"
[[deps.Cairo_jll]]
deps = ["Artifacts", "Bzip2_jll", "CompilerSupportLibraries_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "JLLWrappers", "LZO_jll", "Libdl", "Pixman_jll", "Pkg", "Xorg_libXext_jll", "Xorg_libXrender_jll", "Zlib_jll", "libpng_jll"]
@ -140,6 +144,12 @@ git-tree-sha1 = "e30f2f4e20f7f186dc36529910beaedc60cfa644"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "1.16.0"
[[deps.CodeTracking]]
deps = ["InteractiveUtils", "UUIDs"]
git-tree-sha1 = "d730914ef30a06732bdd9f763f6cc32e92ffbff1"
uuid = "da1fd8a2-8d9e-5ec2-8556-3022fb5608a2"
version = "1.3.1"
[[deps.CodecZlib]]
deps = ["TranscodingStreams", "Zlib_jll"]
git-tree-sha1 = "9c209fb7536406834aa938fb149964b985de6c83"
@ -177,14 +187,10 @@ uuid = "bbf7d656-a473-5ed7-a52c-81e309532950"
version = "0.3.0"
[[deps.Compat]]
deps = ["UUIDs"]
git-tree-sha1 = "4e88377ae7ebeaf29a047aa1ee40826e0b708a5d"
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
git-tree-sha1 = "6c0100a8cf4ed66f66e2039af7cde3357814bad2"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "4.7.0"
weakdeps = ["Dates", "LinearAlgebra"]
[deps.Compat.extensions]
CompatLinearAlgebraExt = "LinearAlgebra"
version = "3.46.2"
[[deps.CompilerSupportLibraries_jll]]
deps = ["Artifacts", "Libdl"]
@ -195,13 +201,11 @@ version = "1.0.2+0"
git-tree-sha1 = "802bb88cd69dfd1509f6670416bd4434015693ad"
uuid = "a33af91c-f02d-484b-be07-31d278c5ca2b"
version = "0.1.2"
weakdeps = ["InverseFunctions"]
[deps.CompositionsBase.extensions]
CompositionsBaseInverseFunctionsExt = "InverseFunctions"
[deps.CompositionsBase.weakdeps]
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
[[deps.ConcurrentUtilities]]
deps = ["Serialization", "Sockets"]
git-tree-sha1 = "96d823b94ba8d187a6d8f0826e731195a74b90e9"
@ -337,9 +341,9 @@ uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee"
[[deps.FillArrays]]
deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"]
git-tree-sha1 = "0b3b52afd0f87b0a3f5ada0466352d125c9db458"
git-tree-sha1 = "2250347838b28a108d1967663cba57bfb3c02a58"
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
version = "1.2.1"
version = "1.3.0"
[[deps.FixedPointNumbers]]
deps = ["Statistics"]
@ -361,6 +365,12 @@ version = "0.13.17"
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
[[deps.FoldsThreads]]
deps = ["Accessors", "FunctionWrappers", "InitialValues", "SplittablesBase", "Transducers"]
git-tree-sha1 = "cdba9b84cad7ddb89a326e10bf48d6dd4ffd0252"
uuid = "9c68100b-dfe1-47cf-94c8-95104e173443"
version = "0.1.2"
[[deps.Fontconfig_jll]]
deps = ["Artifacts", "Bzip2_jll", "Expat_jll", "FreeType2_jll", "JLLWrappers", "Libdl", "Libuuid_jll", "Pkg", "Zlib_jll"]
git-tree-sha1 = "21efd19106a55620a188615da6d3d06cd7f6ee03"
@ -395,6 +405,11 @@ git-tree-sha1 = "aa31987c2ba8704e23c6c8ba8a4f769d5d7e4f91"
uuid = "559328eb-81f9-559d-9380-de523a88c83c"
version = "1.0.10+0"
[[deps.FunctionWrappers]]
git-tree-sha1 = "d62485945ce5ae9c0c48f124a84998d755bae00e"
uuid = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"
version = "1.1.3"
[[deps.Functors]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "478f8c3145bb91d82c2cf20433e8c1b30df454cc"
@ -425,9 +440,9 @@ version = "0.1.5"
[[deps.GPUCompiler]]
deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Scratch", "TimerOutputs", "UUIDs"]
git-tree-sha1 = "cb090aea21c6ca78d59672a7e7d13bd56d09de64"
git-tree-sha1 = "d60b5fe7333b5fa41a0378ead6614f1ab51cf6d0"
uuid = "61eb1bfa-7361-4325-ad38-22787b887f55"
version = "0.20.3"
version = "0.21.3"
[[deps.GR]]
deps = ["Artifacts", "Base64", "DelimitedFiles", "Downloads", "GR_jll", "HTTP", "JSON", "Libdl", "LinearAlgebra", "Pkg", "Preferences", "Printf", "Random", "Serialization", "Sockets", "TOML", "Tar", "Test", "UUIDs", "p7zip_jll"]
@ -491,6 +506,12 @@ version = "0.3.1"
deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
[[deps.InverseFunctions]]
deps = ["Test"]
git-tree-sha1 = "edd1c1ac227767c75e8518defdf6e48dbfa7c3b0"
uuid = "3587e190-3f89-42d0-90ee-14403ec27112"
version = "0.1.10"
[[deps.IrrationalConstants]]
git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2"
uuid = "92d709cd-6900-40b7-9082-c6be49f344b6"
@ -525,6 +546,12 @@ git-tree-sha1 = "6f2675ef130a300a112286de91973805fcc5ffbc"
uuid = "aacddb02-875f-59d6-b918-886e6ef4fbf8"
version = "2.1.91+0"
[[deps.JuliaInterpreter]]
deps = ["CodeTracking", "InteractiveUtils", "Random", "UUIDs"]
git-tree-sha1 = "6a125e6a4cb391e0b9adbd1afa9e771c2179f8ef"
uuid = "aa1ae85d-cabe-5617-a682-6adf51b2e16a"
version = "0.9.23"
[[deps.JuliaVariables]]
deps = ["MLStyle", "NameResolution"]
git-tree-sha1 = "49fb3cb53362ddadb4415e9b73926d6b40709e70"
@ -551,9 +578,9 @@ version = "3.0.0+1"
[[deps.LLVM]]
deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Printf", "Unicode"]
git-tree-sha1 = "5007c1421563108110bbd57f63d8ad4565808818"
git-tree-sha1 = "7d5788011dd273788146d40eb5b1fbdc199d0296"
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
version = "5.2.0"
version = "6.0.1"
[[deps.LLVMExtra_jll]]
deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"]
@ -695,16 +722,22 @@ git-tree-sha1 = "cedb76b37bc5a6c702ade66be44f831fa23c681e"
uuid = "e6f89c97-d47a-5376-807f-9c37f3926c36"
version = "1.0.0"
[[deps.LoweredCodeUtils]]
deps = ["JuliaInterpreter"]
git-tree-sha1 = "60168780555f3e663c536500aa790b6368adc02a"
uuid = "6f1432cf-f94c-5a45-995e-cdbf5db27b0b"
version = "2.3.0"
[[deps.MLStyle]]
git-tree-sha1 = "bc38dff0548128765760c79eb7388a4b37fae2c8"
uuid = "d8e11817-5142-5d16-987a-aa16d5891078"
version = "0.4.17"
[[deps.MLUtils]]
deps = ["ChainRulesCore", "Compat", "DataAPI", "DelimitedFiles", "FLoops", "NNlib", "Random", "ShowCases", "SimpleTraits", "Statistics", "StatsBase", "Tables", "Transducers"]
git-tree-sha1 = "3504cdb8c2bc05bde4d4b09a81b01df88fcbbba0"
deps = ["ChainRulesCore", "DataAPI", "DelimitedFiles", "FLoops", "FoldsThreads", "NNlib", "Random", "ShowCases", "SimpleTraits", "Statistics", "StatsBase", "Tables", "Transducers"]
git-tree-sha1 = "82c1104919d664ab1024663ad851701415300c5f"
uuid = "f1d291b0-491e-4a28-83b9-f70985020b54"
version = "0.4.3"
version = "0.3.1"
[[deps.MacroTools]]
deps = ["Markdown", "Random"]
@ -740,9 +773,9 @@ version = "0.1.4"
[[deps.Missings]]
deps = ["DataAPI"]
git-tree-sha1 = "f66bdc5de519e8f8ae43bdc598782d35a25b1272"
git-tree-sha1 = "f8c673ccc215eb50fcadb285f522420e29e69e1c"
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
version = "1.1.0"
version = "0.4.5"
[[deps.Mmap]]
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
@ -792,10 +825,10 @@ uuid = "e7412a2a-1a6e-54c0-be00-318e2571c051"
version = "1.3.5+1"
[[deps.OneHotArrays]]
deps = ["Adapt", "ChainRulesCore", "Compat", "GPUArraysCore", "LinearAlgebra", "NNlib"]
git-tree-sha1 = "5e4029759e8699ec12ebdf8721e51a659443403c"
deps = ["Adapt", "ChainRulesCore", "GPUArraysCore", "LinearAlgebra", "MLUtils", "NNlib"]
git-tree-sha1 = "aee0130122fa7c1f3d394231376f07869f1e097c"
uuid = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
version = "0.2.4"
version = "0.2.0"
[[deps.OpenBLAS_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"]
@ -989,6 +1022,12 @@ git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7"
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
version = "1.3.0"
[[deps.Revise]]
deps = ["CodeTracking", "Distributed", "FileWatching", "JuliaInterpreter", "LibGit2", "LoweredCodeUtils", "OrderedCollections", "Pkg", "REPL", "Requires", "UUIDs", "Unicode"]
git-tree-sha1 = "1e597b93700fa4045d7189afa7c004e0584ea548"
uuid = "295af30f-e4ad-537b-8983-00126c2a3abe"
version = "3.5.3"
[[deps.SHA]]
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
version = "0.7.0"
@ -1008,6 +1047,10 @@ git-tree-sha1 = "e2cc6d8c88613c05e1defb55170bf5ff211fbeac"
uuid = "efcf1570-3423-57d1-acb7-fd33fddbac46"
version = "1.1.1"
[[deps.SharedArrays]]
deps = ["Distributed", "Mmap", "Random", "Serialization"]
uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383"
[[deps.ShowCases]]
git-tree-sha1 = "7f534ad62ab2bd48591bdeac81994ea8c445e4a5"
uuid = "605ecd9f-84a6-4c9e-81e2-4798472b76a3"
@ -1035,9 +1078,9 @@ uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
[[deps.SortingAlgorithms]]
deps = ["DataStructures"]
git-tree-sha1 = "c60ec5c62180f27efea3ba2908480f8055e17cee"
git-tree-sha1 = "f6cb12bae7c2ecff6c4986f28defff8741747a9b"
uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c"
version = "1.1.1"
version = "0.3.2"
[[deps.SparseArrays]]
deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"]
@ -1060,10 +1103,14 @@ uuid = "171d559e-b47b-412a-8079-5efa626c420e"
version = "0.1.15"
[[deps.StaticArrays]]
deps = ["LinearAlgebra", "Random", "StaticArraysCore", "Statistics"]
git-tree-sha1 = "832afbae2a45b4ae7e831f86965469a24d1d8a83"
deps = ["LinearAlgebra", "Random", "StaticArraysCore"]
git-tree-sha1 = "0da7e6b70d1bb40b1ace3b576da9ea2992f76318"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "1.5.26"
version = "1.6.0"
weakdeps = ["Statistics"]
[deps.StaticArrays.extensions]
StaticArraysStatisticsExt = "Statistics"
[[deps.StaticArraysCore]]
git-tree-sha1 = "6b7ba252635a5eff6a0b0664a41ee140a1c9e72a"
@ -1083,9 +1130,9 @@ version = "1.6.0"
[[deps.StatsBase]]
deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"]
git-tree-sha1 = "75ebe04c5bed70b91614d684259b661c9e6274a4"
git-tree-sha1 = "d1bf48bfcc554a3761a133fe3a9bb01488e06916"
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
version = "0.34.0"
version = "0.33.21"
[[deps.StructArrays]]
deps = ["Adapt", "DataAPI", "GPUArraysCore", "StaticArraysCore", "Tables"]
@ -1185,13 +1232,11 @@ deps = ["ConstructionBase", "Dates", "LinearAlgebra", "Random"]
git-tree-sha1 = "ba4aa36b2d5c98d6ed1f149da916b3ba46527b2b"
uuid = "1986cc42-f94f-5a68-af5c-568840ba703d"
version = "1.14.0"
weakdeps = ["InverseFunctions"]
[deps.Unitful.extensions]
InverseFunctionsUnitfulExt = "InverseFunctions"
[deps.Unitful.weakdeps]
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
[[deps.UnitfulLatexify]]
deps = ["LaTeXStrings", "Latexify", "Unitful"]
git-tree-sha1 = "e2d817cc500e960fdbafcf988ac8436ba3208bfd"
@ -1205,9 +1250,9 @@ version = "0.2.1"
[[deps.UnsafeAtomicsLLVM]]
deps = ["LLVM", "UnsafeAtomics"]
git-tree-sha1 = "ea37e6066bf194ab78f4e747f5245261f17a7175"
git-tree-sha1 = "323e3d0acf5e78a56dfae7bd8928c989b4f3083e"
uuid = "d80eeb9a-aca5-4d75-85e5-170c8b632249"
version = "0.1.2"
version = "0.1.3"
[[deps.Unzip]]
git-tree-sha1 = "ca0969166a028236229f63514992fc073799bb78"
@ -1399,9 +1444,9 @@ version = "0.2.3"
[[deps.cuDNN]]
deps = ["CEnum", "CUDA", "CUDNN_jll"]
git-tree-sha1 = "f65490d187861d6222cb38bcbbff3fd949a7ec3e"
git-tree-sha1 = "ee79f97d07bf875231559f9b3f2649f34fac140b"
uuid = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
version = "1.0.4"
version = "1.1.0"
[[deps.fzf_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]

View file

@ -1,3 +1,10 @@
name = "Diffusers"
uuid = "90edb7a8-79d7-49b2-b6b1-9322c3fdead8"
authors = ["Laureηt <laurent@fainsin.bzh>"]
version = "0.1.0"
[deps]
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"

68
examples/swissroll.jl Normal file
View file

@ -0,0 +1,68 @@
import Diffusers
using Random
using Plots
function make_spiral(rng::AbstractRNG, n_samples::Int=1000)
t_min = 1.5π
t_max = 4.5π
t = rand(rng, n_samples) * (t_max - t_min) .+ t_min
x = t .* cos.(t)
y = t .* sin.(t)
permutedims([x y], (2, 1))
end
make_spiral(n_samples::Int=1000) = make_spiral(Random.GLOBAL_RNG, n_samples)
function normalize_zero_to_one(x)
x_min, x_max = extrema(x)
x_norm = (x .- x_min) ./ (x_max - x_min)
x_norm
end
function normalize_neg_one_to_one(x)
2 * normalize_zero_to_one(x) .- 1
end
n_samples = 1000
data = normalize_neg_one_to_one(make_spiral(n_samples))
scatter(data[1, :], data[2, :],
alpha=0.5,
aspectratio=:equal,
)
num_timesteps = 1000
scheduler = Diffusers.DDPM(
Vector{Float64},
Diffusers.cosine_beta_schedule(num_timesteps, 0.999f0, 0.001f0),
)
noise = randn(size(X))
anim = @animate for i in 1:num_timesteps
noisy_data = Diffusers.add_noise(scheduler, X, noise, [i])
scatter(noise[1, :], noise[2, :],
alpha=0.1,
aspectratio=:equal,
label="noise",
legend=:topright,
)
scatter!(noisy_data[1, :], noisy_data[2, :],
alpha=0.5,
aspectratio=:equal,
label="noisy data",
)
scatter!(data[1, :], data[2, :],
alpha=0.5,
aspectratio=:equal,
label="data",
)
i_str = lpad(i, 3, "0")
title!("t = $(i_str)")
xlims!(-3, 3)
ylims!(-3, 3)
end
gif(anim, "swissroll.gif", fps=50)

7
src/Diffusers.jl Normal file
View file

@ -0,0 +1,7 @@
module Diffusers
include("scheduler.jl")
include("beta_scheduler.jl")
include("ddpm.jl")
end # module Diffusers

27
src/beta_scheduler.jl Normal file
View file

@ -0,0 +1,27 @@
import NNlib: sigmoid
function linear_beta_schedule(num_timesteps::Int, β_start=0.0001f0, β_end=0.02f0)
range(β_start, β_end, length=num_timesteps)
end
function scaled_linear_beta_schedule(num_timesteps::Int, β_start=0.0001f0, β_end=0.02f0)
range(β_start^0.5, β_end^0.5, length=num_timesteps) .^ 2
end
function cosine_beta_schedule(num_timesteps::Int, max_beta=0.999f0, ϵ=1e-3f0)
α_bar(t) = cos((t + ϵ) / (1 + ϵ) * π / 2)^2
βs = Float32[]
for i in 1:num_timesteps
t1 = (i - 1) / num_timesteps
t2 = i / num_timesteps
push!(βs, min(1 - α_bar(t2) / α_bar(t1), max_beta))
end
return βs
end
function sigmoid_beta_schedule(num_timesteps::Int, β_start=0.0001f0, β_end=0.02f0)
x = range(-6, 6, length=num_timesteps)
sigmoid(x) * (β_end - β_start) + β_start
end

44
src/ddpm.jl Normal file
View file

@ -0,0 +1,44 @@
include("scheduler.jl")
"""
Denoising Diffusion Probabilistic Models (DDPM) scheduler.
https://arxiv.org/abs/2006.11239
"""
struct DDPM{V<:AbstractVector} <: Scheduler
# number of diffusion steps used to train the model.
num_train_timesteps::Int
# the betas to use for the diffusion steps
βs::V
αs::V
α_cumprods::V
α_cumprod_prevs::V
sqrt_α_cumprods::V
sqrt_one_minus_α_cumprods::V
end
function DDPM(V::DataType, βs::AbstractVector)
αs = 1 .- βs
α_cumprods = cumprod(αs)
α_cumprod_prevs = [1, (α_cumprods[1:end-1])...]
sqrt_α_cumprods = sqrt.(α_cumprods)
sqrt_one_minus_α_cumprods = sqrt.(1 .- α_cumprods)
DDPM{V}(
length(βs),
βs,
αs,
α_cumprods,
α_cumprod_prevs,
sqrt_α_cumprods,
sqrt_one_minus_α_cumprods,
)
end
function DDPM(V::DataType, beta_scheduler)
DDPM(V, beta_scheduler)
end

14
src/scheduler.jl Normal file
View file

@ -0,0 +1,14 @@
abstract type Scheduler end
function add_noise(
scheduler::Scheduler,
original_samples::AbstractArray,
noise::AbstractArray,
timesteps::AbstractArray,
)
alphas_cumprod = scheduler.α_cumprods[timesteps]
sqrt_alpha_prod = sqrt.(alphas_cumprod)
sqrt_one_minus_alpha_prod = sqrt.(1 .- alphas_cumprod)
sqrt_alpha_prod .* original_samples .+ sqrt_one_minus_alpha_prod .* noise
end