当前位置: 首页 > article >正文

Gmtracker_深度学习驱动的图匹配多目标跟踪项目启动与算法流程

Gmtracker深度学习驱动的图匹配多目标跟踪项目启动与算法流程

说明:对于Gmtracker多目标跟踪算法中涉及到的QP或者是QAP等一些有关图匹配的问题,不做过多的说明只提供源代码中通过图网络的具体实现细节。
对于配置环境时产生的报错的具体信息,因为没有报错的缘故只做简单的说明。

在这里插入图片描述

  • 论文题目:Learnable Graph Matching: Incorporating Graph Partitioning with Deep
    Feature Learning for Multiple Object Tracking

  • 官方代码地址: https://github.com/jiaweihe1996/GMTracker.git

在这里插入图片描述

环境配置

  1. 相比较于之前学习过的一些MOT算法,我大多数都没有使用论文官方提供的代码。

    • DeepSort:没有使用官方提供的TensorFlow版本,而是用的github中的Yolo v5 + Deepsort实现的

    • ByteTrack:也是选择使用的Yolo v8官方提供的代码来进行实现。

  2. 之前的环境安装较为简单对版本的要求并不是特别的高。而yolox + Gmtracker官方提供的代码,因为时间原因对版本和环境配置的要求高自己在配置环境的时候也是踩了许多的坑。

计算库配置

下面是官方所提供的前置环境要求。

Python == 3.6.X
PyTorch >= 1.4 with CUDA >=10.0 (tested on PyTorch 1.4.0)
torchvision
torch_geometric

  1. python的版本不能过高,自己在安装过程中发现采用高版本的python不支持官方要求的numpy==1.19.5的版本,即使勉强下载所有的计算环境也会报错

  2. scikit-learn库的版本要求不能过高,否则在代码里面会存在报错原因是0.23版本之后不在对一个util包进行支持了

  3. 如果使用pip进行安装的时候 cvxpy计算库和scs库两个计算库会存在版本冲突的问题。导致整合环境失败opencv或者是scs库(cpu)版下载卡住PE517…

  4. 安装时一定是使用conda进行下载,而且保证要有torch_geometric图计算库的支持。

自己对图计算资源库的下载是采用离线安装的形式,先安装的4个拓展的支持库,在进行安装的。

