# Binary Dense Storage


<!-- WARNING: THIS FILE WAS AUTOGENERATED! DO NOT EDIT! -->

<a target="_blank" href="https://colab.research.google.com/github/bhoov/amtutorial/blob/main/tutorial_ipynbs/00_dense_storage.ipynb">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

<details class="code-fold">
<summary>Notebook Execution Settings</summary>

``` python
CACHE_DIR = "cache/00_dense_storage"
CACHE_RECALL = True # If False, regenerate all saved results even if files exist.
SHOW_FULL_ANIMATIONS = True # If True, render videos instead of gifs. This is slower than gifs and relies on `ffmpeg` to save the animation, but it lets us see the energy descent alongside the frame evolution.
```

</details>

### General Associative Memory

Our goal in this section is to build the smallest abstraction for
Associative Memory, which at its core is just an *energy function*
*E*<sub>*Ξ*</sub>(*σ*) ∈ ℝ. where *query pattern*
*σ* ∈ {−1, 1}<sup>*D*</sup> is a possibly noisy *D*-dimensional, binary
pattern and *memory matrix* *Ξ* ∈ {−1, 1}<sup>*K* × *D*</sup> is our
matrix of *K* stored patterns. *E*<sub>*Ξ*</sub>(*σ*) stores patterns at
low energies. To retrieve our stored patterns,we want to minimize
*E*<sub>*Ξ*</sub>(*σ*).

Let’s assume an unimplemented, arbitrary energy function and setup a
basic object for a binary AM. All we need to provide is an `energy`
method, parameterized by *Ξ*, that is a function of query *σ*.

Historically, the Hopfield Network \[1\] minimizes energy using
*asynchronous* update rules (where we minimize the query’s energy one
randomly selected bit at a time). We’ll follow that precedent in this
notebook since it makes for nicer visualizations, though fully
*synchronous* update rules (where we minimize the energy by scanning
through all bits sequentially) are also possible. The default
`async_update` is simple: for a randomly sampled bit in the query
pattern, compare the energy of that bit when it is flipped and not
flipped. Keep the pattern whose energy is lower.

<span id="eq-async-update">
$$
\sigma_i^{(t+1)} = \underset{b \in \\-1, 1\\}{\mathrm{argmin}}\left\[E\left(\sigma_i = b, \sigma\_{j \neq i} = \sigma_j^{(t)}\right)\right\]
 \qquad(1)$$
</span>

Converting a noisy query pattern into a stored pattern is a matter of
repeatedly applying the `async_update` rule to minimize energy. Because
this process, if run long enough, will “recall” a memory, we call this
the `async_recall` method.

We use `jax` primitives like `jax.lax.scan` and `jax.lax.cond` so we can
JIT our code run quickly. A `scan` is just a glorified for loop, and a
`cond` is a glorified if-else statement.

With all this, we can fully encapsulate a basic, binary AM in the
following object.

``` python
class BinaryAM(eqx.Module):
    Xi: Float[Array, "K D"] # matrix of stored patterns 
    def energy(
        self, 
        sigma: Float[Array, "D"] # Possibly noisy query pattern
        ): 
        ... # Left to implement later

    def async_update(
        self,
        sigma: Float[Array, "D"], # Possibly noisy query pattern
        idx:int,              # Index of bit to flip
        ):                    # Return next state and its energy
        "Minimize the energy of `x[idx]`"
        sigma_flipped = jnp.array(sigma).at[idx].multiply(-1)
        energy_og = self.energy(sigma)
        energy_flipped = self.energy(sigma_flipped)
        keep_flip = (energy_flipped - energy_og) < 0
        return lax.cond(
            keep_flip, 
            # Keep flipped bit if it has lower energy
            lambda: (sigma_flipped, energy_flipped), 
            # Otherwise keep original bit
            lambda: (sigma, energy_og)
        )

    @eqx.filter_jit
    def async_recall(
        self, 
        sigma0: Float[Array, "D"], # Initial query pattern
        nsteps:int=20000, # Number of bits to flip & check
        key=jr.PRNGKey(0) # Random key for bit-flip choices
        ):
        "Minimize energy of `sigma0` by repeatedly applying `async_update`"
        def update_step(sigma, idx):
            sigma_new, energy_new = self.async_update(sigma, idx)
            return sigma_new, (sigma_new, energy_new)
        D = sigma0.shape[-1]

        # Randomly sample `nsteps` bits to flip
        bitflip_sequence = jr.choice(key, np.arange(D), shape=(nsteps,))

        # Apply `async_update` to each bitflip in seq
        final_x, (frames, energies) = lax.scan(update_step, sigma0, bitflip_sequence)

        # Return final pattern and the trajectory
        return final_x, (frames, energies)
```

