当前位置: 首页 > article >正文

含掩膜mask的单通道灰度图转化为COCO数据集格式标签的json文件(python)

输入:单通道的灰度图,灰度图内含掩膜mask
目标:把灰度图中的语义mask转换为COCO数据集格式的json文件
输出:COCO数据集格式的json文件

期间遇到的问题:
发现有的掩膜内部存在其他类别的掩膜,即mask内部还套了mask,这种情况的mask怎么只用一个数组来表示?

以下是查找的可用代码:

from PIL import Image
import numpy as np
from skimage import measure
from shapely.geometry import Polygon, MultiPolygon
import json
import os
from tqdm import tqdm

def create_sub_masks(mask_image):
    width, height = mask_image.size

    # Initialize a dictionary of sub-masks indexed by RGB colors
    sub_masks = {}
    for x in range(width):
        for y in range(height):
            # Get the RGB values of the pixel
            pixel = mask_image.getpixel((x, y))[:3]

            # If the pixel is not black...
            if pixel != (0, 0, 0):
                # Check to see if we've created a sub-mask...
                pixel_str = str(pixel)
                sub_mask = sub_masks.get(pixel_str)
                if sub_mask is None:
                   # Create a sub-mask (one bit per pixel) and add to the dictionary
                    # Note: we add 1 pixel of padding in each direction
                    # because the contours module doesn't handle cases
                    # where pixels bleed to the edge of the image
                    sub_masks[pixel_str] = Image.new('1', (width+2, height+2))

                # Set the pixel value to 1 (default is 0), accounting for padding
                sub_masks[pixel_str].putpixel((x+1, y+1), 1)

    return sub_masks


def create_sub_mask_annotation(sub_mask, image_id, category_id, annotation_id, is_crowd):
    # Find contours (boundary lines) around each sub-mask
    # Note: there could be multiple contours if the object
    # is partially occluded. (E.g. an elephant behind a tree)
    contours = measure.find_contours(np.array(sub_mask), 0.5, positive_orientation='low')

    segmentations = []
    polygons = []
    for contour in contours:
        # Flip from (row, col) representation to (x, y)
        # and subtract the padding pixel
        # import ipdb;ipdb.set_trace()
        for i in range(len(contour)):
            row, col = contour[i]
            contour[i] = (col - 1, row - 1)

        # Make a polygon and simplify it
        poly = Polygon(contour)
        poly = poly.simplify(1.0, preserve_topology=False)
        polygons.append(poly)
        segmentation = np.array(poly.exterior.coords)
        segmentation = np.maximum(segmentation, 0).ravel().tolist()
        # import ipdb;ipdb.set_trace()
        # print(segmentation)
        #if segmentation == []:
        #    continue
        segmentations.append(segmentation)

    # Combine the polygons to calculate the bounding box and area
    multi_poly = MultiPolygon(polygons)
    if multi_poly.bounds == ():
        return "skip"
    x, y, max_x, max_y = multi_poly.bounds
    # x = max(0, x)
    # y = max(0, y)
    width = max_x - x
    height = max_y - y
    bbox = (x, y, width, height)
    area = multi_poly.area

    annotation = {
        'segmentation': segmentations,
        'iscrowd': is_crowd,
        'image_id': image_id,
        'category_id': category_id,
        'id': annotation_id,
        'bbox': bbox,
        'area': area
    }

    return annotation


def get_name(root, mode_folder=True):
    for root, dirs, file in os.walk(root):
        if mode_folder:
            return sorted(dirs)
        else:
            return sorted(file)


def get_annotation(mask_image_root):
    dataset = {"info": {"year": 2023, "version": "2023", "description": "", "url": "",
                        },
               "license": {},
               "images": [],
               "annotations": [],
               "categories": []}
    class_index = {0: "background",1:'cate1',2:'cate2'}
    for s, k in enumerate(list(class_index.keys())):
        dataset["categories"].append({"id": k, "name": class_index[k], "supercategory": "xxx"})

    is_crowd = 0

    # These ids will be automatically increased as we go
    annotation_id = 0
    image_id = 0

    # Create the annotations
    rrr = maskdir
    for i, root in tqdm(enumerate(mask_image_root)):
        print(i)
        mask_image = Image.open(rrr + root).convert('RGB')
        print(root)
        weight, height = mask_image.size
        # file_name = "rgb_" + root.split("/")[-1].split("_")[-1]
        file_name = mask_image_root[i]
        print(file_name)
        dataset["images"].append({
                                  "file_name": file_name,
                                  "id": i,
                                  "width": weight,
                                  "height": height})
        # import ipdb;ipdb.set_trace()
        sub_masks = create_sub_masks(mask_image)
        # import ipdb;ipdb.set_trace()
        for color, sub_mask in sub_masks.items():
            category_id = 1
            annotation = create_sub_mask_annotation(sub_mask, image_id, category_id, annotation_id, is_crowd)
            if annotation == "skip":
                continue
            dataset["annotations"].append(annotation)
            annotation_id += 1
        image_id += 1
    with open("trainmask.json", "w") as f:
        json.dump(dataset, f)



# rrr = "./InstanceSegmentation/"
# all_root = get_name(rrr, mode_folder=False)
# get_annotation(all_root)
if __name__=='__main__':
    maskdir = './mask/trainmask/'
    maskimglist = os.listdir(maskdir)
    get_annotation(maskimglist)

问题:
上述代码仍然存在不足,有的mask太小segmentation输出是 [],这需要检查一下,我在里面输出的位置判断是空就不保存可以避免这种问题,但是bbox等信息有的会出现Nah的情况,需要自己判断处理。

整体上来说,这个代码还是挺好用的。

还有一点,有些内部mask比较极端的情况,代码执行容易出错,建议把问题图像删除,或者自己查找问题修改代码。


http://www.kler.cn/a/159068.html

相关文章:

  • AI 编程编辑器和工具
  • 从零到一:利用 AI 开发 iOS App 《震感》的编程之旅
  • Python爬虫项目 | 一、网易云音乐热歌榜歌曲
  • PostgreSQL物化视图详解
  • 数据仓库在大数据处理中的作用
  • Javascript_设计模式(二)
  • CUDA简介——Grid和Block内Thread索引
  • 《路由与交换技术》读书笔记
  • 【开源】基于Vue和SpringBoot的开放实验室管理系统
  • 分类预测 | Matlab实现OOA-CNN-SVM鱼鹰算法优化卷积支持向量机分类预测
  • JeecgBoot 框架升级至 Spring Boot3 的实战步骤
  • nodejs+vue+微信小程序+python+PHP在线购票系统的设计与实现-计算机毕业设计推荐
  • 【C++11】线程库/异常
  • 4-Docker命令之docker exec
  • Error: Cannot find module ‘@npmcli/config‘ 最新解决办法
  • javaScript(四):函数和常用对象
  • 第一百十九回 如何Text组件中的文字自动换行
  • 【RabbitMQ】RabbitMQ快速入门 通俗易懂 初学者入门
  • 【1day】蓝凌OA 系统custom.jsp 接口任意文件读取漏洞学习
  • Codeforces Round 913 (Div. 3)
  • 软件测试方法之等价类测试
  • GAN:WGAN-DIV
  • 智慧垃圾分拣站:科技改变城市环境,创造更美好的未来
  • OCP Java17 SE Developers 复习题08
  • MySQL 8.x 自签证书通过keytool和openssl转成JKS文件
  • 剑指 Offer(第2版)面试题 18:删除链表的节点