Inference in Probabilistic Graphical Models by Graph Neural
Inference in Probabilistic Graphical Models by Graph Neural Networks Author: Ki. Jung Yoon, Renjie Liao, Yuwen Xiong, Lisa Zhang, Ethan Fetaya, Raquel Urtasun, Richard Zemel, Xaq Pitkow Presenter: Shihao Niu, Zhe Qu, Siqi Liu, Jules Ahmar
TL; DR: Use Graph Neural Networks (GNNs) to learn a message-passing algorithm that solves inference tasks in probabilistic graphical models. Motivation ● Inference is difficult for probabilistic graphical models. ● Message passing algorithms, such as belief propagation, struggles when the graph contains loops ○ Loopy belief propagation: convergence are not guaranteed.
Why GNNs ● Essentially an extension of recurrent neural networks (RNN) on the graph inputs. ● Central idea is to update hidden states at each node iteratively, by aggregating incoming messages. ● Have a similar structure as a message passing algorithm.
Factor graph and belief propagation ● Recall that the distribution of a factor graph is ○ ● Recall the formulas of a belief propagation algorithm ○
BP to GNNs: mapping the messages ● BP is recursive and graph-based. Naturally, we could map the messages to GNN nodes, and use Neural Networks to describe the nonlinear updates.
BP to GNNs: mapping the variable nodes
BP to GNNs: mapping the variable nodes Marginal probability of in MRF: Marginal joint probability of in factor graph: ● All of the messages depend only on one variable node at a time ● The nonlinear functions between GNN nodes can account for equilibrium is reached. AFTER
Preliminaries for model ● Binary MRF, aka Ising models. ● and are specified randomly, and are provided as input for GNN inference. ●
GNN Recap Update the state embedding - of based on the feature of the edges of the state embeddings of the neighbors of the feature of the neighbor of Local output function:
GNN Recap (Cont. ) Scarselli, Franco, et al. "The graph neural network model. " Decompose the state update function to be a sum of per-edge terms
Message Passing Neural Networks An abstraction of several GNN variants Define Message from i to j at time t+1 as: Step 1: Aggregate all incoming message into a single message at the destination node Step 2: Update hidden state based on the current hidden state and the aggregated message Phase 1 Message Passing
Message Passing Neural Networks (Cont. ) Phase 2: Readout Phase The message function, node update function, and readout function could have different settings. MPNN could generalize several different models.
GG-NN (Gated Graph Neural Network) Gate Recurrent Units (GRU) Source: Zhou, Jie, et al. "Graph neural networks: A review of methods and applications. "
GG-NN (Cont. ) Readout Phase:
GG-NN (Cont. ) Gate Recurrent Units (GRU)
GG-NN (Cont. ) Gate Recurrent Units (GRU)
Two mappings between Factor graph and GNN message-GNN node-GNN message-GNN and node-GNN perform similarly, and much better than belief propagation
Mapping I: Message-GNN (graphical model) Message �� ij between node i and j Message nodes are ij and jk Conforms closely to the structure of conventional belief propagation, and reflects how messages depend on each other: Motivation: (GNN) Node v and w connected
Mapping I: Message-GNN (nodes in graphical model) 1. If connected, message from node to 1. Then update its hidden state by: 2. Readout function to extract marginal or MAP: : Multi-layer Perceptron with Re. LU activation function neural network (GRU) a. First aggregates all GNN nodes with same target by summation b. Then apply a shared readout function another MLP with sigmoid activation function
Mapping II: Node-GNN ● Mapping: (graphical model) Variable nodes 1. Message function: 2. Aggregate Messages: 3. Node update function: 4. Readout is generated directly from hidden states: (GNN) Node
Message-GNN and Node-GNN ● Objective: backpropagation to minimize total cross-entropy loss function --- ground truth, --- estimated result Message Passing Function (General): ● Receives external inputs about couplings between edges ● Depends on the hidden states of source and destination nodes at the previous time step.
Experiments ● In each experiment, two types of GNNs are tested: ○ ○ Variable nodes (node-GNN) Message nodes (msg-GNN) ● Examine generalization of the model when. . . ○ ○ Testing on unseen graphs of the same structure Testing on completely random graphs Testing on graphs with the same size Testing on graphs with larger size ● Analyze performance in estimating both marginal probabilities and MAP state
Training Graphs
Larger, Novel Test Graphs
Marginal Inference Accuracy
Random Graphs
Generalization Performance on Random Graphs
Convergence of Inference Dynamics
MAP Estimation
Conclusion ● Experiments showed that GNNs provide a flexible learning method for inference in probabilistic graphical models ● Proved that learned representations and nonlinear transformations on edges generalize to larger graphs with different structures ● Examined two possible representations of graphical models within GNNs: variable nodes and message nodes ● Experimental results support GNNs as a great framework for solving hard inference problems ● Future work: train and test on larger and more diverse graphs, as well as broader classes of graphical models
References 1. Zhou, Jie, et al. "Graph neural networks: A review of methods and applications. " ar. Xiv preprint ar. Xiv: 1812. 08434 (2018). 2. Gilmer, Justin, et al. "Neural message passing for quantum chemistry. " Proceedings of the 34 th International Conference on Machine Learning-Volume 70. JMLR. org, 2017. 3. Scarselli, Franco, et al. "The graph neural network model. " IEEE Transactions on Neural Networks 20. 1 (2008): 61 -80. 4. Li, Yujia, et al. "Gated graph sequence neural networks. " ar. Xiv preprint ar. Xiv: 1511. 05493 (2015). 5. Wu, Zonghan, et al. "A comprehensive survey on graph neural networks. " ar. Xiv preprint ar. Xiv: 1901. 00596 (2019).
Homework 1. Where do GNNs outperform belief propagation? Where does belief propagation outperform GNNs? 2. Given the following factor graph, draw the GNN using Message-GNN mapping:
- Slides: 33