Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mouse position into env observation #282

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 40 additions & 12 deletions browsergym/core/src/browsergym/core/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,10 @@
from .action.highlevel import HighLevelActionSet
from .chat import Chat
from .constants import BROWSERGYM_ID_ATTRIBUTE, EXTRACT_OBS_MAX_TRIES
from .observation import (
MarkingError,
_post_extract,
_pre_extract,
extract_dom_extra_properties,
extract_dom_snapshot,
extract_focused_element_bid,
extract_merged_axtree,
extract_screenshot,
)
from .observation import (MarkingError, _post_extract, _pre_extract,
extract_dom_extra_properties, extract_dom_snapshot,
extract_focused_element_bid, extract_merged_axtree,
extract_mouse_position, extract_screenshot)
from .spaces import AnyBox, AnyDict, Float, Unicode
from .task import AbstractBrowserTask

Expand Down Expand Up @@ -157,6 +151,9 @@ def __init__(
shape=(-1, -1, 3),
dtype=np.uint8,
), # swapped axes (height, width, RGB)
"mouse_position": gym.spaces.Tuple(
(Float(), Float())
),
"dom_object": AnyDict(),
"axtree_object": AnyDict(),
"extra_element_properties": AnyDict(),
Expand Down Expand Up @@ -258,20 +255,31 @@ def override_property(task, env, property):
# set default timeout
self.context.set_default_timeout(timeout)

# hack: keep track of the active page with a javascript callback
# hack: keep track of the active page and mouse position with javascript callbacks
# there is no concept of active page in playwright
# https://github.com/microsoft/playwright/issues/2603
self.context.expose_binding(
"browsergym_page_activated", lambda source: self._activate_page_from_js(source["page"])
)
self.context.expose_binding(
"browsergym_mouse_moved", lambda source: self._update_mouse_position_from_js(source)
)
# Initialize mouse position tracking
self.last_mouse_position = None
self.context.add_init_script(
r"""
window.browsergym_page_activated();
window.addEventListener("focus", () => {window.browsergym_page_activated();}, {capture: true});
window.addEventListener("focusin", () => {window.browsergym_page_activated();}, {capture: true});
window.addEventListener("load", () => {window.browsergym_page_activated();}, {capture: true});
window.addEventListener("pageshow", () => {window.browsergym_page_activated();}, {capture: true});
window.addEventListener("mousemove", () => {window.browsergym_page_activated();}, {capture: true});
window.addEventListener("mousemove", (event) => {
window.browsergym_page_activated();
window.browsergym_mouse_moved({
x: event.clientX,
y: event.clientY
});
}, {capture: true});
window.addEventListener("mouseup", () => {window.browsergym_page_activated();}, {capture: true});
window.addEventListener("mousedown", () => {window.browsergym_page_activated();}, {capture: true});
window.addEventListener("wheel", () => {window.browsergym_page_activated();}, {capture: true});
Expand Down Expand Up @@ -485,6 +493,25 @@ def _wait_dom_loaded(self):
except playwright.sync_api.Error:
pass

def _update_mouse_position_from_js(self, source):
page = source["page"]
x = source["x"]
y = source["y"]
logger.debug(f"_update_mouse_position_from_js called, page={str(page)}, x={x}, y={y}")

if not page.context == self.context:
raise RuntimeError(
f"Unexpected: mouse event from a page that belongs to a different browser context ({page})."
)

# Store the mouse position along with the page that received the event
self.last_mouse_position = {
"page": page,
"x": x,
"y": y,
"timestamp": time.time()
}

def _activate_page_from_js(self, page: playwright.sync_api.Page):
logger.debug(f"_activate_page_from_js(page) called, page={str(page)}")
if not page.context == self.context:
Expand Down Expand Up @@ -581,6 +608,7 @@ def _get_obs(self):
"last_action": self.last_action,
"last_action_error": self.last_action_error,
"elapsed_time": np.asarray([time.time() - self.start_time]),
"mouse_position": extract_mouse_position(self.page),
}

return obs
15 changes: 15 additions & 0 deletions browsergym/core/src/browsergym/core/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,21 @@ def extract_screenshot(page: playwright.sync_api.Page):

return img

def extract_mouse_position(page: playwright.sync_api.Page):
"""
Extracts the mouse location on a Playwright page using a hacky JS code.

Args:
page: the playwright page of which to extract the mouse location.

Returns:
An array of the x and y coordinates of the mouse location.
"""
position = page.evaluate("""() => {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will work for simple pages, but I'm worried about iframes. Here is something that could work:

  • in the JS callback (mousemove), record the position in JS in the window object, and also record which page / frame received this event, in Python with a method similar to _activate_page_from_js().
  • to extract the mouse position in the browser viewport, take the latest mouse position (last iframe that received a mousemove event), and work your way up the frame hierarchy to reconstruct the current mouse position. See how we do that to get the coordinates of all elements in all iframes here:

https://github.com/ServiceNow/BrowserGym/blob/main/browsergym/core/src/browsergym/core/observation.py#L293-L377

return [window.pageX, window.pageY];
}""")
return (position[0], position[1])


# we could handle more data items here if needed
__BID_EXPR = r"([a-zA-Z0-9]+)"
Expand Down
4 changes: 3 additions & 1 deletion tests/core/test_actions_highlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
# register openended gym environments
import browsergym.core
from browsergym.core.action.highlevel import HighLevelActionSet
from browsergym.core.action.parsers import NamedArgument, highlevel_action_parser
from browsergym.core.action.parsers import (NamedArgument,
highlevel_action_parser)
from browsergym.core.constants import BROWSERGYM_ID_ATTRIBUTE as BID_ATTR
from browsergym.utils.obs import flatten_dom_to_str

Expand Down Expand Up @@ -1141,6 +1142,7 @@ def get_checkbox_elem(obs):

obs, reward, term, trunc, info = env.step(action)
checkbox = get_checkbox_elem(obs)
assert obs['mouse_position'] == (x, y)

# box not checked
assert not obs["last_action_error"]
Expand Down