Source code for gossipcat.lab.Explain

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
author:     Ewen Wang
email:      wolfgangwong2012@gmail.com
license:    Apache License 2.0
"""

import numpy as np
import matplotlib.pyplot as plt

[docs] class Explain(object): """Explain tree-based models with dtreeviz and SHAP.""" def __init__(self, model, X, y, target, features, regression=False): """ Args: model: A tree-based model, like XGBoost. X (np.narray): X values. y (np.array): y values. target (str): The target name. features (list): The list of features. regression (bool): Whether a regression model, defualt False. """ super(Explain, self).__init__() self.model = model self.X = X self.y = y self.target = target self.features = features self.regression = regression
[docs] def tree(self, tree_index=0, class_names=None, show_node_labels=True, title="Decision Tree", orientation='TD', scale=1.5): """Plot the tree with dtreeviz. Args: tree_index (int): The tree index of the model, default 0. class_names (list): [For classifiers] A dictionary or list of strings mapping class value to class name. show_node_labels (bool): Add "Node id" to top of each node in graph for educational purposes. title (str): The plot title. orientation (str): Is the tree top down, "TD", or left to right, "LR"? scale (float): Scale the width, height of the overall SVG preserving aspect ratio, default 1.5. Return: viz (dtreeviz): A dtreeviz instance. """ import matplotlib.font_manager try: from dtreeviz.trees import dtreeviz except Exception as e: raise e if self.regression: class_names = None else: if class_names is None: class_names = np.unique(self.y).tolist() self.viz = dtreeviz(tree_model=self.model, x_data=self.X, y_data=self.y, target_name=self.target, feature_names=self.features, class_names=class_names, tree_index=tree_index, show_node_labels=show_node_labels, orientation=orientation, title=title, scale=scale) return self.viz
[docs] def feature_importance(self, max_display=20): """Plot the feature and SHAP variable importance with SHAP. """ import shap self.explainer = shap.TreeExplainer(self.model) self.shap_values = self.explainer(self.X) plot_type='bar' title='Feature Importance' shap.summary_plot(shap_values=self.shap_values, features=self.X, feature_names=self.features, plot_type=plot_type, show=False) plt.title(title) plt.show() title='Feature Importance and Impact' shap.plots.beeswarm(shap_values=self.shap_values, max_display=max_display, show=False) plt.title(title) plt.show() return None