# Construct Your Personal Imagen Textual content-to-Picture Mannequin

*by*Phil Tadros

DALL-E 2 was launched earlier this 12 months, taking the world by storm with its spectacular text-to-image capabilities. With simply an enter description of a scene, DALL-E 2 outputs practical and semantically believable photographs of the scene, like these you possibly can see under generated from the enter caption **“a bowl of soup that could be a portal to a different dimension as digital artwork”**:

Only a month after DALL-E 2’s launch, Google introduced a competing mannequin **Imagen** that was discovered to be *even higher than DALL-E 2*. Listed here are some instance photographs:

The spectacular outcomes of each DALL-E 2 and Imagen depend on cutting-edge Deep Studying analysis. Whereas obligatory for attaining State-of-the-Artwork outcomes, the utilization of such cutting-edge analysis in fashions like Imagen renders them tougher to grasp for non-specialist researchers, in flip hindering the widespread adoption of those fashions and strategies.

Due to this fact, within the spirit of democratization, we’ll be taught on this article **the best way to construct Imagen with PyTorch**. Specifically, we’ll assemble a minimal implementation of Imagen – referred to as **MinImagen** – that isolates the salient options of Imagen in order that we are able to deal with understanding Imagen’s integral working rules, disentangling implementation elements which might be important from these that are incidental.

Bundle Observe

N.B. in case you are not fascinated about implementation particulars and need solely to make use of MinImagen, it has been packaged up and might be put in with

`pip set up minimagen`

Try the section below or the corresponding GitHub repository for utilization ideas. The documentation comprises extra particulars and details about utilizing the bundle.

## Introduction

Textual content-to-image fashions have made nice strides up to now few years, as evidenced by fashions like GLIDE, DALL-E 2, Imagen, and extra. These strides are largely because of the current flourishing wave of analysis into Diffusion Models, a brand new paradigm/framework for generative fashions.

Whereas there are some good assets on the *theoretical* elements of Diffusion Fashions and text-to-image fashions, *sensible *data on the best way to truly construct these fashions shouldn’t be as ample. That is very true for fashions that incorporate Diffusion Fashions as simply *one* part of a bigger system, frequent in text-to-image fashions, just like the encoding-prior-generator chain in **DALL-E 2**, or the super-resolution chain in **Imagen**.

**MinImagen** strips off the bells and whistles of present finest practices with the intention to isolate Imagen’s salient options for academic functions. The rest of this text is structured as follows:

**Evaluate of Imagen / Diffusion Fashions:**In an effort to orient ourselves earlier than we start to code, we’ll briefly evaluation each Imagen itself and Diffusion Fashions extra typically. These critiques are meant to serve solely as a refresher, so you need to have already got a working understanding of each of those subjects when studying the refresher. You possibly can take a look at our Introduction to Diffusion Models for Machine Learning and our devoted information to How Imagen Actually Works to be taught extra.**Constructing the Diffusion Mannequin:**After our recap, we’ll first construct the`GaussianDiffusion`

class in PyTorch, which defines the Diffusion Fashions utilized in Imagen.**Constructing the Denoising U-Web:**We’ll then construct the denoising U-Web on which the Diffusion Fashions rely, manifested within the`Unet`

class.**Constructing Imagen:**Subsequent, we’ll put all of those items collectively utilizing a T5 textual content encoder and a Diffusion Mannequin chain with the intention to construct our (Min)Imagen class,`Imagen`

.**Utilizing Imagen:**Lastly, we’ll discover ways to practice and pattern from Imagen as soon as it’s totally outlined.

Mannequin Weights

Keep tuned! We’ll be coaching MinImagen over the approaching weeks and releasing a checkpoint so you possibly can generate your personal photographs. Be sure to follow our newsletter to remain updated on our content material releases.

With out additional ado, it is time to bounce into the recaps of each Imagen and Diffusion Fashions. If you’re already accustomed to Imagen and Diffusion Fashions from a theoretical perspective and need to bounce to the PyTorch implementation particulars, click on here.

## What’s Imagen?

Imagen is a text-to-image mannequin that was launched by Google simply a few months in the past. It takes in a **textual immediate** and outputs an **picture **which displays the semantic data contained inside the immediate.

To generate a picture, Imagen first makes use of a **textual content encoder **to generate a consultant encoding of the immediate. Subsequent, an **picture generator**, conditioned on the encoding, begins with Gaussian noise (“TV static”) and progressively denoises it to generate a small picture that displays the scene described by the caption. Lastly, two **tremendous decision** fashions sequentially upscale the picture to larger resolutions, once more conditioning on the encoding data.

The textual content encoder is a **pre-trained T5 textual content encoder **that’s frozen throughout coaching. Each the bottom picture technology mannequin and the tremendous decision fashions are **Diffusion Fashions**.

Need a extra detailed take a look at how Imagen works?

Try our devoted article for a deep dive into Imagen.

## What’s a Diffusion Mannequin?

Diffusion Fashions are a category of generative fashions, which means that they’re used to *generate *novel knowledge, usually photographs. Diffusion Fashions practice by corrupting coaching photographs with Gaussian Noise in a sequence of timesteps, after which studying to undo this noising course of.

Specifically, a mannequin is educated to foretell the noise part of a picture at a given timestep.

As soon as educated, this denoising mannequin can then be iteratively utilized to randomly sampled Gaussian noise, “denoising” it with the intention to generate a novel picture.

Diffusion Fashions represent a form of *metamodel* that orchestrates the coaching of *one other *mannequin – **the noise prediction mannequin**. We due to this fact nonetheless have the duty of deciding what kind of mannequin to really use for the noise prediction itself. Usually, U-Nets are chosen for this position. The U-Web in Imagen has a construction like this:

The structure relies off of the mannequin within the Diffusion Models Beat GANs on Image Synthesis paper. For **MinImagen**, we make some small modifications to this structure, together with

- Eradicating the worldwide consideration layer (not pictured),
- Changing the eye layers with transformer encoders, and
- Putting the transformer encoders on the finish of the sequence at every layer reasonably than in between the residual blocks with the intention to enable for a variable variety of residual blocks.

