Hedgehog logo
← All posts

Running SAM model using Elixir and Ortex

Elixir
Ortex
Nx
ONNX

I’m currently working at this internship where some of the sales team spend a lot of time removing background form images using the very famous removebg website. Of course all of that is done manually and takes quite some time. (removebg does offer an api but it’s expensive)

So I looked into what’s possible and I found out about meta’s Segment Anything Model (SAM).

SAM - select object example

It’s an image segmentation model that can take a bounding box as input to select the object within that box. It can be used to produce a mask and once removed form the original image boom we have our own removebg.

Running models on the BEAM

As I had never worked with machine learning models in Elixir (or any other languages actually) I decided to give it a try an run the model on the BEAM!

Elixir’s bumblebee library offers a high-level API similar to Python’s transformers. It includes some image-related APIs but lacks support for image classification yet so it won’t be that easy !

Of course the Nx team thought of that and provided us with a library called Ortex

Ortex allows for easy loading and fast inference of ONNX models using different backends available to ONNX Runtime such as CUDA, TensorRT, Core ML, and ARM Compute Library.

For those unaware, ONNX models represent a universal format for machine learning models that can be exported from major ML libraries like PyTorch and TensorFlow.

Lets start coding :)

If you want you can directly run this livebook that @kip made out of or elixir forum discussion.

SAM and other image models typically begin by transforming images into embeddings, which are then inputted into the segmentation model. These will be referred to as encoder/decoder models.

Please note that this is my first time using Nx/Ortex/Onnx so i might make silly mistakes.

I’m more or less trying to port this jupyter notebook example to livebook: https://github.com/facebookresearch/segment-anything

I’m using onnx models found here: https://huggingface.co/vietanhdev/segment-anything-onnx-models/tree/main

Nx.global_default_backend(EXLA.Backend)
Nx.default_backend()

Loading the model

model =
  Ortex.load("/Users/erisson/Documents/DEV/LEARNING/IA/SamOrtex/files/mobile_sam_encoder.onnx")
decoder =
  Ortex.load("/Users/erisson/Documents/DEV/LEARNING/IA/SamOrtex/files/mobile_decoder.onnx")

Image Encoding

  1. resize to an 1024x1024 image
  2. convert to tensor
  3. Normalisze tensor
  4. reshape to a (1, 3, 1024, 1024) tensor
image_input = Kino.Input.image("Uploaded Image")
%{file_ref: file_ref, format: :rgb, height: height, width: width} = Kino.Input.read(image_input)

content = file_ref |> Kino.Input.file_path() |> File.read!()

image_tensor =
  Nx.from_binary(content, :u8)
  |> Nx.reshape({height, width, 3})

resized_tensor = StbImage.resize()
# NxImage.resize(image_tensor, {1024, 1024})
# resized_tensor = NxImage.center_crop(image_tensor, {1024, 1024}, )
original_image = Kino.Image.new(image_tensor)
original_label = Kino.Markdown.new("**Original image**")

resized_image = Kino.Image.new(resized_tensor)
resized_label = Kino.Markdown.new("**Resized image**")

Kino.Layout.grid([
  Kino.Layout.grid([original_image, original_label], boxed: true),
  Kino.Layout.grid([resized_image, resized_label], boxed: true)
])
tensor =
  resized_tensor
  |> Nx.as_type(:f32)

# Mean and std values copied from transformer.js
mean = Nx.tensor([123.675, 116.28, 103.53])
std = Nx.tensor([58.395, 57.12, 57.375])

normalized_tensor =
  tensor
  |> NxImage.normalize(mean, std)

# taking +3s on my m1 mac ??
# setting up exla as the backend made it <20ms ?????

# Running image encoder
{image_embeddings} = Ortex.run(model, Nx.broadcast(normalized_tensor, {height, width, 3}))

Prompt encoding & mask generation

# prepare inputs
# xy coordinates in our image of the object we want to detour
input_point = Nx.tensor([[320, 240]]) |> Nx.as_type(:f32) |> Nx.reshape({1, 1, 2})

# 2, 3 is for box startig / end points
input_label = Nx.tensor([1]) |> Nx.reshape({1, 1}) |> Nx.as_type(:f32)

# Filled with 0, not used here
mask_input = Nx.broadcast(0, {1, 1, 256, 256}) |> Nx.as_type(:f32)

# not using mask_input
has_mask = Nx.broadcast(0, 1) |> Nx.as_type(:f32)

original_image_dim = Nx.tensor([height, width]) |> Nx.as_type(:f32)

{mask, _, _} =
  Ortex.run(decoder, {
    Nx.broadcast(image_embeddings, {1, 256, 64, 64}),
    Nx.broadcast(input_point, {1, 2, 2}),
    Nx.broadcast(input_label, {1, 2}),
    Nx.broadcast(mask_input, {1, 1, 256, 256}),
    Nx.broadcast(has_mask, {1}),
    Nx.broadcast(original_image_dim, {2})
  })
mask =
  mask
  |> Nx.backend_transfer()
  |> Nx.map(fn x ->
    if Nx.to_number(x) >= 0 do
      255
    else
      0
    end
  end)

# na pas changer height/width
mask = mask[0][0] |> Nx.as_type(:u8) |> Nx.reshape({height, width, 1})
resized_image = Kino.Image.new(mask)
resized_label = Kino.Markdown.new("**Image mask**")

Kino.Layout.grid([
  Kino.Layout.grid([original_image, original_label], boxed: true),
  Kino.Layout.grid([resized_image, resized_label], boxed: true)
])
StbImage.from_nx(mask)