from fastapi import FastAPI, HTTPException from pydantic import BaseModel from typing import List, Tuple, Union, Optional import base64 import io from PIL import Image import numpy as np import cv2 from imagesegmentation import ImageSegmentation app = FastAPI( title="GSAM2 API", description="Grounded SAM 2 Image Segmentation API", version="1.0.0", ) segmentation_model = ImageSegmentation() # pydantic models for request validation class Point(BaseModel): x: int y: int include: bool # True for include, False for exclude class BoundingBox(BaseModel): upper_left: Tuple[int, int] # (x, y) coordinates lower_right: Tuple[int, int] # (x, y) coordinates class MaskFromTextRequest(BaseModel): image: str # base64 encoded image text: str class MaskFromBBoxRequest(BaseModel): image: str # base64 encoded image bboxes: List[BoundingBox] class MaskFromPointsRequest(BaseModel): image: str # base64 encoded image points: List[Point] class MaskResult(BaseModel): mask: str score: float bbox: BoundingBox # bounding box generated from the mask # fields, only populated in responses for MaskFromTextRequests class_name: str = "" dino_bbox: BoundingBox = BoundingBox(upper_left=(0, 0), lower_right=(0, 0)) center_of_mass: Tuple[float, float] = (0.0, 0.0) class MaskResponse(BaseModel): masks: List[MaskResult] # list of base64 encoded mask images and respectivescores image: str # base64 encoded result image def decode_base64_image(base64_string: str) -> Image.Image: """Helper function to decode base64 image string to PIL Image""" try: # remove data URL prefix if present if base64_string.startswith("data:image"): base64_string = base64_string.split(",")[1] image_data = base64.b64decode(base64_string) image = Image.open(io.BytesIO(image_data)) return image except Exception as e: raise HTTPException(status_code=400, detail=f"Invalid image data: {str(e)}") def encode_mask_to_base64(mask: np.ndarray) -> str: """Helper function to encode mask array to base64 string""" try: # convert mask to PIL Image (assuming binary mask) mask_image = Image.fromarray((mask * 255).astype(np.uint8), mode="L") # convert to base64 buffer = io.BytesIO() mask_image.save(buffer, format="JPEG") mask_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8") return mask_base64 except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to encode mask: {str(e)}") def encode_image_to_base64(image: np.ndarray) -> str: """Helper function to encode cv2 image array to base64 string""" try: pil_image = Image.fromarray(image.astype(np.uint8)) # convert to base64 buffer = io.BytesIO() pil_image.save(buffer, format="JPEG") image_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8") return image_base64 except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to encode image: {str(e)}") @app.get("/") async def root(): return {"message": "GSAM2 API Server", "version": "1.0.0"} @app.post("/gsam2/image/maskfromtext", response_model=MaskResponse) async def mask_from_text(request: MaskFromTextRequest): """ Generate segmentation masks from an image and text description. Args: request: Contains base64 encoded image and text description Returns: MaskResponse with list of base64 encoded masks, their scores and result image """ try: # decode the input image pil_image = decode_base64_image(request.image) text = request.text # segment the image masks, annotated_image = segmentation_model.segment_image_from_text( pil_image, text ) # encode the results enc_masks = [ MaskResult( mask=encode_mask_to_base64(mask), score=score, bbox=BoundingBox( upper_left=(bbox[0], bbox[1]), lower_right=(bbox[2], bbox[3]) ), class_name=class_name, dino_bbox=BoundingBox( upper_left=(round(dino_bbox[0]), round(dino_bbox[1])), lower_right=(round(dino_bbox[2]), round(dino_bbox[3])), ), center_of_mass=(com[0], com[1]), ) for (mask, score, bbox, dino_bbox, class_name, com) in masks ] enc_annotated_image = encode_image_to_base64(annotated_image) return MaskResponse( masks=enc_masks, image=enc_annotated_image, ) except HTTPException: raise except Exception as e: raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}") @app.post("/gsam2/image/maskfrombboxes", response_model=MaskResponse) async def mask_from_bbox(request: MaskFromBBoxRequest): """ Generate segmentation masks from an image and bounding box. Args: request: Contains base64 encoded image and bounding box coordinates Returns: MaskResponse with list of base64 encoded masks, their scores and result image """ try: pil_image = decode_base64_image(request.image) # validate bounding box coordinates bboxes = None for bbox in request.bboxes: x1, y1 = bbox.upper_left x2, y2 = bbox.lower_right if x1 >= x2 or y1 >= y2: raise HTTPException( status_code=400, detail="Invalid bounding box: upper_left must be above and left of lower_right", ) if x1 < 0 or y1 < 0 or x2 > pil_image.width or y2 > pil_image.height: raise HTTPException( status_code=400, detail="Bounding box coordinates out of image bounds", ) # convert to numpy array format expected by ImageSegmentation if bboxes is None: bboxes = np.array([[x1, y1, x2, y2]]) else: bboxes = np.vstack((bboxes, [[x1, y1, x2, y2]])) if bboxes is None: raise HTTPException( status_code=400, detail="At least one bounding box is required" ) # segment the image (masks, annotated_image) = segmentation_model.segment_image_from_bbox( pil_image, bboxes ) # encode the results enc_masks = [ MaskResult( mask=encode_mask_to_base64(mask), score=score, bbox=BoundingBox( upper_left=(bbox[0], bbox[1]), lower_right=(bbox[2], bbox[3]) ), ) for (mask, score, bbox) in masks ] enc_annotated_image = encode_image_to_base64(annotated_image) return MaskResponse( masks=enc_masks, image=enc_annotated_image, ) except HTTPException: raise except Exception as e: raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}") @app.post("/gsam2/image/maskfrompoints", response_model=MaskResponse) async def mask_from_points(request: MaskFromPointsRequest): """ Generate segmentation masks from an image and list of points with include/exclude indicators. Args: request: Contains base64 encoded image and list of points with include/exclude flags Returns: MaskResponse with list of base64 encoded masks, their scores and result image """ try: pil_image = decode_base64_image(request.image) # validate point coordinates for i, point in enumerate(request.points): if ( point.x < 0 or point.x >= pil_image.width or point.y < 0 or point.y >= pil_image.height ): raise HTTPException( status_code=400, detail=f"Point {i} coordinates out of image bounds" ) # convert points to numpy array format expected by ImageSegmentation points = None if request.points is not None and len(request.points) > 0: points = np.array( [[point.x, point.y, point.include] for point in request.points] ) if points is None: raise HTTPException( status_code=400, detail="At least one point is required" ) # segment the image (masks, annotated_image) = segmentation_model.segment_image_from_points( pil_image, points ) # encode the results enc_masks = [ MaskResult( mask=encode_mask_to_base64(mask), score=score, bbox=BoundingBox( upper_left=(bbox[0], bbox[1]), lower_right=(bbox[2], bbox[3]) ), ) for (mask, score, bbox) in masks ] enc_annotated_image = encode_image_to_base64(annotated_image) return MaskResponse( masks=enc_masks, image=enc_annotated_image, ) except HTTPException: raise except Exception as e: raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=13337)