Need a extra detailed take a look at how Diffusion Fashions work?

Try our devoted article for a deep dive into Diffusion Fashions.

## Construct Your Personal Imagen in PyTorch

With our Imagen/Diffusion Mannequin recap full, we’re lastly prepared to start out constructing out our Imagen implementation. To get began, first open up a terminal and clone the project repository:

`git clone https://github.com/AssemblyAI-Examples/MinImagen.git`

On this tutorial, we’ll **isolate the essential components of the supply code** which might be related to the Imagen implementation itself, omitting particulars like argument validation, system dealing with, and many others. Even a minimal implementation of Imagen is comparatively giant, so this method is important with the intention to isolate instructive data. **MinImagen’s supply code is totally commented **(with related documentation here), so data relating to any omitted particulars ought to be straightforward to search out.

Every large part of the mission – the Diffusion Mannequin, the Denoising U-Web, and Imagen – has been positioned into its personal part under. We’ll begin by constructing the `GaussianDiffusion`

class.

Attribution Observe

This implementation is largely a simplified model of Phil Wang’s Imagen implementation, which you’ll find on GitHub here.

## Constructing the Diffusion Mannequin

The Diffusion Mannequin `GaussianDiffusion`

class might be present in `minimagen.diffusion_model`

. To leap to a abstract of this part, click on here.

### Initialization

The `GaussianDiffusion`

initialization perform takes just one argument – the variety of timesteps within the diffusion course of.

```
class GaussianDiffusion(nn.Module):
def __init__(self, *, timesteps: int):
tremendous().__init__()
```

First, Diffusion Fashions require a **variance schedule**, which specifies the variance of the Gaussian noise that’s added to picture at a given timestep within the diffusion course of. The variance schedule ought to be growing, however there’s some flexibility in how this schedule is outlined. For our functions we implement the variance schedule from the unique Denoising Diffusion Probabilistic Models (DDPM) paper, which is a linearly spaced schedule from **0.0001 at t=0** to **0.02 at t=T**.

```
class GaussianDiffusion(nn.Module):
def __init__(self, *, timesteps: int):
tremendous().__init__()
self.num_timesteps = timesteps
scale = 1000 / timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
betas = torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
```

From this schedule, we calculate a couple of values (once more specified within the DDPM paper) that will likely be utilized in calculations later:

```
class GaussianDiffusion(nn.Module):
def __init__(self, *, timesteps: int):
# ...
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), worth=1.)
```

The rest of the initialization function registers the above values and a few derived values as buffers, that are like parameters besides that they *do not require gradients*. **The entire values are finally derived from the variance schedule** and exist to make some calculations simpler and cleaner down the road. The specifics calculating the derived values aren’t essential, however we’ll level out under any time one in every of these derived values is utilized.

### Ahead Diffusion Course of

Now we are able to transfer on to outline `GaussianDiffusion`

‘s `q_sample`

technique, which is liable for the forward diffusion course of. Given an enter picture *x_0*, we noise it to a given timestep *t* within the diffusion course of by sampling from the under distribution:

Sampling from the above distribution is equal to the under computation, the place we now have highlighted two of the buffers outlined in `__init__`

.

That’s, **the noisy model of the picture at time t might be sampled by merely including noise to the picture**, the place each the unique picture and the noise are scaled by their respective coefficients as dictated by the timestep. Let’s implement this calculation in PyTorch now by including the tactic

`q_sample`

to the `GaussianDiffusion`

class:```
class GaussianDiffusion(nn.Module):
def __init__(self, *, timesteps: int):
# ...
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
noised = (
extract(self.sqrt_alphas_cumprod, t, x_start.form) * x_start +
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.form) * noise
)
return noised
```

`x_start`

is a PyTorch tensor of form `(b, c, h, w)`

, `t`

is a PyTorch tensor of form `(b,)`

that offers, for every picture, the timestep to which we wish to noise every picture to, and `noise`

permits us to optionally provide customized noise reasonably than pattern Gaussian noise.

We merely carry out and return the calculation within the equation above, utilizing components from the aforementioned buffers as coefficients. The `default`

perform samples random Gaussian noise when `None`

is equipped, and `extract`

extracts the right values from the buffers in response to `t`

.

### Reverse Diffusion Course of

Finally, our aim is to pattern from this distribution:

Given a picture and its noised counterpart, **this distribution tells us the best way to take a step “again in time” within the diffusion course of**, barely denoising the noisy picture. Equally to above, sampling from this distribution is equal to calculating

To carry out this calculation, we require the distribution’s imply and variance. The variance is a **deterministic perform of the variance schedule**:

However, the **imply is dependent upon the unique and noised photographs **(though the coefficients are once more deterministic features of the variance schedule). The type of the imply is:

At inference, **we won’t have x_0**, the “unique picture”, as a result of it’s what we’re

**attempting to generate**(a novel picture).

**That is the place our trainable U-Web comes into the image**– we use it to foretell

*x_0*from

*x_t*.

In apply, better results are seen when the U-Web learns to foretell the *noise part* of the picture, from which we are able to calculate *x_0*. As soon as we now have *x_0*, we are able to calculate the distribution imply with the formulation above, giving us what we have to pattern from the posterior (i.e. denoise the picture again one timestep). Visually, the general course of appears like this:

The perform to pattern from the posterior (inexperienced block within the diagram) will likely be outlined within the `Imagen`

class, however we’ll outline the 2 remaining features now. First, we implement the perform that calculates *x_0 *given a noised picture and its noisy part (purple block within the diagram). From above, we now have:

Rearranging it with the intention to isolate *x_0 *yields the under, the place two buffers have once more been highlighted.

That’s, to calculate *x_0 *we merely subtract the noise (predicted by the U-Web) from *x_t*, the place each noisy picture and noise itself are scaled by their respective coefficients as dictated by the timestep. Let’s implement this perform `predict_start_from_noise`

in PyTorch now:

```
class GaussianDiffusion(nn.Module):
def __init__(self, *, timesteps: int):
# ...
def q_sample(self, x_start, t, noise=None):
# ...
def predict_start_from_noise(self, x_t, t, noise):
return (
extract(self.sqrt_recip_alphas_cumprod, t, x_t.form) * x_t -
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.form) * noise
)
```

Now that we now have a perform for calculating *x_0*, we are able to return and calculate the posterior imply and variance (yellow block within the diagram). We repeat under their useful definitions from above, highlighting buffers outlined in `_init__`

as wanted.

Let’s implement a perform `q_posterior`

to calculate these variables in PyTorch:

```
class GaussianDiffusion(nn.Module):
def __init__(self, *, timesteps: int):
# ...
def q_sample(self, x_start, t, noise=None):
# ...
def predict_start_from_noise(self, x_t, t, noise):
return (
extract(self.sqrt_recip_alphas_cumprod, t, x_t.form) * x_t -
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.form) * noise
)
def q_posterior(self, x_start: torch.tensor, x_t: torch.tensor, t: torch.tensor, **kwargs):
posterior_mean = (
extract(self.posterior_mean_coef1, t, x_t.form) * x_start +
extract(self.posterior_mean_coef2, t, x_t.form) * x_t
)
posterior_variance = extract(self.posterior_variance, t, x_t.form)
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.form)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
```

In apply, we return each the variance and log of the variance (`posterior_log_variance_clipped`

, the place “clipped” signifies that we push values of 0 to 1e-20 earlier than taking the log). The explanation for utilizing the log of the variance is numerical stability in our calculations, which we’ll level out later when related.

### Abstract

To recap, on this part we outlined the `GaussianDiffusion`

class, which is liable for defining the diffusion course of operations. We first carried out `q_sample`

, which* *performs the *ahead diffusion course of*, noising photographs to a given timestep within the diffusion course of. We additionally carried out `predict_start_from_noise`

and `q_posterior`

, that are used to calculate parameters which might be used within the *reverse diffusion* *course of*.

## Constructing the Noise Prediction Mannequin

Now it is time to denoise our noise prediction mannequin – the U-Web. This mannequin is pretty sophisticated, so to be concise we’ll look at its ahead move, introducing related objects within the `__init__`

the place related. Analyzing solely the ahead move will assist us perceive how the U-Web works operationally whereas omitting pointless particulars that aren’t instructive in our studying the best way to construct Imagen.

The U-Web class `Unet`

might be present in `minimagen.Unet`

. To leap to a abstract of this part, click on here.

### Overview

Recall that the U-Web structure for Imagen is much like the one seen within the under diagram. We make a couple of modifications, most notably inserting the eye block (which is a Transformer encoder for us) on the *finish* of every layer within the U-Web.

### Producing Time Conditioning Vectors

Keep in mind that our U-Web is a *conditional* mannequin, which means it is dependent upon our enter textual content captions. With out this conditioning, *there could be no approach to inform the mannequin what we need to be current* within the generated photographs. Moreover, since we’re utilizing the identical U-Web for all timesteps, we have to situation on the timestep data so the mannequin is aware of what magnitude of noise it ought to be eradicating at any given time (keep in mind, our variance schedule varies with *t*). Let’s check out how we generate this time conditioning signal now. A diagram of those calculations might be seen on the finish of this part.

Enter to the mannequin we obtain a **time vector** of form `(b,)`

, which supplies the timestep for every picture within the batch. We first move this vector by way of a module which generates **hidden states** from them:

```
class Unet(nn.Module):
def __init__(self, *args, **kwargs):
# ...
self.to_time_hiddens = nn.Sequential(
SinusoidalPosEmb(dim),
nn.Linear(dim, time_cond_dim),
nn.SiLU()
)
def ahead(self, *args, **kwargs):
time_hiddens = self.to_time_hiddens(time)
```

First, for every time a singular **positional encoding vector **is generated (`SinusoidalPostEmb()`

), which maps the integer worth of the timestep for a given picture right into a consultant vector that we are able to use for timestep conditioning. For a recap on positional encodings, see the dropdown here. Subsequent, these encodings are projected to the next dimensional house (`time_cond_dim`

), and handed by way of the `SiLU`

nonlinearity. The result’s a tensor of measurement `(b, time_cond_dim)`

that constitutes our hidden states for the timesteps.

These hidden states are then utilized in **two methods**. First, a time conditioning tensor `t`

is generated, which we’ll use to offer timestep conditioning at every step within the U-Web. These are generated from `time_hiddens`

with a easy linear layer. Second, time *tokens* `time_tokens`

are generated once more from `time_hiddens`

with a easy linear layer, that are concatenated to the primary text-conditioning tokens we’ll generate momentarily. The explanation we now have these two makes use of is as a result of the **time conditioning is essentially offered** **in every single place **within the U-Web (through easy addition), whereas the primary conditioning tokens are used **solely within the cross-attention** operation in particular blocks/layers of the U-Web. Let’s examine the best way to implement these features in PyTorch:

```
class Unet(nn.Module):
def __init__(self, *args, **kwargs):
# ...
self.to_time_cond = nn.Sequential(
nn.Linear(time_cond_dim, time_cond_dim)
)
self.to_time_tokens = nn.Sequential(
nn.Linear(time_cond_dim, cond_dim * NUM_TIME_TOKENS),
Rearrange('b (r d) -> b r d', r=NUM_TIME_TOKENS)
)
def ahead(self, *args, **kwargs):
# ...
t = self.to_time_cond(time_hiddens)
time_tokens = self.to_time_tokens(time_hiddens)
```

The form of `t`

is `(b, time_cond_dim)`

, the identical as `time_hiddens`

. The form of `time_tokens`

is `(b, NUM_TIME_TOKENS, cond_dim)`

, the place `NUM_TIME_TOKENS`

defines what number of time tokens ought to be generated that will likely be concatenated on the primary conditioning textual content tokens. The default worth is 2. The einops Rearrange layer reshapes the tensor from `(b, NUM_TIME_TOKENS*cond_dim)`

to `(b, NUM_TIME_TOKENS, cond_dim)`

.

The time encoding course of is summarized on this determine:

### Producing Textual content Conditioning Vectors

