gsamservice/imagesegmentation.py
2025-08-29 09:53:07 +02:00

314 lines
11 KiB
Python

import numpy as np
import supervision as sv
import cv2
import PIL
from scipy import ndimage
from typing import List, Tuple, Union, Optional
import torch
from torchvision.ops import box_convert
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from grounding_dino.groundingdino.util.inference import load_model, load_image, predict
import grounding_dino.groundingdino.datasets.transforms as T
class ImageSegmentation:
def __init__(self):
# select the device for computation
if torch.cuda.is_available():
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")
print(f"using device: {self.device}")
if self.device.type == "cuda":
# NOTE: somehow this didn't work locally inside a docker container
# use bfloat16 for the entire notebook
# orignal:
# torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
# might work without or this:
# torch.autocast("cuda", dtype=torch.float16).__enter__()
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
if torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
sam2_checkpoint = "checkpoints/sam2.1_hiera_large.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
self.sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=self.device)
self.sam2_predictor = SAM2ImagePredictor(self.sam2_model)
grounding_dino_config = (
"grounding_dino/groundingdino/config/GroundingDINO_SwinT_OGC.py"
)
grounding_dino_checkpoint = "gdino_checkpoints/groundingdino_swint_ogc.pth"
self.box_threshold = 0.35
self.text_threshold = 0.25
self.grounding_model = load_model(
model_config_path=grounding_dino_config,
model_checkpoint_path=grounding_dino_checkpoint,
device=self.device,
)
def segment_image_from_text(self, pil_image: PIL.Image.Image, text: str):
"""Generate segmentation masks from image and text description using Grounding DINO + SAM2.
Args:
pil_image: PIL image that should be segmented
text: object description(s) to be segmented
Returns:
List of C tuples (mask, score) with mask (HxW) and float score
Result image
"""
# image preparation taken from load_image() in Grounded-SAM-2/grounding_dino/groundingdino/util/inference.py
pil_image = pil_image.convert("RGB")
transform = T.Compose(
[
T.RandomResize([800], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
image = np.asarray(pil_image)
image_transformed, _ = transform(pil_image, None)
# set the image for sam2
self.sam2_predictor.set_image(image)
# predict the bounding boxes
boxes, confidences, labels = predict(
model=self.grounding_model,
image=image_transformed,
caption=text,
box_threshold=self.box_threshold,
text_threshold=self.text_threshold,
device=self.device,
)
if boxes is None or len(boxes) < 1:
return [], image
# process the box prompt for SAM 2
h, w, _ = image.shape
boxes = boxes * torch.Tensor([w, h, w, h])
input_boxes = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
# NOTE: somehow this didn't work locally inside a docker container
# torch.autocast(device_type=self.device, dtype=torch.bfloat16).__enter__()
# if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8:
# # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
# torch.backends.cuda.matmul.allow_tf32 = True
# torch.backends.cudnn.allow_tf32 = True
masks, scores, logits = self.sam2_predictor.predict(
point_coords=None,
point_labels=None,
box=input_boxes,
multimask_output=False,
)
# convert the shape to (n, H, W)
if masks.ndim == 4:
masks = masks.squeeze(1)
confidences = confidences.numpy().tolist()
class_names = labels
class_ids = np.array(list(range(len(class_names))))
labels = [
f"{class_name} {confidence:.2f}"
for class_name, confidence in zip(class_names, confidences)
]
return zip(
masks,
scores,
ImageSegmentation._bboxes_from_masks(masks),
input_boxes,
class_names,
ImageSegmentation._centers_of_mass_from_masks(masks),
), ImageSegmentation._create_result_image(
image, masks, bboxes=input_boxes, labels=labels, class_ids=class_ids
)
def segment_image_from_bbox(self, pil_image: PIL.Image.Image, bboxes: np.array):
"""Generate segmentation masks from image and bounding box coordinates using SAM2.
Args:
pil_image: PIL image that should be segmented
bboxes: Nx4 array of bounding boxes of objects to be segmented (x1, y1, x2, y2)
Returns:
List of C tuples (mask, score) with mask (HxW) and float score
Segmented image
"""
image = np.asarray(pil_image.convert("RGB"))
self.sam2_predictor.set_image(image)
masks, scores, logits = self.sam2_predictor.predict(
point_coords=None,
point_labels=None,
box=np.array(bboxes),
multimask_output=False,
)
# convert the shape to (n, H, W)
if masks.ndim == 4:
masks = masks.squeeze(1)
return zip(
masks, scores, ImageSegmentation._bboxes_from_masks(masks)
), ImageSegmentation._create_result_image(image, masks, bboxes=bboxes)
def segment_image_from_points(self, pil_image: PIL.Image.Image, points: np.array):
"""Generate segmentation masks from image and point coordinates with include/exclude labels using SAM2.
Args:
pil_image: PIL image that should be segmented
points: Nx3 array of points with include/exclude flags of the objects to be segmented (x, y, include)
Returns:
List of C tuples (mask, score) with mask (HxW) and float score
Result image
"""
image = np.asarray(pil_image.convert("RGB"))
self.sam2_predictor.set_image(image)
# convert points to coordinates and labels arrays
coords = np.array([[point[0], point[1]] for point in points])
labels = np.array([1 if point[2] else 0 for point in points])
masks, scores, logits = self.sam2_predictor.predict(
point_coords=coords,
point_labels=labels,
multimask_output=False,
)
# convert the shape to (n, H, W)
if masks.ndim == 4:
masks = masks.squeeze(1)
return zip(
masks, scores, ImageSegmentation._bboxes_from_masks(masks)
), ImageSegmentation._create_result_image(image, masks, points=points)
def _create_result_image(
pil_image: PIL.Image.Image,
masks: np.array,
bboxes: np.array = [],
labels: np.array = [],
class_ids: np.array = None,
points: np.array = [],
):
"""Create annotated result image with masks, bounding boxes, labels, and points overlaid.
Args:
pil_image: PIL image that should be segmented
masks: NxHxW array of object mask(s)
bboxes: (optional) Nx4 array of objects bounding box(es) (x1, y1, x2, y2)
labels: (optional) Nx1 array of object label(s)
points: (optional) Nx3 array of object point(s) with include/exclude flags (x, y, include)
Returns:
List of C tuples (mask, score) with mask (HxW) and float score
Result image
"""
img = np.array(pil_image)
# we have to define the bboxes in the detections even though we might not show them
detection_bboxes = bboxes
if bboxes is None or len(bboxes) == 0:
detection_bboxes = ImageSegmentation._bboxes_from_masks(masks)
detections = sv.Detections(
xyxy=detection_bboxes, # (n, 4)
mask=masks.astype(bool), # (n, h, w)
class_id=class_ids,
)
annotated_frame = img.copy()
# if there is no class ids (i.e. when using without Grounding DINO) we need to derive the
# color lookup
colorlookup = sv.ColorLookup.INDEX
if class_ids is not None and len(class_ids) > 0:
colorlookup = sv.ColorLookup.CLASS
# points
if points is not None:
for x, y, include in points:
# Green for include (True), Red for exclude (False)
color = (0, 255, 0) if include else (0, 0, 255) # BGR format
cv2.circle(annotated_frame, (x, y), 8, (0, 0, 0), -1) # Outer ring
cv2.circle(annotated_frame, (x, y), 5, color, -1) # Filled circle
# bboxes
if len(bboxes) > 0:
box_annotator = sv.BoxAnnotator(color_lookup=colorlookup)
annotated_frame = box_annotator.annotate(
scene=annotated_frame, detections=detections
)
# labels
if labels is not None and len(labels) > 0:
label_annotator = sv.LabelAnnotator(color_lookup=colorlookup)
annotated_frame = label_annotator.annotate(
scene=annotated_frame, detections=detections, labels=labels
)
# mask
mask_annotator = sv.MaskAnnotator(color_lookup=colorlookup)
annotated_frame = mask_annotator.annotate(
scene=annotated_frame, detections=detections
)
return annotated_frame
def _bboxes_from_masks(masks: np.array):
"""Create bounding boxes for the provided masks
Args:
masks: NxHxW array of object mask(s)
Returns:
bboxes: (optional) Nx4 array of mask bounding box (x1, y1, x2, y2)
"""
bboxes = []
for mask in masks:
mask_bool = np.where(mask != 0)
if len(mask_bool) != 0 and len(mask_bool[1]) != 0 and len(mask_bool[0]) != 0:
bboxes.append(
[
int(np.min(mask_bool[1])),
int(np.min(mask_bool[0])),
int(np.max(mask_bool[1])),
int(np.max(mask_bool[0])),
]
)
else:
bboxes.append([0, 0, 0, 0])
return np.array(bboxes)
def _centers_of_mass_from_masks(masks: np.array):
"""Calculate centers of mass for the provided masks
Args:
masks: NxHxW array of object mask(s)
Returns:
centers_of_mass: (optional) Nx2 array of mask center of mass (x1, y1, x2, y2)
"""
return np.array(
[[x, y] for mask in masks for y, x in [ndimage.center_of_mass(mask)]]
)