当前位置: 首页 > article >正文

TensorFlow之sparse tensor

目录

  • 前言
  • 创建sparse tensor
  • sparse tensor的运算:

前言

sparse tensor 稀疏tensor, tensor中大部分元素是0, 少部分元素是非0.

创建sparse tensor

import tensorflow as tf

# indices指示正常值的索引, 即哪些索引位置上是正常值. 
# values表示这些正常值是多少.
# indices和values是一一对应的. [0, 1]表示第0行第1列的值是1, [1,0]表示第一行第0列的值是2, [2, 3]表示第2行第3列的值是3, 以此类推.
# dense_shape表示这个SparseTensor的shape是多少
s = tf.SparseTensor(indices = [[0, 1], [1, 0], [2, 3]],
                    values = [1., 2., 3.],
                    dense_shape = [3, 4])
print(s)
# 把sparse tensor转化为稠密矩阵 
print(tf.sparse.to_dense(s))

结果如下:

SparseTensor(indices=tf.Tensor(
[[0 1]
 [1 0]
 [2 3]], shape=(3, 2), dtype=int64), values=tf.Tensor([1. 2. 3.], shape=(3,), dtype=float32), dense_shape=tf.Tensor([3 4], shape=(2,), dtype=int64))
tf.Tensor(
[[0. 1. 0. 0.]
 [2. 0. 0. 0.]
 [0. 0. 0. 3.]], shape=(3, 4), dtype=float32)

sparse tensor的运算:


import tensorflow as tf

# indices指示正常值的索引, 即哪些索引位置上是正常值. 
# values表示这些正常值是多少.
# indices和values是一一对应的. [0, 1]表示第0行第1列的值是1, [1,0]表示第一行第0列的值是2, [2, 3]表示第2行第3列的值是3, 以此类推.
# dense_shape表示这个SparseTensor的shape是多少
s = tf.SparseTensor(indices = [[0, 1], [1, 0], [2, 3]],
                    values = [1., 2., 3.],
                    dense_shape = [3, 4])
print(s)

# 乘法
s2 = s * 2.0
print(s2)

try:
    # 加法不支持.
    s3 = s + 1
except TypeError as ex:
    print(ex)

s4 = tf.constant([[10., 20.],
                  [30., 40.],
                  [50., 60.],
                  [70., 80.]])
# 得到一个3 * 2 的矩阵
print(tf.sparse.sparse_dense_matmul(s, s4))

结果如下:

SparseTensor(indices=tf.Tensor(
[[0 1]
 [1 0]
 [2 3]], shape=(3, 2), dtype=int64), values=tf.Tensor([1. 2. 3.], shape=(3,), dtype=float32), dense_shape=tf.Tensor([3 4], shape=(2,), dtype=int64))
SparseTensor(indices=tf.Tensor(
[[0 1]
 [1 0]
 [2 3]], shape=(3, 2), dtype=int64), values=tf.Tensor([2. 4. 6.], shape=(3,), dtype=float32), dense_shape=tf.Tensor([3 4], shape=(2,), dtype=int64))
unsupported operand type(s) for +: 'SparseTensor' and 'int'
tf.Tensor(
[[ 30.  40.]
 [ 20.  40.]
 [210. 240.]], shape=(3, 2), dtype=float32)

注意在定义sparse tensor的时候 indices必须是排好序的. 如果不是,定义的时候不会报错, 但是在to_dense的时候会报错:

import tensorflow as tf


s5 = tf.SparseTensor(indices = [[0, 2], [0, 1], [2, 3]],
                    values = [1., 2., 3.],
                    dense_shape = [3, 4])
print(s5)
#print(tf.sparse.to_dense(s5)) #报错:indices[1] = [0,1] is out of order. Many sparse ops require sorted indices.
# 可以通过reorder对排序, 这样to_dense就没问题了.
s6 = tf.sparse.reorder(s5)
print(tf.sparse.to_dense(s6))

结果如下:

SparseTensor(indices=tf.Tensor(
[[0 2]
 [0 1]
 [2 3]], shape=(3, 2), dtype=int64), values=tf.Tensor([1. 2. 3.], shape=(3,), dtype=float32), dense_shape=tf.Tensor([3 4], shape=(2,), dtype=int64))
tf.Tensor(
[[0. 2. 1. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 3.]], shape=(3, 4), dtype=float32)

http://www.kler.cn/a/613550.html

相关文章:

  • STM32 CAN学习(一)
  • [C++开发经验总结]何时用push?/何时用emplace?
  • 鸿蒙学习笔记(3)-像素单位、this指向问题、RelativeContainer布局、@Style装饰器和@Extend装饰器
  • 【NLP 50、损失函数 KL散度】
  • Unity 简单使用Addressables加载SpriteAtlas图集资源
  • java使用aspose添加多个图片到word
  • 3.27-1 pymysql下载及使用
  • Stable Diffusion 基础模型结构超级详解!
  • 用 pytorch 从零开始创建大语言模型(七):根据指示进行微调
  • TextGrad:案例
  • 横扫SQL面试——事件流处理(峰值统计)问题
  • SDL —— 将sdl渲染画面嵌入Qt窗口显示(附:源码)
  • CSS回顾-Flex弹性盒布局
  • Vue $bus被多次触发
  • 【WPF】ListView数据绑定
  • 【AI工具开发】Notepad++插件开发实践:从基础交互到ScintillaCall集成
  • C语言之链表
  • 分布式光伏防逆流如何实现?
  • 每日免费分享之精品wordpress主题系列~DAY16
  • 云原生四重涅槃·破镜篇:混沌工程证道心,九阳真火锻金身