initial commit

This commit is contained in:
botastic 2025-08-29 09:53:07 +02:00
parent 27ffdf2668
commit 7b063e8357
9 changed files with 983 additions and 0 deletions

1
.gitignore vendored Normal file
View File

@ -0,0 +1 @@
example/out

57
Dockerfile Normal file
View File

@ -0,0 +1,57 @@
FROM docker.io/pytorch/pytorch:2.3.1-cuda12.1-cudnn8-devel
# arguments to build Docker Image using CUDA
ARG USE_CUDA=0
ARG TORCH_ARCH="7.0;7.5;8.0;8.6"
ENV AM_I_DOCKER=True
ENV BUILD_WITH_CUDA="${USE_CUDA}"
ENV TORCH_CUDA_ARCH_LIST="${TORCH_ARCH}"
ENV CUDA_HOME=/usr/local/cuda-12.1/
# ensure CUDA is correctly set up
ENV PATH=/usr/local/cuda-12.1/bin:${PATH}
ENV LD_LIBRARY_PATH=/usr/local/cuda-12.1/lib64:${LD_LIBRARY_PATH}
# install required packages and specific gcc/g++
RUN apt-get update && apt-get install --no-install-recommends wget ffmpeg=7:* \
libsm6=2:* libxext6=2:* git=1:* nano vim=2:* ninja-build gcc-10 g++-10 git -y \
&& apt-get clean && apt-get autoremove && rm -rf /var/lib/apt/lists/*
ENV CC=gcc-10
ENV CXX=g++-10
# clone grounded sam2 repo
WORKDIR /home/appuser
RUN git clone https://github.com/IDEA-Research/Grounded-SAM-2
# download sam2 checkpoints
WORKDIR /home/appuser/Grounded-SAM-2/checkpoints
RUN bash download_ckpts.sh
# download grounding dino checkpoints
WORKDIR /home/appuser/Grounded-SAM-2/gdino_checkpoints
RUN bash download_ckpts.sh
WORKDIR /home/appuser/Grounded-SAM-2
# install essential Python packages
RUN python -m pip install --upgrade pip "setuptools>=62.3.0,<75.9" wheel numpy \
opencv-python transformers supervision pycocotools addict yapf timm
# install segment_anything package in editable mode
RUN python -m pip install -e .
# install grounding dino
RUN python -m pip install --no-build-isolation -e grounding_dino
# install the server dependencies
COPY requirements.txt requirements.txt
RUN python -m pip install -r requirements.txt
COPY app.py app.py
COPY imagesegmentation.py imagesegmentation.py
# RUN mkdir ../host
# start the server
ENTRYPOINT ["python", "app.py", "--log-level", "debug"]

43
Makefile Normal file
View File

@ -0,0 +1,43 @@
# Get version of CUDA and enable it for compilation if CUDA > 11.0
# This solves https://github.com/IDEA-Research/Grounded-Segment-Anything/issues/53
# and https://github.com/IDEA-Research/Grounded-Segment-Anything/issues/84
# when running in Docker
# Check if nvcc is installed
NVCC := $(shell which nvcc)
ifeq ($(NVCC),)
# NVCC not found
USE_CUDA := 0
NVCC_VERSION := "not installed"
else
NVCC_VERSION := $(shell nvcc --version | grep -oP 'release \K[0-9.]+')
USE_CUDA := $(shell echo "$(NVCC_VERSION) > 11" | bc -l)
endif
# Add the list of supported ARCHs
ifeq ($(USE_CUDA), 1)
TORCH_CUDA_ARCH_LIST := "7.0;7.5;8.0;8.6+PTX"
BUILD_MESSAGE := "Trying to build the image with CUDA support"
else
TORCH_CUDA_ARCH_LIST :=
BUILD_MESSAGE := "CUDA $(NVCC_VERSION) is not supported"
endif
build:
docker build --build-arg USE_CUDA=$(USE_CUDA) \
--build-arg TORCH_ARCH=$(TORCH_CUDA_ARCH_LIST) \
--progress=plain -t gsam2 .
run:
docker run -d --gpus all \
--restart unless-stopped \
--name=gsam2 \
--ipc=host -p 13337:13337 gsam2
run-bash:
docker run -it --rm --gpus all \
-v "${PWD}":/home/appuser/host \
--entrypoint bash \
--name=gsam2 \
--network=host \
--ipc=host gsam2

14
README.md Normal file
View File

@ -0,0 +1,14 @@
# GSAM Service
Simple server providing [Grounded SAM2](https://github.com/IDEA-Research/Grounded-SAM-2) through an REST API
## Usage
Build and run the container
```
make build
make run
```
You can then connect to the server on port 13337. Have a look at the `example/main.go` for examples of the provided endpoints.

303
app.py Normal file
View File

@ -0,0 +1,303 @@
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Tuple, Union, Optional
import base64
import io
from PIL import Image
import numpy as np
import cv2
from imagesegmentation import ImageSegmentation
app = FastAPI(
title="GSAM2 API",
description="Grounded SAM 2 Image Segmentation API",
version="1.0.0",
)
segmentation_model = ImageSegmentation()
# pydantic models for request validation
class Point(BaseModel):
x: int
y: int
include: bool # True for include, False for exclude
class BoundingBox(BaseModel):
upper_left: Tuple[int, int] # (x, y) coordinates
lower_right: Tuple[int, int] # (x, y) coordinates
class MaskFromTextRequest(BaseModel):
image: str # base64 encoded image
text: str
class MaskFromBBoxRequest(BaseModel):
image: str # base64 encoded image
bboxes: List[BoundingBox]
class MaskFromPointsRequest(BaseModel):
image: str # base64 encoded image
points: List[Point]
class MaskResult(BaseModel):
mask: str
score: float
bbox: BoundingBox # bounding box generated from the mask
# fields, only populated in responses for MaskFromTextRequests
class_name: str = ""
dino_bbox: BoundingBox = BoundingBox(upper_left=(0, 0), lower_right=(0, 0))
center_of_mass: Tuple[float, float] = (0.0, 0.0)
class MaskResponse(BaseModel):
masks: List[MaskResult] # list of base64 encoded mask images and respectivescores
image: str # base64 encoded result image
def decode_base64_image(base64_string: str) -> Image.Image:
"""Helper function to decode base64 image string to PIL Image"""
try:
# remove data URL prefix if present
if base64_string.startswith("data:image"):
base64_string = base64_string.split(",")[1]
image_data = base64.b64decode(base64_string)
image = Image.open(io.BytesIO(image_data))
return image
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid image data: {str(e)}")
def encode_mask_to_base64(mask: np.ndarray) -> str:
"""Helper function to encode mask array to base64 string"""
try:
# convert mask to PIL Image (assuming binary mask)
mask_image = Image.fromarray((mask * 255).astype(np.uint8), mode="L")
# convert to base64
buffer = io.BytesIO()
mask_image.save(buffer, format="JPEG")
mask_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
return mask_base64
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to encode mask: {str(e)}")
def encode_image_to_base64(image: np.ndarray) -> str:
"""Helper function to encode cv2 image array to base64 string"""
try:
pil_image = Image.fromarray(image.astype(np.uint8))
# convert to base64
buffer = io.BytesIO()
pil_image.save(buffer, format="JPEG")
image_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
return image_base64
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to encode image: {str(e)}")
@app.get("/")
async def root():
return {"message": "GSAM2 API Server", "version": "1.0.0"}
@app.post("/gsam2/image/maskfromtext", response_model=MaskResponse)
async def mask_from_text(request: MaskFromTextRequest):
"""
Generate segmentation masks from an image and text description.
Args:
request: Contains base64 encoded image and text description
Returns:
MaskResponse with list of base64 encoded masks, their scores and result image
"""
try:
# decode the input image
pil_image = decode_base64_image(request.image)
text = request.text
# segment the image
masks, annotated_image = segmentation_model.segment_image_from_text(
pil_image, text
)
# encode the results
enc_masks = [
MaskResult(
mask=encode_mask_to_base64(mask),
score=score,
bbox=BoundingBox(
upper_left=(bbox[0], bbox[1]), lower_right=(bbox[2], bbox[3])
),
class_name=class_name,
dino_bbox=BoundingBox(
upper_left=(round(dino_bbox[0]), round(dino_bbox[1])),
lower_right=(round(dino_bbox[2]), round(dino_bbox[3])),
),
center_of_mass=(com[0], com[1]),
)
for (mask, score, bbox, dino_bbox, class_name, com) in masks
]
enc_annotated_image = encode_image_to_base64(annotated_image)
return MaskResponse(
masks=enc_masks,
image=enc_annotated_image,
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}")
@app.post("/gsam2/image/maskfrombboxes", response_model=MaskResponse)
async def mask_from_bbox(request: MaskFromBBoxRequest):
"""
Generate segmentation masks from an image and bounding box.
Args:
request: Contains base64 encoded image and bounding box coordinates
Returns:
MaskResponse with list of base64 encoded masks, their scores and result image
"""
try:
pil_image = decode_base64_image(request.image)
# validate bounding box coordinates
bboxes = None
for bbox in request.bboxes:
x1, y1 = bbox.upper_left
x2, y2 = bbox.lower_right
if x1 >= x2 or y1 >= y2:
raise HTTPException(
status_code=400,
detail="Invalid bounding box: upper_left must be above and left of lower_right",
)
if x1 < 0 or y1 < 0 or x2 > pil_image.width or y2 > pil_image.height:
raise HTTPException(
status_code=400,
detail="Bounding box coordinates out of image bounds",
)
# convert to numpy array format expected by ImageSegmentation
if bboxes is None:
bboxes = np.array([[x1, y1, x2, y2]])
else:
bboxes = np.vstack((bboxes, [[x1, y1, x2, y2]]))
if bboxes is None:
raise HTTPException(
status_code=400, detail="At least one bounding box is required"
)
# segment the image
(masks, annotated_image) = segmentation_model.segment_image_from_bbox(
pil_image, bboxes
)
# encode the results
enc_masks = [
MaskResult(
mask=encode_mask_to_base64(mask),
score=score,
bbox=BoundingBox(
upper_left=(bbox[0], bbox[1]), lower_right=(bbox[2], bbox[3])
),
)
for (mask, score, bbox) in masks
]
enc_annotated_image = encode_image_to_base64(annotated_image)
return MaskResponse(
masks=enc_masks,
image=enc_annotated_image,
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}")
@app.post("/gsam2/image/maskfrompoints", response_model=MaskResponse)
async def mask_from_points(request: MaskFromPointsRequest):
"""
Generate segmentation masks from an image and list of points with include/exclude indicators.
Args:
request: Contains base64 encoded image and list of points with include/exclude flags
Returns:
MaskResponse with list of base64 encoded masks, their scores and result image
"""
try:
pil_image = decode_base64_image(request.image)
# validate point coordinates
for i, point in enumerate(request.points):
if (
point.x < 0
or point.x >= pil_image.width
or point.y < 0
or point.y >= pil_image.height
):
raise HTTPException(
status_code=400, detail=f"Point {i} coordinates out of image bounds"
)
# convert points to numpy array format expected by ImageSegmentation
points = None
if request.points is not None and len(request.points) > 0:
points = np.array(
[[point.x, point.y, point.include] for point in request.points]
)
if points is None:
raise HTTPException(
status_code=400, detail="At least one point is required"
)
# segment the image
(masks, annotated_image) = segmentation_model.segment_image_from_points(
pil_image, points
)
# encode the results
enc_masks = [
MaskResult(
mask=encode_mask_to_base64(mask),
score=score,
bbox=BoundingBox(
upper_left=(bbox[0], bbox[1]), lower_right=(bbox[2], bbox[3])
),
)
for (mask, score, bbox) in masks
]
enc_annotated_image = encode_image_to_base64(annotated_image)
return MaskResponse(
masks=enc_masks,
image=enc_annotated_image,
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=13337)

245
example/main.go Normal file
View File

@ -0,0 +1,245 @@
package main
import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httputil"
"os"
"time"
)
var (
url = "http://localhost:13337"
fromText = "/gsam2/image/maskfromtext"
fromBboxes = "/gsam2/image/maskfrombboxes"
fromPoints = "/gsam2/image/maskfrompoints"
image = "truck.jpg"
)
type (
Point struct {
X int `json:"x"`
Y int `json:"y"`
Include bool `json:"include"`
}
BoundingBox struct {
UpperLeft [2]int `json:"upper_left"`
LowerRight [2]int `json:"lower_right"`
}
MaskFromTextRequest struct {
Image string `json:"image"`
Text string `json:"text"`
}
MaskFromBBoxRequest struct {
Image string `json:"image"`
Bboxes []BoundingBox `json:"bboxes"`
}
MaskFromPointsRequest struct {
Image string `json:"image"`
Points []Point `json:"points"`
}
MaskResult struct {
Mask string `json:"mask"`
Score float64 `json:"score"`
BBox BoundingBox `json:"bbox"`
// fields, only populated in responses for MaskFromTextRequests
ClassName string `json:"class_name"`
DinoBBox BoundingBox `json:"dino_bbox"`
CenterOfMass [2]float64 `json:"center_of_mass"`
}
MaskResponse struct {
Masks []MaskResult `json:"masks"`
Image string `json:"image"`
}
)
func main() {
// ensure the out directory exists
os.Mkdir("out", 0755)
// load the sample image and base64 encode it
dat, err := os.ReadFile(image)
if err != nil {
fmt.Println(err)
return
}
// post to the different endpoints
c := &http.Client{Timeout: time.Minute}
encImage := base64.StdEncoding.EncodeToString(dat)
// from text
err = doFromText(c, "fromtext", encImage, "truck. tire.")
if err != nil {
fmt.Printf("error %s", err)
return
}
// from bboxes
err = doFromBboxes(c, "frombboxes", encImage, []BoundingBox{
{
UpperLeft: [2]int{75, 275},
LowerRight: [2]int{1725, 850},
},
{
UpperLeft: [2]int{425, 600},
LowerRight: [2]int{700, 875},
},
{
UpperLeft: [2]int{1375, 550},
LowerRight: [2]int{1650, 800},
},
{
UpperLeft: [2]int{1240, 675},
LowerRight: [2]int{1400, 750},
},
})
if err != nil {
fmt.Printf("error %s", err)
return
}
// from points
err = doFromPoints(c, "frompoints", encImage, []Point{
{X: 500, Y: 375, Include: true},
{X: 1125, Y: 625, Include: true},
{X: 575, Y: 750, Include: false},
})
if err != nil {
fmt.Printf("error %s", err)
return
}
}
func do(c *http.Client, req *http.Request, outname string) error {
dump, err := httputil.DumpRequest(req, false)
if err != nil {
return err
}
fmt.Println("request: ", string(dump))
resp, err := c.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode >= 300 {
dump, err := httputil.DumpResponse(resp, true)
if err != nil {
return err
}
fmt.Println("response: ", string(dump))
defer resp.Body.Close()
} else {
dump, err := httputil.DumpResponse(resp, false)
if err != nil {
return err
}
fmt.Println("response: ", string(dump))
}
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
maskResp := MaskResponse{}
err = json.Unmarshal(bodyBytes, &maskResp)
if err != nil {
return err
}
// write the masks to a file
for _, mask := range maskResp.Masks {
dec, err := base64.StdEncoding.DecodeString(mask.Mask)
if err != nil {
return err
}
class := ""
if mask.ClassName != "" {
class = "-" + mask.ClassName
}
os.WriteFile(fmt.Sprintf("out/%s%s-%.4f.jpg", outname, class, mask.Score), dec, 0644)
}
dec, err := base64.StdEncoding.DecodeString(maskResp.Image)
if err != nil {
return err
}
os.WriteFile(fmt.Sprintf("out/%s.jpg", outname), dec, 0644)
return nil
}
func doFromText(c *http.Client, outname string, encImage string, text string) error {
req, err := http.NewRequest("POST", url+fromText, nil)
if err != nil {
return err
}
req.Header.Add("Accept", `application/json`)
body := MaskFromTextRequest{
Image: encImage,
Text: text,
}
jsonBody, err := json.Marshal(body)
if err != nil {
return err
}
req.Body = io.NopCloser(bytes.NewBuffer(jsonBody))
return do(c, req, outname)
}
func doFromBboxes(c *http.Client, outname string, encImage string, bboxes []BoundingBox) error {
req, err := http.NewRequest("POST", url+fromBboxes, nil)
if err != nil {
return err
}
req.Header.Add("Accept", `application/json`)
body := MaskFromBBoxRequest{
Image: encImage,
Bboxes: bboxes,
}
jsonBody, err := json.Marshal(body)
if err != nil {
return err
}
req.Body = io.NopCloser(bytes.NewBuffer(jsonBody))
return do(c, req, outname)
}
func doFromPoints(c *http.Client, outname string, encImage string, points []Point) error {
req, err := http.NewRequest("POST", url+fromPoints, nil)
if err != nil {
return err
}
req.Header.Add("Accept", `application/json`)
body := MaskFromPointsRequest{
Image: encImage,
Points: points,
}
jsonBody, err := json.Marshal(body)
if err != nil {
return err
}
req.Body = io.NopCloser(bytes.NewBuffer(jsonBody))
return do(c, req, outname)
}

BIN
example/truck.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 265 KiB

313
imagesegmentation.py Normal file
View File

@ -0,0 +1,313 @@
import numpy as np
import supervision as sv
import cv2
import PIL
from scipy import ndimage
from typing import List, Tuple, Union, Optional
import torch
from torchvision.ops import box_convert
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from grounding_dino.groundingdino.util.inference import load_model, load_image, predict
import grounding_dino.groundingdino.datasets.transforms as T
class ImageSegmentation:
def __init__(self):
# select the device for computation
if torch.cuda.is_available():
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")
print(f"using device: {self.device}")
if self.device.type == "cuda":
# NOTE: somehow this didn't work locally inside a docker container
# use bfloat16 for the entire notebook
# orignal:
# torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
# might work without or this:
# torch.autocast("cuda", dtype=torch.float16).__enter__()
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
if torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
sam2_checkpoint = "checkpoints/sam2.1_hiera_large.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
self.sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=self.device)
self.sam2_predictor = SAM2ImagePredictor(self.sam2_model)
grounding_dino_config = (
"grounding_dino/groundingdino/config/GroundingDINO_SwinT_OGC.py"
)
grounding_dino_checkpoint = "gdino_checkpoints/groundingdino_swint_ogc.pth"
self.box_threshold = 0.35
self.text_threshold = 0.25
self.grounding_model = load_model(
model_config_path=grounding_dino_config,
model_checkpoint_path=grounding_dino_checkpoint,
device=self.device,
)
def segment_image_from_text(self, pil_image: PIL.Image.Image, text: str):
"""Generate segmentation masks from image and text description using Grounding DINO + SAM2.
Args:
pil_image: PIL image that should be segmented
text: object description(s) to be segmented
Returns:
List of C tuples (mask, score) with mask (HxW) and float score
Result image
"""
# image preparation taken from load_image() in Grounded-SAM-2/grounding_dino/groundingdino/util/inference.py
pil_image = pil_image.convert("RGB")
transform = T.Compose(
[
T.RandomResize([800], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
image = np.asarray(pil_image)
image_transformed, _ = transform(pil_image, None)
# set the image for sam2
self.sam2_predictor.set_image(image)
# predict the bounding boxes
boxes, confidences, labels = predict(
model=self.grounding_model,
image=image_transformed,
caption=text,
box_threshold=self.box_threshold,
text_threshold=self.text_threshold,
device=self.device,
)
if boxes is None or len(boxes) < 1:
return [], image
# process the box prompt for SAM 2
h, w, _ = image.shape
boxes = boxes * torch.Tensor([w, h, w, h])
input_boxes = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
# NOTE: somehow this didn't work locally inside a docker container
# torch.autocast(device_type=self.device, dtype=torch.bfloat16).__enter__()
# if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8:
# # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
# torch.backends.cuda.matmul.allow_tf32 = True
# torch.backends.cudnn.allow_tf32 = True
masks, scores, logits = self.sam2_predictor.predict(
point_coords=None,
point_labels=None,
box=input_boxes,
multimask_output=False,
)
# convert the shape to (n, H, W)
if masks.ndim == 4:
masks = masks.squeeze(1)
confidences = confidences.numpy().tolist()
class_names = labels
class_ids = np.array(list(range(len(class_names))))
labels = [
f"{class_name} {confidence:.2f}"
for class_name, confidence in zip(class_names, confidences)
]
return zip(
masks,
scores,
ImageSegmentation._bboxes_from_masks(masks),
input_boxes,
class_names,
ImageSegmentation._centers_of_mass_from_masks(masks),
), ImageSegmentation._create_result_image(
image, masks, bboxes=input_boxes, labels=labels, class_ids=class_ids
)
def segment_image_from_bbox(self, pil_image: PIL.Image.Image, bboxes: np.array):
"""Generate segmentation masks from image and bounding box coordinates using SAM2.
Args:
pil_image: PIL image that should be segmented
bboxes: Nx4 array of bounding boxes of objects to be segmented (x1, y1, x2, y2)
Returns:
List of C tuples (mask, score) with mask (HxW) and float score
Segmented image
"""
image = np.asarray(pil_image.convert("RGB"))
self.sam2_predictor.set_image(image)
masks, scores, logits = self.sam2_predictor.predict(
point_coords=None,
point_labels=None,
box=np.array(bboxes),
multimask_output=False,
)
# convert the shape to (n, H, W)
if masks.ndim == 4:
masks = masks.squeeze(1)
return zip(
masks, scores, ImageSegmentation._bboxes_from_masks(masks)
), ImageSegmentation._create_result_image(image, masks, bboxes=bboxes)
def segment_image_from_points(self, pil_image: PIL.Image.Image, points: np.array):
"""Generate segmentation masks from image and point coordinates with include/exclude labels using SAM2.
Args:
pil_image: PIL image that should be segmented
points: Nx3 array of points with include/exclude flags of the objects to be segmented (x, y, include)
Returns:
List of C tuples (mask, score) with mask (HxW) and float score
Result image
"""
image = np.asarray(pil_image.convert("RGB"))
self.sam2_predictor.set_image(image)
# convert points to coordinates and labels arrays
coords = np.array([[point[0], point[1]] for point in points])
labels = np.array([1 if point[2] else 0 for point in points])
masks, scores, logits = self.sam2_predictor.predict(
point_coords=coords,
point_labels=labels,
multimask_output=False,
)
# convert the shape to (n, H, W)
if masks.ndim == 4:
masks = masks.squeeze(1)
return zip(
masks, scores, ImageSegmentation._bboxes_from_masks(masks)
), ImageSegmentation._create_result_image(image, masks, points=points)
def _create_result_image(
pil_image: PIL.Image.Image,
masks: np.array,
bboxes: np.array = [],
labels: np.array = [],
class_ids: np.array = None,
points: np.array = [],
):
"""Create annotated result image with masks, bounding boxes, labels, and points overlaid.
Args:
pil_image: PIL image that should be segmented
masks: NxHxW array of object mask(s)
bboxes: (optional) Nx4 array of objects bounding box(es) (x1, y1, x2, y2)
labels: (optional) Nx1 array of object label(s)
points: (optional) Nx3 array of object point(s) with include/exclude flags (x, y, include)
Returns:
List of C tuples (mask, score) with mask (HxW) and float score
Result image
"""
img = np.array(pil_image)
# we have to define the bboxes in the detections even though we might not show them
detection_bboxes = bboxes
if bboxes is None or len(bboxes) == 0:
detection_bboxes = ImageSegmentation._bboxes_from_masks(masks)
detections = sv.Detections(
xyxy=detection_bboxes, # (n, 4)
mask=masks.astype(bool), # (n, h, w)
class_id=class_ids,
)
annotated_frame = img.copy()
# if there is no class ids (i.e. when using without Grounding DINO) we need to derive the
# color lookup
colorlookup = sv.ColorLookup.INDEX
if class_ids is not None and len(class_ids) > 0:
colorlookup = sv.ColorLookup.CLASS
# points
if points is not None:
for x, y, include in points:
# Green for include (True), Red for exclude (False)
color = (0, 255, 0) if include else (0, 0, 255) # BGR format
cv2.circle(annotated_frame, (x, y), 8, (0, 0, 0), -1) # Outer ring
cv2.circle(annotated_frame, (x, y), 5, color, -1) # Filled circle
# bboxes
if len(bboxes) > 0:
box_annotator = sv.BoxAnnotator(color_lookup=colorlookup)
annotated_frame = box_annotator.annotate(
scene=annotated_frame, detections=detections
)
# labels
if labels is not None and len(labels) > 0:
label_annotator = sv.LabelAnnotator(color_lookup=colorlookup)
annotated_frame = label_annotator.annotate(
scene=annotated_frame, detections=detections, labels=labels
)
# mask
mask_annotator = sv.MaskAnnotator(color_lookup=colorlookup)
annotated_frame = mask_annotator.annotate(
scene=annotated_frame, detections=detections
)
return annotated_frame
def _bboxes_from_masks(masks: np.array):
"""Create bounding boxes for the provided masks
Args:
masks: NxHxW array of object mask(s)
Returns:
bboxes: (optional) Nx4 array of mask bounding box (x1, y1, x2, y2)
"""
bboxes = []
for mask in masks:
mask_bool = np.where(mask != 0)
if len(mask_bool) != 0 and len(mask_bool[1]) != 0 and len(mask_bool[0]) != 0:
bboxes.append(
[
int(np.min(mask_bool[1])),
int(np.min(mask_bool[0])),
int(np.max(mask_bool[1])),
int(np.max(mask_bool[0])),
]
)
else:
bboxes.append([0, 0, 0, 0])
return np.array(bboxes)
def _centers_of_mass_from_masks(masks: np.array):
"""Calculate centers of mass for the provided masks
Args:
masks: NxHxW array of object mask(s)
Returns:
centers_of_mass: (optional) Nx2 array of mask center of mass (x1, y1, x2, y2)
"""
return np.array(
[[x, y] for mask in masks for y, x in [ndimage.center_of_mass(mask)]]
)

7
requirements.txt Normal file
View File

@ -0,0 +1,7 @@
fastapi==0.104.1
uvicorn[standard]==0.24.0
pydantic==2.5.0
pillow==10.1.0
numpy==1.24.3
python-multipart==0.0.6
opencv-python==4.8.1.78