在这里插入图片描述

  1. 对于torch的安装尽量不要采用离线安装1.4.0的版本,我安装的是1.5.0的cpu版本。(conda安装

在这里插入图片描述

使用用服务器,在云服务器中使用conda环境 python = 3.6的版本环境.

pip install -r requirements.txt

在这里插入图片描述

在网上查找资料:opencv安装失败卡在这里是因为没有使用高版本的python环境

在这里插入图片描述

切换环境继续进行安装 python =3.7

在这里插入图片描述

换高于3.6版本的安装存在问题报错

网上解决方法:

pip install lxml

但还是安装失败。导致GMtracker启动失败。

项目启动

首先在你确保环境安装成功之后要下载MOT17的数据集安装官方提供的格式进行解压。

— data
— MOT17
— train
— MOT17-02
— MOT17-04

— test
— MOT17-01
— MOT17-03

在这里插入图片描述
官方提供给你的npy文件进行解压

在这里插入图片描述

在数据进行预处理加载的过程中会产生一个类型错误需要进行改正之后的就可以启动成功了

gmtracker_app.py的启动参数

--sequence_dir
data/MOT17/test/MOT17-01-DPM
--detection_file
npy/npytest_tracktor/MOT17-01-DPM.npy
--checkpoint_dir
experiments/static/params/0001
--max_age
100
--reid_thr
0.6
--output_file
results/test/MOT17-01-DPM.txt
# 注意修改多了一个参数
    (cc, warp_matrix) = cv2.findTransformECC (src, dst, warp_matrix, warp_mode, criteria, None)

在这里插入图片描述

执行完成之后会显示一个txt文件,用于下一步的显示视频效果的操作

show_result.py文件的启动参数。

--sequence_dir
data/MOT17/test/MOT17-01-DPM
--result_file
results/test/MOT17-01-DPM.txt
--output_file
/show/MOT17-01-DPM.avi

在启动的时候需要自己断点调试,去除注释的部分代码,才能显示视频效果,视频的保存还可能存在部分的问题

在这里插入图片描述

算法执行流程

在自己看完代码之后发现整个gmtracker_app.py跟踪算法的执行过程,于DeepSort有很大的相似之处

  • 两个阶段的匹配。
  • 特征的保存与添加确认态的状态。
  • 将级联匹配转换为深度图网络的二阶相似度匹配的问题

下面是我经过断点调试分析之后绘制的算法流程图,简单的说明算法的执行过程

核心的启动函数:

    def frame_callback(frame_idx): #执行跟踪的核心函数
        print("Processing %s"%seq_info["sequence_name"], "frame %05d" %frame_idx)

        # Load image and generate detections.(创建跟踪对象)
        detections = create_detections(
            seq_info["detections"], frame_idx, w_img=seq_info["image_size"][1],h_img=seq_info["image_size"][0])
        
        # Update tracker.
        tracker.predict(warp_matrix[frame_idx-2]) #执行卡尔曼滤波的预测
        tracker.update(detections, seq_info["sequence_name"], frame_idx, checkpoint_dir) #卡尔曼滤波的更新和匹配

        # Store results.
        for track in tracker.tracks:
            if track.time_since_update >= 1:
                continue
            bbox = track.to_tlwh2()
            results.append([ # 添加4个点的坐标位置信息连同跟踪器的id进行存储
                frame_idx, track.track_id, bbox[0], bbox[1], bbox[2], bbox[3]])

    # Run tracker.
    frame_idx = seq_info["min_frame_idx"]
    while frame_idx <= seq_info["max_frame_idx"]:
        frame_callback(frame_idx)
        frame_idx += 1

在这里插入图片描述
在这里插入图片描述

    def _match(self, detections,video,frame,checkpoint_dir):

        confirmed_tracks = [
            i for i, t in enumerate(self.tracks) if t.is_confirmed()]

        unconfirmed_tracks = [
            i for i, t in enumerate(self.tracks) if not t.is_confirmed()]
 
        matches_a, unmatched_tracks_a, unmatched_detections = \
            assignment.graph_matching( # Gmtracker执行图匹配的操作
                self.max_age, self.tracks, detections, confirmed_tracks,reid_thr=self.reid_thr,seq_name=video,ckp_dir=checkpoint_dir)
   
        iou_track_candidates = unconfirmed_tracks + [
            k for k in unmatched_tracks_a if
            self.tracks[k].time_since_update == 1]
   
        unmatched_tracks_a = [
            k for k in unmatched_tracks_a if
            self.tracks[k].time_since_update != 1]
        # 没匹配到的检测框进行第二次最小cost的级联匹配
        matches_b, unmatched_tracks_b, unmatched_detections = \
            assignment.min_cost_matching(
                iou_matching.iou_cost, self.max_iou_distance, self.tracks,
                detections, iou_track_candidates, unmatched_detections)
        # 返回第二次的匹配结果之后将结果相加
        matches = matches_a + matches_b

        unmatched_tracks = list(set(unmatched_tracks_a + unmatched_tracks_b))

        return matches, unmatched_tracks, unmatched_detections

在这里插入图片描述
涉及论文公式的最难的一个部分,个人不太理解

def quadratic_matching(
        tracks, detections, track_indices=None,
        detection_indices=None,reid_thr=0.8,seq_name=None,ckp_dir=None):

    if track_indices is None:
        track_indices = np.arange(len(tracks))
    if detection_indices is None:
        detection_indices = np.arange(len(detections))

    if len(detection_indices) == 0 or len(track_indices) == 0:
        return [], track_indices, detection_indices  # Nothing to match.
    
    dets = np.array([detections[i].feature for i in detection_indices]) # 检测
    tras = np.array([tracks[i].mov_ave for i in track_indices]) # 跟踪
    tra = torch.Tensor(tras) # 转为tensor张量
    det = torch.Tensor(dets)
    det_geos = np.array([[detections[i].tlwh[0],detections[i].tlwh[1],detections[i].tlwh[2]+detections[i].tlwh[0],detections[i].tlwh[3]+detections[i].tlwh[1]] for i in detection_indices])# 计算检测对象的4个坐标
    det_geo = torch.Tensor(det_geos)

    tra_means = np.array([[to_tlwh(tracks[i].mean[0:4])[0],to_tlwh(tracks[i].mean[0:4])[1],to_tlwh(tracks[i].mean[0:4])[0]+to_tlwh(tracks[i].mean[0:4])[2],to_tlwh(tracks[i].mean[0:4])[1]+to_tlwh(tracks[i].mean[0:4])[3]] for i in track_indices]) #计算跟踪对象的平均边界框(Bounding Box)的坐标
    tra_geo = torch.Tensor(tra_means)
    iou = torchvision.ops.box_iou(tra_geo,det_geo) # 计算两种边界框的Iou
    data1 = tra # 追踪器框的坐标
    data2 = det # 检测器框的坐标
    kf_gate = gate( # 该函数用于计算卡尔曼滤波门限。这个门限用于评估跟踪对象和检测对象之间的匹配可能性
            kalman_filter.KalmanFilter(), tracks, detections, track_indices,
            detection_indices)
    _, _, start_src, end_src = gh(data1.shape[0])
    _, _, start_tgt, end_tgt = gh(data2.shape[0])
    data1 = data1.t().unsqueeze(0) # 转置并拓展维度
    data2 = data2.t().unsqueeze(0)
    start_src = torch.tensor(start_src)
    end_src = torch.tensor(end_src)
    start_tgt = torch.tensor(start_tgt)
    end_tgt = torch.tensor(end_tgt)  # 上面是生成匹配矩阵的过程

    with torch.no_grad():
        graphnet = GraphNet() # 创建一个图网络
        params_path = os.path.join(ckp_dir, f"params.pt") # 读取预训练的图网络模型
        graphnet.load_state_dict(torch.load(params_path), strict=False) # 加载模型参数值
        if iou.shape[0] >= iou.shape[1]: # .forward输入到图网络进行图匹配
            indices, thr_flag = graphnet.forward(data1, data2, kf_gate, reid_thr, iou, start_src, end_src, start_tgt, end_tgt, seq_name, inverse_flag=False)
        if iou.shape[0] < iou.shape[1]:
            indices, thr_flag = graphnet.forward(data2, data1, kf_gate.T, reid_thr, iou.t(), start_tgt, end_tgt, start_src, end_src, seq_name, inverse_flag=True)

    # Gmtracker算法的核心部分(使用深度图网络进行匹配的过程)
    matches, unmatched_tracks, unmatched_detections = [], [], []
    for col, detection_idx in enumerate(detection_indices):
        if col not in indices[:, 1]:
            unmatched_detections.append(detection_idx)
    for row, track_idx in enumerate(track_indices):
        if row not in indices[:, 0]:
            unmatched_tracks.append(track_idx)
    for row, col in indices:
        track_idx = track_indices[row]
        detection_idx = detection_indices[col]
        
        if thr_flag[row, col] == 1:
            unmatched_tracks.append(track_idx)
            unmatched_detections.append(detection_idx)
        else:  
            matches.append((track_idx, detection_idx))
    return matches, unmatched_tracks, unmatched_detections

http://www.kler.cn/news/303843.html

相关文章:

  • ES机制原理
  • linux ubuntu编译 openjdk11
  • 中国科技统计年鉴1991-2020年
  • JDBC客户端连接Starrocks 2.5
  • python-回文数(一)
  • 4G MQTT网关在物联网应用中的优势-天拓四方
  • 组播 2024 9 11
  • 为什么mac打不开rar文件 苹果电脑打不开rar压缩文件怎么办
  • 基于Java-SpringBoot+vue实现的前后端分离信息管理系统设计和实现
  • element实现动态路由+面包屑
  • Vue的学习(三)
  • vue2响应式系统是如何实现的(手写)
  • 代码随想录刷题day32丨动态规划理论基础,509. 斐波那契数, 70. 爬楼梯, 746. 使用最小花费爬楼梯
  • 基于Python实现一个庆祝国庆节的小程序
  • Kubernetes 与 springboot集成
  • 【九盾安防】叉车使用安全新升级!指纹识别锁,验证司机操作权限
  • 关于我的阿里云服务器被入侵 - 分析报告
  • 春日课堂:SpringBoot在线教育解决方案
  • 限流,流量整形算法
  • 安全基础设施如何形成统一生态标准?OASA 硬件安全合作计划启动 | 2024 龙蜥大会
  • 【贪心算法】(二)贪心算法区间问题及进阶习题
  • 重学SpringBoot3-集成RocketMQ(二)
  • Python(TensorFlow和PyTorch)及C++注意力网络导图
  • Docker 安装 Nacos 教程
  • L3级智能网联汽车准入试点详细解析及所需材料
  • oracle 如何查死锁
  • Web大学生网页作业成品——动漫喜羊羊网页设计与实现(HTML+CSS)(4个页面)
  • 趣味SQL | 从围棋收官到秦楚大战的数据库SQL语言实现
  • Flutter自定义Icon的简易使用(两种)
  • 项目——负载均衡OJ