Level up your TinyML skills

Iris flower classification with TinyML

An introduction to Machine Learning development for Arduino

Iris flower classification with TinyML

This project is an introduction to TinyML classification with a toy dataset.

The Iris dataset is a well known dataset used to explain machine learning concepts. It describes 3 species of the Iris flower in terms of:

  • petal width
  • petal height
  • sepal length
  • sepal height

Iris flowers from http://www.lac.inpe.br/~rafael.santos/Docs/CAP394/WholeStory-Iris.html

In this project, I will show you how to classify these flowers using some feature pre-processing and decision trees.

"""
Start by installing the everywhereml package
"""
#! pip install everywhereml>=0.0.3
'\nStart by installing the everywhereml package\n'
"""
Import the Iris dataset and default classifier
Later in this course you will learn how to create your own dataset and classifier, 
but for now just use some defaults.
"""
from everywhereml.get_started import iris_dataset, iris_classifier

Familiarize with the dataset

Once you have a dataset, the first thing to do is familiarize with the data. You have to understand the kind of features you're working with, their ranges, missing values and any other statistics that may turn useful.

"""
First things first: print the contents of the Iris dataset.
(If you're familiar with pandas, .df gives you access to the 
underlying DataFrame object for the dataset.
You can call any DataFrame method on it.)
"""
iris_dataset.df
sepal length (cm) sepal width (cm) petal length (cm) petal width (cm) target target_name
0 5.1 3.5 1.4 0.2 0.0 setosa
1 4.9 3.0 1.4 0.2 0.0 setosa
2 4.7 3.2 1.3 0.2 0.0 setosa
3 4.6 3.1 1.5 0.2 0.0 setosa
4 5.0 3.6 1.4 0.2 0.0 setosa
... ... ... ... ... ... ...
145 6.7 3.0 5.2 2.3 2.0 virginica
146 6.3 2.5 5.0 1.9 2.0 virginica
147 6.5 3.0 5.2 2.0 2.0 virginica
148 6.2 3.4 5.4 2.3 2.0 virginica
149 5.9 3.0 5.1 1.8 2.0 virginica

150 rows × 6 columns

"""
Print some basic statistics about the Iris dataset
"""
iris_dataset.describe()
sepal length (cm) sepal width (cm) petal length (cm) petal width (cm) target
count 150.000000 150.000000 150.000000 150.000000 150.000000
mean 5.843333 3.057333 3.758000 1.199333 1.000000
std 0.828066 0.435866 1.765298 0.762238 0.819232
min 4.300000 2.000000 1.000000 0.100000 0.000000
25% 5.100000 2.800000 1.600000 0.300000 0.000000
50% 5.800000 3.000000 4.350000 1.300000 1.000000
75% 6.400000 3.300000 5.100000 1.800000 2.000000
max 7.900000 4.400000 6.900000 2.500000 2.000000

Looking at the raw numbers will hardly ever turn really useful, since you won't be able to get a "bird eye" view of all the data.

It is often said that a picture is worth a thousand words, so let's create a picture.

One such picture is a pairplot, which displays a grid of size N features x N features.

Each cell is a scatter plot of only 2 given features among the all available, colored by the class label. If you can somehow separate the classes by eye, there's a good chance a machine learning model will perform good.

If you only see gibberish points, you may be out of luck!

"""
Draw a pair-plot of the features.

The "setosa" class (blue) is clearly isolated from the other two: 
any classifier will be able to achieve 100% accuracy on it.

"versicolor" and "virginica", on the other hand, sometimes overlap: 
you can expect some classification errors between the two.
"""
iris_dataset.plot.features_pairplot()

Create a Machine Learning model

Once you're satisfied with pre-processing, it's time to train a model to classify its samples.

We will take a look at the many classifiers available in the next chapters; for the moment, we will use a RandomForestClassifier because it works fine out-of-the-box in most cases.

Later, you will learn the pros and cons of each type of classifier so you'll choose the best for your project.

"""
Fit classifier on Iris train dataset and test accuracy on Iris test dataset.
For such a simple dataset, you can expect a score above 0.9
"""
iris_train, iris_test = iris_dataset.split(test_size=0.3)
iris_classifier.fit(iris_train)

print('Classification score on test dataset: %.2f' % iris_classifier.score(iris_test))
Classification score on test dataset: 0.93

Export to Arduino code

Now that the development phase is completed and we are satisfied with the accuracy of our classifier, it's time to port it to our Arduino sketch.

