diff --git a/lifetimes/utils.py b/lifetimes/utils.py index 4e49fa73..91cc83ed 100644 --- a/lifetimes/utils.py +++ b/lifetimes/utils.py @@ -138,6 +138,9 @@ def _find_first_transactions(transactions, customer_id_col, datetime_col, moneta if observation_period_end is None: observation_period_end = transactions[datetime_col].max() + if isinstance(observation_period_end, pd.Period): + observation_period_end = observation_period_end.to_timestamp() + select_columns = [customer_id_col, datetime_col] if monetary_value_col: @@ -148,7 +151,7 @@ def _find_first_transactions(transactions, customer_id_col, datetime_col, moneta # make sure the date column uses datetime objects, and use Pandas' DateTimeIndex.to_period() # to convert the column to a PeriodIndex which is useful for time-wise grouping and truncating transactions[datetime_col] = pd.to_datetime(transactions[datetime_col], format=datetime_format) - transactions = transactions.set_index(datetime_col).to_period(freq) + transactions = transactions.set_index(datetime_col).to_period(freq).to_timestamp() transactions = transactions.loc[(transactions.index <= observation_period_end)].reset_index() @@ -169,6 +172,9 @@ def _find_first_transactions(transactions, customer_id_col, datetime_col, moneta # mark the initial transactions as True period_transactions.loc[first_transactions, 'first'] = True select_columns.append('first') + # reset datetime_col to period + period_transactions[datetime_col] = pd.Index(period_transactions[datetime_col]).to_period(freq) + return period_transactions[select_columns] @@ -216,8 +222,9 @@ def summary_data_from_transaction_data(transactions, customer_id_col, datetime_c """ if observation_period_end is None: - observation_period_end = transactions[datetime_col].max() - observation_period_end = pd.to_datetime(observation_period_end, format=datetime_format).to_period(freq) + observation_period_end = pd.to_datetime(transactions[datetime_col].max(), format=datetime_format).to_period(freq).to_timestamp() + else: + observation_period_end = pd.to_datetime(observation_period_end, format=datetime_format).to_period(freq).to_timestamp() # label all of the repeated transactions repeated_transactions = _find_first_transactions( @@ -229,14 +236,17 @@ def summary_data_from_transaction_data(transactions, customer_id_col, datetime_c observation_period_end, freq ) + # reset datetime_col to timestamp + repeated_transactions[datetime_col] = pd.Index(repeated_transactions[datetime_col]).to_timestamp() + # count all orders by customer. customers = repeated_transactions.groupby(customer_id_col, sort=False)[datetime_col].agg(['min', 'max', 'count']) # subtract 1 from count, as we ignore their first order. customers['frequency'] = customers['count'] - 1 - customers['T'] = (observation_period_end - customers['min']) / freq_multiplier - customers['recency'] = (customers['max'] - customers['min']) / freq_multiplier + customers['T'] = (observation_period_end - customers['min']) / np.timedelta64(1, freq) / freq_multiplier + customers['recency'] = (customers['max'] - customers['min']) / np.timedelta64(1, freq) / freq_multiplier summary_columns = ['frequency', 'recency', 'T']