数据结构 ——— 用堆解决TOP-K问题
目录
何为TOP-K问题
用堆解决TOP-K问题
代码实现
何为TOP-K问题
比如:整个专业的前10名,世界500强,富豪榜,游戏中前100的活跃玩家等
对于 TOP-K 问题,能想到的最简单直接的方式就是排序
但是,如果数据量非常大,排序就不太可取了(可能数据都不能一下子全部加载到内存中)
最佳就是用堆来解决
用堆解决TOP-K问题
1. 用数据集合中前K个元素来建堆
- 找前K个最大的元素时,就建小堆
- 找前K个最小的元素时,就建大堆
2. 用剩余的N-K个元素依次与堆顶元素来比较,不满足则替换堆顶元素
- 将剩余N-K个元素依次与堆顶元素比完之后
- 堆中剩余的K个元素就是所求的前K个最大或者最小的元素
代码实现
生成一个数组,并且随机放入 10 个最大的值:
void TestTopK()
{
// 动态申请 10000 个 int 类型的数组
int n = 10000;
int* a = (int*)malloc(sizeof(int) * n);
// 判断是否申请成功
if (a == NULL)
{
perror("malloc fail");
return;
}
// 随机值生成器
srand((unsigned int)time(NULL));
// 在数组中依次存放小于 100000 的值
for (int i = 0; i < n; i++)
{
a[i] = rand() % 100000;
}
// 将数组中的 10 个元素改为大于或者等于 100000 的值
a[5] = 100000 + 1;
a[123] = 100000 + 9;
a[531] = 100000 + 2;
a[4121] = 100000 + 8;
a[115] = 100000 + 3;
a[2335] = 100000 + 7;
a[9999] = 100000 + 4;
a[76] = 100000 + 6;
a[423] = 100000 + 5;
a[3144] = 100000 + 0;
// 找出数组 a 中前 10 个最大的值,并打印
PrintTopK(a, n, 10);
}
找出前 10 个最大的值:
void PrintTopK(int* a, int size, int k)
{
for (int i = (k - 1 - 1) / 2; i >= 0; i--)
{
// 向下调整建堆(建小堆)
AdjustDown(a, k, i);
}
for (int i = k; i < size; i++)
{
// 当前堆顶元素小于当前数组元素时就交换
if (a[0] < a[i])
{
Swap(&a[0], &a[i]);
// 向下调整堆
AdjustDown(a, k, 0);
}
}
ArrPrint(a, k);
}
代码解析(代码中的函数实现会放在最后,先讲解思路):
想要找到数组 a 中的前 10 个最大的数,那么就先建立大小为 10 的小堆
利用向下调整算法对数组 a 中的前 10 个数进行建堆,注意是建小堆
再把数组中剩余的元素依次与堆顶元素比较,当堆顶元素小于数组当前元素时,就交换
因为小堆的特点是:堆顶的元素是整个堆中元素最小的
那么堆中最小元素和数组当前元素比较时,还要小,那么堆顶元素必然不是前 10 个最大的数
交换后再利用向下调整算法调整堆
遍历完数组 a 中的所有元素后,堆中的元素就是前 10 个最大的数
代码验证:
100000
/ \
100002 100001
/ \ / \
100003 100005 100004 100006
/ \ /
100007 100009 100008
代码中的函数实现:
// 向下调整(默认小堆)
void AdjustDown(HPDataType* a, int size, int parent)
{
int child = parent * 2 + 1;
while (child < size)
{
// 先找到左右孩子中小的那个
if ((child + 1 < size) && (a[child + 1] < a[child]))
child++;
if (a[parent] > a[child])
{
// 交换
Swap(&a[parent], &a[child]);
// 迭代
parent = child;
child = parent * 2 + 1;
}
else
{
break;
}
}
}
// 交换
void Swap(HPDataType* p1, HPDataType* p2)
{
HPDataType tmp = *p1;
*p1 = *p2;
*p2 = tmp;
}
// 打印
void ArrPrint(int* a, int size)
{
for (int i = 0; i < size; i++)
{
printf("%d ", a[i]);
}
printf("\n");
}