- Python Machine Learning Cookbook(Second Edition)
- Giuseppe Ciaburro Prateek Joshi
- 306字
- 2021-06-24 15:41:00
How to do it...
Let's see how to extract confidence measurements:
- The full code is given in the svm_confidence.py file, already provided to you. We will discuss the code of the recipe here. Let's define some input data:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.svm import SVC
import utilities
# Load input data
input_file = 'data_multivar.txt'
X, y = utilities.load_data(input_file)
- At this point, we split the data for training and testing, and then we will build the classifier:
from sklearn import model_selection
X_train, X_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.25, random_state=5)
params = {'kernel': 'rbf'}
classifier = SVC(**params, gamma='auto')
classifier.fit(X_train, y_train)
- Define the input datapoint:
input_datapoints = np.array([[2, 1.5], [8, 9], [4.8, 5.2], [4, 4], [2.5, 7], [7.6, 2], [5.4, 5.9]])
- Let's measure the distance from the boundary:
print("Distance from the boundary:")
for i in input_datapoints:
print(i, '-->', classifier.decision_function([i])[0])
- You will see the following printed on your Terminal:
- The distance from the boundary gives us some information about the datapoint, but it doesn't exactly tell us how confident the classifier is about the output tag. To do this, we need Platt scaling. This is a method that converts the distance measure into a probability measure between classes. Let's go ahead and train an SVM using Platt scaling:
# Confidence measure params = {'kernel': 'rbf', 'probability': True} classifier = SVC(**params, gamma='auto')
The probability parameter tells the SVM that it should train to compute the probabilities as well.
- Let's train the classifier:
classifier.fit(X_train, y_train)
- Let's compute the confidence measurements for these input datapoints:
print("Confidence measure:")
for i in input_datapoints:
print(i, '-->', classifier.predict_proba([i])[0])
The predict_proba function measures the confidence value.
- You will see the following on your Terminal:
- Let's see where the points are with respect to the boundary:
utilities.plot_classifier(classifier, input_datapoints, [0]*len(input_datapoints), 'Input datapoints', 'True')
- If you run this, you will get the following: