diff --git a/appdaemon_testing/hass_driver.py b/appdaemon_testing/hass_driver.py index 75176f8..c975850 100644 --- a/appdaemon_testing/hass_driver.py +++ b/appdaemon_testing/hass_driver.py @@ -2,6 +2,7 @@ import logging import unittest.mock as mock from collections import defaultdict +from copy import copy from dataclasses import dataclass from typing import Dict, Any, List, Callable, Union, Optional @@ -75,7 +76,8 @@ def set_state( domain, _ = entity.split(".") state_entry = self._states[entity] - old_value = previous or state_entry.get(attribute_name) + prev_state = copy(state_entry) + old_value = previous or prev_state.get(attribute_name) new_value = state if old_value == new_value: @@ -93,31 +95,35 @@ def set_state( sat_new = spy.new is None or spy.new == new_value sat_old = spy.old is None or spy.old == old_value + param_old = prev_state if spy.attribute == 'all' else old_value + param_new = copy(state_entry) if spy.attribute == 'all' else new_value + param_attribute = None if spy.attribute == 'all' else attribute_name + if all([sat_old, sat_new, sat_attr]): - spy.callback(entity, attribute_name, old_value, new_value, spy.kwargs) + spy.callback(entity, param_attribute, param_old, param_new, spy.kwargs) def _se_get_state(self, entity_id=None, attribute="state", default=None, **kwargs): _LOGGER.debug("Getting state for entity: %s", entity_id) fully_qualified = "." in entity_id - matched_states = [] + matched_states = {} if fully_qualified: - matched_states.append(self._states[entity_id]) + matched_states[entity_id] = self._states[entity_id] else: for s_eid, state in self._states.items(): domain, entity = s_eid.split(".") if domain == entity_id: - matched_states.append(state) + matched_states[s_eid] = state # With matched states, map the provided attribute (if applicable) if attribute != "all": - matched_states = [state[attribute] for state in matched_states] + matched_states = {eid: state.get(attribute) for eid, state in matched_states.items()} if default is not None: - matched_states = [state or default for state in matched_states] + matched_states = {eid: state or default for eid, state in matched_states.items()} if fully_qualified: - return matched_states[0] + return matched_states[entity_id] else: return matched_states diff --git a/appdaemon_testing_tests/test_hass_driver.py b/appdaemon_testing_tests/test_hass_driver.py new file mode 100644 index 0000000..9a4527c --- /dev/null +++ b/appdaemon_testing_tests/test_hass_driver.py @@ -0,0 +1,171 @@ +from unittest import mock + +import pytest + +from appdaemon_testing import HassDriver + + +def test_get_state(hass_driver): + get_state = hass_driver.get_mock('get_state') + assert get_state('light.1') == 'off' + + +def test_get_state_attribute(hass_driver): + get_state = hass_driver.get_mock('get_state') + assert get_state('light.1', attribute='linkquality') == 60 + + +def test_get_state_attribute_all(hass_driver): + get_state = hass_driver.get_mock('get_state') + assert get_state('light.1', attribute='all') == { + 'state': 'off', + 'linkquality': 60 + } + + +def test_get_state_attribute_domain(hass_driver): + get_state = hass_driver.get_mock('get_state') + assert get_state('light') == { + 'light.1': 'off', + 'light.2': 'on' + } + + +def test_get_state_attribute_domain_with_attribute(hass_driver): + get_state = hass_driver.get_mock('get_state') + assert get_state('light', attribute='linkquality') == { + 'light.1': 60, + 'light.2': 10 + } + + +def test_get_state_attribute_domain_with_attribute_all(hass_driver): + get_state = hass_driver.get_mock('get_state') + assert get_state('light', attribute='all') == { + 'light.1': {'state': 'off', 'linkquality': 60}, + 'light.2': {'state': 'on', 'linkquality': 10, 'brightness': 60} + } + + +def test_get_state_with_default(hass_driver): + get_state = hass_driver.get_mock('get_state') + assert get_state('light.1', attribute='brightness', default=40) == 40 + assert get_state('light.2', attribute='brightness', default=40) == 60 + + +def test_get_state_domain_with_default(hass_driver): + get_state = hass_driver.get_mock('get_state') + assert get_state('light', attribute='brightness', default=40) == { + 'light.1': 40, + 'light.2': 60 + } + + +def test_listen_state(hass_driver): + listen_state = hass_driver.get_mock('listen_state') + handler1 = mock.Mock() + handler2 = mock.Mock() + listen_state(handler1, 'light.1') + listen_state(handler2, 'light.1') + + assert handler1.call_count == 0 + assert handler2.call_count == 0 + + hass_driver.set_state('light.1', 'off') + + assert handler1.call_count == 0 + assert handler2.call_count == 0 + + hass_driver.set_state('light.1', 'on') + + handler1.assert_called_once_with('light.1', 'state', 'off', 'on', {}) + handler2.assert_called_once_with('light.1', 'state', 'off', 'on', {}) + + +def test_listen_state_attribute(hass_driver): + listen_state = hass_driver.get_mock('listen_state') + handler = mock.Mock() + listen_state(handler, 'light.1', attribute='linkquality') + + assert handler.call_count == 0 + hass_driver.set_state('light.1', 'on') + assert handler.call_count == 0 + + hass_driver.set_state('light.1', 50, attribute_name='linkquality') + handler.assert_called_once_with('light.1', 'linkquality', 60, 50, {}) + + +def test_listen_state_attribute_all(hass_driver): + listen_state = hass_driver.get_mock('listen_state') + handler = mock.Mock() + listen_state(handler, 'light.1', attribute='all') + + assert handler.call_count == 0 + hass_driver.set_state('light.1', 75, attribute_name='brightness') + handler.assert_called_once_with( + 'light.1', None, + {'state': 'off', 'linkquality': 60}, + {'state': 'off', 'linkquality': 60, 'brightness': 75}, + {} + ) + + +def test_listen_state_domain(hass_driver): + listen_state = hass_driver.get_mock('listen_state') + handler = mock.Mock() + listen_state(handler, 'light', attribute='brightness') + + assert handler.call_count == 0 + hass_driver.set_state('light.2', 75, attribute_name='brightness') + handler.assert_called_once_with('light.2', 'brightness', 60, 75, {}) + + +def test_listen_state_with_new(hass_driver): + listen_state = hass_driver.get_mock('listen_state') + handler = mock.Mock() + listen_state(handler, 'media_player.smart_tv', attribute='source', new='Spotify') + + hass_driver.set_state('media_player.smart_tv', 'YouTube', attribute_name='source') + assert handler.call_count == 0 + + hass_driver.set_state('media_player.smart_tv', 'Spotify', attribute_name='source') + handler.assert_called_once_with('media_player.smart_tv', 'source', 'YouTube', 'Spotify', {}) + + +def test_listen_state_with_old(hass_driver): + listen_state = hass_driver.get_mock('listen_state') + handler = mock.Mock() + listen_state(handler, 'media_player.smart_tv', attribute='source', old='Spotify') + + hass_driver.set_state('media_player.smart_tv', 'Spotify', attribute_name='source') + assert handler.call_count == 0 + + hass_driver.set_state('media_player.smart_tv', 'TV', attribute_name='source') + handler.assert_called_once_with('media_player.smart_tv', 'source', 'Spotify', 'TV', {}) + + +def test_setup_does_not_trigger_spys(hass_driver): + listen_state = hass_driver.get_mock('listen_state') + handler = mock.Mock() + listen_state(handler, 'light') + listen_state(handler, 'light.1', attribute='brightness') + + with hass_driver.setup(): + hass_driver.set_state('light.1', 'off') + hass_driver.set_state('light.1', 'on') + hass_driver.set_state('light.1', 50, attribute_name='linkquality') + + assert handler.call_count == 0 + hass_driver.set_state('light.1', 'off') + assert handler.call_count == 1 + + +@pytest.fixture +def hass_driver() -> HassDriver: + hass_driver = HassDriver() + hass_driver._states = { + 'light.1': {'state': 'off', 'linkquality': 60}, + 'light.2': {'state': 'on', 'linkquality': 10, 'brightness': 60}, + 'media_player.smart_tv': {'state': 'on', 'source': None}, + } + return hass_driver