From e02f778dc9c7858c3e37ec5f5d7fff4ae2f76773 Mon Sep 17 00:00:00 2001 From: LeonStadelmann Date: Wed, 30 Jul 2025 17:18:36 +0200 Subject: [PATCH 1/2] init commit --- src/cellflow/model/_cellflow.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/cellflow/model/_cellflow.py b/src/cellflow/model/_cellflow.py index 1ae3e3b0..3b1cccf6 100644 --- a/src/cellflow/model/_cellflow.py +++ b/src/cellflow/model/_cellflow.py @@ -505,7 +505,8 @@ def prepare_model( else: raise NotImplementedError(f"Solver must be an instance of OTFlowMatching or GENOT, got {type(self.solver)}") - self._trainer = CellFlowTrainer(solver=self.solver, predict_kwargs=self.validation_data["predict_kwargs"]) # type: ignore[arg-type] + predict_kwargs = self.validation_data.get("predict_kwargs", {}) + self._trainer = CellFlowTrainer(solver=self.solver, predict_kwargs=predict_kwargs) # type: ignore[arg-type] def train( self, From 577cccd175ac62a264349fae29a9b7bbbf6ebf1e Mon Sep 17 00:00:00 2001 From: LeonStadelmann Date: Wed, 30 Jul 2025 17:57:01 +0200 Subject: [PATCH 2/2] Add split fix for GENOT --- src/cellflow/solvers/_genot.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/cellflow/solvers/_genot.py b/src/cellflow/solvers/_genot.py index 7270ad7f..f2cc91d1 100644 --- a/src/cellflow/solvers/_genot.py +++ b/src/cellflow/solvers/_genot.py @@ -284,6 +284,12 @@ def predict( pred_targets = batched_predict(src_inputs, batched_conditions) return {k: pred_targets[i] for i, k in enumerate(keys)} + elif isinstance(x, dict): + return jax.tree.map( + functools.partial(self._predict_jit, rng=rng, **kwargs), + x, + condition, # type: ignore[attr-defined] + ) else: x_pred = self._predict_jit(x, condition, rng, rng_genot, **kwargs) return np.array(x_pred)