gsamservice/app.py
2025-08-29 09:53:07 +02:00

304 lines
9.4 KiB
Python

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Tuple, Union, Optional
import base64
import io
from PIL import Image
import numpy as np
import cv2
from imagesegmentation import ImageSegmentation
app = FastAPI(
title="GSAM2 API",
description="Grounded SAM 2 Image Segmentation API",
version="1.0.0",
)
segmentation_model = ImageSegmentation()
# pydantic models for request validation
class Point(BaseModel):
x: int
y: int
include: bool # True for include, False for exclude
class BoundingBox(BaseModel):
upper_left: Tuple[int, int] # (x, y) coordinates
lower_right: Tuple[int, int] # (x, y) coordinates
class MaskFromTextRequest(BaseModel):
image: str # base64 encoded image
text: str
class MaskFromBBoxRequest(BaseModel):
image: str # base64 encoded image
bboxes: List[BoundingBox]
class MaskFromPointsRequest(BaseModel):
image: str # base64 encoded image
points: List[Point]
class MaskResult(BaseModel):
mask: str
score: float
bbox: BoundingBox # bounding box generated from the mask
# fields, only populated in responses for MaskFromTextRequests
class_name: str = ""
dino_bbox: BoundingBox = BoundingBox(upper_left=(0, 0), lower_right=(0, 0))
center_of_mass: Tuple[float, float] = (0.0, 0.0)
class MaskResponse(BaseModel):
masks: List[MaskResult] # list of base64 encoded mask images and respectivescores
image: str # base64 encoded result image
def decode_base64_image(base64_string: str) -> Image.Image:
"""Helper function to decode base64 image string to PIL Image"""
try:
# remove data URL prefix if present
if base64_string.startswith("data:image"):
base64_string = base64_string.split(",")[1]
image_data = base64.b64decode(base64_string)
image = Image.open(io.BytesIO(image_data))
return image
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid image data: {str(e)}")
def encode_mask_to_base64(mask: np.ndarray) -> str:
"""Helper function to encode mask array to base64 string"""
try:
# convert mask to PIL Image (assuming binary mask)
mask_image = Image.fromarray((mask * 255).astype(np.uint8), mode="L")
# convert to base64
buffer = io.BytesIO()
mask_image.save(buffer, format="JPEG")
mask_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
return mask_base64
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to encode mask: {str(e)}")
def encode_image_to_base64(image: np.ndarray) -> str:
"""Helper function to encode cv2 image array to base64 string"""
try:
pil_image = Image.fromarray(image.astype(np.uint8))
# convert to base64
buffer = io.BytesIO()
pil_image.save(buffer, format="JPEG")
image_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
return image_base64
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to encode image: {str(e)}")
@app.get("/")
async def root():
return {"message": "GSAM2 API Server", "version": "1.0.0"}
@app.post("/gsam2/image/maskfromtext", response_model=MaskResponse)
async def mask_from_text(request: MaskFromTextRequest):
"""
Generate segmentation masks from an image and text description.
Args:
request: Contains base64 encoded image and text description
Returns:
MaskResponse with list of base64 encoded masks, their scores and result image
"""
try:
# decode the input image
pil_image = decode_base64_image(request.image)
text = request.text
# segment the image
masks, annotated_image = segmentation_model.segment_image_from_text(
pil_image, text
)
# encode the results
enc_masks = [
MaskResult(
mask=encode_mask_to_base64(mask),
score=score,
bbox=BoundingBox(
upper_left=(bbox[0], bbox[1]), lower_right=(bbox[2], bbox[3])
),
class_name=class_name,
dino_bbox=BoundingBox(
upper_left=(round(dino_bbox[0]), round(dino_bbox[1])),
lower_right=(round(dino_bbox[2]), round(dino_bbox[3])),
),
center_of_mass=(com[0], com[1]),
)
for (mask, score, bbox, dino_bbox, class_name, com) in masks
]
enc_annotated_image = encode_image_to_base64(annotated_image)
return MaskResponse(
masks=enc_masks,
image=enc_annotated_image,
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}")
@app.post("/gsam2/image/maskfrombboxes", response_model=MaskResponse)
async def mask_from_bbox(request: MaskFromBBoxRequest):
"""
Generate segmentation masks from an image and bounding box.
Args:
request: Contains base64 encoded image and bounding box coordinates
Returns:
MaskResponse with list of base64 encoded masks, their scores and result image
"""
try:
pil_image = decode_base64_image(request.image)
# validate bounding box coordinates
bboxes = None
for bbox in request.bboxes:
x1, y1 = bbox.upper_left
x2, y2 = bbox.lower_right
if x1 >= x2 or y1 >= y2:
raise HTTPException(
status_code=400,
detail="Invalid bounding box: upper_left must be above and left of lower_right",
)
if x1 < 0 or y1 < 0 or x2 > pil_image.width or y2 > pil_image.height:
raise HTTPException(
status_code=400,
detail="Bounding box coordinates out of image bounds",
)
# convert to numpy array format expected by ImageSegmentation
if bboxes is None:
bboxes = np.array([[x1, y1, x2, y2]])
else:
bboxes = np.vstack((bboxes, [[x1, y1, x2, y2]]))
if bboxes is None:
raise HTTPException(
status_code=400, detail="At least one bounding box is required"
)
# segment the image
(masks, annotated_image) = segmentation_model.segment_image_from_bbox(
pil_image, bboxes
)
# encode the results
enc_masks = [
MaskResult(
mask=encode_mask_to_base64(mask),
score=score,
bbox=BoundingBox(
upper_left=(bbox[0], bbox[1]), lower_right=(bbox[2], bbox[3])
),
)
for (mask, score, bbox) in masks
]
enc_annotated_image = encode_image_to_base64(annotated_image)
return MaskResponse(
masks=enc_masks,
image=enc_annotated_image,
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}")
@app.post("/gsam2/image/maskfrompoints", response_model=MaskResponse)
async def mask_from_points(request: MaskFromPointsRequest):
"""
Generate segmentation masks from an image and list of points with include/exclude indicators.
Args:
request: Contains base64 encoded image and list of points with include/exclude flags
Returns:
MaskResponse with list of base64 encoded masks, their scores and result image
"""
try:
pil_image = decode_base64_image(request.image)
# validate point coordinates
for i, point in enumerate(request.points):
if (
point.x < 0
or point.x >= pil_image.width
or point.y < 0
or point.y >= pil_image.height
):
raise HTTPException(
status_code=400, detail=f"Point {i} coordinates out of image bounds"
)
# convert points to numpy array format expected by ImageSegmentation
points = None
if request.points is not None and len(request.points) > 0:
points = np.array(
[[point.x, point.y, point.include] for point in request.points]
)
if points is None:
raise HTTPException(
status_code=400, detail="At least one point is required"
)
# segment the image
(masks, annotated_image) = segmentation_model.segment_image_from_points(
pil_image, points
)
# encode the results
enc_masks = [
MaskResult(
mask=encode_mask_to_base64(mask),
score=score,
bbox=BoundingBox(
upper_left=(bbox[0], bbox[1]), lower_right=(bbox[2], bbox[3])
),
)
for (mask, score, bbox) in masks
]
enc_annotated_image = encode_image_to_base64(annotated_image)
return MaskResponse(
masks=enc_masks,
image=enc_annotated_image,
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=13337)