Skip to content

src(comp): Use cattrs #137

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 81 additions & 35 deletions cellengine/resources/compensation.py
Original file line number Diff line number Diff line change
@@ -1,68 +1,116 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import List, TYPE_CHECKING
from typing import List, Optional, TYPE_CHECKING, Tuple, cast

from dataclasses_json.cfg import config
import numpy
from attr import define, field
from numpy import array, linalg
from pandas import DataFrame

import cellengine as ce
from cellengine.utils.dataclass_mixin import DataClassMixin, ReadOnly
from cellengine.utils import readonly
from cellengine.utils.converter import converter


if TYPE_CHECKING:
from cellengine.resources.fcs_file import FcsFile


@dataclass
class Compensation(DataClassMixin):
@define
class Compensation:
"""A class representing a CellEngine compensation matrix.

Can be applied to FCS files to compensate them.
"""

_id: str = field(on_setattr=readonly)
experiment_id: str = field(on_setattr=readonly)
name: str
channels: List[str]
dataframe: DataFrame = field(
metadata=config(
field_name="spillMatrix",
encoder=lambda x: x.to_numpy().flatten().tolist(),
decoder=lambda x: numpy.array(x),
)
)
_id: str = field(
metadata=config(field_name="_id"), default=ReadOnly()
) # type: ignore
experiment_id: str = field(default=ReadOnly()) # type: ignore

