Skip to content
Open

V3 #6

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
522 changes: 311 additions & 211 deletions .idea/workspace.xml

Large diffs are not rendered by default.

39 changes: 23 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,32 @@

# SurPyval

SurPyval is a Bayesian survival analysis library
SurPyval is a Bayesian survival analysis library.

## Why use

SurPyval is aimed at people who are using survival analysis libraries like lifelines, but who want more flexibility and access to Bayesian approaches.

Users who are new to Bayesianism can make use of sensible defaults and helper methods, while power users can take extremely detailed control of models.

## Philosophy:

* Models should be transparent about their assumptions and workings
* Models should allow tweaks and modifications

Implementing this philosophy has a number of positive effects on the library:
SurPyval is built on the core philosophy that it should be as easy as possible for users to understand and tweak models. Many statistical libraries are really easy to use until you want a slightly different assumption, which they cannot support.

* The log-likihood and plate diagrams of models are exposed
* Models are created through composition of simple units
* SurPyval objects thinly wrap and expose well-know libraries (esp. scipy)
* There are no hand-offs to non-python objects
* Models allow for substitution of any of their composite blocks

The trade-off to get these goods is performance. Models provide in the library are designed to be tweakable, which limits performance optimizations. This manifests itself in a number of ways:
There are two main architectural decisions this has entailed:

#### Graph centric

All models in SurPyval are build around graphical models. Every model can return a plate diagram of its likelihood function. Digger deeper, models are actually made through composing together different variable "Nodes". To change the model, we can simply swap out or add nodes to the model's graph

#### Thin wrapper over common libraries

Allowing modification and tweaking is much less valuable if doing so requires learning a complex new API. To make the process as simple as possible, most SuPyval classes and objects are relatively thin wrappers over classes from libraries like scipy and emcee. SurPyval objects are eager to expose these common libraries to the user

#### Trade offs

* Straight up crunching speed
* Memory useage
* Models often don't exploit conjugacy where it exists
In general, SuPyval is comfortable with paying for composability and modifiability with performance. For a lot of tasks, SurPyval won't run as quickly as (say) a custom written Stan model. However:

For very large data sets or very complicated models, you might be better off using something like Stan.
* For data sets with 6 figure rows, the slow down isn't much of a problem
* Any performance loss is asymetrrical, some use cases will be blazing fast
* There are ways of mitigating this (see performance part of docs)
36 changes: 25 additions & 11 deletions SurPyval/model/fitmodel.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Dict, Any
import numpy as np

from SurPyval.node import NodeTree
from SurPyval.node import NodeTree, Node
from SurPyval.samplers import EmceeSampler


Expand Down Expand Up @@ -29,16 +30,29 @@ def generate_replicates(self, n_replicates: int):
posterior_samples = self.posterior.sample(n_replicates)[:n_replicates]
return [self.node_tree.generate_replicate(posterior_sample) for posterior_sample in posterior_samples]

def predict(self, data_dict: Dict[str, Any]):
fitted_node_tree = NodeTree(self.node_tree.node_dict, data_dict)
def predict(self, node_dict: Dict[str, Node]):
fitted_node_tree = self.node_tree.update(node_dict)
fitted_model = FitModel(fitted_node_tree, self.posterior)
return fitted_model

def sample_survival(self):
def plot_survival_function(surv_node):
start_point = surv_node.distribution.ppf( 0.005, **parsed_n )[0]
end_point = surv_node.distribution.ppf( 0.995, **parsed_n )[0]
x_s = np.linspace( start_point, end_point, 1000 )
vals = [surv_node.distribution.sf( x, **parsed_n )[0] for x in x_s]
from matplotlib import pyplot as plt
plt.plot( x_s, vals )
def survival_function(self, flat_parameter_array):
param_dict = self.node_tree.unflatten_parameter_array(flat_parameter_array)

def surv_function(**kwargs):
return self.node_tree["y"].sf(**{**param_dict, **kwargs})

return surv_function

def sample_survival_function(self, n_samples):
posterior_samples = self.posterior.sample(n_samples)[:n_samples]
return [self.survival_function(x) for x in posterior_samples]

def plot_survival_function(self):
param_dict = self.node_tree.unflatten_parameter_array(self.posterior.maximum_likihood)
start_point = self.node_tree["y"].ppf(0.005, **param_dict)
end_point = self.node_tree["y"].ppf(0.995, **param_dict)
x_s = np.linspace(start_point, end_point, 1000)
surv_function = self.survival_function(self.posterior.maximum_likihood)
vals = surv_function(y=x_s)
from matplotlib import pyplot as plt
plt.plot(x_s, vals)
25 changes: 11 additions & 14 deletions SurPyval/model/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from scipy.optimize import minimize
import emcee as em
import numpy as np

