Skip to content

Commit

Permalink
update tests, fix probs, add matching status
Browse files Browse the repository at this point in the history
  • Loading branch information
bridwell committed Nov 14, 2016
1 parent 330063f commit 2643fcf
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 85 deletions.
146 changes: 67 additions & 79 deletions urbansim/utils/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
def get_probs(data, prob_column=None):
"""
Checks for presence of a probability column and returns the result
as a numpy array. If the probabilities are weights (i.e. they don't)
sum to 1, then this will be recalculated.
as a numpy array. If the probabilities are weights (i.e. they don't
sum to 1), then this will be recalculated.
Parameters
----------
Expand All @@ -24,14 +24,15 @@ def get_probs(data, prob_column=None):
if prob_column is None:
p = None
else:
p = data[prob_column].values
p = data[prob_column].fillna(0).values
if p.sum() == 0:
p = np.ones(len(p))
if round(p.sum(), 0) != 1:
p = p / (1.0 * p.sum())
return p


def accounting_sample_replace(total, data, accounting_column,
prob_column=None, exact=True, max_iterations=50):
def accounting_sample_replace(total, data, accounting_column, prob_column=None, max_iterations=50):
"""
Sample rows with accounting with replacement.
Expand All @@ -46,9 +47,6 @@ def accounting_sample_replace(total, data, accounting_column,
If not provided then row counts will be used for accounting.
prob_column: string, optional, default None
Name of the column in the data to provide probabilities or weights.
exact: bool, optional, default True
If True, will attempt to match the total exactly. Otherwise it will be an
approximation.
max_iterations: int, optional, default 50
When using an accounting attribute, the maximum number of sampling iterations
that will be applied.
Expand All @@ -57,7 +55,8 @@ def accounting_sample_replace(total, data, accounting_column,
-------
sample_rows : pandas.DataFrame
Table containing the sample.
matched: bool
Indicates if the total was matched exactly.
"""

# check for probabilities
Expand All @@ -71,31 +70,26 @@ def accounting_sample_replace(total, data, accounting_column,
sample_rows = pd.DataFrame()
closest = None
closest_remain = total
matched = False

for i in range(0, max_iterations):

# stop if we've hit the control
if remaining == 0:
break

# stop after the 1st iteration if we're approximating w/out probs
if (not exact) and i == 1 and p is None:
break

# stop afer the 2nd iteration if we're approximating w/ probs
if (not exact) and i == 2 and p is not None:
matched = True
break

# if sampling with probabilities, re-caclc the # of items per sample
# after the initial sample, this way the sample size reflects the probabilities
if p is not None and i == 1:
per_sample = sample_rows[accounting_column].sum / (1.0 * len(sample_rows))
per_sample = sample_rows[accounting_column].sum() / (1.0 * len(sample_rows))

# update the sample
num_samples = int(math.ceil(math.fabs(remaining) / per_sample))

if remaining > 0:
# we're short, add to the sample
print p
curr_ids = np.random.choice(data.index.values, num_samples, p=p)
sample_rows = pd.concat([sample_rows, data.loc[curr_ids]])
else:
Expand All @@ -110,11 +104,10 @@ def accounting_sample_replace(total, data, accounting_column,
closest_remain = abs(remaining)
closest = sample_rows

return closest
return closest, matched


def accounting_sample_no_replace(total, data, accounting_column,
prob_column=None, exact=True, max_iterations=50):
def accounting_sample_no_replace(total, data, accounting_column, prob_column=None):
"""
Samples rows with accounting without replacement.
Expand All @@ -131,14 +124,13 @@ def accounting_sample_no_replace(total, data, accounting_column,
exact: bool, optional, default True
If True, will attempt to match the total exactly. Otherwise it will be an
approximation.
max_iterations: int, optional, default 50
When using an accounting attribute, the maximum number of sampling iterations
that will be applied.
Returns
-------
sample_rows : pandas.DataFrame
Table containing the sample.
matched: bool
Indicates if the total was matched exactly.
"""

Expand All @@ -148,62 +140,47 @@ def accounting_sample_no_replace(total, data, accounting_column,

# check for probabilities
p = get_probs(data, prob_column)
print p

closest = None
closest_shortage = total

for i in range(0, max_iterations):

print i

# shuffle the rows
if p is None:
# random shuffle
shuff_idx = np.random.permutation(data.index.values)
else:
# weighted shuffle
shuff_idx = np.random.choice(data.index.values, len(data), replace=False, p=p)

# get the initial sample
shuffle = data.loc[shuff_idx]
csum = np.cumsum(shuffle[accounting_column].values)
pos = np.searchsorted(csum, total, 'right')
sample = shuffle.iloc[:pos]

# if we're just approximating we're done
if not exact:
return sample.copy()

# refine the sample
sample_idx = sample.index.values
sample_total = sample[accounting_column].sum()
shortage = total - sample_total

for idx, row in shuffle.iloc[pos:].iterrows():
if shortage == 0:
# we've matached
break

# add the current element if it doesnt exceed the total
cnt = row[accounting_column]
if cnt <= shortage:
sample_idx = np.append(sample_idx, idx)
shortage -= cnt

# we've looped through all the elements, compare with other iterations
# shuffle the rows
if p is None:
# random shuffle
shuff_idx = np.random.permutation(data.index.values)
else:
# weighted shuffle
ran_p = pd.Series(np.power(np.random.rand(len(p)), 1.0 / p), index=data.index)
ran_p.sort(ascending=False)
shuff_idx = ran_p.index.values

# get the initial sample
shuffle = data.loc[shuff_idx]
csum = np.cumsum(shuffle[accounting_column].values)
pos = np.searchsorted(csum, total, 'right')
sample = shuffle.iloc[:pos]

# refine the sample
sample_idx = sample.index.values
sample_total = sample[accounting_column].sum()
shortage = total - sample_total
matched = False

for idx, row in shuffle.iloc[pos:].iterrows():
if shortage == 0:
closest = shuffle.loc[sample_idx].copy()
# we've matached
matched = True
break
else:
if abs(shortage) < closest_shortage:
closest = shuffle.loc[sample_idx].copy()
closest_shortage = shortage

return closest
# add the current element if it doesnt exceed the total
cnt = row[accounting_column]
if cnt <= shortage:
sample_idx = np.append(sample_idx, idx)
shortage -= cnt

return shuffle.loc[sample_idx].copy(), matched


def sample_rows(total, data, replace=True, accounting_column=None,
max_iterations=50, prob_column=None, exact=True):
max_iterations=50, prob_column=None, return_status=False):
"""
Samples and returns rows from a data frame while matching a desired control total. The total may
represent a simple row count or may attempt to match a sum/quantity from an accounting column.
Expand All @@ -221,7 +198,9 @@ def sample_rows(total, data, replace=True, accounting_column=None,
If not provided then row counts will be used for accounting.
max_iterations: int, optional, default 50
When using an accounting attribute, the maximum number of sampling iterations
that will be applied.
that will be applied. Only applicable when sampling with replacement.
return_status: bool, optional, default True
If True, will also return a bool indicating if the total was matched exactly.
Returns
-------
Expand All @@ -233,11 +212,20 @@ def sample_rows(total, data, replace=True, accounting_column=None,
if accounting_column is None:
if replace is False and total > len(data.index.values):
raise ValueError('Control total exceeds the available samples')
return data.loc[np.random.choice(data.index.values, total, replace=replace)].copy()
rows = data.loc[np.random.choice(data.index.values, total, replace=replace)].copy()
matched = True

if replace:
func = accounting_sample_replace
# sample with accounting
else:
func = accounting_sample_no_replace
if replace:
rows, matched = accounting_sample_replace(
total, data, accounting_column, prob_column, max_iterations)
else:
rows, matched = accounting_sample_no_replace(
total, data, accounting_column, prob_column)

return func(total, data, accounting_column, prob_column, exact, max_iterations)
# return the results
if return_status:
return rows, matched
else:
return rows
47 changes: 41 additions & 6 deletions urbansim/utils/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,21 @@
import pandas as pd
import pytest

from urbansim.utils.sampling import sample_rows
from urbansim.utils.sampling import sample_rows, get_probs


def test_get_probs():
df = pd.DataFrame({
'a': np.zeros(4),
'b': [0.25, 0.25, 0.25, 0.25],
'c': np.ones(4)
})

assert get_probs(df) is None
expected = [0.25, 0.25, 0.25, 0.25]
assert (get_probs(df, 'a') == expected).all()
assert (get_probs(df, 'b') == expected).all()
assert (get_probs(df, 'c') == expected).all()


@pytest.fixture(scope='function')
Expand All @@ -18,10 +32,14 @@ def fin():
np.random.set_state(old_state)

request.addfinalizer(fin)
np.random.seed(1)
np.random.seed(123)
return pd.DataFrame(
{'some_count': np.random.randint(1, 8, 20)},
index=range(0, 20))
{
'some_count': np.random.randint(1, 8, 20),
'p': np.arange(20)
},
index=range(0, 20)
)


def test_no_accounting_with_replacment(random_df):
Expand All @@ -45,14 +63,31 @@ def test_no_accounting_no_replacment_raises(random_df):

def test_accounting_with_replacment(random_df):
control = 10
rows = sample_rows(control, random_df, accounting_column='some_count')

rows, matched = sample_rows(
control, random_df, accounting_column='some_count', return_status=True)
assert control == rows['some_count'].sum()
assert matched

# test with probabilities
rows, matched = sample_rows(
control, random_df, accounting_column='some_count', prob_column='p', return_status=True)
assert control == rows['some_count'].sum()
assert matched


def test_accounting_no_replacment(random_df):
control = 10
rows = sample_rows(control, random_df, accounting_column='some_count', replace=False)
rows, matched = sample_rows(
control, random_df, accounting_column='some_count', replace=False, return_status=True)
assert control == rows['some_count'].sum()
assert matched

# test with probabilities
rows, matched = sample_rows(control, random_df, accounting_column='some_count',
replace=False, prob_column='p', return_status=True)
assert control == rows['some_count'].sum()
assert matched


def test_accounting_no_replacment_raises(random_df):
Expand Down

0 comments on commit 2643fcf

Please sign in to comment.