Skip to content
This repository was archived by the owner on Jan 27, 2025. It is now read-only.

Commit f363635

Browse files
committed
ENH: Add gaussian process DWI signal representation notebooks
Add gaussian process DWI signal representation notebooks: - One of the notebooks uses a simulated DWI signal. - The second notebook uses a real DWI signal.
1 parent 8c0bf36 commit f363635

File tree

2 files changed

+473
-0
lines changed

2 files changed

+473
-0
lines changed

docs/notebooks/dwi_gp.ipynb

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
{
2+
"cells": [
3+
{
4+
"metadata": {},
5+
"cell_type": "markdown",
6+
"source": "Gaussian process notebook",
7+
"id": "486923b289155658"
8+
},
9+
{
10+
"metadata": {},
11+
"cell_type": "code",
12+
"source": [
13+
"import tempfile\n",
14+
"from pathlib import Path\n",
15+
"\n",
16+
"import numpy as np\n",
17+
"from sklearn.gaussian_process.kernels import DotProduct, WhiteKernel\n",
18+
"\n",
19+
"from eddymotion import model\n",
20+
"from eddymotion.data.dmri import DWI\n",
21+
"from eddymotion.data.splitting import lovo_split\n",
22+
"\n",
23+
"datadir = Path(\"../../test\") # Adapt to your local path or download to a temp location using wget\n",
24+
"\n",
25+
"kernel = DotProduct() + WhiteKernel()\n",
26+
"\n",
27+
"dwi = DWI.from_filename(datadir / \"dwi.h5\")\n",
28+
"\n",
29+
"_dwi_data = dwi.dataobj\n",
30+
"# Use a subset of the data for now to see that something is written to the\n",
31+
"# output\n",
32+
"# bvecs = dwi.gradients[:3, :].T\n",
33+
"bvecs = dwi.gradients[:3, 10:13].T # b0 values have already been masked\n",
34+
"# bvals = dwi.gradients[3:, 10:13].T # Only for inspection purposes: [[1005.], [1000.], [ 995.]]\n",
35+
"dwi_data = _dwi_data[60:63, 60:64, 40:45, 10:13]\n",
36+
"\n",
37+
"# ToDo\n",
38+
"# Provide proper values/estimates for these\n",
39+
"a = 1\n",
40+
"h = 1 # should be a NIfTI image\n",
41+
"\n",
42+
"num_iterations = 5\n",
43+
"gp = model.GaussianProcessModel(\n",
44+
" dwi=dwi, a=a, h=h, kernel=kernel, num_iterations=num_iterations\n",
45+
")\n",
46+
"indices = list(range(bvecs.shape[0]))\n",
47+
"# ToDo\n",
48+
"# This should be done within the GP model class\n",
49+
"# Apply lovo strategy properly\n",
50+
"# Vectorize and parallelize\n",
51+
"result_mean = np.zeros_like(dwi_data)\n",
52+
"result_stddev = np.zeros_like(dwi_data)\n",
53+
"for idx in indices:\n",
54+
" lovo_idx = np.ones(len(indices), dtype=bool)\n",
55+
" lovo_idx[idx] = False\n",
56+
" X = bvecs[lovo_idx]\n",
57+
" for i in range(dwi_data.shape[0]):\n",
58+
" for j in range(dwi_data.shape[1]):\n",
59+
" for k in range(dwi_data.shape[2]):\n",
60+
" # ToDo\n",
61+
" # Use a mask to avoid traversing background data\n",
62+
" y = dwi_data[i, j, k, lovo_idx]\n",
63+
" gp.fit(X, y)\n",
64+
" pred_mean, pred_stddev = gp.predict(\n",
65+
" bvecs[idx, :][np.newaxis]\n",
66+
" ) # Can take multiple values X[:2, :]\n",
67+
" result_mean[i, j, k, idx] = pred_mean.item()\n",
68+
" result_stddev[i, j, k, idx] = pred_stddev.item()"
69+
],
70+
"id": "da2274009534db61",
71+
"outputs": [],
72+
"execution_count": null
73+
},
74+
{
75+
"metadata": {},
76+
"cell_type": "markdown",
77+
"source": "Plot the data",
78+
"id": "77e77cd4c73409d3"
79+
},
80+
{
81+
"metadata": {},
82+
"cell_type": "code",
83+
"source": [
84+
"from matplotlib import pyplot as plt \n",
85+
"%matplotlib inline\n",
86+
"\n",
87+
"s = dwi_data[1, 1, 2, :]\n",
88+
"s_hat_mean = result_mean[1, 1, 2, :]\n",
89+
"s_hat_stddev = result_stddev[1, 1, 2, :]\n",
90+
"x = np.asarray(indices)\n",
91+
"\n",
92+
"fig, ax = plt.subplots()\n",
93+
"ax.plot(x, s_hat_mean, c=\"orange\", label=\"predicted\")\n",
94+
"plt.fill_between(\n",
95+
" x.ravel(),\n",
96+
" s_hat_mean - 1.96 * s_hat_stddev,\n",
97+
" s_hat_mean + 1.96 * s_hat_stddev,\n",
98+
" alpha=0.5,\n",
99+
" color=\"orange\",\n",
100+
" label=r\"95% confidence interval\",\n",
101+
")\n",
102+
"plt.scatter(x, s, c=\"b\", label=\"ground truth\")\n",
103+
"ax.set_xlabel(\"bvec indices\")\n",
104+
"ax.set_ylabel(\"signal\")\n",
105+
"ax.legend()\n",
106+
"plt.title(\"Gaussian process regression on dataset\")\n",
107+
"\n",
108+
"plt.show()"
109+
],
110+
"id": "4e51f22890fb045a",
111+
"outputs": [],
112+
"execution_count": null
113+
},
114+
{
115+
"metadata": {},
116+
"cell_type": "markdown",
117+
"source": [
118+
"Plot the DWI signal for a given voxel\n",
119+
"Compute the DWI signal value wrt the b0 (how much larger/smaller is and add that delta to the unit sphere?) for each bvec direction and plot that?"
120+
],
121+
"id": "694a4c075457425d"
122+
},
123+
{
124+
"metadata": {},
125+
"cell_type": "code",
126+
"source": [
127+
"# from mpl_toolkits.mplot3d import Axes3D\n",
128+
"# fig, ax = plt.subplots()\n",
129+
"# ax = fig.add_subplot(111, projection='3d')\n",
130+
"# plt.scatter(xx, yy, zz)"
131+
],
132+
"id": "bb7d2aef53ac99f0",
133+
"outputs": [],
134+
"execution_count": null
135+
},
136+
{
137+
"metadata": {},
138+
"cell_type": "markdown",
139+
"source": "Plot the DWI signal brain data\n",
140+
"id": "62d7bc609b65c7cf"
141+
},
142+
{
143+
"metadata": {},
144+
"cell_type": "code",
145+
"source": "# plot_dwi(dmri_dataset.dataobj, dmri_dataset.affine, gradient=data_test[1], black_bg=True)",
146+
"id": "edb0e9d255516e38",
147+
"outputs": [],
148+
"execution_count": null
149+
},
150+
{
151+
"metadata": {},
152+
"cell_type": "markdown",
153+
"source": "Plot the predicted DWI signal",
154+
"id": "1a52e2450fc61dc6"
155+
},
156+
{
157+
"metadata": {},
158+
"cell_type": "code",
159+
"source": "# plot_dwi(predicted, dmri_dataset.affine, gradient=data_test[1], black_bg=True);",
160+
"id": "66150cf337b395e0",
161+
"outputs": [],
162+
"execution_count": null
163+
}
164+
],
165+
"metadata": {
166+
"kernelspec": {
167+
"display_name": "Python 3",
168+
"language": "python",
169+
"name": "python3"
170+
},
171+
"language_info": {
172+
"codemirror_mode": {
173+
"name": "ipython",
174+
"version": 2
175+
},
176+
"file_extension": ".py",
177+
"mimetype": "text/x-python",
178+
"name": "python",
179+
"nbconvert_exporter": "python",
180+
"pygments_lexer": "ipython2",
181+
"version": "2.7.6"
182+
}
183+
},
184+
"nbformat": 4,
185+
"nbformat_minor": 5
186+
}

0 commit comments

Comments
 (0)