mirror of
https://github.com/foomo/gsamservice.git
synced 2025-10-16 12:35:37 +00:00
314 lines
11 KiB
Python
314 lines
11 KiB
Python
import numpy as np
|
|
import supervision as sv
|
|
import cv2
|
|
import PIL
|
|
from scipy import ndimage
|
|
from typing import List, Tuple, Union, Optional
|
|
|
|
import torch
|
|
from torchvision.ops import box_convert
|
|
|
|
from sam2.build_sam import build_sam2
|
|
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
|
|
|
from grounding_dino.groundingdino.util.inference import load_model, load_image, predict
|
|
import grounding_dino.groundingdino.datasets.transforms as T
|
|
|
|
|
|
class ImageSegmentation:
|
|
def __init__(self):
|
|
# select the device for computation
|
|
if torch.cuda.is_available():
|
|
self.device = torch.device("cuda")
|
|
else:
|
|
self.device = torch.device("cpu")
|
|
print(f"using device: {self.device}")
|
|
|
|
if self.device.type == "cuda":
|
|
# NOTE: somehow this didn't work locally inside a docker container
|
|
# use bfloat16 for the entire notebook
|
|
# orignal:
|
|
# torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
|
|
# might work without or this:
|
|
# torch.autocast("cuda", dtype=torch.float16).__enter__()
|
|
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
|
if torch.cuda.get_device_properties(0).major >= 8:
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
torch.backends.cudnn.allow_tf32 = True
|
|
|
|
sam2_checkpoint = "checkpoints/sam2.1_hiera_large.pt"
|
|
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
|
|
|
self.sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=self.device)
|
|
self.sam2_predictor = SAM2ImagePredictor(self.sam2_model)
|
|
|
|
grounding_dino_config = (
|
|
"grounding_dino/groundingdino/config/GroundingDINO_SwinT_OGC.py"
|
|
)
|
|
grounding_dino_checkpoint = "gdino_checkpoints/groundingdino_swint_ogc.pth"
|
|
self.box_threshold = 0.35
|
|
self.text_threshold = 0.25
|
|
|
|
self.grounding_model = load_model(
|
|
model_config_path=grounding_dino_config,
|
|
model_checkpoint_path=grounding_dino_checkpoint,
|
|
device=self.device,
|
|
)
|
|
|
|
def segment_image_from_text(self, pil_image: PIL.Image.Image, text: str):
|
|
"""Generate segmentation masks from image and text description using Grounding DINO + SAM2.
|
|
|
|
Args:
|
|
pil_image: PIL image that should be segmented
|
|
text: object description(s) to be segmented
|
|
|
|
Returns:
|
|
List of C tuples (mask, score) with mask (HxW) and float score
|
|
Result image
|
|
"""
|
|
|
|
# image preparation taken from load_image() in Grounded-SAM-2/grounding_dino/groundingdino/util/inference.py
|
|
pil_image = pil_image.convert("RGB")
|
|
transform = T.Compose(
|
|
[
|
|
T.RandomResize([800], max_size=1333),
|
|
T.ToTensor(),
|
|
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
|
]
|
|
)
|
|
image = np.asarray(pil_image)
|
|
image_transformed, _ = transform(pil_image, None)
|
|
|
|
# set the image for sam2
|
|
self.sam2_predictor.set_image(image)
|
|
|
|
# predict the bounding boxes
|
|
boxes, confidences, labels = predict(
|
|
model=self.grounding_model,
|
|
image=image_transformed,
|
|
caption=text,
|
|
box_threshold=self.box_threshold,
|
|
text_threshold=self.text_threshold,
|
|
device=self.device,
|
|
)
|
|
|
|
if boxes is None or len(boxes) < 1:
|
|
return [], image
|
|
|
|
# process the box prompt for SAM 2
|
|
h, w, _ = image.shape
|
|
boxes = boxes * torch.Tensor([w, h, w, h])
|
|
input_boxes = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
|
|
|
|
# NOTE: somehow this didn't work locally inside a docker container
|
|
# torch.autocast(device_type=self.device, dtype=torch.bfloat16).__enter__()
|
|
# if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8:
|
|
# # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
|
# torch.backends.cuda.matmul.allow_tf32 = True
|
|
# torch.backends.cudnn.allow_tf32 = True
|
|
|
|
masks, scores, logits = self.sam2_predictor.predict(
|
|
point_coords=None,
|
|
point_labels=None,
|
|
box=input_boxes,
|
|
multimask_output=False,
|
|
)
|
|
|
|
# convert the shape to (n, H, W)
|
|
if masks.ndim == 4:
|
|
masks = masks.squeeze(1)
|
|
|
|
confidences = confidences.numpy().tolist()
|
|
class_names = labels
|
|
|
|
class_ids = np.array(list(range(len(class_names))))
|
|
|
|
labels = [
|
|
f"{class_name} {confidence:.2f}"
|
|
for class_name, confidence in zip(class_names, confidences)
|
|
]
|
|
|
|
return zip(
|
|
masks,
|
|
scores,
|
|
ImageSegmentation._bboxes_from_masks(masks),
|
|
input_boxes,
|
|
class_names,
|
|
ImageSegmentation._centers_of_mass_from_masks(masks),
|
|
), ImageSegmentation._create_result_image(
|
|
image, masks, bboxes=input_boxes, labels=labels, class_ids=class_ids
|
|
)
|
|
|
|
def segment_image_from_bbox(self, pil_image: PIL.Image.Image, bboxes: np.array):
|
|
"""Generate segmentation masks from image and bounding box coordinates using SAM2.
|
|
|
|
Args:
|
|
pil_image: PIL image that should be segmented
|
|
bboxes: Nx4 array of bounding boxes of objects to be segmented (x1, y1, x2, y2)
|
|
|
|
Returns:
|
|
List of C tuples (mask, score) with mask (HxW) and float score
|
|
Segmented image
|
|
"""
|
|
image = np.asarray(pil_image.convert("RGB"))
|
|
self.sam2_predictor.set_image(image)
|
|
|
|
masks, scores, logits = self.sam2_predictor.predict(
|
|
point_coords=None,
|
|
point_labels=None,
|
|
box=np.array(bboxes),
|
|
multimask_output=False,
|
|
)
|
|
|
|
# convert the shape to (n, H, W)
|
|
if masks.ndim == 4:
|
|
masks = masks.squeeze(1)
|
|
|
|
return zip(
|
|
masks, scores, ImageSegmentation._bboxes_from_masks(masks)
|
|
), ImageSegmentation._create_result_image(image, masks, bboxes=bboxes)
|
|
|
|
def segment_image_from_points(self, pil_image: PIL.Image.Image, points: np.array):
|
|
"""Generate segmentation masks from image and point coordinates with include/exclude labels using SAM2.
|
|
|
|
Args:
|
|
pil_image: PIL image that should be segmented
|
|
points: Nx3 array of points with include/exclude flags of the objects to be segmented (x, y, include)
|
|
|
|
Returns:
|
|
List of C tuples (mask, score) with mask (HxW) and float score
|
|
Result image
|
|
"""
|
|
image = np.asarray(pil_image.convert("RGB"))
|
|
self.sam2_predictor.set_image(image)
|
|
|
|
# convert points to coordinates and labels arrays
|
|
coords = np.array([[point[0], point[1]] for point in points])
|
|
labels = np.array([1 if point[2] else 0 for point in points])
|
|
|
|
masks, scores, logits = self.sam2_predictor.predict(
|
|
point_coords=coords,
|
|
point_labels=labels,
|
|
multimask_output=False,
|
|
)
|
|
|
|
# convert the shape to (n, H, W)
|
|
if masks.ndim == 4:
|
|
masks = masks.squeeze(1)
|
|
|
|
return zip(
|
|
masks, scores, ImageSegmentation._bboxes_from_masks(masks)
|
|
), ImageSegmentation._create_result_image(image, masks, points=points)
|
|
|
|
def _create_result_image(
|
|
pil_image: PIL.Image.Image,
|
|
masks: np.array,
|
|
bboxes: np.array = [],
|
|
labels: np.array = [],
|
|
class_ids: np.array = None,
|
|
points: np.array = [],
|
|
):
|
|
"""Create annotated result image with masks, bounding boxes, labels, and points overlaid.
|
|
|
|
Args:
|
|
pil_image: PIL image that should be segmented
|
|
masks: NxHxW array of object mask(s)
|
|
bboxes: (optional) Nx4 array of objects bounding box(es) (x1, y1, x2, y2)
|
|
labels: (optional) Nx1 array of object label(s)
|
|
points: (optional) Nx3 array of object point(s) with include/exclude flags (x, y, include)
|
|
|
|
Returns:
|
|
List of C tuples (mask, score) with mask (HxW) and float score
|
|
Result image
|
|
"""
|
|
img = np.array(pil_image)
|
|
|
|
# we have to define the bboxes in the detections even though we might not show them
|
|
detection_bboxes = bboxes
|
|
if bboxes is None or len(bboxes) == 0:
|
|
detection_bboxes = ImageSegmentation._bboxes_from_masks(masks)
|
|
|
|
detections = sv.Detections(
|
|
xyxy=detection_bboxes, # (n, 4)
|
|
mask=masks.astype(bool), # (n, h, w)
|
|
class_id=class_ids,
|
|
)
|
|
annotated_frame = img.copy()
|
|
|
|
# if there is no class ids (i.e. when using without Grounding DINO) we need to derive the
|
|
# color lookup
|
|
colorlookup = sv.ColorLookup.INDEX
|
|
if class_ids is not None and len(class_ids) > 0:
|
|
colorlookup = sv.ColorLookup.CLASS
|
|
|
|
# points
|
|
if points is not None:
|
|
for x, y, include in points:
|
|
# Green for include (True), Red for exclude (False)
|
|
color = (0, 255, 0) if include else (0, 0, 255) # BGR format
|
|
cv2.circle(annotated_frame, (x, y), 8, (0, 0, 0), -1) # Outer ring
|
|
cv2.circle(annotated_frame, (x, y), 5, color, -1) # Filled circle
|
|
|
|
# bboxes
|
|
if len(bboxes) > 0:
|
|
box_annotator = sv.BoxAnnotator(color_lookup=colorlookup)
|
|
annotated_frame = box_annotator.annotate(
|
|
scene=annotated_frame, detections=detections
|
|
)
|
|
|
|
# labels
|
|
if labels is not None and len(labels) > 0:
|
|
label_annotator = sv.LabelAnnotator(color_lookup=colorlookup)
|
|
annotated_frame = label_annotator.annotate(
|
|
scene=annotated_frame, detections=detections, labels=labels
|
|
)
|
|
|
|
# mask
|
|
mask_annotator = sv.MaskAnnotator(color_lookup=colorlookup)
|
|
annotated_frame = mask_annotator.annotate(
|
|
scene=annotated_frame, detections=detections
|
|
)
|
|
|
|
return annotated_frame
|
|
|
|
def _bboxes_from_masks(masks: np.array):
|
|
"""Create bounding boxes for the provided masks
|
|
|
|
Args:
|
|
masks: NxHxW array of object mask(s)
|
|
|
|
Returns:
|
|
bboxes: (optional) Nx4 array of mask bounding box (x1, y1, x2, y2)
|
|
"""
|
|
|
|
bboxes = []
|
|
for mask in masks:
|
|
mask_bool = np.where(mask != 0)
|
|
if len(mask_bool) != 0 and len(mask_bool[1]) != 0 and len(mask_bool[0]) != 0:
|
|
bboxes.append(
|
|
[
|
|
int(np.min(mask_bool[1])),
|
|
int(np.min(mask_bool[0])),
|
|
int(np.max(mask_bool[1])),
|
|
int(np.max(mask_bool[0])),
|
|
]
|
|
)
|
|
else:
|
|
bboxes.append([0, 0, 0, 0])
|
|
|
|
return np.array(bboxes)
|
|
|
|
def _centers_of_mass_from_masks(masks: np.array):
|
|
"""Calculate centers of mass for the provided masks
|
|
|
|
Args:
|
|
masks: NxHxW array of object mask(s)
|
|
|
|
Returns:
|
|
centers_of_mass: (optional) Nx2 array of mask center of mass (x1, y1, x2, y2)
|
|
"""
|
|
|
|
return np.array(
|
|
[[x, y] for mask in masks for y, x in [ndimage.center_of_mass(mask)]]
|
|
)
|