类型体系与基本数据类型(第五节)
目录
前言
一、列表
1.1 Batch模板
如何定位子矩阵中的元素
1.2 Array模板
1.3 重复与Duplicate模板
总结
前言
一个深度学习框架的初步实现为例,讨论如何在一个相对较大的项目中深入应用元编程,为系统优化提供更多的可能。
以下内容结合书中原文阅读最佳!!!
一、列表
引入列表的作用是为了方便对批量计算的操作进行处理。在计算机科学中,列表(或数组)提供了一种便捷的方式来存储和管理大量相似类型的数据。对于每个参与计算的操作数,可以看作是一个数据序列,其中包含一组有序的原始数据,这些原始数据的维度和类型通常是相同的。
通过将数据存储在列表中,可以更轻松地对这些数据进行遍历、访问和操作,同时也方便进行批量计算,例如对每个元素进行相同的处理或者进行并行计算。这样能够简化代码逻辑,提高代码的可读性和可维护性。因此,引入列表可以更有效地管理大量数据并进行数据处理。
相似之处
标量列表和矩阵列表相似之处在于它们都是用于存储一组数据的数据结构,每个数据元素都具有相同的结构和类型。
1. 数据存储方式:标量列表和矩阵列表都使用线性结构来存储数据。标量列表是一维的,每个元素都是单个的标量值。而矩阵列表是二维的,每个元素都是一个矩阵。
2. 相同的数据类型:无论是标量列表还是矩阵列表,它们的数据元素具有相同的数据类型。这可以确保在进行数值计算时,数据元素之间可以进行相同的操作。
3. 维度和形状:尽管标量列表和矩阵列表在维度上有所不同,但它们都可以看作是有序的数据序列。标量列表具有一维,可以通过索引访问单个标量值。矩阵列表具有二维,可以通过两个索引来访问矩阵的特定元素。
4. 可迭代性:无论是标量列表还是矩阵列表,都可以通过迭代器或类似的方式进行迭代访问。这允许用户对列表中的每个元素进行遍历或进行特定的操作。
5. 广义的批量计算:标量列表和矩阵列表都可以支持批量计算。标量列表中的每个标量可以视为一个独立的数据样本,而矩阵列表中的每个矩阵可以视为包含多个数据样本的批次。这样可以方便地对整个列表进行相同的数学运算或逻辑操作。
综上所述,尽管标量列表和矩阵列表在维度和形状上有差异,但它们的相似之处在于它们都是用于存储一组具有相同类型的数据元素的数据结构,可以方便地进行迭代访问和批量计算。
1.1 Batch模板
声明
template <typename TElement, typename TDevice, typename TCategory>
calss Batch;
特化
// 标量列表
template <typename TElement, typename TDevice>
class Batch <TElement, TDevice, CategoryTags::Scalar>;
// 矩阵列表
template <typename TElement, typename TDevice>
calss Batch <TElement, TDevice, CategoryTags::Matrix>;
元函数的特化
template <typename TElement, typename TDevice>
constexpr bool IsBatchMatrix<Batch<TElement, TDevice, CategoryTags::Matrix>> = true;
template <typename TElement, typename TDevice>
constexpr bool IsBatchScalar<Batch<TElement, TDevice, CategoryTags::Scalar>> = true;
定义
template <typename TEement, typename TDevice>
class Batch<TElement, TDevice, CategoryTags::Matrix>
{
public:
using ElementType = TElement;
using DeviceType = TDevice;
friend struct LowerAccessImpl<Batch<TElement, TDevice,
CategoryTags::Matrix>>;
public:
Batch(size_t p_batchNum = 0, size_t p_rowNum = 0,
size_t p_colNum = 0);
// 维度相关接口
size_t RowNum() const { return m_rowNum; }
size_t ColNum() const { return m_colNum; }
size_t BatchNum() const { return m_batchNum; }
// 求值相关接口...
// 读写访问接口
bool AvaillableForWrite() const;
void SetValue(size_t p_batchId, size_t p_rowId,
size_t p_colId, ElementType val);
const auto operator [] (size_t p_batchId) const;
// 子矩阵列表接口
auto SubBatchMatrix(size_t p_rowB, size_t p_rowE, size_t p_colB
size_t p_colE) const;
private:
ContinuousMemory<ElementType, DeviceType> m_mem;
size_t m_rowNum;
size_t m_colNum;
size_t m_batchNum;
size_t m_rowLen;
size_t m_rawMatrixSize;
};
如何定位子矩阵中的元素
可以通过调用`SubBatchMatrix`函数来定位子矩阵中的元素。这个函数接受四个参数:`p_rowB`,`p_rowE`,`p_colB`,`p_colE`,分别表示子矩阵的起始行、结束行、起始列和结束列的索引。
示例代码
// 假设有一个 Batch 对象名为 myBatch
size_t rowStart = 0; // 子矩阵起始行
size_t rowEnd = 2; // 子矩阵结束行
size_t colStart = 1; // 子矩阵起始列
size_t colEnd = 3; // 子矩阵结束列
auto subMatrix = myBatch.SubBatchMatrix(rowStart, rowEnd, colStart, colEnd);
// 遍历子矩阵中的元素
for (size_t i = 0; i < subMatrix.RowNum(); i++) {
for (size_t j = 0; j < subMatrix.ColNum(); j++) {
ElementType element = subMatrix[i][j];
// 使用子矩阵中的元素进行相应的操作
// ...
}
}
根据上面的示例,首先通过调用`SubBatchMatrix`函数获取到子矩阵对象`subMatrix`,然后使用双重循环遍历子矩阵中的每个元素,通过`subMatrix[i][j]`来访问特定位置的元素。在循环中,你可以根据需要对子矩阵中的元素进行进一步的处理或操作。
1.2 Array模板
声明
template <typename TData>
class Array;
特化
template <typename TData>
constexpr bool IsBatchMatrix<Array<TData>> = IsMatrix<TData>;
template <typename TData>
constexpr bool IsBatchScalar<Array<TData>> = IsScalar<TData>;
辅助类
template <typename TData, typename TDataCate>
class ArrayImp;
template <typename TData>
class ArrayImp<TData, CategoryTags::Matrix> {
// ...
};
template <typename TData>
class ArrayImp<TData, CategoryTags::Scalar> {
// ...
};
template <typename TData>
class Array : public ArrayImp<TData, DataCategory<TData>>
{
public:
using ElementType = typename TData::ElementType;
using DeviceType = typename TData::DeviceType;
using ArrayImp<TData, DataCategory<TData>>::ArrayImp;
};
定义
template <typename TData>
class ArrayImp<TData, CategoryTags::Matrix>
{
public:
using ElementType = typename TData::ElementType;
using DeviceType = typename TData::DeviceType;
ArrayImp(size_t rowNum = 0, size_t colNum = 0);
template <typename TIterator,
std::enable_if_t<IsIterator<TIterator>>* = nullptr>
ArrayImp(TIterator b, TIterator e);
public:
size_t RowNum() const { return m_rowNum; }
size_t ColNum() const { return m_colNum; }
size_t BatchNum() const { return m_buffer->size(); }
// STL兼容接口
// push_back, size...
// 求值接口
// ...
bool AvailableForWrite() const {
return (!m_evalBuf.IsEvaluated()) &&
(m_buffer.use_count() == 1);
}
protected:
size_t m_rowNum;
size_t m_colNum;
std::shared_ptr<std::vector<TData>> m_buffer;
EvalBuffer<Batch<ElementType, DeviceType,
CategoryTags::Matrix>> m_evalBUf;
};
IsIterator 元函数
template <typename T>
struct IsIterator_
{
template <typename R>
static std::true_type Test(typename std::iterator_traits<R>
::iterator_category*);
template <typename R>
static std::false_type Test(...);
static constexpr bool value = decltype(Test<T>(nullptr))value;
};
template <typename T>
constexpr bool IsIterator = IsIterator_<T>::value;
1.3 重复与Duplicate模板
列表
template <typename TData, typename TDataCate>
calss DuplicateImp;
template <typename TData>
class Duplicate : public DuplicateImp<TData, DataCategory<TData>>
{
public:
using ElementType = typename TData::ElementType;
using DeviceType = typename TData::DeviceType;
using DuplicateImp<TData, DataCategory<TData>>::DuplicateImp;
};
template <typename TData>
constexpr bool IsBatchMatrix<Duplicate<TData>> = IsMatrix<TData>;
template <typename TData>
constexpr bool IsBatchScalar<Duplicate<TData>> = IsScalar<TData>;
定义
template <typename TData, typename TDataCate>
class DuplicateImp;
template <typename TData>
class DuplicateImp <TData, CategoryTags::Scalar> {
// ...
};
template <typename TData>
class DuplicateImp<TData, CategoryTags::Matrix>
{
public:
using ElementType = typename TData::ElementType;
using DeviceType = typename TDta::DeviceType;
DuplicateImp(TData data, size_t batch_num)
: m_data(std::move(data))
, m_batchNum(batch_num)
{
assert(m_batchNum != 0);
}
public:
size_t RowNum(0 const { return m_data.RowNum(); }
size_t ColNum() const { return m_data.ColNum(); }
size_t BatchNum() const { return m_batchNum; }
const TData& Element() const { return m_data; }
// 求值相关接口
// ...
protected:
TData m_data;
size_t m_batchNum;
EvalBuffer<Batch<ElementType, DeviceType,
CategoryTags::Matrix>> m_evalBuf;
};
总结
本小结,基本可以从原文获取解答,所以大部分提供书中代码给大家方便Copy。后面提供本章题目以及答案!!!