from SurPyval.node.tree import NodeTree
Expand Down Expand Up @@ -40,16 +39,8 @@ def fit(self, n_walkers: int =4, burn: int =500):
The function has the side effect of populating self.posterior
"""

def generate_starting_points():
max_lik_point = self.maximum_likihood()
return [max_lik_point + np.random.normal(0, 0.01, int(self.node_tree.length())) for x in range(n_walkers)]

ndim = self.node_tree.length()
sampler = em.EnsembleSampler(n_walkers, ndim, self.node_tree.logpdf)
p0 = generate_starting_points()
pos, prob, state = sampler.run_mcmc(p0, burn)

posterior = EmceeSampler(sampler, pos)
posterior = EmceeSampler(self.node_tree.logpdf, self.maximum_likihood(), n_walkers)
posterior.sample(burn)
return FitModel(self.node_tree, posterior)

def maximum_likihood(self):
Expand All @@ -64,7 +55,13 @@ def maximum_likihood(self):
array_like
flat array of parameters
"""
neg_lok_lik = lambda *args: -self.node_tree.logpdf(*args)
result = minimize(neg_lok_lik, np.array([1] * self.node_tree.length()))

def neg_lok_lik(*args):
lik = -self.node_tree.logpdf(*args)
if lik is None or np.any(np.isnan(lik)):
return -np.inf
return lik

result = minimize(neg_lok_lik, np.array([1] * self.node_tree.length()), method='Nelder-Mead')
max_lik_point = result["x"]
return max_lik_point
return max_lik_point
6 changes: 5 additions & 1 deletion SurPyval/node/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@
from SurPyval.node.node import Node, gamma, exponential, gaussian
from SurPyval.node.tree import NodeTree
from SurPyval.node.tree import NodeTree
from SurPyval.node.datanode import DataNode
from SurPyval.node.datalikihoodnode import DataLikihoodNode
from SurPyval.node.transformation import DeterministicNode
from SurPyval.node.parameter import ParameterNode
9 changes: 7 additions & 2 deletions SurPyval/node/datalikihoodnode.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
from typing import Dict, Any

from SurPyval.node.node import Node

Expand All @@ -13,8 +14,12 @@ class DataLikihoodNode(Node):
DataLikihoodNode automatically does this routing
"""

def __init__(self, distribution, data: Any, parameter_dict: Dict[str, str], constants_dict: Dict[str, Any]=None):
self.data = data
Node.__init__(self, distribution, parameter_dict, constants_dict)

def logpdf(self, **kwargs):
processed_kwargs = self.parse_unflattened_parameters(**kwargs)
likihood_contribution = self.distribution.logpdf(kwargs["y"], **processed_kwargs)
survival_contribution = self.distribution.logsf(kwargs["y"], **processed_kwargs)
likihood_contribution = self.distribution.logpdf(self.data, **processed_kwargs)
survival_contribution = self.distribution.logsf(self.data, **processed_kwargs)
return np.sum(likihood_contribution[kwargs["event"].astype(bool)]) + np.sum(survival_contribution[~kwargs["event"].astype(bool)])
34 changes: 34 additions & 0 deletions SurPyval/node/datanode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from SurPyval.node import Node


class DataNode(Node):
"""
A Node for fixed sources of data in the model

These nodes don't contribute to the likihood function at all
Their only purpose is to store data for use in the model

Parameters
----------
name: str
the name the data source has in the model (e.g. event, x)
data: array-like
the data itself, usually pd.DataFrame or np array

"""

def __init__(self, name, data):
self.name = name
self.data = data

def sample(self, size=1, **kwargs):
return None

def logpdf(self, **kwargs):
return 0.0

def pdf(self, **kwargs):
return 1.0

def __str__(self):
return self.name
9 changes: 6 additions & 3 deletions SurPyval/node/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,19 @@ def __init__(self, distribution: rv_continuous, parameter_dict: Dict[str, str],
def parse_unflattened_parameters(self, **kwargs):
filtered_kwargs = {x: kwargs[x] for x in kwargs if x in self.parameter_names}
renamed_kwargs = {self.parameter_dict[x]: filtered_kwargs[x] for x in filtered_kwargs}
return {**renamed_kwargs, **self.constants_dict}
return {**self.constants_dict, **renamed_kwargs}

def logpdf(self, **kwargs):
return self.distribution.logpdf(**self.parse_unflattened_parameters(**kwargs))

def pdf(self, **kwargs):
return self.distribution.logpdf(**self.parse_unflattened_parameters(**kwargs))

def ppf(self, **kwargs):
return self.distribution.ppf(**self.parse_unflattened_parameters(**kwargs))
def sf(self, **kwargs):
return self.distribution.sf(**self.parse_unflattened_parameters(**kwargs))

def ppf(self, q, **kwargs):
return self.distribution.ppf(q, **self.parse_unflattened_parameters(**kwargs))

def sample(self, size=1, **kwargs):
"""
Expand Down
1 change: 1 addition & 0 deletions SurPyval/node/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ def sample(self, **kwargs):

def __str__(self):
return self.name

27 changes: 25 additions & 2 deletions SurPyval/node/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from SurPyval.node.node import Node
from SurPyval.node.transformation import DeterministicNode
from SurPyval.node.datalikihoodnode import DataLikihoodNode
from SurPyval.node.datanode import DataNode


