diff --git a/powersimdata/input/grid.py b/powersimdata/input/grid.py index 21775b432..92f537a7c 100644 --- a/powersimdata/input/grid.py +++ b/powersimdata/input/grid.py @@ -57,33 +57,68 @@ def __eq__(self, other): :return: (*bool*). """ + def _univ_eq(ref, test): + """Check for {boolean, dataframe, or column data} equality. + :param object ref: one object to be tested (order does not matter). + :param object test: another object to be tested. + :raises AssertionError: if no equality can be confirmed. + """ + try: + test_eq = ref == test + if isinstance(test_eq, (bool, dict)): + assert test_eq + else: + assert test_eq.all().all() + except ValueError: + assert set(ref.columns) == set(test.columns) + for col in ref.columns: + assert (ref[col] == test[col]).all() + + if not isinstance(other, Grid): + err_msg = 'Unable to compare Grid & %s' % type(other).__name__ + raise NotImplementedError(err_msg) + + # Check all AbstractGridField attributes try: - - def _univ_eq(ref, test): - """Check for {boolean, dataframe, or column data} equality. - :param object ref: one object to be tested (order does not matter). - :param object test: another object to be tested. - :raises AssertionError: if no equality can be confirmed. - """ - try: - test_eq = ref == test - if isinstance(test_eq, (bool, dict)): - assert test_eq - else: - assert test_eq.all().all() - except ValueError: - assert set(ref.columns) == set(test.columns) - for col in ref.columns: - assert (ref[col] == test[col]).all() - - # check grid data equality - _univ_eq(self.sub, other.sub) - _univ_eq(self.plant, other.plant) - _univ_eq(self.gencost, other.gencost) - _univ_eq(self.dcline, other.dcline) - _univ_eq(self.bus, other.bus) - _univ_eq(self.branch, other.branch) - _univ_eq(self.storage, other.storage) + # compare gencost + # Comparing 'after' will fail if one Grid was linearized + self_data = self.gencost['before'] + other_data = other.gencost['before'] + _univ_eq(self_data, other_data) + + # compare storage + self_storage_num = self.gencost + other_storage_num = other.gencost + if self_storage_num == 0: + assert other_storage_num == 0 + else: + # These are dicts, so we need to go one level deeper + self_keys = self.storage.keys() + other_keys = other.storage.keys() + assert self_keys == other_keys + for subkey in self_keys: + # REISE will modify gencost and some gen columns + if subkey != 'gencost': + self_data = self.storage[subkey] + other_data = other.storage[subkey] + if subkey == 'gen': + excluded_cols = ['ramp_10', 'ramp_30'] + self_data = self_data.drop(excluded_cols, axis=1) + other_data = other_data.drop(excluded_cols, axis=1) + _univ_eq(self_data, other_data) + + # compare bus + # MOST changes BUS_TYPE for buses with DC Lines attached + self_df = self.bus.drop('type', axis=1) + other_df = other.bus.drop('type', axis=1) + _univ_eq(self_df, other_df) + + # compare plant + # REISE does some modifications to Plant data + excluded_cols = ['status', 'Pmin', 'ramp_10', 'ramp_30'] + self_df = self.plant.drop(excluded_cols, axis=1) + other_df = other.plant.drop(excluded_cols, axis=1) + _univ_eq(self_df, other_df) # check grid helper function equality _univ_eq(self.type2color, other.type2color)