Text Classification with NLTK
In this post, we will expand on our NLP foundation and explore different ways to improve our text classification with NLTK and Scikit-learn. In details, we will build SMS spam filters.
- Required Packages
- Version Check
- Load the dataset
- Data-preprocess
- Feature extraction
- Scikit-learn Classifier with NLTK
- Summary
import sys
import nltk
import sklearn
import pandas as pd
import numpy as np
print('Python: {}'.format(sys.version))
print('NLTK: {}'.format(nltk.__version__))
print('Scikit-learn: {}'.format(sklearn.__version__))
print('Pandas: {}'.format(pd.__version__))
print('NumPy: {}'.format(np.__version__))
Load the dataset
Now that we have ensured that our libraries are installed correctly, let's load the data set as a Pandas DataFrame. Furthermore, let's extract some useful information such as the column information and class distributions.
The data set we will be using comes from the UCI Machine Learning Repository. It contains over 5000 SMS labeled messages that have been collected for mobile phone spam research.
sms = pd.read_table('./dataset/SMSSpamCollection', header=None, encoding='utf-8')
sms.head()
sms.info()
sms[0].value_counts()
From the data summary, we can find that the SPAM message is defined as spam
and non-SPAM message is defined as ham
. And there are 747 spam messages in dataset.
Data-preprocess
From the label, label is defined with string type. To recognize it in model, It needs to convert it with binary values. This kind of process is called one-hot encoding. There are several ways to apply one-hot encoding:
- use
pd.get_dummies
- use
LabelEncoder
insklearn.preprocessing
In this time, we use LabelEncoder
,
from sklearn.preprocessing import LabelEncoder
enc = LabelEncoder()
label = enc.fit_transform(sms[0])
print(label[:10])
print(sms[0][:10])
text = sms[1]
text[:10]
Now, it is time to text preprocessing. From the previous post, we've learned several text preprocess. But before apply those techniques, we need to formalize the text that need to remove special characters or numbers like phone numbers and so on. To do this, we can use regular expression(regex for short) for finding the pattern-matching. Here is some common regex form described in wikipedia.
-
^ Matches the starting position within the string. In line-based tools, it matches the starting position of any line.
-
. Matches any single character (many applications exclude newlines, and exactly which characters are considered newlines is flavor-, character-encoding-, and platform-specific, but it is safe to assume that the line feed character is included). Within POSIX bracket expressions, the dot character matches a literal dot. For example, a.c matches "abc", etc., but [a.c] matches only "a", ".", or "c".
-
[ ] A bracket expression. Matches a single character that is contained within the brackets. For example, [abc] matches "a", "b", or "c". [a-z] specifies a range which matches any lowercase letter from "a" to "z". These forms can be mixed: [abcx-z] matches "a", "b", "c", "x", "y", or "z", as does [a-cx-z]. The - character is treated as a literal character if it is the last or the first (after the ^, if present) character within the brackets: [abc-], [-abc]. Note that backslash escapes are not allowed. The ] character can be included in a bracket expression if it is the first (after the ^) character: []abc].
-
[^ ] Matches a single character that is not contained within the brackets. For example, [^abc] matches any character other than "a", "b", or "c". [^a-z] matches any single character that is not a lowercase letter from "a" to "z". Likewise, literal characters and ranges can be mixed.
-
\$ Matches the ending position of the string or the position just before a string-ending newline. In line-based tools, it matches the ending position of any line.
-
( ) Defines a marked subexpression. The string matched within the parentheses can be recalled later (see the next entry, \n). A marked subexpression is also called a block or capturing group. BRE mode requires ( ).
-
\n Matches what the nth marked subexpression matched, where n is a digit from 1 to 9. This construct is vaguely defined in the POSIX.2 standard. Some tools allow referencing more than nine capturing groups.
-
* Matches the preceding element zero or more times. For example, abc matches "ac", "abc", "abbbc", etc. [xyz] matches "", "x", "y", "z", "zx", "zyx", "xyzzy", and so on. (ab)* matches "", "ab", "abab", "ababab", and so on.
-
{m,n} Matches the preceding element at least m and not more than n times. For example, a{3,5} matches only "aaa", "aaaa", and "aaaaa". This is not found in a few older instances of regexes. BRE mode requires {m,n}.
If you want to test your regex form, test it here
# Replace email addresses with 'email'
processed = text.str.replace(r'^.+@[^\.].*\.[a-z]{2,}$', 'emailaddress')
# Replace URLs with 'webaddress'
processed = processed.str.replace(r'^http\://[a-zA-Z0-9\-\.]+\.[a-zA-Z]{2,3}(/\S*)?$', 'webaddress')
# Replace money symbols with 'moneysymb' (£ can by typed with ALT key + 156)
processed = processed.str.replace(r'£|\$', 'moneysymb')
# Replace 10 digit phone numbers (formats include paranthesis, spaces, no spaces, dashes) with 'phonenumber'
processed = processed.str.replace(r'^\(?[\d]{3}\)?[\s-]?[\d]{3}[\s-]?[\d]{4}$', 'phonenumbr')
# Replace numbers with 'numbr'
processed = processed.str.replace(r'\d+(\.\d+)?', 'numbr')
And it is required to remove useless characters like whitespace, punctuation and so on.
processed = processed.str.replace(r'[^\w\d\s]', ' ')
# Replace whitespace between terms with a single space
processed = processed.str.replace(r'\s+', ' ')
# Remove leading and trailing whitespace
processed = processed.str.replace(r'^\s+|\s+?$', '')
After that, we will use all lower case sentence.
processed = processed.str.lower()
processed
Then, in the previous post, we learned about stopword removing for text preprocessing. we can apply this.
from nltk.corpus import stopwords
stop_words = set(stopwords.words('english'))
processed = processed.apply(lambda x: ' '.join(term for term in x.split() if term not in stop_words))
Also, using PorterStemmer, we can extract stem of each word.
ps = nltk.PorterStemmer()
processed = processed.apply(lambda x: ' '.join(ps.stem(term) for term in x.split()))
processed
Then, you can see processed message is quite different from original one, since stop word removing, stemming and regular expression is applied.
from nltk.tokenize import word_tokenize
all_words = []
for message in processed:
words = word_tokenize(message)
for w in words:
all_words.append(w)
all_words = nltk.FreqDist(all_words)
# Print the result
print('Number of words: {}'.format(len(all_words)))
print('Most common words: {}'.format(all_words.most_common(15)))
word_features = [x[0] for x in all_words.most_common(1500)]
So we created the feature list, now we need to find the what features are in messages.
def find_features(message):
words = word_tokenize(message)
features = {}
for word in word_features:
features[word] = (word in words)
return features
features = find_features(processed[0])
for key, value in features.items():
if value == True:
print(key)
list(features.items())[:10]
Finally, we made an one simple data that we can use it as an training set. We can apply same apporach in other dataset. Then, we need to split into training set and test set
messages = list(zip(processed, label))
np.random.seed(1)
np.random.shuffle(messages)
# Call find_features function for each SMS message
feature_set = [(find_features(text), label) for (text, label) in messages]
from sklearn.model_selection import train_test_split
training, test = train_test_split(feature_set, test_size=0.25, random_state=1)
print(len(training))
print(len(test))
Scikit-learn Classifier with NLTK
Now, we build the training and test set, we can build machine learning model in scikit-learn. We are using the following alogithms and see the performance of each ones,
- KNearestNeighbors
- Random Forest
- Decision Tree
- Logistic Regression
- Naive Bayes
- Support Vector Machine
from nltk.classify.scikitlearn import SklearnClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression, SGDClassifier
from sklearn.naive_bayes import MultinomialNB
from sklearn.svm import SVC
from sklearn.metrics import classification_report, accuracy_score, confusion_matrix
names = ['K Nearest Neighbors', 'Decision Tree', 'Random Forest', 'Logistic Regression', 'SGD Classifier',
'Naive Bayes', 'Support Vector Classifier']
classifiers = [
KNeighborsClassifier(),
DecisionTreeClassifier(),
RandomForestClassifier(),
LogisticRegression(),
SGDClassifier(max_iter=100),
MultinomialNB(),
SVC(kernel='linear')
]
models = zip(names, classifiers)
for name, model in models:
nltk_model = SklearnClassifier(model)
nltk_model.train(training)
accuracy = nltk.classify.accuracy(nltk_model, test)
print("{} model Accuracy: {}".format(name, accuracy))
From the result, most of models can get almost 95~98% accuracy. But we can also enhance our model to voting the best model from the result, the one of ensemble approach. To do this, we need to use VotingClassifier
from sklearn.ensemble
. You can find the details of Voting Classifier here.
from sklearn.ensemble import VotingClassifier
# Since VotingClassifier can accept list type of models
models = list(zip(names, classifiers))
nltk_ensemble = SklearnClassifier(VotingClassifier(estimators=models, voting='hard', n_jobs=-1))
nltk_ensemble.train(training)
accuracy = nltk.classify.accuracy(nltk_ensemble, test)
print("Voting Classifier model Accuracy: {}".format(accuracy))
We are done. We can generate the confusion matrix, one of the metrics to check classification performance.
text_features, labels = zip(*test)
prediction = nltk_ensemble.classify_many(text_features)
print(classification_report(labels, prediction))
Also we can see the confusion matrix as an DataFrame format (more fancy I guess)
pd.DataFrame( confusion_matrix(labels, prediction),
index=[['actual', 'actual'], ['ham', 'spam']],
columns = [['predicted', 'predicted'], ['ham', 'spam']])
Summary
From this post, we made an SMS spam filter from given SMS dataset. In order to do this, we need preprocess text(seen from previous post like toknization, stemming, stop words removing and so on) and feature extraction to make dataset. NLTK is great tool to do it and it helps to train the model with SklearnClassifier
wrapper. After that, we finally made SMS spam filter with Voting Method(one of ensemble approach) that has almost 98% accuracy.