Now it’s time to generate our **text conditioning** objects. From our textual content encoder we now have two tensors – the **textual content embeddings** of the batch captions, `text_embeds`

, and the **textual content masks**, `text_mask`

, which tells us what number of phrases are in every caption within the batch. These tensors are measurement `(b, max_words, enc_dim)`

, and `(b, max_words)`

respectively, the place `max_words`

is the variety of phrases within the longest caption within the batch, and `enc_dim`

is the encoding dimension of the textual content encoder.

We additionally incorporate classifier-free guidance at this level; so, given all the shifting components, let’s check out a visible instance to grasp what is going on at a excessive stage. The entire calculations are once more summarized in a diagram under.

#### Visible Instance

Let’s assume that we now have three captions – ‘*a really large purple home*‘, ‘*a person*‘, and ‘*a cheerful canine*‘. Our textual content encoder supplies the next:

We mission the embedding vectors to the next dimension (larger horizontal width), and pad each the masks and embedding tensors (further entry vertically) to the utmost variety of phrases allowed in a caption, **a worth we select** and which we let be 6 right here:

From right here, we incorporate **classifier-free guidance** by *randomly deciding which batch situations to drop* with a hard and fast likelihood. Let’s simply assume that the final occasion is dropped, which is carried out by alterting the textual content masks.

Persevering with with classifier-free steerage, we generate NULL vectors to make use of for the dropped components.

We substitute the encodings will NULL wherever the textual content masks is purple:

To get the ultimate predominant conditioning token `c`

, we easy concatenate the `time_tokens`

generated above to those textual content conditioning tensors. The concatenation occurs alongside the num_tokens/phrase dimension to depart a closing predominant conditioning token of form `(b, NUM_TIME_TOKENS + max_text_len, cond_dim)`

.

Lastly, we additionally imply pool throughout the phrase dimension to accumulate a tensor of form `(b, cond_dim)`

, after which mission to the time conditioning vector dimension to yield a tensor of form `(b, 4*cond_dim)`

. After dropping the required situations alongside the batch dimensions in response to the classifier-free steerage vector, we **add **this to `t`

to get the ultimate timestep conditioning `t`

.

#### Corresponding Code

The corresponding code for these operations is a bit cumbersome and simply reiterates implements the above course of, so the code will likely be omitted right here. Be happy to take a look at the `Unet._text_condition`

technique within the supply code to discover how this perform is carried out. The under picture summarizes all the conditioning technology course of, so be happy to open this image in a brand new tab and comply with alongside visually whereas going by way of the code with the intention to keep oriented.

### Constructing the U-Web

Now that we now have the 2 conditioning tensors we want – the primary conditioning tensor `c`

utilized through consideration and the time conditioning tensor `t`

utilized through addition – we are able to transfer on to defining the U-Web itself. As above, we proceed by inspecting `Unet`

‘s `forward`

technique, introducing objects in `__init__`

as wanted.

#### Preliminary Convolution

First, we have to carry out an initial convolution to get our enter photographs to the anticipated variety of channels for the community. We make the most of `minimagen.layers.CrossEmbedLayer`

, which is actually an Inception layer.

```
class Unet(nn.Module):
def __init__(self, *args, **kwargs):
# ...
self.init_conv = CrossEmbedLayer(channels, dim_out=dim, kernel_sizes=(3, 7, 15), stride=1)
def ahead(self, *args, **kwargs):
# ...
x = self.init_conv(x)
```

#### Preliminary ResNet Block

Subsequent, we move the photographs into the preliminary ResNet block (`minimagen.layers.ResnetBlock`

) for this layer of the U-Web, referred to as `init_block`

.

```
class Unet(nn.Module):
def __init__(self, *args, **kwargs):
# ...
self.init_block = ResnetBlock(current_dim, current_dim, cond_dim=layer_cond_dim, time_cond_dim=time_cond_dim, teams=teams)
def ahead(self, *args, **kwargs):
# ...
x = init_block(x, t, c)
```

The ResNet block first passes the photographs by way of an preliminary `block1`

(`minimagen.layers.block`

), leading to an output tensor of the identical measurement because the enter.

Subsequent, residual cross consideration (`minimagen.layers.CrossAttention`

) is carried out with the primary conditioning tokens

After that we move the time encodings by way of a easy MLP to realize the right dimensionality, after which break up it into two sizes `(b, c, 1, 1)`

tensors.

We lastly move the photographs by way of one other convolution block that’s equivalent to `block1`

, aside from the truth that it incorporates the timestep data through a scale-shift utilizing the timestep embeddings.

The ultimate ensuing output of `init_block`

has the identical form because the enter tensor.

#### Remaining ResNet Blocks

Subsequent, we move the photographs by way of a sequence of ResNet blocks which might be equivalent to `init_block`

, aside from the truth that they solely situation on the timestep. We save the outputs in `hiddens`

for the skip connections afterward.

```
class Unet(nn.Module):
def __init__(self, *args, **kwargs):
# ...
self.resnet_blocks = nn.ModuleList(
[
ResnetBlock(current_dim, current_dim, time_cond_dim=time_cond_dim, groups=groups)
for _ in range(layer_num_resnet_blocks)
]
def ahead(self, *args, **kwargs):
# ...
hiddens = []
for resnet_block in self.resnet_blocks:
x = resnet_block(x, t)
hiddens.append(x)
```

#### Last Transformer Block

After processing with the ResNet blocks, we optionally move the photographs by way of a Transformer encoder (`minimagen.layers.TransformerBlock`

).

```
class Unet(nn.Module):
def __init__(self, *args, **kwargs):
# ...
self.transformer_block = TransformerBlock(dim=current_dim, heads=ATTN_HEADS, dim_head=ATTN_DIM_HEAD)
def ahead(self, *args, **kwargs):
# ...
x = self.transformer_block(x)
hiddens.append(x)
```

The transformer block applies multi-headed consideration (purple block under), after which passes the output by way of a `minimagen.layers.ChanFeedForward`

layer, a sequence of convolutions with layer norms between them and GeLU between them.

Detailed Diagram

#### Downsample

