- Python Machine Learning Cookbook(Second Edition)
- Giuseppe Ciaburro Prateek Joshi
- 358字
- 2021-06-24 15:40:56
How to do it...
In this recipe, we will learn how to build a linear classifier using SVMs:
- We need to split our dataset into training and testing datasets. Add the following lines to the same Python file:
# Train test split and SVM training from sklearn import cross_validation from sklearn.svm import SVC X_train, X_test, y_train, y_test = cross_validation.train_test_split(X, y, test_size=0.25, random_state=5)
- Let's initialize the SVM object using a linear kernel. Add the following lines to the file:
params = {'kernel': 'linear'} classifier = SVC(**params, gamma='auto')
- We are now ready to train the linear SVM classifier:
classifier.fit(X_train, y_train)
- We can now see how the classifier performs:
utilities.plot_classifier(classifier, X_train, y_train, 'Training dataset') plt.show()
If you run this code, you will get the following:
The plot_classifier function is the same as we discussed in Chapter 1, The Realm of Supervised Learning. It has a couple of minor additions.
- Let's see how this performs on the test dataset. Add the following lines to the svm.py file:
y_test_pred = classifier.predict(X_test) utilities.plot_classifier(classifier, X_test, y_test, 'Test dataset') plt.show()
If you run this code, you will see the following output:
As you can see, the classifier boundaries on the input data are clearly identified.
- Let's compute the accuracy for the training set. Add the following lines to the same file:
from sklearn.metrics import classification_report target_names = ['Class-' + str(int(i)) for i in set(y)]
print("\n" + "#"*30)
print("\nClassifier performance on training dataset\n")
print(classification_report(y_train, classifier.predict(X_train), target_names=target_names))
print("#"*30 + "\n")
If you run this code, you will see the following on your Terminal:
- Finally, let's see the classification report for the testing dataset:
print("#"*30)
print("\nClassification report on test dataset\n")
print(classification_report(y_test, y_test_pred, target_names=target_names))
print("#"*30 + "\n")
- If you run this code, you will see the following on the Terminal:
From the output screenshot where we visualized the data, we can see that the solid squares are completely surrounded by empty squares. This means that the data is not linearly separable. We cannot draw a nice straight line to separate the two sets of points! Hence, we need a nonlinear classifier to separate these datapoints.