Using the everywhereml package makes this part a breeze, since each component is able to convert itself to C++ code. In this case, the classifier is the only piece of code we want to convert to C++.

"""
Export the classifier to Arduino C++ code.
"""
# return a string with the C++ code
iris_classifier.to_arduino(instance_name='classifier', class_map=iris_dataset.class_map)

# or save the contents directly to a file
print(iris_classifier.to_arduino_file(
    'sketches/IrisClassification/Classifier.h', 
    instance_name='classifier', 
    class_map=iris_dataset.class_map
))
#ifndef UUID4431762480
#define UUID4431762480

/**
  * RandomForestClassifier(base_estimator=DecisionTreeClassifier(), bootstrap=True, ccp_alpha=0.0, class_name=RandomForestClassifier, class_weight=None, criterion=gini, estimator_params=('criterion', 'max_depth', 'min_samples_split', 'min_samples_leaf', 'min_weight_fraction_leaf', 'max_features', 'max_leaf_nodes', 'min_impurity_decrease', 'random_state', 'ccp_alpha'), max_depth=10, max_features=auto, max_leaf_nodes=None, max_samples=None, min_impurity_decrease=0.0, min_samples_leaf=1, min_samples_split=2, min_weight_fraction_leaf=0.0, n_estimators=10, n_jobs=None, num_outputs=3, oob_score=False, package_name=everywhereml.sklearn.ensemble, random_state=None, template_folder=everywhereml/sklearn/ensemble, verbose=0, warm_start=False)
 */
class RandomForestClassifier {
    public:

        /**
         * Predict class from features
         */
        int predict(float *x) {
            int predictedValue = 0;
            size_t startedAt = micros();

            
    uint16_t votes[3] = { 0 };
    uint8_t classIdx = 0;
    float classScore = 0;

    
        tree0(x, &classIdx, &classScore);
        votes[classIdx] += classScore;
    
        tree1(x, &classIdx, &classScore);
        votes[classIdx] += classScore;
    
        tree2(x, &classIdx, &classScore);
        votes[classIdx] += classScore;
    
        tree3(x, &classIdx, &classScore);
        votes[classIdx] += classScore;
    
        tree4(x, &classIdx, &classScore);
        votes[classIdx] += classScore;
    
        tree5(x, &classIdx, &classScore);
        votes[classIdx] += classScore;
    
        tree6(x, &classIdx, &classScore);
        votes[classIdx] += classScore;
    
        tree7(x, &classIdx, &classScore);
        votes[classIdx] += classScore;
    
        tree8(x, &classIdx, &classScore);
        votes[classIdx] += classScore;
    
        tree9(x, &classIdx, &classScore);
        votes[classIdx] += classScore;
    

    // return argmax of votes
uint8_t maxClassIdx = 0;
float maxVote = votes[0];

for (uint8_t i = 1; i < 3; i++) {
    if (votes[i] > maxVote) {
        maxClassIdx = i;
        maxVote = votes[i];
    }
}

predictedValue = maxClassIdx;


            latency = micros() - startedAt;

            return (lastPrediction = predictedValue);
        }


        

/**
 * Predict class label
 */
String predictLabel(float *x) {
    return getLabelOf(predict(x));
}

/**
 * Get label of last prediction
 */
String getLabel() {
    return getLabelOf(lastPrediction);
}

/**
 * Get label of given class
 */
String getLabelOf(int8_t idx) {
    switch (idx) {
        case -1:
            return "ERROR";
        
            case 0:
                return "setosa";
        
            case 1:
                return "versicolor";
        
            case 2:
                return "virginica";
        
        default:
            return "UNKNOWN";
    }
}


        /**
 * Get latency in micros
 */
uint32_t latencyInMicros() {
    return latency;
}

/**
 * Get latency in millis
 */
uint16_t latencyInMillis() {
    return latency / 1000;
}

    protected:
        float latency = 0;
        int lastPrediction = 0;

        

    
        
            /**
             * Random forest's tree #0
             */
            void tree0(float *x, uint8_t *classIdx, float *classScore) {
                
    if (x[3] <= 0.7000000029802322) {
        
            
    *classIdx = 0;
    *classScore = 27.0;
    return;

        
    }
    else {
        
            
    if (x[3] <= 1.699999988079071) {
        
            
    if (x[1] <= 2.649999976158142) {
        
            
    if (x[2] <= 4.75) {
        
            
    *classIdx = 1;
    *classScore = 41.0;
    return;

        
    }
    else {
        
            
    *classIdx = 2;
    *classScore = 37.0;
    return;

        
    }

        
    }
    else {
        
            
    *classIdx = 1;
    *classScore = 41.0;
    return;

        
    }

        
    }
    else {
        
            
    *classIdx = 2;
    *classScore = 37.0;
    return;

        
    }

        
    }

            }
        
    
        
