11# Switch To MKL For Faster Computation
2- using MKL
2+ # using MKL
33
44# Enable Logging
55using Logging, TerminalLoggers
66global_logger (TerminalLogger ())
77
8+ # Data
9+ using Distributions
10+ ndata = 1024
11+ ndimension = 1
12+ data_dist = Beta {Float32} (2.0f0 , 4.0f0 )
13+ r = rand (data_dist, ndimension, ndata)
14+ r = convert .(Float32, r)
15+
816# Parameters
9- nvars = 1
17+ nvars = size (r, 1 )
1018naugs = nvars
11- # n_in = nvars # without augmentation
12- n_in = nvars + naugs # with augmentation
13- n = 1024
19+ n_in = nvars + naugs
1420
1521# Model
1622using ContinuousNormalizingFlows,
1723 Lux, OrdinaryDiffEqDefault, SciMLSensitivity, ADTypes, Zygote, MLDataDevices
24+
25+ # To use gpu, add related packages
26+ # using LuxCUDA, CUDA, cuDNN
27+
1828nn = Chain (Dense (n_in => 3 * n_in, tanh), Dense (3 * n_in => n_in, tanh))
1929icnf = construct (
2030 RNODE,
@@ -24,6 +34,7 @@ icnf = construct(
2434 compute_mode = LuxVecJacMatrixMode (AutoZygote ()), # process data in batches and use Zygote
2535 inplace = false , # not using the inplace version of functions
2636 device = cpu_device (), # process data by CPU
37+ # device = gpu_device(), # process data by GPU
2738 tspan = (0.0f0 , 13.0f0 ), # have bigger time span
2839 steer_rate = 1.0f-1 , # add random noise to end of the time span
2940 λ₁ = 1.0f-2 , # regulate flow
@@ -36,12 +47,6 @@ icnf = construct(
3647 ), # pass to the solver
3748)
3849
39- # Data
40- using Distributions
41- data_dist = Beta {Float32} (2.0f0 , 4.0f0 )
42- r = rand (data_dist, nvars, n)
43- r = convert .(Float32, r)
44-
4550# Fit It
4651using DataFrames, MLJBase, Zygote, ADTypes, OptimizationOptimisers
4752df = DataFrame (transpose (r), :auto )
@@ -55,25 +60,26 @@ model = ICNFModel(
5560)
5661mach = machine (model, df)
5762fit! (mach)
63+ # CUDA.@allowscalar fit!(mach) # needed for gpu
5864ps, st = fitted_params (mach)
5965
6066# Store It
6167using JLD2, UnPack
62- jldsave (" fitted.jld2" ; ps, st) # save
63- @unpack ps, st = load (" fitted.jld2" ) # load
68+ jldsave (" fitted.jld2" ; ps, st) # save it
69+ @unpack ps, st = load (" fitted.jld2" ) # load it
6470
6571# Use It
6672d = ICNFDist (icnf, TestMode (), ps, st) # direct way
6773# d = ICNFDist(mach, TestMode()) # alternative way
6874actual_pdf = pdf .(data_dist, vec (r))
6975estimated_pdf = pdf (d, r)
70- new_data = rand (d, n )
76+ new_data = rand (d, ndata )
7177
7278# Evaluate It
7379using Distances
7480mad_ = meanad (estimated_pdf, actual_pdf)
7581msd_ = msd (estimated_pdf, actual_pdf)
76- tv_dis = totalvariation (estimated_pdf, actual_pdf) / n
82+ tv_dis = totalvariation (estimated_pdf, actual_pdf) / ndata
7783res_df = DataFrame (; mad_, msd_, tv_dis)
7884display (res_df)
7985
0 commit comments