### Loading data

<div>

> **Note**
>
> Feel free to skip this section. It’s just loading data and setting up
> some fancy visualization functions.

</div>

Let’s build some helper functions to load and view our data: binarized
pokemon sprites. While other fields like to work with {0, 1} binary
data, Hopfield Networks like to work with bipolar data where each
datapoint *σ* ∈ {−1, 1}<sup>*D*</sup>.

``` python
from amtutorial.data_utils import get_pokemon_data
poke_pixels, poke_names = get_pokemon_data()
data = poke_pixels

pxh, pxw = data.shape[-2:]
data = data.reshape(-1, pxh * pxw)

def gridify(images, grid_h=None):
    """Convert list of images to a single grid image"""
    images = np.array(images)  # Shape: (n_images, H*W)
    if grid_h is None: grid_h = int(np.sqrt(len(images)))
    grid_w = int(np.ceil(len(images) / grid_h))

    # Pad if necessary
    n_needed = grid_h * grid_w
    if len(images) < n_needed:
        padding_shape = (n_needed - len(images),) + images.shape[1:]
        padding = np.zeros(padding_shape)
        images = np.concatenate([images, padding], axis=0)
    
    # Reshape individual images and arrange in grid
    grid = rearrange(images[:n_needed], '(gh gw) h w -> (gh h) (gw w)', gh=grid_h, gw=grid_w)
    return grid

def show_im(sigma, ax=None, do_gridify=True, grid_h=None, figsize=None):
    """Vector to figure"""
    sigma = rearrange(sigma, "... (h w) -> ... h w", h=pxh, w=pxw)
    if do_gridify and len(sigma.shape) == 3: sigma = gridify(sigma, grid_h)
    empty_ax = ax is None
    figsize = figsize or (8, 2.67) # Quarto aspect ratio
    if empty_ax: fig, ax = plt.subplots(figsize=figsize)
    ax.imshow(sigma, cmap="gray", vmin=-1, vmax=1)
    ax.axis("off")
    return None if not empty_ax else fig, ax
```

<img src="00_dense_storage_files/figure-commonmark/cell-5-output-1.png"
width="481" height="503" />

## The Classical Hopfield Network

Let’s revisit our task to store *K* binary patterns each of dimension
*D* into an energy function. Let’s keep things simple and fast for the
first part of this notebook and focus on storing and retrieving *K* = 2
patterns: an eevee and pichu, where each `(48,48)` image is rasterized
to a vector dimension of *D* = 2304.

``` python
desired_names = ["eevee", "pichu"]
eevee_pichu_idxs = [poke_names.index(name) for name in desired_names]
Xi = data[eevee_pichu_idxs]

fig, ax = show_im(Xi, figsize=(6,3));
ax.set_title("Stored patterns")
plt.show()

print(f"K={Xi.shape[0]}, D={Xi.shape[1]}")
```

    K=2, D=2304

<img src="00_dense_storage_files/figure-commonmark/cell-6-output-2.png"
width="481" height="272" />

The Classical Hopfield Network (CHN) \[1\] defines an energy function
for this collection of patterns, putting the *μ*-th stored pattern
*ξ*<sup>*μ*</sup> at a *low* value of energy. The CHN energy is a
quadratic function described by dot-product correlations:

