"""
Design discovery and execution
==============================
This module provides tools for discovering, building, and executing
JITX designs, including communication with the JITX runtime. Normally this
module is not used directly, but rather through the `jitx` command line, or
through the VSCode extension.
.. warning::
The API in this module is still experimental and may change significantly
without notice.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
import asyncio
from collections.abc import Mapping, Sequence
from logging import getLogger
import sys
from typing import Any, overload, override, cast
import jitx.design
from jitx._formatters import Formatter, text_formatter
from jitx.error import InstantiationException, UserCodeException
from .._websocket import (
PersistentWebSocketClient,
Message,
set_websocket_uri as _set_websocket_uri,
)
from .pyproject import PyProject
logging = getLogger("jitx.run")
STATUS = "status"
MESSAGE = "message"
ERRORS = "errors"
IMPORTFAILED = "import failed"
INSTANTIATIONFAILED = "instantiation failed"
TRANSLATEFAILED = "translation failed"
HINTS = "hints"
LOG = "log"
OK = "ok"
[docs]
class DesignFactory:
def __init__(
self,
finder: DesignFinder | None = None,
builder: BaseDesignBuilder | None = None,
*,
formatter: Formatter | None = None,
dump: str | None = None,
):
self.finder = finder or DesignFinder()
self.builder = builder or DesignBuilder()
self.formatter: Formatter = formatter or text_formatter
self.dump = dump
self.queue: dict[str, type[jitx.design.Design] | jitx.design.Design] = {}
self.success = True
[docs]
def by_name(self, name: str):
self.add(self.finder.find_by_name(name))
[docs]
def by_file(self, name: str):
for design in self.finder.find_by_file(name):
self.add(design)
[docs]
def add_all(self):
for design in self.finder.find_all():
if hasattr(design, "__signature__"):
if design.__signature__.parameters:
try:
design.__signature__.bind()
except Exception:
# parameterized design with non-default arguments, skip
continue
self.add(design)
[docs]
def add(
self,
design: type[jitx.design.Design] | jitx.design.Design,
*,
name: str | None = None,
):
if not name:
import jitx._structural
if isinstance(design, jitx._structural.Instantiable):
cls = design._instantiable_()
if isinstance(cls, type) and issubclass(cls, jitx.design.Design):
name = cls.__module__ + "." + cls.__name__
# fake the type to avoid type errors, it's a deferred instantiable of a Design
design = cast(jitx.design.Design, design)
else:
raise ValueError(f"Invalid design: {design}")
elif isinstance(design, jitx.design.Design):
name = design.__class__.__module__ + "." + design.__class__.__name__
elif isinstance(design, type) and issubclass(design, jitx.design.Design):
name = design.__module__ + "." + design.__name__
else:
raise ValueError(f"Invalid design: {design}")
if name in self.queue:
raise ValueError(f"Duplicate design added to build queue: {name}")
self.queue[name] = design
[docs]
def build(self):
aggregate = {}
for name, design in self.queue.items():
result = self.builder.build(
design, name=name, dump=self.dump, formatter=self.formatter
)
aggregate[name] = result
if ERRORS in result or result.get(STATUS) != OK:
self.success = False
if self.finder.exceptions:
self.success = False
aggregate[ERRORS] = {
IMPORTFAILED: {
name: repr(e) for name, e in self.finder.exceptions.items()
}
}
self.formatter(aggregate)
[docs]
def list(self):
result = {}
df = self.finder
result["designs"] = [d.__module__ + "." + d.__name__ for d in df.find_all()]
if df.exceptions:
if ERRORS not in result:
result[ERRORS] = {}
if IMPORTFAILED not in result[ERRORS]:
result[ERRORS][IMPORTFAILED] = {}
result[ERRORS][IMPORTFAILED].update(
(name, repr(e)) for name, e in df.exceptions.items()
)
self.formatter(result)
[docs]
class DesignFinder:
def __init__(self, roots: str | Sequence[str] | None = None):
if isinstance(roots, str):
roots = (roots,)
self.roots = roots or (".",)
self.exceptions: dict[str, Exception] = {}
[docs]
def find_all(self):
import os
for root in self.roots:
project = PyProject(root)
tool = project.jitxtool
for dirpath, dirnames, filenames in os.walk(root):
path = os.path.relpath(dirpath, root)
if path in tool.exclude:
dirnames[:] = []
else:
dirnames[:] = [
d
for d in dirnames
if not d.startswith("_")
and not d.startswith(".")
and not any(d == ex for ex in tool.exclude)
]
for filename in filenames:
if filename.startswith("_") or filename.startswith("."):
continue
if any(filename == ex for ex in tool.exclude):
continue
if not filename.endswith(".py"):
continue
yield from self.find_by_file(os.path.join(path, filename))
# NOTE walk_packages struggles with namespace packages, so we have to
# resort to look for python files for now. This doesn't support things
# like eggs, which is probably fine anyway.
# import pkgutil
# for mi in pkgutil.walk_packages(self.roots):
# try:
# yield from self.find_by_module(mi.name)
# except Exception as e:
# self.exceptions[mi.name] = e
[docs]
def find_by_name(self, name: str):
import importlib
ns = name.rsplit(".", 1)
if len(ns) != 2:
raise ValueError(f"Invalid design name: {name}")
modulename, classname = ns
m = importlib.import_module(modulename)
design = getattr(m, classname, None)
if design is None:
raise ValueError(f"{classname} not found in {modulename}")
if not issubclass(design, jitx.design.Design):
raise ValueError(f"{classname} in {modulename} is not a Design")
return design
[docs]
def find_by_module(self, name: str):
import importlib
import jitx.sample
try:
m = importlib.import_module(name)
except Exception as e:
self.exceptions[name] = e
return
for elem in dir(m):
field = getattr(m, elem, None)
if (
isinstance(field, type)
and issubclass(field, jitx.design.Design)
and field not in (jitx.design.Design, jitx.sample.SampleDesign)
and field.__module__ == m.__name__
):
yield field
[docs]
def find_by_file(self, path: str):
import os.path
import importlib
path, filename = os.path.split(os.path.normpath(path))
module, _ = os.path.splitext(filename)
steps = []
while path:
rem, last = os.path.split(path)
if rem == path:
break
path = rem
steps.append(last)
steps.reverse()
steps.append(module)
# attempt to find the longest matching module path to avoid accidentally
# importing a "shadowed" module
candidate = module
for i in range(len(steps)):
if steps[i] == "src":
# do not accept "src" as top-level module, this is common in the
# so called "src-layout" pattern, and it could technically be a
# valid top level package, it's exceedingly unlikely.
continue
candidate = ".".join(steps[i:])
try:
importlib.import_module(candidate)
break
except ModuleNotFoundError as e:
if e.name != candidate and e.name != steps[i]:
self.exceptions[candidate] = e
# only continue trying if the error is about the candidate,
# otherwise it's probably an error _inside_ the module.
return
except Exception as e:
# some other error happening on import here, so we're probably
# in the right spot.
self.exceptions[candidate] = e
return
yield from self.find_by_module(candidate)
[docs]
class BaseDesignBuilder(ABC):
@abstractmethod
def __init__(self):
pass
[docs]
def build(
self,
design: type[jitx.design.Design] | jitx.design.Design,
*,
name: str | None = None,
dump: str | None = None,
formatter: Formatter,
) -> Mapping[str, Any]:
"""Build the design. This is a base class, where the design gets sent
is determined by the specific subclass implementation.
Args:
design: The Design class to build.
name: Optional name for the design.
dump: Optional file path to dump the design data.
formatter: Function to format and output results.
"""
import jitx._structural
from jitx._instantiation import instantiation
import jitx._translate.design
import gc
name = name or design.__module__ + "." + design.__name__
result: dict[str, Any] = {"design": name}
with instantiation.activate():
try:
if isinstance(design, jitx._structural.Instantiable):
cls = design._instantiable_()
if not (
isinstance(cls, type) and issubclass(cls, jitx.design.Design)
):
raise ValueError(f"Not an instantiable design: {cls}")
with instantiation.frame():
instantiated = design._instantiate_({})
elif isinstance(design, type):
if not issubclass(design, jitx.design.Design):
raise ValueError(f"Not a Design subclass: {cls}")
try:
if hasattr(design, "__signature__"):
design.__signature__.bind()
except TypeError:
raise ValueError(
"Design is parameterized but no parameters were provided"
) from None
with instantiation.frame():
instantiated = design()
else:
instantiated = design
assert isinstance(instantiated, jitx.design.Design), (
f"Got non-Design object {jitx._structural.Proxy.type(instantiated)}"
)
except Exception as e:
logging.exception(f"Unable to instantiate design {design}")
errors: list[BaseException] = [e]
while e.__cause__ is not None:
e = e.__cause__
errors.insert(0, e)
if (tb := e.__traceback__) is not None:
tb = tb.tb_next
def formatexc(e: BaseException):
if isinstance(e, UserCodeException | InstantiationException):
return str(e)
loc = ""
if tb := e.__traceback__:
while tb.tb_next:
tb = tb.tb_next
loc = f" at {tb.tb_frame.f_code.co_filename}:{tb.tb_lineno}"
pass
return str(e) + loc
result[ERRORS] = {INSTANTIATIONFAILED: [formatexc(e) for e in errors]}
return result
# try to force detection of lost elements.
gc.collect()
# callbacks during packaging need active instantiation.
try:
from jitx.substrate import SubstrateContext
from jitx.design import DesignContext
# some elements introspect the design on access, looking at
# general contexts on access is not a good idea in general,
# given that the expected behavior is to look at the context
# that was set at instantiation, not inspection - design and
# substrate seem like safe enough bets though.
with DesignContext(instantiated):
with SubstrateContext(instantiated.substrate):
packaged = jitx._translate.design.package_design(instantiated)
except UserCodeException as e:
result[ERRORS] = {TRANSLATEFAILED: [str(e)]}
if e.hint:
result[HINTS] = [e.hint]
return result
except Exception as e:
logging.exception(f"Unable to translate design {design}")
result[ERRORS] = {TRANSLATEFAILED: [str(e)]}
return result
# and again, in case something happened in translation callbacks.
gc.collect()
del instantiated
from google.protobuf.json_format import MessageToDict
body = MessageToDict(packaged, use_integers_for_enums=True)
if dump:
with open(dump, "w") as f:
formatter(body, file=f)
def log_message(ob, file=None):
if file is None:
formatter(ob, sys.stdout)
else:
formatter(ob, file)
try:
result.update(asyncio.run(self._send_design(name, body, log_message)))
except Exception as e:
result[STATUS] = "error"
result[MESSAGE] = str(e)
return result
async def _send_design(
self, name: str, body, formatter: Formatter
) -> Mapping[str, Any]:
raise NotImplementedError
[docs]
class DryRunBuilder(BaseDesignBuilder):
def __init__(self):
super().__init__()
async def _send_design(self, name: str, body, formatter: Formatter):
return {STATUS: OK}
[docs]
class DesignBuilder(BaseDesignBuilder):
@overload
def __init__(self, *, spec: str | None = None): ...
@overload
def __init__(self, *, uri: str): ...
@overload
def __init__(self, *, port: int, host: str = "localhost"): ...
__client_future: asyncio.Task[PersistentWebSocketClient] | None = None
def __init__(
self,
*,
spec: str | None = None,
uri: str | None = None,
port: int | None = None,
host: str | None = None,
):
super().__init__()
def lazy_setup():
if uri is not None:
_set_websocket_uri(uri=uri)
elif port is not None:
_set_websocket_uri(uri=f"ws://{host or 'localhost'}:{port}")
elif spec is not None:
_set_websocket_uri(file=spec)
else:
_set_websocket_uri()
self.__setup = lambda: None
self.__setup = lazy_setup
async def __client(self):
if self.__client_future is None:
self.__client_future = asyncio.create_task(
PersistentWebSocketClient.create()
)
return await self.__client_future
[docs]
@override
def build(
self,
design: type[jitx.design.Design] | jitx.design.Design,
*,
name: str | None = None,
dump: str | None = None,
formatter: Formatter,
) -> Mapping[str, Any]:
self.__setup()
return super().build(design, name=name, formatter=formatter, dump=dump)
async def _send_design(self, name: str, body, formatter: Formatter):
formatter(f"Running design {name}...")
conversation = None
success = False
client = await self.__client()
await client.establish()
try:
# Create route handler and send message
route = client.root()
message = Message(type="load", body=body, ns="des")
conversation = await route.request(message)
# Iterate over responses
result_message = None
async for envelope in conversation:
message = envelope.message
if envelope.is_terminal():
# Terminal error envelopes are not delivered, they raise an error
match envelope.type:
case "ok":
if "message" in message.body:
result_message = message.body["message"]
success = True
case _:
raise RuntimeError(
f"Unhandled terminal message: {message.type}"
)
else:
match message.type:
case "stdin":
answer = input()
await conversation.send(
message.response({"message": answer})
)
case "stdout":
formatter(message.body["message"])
case _:
raise RuntimeError(
f"Unhandled intermediate response message: {message.type}"
)
finally:
if conversation is not None:
await conversation.close()
if not success:
raise RuntimeError("Did not receive a status report")
if result_message is not None:
return {
STATUS: OK,
MESSAGE: result_message,
}
else:
return {STATUS: OK}