Iris flower classification with TinyML
An introduction to Machine Learning development for Arduino

This project is an introduction to TinyML classification with a toy dataset.
The Iris dataset is a well known dataset used to explain machine learning concepts. It describes 3 species of the Iris flower in terms of:
- petal width
- petal height
- sepal length
- sepal height
In this project, I will show you how to classify these flowers using some feature pre-processing and decision trees.
"""
Start by installing the everywhereml package
"""
#! pip install everywhereml>=0.0.3
'\nStart by installing the everywhereml package\n'
"""
Import the Iris dataset and default classifier
Later in this course you will learn how to create your own dataset and classifier,
but for now just use some defaults.
"""
from everywhereml.get_started import iris_dataset, iris_classifier
Familiarize with the dataset
Once you have a dataset, the first thing to do is familiarize with the data. You have to understand the kind of features you're working with, their ranges, missing values and any other statistics that may turn useful.
"""
First things first: print the contents of the Iris dataset.
(If you're familiar with pandas, .df gives you access to the
underlying DataFrame object for the dataset.
You can call any DataFrame method on it.)
"""
iris_dataset.df
sepal length (cm) | sepal width (cm) | petal length (cm) | petal width (cm) | target | target_name | |
---|---|---|---|---|---|---|
0 | 5.1 | 3.5 | 1.4 | 0.2 | 0.0 | setosa |
1 | 4.9 | 3.0 | 1.4 | 0.2 | 0.0 | setosa |
2 | 4.7 | 3.2 | 1.3 | 0.2 | 0.0 | setosa |
3 | 4.6 | 3.1 | 1.5 | 0.2 | 0.0 | setosa |
4 | 5.0 | 3.6 | 1.4 | 0.2 | 0.0 | setosa |
... | ... | ... | ... | ... | ... | ... |
145 | 6.7 | 3.0 | 5.2 | 2.3 | 2.0 | virginica |
146 | 6.3 | 2.5 | 5.0 | 1.9 | 2.0 | virginica |
147 | 6.5 | 3.0 | 5.2 | 2.0 | 2.0 | virginica |
148 | 6.2 | 3.4 | 5.4 | 2.3 | 2.0 | virginica |
149 | 5.9 | 3.0 | 5.1 | 1.8 | 2.0 | virginica |
150 rows × 6 columns
"""
Print some basic statistics about the Iris dataset
"""
iris_dataset.describe()
sepal length (cm) | sepal width (cm) | petal length (cm) | petal width (cm) | target | |
---|---|---|---|---|---|
count | 150.000000 | 150.000000 | 150.000000 | 150.000000 | 150.000000 |
mean | 5.843333 | 3.057333 | 3.758000 | 1.199333 | 1.000000 |
std | 0.828066 | 0.435866 | 1.765298 | 0.762238 | 0.819232 |
min | 4.300000 | 2.000000 | 1.000000 | 0.100000 | 0.000000 |
25% | 5.100000 | 2.800000 | 1.600000 | 0.300000 | 0.000000 |
50% | 5.800000 | 3.000000 | 4.350000 | 1.300000 | 1.000000 |
75% | 6.400000 | 3.300000 | 5.100000 | 1.800000 | 2.000000 |
max | 7.900000 | 4.400000 | 6.900000 | 2.500000 | 2.000000 |
Looking at the raw numbers will hardly ever turn really useful, since you won't be able to get a "bird eye" view of all the data.
It is often said that a picture is worth a thousand words, so let's create a picture.
One such picture is a pairplot, which displays a grid of size N features x N features.
Each cell is a scatter plot of only 2 given features among the all available, colored by the class label. If you can somehow separate the classes by eye, there's a good chance a machine learning model will perform good.
If you only see gibberish points, you may be out of luck!
"""
Draw a pair-plot of the features.
The "setosa" class (blue) is clearly isolated from the other two:
any classifier will be able to achieve 100% accuracy on it.
"versicolor" and "virginica", on the other hand, sometimes overlap:
you can expect some classification errors between the two.
"""
iris_dataset.plot.features_pairplot()