<span id="eq-chn-energy">
$$
E\_\text{CHN}(\sigma) = -\frac{1}{2} \sum\_\mu \left(\sum\_{i} \xi^\mu_i \sigma_i\right)^2 = -\frac{1}{2} \sum\_{i,j} T\_{ij} \sigma_i \sigma_j.
 \qquad(2)$$
</span>

We see the familiar equation for CHN energy on the RHS if we expand the
quadratic function, where
$T\_{ij} := \sum\_{\mu=1}^K \xi^\mu_i \xi^\mu_j$ is the matrix of
symmetric synapses. Learned patterns *ξ*<sup>*μ*</sup> are stored in *T*
via a simple, Hebbian learning rule.

The CHN can be easily implemented in code via

``` python
class CHN(BinaryAM):
    def energy(
        self, 
        sigma: Float[Array, "D"] # Possibly noisy query pattern
        ): 
        "Quadratic energy function for the CHN"
        return -0.5 * jnp.sum((self.Xi @ sigma)**2, axis=0)

chn = CHN(Xi)
```

The asynchronous update rule of
<a href="#eq-async-update" class="quarto-xref">Equation 1</a> uses the
energy difference of a flipped bit to determine whether to keep the flip
or not. That update rule is equivalent to the following, arguably more
familiar update rule, which describes the next state based on the sign
of the *total input current* to the neuron *σ*<sub>*i*</sub>.

$$
\begin{align\*}
\sigma_i^{(t+1)} &\leftarrow \text{sgn}\left(\sum\_{\mu} \xi^\mu_i \sum\_{j \neq i} \left(\xi^\mu_j \sigma_j^{(t)}\right) \right)\\
\text{sgn}(x) &:= \begin{cases}
1 & \text{if } x \geq 0 \\
-1 & \text{if } x \< 0
\end{cases}\quad.
\end{align\*}
$$

This update rule also ensures the network always moves toward lower
energy states. Because the *E*<sub>CHN</sub> is bounded from below, the
network will eventually converge to a local minimum that (ideally)
corresponds to one of the stored patterns.

Let’s observe the recall process! We’ll start with a noisy version of
the first pattern and see if we can recover it.

``` python
def flip_some_bits(key, x, p=0.1):
    "Flip `p` fraction of bits in `x`"
    prange = np.array([p, 1-p])
    return x * jr.choice(key, np.array([-1, 1]), p=prange, shape=x.shape)

sigma_og = Xi[0] 
sigma_noisy = flip_some_bits(jr.PRNGKey(0), sigma_og, 0.2)

show_im(jnp.stack([sigma_og, sigma_noisy]), figsize=(6, 3));
```

<img src="00_dense_storage_files/figure-commonmark/cell-8-output-1.png"
width="481" height="251" />

For the pedagogical purpose of this notebook, we’ll cache the recall
process and results so we don’t have to run it every time.

``` python
@delegates(BinaryAM.async_recall)
def cached_recall(am, cache_name, sigma_noisy, key=jr.PRNGKey(0), save=True, **kwargs):
    "Cache the recall process using key `cache_name`"
    npz_fname = Path(CACHE_DIR) / (cache_name + '.npz')
    if npz_fname.exists() and CACHE_RECALL: 
        npz_data = np.load(npz_fname)
        sigma_final, frames, energies = npz_data['sigma_final'], npz_data['frames'], npz_data['energies']
        print("Loading cached recall data")
    else: 
        sigma_final, (frames, energies) = am.async_recall(sigma_noisy, key=key, **kwargs)
        if save: jnp.savez(npz_fname, sigma_final=sigma_final, frames=frames, energies=energies)
    return sigma_final, frames, energies

cache_name = 'basic_hopfield_recovery'
sigma_final, frames, energies = cached_recall(chn, cache_name, sigma_noisy, nsteps=12000, key=jr.PRNGKey(5))
```

    Loading cached recall data

<img src="00_dense_storage_files/figure-commonmark/cell-10-output-1.png"
width="640" height="305" />

