Skip to content

Commit d7b709d

Browse files
committed
src(comp): Use attrs + cattrs
1 parent 11bf3c4 commit d7b709d

File tree

3 files changed

+91
-53
lines changed

3 files changed

+91
-53
lines changed

cellengine/resources/compensation.py

Lines changed: 59 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,68 +1,95 @@
11
from __future__ import annotations
2-
from dataclasses import dataclass, field
32
from typing import List, TYPE_CHECKING
3+
from attr import define, field
44

5-
from dataclasses_json.cfg import config
6-
import numpy
5+
from numpy import array, linalg
76
from pandas import DataFrame
87

98
import cellengine as ce
10-
from cellengine.utils.dataclass_mixin import DataClassMixin, ReadOnly
9+
from cellengine.utils import readonly
10+
from cellengine.utils.converter import converter
1111

1212

1313
if TYPE_CHECKING:
1414
from cellengine.resources.fcs_file import FcsFile
1515

1616

17-
@dataclass
18-
class Compensation(DataClassMixin):
17+
@define
18+
class Compensation:
1919
"""A class representing a CellEngine compensation matrix.
2020
2121
Can be applied to FCS files to compensate them.
2222
"""
2323

24+
_id: str = field(on_setattr=readonly)
25+
experiment_id: str = field(on_setattr=readonly)
2426
name: str
2527
channels: List[str]
26-
dataframe: DataFrame = field(
27-
metadata=config(
28-
field_name="spillMatrix",
29-
encoder=lambda x: x.to_numpy().flatten().tolist(),
30-
decoder=lambda x: numpy.array(x),
31-
)
32-
)
33-
_id: str = field(
34-
metadata=config(field_name="_id"), default=ReadOnly()
35-
) # type: ignore
36-
experiment_id: str = field(default=ReadOnly()) # type: ignore
37-
38-
def __post_init__(self):
39-
self.dataframe = DataFrame(
40-
self.dataframe.reshape(self.N, self.N),
28+
spill_matrix: List[float]
29+
30+
@property
31+
def dataframe(self):
32+
return DataFrame(
33+
array(self.spill_matrix).reshape(self.N, self.N), # type: ignore
4134
columns=self.channels,
4235
index=self.channels,
4336
)
4437

38+
@dataframe.setter
39+
def dataframe(self, val: DataFrame):
40+
try:
41+
assert len(val.columns) == len(val.index)
42+
assert all(val.columns == val.index)
43+
self.channels = val.columns.to_list()
44+
self.spill_matrix = val.to_numpy().flatten().tolist()
45+
except Exception as e:
46+
raise e
47+
4548
def __repr__(self):
4649
return f"Compensation(_id='{self._id}', name='{self.name}')"
4750

4851
@property
49-
def N(self):
50-
return len(self.channels)
52+
def path(self):
53+
return f"experiments/{self.experiment_id}/compensations/{self._id}".rstrip(
54+
"/None"
55+
)
56+
57+
@classmethod
58+
def from_dict(cls, data: dict):
59+
return converter.structure(data, cls)
60+
61+
def to_dict(self):
62+
return converter.unstructure(self)
5163

5264
@classmethod
5365
def get(cls, experiment_id: str, _id: str = None, name: str = None) -> Compensation:
5466
kwargs = {"name": name} if name else {"_id": _id}
5567
return ce.APIClient().get_compensation(experiment_id, **kwargs)
5668

5769
@classmethod
58-
def create(cls, experiment_id: str, compensation: dict) -> Compensation:
59-
"""Creates a compensation
70+
def create(
71+
cls,
72+
experiment_id: str,
73+
name: str,
74+
channels: List[str],
75+
spill_matrix: List[float],
76+
) -> Compensation:
77+
"""Create a new compensation for this experiment
6078
6179
Args:
62-
experiment_id: ID of experiment that this compensation belongs to.
63-
compensation: Dict containing `channels` and `spillMatrix` properties.
80+
experiment_id (str): the ID of the experiment.
81+
name (str): The name of the compensation.
82+
channels (List[str]): The names of the channels to which this
83+
compensation matrix applies.
84+
spill_matrix (List[float]): The row-wise, square spillover matrix. The
85+
length of the array must be the number of channels squared.
6486
"""
65-
return ce.APIClient().post_compensation(experiment_id, compensation)
87+
body = {"name": name, "channels": channels, "spillMatrix": spill_matrix}
88+
return ce.APIClient().post_compensation(experiment_id, body)
89+
90+
@property
91+
def N(self):
92+
return len(self.channels)
6693

6794
@staticmethod
6895
def from_spill_string(spill_string: str) -> Compensation:
@@ -81,14 +108,12 @@ def from_spill_string(spill_string: str) -> Compensation:
81108
"experimentId": "",
82109
"name": "",
83110
}
84-
return Compensation.from_dict(properties)
111+
return converter.structure(properties, Compensation)
85112

86113
def update(self):
87114
"""Save changes to this Compensation to CellEngine."""
88-
res = ce.APIClient().update_entity(
89-
self.experiment_id, self._id, "compensations", body=self.to_dict()
90-
)
91-
self.__dict__.update(Compensation.from_dict(res).__dict__)
115+
res = ce.APIClient().update(self)
116+
self.__setstate__(res.__getstate__()) # type: ignore
92117

93118
def delete(self):
94119
return ce.APIClient().delete_entity(
@@ -132,7 +157,7 @@ def apply(self, file: FcsFile, inplace: bool = True, **kwargs):
132157
if any(ix):
133158
copy = data.copy()
134159
comped = copy[ix]
135-
comped = comped.dot(numpy.linalg.inv(self.dataframe)) # type: ignore
160+
comped = comped.dot(linalg.inv(self.dataframe)) # type: ignore
136161
comped.columns = ix
137162
copy.update(comped.astype(comped.dtypes[0]))
138163
else:

cellengine/utils/api_client/APIClient.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def get_compensations(self, experiment_id, as_dict=False) -> List[Compensation]:
228228
)
229229
if as_dict:
230230
return compensations
231-
return [Compensation.from_dict(comp) for comp in compensations]
231+
return converter.structure(compensations, List[Compensation])
232232

233233
def get_compensation(
234234
self, experiment_id, _id=None, name=None, as_dict=False
@@ -239,14 +239,13 @@ def get_compensation(
239239
)
240240
if as_dict:
241241
return comp
242-
return Compensation.from_dict(comp)
242+
return converter.structure(comp, Compensation)
243243

244-
def post_compensation(self, experiment_id, compensation=None) -> Compensation:
245-
res = self._post(
246-
f"{self.base_url}/experiments/{experiment_id}/compensations",
247-
json=compensation,
244+
def post_compensation(self, experiment_id: str, body: Dict[str, Any]):
245+
comp = self._post(
246+
f"{self.base_url}/experiments/{experiment_id}/compensations", json=body
248247
)
249-
return Compensation.from_dict(res)
248+
return converter.structure(comp, Compensation)
250249

251250
def get_experiments(self, as_dict=False) -> List[Experiment]:
252251
experiments = self._get(f"{self.base_url}/experiments")

tests/unit/resources/test_compensation.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
1-
from cellengine.utils.parse_fcs_file import parse_fcs_file
2-
from cellengine.resources.fcs_file import FcsFile
31
import json
4-
import pytest
5-
import responses
2+
3+
from numpy import identity
64
from pandas import DataFrame
75
from pandas.testing import assert_frame_equal
8-
from numpy import identity
6+
import pytest
7+
import responses
8+
99
from cellengine.resources.compensation import Compensation
10+
from cellengine.resources.fcs_file import FcsFile
11+
from cellengine.utils import converter
12+
from cellengine.utils.parse_fcs_file import parse_fcs_file
1013

1114

1215
EXP_ID = "5d38a6f79fae87499999a74b"
@@ -20,10 +23,10 @@ def fcs_file(ENDPOINT_BASE, client, fcs_files):
2023

2124

2225
@pytest.fixture(scope="function")
23-
def compensation(ENDPOINT_BASE, client, fcs_file, compensations):
26+
def compensation(ENDPOINT_BASE, client, compensations):
2427
comp = compensations[0]
2528
comp.update({"experimentId": EXP_ID})
26-
return Compensation.from_dict(comp)
29+
return converter.structure(comp, Compensation)
2730

2831

2932
def properties_tester(comp):
@@ -46,14 +49,25 @@ def test_compensation_properties(ENDPOINT_BASE, compensation):
4649

4750

4851
@responses.activate
49-
def test_should_post_compensation(ENDPOINT_BASE, experiment, compensations):
52+
def test_should_post_compensation(client, ENDPOINT_BASE, compensations):
53+
responses.add(
54+
responses.POST,
55+
ENDPOINT_BASE + f"/experiments/{EXP_ID}/compensations",
56+
json=compensations[0],
57+
)
58+
comp = Compensation(None, EXP_ID, "test_comp", ["a", "b"], [1, 0, 0, 1])
59+
comp = client.create(comp)
60+
properties_tester(comp)
61+
62+
63+
@responses.activate
64+
def test_should_create_compensation(client, ENDPOINT_BASE, compensations):
5065
responses.add(
5166
responses.POST,
5267
ENDPOINT_BASE + f"/experiments/{EXP_ID}/compensations",
5368
json=compensations[0],
5469
)
55-
payload = compensations[0].copy()
56-
comp = Compensation.create(experiment._id, payload)
70+
comp = Compensation.create(EXP_ID, "test-comp", ["a", "b"], [1, 0, 0, 1])
5771
properties_tester(comp)
5872

5973

@@ -74,7 +88,7 @@ def test_should_update_compensation(ENDPOINT_BASE, compensation):
7488
test.
7589
"""
7690
# patch the mocked response with the correct values
77-
response = compensation.to_dict().copy()
91+
response = converter.unstructure(compensation)
7892
response.update({"name": "newname"})
7993
responses.add(
8094
responses.PATCH,
@@ -84,7 +98,7 @@ def test_should_update_compensation(ENDPOINT_BASE, compensation):
8498
compensation.name = "newname"
8599
compensation.update()
86100
properties_tester(compensation)
87-
assert json.loads(responses.calls[0].request.body) == compensation.to_dict()
101+
assert json.loads(responses.calls[0].request.body) == response
88102

89103

90104
def test_create_from_spill_string(spillstring):

0 commit comments

Comments
 (0)