文章目录
- 一. python调用生成掩码教程
- 二. python调用SAM分割后转labelme数据集
一. python调用生成掩码教程
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import sys
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
def show_anns(anns):
if len(anns) == 0:
return
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
ax = plt.gca()
ax.set_autoscale_on(False)
img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
img[:,:,3] = 0
for ann in sorted_anns:
m = ann['segmentation']
color_mask = np.concatenate([np.random.random(3), [0.35]])
img[m] = color_mask
ax.imshow(img)
image = cv2.imread('4.PNG')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
plt.figure(figsize=(20,20))
plt.imshow(image)
plt.axis('off')
plt.show()
sys.path.append("..")
sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
device = "cuda"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
mask_generator = SamAutomaticMaskGenerator(sam)
masks = mask_generator.generate(image)
print(len(masks))
print(masks[0].keys())
plt.figure(figsize=(20,20))
plt.imshow(image)
show_anns(masks)
plt.axis('off')
plt.show()
二. python调用SAM分割后转labelme数据集
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import json
import sys
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
def segment(imgPath):
image = cv2.imread(imgPath)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
sys.path.append("..")
sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
device = "cuda"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
mask_generator = SamAutomaticMaskGenerator(sam)
masks = mask_generator.generate(image)
print(len(masks))
print(masks[0].keys())
show_anns(masks, imgPath)
def show_anns(anns, imgPath):
if len(anns) == 0:
return
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
ax = plt.gca()
ax.set_autoscale_on(False)
img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
img[:,:,3] = 0
shapes = []
for ann in sorted_anns:
if ann['area'] >=2500:
tempData = {"label": "otherheavy",
"points": [],
"group_id": None,
"shape_type": "polygon",
"flags": {}
}
m = ann['segmentation']
objImg = np.zeros((m.shape[0], m.shape[1], 1), np.uint8)
objImg[m] = 255
contours, hierarchy = cv2.findContours(objImg, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
max_area = 0
maxIndex = 0
for i in range(0, len(contours)):
area = cv2.contourArea(contours[i])
if area >= max_area:
max_area = area
maxIndex = i
if len(contours[maxIndex]) >=30:
contours = list(contours[maxIndex])
contours = contours[::int(len(contours)/30)]
else:
contours = list(contours[maxIndex])
for point in contours:
tempData["points"].append([int(point[0][0]), int(point[0][1])])
shapes.append(tempData)
color_mask = np.concatenate([np.random.random(3), [1]])
img[m] = color_mask
jsonPath = imgPath.replace(".png", ".json")
print(jsonPath)
file_out = open(jsonPath, "w")
jsonData = {}
jsonData["version"] = "5.2.1"
jsonData["flags"] = {}
jsonData["shapes"] = shapes
jsonData["imagePath"] = imgPath
jsonData["imageData"] = None
jsonData["imageHeight"] = sorted_anns[0]['segmentation'].shape[0]
jsonData["imageWidth"] = sorted_anns[0]['segmentation'].shape[1]
file_out.write(json.dumps(jsonData, indent=4))
file_out.close()
ax.imshow(img)
if __name__ == '__main__':
imgPath = "4.png"
segment(imgPath)