当前位置: 首页 > article >正文

第二十六周学习周报

目录

    • 摘要
    • Abstract
    • 1 FOMM(一阶运动模型)
      • 1.1 基本框架
      • 1.2 实验
      • 1.3 代码分析
    • 2 LSTM复习
      • 2.1 LSTM原理
      • 1.2 LSTM反向传播的数学推导
    • 总结

摘要

本周的主要学习内容是FOMM模型,FOMM一种用于图像动画的技术,它能够通过给定的源图像和驱动视频生成逼真的动画序列,是实现虚拟主播的重要技术之一。FOMM模型作为图像动画生成的深度学习技术,在原理、训练和应用方面都展现出强大的实力。它的出现解决了GAN在视频中传输面部表情时依赖于预训练模型的问题。本周,作者将从基本框架、实验以及代码解析三部分对FOMM进行学习。除此本周作者还对LSTM的相关内容进行了复习。

Abstract

The main learning content of this week is the FOMM model, which is a technology used for image animation. FOMM can generate realistic animation sequences through given source images and driving videos, and is one of the important technologies for implementing virtual anchors. The FOMM model, as a deep learning technique for image animation generation, has demonstrated strong capabilities in terms of principles, training, and applications. Its emergence solves the problem of GAN relying on pre trained models when transmitting facial expressions in videos. This week, the author will study FOMM from three parts: basic framework, experiments, and code analysis.In addition, this week the author also reviewed the relevant content of LSTM.

1 FOMM(一阶运动模型)

本次的学习内容基于一篇名为First Order Motion Model for Image Animation的论文。一阶运动模型是一种用于图像动画的技术,它通过深度学习来实现静态图像的动态化。这种模型的核心思想是将动画对象的运动分解为关键点的位移和局部仿射变换的组合。在论文中,作者提出了一种新的模型能够通过自监督的方式学习到图像中的关键点。下图中是人脸、全身运动,动画运动的驱动效果。可以看出,first-order 的效果还是很不错的,且该方法基于运动模型推导而出,在形式上也具有良好的可解释性。
在这里插入图片描述

1.1 基本框架

first-order 的算法框架如下图所示,由运动估计模块和图像生成两个主要模块组成。输入是一个源图像S和一个驱动视频帧D的帧,无监督keyporint detector 提取由稀疏关键点和相对于参考系R的局部仿射变换组成的一阶运动表示。dense motion network使用运动表示来生成从D到S的密集光流TS←D和遮挡映射OS←D。generation module使用源图像和密集运动网络的输出来渲染目标图像。
在这里插入图片描述
运动估计模块
运动估计模块的目的是预测从一个驱动视频中的帧D到源帧S的稠密运动场。稠密运动场随后用于将从计算出的特征图与中的目标姿态对齐。运动场通过一个函数建模,该函数将中每个像素位置映射到中相应的位置。
在这里插入图片描述
运动估计模块由两部分组成:
(1)Keypoint Detector:这个部分负责探测关键点的位置,并计算局部的仿射变换矩阵以及这个矩阵的一阶导数(Jacobi)。它要同时估计源影像和驱动帧的关键点变换参数。
(2)Dense Motion Field:在得到关键点参数后,密集运动场会生成每一个像素的运动,即图像的光流场。这个部分将局部近似运动组合起来,以获得最终的密集运动场。
经运动估计模块处理后的输出有两个:dense motion field和occlusion mask。
dense motion field体现了驱动图像D中的每个关键点到源图像S的映射关系。
occlusion mask表明了在最终生成的图像中,对于驱动图像D而言,哪部分姿态可以通过S扭曲得到,哪部分只能通过inpainting得到。
图像生成模块
图像生成模块的主要任务是用encoder对原帧S进行特征编码,得到中间特征。采用上一步得到的denseoptical flow对中间特征进行warp,并乘上occlusion map,得到变换后的中间特征,最终通过一个decoder重建出图像。

1.2 实验