            /**
             * Random forest's tree #1
             */
            void tree1(float *x, uint8_t *classIdx, float *classScore) {
                
    if (x[2] <= 4.8500001430511475) {
        
            
    if (x[3] <= 0.7000000029802322) {
        
            
    *classIdx = 0;
    *classScore = 28.0;
    return;

        
    }
    else {
        
            
    if (x[0] <= 6.1499998569488525) {
        
            
    *classIdx = 1;
    *classScore = 39.0;
    return;

        
    }
    else {
        
            
    if (x[3] <= 1.699999988079071) {
        
            
    *classIdx = 1;
    *classScore = 39.0;
    return;

        
    }
    else {
        
            
    *classIdx = 2;
    *classScore = 38.0;
    return;

        
    }

        
    }

        
    }

        
    }
    else {
        
            
    if (x[3] <= 1.649999976158142) {
        
            
    if (x[0] <= 6.200000047683716) {
        
            
    *classIdx = 2;
    *classScore = 38.0;
    return;

        
    }
    else {
        
            
    *classIdx = 1;
    *classScore = 39.0;
    return;

        
    }

        
    }
    else {
        
            
    *classIdx = 2;
    *classScore = 38.0;
    return;

        
    }

        
    }

            }
        
    
        
            /**
             * Random forest's tree #2
             */
            void tree2(float *x, uint8_t *classIdx, float *classScore) {
                
    if (x[3] <= 0.800000011920929) {
        
            
    *classIdx = 0;
    *classScore = 42.0;
    return;

        
    }
    else {
        
            
    if (x[2] <= 4.950000047683716) {
        
            
    if (x[3] <= 1.649999976158142) {
        
            
    *classIdx = 1;
    *classScore = 31.0;
    return;

        
    }
    else {
        
            
    if (x[2] <= 4.8500001430511475) {
        
            
    if (x[1] <= 3.0) {
        
            
    *classIdx = 2;
    *classScore = 32.0;
    return;

        
    }
    else {
        
            
    *classIdx = 1;
    *classScore = 31.0;
    return;

        
    }

        
    }
    else {
        
            
    *classIdx = 2;
    *classScore = 32.0;
    return;

        
    }

        
    }

        
    }
    else {
        
            
    *classIdx = 2;
    *classScore = 32.0;
    return;

        
    }

        
    }

            }
        
    
        
            /**
             * Random forest's tree #3
             */
            void tree3(float *x, uint8_t *classIdx, float *classScore) {
                
    if (x[3] <= 1.699999988079071) {
        
            
    if (x[0] <= 5.549999952316284) {
        
            
    if (x[3] <= 0.7000000029802322) {
        
            
    *classIdx = 0;
    *classScore = 27.0;
    return;

        
    }
    else {
        
            
    *classIdx = 1;
    *classScore = 44.0;
    return;

        
    }

        
    }
    else {
        
            
    if (x[1] <= 3.600000023841858) {
        
            
    if (x[2] <= 4.950000047683716) {
        
            
    *classIdx = 1;
    *classScore = 44.0;
    return;

        
    }
    else {
        
            
    *classIdx = 2;
    *classScore = 34.0;
    return;

        
    }

        
    }
    else {
        
            
    *classIdx = 0;
    *classScore = 27.0;
    return;

        
    }

        
    }

        
    }
    else {
        
            
    *classIdx = 2;
    *classScore = 34.0;
    return;

        
    }

            }
        
    
        
            /**
             * Random forest's tree #4
             */
            void tree4(float *x, uint8_t *classIdx, float *classScore) {
                
    if (x[3] <= 0.7000000029802322) {
        
            
    *classIdx = 0;
    *classScore = 35.0;
    return;

        
    }
    else {
        
            
    if (x[3] <= 1.6500000357627869) {
        
            
    if (x[1] <= 2.25) {
        
            
    *classIdx = 2;
    *classScore = 32.0;
    return;

        
    }
    else {
        
            
    *classIdx = 1;
    *classScore = 38.0;
    return;

        
    }

        
    }
    else {
        
            
    *classIdx = 2;
    *classScore = 32.0;
    return;

        
    }

        
    }

            }
        
    
        