def __post_init__(self):
self.dataframe = DataFrame(
self.dataframe.reshape(self.N, self.N),
spill_matrix: List[float]

@property
def dataframe(self):
return DataFrame(
array(self.spill_matrix).reshape(self.N, self.N), # type: ignore
columns=self.channels,
index=self.channels,
)

@dataframe.setter
def dataframe(self, df: DataFrame):
self.channels, self.spill_matrix = self._convert_dataframe(df)

@staticmethod
def _convert_dataframe(df: DataFrame) -> Tuple[List[str], List[float]]:
try:
assert all(df.columns == df.index)
channels = cast(List[str], df.columns.tolist())
spill_matrix = cast(List[float], df.to_numpy().flatten().tolist())
return channels, spill_matrix
except Exception:
raise ValueError(
"Dataframe must be a square matrix with equivalent index and columns."
)

def __repr__(self):
return f"Compensation(_id='{self._id}', name='{self.name}')"

@property
def N(self):
return len(self.channels)
def path(self):
return f"experiments/{self.experiment_id}/compensations/{self._id}".rstrip(
"/None"
)

@classmethod
def from_dict(cls, data: dict):
return converter.structure(data, cls)

def to_dict(self):
return converter.unstructure(self)

@classmethod
def get(cls, experiment_id: str, _id: str = None, name: str = None) -> Compensation:
kwargs = {"name": name} if name else {"_id": _id}
return ce.APIClient().get_compensation(experiment_id, **kwargs)

@classmethod
def create(cls, experiment_id: str, compensation: dict) -> Compensation:
"""Creates a compensation
def create(
cls,
experiment_id: str,
name: str,
channels: Optional[List[str]] = None,
spill_matrix: Optional[List[float]] = None,
dataframe: Optional[DataFrame] = None,
) -> Compensation:
"""Create a new compensation for this experiment

Args:
experiment_id: ID of experiment that this compensation belongs to.
compensation: Dict containing `channels` and `spillMatrix` properties.
experiment_id (str): the ID of the experiment.
name (str): The name of the compensation.
channels (List[str]): The names of the channels to which this
compensation matrix applies.
spill_matrix (List[float]): The row-wise, square spillover matrix. The
length of the array must be the number of channels squared.
spill_matrix (DataFrame): A square pandas DataFrame with channel
names in [df.index, df.columns].
"""
return ce.APIClient().post_compensation(experiment_id, compensation)
if dataframe is None:
if not (channels and spill_matrix):
raise TypeError("Both 'channels' and 'spill_matrix' are required.")
else:
if spill_matrix or channels:
raise TypeError(
"Only one of 'dataframe' or {'channels', 'spill_matrix'} "
"may be assigned."
)
else:
channels, spill_matrix = cls._convert_dataframe(dataframe)

body = {"name": name, "channels": channels, "spillMatrix": spill_matrix}
return ce.APIClient().post_compensation(experiment_id, body)

@property
def N(self):
return len(self.channels)

@staticmethod
def from_spill_string(spill_string: str) -> Compensation:
Expand All @@ -81,14 +129,12 @@ def from_spill_string(spill_string: str) -> Compensation:
"experimentId": "",
"name": "",
}
return Compensation.from_dict(properties)
return converter.structure(properties, Compensation)

def update(self):
"""Save changes to this Compensation to CellEngine."""
res = ce.APIClient().update_entity(
self.experiment_id, self._id, "compensations", body=self.to_dict()
)
self.__dict__.update(Compensation.from_dict(res).__dict__)
res = ce.APIClient().update(self)
self.__setstate__(res.__getstate__()) # type: ignore

def delete(self):
return ce.APIClient().delete_entity(
Expand Down Expand Up @@ -132,7 +178,7 @@ def apply(self, file: FcsFile, inplace: bool = True, **kwargs):
if any(ix):
copy = data.copy()
comped = copy[ix]
comped = comped.dot(numpy.linalg.inv(self.dataframe)) # type: ignore
comped = comped.dot(linalg.inv(self.dataframe)) # type: ignore
comped.columns = ix
copy.update(comped.astype(comped.dtypes[0]))
else:
Expand Down
13 changes: 6 additions & 7 deletions cellengine/utils/api_client/APIClient.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def get_compensations(self, experiment_id, as_dict=False) -> List[Compensation]:
)
if as_dict:
return compensations
return [Compensation.from_dict(comp) for comp in compensations]
return converter.structure(compensations, List[Compensation])

def get_compensation(
self, experiment_id, _id=None, name=None, as_dict=False
Expand All @@ -268,14 +268,13 @@ def get_compensation(
)
if as_dict:
return comp
return Compensation.from_dict(comp)
return converter.structure(comp, Compensation)

def post_compensation(self, experiment_id, compensation=None) -> Compensation:
res = self._post(
f"{self.base_url}/experiments/{experiment_id}/compensations",
json=compensation,
def post_compensation(self, experiment_id: str, body: Dict[str, Any]):
comp = self._post(
f"{self.base_url}/experiments/{experiment_id}/compensations", json=body
)
return Compensation.from_dict(res)
return converter.structure(comp, Compensation)

def get_experiments(self, as_dict=False) -> List[Experiment]:
experiments = self._get(f"{self.base_url}/experiments")
Expand Down
76 changes: 64 additions & 12 deletions tests/unit/resources/test_compensation.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from cellengine.utils.parse_fcs_file import parse_fcs_file
from cellengine.resources.fcs_file import FcsFile
import json
import pytest
import responses

from numpy import identity
from pandas import DataFrame
from pandas.testing import assert_frame_equal
from numpy import identity
import pytest
import responses

from cellengine.resources.compensation import Compensation
from cellengine.resources.fcs_file import FcsFile
from cellengine.utils import converter
from cellengine.utils.parse_fcs_file import parse_fcs_file


EXP_ID = "5d38a6f79fae87499999a74b"
Expand All @@ -20,10 +23,10 @@ def fcs_file(ENDPOINT_BASE, client, fcs_files):


@pytest.fixture(scope="function")
def compensation(ENDPOINT_BASE, client, fcs_file, compensations):
def compensation(ENDPOINT_BASE, client, compensations):
comp = compensations[0]
comp.update({"experimentId": EXP_ID})
return Compensation.from_dict(comp)
return converter.structure(comp, Compensation)


def properties_tester(comp):
Expand All @@ -46,17 +49,66 @@ def test_compensation_properties(ENDPOINT_BASE, compensation):


@responses.activate
def test_should_post_compensation(ENDPOINT_BASE, experiment, compensations):
def test_should_post_compensation(client, ENDPOINT_BASE, compensations):
responses.add(
responses.POST,
ENDPOINT_BASE + f"/experiments/{EXP_ID}/compensations",
json=compensations[0],
)
payload = compensations[0].copy()
comp = Compensation.create(experiment._id, payload)
comp = Compensation(None, EXP_ID, "test_comp", ["a", "b"], [1, 0, 0, 1])
comp = client.create(comp)
properties_tester(comp)


@responses.activate
def test_creates_compensation(client, ENDPOINT_BASE, compensations):
responses.add(
responses.POST,
ENDPOINT_BASE + f"/experiments/{EXP_ID}/compensations",
json=compensations[0],
)
comp = Compensation.create(EXP_ID, "test-comp", ["a", "b"], [1, 0, 0, 1])
properties_tester(comp)


@responses.activate
def test_creates_compensation_with_dataframe(client, ENDPOINT_BASE, compensations):
responses.add(
responses.POST,
ENDPOINT_BASE + f"/experiments/{EXP_ID}/compensations",
json=compensations[0],
)
df = DataFrame([[1, 0], [0, 1]], columns=["a", "b"], index=["a", "b"])
comp = Compensation.create(EXP_ID, "test-comp", dataframe=df)
properties_tester(comp)


@responses.activate
def test_raises_TypeError_when_wrong_arg_combo_is_passed(
client, ENDPOINT_BASE, compensations
):
responses.add(
responses.POST,
ENDPOINT_BASE + f"/experiments/{EXP_ID}/compensations",
json=compensations[0],
)
with pytest.raises(TypeError) as err:
Compensation.create(EXP_ID, "test-comp", spill_matrix=[0, 1])
assert err.value.args[0] == "Both 'channels' and 'spill_matrix' are required."

with pytest.raises(TypeError) as err:
Compensation.create(EXP_ID, "test-comp", channels=["a", "b"])
assert err.value.args[0] == "Both 'channels' and 'spill_matrix' are required."

with pytest.raises(TypeError) as err:
Compensation.create(
EXP_ID, "test-comp", channels=["a", "b"], dataframe=DataFrame()
)
assert err.value.args[0] == (
"Only one of 'dataframe' or {'channels', 'spill_matrix'} may be assigned."
)


@responses.activate
def test_should_delete_compensation(ENDPOINT_BASE, compensation):
responses.add(
Expand All @@ -74,7 +126,7 @@ def test_should_update_compensation(ENDPOINT_BASE, compensation):
test.
"""
# patch the mocked response with the correct values
response = compensation.to_dict().copy()
response = converter.unstructure(compensation)
response.update({"name": "newname"})
responses.add(
responses.PATCH,
Expand All @@ -84,7 +136,7 @@ def test_should_update_compensation(ENDPOINT_BASE, compensation):
compensation.name = "newname"
compensation.update()
properties_tester(compensation)
assert json.loads(responses.calls[0].request.body) == compensation.to_dict()
assert json.loads(responses.calls[0].request.body) == response


def test_create_from_spill_string(spillstring):
Expand Down