Page 451 - Ai_V3.0_c11_flipbook
P. 451
# Model Evaluation
rf_pred = rf_pipeline.predict(X_test)
svm_pred = svm_pipeline.predict(X_test)
'''there's a mismatch between the data types of the true labels (y_true)
and the predicted labels (y_pred). The true labels seem to be encoded as
integers (0 and 1), while the predicted labels are encoded as
strings ('N' and 'Y'). To resolve this issue, you need to ensure that
both y_true and y_pred have the same data type.
Since your true labels (y_test_binary) are binary encoded (0 and 1),
you should convert the predicted labels (rf_pred and svm_pred)
to the same format.'''
from sklearn.preprocessing import LabelBinarizer
# Convert predicted labels to binary format
label_binarizer = LabelBinarizer()
rf_pred_binary = label_binarizer.fit_transform(rf_pred)
svm_pred_binary = label_binarizer.transform(svm_pred)
# Convert true labels to binary format
label_binarizer = LabelBinarizer()
y_test_binary = label_binarizer.fit_transform(y_test)
# Model Evaluation with modified accuracy calculation
rf_accuracy = accuracy_score(y_test_binary, rf_pred_binary)
svm_accuracy = accuracy_score(y_test_binary, svm_pred_binary)
print("Random Forest Accuracy:", rf_accuracy)
print("SVM Accuracy:", svm_accuracy)
else:
print("DataFrame is either not loaded successfully or is empty.")
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, roc_curve, roc_auc_score, f1_score
# Visualize model performance
def visualize_performance(y_test, y_pred_rf, y_pred_svm):
# Convert true labels to binary format
label_binarizer = LabelBinarizer()
y_test_binary = label_binarizer.fit_transform(y_test)
# Convert predicted labels to binary format
y_pred_rf_binary = label_binarizer.transform(y_pred_rf)
y_pred_svm_binary = label_binarizer.transform(y_pred_svm)
# Confusion matrices
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.title('Random Forest Confusion Matrix')
sns.heatmap(confusion_matrix(y_test_binary, y_pred_rf_binary), annot=True,
fmt='d', cmap='Blues')
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.subplot(1, 2, 2)
Projects 449

