* Added HRNet18 click based interactive segmentation (#3729) The commit adds serverless container that has a click based interactive segmentation model(HRNet18) that performs better than f-BRS method statistically. * fixup! Added HRNet18 click based interactive segmentation (#3729) Co-authored-by: Nikita Manovich <nikita.manovich@intel.com>main
parent
0473da064d
commit
2d75101c73
@ -0,0 +1,78 @@
|
|||||||
|
metadata:
|
||||||
|
name: pth-saic-vul-hrnet
|
||||||
|
namespace: cvat
|
||||||
|
annotations:
|
||||||
|
name: HRNET
|
||||||
|
type: interactor
|
||||||
|
spec:
|
||||||
|
framework: pytorch
|
||||||
|
min_pos_points: 1
|
||||||
|
min_neg_points: 0
|
||||||
|
help_message: The interactor allows to get a mask for an object using positive points, and negative points
|
||||||
|
|
||||||
|
spec:
|
||||||
|
description: HRNet18 for click based interactive segmentation
|
||||||
|
runtime: 'python:3.8'
|
||||||
|
handler: main:handler
|
||||||
|
eventTimeout: 30s
|
||||||
|
env:
|
||||||
|
- name: PYTHONPATH
|
||||||
|
value: /opt/nuclio/hrnet
|
||||||
|
|
||||||
|
build:
|
||||||
|
image: cvat/pth.saic-vul.hrnet
|
||||||
|
baseImage: ubuntu:20.04
|
||||||
|
|
||||||
|
directives:
|
||||||
|
preCopy:
|
||||||
|
- kind: ENV
|
||||||
|
value: DEBIAN_FRONTEND=noninteractive
|
||||||
|
- kind: RUN
|
||||||
|
value: apt-get update && apt-get install software-properties-common -y
|
||||||
|
- kind: RUN
|
||||||
|
value: add-apt-repository ppa:deadsnakes/ppa
|
||||||
|
- kind: RUN
|
||||||
|
value: apt-get update && apt-get install -y --no-install-recommends build-essential git curl libglib2.0-0 software-properties-common python3 python3.6-dev python3-pip python3-tk
|
||||||
|
- kind: RUN
|
||||||
|
value: pip3 install --upgrade pip
|
||||||
|
- kind: WORKDIR
|
||||||
|
value: /opt/nuclio
|
||||||
|
- kind: RUN
|
||||||
|
value: git clone https://github.com/saic-vul/ritm_interactive_segmentation.git hrnet
|
||||||
|
- kind: WORKDIR
|
||||||
|
value: /opt/nuclio/hrnet
|
||||||
|
- kind: RUN
|
||||||
|
value: apt-get install -y --no-install-recommends wget
|
||||||
|
- kind: RUN
|
||||||
|
value: wget https://github.com/saic-vul/ritm_interactive_segmentation/releases/download/v1.0/coco_lvis_h18_itermask.pth
|
||||||
|
- kind: RUN
|
||||||
|
value: pip3 install setuptools
|
||||||
|
- kind: RUN
|
||||||
|
value: pip3 install -r requirements.txt
|
||||||
|
- kind: RUN
|
||||||
|
value: apt update && apt install -y libgl1-mesa-glx
|
||||||
|
- kind: RUN
|
||||||
|
value: pip3 uninstall torch torch vision -y
|
||||||
|
- kind: RUN
|
||||||
|
value: pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
|
||||||
|
- kind: WORKDIR
|
||||||
|
value: /opt/nuclio
|
||||||
|
|
||||||
|
triggers:
|
||||||
|
myHttpTrigger:
|
||||||
|
maxWorkers: 1
|
||||||
|
kind: 'http'
|
||||||
|
workerAvailabilityTimeoutMilliseconds: 10000
|
||||||
|
attributes:
|
||||||
|
maxRequestBodySize: 33554432 # 32MB
|
||||||
|
|
||||||
|
resources:
|
||||||
|
limits:
|
||||||
|
nvidia.com/gpu: 1
|
||||||
|
|
||||||
|
platform:
|
||||||
|
attributes:
|
||||||
|
restartPolicy:
|
||||||
|
name: always
|
||||||
|
maximumRetryCount: 3
|
||||||
|
mountMode: volume
|
||||||
@ -0,0 +1,33 @@
|
|||||||
|
# Copyright (C) 2021 Intel Corporation
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
import json
|
||||||
|
import base64
|
||||||
|
from PIL import Image
|
||||||
|
import io
|
||||||
|
from model_handler import ModelHandler
|
||||||
|
|
||||||
|
def init_context(context):
|
||||||
|
context.logger.info("Init context... 0%")
|
||||||
|
|
||||||
|
model = ModelHandler()
|
||||||
|
context.user_data.model = model
|
||||||
|
|
||||||
|
context.logger.info("Init context...100%")
|
||||||
|
|
||||||
|
def handler(context, event):
|
||||||
|
context.logger.info("call handler")
|
||||||
|
data = event.body
|
||||||
|
pos_points = data["pos_points"]
|
||||||
|
neg_points = data["neg_points"]
|
||||||
|
threshold = data.get("threshold", 0.5)
|
||||||
|
buf = io.BytesIO(base64.b64decode(data["image"]))
|
||||||
|
image = Image.open(buf)
|
||||||
|
|
||||||
|
polygon = context.user_data.model.handle(image, pos_points,
|
||||||
|
neg_points, threshold)
|
||||||
|
return context.Response(body=json.dumps(polygon),
|
||||||
|
headers={},
|
||||||
|
content_type='application/json',
|
||||||
|
status_code=200)
|
||||||
@ -0,0 +1,68 @@
|
|||||||
|
# Copyright (C) 2021 Intel Corporation
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
import os
|
||||||
|
|
||||||
|
from isegm.inference import utils
|
||||||
|
from isegm.inference.predictors import get_predictor
|
||||||
|
from isegm.inference.clicker import Clicker, Click
|
||||||
|
|
||||||
|
def convert_mask_to_polygon(mask):
|
||||||
|
mask = np.array(mask, dtype=np.uint8)
|
||||||
|
cv2.normalize(mask, mask, 0, 255, cv2.NORM_MINMAX)
|
||||||
|
contours = None
|
||||||
|
if int(cv2.__version__.split('.')[0]) > 3:
|
||||||
|
contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_TC89_KCOS)[0]
|
||||||
|
else:
|
||||||
|
contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_TC89_KCOS)[1]
|
||||||
|
|
||||||
|
contours = max(contours, key=lambda arr: arr.size)
|
||||||
|
if contours.shape.count(1):
|
||||||
|
contours = np.squeeze(contours)
|
||||||
|
if contours.size < 3 * 2:
|
||||||
|
raise Exception('Less then three point have been detected. Can not build a polygon.')
|
||||||
|
|
||||||
|
polygon = []
|
||||||
|
for point in contours:
|
||||||
|
polygon.append([int(point[0]), int(point[1])])
|
||||||
|
|
||||||
|
return polygon
|
||||||
|
|
||||||
|
class ModelHandler:
|
||||||
|
def __init__(self):
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
base_dir = os.path.abspath(os.environ.get("MODEL_PATH", "/opt/nuclio/hrnet"))
|
||||||
|
model_path = os.path.join(base_dir)
|
||||||
|
|
||||||
|
self.net = None
|
||||||
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
|
||||||
|
checkpoint_path = utils.find_checkpoint(model_path, "coco_lvis_h18_itermask.pth")
|
||||||
|
self.net = utils.load_is_model(checkpoint_path, self.device)
|
||||||
|
|
||||||
|
def handle(self, image, pos_points, neg_points, threshold):
|
||||||
|
image_nd = np.array(image)
|
||||||
|
|
||||||
|
clicker = Clicker()
|
||||||
|
for x, y in pos_points:
|
||||||
|
click = Click(is_positive=True, coords=(y, x))
|
||||||
|
clicker.add_click(click)
|
||||||
|
|
||||||
|
for x, y in neg_points:
|
||||||
|
click = Click(is_positive=False, coords=(y, x))
|
||||||
|
clicker.add_click(click)
|
||||||
|
|
||||||
|
predictor = get_predictor(self.net, 'NoBRS', device=self.device, prob_thresh=0.49)
|
||||||
|
predictor.set_input_image(image_nd)
|
||||||
|
|
||||||
|
object_prob = predictor.get_prediction(clicker)
|
||||||
|
if self.device == 'cuda':
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
object_mask = object_prob > threshold
|
||||||
|
polygon = convert_mask_to_polygon(object_mask)
|
||||||
|
|
||||||
|
return polygon
|
||||||
Loading…
Reference in New Issue