Kaplan-Meier estimator of the survival function
Introduction
Survival analysis is a collection of statistical approaches for characterizing 1) the time-dependent changes in the probability of an event of interest and 2) the factors that modify them. While survival analysis began as a method for examining patient mortality, it is now widely used in many other domains including business (customer churn) and engineering (equipment failure)
In this series, we will put together a starter kit of essential concepts and tools for performing survival analysis. As in our other posts, we will use the well-known IBM Telco customer churn dataset as an example, to illustrate the need to use a diverse range of analytic approaches to generate balanced and confident insights from a set of data.
If you want to try this yourself, click on "Remix" in the upper right corner to get a copy of the notebook in your own workspace. Please remember to change the Python and R runtimes to survival_Python
and survival_R
, respectively, from this notebook to ensure that you have all the installed packages and can start right away.
First, we will introduce several fundamental concepts of survival analysis.
Survival function
The survival function S(t) describes the probability that the event of interest does not happen at a given time t.
Kaplan-Meier estimator
As the actual survival function cannot be observed, it is approximated by the non-parametric Kaplan-Meier (KM) estimator, which essentially calculates the proportion of at-risk subjects that has not yet "succumbed" (to death, malfunction, etc) out of all at-risk subjects present at each time t.
The survival function S(t) is expressed as:
Hazard function
The hazard function h(t) describes the instantaneous probability of the event of interest occurring at a given time t, given that it has not happened yet.
The hazard function is inversely related to the survival function.
Proportional hazards
One of the key uses for survival analysis is to identify factors (covariates) that significantly modify the probability of the event of interest, such as gender for cancer patient survival.
An important assumption underlying certain survival analysis approaches, notably the log-rank test (discussed below) and the Cox proportional hazard regression model (a later post), is that the efffect of each variable on probability of the event of interest remains the same with respect to time.
Imagine that the schematics below plot the number of patients that have died (y-axis) against the length of time since diagnosis (x-axis), with male and female patients indicated by green and red, respectively. On the left, we see what the plot would look like if the gender variable meets the proportional hazards assumption, as being male or female has the same effect on patient mortality across all time points examined. On the right, we see what it would look like if the assumption is not met, where being female has a protective effect at early stages of the disease, but becomes detrimental later on.

