How to create dynamic masks with DALL·E and Segment Anything

May 19, 2023
Open in Github

Segment Anything is a model from Meta that can be used to select portions of images. Combined with DALL·E's ability to inpaint specified portions of images, you can use Segment Anything to easily select any part of an image you'd like to alter.

In this notebook, we'll use these tools to become fashion designers and dynamically replace our digital models' outfits with tailored, original creations. The notebook follows this flow:

  • Setup: Initialise your libraries and any location directories.
  • Generate original image: Make an original image that we'll create dynamic masks from.
  • Generate mask: Use Segment Anything to create a dynamic mask.
  • Create new image: Generate a new image with the masked area inpainted with a fresh prompt.

Setup

To get started we'll need to follow the instructions for using the Segment Anything (SAM) model open-sourced by Meta. As of May 2023, the key steps are:

  • Install Pytorch (version 1.7+).
  • Install the library using pip install git+https://github.com/facebookresearch/segment-anything.git.
  • Install dependencies using pip install opencv-python pycocotools matplotlib onnxruntime onnx.
  • Download a model checkpoint to use (default size is 2.4 GB).
!pip install torch torchvision torchaudio
!pip install git+https://github.com/facebookresearch/segment-anything.git
!pip install opencv-python pycocotools matplotlib onnxruntime onnx
!pip install requests
!pip install openai
!pip install numpy
!wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
import cv2
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from matplotlib import rcParams
import numpy as np
from openai import OpenAI
import os
from PIL import Image
import requests
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
import torch

# Set directories for generation images and edit images
base_image_dir = os.path.join("images", "01_generations")
mask_dir = os.path.join("images", "02_masks")
edit_image_dir = os.path.join("images", "03_edits")

# Point to your downloaded SAM model
sam_model_filepath = "./sam_vit_h_4b8939.pth"

# Initiate SAM model
sam = sam_model_registry["default"](checkpoint=sam_model_filepath)

# Initiate openAI client
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY", "<your OpenAI API key if not set as env var>"))
def process_dalle_images(response, filename, image_dir):
    # save the images
    urls = [datum.url for datum in response.data]  # extract URLs
    images = [requests.get(url).content for url in urls]  # download images
    image_names = [f"{filename}_{i + 1}.png" for i in range(len(images))]  # create names
    filepaths = [os.path.join(image_dir, name) for name in image_names]  # create filepaths
    for image, filepath in zip(images, filepaths):  # loop through the variations
        with open(filepath, "wb") as image_file:  # open the file
            image_file.write(image)  # write the image to the file

    return filepaths
dalle_prompt = '''
Full length, zoomed out photo of our premium Lederhosen-inspired jumpsuit.
Showcase the intricate hand-stitched details and high-quality leather, while highlighting the perfect blend of Austrian heritage and modern fashion.
This piece appeals to a sophisticated, trendsetting audience who appreciates cultural fusion and innovative design.
'''
# Generate your images
generation_response = client.images.generate(
    model = "dall-e-3",
    prompt=dalle_prompt,
    n=3,
    size="1024x1024",
    response_format="url",
)
filepaths = process_dalle_images(generation_response, "generation", base_image_dir)
# print the new generations
for filepath in filepaths:
    print(filepath)
    display(Image.open(filepath))

Generate Mask

Next we'll load up one of our images and generate masks.

For this demonstration we're picking a UX where we "click" on a point on the image to generate masks from. However, there are example notebooks provided by Meta which show how to generate every possible mask for an image, draw a box, and some other useful approaches.

# Pick one of your generated images
chosen_image = "images/01_generations/generation_2.png"
# Function to display mask using matplotlib
def show_mask(mask, ax):
    color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)


# Function to display where we've "clicked"
def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels == 1]
    neg_points = coords[labels == 0]
    ax.scatter(
        pos_points[:, 0],
        pos_points[:, 1],
        color="green",
        marker="*",
        s=marker_size,
        edgecolor="white",
        linewidth=1.25,
    )
    ax.scatter(
        neg_points[:, 0],
        neg_points[:, 1],
        color="red",
        marker="*",
        s=marker_size,
        edgecolor="white",
        linewidth=1.25,
    )
# Load chosen image using opencv
image = cv2.imread(chosen_image)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

# Display our chosen image
plt.figure(figsize=(10, 10))
plt.imshow(image)
plt.axis("on")
plt.show()