Create a Machine Learning model
Once you're satisfied with pre-processing, it's time to train a model to classify its samples.
We will take a look at the many classifiers available in the next chapters; for the moment, we will use a RandomForestClassifier because it works fine out-of-the-box in most cases.
Later, you will learn the pros and cons of each type of classifier so you'll choose the best for your project.
"""
Fit classifier on Iris train dataset and test accuracy on Iris test dataset.
For such a simple dataset, you can expect a score above 0.9
"""
iris_train, iris_test = iris_dataset.split(test_size=0.3)
iris_classifier.fit(iris_train)
print('Classification score on test dataset: %.2f' % iris_classifier.score(iris_test))
Classification score on test dataset: 0.93
Export to Arduino code
Now that the development phase is completed and we are satisfied with the accuracy of our classifier, it's time to port it to our Arduino sketch.
Using the everywhereml
package makes this part a breeze, since each component is able to convert itself to C++ code. In this case, the classifier is the only piece of code we want to convert to C++.
"""
Export the classifier to Arduino C++ code.
"""
# return a string with the C++ code
iris_classifier.to_arduino(instance_name='classifier', class_map=iris_dataset.class_map)
# or save the contents directly to a file
print(iris_classifier.to_arduino_file(
'sketches/IrisClassification/Classifier.h',
instance_name='classifier',
class_map=iris_dataset.class_map
))
#ifndef UUID4431762480 #define UUID4431762480 /** * RandomForestClassifier(base_estimator=DecisionTreeClassifier(), bootstrap=True, ccp_alpha=0.0, class_name=RandomForestClassifier, class_weight=None, criterion=gini, estimator_params=('criterion', 'max_depth', 'min_samples_split', 'min_samples_leaf', 'min_weight_fraction_leaf', 'max_features', 'max_leaf_nodes', 'min_impurity_decrease', 'random_state', 'ccp_alpha'), max_depth=10, max_features=auto, max_leaf_nodes=None, max_samples=None, min_impurity_decrease=0.0, min_samples_leaf=1, min_samples_split=2, min_weight_fraction_leaf=0.0, n_estimators=10, n_jobs=None, num_outputs=3, oob_score=False, package_name=everywhereml.sklearn.ensemble, random_state=None, template_folder=everywhereml/sklearn/ensemble, verbose=0, warm_start=False) */ class RandomForestClassifier { public: /** * Predict class from features */ int predict(float *x) { int predictedValue = 0; size_t startedAt = micros(); uint16_t votes[3] = { 0 }; uint8_t classIdx = 0; float classScore = 0; tree0(x, &classIdx, &classScore); votes[classIdx] += classScore; tree1(x, &classIdx, &classScore); votes[classIdx] += classScore; tree2(x, &classIdx, &classScore); votes[classIdx] += classScore; tree3(x, &classIdx, &classScore); votes[classIdx] += classScore; tree4(x, &classIdx, &classScore); votes[classIdx] += classScore; tree5(x, &classIdx, &classScore); votes[classIdx] += classScore; tree6(x, &classIdx, &classScore); votes[classIdx] += classScore; tree7(x, &classIdx, &classScore); votes[classIdx] += classScore; tree8(x, &classIdx, &classScore); votes[classIdx] += classScore; tree9(x, &classIdx, &classScore); votes[classIdx] += classScore; // return argmax of votes uint8_t maxClassIdx = 0; float maxVote = votes[0]; for (uint8_t i = 1; i < 3; i++) { if (votes[i] > maxVote) { maxClassIdx = i; maxVote = votes[i]; } } predictedValue = maxClassIdx; latency = micros() - startedAt; return (lastPrediction = predictedValue); } /** * Predict class label */ String predictLabel(float *x) { return getLabelOf(predict(x)); } /** * Get label of last prediction */ String getLabel() { return getLabelOf(lastPrediction); } /** * Get label of given class */ String getLabelOf(int8_t idx) { switch (idx) { case -1: return "ERROR"; case 0: return "setosa"; case 1: return "versicolor"; case 2: return "virginica"; default: return "UNKNOWN"; } } /** * Get latency in micros */ uint32_t latencyInMicros() { return latency; } /** * Get latency in millis */ uint16_t latencyInMillis() { return latency / 1000; } protected: float latency = 0; int lastPrediction = 0; /** * Random forest's tree #0 */ void tree0(float *x, uint8_t *classIdx, float *classScore) { if (x[3] <= 0.7000000029802322) { *classIdx = 0; *classScore = 27.0; return; } else { if (x[3] <= 1.699999988079071) { if (x[1] <= 2.649999976158142) { if (x[2] <= 4.75) { *classIdx = 1; *classScore = 41.0; return; } else { *classIdx = 2; *classScore = 37.0; return; } } else { *classIdx = 1; *classScore = 41.0; return; } } else { *classIdx = 2; *classScore = 37.0; return; } } } /** * Random forest's tree #1 */ void tree1(float *x, uint8_t *classIdx, float *classScore) { if (x[2] <= 4.8500001430511475) { if (x[3] <= 0.7000000029802322) { *classIdx = 0; *classScore = 28.0; return; } else { if (x[0] <= 6.1499998569488525) { *classIdx = 1; *classScore = 39.0; return; } else { if (x[3] <= 1.699999988079071) { *classIdx = 1; *classScore = 39.0; return; } else { *classIdx = 2; *classScore = 38.0; return; } } } } else { if (x[3] <= 1.649999976158142) { if (x[0] <= 6.200000047683716) { *classIdx = 2; *classScore = 38.0; return; } else { *classIdx = 1; *classScore = 39.0; return; } } else { *classIdx = 2; *classScore = 38.0; return; } } } /** * Random forest's tree #2 */ void tree2(float *x, uint8_t *classIdx, float *classScore) { if (x[3] <= 0.800000011920929) { *classIdx = 0; *classScore = 42.0; return; } else { if (x[2] <= 4.950000047683716) { if (x[3] <= 1.649999976158142) { *classIdx = 1; *classScore = 31.0; return; } else { if (x[2] <= 4.8500001430511475) { if (x[1] <= 3.0) { *classIdx = 2; *classScore = 32.0; return; } else { *classIdx = 1; *classScore = 31.0; return; } } else { *classIdx = 2; *classScore = 32.0; return; } } } else { *classIdx = 2; *classScore = 32.0; return; } } } /** * Random forest's tree #3 */ void tree3(float *x, uint8_t *classIdx, float *classScore) { if (x[3] <= 1.699999988079071) { if (x[0] <= 5.549999952316284) { if (x[3] <= 0.7000000029802322) { *classIdx = 0; *classScore = 27.0; return; } else { *classIdx = 1; *classScore = 44.0; return; } } else { if (x[1] <= 3.600000023841858) { if (x[2] <= 4.950000047683716) { *classIdx = 1; *classScore = 44.0; return; } else { *classIdx = 2; *classScore = 34.0; return; } } else { *classIdx = 0; *classScore = 27.0; return; } } } else { *classIdx = 2; *classScore = 34.0; return; } } /** * Random forest's tree #4 */ void tree4(float *x, uint8_t *classIdx, float *classScore) { if (x[3] <= 0.7000000029802322) { *classIdx = 0; *classScore = 35.0; return; } else { if (x[3] <= 1.6500000357627869) { if (x[1] <= 2.25) { *classIdx = 2; *classScore = 32.0; return; } else { *classIdx = 1; *classScore = 38.0; return; } } else { *classIdx = 2; *classScore = 32.0; return; } } } /** * Random forest's tree #5 */ void tree5(float *x, uint8_t *classIdx, float *classScore) { if (x[3] <= 0.800000011920929) { *classIdx = 0; *classScore = 33.0; return; } else { if (x[2] <= 4.8500001430511475) { *classIdx = 1; *classScore = 45.0; return; } else { if (x[3] <= 1.649999976158142) { if (x[2] <= 5.25) { *classIdx = 1; *classScore = 45.0; return; } else { *classIdx = 2; *classScore = 27.0; return; } } else { *classIdx = 2; *classScore = 27.0; return; } } } } /** * Random forest's tree #6 */ void tree6(float *x, uint8_t *classIdx, float *classScore) { if (x[0] <= 5.450000047683716) { if (x[1] <= 2.850000023841858) { *classIdx = 1; *classScore = 37.0; return; } else { if (x[0] <= 5.299999952316284) { *classIdx = 0; *classScore = 29.0; return; } else { if (x[3] <= 0.8500000014901161) { *classIdx = 0; *classScore = 29.0; return; } else { *classIdx = 1; *classScore = 37.0; return; } } } } else { if (x[2] <= 4.950000047683716) { if (x[2] <= 2.5) { *classIdx = 0; *classScore = 29.0; return; } else { if (x[3] <= 1.649999976158142) { *classIdx = 1; *classScore = 37.0; return; } else { if (x[1] <= 3.0) { *classIdx = 2; *classScore = 39.0; return; } else { *classIdx = 1; *classScore = 37.0; return; } } } } else { *classIdx = 2; *classScore = 39.0; return; } } } /** * Random forest's tree #7 */ void tree7(float *x, uint8_t *classIdx, float *classScore) { if (x[3] <= 0.7000000029802322) { *classIdx = 0; *classScore = 35.0; return; } else { if (x[3] <= 1.699999988079071) { if (x[2] <= 4.950000047683716) { *classIdx = 1; *classScore = 42.0; return; } else { *classIdx = 2; *classScore = 28.0; return; } } else { *classIdx = 2; *classScore = 28.0; return; } } } /** * Random forest's tree #8 */ void tree8(float *x, uint8_t *classIdx, float *classScore) { if (x[0] <= 5.450000047683716) { if (x[2] <= 2.449999988079071) { *classIdx = 0; *classScore = 39.0; return; } else { if (x[0] <= 4.950000047683716) { *classIdx = 2; *classScore = 29.0; return; } else { *classIdx = 1; *classScore = 37.0; return; } } } else { if (x[2] <= 4.950000047683716) { if (x[2] <= 2.600000023841858) { *classIdx = 0; *classScore = 39.0; return; } else { if (x[3] <= 1.699999988079071) { *classIdx = 1; *classScore = 37.0; return; } else { if (x[0] <= 5.75) { *classIdx = 2; *classScore = 29.0; return; } else { if (x[0] <= 6.049999952316284) { *classIdx = 1; *classScore = 37.0; return; } else { *classIdx = 2; *classScore = 29.0; return; } } } } } else { *classIdx = 2; *classScore = 29.0; return; } } } /** * Random forest's tree #9 */ void tree9(float *x, uint8_t *classIdx, float *classScore) { if (x[2] <= 2.599999964237213) { *classIdx = 0; *classScore = 45.0; return; } else { if (x[3] <= 1.699999988079071) { if (x[1] <= 2.25) { if (x[2] <= 4.75) { *classIdx = 1; *classScore = 26.0; return; } else { *classIdx = 2; *classScore = 34.0; return; } } else { *classIdx = 1; *classScore = 26.0; return; } } else { if (x[0] <= 6.0) { if (x[0] <= 5.8500001430511475) { *classIdx = 2; *classScore = 34.0; return; } else { if (x[1] <= 3.100000023841858) { *classIdx = 2; *classScore = 34.0; return; } else { *classIdx = 1; *classScore = 26.0; return; } } } else { *classIdx = 2; *classScore = 34.0; return; } } } } }; static RandomForestClassifier classifier; #endif
Import into Arduino sketch
To finish the process, we have to call the exported code from our main sketch.
This sample sketch will load Iris samples from the Serial monitor and print the predicted class. A few demo samples are provided for you to try.
// file IrisClassification.ino
/**
* This is a demo project to classify the Iris dataset using
* the pre-processing pipeline and classifier fitted in Python
* with the eloquent_tinyml package.
*
* It reads the data from the Serial Monitor and performs
* classification.
*/
#include "Classifier.h"
void setup() {
Serial.begin(115200);
Serial.println("This is a demo for the Iris dataset classification");
Serial.println("Paste a list of features into the Serial Monitor and get the prediction back");
Serial.println("Here are a few examples for each class: ");
Serial.println(" > Class 0");
Serial.println(" > 5.1,3.5,1.4,0.2");
Serial.println(" > 4.9,3.0,1.4,0.2");
Serial.println(" > 4.7,3.2,1.3,0.2");
Serial.println(" > Class 1");
Serial.println(" > 7.0,3.2,4.7,1.4");
Serial.println(" > 6.4,3.2,4.5,1.5");
Serial.println(" > 6.9,3.1,4.9,1.5");
Serial.println(" > Class 2");
Serial.println(" > 6.3,3.3,6.0,2.5");
Serial.println(" > 5.8,2.7,5.1,1.9");
Serial.println(" > 7.1,3.0,5.9,2.1");
}
void loop() {
float features[4];
if (!Serial.available())
return;
for (int i = 0; i < 4; i++)
features[i] = Serial.readStringUntil(',').toFloat();
// run prediction and print result
Serial.print("Predicted class: ");
Serial.println(classifier.predict(features));
Serial.print("Predicted class label: ");
Serial.println(classifier.predictLabel(features));
Serial.print("It took ");
Serial.print(classifier.latencyInMicros());
Serial.println(" micros");
}
The main points of interest in the sketch are the line classifier.predict(features)
to get the predicted class index as a numeric value, and classifier.predictLabel(features)
to get the predicted class name as string.
Well done!
Congratulations, you implemented your first TinyML project for Arduino.
This is a toy example, and may be not that much useful in the real world, but it showcased a few important constructs of the everywhereml
library and the key points to follow while developing your own TinyML project.
In the next lessons, we'll start digging in into each point in more details.