Classification Model
In this post, We will cover the use case of Classification model including Logistic Regression through StatsModels and scikit-learn.
- Resources & Credits
- Packages
- Credit - Load the dataset and EDA
- Simple model - linear regression
- Logistic Regression
Resources & Credits
The dataset that we use are from the book Introduction to Statistical Learning
by Gareth James, Daniela Witten, Trevor Hastie, and Rob Tibshirani. You can check the details in here.
import numpy as np
import pandas as pd
import statsmodels.api as sm
import statsmodels.formula.api as smf
import matplotlib.pyplot as plt
default = pd.read_csv('./dataset/Default.csv')
default = default[['default', 'balance']]
default.shape
default.head()
X = default['balance']
X = sm.add_constant(X)
X.iloc[:6, :]
Then we convert label (or response variable) from text data to numerical.
y = list(map(lambda x: 1 if x == 'Yes' else 0, default['default']))
y[:6]
linear_reg = sm.OLS(y, X).fit()
linear_reg.summary()
After that, we can measure the performance with graph.
y_pred = linear_reg.predict(X)
plt.plot(default['balance'], y_pred)
plt.plot(default['balance'], default['default'], linestyle='none', marker='o', markersize=2, color='red')
plt.show()
In the graph, the real values are shown in red, and blue line is the regression line. And this line cannot classify some sort of negative data. That's the problem.
logistic_reg = sm.Logit(y, X).fit()
logistic_reg.summary()
After that, we can measure the performance with graph. In this case, we need to sort the predicted output for visualization.
y_pred = logistic_reg.predict(X)
plt.plot(np.sort(default['balance']), np.sort(y_pred))
plt.plot(default['balance'], default['default'], linestyle='none', marker='o', markersize=2, color='red')
plt.show()
For the comparison, we plot the two graphs at once.
fig, ax = plt.subplots(1, 2, figsize=(16, 10))
y_pred_linear = linear_reg.predict(X)
y_pred_logistic = logistic_reg.predict(X)
ax[0].plot(default['balance'], y_pred_linear)
ax[0].plot(default['balance'], default['default'], linestyle='none', marker='o', markersize=2, color='red')
ax[1].plot(np.sort(default['balance']), np.sort(y_pred_logistic))
ax[1].plot(default['balance'], default['default'], linestyle='none', marker='o', markersize=2, color='red')
plt.show()