Saturday, September 29, 2012

Weighted random choice

Weighted random choice makes you able to select a random value out of a set of values using a distribution specified though a set of weights. So, given a list we want to pick randomly some elements from it but we need that the chances to pick a specific element is defined using a weight. In the following code we have a function that implements the weighted random choice mechanism and an example of how to use it:
from numpy import cumsum, sort, sum, searchsorted
from numpy.random import rand
from pylab import hist,show,xticks

def weighted_pick(weights,n_picks):
  Weighted random selection
  returns n_picks random indexes.
  the chance to pick the index i 
  is give by the weight weights[i].
 t = cumsum(weights)
 s = sum(weights)
 return searchsorted(t,rand(n_picks)*s)

# weights, don't have to sum up to one
w = [0.1, 0.2, 0.5, 0.5, 1.0, 1.1, 2.0]

# picking 10000 times
picked_list = weighted_pick(w,10000)

# plotting the histogram
The code above plots the distribution of the selected indexes:

We can observe that the chance to pick the element i is proportional to the weight w[i].


  1. thanks for sharing..

  2. Cool stuff, thanks.

    I'm learning Python list & dict math methods, found your article very helpful & concise.

    What if you wanted the weights to add up to 1, how does that change the code?

    That would be helpful so as to extend your Matplotlib code to produce boxplots for each indice on a single chart.

    Thanks in advance for your help.


    1. Hi :) you can use a list of weights that add to 1.

    2. This will convert a list with weights which do not add up to 1, to a list of weights which do.
      It uses a list comprehension, which might look scary at first, but they are very fast and useful!

      weights = [1, 2, 3, 4, 5]
      sumOfWeights = sum(weights)
      probabilities = [(i/float(sumOfWeights)) for i in weights]

      >>> probabilities

      >>> sum(probabilities)

      The list comprehension can be translated to a normal loop, for better understanding:

      [(i/float(sumOfWeights)) for i in weights]

      l = []
      for i in weights: