Skip to content

Optics API Reference

The deeplens.optics module contains the differentiable lens models, optical surfaces, light representations, and image simulation utilities.


Base Classes

Base class for all optical objects. Provides device transfer, dtype conversion, and cloning by introspecting instance tensors.

deeplens.optics.DeepObj

DeepObj(dtype=None)

Base class for all differentiable optical objects in DeepLens.

Provides device management, dtype conversion, and deep-copy support via automatic introspection over instance tensors and nested DeepObj sub-objects. All lens, surface, material, ray, and wave objects inherit from this class.

Attributes:

Name Type Description
dtype dtype

Current floating-point dtype of all owned tensors.

device dtype

Current compute device (set by :meth:to).

Source code in deeplens/optics/base.py
def __init__(self, dtype=None):
    self.dtype = torch.get_default_dtype() if dtype is None else dtype

__str__

__str__()

Called when using print() and str()

Source code in deeplens/optics/base.py
def __str__(self):
    """Called when using print() and str()"""
    lines = [self.__class__.__name__ + ":"]
    for key, val in vars(self).items():
        if val.__class__.__name__ in ["list", "tuple"]:
            for i, v in enumerate(val):
                lines += "{}[{}]: {}".format(key, i, v).split("\n")
        elif val.__class__.__name__ in ["dict", "OrderedDict", "set"]:
            pass
        else:
            lines += "{}: {}".format(key, val).split("\n")

    return "\n    ".join(lines)

__call__

__call__(inp)

Call the forward function.

Source code in deeplens/optics/base.py
def __call__(self, inp):
    """Call the forward function."""
    return self.forward(inp)

clone

clone()

Clone a DeepObj object.

Source code in deeplens/optics/base.py
def clone(self):
    """Clone a DeepObj object."""
    return copy.deepcopy(self)

to

to(device)

Move all tensors and nested objects to device.

Recursively walks over every instance attribute and moves tensors, nn.Module sub-objects, and nested DeepObj objects to the requested device.

Parameters:

Name Type Description Default
device

Target device, e.g. "cuda", "cpu", or a torch.device instance.

required

Returns:

Name Type Description
DeepObj

self (for chaining).

Example

lens = GeoLens(filename="lens.json") lens.to("cuda") # move all tensors to GPU

Source code in deeplens/optics/base.py
def to(self, device):
    """Move all tensors and nested objects to *device*.

    Recursively walks over every instance attribute and moves tensors,
    ``nn.Module`` sub-objects, and nested ``DeepObj`` objects to the
    requested device.

    Args:
        device: Target device, e.g. ``"cuda"``, ``"cpu"``, or a
            ``torch.device`` instance.

    Returns:
        DeepObj: ``self`` (for chaining).

    Example:
        >>> lens = GeoLens(filename="lens.json")
        >>> lens.to("cuda")  # move all tensors to GPU
    """
    self.device = device

    for key, val in vars(self).items():
        if torch.is_tensor(val):
            setattr(self, key, val.to(device))
        elif isinstance(val, nn.Module):
            val.to(device)
        elif issubclass(type(val), DeepObj):
            val.to(device)
        elif val.__class__.__name__ in ("list", "tuple"):
            for i, v in enumerate(val):
                if torch.is_tensor(v):
                    val[i] = v.to(device)
                elif issubclass(type(v), DeepObj):
                    v.to(device)
    return self

astype

astype(dtype)

Convert all floating-point tensors to dtype.

Also calls torch.set_default_dtype(dtype) so that subsequent tensor creation uses the same precision.

Parameters:

Name Type Description Default
dtype dtype

Target floating-point dtype. Must be one of torch.float16, torch.float32, or torch.float64. Pass None to be a no-op.

required

Returns:

Name Type Description
DeepObj

self (for chaining).

Raises:

Type Description
AssertionError

If dtype is not a recognised floating-point dtype.

Example

lens = GeoLens(filename="lens.json") lens.astype(torch.float64) # switch to double precision

Source code in deeplens/optics/base.py
def astype(self, dtype):
    """Convert all floating-point tensors to *dtype*.

    Also calls ``torch.set_default_dtype(dtype)`` so that subsequent
    tensor creation uses the same precision.

    Args:
        dtype (torch.dtype): Target floating-point dtype.  Must be one of
            ``torch.float16``, ``torch.float32``, or ``torch.float64``.
            Pass ``None`` to be a no-op.

    Returns:
        DeepObj: ``self`` (for chaining).

    Raises:
        AssertionError: If *dtype* is not a recognised floating-point dtype.

    Example:
        >>> lens = GeoLens(filename="lens.json")
        >>> lens.astype(torch.float64)  # switch to double precision
    """
    if dtype is None:
        return self

    dtype_ls = [torch.float16, torch.float32, torch.float64]
    assert dtype in dtype_ls, f"Data type {dtype} is not supported."

    if torch.get_default_dtype() != dtype:
        torch.set_default_dtype(dtype)
        print(f"Set {dtype} as default torch dtype.")

    self.dtype = dtype
    for key, val in vars(self).items():
        if torch.is_tensor(val) and val.dtype in dtype_ls:
            setattr(self, key, val.to(dtype))
        elif issubclass(type(val), DeepObj):
            val.astype(dtype)
        elif issubclass(type(val), list):
            for i, v in enumerate(val):
                if torch.is_tensor(v) and v.dtype in dtype_ls:
                    val[i] = v.to(dtype)
                elif issubclass(type(v), DeepObj):
                    v.astype(dtype)
    return self

Abstract base class for all lens types. Defines the shared interface: psf(), psf_rgb(), render(), etc.

deeplens.optics.Lens

Lens(dtype=torch.float32, device=None)

Bases: DeepObj

Initialize a lens class.

Parameters:

Name Type Description Default
dtype dtype

Data type. Defaults to torch.float32.

float32
device str

Device to run the lens. Defaults to None.

None
Source code in deeplens/optics/lens.py
def __init__(self, dtype=torch.float32, device=None):
    """Initialize a lens class.

    Args:
        dtype (torch.dtype, optional): Data type. Defaults to torch.float32.
        device (str, optional): Device to run the lens. Defaults to None.
    """
    # Lens device
    if device is None:
        self.device = init_device()
    else:
        self.device = torch.device(device)

    # Lens default dtype
    self.dtype = dtype

read_lens_json

read_lens_json(filename)

Read lens from a json file.

Source code in deeplens/optics/lens.py
def read_lens_json(self, filename):
    """Read lens from a json file."""
    raise NotImplementedError

write_lens_json

write_lens_json(filename)

Write lens to a json file.

Source code in deeplens/optics/lens.py
def write_lens_json(self, filename):
    """Write lens to a json file."""
    raise NotImplementedError

set_sensor

set_sensor(sensor_size, sensor_res)

Set sensor size and resolution.

Parameters:

Name Type Description Default
sensor_size tuple

Sensor size (w, h) in [mm].

required
sensor_res tuple

Sensor resolution (W, H) in [pixels].

required
Source code in deeplens/optics/lens.py
def set_sensor(self, sensor_size, sensor_res):
    """Set sensor size and resolution.

    Args:
        sensor_size (tuple): Sensor size (w, h) in [mm].
        sensor_res (tuple): Sensor resolution (W, H) in [pixels].
    """
    assert sensor_size[0] * sensor_res[1] == sensor_size[1] * sensor_res[0], (
        "Sensor resolution aspect ratio does not match sensor size aspect ratio."
    )
    self.sensor_size = sensor_size
    self.sensor_res = sensor_res
    self.pixel_size = self.sensor_size[0] / self.sensor_res[0]
    self.r_sensor = float(np.sqrt(sensor_size[0] ** 2 + sensor_size[1] ** 2)) / 2
    self.calc_fov()

set_sensor_res

set_sensor_res(sensor_res)

Set sensor resolution (and aspect ratio) while keeping sensor radius unchanged.

Parameters:

Name Type Description Default
sensor_res tuple

Sensor resolution (W, H) in [pixels].

required
Source code in deeplens/optics/lens.py
def set_sensor_res(self, sensor_res):
    """Set sensor resolution (and aspect ratio) while keeping sensor radius unchanged.

    Args:
        sensor_res (tuple): Sensor resolution (W, H) in [pixels].
    """
    # Change sensor resolution
    self.sensor_res = sensor_res

    # Change sensor size (r_sensor is fixed)
    diam_res = float(np.sqrt(self.sensor_res[0] ** 2 + self.sensor_res[1] ** 2))
    self.sensor_size = (
        2 * self.r_sensor * self.sensor_res[0] / diam_res,
        2 * self.r_sensor * self.sensor_res[1] / diam_res,
    )
    self.pixel_size = self.sensor_size[0] / self.sensor_res[0]
    self.calc_fov()

calc_fov

calc_fov()

Compute FoV (radian) of the lens.

Reference

[1] https://en.wikipedia.org/wiki/Angle_of_view_(photography)

Source code in deeplens/optics/lens.py
@torch.no_grad()
def calc_fov(self):
    """Compute FoV (radian) of the lens.

    Reference:
        [1] https://en.wikipedia.org/wiki/Angle_of_view_(photography)
    """
    if not hasattr(self, "foclen"):
        return

    self.vfov = 2 * float(np.atan(self.sensor_size[0] / 2 / self.foclen))
    self.hfov = 2 * float(np.atan(self.sensor_size[1] / 2 / self.foclen))
    self.dfov = 2 * float(np.atan(self.r_sensor / self.foclen))
    self.rfov = self.dfov / 2  # radius (half diagonal) FoV

psf

psf(points, wvln=DEFAULT_WAVE, ks=PSF_KS, **kwargs)

Compute the monochromatic PSF for one or more point sources.

Subclasses must override this method with a differentiable implementation. Three computation models are common in practice: geometric ray binning, coherent ray-wave, and Huygens spherical-wave integration.

Parameters:

Name Type Description Default
points Tensor

Point source coordinates, shape [N, 3] or [3]. x, y are normalised to [-1, 1] (relative to the sensor half-diagonal); z is depth in mm (must be negative, i.e. in front of the lens).

required
wvln float

Wavelength in micrometers. Defaults to DEFAULT_WAVE (0.587 µm, d-line).

DEFAULT_WAVE
ks int

Output PSF kernel size in pixels. Defaults to PSF_KS (64).

PSF_KS
**kwargs

Additional keyword arguments forwarded to the underlying PSF computation (e.g. spp, model, recenter).

{}

Returns:

Type Description

torch.Tensor: PSF intensity map, shape [ks, ks] for a single

point or [N, ks, ks] for a batch.

Raises:

Type Description
NotImplementedError

This base implementation must be overridden.

Notes

The method is differentiable with respect to all optimisable lens parameters so it can be used directly inside a training loop.

Example

point = torch.tensor([0.0, 0.0, -10000.0]) psf = lens.psf(points=point, ks=64, model="geometric") print(psf.shape) # torch.Size([64, 64])

Source code in deeplens/optics/lens.py
def psf(self, points, wvln=DEFAULT_WAVE, ks=PSF_KS, **kwargs):
    """Compute the monochromatic PSF for one or more point sources.

    Subclasses must override this method with a differentiable
    implementation.  Three computation models are common in practice:
    geometric ray binning, coherent ray-wave, and Huygens spherical-wave
    integration.

    Args:
        points (torch.Tensor): Point source coordinates, shape ``[N, 3]``
            or ``[3]``.  ``x, y`` are normalised to ``[-1, 1]``
            (relative to the sensor half-diagonal); ``z`` is depth in mm
            (must be negative, i.e. in front of the lens).
        wvln (float, optional): Wavelength in micrometers.  Defaults to
            ``DEFAULT_WAVE`` (0.587 µm, d-line).
        ks (int, optional): Output PSF kernel size in pixels.  Defaults
            to ``PSF_KS`` (64).
        **kwargs: Additional keyword arguments forwarded to the underlying
            PSF computation (e.g. ``spp``, ``model``, ``recenter``).

    Returns:
        torch.Tensor: PSF intensity map, shape ``[ks, ks]`` for a single
        point or ``[N, ks, ks]`` for a batch.

    Raises:
        NotImplementedError: This base implementation must be overridden.

    Notes:
        The method is differentiable with respect to all optimisable lens
        parameters so it can be used directly inside a training loop.

    Example:
        >>> point = torch.tensor([0.0, 0.0, -10000.0])
        >>> psf = lens.psf(points=point, ks=64, model="geometric")
        >>> print(psf.shape)  # torch.Size([64, 64])
    """
    raise NotImplementedError

psf_rgb

psf_rgb(points, ks=PSF_KS, **kwargs)

Compute the RGB (tri-chromatic) PSF by stacking three wavelength calls.

Calls :meth:psf three times for the RGB primary wavelengths defined in WAVE_RGB and stacks the results along the channel axis.

Parameters:

Name Type Description Default
points Tensor

Point source coordinates, shape [N, 3] or [3]. Same convention as :meth:psf.

required
ks int

PSF kernel size. Defaults to PSF_KS.

PSF_KS
**kwargs

Forwarded to :meth:psf (e.g. spp, model).

{}

Returns:

Type Description

torch.Tensor: RGB PSF, shape [3, ks, ks] for a single point

or [N, 3, ks, ks] for a batch.

Source code in deeplens/optics/lens.py
def psf_rgb(self, points, ks=PSF_KS, **kwargs):
    """Compute the RGB (tri-chromatic) PSF by stacking three wavelength calls.

    Calls :meth:`psf` three times for the RGB primary wavelengths defined
    in ``WAVE_RGB`` and stacks the results along the channel axis.

    Args:
        points (torch.Tensor): Point source coordinates, shape ``[N, 3]``
            or ``[3]``.  Same convention as :meth:`psf`.
        ks (int, optional): PSF kernel size. Defaults to ``PSF_KS``.
        **kwargs: Forwarded to :meth:`psf` (e.g. ``spp``, ``model``).

    Returns:
        torch.Tensor: RGB PSF, shape ``[3, ks, ks]`` for a single point
        or ``[N, 3, ks, ks]`` for a batch.
    """
    psfs = []
    for wvln in WAVE_RGB:
        psfs.append(self.psf(points=points, ks=ks, wvln=wvln, **kwargs))
    psf_rgb = torch.stack(psfs, dim=-3)  # shape [3, ks, ks] or [N, 3, ks, ks]
    return psf_rgb

point_source_grid

point_source_grid(depth, grid=(9, 9), normalized=True, quater=False, center=True)

Generate point source grid for PSF calculation.

Parameters:

Name Type Description Default
depth float

Depth of the point source.

required
grid tuple

Grid size (grid_w, grid_h). Defaults to (9, 9), meaning 9x9 grid.

(9, 9)
normalized bool

Return normalized object source coordinates. Defaults to True, meaning object sources xy coordinates range from [-1, 1].

True
quater bool

Use quater of the sensor plane to save memory. Defaults to False.

False
center bool

Use center of each patch. Defaults to True.

True

Returns:

Name Type Description
point_source

Normalized object source coordinates. Shape of [grid_h, grid_w, 3], [-1, 1], [-1, 1], [-Inf, 0].

Source code in deeplens/optics/lens.py
def point_source_grid(
    self, depth, grid=(9, 9), normalized=True, quater=False, center=True
):
    """Generate point source grid for PSF calculation.

    Args:
        depth (float): Depth of the point source.
        grid (tuple): Grid size (grid_w, grid_h). Defaults to (9, 9), meaning 9x9 grid.
        normalized (bool): Return normalized object source coordinates. Defaults to True, meaning object sources xy coordinates range from [-1, 1].
        quater (bool): Use quater of the sensor plane to save memory. Defaults to False.
        center (bool): Use center of each patch. Defaults to True.

    Returns:
        point_source: Normalized object source coordinates. Shape of [grid_h, grid_w, 3], [-1, 1], [-1, 1], [-Inf, 0].
    """
    # Compute point source grid
    if grid[0] == 1:
        x, y = torch.tensor([[0.0]], device=self.device), torch.tensor([[0.0]], device=self.device)
        assert not quater, "Quater should be False when grid is 1."
    else:
        if center:
            # Use center of each patch
            half_bin_size = 1 / 2 / (grid[0] - 1)
            x, y = torch.meshgrid(
                torch.linspace(-1 + half_bin_size, 1 - half_bin_size, grid[0], device=self.device),
                torch.linspace(1 - half_bin_size, -1 + half_bin_size, grid[1], device=self.device),
                indexing="xy",
            )
        else:
            # Use corner of image sensor
            x, y = torch.meshgrid(
                torch.linspace(-0.98, 0.98, grid[0], device=self.device),
                torch.linspace(0.98, -0.98, grid[1], device=self.device),
                indexing="xy",
            )

    z = torch.full_like(x, depth)
    point_source = torch.stack([x, y, z], dim=-1)

    # Use quater of the sensor plane to save memory
    if quater:
        bound_i = grid[0] // 2 if grid[0] % 2 == 0 else grid[0] // 2 + 1
        bound_j = grid[1] // 2
        point_source = point_source[0:bound_i, bound_j:, :]

    # De-normalize object source coordinates to physical coordinates
    if not normalized:
        scale = self.calc_scale(depth)
        point_source[..., 0] *= scale * self.sensor_size[0] / 2
        point_source[..., 1] *= scale * self.sensor_size[1] / 2

    return point_source

psf_map

psf_map(grid=(5, 5), wvln=DEFAULT_WAVE, depth=DEPTH, ks=PSF_KS, **kwargs)

Compute monochrome PSF map.

Parameters:

Name Type Description Default
grid tuple

Grid size (grid_w, grid_h). Defaults to (5, 5), meaning 5x5 grid.

(5, 5)
wvln float

Wavelength. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE
depth float

Depth of the object. Defaults to DEPTH.

DEPTH
ks int

Kernel size. Defaults to PSF_KS.

PSF_KS

Returns:

Name Type Description
psf_map

Shape of [grid_h, grid_w, 3, ks, ks].

Source code in deeplens/optics/lens.py
def psf_map(self, grid=(5, 5), wvln=DEFAULT_WAVE, depth=DEPTH, ks=PSF_KS, **kwargs):
    """Compute monochrome PSF map.

    Args:
        grid (tuple): Grid size (grid_w, grid_h). Defaults to (5, 5), meaning 5x5 grid.
        wvln (float): Wavelength. Defaults to DEFAULT_WAVE.
        depth (float): Depth of the object. Defaults to DEPTH.
        ks (int): Kernel size. Defaults to PSF_KS.

    Returns:
        psf_map: Shape of [grid_h, grid_w, 3, ks, ks].
    """
    # PSF map grid
    points = self.point_source_grid(depth=depth, grid=grid, center=True)
    points = points.reshape(-1, 3)

    # Compute PSF map
    psfs = []
    for i in range(points.shape[0]):
        point = points[i, ...]
        psf = self.psf(points=point, wvln=wvln, ks=ks)
        psfs.append(psf)
    psf_map = torch.stack(psfs).unsqueeze(1)  # shape [grid_h * grid_w, 1, ks, ks]

    # Reshape PSF map from [grid_h * grid_w, 1, ks, ks] -> [grid_h, grid_w, 1, ks, ks]
    psf_map = psf_map.reshape(grid[1], grid[0], 1, ks, ks)
    return psf_map

psf_map_rgb

psf_map_rgb(grid=(5, 5), ks=PSF_KS, depth=DEPTH, **kwargs)

Compute RGB PSF map.

Parameters:

Name Type Description Default
grid tuple

Grid size (grid_w, grid_h). Defaults to (5, 5), meaning 5x5 grid.

(5, 5)
ks int

Kernel size. Defaults to PSF_KS, meaning PSF_KS x PSF_KS kernel size.

PSF_KS
depth float

Depth of the object. Defaults to DEPTH.

DEPTH
**kwargs

Additional arguments for psf_map().

{}

Returns:

Name Type Description
psf_map

Shape of [grid_h, grid_w, 3, ks, ks].

Source code in deeplens/optics/lens.py
def psf_map_rgb(self, grid=(5, 5), ks=PSF_KS, depth=DEPTH, **kwargs):
    """Compute RGB PSF map.

    Args:
        grid (tuple): Grid size (grid_w, grid_h). Defaults to (5, 5), meaning 5x5 grid.
        ks (int): Kernel size. Defaults to PSF_KS, meaning PSF_KS x PSF_KS kernel size.
        depth (float): Depth of the object. Defaults to DEPTH.
        **kwargs: Additional arguments for psf_map().

    Returns:
        psf_map: Shape of [grid_h, grid_w, 3, ks, ks].
    """
    psfs = []
    for wvln in WAVE_RGB:
        psf_map = self.psf_map(grid=grid, ks=ks, depth=depth, wvln=wvln, **kwargs)
        psfs.append(psf_map)
    psf_map = torch.cat(psfs, dim=2)  # shape [grid_h, grid_w, 3, ks, ks]
    return psf_map

draw_psf_map

draw_psf_map(grid=(7, 7), ks=PSF_KS, depth=DEPTH, log_scale=False, save_name='./psf_map.png', show=False)

Draw RGB PSF map of the lens.

Source code in deeplens/optics/lens.py
@torch.no_grad()
def draw_psf_map(
    self,
    grid=(7, 7),
    ks=PSF_KS,
    depth=DEPTH,
    log_scale=False,
    save_name="./psf_map.png",
    show=False,
):
    """Draw RGB PSF map of the lens."""
    # Calculate RGB PSF map, shape [grid_h, grid_w, 3, ks, ks]
    psf_map = self.psf_map_rgb(depth=depth, grid=grid, ks=ks)

    # Create a grid visualization (vis_map: shape [3, grid_h * ks, grid_w * ks])
    grid_w, grid_h = grid if isinstance(grid, tuple) else (grid, grid)
    h, w = grid_h * ks, grid_w * ks
    vis_map = torch.zeros((3, h, w), device=psf_map.device, dtype=psf_map.dtype)

    # Put each PSF into the vis_map
    for i in range(grid_h):
        for j in range(grid_w):
            # Extract the PSF at this grid position
            psf = psf_map[i, j]  # shape [3, ks, ks]

            # Normalize the PSF
            if log_scale:
                # Log scale normalization for better visualization
                psf = torch.log(psf + 1e-4)  # 1e-4 is an empirical value
                psf = (psf - psf.min()) / (psf.max() - psf.min() + 1e-8)
            else:
                # Linear normalization
                local_max = psf.max()
                if local_max > 0:
                    psf = psf / local_max

            # Place the normalized PSF in the visualization map
            y_start, y_end = i * ks, (i + 1) * ks
            x_start, x_end = j * ks, (j + 1) * ks
            vis_map[:, y_start:y_end, x_start:x_end] = psf

    # Create the figure and display
    fig, ax = plt.subplots(figsize=(10, 10))

    # Convert to numpy for plotting
    vis_map = vis_map.permute(1, 2, 0).cpu().numpy()
    ax.imshow(vis_map)

    # Add scale bar near bottom-left
    H, W, _ = vis_map.shape
    scale_bar_length = 100
    arrow_length = scale_bar_length / (self.pixel_size * 1e3)
    y_position = H - 20  # a little above the lower edge
    x_start = 20
    x_end = x_start + arrow_length

    ax.annotate(
        "",
        xy=(x_start, y_position),
        xytext=(x_end, y_position),
        arrowprops=dict(arrowstyle="-", color="white"),
        annotation_clip=False,
    )
    ax.text(
        x_end + 5,
        y_position,
        f"{scale_bar_length} μm",
        color="white",
        fontsize=12,
        ha="left",
        va="center",
        clip_on=False,
    )

    # Clean up axes and save
    ax.axis("off")
    plt.tight_layout(pad=0)

    if show:
        return fig, ax
    else:
        plt.savefig(save_name, dpi=300, bbox_inches="tight", pad_inches=0)
        plt.close(fig)

point_source_radial

point_source_radial(depth, grid=9, center=False, direction='diagonal', normalized=True)

Generate radial point sources from center to edge of the field.

Produces grid evenly-spaced points along a chosen radial direction (diagonal, meridional, or sagittal) in normalized or physical object-space coordinates.

Parameters:

Name Type Description Default
depth float

Object depth (z-coordinate) in mm.

required
grid int

Number of sample points. Defaults to 9.

9
center bool

If True, offset positions to bin centers. Defaults to False.

False
direction str

Sampling direction — "diagonal" (x = y, 45°, default), "y" (meridional, x = 0), "x" (sagittal, y = 0).

'diagonal'
normalized bool

If True, return coordinates in [0, 1]. If False, scale to physical object-space positions (mm). Defaults to True.

True

Returns:

Type Description

torch.Tensor: Point source positions, shape [grid, 3].

Source code in deeplens/optics/lens.py
def point_source_radial(self, depth, grid=9, center=False, direction="diagonal", normalized=True):
    """Generate radial point sources from center to edge of the field.

    Produces ``grid`` evenly-spaced points along a chosen radial direction
    (diagonal, meridional, or sagittal) in normalized or physical object-space
    coordinates.

    Args:
        depth (float): Object depth (z-coordinate) in mm.
        grid (int): Number of sample points. Defaults to 9.
        center (bool): If ``True``, offset positions to bin centers.
            Defaults to ``False``.
        direction (str): Sampling direction —
            ``"diagonal"`` (x = y, 45°, default),
            ``"y"`` (meridional, x = 0),
            ``"x"`` (sagittal, y = 0).
        normalized (bool): If ``True``, return coordinates in [0, 1].
            If ``False``, scale to physical object-space positions (mm).
            Defaults to ``True``.

    Returns:
        torch.Tensor: Point source positions, shape ``[grid, 3]``.
    """
    if grid == 1:
        r = torch.tensor([0.0], device=self.device)
    else:
        # Select center of bin to calculate PSF
        if center:
            half_bin_size = 1 / 2 / (grid - 1)
            r = torch.linspace(0, 1 - half_bin_size, grid, device=self.device)
        else:
            r = torch.linspace(0, 0.98, grid, device=self.device)

    # Map radial coordinate to (x, y) based on direction
    if direction == "diagonal":
        px, py = r, r
    elif direction == "y":
        px, py = torch.zeros_like(r), r
    elif direction == "x":
        px, py = r, torch.zeros_like(r)
    else:
        raise ValueError(f"Invalid direction: {direction!r}. Use 'diagonal', 'x', or 'y'.")

    z = torch.full_like(px, depth)
    point_source = torch.stack([px, py, z], dim=-1)

    if not normalized:
        scale = self.calc_scale(depth)
        point_source[..., 0] = point_source[..., 0] * scale * self.sensor_size[0] / 2
        point_source[..., 1] = point_source[..., 1] * scale * self.sensor_size[1] / 2

    return point_source

draw_psf_radial

draw_psf_radial(M=3, depth=DEPTH, ks=PSF_KS, log_scale=False, save_name='./psf_radial.png')

Draw radial PSF (45 deg). Will draw M PSFs, each of size ks x ks.

Source code in deeplens/optics/lens.py
@torch.no_grad()
def draw_psf_radial(
    self, M=3, depth=DEPTH, ks=PSF_KS, log_scale=False, save_name="./psf_radial.png"
):
    """Draw radial PSF (45 deg). Will draw M PSFs, each of size ks x ks."""
    x = torch.linspace(0, 1, M)
    y = torch.linspace(0, 1, M)
    z = torch.full_like(x, depth)
    points = torch.stack((x, y, z), dim=-1)

    psfs = []
    for i in range(M):
        # Scale PSF for a better visualization
        psf = self.psf_rgb(points=points[i], ks=ks, recenter=True, spp=SPP_PSF)
        psf /= psf.max()

        if log_scale:
            psf = torch.log(psf + EPSILON)
            psf = (psf - psf.min()) / (psf.max() - psf.min())

        psfs.append(psf)

    psf_grid = make_grid(psfs, nrow=M, padding=1, pad_value=0.0)
    save_image(psf_grid, save_name, normalize=True)

render

render(img_obj, depth=DEPTH, method='psf_patch', **kwargs)

Differentiable image simulation for a 2D (flat) scene.

Performs only the optical component of image simulation and is fully differentiable. Sensor noise is handled separately by the :class:~deeplens.camera.Camera class.

For incoherent imaging the intensity PSF is convolved with the object-space image. For coherent imaging the complex PSF is convolved with the complex object image before squaring for intensity.

Parameters:

Name Type Description Default
img_obj Tensor

Input image in linear (raw) space, shape [B, C, H, W].

required
depth float

Object depth in mm (negative value). Defaults to DEPTH (-20 000 mm, i.e. infinity).

DEPTH
method str

Rendering method. One of:

  • "psf_patch" – convolve a single PSF evaluated at patch_center (default).
  • "psf_map" – spatially-varying PSF block convolution.
'psf_patch'
**kwargs

Method-specific keyword arguments:

  • For "psf_map": psf_grid (tuple, default (10, 10)), psf_ks (int, default PSF_KS).
  • For "psf_patch": patch_center (tuple or Tensor, default (0.0, 0.0)), psf_ks (int).
{}

Returns:

Type Description

torch.Tensor: Rendered image, shape [B, C, H, W].

Raises:

Type Description
AssertionError

If method is "psf_map" and the image resolution does not match the sensor resolution.

Exception

If method is not recognised.

References

[1] "Optical Aberration Correction in Postprocessing using Imaging Simulation", TOG 2021. [2] "Efficient depth- and spatially-varying image simulation for defocus deblur", ICCVW 2025.

Example

img_rendered = lens.render(img, depth=-10000.0, method="psf_patch", ... patch_center=(0.3, 0.0), psf_ks=64)

Source code in deeplens/optics/lens.py
def render(self, img_obj, depth=DEPTH, method="psf_patch", **kwargs):
    """Differentiable image simulation for a 2D (flat) scene.

    Performs only the optical component of image simulation and is fully
    differentiable.  Sensor noise is handled separately by the
    :class:`~deeplens.camera.Camera` class.

    For incoherent imaging the intensity PSF is convolved with the
    object-space image.  For coherent imaging the complex PSF is convolved
    with the complex object image before squaring for intensity.

    Args:
        img_obj (torch.Tensor): Input image in linear (raw) space,
            shape ``[B, C, H, W]``.
        depth (float, optional): Object depth in mm (negative value).
            Defaults to ``DEPTH`` (-20 000 mm, i.e. infinity).
        method (str, optional): Rendering method.  One of:

            * ``"psf_patch"`` – convolve a single PSF evaluated at
              *patch_center* (default).
            * ``"psf_map"`` – spatially-varying PSF block convolution.

        **kwargs: Method-specific keyword arguments:

            * For ``"psf_map"``: ``psf_grid`` (tuple, default ``(10, 10)``),
              ``psf_ks`` (int, default ``PSF_KS``).
            * For ``"psf_patch"``: ``patch_center`` (tuple or Tensor,
              default ``(0.0, 0.0)``), ``psf_ks`` (int).

    Returns:
        torch.Tensor: Rendered image, shape ``[B, C, H, W]``.

    Raises:
        AssertionError: If *method* is ``"psf_map"`` and the image
            resolution does not match the sensor resolution.
        Exception: If *method* is not recognised.

    References:
        [1] "Optical Aberration Correction in Postprocessing using Imaging Simulation", TOG 2021.
        [2] "Efficient depth- and spatially-varying image simulation for defocus deblur", ICCVW 2025.

    Example:
        >>> img_rendered = lens.render(img, depth=-10000.0, method="psf_patch",
        ...                            patch_center=(0.3, 0.0), psf_ks=64)
    """
    # Check sensor resolution
    B, C, Himg, Wimg = img_obj.shape
    Wsensor, Hsensor = self.sensor_res

    # Image simulation (in RAW space)
    if method == "psf_map":
        # Render full resolution image with PSF map convolution
        assert Wimg == Wsensor and Himg == Hsensor, (
            f"Sensor resolution {Wsensor}x{Hsensor} must match input image {Wimg}x{Himg}."
        )
        psf_grid = kwargs.get("psf_grid", (10, 10))
        psf_ks = kwargs.get("psf_ks", PSF_KS)
        img_render = self.render_psf_map(
            img_obj, depth=depth, psf_grid=psf_grid, psf_ks=psf_ks
        )

    elif method == "psf_patch":
        # Render an image patch with its corresponding PSF
        patch_center = kwargs.get("patch_center", (0.0, 0.0))
        psf_ks = kwargs.get("psf_ks", PSF_KS)
        img_render = self.render_psf_patch(
            img_obj, depth=depth, patch_center=patch_center, psf_ks=psf_ks
        )

    elif method == "psf_pixel":
        raise NotImplementedError(
            "Per-pixel PSF convolution has not been implemented."
        )

    else:
        raise Exception(f"Image simulation method {method} is not supported.")

    return img_render

render_psf

render_psf(img_obj, depth=DEPTH, patch_center=(0, 0), psf_ks=PSF_KS)

Render image patch using PSF convolution. Better not use this function to avoid confusion.

Source code in deeplens/optics/lens.py
def render_psf(self, img_obj, depth=DEPTH, patch_center=(0, 0), psf_ks=PSF_KS):
    """Render image patch using PSF convolution. Better not use this function to avoid confusion."""
    return self.render_psf_patch(
        img_obj, depth=depth, patch_center=patch_center, psf_ks=psf_ks
    )

render_psf_patch

render_psf_patch(img_obj, depth=DEPTH, patch_center=(0, 0), psf_ks=PSF_KS)

Render an image patch using PSF convolution, and return positional encoding channel.

Parameters:

Name Type Description Default
img_obj tensor

Input image object in raw space. Shape of [B, C, H, W].

required
depth float

Depth of the object.

DEPTH
patch_center tensor

Center of the image patch. Shape of [2] or [B, 2].

(0, 0)
psf_ks int

PSF kernel size. Defaults to PSF_KS.

PSF_KS

Returns:

Name Type Description
img_render

Rendered image. Shape of [B, C, H, W].

Source code in deeplens/optics/lens.py
def render_psf_patch(self, img_obj, depth=DEPTH, patch_center=(0, 0), psf_ks=PSF_KS):
    """Render an image patch using PSF convolution, and return positional encoding channel.

    Args:
        img_obj (tensor): Input image object in raw space. Shape of [B, C, H, W].
        depth (float): Depth of the object.
        patch_center (tensor): Center of the image patch. Shape of [2] or [B, 2].
        psf_ks (int): PSF kernel size. Defaults to PSF_KS.

    Returns:
        img_render: Rendered image. Shape of [B, C, H, W].
    """
    # Convert patch_center to tensor
    if isinstance(patch_center, (list, tuple)):
        points = (patch_center[0], patch_center[1], depth)
        points = torch.tensor(points).unsqueeze(0)
    elif isinstance(patch_center, torch.Tensor):
        depth = torch.full_like(patch_center[..., 0], depth)
        points = torch.stack(
            [patch_center[..., 0], patch_center[..., 1], depth], dim=-1
        )
    else:
        raise Exception(
            f"Patch center must be a list or tuple or tensor, but got {type(patch_center)}."
        )

    # Compute PSF and perform PSF convolution
    psf = self.psf_rgb(points=points, ks=psf_ks).squeeze(0)
    img_render = conv_psf(img_obj, psf=psf)
    return img_render

render_psf_map

render_psf_map(img_obj, depth=DEPTH, psf_grid=7, psf_ks=PSF_KS)

Render image using PSF block convolution.

Note

Larger psf_grid and psf_ks are typically better for more accurate rendering, but slower.

Parameters:

Name Type Description Default
img_obj tensor

Input image object in raw space. Shape of [B, C, H, W].

required
depth float

Depth of the object.

DEPTH
psf_grid int

PSF grid size.

7
psf_ks int

PSF kernel size. Defaults to PSF_KS.

PSF_KS

Returns:

Name Type Description
img_render

Rendered image. Shape of [B, C, H, W].

Source code in deeplens/optics/lens.py
def render_psf_map(self, img_obj, depth=DEPTH, psf_grid=7, psf_ks=PSF_KS):
    """Render image using PSF block convolution.

    Note:
        Larger psf_grid and psf_ks are typically better for more accurate rendering, but slower.

    Args:
        img_obj (tensor): Input image object in raw space. Shape of [B, C, H, W].
        depth (float): Depth of the object.
        psf_grid (int): PSF grid size.
        psf_ks (int): PSF kernel size. Defaults to PSF_KS.

    Returns:
        img_render: Rendered image. Shape of [B, C, H, W].
    """
    psf_map = self.psf_map_rgb(grid=psf_grid, ks=psf_ks, depth=depth)
    img_render = conv_psf_map(img_obj, psf_map)
    return img_render

render_rgbd

render_rgbd(img_obj, depth_map, method='psf_patch', **kwargs)

Render RGBD image.

TODO: add obstruction-aware image simulation.

Parameters:

Name Type Description Default
img_obj tensor

Object image. Shape of [B, C, H, W].

required
depth_map tensor

Depth map [mm]. Shape of [B, 1, H, W]. Values should be positive.

required
method str

Image simulation method. Defaults to "psf_patch".

'psf_patch'
**kwargs

Additional arguments for different methods. - interp_mode (str): "depth" or "disparity". Defaults to "depth".

{}

Returns:

Name Type Description
img_render

Rendered image. Shape of [B, C, H, W].

Reference

[1] "Aberration-Aware Depth-from-Focus", TPAMI 2023. [2] "Efficient Depth- and Spatially-Varying Image Simulation for Defocus Deblur", ICCVW 2025.

Source code in deeplens/optics/lens.py
def render_rgbd(self, img_obj, depth_map, method="psf_patch", **kwargs):
    """Render RGBD image.

    TODO: add obstruction-aware image simulation.

    Args:
        img_obj (tensor): Object image. Shape of [B, C, H, W].
        depth_map (tensor): Depth map [mm]. Shape of [B, 1, H, W]. Values should be positive.
        method (str, optional): Image simulation method. Defaults to "psf_patch".
        **kwargs: Additional arguments for different methods.
            - interp_mode (str): "depth" or "disparity". Defaults to "depth".

    Returns:
        img_render: Rendered image. Shape of [B, C, H, W].

    Reference:
        [1] "Aberration-Aware Depth-from-Focus", TPAMI 2023.
        [2] "Efficient Depth- and Spatially-Varying Image Simulation for Defocus Deblur", ICCVW 2025.
    """
    if depth_map.min() < 0:
        raise ValueError("Depth map should be positive.")

    if len(depth_map.shape) == 3:
        # [B, H, W] -> [B, 1, H, W]
        depth_map = depth_map.unsqueeze(1)

    if method == "psf_patch":
        # Render an image patch (same FoV, different depth)
        patch_center = kwargs.get("patch_center", (0.0, 0.0))
        psf_ks = kwargs.get("psf_ks", PSF_KS)
        depth_min = kwargs.get("depth_min", depth_map.min())
        depth_max = kwargs.get("depth_max", depth_map.max())
        num_layers = kwargs.get("num_layers", 16)
        interp_mode = kwargs.get("interp_mode", "disparity")

        # Calculate PSF at different depths, (num_layers, 3, ks, ks)
        disp_ref, depths_ref = self._sample_depth_layers(depth_min, depth_max, num_layers)

        points = torch.stack(
            [
                torch.full_like(depths_ref, patch_center[0]),
                torch.full_like(depths_ref, patch_center[1]),
                depths_ref,
            ],
            dim=-1,
        )
        psfs = self.psf_rgb(points=points, ks=psf_ks) # (num_layers, 3, ks, ks)

        # Image simulation
        img_render = conv_psf_depth_interp(img_obj, -depth_map, psfs, depths_ref, interp_mode=interp_mode)
        return img_render

    elif method == "psf_map":
        # Render full resolution image with PSF map convolution (different FoV, different depth)
        psf_grid = kwargs.get("psf_grid", (8, 8))  # (grid_w, grid_h)
        psf_ks = kwargs.get("psf_ks", PSF_KS)
        depth_min = kwargs.get("depth_min", depth_map.min())
        depth_max = kwargs.get("depth_max", depth_map.max())
        num_layers = kwargs.get("num_layers", 16)
        interp_mode = kwargs.get("interp_mode", "disparity")

        # Calculate PSF map at different depths (convert to negative for PSF calculation)
        disp_ref, depths_ref = self._sample_depth_layers(depth_min, depth_max, num_layers)

        psf_maps = []
        for depth in tqdm(depths_ref):
            psf_map = self.psf_map_rgb(grid=psf_grid, ks=psf_ks, depth=depth)
            psf_maps.append(psf_map)
        psf_map = torch.stack(
            psf_maps, dim=2
        )  # shape [grid_h, grid_w, num_layers, 3, ks, ks]

        # Image simulation
        img_render = conv_psf_map_depth_interp(
            img_obj, -depth_map, psf_map, depths_ref, interp_mode=interp_mode
        )
        return img_render

    elif method == "psf_pixel":
        # Render full resolution image with pixel-wise PSF convolution. This method is computationally expensive.
        psf_ks = kwargs.get("psf_ks", PSF_KS)
        assert img_obj.shape[0] == 1, "Now only support batch size 1"

        # Calculate points in the object space
        points_xy = torch.meshgrid(
            torch.linspace(-1, 1, img_obj.shape[-1], device=self.device),
            torch.linspace(1, -1, img_obj.shape[-2], device=self.device),
            indexing="xy",
        )
        points_xy = torch.stack(points_xy, dim=0).unsqueeze(0)
        points = torch.cat([points_xy, -depth_map], dim=1)  # shape [B, 3, H, W]

        # Calculate PSF at different pixels. This step is the most time-consuming.
        points = points.permute(0, 2, 3, 1).reshape(-1, 3)  # shape [H*W, 3]
        psfs = self.psf_rgb(points=points, ks=psf_ks)  # shape [H*W, 3, ks, ks]
        psfs = psfs.reshape(
            img_obj.shape[-2], img_obj.shape[-1], 3, psf_ks, psf_ks
        )  # shape [H, W, 3, ks, ks]

        # Image simulation
        img_render = conv_psf_pixel(img_obj, psfs)  # shape [1, C, H, W]
        return img_render

    else:
        raise Exception(f"Image simulation method {method} is not supported.")

activate_grad

activate_grad(activate=True)

Activate gradient for each surface.

Source code in deeplens/optics/lens.py
def activate_grad(self, activate=True):
    """Activate gradient for each surface."""
    raise NotImplementedError

get_optimizer_params

get_optimizer_params(lr=[0.0001, 0.0001, 0.1, 0.001])

Get optimizer parameters for different lens parameters.

Source code in deeplens/optics/lens.py
def get_optimizer_params(self, lr=[1e-4, 1e-4, 1e-1, 1e-3]):
    """Get optimizer parameters for different lens parameters."""
    raise NotImplementedError

get_optimizer

get_optimizer(lr=[0.0001, 0.0001, 0, 0.001])

Get optimizer.

Source code in deeplens/optics/lens.py
def get_optimizer(self, lr=[1e-4, 1e-4, 0, 1e-3]):
    """Get optimizer."""
    params = self.get_optimizer_params(lr)
    optimizer = torch.optim.Adam(params)
    return optimizer

Lens Models

Differentiable multi-element refractive lens via geometric ray tracing. This is the primary lens model in DeepLens.

GeoLens uses a mixin architecture — functionality is split across GeoLensPSF, GeoLensEval, GeoLensSeidel, GeoLensOptim, GeoLensSurfOps, GeoLensVis, GeoLensIO, GeoLensTolerance, and GeoLensVis3D.

deeplens.optics.GeoLens

GeoLens(filename=None, device=None, dtype=torch.float32)

Bases: GeoLensPSF, GeoLensEval, GeoLensSeidel, GeoLensOptim, GeoLensSurfOps, GeoLensVis, GeoLensIO, GeoLensTolerance, GeoLensVis3D, Lens

Differentiable geometric lens using vectorised ray tracing.

The primary lens model in DeepLens. Supports multi-element refractive (and partially reflective) systems loaded from JSON, Zemax .zmx, or Code V .seq files. Accuracy is aligned with Zemax OpticStudio.

Uses a mixin architecture – eight specialised mixin classes are composed at class definition time to keep each concern isolated:

  • :class:~deeplens.optics.geolens_pkg.psf_compute.GeoLensPSF – PSF computation (geometric, coherent, Huygens models).
  • :class:~deeplens.optics.geolens_pkg.eval.GeoLensEval – optical performance evaluation (spot, MTF, distortion, vignetting).
  • :class:~deeplens.optics.geolens_pkg.optim.GeoLensOptim – loss functions and gradient-based optimisation.
  • :class:~deeplens.optics.geolens_pkg.optim_ops.GeoLensSurfOps – surface geometry operations (aspheric conversion, pruning, shape correction, material matching).
  • :class:~deeplens.optics.geolens_pkg.vis.GeoLensVis – 2-D layout and ray visualisation.
  • :class:~deeplens.optics.geolens_pkg.io.GeoLensIO – read/write JSON, Zemax .zmx.
  • :class:~deeplens.optics.geolens_pkg.eval_tolerance.GeoLensTolerance – manufacturing tolerance analysis.
  • :class:~deeplens.optics.geolens_pkg.vis3d.GeoLensVis3D – 3-D mesh visualisation.

Key differentiability trick: Ray-surface intersection (:meth:~deeplens.optics.geometric_surface.base.Surface.newtons_method) uses a non-differentiable Newton loop followed by one differentiable Newton step to enable gradient flow.

Attributes:

Name Type Description
surfaces list[Surface]

Ordered list of optical surfaces.

materials list[Material]

Optical materials between surfaces.

d_sensor Tensor

Back focal distance [mm].

foclen float

Effective focal length [mm].

fnum float

F-number.

rfov float

Half-diagonal field of view [radians].

sensor_size tuple

Physical sensor size (W, H) [mm].

sensor_res tuple

Sensor resolution (W, H) [pixels].

pixel_size float

Pixel pitch [mm].

References

Xinge Yang et al., "Curriculum learning for ab initio deep learned refractive optics," Nature Communications 2024.

Initialize a refractive lens.

There are two ways to initialize a GeoLens
  1. Read a lens from .json/.zmx/.seq file
  2. Initialize a lens with no lens file, then manually add surfaces and materials

Parameters:

Name Type Description Default
filename str

Path to lens file (.json, .zmx, or .seq). Defaults to None.

None
device device

Device for tensor computations. Defaults to None.

None
dtype dtype

Data type for computations. Defaults to torch.float32.

float32
Source code in deeplens/optics/geolens.py
def __init__(
    self,
    filename=None,
    device=None,
    dtype=torch.float32,
):
    """Initialize a refractive lens.

    There are two ways to initialize a GeoLens:
        1. Read a lens from .json/.zmx/.seq file
        2. Initialize a lens with no lens file, then manually add surfaces and materials

    Args:
        filename (str, optional): Path to lens file (.json, .zmx, or .seq). Defaults to None.
        device (torch.device, optional): Device for tensor computations. Defaults to None.
        dtype (torch.dtype, optional): Data type for computations. Defaults to torch.float32.
    """
    super().__init__(device=device, dtype=dtype)

    # Load lens file
    if filename is not None:
        self.read_lens(filename)
    else:
        self.surfaces = []
        self.materials = []
        # Set default sensor size and resolution
        self.sensor_size = (8.0, 8.0)
        self.sensor_res = (2000, 2000)
        self.to(self.device)

read_lens

read_lens(filename)

Read a GeoLens from a file.

Supported file formats
  • .json: DeepLens native JSON format
  • .zmx: Zemax lens file format
  • .seq: CODE V sequence file format

Parameters:

Name Type Description Default
filename str

Path to the lens file.

required
Note

Sensor size and resolution will usually be overwritten by values from the file.

Source code in deeplens/optics/geolens.py
def read_lens(self, filename):
    """Read a GeoLens from a file.

    Supported file formats:
        - .json: DeepLens native JSON format
        - .zmx: Zemax lens file format
        - .seq: CODE V sequence file format

    Args:
        filename (str): Path to the lens file.

    Note:
        Sensor size and resolution will usually be overwritten by values from the file.
    """
    # Load lens file
    if filename[-4:] == ".txt":
        raise ValueError("File format .txt has been deprecated.")
    elif filename[-5:] == ".json":
        self.read_lens_json(filename)
    elif filename[-4:] == ".zmx":
        self.read_lens_zmx(filename)
    elif filename[-4:] == ".seq":
        self.read_lens_seq(filename)
    else:
        raise ValueError(f"File format {filename[-4:]} not supported.")

    # Complete sensor size and resolution if not set from lens file
    if not hasattr(self, "sensor_size"):
        self.sensor_size = (8.0, 8.0)
        print(
            f"Sensor_size not found in lens file. Using default: {self.sensor_size} mm. "
            "Consider specifying sensor_size in the lens file or using set_sensor()."
        )

    if not hasattr(self, "sensor_res"):
        self.sensor_res = (2000, 2000)
        print(
            f"Sensor_res not found in lens file. Using default: {self.sensor_res} pixels. "
            "Consider specifying sensor_res in the lens file or using set_sensor()."
        )
        self.set_sensor_res(self.sensor_res)

    # After loading lens, compute foclen, fov and fnum
    self.to(self.device)
    self.astype(self.dtype)
    self.post_computation()

post_computation

post_computation()

Compute derived optical properties after loading or modifying lens.

Calculates and caches
  • Effective focal length (EFL)
  • Entrance and exit pupil positions and radii
  • Field of view (FoV) in horizontal, vertical, and diagonal directions
  • F-number
Note

This method should be called after any changes to the lens geometry.

Source code in deeplens/optics/geolens.py
def post_computation(self):
    """Compute derived optical properties after loading or modifying lens.

    Calculates and caches:
        - Effective focal length (EFL)
        - Entrance and exit pupil positions and radii
        - Field of view (FoV) in horizontal, vertical, and diagonal directions
        - F-number

    Note:
        This method should be called after any changes to the lens geometry.
    """
    self.calc_foclen()
    self.calc_pupil()
    self.calc_fov()

__call__

__call__(ray)

Trace rays through the lens system.

Makes the GeoLens callable, allowing ray tracing with function call syntax.

Source code in deeplens/optics/geolens.py
def __call__(self, ray):
    """Trace rays through the lens system.

    Makes the GeoLens callable, allowing ray tracing with function call syntax.
    """
    return self.trace(ray)

sample_grid_rays

sample_grid_rays(depth=float('inf'), num_grid=(11, 11), num_rays=SPP_PSF, wvln=DEFAULT_WAVE, uniform_fov=True, sample_more_off_axis=False, scale_pupil=1.0)

Sample grid rays from object space. (1) If depth is infinite, sample parallel rays at different field angles. (2) If depth is finite, sample point source rays from the object plane.

This function is usually used for (1) PSF map, (2) RMS error map, and (3) spot diagram calculation.

Parameters:

Name Type Description Default
depth float

sampling depth. Defaults to float("inf").

float('inf')
num_grid tuple

number of grid points. Defaults to [11, 11].

(11, 11)
num_rays int

number of rays. Defaults to SPP_PSF.

SPP_PSF
wvln float

ray wvln. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE
uniform_fov bool

If True, sample uniform FoV angles.

True
sample_more_off_axis bool

If True, sample more off-axis rays.

False
scale_pupil float

Scale factor for pupil radius.

1.0

Returns:

Name Type Description
ray Ray object

Ray object. Shape [num_grid[1], num_grid[0], num_rays, 3]

Source code in deeplens/optics/geolens.py
@torch.no_grad()
def sample_grid_rays(
    self,
    depth=float("inf"),
    num_grid=(11, 11),
    num_rays=SPP_PSF,
    wvln=DEFAULT_WAVE,
    uniform_fov=True,
    sample_more_off_axis=False,
    scale_pupil=1.0,
):
    """Sample grid rays from object space.
        (1) If depth is infinite, sample parallel rays at different field angles.
        (2) If depth is finite, sample point source rays from the object plane.

    This function is usually used for (1) PSF map, (2) RMS error map, and (3) spot diagram calculation.

    Args:
        depth (float, optional): sampling depth. Defaults to float("inf").
        num_grid (tuple, optional): number of grid points. Defaults to [11, 11].
        num_rays (int, optional): number of rays. Defaults to SPP_PSF.
        wvln (float, optional): ray wvln. Defaults to DEFAULT_WAVE.
        uniform_fov (bool, optional): If True, sample uniform FoV angles.
        sample_more_off_axis (bool, optional): If True, sample more off-axis rays.
        scale_pupil (float, optional): Scale factor for pupil radius.

    Returns:
        ray (Ray object): Ray object. Shape [num_grid[1], num_grid[0], num_rays, 3]
    """
    # Normalize num_grid to a tuple if it's an int
    if isinstance(num_grid, int):
        num_grid = (num_grid, num_grid)

    # Calculate field angles for grid source. Top-left field has positive fov_x and negative fov_y
    x_list = [x for x in np.linspace(1, -1, num_grid[0])]
    y_list = [y for y in np.linspace(-1, 1, num_grid[1])]
    if sample_more_off_axis:
        x_list = [np.sign(x) * np.abs(x) ** 0.5 for x in x_list]
        y_list = [np.sign(y) * np.abs(y) ** 0.5 for y in y_list]

    # Calculate FoV_x and FoV_y
    if uniform_fov:
        # Sample uniform FoV angles
        fov_x_list = [x * self.vfov / 2 for x in x_list]
        fov_y_list = [y * self.hfov / 2 for y in y_list]
        fov_x_list = [float(np.rad2deg(fov_x)) for fov_x in fov_x_list]
        fov_y_list = [float(np.rad2deg(fov_y)) for fov_y in fov_y_list]
    else:
        # Sample uniform object grid
        fov_x_list = [np.arctan(x * np.tan(self.vfov / 2)) for x in x_list]
        fov_y_list = [np.arctan(y * np.tan(self.hfov / 2)) for y in y_list]
        fov_x_list = [float(np.rad2deg(fov_x)) for fov_x in fov_x_list]
        fov_y_list = [float(np.rad2deg(fov_y)) for fov_y in fov_y_list]

    # Sample rays (collimated or point source via unified API)
    rays = self.sample_from_fov(
        fov_x=fov_x_list,
        fov_y=fov_y_list,
        depth=depth,
        num_rays=num_rays,
        wvln=wvln,
        scale_pupil=scale_pupil,
    )
    return rays

sample_radial_rays

sample_radial_rays(num_field=5, depth=float('inf'), num_rays=SPP_PSF, wvln=DEFAULT_WAVE, direction='y')

Sample radial rays at evenly-spaced field angles along a chosen direction.

Parameters:

Name Type Description Default
num_field int

Number of field angles from on-axis to full-field. Defaults to 5.

5
depth float

Object distance in mm. Use float('inf') for collimated light. Defaults to float('inf').

float('inf')
num_rays int

Rays per field position. Defaults to SPP_PSF.

SPP_PSF
wvln float

Wavelength in micrometers. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE
direction str

Sampling direction — "y" (meridional, default), "x" (sagittal), "diagonal" (45°, x = y).

'y'

Returns:

Name Type Description
Ray

Ray object with shape [num_field, num_rays, 3].

Source code in deeplens/optics/geolens.py
@torch.no_grad()
def sample_radial_rays(
    self,
    num_field=5,
    depth=float("inf"),
    num_rays=SPP_PSF,
    wvln=DEFAULT_WAVE,
    direction="y",
):
    """Sample radial rays at evenly-spaced field angles along a chosen direction.

    Args:
        num_field (int): Number of field angles from on-axis to full-field.
            Defaults to 5.
        depth (float): Object distance in mm. Use ``float('inf')`` for
            collimated light. Defaults to ``float('inf')``.
        num_rays (int): Rays per field position. Defaults to ``SPP_PSF``.
        wvln (float): Wavelength in micrometers. Defaults to ``DEFAULT_WAVE``.
        direction (str): Sampling direction —
            ``"y"`` (meridional, default),
            ``"x"`` (sagittal),
            ``"diagonal"`` (45°, x = y).

    Returns:
        Ray: Ray object with shape ``[num_field, num_rays, 3]``.
    """
    device = self.device
    fov_deg = float(np.rad2deg(self.rfov))
    fov_list = torch.linspace(0, fov_deg, num_field, device=device)

    if direction == "y":
        ray = self.sample_from_fov(
            fov_x=0.0, fov_y=fov_list, depth=depth, num_rays=num_rays, wvln=wvln
        )
    elif direction == "x":
        ray = self.sample_from_fov(
            fov_x=fov_list, fov_y=0.0, depth=depth, num_rays=num_rays, wvln=wvln
        )
    elif direction == "diagonal":
        # sample_from_fov creates a meshgrid; for pairwise diagonal, loop
        rays = [
            self.sample_from_fov(
                fov_x=f.item(), fov_y=f.item(), depth=depth, num_rays=num_rays, wvln=wvln
            )
            for f in fov_list
        ]
        ray_o = torch.stack([r.o for r in rays], dim=0)
        ray_d = torch.stack([r.d for r in rays], dim=0)
        ray = Ray(ray_o, ray_d, wvln, device=device)
    else:
        raise ValueError(f"Invalid direction: {direction!r}. Use 'x', 'y', or 'diagonal'.")
    return ray

sample_from_points

sample_from_points(points=[[0.0, 0.0, -10000.0]], num_rays=SPP_PSF, wvln=DEFAULT_WAVE, scale_pupil=1.0)

Sample rays from point sources in object space (absolute physical coordinates).

Used for PSF and chief ray calculation.

Parameters:

Name Type Description Default
points list or Tensor

Ray origins in shape [3], [N, 3], or [Nx, Ny, 3].

[[0.0, 0.0, -10000.0]]
num_rays int

Number of rays per point. Default: SPP_PSF.

SPP_PSF
wvln float

Wavelength of rays. Default: DEFAULT_WAVE.

DEFAULT_WAVE
scale_pupil float

Scale factor for pupil radius.

1.0

Returns:

Name Type Description
Ray

Sampled rays with shape (\*points.shape[:-1], num_rays, 3).

Source code in deeplens/optics/geolens.py
@torch.no_grad()
def sample_from_points(
    self,
    points=[[0.0, 0.0, -10000.0]],
    num_rays=SPP_PSF,
    wvln=DEFAULT_WAVE,
    scale_pupil=1.0,
):
    """
    Sample rays from point sources in object space (absolute physical coordinates).

    Used for PSF and chief ray calculation.

    Args:
        points (list or Tensor): Ray origins in shape [3], [N, 3], or [Nx, Ny, 3].
        num_rays (int): Number of rays per point. Default: SPP_PSF.
        wvln (float): Wavelength of rays. Default: DEFAULT_WAVE.
        scale_pupil (float): Scale factor for pupil radius.

    Returns:
        Ray: Sampled rays with shape ``(\\*points.shape[:-1], num_rays, 3)``.
    """
    # Ray origin is given
    if not torch.is_tensor(points):
        ray_o = torch.tensor(points, device=self.device)
    else:
        ray_o = points.to(self.device)

    # Sample points on the pupil
    pupilz, pupilr = self.get_entrance_pupil()
    pupilr *= scale_pupil
    ray_o2 = self.sample_circle(
        r=pupilr, z=pupilz, shape=(*ray_o.shape[:-1], num_rays)
    )

    # Compute ray directions
    if len(ray_o.shape) == 1:
        # Input point shape is [3]
        ray_o = ray_o.unsqueeze(0).repeat(num_rays, 1)  # shape [num_rays, 3]
        ray_d = ray_o2 - ray_o

    elif len(ray_o.shape) == 2:
        # Input point shape is [N, 3]
        ray_o = ray_o.unsqueeze(1).repeat(1, num_rays, 1)  # shape [N, num_rays, 3]
        ray_d = ray_o2 - ray_o

    elif len(ray_o.shape) == 3:
        # Input point shape is [Nx, Ny, 3]
        ray_o = ray_o.unsqueeze(2).repeat(
            1, 1, num_rays, 1
        )  # shape [Nx, Ny, num_rays, 3]
        ray_d = ray_o2 - ray_o

    else:
        raise Exception("The shape of input object positions is not supported.")

    # Calculate rays
    rays = Ray(ray_o, ray_d, wvln, device=self.device)
    return rays

sample_from_points_by_fov

sample_from_points_by_fov(fov_x=[0.0], fov_y=[0.0], depth=DEPTH, num_rays=SPP_PSF, wvln=DEFAULT_WAVE, scale_pupil=1.0)

Sample point-source rays specified by field angles and depth.

Converts field angles to physical object-space coordinates, then delegates to :meth:sample_from_points.

Parameters:

Name Type Description Default
fov_x float or list

Field angle(s) in the xz plane (degrees).

[0.0]
fov_y float or list

Field angle(s) in the yz plane (degrees).

[0.0]
depth float

Object distance in mm. Default: DEPTH.

DEPTH
num_rays int

Number of rays per field point. Default: SPP_PSF.

SPP_PSF
wvln float

Wavelength of rays. Default: DEFAULT_WAVE.

DEFAULT_WAVE
scale_pupil float

Scale factor for pupil radius. Default: 1.0.

1.0

Returns:

Name Type Description
Ray

Rays with shape [..., num_rays, 3], where leading dims follow the same scalar-squeeze convention as :meth:sample_from_fov.

Source code in deeplens/optics/geolens.py
@torch.no_grad()
def sample_from_points_by_fov(
    self,
    fov_x=[0.0],
    fov_y=[0.0],
    depth=DEPTH,
    num_rays=SPP_PSF,
    wvln=DEFAULT_WAVE,
    scale_pupil=1.0,
):
    """Sample point-source rays specified by field angles and depth.

    Converts field angles to physical object-space coordinates, then
    delegates to :meth:`sample_from_points`.

    Args:
        fov_x (float or list): Field angle(s) in the xz plane (degrees).
        fov_y (float or list): Field angle(s) in the yz plane (degrees).
        depth (float): Object distance in mm. Default: ``DEPTH``.
        num_rays (int): Number of rays per field point. Default: SPP_PSF.
        wvln (float): Wavelength of rays. Default: DEFAULT_WAVE.
        scale_pupil (float): Scale factor for pupil radius. Default: 1.0.

    Returns:
        Ray:
            Rays with shape [..., num_rays, 3], where leading dims follow
            the same scalar-squeeze convention as :meth:`sample_from_fov`.
    """
    x_scalar = isinstance(fov_x, (float, int))
    y_scalar = isinstance(fov_y, (float, int))
    if x_scalar:
        fov_x = [float(fov_x)]
    if y_scalar:
        fov_y = [float(fov_y)]

    fov_x_rad = torch.tensor([fx * torch.pi / 180 for fx in fov_x], device=self.device)
    fov_y_rad = torch.tensor([fy * torch.pi / 180 for fy in fov_y], device=self.device)
    fov_x_grid, fov_y_grid = torch.meshgrid(fov_x_rad, fov_y_rad, indexing="xy")
    x = torch.tan(fov_x_grid) * depth
    y = torch.tan(fov_y_grid) * depth
    z = torch.full_like(x, depth)
    points = torch.stack((x, y, z), dim=-1)  # [len(fov_y), len(fov_x), 3]

    # Squeeze scalar dims before sample_from_points so the Ray is
    # constructed with the right shape (avoids post-hoc attribute edits).
    if x_scalar:
        points = points.squeeze(-2)
    if y_scalar:
        points = points.squeeze(0)

    return self.sample_from_points(
        points=points, num_rays=num_rays, wvln=wvln, scale_pupil=scale_pupil
    )

sample_from_fov

sample_from_fov(fov_x=[0.0], fov_y=[0.0], depth=float('inf'), num_rays=SPP_CALC, wvln=DEFAULT_WAVE, entrance_pupil=True, prop_to=-1.0, scale_pupil=1.0)

Sample rays from object space at given field angles.

Unified entry point for both collimated (infinite-depth) and diverging (finite-depth) ray bundles specified by field-of-view angles.

Parameters:

Name Type Description Default
fov_x float or list

Field angle(s) in the xz plane (degrees). Default: [0.0].

[0.0]
fov_y float or list

Field angle(s) in the yz plane (degrees). Default: [0.0].

[0.0]
depth float

Object distance in mm. float('inf') for collimated rays, finite value for point-source rays. Default: float('inf').

float('inf')
num_rays int

Number of rays per field point. Default: SPP_CALC.

SPP_CALC
wvln float

Wavelength of rays. Default: DEFAULT_WAVE.

DEFAULT_WAVE
entrance_pupil bool

If True, sample origins on entrance pupil; otherwise, on surface 0. Only used for infinite depth. Default: True.

True
prop_to float

Propagation depth in z (only for infinite depth). Default: -1.0.

-1.0
scale_pupil float

Scale factor for pupil radius. Default: 1.0.

1.0

Returns:

Name Type Description
Ray

Rays with shape [..., num_rays, 3], where leading dims are: - both fov_x and fov_y scalars: [num_rays, 3] - fov_x scalar: [len(fov_y), num_rays, 3] - fov_y scalar: [len(fov_x), num_rays, 3] - both lists: [len(fov_y), len(fov_x), num_rays, 3] Ordered as (u, v).

Source code in deeplens/optics/geolens.py
@torch.no_grad()
def sample_from_fov(
    self,
    fov_x=[0.0],
    fov_y=[0.0],
    depth=float("inf"),
    num_rays=SPP_CALC,
    wvln=DEFAULT_WAVE,
    entrance_pupil=True,
    prop_to=-1.0,
    scale_pupil=1.0,
):
    """Sample rays from object space at given field angles.

    Unified entry point for both collimated (infinite-depth) and diverging
    (finite-depth) ray bundles specified by field-of-view angles.

    Args:
        fov_x (float or list): Field angle(s) in the xz plane (degrees). Default: [0.0].
        fov_y (float or list): Field angle(s) in the yz plane (degrees). Default: [0.0].
        depth (float): Object distance in mm. ``float('inf')`` for collimated
            rays, finite value for point-source rays. Default: ``float('inf')``.
        num_rays (int): Number of rays per field point. Default: SPP_CALC.
        wvln (float): Wavelength of rays. Default: DEFAULT_WAVE.
        entrance_pupil (bool): If True, sample origins on entrance pupil;
            otherwise, on surface 0. Only used for infinite depth. Default: True.
        prop_to (float): Propagation depth in z (only for infinite depth). Default: -1.0.
        scale_pupil (float): Scale factor for pupil radius. Default: 1.0.

    Returns:
        Ray:
            Rays with shape [..., num_rays, 3], where leading dims are:
            - both fov_x and fov_y scalars: [num_rays, 3]
            - fov_x scalar: [len(fov_y), num_rays, 3]
            - fov_y scalar: [len(fov_x), num_rays, 3]
            - both lists: [len(fov_y), len(fov_x), num_rays, 3]
            Ordered as (u, v).
    """
    # Finite depth: delegate to sample_from_points_by_fov
    if depth != float("inf"):
        return self.sample_from_points_by_fov(
            fov_x=fov_x, fov_y=fov_y, depth=depth,
            num_rays=num_rays, wvln=wvln, scale_pupil=scale_pupil,
        )

    # --- Infinite depth: collimated parallel rays ---
    # Remember whether inputs were scalar
    x_scalar = isinstance(fov_x, (float, int))
    y_scalar = isinstance(fov_y, (float, int))

    # Normalize to lists for internal processing
    if x_scalar:
        fov_x = [float(fov_x)]
    if y_scalar:
        fov_y = [float(fov_y)]

    fov_x = torch.tensor([fx * torch.pi / 180 for fx in fov_x], device=self.device)
    fov_y = torch.tensor([fy * torch.pi / 180 for fy in fov_y], device=self.device)

    # Sample ray origins on the pupil
    if entrance_pupil:
        pupilz, pupilr = self.get_entrance_pupil()
        pupilr *= scale_pupil
    else:
        pupilz, pupilr = 0.0, self.surfaces[0].r
        pupilr *= scale_pupil

    ray_o = self.sample_circle(
        r=pupilr, z=pupilz, shape=[len(fov_y), len(fov_x), num_rays]
    )

    # Sample ray directions
    fov_x_grid, fov_y_grid = torch.meshgrid(fov_x, fov_y, indexing="xy")
    dx = torch.tan(fov_x_grid).unsqueeze(-1).expand_as(ray_o[..., 0])
    dy = torch.tan(fov_y_grid).unsqueeze(-1).expand_as(ray_o[..., 1])
    dz = torch.ones_like(ray_o[..., 2])
    ray_d = torch.stack((dx, dy, dz), dim=-1)

    # Squeeze singleton FOV dims only if the original input was scalar
    if x_scalar:
        ray_o = ray_o.squeeze(1)
        ray_d = ray_d.squeeze(1)
    if y_scalar:
        ray_o = ray_o.squeeze(0)
        ray_d = ray_d.squeeze(0)

    rays = Ray(ray_o, ray_d, wvln, device=self.device)
    rays.prop_to(prop_to)
    return rays

sample_sensor

sample_sensor(spp=64, wvln=DEFAULT_WAVE, sub_pixel=False)

Sample rays from sensor pixels (backward rays). Used for ray tracing rendering.

Parameters:

Name Type Description Default
spp int

sample per pixel. Defaults to 64.

64
pupil bool

whether to use pupil. Defaults to True.

required
wvln float

ray wvln. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE
sub_pixel bool

whether to sample multiple points inside the pixel. Defaults to False.

False

Returns:

Name Type Description
ray Ray object

Ray object. Shape [H, W, spp, 3]

Source code in deeplens/optics/geolens.py
@torch.no_grad()
def sample_sensor(self, spp=64, wvln=DEFAULT_WAVE, sub_pixel=False):
    """Sample rays from sensor pixels (backward rays). Used for ray tracing rendering.

    Args:
        spp (int, optional): sample per pixel. Defaults to 64.
        pupil (bool, optional): whether to use pupil. Defaults to True.
        wvln (float, optional): ray wvln. Defaults to DEFAULT_WAVE.
        sub_pixel (bool, optional): whether to sample multiple points inside the pixel. Defaults to False.

    Returns:
        ray (Ray object): Ray object. Shape [H, W, spp, 3]
    """
    w, h = self.sensor_size
    W, H = self.sensor_res
    device = self.device

    # Sample points on sensor plane
    # Use top-left point as reference in rendering, so here we should sample bottom-right point
    x1, y1 = torch.meshgrid(
        torch.linspace(
            -w / 2,
            w / 2,
            W + 1,
            device=device,
        )[1:],
        torch.linspace(
            h / 2,
            -h / 2,
            H + 1,
            device=device,
        )[1:],
        indexing="xy",
    )
    z1 = torch.full_like(x1, self.d_sensor)

    # Sample second points on the pupil
    pupilz, pupilr = self.get_exit_pupil()
    ray_o2 = self.sample_circle(r=pupilr, z=pupilz, shape=(*self.sensor_res, spp))

    # Form rays
    ray_o = torch.stack((x1, y1, z1), 2)
    ray_o = ray_o.unsqueeze(2).repeat(1, 1, spp, 1)  # [H, W, spp, 3]

    # Sub-pixel sampling for more realistic rendering
    if sub_pixel:
        delta_ox = (
            torch.rand(ray_o.shape[:-1], device=device)
            * self.pixel_size
        )
        delta_oy = (
            -torch.rand(ray_o.shape[:-1], device=device)
            * self.pixel_size
        )
        delta_oz = torch.zeros_like(delta_ox)
        delta_o = torch.stack((delta_ox, delta_oy, delta_oz), -1)
        ray_o = ray_o + delta_o

    # Form rays
    ray_d = ray_o2 - ray_o  # shape [H, W, spp, 3]
    ray = Ray(ray_o, ray_d, wvln, device=device)
    return ray

sample_circle

sample_circle(r, z, shape=[16, 16, 512])

Sample points inside a circle.

Parameters:

Name Type Description Default
r float

Radius of the circle.

required
z float

Z-coordinate for all sampled points.

required
shape list

Shape of the output tensor.

[16, 16, 512]

Returns:

Type Description

torch.Tensor: Sampled points, shape (\*shape, 3).

Source code in deeplens/optics/geolens.py
def sample_circle(self, r, z, shape=[16, 16, 512]):
    """Sample points inside a circle.

    Args:
        r (float): Radius of the circle.
        z (float): Z-coordinate for all sampled points.
        shape (list): Shape of the output tensor.

    Returns:
        torch.Tensor: Sampled points, shape ``(\\*shape, 3)``.
    """
    device = self.device

    # Generate random angles and radii
    theta = torch.rand(*shape, device=device) * 2 * torch.pi
    r2 = torch.rand(*shape, device=device) * r**2
    radius = torch.sqrt(r2)

    # Stack to form 3D points
    x = radius * torch.cos(theta)
    y = radius * torch.sin(theta)
    z_tensor = torch.full_like(x, z)
    points = torch.stack((x, y, z_tensor), dim=-1)

    # Manually sample chief ray
    # points[..., 0, :2] = 0.0

    return points

trace

trace(ray, surf_range=None, record=False)

Trace rays through the lens.

Forward or backward tracing is automatically determined by the ray direction.

Parameters:

Name Type Description Default
ray Ray object

Ray object.

required
surf_range list

Surface index range.

None
record bool

record ray path or not.

False

Returns:

Name Type Description
ray_final Ray object

ray after optical system.

ray_o_rec list

list of intersection points.

Source code in deeplens/optics/geolens.py
def trace(self, ray, surf_range=None, record=False):
    """Trace rays through the lens.

    Forward or backward tracing is automatically determined by the ray direction.

    Args:
        ray (Ray object): Ray object.
        surf_range (list): Surface index range.
        record (bool): record ray path or not.

    Returns:
        ray_final (Ray object): ray after optical system.
        ray_o_rec (list): list of intersection points.
    """
    if surf_range is None:
        surf_range = range(0, len(self.surfaces))

    if (ray.d[..., 2] > 0).any():
        ray_out, ray_o_rec = self.forward_tracing(ray, surf_range, record=record)
    else:
        ray_out, ray_o_rec = self.backward_tracing(ray, surf_range, record=record)

    return ray_out, ray_o_rec

trace2obj

trace2obj(ray)

Traces rays backwards through all lens surfaces from sensor side to object side.

Parameters:

Name Type Description Default
ray Ray

Ray object to trace backwards.

required

Returns:

Name Type Description
Ray

Ray object after backward propagation through the lens.

Source code in deeplens/optics/geolens.py
def trace2obj(self, ray):
    """Traces rays backwards through all lens surfaces from sensor side
    to object side.

    Args:
        ray (Ray): Ray object to trace backwards.

    Returns:
        Ray: Ray object after backward propagation through the lens.
    """
    ray, _ = self.trace(ray)
    return ray

trace2sensor

trace2sensor(ray, record=False)

Forward trace rays through the lens to sensor plane.

Parameters:

Name Type Description Default
ray Ray object

Ray object.

required
record bool

record ray path or not.

False

Returns:

Name Type Description
ray_out Ray object

ray after optical system.

ray_o_record list

list of intersection points.

Source code in deeplens/optics/geolens.py
def trace2sensor(self, ray, record=False):
    """Forward trace rays through the lens to sensor plane.

    Args:
        ray (Ray object): Ray object.
        record (bool): record ray path or not.

    Returns:
        ray_out (Ray object): ray after optical system.
        ray_o_record (list): list of intersection points.
    """
    # Manually propagate ray to a shallow depth to avoid numerical instability
    if ray.o[..., 2].min() < -100.0:
        ray = ray.prop_to(-10.0)

    # Trace rays
    ray, ray_o_record = self.trace(ray, record=record)
    ray = ray.prop_to(self.d_sensor)

    if record:
        ray_o = ray.o.clone().detach()
        # Set to NaN to be skipped in 2d layout visualization
        ray_o[ray.is_valid == 0] = float("nan")
        ray_o_record.append(ray_o)
        return ray, ray_o_record
    else:
        return ray

trace2exit_pupil

trace2exit_pupil(ray)

Forward trace rays through the lens to exit pupil plane.

Parameters:

Name Type Description Default
ray Ray

Ray object to trace.

required

Returns:

Name Type Description
Ray

Ray object propagated to the exit pupil plane.

Source code in deeplens/optics/geolens.py
def trace2exit_pupil(self, ray):
    """Forward trace rays through the lens to exit pupil plane.

    Args:
        ray (Ray): Ray object to trace.

    Returns:
        Ray: Ray object propagated to the exit pupil plane.
    """
    ray = self.trace2sensor(ray)
    pupil_z, _ = self.get_exit_pupil()
    ray = ray.prop_to(pupil_z)
    return ray

forward_tracing

forward_tracing(ray, surf_range, record)

Forward traces rays through each surface in the specified range from object side to image side.

Parameters:

Name Type Description Default
ray Ray

Ray object to trace.

required
surf_range range

Range of surface indices to trace through.

required
record bool

If True, record ray positions at each surface.

required

Returns:

Name Type Description
tuple

(ray_out, ray_o_record) where: - ray_out (Ray): Ray after propagation through all surfaces. - ray_o_record (list or None): List of ray positions at each surface, or None if record is False.

Source code in deeplens/optics/geolens.py
def forward_tracing(self, ray, surf_range, record):
    """Forward traces rays through each surface in the specified range from object side to image side.

    Args:
        ray (Ray): Ray object to trace.
        surf_range (range): Range of surface indices to trace through.
        record (bool): If True, record ray positions at each surface.

    Returns:
        tuple: (ray_out, ray_o_record) where:
            - ray_out (Ray): Ray after propagation through all surfaces.
            - ray_o_record (list or None): List of ray positions at each surface,
                or None if record is False.
    """
    if record:
        ray_o_record = []
        ray_o_record.append(ray.o.clone().detach())
    else:
        ray_o_record = None

    mat1 = Material("air")
    for i in surf_range:
        n1 = mat1.ior(ray.wvln)
        n2 = self.surfaces[i].mat2.ior(ray.wvln)
        ray = self.surfaces[i].ray_reaction(ray, n1, n2)
        mat1 = self.surfaces[i].mat2

        if record:
            ray_out_o = ray.o.clone().detach()
            ray_out_o[ray.is_valid == 0] = float("nan")
            ray_o_record.append(ray_out_o)

    return ray, ray_o_record

backward_tracing

backward_tracing(ray, surf_range, record)

Backward traces rays through each surface in reverse order from image side to object side.

Parameters:

Name Type Description Default
ray Ray

Ray object to trace.

required
surf_range range

Range of surface indices to trace through.

required
record bool

If True, record ray positions at each surface.

required

Returns:

Name Type Description
tuple

(ray_out, ray_o_record) where: - ray_out (Ray): Ray after backward propagation through all surfaces. - ray_o_record (list or None): List of ray positions at each surface, or None if record is False.

Source code in deeplens/optics/geolens.py
def backward_tracing(self, ray, surf_range, record):
    """Backward traces rays through each surface in reverse order from image side to object side.

    Args:
        ray (Ray): Ray object to trace.
        surf_range (range): Range of surface indices to trace through.
        record (bool): If True, record ray positions at each surface.

    Returns:
        tuple: (ray_out, ray_o_record) where:
            - ray_out (Ray): Ray after backward propagation through all surfaces.
            - ray_o_record (list or None): List of ray positions at each surface,
                or None if record is False.
    """
    if record:
        ray_o_record = []
        ray_o_record.append(ray.o.clone().detach())
    else:
        ray_o_record = None

    mat1 = Material("air")
    for i in np.flip(surf_range):
        n1 = mat1.ior(ray.wvln)
        n2 = self.surfaces[i - 1].mat2.ior(ray.wvln)
        ray = self.surfaces[i].ray_reaction(ray, n1, n2)
        mat1 = self.surfaces[i - 1].mat2

        if record:
            ray_out_o = ray.o.clone().detach()
            ray_out_o[ray.is_valid == 0] = float("nan")
            ray_o_record.append(ray_out_o)

    return ray, ray_o_record

render

render(img_obj, depth=DEPTH, method='ray_tracing', **kwargs)

Differentiable image simulation.

Image simulation methods

[1] PSF map block convolution. [2] PSF patch convolution. [3] Ray tracing rendering.

Parameters:

Name Type Description Default
img_obj Tensor

Input image object in raw space. Shape of [N, C, H, W].

required
depth float

Depth of the object. Defaults to DEPTH.

DEPTH
method str

Image simulation method. One of 'psf_map', 'psf_patch', or 'ray_tracing'. Defaults to 'ray_tracing'.

'ray_tracing'
**kwargs

Additional arguments for different methods: - psf_grid (tuple): Grid size for PSF map method. Defaults to (10, 10). - psf_ks (int): Kernel size for PSF methods. Defaults to PSF_KS. - patch_center (tuple): Center position for PSF patch method. - spp (int): Samples per pixel for ray tracing. Defaults to SPP_RENDER.

{}

Returns:

Name Type Description
Tensor

Rendered image tensor. Shape of [N, C, H, W].

Source code in deeplens/optics/geolens.py
def render(self, img_obj, depth=DEPTH, method="ray_tracing", **kwargs):
    """Differentiable image simulation.

    Image simulation methods:
        [1] PSF map block convolution.
        [2] PSF patch convolution.
        [3] Ray tracing rendering.

    Args:
        img_obj (Tensor): Input image object in raw space. Shape of [N, C, H, W].
        depth (float, optional): Depth of the object. Defaults to DEPTH.
        method (str, optional): Image simulation method. One of 'psf_map', 'psf_patch',
            or 'ray_tracing'. Defaults to 'ray_tracing'.
        **kwargs: Additional arguments for different methods:
            - psf_grid (tuple): Grid size for PSF map method. Defaults to (10, 10).
            - psf_ks (int): Kernel size for PSF methods. Defaults to PSF_KS.
            - patch_center (tuple): Center position for PSF patch method.
            - spp (int): Samples per pixel for ray tracing. Defaults to SPP_RENDER.

    Returns:
        Tensor: Rendered image tensor. Shape of [N, C, H, W].
    """
    B, C, Himg, Wimg = img_obj.shape
    Wsensor, Hsensor = self.sensor_res

    # Image simulation
    if method == "psf_map":
        # PSF rendering - uses PSF map to render image
        assert Wimg == Wsensor and Himg == Hsensor, (
            f"Sensor resolution {Wsensor}x{Hsensor} must match input image {Wimg}x{Himg}."
        )
        psf_grid = kwargs.get("psf_grid", (10, 10))
        psf_ks = kwargs.get("psf_ks", PSF_KS)
        img_render = self.render_psf_map(
            img_obj, depth=depth, psf_grid=psf_grid, psf_ks=psf_ks
        )

    elif method == "psf_patch":
        # PSF patch rendering - uses a single PSF to render a patch of the image
        patch_center = kwargs.get("patch_center", (0.0, 0.0))
        psf_ks = kwargs.get("psf_ks", PSF_KS)
        img_render = self.render_psf_patch(
            img_obj, depth=depth, patch_center=patch_center, psf_ks=psf_ks
        )

    elif method == "ray_tracing":
        # Ray tracing rendering
        assert Wimg == Wsensor and Himg == Hsensor, (
            f"Sensor resolution {Wsensor}x{Hsensor} must match input image {Wimg}x{Himg}."
        )
        spp = kwargs.get("spp", SPP_RENDER)
        img_render = self.render_raytracing(img_obj, depth=depth, spp=spp)

    else:
        raise Exception(f"Image simulation method {method} is not supported.")

    return img_render

render_raytracing

render_raytracing(img, depth=DEPTH, spp=SPP_RENDER, vignetting=False)

Render RGB image using ray tracing rendering.

Parameters:

Name Type Description Default
img tensor

RGB image tensor. Shape of [N, 3, H, W].

required
depth float

Depth of the object. Defaults to DEPTH.

DEPTH
spp int

Sample per pixel. Defaults to 64.

SPP_RENDER
vignetting bool

whether to consider vignetting effect. Defaults to False.

False

Returns:

Name Type Description
img_render tensor

Rendered RGB image tensor. Shape of [N, 3, H, W].

Source code in deeplens/optics/geolens.py
def render_raytracing(self, img, depth=DEPTH, spp=SPP_RENDER, vignetting=False):
    """Render RGB image using ray tracing rendering.

    Args:
        img (tensor): RGB image tensor. Shape of [N, 3, H, W].
        depth (float, optional): Depth of the object. Defaults to DEPTH.
        spp (int, optional): Sample per pixel. Defaults to 64.
        vignetting (bool, optional): whether to consider vignetting effect. Defaults to False.

    Returns:
        img_render (tensor): Rendered RGB image tensor. Shape of [N, 3, H, W].
    """
    img_render = torch.zeros_like(img)
    for i in range(3):
        img_render[:, i, :, :] = self.render_raytracing_mono(
            img=img[:, i, :, :],
            wvln=WAVE_RGB[i],
            depth=depth,
            spp=spp,
            vignetting=vignetting,
        )
    return img_render

render_raytracing_mono

render_raytracing_mono(img, wvln, depth=DEPTH, spp=64, vignetting=False)

Render monochrome image using ray tracing rendering.

Parameters:

Name Type Description Default
img tensor

Monochrome image tensor. Shape of [N, 1, H, W] or [N, H, W].

required
wvln float

Wavelength of the light.

required
depth float

Depth of the object. Defaults to DEPTH.

DEPTH
spp int

Sample per pixel. Defaults to 64.

64

Returns:

Name Type Description
img_mono tensor

Rendered monochrome image tensor. Shape of [N, 1, H, W] or [N, H, W].

Source code in deeplens/optics/geolens.py
def render_raytracing_mono(self, img, wvln, depth=DEPTH, spp=64, vignetting=False):
    """Render monochrome image using ray tracing rendering.

    Args:
        img (tensor): Monochrome image tensor. Shape of [N, 1, H, W] or [N, H, W].
        wvln (float): Wavelength of the light.
        depth (float, optional): Depth of the object. Defaults to DEPTH.
        spp (int, optional): Sample per pixel. Defaults to 64.

    Returns:
        img_mono (tensor): Rendered monochrome image tensor. Shape of [N, 1, H, W] or [N, H, W].
    """
    img = torch.flip(img, [-2, -1])
    scale = self.calc_scale(depth=depth)
    ray = self.sample_sensor(spp=spp, wvln=wvln)
    ray = self.trace2obj(ray)
    img_mono = self.render_compute_image(
        img, depth=depth, scale=scale, ray=ray, vignetting=vignetting
    )
    return img_mono

render_compute_image

render_compute_image(img, depth, scale, ray, vignetting=False)

Computes the intersection points between rays and the object image plane, then generates the rendered image following rendering equation.

Back-propagation gradient flow: image -> w_i -> u -> p -> ray -> surface

Parameters:

Name Type Description Default
img tensor

[N, C, H, W] or [N, H, W] shape image tensor.

required
depth float

depth of the object.

required
scale float

scale factor.

required
ray Ray object

Ray object. Shape [H, W, spp, 3].

required
vignetting bool

whether to consider vignetting effect.

False

Returns:

Name Type Description
image tensor

[N, C, H, W] or [N, H, W] shape rendered image tensor.

Source code in deeplens/optics/geolens.py
def render_compute_image(self, img, depth, scale, ray, vignetting=False):
    """Computes the intersection points between rays and the object image plane, then generates the rendered image following rendering equation.

    Back-propagation gradient flow: image -> w_i -> u -> p -> ray -> surface

    Args:
        img (tensor): [N, C, H, W] or [N, H, W] shape image tensor.
        depth (float): depth of the object.
        scale (float): scale factor.
        ray (Ray object): Ray object. Shape [H, W, spp, 3].
        vignetting (bool): whether to consider vignetting effect.

    Returns:
        image (tensor): [N, C, H, W] or [N, H, W] shape rendered image tensor.
    """
    assert torch.is_tensor(img), "Input image should be Tensor."

    # Padding
    H, W = img.shape[-2:]
    if len(img.shape) == 3:
        img = F.pad(img.unsqueeze(1), (1, 1, 1, 1), "replicate").squeeze(1)
    elif len(img.shape) == 4:
        img = F.pad(img, (1, 1, 1, 1), "replicate")
    else:
        raise ValueError("Input image should be [N, C, H, W] or [N, H, W] tensor.")

    # Scale object image physical size to get 1:1 pixel-pixel alignment with sensor image
    ray = ray.prop_to(depth)
    p = ray.o[..., :2]
    pixel_size = scale * self.pixel_size
    ray.is_valid = (
        ray.is_valid
        * (torch.abs(p[..., 0] / pixel_size) < (W / 2 + 1))
        * (torch.abs(p[..., 1] / pixel_size) < (H / 2 + 1))
    )

    # Convert to uv coordinates in object image coordinate
    # (we do padding so corrdinates should add 1)
    u = torch.clamp(W / 2 + p[..., 0] / pixel_size, min=-0.99, max=W - 0.01)
    v = torch.clamp(H / 2 + p[..., 1] / pixel_size, min=0.01, max=H + 0.99)

    # (idx_i, idx_j) denotes left-top pixel (reference pixel). Index does not store gradients
    # (idx + 1 because we did padding)
    idx_i = H - v.ceil().long() + 1
    idx_j = u.floor().long() + 1

    # Gradients are stored in interpolation weight parameters
    w_i = v - v.floor().long()
    w_j = u.ceil().long() - u

    # Bilinear interpolation
    # (img shape [B, N, H', W'], idx_i shape [H, W, spp], w_i shape [H, W, spp], irr_img shape [N, C, H, W, spp])
    irr_img = img[..., idx_i, idx_j] * w_i * w_j
    irr_img += img[..., idx_i + 1, idx_j] * (1 - w_i) * w_j
    irr_img += img[..., idx_i, idx_j + 1] * w_i * (1 - w_j)
    irr_img += img[..., idx_i + 1, idx_j + 1] * (1 - w_i) * (1 - w_j)

    # Computation image
    if not vignetting:
        image = torch.sum(irr_img * ray.is_valid, -1) / (
            torch.sum(ray.is_valid, -1) + EPSILON
        )
    else:
        image = torch.sum(irr_img * ray.is_valid, -1) / torch.numel(ray.is_valid)

    return image

unwarp

unwarp(img, depth=DEPTH, num_grid=128, crop=True, flip=True)

Unwarp rendered images using distortion map.

Parameters:

Name Type Description Default
img tensor

Rendered image tensor. Shape of [N, C, H, W].

required
depth float

Depth of the object. Defaults to DEPTH.

DEPTH
grid_size int

Grid size. Defaults to 256.

required
crop bool

Whether to crop the image. Defaults to True.

True

Returns:

Name Type Description
img_unwarpped tensor

Unwarped image tensor. Shape of [N, C, H, W].

Source code in deeplens/optics/geolens.py
def unwarp(self, img, depth=DEPTH, num_grid=128, crop=True, flip=True):
    """Unwarp rendered images using distortion map.

    Args:
        img (tensor): Rendered image tensor. Shape of [N, C, H, W].
        depth (float, optional): Depth of the object. Defaults to DEPTH.
        grid_size (int, optional): Grid size. Defaults to 256.
        crop (bool, optional): Whether to crop the image. Defaults to True.

    Returns:
        img_unwarpped (tensor): Unwarped image tensor. Shape of [N, C, H, W].
    """
    # Calculate distortion map, shape (num_grid, num_grid, 2)
    distortion_map = self.calc_distortion_map(depth=depth, num_grid=num_grid)

    # Interpolate distortion map to image resolution
    distortion_map = distortion_map.permute(2, 0, 1).unsqueeze(1)
    # distortion_map = torch.flip(distortion_map, [-2]) if flip else distortion_map
    distortion_map = F.interpolate(
        distortion_map, img.shape[-2:], mode="bilinear", align_corners=True
    )  # shape (B, 2, Himg, Wimg)
    distortion_map = distortion_map.permute(1, 2, 3, 0).repeat(
        img.shape[0], 1, 1, 1
    )  # shape (B, Himg, Wimg, 2)

    # Unwarp using grid_sample function
    img_unwarpped = F.grid_sample(
        img, distortion_map, align_corners=True
    )  # shape (B, C, Himg, Wimg)
    return img_unwarpped

find_diff_surf

find_diff_surf()

Get differentiable/optimizable surface indices.

Returns a list of surface indices that can be optimized during lens design. Excludes the aperture surface from optimization.

Returns:

Type Description

list or range: Surface indices excluding the aperture.

Source code in deeplens/optics/geolens.py
def find_diff_surf(self):
    """Get differentiable/optimizable surface indices.

    Returns a list of surface indices that can be optimized during lens design.
    Excludes the aperture surface from optimization.

    Returns:
        list or range: Surface indices excluding the aperture.
    """
    if self.aper_idx is None:
        diff_surf_range = range(len(self.surfaces))
    else:
        diff_surf_range = list(range(0, self.aper_idx)) + list(
            range(self.aper_idx + 1, len(self.surfaces))
        )
    return diff_surf_range

calc_foclen

calc_foclen()

Compute effective focal length (EFL).

Traces a paraxial chief ray and computes the image height, then uses the image height to compute the EFL.

Updates

self.efl: Effective focal length. self.foclen: Alias for effective focal length. self.bfl: Back focal length (distance from last surface to sensor).

Reference

[1] https://wp.optics.arizona.edu/optomech/wp-content/uploads/sites/53/2016/10/Tutorial_MorelSophie.pdf [2] https://rafcamera.com/info/imaging-theory/back-focal-length

Source code in deeplens/optics/geolens.py
@torch.no_grad()
def calc_foclen(self):
    """Compute effective focal length (EFL).

    Traces a paraxial chief ray and computes the image height, then uses the image height to compute the EFL.

    Updates:
        self.efl: Effective focal length.
        self.foclen: Alias for effective focal length.
        self.bfl: Back focal length (distance from last surface to sensor).

    Reference:
        [1] https://wp.optics.arizona.edu/optomech/wp-content/uploads/sites/53/2016/10/Tutorial_MorelSophie.pdf
        [2] https://rafcamera.com/info/imaging-theory/back-focal-length
    """
    # Trace a paraxial chief ray, shape [1, 1, num_rays, 3]
    paraxial_fov = 0.01
    paraxial_fov_deg = float(np.rad2deg(paraxial_fov))

    # 1. Trace on-axis parallel rays to find paraxial focus z (equivalent to infinite focus)
    ray_axis = self.sample_from_fov(
        fov_x=0.0, fov_y=0.0, entrance_pupil=False, scale_pupil=0.2
    )
    ray_axis, _ = self.trace(ray_axis)
    valid_axis = ray_axis.is_valid > 0
    t = -(ray_axis.d[valid_axis, 0] * ray_axis.o[valid_axis, 0]
          + ray_axis.d[valid_axis, 1] * ray_axis.o[valid_axis, 1]) / (
        ray_axis.d[valid_axis, 0] ** 2 + ray_axis.d[valid_axis, 1] ** 2
    )
    focus_z = ray_axis.o[valid_axis, 2] + t * ray_axis.d[valid_axis, 2]
    focus_z = focus_z[~torch.isnan(focus_z) & (focus_z > 0)]
    paraxial_focus_z = float(torch.mean(focus_z))

    # 2. Trace off-axis paraxial ray to paraxial focus, measure image height
    ray = self.sample_from_fov(
        fov_x=0.0, fov_y=paraxial_fov_deg, entrance_pupil=False, scale_pupil=0.2
    )
    ray, _ = self.trace(ray)
    ray = ray.prop_to(paraxial_focus_z)

    # Compute the effective focal length
    paraxial_imgh = (ray.o[:, 1] * ray.is_valid).sum() / ray.is_valid.sum()
    eff_foclen = paraxial_imgh.item() / float(np.tan(paraxial_fov))
    self.efl = eff_foclen
    self.foclen = eff_foclen

    # Compute the back focal length
    self.bfl = self.d_sensor.item() - self.surfaces[-1].d.item()

    return eff_foclen

calc_numerical_aperture

calc_numerical_aperture(n=1.0)

Compute numerical aperture (NA).

Parameters:

Name Type Description Default
n float

Refractive index. Defaults to 1.0.

1.0

Returns:

Name Type Description
NA float

Numerical aperture.

Reference

[1] https://en.wikipedia.org/wiki/Numerical_aperture

Source code in deeplens/optics/geolens.py
@torch.no_grad()
def calc_numerical_aperture(self, n=1.0):
    """Compute numerical aperture (NA).

    Args:
        n (float, optional): Refractive index. Defaults to 1.0.

    Returns:
        NA (float): Numerical aperture.

    Reference:
        [1] https://en.wikipedia.org/wiki/Numerical_aperture
    """
    return n * math.sin(math.atan(1 / 2 / self.fnum))

calc_focal_plane

calc_focal_plane(wvln=DEFAULT_WAVE)

Compute the focus distance in the object space. Ray starts from sensor center and traces to the object space.

Parameters:

Name Type Description Default
wvln float

Wavelength. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE

Returns:

Name Type Description
focal_plane float

Focal plane in the object space.

Source code in deeplens/optics/geolens.py
@torch.no_grad()
def calc_focal_plane(self, wvln=DEFAULT_WAVE):
    """Compute the focus distance in the object space. Ray starts from sensor center and traces to the object space.

    Args:
        wvln (float, optional): Wavelength. Defaults to DEFAULT_WAVE.

    Returns:
        focal_plane (float): Focal plane in the object space.
    """
    device = self.device

    # Sample point source rays from sensor center
    o1 = torch.zeros(SPP_CALC, 3, device=device)
    o1[:, 2] = self.d_sensor

    # Sample the first surface as pupil
    # o2 = self.sample_circle(self.surfaces[0].r, z=0.0, shape=[SPP_CALC])
    # o2 *= 0.5  # Shrink sample region to improve accuracy
    pupilz, pupilr = self.get_exit_pupil()
    o2 = self.sample_circle(pupilr, pupilz, shape=[SPP_CALC])
    d = o2 - o1
    ray = Ray(o1, d, wvln, device=device)

    # Trace rays to object space
    ray = self.trace2obj(ray)

    # Optical axis intersection
    t = (ray.d[..., 0] * ray.o[..., 0] + ray.d[..., 1] * ray.o[..., 1]) / (
        ray.d[..., 0] ** 2 + ray.d[..., 1] ** 2
    )
    focus_z = (ray.o[..., 2] - ray.d[..., 2] * t)[ray.is_valid > 0].cpu().numpy()
    focus_z = focus_z[~np.isnan(focus_z) & (focus_z < 0)]

    if len(focus_z) > 0:
        focal_plane = float(np.mean(focus_z))
    else:
        raise ValueError(
            "No valid rays found, focal plane in the image space cannot be computed."
        )

    return focal_plane

calc_sensor_plane

calc_sensor_plane(depth=float('inf'))

Calculate in-focus sensor plane.

Parameters:

Name Type Description Default
depth float

Depth of the object plane. Defaults to float("inf").

float('inf')

Returns:

Name Type Description
d_sensor Tensor

Sensor plane in the image space.

Source code in deeplens/optics/geolens.py
@torch.no_grad()
def calc_sensor_plane(self, depth=float("inf")):
    """Calculate in-focus sensor plane.

    Args:
        depth (float, optional): Depth of the object plane. Defaults to float("inf").

    Returns:
        d_sensor (torch.Tensor): Sensor plane in the image space.
    """
    # Sample and trace rays, shape [SPP_CALC, 3]
    ray = self.sample_from_fov(
        fov_x=0.0, fov_y=0.0, depth=depth, num_rays=SPP_CALC, wvln=DEFAULT_WAVE
    )
    ray = self.trace2sensor(ray)

    # Calculate in-focus sensor position
    t = (ray.d[:, 0] * ray.o[:, 0] + ray.d[:, 1] * ray.o[:, 1]) / (
        ray.d[:, 0] ** 2 + ray.d[:, 1] ** 2
    )
    focus_z = ray.o[:, 2] - ray.d[:, 2] * t
    focus_z = focus_z[ray.is_valid > 0]
    focus_z = focus_z[~torch.isnan(focus_z) & (focus_z > 0)]
    d_sensor = torch.mean(focus_z)
    return d_sensor

calc_fov

calc_fov()

Compute field of view (FoV) of the lens in radians.

Calculates FoV using two methods
  1. Perspective projection — from focal length and sensor size (effective FoV, ignoring distortion).
  2. Ray tracing — traces rays from the sensor edge backwards to determine the real FoV including distortion effects.
Updates

self.vfov (float): Vertical FoV in radians. self.hfov (float): Horizontal FoV in radians. self.dfov (float): Diagonal FoV in radians. self.rfov (float): Half-diagonal (radius) FoV in radians. self.real_rfov (float): Real half-diagonal FoV from ray tracing. self.real_dfov (float): Real diagonal FoV from ray tracing. self.eqfl (float): 35mm equivalent focal length in mm.

Reference

[1] https://en.wikipedia.org/wiki/Angle_of_view_(photography)

Source code in deeplens/optics/geolens.py
@torch.no_grad()
def calc_fov(self):
    """Compute field of view (FoV) of the lens in radians.

    Calculates FoV using two methods:
        1. **Perspective projection** — from focal length and sensor size
           (effective FoV, ignoring distortion).
        2. **Ray tracing** — traces rays from the sensor edge backwards to
           determine the real FoV including distortion effects.

    Updates:
        self.vfov (float): Vertical FoV in radians.
        self.hfov (float): Horizontal FoV in radians.
        self.dfov (float): Diagonal FoV in radians.
        self.rfov (float): Half-diagonal (radius) FoV in radians.
        self.real_rfov (float): Real half-diagonal FoV from ray tracing.
        self.real_dfov (float): Real diagonal FoV from ray tracing.
        self.eqfl (float): 35mm equivalent focal length in mm.

    Reference:
        [1] https://en.wikipedia.org/wiki/Angle_of_view_(photography)
    """
    if not hasattr(self, "foclen"):
        return

    # 1. Perspective projection (effective FoV)
    self.vfov = 2 * math.atan(self.sensor_size[0] / 2 / self.foclen)
    self.hfov = 2 * math.atan(self.sensor_size[1] / 2 / self.foclen)
    self.dfov = 2 * math.atan(self.r_sensor / self.foclen)
    self.rfov = self.dfov / 2  # radius (half diagonal) FoV

    # 2. Ray tracing to calculate real FoV (distortion-affected FoV)
    # Sample rays from edge of sensor, shape [SPP_CALC, 3]
    o1 = torch.zeros([SPP_CALC, 3])
    o1 = torch.tensor([self.r_sensor, 0, self.d_sensor.item()]).repeat(SPP_CALC, 1)

    # Sample second points on exit pupil
    pupilz, pupilx = self.get_exit_pupil()
    x2 = torch.linspace(-pupilx, pupilx, SPP_CALC)
    z2 = torch.full_like(x2, pupilz)
    y2 = torch.full_like(x2, 0)
    o2 = torch.stack((x2, y2, z2), axis=-1)

    # Ray tracing to object space
    ray = Ray(o1, o2 - o1, device=self.device)
    ray = self.trace2obj(ray)

    # Compute output ray angle
    tan_rfov = ray.d[..., 0] / ray.d[..., 2]
    rfov = torch.atan(torch.sum(tan_rfov * ray.is_valid) / torch.sum(ray.is_valid))

    # If calculation failed, use pinhole camera model to compute fov
    if torch.isnan(rfov):
        self.real_rfov = self.rfov
        self.real_dfov = self.dfov
        print(
            f"Failed to calculate distorted FoV by ray tracing, use effective FoV {self.rfov} rad."
        )
    else:
        self.real_rfov = rfov.item()
        self.real_dfov = 2 * rfov.item()

    # 3. Compute 35mm equivalent focal length. 35mm sensor: 36mm * 24mm
    self.eqfl = 21.63 / math.tan(self.rfov)

calc_scale

calc_scale(depth)

Calculate the scale factor (object height / image height).

Uses the pinhole camera model to compute magnification.

Parameters:

Name Type Description Default
depth float

Object distance from the lens (negative z direction).

required

Returns:

Name Type Description
float

Scale factor relating object height to image height.

Source code in deeplens/optics/geolens.py
@torch.no_grad()
def calc_scale(self, depth):
    """Calculate the scale factor (object height / image height).

    Uses the pinhole camera model to compute magnification.

    Args:
        depth (float): Object distance from the lens (negative z direction).

    Returns:
        float: Scale factor relating object height to image height.
    """
    return -depth / self.foclen

calc_pupil

calc_pupil()

Compute entrance and exit pupil positions and radii.

The entrance and exit pupils must be recalculated whenever
  • First-order parameters change (e.g., field of view, object height, image height),
  • Lens geometry or materials change (e.g., surface curvatures, refractive indices, thicknesses),
  • Or generally, any time the lens configuration is modified.
Updates

self.aper_idx: Index of the aperture surface. self.exit_pupilz, self.exit_pupilr: Exit pupil position and radius. self.entr_pupilz, self.entr_pupilr: Entrance pupil position and radius. self.exit_pupilz_parax, self.exit_pupilr_parax: Paraxial exit pupil. self.entr_pupilz_parax, self.entr_pupilr_parax: Paraxial entrance pupil. self.fnum: F-number calculated from focal length and entrance pupil.

Source code in deeplens/optics/geolens.py
@torch.no_grad()
def calc_pupil(self):
    """Compute entrance and exit pupil positions and radii.

    The entrance and exit pupils must be recalculated whenever:
        - First-order parameters change (e.g., field of view, object height, image height),
        - Lens geometry or materials change (e.g., surface curvatures, refractive indices, thicknesses),
        - Or generally, any time the lens configuration is modified.

    Updates:
        self.aper_idx: Index of the aperture surface.
        self.exit_pupilz, self.exit_pupilr: Exit pupil position and radius.
        self.entr_pupilz, self.entr_pupilr: Entrance pupil position and radius.
        self.exit_pupilz_parax, self.exit_pupilr_parax: Paraxial exit pupil.
        self.entr_pupilz_parax, self.entr_pupilr_parax: Paraxial entrance pupil.
        self.fnum: F-number calculated from focal length and entrance pupil.
    """
    # Find aperture
    self.aper_idx = None
    for i in range(len(self.surfaces)):
        if isinstance(self.surfaces[i], Aperture):
            self.aper_idx = i
            break

    if self.aper_idx is None:
        self.aper_idx = np.argmin([s.r for s in self.surfaces])
        print("No aperture found, use the smallest surface as aperture.")

    # Compute entrance and exit pupil
    self.exit_pupilz, self.exit_pupilr = self.calc_exit_pupil(paraxial=False)
    self.entr_pupilz, self.entr_pupilr = self.calc_entrance_pupil(paraxial=False)
    self.exit_pupilz_parax, self.exit_pupilr_parax = self.calc_exit_pupil(
        paraxial=True
    )
    self.entr_pupilz_parax, self.entr_pupilr_parax = self.calc_entrance_pupil(
        paraxial=True
    )

    # Compute F-number
    self.fnum = self.foclen / (2 * self.entr_pupilr)

get_entrance_pupil

get_entrance_pupil(paraxial=False)

Get entrance pupil location and radius.

Parameters:

Name Type Description Default
paraxial bool

If True, return paraxial approximation values. If False, return real ray-traced values. Defaults to False.

False

Returns:

Name Type Description
tuple

(z_position, radius) of the entrance pupil in [mm].

Source code in deeplens/optics/geolens.py
def get_entrance_pupil(self, paraxial=False):
    """Get entrance pupil location and radius.

    Args:
        paraxial (bool, optional): If True, return paraxial approximation values.
            If False, return real ray-traced values. Defaults to False.

    Returns:
        tuple: (z_position, radius) of the entrance pupil in [mm].
    """
    if paraxial:
        return self.entr_pupilz_parax, self.entr_pupilr_parax
    else:
        return self.entr_pupilz, self.entr_pupilr

get_exit_pupil

get_exit_pupil(paraxial=False)

Get exit pupil location and radius.

Parameters:

Name Type Description Default
paraxial bool

If True, return paraxial approximation values. If False, return real ray-traced values. Defaults to False.

False

Returns:

Name Type Description
tuple

(z_position, radius) of the exit pupil in [mm].

Source code in deeplens/optics/geolens.py
def get_exit_pupil(self, paraxial=False):
    """Get exit pupil location and radius.

    Args:
        paraxial (bool, optional): If True, return paraxial approximation values.
            If False, return real ray-traced values. Defaults to False.

    Returns:
        tuple: (z_position, radius) of the exit pupil in [mm].
    """
    if paraxial:
        return self.exit_pupilz_parax, self.exit_pupilr_parax
    else:
        return self.exit_pupilz, self.exit_pupilr

calc_exit_pupil

calc_exit_pupil(paraxial=False)

Calculate exit pupil location and radius.

Paraxial mode

Rays are emitted from near the center of the aperture stop and are close to the optical axis. This mode estimates the exit pupil position and radius under ideal (first-order) optical assumptions. It is fast and stable.

Non-paraxial mode

Rays are emitted from the edge of the aperture stop in large quantities. The exit pupil position and radius are determined based on the intersection points of these rays. This mode is slower and affected by aperture-related aberrations.

Use paraxial mode unless precise ray aiming is required.

Parameters:

Name Type Description Default
paraxial bool

center (True) or edge (False).

False

Returns:

Name Type Description
avg_pupilz float

z coordinate of exit pupil.

avg_pupilr float

radius of exit pupil.

Reference

[1] Exit pupil: how many rays can come from sensor to object space. [2] https://en.wikipedia.org/wiki/Exit_pupil

Source code in deeplens/optics/geolens.py
@torch.no_grad()
def calc_exit_pupil(self, paraxial=False):
    """Calculate exit pupil location and radius.

    Paraxial mode:
        Rays are emitted from near the center of the aperture stop and are close to the optical axis.
        This mode estimates the exit pupil position and radius under ideal (first-order) optical assumptions.
        It is fast and stable.

    Non-paraxial mode:
        Rays are emitted from the edge of the aperture stop in large quantities.
        The exit pupil position and radius are determined based on the intersection points of these rays.
        This mode is slower and affected by aperture-related aberrations.

    Use paraxial mode unless precise ray aiming is required.

    Args:
        paraxial (bool): center (True) or edge (False).

    Returns:
        avg_pupilz (float): z coordinate of exit pupil.
        avg_pupilr (float): radius of exit pupil.

    Reference:
        [1] Exit pupil: how many rays can come from sensor to object space.
        [2] https://en.wikipedia.org/wiki/Exit_pupil
    """
    if self.aper_idx is None or hasattr(self, "aper_idx") is False:
        print("No aperture, use the last surface as exit pupil.")
        return self.surfaces[-1].d.item(), self.surfaces[-1].r

    # Sample rays from aperture (edge or center)
    aper_idx = self.aper_idx
    aper_z = self.surfaces[aper_idx].d.item()
    aper_r = self.surfaces[aper_idx].r

    if paraxial:
        ray_o = torch.tensor([[DELTA_PARAXIAL, 0, aper_z]], device=self.device).repeat(32, 1)
        phi_rad = torch.linspace(-0.01, 0.01, 32, device=self.device)
    else:
        ray_o = torch.tensor([[aper_r, 0, aper_z]], device=self.device).repeat(SPP_CALC, 1)
        rfov = float(np.arctan(self.r_sensor / self.foclen))
        phi_rad = torch.linspace(-rfov / 2, rfov / 2, SPP_CALC, device=self.device)

    d = torch.stack(
        (torch.sin(phi_rad), torch.zeros_like(phi_rad), torch.cos(phi_rad)), axis=-1
    )
    ray = Ray(ray_o, d, device=self.device)

    # Ray tracing from aperture edge to last surface
    surf_range = range(self.aper_idx + 1, len(self.surfaces))
    ray, _ = self.trace(ray, surf_range=surf_range)

    # Compute intersection points, solving the equation: o1+d1*t1 = o2+d2*t2
    ray_o = torch.stack(
        [ray.o[ray.is_valid != 0][:, 0], ray.o[ray.is_valid != 0][:, 2]], dim=-1
    )
    ray_d = torch.stack(
        [ray.d[ray.is_valid != 0][:, 0], ray.d[ray.is_valid != 0][:, 2]], dim=-1
    )
    intersection_points = self.compute_intersection_points_2d(ray_o, ray_d)

    # Handle the case where no intersection points are found or small pupil
    if len(intersection_points) == 0:
        print("No intersection points found, use the last surface as exit pupil.")
        avg_pupilr = self.surfaces[-1].r
        avg_pupilz = self.surfaces[-1].d.item()
    else:
        avg_pupilr = torch.mean(intersection_points[:, 0]).item()
        avg_pupilz = torch.mean(intersection_points[:, 1]).item()

        if paraxial:
            avg_pupilr = abs(avg_pupilr / DELTA_PARAXIAL * aper_r)

        if avg_pupilr < EPSILON:
            print(
                "Zero or negative exit pupil is detected, use the last surface as pupil."
            )
            avg_pupilr = self.surfaces[-1].r
            avg_pupilz = self.surfaces[-1].d.item()

    return avg_pupilz, avg_pupilr

calc_entrance_pupil

calc_entrance_pupil(paraxial=False)

Calculate entrance pupil of the lens.

The entrance pupil is the optical image of the physical aperture stop, as seen through the optical elements in front of the stop. We sample backward rays from the aperture stop and trace them to the first surface, then find the intersection points of the reverse extension of the rays. The average of the intersection points defines the entrance pupil position and radius.

Parameters:

Name Type Description Default
paraxial bool

Ray sampling mode. If True, rays are emitted near the centre of the aperture stop (fast, paraxially stable). If False, rays are emitted from the stop edge in larger quantities (slower, accounts for aperture aberrations). Defaults to False.

False

Returns:

Name Type Description
tuple

(z_position, radius) of entrance pupil.

Note

[1] Use paraxial mode unless precise ray aiming is required. [2] This function only works for object at a far distance. For microscopes, this function usually returns a negative entrance pupil.

References

[1] Entrance pupil: how many rays can come from object space to sensor. [2] https://en.wikipedia.org/wiki/Entrance_pupil: "In an optical system, the entrance pupil is the optical image of the physical aperture stop, as 'seen' through the optical elements in front of the stop." [3] Zemax LLC, OpticStudio User Manual, Version 19.4, Document No. 2311, 2019.

Source code in deeplens/optics/geolens.py
@torch.no_grad()
def calc_entrance_pupil(self, paraxial=False):
    """Calculate entrance pupil of the lens.

    The entrance pupil is the optical image of the physical aperture stop, as seen through the optical elements in front of the stop. We sample backward rays from the aperture stop and trace them to the first surface, then find the intersection points of the reverse extension of the rays. The average of the intersection points defines the entrance pupil position and radius.

    Args:
        paraxial (bool): Ray sampling mode.  If ``True``, rays are emitted
            near the centre of the aperture stop (fast, paraxially stable).
            If ``False``, rays are emitted from the stop edge in larger
            quantities (slower, accounts for aperture aberrations).
            Defaults to ``False``.

    Returns:
        tuple: (z_position, radius) of entrance pupil.

    Note:
        [1] Use paraxial mode unless precise ray aiming is required.
        [2] This function only works for object at a far distance. For microscopes, this function usually returns a negative entrance pupil.

    References:
        [1] Entrance pupil: how many rays can come from object space to sensor.
        [2] https://en.wikipedia.org/wiki/Entrance_pupil: "In an optical system, the entrance pupil is the optical image of the physical aperture stop, as 'seen' through the optical elements in front of the stop."
        [3] Zemax LLC, *OpticStudio User Manual*, Version 19.4, Document No. 2311, 2019.
    """
    if self.aper_idx is None or not hasattr(self, "aper_idx"):
        print("No aperture stop, use the first surface as entrance pupil.")
        return self.surfaces[0].d.item(), self.surfaces[0].r

    # Sample rays from edge of aperture stop
    aper_idx = self.aper_idx
    aper_surf = self.surfaces[aper_idx]
    aper_z = aper_surf.d.item()
    if aper_surf.is_square:
        aper_r = float(np.sqrt(2)) * aper_surf.r
    else:
        aper_r = aper_surf.r

    if paraxial:
        ray_o = torch.tensor([[DELTA_PARAXIAL, 0, aper_z]]).repeat(32, 1)
        phi = torch.linspace(-0.01, 0.01, 32)
    else:
        ray_o = torch.tensor([[aper_r, 0, aper_z]]).repeat(SPP_CALC, 1)
        rfov = float(np.arctan(self.r_sensor / self.foclen))
        phi = torch.linspace(-rfov / 2, rfov / 2, SPP_CALC)

    d = torch.stack(
        (torch.sin(phi), torch.zeros_like(phi), -torch.cos(phi)), axis=-1
    )
    ray = Ray(ray_o, d, device=self.device)

    # Ray tracing from aperture edge to first surface
    surf_range = range(0, self.aper_idx)
    ray, _ = self.trace(ray, surf_range=surf_range)

    # Compute intersection points, solving the equation: o1+d1*t1 = o2+d2*t2
    ray_o = torch.stack(
        [ray.o[ray.is_valid > 0][:, 0], ray.o[ray.is_valid > 0][:, 2]], dim=-1
    )
    ray_d = torch.stack(
        [ray.d[ray.is_valid > 0][:, 0], ray.d[ray.is_valid > 0][:, 2]], dim=-1
    )
    intersection_points = self.compute_intersection_points_2d(ray_o, ray_d)

    # Handle the case where no intersection points are found or small entrance pupil
    if len(intersection_points) == 0:
        print(
            "No intersection points found, use the first surface as entrance pupil."
        )
        avg_pupilr = self.surfaces[0].r
        avg_pupilz = self.surfaces[0].d.item()
    else:
        avg_pupilr = torch.mean(intersection_points[:, 0]).item()
        avg_pupilz = torch.mean(intersection_points[:, 1]).item()

        if paraxial:
            avg_pupilr = abs(avg_pupilr / DELTA_PARAXIAL * aper_r)

        if avg_pupilr < EPSILON:
            print(
                "Zero or negative entrance pupil is detected, use the first surface as entrance pupil."
            )
            avg_pupilr = self.surfaces[0].r
            avg_pupilz = self.surfaces[0].d.item()

    return avg_pupilz, avg_pupilr

compute_intersection_points_2d staticmethod

compute_intersection_points_2d(origins, directions)

Compute the intersection points of 2D lines.

Parameters:

Name Type Description Default
origins Tensor

Origins of the lines. Shape: [N, 2]

required
directions Tensor

Directions of the lines. Shape: [N, 2]

required

Returns:

Type Description

torch.Tensor: Intersection points. Shape: [N*(N-1)/2, 2]

Source code in deeplens/optics/geolens.py
@staticmethod
def compute_intersection_points_2d(origins, directions):
    """Compute the intersection points of 2D lines.

    Args:
        origins (torch.Tensor): Origins of the lines. Shape: [N, 2]
        directions (torch.Tensor): Directions of the lines. Shape: [N, 2]

    Returns:
        torch.Tensor: Intersection points. Shape: [N*(N-1)/2, 2]
    """
    N = origins.shape[0]

    # Create pairwise combinations of indices
    idx = torch.arange(N)
    idx_i, idx_j = torch.combinations(idx, r=2).unbind(1)

    Oi = origins[idx_i]  # Shape: [N*(N-1)/2, 2]
    Oj = origins[idx_j]  # Shape: [N*(N-1)/2, 2]
    Di = directions[idx_i]  # Shape: [N*(N-1)/2, 2]
    Dj = directions[idx_j]  # Shape: [N*(N-1)/2, 2]

    # Vector from Oi to Oj
    b = Oj - Oi  # Shape: [N*(N-1)/2, 2]

    # Coefficients matrix A
    A = torch.stack([Di, -Dj], dim=-1)  # Shape: [N*(N-1)/2, 2, 2]

    # Solve the linear system Ax = b
    # Using least squares to handle the case of no exact solution
    if A.device.type == "mps":
        # Perform lstsq on CPU for MPS devices and move result back
        x, _ = torch.linalg.lstsq(A.cpu(), b.unsqueeze(-1).cpu())[:2]
        x = x.to(A.device)
    else:
        x, _ = torch.linalg.lstsq(A, b.unsqueeze(-1))[:2]
    x = x.squeeze(-1)  # Shape: [N*(N-1)/2, 2]
    s = x[:, 0]
    t = x[:, 1]

    # Calculate the intersection points using either rays
    P_i = Oi + s.unsqueeze(-1) * Di  # Shape: [N*(N-1)/2, 2]
    P_j = Oj + t.unsqueeze(-1) * Dj  # Shape: [N*(N-1)/2, 2]

    # Take the average to mitigate numerical precision issues
    P = (P_i + P_j) / 2

    return P

refocus

refocus(foc_dist=float('inf'))

Refocus the lens to a depth distance by changing sensor position.

Parameters:

Name Type Description Default
foc_dist float

focal distance.

float('inf')
Note

In DSLR, phase detection autofocus (PDAF) is a popular and efficient method. But here we simplify the problem by calculating the in-focus position of green light.

Source code in deeplens/optics/geolens.py
@torch.no_grad()
def refocus(self, foc_dist=float("inf")):
    """Refocus the lens to a depth distance by changing sensor position.

    Args:
        foc_dist (float): focal distance.

    Note:
        In DSLR, phase detection autofocus (PDAF) is a popular and efficient method. But here we simplify the problem by calculating the in-focus position of green light.
    """
    # Calculate in-focus sensor position
    d_sensor_new = self.calc_sensor_plane(depth=foc_dist)

    # Update sensor position
    assert d_sensor_new > 0, "Obtained negative sensor position."
    self.d_sensor = d_sensor_new

    # FoV will be slightly changed
    self.post_computation()

set_fnum

set_fnum(fnum)

Set F-number and aperture radius using binary search.

Parameters:

Name Type Description Default
fnum float

target F-number.

required
Source code in deeplens/optics/geolens.py
@torch.no_grad()
def set_fnum(self, fnum):
    """Set F-number and aperture radius using binary search.

    Args:
        fnum (float): target F-number.
    """
    current_fnum = self.fnum
    current_aper_r = self.surfaces[self.aper_idx].r
    target_pupil_r = self.foclen / fnum / 2

    # Binary search to find aperture radius that gives desired exit pupil radius
    aper_r = current_aper_r * (current_fnum / fnum)
    aper_r_min = 0.5 * aper_r
    aper_r_max = 2.0 * aper_r

    for _ in range(16):
        self.surfaces[self.aper_idx].r = aper_r
        _, pupilr = self.calc_entrance_pupil()

        if abs(pupilr - target_pupil_r) < 0.1:  # Close enough
            break

        if pupilr > target_pupil_r:
            # Current radius is too large, decrease it
            aper_r_max = aper_r
            aper_r = (aper_r_min + aper_r) / 2
        else:
            # Current radius is too small, increase it
            aper_r_min = aper_r
            aper_r = (aper_r_max + aper_r) / 2

    self.surfaces[self.aper_idx].r = aper_r

    # Update pupil after setting aperture radius
    self.calc_pupil()

set_target_fov_fnum

set_target_fov_fnum(rfov, fnum)

Set FoV, ImgH and F number, only use this function to assign design targets.

Parameters:

Name Type Description Default
rfov float

half diagonal-FoV in radian.

required
fnum float

F number.

required
Source code in deeplens/optics/geolens.py
@torch.no_grad()
def set_target_fov_fnum(self, rfov, fnum):
    """Set FoV, ImgH and F number, only use this function to assign design targets.

    Args:
        rfov (float): half diagonal-FoV in radian.
        fnum (float): F number.
    """
    if rfov > math.pi:
        self.rfov = rfov / 180.0 * math.pi
    else:
        self.rfov = rfov

    self.foclen = self.r_sensor / math.tan(self.rfov)
    self.eqfl = 21.63 / math.tan(self.rfov)
    self.fnum = fnum
    aper_r = self.foclen / fnum / 2
    self.surfaces[self.aper_idx].update_r(float(aper_r))

    # Update pupil after setting aperture radius
    self.calc_pupil()

set_fov

set_fov(rfov)

Set half-diagonal field of view as a design target.

Unlike calc_fov() which derives FoV from focal length and sensor size, this method directly assigns the target FoV for lens optimisation.

Parameters:

Name Type Description Default
rfov float

Half-diagonal FoV in radians.

required
Source code in deeplens/optics/geolens.py
@torch.no_grad()
def set_fov(self, rfov):
    """Set half-diagonal field of view as a design target.

    Unlike ``calc_fov()`` which derives FoV from focal length and sensor
    size, this method directly assigns the target FoV for lens optimisation.

    Args:
        rfov (float): Half-diagonal FoV in radians.
    """
    self.rfov = rfov
    self.eqfl = 21.63 / math.tan(self.rfov)

deeplens.optics.geolens_pkg.psf_compute.GeoLensPSF

Mixin providing PSF computation for GeoLens.

All three PSF models are exposed through a single :meth:psf dispatcher. The geometric and coherent models are differentiable; Huygens is not.

This class is not instantiated directly; it is mixed into :class:~deeplens.optics.geolens.GeoLens.

psf

psf(points, ks=PSF_KS, wvln=DEFAULT_WAVE, spp=None, recenter=True, model='geometric')

Calculate Point Spread Function (PSF) for given point sources.

Supports multiple PSF calculation models
  • geometric: Incoherent intensity ray tracing (fast, differentiable)
  • coherent: Coherent ray tracing with free-space propagation (accurate, differentiable)
  • huygens: Huygens-Fresnel integration (accurate, not differentiable)

Parameters:

Name Type Description Default
points Tensor

Point source positions. Shape [N, 3] with x, y in [-1, 1] and z in [-Inf, 0]. Normalized coordinates.

required
ks int

Output kernel size in pixels. Defaults to PSF_KS.

PSF_KS
wvln float

Wavelength in [um]. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE
spp int

Samples per pixel. If None, uses model-specific default.

None
recenter bool

If True, center PSF using chief ray. Defaults to True.

True
model str

PSF model type. One of 'geometric', 'coherent', 'huygens'. Defaults to 'geometric'.

'geometric'

Returns:

Name Type Description
Tensor

PSF normalized to sum to 1. Shape [ks, ks] or [N, ks, ks].

Source code in deeplens/optics/geolens_pkg/psf_compute.py
def psf(
    self,
    points,
    ks=PSF_KS,
    wvln=DEFAULT_WAVE,
    spp=None,
    recenter=True,
    model="geometric",
):
    """Calculate Point Spread Function (PSF) for given point sources.

    Supports multiple PSF calculation models:
        - geometric: Incoherent intensity ray tracing (fast, differentiable)
        - coherent: Coherent ray tracing with free-space propagation (accurate, differentiable)
        - huygens: Huygens-Fresnel integration (accurate, not differentiable)

    Args:
        points (Tensor): Point source positions. Shape [N, 3] with x, y in [-1, 1]
            and z in [-Inf, 0]. Normalized coordinates.
        ks (int, optional): Output kernel size in pixels. Defaults to PSF_KS.
        wvln (float, optional): Wavelength in [um]. Defaults to DEFAULT_WAVE.
        spp (int, optional): Samples per pixel. If None, uses model-specific default.
        recenter (bool, optional): If True, center PSF using chief ray. Defaults to True.
        model (str, optional): PSF model type. One of 'geometric', 'coherent', 'huygens'.
            Defaults to 'geometric'.

    Returns:
        Tensor: PSF normalized to sum to 1. Shape [ks, ks] or [N, ks, ks].
    """
    if model == "geometric":
        spp = SPP_PSF if spp is None else spp
        return self.psf_geometric(points, ks, wvln, spp, recenter)
    elif model == "coherent":
        spp = SPP_COHERENT if spp is None else spp
        return self.psf_coherent(points, ks, wvln, spp, recenter)
    elif model == "huygens":
        spp = SPP_COHERENT if spp is None else spp
        return self.psf_huygens(points, ks, wvln, spp, recenter)
    else:
        raise ValueError(f"Unknown PSF model: {model}")

psf_geometric

psf_geometric(points, ks=PSF_KS, wvln=DEFAULT_WAVE, spp=SPP_PSF, recenter=True)

Single wavelength geometric PSF calculation.

Parameters:

Name Type Description Default
points Tensor

Normalized point source position. Shape of [N, 3], x, y in range [-1, 1], z in range [-Inf, 0].

required
ks int

Output kernel size.

PSF_KS
wvln float

Wavelength.

DEFAULT_WAVE
spp int

Sample per pixel.

SPP_PSF
recenter bool

Recenter PSF using chief ray.

True

Returns:

Name Type Description
psf

Shape of [ks, ks] or [N, ks, ks].

References

[1] https://optics.ansys.com/hc/en-us/articles/42661723066515-What-is-a-Point-Spread-Function

Source code in deeplens/optics/geolens_pkg/psf_compute.py
def psf_geometric(
    self, points, ks=PSF_KS, wvln=DEFAULT_WAVE, spp=SPP_PSF, recenter=True
):
    """Single wavelength geometric PSF calculation.

    Args:
        points (Tensor): Normalized point source position. Shape of [N, 3], x, y in range [-1, 1], z in range [-Inf, 0].
        ks (int, optional): Output kernel size.
        wvln (float, optional): Wavelength.
        spp (int, optional): Sample per pixel.
        recenter (bool, optional): Recenter PSF using chief ray.

    Returns:
        psf: Shape of [ks, ks] or [N, ks, ks].

    References:
        [1] https://optics.ansys.com/hc/en-us/articles/42661723066515-What-is-a-Point-Spread-Function
    """
    sensor_w, sensor_h = self.sensor_size
    pixel_size = self.pixel_size
    device = self.device

    # Points shape of [N, 3]
    if not torch.is_tensor(points):
        points = torch.tensor(points, device=device)

    if len(points.shape) == 1:
        single_point = True
        points = points.unsqueeze(0)
    else:
        single_point = False

    # Sample rays. Ray position in the object space by perspective projection
    depth = points[:, 2]
    scale = self.calc_scale(depth)
    point_obj_x = points[..., 0] * scale * sensor_w / 2
    point_obj_y = points[..., 1] * scale * sensor_h / 2
    point_obj = torch.stack([point_obj_x, point_obj_y, points[..., 2]], dim=-1)
    ray = self.sample_from_points(points=point_obj, num_rays=spp, wvln=wvln)

    # Trace rays to sensor plane (incoherent)
    ray.coherent = False
    ray = self.trace2sensor(ray)

    # Calculate PSF center, shape [N, 2]
    if recenter:
        pointc = self.psf_center(point_obj, method="chief_ray")
    else:
        pointc = self.psf_center(point_obj, method="pinhole")

    # Monte Carlo integration
    psf = forward_integral(ray.flip_xy(), ps=pixel_size, ks=ks, pointc=pointc)

    # Intensity normalization
    psf = psf / (torch.sum(psf, dim=(-2, -1), keepdim=True) + EPSILON)

    if single_point:
        psf = psf.squeeze(0)

    return diff_float(psf)

psf_coherent

psf_coherent(points, ks=PSF_KS, wvln=DEFAULT_WAVE, spp=SPP_COHERENT, recenter=True)

Alias for psf_pupil_prop. Calculates PSF by coherent ray tracing to exit pupil followed by Angular Spectrum Method (ASM) propagation.

Source code in deeplens/optics/geolens_pkg/psf_compute.py
def psf_coherent(
    self, points, ks=PSF_KS, wvln=DEFAULT_WAVE, spp=SPP_COHERENT, recenter=True
):
    """Alias for psf_pupil_prop. Calculates PSF by coherent ray tracing to exit pupil followed by Angular Spectrum Method (ASM) propagation."""
    return self.psf_pupil_prop(points, ks=ks, wvln=wvln, spp=spp, recenter=recenter)

psf_pupil_prop

psf_pupil_prop(points, ks=PSF_KS, wvln=DEFAULT_WAVE, spp=SPP_COHERENT, recenter=True)

Single point monochromatic PSF using exit-pupil diffraction model. This function is differentiable.

Steps

1, Calculate complex wavefield at exit-pupil plane by coherent ray tracing. 2, Free-space propagation to sensor plane and calculate intensity PSF.

Parameters:

Name Type Description Default
points Tensor

[x, y, z] coordinates of the point source. Defaults to torch.Tensor([0,0,-10000]).

required
ks int

size of the PSF patch. Defaults to PSF_KS.

PSF_KS
wvln float

wvln. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE
spp int

number of rays to sample. Defaults to SPP_COHERENT.

SPP_COHERENT
recenter bool

Recenter PSF using chief ray. Defaults to True.

True

Returns:

Name Type Description
psf_out Tensor

PSF patch. Normalized to sum to 1. Shape [ks, ks]

Reference

[1] "End-to-End Hybrid Refractive-Diffractive Lens Design with Differentiable Ray-Wave Model", SIGGRAPH Asia 2024.

Note

[1] This function is similar to ZEMAX FFT_PSF but implement free-space propagation with Angular Spectrum Method (ASM) rather than FFT transform. Free-space propagation using ASM is more accurate than doing FFT, because FFT (as used in ZEMAX) assumes far-field condition (e.g., chief ray perpendicular to image plane).

Source code in deeplens/optics/geolens_pkg/psf_compute.py
def psf_pupil_prop(
    self, points, ks=PSF_KS, wvln=DEFAULT_WAVE, spp=SPP_COHERENT, recenter=True
):
    """Single point monochromatic PSF using exit-pupil diffraction model. This function is differentiable.

    Steps:
        1, Calculate complex wavefield at exit-pupil plane by coherent ray tracing.
        2, Free-space propagation to sensor plane and calculate intensity PSF.

    Args:
        points (torch.Tensor, optional): [x, y, z] coordinates of the point source. Defaults to torch.Tensor([0,0,-10000]).
        ks (int, optional): size of the PSF patch. Defaults to PSF_KS.
        wvln (float, optional): wvln. Defaults to DEFAULT_WAVE.
        spp (int, optional): number of rays to sample. Defaults to SPP_COHERENT.
        recenter (bool, optional): Recenter PSF using chief ray. Defaults to True.

    Returns:
        psf_out (torch.Tensor): PSF patch. Normalized to sum to 1. Shape [ks, ks]

    Reference:
        [1] "End-to-End Hybrid Refractive-Diffractive Lens Design with Differentiable Ray-Wave Model", SIGGRAPH Asia 2024.

    Note:
        [1] This function is similar to ZEMAX FFT_PSF but implement free-space propagation with Angular Spectrum Method (ASM) rather than FFT transform. Free-space propagation using ASM is more accurate than doing FFT, because FFT (as used in ZEMAX) assumes far-field condition (e.g., chief ray perpendicular to image plane).
    """
    # Pupil field by coherent ray tracing
    wavefront, psfc = self.pupil_field(
        points=points, wvln=wvln, spp=spp, recenter=recenter
    )

    # Propagate to sensor plane and get intensity
    pupilz, pupilr = self.get_exit_pupil()
    h, w = wavefront.shape
    # Manually pad wave field
    wavefront = F.pad(
        wavefront.unsqueeze(0).unsqueeze(0),
        [h // 2, h // 2, w // 2, w // 2],
        mode="constant",
        value=0,
    )
    # Free-space propagation using Angular Spectrum Method (ASM)
    sensor_field = AngularSpectrumMethod(
        wavefront,
        z=self.d_sensor - pupilz,
        wvln=wvln,
        ps=self.pixel_size,
        padding=False,
    )
    # Get intensity
    psf_inten = sensor_field.abs() ** 2

    # Calculate PSF center
    h, w = psf_inten.shape[-2:]
    # consider both interplation and padding
    psfc_idx_i = ((2 - psfc[1]) * h / 4).round().long()
    psfc_idx_j = ((2 + psfc[0]) * w / 4).round().long()

    # Crop valid PSF region and normalize
    if ks is not None:
        psf_inten_pad = (
            F.pad(
                psf_inten,
                [ks // 2, ks // 2, ks // 2, ks // 2],
                mode="constant",
                value=0,
            )
            .squeeze(0)
            .squeeze(0)
        )
        psf = psf_inten_pad[
            psfc_idx_i : psfc_idx_i + ks, psfc_idx_j : psfc_idx_j + ks
        ]
    else:
        psf = psf_inten

    # Intensity normalization, shape of [ks, ks] or [h, w]
    psf = psf / (torch.sum(psf, dim=(-2, -1), keepdim=True) + EPSILON)

    return diff_float(psf)

pupil_field

pupil_field(points, wvln=DEFAULT_WAVE, spp=SPP_COHERENT, recenter=True)

Compute complex wavefront at exit pupil plane by coherent ray tracing.

The wavefront is flipped for subsequent PSF calculation and has the same size as the image sensor. This function is differentiable.

Parameters:

Name Type Description Default
points Tensor or list

Single point source position. Shape [3] or [1, 3], with x, y in [-1, 1] and z in [-Inf, 0].

required
wvln float

Wavelength in [um]. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE
spp int

Number of rays to sample. Must be >= 1,000,000 for accurate coherent simulation. Defaults to SPP_COHERENT.

SPP_COHERENT
recenter bool

If True, center using chief ray. Defaults to True.

True

Returns:

Name Type Description
tuple

(wavefront, psf_center) where: - wavefront (Tensor): Complex wavefront at exit pupil. Shape [H, H]. - psf_center (list): Normalized PSF center coordinates [x, y] in [-1, 1].

Note

Default dtype must be torch.float64 for accurate phase calculation.

Source code in deeplens/optics/geolens_pkg/psf_compute.py
def pupil_field(self, points, wvln=DEFAULT_WAVE, spp=SPP_COHERENT, recenter=True):
    """Compute complex wavefront at exit pupil plane by coherent ray tracing.

    The wavefront is flipped for subsequent PSF calculation and has the same
    size as the image sensor. This function is differentiable.

    Args:
        points (Tensor or list): Single point source position. Shape [3] or [1, 3],
            with x, y in [-1, 1] and z in [-Inf, 0].
        wvln (float, optional): Wavelength in [um]. Defaults to DEFAULT_WAVE.
        spp (int, optional): Number of rays to sample. Must be >= 1,000,000 for
            accurate coherent simulation. Defaults to SPP_COHERENT.
        recenter (bool, optional): If True, center using chief ray. Defaults to True.

    Returns:
        tuple: (wavefront, psf_center) where:
            - wavefront (Tensor): Complex wavefront at exit pupil. Shape [H, H].
            - psf_center (list): Normalized PSF center coordinates [x, y] in [-1, 1].

    Note:
        Default dtype must be torch.float64 for accurate phase calculation.
    """
    assert spp >= 1_000_000, (
        f"Ray sampling {spp} is too small for coherent ray tracing, which may lead to inaccurate simulation."
    )
    assert torch.get_default_dtype() == torch.float64, (
        "Default dtype must be set to float64 for accurate phase calculation."
    )

    sensor_w, sensor_h = self.sensor_size
    device = self.device

    if isinstance(points, list):
        points = torch.tensor(points, device=device).unsqueeze(0)  # [1, 3]
    elif torch.is_tensor(points) and len(points.shape) == 1:
        points = points.unsqueeze(0).to(device)  # [1, 3]
    elif torch.is_tensor(points) and len(points.shape) == 2:
        assert points.shape[0] == 1, (
            f"pupil_field only supports single point input, got shape {points.shape}"
        )
    else:
        raise ValueError(f"Unsupported point type {points.type()}.")

    assert points.shape[0] == 1, (
        "Only one point is supported for pupil field calculation."
    )

    # Ray origin in the object space
    scale = self.calc_scale(points[:, 2].item())
    point_obj_x = points[:, 0] * scale * sensor_w / 2
    point_obj_y = points[:, 1] * scale * sensor_h / 2
    points_obj = torch.stack([point_obj_x, point_obj_y, points[:, 2]], dim=-1)

    # Ray center determined by chief ray
    # Shape of [N, 2], un-normalized physical coordinates
    if recenter:
        pointc = self.psf_center(points_obj, method="chief_ray")
    else:
        pointc = self.psf_center(points_obj, method="pinhole")

    # Ray-tracing to exit_pupil
    ray = self.sample_from_points(points=points_obj, num_rays=spp, wvln=wvln)
    ray.coherent = True
    ray = self.trace2exit_pupil(ray)

    # Calculate complex field (same physical size and resolution as the sensor)
    # Complex field is flipped here for further PSF calculation
    pointc_ref = torch.zeros_like(points[:, :2])  # [N, 2]
    wavefront = forward_integral(
        ray.flip_xy(),
        ps=self.pixel_size,
        ks=self.sensor_res[1],
        pointc=pointc_ref,
    )
    wavefront = wavefront.squeeze(0)  # [H, H]

    # PSF center (on the sensor plane)
    pointc = pointc[0, :]
    psf_center = [
        pointc[0] / sensor_w * 2,
        pointc[1] / sensor_h * 2,
    ]

    return wavefront, psf_center

psf_huygens

psf_huygens(points, ks=PSF_KS, wvln=DEFAULT_WAVE, spp=SPP_COHERENT, recenter=True)

Single wavelength Huygens PSF calculation.

This function is not differentiable due to its heavy computational cost.

Steps

1, Trace coherent rays to exit-pupil plane. 2, Treat every ray as a secondary point source emitting a spherical wave.

Parameters:

Name Type Description Default
points Tensor

Normalized point source position. Shape of [N, 3], x, y in range [-1, 1], z in range [-Inf, 0].

required
ks int

Output kernel size.

PSF_KS
wvln float

Wavelength.

DEFAULT_WAVE
spp int

Sample per pixel.

SPP_COHERENT
recenter bool

Recenter PSF using chief ray.

True

Returns:

Name Type Description
psf

Shape of [ks, ks] or [N, ks, ks].

References

[1] "Optical Aberrations Correction in Postprocessing Using Imaging Simulation", TOG 2021

Note

This is different from ZEMAX Huygens PSF, which traces rays to image plane and do plane wave integration.

Source code in deeplens/optics/geolens_pkg/psf_compute.py
def psf_huygens(
    self, points, ks=PSF_KS, wvln=DEFAULT_WAVE, spp=SPP_COHERENT, recenter=True
):
    """Single wavelength Huygens PSF calculation.

    This function is not differentiable due to its heavy computational cost.

    Steps:
        1, Trace coherent rays to exit-pupil plane.
        2, Treat every ray as a secondary point source emitting a spherical wave.

    Args:
        points (Tensor): Normalized point source position. Shape of [N, 3], x, y in range [-1, 1], z in range [-Inf, 0].
        ks (int, optional): Output kernel size.
        wvln (float, optional): Wavelength.
        spp (int, optional): Sample per pixel.
        recenter (bool, optional): Recenter PSF using chief ray.

    Returns:
        psf: Shape of [ks, ks] or [N, ks, ks].

    References:
        [1] "Optical Aberrations Correction in Postprocessing Using Imaging Simulation", TOG 2021

    Note:
        This is different from ZEMAX Huygens PSF, which traces rays to image plane and do plane wave integration.
    """
    assert torch.get_default_dtype() == torch.float64, (
        "Default dtype must be set to float64 for accurate phase calculation."
    )

    sensor_w, sensor_h = self.sensor_size
    pixel_size = self.pixel_size
    device = self.device
    wvln_mm = wvln * 1e-3  # Convert wavelength to mm

    # Points shape of [N, 3]
    if not torch.is_tensor(points):
        points = torch.tensor(points, device=device)

    if len(points.shape) == 1:
        single_point = True
        points = points.unsqueeze(0)
    elif len(points.shape) == 2 and points.shape[0] == 1:
        single_point = True
    else:
        raise ValueError(
            f"Points must be of shape [3] or [1, 3], got {points.shape}."
        )

    # Sample rays from object point
    depth = points[:, 2]
    scale = self.calc_scale(depth)
    point_obj_x = points[..., 0] * scale * sensor_w / 2
    point_obj_y = points[..., 1] * scale * sensor_h / 2
    point_obj = torch.stack([point_obj_x, point_obj_y, points[..., 2]], dim=-1)
    ray = self.sample_from_points(points=point_obj, num_rays=spp, wvln=wvln)

    # Trace rays coherently through the lens to exit pupil
    ray.coherent = True
    ray = self.trace2exit_pupil(ray)

    # Calculate PSF center (not flipped here)
    if recenter:
        pointc = -self.psf_center(point_obj, method="chief_ray")
    else:
        pointc = -self.psf_center(point_obj, method="pinhole")

    # Build PSF pixel coordinates (sensor plane at z = d_sensor)
    sensor_z = self.d_sensor.item()
    psf_half_size = (ks / 2) * pixel_size  # Physical half-size of PSF region
    x_coords = torch.linspace(
        -psf_half_size + pixel_size / 2,
        psf_half_size - pixel_size / 2,
        ks,
        device=device,
    )
    y_coords = torch.linspace(
        psf_half_size - pixel_size / 2,
        -psf_half_size + pixel_size / 2,
        ks,
        device=device,
    )
    psf_x, psf_y = torch.meshgrid(
        pointc[0, 0] + x_coords, pointc[0, 1] + y_coords, indexing="xy"
    )  # [ks, ks] each

    # Get valid rays only
    valid_mask = ray.is_valid > 0
    valid_pos = ray.o[valid_mask]  # [num_valid, 3]
    valid_dir = ray.d[valid_mask]  # [num_valid, 3]
    valid_opl = ray.opl[valid_mask]  # [num_valid]
    num_valid = valid_pos.shape[0]

    # Huygens integration: sum spherical waves from each secondary source
    psf_complex = torch.zeros(ks, ks, dtype=torch.complex128, device=device)
    opl_min = valid_opl.min()

    # Compute distance from each secondary source to each pixel
    batch_size = min(num_valid, 10_000)  # Process rays in batches
    for batch_start in range(0, num_valid, batch_size):
        batch_end = min(batch_start + batch_size, num_valid)

        # Batch ray data
        batch_pos = valid_pos[batch_start:batch_end]  # [batch, 3]
        batch_dir = valid_dir[batch_start:batch_end]  # [batch, 3]
        batch_opl = valid_opl[batch_start:batch_end].squeeze(-1)  # [batch]

        # Distance from each secondary source to each pixel
        # batch_pos: [batch, 3], psf_x: [ks, ks]
        dx = psf_x.unsqueeze(-1) - batch_pos[:, 0]  # [ks, ks, batch]
        dy = psf_y.unsqueeze(-1) - batch_pos[:, 1]  # [ks, ks, batch]
        dz = sensor_z - batch_pos[:, 2]  # [batch]

        # Distance r from secondary source to pixel
        r = torch.sqrt(dx**2 + dy**2 + dz**2)  # [ks, ks, batch]

        # Obliquity factor: cos(theta) where theta is angle from normal
        # Using ray direction at exit pupil (dz component)
        obliq = torch.abs(batch_dir[:, 2])  # [batch]
        amp = 0.5 * (1.0 + obliq)  # Huygens–Fresnel obliquity factor

        # Total optical path = OPL through lens + distance to pixel
        total_opl = batch_opl + r  # [ks, ks, batch]

        # Phase relative to reference
        phase = torch.fmod((total_opl - opl_min) / wvln_mm, 1.0) * (
            2 * torch.pi
        )  # [ks, ks, batch]

        # Complex amplitude: A * exp(i * phase) / r (spherical wave decay)
        # We use 1/r for spherical wave amplitude decay
        complex_amp = (amp / r) * torch.exp(1j * phase)  # [ks, ks, batch]

        # Sum contributions from this batch
        psf_complex += complex_amp.sum(dim=-1)  # [ks, ks]

    # Convert complex field to intensity
    psf = psf_complex.abs() ** 2

    # Intensity normalization
    psf = psf / (torch.sum(psf, dim=(-2, -1), keepdim=True) + EPSILON)

    # Flip PSF
    psf = torch.flip(psf, [-2, -1])

    if single_point:
        psf = psf.squeeze(0)

    return diff_float(psf)

psf_map

psf_map(depth=DEPTH, grid=(7, 7), ks=PSF_KS, spp=SPP_PSF, wvln=DEFAULT_WAVE, recenter=True)

Compute the geometric PSF map at given depth.

Overrides the base method in Lens class to improve efficiency by parallel ray tracing over different field points.

Parameters:

Name Type Description Default
depth float

Depth of the object plane. Defaults to DEPTH.

DEPTH
grid (int, tuple)

Grid size (grid_w, grid_h). Defaults to 7.

(7, 7)
ks int

Kernel size. Defaults to PSF_KS.

PSF_KS
spp int

Sample per pixel. Defaults to SPP_PSF.

SPP_PSF
recenter bool

Recenter PSF using chief ray. Defaults to True.

True

Returns:

Name Type Description
psf_map

PSF map. Shape of [grid_h, grid_w, 1, ks, ks].

Source code in deeplens/optics/geolens_pkg/psf_compute.py
def psf_map(
    self,
    depth=DEPTH,
    grid=(7, 7),
    ks=PSF_KS,
    spp=SPP_PSF,
    wvln=DEFAULT_WAVE,
    recenter=True,
):
    """Compute the geometric PSF map at given depth.

    Overrides the base method in Lens class to improve efficiency by parallel ray tracing over different field points.

    Args:
        depth (float, optional): Depth of the object plane. Defaults to DEPTH.
        grid (int, tuple): Grid size (grid_w, grid_h). Defaults to 7.
        ks (int, optional): Kernel size. Defaults to PSF_KS.
        spp (int, optional): Sample per pixel. Defaults to SPP_PSF.
        recenter (bool, optional): Recenter PSF using chief ray. Defaults to True.

    Returns:
        psf_map: PSF map. Shape of [grid_h, grid_w, 1, ks, ks].
    """
    if isinstance(grid, int):
        grid = (grid, grid)
    points = self.point_source_grid(depth=depth, grid=grid)
    points = points.reshape(-1, 3)
    psfs = self.psf(
        points=points, ks=ks, recenter=recenter, spp=spp, wvln=wvln
    ).unsqueeze(1)  # [grid_h * grid_w, 1, ks, ks]

    psf_map = psfs.reshape(grid[1], grid[0], 1, ks, ks)
    return psf_map

psf_center

psf_center(points_obj, method='chief_ray')

Compute reference PSF center (flipped to match the original point) for given point source.

Parameters:

Name Type Description Default
points_obj

[..., 3] un-normalized point in object plane. [-Inf, Inf] * [-Inf, Inf] * [-Inf, 0]

required
method

"chief_ray" or "pinhole". Defaults to "chief_ray".

'chief_ray'

Returns:

Name Type Description
psf_center

[..., 2] un-normalized psf center in sensor plane.

Source code in deeplens/optics/geolens_pkg/psf_compute.py
@torch.no_grad()
def psf_center(self, points_obj, method="chief_ray"):
    """Compute reference PSF center (flipped to match the original point) for given point source.

    Args:
        points_obj: [..., 3] un-normalized point in object plane. [-Inf, Inf] * [-Inf, Inf] * [-Inf, 0]
        method: "chief_ray" or "pinhole". Defaults to "chief_ray".

    Returns:
        psf_center: [..., 2] un-normalized psf center in sensor plane.
    """
    if method == "chief_ray":
        # Shrink the pupil and calculate green light centroid ray as the chief ray
        ray = self.sample_from_points(points_obj, scale_pupil=0.5, num_rays=SPP_CALC)
        ray = self.trace2sensor(ray)
        if not ray.is_valid.any():
            raise RuntimeError(
                "When tracing chief ray for PSF center calculation, no ray arrives at the sensor."
            )
        psf_center = ray.centroid()
        psf_center = -psf_center[..., :2]  # shape [..., 2]

    elif method == "pinhole":
        # Pinhole camera perspective projection, distortion not considered
        if points_obj[..., 2].min().abs() < 100:
            print(
                "Point source is too close, pinhole model may be inaccurate for PSF center calculation."
            )
        tan_point_fov_x = -points_obj[..., 0] / points_obj[..., 2]
        tan_point_fov_y = -points_obj[..., 1] / points_obj[..., 2]
        psf_center_x = self.foclen * tan_point_fov_x
        psf_center_y = self.foclen * tan_point_fov_y
        psf_center = torch.stack([psf_center_x, psf_center_y], dim=-1).to(
            self.device
        )

    else:
        raise ValueError(
            f"Unsupported method for PSF center calculation: {method}."
        )

    return psf_center

deeplens.optics.geolens_pkg.eval.GeoLensEval

Mixin that adds classical optical evaluation methods to GeoLens.

This class is never instantiated on its own. It is mixed into GeoLens via multiple inheritance, so every method can access lens geometry (self.d_sensor, self.rfov, …) and ray-tracing routines (self.trace(), self.trace2sensor(), …) directly through self.

All evaluation functions follow the same pattern
  1. Sample rays from object space (parallel / grid / radial).
  2. Trace rays through the lens (self.trace or self.trace2sensor).
  3. Analyze ray positions / directions at the sensor plane.
  4. Optionally produce a matplotlib figure saved to disk.

Results are accuracy-aligned with Zemax OpticStudio for the same lens prescriptions and ray-sampling densities.

Attributes consumed from GeoLens (via self): d_sensor (float): Axial position of the sensor plane (mm). sensor_size (tuple[float, float]): Sensor (width, height) in mm. pixel_size (float): Pixel pitch in mm. sensor_res (tuple[int, int]): Sensor resolution (H, W) in pixels. rfov (float): Half field-of-view in radians. foclen (float): Equivalent focal length in mm. fnum (float): F-number. aper_idx (int): Index of the aperture stop surface. device (torch.device): Compute device (CPU / CUDA).

spot_points

spot_points(points, num_rays=SPP_PSF, wvln=DEFAULT_WAVE)

Trace rays from object points to sensor and return the traced Ray.

Samples rays from each physical object point toward the entrance pupil, traces through all lens surfaces (refraction + clipping), and returns the resulting Ray object on the sensor plane.

This is the shared computational core for spot diagrams (draw_spot_radial, draw_spot_map) and RMS error maps (rms_map, rms_map_rgb).

Algorithm
  1. self.sample_from_points(points, num_rays, wvln) generates a fan of num_rays rays per object point, aimed at the entrance pupil.
  2. self.trace2sensor() propagates through all surfaces and clips vignetted rays.

Parameters:

Name Type Description Default
points Tensor

Physical 3D object-space coordinates with shape [..., 3] (mm). Supported layouts: - [3] — single point. - [N, 3] — N points (e.g. radial field positions). - [H, W, 3] — 2-D field grid. Generated by self.point_source_grid(normalized=False) for grid sampling, or self.point_source_radial(normalized=False) for radial sampling.

required
num_rays int

Number of rays sampled per object point. Defaults to SPP_PSF.

SPP_PSF
wvln float

Wavelength in micrometers. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE

Returns:

Name Type Description
Ray

Traced ray on the sensor plane, with shape [..., num_rays, 3] for positions and [..., num_rays] for validity mask. Use ray.o[..., :2] for transverse positions and ray.is_valid for the validity mask. ray.centroid() gives the weighted centroid.

Source code in deeplens/optics/geolens_pkg/eval.py
@torch.no_grad()
def spot_points(self, points, num_rays=SPP_PSF, wvln=DEFAULT_WAVE):
    """Trace rays from object points to sensor and return the traced Ray.

    Samples rays from each physical object point toward the entrance pupil,
    traces through all lens surfaces (refraction + clipping), and returns
    the resulting Ray object on the sensor plane.

    This is the shared computational core for spot diagrams
    (``draw_spot_radial``, ``draw_spot_map``) and RMS error maps
    (``rms_map``, ``rms_map_rgb``).

    Algorithm:
        1. ``self.sample_from_points(points, num_rays, wvln)`` generates a
           fan of ``num_rays`` rays per object point, aimed at the entrance
           pupil.
        2. ``self.trace2sensor()`` propagates through all surfaces and
           clips vignetted rays.

    Args:
        points (torch.Tensor): Physical 3D object-space coordinates with
            shape ``[..., 3]`` (mm).  Supported layouts:
            - ``[3]`` — single point.
            - ``[N, 3]`` — N points (e.g. radial field positions).
            - ``[H, W, 3]`` — 2-D field grid.
            Generated by ``self.point_source_grid(normalized=False)`` for
            grid sampling, or ``self.point_source_radial(normalized=False)``
            for radial sampling.
        num_rays (int): Number of rays sampled per object point.
            Defaults to ``SPP_PSF``.
        wvln (float): Wavelength in micrometers.
            Defaults to ``DEFAULT_WAVE``.

    Returns:
        Ray: Traced ray on the sensor plane, with shape
            ``[..., num_rays, 3]`` for positions and ``[..., num_rays]``
            for validity mask. Use ``ray.o[..., :2]`` for transverse
            positions and ``ray.is_valid`` for the validity mask.
            ``ray.centroid()`` gives the weighted centroid.
    """
    ray = self.sample_from_points(points=points, num_rays=num_rays, wvln=wvln)
    return self.trace2sensor(ray)

draw_spot_radial

draw_spot_radial(save_name='./lens_spot_radial.png', num_fov=5, depth=DEPTH, num_rays=SPP_PSF, wvln_list=WAVE_RGB, direction='y', show=False)

Draw spot diagrams at evenly-spaced field angles along a chosen direction.

A spot diagram visualizes the transverse ray-intercept distribution on the sensor plane for a point source at a given field angle and depth. It reveals the combined effect of all aberrations (spherical, coma, astigmatism, field curvature, chromatic, …).

Algorithm

For each wavelength in wvln_list: 1. self.point_source_radial(direction, normalized=False) generates physical object-space points along the chosen direction. 2. self.spot_points() samples rays and traces to sensor. 3. Valid ray (x, y) positions are scatter-plotted per subplot. All wavelengths are overlaid in a single figure with RGB coloring.

Parameters:

Name Type Description Default
save_name str

File path for the output PNG. Defaults to './lens_spot_radial.png'.

'./lens_spot_radial.png'
num_fov int

Number of field positions sampled uniformly from on-axis (0) to full-field. Defaults to 5.

5
depth float

Object distance in mm (negative = real object). Defaults to DEPTH.

DEPTH
num_rays int

Rays per field position per wavelength. Defaults to SPP_PSF.

SPP_PSF
wvln_list list[float]

Wavelengths in micrometers. Defaults to WAVE_RGB (red, green, blue).

WAVE_RGB
direction str

Sampling direction — "y" (meridional, default), "x" (sagittal), "diagonal" (45°).

'y'
show bool

If True, display the figure interactively instead of saving to disk. Defaults to False.

False
Source code in deeplens/optics/geolens_pkg/eval.py
@torch.no_grad()
def draw_spot_radial(
    self,
    save_name="./lens_spot_radial.png",
    num_fov=5,
    depth=DEPTH,
    num_rays=SPP_PSF,
    wvln_list=WAVE_RGB,
    direction="y",
    show=False,
):
    """Draw spot diagrams at evenly-spaced field angles along a chosen direction.

    A *spot diagram* visualizes the transverse ray-intercept distribution on
    the sensor plane for a point source at a given field angle and depth.
    It reveals the combined effect of all aberrations (spherical, coma,
    astigmatism, field curvature, chromatic, …).

    Algorithm:
        For each wavelength in ``wvln_list``:
            1. ``self.point_source_radial(direction, normalized=False)``
               generates physical object-space points along the chosen
               direction.
            2. ``self.spot_points()`` samples rays and traces to sensor.
            3. Valid ray (x, y) positions are scatter-plotted per subplot.
        All wavelengths are overlaid in a single figure with RGB coloring.

    Args:
        save_name (str): File path for the output PNG.
            Defaults to ``'./lens_spot_radial.png'``.
        num_fov (int): Number of field positions sampled uniformly from
            on-axis (0) to full-field. Defaults to 5.
        depth (float): Object distance in mm (negative = real object).
            Defaults to ``DEPTH``.
        num_rays (int): Rays per field position per wavelength.
            Defaults to ``SPP_PSF``.
        wvln_list (list[float]): Wavelengths in micrometers.
            Defaults to ``WAVE_RGB`` (red, green, blue).
        direction (str): Sampling direction —
            ``"y"`` (meridional, default), ``"x"`` (sagittal),
            ``"diagonal"`` (45°).
        show (bool): If ``True``, display the figure interactively instead
            of saving to disk. Defaults to ``False``.
    """
    assert isinstance(wvln_list, list), "wvln_list must be a list"
    if depth == float("inf"):
        depth = DEPTH

    # Generate physical object-space points along the chosen direction
    points = self.point_source_radial(
        depth=depth, grid=num_fov, direction=direction, normalized=False
    )

    # Prepare figure
    fig, axs = plt.subplots(1, num_fov, figsize=(num_fov * 3.5, 3))
    axs = np.atleast_1d(axs)

    # Trace and draw each wavelength separately, overlaying results
    for wvln_idx, wvln in enumerate(wvln_list):
        ray = self.spot_points(points, num_rays=num_rays, wvln=wvln)
        ray_o = ray.o[..., :2].cpu().numpy()
        ray_valid_np = ray.is_valid.cpu().numpy()

        color = RGB_COLORS[wvln_idx % len(RGB_COLORS)]

        # Plot multiple spot diagrams in one figure
        for i in range(num_fov):
            valid = ray_valid_np[i, :]
            xi, yi = ray_o[i, :, 0], ray_o[i, :, 1]

            # Filter valid rays
            mask = valid > 0
            x_valid, y_valid = xi[mask], yi[mask]

            # Plot points and center of mass for this wavelength
            axs[i].scatter(x_valid, y_valid, 2, color=color, alpha=0.5)
            axs[i].set_aspect("equal", adjustable="datalim")
            axs[i].tick_params(axis="both", which="major", labelsize=6)

    if show:
        plt.show()
    else:
        assert save_name.endswith(".png"), "save_name must end with .png"
        plt.savefig(save_name, bbox_inches="tight", format="png", dpi=300)
    plt.close(fig)

draw_spot_map

draw_spot_map(save_name='./lens_spot_map.png', num_grid=5, depth=DEPTH, num_rays=SPP_PSF, wvln_list=WAVE_RGB, show=False)

Draw a 2-D grid of spot diagrams across the full field of view.

Unlike draw_spot_radial (which samples only a radial slice), this method samples a num_grid × num_grid grid of field positions covering both the x (sagittal) and y (meridional) axes, revealing off-axis aberrations that are invisible in a 1-D radial scan.

Algorithm

For each wavelength in wvln_list: 1. self.point_source_grid(normalized=False) creates physical object-space grid points, shape [grid_h, grid_w, 3]. 2. self.spot_points() samples rays and traces to sensor. 3. Valid (x, y) positions are scatter-plotted in the corresponding subplot of the num_grid × num_grid figure. All wavelengths are overlaid with RGB coloring.

Parameters:

Name Type Description Default
save_name str

File path for the output PNG. Defaults to './lens_spot_map.png'.

'./lens_spot_map.png'
num_grid int | tuple[int, int]

Number of grid points along each axis. Total subplots = grid_w * grid_h. Defaults to 5.

5
depth float

Object distance in mm. Defaults to DEPTH.

DEPTH
num_rays int

Rays per grid cell per wavelength. Defaults to SPP_PSF.

SPP_PSF
wvln_list list[float]

Wavelengths in micrometers. Defaults to WAVE_RGB.

WAVE_RGB
show bool

If True, display interactively. Defaults to False.

False
Source code in deeplens/optics/geolens_pkg/eval.py
@torch.no_grad()
def draw_spot_map(
    self,
    save_name="./lens_spot_map.png",
    num_grid=5,
    depth=DEPTH,
    num_rays=SPP_PSF,
    wvln_list=WAVE_RGB,
    show=False,
):
    """Draw a 2-D grid of spot diagrams across the full field of view.

    Unlike ``draw_spot_radial`` (which samples only a radial slice),
    this method samples a ``num_grid × num_grid`` grid of field positions
    covering both the x (sagittal) and y (meridional) axes, revealing
    off-axis aberrations that are invisible in a 1-D radial scan.

    Algorithm:
        For each wavelength in ``wvln_list``:
            1. ``self.point_source_grid(normalized=False)`` creates physical
               object-space grid points, shape ``[grid_h, grid_w, 3]``.
            2. ``self.spot_points()`` samples rays and traces to sensor.
            3. Valid (x, y) positions are scatter-plotted in the
               corresponding subplot of the ``num_grid × num_grid`` figure.
        All wavelengths are overlaid with RGB coloring.

    Args:
        save_name (str): File path for the output PNG.
            Defaults to ``'./lens_spot_map.png'``.
        num_grid (int | tuple[int, int]): Number of grid points along each
            axis. Total subplots = ``grid_w * grid_h``. Defaults to 5.
        depth (float): Object distance in mm. Defaults to ``DEPTH``.
        num_rays (int): Rays per grid cell per wavelength.
            Defaults to ``SPP_PSF``.
        wvln_list (list[float]): Wavelengths in micrometers.
            Defaults to ``WAVE_RGB``.
        show (bool): If ``True``, display interactively. Defaults to ``False``.
    """
    assert isinstance(wvln_list, list), "wvln_list must be a list"
    if isinstance(num_grid, int):
        num_grid = (num_grid, num_grid)

    # Generate physical object-space grid points, shape [grid_h, grid_w, 3]
    points = self.point_source_grid(depth=depth, grid=num_grid, normalized=False)

    grid_w, grid_h = num_grid
    fig, axs = plt.subplots(
        grid_h, grid_w, figsize=(grid_w * 3, grid_h * 3)
    )
    axs = np.atleast_2d(axs)

    # Loop wavelengths and overlay scatters
    for wvln_idx, wvln in enumerate(wvln_list):
        ray = self.spot_points(points, num_rays=num_rays, wvln=wvln)

        # Convert to numpy, shape [grid_h, grid_w, num_rays, 2]
        ray_o = -ray.o[..., :2].cpu().numpy()
        ray_valid_np = ray.is_valid.cpu().numpy()

        color = RGB_COLORS[wvln_idx % len(RGB_COLORS)]

        # Draw per grid cell
        for i in range(grid_h):
            for j in range(grid_w):
                valid = ray_valid_np[i, j, :]
                xi, yi = ray_o[i, j, :, 0], ray_o[i, j, :, 1]

                # Filter valid rays
                mask = valid > 0
                x_valid, y_valid = xi[mask], yi[mask]

                # Plot points for this wavelength
                axs[i, j].scatter(x_valid, y_valid, 2, color=color, alpha=0.5)
                axs[i, j].set_aspect("equal", adjustable="datalim")
                axs[i, j].tick_params(axis="both", which="major", labelsize=6)

    if show:
        plt.show()
    else:
        assert save_name.endswith(".png"), "save_name must end with .png"
        plt.savefig(save_name, bbox_inches="tight", format="png", dpi=300)
    plt.close(fig)

rms_map

rms_map(num_grid=32, depth=DEPTH, wvln=DEFAULT_WAVE, center=None)

Compute per-field-position RMS spot radius for a single wavelength.

Traces SPP_PSF rays per grid cell and computes the root-mean-square distance of valid ray hits from a reference centroid. When center is None, each cell uses its own centroid (monochromatic blur). When an external center is provided (e.g. the green-channel centroid), the RMS includes the chromatic shift from that reference.

Algorithm
  1. self.point_source_grid(normalized=False) generates physical object points on a [num_grid, num_grid] field grid.
  2. self.spot_points() samples SPP_PSF rays per point and traces to sensor.
  3. If center is None, compute per-cell centroid c = mean(valid ray_xy); otherwise use the provided center.
  4. RMS = sqrt( mean( ||ray_xy - c||^2 ) ).

Parameters:

Name Type Description Default
num_grid int | tuple[int, int]

Spatial resolution of the field sampling grid. Defaults to 32.

32
depth float

Object distance in mm. Defaults to DEPTH.

DEPTH
wvln float

Wavelength in micrometers. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE
center Tensor | None

External reference centroid with shape [grid_h, grid_w, 2]. If None, each cell's own centroid is used. Defaults to None.

None

Returns:

Type Description

tuple[torch.Tensor, torch.Tensor]: - rms: RMS spot error map, shape [grid_h, grid_w], in mm. - centroid: Per-cell centroid used as reference, shape [grid_h, grid_w, 2]. Useful for passing as center to subsequent calls (e.g. in rms_map_rgb).

Source code in deeplens/optics/geolens_pkg/eval.py
@torch.no_grad()
def rms_map(self, num_grid=32, depth=DEPTH, wvln=DEFAULT_WAVE, center=None):
    """Compute per-field-position RMS spot radius for a single wavelength.

    Traces ``SPP_PSF`` rays per grid cell and computes the root-mean-square
    distance of valid ray hits from a reference centroid.  When ``center``
    is ``None``, each cell uses its own centroid (monochromatic blur).
    When an external ``center`` is provided (e.g. the green-channel
    centroid), the RMS includes the chromatic shift from that reference.

    Algorithm:
        1. ``self.point_source_grid(normalized=False)`` generates physical
           object points on a ``[num_grid, num_grid]`` field grid.
        2. ``self.spot_points()`` samples ``SPP_PSF`` rays per point and
           traces to sensor.
        3. If ``center`` is ``None``, compute per-cell centroid
           ``c = mean(valid ray_xy)``; otherwise use the provided ``center``.
        4. ``RMS = sqrt( mean( ||ray_xy - c||^2 ) )``.

    Args:
        num_grid (int | tuple[int, int]): Spatial resolution of the field
            sampling grid. Defaults to 32.
        depth (float): Object distance in mm. Defaults to ``DEPTH``.
        wvln (float): Wavelength in micrometers. Defaults to ``DEFAULT_WAVE``.
        center (torch.Tensor | None): External reference centroid with shape
            ``[grid_h, grid_w, 2]``.  If ``None``, each cell's own
            centroid is used. Defaults to ``None``.

    Returns:
        tuple[torch.Tensor, torch.Tensor]:
            - **rms**: RMS spot error map, shape ``[grid_h, grid_w]``,
              in mm.
            - **centroid**: Per-cell centroid used as reference, shape
              ``[grid_h, grid_w, 2]``.  Useful for passing as
              ``center`` to subsequent calls (e.g. in ``rms_map_rgb``).
    """
    if isinstance(num_grid, int):
        num_grid = (num_grid, num_grid)

    # Generate physical grid points and trace rays to sensor
    points = self.point_source_grid(depth=depth, grid=num_grid, normalized=False)
    ray = self.spot_points(points, num_rays=SPP_PSF, wvln=wvln)

    # Reuse Ray.centroid() — shape [grid_h, grid_w, 3], slice to [grid_h, grid_w, 2]
    centroid = ray.centroid()[..., :2]

    # Use external center if provided, otherwise own centroid
    ref = center if center is not None else centroid

    # RMS relative to reference, shape [grid_h, grid_w]
    ray_xy = ray.o[..., :2]
    ray_valid = ray.is_valid
    rms = torch.sqrt(
        (((ray_xy - ref.unsqueeze(-2)) ** 2).sum(-1) * ray_valid).sum(-1)
        / (ray_valid.sum(-1) + EPSILON)
    )

    return rms, centroid

rms_map_rgb

rms_map_rgb(num_grid=32, depth=DEPTH)

Compute per-field-position RMS spot radius for R, G, B wavelengths.

The RMS spot radius is a standard measure of geometrical image quality. For each field position in a num_grid × num_grid grid, this method traces SPP_PSF rays per wavelength and computes the root-mean-square distance of valid ray hits from a common reference centroid.

The reference centroid is the green-channel centroid. Using a common reference means the returned RMS values include lateral chromatic aberration (the shift between R/G/B centroids), making the map useful as a polychromatic image-quality metric.

Algorithm
  1. Call rms_map(wvln=green) to get the green RMS map and the green centroid.
  2. Call rms_map(wvln=red, center=green_centroid) and rms_map(wvln=blue, center=green_centroid) to measure R/B blur relative to the green reference.
  3. Stack as [R, G, B].

Parameters:

Name Type Description Default
num_grid int

Spatial resolution of the field sampling grid. Defaults to 32.

32
depth float

Object distance in mm. Defaults to DEPTH.

DEPTH

Returns:

Type Description

torch.Tensor: RMS spot error map with shape [3, num_grid, num_grid] (channels ordered R, G, B). Units are mm (same as sensor coordinates).

Source code in deeplens/optics/geolens_pkg/eval.py
@torch.no_grad()
def rms_map_rgb(self, num_grid=32, depth=DEPTH):
    """Compute per-field-position RMS spot radius for R, G, B wavelengths.

    The RMS spot radius is a standard measure of geometrical image quality.
    For each field position in a ``num_grid × num_grid`` grid, this method
    traces ``SPP_PSF`` rays per wavelength and computes the root-mean-square
    distance of valid ray hits from a **common** reference centroid.

    The reference centroid is the green-channel centroid.  Using a common
    reference means the returned RMS values include *lateral chromatic
    aberration* (the shift between R/G/B centroids), making the map useful
    as a polychromatic image-quality metric.

    Algorithm:
        1. Call ``rms_map(wvln=green)`` to get the green RMS map **and**
           the green centroid.
        2. Call ``rms_map(wvln=red, center=green_centroid)`` and
           ``rms_map(wvln=blue, center=green_centroid)`` to measure R/B
           blur relative to the green reference.
        3. Stack as ``[R, G, B]``.

    Args:
        num_grid (int): Spatial resolution of the field sampling grid.
            Defaults to 32.
        depth (float): Object distance in mm. Defaults to ``DEPTH``.

    Returns:
        torch.Tensor: RMS spot error map with shape ``[3, num_grid, num_grid]``
            (channels ordered R, G, B). Units are mm (same as sensor
            coordinates).
    """
    # Green first to obtain the shared reference centroid
    rms_g, green_centroid = self.rms_map(
        num_grid=num_grid, depth=depth, wvln=WAVE_RGB[1]
    )

    # Red and blue relative to the green centroid
    rms_r, _ = self.rms_map(
        num_grid=num_grid, depth=depth, wvln=WAVE_RGB[0], center=green_centroid
    )
    rms_b, _ = self.rms_map(
        num_grid=num_grid, depth=depth, wvln=WAVE_RGB[2], center=green_centroid
    )

    return torch.stack([rms_r, rms_g, rms_b], dim=0)

calc_distortion_radial

calc_distortion_radial(num_points=GEO_GRID, wvln=DEFAULT_WAVE, plane='meridional', ray_aiming=True)

Compute fractional distortion at evenly-spaced field angles along the meridional direction.

Distortion is defined as (h_actual - h_ideal) / h_ideal, where h_ideal = f * tan(theta) (rectilinear projection) and h_actual is the chief-ray image height on the sensor. A positive value means pincushion distortion; negative means barrel distortion.

This is the computational counterpart to draw_spot_radial: it samples num_points field angles uniformly from 0 to self.rfov and returns both the sampled angles and the corresponding distortion values, making it easy to pair with other radial evaluation functions.

Algorithm
  1. Derive rfov_deg from self.rfov (radians → degrees).
  2. Sample num_points field angles uniformly in [0, rfov_deg]. The on-axis sample (0°) is replaced by a tiny positive angle to avoid 0/0.
  3. Compute h_ideal = foclen * tan(angle) for each sample.
  4. Trace the chief ray (via calc_chief_ray_infinite) through the full lens to the sensor plane.
  5. Extract h_actual from the appropriate transverse coordinate (x for sagittal, y for meridional).
  6. Return (h_actual - h_ideal) / h_ideal.

Parameters:

Name Type Description Default
num_points int

Number of evenly-spaced field-angle samples from on-axis (0°) to full-field (self.rfov). Defaults to GEO_GRID.

GEO_GRID
wvln float

Wavelength in micrometers. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE
plane str

'meridional' (y-axis) or 'sagittal' (x-axis). Defaults to 'meridional'.

'meridional'
ray_aiming bool

If True, the chief ray is aimed to pass through the center of the aperture stop (more accurate for wide-angle lenses). Defaults to True.

True

Returns:

Type Description

tuple[np.ndarray, np.ndarray]: - rfov_samples: Field angles in degrees, shape [num_points]. - distortions: Fractional distortion at each angle, shape [num_points]. Dimensionless (multiply by 100 for percent).

Source code in deeplens/optics/geolens_pkg/eval.py
@torch.no_grad()
def calc_distortion_radial(
    self,
    num_points=GEO_GRID,
    wvln=DEFAULT_WAVE,
    plane="meridional",
    ray_aiming=True,
):
    """Compute fractional distortion at evenly-spaced field angles along the meridional direction.

    Distortion is defined as ``(h_actual - h_ideal) / h_ideal``, where
    ``h_ideal = f * tan(theta)`` (rectilinear projection) and ``h_actual``
    is the chief-ray image height on the sensor.  A positive value means
    pincushion distortion; negative means barrel distortion.

    This is the computational counterpart to ``draw_spot_radial``: it
    samples ``num_points`` field angles uniformly from 0 to ``self.rfov``
    and returns both the sampled angles and the corresponding distortion
    values, making it easy to pair with other radial evaluation functions.

    Algorithm:
        1. Derive ``rfov_deg`` from ``self.rfov`` (radians → degrees).
        2. Sample ``num_points`` field angles uniformly in
           ``[0, rfov_deg]``.  The on-axis sample (0°) is replaced by a
           tiny positive angle to avoid 0/0.
        3. Compute ``h_ideal = foclen * tan(angle)`` for each sample.
        4. Trace the chief ray (via ``calc_chief_ray_infinite``) through the
           full lens to the sensor plane.
        5. Extract ``h_actual`` from the appropriate transverse coordinate
           (x for sagittal, y for meridional).
        6. Return ``(h_actual - h_ideal) / h_ideal``.

    Args:
        num_points (int): Number of evenly-spaced field-angle samples from
            on-axis (0°) to full-field (``self.rfov``).
            Defaults to ``GEO_GRID``.
        wvln (float): Wavelength in micrometers. Defaults to ``DEFAULT_WAVE``.
        plane (str): ``'meridional'`` (y-axis) or ``'sagittal'`` (x-axis).
            Defaults to ``'meridional'``.
        ray_aiming (bool): If ``True``, the chief ray is aimed to pass
            through the center of the aperture stop (more accurate for
            wide-angle lenses). Defaults to ``True``.

    Returns:
        tuple[np.ndarray, np.ndarray]:
            - **rfov_samples**: Field angles in degrees, shape ``[num_points]``.
            - **distortions**: Fractional distortion at each angle, shape
              ``[num_points]``.  Dimensionless (multiply by 100 for
              percent).
    """
    rfov_deg = float(self.rfov) * 180.0 / np.pi

    # Sample field angles uniformly from 0 to rfov_deg.
    # For the on-axis point (FOV=0), distortion is 0/0.  We compute it at a
    # tiny positive angle to obtain the correct limit, which may be non-zero
    # when the sensor is not at the paraxial focus.
    rfov_samples = torch.linspace(0, rfov_deg, num_points)
    rfov_compute = rfov_samples.clone()
    if rfov_compute[0] == 0:
        rfov_compute[0] = min(0.01, rfov_samples[1].item() * 0.01)

    # Ideal image height: h_ideal = f * tan(theta)
    eff_foclen = float(self.foclen)
    ideal_imgh = eff_foclen * np.tan(rfov_compute.numpy() * np.pi / 180)

    # Trace chief rays to the sensor plane
    chief_ray_o, chief_ray_d = self.calc_chief_ray_infinite(
        rfov=rfov_compute, wvln=wvln, plane=plane, ray_aiming=ray_aiming
    )
    ray = Ray(chief_ray_o, chief_ray_d, wvln=wvln, device=self.device)
    ray, _ = self.trace(ray)
    t = (self.d_sensor - ray.o[..., 2]) / ray.d[..., 2]

    # Actual image height from the appropriate transverse coordinate
    if plane == "sagittal":
        actual_imgh = (ray.o[..., 0] + ray.d[..., 0] * t).abs()
    elif plane == "meridional":
        actual_imgh = (ray.o[..., 1] + ray.d[..., 1] * t).abs()
    else:
        raise ValueError(f"Invalid plane: {plane}")

    actual_imgh = actual_imgh.cpu().numpy()

    # Fractional distortion, with safe handling of the on-axis singularity
    ideal_imgh = np.asarray(ideal_imgh)
    mask = np.abs(ideal_imgh) < EPSILON
    distortions = np.where(
        mask, 0.0, (actual_imgh - ideal_imgh) / np.where(mask, 1.0, ideal_imgh)
    )

    return rfov_samples.numpy(), distortions

draw_distortion_radial

draw_distortion_radial(save_name=None, num_points=GEO_GRID, wvln=DEFAULT_WAVE, plane='meridional', ray_aiming=True, show=False)

Draw distortion-vs-field-angle curve in Zemax style.

Produces a plot with field angle on the y-axis and percent distortion on the x-axis, matching the layout convention used in Zemax OpticStudio. Useful for quick visual assessment of barrel / pincushion distortion.

Algorithm
  1. Call calc_distortion_radial to obtain field angles and fractional distortion values.
  2. Convert distortion to percent and plot.

Parameters:

Name Type Description Default
save_name str | None

File path for the output PNG. If None, auto-generates './{plane}_distortion_inf.png'.

None
num_points int

Number of field-angle samples. Defaults to GEO_GRID.

GEO_GRID
wvln float

Wavelength in micrometers. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE
plane str

'meridional' or 'sagittal'. Defaults to 'meridional'.

'meridional'
ray_aiming bool

Whether to use ray aiming for chief-ray computation. Defaults to True.

True
show bool

If True, display interactively. Defaults to False.

False
Source code in deeplens/optics/geolens_pkg/eval.py
@torch.no_grad()
def draw_distortion_radial(
    self,
    save_name=None,
    num_points=GEO_GRID,
    wvln=DEFAULT_WAVE,
    plane="meridional",
    ray_aiming=True,
    show=False,
):
    """Draw distortion-vs-field-angle curve in Zemax style.

    Produces a plot with field angle on the y-axis and percent distortion
    on the x-axis, matching the layout convention used in Zemax OpticStudio.
    Useful for quick visual assessment of barrel / pincushion distortion.

    Algorithm:
        1. Call ``calc_distortion_radial`` to obtain field angles and
           fractional distortion values.
        2. Convert distortion to percent and plot.

    Args:
        save_name (str | None): File path for the output PNG.  If ``None``,
            auto-generates ``'./{plane}_distortion_inf.png'``.
        num_points (int): Number of field-angle samples.
            Defaults to ``GEO_GRID``.
        wvln (float): Wavelength in micrometers. Defaults to ``DEFAULT_WAVE``.
        plane (str): ``'meridional'`` or ``'sagittal'``.
            Defaults to ``'meridional'``.
        ray_aiming (bool): Whether to use ray aiming for chief-ray
            computation. Defaults to ``True``.
        show (bool): If ``True``, display interactively. Defaults to ``False``.
    """
    rfov_deg = float(self.rfov) * 180.0 / np.pi

    # Calculate distortion at evenly-spaced field angles
    rfov_samples, distortions = self.calc_distortion_radial(
        num_points=num_points, wvln=wvln, plane=plane, ray_aiming=ray_aiming
    )

    # Convert to percentage and handle NaN
    values = np.nan_to_num(distortions * 100, nan=0.0).tolist()

    # Create figure
    fig, ax = plt.subplots(figsize=(8, 8))
    ax.set_title(f"{plane} Surface Distortion")

    # Draw distortion curve
    ax.plot(values, rfov_samples, linestyle="-", color="g", linewidth=1.5)

    # Draw reference line (vertical line)
    ax.axvline(x=0, color="k", linestyle="-", linewidth=0.8)

    # Set grid
    ax.grid(True, color="gray", linestyle="-", linewidth=0.5, alpha=1)

    # Dynamically adjust x-axis range
    value = max(abs(v) for v in values)
    margin = value * 0.2  # 20% margin
    x_min, x_max = -max(0.2, value + margin), max(0.2, value + margin)

    # Set ticks
    x_ticks = np.linspace(-value, value, 3)
    y_ticks = np.linspace(0, rfov_deg, 3)

    ax.set_xticks(x_ticks)
    ax.set_yticks(y_ticks)

    # Format tick labels
    x_labels = [f"{x:.1f}%" for x in x_ticks]
    y_labels = [f"{y:.1f}" for y in y_ticks]

    ax.set_xticklabels(x_labels)
    ax.set_yticklabels(y_labels)

    # Set axis labels
    ax.set_xlabel("Distortion (%)")
    ax.set_ylabel("Field of View (degrees)")

    # Set axis range
    ax.set_xlim(x_min, x_max)
    ax.set_ylim(0, rfov_deg)

    if show:
        plt.show()
    else:
        if save_name is None:
            save_name = f"./{plane}_distortion_inf.png"
        plt.savefig(save_name, bbox_inches="tight", format="png", dpi=300)
    plt.close(fig)

calc_distortion_map

calc_distortion_map(num_grid=16, depth=DEPTH, wvln=DEFAULT_WAVE)

Compute a 2-D distortion grid mapping ideal to actual image positions.

For each cell in a num_grid × num_grid field grid, rays are traced to the sensor and their centroid is computed. The centroid is then normalized to [-1, 1] sensor coordinates, producing a map that shows how each ideal image point is displaced by lens distortion.

This map can be used with torch.nn.functional.grid_sample to warp or unwarp rendered images.

Parameters:

Name Type Description Default
num_grid int

Grid resolution along each axis. Defaults to 16.

16
depth float

Object distance in mm. Defaults to DEPTH.

DEPTH
wvln float

Wavelength in micrometers. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE

Returns:

Type Description

torch.Tensor: Distortion grid with shape [num_grid, num_grid, 2]. Each entry (dx, dy) is in normalized sensor coordinates [-1, 1], representing the actual centroid position for the corresponding ideal grid position.

Source code in deeplens/optics/geolens_pkg/eval.py
@torch.no_grad()
def calc_distortion_map(self, num_grid=16, depth=DEPTH, wvln=DEFAULT_WAVE):
    """Compute a 2-D distortion grid mapping ideal to actual image positions.

    For each cell in a ``num_grid × num_grid`` field grid, rays are traced
    to the sensor and their centroid is computed.  The centroid is then
    normalized to ``[-1, 1]`` sensor coordinates, producing a map that
    shows how each ideal image point is displaced by lens distortion.

    This map can be used with ``torch.nn.functional.grid_sample`` to warp
    or unwarp rendered images.

    Args:
        num_grid (int): Grid resolution along each axis. Defaults to 16.
        depth (float): Object distance in mm. Defaults to ``DEPTH``.
        wvln (float): Wavelength in micrometers. Defaults to ``DEFAULT_WAVE``.

    Returns:
        torch.Tensor: Distortion grid with shape ``[num_grid, num_grid, 2]``.
            Each entry ``(dx, dy)`` is in normalized sensor coordinates
            ``[-1, 1]``, representing the actual centroid position for the
            corresponding ideal grid position.
    """
    # Sample and trace rays, shape (grid_size, grid_size, num_rays, 3)
    ray = self.sample_grid_rays(depth=depth, num_grid=num_grid, wvln=wvln, uniform_fov=False)
    ray = self.trace2sensor(ray)

    # Calculate centroid of the rays, shape (grid_size, grid_size, 2)
    ray_xy = ray.centroid()[..., :2]
    x_dist = -ray_xy[..., 0] / self.sensor_size[1] * 2
    y_dist = ray_xy[..., 1] / self.sensor_size[0] * 2
    distortion_grid = torch.stack((x_dist, y_dist), dim=-1)
    return distortion_grid

distortion_center

distortion_center(points)

Compute the distorted image centroid for arbitrary normalized object points.

Given object points in normalized coordinates, this method converts them to physical object-space positions, traces rays from each point through the lens, and returns the ray centroid on the sensor in normalized [-1, 1] coordinates. This is the inverse mapping needed for distortion correction (unwarping).

Algorithm
  1. Convert normalized (x, y) ∈ [-1, 1] to physical object-space positions using self.calc_scale(depth) and self.sensor_size.
  2. self.sample_from_points() generates rays from each point.
  3. self.trace2sensor() propagates rays.
  4. Compute centroid and normalize back to [-1, 1].

Parameters:

Name Type Description Default
points Tensor

Normalized point source positions with shape [N, 3] or [..., 3]. x, y ∈ [-1, 1] encode the field position; z ∈ (-∞, 0] is the object depth in mm.

required

Returns:

Type Description

torch.Tensor: Normalized distortion centroid positions with shape [N, 2] or [..., 2]. x, y ∈ [-1, 1].

Source code in deeplens/optics/geolens_pkg/eval.py
def distortion_center(self, points):
    """Compute the distorted image centroid for arbitrary normalized object points.

    Given object points in normalized coordinates, this method converts them
    to physical object-space positions, traces rays from each point through
    the lens, and returns the ray centroid on the sensor in normalized
    ``[-1, 1]`` coordinates.  This is the inverse mapping needed for
    distortion correction (unwarping).

    Algorithm:
        1. Convert normalized ``(x, y)`` ∈ [-1, 1] to physical object-space
           positions using ``self.calc_scale(depth)`` and ``self.sensor_size``.
        2. ``self.sample_from_points()`` generates rays from each point.
        3. ``self.trace2sensor()`` propagates rays.
        4. Compute centroid and normalize back to ``[-1, 1]``.

    Args:
        points (torch.Tensor): Normalized point source positions with shape
            ``[N, 3]`` or ``[..., 3]``.  ``x, y`` ∈ [-1, 1] encode the
            field position; ``z`` ∈ (-∞, 0] is the object depth in mm.

    Returns:
        torch.Tensor: Normalized distortion centroid positions with shape
            ``[N, 2]`` or ``[..., 2]``.  ``x, y`` ∈ [-1, 1].
    """
    sensor_w, sensor_h = self.sensor_size

    # Convert normalized points to object space coordinates
    depth = points[..., 2]
    scale = self.calc_scale(depth)
    points_obj_x = points[..., 0] * scale * sensor_w / 2
    points_obj_y = points[..., 1] * scale * sensor_h / 2
    points_obj = torch.stack([points_obj_x, points_obj_y, depth], dim=-1)

    # Sample rays and trace to sensor
    ray = self.sample_from_points(points=points_obj)
    ray = self.trace2sensor(ray)

    # Calculate centroid and normalize to [-1, 1]
    ray_center = -ray.centroid()  # shape [..., 3]
    distortion_center_x = ray_center[..., 0] / (sensor_w / 2)
    distortion_center_y = ray_center[..., 1] / (sensor_h / 2)
    distortion_center = torch.stack((distortion_center_x, distortion_center_y), dim=-1)
    return distortion_center

draw_distortion_map

draw_distortion_map(save_name=None, num_grid=16, depth=DEPTH, wvln=DEFAULT_WAVE, show=False)

Draw a scatter plot of the distortion grid.

Visualizes the output of calc_distortion_map() as a scatter plot on [-1, 1] normalized sensor coordinates. An undistorted lens would show a perfect rectilinear grid; deviations reveal barrel or pincushion distortion.

Parameters:

Name Type Description Default
save_name str | None

File path for the output PNG. If None, auto-generates './distortion_{depth}.png'.

None
num_grid int

Grid resolution per axis. Defaults to 16.

16
depth float

Object distance in mm. Defaults to DEPTH.

DEPTH
wvln float

Wavelength in micrometers. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE
show bool

If True, display interactively. Defaults to False.

False
Source code in deeplens/optics/geolens_pkg/eval.py
@torch.no_grad()
def draw_distortion_map(
    self, save_name=None, num_grid=16, depth=DEPTH, wvln=DEFAULT_WAVE, show=False
):
    """Draw a scatter plot of the distortion grid.

    Visualizes the output of ``calc_distortion_map()`` as a scatter plot on
    ``[-1, 1]`` normalized sensor coordinates.  An undistorted lens would
    show a perfect rectilinear grid; deviations reveal barrel or pincushion
    distortion.

    Args:
        save_name (str | None): File path for the output PNG.  If ``None``,
            auto-generates ``'./distortion_{depth}.png'``.
        num_grid (int): Grid resolution per axis. Defaults to 16.
        depth (float): Object distance in mm. Defaults to ``DEPTH``.
        wvln (float): Wavelength in micrometers. Defaults to ``DEFAULT_WAVE``.
        show (bool): If ``True``, display interactively. Defaults to ``False``.
    """
    # Ray tracing to calculate distortion map
    distortion_grid = self.calc_distortion_map(num_grid=num_grid, depth=depth, wvln=wvln)
    x1 = distortion_grid[..., 0].cpu().numpy()
    y1 = distortion_grid[..., 1].cpu().numpy()

    # Draw image
    fig, ax = plt.subplots()
    ax.set_title("Lens distortion")
    ax.scatter(x1, y1, s=2)
    ax.axis("scaled")
    ax.grid(True)

    # Add grid lines based on grid_size
    ax.set_xticks(np.linspace(-1, 1, num_grid))
    ax.set_yticks(np.linspace(-1, 1, num_grid))

    if show:
        plt.show()
    else:
        depth_str = "inf" if depth == float("inf") else f"{-depth}mm"
        if save_name is None:
            save_name = f"./distortion_{depth_str}.png"
        plt.savefig(save_name, bbox_inches="tight", format="png", dpi=300)
    plt.close(fig)

mtf

mtf(fov, wvln=DEFAULT_WAVE)

Compute the geometric MTF at a single field position.

The Modulation Transfer Function describes how well the lens preserves contrast as a function of spatial frequency. MTF = 1 at low frequencies (perfect contrast) and falls toward 0 near the diffraction limit or the Nyquist frequency of the sensor.

This implementation uses the geometric (ray-based) approach: 1. Compute the PSF at the given field position via self.psf(). 2. Convert PSF → MTF via psf2mtf() (project onto tangential and sagittal axes, then take the magnitude of the 1-D FFT).

Tangential MTF captures resolution in the meridional (radial) direction; sagittal MTF captures resolution perpendicular to it. The difference between the two indicates astigmatism.

Parameters:

Name Type Description Default
fov float

Field position as a fraction of self.rfov (0 = on-axis, 1 = full field). Internally mapped to a normalized point [0, -fov/rfov, DEPTH].

required
wvln float

Wavelength in micrometers. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE

Returns:

Type Description

tuple[np.ndarray, np.ndarray, np.ndarray]: - freq: Spatial frequency axis in cycles/mm (positive frequencies only, excluding DC). - mtf_tan: Tangential (meridional) MTF values, normalized so that MTF → 1 at low frequency. - mtf_sag: Sagittal MTF values, same normalization.

Source code in deeplens/optics/geolens_pkg/eval.py
def mtf(self, fov, wvln=DEFAULT_WAVE):
    """Compute the geometric MTF at a single field position.

    The *Modulation Transfer Function* describes how well the lens preserves
    contrast as a function of spatial frequency.  MTF = 1 at low frequencies
    (perfect contrast) and falls toward 0 near the diffraction limit or the
    Nyquist frequency of the sensor.

    This implementation uses the *geometric* (ray-based) approach:
        1. Compute the PSF at the given field position via ``self.psf()``.
        2. Convert PSF → MTF via ``psf2mtf()`` (project onto tangential and
           sagittal axes, then take the magnitude of the 1-D FFT).

    Tangential MTF captures resolution in the meridional (radial) direction;
    sagittal MTF captures resolution perpendicular to it.  The difference
    between the two indicates astigmatism.

    Args:
        fov (float): Field position as a **fraction** of ``self.rfov``
            (0 = on-axis, 1 = full field).  Internally mapped to a
            normalized point ``[0, -fov/rfov, DEPTH]``.
        wvln (float): Wavelength in micrometers. Defaults to ``DEFAULT_WAVE``.

    Returns:
        tuple[np.ndarray, np.ndarray, np.ndarray]:
            - **freq**: Spatial frequency axis in cycles/mm (positive
              frequencies only, excluding DC).
            - **mtf_tan**: Tangential (meridional) MTF values, normalized
              so that MTF → 1 at low frequency.
            - **mtf_sag**: Sagittal MTF values, same normalization.
    """
    point = [0, -fov / self.rfov, DEPTH]
    psf = self.psf(points=point, recenter=True, wvln=wvln)
    freq, mtf_tan, mtf_sag = self.psf2mtf(psf, pixel_size=self.pixel_size)
    return freq, mtf_tan, mtf_sag

psf2mtf staticmethod

psf2mtf(psf, pixel_size)

Convert a 2-D point-spread function to tangential and sagittal MTF curves.

The MTF is the magnitude of the optical transfer function (OTF), which is the Fourier transform of the PSF. For separable 1-D analysis: 1. Integrate the PSF along the x-axis → tangential line-spread function (LSF_tan). 2. Integrate the PSF along the y-axis → sagittal LSF_sag. 3. Take |FFT(LSF)| and normalize by the DC component so that MTF(0) = 1.

Only positive frequencies (excluding DC) are returned, following the convention used in Zemax MTF plots.

Parameters:

Name Type Description Default
psf Tensor | ndarray

2-D PSF with shape [H, W]. The array's y-axis (rows) corresponds to the tangential (meridional) direction; x-axis (columns) to the sagittal direction.

required
pixel_size float

Pixel pitch in mm. Determines the frequency axis scaling: Nyquist = 0.5 / pixel_size cycles/mm.

required

Returns:

Type Description

tuple[np.ndarray, np.ndarray, np.ndarray]: - freq: Spatial frequency in cycles/mm (positive, excluding DC). Length is roughly H // 2. - mtf_tan: Tangential MTF, normalized to 1 at DC. - mtf_sag: Sagittal MTF, normalized to 1 at DC.

References
  • https://en.wikipedia.org/wiki/Optical_transfer_function
  • Edmund Optics: Introduction to Modulation Transfer Function.
Source code in deeplens/optics/geolens_pkg/eval.py
@staticmethod
def psf2mtf(psf, pixel_size):
    """Convert a 2-D point-spread function to tangential and sagittal MTF curves.

    The MTF is the magnitude of the optical transfer function (OTF), which
    is the Fourier transform of the PSF.  For separable 1-D analysis:
        1. Integrate the PSF along the x-axis → *tangential* line-spread
           function (LSF_tan).
        2. Integrate the PSF along the y-axis → *sagittal* LSF_sag.
        3. Take ``|FFT(LSF)|`` and normalize by the DC component so that
           MTF(0) = 1.

    Only positive frequencies (excluding DC) are returned, following the
    convention used in Zemax MTF plots.

    Args:
        psf (torch.Tensor | np.ndarray): 2-D PSF with shape ``[H, W]``.
            The array's y-axis (rows) corresponds to the **tangential**
            (meridional) direction; x-axis (columns) to the **sagittal**
            direction.
        pixel_size (float): Pixel pitch in mm.  Determines the frequency
            axis scaling: ``Nyquist = 0.5 / pixel_size`` cycles/mm.

    Returns:
        tuple[np.ndarray, np.ndarray, np.ndarray]:
            - **freq**: Spatial frequency in cycles/mm (positive, excluding
              DC).  Length is roughly ``H // 2``.
            - **mtf_tan**: Tangential MTF, normalized to 1 at DC.
            - **mtf_sag**: Sagittal MTF, normalized to 1 at DC.

    References:
        - https://en.wikipedia.org/wiki/Optical_transfer_function
        - Edmund Optics: Introduction to Modulation Transfer Function.
    """
    # Convert to numpy (supports torch tensors and numpy arrays)
    try:
        psf_np = psf.detach().cpu().numpy()
    except AttributeError:
        try:
            psf_np = psf.cpu().numpy()
        except AttributeError:
            psf_np = np.asarray(psf)

    # Compute line spread functions (integrate PSF over orthogonal axes)
    # y-axis corresponds to tangential; x-axis corresponds to sagittal
    lsf_sagittal = psf_np.sum(axis=0)  # function of x
    lsf_tangential = psf_np.sum(axis=1)  # function of y

    # One-sided spectra (for real inputs)
    mtf_sag = np.abs(np.fft.rfft(lsf_sagittal))
    mtf_tan = np.abs(np.fft.rfft(lsf_tangential))

    # Normalize by DC to ensure MTF(0) == 1
    dc_sag = mtf_sag[0] if mtf_sag.size > 0 else 1.0
    dc_tan = mtf_tan[0] if mtf_tan.size > 0 else 1.0
    if dc_sag != 0:
        mtf_sag = mtf_sag / dc_sag
    if dc_tan != 0:
        mtf_tan = mtf_tan / dc_tan

    # Frequency axis in cycles/mm (one-sided)
    fx = np.fft.rfftfreq(lsf_sagittal.size, d=pixel_size)
    freq = fx
    positive_freq_idx = freq > 0

    return (
        freq[positive_freq_idx],
        mtf_tan[positive_freq_idx],
        mtf_sag[positive_freq_idx],
    )

draw_mtf

draw_mtf(save_name='./lens_mtf.png', relative_fov_list=[0.0, 0.7, 1.0], depth_list=[DEPTH], psf_ks=128, show=False)

Draw a grid of tangential MTF curves for multiple depths and field positions.

Produces a len(depth_list) × len(relative_fov_list) subplot grid. Each subplot shows the tangential MTF for R, G, B wavelengths plus a vertical line at the sensor Nyquist frequency (0.5 / pixel_size cycles/mm).

Algorithm per subplot
  1. Compute the RGB PSF via self.psf_rgb() at the specified (depth, relative_fov) with kernel size psf_ks.
  2. For each wavelength channel, call psf2mtf() to obtain the tangential MTF curve.
  3. Plot frequency vs MTF with RGB coloring.

Parameters:

Name Type Description Default
save_name str

File path for the output PNG. Defaults to './lens_mtf.png'.

'./lens_mtf.png'
relative_fov_list list[float]

Relative field positions in [0, 1], where 0 = on-axis and 1 = full field. Defaults to [0.0, 0.7, 1.0].

[0.0, 0.7, 1.0]
depth_list list[float]

Object distances in mm. float('inf') is automatically replaced by DEPTH. Defaults to [DEPTH].

[DEPTH]
psf_ks int

PSF kernel size in pixels (controls frequency resolution of the resulting MTF). Defaults to 128.

128
show bool

If True, display interactively. Defaults to False.

False
Source code in deeplens/optics/geolens_pkg/eval.py
@torch.no_grad()
def draw_mtf(
    self,
    save_name="./lens_mtf.png",
    relative_fov_list=[0.0, 0.7, 1.0],
    depth_list=[DEPTH],
    psf_ks=128,
    show=False,
):
    """Draw a grid of tangential MTF curves for multiple depths and field positions.

    Produces a ``len(depth_list) × len(relative_fov_list)`` subplot grid.
    Each subplot shows the tangential MTF for R, G, B wavelengths plus a
    vertical line at the sensor Nyquist frequency
    (``0.5 / pixel_size`` cycles/mm).

    Algorithm per subplot:
        1. Compute the RGB PSF via ``self.psf_rgb()`` at the specified
           ``(depth, relative_fov)`` with kernel size ``psf_ks``.
        2. For each wavelength channel, call ``psf2mtf()`` to obtain the
           tangential MTF curve.
        3. Plot frequency vs MTF with RGB coloring.

    Args:
        save_name (str): File path for the output PNG.
            Defaults to ``'./lens_mtf.png'``.
        relative_fov_list (list[float]): Relative field positions in
            ``[0, 1]``, where 0 = on-axis and 1 = full field.
            Defaults to ``[0.0, 0.7, 1.0]``.
        depth_list (list[float]): Object distances in mm.
            ``float('inf')`` is automatically replaced by ``DEPTH``.
            Defaults to ``[DEPTH]``.
        psf_ks (int): PSF kernel size in pixels (controls frequency
            resolution of the resulting MTF). Defaults to 128.
        show (bool): If ``True``, display interactively. Defaults to ``False``.
    """
    pixel_size = self.pixel_size
    nyquist_freq = 0.5 / pixel_size
    num_fovs = len(relative_fov_list)
    if float("inf") in depth_list:
        depth_list = [DEPTH if x == float("inf") else x for x in depth_list]
    num_depths = len(depth_list)

    # Create figure and subplots (num_depths * num_fovs subplots)
    fig, axs = plt.subplots(
        num_depths, num_fovs, figsize=(num_fovs * 3, num_depths * 3), squeeze=False
    )

    # Iterate over depth and field of view
    for depth_idx, depth in enumerate(depth_list):
        for fov_idx, fov_relative in enumerate(relative_fov_list):
            # Calculate rgb PSF
            point = [0, -fov_relative, depth]
            psf_rgb = self.psf_rgb(points=point, ks=psf_ks, recenter=False)

            # Calculate MTF curves for rgb wavelengths
            for wvln_idx, wvln in enumerate(WAVE_RGB):
                # Calculate MTF curves from PSF
                psf = psf_rgb[wvln_idx]
                freq, mtf_tan, _ = self.psf2mtf(psf, pixel_size)

                # Plot MTF curves
                ax = axs[depth_idx, fov_idx]
                color = RGB_COLORS[wvln_idx % len(RGB_COLORS)]
                wvln_label = RGB_LABELS[wvln_idx % len(RGB_LABELS)]
                wvln_nm = int(wvln * 1000)
                ax.plot(
                    freq,
                    mtf_tan,
                    color=color,
                    label=f"{wvln_label}({wvln_nm}nm)-Tan",
                )

            # Draw Nyquist frequency
            ax.axvline(
                x=nyquist_freq,
                color="k",
                linestyle=":",
                linewidth=1.2,
                label="Nyquist",
            )

            # Set title and label for subplot
            fov_deg = round(fov_relative * self.rfov * 180 / np.pi, 1)
            depth_str = "inf" if depth == float("inf") else f"{depth}"
            ax.set_title(f"FOV: {fov_deg}deg, Depth: {depth_str}mm", fontsize=8)
            ax.set_xlabel("Spatial Frequency [cycles/mm]", fontsize=8)
            ax.set_ylabel("MTF", fontsize=8)
            ax.legend(fontsize=6)
            ax.tick_params(axis="both", which="major", labelsize=7)
            ax.grid(True)
            ax.set_ylim(0, 1.05)

    plt.tight_layout()
    if show:
        plt.show()
    else:
        assert save_name.endswith(".png"), "save_name must end with .png"
        plt.savefig(save_name, bbox_inches="tight", format="png", dpi=300)
    plt.close(fig)

draw_field_curvature

draw_field_curvature(save_name=None, num_points=64, z_span=1.0, z_steps=201, wvln_list=WAVE_RGB, spp=256, show=False)

Draw field curvature: best-focus defocus (Δz) vs field angle for RGB.

Field curvature (Petzval curvature) causes off-axis image points to focus on a curved surface rather than the flat sensor. This method finds the axial position of minimum RMS spot size at each field angle and plots the deviation from the nominal sensor plane.

Algorithm (fully vectorized per wavelength): 1. Construct a meridional ray fan at num_points field angles, each with spp rays spanning the entrance pupil. 2. Trace all rays through the lens in a single batched call. 3. For each of z_steps defocus planes within ±z_span mm of self.d_sensor, propagate rays analytically (linear extension) and compute the variance of the y-coordinate. 4. The defocus with minimum variance is the best-focus plane. Parabolic interpolation on the three-point neighborhood gives sub-grid-step precision. 5. Repeat for each wavelength; overlay R/G/B curves on a single plot.

Parameters:

Name Type Description Default
save_name str | None

File path for the output PNG. If None, defaults to './field_curvature.png'.

None
num_points int

Number of field-angle samples from 0 to self.rfov. Defaults to 64.

64
z_span float

Half-range of the defocus sweep in mm. If the best-focus hits the boundary, a warning is printed. Defaults to 1.0.

1.0
z_steps int

Number of uniformly-spaced defocus planes within ±z_span. Higher values give finer axial resolution. Defaults to 201.

201
wvln_list list[float]

Wavelengths in micrometers. Defaults to WAVE_RGB.

WAVE_RGB
spp int

Rays per field point (sampled uniformly across the entrance pupil in the meridional plane). Defaults to 256.

256
show bool

If True, display interactively. Defaults to False.

False
Source code in deeplens/optics/geolens_pkg/eval.py
@torch.no_grad()
def draw_field_curvature(
    self,
    save_name=None,
    num_points=64,
    z_span=1.0,
    z_steps=201,
    wvln_list=WAVE_RGB,
    spp=256,
    show=False,
):
    """Draw field curvature: best-focus defocus (Δz) vs field angle for RGB.

    *Field curvature* (Petzval curvature) causes off-axis image points to
    focus on a curved surface rather than the flat sensor.  This method
    finds the axial position of minimum RMS spot size at each field angle
    and plots the deviation from the nominal sensor plane.

    Algorithm (fully vectorized per wavelength):
        1. Construct a meridional ray fan at ``num_points`` field angles,
           each with ``spp`` rays spanning the entrance pupil.
        2. Trace all rays through the lens in a single batched call.
        3. For each of ``z_steps`` defocus planes within ``±z_span`` mm of
           ``self.d_sensor``, propagate rays analytically (linear
           extension) and compute the variance of the y-coordinate.
        4. The defocus with minimum variance is the best-focus plane.
           Parabolic interpolation on the three-point neighborhood gives
           sub-grid-step precision.
        5. Repeat for each wavelength; overlay R/G/B curves on a single plot.

    Args:
        save_name (str | None): File path for the output PNG.  If ``None``,
            defaults to ``'./field_curvature.png'``.
        num_points (int): Number of field-angle samples from 0 to
            ``self.rfov``. Defaults to 64.
        z_span (float): Half-range of the defocus sweep in mm.  If the
            best-focus hits the boundary, a warning is printed.
            Defaults to 1.0.
        z_steps (int): Number of uniformly-spaced defocus planes within
            ``±z_span``. Higher values give finer axial resolution.
            Defaults to 201.
        wvln_list (list[float]): Wavelengths in micrometers.
            Defaults to ``WAVE_RGB``.
        spp (int): Rays per field point (sampled uniformly across the
            entrance pupil in the meridional plane). Defaults to 256.
        show (bool): If ``True``, display interactively. Defaults to ``False``.
    """
    device = self.device
    rfov_deg = float(self.rfov) * 180.0 / np.pi

    # Sample field angles [0, rfov_deg], shape [F]
    rfov_samples = torch.linspace(0.0, rfov_deg, num_points, device=device)

    # Entrance pupil (computed once)
    pupilz, pupilr = self.get_entrance_pupil()

    # Defocus sweep grid, shape [Z]
    d_sensor = self.d_sensor
    z_grid = d_sensor + torch.linspace(-z_span, z_span, z_steps, device=device)

    delta_z_tan = []

    for wvln in wvln_list:
        # --- Batch ray construction for all field angles ---
        # Pupil positions: shape [spp]
        pupil_y = torch.linspace(-pupilr, pupilr, spp, device=device) * 0.99

        # Ray origins: shape [F, spp, 3] (meridional plane: x=0)
        ray_o = torch.zeros(num_points, spp, 3, device=device)
        ray_o[..., 1] = pupil_y.unsqueeze(0)  # y = pupil sample
        ray_o[..., 2] = pupilz  # z = entrance pupil z

        # Ray directions: shape [F, spp, 3] (meridional: dx=0)
        fov_rad = rfov_samples * (np.pi / 180.0)  # [F]
        sin_fov = torch.sin(fov_rad)  # [F]
        cos_fov = torch.cos(fov_rad)  # [F]
        ray_d = torch.zeros(num_points, spp, 3, device=device)
        ray_d[..., 1] = sin_fov.unsqueeze(-1)  # [F, 1] -> [F, spp]
        ray_d[..., 2] = cos_fov.unsqueeze(-1)

        # Create batched ray and trace all field angles at once
        ray = Ray(ray_o, ray_d, wvln=wvln, device=device)
        ray, _ = self.trace(ray)

        # --- Vectorized best-focus for all field angles ---
        # ray.o: [F, spp, 3], ray.d: [F, spp, 3]
        oz = ray.o[..., 2:3]  # [F, spp, 1]
        dz = ray.d[..., 2:3]  # [F, spp, 1]
        t = (z_grid.view(1, 1, -1) - oz) / (dz + EPSILON)  # [F, spp, Z]

        oa = ray.o[..., 1:2]  # y-axis (tangential)
        da = ray.d[..., 1:2]
        pos_y = oa + da * t  # [F, spp, Z]

        w = ray.is_valid.unsqueeze(-1).float()  # [F, spp, 1]
        pos_y = pos_y * w  # mask invalid rays
        w_sum = w.sum(dim=1)  # [F, 1]

        centroid = pos_y.sum(dim=1) / (w_sum + EPSILON)  # [F, Z]
        ms = (((pos_y - centroid.unsqueeze(1)) ** 2) * w).sum(dim=1) / (
            w_sum + EPSILON
        )  # [F, Z]

        best_idx = torch.argmin(ms, dim=1)  # [F]

        # Warn if best focus hits z_span boundary
        boundary_hit = (best_idx == 0) | (best_idx == z_steps - 1)
        if boundary_hit.any():
            n_boundary = boundary_hit.sum().item()
            print(
                f"Warning: {n_boundary}/{num_points} field angles hit z_span "
                f"boundary. Consider increasing z_span (currently {z_span} mm)."
            )

        # Parabolic interpolation for sub-grid precision
        idx_c = best_idx.clamp(1, z_steps - 2)  # avoid boundary
        f_range = torch.arange(num_points, device=device)
        y_l = ms[f_range, idx_c - 1]
        y_c = ms[f_range, idx_c]
        y_r = ms[f_range, idx_c + 1]
        denom = 2.0 * (y_l - 2.0 * y_c + y_r)
        shift = (y_l - y_r) / (denom + EPSILON)  # fractional index offset
        shift = shift.clamp(-0.5, 0.5)  # safety clamp

        z_step_size = (2.0 * z_span) / (z_steps - 1)
        best_z = z_grid[idx_c] + shift * z_step_size  # [F]
        dz_tan = (best_z - d_sensor).cpu().numpy()

        # Mark fully-vignetted field angles as NaN (gaps in plot)
        valid_count = w.sum(dim=1).squeeze(-1)  # [F]
        fully_vignetted = (valid_count < 2).cpu().numpy()
        dz_tan[fully_vignetted] = np.nan

        delta_z_tan.append(dz_tan)

    # Plot
    fov_np = rfov_samples.detach().cpu().numpy()
    fig, ax = plt.subplots(figsize=(7, 6))
    ax.set_title("Field Curvature (Δz vs Field Angle)")

    all_vals = np.abs(np.concatenate(delta_z_tan)) if len(delta_z_tan) > 0 else np.array([0.0])
    x_range = float(max(0.2, all_vals.max() * 1.2)) if all_vals.size > 0 else 0.2

    for w_idx in range(len(wvln_list)):
        color = RGB_COLORS[w_idx % len(RGB_COLORS)]
        lbl = RGB_LABELS[w_idx % len(RGB_LABELS)]
        ax.plot(
            delta_z_tan[w_idx],
            fov_np,
            color=color,
            linestyle="-",
            label=f"{lbl}-Tan",
        )

    ax.axvline(x=0, color="k", linestyle="-", linewidth=0.8)
    ax.grid(True, color="gray", linestyle="-", linewidth=0.5, alpha=1.0)
    ax.set_xlabel("Defocus Δz (mm) relative to sensor plane")
    ax.set_ylabel("Field Angle (deg)")
    ax.set_xlim(-x_range, x_range)
    ax.set_ylim(0, rfov_deg)
    ax.legend(fontsize=8)
    plt.tight_layout()

    if show:
        plt.show()
    else:
        if save_name is None:
            save_name = "./field_curvature.png"
        plt.savefig(save_name, bbox_inches="tight", format="png", dpi=300)
    plt.close(fig)

vignetting

vignetting(depth=DEPTH, num_grid=32, num_rays=512)

Compute the relative-illumination (vignetting) map across the field.

Vignetting measures how much light is lost at each field position due to rays being clipped by lens apertures or barrel edges. It is computed as the fraction of traced rays that remain valid (not vignetted) at each grid cell, normalized by the total number of launched rays.

A value of 1.0 means all rays reach the sensor (no vignetting); 0.0 means complete light blockage. Real lenses typically show 1.0 on-axis and fall off toward the field edges due to mechanical vignetting and the cos⁴ illumination law.

Algorithm
  1. self.sample_grid_rays() with uniform_fov=False (uniform image-space sampling) to ensure correct sensor-plane mapping.
  2. self.trace2sensor() propagates rays and marks clipped ones as invalid.
  3. Per-cell throughput = count(valid) / num_rays.

Parameters:

Name Type Description Default
depth float

Object distance in mm. Defaults to DEPTH.

DEPTH
num_grid int

Grid resolution per axis. Defaults to 32.

32
num_rays int

Rays launched per grid cell. Higher values reduce Monte-Carlo noise. Defaults to 512.

512

Returns:

Type Description

torch.Tensor: Vignetting map with shape [num_grid, num_grid], values in [0, 1].

Source code in deeplens/optics/geolens_pkg/eval.py
@torch.no_grad()
def vignetting(self, depth=DEPTH, num_grid=32, num_rays=512):
    """Compute the relative-illumination (vignetting) map across the field.

    Vignetting measures how much light is lost at each field position due to
    rays being clipped by lens apertures or barrel edges.  It is computed as
    the fraction of traced rays that remain valid (not vignetted) at each
    grid cell, normalized by the total number of launched rays.

    A value of 1.0 means all rays reach the sensor (no vignetting); 0.0
    means complete light blockage.  Real lenses typically show 1.0 on-axis
    and fall off toward the field edges due to mechanical vignetting and the
    cos⁴ illumination law.

    Algorithm:
        1. ``self.sample_grid_rays()`` with ``uniform_fov=False`` (uniform
           image-space sampling) to ensure correct sensor-plane mapping.
        2. ``self.trace2sensor()`` propagates rays and marks clipped ones as
           invalid.
        3. Per-cell throughput = ``count(valid) / num_rays``.

    Args:
        depth (float): Object distance in mm. Defaults to ``DEPTH``.
        num_grid (int): Grid resolution per axis. Defaults to 32.
        num_rays (int): Rays launched per grid cell.  Higher values reduce
            Monte-Carlo noise. Defaults to 512.

    Returns:
        torch.Tensor: Vignetting map with shape ``[num_grid, num_grid]``,
            values in ``[0, 1]``.
    """
    # Sample rays in uniform image space (not FOV angles) for correct sensor mapping
    # shape [num_grid, num_grid, num_rays, 3]
    ray = self.sample_grid_rays(
        depth=depth, num_grid=num_grid, num_rays=num_rays, uniform_fov=False
    )

    # Trace rays to sensor
    ray = self.trace2sensor(ray)

    # Calculate vignetting map
    vignetting = ray.is_valid.sum(-1) / (ray.is_valid.shape[-1])
    return vignetting

draw_vignetting

draw_vignetting(filename=None, depth=DEPTH, resolution=512, show=False)

Draw the vignetting map as a grayscale image with a colorbar.

Computes the vignetting map via self.vignetting(), bilinearly upsamples it to resolution × resolution, and displays it as a grayscale image where white = no vignetting and black = fully vignetted.

Parameters:

Name Type Description Default
filename str | None

File path for the output PNG. If None, auto-generates './vignetting_{depth}.png'.

None
depth float

Object distance in mm. Defaults to DEPTH.

DEPTH
resolution int

Output image size in pixels (square). Defaults to 512.

512
show bool

If True, display interactively. Defaults to False.

False
Source code in deeplens/optics/geolens_pkg/eval.py
@torch.no_grad()
def draw_vignetting(self, filename=None, depth=DEPTH, resolution=512, show=False):
    """Draw the vignetting map as a grayscale image with a colorbar.

    Computes the vignetting map via ``self.vignetting()``, bilinearly
    upsamples it to ``resolution × resolution``, and displays it as a
    grayscale image where white = no vignetting and black = fully vignetted.

    Args:
        filename (str | None): File path for the output PNG.  If ``None``,
            auto-generates ``'./vignetting_{depth}.png'``.
        depth (float): Object distance in mm. Defaults to ``DEPTH``.
        resolution (int): Output image size in pixels (square).
            Defaults to 512.
        show (bool): If ``True``, display interactively. Defaults to ``False``.
    """
    # Calculate vignetting map
    vignetting = self.vignetting(depth=depth)

    # Interpolate vignetting map to desired resolution
    vignetting = F.interpolate(
        vignetting.unsqueeze(0).unsqueeze(0),
        size=(resolution, resolution),
        mode="bilinear",
        align_corners=False,
    ).squeeze()

    fig, ax = plt.subplots()
    ax.set_title("Relative Illumination (Vignetting)")
    im = ax.imshow(vignetting.cpu().numpy(), cmap="gray", vmin=0.0, vmax=1.0)
    fig.colorbar(im, ax=ax, ticks=[0.0, 0.25, 0.5, 0.75, 1.0])

    if show:
        plt.show()
    else:
        if filename is None:
            filename = f"./vignetting_{depth}.png"
        plt.savefig(filename, bbox_inches="tight", format="png", dpi=300)
    plt.close(fig)

wavefront_error

wavefront_error(relative_fov=0.0, depth=DEPTH, wvln=DEFAULT_WAVE, num_rays=SPP_COHERENT, ks=256)

Compute wavefront error (OPD) at the exit pupil for a given field position.

The wavefront error is the optical path difference between the actual wavefront and the ideal spherical reference wavefront. The reference sphere is centered at the ideal image point (chief ray intersection with the sensor) and passes through the exit pupil center.

By Fermat's principle, a perfect lens has equal total optical path (object → lens → image) for all rays. The deviation from this equal-path condition is the wavefront error:

``OPD(x,y) = [OPL(x,y) + r(x,y)] - mean_over_pupil``

where OPL(x,y) is the accumulated optical path from the object through the lens to the exit pupil, and r(x,y) is the geometric distance from the exit pupil point to the ideal image point. Piston (mean) is removed.

Uses the same coherent ray-tracing infrastructure as :meth:pupil_field.

Parameters:

Name Type Description Default
relative_fov float

Relative field of view in [-1, 1] along the meridional (y) direction. 0 = on-axis, 1 = full field.

0.0
depth float

Object distance [mm]. Use DEPTH for practical infinity.

DEPTH
wvln float

Wavelength [µm].

DEFAULT_WAVE
num_rays int

Number of rays to sample through the pupil.

SPP_COHERENT
ks int

Grid resolution for the OPD map at the exit pupil.

256

Returns:

Name Type Description
dict
  • opd_map (Tensor): OPD map on exit pupil grid, shape [ks, ks], in waves. Invalid (vignetted) regions are zero.
  • rms (float): RMS wavefront error in waves (piston removed).
  • pv (float): Peak-to-valley wavefront error in waves.
  • valid_mask (Tensor): Boolean mask of valid pupil pixels [ks, ks].
  • strehl (float): Maréchal approximation Strehl ratio.
Note

This function sets the default dtype to torch.float64 for phase accuracy (consistent with :meth:pupil_field).

References

[1] V. N. Mahajan, "Optical Imaging and Aberrations, Part II", Ch. 1. [2] Zemax OpticStudio, "Wavefront Error Analysis".

Source code in deeplens/optics/geolens_pkg/eval.py
@torch.no_grad()
def wavefront_error(
    self,
    relative_fov=0.0,
    depth=DEPTH,
    wvln=DEFAULT_WAVE,
    num_rays=SPP_COHERENT,
    ks=256,
):
    """Compute wavefront error (OPD) at the exit pupil for a given field position.

    The wavefront error is the optical path difference between the actual
    wavefront and the ideal spherical reference wavefront. The reference sphere
    is centered at the ideal image point (chief ray intersection with the sensor)
    and passes through the exit pupil center.

    By Fermat's principle, a perfect lens has equal total optical path (object →
    lens → image) for all rays. The deviation from this equal-path condition is
    the wavefront error:

        ``OPD(x,y) = [OPL(x,y) + r(x,y)] - mean_over_pupil``

    where ``OPL(x,y)`` is the accumulated optical path from the object through
    the lens to the exit pupil, and ``r(x,y)`` is the geometric distance from
    the exit pupil point to the ideal image point. Piston (mean) is removed.

    Uses the same coherent ray-tracing infrastructure as :meth:`pupil_field`.

    Args:
        relative_fov (float): Relative field of view in ``[-1, 1]`` along the
            meridional (y) direction. ``0`` = on-axis, ``1`` = full field.
        depth (float): Object distance [mm]. Use ``DEPTH`` for practical infinity.
        wvln (float): Wavelength [µm].
        num_rays (int): Number of rays to sample through the pupil.
        ks (int): Grid resolution for the OPD map at the exit pupil.

    Returns:
        dict:
            - ``opd_map`` (Tensor): OPD map on exit pupil grid, shape ``[ks, ks]``,
              in waves. Invalid (vignetted) regions are zero.
            - ``rms`` (float): RMS wavefront error in waves (piston removed).
            - ``pv`` (float): Peak-to-valley wavefront error in waves.
            - ``valid_mask`` (Tensor): Boolean mask of valid pupil pixels ``[ks, ks]``.
            - ``strehl`` (float): Maréchal approximation Strehl ratio.

    Note:
        This function sets the default dtype to ``torch.float64`` for phase
        accuracy (consistent with :meth:`pupil_field`).

    References:
        [1] V. N. Mahajan, "Optical Imaging and Aberrations, Part II", Ch. 1.
        [2] Zemax OpticStudio, "Wavefront Error Analysis".
    """
    # Float64 required for accurate OPL accumulation
    self.astype(torch.float64)
    device = self.device
    sensor_w, sensor_h = self.sensor_size
    wvln_mm = wvln * 1e-3

    # Build normalized point: positive relative_fov -> negative y (convention)
    point_norm = torch.tensor(
        [0.0, -relative_fov, depth], dtype=torch.float64, device=device
    )
    points = point_norm.unsqueeze(0)  # [1, 3]

    # Convert to physical object coordinates
    scale = self.calc_scale(points[:, 2].item())
    point_obj_x = points[:, 0] * scale * sensor_w / 2
    point_obj_y = points[:, 1] * scale * sensor_h / 2
    point_obj = torch.stack([point_obj_x, point_obj_y, points[:, 2]], dim=-1)

    # Find ideal image point via chief ray
    # psf_center returns negated centroid, so negate back to get actual image position
    chief_pointc = self.psf_center(point_obj, method="chief_ray")  # [1, 2]
    img_x = -chief_pointc[0, 0]
    img_y = -chief_pointc[0, 1]
    img_z = float(self.d_sensor)

    # Sample rays and trace coherently to exit pupil
    ray = self.sample_from_points(
        points=point_obj, num_rays=num_rays, wvln=wvln
    )
    ray.coherent = True
    ray = self.trace2exit_pupil(ray)

    # Get exit pupil parameters
    pupilz, pupilr = self.get_exit_pupil()
    pupilr = float(pupilr)
    pupilz = float(pupilz)

    # Extract valid rays (squeeze batch dim since single point)
    valid = ray.is_valid.squeeze(0) > 0  # [num_rays]
    ray_x = ray.o[0, :, 0]  # [num_rays]
    ray_y = ray.o[0, :, 1]
    opl = ray.opl[0, :, 0]  # [num_rays]

    if valid.sum() == 0:
        raise RuntimeError(
            f"No valid rays at relative_fov={relative_fov}. "
            "The field may be fully vignetted."
        )

    # Distance from each ray's exit pupil position to ideal image point
    dist_to_img = torch.sqrt(
        (ray_x - img_x) ** 2
        + (ray_y - img_y) ** 2
        + (pupilz - img_z) ** 2
    )

    # Total optical path = OPL through lens to exit pupil + free-space to image
    total_path = opl + dist_to_img  # [num_rays]

    # Remove piston (mean over valid rays) to get wavefront error
    total_path_valid = total_path[valid]
    mean_path = total_path_valid.mean()
    opd_mm = total_path - mean_path  # OPD in [mm]
    opd_waves = opd_mm / wvln_mm  # OPD in [waves]

    # Compute RMS and PV from per-ray values (more accurate than from grid)
    opd_valid = opd_waves[valid]
    rms_waves = torch.sqrt(torch.mean(opd_valid**2)).item()
    pv_waves = (opd_valid.max() - opd_valid.min()).item()

    # Maréchal approximation: Strehl ≈ exp(-(2π·σ)²)
    strehl = math.exp(-(2 * math.pi * rms_waves) ** 2)

    # Bin OPD values onto exit pupil grid using assign_points_to_pixels
    # Grid covers [-pupilr, pupilr] x [-pupilr, pupilr]
    pupil_range = [-pupilr, pupilr]
    pupil_points = torch.stack([ray_x[valid], ray_y[valid]], dim=-1)  # [N, 2]
    pupil_mask = torch.ones(pupil_points.shape[0], device=device)

    # Sum of weighted OPD values
    opd_sum = assign_points_to_pixels(
        points=pupil_points,
        mask=pupil_mask,
        ks=ks,
        x_range=pupil_range,
        y_range=pupil_range,
        value=opd_valid,
    )
    # Sum of weights (count)
    count = assign_points_to_pixels(
        points=pupil_points,
        mask=pupil_mask,
        ks=ks,
        x_range=pupil_range,
        y_range=pupil_range,
        value=torch.ones_like(opd_valid),
    )
    valid_mask = count > 0
    opd_map = torch.where(valid_mask, opd_sum / count, torch.zeros_like(opd_sum))

    return {
        "opd_map": opd_map,
        "rms": rms_waves,
        "pv": pv_waves,
        "valid_mask": valid_mask,
        "strehl": strehl,
    }

rms_wavefront_error

rms_wavefront_error(relative_fov=0.0, depth=DEPTH, wvln=DEFAULT_WAVE, num_rays=SPP_COHERENT)

Compute scalar RMS wavefront error at a given field position.

Convenience wrapper around :meth:wavefront_error.

Parameters:

Name Type Description Default
relative_fov float

Relative field of view in [-1, 1].

0.0
depth float

Object distance [mm].

DEPTH
wvln float

Wavelength [µm].

DEFAULT_WAVE
num_rays int

Number of rays to sample.

SPP_COHERENT

Returns:

Name Type Description
float

RMS wavefront error in waves.

Source code in deeplens/optics/geolens_pkg/eval.py
@torch.no_grad()
def rms_wavefront_error(
    self,
    relative_fov=0.0,
    depth=DEPTH,
    wvln=DEFAULT_WAVE,
    num_rays=SPP_COHERENT,
):
    """Compute scalar RMS wavefront error at a given field position.

    Convenience wrapper around :meth:`wavefront_error`.

    Args:
        relative_fov (float): Relative field of view in ``[-1, 1]``.
        depth (float): Object distance [mm].
        wvln (float): Wavelength [µm].
        num_rays (int): Number of rays to sample.

    Returns:
        float: RMS wavefront error in waves.
    """
    result = self.wavefront_error(
        relative_fov=relative_fov,
        depth=depth,
        wvln=wvln,
        num_rays=num_rays,
    )
    return result["rms"]

draw_wavefront_error

draw_wavefront_error(save_name='./wavefront_error.png', num_fov=5, depth=DEPTH, wvln=DEFAULT_WAVE, num_rays=SPP_COHERENT, ks=256, show=False)

Draw wavefront error (OPD) maps at multiple field positions.

Evaluates the wavefront error along the meridional (y) direction from on-axis to full field, and displays each OPD map with RMS and PV annotations.

Parameters:

Name Type Description Default
save_name str

Filename to save the figure.

'./wavefront_error.png'
num_fov int

Number of field positions to evaluate.

5
depth float

Object distance [mm].

DEPTH
wvln float

Wavelength [µm].

DEFAULT_WAVE
num_rays int

Number of rays to sample per field position.

SPP_COHERENT
ks int

Grid resolution for each OPD map.

256
show bool

If True, display the figure interactively.

False
Source code in deeplens/optics/geolens_pkg/eval.py
@torch.no_grad()
def draw_wavefront_error(
    self,
    save_name="./wavefront_error.png",
    num_fov=5,
    depth=DEPTH,
    wvln=DEFAULT_WAVE,
    num_rays=SPP_COHERENT,
    ks=256,
    show=False,
):
    """Draw wavefront error (OPD) maps at multiple field positions.

    Evaluates the wavefront error along the meridional (y) direction from
    on-axis to full field, and displays each OPD map with RMS and PV
    annotations.

    Args:
        save_name (str): Filename to save the figure.
        num_fov (int): Number of field positions to evaluate.
        depth (float): Object distance [mm].
        wvln (float): Wavelength [µm].
        num_rays (int): Number of rays to sample per field position.
        ks (int): Grid resolution for each OPD map.
        show (bool): If True, display the figure interactively.
    """
    fov_list = torch.linspace(0, 1, num_fov).tolist()

    fig, axs = plt.subplots(1, num_fov, figsize=(num_fov * 3.5, 3.5))
    axs = np.atleast_1d(axs)

    # Collect all OPD ranges to use a shared color scale
    results = []
    vmax = 0.0
    for fov in fov_list:
        try:
            result = self.wavefront_error(
                relative_fov=fov,
                depth=depth,
                wvln=wvln,
                num_rays=num_rays,
                ks=ks,
            )
            results.append(result)
            opd_valid = result["opd_map"][result["valid_mask"]]
            if len(opd_valid) > 0:
                vmax = max(vmax, opd_valid.abs().max().item())
        except RuntimeError:
            results.append(None)

    if vmax == 0:
        vmax = 1.0  # fallback

    for i, (fov, result) in enumerate(zip(fov_list, results)):
        if result is None:
            axs[i].set_title(f"FoV={fov:.2f}\n(vignetted)", fontsize=8)
            axs[i].axis("off")
            continue

        opd = result["opd_map"].cpu().numpy()
        mask = result["valid_mask"].cpu().numpy()
        rms = result["rms"]
        pv = result["pv"]

        # Mask invalid regions with NaN for visualization
        opd_vis = np.where(mask, opd, np.nan)

        im = axs[i].imshow(
            opd_vis,
            cmap="RdBu_r",
            vmin=-vmax,
            vmax=vmax,
            interpolation="bilinear",
        )
        axs[i].set_title(
            f"FoV={fov:.2f}\nRMS={rms:.4f}λ  PV={pv:.3f}λ",
            fontsize=8,
        )
        axs[i].axis("off")
        fig.colorbar(
            im,
            ax=axs[i],
            fraction=0.046,
            pad=0.04,
            label="OPD [waves]",
        )

    fig.suptitle(
        f"Wavefront Error (λ={wvln}µm, depth={depth}mm)", fontsize=10
    )
    plt.tight_layout()

    if show:
        plt.show()
    else:
        assert save_name.endswith(".png"), "save_name must end with .png"
        plt.savefig(save_name, bbox_inches="tight", format="png", dpi=300)
    plt.close(fig)

field_curvature

field_curvature()

Compute field curvature data (best-focus defocus vs field angle).

Field curvature is the axial shift of the best-focus surface away from the flat sensor plane as a function of field angle. It is caused by the Petzval sum of lens surface curvatures and refractive indices.

Not yet implemented. See draw_field_curvature() for a plotting version that already performs the underlying computation.

Source code in deeplens/optics/geolens_pkg/eval.py
def field_curvature(self):
    """Compute field curvature data (best-focus defocus vs field angle).

    Field curvature is the axial shift of the best-focus surface away from
    the flat sensor plane as a function of field angle.  It is caused by
    the Petzval sum of lens surface curvatures and refractive indices.

    Not yet implemented.  See ``draw_field_curvature()`` for a plotting
    version that already performs the underlying computation.
    """
    pass

calc_chief_ray

calc_chief_ray(fov, plane='sagittal')

Find the chief ray for a given field angle using 2-D ray tracing.

The chief ray (also called the principal ray) is the ray from an off-axis object point that passes through the center of the aperture stop. It defines the image height for distortion calculations and sets the reference axis for coma and lateral color analysis.

Algorithm
  1. Sample a fan of parallel rays at the specified fov in the chosen plane, entering through the entrance pupil.
  2. Trace the fan up to (but not through) the aperture stop.
  3. Select the ray whose transverse position at the stop is closest to the optical axis — this is the chief ray.
  4. Return its incident (object-space) origin and direction.

Parameters:

Name Type Description Default
fov float

Incident half-angle in degrees.

required
plane str

'sagittal' (x-axis) or 'meridional' (y-axis). Defaults to 'sagittal'.

'sagittal'

Returns:

Type Description

tuple[torch.Tensor, torch.Tensor]: - chief_ray_o: Origin of the chief ray in object space, shape [3]. - chief_ray_d: Unit direction of the chief ray, shape [3].

Note

This is a 2-D (meridional or sagittal plane) search. For a full 3-D chief ray, one would shrink the pupil and trace the centroid ray.

Source code in deeplens/optics/geolens_pkg/eval.py
@torch.no_grad()
def calc_chief_ray(self, fov, plane="sagittal"):
    """Find the chief ray for a given field angle using 2-D ray tracing.

    The *chief ray* (also called the *principal ray*) is the ray from an
    off-axis object point that passes through the center of the aperture
    stop.  It defines the image height for distortion calculations and sets
    the reference axis for coma and lateral color analysis.

    Algorithm:
        1. Sample a fan of parallel rays at the specified ``fov`` in the
           chosen plane, entering through the entrance pupil.
        2. Trace the fan up to (but not through) the aperture stop.
        3. Select the ray whose transverse position at the stop is closest
           to the optical axis — this is the chief ray.
        4. Return its *incident* (object-space) origin and direction.

    Args:
        fov (float): Incident half-angle in **degrees**.
        plane (str): ``'sagittal'`` (x-axis) or ``'meridional'`` (y-axis).
            Defaults to ``'sagittal'``.

    Returns:
        tuple[torch.Tensor, torch.Tensor]:
            - **chief_ray_o**: Origin of the chief ray in object space,
              shape ``[3]``.
            - **chief_ray_d**: Unit direction of the chief ray, shape ``[3]``.

    Note:
        This is a 2-D (meridional or sagittal plane) search.  For a full
        3-D chief ray, one would shrink the pupil and trace the centroid ray.
    """
    # Sample parallel rays from object space
    ray = self.sample_parallel_2D(
        fov=fov, num_rays=SPP_CALC, entrance_pupil=True, plane=plane
    )
    inc_ray = ray.clone()

    # Trace to the aperture
    surf_range = range(0, self.aper_idx)
    ray, _ = self.trace(ray, surf_range=surf_range)

    # Look for the ray that is closest to the optical axis
    center_x = torch.min(torch.abs(ray.o[:, 0]))
    center_idx = torch.where(torch.abs(ray.o[:, 0]) == center_x)[0][0].item()
    chief_ray_o, chief_ray_d = inc_ray.o[center_idx, :], inc_ray.d[center_idx, :]

    return chief_ray_o, chief_ray_d

calc_chief_ray_infinite

calc_chief_ray_infinite(rfov, depth=0.0, wvln=DEFAULT_WAVE, plane='meridional', num_rays=SPP_CALC, ray_aiming=True)

Compute chief rays for one or more field angles with optional ray aiming.

This is the batched, production version of calc_chief_ray. It supports vectorized evaluation over multiple field angles and implements ray aiming — an iterative procedure that launches a fan of rays toward the entrance pupil and selects the one that passes closest to the aperture-stop center. Ray aiming is essential for accurate distortion measurement in wide-angle or fisheye lenses where the paraxial approximation breaks down.

Algorithm
  1. For on-axis (rfov = 0): chief ray is trivially along the z-axis.
  2. For off-axis angles with ray_aiming=False: the chief ray is aimed at the entrance pupil center (paraxial approximation).
  3. For off-axis angles with ray_aiming=True: a. Estimate the object-space y (or x) position from the entrance pupil geometry. b. Create a narrow fan of num_rays rays bracketing that estimate (width = 5 % of y_distance, clamped to 0.05 * pupil_radius). c. Trace the fan to the aperture stop. d. Pick the ray closest to the optical axis at the stop.

Parameters:

Name Type Description Default
rfov float | Tensor

Field angle(s) in degrees. A scalar is converted to [0, rfov] (two-element tensor). A tensor of shape [N] is used directly.

required
depth float | Tensor

Object depth(s) in mm. Defaults to 0.0 (object at the first surface).

0.0
wvln float

Wavelength in micrometers. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE
plane str

'sagittal' or 'meridional'. Defaults to 'meridional'.

'meridional'
num_rays int

Size of the search fan for ray aiming. Defaults to SPP_CALC.

SPP_CALC
ray_aiming bool

If True, perform iterative ray aiming for accurate chief-ray identification. Defaults to True.

True

Returns:

Type Description

tuple[torch.Tensor, torch.Tensor]: - chief_ray_o: Origins, shape [N, 3]. - chief_ray_d: Unit directions, shape [N, 3].

Source code in deeplens/optics/geolens_pkg/eval.py
@torch.no_grad()
def calc_chief_ray_infinite(
    self,
    rfov,
    depth=0.0,
    wvln=DEFAULT_WAVE,
    plane="meridional",
    num_rays=SPP_CALC,
    ray_aiming=True,
):
    """Compute chief rays for one or more field angles with optional ray aiming.

    This is the batched, production version of ``calc_chief_ray``.  It
    supports vectorized evaluation over multiple field angles and implements
    *ray aiming* — an iterative procedure that launches a fan of rays
    toward the entrance pupil and selects the one that passes closest to
    the aperture-stop center.  Ray aiming is essential for accurate
    distortion measurement in wide-angle or fisheye lenses where the
    paraxial approximation breaks down.

    Algorithm:
        1. For on-axis (``rfov = 0``): chief ray is trivially along the
           z-axis.
        2. For off-axis angles with ``ray_aiming=False``: the chief ray is
           aimed at the entrance pupil center (paraxial approximation).
        3. For off-axis angles with ``ray_aiming=True``:
           a. Estimate the object-space y (or x) position from the entrance
              pupil geometry.
           b. Create a narrow fan of ``num_rays`` rays bracketing that
              estimate (width = 5 % of y_distance, clamped to
              ``0.05 * pupil_radius``).
           c. Trace the fan to the aperture stop.
           d. Pick the ray closest to the optical axis at the stop.

    Args:
        rfov (float | torch.Tensor): Field angle(s) in **degrees**.
            A scalar is converted to ``[0, rfov]`` (two-element tensor).
            A tensor of shape ``[N]`` is used directly.
        depth (float | torch.Tensor): Object depth(s) in mm.
            Defaults to 0.0 (object at the first surface).
        wvln (float): Wavelength in micrometers. Defaults to ``DEFAULT_WAVE``.
        plane (str): ``'sagittal'`` or ``'meridional'``.
            Defaults to ``'meridional'``.
        num_rays (int): Size of the search fan for ray aiming.
            Defaults to ``SPP_CALC``.
        ray_aiming (bool): If ``True``, perform iterative ray aiming for
            accurate chief-ray identification. Defaults to ``True``.

    Returns:
        tuple[torch.Tensor, torch.Tensor]:
            - **chief_ray_o**: Origins, shape ``[N, 3]``.
            - **chief_ray_d**: Unit directions, shape ``[N, 3]``.
    """
    if isinstance(rfov, float) and rfov > 0:
        rfov = torch.linspace(0, rfov, 2)
    rfov = rfov.to(self.device)

    if not isinstance(depth, torch.Tensor):
        depth = torch.tensor(depth, device=self.device).repeat(len(rfov))

    # set chief ray
    chief_ray_o = torch.zeros([len(rfov), 3]).to(self.device)
    chief_ray_d = torch.zeros([len(rfov), 3]).to(self.device)

    # Convert rfov to radian
    rfov = rfov * torch.pi / 180.0

    if torch.any(rfov == 0):
        chief_ray_o[0, ...] = torch.tensor(
            [0.0, 0.0, depth[0]], device=self.device, dtype=torch.float32
        )
        chief_ray_d[0, ...] = torch.tensor(
            [0.0, 0.0, 1.0], device=self.device, dtype=torch.float32
        )
        if len(rfov) == 1:
            return chief_ray_o, chief_ray_d

    # Extract non-zero rfov entries for processing
    has_zero = torch.any(rfov == 0)
    if has_zero:
        start_idx = 1
        rfovs = rfov[1:]
        depths = depth[1:]
    else:
        start_idx = 0
        rfovs = rfov
        depths = depth

    if self.aper_idx == 0:
        if plane == "sagittal":
            chief_ray_o[start_idx:, ...] = torch.stack(
                [depths * torch.tan(rfovs), torch.zeros_like(rfovs), depths], dim=-1
            )
            chief_ray_d[start_idx:, ...] = torch.stack(
                [torch.sin(rfovs), torch.zeros_like(rfovs), torch.cos(rfovs)],
                dim=-1,
            )
        else:
            chief_ray_o[start_idx:, ...] = torch.stack(
                [torch.zeros_like(rfovs), depths * torch.tan(rfovs), depths], dim=-1
            )
            chief_ray_d[start_idx:, ...] = torch.stack(
                [torch.zeros_like(rfovs), torch.sin(rfovs), torch.cos(rfovs)],
                dim=-1,
            )

        return chief_ray_o, chief_ray_d

    # Scale factor
    pupilz, pupilr = self.calc_entrance_pupil()
    y_distance = torch.tan(rfovs) * (abs(depths) + pupilz)

    if ray_aiming:
        scale = 0.05
        min_delta = 0.05 * pupilr  # minimum search range based on pupil radius
        delta = torch.clamp(scale * y_distance, min=min_delta)

    if not ray_aiming:
        if plane == "sagittal":
            chief_ray_o[start_idx:, ...] = torch.stack(
                [-y_distance, torch.zeros_like(rfovs), depths], dim=-1
            )
            chief_ray_d[start_idx:, ...] = torch.stack(
                [torch.sin(rfovs), torch.zeros_like(rfovs), torch.cos(rfovs)],
                dim=-1,
            )
        else:
            chief_ray_o[start_idx:, ...] = torch.stack(
                [torch.zeros_like(rfovs), -y_distance, depths], dim=-1
            )
            chief_ray_d[start_idx:, ...] = torch.stack(
                [torch.zeros_like(rfovs), torch.sin(rfovs), torch.cos(rfovs)],
                dim=-1,
            )

    else:
        min_y = -y_distance - delta
        max_y = -y_distance + delta
        t = torch.linspace(0, 1, num_rays, device=min_y.device)
        o1_linspace = min_y.unsqueeze(-1) + t * (max_y - min_y).unsqueeze(-1)

        o1 = torch.zeros([len(rfovs), num_rays, 3])
        o1[:, :, 2] = depths[0]

        o2_linspace = -delta.unsqueeze(-1) + t * (2 * delta).unsqueeze(-1)

        o2 = torch.zeros([len(rfovs), num_rays, 3])
        o2[:, :, 2] = pupilz

        if plane == "sagittal":
            o1[:, :, 0] = o1_linspace
            o2[:, :, 0] = o2_linspace
        else:
            o1[:, :, 1] = o1_linspace
            o2[:, :, 1] = o2_linspace

        # Trace until the aperture
        ray = Ray(o1, o2 - o1, wvln=wvln, device=self.device)
        inc_ray = ray.clone()
        surf_range = range(0, self.aper_idx + 1)
        ray, _ = self.trace(ray, surf_range=surf_range)

        # Look for the ray that is closest to the optical axis
        if plane == "sagittal":
            _, center_idx = torch.min(torch.abs(ray.o[..., 0]), dim=1)
            chief_ray_o[start_idx:, ...] = inc_ray.o[
                torch.arange(len(rfovs)), center_idx.long(), ...
            ]
            chief_ray_d[start_idx:, ...] = torch.stack(
                [torch.sin(rfovs), torch.zeros_like(rfovs), torch.cos(rfovs)],
                dim=-1,
            )
        else:
            _, center_idx = torch.min(torch.abs(ray.o[..., 1]), dim=1)
            chief_ray_o[start_idx:, ...] = inc_ray.o[
                torch.arange(len(rfovs)), center_idx.long(), ...
            ]
            chief_ray_d[start_idx:, ...] = torch.stack(
                [torch.zeros_like(rfovs), torch.sin(rfovs), torch.cos(rfovs)],
                dim=-1,
            )

    return chief_ray_o, chief_ray_d

analysis_rendering

analysis_rendering(img_org, save_name=None, depth=DEPTH, spp=SPP_RENDER, unwarp=False, method='ray_tracing', show=False)

Render a test image through the lens and report PSNR / SSIM.

Simulates what the sensor would capture if the given image were placed at the specified object distance. The rendering accounts for all geometric aberrations (blur, distortion, vignetting, chromatic effects). Optionally applies an inverse distortion warp (unwarp) and reports quality metrics for both the raw and unwarped renderings.

Algorithm
  1. Convert img_org to a [1, 3, H, W] float tensor and temporarily set the sensor resolution to match.
  2. Call self.render() with the chosen method (ray tracing or PSF convolution).
  3. Compute PSNR and SSIM between the original and rendered images.
  4. If unwarp=True, apply self.unwarp() to correct geometric distortion and report metrics again.
  5. Restore the original sensor resolution.

Parameters:

Name Type Description Default
img_org ndarray | Tensor

Source image with shape [H, W, 3], either uint8 [0, 255] or float [0, 1].

required
save_name str | None

Path prefix for saved PNGs. If not None, saves '{save_name}.png' and (if unwarped) '{save_name}_unwarped.png'. Defaults to None.

None
depth float

Object distance in mm. Defaults to DEPTH.

DEPTH
spp int

Samples (rays) per pixel for rendering. Defaults to SPP_RENDER.

SPP_RENDER
unwarp bool

If True, apply distortion correction after rendering. Defaults to False.

False
method str

Rendering backend — 'ray_tracing' or 'psf_conv'. Defaults to 'ray_tracing'.

'ray_tracing'
show bool

If True, display the result with matplotlib. Defaults to False.

False

Returns:

Type Description

torch.Tensor: Rendered (and optionally unwarped) image with shape [1, 3, H, W], float values in [0, 1].

Source code in deeplens/optics/geolens_pkg/eval.py
@torch.no_grad()
def analysis_rendering(
    self,
    img_org,
    save_name=None,
    depth=DEPTH,
    spp=SPP_RENDER,
    unwarp=False,
    method="ray_tracing",
    show=False,
):
    """Render a test image through the lens and report PSNR / SSIM.

    Simulates what the sensor would capture if the given image were placed
    at the specified object distance.  The rendering accounts for all
    geometric aberrations (blur, distortion, vignetting, chromatic effects).
    Optionally applies an inverse distortion warp (``unwarp``) and reports
    quality metrics for both the raw and unwarped renderings.

    Algorithm:
        1. Convert ``img_org`` to a ``[1, 3, H, W]`` float tensor and
           temporarily set the sensor resolution to match.
        2. Call ``self.render()`` with the chosen method (ray tracing or PSF
           convolution).
        3. Compute PSNR and SSIM between the original and rendered images.
        4. If ``unwarp=True``, apply ``self.unwarp()`` to correct geometric
           distortion and report metrics again.
        5. Restore the original sensor resolution.

    Args:
        img_org (np.ndarray | torch.Tensor): Source image with shape
            ``[H, W, 3]``, either uint8 ``[0, 255]`` or float ``[0, 1]``.
        save_name (str | None): Path prefix for saved PNGs.  If not
            ``None``, saves ``'{save_name}.png'`` and (if unwarped)
            ``'{save_name}_unwarped.png'``. Defaults to ``None``.
        depth (float): Object distance in mm. Defaults to ``DEPTH``.
        spp (int): Samples (rays) per pixel for rendering.
            Defaults to ``SPP_RENDER``.
        unwarp (bool): If ``True``, apply distortion correction after
            rendering. Defaults to ``False``.
        method (str): Rendering backend — ``'ray_tracing'`` or
            ``'psf_conv'``. Defaults to ``'ray_tracing'``.
        show (bool): If ``True``, display the result with matplotlib.
            Defaults to ``False``.

    Returns:
        torch.Tensor: Rendered (and optionally unwarped) image with shape
            ``[1, 3, H, W]``, float values in ``[0, 1]``.
    """
    # Change sensor resolution to match the image
    sensor_res_original = self.sensor_res
    if isinstance(img_org, np.ndarray):
        img = torch.from_numpy(img_org).permute(2, 0, 1).unsqueeze(0).float() / 255.0
    elif torch.is_tensor(img_org):
        img = img_org.permute(2, 0, 1).unsqueeze(0).float()
        if img.max() > 1.0:
            img = img / 255.0
    img = img.to(self.device)
    self.set_sensor_res(sensor_res=img.shape[-2:])

    # Image rendering
    img_render = self.render(img, depth=depth, method=method, spp=spp)

    # Compute PSNR and SSIM
    img_np = img.squeeze(0).permute(1, 2, 0).cpu().numpy()
    render_np = img_render.squeeze(0).permute(1, 2, 0).clamp(0, 1).cpu().detach().numpy()
    render_psnr = round(peak_signal_noise_ratio(img_np, render_np, data_range=1.0), 3)
    render_ssim = round(structural_similarity(img_np, render_np, channel_axis=2, data_range=1.0), 4)
    print(f"Rendered image: PSNR={render_psnr:.3f}, SSIM={render_ssim:.4f}")

    # Save image
    if save_name is not None:
        save_image(img_render, f"{save_name}.png")

    # Unwarp to correct geometry distortion
    if unwarp:
        img_render = self.unwarp(img_render, depth)

        # Compute PSNR and SSIM
        render_np = img_render.squeeze(0).permute(1, 2, 0).clamp(0, 1).cpu().detach().numpy()
        render_psnr = round(peak_signal_noise_ratio(img_np, render_np, data_range=1.0), 3)
        render_ssim = round(structural_similarity(img_np, render_np, channel_axis=2, data_range=1.0), 4)
        print(
            f"Rendered image (unwarped): PSNR={render_psnr:.3f}, SSIM={render_ssim:.4f}"
        )

        if save_name is not None:
            save_image(img_render, f"{save_name}_unwarped.png")

    # Change the sensor resolution back
    self.set_sensor_res(sensor_res=sensor_res_original)

    # Show image
    if show:
        plt.imshow(img_render.cpu().squeeze(0).permute(1, 2, 0).numpy())
        plt.title("Rendered image")
        plt.axis("off")
        plt.show()
        plt.close()

    return img_render

analysis_spot

analysis_spot(num_field=3, depth=float('inf'))

Compute RMS and geometric spot radii at multiple field positions for RGB.

Traces rays at num_field evenly-spaced field positions along the meridional direction for three wavelengths (G, R, B), computes per- wavelength RMS and maximum (geometric) spot radii referenced to the green centroid, then averages the three wavelengths.

This provides a quick polychromatic spot-size summary used for design comparisons and printed to stdout during analysis().

Algorithm
  1. For each wavelength (G first, then R, B): a. self.sample_radial_rays()[num_field, SPP_PSF, 3]. b. self.trace2sensor() → sensor-plane positions. c. Green centroid c_G is computed on the first iteration and used as the common reference for all wavelengths. d. RMS = sqrt(mean(||xy - c_G||^2)) per field position. e. radius = max(||xy - c_G||) per field position.
  2. Average RMS and radius over the three wavelengths.
  3. Convert from mm to μm (× 1000).

Parameters:

Name Type Description Default
num_field int

Number of field positions sampled from on-axis to full-field. Defaults to 3.

3
depth float

Object distance in mm. Use float('inf') for collimated light. Defaults to float('inf').

float('inf')

Returns:

Type Description

dict[str, dict[str, float]]: Spot analysis results keyed by field position string (e.g., 'fov0.0', 'fov0.5', 'fov1.0'). Each value is a dict with: - 'rms': Polychromatic RMS spot radius in μm. - 'radius': Polychromatic geometric spot radius in μm.

Source code in deeplens/optics/geolens_pkg/eval.py
@torch.no_grad()
def analysis_spot(self, num_field=3, depth=float("inf")):
    """Compute RMS and geometric spot radii at multiple field positions for RGB.

    Traces rays at ``num_field`` evenly-spaced field positions along the
    meridional direction for three wavelengths (G, R, B), computes per-
    wavelength RMS and maximum (geometric) spot radii referenced to the
    **green centroid**, then averages the three wavelengths.

    This provides a quick polychromatic spot-size summary used for design
    comparisons and printed to stdout during ``analysis()``.

    Algorithm:
        1. For each wavelength (G first, then R, B):
           a. ``self.sample_radial_rays()`` → ``[num_field, SPP_PSF, 3]``.
           b. ``self.trace2sensor()`` → sensor-plane positions.
           c. Green centroid ``c_G`` is computed on the first iteration and
              used as the common reference for all wavelengths.
           d. ``RMS = sqrt(mean(||xy - c_G||^2))`` per field position.
           e. ``radius = max(||xy - c_G||)`` per field position.
        2. Average RMS and radius over the three wavelengths.
        3. Convert from mm to μm (× 1000).

    Args:
        num_field (int): Number of field positions sampled from on-axis
            to full-field. Defaults to 3.
        depth (float): Object distance in mm.  Use ``float('inf')`` for
            collimated light. Defaults to ``float('inf')``.

    Returns:
        dict[str, dict[str, float]]: Spot analysis results keyed by field
            position string (e.g., ``'fov0.0'``, ``'fov0.5'``, ``'fov1.0'``).
            Each value is a dict with:
                - ``'rms'``: Polychromatic RMS spot radius in μm.
                - ``'radius'``: Polychromatic geometric spot radius in μm.
    """
    rms_radius_fields = []
    geo_radius_fields = []
    for i, wvln in enumerate([WAVE_RGB[1], WAVE_RGB[0], WAVE_RGB[2]]):
        # Sample rays along meridional (y) direction, shape [num_field, num_rays, 3]
        ray = self.sample_radial_rays(
            num_field=num_field, depth=depth, num_rays=SPP_PSF, wvln=wvln
        )
        ray = self.trace2sensor(ray)

        # Green light point center for reference, shape [num_field, 1, 2]
        if i == 0:
            ray_xy_center_green = ray.centroid()[..., :2].unsqueeze(-2)

        # Calculate RMS spot size and radius for different FoVs
        ray_xy_norm = (
            ray.o[..., :2] - ray_xy_center_green
        ) * ray.is_valid.unsqueeze(-1)
        spot_rms = (
            ((ray_xy_norm**2).sum(-1) * ray.is_valid).sum(-1)
            / (ray.is_valid.sum(-1) + EPSILON)
        ).sqrt()
        spot_radius = (ray_xy_norm**2).sum(-1).sqrt().max(dim=-1).values

        # Append to list
        rms_radius_fields.append(spot_rms)
        geo_radius_fields.append(spot_radius)

    # Average over wavelengths, shape [num_field]
    avg_rms_radius_um = torch.stack(rms_radius_fields, dim=0).mean(dim=0) * 1000.0
    avg_geo_radius_um = torch.stack(geo_radius_fields, dim=0).mean(dim=0) * 1000.0

    # Print results
    print(f"Ray spot analysis results for depth {depth}:")
    print(
        f"RMS radius: FoV (0.0) {avg_rms_radius_um[0]:.3f} um, FoV (0.5) {avg_rms_radius_um[num_field // 2]:.3f} um, FoV (1.0) {avg_rms_radius_um[-1]:.3f} um"
    )
    print(
        f"Geo radius: FoV (0.0) {avg_geo_radius_um[0]:.3f} um, FoV (0.5) {avg_geo_radius_um[num_field // 2]:.3f} um, FoV (1.0) {avg_geo_radius_um[-1]:.3f} um"
    )

    # Save to dict
    rms_results = {}
    fov_ls = torch.linspace(0, 1, num_field)
    for i in range(num_field):
        fov = round(fov_ls[i].item(), 2)
        rms_results[f"fov{fov}"] = {
            "rms": round(avg_rms_radius_um[i].item(), 4),
            "radius": round(avg_geo_radius_um[i].item(), 4),
        }

    return rms_results

analysis

analysis(save_name='./lens', depth=float('inf'), full_eval=False, render=False, render_unwarp=False, lens_title=None, show=False)

Run a comprehensive optical analysis pipeline for the lens.

This is the main entry point for evaluating a lens design. It chains multiple evaluation steps in order, saving all plots with a common save_name prefix.

Execution flow
  1. Always: draw the lens layout (draw_layout) and compute polychromatic spot RMS/radius (analysis_spot).
  2. If full_eval=True: additionally generate:
  3. Spot diagram (draw_spot_radial).
  4. MTF grid (draw_mtf).
  5. Distortion curve (draw_distortion_radial).
  6. Field curvature plot (draw_field_curvature).
  7. Vignetting map (draw_vignetting).
  8. If render=True: render a test chart image through the lens and report PSNR/SSIM (analysis_rendering).

Parameters:

Name Type Description Default
save_name str

Path prefix for all output files. Each plot appends a suffix (e.g., '_spot.png', '_mtf.png'). Defaults to './lens'.

'./lens'
depth float

Object distance in mm. float('inf') is replaced by DEPTH for rendering and vignetting. Defaults to float('inf').

float('inf')
full_eval bool

If True, run all evaluation plots. If False, only layout + spot RMS. Defaults to False.

False
render bool

If True, render a test image through the lens. Defaults to False.

False
render_unwarp bool

If True (and render=True), also produce an unwarped rendering. Defaults to False.

False
lens_title str | None

Title string for the layout plot. Defaults to None.

None
show bool

If True, display all plots interactively. Defaults to False.

False
Source code in deeplens/optics/geolens_pkg/eval.py
@torch.no_grad()
def analysis(
    self,
    save_name="./lens",
    depth=float("inf"),
    full_eval=False,
    render=False,
    render_unwarp=False,
    lens_title=None,
    show=False,
):
    """Run a comprehensive optical analysis pipeline for the lens.

    This is the main entry point for evaluating a lens design.  It chains
    multiple evaluation steps in order, saving all plots with a common
    ``save_name`` prefix.

    Execution flow:
        1. **Always**: draw the lens layout (``draw_layout``) and compute
           polychromatic spot RMS/radius (``analysis_spot``).
        2. **If** ``full_eval=True``: additionally generate:
           - Spot diagram (``draw_spot_radial``).
           - MTF grid (``draw_mtf``).
           - Distortion curve (``draw_distortion_radial``).
           - Field curvature plot (``draw_field_curvature``).
           - Vignetting map (``draw_vignetting``).
        3. **If** ``render=True``: render a test chart image through the
           lens and report PSNR/SSIM (``analysis_rendering``).

    Args:
        save_name (str): Path prefix for all output files.  Each plot
            appends a suffix (e.g., ``'_spot.png'``, ``'_mtf.png'``).
            Defaults to ``'./lens'``.
        depth (float): Object distance in mm.  ``float('inf')`` is replaced
            by ``DEPTH`` for rendering and vignetting.
            Defaults to ``float('inf')``.
        full_eval (bool): If ``True``, run all evaluation plots.  If
            ``False``, only layout + spot RMS. Defaults to ``False``.
        render (bool): If ``True``, render a test image through the lens.
            Defaults to ``False``.
        render_unwarp (bool): If ``True`` (and ``render=True``), also
            produce an unwarped rendering. Defaults to ``False``.
        lens_title (str | None): Title string for the layout plot.
            Defaults to ``None``.
        show (bool): If ``True``, display all plots interactively.
            Defaults to ``False``.
    """
    # Draw lens layout and ray path
    self.draw_layout(
        filename=f"{save_name}.png",
        lens_title=lens_title,
        depth=depth,
        show=show,
    )

    # Calculate RMS error
    self.analysis_spot(depth=depth)

    # Comprehensive optical evaluation
    if full_eval:
        # Draw spot diagram
        self.draw_spot_radial(
            save_name=f"{save_name}_spot.png",
            depth=depth,
            show=show,
        )

        # Draw MTF
        if depth == float("inf"):
            self.draw_mtf(
                depth_list=[DEPTH],
                save_name=f"{save_name}_mtf.png",
                show=show,
            )
        else:
            self.draw_mtf(
                depth_list=[depth],
                save_name=f"{save_name}_mtf.png",
                show=show,
            )

        # Draw distortion
        self.draw_distortion_radial(
            save_name=f"{save_name}_distortion.png",
            show=show,
        )

        # Draw field curvature
        self.draw_field_curvature(
            save_name=f"{save_name}_field_curvature.png",
            show=show,
        )

        # Draw vignetting
        eval_depth = DEPTH if depth == float("inf") else depth
        self.draw_vignetting(
            filename=f"{save_name}_vignetting.png",
            depth=eval_depth,
            show=show,
        )

    # Render an image, compute PSNR and SSIM
    if render:
        depth = DEPTH if depth == float("inf") else depth
        img_org = Image.open("./datasets/charts/NBS_1963_1k.png").convert("RGB")
        img_org = np.array(img_org)
        self.analysis_rendering(
            img_org,
            depth=depth,
            spp=SPP_RENDER,
            unwarp=render_unwarp,
            save_name=f"{save_name}_render",
            show=show,
        )

deeplens.optics.geolens_pkg.eval_seidel.GeoLensSeidel

Mixin for Seidel (third-order) aberration analysis.

seidel_coefficients

seidel_coefficients(wvln: float = WVLN_d, include_chromatic: bool = True) -> Dict

Compute per-surface Seidel (third-order) aberration coefficients.

Parameters:

Name Type Description Default
wvln float

Reference wavelength in µm (default: d-line 0.5876 µm).

WVLN_d
include_chromatic bool

If True, also compute longitudinal and transverse chromatic aberration (C_L, C_T).

True

Returns:

Type Description
Dict

Dict with keys: S1..S5 — per-surface lists of Seidel sums [mm] CL, CT — per-surface chromatic aberrations [mm] labels — surface labels (e.g. ["S1", "S2", ...]) sums — dict of system totals for each aberration

Source code in deeplens/optics/geolens_pkg/eval_seidel.py
@torch.no_grad()
def seidel_coefficients(
    self,
    wvln: float = WVLN_d,
    include_chromatic: bool = True,
) -> Dict:
    """Compute per-surface Seidel (third-order) aberration coefficients.

    Args:
        wvln: Reference wavelength in µm (default: d-line 0.5876 µm).
        include_chromatic: If True, also compute longitudinal and
            transverse chromatic aberration (C_L, C_T).

    Returns:
        Dict with keys:
            S1..S5 — per-surface lists of Seidel sums [mm]
            CL, CT — per-surface chromatic aberrations [mm]
            labels — surface labels (e.g. ["S1", "S2", ...])
            sums   — dict of system totals for each aberration
    """
    tr = self._paraxial_trace(wvln)
    y = tr["y"]
    u = tr["u"]
    u_aft = tr["u_after"]
    yb = tr["ybar"]
    ub = tr["ubar"]
    ub_aft = tr["ubar_after"]
    n = tr["n"]
    np_ = tr["np"]
    c = tr["c"]
    surf_indices = tr["surf_indices"]
    num = len(y)

    # Lagrange invariant: H = n * (y_bar * u - y * u_bar)
    # Compute at first surface
    H = n[0] * (yb[0] * u[0] - y[0] * ub[0])

    S1 = [0.0] * num  # Spherical
    S2 = [0.0] * num  # Coma
    S3 = [0.0] * num  # Astigmatism
    S4 = [0.0] * num  # Petzval
    S5 = [0.0] * num  # Distortion
    CL = [0.0] * num  # Longitudinal chromatic
    CT = [0.0] * num  # Transverse chromatic

    wvln_t = torch.tensor([wvln])
    wvln_F_t = torch.tensor([WVLN_F])
    wvln_C_t = torch.tensor([WVLN_C])

    mat_before = Material("air")

    for j in range(num):
        si = surf_indices[j]
        surf = self.surfaces[si]

        # Refraction invariant A = n*(u + y*c), Abar = n*(ubar + ybar*c)
        A = n[j] * (u[j] + y[j] * c[j])
        Abar = n[j] * (ub[j] + yb[j] * c[j])

        # Delta(u/n) = u'/n' - u/n
        delta_u_over_n = u_aft[j] / np_[j] - u[j] / n[j]

        # Delta(1/n) = 1/n' - 1/n
        delta_inv_n = 1.0 / np_[j] - 1.0 / n[j]

        # --- Spherical surface contributions ---
        S1[j] = -A * A * y[j] * delta_u_over_n
        S2[j] = -A * Abar * y[j] * delta_u_over_n
        S3[j] = -Abar * Abar * y[j] * delta_u_over_n
        S4[j] = -H * H * c[j] * delta_inv_n
        # S5 = (Abar/A) * (S3 + S4), guarding A ≈ 0
        if abs(A) > 1e-12:
            S5[j] = (Abar / A) * (S3[j] + S4[j])
        else:
            S5[j] = 0.0

        # --- Aspheric correction ---
        if isinstance(surf, Aspheric):
            k_val = float(surf.k) if hasattr(surf.k, 'item') else float(surf.k)
            c_val = c[j]
            # Fourth-order deformation: b4 = k*c^3/8 + a4
            a4 = 0.0
            if surf.ai is not None and len(surf.ai) > 0:
                a4 = float(surf.ai[0])
            b4 = k_val * c_val**3 / 8.0 + a4

            dn = np_[j] - n[j]
            y4 = y[j] ** 4

            dS1 = -8.0 * dn * y4 * b4
            S1[j] += dS1

            if abs(y[j]) > 1e-12:
                ratio = yb[j] / y[j]
                dS2 = -ratio * dS1
                dS3 = -(ratio**2) * dS1
                dS5 = -(ratio**3) * dS1
                S2[j] += dS2
                S3[j] += dS3
                S5[j] += dS5

        # --- Chromatic aberration ---
        if include_chromatic:
            n_F = float(mat_before.ior(wvln_F_t))
            n_C = float(mat_before.ior(wvln_C_t))
            np_F = float(surf.mat2.ior(wvln_F_t))
            np_C = float(surf.mat2.ior(wvln_C_t))

            delta_n = n_F - n_C
            delta_np = np_F - np_C

            # Δ(δn / n_d) = δn'/n'_d - δn/n_d
            delta_dn_over_nd = delta_np / np_[j] - delta_n / n[j]

            CL[j] = -y[j] * A * delta_dn_over_nd
            CT[j] = -y[j] * Abar * delta_dn_over_nd

        mat_before = surf.mat2

    # Labels
    labels = [f"S{si + 1}" for si in surf_indices]

    # System sums
    sums = {
        "S1": sum(S1),
        "S2": sum(S2),
        "S3": sum(S3),
        "S4": sum(S4),
        "S5": sum(S5),
        "CL": sum(CL),
        "CT": sum(CT),
    }

    result = {
        "S1": S1,
        "S2": S2,
        "S3": S3,
        "S4": S4,
        "S5": S5,
        "CL": CL,
        "CT": CT,
        "labels": labels,
        "sums": sums,
    }

    logger.info(
        "Seidel sums: S1=%.4f S2=%.4f S3=%.4f S4=%.4f S5=%.4f CL=%.4f CT=%.4f",
        sums["S1"], sums["S2"], sums["S3"], sums["S4"], sums["S5"],
        sums["CL"], sums["CT"],
    )

    return result

aberration_histogram

aberration_histogram(wvln: float = WVLN_d, save_name: Optional[str] = None, show: bool = False, include_chromatic: bool = True) -> Dict

Draw a Zemax-style Seidel aberration bar chart.

Parameters:

Name Type Description Default
wvln float

Reference wavelength in µm.

WVLN_d
save_name Optional[str]

Path to save the figure. Defaults to "./seidel_aberration.png".

None
show bool

If True, call plt.show() instead of saving.

False
include_chromatic bool

Include C_L and C_T bars.

True

Returns:

Type Description
Dict

The Seidel coefficients dict (same as seidel_coefficients).

Source code in deeplens/optics/geolens_pkg/eval_seidel.py
@torch.no_grad()
def aberration_histogram(
    self,
    wvln: float = WVLN_d,
    save_name: Optional[str] = None,
    show: bool = False,
    include_chromatic: bool = True,
) -> Dict:
    """Draw a Zemax-style Seidel aberration bar chart.

    Args:
        wvln: Reference wavelength in µm.
        save_name: Path to save the figure. Defaults to
            ``"./seidel_aberration.png"``.
        show: If True, call ``plt.show()`` instead of saving.
        include_chromatic: Include C_L and C_T bars.

    Returns:
        The Seidel coefficients dict (same as ``seidel_coefficients``).
    """
    coeffs = self.seidel_coefficients(wvln=wvln, include_chromatic=include_chromatic)

    labels = coeffs["labels"]
    sums = coeffs["sums"]

    # Aberration keys and display config
    if include_chromatic:
        ab_keys = ["S1", "S2", "S3", "S4", "S5", "CL", "CT"]
        ab_names = [
            "S_I (Spherical)",
            "S_II (Coma)",
            "S_III (Astigmatism)",
            "S_IV (Petzval)",
            "S_V (Distortion)",
            "C_L (Axial Color)",
            "C_T (Lateral Color)",
        ]
        colors = ["#1f77b4", "#2ca02c", "#d62728", "#17becf", "#9467bd", "#bcbd22", "#ff7f0e"]
    else:
        ab_keys = ["S1", "S2", "S3", "S4", "S5"]
        ab_names = [
            "S_I (Spherical)",
            "S_II (Coma)",
            "S_III (Astigmatism)",
            "S_IV (Petzval)",
            "S_V (Distortion)",
        ]
        colors = ["#1f77b4", "#2ca02c", "#d62728", "#17becf", "#9467bd"]

    n_ab = len(ab_keys)
    n_surf = len(labels)
    x_labels = labels + ["SUM"]
    n_groups = n_surf + 1  # surfaces + SUM

    x = np.arange(n_groups)
    bar_width = 0.8 / n_ab

    fig, ax = plt.subplots(figsize=(max(8, n_groups * 0.8 + 2), 5))

    for k, (key, name, color) in enumerate(zip(ab_keys, ab_names, colors)):
        vals = coeffs[key] + [sums[key]]
        offset = (k - n_ab / 2.0 + 0.5) * bar_width
        ax.bar(x + offset, vals, bar_width, label=name, color=color, edgecolor="white", linewidth=0.5)

    ax.set_xlabel("Surface")
    ax.set_ylabel("Aberration Coefficient [mm]")
    ax.set_title("Seidel Aberration Diagram")
    ax.set_xticks(x)
    ax.set_xticklabels(x_labels, rotation=45, ha="right")
    ax.legend(fontsize=7, loc="best")
    ax.axhline(y=0, color="black", linewidth=0.5)
    ax.grid(axis="y", alpha=0.3)

    plt.tight_layout()

    if show:
        plt.show()
    else:
        if save_name is None:
            save_name = "./seidel_aberration.png"
        plt.savefig(save_name, bbox_inches="tight", format="png", dpi=300)
    plt.close(fig)

    return coeffs

deeplens.optics.geolens_pkg.optim.GeoLensOptim

Mixin providing differentiable optimisation for GeoLens.

Implements gradient-based lens design using PyTorch autograd:

  • Loss functions – RMS spot error, focus, surface regularity, gap constraints, material validity.
  • Constraint initialisation – edge-thickness and self-intersection guards.
  • Optimizer helpers – parameter groups with per-type learning rates and cosine annealing schedules.
  • High-level optimize() – curriculum-learning training loop.

This class is not instantiated directly; it is mixed into :class:~deeplens.optics.geolens.GeoLens.

References

Xinge Yang et al., "Curriculum learning for ab initio deep learned refractive optics," Nature Communications 2024.

init_constraints

init_constraints(constraint_params=None)

Initialize constraints for the lens design.

Parameters:

Name Type Description Default
constraint_params dict

Constraint parameters.

None
Source code in deeplens/optics/geolens_pkg/optim.py
def init_constraints(self, constraint_params=None):
    """Initialize constraints for the lens design.

    Args:
        constraint_params (dict): Constraint parameters.
    """
    # In the future, we want to use constraint_params to set the constraints.
    if constraint_params is None:
        constraint_params = {}
        print("Lens design constraints initialized with default values.")

    if self.r_sensor < 12.0:
        self.is_cellphone = True

        # Self intersection lower bounds
        self.air_min_edge = 0.05
        self.air_min_center = 0.05
        self.thick_min_edge = 0.25
        self.thick_min_center = 0.4
        self.bfl_min = 0.8

        # Air gap and thickness upper bounds
        self.air_max_edge = 3.0
        self.air_max_center = 1.5
        self.thick_max_edge = 2.0
        self.thick_max_center = 3.0
        self.bfl_max = 3.0
        self.ttl_max = 15.0

        # Surface shape constraints
        self.sag2diam_max = 0.1
        self.grad_max = 0.57 # tan(30deg)
        self.diam2thick_max = 15.0
        self.tmax2tmin_max = 5.0

        # Ray angle constraints
        self.chief_ray_angle_max = 30.0 # deg
        self.obliq_min = 0.6

    else:
        self.is_cellphone = False

        # Self-intersection lower bounds
        self.air_min_edge = 0.1
        self.air_min_center = 0.1
        self.thick_min_edge = 1.0
        self.thick_min_center = 2.0
        self.bfl_min = 5.0

        # Air gap and thickness upper bounds
        self.air_max_edge = 100.0  # float("inf")
        self.air_max_center = 100.0  # float("inf")
        self.thick_max_edge = 20.0
        self.thick_max_center = 20.0
        self.bfl_max = 100.0  # float("inf")
        self.ttl_max = 300.0

        # Surface shape constraints
        self.sag2diam_max = 0.2
        self.grad_max = 0.84 # tan(40deg)
        self.diam2thick_max = 20.0
        self.tmax2tmin_max = 10.0

        # Ray angle constraints
        self.chief_ray_angle_max = 40.0 # deg
        self.obliq_min = 0.4

loss_reg

loss_reg(w_focus=10.0, w_ray_angle=2.0, w_intersec=1.0, w_thickness=0.1, w_surf=1.0)

Compute combined regularization loss for lens design.

Aggregates multiple constraint losses to keep the lens physically valid during gradient-based optimisation.

Parameters:

Name Type Description Default
w_focus float

Weight for focus loss. Defaults to 10.0.

10.0
w_ray_angle float

Weight for chief ray angle loss. Defaults to 2.0.

2.0
w_intersec float

Weight for self-intersection loss. Defaults to 1.0.

1.0
w_thickness float

Weight for thickness / TTL loss. Defaults to 0.1.

0.1
w_surf float

Weight for surface shape loss. Defaults to 1.0.

1.0

Returns:

Name Type Description
tuple

(loss_reg, loss_dict) where: - loss_reg (Tensor): Scalar combined regularization loss. - loss_dict (dict): Per-component loss values for logging.

Source code in deeplens/optics/geolens_pkg/optim.py
def loss_reg(self, w_focus=10.0, w_ray_angle=2.0, w_intersec=1.0, w_thickness=0.1, w_surf=1.0):
    """Compute combined regularization loss for lens design.

    Aggregates multiple constraint losses to keep the lens physically valid
    during gradient-based optimisation.

    Args:
        w_focus (float, optional): Weight for focus loss. Defaults to 10.0.
        w_ray_angle (float, optional): Weight for chief ray angle loss. Defaults to 2.0.
        w_intersec (float, optional): Weight for self-intersection loss. Defaults to 1.0.
        w_thickness (float, optional): Weight for thickness / TTL loss. Defaults to 0.1.
        w_surf (float, optional): Weight for surface shape loss. Defaults to 1.0.

    Returns:
        tuple: (loss_reg, loss_dict) where:
            - loss_reg (Tensor): Scalar combined regularization loss.
            - loss_dict (dict): Per-component loss values for logging.
    """
    # Loss functions for regularization
    # loss_focus = self.loss_infocus()
    loss_ray_angle = self.loss_ray_angle()
    loss_intersec = self.loss_intersec()
    loss_thickness = self.loss_thickness()
    loss_surf = self.loss_surface()
    # loss_mat = self.loss_mat()
    loss_reg = (
        # w_focus * loss_focus
        + w_intersec * loss_intersec
        + w_thickness * loss_thickness
        + w_surf * loss_surf
        + w_ray_angle * loss_ray_angle
        # + loss_mat
    )

    # Return loss and loss dictionary
    loss_dict = {
        # "loss_focus": loss_focus.item(),
        "loss_intersec": loss_intersec.item(),
        "loss_thickness": loss_thickness.item(),
        "loss_surf": loss_surf.item(),
        'loss_ray_angle': loss_ray_angle.item(),
        # 'loss_mat': loss_mat.item(),
    }
    return loss_reg, loss_dict

loss_infocus

loss_infocus(target=0.005)

Sample parallel rays and compute RMS loss on the sensor plane, minimize focus loss.

Parameters:

Name Type Description Default
target float

target of RMS loss. Defaults to 0.005 [mm].

0.005
Source code in deeplens/optics/geolens_pkg/optim.py
def loss_infocus(self, target=0.005):
    """Sample parallel rays and compute RMS loss on the sensor plane, minimize focus loss.

    Args:
        target (float, optional): target of RMS loss. Defaults to 0.005 [mm].
    """
    loss = torch.tensor(0.0, device=self.device)

    # Ray tracing and calculate RMS error
    ray = self.sample_from_fov(fov_x=0.0, fov_y=0.0, wvln=WAVE_RGB[1], num_rays=SPP_CALC)
    ray = self.trace2sensor(ray)
    rms_error = ray.rms_error()

    # If RMS error is larger than target, add it to loss
    if rms_error > target:
        loss += rms_error

    return loss

loss_surface

loss_surface()

Penalize extreme surface shapes that are difficult to manufacture.

Checks four constraints for each optimisable surface
  1. Sag-to-diameter ratio exceeding sag2diam_max.
  2. Maximum surface gradient exceeding grad_max.
  3. Diameter-to-thickness ratio exceeding diam2thick_max.
  4. Maximum-to-minimum thickness ratio exceeding tmax2tmin_max.

Returns:

Name Type Description
Tensor

Scalar surface shape penalty loss.

Source code in deeplens/optics/geolens_pkg/optim.py
def loss_surface(self):
    """Penalize extreme surface shapes that are difficult to manufacture.

    Checks four constraints for each optimisable surface:
        1. Sag-to-diameter ratio exceeding ``sag2diam_max``.
        2. Maximum surface gradient exceeding ``grad_max``.
        3. Diameter-to-thickness ratio exceeding ``diam2thick_max``.
        4. Maximum-to-minimum thickness ratio exceeding ``tmax2tmin_max``.

    Returns:
        Tensor: Scalar surface shape penalty loss.
    """
    sag2diam_max = self.sag2diam_max
    grad_max_allowed = self.grad_max
    diam2thick_max = self.diam2thick_max
    tmax2tmin_max = self.tmax2tmin_max

    loss_grad = torch.tensor(0.0, device=self.device)
    loss_diam2thick = torch.tensor(0.0, device=self.device)
    loss_tmax2tmin = torch.tensor(0.0, device=self.device)
    loss_sag2diam = torch.tensor(0.0, device=self.device)
    for i in self.find_diff_surf():
        # Sample points on the surface
        x_ls = torch.linspace(0.0, 1.0, 32, device=self.device) * self.surfaces[i].r
        y_ls = torch.zeros_like(x_ls)

        # Sag
        sag_ls = self.surfaces[i].sag(x_ls, y_ls)
        sag2diam = sag_ls.abs().max() / self.surfaces[i].r / 2
        if sag2diam > sag2diam_max:
            loss_sag2diam += sag2diam

        # 1st-order derivative
        grad_ls = self.surfaces[i].dfdxyz(x_ls, y_ls)[0]
        grad_max = grad_ls.abs().max()
        if grad_max > grad_max_allowed:
            loss_grad += grad_max

        # Diameter to thickness ratio, thick_max to thick_min ratio
        if not self.surfaces[i].mat2.name == "air":
            surf2 = self.surfaces[i + 1]
            surf1 = self.surfaces[i]

            # Penalize diameter to thickness ratio
            diam2thick = 2 * max(surf2.r, surf1.r) / (surf2.d - surf1.d)
            if diam2thick > diam2thick_max:
                loss_diam2thick += diam2thick

            # Penalize thick_max to thick_min ratio.
            # Clamp denominators to avoid Inf from near-zero edge thickness.
            r_edge = min(surf2.r, surf1.r)
            thick_center = surf2.d - surf1.d
            thick_edge = surf2.surface_with_offset(r_edge, 0.0) - surf1.surface_with_offset(r_edge, 0.0)
            if thick_center > thick_edge:
                tmax2tmin = thick_center / thick_edge.clamp(min=0.01)
            else:
                tmax2tmin = thick_edge / max(thick_center, 0.01)

            if tmax2tmin > tmax2tmin_max:
                loss_tmax2tmin += tmax2tmin

    return loss_sag2diam + loss_grad + loss_diam2thick + loss_tmax2tmin

loss_intersec

loss_intersec()

Loss function to avoid self-intersection.

This function penalizes when surfaces are too close to each other, which could cause self-intersection or manufacturing issues.

Source code in deeplens/optics/geolens_pkg/optim.py
def loss_intersec(self):
    """Loss function to avoid self-intersection.

    This function penalizes when surfaces are too close to each other,
    which could cause self-intersection or manufacturing issues.
    """
    # Constraints
    air_min_center = self.air_min_center
    air_min_edge = self.air_min_edge
    thick_min_center = self.thick_min_center
    thick_min_edge = self.thick_min_edge
    bfl_min = self.bfl_min

    # Loss
    loss = torch.tensor(0.0, device=self.device)
    for i in range(len(self.surfaces) - 1):
        # Sample evaluation points on the two surfaces
        current_surf = self.surfaces[i]
        next_surf = self.surfaces[i + 1]

        r_center = torch.tensor(0.0, device=self.device) * current_surf.r
        z_prev_center = current_surf.surface_with_offset(r_center, 0.0, valid_check=False)
        z_next_center = next_surf.surface_with_offset(r_center, 0.0, valid_check=False)

        r_edge = torch.linspace(0.5, 1.0, 16, device=self.device) * current_surf.r
        z_prev_edge = current_surf.surface_with_offset(r_edge, 0.0, valid_check=False)
        z_next_edge = next_surf.surface_with_offset(r_edge, 0.0, valid_check=False)

        # Next surface is air
        if self.surfaces[i].mat2.name == "air":
            # Center air gap
            dist_center = z_next_center - z_prev_center
            if dist_center < air_min_center:
                loss += dist_center

            # Edge air gap
            dist_edge = torch.min(z_next_edge - z_prev_edge)
            if dist_edge < air_min_edge:
                loss += dist_edge

        # Next surface is lens
        else:
            # Center thickness
            dist_center = z_next_center - z_prev_center
            if dist_center < thick_min_center:
                loss += dist_center

            # Edge thickness
            dist_edge = torch.min(z_next_edge - z_prev_edge)
            if dist_edge < thick_min_edge:
                loss += dist_edge

    # Distance to sensor (back focal length)
    last_surf = self.surfaces[-1]
    r = torch.linspace(0.0, 1.0, 32, device=self.device) * last_surf.r
    z_last_surf = self.d_sensor - last_surf.surface_with_offset(r, 0.0)

    bfl = torch.min(z_last_surf)
    if bfl < bfl_min:
        loss += bfl

    # Loss, maximize loss
    return -loss

loss_thickness

loss_thickness()

Penalize excessive air gaps, lens thicknesses, and total track length.

Checks three types of upper-bound constraints
  1. Per-gap air and glass thickness (center and edge).
  2. Back focal length (BFL).
  3. Total track length (TTL) from first surface to sensor.

Returns:

Name Type Description
Tensor

Scalar thickness penalty loss.

Source code in deeplens/optics/geolens_pkg/optim.py
def loss_thickness(self):
    """Penalize excessive air gaps, lens thicknesses, and total track length.

    Checks three types of upper-bound constraints:
        1. Per-gap air and glass thickness (center and edge).
        2. Back focal length (BFL).
        3. Total track length (TTL) from first surface to sensor.

    Returns:
        Tensor: Scalar thickness penalty loss.
    """
    # Constraints
    air_max_center = self.air_max_center
    air_max_edge = self.air_max_edge
    thick_max_center = self.thick_max_center
    thick_max_edge = self.thick_max_edge
    bfl_max = self.bfl_max
    ttl_max = self.ttl_max

    # Loss
    loss = torch.tensor(0.0, device=self.device)

    # Distance between surfaces
    for i in range(len(self.surfaces) - 1):
        # Sample evaluation points on the two surfaces
        current_surf = self.surfaces[i]
        next_surf = self.surfaces[i + 1]

        r_center = torch.tensor(0.0, device=self.device) * current_surf.r
        z_prev_center = current_surf.surface_with_offset(r_center, 0.0, valid_check=False)
        z_next_center = next_surf.surface_with_offset(r_center, 0.0, valid_check=False)

        r_edge = torch.linspace(0.5, 1.0, 16, device=self.device) * current_surf.r
        z_prev_edge = current_surf.surface_with_offset(r_edge, 0.0, valid_check=False)
        z_next_edge = next_surf.surface_with_offset(r_edge, 0.0, valid_check=False)

        # Air gap
        if self.surfaces[i].mat2.name == "air":
            # Center air gap
            dist_center = z_next_center - z_prev_center
            if dist_center > air_max_center:
                loss += dist_center

            # Edge air gap
            dist_edge = torch.max(z_next_edge - z_prev_edge)
            if dist_edge > air_max_edge:
                loss += dist_edge

        # Lens thickness
        else:
            # Center thickness
            dist_center = z_next_center - z_prev_center
            if dist_center > thick_max_center:
                loss += dist_center

            # Edge thickness
            dist_edge = torch.max(z_next_edge - z_prev_edge)
            if dist_edge > thick_max_edge:
                loss += dist_edge

    # Distance to sensor (back focal length)
    last_surf = self.surfaces[-1]
    r = torch.linspace(0.0, 1.0, 32, device=self.device) * last_surf.r
    z_last_surf = self.d_sensor - last_surf.surface_with_offset(r, 0.0)

    bfl = torch.max(z_last_surf)
    if bfl > bfl_max:
        loss += bfl

    # Total track length (first surface to sensor)
    ttl = self.d_sensor - self.surfaces[0].d
    if ttl > ttl_max:
        loss += ttl

    # Loss, minimize loss
    return loss

loss_ray_angle

loss_ray_angle()

Penalize large chief ray angles and low obliquity factors.

Ensures that rays arrive at the sensor within acceptable incidence angles, which is critical for sensor coupling and colour cross-talk.

Returns:

Name Type Description
Tensor

Scalar chief-ray-angle penalty loss.

Source code in deeplens/optics/geolens_pkg/optim.py
def loss_ray_angle(self):
    """Penalize large chief ray angles and low obliquity factors.

    Ensures that rays arrive at the sensor within acceptable incidence
    angles, which is critical for sensor coupling and colour cross-talk.

    Returns:
        Tensor: Scalar chief-ray-angle penalty loss.
    """
    max_angle_deg = self.chief_ray_angle_max
    obliq_min = self.obliq_min

    # Loss on chief ray angle
    ray = self.sample_ring_arm_rays(num_ring=8, num_arm=8, spp=SPP_CALC, scale_pupil=0.2)
    ray = self.trace2sensor(ray)
    cos_cra = ray.d[..., 2]
    cos_cra_ref = float(np.cos(np.deg2rad(max_angle_deg)))
    mask_cra = (cos_cra < cos_cra_ref).float()
    count_cra = mask_cra.sum()
    loss_cra = -(cos_cra * mask_cra).sum() / (count_cra + EPSILON)

    # Loss on accumulated oblique term
    ray = self.sample_ring_arm_rays(num_ring=8, num_arm=8, spp=SPP_CALC, scale_pupil=1.0)
    ray = self.trace2sensor(ray)
    obliq = ray.obliq.squeeze(-1)
    mask_obliq = (obliq < obliq_min).float()
    count_obliq = mask_obliq.sum()
    loss_obliq = -(obliq * mask_obliq).sum() / (count_obliq + EPSILON)

    return loss_cra + loss_obliq

loss_mat

loss_mat()

Penalize material parameters outside manufacturable ranges.

Constrains refractive index n to [1.5, 1.9] and Abbe number V to [30, 70] for each non-air surface material.

Returns:

Name Type Description
Tensor

Scalar material penalty loss.

Source code in deeplens/optics/geolens_pkg/optim.py
def loss_mat(self):
    """Penalize material parameters outside manufacturable ranges.

    Constrains refractive index *n* to [1.5, 1.9] and Abbe number *V* to
    [30, 70] for each non-air surface material.

    Returns:
        Tensor: Scalar material penalty loss.
    """
    n_max = 1.9
    n_min = 1.5
    V_max = 70
    V_min = 30
    loss_mat = torch.tensor(0.0, device=self.device)
    for i in range(len(self.surfaces)):
        if self.surfaces[i].mat2.name != "air":
            if self.surfaces[i].mat2.n > n_max:
                loss_mat += (self.surfaces[i].mat2.n - n_max) / (n_max - n_min)
            if self.surfaces[i].mat2.n < n_min:
                loss_mat += (n_min - self.surfaces[i].mat2.n) / (n_max - n_min)
            if self.surfaces[i].mat2.V > V_max:
                loss_mat += (self.surfaces[i].mat2.V - V_max) / (V_max - V_min)
            if self.surfaces[i].mat2.V < V_min:
                loss_mat += (V_min - self.surfaces[i].mat2.V) / (V_max - V_min)

    return loss_mat

loss_rms

loss_rms(num_grid=GEO_GRID, depth=DEPTH, num_rays=SPP_PSF, sample_more_off_axis=False)

Loss function to compute RGB spot error RMS.

Parameters:

Name Type Description Default
num_grid int

Number of grid points. Defaults to GEO_GRID.

GEO_GRID
depth float

Depth of the lens. Defaults to DEPTH.

DEPTH
num_rays int

Number of rays. Defaults to SPP_CALC.

SPP_PSF
sample_more_off_axis bool

Whether to sample more off-axis rays. Defaults to False.

False

Returns:

Name Type Description
avg_rms_error Tensor

RMS error averaged over wavelengths and grid points.

Source code in deeplens/optics/geolens_pkg/optim.py
def loss_rms(
    self,
    num_grid=GEO_GRID,
    depth=DEPTH,
    num_rays=SPP_PSF,
    sample_more_off_axis=False,
):
    """Loss function to compute RGB spot error RMS.

    Args:
        num_grid (int, optional): Number of grid points. Defaults to GEO_GRID.
        depth (float, optional): Depth of the lens. Defaults to DEPTH.
        num_rays (int, optional): Number of rays. Defaults to SPP_CALC.
        sample_more_off_axis (bool, optional): Whether to sample more off-axis rays. Defaults to False.

    Returns:
        avg_rms_error (torch.Tensor): RMS error averaged over wavelengths and grid points.
    """
    all_rms_errors = []
    for i, wvln in enumerate([WAVE_RGB[1], WAVE_RGB[0], WAVE_RGB[2]]):
        ray = self.sample_grid_rays(
            depth=depth,
            num_grid=num_grid,
            num_rays=num_rays,
            wvln=wvln,
            sample_more_off_axis=sample_more_off_axis,
        )

        # Calculate reference center, shape of (..., 2)
        if i == 0:
            with torch.no_grad():
                ray_center_green = -self.psf_center(points_obj=ray.o[:, :, 0, :], method="pinhole")

        ray = self.trace2sensor(ray)

        # # Green light centroid for reference
        # if i == 0:
        #     with torch.no_grad():
        #         ray_center_green = ray.centroid()

        # Calculate RMS error with reference center
        rms_error = ray.rms_error(center_ref=ray_center_green)
        all_rms_errors.append(rms_error)

    # Calculate average RMS error
    avg_rms_error = torch.stack(all_rms_errors).mean(dim=0)
    return avg_rms_error

sample_ring_arm_rays

sample_ring_arm_rays(num_ring=8, num_arm=8, spp=2048, depth=DEPTH, wvln=DEFAULT_WAVE, scale_pupil=1.0, sample_more_off_axis=True)

Sample rays from object space using a ring-arm pattern.

This method distributes sampling points (origins of ray bundles) on a polar grid in the object plane, defined by field of view. This is useful for capturing lens performance across the full field. The points include the center and num_ring rings with num_arm points on each.

Parameters:

Name Type Description Default
num_ring int

Number of rings to sample in the field of view.

8
num_arm int

Number of arms (spokes) to sample for each ring.

8
spp int

Total number of rays to be sampled, distributed among field points.

2048
depth float

Depth of the object plane.

DEPTH
wvln float

Wavelength of the rays.

DEFAULT_WAVE
scale_pupil float

Scale factor for the pupil size.

1.0

Returns:

Name Type Description
Ray

A Ray object containing the sampled rays.

Source code in deeplens/optics/geolens_pkg/optim.py
def sample_ring_arm_rays(self, num_ring=8, num_arm=8, spp=2048, depth=DEPTH, wvln=DEFAULT_WAVE, scale_pupil=1.0, sample_more_off_axis=True):
    """Sample rays from object space using a ring-arm pattern.

    This method distributes sampling points (origins of ray bundles) on a polar grid in the object plane,
    defined by field of view. This is useful for capturing lens performance across the full field.
    The points include the center and `num_ring` rings with `num_arm` points on each.

    Args:
        num_ring (int): Number of rings to sample in the field of view.
        num_arm (int): Number of arms (spokes) to sample for each ring.
        spp (int): Total number of rays to be sampled, distributed among field points.
        depth (float): Depth of the object plane.
        wvln (float): Wavelength of the rays.
        scale_pupil (float): Scale factor for the pupil size.

    Returns:
        Ray: A Ray object containing the sampled rays.
    """
    # Create points on rings and arms
    max_fov_rad = self.rfov
    if sample_more_off_axis:
        # Use beta distribution to sample more points near the edge (close to 1.0)
        # Beta(0.5, 0.5) gives more samples at 0 and 1, Beta(0.5, 0.3) gives more samples near 1.0
        beta_values = torch.linspace(0.0, 1.0, num_ring, device=self.device)
        # Apply beta transformation to concentrate samples near 1.0
        beta_transformed = beta_values ** 0.5  # Equivalent to Beta(0.5, 1.0) distribution
        ring_fovs = max_fov_rad * beta_transformed

        # Use square root to sample more points near the edge
        # ring_fovs = max_fov_rad * torch.sqrt(torch.linspace(0.0, 1.0, num_ring, device=self.device))
    else:
        ring_fovs = max_fov_rad * torch.linspace(0.0, 1.0, num_ring, device=self.device)

    arm_angles = torch.linspace(0.0, 2 * torch.pi, num_arm + 1, device=self.device)[:-1]
    ring_grid, arm_grid = torch.meshgrid(ring_fovs, arm_angles, indexing="ij")
    x = depth * torch.tan(ring_grid) * torch.cos(arm_grid)
    y = depth * torch.tan(ring_grid) * torch.sin(arm_grid)        
    z = torch.full_like(x, depth)
    points = torch.stack([x, y, z], dim=-1)  # shape: [num_ring, num_arm, 3]

    # Sample rays
    rays = self.sample_from_points(points=points, num_rays=spp, wvln=wvln, scale_pupil=scale_pupil)
    return rays

optimize

optimize(lrs=[0.001, 0.0001, 0.1, 0.0001], iterations=5000, test_per_iter=100, centroid=False, optim_mat=False, shape_control=True, result_dir=None)

Optimise the lens by minimising RGB RMS spot errors.

Runs a curriculum-learning training loop with Adam optimiser and cosine annealing. Periodically evaluates the lens, saves intermediate results, and optionally corrects surface shapes.

Parameters:

Name Type Description Default
lrs list

Learning rates for [d, c, k, a] parameter groups. Defaults to [1e-3, 1e-4, 1e-1, 1e-4].

[0.001, 0.0001, 0.1, 0.0001]
iterations int

Total training iterations. Defaults to 5000.

5000
test_per_iter int

Evaluate and save every N iterations. Defaults to 100.

100
centroid bool

If True, use chief-ray centroid as PSF centre reference; otherwise use pinhole model. Defaults to False.

False
optim_mat bool

If True, include material parameters (n, V) in optimisation. Defaults to False.

False
shape_control bool

If True, call correct_shape() at each evaluation step. Defaults to True.

True
result_dir str

Directory to save results. If None, auto-generates a timestamped directory. Defaults to None.

None
Note

Debug hints: 1. Slowly optimise with small learning rate. 2. FoV and thickness should match well. 3. Keep parameter ranges reasonable. 4. Higher aspheric order is better but more sensitive. 5. More iterations with larger ray sampling improves convergence.

Source code in deeplens/optics/geolens_pkg/optim.py
def optimize(
    self,
    lrs=[1e-3, 1e-4, 1e-1, 1e-4],
    iterations=5000,
    test_per_iter=100,
    centroid=False,
    optim_mat=False,
    shape_control=True,
    result_dir=None,
):
    """Optimise the lens by minimising RGB RMS spot errors.

    Runs a curriculum-learning training loop with Adam optimiser and cosine
    annealing. Periodically evaluates the lens, saves intermediate results,
    and optionally corrects surface shapes.

    Args:
        lrs (list, optional): Learning rates for [d, c, k, a] parameter groups.
            Defaults to [1e-3, 1e-4, 1e-1, 1e-4].
        iterations (int, optional): Total training iterations. Defaults to 5000.
        test_per_iter (int, optional): Evaluate and save every N iterations.
            Defaults to 100.
        centroid (bool, optional): If True, use chief-ray centroid as PSF centre
            reference; otherwise use pinhole model. Defaults to False.
        optim_mat (bool, optional): If True, include material parameters (n, V)
            in optimisation. Defaults to False.
        shape_control (bool, optional): If True, call ``correct_shape()`` at each
            evaluation step. Defaults to True.
        result_dir (str, optional): Directory to save results. If None,
            auto-generates a timestamped directory. Defaults to None.

    Note:
        Debug hints:
            1. Slowly optimise with small learning rate.
            2. FoV and thickness should match well.
            3. Keep parameter ranges reasonable.
            4. Higher aspheric order is better but more sensitive.
            5. More iterations with larger ray sampling improves convergence.
    """
    # Experiment settings
    depth = DEPTH
    num_ring = 32
    num_arm = 8
    spp = 2048

    # Result directory and logger
    if result_dir is None:
        result_dir = f"./results/{datetime.now().strftime('%m%d-%H%M%S')}-DesignLens"

    os.makedirs(result_dir, exist_ok=True)
    if not logging.getLogger().hasHandlers():
        logger = logging.getLogger()
        logger.setLevel("DEBUG")
        fmt = logging.Formatter("%(asctime)s:%(levelname)s:%(message)s", "%Y-%m-%d %H:%M:%S")
        sh = logging.StreamHandler()
        sh.setFormatter(fmt)
        sh.setLevel("INFO")
        fh = logging.FileHandler(f"{result_dir}/output.log")
        fh.setFormatter(fmt)
        fh.setLevel("INFO")
        logger.addHandler(sh)
        logger.addHandler(fh)
    logging.info(f"lr:{lrs}, iterations:{iterations}, num_ring:{num_ring}, num_arm:{num_arm}, rays_per_fov:{spp}.")
    logging.info("If Out-of-Memory, try to reduce num_ring, num_arm, and rays_per_fov.")

    # Optimizer and scheduler
    optimizer = self.get_optimizer(lrs, optim_mat=optim_mat)
    scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=100, num_training_steps=iterations)

    # Training loop
    pbar = tqdm(
        total=iterations + 1,
        desc="Progress",
        postfix={"loss_rms": 0, "loss_focus": 0},
    )
    for i in range(iterations + 1):
        # ===> Evaluate the lens
        if i % test_per_iter == 0:
            with torch.no_grad():
                if shape_control and i > 0:
                    self.correct_shape()
                    # self.refocus()

                self.write_lens_json(f"{result_dir}/iter{i}.json")
                self.analysis(f"{result_dir}/iter{i}")

                # Sample rays
                self.calc_pupil()
                rays_backup = []
                for wv in WAVE_RGB:
                    ray = self.sample_ring_arm_rays(num_ring=num_ring, num_arm=num_arm, spp=spp, depth=depth, wvln=wv, scale_pupil=1.05, sample_more_off_axis=False)
                    rays_backup.append(ray)

                # Calculate ray centers
                if centroid:
                    center_ref = -self.psf_center(points_obj=ray.o[:, :, 0, :], method="chief_ray")
                    center_ref = center_ref.unsqueeze(-2).repeat(1, 1, spp, 1)
                else:
                    center_ref = -self.psf_center(points_obj=ray.o[:, :, 0, :], method="pinhole")
                    center_ref = center_ref.unsqueeze(-2).repeat(1, 1, spp, 1)

        # ===> Optimize lens by minimizing RMS
        loss_rms_ls = []
        for wv_idx, wv in enumerate(WAVE_RGB):
            # Ray tracing to sensor, [num_grid, num_grid, num_rays, 3]
            ray = rays_backup[wv_idx].clone()
            ray = self.trace2sensor(ray)

            # Ray error to center and valid mask.
            # Use torch.where to zero out invalid rays BEFORE squaring,
            # preventing NaN from Inf*0 (IEEE 754: inf * 0 = nan).
            ray_xy = ray.o[..., :2]
            ray_valid = ray.is_valid
            ray_err = ray_xy - center_ref
            ray_err = torch.where(
                ray_valid.bool().unsqueeze(-1), ray_err, torch.zeros_like(ray_err)
            )

            # Weight mask, shape of [num_grid, num_grid]
            if wv_idx == 0:
                with torch.no_grad():
                    weight_mask = (ray_err**2).sum(-1).sum(-1)
                    weight_mask /= ray_valid.sum(-1) + EPSILON
                    weight_mask /= weight_mask.mean() + EPSILON

            # Loss on RMS error
            l_rms = (ray_err**2).sum(-1).sum(-1)
            l_rms /= ray_valid.sum(-1) + EPSILON
            l_rms = (l_rms + EPSILON).sqrt()

            # Weighted loss
            l_rms_weighted = (l_rms * weight_mask).sum()
            l_rms_weighted /= weight_mask.sum() + EPSILON
            loss_rms_ls.append(l_rms_weighted)

        # RMS loss for all wavelengths
        loss_rms = sum(loss_rms_ls) / len(loss_rms_ls)

        # Total loss
        w_focus = 1.0
        loss_focus = self.loss_infocus()

        w_reg = 0.1
        loss_reg, loss_dict = self.loss_reg()

        L_total = loss_rms + w_focus * loss_focus + w_reg * loss_reg

        # Back-propagation
        optimizer.zero_grad()
        L_total.backward()
        optimizer.step()
        scheduler.step()

        pbar.set_postfix(loss_rms=loss_rms.item(), loss_focus=loss_focus.item(), **loss_dict)
        pbar.update(1)

    pbar.close()

get_optimizer_params

get_optimizer_params(lrs=[0.0001, 0.0001, 0.01, 0.0001], optim_mat=False, optim_surf_range=None)

Get optimizer parameters for different lens surface.

Recommendation

For cellphone lens: [d, c, k, a], [1e-4, 1e-4, 1e-1, 1e-4] For camera lens: [d, c, 0, 0], [1e-3, 1e-4, 0, 0]

Parameters:

Name Type Description Default
lrs list

learning rate for different parameters.

[0.0001, 0.0001, 0.01, 0.0001]
optim_mat bool

whether to optimize material. Defaults to False.

False
optim_surf_range list

surface indices to be optimized. Defaults to None.

None

Returns:

Name Type Description
list

optimizer parameters

Source code in deeplens/optics/geolens_pkg/optim.py
def get_optimizer_params(
    self,
    lrs=[1e-4, 1e-4, 1e-2, 1e-4],
    optim_mat=False,
    optim_surf_range=None,
):
    """Get optimizer parameters for different lens surface.

    Recommendation:
        For cellphone lens: [d, c, k, a], [1e-4, 1e-4, 1e-1, 1e-4]
        For camera lens: [d, c, 0, 0], [1e-3, 1e-4, 0, 0]

    Args:
        lrs (list): learning rate for different parameters.
        optim_mat (bool): whether to optimize material. Defaults to False.
        optim_surf_range (list): surface indices to be optimized. Defaults to None.

    Returns:
        list: optimizer parameters
    """
    # Find surfaces to be optimized
    if optim_surf_range is None:
        # optim_surf_range = self.find_diff_surf()
        optim_surf_range = range(len(self.surfaces))

    # If lr for each surface is a list is given
    if isinstance(lrs[0], list):
        return self.get_optimizer_params_manual(
            lrs=lrs, optim_mat=optim_mat, optim_surf_range=optim_surf_range
        )

    # Optimize lens surface parameters
    params = []
    for surf_idx in optim_surf_range:
        surf = self.surfaces[surf_idx]

        if isinstance(surf, Aperture):
            params += surf.get_optimizer_params(lrs=[lrs[0]])

        elif isinstance(surf, Aspheric):
            params += surf.get_optimizer_params(
                lrs=lrs[:4], optim_mat=optim_mat
            )

        elif isinstance(surf, Phase):
            params += surf.get_optimizer_params(lrs=[lrs[0], lrs[4]])

        # elif isinstance(surf, GaussianRBF):
        #     params += surf.get_optimizer_params(lrs=lr, optim_mat=optim_mat)

        # elif isinstance(surf, NURBS):
        #     params += surf.get_optimizer_params(lrs=lr, optim_mat=optim_mat)

        elif isinstance(surf, Plane):
            params += surf.get_optimizer_params(lrs=[lrs[0]], optim_mat=optim_mat)

        # elif isinstance(surf, PolyEven):
        #     params += surf.get_optimizer_params(lrs=lr, optim_mat=optim_mat)

        elif isinstance(surf, Spheric):
            params += surf.get_optimizer_params(
                lrs=[lrs[0], lrs[1]], optim_mat=optim_mat
            )

        elif isinstance(surf, ThinLens):
            params += surf.get_optimizer_params(
                lrs=[lrs[0], lrs[1]], optim_mat=optim_mat
            )

        else:
            raise Exception(
                f"Surface type {surf.__class__.__name__} is not supported for optimization yet."
            )

    # Optimize sensor place
    self.d_sensor.requires_grad = True
    params += [{"params": self.d_sensor, "lr": lrs[0]}]

    return params

get_optimizer

get_optimizer(lrs=[0.0001, 0.0001, 0.1, 0.0001], optim_surf_range=None, optim_mat=False)

Get optimizers and schedulers for different lens parameters.

Parameters:

Name Type Description Default
lrs list

learning rate for different parameters [c, d, k, a]. Defaults to [1e-4, 1e-4, 0, 1e-4].

[0.0001, 0.0001, 0.1, 0.0001]
optim_surf_range list

surface indices to be optimized. Defaults to None.

None
optim_mat bool

whether to optimize material. Defaults to False.

False

Returns:

Name Type Description
list

optimizer parameters

Source code in deeplens/optics/geolens_pkg/optim.py
def get_optimizer(
    self,
    lrs=[1e-4, 1e-4, 1e-1, 1e-4],
    optim_surf_range=None,
    optim_mat=False,
):
    """Get optimizers and schedulers for different lens parameters.

    Args:
        lrs (list): learning rate for different parameters [c, d, k, a]. Defaults to [1e-4, 1e-4, 0, 1e-4].
        optim_surf_range (list): surface indices to be optimized. Defaults to None.
        optim_mat (bool): whether to optimize material. Defaults to False.

    Returns:
        list: optimizer parameters
    """
    # Initialize lens design constraints (edge thickness, etc.)
    self.init_constraints()

    # Get optimizer
    params = self.get_optimizer_params(
        lrs=lrs, optim_surf_range=optim_surf_range, optim_mat=optim_mat
    )
    optimizer = torch.optim.Adam(params)
    # optimizer = torch.optim.SGD(params)
    return optimizer

deeplens.optics.geolens_pkg.optim_ops.GeoLensSurfOps

Mixin providing surface geometry operations for GeoLens.

Methods:

Name Description
- add_aspheric

Convert a spherical surface to aspheric.

- increase_aspheric_order

Add higher-order polynomial terms.

- prune_surf

Size clear apertures by ray tracing.

- correct_shape

Fix lens geometry during optimisation.

add_aspheric

add_aspheric(surf_idx=None, ai_degree=4)

Convert a spherical surface to aspheric for improved aberration correction.

If surf_idx is given, converts that specific surface. Otherwise, automatically selects the best candidate following established optical design principles:

  1. First asphere: placed near the aperture stop (corrects spherical aberration).
  2. Subsequent aspheres: placed far from the stop (corrects field-dependent aberrations like coma, astigmatism, distortion).
  3. Prefer air-glass interfaces over cemented surfaces.
  4. Among candidates at similar stop-distances, prefer larger semi-diameter (higher marginal ray height → more SA contribution).

The new surface starts with k=0 and all polynomial coefficients at zero, so it is initially identical to the original spherical surface.

Note

After calling this method, any existing optimizer is stale. Call get_optimizer() again to include the new parameters.

Parameters:

Name Type Description Default
surf_idx int or None

Surface index to convert. If None, auto-selects the best candidate.

None
ai_degree int

Number of even-order aspheric coefficients [a2, a4, a6, ...]. Defaults to 4.

4

Returns:

Name Type Description
int

Index of the converted surface.

Raises:

Type Description
IndexError

If surf_idx is out of range.

ValueError

If surf_idx points to a non-Spheric surface, or no eligible candidate exists for auto-selection.

References

Design principles from research/aspheric_design_principles.md.

Source code in deeplens/optics/geolens_pkg/optim_ops.py
@torch.no_grad()
def add_aspheric(self, surf_idx=None, ai_degree=4):
    """Convert a spherical surface to aspheric for improved aberration correction.

    If ``surf_idx`` is given, converts that specific surface. Otherwise,
    automatically selects the best candidate following established optical
    design principles:

    1. First asphere: placed near the aperture stop (corrects spherical
       aberration).
    2. Subsequent aspheres: placed far from the stop (corrects field-dependent
       aberrations like coma, astigmatism, distortion).
    3. Prefer air-glass interfaces over cemented surfaces.
    4. Among candidates at similar stop-distances, prefer larger semi-diameter
       (higher marginal ray height → more SA contribution).

    The new surface starts with ``k=0`` and all polynomial coefficients at
    zero, so it is initially identical to the original spherical surface.

    Note:
        After calling this method, any existing optimizer is stale.
        Call ``get_optimizer()`` again to include the new parameters.

    Args:
        surf_idx (int or None): Surface index to convert. If ``None``,
            auto-selects the best candidate.
        ai_degree (int): Number of even-order aspheric coefficients
            ``[a2, a4, a6, ...]``. Defaults to 4.

    Returns:
        int: Index of the converted surface.

    Raises:
        IndexError: If ``surf_idx`` is out of range.
        ValueError: If ``surf_idx`` points to a non-Spheric surface, or no
            eligible candidate exists for auto-selection.

    References:
        Design principles from ``research/aspheric_design_principles.md``.
    """
    if surf_idx is not None:
        if surf_idx < 0 or surf_idx >= len(self.surfaces):
            raise IndexError(
                f"surf_idx={surf_idx} out of range [0, {len(self.surfaces) - 1}]."
            )
        if not isinstance(self.surfaces[surf_idx], Spheric):
            raise ValueError(
                f"Surface {surf_idx} is {type(self.surfaces[surf_idx]).__name__}, "
                f"expected Spheric. To add higher-order terms to an existing "
                f"Aspheric surface, use increase_aspheric_order(surf_idx={surf_idx})."
            )
        self._spheric_to_aspheric(surf_idx, ai_degree)
        logging.info(
            f"Converted surface {surf_idx} from Spheric to Aspheric "
            f"(ai_degree={ai_degree})."
        )
        return surf_idx

    # Auto-select best candidate
    surf_idx = self._find_best_asphere_candidate()
    self._spheric_to_aspheric(surf_idx, ai_degree)
    logging.info(
        f"Auto-selected surface {surf_idx} as best asphere candidate. "
        f"Converted to Aspheric (ai_degree={ai_degree})."
    )
    return surf_idx

increase_aspheric_order

increase_aspheric_order(surf_idx=None, increment=1)

Add higher-order polynomial terms to existing Aspheric surfaces.

Appends increment additional even-order coefficients (initialised to zero). For example, degree 4 [a4, a6, a8, a10] becomes degree 5 [a4, a6, a8, a10, a12] after increment=1.

Follows the principle of start low, add incrementally: increase order only when residual higher-order aberrations persist after optimisation at the current order.

Note

After calling this method, any existing optimizer is stale. Call get_optimizer() again to include the new parameters.

Parameters:

Name Type Description Default
surf_idx int or None

Surface index. If None, auto-selects the best candidate (see _find_best_order_increase_candidate).

None
increment int

Number of additional coefficients to add. Defaults to 1.

1

Returns:

Name Type Description
int

Index of the surface whose order was increased.

Raises:

Type Description
IndexError

If surf_idx is out of range.

ValueError

If surf_idx is given but is not Aspheric, if no Aspheric surfaces exist when surf_idx is None, or if increment < 1.

Source code in deeplens/optics/geolens_pkg/optim_ops.py
@torch.no_grad()
def increase_aspheric_order(self, surf_idx=None, increment=1):
    """Add higher-order polynomial terms to existing Aspheric surfaces.

    Appends ``increment`` additional even-order coefficients (initialised
    to zero). For example, degree 4 ``[a4, a6, a8, a10]`` becomes degree 5
    ``[a4, a6, a8, a10, a12]`` after ``increment=1``.

    Follows the principle of *start low, add incrementally*: increase
    order only when residual higher-order aberrations persist after
    optimisation at the current order.

    Note:
        After calling this method, any existing optimizer is stale.
        Call ``get_optimizer()`` again to include the new parameters.

    Args:
        surf_idx (int or None): Surface index. If ``None``, auto-selects
            the best candidate (see ``_find_best_order_increase_candidate``).
        increment (int): Number of additional coefficients to add.
            Defaults to 1.

    Returns:
        int: Index of the surface whose order was increased.

    Raises:
        IndexError: If ``surf_idx`` is out of range.
        ValueError: If ``surf_idx`` is given but is not Aspheric, if
            no Aspheric surfaces exist when ``surf_idx`` is ``None``,
            or if ``increment`` < 1.
    """
    if increment < 1:
        raise ValueError(f"increment must be >= 1, got {increment}.")
    if surf_idx is not None:
        if surf_idx < 0 or surf_idx >= len(self.surfaces):
            raise IndexError(
                f"surf_idx={surf_idx} out of range [0, {len(self.surfaces) - 1}]."
            )
    else:
        surf_idx = self._find_best_order_increase_candidate()

    surf = self.surfaces[surf_idx]
    if not isinstance(surf, Aspheric):
        raise ValueError(
            f"Surface {surf_idx} is {type(surf).__name__}, expected Aspheric."
        )
    old_degree = surf.ai_degree
    self._increase_surface_order(surf, increment)
    logging.info(
        f"Surface {surf_idx}: aspheric order {old_degree} -> {surf.ai_degree}."
    )

    return surf_idx

prune_surf

prune_surf(expand_factor=None, mounting_margin=None)

Prune surfaces to allow all valid rays to go through.

Determines the clear aperture for each surface by ray tracing, then applies margins and enforces manufacturability constraints (edge thickness and air-gap clearance).

Parameters:

Name Type Description Default
expand_factor float

Fractional expansion applied to the ray-traced clear aperture radius. Auto-selected if None: 10 % for all lenses.

None
mounting_margin float

Absolute margin [mm] added to the clear aperture for mechanical mounting. When given, this replaces the proportional expand_factor expansion.

None
Source code in deeplens/optics/geolens_pkg/optim_ops.py
@torch.no_grad()
def prune_surf(self, expand_factor=None, mounting_margin=None):
    """Prune surfaces to allow all valid rays to go through.

    Determines the clear aperture for each surface by ray tracing, then
    applies margins and enforces manufacturability constraints (edge
    thickness and air-gap clearance).

    Args:
        expand_factor (float, optional): Fractional expansion applied to
            the ray-traced clear aperture radius.  Auto-selected if None:
            10 % for all lenses.
        mounting_margin (float, optional): Absolute margin [mm] added to
            the clear aperture for mechanical mounting.  When given, this
            replaces the proportional ``expand_factor`` expansion.
    """
    surface_range = self.find_diff_surf()
    num_surfs = len(self.surfaces)

    # Set expansion factor
    if expand_factor is None:
        expand_factor = 0.10

    # ------------------------------------------------------------------
    # 1. Temporarily remove radius limits so the trace is unclipped
    # ------------------------------------------------------------------
    saved_radii = [self.surfaces[i].r for i in range(num_surfs)]
    for i in surface_range:
        self.surfaces[i].r = self.surfaces[i].max_height()

    # ------------------------------------------------------------------
    # 2. Trace rays at full FoV to find maximum ray height per surface
    # ------------------------------------------------------------------
    if self.rfov is not None:
        fov_deg = self.rfov * 180 / torch.pi
    else:
        fov = np.arctan(self.r_sensor / self.foclen)
        fov_deg = float(fov) * 180 / torch.pi
        print(f"Using fov_deg: {fov_deg} during surface pruning.")

    fov_y = [f * fov_deg / 10 for f in range(0, 11)]
    ray = self.sample_from_fov(
        fov_x=[0.0], fov_y=fov_y, num_rays=SPP_CALC, scale_pupil=1.0
    )
    _, ray_o_record = self.trace2sensor(ray=ray, record=True)

    # Ray record, shape [num_rays, num_surfaces + 2, 3]
    ray_o_record = torch.stack(ray_o_record, dim=-2)
    ray_o_record = torch.nan_to_num(ray_o_record, 0.0)
    ray_o_record = ray_o_record.reshape(-1, ray_o_record.shape[-2], 3)

    # Compute the maximum ray height for each surface
    ray_r_record = (ray_o_record[..., :2] ** 2).sum(-1).sqrt()
    surf_r_max = ray_r_record.max(dim=0)[0][1:-1]

    # Restore original radii before updating
    for i in range(num_surfs):
        self.surfaces[i].r = saved_radii[i]

    # ------------------------------------------------------------------
    # 3. Set new surface radii = ray-traced clear aperture + margin
    # ------------------------------------------------------------------
    for i in surface_range:
        if surf_r_max[i] > 0:
            r_clear = surf_r_max[i].item()
            if mounting_margin is not None:
                r_new = r_clear + mounting_margin
            else:
                r_expand = r_clear * expand_factor
                r_expand = max(min(r_expand, 2.0), 0.1)
                r_new = r_clear + r_expand
            self.surfaces[i].update_r(r_new)
        else:
            print(f"No valid rays for Surf {i}, expand existing radius.")
            if mounting_margin is not None:
                self.surfaces[i].update_r(self.surfaces[i].r + mounting_margin)
            else:
                r_expand = self.surfaces[i].r * expand_factor
                r_expand = max(min(r_expand, 2.0), 0.1)
                self.surfaces[i].update_r(self.surfaces[i].r + r_expand)

    # ------------------------------------------------------------------
    # 4. Air gap clearance check
    #    For each air gap (surface i with mat2 = "air"), ensure that
    #    surfaces do not physically intersect at the clear aperture edge.
    # ------------------------------------------------------------------
    if self.r_sensor < 10.0:
        air_gap_min = 0.05  # mm
    else:
        air_gap_min = 0.1  # mm

    for i in range(num_surfs - 1):
        if self.surfaces[i].mat2.name != "air":
            continue
        if isinstance(self.surfaces[i], Aperture):
            continue

        curr = self.surfaces[i]
        nxt = self.surfaces[i + 1]
        r_check = min(curr.r, nxt.r)

        if r_check <= 0:
            continue

        # Check gap at multiple radial points along the edge
        r_pts = torch.linspace(0.5 * r_check, r_check, 8, device=self.device)
        z_curr = curr.surface_with_offset(r_pts, 0.0, valid_check=False)
        z_nxt = nxt.surface_with_offset(r_pts, 0.0, valid_check=False)
        min_gap = (z_nxt - z_curr).min().item()

        if min_gap < air_gap_min:
            # Shrink radius until air gap is met (binary search)
            r_lo, r_hi = 0.0, r_check
            for _ in range(20):
                r_mid = (r_lo + r_hi) / 2
                r_pts = torch.linspace(0.5 * r_mid, r_mid, 8, device=self.device)
                z_c = curr.surface_with_offset(r_pts, 0.0, valid_check=False)
                z_n = nxt.surface_with_offset(r_pts, 0.0, valid_check=False)
                if (z_n - z_c).min().item() >= air_gap_min:
                    r_lo = r_mid
                else:
                    r_hi = r_mid

            r_safe = r_lo
            if r_safe > 0 and r_safe < r_check:
                print(
                    f"Surf {i}-{i+1}: air gap {min_gap:.3f} mm "
                    f"< {air_gap_min} mm, shrinking radius {r_check:.3f} -> {r_safe:.3f} mm."
                )
                if curr.r > r_safe:
                    curr.update_r(r_safe)
                if nxt.r > r_safe:
                    nxt.update_r(r_safe)

    # ------------------------------------------------------------------
    # 6. Validate aperture radius consistency
    #    The aperture (stop) radius should not exceed the clear aperture
    #    of its neighboring surfaces.
    # ------------------------------------------------------------------
    if self.aper_idx is not None:
        aper = self.surfaces[self.aper_idx]
        # Find neighboring non-aperture surfaces
        neighbor_r = []
        if self.aper_idx > 0:
            neighbor_r.append(self.surfaces[self.aper_idx - 1].r)
        if self.aper_idx < num_surfs - 1:
            neighbor_r.append(self.surfaces[self.aper_idx + 1].r)

        if neighbor_r:
            max_aper_r = min(neighbor_r)
            if aper.r > max_aper_r:
                print(
                    f"Aperture radius {aper.r:.3f} mm exceeds neighbor "
                    f"clear aperture {max_aper_r:.3f} mm, clamping."
                )
                aper.r = max_aper_r

correct_shape

correct_shape(expand_factor=None, mounting_margin=None)

Correct wrong lens shape during lens design optimization.

Applies correction rules to ensure valid lens geometry
  1. Move the first surface to z = 0.0
  2. Fix aperture distance if aperture is at the front
  3. Prune all surfaces to allow valid rays through

Parameters:

Name Type Description Default
expand_factor float

Height expansion factor for surface pruning. If None, auto-selects based on lens type. Defaults to None.

None
mounting_margin float

Absolute mounting margin [mm] for surface pruning. Passed through to :meth:prune_surf.

None

Returns:

Name Type Description
bool

True if any shape corrections were made, False otherwise.

Source code in deeplens/optics/geolens_pkg/optim_ops.py
@torch.no_grad()
def correct_shape(self, expand_factor=None, mounting_margin=None):
    """Correct wrong lens shape during lens design optimization.

    Applies correction rules to ensure valid lens geometry:
        1. Move the first surface to z = 0.0
        2. Fix aperture distance if aperture is at the front
        3. Prune all surfaces to allow valid rays through

    Args:
        expand_factor (float, optional): Height expansion factor for surface pruning.
            If None, auto-selects based on lens type. Defaults to None.
        mounting_margin (float, optional): Absolute mounting margin [mm] for
            surface pruning.  Passed through to :meth:`prune_surf`.

    Returns:
        bool: True if any shape corrections were made, False otherwise.
    """
    aper_idx = self.aper_idx
    optim_surf_range = self.find_diff_surf()
    shape_changed = False

    # Rule 1: Move the first surface to z = 0.0
    move_dist = self.surfaces[0].d.item()
    for surf in self.surfaces:
        surf.d -= move_dist
    self.d_sensor -= move_dist

    # Rule 2: Fix aperture distance to the first surface if aperture in the front.
    if aper_idx == 0:
        d_aper = 0.05

        # If the first surface is concave, use the maximum negative sag.
        aper_r = torch.tensor(self.surfaces[aper_idx].r, device=self.device)
        sag1 = -self.surfaces[aper_idx + 1].sag(aper_r, 0).item()

        if sag1 > 0:
            d_aper += sag1

        # Update position of all surfaces.
        delta_aper = self.surfaces[1].d.item() - d_aper
        for i in optim_surf_range:
            self.surfaces[i].d -= delta_aper
        self.d_sensor -= delta_aper

    # Rule 4: Prune all surfaces
    self.prune_surf(expand_factor=expand_factor, mounting_margin=mounting_margin)

    if shape_changed:
        print("Surface shape corrected.")
    return shape_changed

match_materials

match_materials(mat_table='CDGM')

Match lens materials to a glass catalog.

Parameters:

Name Type Description Default
mat_table str

Glass catalog name. Common options include 'CDGM', 'SCHOTT', 'OHARA'. Defaults to 'CDGM'.

'CDGM'
Source code in deeplens/optics/geolens_pkg/optim_ops.py
@torch.no_grad()
def match_materials(self, mat_table="CDGM"):
    """Match lens materials to a glass catalog.

    Args:
        mat_table (str, optional): Glass catalog name. Common options include
            'CDGM', 'SCHOTT', 'OHARA'. Defaults to 'CDGM'.
    """
    for surf in self.surfaces:
        surf.mat2.match_material(mat_table=mat_table)

deeplens.optics.geolens_pkg.io.GeoLensIO

Mixin providing file I/O for GeoLens.

Supports reading and writing lens prescriptions in three formats:

  • JSON (primary): human-readable, supports parenthesised optimisable parameters, e.g. "(d)": 5.0.
  • Zemax .zmx: industry-standard sequential lens file.
  • Code V .seq: Code V sequential format (read-only).

This class is not instantiated directly; it is mixed into :class:~deeplens.optics.geolens.GeoLens.

read_lens_zmx

read_lens_zmx(filename='./test.zmx')

Load the lens from a Zemax .zmx sequential lens file.

Parses STANDARD and EVENASPH surface types, glass materials, field definitions (YFLN), and entrance pupil settings (ENPD/FLOA).

Parameters:

Name Type Description Default
filename str

Path to the .zmx file. Supports both UTF-8 and UTF-16 encoded files. Defaults to './test.zmx'.

'./test.zmx'

Returns:

Name Type Description
GeoLens

self, for method chaining.

Source code in deeplens/optics/geolens_pkg/io.py
def read_lens_zmx(self, filename="./test.zmx"):
    """Load the lens from a Zemax .zmx sequential lens file.

    Parses STANDARD and EVENASPH surface types, glass materials, field
    definitions (YFLN), and entrance pupil settings (ENPD/FLOA).

    Args:
        filename (str, optional): Path to the .zmx file. Supports both
            UTF-8 and UTF-16 encoded files. Defaults to './test.zmx'.

    Returns:
        GeoLens: ``self``, for method chaining.
    """
    # Read .zmx file
    try:
        with open(filename, "r", encoding="utf-8") as file:
            lines = file.readlines()
    except UnicodeDecodeError:
        with open(filename, "r", encoding="utf-16") as file:
            lines = file.readlines()

    # Iterate through the lines and extract SURF dict
    surfs_dict = {}
    current_surf = None
    for line in lines:
        # Strip leading/trailing whitespace for consistent parsing
        stripped_line = line.strip()

        if stripped_line.startswith("SURF"):
            current_surf = int(stripped_line.split()[1])
            surfs_dict[current_surf] = {}

        elif current_surf is not None and stripped_line != "":
            if len(stripped_line.split(maxsplit=1)) == 1:
                continue
            else:
                key, value = stripped_line.split(maxsplit=1)
                if key == "PARM":
                    new_key = "PARM" + value.split()[0]
                    new_value = value.split()[1]
                    surfs_dict[current_surf][new_key] = new_value
                else:
                    surfs_dict[current_surf][key] = value

        elif stripped_line.startswith("FLOA") or stripped_line.startswith("ENPD"):
            if stripped_line.startswith("FLOA"):
                self.float_enpd = True
                self.enpd = None
            else:
                self.float_enpd = False
                self.enpd = float(stripped_line.split()[1])

        elif stripped_line.startswith("YFLN"):
            # Parse field of view from YFLN line (field coordinates in degrees)
            # YFLN format: YFLN 0.0 <0.707*rfov_deg> <0.99*rfov_deg>
            parts = stripped_line.split()
            if len(parts) > 1:
                field_values = [abs(float(x)) for x in parts[1:] if float(x) != 0.0]
                if field_values:
                    # The largest field value is typically 0.99 * rfov_deg
                    max_field_deg = max(field_values) / 0.99
                    self.rfov = (
                        max_field_deg * math.pi / 180.0
                    )  # Convert to radians

    self.float_foclen = False
    self.float_rfov = False
    # Set default rfov if not parsed from file
    if not hasattr(self, "rfov"):
        self.rfov = None

    # Read the extracted data from each SURF
    self.surfaces = []
    d = 0.0
    mat1_name = "air"
    for surf_idx, surf_dict in surfs_dict.items():
        if surf_idx > 0 and surf_idx < current_surf:
            # Lens surface parameters
            if "GLAS" in surf_dict:
                if surf_dict["GLAS"].split()[0] == "___BLANK":
                    mat2_name = f"{surf_dict['GLAS'].split()[3]}/{surf_dict['GLAS'].split()[4]}"
                else:
                    mat2_name = surf_dict["GLAS"].split()[0].lower()
            else:
                mat2_name = "air"

            surf_r = (
                float(surf_dict["DIAM"].split()[0]) if "DIAM" in surf_dict else 1.0
            )
            surf_c = (
                float(surf_dict["CURV"].split()[0]) if "CURV" in surf_dict else 0.0
            )
            surf_d_next = (
                float(surf_dict["DISZ"].split()[0]) if "DISZ" in surf_dict else 0.0
            )
            surf_conic = float(surf_dict.get("CONI", 0.0))
            surf_param2 = float(surf_dict.get("PARM2", 0.0))
            surf_param3 = float(surf_dict.get("PARM3", 0.0))
            surf_param4 = float(surf_dict.get("PARM4", 0.0))
            surf_param5 = float(surf_dict.get("PARM5", 0.0))
            surf_param6 = float(surf_dict.get("PARM6", 0.0))
            surf_param7 = float(surf_dict.get("PARM7", 0.0))
            surf_param8 = float(surf_dict.get("PARM8", 0.0))

            # Create surface object
            if surf_dict["TYPE"] == "STANDARD":
                if mat2_name == "air" and mat1_name == "air":
                    # Aperture
                    s = Aperture(r=surf_r, d=d)
                else:
                    # Spherical surface
                    s = Spheric(c=surf_c, r=surf_r, d=d, mat2=mat2_name)

            elif surf_dict["TYPE"] == "EVENASPH":
                # Aspherical surface
                s = Aspheric(
                    c=surf_c,
                    r=surf_r,
                    d=d,
                    ai=[
                        surf_param2,
                        surf_param3,
                        surf_param4,
                        surf_param5,
                        surf_param6,
                        surf_param7,
                        surf_param8,
                    ],
                    k=surf_conic,
                    mat2=mat2_name,
                )

            else:
                print(f"Surface type {surf_dict['TYPE']} not implemented.")
                continue

            self.surfaces.append(s)
            d += surf_d_next
            mat1_name = mat2_name

        elif surf_idx == current_surf:
            # Image sensor
            self.r_sensor = float(surf_dict["DIAM"].split()[0])

        else:
            pass

    self.d_sensor = torch.tensor(d)
    return self

write_lens_zmx

write_lens_zmx(filename='./test.zmx')

Write the lens to a Zemax .zmx sequential lens file.

Exports surfaces (STANDARD or EVENASPH), materials, field definitions, and entrance pupil settings in Zemax OpticStudio format.

Parameters:

Name Type Description Default
filename str

Output file path. Defaults to './test.zmx'.

'./test.zmx'
Source code in deeplens/optics/geolens_pkg/io.py
def write_lens_zmx(self, filename="./test.zmx"):
    """Write the lens to a Zemax .zmx sequential lens file.

    Exports surfaces (STANDARD or EVENASPH), materials, field definitions,
    and entrance pupil settings in Zemax OpticStudio format.

    Args:
        filename (str, optional): Output file path. Defaults to './test.zmx'.
    """
    lens_zmx_str = ""
    if self.float_enpd:
        enpd_str = "FLOA"
    else:
        enpd_str = f"ENPD {self.enpd}"
    # Head string
    head_str = f"""VERS 190513 80 123457 L123457
MODE SEQ
NAME 
PFIL 0 0 0
LANG 0
UNIT MM X W X CM MR CPMM
{enpd_str}
ENVD 2.0E+1 1 0
GFAC 0 0
GCAT OSAKAGASCHEMICAL MISC
XFLN 0. 0. 0.
YFLN 0.0 {0.707 * self.rfov * 57.3} {0.99 * self.rfov * 57.3}
WAVL 0.4861327 0.5875618 0.6562725
RAIM 0 0 1 1 0 0 0 0 0
PUSH 0 0 0 0 0 0
SDMA 0 1 0
FTYP 0 0 3 3 0 0 0
ROPD 2
PICB 1
PWAV 2
POLS 1 0 1 0 0 1 0
GLRS 1 0
GSTD 0 100.000 100.000 100.000 100.000 100.000 100.000 0 1 1 0 0 1 1 1 1 1 1
NSCD 100 500 0 1.0E-3 5 1.0E-6 0 0 0 0 0 0 1000000 0 2
COFN QF "COATING.DAT" "SCATTER_PROFILE.DAT" "ABG_DATA.DAT" "PROFILE.GRD"
COFN COATING.DAT SCATTER_PROFILE.DAT ABG_DATA.DAT PROFILE.GRD
SURF 0
TYPE STANDARD
CURV 0.0
DISZ INFINITY
"""
    lens_zmx_str += head_str

    # Surface string
    for i, s in enumerate(self.surfaces):
        d_next = (
            self.surfaces[i + 1].d - self.surfaces[i].d
            if i < len(self.surfaces) - 1
            else self.d_sensor - self.surfaces[i].d
        )
        surf_str = s.zmx_str(surf_idx=i + 1, d_next=d_next)
        lens_zmx_str += surf_str

    # Sensor string
    sensor_str = f"""SURF {i + 2}
TYPE STANDARD
CURV 0.
DISZ 0.0
DIAM {self.r_sensor}
"""
    lens_zmx_str += sensor_str

    # Write lens zmx string into file
    with open(filename, "w") as f:
        f.writelines(lens_zmx_str)
        f.close()
        print(f"Lens written to {filename}")

read_lens_seq

read_lens_seq(filename='./test.seq')

Load the lens from a CODE V .seq sequential file.

Parses standard and aspheric surfaces (with conic and polynomial coefficients A–I), entrance pupil diameter (EPD), field angles (YAN), aperture stop (STO), and image surface (SI).

Parameters:

Name Type Description Default
filename str

Path to the .seq file. Supports both UTF-8 and Latin-1 encoded files. Defaults to './test.seq'.

'./test.seq'

Returns:

Name Type Description
GeoLens

self, for method chaining.

Source code in deeplens/optics/geolens_pkg/io.py
def read_lens_seq(self, filename="./test.seq"):
    """Load the lens from a CODE V .seq sequential file.

    Parses standard and aspheric surfaces (with conic and polynomial
    coefficients A–I), entrance pupil diameter (EPD), field angles (YAN),
    aperture stop (STO), and image surface (SI).

    Args:
        filename (str, optional): Path to the .seq file. Supports both
            UTF-8 and Latin-1 encoded files. Defaults to './test.seq'.

    Returns:
        GeoLens: ``self``, for method chaining.
    """
    print(f"\n{'=' * 60}")
    print(f"Start reading CODE V file: {filename}")
    print(f"{'=' * 60}\n")

    # Read .seq file
    try:
        with open(filename, "r", encoding="utf-8") as file:
            lines = file.readlines()
        print(f"File read successfully (UTF-8)")
    except UnicodeDecodeError:
        try:
            with open(filename, "r", encoding="latin-1") as file:
                lines = file.readlines()
            print(f"File read successfully (Latin-1)")
        except Exception as e:
            print(f"Failed to read file: {e}")
            return self
    print(f"Total lines: {len(lines)}\n")

    # ============ Step 1: Parse file structure ============
    surfaces = []
    current_surface = {}
    surface_index = 0
    global_diameter = None

    print("Beginning to parse surface data...\n")

    for line_num, line in enumerate(lines, 1):
        line = line.strip()

        # Skip irrelevant lines
        if not line or line.startswith(
            (
                "RDM",
                "TITLE",
                "UID",
                "GO",
                "WL",
                "XAN",
                "REF",
                "WTW",
                "INI",
                "WTF",
                "VUY",
                "VLY",
                "DOR",
                "DIM",
                "THC",
            )
        ):
            continue
        # Read entrance pupil diameter
        if line.startswith("EPD"):
            self.enpd = float(line.split()[1])
            self.float_enpd = False
            global_diameter = self.enpd / 2.0
            print(
                f"[Line {line_num}] EPD={self.enpd} -> default radius={global_diameter}"
            )
            continue
        # Read field of view angle
        if line.startswith("YAN"):
            angles = [abs(float(x)) for x in line.split()[1:] if float(x) != 0.0]
            if angles:
                self.hfov = max(angles)
                # Also set rfov in radians for consistency with write functions
                self.rfov = self.hfov * math.pi / 180.0
                print(f"[Line {line_num}] Max field of view={self.hfov} deg")
            continue
        # Object surface
        if line.startswith("SO"):
            parts = line.split()
            thickness = float(parts[2]) if len(parts) > 2 else 1e10

            current_surface = {
                "type": "OBJECT",
                "thickness": thickness,
                "index": surface_index,
            }
            surfaces.append(current_surface)
            print(f"[Line {line_num}] Object surface: T={thickness}")
            surface_index += 1
            current_surface = {}
            continue
        # Standard surface
        if line.startswith("S "):
            # Save the previous surface
            if current_surface:
                surfaces.append(current_surface)
                surface_index += 1

            parts = line.split()
            radius_value = float(parts[1]) if len(parts) > 1 else 0.0
            thickness = float(parts[2]) if len(parts) > 2 else 0.0
            material = parts[3].upper() if len(parts) > 3 else "AIR"

            # Key: compute curvature C = 1/R
            if abs(radius_value) > 1e-10:
                curvature = 1.0 / radius_value
            else:
                curvature = 0.0

            current_surface = {
                "type": "STANDARD",
                "radius": radius_value,
                "curvature": curvature,
                "thickness": thickness,
                "material": material,
                "index": surface_index,
                "diameter": global_diameter,
                "conic": 0.0,
                "asph_coeffs": {},
                "is_stop": False,
            }

            print(
                f"[Line {line_num}] Surface{surface_index}: R={radius_value:.4f} → C={curvature:.6f}, T={thickness}, Mat={material}"
            )
            continue
        # Image surface - do not append yet, wait for CIR
        if line.startswith("SI"):
            if current_surface:
                surfaces.append(current_surface)
                surface_index += 1

            parts = line.split()
            thickness = float(parts[1]) if len(parts) > 1 else 0.0

            current_surface = {
                "type": "IMAGE",
                "thickness": thickness,
                "diameter": None,  # Set to None first, wait for CIR line to update
                "index": surface_index,
            }
            print(f"[Line {line_num}] Image surface")
            continue
        # Handle surface attributes (CIR, STO, ASP, K, A~J, etc.)
        if current_surface:
            if line.startswith("CIR"):
                current_surface["diameter"] = float(
                    line.split()[1].replace(";", "")
                )
                print(f"[Line {line_num}]   → CIR={current_surface['diameter']}")

            elif line.startswith("STO"):
                current_surface["is_stop"] = True
                print(f"[Line {line_num}]   → Aperture stop flag")

            elif line.startswith("ASP"):
                current_surface["type"] = "ASPHERIC"
                print(f"[Line {line_num}]   → Aspheric surface")

            elif line.startswith("K "):
                current_surface["conic"] = float(line.split()[1].replace(";", ""))
                print(f"[Line {line_num}]   → K={current_surface['conic']}")

            # Only extract single-letter coefficients A-J
            elif any(
                line.startswith(p)
                for p in [
                    "A ",
                    "B ",
                    "C ",
                    "D ",
                    "E ",
                    "F ",
                    "G ",
                    "H ",
                    "I ",
                    "J ",
                ]
            ):
                parts = line.replace(";", "").split()
                i = 0
                while i < len(parts) - 1:
                    try:
                        key = parts[i]
                        # Only accept single letters within the range A-J
                        if len(key) == 1 and key in [
                            "A",
                            "B",
                            "C",
                            "D",
                            "E",
                            "F",
                            "G",
                            "H",
                            "I",
                            "J",
                        ]:
                            value = float(parts[i + 1])
                            current_surface["asph_coeffs"][key] = value
                            print(f"[Line {line_num}]   → {key}={value}")
                        i += 2
                    except:
                        i += 1

    # Save the last surface
    if current_surface:
        surfaces.append(current_surface)

    print(f"\nParsing complete, total {len(surfaces)} surfaces\n")

    # ============ Step 2: Create surface objects ============
    print(f"{'=' * 60}")
    print("Start creating surface objects:")
    print(f"{'=' * 60}\n")

    self.surfaces = []
    d = 0.0  # Cumulative distance from the first optical surface to the current surface
    previous_material = "air"

    for surf in surfaces:
        surf_idx = surf["index"]
        surf_type = surf["type"]

        print(f"{'=' * 50}")
        print(f"Processing surface{surf_idx} ({surf_type}), current d={d:.4f}")

        # Handle object surface
        if surf_type == "OBJECT":
            obj_thickness = surf["thickness"]
            if obj_thickness < 1e9:  # Finite object distance
                d += obj_thickness
                print(
                    f"   Object surface thickness={obj_thickness} → accumulated d={d:.4f}"
                )
            else:
                print("   Object surface at infinity")
            previous_material = "air"
            continue

        # Handle image surface
        if surf_type == "IMAGE":
            self.d_sensor = torch.tensor(d)
            # Read diameter from surf dictionary (CIR value)
            self.r_sensor = (
                surf.get("diameter") if surf.get("diameter") is not None else 18.0
            )
            print(
                f"   Image plane position: d_sensor={d:.4f}, r_sensor={self.r_sensor:.4f}"
            )
            break

        # Get surface parameters
        current_material = surf.get("material", "AIR")
        if current_material in ["AIR", "0.0", "", None]:
            current_material = "air"
        else:
            current_material = current_material.lower()

        c = surf.get("curvature", 0.0)
        r = surf.get("diameter", 10.0)
        d_next = surf.get("thickness", 0.0)
        is_stop = surf.get("is_stop", False)

        print(f"   C={c:.6f}, R_aperture={r:.4f}, T={d_next:.4f}")
        print(f"   Material: {previous_material}{current_material}")
        print(f"   is_stop={is_stop}")

        # Create surface object
        try:
            # Case 1: pure aperture (air on both sides + STO flag)
            if is_stop and current_material == "air" and previous_material == "air":
                aperture = Aperture(r=r, d=d)
                self.surfaces.append(aperture)
                print(f"   Created pure aperture: Aperture(r={r:.4f}, d={d:.4f})")

            # Case 2: refractive surface (material change)
            elif current_material != previous_material:
                if surf_type == "STANDARD":
                    s = Spheric(c=c, r=r, d=d, mat2=current_material)
                    self.surfaces.append(s)
                    status = " (stop surface)" if is_stop else ""
                    print(
                        f"   Created spherical surface{status}: Spheric(c={c:.6f}, r={r:.4f}, d={d:.4f}, mat2='{current_material}')"
                    )

                elif surf_type == "ASPHERIC":
                    k = surf.get("conic", 0.0)
                    asph_coeffs = surf.get("asph_coeffs", {})

                    # CODE V aspheric coefficient mapping (shift forward by one position):
                    # A → ai[1] (2nd term, ρ²)
                    # B → ai[2] (4th term, ρ⁴)
                    # C → ai[3] (6th term, ρ⁶)
                    # D → ai[4] (8th term, ρ⁸)
                    # E → ai[5] (10th term, ρ¹⁰)
                    # F → ai[6] (12th term, ρ¹²)
                    # G → ai[7] (14th term, ρ¹⁴)
                    # H → ai[8] (16th term, ρ¹⁶)
                    # I → ai[9] (18th term, ρ¹⁸)

                    # Initialize ai array (10 elements)
                    ai = [0.0] * 10
                    ai[0] = 0.0  # ρ⁰ term (unused)
                    ai[1] = asph_coeffs.get("A", 0.0)  # ρ²
                    ai[2] = asph_coeffs.get("B", 0.0)  # ρ⁴
                    ai[3] = asph_coeffs.get("C", 0.0)  # ρ⁶
                    ai[4] = asph_coeffs.get("D", 0.0)  # ρ⁸
                    ai[5] = asph_coeffs.get("E", 0.0)  # ρ¹⁰
                    ai[6] = asph_coeffs.get("F", 0.0)  # ρ¹²
                    ai[7] = asph_coeffs.get("G", 0.0)  # ρ¹⁴
                    ai[8] = asph_coeffs.get("H", 0.0)  # ρ¹⁶
                    ai[9] = asph_coeffs.get("I", 0.0)  # ρ¹⁸

                    s = Aspheric(c=c, r=r, d=d, ai=ai, k=k, mat2=current_material)
                    self.surfaces.append(s)
                    status = " (stop surface)" if is_stop else ""
                    print(
                        f"   Created aspheric surface{status}: Aspheric(c={c:.6f}, r={r:.4f}, d={d:.4f}, k={k}, mat2='{current_material}')"
                    )
                    if any(
                        ai[1:]
                    ):  # If there are non-zero higher-order terms (starting from ai[1])
                        print(
                            f"      Aspheric coefficients: A={ai[1]:.2e}, B={ai[2]:.2e}, C={ai[3]:.2e}, D={ai[4]:.2e}"
                        )

            else:
                print(f"   Skipped (same material on both sides and no stop flag)")

        except Exception as e:
            print(f"   Failed to create surface: {e}")
            import traceback

            traceback.print_exc()

        # Key: accumulate distance at the end of the loop
        d += d_next
        print(f"   After accumulation: d={d:.4f}")
        previous_material = current_material

    print(f"\n{'=' * 60}")
    print(f"   Done! Created {len(self.surfaces)} objects")
    print(f"   d_sensor={self.d_sensor:.4f}")
    print(f"   r_sensor={self.r_sensor:.4f}")
    print(f"   hfov={self.hfov:.4f}°")
    print(f"{'=' * 60}\n")

    return self

write_lens_seq

write_lens_seq(filename='./test.seq')

Write the lens to a CODE V .seq sequential file.

Exports surfaces, materials, field definitions, and entrance pupil settings in CODE V format.

Parameters:

Name Type Description Default
filename str

Output file path. Defaults to './test.seq'.

'./test.seq'

Returns:

Name Type Description
GeoLens

self, for method chaining.

Source code in deeplens/optics/geolens_pkg/io.py
def write_lens_seq(self, filename="./test.seq"):
    """Write the lens to a CODE V .seq sequential file.

    Exports surfaces, materials, field definitions, and entrance pupil
    settings in CODE V format.

    Args:
        filename (str, optional): Output file path. Defaults to './test.seq'.

    Returns:
        GeoLens: ``self``, for method chaining.
    """

    import datetime

    current_date = datetime.datetime.now().strftime("%d-%b-%Y")

    head_str = f"""RDM;LEN       "VERSION: 2023.03       LENS VERSION: 89       Creation Date:  {current_date}"
TITLE 'Lens Design'
EPD   {self.enpd}
DIM   M
WL    650.0 550.0 480.0
REF   2
WTW   1 2 1
INI   '   '
XAN   0.0 0.0 0.0
YAN   0.0  {0.707 * self.rfov * 57.3} {0.99 * self.rfov * 57.3}
WTF   1.0 1.0 1.0
VUY   0.0 0.0 0.0
VLY   0.0 0.0 0.0
DOR   1.15 1.05
SO    0.0 0.1e14
"""

    lens_seq_str = head_str
    previous_material = "air"

    for i, surf in enumerate(self.surfaces):
        if i < len(self.surfaces) - 1:
            d_next = self.surfaces[i + 1].d - surf.d
        else:
            d_next = float(self.d_sensor - surf.d)

        current_material = getattr(surf, "mat2", "air")

        if current_material is None or current_material == "air":
            material_str = ""
            material_name = "air"
        elif isinstance(current_material, str):
            material_str = f" {current_material.upper()}"
            material_name = current_material
        else:
            material_name = getattr(current_material, "name", str(current_material))
            material_str = f" {material_name.upper()}"

        is_aperture = surf.__class__.__name__ == "Aperture"

        if is_aperture:
            continue

        is_aspheric = surf.__class__.__name__ == "Aspheric"
        is_stop_surface = getattr(surf, "is_stop", False)

        if is_aspheric:
            if abs(surf.c) > 1e-10:
                radius = 1.0 / surf.c
            else:
                radius = 0.0

            k = surf.k if hasattr(surf, "k") else 0.0
            ai = surf.ai if hasattr(surf, "ai") else [0.0] * 10

            surf_str = f"S     {radius} {d_next}{material_str}\n"
            surf_str += f"  CCY 0; THC 0\n"
            surf_str += f"  CIR {surf.r}\n"
            if is_stop_surface:
                surf_str += f"  STO\n"
            surf_str += f"  ASP\n"
            surf_str += f"  K   {k}\n"

            if len(ai) > 4 and any(ai[1:5]):
                surf_str += f"  A   {ai[1]:.16e}; B {ai[2]:.16e}; C&\n"
                surf_str += f"   {ai[3]:.16e}; D {ai[4]:.16e}\n"

            if len(ai) > 8 and any(ai[5:9]):
                surf_str += f"  E   {ai[5]:.16e}; F {ai[6]:.16e}; G {ai[7]:.16e}; H {ai[8]:.16e}\n"

        else:
            if abs(surf.c) > 1e-10:
                radius = 1.0 / surf.c
            else:
                radius = 0.0

            surf_str = f"S     {radius} {d_next}{material_str}\n"
            surf_str += f"  CCY 0; THC 0\n"

            if is_stop_surface:
                surf_str += f"  STO\n"

            surf_str += f"  CIR {surf.r}\n"

        lens_seq_str += surf_str
        previous_material = material_name

    sensor_str = f"SI    0.0 0.0\n"
    sensor_str += f"  CIR {self.r_sensor}\n"
    lens_seq_str += sensor_str
    lens_seq_str += "GO \n"

    with open(filename, "w") as f:
        f.write(lens_seq_str)

    print(f"Lens written to CODE V file: {filename}")
    return self

read_lens_json

read_lens_json(filename='./test.json')

Read the lens from a JSON file.

Loads lens configuration including surfaces, materials, and optical properties from the DeepLens native JSON format.

Parameters:

Name Type Description Default
filename str

Path to the JSON lens file. Defaults to './test.json'.

'./test.json'
Note

After loading, the lens is moved to self.device and post_computation is called to calculate derived properties.

Source code in deeplens/optics/geolens_pkg/io.py
def read_lens_json(self, filename="./test.json"):
    """Read the lens from a JSON file.

    Loads lens configuration including surfaces, materials, and optical properties
    from the DeepLens native JSON format.

    Args:
        filename (str, optional): Path to the JSON lens file. Defaults to './test.json'.

    Note:
        After loading, the lens is moved to self.device and post_computation is called
        to calculate derived properties.
    """
    self.surfaces = []
    self.materials = []
    with open(filename, "r") as f:
        data = json.load(f)
        d = 0.0
        for idx, surf_dict in enumerate(data["surfaces"]):
            surf_dict["d"] = d
            surf_dict["surf_idx"] = idx

            if surf_dict["type"] == "Aperture":
                s = Aperture.init_from_dict(surf_dict)

            elif surf_dict["type"] == "Aspheric":
                s = Aspheric.init_from_dict(surf_dict)

            elif surf_dict["type"] == "Cubic":
                s = Cubic.init_from_dict(surf_dict)

            # elif surf_dict["type"] == "GaussianRBF":
            #     s = GaussianRBF.init_from_dict(surf_dict)

            # elif surf_dict["type"] == "NURBS":
            #     s = NURBS.init_from_dict(surf_dict)

            elif surf_dict["type"] == "Phase":
                s = Phase.init_from_dict(surf_dict)

            elif surf_dict["type"] == "Plane":
                s = Plane.init_from_dict(surf_dict)

            # elif surf_dict["type"] == "PolyEven":
            #     s = PolyEven.init_from_dict(surf_dict)

            elif surf_dict["type"] == "Stop":
                s = Aperture.init_from_dict(surf_dict)

            elif surf_dict["type"] == "Spheric":
                s = Spheric.init_from_dict(surf_dict)

            elif surf_dict["type"] == "ThinLens":
                s = ThinLens.init_from_dict(surf_dict)

            else:
                raise Exception(
                    f"Surface type {surf_dict['type']} is not implemented in GeoLens.read_lens_json()."
                )

            self.surfaces.append(s)
            d += surf_dict["d_next"]

    self.d_sensor = torch.tensor(d)
    self.lens_info = data.get("info", "None")
    self.enpd = data.get("enpd", None)
    self.float_enpd = True if self.enpd is None else False
    self.float_foclen = False
    self.float_rfov = False
    self.r_sensor = data["r_sensor"]

    self.to(self.device)

    # Set sensor size and resolution
    sensor_res = data.get("sensor_res", (2000, 2000))
    self.set_sensor_res(sensor_res=sensor_res)
    self.post_computation()

write_lens_json

write_lens_json(filename='./test.json')

Write the lens to a JSON file.

Saves the complete lens configuration including all surfaces, materials, focal length, F-number, and sensor properties to the DeepLens JSON format.

Parameters:

Name Type Description Default
filename str

Path for the output JSON file. Defaults to './test.json'.

'./test.json'
Source code in deeplens/optics/geolens_pkg/io.py
def write_lens_json(self, filename="./test.json"):
    """Write the lens to a JSON file.

    Saves the complete lens configuration including all surfaces, materials,
    focal length, F-number, and sensor properties to the DeepLens JSON format.

    Args:
        filename (str, optional): Path for the output JSON file. Defaults to './test.json'.
    """
    data = {}
    data["info"] = self.lens_info if hasattr(self, "lens_info") else "None"
    data["foclen"] = round(self.foclen, 4)
    data["fnum"] = round(self.fnum, 4)
    if self.float_enpd is False:
        data["enpd"] = round(self.enpd, 4)
    data["r_sensor"] = self.r_sensor
    data["(d_sensor)"] = round(self.d_sensor.item(), 4)
    data["(sensor_size)"] = [round(i, 4) for i in self.sensor_size]
    data["surfaces"] = []
    for i, s in enumerate(self.surfaces):
        surf_dict = {"idx": i}
        surf_dict.update(s.surf_dict())
        if i < len(self.surfaces) - 1:
            surf_dict["d_next"] = round(
                self.surfaces[i + 1].d.item() - self.surfaces[i].d.item(), 4
            )
        else:
            surf_dict["d_next"] = round(
                self.d_sensor.item() - self.surfaces[i].d.item(), 4
            )

        data["surfaces"].append(surf_dict)

    with open(filename, "w") as f:
        json.dump(data, f, indent=4)
    print(f"Lens written to {filename}")

deeplens.optics.geolens_pkg.vis.GeoLensVis

Mixin providing 2-D lens layout and ray visualisation for GeoLens.

Generates publication-quality cross-section plots showing lens surfaces and traced ray bundles in either the meridional or sagittal plane.

This class is not instantiated directly; it is mixed into :class:~deeplens.optics.geolens.GeoLens.

sample_parallel_2D

sample_parallel_2D(fov=0.0, num_rays=7, wvln=DEFAULT_WAVE, plane='meridional', entrance_pupil=True, depth=0.0)

Sample parallel rays (2D) in object space.

Used for (1) drawing lens setup, (2) 2D geometric optics calculation, for example, refocusing to infinity

Parameters:

Name Type Description Default
fov float

incident angle (in degree). Defaults to 0.0.

0.0
depth float

sampling depth. Defaults to 0.0.

0.0
num_rays int

ray number. Defaults to 7.

7
wvln float

ray wvln. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE
plane str

sampling plane. Defaults to "meridional" (y-z plane).

'meridional'
entrance_pupil bool

whether to use entrance pupil. Defaults to True.

True

Returns:

Name Type Description
ray Ray object

Ray object. Shape [num_rays, 3]

Source code in deeplens/optics/geolens_pkg/vis.py
@torch.no_grad()
def sample_parallel_2D(
    self,
    fov=0.0,
    num_rays=7,
    wvln=DEFAULT_WAVE,
    plane="meridional",
    entrance_pupil=True,
    depth=0.0,
):
    """Sample parallel rays (2D) in object space.

    Used for (1) drawing lens setup, (2) 2D geometric optics calculation, for example, refocusing to infinity

    Args:
        fov (float, optional): incident angle (in degree). Defaults to 0.0.
        depth (float, optional): sampling depth. Defaults to 0.0.
        num_rays (int, optional): ray number. Defaults to 7.
        wvln (float, optional): ray wvln. Defaults to DEFAULT_WAVE.
        plane (str, optional): sampling plane. Defaults to "meridional" (y-z plane).
        entrance_pupil (bool, optional): whether to use entrance pupil. Defaults to True.

    Returns:
        ray (Ray object): Ray object. Shape [num_rays, 3]
    """
    # Sample points on the pupil
    if entrance_pupil:
        pupilz, pupilr = self.get_entrance_pupil()
    else:
        pupilz, pupilr = 0, self.surfaces[0].r

    # Sample ray origins, shape [num_rays, 3]
    if plane == "sagittal":
        ray_o = torch.stack(
            (
                torch.linspace(-pupilr, pupilr, num_rays) * 0.99,
                torch.full((num_rays,), 0),
                torch.full((num_rays,), pupilz),
            ),
            axis=-1,
        )
    elif plane == "meridional":
        ray_o = torch.stack(
            (
                torch.full((num_rays,), 0),
                torch.linspace(-pupilr, pupilr, num_rays) * 0.99,
                torch.full((num_rays,), pupilz),
            ),
            axis=-1,
        )
    else:
        raise ValueError(f"Invalid plane: {plane}")

    # Sample ray directions, shape [num_rays, 3]
    if plane == "sagittal":
        ray_d = torch.stack(
            (
                torch.full((num_rays,), float(np.sin(np.deg2rad(fov)))),
                torch.zeros((num_rays,)),
                torch.full((num_rays,), float(np.cos(np.deg2rad(fov)))),
            ),
            axis=-1,
        )
    elif plane == "meridional":
        ray_d = torch.stack(
            (
                torch.zeros((num_rays,)),
                torch.full((num_rays,), float(np.sin(np.deg2rad(fov)))),
                torch.full((num_rays,), float(np.cos(np.deg2rad(fov)))),
            ),
            axis=-1,
        )
    else:
        raise ValueError(f"Invalid plane: {plane}")

    # Form rays and propagate to the target depth
    rays = Ray(ray_o, ray_d, wvln, device=self.device)
    rays.prop_to(depth)
    return rays

sample_point_source_2D

sample_point_source_2D(fov=0.0, depth=DEPTH, num_rays=7, wvln=DEFAULT_WAVE, entrance_pupil=True)

Sample point source rays (2D) in object space.

Used for (1) drawing lens setup.

Parameters:

Name Type Description Default
fov float

incident angle (in degree). Defaults to 0.0.

0.0
depth float

sampling depth. Defaults to DEPTH.

DEPTH
num_rays int

ray number. Defaults to 7.

7
wvln float

ray wvln. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE
entrance_pupil bool

whether to use entrance pupil. Defaults to False.

True

Returns:

Name Type Description
ray Ray object

Ray object. Shape [num_rays, 3]

Source code in deeplens/optics/geolens_pkg/vis.py
@torch.no_grad()
def sample_point_source_2D(
    self,
    fov=0.0,
    depth=DEPTH,
    num_rays=7,
    wvln=DEFAULT_WAVE,
    entrance_pupil=True,
):
    """Sample point source rays (2D) in object space.

    Used for (1) drawing lens setup.

    Args:
        fov (float, optional): incident angle (in degree). Defaults to 0.0.
        depth (float, optional): sampling depth. Defaults to DEPTH.
        num_rays (int, optional): ray number. Defaults to 7.
        wvln (float, optional): ray wvln. Defaults to DEFAULT_WAVE.
        entrance_pupil (bool, optional): whether to use entrance pupil. Defaults to False.

    Returns:
        ray (Ray object): Ray object. Shape [num_rays, 3]
    """
    # Sample point on the object plane
    ray_o = torch.tensor([depth * float(np.tan(np.deg2rad(fov))), 0.0, depth])
    ray_o = ray_o.unsqueeze(0).repeat(num_rays, 1)

    # Sample points (second point) on the pupil
    if entrance_pupil:
        pupilz, pupilr = self.calc_entrance_pupil()
    else:
        pupilz, pupilr = 0, self.surfaces[0].r

    x2 = torch.linspace(-pupilr, pupilr, num_rays) * 0.99
    y2 = torch.zeros_like(x2)
    z2 = torch.full_like(x2, pupilz)
    ray_o2 = torch.stack((x2, y2, z2), axis=1)

    # Form the rays
    ray_d = ray_o2 - ray_o
    ray = Ray(ray_o, ray_d, wvln, device=self.device)

    # Propagate rays to the sampling depth
    ray.prop_to(depth)
    return ray

draw_layout

draw_layout(filename, depth=float('inf'), zmx_format=True, multi_plot=False, lens_title=None, show=False)

Plot 2D lens layout with ray tracing.

Parameters:

Name Type Description Default
filename

Output filename

required
depth

Depth for ray tracing

float('inf')
entrance_pupil

Whether to use entrance pupil

required
zmx_format

Whether to use ZMX format

True
multi_plot

Whether to create multiple plots

False
lens_title

Title for the lens plot

None
show

Whether to show the plot

False
Source code in deeplens/optics/geolens_pkg/vis.py
def draw_layout(
    self,
    filename,
    depth=float("inf"),
    zmx_format=True,
    multi_plot=False,
    lens_title=None,
    show=False,
):
    """Plot 2D lens layout with ray tracing.

    Args:
        filename: Output filename
        depth: Depth for ray tracing
        entrance_pupil: Whether to use entrance pupil
        zmx_format: Whether to use ZMX format
        multi_plot: Whether to create multiple plots
        lens_title: Title for the lens plot
        show: Whether to show the plot
    """
    num_rays = 11
    num_views = 3

    # Lens title
    if lens_title is None:
        eff_foclen = int(self.foclen)
        eq_foclen = int(self.eqfl)
        fov_deg = round(2 * self.rfov * 180 / torch.pi, 1)
        sensor_r = round(self.r_sensor, 1)
        sensor_w, sensor_h = self.sensor_size
        sensor_w = round(sensor_w, 1)
        sensor_h = round(sensor_h, 1)

        if self.aper_idx is not None:
            _, pupil_r = self.calc_entrance_pupil()
            fnum = round(eff_foclen / pupil_r / 2, 1)
            lens_title = f"FocLen{eff_foclen}mm - F/{fnum} - FoV{fov_deg}(Equivalent {eq_foclen}mm) - Sensor Diagonal {2 * sensor_r}mm"
        else:
            lens_title = f"FocLen{eff_foclen}mm - FoV{fov_deg}(Equivalent {eq_foclen}mm) - Sensor Diagonal {2 * sensor_r}mm"

    # Draw lens layout
    colors_list = ["#CC0000", "#006600", "#0066CC"]
    rfov_deg = float(np.rad2deg(self.rfov))
    fov_ls = np.linspace(0, rfov_deg * 0.99, num=num_views)

    if not multi_plot:
        ax, fig = self.draw_lens_2d(zmx_format=zmx_format)
        fig.suptitle(lens_title, fontsize=10)
        for i, fov in enumerate(fov_ls):
            # Sample rays, shape (num_rays, 3)
            if depth == float("inf"):
                ray = self.sample_parallel_2D(
                    fov=fov,
                    wvln=WAVE_RGB[2 - i],
                    num_rays=num_rays,
                    depth=-1.0,
                    plane="sagittal",
                )
            else:
                ray = self.sample_point_source_2D(
                    fov=fov,
                    depth=depth,
                    num_rays=num_rays,
                    wvln=WAVE_RGB[2 - i],
                )
                ray.prop_to(-1.0)

            # Trace rays to sensor and plot ray paths
            _, ray_o_record = self.trace2sensor(ray=ray, record=True)
            ax, fig = self.draw_ray_2d(
                ray_o_record, ax=ax, fig=fig, color=colors_list[i]
            )

        ax.axis("off")

    else:
        fig, axs = plt.subplots(1, 3, figsize=(15, 5))
        fig.suptitle(lens_title, fontsize=10)
        for i, wvln in enumerate(WAVE_RGB):
            ax = axs[i]
            ax, fig = self.draw_lens_2d(ax=ax, fig=fig, zmx_format=zmx_format)
            for fov in fov_ls:
                # Sample rays, shape (num_rays, 3)
                if depth == float("inf"):
                    ray = self.sample_parallel_2D(
                        fov=fov,
                        num_rays=num_rays,
                        wvln=wvln,
                        plane="sagittal",
                    )
                else:
                    ray = self.sample_point_source_2D(
                        fov=fov,
                        depth=depth,
                        num_rays=num_rays,
                        wvln=wvln,
                    )

                # Trace rays to sensor and plot ray paths
                ray_out, ray_o_record = self.trace2sensor(ray=ray, record=True)
                ax, fig = self.draw_ray_2d(
                    ray_o_record, ax=ax, fig=fig, color=colors_list[i]
                )
                ax.axis("off")

    if show:
        fig.show()
    else:
        fig.savefig(filename, format="png", dpi=300)
        plt.close()

draw_lens_2d

draw_lens_2d(ax=None, fig=None, color='k', linestyle='-', zmx_format=False, fix_bound=False)

Draw lens cross-section layout in a 2D plot.

Renders each surface profile, connects lens elements with edge lines, and draws the sensor plane.

Parameters:

Name Type Description Default
ax Axes

Existing axes to draw on. If None, creates a new figure. Defaults to None.

None
fig Figure

Existing figure. Defaults to None.

None
color str

Line colour for lens outlines. Defaults to 'k'.

'k'
linestyle str

Line style. Defaults to '-'.

'-'
zmx_format bool

If True, draw stepped edge connections matching Zemax layout style. Defaults to False.

False
fix_bound bool

If True, use fixed axis limits [-1,7]x[-4,4]. Defaults to False.

False

Returns:

Name Type Description
tuple

(ax, fig) matplotlib axes and figure objects.

Source code in deeplens/optics/geolens_pkg/vis.py
def draw_lens_2d(
    self,
    ax=None,
    fig=None,
    color="k",
    linestyle="-",
    zmx_format=False,
    fix_bound=False,
):
    """Draw lens cross-section layout in a 2D plot.

    Renders each surface profile, connects lens elements with edge lines,
    and draws the sensor plane.

    Args:
        ax (matplotlib.axes.Axes, optional): Existing axes to draw on. If None,
            creates a new figure. Defaults to None.
        fig (matplotlib.figure.Figure, optional): Existing figure. Defaults to None.
        color (str, optional): Line colour for lens outlines. Defaults to 'k'.
        linestyle (str, optional): Line style. Defaults to '-'.
        zmx_format (bool, optional): If True, draw stepped edge connections
            matching Zemax layout style. Defaults to False.
        fix_bound (bool, optional): If True, use fixed axis limits [-1,7]x[-4,4].
            Defaults to False.

    Returns:
        tuple: (ax, fig) matplotlib axes and figure objects.
    """
    # If no ax is given, generate a new one.
    if ax is None and fig is None:
        # fig, ax = plt.subplots(figsize=(6, 6))
        fig, ax = plt.subplots()

    # Draw lens surfaces
    for i, s in enumerate(self.surfaces):
        s.draw_widget(ax)

    # Connect two surfaces
    for i in range(len(self.surfaces)):
        if self.surfaces[i].mat2.n > 1.1:
            s_prev = self.surfaces[i]
            s = self.surfaces[i + 1]

            r_prev = float(s_prev.draw_r())
            r = float(s.draw_r())
            sag_prev = s_prev.surface_with_offset(
                r_prev, 0.0, valid_check=False
            ).item()
            sag = s.surface_with_offset(
                r, 0.0, valid_check=False
            ).item()

            if r_prev >= r:
                # Front surface wider: go axially forward at r_prev, then step radially inward
                z = np.array([sag_prev, sag, sag])
                x = np.array([r_prev, r_prev, r])
            else:
                # Rear surface wider: step radially outward at z_prev, then go axially forward
                z = np.array([sag_prev, sag_prev, sag])
                x = np.array([r_prev, r, r])

            if not zmx_format:
                # In non-zmx mode use a direct diagonal between the two outer edges
                z = np.array([z[0], z[-1]])
                x = np.array([x[0], x[-1]])

            ax.plot(z, -x, color, linewidth=0.75)
            ax.plot(z, x, color, linewidth=0.75)
            s_prev = s

    # Draw sensor
    ax.plot(
        [self.d_sensor.item(), self.d_sensor.item()],
        [-self.r_sensor, self.r_sensor],
        color,
    )

    # Set figure size
    if fix_bound:
        ax.set_aspect("equal")
        ax.set_xlim(-1, 7)
        ax.set_ylim(-4, 4)
    else:
        ax.set_aspect("equal", adjustable="datalim", anchor="C")
        ax.minorticks_on()
        ax.set_xlim(-0.5, 7.5)
        ax.set_ylim(-4, 4)
        ax.autoscale()

    return ax, fig

draw_ray_2d

draw_ray_2d(ray_o_record, ax, fig, color='b')

Plot ray paths.

Parameters:

Name Type Description Default
ray_o_record list

list of intersection points.

required
ax Axes

matplotlib axes.

required
fig Figure

matplotlib figure.

required
Source code in deeplens/optics/geolens_pkg/vis.py
def draw_ray_2d(self, ray_o_record, ax, fig, color="b"):
    """Plot ray paths.

    Args:
        ray_o_record (list): list of intersection points.
        ax (matplotlib.axes.Axes): matplotlib axes.
        fig (matplotlib.figure.Figure): matplotlib figure.
    """
    # shape (num_view, num_rays, num_path, 2)
    ray_o_record = torch.stack(ray_o_record, dim=-2).cpu().numpy()
    if ray_o_record.ndim == 3:
        ray_o_record = ray_o_record[None, ...]

    for idx_view in range(ray_o_record.shape[0]):
        for idx_ray in range(ray_o_record.shape[1]):
            ax.plot(
                ray_o_record[idx_view, idx_ray, :, 2],
                ray_o_record[idx_view, idx_ray, :, 0],
                color,
                linewidth=0.8,
            )

            # ax.scatter(
            #     ray_o_record[idx_view, idx_ray, :, 2],
            #     ray_o_record[idx_view, idx_ray, :, 0],
            #     "b",
            #     marker="x",
            # )

    return ax, fig

create_barrier

create_barrier(filename, barrier_thickness=1.0, ring_height=0.5, ring_size=1.0)

Create a 3D barrier for the lens system.

Parameters:

Name Type Description Default
filename

Path to save the figure

required
barrier_thickness

Thickness of the barrier

1.0
ring_height

Height of the annular ring

0.5
ring_size

Size of the annular ring

1.0
Source code in deeplens/optics/geolens_pkg/vis.py
def create_barrier(
    self, filename, barrier_thickness=1.0, ring_height=0.5, ring_size=1.0
):
    """Create a 3D barrier for the lens system.

    Args:
        filename: Path to save the figure
        barrier_thickness: Thickness of the barrier
        ring_height: Height of the annular ring
        ring_size: Size of the annular ring
    """
    barriers = []
    rings = []

    # Create barriers
    barrier_z = 0.0
    barrier_r = 0.0
    barrier_length = 0.0
    for i in range(len(self.surfaces)):
        barrier_r = max(self.surfaces[i].r, barrier_r)

        if self.surfaces[i].mat2.get_name() != "air":
            # Update the barrier radius
            # barrier_r = max(geolens.surfaces[i].r, barrier_r)
            pass
        else:
            # Extend the barrier till middle of the air space to the next surface
            max_curr_surf_d = self.surfaces[i].d.item() + max(
                self.surfaces[i].surface_sag(0.0, self.surfaces[i].r), 0.0
            )
            if i < len(self.surfaces) - 1:
                min_next_surf_d = self.surfaces[i + 1].d.item() + min(
                    self.surfaces[i + 1].surface_sag(0.0, self.surfaces[i + 1].r),
                    0.0,
                )
                extra_space = (min_next_surf_d - max_curr_surf_d) / 2
            else:
                min_next_surf_d = self.d_sensor.item()
                extra_space = min_next_surf_d - max_curr_surf_d

            barrier_length = max_curr_surf_d + extra_space - barrier_z

            # Create a barrier
            barrier = {
                "pos_z": barrier_z,
                "pos_r": barrier_r,
                "length": barrier_length,
                "thickness": barrier_thickness,
            }
            barriers.append(barrier)

            # Reset the barrier parameters
            barrier_z = barrier_length + barrier_z
            barrier_r = 0.0
            barrier_length = 0.0

    # # Create rings
    # for i in range(len(geolens.surfaces)):
    #     if geolens.surfaces[i].mat2.get_name() != "air":
    #         ring = {
    #             "pos_z": geolens.surfaces[i].d.item(),

    # Plot lens layout
    ax, fig = self.draw_layout(filename)

    # Plot barrier
    barrier_z_ls = []
    barrier_r_ls = []
    for b in barriers:
        barrier_z_ls.append(b["pos_z"])
        barrier_z_ls.append(b["pos_z"] + b["length"])
        barrier_r_ls.append(b["pos_r"])
        barrier_r_ls.append(b["pos_r"])
    ax.plot(barrier_z_ls, barrier_r_ls, "green", linewidth=1.0)
    ax.plot(barrier_z_ls, [-i for i in barrier_r_ls], "green", linewidth=1.0)

    # Plot rings

    fig.savefig(filename, format="png", dpi=300)
    plt.close()

    pass

deeplens.optics.geolens_pkg.eval_tolerance.GeoLensTolerance

Mixin providing tolerance analysis for GeoLens.

Implements two complementary approaches:

  • Sensitivity analysis – first-order gradient-based estimation of how each manufacturing error affects optical performance.
  • Monte-Carlo analysis – statistical sampling of random manufacturing errors to predict yield and worst-case performance.

This class is not instantiated directly; it is mixed into :class:~deeplens.optics.geolens.GeoLens.

References

Jun Dai et al., "Tolerance-Aware Deep Optics," arXiv:2502.04719, 2025.

init_tolerance

init_tolerance(tolerance_params=None)

Initialize manufacturing tolerance parameters for all surfaces.

Sets up tolerance ranges (e.g., curvature, thickness, decenter, tilt) on each surface. These are used by sample_tolerance() to simulate random manufacturing errors.

Parameters:

Name Type Description Default
tolerance_params dict

Custom tolerance specifications. If None, each surface uses its own defaults. Defaults to None.

None
Source code in deeplens/optics/geolens_pkg/eval_tolerance.py
def init_tolerance(self, tolerance_params=None):
    """Initialize manufacturing tolerance parameters for all surfaces.

    Sets up tolerance ranges (e.g., curvature, thickness, decenter, tilt)
    on each surface. These are used by ``sample_tolerance()`` to simulate
    random manufacturing errors.

    Args:
        tolerance_params (dict, optional): Custom tolerance specifications.
            If None, each surface uses its own defaults. Defaults to None.
    """
    if tolerance_params is None:
        tolerance_params = {}

    for i in range(len(self.surfaces)):
        self.surfaces[i].init_tolerance(tolerance_params=tolerance_params)

sample_tolerance

sample_tolerance()

Apply random manufacturing errors to all surfaces.

Randomly perturbs each surface according to its tolerance ranges and then refocuses the lens to compensate for the focus shift.

Source code in deeplens/optics/geolens_pkg/eval_tolerance.py
@torch.no_grad()
def sample_tolerance(self):
    """Apply random manufacturing errors to all surfaces.

    Randomly perturbs each surface according to its tolerance ranges and
    then refocuses the lens to compensate for the focus shift.
    """
    # Randomly perturb all surfaces
    for i in range(len(self.surfaces)):
        self.surfaces[i].sample_tolerance()

    # Refocus the lens
    self.refocus()

zero_tolerance

zero_tolerance()

Reset all manufacturing errors to zero (nominal lens state).

Clears the perturbations on every surface and refocuses the lens.

Source code in deeplens/optics/geolens_pkg/eval_tolerance.py
@torch.no_grad()
def zero_tolerance(self):
    """Reset all manufacturing errors to zero (nominal lens state).

    Clears the perturbations on every surface and refocuses the lens.
    """
    for i in range(len(self.surfaces)):
        self.surfaces[i].zero_tolerance()

    # Refocus the lens
    self.refocus()

tolerancing_sensitivity

tolerancing_sensitivity(tolerance_params=None)

Use sensitivity analysis (1st order gradient) to compute the tolerance score.

References

[1] Page 10 from: https://wp.optics.arizona.edu/optomech/wp-content/uploads/sites/53/2016/08/8-Tolerancing-1.pdf [2] Fast sensitivity control method with differentiable optics. Optics Express 2025. [3] Optical Design Tolerancing. CODE V.

Source code in deeplens/optics/geolens_pkg/eval_tolerance.py
def tolerancing_sensitivity(self, tolerance_params=None):
    """Use sensitivity analysis (1st order gradient) to compute the tolerance score.

    References:
        [1] Page 10 from: https://wp.optics.arizona.edu/optomech/wp-content/uploads/sites/53/2016/08/8-Tolerancing-1.pdf
        [2] Fast sensitivity control method with differentiable optics. Optics Express 2025.
        [3] Optical Design Tolerancing. CODE V.
    """
    # Initialize tolerance
    self.init_tolerance(tolerance_params=tolerance_params)

    # AutoDiff to compute the gradient/sensitivity
    self.get_optimizer_params()
    loss = self.loss_rms()
    loss.backward()

    # Calculate sensitivity results
    sensitivity_results = {}
    for i in range(len(self.surfaces)):
        sensitivity_results.update(self.surfaces[i].sensitivity_score())

    # Toleranced RSS (Root Sum Square) loss
    tolerancing_score = sum(
        v for k, v in sensitivity_results.items() if k.endswith("_score")
    )
    loss_rss = torch.sqrt(loss**2 + tolerancing_score).item()
    sensitivity_results["loss_nominal"] = round(loss.item(), 6)
    sensitivity_results["loss_rss"] = round(loss_rss, 6)
    return sensitivity_results

tolerancing_monte_carlo

tolerancing_monte_carlo(trials=200, spp=SPP_CALC, tolerance_params=None)

Use Monte Carlo simulation to compute the tolerance.

The default trials=200 is tuned for ~3 min runtime on GPU. For production-quality yield estimates (especially 95th/99th percentile tails), increase to 1000+.

Parameters:

Name Type Description Default
trials int

Number of Monte Carlo trials. Defaults to 200.

200
spp int

Samples per pixel for PSF calculation. Lower values run faster at the cost of noisier MTF estimates. Defaults to SPP_CALC (1024), which is ~16x faster than the full SPP_PSF.

SPP_CALC
tolerance_params dict

Tolerance parameters.

None

Returns:

Name Type Description
dict

Monte Carlo tolerance analysis results.

References

[1] https://optics.ansys.com/hc/en-us/articles/43071088477587-How-to-analyze-your-tolerance-results [2] Optical Design Tolerancing. CODE V.

Source code in deeplens/optics/geolens_pkg/eval_tolerance.py
@torch.no_grad()
def tolerancing_monte_carlo(self, trials=200, spp=SPP_CALC, tolerance_params=None):
    """Use Monte Carlo simulation to compute the tolerance.

    The default ``trials=200`` is tuned for ~3 min runtime on GPU.
    For production-quality yield estimates (especially 95th/99th
    percentile tails), increase to 1000+.

    Args:
        trials (int): Number of Monte Carlo trials. Defaults to 200.
        spp (int): Samples per pixel for PSF calculation. Lower values
            run faster at the cost of noisier MTF estimates. Defaults to
            SPP_CALC (1024), which is ~16x faster than the full SPP_PSF.
        tolerance_params (dict): Tolerance parameters.

    Returns:
        dict: Monte Carlo tolerance analysis results.

    References:
        [1] https://optics.ansys.com/hc/en-us/articles/43071088477587-How-to-analyze-your-tolerance-results
        [2] Optical Design Tolerancing. CODE V.
    """

    def merit_func(lens, fov=0.0, depth=DEPTH):
        """Evaluate MTF merit at a single field point."""
        try:
            point = [0, -fov / lens.rfov, depth]
            psf = lens.psf(points=point, spp=spp, recenter=True)
            freq, mtf_tan, mtf_sag = lens.psf2mtf(psf, pixel_size=lens.pixel_size)

            # Evaluate MTF at quarter-Nyquist frequency
            nyquist_freq = 0.5 / lens.pixel_size
            eval_freq = 0.25 * nyquist_freq
            idx = torch.argmin(torch.abs(torch.tensor(freq) - eval_freq))
            score = (mtf_tan[idx] + mtf_sag[idx]) / 2
            return score.item()
        except RuntimeError:
            # Perturbed lens may block all rays at extreme fields
            return 0.0

    def multi_field_merit(lens, depth=DEPTH):
        """Evaluate average MTF merit across multiple field positions."""
        fov_points = [0.0, 0.5, 1.0]
        scores = [merit_func(lens, fov=fov, depth=depth) for fov in fov_points]
        return float(np.mean(scores))

    # Initialize tolerance
    self.init_tolerance(tolerance_params=tolerance_params)

    # Monte Carlo simulation
    merit_ls = []
    with torch.no_grad():
        for i in tqdm(range(trials)):
            # Sample a random perturbation and refocus sensor only
            # (skip full post_computation — focal length, pupil, and FoV
            # don't change meaningfully under small tolerance errors).
            for surf in self.surfaces:
                surf.sample_tolerance()
            self.d_sensor = self.calc_sensor_plane()

            # Evaluate perturbed performance across multiple field positions
            perturbed_merit = multi_field_merit(lens=self, depth=DEPTH)
            merit_ls.append(perturbed_merit)

            # Clear perturbation (no refocus needed — next iteration
            # will set sensor position after sampling).
            for surf in self.surfaces:
                surf.zero_tolerance()

    merit_ls = np.array(merit_ls)

    # Baseline merit (nominal lens)
    self.refocus()
    baseline_merit = multi_field_merit(lens=self, depth=DEPTH)

    # Results plot — histogram + CDF
    fig, ax1 = plt.subplots(figsize=(9, 5))

    # Histogram
    ax1.hist(
        merit_ls,
        bins=30,
        color="#4C72B0",
        alpha=0.6,
        edgecolor="white",
        label="Frequency",
    )
    ax1.set_xlabel("MTF Merit Score (higher is better)", fontsize=12)
    ax1.set_ylabel("Count", fontsize=12, color="#4C72B0")
    ax1.tick_params(axis="y", labelcolor="#4C72B0")

    # CDF on secondary axis
    ax2 = ax1.twinx()
    sorted_merit = np.sort(merit_ls)
    cdf = np.arange(1, len(sorted_merit) + 1) / len(sorted_merit) * 100
    ax2.plot(sorted_merit, cdf, color="#C44E52", linewidth=2, label="CDF")
    ax2.set_ylabel("Cumulative % of Lenses", fontsize=12, color="#C44E52")
    ax2.tick_params(axis="y", labelcolor="#C44E52")
    ax2.set_ylim(0, 105)

    # Baseline reference
    ax1.axvline(
        baseline_merit,
        color="green",
        linestyle="--",
        linewidth=1.5,
        label=f"Nominal = {baseline_merit:.3f}",
    )

    # Yield annotations — 90% and 50% yield lines
    p90 = float(np.percentile(merit_ls, 10))  # 90% of lenses exceed this
    p50 = float(np.percentile(merit_ls, 50))
    ax1.axvline(
        p90, color="orange", linestyle=":", linewidth=1.5,
        label=f"90% yield > {p90:.3f}",
    )
    ax1.axvline(
        p50, color="gray", linestyle=":", linewidth=1.5,
        label=f"50% yield > {p50:.3f}",
    )

    # Title and legend
    ax1.set_title(
        f"Monte Carlo Tolerance Analysis  ({trials} trials)",
        fontsize=13,
        fontweight="bold",
    )
    ax1.legend(loc="upper left", fontsize=9, framealpha=0.9)
    ax1.grid(True, alpha=0.2)
    fig.tight_layout()
    fig.savefig(
        "Monte_Carlo_Tolerance.png", dpi=300, bbox_inches="tight"
    )
    plt.close(fig)

    # Results dict
    results = {
        "method": "monte_carlo",
        "trials": trials,
        "baseline_merit": round(baseline_merit, 6),
        "merit_std": round(float(np.std(merit_ls)), 6),
        "merit_mean": round(float(np.mean(merit_ls)), 6),
        "merit_yield": {
            "99% > ": round(float(np.percentile(merit_ls, 1)), 4),
            "95% > ": round(float(np.percentile(merit_ls, 5)), 4),
            "90% > ": round(float(np.percentile(merit_ls, 10)), 4),
            "80% > ": round(float(np.percentile(merit_ls, 20)), 4),
            "70% > ": round(float(np.percentile(merit_ls, 30)), 4),
            "60% > ": round(float(np.percentile(merit_ls, 40)), 4),
            "50% > ": round(float(np.percentile(merit_ls, 50)), 4),
        },
        "merit_percentile": {
            "99% < ": round(float(np.percentile(merit_ls, 99)), 4),
            "95% < ": round(float(np.percentile(merit_ls, 95)), 4),
            "90% < ": round(float(np.percentile(merit_ls, 90)), 4),
            "80% < ": round(float(np.percentile(merit_ls, 80)), 4),
            "70% < ": round(float(np.percentile(merit_ls, 70)), 4),
            "60% < ": round(float(np.percentile(merit_ls, 60)), 4),
            "50% < ": round(float(np.percentile(merit_ls, 50)), 4),
        },
    }
    return results

tolerancing_wavefront

tolerancing_wavefront(tolerance_params=None)

Use wavefront differential method to compute the tolerance.

Wavefront differential method is proposed in [1], while the detailed implementation remains unknown. I (Xinge Yang) assume a symbolic differentiation is used to compute the gradient/Jacobian of the wavefront error. With AutoDiff, we can easily calculate Jacobian with gradient backpropagation, therefore I leave the implementation of this method as future work.

Parameters:

Name Type Description Default
tolerance_params dict

Tolerance parameters

None

Returns:

Name Type Description
dict

Wavefront tolerance analysis results

References

[1] Optical Design Tolerancing. CODE V.

Source code in deeplens/optics/geolens_pkg/eval_tolerance.py
def tolerancing_wavefront(self, tolerance_params=None):
    """Use wavefront differential method to compute the tolerance.

    Wavefront differential method is proposed in [1], while the detailed implementation remains unknown. I (Xinge Yang) assume a symbolic differentiation is used to compute the gradient/Jacobian of the wavefront error. With AutoDiff, we can easily calculate Jacobian with gradient backpropagation, therefore I leave the implementation of this method as future work.

    Args:
        tolerance_params (dict): Tolerance parameters

    Returns:
        dict: Wavefront tolerance analysis results

    References:
        [1] Optical Design Tolerancing. CODE V.
    """
    pass

deeplens.optics.geolens_pkg.vis3d.GeoLensVis3D

Mixin providing 3-D mesh visualisation for GeoLens.

Creates lens surface, aperture, barrier, sensor, and ray-path meshes as polygon data and optionally renders them with PyVista. All geometry is expressed in millimetres and stored as :class:CrossPoly (vertex/face) objects that can be saved to .obj files for external renderers.

This class is not instantiated directly; it is mixed into :class:~deeplens.optics.geolens.GeoLens.

create_mesh

create_mesh(mesh_rings: int = 32, mesh_arms: int = 128, is_wrap: bool = False)

Create all lens/bridge/sensor/aperture meshes.

Parameters:

Name Type Description Default
lens GeoLens

The lens object.

required
mesh_rings int

The number of rings in the mesh.

32
mesh_arms int

The number of arms in the mesh.

128
is_wrap bool

Whether to wrap the lens bridge around the lens as cylinder.

False

Returns: surf_meshes (List[Surface]): Lens surfaces meshes. bridge_meshes (List[FaceMesh]): Lens bridges meshes. (NOT support wrap around for now) sensor_mesh (RectangleMesh): Sensor meshes. (only support rectangular sensor for now)

Source code in deeplens/optics/geolens_pkg/vis3d.py
def create_mesh(
    self,
    mesh_rings: int = 32,
    mesh_arms: int = 128,
    is_wrap: bool = False,
):
    """Create all lens/bridge/sensor/aperture meshes.

    Args:
        lens (GeoLens): The lens object.
        mesh_rings (int): The number of rings in the mesh.
        mesh_arms (int): The number of arms in the mesh.
        is_wrap (bool): Whether to wrap the lens bridge around the lens as cylinder.
    Returns:
        surf_meshes (List[Surface]): Lens surfaces meshes.
        bridge_meshes (List[FaceMesh]): Lens bridges meshes. (NOT support wrap around for now)
        sensor_mesh (RectangleMesh): Sensor meshes. (only support rectangular sensor for now)
    """
    surf_meshes = []
    element_group = []
    element_groups = []
    bridge_meshes = []  # change to nested list for wrap around
    sensor_mesh = None

    # Create the surface meshes
    for i, surf in enumerate(self.surfaces):
        # Create the surface mesh (list of Surface objects)
        surf_meshes.append(surf.create_mesh(n_rings=mesh_rings, n_arms=mesh_arms))

        # Add the surface to the element group
        element_group.append(i)
        if surf.mat2.name == "air":
            element_groups.append(element_group)
            element_group = []

    # Create the bridge meshes (list of FaceMesh objects)
    for i, pair in enumerate(element_groups):
        if len(pair) == 1:
            bridge_meshes.append([])
            continue
        elif len(pair) == 2:
            a_idx, b_idx = pair
            a = surf_meshes[a_idx]
            b = surf_meshes[b_idx]
            bridge_mesh_group = []
            if not is_wrap:
                bridge_mesh = bridge(a.rim, b.rim)
                bridge_mesh_group.append(bridge_mesh)
            else:
                # create wrap by creating a new rim
                # from projecting the larger rim onto the smaller rim plane
                # assume the elements are always ordered on z-axis
                r_a = self.surfaces[a_idx].r
                r_b = self.surfaces[b_idx].r
                d_rim_a = np.mean(
                    a.rim.vertices[:, 2], keepdims=False
                )  # calc rim mean z
                d_rim_b = np.mean(b.rim.vertices[:, 2], keepdims=False)

                if r_a > r_b:
                    z = line_translate(a.rim, 0, 0, d_rim_b - d_rim_a)
                    bridge_mesh_wrap = bridge(z, b.rim)
                    bridge_mesh = bridge(a.rim, z)
                    bridge_mesh_group.append(bridge_mesh_wrap)
                elif r_a < r_b:
                    z = line_translate(b.rim, 0, 0, d_rim_a - d_rim_b)
                    bridge_mesh_wrap = bridge(a.rim, z)
                    bridge_mesh = bridge(z, b.rim)
                    bridge_mesh_group.append(bridge_mesh_wrap)
                else:
                    bridge_mesh = bridge(a.rim, b.rim)
                bridge_mesh_group.append(bridge_mesh)
            bridge_meshes.append(bridge_mesh_group)

        elif len(pair) == 3:
            a_idx, b_idx, c_idx = pair
            a = surf_meshes[a_idx]
            b = surf_meshes[b_idx]
            c = surf_meshes[c_idx]
            bridge_mesh_group = []
            if not is_wrap:
                bridge_mesh = bridge(a.rim, b.rim)
                bridge_mesh_group.append(bridge_mesh)
                bridge_mesh = bridge(b.rim, c.rim)
                bridge_mesh_group.append(bridge_mesh)
            else:
                # create wrap by creating a new rim
                # from projecting the larger rim onto the smaller rim plane
                # assume the elements are always ordered on z-axis
                r_a = self.surfaces[a_idx].r
                r_b = self.surfaces[b_idx].r
                r_c = self.surfaces[c_idx].r
                d_rim_a = np.mean(
                    a.rim.vertices[:, 2], keepdims=False
                )  # calc rim mean z
                d_rim_b = np.mean(b.rim.vertices[:, 2], keepdims=False)
                d_rim_c = np.mean(c.rim.vertices[:, 2], keepdims=False)

                rim_list = [a.rim, b.rim, c.rim]
                r_list = [r_a, r_b, r_c]
                d_rim_list = [d_rim_a, d_rim_b, d_rim_c]
                idx_wrap = r_list.index(max(r_list))
                r_wrap = r_list[idx_wrap]
                d_rim_wrap = d_rim_list[idx_wrap]

                for i in range(3):
                    if i != idx_wrap and r_list[i] != r_wrap:
                        # substitute the rim with the wrapped rim
                        d_diff = d_rim_list[i] - d_rim_wrap
                        z = line_translate(rim_list[idx_wrap], 0, 0, d_diff)
                        # add the wrap bridge between older rim and wrapped one
                        wrap_mesh = bridge(rim_list[i], z)
                        # update the rim
                        rim_list[i] = z
                        bridge_mesh_group.append(wrap_mesh)
                bridge_mesh = bridge(rim_list[0], rim_list[1])
                bridge_mesh_group.append(bridge_mesh)
                bridge_mesh = bridge(rim_list[1], rim_list[2])
                bridge_mesh_group.append(bridge_mesh)
            bridge_meshes.append(bridge_mesh_group)

        else:
            raise ValueError(f"Invalid bridge group length: {len(pair)}")

    # Create the sensor mesh (RectangleMesh object)
    sensor_d = self.d_sensor.item()
    sensor_r = self.r_sensor
    h, w = sensor_r * 1.4142, sensor_r * 1.4142
    sensor_mesh = RectangleMesh(
        np.array([0, 0, sensor_d]), np.array([1, 0, 0]), np.array([0, 1, 0]), w, h
    )

    # turn surf_meshes to list of FaceMesh
    surf_meshes_cvt = [surf_to_face_mesh(surf) for surf in surf_meshes]
    return surf_meshes_cvt, bridge_meshes, element_groups, sensor_mesh

draw_lens_3d

draw_lens_3d(plotter=None, save_dir: Optional[str] = None, mesh_rings: int = 32, mesh_arms: int = 128, surface_color: List[float] = [0.06, 0.3, 0.6], draw_rays: bool = True, fovs: List[float] = [0.0], fov_phis: List[float] = [0.0], ray_rings: int = 6, ray_arms: int = 8, is_wrap: bool = False)

Draw lens 3D layout with rays using pyvista.

Note: PyVista is imported lazily only when this method is called.

Parameters:

Name Type Description Default
plotter

pv.Plotter. Optional pyvista Plotter instance. If None, a new one is created.

None
save_dir str

The directory to save the image.

None
mesh_rings int

The number of rings in the mesh.

32
mesh_arms int

The number of arms in the mesh.

128
surface_color List[float]

The color of the surfaces.

[0.06, 0.3, 0.6]
draw_rays bool

Whether to show the rays.

True
fovs List[float]

The FoV angles to be sampled, unit: degree.

[0.0]
fov_phis List[float]

The FoV azimuthal angles to be sampled, unit: degree.

[0.0]
ray_rings int

The number of pupil rings to be sampled.

6
ray_arms int

The number of pupil arms to be sampled.

8
is_wrap bool

Whether to wrap the lens bridge around the lens as cylinder.

False

Returns:

Name Type Description
plotter

pv.Plotter. The pyvista Plotter instance.

Source code in deeplens/optics/geolens_pkg/vis3d.py
def draw_lens_3d(
    self,
    plotter=None,
    save_dir: Optional[str] = None,
    mesh_rings: int = 32,
    mesh_arms: int = 128,
    surface_color: List[float] = [0.06, 0.3, 0.6],
    draw_rays: bool = True,
    fovs: List[float] = [0.0],
    fov_phis: List[float] = [0.0],
    ray_rings: int = 6,
    ray_arms: int = 8,
    is_wrap: bool = False,
):
    """Draw lens 3D layout with rays using pyvista.

    Note: PyVista is imported lazily only when this method is called.

    Args:
        plotter: pv.Plotter. Optional pyvista Plotter instance. If None, a new one is created.
        save_dir (str): The directory to save the image.
        mesh_rings (int): The number of rings in the mesh.
        mesh_arms (int): The number of arms in the mesh.
        surface_color (List[float]): The color of the surfaces.
        draw_rays (bool): Whether to show the rays.
        fovs (List[float]): The FoV angles to be sampled, unit: degree.
        fov_phis (List[float]): The FoV azimuthal angles to be sampled, unit: degree.
        ray_rings (int): The number of pupil rings to be sampled.
        ray_arms (int): The number of pupil arms to be sampled.
        is_wrap (bool): Whether to wrap the lens bridge around the lens as cylinder.

    Returns:
        plotter: pv.Plotter. The pyvista Plotter instance.
    """
    # Lazy import of pyvista
    try:
        import pyvista as pv
    except ImportError as e:
        raise ImportError(
            "PyVista is required for 3D GUI rendering. Install with `pip install pyvista`."
        ) from e

    # Create plotter if not provided
    if plotter is None:
        plotter = pv.Plotter()

    surf_color = surface_color
    sensor_color = [0.5, 0.5, 0.5]

    # Create meshes
    surf_meshes, bridge_meshes, _, sensor_mesh = self.create_mesh(
        mesh_rings, mesh_arms, is_wrap
    )

    # Draw meshes
    for surf in surf_meshes:
        if not isinstance(surf, Aperture):
            _draw_mesh_to_plotter(
                plotter, surf, color=surf_color, opacity=0.5, pv=pv
            )

    for bridge_group in bridge_meshes:
        for bridge_mesh in bridge_group:
            _draw_mesh_to_plotter(
                plotter, bridge_mesh, color=surf_color, opacity=0.5, pv=pv
            )

    _draw_mesh_to_plotter(
        plotter, sensor_mesh, color=sensor_color, opacity=1.0, pv=pv
    )

    # Draw rays
    if draw_rays:
        rays_curve = geolens_ray_poly(
            self, fovs, fov_phis, n_rings=ray_rings, n_arms=ray_arms
        )

        rays_poly_list = [curve_list_to_polydata(r) for r in rays_curve]
        rays_poly_fov = [merge(r) for r in rays_poly_list]
        rays_poly_fov = [_wrap_base_poly_to_pyvista(r, pv) for r in rays_poly_fov]
        for r in rays_poly_fov:
            plotter.add_mesh(r)

    # Save images
    if save_dir is not None:
        os.makedirs(save_dir, exist_ok=True)
        plotter.show(screenshot=os.path.join(save_dir, "lens_layout3d.png"))

    return plotter

save_lens_obj

save_lens_obj(save_dir: str, mesh_rings: int = 64, mesh_arms: int = 128, save_rays: bool = False, fovs: List[float] = [0.0], fov_phis: List[float] = [0.0], ray_rings: int = 6, ray_arms: int = 8, is_wrap: bool = False, save_elements: bool = True)

Save lens geometry and rays as .obj files using pyvista.

Note: use #F2F7FFFF as the color for lens when rendering in Blender.

Parameters:

Name Type Description Default
lens GeoLens

The lens object.

required
save_dir str

The directory to save the image.

required
mesh_rings int

The number of rings in the mesh. (default: 128)

64
mesh_arms int

The number of arms in the mesh. (default: 256)

128
save_rays bool

Whether to save the rays.

False
fovs List[float]

The FoV angles to be sampled, unit: degree.

[0.0]
fov_phis List[float]

The FoV azimuthal angles to be sampled, unit: degree.

[0.0]
ray_rings int

The number of pupil rings to be sampled. (default: 6)

6
ray_arms int

The number of pupil arms to be sampled. (default: 8)

8
is_wrap bool

Whether to wrap the lens bridge around the lens as cylinder.

False
save_elements bool

Whether to save the elements.

True
Source code in deeplens/optics/geolens_pkg/vis3d.py
def save_lens_obj(
    self,
    save_dir: str,
    mesh_rings: int = 64,
    mesh_arms: int = 128,
    save_rays: bool = False,
    fovs: List[float] = [0.0],
    fov_phis: List[float] = [0.0],
    ray_rings: int = 6,
    ray_arms: int = 8,
    is_wrap: bool = False,
    save_elements: bool = True,
):
    """Save lens geometry and rays as .obj files using pyvista.

    Note: use #F2F7FFFF as the color for lens when rendering in Blender.

    Args:
        lens (GeoLens): The lens object.
        save_dir (str): The directory to save the image.
        mesh_rings (int): The number of rings in the mesh. (default: 128)
        mesh_arms (int): The number of arms in the mesh. (default: 256)
        save_rays (bool): Whether to save the rays.
        fovs (List[float]): The FoV angles to be sampled, unit: degree.
        fov_phis (List[float]): The FoV azimuthal angles to be sampled, unit: degree.
        ray_rings (int): The number of pupil rings to be sampled. (default: 6)
        ray_arms (int): The number of pupil arms to be sampled. (default: 8)
        is_wrap (bool): Whether to wrap the lens bridge around the lens as cylinder.
        save_elements (bool): Whether to save the elements.
    """
    os.makedirs(save_dir, exist_ok=True)

    # Create surfaces & bridges meshes
    surf_meshes, bridge_meshes, element_groups, sensor_mesh = self.create_mesh(
        mesh_rings, mesh_arms, is_wrap
    )

    # Save individual lens elements (surfaces + bridges merged)
    if save_elements:
        for i, pair in enumerate(element_groups):
            print(f"Running in pair {i} with pair length {len(pair)}")
            # Collect surface polydata
            surf_polydata_list = [surf_meshes[idx].get_polydata() for idx in pair]

            # Collect bridge polydata if available
            bridge_polydata_list = []
            if i < len(bridge_meshes) and len(bridge_meshes[i]) > 0:
                print(f"Bridge mesh group number: {len(bridge_meshes[i])}")
                bridge_polydata_list = [b.get_polydata() for b in bridge_meshes[i]]

            # Merge surfaces and bridges together
            all_polydata = surf_polydata_list + bridge_polydata_list
            if len(all_polydata) == 1:
                element = all_polydata[0]
            else:
                element = merge(all_polydata)
            element.save(os.path.join(save_dir, f"element_{i}.obj"))

    # Merge all surfaces and bridges, and save as single lens.obj file
    surf_polydata = [
        surf.get_polydata()
        for surf in surf_meshes
        if not isinstance(surf, Aperture)
    ]
    bridge_polydata = [
        b.get_polydata() for group in bridge_meshes for b in group
    ]  # flatten the nested list
    lens_polydata = surf_polydata + bridge_polydata
    lens_polydata = merge(lens_polydata)
    lens_polydata.save(os.path.join(save_dir, "lens.obj"))

    # Save sensor
    sensor_polydata = sensor_mesh.get_polydata()
    sensor_polydata.save(os.path.join(save_dir, "sensor.obj"))

    # Save rays
    if save_rays:
        rays_curve = geolens_ray_poly(
            self, fovs, fov_phis, n_rings=ray_rings, n_arms=ray_arms
        )
        rays_poly_list = [curve_list_to_polydata(r) for r in rays_curve]
        rays_poly_fov = [merge(r) for r in rays_poly_list]
        for i, r in enumerate(rays_poly_fov):
            r.save(os.path.join(save_dir, f"lens_rays_fov_{i}.obj"))

Combines a GeoLens with a diffractive optical element (DOE). Performs coherent ray tracing to the DOE plane, then Angular Spectrum Method (ASM) propagation to the sensor.

deeplens.optics.HybridLens

HybridLens(filename=None, device=None, dtype=torch.float64)

Bases: Lens

Hybrid refractive-diffractive lens using a differentiable ray–wave model.

Combines a :class:~deeplens.optics.geolens.GeoLens (refractive module) with a diffractive optical element (DOE) placed behind it. The pipeline is:

  1. Coherent ray tracing through the embedded GeoLens to obtain a complex wavefront at the DOE plane (including all geometric aberrations).
  2. DOE phase modulation applied to the wavefront.
  3. Angular Spectrum Method (ASM) propagation from the DOE to the sensor plane to produce the final intensity PSF.

This enables end-to-end gradient flow from image quality metrics back to both refractive surface parameters and the DOE phase profile.

Attributes:

Name Type Description
geolens GeoLens

Embedded refractive module.

doe

Diffractive optical element (one of Binary2, Pixel2D, Fresnel, Zernike, Grating).

Notes

Operates in torch.float64 by default for numerical stability of the wave-propagation step.

References

Xinge Yang et al., "End-to-End Hybrid Refractive-Diffractive Lens Design with Differentiable Ray-Wave Model," SIGGRAPH Asia 2024.

Initialize a hybrid refractive-diffractive lens.

Parameters:

Name Type Description Default
filename str

Path to the lens configuration JSON file. Defaults to None.

None
device str

Computation device ('cpu' or 'cuda'). Defaults to None.

None
dtype dtype

Data type for computations. Defaults to torch.float64.

float64
Source code in deeplens/optics/hybridlens.py
def __init__(
    self,
    filename=None,
    device=None,
    dtype=torch.float64,
):
    """Initialize a hybrid refractive-diffractive lens.

    Args:
        filename (str, optional): Path to the lens configuration JSON file. Defaults to None.
        device (str, optional): Computation device ('cpu' or 'cuda'). Defaults to None.
        dtype (torch.dtype, optional): Data type for computations. Defaults to torch.float64.
    """
    super().__init__(device=device, dtype=dtype)

    # Load lens file
    if filename is not None:
        self.read_lens_json(filename)
    else:
        self.geolens = None
        self.doe = None
        # Set default sensor size and resolution if no file provided
        self.sensor_size = (8.0, 8.0)
        self.sensor_res = (2000, 2000)
        print(
            f"No lens file provided. Using default sensor_size: {self.sensor_size} mm, "
            f"sensor_res: {self.sensor_res} pixels. Use set_sensor() to change."
        )

    self.double()

read_lens_json

read_lens_json(filename)

Read the lens configuration from a JSON file.

Loads a :class:GeoLens and associated DOE from the specified file. A Plane surface is appended to the GeoLens surface list as a placeholder for the DOE plane.

Supported DOE types: binary2, pixel2d, fresnel, zernike, grating.

Parameters:

Name Type Description Default
filename str

Path to the JSON configuration file. Must contain a "DOE" key with a "type" field.

required

Raises:

Type Description
ValueError

If the DOE type in the file is not supported.

Source code in deeplens/optics/hybridlens.py
def read_lens_json(self, filename):
    """Read the lens configuration from a JSON file.

    Loads a :class:`GeoLens` and associated DOE from the specified file.
    A ``Plane`` surface is appended to the GeoLens surface list as a
    placeholder for the DOE plane.

    Supported DOE types: ``binary2``, ``pixel2d``, ``fresnel``,
    ``zernike``, ``grating``.

    Args:
        filename (str): Path to the JSON configuration file.  Must
            contain a ``"DOE"`` key with a ``"type"`` field.

    Raises:
        ValueError: If the DOE type in the file is not supported.
    """
    # Load geolens
    geolens = GeoLens(filename=filename, device=self.device)

    # Load DOE (diffractive surface)
    with open(filename, "r") as f:
        data = json.load(f)

        doe_dict = data["DOE"]
        doe_param_model = doe_dict["type"].lower()
        if doe_param_model == "binary2":
            doe = Binary2.init_from_dict(doe_dict)
        elif doe_param_model == "pixel2d":
            doe = Pixel2D.init_from_dict(doe_dict)
        elif doe_param_model == "fresnel":
            doe = Fresnel.init_from_dict(doe_dict)
        elif doe_param_model == "zernike":
            doe = Zernike.init_from_dict(doe_dict)
        elif doe_param_model == "grating":
            doe = Grating.init_from_dict(doe_dict)
        else:
            raise ValueError(f"Unsupported DOE parameter model: {doe_param_model}")
        self.doe = doe

    # Add a Plane/Phase surface to GeoLens (DOE placeholder)
    r_doe = float(np.sqrt(doe.w**2 + doe.h**2) / 2)
    geolens.surfaces.append(Plane(d=doe.d.item(), r=r_doe, mat2="air"))
    # r_doe = float(np.sqrt(doe.w**2 + doe.h**2) / 2)
    # geolens.surfaces.append(Phase(r=r_doe, d=doe.d))
    self.geolens = geolens
    self.foclen = geolens.foclen

    # Update hybrid lens sensor resolution and pixel size
    self.set_sensor(sensor_size=geolens.sensor_size, sensor_res=geolens.sensor_res)
    self.to(self.device)

write_lens_json

write_lens_json(lens_path)

Write the lens configuration to a JSON file.

Serialises the GeoLens surfaces (excluding the DOE placeholder) and the DOE configuration into a single JSON file that can be reloaded with :meth:read_lens_json.

Parameters:

Name Type Description Default
lens_path str

Output file path.

required
Source code in deeplens/optics/hybridlens.py
def write_lens_json(self, lens_path):
    """Write the lens configuration to a JSON file.

    Serialises the ``GeoLens`` surfaces (excluding the DOE placeholder)
    and the ``DOE`` configuration into a single JSON file that can be
    reloaded with :meth:`read_lens_json`.

    Args:
        lens_path (str): Output file path.
    """
    geolens = self.geolens
    data = {}
    data["info"] = geolens.lens_info if hasattr(geolens, "lens_info") else "None"
    data["foclen"] = round(geolens.foclen, 4)
    data["fnum"] = round(geolens.fnum, 4)
    data["r_sensor"] = round(geolens.r_sensor, 4)
    data["d_sensor"] = round(geolens.d_sensor.item(), 4)
    data["sensor_size"] = [round(i, 4) for i in geolens.sensor_size]
    data["sensor_res"] = geolens.sensor_res

    # Geolens
    data["surfaces"] = []
    for i, s in enumerate(geolens.surfaces[:-1]):
        surf_dict = s.surf_dict()

        # To exclude the last surface (DOE)
        if i < len(geolens.surfaces) - 2:
            surf_dict["d_next"] = round(
                geolens.surfaces[i + 1].d.item() - geolens.surfaces[i].d.item(), 3
            )
        else:
            surf_dict["d_next"] = round(
                geolens.d_sensor.item() - geolens.surfaces[i].d.item(), 3
            )

        data["surfaces"].append(surf_dict)

    # DOE
    data["DOE"] = self.doe.surf_dict()

    with open(lens_path, "w") as f:
        json.dump(data, f, indent=4)

analysis

analysis(save_name='./test.png')

Run a quick visual analysis of the hybrid lens.

Generates two figures: the 2D lens layout (saved to save_name) and the DOE phase map (saved to <save_name>_doe.png).

Parameters:

Name Type Description Default
save_name str

Base file path for the layout image. The DOE phase-map image is derived by appending _doe before the extension. Defaults to './test.png'.

'./test.png'
Source code in deeplens/optics/hybridlens.py
def analysis(self, save_name="./test.png"):
    """Run a quick visual analysis of the hybrid lens.

    Generates two figures: the 2D lens layout (saved to *save_name*) and
    the DOE phase map (saved to ``<save_name>_doe.png``).

    Args:
        save_name (str, optional): Base file path for the layout image.
            The DOE phase-map image is derived by appending ``_doe``
            before the extension.  Defaults to ``'./test.png'``.
    """
    self.draw_layout(save_name=save_name)
    self.doe.draw_phase_map(save_name=f"{save_name}_doe.png")

double

double()

Convert the GeoLens and DOE to float64 precision.

Double precision is required for numerically stable phase accumulation during coherent ray tracing and ASM propagation. Called automatically by :meth:__init__.

Source code in deeplens/optics/hybridlens.py
def double(self):
    """Convert the GeoLens and DOE to ``float64`` precision.

    Double precision is required for numerically stable phase
    accumulation during coherent ray tracing and ASM propagation.
    Called automatically by :meth:`__init__`.
    """
    self.geolens.astype(torch.float64)
    self.doe.astype(torch.float64)

refocus

refocus(foc_dist)

Refocus the hybrid lens to a given object distance.

Only the GeoLens sensor-to-last-surface spacing is adjusted; the DOE remains fixed relative to the refractive group (it is physically cemented to the lens barrel).

Parameters:

Name Type Description Default
foc_dist float

Target focus distance in [mm] (negative, towards the object).

required
Source code in deeplens/optics/hybridlens.py
def refocus(self, foc_dist):
    """Refocus the hybrid lens to a given object distance.

    Only the ``GeoLens`` sensor-to-last-surface spacing is adjusted; the
    DOE remains fixed relative to the refractive group (it is physically
    cemented to the lens barrel).

    Args:
        foc_dist (float): Target focus distance in [mm] (negative,
            towards the object).
    """
    self.geolens.refocus(foc_dist)

calc_scale

calc_scale(depth)

Calculate the object-to-image magnification scale factor.

Delegates to the embedded :class:GeoLens.

Parameters:

Name Type Description Default
depth float

Object distance in [mm] (negative, towards the object).

required

Returns:

Name Type Description
float

Scale factor mapping normalised sensor coordinates [-1, 1] to physical object-space coordinates [mm].

Source code in deeplens/optics/hybridlens.py
def calc_scale(self, depth):
    """Calculate the object-to-image magnification scale factor.

    Delegates to the embedded :class:`GeoLens`.

    Args:
        depth (float): Object distance in [mm] (negative, towards the
            object).

    Returns:
        float: Scale factor mapping normalised sensor coordinates
            ``[-1, 1]`` to physical object-space coordinates [mm].
    """
    return self.geolens.calc_scale(depth)

doe_field

doe_field(point, wvln=DEFAULT_WAVE, spp=SPP_COHERENT)

Compute the complex wave field at the DOE plane via coherent ray tracing.

Similar to GeoLens.pupil_field(), but evaluates the field at the last surface (DOE plane) instead of the exit pupil. The returned wavefront encodes amplitude, phase, and all diffraction-order information needed for subsequent DOE modulation and ASM propagation.

Parameters:

Name Type Description Default
point Tensor

Point source position, shape (3,) or (1, 3) as [x, y, z] in normalised sensor coordinates for x/y and mm for z.

required
wvln float

Wavelength in [um]. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE
spp int

Number of rays to sample. Must be

= 1,000,000 for accurate coherent simulation. Defaults to SPP_COHERENT.

SPP_COHERENT

Returns:

Name Type Description
tuple
  • wavefront (torch.Tensor) -- Complex wavefront at the DOE plane, shape [H, W].
  • psf_center (list[float]) -- Estimated PSF centre on the sensor in normalised coordinates [x, y].

Raises:

Type Description
AssertionError

If spp < 1,000,000 or the default dtype is not float64.

Source code in deeplens/optics/hybridlens.py
def doe_field(self, point, wvln=DEFAULT_WAVE, spp=SPP_COHERENT):
    """Compute the complex wave field at the DOE plane via coherent ray tracing.

    Similar to ``GeoLens.pupil_field()``, but evaluates the field at the
    last surface (DOE plane) instead of the exit pupil.  The returned
    wavefront encodes amplitude, phase, and all diffraction-order
    information needed for subsequent DOE modulation and ASM propagation.

    Args:
        point (torch.Tensor): Point source position, shape ``(3,)`` or
            ``(1, 3)`` as ``[x, y, z]`` in normalised sensor coordinates
            for x/y and mm for z.
        wvln (float, optional): Wavelength in [um].  Defaults to
            ``DEFAULT_WAVE``.
        spp (int, optional): Number of rays to sample.  Must be
            >= 1,000,000 for accurate coherent simulation.  Defaults to
            ``SPP_COHERENT``.

    Returns:
        tuple:
            - **wavefront** (*torch.Tensor*) -- Complex wavefront at the
              DOE plane, shape ``[H, W]``.
            - **psf_center** (*list[float]*) -- Estimated PSF centre on
              the sensor in normalised coordinates ``[x, y]``.

    Raises:
        AssertionError: If *spp* < 1,000,000 or the default dtype is not
            ``float64``.
    """
    assert spp >= 1_000_000, (
        "Coherent ray tracing spp is too small, "
        "which may lead to inaccurate simulation."
    )
    assert torch.get_default_dtype() == torch.float64, (
        "Default dtype must be set to float64 for accurate phase tracing."
    )

    geolens, doe = self.geolens, self.doe

    if point.dim() == 1:
        point = point.unsqueeze(0)
    point = point.to(self.device)

    # Calculate ray origin in the object space
    scale = geolens.calc_scale(point[:, 2].item())
    point_obj = point.clone()
    point_obj[:, 0] = point[:, 0] * scale * geolens.sensor_size[1] / 2
    point_obj[:, 1] = point[:, 1] * scale * geolens.sensor_size[0] / 2

    # Determine ray center via chief ray
    pointc_chief_ray = geolens.psf_center(point_obj, method="chief_ray")[
        0
    ]  # shape [2]

    # Ray tracing to the DOE plane
    ray = geolens.sample_from_points(points=point_obj, num_rays=spp, wvln=wvln)
    ray.coherent = True
    ray, _ = geolens.trace(ray)
    ray = ray.prop_to(doe.d)

    # Calculate full-resolution complex field for exit-pupil diffraction
    wavefront = forward_integral(
        ray.flip_xy(),
        ps=doe.ps,
        ks=doe.res[0],
        pointc=torch.zeros_like(point[:, :2]),
    ).squeeze(0)  # shape [H, W]

    # Compute PSF center based on chief ray
    psf_center = [
        pointc_chief_ray[0] / geolens.sensor_size[0] * 2,
        pointc_chief_ray[1] / geolens.sensor_size[1] * 2,
    ]

    return wavefront, psf_center

psf

psf(points=[0.0, 0.0, -10000.0], ks=PSF_KS, wvln=DEFAULT_WAVE, spp=SPP_COHERENT)

Compute a single-point monochromatic PSF using the ray-wave model.

The returned PSF includes all diffraction orders with physically correct diffraction efficiencies. The pipeline is:

  1. Coherent ray tracing through the GeoLens to obtain the complex wavefront at the DOE plane.
  2. DOE phase modulation applied to the wavefront.
  3. ASM propagation to the sensor, intensity calculation, cropping, and normalisation.

Parameters:

Name Type Description Default
points list or Tensor

[x, y, z] point source coordinates. x, y are in normalised sensor coordinates [-1, 1]; z is depth in [mm]. Defaults to [0.0, 0.0, -10000.0].

[0.0, 0.0, -10000.0]
ks int or None

Output PSF patch size. If None, returns the central quarter of the full-sensor intensity. Defaults to PSF_KS.

PSF_KS
wvln float

Wavelength in [um]. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE
spp int

Number of coherent rays to sample. Defaults to SPP_COHERENT.

SPP_COHERENT

Returns:

Type Description

torch.Tensor: Normalised PSF patch (sums to 1), shape [ks, ks]. Returned in float32 precision.

Raises:

Type Description
ValueError

If the default dtype is not float64 (call :meth:double first).

Source code in deeplens/optics/hybridlens.py
def psf(
    self,
    points=[0.0, 0.0, -10000.0],
    ks=PSF_KS,
    wvln=DEFAULT_WAVE,
    spp=SPP_COHERENT,
):
    """Compute a single-point monochromatic PSF using the ray-wave model.

    The returned PSF includes all diffraction orders with physically
    correct diffraction efficiencies.  The pipeline is:

    1. Coherent ray tracing through the ``GeoLens`` to obtain the complex
       wavefront at the DOE plane.
    2. DOE phase modulation applied to the wavefront.
    3. ASM propagation to the sensor, intensity calculation, cropping, and
       normalisation.

    Args:
        points (list or torch.Tensor, optional): ``[x, y, z]`` point
            source coordinates.  *x, y* are in normalised sensor
            coordinates ``[-1, 1]``; *z* is depth in [mm].  Defaults to
            ``[0.0, 0.0, -10000.0]``.
        ks (int or None, optional): Output PSF patch size.  If ``None``,
            returns the central quarter of the full-sensor intensity.
            Defaults to ``PSF_KS``.
        wvln (float, optional): Wavelength in [um].  Defaults to
            ``DEFAULT_WAVE``.
        spp (int, optional): Number of coherent rays to sample.  Defaults
            to ``SPP_COHERENT``.

    Returns:
        torch.Tensor: Normalised PSF patch (sums to 1), shape
            ``[ks, ks]``.  Returned in ``float32`` precision.

    Raises:
        ValueError: If the default dtype is not ``float64`` (call
            :meth:`double` first).
    """
    # Check double precision
    if not torch.get_default_dtype() == torch.float64:
        raise ValueError(
            "Please call HybridLens.double() to set the default dtype to float64 for accurate phase tracing."
        )

    # Check lens last surface
    assert isinstance(self.geolens.surfaces[-1], Phase) or isinstance(
        self.geolens.surfaces[-1], Plane
    ), "The last lens surface should be a DOE."
    geolens, doe = self.geolens, self.doe

    # Compute pupil field by coherent ray tracing
    if isinstance(points, list):
        point0 = torch.tensor(points)
    elif isinstance(points, torch.Tensor):
        point0 = points
    else:
        raise ValueError("point should be a list or a torch.Tensor.")

    wavefront, psfc = self.doe_field(point=point0, wvln=wvln, spp=spp)
    wavefront = wavefront.squeeze(0)  # shape of [H, W]

    # DOE phase modulation. We have to flip the phase map because the wavefront has been flipped
    phase_map = torch.flip(doe.get_phase_map(wvln), [-1, -2])
    wavefront = wavefront * torch.exp(1j * phase_map)

    # Propagate wave field to sensor plane
    h, w = wavefront.shape
    wavefront = F.pad(
        wavefront.unsqueeze(0).unsqueeze(0),
        [h // 2, h // 2, w // 2, w // 2],
        mode="constant",
        value=0,
    )
    sensor_field = AngularSpectrumMethod(
        wavefront, z=geolens.d_sensor - doe.d, wvln=wvln, ps=doe.ps, padding=False
    )

    # Compute PSF (intensity distribution)
    psf_inten = sensor_field.abs() ** 2
    psf_inten = (
        F.interpolate(
            psf_inten,
            scale_factor=geolens.sensor_res[0] / h,
            mode="bilinear",
            align_corners=False,
        )
        .squeeze(0)
        .squeeze(0)
    )

    # Calculate PSF center index and crop valid PSF region (Consider both interplation and padding)
    if ks is not None:
        h, w = psf_inten.shape[-2:]
        psfc_idx_i = ((2 - psfc[1]) * h / 4).round().long()
        psfc_idx_j = ((2 + psfc[0]) * w / 4).round().long()

        # Pad to avoid invalid edge region
        psf_inten_pad = F.pad(
            psf_inten,
            [ks // 2, ks // 2, ks // 2, ks // 2],
            mode="constant",
            value=0,
        )
        psf = psf_inten_pad[
            psfc_idx_i : psfc_idx_i + ks, psfc_idx_j : psfc_idx_j + ks
        ]
    else:
        h, w = psf_inten.shape[-2:]
        psf = psf_inten[
            int(h / 2 - h / 4) : int(h / 2 + h / 4),
            int(w / 2 - w / 4) : int(w / 2 + w / 4),
        ]

    # Normalize and convert to float precision
    psf /= psf.sum()  # shape of [ks, ks] or [h, w]
    return diff_float(psf)

draw_layout

draw_layout(save_name='./DOELens.png', depth=-10000.0, ax=None, fig=None)

Draw the hybrid-lens layout with ray paths and wave-propagation arcs.

Renders the refractive elements via GeoLens.draw_lens_2d(), traces rays at three field angles (on-axis, 0.707x, 0.99x full field), and overlays concentric arcs between the DOE and sensor to illustrate the wave-propagation region.

Parameters:

Name Type Description Default
save_name str

File path to save the figure (used only when ax is None). Defaults to './DOELens.png'.

'./DOELens.png'
depth float

Object depth [mm] for the traced rays. Defaults to -10000.0.

-10000.0
ax Axes

Pre-existing axes to draw into. If None, a new figure is created and saved.

None
fig Figure

Pre-existing figure. Required when ax is provided.

None

Returns:

Type Description

tuple or None: (ax, fig) when ax was provided; otherwise the figure is saved to save_name and nothing is returned.

Source code in deeplens/optics/hybridlens.py
@torch.no_grad()
def draw_layout(self, save_name="./DOELens.png", depth=-10000.0, ax=None, fig=None):
    """Draw the hybrid-lens layout with ray paths and wave-propagation arcs.

    Renders the refractive elements via ``GeoLens.draw_lens_2d()``, traces
    rays at three field angles (on-axis, 0.707x, 0.99x full field), and
    overlays concentric arcs between the DOE and sensor to illustrate the
    wave-propagation region.

    Args:
        save_name (str, optional): File path to save the figure (used only
            when *ax* is ``None``).  Defaults to ``'./DOELens.png'``.
        depth (float, optional): Object depth [mm] for the traced rays.
            Defaults to ``-10000.0``.
        ax (matplotlib.axes.Axes, optional): Pre-existing axes to draw
            into.  If ``None``, a new figure is created and saved.
        fig (matplotlib.figure.Figure, optional): Pre-existing figure.
            Required when *ax* is provided.

    Returns:
        tuple or None: ``(ax, fig)`` when *ax* was provided; otherwise
            the figure is saved to *save_name* and nothing is returned.
    """
    geolens = self.geolens

    # Draw lens layout
    if ax is None:
        ax, fig = geolens.draw_lens_2d()
        save_fig = True
    else:
        save_fig = False

    # Draw light path
    color_list = ["#CC0000", "#006600", "#0066CC"]
    views = [
        0.0,
        float(np.rad2deg(geolens.rfov) * 0.707),
        float(np.rad2deg(geolens.rfov) * 0.99),
    ]
    arc_radi_list = [0.1, 0.4, 0.7, 1.0, 1.4, 1.8]
    num_rays = 7
    for i, view in enumerate(views):
        # Draw ray tracing
        ray = geolens.sample_point_source_2D(
            depth=depth,
            fov=view,
            num_rays=num_rays,
            entrance_pupil=True,
            wvln=WAVE_RGB[2 - i],
        )
        ray.prop_to(-1.0)

        ray, ray_o_record = geolens.trace(ray=ray, record=True)
        ax, fig = geolens.draw_ray_2d(
            ray_o_record, ax=ax, fig=fig, color=color_list[i]
        )

        # Draw wave propagation
        # Calculate ray center for wave propagation visualization
        ray_center_doe = (
            ((ray.o * ray.is_valid.unsqueeze(-1)).sum(dim=0) / ray.is_valid.sum())
            .cpu()
            .numpy()
        )  # shape [3]
        ray.prop_to(geolens.d_sensor)  # shape [num_rays, 3]
        ray_center_sensor = (
            ((ray.o * ray.is_valid.unsqueeze(-1)).sum(dim=0) / ray.is_valid.sum())
            .cpu()
            .numpy()
        )  # shape [3]

        arc_radi = ray_center_sensor[2] - ray_center_doe[2]
        chief_theta = np.rad2deg(
            np.arctan2(
                ray_center_sensor[0] - ray_center_doe[0],
                ray_center_sensor[2] - ray_center_doe[2],
            )
        )
        theta1 = chief_theta - 10
        theta2 = chief_theta + 10

        for j in arc_radi_list:
            arc_radi_j = arc_radi * j
            arc = patches.Arc(
                (ray_center_sensor[2], ray_center_sensor[0]),
                arc_radi_j,
                arc_radi_j,
                angle=180.0,
                theta1=theta1,
                theta2=theta2,
                color=color_list[i],
            )
            ax.add_patch(arc)

    if save_fig:
        # Save figure
        ax.axis("off")
        ax.set_title("DOE Lens")
        fig.savefig(save_name, bbox_inches="tight", format="png", dpi=600)
        plt.close()
    else:
        return ax, fig

get_optimizer

get_optimizer(doe_lr=0.0001, lens_lr=[0.0001, 0.0001, 0.01, 1e-05])

Build an Adam optimiser for joint lens + DOE design.

Collects trainable parameters from both the GeoLens (surface thicknesses, curvatures, conic constants, aspheric coefficients) and the DOE phase profile into a single optimiser with per-group learning rates.

Parameters:

Name Type Description Default
doe_lr float

Learning rate for DOE phase parameters. Defaults to 1e-4.

0.0001
lens_lr list[float]

Per-parameter-group learning rates for the GeoLens, ordered as [thickness_d, curvature_c, conic_k, aspheric_a]. Defaults to [1e-4, 1e-4, 1e-2, 1e-5].

[0.0001, 0.0001, 0.01, 1e-05]

Returns:

Type Description

torch.optim.Adam: Configured optimiser over all trainable parameters.

Source code in deeplens/optics/hybridlens.py
def get_optimizer(
    self, doe_lr=1e-4, lens_lr=[1e-4, 1e-4, 1e-2, 1e-5]
):
    """Build an Adam optimiser for joint lens + DOE design.

    Collects trainable parameters from both the ``GeoLens`` (surface
    thicknesses, curvatures, conic constants, aspheric coefficients) and
    the DOE phase profile into a single optimiser with per-group learning
    rates.

    Args:
        doe_lr (float, optional): Learning rate for DOE phase parameters.
            Defaults to ``1e-4``.
        lens_lr (list[float], optional): Per-parameter-group learning
            rates for the GeoLens, ordered as
            ``[thickness_d, curvature_c, conic_k, aspheric_a]``.
            Defaults to ``[1e-4, 1e-4, 1e-2, 1e-5]``.

    Returns:
        torch.optim.Adam: Configured optimiser over all trainable
            parameters.
    """
    params = []
    params += self.geolens.get_optimizer_params(lrs=lens_lr)
    params += self.doe.get_optimizer_params(lr=doe_lr)

    optimizer = torch.optim.Adam(params)
    return optimizer

Pure wave-optics lens using diffractive surfaces and scalar diffraction propagation.

deeplens.optics.DiffractiveLens

DiffractiveLens(filename=None, device=None)

Bases: Lens

Paraxial diffractive lens in which each element is modelled as a phase surface.

Every optical element (converging lens, DOE, metasurface, …) is represented by a phase function applied to an incoming complex wavefront. Propagation between surfaces uses the Angular Spectrum Method (ASM). This model is simple and fast, but accurate only in the paraxial regime (it does not account for higher-order geometric aberrations).

Attributes:

Name Type Description
surfaces list

Ordered list of diffractive/phase surfaces.

d_sensor Tensor

Distance from the last surface to the sensor plane [mm].

Notes

Operates in torch.float64 by default for numerical stability of the wave-propagation step.

Initialize a diffractive lens.

Parameters:

Name Type Description Default
filename str

Path to the lens configuration JSON file. If provided, loads the lens configuration from file. Defaults to None.

None
device str

Computation device ('cpu' or 'cuda'). Defaults to 'cpu'.

None
Source code in deeplens/optics/diffraclens.py
def __init__(
    self,
    filename=None,
    device=None,
):
    """Initialize a diffractive lens.

    Args:
        filename (str, optional): Path to the lens configuration JSON file. If provided, loads the lens configuration from file. Defaults to None.
        device (str, optional): Computation device ('cpu' or 'cuda'). Defaults to 'cpu'.
    """
    super().__init__(device=device)

    # Load lens file
    if filename is not None:
        self.read_lens_json(filename)
    else:
        self.surfaces = []
        # Set default sensor size and resolution if no file provided
        self.sensor_size = (8.0, 8.0)
        self.sensor_res = (2000, 2000)

    self.astype(torch.float64)

    # Use total track length (first element to sensor) as focal length
    if hasattr(self, "d_sensor"):
        self.foclen = float(self.d_sensor)
        self.calc_fov()

load_example1 classmethod

load_example1()

Create an example diffractive lens with a single Fresnel DOE.

Returns:

Name Type Description
DiffractiveLens

A configured diffractive lens with a Fresnel surface at f=50mm, 4mm size, and 4000 resolution.

Source code in deeplens/optics/diffraclens.py
@classmethod
def load_example1(cls):
    """Create an example diffractive lens with a single Fresnel DOE.

    Returns:
        DiffractiveLens: A configured diffractive lens with a Fresnel surface
            at f=50mm, 4mm size, and 4000 resolution.
    """
    self = cls(sensor_size=(4.0, 4.0), sensor_res=(2000, 2000))

    # Diffractive Fresnel DOE
    self.surfaces = [Fresnel(f0=50, d=0, size=4, res=4000)]

    # Sensor
    self.d_sensor = torch.tensor(50)
    self.foclen = float(self.d_sensor)
    self.calc_fov()

    self.to(self.device)
    return self

load_example2 classmethod

load_example2()

Create an example diffractive lens with a thin lens and binary DOE combination.

Returns:

Name Type Description
DiffractiveLens

A configured diffractive lens with a ThinLens (f=50mm) and a Binary2 DOE, both at 4mm size and 4000 resolution.

Source code in deeplens/optics/diffraclens.py
@classmethod
def load_example2(cls):
    """Create an example diffractive lens with a thin lens and binary DOE combination.

    Returns:
        DiffractiveLens: A configured diffractive lens with a ThinLens (f=50mm)
            and a Binary2 DOE, both at 4mm size and 4000 resolution.
    """
    self = cls(sensor_size=(8.0, 8.0), sensor_res=(2000, 2000))

    # Diffractive Fresnel DOE
    self.surfaces = [
        ThinLens(f0=50, d=0, size=4, res=4000),
        Binary2(d=0, size=4, res=4000),
    ]

    # Sensor
    self.d_sensor = torch.tensor(50)
    self.sensor_size = (8.0, 8.0)
    self.sensor_res = (2000, 2000)
    self.foclen = float(self.d_sensor)
    self.calc_fov()

    self.to(self.device)
    return self

read_lens_json

read_lens_json(filename)

Load the lens configuration from a JSON file.

Reads lens parameters including sensor configuration and diffractive surfaces from the specified JSON file. If sensor_size or sensor_res are not provided, defaults of 8mm x 8mm and 2000x2000 pixels will be used.

Parameters:

Name Type Description Default
filename str

Path to the JSON configuration file.

required
Source code in deeplens/optics/diffraclens.py
def read_lens_json(self, filename):
    """Load the lens configuration from a JSON file.

    Reads lens parameters including sensor configuration and diffractive surfaces
    from the specified JSON file. If sensor_size or sensor_res are not provided,
    defaults of 8mm x 8mm and 2000x2000 pixels will be used.

    Args:
        filename (str): Path to the JSON configuration file.
    """
    assert filename.endswith(".json"), "File must be a .json file."

    with open(filename, "r") as f:
        # Lens general info
        data = json.load(f)
        self.d_sensor = torch.tensor(data["d_sensor"])
        self.lens_info = data.get("info", "None")

        # Read sensor_size with default
        if "sensor_size" in data:
            self.sensor_size = tuple(data["sensor_size"])
        else:
            self.sensor_size = (8.0, 8.0)
            print(
                f"Sensor_size not found in lens file. Using default: {self.sensor_size} mm. "
                "Consider specifying sensor_size in the lens file or using set_sensor()."
            )

        # Read sensor_res with default
        if "sensor_res" in data:
            self.sensor_res = tuple(data["sensor_res"])
        else:
            self.sensor_res = (2000, 2000)
            print(
                f"Sensor_res not found in lens file. Using default: {self.sensor_res} pixels. "
                "Consider specifying sensor_res in the lens file or using set_sensor()."
            )

        # Load diffractive surfaces/elements
        d = 0.0
        self.surfaces = []
        for surf_dict in data["surfaces"]:
            surf_dict["d"] = d

            if surf_dict["type"].lower() == "binary2":
                s = Binary2.init_from_dict(surf_dict)
            elif surf_dict["type"].lower() == "fresnel":
                s = Fresnel.init_from_dict(surf_dict)
            elif surf_dict["type"].lower() == "pixel2d":
                s = Pixel2D.init_from_dict(surf_dict)
            elif surf_dict["type"].lower() == "thinlens":
                s = ThinLens.init_from_dict(surf_dict)
            elif surf_dict["type"].lower() == "zernike":
                s = Zernike.init_from_dict(surf_dict)
            else:
                raise ValueError(
                    f"Diffractive surface type {surf_dict['type']} not implemented."
                )

            self.surfaces.append(s)
            d_next = surf_dict["d_next"]
            d += d_next

write_lens_json

write_lens_json(filename)

Write the lens configuration to a JSON file.

Saves all lens parameters including sensor configuration and diffractive surface data to the specified file.

Parameters:

Name Type Description Default
filename str

Output path for the JSON file.

required
Source code in deeplens/optics/diffraclens.py
def write_lens_json(self, filename):
    """Write the lens configuration to a JSON file.

    Saves all lens parameters including sensor configuration and
    diffractive surface data to the specified file.

    Args:
        filename (str): Output path for the JSON file.
    """
    assert filename.endswith(".json"), "File must be a .json file."

    # Save lens to a file
    data = {}
    data["info"] = self.lens_info if hasattr(self, "lens_info") else "None"
    data["surfaces"] = []
    data["d_sensor"] = round(self.d_sensor.item(), 3)
    data["l_sensor"] = round(self.l_sensor, 3)
    data["sensor_res"] = self.sensor_res

    # Save diffractive surfaces
    for i, s in enumerate(self.surfaces):
        surf_dict = {"idx": i + 1}

        if isinstance(s, Pixel2D):
            surf_data = s.surf_dict(filename.replace(".json", "_pixel2d.pth"))
        else:
            surf_data = s.surf_dict()

        surf_dict.update(surf_data)

        if i < len(self.surfaces) - 1:
            surf_dict["d_next"] = (
                self.surfaces[i + 1].d.item() - self.surfaces[i].d.item()
            )

        data["surfaces"].append(surf_dict)

    # Save data to a file
    with open(filename, "w") as f:
        json.dump(data, f, indent=4)

__call__

__call__(wave)

Propagate a wave through the lens system.

Source code in deeplens/optics/diffraclens.py
def __call__(self, wave):
    """Propagate a wave through the lens system."""
    return self.forward(wave)

forward

forward(wave)

Propagate a wave through the diffractive lens system to the sensor.

Sequentially applies phase modulation from each diffractive surface, then propagates the wave to the sensor plane using wave optics.

Parameters:

Name Type Description Default
wave ComplexWave

Input wave field entering the lens system.

required

Returns:

Name Type Description
ComplexWave

Output wave field at the sensor plane.

Source code in deeplens/optics/diffraclens.py
def forward(self, wave):
    """Propagate a wave through the diffractive lens system to the sensor.

    Sequentially applies phase modulation from each diffractive surface, then propagates
    the wave to the sensor plane using wave optics.

    Args:
        wave (ComplexWave): Input wave field entering the lens system.

    Returns:
        ComplexWave: Output wave field at the sensor plane.
    """
    # Propagate to DOE
    for surf in self.surfaces:
        wave = surf(wave)

    # Propagate to sensor
    wave = wave.prop_to(self.d_sensor.item())

    return wave

render_mono

render_mono(img, wvln=DEFAULT_WAVE, ks=PSF_KS)

Simulate monochromatic lens blur by convolving an image with the point spread function.

Parameters:

Name Type Description Default
img Tensor

Input image. Shape: (B, 1, H, W)

required
wvln float

Wavelength. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE
ks int

PSF kernel size. Defaults to PSF_KS.

PSF_KS

Returns:

Type Description

torch.Tensor: Rendered image after applying lens blur with shape (B, 1, H, W).

Source code in deeplens/optics/diffraclens.py
def render_mono(self, img, wvln=DEFAULT_WAVE, ks=PSF_KS):
    """Simulate monochromatic lens blur by convolving an image with the point spread function.

    Args:
        img (torch.Tensor): Input image. Shape: (B, 1, H, W)
        wvln (float, optional): Wavelength. Defaults to DEFAULT_WAVE.
        ks (int, optional): PSF kernel size. Defaults to PSF_KS.

    Returns:
        torch.Tensor: Rendered image after applying lens blur with shape (B, 1, H, W).
    """
    psf = self.psf_infinite(wvln=wvln, ks=ks).unsqueeze(0)  # (1, ks, ks)
    img_render = conv_psf(img, psf)
    return img_render

psf

psf(depth=float('inf'), wvln=DEFAULT_WAVE, ks=PSF_KS, upsample_factor=1)

Calculate monochromatic point PSF by wave propagation approach.

Parameters:

Name Type Description Default
depth float

Depth of the point source. Defaults to float('inf').

float('inf')
wvln float

Wavelength in micrometers. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE
ks int

PSF kernel size. Defaults to PSF_KS.

PSF_KS
upsample_factor int

Upsampling factor to meet Nyquist sampling constraint. Defaults to 1.

1

Returns:

Name Type Description
psf_out tensor

PSF. shape [ks, ks]

Note

[1] Usually we only consider the on-axis PSF because paraxial approximation is implicitly applied for wave optical model. For the shifted phase issue, refer to "Modeling off-axis diffraction with the least-sampling angular spectrum method".

Source code in deeplens/optics/diffraclens.py
def psf(self, depth=float("inf"), wvln=DEFAULT_WAVE, ks=PSF_KS, upsample_factor=1):
    """Calculate monochromatic point PSF by wave propagation approach.

    Args:
        depth (float, optional): Depth of the point source. Defaults to float('inf').
        wvln (float, optional): Wavelength in micrometers. Defaults to DEFAULT_WAVE.
        ks (int, optional): PSF kernel size. Defaults to PSF_KS.
        upsample_factor (int, optional): Upsampling factor to meet Nyquist sampling constraint. Defaults to 1.

    Returns:
        psf_out (tensor): PSF. shape [ks, ks]

    Note:
        [1] Usually we only consider the on-axis PSF because paraxial approximation is implicitly applied for wave optical model. For the shifted phase issue, refer to "Modeling off-axis diffraction with the least-sampling angular spectrum method".
    """
    # Sample input wave field (We have to sample high resolution to meet Nyquist sampling constraint)
    field_res = [
        self.surfaces[0].res[0] * upsample_factor,
        self.surfaces[0].res[1] * upsample_factor,
    ]
    field_size = [
        self.surfaces[0].res[0] * self.surfaces[0].ps,
        self.surfaces[0].res[1] * self.surfaces[0].ps,
    ]
    if depth == float("inf"):
        inp_wave = ComplexWave.plane_wave(
            phy_size=field_size,
            res=field_res,
            wvln=wvln,
            z=0.0,
        ).to(self.device)
    else:
        inp_wave = ComplexWave.point_wave(
            point=[0.0, 0.0, depth],
            phy_size=field_size,
            res=field_res,
            wvln=wvln,
            z=0.0,
        ).to(self.device)

    # Calculate intensity on the sensor. Shape [H_sensor, W_sensor]
    output_wave = self.forward(inp_wave)
    intensity = output_wave.u.abs() ** 2

    # Interpolate wave to have the same pixel size as the sensor
    factor = output_wave.ps / self.pixel_size
    intensity = F.interpolate(
        intensity,
        scale_factor=(factor, factor),
        mode="bilinear",
        align_corners=False,
    )[0, 0, :, :]

    # Crop or pad wave to the sensor resolution
    intensity_h, intensity_w = intensity.shape[-2:]
    sensor_h, sensor_w = self.sensor_res
    if sensor_h < intensity_h or sensor_w < intensity_w:
        # crop
        start_h = (intensity_h - sensor_h) // 2
        start_w = (intensity_w - sensor_w) // 2
        intensity = intensity[
            start_h : start_h + sensor_h, start_w : start_w + sensor_w
        ]
    elif sensor_h > intensity_h or sensor_w > intensity_w:
        # pad
        pad_top = (sensor_h - intensity_h) // 2
        pad_bottom = sensor_h - intensity_h - pad_top
        pad_left = (sensor_w - intensity_w) // 2
        pad_right = sensor_w - intensity_w - pad_left
        intensity = F.pad(
            intensity,
            (pad_left, pad_right, pad_top, pad_bottom),
            mode="constant",
            value=0,
        )

    # Crop the valid patch from the full-resolution intensity map as the PSF
    coord_c_i = int(self.sensor_res[1] / 2)
    coord_c_j = int(self.sensor_res[0] / 2)
    intensity = F.pad(
        intensity,
        [ks // 2, ks // 2, ks // 2, ks // 2],
        mode="constant",
        value=0,
    )
    psf = intensity[coord_c_i : coord_c_i + ks, coord_c_j : coord_c_j + ks]

    # Normalize PSF
    psf /= psf.sum()
    psf = torch.flip(psf, [0, 1])

    return diff_float(psf)

draw_layout

draw_layout(save_name='./doelens.png')

Draw the lens layout diagram.

Visualizes the DOE and sensor positions in a 2D layout.

Parameters:

Name Type Description Default
save_name str

Path to save the figure. Defaults to './doelens.png'.

'./doelens.png'
Source code in deeplens/optics/diffraclens.py
def draw_layout(self, save_name="./doelens.png"):
    """Draw the lens layout diagram.

    Visualizes the DOE and sensor positions in a 2D layout.

    Args:
        save_name (str, optional): Path to save the figure. Defaults to './doelens.png'.
    """
    fig, ax = plt.subplots()

    # Draw DOE
    d = self.doe.d.item()
    doe_l = self.doe.l
    ax.plot(
        [d, d], [-doe_l / 2, doe_l / 2], "orange", linestyle="--", dashes=[1, 1]
    )

    # Draw sensor
    d = self.sensor.d.item()
    sensor_l = self.sensor.l
    width = 0.2  # Width of the rectangle
    rect = plt.Rectangle(
        (d - width / 2, -sensor_l / 2),
        width,
        sensor_l,
        facecolor="none",
        edgecolor="black",
        linewidth=1,
    )
    ax.add_patch(rect)

    ax.set_aspect("equal")
    ax.axis("off")
    fig.savefig(save_name, dpi=600, bbox_inches="tight")
    plt.close(fig)

draw_psf

draw_psf(depth=DEPTH, ks=PSF_KS, save_name='./psf_doelens.png', log_scale=True, eps=0.0001)

Draw on-axis RGB PSF.

Computes and saves a visualization of the RGB PSF for a given depth.

Parameters:

Name Type Description Default
depth float

Depth of the point source. Defaults to DEPTH.

DEPTH
ks int

Size of the PSF kernel in pixels. Defaults to PSF_KS.

PSF_KS
save_name str

Path to save the PSF image. Defaults to './psf_doelens.png'.

'./psf_doelens.png'
log_scale bool

If True, display PSF in log scale. Defaults to True.

True
eps float

Small value for log scale to avoid log(0). Defaults to 1e-4.

0.0001
Source code in deeplens/optics/diffraclens.py
def draw_psf(
    self,
    depth=DEPTH,
    ks=PSF_KS,
    save_name="./psf_doelens.png",
    log_scale=True,
    eps=1e-4,
):
    """Draw on-axis RGB PSF.

    Computes and saves a visualization of the RGB PSF for a given depth.

    Args:
        depth (float, optional): Depth of the point source. Defaults to DEPTH.
        ks (int, optional): Size of the PSF kernel in pixels. Defaults to PSF_KS.
        save_name (str, optional): Path to save the PSF image. Defaults to './psf_doelens.png'.
        log_scale (bool, optional): If True, display PSF in log scale. Defaults to True.
        eps (float, optional): Small value for log scale to avoid log(0). Defaults to 1e-4.
    """
    psf_rgb = self.psf_rgb(point=[0, 0, depth], ks=ks)

    if log_scale:
        psf_rgb = torch.log10(psf_rgb + eps)
        psf_rgb = (psf_rgb - psf_rgb.min()) / (psf_rgb.max() - psf_rgb.min())
        save_name = save_name.replace(".png", "_log.png")

    save_image(psf_rgb.unsqueeze(0), save_name, normalize=True)

get_optimizer

get_optimizer(lr)

Get optimizer for the lens parameters.

Parameters:

Name Type Description Default
lr float

Learning rate.

required

Returns:

Name Type Description
Optimizer

Optimizer object for lens parameters.

Source code in deeplens/optics/diffraclens.py
def get_optimizer(self, lr):
    """Get optimizer for the lens parameters.

    Args:
        lr (float): Learning rate.

    Returns:
        Optimizer: Optimizer object for lens parameters.
    """
    return self.doe.get_optimizer(lr=lr)

Thin-lens / circle-of-confusion model for simple depth-of-field and bokeh simulation.

deeplens.optics.ParaxialLens

ParaxialLens(foclen, fnum, sensor_size=None, sensor_res=None, device='cpu')

Bases: Lens

Thin-lens / ABCD-matrix model for fast defocus simulation.

Models the circle of confusion (CoC) caused by defocus but not higher-order optical aberrations. Useful as a fast baseline renderer for depth-of-field effects, as commonly used in Blender and similar tools.

Attributes:

Name Type Description
foclen float

Focal length [mm].

fnum float

F-number.

sensor_size tuple

Physical sensor size (W, H) [mm].

sensor_res tuple

Pixel resolution (W, H).

pixel_size float

Pixel pitch [mm].

Initialize a paraxial lens.

Parameters:

Name Type Description Default
foclen float

Focal length in [mm].

required
fnum float

F-number.

required
sensor_size tuple

Physical sensor size as (W, H) in [mm]. Defaults to (8.0, 8.0).

None
sensor_res tuple

Sensor resolution as (W, H) in pixels. Defaults to (2000, 2000).

None
device str

Computation device. Defaults to "cpu".

'cpu'
Source code in deeplens/optics/paraxiallens.py
def __init__(self, foclen, fnum, sensor_size=None, sensor_res=None, device="cpu"):
    """Initialize a paraxial lens.

    Args:
        foclen (float): Focal length in [mm].
        fnum (float): F-number.
        sensor_size (tuple, optional): Physical sensor size as (W, H) in [mm]. Defaults to (8.0, 8.0).
        sensor_res (tuple, optional): Sensor resolution as (W, H) in pixels. Defaults to (2000, 2000).
        device (str, optional): Computation device. Defaults to "cpu".
    """
    super(ParaxialLens, self).__init__(device=device)

    # Lens parameters
    self.foclen = foclen  # Focal length [mm]
    self.fnum = fnum

    # Sensor size and resolution with defaults
    if sensor_size is None:
        sensor_size = (8.0, 8.0)
        print(
            f"Sensor_size not provided. Using default: {sensor_size} mm. "
            "Use set_sensor() to change."
        )
    if sensor_res is None:
        sensor_res = (2000, 2000)
        print(
            f"Sensor_res not provided. Using default: {sensor_res} pixels. "
            "Use set_sensor() to change."
        )

    self.sensor_size = sensor_size
    self.sensor_res = sensor_res
    self.pixel_size = self.sensor_size[0] / self.sensor_res[0]  # Pixel size [mm]

    self.d_far = -20000.0
    self.d_close = -200.0
    self.refocus(foc_dist=-20000)

refocus

refocus(foc_dist)

Refocus the lens to a given object distance.

Parameters:

Name Type Description Default
foc_dist float

Focus distance in [mm]. Must be less than the focal length (i.e. beyond the focal point).

required

Raises:

Type Description
AssertionError

If foc_dist >= self.foclen.

Source code in deeplens/optics/paraxiallens.py
def refocus(self, foc_dist):
    """Refocus the lens to a given object distance.

    Args:
        foc_dist (float): Focus distance in [mm].  Must be less than the
            focal length (i.e. beyond the focal point).

    Raises:
        AssertionError: If *foc_dist* >= ``self.foclen``.
    """
    assert foc_dist < self.foclen, "Focus distance is too close."
    self.foc_dist = foc_dist

psf

psf(points, ks=PSF_KS, psf_type='gaussian', **kwargs)

PSF is modeled as a 2D uniform circular disk with diameter CoC.

Parameters:

Name Type Description Default
points Tensor

Points of the object. Shape [N, 3] or [3].

required
ks int

Kernel size.

PSF_KS
psf_type str

PSF type. "gaussian" or "pillbox".

'gaussian'
**kwargs

Additional arguments for psf(). Currently not used.

{}

Returns:

Name Type Description
psf Tensor

PSF kernels. Shape [ks, ks] or [N, ks, ks].

Source code in deeplens/optics/paraxiallens.py
def psf(self, points, ks=PSF_KS, psf_type="gaussian", **kwargs):
    """PSF is modeled as a 2D uniform circular disk with diameter CoC.

    Args:
        points (torch.Tensor): Points of the object. Shape [N, 3] or [3].
        ks (int): Kernel size.
        psf_type (str): PSF type. "gaussian" or "pillbox".
        **kwargs: Additional arguments for psf(). Currently not used.

    Returns:
        psf (torch.Tensor): PSF kernels. Shape [ks, ks] or [N, ks, ks].
    """
    points = points.to(self.device)

    # Handle single point vs multiple points
    if len(points.shape) == 1:
        points = points.unsqueeze(0)
        single_point = True
    else:
        single_point = False

    # Calculate circle of confusion for each point
    depths = points[:, 2]  # Shape [N]
    coc_values = self.coc(depths)  # Shape [N]

    # Convert CoC from mm to pixels and add minimum value for numerical stability
    coc_pixel = torch.clamp(
        coc_values / self.pixel_size, min=0.5
    )  # Shape [N], minimum 0.5 pixels
    coc_pixel = coc_pixel.unsqueeze(-1).unsqueeze(-1)  # Shape [N, 1, 1], broadcasts with [ks, ks]
    coc_pixel_radius = coc_pixel / 2

    # Create coordinate meshgrid
    x, y = torch.meshgrid(
        torch.linspace(-ks / 2 + 1 / 2, ks / 2 - 1 / 2, ks, device=self.device),
        torch.linspace(-ks / 2 + 1 / 2, ks / 2 - 1 / 2, ks, device=self.device),
        indexing="xy",
    )
    distance_sq = x**2 + y**2

    # Create PSF
    if psf_type == "gaussian":
        # Gaussian PSF
        psf = torch.exp(-distance_sq / (2 * coc_pixel_radius**2)) / (
            2 * np.pi * coc_pixel_radius**2
        )
    elif psf_type == "pillbox":
        # Pillbox PSF
        psf = torch.ones_like(x)
    else:
        raise ValueError(f"Invalid PSF type: {psf_type}")

    # Apply circular mask
    psf_mask = distance_sq < coc_pixel_radius**2
    psf = psf * psf_mask

    # Normalize PSF to sum to 1
    psf = psf / (psf.sum(dim=(-1, -2), keepdim=True) + EPSILON)

    if single_point:
        psf = psf.squeeze(0)

    return psf

coc

coc(depth)

Calculate circle of confusion (CoC) [mm].

Parameters:

Name Type Description Default
depth Tensor

Depth of the object. Shape [B].

required

Returns:

Name Type Description
coc Tensor

Circle of confusion. Shape [B].

Reference

[1] https://en.wikipedia.org/wiki/Circle_of_confusion

Source code in deeplens/optics/paraxiallens.py
def coc(self, depth):
    """Calculate circle of confusion (CoC) [mm].

    Args:
        depth (torch.Tensor): Depth of the object. Shape [B].

    Returns:
        coc (torch.Tensor): Circle of confusion. Shape [B].

    Reference:
        [1] https://en.wikipedia.org/wiki/Circle_of_confusion
    """
    foc_dist = torch.tensor(
        self.foc_dist, device=self.device, dtype=depth.dtype
    ).abs()
    foclen = self.foclen
    fnum = self.fnum

    depth = torch.clamp(depth, self.d_far, self.d_close)
    depth = torch.abs(depth)

    # Calculate circle of confusion diameter, [mm]
    part1 = torch.abs(depth - foc_dist) / depth
    part2 = foclen**2 / (fnum * (foc_dist - foclen))
    coc = part1 * part2

    return coc

dof

dof(depth)

Calculate depth of field [mm].

Parameters:

Name Type Description Default
depth Tensor

Depth of the object. Shape [B].

required

Returns:

Name Type Description
dof Tensor

Depth of field. Shape [B].

Reference

[1] https://en.wikipedia.org/wiki/Depth_of_field

Source code in deeplens/optics/paraxiallens.py
def dof(self, depth):
    """Calculate depth of field [mm].

    Args:
        depth (torch.Tensor): Depth of the object. Shape [B].

    Returns:
        dof (torch.Tensor): Depth of field. Shape [B].

    Reference:
        [1] https://en.wikipedia.org/wiki/Depth_of_field
    """
    depth = torch.clamp(depth, self.d_far, self.d_close)
    depth_abs = torch.abs(depth)

    foclen = self.foclen
    fnum = self.fnum

    # Magnification factor
    m = foclen / (depth_abs - foclen)

    # CoC, [mm]
    coc = self.coc(depth)

    # Depth of field, [mm]
    part1 = 2 * fnum * coc * (m + 1)
    part2 = m**2 - (fnum * coc / foclen) ** 2
    dof = part1 / part2

    return dof

psf_rgb

psf_rgb(points, ks=PSF_KS, **kwargs)

Compute RGB PSF by replicating the monochrome PSF across three channels.

The paraxial model is achromatic, so all channels share the same PSF.

Parameters:

Name Type Description Default
points Tensor

Point source positions, shape [N, 3].

required
ks int

Kernel size. Defaults to PSF_KS.

PSF_KS
**kwargs

Forwarded to :meth:psf.

{}

Returns:

Type Description

torch.Tensor: RGB PSFs, shape [N, 3, ks, ks].

Source code in deeplens/optics/paraxiallens.py
def psf_rgb(self, points, ks=PSF_KS, **kwargs):
    """Compute RGB PSF by replicating the monochrome PSF across three channels.

    The paraxial model is achromatic, so all channels share the same PSF.

    Args:
        points (torch.Tensor): Point source positions, shape ``[N, 3]``.
        ks (int, optional): Kernel size. Defaults to ``PSF_KS``.
        **kwargs: Forwarded to :meth:`psf`.

    Returns:
        torch.Tensor: RGB PSFs, shape ``[N, 3, ks, ks]``.
    """
    psf = self.psf(points, ks=ks, psf_type="gaussian", **kwargs)
    return psf.unsqueeze(1).repeat(1, 3, 1, 1)

psf_map

psf_map(grid=(5, 5), ks=PSF_KS, depth=DEPTH, **kwargs)

Compute a spatially-uniform monochrome PSF map.

Because the paraxial model has no spatially-varying aberrations, every grid position receives the same on-axis PSF.

Parameters:

Name Type Description Default
grid tuple

Grid dimensions (rows, cols). Defaults to (5, 5).

(5, 5)
ks int

Kernel size. Defaults to PSF_KS.

PSF_KS
depth float

Object depth [mm]. Defaults to DEPTH.

DEPTH
**kwargs

Forwarded to :meth:psf.

{}

Returns:

Type Description

torch.Tensor: PSF map, shape [rows, cols, 1, ks, ks].

Source code in deeplens/optics/paraxiallens.py
def psf_map(self, grid=(5, 5), ks=PSF_KS, depth=DEPTH, **kwargs):
    """Compute a spatially-uniform monochrome PSF map.

    Because the paraxial model has no spatially-varying aberrations, every
    grid position receives the same on-axis PSF.

    Args:
        grid (tuple, optional): Grid dimensions ``(rows, cols)``.
            Defaults to ``(5, 5)``.
        ks (int, optional): Kernel size. Defaults to ``PSF_KS``.
        depth (float, optional): Object depth [mm]. Defaults to ``DEPTH``.
        **kwargs: Forwarded to :meth:`psf`.

    Returns:
        torch.Tensor: PSF map, shape ``[rows, cols, 1, ks, ks]``.
    """
    points = torch.tensor([[0, 0, depth]], device=self.device)
    psf = self.psf(points=points, ks=ks, psf_type="gaussian", **kwargs)
    psf_map = psf.unsqueeze(0).unsqueeze(0).repeat(grid[0], grid[1], 1, 1, 1)
    return psf_map

psf_dp

psf_dp(points, ks=PSF_KS)

Generate dual-pixel PSF for left and right sub-apertures.

This function generates separate PSFs for left and right sub-apertures of a dual pixel sensor, which enables depth estimation and improved autofocus capabilities.

Parameters:

Name Type Description Default
points Tensor

Input tensor with shape [N, 3], where columns are [x, y, z] coordinates.

required
ks int

Kernel size for PSF generation.

PSF_KS

Returns:

Name Type Description
tuple

(left_psf, right_psf) where each PSF tensor has shape [N, ks, ks].

Source code in deeplens/optics/paraxiallens.py
def psf_dp(self, points, ks=PSF_KS):
    """Generate dual-pixel PSF for left and right sub-apertures.

    This function generates separate PSFs for left and right sub-apertures of a dual pixel sensor,
    which enables depth estimation and improved autofocus capabilities.

    Args:
        points (torch.Tensor): Input tensor with shape [N, 3], where columns are [x, y, z] coordinates.
        ks (int): Kernel size for PSF generation.

    Returns:
        tuple: (left_psf, right_psf) where each PSF tensor has shape [N, ks, ks].
    """
    N = points.shape[0]
    depth = points[:, 2]

    # Get the base PSF
    psf_base = self.psf(points, ks=ks, psf_type="gaussian")
    device = psf_base.device

    # Create left and right masks for dual pixel simulation
    l_mask = torch.ones((ks, ks), device=device)
    r_mask = torch.ones((ks, ks), device=device)

    # Split aperture vertically (left half and right half)
    l_pixel, r_pixel = ks // 2, ks // 2 + 1
    l_mask[:, 0:l_pixel] = 0  # Block right side for left PSF
    r_mask[:, r_pixel:] = 0  # Block left side for right PSF

    # Determine focus positions
    depth = depth.to(device)
    foc_dist = torch.tensor(self.foc_dist, device=device, dtype=depth.dtype)
    near_focus_pos = depth > foc_dist  # Shape [N]

    # Apply masks based on focus position (vectorized)
    # For near focus: left PSF gets left mask, right PSF gets right mask
    # For far focus: masks are swapped to create opposite asymmetry
    nfp = near_focus_pos.unsqueeze(-1).unsqueeze(-1)  # [N, 1, 1]
    mask_l = torch.where(nfp, l_mask, r_mask)  # [N, ks, ks]
    mask_r = torch.where(nfp, r_mask, l_mask)  # [N, ks, ks]
    psf_l = psf_base * mask_l
    psf_r = psf_base * mask_r

    # Normalize PSFs
    psf_l = psf_l / (psf_l.sum(dim=(-1, -2), keepdim=True) + EPSILON)
    psf_r = psf_r / (psf_r.sum(dim=(-1, -2), keepdim=True) + EPSILON)

    return psf_l, psf_r

psf_rgb_dp

psf_rgb_dp(points, ks=PSF_KS)

Compute RGB dual-pixel PSFs for left and right sub-apertures.

Replicates the monochrome dual-pixel PSFs across three colour channels.

Parameters:

Name Type Description Default
points Tensor

Point source positions, shape [N, 3].

required
ks int

Kernel size. Defaults to PSF_KS.

PSF_KS

Returns:

Name Type Description
tuple

(psf_left, psf_right) each of shape [N, 3, ks, ks].

Source code in deeplens/optics/paraxiallens.py
def psf_rgb_dp(self, points, ks=PSF_KS):
    """Compute RGB dual-pixel PSFs for left and right sub-apertures.

    Replicates the monochrome dual-pixel PSFs across three colour channels.

    Args:
        points (torch.Tensor): Point source positions, shape ``[N, 3]``.
        ks (int, optional): Kernel size. Defaults to ``PSF_KS``.

    Returns:
        tuple: ``(psf_left, psf_right)`` each of shape ``[N, 3, ks, ks]``.
    """
    psf_l, psf_r = self.psf_dp(points, ks=ks)
    psf_l = psf_l.unsqueeze(1).repeat(1, 3, 1, 1)
    psf_r = psf_r.unsqueeze(1).repeat(1, 3, 1, 1)
    return psf_l, psf_r

psf_map_dp

psf_map_dp(grid=(5, 5), ks=PSF_KS, depth=DEPTH, **kwargs)

Compute spatially-uniform dual-pixel PSF maps.

Parameters:

Name Type Description Default
grid tuple

Grid dimensions (rows, cols). Defaults to (5, 5).

(5, 5)
ks int

Kernel size. Defaults to PSF_KS.

PSF_KS
depth float

Object depth [mm]. Defaults to DEPTH.

DEPTH
**kwargs

Forwarded to :meth:psf_dp.

{}

Returns:

Name Type Description
tuple

(psf_map_left, psf_map_right) each of shape [rows, cols, 1, ks, ks].

Source code in deeplens/optics/paraxiallens.py
def psf_map_dp(self, grid=(5, 5), ks=PSF_KS, depth=DEPTH, **kwargs):
    """Compute spatially-uniform dual-pixel PSF maps.

    Args:
        grid (tuple, optional): Grid dimensions ``(rows, cols)``.
            Defaults to ``(5, 5)``.
        ks (int, optional): Kernel size. Defaults to ``PSF_KS``.
        depth (float, optional): Object depth [mm]. Defaults to ``DEPTH``.
        **kwargs: Forwarded to :meth:`psf_dp`.

    Returns:
        tuple: ``(psf_map_left, psf_map_right)`` each of shape
            ``[rows, cols, 1, ks, ks]``.
    """
    points = torch.tensor([[0, 0, depth]], device=self.device)
    psf_l, psf_r = self.psf_dp(points, ks=ks, **kwargs)
    psf_map_l = psf_l.unsqueeze(0).unsqueeze(0).repeat(grid[0], grid[1], 1, 1, 1)
    psf_map_r = psf_r.unsqueeze(0).unsqueeze(0).repeat(grid[0], grid[1], 1, 1, 1)
    return psf_map_l, psf_map_r

render_rgbd

render_rgbd(img_obj, depth_map, method='psf_patch', **kwargs)

Occlusion-aware RGBD rendering for paraxial lens.

Uses back-to-front layered compositing to prevent color bleeding at depth discontinuities. Since paraxial lenses have no spatially varying aberrations, all methods (psf_patch, psf_map, psf_pixel) produce identical results; the method parameter is accepted for API compatibility but ignored.

Parameters:

Name Type Description Default
img_obj tensor

Object image. Shape [B, C, H, W].

required
depth_map tensor

Depth map [mm]. Shape [B, 1, H, W]. Values should be positive.

required
method str

Ignored (no spatial variation). Defaults to "psf_patch".

'psf_patch'
**kwargs

Additional keyword arguments: - psf_ks (int): PSF kernel size. Defaults to PSF_KS. - num_layers (int): Number of depth layers. Defaults to 16. - depth_min (float): Minimum depth. Defaults to depth_map.min(). - depth_max (float): Maximum depth. Defaults to depth_map.max().

{}

Returns:

Name Type Description
img_render tensor

Rendered image. Shape [B, C, H, W].

Reference

[1] "Dr.Bokeh: DiffeRentiable Occlusion-aware Bokeh Rendering", CVPR 2024.

Source code in deeplens/optics/paraxiallens.py
def render_rgbd(self, img_obj, depth_map, method="psf_patch", **kwargs):
    """Occlusion-aware RGBD rendering for paraxial lens.

    Uses back-to-front layered compositing to prevent color bleeding at depth
    discontinuities. Since paraxial lenses have no spatially varying
    aberrations, all methods (psf_patch, psf_map, psf_pixel) produce
    identical results; the `method` parameter is accepted for API
    compatibility but ignored.

    Args:
        img_obj (tensor): Object image. Shape [B, C, H, W].
        depth_map (tensor): Depth map [mm]. Shape [B, 1, H, W]. Values should be positive.
        method (str, optional): Ignored (no spatial variation). Defaults to "psf_patch".
        **kwargs: Additional keyword arguments:
            - psf_ks (int): PSF kernel size. Defaults to PSF_KS.
            - num_layers (int): Number of depth layers. Defaults to 16.
            - depth_min (float): Minimum depth. Defaults to depth_map.min().
            - depth_max (float): Maximum depth. Defaults to depth_map.max().

    Returns:
        img_render (tensor): Rendered image. Shape [B, C, H, W].

    Reference:
        [1] "Dr.Bokeh: DiffeRentiable Occlusion-aware Bokeh Rendering", CVPR 2024.
    """
    if depth_map.min() < 0:
        raise ValueError("Depth map should be positive.")

    if len(depth_map.shape) == 3:
        depth_map = depth_map.unsqueeze(1)  # [B, H, W] -> [B, 1, H, W]

    psf_ks = kwargs.get("psf_ks", PSF_KS)
    num_layers = kwargs.get("num_layers", 16)
    depth_min = kwargs.get("depth_min", depth_map.min())
    depth_max = kwargs.get("depth_max", depth_map.max())

    # Sample depth layers
    disp_ref, depths_ref = self._sample_depth_layers(depth_min, depth_max, num_layers)

    # Compute PSF at each depth layer (spatially invariant, so patch_center=(0,0))
    points = torch.stack(
        [
            torch.zeros_like(depths_ref),
            torch.zeros_like(depths_ref),
            depths_ref,
        ],
        dim=-1,
    )
    psfs = self.psf_rgb(points=points, ks=psf_ks)  # [num_layers, 3, ks, ks]

    # Occlusion-aware rendering
    img_render = conv_psf_occlusion(img_obj, -depth_map, psfs, depths_ref)
    return img_render

render_rgbd_dp

render_rgbd_dp(rgb_img, depth)

Render RGBD image with dual-pixel PSF.

Parameters:

Name Type Description Default
rgb_img tensor

[B, 3, H, W]

required
depth tensor

[B, 1, H, W]

required

Returns:

Name Type Description
img_left tensor

[B, 3, H, W]

img_right tensor

[B, 3, H, W]

Source code in deeplens/optics/paraxiallens.py
def render_rgbd_dp(self, rgb_img, depth):
    """Render RGBD image with dual-pixel PSF.

    Args:
        rgb_img (tensor): [B, 3, H, W]
        depth (tensor): [B, 1, H, W]

    Returns:
        img_left (tensor): [B, 3, H, W]
        img_right (tensor): [B, 3, H, W]
    """
    # Convert depth to negative values
    if (depth > 0).any():
        depth = -depth

    depth_min = depth.min()
    depth_max = depth.max()
    num_depth = 10
    patch_center = (0.0, 0.0)
    psf_ks = PSF_KS

    # Calculate dual-pixel PSF at reference depths
    depths_ref = torch.linspace(depth_min, depth_max, num_depth, device=self.device)
    points = torch.stack(
        [
            torch.full_like(depths_ref, patch_center[0]),
            torch.full_like(depths_ref, patch_center[1]),
            depths_ref,
        ],
        dim=-1,
    )
    psfs_left, psfs_right = self.psf_rgb_dp(
        points=points, ks=psf_ks
    )  # shape [num_depth, 3, ks, ks]

    # Render dual-pixel image with PSF convolution and depth interpolation
    img_left = conv_psf_depth_interp(rgb_img, depth, psfs_left, depths_ref)
    img_right = conv_psf_depth_interp(rgb_img, depth, psfs_right, depths_ref)
    return img_left, img_right

Neural surrogate that wraps a GeoLens with an MLP to predict PSFs. Useful for fast, differentiable PSF evaluation during end-to-end training.

deeplens.optics.PSFNetLens

PSFNetLens(lens_path, in_chan=3, psf_chan=3, model_name='mlp_conv', kernel_size=64)

Bases: Lens

Neural surrogate lens that predicts PSFs via a small MLP/MLPConv network.

Wraps a :class:~deeplens.optics.geolens.GeoLens with a neural network trained to predict RGB PSFs from (fov, depth, focus_distance) inputs. After training, PSF prediction is ~100× faster than ray tracing, making it suitable for real-time applications and large-scale optimisation.

Attributes:

Name Type Description
lens GeoLens

The underlying refractive lens (used for training data generation and for sensor metadata).

psfnet Module

Neural network for PSF prediction.

pixel_size float

Pixel pitch [mm] (copied from the embedded lens).

rfov float

Half-diagonal field of view [radians].

Notes

Use :meth:train_psfnet to train the surrogate from ray-traced PSF samples. Use :meth:load_net to load pre-trained weights.

Initialize a PSF network lens.

In the default settings, the PSF network takes (fov, depth, foc_dist) as input and outputs RGB PSF on y-axis at (fov, depth, foc_dist).

Parameters:

Name Type Description Default
lens_path str

Path to the lens file.

required
in_chan int

Number of input channels.

3
psf_chan int

Number of output channels.

3
model_name str

Name of the model.

'mlp_conv'
kernel_size int

Kernel size.

64
Source code in deeplens/optics/psfnetlens.py
def __init__(
    self,
    lens_path,
    in_chan=3,
    psf_chan=3,
    model_name="mlp_conv",
    kernel_size=64,
):
    """Initialize a PSF network lens.

    In the default settings, the PSF network takes (fov, depth, foc_dist) as input and outputs RGB PSF on y-axis at (fov, depth, foc_dist).

    Args:
        lens_path (str): Path to the lens file.
        in_chan (int): Number of input channels.
        psf_chan (int): Number of output channels.
        model_name (str): Name of the model.
        kernel_size (int): Kernel size.
    """
    super().__init__()

    # Load lens (sensor_size and sensor_res are read from the lens file)
    self.lens_path = lens_path
    self.lens = GeoLens(filename=lens_path, device=self.device)
    self.foclen = self.lens.foclen
    self.rfov = self.lens.rfov

    # Init PSF network
    self.in_chan = in_chan
    self.psf_chan = psf_chan
    self.kernel_size = kernel_size
    self.pixel_size = self.lens.pixel_size

    self.psfnet = self.init_net(
        in_chan=in_chan,
        psf_chan=psf_chan,
        kernel_size=kernel_size,
        model_name=model_name,
    )
    self.psfnet.to(self.device)

    # Object depth range
    self.d_close = -200
    self.d_far = -20000

    # Focus distance range
    # There is a minimum focal distance for each lens. For example, the Canon EF 50mm lens can only focus to 0.5m and further.
    self.foc_d_close = -500
    self.foc_d_far = -20000
    self.refocus(foc_dist=-20000)

set_sensor_res

set_sensor_res(sensor_res)

Set sensor resolution for both PSFNetLens and the embedded GeoLens.

Updates the pixel size accordingly.

Parameters:

Name Type Description Default
sensor_res tuple

New sensor resolution as (W, H) in pixels.

required
Source code in deeplens/optics/psfnetlens.py
def set_sensor_res(self, sensor_res):
    """Set sensor resolution for both PSFNetLens and the embedded GeoLens.

    Updates the pixel size accordingly.

    Args:
        sensor_res (tuple): New sensor resolution as ``(W, H)`` in pixels.
    """
    self.lens.set_sensor_res(sensor_res)
    self.pixel_size = self.lens.pixel_size

init_net

init_net(in_chan=2, psf_chan=3, kernel_size=64, model_name='mlpconv')

Initialize a PSF network.

PSF network

Input: [B, 3], (fov, depth, foc_dist). fov from [0, pi/2], depth from [-20000, -100], foc_dist from [-20000, -500] Output: psf kernel [B, 3, ks, ks]

Parameters:

Name Type Description Default
in_chan int

number of input channels

2
psf_chan int

number of output channels

3
kernel_size int

kernel size

64
model_name str

name of the network architecture

'mlpconv'

Returns:

Name Type Description
psfnet Module

network

Source code in deeplens/optics/psfnetlens.py
def init_net(self, in_chan=2, psf_chan=3, kernel_size=64, model_name="mlpconv"):
    """Initialize a PSF network.

    PSF network:
        Input: [B, 3], (fov, depth, foc_dist). fov from [0, pi/2], depth from [-20000, -100], foc_dist from [-20000, -500]
        Output: psf kernel [B, 3, ks, ks]

    Args:
        in_chan (int): number of input channels
        psf_chan (int): number of output channels
        kernel_size (int): kernel size
        model_name (str): name of the network architecture

    Returns:
        psfnet (nn.Module): network
    """
    if model_name == "mlp":
        psfnet = MLP(
            in_features=in_chan,
            out_features=psf_chan * kernel_size**2,
            hidden_features=256,
            hidden_layers=8,
        )
    elif model_name == "mlpconv":
        psfnet = PSFNet_MLPConv(
            in_chan=in_chan, kernel_size=kernel_size, out_chan=psf_chan
        )
    else:
        raise Exception(f"Unsupported PSF network architecture: {model_name}.")

    return psfnet

load_net

load_net(net_path)

Load pretrained network.

Parameters:

Name Type Description Default
net_path str

path to load the network

required
Source code in deeplens/optics/psfnetlens.py
def load_net(self, net_path):
    """Load pretrained network.

    Args:
        net_path (str): path to load the network
    """
    # Check the correct model is loaded
    psfnet_dict = torch.load(net_path, map_location="cpu", weights_only=False)
    print(
        f"Pretrained model lens pixel size: {psfnet_dict['pixel_size']*1000.0:.1f} um, "
        f"Current lens pixel size: {self.pixel_size*1000.0:.1f} um"
    )
    print(
        f"Pretrained model lens path: {psfnet_dict['lens_path']}, "
        f"Current lens path: {self.lens_path}"
    )

    # Load the model weights
    self.psfnet.load_state_dict(psfnet_dict["psfnet_model_weights"])

save_psfnet

save_psfnet(psfnet_path)

Save the PSF network.

Parameters:

Name Type Description Default
psfnet_path str

path to save the PSF network

required
Source code in deeplens/optics/psfnetlens.py
def save_psfnet(self, psfnet_path):
    """Save the PSF network.

    Args:
        psfnet_path (str): path to save the PSF network
    """
    psfnet_dict = {
        "model_name": self.psfnet.__class__.__name__,
        "in_chan": self.in_chan,
        "pixel_size": self.pixel_size,
        "kernel_size": self.kernel_size,
        "psf_chan": self.psf_chan,
        "lens_path": self.lens_path,
        "psfnet_model_weights": self.psfnet.state_dict(),
    }
    torch.save(psfnet_dict, psfnet_path)

train_psfnet

train_psfnet(iters=100000, bs=128, lr=5e-05, evaluate_every=500, spp=16384, concentration_factor=2.0, result_dir='./results/psfnet')

Train the PSF surrogate network.

Parameters:

Name Type Description Default
iters int

number of training iterations

100000
bs int

batch size

128
lr float

learning rate

5e-05
evaluate_every int

evaluate every how many iterations

500
spp int

number of samples per pixel

16384
concentration_factor float

concentration factor for training data sampling

2.0
result_dir str

directory to save the results

'./results/psfnet'
Source code in deeplens/optics/psfnetlens.py
def train_psfnet(
    self,
    iters=100000,
    bs=128,
    lr=5e-5,
    evaluate_every=500,
    spp=16384,
    concentration_factor=2.0,
    result_dir="./results/psfnet",
):
    """Train the PSF surrogate network.

    Args:
        iters (int): number of training iterations
        bs (int): batch size
        lr (float): learning rate
        evaluate_every (int): evaluate every how many iterations
        spp (int): number of samples per pixel
        concentration_factor (float): concentration factor for training data sampling
        result_dir (str): directory to save the results
    """
    # Init network and prepare for training
    psfnet = self.psfnet
    loss_fn = nn.L1Loss()
    optimizer = torch.optim.AdamW(psfnet.parameters(), lr=lr)
    scheduler = get_cosine_schedule_with_warmup(
        optimizer, num_warmup_steps=int(iters) // 100, num_training_steps=iters
    )

    # Train the network
    for i in tqdm(range(iters + 1)):
        # Sample training data
        sample_input, sample_psf = self.sample_training_data(
            num_points=bs, concentration_factor=concentration_factor
        )
        sample_input, sample_psf = (
            sample_input.to(self.device),
            sample_psf.to(self.device),
        )

        # Forward pass, pred_psf: [B, 3, ks, ks]
        pred_psf = psfnet(sample_input)

        # Backward propagation
        optimizer.zero_grad()
        loss = loss_fn(pred_psf, sample_psf)
        loss.backward()
        optimizer.step()
        scheduler.step()

        # Evaluate training
        if (i + 1) % evaluate_every == 0:
            # Visualize PSFs
            n_vis = 16
            fig, axs = plt.subplots(n_vis, 2, figsize=(4, n_vis * 2))
            for j in range(n_vis):
                psf0 = sample_psf[j, ...].detach().clone().cpu()
                axs[j, 0].imshow(psf0.permute(1, 2, 0) * 255.0)
                axs[j, 0].axis("off")

                psf1 = pred_psf[j, ...].detach().clone().cpu()
                axs[j, 1].imshow(psf1.permute(1, 2, 0) * 255.0)
                axs[j, 1].axis("off")

            axs[0, 0].set_title("GT")
            axs[0, 1].set_title("Pred")

            fig.suptitle(f"GT/Pred PSFs at iter {i + 1}")
            plt.tight_layout()
            plt.savefig(f"{result_dir}/iter{i + 1}.png", dpi=300)
            plt.close()

            # Save network
            self.save_psfnet(f"{result_dir}/PSFNet_last.pth")

sample_training_data

sample_training_data(num_points=512, concentration_factor=2.0)

Sample training data for PSF surrogate network.

Parameters:

Name Type Description Default
num_points int

number of training points

512
concentration_factor float

concentration factor for training data sampling

2.0

Returns:

Name Type Description
sample_input tensor

[B, 3] tensor, (fov, depth, foc_dist). - fov from [0, rfov] on 0y-axis, [radians] - depth from [d_far, d_close], [mm] - foc_dist from [foc_d_far, foc_d_close], [mm] - We use absolute fov and depth.

sample_psf tensor

[B, 3, ks, ks] tensor

Source code in deeplens/optics/psfnetlens.py
@torch.no_grad()
def sample_training_data(self, num_points=512, concentration_factor=2.0):
    """Sample training data for PSF surrogate network.

    Args:
        num_points (int): number of training points
        concentration_factor (float): concentration factor for training data sampling

    Returns:
        sample_input (tensor): [B, 3] tensor, (fov, depth, foc_dist).
            - fov from [0, rfov] on 0y-axis, [radians]
            - depth from [d_far, d_close], [mm]
            - foc_dist from [foc_d_far, foc_d_close], [mm]
            - We use absolute fov and depth.

        sample_psf (tensor): [B, 3, ks, ks] tensor
    """
    d_close = self.d_close
    d_far = self.d_far
    rfov = self.lens.rfov

    # In each iteration, sample one focus distance, [mm], range [foc_d_far, foc_d_close]
    # Example beta distribution: https://share.google/images/Mrb9c39PdddYx3UHj
    beta_sample = float(np.random.beta(1, 4))  # Biased towards 0
    foc_dist = self.foc_d_close + beta_sample * (self.foc_d_far - self.foc_d_close)
    self.lens.refocus(foc_dist)
    foc_dist = torch.full((num_points,), foc_dist)

    # Sample (fov), [radians], range[0, rfov]
    beta_values = np.random.beta(4, 1, num_points)  # Biased towards 1
    beta_values = torch.from_numpy(beta_values).float()
    fov = beta_values * rfov

    # Sample (depth), sample more points near the focus distance, [mm], range [d_far, d_close]
    # A smaller std_dev value samples points more tightly
    std_dev = -foc_dist / concentration_factor
    depth = foc_dist + torch.randn(num_points) * std_dev
    depth = torch.clamp(depth, d_far, d_close)

    # Create input tensor
    sample_input = torch.stack([fov, depth / 1000.0, foc_dist / 1000.0], dim=1)
    sample_input = sample_input.to(self.device)

    # Calculate PSF by ray tracing, shape of [B, 3, ks, ks]
    points_x = torch.zeros_like(depth)
    points_y = self.lens.foclen * torch.tan(fov) / self.lens.r_sensor
    points_z = depth
    points = torch.stack((points_x, points_y, points_z), dim=-1)
    sample_psf = self.lens.psf_rgb(
        points=points, ks=self.kernel_size, recenter=True
    )

    return sample_input, sample_psf

eval

eval()

Switch the PSF surrogate network to evaluation mode.

Disables dropout and batch normalisation updates in the internal psfnet module. Call this before inference.

Source code in deeplens/optics/psfnetlens.py
def eval(self):
    """Switch the PSF surrogate network to evaluation mode.

    Disables dropout and batch normalisation updates in the internal
    ``psfnet`` module.  Call this before inference.
    """
    self.psfnet.eval()

points2input

points2input(points)

Convert points to input tensor.

Parameters:

Name Type Description Default
points tensor

[N, 3] tensor, [-1, 1] * [-1, 1] * [depth_min, depth_max]

required

Returns:

Name Type Description
input tensor

[N, 3] tensor, (fov, depth, foc_dist). - fov from [0, rfov] on y-axis, [radians] - depth/1000.0 from [d_far, d_close], [mm] - foc_dist/1000.0 from [foc_d_far, foc_d_close], [mm]

Source code in deeplens/optics/psfnetlens.py
def points2input(self, points):
    """Convert points to input tensor.

    Args:
        points (tensor): [N, 3] tensor, [-1, 1] * [-1, 1] * [depth_min, depth_max]

    Returns:
        input (tensor): [N, 3] tensor, (fov, depth, foc_dist).
            - fov from [0, rfov] on y-axis, [radians]
            - depth/1000.0 from [d_far, d_close], [mm]
            - foc_dist/1000.0 from [foc_d_far, foc_d_close], [mm]
    """
    sensor_h, sensor_w = self.lens.sensor_size
    foclen = self.lens.foclen

    points_x = points[:, 0] * sensor_w / 2
    points_y = points[:, 1] * sensor_h / 2
    points_r = torch.sqrt(points_x**2 + points_y**2)
    fov = torch.atan(points_r / foclen)
    depth = points[:, 2]
    foc_dist = torch.full_like(fov, self.foc_dist)
    network_inp = torch.stack((fov, depth / 1000.0, foc_dist / 1000.0), dim=-1)
    return network_inp

refocus

refocus(foc_dist)

Refocus the lens to a given object distance.

Delegates to the embedded :class:GeoLens and stores the focus distance for subsequent PSF predictions.

Parameters:

Name Type Description Default
foc_dist float

Focus distance in [mm] (negative, towards the object).

required
Source code in deeplens/optics/psfnetlens.py
def refocus(self, foc_dist):
    """Refocus the lens to a given object distance.

    Delegates to the embedded :class:`GeoLens` and stores the focus
    distance for subsequent PSF predictions.

    Args:
        foc_dist (float): Focus distance in [mm] (negative, towards the
            object).
    """
    self.lens.refocus(foc_dist)
    self.foc_dist = foc_dist

psf_rgb

psf_rgb(points, ks=64)

Calculate RGB PSF using the PSF network.

Parameters:

Name Type Description Default
points tensor

[N, 3] tensor, [-1, 1] * [-1, 1] * [depth_min, depth_max]

required
foc_dist float

focus distance

required

Returns:

Name Type Description
psf tensor

[N, 3, ks, ks] tensor

Source code in deeplens/optics/psfnetlens.py
def psf_rgb(self, points, ks=64):
    """Calculate RGB PSF using the PSF network.

    Args:
        points (tensor): [N, 3] tensor, [-1, 1] * [-1, 1] * [depth_min, depth_max]
        foc_dist (float): focus distance

    Returns:
        psf (tensor): [N, 3, ks, ks] tensor
    """
    # Calculate network input
    network_inp = self.points2input(points)

    # Predict y-axis PSF from network
    psf = self.psfnet(network_inp)

    # Post-process PSF
    # The psfnet is trained with PSFs on the y-axis.
    # We need to rotate the PSF to the correct orientation based on the point's coordinates.
    # The counter-clockwise angle from the positive y-axis to the point (x, y) is atan2(x, y).
    rot_angle = torch.atan2(points[:, 0], points[:, 1])
    psf = rotate_psf(psf, rot_angle)

    # Crop PSF to the given kernel size
    if ks < self.kernel_size:
        psf = psf[
            :,
            :,
            self.kernel_size // 2 - ks // 2 : self.kernel_size // 2 + ks // 2,
            self.kernel_size // 2 - ks // 2 : self.kernel_size // 2 + ks // 2,
        ]
    return psf

psf_map_rgb

psf_map_rgb(grid=(11, 11), depth=DEPTH, ks=PSF_KS, **kwargs)

Compute monochrome PSF map.

Parameters:

Name Type Description Default
grid tuple

Grid size. Defaults to (11, 11), meaning 11x11 grid.

(11, 11)
wvln float

Wavelength. Defaults to DEFAULT_WAVE.

required
depth float

Depth of the object. Defaults to DEPTH.

DEPTH
ks int

Kernel size. Defaults to PSF_KS, meaning PSF_KS x PSF_KS kernel size.

PSF_KS

Returns:

Name Type Description
psf_map

Shape of [grid, grid, 3, ks, ks].

Source code in deeplens/optics/psfnetlens.py
def psf_map_rgb(self, grid=(11, 11), depth=DEPTH, ks=PSF_KS, **kwargs):
    """Compute monochrome PSF map.

    Args:
        grid (tuple, optional): Grid size. Defaults to (11, 11), meaning 11x11 grid.
        wvln (float, optional): Wavelength. Defaults to DEFAULT_WAVE.
        depth (float, optional): Depth of the object. Defaults to DEPTH.
        ks (int, optional): Kernel size. Defaults to PSF_KS, meaning PSF_KS x PSF_KS kernel size.

    Returns:
        psf_map: Shape of [grid, grid, 3, ks, ks].
    """
    # PSF map grid
    points = self.point_source_grid(depth=depth, grid=grid, center=True)
    points = points.reshape(-1, 3).to(self.device)

    # Compute PSF map
    psf = self.psf_rgb(points=points, ks=ks)  # [grid*grid, 3, ks, ks]
    psf_map = psf.reshape(grid[0], grid[1], 3, ks, ks)  # [grid, grid, 3, ks, ks]
    return psf_map

render_rgbd

render_rgbd(img, depth, foc_dist, ks=64, high_res=False)

Render image with aif image and depth map. Receive [N, C, H, W] image.

Parameters:

Name Type Description Default
img tensor

[1, C, H, W]

required
depth tensor

[1, H, W], depth map, unit in mm, range from [-20000, -200]

required
foc_dist tensor

[1], unit in mm, range from [-20000, -200]

required
ks int

kernel size

64
high_res bool

whether to use high resolution rendering

False

Returns:

Name Type Description
render tensor

[1, C, H, W]

Source code in deeplens/optics/psfnetlens.py
@torch.no_grad()
def render_rgbd(self, img, depth, foc_dist, ks=64, high_res=False):
    """Render image with aif image and depth map. Receive [N, C, H, W] image.

    Args:
        img (tensor): [1, C, H, W]
        depth (tensor): [1, H, W], depth map, unit in mm, range from [-20000, -200]
        foc_dist (tensor): [1], unit in mm, range from [-20000, -200]
        ks (int): kernel size
        high_res (bool): whether to use high resolution rendering

    Returns:
        render (tensor): [1, C, H, W]
    """
    B, C, H, W = img.shape
    assert B == 1, "Only support batch size 1"

    # Refocus the lens to the given focus distance
    self.refocus(foc_dist)

    # Estimate PSF for each pixel
    x, y = torch.meshgrid(
        torch.linspace(-1, 1, W, device=self.device),
        torch.linspace(1, -1, H, device=self.device),
        indexing="xy",
    )
    x, y = x.unsqueeze(0).repeat(B, 1, 1), y.unsqueeze(0).repeat(B, 1, 1)
    depth = depth.squeeze(1) / 1000.0

    points = torch.stack((x, y, depth), -1).float()
    psf = self.psf_rgb(points=points, ks=ks)

    # Render image with per-pixel PSF convolution
    if high_res:
        render = conv_psf_pixel_high_res(img, psf)
    else:
        render = conv_psf_pixel(img, psf)

    return render

Surfaces

Base class for all geometric optical surfaces. Implements surface intersection (Newton's method with one differentiable step) and differentiable vector Snell's law refraction.

deeplens.optics.geometric_surface.Surface

Surface(r, d, mat2, pos_xy=[0.0, 0.0], vec_local=[0.0, 0.0, 1.0], is_square=False, device='cpu')

Bases: DeepObj

Base class for all geometric optical surfaces.

A surface sits at axial position d (mm) in the global coordinate system, has an aperture radius r (mm), and separates two optical media. Subclasses override :meth:_sag and :meth:_dfdxy to define their shape.

Ray–surface interaction is handled by three stages, implemented in :meth:ray_reaction:

  1. Coordinate transform – ray is brought into the local surface frame.
  2. Intersection – solved via Newton's method (:meth:newtons_method), using a non-differentiable iteration loop followed by a single differentiable Newton step to enable gradient flow.
  3. Refraction / reflection – vector Snell's law (:meth:refract) or specular reflection (:meth:reflect).

Attributes:

Name Type Description
d Tensor

Axial position of the surface vertex [mm].

r float

Aperture radius [mm].

mat2 Material

Optical material on the transmission side.

is_square bool

If True the aperture is square; otherwise circular.

tolerancing bool

When True, manufacturing error offsets are applied during ray tracing.

Initialize a generic optical surface.

Parameters:

Name Type Description Default
r float

Aperture radius [mm].

required
d float

Axial position of the surface vertex [mm].

required
mat2 str or Material

Material on the transmission side (e.g. "N-BK7", "air").

required
pos_xy list[float]

Lateral offset [x, y] [mm]. Defaults to [0.0, 0.0].

[0.0, 0.0]
vec_local list[float]

Local normal direction. Defaults to [0.0, 0.0, 1.0] (on-axis).

[0.0, 0.0, 1.0]
is_square bool

Use a square aperture. Defaults to False.

False
device str

Compute device. Defaults to "cpu".

'cpu'
Source code in deeplens/optics/geometric_surface/base.py
def __init__(
    self,
    r,
    d,
    mat2,
    pos_xy=[0.0, 0.0],
    vec_local=[0.0, 0.0, 1.0],
    is_square=False,
    device="cpu",
):
    """Initialize a generic optical surface.

    Args:
        r (float): Aperture radius [mm].
        d (float): Axial position of the surface vertex [mm].
        mat2 (str or Material): Material on the transmission side
            (e.g. ``"N-BK7"``, ``"air"``).
        pos_xy (list[float], optional): Lateral offset ``[x, y]`` [mm].
            Defaults to ``[0.0, 0.0]``.
        vec_local (list[float], optional): Local normal direction.
            Defaults to ``[0.0, 0.0, 1.0]`` (on-axis).
        is_square (bool, optional): Use a square aperture.
            Defaults to ``False``.
        device (str, optional): Compute device. Defaults to ``"cpu"``.
    """
    super(Surface, self).__init__()

    # Global direction vector, always pointing to the positive z-axis
    self.vec_global = torch.tensor([0.0, 0.0, 1.0])

    # Surface position in global coordinate system
    self.d = torch.tensor(d)
    self.pos_x = torch.tensor(pos_xy[0])
    self.pos_y = torch.tensor(pos_xy[1])

    # Surface direction vector in global coordinate system
    self.vec_local = F.normalize(torch.tensor(vec_local), p=2, dim=-1)

    # Material after the surface
    self.mat2 = Material(mat2)

    # Surface aperture radius (non-differentiable)
    self.r = float(r)
    self.is_square = is_square
    if is_square:
        # r is the incircle radius
        self.h = 2 * self.r
        self.w = 2 * self.r

    # Newton method parameters
    self.newton_maxiter = 10  # [int], maximum number of Newton iterations
    self.newton_convergence = (
        50.0 * 1e-6
    )  # [mm], Newton method convergence threshold
    self.newton_step_bound = self.r / 5  # [mm], maximum step size in each iteration

    self.tolerancing = False
    self._R_tilt = None
    self._R_tilt_inv = None
    self.device = device if device is not None else torch.device("cpu")
    self.to(self.device)

    # Pre-compute rotation matrices (depends only on static vec_local/vec_global)
    self._cache_rotation_matrices()

init_from_dict classmethod

init_from_dict(surf_dict)

Initialize surface from a dict.

Source code in deeplens/optics/geometric_surface/base.py
@classmethod
def init_from_dict(cls, surf_dict):
    """Initialize surface from a dict."""
    raise NotImplementedError(
        f"init_from_dict() is not implemented for {cls.__name__}."
    )

ray_reaction

ray_reaction(ray, n1, n2, refraction=True)

Compute the output ray after intersection and refraction/reflection.

Transforms the ray to the local surface frame, solves the intersection via Newton's method, applies vector Snell's law (or specular reflection), then transforms back to global coordinates.

When tolerancing is active, mat2_n_error is added to n2 to simulate refractive-index manufacturing error.

Parameters:

Name Type Description Default
ray Ray

Incident ray bundle.

required
n1 float

Refractive index of the incident medium.

required
n2 float

Refractive index of the transmission medium.

required
refraction bool

If True (default) refract the ray; if False reflect it.

True

Returns:

Name Type Description
Ray

Updated ray bundle after the surface interaction.

Source code in deeplens/optics/geometric_surface/base.py
def ray_reaction(self, ray, n1, n2, refraction=True):
    """Compute the output ray after intersection and refraction/reflection.

    Transforms the ray to the local surface frame, solves the intersection
    via Newton's method, applies vector Snell's law (or specular reflection),
    then transforms back to global coordinates.

    When tolerancing is active, ``mat2_n_error`` is added to ``n2`` to
    simulate refractive-index manufacturing error.

    Args:
        ray (Ray): Incident ray bundle.
        n1 (float): Refractive index of the incident medium.
        n2 (float): Refractive index of the transmission medium.
        refraction (bool, optional): If ``True`` (default) refract the ray;
            if ``False`` reflect it.

    Returns:
        Ray: Updated ray bundle after the surface interaction.
    """
    # Transform ray to local coordinate system
    ray = self.to_local_coord(ray)

    # Intersection
    ray = self.intersect(ray, n1)

    if refraction:
        # Apply refractive index tolerance error
        if self.tolerancing:
            n2 = n2 + self.mat2_n_error
        ray = self.refract(ray, n1 / n2)
    else:
        # Reflection
        ray = self.reflect(ray)

    # Transform ray to global coordinate system
    ray = self.to_global_coord(ray)

    return ray

intersect

intersect(ray, n=1.0)

Solve ray-surface intersection in local coordinate system.

Parameters:

Name Type Description Default
ray Ray

input ray.

required
n float

refractive index. Defaults to 1.0.

1.0
Source code in deeplens/optics/geometric_surface/base.py
def intersect(self, ray, n=1.0):
    """Solve ray-surface intersection in local coordinate system.

    Args:
        ray (Ray): input ray.
        n (float, optional): refractive index. Defaults to 1.0.
    """
    # Solve ray-surface intersection time by Newton's method
    t, valid = self.newtons_method(ray)

    # Update ray
    new_o = ray.o + ray.d * t.unsqueeze(-1)
    ray.o = torch.where(valid.unsqueeze(-1), new_o, ray.o)
    ray.is_valid = ray.is_valid * valid

    if ray.coherent:
        if t.abs().max() > 100 and torch.get_default_dtype() == torch.float32:
            raise Exception(
                "Using float32 may cause precision problem for OPL calculation."
            )
        new_opl = ray.opl + n * t.unsqueeze(-1)
        ray.opl = torch.where(valid.unsqueeze(-1), new_opl, ray.opl)

    return ray

newtons_method

newtons_method(ray)

Solve intersection by Newton's method in local coordinate system.

Parameters:

Name Type Description Default
ray Ray

input ray.

required

Returns:

Name Type Description
t tensor

intersection time.

valid tensor

valid mask.

Source code in deeplens/optics/geometric_surface/base.py
def newtons_method(self, ray):
    """Solve intersection by Newton's method in local coordinate system.

    Args:
        ray (Ray): input ray.

    Returns:
        t (tensor): intersection time.
        valid (tensor): valid mask.
    """
    newton_maxiter = self.newton_maxiter
    newton_convergence = self.newton_convergence
    newton_step_bound = self.newton_step_bound

    # Ray direction components (reused across iterations)
    dxdt, dydt, dzdt = ray.d[..., 0], ray.d[..., 1], ray.d[..., 2]

    # Initial guess of t (can also use spherical surface for initial guess)
    t = -ray.o[..., 2] / dzdt

    # 1. Non-differentiable Newton's iterations to find the intersection
    #    Run (maxiter - 1) iterations; the differentiable step below acts as
    #    the final iteration while also enabling gradient flow.
    with torch.no_grad():
        for _ in range(newton_maxiter - 1):
            new_o = ray.o + ray.d * t.unsqueeze(-1)
            new_x, new_y = new_o[..., 0], new_o[..., 1]
            valid = self.is_within_data_range(new_x, new_y) & (ray.is_valid > 0)

            x, y = new_x * valid, new_y * valid
            ft = self._sag(x, y) - new_o[..., 2]
            dfdx, dfdy = self._dfdxy(x, y)
            dfdt = dfdx * dxdt + dfdy * dydt - dzdt
            t = t - torch.clamp(
                ft / (dfdt + EPSILON), -newton_step_bound, newton_step_bound
            )

    # 2. One differentiable Newton step (final iteration + gradient flow)
    new_o = ray.o + ray.d * t.unsqueeze(-1)
    new_x, new_y = new_o[..., 0], new_o[..., 1]
    valid = self.is_valid(new_x, new_y) & (ray.is_valid > 0)

    x, y = new_x * valid, new_y * valid
    ft = self._sag(x, y) - new_o[..., 2]
    dfdx, dfdy = self._dfdxy(x, y)
    dfdt = dfdx * dxdt + dfdy * dydt - dzdt
    t = t - torch.clamp(
        ft / (dfdt + EPSILON), -newton_step_bound, newton_step_bound
    )

    # 3. Determine valid solutions — reuse ft and valid from the diff step
    with torch.no_grad():
        valid = valid & (ft.abs() < newton_convergence)

    return t, valid

refract

refract(ray, eta)

Calculate refracted ray according to Snell's law in local coordinate system.

Normal vector points from the surface toward the side where the light is coming from. d is already normalized if both n and ray.d are normalized.

Parameters:

Name Type Description Default
ray Ray

incident ray.

required
eta float

ratio of indices of refraction, eta = n_i / n_t

required

Returns:

Name Type Description
ray Ray

refracted ray.

References

[1] https://registry.khronos.org/OpenGL-Refpages/gl4/html/refract.xhtml [2] https://en.wikipedia.org/wiki/Snell%27s_law, "Vector form" section.

Source code in deeplens/optics/geometric_surface/base.py
def refract(self, ray, eta):
    """Calculate refracted ray according to Snell's law in local coordinate system.

    Normal vector points from the surface toward the side where the light is coming from. d is already normalized if both n and ray.d are normalized.

    Args:
        ray (Ray): incident ray.
        eta (float): ratio of indices of refraction, eta = n_i / n_t

    Returns:
        ray (Ray): refracted ray.

    References:
        [1] https://registry.khronos.org/OpenGL-Refpages/gl4/html/refract.xhtml
        [2] https://en.wikipedia.org/wiki/Snell%27s_law, "Vector form" section.
    """
    # Compute normal vectors
    normal_vec = self.normal_vec(ray)

    # Compute refraction according to Snell's law, normal_vec * ray_d
    dot_product = (-normal_vec * ray.d).sum(-1).unsqueeze(-1)
    k = 1 - eta**2 * (1 - dot_product**2)

    # Total internal reflection
    valid = (k >= 0).squeeze(-1) & (ray.is_valid > 0)
    k = k * valid.unsqueeze(-1)

    # Update ray direction and obliquity
    new_d = eta * ray.d + (eta * dot_product - torch.sqrt(k + EPSILON)) * normal_vec
    # ==> Update obliq term to penalize steep rays in the later optimization.
    obliq = torch.sum(new_d * ray.d, axis=-1).unsqueeze(-1)
    obliq_update_mask = valid.unsqueeze(-1) & (obliq < 0.5)
    ray.obliq = torch.where(obliq_update_mask, obliq * ray.obliq, ray.obliq)
    # ==>
    ray.d = torch.where(valid.unsqueeze(-1), new_d, ray.d)

    # Update ray valid mask
    ray.is_valid = ray.is_valid * valid

    return ray

reflect

reflect(ray)

Calculate reflected ray in local coordinate system.

Normal vector points from the surface toward the side where the light is coming from.

Parameters:

Name Type Description Default
ray Ray

incident ray.

required

Returns:

Name Type Description
ray Ray

reflected ray.

References

[1] https://registry.khronos.org/OpenGL-Refpages/gl4/html/reflect.xhtml [2] https://en.wikipedia.org/wiki/Snell%27s_law, "Vector form" section.

Source code in deeplens/optics/geometric_surface/base.py
def reflect(self, ray):
    """Calculate reflected ray in local coordinate system.

    Normal vector points from the surface toward the side where the light is coming from.

    Args:
        ray (Ray): incident ray.

    Returns:
        ray (Ray): reflected ray.

    References:
        [1] https://registry.khronos.org/OpenGL-Refpages/gl4/html/reflect.xhtml
        [2] https://en.wikipedia.org/wiki/Snell%27s_law, "Vector form" section.
    """
    # Compute surface normal vectors
    normal_vec = self.normal_vec(ray)

    # Reflect
    dot_product = (normal_vec * ray.d).sum(-1).unsqueeze(-1)
    new_d = ray.d - 2 * dot_product * normal_vec
    new_d = F.normalize(new_d, p=2, dim=-1)

    # Update valid rays
    valid_mask = ray.is_valid > 0
    ray.d = torch.where(valid_mask.unsqueeze(-1), new_d, ray.d)

    return ray

normal_vec

normal_vec(ray)

Calculate surface normal vector at the intersection point in local coordinate system.

Normal vector points from the surface toward the side where the light is coming from.

Parameters:

Name Type Description Default
ray Ray

input ray.

required

Returns:

Name Type Description
n_vec tensor

surface normal vector.

Source code in deeplens/optics/geometric_surface/base.py
def normal_vec(self, ray):
    """Calculate surface normal vector at the intersection point in local coordinate system.

    Normal vector points from the surface toward the side where the light is coming from.

    Args:
        ray (Ray): input ray.

    Returns:
        n_vec (tensor): surface normal vector.
    """
    x, y = ray.o[..., 0], ray.o[..., 1]
    nx, ny, nz = self.dfdxyz(x, y)
    n_vec = torch.stack((nx, ny, nz), axis=-1)
    n_vec = F.normalize(n_vec, p=2, dim=-1)

    is_forward = ray.d[..., 2].unsqueeze(-1) > 0
    n_vec = torch.where(is_forward, n_vec, -n_vec)
    return n_vec

to_local_coord

to_local_coord(ray)

Transform ray to local coordinate system.

When tolerancing is active, applies manufacturing error perturbations: d_error (axial shift), decenter_x/y_error (lateral shift), and tilt_error (rotation about the x-axis).

Parameters:

Name Type Description Default
ray Ray

input ray in global coordinate system.

required

Returns:

Name Type Description
ray Ray

transformed ray in local coordinate system.

Source code in deeplens/optics/geometric_surface/base.py
def to_local_coord(self, ray):
    """Transform ray to local coordinate system.

    When tolerancing is active, applies manufacturing error perturbations:
    d_error (axial shift), decenter_x/y_error (lateral shift), and
    tilt_error (rotation about the x-axis).

    Args:
        ray (Ray): input ray in global coordinate system.

    Returns:
        ray (Ray): transformed ray in local coordinate system.
    """
    # Shift ray origin to surface origin (with tolerance perturbations)
    if self.tolerancing:
        ray.o[..., 0] = ray.o[..., 0] - self.pos_x - self.decenter_x_error
        ray.o[..., 1] = ray.o[..., 1] - self.pos_y - self.decenter_y_error
        ray.o[..., 2] = ray.o[..., 2] - self.d - self.d_error
    else:
        ray.o[..., 0] = ray.o[..., 0] - self.pos_x
        ray.o[..., 1] = ray.o[..., 1] - self.pos_y
        ray.o[..., 2] = ray.o[..., 2] - self.d

    # Apply tilt rotation (tolerance-induced, using cached matrix)
    if self._R_tilt is not None:
        ray.o = self._apply_rotation(ray.o, self._R_tilt)
        ray.d = self._apply_rotation(ray.d, self._R_tilt)
        ray.d = F.normalize(ray.d, p=2, dim=-1)

    # Rotate ray origin and direction (using cached matrix for nominal orientation)
    if self._R_to_local is not None:
        ray.o = self._apply_rotation(ray.o, self._R_to_local)
        ray.d = self._apply_rotation(ray.d, self._R_to_local)
        ray.d = F.normalize(ray.d, p=2, dim=-1)

    return ray

to_global_coord

to_global_coord(ray)

Transform ray to global coordinate system.

When tolerancing is active, reverses the manufacturing error perturbations applied in :meth:to_local_coord.

Parameters:

Name Type Description Default
ray Ray

input ray in local coordinate system.

required

Returns:

Name Type Description
ray Ray

transformed ray in global coordinate system.

Source code in deeplens/optics/geometric_surface/base.py
def to_global_coord(self, ray):
    """Transform ray to global coordinate system.

    When tolerancing is active, reverses the manufacturing error
    perturbations applied in :meth:`to_local_coord`.

    Args:
        ray (Ray): input ray in local coordinate system.

    Returns:
        ray (Ray): transformed ray in global coordinate system.
    """
    # Rotate ray origin and direction (using cached matrix for nominal orientation)
    if self._R_to_global is not None:
        ray.o = self._apply_rotation(ray.o, self._R_to_global)
        ray.d = self._apply_rotation(ray.d, self._R_to_global)
        ray.d = F.normalize(ray.d, p=2, dim=-1)

    # Reverse tilt rotation (tolerance-induced, using cached inverse matrix)
    if self._R_tilt_inv is not None:
        ray.o = self._apply_rotation(ray.o, self._R_tilt_inv)
        ray.d = self._apply_rotation(ray.d, self._R_tilt_inv)
        ray.d = F.normalize(ray.d, p=2, dim=-1)

    # Shift ray origin back to global coordinates (with tolerance perturbations)
    if self.tolerancing:
        ray.o[..., 0] = ray.o[..., 0] + self.pos_x + self.decenter_x_error
        ray.o[..., 1] = ray.o[..., 1] + self.pos_y + self.decenter_y_error
        ray.o[..., 2] = ray.o[..., 2] + self.d + self.d_error
    else:
        ray.o[..., 0] = ray.o[..., 0] + self.pos_x
        ray.o[..., 1] = ray.o[..., 1] + self.pos_y
        ray.o[..., 2] = ray.o[..., 2] + self.d

    return ray

sag

sag(x, y, valid=None)

Calculate sag (z) of the surface: z = f(x, y).

Valid term is used to avoid NaN when x, y exceed the data range, which happens in spherical and aspherical surfaces.

Calculating r = sqrt(x2, y2) may cause an NaN error during back-propagation. Because dr/dx = x / sqrt(x2 + y2), NaN will occur when x=y=0.

Source code in deeplens/optics/geometric_surface/base.py
def sag(self, x, y, valid=None):
    """Calculate sag (z) of the surface: z = f(x, y).

    Valid term is used to avoid NaN when x, y exceed the data range, which happens in spherical and aspherical surfaces.

    Calculating r = sqrt(x**2, y**2) may cause an NaN error during back-propagation. Because dr/dx = x / sqrt(x**2 + y**2), NaN will occur when x=y=0.
    """
    if valid is None:
        valid = self.is_valid(x, y)

    x, y = x * valid, y * valid
    return self._sag(x, y)

dfdxyz

dfdxyz(x, y, valid=None)

Compute derivatives of surface function. Surface function: f(x, y, z): sag(x, y) - z = 0. This function is used in Newton's method and normal vector calculation.

There are several methods to compute derivatives of surfaces

[1] Analytical derivatives: The current implementation is based on this method. But the implementation only works for surfaces which can be written as z = sag(x, y). For implicit surfaces, we need to compute derivatives (df/dx, df/dy, df/dz). [2] Numerical derivatives: Use finite difference method to compute derivatives. This can be used for those very complex surfaces, for example, NURBS. But it may suffer from numerical instability when the surface is very steep. [3] Automatic differentiation: Use torch.autograd to compute derivatives. This can work for almost all the surfaces and is accurate, but it requires an extra backward pass to compute the derivatives of the surface function.

Source code in deeplens/optics/geometric_surface/base.py
def dfdxyz(self, x, y, valid=None):
    """Compute derivatives of surface function. Surface function: f(x, y, z): sag(x, y) - z = 0. This function is used in Newton's method and normal vector calculation.

    There are several methods to compute derivatives of surfaces:
        [1] Analytical derivatives: The current implementation is based on this method. But the implementation only works for surfaces which can be written as z = sag(x, y). For implicit surfaces, we need to compute derivatives (df/dx, df/dy, df/dz).
        [2] Numerical derivatives: Use finite difference method to compute derivatives. This can be used for those very complex surfaces, for example, NURBS. But it may suffer from numerical instability when the surface is very steep.
        [3] Automatic differentiation: Use torch.autograd to compute derivatives. This can work for almost all the surfaces and is accurate, but it requires an extra backward pass to compute the derivatives of the surface function.
    """
    if valid is None:
        valid = self.is_valid(x, y)

    x, y = x * valid, y * valid
    dx, dy = self._dfdxy(x, y)
    return dx, dy, -torch.ones_like(x)

d2fdxyz2

d2fdxyz2(x, y, valid=None)

Compute second-order partial derivatives of the surface function f(x, y, z): sag(x, y) - z = 0. This function is currently only used for surfaces constraints.

Source code in deeplens/optics/geometric_surface/base.py
def d2fdxyz2(self, x, y, valid=None):
    """Compute second-order partial derivatives of the surface function f(x, y, z): sag(x, y) - z = 0. This function is currently only used for surfaces constraints."""
    if valid is None:
        valid = self.is_within_data_range(x, y)

    x, y = x * valid, y * valid

    # Compute second-order derivatives of sag(x, y)
    d2f_dx2, d2f_dxdy, d2f_dy2 = self._d2fdxy(x, y)

    # Mixed partial derivatives involving z are zero
    zeros = torch.zeros_like(x)
    d2f_dxdz = zeros  # ∂²f/∂x∂z = 0
    d2f_dydz = zeros  # ∂²f/∂y∂z = 0
    d2f_dz2 = zeros  # ∂²f/∂z² = 0

    return d2f_dx2, d2f_dxdy, d2f_dy2, d2f_dxdz, d2f_dydz, d2f_dz2

is_valid

is_valid(x, y)

Valid points within the data range and boundary of the surface.

Source code in deeplens/optics/geometric_surface/base.py
def is_valid(self, x, y):
    """Valid points within the data range and boundary of the surface."""
    return self.is_within_data_range(x, y) & self.is_within_boundary(x, y)

is_within_boundary

is_within_boundary(x, y)

Valid points within the boundary of the surface.

Source code in deeplens/optics/geometric_surface/base.py
def is_within_boundary(self, x, y):
    """Valid points within the boundary of the surface."""
    if self.is_square:
        valid = (torch.abs(x) <= (self.w / 2 + EPSILON)) & (
            torch.abs(y) <= (self.h / 2 + EPSILON)
        )
    else:
        if self.tolerancing:
            r = self.r + self.r_error
        else:
            r = self.r
        valid = (x**2 + y**2) <= (r**2 + EPSILON)

    return valid

is_within_data_range

is_within_data_range(x, y)

Valid points inside the data region of the sag function.

Source code in deeplens/optics/geometric_surface/base.py
def is_within_data_range(self, x, y):
    """Valid points inside the data region of the sag function."""
    return torch.ones_like(x, dtype=torch.bool)

max_height

max_height()

Maximum valid height.

Source code in deeplens/optics/geometric_surface/base.py
def max_height(self):
    """Maximum valid height."""
    return 10e3

surface_with_offset

surface_with_offset(x, y, valid_check=True)

Calculate z coordinate of the surface at (x, y).

This function is used in lens setup plotting and lens self-intersection detection.

Source code in deeplens/optics/geometric_surface/base.py
def surface_with_offset(self, x, y, valid_check=True):
    """Calculate z coordinate of the surface at (x, y).

    This function is used in lens setup plotting and lens self-intersection detection.
    """
    x = x if torch.is_tensor(x) else torch.tensor(x).to(self.device)
    y = y if torch.is_tensor(y) else torch.tensor(y).to(self.device)
    if valid_check:
        return self.sag(x, y) + self.d
    else:
        return self._sag(x, y) + self.d

surface_sag

surface_sag(x, y)

Calculate sag of the surface at (x, y).

This function is currently not used.

Source code in deeplens/optics/geometric_surface/base.py
def surface_sag(self, x, y):
    """Calculate sag of the surface at (x, y).

    This function is currently not used.
    """
    x = x if torch.is_tensor(x) else torch.tensor(x).to(self.device)
    y = y if torch.is_tensor(y) else torch.tensor(y).to(self.device)
    return self.sag(x, y).item()

get_optimizer_params

get_optimizer_params(lrs=[0.0001], optim_mat=False)

Get optimizer parameters for different parameters.

Parameters:

Name Type Description Default
lrs list

learning rates for different parameters.

[0.0001]
optim_mat bool

whether to optimize material. Defaults to False.

False
Source code in deeplens/optics/geometric_surface/base.py
def get_optimizer_params(self, lrs=[1e-4], optim_mat=False):
    """Get optimizer parameters for different parameters.

    Args:
        lrs (list): learning rates for different parameters.
        optim_mat (bool): whether to optimize material. Defaults to False.
    """
    raise NotImplementedError(
        "get_optimizer_params() is not implemented for {}".format(
            self.__class__.__name__
        )
    )

get_optimizer

get_optimizer(lrs=[0.0001], optim_mat=False)

Get optimizer for the surface.

Source code in deeplens/optics/geometric_surface/base.py
def get_optimizer(self, lrs=[1e-4], optim_mat=False):
    """Get optimizer for the surface."""
    params = self.get_optimizer_params(lrs, optim_mat=optim_mat)
    return torch.optim.Adam(params)

update_r

update_r(r)

Update surface radius.

Source code in deeplens/optics/geometric_surface/base.py
def update_r(self, r):
    """Update surface radius."""
    r_max = self.max_height()
    self.r = min(r, r_max)

init_tolerance

init_tolerance(tolerance_params=None)

Initialize tolerance parameters for the surface.

Parameters:

Name Type Description Default
tolerance_params dict or None

Tolerance for surface parameters. Supported keys (all optional, default values shown):

.. code-block:: python

{
    "r_tole": 0.05,          # aperture radius [mm]
    "d_tole": 0.05,          # axial position [mm]
    "decenter_tole": 0.1,    # lateral decentre x & y [mm]
    "tilt_tole": 0.1,        # tilt [arcmin]
    "mat2_n_tole": 0.001,    # refractive index
}
None
References

[1] https://www.edmundoptics.com/knowledge-center/application-notes/optics/understanding-optical-specifications/ [2] https://wp.optics.arizona.edu/optomech/wp-content/uploads/sites/53/2016/08/8-Tolerancing-1.pdf [3] https://wp.optics.arizona.edu/jsasian/wp-content/uploads/sites/33/2016/03/L17_OPTI517_Lens-_Tolerancing.pdf

Source code in deeplens/optics/geometric_surface/base.py
def init_tolerance(self, tolerance_params=None):
    """Initialize tolerance parameters for the surface.

    Args:
        tolerance_params (dict or None): Tolerance for surface parameters.
            Supported keys (all optional, default values shown):

            .. code-block:: python

                {
                    "r_tole": 0.05,          # aperture radius [mm]
                    "d_tole": 0.05,          # axial position [mm]
                    "decenter_tole": 0.1,    # lateral decentre x & y [mm]
                    "tilt_tole": 0.1,        # tilt [arcmin]
                    "mat2_n_tole": 0.001,    # refractive index
                }

    References:
        [1] https://www.edmundoptics.com/knowledge-center/application-notes/optics/understanding-optical-specifications/
        [2] https://wp.optics.arizona.edu/optomech/wp-content/uploads/sites/53/2016/08/8-Tolerancing-1.pdf
        [3] https://wp.optics.arizona.edu/jsasian/wp-content/uploads/sites/33/2016/03/L17_OPTI517_Lens-_Tolerancing.pdf
    """
    if tolerance_params is None:
        tolerance_params = {}

    # Tolerance ranges
    self.r_tole = tolerance_params.get("r_tole", 0.05)
    self.d_tole = tolerance_params.get("d_tole", 0.05)
    self.decenter_tole = tolerance_params.get("decenter_tole", 0.1)
    self.tilt_tole = tolerance_params.get("tilt_tole", 0.1)
    self.mat2_n_tole = tolerance_params.get("mat2_n_tole", 0.001)

    # Initialize error values to zero (set to random values by sample_tolerance)
    self.r_error = 0.0
    self.d_error = 0.0
    self.decenter_x_error = 0.0
    self.decenter_y_error = 0.0
    self.tilt_error = 0.0
    self.mat2_n_error = 0.0
    # Cached tilt rotation matrices (populated by sample_tolerance)
    self._R_tilt = None
    self._R_tilt_inv = None

sample_tolerance

sample_tolerance()

Sample one set of random manufacturing errors for the surface.

Error distributions
  • r_error: Uniform[-r_tole, 0] (aperture only shrinks).
  • d_error: Normal(0, d_tole) axial position shift [mm].
  • decenter_x/y_error: Normal(0, decenter_tole) lateral shift [mm].
  • tilt_error: Normal(0, tilt_tole) tilt about x-axis [arcmin → rad].
  • mat2_n_error: Normal(0, mat2_n_tole) refractive index offset.
Source code in deeplens/optics/geometric_surface/base.py
@torch.no_grad()
def sample_tolerance(self):
    """Sample one set of random manufacturing errors for the surface.

    Error distributions:
        - r_error: Uniform[-r_tole, 0] (aperture only shrinks).
        - d_error: Normal(0, d_tole) axial position shift [mm].
        - decenter_x/y_error: Normal(0, decenter_tole) lateral shift [mm].
        - tilt_error: Normal(0, tilt_tole) tilt about x-axis [arcmin → rad].
        - mat2_n_error: Normal(0, mat2_n_tole) refractive index offset.
    """
    self.r_error = float(np.random.uniform(-self.r_tole, 0))  # [mm]
    self.d_error = float(np.random.randn() * self.d_tole)  # [mm]
    self.decenter_x_error = float(np.random.randn() * self.decenter_tole)  # [mm]
    self.decenter_y_error = float(np.random.randn() * self.decenter_tole)  # [mm]
    tilt_arcmin = float(np.random.randn() * self.tilt_tole)  # [arcmin]
    self.tilt_error = tilt_arcmin / 60.0 * np.pi / 180.0  # [rad]
    self.mat2_n_error = float(np.random.randn() * self.mat2_n_tole)

    # Cache tilt rotation matrices to avoid per-call tensor allocation
    if abs(self.tilt_error) > 1e-12:
        self._R_tilt = self._tilt_rotation_matrix(self.tilt_error, self.device)
        self._R_tilt_inv = self._tilt_rotation_matrix(-self.tilt_error, self.device)
    else:
        self._R_tilt = None
        self._R_tilt_inv = None

    self.tolerancing = True

zero_tolerance

zero_tolerance()

Reset all manufacturing errors to zero (nominal state).

Source code in deeplens/optics/geometric_surface/base.py
def zero_tolerance(self):
    """Reset all manufacturing errors to zero (nominal state)."""
    self.r_error = 0.0
    self.d_error = 0.0
    self.decenter_x_error = 0.0
    self.decenter_y_error = 0.0
    self.tilt_error = 0.0
    self.mat2_n_error = 0.0
    self._R_tilt = None
    self._R_tilt_inv = None
    self.tolerancing = False

sensitivity_score

sensitivity_score()

Compute first-order tolerance sensitivity scores via RSS formula.

For each parameter with a gradient, the score is: tolerance_range² × gradient², which approximates the variance of the loss contribution from that parameter's manufacturing error.

Returns:

Name Type Description
dict

Sensitivity gradients and RSS scores keyed by surface index.

Reference

[1] Page 10 from: https://wp.optics.arizona.edu/optomech/wp-content/uploads/sites/53/2016/08/8-Tolerancing-1.pdf

Source code in deeplens/optics/geometric_surface/base.py
def sensitivity_score(self):
    """Compute first-order tolerance sensitivity scores via RSS formula.

    For each parameter with a gradient, the score is:
    ``tolerance_range² × gradient²``, which approximates the variance of
    the loss contribution from that parameter's manufacturing error.

    Returns:
        dict: Sensitivity gradients and RSS scores keyed by surface index.

    Reference:
        [1] Page 10 from: https://wp.optics.arizona.edu/optomech/wp-content/uploads/sites/53/2016/08/8-Tolerancing-1.pdf
    """
    score_dict = {}
    idx = getattr(self, "surf_idx", id(self))

    if self.d.grad is not None:
        score_dict[f"surf{idx}_d_grad"] = round(self.d.grad.item(), 6)
        score_dict[f"surf{idx}_d_score"] = round(
            (self.d_tole**2 * self.d.grad**2).item(), 6
        )

    return score_dict

draw_r

draw_r()

Effective drawing radius, clamped to the valid data range.

Source code in deeplens/optics/geometric_surface/base.py
def draw_r(self):
    """Effective drawing radius, clamped to the valid data range."""
    return min(self.r, self.max_height())

draw_widget

draw_widget(ax, color='black', linestyle='solid')

Draw widget for the surface on the 2D plot.

Source code in deeplens/optics/geometric_surface/base.py
def draw_widget(self, ax, color="black", linestyle="solid"):
    """Draw widget for the surface on the 2D plot."""
    r_eff = self.draw_r()
    r = torch.linspace(-r_eff, r_eff, 128, device=self.device)
    z = self.surface_with_offset(
        r, torch.zeros(len(r), device=self.device), valid_check=False
    )
    ax.plot(
        z.cpu().detach().numpy(),
        r.cpu().detach().numpy(),
        color=color,
        linestyle=linestyle,
        linewidth=0.75,
    )

create_mesh

create_mesh(n_rings=32, n_arms=128, color=[0.06, 0.3, 0.6])

Create triangulated surface mesh.

Parameters:

Name Type Description Default
n_rings int

Number of concentric rings for sampling.

32
n_arms int

Number of angular divisions.

128
color List[float]

The color of the mesh.

[0.06, 0.3, 0.6]

Returns:

Name Type Description
self

The surface with mesh data.

Source code in deeplens/optics/geometric_surface/base.py
def create_mesh(self, n_rings=32, n_arms=128, color=[0.06, 0.3, 0.6]):
    """Create triangulated surface mesh.

    Args:
        n_rings (int): Number of concentric rings for sampling.
        n_arms (int): Number of angular divisions.
        color (List[float]): The color of the mesh.

    Returns:
        self: The surface with mesh data.
    """
    self.vertices = self._create_vertices(n_rings, n_arms)
    self.faces = self._create_faces(n_rings, n_arms)
    self.rim = self._create_rim(n_rings, n_arms)
    self.mesh_color = color
    return self

get_polydata

get_polydata()

Get PyVista PolyData object from previously generated vertices and faces.

PolyData object will be used to draw the surface and export as .obj file.

Source code in deeplens/optics/geometric_surface/base.py
def get_polydata(self):
    """Get PyVista PolyData object from previously generated vertices and faces.

    PolyData object will be used to draw the surface and export as .obj file.
    """
    from pyvista import PolyData

    face_vertex_n = 3  # vertices per triangle
    formatted_faces = np.hstack(
        [
            face_vertex_n * np.ones((self.faces.shape[0], 1), dtype=np.uint32),
            self.faces,
        ]
    )
    return PolyData(self.vertices, formatted_faces)

zmx_str

zmx_str(surf_idx, d_next)

Return Zemax surface string.

Source code in deeplens/optics/geometric_surface/base.py
def zmx_str(self, surf_idx, d_next):
    """Return Zemax surface string."""
    raise NotImplementedError(
        "zmx_str() is not implemented for {}".format(self.__class__.__name__)
    )

Spherical surface defined by curvature \(c = 1/R\).

deeplens.optics.geometric_surface.Spheric

Spheric(c, r, d, mat2, pos_xy=[0.0, 0.0], vec_local=[0.0, 0.0, 1.0], is_square=False, device='cpu')

Bases: Surface

Spherical refractive surface parameterized by curvature.

The sag function is:

.. math::

z(x, y) = \frac{c \rho^2}{1 + \sqrt{1 - c^2 \rho^2}}, \quad
\rho^2 = x^2 + y^2

Attributes:

Name Type Description
c Tensor

Surface curvature 1/R [1/mm]. Differentiable with respect to gradient-based optimization.

Initialize a spherical surface.

Parameters:

Name Type Description Default
c float

Surface curvature 1/R [1/mm]. Use 0 for a flat surface (equivalent to Plane).

required
r float

Aperture radius [mm].

required
d float

Axial vertex position [mm].

required
mat2 str or Material

Material on the transmission side.

required
pos_xy list[float]

Lateral offset [x, y] [mm]. Defaults to [0.0, 0.0].

[0.0, 0.0]
vec_local list[float]

Local normal direction. Defaults to [0.0, 0.0, 1.0].

[0.0, 0.0, 1.0]
is_square bool

Square aperture flag. Defaults to False.

False
device str

Compute device. Defaults to "cpu".

'cpu'
Source code in deeplens/optics/geometric_surface/spheric.py
def __init__(
    self,
    c,
    r,
    d,
    mat2,
    pos_xy=[0.0, 0.0],
    vec_local=[0.0, 0.0, 1.0],
    is_square=False,
    device="cpu",
):
    """Initialize a spherical surface.

    Args:
        c (float): Surface curvature ``1/R`` [1/mm].  Use ``0`` for a flat
            surface (equivalent to ``Plane``).
        r (float): Aperture radius [mm].
        d (float): Axial vertex position [mm].
        mat2 (str or Material): Material on the transmission side.
        pos_xy (list[float], optional): Lateral offset ``[x, y]`` [mm].
            Defaults to ``[0.0, 0.0]``.
        vec_local (list[float], optional): Local normal direction.
            Defaults to ``[0.0, 0.0, 1.0]``.
        is_square (bool, optional): Square aperture flag. Defaults to
            ``False``.
        device (str, optional): Compute device. Defaults to ``"cpu"``.
    """
    super(Spheric, self).__init__(
        r=r,
        d=d,
        mat2=mat2,
        pos_xy=pos_xy,
        vec_local=vec_local,
        is_square=is_square,
        device=device,
    )
    self.c = torch.tensor(c)

    self.tolerancing = False
    self.to(device)

intersect

intersect(ray, n=1.0)

Solve ray-surface intersection in local coordinate system using analytical method.

Sphere equation: (x)^2 + (y)^2 + (z - R)^2 = R^2, where R = 1/c Ray equation: p(t) = o + t*d Solve quadratic equation for intersection parameter t.

Parameters:

Name Type Description Default
ray Ray

input ray.

required
n float

refractive index. Defaults to 1.0.

1.0

Returns:

Name Type Description
ray Ray

ray with updated position and opl.

Source code in deeplens/optics/geometric_surface/spheric.py
def intersect(self, ray, n=1.0):
    """Solve ray-surface intersection in local coordinate system using analytical method.

    Sphere equation: (x)^2 + (y)^2 + (z - R)^2 = R^2, where R = 1/c
    Ray equation: p(t) = o + t*d
    Solve quadratic equation for intersection parameter t.

    Args:
        ray (Ray): input ray.
        n (float, optional): refractive index. Defaults to 1.0.

    Returns:
        ray (Ray): ray with updated position and opl.
    """
    # Tolerance
    if self.tolerancing:
        c = self.c + self.c_error
    else:
        c = self.c

    if torch.abs(c) < EPSILON:
        # Handle flat surface as a plane
        t = (0.0 - ray.o[..., 2]) / ray.d[..., 2]
        new_o = ray.o + t.unsqueeze(-1) * ray.d
        valid = (new_o[..., 0] ** 2 + new_o[..., 1] ** 2 < self.r**2) & (
            ray.is_valid > 0
        )
    else:
        R = 1.0 / c

        # Vector from ray origin to sphere center at (0, 0, R)
        oc = ray.o.clone()
        oc[..., 2] = oc[..., 2] - R

        # Quadratic equation: a*t^2 + b*t + c = 0
        # a = d·d = 1 (since ray direction is normalized)
        # b = 2*(o-center)·d
        # c = (o-center)·(o-center) - R^2

        a = torch.sum(ray.d * ray.d, dim=-1)  # Should be 1 for normalized rays
        b = 2.0 * torch.sum(oc * ray.d, dim=-1)
        c_coeff = torch.sum(oc * oc, dim=-1) - R * R

        discriminant = b * b - 4 * a * c_coeff
        valid_intersect = discriminant >= 0

        sqrt_discriminant = torch.sqrt(torch.clamp(discriminant, min=EPSILON))
        t1 = (-b - sqrt_discriminant) / (2 * a + EPSILON)
        t2 = (-b + sqrt_discriminant) / (2 * a + EPSILON)

        # Choose intersection closest to z=0 (surface vertex)
        z1 = ray.o[..., 2] + t1 * ray.d[..., 2]
        z2 = ray.o[..., 2] + t2 * ray.d[..., 2]
        use_t1 = torch.abs(z1) < torch.abs(z2)
        t = torch.where(use_t1, t1, t2)

        new_o = ray.o + t.unsqueeze(-1) * ray.d

        # Check aperture
        r_squared = new_o[..., 0] ** 2 + new_o[..., 1] ** 2
        within_aperture = r_squared <= (self.r**2 + EPSILON)

        valid = valid_intersect & within_aperture & (ray.is_valid > 0)

    # Update ray position
    ray.o = torch.where(valid.unsqueeze(-1), new_o, ray.o)
    ray.is_valid = ray.is_valid * valid

    if ray.coherent:
        if t.abs().max() > 100 and torch.get_default_dtype() == torch.float32:
            raise Exception(
                "Using float32 may cause precision problem for OPL calculation."
            )
        new_opl = ray.opl + n * t.unsqueeze(-1)
        ray.opl = torch.where(valid.unsqueeze(-1), new_opl, ray.opl)

    return ray

is_within_data_range

is_within_data_range(x, y)

Invalid when shape is non-defined.

Source code in deeplens/optics/geometric_surface/spheric.py
def is_within_data_range(self, x, y):
    """Invalid when shape is non-defined."""
    if self.tolerancing:
        c = self.c + self.c_error
    else:
        c = self.c

    valid = (x**2 + y**2) < 1 / c**2
    return valid

max_height

max_height()

Maximum valid height.

Source code in deeplens/optics/geometric_surface/spheric.py
def max_height(self):
    """Maximum valid height."""
    if self.tolerancing:
        c = self.c + self.c_error
    else:
        c = self.c

    max_height = torch.sqrt(1 / c**2).item() - 0.001
    return max_height

init_tolerance

init_tolerance(tolerance_params=None)

Initialize tolerance parameters for the surface.

Parameters:

Name Type Description Default
tolerance_params dict

Tolerance for surface parameters.

None
Source code in deeplens/optics/geometric_surface/spheric.py
def init_tolerance(self, tolerance_params=None):
    """Initialize tolerance parameters for the surface.

    Args:
        tolerance_params (dict): Tolerance for surface parameters.
    """
    super().init_tolerance(tolerance_params)
    if tolerance_params is None:
        tolerance_params = {}
    self.c_tole = tolerance_params.get("c_tole", 0.0001)
    self.c_error = 0.0

sample_tolerance

sample_tolerance()

Randomly perturb surface parameters to simulate manufacturing errors.

Source code in deeplens/optics/geometric_surface/spheric.py
def sample_tolerance(self):
    """Randomly perturb surface parameters to simulate manufacturing errors."""
    super().sample_tolerance()
    self.c_error = float(np.random.randn() * self.c_tole)

zero_tolerance

zero_tolerance()

Zero tolerance.

Source code in deeplens/optics/geometric_surface/spheric.py
def zero_tolerance(self):
    """Zero tolerance."""
    super().zero_tolerance()
    self.c_error = 0.0

sensitivity_score

sensitivity_score()

Tolerance squared sum.

Source code in deeplens/optics/geometric_surface/spheric.py
def sensitivity_score(self):
    """Tolerance squared sum."""
    score_dict = super().sensitivity_score()
    if self.c.grad is not None:
        idx = getattr(self, "surf_idx", id(self))
        score_dict[f"surf{idx}_c_grad"] = round(self.c.grad.item(), 6)
        score_dict[f"surf{idx}_c_score"] = round(
            (self.c_tole**2 * self.c.grad**2).item(), 6
        )
    return score_dict

get_optimizer_params

get_optimizer_params(lrs=[0.0001, 0.0001], optim_mat=False)

Activate gradient computation for c and d and return optimizer parameters.

Source code in deeplens/optics/geometric_surface/spheric.py
def get_optimizer_params(self, lrs=[1e-4, 1e-4], optim_mat=False):
    """Activate gradient computation for c and d and return optimizer parameters."""
    self.c.requires_grad_(True)
    self.d.requires_grad_(True)

    params = []
    params.append({"params": [self.d], "lr": lrs[0]})
    params.append({"params": [self.c], "lr": lrs[1]})

    if optim_mat and self.mat2.get_name() != "air":
        params += self.mat2.get_optimizer_params()

    return params

surf_dict

surf_dict()

Return surface parameters.

Source code in deeplens/optics/geometric_surface/spheric.py
def surf_dict(self):
    """Return surface parameters."""
    roc = 1 / self.c.item() if self.c.item() != 0 else 0.0
    surf_dict = {
        "type": "Spheric",
        "r": round(self.r, 4),
        "(c)": round(self.c.item(), 4),
        "roc": round(roc, 4),
        "(d)": round(self.d.item(), 4),
        "mat2": self.mat2.get_name(),
    }

    return surf_dict

zmx_str

zmx_str(surf_idx, d_next)

Return Zemax surface string.

Source code in deeplens/optics/geometric_surface/spheric.py
    def zmx_str(self, surf_idx, d_next):
        """Return Zemax surface string."""
        if self.mat2.get_name() == "air":
            zmx_str = f"""SURF {surf_idx} 
    TYPE STANDARD 
    CURV {self.c.item()} 
    DISZ {d_next.item()} 
    DIAM {self.r} 1 0 0 1 ""
"""
        else:
            zmx_str = f"""SURF {surf_idx} 
    TYPE STANDARD 
    CURV {self.c.item()} 
    DISZ {d_next.item()} 
    GLAS ___BLANK 1 0 {self.mat2.n} {self.mat2.V}
    DIAM {self.r} 1 0 0 1 ""
"""
        return zmx_str

Even-asphere surface: spherical base with polynomial corrections.

deeplens.optics.geometric_surface.Aspheric

Aspheric(r, d, c, k, ai, mat2, ai2=None, pos_xy=[0.0, 0.0], vec_local=[0.0, 0.0, 1.0], is_square=False, device='cpu')

Bases: Surface

Even-order aspheric surface.

The sag function is:

.. math::

z(\rho) = \frac{c\,\rho^2}{1 + \sqrt{1-(1+k)c^2\rho^2}}
         + \sum_{i=2}^{n} a_{2i}\,\rho^{2i},
\quad \rho^2 = x^2 + y^2

The polynomial starts at the 4th-order term (a4) because the 2nd-order term competes with the base curvature c.

All coefficients c, k, and ai are differentiable torch tensors so they can be optimised with gradient descent.

Attributes:

Name Type Description
c Tensor

Base curvature [1/mm].

k Tensor

Conic constant.

ai2 Tensor or None

2nd-order aspheric coefficient (legacy).

ai Tensor

Even-order aspheric coefficients [a4, a6, a8, ...].

Initialize an aspheric surface.

Parameters:

Name Type Description Default
r float

Aperture radius [mm].

required
d float

Axial vertex position [mm].

required
c float

Base curvature 1/R [1/mm].

required
k float

Conic constant (0 = sphere, -1 = paraboloid).

required
ai list[float] or None

Even-order aspheric coefficients starting from the 4th-order term: [a4, a6, a8, ...]. Pass None or an empty list for a pure conic.

required
mat2 str or Material

Material on the transmission side.

required
ai2 float or None

2nd-order aspheric coefficient from legacy data. Included in sag but not optimised. Defaults to None.

None
pos_xy list[float]

Lateral offset [x, y] [mm]. Defaults to [0.0, 0.0].

[0.0, 0.0]
vec_local list[float]

Local normal direction. Defaults to [0.0, 0.0, 1.0].

[0.0, 0.0, 1.0]
is_square bool

Square aperture flag. Defaults to False.

False
device str

Compute device. Defaults to "cpu".

'cpu'
Source code in deeplens/optics/geometric_surface/aspheric.py
def __init__(
    self,
    r,
    d,
    c,
    k,
    ai,
    mat2,
    ai2=None,
    pos_xy=[0.0, 0.0],
    vec_local=[0.0, 0.0, 1.0],
    is_square=False,
    device="cpu",
):
    """Initialize an aspheric surface.

    Args:
        r (float): Aperture radius [mm].
        d (float): Axial vertex position [mm].
        c (float): Base curvature ``1/R`` [1/mm].
        k (float): Conic constant (``0`` = sphere, ``-1`` = paraboloid).
        ai (list[float] or None): Even-order aspheric coefficients
            starting from the 4th-order term: ``[a4, a6, a8, ...]``.
            Pass ``None`` or an empty list for a pure conic.
        mat2 (str or Material): Material on the transmission side.
        ai2 (float or None, optional): 2nd-order aspheric coefficient
            from legacy data.  Included in sag but not optimised.
            Defaults to ``None``.
        pos_xy (list[float], optional): Lateral offset ``[x, y]`` [mm].
            Defaults to ``[0.0, 0.0]``.
        vec_local (list[float], optional): Local normal direction.
            Defaults to ``[0.0, 0.0, 1.0]``.
        is_square (bool, optional): Square aperture flag.
            Defaults to ``False``.
        device (str, optional): Compute device. Defaults to ``"cpu"``.
    """
    Surface.__init__(
        self,
        r=r,
        d=d,
        mat2=mat2,
        pos_xy=pos_xy,
        vec_local=vec_local,
        is_square=is_square,
        device=device,
    )

    self.c = torch.tensor(c)
    self.k = torch.tensor(k)

    # 2nd-order coefficient (legacy, not optimised)
    if ai2 is not None:
        self.ai2 = torch.tensor(float(ai2))
    else:
        self.ai2 = None

    if ai is not None and len(ai) > 0:
        self.ai = torch.tensor(ai)
        self.ai_degree = len(ai)
        # ai[0] -> ai4, ai[1] -> ai6, ai[2] -> ai8, ...
        for i, a in enumerate(ai):
            setattr(self, f"ai{2 * (i + 2)}", torch.tensor(a))
    else:
        self.ai = None
        self.ai_degree = 0

    self.tolerancing = False
    self.to(device)

is_within_data_range

is_within_data_range(x, y)

Invalid when shape is non-defined.

Source code in deeplens/optics/geometric_surface/aspheric.py
def is_within_data_range(self, x, y):
    """Invalid when shape is non-defined."""
    c, k = self._get_curvature_params()
    if k > -1:
        return (x**2 + y**2) < 1 / c**2 / (1 + k)
    return torch.ones_like(x, dtype=torch.bool)

max_height

max_height()

Maximum valid height.

Source code in deeplens/optics/geometric_surface/aspheric.py
def max_height(self):
    """Maximum valid height."""
    c, k = self._get_curvature_params()
    if k > -1:
        return torch.sqrt(1 / (k + 1) / (c**2)).item() - 0.001
    return 10e3

get_optimizer_params

get_optimizer_params(lrs=[0.0001, 0.0001, 0.01, 0.0001], optim_mat=False)

Get optimizer parameters for different parameters.

The learning rate for each aspheric coefficient a_{2n} is scaled by 1 / max(r, 1)^{2n} so that the effective sag perturbation per Adam step is approximately constant (~lr_base mm) regardless of surface semi-diameter. Without this normalisation, gradients scale as O(r^{2n}) and can reach 10^5 for camera-sized surfaces, causing NaN within a few dozen iterations.

Parameters:

Name Type Description Default
lrs list

learning rates for [d, c, k, ai].

[0.0001, 0.0001, 0.01, 0.0001]
optim_mat bool

whether to optimize material. Defaults to False.

False
Source code in deeplens/optics/geometric_surface/aspheric.py
def get_optimizer_params(self, lrs=[1e-4, 1e-4, 1e-2, 1e-4], optim_mat=False):
    """Get optimizer parameters for different parameters.

    The learning rate for each aspheric coefficient ``a_{2n}`` is scaled
    by ``1 / max(r, 1)^{2n}`` so that the effective sag perturbation per
    Adam step is approximately constant (~lr_base mm) regardless of
    surface semi-diameter.  Without this normalisation, gradients scale
    as ``O(r^{2n})`` and can reach ``10^5`` for camera-sized surfaces,
    causing NaN within a few dozen iterations.

    Args:
        lrs (list, optional): learning rates for ``[d, c, k, ai]``.
        optim_mat (bool, optional): whether to optimize material.
            Defaults to False.
    """
    params = []

    # Optimize distance
    self.d.requires_grad_(True)
    params.append({"params": [self.d], "lr": lrs[0]})

    # Optimize curvature
    self.c.requires_grad_(True)
    params.append({"params": [self.c], "lr": lrs[1]})

    # Optimize conic constant
    self.k.requires_grad_(True)
    params.append({"params": [self.k], "lr": lrs[2]})

    # Optimize aspheric coefficients with r-normalised learning rates.
    # Gradient of sag w.r.t. a_{2n} scales as r^{2n}.  Dividing the lr
    # by r^{2n} keeps the effective sag change per step ≈ lr_base,
    # so every order contributes equally to surface shape evolution.
    if self.ai is not None:
        if self.ai_degree > 0:
            r_norm = max(self.r, 1.0)
            lr_base = lrs[3] if len(lrs) > 3 else 1e-4
            for i in range(self.ai_degree):
                p_name = f"ai{2 * (i + 2)}"
                p = getattr(self, p_name)
                p.requires_grad_(True)
                order = 2 * (i + 2)  # 4, 6, 8, 10, ...
                lr_ai = lr_base / r_norm**order
                params.append({"params": [p], "lr": lr_ai})

    # Optimize material parameters
    if optim_mat and self.mat2.get_name() != "air":
        params += self.mat2.get_optimizer_params()

    return params

init_tolerance

init_tolerance(tolerance_params=None)

Perturb the surface with some tolerance.

Parameters:

Name Type Description Default
tolerance_params dict

Tolerance for surface parameters.

None
References

[1] https://www.edmundoptics.com/capabilities/precision-optics/capabilities/aspheric-lenses/ [2] https://www.edmundoptics.com/knowledge-center/application-notes/optics/all-about-aspheric-lenses/?srsltid=AfmBOoon8AUXVALojol2s5K20gQk7W1qUisc6cE4WzZp3ATFY5T1pK8q

Source code in deeplens/optics/geometric_surface/aspheric.py
@torch.no_grad()
def init_tolerance(self, tolerance_params=None):
    """Perturb the surface with some tolerance.

    Args:
        tolerance_params (dict): Tolerance for surface parameters.

    References:
        [1] https://www.edmundoptics.com/capabilities/precision-optics/capabilities/aspheric-lenses/
        [2] https://www.edmundoptics.com/knowledge-center/application-notes/optics/all-about-aspheric-lenses/?srsltid=AfmBOoon8AUXVALojol2s5K20gQk7W1qUisc6cE4WzZp3ATFY5T1pK8q
    """
    super().init_tolerance(tolerance_params)
    if tolerance_params is None:
        tolerance_params = {}
    self.c_tole = tolerance_params.get("c_tole", 0.001)
    self.k_tole = tolerance_params.get("k_tole", 0.001)
    self.c_error = 0.0
    self.k_error = 0.0

sample_tolerance

sample_tolerance()

Randomly perturb surface parameters to simulate manufacturing errors.

Source code in deeplens/optics/geometric_surface/aspheric.py
def sample_tolerance(self):
    """Randomly perturb surface parameters to simulate manufacturing errors."""
    super().sample_tolerance()
    self.c_error = float(np.random.randn() * self.c_tole)
    self.k_error = float(np.random.randn() * self.k_tole)

zero_tolerance

zero_tolerance()

Zero tolerance.

Source code in deeplens/optics/geometric_surface/aspheric.py
def zero_tolerance(self):
    """Zero tolerance."""
    super().zero_tolerance()
    self.c_error = 0.0
    self.k_error = 0.0

sensitivity_score

sensitivity_score()

Tolerance squared sum.

Source code in deeplens/optics/geometric_surface/aspheric.py
def sensitivity_score(self):
    """Tolerance squared sum."""
    score_dict = super().sensitivity_score()
    idx = getattr(self, "surf_idx", id(self))

    if self.c.grad is not None:
        score_dict[f"surf{idx}_c_grad"] = round(self.c.grad.item(), 6)
        score_dict[f"surf{idx}_c_score"] = round(
            (self.c_tole**2 * self.c.grad**2).item(), 6
        )

    if self.k.grad is not None:
        score_dict[f"surf{idx}_k_grad"] = round(self.k.grad.item(), 6)
        score_dict[f"surf{idx}_k_score"] = round(
            (self.k_tole**2 * self.k.grad**2).item(), 6
        )
    return score_dict

surf_dict

surf_dict()

Return a dict of surface.

Source code in deeplens/optics/geometric_surface/aspheric.py
def surf_dict(self):
    """Return a dict of surface."""
    has_ai2 = self.ai2 is not None
    surf_dict = {
        "type": "Aspheric",
        "r": round(self.r, 4),
        "(c)": round(self.c.item(), 4),
        "roc": round(1 / self.c.item(), 4),
        "d": round(self.d.item(), 4),
        "k": round(self.k.item(), 4),
        "ai": [],
        "use_ai2": has_ai2,
        "mat2": self.mat2.get_name(),
    }

    # Prepend a2 to ai list if present (ai2 key is informational;
    # deserialization reads ai[0] when use_ai2=True)
    if has_ai2:
        surf_dict["ai2"] = float(format(self.ai2.item(), ".6e"))
        surf_dict["ai"].append(float(format(self.ai2.item(), ".6e")))

    for i in range(self.ai_degree):
        order = i + 2
        coeff = getattr(self, f"ai{2 * order}")
        surf_dict[f"(ai{2 * order})"] = float(format(coeff.item(), ".6e"))
        surf_dict["ai"].append(float(format(coeff.item(), ".6e")))

    return surf_dict

zmx_str

zmx_str(surf_idx, d_next)

Return Zemax surface string.

Source code in deeplens/optics/geometric_surface/aspheric.py
    def zmx_str(self, surf_idx, d_next):
        """Return Zemax surface string."""
        assert self.c.item() != 0, (
            "Aperture surface is re-implemented in Aperture class."
        )
        assert self.ai is not None or self.k != 0, (
            "Spheric surface is re-implemented in Spheric class."
        )

        # Collect absolute ai values, PARM 1 = a2, PARM 2+ = a4, a6, ...
        abs_ai = [self.ai2.item() if self.ai2 is not None else 0.0]
        for i in range(self.ai_degree):
            abs_ai.append(getattr(self, f"ai{2 * (i + 2)}").item())

        # Pad with zeros for Zemax PARM format (needs 6 PARMs)
        while len(abs_ai) < 6:
            abs_ai.append(0.0)

        if self.mat2.get_name() == "air":
            zmx_str = f"""SURF {surf_idx}
    TYPE EVENASPH
    CURV {self.c.item()}
    DISZ {d_next.item()}
    DIAM {self.r} 1 0 0 1 ""
    CONI {self.k}
    PARM 1 {abs_ai[0]}
    PARM 2 {abs_ai[1]}
    PARM 3 {abs_ai[2]}
    PARM 4 {abs_ai[3]}
    PARM 5 {abs_ai[4]}
    PARM 6 {abs_ai[5]}
"""
        else:
            zmx_str = f"""SURF {surf_idx}
    TYPE EVENASPH
    CURV {self.c.item()}
    DISZ {d_next.item()}
    GLAS ___BLANK 1 0 {self.mat2.n} {self.mat2.V}
    DIAM {self.r} 1 0 0 1 ""
    CONI {self.k}
    PARM 1 {abs_ai[0]}
    PARM 2 {abs_ai[1]}
    PARM 3 {abs_ai[2]}
    PARM 4 {abs_ai[3]}
    PARM 5 {abs_ai[4]}
    PARM 6 {abs_ai[5]}
"""
        return zmx_str

deeplens.optics.geometric_surface.Aperture

Aperture(r, d, pos_xy=[0.0, 0.0], vec_local=[0.0, 0.0, 1.0], is_square=False, device='cpu')

Bases: Plane

Aperture surface.

Source code in deeplens/optics/geometric_surface/aperture.py
def __init__(
    self,
    r,
    d,
    pos_xy=[0.0, 0.0],
    vec_local=[0.0, 0.0, 1.0],
    is_square=False,
    device="cpu",
):
    """Aperture surface."""
    Plane.__init__(
        self,
        r=r,
        d=d,
        mat2="air",
        pos_xy=pos_xy,
        vec_local=vec_local,
        is_square=is_square,
        device=device,
    )
    self.tolerancing = False
    self.to(device)

ray_reaction

ray_reaction(ray, n1=1.0, n2=1.0, refraction=False)

Compute output ray after intersection and refraction.

Source code in deeplens/optics/geometric_surface/aperture.py
def ray_reaction(self, ray, n1=1.0, n2=1.0, refraction=False):
    """Compute output ray after intersection and refraction."""
    ray = self.to_local_coord(ray)
    ray = self.intersect(ray)
    ray = self.to_global_coord(ray)
    return ray

draw_widget

draw_widget(ax, color='orange', linestyle='solid')

Draw aperture wedge on the figure.

Source code in deeplens/optics/geometric_surface/aperture.py
def draw_widget(self, ax, color="orange", linestyle="solid"):
    """Draw aperture wedge on the figure."""
    d = self.d.item()
    aper_wedge_l = 0.05 * self.r  # [mm]
    aper_wedge_h = 0.15 * self.r  # [mm]

    # Parallel edges
    z = np.linspace(d - aper_wedge_l, d + aper_wedge_l, 3)
    x = -self.r * np.ones(3)
    ax.plot(z, x, color=color, linestyle=linestyle, linewidth=0.8)
    x = self.r * np.ones(3)
    ax.plot(z, x, color=color, linestyle=linestyle, linewidth=0.8)

    # Vertical edges
    z = d * np.ones(3)
    x = np.linspace(self.r, self.r + aper_wedge_h, 3)
    ax.plot(z, x, color=color, linestyle=linestyle, linewidth=0.8)
    x = np.linspace(-self.r - aper_wedge_h, -self.r, 3)
    ax.plot(z, x, color=color, linestyle=linestyle, linewidth=0.8)

draw_widget3D

draw_widget3D(ax, color='black')

Draw the aperture as a circle in a 3D plot.

Source code in deeplens/optics/geometric_surface/aperture.py
def draw_widget3D(self, ax, color="black"):
    """Draw the aperture as a circle in a 3D plot."""
    # Draw the edge circle
    theta = np.linspace(0, 2 * np.pi, 100)
    edge_x = self.r * np.cos(theta)
    edge_y = self.r * np.sin(theta)
    edge_z = np.full_like(edge_x, self.d.item())  # Constant z at aperture position

    # Plot the edge circle
    line = ax.plot(edge_z, edge_x, edge_y, color=color, linewidth=1.5)

    return line

create_mesh

create_mesh(n_rings=32, n_arms=128, color=[0.0, 0.0, 0.0])

Create triangulated surface mesh.

Parameters:

Name Type Description Default
n_rings int

Number of concentric rings for sampling.

32
n_arms int

Number of angular divisions.

128
color List[float]

The color of the mesh.

[0.0, 0.0, 0.0]

Returns:

Name Type Description
self

The surface with mesh data.

Source code in deeplens/optics/geometric_surface/aperture.py
def create_mesh(self, n_rings=32, n_arms=128, color=[0.0, 0.0, 0.0]):
    """Create triangulated surface mesh.

    Args:
        n_rings (int): Number of concentric rings for sampling.
        n_arms (int): Number of angular divisions.
        color (List[float]): The color of the mesh.

    Returns:
        self: The surface with mesh data.
    """
    self.vertices = self._create_vertices(n_rings, n_arms)
    self.faces = self._create_faces(n_rings, n_arms)
    self.rim = self._create_rim(n_rings, n_arms)
    self.mesh_color = color
    return self

get_optimizer_params

get_optimizer_params(lrs=[0.0001])

Activate gradient computation for d and return optimizer parameters.

Source code in deeplens/optics/geometric_surface/aperture.py
def get_optimizer_params(self, lrs=[1e-4]):
    """Activate gradient computation for d and return optimizer parameters."""
    self.d.requires_grad_(True)

    params = []
    params.append({"params": [self.d], "lr": lrs[0]})

    return params

surf_dict

surf_dict()

Dict of surface parameters.

Source code in deeplens/optics/geometric_surface/aperture.py
def surf_dict(self):
    """Dict of surface parameters."""
    surf_dict = {
        "type": "Aperture",
        "r": round(self.r, 4),
        "(d)": round(self.d.item(), 4),
        "mat2": "air",
        "is_square": self.is_square,
    }
    return surf_dict

zmx_str

zmx_str(surf_idx, d_next)

Zemax surface string.

Source code in deeplens/optics/geometric_surface/aperture.py
    def zmx_str(self, surf_idx, d_next):
        """Zemax surface string."""
        zmx_str = f"""SURF {surf_idx}
    STOP
    TYPE STANDARD
    CURV 0.0
    DISZ {d_next.item()}
"""
        return zmx_str

Light Representations

Geometric ray representation carrying origin, direction, wavelength, validity mask, energy, and optical path length (OPL).

deeplens.optics.Ray

Ray(o, d, wvln=DEFAULT_WAVE, coherent=False, device='cpu')

Bases: DeepObj

Batched ray bundle for optical simulation.

Stores ray origins, directions, wavelength, validity mask, energy, obliquity, and (in coherent mode) optical path length. All tensor attributes share the same batch shape (*batch_size, num_rays).

Attributes:

Name Type Description
o Tensor

Ray origins, shape (*batch, num_rays, 3) [mm].

d Tensor

Unit ray directions, shape (*batch, num_rays, 3).

wvln Tensor

Wavelength scalar [µm].

is_valid Tensor

Binary validity mask, shape (*batch, num_rays).

en Tensor

Energy weight, shape (*batch, num_rays, 1).

obliq Tensor

Obliquity factor, shape (*batch, num_rays, 1).

opl Tensor

Optical path length (coherent mode only), shape (*batch, num_rays, 1) [mm].

coherent bool

Whether OPL tracking is enabled.

Initialize a ray object.

Parameters:

Name Type Description Default
o Tensor

Ray origin, shape (..., num_rays, 3) [mm].

required
d Tensor

Ray direction, shape (..., num_rays, 3).

required
wvln float

Ray wavelength [µm].

DEFAULT_WAVE
coherent bool

Enable optical path length tracking for coherent tracing. Defaults to False.

False
device str

Compute device. Defaults to "cpu".

'cpu'
Source code in deeplens/optics/light/ray.py
def __init__(self, o, d, wvln=DEFAULT_WAVE, coherent=False, device="cpu"):
    """Initialize a ray object.

    Args:
        o (torch.Tensor): Ray origin, shape ``(..., num_rays, 3)`` [mm].
        d (torch.Tensor): Ray direction, shape ``(..., num_rays, 3)``.
        wvln (float): Ray wavelength [µm].
        coherent (bool): Enable optical path length tracking for coherent
            tracing. Defaults to ``False``.
        device (str): Compute device. Defaults to ``"cpu"``.
    """
    # Basic ray parameters - move to device
    self.o = (o if torch.is_tensor(o) else torch.tensor(o)).to(device)
    self.d = (d if torch.is_tensor(d) else torch.tensor(d)).to(device)
    self.shape = self.o.shape[:-1]

    # Wavelength
    assert wvln > 0.1 and wvln < 10.0, "Ray wavelength unit should be [um]"
    self.wvln = torch.tensor(wvln, device=device)

    # Auxiliary ray parameters - create directly on device
    self.is_valid = torch.ones(self.shape, device=device)
    self.en = torch.ones((*self.shape, 1), device=device)
    self.obliq = torch.ones((*self.shape, 1), device=device)

    # Coherent ray tracing
    self.coherent = coherent  # bool
    self.opl = torch.zeros((*self.shape, 1), device=device)

    self.device = device
    self.d = F.normalize(self.d, p=2, dim=-1)

prop_to

prop_to(z, n=1.0)

Ray propagates to a given depth plane.

Parameters:

Name Type Description Default
z float

depth.

required
n float

refractive index. Defaults to 1.

1.0
Source code in deeplens/optics/light/ray.py
def prop_to(self, z, n=1.0):
    """Ray propagates to a given depth plane.

    Args:
        z (float): depth.
        n (float, optional): refractive index. Defaults to 1.
    """
    t = (z - self.o[..., 2]) / self.d[..., 2]
    new_o = self.o + self.d * t.unsqueeze(-1)
    valid_mask = (self.is_valid > 0).unsqueeze(-1)
    self.o = torch.where(valid_mask, new_o, self.o)

    if self.coherent:
        if t.dtype != torch.float64:
            raise Warning("Should use float64 in coherent ray tracing.")
        else:
            new_opl = self.opl + n * t.unsqueeze(-1)
            self.opl = torch.where(valid_mask, new_opl, self.opl)

    return self

centroid

centroid()

Calculate the centroid of the ray, shape (..., num_rays, 3)

Returns:

Type Description

torch.Tensor: Centroid of the ray, shape (..., 3)

Source code in deeplens/optics/light/ray.py
def centroid(self):
    """Calculate the centroid of the ray, shape (..., num_rays, 3)

    Returns:
        torch.Tensor: Centroid of the ray, shape (..., 3)
    """
    return (self.o * self.is_valid.unsqueeze(-1)).sum(-2) / self.is_valid.sum(
        -1
    ).add(EPSILON).unsqueeze(-1)

rms_error

rms_error(center_ref=None)

Calculate the RMS error of the ray.

Parameters:

Name Type Description Default
center_ref Tensor

Reference center of the ray, shape (..., 3). If None, use the centroid of the ray as reference.

None

Returns:

Type Description

torch.Tensor: average RMS error of the ray

Source code in deeplens/optics/light/ray.py
def rms_error(self, center_ref=None):
    """Calculate the RMS error of the ray.

    Args:
        center_ref (torch.Tensor): Reference center of the ray, shape (..., 3). If None, use the centroid of the ray as reference.

    Returns:
        torch.Tensor: average RMS error of the ray
    """
    # Calculate the centroid of the ray as reference
    if center_ref is None:
        with torch.no_grad():
            center_ref = self.centroid()

    center_ref = center_ref.unsqueeze(-2)

    # Calculate RMS error for each region
    rms_error = ((self.o[..., :2] - center_ref[..., :2]) ** 2).sum(-1)
    rms_error = (rms_error * self.is_valid).sum(-1) / (
        self.is_valid.sum(-1) + EPSILON
    )
    rms_error = rms_error.sqrt()

    # Average RMS error
    return rms_error.mean()

flip_xy

flip_xy()

Flip the x and y coordinates of the ray.

This function is used when calculating point spread function and wavefront distribution.

Source code in deeplens/optics/light/ray.py
def flip_xy(self):
    """Flip the x and y coordinates of the ray.

    This function is used when calculating point spread function and wavefront distribution.
    """
    self.o = torch.cat([-self.o[..., :2], self.o[..., 2:]], dim=-1)
    self.d = torch.cat([-self.d[..., :2], self.d[..., 2:]], dim=-1)
    return self

clone

clone(device=None)

Clone the ray.

Can spercify which device we want to clone. Sometimes we want to store all rays in CPU, and when using it, we move it to GPU.

Source code in deeplens/optics/light/ray.py
def clone(self, device=None):
    """Clone the ray.

    Can spercify which device we want to clone. Sometimes we want to store all rays in CPU, and when using it, we move it to GPU.
    """
    if device is None:
        return copy.deepcopy(self).to(self.device)
    else:
        return copy.deepcopy(self).to(device)

squeeze

squeeze(dim=None)

Squeeze the ray.

Parameters:

Name Type Description Default
dim int

dimension to squeeze. Defaults to None.

None
Source code in deeplens/optics/light/ray.py
def squeeze(self, dim=None):
    """Squeeze the ray.

    Args:
        dim (int, optional): dimension to squeeze. Defaults to None.
    """
    self.o = self.o.squeeze(dim)
    self.d = self.d.squeeze(dim)
    # wvln is a single element tensor, no squeeze needed
    self.is_valid = self.is_valid.squeeze(dim)
    self.en = self.en.squeeze(dim)
    self.opl = self.opl.squeeze(dim)
    self.obliq = self.obliq.squeeze(dim)
    return self

unsqueeze

unsqueeze(dim=None)

Unsqueeze the ray.

Parameters:

Name Type Description Default
dim int

dimension to unsqueeze. Defaults to None.

None
Source code in deeplens/optics/light/ray.py
def unsqueeze(self, dim=None):
    """Unsqueeze the ray.

    Args:
        dim (int, optional): dimension to unsqueeze. Defaults to None.
    """
    self.o = self.o.unsqueeze(dim)
    self.d = self.d.unsqueeze(dim)
    # wvln is a single element tensor, no unsqueeze needed
    self.is_valid = self.is_valid.unsqueeze(dim)
    self.en = self.en.unsqueeze(dim)
    self.opl = self.opl.unsqueeze(dim)
    self.obliq = self.obliq.unsqueeze(dim)
    return self

Complex electromagnetic field with Angular Spectrum Method (ASM), Fresnel, and Fraunhofer propagation via torch.fft.

deeplens.optics.ComplexWave

ComplexWave(u=None, wvln=0.55, z=0.0, phy_size=(4.0, 4.0), res=(2000, 2000))

Bases: DeepObj

Complex scalar wave field for diffraction simulation.

Represents a monochromatic, coherent complex amplitude on a uniform rectangular grid. Propagation methods (ASM, Fresnel, Fraunhofer) are implemented as member functions and use torch.fft for efficiency.

Attributes:

Name Type Description
u Tensor

Complex amplitude, shape [1, 1, H, W].

wvln float

Wavelength [µm].

k float

Wave number 2π / (λ × 10⁻³) [mm⁻¹].

phy_size tuple

Physical aperture size (W, H) [mm].

ps float

Pixel pitch [mm] (must be square).

res tuple

Grid resolution (H, W) in pixels.

z float

Current axial position [mm].

Initialize a complex wave field.

Parameters:

Name Type Description Default
u Tensor or None

Initial complex amplitude. Accepted shapes: [H, W], [1, H, W], or [1, 1, H, W]. If None a zero field is created with the given res.

None
wvln float

Wavelength [µm]. Defaults to 0.55.

0.55
z float

Initial axial position [mm]. Defaults to 0.0.

0.0
phy_size tuple

Physical aperture (W, H) [mm]. Defaults to (4.0, 4.0).

(4.0, 4.0)
res tuple

Grid resolution (H, W) [pixels]. Only used when u is None. Defaults to (2000, 2000).

(2000, 2000)

Raises:

Type Description
AssertionError

If the pixel pitch is not square or the wavelength is outside the range (0.1, 10) µm.

Source code in deeplens/optics/light/wave.py
def __init__(
    self,
    u=None,
    wvln=0.55,
    z=0.0,
    phy_size=(4.0, 4.0),
    res=(2000, 2000),
):
    """Initialize a complex wave field.

    Args:
        u (torch.Tensor or None, optional): Initial complex amplitude.
            Accepted shapes: ``[H, W]``, ``[1, H, W]``, or
            ``[1, 1, H, W]``.  If ``None`` a zero field is created with
            the given *res*.
        wvln (float, optional): Wavelength [µm].  Defaults to ``0.55``.
        z (float, optional): Initial axial position [mm].  Defaults to
            ``0.0``.
        phy_size (tuple, optional): Physical aperture (W, H) [mm].
            Defaults to ``(4.0, 4.0)``.
        res (tuple, optional): Grid resolution (H, W) [pixels].  Only
            used when *u* is ``None``.  Defaults to ``(2000, 2000)``.

    Raises:
        AssertionError: If the pixel pitch is not square or the
            wavelength is outside the range ``(0.1, 10)`` µm.
    """
    if u is not None:
        if not u.dtype == torch.complex128:
            print(
                "A complex wave field is created with single precision. In the future, we want to always use double precision."
            )

        self.u = u if torch.is_tensor(u) else torch.from_numpy(u)
        if not self.u.is_complex():
            self.u = self.u.to(torch.complex64)

        # [H, W] or [1, H, W] to [1, 1, H, W]
        if len(u.shape) == 2:
            self.u = u.unsqueeze(0).unsqueeze(0)
        elif len(self.u.shape) == 3:
            self.u = self.u.unsqueeze(0)

        self.res = self.u.shape[-2:]

    else:
        # Initialize a zero complex wave field
        amp = torch.zeros(res).unsqueeze(0).unsqueeze(0)
        phi = torch.zeros(res).unsqueeze(0).unsqueeze(0)
        self.u = amp + 1j * phi
        self.res = res

    # Wave field parameters
    assert wvln > 0.1 and wvln < 10.0, "Wavelength should be in [um]."
    self.wvln = wvln  # [um], wavelength
    self.k = 2 * torch.pi / (self.wvln * 1e-3)  # [mm^-1], wave number
    self.phy_size = phy_size  # [mm], physical size
    assert phy_size[0] / self.res[0] == phy_size[1] / self.res[1], (
        "Pixel size is not square."
    )
    self.ps = phy_size[0] / self.res[0]  # [mm], pixel size

    # Wave field grid
    self.x, self.y = self.gen_xy_grid()  # x, y grid
    self.z = torch.full_like(self.x, z)  # z grid

    # Cache propagation method boundaries (depend only on wvln, ps, phy_size)
    self._asm_zmax = Nyquist_ASM_zmax(wvln=self.wvln, ps=self.ps, side_length=self.phy_size[0])
    self._fresnel_zmin = Fresnel_zmin(wvln=self.wvln, ps=self.ps, side_length=self.phy_size[0])

point_wave classmethod

point_wave(point=(0, 0, -1000.0), wvln=0.55, z=0.0, phy_size=(4.0, 4.0), res=(2000, 2000), valid_r=None)

Create a spherical wave field on x0y plane originating from a point source.

Parameters:

Name Type Description Default
point tuple

Point source position in object space. [mm]. Defaults to (0, 0, -1000.0).

(0, 0, -1000.0)
wvln float

Wavelength. [um]. Defaults to 0.55.

0.55
z float

Field z position. [mm]. Defaults to 0.0.

0.0
phy_size tuple

Valid plane on x0y plane. [mm]. Defaults to (2, 2).

(4.0, 4.0)
res tuple

Valid plane resoltution. Defaults to (1000, 1000).

(2000, 2000)
valid_r float

Valid circle radius. [mm]. Defaults to None.

None

Returns:

Name Type Description
field ComplexWave

Complex field on x0y plane.

Source code in deeplens/optics/light/wave.py
@classmethod
def point_wave(
    cls,
    point=(0, 0, -1000.0),
    wvln=0.55,
    z=0.0,
    phy_size=(4.0, 4.0),
    res=(2000, 2000),
    valid_r=None,
):
    """Create a spherical wave field on x0y plane originating from a point source.

    Args:
        point (tuple): Point source position in object space. [mm]. Defaults to (0, 0, -1000.0).
        wvln (float): Wavelength. [um]. Defaults to 0.55.
        z (float): Field z position. [mm]. Defaults to 0.0.
        phy_size (tuple): Valid plane on x0y plane. [mm]. Defaults to (2, 2).
        res (tuple): Valid plane resoltution. Defaults to (1000, 1000).
        valid_r (float): Valid circle radius. [mm]. Defaults to None.

    Returns:
        field (ComplexWave): Complex field on x0y plane.
    """
    assert wvln > 0.1 and wvln < 10.0, "Wavelength should be in [um]."
    k = 2 * torch.pi / (wvln * 1e-3)  # [mm^-1], wave number

    # Create meshgrid on target plane
    x, y = torch.meshgrid(
        torch.linspace(
            -0.5 * phy_size[0], 0.5 * phy_size[0], res[0], dtype=torch.float64
        ),
        torch.linspace(
            0.5 * phy_size[1], -0.5 * phy_size[1], res[1], dtype=torch.float64
        ),
        indexing="xy",
    )

    # Calculate distance to point source, and calculate spherical wave phase
    r = torch.sqrt((x - point[0]) ** 2 + (y - point[1]) ** 2 + (z - point[2]) ** 2)
    if point[2] < z:
        phi = k * r
    else:
        phi = -k * r
    u = (r.min() / r) * torch.exp(1j * phi)

    # Apply valid circle if provided, e.g., the aperture of a lens
    if valid_r is not None:
        mask = (x - point[0]) ** 2 + (y - point[1]) ** 2 < valid_r**2
        u = u * mask

    # Create wave field
    return cls(u=u, wvln=wvln, phy_size=phy_size, res=res, z=z)

plane_wave classmethod

plane_wave(wvln=0.55, z=0.0, phy_size=(4.0, 4.0), res=(2000, 2000), valid_r=None)

Create a planar wave field on x0y plane.

Parameters:

Name Type Description Default
wvln float

Wavelength. [um].

0.55
z float

Field z position. [mm].

0.0
phy_size tuple

Physical size of the field. [mm].

(4.0, 4.0)
res tuple

Resolution.

(2000, 2000)
valid_r float

Valid circle radius. [mm].

None

Returns:

Name Type Description
field ComplexWave

Complex field.

Source code in deeplens/optics/light/wave.py
@classmethod
def plane_wave(
    cls,
    wvln=0.55,
    z=0.0,
    phy_size=(4.0, 4.0),
    res=(2000, 2000),
    valid_r=None,
):
    """Create a planar wave field on x0y plane.

    Args:
        wvln (float): Wavelength. [um].
        z (float): Field z position. [mm].
        phy_size (tuple): Physical size of the field. [mm].
        res (tuple): Resolution.
        valid_r (float): Valid circle radius. [mm].

    Returns:
        field (ComplexWave): Complex field.
    """
    assert wvln > 0.1 and wvln < 10.0, "Wavelength should be in [um]."

    # Create a plane wave field
    u = torch.ones(res, dtype=torch.float64) + 0j

    # Apply valid circle if provided
    if valid_r is not None:
        x, y = torch.meshgrid(
            torch.linspace(-0.5 * phy_size[0], 0.5 * phy_size[0], res[0]),
            torch.linspace(-0.5 * phy_size[1], 0.5 * phy_size[1], res[1]),
            indexing="xy",
        )
        mask = (x**2 + y**2) < valid_r**2
        u = u * mask

    # Create wave field
    return cls(u=u, phy_size=phy_size, wvln=wvln, res=res, z=z)

image_wave classmethod

image_wave(img, wvln=0.55, z=0.0, phy_size=(4.0, 4.0))

Initialize a complex wave field from an image.

Parameters:

Name Type Description Default
img Tensor

Input image with shape [H, W] or [B, C, H, W]. Data range is [0, 1].

required
wvln float

Wavelength. [um].

0.55
z float

Field z position. [mm].

0.0
phy_size tuple

Physical size of the field. [mm].

(4.0, 4.0)

Returns:

Name Type Description
field ComplexWave

Complex field.

Source code in deeplens/optics/light/wave.py
@classmethod
def image_wave(cls, img, wvln=0.55, z=0.0, phy_size=(4.0, 4.0)):
    """Initialize a complex wave field from an image.

    Args:
        img (torch.Tensor): Input image with shape [H, W] or [B, C, H, W]. Data range is [0, 1].
        wvln (float): Wavelength. [um].
        z (float): Field z position. [mm].
        phy_size (tuple): Physical size of the field. [mm].

    Returns:
        field (ComplexWave): Complex field.
    """
    assert img.dtype == torch.float32, "Image must be float32."

    amp = torch.sqrt(img)
    phi = torch.zeros_like(amp)
    u = amp + 1j * phi

    return cls(u=u, wvln=wvln, phy_size=phy_size, res=u.shape[-2:], z=z)

prop

prop(prop_dist, n=1.0)

Propagate the field by distance z. Can only propagate planar wave.

Reference

[1] Modeling and propagation of near-field diffraction patterns: A more complete approach. Table 1. [2] https://github.com/kaanaksit/odak/blob/master/odak/wave/classical.py [3] https://spie.org/samples/PM103.pdf [4] "Non-approximated Rayleigh Sommerfeld diffraction integral: advantages and disadvantages in the propagation of complex wave fields"

Parameters:

Name Type Description Default
prop_dist float

propagation distance, unit [mm].

required
n float

refractive index.

1.0

Returns:

Name Type Description
self

propagated complex wave field.

Source code in deeplens/optics/light/wave.py
def prop(self, prop_dist, n=1.0):
    """Propagate the field by distance z. Can only propagate planar wave.

    Reference:
        [1] Modeling and propagation of near-field diffraction patterns: A more complete approach. Table 1.
        [2] https://github.com/kaanaksit/odak/blob/master/odak/wave/classical.py
        [3] https://spie.org/samples/PM103.pdf
        [4] "Non-approximated Rayleigh Sommerfeld diffraction integral: advantages and disadvantages in the propagation of complex wave fields"

    Args:
        prop_dist (float): propagation distance, unit [mm].
        n (float): refractive index.

    Returns:
        self: propagated complex wave field.
    """
    # Determine propagation method using cached boundaries
    wvln_mm = self.wvln * 1e-3  # [um] to [mm]

    # Wave propagation methods
    if prop_dist < DELTA:
        # Zero distance: do nothing
        pass

    elif prop_dist < wvln_mm:
        # Sub-wavelength distance: full wave method (e.g., FDTD)
        raise Exception(
            "The propagation distance in sub-wavelength range is not implemented yet. Have to use full wave method (e.g., FDTD)."
        )

    elif prop_dist < self._asm_zmax:
        # Angular Spectrum Method (ASM)
        self.u = AngularSpectrumMethod(self.u, z=prop_dist, wvln=self.wvln, ps=self.ps, n=n)

    elif prop_dist > self._fresnel_zmin:
        # Fresnel diffraction
        self.u = FresnelDiffraction(self.u, z=prop_dist, wvln=self.wvln, ps=self.ps, n=n)

    else:
        raise Exception(f"Propagation method not implemented for distance {prop_dist} mm.")

    # Update z grid
    self.z += prop_dist
    return self

prop_to

prop_to(z, n=1)

Propagate the field to plane z.

Parameters:

Name Type Description Default
z float

destination plane z coordinate.

required
Source code in deeplens/optics/light/wave.py
def prop_to(self, z, n=1):
    """Propagate the field to plane z.

    Args:
        z (float): destination plane z coordinate.
    """
    # Use float() instead of .item() to avoid GPU-CPU sync on CUDA tensors
    # (self.z is a full grid but all values are identical; [0,0] is representative)
    prop_dist = float(z) - float(self.z[0, 0])
    self.prop(prop_dist, n=n)
    return self

gen_xy_grid

gen_xy_grid()

Generate the x and y grid.

Source code in deeplens/optics/light/wave.py
def gen_xy_grid(self):
    """Generate the x and y grid."""
    x, y = torch.meshgrid(
        torch.linspace(-0.5 * self.phy_size[1], 0.5 * self.phy_size[1], self.res[0],),
        torch.linspace(0.5 * self.phy_size[0], -0.5 * self.phy_size[0], self.res[1],),
        indexing="xy",
    )
    return x, y

gen_freq_grid

gen_freq_grid()

Generate the frequency grid.

Source code in deeplens/optics/light/wave.py
def gen_freq_grid(self):
    """Generate the frequency grid."""
    x, y = self.gen_xy_grid()
    fx = x / (self.ps * self.phy_size[0])
    fy = y / (self.ps * self.phy_size[1])
    return fx, fy

load_npz

load_npz(filepath)

Load data from npz file.

Source code in deeplens/optics/light/wave.py
def load_npz(self, filepath):
    """Load data from npz file."""
    data = np.load(filepath)
    self.u = torch.from_numpy(data["u"])
    self.x = torch.from_numpy(data["x"])
    self.y = torch.from_numpy(data["y"])
    self.wvln = data["wvln"].item()
    self.phy_size = data["phy_size"].tolist()
    self.res = self.u.shape[-2:]

save

save(filepath='./wavefield.npz')

Save the complex wave field to a npz file.

Source code in deeplens/optics/light/wave.py
def save(self, filepath="./wavefield.npz"):
    """Save the complex wave field to a npz file."""
    if filepath.endswith(".npz"):
        self.save_npz(filepath)
    else:
        raise Exception("Unimplemented file format.")

save_npz

save_npz(filepath='./wavefield.npz')

Save the complex wave field to a npz file.

Source code in deeplens/optics/light/wave.py
def save_npz(self, filepath="./wavefield.npz"):
    """Save the complex wave field to a npz file."""
    # Save data
    np.savez_compressed(
        filepath,
        u=self.u.cpu().numpy(),
        x=self.x.cpu().numpy(),
        y=self.y.cpu().numpy(),
        wvln=np.array(self.wvln),
        phy_size=np.array(self.phy_size),
    )

    # Save intensity, amplitude, and phase images
    u = self.u.cpu()
    save_image(u.abs() ** 2, f"{filepath[:-4]}_intensity.png", normalize=True)
    save_image(u.abs(), f"{filepath[:-4]}_amp.png", normalize=True)
    save_image(u.angle(), f"{filepath[:-4]}_phase.png", normalize=True)

show

show(save_name=None, data='irr')

Save the field as an image.

Source code in deeplens/optics/light/wave.py
def show(self, save_name=None, data="irr"):
    """Save the field as an image."""
    cmap = "gray"
    if data == "irr":
        value = self.u.detach().abs() ** 2
    elif data == "amp":
        value = self.u.detach().abs()
    elif data == "phi" or data == "phase":
        value = torch.angle(self.u).detach()
        cmap = "hsv"
    elif data == "real":
        value = self.u.real.detach()
    elif data == "imag":
        value = self.u.imag.detach()
    else:
        raise Exception(f"Unimplemented visualization: {data}.")

    if len(self.u.shape) == 2:
        raise Exception("Deprecated.")
        if save_name is not None:
            save_image(value, save_name, normalize=True)
        else:
            value = value.cpu().numpy()
            plt.imshow(
                value,
                cmap=cmap,
                extent=[
                    -self.phy_size[0] / 2,
                    self.phy_size[0] / 2,
                    -self.phy_size[1] / 2,
                    self.phy_size[1] / 2,
                ],
            )

    elif len(self.u.shape) == 4:
        B, C, H, W = self.u.shape
        if B == 1:
            if save_name is not None:
                save_image(value, save_name, normalize=True)
            else:
                value = value.cpu().numpy()
                plt.imshow(
                    value[0, 0, :, :],
                    cmap=cmap,
                    extent=[
                        -self.phy_size[0] / 2,
                        self.phy_size[0] / 2,
                        -self.phy_size[1] / 2,
                        self.phy_size[1] / 2,
                    ],
                )
        else:
            if save_name is not None:
                plt.savefig(save_name)
            else:
                value = value.cpu().numpy()
                fig, axs = plt.subplots(1, B)
                for i in range(B):
                    axs[i].imshow(
                        value[i, 0, :, :],
                        cmap=cmap,
                        extent=[
                            -self.phy_size[0] / 2,
                            self.phy_size[0] / 2,
                            -self.phy_size[1] / 2,
                            self.phy_size[1] / 2,
                        ],
                    )
                fig.show()
    else:
        raise Exception("Unsupported complex field shape.")

pad

pad(Hpad, Wpad)

Pad the input field by (Hpad, Hpad, Wpad, Wpad). This step will also expand physical size of the field.

Parameters:

Name Type Description Default
Hpad int

Number of pixels to pad on the top and bottom.

required
Wpad int

Number of pixels to pad on the left and right.

required

Returns:

Name Type Description
self

Padded complex wave field.

Source code in deeplens/optics/light/wave.py
def pad(self, Hpad, Wpad):
    """Pad the input field by (Hpad, Hpad, Wpad, Wpad). This step will also expand physical size of the field.

    Args:
        Hpad (int): Number of pixels to pad on the top and bottom.
        Wpad (int): Number of pixels to pad on the left and right.

    Returns:
        self: Padded complex wave field.
    """
    self.u = F.pad(self.u, (Hpad, Hpad, Wpad, Wpad), mode="constant", value=0)

    Horg, Worg = self.res
    self.res = [Horg + 2 * Hpad, Worg + 2 * Wpad]
    self.phy_size = [
        self.phy_size[0] * self.res[0] / Horg,
        self.phy_size[1] * self.res[1] / Worg,
    ]
    self.x, self.y = self.gen_xy_grid()
    self.z = torch.full_like(self.x, float(self.z[0, 0]))

flip

flip()

Flip the field horizontally and vertically.

Source code in deeplens/optics/light/wave.py
def flip(self):
    """Flip the field horizontally and vertically."""
    self.u = torch.flip(self.u, [-1, -2])
    self.x = torch.flip(self.x, [-1, -2])
    self.y = torch.flip(self.y, [-1, -2])
    self.z = torch.flip(self.z, [-1, -2])
    return self

PSF Utilities

Functions for convolving images with point spread functions.

deeplens.optics.imgsim.psf.conv_psf

conv_psf(img, psf)

Convolve an image batch with a single spatially-uniform PSF.

Applies a per-channel 2-D convolution using reflect boundary padding so that the output has the same spatial dimensions as the input. The PSF is internally flipped to convert the cross-correlation implemented by F.conv2d into a true convolution.

Parameters:

Name Type Description Default
img Tensor

Input image batch, shape [B, C, H, W].

required
psf Tensor

PSF kernel, shape [C, ks, ks]. ks may be odd or even.

required

Returns:

Type Description

torch.Tensor: Rendered image, shape [B, C, H, W].

Example

psf = lens.psf_rgb(points=torch.tensor([0.0, 0.0, -10000.0])) img_blur = conv_psf(img, psf)

Source code in deeplens/optics/imgsim/psf.py
def conv_psf(img, psf):
    """Convolve an image batch with a single spatially-uniform PSF.

    Applies a per-channel 2-D convolution using ``reflect`` boundary padding
    so that the output has the same spatial dimensions as the input.  The PSF
    is internally flipped to convert the cross-correlation implemented by
    ``F.conv2d`` into a true convolution.

    Args:
        img (torch.Tensor): Input image batch, shape ``[B, C, H, W]``.
        psf (torch.Tensor): PSF kernel, shape ``[C, ks, ks]``.  ``ks`` may be
            odd or even.

    Returns:
        torch.Tensor: Rendered image, shape ``[B, C, H, W]``.

    Example:
        >>> psf = lens.psf_rgb(points=torch.tensor([0.0, 0.0, -10000.0]))
        >>> img_blur = conv_psf(img, psf)
    """
    B, C, H, W = img.shape
    C_psf, ks, _ = psf.shape
    assert C_psf == C, f"psf channels ({C_psf}) must match image channels ({C})."

    # Flip the PSF because F.conv2d use cross-correlation
    psf = torch.flip(psf, [1, 2])
    psf = psf.unsqueeze(1)  # shape [C, 1, ks, ks]

    # Padding
    pad_h_left  = (ks - 1) // 2
    pad_h_right = ks // 2
    pad_w_left  = (ks - 1) // 2
    pad_w_right = ks // 2
    img_pad = F.pad(img, (pad_w_left, pad_w_right, pad_h_left, pad_h_right), mode="reflect")

    # Convolution
    img_render = F.conv2d(img_pad, psf, groups=C)
    return img_render

deeplens.optics.imgsim.psf.conv_psf_map

conv_psf_map(img, psf_map)

Convolve an image batch with a spatially-varying PSF map.

Divides the image into grid_h × grid_w non-overlapping patches and convolves each patch with its corresponding PSF kernel. The results are assembled back into a full-resolution output via a weighted blending step.

Parameters:

Name Type Description Default
img Tensor

Input image batch, shape [B, C, H, W].

required
psf_map Tensor

PSF map, shape [grid_h, grid_w, C, ks, ks].

required

Returns:

Type Description

torch.Tensor: Rendered image, shape [B, C, H, W].

Source code in deeplens/optics/imgsim/psf.py
def conv_psf_map(img, psf_map):
    """Convolve an image batch with a spatially-varying PSF map.

    Divides the image into ``grid_h × grid_w`` non-overlapping patches and
    convolves each patch with its corresponding PSF kernel.  The results are
    assembled back into a full-resolution output via a weighted blending step.

    Args:
        img (torch.Tensor): Input image batch, shape ``[B, C, H, W]``.
        psf_map (torch.Tensor): PSF map, shape ``[grid_h, grid_w, C, ks, ks]``.

    Returns:
        torch.Tensor: Rendered image, shape ``[B, C, H, W]``.
    """
    B, C, H, W = img.shape
    grid_h, grid_w, C_psf, ks, _ = psf_map.shape
    assert C_psf == C, f"PSF map channels ({C_psf}) must match image channels ({C})."

    # Padding
    pad_h_left  = (ks - 1) // 2
    pad_h_right = ks // 2
    pad_w_left  = (ks - 1) // 2
    pad_w_right = ks // 2
    img_pad = F.pad(img, (pad_w_left, pad_w_right, pad_h_left, pad_h_right), mode="reflect")

    # Pre-flip entire PSF map once (instead of flipping each PSF inside the loop)
    psf_map_flipped = torch.flip(psf_map, dims=(-2, -1))

    # Render image patch by patch
    img_render = torch.zeros_like(img)
    for i in range(grid_h):
        h_low  = (i * H) // grid_h
        h_high = ((i + 1) * H) // grid_h

        for j in range(grid_w):
            w_low  = (j * W) // grid_w
            w_high = ((j + 1) * W) // grid_w

            # PSF, [C, 1, ks, ks]
            psf = psf_map_flipped[i, j].unsqueeze(1)

            # Consider overlap to avoid boundary artifacts
            img_pad_patch = img_pad[
                :,
                :,
                h_low : h_high + pad_h_left + pad_h_right,
                w_low : w_high + pad_w_left + pad_w_right,
            ]

            # Convolution, [B, C, h_high-h_low, w_high-w_low]
            render_patch = F.conv2d(img_pad_patch, psf, groups=C)  
            img_render[:, :, h_low:h_high, w_low:w_high] = render_patch

    return img_render