This is the first entry in a series of blog posts about building and validating machine learning pipelines with Apache Spark . Its main concern is to show how to explore data with Spark and Apache Zeppelin notebooks in order to build machine learning prototypes that can be brought into production after working with a sample data set.
The notebooks for this blog post are available via ZeppelinHub:
- Part 1: Normalization
- Part 2: Bag of words
- Part 3: Base line model
We will also give special emphasis to the use of DataFrames . As explained in another post , RDDs are at the very heart of Spark. But one might wish for more comfort as provided by their API when it comes to working with structured data like the one coming from CSV files, Hive tables or from a relational database via JDBC. For this purpose Spark introduced support of DataFrames, a concept that has been already popular among users of R and Python for a while. Basically, a DataFrame is a tabular data structure that is organized into named and typed columns. The gathered name and type information of the DataFrame is called its Schema.
In one of the follow-up posts in this series, we will see how to build and test machine learning pipelines with the DataFrame centered API provided by the spark.ml
package (see Link ). This will include assessing the quality of machine learning models and comparing models among each other. The following quote is well-known in narrow circles:
Essentially, all models are wrong, but some are useful.
— George E. P. Box
So one goal will be to find good models, the ones that are useful.
The problem definition
Throughout the series, we set out to build a spam classifier and base the discussion about model evaluation and working with the DataFrame API around this.
The spam classifier will be able to distinguish spam messages from ham messages . The first being the ones you usually do not want to get as they contain identical content sent to many recipients while the second are the good ones:
We leverage some basic machine learning techniques to build and evaluate such a classifier.
Background: Machine learning
If one is asked to give an easy and general purpose definition of what machine learning is, usually this quote is the go-to solution:
A computer program is said to learn from experience E with respect to some class of tasks T and performance measure P if its performance at tasks in T, as measured by P, improves with experience E.
— Tom M. Mitchell
The great advantage of this quote is that it can give you an idea what might be behind the term “machine learning” even if you had no prior experience with it. But it might be too general to give you a concrete idea. Assuming a little high school mathematics on your part, here is a more concrete description: You are given a fixed number \(m\in\mathbb{N}\) of data points \(x_i\in\mathbb{R}^n\) together with labels \(y_i\in\mathbb{R}\).
Data (think: x = (age, income)) | Label (think: y = credit score) |
\(x_1\) | \(y_1\) |
\(x_2\) | \(y_2\) |
\(x_3\) | \(y_3\) |
\(\vdots\) | \(\vdots\) |
\(x_m\) | \(y_m\) |
In simple terms the aim is to predict the label (right column) from the data (left column). The task T for a machine learning algorithm (or the corresponding computer program) would now consist in finding a function \(f\colon\mathbb{R}^n\rightarrow\mathbb{R}\) belonging to some class (e.g. linear or polynomial) such that \(f(x_i)\approx y_i\) for all \(i=1,\ldots,m\). Here the size of the error in the approximation would be measured by some quality measure P and would need to be small after the program has seen only, say, 60% of the points \((x_i, y_i)\) and thus has learned to perform the task for the remaining 40% after the experience E of seeing the previous 60%.
But how can a computer program learn after all? Under the hood it is all mathematical optimization – solving the right equations. If not exactly than approximately. What these equations are, is determined by calculus and the information it gives based on mathematical relations between the kind of function one is seeking for and the performance measure that is used together with the optimization goal.
A simple and classical example of machine learning is to fit a line through the origin to a two-dimensional point cloud. The parameter to be learned (or, approximated) is the angle \(\theta\) in the picture, and the error measure typically would be the mean square error . The horizontal axis would be the \(x\)-axis and the vertical the \(y\)-axis. The prediction the model would make, is \(y=tan(\theta)\cdot x\).
The actual fitting of the model would be done with a training set, which could consist of 60% of the available data, while a test of the predictive power of the thus learned model could be done on the remaining 40% of previously unseen test data.
Binary classification
If the labels come from a fixed and finite set of mutually exclusive values, the problem of predicting the label from the data is called classification. If there are only two such labels, one speaks about binary classification. Deciding whether an email is ham or spam is thus a binary classification problem. It is custom to call one of the two labels the positive one and the other the negative one. Which is which is completely a matter of convention, though certain connotations in real-life examples certainly play a role when it comes to assigning the classes. For example, it appears natural to identify the ham class as the positive one.
Convention: Throughout this blog post series and in the Zeppelin notebooks the positive class is ham and will be encoded with 1.0
(see the function encodeLabel
below).
The dataset
The data set consists of 5564 SMS, which have been classified by hand as either ham or spam. Only 13,4% or about 747 of these SMS are spam. This means the data set is skewed and provides only few examples of spam. This is something to keep in mind as it can introduce bias when training models.
We use the CSV parsing library by Databricks to read the data into a DataFrame:
1val smsSpamCollectionDf: Dataframe = sqlContext 2 .read 3 .format("com.databricks.spark.csv") 4 .option("delimiter", "\t") 5 .option("header", "false") 6 .option("mode", "PERMISSIVE") 7 .schema(schemaSmsSpamCollection) 8 .load(smsSpamCollectionPath) 9 .cache()
So how does this data look like? Here is an extract:
Label | SmsText |
ham | HI BABE IM AT HOME NOW WANNA DO SOMETHING? XX |
spam | URGENT! You have won a 1 week FREE membership in our £100,000 Prize Jackpot! Txt the word: CLAIM to No: 81010 T&C www.dbuk.net LCCLTD POBOX 4403LDNW1A7RW18 |
ham | K.k:)apo k.good movie. |
ham | He is a womdarfull actor |
ham | Ffffffffff. Alright no way I can meet up with you sooner? |
ham | Great! I hope you like your man well endowed. I am <#> inches… |
As you see social media text can really get dirty:
- Misspelled words? Happens.
- Missing whitespace? Happens.
- Slang? Happens.
- Abbreviations as “u”, “urs”, “yrs”, “y”,…?
- Violation of grammar rules?
Be sure to encounter these.
Data preparation
Before we can actually feed the data to any machine learning pipeline, we need to go through some preparational steps.
Step 1: Normalization
We lowercase all SMS texts and apply the following normalizations:
Replace … | by … |
Any number | " normalizednumber " |
Any emoticon (like ;-) ) | " normalizedemoticon " |
currency sign (like €, $) | " normalizedcurrencysymbol " |
links | " normalizedurl " |
HTML character entities (like ">" ) | "" |
punctuation and special characters | " " |
Why do we do this? Because as good teachers we want to present good examples to our learning algorithms. Hence we do not want to distract them with cluttered details but just give them the information that there was a number, that there was an emoticon, that there was an URL – but we do not tell them what exactly that number, emoticon or URL was. To drive this point home: Think about teaching a child the letters of the alphabet by showing it various examples of letters written on scrapes of paper. Neither the color of the paper nor the handwriting of the person who wrote the letter on the paper should be a part of the concept the child is about to learn. You want to abstract over the various representations a letter can take. Typically (magically?) this is something children seem to accomplish on their own by nature. For our case – teaching a machine – it does not hurt to bring in the focus by hand via the normalization.
Inserting leading and trailing white space in the normalization string helps to separate entities that are to be normalized and that might occur next to each other. For example this is the case with "£1000"
. We want to take mere occurrences into account and not their relation to each other. That is, we only care that "£"
and "1000"
occur but not in which order or relation to each other. Rather we will treat each SMS text as a bag of words , a technique quite common in document classification.
A little care is needed when making these replacements as the result will depend on the order in which the steps are applied. For example, ;-)
will not result in " normalizedemoticon "
if first punctuation and special characters are removed.
User-defined functions (UDFs) for the normalization
Spark SQL allows to define user-defined functions (UDFs) to map columns of a DataFrame. In our case all UDFs follow more or less the same template and the usage is illustrated by the following snippet for the case of the currency symbol normalization:
1import org.apache.spark.sql.UserDefinedFunction 2import java.util.regex.Pattern 3 4val normalizeCurrencySymbol = udf { 5 (text: SmsText) => 6 val regex = "[\\$\\€\\£]" 7 val pattern = Pattern.compile(regex) 8 val matcher = pattern.matcher(text) 9 10 matcher.replaceAll(" normalizedcurrencysymbol ") 11}
To obtain the normalized SMS in an appended column we chain the methods as follows:
1// Using an obvious toLowerCase UDF to ignore case: 2val smsSpamCollectionDfLowerCased = smsSpamCollectionDf 3 .select(label, 4 toLowerCase(smsText) 5 .as(smsText.toString)) 6 7// Doing the actual normalization on the lower cased data: 8val smsSpamCollectionDfNormalized = smsSpamCollectionDfLowerCased 9 .withColumn("NormalizedSmsText", 10 removePunctuationAndSpecialChar( 11 normalizeNumber( 12 removeHTMLCharacterEntities( 13 normalizeCurrencySymbol( 14 normalizeEmoticon( 15 normalizeURL( 16 normalizeEmailAddress(smsText) 17 )))))))
One benefit when working with DataFrames is that we can append new columns, which can contain derived data. The nice thing about this is, it allows us to keep everything central in one DataFrame and enables us to easily compare data across rows of one DataFrame. The latter becomes especially handy for instance when you try to understand the quality of, say, the normalization step by looking at specific examples. Basically, this can allow the revision of the “working history” and helps to spot points were information might have been lost or corrupted that could be crucial for the further process. Something that would be cumbersome to do by using plain RDDs.
Step 2: Bagging words and label encoding
We use two other UDFs. One to obtain the desired bag of words from the normalized SMS text and another to encode the label ("ham"
or "spam"
) numerically:
1val bagOfWordsFrom = udf { 2 (smsText: String) => smsText.split(" ").toList.map(_.trim) 3} 4 5val encodeLabel: UserDefinedFunction = udf { 6 (label: String) => if (label == "ham") 1.0 else 0.0 7}
As convention we use 1.0
as encoding for "ham"
.
Word frequencies
What words appear most frequent in the corpus? Simple question. But how do we answer that? Well, one way is basically to flatten all the bag of words out to a DataFrame on which then do some aggregation and counting. And that is something we could do as follows:
1val tokenFrequencies = smsSpamCollectionBagOfWords 2 .select($"BagOfWords") 3 .flatMap(row => row.getAs[Seq[String]](0)) 4 .toDF("Tokens") 5 .groupBy($"Tokens") 6 .agg(count("*").as("Frequency")) 7 .orderBy($"Frequency".desc)
Zeppelin allows us to visualize the top twenty elements of this DataFrame as a Pie chart and this chart reveals that it might be a good idea to filter out words that already appear quite frequently in daily speech:
So what are these words? They are stop words and it is usually a good idea to get rid of them as they are merely here because they refer to things we frequently have to speak about, like “I” and “you”, or just because good old grammar requires them, like “a” or “the” to name a few. We use the list of stop words from here . This list is also used by Spark ML. So we work on a equal footing when we use Spark ML in a follow-up to this blog post and do not have two stop words list floating around. Back to the topic…
So before reframing the RDD obtained by flattening, we do a little filter on the stop words (see next snippet). Actually, it is a good idea to broadcast the collection of stop words read from the file:
1import scala.io.Source 2import org.apache.spark.broadcast.Broadcast 3 4val stopWords = sc.broadcast { Source.fromFile(stopWordsPath).getLines().to[Seq] } 5val tokenFrequencies = smsSpamCollectionBagOfWords.select($"BagOfWords") 6 .flatMap(row => row.getAs[Seq[String]](0)) 7 .filter(token => !(stopWords.value contains token)) 8 .toDF("Tokens") 9 .groupBy($"Tokens") 10 .agg(count("*").as("Frequency")) 11 .orderBy($"Frequency".desc)
With this out of the way, the distribution of the twenty most frequent words looks as follows:
This already looks much better: We can see that more normalizations strings appear now among the top twenty frequent words. We also could try to map some common abbreviations like “u” to their corresponding verbose version but we won’t touch upon this here.
Okay, next question: What words are the most frequent among ham and spam messages?
Why do we care? For one thing, one might hope that there are a few keywords that are so characteristic for spam messages that they will suffice to actually identify spam. On the other hand, when we will be using term frequency vectors, we want to base those only on the most frequent terms.
As the following image reveals, the top three entities in spam messages tend to be numbers, currency symbols, followed by the word “free”:
This looks quite different from the distribution of entities among ham messages:
So indeed there seem to be words that occur quite frequently within spam messages. We will set up a baseline model that classifies messages on the three most frequent of these words. This model is easy to build and any other model we build will have to compete with this one.
Building a baseline model
Alright! As said, a first naive approach to classifying the messages would be to mark those as spam that contain the top three most frequent words among all spam messages. This could be done by using the following UDF:
1val doesLookSuspicious = udf { 2 (bagOfWords: Seq[String]) => 3 val containsAlarmWords = Set("normalizednumber", "normalizedcurrencysymbol", "free") subsetOf bagOfWords.toSet 4 5 // Ham <--> positive class <--> 1.0; Spam <--> negative class <--> 0.0: 6 if (containsAlarmWords) 0.0 else 1.0 7}
This naive approach is certainly too overly simplistic since there are also spam messages that contain neither of the three alarm words. But let’s try it and see what we get! For this we collect the predictions and actual labels as RDD predictionAndLabels
:
1val predictionAndLabels = dataset 2 .withColumn("Prediction", doesLookSuspicious($"BagOfWords")) 3 .select($"LabelCode", $"Prediction") 4 .map(row => (row.getDouble(0), row.getDouble(1)))
In the next section we will see how we can use this RDD to assess the quality of the baseline model.
Model evaluation
Everything that can be measured, can be improved. But how do we measure the usefulness of a model?
The first thing that comes to mind is to count the number of instances that were correctly classified and to hope that this number turns out to be comparatively higher than the one given by the misclassified instances. Although, this is okay for a first test, one typically wants to understand on a by class level what labels were confused by the classifier — especially when getting the label for one of the classes wrong is absolutely undesirable.
This leads to the following notions: A true positive (TP) is an instance of the positive class that was classified as such. On the other hand, a false positive (FP) is an instance of the negative class that was misclassified as positive. Analogously, one can speak of true negatives (TN) and false negatives (FN). These values are typically summarized in the so-called confusion matrix:
Actual = N | Actual = P | |
Predicted = N | TN | FN |
Predicted = P | FP | TP |
For a perfect binary classifier the off-diagonal terms of this matrix should should be close to zero. That is, the true positive rate (TPR)
\(TPR = \frac{TP}{P} = \frac{TP}{TP + FN}\)
and false positive rate (TNR)
\(FPR = \frac{FP}{N} = \frac{FP}{FP + TN}\)
should be close to one and zero, respectively.
It’s an easy matter to compute and pretty print these summary values:
1val tn = predictionAndLabels.filter { case (predicted, actual) => actual == 0 && predicted == 0 }.count().toFloat 2val fp = predictionAndLabels.filter { case (predicted, actual) => actual == 0 && predicted == 1 }.count().toFloat 3val fn = predictionAndLabels.filter { case (predicted, actual) => actual == 1 && predicted == 0 }.count().toFloat 4val tp = predictionAndLabels.filter { case (predicted, actual) => actual == 1 && predicted == 1 }.count().toFloat 5 6printf(s"""|=================== Confusion matrix ========================== 7 |#############| %-15s %-15s 8 |-------------+------------------------------------------------- 9 |Predicted = 0| %-15f %-15f 10 |Predicted = 1| %-15f %-15f 11 |=============================================================== 12 """.stripMargin, "Actual = 0", "Actual = 1", tn, fp, fn, tp)
One can also compute these summary values using Spark’s MLlib. It provides various utility classes for computing evaluation and performance measures via the package evaluation
. For the full details see Spark’s Documentation on Evaluation Metrics . Unfortunately, the class BinaryClassificationMetrics
contained therein does not itself provide the capability to compute a confusion matrix. But instead we can use the class MulticlassMetrics
as follows: The confusion matrix, and various other measures, can be computed from a metrics object that’s created from the RDD of predictions of labels our model makes:
1import org.apache.spark.mllib.evaluation.MulticlassMetrics 2 3val metrics = new MulticlassMetrics(predictionAndLabels) 4val cfm = metrics.confusionMatrix 5 6val tn = cfm(0, 0) 7val fp = cfm(0, 1) 8val fn = cfm(1, 0) 9val tp = cfm(1, 1)
For our baseline model, we obtain the following confusion matrix:
Actual = Spam | Actual = Ham | |
Predicted = Spam | 37 | 0 |
Predicted = Ham | 710 | 4817 |
Is this good news or bad news? Well, the good news is, our baseline model got none of the ham messages wrong. But the bad news is, that way too many spam messages are also classified as ham. This is reflected by a true positive rate of about 87% and a false positive rate of 0%. So this is certainly something where room for improvement is left for the other models we are going to build! So stay tuned…
Summary and Outlook
We saw how to use Spark’s DataFrames to process labeled input text data from a CSV file and how one can use Spark and Zeppelin to explore the data. Finally, we built our first model and discussed how to evaluate its performance using the confusion matrix and the false and true positive rates.
In the following blog post, we will use logistic regression to build another model. Testing this model against our base line model will give us an idea which one to pick.
We will also see how one can use grid search to train various models automatically varying some of their common parameters. This will allow us to test models with different parameter configurations and subsequently to choose one, which performs better than the others.
More articles
fromDaniel Pape
Your job at codecentric?
Jobs
Agile Developer und Consultant (w/d/m)
Alle Standorte
More articles in this subject area
Discover exciting further topics and let the codecentric world inspire you.
Gemeinsam bessere Projekte umsetzen.
Wir helfen deinem Unternehmen.
Du stehst vor einer großen IT-Herausforderung? Wir sorgen für eine maßgeschneiderte Unterstützung. Informiere dich jetzt.
Hilf uns, noch besser zu werden.
Wir sind immer auf der Suche nach neuen Talenten. Auch für dich ist die passende Stelle dabei.
Blog author
Daniel Pape
Do you still have questions? Just send me a message.
Do you still have questions? Just send me a message.