import random import bisect class RandomGen(object): def __init__(self, random_nums, probabilities): assert len(random_nums), 'Random number list is empty.' assert len(random_nums) == len(probabilities), 'Random numbers and probabilities do not match.' assert abs(1. - sum(probabilities)) < 0.0000001, 'Probabilities do not sum to 1.' # Values that may be returned by next_num() self._random_nums = random_nums # Probability of the occurence of random_nums. # Actually we create range table here, such that all random numbers # ranging from (tbl[i-1],tbl[i]] map to random_number[i] self._range_table = probabilities for i in range(1,len(self._range_table)): self._range_table[i] += self._range_table[i-1] # All random numbers which are not in any range will map # to last _random_num. Therefore there is no point to have # last _random_num probability range in the table. self._range_table.pop() def next_num(self): """ Returns one of the randomNums. When this method is called multiple times over a long period, it should return the numbers roughly with the initialized probabilities. """ r = random.random() return self._random_nums[bisect.bisect_left(self._range_table, r)] def test_init(random_nums, probabilities): try: rg = RandomGen(random_nums, probabilities) except AssertionError as e: return assert False, "Had to be assert, but no assert could be detected." def test(input, rep_count = 1000): generate_stat = { _1 : 0 for _1,_ in input.items() } rg = RandomGen([key for key in input.keys()], [val for val in input.values()]) for i in range(rep_count): generate_stat[rg.next_num()] += 1 for rnum, rnum_stat in generate_stat.items(): stat_prob = rnum_stat/rep_count orig_prob = input[rnum] print('{}: orig: {} actual: {}'.format(rnum, orig_prob, stat_prob)) print('') if __name__ == '__main__': test_init([], []) # Empty list of numbers not allowed test_init([1, 2], [1. ]) # Probabilities don't match test_init([1, 2], [ 0.1, 0.1]) # Probabilities don't sum to 1. test({0 : 0.5, 1 : 0.25, 2 : 0.25 }) test({-1 : 0.01, 0 : 0.3, 1 : 0.58, 2 : 0.1, 3 : 0.01 }, 100) test({-1 : 0.01, 0 : 0.3, 1 : 0.58, 2 : 0.1, 3 : 0.01 }, 100000)