We can animate the recall process to view the “thinking” process of the
CHN.

![](cache/00_dense_storage/basic_hopfield_recovery.mp4)

### Retrieving “inverted” images

If we initialize a query with *too much* noise, it’s possible to
retrieve the negative of a stored pattern or an “inverted image”.
Because the energy is quadratic, both *σ* and −*σ* produce the same
small value of energy. Whether we retrieve the original *σ* or the
inverted −*σ* is dependent on whether we initialize our query closer to
the original or inverted pattern.

$$
E\_\text{CHN}(-\sigma) = -\frac{1}{2} \left(\sum\_{\mu} \xi^\mu_i (-\sigma_i)\right)^2 = E\_\text{CHN}(\sigma)
$$

    Loading cached recall data
    Accidentally retrieved the inverted pattern!

<img src="00_dense_storage_files/figure-commonmark/cell-12-output-2.png"
width="654" height="305" />

![](cache/00_dense_storage/hopfield_recovery_inverted.mp4)

### Memory retrieval failure

Unfortunately, the CHN is terrible at storing and retrieving multiple
patterns. If we add even four more patterns into the synaptic memory,
our network will fail to retrieve our eevee.

``` python
Xi = data[eevee_pichu_idxs]
Xi = jnp.concatenate([Xi, jr.choice(jr.PRNGKey(10), data, shape=(4,), replace=False)])
fig, ax = show_im(Xi, figsize=(6, 4));
ax.set_title(f"Stored patterns (K={Xi.shape[0]})")
plt.show()
```

<img src="00_dense_storage_files/figure-commonmark/cell-13-output-1.png"
width="481" height="349" />

    Loading cached recall data
    CHN failed to retrieve the correct pattern!

<img src="00_dense_storage_files/figure-commonmark/cell-14-output-2.png"
width="654" height="305" />

![](cache/00_dense_storage/hopfield_recovery_fail.mp4)

## Dense Associative Memory

The CHN has a *quadratic* energy, which is a special case of a more
general class of models called **Dense Associative Memory** (DenseAM)
\[2\]. If we increase the degree of the polynomial used in the energy
function, we strengthen the coupling between neurons and can store more
patterns into the same synaptic matrix.

The new energy function, written in terms of polynomials of degree *n*
and using the same notation for stored patterns
*ξ*<sub>*i*</sub><sup>*μ*</sup>, is

<span id="eq-dam-energy">
$$
\begin{align\*}
E\_\text{DAM}(\sigma) &= -\sum\_{\mu=1}^K F_n\left(\sum\_{i=1}^D \xi^\mu_i \sigma_i\right),\\
\text{where}\\F_n(x) &= \begin{cases} \frac{x^n}{n} & \text{if } x \geq 0 \\ 0 & \text{if } x \< 0 \end{cases}.
\end{align\*}
 \qquad(3)$$
</span>

<div>

> **Note**
>
> We need *F*<sub>*n*</sub> to be convex for all *n*, which is why we
> perform the rectification. We could alternatively limit ourselves to
> only even values of *n*.
>
> Fun fact, rectified polynomials remove the “inverted” retrieval
> phenomenon seen in
> <a href="#sec-inverted-retrievals" class="quarto-xref">Section 1.1</a>.

</div>

<a href="#eq-dam-energy" class="quarto-xref">Equation 3</a> admits the
following manual update rule for a single neuron *i*:

<span id="eq-dam-update">
$$
\begin{align\*}
\sigma_i^{(t+1)} &\leftarrow \text{sgn}\left( \sum\_{\mu} \xi^\mu_i f_n\left( \sum\_{j \neq i} \xi^\mu_j \sigma_j^{(t)}\right)\right)\\
\end{align\*}.
 \qquad(4)$$
</span>

Here we introduced an activation function
*f*<sub>*n*</sub>(⋅) = *F*<sub>*n*</sub>′(⋅) that is the derivative of
the rectified polynomial used to define the energy. This update can be
viewed as the negative gradient of the energy function, ensuring that
the network always moves toward lower energy states. Like before, this
energy is bounded from below and we will eventually converge to a local
minimum that corresponds to one of the stored patterns.

