- Python Machine Learning Cookbook(Second Edition)
- Giuseppe Ciaburro Prateek Joshi
- 247字
- 2021-06-24 15:40:50
How to do it…
Let's see how to extract validation curves:
- Add the following code to the same Python file as in the previous recipe, Evaluating cars based on their characteristics:
# Validation curves
import matplotlib.pyplot as plt
from sklearn.model_selection import validation_curve
classifier = RandomForestClassifier(max_depth=4, random_state=7)
parameter_grid = np.linspace(25, 200, 8).astype(int)
train_scores, validation_scores = validation_curve(classifier, X, y, "n_estimators", parameter_grid, cv=5)
print("##### VALIDATION CURVES #####")
print("\nParam: n_estimators\nTraining scores:\n", train_scores)
print("\nParam: n_estimators\nValidation scores:\n", validation_scores)
In this case, we defined the classifier by fixing the max_depth parameter. We want to estimate the optimal number of estimators to use, and so have defined our search space using parameter_grid. It is going to extract training and validation scores by iterating from 25 to 200 in 8 steps.
- If you run it, you will see the following on your Terminal:
- Let's plot it:
# Plot the curve
plt.figure()
plt.plot(parameter_grid, 100*np.average(train_scores, axis=1), color='black')
plt.title('Training curve')
plt.xlabel('Number of estimators')
plt.ylabel('Accuracy')
plt.show()
- Here is what you'll get:
- Let's do the same for the max_depth parameter:
classifier = RandomForestClassifier(n_estimators=20, random_state=7)
parameter_grid = np.linspace(2, 10, 5).astype(int)
train_scores, valid_scores = validation_curve(classifier, X, y,
"max_depth", parameter_grid, cv=5)
print("\nParam: max_depth\nTraining scores:\n", train_scores)
print("\nParam: max_depth\nValidation scores:\n", validation_scores)
We fixed the n_estimators parameter at 20 to see how the performance varies with max_depth. Here is the output on the Terminal:
- Let's plot it:
# Plot the curve
plt.figure()
plt.plot(parameter_grid, 100*np.average(train_scores, axis=1), color='black')
plt.title('Validation curve')
plt.xlabel('Maximum depth of the tree')
plt.ylabel('Accuracy')
plt.show()
- If you run this code, you will get the following: