(leetcode算法题)528. 按权重随机选择
测试阶段将会反复调用多次成员函数,期望的结果是调用多次后满足:
第 i 个下标被返回的概率是 w[i] / Σw[j],
可以利用rand() % Σw[j] 可以等概率的得到[0, Σw[j] - 1] 这个区间内的整数,且取到其中任意一个整数的概率是 1 / Σw[j]。
那么可以很容易的想到,将一个区间分成Σw[j] 段,取到每段的概率都是1 / Σw[j],
对于第 i 个下标来说,将区间中的 w[i] 个段分配给这第 i 个下标,那么取到第 i 个下标对应的那些段的概率就是w[i] / Σw[j],
Example.
假设 权重数组为[2, 5, 3]
那么可以生成如下示意图,总共有10段,用rand() % 10 获得[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]中的一个随机数,获得其中每个数的概率都是1 / 10,
则取到红色的概率是2 / 10,取到绿色的概率是5 / 10,取到黄色的概率是 3 / 10
那么用代码怎么实现这一点?下面分为两部分伪代码分别说明求前缀和 和 查找pos这两步
求前缀和:
思想:使用前缀和数组让[0, 1]对应红色,[2, 6]对应绿色,[7, 9]对应黄色
[2, 5, 3]的前缀和数组为[2, 7, 10] 2就是红色的末尾,7就是绿色的末尾,9就是黄色的末尾
记[2, 7, 10]为presum
伪代码:
for(i 从 1 到 n - 1){ nums[i] = nums[i - 1] + nums[i]; }
好了好了,上代码!求前缀和数组代码如下:
partial_sum(_weight.begin(), _weight.end(), _weight.begin());
对于partial_sum接口的说明
请看链接std::partial_sum - cppreference.com
template< class InputIt, class OutputIt >
OutputIt partial_sum( InputIt first, InputIt last, OutputIt d_first);
好了好了,那怎么通过x找到其对应的颜色?x是通过rand() % 10获得的一个随机数
查找 x 对应的 pos:
思想:if(x小于2){x对应的是红色,即下标0}
else if(x大于等于2且小于7){x对应的是绿色,即下标1}
else{x大于等于7且小于10}{x对应的是黄色,即下标2}
上面的 if else 逻辑本质上是在找第一个pos,这个 pos 满足下列条件
for(i 从 0 到 n - 1)
if((pos == 0 || x >= presum[pos - 1]) && x < presum[pos])
return pos
然而对于上述的代码有进一步的优化,
我们注意到正数数组对应的前缀和数组是严格递增数组
所以可以使用二分查找算法以 O(logn) 的时间复杂度找到 pos
好了好了,上代码!查找 pos 的代码如下
int sum = _weight.back();
int randnum = rand() % sum + 1;
return lower_bound(_weight.begin(), _weight.end(), randnum) - _weight.begin();
对于lower_bound() 这个接口的说明
具体说明请看链接std::lower_bound - cppreference.com
template< class ForwardIt, class T = typename std::iterator_traits<ForwardIt>::value_type >
constexpr ForwardIt lower_bound( ForwardIt first, ForwardIt last, const T& value );对于一个严格升序的数组nums,给定一个目标值 x,从左到右找到第一个pos,
这个pos满足 nums[pos] >= x,返回指向pos的迭代器。如果一直找不到,返回nums.end()
注意:
查找pos的代码中 randnum加上 1 就是为了保证找的nums[pos]不是 >= x,而是大于x
好了好了,整合一下代码
class Solution {
vector<int> _weight;
public:
Solution(vector<int> w) {
_weight = move(w);
partial_sum(_weight.begin(), _weight.end(), _weight.begin());
}
int pickIndex() {
int sum = _weight.back();
int randnum = rand() % sum + 1;
return lower_bound(_weight.begin(), _weight.end(), randnum) - _weight.begin();
}
};
/**
* Your Solution object will be instantiated and called as such:
* Solution* obj = new Solution(w);
* int param_1 = obj->pickIndex();
*/