class NodeTree:
Expand All @@ -19,16 +20,38 @@ class NodeTree:
lookup from name to data (e.g. (event->np.array([True, False, True])
"""

def __init__(self, node_dict: Dict[str, Node], data_dict: Dict[str, Any]):
def __init__(self, node_dict: Dict[str, Node]):
self.parameters: List[ParameterNode] = [x for x in node_dict.values() if type(x) is ParameterNode]
self.transformations: List[DeterministicNode] = [x for x in node_dict.values() if type(x) is DeterministicNode]
self.likihood_nodes: List[DataLikihoodNode] = [x for x in node_dict.values() if type(x) is DataLikihoodNode]

self.data_dict = data_dict
self.data_dict = {x[0]: x[1].data for x in node_dict.items() if type(x[1]) is type(x[1]) is DataNode}
self.node_dict = node_dict
self.node_names = sorted(node_dict.keys())
self.flat_split_point = self.flattened_parameter_split_points()

def __getitem__(self, item: str):
return self.node_dict[item]

def update(self, updated_node_dict: Dict[str, Node]) -> 'NodeTree':
"""
Upsert nodes in the tree

Nodes that are already in the tree will be updated
Original node_tree isn't modifed in the process

Parameters
----------
updated_node_dict: Dict[str, Node]
nodes to update or insert to the tree

Returns
-------
NodeTree
Updated NodeTree with new nodes
"""
return NodeTree({**self.node_dict, **updated_node_dict})

def append_transformations(self, parameter_dict):
"""
Enrich the parameter dictionary with transformed parameters
Expand Down
67 changes: 19 additions & 48 deletions SurPyval/parametric/fitted/exponential.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,28 @@
import numpy as np
import scipy.stats

import seaborn as sns
from SurPyval.node import NodeTree, gamma, DataNode, DeterministicNode, ParameterNode, DataLikihoodNode
from SurPyval.model.model import Model

from SurPyval.distributions.gamma import Gamma
from SurPyval.core.sampling import NumpySampler


def create_prior_from_deaths_and_total_observed_time(deaths, total_observed_time):
def create_prior_from_deaths_and_total_observed_time(parameter_name, deaths, total_observed_time):
alpha = deaths
llambda = total_observed_time
return Gamma(alpha, llambda)
return gamma({parameter_name: "x"}, {"a": alpha, "scale": llambda})


def create_prior_from_deaths_and_average_lifetime(deaths, average_lifetime):
def create_prior_from_deaths_and_average_lifetime(parameter_name, deaths, average_lifetime):
alpha = deaths
llambda = average_lifetime * deaths
return Gamma(alpha, llambda)

class DistributionNode:

def __init__(self, name, prior_distribution, posterior_distribution):
self.name = name
self.prior = prior_distribution
self.posterior = posterior_distribution

def sample_prior(self, n_samples = 10000):
return self.prior.sample(n_samples)

def sample_posterior(self, n_samples = 10000):
return self.posterior.sample(n_samples)
return gamma({parameter_name: "x"}, {"a": alpha, "scale": llambda})

class PosteriorPredictiveDistribution(NumpySampler):

def __init__(self, gamma_distribution):
self.distr = gamma_distribution

def sample(self, n_samples):
samples = self.distr.sample(n_samples)
posterior_predictive_samples = np.array(map(lambda x: np.random.exponential(1.0 / x), samples))
return posterior_predictive_samples

def plot(self, n_samples = 10000):
sns.distplot(self.sample(n_samples))

class FittedExponential:
class FittedExponential(Model):
"""
Fit an exponential distribution to the life lengths

Nodes:
* llambda - the main parameter for the expoential distribution
* alpha - the scale of the exponential distribution
* y - predictive distribution for lifetime

Likihood:
Expand All @@ -59,21 +35,16 @@ class FittedExponential:
lambda => Total observed time
"""


def __init__(self, prior_dict, y, event):
self.prior_dict = prior_dict
self.d = np.sum(event)
self.y_sum = np.sum(y)
llambda_posterior = Gamma(prior.alpha + self.d, prior.llambda + self.y_sum)

self.constants = {"alpha_0": prior_dict["llambda"].alpha, "llambda_0": prior_dict["llambda"].llambda}
self.nodes = {
"llambda": DistributionNode("llambda", prior_dict["llambda"], llambda_posterior),
"y": DistributionNode("y", PosteriorPredictiveDistribution(prior_dict["llambda"]), PosteriorPredictiveDistribution(llambda_posterior))
def __init__(self, y, event, alpha_prior):
node_dict = {
"alpha": alpha_prior,
"y": DataLikihoodNode(scipy.stats.expon, y, {"alpha": "scale", "y": "x"}),
"event": DataNode("event", event)
}

def fit(self):
pass
node_tree = NodeTree(node_dict)

Model.__init__(self, node_tree)

@staticmethod
def show_plate():
Expand Down
Loading