Skip to content

Commit 00ba366

Browse files
authored
Merge faeb446 into 45ec1b8
2 parents 45ec1b8 + faeb446 commit 00ba366

File tree

1 file changed

+21
-15
lines changed

1 file changed

+21
-15
lines changed

examples/usage.jl

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,30 @@
11
# Switch To MKL For Faster Computation
2-
using MKL
2+
# using MKL
33

44
# Enable Logging
55
using Logging, TerminalLoggers
66
global_logger(TerminalLogger())
77

8+
# Data
9+
using Distributions
10+
ndata = 1024
11+
ndimension = 1
12+
data_dist = Beta{Float32}(2.0f0, 4.0f0)
13+
r = rand(data_dist, ndimension, ndata)
14+
r = convert.(Float32, r)
15+
816
# Parameters
9-
nvars = 1
17+
nvars = size(r, 1)
1018
naugs = nvars
11-
# n_in = nvars # without augmentation
12-
n_in = nvars + naugs # with augmentation
13-
n = 1024
19+
n_in = nvars + naugs
1420

1521
# Model
1622
using ContinuousNormalizingFlows,
1723
Lux, OrdinaryDiffEqDefault, SciMLSensitivity, ADTypes, Zygote, MLDataDevices
24+
25+
# To use gpu, add related packages
26+
# using LuxCUDA, CUDA, cuDNN
27+
1828
nn = Chain(Dense(n_in => 3 * n_in, tanh), Dense(3 * n_in => n_in, tanh))
1929
icnf = construct(
2030
RNODE,
@@ -24,6 +34,7 @@ icnf = construct(
2434
compute_mode = LuxVecJacMatrixMode(AutoZygote()), # process data in batches and use Zygote
2535
inplace = false, # not using the inplace version of functions
2636
device = cpu_device(), # process data by CPU
37+
# device = gpu_device(), # process data by GPU
2738
tspan = (0.0f0, 13.0f0), # have bigger time span
2839
steer_rate = 1.0f-1, # add random noise to end of the time span
2940
λ₁ = 1.0f-2, # regulate flow
@@ -36,12 +47,6 @@ icnf = construct(
3647
), # pass to the solver
3748
)
3849

39-
# Data
40-
using Distributions
41-
data_dist = Beta{Float32}(2.0f0, 4.0f0)
42-
r = rand(data_dist, nvars, n)
43-
r = convert.(Float32, r)
44-
4550
# Fit It
4651
using DataFrames, MLJBase, Zygote, ADTypes, OptimizationOptimisers
4752
df = DataFrame(transpose(r), :auto)
@@ -55,25 +60,26 @@ model = ICNFModel(
5560
)
5661
mach = machine(model, df)
5762
fit!(mach)
63+
# CUDA.@allowscalar fit!(mach) # needed for gpu
5864
ps, st = fitted_params(mach)
5965

6066
# Store It
6167
using JLD2, UnPack
62-
jldsave("fitted.jld2"; ps, st) # save
63-
@unpack ps, st = load("fitted.jld2") # load
68+
jldsave("fitted.jld2"; ps, st) # save it
69+
@unpack ps, st = load("fitted.jld2") # load it
6470

6571
# Use It
6672
d = ICNFDist(icnf, TestMode(), ps, st) # direct way
6773
# d = ICNFDist(mach, TestMode()) # alternative way
6874
actual_pdf = pdf.(data_dist, vec(r))
6975
estimated_pdf = pdf(d, r)
70-
new_data = rand(d, n)
76+
new_data = rand(d, ndata)
7177

7278
# Evaluate It
7379
using Distances
7480
mad_ = meanad(estimated_pdf, actual_pdf)
7581
msd_ = msd(estimated_pdf, actual_pdf)
76-
tv_dis = totalvariation(estimated_pdf, actual_pdf) / n
82+
tv_dis = totalvariation(estimated_pdf, actual_pdf) / ndata
7783
res_df = DataFrame(; mad_, msd_, tv_dis)
7884
display(res_df)
7985

0 commit comments

Comments
 (0)