mirror of
https://github.com/foomo/gsamservice.git
synced 2025-10-16 12:35:37 +00:00
304 lines
9.4 KiB
Python
304 lines
9.4 KiB
Python
from fastapi import FastAPI, HTTPException
|
|
from pydantic import BaseModel
|
|
from typing import List, Tuple, Union, Optional
|
|
import base64
|
|
import io
|
|
from PIL import Image
|
|
import numpy as np
|
|
import cv2
|
|
from imagesegmentation import ImageSegmentation
|
|
|
|
app = FastAPI(
|
|
title="GSAM2 API",
|
|
description="Grounded SAM 2 Image Segmentation API",
|
|
version="1.0.0",
|
|
)
|
|
|
|
segmentation_model = ImageSegmentation()
|
|
|
|
# pydantic models for request validation
|
|
class Point(BaseModel):
|
|
x: int
|
|
y: int
|
|
include: bool # True for include, False for exclude
|
|
|
|
|
|
class BoundingBox(BaseModel):
|
|
upper_left: Tuple[int, int] # (x, y) coordinates
|
|
lower_right: Tuple[int, int] # (x, y) coordinates
|
|
|
|
|
|
class MaskFromTextRequest(BaseModel):
|
|
image: str # base64 encoded image
|
|
text: str
|
|
|
|
|
|
class MaskFromBBoxRequest(BaseModel):
|
|
image: str # base64 encoded image
|
|
bboxes: List[BoundingBox]
|
|
|
|
|
|
class MaskFromPointsRequest(BaseModel):
|
|
image: str # base64 encoded image
|
|
points: List[Point]
|
|
|
|
|
|
class MaskResult(BaseModel):
|
|
mask: str
|
|
score: float
|
|
bbox: BoundingBox # bounding box generated from the mask
|
|
|
|
# fields, only populated in responses for MaskFromTextRequests
|
|
class_name: str = ""
|
|
dino_bbox: BoundingBox = BoundingBox(upper_left=(0, 0), lower_right=(0, 0))
|
|
center_of_mass: Tuple[float, float] = (0.0, 0.0)
|
|
|
|
|
|
class MaskResponse(BaseModel):
|
|
masks: List[MaskResult] # list of base64 encoded mask images and respectivescores
|
|
image: str # base64 encoded result image
|
|
|
|
|
|
def decode_base64_image(base64_string: str) -> Image.Image:
|
|
"""Helper function to decode base64 image string to PIL Image"""
|
|
try:
|
|
# remove data URL prefix if present
|
|
if base64_string.startswith("data:image"):
|
|
base64_string = base64_string.split(",")[1]
|
|
|
|
image_data = base64.b64decode(base64_string)
|
|
image = Image.open(io.BytesIO(image_data))
|
|
return image
|
|
except Exception as e:
|
|
raise HTTPException(status_code=400, detail=f"Invalid image data: {str(e)}")
|
|
|
|
|
|
def encode_mask_to_base64(mask: np.ndarray) -> str:
|
|
"""Helper function to encode mask array to base64 string"""
|
|
try:
|
|
# convert mask to PIL Image (assuming binary mask)
|
|
mask_image = Image.fromarray((mask * 255).astype(np.uint8), mode="L")
|
|
|
|
# convert to base64
|
|
buffer = io.BytesIO()
|
|
mask_image.save(buffer, format="JPEG")
|
|
mask_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
|
return mask_base64
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Failed to encode mask: {str(e)}")
|
|
|
|
|
|
def encode_image_to_base64(image: np.ndarray) -> str:
|
|
"""Helper function to encode cv2 image array to base64 string"""
|
|
try:
|
|
pil_image = Image.fromarray(image.astype(np.uint8))
|
|
|
|
# convert to base64
|
|
buffer = io.BytesIO()
|
|
pil_image.save(buffer, format="JPEG")
|
|
image_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
|
return image_base64
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Failed to encode image: {str(e)}")
|
|
|
|
|
|
@app.get("/")
|
|
async def root():
|
|
return {"message": "GSAM2 API Server", "version": "1.0.0"}
|
|
|
|
|
|
@app.post("/gsam2/image/maskfromtext", response_model=MaskResponse)
|
|
async def mask_from_text(request: MaskFromTextRequest):
|
|
"""
|
|
Generate segmentation masks from an image and text description.
|
|
|
|
Args:
|
|
request: Contains base64 encoded image and text description
|
|
|
|
Returns:
|
|
MaskResponse with list of base64 encoded masks, their scores and result image
|
|
"""
|
|
try:
|
|
# decode the input image
|
|
pil_image = decode_base64_image(request.image)
|
|
text = request.text
|
|
|
|
# segment the image
|
|
masks, annotated_image = segmentation_model.segment_image_from_text(
|
|
pil_image, text
|
|
)
|
|
|
|
# encode the results
|
|
enc_masks = [
|
|
MaskResult(
|
|
mask=encode_mask_to_base64(mask),
|
|
score=score,
|
|
bbox=BoundingBox(
|
|
upper_left=(bbox[0], bbox[1]), lower_right=(bbox[2], bbox[3])
|
|
),
|
|
class_name=class_name,
|
|
dino_bbox=BoundingBox(
|
|
upper_left=(round(dino_bbox[0]), round(dino_bbox[1])),
|
|
lower_right=(round(dino_bbox[2]), round(dino_bbox[3])),
|
|
),
|
|
center_of_mass=(com[0], com[1]),
|
|
)
|
|
for (mask, score, bbox, dino_bbox, class_name, com) in masks
|
|
]
|
|
enc_annotated_image = encode_image_to_base64(annotated_image)
|
|
|
|
return MaskResponse(
|
|
masks=enc_masks,
|
|
image=enc_annotated_image,
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}")
|
|
|
|
|
|
@app.post("/gsam2/image/maskfrombboxes", response_model=MaskResponse)
|
|
async def mask_from_bbox(request: MaskFromBBoxRequest):
|
|
"""
|
|
Generate segmentation masks from an image and bounding box.
|
|
|
|
Args:
|
|
request: Contains base64 encoded image and bounding box coordinates
|
|
|
|
Returns:
|
|
MaskResponse with list of base64 encoded masks, their scores and result image
|
|
"""
|
|
try:
|
|
pil_image = decode_base64_image(request.image)
|
|
|
|
# validate bounding box coordinates
|
|
bboxes = None
|
|
for bbox in request.bboxes:
|
|
x1, y1 = bbox.upper_left
|
|
x2, y2 = bbox.lower_right
|
|
|
|
if x1 >= x2 or y1 >= y2:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Invalid bounding box: upper_left must be above and left of lower_right",
|
|
)
|
|
|
|
if x1 < 0 or y1 < 0 or x2 > pil_image.width or y2 > pil_image.height:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Bounding box coordinates out of image bounds",
|
|
)
|
|
|
|
# convert to numpy array format expected by ImageSegmentation
|
|
if bboxes is None:
|
|
bboxes = np.array([[x1, y1, x2, y2]])
|
|
else:
|
|
bboxes = np.vstack((bboxes, [[x1, y1, x2, y2]]))
|
|
|
|
if bboxes is None:
|
|
raise HTTPException(
|
|
status_code=400, detail="At least one bounding box is required"
|
|
)
|
|
|
|
# segment the image
|
|
(masks, annotated_image) = segmentation_model.segment_image_from_bbox(
|
|
pil_image, bboxes
|
|
)
|
|
|
|
# encode the results
|
|
enc_masks = [
|
|
MaskResult(
|
|
mask=encode_mask_to_base64(mask),
|
|
score=score,
|
|
bbox=BoundingBox(
|
|
upper_left=(bbox[0], bbox[1]), lower_right=(bbox[2], bbox[3])
|
|
),
|
|
)
|
|
for (mask, score, bbox) in masks
|
|
]
|
|
enc_annotated_image = encode_image_to_base64(annotated_image)
|
|
|
|
return MaskResponse(
|
|
masks=enc_masks,
|
|
image=enc_annotated_image,
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}")
|
|
|
|
|
|
@app.post("/gsam2/image/maskfrompoints", response_model=MaskResponse)
|
|
async def mask_from_points(request: MaskFromPointsRequest):
|
|
"""
|
|
Generate segmentation masks from an image and list of points with include/exclude indicators.
|
|
|
|
Args:
|
|
request: Contains base64 encoded image and list of points with include/exclude flags
|
|
|
|
Returns:
|
|
MaskResponse with list of base64 encoded masks, their scores and result image
|
|
"""
|
|
try:
|
|
pil_image = decode_base64_image(request.image)
|
|
|
|
# validate point coordinates
|
|
for i, point in enumerate(request.points):
|
|
if (
|
|
point.x < 0
|
|
or point.x >= pil_image.width
|
|
or point.y < 0
|
|
or point.y >= pil_image.height
|
|
):
|
|
raise HTTPException(
|
|
status_code=400, detail=f"Point {i} coordinates out of image bounds"
|
|
)
|
|
|
|
# convert points to numpy array format expected by ImageSegmentation
|
|
points = None
|
|
if request.points is not None and len(request.points) > 0:
|
|
points = np.array(
|
|
[[point.x, point.y, point.include] for point in request.points]
|
|
)
|
|
|
|
if points is None:
|
|
raise HTTPException(
|
|
status_code=400, detail="At least one point is required"
|
|
)
|
|
|
|
# segment the image
|
|
(masks, annotated_image) = segmentation_model.segment_image_from_points(
|
|
pil_image, points
|
|
)
|
|
|
|
# encode the results
|
|
enc_masks = [
|
|
MaskResult(
|
|
mask=encode_mask_to_base64(mask),
|
|
score=score,
|
|
bbox=BoundingBox(
|
|
upper_left=(bbox[0], bbox[1]), lower_right=(bbox[2], bbox[3])
|
|
),
|
|
)
|
|
for (mask, score, bbox) in masks
|
|
]
|
|
enc_annotated_image = encode_image_to_base64(annotated_image)
|
|
|
|
return MaskResponse(
|
|
masks=enc_masks,
|
|
image=enc_annotated_image,
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
uvicorn.run(app, host="0.0.0.0", port=13337)
|