* Added HRNet18 click based interactive segmentation (#3729) The commit adds serverless container that has a click based interactive segmentation model(HRNet18) that performs better than f-BRS method statistically. * fixup! Added HRNet18 click based interactive segmentation (#3729) Co-authored-by: Nikita Manovich <nikita.manovich@intel.com>main
parent
0473da064d
commit
2d75101c73
@ -0,0 +1,78 @@
|
||||
metadata:
|
||||
name: pth-saic-vul-hrnet
|
||||
namespace: cvat
|
||||
annotations:
|
||||
name: HRNET
|
||||
type: interactor
|
||||
spec:
|
||||
framework: pytorch
|
||||
min_pos_points: 1
|
||||
min_neg_points: 0
|
||||
help_message: The interactor allows to get a mask for an object using positive points, and negative points
|
||||
|
||||
spec:
|
||||
description: HRNet18 for click based interactive segmentation
|
||||
runtime: 'python:3.8'
|
||||
handler: main:handler
|
||||
eventTimeout: 30s
|
||||
env:
|
||||
- name: PYTHONPATH
|
||||
value: /opt/nuclio/hrnet
|
||||
|
||||
build:
|
||||
image: cvat/pth.saic-vul.hrnet
|
||||
baseImage: ubuntu:20.04
|
||||
|
||||
directives:
|
||||
preCopy:
|
||||
- kind: ENV
|
||||
value: DEBIAN_FRONTEND=noninteractive
|
||||
- kind: RUN
|
||||
value: apt-get update && apt-get install software-properties-common -y
|
||||
- kind: RUN
|
||||
value: add-apt-repository ppa:deadsnakes/ppa
|
||||
- kind: RUN
|
||||
value: apt-get update && apt-get install -y --no-install-recommends build-essential git curl libglib2.0-0 software-properties-common python3 python3.6-dev python3-pip python3-tk
|
||||
- kind: RUN
|
||||
value: pip3 install --upgrade pip
|
||||
- kind: WORKDIR
|
||||
value: /opt/nuclio
|
||||
- kind: RUN
|
||||
value: git clone https://github.com/saic-vul/ritm_interactive_segmentation.git hrnet
|
||||
- kind: WORKDIR
|
||||
value: /opt/nuclio/hrnet
|
||||
- kind: RUN
|
||||
value: apt-get install -y --no-install-recommends wget
|
||||
- kind: RUN
|
||||
value: wget https://github.com/saic-vul/ritm_interactive_segmentation/releases/download/v1.0/coco_lvis_h18_itermask.pth
|
||||
- kind: RUN
|
||||
value: pip3 install setuptools
|
||||
- kind: RUN
|
||||
value: pip3 install -r requirements.txt
|
||||
- kind: RUN
|
||||
value: apt update && apt install -y libgl1-mesa-glx
|
||||
- kind: RUN
|
||||
value: pip3 uninstall torch torch vision -y
|
||||
- kind: RUN
|
||||
value: pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
|
||||
- kind: WORKDIR
|
||||
value: /opt/nuclio
|
||||
|
||||
triggers:
|
||||
myHttpTrigger:
|
||||
maxWorkers: 1
|
||||
kind: 'http'
|
||||
workerAvailabilityTimeoutMilliseconds: 10000
|
||||
attributes:
|
||||
maxRequestBodySize: 33554432 # 32MB
|
||||
|
||||
resources:
|
||||
limits:
|
||||
nvidia.com/gpu: 1
|
||||
|
||||
platform:
|
||||
attributes:
|
||||
restartPolicy:
|
||||
name: always
|
||||
maximumRetryCount: 3
|
||||
mountMode: volume
|
||||
@ -0,0 +1,33 @@
|
||||
# Copyright (C) 2021 Intel Corporation
|
||||
#
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import json
|
||||
import base64
|
||||
from PIL import Image
|
||||
import io
|
||||
from model_handler import ModelHandler
|
||||
|
||||
def init_context(context):
|
||||
context.logger.info("Init context... 0%")
|
||||
|
||||
model = ModelHandler()
|
||||
context.user_data.model = model
|
||||
|
||||
context.logger.info("Init context...100%")
|
||||
|
||||
def handler(context, event):
|
||||
context.logger.info("call handler")
|
||||
data = event.body
|
||||
pos_points = data["pos_points"]
|
||||
neg_points = data["neg_points"]
|
||||
threshold = data.get("threshold", 0.5)
|
||||
buf = io.BytesIO(base64.b64decode(data["image"]))
|
||||
image = Image.open(buf)
|
||||
|
||||
polygon = context.user_data.model.handle(image, pos_points,
|
||||
neg_points, threshold)
|
||||
return context.Response(body=json.dumps(polygon),
|
||||
headers={},
|
||||
content_type='application/json',
|
||||
status_code=200)
|
||||
@ -0,0 +1,68 @@
|
||||
# Copyright (C) 2021 Intel Corporation
|
||||
#
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import cv2
|
||||
import os
|
||||
|
||||
from isegm.inference import utils
|
||||
from isegm.inference.predictors import get_predictor
|
||||
from isegm.inference.clicker import Clicker, Click
|
||||
|
||||
def convert_mask_to_polygon(mask):
|
||||
mask = np.array(mask, dtype=np.uint8)
|
||||
cv2.normalize(mask, mask, 0, 255, cv2.NORM_MINMAX)
|
||||
contours = None
|
||||
if int(cv2.__version__.split('.')[0]) > 3:
|
||||
contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_TC89_KCOS)[0]
|
||||
else:
|
||||
contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_TC89_KCOS)[1]
|
||||
|
||||
contours = max(contours, key=lambda arr: arr.size)
|
||||
if contours.shape.count(1):
|
||||
contours = np.squeeze(contours)
|
||||
if contours.size < 3 * 2:
|
||||
raise Exception('Less then three point have been detected. Can not build a polygon.')
|
||||
|
||||
polygon = []
|
||||
for point in contours:
|
||||
polygon.append([int(point[0]), int(point[1])])
|
||||
|
||||
return polygon
|
||||
|
||||
class ModelHandler:
|
||||
def __init__(self):
|
||||
torch.backends.cudnn.deterministic = True
|
||||
base_dir = os.path.abspath(os.environ.get("MODEL_PATH", "/opt/nuclio/hrnet"))
|
||||
model_path = os.path.join(base_dir)
|
||||
|
||||
self.net = None
|
||||
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
checkpoint_path = utils.find_checkpoint(model_path, "coco_lvis_h18_itermask.pth")
|
||||
self.net = utils.load_is_model(checkpoint_path, self.device)
|
||||
|
||||
def handle(self, image, pos_points, neg_points, threshold):
|
||||
image_nd = np.array(image)
|
||||
|
||||
clicker = Clicker()
|
||||
for x, y in pos_points:
|
||||
click = Click(is_positive=True, coords=(y, x))
|
||||
clicker.add_click(click)
|
||||
|
||||
for x, y in neg_points:
|
||||
click = Click(is_positive=False, coords=(y, x))
|
||||
clicker.add_click(click)
|
||||
|
||||
predictor = get_predictor(self.net, 'NoBRS', device=self.device, prob_thresh=0.49)
|
||||
predictor.set_input_image(image_nd)
|
||||
|
||||
object_prob = predictor.get_prediction(clicker)
|
||||
if self.device == 'cuda':
|
||||
torch.cuda.empty_cache()
|
||||
object_mask = object_prob > threshold
|
||||
polygon = convert_mask_to_polygon(object_mask)
|
||||
|
||||
return polygon
|
||||
Loading…
Reference in New Issue