diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..b67c908 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +example/out \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..fa1123e --- /dev/null +++ b/Dockerfile @@ -0,0 +1,57 @@ +FROM docker.io/pytorch/pytorch:2.3.1-cuda12.1-cudnn8-devel + +# arguments to build Docker Image using CUDA +ARG USE_CUDA=0 +ARG TORCH_ARCH="7.0;7.5;8.0;8.6" + +ENV AM_I_DOCKER=True +ENV BUILD_WITH_CUDA="${USE_CUDA}" +ENV TORCH_CUDA_ARCH_LIST="${TORCH_ARCH}" +ENV CUDA_HOME=/usr/local/cuda-12.1/ +# ensure CUDA is correctly set up +ENV PATH=/usr/local/cuda-12.1/bin:${PATH} +ENV LD_LIBRARY_PATH=/usr/local/cuda-12.1/lib64:${LD_LIBRARY_PATH} + +# install required packages and specific gcc/g++ +RUN apt-get update && apt-get install --no-install-recommends wget ffmpeg=7:* \ + libsm6=2:* libxext6=2:* git=1:* nano vim=2:* ninja-build gcc-10 g++-10 git -y \ + && apt-get clean && apt-get autoremove && rm -rf /var/lib/apt/lists/* + +ENV CC=gcc-10 +ENV CXX=g++-10 + +# clone grounded sam2 repo +WORKDIR /home/appuser +RUN git clone https://github.com/IDEA-Research/Grounded-SAM-2 + +# download sam2 checkpoints +WORKDIR /home/appuser/Grounded-SAM-2/checkpoints +RUN bash download_ckpts.sh + +# download grounding dino checkpoints +WORKDIR /home/appuser/Grounded-SAM-2/gdino_checkpoints +RUN bash download_ckpts.sh + +WORKDIR /home/appuser/Grounded-SAM-2 + +# install essential Python packages +RUN python -m pip install --upgrade pip "setuptools>=62.3.0,<75.9" wheel numpy \ + opencv-python transformers supervision pycocotools addict yapf timm + +# install segment_anything package in editable mode +RUN python -m pip install -e . + +# install grounding dino +RUN python -m pip install --no-build-isolation -e grounding_dino + +# install the server dependencies +COPY requirements.txt requirements.txt +RUN python -m pip install -r requirements.txt + +COPY app.py app.py +COPY imagesegmentation.py imagesegmentation.py + +# RUN mkdir ../host + +# start the server +ENTRYPOINT ["python", "app.py", "--log-level", "debug"] \ No newline at end of file diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..c444f2e --- /dev/null +++ b/Makefile @@ -0,0 +1,43 @@ +# Get version of CUDA and enable it for compilation if CUDA > 11.0 +# This solves https://github.com/IDEA-Research/Grounded-Segment-Anything/issues/53 +# and https://github.com/IDEA-Research/Grounded-Segment-Anything/issues/84 +# when running in Docker +# Check if nvcc is installed +NVCC := $(shell which nvcc) +ifeq ($(NVCC),) + # NVCC not found + USE_CUDA := 0 + NVCC_VERSION := "not installed" +else + NVCC_VERSION := $(shell nvcc --version | grep -oP 'release \K[0-9.]+') + USE_CUDA := $(shell echo "$(NVCC_VERSION) > 11" | bc -l) +endif + +# Add the list of supported ARCHs +ifeq ($(USE_CUDA), 1) + TORCH_CUDA_ARCH_LIST := "7.0;7.5;8.0;8.6+PTX" + BUILD_MESSAGE := "Trying to build the image with CUDA support" +else + TORCH_CUDA_ARCH_LIST := + BUILD_MESSAGE := "CUDA $(NVCC_VERSION) is not supported" +endif + +build: + docker build --build-arg USE_CUDA=$(USE_CUDA) \ + --build-arg TORCH_ARCH=$(TORCH_CUDA_ARCH_LIST) \ + --progress=plain -t gsam2 . + +run: + docker run -d --gpus all \ + --restart unless-stopped \ + --name=gsam2 \ + --ipc=host -p 13337:13337 gsam2 + +run-bash: + docker run -it --rm --gpus all \ + -v "${PWD}":/home/appuser/host \ + --entrypoint bash \ + --name=gsam2 \ + --network=host \ + --ipc=host gsam2 + diff --git a/README.md b/README.md new file mode 100644 index 0000000..0197e51 --- /dev/null +++ b/README.md @@ -0,0 +1,14 @@ +# GSAM Service + +Simple server providing [Grounded SAM2](https://github.com/IDEA-Research/Grounded-SAM-2) through an REST API + +## Usage + +Build and run the container + +``` +make build +make run +``` + +You can then connect to the server on port 13337. Have a look at the `example/main.go` for examples of the provided endpoints. diff --git a/app.py b/app.py new file mode 100644 index 0000000..59aed76 --- /dev/null +++ b/app.py @@ -0,0 +1,303 @@ +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel +from typing import List, Tuple, Union, Optional +import base64 +import io +from PIL import Image +import numpy as np +import cv2 +from imagesegmentation import ImageSegmentation + +app = FastAPI( + title="GSAM2 API", + description="Grounded SAM 2 Image Segmentation API", + version="1.0.0", +) + +segmentation_model = ImageSegmentation() + +# pydantic models for request validation +class Point(BaseModel): + x: int + y: int + include: bool # True for include, False for exclude + + +class BoundingBox(BaseModel): + upper_left: Tuple[int, int] # (x, y) coordinates + lower_right: Tuple[int, int] # (x, y) coordinates + + +class MaskFromTextRequest(BaseModel): + image: str # base64 encoded image + text: str + + +class MaskFromBBoxRequest(BaseModel): + image: str # base64 encoded image + bboxes: List[BoundingBox] + + +class MaskFromPointsRequest(BaseModel): + image: str # base64 encoded image + points: List[Point] + + +class MaskResult(BaseModel): + mask: str + score: float + bbox: BoundingBox # bounding box generated from the mask + + # fields, only populated in responses for MaskFromTextRequests + class_name: str = "" + dino_bbox: BoundingBox = BoundingBox(upper_left=(0, 0), lower_right=(0, 0)) + center_of_mass: Tuple[float, float] = (0.0, 0.0) + + +class MaskResponse(BaseModel): + masks: List[MaskResult] # list of base64 encoded mask images and respectivescores + image: str # base64 encoded result image + + +def decode_base64_image(base64_string: str) -> Image.Image: + """Helper function to decode base64 image string to PIL Image""" + try: + # remove data URL prefix if present + if base64_string.startswith("data:image"): + base64_string = base64_string.split(",")[1] + + image_data = base64.b64decode(base64_string) + image = Image.open(io.BytesIO(image_data)) + return image + except Exception as e: + raise HTTPException(status_code=400, detail=f"Invalid image data: {str(e)}") + + +def encode_mask_to_base64(mask: np.ndarray) -> str: + """Helper function to encode mask array to base64 string""" + try: + # convert mask to PIL Image (assuming binary mask) + mask_image = Image.fromarray((mask * 255).astype(np.uint8), mode="L") + + # convert to base64 + buffer = io.BytesIO() + mask_image.save(buffer, format="JPEG") + mask_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8") + return mask_base64 + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to encode mask: {str(e)}") + + +def encode_image_to_base64(image: np.ndarray) -> str: + """Helper function to encode cv2 image array to base64 string""" + try: + pil_image = Image.fromarray(image.astype(np.uint8)) + + # convert to base64 + buffer = io.BytesIO() + pil_image.save(buffer, format="JPEG") + image_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8") + return image_base64 + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to encode image: {str(e)}") + + +@app.get("/") +async def root(): + return {"message": "GSAM2 API Server", "version": "1.0.0"} + + +@app.post("/gsam2/image/maskfromtext", response_model=MaskResponse) +async def mask_from_text(request: MaskFromTextRequest): + """ + Generate segmentation masks from an image and text description. + + Args: + request: Contains base64 encoded image and text description + + Returns: + MaskResponse with list of base64 encoded masks, their scores and result image + """ + try: + # decode the input image + pil_image = decode_base64_image(request.image) + text = request.text + + # segment the image + masks, annotated_image = segmentation_model.segment_image_from_text( + pil_image, text + ) + + # encode the results + enc_masks = [ + MaskResult( + mask=encode_mask_to_base64(mask), + score=score, + bbox=BoundingBox( + upper_left=(bbox[0], bbox[1]), lower_right=(bbox[2], bbox[3]) + ), + class_name=class_name, + dino_bbox=BoundingBox( + upper_left=(round(dino_bbox[0]), round(dino_bbox[1])), + lower_right=(round(dino_bbox[2]), round(dino_bbox[3])), + ), + center_of_mass=(com[0], com[1]), + ) + for (mask, score, bbox, dino_bbox, class_name, com) in masks + ] + enc_annotated_image = encode_image_to_base64(annotated_image) + + return MaskResponse( + masks=enc_masks, + image=enc_annotated_image, + ) + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}") + + +@app.post("/gsam2/image/maskfrombboxes", response_model=MaskResponse) +async def mask_from_bbox(request: MaskFromBBoxRequest): + """ + Generate segmentation masks from an image and bounding box. + + Args: + request: Contains base64 encoded image and bounding box coordinates + + Returns: + MaskResponse with list of base64 encoded masks, their scores and result image + """ + try: + pil_image = decode_base64_image(request.image) + + # validate bounding box coordinates + bboxes = None + for bbox in request.bboxes: + x1, y1 = bbox.upper_left + x2, y2 = bbox.lower_right + + if x1 >= x2 or y1 >= y2: + raise HTTPException( + status_code=400, + detail="Invalid bounding box: upper_left must be above and left of lower_right", + ) + + if x1 < 0 or y1 < 0 or x2 > pil_image.width or y2 > pil_image.height: + raise HTTPException( + status_code=400, + detail="Bounding box coordinates out of image bounds", + ) + + # convert to numpy array format expected by ImageSegmentation + if bboxes is None: + bboxes = np.array([[x1, y1, x2, y2]]) + else: + bboxes = np.vstack((bboxes, [[x1, y1, x2, y2]])) + + if bboxes is None: + raise HTTPException( + status_code=400, detail="At least one bounding box is required" + ) + + # segment the image + (masks, annotated_image) = segmentation_model.segment_image_from_bbox( + pil_image, bboxes + ) + + # encode the results + enc_masks = [ + MaskResult( + mask=encode_mask_to_base64(mask), + score=score, + bbox=BoundingBox( + upper_left=(bbox[0], bbox[1]), lower_right=(bbox[2], bbox[3]) + ), + ) + for (mask, score, bbox) in masks + ] + enc_annotated_image = encode_image_to_base64(annotated_image) + + return MaskResponse( + masks=enc_masks, + image=enc_annotated_image, + ) + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}") + + +@app.post("/gsam2/image/maskfrompoints", response_model=MaskResponse) +async def mask_from_points(request: MaskFromPointsRequest): + """ + Generate segmentation masks from an image and list of points with include/exclude indicators. + + Args: + request: Contains base64 encoded image and list of points with include/exclude flags + + Returns: + MaskResponse with list of base64 encoded masks, their scores and result image + """ + try: + pil_image = decode_base64_image(request.image) + + # validate point coordinates + for i, point in enumerate(request.points): + if ( + point.x < 0 + or point.x >= pil_image.width + or point.y < 0 + or point.y >= pil_image.height + ): + raise HTTPException( + status_code=400, detail=f"Point {i} coordinates out of image bounds" + ) + + # convert points to numpy array format expected by ImageSegmentation + points = None + if request.points is not None and len(request.points) > 0: + points = np.array( + [[point.x, point.y, point.include] for point in request.points] + ) + + if points is None: + raise HTTPException( + status_code=400, detail="At least one point is required" + ) + + # segment the image + (masks, annotated_image) = segmentation_model.segment_image_from_points( + pil_image, points + ) + + # encode the results + enc_masks = [ + MaskResult( + mask=encode_mask_to_base64(mask), + score=score, + bbox=BoundingBox( + upper_left=(bbox[0], bbox[1]), lower_right=(bbox[2], bbox[3]) + ), + ) + for (mask, score, bbox) in masks + ] + enc_annotated_image = encode_image_to_base64(annotated_image) + + return MaskResponse( + masks=enc_masks, + image=enc_annotated_image, + ) + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}") + + +if __name__ == "__main__": + import uvicorn + + uvicorn.run(app, host="0.0.0.0", port=13337) diff --git a/example/main.go b/example/main.go new file mode 100644 index 0000000..537564b --- /dev/null +++ b/example/main.go @@ -0,0 +1,245 @@ +package main + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httputil" + "os" + "time" +) + +var ( + url = "http://localhost:13337" + fromText = "/gsam2/image/maskfromtext" + fromBboxes = "/gsam2/image/maskfrombboxes" + fromPoints = "/gsam2/image/maskfrompoints" + + image = "truck.jpg" +) + +type ( + Point struct { + X int `json:"x"` + Y int `json:"y"` + Include bool `json:"include"` + } + + BoundingBox struct { + UpperLeft [2]int `json:"upper_left"` + LowerRight [2]int `json:"lower_right"` + } + + MaskFromTextRequest struct { + Image string `json:"image"` + Text string `json:"text"` + } + + MaskFromBBoxRequest struct { + Image string `json:"image"` + Bboxes []BoundingBox `json:"bboxes"` + } + + MaskFromPointsRequest struct { + Image string `json:"image"` + Points []Point `json:"points"` + } + + MaskResult struct { + Mask string `json:"mask"` + Score float64 `json:"score"` + BBox BoundingBox `json:"bbox"` + // fields, only populated in responses for MaskFromTextRequests + ClassName string `json:"class_name"` + DinoBBox BoundingBox `json:"dino_bbox"` + CenterOfMass [2]float64 `json:"center_of_mass"` + } + + MaskResponse struct { + Masks []MaskResult `json:"masks"` + Image string `json:"image"` + } +) + +func main() { + // ensure the out directory exists + os.Mkdir("out", 0755) + + // load the sample image and base64 encode it + dat, err := os.ReadFile(image) + if err != nil { + fmt.Println(err) + return + } + + // post to the different endpoints + c := &http.Client{Timeout: time.Minute} + encImage := base64.StdEncoding.EncodeToString(dat) + + // from text + err = doFromText(c, "fromtext", encImage, "truck. tire.") + if err != nil { + fmt.Printf("error %s", err) + return + } + + // from bboxes + err = doFromBboxes(c, "frombboxes", encImage, []BoundingBox{ + { + UpperLeft: [2]int{75, 275}, + LowerRight: [2]int{1725, 850}, + }, + { + UpperLeft: [2]int{425, 600}, + LowerRight: [2]int{700, 875}, + }, + { + UpperLeft: [2]int{1375, 550}, + LowerRight: [2]int{1650, 800}, + }, + { + UpperLeft: [2]int{1240, 675}, + LowerRight: [2]int{1400, 750}, + }, + }) + if err != nil { + fmt.Printf("error %s", err) + return + } + + // from points + err = doFromPoints(c, "frompoints", encImage, []Point{ + {X: 500, Y: 375, Include: true}, + {X: 1125, Y: 625, Include: true}, + {X: 575, Y: 750, Include: false}, + }) + if err != nil { + fmt.Printf("error %s", err) + return + } +} + +func do(c *http.Client, req *http.Request, outname string) error { + + dump, err := httputil.DumpRequest(req, false) + if err != nil { + return err + } + fmt.Println("request: ", string(dump)) + + resp, err := c.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode >= 300 { + dump, err := httputil.DumpResponse(resp, true) + if err != nil { + return err + } + fmt.Println("response: ", string(dump)) + defer resp.Body.Close() + } else { + dump, err := httputil.DumpResponse(resp, false) + if err != nil { + return err + } + fmt.Println("response: ", string(dump)) + } + + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + return err + } + + maskResp := MaskResponse{} + err = json.Unmarshal(bodyBytes, &maskResp) + if err != nil { + return err + } + + // write the masks to a file + for _, mask := range maskResp.Masks { + dec, err := base64.StdEncoding.DecodeString(mask.Mask) + if err != nil { + return err + } + class := "" + if mask.ClassName != "" { + class = "-" + mask.ClassName + } + os.WriteFile(fmt.Sprintf("out/%s%s-%.4f.jpg", outname, class, mask.Score), dec, 0644) + } + + dec, err := base64.StdEncoding.DecodeString(maskResp.Image) + if err != nil { + return err + } + os.WriteFile(fmt.Sprintf("out/%s.jpg", outname), dec, 0644) + + return nil +} + +func doFromText(c *http.Client, outname string, encImage string, text string) error { + req, err := http.NewRequest("POST", url+fromText, nil) + if err != nil { + return err + } + req.Header.Add("Accept", `application/json`) + + body := MaskFromTextRequest{ + Image: encImage, + Text: text, + } + jsonBody, err := json.Marshal(body) + if err != nil { + return err + } + req.Body = io.NopCloser(bytes.NewBuffer(jsonBody)) + + return do(c, req, outname) +} + +func doFromBboxes(c *http.Client, outname string, encImage string, bboxes []BoundingBox) error { + req, err := http.NewRequest("POST", url+fromBboxes, nil) + if err != nil { + return err + } + req.Header.Add("Accept", `application/json`) + + body := MaskFromBBoxRequest{ + Image: encImage, + Bboxes: bboxes, + } + jsonBody, err := json.Marshal(body) + if err != nil { + return err + } + req.Body = io.NopCloser(bytes.NewBuffer(jsonBody)) + + return do(c, req, outname) +} + +func doFromPoints(c *http.Client, outname string, encImage string, points []Point) error { + req, err := http.NewRequest("POST", url+fromPoints, nil) + if err != nil { + return err + } + req.Header.Add("Accept", `application/json`) + + body := MaskFromPointsRequest{ + Image: encImage, + Points: points, + } + jsonBody, err := json.Marshal(body) + if err != nil { + return err + } + req.Body = io.NopCloser(bytes.NewBuffer(jsonBody)) + + return do(c, req, outname) +} diff --git a/example/truck.jpg b/example/truck.jpg new file mode 100644 index 0000000..6b98688 Binary files /dev/null and b/example/truck.jpg differ diff --git a/imagesegmentation.py b/imagesegmentation.py new file mode 100644 index 0000000..a1e4bb4 --- /dev/null +++ b/imagesegmentation.py @@ -0,0 +1,313 @@ +import numpy as np +import supervision as sv +import cv2 +import PIL +from scipy import ndimage +from typing import List, Tuple, Union, Optional + +import torch +from torchvision.ops import box_convert + +from sam2.build_sam import build_sam2 +from sam2.sam2_image_predictor import SAM2ImagePredictor + +from grounding_dino.groundingdino.util.inference import load_model, load_image, predict +import grounding_dino.groundingdino.datasets.transforms as T + + +class ImageSegmentation: + def __init__(self): + # select the device for computation + if torch.cuda.is_available(): + self.device = torch.device("cuda") + else: + self.device = torch.device("cpu") + print(f"using device: {self.device}") + + if self.device.type == "cuda": + # NOTE: somehow this didn't work locally inside a docker container + # use bfloat16 for the entire notebook + # orignal: + # torch.autocast("cuda", dtype=torch.bfloat16).__enter__() + # might work without or this: + # torch.autocast("cuda", dtype=torch.float16).__enter__() + # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices) + if torch.cuda.get_device_properties(0).major >= 8: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + sam2_checkpoint = "checkpoints/sam2.1_hiera_large.pt" + model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml" + + self.sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=self.device) + self.sam2_predictor = SAM2ImagePredictor(self.sam2_model) + + grounding_dino_config = ( + "grounding_dino/groundingdino/config/GroundingDINO_SwinT_OGC.py" + ) + grounding_dino_checkpoint = "gdino_checkpoints/groundingdino_swint_ogc.pth" + self.box_threshold = 0.35 + self.text_threshold = 0.25 + + self.grounding_model = load_model( + model_config_path=grounding_dino_config, + model_checkpoint_path=grounding_dino_checkpoint, + device=self.device, + ) + + def segment_image_from_text(self, pil_image: PIL.Image.Image, text: str): + """Generate segmentation masks from image and text description using Grounding DINO + SAM2. + + Args: + pil_image: PIL image that should be segmented + text: object description(s) to be segmented + + Returns: + List of C tuples (mask, score) with mask (HxW) and float score + Result image + """ + + # image preparation taken from load_image() in Grounded-SAM-2/grounding_dino/groundingdino/util/inference.py + pil_image = pil_image.convert("RGB") + transform = T.Compose( + [ + T.RandomResize([800], max_size=1333), + T.ToTensor(), + T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ] + ) + image = np.asarray(pil_image) + image_transformed, _ = transform(pil_image, None) + + # set the image for sam2 + self.sam2_predictor.set_image(image) + + # predict the bounding boxes + boxes, confidences, labels = predict( + model=self.grounding_model, + image=image_transformed, + caption=text, + box_threshold=self.box_threshold, + text_threshold=self.text_threshold, + device=self.device, + ) + + if boxes is None or len(boxes) < 1: + return [], image + + # process the box prompt for SAM 2 + h, w, _ = image.shape + boxes = boxes * torch.Tensor([w, h, w, h]) + input_boxes = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy() + + # NOTE: somehow this didn't work locally inside a docker container + # torch.autocast(device_type=self.device, dtype=torch.bfloat16).__enter__() + # if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8: + # # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices) + # torch.backends.cuda.matmul.allow_tf32 = True + # torch.backends.cudnn.allow_tf32 = True + + masks, scores, logits = self.sam2_predictor.predict( + point_coords=None, + point_labels=None, + box=input_boxes, + multimask_output=False, + ) + + # convert the shape to (n, H, W) + if masks.ndim == 4: + masks = masks.squeeze(1) + + confidences = confidences.numpy().tolist() + class_names = labels + + class_ids = np.array(list(range(len(class_names)))) + + labels = [ + f"{class_name} {confidence:.2f}" + for class_name, confidence in zip(class_names, confidences) + ] + + return zip( + masks, + scores, + ImageSegmentation._bboxes_from_masks(masks), + input_boxes, + class_names, + ImageSegmentation._centers_of_mass_from_masks(masks), + ), ImageSegmentation._create_result_image( + image, masks, bboxes=input_boxes, labels=labels, class_ids=class_ids + ) + + def segment_image_from_bbox(self, pil_image: PIL.Image.Image, bboxes: np.array): + """Generate segmentation masks from image and bounding box coordinates using SAM2. + + Args: + pil_image: PIL image that should be segmented + bboxes: Nx4 array of bounding boxes of objects to be segmented (x1, y1, x2, y2) + + Returns: + List of C tuples (mask, score) with mask (HxW) and float score + Segmented image + """ + image = np.asarray(pil_image.convert("RGB")) + self.sam2_predictor.set_image(image) + + masks, scores, logits = self.sam2_predictor.predict( + point_coords=None, + point_labels=None, + box=np.array(bboxes), + multimask_output=False, + ) + + # convert the shape to (n, H, W) + if masks.ndim == 4: + masks = masks.squeeze(1) + + return zip( + masks, scores, ImageSegmentation._bboxes_from_masks(masks) + ), ImageSegmentation._create_result_image(image, masks, bboxes=bboxes) + + def segment_image_from_points(self, pil_image: PIL.Image.Image, points: np.array): + """Generate segmentation masks from image and point coordinates with include/exclude labels using SAM2. + + Args: + pil_image: PIL image that should be segmented + points: Nx3 array of points with include/exclude flags of the objects to be segmented (x, y, include) + + Returns: + List of C tuples (mask, score) with mask (HxW) and float score + Result image + """ + image = np.asarray(pil_image.convert("RGB")) + self.sam2_predictor.set_image(image) + + # convert points to coordinates and labels arrays + coords = np.array([[point[0], point[1]] for point in points]) + labels = np.array([1 if point[2] else 0 for point in points]) + + masks, scores, logits = self.sam2_predictor.predict( + point_coords=coords, + point_labels=labels, + multimask_output=False, + ) + + # convert the shape to (n, H, W) + if masks.ndim == 4: + masks = masks.squeeze(1) + + return zip( + masks, scores, ImageSegmentation._bboxes_from_masks(masks) + ), ImageSegmentation._create_result_image(image, masks, points=points) + + def _create_result_image( + pil_image: PIL.Image.Image, + masks: np.array, + bboxes: np.array = [], + labels: np.array = [], + class_ids: np.array = None, + points: np.array = [], + ): + """Create annotated result image with masks, bounding boxes, labels, and points overlaid. + + Args: + pil_image: PIL image that should be segmented + masks: NxHxW array of object mask(s) + bboxes: (optional) Nx4 array of objects bounding box(es) (x1, y1, x2, y2) + labels: (optional) Nx1 array of object label(s) + points: (optional) Nx3 array of object point(s) with include/exclude flags (x, y, include) + + Returns: + List of C tuples (mask, score) with mask (HxW) and float score + Result image + """ + img = np.array(pil_image) + + # we have to define the bboxes in the detections even though we might not show them + detection_bboxes = bboxes + if bboxes is None or len(bboxes) == 0: + detection_bboxes = ImageSegmentation._bboxes_from_masks(masks) + + detections = sv.Detections( + xyxy=detection_bboxes, # (n, 4) + mask=masks.astype(bool), # (n, h, w) + class_id=class_ids, + ) + annotated_frame = img.copy() + + # if there is no class ids (i.e. when using without Grounding DINO) we need to derive the + # color lookup + colorlookup = sv.ColorLookup.INDEX + if class_ids is not None and len(class_ids) > 0: + colorlookup = sv.ColorLookup.CLASS + + # points + if points is not None: + for x, y, include in points: + # Green for include (True), Red for exclude (False) + color = (0, 255, 0) if include else (0, 0, 255) # BGR format + cv2.circle(annotated_frame, (x, y), 8, (0, 0, 0), -1) # Outer ring + cv2.circle(annotated_frame, (x, y), 5, color, -1) # Filled circle + + # bboxes + if len(bboxes) > 0: + box_annotator = sv.BoxAnnotator(color_lookup=colorlookup) + annotated_frame = box_annotator.annotate( + scene=annotated_frame, detections=detections + ) + + # labels + if labels is not None and len(labels) > 0: + label_annotator = sv.LabelAnnotator(color_lookup=colorlookup) + annotated_frame = label_annotator.annotate( + scene=annotated_frame, detections=detections, labels=labels + ) + + # mask + mask_annotator = sv.MaskAnnotator(color_lookup=colorlookup) + annotated_frame = mask_annotator.annotate( + scene=annotated_frame, detections=detections + ) + + return annotated_frame + + def _bboxes_from_masks(masks: np.array): + """Create bounding boxes for the provided masks + + Args: + masks: NxHxW array of object mask(s) + + Returns: + bboxes: (optional) Nx4 array of mask bounding box (x1, y1, x2, y2) + """ + + bboxes = [] + for mask in masks: + mask_bool = np.where(mask != 0) + if len(mask_bool) != 0 and len(mask_bool[1]) != 0 and len(mask_bool[0]) != 0: + bboxes.append( + [ + int(np.min(mask_bool[1])), + int(np.min(mask_bool[0])), + int(np.max(mask_bool[1])), + int(np.max(mask_bool[0])), + ] + ) + else: + bboxes.append([0, 0, 0, 0]) + + return np.array(bboxes) + + def _centers_of_mass_from_masks(masks: np.array): + """Calculate centers of mass for the provided masks + + Args: + masks: NxHxW array of object mask(s) + + Returns: + centers_of_mass: (optional) Nx2 array of mask center of mass (x1, y1, x2, y2) + """ + + return np.array( + [[x, y] for mask in masks for y, x in [ndimage.center_of_mass(mask)]] + ) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..9dfbb13 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,7 @@ +fastapi==0.104.1 +uvicorn[standard]==0.24.0 +pydantic==2.5.0 +pillow==10.1.0 +numpy==1.24.3 +python-multipart==0.0.6 +opencv-python==4.8.1.78 \ No newline at end of file