            /**
             * Random forest's tree #5
             */
            void tree5(float *x, uint8_t *classIdx, float *classScore) {
                
    if (x[3] <= 0.800000011920929) {
        
            
    *classIdx = 0;
    *classScore = 33.0;
    return;

        
    }
    else {
        
            
    if (x[2] <= 4.8500001430511475) {
        
            
    *classIdx = 1;
    *classScore = 45.0;
    return;

        
    }
    else {
        
            
    if (x[3] <= 1.649999976158142) {
        
            
    if (x[2] <= 5.25) {
        
            
    *classIdx = 1;
    *classScore = 45.0;
    return;

        
    }
    else {
        
            
    *classIdx = 2;
    *classScore = 27.0;
    return;

        
    }

        
    }
    else {
        
            
    *classIdx = 2;
    *classScore = 27.0;
    return;

        
    }

        
    }

        
    }

            }
        
    
        
            /**
             * Random forest's tree #6
             */
            void tree6(float *x, uint8_t *classIdx, float *classScore) {
                
    if (x[0] <= 5.450000047683716) {
        
            
    if (x[1] <= 2.850000023841858) {
        
            
    *classIdx = 1;
    *classScore = 37.0;
    return;

        
    }
    else {
        
            
    if (x[0] <= 5.299999952316284) {
        
            
    *classIdx = 0;
    *classScore = 29.0;
    return;

        
    }
    else {
        
            
    if (x[3] <= 0.8500000014901161) {
        
            
    *classIdx = 0;
    *classScore = 29.0;
    return;

        
    }
    else {
        
            
    *classIdx = 1;
    *classScore = 37.0;
    return;

        
    }

        
    }

        
    }

        
    }
    else {
        
            
    if (x[2] <= 4.950000047683716) {
        
            
    if (x[2] <= 2.5) {
        
            
    *classIdx = 0;
    *classScore = 29.0;
    return;

        
    }
    else {
        
            
    if (x[3] <= 1.649999976158142) {
        
            
    *classIdx = 1;
    *classScore = 37.0;
    return;

        
    }
    else {
        
            
    if (x[1] <= 3.0) {
        
            
    *classIdx = 2;
    *classScore = 39.0;
    return;

        
    }
    else {
        
            
    *classIdx = 1;
    *classScore = 37.0;
    return;

        
    }

        
    }

        
    }

        
    }
    else {
        
            
    *classIdx = 2;
    *classScore = 39.0;
    return;

        
    }

        
    }

            }
        
    
        
            /**
             * Random forest's tree #7
             */
            void tree7(float *x, uint8_t *classIdx, float *classScore) {
                
    if (x[3] <= 0.7000000029802322) {
        
            
    *classIdx = 0;
    *classScore = 35.0;
    return;

        
    }
    else {
        
            
    if (x[3] <= 1.699999988079071) {
        
            
    if (x[2] <= 4.950000047683716) {
        
            
    *classIdx = 1;
    *classScore = 42.0;
    return;

        
    }
    else {
        
            
    *classIdx = 2;
    *classScore = 28.0;
    return;

        
    }

        
    }
    else {
        
            
    *classIdx = 2;
    *classScore = 28.0;
    return;

        
    }

        
    }

            }
        
    
        
            /**
             * Random forest's tree #8
             */
            void tree8(float *x, uint8_t *classIdx, float *classScore) {
                
    if (x[0] <= 5.450000047683716) {
        
            
    if (x[2] <= 2.449999988079071) {
        
            
    *classIdx = 0;
    *classScore = 39.0;
    return;

        
    }
    else {
        
            
    if (x[0] <= 4.950000047683716) {
        
            
    *classIdx = 2;
    *classScore = 29.0;
    return;

        
    }
    else {
        
            
    *classIdx = 1;
    *classScore = 37.0;
    return;

        
    }

        
    }

        
    }
    else {
        
            
    if (x[2] <= 4.950000047683716) {
        
            
    if (x[2] <= 2.600000023841858) {
        
            
    *classIdx = 0;
    *classScore = 39.0;
    return;

        
    }
    else {
        
            
    if (x[3] <= 1.699999988079071) {
        
            
    *classIdx = 1;
    *classScore = 37.0;
    return;

        
    }
    else {
        
            
    if (x[0] <= 5.75) {
        
            
    *classIdx = 2;
    *classScore = 29.0;
    return;

        
    }
    else {
        
            
    if (x[0] <= 6.049999952316284) {
        
            
    *classIdx = 1;
    *classScore = 37.0;
    return;

        
    }
    else {
        
            
    *classIdx = 2;
    *classScore = 29.0;
    return;

        
    }

        
    }

        
    }

        
    }

        
    }
    else {
        
            
    *classIdx = 2;
    *classScore = 29.0;
    return;

        
    }

        
    }

            }
        
    
        
