Skip to content

Commit cc72b3c

Browse files
committed
src(comp): Use attrs + cattrs
1 parent 2d35c99 commit cc72b3c

File tree

3 files changed

+93
-53
lines changed

3 files changed

+93
-53
lines changed

cellengine/resources/compensation.py

Lines changed: 61 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,68 +1,97 @@
11
from __future__ import annotations
2-
from dataclasses import dataclass, field
32
from typing import List, TYPE_CHECKING
43

5-
from dataclasses_json.cfg import config
6-
import numpy
4+
from attr import define, field
5+
from numpy import array, linalg
76
from pandas import DataFrame
87

98
import cellengine as ce
10-
from cellengine.utils.dataclass_mixin import DataClassMixin, ReadOnly
9+
from cellengine.utils import readonly
10+
from cellengine.utils.converter import converter
1111

1212

1313
if TYPE_CHECKING:
1414
from cellengine.resources.fcs_file import FcsFile
1515

1616

17-
@dataclass
18-
class Compensation(DataClassMixin):
17+
@define
18+
class Compensation:
1919
"""A class representing a CellEngine compensation matrix.
2020
2121
Can be applied to FCS files to compensate them.
2222
"""
2323

24+
_id: str = field(on_setattr=readonly)
25+
experiment_id: str = field(on_setattr=readonly)
2426
name: str
2527
channels: List[str]
26-
dataframe: DataFrame = field(
27-
metadata=config(
28-
field_name="spillMatrix",
29-
encoder=lambda x: x.to_numpy().flatten().tolist(),
30-
decoder=lambda x: numpy.array(x),
31-
)
32-
)
33-
_id: str = field(
34-
metadata=config(field_name="_id"), default=ReadOnly()
35-
) # type: ignore
36-
experiment_id: str = field(default=ReadOnly()) # type: ignore
37-
38-
def __post_init__(self):
39-
self.dataframe = DataFrame(
40-
self.dataframe.reshape(self.N, self.N),
28+
spill_matrix: List[float]
29+
30+
@property
31+
def dataframe(self):
32+
return DataFrame(
33+
array(self.spill_matrix).reshape(self.N, self.N), # type: ignore
4134
columns=self.channels,
4235
index=self.channels,
4336
)
4437

38+
@dataframe.setter
39+
def dataframe(self, val: DataFrame):
40+
try:
41+
assert len(val.columns) == len(val.index)
42+
assert all(val.columns == val.index)
43+
self.channels = val.columns.to_list()
44+
self.spill_matrix = val.to_numpy().flatten().tolist()
45+
except Exception:
46+
raise ValueError(
47+
"Dataframe must be a square matrix with equivalent index and columns."
48+
)
49+
4550
def __repr__(self):
4651
return f"Compensation(_id='{self._id}', name='{self.name}')"
4752

4853
@property
49-
def N(self):
50-
return len(self.channels)
54+
def path(self):
55+
return f"experiments/{self.experiment_id}/compensations/{self._id}".rstrip(
56+
"/None"
57+
)
58+
59+
@classmethod
60+
def from_dict(cls, data: dict):
61+
return converter.structure(data, cls)
62+
63+
def to_dict(self):
64+
return converter.unstructure(self)
5165

5266
@classmethod
5367
def get(cls, experiment_id: str, _id: str = None, name: str = None) -> Compensation:
5468
kwargs = {"name": name} if name else {"_id": _id}
5569
return ce.APIClient().get_compensation(experiment_id, **kwargs)
5670

5771
@classmethod
58-
def create(cls, experiment_id: str, compensation: dict) -> Compensation:
59-
"""Creates a compensation
72+
def create(
73+
cls,
74+
experiment_id: str,
75+
name: str,
76+
channels: List[str],
77+
spill_matrix: List[float],
78+
) -> Compensation:
79+
"""Create a new compensation for this experiment
6080
6181
Args:
62-
experiment_id: ID of experiment that this compensation belongs to.
63-
compensation: Dict containing `channels` and `spillMatrix` properties.
82+
experiment_id (str): the ID of the experiment.
83+
name (str): The name of the compensation.
84+
channels (List[str]): The names of the channels to which this
85+
compensation matrix applies.
86+
spill_matrix (List[float]): The row-wise, square spillover matrix. The
87+
length of the array must be the number of channels squared.
6488
"""
65-
return ce.APIClient().post_compensation(experiment_id, compensation)
89+
body = {"name": name, "channels": channels, "spillMatrix": spill_matrix}
90+
return ce.APIClient().post_compensation(experiment_id, body)
91+
92+
@property
93+
def N(self):
94+
return len(self.channels)
6695

6796
@staticmethod
6897
def from_spill_string(spill_string: str) -> Compensation:
@@ -81,14 +110,12 @@ def from_spill_string(spill_string: str) -> Compensation:
81110
"experimentId": "",
82111
"name": "",
83112
}
84-
return Compensation.from_dict(properties)
113+
return converter.structure(properties, Compensation)
85114

86115
def update(self):
87116
"""Save changes to this Compensation to CellEngine."""
88-
res = ce.APIClient().update_entity(
89-
self.experiment_id, self._id, "compensations", body=self.to_dict()
90-
)
91-
self.__dict__.update(Compensation.from_dict(res).__dict__)
117+
res = ce.APIClient().update(self)
118+
self.__setstate__(res.__getstate__()) # type: ignore
92119

93120
def delete(self):
94121
return ce.APIClient().delete_entity(
@@ -132,7 +159,7 @@ def apply(self, file: FcsFile, inplace: bool = True, **kwargs):
132159
if any(ix):
133160
copy = data.copy()
134161
comped = copy[ix]
135-
comped = comped.dot(numpy.linalg.inv(self.dataframe)) # type: ignore
162+
comped = comped.dot(linalg.inv(self.dataframe)) # type: ignore
136163
comped.columns = ix
137164
copy.update(comped.astype(comped.dtypes[0]))
138165
else:

cellengine/utils/api_client/APIClient.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def get_compensations(self, experiment_id, as_dict=False) -> List[Compensation]:
251251
)
252252
if as_dict:
253253
return compensations
254-
return [Compensation.from_dict(comp) for comp in compensations]
254+
return converter.structure(compensations, List[Compensation])
255255

256256
def get_compensation(
257257
self, experiment_id, _id=None, name=None, as_dict=False
@@ -262,14 +262,13 @@ def get_compensation(
262262
)
263263
if as_dict:
264264
return comp
265-
return Compensation.from_dict(comp)
265+
return converter.structure(comp, Compensation)
266266

267-
def post_compensation(self, experiment_id, compensation=None) -> Compensation:
268-
res = self._post(
269-
f"{self.base_url}/experiments/{experiment_id}/compensations",
270-
json=compensation,
267+
def post_compensation(self, experiment_id: str, body: Dict[str, Any]):
268+
comp = self._post(
269+
f"{self.base_url}/experiments/{experiment_id}/compensations", json=body
271270
)
272-
return Compensation.from_dict(res)
271+
return converter.structure(comp, Compensation)
273272

274273
def get_experiments(self, as_dict=False) -> List[Experiment]:
275274
experiments = self._get(f"{self.base_url}/experiments")

tests/unit/resources/test_compensation.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
1-
from cellengine.utils.parse_fcs_file import parse_fcs_file
2-
from cellengine.resources.fcs_file import FcsFile
31
import json
4-
import pytest
5-
import responses
2+
3+
from numpy import identity
64
from pandas import DataFrame
75
from pandas.testing import assert_frame_equal
8-
from numpy import identity
6+
import pytest
7+
import responses
8+
99
from cellengine.resources.compensation import Compensation
10+
from cellengine.resources.fcs_file import FcsFile
11+
from cellengine.utils import converter
12+
from cellengine.utils.parse_fcs_file import parse_fcs_file
1013

1114

1215
EXP_ID = "5d38a6f79fae87499999a74b"
@@ -20,10 +23,10 @@ def fcs_file(ENDPOINT_BASE, client, fcs_files):
2023

2124

2225
@pytest.fixture(scope="function")
23-
def compensation(ENDPOINT_BASE, client, fcs_file, compensations):
26+
def compensation(ENDPOINT_BASE, client, compensations):
2427
comp = compensations[0]
2528
comp.update({"experimentId": EXP_ID})
26-
return Compensation.from_dict(comp)
29+
return converter.structure(comp, Compensation)
2730

2831

2932
def properties_tester(comp):
@@ -46,14 +49,25 @@ def test_compensation_properties(ENDPOINT_BASE, compensation):
4649

4750

4851
@responses.activate
49-
def test_should_post_compensation(ENDPOINT_BASE, experiment, compensations):
52+
def test_should_post_compensation(client, ENDPOINT_BASE, compensations):
53+
responses.add(
54+
responses.POST,
55+
ENDPOINT_BASE + f"/experiments/{EXP_ID}/compensations",
56+
json=compensations[0],
57+
)
58+
comp = Compensation(None, EXP_ID, "test_comp", ["a", "b"], [1, 0, 0, 1])
59+
comp = client.create(comp)
60+
properties_tester(comp)
61+
62+
63+
@responses.activate
64+
def test_should_create_compensation(client, ENDPOINT_BASE, compensations):
5065
responses.add(
5166
responses.POST,
5267
ENDPOINT_BASE + f"/experiments/{EXP_ID}/compensations",
5368
json=compensations[0],
5469
)
55-
payload = compensations[0].copy()
56-
comp = Compensation.create(experiment._id, payload)
70+
comp = Compensation.create(EXP_ID, "test-comp", ["a", "b"], [1, 0, 0, 1])
5771
properties_tester(comp)
5872

5973

@@ -74,7 +88,7 @@ def test_should_update_compensation(ENDPOINT_BASE, compensation):
7488
test.
7589
"""
7690
# patch the mocked response with the correct values
77-
response = compensation.to_dict().copy()
91+
response = converter.unstructure(compensation)
7892
response.update({"name": "newname"})
7993
responses.add(
8094
responses.PATCH,
@@ -84,7 +98,7 @@ def test_should_update_compensation(ENDPOINT_BASE, compensation):
8498
compensation.name = "newname"
8599
compensation.update()
86100
properties_tester(compensation)
87-
assert json.loads(responses.calls[0].request.body) == compensation.to_dict()
101+
assert json.loads(responses.calls[0].request.body) == response
88102

89103

90104
def test_create_from_spill_string(spillstring):

0 commit comments

Comments
 (0)