import numpy as np import supervision as sv import cv2 import PIL from scipy import ndimage from typing import List, Tuple, Union, Optional import torch from torchvision.ops import box_convert from sam2.build_sam import build_sam2 from sam2.sam2_image_predictor import SAM2ImagePredictor from grounding_dino.groundingdino.util.inference import load_model, load_image, predict import grounding_dino.groundingdino.datasets.transforms as T class ImageSegmentation: def __init__(self): # select the device for computation if torch.cuda.is_available(): self.device = torch.device("cuda") else: self.device = torch.device("cpu") print(f"using device: {self.device}") if self.device.type == "cuda": # NOTE: somehow this didn't work locally inside a docker container # use bfloat16 for the entire notebook # orignal: # torch.autocast("cuda", dtype=torch.bfloat16).__enter__() # might work without or this: # torch.autocast("cuda", dtype=torch.float16).__enter__() # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices) if torch.cuda.get_device_properties(0).major >= 8: torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True sam2_checkpoint = "checkpoints/sam2.1_hiera_large.pt" model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml" self.sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=self.device) self.sam2_predictor = SAM2ImagePredictor(self.sam2_model) grounding_dino_config = ( "grounding_dino/groundingdino/config/GroundingDINO_SwinT_OGC.py" ) grounding_dino_checkpoint = "gdino_checkpoints/groundingdino_swint_ogc.pth" self.box_threshold = 0.35 self.text_threshold = 0.25 self.grounding_model = load_model( model_config_path=grounding_dino_config, model_checkpoint_path=grounding_dino_checkpoint, device=self.device, ) def segment_image_from_text(self, pil_image: PIL.Image.Image, text: str): """Generate segmentation masks from image and text description using Grounding DINO + SAM2. Args: pil_image: PIL image that should be segmented text: object description(s) to be segmented Returns: List of C tuples (mask, score) with mask (HxW) and float score Result image """ # image preparation taken from load_image() in Grounded-SAM-2/grounding_dino/groundingdino/util/inference.py pil_image = pil_image.convert("RGB") transform = T.Compose( [ T.RandomResize([800], max_size=1333), T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] ) image = np.asarray(pil_image) image_transformed, _ = transform(pil_image, None) # set the image for sam2 self.sam2_predictor.set_image(image) # predict the bounding boxes boxes, confidences, labels = predict( model=self.grounding_model, image=image_transformed, caption=text, box_threshold=self.box_threshold, text_threshold=self.text_threshold, device=self.device, ) if boxes is None or len(boxes) < 1: return [], image # process the box prompt for SAM 2 h, w, _ = image.shape boxes = boxes * torch.Tensor([w, h, w, h]) input_boxes = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy() # NOTE: somehow this didn't work locally inside a docker container # torch.autocast(device_type=self.device, dtype=torch.bfloat16).__enter__() # if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8: # # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices) # torch.backends.cuda.matmul.allow_tf32 = True # torch.backends.cudnn.allow_tf32 = True masks, scores, logits = self.sam2_predictor.predict( point_coords=None, point_labels=None, box=input_boxes, multimask_output=False, ) # convert the shape to (n, H, W) if masks.ndim == 4: masks = masks.squeeze(1) confidences = confidences.numpy().tolist() class_names = labels class_ids = np.array(list(range(len(class_names)))) labels = [ f"{class_name} {confidence:.2f}" for class_name, confidence in zip(class_names, confidences) ] return zip( masks, scores, ImageSegmentation._bboxes_from_masks(masks), input_boxes, class_names, ImageSegmentation._centers_of_mass_from_masks(masks), ), ImageSegmentation._create_result_image( image, masks, bboxes=input_boxes, labels=labels, class_ids=class_ids ) def segment_image_from_bbox(self, pil_image: PIL.Image.Image, bboxes: np.array): """Generate segmentation masks from image and bounding box coordinates using SAM2. Args: pil_image: PIL image that should be segmented bboxes: Nx4 array of bounding boxes of objects to be segmented (x1, y1, x2, y2) Returns: List of C tuples (mask, score) with mask (HxW) and float score Segmented image """ image = np.asarray(pil_image.convert("RGB")) self.sam2_predictor.set_image(image) masks, scores, logits = self.sam2_predictor.predict( point_coords=None, point_labels=None, box=np.array(bboxes), multimask_output=False, ) # convert the shape to (n, H, W) if masks.ndim == 4: masks = masks.squeeze(1) return zip( masks, scores, ImageSegmentation._bboxes_from_masks(masks) ), ImageSegmentation._create_result_image(image, masks, bboxes=bboxes) def segment_image_from_points(self, pil_image: PIL.Image.Image, points: np.array): """Generate segmentation masks from image and point coordinates with include/exclude labels using SAM2. Args: pil_image: PIL image that should be segmented points: Nx3 array of points with include/exclude flags of the objects to be segmented (x, y, include) Returns: List of C tuples (mask, score) with mask (HxW) and float score Result image """ image = np.asarray(pil_image.convert("RGB")) self.sam2_predictor.set_image(image) # convert points to coordinates and labels arrays coords = np.array([[point[0], point[1]] for point in points]) labels = np.array([1 if point[2] else 0 for point in points]) masks, scores, logits = self.sam2_predictor.predict( point_coords=coords, point_labels=labels, multimask_output=False, ) # convert the shape to (n, H, W) if masks.ndim == 4: masks = masks.squeeze(1) return zip( masks, scores, ImageSegmentation._bboxes_from_masks(masks) ), ImageSegmentation._create_result_image(image, masks, points=points) def _create_result_image( pil_image: PIL.Image.Image, masks: np.array, bboxes: np.array = [], labels: np.array = [], class_ids: np.array = None, points: np.array = [], ): """Create annotated result image with masks, bounding boxes, labels, and points overlaid. Args: pil_image: PIL image that should be segmented masks: NxHxW array of object mask(s) bboxes: (optional) Nx4 array of objects bounding box(es) (x1, y1, x2, y2) labels: (optional) Nx1 array of object label(s) points: (optional) Nx3 array of object point(s) with include/exclude flags (x, y, include) Returns: List of C tuples (mask, score) with mask (HxW) and float score Result image """ img = np.array(pil_image) # we have to define the bboxes in the detections even though we might not show them detection_bboxes = bboxes if bboxes is None or len(bboxes) == 0: detection_bboxes = ImageSegmentation._bboxes_from_masks(masks) detections = sv.Detections( xyxy=detection_bboxes, # (n, 4) mask=masks.astype(bool), # (n, h, w) class_id=class_ids, ) annotated_frame = img.copy() # if there is no class ids (i.e. when using without Grounding DINO) we need to derive the # color lookup colorlookup = sv.ColorLookup.INDEX if class_ids is not None and len(class_ids) > 0: colorlookup = sv.ColorLookup.CLASS # points if points is not None: for x, y, include in points: # Green for include (True), Red for exclude (False) color = (0, 255, 0) if include else (0, 0, 255) # BGR format cv2.circle(annotated_frame, (x, y), 8, (0, 0, 0), -1) # Outer ring cv2.circle(annotated_frame, (x, y), 5, color, -1) # Filled circle # bboxes if len(bboxes) > 0: box_annotator = sv.BoxAnnotator(color_lookup=colorlookup) annotated_frame = box_annotator.annotate( scene=annotated_frame, detections=detections ) # labels if labels is not None and len(labels) > 0: label_annotator = sv.LabelAnnotator(color_lookup=colorlookup) annotated_frame = label_annotator.annotate( scene=annotated_frame, detections=detections, labels=labels ) # mask mask_annotator = sv.MaskAnnotator(color_lookup=colorlookup) annotated_frame = mask_annotator.annotate( scene=annotated_frame, detections=detections ) return annotated_frame def _bboxes_from_masks(masks: np.array): """Create bounding boxes for the provided masks Args: masks: NxHxW array of object mask(s) Returns: bboxes: (optional) Nx4 array of mask bounding box (x1, y1, x2, y2) """ bboxes = [] for mask in masks: mask_bool = np.where(mask != 0) if len(mask_bool) != 0 and len(mask_bool[1]) != 0 and len(mask_bool[0]) != 0: bboxes.append( [ int(np.min(mask_bool[1])), int(np.min(mask_bool[0])), int(np.max(mask_bool[1])), int(np.max(mask_bool[0])), ] ) else: bboxes.append([0, 0, 0, 0]) return np.array(bboxes) def _centers_of_mass_from_masks(masks: np.array): """Calculate centers of mass for the provided masks Args: masks: NxHxW array of object mask(s) Returns: centers_of_mass: (optional) Nx2 array of mask center of mass (x1, y1, x2, y2) """ return np.array( [[x, y] for mask in masks for y, x in [ndimage.center_of_mass(mask)]] )