Here we can download the free dataset for Iris.
Now we rename the file to csv and load it into Spark for check.
scala> val schema = "sepal_length DOUBLE, sepal_width DOUBLE, petal_length DOUBLE,petal_width DOUBLE,species STRING"
scala> val df = spark.read.schema(schema).csv("iris.csv")
This is a small dataset as you see below.
scala> df.agg(countDistinct("species")).show
+--------------+
|count(species)|
+--------------+
| 3|
+--------------+
scala> df.groupBy("species").count.show
+---------------+-----+
| species|count|
+---------------+-----+
| Iris-virginica| 50|
| Iris-setosa| 50|
|Iris-versicolor| 50|
+---------------+-----+
The column "species" is label and the other four columns are features. There are only three species for machine learning classification purpose.
So we write a script for this job, with sklearn as the ML library.
import joblib
import sklearn
import pandas as pd
from sklearn import svm
#Read the training data from the file
iris_data = pd.read_csv('./iris.csv',sep=',',names=["sepal_length", "sepal_width", "petal_length","petal_width","species"])
#Assigning the classes and removing the target variable
iris_label = iris_data.pop('species')
#We're going to be using the SVC (support vector classifier) SVM (support vector machine)
classifier = svm.SVC(gamma='auto')
#Training the model
classifier.fit(iris_data, iris_label)
#Saving the data locally
model_filename = 'model.joblib'
joblib.dump(classifier, model_filename)
As you see, only few lines of code we can train a mode to run this classification job.
Return to home | Generated on 09/30/22