Source code for jitx.run

"""
Design discovery and execution
==============================

This module provides tools for discovering, building, and executing
JITX designs, including communication with the JITX runtime. Normally this
module is not used directly, but rather through the `jitx` command line, or
through the VSCode extension.

.. warning::
    The API in this module is still experimental and may change significantly
    without notice.
"""

from __future__ import annotations
from abc import ABC, abstractmethod
import asyncio
from collections.abc import Mapping, Sequence
from logging import getLogger
import sys
from typing import Any, overload, override, cast

import jitx.design
from jitx._formatters import Formatter, text_formatter
from jitx.error import InstantiationException, UserCodeException

from .._websocket import (
    PersistentWebSocketClient,
    Message,
    set_websocket_uri as _set_websocket_uri,
)
from .pyproject import PyProject


logging = getLogger("jitx.run")

STATUS = "status"
MESSAGE = "message"
ERRORS = "errors"
IMPORTFAILED = "import failed"
INSTANTIATIONFAILED = "instantiation failed"
TRANSLATEFAILED = "translation failed"
HINTS = "hints"
LOG = "log"
OK = "ok"


[docs] class DesignFactory: def __init__( self, finder: DesignFinder | None = None, builder: BaseDesignBuilder | None = None, *, formatter: Formatter | None = None, dump: str | None = None, ): self.finder = finder or DesignFinder() self.builder = builder or DesignBuilder() self.formatter: Formatter = formatter or text_formatter self.dump = dump self.queue: dict[str, type[jitx.design.Design] | jitx.design.Design] = {} self.success = True
[docs] def by_name(self, name: str): self.add(self.finder.find_by_name(name))
[docs] def by_file(self, name: str): for design in self.finder.find_by_file(name): self.add(design)
[docs] def add_all(self): for design in self.finder.find_all(): if hasattr(design, "__signature__"): if design.__signature__.parameters: try: design.__signature__.bind() except Exception: # parameterized design with non-default arguments, skip continue self.add(design)
[docs] def add( self, design: type[jitx.design.Design] | jitx.design.Design, *, name: str | None = None, ): if not name: import jitx._structural if isinstance(design, jitx._structural.Instantiable): cls = design._instantiable_() if isinstance(cls, type) and issubclass(cls, jitx.design.Design): name = cls.__module__ + "." + cls.__name__ # fake the type to avoid type errors, it's a deferred instantiable of a Design design = cast(jitx.design.Design, design) else: raise ValueError(f"Invalid design: {design}") elif isinstance(design, jitx.design.Design): name = design.__class__.__module__ + "." + design.__class__.__name__ elif isinstance(design, type) and issubclass(design, jitx.design.Design): name = design.__module__ + "." + design.__name__ else: raise ValueError(f"Invalid design: {design}") if name in self.queue: raise ValueError(f"Duplicate design added to build queue: {name}") self.queue[name] = design
[docs] def build(self): aggregate = {} for name, design in self.queue.items(): result = self.builder.build( design, name=name, dump=self.dump, formatter=self.formatter ) aggregate[name] = result if ERRORS in result or result.get(STATUS) != OK: self.success = False if self.finder.exceptions: self.success = False aggregate[ERRORS] = { IMPORTFAILED: { name: repr(e) for name, e in self.finder.exceptions.items() } } self.formatter(aggregate)
[docs] def list(self): result = {} df = self.finder result["designs"] = [d.__module__ + "." + d.__name__ for d in df.find_all()] if df.exceptions: if ERRORS not in result: result[ERRORS] = {} if IMPORTFAILED not in result[ERRORS]: result[ERRORS][IMPORTFAILED] = {} result[ERRORS][IMPORTFAILED].update( (name, repr(e)) for name, e in df.exceptions.items() ) self.formatter(result)
[docs] class DesignFinder: def __init__(self, roots: str | Sequence[str] | None = None): if isinstance(roots, str): roots = (roots,) self.roots = roots or (".",) self.exceptions: dict[str, Exception] = {}
[docs] def find_all(self): import os for root in self.roots: project = PyProject(root) tool = project.jitxtool for dirpath, dirnames, filenames in os.walk(root): path = os.path.relpath(dirpath, root) if path in tool.exclude: dirnames[:] = [] else: dirnames[:] = [ d for d in dirnames if not d.startswith("_") and not d.startswith(".") and not any(d == ex for ex in tool.exclude) ] for filename in filenames: if filename.startswith("_") or filename.startswith("."): continue if any(filename == ex for ex in tool.exclude): continue if not filename.endswith(".py"): continue yield from self.find_by_file(os.path.join(path, filename))
# NOTE walk_packages struggles with namespace packages, so we have to # resort to look for python files for now. This doesn't support things # like eggs, which is probably fine anyway. # import pkgutil # for mi in pkgutil.walk_packages(self.roots): # try: # yield from self.find_by_module(mi.name) # except Exception as e: # self.exceptions[mi.name] = e
[docs] def find_by_name(self, name: str): import importlib ns = name.rsplit(".", 1) if len(ns) != 2: raise ValueError(f"Invalid design name: {name}") modulename, classname = ns m = importlib.import_module(modulename) design = getattr(m, classname, None) if design is None: raise ValueError(f"{classname} not found in {modulename}") if not issubclass(design, jitx.design.Design): raise ValueError(f"{classname} in {modulename} is not a Design") return design
[docs] def find_by_module(self, name: str): import importlib import jitx.sample try: m = importlib.import_module(name) except Exception as e: self.exceptions[name] = e return for elem in dir(m): field = getattr(m, elem, None) if ( isinstance(field, type) and issubclass(field, jitx.design.Design) and field not in (jitx.design.Design, jitx.sample.SampleDesign) and field.__module__ == m.__name__ ): yield field
[docs] def find_by_file(self, path: str): import os.path import importlib path, filename = os.path.split(os.path.normpath(path)) module, _ = os.path.splitext(filename) steps = [] while path: rem, last = os.path.split(path) if rem == path: break path = rem steps.append(last) steps.reverse() steps.append(module) # attempt to find the longest matching module path to avoid accidentally # importing a "shadowed" module candidate = module for i in range(len(steps)): if steps[i] == "src": # do not accept "src" as top-level module, this is common in the # so called "src-layout" pattern, and it could technically be a # valid top level package, it's exceedingly unlikely. continue candidate = ".".join(steps[i:]) try: importlib.import_module(candidate) break except ModuleNotFoundError as e: if e.name != candidate and e.name != steps[i]: self.exceptions[candidate] = e # only continue trying if the error is about the candidate, # otherwise it's probably an error _inside_ the module. return except Exception as e: # some other error happening on import here, so we're probably # in the right spot. self.exceptions[candidate] = e return yield from self.find_by_module(candidate)
[docs] class BaseDesignBuilder(ABC): @abstractmethod def __init__(self): pass
[docs] def build( self, design: type[jitx.design.Design] | jitx.design.Design, *, name: str | None = None, dump: str | None = None, formatter: Formatter, ) -> Mapping[str, Any]: """Build the design. This is a base class, where the design gets sent is determined by the specific subclass implementation. Args: design: The Design class to build. name: Optional name for the design. dump: Optional file path to dump the design data. formatter: Function to format and output results. """ import jitx._structural from jitx._instantiation import instantiation import jitx._translate.design import gc name = name or design.__module__ + "." + design.__name__ result: dict[str, Any] = {"design": name} with instantiation.activate(): try: if isinstance(design, jitx._structural.Instantiable): cls = design._instantiable_() if not ( isinstance(cls, type) and issubclass(cls, jitx.design.Design) ): raise ValueError(f"Not an instantiable design: {cls}") with instantiation.frame(): instantiated = design._instantiate_({}) elif isinstance(design, type): if not issubclass(design, jitx.design.Design): raise ValueError(f"Not a Design subclass: {cls}") try: if hasattr(design, "__signature__"): design.__signature__.bind() except TypeError: raise ValueError( "Design is parameterized but no parameters were provided" ) from None with instantiation.frame(): instantiated = design() else: instantiated = design assert isinstance(instantiated, jitx.design.Design), ( f"Got non-Design object {jitx._structural.Proxy.type(instantiated)}" ) except Exception as e: logging.exception(f"Unable to instantiate design {design}") errors: list[BaseException] = [e] while e.__cause__ is not None: e = e.__cause__ errors.insert(0, e) if (tb := e.__traceback__) is not None: tb = tb.tb_next def formatexc(e: BaseException): if isinstance(e, UserCodeException | InstantiationException): return str(e) loc = "" if tb := e.__traceback__: while tb.tb_next: tb = tb.tb_next loc = f" at {tb.tb_frame.f_code.co_filename}:{tb.tb_lineno}" pass return str(e) + loc result[ERRORS] = {INSTANTIATIONFAILED: [formatexc(e) for e in errors]} return result # try to force detection of lost elements. gc.collect() # callbacks during packaging need active instantiation. try: from jitx.substrate import SubstrateContext from jitx.design import DesignContext # some elements introspect the design on access, looking at # general contexts on access is not a good idea in general, # given that the expected behavior is to look at the context # that was set at instantiation, not inspection - design and # substrate seem like safe enough bets though. with DesignContext(instantiated): with SubstrateContext(instantiated.substrate): packaged = jitx._translate.design.package_design(instantiated) except UserCodeException as e: result[ERRORS] = {TRANSLATEFAILED: [str(e)]} if e.hint: result[HINTS] = [e.hint] return result except Exception as e: logging.exception(f"Unable to translate design {design}") result[ERRORS] = {TRANSLATEFAILED: [str(e)]} return result # and again, in case something happened in translation callbacks. gc.collect() del instantiated from google.protobuf.json_format import MessageToDict body = MessageToDict(packaged, use_integers_for_enums=True) if dump: with open(dump, "w") as f: formatter(body, file=f) def log_message(ob, file=None): if file is None: formatter(ob, sys.stdout) else: formatter(ob, file) try: result.update(asyncio.run(self._send_design(name, body, log_message))) except Exception as e: result[STATUS] = "error" result[MESSAGE] = str(e) return result
async def _send_design( self, name: str, body, formatter: Formatter ) -> Mapping[str, Any]: raise NotImplementedError
[docs] class DryRunBuilder(BaseDesignBuilder): def __init__(self): super().__init__() async def _send_design(self, name: str, body, formatter: Formatter): return {STATUS: OK}
[docs] class DesignBuilder(BaseDesignBuilder): @overload def __init__(self, *, spec: str | None = None): ... @overload def __init__(self, *, uri: str): ... @overload def __init__(self, *, port: int, host: str = "localhost"): ... __client_future: asyncio.Task[PersistentWebSocketClient] | None = None def __init__( self, *, spec: str | None = None, uri: str | None = None, port: int | None = None, host: str | None = None, ): super().__init__() def lazy_setup(): if uri is not None: _set_websocket_uri(uri=uri) elif port is not None: _set_websocket_uri(uri=f"ws://{host or 'localhost'}:{port}") elif spec is not None: _set_websocket_uri(file=spec) else: _set_websocket_uri() self.__setup = lambda: None self.__setup = lazy_setup async def __client(self): if self.__client_future is None: self.__client_future = asyncio.create_task( PersistentWebSocketClient.create() ) return await self.__client_future
[docs] @override def build( self, design: type[jitx.design.Design] | jitx.design.Design, *, name: str | None = None, dump: str | None = None, formatter: Formatter, ) -> Mapping[str, Any]: self.__setup() return super().build(design, name=name, formatter=formatter, dump=dump)
async def _send_design(self, name: str, body, formatter: Formatter): formatter(f"Running design {name}...") conversation = None success = False client = await self.__client() await client.establish() try: # Create route handler and send message route = client.root() message = Message(type="load", body=body, ns="des") conversation = await route.request(message) # Iterate over responses result_message = None async for envelope in conversation: message = envelope.message if envelope.is_terminal(): # Terminal error envelopes are not delivered, they raise an error match envelope.type: case "ok": if "message" in message.body: result_message = message.body["message"] success = True case _: raise RuntimeError( f"Unhandled terminal message: {message.type}" ) else: match message.type: case "stdin": answer = input() await conversation.send( message.response({"message": answer}) ) case "stdout": formatter(message.body["message"]) case _: raise RuntimeError( f"Unhandled intermediate response message: {message.type}" ) finally: if conversation is not None: await conversation.close() if not success: raise RuntimeError("Did not receive a status report") if result_message is not None: return { STATUS: OK, MESSAGE: result_message, } else: return {STATUS: OK}