It is very important to test whether each of the variables in your dataset meet the proportional hazards assumption, as results of the analyses cannot be trusted otherwise. Also, there are methods to handle time-varying factors, such as adding a time interaction term and/or stratification (discussed in later posts).
Log-rank test
Analgous to the t-test, the log-rank test compares the survival function of two groups with the null hypothesis that there is no difference between them at any time point.
The log-rank test cannot be used if the proportional hazard assumption is not met, which is indicated by the survival curves of the two groups crossing over each other.
Import and preprocess data
After importing the data that has been cleaned and straightened out (see previous post), we will first re-encode the outcome column "Churn" in binary form (0/1) as required by the packages that we will use in this post.
As Kaplan-Meier estimation of the survival function cannot characterize the probability of event occurrence in relation to a continuous variable, we need to bin the continuous variable MonthlyCharges
into discrete levels, so that we can examine the probability of customer churn at different "tiers" of customer spending. Here, we will use the arulesCBA
package, which uses a "supervised" approach to identify bin breaks that are most informative with respect to a class label, which is Churn
in our case. We will ignore TotalCharges
as it is a product of MonthlyCharges
and Tenure
, the latter of which is already part of the survival function.
## Import library library(plyr) library(dplyr) library(arulesCBA) ## Import data df <- read.csv("https://github.com/nchelaru/data-prep/raw/master/telco_cleaned_yes_no.csv") ## Encode "Churn" as 0/1 df <- df %>% mutate(Churn = ifelse(Churn == "No",0,1)) ## Set "Churn" as factor df$Churn <- as.factor(df$Churn) ## Discretize "MonthlyCharges" with respect to "Churn"/"No Churn" label and assign to new column in dataframe df$Binned_MonthlyCharges <- discretizeDF.supervised(Churn ~ ., df[, c('MonthlyCharges', 'Churn')], method='mdlp')$MonthlyCharges ## Check the dataframe head(df)
We want to rename the levels of the binned MonthlyCharges
variable, to make them more reader friendly:
## Check levels of binned variable unique(df$Binned_MonthlyCharges)
## Rename the levels based on knowledge of min/max monthly charges df$Binned_MonthlyCharges = revalue(df$Binned_MonthlyCharges, c("[-Inf,29.4)"="$0-29.4", "[29.4,56)"="$29.4-56", "[56,68.8)"="$56-68.8", "[68.8,107)"="$68.8-107", "[107, Inf]" = "$107-118.75")) ## Check the dataframe head(df) ## Output dataframe to CSV so it can be used by Python runtime write.csv(df, 'results/processed_telco_df.csv', row.names=FALSE)
Workflow in Python
We will first demonstrate the Python implementation of the Kaplan-Meier method for estimating the probability of customer churn.
The lifelines
API is very similar to that of scikit-learn
, where the KaplanMeierFitter
object is first instantiated and then fitted to the data.
## Import libraries import pandas as pd import matplotlib.pyplot as plt from matplotlib import style import plotly.tools as tls from lifelines import KaplanMeierFitter ## Import data df = pd.read_csv(processed_telco_df.csv) ## Set figure size plt.rcParams['figure.figsize'] = [8, 4] ## Instantiate kmf object kmf = KaplanMeierFitter() ## Fit kmf object to data kmf.fit(df['Tenure'], event_observed = df['Churn']) ## Plot KM curve ax = kmf.plot(xlim=(0, 75), ylim=(0, 1)) ax.set_title("Overall survival function") ax.set_xlabel("Tenure") ax.set_ylabel("Survival probability") plt.gcf()
We first created a KM curve for the entire dataset, where the y-axis represents the probability that a customer is still subscribed to the company's services at a given time (x-axis) since they first signed on. We see that the probability of a given customer leaving decreases (the curve flattens) with time, consistent with what we saw previously showing that the probability of churn decreases as the customer tenure increases. The blue band superimposed on the KM curve is the 95% confidence interval.
To prevent (or facilitate) the occurrence of the event of interest, we mostly want to know what factors in the observed population modify the survival probability. In our example, this means identifying which subset of customers, whether in terms of demographics or purchasing behaviour, are more or less likely to stop buying the company's services. To this end, we can plot and compare the KM survival curves for each level within each categorical variable:
## Import libraries from lifelines.statistics import multivariate_logrank_test from matplotlib.offsetbox import AnchoredText ## Set colour dictionary for consistent colour coding of KM curves colours = {'Yes':'g', 'No':'r', 'Female':'b', 'Male':'y', 'Month-to-month':'#007f0e', 'Two year':'#c4507c','One year':'#feba9e', 'DSL':'#ad53cd', 'Fiber optic':'#33ccff', 'Electronic check':'#33cc33', 'Mailed check':'#ff8000', 'Bank transfer (automatic)':'#9933ff', 'Credit card (automatic)':'#ff66b3', '$0-29.4':'#ff3333', '$29.4-56':'#55ff00', '$56-68.8':'#1a8cff', '$68.8-107':'#48bc80', '$107-118.75':'#ffff4d' } ## Set up subplot grid fig, axes = plt.subplots(nrows = 6, ncols = 3, sharex = True, sharey = True, figsize=(40, 70) ) ## Plot KM curve for each categorical variable def categorical_km_curves(feature, t='Tenure', event='Churn', df=df, ax=None): for cat in sorted(df[feature].unique(), reverse=True): idx = df[feature] == cat kmf = KaplanMeierFitter() kmf.fit(df[idx][t], event_observed=df[idx][event], label=cat) kmf.plot(ax=ax, label=cat, ci_show=True, c=colours[cat]) ## Loop over each categorical variable to plot stratified survival curves col_list = [col for col in df.columns if df[col].dtype == object] for cat, ax in zip(col_list, axes.flatten()): categorical_km_curves(feature=cat, t='Tenure', event='Churn', df = df, ax=ax) ax.legend(loc='lower left', prop=dict(size=48)) ax.set_title(cat, pad=20, fontsize=56) p = multivariate_logrank_test(df['Tenure'], df[cat], df['Churn']) ax.add_artist(AnchoredText(p.p_value, frameon=False, loc='upper right', prop=dict(size=46))) ax.set_xlabel('Tenure', fontsize = 40) ax.set_ylabel('Survival probability', fontsize = 40) ## Format subplots fig.subplots_adjust(wspace = 0.2, hspace = 0.2, bottom = 0.2) fig.tight_layout() plt.gcf()
Now we see some interesting stuff! For each variable, the curve(s) that decline faster to 0% survival probability represent population subsets that are more likely to stop buying the company's services. For example, consistent with what we saw with factor analysis of mixed data and correlation analysis, customers with month-to-month contracts, fiber optic internet service and/or electronic check sic payment method are more likely than those without to churn. The Binned_MonthlyCharges
group of curves also support findings from the conditional probability density plot, where customers at the two extremes of monthly fees are less likely churn than those in the middle of the range, particularly in the $29.4-56 and . Conversely, male and female customers have overlapping KM curves, and these two groups are shown to have similar rates of leaving the company.
It must be noted that visual comparisons of the KM curves must be supplemented with statistical testing, which is the log-rank test. The number at the bottom right corner of each subplot is the p-value result of the log-rank test for that group of comparisons. We see that all variables but Gender
and PhoneService
have significant differences in customer churn. However, remember that log-rank test results are unreliable for covariates where the proportional hazard assumption is not met, most obviously seen in cases where the KM curve intersect, such as for StreamingTV
, StreamingMovies
and Binned_MonthlyCharges
. We will cover this in more detail in future posts.
Workflow in R
The R package survival
offers many of the same functionalities, which we will quickly demonstrate here:
## Import libraries library(survival) library(survminer) ## Import data df <- read.csv(processed_telco_df.csv) ## Set figure size options(repr.plot.width=5, repr.plot.height=5) ## Fit KM estimator to data fit <- survfit(Surv(Tenure, Churn) ~ 1, data=df) ## Plot ggsurvplot(fit, conf.int = TRUE, ggtheme = theme_bw(), title = "Overall survival function", xlim = c(0, 75) )
The code below generates the same stratified KM curves as in Python, so I have output them into a PDF for anyone interested:
## Drop columns containing continuous variables and the event variable "Churn" subset_df <- within(df, rm("Churn", "Tenure", "MonthlyCharges", "TotalCharges")) ## For loop to create list of KM curve graph objects, one for each categorical variable plot_list <- list() for (i in colnames(subset_df)){ x <- survfit(Surv(Tenure, Churn) ~ get(i), data=df) y <- ggsurvplot(x, conf.int = TRUE, pval = TRUE, legend = 'right', ggtheme = theme_bw(), font.main = c(16), font.x = c(12), font.y = c(12), font.legend = c(10), ) + ggtitle(i) plot_list[[i]] <- y } ## Create subplots and output to PDF res <- arrange_ggsurvplots(plot_list, ncol = 1, nrow=2, print=FALSE) ggsave("./results/telco_survival_curves.pdf", res)
We can take a closer look at the KM curve for each binned level of MonthlyCharges
, again showing that there are two mid-level "tiers" of monthly fees at which customers tend to churn more quickly:
## Closer look at binned "MonthlyCharges" plot_list$Binned_MonthlyCharges
Parting notes
This brings us to the end of part 1 of our series on survival analysis, which is already offering us interesting insights about our dataset. In the next few posts, we will delve deeper into other statistical and machine learning survival analysis models.
Til next time! :)