As the ultimate step for this layer of the U-Web, the photographs are downsampled to half the spatial width.

```
class Unet(nn.Module):
def __init__(self, *args, **kwargs):
# ...
self.post_downsample = Downsample(current_dim, dim_out)
def ahead(self, *args, **kwargs):
# ...
x = post_downsample(x)
```

The place the downsampling operation is a straightforward fastened convolution.

```
def Downsample(dim, dim_out=None):
dim_out = default(dim_out, dim)
return nn.Conv2d(dim, dim_out, kernel_size=4, stride=2, padding=1)
```

#### Center Layers

The above sequence of ResNet blocks, (attainable) Transformer encoder, and Downsampling is repeated for every layer of the U-Web till we attain the bottom spatial decision / best channel depth. At this level, we pass the images through two more ResNet blocks, which **do** situation on the primary conditioning tokens (just like the `init_block`

of every Resnet Layer). Optionally, we move the photographs by way of a residual Attention layer between these blocks.

```
class Unet(nn.Module):
def __init__(self, *args, **kwargs):
# ...
self.mid_block1 = ResnetBlock(mid_dim, mid_dim, cond_dim=cond_dim, time_cond_dim=time_cond_dim,
teams=resnet_groups[-1])
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c',
Residual(Consideration(mid_dim, heads=ATTN_HEADS,
dim_head=ATTN_DIM_HEAD))) if attend_at_middle else None
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim=cond_dim, time_cond_dim=time_cond_dim,
teams=resnet_groups[-1])
def ahead(self, *args, **kwargs):
# ...
x = self.mid_block1(x, t, c)
if exists(self.mid_attn):
x = self.mid_attn(x)
x = self.mid_block2(x, t, c)
```

#### Upsampling Trajectory

The upsampling trajectory of the U-Web is essentially a mirror-inverse of the downsampling trajectory, aside from the truth that we (a) concatenate the corresponding skip connections from the downsampling trajectory earlier than every resnet block at any given layer, and (b) we use an upsampling operation reasonably than a downsampling one. This upsampling operation is a nearest-neighbor upsampling adopted by a spatial measurement preserving convolution

```
def Upsample(dim, dim_out=None):
dim_out = default(dim_out, dim)
return nn.Sequential(
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(dim, dim_out, 3, padding=1)
)
```

For the sake of brevity, the upsampling trajectory code shouldn’t be repeated right here, however can be found within the supply code.

On the finish of the upsampling trajectory, a final convolution layer brings the photographs to the right output channel depth (typically 3).

### Abstract

To recap, on this part we outlined the `Unet`

class, which is liable for defining the denoising U-Web that’s educated through Diffusion. We first discovered the best way to generate conditioning tensors for a given timestep and caption, after which incorporate this conditioning data into the U-Web’s forward pass, which sends photographs by way of a sequence of ResNet blocks and Transformer encoders with the intention to predict the noise part of a given picture

## Constructing Imagen

To recap, we now have constructed a `GaussianDiffusion`

object which defines and implements the diffusion course of “metamodel”, which in flip makes use of our `Unet`

class to coach. Let’s now check out how we put these items collectively to construct Imagen itself. We’ll once more take a look at the 2 major features inside Imagen – `forward`

for coaching and `sample`

for picture technology, once more introducing objects in `__init__`

as wanted.

The `Imagen`

class might be present in `minimagen.Imagen`

. To leap to a abstract of this part, click on here.

### Imagen Ahead Cross

The Imagen ahead move consists of (1) noising the coaching photographs, (2) predicting the noise elements with the U-Web, after which (3) returning the loss between the expected noise and the true noise.

To start, we randomly pattern timesteps to noise the coaching photographs to, after which encoding the conditioning textual content, inserting the embeddings and masks on the identical system because the enter picture tensor:

```
from minimagen.t5 import t5_encode_text
class Imagen(nn.Module):
def __init__(self, timesteps):
self.noise_scheduler = GaussianDiffusion(timesteps=timesteps)
self.text_encoder_name="t5_small"
def ahead(self, photographs, texts):
instances = self.noise_scheduler.sample_random_times(b, system=system)
text_embeds, text_masks = t5_encode_text(texts, identify=self.text_encoder_name)
text_embeds, text_masks = map(lambda t: t.to(photographs.system), (text_embeds, text_masks))
```

Recall that Imagen has a **base** mannequin that generates small photographs and **super-resolution** fashions that upscale the photographs. We due to this fact have to resize the photographs to the right measurement for the U-Web in use. If the U-Web is a super-resolution mannequin, we moreover have to rescale the coaching photographs **first **right down to the low-resolution conditioning measurement, **after which** as much as the right measurement for the U-Web. This simulates the upsampling of 1 U-Web’s output to the dimensions of the subsequent U-Web’s enter in Imagen’s super-resolution chain (permitting the latter U-Web to situation on the previous U-Web’s output).

We additionally **add noise to the low-resolution conditioning **photographs for noise conditioning augmentation, selecting one noise stage for the entire batch.

```
#...
from minimagen.helpers import resize_image_to
from einops import repeat
class Imagen(nn.Module):
def __init__(self, timesteps):
# ...
self.lowres_noise_schedule = GaussianDiffusion(timesteps=timesteps)
def ahead(self, photographs, texts):
# ...
lowres_cond_img = lowres_aug_times = None
if exists(prev_image_size):
lowres_cond_img = resize_image_to(photographs, prev_image_size, pad_mode="mirror")
lowres_cond_img = resize_image_to(lowres_cond_img, target_image_size,
pad_mode="mirror")
lowres_aug_time = self.lowres_noise_schedule.sample_random_times(1, system=system)
lowres_aug_times = repeat(lowres_aug_time, '1 -> b', b=b)
photographs = resize_image_to(photographs, target_image_size)
```

Lastly, we calculate and return the loss:

```
#...
from minimagen.helpers import resize_image_to
from einops import repeat
class Imagen(nn.Module):
def __init__(self, timesteps):
# ...
def ahead(self, photographs, texts, unet):
# ...
return self._p_losses(unet, photographs, instances, text_embeds=text_embeds,
text_mask=text_masks,
lowres_cond_img=lowres_cond_img,
lowres_aug_times=lowres_aug_times)
```

