CLIP Guidance
Similar to Classifier Guidance, while using CLIP to calculate the similarity between texts and images instead of a classifier.
The denoise step of DDPM with CLIP Guidance
where
- is the guidance condition (text)
- and are encoders to embed images and texts into the same space so that we can use dot product to calculate similarity
Pseudo-code
# Load a pre-trained CLIP model (both image and text encoders)
clip_model, preprocess = load_clip_model(...)
# Encode the target text prompt (e.g., "a cat sitting on a tree")
text_prompt = "a cat sitting on a tree"
text_features = clip_model.encode_text(preprocess(text_prompt))
# Controls the strength of the CLIP guidance
guidance_scale = 7.5
# Randomly draw noise with the same shape as the output image from a Gaussian distribution
input = get_noise(...)
# Each step denoises the input
for t in tqdm(scheduler.timesteps):
# Use U-Net to predict noise (the score function is represented by the predicted noise added from x_0 to x_t)
with torch.no_grad():
noise_pred = model(input, t).sample
# CLIP guidance step: pass the current image through the CLIP model to get its features
current_image = denormalize_image(input) # Adjust the image to the appropriate range (e.g., [0, 1])
image_features = clip_model.encode_image(preprocess(current_image))
# Compute the direction towards the target text features
clip_guidance = compute_clip_direction(image_features, text_features)
# Apply the gradient for CLIP guidance
noise_pred += clip_guidance * guidance_scale # Strengthen the alignment with the text prompt
# Calculate x_{t-1} using the updated noise
input = scheduler.step(noise_pred, t, input).prev_sample
where compute_clip_direction
can be implemented as
def compute_clip_direction(image_features, text_features):
# Normalize the features to have unit length for proper cosine similarity calculation
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# Calculate the cosine similarity between image and text features
similarity = (image_features * text_features).sum(dim=-1)
# Backpropagate the similarity score to get the gradient direction
similarity_loss = -similarity # Maximizing similarity, so we take the negative
similarity_loss.backward() # Compute the gradient of the loss with respect to the image
# Get the gradient from the image features (which implicitly propagates back to the input image)
clip_guidance = image_features.grad
return clip_guidance