Explaining Deep Reinforcement Learning models with Linear Model U-Trees

Vishnu Sharma
7 min readApr 4, 2021

The highly non-linear structure of deep learning (DL) models make them effective at learning difficult tasks like scene understanding, text translation, and novel data synthesis. This also makes them very successful at reinforcement learning (RL) tasks like playing video games, which require interacting with the environment and learning. While many different approaches have been proposed and used for explaining (or interpreting) DL models, not all of them can be used directly for RL. RL models often predict value function or policy parameters, which are not as intuitive as labels in object detection. In this blog, I’ll give overview of Linear Model U-Trees based approach for explaining decision in a Deep Q-learning Network (DQN) proposed by Guiliang Liu, Oliver Schulte, Wang Zhu and Qingcan Li in their paper ‘Toward Interpretable Deep Reinforcement Learning with Linear Model U-Trees’ [1]. This is one of the decision tree based approach used for RL explainability.

Linear Model U-Trees

A linear model U-Tree (LMUT) is an extension of Continuous U-Trees (CUTs), where we learn linear models at the leaves (Fig. 1). CUTs are regression trees which discretize the continues state space using the decision splits. Chernova et al. [2] proposed using CUTs for RL, where we regress over the value function. LMUTs can therefore be considered as marriage of tree-based knowledge distillation and tree-based RL techniques. They distill the DRL models into a structure suitable for RL. Effectively, a LMUT generates a piecewise linear function over the feature space. The weights of the linear models at the leaves help in finding out the importance of features.

An example LMUT

Fig. 1: An example LMUT. Leaves learn learning models for Q-function using the feature ranges defined by the nodes. Source

Training Process

The authors call the process of learning LMUT from DRL network mimic training. They authors present two training settings for this, based on whether we learn the LMUTs while running the DRL (like RL) or after gathering data from running the DRL model (like supervised learning):

  • Experience Training

In this setting, we train the LMUT in a supervised fashion. We first gather the input observation (I) and actions (a) while training a DRL network. After we have learnt the DRL mode, we feed the recorded data to get the Q-values from it. The dataset thus obtained is used for batch training the LMUT.

Fig. 2: Experience training setting. Source
  • Active Play

This setting similar to learning a RL model. We let the DRL model run over the environment and used the observations and predictions (Q-values) to train the LMUT in parallel, eliminating the need to store the data.

Fig 3: Active play setting. Source

LMUT training steps

To train the LMUTs we go through following 2phases:

  1. Data Gathering Phase: In this phase we collect the observations and Q-values for training the linear models at leaves. We can choose either experience learning or active play for this purpose.
  2. Node Splitting Phase: LMUT is a decision tree,and we need to learn how to split the features/node. Before splitting a node, we first update the linear model as the leaf nodes using Stochastic Gradient Descent (SGD). If the error in prediction is above a certain threshold, we use a splitting criteria to split the node. The authors tried three splitting criterion: working response of SGD, KolmogorovSmirnov (KS) test and Variance Test and found Variance Test to be be more efficient than others.
Fig 3: LMUT training algorithm. Source

Interpretability

To quantify the influence of features, the authors find feature influence as total variance reduction in Q-values as following:

Here Inf_f is the influence of splitting feature f on node N, WNj is the weight of feature i on node N, Numᵢ is the numbr of Q values on node i and Varᵢ is the variance of Q values of node c. The overall influence of feature f is given by summing up influence of of that feature an all nodes.

Another way explain the decision in LMUTs is to used the rule extraction from the underlying tree. The rules are represented in the paper as feature ranges and the Q-values vector for the current observation.

Evaluation

The authors learn LMUTs on three environments: Mountain Car, Cart Pole, and Flappy Bird. While the first two environments use 2 and 4 dimensional continuous inputs, Flappy Bird uses an image as input. A DQN is train for playing these games, using a rescaled images (80x80) converted to binary features for Flappy Bird. For mountain Car, velocity and position of the cars are used as input and Q-values are found for move left, no push and move right actions. Cart pole uses the pole angle, cart’s velocity, its position and the velocity of the pole at the tip as input. For this environment the model predicts Q-values for push left and push right action. Model for Flappy bird uses 4 consecutive, rescaled images as input and generated Q-values for fly up and fly down actions.

After extracting the DQN into a LMUT, the authors compare the performance of this tree against other decision tree algorithms like CART, M5-Regresstion Trees, and Fast Incremental Model Trees (FIMT), where LMUT performs appreciably well.

Fig. 4: Feature influence in Mountain Car and Cart Pole. Source
Fig. 5: Feature influence in Flappy Bird. Source

Fig. 4 shows the influence of features in Mountain car and Cart Pole. As Flappy Bird has a high-dimensional input (80x80), it is instead visualized as a heatmap, as shown in Fig. 5.

Fig. 6: Rule extraction for Mountain Car and Cart Pole. Source

Fig. 6 shows how the rule extraction is used for Mountain Car and Cart Pole. What we effectively see here is the feature ranges in which the environment currently is, and the corresponding Q-values for all the actions.

Fig 7: Rule extraction for Flappy Bird. Source

Fig. 7 shows the rules for the two examples. The DQN model used 4 consecutive images, shown in each row. The red start show the features with influence values more than 0.008. These rules show that the most recent image is the more influential in the decision. Also,these rules suggest that the most recent image is used to find the location of pipes and the other images are used to find the bird’s position and velocity.

My Takeaways

In this post, I have called this approach as explainability technique rather than interpretability (as used by authors), because here the LMUTs are used to extract knowledge from a DQN, rather than training the LMUTs to perform the task directly. While the LMUTs themselves are interpretable, they approximates the DQN and then we interpret LMUTs in this paper.

This method is one among many other tree-based methods used for DL transparency. These methods have the advantage of sophisticated rule based extractions, which are easy to understand and can be applied over different types of input. Given these benefits, it wouldn’t be surprising to see variations of decision trees (even Random Forests) being used to explain deep learning models’ decision-making process.

I specifically liked two things about this paper: (1) it build up on an existing RL technique (CUTs) for model distillation, and shows its efficacy against other tree based models, making a case for using this approximated model to be used not only for explainability, but also as a faster and lighter alternative to the original model in deployment, and (2) it learns a global explainable model, which is very useful when comparing decision for two observations, which requires normalization in local explanation approaches.

The implementations for the three games presented in this paper are available atthe lead authors’ (Guiliang Liu) GitHub repos:

References

  1. Liu, Guiliang, et al. “Toward interpretable deep reinforcement learning with linear model u-trees.” Joint European Conference on Machine Learning and Knowledge Discovery in Databases. Springer, Cham, 2018.
  2. Chernova, Sonia, and Manuela Veloso. “Tree-based policy learning in continuous domains through teaching by demonstration.” Proceedings of Workshop on Modeling Others from Observations (MOO 2006). 2006.

--

--