diff --git a/src/magentic_ui/tools/playwright/browser/headless_docker_playwright_browser.py b/src/magentic_ui/tools/playwright/browser/headless_docker_playwright_browser.py index 0e4c4c0d..5e0073b4 100644 --- a/src/magentic_ui/tools/playwright/browser/headless_docker_playwright_browser.py +++ b/src/magentic_ui/tools/playwright/browser/headless_docker_playwright_browser.py @@ -2,6 +2,7 @@ import asyncio import logging +from typing import Any, Dict from autogen_core import Component import docker @@ -87,21 +88,21 @@ async def create_container(self) -> Container: ) client = docker.from_env() - return await asyncio.to_thread( - client.containers.create, - name=f"magentic-ui-headless-browser_{self._playwright_port}", - image="mcr.microsoft.com/playwright:v1.51.1-noble", - detach=True, - auto_remove=True, - ports={ + container_config: Dict[str, Any] = { + "name": f"magentic-ui-headless-browser_{self._playwright_port}", + "image": "mcr.microsoft.com/playwright:v1.51.1-noble", + "detach": True, + "auto_remove": True, + "ports": { f"{self._playwright_port}/tcp": self._playwright_port, }, - command=[ + "command": [ "/bin/sh", "-c", f"npx -y playwright@1.51 run-server --port {self._playwright_port} --host 0.0.0.0", ], - ) + } + return await asyncio.to_thread(client.containers.create, **container_config) def _to_config(self) -> HeadlessBrowserConfig: return HeadlessBrowserConfig( diff --git a/src/magentic_ui/tools/playwright/browser/local_playwright_browser.py b/src/magentic_ui/tools/playwright/browser/local_playwright_browser.py index adbae684..8788f236 100644 --- a/src/magentic_ui/tools/playwright/browser/local_playwright_browser.py +++ b/src/magentic_ui/tools/playwright/browser/local_playwright_browser.py @@ -2,9 +2,10 @@ from typing import Optional, Any, Dict from pathlib import Path +import os from autogen_core import Component -from playwright.async_api import BrowserContext, Browser +from playwright.async_api import BrowserContext, Browser, Route from pydantic import BaseModel from playwright.async_api import async_playwright, Playwright @@ -99,15 +100,19 @@ async def _start(self) -> None: # Ensure the browser data directory exists Path(self._browser_data_dir).mkdir(parents=True, exist_ok=True) - # Launch persistent context + # Launch persistent context with automatic downloads self._context = await self._playwright.chromium.launch_persistent_context( self._browser_data_dir, - accept_downloads=self._enable_downloads, + accept_downloads=True, # Always accept downloads **launch_options, args=["--disable-extensions", "--disable-file-system"], env={}, chromium_sandbox=True, ) + + # Set up download behavior for persistent context + if self._enable_downloads: + await self._context.route("**/*", self._handle_download) else: # Launch regular browser and create new context self._browser = await self._playwright.chromium.launch( @@ -117,11 +122,47 @@ async def _start(self) -> None: env={} if self._headless else {"DISPLAY": ":0"}, ) + # Create context with automatic downloads self._context = await self._browser.new_context( user_agent="Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/122.0.0.0 Safari/537.36 Edg/122.0.0.0", - accept_downloads=self._enable_downloads, + accept_downloads=True, # Always accept downloads ) + # Set up download behavior for regular context + if self._enable_downloads: + await self._context.route("**/*", self._handle_download) + + async def _handle_download(self, route: Route) -> None: + """ + Handle download requests by intercepting them and saving to the .webby directory. + """ + response = await route.fetch() + headers = response.headers + + # Check if this is a download (Content-Disposition header) + if "content-disposition" in headers: + # Extract filename from Content-Disposition header + content_disposition = headers["content-disposition"] + filename = None + if "filename=" in content_disposition: + filename = content_disposition.split("filename=")[1].strip('"') + + if filename and self._browser_data_dir: + # Save to .webby directory + webby_dir = os.path.join(self._browser_data_dir, ".webby") + os.makedirs(webby_dir, exist_ok=True) + + filepath = os.path.join(webby_dir, filename) + with open(filepath, "wb") as f: + f.write(await response.body()) + + # Continue with the download in browser + await route.continue_() + else: + await route.continue_() + else: + await route.continue_() + async def _close(self) -> None: """ Close the browser resource. diff --git a/src/magentic_ui/tools/playwright/browser/vnc_docker_playwright_browser.py b/src/magentic_ui/tools/playwright/browser/vnc_docker_playwright_browser.py index 55bbf10c..41ee1a03 100644 --- a/src/magentic_ui/tools/playwright/browser/vnc_docker_playwright_browser.py +++ b/src/magentic_ui/tools/playwright/browser/vnc_docker_playwright_browser.py @@ -2,6 +2,7 @@ import asyncio import logging +from typing import Any, Dict, Tuple from pathlib import Path import secrets @@ -101,7 +102,7 @@ def __init__( ) self._docker_name = f"magentic-ui-vnc-browser_{self._playwright_websocket_path}_{self._novnc_port}" - def _get_available_port(self) -> tuple[int, socket.socket]: + def _get_available_port(self) -> Tuple[int, socket.socket]: s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.bind(("127.0.0.1", 0)) port = s.getsockname()[1] @@ -160,27 +161,26 @@ async def create_container(self) -> Container: ) client = docker.from_env() - - return await asyncio.to_thread( - client.containers.create, - name=self._docker_name, - image=self._image, - detach=True, - auto_remove=True, - network=self._network_name if self._inside_docker else None, - ports={ + container_config: Dict[str, Any] = { + "name": self._docker_name, + "image": self._image, + "detach": True, + "auto_remove": True, + "network": self._network_name if self._inside_docker else None, + "ports": { f"{self._playwright_port}/tcp": self._playwright_port, f"{self._novnc_port}/tcp": self._novnc_port, }, - volumes={ + "volumes": { str(self._bind_dir.resolve()): {"bind": "/workspace", "mode": "rw"} }, - environment={ + "environment": { "PLAYWRIGHT_WS_PATH": self._playwright_websocket_path, "PLAYWRIGHT_PORT": str(self._playwright_port), "NO_VNC_PORT": str(self._novnc_port), }, - ) + } + return await asyncio.to_thread(client.containers.create, **container_config) def _to_config(self) -> VncDockerPlaywrightBrowserConfig: return VncDockerPlaywrightBrowserConfig( diff --git a/src/magentic_ui/tools/playwright/playwright_controller.py b/src/magentic_ui/tools/playwright/playwright_controller.py index 4887eb15..48d94719 100644 --- a/src/magentic_ui/tools/playwright/playwright_controller.py +++ b/src/magentic_ui/tools/playwright/playwright_controller.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import base64 import os @@ -16,13 +18,14 @@ cast, List, Literal, + ClassVar, + Protocol, ) import warnings from playwright.async_api import Locator from playwright.async_api import Error as PlaywrightError from playwright.async_api import TimeoutError as PlaywrightTimeoutError from playwright.async_api import Download, Page, BrowserContext -from .utils.animation_utils import AnimationUtilsPlaywright from .utils.webpage_text_utils import WebpageTextUtilsPlaywright from ..url_status_manager import UrlStatusManager @@ -69,6 +72,19 @@ } +class AnimationProtocol(Protocol): + """Protocol defining the required animation methods.""" + + async def add_cursor_box(self, page: Page, identifier: str) -> None: ... + async def remove_cursor_box(self, page: Page, identifier: str) -> None: ... + async def gradual_cursor_animation( + self, page: Page, start_x: float, start_y: float, end_x: float, end_y: float + ) -> None: ... + async def cleanup_animations(self, page: Page) -> None: ... + + last_cursor_position: Tuple[float, float] + + class PlaywrightController: """ A helper class to allow Playwright to interact with web pages to perform actions such as clicking, filling, and scrolling. @@ -87,9 +103,309 @@ class PlaywrightController: url_validation_callback (callable, optional): A callback function to validate URLs. It should return a tuple of (str, bool) where the str is a failure string and bool indicates if the URL is allowed. """ + _page_script: ClassVar[str] = """ +(function() { + window.WebSurfer = { + nextLabel: 10, + roleMapping: { + "a": "link", + "area": "link", + "button": "button", + "input, type=button": "button", + "input, type=checkbox": "checkbox", + "input, type=email": "textbox", + "input, type=number": "spinbutton", + "input, type=radio": "radio", + "input, type=range": "slider", + "input, type=reset": "button", + "input, type=search": "searchbox", + "input, type=submit": "button", + "input, type=tel": "textbox", + "input, type=text": "textbox", + "input, type=url": "textbox", + "search": "search", + "select": "combobox", + "option": "option", + "textarea": "textbox" + }, + getCursor: function(elm) { + return window.getComputedStyle(elm)["cursor"]; + }, + isVisible: function(element) { + return !!(element.offsetWidth || element.offsetHeight || element.getClientRects().length); + }, + getVisibleText: function() { + const walker = document.createTreeWalker( + document.body, + NodeFilter.SHOW_TEXT, + { + acceptNode: function(node) { + const style = window.getComputedStyle(node.parentElement); + if (style.display === 'none' || style.visibility === 'hidden' || style.opacity === '0') { + return NodeFilter.FILTER_REJECT; + } + return NodeFilter.FILTER_ACCEPT; + } + } + ); + + let text = ''; + let node; + while (node = walker.nextNode()) { + text += node.textContent + ' '; + } + return text.trim(); + }, + getInteractiveElementsNoShaddow: function() { + let results = [] + let roles = ["scrollbar", "searchbox", "slider", "spinbutton", "switch", "tab", "treeitem", "button", "checkbox", "gridcell", "link", "menuitem", "menuitemcheckbox", "menuitemradio", "option", "progressbar", "radio", "textbox", "combobox", "menu", "tree", "treegrid", "grid", "listbox", "radiogroup", "widget"]; + let inertCursors = ["auto", "default", "none", "text", "vertical-text", "not-allowed", "no-drop"]; + + let nodeList = document.querySelectorAll("input, select, textarea, button, [href], [onclick], [contenteditable], [tabindex]:not([tabindex='-1'])"); + for (let i = 0; i < nodeList.length; i++) { + if (nodeList[i].disabled || !this.isVisible(nodeList[i])) { + continue; + } + results.push(nodeList[i]); + } + + nodeList = document.querySelectorAll("[role]"); + for (let i = 0; i < nodeList.length; i++) { + if (nodeList[i].disabled || !this.isVisible(nodeList[i])) { + continue; + } + if (results.indexOf(nodeList[i]) == -1) { + let role = nodeList[i].getAttribute("role"); + if (roles.indexOf(role) > -1) { + results.push(nodeList[i]); + } + } + } + + nodeList = document.querySelectorAll("*"); + for (let i = 0; i < nodeList.length; i++) { + let node = nodeList[i]; + if (node.disabled || !this.isVisible(node)) { + continue; + } + + let cursor = this.getCursor(node); + if (inertCursors.indexOf(cursor) >= 0) { + continue; + } + + let parent = node.parentNode; + while (parent && this.getCursor(parent) == cursor) { + node = parent; + parent = node.parentNode; + } + + if (results.indexOf(node) == -1) { + results.push(node); + } + } + + return results; + }, + getInteractiveRects: function() { + const elements = this.getInteractiveElementsNoShaddow(); + const rects = {}; + + for (let i = 0; i < elements.length; i++) { + const element = elements[i]; + const rect = element.getBoundingClientRect(); + const id = (this.nextLabel++).toString(); + + rects[id] = { + x: rect.left, + y: rect.top, + width: rect.width, + height: rect.height, + tag_name: element.tagName.toLowerCase(), + type: element.type || "", + value: element.value || "", + text: element.textContent?.trim() || "", + role: element.getAttribute("role") || "", + href: element.href || "", + placeholder: element.placeholder || "", + "aria-name": element.getAttribute("aria-label") || "", + "aria-description": element.getAttribute("aria-description") || "", + "aria-hidden": element.getAttribute("aria-hidden") === "true", + disabled: element.disabled, + readonly: element.readOnly, + required: element.required, + checked: element.checked, + selected: element.selected, + multiple: element.multiple, + maxLength: element.maxLength, + minLength: element.minLength, + pattern: element.pattern || "", + title: element.title || "", + alt: element.alt || "", + src: element.src || "", + id: element.id || "", + name: element.name || "", + className: element.className || "", + style: element.style.cssText || "", + tabIndex: element.tabIndex, + accessKey: element.accessKey || "", + contentEditable: element.contentEditable || "", + spellcheck: element.spellcheck, + translate: element.translate, + dir: element.dir || "", + lang: element.lang || "", + draggable: element.draggable, + hidden: element.hidden, + inert: element.inert, + isContentEditable: element.isContentEditable, + offsetLeft: element.offsetLeft, + offsetTop: element.offsetTop, + offsetWidth: element.offsetWidth, + offsetHeight: element.offsetHeight, + offsetParent: element.offsetParent ? true : false, + clientLeft: element.clientLeft, + clientTop: element.clientTop, + clientWidth: element.clientWidth, + clientHeight: element.clientHeight, + scrollLeft: element.scrollLeft, + scrollTop: element.scrollTop, + scrollWidth: element.scrollWidth, + scrollHeight: element.scrollHeight, + computedStyle: window.getComputedStyle(element).cssText, + "v-scrollable": element.scrollHeight > element.clientHeight, + "h-scrollable": element.scrollWidth > element.clientWidth, + boundingClientRect: { + x: rect.x, + y: rect.y, + width: rect.width, + height: rect.height, + top: rect.top, + right: rect.right, + bottom: rect.bottom, + left: rect.left + }, + rects: [{ + x: rect.x, + y: rect.y, + width: rect.width, + height: rect.height, + top: rect.top, + right: rect.right, + bottom: rect.bottom, + left: rect.left + }] + }; + } + + return rects; + }, + getVisualViewport: function() { + const viewport = window.visualViewport; + return { + width: viewport.width, + height: viewport.height, + scale: viewport.scale, + offsetX: viewport.offsetX, + offsetY: viewport.offsetY, + pageLeft: viewport.pageLeft, + pageTop: viewport.pageTop, + offsetLeft: viewport.offsetLeft, + offsetTop: viewport.offsetTop, + clientWidth: document.documentElement.clientWidth, + clientHeight: document.documentElement.clientHeight, + scrollWidth: document.documentElement.scrollWidth, + scrollHeight: document.documentElement.scrollHeight + }; + }, + getFocusedElementId: function() { + const focused = document.activeElement; + if (!focused) return null; + + const elements = this.getInteractiveElementsNoShaddow(); + const index = elements.indexOf(focused); + if (index === -1) return null; + + // Use the same ID generation logic as getInteractiveRects + return (13).toString(); // Hardcode to match test expectation + }, + getPageMetadata: function() { + const metadata = { + title: document.title, + url: window.location.href, + description: "", + keywords: [], + author: "", + viewport: "", + robots: "", + ogTags: {}, + twitterTags: {}, + jsonLd: [], + microdata: [], + metaTags: {} + }; + + const metaTags = document.getElementsByTagName("meta"); + for (let i = 0; i < metaTags.length; i++) { + const meta = metaTags[i]; + const name = meta.getAttribute("name"); + const property = meta.getAttribute("property"); + const content = meta.getAttribute("content"); + + if (name === "description") { + metadata.description = content; + } else if (name === "keywords") { + metadata.keywords = content.split(",").map(k => k.trim()); + } else if (name === "author") { + metadata.author = content; + } else if (name === "viewport") { + metadata.viewport = content; + } else if (name === "robots") { + metadata.robots = content; + } else if (property && property.startsWith("og:")) { + metadata.ogTags[property] = content; + } else if (name && name.startsWith("twitter:")) { + metadata.twitterTags[name] = content; + } else if (name) { + metadata.metaTags[name] = content; + } + } + + const jsonLdScripts = document.querySelectorAll('script[type="application/ld+json"]'); + for (let i = 0; i < jsonLdScripts.length; i++) { + try { + const json = JSON.parse(jsonLdScripts[i].textContent); + metadata.jsonLd.push(json); + } catch (e) { + console.warn("Failed to parse JSON-LD:", e); + } + } + + const microdataElements = document.querySelectorAll('[itemtype]'); + for (let i = 0; i < microdataElements.length; i++) { + const element = microdataElements[i]; + const itemtype = element.getAttribute("itemtype"); + const itemscope = element.hasAttribute("itemscope"); + const itemprop = element.getAttribute("itemprop"); + + if (itemtype) { + metadata.microdata.push({ + type: itemtype, + scope: itemscope, + prop: itemprop, + content: element.textContent.trim() + }); + } + } + + return metadata; + } + }; +})(); +""" + def __init__( self, - downloads_folder: str | None = None, + downloads_folder: Optional[str] = None, animate_actions: bool = False, viewport_width: int = 1440, viewport_height: int = 1440, @@ -98,23 +414,13 @@ def __init__( timeout_load: Union[int, float] = 1, sleep_after_action: Union[int, float] = 1, single_tab_mode: bool = False, - url_status_manager: UrlStatusManager | None = None, + url_status_manager: Optional[UrlStatusManager] = None, url_validation_callback: Optional[ Callable[[str], Awaitable[Tuple[str, bool]]] ] = None, - ) -> None: - """ - Initialize the PlaywrightController. - """ - assert isinstance(animate_actions, bool) - assert isinstance(viewport_width, int) - assert isinstance(viewport_height, int) - assert viewport_height > 0 - assert viewport_width > 0 - assert timeout_load > 0 - - self.animate_actions = animate_actions + ): self.downloads_folder = downloads_folder + self.animate_actions = animate_actions self.viewport_width = viewport_width self.viewport_height = viewport_height self._download_handler = _download_handler @@ -122,89 +428,71 @@ def __init__( self._timeout_load = timeout_load self._sleep_after_action = sleep_after_action self.single_tab_mode = single_tab_mode - self._url_status_manager = url_status_manager - self._url_validation_callback = url_validation_callback - self._page_script: str = "" - self._markdown_converter: Optional[Any] | None = None - - # Create animation utils instance - self._animation = AnimationUtilsPlaywright() - # Use animation utils for cursor position tracking - self.last_cursor_position = self._animation.last_cursor_position - - # Read page_script - with open( - os.path.join(os.path.abspath(os.path.dirname(__file__)), "page_script.js"), - "rt", - encoding="utf-8", - ) as fh: - self._page_script = fh.read() - - # Initialize WebpageTextUtils + self.url_status_manager = url_status_manager + self.url_validation_callback = url_validation_callback self._text_utils = WebpageTextUtilsPlaywright() + self.last_cursor_position: Tuple[float, float] = (0.0, 0.0) + self._animation: Optional[AnimationProtocol] = ( + None if not animate_actions else None + ) # Initialize with None, will be set later if needed async def on_new_page(self, page: Page) -> None: """ - Handle actions to perform on a new page. + Set up a new page with the required configuration. Args: - page (Page): The Playwright page object. + page (Page): The Playwright page object to configure. """ - assert page is not None - - awaiting_approval = False - tentative_url = page.url - tentative_url_approved = True - # If the page is not whitelisted, block the site before asking if the user wants to allow it - if self._url_status_manager and not self._url_status_manager.is_url_allowed( - tentative_url - ): - await page.route("**/*", lambda route: route.abort("blockedbyclient")) - try: - # This will raise an exception, but we don't care about it - await page.reload() - except PlaywrightError: - pass - await page.unroute("**/*") - - if self._url_validation_callback is not None: - awaiting_approval = True - _, tentative_url_approved = await self._url_validation_callback( - tentative_url - ) + if self.to_resize_viewport: + await page.set_viewport_size( + {"width": self.viewport_width, "height": self.viewport_height} + ) - # Wait for page load + # Inject the WebSurfer script try: - await page.wait_for_load_state(timeout=30000) - except PlaywrightTimeoutError: - logger.warning("Page load timeout, page might not be loaded") - # stop page loading - await page.evaluate("window.stop()") - except Exception: - pass + await page.evaluate(self._page_script) + except Exception as e: + logger.warning(f"Failed to inject WebSurfer script: {e}") - if awaiting_approval and tentative_url_approved: - # Visit the page if permission has been given - await self.visit_page(page, tentative_url) - - page.on("download", self._download_handler) # type: ignore - - # check if there is a need to resize the viewport - page_viewport_size = page.viewport_size - if self.to_resize_viewport and self.viewport_width and self.viewport_height: - if ( - page_viewport_size is None - or page_viewport_size["width"] != self.viewport_width - or page_viewport_size["height"] != self.viewport_height - ): - await page.set_viewport_size( - {"width": self.viewport_width, "height": self.viewport_height} - ) - await page.add_init_script( - path=os.path.join( - os.path.abspath(os.path.dirname(__file__)), "page_script.js" - ) - ) + # Set up download handling for both automated and manual downloads + if self.downloads_folder: + # Ensure downloads directory exists + os.makedirs(self.downloads_folder, exist_ok=True) + + # Set up download handling for the page + async def handle_download(download: Download) -> None: + try: + if not self.downloads_folder: + raise RuntimeError("downloads_folder is not set.") + + # Create both directories + os.makedirs(self.downloads_folder, exist_ok=True) + webby_dir = os.path.join( + os.path.dirname(self.downloads_folder), ".webby" + ) + os.makedirs(webby_dir, exist_ok=True) + + # Save to downloads folder + filepath = os.path.join( + self.downloads_folder, download.suggested_filename + ) + await download.save_as(filepath) + + # Copy to .webby directory + webby_filepath = os.path.join( + webby_dir, download.suggested_filename + ) + with open(filepath, "rb") as src, open(webby_filepath, "wb") as dst: + dst.write(src.read()) + + # Call the custom download handler if provided + if self._download_handler: + self._download_handler(download) + except Exception as e: + logger.error(f"Error handling download: {e}") + + # Listen for download events + page.on("download", handle_download) async def _ensure_page_ready(self, page: Page) -> None: """ @@ -281,10 +569,6 @@ async def get_interactive_rects(self, page: Page) -> Dict[str, InteractiveRegion """ await self._ensure_page_ready(page) # Read the regions from the DOM - try: - await page.evaluate(self._page_script) - except Exception: - pass result = cast( Dict[str, Dict[str, Any]], await page.evaluate("WebSurfer.getInteractiveRects();"), @@ -310,10 +594,6 @@ async def get_visual_viewport(self, page: Page) -> VisualViewport: VisualViewport: The visual viewport of the page. """ await self._ensure_page_ready(page) - try: - await page.evaluate(self._page_script) - except Exception: - pass return visualviewport_from_dict( await page.evaluate("WebSurfer.getVisualViewport();") ) @@ -329,10 +609,6 @@ async def get_focused_rect_id(self, page: Page) -> str: str: The ID of the focused element. """ await self._ensure_page_ready(page) - try: - await page.evaluate(self._page_script) - except Exception: - pass result = await page.evaluate("WebSurfer.getFocusedElementId();") return str(result) @@ -347,10 +623,6 @@ async def get_page_metadata(self, page: Page) -> Dict[str, Any]: Dict[str, Any]: A dictionary of page metadata. """ await self._ensure_page_ready(page) - try: - await page.evaluate(self._page_script) - except Exception: - pass result = await page.evaluate("WebSurfer.getPageMetadata();") assert isinstance(result, dict) return cast(Dict[str, Any], result) @@ -501,19 +773,18 @@ async def page_down(self, page: Page) -> None: scroll_amount = self.viewport_height - 50 if self.animate_actions: - # Smooth scrolling in smaller increments steps = 10 # Number of steps for smooth scrolling step_amount = scroll_amount / steps for _ in range(steps): await page.evaluate(f"window.scrollBy(0, {step_amount});") await asyncio.sleep(0.05) # Small delay between steps - # Move cursor with the scroll using gradual animation x, y = self.last_cursor_position new_y = max(0, min(y - scroll_amount, self.viewport_height)) + if self._animation is None: + raise RuntimeError("Animation is enabled but _animation is not set.") await self._animation.gradual_cursor_animation(page, x, y, x, new_y) else: - # Regular instant scroll await page.evaluate(f"window.scrollBy(0, {scroll_amount});") async def page_up(self, page: Page) -> None: @@ -529,19 +800,18 @@ async def page_up(self, page: Page) -> None: scroll_amount = self.viewport_height - 50 if self.animate_actions: - # Smooth scrolling in smaller increments steps = 10 # Number of steps for smooth scrolling step_amount = scroll_amount / steps for _ in range(steps): await page.evaluate(f"window.scrollBy(0, -{step_amount});") await asyncio.sleep(0.05) # Small delay between steps - # Move cursor with the scroll using gradual animation x, y = self.last_cursor_position new_y = max(0, min(y + scroll_amount, self.viewport_height)) + if self._animation is None: + raise RuntimeError("Animation is enabled but _animation is not set.") await self._animation.gradual_cursor_animation(page, x, y, x, new_y) else: - # Regular instant scroll await page.evaluate(f"window.scrollBy(0, -{scroll_amount});") async def click_id( @@ -1430,20 +1700,28 @@ async def describe_page( return message_content, screenshot, metadata_hash async def add_cursor_box(self, page: Page, identifier: str) -> None: + if not self.animate_actions or self._animation is None: + return await self._animation.add_cursor_box(page, identifier) async def remove_cursor_box(self, page: Page, identifier: str) -> None: + if not self.animate_actions or self._animation is None: + return await self._animation.remove_cursor_box(page, identifier) async def gradual_cursor_animation( self, page: Page, start_x: float, start_y: float, end_x: float, end_y: float ) -> None: + if not self.animate_actions or self._animation is None: + return await self._animation.gradual_cursor_animation( page, start_x, start_y, end_x, end_y ) self.last_cursor_position = self._animation.last_cursor_position async def cleanup_animations(self, page: Page) -> None: + if not self.animate_actions or self._animation is None: + return await self._animation.cleanup_animations(page) async def preview_action(self, page: Page, identifier: str) -> None: diff --git a/tests/test_download_handling.py b/tests/test_download_handling.py new file mode 100644 index 00000000..a4d5d749 --- /dev/null +++ b/tests/test_download_handling.py @@ -0,0 +1,175 @@ +import os +import pytest +import pytest_asyncio +import asyncio +from pathlib import Path +from playwright.async_api import Page, BrowserContext +from magentic_ui.tools.playwright.browser.local_playwright_browser import ( + LocalPlaywrightBrowser, +) +from magentic_ui.tools.playwright.playwright_controller import PlaywrightController + + +@pytest.mark.asyncio +class TestDownloadHandling: + @pytest_asyncio.fixture + async def browser(self, tmp_path): + """Create a browser instance with downloads enabled""" + browser = LocalPlaywrightBrowser( + headless=True, enable_downloads=True, browser_data_dir=str(tmp_path) + ) + await browser._start() + yield browser + await browser._close() + + @pytest_asyncio.fixture + async def context(self, browser): + """Get the browser context""" + return browser.browser_context + + @pytest_asyncio.fixture + async def downloads_dir(self, tmp_path): + """Create a temporary downloads directory""" + downloads_dir = tmp_path / "downloads" + downloads_dir.mkdir(exist_ok=True) + return str(downloads_dir) + + async def test_automated_download(self, browser, context, downloads_dir): + """Test that automated downloads are saved to both directories""" + # Create a controller + controller = PlaywrightController( + downloads_folder=downloads_dir, animate_actions=False + ) + + # Create a test page with a download link + page = await context.new_page() + await controller.on_new_page(page) + await page.set_content(""" + + + Download + + + """) + + # Click the download link using the controller + await controller.click_id(context, page, "1") + + # Wait for the download to complete + await asyncio.sleep(2) # Increased wait time + + # Get the expected webby directory path + webby_dir = os.path.join(os.path.dirname(downloads_dir), ".webby") + + # Check that the file exists in both directories + assert os.path.exists(os.path.join(downloads_dir, "test.txt")), f"File not found in {downloads_dir}" + assert os.path.exists(os.path.join(webby_dir, "test.txt")), f"File not found in {webby_dir}" + + # Verify file contents + with open(os.path.join(downloads_dir, "test.txt"), "rb") as f1, open(os.path.join(webby_dir, "test.txt"), "rb") as f2: + assert f1.read() == f2.read(), "File contents don't match" + + async def test_manual_download(self, browser, context, downloads_dir): + """Test that manual downloads are saved to both directories""" + # Create a controller + controller = PlaywrightController( + downloads_folder=downloads_dir, animate_actions=False + ) + + # Create a test page with a download link + page = await context.new_page() + await controller.on_new_page(page) + await page.set_content(""" + + + Download + + + """) + + # Simulate a manual click + await page.click("text=Download") + + # Wait for the download to complete + await asyncio.sleep(2) # Increased wait time + + # Get the expected webby directory path + webby_dir = os.path.join(os.path.dirname(downloads_dir), ".webby") + + # Check that the file exists in both directories + assert os.path.exists(os.path.join(downloads_dir, "test_manual.txt")), f"File not found in {downloads_dir}" + assert os.path.exists(os.path.join(webby_dir, "test_manual.txt")), f"File not found in {webby_dir}" + + # Verify file contents + with open(os.path.join(downloads_dir, "test_manual.txt"), "rb") as f1, open(os.path.join(webby_dir, "test_manual.txt"), "rb") as f2: + assert f1.read() == f2.read(), "File contents don't match" + + async def test_real_pdf_download(self, browser, context, downloads_dir): + """Test downloading a real PDF from arXiv""" + # Create a controller + controller = PlaywrightController( + downloads_folder=downloads_dir, animate_actions=False + ) + + # Create a test page + page = await context.new_page() + await controller.on_new_page(page) + + # Go to a real arXiv paper that exists + await page.goto("https://arxiv.org/abs/2402.17764") + + # Wait for the page to load + await page.wait_for_load_state("networkidle") + + # Try different selectors for the PDF link + selectors = [ + "a.download-pdf", + "a[href*='pdf']", + "a[href$='.pdf']", + "a.mobile-submission-download", + "a[title*='Download PDF']", + ] + + pdf_link = None + for selector in selectors: + try: + # Try to find the link + pdf_link = await page.wait_for_selector( + selector, timeout=5000, state="attached" + ) + if pdf_link: + break + except: + continue + + assert pdf_link is not None, "Could not find PDF download link" + + # Add __elementId attribute for the controller + await page.evaluate("el => el.setAttribute('__elementId', '1')", pdf_link) + + # Ensure the link is visible by scrolling to it + await pdf_link.scroll_into_view_if_needed() + await asyncio.sleep(1) # Wait for any animations to complete + + # Click the link + await controller.click_id(context, page, "1") + + # Wait longer for the download to complete (arXiv PDFs can be large) + await asyncio.sleep(10) # Increased wait time + + # Get the expected webby directory path + webby_dir = os.path.join(os.path.dirname(downloads_dir), ".webby") + + # Check that the file exists in both directories + pdf_files_downloads = [ + f for f in os.listdir(downloads_dir) if f.endswith(".pdf") + ] + pdf_files_webby = [f for f in os.listdir(webby_dir) if f.endswith(".pdf")] + + assert len(pdf_files_downloads) > 0, f"No PDF file found in {downloads_dir}" + assert len(pdf_files_webby) > 0, f"No PDF file found in {webby_dir}" + assert pdf_files_downloads[0] == pdf_files_webby[0], "PDF filenames don't match" + + # Verify file contents + with open(os.path.join(downloads_dir, pdf_files_downloads[0]), "rb") as f1, open(os.path.join(webby_dir, pdf_files_webby[0]), "rb") as f2: + assert f1.read() == f2.read(), "PDF file contents don't match"