            /**
             * Random forest's tree #9
             */
            void tree9(float *x, uint8_t *classIdx, float *classScore) {
                
    if (x[2] <= 2.599999964237213) {
        
            
    *classIdx = 0;
    *classScore = 45.0;
    return;

        
    }
    else {
        
            
    if (x[3] <= 1.699999988079071) {
        
            
    if (x[1] <= 2.25) {
        
            
    if (x[2] <= 4.75) {
        
            
    *classIdx = 1;
    *classScore = 26.0;
    return;

        
    }
    else {
        
            
    *classIdx = 2;
    *classScore = 34.0;
    return;

        
    }

        
    }
    else {
        
            
    *classIdx = 1;
    *classScore = 26.0;
    return;

        
    }

        
    }
    else {
        
            
    if (x[0] <= 6.0) {
        
            
    if (x[0] <= 5.8500001430511475) {
        
            
    *classIdx = 2;
    *classScore = 34.0;
    return;

        
    }
    else {
        
            
    if (x[1] <= 3.100000023841858) {
        
            
    *classIdx = 2;
    *classScore = 34.0;
    return;

        
    }
    else {
        
            
    *classIdx = 1;
    *classScore = 26.0;
    return;

        
    }

        
    }

        
    }
    else {
        
            
    *classIdx = 2;
    *classScore = 34.0;
    return;

        
    }

        
    }

        
    }

            }
        
    


};



static RandomForestClassifier classifier;


#endif

Import into Arduino sketch

To finish the process, we have to call the exported code from our main sketch.

This sample sketch will load Iris samples from the Serial monitor and print the predicted class. A few demo samples are provided for you to try.

// file IrisClassification.ino
/**
 * This is a demo project to classify the Iris dataset using
 * the pre-processing pipeline and classifier fitted in Python
 * with the eloquent_tinyml package.
 *
 * It reads the data from the Serial Monitor and performs
 * classification.
 */
#include "Classifier.h"


void setup() {
    Serial.begin(115200);
    Serial.println("This is a demo for the Iris dataset classification");
    Serial.println("Paste a list of features into the Serial Monitor and get the prediction back");
    Serial.println("Here are a few examples for each class: ");
    Serial.println(" > Class 0");
    Serial.println("   > 5.1,3.5,1.4,0.2");
    Serial.println("   > 4.9,3.0,1.4,0.2");
    Serial.println("   > 4.7,3.2,1.3,0.2");
    Serial.println(" > Class 1");
    Serial.println("   > 7.0,3.2,4.7,1.4");
    Serial.println("   > 6.4,3.2,4.5,1.5");
    Serial.println("   > 6.9,3.1,4.9,1.5");
    Serial.println(" > Class 2");
    Serial.println("   > 6.3,3.3,6.0,2.5");
    Serial.println("   > 5.8,2.7,5.1,1.9");
    Serial.println("   > 7.1,3.0,5.9,2.1");
}


void loop() {
    float features[4];
    
    if (!Serial.available())
        return;
        
    for (int i = 0; i < 4; i++)
        features[i] = Serial.readStringUntil(',').toFloat();
    
    // run prediction and print result
    Serial.print("Predicted class: ");
    Serial.println(classifier.predict(features));
    Serial.print("Predicted class label: ");
    Serial.println(classifier.predictLabel(features));
    Serial.print("It took ");
    Serial.print(classifier.latencyInMicros());
    Serial.println(" micros");
}

The main points of interest in the sketch are the line classifier.predict(features) to get the predicted class index as a numeric value, and classifier.predictLabel(features) to get the predicted class name as string.

Well done!

Congratulations, you implemented your first TinyML project for Arduino.

This is a toy example, and may be not that much useful in the real world, but it showcased a few important constructs of the everywhereml library and the key points to follow while developing your own TinyML project.

In the next lessons, we'll start digging in into each point in more details.


Get monthly updates

Do not miss the next posts on TinyML and Esp32 camera. No spam, I promise

We use Mailchimp as our marketing platform. By submitting this form, you acknowledge that the information you provided will be transferred to Mailchimp for processing in accordance with their terms of use. We will use your email to send you updates relevant to this website.

Having troubles? Ask a question

© Copyright 2023 Eloquent Arduino. All Rights Reserved.