Here we can download the free dataset for Iris.

Iris dataset

Now we rename the file to csv and load it into Spark for check.

 scala> val schema = "sepal_length DOUBLE, sepal_width DOUBLE, petal_length DOUBLE,petal_width DOUBLE,species STRING"
scala> val df = spark.read.schema(schema).csv("iris.csv")

This is a small dataset as you see below.

 scala> df.agg(countDistinct("species")).show
+--------------+
|count(species)|
+--------------+
| 3|
+--------------+
scala> df.groupBy("species").count.show
+---------------+-----+
| species|count|
+---------------+-----+
| Iris-virginica| 50|
| Iris-setosa| 50|
|Iris-versicolor| 50|
+---------------+-----+

The column "species" is label and the other four columns are features. There are only three species for machine learning classification purpose.

So we write a script for this job, with sklearn as the ML library.

 import joblib
import sklearn
import pandas as pd
from sklearn import svm

#Read the training data from the file
iris_data = pd.read_csv('./iris.csv',sep=',',names=["sepal_length", "sepal_width", "petal_length","petal_width","species"])

#Assigning the classes and removing the target variable
iris_label = iris_data.pop('species')

#We're going to be using the SVC (support vector classifier) SVM (support vector machine)
classifier = svm.SVC(gamma='auto')

#Training the model
classifier.fit(iris_data, iris_label)

#Saving the data locally
model_filename = 'model.joblib'
joblib.dump(classifier, model_filename)

As you see, only few lines of code we can train a mode to run this classification job.

Return to home | Generated on 09/30/22