基于权重(别名采样alias sample) 的随机选择算法
转载用于收藏学习,尊重原创
,
算法原理讲解,原创作者写的很通俗易懂,感恩 !
当我们得到一个概率分布,如何根据这个概率分布抽样是一个常见的问题。这篇文章将介绍alias method(别名采样),这种算法的时间复杂度为O(1),当然还需要复杂度为O(n)的预处理。下面我将通过一个例子介绍别名采样算法。
问题背景
假设一共存在A,B,C,D四种情况,它们出现的概率分别为 0.3,0.1,0.1,0.5。如何实现按概率抽样呢?
比较常用的一种方法是生成一个数组:1,1,1,2,3,4,4,4,4,4,其中1对应A,2对应B,以此类推,然后随机在数组中抽取一个即可。这种方法简单易实现,但是这是在仅仅有4种情况的条件下。当情况变多,这种方法就会占用很大的空间了,所以并不适用于大规模的通用情况。
另外,可以根据它们的概率密度分布生成累积分布:0.3,0.4,0.5,1。然后生成一个0-1之间的随机数,看它落在哪个区间。然而,这时需要与临界点进行比较。我们知道,插入有序数列最好的算法的时间复杂度为O(logn),所以这种方法复杂度较大。
我们这篇文章提到的alias method可以实现以运行复杂度为O(1)的方式抽样。当然它需要预处理,预处理的时间复杂度为O(n),但是重复跑的时候,运行时间复杂度低才是重要的。
别名采样算法
首先,介绍在等概率分布和二项分布这两种模型中的抽样方法。
我们知道等概率分布抽样的时间复杂度为O(1),考虑一种情况,如果四种情况A,B,C,D出现的概率均为0.25,我们用1代表A,2代表B,以此类推,那我们只需要在1~4里随机产生一个整数,抽中哪个就是哪个,复杂度自然为O(1),这是等概率分布抽样的情况。
我们知道对符合二项分布的模型进行抽样的时间复杂度也为O(1),因为如果只有两种情况A,B时,设他们出现的概率分别为 0.2,0.8,我们用累积分布的方法,小于0.2就是A, 大于就是B,只需比较一次,所以复杂度也是O(1),这是二项分布抽样的情况。
alias method就是把这两种方法结合起来。
仍然以本文一开始提出的例子为例。本文一开始的例子中,假设一共存在A,B,C,D四种情况,它们出现的概率分别为 0.3,0.1,0.1,0.5,我们称为原始分布。在下图中,我们用绿色代表A,蓝色代表B,紫色代表C,橙色代表D,将原始分布示意如下:

首先我们把原概率分布乘以N(有几种情况,N就等于几),这里是N=4。得到A、B、C、D四种情况的值分别是1.2,0.4,0.4,2.0,如图所示。

然后我们把它们拼成等概率分布和二项分布:

注意拼接的过程中,每一列的值最大为1,且每一列最多有两种情况。这样对于这四列,他们被抽中的概率均为1,服从四种情况的等概率分布;对于每一列,只有最多两种情况,都符合二项分布。
做完以上处理后,我们就可以开始抽样了。首先第一步,我们以等概率分布抽四列中的一列。然后第二步,生成一个0-1之间的随机数,在第一步抽中的列里继续抽样。
举例来说,例如我们首先抽中了第四列(概率为0.25)。然后在第四列中进行二项分布抽样,如果小于0.8,是橙色,代表A,反之,就是绿色,代表D。
这两步操作的复杂度均为O(1),故总时间复杂度也为O(1)。
那么这样抽样是正确的吗,是否服从原来的概率分布呢?换句话说,在原概率分布中抽到A的概率是0.3,那使用alias方法抽到的概率还是0.3吗?
我们以抽取A情况为例。原来抽中A的概率为0.3。运用alias method方法后,抽中a的概率为抽中第一列的概率+抽中第四列且随机数小于0.2的概率,算起来为0.25+(0.25 * 0.2) = 0.3,所以完全一样。
lua 实现算法
-- 别名方法 local function prepare_weighted_random4(values, weights) assert(#values == #weights) local tinsert = table.insert local ipairs = ipairs local count = #weights local sum = 0 -- 计算总和 for _, w in ipairs(weights) do sum = sum + w end local avg = sum / count -- 平均权重 local aliases = {} -- 别名表 for _, _ in ipairs(weights) do tinsert(aliases, {1, false}) end local sidx = 1 -- 找到第1个小于平均值的索引 while sidx <= count and weights[sidx] >= avg do sidx = sidx + 1 end if sidx <= count then -- 如果 small_i > count 表示所有权重值相等,什么也不用处理 local small = {sidx, weights[sidx] / avg} local bidx = 1 -- 找到第1个大于等于平均值的索引 while bidx <= count and weights[bidx] < avg do bidx = bidx + 1 end local big = {bidx, weights[bidx] / avg} while true do aliases[small[1]] = {small[2], big[1]} -- 桶的索引即是小权重的索引,从中去掉小权重的比例,剩余的放大权重 big = {big[1], big[2] - (1-small[2])} -- 大权重去掉已放入的比例 if big[2] < 1 then -- 如果大权重剩余的比例已小于1,表示小于平均权重 small = big -- 大权重变成小权重 bidx = bidx + 1 -- 找下一个大权重的索引 while bidx <= count and weights[bidx] < avg do bidx = bidx + 1 end if bidx > count then break end big = {bidx, weights[bidx] / avg} -- 得到下一个大权重 else -- 大权重的比例大于等于1,表示不比平均权重小,继续找小权重 sidx = sidx + 1 -- 找下一个小权重索引 while sidx <= count and weights[sidx] >= avg do sidx = sidx + 1 end if sidx > count then break end small = {sidx, weights[sidx] / avg} end end end return function() local n = math.random() * count local i = math.floor(n) local odds, alias = aliases[i+1][1], aliases[i+1][2] -- 小权重比例,大权重索引 local idx if n - i > odds then idx = alias else idx = i + 1 end return values[idx], weights[idx] end end 代码有点复杂,不过看那个返回的函数,就知道有多快。算法思路是这样的:
- 将权重总和切成N个桶,N就是weights的数量,桶的大小就是平均权重。
- 从weights中得到一个小于平均权重的列表,和一个大于等于平均权重的列表。
- 取出一个小权重放入桶中,桶还有一点空间用来放一个大权重的一部分。
- 一直重复这个过程,直到桶都填完,最终得到aliases这个数据结构。
- aliases的索引是小权重的索引,aliases的每个元素由两个值组成:第一个是小权重占的比例,第二个是大权重的索引。