数据集
实验使用到的数据集有三个VoxCeleb、UvA Nemo、BAIR机器人推送数据集、太极高清数据集。
1、VoxCeleb数据集是从Y ouTube视频中提取的22496个视频的人脸数据集。
2、UvA Nemo数据集是一个面部分析数据集,由1240个视频组成。应用与VoxCeleb完全相同的预处理。
3、BAIR机器人推送数据集包含Sawyer机械臂在桌子上推送不同物体所收集的视频。它包括42880个培训和128个测试视频。每个视频有30帧长,分辨率为256×256。
4、太极高清数据集,包含了280个太极视频。作者使用252个视频进行训练,28个用于测试。
实验评估标准
评估标准主要有4点:
1、计算生成图片和ground-truth的L1距离
2、平均关键点距离(AKD):对于Tai-Chi-HD, VoxCeleb and Nemo数据集,采用了第三方预训练的关键点检测器来进行评估。AKD是通过计算ground truth检测关键点与生成视频的平均距离得到的。
3、缺少关键点速率(MKR):在实况帧中检测到但在生成的帧中没有检测到的关键点的百分比。该度量评估每个生成的帧的外观质量。
4、平均欧几里得距离(AED):考虑到外部训练的图像表示,实际和生成的帧表示之间的平均欧氏距离,采用了特征嵌入。
实验结果
与其他现有技术的比较发现,X2Face和Monkey Net都无法正确地传递驾驶视频中的身体概念,而是将源图像中的人体扭曲为斑点。相反,论文中的方法能够生成外观更好的视频,其中每个身体部位都是独立动画。
在这里插入图片描述
在这里插入图片描述

1.3 代码分析

FFOM的代码是开源的,源码地址为https://github.com/AliaksandrSiarohin/first-order-model。
接下来我将对其中的重要部分进行解读

(1)数据载入部分
该段代码用于构建一个数据集,以学习如何同步视频帧。可以看到图像输入x=5帧,但是视频输入有mel=5帧和indiv_mels=25帧。x和indiv_mels输入到生成器,得到x’=5帧,x’和mel输入到动作同步。


syncnet_T = 5
syncnet_mel_step_size = 16

