mirror of
https://github.com/foomo/gsamservice.git
synced 2025-10-16 12:35:37 +00:00
initial commit
This commit is contained in:
parent
27ffdf2668
commit
7b063e8357
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
example/out
|
||||
57
Dockerfile
Normal file
57
Dockerfile
Normal file
@ -0,0 +1,57 @@
|
||||
FROM docker.io/pytorch/pytorch:2.3.1-cuda12.1-cudnn8-devel
|
||||
|
||||
# arguments to build Docker Image using CUDA
|
||||
ARG USE_CUDA=0
|
||||
ARG TORCH_ARCH="7.0;7.5;8.0;8.6"
|
||||
|
||||
ENV AM_I_DOCKER=True
|
||||
ENV BUILD_WITH_CUDA="${USE_CUDA}"
|
||||
ENV TORCH_CUDA_ARCH_LIST="${TORCH_ARCH}"
|
||||
ENV CUDA_HOME=/usr/local/cuda-12.1/
|
||||
# ensure CUDA is correctly set up
|
||||
ENV PATH=/usr/local/cuda-12.1/bin:${PATH}
|
||||
ENV LD_LIBRARY_PATH=/usr/local/cuda-12.1/lib64:${LD_LIBRARY_PATH}
|
||||
|
||||
# install required packages and specific gcc/g++
|
||||
RUN apt-get update && apt-get install --no-install-recommends wget ffmpeg=7:* \
|
||||
libsm6=2:* libxext6=2:* git=1:* nano vim=2:* ninja-build gcc-10 g++-10 git -y \
|
||||
&& apt-get clean && apt-get autoremove && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
ENV CC=gcc-10
|
||||
ENV CXX=g++-10
|
||||
|
||||
# clone grounded sam2 repo
|
||||
WORKDIR /home/appuser
|
||||
RUN git clone https://github.com/IDEA-Research/Grounded-SAM-2
|
||||
|
||||
# download sam2 checkpoints
|
||||
WORKDIR /home/appuser/Grounded-SAM-2/checkpoints
|
||||
RUN bash download_ckpts.sh
|
||||
|
||||
# download grounding dino checkpoints
|
||||
WORKDIR /home/appuser/Grounded-SAM-2/gdino_checkpoints
|
||||
RUN bash download_ckpts.sh
|
||||
|
||||
WORKDIR /home/appuser/Grounded-SAM-2
|
||||
|
||||
# install essential Python packages
|
||||
RUN python -m pip install --upgrade pip "setuptools>=62.3.0,<75.9" wheel numpy \
|
||||
opencv-python transformers supervision pycocotools addict yapf timm
|
||||
|
||||
# install segment_anything package in editable mode
|
||||
RUN python -m pip install -e .
|
||||
|
||||
# install grounding dino
|
||||
RUN python -m pip install --no-build-isolation -e grounding_dino
|
||||
|
||||
# install the server dependencies
|
||||
COPY requirements.txt requirements.txt
|
||||
RUN python -m pip install -r requirements.txt
|
||||
|
||||
COPY app.py app.py
|
||||
COPY imagesegmentation.py imagesegmentation.py
|
||||
|
||||
# RUN mkdir ../host
|
||||
|
||||
# start the server
|
||||
ENTRYPOINT ["python", "app.py", "--log-level", "debug"]
|
||||
43
Makefile
Normal file
43
Makefile
Normal file
@ -0,0 +1,43 @@
|
||||
# Get version of CUDA and enable it for compilation if CUDA > 11.0
|
||||
# This solves https://github.com/IDEA-Research/Grounded-Segment-Anything/issues/53
|
||||
# and https://github.com/IDEA-Research/Grounded-Segment-Anything/issues/84
|
||||
# when running in Docker
|
||||
# Check if nvcc is installed
|
||||
NVCC := $(shell which nvcc)
|
||||
ifeq ($(NVCC),)
|
||||
# NVCC not found
|
||||
USE_CUDA := 0
|
||||
NVCC_VERSION := "not installed"
|
||||
else
|
||||
NVCC_VERSION := $(shell nvcc --version | grep -oP 'release \K[0-9.]+')
|
||||
USE_CUDA := $(shell echo "$(NVCC_VERSION) > 11" | bc -l)
|
||||
endif
|
||||
|
||||
# Add the list of supported ARCHs
|
||||
ifeq ($(USE_CUDA), 1)
|
||||
TORCH_CUDA_ARCH_LIST := "7.0;7.5;8.0;8.6+PTX"
|
||||
BUILD_MESSAGE := "Trying to build the image with CUDA support"
|
||||
else
|
||||
TORCH_CUDA_ARCH_LIST :=
|
||||
BUILD_MESSAGE := "CUDA $(NVCC_VERSION) is not supported"
|
||||
endif
|
||||
|
||||
build:
|
||||
docker build --build-arg USE_CUDA=$(USE_CUDA) \
|
||||
--build-arg TORCH_ARCH=$(TORCH_CUDA_ARCH_LIST) \
|
||||
--progress=plain -t gsam2 .
|
||||
|
||||
run:
|
||||
docker run -d --gpus all \
|
||||
--restart unless-stopped \
|
||||
--name=gsam2 \
|
||||
--ipc=host -p 13337:13337 gsam2
|
||||
|
||||
run-bash:
|
||||
docker run -it --rm --gpus all \
|
||||
-v "${PWD}":/home/appuser/host \
|
||||
--entrypoint bash \
|
||||
--name=gsam2 \
|
||||
--network=host \
|
||||
--ipc=host gsam2
|
||||
|
||||
14
README.md
Normal file
14
README.md
Normal file
@ -0,0 +1,14 @@
|
||||
# GSAM Service
|
||||
|
||||
Simple server providing [Grounded SAM2](https://github.com/IDEA-Research/Grounded-SAM-2) through an REST API
|
||||
|
||||
## Usage
|
||||
|
||||
Build and run the container
|
||||
|
||||
```
|
||||
make build
|
||||
make run
|
||||
```
|
||||
|
||||
You can then connect to the server on port 13337. Have a look at the `example/main.go` for examples of the provided endpoints.
|
||||
303
app.py
Normal file
303
app.py
Normal file
@ -0,0 +1,303 @@
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Tuple, Union, Optional
|
||||
import base64
|
||||
import io
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import cv2
|
||||
from imagesegmentation import ImageSegmentation
|
||||
|
||||
app = FastAPI(
|
||||
title="GSAM2 API",
|
||||
description="Grounded SAM 2 Image Segmentation API",
|
||||
version="1.0.0",
|
||||
)
|
||||
|
||||
segmentation_model = ImageSegmentation()
|
||||
|
||||
# pydantic models for request validation
|
||||
class Point(BaseModel):
|
||||
x: int
|
||||
y: int
|
||||
include: bool # True for include, False for exclude
|
||||
|
||||
|
||||
class BoundingBox(BaseModel):
|
||||
upper_left: Tuple[int, int] # (x, y) coordinates
|
||||
lower_right: Tuple[int, int] # (x, y) coordinates
|
||||
|
||||
|
||||
class MaskFromTextRequest(BaseModel):
|
||||
image: str # base64 encoded image
|
||||
text: str
|
||||
|
||||
|
||||
class MaskFromBBoxRequest(BaseModel):
|
||||
image: str # base64 encoded image
|
||||
bboxes: List[BoundingBox]
|
||||
|
||||
|
||||
class MaskFromPointsRequest(BaseModel):
|
||||
image: str # base64 encoded image
|
||||
points: List[Point]
|
||||
|
||||
|
||||
class MaskResult(BaseModel):
|
||||
mask: str
|
||||
score: float
|
||||
bbox: BoundingBox # bounding box generated from the mask
|
||||
|
||||
# fields, only populated in responses for MaskFromTextRequests
|
||||
class_name: str = ""
|
||||
dino_bbox: BoundingBox = BoundingBox(upper_left=(0, 0), lower_right=(0, 0))
|
||||
center_of_mass: Tuple[float, float] = (0.0, 0.0)
|
||||
|
||||
|
||||
class MaskResponse(BaseModel):
|
||||
masks: List[MaskResult] # list of base64 encoded mask images and respectivescores
|
||||
image: str # base64 encoded result image
|
||||
|
||||
|
||||
def decode_base64_image(base64_string: str) -> Image.Image:
|
||||
"""Helper function to decode base64 image string to PIL Image"""
|
||||
try:
|
||||
# remove data URL prefix if present
|
||||
if base64_string.startswith("data:image"):
|
||||
base64_string = base64_string.split(",")[1]
|
||||
|
||||
image_data = base64.b64decode(base64_string)
|
||||
image = Image.open(io.BytesIO(image_data))
|
||||
return image
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid image data: {str(e)}")
|
||||
|
||||
|
||||
def encode_mask_to_base64(mask: np.ndarray) -> str:
|
||||
"""Helper function to encode mask array to base64 string"""
|
||||
try:
|
||||
# convert mask to PIL Image (assuming binary mask)
|
||||
mask_image = Image.fromarray((mask * 255).astype(np.uint8), mode="L")
|
||||
|
||||
# convert to base64
|
||||
buffer = io.BytesIO()
|
||||
mask_image.save(buffer, format="JPEG")
|
||||
mask_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
return mask_base64
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to encode mask: {str(e)}")
|
||||
|
||||
|
||||
def encode_image_to_base64(image: np.ndarray) -> str:
|
||||
"""Helper function to encode cv2 image array to base64 string"""
|
||||
try:
|
||||
pil_image = Image.fromarray(image.astype(np.uint8))
|
||||
|
||||
# convert to base64
|
||||
buffer = io.BytesIO()
|
||||
pil_image.save(buffer, format="JPEG")
|
||||
image_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
return image_base64
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to encode image: {str(e)}")
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"message": "GSAM2 API Server", "version": "1.0.0"}
|
||||
|
||||
|
||||
@app.post("/gsam2/image/maskfromtext", response_model=MaskResponse)
|
||||
async def mask_from_text(request: MaskFromTextRequest):
|
||||
"""
|
||||
Generate segmentation masks from an image and text description.
|
||||
|
||||
Args:
|
||||
request: Contains base64 encoded image and text description
|
||||
|
||||
Returns:
|
||||
MaskResponse with list of base64 encoded masks, their scores and result image
|
||||
"""
|
||||
try:
|
||||
# decode the input image
|
||||
pil_image = decode_base64_image(request.image)
|
||||
text = request.text
|
||||
|
||||
# segment the image
|
||||
masks, annotated_image = segmentation_model.segment_image_from_text(
|
||||
pil_image, text
|
||||
)
|
||||
|
||||
# encode the results
|
||||
enc_masks = [
|
||||
MaskResult(
|
||||
mask=encode_mask_to_base64(mask),
|
||||
score=score,
|
||||
bbox=BoundingBox(
|
||||
upper_left=(bbox[0], bbox[1]), lower_right=(bbox[2], bbox[3])
|
||||
),
|
||||
class_name=class_name,
|
||||
dino_bbox=BoundingBox(
|
||||
upper_left=(round(dino_bbox[0]), round(dino_bbox[1])),
|
||||
lower_right=(round(dino_bbox[2]), round(dino_bbox[3])),
|
||||
),
|
||||
center_of_mass=(com[0], com[1]),
|
||||
)
|
||||
for (mask, score, bbox, dino_bbox, class_name, com) in masks
|
||||
]
|
||||
enc_annotated_image = encode_image_to_base64(annotated_image)
|
||||
|
||||
return MaskResponse(
|
||||
masks=enc_masks,
|
||||
image=enc_annotated_image,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}")
|
||||
|
||||
|
||||
@app.post("/gsam2/image/maskfrombboxes", response_model=MaskResponse)
|
||||
async def mask_from_bbox(request: MaskFromBBoxRequest):
|
||||
"""
|
||||
Generate segmentation masks from an image and bounding box.
|
||||
|
||||
Args:
|
||||
request: Contains base64 encoded image and bounding box coordinates
|
||||
|
||||
Returns:
|
||||
MaskResponse with list of base64 encoded masks, their scores and result image
|
||||
"""
|
||||
try:
|
||||
pil_image = decode_base64_image(request.image)
|
||||
|
||||
# validate bounding box coordinates
|
||||
bboxes = None
|
||||
for bbox in request.bboxes:
|
||||
x1, y1 = bbox.upper_left
|
||||
x2, y2 = bbox.lower_right
|
||||
|
||||
if x1 >= x2 or y1 >= y2:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Invalid bounding box: upper_left must be above and left of lower_right",
|
||||
)
|
||||
|
||||
if x1 < 0 or y1 < 0 or x2 > pil_image.width or y2 > pil_image.height:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Bounding box coordinates out of image bounds",
|
||||
)
|
||||
|
||||
# convert to numpy array format expected by ImageSegmentation
|
||||
if bboxes is None:
|
||||
bboxes = np.array([[x1, y1, x2, y2]])
|
||||
else:
|
||||
bboxes = np.vstack((bboxes, [[x1, y1, x2, y2]]))
|
||||
|
||||
if bboxes is None:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="At least one bounding box is required"
|
||||
)
|
||||
|
||||
# segment the image
|
||||
(masks, annotated_image) = segmentation_model.segment_image_from_bbox(
|
||||
pil_image, bboxes
|
||||
)
|
||||
|
||||
# encode the results
|
||||
enc_masks = [
|
||||
MaskResult(
|
||||
mask=encode_mask_to_base64(mask),
|
||||
score=score,
|
||||
bbox=BoundingBox(
|
||||
upper_left=(bbox[0], bbox[1]), lower_right=(bbox[2], bbox[3])
|
||||
),
|
||||
)
|
||||
for (mask, score, bbox) in masks
|
||||
]
|
||||
enc_annotated_image = encode_image_to_base64(annotated_image)
|
||||
|
||||
return MaskResponse(
|
||||
masks=enc_masks,
|
||||
image=enc_annotated_image,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}")
|
||||
|
||||
|
||||
@app.post("/gsam2/image/maskfrompoints", response_model=MaskResponse)
|
||||
async def mask_from_points(request: MaskFromPointsRequest):
|
||||
"""
|
||||
Generate segmentation masks from an image and list of points with include/exclude indicators.
|
||||
|
||||
Args:
|
||||
request: Contains base64 encoded image and list of points with include/exclude flags
|
||||
|
||||
Returns:
|
||||
MaskResponse with list of base64 encoded masks, their scores and result image
|
||||
"""
|
||||
try:
|
||||
pil_image = decode_base64_image(request.image)
|
||||
|
||||
# validate point coordinates
|
||||
for i, point in enumerate(request.points):
|
||||
if (
|
||||
point.x < 0
|
||||
or point.x >= pil_image.width
|
||||
or point.y < 0
|
||||
or point.y >= pil_image.height
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Point {i} coordinates out of image bounds"
|
||||
)
|
||||
|
||||
# convert points to numpy array format expected by ImageSegmentation
|
||||
points = None
|
||||
if request.points is not None and len(request.points) > 0:
|
||||
points = np.array(
|
||||
[[point.x, point.y, point.include] for point in request.points]
|
||||
)
|
||||
|
||||
if points is None:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="At least one point is required"
|
||||
)
|
||||
|
||||
# segment the image
|
||||
(masks, annotated_image) = segmentation_model.segment_image_from_points(
|
||||
pil_image, points
|
||||
)
|
||||
|
||||
# encode the results
|
||||
enc_masks = [
|
||||
MaskResult(
|
||||
mask=encode_mask_to_base64(mask),
|
||||
score=score,
|
||||
bbox=BoundingBox(
|
||||
upper_left=(bbox[0], bbox[1]), lower_right=(bbox[2], bbox[3])
|
||||
),
|
||||
)
|
||||
for (mask, score, bbox) in masks
|
||||
]
|
||||
enc_annotated_image = encode_image_to_base64(annotated_image)
|
||||
|
||||
return MaskResponse(
|
||||
masks=enc_masks,
|
||||
image=enc_annotated_image,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=13337)
|
||||
245
example/main.go
Normal file
245
example/main.go
Normal file
@ -0,0 +1,245 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
url = "http://localhost:13337"
|
||||
fromText = "/gsam2/image/maskfromtext"
|
||||
fromBboxes = "/gsam2/image/maskfrombboxes"
|
||||
fromPoints = "/gsam2/image/maskfrompoints"
|
||||
|
||||
image = "truck.jpg"
|
||||
)
|
||||
|
||||
type (
|
||||
Point struct {
|
||||
X int `json:"x"`
|
||||
Y int `json:"y"`
|
||||
Include bool `json:"include"`
|
||||
}
|
||||
|
||||
BoundingBox struct {
|
||||
UpperLeft [2]int `json:"upper_left"`
|
||||
LowerRight [2]int `json:"lower_right"`
|
||||
}
|
||||
|
||||
MaskFromTextRequest struct {
|
||||
Image string `json:"image"`
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
MaskFromBBoxRequest struct {
|
||||
Image string `json:"image"`
|
||||
Bboxes []BoundingBox `json:"bboxes"`
|
||||
}
|
||||
|
||||
MaskFromPointsRequest struct {
|
||||
Image string `json:"image"`
|
||||
Points []Point `json:"points"`
|
||||
}
|
||||
|
||||
MaskResult struct {
|
||||
Mask string `json:"mask"`
|
||||
Score float64 `json:"score"`
|
||||
BBox BoundingBox `json:"bbox"`
|
||||
// fields, only populated in responses for MaskFromTextRequests
|
||||
ClassName string `json:"class_name"`
|
||||
DinoBBox BoundingBox `json:"dino_bbox"`
|
||||
CenterOfMass [2]float64 `json:"center_of_mass"`
|
||||
}
|
||||
|
||||
MaskResponse struct {
|
||||
Masks []MaskResult `json:"masks"`
|
||||
Image string `json:"image"`
|
||||
}
|
||||
)
|
||||
|
||||
func main() {
|
||||
// ensure the out directory exists
|
||||
os.Mkdir("out", 0755)
|
||||
|
||||
// load the sample image and base64 encode it
|
||||
dat, err := os.ReadFile(image)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return
|
||||
}
|
||||
|
||||
// post to the different endpoints
|
||||
c := &http.Client{Timeout: time.Minute}
|
||||
encImage := base64.StdEncoding.EncodeToString(dat)
|
||||
|
||||
// from text
|
||||
err = doFromText(c, "fromtext", encImage, "truck. tire.")
|
||||
if err != nil {
|
||||
fmt.Printf("error %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
// from bboxes
|
||||
err = doFromBboxes(c, "frombboxes", encImage, []BoundingBox{
|
||||
{
|
||||
UpperLeft: [2]int{75, 275},
|
||||
LowerRight: [2]int{1725, 850},
|
||||
},
|
||||
{
|
||||
UpperLeft: [2]int{425, 600},
|
||||
LowerRight: [2]int{700, 875},
|
||||
},
|
||||
{
|
||||
UpperLeft: [2]int{1375, 550},
|
||||
LowerRight: [2]int{1650, 800},
|
||||
},
|
||||
{
|
||||
UpperLeft: [2]int{1240, 675},
|
||||
LowerRight: [2]int{1400, 750},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
fmt.Printf("error %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
// from points
|
||||
err = doFromPoints(c, "frompoints", encImage, []Point{
|
||||
{X: 500, Y: 375, Include: true},
|
||||
{X: 1125, Y: 625, Include: true},
|
||||
{X: 575, Y: 750, Include: false},
|
||||
})
|
||||
if err != nil {
|
||||
fmt.Printf("error %s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func do(c *http.Client, req *http.Request, outname string) error {
|
||||
|
||||
dump, err := httputil.DumpRequest(req, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Println("request: ", string(dump))
|
||||
|
||||
resp, err := c.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 300 {
|
||||
dump, err := httputil.DumpResponse(resp, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Println("response: ", string(dump))
|
||||
defer resp.Body.Close()
|
||||
} else {
|
||||
dump, err := httputil.DumpResponse(resp, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Println("response: ", string(dump))
|
||||
}
|
||||
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
maskResp := MaskResponse{}
|
||||
err = json.Unmarshal(bodyBytes, &maskResp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// write the masks to a file
|
||||
for _, mask := range maskResp.Masks {
|
||||
dec, err := base64.StdEncoding.DecodeString(mask.Mask)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
class := ""
|
||||
if mask.ClassName != "" {
|
||||
class = "-" + mask.ClassName
|
||||
}
|
||||
os.WriteFile(fmt.Sprintf("out/%s%s-%.4f.jpg", outname, class, mask.Score), dec, 0644)
|
||||
}
|
||||
|
||||
dec, err := base64.StdEncoding.DecodeString(maskResp.Image)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
os.WriteFile(fmt.Sprintf("out/%s.jpg", outname), dec, 0644)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func doFromText(c *http.Client, outname string, encImage string, text string) error {
|
||||
req, err := http.NewRequest("POST", url+fromText, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Add("Accept", `application/json`)
|
||||
|
||||
body := MaskFromTextRequest{
|
||||
Image: encImage,
|
||||
Text: text,
|
||||
}
|
||||
jsonBody, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Body = io.NopCloser(bytes.NewBuffer(jsonBody))
|
||||
|
||||
return do(c, req, outname)
|
||||
}
|
||||
|
||||
func doFromBboxes(c *http.Client, outname string, encImage string, bboxes []BoundingBox) error {
|
||||
req, err := http.NewRequest("POST", url+fromBboxes, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Add("Accept", `application/json`)
|
||||
|
||||
body := MaskFromBBoxRequest{
|
||||
Image: encImage,
|
||||
Bboxes: bboxes,
|
||||
}
|
||||
jsonBody, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Body = io.NopCloser(bytes.NewBuffer(jsonBody))
|
||||
|
||||
return do(c, req, outname)
|
||||
}
|
||||
|
||||
func doFromPoints(c *http.Client, outname string, encImage string, points []Point) error {
|
||||
req, err := http.NewRequest("POST", url+fromPoints, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Add("Accept", `application/json`)
|
||||
|
||||
body := MaskFromPointsRequest{
|
||||
Image: encImage,
|
||||
Points: points,
|
||||
}
|
||||
jsonBody, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Body = io.NopCloser(bytes.NewBuffer(jsonBody))
|
||||
|
||||
return do(c, req, outname)
|
||||
}
|
||||
BIN
example/truck.jpg
Normal file
BIN
example/truck.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 265 KiB |
313
imagesegmentation.py
Normal file
313
imagesegmentation.py
Normal file
@ -0,0 +1,313 @@
|
||||
import numpy as np
|
||||
import supervision as sv
|
||||
import cv2
|
||||
import PIL
|
||||
from scipy import ndimage
|
||||
from typing import List, Tuple, Union, Optional
|
||||
|
||||
import torch
|
||||
from torchvision.ops import box_convert
|
||||
|
||||
from sam2.build_sam import build_sam2
|
||||
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
||||
|
||||
from grounding_dino.groundingdino.util.inference import load_model, load_image, predict
|
||||
import grounding_dino.groundingdino.datasets.transforms as T
|
||||
|
||||
|
||||
class ImageSegmentation:
|
||||
def __init__(self):
|
||||
# select the device for computation
|
||||
if torch.cuda.is_available():
|
||||
self.device = torch.device("cuda")
|
||||
else:
|
||||
self.device = torch.device("cpu")
|
||||
print(f"using device: {self.device}")
|
||||
|
||||
if self.device.type == "cuda":
|
||||
# NOTE: somehow this didn't work locally inside a docker container
|
||||
# use bfloat16 for the entire notebook
|
||||
# orignal:
|
||||
# torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
|
||||
# might work without or this:
|
||||
# torch.autocast("cuda", dtype=torch.float16).__enter__()
|
||||
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
||||
if torch.cuda.get_device_properties(0).major >= 8:
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
sam2_checkpoint = "checkpoints/sam2.1_hiera_large.pt"
|
||||
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
||||
|
||||
self.sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=self.device)
|
||||
self.sam2_predictor = SAM2ImagePredictor(self.sam2_model)
|
||||
|
||||
grounding_dino_config = (
|
||||
"grounding_dino/groundingdino/config/GroundingDINO_SwinT_OGC.py"
|
||||
)
|
||||
grounding_dino_checkpoint = "gdino_checkpoints/groundingdino_swint_ogc.pth"
|
||||
self.box_threshold = 0.35
|
||||
self.text_threshold = 0.25
|
||||
|
||||
self.grounding_model = load_model(
|
||||
model_config_path=grounding_dino_config,
|
||||
model_checkpoint_path=grounding_dino_checkpoint,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def segment_image_from_text(self, pil_image: PIL.Image.Image, text: str):
|
||||
"""Generate segmentation masks from image and text description using Grounding DINO + SAM2.
|
||||
|
||||
Args:
|
||||
pil_image: PIL image that should be segmented
|
||||
text: object description(s) to be segmented
|
||||
|
||||
Returns:
|
||||
List of C tuples (mask, score) with mask (HxW) and float score
|
||||
Result image
|
||||
"""
|
||||
|
||||
# image preparation taken from load_image() in Grounded-SAM-2/grounding_dino/groundingdino/util/inference.py
|
||||
pil_image = pil_image.convert("RGB")
|
||||
transform = T.Compose(
|
||||
[
|
||||
T.RandomResize([800], max_size=1333),
|
||||
T.ToTensor(),
|
||||
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||||
]
|
||||
)
|
||||
image = np.asarray(pil_image)
|
||||
image_transformed, _ = transform(pil_image, None)
|
||||
|
||||
# set the image for sam2
|
||||
self.sam2_predictor.set_image(image)
|
||||
|
||||
# predict the bounding boxes
|
||||
boxes, confidences, labels = predict(
|
||||
model=self.grounding_model,
|
||||
image=image_transformed,
|
||||
caption=text,
|
||||
box_threshold=self.box_threshold,
|
||||
text_threshold=self.text_threshold,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
if boxes is None or len(boxes) < 1:
|
||||
return [], image
|
||||
|
||||
# process the box prompt for SAM 2
|
||||
h, w, _ = image.shape
|
||||
boxes = boxes * torch.Tensor([w, h, w, h])
|
||||
input_boxes = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
|
||||
|
||||
# NOTE: somehow this didn't work locally inside a docker container
|
||||
# torch.autocast(device_type=self.device, dtype=torch.bfloat16).__enter__()
|
||||
# if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8:
|
||||
# # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
||||
# torch.backends.cuda.matmul.allow_tf32 = True
|
||||
# torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
masks, scores, logits = self.sam2_predictor.predict(
|
||||
point_coords=None,
|
||||
point_labels=None,
|
||||
box=input_boxes,
|
||||
multimask_output=False,
|
||||
)
|
||||
|
||||
# convert the shape to (n, H, W)
|
||||
if masks.ndim == 4:
|
||||
masks = masks.squeeze(1)
|
||||
|
||||
confidences = confidences.numpy().tolist()
|
||||
class_names = labels
|
||||
|
||||
class_ids = np.array(list(range(len(class_names))))
|
||||
|
||||
labels = [
|
||||
f"{class_name} {confidence:.2f}"
|
||||
for class_name, confidence in zip(class_names, confidences)
|
||||
]
|
||||
|
||||
return zip(
|
||||
masks,
|
||||
scores,
|
||||
ImageSegmentation._bboxes_from_masks(masks),
|
||||
input_boxes,
|
||||
class_names,
|
||||
ImageSegmentation._centers_of_mass_from_masks(masks),
|
||||
), ImageSegmentation._create_result_image(
|
||||
image, masks, bboxes=input_boxes, labels=labels, class_ids=class_ids
|
||||
)
|
||||
|
||||
def segment_image_from_bbox(self, pil_image: PIL.Image.Image, bboxes: np.array):
|
||||
"""Generate segmentation masks from image and bounding box coordinates using SAM2.
|
||||
|
||||
Args:
|
||||
pil_image: PIL image that should be segmented
|
||||
bboxes: Nx4 array of bounding boxes of objects to be segmented (x1, y1, x2, y2)
|
||||
|
||||
Returns:
|
||||
List of C tuples (mask, score) with mask (HxW) and float score
|
||||
Segmented image
|
||||
"""
|
||||
image = np.asarray(pil_image.convert("RGB"))
|
||||
self.sam2_predictor.set_image(image)
|
||||
|
||||
masks, scores, logits = self.sam2_predictor.predict(
|
||||
point_coords=None,
|
||||
point_labels=None,
|
||||
box=np.array(bboxes),
|
||||
multimask_output=False,
|
||||
)
|
||||
|
||||
# convert the shape to (n, H, W)
|
||||
if masks.ndim == 4:
|
||||
masks = masks.squeeze(1)
|
||||
|
||||
return zip(
|
||||
masks, scores, ImageSegmentation._bboxes_from_masks(masks)
|
||||
), ImageSegmentation._create_result_image(image, masks, bboxes=bboxes)
|
||||
|
||||
def segment_image_from_points(self, pil_image: PIL.Image.Image, points: np.array):
|
||||
"""Generate segmentation masks from image and point coordinates with include/exclude labels using SAM2.
|
||||
|
||||
Args:
|
||||
pil_image: PIL image that should be segmented
|
||||
points: Nx3 array of points with include/exclude flags of the objects to be segmented (x, y, include)
|
||||
|
||||
Returns:
|
||||
List of C tuples (mask, score) with mask (HxW) and float score
|
||||
Result image
|
||||
"""
|
||||
image = np.asarray(pil_image.convert("RGB"))
|
||||
self.sam2_predictor.set_image(image)
|
||||
|
||||
# convert points to coordinates and labels arrays
|
||||
coords = np.array([[point[0], point[1]] for point in points])
|
||||
labels = np.array([1 if point[2] else 0 for point in points])
|
||||
|
||||
masks, scores, logits = self.sam2_predictor.predict(
|
||||
point_coords=coords,
|
||||
point_labels=labels,
|
||||
multimask_output=False,
|
||||
)
|
||||
|
||||
# convert the shape to (n, H, W)
|
||||
if masks.ndim == 4:
|
||||
masks = masks.squeeze(1)
|
||||
|
||||
return zip(
|
||||
masks, scores, ImageSegmentation._bboxes_from_masks(masks)
|
||||
), ImageSegmentation._create_result_image(image, masks, points=points)
|
||||
|
||||
def _create_result_image(
|
||||
pil_image: PIL.Image.Image,
|
||||
masks: np.array,
|
||||
bboxes: np.array = [],
|
||||
labels: np.array = [],
|
||||
class_ids: np.array = None,
|
||||
points: np.array = [],
|
||||
):
|
||||
"""Create annotated result image with masks, bounding boxes, labels, and points overlaid.
|
||||
|
||||
Args:
|
||||
pil_image: PIL image that should be segmented
|
||||
masks: NxHxW array of object mask(s)
|
||||
bboxes: (optional) Nx4 array of objects bounding box(es) (x1, y1, x2, y2)
|
||||
labels: (optional) Nx1 array of object label(s)
|
||||
points: (optional) Nx3 array of object point(s) with include/exclude flags (x, y, include)
|
||||
|
||||
Returns:
|
||||
List of C tuples (mask, score) with mask (HxW) and float score
|
||||
Result image
|
||||
"""
|
||||
img = np.array(pil_image)
|
||||
|
||||
# we have to define the bboxes in the detections even though we might not show them
|
||||
detection_bboxes = bboxes
|
||||
if bboxes is None or len(bboxes) == 0:
|
||||
detection_bboxes = ImageSegmentation._bboxes_from_masks(masks)
|
||||
|
||||
detections = sv.Detections(
|
||||
xyxy=detection_bboxes, # (n, 4)
|
||||
mask=masks.astype(bool), # (n, h, w)
|
||||
class_id=class_ids,
|
||||
)
|
||||
annotated_frame = img.copy()
|
||||
|
||||
# if there is no class ids (i.e. when using without Grounding DINO) we need to derive the
|
||||
# color lookup
|
||||
colorlookup = sv.ColorLookup.INDEX
|
||||
if class_ids is not None and len(class_ids) > 0:
|
||||
colorlookup = sv.ColorLookup.CLASS
|
||||
|
||||
# points
|
||||
if points is not None:
|
||||
for x, y, include in points:
|
||||
# Green for include (True), Red for exclude (False)
|
||||
color = (0, 255, 0) if include else (0, 0, 255) # BGR format
|
||||
cv2.circle(annotated_frame, (x, y), 8, (0, 0, 0), -1) # Outer ring
|
||||
cv2.circle(annotated_frame, (x, y), 5, color, -1) # Filled circle
|
||||
|
||||
# bboxes
|
||||
if len(bboxes) > 0:
|
||||
box_annotator = sv.BoxAnnotator(color_lookup=colorlookup)
|
||||
annotated_frame = box_annotator.annotate(
|
||||
scene=annotated_frame, detections=detections
|
||||
)
|
||||
|
||||
# labels
|
||||
if labels is not None and len(labels) > 0:
|
||||
label_annotator = sv.LabelAnnotator(color_lookup=colorlookup)
|
||||
annotated_frame = label_annotator.annotate(
|
||||
scene=annotated_frame, detections=detections, labels=labels
|
||||
)
|
||||
|
||||
# mask
|
||||
mask_annotator = sv.MaskAnnotator(color_lookup=colorlookup)
|
||||
annotated_frame = mask_annotator.annotate(
|
||||
scene=annotated_frame, detections=detections
|
||||
)
|
||||
|
||||
return annotated_frame
|
||||
|
||||
def _bboxes_from_masks(masks: np.array):
|
||||
"""Create bounding boxes for the provided masks
|
||||
|
||||
Args:
|
||||
masks: NxHxW array of object mask(s)
|
||||
|
||||
Returns:
|
||||
bboxes: (optional) Nx4 array of mask bounding box (x1, y1, x2, y2)
|
||||
"""
|
||||
|
||||
bboxes = []
|
||||
for mask in masks:
|
||||
mask_bool = np.where(mask != 0)
|
||||
if len(mask_bool) != 0 and len(mask_bool[1]) != 0 and len(mask_bool[0]) != 0:
|
||||
bboxes.append(
|
||||
[
|
||||
int(np.min(mask_bool[1])),
|
||||
int(np.min(mask_bool[0])),
|
||||
int(np.max(mask_bool[1])),
|
||||
int(np.max(mask_bool[0])),
|
||||
]
|
||||
)
|
||||
else:
|
||||
bboxes.append([0, 0, 0, 0])
|
||||
|
||||
return np.array(bboxes)
|
||||
|
||||
def _centers_of_mass_from_masks(masks: np.array):
|
||||
"""Calculate centers of mass for the provided masks
|
||||
|
||||
Args:
|
||||
masks: NxHxW array of object mask(s)
|
||||
|
||||
Returns:
|
||||
centers_of_mass: (optional) Nx2 array of mask center of mass (x1, y1, x2, y2)
|
||||
"""
|
||||
|
||||
return np.array(
|
||||
[[x, y] for mask in masks for y, x in [ndimage.center_of_mass(mask)]]
|
||||
)
|
||||
7
requirements.txt
Normal file
7
requirements.txt
Normal file
@ -0,0 +1,7 @@
|
||||
fastapi==0.104.1
|
||||
uvicorn[standard]==0.24.0
|
||||
pydantic==2.5.0
|
||||
pillow==10.1.0
|
||||
numpy==1.24.3
|
||||
python-multipart==0.0.6
|
||||
opencv-python==4.8.1.78
|
||||
Loading…
Reference in New Issue
Block a user