diff --git a/.gitignore b/.gitignore index 62e66b3..284d517 100644 --- a/.gitignore +++ b/.gitignore @@ -95,3 +95,42 @@ venv.bak/ *.aux *.out *.pdf + + +# Compiled source # +################### +*.com +*.class +*.dll +*.exe +*.o +*.so + +# Packages # +############ +# it's better to unpack these files and commit the raw source +# git has its own built in compression methods +*.7z +*.dmg +*.gz +*.iso +*.jar +*.rar +*.tar +*.zip + +# Logs and databases # +###################### +*.log +*.sql +*.sqlite + +# OS generated files # +###################### +.DS_Store +.DS_Store? +._* +.Spotlight-V100 +.Trashes +ehthumbs.db +Thumbs.db diff --git a/notebooks/biochemical_model.ipynb b/notebooks/biochemical_model.ipynb index f4db6c9..4a5cc6d 100644 --- a/notebooks/biochemical_model.ipynb +++ b/notebooks/biochemical_model.ipynb @@ -32,17 +32,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 79, "metadata": {}, "outputs": [], "source": [ "from math import exp\n", "\n", - "from pyro.distributions import LogNormal\n", + "from pyro.distributions import LogNormal,Normal\n", + "from pyro.primitives import iarange\n", "from pyro import condition, do, infer, sample\n", "from pyro.infer import EmpiricalMarginal\n", "from torch import tensor\n", "\n", + "import sys\n", + "sys.path.append('../')\n", "from causal_demon.inference import infer_dist\n", "from causal_demon.transmitters import cancer_signaling\n", "\n", @@ -65,27 +68,58 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 82, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor(0.1904)\n" + ] + } + ], "source": [ "noise_vars = ['N_egf', 'N_igf', 'N_sos', 'N_ras', 'N_pi3k', 'N_akt', 'N_raf', 'N_mek', 'N_erk']\n", - "noise_prior = {N: LogNormal(0, 10) for N in noise_vars}" + "noise_prior = {N: LogNormal(0, 10) for N in noise_vars}\n", + "\n", + "print(noise_prior['N_egf'].sample())" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 83, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "{'akt': tensor(405643.),\n", + " 'egf': tensor(44953.8320),\n", + " 'erk': tensor(3115.5735),\n", + " 'igf': tensor(0.1029),\n", + " 'mek': tensor(162.8303),\n", + " 'pi3k': tensor(119770.7422),\n", + " 'raf': tensor(4.5652),\n", + " 'ras': tensor(176.5171),\n", + " 'sos': tensor(28.5478)}" + ] + }, + "execution_count": 83, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "cancer_signaling(noise_prior)" ] }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, + "execution_count": 84, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "# Experimental use only, please ignore\n", @@ -127,12 +161,19 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 213, "metadata": {}, "outputs": [], "source": [ - "evidence = {'egf': tensor(800.), 'igf': tensor(0.), 'erk': tensor(800.)}\n", - "cancer_obs = condition(cancer_signaling, data=evidence)" + "evidence = {'egf': tensor([800.]), 'igf': tensor([0.,0.]), 'erk': tensor([800.])}\n", + "\n", + "cancer_obs = condition(cancer_signaling, data= evidence)\n", + "\n", + "#evidence = [{'egf': tensor([800.]), 'igf': tensor([0.]), 'erk': tensor([800.])},{'egf': tensor([800.]), 'igf': tensor([0.]), 'erk': tensor([800.])},{'egf': tensor([600.]), 'igf': tensor([0.]), 'erk': tensor([600.])},{'egf': tensor([800.]), 'igf': tensor([0.]), 'erk': tensor([800.])},{'egf': tensor([600.]), 'igf': tensor([0.]), 'erk': tensor([600.])}]\n", + "\n", + "#with iarange('data',len(evidence)) as ind:\n", + "# print(ind)\n", + "# cancer_obs = condition(cancer_signaling, data= evidence)" ] }, { @@ -144,7 +185,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 180, "metadata": {}, "outputs": [], "source": [ @@ -160,8 +201,10 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, + "execution_count": 181, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "noise_marginals = {\n", @@ -179,11 +222,11 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 204, "metadata": {}, "outputs": [], "source": [ - "cancer_do = do(cancer_signaling, data={'igf': tensor(800.)})" + "cancer_do = do(cancer_signaling, data={'igf': tensor([800.,600.])})" ] }, { @@ -195,21 +238,43 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 207, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], "source": [ "cancer_do_dist = infer_dist(cancer_do, noise_marginals)\n", + "print(cancer_do_dist)\n", "erk_cf_marginal = EmpiricalMarginal(cancer_do_dist, sites = 'erk')" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 201, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[tensor([90.9036, 90.9029])]\n" + ] + } + ], "source": [ - "hist(erk_cf_marginal, 'Erk')" + "#hist(erk_cf_marginal, 'Erk')\n", + "\n", + "print(erk_cf_marginal())\n", + "#\n", + "#plt.hist([erk_cf_marginal().item() for _ in range(100)])\n", + "#plt.title(\"Erk\")" ] }, { @@ -226,9 +291,9 @@ ], "metadata": { "kernelspec": { - "display_name": "venv", + "display_name": "Python 3", "language": "python", - "name": "venv" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -240,7 +305,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.4" + "version": "3.6.3" } }, "nbformat": 4, diff --git a/notebooks/experiments/condition_on_iid_data.ipynb b/notebooks/experiments/condition_on_iid_data.ipynb new file mode 100644 index 0000000..598a8b8 --- /dev/null +++ b/notebooks/experiments/condition_on_iid_data.ipynb @@ -0,0 +1,351 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Experiment on passing i.i.d. data to pyro.condition\n", + "\n", + "\n", + "(Building the same model as pyro documentation: http://pyro.ai/examples/intro_part_ii.html#Conditioning-Models-on-Data)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "# import some dependencies\n", + "import numpy as np\n", + "import pandas as pd\n", + "import os\n", + "\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline\n", + "\n", + "import seaborn as sns\n", + "sns.set(style=\"ticks\")\n", + "\n", + "import torch\n", + "import torch.distributions.constraints as constraints\n", + "\n", + "import pyro\n", + "from pyro.optim import Adam\n", + "from pyro.infer import SVI, Trace_ELBO\n", + "import pyro.infer\n", + "import pyro.optim\n", + "import pyro.distributions as dist\n", + "from pyro.infer.mcmc import MCMC\n", + "from pyro.infer.mcmc.nuts import NUTS\n", + "torch.manual_seed(101);" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Model" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "def scale(guess):\n", + " # The prior over weight encodes our uncertainty about our guess\n", + " # This is one dimension\n", + " weight = pyro.sample(\"weight\", dist.Normal(guess, 1.0))\n", + " # This encodes our belief about the noisiness of the scale:\n", + " # the measurement fluctuates around the true weight\n", + " # data point here is also one dimensional\n", + " return pyro.sample(\"measurement\", dist.Normal(weight, 0.75))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Conditioning Models on Data" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor(12.5057)\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAY0AAAEXCAYAAABRWhj0AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvDW2N/gAAGExJREFUeJzt3XuUXWWZ5/FvJiaQRiDtjaABQSAPchEcJAjIZXUDIwhEmlt3cCHd3KRRYQyg3WIHcNRxbCBoizrRCMq1hRXpELAFIooBsbG5CjwiBgYhTNtLo3Iz1/lj73JOTk5VvUnqnF2p+n7WqpVz3vPuvZ/alTq/evc++91jVq1ahSRJJf5L0wVIkjYchoYkqZihIUkqZmhIkooZGpKkYoaGJKmYoSEAImKbiFgREQ+0fD0YEX/T1u+bEbHLEGzvgYiYOEifOyPimH5emx0Re/Tz2lPrW99oEhF7RsSXm65jMBHx6oi4NSImNF3LaGZoqNXLmbl73xdwGHBxRLwNICKOA36bmY+s74bqbSxZj1UcDIxZ3zoEwM7A5KaLGExmvgBcC3yy6VpGs1c1XYCGr8x8NiKeAKYADwEXAsdGxFjgeWDvzPx5RHwMOCMz3wwQEbcBlwILgcuAXYFxwB3AuZm5PCJWAa8HfgN8DjgS+C1wL7BTZh5YlzEtIs4DtgBuB06letN4I3B1RJyYmfeWfD/1COQa4D3Aa4GZwL7AHsAy4MjMfC4i3gT8E7B1Xfd1mfnpeh1/D7wX2BjYBDgnM+dGxI7A1+r2McBXM/PyiLgAeF1mfrBe/o/PI+JO4NfAjsCXgG8MsL9eqffp4cBmwLnAsXXf54AjMvPFiHhrvY7XAmOBz2fmnIg4EPgU8AtgF2Aj4Ezg58BFwOYR8fXM/OsB9t9Y+vlZ1d/LP2XmDXXfPz4foKZXA18HdgBWAj8BTgf+pFN7Zq4E/hn4bER8LjP/b3+1qnscaahfEbE3sD1wb31IakJmPpKZK4B5wLvrru8GxkfElIjYHNid6g3+UuAnmbkH8HbgdcBH2jZzCtWb9i7A3sB2ba9vWre/FTgU2DczP071RnlCaWC02DgzdwNmAP8buKx+/gxwUt3nm8Ccuu6pwEERcVxEvBk4CDggM98GfJzqDReqN/F59TKHAftHRMnv128yc6fM/AID76+NgMWZuStwOfBV4GxgJ2BzqnB9FXAD8LF6HQcA50TEO+t17AVcnJlvpwq4CzLzGeAfgLsGCozaYD+rNQxS01HApvWods96kbcM0E5mvgL8kGofqwGONNRqQkQ8UD9+FfCfVG/Mz0TEXlR/lfaZC3wgIq4EtqT6C/5gqr+cv5OZSyPicGBqRJzct/4O2zwM+Eb9ZkBEfAX4cMvr19ch9VI96nnDen6PN9b/Pgk8n5kPtjx/TURsQvXG9pqI6DsM8mpg98z854h4P3BCRGwPvLN+Dar98Y2ImEoVmB/OzJURMVg9d7U8Hmx/tdb+cGY+CxARi4DXUI0ItwPmtGx3AlUAPQY8nZl9P99/5/+HZKnBfladDFTTd4BP16OS24BZ9ch1Raf2lnU+CQy6Y9UdhoZavVz/ddfJSqpDC31uo/pr9z3AnfXzM4CXgOvrPmOBYzPzMYD6xHf7ZGfLWf3cxIq215e1PF7F+p/H+EM/6+4ztt7GPpn5EkBEvA54JSL+K3AT1Yjgu8D3qQ4rkZk3R8QOVMH558DMiNinQ83j27b3Qtu2B9pfJbUvaf0ZRsQWVIeS3gm83NJ3XfblQD+r/r7PfmvKzFfq8D0Q+DPg9oj4UH1Iq2N7yzqXrmXtGiIenlKpn1EfIoA/Hib4PtV5gb430L2B/aj+ggT4V+C/R8SYiNgI+Bfgg23rnQ+8LyI2qg9lnMSawdLJcqrj/kMqM38H/Ij6sFD9xr0QmAbsD9yXmZdQfb/vpQ7SiLgGOD4zrwP+FvgdsBXwK2CPeh9sAhwywOZL9teA5VOF2/vqmrYCHqE6pDSQ0n050M/qV8A76u1uB7xtsJoi4gyqcxffzcyPUn3/u/TX3lLHW4DHC+pVFxgaKlJ/Yurl+qRmn7lUhx8WZObLwIPAwr7DF1SHLjYBHqY6kf4w8L/aVn0F1QnV+4G7qf6CfKmgpG8D10fEQG/C62o68M6IeLiu7drMvJrqkzuvi4hHqU7OvkB1GGtTqpPzJ0TEg/Uyc6mC5WqqN9QngFuAewbYbsn+6ldmLqUKt1Mi4iGqMP9EZi4cZNF7gB0jYi5ARNwSEUd26HcF/f+s/gdwSEQ8AnwW+EFBTd+gCt1HI+I+qhP8lw3QTh2me1OdU1MDxjg1ukpFxHTgXZn5t0O4zkOAN2TmVfXzy4BX6r8w13WdT2XmNkNU4qgTEacC/5mZc9vah/xntQ61nQTsnJnn9mqbWp0jDRXLzGuA10bErkO42p8C74/qQsKfUn0M99NDuH6tveXAzR3aG/1Z1SO66cAFvdqm1uRIQ5JUzJGGJKmYoSFJKrbBX6dRf5piT2Axa37GX5LU2ViqC3P/LTP/MFjnPht8aFAFxl2D9pIkdbIf1dQsRUZCaCwGuPrqq5k0aVLTtUjSBuH555/nhBNOgPo9tNRICI0VAJMmTWLy5GE/u7MkDTdrdVjfE+GSpGKGhiSpmKEhSSpmaEiSihkakqRihoYkqZihIUkq1tXrNCJiJnBc/XR+Zp4XEXOorkB8sW6/MDPnRsRBwCVU9w++PjPP72ZtGl2WLlvB+HFjB+84QrYrdUvXQqMOgUOobiC/CvhORBxFNe3H/pm5uKXvBGAOcADwDDA/Ig7NzFu7VZ9Gl/HjxnLEjJt6vt15F0/r+TalburmSGMxMKO+3SMR8Riwdf01OyK2prol5oXAVOCJzFxU970KOBZYLTTq+zVPbNuOl4FLUo90LTQy86d9jyNiB+B44F3AgcDpVPdXvhk4uX7cOv/JYjqHwdnAzO5ULEkaTNfnnoqInYH5wDmZmcBRLa99ATgR+FaHRVd2aJtFdXP7VpNxlltJ6olunwjfF7gRODszr6vvLT0lM2+su4wBlgHPAq1T1G4JPNe+vsxcAixp20Y3SpckddDNE+FbAd8Gjs/MBXXzGGBWRCygOiR1GnAlcG+1SGwPLKK6efycbtUmSVo33RxpnANsDFzSMhr4MvAZYCEwDrgxM68FiIiTqEYlGwO3ADd0sTZJ0jro5onws4Cz+nn58g797wB261Y9kqT15xXhkqRihoYkqZihIUkqZmhIkooZGpKkYoaGJKmYoSFJKmZoSJKKGRqSpGKGhiSpmKEhSSpmaEiSihkakqRihoYkqZihIUkqZmhIkooZGpKkYoaGJKmYoSFJKmZoSJKKGRqSpGKGhtRFS5etGJXb1sj1qqYLkEay8ePGcsSMmxrZ9ryLpzWyXY1sjjQkScUMDUlSMUNDklTM0JAkFTM0JEnFDA1JUjFDQ5JUzNCQJBUzNCRJxbp6RXhEzASOq5/Oz8zzIuIg4BJgAnB9Zp5f990dmA1sDvwA+EBmLu9mfZKktdO1kUYdDocAbwd2B/aIiL8C5gDTgLcCe0bEofUiVwEfyswpwBjg1G7VJklaN90caSwGZmTmUoCIeAyYAjyRmYvqtquAYyPiUWBCZv6oXvYK4ELgS60rjIiJwMS27Uzu2ncgSVpN10IjM3/a9zgidgCOBz5PFSZ9FlO96b+xn/Z2ZwMzh7xYSVKRrp8Ij4idgduAc4AnO3RZSXU4qlN7u1nAtm1f+w1NpZKkwXT7RPi+wI3A2Zl5XUQcAExq6bIl8BzwbD/tq8nMJcCStm0MddmSpH5080T4VsC3gemZeV3dfG/1UmwfEWOB6cCtmfk08EodMgAnArd2qzZJ0rrp5kjjHGBj4JKW0cCXgZOoRh8bA7cAN9SvnQDMjohNgfupzn9IkoaRbp4IPws4q5+Xd+vQ/0FgarfqkSStP68IlyQVMzQkScUMDUlSMUNDklTM0JAkFTM0JEnFDA1JUjFDQ5JUzNCQJBUzNCRJxQwNSVIxQ0OSVMzQkCQVMzQkScUMDUlSMUNDPbV02YqmS5C0Hrp6j3Cp3fhxYzlixk093+68i6f1fJvSSORIQ5JUzNCQJBUzNCRJxQwNSVIxQ0OSVMzQkCQVMzQkScUMDUlSMUNDklTM0JAkFTM0JEnFDA1JUjFDQ5JUzNCQJBXr+tToEbEZcDdweGY+FRFzgP2AF+suF2bm3Ig4CLgEmABcn5nnd7s2SdLa6WpoRMRewGxgSkvznsD+mbm4pd8EYA5wAPAMMD8iDs3MW7tZnyRp7XR7pHEqcCbwTYCI2ATYGpgdEVsDc4ELganAE5m5qO53FXAsYGhI0jDS1dDIzFMAIqKvaQtgAXA68AJwM3By/Xhxy6KLgcnt64uIicDEtuY1+kmSuqOnt3vNzF8AR/U9j4gvACcC3+rQfWWHtrOBmd2pTpI0mJ5+eioido2Io1uaxgDLgGeBSS3tWwLPdVjFLGDbtq/9ulOtJKldT0caVCExKyIWUB2SOg24ErgXiIjYHlgETKc6Mb6azFwCLGltazn0JUnqsp6ONDLzIeAzwELgUeCBzLw2M18BTgJurNsfB27oZW2SpMH1ZKSRmdu0PL4cuLxDnzuA3XpRjyRp3XhFuCSpmKEhSSpmaEiSihkakqRihoYkqZihIUkqNmhoRMQZnR5Lkkaffq/TiIikulJ7n4hYCDxGNWvtl3pUmyRpmBlopLEL8DVgM6pJAh8GdoiIWRFx1ADLSZJGqIFCY9vM/D7wbGYenZk7Ak8B3wP27UVxkqThZaBpRC6LiO2AiRHxUeB+YFVm3gTc1JPqJEnDSr8jjcw8FNgJ+D3wW6r7YGwXEY9ExFd6VJ8kaRgZcMLCzFweEY9n5pcB6lu0Hgfs3YviJEnDy6Cz3Gbmezo8vr1rFUmShi0v7pMkFTM0JEnFDA1JUjFDQ5JUzNCQJBUzNKQRaumyFaNqu+qNQT9yK2nDNH7cWI6Y0fvJG+ZdPK3n21TvONKQJBUzNCRJxQwNSVIxQ0OSVMzQkCQVMzQkScUMDUlSMUNDklTM0JAkFTM0JEnFDA1JUrGuzj0VEZsBdwOHZ+ZTEXEQcAkwAbg+M8+v++0OzAY2B34AfCAzl3ezNknS2uvaSCMi9gJ+CEypn08A5gDTgLcCe0bEoXX3q4APZeYUYAxwarfqkiStu24enjoVOBN4rn4+FXgiMxfVo4irgGMj4s3AhMz8Ud3vCuDYLtYlSVpHXTs8lZmnAEREX9MbgcUtXRYDkwdoX0NETAQmtjV37CtJGnq9vJ/GmA5tKwdo7+RsYOaQVSRJWiu9DI1ngUktz7ekOnTVX3sns6gOX7WaDNw1NCVKkgbSy9C4F4iI2B5YBEwH5mTm0xHxSkTsm5kLgROBWzutIDOXAEta21oOf0mSuqxn12lk5ivAScCNwKPA48AN9csnAJdGxGPAJsDne1WXJKlc10camblNy+M7gN069HmQ6tNVkqRhzCvCJUnFDA1JUjFDQ5JUzNCQJBUzNCRJxQwNSVIxQ0OSVMzQkCQVMzQkScUMDUlSMUNDklTM0JAkFTM0JEnFDA1JUjFDQ5JUzNAYhZYuW9F0CZI2UL283auGifHjxnLEjJsa2fa8i6c1sl1JQ8ORhiSpmKEhSSpmaEiSihkakqRihoYkqZihIUkqZmhIkooZGpKkYoaGJKmYoSFJKmZoSBpSTc5t5rxq3efcU5KGlHObjWyONCRJxQwNSVIxQ0OSVKyRcxoRsQDYAlhWN50ObAecD4wHLs3MLzZRmySpfz0PjYgYA+wIbJ2Zy+u2NwHXAXsAfwDujojvZeajva5PktS/JkYaAawCbo2INwCzgd8DCzLz1wARcQNwDHDRagtGTAQmtq1vctcrliQBzYTGnwJ3AGcAE4A7geuBxS19FgNTOyx7NjCzy/VJkvrR89DIzHuAe+qnL0bE14BLgE+1dV3ZYfFZwBVtbZOBu4ayRklSZ02c03gXsFFm3lE3jQGeAia1dNsSeK592cxcAixpW193CpUkraGJw1MTgYsiYh9gHPB+4H3AVRHxeuBF4GjgtAZqkyQNoOfXaWTmzcB84H7gJ8CczFwIfBz4HvAAcE1m/rjXtUmSBtbIdRqZ+QngE21t1wDXNFGPJKmMV4RLkooZGg1yGmdJGxqnRm9QU1NIO320pHXlSEOSVMzQkCQVMzQkScUMDUlSMUNDklTM0JA0YjT1MfbR9PF5P3IracTwY+zd50hDklTM0JAkFTM0JEnFDA1JUjFDQ5JUzNCQJBUzNCRJxQwNSVIxQ0OSVMzQkCQVMzQkScVGfWiMponGJGl9jfoJC5ua4AxG1yRnkkaGUT/SkCSVMzQkScUMDUlSMUNDklTM0JAkFTM0JEnFDA1JUjFDQ5LWU5MXCfd626P+4j5JWl+j6SLhYRUaETEdOB8YD1yamV9suCRJUothc3gqIt4EfAp4F7AbcFpE7NRsVZKkVsNppHEQsCAzfw0QETcAxwAX9XWIiInAxLbl3gzw/PPPr/OGl73063Vedn388pe/bGTbTW23yW2Ptu02uW2/595ve120vGeOXZvlxqxatWqdNjjUIuLvgE0y8/z6+SnA1Mw8raXPBcDMZiqUpBFpv8z8YWnn4TTSGNOhbWXb81nAFW1t44G3AE8A6/sxgsnAXcB+wLrF98jjPlmT+2RN7pPVbQj7YyywJfBva7PQcAqNZ6l2cJ8tgedaO2TmEmBJh2V/NhQFRETfw19m5lNDsc4NnftkTe6TNblPVrcB7Y8n13aB4RQatwMXRMTrgReBo4HTBl5EktRLw+bTU5n5LPBx4HvAA8A1mfnjZquSJLUaTiMNMvMa4Jqm65AkdTZsRhrDxBLgQjqfNxmt3Cdrcp+syX2yuhG7P4bNR24lScOfIw1JUjFDQ5JUbFidCG9CRGwG3A0c3vd56ogYB3wH+GRm3tlcdc1o3ycRcRrwYWAVcB9wemYubbLGXuuwT84APkh1Uep84LzMHFXHejv97tTtZwLHZuaBDZXWmA7/T+ZQXX/2Yt3lwsyc21iBQ2BUjzQiYi/gh8CUlrYA7gT2aaisRrXvk4iYApxLtT/eRvV/5szGCmxAh32yLfARYCqwK9W+ObixAhvQ6Xenbt8J+LtGimpYP/tkT2D/zNy9/tqgAwNGeWgAp1K9AbZeeX4y8Dng3kYqal77PvkDcEZm/q7+S/phYOumimvIavskMxcBO2Xmi1QTaG7OCPyUzCDW+N2JiI2ArwCfaKqohq22TyJiE6rfldkR8VBEXBgRG/x77qg+PJWZp8Bql/yTmefVbWc3VFaj2vdJZj4NPF23vZ7qkMxJDZXXiH7+nyyLiFOBfwR+THVB6qjRaZ8AnwHmAIuaqKlpHfbJFsAC4HTgBeBmqj9KZzdR31DZ4FNPvVHf7+QO4Guj8TxPJ5k5G3gt8DxwQbPVNCsiDga2zsyvN13LcJGZv8jMozLzPzLzJeALwGFN17W+DA0NKiJ2BBYCV2bmJ5uup2kRsVVE7AuQmcuB66jO94xmfwXsHBEPAF8F3hER1zdcU6MiYteIOLqlaQywrKl6hsqoPjylwUXEpsB3gb/PzKuarmeY2By4OiJ2B35LdbOw4vsRjESZ+Td9jyPiQOCCzDy+uYqGhTHArIhYQHV46jTgymZLWn+ONDSYU6iOzZ4TEQ/UXxcNttBIlpmPUB2/vxt4EHgJuLjRojTsZOZDVP9PFgKPAg9k5rXNVrX+nEZEklTMkYYkqZihIUkqZmhIkooZGpKkYoaGJKmYoSF1QUTcUk/eN1CfKyLinH5e+4eImNad6qR158V9Uhdk5vpOF/FnVJ/tl4YVr9OQ+hER9wPnZubtEfGXwBXAn2bmyxExm2rG322BA4CxwP3AhzPzdxHxFHBMZt4XER+jmqju98APgPdm5jYRcQWwGTCJ6gLKR4DpVBNCfhb4FfCRkTCdtkYOD09J/ZsLvLt+/G7gN8B+9fTW76GaFn05sEdm7kY1Jfb/bF1BRPw3qhDYE9gD2LRtG28CDqK6B8Nk4C8y84tUN7s618DQcOPhKal/c6kmIzyH6u5rl1DdbOn3wJPA4VTBcXA9HfZ44D/a1nEY8K3MXAIQEV8E/rzl9W/XM6ASEY8Ab+jWNyMNBUcaUj8y82FgfEQcCfwcmAccAhwJ3Eh1SOqsvruyUd3J75i21Synmriuz4q211tnPV3V1lcadgwNaWBzqc4vfDczH6ea4fYEqtD4V+CDETG+PmQ1m2qCulbzgaMjYvP6+clU4TCY5cC4IahfGlKGhjSwucCOwG3189uAxZn5DPBJ4CmqE+CPUo0SZrQunJkLqMLknoi4jyp0XirY7jzgHyPi/UPwPUhDxk9PSV0UEe8A9snMz9fPPwLs5b0mtKHyRLjUXT8DPhoRp1Edlvo/VDfjkTZIjjQkScU8pyFJKmZoSJKKGRqSpGKGhiSpmKEhSSpmaEiSiv0/GkpcPUJ3+sQAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "guess = 1.0\n", + "measurement = 20.\n", + "\n", + "conditioned_scale = pyro.condition(scale, data={\"measurement\": measurement})\n", + "\n", + "#marginal = pyro.infer.EmpiricalMarginal(\n", + "# pyro.infer.Importance(conditioned_scale, num_samples=1000).run(guess), sites=\"weight\")\n", + "nuts_kernel = NUTS(conditioned_scale, adapt_step_size=True)\n", + "mcmc_run = MCMC(nuts_kernel, num_samples=500, warmup_steps=300).run(guess)\n", + "marginal = pyro.infer.EmpiricalMarginal(mcmc_run, 'weight')\n", + "\n", + "# The marginal distribution concentrates around the data\n", + "print(marginal())\n", + "plt.hist([marginal().item() for _ in range(1000)],)\n", + "plt.title(\"P(weight | measurement, guess)\")\n", + "plt.xlabel(\"weight\")\n", + "plt.ylabel(\"#\");" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Conditioning Model with i.i.d. Data" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "# Number of i.i.d. samples\n", + "\n", + "# sample i.i.d points\n", + "measurement = torch.tensor([21.0, 20.0, 19.0, 19.5, 20.5])\n", + "\n", + "# condition\n", + "with pyro.iarange(\"data\", len(measurement)) as ind:\n", + " x=ind\n", + " measurement_batch = measurement[[ind]]\n", + " conditioned_scale_iid = pyro.condition(scale, data={\"measurement\": measurement_batch})\n", + "\n", + "num_samples = 1000\n", + "#marginal = pyro.infer.EmpiricalMarginal(\n", + "# pyro.infer.Importance(conditioned_scale, num_samples=1000).run(1.0), sites=\"weight\")\n", + "nuts_kernel = NUTS(conditioned_scale_iid, adapt_step_size=True)\n", + "mcmc_run = MCMC(nuts_kernel, num_samples=500, warmup_steps=300).run(1.0)\n", + "marginal = pyro.infer.EmpiricalMarginal(mcmc_run, 'weight')" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor(17.7418)\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYYAAAEXCAYAAACpuuMDAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvDW2N/gAAF6pJREFUeJzt3Xu4JHV95/H3OMwAQWBUlEEHRBG+XAUWh6sgT0QCCAxEgSzkAVxuIqwQB5BdIAPGaIwBBy/ILkpQuSbwjMhFIzBiEHESDHfwKxJmgzBkSXRkuTkXZv+oOtC/nj6XOedU95xz3q/n4aGruqp+31/3mf50VXX9atKKFSuQJKnPG3pdgCRp9WIwSJIKBoMkqWAwSJIKBoMkqWAwSJIKBsMEExGbRsTyiLi/5b8HIuK/tS33nYjYdhTauz8ipg2yzJ0R8dF+nrssInbq57mFI61vIomImRFxaa/rGExEvDEivh8Ra/e6lonKYJiYXs7MHfr+Aw4ALoyI9wJExOHA7zLz4ZE2VLexeASb+BAwaaR1CIBtgBm9LmIwmfkCcA3wF72uZaJao9cFqPcy8+mIeBzYAngQuAA4LCImA88Cu2XmryLibODkzHwnQETcBnwJuBu4GNgOmALcAZyZmcsiYgXwVuC3wBeBg4HfAQuArTNz77qMWRFxFrAhcDtwAtUHw9uBqyLi6MxcMJT+1HsSVwMfBt4CzAH2AHYClgIHZ+YzEfEO4KvAJnXd12bm5+pt/E/gEGAtYB3gjMycFxFbAt+s508CvpGZl0TE+cAGmXlqvf5r0xFxJ/AbYEvg68C3B3i9Xqlf0wOB9YAzgcPqZZ8BDsrMFyNiq3obbwEmA1/OzMsjYm/gL4F/BbYF1gROAX4FfAZYPyL+NjM/NsDrN5l+3qu6L1/NzOvrZV+bHqCmNwJ/C2wOvAr8HDgJ+INO8zPzVeDvgC9ExBcz89/7q1XNcI9BRMRuwHuABfXho7Uz8+HMXA7cBOxXL7ofMDUitoiI9YEdqD7EvwT8PDN3AnYENgA+1dbM8VQfzNsCuwGbtT2/bj1/K2B/YI/MPIfqw/CooYZCi7Uyc3tgNvC/gYvr6aeAY+tlvgNcXte9M7BPRBweEe8E9gE+kJnvBc6h+lCF6oP6pnqdA4C9ImIo/45+m5lbZ+ZXGPj1WhNYlJnbAZcA3wBOB7YG1qcK0DWA64Gz6218ADgjInatt7ELcGFm7kgVYudn5lPAnwN3DRQKtcHeq5UMUtOhwLr13unMepV3DzCfzHwF+AnVa6wuc49hYlo7Iu6vH68B/AfVh+9TEbEL1bfLPvOAj0fEt4CNqL6Jf4jqG/APMnNJRBwI7BwRx/Vtv0ObBwDfrv/BExH/C/hky/PX1UH0Ur338rYR9vGG+v9PAM9m5gMt02+OiHWoPrzeHBF9hyzeCOyQmX8XEccAR0XEe4Bd6+egej2+HRE7U4XiJzPz1YgYrJ67Wh4P9nq11v5QZj4NEBFPAm+m2rPbDLi8pd21qULmMeD/ZGbf+/svvB6EQzXYe9XJQDX9APhcvXdxGzC33gNd3ml+yzafAAZ9YTX6DIaJ6eX6W1onr1IdBuhzG9W31g8Dd9bTJwMvAdfVy0wGDsvMxwDqk83tg3AtozxXsLzt+aUtj1cw8vMKv+9n230m123snpkvAUTEBsArEfFfgBupvtn/EPgx1SEgMvPmiNicKhw/CMyJiN071Dy1rb0X2toe6PUaSu2LW9/DiNiQ6rDPrsDLLcsO57Uc6L3qr5/91pSZr9QBuzfwh8DtEfHf68NPHee3bHPJKtauUeChJLX7JfXuPLy2S/9jquP0fR+SuwF7Un0TBPgH4M8iYlJErAl8Dzi1bbu3AH8aEWvWhx2OZeXw6GQZ1XH4UZWZzwM/oz6EU3843w3MAvYC7s3Mi6j6ewh1WEbE1cARmXkt8AngeWBj4Dlgp/o1WAfYd4Dmh/J6DVg+VYD9aV3TxsDDVId/BjLU13Kg9+o54H11u5sB7x2spog4mepcwg8z89NU/d+2v/ktdbwb+MUQ6tUoMxhUqH+J9HJ9IrHPPKpDBfMz82XgAeDuvkMNVIcZ1gEeojp5/RDw122bvoLqJOZ9wE+pvgm+NISSvgtcFxEDfdAO15HArhHxUF3bNZl5FdUvYjaIiEepToi+QHXIaV2qE+JHRcQD9TrzqMLjKqoPzceBW4F7Bmh3KK9XvzJzCVWAHR8RD1IF9nmZefcgq94DbBkR8wAi4taIOLjDclfQ/3v1WWDfiHgY+ALwj0Oo6dtUwfpoRNxLdVL94gHmUwfmblTnuNRlkxx2W+0i4kjg/Zn5iVHc5r7A2zLzynr6YuCV+pvicLe5MDM3HaUSJ5yIOAH4j8yc1zZ/1N+rYdR2LLBNZp7ZrTb1OvcYtJLMvBp4S0RsN4qbfQQ4JqqL6R6h+gnr50Zx+1p1y4CbO8zv6XtV75kdCZzfrTZVco9BklRwj0GSVDAYJEmFMXMdQ/0rhZnAIlb+DbwkqbPJVBen/nNm/n6whWEMBQNVKNw16FKSpE72pBpmZFBjKRgWAVx11VVMnz6917VI0pjw7LPPctRRR0H9GToUYykYlgNMnz6dGTNW+5GDJWl1M+RD8J58liQVDAZJUsFgkCQVDAZJUqHRk88RMQc4vJ68JTPPiojLqX429WI9/4L2QbwkSb3TWDBExD5UY9LvSDWW+w8i4lCq6xH2yswh/3RKktQ9Te4xLAJm1+O0ExGPUd10fRPgsojYhGos+wvqm3+/pr5pyrS27fkbVUnqgsaCITMf6Xtc3wrxCOD9VLfxO4nq5ic3A8cBl7WtfjrVHcOkVbJk6XKmTpk8+ILjpF2pCY1f4BYR21DdKvCMzEzg0JbnvgIczcrBMJfqLlKtZuCQGBrE1CmTOWj2jV1v96YLZ3W9TakpTZ983gO4ATg9M6+tb/yyRWbeUC8yiQ43O8/MxcDitm01WaokqdbkyeeNqe7Xe0Rmzq9nTwLmRsR8qkNJJwLfaqoGSdKqa3KP4QxgLeCilm/7lwKfB+4GpgA3ZOY1DdYgSVpFTZ58Pg04rZ+nL2mqXUnSyHjlsySpYDBIkgoGgySpYDBIkgoGgySpYDBIkgoGgySpYDBIkgoGgySpYDBIkgoGgySpYDBIkgoGgySpYDBIkgoGgySpYDBIkgoGgySpYDBIkgoGgySpYDBIkgoGgySpYDBIkgoGgySpYDBIkgoGgySpYDBIkgoGgySpYDBIkgoGgySpYDBIkgoGgySpYDBIkgprNLnxiJgDHF5P3pKZZ0XEPsBFwNrAdZl5bpM1SJJWTWN7DHUA7AvsCOwA7BQR/xW4HJgFbAXMjIj9m6pBkrTqmjyUtAiYnZlLMnMp8BiwBfB4Zj6ZmcuAK4HDGqxBkrSKGjuUlJmP9D2OiM2BI4AvUwVGn0XAjPZ1I2IaMK1t9krLSZJGX6PnGAAiYhvgFuAMYCkQbYu82mG104E5DZcmSeqg0V8lRcQewB3A2Zn5LeBpYHrLIhsBz3RYdS7wrrb/9myyVklSpbE9hojYGPgucERmzq9nL6ieivcATwJHUp2MLmTmYmBx2/aaKlWS1KLJQ0lnAGsBF7V8qF8KHAvcUD93K3B9gzVIklZRkyefTwNO6+fp7ZtqV5I0Ml75LEkqGAySpILBIEkqGAzSKFiydPmEbFvjU+MXuEkTwdQpkzlo9o09afumC2f1pF2NX+4xSJIKBoMkqWAwSJIKBoMkqWAwSJIKBoMkqWAwSJIKBoMkqWAwSJIKBoMkqWAwSJIKBoMkqWAwSJIKBoMkqWAwSJIKBoMkqWAwSJIKBoMkqWAwSJIKBoMkqWAwSJIKBoMkqWAwSJIKBoMkqWAwSJIKBoMkqWAwSJIKazTdQESsB/wUODAzF0bE5cCewIv1Ihdk5rym65AkDU2jwRARuwCXAVu0zJ4J7JWZi5psW5I0PE0fSjoBOAV4BiAi1gE2AS6LiAcj4oKI8HCWJK1GGt1jyMzjASKib9aGwHzgJOAF4GbgOKq9itdExDRgWtvmZjRZqySp0vg5hlaZ+a/AoX3TEfEV4GjaggE4HZjTxdI0ipYsXc7UKZN7XYakYepqMETEdsAWmXlDPWsSsLTDonOBK9rmzQDuaq46jZapUyZz0Owbe9L2TRfO6km70njS1WCgCoK5ETGf6lDSicC32hfKzMXA4tZ5LYejJEkN6uqJ38x8EPg8cDfwKHB/Zl7TzRokSQPryh5DZm7a8vgS4JJutCtJWnX+VFSSVDAYJEkFg0GSVDAYJEkFg0GSVDAYJEkFg0GSVDAYJEmFQYMhIk7u9FiSND71e+VzRCSwANg9Iu4GHqO6v8LXu1SbJKkHBtpj2Bb4JrAe1RDYDwGbR8TciDh0gPUkSWPYQMHwrsz8MfB0Zn4kM7cEFgI/AvboRnGSpO4baBC9iyNiM2BaRHwauA9YkZk3Ar0ZbF+rxBvmSBqOfoMhM/ePiDWABH5Hdee1zSLiYeDuzDypSzVqmHp1wxxvliONbQMOu52ZyyLiF5l5KUBEbAIcDuzWjeIkSd036P0YMvPDHR7f3lhFkqSe8gI3SVLBYJAkFQwGSVLBYJAkFQwGSVLBYJAkFQwGSVLBYJAkFQwGSVLBYJAkFQwGSVLBYJAkFQwGSVLBYJAkFQwGSVLBYJAkFQwGSVJh0Du4jURErAf8FDgwMxdGxD7ARcDawHWZeW6T7UuSVl1jewwRsQvwE2CLenpt4HJgFrAVMDMi9m+qfUnS8DS5x3ACcArwnXp6Z+DxzHwSICKuBA4Dvt++YkRMA6a1zZ7RXKmSpD6NBUNmHg8QEX2z3g4sallkEf1/2J8OzGmqNklS/xo9x9BmUod5r/az7FzgirZ5M4C7RrMgSdLKuhkMTwPTW6Y3Ap7ptGBmLgYWt85r2fOQJDWom8GwAIiIeA/wJHAk1cloSdJqpGvXMWTmK8CxwA3Ao8AvgOu71b4kaWga32PIzE1bHt8BbN90m5Kk4fPKZ0lSwWCQJBUMBklSwWDogiVLl/e6BI1jvfr78u96/Ormz1UnrKlTJnPQ7Bu73u5NF87qepvqPv++NNrcY5AkFQwGSVLBYJAkFQwGSVLBYJAkFQwGSVLBYJAkFQwGSVLBYJAkFQwGSVLBYJAkFQwGSVLBYJAkFQwGSVLBYJAkFQwGSVLBYJAkFQwGSVLBYJAkFQwGSVLBYJAkFQwGSVLBYJAkFQwGSVLBYJAkFQwGSVLBYJAkFdboRaMRMR/YEFhazzopMxf0ohZJUqnrwRARk4AtgU0yc1m325ckDawXh5ICWAF8PyIeiIhTe1CDJKkfvTiU9CbgDuBkYG3gzojIzLytb4GImAZMa1tvRvdKlKSJq+vBkJn3APfUky9GxDeBA4DbWhY7HZjT7dokSb05x/B+YM3MvKOeNYnXT0L3mQtc0TZvBnBXs9VJknpxKGka8JmI2B2YAhwDfLx1gcxcDCxunRcRXStQkiayrp98zsybgVuA+4CfA5fXh5ckSauBnlzHkJnnAef1om1J0sC88lmSVDAYJEkFg0GSVJgwwbBk6fJelyCNK738N+W/52b15ORzL0ydMpmDZt/Yk7ZvunBWT9qVmuS/qfFrwuwxSJKGxmCQJBUMBklSwWCQJBUMBklSwWCQJBUMBklSwWCQJBUMBklSwWCQJBUMBklSwWCQJBUMBklSwWCQJBUMBklSwWCQpCGaKDcnmjA36pGkkZooNydyj0GSVDAYJEkFg0GSVDAYJEkFg0GSVDAYJEkFg0GSVDAYJEkFg0GSVDAYJEkFg0GSVOjJWEkRcSRwLjAV+FJmfq0XdUiSVtb1PYaIeAfwl8D7ge2BEyNi627XIUnqrBd7DPsA8zPzNwARcT3wUeAzfQtExDRgWtt67wR49tlnh93w0pd+M+x1R+LXv/51T9qeaO32sm373P22e2Ws9bnlM3PyUNeZtGLFimE1NlwR8T+AdTLz3Hr6eGDnzDyxZZnzgTldLUySxrc9M/MnQ1mwF3sMkzrMe7Vtei5wRdu8qcC7gceBkdyxYgZwF7An0LuvHb1j/+2//Z9Y/Z8MbAT881BX6EUwPE31pvTZCHimdYHMXAws7rDuL0faeET0Pfx1Zi4c6fbGGvtv/2v2f2L1/4lVWbgXwXA7cH5EvBV4EfgIcOLAq0iSuqXrv0rKzKeBc4AfAfcDV2fmP3W7DklSZz25jiEzrwau7kXbkqSBTcQrnxcDF9D5HMZEYP/tv/2fuP0fkq7/XFWStHqbiHsMkqQBGAySpEJPTj43KSLWA34KHAhsDXyu5el3AAsy88C2dTYBrgTeBiRwVGa+0J2KR9cw+3808AXg3+tZt2TmOV0od9S19j8zF0bEvsAXqS7y+Rfg+Mxc0rbONOAqqgsonwMOz8zhj73SQ8Ps/17APOCpetZ9mfmxLpY9ajr0/1jgLKqLYucDszNzWds64+b9Hy3jao8hInYBfgJsAZCZt2bmDpm5A7Af8DzwZx1WvQS4JDO3BO4FzutSyaNqBP2fCXyqb9kxHApF/2vfBP4kM7cF/gA4usOqnwXuysytgMuAi5uutQkj6P9M4G9a3v+xGgpF/6O6mu2zwAczcztgCvDJDquOi/d/NI2rYABOAE6h7Urq2heBSzPz8daZETEF2Au4vp51BXBYgzU2aZX7X5sJHB0RD0TElRHxpiaLbFCn/k8G1ouIycBawMsd1vsw1TdGgGuA/eu/i7FmuP2fCXwoIu6LiO9FxMbNl9qI9v6/F7gnMxfV0zcDh3RYb7y8/6NmXAVDZh6fmXe1z4+IzYG9gS93WG0D4PmW3ctFVOOpjDnD7D9UfT4f2IHqcMJXGyqxUf30/xPAnVQfFhvw+heAVm+neg2o/w6eB97aXKXNGEH/FwMXZ+aOwK3AtU3W2ZQO/X8A2DUiNq6D8aPA9A6rjov3fzSNq2AYwIlUh4p+3+G5oQzqN9YN1H8y89DMXJCZK4C/Bg7oanUNiYjpwF8B21KNyfUz4KIOi47Lv4Gh9j8zP56ZN9aPLwW2iYj1u1lrEzLzl8DZwPeoBs57EFjSYdFx+f6PxEQJhkPo/1vQc7y+qw0dBvUbB/rtf0SsHxGt5x0mAUu7UlXz9gQezswnMvNVquPHe3dY7mnqb5IRsQawHvCf3SqyQYP2PyLeEBHntPz99xnzfwMRsRbwT5m5Y2buDvwbnQeTG6/v/7CN+2CIiA2AtTPzyU7PZ+ZSqm8TR9Szjga+36XyGjdY/4EXgLPqE3cAp1L9QmU8eBjYOSI2rKdn0Xno4Vt5/aTsEVQnIsf8ByND6H8dGIdSDWbZ9wu1BZn5UjcLbcg6wPyIWDciplKdeL6uw3Lj9f0ftnEfDFQ/QVtp3PWI+EZEHFxPfoLqFqOPUn3LOreL9TVtwP5n5nLgcODrEfEYsBPVz/vGvMx8jOoXZj+KiAeB9wFnAETEZyLi4/Wi51Edi36E6m/hlF7UO9pWof/HAKfX/f8YcHwv6h1tmfmfVOfOfkYVknfW47RNiPd/JBwSQ5JUmAh7DJKkVWAwSJIKBoMkqWAwSJIKBoMkqWAwSCMQEbdGxNaDLHNFRJzRz3N/HhGzmqlOGp5xN+y21E2ZOdLhQ/4QeHQ0apFGi9cxaMKLiPuAMzPz9oj4E6oRdt+UmS9HxGXAQ8C7gA9QjVZ6H/DJzHw+IhYCH83MeyPibOA44P8B/wgckpmbRsQVVMMsTAc2pLrY6kjgWKr7YDxHNez5eLniXGOch5KkagiQ/erH+wG/BfaMiDdQDck8DVgG7JSZ21ONpfVXrRuIiD+i+qCfSXX1+LptbbwD2IfqXgEzgD/OzK9R3f/jTENBqxMPJUlVMFxLNVzEnlQjkH6I6pv/E1R3w5tGdc8CgKnA/23bxgHA32fmYoCI+BrwwZbnv9s3/lBEPEx1t0BpteQegya8zHwImFqPnfUr4CZgX+Bg4Aaqw0entdwNb2eqsf1bLaMcvnl52/Otg7KtoPNQz9JqwWCQKvOojvf/MDN/AawPHEUVDP8AnBoRU+vDS5cBn29b/xbgIy33MTiOKgAGs4zqlpPSasNgkCrzgC2B2+rp24BFmfkU8BfAQqqTzo9Sfduf3bpyZs6nCox7IuJeqmAZytDVNwF/ExHHjEIfpFHhr5KkURAR7wN2z8wv19OfAnbJzCMGXlNa/XjyWRodvwQ+HREnUh1C+jeqW6pKY457DJKkgucYJEkFg0GSVDAYJEkFg0GSVDAYJEkFg0GSVPj/skOENjUTRucAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "print(marginal())\n", + "plt.hist([marginal().item() for _ in range(100)])\n", + "plt.title(\"P(weight | measurement, guess)\")\n", + "plt.xlabel(\"weight\")\n", + "plt.ylabel(\"#\");" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "# One dimensional data point, n dimensional data\n", + "z1 = torch.tensor(0.)\n", + "s1 = torch.tensor(10.)\n", + "n = torch.Size([10])\n", + "# Without use of 'independent'\n", + "x1 = dist.Normal(z1, s1).sample(sample_shape = n)\n", + "# With use of 'independent'\n", + "x2 = dist.Normal(z1, s1).independent().sample(sample_shape = n)\n", + "\n", + "# K-dimensional data point, n dimensional data\n", + "z2 = torch.tensor([0., 100.])\n", + "s2 = torch.tensor([[1., 10.], [10., 1.]])\n", + "# Without use of 'independent'\n", + "x3 = dist.Normal(z2, s2).sample(sample_shape = n)\n", + "# With use of 'independent'\n", + "x4 = dist.Normal(z2, s2).independent().sample(sample_shape = n)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Variational Inference" + ] + }, + { + "cell_type": "code", + "execution_count": 72, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "def prev_scale(guess):\n", + " # The prior over weight encodes our uncertainty about our guess\n", + " # This is one dimension\n", + " weight = pyro.sample(\"weight\", dist.Normal(guess, 1.0))\n", + " # This encodes our belief about the noisiness of the scale:\n", + " # the measurement fluctuates around the true weight\n", + " # data point here is also one dimensional\n", + " return pyro.sample(\"measurement\", dist.Normal(weight, 0.75))\n", + "\n", + "def scale(data):\n", + " guess = torch.tensor(8.)\n", + " weight = pyro.sample(\"weight\", dist.Normal(guess, 1.0))\n", + " f = pyro.sample(\"measurement\", dist.Normal(weight, 0.75))\n", + " \n", + " # Normal\n", + " #for i in range(len(data)):\n", + " # pyro.sample(\"obs_{}\".format(i), dist.Normal(f, 0.75), obs=data[i])\n", + " \n", + " # Conditionally independent\n", + " with pyro.iarange('observe_data'):\n", + " pyro.sample('obs', dist.Normal(f,0.75), obs=data)\n", + "\n", + "\n", + "def scale_guide(data):\n", + " guess_q = pyro.param(\"guess_q\", torch.tensor(8.), \n", + " constraint=constraints.positive)\n", + " weight = pyro.sample(\"weight\", dist.Normal(guess_q, 1.0))\n", + " \n", + " pyro.sample(\"measurement\", dist.Normal(guess_q, 0.75))" + ] + }, + { + "cell_type": "code", + "execution_count": 73, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "# setup the optimizer\n", + "adam_params = {\"lr\": 0.0005, \"betas\": (0.90, 0.999)}\n", + "optimizer = Adam(adam_params)\n", + "\n", + "# setup the inference algorithm\n", + "svi = SVI(scale, scale_guide, optimizer, loss=Trace_ELBO())" + ] + }, + { + "cell_type": "code", + "execution_count": 74, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "...................." + ] + } + ], + "source": [ + "# this is for running the notebook in our testing framework\n", + "n_steps = 2000\n", + "\n", + "# create data\n", + "measurement = torch.tensor([18.,19.,13.,17.5,16.,18.2,20.5])\n", + "#measurement = torch.tensor([8.,9.,3.,7.5,6.,8.2,10.5])\n", + "\n", + "# do gradient steps\n", + "for step in range(n_steps):\n", + " svi.step(measurement)\n", + " if step % 100 == 0:\n", + " print('.', end='')" + ] + }, + { + "cell_type": "code", + "execution_count": 75, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "16.720230102539062" + ] + }, + "execution_count": 75, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "guess_q = pyro.param(\"guess_q\").item()\n", + "guess_q" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}