class Dataset(object):
    def __init__(self, split):
        self.all_videos = get_image_list(args.data_root, split)

    def get_frame_id(self, frame):
        return int(basename(frame).split('.')[0])

    def get_window(self, start_frame):
        start_id = self.get_frame_id(start_frame)
        vidname = dirname(start_frame)

        window_fnames = []
        for frame_id in range(start_id, start_id + syncnet_T):
            frame = join(vidname, '{}.jpg'.format(frame_id))
            if not isfile(frame):
                return None
            window_fnames.append(frame)
        return window_fnames

    def read_window(self, window_fnames):
        if window_fnames is None: return None
        window = []
        for fname in window_fnames:
            img = cv2.imread(fname)
            if img is None:
                return None
            try:
                img = cv2.resize(img, (hparams.img_size, hparams.img_size))
            except Exception as e:
                return None

            window.append(img)

        return window

    def crop_audio_window(self, spec, start_frame):
        if type(start_frame) == int:
            start_frame_num = start_frame
        else:
            start_frame_num = self.get_frame_id(start_frame)
        start_idx = int(80. * (start_frame_num / float(hparams.fps)))
        
        end_idx = start_idx + syncnet_mel_step_size

        return spec[start_idx : end_idx, :]

    def get_segmented_mels(self, spec, start_frame):
        mels = []
        assert syncnet_T == 5
        start_frame_num = self.get_frame_id(start_frame) + 1 # 0-indexing ---> 1-indexing
        if start_frame_num - 2 < 0: return None
        for i in range(start_frame_num, start_frame_num + syncnet_T):
            m = self.crop_audio_window(spec, i - 2)
            if m.shape[0] != syncnet_mel_step_size:
                return None
            mels.append(m.T)

        mels = np.asarray(mels)

        return mels

    def prepare_window(self, window):
        # 3 x T x H x W
        x = np.asarray(window) / 255.
        x = np.transpose(x, (3, 0, 1, 2))

        return x

    def __len__(self):
        return len(self.all_videos)

    def __getitem__(self, idx):
        while 1:
            idx = random.randint(0, len(self.all_videos) - 1)
            vidname = self.all_videos[idx]
            img_names = list(glob(join(vidname, '*.jpg')))
            if len(img_names) <= 3 * syncnet_T:
                continue
            
            img_name = random.choice(img_names)
            wrong_img_name = random.choice(img_names)
            while wrong_img_name == img_name:
                wrong_img_name = random.choice(img_names)

            window_fnames = self.get_window(img_name)
            wrong_window_fnames = self.get_window(wrong_img_name)
            if window_fnames is None or wrong_window_fnames is None:
                continue

            window = self.read_window(window_fnames)
            if window is None:
                continue

            wrong_window = self.read_window(wrong_window_fnames)
            if wrong_window is None:
                continue

            try:
                wavpath = join(vidname, "audio.wav")
                wav = audio.load_wav(wavpath, hparams.sample_rate)

                orig_mel = audio.melspectrogram(wav).T
            except Exception as e:
                continue

            mel = self.crop_audio_window(orig_mel.copy(), img_name)
            
            if (mel.shape[0] != syncnet_mel_step_size):
                continue

            indiv_mels = self.get_segmented_mels(orig_mel.copy(), img_name)
            if indiv_mels is None: continue

            window = self.prepare_window(window)
            y = window.copy()
            window[:, :, window.shape[2]//2:] = 0.

            wrong_window = self.prepare_window(wrong_window)
            x = np.concatenate([window, wrong_window], axis=0)

            x = torch.FloatTensor(x)
            mel = torch.FloatTensor(mel.T).unsqueeze(0)
            indiv_mels = torch.FloatTensor(indiv_mels).unsqueeze(1)
            y = torch.FloatTensor(y)
            return x, indiv_mels, mel, y

(2)训练代码
下面是训练代码,分别使用了3个损失,重构loss,关键点探测loss和dense motion loss,并且分别有一个权重。在hparams.py我们可以看到,原作者在最开始给定关键点探测loss的权重为0(意味着没有梯度反传,没有训练),关键点探测loss为0.07,当dense motion loss在验证集中下降到0.75的时候,才修改dense motion loss的权重为0.03(此时才开始梯度反传)。


def train(device, model, disc, train_data_loader, test_data_loader, optimizer, disc_optimizer,
          checkpoint_dir=None, checkpoint_interval=None, nepochs=None):
    global global_step, global_epoch
    resumed_step = global_step

    while global_epoch < nepochs:
        print('Starting Epoch: {}'.format(global_epoch))
        running_sync_loss, running_l1_loss, disc_loss, running_perceptual_loss = 0., 0., 0., 0.
        running_disc_real_loss, running_disc_fake_loss = 0., 0.
        prog_bar = tqdm(enumerate(train_data_loader))
        for step, (x, indiv_mels, mel, gt) in prog_bar:
            disc.train()
            model.train()

            x = x.to(device)
            mel = mel.to(device)
            indiv_mels = indiv_mels.to(device)
            gt = gt.to(device)

            ### Train generator now. Remove ALL grads. 
            optimizer.zero_grad()
            disc_optimizer.zero_grad()

            g = model(indiv_mels, x)

            if hparams.syncnet_wt > 0.:
                sync_loss = get_sync_loss(mel, g)
            else:
                sync_loss = 0.

            if hparams.disc_wt > 0.:
                perceptual_loss = disc.perceptual_forward(g)
            else:
                perceptual_loss = 0.

            l1loss = recon_loss(g, gt)

            loss = hparams.syncnet_wt * sync_loss + hparams.disc_wt * perceptual_loss + \
                                    (1. - hparams.syncnet_wt - hparams.disc_wt) * l1loss

            loss.backward()
            optimizer.step()

            ### Remove all gradients before Training disc
            disc_optimizer.zero_grad()

            pred = disc(gt)
            disc_real_loss = F.binary_cross_entropy(pred, torch.ones((len(pred), 1)).to(device))
            disc_real_loss.backward()

            pred = disc(g.detach())
            disc_fake_loss = F.binary_cross_entropy(pred, torch.zeros((len(pred), 1)).to(device))
            disc_fake_loss.backward()

            disc_optimizer.step()

            running_disc_real_loss += disc_real_loss.item()
            running_disc_fake_loss += disc_fake_loss.item()

            if global_step % checkpoint_interval == 0:
                save_sample_images(x, g, gt, global_step, checkpoint_dir)

            # Logs
            global_step += 1
            cur_session_steps = global_step - resumed_step

            running_l1_loss += l1loss.item()
            if hparams.syncnet_wt > 0.:
                running_sync_loss += sync_loss.item()
            else:
                running_sync_loss += 0.

            if hparams.disc_wt > 0.:
                running_perceptual_loss += perceptual_loss.item()
            else:
                running_perceptual_loss += 0.

            if global_step == 1 or global_step % checkpoint_interval == 0:
                save_checkpoint(
                    model, optimizer, global_step, checkpoint_dir, global_epoch)
                save_checkpoint(disc, disc_optimizer, global_step, checkpoint_dir, global_epoch, prefix='disc_')


            if global_step % hparams.eval_interval == 0:
                with torch.no_grad():
                    average_sync_loss = eval_model(test_data_loader, global_step, device, model, disc)

                    if average_sync_loss < 2:
                        hparams.set_hparam('syncnet_wt', 0.01)

            prog_bar.set_description('L1: {}, Sync: {}, Percep: {} | Fake: {}, Real: {}'.format(running_l1_loss / (step + 1),
                                                                                        running_sync_loss / (step + 1),
                                                                                        running_perceptual_loss / (step + 1),
                                                                                        running_disc_fake_loss / (step + 1),
                                                                                        running_disc_real_loss / (step + 1)))

        global_epoch += 1

2 LSTM复习

2.1 LSTM原理

LSTM结构如下图所示:
在这里插入图片描述
LSTM有3个gate:
(1)Input Gate:控制data输入储存单元,外界neural的输出想要写入存储单元时要经过它。
(2)Output Gate:控制data从存储单元中输出,外界neural想要从存储单元中读出值需要经过它。
(3)Forget Gate:决定什么时候要把存储单元中的data清除。
这三个Gate的开关都是神经网络自己学习的,它可以自己学习什么时候开门,什么时候关门。综上,整个LSTM有4个输入,1个输出。
在这里插入图片描述
将上述模型进一步细化,如上图所示。假定
我们要存入单元的输入为z,
控制input Gate的信号为zi,
控制Output Gate的信号为zo,
控制Forget Gate的信号为zf,

把z输入通过激活函数得到g(z),zi输入通过激活函数得到f(zi),这里用到的激活函数通常都为sigmoid函数,因为经过sigmoid函数所得值介于0和1之间,可以以此判断门是打开还是关闭的,把g(z)乘上f(zi)得到g(z)f(zi)。
zf通过激活函数得到f(zf),c’=g(z)f(zi)+cf(zf)。c’为重新存入单元的值,c’经sigmoid函数得到h(c‘)。最后由h(c‘)乘上f(zo)得到最终的输出a。输出门受f(zo) 所操控f(zo)等于 1 的话,就说明 h(c′) 能通过,f(zo) 等于 0 的话,说明记忆元里面存在的值没有办法通过输出门被读取出来。其他gate的处理与此类似。

1.2 LSTM反向传播的数学推导

LSTM前向传播的示意图如下所示:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

总结

在实际应用中,FOMM模型已经成为动画制作领域的重要工具之一,为虚拟现实、增强现实、电影和游戏制作等领域带来了更加自然、逼真和智能的交互体验。随着技术的不断发展和完善,FOMM模型将会在未来发挥更加重要的作用。通过本周的学习发现,我对原来的很多知识有所遗忘,后续将通过相关模型的文献阅读来进行复习。


http://www.kler.cn/a/447808.html

相关文章:

  • JavaScript 中的 `parseInt()` 函数详解
  • 设计模式期末复习
  • 靜態IP與DHCP的區別和用法
  • UIP协议栈 TCP通信客户端 服务端,UDP单播 广播通信 example
  • 【Lua热更新】上篇
  • 基于 Python 解决 X 轴上点距离最小值问题
  • c语言图书信息管理系统源码
  • YOLOv8改进,YOLOv8引入Hyper-YOLO的MANet混合聚合网络+HyperC2Net网络
  • AI图像生成利器:Stable Diffusion 3.5本地运行与远程出图操作流程
  • Nginx - 负载均衡及其配置(Balance)
  • SVM理论推导
  • NLP自然语言学习路径图
  • MAC地址和IP地址的区别
  • 【HarmonyOs学习日志(14)】计算机网络之域名系统DNS
  • 【Pandas】pandas Series size
  • mysql,数据库数据备份
  • [Unity Shader]【游戏开发】【图形渲染】 Shader数学基础5-方阵、单位矩阵和转置矩阵
  • 地址栏输入URL浏览器会发生什么?
  • 有关异步场景的 10 大 Spring Boot 面试问题
  • CentOS 7 安装、测试和部署FastDFS
  • 在 K8S 中创建 Pod 是如何使用到 GPU 的: nvidia device plugin 源码分析
  • 得物Java后端一面,扛住了!
  • 数据结构与算法学习笔记----Kruskal算法
  • Moretl非共享文件夹日志采集
  • 计算世界之安生:C++继承的文水和智慧(上)
  • 数据仓库工具箱—读书笔记02(Kimball维度建模技术概述03、维度表技术基础)