Let’s check out `_p_losses`

to see how we calculate the loss.

First, we use the Diffusion Mannequin ahead course of to noise each the enter photographs and, if the U-Web is a super-resolution mannequin, the low-resolution conditioning photographs as nicely.

```
#...
class Imagen(nn.Module):
def __init__(self, timesteps):
# ...
def p_losses(self, x_start, instances, lowres_cond_img=None):
# ...
noise = torch.randn_like(x_start)
x_noisy = self.noise_scheduler.q_sample(x_start=x_start, t=instances, noise=noise)
lowres_cond_img_noisy = None
if exists(lowres_cond_img):
lowres_aug_times = default(lowres_aug_times, instances)
lowres_cond_img_noisy = self.lowres_noise_schedule.q_sample(
x_start=lowres_cond_img, t=lowres_aug_times,
noise=torch.randn_like(lowres_cond_img))
```

Subsequent, we use the U-Web to **predict the noise part **of the noisy photographs, taking in textual content embeddings as conditioning data, along with the low-resolution photographs if the U-Web is for super-resolution. `cond_drop_prob`

provides the likelihood of dropout for classifier-free guidance.

```
#...
class Imagen(nn.Module):
def __init__(self, timesteps):
# ...
self.cond_drop_prob = 0.1
def p_losses(self, x_start, instances, text_embeds, text_mask, lowres_cond_img=None):
# ...
pred = unet.ahead(
x_noisy,
instances,
text_embeds=text_embeds,
text_mask=text_mask,
lowres_noise_times=lowres_aug_times,
lowres_cond_img=lowres_cond_img_noisy,
cond_drop_prob=self.cond_drop_prob,
)
```

We then calculate the loss between the precise noise that was added and the U-Web’s prediction of the noise in response to `self.loss_fn`

, which is L2 loss by default.

```
#...
class Imagen(nn.Module):
def __init__(self, timesteps):
# ...
self.loss_fn = torch.nn.useful.mse_loss
def p_losses(self, x_start, instances, text_embeds, text_mask, lowres_cond_img=None):
# ...
return self.loss_fn(pred, noise)
```

That is all it takes to get the loss with Imagen! It’s fairly a easy course of as soon as we now have constructed the Diffusion Mannequin/U-Web spine.

### Sampling with Imagen

Finally, what we need to do is **pattern with Imagen**. That’s, we would like to have the ability to generate novel photographs given textual captions. Recall from above that this requires calculating the ahead course of posterior imply:

Now that we now have outlined our U-Web that predicts the noise part, we now have all the items we have to compute the posterior imply.

First, we get the noise prediction (blue) utilizing our U-Web’s `forward`

(or `forward_with_cond_scale`

) technique, after which calculate *x_0 *from it (purple) utilizing the U-Web’s `predict_start_from_noise`

technique launched beforehand which performs the under calculation:

The place *x_t *is a loud picture and epsilon is the U-Web’s noise prediction. Subsequent, *x_0 *is dynamically thresholded after which handed, together with *x_t*, into the into the `q_posterior`

technique of the U-Web (yellow) to get the distribution imply.

This complete course of is wrapped up in `Imagen`

‘s `_p_mean_variance`

perform.

```
#...
class Imagen(nn.Module):
def __init__(self, timesteps):
# ...
self.dynamic_thresholding_percentile = 0.9
def _p_mean_variance(self, unet, x, t, *, noise_scheduler
text_embeds=None, text_mask=None):
# Get the noise prediction from the unet (blue block)
pred = unet.forward_with_cond_scale(x, t, text_embeds=text_embeds, text_mask=text_mask)
# Calculate the beginning photographs from the noise (yellow block)
x_start = noise_scheduler.predict_start_from_noise(x, t=t,
# Dynamically threshold
s = torch.quantile(
rearrange(x_start, 'b ... -> b (...)').abs(),
self.dynamic_thresholding_percentile,
dim=-1
)
s.clamp_(min=1.)
s = right_pad_dims_to(x_start, s)
x_start = x_start.clamp(-s, s) / s
# Return the ahead course of posterior parameters (inexperienced block)
return noise_scheduler.q_posterior(x_start=x_start, x_t=x, t=t, t_next=t_next)
```

From right here we now have every part we have to pattern from the posterior, which is to say “return one timestep” within the diffusion course of. That’s, we’re looking for to pattern from the under distribution:

We noticed above that sampling from this distribution is equal to calculating

Since we calculated the posterior imply and (log) variance with `_p_mean_variance`

, we are able to now implement the above calculation with `_p_sample`

, calculating the sq. root of the variance as such for numerical stability.

```
class Imagen(nn.Module):
def __init__(self, timesteps):
# ...
self.dynamic_thresholding_percentile = 0.9
@torch.no_grad()
def _p_sample(self, unet, x, t, *, text_embeds=None, text_mask=None):
b, *_, system = *x.form, x.system
# Get posterior parameters
model_mean, _, model_log_variance = self.p_mean_variance(unet, x=x, t=t,
text_embeds=text_embeds, text_mask=text_mask)
# Get noise which we'll use to calculate the denoised picture
noise = torch.randn_like(x)
# No extra denoising when t == 0
is_last_sampling_timestep = (t == 0)
nonzero_mask = (1 - is_last_sampling_timestep.float()).reshape(b,
*((1,) * (len(x.form) - 1)))
# Get the denoised picture. Equal to imply * sqrt(variance) however calculate this approach to be extra numerically secure
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
```

At this level, we now have denoised the random noise enter into Imagen **one timestep**. To generate photographs, we have to do that for *each* timestep, beginning with randomly sampled Gaussian noise at *t=T* and going “again in time” till we attain *t=0*. Due to this fact, we run `_p_sample`

in a loop over timesteps with `_p_sample_loop`

:

```
class Imagen(nn.Module):
def __init__(self, *args, **kwargs):
# ...
@torch.no_grad()
def p_sample_loop(self, unet, form, *, lowres_cond_img=None, lowres_noise_times=None,
noise_scheduler=None, text_embeds=None, text_mask=None):
system = self.system
# Get beginning noisy photographs
img = torch.randn(form, system=system)
# Get sampling timesteps (final_t, final_t-1, ..., 2, 1, 0)
batch = form[0]
timesteps = noise_scheduler.get_sampling_timesteps(batch, system=system)
# For every timestep, denoise the photographs barely
for instances in tqdm(timesteps, desc="sampling loop time step", complete=len(timesteps)):
img = self.p_sample(
unet,
img,
instances,
text_embeds=text_embeds,
text_mask=text_mask)
# Clamp the values to be within the allowed vary and potentialy
# unnormalize again into the vary (0., 1.)
img.clamp_(-1., 1.)
unnormalize_img = self.unnormalize_img(img)
return unnormalize_img
```

`_p_sample_loop`

is how we generate photographs with **one unet**. Imagen comprises a *chain*** **of U-Nets, so, lastly, the `sample`

perform iteratively passes the generated photographs by way of *every* U-Web within the chain, and handles different sampling necessities like producing textual content encodings/masks, inserting the currently-sampling U-Web on the GPU if obtainable, and many others. `eval_decorator`

units the mannequin to be in analysis mode if it isn’t upon calling `pattern`

.

```
class Imagen(nn.Module):
def __init__(self, *args, **kwargs):
# ...
self.noise_schedulers = nn.ModuleList([])
for i in num_unets:
self.noise_schedulers.append(GaussianDiffusion(timesteps=timesteps))
@torch.no_grad()
@eval_decorator
def pattern(self, texts=None, batch_size=1, cond_scale=1., lowres_sample_noise_level=None, return_pil_images=False, system=None):
# Put all Unets on the identical system as Imagen
system = default(system, self.system)
self.reset_unets_all_one_device(system=system)
# Get the textual content embeddings/masks from textual captions (`texts`)
text_embeds, text_masks = t5_encode_text(texts, identify=self.text_encoder_name)
text_embeds, text_masks = map(lambda t: t.to(system), (text_embeds, text_masks))
batch_size = text_embeds.form[0]
outputs = None
is_cuda = subsequent(self.parameters()).is_cuda
system = subsequent(self.parameters()).system
lowres_sample_noise_level = default(lowres_sample_noise_level,
self.lowres_sample_noise_level)
# Iterate by way of every Unet
for unet_number, unet, channel, image_size, noise_scheduler, dynamic_threshold in tqdm(
zip(vary(1, len(self.unets) + 1), self.unets, self.sample_channels,
self.image_sizes, self.noise_schedulers, self.dynamic_thresholding)):
# If GPU is accessible, place the Unet at the moment being sampled from on the GPU
context = self.one_unet_in_gpu(unet=unet) if is_cuda else null_context()
with context:
lowres_cond_img = lowres_noise_times = None
form = (batch_size, channel, image_size, image_size)
# If on a super-res mannequin, noise the earlier unet's photographs for conditioning
if unet.lowres_cond:
lowres_noise_times = self.lowres_noise_schedule.get_times(batch_size,
lowres_sample_noise_level,
system=system)
lowres_cond_img = resize_image_to(img, image_size, pad_mode="mirror")
lowres_cond_img = self.lowres_noise_schedule.q_sample(
x_start=lowres_cond_img,
t=lowres_noise_times,
noise=torch.randn_like(lowres_cond_img))
form = (batch_size, self.channels, image_size, image_size)
# Generate a picture with the present U-Web
img = self.p_sample_loop(
unet,
form,
text_embeds=text_embeds,
text_mask=text_masks,
cond_scale=cond_scale,
lowres_cond_img=lowres_cond_img,
lowres_noise_times=lowres_noise_times,
noise_scheduler=noise_scheduler,
)
# Output the picture if on the finish of the super-resolution chain
outputs = img if unet_number == len(self.unets) + 1 else None
# Return tensors or PIL photographs
if not return_pil_images:
return outputs
pil_images = record(map(T.ToPILImage(), img.unbind(dim=0)))
return pil_images
```

### Abstract

To recap, on this part we outlined the `Imagen`

class, first inspecting its `forward`

move which noises coaching photographs, predicts their noise elements, after which returns the common L2 loss between the predictions and true noise values. Then, we checked out `sample`

, which is used to generate photographs through the successive software of the U-Nets which compose the Imagen occasion.

## Coaching and Sampling from MinImagen

MinImagen might be put in with

`pip set up minimagen`

The MinImagen bundle hides all the implementation particulars mentioned above, and exposes a high-level API for working with Imagen, documented here. Let’s take a look at the best way to use the `minimagen`

bundle to coach and pattern from a MinImagen occasion. You possibly can alternatively take a look at MinImagen’s GitHub repo to see data on utilizing the offered scripts for coaching/technology.

### Coaching MinImagen

To coach Imagen, we have to first carry out some imports.

```
import os
from datetime import datetime
import torch.utils.knowledge
from torch import optim
from minimagen.Imagen import Imagen
from minimagen.Unet import Unet, Base, Tremendous, BaseTest, SuperTest
from minimagen.generate import load_minimagen, load_params
from minimagen.t5 import get_encoded_dim
from minimagen.coaching import get_minimagen_parser, ConceptualCaptions, get_minimagen_dl_opts,
create_directory, get_model_size, save_training_info, get_default_args, MinimagenTrain,
load_testing_parameters
```

Subsequent, we decide the system the coaching will occur on, utilizing a GPU if one is accessible, after which instantiate a MinImagen argument parser. The parser will enable us to specify relevant parameters when operating the script from the command line.

```
# Get system
system = torch.system("cuda:0" if torch.cuda.is_available() else "cpu")
# Command line argument parser
parser = get_minimagen_parser()
args = parser.parse_args()
```

Now we’ll create a timestamped coaching listing that may retailer all the data from the coaching. The `create_directory()`

perform returns a context manager that permits us to briefly enter the listing to learn recordsdata, save recordsdata, and many others.

