00001
00002
00003
00004 from __future__ import division
00005 from random import random
00006 from copy import copy
00007
00008
00009 FLOAT_ERROR = 0.000001
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020 class NonUniformRandomInt:
00021 def __init__(self, relative_probs_list):
00022 absolute_prob_list = self.make_absolute_probs_list(relative_probs_list)
00023 self.up_bound = len(absolute_prob_list)
00024 self.probs, self.overflow = self.calc_tables(absolute_prob_list)
00025
00026
00027 def make_absolute_probs_list(self, relative_probs_list):
00028 total = sum(relative_probs_list)
00029 absolute_prob_list = []
00030
00031 for relative_prob in relative_probs_list:
00032 absolute_prob_list.append(relative_prob / total)
00033
00034 return absolute_prob_list
00035
00036
00037 def calc_tables(self, probs_list):
00038
00039 distribution = []
00040
00041 for i, prob in enumerate(probs_list):
00042 distribution.append((i, prob))
00043
00044
00045
00046
00047
00048 def comp_probs(x, y):
00049 if x[1] > y[1]:
00050 return 1
00051 elif x[1] < y[1]:
00052 return -1
00053 else:
00054 return 0
00055
00056
00057 distribution.sort(comp_probs)
00058
00059
00060 probs = [0] * self.up_bound
00061 overflow = [0] * self.up_bound
00062
00063 for n in range(self.up_bound):
00064 min_val, min_prob = distribution[0]
00065 assert min_prob <= 1/self.up_bound + FLOAT_ERROR
00066
00067 max_val, max_prob = distribution[-1]
00068 assert max_prob >= 1/self.up_bound - FLOAT_ERROR
00069
00070 probs[min_val] = self.up_bound * min_prob
00071 assert probs[min_val] <= 1 + FLOAT_ERROR
00072
00073 overflow[min_val] = max_val
00074
00075 del distribution[0]
00076
00077 if (len(distribution) > 0):
00078 del distribution[-1]
00079
00080
00081
00082
00083
00084 remaining_prob = max_prob + min_prob - 1 / self.up_bound
00085
00086 i = 0
00087 if len(distribution) > 0:
00088 while distribution[i][1] < remaining_prob:
00089 i += 1
00090 if i >= len(distribution):
00091 break
00092
00093 distribution.insert(i, (max_val, remaining_prob))
00094
00095 return probs, overflow
00096
00097
00098
00099 def random(self):
00100 uniform = random() * self.up_bound
00101 int_uniform = int(uniform)
00102 frac_uniform = uniform - int_uniform
00103
00104 if frac_uniform < self.probs[int_uniform]:
00105 return int_uniform
00106 else:
00107 return self.overflow[int_uniform]
00108
00109
00110 def calc_expected(probs_list):
00111 expected = 0
00112 total = sum(probs_list)
00113 for i, prob in enumerate(probs_list):
00114 expected += prob * i / total
00115 return expected
00116
00117
00118 def demo_NonUniformRandomInt():
00119 probs_list = [3/18, 7/18, 8/18]
00120 r = NonUniformRandomInt(probs_list)
00121 s = 0
00122 L = 100000
00123 frequencies = [0, 0, 0]
00124
00125 for i in xrange(L):
00126 n = r.random()
00127 frequencies[n] += 1/L
00128 s += n
00129
00130 print s / L, calc_expected(probs_list), frequencies
00131
00132 demo_NonUniformRandomInt()
00133