mirror of
https://github.com/foomo/gsamservice.git
synced 2025-10-16 12:35:37 +00:00
initial commit
This commit is contained in:
parent
27ffdf2668
commit
7b063e8357
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
example/out
|
||||||
57
Dockerfile
Normal file
57
Dockerfile
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
FROM docker.io/pytorch/pytorch:2.3.1-cuda12.1-cudnn8-devel
|
||||||
|
|
||||||
|
# arguments to build Docker Image using CUDA
|
||||||
|
ARG USE_CUDA=0
|
||||||
|
ARG TORCH_ARCH="7.0;7.5;8.0;8.6"
|
||||||
|
|
||||||
|
ENV AM_I_DOCKER=True
|
||||||
|
ENV BUILD_WITH_CUDA="${USE_CUDA}"
|
||||||
|
ENV TORCH_CUDA_ARCH_LIST="${TORCH_ARCH}"
|
||||||
|
ENV CUDA_HOME=/usr/local/cuda-12.1/
|
||||||
|
# ensure CUDA is correctly set up
|
||||||
|
ENV PATH=/usr/local/cuda-12.1/bin:${PATH}
|
||||||
|
ENV LD_LIBRARY_PATH=/usr/local/cuda-12.1/lib64:${LD_LIBRARY_PATH}
|
||||||
|
|
||||||
|
# install required packages and specific gcc/g++
|
||||||
|
RUN apt-get update && apt-get install --no-install-recommends wget ffmpeg=7:* \
|
||||||
|
libsm6=2:* libxext6=2:* git=1:* nano vim=2:* ninja-build gcc-10 g++-10 git -y \
|
||||||
|
&& apt-get clean && apt-get autoremove && rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
ENV CC=gcc-10
|
||||||
|
ENV CXX=g++-10
|
||||||
|
|
||||||
|
# clone grounded sam2 repo
|
||||||
|
WORKDIR /home/appuser
|
||||||
|
RUN git clone https://github.com/IDEA-Research/Grounded-SAM-2
|
||||||
|
|
||||||
|
# download sam2 checkpoints
|
||||||
|
WORKDIR /home/appuser/Grounded-SAM-2/checkpoints
|
||||||
|
RUN bash download_ckpts.sh
|
||||||
|
|
||||||
|
# download grounding dino checkpoints
|
||||||
|
WORKDIR /home/appuser/Grounded-SAM-2/gdino_checkpoints
|
||||||
|
RUN bash download_ckpts.sh
|
||||||
|
|
||||||
|
WORKDIR /home/appuser/Grounded-SAM-2
|
||||||
|
|
||||||
|
# install essential Python packages
|
||||||
|
RUN python -m pip install --upgrade pip "setuptools>=62.3.0,<75.9" wheel numpy \
|
||||||
|
opencv-python transformers supervision pycocotools addict yapf timm
|
||||||
|
|
||||||
|
# install segment_anything package in editable mode
|
||||||
|
RUN python -m pip install -e .
|
||||||
|
|
||||||
|
# install grounding dino
|
||||||
|
RUN python -m pip install --no-build-isolation -e grounding_dino
|
||||||
|
|
||||||
|
# install the server dependencies
|
||||||
|
COPY requirements.txt requirements.txt
|
||||||
|
RUN python -m pip install -r requirements.txt
|
||||||
|
|
||||||
|
COPY app.py app.py
|
||||||
|
COPY imagesegmentation.py imagesegmentation.py
|
||||||
|
|
||||||
|
# RUN mkdir ../host
|
||||||
|
|
||||||
|
# start the server
|
||||||
|
ENTRYPOINT ["python", "app.py", "--log-level", "debug"]
|
||||||
43
Makefile
Normal file
43
Makefile
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
# Get version of CUDA and enable it for compilation if CUDA > 11.0
|
||||||
|
# This solves https://github.com/IDEA-Research/Grounded-Segment-Anything/issues/53
|
||||||
|
# and https://github.com/IDEA-Research/Grounded-Segment-Anything/issues/84
|
||||||
|
# when running in Docker
|
||||||
|
# Check if nvcc is installed
|
||||||
|
NVCC := $(shell which nvcc)
|
||||||
|
ifeq ($(NVCC),)
|
||||||
|
# NVCC not found
|
||||||
|
USE_CUDA := 0
|
||||||
|
NVCC_VERSION := "not installed"
|
||||||
|
else
|
||||||
|
NVCC_VERSION := $(shell nvcc --version | grep -oP 'release \K[0-9.]+')
|
||||||
|
USE_CUDA := $(shell echo "$(NVCC_VERSION) > 11" | bc -l)
|
||||||
|
endif
|
||||||
|
|
||||||
|
# Add the list of supported ARCHs
|
||||||
|
ifeq ($(USE_CUDA), 1)
|
||||||
|
TORCH_CUDA_ARCH_LIST := "7.0;7.5;8.0;8.6+PTX"
|
||||||
|
BUILD_MESSAGE := "Trying to build the image with CUDA support"
|
||||||
|
else
|
||||||
|
TORCH_CUDA_ARCH_LIST :=
|
||||||
|
BUILD_MESSAGE := "CUDA $(NVCC_VERSION) is not supported"
|
||||||
|
endif
|
||||||
|
|
||||||
|
build:
|
||||||
|
docker build --build-arg USE_CUDA=$(USE_CUDA) \
|
||||||
|
--build-arg TORCH_ARCH=$(TORCH_CUDA_ARCH_LIST) \
|
||||||
|
--progress=plain -t gsam2 .
|
||||||
|
|
||||||
|
run:
|
||||||
|
docker run -d --gpus all \
|
||||||
|
--restart unless-stopped \
|
||||||
|
--name=gsam2 \
|
||||||
|
--ipc=host -p 13337:13337 gsam2
|
||||||
|
|
||||||
|
run-bash:
|
||||||
|
docker run -it --rm --gpus all \
|
||||||
|
-v "${PWD}":/home/appuser/host \
|
||||||
|
--entrypoint bash \
|
||||||
|
--name=gsam2 \
|
||||||
|
--network=host \
|
||||||
|
--ipc=host gsam2
|
||||||
|
|
||||||
14
README.md
Normal file
14
README.md
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
# GSAM Service
|
||||||
|
|
||||||
|
Simple server providing [Grounded SAM2](https://github.com/IDEA-Research/Grounded-SAM-2) through an REST API
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
Build and run the container
|
||||||
|
|
||||||
|
```
|
||||||
|
make build
|
||||||
|
make run
|
||||||
|
```
|
||||||
|
|
||||||
|
You can then connect to the server on port 13337. Have a look at the `example/main.go` for examples of the provided endpoints.
|
||||||
303
app.py
Normal file
303
app.py
Normal file
@ -0,0 +1,303 @@
|
|||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import List, Tuple, Union, Optional
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
from imagesegmentation import ImageSegmentation
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
title="GSAM2 API",
|
||||||
|
description="Grounded SAM 2 Image Segmentation API",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
|
|
||||||
|
segmentation_model = ImageSegmentation()
|
||||||
|
|
||||||
|
# pydantic models for request validation
|
||||||
|
class Point(BaseModel):
|
||||||
|
x: int
|
||||||
|
y: int
|
||||||
|
include: bool # True for include, False for exclude
|
||||||
|
|
||||||
|
|
||||||
|
class BoundingBox(BaseModel):
|
||||||
|
upper_left: Tuple[int, int] # (x, y) coordinates
|
||||||
|
lower_right: Tuple[int, int] # (x, y) coordinates
|
||||||
|
|
||||||
|
|
||||||
|
class MaskFromTextRequest(BaseModel):
|
||||||
|
image: str # base64 encoded image
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
class MaskFromBBoxRequest(BaseModel):
|
||||||
|
image: str # base64 encoded image
|
||||||
|
bboxes: List[BoundingBox]
|
||||||
|
|
||||||
|
|
||||||
|
class MaskFromPointsRequest(BaseModel):
|
||||||
|
image: str # base64 encoded image
|
||||||
|
points: List[Point]
|
||||||
|
|
||||||
|
|
||||||
|
class MaskResult(BaseModel):
|
||||||
|
mask: str
|
||||||
|
score: float
|
||||||
|
bbox: BoundingBox # bounding box generated from the mask
|
||||||
|
|
||||||
|
# fields, only populated in responses for MaskFromTextRequests
|
||||||
|
class_name: str = ""
|
||||||
|
dino_bbox: BoundingBox = BoundingBox(upper_left=(0, 0), lower_right=(0, 0))
|
||||||
|
center_of_mass: Tuple[float, float] = (0.0, 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
class MaskResponse(BaseModel):
|
||||||
|
masks: List[MaskResult] # list of base64 encoded mask images and respectivescores
|
||||||
|
image: str # base64 encoded result image
|
||||||
|
|
||||||
|
|
||||||
|
def decode_base64_image(base64_string: str) -> Image.Image:
|
||||||
|
"""Helper function to decode base64 image string to PIL Image"""
|
||||||
|
try:
|
||||||
|
# remove data URL prefix if present
|
||||||
|
if base64_string.startswith("data:image"):
|
||||||
|
base64_string = base64_string.split(",")[1]
|
||||||
|
|
||||||
|
image_data = base64.b64decode(base64_string)
|
||||||
|
image = Image.open(io.BytesIO(image_data))
|
||||||
|
return image
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=400, detail=f"Invalid image data: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
def encode_mask_to_base64(mask: np.ndarray) -> str:
|
||||||
|
"""Helper function to encode mask array to base64 string"""
|
||||||
|
try:
|
||||||
|
# convert mask to PIL Image (assuming binary mask)
|
||||||
|
mask_image = Image.fromarray((mask * 255).astype(np.uint8), mode="L")
|
||||||
|
|
||||||
|
# convert to base64
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
mask_image.save(buffer, format="JPEG")
|
||||||
|
mask_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||||
|
return mask_base64
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"Failed to encode mask: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
def encode_image_to_base64(image: np.ndarray) -> str:
|
||||||
|
"""Helper function to encode cv2 image array to base64 string"""
|
||||||
|
try:
|
||||||
|
pil_image = Image.fromarray(image.astype(np.uint8))
|
||||||
|
|
||||||
|
# convert to base64
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
pil_image.save(buffer, format="JPEG")
|
||||||
|
image_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||||
|
return image_base64
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"Failed to encode image: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
async def root():
|
||||||
|
return {"message": "GSAM2 API Server", "version": "1.0.0"}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/gsam2/image/maskfromtext", response_model=MaskResponse)
|
||||||
|
async def mask_from_text(request: MaskFromTextRequest):
|
||||||
|
"""
|
||||||
|
Generate segmentation masks from an image and text description.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: Contains base64 encoded image and text description
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MaskResponse with list of base64 encoded masks, their scores and result image
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# decode the input image
|
||||||
|
pil_image = decode_base64_image(request.image)
|
||||||
|
text = request.text
|
||||||
|
|
||||||
|
# segment the image
|
||||||
|
masks, annotated_image = segmentation_model.segment_image_from_text(
|
||||||
|
pil_image, text
|
||||||
|
)
|
||||||
|
|
||||||
|
# encode the results
|
||||||
|
enc_masks = [
|
||||||
|
MaskResult(
|
||||||
|
mask=encode_mask_to_base64(mask),
|
||||||
|
score=score,
|
||||||
|
bbox=BoundingBox(
|
||||||
|
upper_left=(bbox[0], bbox[1]), lower_right=(bbox[2], bbox[3])
|
||||||
|
),
|
||||||
|
class_name=class_name,
|
||||||
|
dino_bbox=BoundingBox(
|
||||||
|
upper_left=(round(dino_bbox[0]), round(dino_bbox[1])),
|
||||||
|
lower_right=(round(dino_bbox[2]), round(dino_bbox[3])),
|
||||||
|
),
|
||||||
|
center_of_mass=(com[0], com[1]),
|
||||||
|
)
|
||||||
|
for (mask, score, bbox, dino_bbox, class_name, com) in masks
|
||||||
|
]
|
||||||
|
enc_annotated_image = encode_image_to_base64(annotated_image)
|
||||||
|
|
||||||
|
return MaskResponse(
|
||||||
|
masks=enc_masks,
|
||||||
|
image=enc_annotated_image,
|
||||||
|
)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/gsam2/image/maskfrombboxes", response_model=MaskResponse)
|
||||||
|
async def mask_from_bbox(request: MaskFromBBoxRequest):
|
||||||
|
"""
|
||||||
|
Generate segmentation masks from an image and bounding box.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: Contains base64 encoded image and bounding box coordinates
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MaskResponse with list of base64 encoded masks, their scores and result image
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
pil_image = decode_base64_image(request.image)
|
||||||
|
|
||||||
|
# validate bounding box coordinates
|
||||||
|
bboxes = None
|
||||||
|
for bbox in request.bboxes:
|
||||||
|
x1, y1 = bbox.upper_left
|
||||||
|
x2, y2 = bbox.lower_right
|
||||||
|
|
||||||
|
if x1 >= x2 or y1 >= y2:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="Invalid bounding box: upper_left must be above and left of lower_right",
|
||||||
|
)
|
||||||
|
|
||||||
|
if x1 < 0 or y1 < 0 or x2 > pil_image.width or y2 > pil_image.height:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="Bounding box coordinates out of image bounds",
|
||||||
|
)
|
||||||
|
|
||||||
|
# convert to numpy array format expected by ImageSegmentation
|
||||||
|
if bboxes is None:
|
||||||
|
bboxes = np.array([[x1, y1, x2, y2]])
|
||||||
|
else:
|
||||||
|
bboxes = np.vstack((bboxes, [[x1, y1, x2, y2]]))
|
||||||
|
|
||||||
|
if bboxes is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail="At least one bounding box is required"
|
||||||
|
)
|
||||||
|
|
||||||
|
# segment the image
|
||||||
|
(masks, annotated_image) = segmentation_model.segment_image_from_bbox(
|
||||||
|
pil_image, bboxes
|
||||||
|
)
|
||||||
|
|
||||||
|
# encode the results
|
||||||
|
enc_masks = [
|
||||||
|
MaskResult(
|
||||||
|
mask=encode_mask_to_base64(mask),
|
||||||
|
score=score,
|
||||||
|
bbox=BoundingBox(
|
||||||
|
upper_left=(bbox[0], bbox[1]), lower_right=(bbox[2], bbox[3])
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for (mask, score, bbox) in masks
|
||||||
|
]
|
||||||
|
enc_annotated_image = encode_image_to_base64(annotated_image)
|
||||||
|
|
||||||
|
return MaskResponse(
|
||||||
|
masks=enc_masks,
|
||||||
|
image=enc_annotated_image,
|
||||||
|
)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/gsam2/image/maskfrompoints", response_model=MaskResponse)
|
||||||
|
async def mask_from_points(request: MaskFromPointsRequest):
|
||||||
|
"""
|
||||||
|
Generate segmentation masks from an image and list of points with include/exclude indicators.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: Contains base64 encoded image and list of points with include/exclude flags
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MaskResponse with list of base64 encoded masks, their scores and result image
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
pil_image = decode_base64_image(request.image)
|
||||||
|
|
||||||
|
# validate point coordinates
|
||||||
|
for i, point in enumerate(request.points):
|
||||||
|
if (
|
||||||
|
point.x < 0
|
||||||
|
or point.x >= pil_image.width
|
||||||
|
or point.y < 0
|
||||||
|
or point.y >= pil_image.height
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail=f"Point {i} coordinates out of image bounds"
|
||||||
|
)
|
||||||
|
|
||||||
|
# convert points to numpy array format expected by ImageSegmentation
|
||||||
|
points = None
|
||||||
|
if request.points is not None and len(request.points) > 0:
|
||||||
|
points = np.array(
|
||||||
|
[[point.x, point.y, point.include] for point in request.points]
|
||||||
|
)
|
||||||
|
|
||||||
|
if points is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail="At least one point is required"
|
||||||
|
)
|
||||||
|
|
||||||
|
# segment the image
|
||||||
|
(masks, annotated_image) = segmentation_model.segment_image_from_points(
|
||||||
|
pil_image, points
|
||||||
|
)
|
||||||
|
|
||||||
|
# encode the results
|
||||||
|
enc_masks = [
|
||||||
|
MaskResult(
|
||||||
|
mask=encode_mask_to_base64(mask),
|
||||||
|
score=score,
|
||||||
|
bbox=BoundingBox(
|
||||||
|
upper_left=(bbox[0], bbox[1]), lower_right=(bbox[2], bbox[3])
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for (mask, score, bbox) in masks
|
||||||
|
]
|
||||||
|
enc_annotated_image = encode_image_to_base64(annotated_image)
|
||||||
|
|
||||||
|
return MaskResponse(
|
||||||
|
masks=enc_masks,
|
||||||
|
image=enc_annotated_image,
|
||||||
|
)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=13337)
|
||||||
245
example/main.go
Normal file
245
example/main.go
Normal file
@ -0,0 +1,245 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httputil"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
url = "http://localhost:13337"
|
||||||
|
fromText = "/gsam2/image/maskfromtext"
|
||||||
|
fromBboxes = "/gsam2/image/maskfrombboxes"
|
||||||
|
fromPoints = "/gsam2/image/maskfrompoints"
|
||||||
|
|
||||||
|
image = "truck.jpg"
|
||||||
|
)
|
||||||
|
|
||||||
|
type (
|
||||||
|
Point struct {
|
||||||
|
X int `json:"x"`
|
||||||
|
Y int `json:"y"`
|
||||||
|
Include bool `json:"include"`
|
||||||
|
}
|
||||||
|
|
||||||
|
BoundingBox struct {
|
||||||
|
UpperLeft [2]int `json:"upper_left"`
|
||||||
|
LowerRight [2]int `json:"lower_right"`
|
||||||
|
}
|
||||||
|
|
||||||
|
MaskFromTextRequest struct {
|
||||||
|
Image string `json:"image"`
|
||||||
|
Text string `json:"text"`
|
||||||
|
}
|
||||||
|
|
||||||
|
MaskFromBBoxRequest struct {
|
||||||
|
Image string `json:"image"`
|
||||||
|
Bboxes []BoundingBox `json:"bboxes"`
|
||||||
|
}
|
||||||
|
|
||||||
|
MaskFromPointsRequest struct {
|
||||||
|
Image string `json:"image"`
|
||||||
|
Points []Point `json:"points"`
|
||||||
|
}
|
||||||
|
|
||||||
|
MaskResult struct {
|
||||||
|
Mask string `json:"mask"`
|
||||||
|
Score float64 `json:"score"`
|
||||||
|
BBox BoundingBox `json:"bbox"`
|
||||||
|
// fields, only populated in responses for MaskFromTextRequests
|
||||||
|
ClassName string `json:"class_name"`
|
||||||
|
DinoBBox BoundingBox `json:"dino_bbox"`
|
||||||
|
CenterOfMass [2]float64 `json:"center_of_mass"`
|
||||||
|
}
|
||||||
|
|
||||||
|
MaskResponse struct {
|
||||||
|
Masks []MaskResult `json:"masks"`
|
||||||
|
Image string `json:"image"`
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// ensure the out directory exists
|
||||||
|
os.Mkdir("out", 0755)
|
||||||
|
|
||||||
|
// load the sample image and base64 encode it
|
||||||
|
dat, err := os.ReadFile(image)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// post to the different endpoints
|
||||||
|
c := &http.Client{Timeout: time.Minute}
|
||||||
|
encImage := base64.StdEncoding.EncodeToString(dat)
|
||||||
|
|
||||||
|
// from text
|
||||||
|
err = doFromText(c, "fromtext", encImage, "truck. tire.")
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("error %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// from bboxes
|
||||||
|
err = doFromBboxes(c, "frombboxes", encImage, []BoundingBox{
|
||||||
|
{
|
||||||
|
UpperLeft: [2]int{75, 275},
|
||||||
|
LowerRight: [2]int{1725, 850},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
UpperLeft: [2]int{425, 600},
|
||||||
|
LowerRight: [2]int{700, 875},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
UpperLeft: [2]int{1375, 550},
|
||||||
|
LowerRight: [2]int{1650, 800},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
UpperLeft: [2]int{1240, 675},
|
||||||
|
LowerRight: [2]int{1400, 750},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("error %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// from points
|
||||||
|
err = doFromPoints(c, "frompoints", encImage, []Point{
|
||||||
|
{X: 500, Y: 375, Include: true},
|
||||||
|
{X: 1125, Y: 625, Include: true},
|
||||||
|
{X: 575, Y: 750, Include: false},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("error %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func do(c *http.Client, req *http.Request, outname string) error {
|
||||||
|
|
||||||
|
dump, err := httputil.DumpRequest(req, false)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
fmt.Println("request: ", string(dump))
|
||||||
|
|
||||||
|
resp, err := c.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode >= 300 {
|
||||||
|
dump, err := httputil.DumpResponse(resp, true)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
fmt.Println("response: ", string(dump))
|
||||||
|
defer resp.Body.Close()
|
||||||
|
} else {
|
||||||
|
dump, err := httputil.DumpResponse(resp, false)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
fmt.Println("response: ", string(dump))
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyBytes, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
maskResp := MaskResponse{}
|
||||||
|
err = json.Unmarshal(bodyBytes, &maskResp)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// write the masks to a file
|
||||||
|
for _, mask := range maskResp.Masks {
|
||||||
|
dec, err := base64.StdEncoding.DecodeString(mask.Mask)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
class := ""
|
||||||
|
if mask.ClassName != "" {
|
||||||
|
class = "-" + mask.ClassName
|
||||||
|
}
|
||||||
|
os.WriteFile(fmt.Sprintf("out/%s%s-%.4f.jpg", outname, class, mask.Score), dec, 0644)
|
||||||
|
}
|
||||||
|
|
||||||
|
dec, err := base64.StdEncoding.DecodeString(maskResp.Image)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
os.WriteFile(fmt.Sprintf("out/%s.jpg", outname), dec, 0644)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func doFromText(c *http.Client, outname string, encImage string, text string) error {
|
||||||
|
req, err := http.NewRequest("POST", url+fromText, nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
req.Header.Add("Accept", `application/json`)
|
||||||
|
|
||||||
|
body := MaskFromTextRequest{
|
||||||
|
Image: encImage,
|
||||||
|
Text: text,
|
||||||
|
}
|
||||||
|
jsonBody, err := json.Marshal(body)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
req.Body = io.NopCloser(bytes.NewBuffer(jsonBody))
|
||||||
|
|
||||||
|
return do(c, req, outname)
|
||||||
|
}
|
||||||
|
|
||||||
|
func doFromBboxes(c *http.Client, outname string, encImage string, bboxes []BoundingBox) error {
|
||||||
|
req, err := http.NewRequest("POST", url+fromBboxes, nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
req.Header.Add("Accept", `application/json`)
|
||||||
|
|
||||||
|
body := MaskFromBBoxRequest{
|
||||||
|
Image: encImage,
|
||||||
|
Bboxes: bboxes,
|
||||||
|
}
|
||||||
|
jsonBody, err := json.Marshal(body)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
req.Body = io.NopCloser(bytes.NewBuffer(jsonBody))
|
||||||
|
|
||||||
|
return do(c, req, outname)
|
||||||
|
}
|
||||||
|
|
||||||
|
func doFromPoints(c *http.Client, outname string, encImage string, points []Point) error {
|
||||||
|
req, err := http.NewRequest("POST", url+fromPoints, nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
req.Header.Add("Accept", `application/json`)
|
||||||
|
|
||||||
|
body := MaskFromPointsRequest{
|
||||||
|
Image: encImage,
|
||||||
|
Points: points,
|
||||||
|
}
|
||||||
|
jsonBody, err := json.Marshal(body)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
req.Body = io.NopCloser(bytes.NewBuffer(jsonBody))
|
||||||
|
|
||||||
|
return do(c, req, outname)
|
||||||
|
}
|
||||||
BIN
example/truck.jpg
Normal file
BIN
example/truck.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 265 KiB |
313
imagesegmentation.py
Normal file
313
imagesegmentation.py
Normal file
@ -0,0 +1,313 @@
|
|||||||
|
import numpy as np
|
||||||
|
import supervision as sv
|
||||||
|
import cv2
|
||||||
|
import PIL
|
||||||
|
from scipy import ndimage
|
||||||
|
from typing import List, Tuple, Union, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torchvision.ops import box_convert
|
||||||
|
|
||||||
|
from sam2.build_sam import build_sam2
|
||||||
|
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
||||||
|
|
||||||
|
from grounding_dino.groundingdino.util.inference import load_model, load_image, predict
|
||||||
|
import grounding_dino.groundingdino.datasets.transforms as T
|
||||||
|
|
||||||
|
|
||||||
|
class ImageSegmentation:
|
||||||
|
def __init__(self):
|
||||||
|
# select the device for computation
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
self.device = torch.device("cuda")
|
||||||
|
else:
|
||||||
|
self.device = torch.device("cpu")
|
||||||
|
print(f"using device: {self.device}")
|
||||||
|
|
||||||
|
if self.device.type == "cuda":
|
||||||
|
# NOTE: somehow this didn't work locally inside a docker container
|
||||||
|
# use bfloat16 for the entire notebook
|
||||||
|
# orignal:
|
||||||
|
# torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
|
||||||
|
# might work without or this:
|
||||||
|
# torch.autocast("cuda", dtype=torch.float16).__enter__()
|
||||||
|
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
||||||
|
if torch.cuda.get_device_properties(0).major >= 8:
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
|
||||||
|
sam2_checkpoint = "checkpoints/sam2.1_hiera_large.pt"
|
||||||
|
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
||||||
|
|
||||||
|
self.sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=self.device)
|
||||||
|
self.sam2_predictor = SAM2ImagePredictor(self.sam2_model)
|
||||||
|
|
||||||
|
grounding_dino_config = (
|
||||||
|
"grounding_dino/groundingdino/config/GroundingDINO_SwinT_OGC.py"
|
||||||
|
)
|
||||||
|
grounding_dino_checkpoint = "gdino_checkpoints/groundingdino_swint_ogc.pth"
|
||||||
|
self.box_threshold = 0.35
|
||||||
|
self.text_threshold = 0.25
|
||||||
|
|
||||||
|
self.grounding_model = load_model(
|
||||||
|
model_config_path=grounding_dino_config,
|
||||||
|
model_checkpoint_path=grounding_dino_checkpoint,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
def segment_image_from_text(self, pil_image: PIL.Image.Image, text: str):
|
||||||
|
"""Generate segmentation masks from image and text description using Grounding DINO + SAM2.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pil_image: PIL image that should be segmented
|
||||||
|
text: object description(s) to be segmented
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of C tuples (mask, score) with mask (HxW) and float score
|
||||||
|
Result image
|
||||||
|
"""
|
||||||
|
|
||||||
|
# image preparation taken from load_image() in Grounded-SAM-2/grounding_dino/groundingdino/util/inference.py
|
||||||
|
pil_image = pil_image.convert("RGB")
|
||||||
|
transform = T.Compose(
|
||||||
|
[
|
||||||
|
T.RandomResize([800], max_size=1333),
|
||||||
|
T.ToTensor(),
|
||||||
|
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
image = np.asarray(pil_image)
|
||||||
|
image_transformed, _ = transform(pil_image, None)
|
||||||
|
|
||||||
|
# set the image for sam2
|
||||||
|
self.sam2_predictor.set_image(image)
|
||||||
|
|
||||||
|
# predict the bounding boxes
|
||||||
|
boxes, confidences, labels = predict(
|
||||||
|
model=self.grounding_model,
|
||||||
|
image=image_transformed,
|
||||||
|
caption=text,
|
||||||
|
box_threshold=self.box_threshold,
|
||||||
|
text_threshold=self.text_threshold,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
if boxes is None or len(boxes) < 1:
|
||||||
|
return [], image
|
||||||
|
|
||||||
|
# process the box prompt for SAM 2
|
||||||
|
h, w, _ = image.shape
|
||||||
|
boxes = boxes * torch.Tensor([w, h, w, h])
|
||||||
|
input_boxes = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
|
||||||
|
|
||||||
|
# NOTE: somehow this didn't work locally inside a docker container
|
||||||
|
# torch.autocast(device_type=self.device, dtype=torch.bfloat16).__enter__()
|
||||||
|
# if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8:
|
||||||
|
# # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
||||||
|
# torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
# torch.backends.cudnn.allow_tf32 = True
|
||||||
|
|
||||||
|
masks, scores, logits = self.sam2_predictor.predict(
|
||||||
|
point_coords=None,
|
||||||
|
point_labels=None,
|
||||||
|
box=input_boxes,
|
||||||
|
multimask_output=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# convert the shape to (n, H, W)
|
||||||
|
if masks.ndim == 4:
|
||||||
|
masks = masks.squeeze(1)
|
||||||
|
|
||||||
|
confidences = confidences.numpy().tolist()
|
||||||
|
class_names = labels
|
||||||
|
|
||||||
|
class_ids = np.array(list(range(len(class_names))))
|
||||||
|
|
||||||
|
labels = [
|
||||||
|
f"{class_name} {confidence:.2f}"
|
||||||
|
for class_name, confidence in zip(class_names, confidences)
|
||||||
|
]
|
||||||
|
|
||||||
|
return zip(
|
||||||
|
masks,
|
||||||
|
scores,
|
||||||
|
ImageSegmentation._bboxes_from_masks(masks),
|
||||||
|
input_boxes,
|
||||||
|
class_names,
|
||||||
|
ImageSegmentation._centers_of_mass_from_masks(masks),
|
||||||
|
), ImageSegmentation._create_result_image(
|
||||||
|
image, masks, bboxes=input_boxes, labels=labels, class_ids=class_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
def segment_image_from_bbox(self, pil_image: PIL.Image.Image, bboxes: np.array):
|
||||||
|
"""Generate segmentation masks from image and bounding box coordinates using SAM2.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pil_image: PIL image that should be segmented
|
||||||
|
bboxes: Nx4 array of bounding boxes of objects to be segmented (x1, y1, x2, y2)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of C tuples (mask, score) with mask (HxW) and float score
|
||||||
|
Segmented image
|
||||||
|
"""
|
||||||
|
image = np.asarray(pil_image.convert("RGB"))
|
||||||
|
self.sam2_predictor.set_image(image)
|
||||||
|
|
||||||
|
masks, scores, logits = self.sam2_predictor.predict(
|
||||||
|
point_coords=None,
|
||||||
|
point_labels=None,
|
||||||
|
box=np.array(bboxes),
|
||||||
|
multimask_output=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# convert the shape to (n, H, W)
|
||||||
|
if masks.ndim == 4:
|
||||||
|
masks = masks.squeeze(1)
|
||||||
|
|
||||||
|
return zip(
|
||||||
|
masks, scores, ImageSegmentation._bboxes_from_masks(masks)
|
||||||
|
), ImageSegmentation._create_result_image(image, masks, bboxes=bboxes)
|
||||||
|
|
||||||
|
def segment_image_from_points(self, pil_image: PIL.Image.Image, points: np.array):
|
||||||
|
"""Generate segmentation masks from image and point coordinates with include/exclude labels using SAM2.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pil_image: PIL image that should be segmented
|
||||||
|
points: Nx3 array of points with include/exclude flags of the objects to be segmented (x, y, include)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of C tuples (mask, score) with mask (HxW) and float score
|
||||||
|
Result image
|
||||||
|
"""
|
||||||
|
image = np.asarray(pil_image.convert("RGB"))
|
||||||
|
self.sam2_predictor.set_image(image)
|
||||||
|
|
||||||
|
# convert points to coordinates and labels arrays
|
||||||
|
coords = np.array([[point[0], point[1]] for point in points])
|
||||||
|
labels = np.array([1 if point[2] else 0 for point in points])
|
||||||
|
|
||||||
|
masks, scores, logits = self.sam2_predictor.predict(
|
||||||
|
point_coords=coords,
|
||||||
|
point_labels=labels,
|
||||||
|
multimask_output=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# convert the shape to (n, H, W)
|
||||||
|
if masks.ndim == 4:
|
||||||
|
masks = masks.squeeze(1)
|
||||||
|
|
||||||
|
return zip(
|
||||||
|
masks, scores, ImageSegmentation._bboxes_from_masks(masks)
|
||||||
|
), ImageSegmentation._create_result_image(image, masks, points=points)
|
||||||
|
|
||||||
|
def _create_result_image(
|
||||||
|
pil_image: PIL.Image.Image,
|
||||||
|
masks: np.array,
|
||||||
|
bboxes: np.array = [],
|
||||||
|
labels: np.array = [],
|
||||||
|
class_ids: np.array = None,
|
||||||
|
points: np.array = [],
|
||||||
|
):
|
||||||
|
"""Create annotated result image with masks, bounding boxes, labels, and points overlaid.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pil_image: PIL image that should be segmented
|
||||||
|
masks: NxHxW array of object mask(s)
|
||||||
|
bboxes: (optional) Nx4 array of objects bounding box(es) (x1, y1, x2, y2)
|
||||||
|
labels: (optional) Nx1 array of object label(s)
|
||||||
|
points: (optional) Nx3 array of object point(s) with include/exclude flags (x, y, include)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of C tuples (mask, score) with mask (HxW) and float score
|
||||||
|
Result image
|
||||||
|
"""
|
||||||
|
img = np.array(pil_image)
|
||||||
|
|
||||||
|
# we have to define the bboxes in the detections even though we might not show them
|
||||||
|
detection_bboxes = bboxes
|
||||||
|
if bboxes is None or len(bboxes) == 0:
|
||||||
|
detection_bboxes = ImageSegmentation._bboxes_from_masks(masks)
|
||||||
|
|
||||||
|
detections = sv.Detections(
|
||||||
|
xyxy=detection_bboxes, # (n, 4)
|
||||||
|
mask=masks.astype(bool), # (n, h, w)
|
||||||
|
class_id=class_ids,
|
||||||
|
)
|
||||||
|
annotated_frame = img.copy()
|
||||||
|
|
||||||
|
# if there is no class ids (i.e. when using without Grounding DINO) we need to derive the
|
||||||
|
# color lookup
|
||||||
|
colorlookup = sv.ColorLookup.INDEX
|
||||||
|
if class_ids is not None and len(class_ids) > 0:
|
||||||
|
colorlookup = sv.ColorLookup.CLASS
|
||||||
|
|
||||||
|
# points
|
||||||
|
if points is not None:
|
||||||
|
for x, y, include in points:
|
||||||
|
# Green for include (True), Red for exclude (False)
|
||||||
|
color = (0, 255, 0) if include else (0, 0, 255) # BGR format
|
||||||
|
cv2.circle(annotated_frame, (x, y), 8, (0, 0, 0), -1) # Outer ring
|
||||||
|
cv2.circle(annotated_frame, (x, y), 5, color, -1) # Filled circle
|
||||||
|
|
||||||
|
# bboxes
|
||||||
|
if len(bboxes) > 0:
|
||||||
|
box_annotator = sv.BoxAnnotator(color_lookup=colorlookup)
|
||||||
|
annotated_frame = box_annotator.annotate(
|
||||||
|
scene=annotated_frame, detections=detections
|
||||||
|
)
|
||||||
|
|
||||||
|
# labels
|
||||||
|
if labels is not None and len(labels) > 0:
|
||||||
|
label_annotator = sv.LabelAnnotator(color_lookup=colorlookup)
|
||||||
|
annotated_frame = label_annotator.annotate(
|
||||||
|
scene=annotated_frame, detections=detections, labels=labels
|
||||||
|
)
|
||||||
|
|
||||||
|
# mask
|
||||||
|
mask_annotator = sv.MaskAnnotator(color_lookup=colorlookup)
|
||||||
|
annotated_frame = mask_annotator.annotate(
|
||||||
|
scene=annotated_frame, detections=detections
|
||||||
|
)
|
||||||
|
|
||||||
|
return annotated_frame
|
||||||
|
|
||||||
|
def _bboxes_from_masks(masks: np.array):
|
||||||
|
"""Create bounding boxes for the provided masks
|
||||||
|
|
||||||
|
Args:
|
||||||
|
masks: NxHxW array of object mask(s)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bboxes: (optional) Nx4 array of mask bounding box (x1, y1, x2, y2)
|
||||||
|
"""
|
||||||
|
|
||||||
|
bboxes = []
|
||||||
|
for mask in masks:
|
||||||
|
mask_bool = np.where(mask != 0)
|
||||||
|
if len(mask_bool) != 0 and len(mask_bool[1]) != 0 and len(mask_bool[0]) != 0:
|
||||||
|
bboxes.append(
|
||||||
|
[
|
||||||
|
int(np.min(mask_bool[1])),
|
||||||
|
int(np.min(mask_bool[0])),
|
||||||
|
int(np.max(mask_bool[1])),
|
||||||
|
int(np.max(mask_bool[0])),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
bboxes.append([0, 0, 0, 0])
|
||||||
|
|
||||||
|
return np.array(bboxes)
|
||||||
|
|
||||||
|
def _centers_of_mass_from_masks(masks: np.array):
|
||||||
|
"""Calculate centers of mass for the provided masks
|
||||||
|
|
||||||
|
Args:
|
||||||
|
masks: NxHxW array of object mask(s)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
centers_of_mass: (optional) Nx2 array of mask center of mass (x1, y1, x2, y2)
|
||||||
|
"""
|
||||||
|
|
||||||
|
return np.array(
|
||||||
|
[[x, y] for mask in masks for y, x in [ndimage.center_of_mass(mask)]]
|
||||||
|
)
|
||||||
7
requirements.txt
Normal file
7
requirements.txt
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
fastapi==0.104.1
|
||||||
|
uvicorn[standard]==0.24.0
|
||||||
|
pydantic==2.5.0
|
||||||
|
pillow==10.1.0
|
||||||
|
numpy==1.24.3
|
||||||
|
python-multipart==0.0.6
|
||||||
|
opencv-python==4.8.1.78
|
||||||
Loading…
Reference in New Issue
Block a user