Kaplan-Meier estimator of the survival function
Introduction
In many applications of data analysis, we are often interested in how long it takes before an event occurs, such as patient death in medicine, customer churn in business and equipment failure in engineering. Survival analysis is a well-established method to characterize the probability of time-dependent events occuring, and most importantly the contribution of various factors in modifying this probability.
In this series, we will put together a starter kit of essential concepts and tools for performing survival analysis. As survival analysis commonly used in analyzing customer churn patterns, we will use the IBM Telco customer churn dataset. In addition, as we have used this dataset as an example in demonstrating other analyses, such as factor analysis for mixed data, this provides a great opportunity for illustrating the need to use a diverse range of analytic approaches to generate balanced and confident insights from a dataset.
At its core, survival analysis aims to estimate how much time it would take before a particular event happens. Data required for survival analysis is collected over a fixed-length observation period for a population of subjects, whether they be patients, customers or machinery, by noting 1) whether the event of interest occurs for each object, and if so, 2) when does it occur relative to the start of the observation period. These two parameters allow estimation of the survival function, which is one of the two key functions in survival analysis (the other is hazard function and will be covered in the next post). The survival function computes, at each time t, the probability that the event of interest does not happen. However, in reality, the actual survival function of a population cannot be observed. Instead, it is estimated by the non-parametric Kaplan-Meier (KM) estimator, which essentially calculates the proportion of at-risk subjects that has not yet "succumbed" (to death, malfunction, etc) out of all at-risk subjects present at each time t.
The Python package lifelines
and R package survival
both provide easy implementations to plot the Kaplan-Meier survival curve from a dataset that contains (at least) the two parameters for each subject observed: whether the event of interest has occurred and if so, when did it occur relative to the start point of observation. For the Telco data set, this correspond to the Churn
and Tenure
columns, respectively.
If you want to try this yourself, click on "Remix" in the upper right corner to get a copy of the notebook in your own workspace. Please remember to change the Python and R runtimes to survival_Python
and survival_R
, respectively, from this notebook to ensure that you have all the installed packages and can start right away.
Let's get started!
Import and preprocess data
After importing the data that has been cleaned and straightened out (see previous post), we will first re-encode the outcome column "Churn" in binary form (0/1) as required by the packages.
As survival analysis is not designed to characterize the probability of event occurrence in relation to a continuous variable, we need to bin the continuous variable MonthlyCharges
into discrete levels, so that we can examine the probability of customer churn at different "tiers" of customer spending. Here, we will use the arulesCBA
package, which uses a "supervised" approach to identify bin breaks that are most informative with respect to a class label, which is Churn
in our case. We will ignore TotalCharges
as it is calculated from MonthlyCharges
x Tenure
, the latter of which is already part of the survival function.
## Import library library(plyr) library(dplyr) library(arulesCBA) ## Import data df <- read.csv("https://github.com/nchelaru/data-prep/raw/master/telco_cleaned_yes_no.csv") ## Encode "Churn" as 0/1 df <- df %>% mutate(Churn = ifelse(Churn == "No",0,1)) ## Set "Churn" as factor df$Churn <- as.factor(df$Churn) ## Discretize "MonthlyCharges" with respect to "Churn"/"No Churn" label and assign to new column in dataframe df$Binned_MonthlyCharges <- discretizeDF.supervised(Churn ~ ., df[, c('MonthlyCharges', 'Churn')], method='mdlp')$MonthlyCharges ## Check the dataframe head(df)
We want to rename the levels of the binned MonthlyCharges
variable, to make them more reader friendly:
## Check levels of binned variable unique(df$Binned_MonthlyCharges)
## Rename the levels based on knowledge of min/max monthly charges df$Binned_MonthlyCharges = revalue(df$Binned_MonthlyCharges, c("[-Inf,29.4)"="$0-29.4", "[29.4,56)"="$29.4-56", "[56,68.8)"="$56-68.8", "[68.8,107)"="$68.8-107", "[107, Inf]" = "$107-118.75")) ## Check the dataframe head(df) ## Output dataframe to CSV so it can be used by Python runtime write.csv(df, 'results/processed_telco_df.csv', row.names=FALSE)
Workflow in Python
The lifelines
API is very similar to that of scikit-learn
, where the KaplanMeierFitter
object is first instantiated and then fitted to the data.
## Import libraries import pandas as pd import matplotlib.pyplot as plt from matplotlib import style import plotly.tools as tls from lifelines import KaplanMeierFitter ## Import data df = pd.read_csv(processed_telco_df.csv) ## Set figure size plt.rcParams['figure.figsize'] = [8, 4] ## Instantiate kmf object kmf = KaplanMeierFitter() ## Fit kmf object to data kmf.fit(df['Tenure'], event_observed = df['Churn']) ## Plot KM curve ax = kmf.plot(xlim=(0, 75), ylim=(0, 1)) ax.set_title("Overall survival function") ax.set_xlabel("Tenure") ax.set_ylabel("Survival probability") plt.gcf()
We first created a KM curve for the entire dataset, where the y-axis represents the probability that a customer is still subscribed to the company's services at a given time (x-axis) since they first signed on. We see that the probability of a given customer leaving decreases (the curve flattens) with time, consistent with what we saw previously showing that the probability of churn decreases as the customer tenure increases. The coloured band superimposed on the KM curve is the 95% confidence interval.
To prevent (or facilitate) the occurrence of the event of interest, we mostly want to know what factors in the observed population modify the survival probability. In our example, this means identifying which subset of customers, whether in terms of demographics or purchasing behaviour, are more or less likely to stop buying the company's services. To this end, we can plot and compare the KM survival curves for each level within each categorical variable:
## Import libraries from lifelines.statistics import multivariate_logrank_test from matplotlib.offsetbox import AnchoredText ## Set colour dictionary for consistent colour coding of KM curves colours = {'Yes':'g', 'No':'r', 'Female':'b', 'Male':'y', 'Month-to-month':'#007f0e', 'Two year':'#c4507c','One year':'#feba9e', 'DSL':'#ad53cd', 'Fiber optic':'#33ccff', 'Electronic check':'#33cc33', 'Mailed check':'#ff8000', 'Bank transfer (automatic)':'#9933ff', 'Credit card (automatic)':'#ff66b3', '$0-29.4':'#ff3333', '$29.4-56':'#55ff00', '$56-68.8':'#1a8cff', '$68.8-107':'#48bc80', '$107-118.75':'#ffff4d' } ## Set up subplot grid fig, axes = plt.subplots(nrows = 6, ncols = 3, sharex = True, sharey = True, figsize=(40, 70) ) ## Plot KM curve for each categorical variable def categorical_km_curves(feature, t='Tenure', event='Churn', df=df, ax=None): for cat in sorted(df[feature].unique(), reverse=True): idx = df[feature] == cat kmf = KaplanMeierFitter() kmf.fit(df[idx][t], event_observed=df[idx][event], label=cat) kmf.plot(ax=ax, label=cat, ci_show=True, c=colours[cat]) ## Loop over each categorical variable to plot stratified survival curves col_list = [col for col in df.columns if df[col].dtype == object] for cat, ax in zip(col_list, axes.flatten()): categorical_km_curves(feature=cat, t='Tenure', event='Churn', df = df, ax=ax) ax.legend(loc='lower left', prop=dict(size=48)) ax.set_title(cat, pad=20, fontsize=56) p = multivariate_logrank_test(df['Tenure'], df[cat], df['Churn']) ax.add_artist(AnchoredText(p.p_value, frameon=False, loc='upper right', prop=dict(size=46))) ax.set_xlabel('Tenure', fontsize = 40) ax.set_ylabel('Survival probability', fontsize = 40) ## Format subplots fig.subplots_adjust(wspace = 0.2, hspace = 0.2, bottom = 0.2) fig.tight_layout() plt.gcf()
Now we see some interesting stuff! For each variable, the curve(s) that decline faster to 0% survival probability represent population subsets that are more likely to stop buying the company's services. For example, consistent with what we saw with factor analysis of mixed data and correlation analysis, customers with month-to-month contracts, fiber optic internet service and/or electronic check sic payment method are more likely than those without to churn. The Binned_MonthlyCharges
group of curves also support findings from the conditional probability density plot, where customers at the two extremes of monthly fees are less likely churn than those in the middle of the range, particularly in the $29.4-56 and . Conversely, male and female customers have overlapping KM curves, and these two groups are shown to have similar rates of leaving the company.
It must be noted that visual comparisons of the KM curves must be supplemented with statistical testing, which is the log-rank test in this case. The number at the bottom right corner of each subplot is the p-value result of the log-rank test for that group of comparisons. We see that all variables but Gender
and PhoneService
have significant differences in customer churn. However, log-rank test result is unreliable in cases where KM curves intersect, such as for StreamingTV
, StreamingMovies
and Binned_MonthlyCharges
. We will cover this in more detail in future posts.
Workflow in R
The R package survival
offers many of the same functionalities, which we will quickly demonstrate here:
## Import libraries library(survival) library(survminer) ## Import data df <- read.csv(processed_telco_df.csv) ## Set figure size options(repr.plot.width=5, repr.plot.height=5) ## Fit KM estimator to data fit <- survfit(Surv(Tenure, Churn) ~ 1, data=df) ## Plot ggsurvplot(fit, conf.int = TRUE, ggtheme = theme_bw(), title = "Overall survival function", xlim = c(0, 75) )
The code below generates the same stratified KM curves as in Python, so I have output them into a PDF for anyone interested:
## Drop columns containing continuous variables and the event variable "Churn" subset_df <- within(df, rm("Churn", "Tenure", "MonthlyCharges", "TotalCharges")) ## For loop to create list of KM curve graph objects, one for each categorical variable plot_list <- list() for (i in colnames(subset_df)){ x <- survfit(Surv(Tenure, Churn) ~ get(i), data=df) y <- ggsurvplot(x, conf.int = TRUE, pval = TRUE, legend = 'right', ggtheme = theme_bw(), font.main = c(16), font.x = c(12), font.y = c(12), font.legend = c(10), ) + ggtitle(i) plot_list[[i]] <- y } ## Create subplots and output to PDF res <- arrange_ggsurvplots(plot_list, ncol = 1, nrow=2, print=FALSE) ggsave("./results/telco_survival_curves.pdf", res)
We can take a closer look at the KM curve for each binned level of MonthlyCharges
, again showing that there are two mid-level "tiers" of monthly fees at which customers tend to churn more quickly:
## Closer look at binned "MonthlyCharges" plot_list$Binned_MonthlyCharges
Parting notes
This brings us to the end of part 1 of our series on survival analysis, which is already offering us interesting insights about our dataset. In the next few posts, we will delve deeper into other statistical and machine learning survival analysis models.
Til next time! :)