Plotting Kaplan-Meier survival curves - Part 2
Introduction
As mentioned in the previous notebook, we are almost always interested in identifying factors (covariates) that significantly modify the probability of the event of interest when analyzing time-to-event data. In our example, this means identifying which subset of customers, whether in terms of demographics or purchasing behaviour, are more or less likely to stop buying the company's services at a certain point. One of the ways we can do it is by plotting and examining KM survival curves that are stratified by each variable.
If you want to try this yourself, click on "Remix" in the upper right corner to get a copy of the notebook in your own workspace. Please remember to change the Python and R runtimes to survival_Python
and survival_R
, respectively, from this notebook to ensure that you have all the installed packages and can start right away.
Import data and discretize continuous variable
As Kaplan-Meier estimation of the survival function cannot characterize the probability of event occurrence in relation to a continuous variable, we need to bin the continuous variable MonthlyCharges
into discrete levels, so that we can examine the probability of customer churn at different "tiers" of customer spending. Here, we will use the arulesCBA
package, which uses a supervised approach to identify bin breaks that are most informative with respect to a class label, which is Churn
in our case. We will ignore TotalCharges
as it is a product of MonthlyCharges
and Tenure
, the latter of which is already part of the survival function.
## Import library library(plyr) library(dplyr) library(arulesCBA) ## Import data df <- read.csv("https://github.com/nchelaru/data-prep/raw/master/telco_cleaned_yes_no.csv") ## Encode "Churn" as 0/1 df <- df %>% mutate(Churn = ifelse(Churn == "No",0,1)) ## Set "Churn" as factor df$Churn <- as.factor(df$Churn) ## Discretize "MonthlyCharges" with respect to "Churn"/"No Churn" label and assign to new column in dataframe df$Binned_MonthlyCharges <- discretizeDF.supervised(Churn ~ ., df[, c('MonthlyCharges', 'Churn')], method='mdlp')$MonthlyCharges ## Check the dataframe head(df)
Let's see what the new bins look like:
## Import library library(ggplot2) ## Summarize proportion of customers that churned for each tier of monthly fee summary_df <- data.frame(t(table(df$Binned_MonthlyCharges, df$Churn))) ## Plot ggplot(summary_df, aes(x = Var2, y = Freq, fill = Var1)) + geom_col() + xlab("Binned monthly charges") + ylab("No. customers") + labs(fill = "Churn")
Looks like there are differences in the propensity to churn between tiers of monthly fees. Consistent with what we saw in the conditional probability density plots, customers paying $29.4-56/month or $68.8-107/month appear to be more likely to leave the company.
This gives us some inkling that the discretization does result in informative binning. So, we want to rename the levels of the binned MonthlyCharges
variable, to make them more reader friendly:
## Check levels of binned variable unique(df$Binned_MonthlyCharges)
## Rename the levels based on knowledge of min/max monthly charges df$Binned_MonthlyCharges = revalue(df$Binned_MonthlyCharges, c("[-Inf,29.4)"="$0-29.4", "[29.4,56)"="$29.4-56", "[56,68.8)"="$56-68.8", "[68.8,107)"="$68.8-107", "[107, Inf]" = "$107-118.75")) ## Output dataframe to CSV so it can be used by Python runtime write.csv(df, './results/processed_telco_df.csv', row.names=FALSE)
Plotting stratified KM survival curves
Workflow in Python
## Import libraries import pandas as pd from lifelines import KaplanMeierFitter import matplotlib.pyplot as plt from matplotlib import style from lifelines.statistics import multivariate_logrank_test from matplotlib.offsetbox import AnchoredText ## Import data df = pd.read_csv(processed_telco_df.csv) ## Set colour dictionary for consistent colour coding of KM curves colours = {'Yes':'g', 'No':'r', 'Female':'b', 'Male':'y', 'Month-to-month':'#007f0e', 'Two year':'#c4507c','One year':'#feba9e', 'DSL':'#ad53cd', 'Fiber optic':'#33ccff', 'Electronic check':'#33cc33', 'Mailed check':'#ff8000', 'Bank transfer (automatic)':'#9933ff', 'Credit card (automatic)':'#ff66b3', '$0-29.4':'#ff3333', '$29.4-56':'#55ff00', '$56-68.8':'#1a8cff', '$68.8-107':'#48bc80', '$107-118.75':'#ffff4d' } ## Set up subplot grid fig, axes = plt.subplots(nrows = 6, ncols = 3, sharex = True, sharey = True, figsize=(40, 70) ) ## Plot KM curve for each categorical variable def km_curves(feature, t='Tenure', event='Churn', df=df, ax=None): for cat in sorted(df[feature].unique(), reverse=True): idx = df[feature] == cat kmf = KaplanMeierFitter() kmf.fit(df[idx][t], event_observed=df[idx][event], label=cat) kmf.plot(ax=ax, label=cat, ci_show=True, c=colours[cat]) ## Loop over each categorical variable to plot stratified survival curves and compare them by the log-rank test col_list = [col for col in df.columns if df[col].dtype == object] for cat, ax in zip(col_list, axes.flatten()): km_curves(feature=cat, t='Tenure', event='Churn', df = df, ax=ax) ax.legend(loc='lower left', prop=dict(size=48)) ax.set_title(cat, pad=20, fontsize=56) p = multivariate_logrank_test(df['Tenure'], df[cat], df['Churn']) ax.add_artist(AnchoredText(p.p_value, frameon=False, loc='upper right', prop=dict(size=46))) ax.set_xlabel('Tenure (months)', fontsize = 40) ax.set_ylabel('Survival probability', fontsize = 40) ## Format subplots fig.subplots_adjust(wspace = 0.2, hspace = 0.2, bottom = 0.2) fig.tight_layout() plt.gcf()
This type of "shotgun" approach helps us to quickly pick out which variable may have particular impact on the rate of customer churn. For each variable, the curve(s) that decline faster to 0% survival probability represent population subsets that are more likely to stop buying the company's services. For example, consistent with what we saw with factor analysis of mixed data and correlation analysis, customers with month-to-month contracts, fiber optic internet service and/or electronic check sic payment method are more likely than those without to churn. The Binned_MonthlyCharges
group of curves also support findings from the conditional probability density plot, where customers at the two extremes of monthly fees are less likely churn than those in the middle of the range. Conversely, male and female customers have overlapping KM curves, and these two groups are shown to have similar rates of leaving the company.
It must be noted that visual comparisons of the KM curves must be supplemented with statistical testing, which is the log-rank test. Analgous to the t-test, the log-rank test compares the survival function of two groups with the null hypothesis that there is no difference between them at any time point. The number at the bottom right corner of each subplot is the p-value result of the log-rank test for that group of comparisons. We see that all variables but Gender
and PhoneService
have significant differences in customer churn. While the log-rank test is non-parametric, and thus do not make assumptions regarding the distribution of time-to-event and relationships between covariates and time-to-event, it does rest upon the proportional hazards assumption. Therefore, results of log-rank test are not valid for variables showing time-dependent effects. This is most obviously seen in cases where the KM curve intersect, such as for StreamingTV
, StreamingMovies
and MultipleLines
. We will cover this in more detail in future notebooks.
Workflow in R
The R package survival
offers many of the same functionalities, which we will quickly demonstrate here. The code below generates the same stratified KM curves as in Python, so we will just take a closer look at the KM curve for each binned level of MonthlyCharges
:
## Import libraries library(survival) library(survminer) ## Import data df <- read.csv(processed_telco_df.csv) ## Drop columns containing continuous variables and the event variable "Churn" subset_df <- within(df, rm("Churn", "Tenure", "MonthlyCharges", "TotalCharges")) ## For loop to create list of KM curve graph objects, one for each categorical variable plot_list <- list() for (i in colnames(subset_df)){ x <- survfit(Surv(Tenure, Churn) ~ get(i), data=df) y <- ggsurvplot(x, conf.int = TRUE, pval = TRUE, legend = 'right', ggtheme = theme_bw(), font.main = c(16), font.x = c(12), font.y = c(12), font.legend = c(10), ) + ggtitle(i) plot_list[[i]] <- y } ## Closer look at binned "MonthlyCharges" plot_list$Binned_MonthlyCharges
Parting notes
Congratulations on reaching this point, where you have learned the very fundamentals of performing survival analysis! By now, you can already derive some very interesting insights from time-to-event datasets.
You can continue on to the "What You Should Know" section to learn about more advanced topics in survival regression.