```
# Create coaching listing
timestamp = datetime.now().strftime("%Ypercentmpercentd_percentHpercentMpercentS")
dir_path = f"./training_{timestamp}"
training_dir = create_directory(dir_path)
```

Since that is an instance script, we substitute some command-line arguments with various values that may decrease the computational load in order that we are able to rapidly practice and see the outcomes to grasp how MinImagen trains.

```
# Exchange some cmd line args to decrease computational load.
args = load_testing_parameters(args)
```

Subsequent, we’ll create our DataLoaders, utilizing a subset of the Conceptual Captions dataset. Try `MinimagenDataset`

if you wish to use a distinct dataset.

```
# Exchange some cmd line args to decrease computational load.
args = load_testing_parameters(args)
# Load subset of Conceptual Captions dataset.
train_dataset, valid_dataset = ConceptualCaptions(args, smalldata=True)
# Create dataloaders
dl_opts = {**get_minimagen_dl_opts(system), 'batch_size': args.BATCH_SIZE, 'num_workers': args.NUM_WORKERS}
train_dataloader = torch.utils.knowledge.DataLoader(train_dataset, **dl_opts)
valid_dataloader = torch.utils.knowledge.DataLoader(valid_dataset, **dl_opts)
```

It is now time to create the U-Web’s that will likely be used within the MinImagen’s U-Web chain. The bottom mannequin that generates the picture is a `BaseTest`

occasion, and the super-resolution mannequin that upscales the picture is a `SuperTest`

occasion. These fashions are deliberately tiny in order that we are able to rapidly practice them to see how coaching a MinImagen occasion works. See `Base`

and `Super`

for fashions nearer to the unique Imagen implementation.

We load the parameters for these U-Nets, after which instantiate the situations with an inventory comprehension.

```
# Use small U-Nets to decrease computational load.
unets_params = [get_default_args(BaseTest), get_default_args(SuperTest)]
unets = [Unet(**unet_params).to(device) for unet_params in unets_params]
```

Now we are able to lastly instantiate the precise MinImagen occasion. We first specify some parameters, after which create the occasion.

```
# Specify MinImagen parameters
imagen_params = dict(
image_sizes=(int(args.IMG_SIDE_LEN / 2), args.IMG_SIDE_LEN),
timesteps=args.TIMESTEPS,
cond_drop_prob=0.15,
text_encoder_name=args.T5_NAME
)
# Create MinImagen from UNets with specified imagen parameters
imagen = Imagen(unets=unets, **imagen_params).to(system)
```

For report preserving, we fill within the default values for unspecified arguments, get the size of the MinImagen occasion, after which save all of this info and extra.

```
# Fill in unspecified arguments with defaults to report full config (parameters) file
unets_params = [{**get_default_args(Unet), **i} for i in unets_params]
imagen_params = {**get_default_args(Imagen), **imagen_params}
# Get the dimensions of the Imagen mannequin in megabytes
model_size_MB = get_model_size(imagen)
# Save all coaching information (config recordsdata, mannequin measurement, and many others.)
save_training_info(args, timestamp, unets_params, imagen_params, model_size_MB, training_dir)
```

Lastly, we are able to practice the MinImagen occasion utilizing `MinimagenTrain`

:

```
# Create optimizer
optimizer = optim.Adam(imagen.parameters(), lr=args.OPTIM_LR)
# Practice the MinImagen occasion
MinimagenTrain(timestamp, args, unets, imagen, train_dataloader, valid_dataloader, training_dir, optimizer, timeout=30)
```

In an effort to practice the occasion, save the script as `minimagen_train.py`

after which run the next within the terminal:

`python minimagen_train.py`

N.B. – you might have to vary `python`

to `python3`

, and/or `minimagen_train.py`

to `-m minimagen_train`

.

After coaching is full, you will note a brand new Training Directory, which shops all the data from the coaching together with mannequin configurations and weights. To see how this Coaching Listing can be utilized to generate photographs, transfer on to the subsequent part.

`practice.py`

### Producing Photos with MinImagen

Now that we now have a “educated” MinImagen occasion, we are able to use it to generate photographs. Fortunately, this course of is way more simple. First, we’ll once more carry out obligatory imports and outline an argument parser in order that we are able to specify the placement of the Training Directory that comprises the educated MinImagen weights.

```
from argparse import ArgumentParser
from minimagen.generate import load_minimagen, sample_and_save
# Command line argument parser
parser = ArgumentParser()
parser.add_argument("-d", "--TRAINING_DIRECTORY", dest="TRAINING_DIRECTORY", assist="Coaching listing to make use of for inference", kind=str)
args = parser.parse_args()
```

Subsequent, we are able to outline an inventory of captions that we need to generate photographs for. We simply specify one caption for now.

```
# Specify the caption(s) to generate photographs for
captions = ['a happy dog']
```

Now all we now have to do is run `sample_and_save()`

, specifying the captions and Coaching Listing to make use of, and a picture for every caption will likely be generated and saved.

```
# Use `sample_and_save` to generate and save the iamges
sample_and_save(captions, training_directory=args.TRAINING_DIRECTORY)
```

Alternatively, you possibly can load a MinImagen occasion and enter this (reasonably than a Coaching Listing) to `sample_and_save()`

, however on this case details about the MinImagen occasion used to generate the photographs won’t be saved, so this isn’t advisable.

```
minimagen = load_minimagen(args.TRAINING_DIRECTORY)
sample_and_save(captions, minimagen=minimagen)
```

That is it! As soon as the technology is full, you will note a brand new listing referred to as `generated_images_<TIMESTAMP>`

that shops the captions used to generate the photographs, the Coaching Listing used to generate photographs, and the photographs themselves. The quantity in every picture’s filename corresponds to the index of the caption that was used to generate it.

`inference.py`

## Last Phrases

The spectacular outcomes of State-of-the-Artwork text-to-image fashions communicate for themselves, and MinImagen serves as a strong basis for understanding the sensible workings of such fashions. For extra Machine Studying content material, be happy to take a look at extra of our blog or YouTube channel. Alternatively, comply with us on Twitter or comply with our publication to remain within the loop for future content material we drop.