diff --git a/.github/workflows/Invalidations.yml b/.github/workflows/Invalidations.yml index 983c2df..4d0004e 100644 --- a/.github/workflows/Invalidations.yml +++ b/.github/workflows/Invalidations.yml @@ -9,9 +9,6 @@ concurrency: group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true -env: - PYTHON: "Conda" # use Julia's packaged Conda build for installing packages - jobs: evaluate: # Only run on PRs to the default branch. @@ -24,17 +21,6 @@ jobs: version: '1' - uses: actions/checkout@v3 - uses: julia-actions/julia-buildpkg@v1 - - name: Install Python ArviZ dependencies - run: | - using Pkg - Pkg.instantiate() - using Conda - # https://discourse.julialang.org/t/conda-not-installing-matplotlib-for-pyplot/96813/2 - Conda.add("conda==23.1.0") - using ArviZPythonPlots - ArviZPythonPlots.initialize_arviz() - ArviZPythonPlots.initialize_pandas() - shell: julia --color=yes --project {0} - uses: julia-actions/julia-invalidations@v1 id: invs_pr diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 86202b2..1d430d4 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -7,9 +7,6 @@ on: schedule: - cron: "0 0 * * *" -env: - PYTHON: "Conda" # use Julia's packaged Conda build for installing packages - jobs: test: name: Julia ${{ matrix.julia-version }} - ${{ matrix.os }} @@ -17,45 +14,31 @@ jobs: strategy: fail-fast: false matrix: - julia-version: ["1"] - os: [ubuntu-latest, windows-latest, macOS-latest] + julia-version: ["1", "1.8", "nightly"] + os: [ubuntu-latest] include: - - julia-version: "1.6" - os: ubuntu-latest + - julia-version: "1" + os: windows-latest + - julia-version: "1" + os: macOS-latest steps: - uses: actions/checkout@v2 - uses: julia-actions/setup-julia@v1 with: version: ${{ matrix.julia-version }} arch: x64 - - name: Cache Julia artifacts - uses: actions/cache@v2 + - uses: julia-actions/cache@v1 + - name: Cache CondaPkg + id: cache-condaPkg + uses: actions/cache@v3 env: - cache-name: cache-artifacts + cache-name: cache-condapkg with: - path: ~/.julia/artifacts - key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} + path: .CondaPkg + key: ${{ runner.os }}-${{ env.cache-name }}-${{ hashFiles('docs/CondaPkg.toml') }} restore-keys: | - ${{ runner.os }}-test-${{ env.cache-name }}- - ${{ runner.os }}-test- - ${{ runner.os }}- - - name: Install wget for windows - if: matrix.os == 'windows-latest' - uses: crazy-max/ghaction-chocolatey@v2 - with: - args: install wget + ${{ runner.os }}-${{ env.cache-name }}- - uses: julia-actions/julia-buildpkg@latest - - name: Install Python ArviZ dependencies - run: | - using Pkg - Pkg.instantiate() - using Conda - # https://discourse.julialang.org/t/conda-not-installing-matplotlib-for-pyplot/96813/2 - Conda.add("conda==23.1.0") - using ArviZPythonPlots - ArviZPythonPlots.initialize_arviz() - ArviZPythonPlots.initialize_pandas() - shell: julia --color=yes --project {0} - uses: julia-actions/julia-runtest@latest - uses: julia-actions/julia-processcoverage@v1 - uses: codecov/codecov-action@v1 diff --git a/.github/workflows/documenter.yml b/.github/workflows/documenter.yml index dcb9cf2..f218c7d 100644 --- a/.github/workflows/documenter.yml +++ b/.github/workflows/documenter.yml @@ -5,9 +5,6 @@ on: tags: [v*] pull_request: -env: - PYTHON: "Conda" # use Julia's packaged Conda build for installing packages - jobs: docs: name: Documentation @@ -15,15 +12,20 @@ jobs: steps: - uses: actions/checkout@v2 - uses: julia-actions/setup-julia@v1 + - uses: julia-actions/cache@v1 + - name: Cache CondaPkg + id: cache-condaPkg + uses: actions/cache@v3 + env: + cache-name: cache-condapkg + with: + path: | + .CondaPkg + docs/.CondaPkg + key: ${{ runner.os }}-${{ env.cache-name }}-${{ hashFiles('docs/CondaPkg.toml') }} + restore-keys: | + ${{ runner.os }}-${{ env.cache-name }}- - uses: julia-actions/julia-buildpkg@latest - - name: Setup Conda - run: | - using Pkg - Pkg.instantiate() - using Conda - # https://discourse.julialang.org/t/conda-not-installing-matplotlib-for-pyplot/96813/2 - Conda.add("conda==23.1.0") - shell: julia --color=yes --project {0} - uses: julia-actions/julia-docdeploy@v1 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/futures.yml b/.github/workflows/futures.yml deleted file mode 100644 index 0204254..0000000 --- a/.github/workflows/futures.yml +++ /dev/null @@ -1,69 +0,0 @@ -name: Futures -on: - push: - branches: [main] - tags: [v*] - pull_request: - schedule: - - cron: "0 0 * * *" - -env: - PYTHON: "Conda" # use Julia's packaged Conda build for installing packages - -jobs: - test: - name: Julia ${{ matrix.julia-version }} - ${{ matrix.arviz_version }} - runs-on: ${{ matrix.os }} - strategy: - fail-fast: false - matrix: - julia-version: ["1", "nightly"] - os: [ubuntu-latest] - include: - - julia-version: "1" - arviz_version: "main" - - julia-version: "nightly" - arviz_version: "release" - steps: - - uses: actions/checkout@v2 - - uses: julia-actions/setup-julia@v1 - with: - version: ${{ matrix.julia-version }} - arch: x64 - - name: Cache Julia artifacts - uses: actions/cache@v2 - env: - cache-name: cache-artifacts - with: - path: ~/.julia/artifacts - key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} - restore-keys: | - ${{ runner.os }}-test-${{ env.cache-name }}- - ${{ runner.os }}-test- - ${{ runner.os }}- - - uses: julia-actions/julia-buildpkg@latest - if: matrix.arviz_version != 'main' || github.event_name == 'create' && startsWith(github.ref, 'refs/tags/v') - - name: "Install arviz#main" - if: matrix.arviz_version == 'main' && github.event_name != 'create' && !startsWith(github.ref, 'refs/tags/v') - run: | - using Pkg - Pkg.add("PyCall") - using PyCall - run( - PyCall.python_cmd( - `-m pip install git+https://github.com/pydata/xarray git+https://github.com/arviz-devs/arviz`, - ), - ) - shell: julia --color=yes {0} - - name: Install Python ArviZ dependencies - run: | - using Pkg - Pkg.instantiate() - using Conda - # https://discourse.julialang.org/t/conda-not-installing-matplotlib-for-pyplot/96813/2 - Conda.add("conda==23.1.0") - using ArviZPythonPlots - ArviZPythonPlots.initialize_arviz() - ArviZPythonPlots.initialize_pandas() - shell: julia --color=yes --project {0} - - uses: julia-actions/julia-runtest@latest diff --git a/.gitignore b/.gitignore index 7dd4e70..b8c6b69 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ /docs/build /test/*.log .DS_Store +.CondaPkg \ No newline at end of file diff --git a/CondaPkg.toml b/CondaPkg.toml new file mode 100644 index 0000000..c3f696b --- /dev/null +++ b/CondaPkg.toml @@ -0,0 +1,5 @@ +[deps] +pandas = "" +matplotlib = "" +xarray = "" +arviz = ">=0.14.0" diff --git a/Project.toml b/Project.toml index e2bfe8b..0d0ffed 100644 --- a/Project.toml +++ b/Project.toml @@ -5,26 +5,27 @@ version = "0.1.0" [deps] ArviZ = "131c737c-5715-5e2e-ad31-c244f01c1dc7" -Conda = "8f4d0f93-b110-5947-807f-2305c1781a2d" +CondaPkg = "992eb4ea-22a4-4c89-a5bb-47a3300528ab" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" DimensionalData = "0703355e-b756-11e9-17c0-8b28908087d0" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" -PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" -PyPlot = "d330b81b-6aea-500a-939a-2ce795aea3ee" +PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" +PythonPlot = "274fc56d-3b97-40fa-a1cd-1b4a50311bf9" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" [compat] -ArviZ = "0.9" +ArviZ = "0.10" ArviZExampleData = "0.1.5" -Conda = "1.0" -DataFrames = "0.20, 0.21, 0.22, 1.0" +CondaPkg = "0.2" DimensionalData = "0.23, 0.24" OrderedCollections = "1" -PyCall = "1.91.2" -PyPlot = "2.8.2" +PythonCall = "0.9" +PythonPlot = "1" Reexport = "1" -julia = "1.6" +Tables = "1" +julia = "1.8" [extras] ArviZExampleData = "2f96bb34-afd9-46ae-bcd0-9b2d4372fe3c" diff --git a/README.md b/README.md index 39c7f10..f64bb25 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ [![Powered by NumFOCUS](https://img.shields.io/badge/powered%20by-NumFOCUS-orange.svg?style=flat&colorA=E1523D&colorB=007D8A)](https://numfocus.org) ArviZPythonPlots.jl provides PyPlot-compatible plotting functions for exploratory analysis of Bayesian models using [ArviZ.jl](https://julia.arviz.org/). -It provides an interface to use the plotting functions in [Python ArviZ](https://python.arviz.org/) with Julia types. -It also re-exports all methods exported by both ArviZ.jl and [PyPlot.jl](https://github.com/JuliaPy/PyPlot.jl). +It uses [PythonCall.jl](https://github.com/cjdoris/PythonCall.jl) to provide an interface for using the plotting functions in [Python ArviZ](https://python.arviz.org/) with Julia types. +It also re-exports all methods exported by [PythonPlot.jl](https://github.com/JuliaPy/PythonPlot.jl). See the [documentation](https://julia.arviz.org/ArviZPythonPlots) for details. diff --git a/deps/build.jl b/deps/build.jl deleted file mode 100644 index 3fef876..0000000 --- a/deps/build.jl +++ /dev/null @@ -1,6 +0,0 @@ -using Conda - -# temporary workaround for -# - https://github.com/arviz-devs/ArviZ.jl/issues/188 -# - https://github.com/arviz-devs/arviz/issues/2120 -Conda.add(["scipy<=1.8.0"]) diff --git a/docs/Project.toml b/docs/Project.toml index dc75be7..bd90132 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -4,12 +4,9 @@ ArviZExampleData = "2f96bb34-afd9-46ae-bcd0-9b2d4372fe3c" ArviZPythonPlots = "4a6e88f0-2c8e-11ee-0601-e94153f0eada" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" -PyPlot = "d330b81b-6aea-500a-939a-2ce795aea3ee" +PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" [compat] ArviZExampleData = "0.1.5" Distributions = "0.25" Documenter = "0.27" -PyCall = "1.0" -PyPlot = "2.0" diff --git a/docs/lazyhelp.jl b/docs/lazyhelp.jl new file mode 100644 index 0000000..98fd016 --- /dev/null +++ b/docs/lazyhelp.jl @@ -0,0 +1,47 @@ +using Documenter, Markdown, PythonCall + +# adapted from https://github.com/AtelierArith/Kyulacs.jl for PythonCall +# MIT License +# Copyright (c) 2022 Satoshi Terasaki and contributors + +function get_signature(f) + inspect = pyimport("inspect") + try + return pyconvert(String, inspect.signature(f)) + catch e + return "" + end +end + +function gendocstr(h::LazyHelp) + o = h.o + for k in h.keys + o = pygetattr(o, k) + end + fname = pyhasattr(o, "__name__") ? pyconvert(String, o.__name__) : "" + sig = pyhasattr(o, "__call__") ? get_signature(o) : "" + fdoc = pyhasattr(o, "__doc__") ? pyconvert(String, o.__doc__) : "" + + if isnothing(fdoc) + return """ + $(fname)$(sig) + """ + else + return """ + $(fdoc) + """ + end +end + +function Documenter.Writers.HTMLWriter.mdconvert(h::LazyHelp, parent; kwargs...) + s = gendocstr(h) + # quote docstring `s` to prevent changing display result + m = Markdown.parse(""" + ``` + $s + ``` + """) + return Documenter.Writers.HTMLWriter.mdconvert(m, parent; kwargs...) +end + +Documenter.Utilities.MDFlatten.mdflatten(::IOBuffer, ::LazyHelp, ::Markdown.MD) = nothing diff --git a/docs/make.jl b/docs/make.jl index 4bba6e4..6728c44 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,4 +1,7 @@ using ArviZPythonPlots, Documenter +using ArviZPythonPlots: LazyHelp + +include("lazyhelp.jl") makedocs(; modules=[ArviZPythonPlots], @@ -17,7 +20,9 @@ makedocs(; format=Documenter.HTML(; prettyurls=haskey(ENV, "CI"), sidebar_sitename=false, canonical="stable" ), + doctest=false, linkcheck=true, + strict=true, ) deploydocs(; diff --git a/docs/src/examples.md b/docs/src/examples.md index 6f5a724..11b2ef8 100644 --- a/docs/src/examples.md +++ b/docs/src/examples.md @@ -5,11 +5,10 @@ ```@example using ArviZPythonPlots, ArviZExampleData -ArviZPythonPlots.use_style("arviz-darkgrid") +use_style("arviz-darkgrid") data = load_example_data("centered_eight") -figure() #hide -plot_autocorr(data; var_names=[:tau, :mu]) +plot_autocorr(data; var_names=["tau", "mu"]) gcf() ``` @@ -20,10 +19,9 @@ See [`plot_autocorr`](@ref) ```@example using ArviZPythonPlots, ArviZExampleData -ArviZPythonPlots.use_style("arviz-darkgrid") +use_style("arviz-darkgrid") data = load_example_data("regression1d") -figure() #hide plot_bpv(data) gcf() ``` @@ -35,11 +33,10 @@ See [`plot_bpv`](@ref) ```@example using ArviZPythonPlots, ArviZExampleData -ArviZPythonPlots.use_style("arviz-darkgrid") +use_style("arviz-darkgrid") data = load_example_data("regression1d") -figure() #hide -plot_bpv(data; kind=:t_stat, t_stat="0.5") +plot_bpv(data; kind="t_stat", t_stat="0.5") gcf() ``` @@ -48,9 +45,9 @@ See [`plot_bpv`](@ref) ## Compare Plot ```@example -using ArviZPythonPlots, ArviZExampleData +using ArviZ, ArviZPythonPlots, ArviZExampleData -ArviZPythonPlots.use_style("arviz-darkgrid") +use_style("arviz-darkgrid") model_compare = compare( ( @@ -58,7 +55,6 @@ model_compare = compare( var"Non-centered 8 schools" = load_example_data("non_centered_eight"), ), ) -figure() #hide plot_compare(model_compare; figsize=(12, 4)) gcf() ``` @@ -70,15 +66,14 @@ See [`compare`](https://julia.arviz.org/ArviZ/stable/api/stats/#ArviZ.compare), ```@example using ArviZPythonPlots, ArviZExampleData -ArviZPythonPlots.use_style("arviz-darkgrid") +use_style("arviz-darkgrid") centered_data = load_example_data("centered_eight") non_centered_data = load_example_data("non_centered_eight") -figure() #hide plot_density( [centered_data, non_centered_data]; data_labels=["Centered", "Non Centered"], - var_names=[:theta], + var_names=["theta"], shade=0.1, ) gcf() @@ -89,20 +84,17 @@ See [`plot_density`](@ref) ## Dist Plot ```@example -using Random -using Distributions -using ArviZPythonPlots +using ArviZPythonPlots, Distributions, Random Random.seed!(308) -ArviZPythonPlots.use_style("arviz-darkgrid") +use_style("arviz-darkgrid") a = rand(Poisson(4), 1000) b = rand(Normal(0, 1), 1000) -figure() #hide -_, ax = plt.subplots(1, 2; figsize=(10, 4)) -plot_dist(a; color="C1", label="Poisson", ax=ax[1]) -plot_dist(b; color="C2", label="Gaussian", ax=ax[2]) +_, ax = subplots(1, 2; figsize=(10, 4)) +plot_dist(a; color="C1", label="Poisson", ax=ax[0]) +plot_dist(b; color="C2", label="Gaussian", ax=ax[1]) gcf() ``` @@ -113,11 +105,10 @@ See [`plot_dist`](@ref) ```@example using ArviZPythonPlots, ArviZExampleData -ArviZPythonPlots.use_style("arviz-darkgrid") +use_style("arviz-darkgrid") d1 = load_example_data("centered_eight") d2 = load_example_data("non_centered_eight") -figure() #hide plot_elpd(Dict("Centered eight" => d1, "Non centered eight" => d2); xlabels=true) gcf() ``` @@ -129,10 +120,9 @@ See [`plot_elpd`](@ref) ```@example using ArviZPythonPlots, ArviZExampleData -ArviZPythonPlots.use_style("arviz-darkgrid") +use_style("arviz-darkgrid") data = load_example_data("centered_eight") -figure() #hide plot_energy(data; figsize=(12, 8)) gcf() ``` @@ -144,11 +134,10 @@ See [`plot_energy`](@ref) ```@example using ArviZPythonPlots, ArviZExampleData -ArviZPythonPlots.use_style("arviz-darkgrid") +use_style("arviz-darkgrid") idata = load_example_data("radon") -figure() #hide -plot_ess(idata; var_names=[:b], kind=:evolution) +plot_ess(idata; var_names=["b"], kind="evolution") gcf() ``` @@ -159,11 +148,10 @@ See [`plot_ess`](@ref) ```@example using ArviZPythonPlots, ArviZExampleData -ArviZPythonPlots.use_style("arviz-darkgrid") +use_style("arviz-darkgrid") idata = load_example_data("non_centered_eight") -figure() #hide -plot_ess(idata; var_names=[:mu], kind=:local, marker="_", ms=20, mew=2, rug=true) +plot_ess(idata; var_names=["mu"], kind="local", marker="_", ms=20, mew=2, rug=true) gcf() ``` @@ -174,11 +162,10 @@ See [`plot_ess`](@ref) ```@example using ArviZPythonPlots, ArviZExampleData -ArviZPythonPlots.use_style("arviz-darkgrid") +use_style("arviz-darkgrid") idata = load_example_data("radon") -figure() #hide -plot_ess(idata; var_names=[:sigma], kind=:quantile, color="C4") +plot_ess(idata; var_names=["sigma"], kind="quantile", color="C4") gcf() ``` @@ -189,15 +176,14 @@ See [`plot_ess`](@ref) ```@example using ArviZPythonPlots, ArviZExampleData -ArviZPythonPlots.use_style("arviz-darkgrid") +use_style("arviz-darkgrid") centered_data = load_example_data("centered_eight") non_centered_data = load_example_data("non_centered_eight") -figure() #hide plot_forest( [centered_data, non_centered_data]; model_names=["Centered", "Non Centered"], - var_names=[:mu], + var_names=["mu"], ) title("Estimated theta for eight schools model") gcf() @@ -210,18 +196,17 @@ See [`plot_forest`](@ref) ```@example using ArviZPythonPlots, ArviZExampleData -ArviZPythonPlots.use_style("arviz-darkgrid") +use_style("arviz-darkgrid") rugby_data = load_example_data("rugby") -figure() #hide plot_forest( rugby_data; - kind=:ridgeplot, - var_names=[:defs], + kind="ridgeplot", + var_names=["defs"], linewidth=4, combined=true, ridgeplot_overlap=1.5, - colors=:blue, + colors="blue", figsize=(9, 4), ) title("Relative defensive strength\nof Six Nation rugby teams") @@ -238,15 +223,14 @@ using ArviZPythonPlots Random.seed!(308) -ArviZPythonPlots.use_style("arviz-darkgrid") +use_style("arviz-darkgrid") x_data = randn(100) y_data = 2 .+ x_data .* 0.5 y_data_rep = 0.5 .* randn(200, 100) .+ transpose(y_data) -figure() #hide plot(x_data, y_data; color="C6") -plot_hdi(x_data, y_data_rep; color=:k, plot_kwargs=Dict(:ls => "--")) +plot_hdi(x_data, y_data_rep; color="k", plot_kwargs=Dict("ls" => "--")) gcf() ``` @@ -257,15 +241,14 @@ See [`plot_hdi`](@ref) ```@example using ArviZPythonPlots, ArviZExampleData -ArviZPythonPlots.use_style("arviz-darkgrid") +use_style("arviz-darkgrid") data = load_example_data("non_centered_eight") -figure() #hide plot_pair( data; - var_names=[:theta], - coords=Dict(:school => ["Choate", "Phillips Andover"]), - kind=:hexbin, + var_names=["theta"], + coords=Dict("school" => ["Choate", "Phillips Andover"]), + kind="hexbin", marginals=true, figsize=(10, 10), ) @@ -279,7 +262,7 @@ See [`plot_pair`](@ref) ```@example using ArviZPythonPlots, ArviZExampleData -ArviZPythonPlots.use_style("arviz-darkgrid") +use_style("arviz-darkgrid") data = load_example_data("centered_eight") @@ -288,13 +271,12 @@ obs = data.posterior_predictive.obs size_obs = size(obs) y_hat = reshape(obs, prod(size_obs[1:2]), size_obs[3:end]...) -figure() #hide plot_kde( y_hat; label="Estimated Effect\n of SAT Prep", rug=true, - plot_kwargs=Dict(:linewidth => 2, :color => :black), - rug_kwargs=Dict(:color => :black), + plot_kwargs=Dict("linewidth" => 2, "color" => "black"), + rug_kwargs=Dict("color" => "black"), ) gcf() ``` @@ -309,9 +291,8 @@ using ArviZPythonPlots Random.seed!(308) -ArviZPythonPlots.use_style("arviz-darkgrid") +use_style("arviz-darkgrid") -figure() #hide plot_kde(rand(100), rand(100)) gcf() ``` @@ -327,10 +308,9 @@ using ArviZPythonPlots Random.seed!(308) -ArviZPythonPlots.use_style("arviz-darkgrid") +use_style("arviz-darkgrid") dist = rand(Beta(rand(Uniform(0.5, 10)), 5), 1000) -figure() #hide plot_kde(dist; quantiles=[0.25, 0.5, 0.75]) gcf() ``` @@ -340,13 +320,12 @@ See [`plot_kde`](@ref) ## Pareto Shape Plot ```@example -using ArviZPythonPlots, ArviZExampleData +using ArviZ, ArviZPythonPlots, ArviZExampleData -ArviZPythonPlots.use_style("arviz-darkgrid") +use_style("arviz-darkgrid") idata = load_example_data("radon") loo_data = loo(idata) -figure() #hide plot_khat(loo_data; show_bins=true) gcf() ``` @@ -358,12 +337,11 @@ See [`loo`](https://julia.arviz.org/ArviZ/stable/api/stats/#ArviZ.loo), [`plot_k ```@example using ArviZPythonPlots, ArviZExampleData -ArviZPythonPlots.use_style("arviz-darkgrid") +use_style("arviz-darkgrid") idata = load_example_data("radon") -figure() #hide -plot_loo_pit(idata; y=:y, ecdf=true, color=:maroon) +plot_loo_pit(idata; y="y", ecdf=true, color="maroon") gcf() ``` @@ -374,11 +352,10 @@ See [`loo_pit`](https://julia.arviz.org/ArviZ/stable/api/stats/#ArviZ.loo_pit), ```@example using ArviZPythonPlots, ArviZExampleData -ArviZPythonPlots.use_style("arviz-darkgrid") +use_style("arviz-darkgrid") idata = load_example_data("non_centered_eight") -figure() #hide -plot_loo_pit(; idata, y=:obs, color=:indigo) +plot_loo_pit(; idata, y="obs", color="indigo") gcf() ``` @@ -389,11 +366,10 @@ See [`loo_pit`](https://julia.arviz.org/ArviZ/stable/api/stats/#ArviZ.loo_pit), ```@example using ArviZPythonPlots, ArviZExampleData -ArviZPythonPlots.use_style("arviz-darkgrid") +use_style("arviz-darkgrid") data = load_example_data("centered_eight") -figure() #hide -plot_mcse(data; var_names=[:tau, :mu], rug=true, extra_methods=true) +plot_mcse(data; var_names=["tau", "mu"], rug=true, extra_methods=true) gcf() ``` @@ -404,11 +380,10 @@ See [`plot_mcse`](@ref) ```@example using ArviZPythonPlots, ArviZExampleData -ArviZPythonPlots.use_style("arviz-darkgrid") +use_style("arviz-darkgrid") data = load_example_data("radon") -figure() #hide -plot_mcse(data; var_names=[:sigma_a], color="C4", errorbar=true) +plot_mcse(data; var_names=["sigma_a"], color="C4", errorbar=true) gcf() ``` @@ -419,13 +394,12 @@ See [`plot_mcse`](@ref) ```@example using ArviZPythonPlots, ArviZExampleData -ArviZPythonPlots.use_style("arviz-darkgrid") +use_style("arviz-darkgrid") centered = load_example_data("centered_eight") -coords = Dict(:school => ["Choate", "Deerfield"]) -figure() #hide +coords = Dict("school" => ["Choate", "Deerfield"]) plot_pair( - centered; var_names=[:theta, :mu, :tau], coords, divergences=true, textsize=22 + centered; var_names=["theta", "mu", "tau"], coords, divergences=true, textsize=22 ) gcf() ``` @@ -437,15 +411,14 @@ See [`plot_pair`](@ref) ```@example using ArviZPythonPlots, ArviZExampleData -ArviZPythonPlots.use_style("arviz-darkgrid") +use_style("arviz-darkgrid") centered = load_example_data("centered_eight") -coords = Dict(:school => ["Choate", "Deerfield"]) -figure() #hide +coords = Dict("school" => ["Choate", "Deerfield"]) plot_pair( centered; - var_names=[:theta, :mu, :tau], - kind=:hexbin, + var_names=["theta", "mu", "tau"], + kind="hexbin", coords, colorbar=true, divergences=true, @@ -460,15 +433,14 @@ See [`plot_pair`](@ref) ```@example using ArviZPythonPlots, ArviZExampleData -ArviZPythonPlots.use_style("arviz-darkgrid") +use_style("arviz-darkgrid") centered = load_example_data("centered_eight") -coords = Dict(:school => ["Choate", "Deerfield"]) -figure() #hide +coords = Dict("school" => ["Choate", "Deerfield"]) plot_pair( centered; - var_names=[:theta, :mu, :tau], - kind=:kde, + var_names=["theta", "mu", "tau"], + kind="kde", coords, divergences=true, textsize=22, @@ -483,19 +455,18 @@ See [`plot_pair`](@ref) ```@example using ArviZPythonPlots, ArviZExampleData -ArviZPythonPlots.use_style("arviz-darkgrid") +use_style("arviz-darkgrid") centered = load_example_data("centered_eight") -coords = Dict(:school => ["Choate", "Deerfield"]) -figure() #hide +coords = Dict("school" => ["Choate", "Deerfield"]) plot_pair( centered; - var_names=[:mu, :theta], - kind=[:scatter, :kde], - kde_kwargs=Dict(:fill_last => false), + var_names=["mu", "theta"], + kind=["scatter", "kde"], + kde_kwargs=Dict("fill_last" => false), marginals=true, coords, - point_estimate=:median, + point_estimate="median", figsize=(10, 8), ) gcf() @@ -508,11 +479,10 @@ See [`plot_pair`](@ref) ```@example using ArviZPythonPlots, ArviZExampleData -ArviZPythonPlots.use_style("arviz-darkgrid") +use_style("arviz-darkgrid") data = load_example_data("centered_eight") -figure() #hide -ax = plot_parallel(data; var_names=[:theta, :tau, :mu]) +ax = plot_parallel(data; var_names=["theta", "tau", "mu"]) ax.set_xticklabels(ax.get_xticklabels(); rotation=70) draw() gcf() @@ -525,12 +495,11 @@ See [`plot_parallel`](@ref) ```@example using ArviZPythonPlots, ArviZExampleData -ArviZPythonPlots.use_style("arviz-darkgrid") +use_style("arviz-darkgrid") data = load_example_data("centered_eight") -coords = Dict(:school => ["Choate"]) -figure() #hide -plot_posterior(data; var_names=[:mu, :theta], coords, rope=(-1, 1)) +coords = Dict("school" => ["Choate"]) +plot_posterior(data; var_names=["mu", "theta"], coords, rope=(-1, 1)) gcf() ``` @@ -541,11 +510,10 @@ See [`plot_posterior`](@ref) ```@example using ArviZPythonPlots, ArviZExampleData -ArviZPythonPlots.use_style("arviz-darkgrid") +use_style("arviz-darkgrid") data = load_example_data("non_centered_eight") -figure() #hide -plot_ppc(data; data_pairs=Dict(:obs => :obs), alpha=0.03, figsize=(12, 6), textsize=14) +plot_ppc(data; data_pairs=Dict("obs" => "obs"), alpha=0.03, figsize=(12, 6), textsize=14) gcf() ``` @@ -556,11 +524,10 @@ See [`plot_ppc`](@ref) ```@example using ArviZPythonPlots, ArviZExampleData -ArviZPythonPlots.use_style("arviz-darkgrid") +use_style("arviz-darkgrid") data = load_example_data("non_centered_eight") -figure() #hide -plot_ppc(data; alpha=0.3, kind=:cumulative, figsize=(12, 6), textsize=14) +plot_ppc(data; alpha=0.3, kind="cumulative", figsize=(12, 6), textsize=14) gcf() ``` @@ -571,11 +538,10 @@ See [`plot_ppc`](@ref) ```@example using ArviZPythonPlots, ArviZExampleData -ArviZPythonPlots.use_style("arviz-darkgrid") +use_style("arviz-darkgrid") data = load_example_data("centered_eight") -figure() #hide -plot_rank(data; var_names=[:tau, :mu]) +plot_rank(data; var_names=["tau", "mu"]) gcf() ``` @@ -586,11 +552,10 @@ See [`plot_rank`](@ref) ```@example using ArviZPythonPlots, ArviZExampleData -ArviZPythonPlots.use_style("arviz-darkgrid") +use_style("arviz-darkgrid") data = load_example_data("classification10d") -figure() #hide -plot_separation(data; y=:outcome, y_hat=:outcome, figsize=(8, 1)) +plot_separation(data; y="outcome", y_hat="outcome", figsize=(8, 1)) gcf() ``` @@ -601,11 +566,10 @@ See [`plot_separation`](@ref) ```@example using ArviZPythonPlots, ArviZExampleData -ArviZPythonPlots.use_style("arviz-darkgrid") +use_style("arviz-darkgrid") data = load_example_data("non_centered_eight") -figure() #hide -plot_trace(data; var_names=[:tau, :mu]) +plot_trace(data; var_names=["tau", "mu"]) gcf() ``` @@ -616,11 +580,10 @@ See [`plot_trace`](@ref) ```@example using ArviZPythonPlots, ArviZExampleData -ArviZPythonPlots.use_style("arviz-darkgrid") +use_style("arviz-darkgrid") data = load_example_data("non_centered_eight") -figure() #hide -plot_violin(data; var_names=[:mu, :tau]) +plot_violin(data; var_names=["mu", "tau"]) gcf() ``` @@ -629,9 +592,7 @@ See [`plot_violin`](@ref) ## Styles ```@example -using PyCall -using Distributions -using ArviZPythonPlots +using ArviZPythonPlots, Distributions, PythonCall x = range(0, 1; length=100) dist = pdf.(Beta(2, 5), x) @@ -646,7 +607,7 @@ style_list = [ fig = figure(; figsize=(12, 12)) for (idx, style) in enumerate(style_list) - @pywith plt.style.context(style) begin + pywith(pyplot.style.context(style)) do _ ax = fig.add_subplot(3, 2, idx; label=idx) for i in 0:9 ax.plot(x, dist .- i, "C$i"; label="C$i") diff --git a/docs/src/index.md b/docs/src/index.md index 40f5578..6c844a2 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -1,32 +1,16 @@ # ArviZPythonPlots.jl ArviZPythonPlots.jl provides PyPlot-compatible plotting functions for exploratory analysis of Bayesian models using [ArviZ.jl](https://julia.arviz.org/). -It provides an interface to use the plotting functions in [Python ArviZ](https://python.arviz.org/) with Julia types. -It also re-exports all methods exported by both ArviZ.jl and [PyPlot.jl](https://github.com/JuliaPy/PyPlot.jl). +It uses [PythonCall.jl](https://github.com/cjdoris/PythonCall.jl) to provide an interface for using the plotting functions in [Python ArviZ](https://python.arviz.org/) with Julia types. +It also re-exports all methods exported by [PythonPlot.jl](https://github.com/JuliaPy/PythonPlot.jl). For details, see the [Example Gallery](@ref) or the [API](@ref api). ## [Installation](@id installation) To install ArviZPythonPlots.jl, we first need to install Python ArviZ. -To use with the default Python environment, first [install Python ArviZ](https://python.arviz.org/en/latest/getting_started/Installation.html). From the Julia REPL, type `]` to enter the Pkg REPL mode and run ``` pkg> add ArviZPythonPlots ``` - -To install ArviZPythonPlots.jl with its Python dependencies in Julia's private conda environment, in the console run - -```console -PYTHON="" julia -e 'using Pkg; Pkg.add("PyCall"); Pkg.build("PyCall"); Pkg.add("ArviZPythonPlots")' -``` - -For specifying other Python versions, see the [PyCall documentation](https://github.com/JuliaPy/PyCall.jl). - -## [Known Issues](@id knownissues) - -ArviZPythonPlots.jl uses [PyCall.jl](https://github.com/JuliaPy/PyCall.jl) to wrap Python ArviZ. -At the moment, Julia segfaults if Numba is imported, which Python ArviZ does if it is available. -For the moment, the workaround is to [specify a Python version](https://github.com/JuliaPy/PyCall.jl#specifying-the-python-version) that doesn't have Numba installed. -See [this issue](https://github.com/JuliaPy/PyCall.jl/issues/220) for more details. diff --git a/src/ArviZPythonPlots.jl b/src/ArviZPythonPlots.jl index 14ad30d..3249b4d 100644 --- a/src/ArviZPythonPlots.jl +++ b/src/ArviZPythonPlots.jl @@ -1,21 +1,17 @@ module ArviZPythonPlots using Base: @__doc__ -using DataFrames +using ArviZ +using DimensionalData: DimensionalData, Dimensions using OrderedCollections: OrderedDict - +using PythonCall using Reexport -@reexport using ArviZ -@reexport using PyPlot -using PyCall -using Conda -using DimensionalData: DimensionalData, Dimensions - -import Base.Docs: getdoc -import Markdown: @doc_str +using Tables # Exports +@reexport using PythonPlot + ## Plots export plot_autocorr, plot_bpv, @@ -42,28 +38,26 @@ export plot_autocorr, plot_violin ## rcParams -export rcParams, with_rc_context +export rcParams, rc_context ## styles export styles, use_style -const _min_arviz_version = v"0.13.0" -const arviz = PyNULL() -const xarray = PyNULL() -const pandas = PyNULL() -const _rcParams = PyNULL() - -include("setup.jl") - -# Load ArviZ once at precompilation time for docstringS -copy!(arviz, import_arviz()) -check_needs_update(; update=false) -const _precompile_arviz_version = arviz_version() +const arviz = PythonCall.pynew() +const xarray = PythonCall.pynew() +const pandas = PythonCall.pynew() function __init__() - return initialize_arviz() + PythonCall.pycopy!(arviz, pyimport("arviz")) + PythonCall.pycopy!(xarray, pyimport("xarray")) + PythonCall.pycopy!(pandas, pyimport("pandas")) + PythonCall.pycopy!(rcParams, arviz.rcParams) + # use 1-based indexing in plots + rcParams["data.index_origin"] = 1 + return nothing end +include("lazyhelp.jl") include("utils.jl") include("rcparams.jl") include("style.jl") diff --git a/src/conversions.jl b/src/conversions.jl index 7519ea5..65ed192 100644 --- a/src/conversions.jl +++ b/src/conversions.jl @@ -1,9 +1,9 @@ -function topandas(::Val{:ELPDData}, d::PSISLOOResult) +function PythonCall.Py(d::PSISLOOResult) estimates = elpd_estimates(d) pointwise = elpd_estimates(d; pointwise=true) psis_result = d.psis_result ds = convert_to_dataset((loo_i=pointwise.elpd, pareto_shape=pointwise.pareto_shape)) - pyds = PyCall.PyObject(ds) + pyds = PythonCall.Py(ds) entries = ( elpd_loo=estimates.elpd, se=estimates.elpd_mcse, @@ -15,16 +15,16 @@ function topandas(::Val{:ELPDData}, d::PSISLOOResult) pareto_k=pyds.pareto_shape, scale="log", ) - return PyCall.pycall( - arviz.stats.ELPDData, PyCall.PyObject; data=values(entries), index=keys(entries) - ) + data = pylist(values(entries)) + index = pylist(map(pystr, keys(entries))) + return arviz.stats.ELPDData(; data, index) end -function topandas(::Val{:ELPDData}, d::WAICResult) +function PythonCall.Py(d::WAICResult) estimates = elpd_estimates(d) pointwise = elpd_estimates(d; pointwise=true) ds = convert_to_dataset((waic_i=pointwise.elpd,)) - pyds = PyCall.PyObject(ds) + pyds = PythonCall.Py(ds) entries = ( elpd_waic=estimates.elpd, se=estimates.elpd_mcse, @@ -35,7 +35,27 @@ function topandas(::Val{:ELPDData}, d::WAICResult) waic_i=pyds.waic_i, scale="log", ) - return PyCall.pycall( - arviz.stats.ELPDData, PyCall.PyObject; data=values(entries), index=keys(entries) - ) + data = pylist(values(entries)) + index = pylist(map(pystr, keys(entries))) + return arviz.stats.ELPDData(; data, index) +end + +function rekey(nt::NamedTuple, old_new_keys::Pair...) + keys_new = replace(keys(nt), old_new_keys...) + return NamedTuple{keys_new}(values(nt)) +end + +function PythonCall.Py(mc::ModelComparisonResult) + table = Tables.columntable(mc) + se_pairs = (:elpd_mcse => :se, :elpd_diff_mcse => :dse) + est_pairs = if eltype(mc.elpd_result) <: PSISLOOResult + (:elpd => :elpd_loo, :p => :p_loo) + elseif eltype(mc.elpd_result) <: WAICResult + (:elpd => :elpd_waic, :p => :p_waic) + end + nrows = Tables.rowcount(table) + new_cols = (warning=fill(false, nrows), scale=fill("log", nrows)) + table_new = merge(rekey(table, est_pairs..., se_pairs...), new_cols) + pdf = topandas(Val(:DataFrame), table_new; index_name="name") + return pdf end diff --git a/src/lazyhelp.jl b/src/lazyhelp.jl new file mode 100644 index 0000000..707a517 --- /dev/null +++ b/src/lazyhelp.jl @@ -0,0 +1,34 @@ +# adapted from https://github.com/JuliaPy/PythonPlot.jl +# MIT License +# Copyright © 2013 by Steven G. Johnson + +struct LazyHelp + o # a Py or similar object supporting getindex with a __doc__ property + keys::Tuple{Vararg{String}} + LazyHelp(o) = new(o, ()) + LazyHelp(o, k::AbstractString) = new(o, (k,)) + LazyHelp(o, k1::AbstractString, k2::AbstractString) = new(o, (k1, k2)) + LazyHelp(o, k::Tuple{Vararg{AbstractString}}) = new(o, k) +end + +function Base.show(io::IO, ::MIME"text/plain", h::LazyHelp) + o = h.o + for k in h.keys + o = pygetattr(o, k) + end + if pyhasattr(o, "__doc__") + print(io, pyconvert(String, o.__doc__)) + else + print(io, "no Python docstring found for ", o) + end +end + +Base.show(io::IO, h::LazyHelp) = Base.show(io, "text/plain", h) + +function Base.Docs.catdoc(hs::LazyHelp...) + Base.Docs.Text() do io + for h in hs + Base.show(io, MIME"text/plain"(), h) + end + end +end diff --git a/src/plots.jl b/src/plots.jl index f93505d..801d4db 100644 --- a/src/plots.jl +++ b/src/plots.jl @@ -23,42 +23,13 @@ @forwardplotfun plot_trace @forwardplotfun plot_violin -function convert_arguments(::typeof(plot_kde), values, args...; values2=nothing, kwargs...) - if values2 === nothing - kwargs_new = NamedTuple(kwargs) - else - kwargs_new = (; values2=convert(Array, values2), kwargs...) - end - return tuple(convert(Array, values), args...), kwargs_new -end - -function convert_arguments( - ::typeof(plot_compare), mc::ModelComparisonResult, args...; kwargs... -) - df = DataFrame(mc) - rename!(df, :elpd_mcse => :se, :elpd_diff_mcse => :dse) - if eltype(mc.elpd_result) <: PSISLOOResult - rename!(df, :elpd => :elpd_loo, :p => :p_loo) - elseif eltype(mc.elpd_result) <: WAICResult - rename!(df, :elpd => :elpd_waic, :p => :p_waic) - end - df.warning = map(_ -> false, df.name) - df.scale = map(_ -> "log", df.name) - pdf = topandas(Val(:DataFrame), df; index_name=:name) - return tuple(pdf, args...), kwargs -end - function convert_arguments(::typeof(plot_elpd), data, args...; kwargs...) - dict = Dict(k => try - topandas(Val(:ELPDData), v) - catch - convert_to_inference_data(v) - end for (k, v) in pairs(data)) + dict = OrderedDict( + k => v isa AbstractELPDResult ? v : convert_to_inference_data(v) for + (k, v) in pairs(data) + ) return tuple(dict, args...), kwargs end -function convert_arguments(::typeof(plot_khat), df, args...; kwargs...) - return tuple(topandas(Val(:ELPDData), df), args...), kwargs -end for f in ( :plot_autocorr, @@ -67,6 +38,7 @@ for f in ( :plot_pair, :plot_parallel, :plot_posterior, + :plot_trace, :plot_violin, ) @eval begin @@ -77,29 +49,6 @@ for f in ( end end -function convert_arguments(::typeof(plot_trace), data, args...; kwargs...) - idata = convert_to_inference_data(data; group=:posterior) - # temporary workaround for https://github.com/arviz-devs/arviz/issues/2150 - if arviz_version() ≤ v"0.13.0" && - hasgroup(idata, :sample_stats) && - haskey(idata.sample_stats, :diverging) - sample_dims = Dimensions.key2dim((:chain, :draw)) - diverging = permutedims(idata.sample_stats.diverging, sample_dims) - sample_stats = merge(idata.sample_stats, (; diverging)) - idata = merge(idata, InferenceData(; sample_stats)) - end - return tuple(idata, args...), kwargs -end - -for f in (:plot_autocorr, :plot_ess, :plot_mcse, :plot_posterior, :plot_violin) - @eval begin - function convert_arguments(::typeof($(f)), data::AbstractArray, args...; kwargs...) - idata = convert_to_inference_data(data; group=:posterior) - return tuple(idata, args...), kwargs - end - end -end - function convert_arguments(::typeof(plot_energy), data, args...; kwargs...) dataset = convert_to_dataset(data; group=:sample_stats) return tuple(dataset, args...), kwargs diff --git a/src/rcparams.jl b/src/rcparams.jl index 931f879..a177edf 100644 --- a/src/rcparams.jl +++ b/src/rcparams.jl @@ -1,102 +1,5 @@ -@doc doc""" - rcParams +@doc LazyHelp(arviz, "rcParams") const rcParams = PythonCall.pynew() -Dictionary to contain Python ArviZ default parameters, with validation when setting items. - -Note that only Python code will use these parameters, so in general only the ones used by -plotting functions have an effect. -""" -rcParams - -struct RcParams{K,V} <: AbstractDict{K,V} - o::PyObject -end -RcParams(obj) = RcParams{Any,Any}(obj) - -@inline PyObject(r::RcParams) = getfield(r, :o) - -Base.convert(::Type{RcParams{K,V}}, obj::PyObject) where {K,V} = RcParams{K,V}(obj) -Base.convert(::Type{RcParams}, obj::PyObject) = RcParams(obj) - -const rcParams = RcParams{String,Any}(_rcParams) - -@inline Base.length(r::RcParams) = py"len"(PyObject(r)) -function Base.get(r::RcParams, k, default) - haskey(r, k) && return PyObject(r).__getitem__(k) - return default -end -function Base.setindex!(r::RcParams, v, k) - try - PyObject(r).__setitem__(k, v) - catch e - if e isa PyCall.PyError - err = e.val - if pyisinstance(err, py"ValueError") - throw(ErrorException(err.args[1])) - elseif pyisinstance(err, py"KeyError") - throw( - KeyError( - "$(k) is not a valid rc parameter (see keys(rcParams) for a list of valid parameters)", - ), - ) - end - end - throw(e) - end - return r -end - -@inline Base.haskey(r::RcParams, k) = PyObject(r).__contains__(k) - -function Base.iterate(r::RcParams, it) - return try - pair = Pair(py"next"(it)...) - (pair, it) - catch - nothing - end -end -function Base.iterate(r::RcParams) - items = PyObject(r).items() - it = py"iter"(items) - return Base.iterate(r, it) -end - -@doc doc""" - with_rc_context(f; rc = nothing, fname = nothing) - -Execute the thunk `f` within a context controlled by temporary rc params. - -See [`rcParams`](@ref) for supported params or to modify params long-term. - -# Examples - -```julia -using ArviZExampleData -with_rc_context(fname = "pystan.rc") do - idata = load_example_data("radon") - plot_posterior(idata; var_names=["gamma"]) -end -``` - -The plot would have settings from `pystan.rc`. - -A dictionary can also be passed to the context manager: - -```julia -with_rc_context(rc = Dict("plot.max_subplots" => 1), fname = "pystan.rc") do - idata = load_example_data("radon") - plot_posterior(idata, var_names=["gamma"]) -end -``` - -The `rc` dictionary takes precedence over the settings loaded from `fname`. Passing a -dictionary only is also valid. -""" -with_rc_context - -function with_rc_context(f; kwargs...) - return @pywith arviz.rc_context(; kwargs...) as _ begin - return f() - end +@doc LazyHelp(arviz, "rc_context") function rc_context(args...; kwargs...) + return arviz.rc_context(args...; kwargs...) end diff --git a/src/setup.jl b/src/setup.jl deleted file mode 100644 index b129e0e..0000000 --- a/src/setup.jl +++ /dev/null @@ -1,135 +0,0 @@ -import_arviz() = _import_dependency("arviz", "arviz"; channel="conda-forge") - -function arviz_version() - v = arviz.__version__ - try - VersionNumber(v) - catch ArgumentError - v, suff = splitext(v) - return VersionNumber(v * suff[2:end]) - end -end - -function check_needs_update(; update=true) - if arviz_version() < _min_arviz_version - @warn "ArviZ.jl only officially supports arviz version $(_min_arviz_version) or " * - "greater but found version $(arviz_version())." - if update - if update_arviz() - # yay, but we still already imported the old version - msg = """ - Please rebuild ArviZ.jl with `using Pkg; Pkg.build("ArviZ")` and re-launch Julia - to continue. - """ - else - msg = """ - Could not automatically update arviz. Please manually update arviz, rebuild - ArviZ.jl with `using Pkg; Pkg.build("ArviZ")`, and then re-launch Julia to - continue. - """ - end - @warn msg - end - end - return nothing -end - -function check_needs_rebuild() - if arviz_version() != _precompile_arviz_version - msg = """ - ArviZ.jl was built using arviz version $(_precompile_arviz_version) but loaded with - version $(arviz_version()). Please recompile with `using Pkg; Pkg.build("ArviZ")` - and re-launch Julia to continue. - """ - @warn msg - end - return nothing -end - -function initialize_arviz() - ispynull(arviz) || return nothing - copy!(arviz, import_arviz()) - check_needs_update(; update=true) - check_needs_rebuild() - - pytype_mapping(arviz.InferenceData, InferenceData) - - # pytypemap-ing RcParams produces a Dict - copy!(_rcParams, py"$(arviz).rcparams.rcParams"o) - - # use 1-based indexing by default within arviz - rcParams["data.index_origin"] = 1 - - initialize_xarray() - initialize_numpy() - return nothing -end - -function initialize_xarray() - ispynull(xarray) || return nothing - copy!(xarray, _import_dependency("xarray", "xarray"; channel="conda-forge")) - _import_dependency("dask", "dask"; channel="conda-forge") - return nothing -end - -function initialize_numpy() - # Trigger NumPy initialization, see https://github.com/JuliaPy/PyCall.jl/issues/744 - PyObject([true]) - return nothing -end - -function initialize_pandas() - ispynull(pandas) || return nothing - copy!(pandas, _import_dependency("pandas", "pandas"; channel="conda-forge")) - return nothing -end - -function update_arviz() - # updating arviz can change other packages, so we always ask for permission - if _using_conda() && _isyes(Base.prompt("Try updating arviz using conda? [Y/n]")) - # this syntax isn't officially supported, but it works (for now) - try - Conda.add("arviz>=$_min_arviz_version"; channel="conda-forge") - return true - catch e - println(e.msg) - end - end - if _has_pip() && _isyes(Base.prompt("Try updating arviz using pip? [Y/n]")) - # can't specify version lower bound, so update to latest - try - run(PyCall.python_cmd(`-m pip install --upgrade -- arviz`)) - return true - catch e - println(e.msg) - end - end - return false -end - -function _import_dependency(modulename, pkgname=modulename; channel=nothing) - try - return if channel === nothing - pyimport_conda(modulename, pkgname) - else - pyimport_conda(modulename, pkgname, channel) - end - catch e - if _has_pip() && _isyes(Base.prompt("Try installing $pkgname using pip? [Y/n]")) - # installing with pip is riskier, so we ask for permission - run(PyCall.python_cmd(`-m pip install -- $pkgname`)) - return pyimport(modulename) - end - # PyCall has a nice error message - throw(e) - end -end - -_isyes(s) = isempty(s) || lowercase(strip(s)) ∈ ("y", "yes") -_isyes(::Nothing) = true - -_using_conda() = PyCall.conda - -_has_pip() = _has_pymodule("pip") - -_has_pymodule(modulename) = !ispynull(pyimport_e(modulename)) diff --git a/src/style.jl b/src/style.jl index 1dd5967..a29d232 100644 --- a/src/style.jl +++ b/src/style.jl @@ -12,11 +12,14 @@ To see all available style specifications, use [`styles()`](@ref). If a `Vector` of styles is provided, they are applied from first to last. """ -use_style(style) = plt.style.use(style) +function use_style(style) + pyplot.style.use(style) + return nothing +end """ styles() -> Vector{String} Get all available matplotlib styles for use with [`use_style`](@ref) """ -styles() = plt.style.available +styles() = pyconvert(Vector{String}, pyplot.style.available) diff --git a/src/utils.jl b/src/utils.jl index 0db5f18..c57d627 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -9,88 +9,22 @@ to arviz. convert_arguments(::Any, args...; kwargs...) = args, kwargs """ - convert_result(f, result, args...) - -Convert result of the function `f` before returning. - -This function is used primarily for post-processing outputs of arviz before returning. -The `args` are primarily used for dispatch. -""" -convert_result(f, result, args...) = result - -function forwarddoc(f::Symbol) - pydoc = "$(Docs.getdoc(getproperty(arviz, f)))" - pydoc_sections = split(pydoc, '\n'; limit=2) - if length(pydoc_sections) > 1 - summary, body = pydoc_sections - summary *= "\n" - else - summary = "" - body = pydoc - end - return """ - $summary - - !!! note - This function is forwarded to Python's [`arviz.$(f)`](https://python.arviz.org/en/v$(arviz_version())/api/generated/arviz.$(f).html). - The docstring of that function is included below. - ``` - $body - ``` - """ -end - -forwardgetdoc(f::Symbol) = Docs.getdoc(getproperty(arviz, f)) - -""" - @forwardfun f [forward_docs] - @forwardfun(f, forward_docs=true) - -Wrap a function `arviz.f` in `f`, forwarding its docstrings. - -Use [`convert_arguments`](@ref) and [`convert_result`](@ref) to customize what is passed to -and returned from `f`. -""" -macro forwardfun(f, forward_docs=true) - fesc = esc(f) - fdoc = forwarddoc(f) - ex = quote - if $forward_docs - @doc $fdoc $f - end - - function $(fesc)(args...; kwargs...) - args, kwargs = convert_arguments($(fesc), args...; kwargs...) - result = arviz.$(f)(args...; kwargs...) - return convert_result($(fesc), result) - end - end - # make sure line number of methods are place where macro is called, not here - _replace_line_number!(ex, __source__) - return ex -end - -""" - @forwardplotfun f [forward_docs] - @forwardplotfun(f, forward_docs=true) + @forwardplotfun f Wrap a plotting function `arviz.f` in `f`, forwarding its docstrings. -Use [`convert_arguments`](@ref) and [`convert_result`](@ref) to customize what is passed to -and returned from `f`. +Use [`convert_arguments`](@ref) to customize what is passed to `f`. """ -macro forwardplotfun(f, forward_docs=true) +macro forwardplotfun(f) fesc = esc(f) - fdoc = forwarddoc(f) + sf = string(f) ex = quote - if $forward_docs - @doc $fdoc $f - end - - function $(fesc)(args...; kwargs...) + @doc LazyHelp(arviz, $sf) function $(fesc)(args...; kwargs...) args, kwargs = convert_arguments($(fesc), args...; kwargs...) - result = arviz.$(f)(args...; kwargs..., backend="matplotlib") - return convert_result($(fesc), result) + pyargs = Iterators.map(topytype, args) + pykwargs = (k => topytype(v) for (k, v) in pairs(kwargs)) + result = arviz.$(f)(pyargs...; pykwargs..., backend="matplotlib") + return result end end # make sure line number of methods are place where macro is called, not here @@ -108,71 +42,31 @@ function _replace_line_number!(ex, source) end end -# Convert python types to Julia types if possible -@inline frompytype(x) = x -@inline frompytype(x::PyObject) = PyAny(x) -frompytype(x::AbstractArray{PyObject}) = map(frompytype, x) -frompytype(x::AbstractArray{Any}) = map(frompytype, x) -frompytype(x::AbstractArray{<:AbstractArray}) = map(frompytype, x) - -""" - todataframes(df; index_name = nothing) -> DataFrames.DataFrame - -Convert a Python `pandas.DataFrame` or `pandas.Series` into a `DataFrames.DataFrame`. - -If `index_name` is not `nothing`, the index is converted into a column with `index_name`. -Otherwise, it is discarded. -""" -function todataframes(::Val{:DataFrame}, df::PyObject; index_name=nothing) - initialize_pandas() - col_vals = map(df.columns) do name - series = py"$(df)[$(name)]" - vals = series.values - return Symbol(name) => frompytype(vals) - end - if index_name !== nothing - index_vals = frompytype(df.index.values) - col_vals = [Symbol(index_name) => index_vals; col_vals] - end - return DataFrames.DataFrame(col_vals) -end -function todataframes(::Val{:Series}, series::PyObject; kwargs...) - initialize_pandas() - colnames = map(i -> Symbol(frompytype(i)), series.index) - colvals = map(x -> [frompytype(x)], series.values) - return DataFrames.DataFrame(colvals, colnames) -end -function todataframes(df::PyObject; kwargs...) - initialize_pandas() - if pyisinstance(df, pandas.Series) - return todataframes(Val(:Series), df; kwargs...) - end - return todataframes(Val(:DataFrame), df; kwargs...) -end +# Convert Julia types to suitable Python types +topytype(x::AbstractVector) = pylist(map(topytype, x)) +topytype(x::AbstractVector{<:Real}) = Py(x).to_numpy() +topytype(x::AbstractUnitRange{<:Integer}) = topytype(collect(x)) +topytype(x::AbstractArray{<:Real}) = Py(x).to_numpy() +topytype(x::Tuple) = pytuple(map(topytype, x)) +topytype(x::AbstractDict) = pydict(topytype(k) => topytype(v) for (k, v) in pairs(x)) +topytype(x::NamedTuple) = topytype(pairs(x)) +topytype(x::Symbol) = pystr(x) +topytype(::Missing) = Py(NaN) +topytype(x) = Py(x) """ - topandas(::Type{:DataFrame}, df; index_name = nothing) -> PyObject - topandas(::Type{:Series}, df) -> PyObject - topandas(::Val{:ELPDData}, df) -> PyObject + topandas(::Type{:DataFrame}, table; index_name = nothing) -> Py -Convert a `DataFrames.DataFrame` to the specified pandas type. +Convert a Tables-compatible table to the specified pandas type. If `index_name` is not `nothing`, the corresponding column is made the index of the returned dataframe. """ -function topandas(::Val{:DataFrame}, df; index_name=nothing) - initialize_pandas() - df = DataFrames.DataFrame(df) - colnames = names(df) - rowvals = map(Array, eachrow(df)) +function topandas(::Val{:DataFrame}, table; index_name=nothing) + # initialize_pandas() + colnames = topytype(Tables.columnnames(table)) + rowvals = map(topytype ∘ values, Tables.namedtupleiterator(table)) pdf = pandas.DataFrame(rowvals; columns=colnames) - index_name !== nothing && pdf.set_index(index_name; inplace=true) + index_name !== nothing && pdf.set_index(topytype(index_name); inplace=true) return pdf end -function topandas(::Val{:Series}, df) - initialize_pandas() - df = DataFrames.DataFrame(df) - rownames = names(df) - colvals = Array(only(eachrow(df))) - return pandas.Series(colvals, rownames) -end diff --git a/src/xarray.jl b/src/xarray.jl index 569de39..8365d52 100644 --- a/src/xarray.jl +++ b/src/xarray.jl @@ -1,102 +1,33 @@ -PyCall.PyObject(data::Dataset) = _to_xarray(data) +PythonCall.Py(data::Dataset) = _to_xarray(data) -Base.convert(::Type{Dataset}, obj::PyObject) = Dataset(_dimstack_from_xarray(obj)) - -function PyCall.PyObject(data::InferenceData) +function PythonCall.Py(data::InferenceData) groups = NamedTuple(data) - return pycall(arviz.InferenceData, PyObject; map(PyObject, groups)...) -end - -function ArviZ.convert_to_inference_data( - obj::PyObject; dims=nothing, coords=nothing, kwargs... -) - if pyisinstance(obj, arviz.InferenceData) - group_names = obj.groups() - groups = ( - Symbol(name) => convert(Dataset, getindex(obj, name)) for name in group_names - ) - return InferenceData(; groups...) - else - # Python ArviZ requires dims and coords be dicts matching to vectors - pydims = dims === nothing ? dims : Dict(k -> collect(dims[k]) for k in keys(dims)) - pycoords = - dims === nothing ? dims : Dict(k -> collect(coords[k]) for k in keys(coords)) - return arviz.convert_to_inference_data(obj; dims=pydims, coords=pycoords, kwargs...) - end -end - -function _dimstack_from_xarray(o::PyObject) - pyisinstance(o, xarray.Dataset) || - throw(ArgumentError("argument is not an `xarray.Dataset`.")) - var_names = collect(o.data_vars) - data = [_dimarray_from_xarray(getindex(o, name)) for name in var_names] - metadata = OrderedDict{Symbol,Any}(Symbol(k) => v for (k, v) in o.attrs) - return DimensionalData.DimStack(data...; metadata) -end - -function _dimarray_from_xarray(o::PyObject) - pyisinstance(o, xarray.DataArray) || - throw(ArgumentError("argument is not an `xarray.DataArray`.")) - name = Symbol(o.name) - data = _process_pyarray(o.to_numpy()) - coords = PyCall.PyDict(o.coords) - dims = Tuple( - map(d -> _wrap_dims(Symbol(d), _process_pyarray(coords[d].values)), o.dims) - ) - attrs = OrderedDict{Symbol,Any}(Symbol(k) => v for (k, v) in o.attrs) - metadata = isempty(attrs) ? DimensionalData.NoMetadata() : attrs - return DimensionalData.DimArray(data, dims; name, metadata) + return arviz.InferenceData(; map(topytype, groups)...) end -_process_pyarray(x) = x -# NOTE: sometimes strings fail to convert to Julia types, so we try to force them here -function _process_pyarray(x::Union{PyObject,<:AbstractVector{PyObject}}) - return map(z -> z isa PyObject ? PyAny(z)::Any : z, x) -end - -# wrap dims in a `Dim`, converting to an AbstractRange if possible -function _wrap_dims(name::Symbol, dims::AbstractVector{<:Real}) - D = DimensionalData.Dim{name} - start = dims[begin] - stop = dims[end] - n = length(dims) - step = (stop - start) / (n - 1) - isrange = all(Iterators.drop(eachindex(dims), 1)) do i - return (dims[i] - dims[i - 1]) ≈ step - end - return if isrange - if step == 1 - D(UnitRange(start, stop)) - else - D(range(start, stop; length=n)) - end - else - D(dims) - end -end -_wrap_dims(name::Symbol, dims::AbstractVector) = DimensionalData.Dim{name}(dims) - function _to_xarray(data::DimensionalData.AbstractDimStack) - data_vars = Dict(pairs(map(_to_xarray, DimensionalData.layers(data)))) - attrs = Dict(pairs(DimensionalData.metadata(data))) - return PyCall.pycall(xarray.Dataset, PyObject, data_vars; attrs) + data_vars = map(_to_xarray, DimensionalData.layers(data)) + attrs = pairs(DimensionalData.metadata(data)) + return xarray.Dataset(topytype(data_vars); attrs=topytype(attrs)) end function _to_xarray(data::DimensionalData.AbstractDimArray) var_name = DimensionalData.name(data) data_dims = DimensionalData.dims(data) - dims = collect(DimensionalData.name(data_dims)) + dims = DimensionalData.name(data_dims) coords = Dict(zip(dims, DimensionalData.index(data_dims))) - default_dims = String[] + default_dims = () values = parent(data) if Missing <: eltype(values) - # passing `missing` to Python causes the array to have a `PyCall.jlwrap` dtype + # passing `missing` to Python causes the array to have a `PythonCall.jlwrap` dtype values = replace(values, missing => NaN) end - metadata = DimensionalData.metadata(data) - da = arviz.numpy_to_data_array(values; var_name, dims, coords, default_dims) - if !isempty(metadata) - da.attrs = metadata - end - return da + metadata = pairs(DimensionalData.metadata(data)) + kwargs = (; var_name, dims, coords, default_dims) + pykwargs = map(topytype, kwargs) + return da = arviz.numpy_to_data_array(topytype(values); pykwargs...) + # if !isempty(metadata) + # da.attrs = metadata + # end + # return da end diff --git a/test/helpers.jl b/test/helpers.jl index ee5660a..549093f 100644 --- a/test/helpers.jl +++ b/test/helpers.jl @@ -1,6 +1,7 @@ +using ArviZ using ArviZPythonPlots using Random -using PyCall +using Test function random_dim_array(var_name, dims, coords, default_dims=()) _dims = (default_dims..., dims...) diff --git a/test/runtests.jl b/test/runtests.jl index 9eff936..244c26b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,7 +4,9 @@ using Test @testset "ArviZPythonPlots" begin include("helpers.jl") include("test_rcparams.jl") + include("test_style.jl") include("test_utils.jl") include("test_conversions.jl") + include("test_xarray.jl") include("test_plots.jl") end diff --git a/test/test_conversions.jl b/test/test_conversions.jl index cc4bbfa..584cee9 100644 --- a/test/test_conversions.jl +++ b/test/test_conversions.jl @@ -1,5 +1,7 @@ -using ArviZPythonPlots +using ArviZ using ArviZExampleData +using ArviZPythonPlots +using PythonCall using Test @testset "conversions" begin @@ -7,23 +9,38 @@ using Test idata = load_example_data("centered_eight") loo_result = loo(idata; reff=1) loo_py_result = ArviZPythonPlots.arviz.loo(idata; pointwise=true, reff=1) - py_loo_result = ArviZPythonPlots.topandas(Val(:ELPDData), loo_result) - @test all(py_loo_result.keys() == loo_py_result.keys()) - @test py_loo_result.elpd_loo ≈ loo_py_result.elpd_loo rtol = 1e-3 - @test py_loo_result.se ≈ loo_py_result.se rtol = 1e-1 - @test py_loo_result.p_loo ≈ loo_py_result.p_loo rtol = 1e-3 - @test py_loo_result.loo_i.values ≈ loo_py_result.loo_i.values rtol = 1e-3 - @test py_loo_result.pareto_k.values ≈ loo_py_result.pareto_k.values rtol = 1e-1 + py_loo_result = Py(loo_result) + @test all( + pyconvert(Array{String}, py_loo_result.keys()) == + pyconvert(Array{String}, loo_py_result.keys()), + ) + @test pyconvert(Float64, py_loo_result.elpd_loo) ≈ + pyconvert(Float64, loo_py_result.elpd_loo) rtol = 1e-3 + @test pyconvert(Float64, py_loo_result.se) ≈ pyconvert(Float64, loo_py_result.se) rtol = + 1e-1 + @test pyconvert(Float64, py_loo_result.p_loo) ≈ + pyconvert(Float64, loo_py_result.p_loo) rtol = 1e-3 + @test pyconvert(Array{Float64}, py_loo_result.loo_i.values) ≈ + pyconvert(Array{Float64}, loo_py_result.loo_i.values) rtol = 1e-3 + @test pyconvert(Array{Float64}, py_loo_result.pareto_k.values) ≈ + pyconvert(Array{Float64}, loo_py_result.pareto_k.values) rtol = 1e-1 end @testset "WAICResult" begin idata = load_example_data("centered_eight") waic_result = waic(idata) waic_py_result = ArviZPythonPlots.arviz.waic(idata; pointwise=true) - py_waic_result = ArviZPythonPlots.topandas(Val(:ELPDData), waic_result) - @test all(py_waic_result.keys() == waic_py_result.keys()) - @test py_waic_result.elpd_waic ≈ waic_py_result.elpd_waic rtol = 1e-3 - @test py_waic_result.se ≈ waic_py_result.se rtol = 1e-1 - @test py_waic_result.p_waic ≈ waic_py_result.p_waic rtol = 1e-3 - @test py_waic_result.waic_i.values ≈ waic_py_result.waic_i.values rtol = 1e-3 + py_waic_result = Py(waic_result) + @test all( + pyconvert(Array{String}, py_waic_result.keys()) == + pyconvert(Array{String}, waic_py_result.keys()), + ) + @test pyconvert(Float64, py_waic_result.elpd_waic) ≈ + pyconvert(Float64, waic_py_result.elpd_waic) rtol = 1e-3 + @test pyconvert(Float64, py_waic_result.se) ≈ pyconvert(Float64, waic_py_result.se) rtol = + 1e-1 + @test pyconvert(Float64, py_waic_result.p_waic) ≈ + pyconvert(Float64, waic_py_result.p_waic) rtol = 1e-3 + @test pyconvert(Array{Float64}, py_waic_result.waic_i.values) ≈ + pyconvert(Array{Float64}, waic_py_result.waic_i.values) rtol = 1e-3 end end diff --git a/test/test_plots.jl b/test/test_plots.jl index 54ec4c6..4e67e6e 100644 --- a/test/test_plots.jl +++ b/test/test_plots.jl @@ -1,6 +1,7 @@ -using ArviZPythonPlots +using ArviZ using ArviZExampleData -using PyCall +using ArviZPythonPlots +using PythonCall using Test @testset "plots" begin @@ -14,81 +15,81 @@ using Test @testset "$(f)" for f in (plot_trace, plot_pair) f(data; var_names=["tau", "mu"]) - close(gcf()) + plotclose() f((x=arr1, y=arr2); var_names=["x", "y"]) - close(gcf()) + plotclose() end @testset "$(f)" for f in (plot_autocorr, plot_ess, plot_mcse, plot_posterior, plot_violin) f(data; var_names=["tau", "mu"]) - close(gcf()) + plotclose() f(arr1) - close(gcf()) + plotclose() f((x=arr1, y=arr2); var_names=["x", "y"]) - close(gcf()) + plotclose() end @testset "$(f)" for f in (plot_energy, plot_parallel) f(data) - close(gcf()) + plotclose() end @testset "$(f)" for f in (plot_density, plot_forest) f(data; var_names=["tau", "mu"]) - close(gcf()) + plotclose() f([(x=arr1,), (x=arr2,)]; var_names=["x"]) - close(gcf()) + plotclose() f(arr3) - close(gcf()) + plotclose() f((x=arr1, y=arr2); var_names=["x", "y"]) - close(gcf()) + plotclose() end @testset "plot_bpv" begin plot_bpv(data) - close(gcf()) + plotclose() plot_bpv(data; kind="p_value") - close(gcf()) + plotclose() end @testset "plot_separation" begin data3 = load_example_data("classification10d") plot_separation(data3; y="outcome") - close(gcf()) + plotclose() end @testset "plot_rank" begin plot_rank(data; var_names=["tau", "mu"]) - close(gcf()) + plotclose() plot_rank(arr1) - close(gcf()) + plotclose() plot_rank((x=arr1, y=arr2); var_names=["x", "y"]) - close(gcf()) + plotclose() end @testset "plot_compare" begin mc = compare((a=data, b=data2)) plot_compare(mc) - close(gcf()) + plotclose() end @testset "plot_dist_compare" begin plot_dist_comparison(data; var_names=["mu"]) - close(gcf()) + plotclose() end @testset "$(f)" for f in (plot_dist, ArviZPythonPlots.plot_ecdf) f(arr1) - close(gcf()) + plotclose() end VERSION ≥ v"1.8" && @testset "plot_kde" begin plot_kde(arr1) - close(gcf()) + plotclose() plot_kde(arr1, arr2) - close(gcf()) + plotclose() end @testset "plot_hdi" begin @@ -96,29 +97,29 @@ using Test y_data = 2 .+ x_data .* 0.5 y_data_rep = 0.5 .* randn(rng, 200, 100) .+ transpose(y_data) plot_hdi(x_data, y_data_rep) - close(gcf()) + plotclose() end @testset "plot_elpd" begin plot_elpd(Dict("a" => data, "b" => data2)) - close(gcf()) + plotclose() plot_elpd(Dict("a" => loo(data), "b" => loo(data2))) - close(gcf()) + plotclose() end @testset "plot_khat" begin l = loo(data) plot_khat(l) - close(gcf()) + plotclose() end @testset "plot_loo_pit" begin plot_loo_pit(data; y="obs") - close(gcf()) + plotclose() end @testset "plot_loo_pit" begin plot_ppc(data) - close(gcf()) + plotclose() end end diff --git a/test/test_rcparams.jl b/test/test_rcparams.jl index a78a859..d77e6eb 100644 --- a/test/test_rcparams.jl +++ b/test/test_rcparams.jl @@ -1,44 +1,26 @@ using ArviZPythonPlots +using PythonCall using Test @testset "rcParams" begin @testset "rcParams" begin - @test rcParams isa ArviZPythonPlots.RcParams - @test pyisinstance(PyObject(rcParams), ArviZPythonPlots.arviz.rcparams.RcParams) - pyrcParams = ArviZPythonPlots.arviz.rcParams - @test rcParams == pyrcParams - @test ArviZPythonPlots.RcParams(pyrcParams) isa ArviZPythonPlots.RcParams{Any,Any} - @test isa( - convert(ArviZPythonPlots.RcParams{String,Union{Int64,String}}, pyrcParams), - ArviZPythonPlots.RcParams{String,Union{Int64,String}}, - ) - @test convert(ArviZPythonPlots.RcParams, pyrcParams) isa ArviZPythonPlots.RcParams - @test haskey(rcParams, "plot.backend") - def_backend = rcParams["plot.backend"] - @test ("plot.backend" => def_backend) ∈ rcParams - rcParams["plot.backend"] = "matplotlib" - @test rcParams["plot.backend"] == "matplotlib" - rcParams["plot.backend"] = "bokeh" - @test rcParams["plot.backend"] == "bokeh" - @test_throws KeyError rcParams["blah"] - @test_throws KeyError rcParams["blah"] = 3 - @test_throws ErrorException rcParams["plot.backend"] = "blah" - @test get(rcParams, "blah", "def") == "def" - @test Dict(map(p -> Pair(p...), zip(keys(rcParams), values(rcParams)))) == rcParams - rcParams["plot.backend"] = def_backend + @test rcParams isa Py + @test pyisinstance(rcParams, ArviZPythonPlots.arviz.rcparams.RcParams) + @test pyhasitem(rcParams, "plot.backend") end @testset "defaults" begin - @test rcParams["data.index_origin"] == 1 + @test pyconvert(Int, rcParams["data.index_origin"]) == 1 end - @testset "with_rc_context" begin + @testset "rc_context" begin def_backend = rcParams["plot.backend"] rcParams["plot.backend"] = "matplotlib" - with_rc_context(; rc=Dict("plot.backend" => "bokeh")) do - @test rcParams["plot.backend"] == "bokeh" + pywith(rc_context(; rc=Dict("plot.backend" => "bokeh"))) do _ + @test pyconvert(String, rcParams["plot.backend"]) == "bokeh" return nothing end + @test pyconvert(String, rcParams["plot.backend"]) == "matplotlib" rcParams["plot.backend"] = def_backend end end diff --git a/test/test_style.jl b/test/test_style.jl new file mode 100644 index 0000000..30d65a5 --- /dev/null +++ b/test/test_style.jl @@ -0,0 +1,17 @@ +using ArviZPythonPlots +using PythonCall +using Test + +@testset "style" begin + @testset "styles" begin + @test styles() isa Vector{String} + @test "arviz-darkgrid" ∈ styles() + @test styles() == + map(Base.Fix1(pyconvert, String), ArviZPythonPlots.arviz.style.available) + end + + @testset "use_style" begin + use_style("arviz-darkgrid") + use_style("default") + end +end diff --git a/test/test_utils.jl b/test/test_utils.jl index 5feb14f..3e1c8c5 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -1,76 +1,25 @@ using ArviZPythonPlots using DataFrames: DataFrames -using PyCall +using PythonCall using Test pandas = ArviZPythonPlots.pandas @testset "utils" begin - @testset "frompytype" begin - x = 1.0 - @test ArviZPythonPlots.frompytype(x) === x - x2 = PyObject(x) - @test ArviZPythonPlots.frompytype(x2) == x - @test ArviZPythonPlots.frompytype([x2]) == [x] - @test ArviZPythonPlots.frompytype(Any[x2]) == [x] - @test eltype(ArviZPythonPlots.frompytype(Any[x2])) <: Real - @test ArviZPythonPlots.frompytype([[x2]]) == [[x]] - end - @testset "topandas" begin - @testset "DataFrames.DataFrame -> pandas.DataFrame" begin + @testset "Table -> pandas.DataFrame" begin columns = [:a, :b, :c] index = ["d", "e"] rowvals = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] - df = DataFrames.DataFrame([ - :i => ["d", "e"], :a => [1.0, 4.0], :b => [2.0, 5.0], :c => [3.0, 6.0] - ]) - pdf = ArviZPythonPlots.topandas(Val(:DataFrame), df; index_name=:i) + table = (i=["d", "e"], a=[1.0, 4.0], b=[2.0, 5.0], c=[3.0, 6.0]) + pdf = ArviZPythonPlots.topandas(Val(:DataFrame), table; index_name=:i) @test pyisinstance(pdf, pandas.DataFrame) - pdf_exp = pandas.DataFrame(rowvals; columns, index) - @test py"($(pdf) == $(pdf_exp)).all().all()" - end - - @testset "DataFrames.DataFrame -> pandas.Series" begin - df2 = DataFrames.DataFrame([:a => [1.0], :b => [2.0], :c => [3.0]]) - ps = ArviZPythonPlots.topandas(Val(:Series), df2) - @test pyisinstance(ps, pandas.Series) - ps_exp = pandas.Series([1.0, 2.0, 3.0], [:a, :b, :c]) - @test py"($(ps) == $(ps_exp)).all()" - end - end - - @testset "todataframes" begin - @testset "pandas.DataFrame -> DataFrames.DataFrame" begin - columns = [:a, :b, :c] - index = ["d", "e"] - rowvals = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] - pdf = pandas.DataFrame(rowvals; columns, index) - df = ArviZPythonPlots.todataframes(pdf; index_name=:i) - @test df isa DataFrames.DataFrame - @test df == DataFrames.DataFrame([ - :i => ["d", "e"], :a => [1.0, 4.0], :b => [2.0, 5.0], :c => [3.0, 6.0] - ]) - @test df == ArviZPythonPlots.todataframes(pdf; index_name=:i) - end - - @testset "pandas.Series -> DataFrames.DataFrame" begin - ps = pandas.Series([1.0, 2.0, 3.0], [:a, :b, :c]) - df2 = ArviZPythonPlots.todataframes(ps) - @test df2 isa DataFrames.DataFrame - @test df2 == DataFrames.DataFrame([:a => [1.0], :b => [2.0], :c => [3.0]]) - @test df2 == ArviZPythonPlots.todataframes(ps) + pdf_exp = pandas.DataFrame( + Py(rowvals); + columns=pylist(map(pystr, columns)), + index=pylist(map(pystr, index)), + ) + @test pyconvert(Bool, pyall(pyeq(pdf, pdf_exp))) end end - - @testset "styles" begin - @test ArviZPythonPlots.styles() isa AbstractArray{String} - @test "arviz-darkgrid" ∈ ArviZPythonPlots.styles() - @test ArviZPythonPlots.styles() == ArviZPythonPlots.arviz.style.available - end - - @testset "use_style" begin - ArviZPythonPlots.use_style("arviz-darkgrid") - ArviZPythonPlots.use_style("default") - end end diff --git a/test/test_xarray.jl b/test/test_xarray.jl index b5edb04..5a76d25 100644 --- a/test/test_xarray.jl +++ b/test/test_xarray.jl @@ -1,10 +1,11 @@ +using ArviZ using ArviZPythonPlots using DimensionalData -using PyCall +using PythonCall using Test @testset "xarray interop" begin - @testset "Dataset <-> xarray" begin + @testset "Dataset -> xarray" begin nchains = 4 ndraws = 100 nshared = 3 @@ -14,58 +15,40 @@ using Test y = DimArray(randn(nchains, ndraws, 2, nshared), ydims) metadata = Dict(:prop1 => "val1", :prop2 => "val2") ds = Dataset((; x, y); metadata) - o = PyObject(ds) - @test o isa PyObject + o = Py(ds) + @test o isa Py @test pyisinstance(o, ArviZPythonPlots.xarray.Dataset) @test issetequal(Symbol.(o.coords.keys()), (:chain, :draw, :shared, :ydim1)) for (dim, coord) in o.coords.items() - @test collect(coord.values) == DimensionalData.index(ds, Symbol(dim)) + @test pyeq( + Bool, pylist(coord.values), pylist(DimensionalData.index(ds, Symbol(dim))) + ) end variables = Dict(collect(o.data_vars.variables.items())) - @test "x" ∈ keys(variables) - @test x == variables["x"].values - @test variables["x"].dims == String.(xdims) - - @test "y" ∈ keys(variables) - @test y == variables["y"].values - @test variables["y"].dims == ("chain", "draw", "ydim1", "shared") - - # check that the Python object accesses the underlying Julia array - x[1] = 1 - @test x == variables["x"].values - - ds2 = convert(Dataset, o) - @test ds2 isa Dataset - @test ds2.x ≈ ds.x - @test ds2.y ≈ ds.y - dims1 = sort(collect(DimensionalData.dims(ds)); by=DimensionalData.name) - dims2 = sort(collect(DimensionalData.dims(ds2)); by=DimensionalData.name) - for (dim1, dim2) in zip(dims1, dims2) - @test DimensionalData.name(dim1) === DimensionalData.name(dim2) - @test DimensionalData.index(dim1) == DimensionalData.index(dim2) - if DimensionalData.index(dim1) isa AbstractRange - @test DimensionalData.index(dim2) isa AbstractRange - end - end - @test DimensionalData.metadata(ds2) == DimensionalData.metadata(ds) + @test pystr("x") ∈ keys(variables) + @test Bool(pyeq(Py(x), variables[pystr("x")].values).all()) + @test Bool(variables[pystr("x")].dims == pytuple(pystr.(xdims))) + + @test pystr("y") ∈ keys(variables) + @test Bool(pyeq(Py(y), variables[pystr("y")].values).all()) + @test Bool( + variables[pystr("y")].dims == + pytuple(pystr.(("chain", "draw", "ydim1", "shared"))), + ) end - @testset "InferenceData <-> PyObject" begin - idata1 = random_data() - pyidata1 = PyObject(idata1) - @test pyidata1 isa PyObject - @test pyisinstance(pyidata1, ArviZPythonPlots.arviz.InferenceData) - idata2 = convert(InferenceData, pyidata1) - test_idata_approx_equal(idata2, idata1) - end - - @testset "convert_to_inference_data(obj::PyObject)" begin - data = Dict(:z => randn(4, 100, 10)) - idata1 = convert_to_inference_data(data) - idata2 = convert_to_inference_data(PyObject(data)) - @test idata2 isa InferenceData - @test idata2.posterior.z ≈ collect(idata1.posterior.z) + @testset "InferenceData -> Py" begin + idata = random_data() + pyidata = Py(idata) + @test pyidata isa Py + @test pyisinstance(pyidata, ArviZPythonPlots.arviz.InferenceData) + for group in keys(idata) + pyds = Py(idata[group]) + @test pyds isa Py + @test pyisinstance(pyds, ArviZPythonPlots.xarray.Dataset) + @test pyall(pyidata[pystr(group)] == pyds) + end end end