{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Store and restore a network simulation\n\nThis example shows how to store user-specified aspects of a network\nto file and how to later restore the network for further simulation.\nThis may be used, e.g., to train weights in a network up to a certain\npoint, store those weights and later perform diverse experiments on\nthe same network using the stored weights.\n\n.. admonition:: Only user-specified aspects are stored\n\n   NEST does not support storing the complete state of a simulation\n   in a way that would allow one to continue a simulation as if one had\n   made a new ``Simulate()`` call on an existing network. Such complete\n   checkpointing would be very difficult to implement.\n\n   NEST's explicit approach to storing and restoring network state makes\n   clear to all which aspects of a network are carried from one simulation\n   to another and thus contributes to good scientific practice.\n\n   Storing and restoring is currently not supported for MPI-parallel simulations.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Import necessary modules.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import nest\nimport pickle"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "These modules are only needed for illustrative plotting.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\nfrom matplotlib import gridspec\nimport numpy as np\nimport pandas as pd\nimport textwrap"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Implement network as class.\n\nImplementing the network as a class makes network properties available to\nthe initial network builder, the storer and the restorer, thus reducing the\namount of data that needs to be stored.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "class EINetwork:\n    \"\"\"\n    A simple balanced random network with plastic excitatory synapses.\n\n    This simple Brunel-style balanced random network has an excitatory\n    and inhibitory population, both driven by external excitatory poisson\n    input. Excitatory connections are plastic (STDP). Spike activity of\n    the excitatory population is recorded.\n\n    The model is provided as a non-trivial example for storing and restoring.\n    \"\"\"\n\n    def __init__(self):\n        self.nI = 500\n        self.nE = 4 * self.nI\n        self.n = self.nE + self.nI\n\n        self.JE = 1.0\n        self.JI = -4 * self.JE\n        self.indeg_e = 200\n        self.indeg_i = 50\n\n        self.neuron_model = \"iaf_psc_delta\"\n\n        # Create synapse models so we can extract specific connection information\n        nest.CopyModel(\"stdp_synapse_hom\", \"e_syn\", {\"Wmax\": 2 * self.JE})\n        nest.CopyModel(\"static_synapse\", \"i_syn\")\n\n        self.nrn_params = {\"V_m\": nest.random.normal(-65., 5.)}\n        self.poisson_rate = 800.\n\n    def build(self):\n        \"\"\"\n        Construct network from scratch, including instrumentation.\n        \"\"\"\n\n        self.e_neurons = nest.Create(self.neuron_model, n=self.nE, params=self.nrn_params)\n        self.i_neurons = nest.Create(self.neuron_model, n=self.nI, params=self.nrn_params)\n        self.neurons = self.e_neurons + self.i_neurons\n\n        self.pg = nest.Create(\"poisson_generator\", {\"rate\": self.poisson_rate})\n        self.sr = nest.Create(\"spike_recorder\")\n\n        nest.Connect(self.e_neurons, self.neurons,\n                     {\"rule\": \"fixed_indegree\", \"indegree\": self.indeg_e},\n                     {\"synapse_model\": \"e_syn\", \"weight\": self.JE})\n        nest.Connect(self.i_neurons, self.neurons,\n                     {\"rule\": \"fixed_indegree\", \"indegree\": self.indeg_i},\n                     {\"synapse_model\": \"i_syn\", \"weight\": self.JI})\n        nest.Connect(self.pg, self.neurons, \"all_to_all\", {\"weight\": self.JE})\n        nest.Connect(self.e_neurons, self.sr)\n\n    def store(self, dump_filename):\n        \"\"\"\n        Store neuron membrane potential and synaptic weights to given file.\n        \"\"\"\n\n        assert nest.NumProcesses() == 1, \"Cannot dump MPI parallel\"\n\n        ###############################################################################\n        # Build dictionary with relevant network information:\n        #   - membrane potential for all neurons in each population\n        #   - source, target and weight of all connections\n        # Dictionary entries are Pandas Dataframes.\n        #\n        # Strictly speaking, we would not need to store the weight of the inhibitory\n        # synapses since they are fixed, but we do so out of symmetry and to make it\n        # easier to add plasticity for inhibitory connections later.\n\n        network = {}\n        network[\"n_vp\"] = nest.total_num_virtual_procs\n        network[\"e_nrns\"] = self.neurons.get([\"V_m\"], output=\"pandas\")\n        network[\"i_nrns\"] = self.neurons.get([\"V_m\"], output=\"pandas\")\n\n        network[\"e_syns\"] = nest.GetConnections(synapse_model=\"e_syn\").get(\n            (\"source\", \"target\", \"weight\"), output=\"pandas\")\n        network[\"i_syns\"] = nest.GetConnections(synapse_model=\"i_syn\").get(\n            (\"source\", \"target\", \"weight\"), output=\"pandas\")\n\n        with open(dump_filename, \"wb\") as f:\n            pickle.dump(network, f, pickle.HIGHEST_PROTOCOL)\n\n    def restore(self, dump_filename):\n        \"\"\"\n        Restore network from data in file combined with base information in the class.\n        \"\"\"\n\n        assert nest.NumProcesses() == 1, \"Cannot load MPI parallel\"\n\n        with open(dump_filename, \"rb\") as f:\n            network = pickle.load(f)\n\n        assert network[\"n_vp\"] == nest.total_num_virtual_procs, \"N_VP must match\"\n\n        ###############################################################################\n        # Reconstruct neurons\n        # Since NEST does not understand Pandas Series, we must pass the values as\n        # NumPy arrays\n        self.e_neurons = nest.Create(self.neuron_model, n=self.nE,\n                                     params={\"V_m\": network[\"e_nrns\"].V_m.values})\n        self.i_neurons = nest.Create(self.neuron_model, n=self.nI,\n                                     params={\"V_m\": network[\"i_nrns\"].V_m.values})\n        self.neurons = self.e_neurons + self.i_neurons\n\n        ###############################################################################\n        # Reconstruct instrumentation\n        self.pg = nest.Create(\"poisson_generator\", {\"rate\": self.poisson_rate})\n        self.sr = nest.Create(\"spike_recorder\")\n\n        ###############################################################################\n        # Reconstruct connectivity\n        nest.Connect(network[\"e_syns\"].source.values, network[\"e_syns\"].target.values,\n                     \"one_to_one\",\n                     {\"synapse_model\": \"e_syn\", \"weight\": network[\"e_syns\"].weight.values})\n\n        nest.Connect(network[\"i_syns\"].source.values, network[\"i_syns\"].target.values,\n                     \"one_to_one\",\n                     {\"synapse_model\": \"i_syn\", \"weight\": network[\"i_syns\"].weight.values})\n\n        ###############################################################################\n        # Reconnect instruments\n        nest.Connect(self.pg, self.neurons, \"all_to_all\", {\"weight\": self.JE})\n        nest.Connect(self.e_neurons, self.sr)\n\n\nclass DemoPlot:\n    \"\"\"\n    Create demonstration figure for effect of storing and restoring a network.\n\n    The figure shows raster plots for five different runs, a PSTH for the\n    initial 1 s simulation and PSTHs for all 1 s continuations, and weight\n    histograms.\n    \"\"\"\n\n    def __init__(self):\n        self._colors = [c[\"color\"] for c in plt.rcParams[\"axes.prop_cycle\"]]\n        self._next_line = 0\n\n        plt.rcParams.update({'font.size': 10})\n        self.fig = plt.figure(figsize=(10, 7), constrained_layout=False)\n\n        gs = gridspec.GridSpec(4, 2, bottom=0.08, top=0.9, left=0.07, right=0.98, wspace=0.2, hspace=0.4)\n        self.rasters = ([self.fig.add_subplot(gs[0, 0])] +\n                        [self.fig.add_subplot(gs[n, 1]) for n in range(4)])\n        self.weights = self.fig.add_subplot(gs[1, 0])\n        self.comment = self.fig.add_subplot(gs[2:, 0])\n\n        self.fig.suptitle(\"Storing and reloading a network simulation\")\n        self.comment.set_axis_off()\n        self.comment.text(0, 1, textwrap.dedent(\"\"\"\n            Storing, loading and continuing a simulation of a balanced E-I network\n            with STDP in excitatory synapses.\n\n            Top left: Raster plot of initial simulation for 1000ms (blue). Network state\n            (connections, membrane potential, synaptic weights) is stored at the end of\n            the initial simulation.\n\n            Top right: Immediate continuation of the initial simulation from t=1000ms\n            to t=2000ms (orange) by calling Simulate(1000) again after storing the network.\n            This continues based on the full network state, including spikes in transit.\n\n            Second row, right: Simulating for 1000ms after loading the stored network\n            into a clean kernel (green). Time runs from 0ms and only connectivity, V_m and\n            synaptic weights are restored. Dynamics differ somewhat from continuation.\n\n            Third row, right: Same as in second row with identical random seed (red),\n            resulting in identical spike patterns.\n\n            Fourth row, right: Simulating for 1000ms from same stored network state as\n            above but with different random seed yields different spike patterns (purple).\n\n            Above: Distribution of excitatory synaptic weights at end of each sample\n            simulation. Green and red curves are identical and overlay to form brown curve.\"\"\"),\n                          transform=self.comment.transAxes, fontsize=8,\n                          verticalalignment='top')\n\n    def add_to_plot(self, net, n_max=100, t_min=0, t_max=1000, lbl=\"\"):\n        spks = pd.DataFrame.from_dict(net.sr.get(\"events\"))\n        spks = spks.loc[(spks.senders < n_max) & (t_min < spks.times) & (spks.times < t_max)]\n\n        self.rasters[self._next_line].plot(spks.times, spks.senders, \".\",\n                                           color=self._colors[self._next_line])\n        self.rasters[self._next_line].set_xlim(t_min, t_max)\n        self.rasters[self._next_line].set_title(lbl)\n        if 1 < self._next_line < 4:\n            self.rasters[self._next_line].set_xticklabels([])\n        elif self._next_line == 4:\n            self.rasters[self._next_line].set_xlabel('Time [ms]')\n\n        # To save time while plotting, we extract only a subset of connections.\n        # For simplicity, we just use a prime-number stepping.\n        w = nest.GetConnections(source=net.e_neurons[::41], synapse_model=\"e_syn\").weight\n        wbins = np.arange(0.7, 1.4, 0.01)\n        self.weights.hist(w, bins=wbins,\n                          histtype=\"step\", density=True, label=lbl,\n                          color=self._colors[self._next_line],\n                          alpha=0.7, lw=3)\n\n        if self._next_line == 0:\n            self.rasters[0].set_ylabel(\"neuron id\")\n            self.weights.set_ylabel(\"p(w)\")\n            self.weights.set_xlabel(\"Weight w [mV]\")\n\n        plt.draw()\n        plt.pause(1e-3)  # allow figure window to draw figure\n\n        self._next_line += 1\n\n\nif __name__ == \"__main__\":\n\n    plt.ion()\n\n    T_sim = 1000\n\n    dplot = DemoPlot()\n\n    ###############################################################################\n    # Ensure clean slate and make NEST less chatty\n    nest.set_verbosity(\"M_WARNING\")\n    nest.ResetKernel()\n\n    ###############################################################################\n    # Create network from scratch and simulate 1s.\n    nest.local_num_threads = 4\n    nest.print_time = True\n    ein = EINetwork()\n    print(\"*** Initial simulation ***\")\n    ein.build()\n    nest.Simulate(T_sim)\n    dplot.add_to_plot(ein, lbl=\"Initial simulation\")\n\n    ###############################################################################\n    # Store network state to file with state after 1s.\n    print(\"\\n*** Storing simulation ...\", end=\"\", flush=True)\n    ein.store(\"ein_1000.pkl\")\n    print(\" done ***\\n\")\n\n    ###############################################################################\n    # Continue simulation by another 1s.\n    print(\"\\n*** Continuing simulation ***\")\n    nest.Simulate(T_sim)\n    dplot.add_to_plot(ein, lbl=\"Continued simulation\", t_min=T_sim, t_max=2 * T_sim)\n\n    ###############################################################################\n    # Clear kernel, restore network from file and simulate for 1s.\n    print(\"\\n*** Reloading and resuming simulation ***\")\n    nest.ResetKernel()\n    nest.local_num_threads = 4\n    ein2 = EINetwork()\n    ein2.restore(\"ein_1000.pkl\")\n    nest.Simulate(T_sim)\n    dplot.add_to_plot(ein2, lbl=\"Reloaded simulation\")\n\n    ###############################################################################\n    # Repeat previous step. This shall result in *exactly* the same results as\n    # the previous run because we use the same random seed.\n    print(\"\\n*** Reloading and resuming simulation (same seed) ***\")\n    nest.ResetKernel()\n    nest.local_num_threads = 4\n    ein2 = EINetwork()\n    ein2.restore(\"ein_1000.pkl\")\n    nest.Simulate(T_sim)\n    dplot.add_to_plot(ein2, lbl=\"Reloaded simulation (same seed)\")\n\n    ###############################################################################\n    # Clear, restore and simulate again, but now with different random seed.\n    # Details in results shall differ from previous run.\n    print(\"\\n*** Reloading and resuming simulation (different seed) ***\")\n    nest.ResetKernel()\n    nest.local_num_threads = 4\n    nest.rng_seed = 987654321\n    ein2 = EINetwork()\n    ein2.restore(\"ein_1000.pkl\")\n    nest.Simulate(T_sim)\n    dplot.add_to_plot(ein2, lbl=\"Reloaded simulation (different seed)\")\n\n    dplot.fig.savefig(\"store_restore_network.png\")\n\n    input(\"Press ENTER to close figure!\")"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.6"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}