Probabilistic Surrogate Networks for Simulators with Unbounded Randomness

Andreas Munk, Berend Zwartsenberg, Adam Ścibior, Atılım Güneş Baydin,
Andrew Stewart, Goran Fernlund,
Anoush Poursartip, Frank Wood

January 13, 2022

Types of Simulators

different_sims.svg

Probabilistic programs and automatic inference - an example le2016inference

gmm3_program.svg

gmm3_with_posterior.png

PyProb - https://github.com/pyprob/pyprob/tree/master/pyprob

gmm3_program_pyprob.svg

Addresses Uniquely Identifies Random Variables van2018introduction

example_program.svg

  • Infer the posterior distribution \(p(\vx_{\mr{lat}}|\vx_{\mr{obs}})\) over latent variables \(\vx_{\mr{lat}}\) (specified by sample statements) conditioned on observed variables \(\vx_{\mr{obs}}\) (specified by observe statements)
  • We consider the joint distribution over the space of “traces” \((x_{a_t}, a_t)\) for \(t=1,\dots,T\) \[p(\vx, \va) = \prod_{t=1}^T p(a_t|x_{a_1:a_{t-1}},a_{1:t-1}) p(x_{a_t}|x_{a_1:a_{t-1}},a_{1:t})\]
  • We denoted the set of all variables \(\vx=\vx_{\mr{lat}}\cup\vx_{\mr{obs}}\)
  • Address transitions are deterministic: \(p(a_t|x_{a_1:a_{t-1}},a_{1:t-1})\) is a Kronecker-Delta function
  • But this notation becomes useful later

Sequential Importance Sampling

  • Draw \(K\) samples \(\vx_{\mr{lat}}^k \overset{i.i.d.}{\sim}q(\vx_{\mr{lat}}|\vx_{\mr{obs}})\) (proposal distribution)
  • Compute the weights, \[w^k = \frac{p(\vx_{\mr{lat}}^k,\vx_{\mr{obs}})}{q(\vx_{\mr{lat}}^k|\vx_{\mr{obs}})}\]
  • Approximate the posterior as \[ p(\vx_{\mr{lat}}|\vx_{\mr{obs}})\approx\frac{\sum_{k=1}^Kw^k\delta(\vx_{\mr{lat}}^k-\vx_{\mr{lat}})}{\sum_{k=1}^Kw^k}\]

Sequential Importance Sampling and Evaluation Based Inference

  • SIS in evaluation based inference engines involves calculating the weights as \[ w^k=\frac{p(\vx,\va)}{q(\vx_{\mr{lat}}^k,\va|\vx_{\mr{obs}})}=\frac{\prod_{t=1}^T p(a_t|x_{\lt a_{t}},a_{\lt t})p(x_{a_t}|x_{\lt a_{t}},a_{\lt t})} {\prod_{x_{a_t}^{\mr{lat}}\in\vx_{\mr{lat}}}q(x^{\mr{lat}}_{a_t}|x^{\mr{lat}}_{\lt a_t},a_{\lt t},\vx_{\mr{obs}})\prod_{t=1}^{T}q(a_{t}|x_{\lt a_{t}},a_{\lt t})}\]
  • Where \(q(a_{t}|x_{\lt a_{t}},a_{\lt t})=p(a_{t}|x_{\lt a_{t}},a_{\lt t})\),
    • Proposing from \(q\) requires executing the program
    • \(K\) samples requires \(K\) evaluations of the program

Inference Compilation (IC) le2016inference

  • IC is all about constructing \(q(\vx_{\mr{lat}}, \va|\vx_{\mr{obs}})\) \(\rightarrow\) efficient SIS \[q(\vx_{\mr{lat}},\va|\vx_{\mr{obs}};\phi) = q(\vx_{\mr{lat}},\va|\vx_{\mr{obs}};\phi) = \prod_{x_{a_t}^{\mr{lat}}\in\vx_{\mr{lat}}}q(x^{\mr{lat}}_{a_t}|\eta_{a_t}(x^{\mr{lat}}_{\lt a_t},a_{\lt t},\vx_{\mr{obs}};\phi))\prod_{t=1}^{T}q(a_{t}|x_{\lt a_{t}},a_{\lt t})\]
  • \(\eta_{a_{t}}(\cdot)\) is a recurrent neural network
  • We train \(q(\cdot|\vx_{\mr{obs}};\phi)\) to be close to \(p(\cdot|\vx_{\mr{obs}})\) for all \(\vx_{\mr{obs}}\),
    by minimizing the expected KL divergence under the marginal \(p(\vx_{\mr{obs}})\) \[\begin{align}\mL_{\mr{IC}}(\phi)&=\E_{p(\vx_{\mr{obs}})}\br{\mr{KL}\paren{p(\vx_{\mr{lat}},\va|\vx_{\mr{obs}})||q(\vx_{\mr{lat}},\va|\vx_{\mr{obs}};\phi)}}\nonumber\\ &=-\E_{p(\vx,\va)}\br{\log q(\vx_{\mr{lat}},\va|\vx_{\mr{obs}};\phi)} + \mr{const}\nonumber\end{align}\]

Probabilistic Surrogate Network (PSN)

  • Evaluation-based inference (IC) is still limited by the execution speed of the program
  • How about learning a surrogate model \(s\) that replaces \(p\) as a program when \(p\) is slow?
  • We structure \(s\) as a distribution over traces factorized exactly like \(p\) \[ s(\vx,\va;\theta)=\prod_{t=1}^{T}s(x_{a_t}|\xi_{a_t}(x_{\lt a_{t}},a_{\leq t};\theta))s(a_t|\zeta_{a_t}(x_{\lt a_{t}},a_{\lt t};\theta))\]
  • \(\xi_{a_t}(\cdot)\) and \(\zeta_{a_{t}}(\cdot)\) are recurrent neural networks
  • Now we have to model the address transitions \(s(a_t|\zeta_{a_t}(x_{\lt a_{t}},a_{\lt t};\theta))\)
  • For instance a categorical distribution
  • We match \(p\) and \(s\) by minimizing the KL-divergence \[\begin{align*} \mL(\theta) &= \mr{KL}\paren{p(\vx,\va)||s(\vx,\va;\theta)} \\ &= -\E_{p(\vx,\va)}\br{\log s(\vx,\va;\theta)} + \mr{const} \end{align*}\]
  • Fully compatible with the probabilistic programming framework

The Communication Interface

program-to-surrogate.svg

Simulating the Curing of Composite Materials

wing_problem.svg

Faster Simulation

  • ~90 times faster simulation

heatmap_guess.gif

Faster Inference

  • ~15 times faster inference
  • Infer \(\E_{p(\vx_{\mr{lat}}|\vx_{\mr{obs}})}\br{\mu_w}\)
  • \(\mu_w\) is the empirical mean of the temperature across the time window $w=\br{155,165}$min at a fixed depth $30$mm

posterior_box_plot.svg

Program Synthesis and Address Transitions

orig_traces_graph.svg

Control Flow Program

toy_program.svg

Address Transitions

address_transition_orig.svg

Figure 1: Address transition graph generated by the original program

address_transition_surr.svg

Figure 2: Address transition graph generated by the PSN

References