Source code for jitx.transform

"""
Coordinate system transforms
============================

This module provides 2D transforms used in the JITX system to position objects
within the design tree.
"""

from __future__ import annotations
from typing import Literal, Self, overload
from collections.abc import Iterator
from math import radians, cos, sin, atan2, sqrt

from jitx.shapes import Shape, ShapeGeometry

type GridPoint = tuple[int, int]
"""Grid point coordinates as (x, y) integer tuple."""

type Point = tuple[float | int, float | int]
"""2D point coordinates as (x, y) tuple."""

type Vec3D = tuple[float | int, float | int, float | int]
"""3D vector as (x, y, z) tuple."""


[docs] class Transform: """Transform represents a translate * rotate * scale (TRS) transform in that order, so scale is applied first, then rotate, and finally translate. Note that non-unit scale will typically not carry over into component and circuit placement, and while it can be used to compute a placement before it's applied, the results may be unexpected, and should be avoided. Also note that because the transforms are internally stored as decomposed values, a non-uniform scale cannot be applied to another transform, as this would result in a shear. Constructing a transform directly is typically only needed in niche cases, and more often than not it is better to use member methods of various objects (in JITX typically called :py:meth:`~jitx.shapes.Shape.at`) to transform the object. Args: translate: Translation as (x, y). rotate: Optional rotation angle in degrees. scale: Optional scale factors as (x, y). If a single value is provided, then the same value is used for both x and y. >>> # Place a component at (10, 20) with 90° rotation >>> xform = Transform((10, 20), rotate=90) >>> # Use helper methods instead >>> xform = Transform.translate(10, 20) * Transform.rotate(90) >>> # Apply transform to a point >>> point = (5, 0) >>> new_point = xform * point >>> print(new_point) (10, 25) """ __slots__ = ("_translate", "_rotate", "_scale") _translate: Point """Translation as (x, y).""" _rotate: float """Rotation angle in degrees.""" _scale: tuple[float, float] """Scale factors as (x_scale, y_scale).""" def __init__( self, translate: Point, rotate: float = 0, scale: float | tuple[float, float] = (1, 1), ): self._translate = translate self._rotate = rotate if not isinstance(scale, tuple): scale = (scale, scale) self._scale = scale def __repr__(self): return f"Transform({self._translate}, {self._rotate}, {self._scale})"
[docs] def clone(self): """Create a copy of this transform. Returns: New Transform with identical translation, rotation, and scale. """ return self.__class__(self._translate, self._rotate, self._scale)
@property def trs(self): """Get the transform components as a tuple (translation, rotation, scale).""" return self._translate, self._rotate, self._scale @property def translation(self) -> Point: """The transform's translation as (x, y).""" return self._translate @property def rotation(self) -> float: """The transform's rotation angle in degrees.""" return self._rotate def __eq__(self, other): return ( isinstance(other, Transform) and self._translate == other._translate and self._rotate == other._rotate and self._scale == other._scale ) @overload def __mul__(self, other: Transform) -> Transform: ... @overload def __mul__(self, other: Point) -> Point: ... @overload def __mul__(self, other: Vec2D) -> Vec2D: ... @overload def __mul__[T: ShapeGeometry](self, other: Shape[T]) -> Shape[T]: ... def __mul__( self, other: Transform | Point | Shape | Vec2D ) -> Transform | Point | Shape | Vec2D: """Apply this transform to another transform, point, vector, or shape. Transforms can be composed by multiplying them together. The multiplication order follows matrix convention: `T1 * T2` means apply T2 first, then T1. Args: other: Object to transform (Transform, Point, Vec2D, or Shape). Returns: Transformed object of the same type. Raises: ValueError: If trying to apply non-uniform scale to a rotated transform. >>> # Compose transforms >>> t1 = Transform.translate(10, 0) >>> t2 = Transform.rotate(90) >>> combined = t1 * t2 # Rotate first, then translate >>> # Transform a point >>> point = (5, 0) >>> new_point = t1 * point # (15, 0) >>> # Transform a shape >>> circle = Circle(radius=5).at(0, 0) >>> moved_circle = t1 * circle """ # translate * rotate * scale * object if isinstance(other, Transform): return self.__apply_to_transform(other) elif isinstance(other, tuple): if len(other) == 2: return self.__apply_to_point(other) else: return NotImplemented elif isinstance(other, Vec2D): return self.__apply_to_vec2d(other) elif isinstance(other, Shape): if not other.transform: return Shape(other.geometry, self.clone()) else: return Shape(other.geometry, self.__apply_to_transform(other.transform)) else: return NotImplemented def __imul__(self, other: Transform): xform = self.__apply_to_transform(other) self._translate = xform._translate self._rotate = xform._rotate self._scale = xform._scale def __apply_to_transform(self, xf: Transform) -> Transform: sx, sy = self._scale osx, osy = xf._scale r = xf._rotate if sx != sy and abs(sx) != abs(sy) and r != 0: raise ValueError( "Unable to apply a non-uniform scale to an existing transform" ) if sx * sy < 0: r = -r return self.__class__( self.__apply_to_point(xf._translate), self._rotate + r, (sx * osx, sy * osy), )._post_mul(self, xf) def _post_mul(self, left: Transform, right: Transform): return right._post_rmul(self, left) def _post_rmul(self, result: Transform, left: Transform): return result def __apply_to_point(self, pt: Point) -> Point: x, y = pt tx, ty = self._translate r = self._rotate sx, sy = self._scale if r < 0: r += 360 if r == 0: return (tx + sx * x, ty + sy * y) if r == 90: return (tx - sy * y, ty + sx * x) if r == 180: return (tx - sx * x, ty - sy * y) if r == 270: return (tx + sy * y, ty - sx * x) a = radians(r) cosa = cos(a) sina = sin(a) return (tx + sx * x * cosa - sy * y * sina, ty + sx * x * sina + sy * y * cosa) def __apply_to_vec2d(self, v: Vec2D) -> Point: # same as point, but without the translation x, y = v.x, v.y r = self._rotate sx, sy = self._scale if r < 0: r += 360 if r == 0: return (sx * x, +sy * y) if r == 90: return (-sy * y, +sx * x) if r == 180: return (-sx * x, -sy * y) if r == 270: return (+sy * y, -sx * x) a = radians(r) cosa = cos(a) sina = sin(a) return (sx * x * cosa - sy * y * sina, sx * x * sina + sy * y * cosa)
[docs] def inverse(self): """Compute the inverse transform. The inverse transform undoes the effect of this transform. Applying a transform followed by its inverse returns the original value. Returns: New Transform that is the inverse of this transform. >>> t = Transform.translate(10, 20) * Transform.rotate(45) >>> t_inv = t.inverse() >>> point = (5, 5) >>> # Applying transform then inverse returns original >>> result = t_inv * (t * point) >>> # result ≈ (5, 5) """ x, y = self._translate r = self._rotate sx, sy = self._scale return ( Transform.scale(1 / sx, 1 / sy) * Transform.rotate(-r) * Transform.translate(-x, -y) )
def __invert__(self): """Compute the inverse transform using ~ operator. Returns: Inverse transform (same as calling :py:meth:`inverse`). >>> t = Transform.translate(10, 20) >>> t_inv = ~t # Equivalent to t.inverse() """ return self.inverse() @overload def matrix2x3( self, *, row_major: Literal[False] = False, flat: Literal[False] = False ) -> tuple[tuple[float, float], tuple[float, float], tuple[float, float]]: ... @overload def matrix2x3( self, *, row_major: Literal[True], flat: Literal[False] = False ) -> tuple[tuple[float, float, float], tuple[float, float, float]]: ... @overload def matrix2x3( self, *, row_major: bool = False, flat: Literal[True] ) -> tuple[float, float, float, float, float, float]: ...
[docs] def matrix2x3(self, *, row_major=False, flat=False): """Convert transform to a 2x3 affine transformation matrix. Returns the matrix representation of this transform, which can be used with graphics libraries or other systems that expect matrix form. Args: row_major: If True, return in row-major order; otherwise column-major (default). flat: If True, return as flat tuple; otherwise as nested tuples (default). Returns: Matrix in the requested format (nested tuples or flat tuple). >>> t = Transform.translate(10, 5) * Transform.rotate(90) >>> # Get column-major nested format (default) >>> m = t.matrix2x3() >>> # m = ((0, 1), (-1, 0), (10, 5)) >>> # Get row-major flat format >>> m = t.matrix2x3(row_major=True, flat=True) >>> # m = (0, -1, 10, 1, 0, 5) """ tx, ty = self._translate alpha = radians(self._rotate) ca = cos(alpha) sa = sin(alpha) sx, sy = self._scale m11 = ca * sx m21 = sa * sx m12 = -sa * sy m22 = ca * sy m13 = tx m23 = ty if not row_major: if flat: # fmt: off return ( m11, m21, m12, m22, m13, m23, ) # fmt: on else: # fmt: off return ( (m11, m21), (m12, m22), (m13, m23), ) # fmt: on else: if flat: # fmt: off return ( m11, m12, m13, m21, m22, m23, ) # fmt: on else: # fmt: off return ( (m11, m12, m13), (m21, m22, m23), )
# fmt: on @overload def matrix3x3( self, *, row_major: bool = False, flat: Literal[False] = False ) -> tuple[ tuple[float, float, float], tuple[float, float, float], tuple[float, float, float], ]: ... @overload def matrix3x3( self, *, row_major: bool = False, flat: Literal[True] ) -> tuple[float, float, float, float, float, float, float, float, float]: ...
[docs] def matrix3x3(self, *, row_major=False, flat=False): """Convert transform to a 3x3 homogeneous transformation matrix. Returns the matrix representation as a 3x3 homogeneous matrix with the bottom row as [0, 0, 1], suitable for use with homogeneous coordinates. Args: row_major: If True, return in row-major order; otherwise column-major (default). flat: If True, return as flat tuple; otherwise as nested tuples (default). Returns: 3x3 matrix in the requested format (nested tuples or flat tuple). >>> t = Transform.translate(10, 5) >>> m = t.matrix3x3() >>> # Column-major: ((1, 0, 0), (0, 1, 0), (10, 5, 1)) """ m11, m21, m12, m22, m13, m23 = self.matrix2x3(flat=True) if not row_major: if flat: # fmt: off return ( m11, m21, 0.0, m12, m22, 0.0, m13, m23, 1.0, ) # fmt: on else: # fmt: off return ( (m11, m21, 0.0), (m12, m22, 0.0), (m13, m23, 1.0), ) # fmt: on else: if flat: # fmt: off return ( m11, m12, m13, m21, m22, m23, 0.0, 0.0, 1.0, ) # fmt: on else: # fmt: off return ( (m11, m12, m13), (m21, m22, m23), (0.0, 0.0, 1.0), )
# fmt: on @overload @classmethod def translate(cls, x: float, y: float, /) -> Self: ... @overload @classmethod def translate(cls, vector: Point, /) -> Self: ...
[docs] @classmethod def translate(cls, x: float | Point, y: float | None = None, /): """Create a translation-only transform. Args: x: X translation or a (x, y) point. y: Y translation (not used if x is a point). Returns: New Transform with only translation (no rotation or scaling). >>> # Translate by (10, 20) >>> t = Transform.translate(10, 20) >>> # Alternative using tuple >>> t = Transform.translate((10, 20)) >>> # Use in component placement >>> component.at(Transform.translate(5, 10)) """ if isinstance(x, tuple): return cls(x, 0, (1, 1)) else: assert isinstance(y, float | int) return cls((x, y), 0, (1, 1))
[docs] @classmethod def rotate(cls, angle: float): """Create a rotation-only transform. Args: angle: Rotation angle in degrees (counter-clockwise). Returns: New Transform with only rotation (no translation or scaling). >>> # Rotate 90 degrees counter-clockwise >>> t = Transform.rotate(90) >>> # Combine with translation >>> t = Transform.translate(10, 0) * Transform.rotate(45) """ return cls((0, 0), angle, (1, 1))
@overload @classmethod def scale(cls, x: float, y: float, /) -> Self: ... @overload @classmethod def scale(cls, uniform: float, /) -> Self: ...
[docs] @classmethod def scale(cls, x: float, y: float | None = None, /): """Create a scale-only transform. Args: x: X scale factor, or uniform scale if y is not provided. y: Y scale factor (optional, defaults to x for uniform scaling). Returns: New Transform with only scaling (no translation or rotation). Note: Non-uniform scaling should be used with caution as it may not carry over properly to component placement. >>> # Uniform scale by 2x >>> t = Transform.scale(2) >>> # Non-uniform scale >>> t = Transform.scale(2, 1.5) # 2x in X, 1.5x in Y """ if y is None: y = x return cls((0, 0), 0, (x, y))
[docs] @classmethod def identity(cls): """Create an identity transform (no transformation). Returns: New Transform that applies no transformation. >>> t = Transform.identity() >>> point = (5, 10) >>> result = t * point >>> # result == (5, 10) """ return cls((0, 0), 0, (1, 1))
[docs] class ImmutableTransform(Transform): __frozen = False def __init__( self, translate: Point, rotate: float = 0, scale: tuple[float, float] = (1, 1) ): super().__init__(translate, rotate, scale) self.__frozen = True def __setattr__(self, attr, value): if self.__frozen: raise ValueError("Transform is immutable") super().__setattr__(attr, value)
IDENTITY = ImmutableTransform.identity() """Immutable identity transform constant."""
[docs] class Vec2D: """A basic 2D vector class. Used internally for some calculations, not typically used in public APIs.""" __slots__ = ("_x", "_y") _x: float _y: float def __init__(self, x: float, y: float): self._x = x self._y = y def __repr__(self): return f"Vec2D({self._x}, {self._y})" @overload def __add__(self, other: Vec2D) -> Vec2D: ... @overload def __add__(self, other: Point) -> Point: ... def __add__(self, other) -> Vec2D | Point: # vector + vector = vector # point + vector = point if isinstance(other, tuple): return (self._x + other[0], self._y + other[1]) elif isinstance(other, Vec2D): return Vec2D(self._x + other._x, self._y + other._y) else: return NotImplemented @overload def __radd__(self, other: Vec2D) -> Vec2D: ... @overload def __radd__(self, other: Point) -> Point: ... def __radd__(self, other) -> Vec2D | Point: return self.__add__(other) def __sub__(self, other) -> Vec2D: # vector - vector = vector # vector - point is undefined # point - point = vector, but can't be implemented since point is a tuple if isinstance(other, Vec2D): return Vec2D(self._x - other._x, self._y - other._y) else: return NotImplemented @overload def __rsub__(self, other: Vec2D) -> Vec2D: ... @overload def __rsub__(self, other: Point) -> Point: ... def __rsub__(self, other) -> Vec2D | Point: # point - vector = (-vector) + point = point return (-self).__add__(other) def __mul__(self, other: float) -> Vec2D: return Vec2D(self._x * other, self._y * other) def __rmul__(self, other: float) -> Vec2D: return Vec2D(self._x * other, self._y * other) def __truediv__(self, other: float) -> Vec2D: return Vec2D(self._x / other, self._y / other) def __floordiv__(self, other: float) -> Vec2D: return Vec2D(self._x // other, self._y // other) def __neg__(self) -> Vec2D: return Vec2D(-self._x, -self._y) def __pos__(self) -> Vec2D: return Vec2D(+self._x, +self._y) def __abs__(self) -> float: return sqrt(self._x**2 + self._y**2) def __getitem__(self, index: Literal[0, 1]) -> float: if index == 0: return self._x elif index == 1: return self._y else: raise IndexError("Vec2D index out of range") def __iter__(self) -> Iterator[float]: return iter((self._x, self._y)) def __eq__(self, other) -> bool: if not isinstance(other, Vec2D): return False return self._x == other._x and self._y == other._y def __lt__(self, other: Vec2D) -> bool: return self._x < other._x and self._y < other._y def __le__(self, other: Vec2D) -> bool: return self._x <= other._x and self._y <= other._y def __gt__(self, other: Vec2D) -> bool: return self._x > other._x and self._y > other._y def __ge__(self, other: Vec2D) -> bool: return self._x >= other._x and self._y >= other._y def __hash__(self) -> int: return hash((self._x, self._y)) @property def x(self) -> float: return self._x @property def y(self): return self._y @property def xy(self) -> tuple[float, float]: return self._x, self._y @property def length(self) -> float: return abs(self)
[docs] def normalized(self) -> Vec2D: return self / abs(self)
[docs] def dot(self, other: Vec2D) -> float: return self.x * other.x + self.y * other.y
[docs] def cross(self, other: Vec2D) -> float: """Compute the magnitude of the cross product.""" return self._x * other._y - self._y * other._x
[docs] def angle(self) -> float: """Compute the angle in radians counter-clockwise from the +X axis of this vector""" return atan2(self.y, self.x)
[docs] def transform_grid_point(xform: Transform, pt: GridPoint) -> GridPoint: """Transform a grid point by a transform. Grid points are integer coordinates used in schematic symbol positioning. This function ensures the result remains on the integer grid. Args: xform: Transform to apply. pt: Grid point as (x, y) integer tuple. Returns: Transformed grid point. Raises: ValueError: If the transformation results in non-integer coordinates. Examples: >>> t = Transform.translate(10, 5) >>> grid_pt = (3, 4) >>> new_pt = transform_grid_point(t, grid_pt) >>> # new_pt = (13, 9) """ x, y = xform * (pt[0], pt[1]) # Raise error if x or y is not an integer value if not x.is_integer() or not y.is_integer(): raise ValueError("Transform result is not a grid point") return (int(x), int(y))