Let’s implement the DenseAM model. The primary difference from the CHN
is that now we generalize the quadratic energy to a (possibly rectified)
polynomial energy.

``` python
class PolynomialDenseAM(BinaryAM):
    Xi: jax.Array # (K, D) Memory patterns 
    n: int # Power of polynomial F
    rectified: bool = True # Whether to rectify inputs to F

    def F_n(self, sims): 
        """Rectified polynomial of degree `n` for energy"""
        sims = sims.clip(0) if self.rectified else sims
        return 1 / self.n * sims ** self.n

    def energy(self, sigma): 
        return -jnp.sum(self.F_n(self.Xi @ sigma))
```

A simple change to using a polynomial of degree 6 instead of the CHN’s
quadratic energy function allows us to store and retrieve our desired
eevee even with up to *K* = 100 patterns.

``` python
# Increase the number of stored patterns!
Xi = data[eevee_pichu_idxs]
Xi = jnp.concatenate([Xi, jr.choice(jr.PRNGKey(10), data, shape=(98,), replace=False)])
fig1, ax1 = show_im(Xi, figsize=(7,7));
ax1.set_title("Stored patterns")
dam = PolynomialDenseAM(Xi, n=6, rectified=True)

fname = f'dam_recovery_n_{dam.n}_K_{Xi.shape[0]}'

sigma_og = Xi[0]
sigma_noisy = flip_some_bits(jr.PRNGKey(0), sigma_og, 0.2)
sigma_final, frames, energies = cached_recall(dam, fname, sigma_noisy, nsteps=20000, key=jr.PRNGKey(5))

fig2, axes2 = show_recall_output(sigma_og, sigma_noisy, sigma_final, energies, show_original=False)
fig2.suptitle(f"DenseAM(n={dam.n}, K={Xi.shape[0]})")
plt.subplots_adjust(top=0.75)
plt.show()

video, video_fname = show_cached_recall_animation(fname, steps_per_sample=32)
Markdown(f"![]({video_fname})")
```

    Loading cached recall data

<img src="00_dense_storage_files/figure-commonmark/cell-16-output-2.png"
width="558" height="580" />

<img src="00_dense_storage_files/figure-commonmark/cell-16-output-3.png"
width="650" height="295" />

![](cache/00_dense_storage/dam_recovery_n_6_K_100.mp4)

A higher degree polynomial gives us more **storage capacity**, which
means that it is easier to retrieve the patterns we have stored in the
network. Note that the higher the degree *n*, the narrower the basins of
attraction, which makes it easier to pack more patterns into the energy
landscape.

## Gotta catch ’em all!

Let’s try to store and retrieve all `1024` pokemon patterns into our
network (though we will only show retrieval for a subset of them for
computational reasons). To do this, we’ll need very large values of *n*,
which is bad for numeric overflow (computers don’t like working in
really really large numbers i.e., `inf` energy regimes).

We’ll implement an exponential version of the DenseAM \[3\].
Specifically, we will use a numerically stable `logsumexp` version
\[4\].

<span id="eq-dam-energy-logsumexp">
$$
\begin{align\*}
E\_\text{eDAM}(\sigma) &= -\log \sum\_{\mu=1}^K \exp \left(\beta \sum\_{i=1}^D \xi^\mu_i \sigma_i\right)
\end{align\*}
 \qquad(5)$$
</span>

where increasing the inverse temperature *β* has a similar effect to
increasing *n* in the DenseAM polynomial energy function. Because the
`log` is a monotonically increasing function, the energy minima of the
original energy function are preserved, while simultaneously making the
energy function more numerically stable.

``` python
class ExponentialDenseAM(BinaryAM):
    Xi: jax.Array # (K, D) Memory patterns 
    beta: float = 1.0 # Temperature parameter

    def energy(self, sigma):
        return -jax.nn.logsumexp(self.beta * self.Xi @ sigma, axis=-1)
```

