🔗 Live Project: retention-ai.onrender.com
Retention-AI is a machine learning project for early churn detection and customer retention. It processes the Telco Customer Churn dataset, trains an XGBoost model, and serves predictions via a FastAPI interface deployed on Render. The system provides real-time churn analysis with feature explanations, helping businesses retain at-risk customers through data-driven insights.
Key Features
- Early Churn Prediction: Detects at-risk customers before they leave.
- Data-Driven Insights: Uses explainable AI to highlight churn risk factors.
- Scalable Architecture: Built with FastAPI and deployed on Render for real-time predictions.
- End-to-End Workflow: From data preprocessing to model deployment.
The model is deployed and ready for you to test. Visit our Live Demo to see Retention-AI in action!
Watch a quick video showcasing how to use Retention-AI:
The Telco Customer Churn dataset was selected as it contains customer demographics, account details, and service usage information, making it suitable for predicting churn.
- Identified missing or incorrect values in the TotalCharges column.
- Converted it to numeric format (
float64), replacing non-numeric values withNaN. - Imputed missing values with
0.
- Binary Encoding: Converted categorical variables with two unique values (e.g.,
gender,Partner,Churn) into binary (0and1). - One-Hot Encoding: Applied one-hot encoding for categorical columns with more than two values (
Contract,InternetService,PaymentMethod), dropping the first category to avoid multicollinearity.
- Replaced
"No phone service"and"No internet service"with"No"in relevant columns (MultipleLines,OnlineSecurity,OnlineBackup, etc.). - Mapped
"No"to0and"Yes"to1.
- Ensured one-hot encoded columns were correctly cast to integers (
int).
- Removed
customerID, as it is an identifier with no predictive value.
- The cleaned dataset was saved as
telco_cleaned.csvin the processed data folder for model training.
Since this dataset isn’t time-series, the model can’t track gradual disengagement, limiting its ability to predict churn before clear signs appear.
📂 For full preprocessing details, check data-processing.ipynb.
The model is trained using the XGBoost classifier, a powerful gradient boosting algorithm well-suited for imbalanced classification problems like churn prediction. The dataset is split into:
- Training Set (80%) – Used to train the model.
- Testing Set (16%) – Used to evaluate the model.
- Demo Set (4%) – Saved separately for frontend testing.
The model was trained with the following hyperparameters:
max_depth=6: Limits tree depth to prevent overfitting.learning_rate=0.01: Ensures gradual convergence.n_estimators=300: Uses 300 boosting rounds.scale_pos_weight=0.8 * (non-churn / churn ratio): Adjusts for the imbalance in churned vs. non-churned customers.
The model achieves 79.33% accuracy in predicting customer churn.
precision recall f1-score support
0 0.91 0.79 0.85 823
1 0.59 0.79 0.67 304
accuracy 0.79 1127
macro avg 0.75 0.79 0.76 1127
weighted avg 0.82 0.79 0.80 1127
This classification report provides insights into the model's performance:
- Precision: Of all the customers predicted as churned, precision tells us the percentage that actually churned. A precision of 0.59 for churn means that when the model predicts a customer will churn, it is correct 59% of the time.
- Recall: Of all the customers who actually churned, recall tells us the percentage that the model correctly predicted. A recall of 0.79 for churn means that the model correctly identifies 79% of customers who will churn.
- F1-Score: The F1-score is the harmonic mean of precision and recall. It provides a single score that balances both concerns.
In this context:
- High recall for churned customers (0.79) means the model correctly identifies most at-risk customers, which is crucial for proactive retention strategies.
- Precision for churned customers (0.59) indicates some false positives but is acceptable given the focus on minimizing churn loss.
The model assigns importance scores to all input features based on how much they contribute to churn prediction. After training, the most impactful features identified were:
The trained model is saved as: model_training.py
This FastAPI application serves as a lightweight backend for churn prediction, allowing users to input customer details and receive predictions on whether they are likely to churn. It also provides explanations for the predictions using SHAP values.
GET /– Serves the homepage with a form for user input.POST /predict– Accepts customer details via a form, processes the input, and returns a churn prediction along with feature importance explanations.
-
Model Loading
- The pre-trained XGBoost model is loaded from
models/churn_model.json. - A SHAP explainer is initialized to provide feature importance for individual predictions.
- The pre-trained XGBoost model is loaded from
-
Prediction Flow
- User inputs are collected via an HTML form and passed to the
/predictendpoint. - The input is formatted to match the trained model’s expected features.
- The model predicts the probability of churn and returns a classification (Churn / Not Churn).
- The top three most influential features for the prediction are extracted using SHAP values.
- User inputs are collected via an HTML form and passed to the
-
Feature Mapping
- Since FastAPI form inputs use underscores (
_), a feature mapping is applied to match the model’s expected column names (e.g.,Contract_One_year→Contract_One year).
- Since FastAPI form inputs use underscores (
-
Result Presentation
- The prediction result, probability, and top influencing features are displayed on an HTML results page (
result.html).
- The prediction result, probability, and top influencing features are displayed on an HTML results page (
When the API returns a prediction, it also provides information on the factors influencing that prediction. Each factor has a number and an arrow, and can be interpreted as follows:
- ▲ Positive Number → Increases churn risk
- ▼ Negative Number → Decreases churn risk
- ▲ Large Positive Number → Strong churn factor
- ▼ Large Negative Number → Strong retention factor
This API enables quick and interpretable customer churn predictions, making it easy to integrate into a broader customer retention system.
The API is deployed on Render using a free instance. Deployment was straightforward:
- Connected the GitHub repository to Render.
- Added build and run commands, including installing dependencies from
requirements.txt. - Render automatically handles deployment and hosting.
Note: Since this is a free instance, it spins down when inactive, causing delays of 50+ seconds for the first request.
Live Demo: https://retention-ai.onrender.com
[Maimuna Zaheer] - [mz2934@nyu.edu]
Project Link: https://github.com/itserror404/retention-ai
