最小堆 + 数学思维(重点) + 快速幂
首先是这个题目的问题描述:
给你一个整数数组 nums ,一个整数 k 和一个整数 multiplier 。
你需要对 nums 执行 k 次操作,每次操作中:
找到 nums 中的 最小 值 x ,如果存在多个最小值,选择最 前面 的一个。
将 x 替换为 x * multiplier 。
k 次操作以后,你需要将 nums 中每一个数值对 10^9 + 7 取余。
请你返回执行完 k 次乘运算以及取余运算之后,最终的 nums 数组。
这题整体的思路就是实现最直观的思路就是枚举这个k次操作,然后每次找到数组里面最小的值,然后做修改即可,但是这题的k的范围在1e9,显然直接枚举k是不合理的,显然是一个超时的算法,那么我们如何做到和这个k无关的算法讨论呢?这里其实会用到一些数学的结论推导,这里从灵神那里学来的思路,个人觉得很好,做一下记录,是分析问题的思路,不仅限于这道题,最基本的核心思想就是从简单到复杂,我们从最少的数字开始枚举,然后看一下这之间存在什么样的一个可循关系和规律呢,对于这个题目而言,我们这个算法一定不能从k去暴力枚举,那么入手点其实就是肯定是和k存在一定关系的规律性结论,设计一个和k无关的时间复杂度的算法,那么这个算法的出发点我们从下面考虑:
假如我们只有两个数字[1,3] 然后mul = 2,那么整个过程就是这样的:
首先:1 ->2
然后:2 ->4
接着:3 ->6
再者:4 ->8
继续:6->12
接着:8 -> 16…
其实这里是有一个简单的规律的,我们可以发现当这个1,第一次超过这个3之后,也就是超过这个最大值之后,其实这里的后面的循环就是在 这里假设我们超过这3之后得到的一组数据位[4,3],其实可以发现一个现象就是 这个4和3在轮流的进行相乘,其实这个可以去用理论去证明的,这里我们假设一个x < y,那么我们在x * mul 之后他超过 y了,那么下一次我们一定是用y * mul ,那么在下一次我们一定是使用x进行相乘,因为x*mul也一定小于y * mul,所以按照这个理论我们可以得到,如果是3个数据我们在推导一下:
[1,3,7] mul = 2
还是一样的:
首先:1->2
然后:2->4
接着:3->6
接着:4->8
接着:6->12
然后:7->14
接着:8->16
接着:12->24
接着:14->28
接着:16->32
接着:24->48
…
其实到这里大家也就可以看出来了,对于我们和刚才分析一样的原则只是这里我们的数字变成了三个进行循环,也就是x < y < z,那么依次就对 x y z进行mul操作得到我们的结果,但是可以注意到这样的循环开始的条件一定是我们的所有数据都超过maxx(数组最大值之后才开始循环的)这个其实可以理解,因为当其他数字第一次经过mul操作超过这个maxx是最接近的情况,这样我们在之后的mul操作就可以严格保证按照这样的顺序进行。
所以这题的思路到这里就打开了:
1.第一步我们的处理肯定是将数组中的所有数据经过mul操作,每次去最小值去和mul进行相乘操作,知道我们去到的数据下标为我们的maxx的下标,也就是我们之前数组里面的最大值的下标(这里注意要是最后一个 最大值的下标) (这里可以直接使用最小堆来处理,每次取出对顶元素,记录我们的操作次数,对于k进行相减即可)
2.然后就可以按照我们刚才的思想对于每个元素需要进行多少次mul操作进行计算即可,因为是循环处理嘛,这里假设我们剩下的处理次数为 kk(上面经过第一次处理之后 还需要mul的次数),只需要判断一下当前下标是否小于 kk % n 如果小于 就说明需要 k / n + 1次,否则就是 k / n 次
3.这里的第三个点就是k很大,所以需要快速幂进行指数求和,然后取余,这里不在多说(快速幂大家可以自己学习一下)
4.这里还有一个需要注意的点,就是我们的这个mul如果是1的话需要特殊处理,因为如果是1,就会超时间(假设 k > 1e8),所以直接返回nums就行
具体实现代码(C++版本如下):
class Solution {
public:
vector<int> getFinalState(vector<int>& nums, int k, int multiplier) {
const int mod = 1e9 + 7;
int n = nums.size();
int pos = 0;
if(multiplier == 1){//这里是特殊情况1需要直接输出 否则超时间
return nums;
}
for(int i = 1;i < n; i++){
if(nums[i] >= nums[pos]){
pos = i;//找到最大元素的位置,判断第一次循环的终止位置
}
}
cout << "pos == " << pos << endl;
vector<long long>tmp;
priority_queue<pair<long long,int>, vector<pair<long long, int>> ,greater<pair<long long,int>>>pq;
typedef pair<long long, int> pli;
for(int i = 0;i < n; i++){
pq.push({1ll * nums[i], i});
tmp.push_back(nums[i]);/
}
while (k > 0) {
k--;
pli p = pq.top(); pq.pop();
p.first *= multiplier; pq.push(p);
tmp[p.second] = tmp[p.second] * multiplier % mod; //保存处理之后的数组
if (p.second == pos) break;//到达maxx最大的pos
}
// cout << kk << endl;
auto poww = [&](long long a, long long b) {
long long y = 1;
for (; b; b >>= 1) {
if (b & 1) y = y * a % mod;
a = a * a % mod;
}
return y;
};
cout << k << endl;
for (int i = 0; i < n; i++) {
int idx = pq.top().second; pq.pop();
tmp[idx] = tmp[idx] * poww(multiplier, k / n + (i < (k % n) ? 1 : 0)) % mod;//这里是针对每个循环元素的次数,进行快速幂相乘即可,结果都要记得取mod
}
vector<int> ans;
for (auto x : tmp) ans.push_back(x);
return ans;
}
};