``` python
# Show larger batch retrieval
Xi = data[:1024]
Nshow = 255
Xi_show = jnp.concatenate([data[eevee_pichu_idxs], jr.choice(jr.PRNGKey(10), Xi, shape=(Nshow - len(eevee_pichu_idxs),), replace=False)])
fig1, ax1 = show_im(Xi_show, figsize=(8,8));
ax1.set_title(f"Random sample of {Nshow} stored patterns")
print(f"Storing {Xi.shape[0]} patterns")
```

    Storing 1024 patterns

<img src="00_dense_storage_files/figure-commonmark/cell-18-output-2.png"
width="640" height="588" />

<div>

> **Memory usage warning**
>
> **Depending on your RAM availability, the following cell may crash
> your session**. Decrease to e.g., `nh = nw = 5` to avoid this (or
> upgrade your runtime on Colab for more resources).
>
> ![](assets/figs/ResourceExhaustedWarning.png)

</div>

<details class="code-fold">
<summary>Code</summary>

``` python
key1, key2 = jr.split(jr.PRNGKey(3))
nh = nw = 10
N = nh * nw # Sample N patterns to show in grid
sigma_og = jnp.concatenate([
    data[eevee_pichu_idxs], 
    jr.choice(jr.PRNGKey(10), data, shape=(N - len(eevee_pichu_idxs),), replace=False)])
sigma_noisy = flip_some_bits(key2, sigma_og, 0.25)

edam = ExponentialDenseAM(Xi, beta=50.)

cache_name = "logsumexp_batched"
keys = jr.split(key2, sigma_noisy.shape[0])
npz_fname = Path(CACHE_DIR) / (cache_name + ".npz")
if os.path.exists(npz_fname) and CACHE_RECALL:
    npz_data = np.load(npz_fname)
    sigma_final, frames, energies = npz_data['sigma_final'], npz_data['frames'], npz_data['energies']
else:
    sigma_final, frames, energies = jax.vmap(ft.partial(cached_recall, nsteps=16000, save=False), in_axes=(None, None, 0,0))(edam, cache_name, sigma_noisy, keys)
    np.savez(npz_fname, sigma_final=sigma_final, frames=frames, energies=energies)
```

</details>

And of course, what’s the fun if we can’t animate the retrieval process?

    <IPython.core.display.Image object>

![](cache/00_dense_storage/logsumexp_batched.gif)

<div id="refs" class="references csl-bib-body" entry-spacing="0">

<div id="ref-hopfield1982neural" class="csl-entry">

<span class="csl-left-margin">\[1\]
</span><span class="csl-right-inline">J. J. Hopfield, “Neural networks
and physical systems with emergent collective computational abilities.”
*Proceedings of the national academy of sciences*, vol. 79, no. 8, pp.
2554–2558, 1982.</span>

</div>

<div id="ref-krotov2016dense" class="csl-entry">

<span class="csl-left-margin">\[2\]
</span><span class="csl-right-inline">D. Krotov and J. J. Hopfield,
“Dense associative memory for pattern recognition,” *Advances in neural
information processing systems*, vol. 29, 2016.</span>

</div>

<div id="ref-demicirgil2017model" class="csl-entry">

<span class="csl-left-margin">\[3\]
</span><span class="csl-right-inline">M. Demircigil, J. Heusel, M. Löwe,
S. Upgang, and F. Vermet, “On a model of associative memory with huge
storage capacity,” *Journal of Statistical Physics*, vol. 168, no. 2,
pp. 288–299, May 2017, doi:
[10.1007/s10955-017-1806-y](https://doi.org/10.1007/s10955-017-1806-y).</span>

</div>

<div id="ref-ramsauer2021hopfield" class="csl-entry">

<span class="csl-left-margin">\[4\]
</span><span class="csl-right-inline">H. Ramsauer *et al.*, “Hopfield
networks is all you need,” 2021, \[Online\]. Available:
<https://openreview.net/forum?id=tL89RnzIiCd>.</span>

</div>

</div>
