Skip to content

Optics API Reference

The deeplens.optics module contains the differentiable lens models, optical surfaces, light representations, and image simulation utilities.


Base Classes

Base class for all optical objects. Provides device transfer, dtype conversion, and cloning by introspecting instance tensors.

deeplens.optics.DeepObj

DeepObj(dtype=None)

Base class for all differentiable optical objects in DeepLens.

Provides device management, dtype conversion, and deep-copy support via automatic introspection over instance tensors and nested DeepObj sub-objects. All lens, surface, material, ray, and wave objects inherit from this class.

Attributes:

Name Type Description
dtype dtype

Current floating-point dtype of all owned tensors.

device dtype

Current compute device (set by :meth:to).

Source code in deeplens/optics/base.py
def __init__(self, dtype=None):
    self.dtype = torch.get_default_dtype() if dtype is None else dtype

__str__

__str__()

Called when using print() and str()

Source code in deeplens/optics/base.py
def __str__(self):
    """Called when using print() and str()"""
    lines = [self.__class__.__name__ + ":"]
    for key, val in vars(self).items():
        if val.__class__.__name__ in ["list", "tuple"]:
            for i, v in enumerate(val):
                lines += "{}[{}]: {}".format(key, i, v).split("\n")
        elif val.__class__.__name__ in ["dict", "OrderedDict", "set"]:
            pass
        else:
            lines += "{}: {}".format(key, val).split("\n")

    return "\n    ".join(lines)

__call__

__call__(inp)

Call the forward function.

Source code in deeplens/optics/base.py
def __call__(self, inp):
    """Call the forward function."""
    return self.forward(inp)

clone

clone()

Clone a DeepObj object.

Source code in deeplens/optics/base.py
def clone(self):
    """Clone a DeepObj object."""
    return copy.deepcopy(self)

to

to(device)

Move all tensors and nested objects to device.

Recursively walks over every instance attribute and moves tensors, nn.Module sub-objects, and nested DeepObj objects to the requested device.

Parameters:

Name Type Description Default
device

Target device, e.g. "cuda", "cpu", or a torch.device instance.

required

Returns:

Name Type Description
DeepObj

self (for chaining).

Example

lens = GeoLens(filename="lens.json") lens.to("cuda") # move all tensors to GPU

Source code in deeplens/optics/base.py
def to(self, device):
    """Move all tensors and nested objects to *device*.

    Recursively walks over every instance attribute and moves tensors,
    ``nn.Module`` sub-objects, and nested ``DeepObj`` objects to the
    requested device.

    Args:
        device: Target device, e.g. ``"cuda"``, ``"cpu"``, or a
            ``torch.device`` instance.

    Returns:
        DeepObj: ``self`` (for chaining).

    Example:
        >>> lens = GeoLens(filename="lens.json")
        >>> lens.to("cuda")  # move all tensors to GPU
    """
    self.device = device

    for key, val in vars(self).items():
        if torch.is_tensor(val):
            exec(f"self.{key} = self.{key}.to(device)")
        elif isinstance(val, nn.Module):
            exec(f"self.{key}.to(device)")
        elif issubclass(type(val), DeepObj):
            exec(f"self.{key}.to(device)")
        elif val.__class__.__name__ in ("list", "tuple"):
            for i, v in enumerate(val):
                if torch.is_tensor(v):
                    exec(f"self.{key}[{i}] = self.{key}[{i}].to(device)")
                elif issubclass(type(v), DeepObj):
                    exec(f"self.{key}[{i}].to(device)")
    return self

astype

astype(dtype)

Convert all floating-point tensors to dtype.

Also calls torch.set_default_dtype(dtype) so that subsequent tensor creation uses the same precision.

Parameters:

Name Type Description Default
dtype dtype

Target floating-point dtype. Must be one of torch.float16, torch.float32, or torch.float64. Pass None to be a no-op.

required

Returns:

Name Type Description
DeepObj

self (for chaining).

Raises:

Type Description
AssertionError

If dtype is not a recognised floating-point dtype.

Example

lens = GeoLens(filename="lens.json") lens.astype(torch.float64) # switch to double precision

Source code in deeplens/optics/base.py
def astype(self, dtype):
    """Convert all floating-point tensors to *dtype*.

    Also calls ``torch.set_default_dtype(dtype)`` so that subsequent
    tensor creation uses the same precision.

    Args:
        dtype (torch.dtype): Target floating-point dtype.  Must be one of
            ``torch.float16``, ``torch.float32``, or ``torch.float64``.
            Pass ``None`` to be a no-op.

    Returns:
        DeepObj: ``self`` (for chaining).

    Raises:
        AssertionError: If *dtype* is not a recognised floating-point dtype.

    Example:
        >>> lens = GeoLens(filename="lens.json")
        >>> lens.astype(torch.float64)  # switch to double precision
    """
    if dtype is None:
        return self

    dtype_ls = [torch.float16, torch.float32, torch.float64]
    assert dtype in dtype_ls, f"Data type {dtype} is not supported."

    if torch.get_default_dtype() != dtype:
        torch.set_default_dtype(dtype)
        print(f"Set {dtype} as default torch dtype.")

    self.dtype = dtype
    for key, val in vars(self).items():
        if torch.is_tensor(val) and val.dtype in dtype_ls:
            exec(f"self.{key} = self.{key}.to(dtype)")
        elif issubclass(type(val), DeepObj):
            exec(f"self.{key}.astype(dtype)")
        elif issubclass(type(val), list):
            for i, v in enumerate(val):
                if torch.is_tensor(v) and v.dtype in dtype_ls:
                    exec(f"self.{key}[{i}] = self.{key}[{i}].to(dtype)")
                elif issubclass(type(v), DeepObj):
                    exec(f"self.{key}[{i}].astype(dtype)")
    return self

Abstract base class for all lens types. Defines the shared interface: psf(), psf_rgb(), render(), etc.

deeplens.optics.Lens

Lens(dtype=torch.float32, device=None)

Bases: DeepObj

Initialize a lens class.

Parameters:

Name Type Description Default
dtype dtype

Data type. Defaults to torch.float32.

float32
device str

Device to run the lens. Defaults to None.

None
Source code in deeplens/optics/lens.py
def __init__(self, dtype=torch.float32, device=None):
    """Initialize a lens class.

    Args:
        dtype (torch.dtype, optional): Data type. Defaults to torch.float32.
        device (str, optional): Device to run the lens. Defaults to None.
    """
    # Lens device
    if device is None:
        self.device = init_device()
    else:
        self.device = torch.device(device)

    # Lens default dtype
    self.dtype = dtype

read_lens_json

read_lens_json(filename)

Read lens from a json file.

Source code in deeplens/optics/lens.py
def read_lens_json(self, filename):
    """Read lens from a json file."""
    raise NotImplementedError

write_lens_json

write_lens_json(filename)

Write lens to a json file.

Source code in deeplens/optics/lens.py
def write_lens_json(self, filename):
    """Write lens to a json file."""
    raise NotImplementedError

set_sensor

set_sensor(sensor_size, sensor_res)

Set sensor size and resolution.

Parameters:

Name Type Description Default
sensor_size tuple

Sensor size (w, h) in [mm].

required
sensor_res tuple

Sensor resolution (W, H) in [pixels].

required
Source code in deeplens/optics/lens.py
def set_sensor(self, sensor_size, sensor_res):
    """Set sensor size and resolution.

    Args:
        sensor_size (tuple): Sensor size (w, h) in [mm].
        sensor_res (tuple): Sensor resolution (W, H) in [pixels].
    """
    assert sensor_size[0] * sensor_res[1] == sensor_size[1] * sensor_res[0], (
        "Sensor resolution aspect ratio does not match sensor size aspect ratio."
    )
    self.sensor_size = sensor_size
    self.sensor_res = sensor_res
    self.pixel_size = self.sensor_size[0] / self.sensor_res[0]
    self.r_sensor = float(np.sqrt(sensor_size[0] ** 2 + sensor_size[1] ** 2)) / 2
    self.calc_fov()

set_sensor_res

set_sensor_res(sensor_res)

Set sensor resolution (and aspect ratio) while keeping sensor radius unchanged.

Parameters:

Name Type Description Default
sensor_res tuple

Sensor resolution (W, H) in [pixels].

required
Source code in deeplens/optics/lens.py
def set_sensor_res(self, sensor_res):
    """Set sensor resolution (and aspect ratio) while keeping sensor radius unchanged.

    Args:
        sensor_res (tuple): Sensor resolution (W, H) in [pixels].
    """
    # Change sensor resolution
    self.sensor_res = sensor_res

    # Change sensor size (r_sensor is fixed)
    diam_res = float(np.sqrt(self.sensor_res[0] ** 2 + self.sensor_res[1] ** 2))
    self.sensor_size = (
        2 * self.r_sensor * self.sensor_res[0] / diam_res,
        2 * self.r_sensor * self.sensor_res[1] / diam_res,
    )
    self.pixel_size = self.sensor_size[0] / self.sensor_res[0]
    self.calc_fov()

calc_fov

calc_fov()

Compute FoV (radian) of the lens.

Reference

[1] https://en.wikipedia.org/wiki/Angle_of_view_(photography)

Source code in deeplens/optics/lens.py
@torch.no_grad()
def calc_fov(self):
    """Compute FoV (radian) of the lens.

    Reference:
        [1] https://en.wikipedia.org/wiki/Angle_of_view_(photography)
    """
    if not hasattr(self, "foclen"):
        return

    self.vfov = 2 * float(np.atan(self.sensor_size[0] / 2 / self.foclen))
    self.hfov = 2 * float(np.atan(self.sensor_size[1] / 2 / self.foclen))
    self.dfov = 2 * float(np.atan(self.r_sensor / self.foclen))
    self.rfov = self.dfov / 2  # radius (half diagonal) FoV

psf

psf(points, wvln=DEFAULT_WAVE, ks=PSF_KS, **kwargs)

Compute the monochromatic PSF for one or more point sources.

Subclasses must override this method with a differentiable implementation. Three computation models are common in practice: geometric ray binning, coherent ray-wave, and Huygens spherical-wave integration.

Parameters:

Name Type Description Default
points Tensor

Point source coordinates, shape [N, 3] or [3]. x, y are normalised to [-1, 1] (relative to the sensor half-diagonal); z is depth in mm (must be negative, i.e. in front of the lens).

required
wvln float

Wavelength in micrometers. Defaults to DEFAULT_WAVE (0.587 µm, d-line).

DEFAULT_WAVE
ks int

Output PSF kernel size in pixels. Defaults to PSF_KS (64).

PSF_KS
**kwargs

Additional keyword arguments forwarded to the underlying PSF computation (e.g. spp, model, recenter).

{}

Returns:

Type Description

torch.Tensor: PSF intensity map, shape [ks, ks] for a single

point or [N, ks, ks] for a batch.

Raises:

Type Description
NotImplementedError

This base implementation must be overridden.

Notes

The method is differentiable with respect to all optimisable lens parameters so it can be used directly inside a training loop.

Example

point = torch.tensor([0.0, 0.0, -10000.0]) psf = lens.psf(points=point, ks=64, model="geometric") print(psf.shape) # torch.Size([64, 64])

Source code in deeplens/optics/lens.py
def psf(self, points, wvln=DEFAULT_WAVE, ks=PSF_KS, **kwargs):
    """Compute the monochromatic PSF for one or more point sources.

    Subclasses must override this method with a differentiable
    implementation.  Three computation models are common in practice:
    geometric ray binning, coherent ray-wave, and Huygens spherical-wave
    integration.

    Args:
        points (torch.Tensor): Point source coordinates, shape ``[N, 3]``
            or ``[3]``.  ``x, y`` are normalised to ``[-1, 1]``
            (relative to the sensor half-diagonal); ``z`` is depth in mm
            (must be negative, i.e. in front of the lens).
        wvln (float, optional): Wavelength in micrometers.  Defaults to
            ``DEFAULT_WAVE`` (0.587 µm, d-line).
        ks (int, optional): Output PSF kernel size in pixels.  Defaults
            to ``PSF_KS`` (64).
        **kwargs: Additional keyword arguments forwarded to the underlying
            PSF computation (e.g. ``spp``, ``model``, ``recenter``).

    Returns:
        torch.Tensor: PSF intensity map, shape ``[ks, ks]`` for a single
        point or ``[N, ks, ks]`` for a batch.

    Raises:
        NotImplementedError: This base implementation must be overridden.

    Notes:
        The method is differentiable with respect to all optimisable lens
        parameters so it can be used directly inside a training loop.

    Example:
        >>> point = torch.tensor([0.0, 0.0, -10000.0])
        >>> psf = lens.psf(points=point, ks=64, model="geometric")
        >>> print(psf.shape)  # torch.Size([64, 64])
    """
    raise NotImplementedError

psf_rgb

psf_rgb(points, ks=PSF_KS, **kwargs)

Compute the RGB (tri-chromatic) PSF by stacking three wavelength calls.

Calls :meth:psf three times for the RGB primary wavelengths defined in WAVE_RGB and stacks the results along the channel axis.

Parameters:

Name Type Description Default
points Tensor

Point source coordinates, shape [N, 3] or [3]. Same convention as :meth:psf.

required
ks int

PSF kernel size. Defaults to PSF_KS.

PSF_KS
**kwargs

Forwarded to :meth:psf (e.g. spp, model).

{}

Returns:

Type Description

torch.Tensor: RGB PSF, shape [3, ks, ks] for a single point

or [N, 3, ks, ks] for a batch.

Source code in deeplens/optics/lens.py
def psf_rgb(self, points, ks=PSF_KS, **kwargs):
    """Compute the RGB (tri-chromatic) PSF by stacking three wavelength calls.

    Calls :meth:`psf` three times for the RGB primary wavelengths defined
    in ``WAVE_RGB`` and stacks the results along the channel axis.

    Args:
        points (torch.Tensor): Point source coordinates, shape ``[N, 3]``
            or ``[3]``.  Same convention as :meth:`psf`.
        ks (int, optional): PSF kernel size. Defaults to ``PSF_KS``.
        **kwargs: Forwarded to :meth:`psf` (e.g. ``spp``, ``model``).

    Returns:
        torch.Tensor: RGB PSF, shape ``[3, ks, ks]`` for a single point
        or ``[N, 3, ks, ks]`` for a batch.
    """
    psfs = []
    for wvln in WAVE_RGB:
        psfs.append(self.psf(points=points, ks=ks, wvln=wvln, **kwargs))
    psf_rgb = torch.stack(psfs, dim=-3)  # shape [3, ks, ks] or [N, 3, ks, ks]
    return psf_rgb

point_source_grid

point_source_grid(depth, grid=(9, 9), normalized=True, quater=False, center=True)

Generate point source grid for PSF calculation.

Parameters:

Name Type Description Default
depth float

Depth of the point source.

required
grid tuple

Grid size (grid_w, grid_h). Defaults to (9, 9), meaning 9x9 grid.

(9, 9)
normalized bool

Return normalized object source coordinates. Defaults to True, meaning object sources xy coordinates range from [-1, 1].

True
quater bool

Use quater of the sensor plane to save memory. Defaults to False.

False
center bool

Use center of each patch. Defaults to True.

True

Returns:

Name Type Description
point_source

Normalized object source coordinates. Shape of [grid_h, grid_w, 3], [-1, 1], [-1, 1], [-Inf, 0].

Source code in deeplens/optics/lens.py
def point_source_grid(
    self, depth, grid=(9, 9), normalized=True, quater=False, center=True
):
    """Generate point source grid for PSF calculation.

    Args:
        depth (float): Depth of the point source.
        grid (tuple): Grid size (grid_w, grid_h). Defaults to (9, 9), meaning 9x9 grid.
        normalized (bool): Return normalized object source coordinates. Defaults to True, meaning object sources xy coordinates range from [-1, 1].
        quater (bool): Use quater of the sensor plane to save memory. Defaults to False.
        center (bool): Use center of each patch. Defaults to True.

    Returns:
        point_source: Normalized object source coordinates. Shape of [grid_h, grid_w, 3], [-1, 1], [-1, 1], [-Inf, 0].
    """
    # Compute point source grid
    if grid[0] == 1:
        x, y = torch.tensor([[0.0]]), torch.tensor([[0.0]])
        assert not quater, "Quater should be False when grid is 1."
    else:
        if center:
            # Use center of each patch
            half_bin_size = 1 / 2 / (grid[0] - 1)
            x, y = torch.meshgrid(
                torch.linspace(-1 + half_bin_size, 1 - half_bin_size, grid[0]),
                torch.linspace(1 - half_bin_size, -1 + half_bin_size, grid[1]),
                indexing="xy",
            )
        else:
            # Use corner of image sensor
            x, y = torch.meshgrid(
                torch.linspace(-0.98, 0.98, grid[0]),
                torch.linspace(0.98, -0.98, grid[1]),
                indexing="xy",
            )

    z = torch.full_like(x, depth)
    point_source = torch.stack([x, y, z], dim=-1)

    # Use quater of the sensor plane to save memory
    if quater:
        z = torch.full_like(x, depth)
        point_source = torch.stack([x, y, z], dim=-1)
        bound_i = grid[0] // 2 if grid[0] % 2 == 0 else grid[0] // 2 + 1
        bound_j = grid[1] // 2
        point_source = point_source[0:bound_i, bound_j:, :]

    # De-normalize object source coordinates to physical coordinates
    if not normalized:
        scale = self.calc_scale(depth)
        point_source[..., 0] *= scale * self.sensor_size[0] / 2
        point_source[..., 1] *= scale * self.sensor_size[1] / 2

    return point_source

psf_map

psf_map(grid=(5, 5), wvln=DEFAULT_WAVE, depth=DEPTH, ks=PSF_KS, **kwargs)

Compute monochrome PSF map.

Parameters:

Name Type Description Default
grid tuple

Grid size (grid_w, grid_h). Defaults to (5, 5), meaning 5x5 grid.

(5, 5)
wvln float

Wavelength. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE
depth float

Depth of the object. Defaults to DEPTH.

DEPTH
ks int

Kernel size. Defaults to PSF_KS.

PSF_KS

Returns:

Name Type Description
psf_map

Shape of [grid_h, grid_w, 3, ks, ks].

Source code in deeplens/optics/lens.py
def psf_map(self, grid=(5, 5), wvln=DEFAULT_WAVE, depth=DEPTH, ks=PSF_KS, **kwargs):
    """Compute monochrome PSF map.

    Args:
        grid (tuple): Grid size (grid_w, grid_h). Defaults to (5, 5), meaning 5x5 grid.
        wvln (float): Wavelength. Defaults to DEFAULT_WAVE.
        depth (float): Depth of the object. Defaults to DEPTH.
        ks (int): Kernel size. Defaults to PSF_KS.

    Returns:
        psf_map: Shape of [grid_h, grid_w, 3, ks, ks].
    """
    # PSF map grid
    points = self.point_source_grid(depth=depth, grid=grid, center=True)
    points = points.reshape(-1, 3)

    # Compute PSF map
    psfs = []
    for i in range(points.shape[0]):
        point = points[i, ...]
        psf = self.psf(points=point, wvln=wvln, ks=ks)
        psfs.append(psf)
    psf_map = torch.stack(psfs).unsqueeze(1)  # shape [grid_h * grid_w, 1, ks, ks]

    # Reshape PSF map from [grid_h * grid_w, 1, ks, ks] -> [grid_h, grid_w, 1, ks, ks]
    psf_map = psf_map.reshape(grid[1], grid[0], 1, ks, ks)
    return psf_map

psf_map_rgb

psf_map_rgb(grid=(5, 5), ks=PSF_KS, depth=DEPTH, **kwargs)

Compute RGB PSF map.

Parameters:

Name Type Description Default
grid tuple

Grid size (grid_w, grid_h). Defaults to (5, 5), meaning 5x5 grid.

(5, 5)
ks int

Kernel size. Defaults to PSF_KS, meaning PSF_KS x PSF_KS kernel size.

PSF_KS
depth float

Depth of the object. Defaults to DEPTH.

DEPTH
**kwargs

Additional arguments for psf_map().

{}

Returns:

Name Type Description
psf_map

Shape of [grid_h, grid_w, 3, ks, ks].

Source code in deeplens/optics/lens.py
def psf_map_rgb(self, grid=(5, 5), ks=PSF_KS, depth=DEPTH, **kwargs):
    """Compute RGB PSF map.

    Args:
        grid (tuple): Grid size (grid_w, grid_h). Defaults to (5, 5), meaning 5x5 grid.
        ks (int): Kernel size. Defaults to PSF_KS, meaning PSF_KS x PSF_KS kernel size.
        depth (float): Depth of the object. Defaults to DEPTH.
        **kwargs: Additional arguments for psf_map().

    Returns:
        psf_map: Shape of [grid_h, grid_w, 3, ks, ks].
    """
    psfs = []
    for wvln in WAVE_RGB:
        psf_map = self.psf_map(grid=grid, ks=ks, depth=depth, wvln=wvln, **kwargs)
        psfs.append(psf_map)
    psf_map = torch.cat(psfs, dim=2)  # shape [grid_h, grid_w, 3, ks, ks]
    return psf_map

draw_psf_map

draw_psf_map(grid=(7, 7), ks=PSF_KS, depth=DEPTH, log_scale=False, save_name='./psf_map.png', show=False)

Draw RGB PSF map of the lens.

Source code in deeplens/optics/lens.py
@torch.no_grad()
def draw_psf_map(
    self,
    grid=(7, 7),
    ks=PSF_KS,
    depth=DEPTH,
    log_scale=False,
    save_name="./psf_map.png",
    show=False,
):
    """Draw RGB PSF map of the lens."""
    # Calculate RGB PSF map, shape [grid_h, grid_w, 3, ks, ks]
    psf_map = self.psf_map_rgb(depth=depth, grid=grid, ks=ks)

    # Create a grid visualization (vis_map: shape [3, grid_h * ks, grid_w * ks])
    grid_w, grid_h = grid if isinstance(grid, tuple) else (grid, grid)
    h, w = grid_h * ks, grid_w * ks
    vis_map = torch.zeros((3, h, w), device=psf_map.device, dtype=psf_map.dtype)

    # Put each PSF into the vis_map
    for i in range(grid_h):
        for j in range(grid_w):
            # Extract the PSF at this grid position
            psf = psf_map[i, j]  # shape [3, ks, ks]

            # Normalize the PSF
            if log_scale:
                # Log scale normalization for better visualization
                psf = torch.log(psf + 1e-4)  # 1e-4 is an empirical value
                psf = (psf - psf.min()) / (psf.max() - psf.min() + 1e-8)
            else:
                # Linear normalization
                local_max = psf.max()
                if local_max > 0:
                    psf = psf / local_max

            # Place the normalized PSF in the visualization map
            y_start, y_end = i * ks, (i + 1) * ks
            x_start, x_end = j * ks, (j + 1) * ks
            vis_map[:, y_start:y_end, x_start:x_end] = psf

    # Create the figure and display
    fig, ax = plt.subplots(figsize=(10, 10))

    # Convert to numpy for plotting
    vis_map = vis_map.permute(1, 2, 0).cpu().numpy()
    ax.imshow(vis_map)

    # Add scale bar near bottom-left
    H, W, _ = vis_map.shape
    scale_bar_length = 100
    arrow_length = scale_bar_length / (self.pixel_size * 1e3)
    y_position = H - 20  # a little above the lower edge
    x_start = 20
    x_end = x_start + arrow_length

    ax.annotate(
        "",
        xy=(x_start, y_position),
        xytext=(x_end, y_position),
        arrowprops=dict(arrowstyle="-", color="white"),
        annotation_clip=False,
    )
    ax.text(
        x_end + 5,
        y_position,
        f"{scale_bar_length} μm",
        color="white",
        fontsize=12,
        ha="left",
        va="center",
        clip_on=False,
    )

    # Clean up axes and save
    ax.axis("off")
    plt.tight_layout(pad=0)

    if show:
        return fig, ax
    else:
        plt.savefig(save_name, dpi=300, bbox_inches="tight", pad_inches=0)
        plt.close(fig)

point_source_radial

point_source_radial(depth, grid=9, center=False)

Compute point radial [0, 1] in the object space to compute PSF grid.

Parameters:

Name Type Description Default
grid int

Grid size. Defaults to 9.

9

Returns:

Name Type Description
point_source

Shape of [grid, 3].

Source code in deeplens/optics/lens.py
def point_source_radial(self, depth, grid=9, center=False):
    """Compute point radial [0, 1] in the object space to compute PSF grid.

    Args:
        grid (int, optional): Grid size. Defaults to 9.

    Returns:
        point_source: Shape of [grid, 3].
    """
    if grid == 1:
        x = torch.tensor([0.0])
    else:
        # Select center of bin to calculate PSF
        if center:
            half_bin_size = 1 / 2 / (grid - 1)
            x = torch.linspace(0, 1 - half_bin_size, grid)
        else:
            x = torch.linspace(0, 0.98, grid)

    z = torch.full_like(x, depth)
    point_source = torch.stack([x, x, z], dim=-1)
    return point_source

draw_psf_radial

draw_psf_radial(M=3, depth=DEPTH, ks=PSF_KS, log_scale=False, save_name='./psf_radial.png')

Draw radial PSF (45 deg). Will draw M PSFs, each of size ks x ks.

Source code in deeplens/optics/lens.py
@torch.no_grad()
def draw_psf_radial(
    self, M=3, depth=DEPTH, ks=PSF_KS, log_scale=False, save_name="./psf_radial.png"
):
    """Draw radial PSF (45 deg). Will draw M PSFs, each of size ks x ks."""
    x = torch.linspace(0, 1, M)
    y = torch.linspace(0, 1, M)
    z = torch.full_like(x, depth)
    points = torch.stack((x, y, z), dim=-1)

    psfs = []
    for i in range(M):
        # Scale PSF for a better visualization
        psf = self.psf_rgb(points=points[i], ks=ks, recenter=True, spp=SPP_PSF)
        psf /= psf.max()

        if log_scale:
            psf = torch.log(psf + EPSILON)
            psf = (psf - psf.min()) / (psf.max() - psf.min())

        psfs.append(psf)

    psf_grid = make_grid(psfs, nrow=M, padding=1, pad_value=0.0)
    save_image(psf_grid, save_name, normalize=True)

render

render(img_obj, depth=DEPTH, method='psf_patch', **kwargs)

Differentiable image simulation for a 2D (flat) scene.

Performs only the optical component of image simulation and is fully differentiable. Sensor noise is handled separately by the :class:~deeplens.camera.Camera class.

For incoherent imaging the intensity PSF is convolved with the object-space image. For coherent imaging the complex PSF is convolved with the complex object image before squaring for intensity.

Parameters:

Name Type Description Default
img_obj Tensor

Input image in linear (raw) space, shape [B, C, H, W].

required
depth float

Object depth in mm (negative value). Defaults to DEPTH (-20 000 mm, i.e. infinity).

DEPTH
method str

Rendering method. One of:

  • "psf_patch" – convolve a single PSF evaluated at patch_center (default).
  • "psf_map" – spatially-varying PSF block convolution.
'psf_patch'
**kwargs

Method-specific keyword arguments:

  • For "psf_map": psf_grid (tuple, default (10, 10)), psf_ks (int, default PSF_KS).
  • For "psf_patch": patch_center (tuple or Tensor, default (0.0, 0.0)), psf_ks (int).
{}

Returns:

Type Description

torch.Tensor: Rendered image, shape [B, C, H, W].

Raises:

Type Description
AssertionError

If method is "psf_map" and the image resolution does not match the sensor resolution.

Exception

If method is not recognised.

References

[1] "Optical Aberration Correction in Postprocessing using Imaging Simulation", TOG 2021. [2] "Efficient depth- and spatially-varying image simulation for defocus deblur", ICCVW 2025.

Example

img_rendered = lens.render(img, depth=-10000.0, method="psf_patch", ... patch_center=(0.3, 0.0), psf_ks=64)

Source code in deeplens/optics/lens.py
def render(self, img_obj, depth=DEPTH, method="psf_patch", **kwargs):
    """Differentiable image simulation for a 2D (flat) scene.

    Performs only the optical component of image simulation and is fully
    differentiable.  Sensor noise is handled separately by the
    :class:`~deeplens.camera.Camera` class.

    For incoherent imaging the intensity PSF is convolved with the
    object-space image.  For coherent imaging the complex PSF is convolved
    with the complex object image before squaring for intensity.

    Args:
        img_obj (torch.Tensor): Input image in linear (raw) space,
            shape ``[B, C, H, W]``.
        depth (float, optional): Object depth in mm (negative value).
            Defaults to ``DEPTH`` (-20 000 mm, i.e. infinity).
        method (str, optional): Rendering method.  One of:

            * ``"psf_patch"`` – convolve a single PSF evaluated at
              *patch_center* (default).
            * ``"psf_map"`` – spatially-varying PSF block convolution.

        **kwargs: Method-specific keyword arguments:

            * For ``"psf_map"``: ``psf_grid`` (tuple, default ``(10, 10)``),
              ``psf_ks`` (int, default ``PSF_KS``).
            * For ``"psf_patch"``: ``patch_center`` (tuple or Tensor,
              default ``(0.0, 0.0)``), ``psf_ks`` (int).

    Returns:
        torch.Tensor: Rendered image, shape ``[B, C, H, W]``.

    Raises:
        AssertionError: If *method* is ``"psf_map"`` and the image
            resolution does not match the sensor resolution.
        Exception: If *method* is not recognised.

    References:
        [1] "Optical Aberration Correction in Postprocessing using Imaging Simulation", TOG 2021.
        [2] "Efficient depth- and spatially-varying image simulation for defocus deblur", ICCVW 2025.

    Example:
        >>> img_rendered = lens.render(img, depth=-10000.0, method="psf_patch",
        ...                            patch_center=(0.3, 0.0), psf_ks=64)
    """
    # Check sensor resolution
    B, C, Himg, Wimg = img_obj.shape
    Wsensor, Hsensor = self.sensor_res

    # Image simulation (in RAW space)
    if method == "psf_map":
        # Render full resolution image with PSF map convolution
        assert Wimg == Wsensor and Himg == Hsensor, (
            f"Sensor resolution {Wsensor}x{Hsensor} must match input image {Wimg}x{Himg}."
        )
        psf_grid = kwargs.get("psf_grid", (10, 10))
        psf_ks = kwargs.get("psf_ks", PSF_KS)
        img_render = self.render_psf_map(
            img_obj, depth=depth, psf_grid=psf_grid, psf_ks=psf_ks
        )

    elif method == "psf_patch":
        # Render an image patch with its corresponding PSF
        patch_center = kwargs.get("patch_center", (0.0, 0.0))
        psf_ks = kwargs.get("psf_ks", PSF_KS)
        img_render = self.render_psf_patch(
            img_obj, depth=depth, patch_center=patch_center, psf_ks=psf_ks
        )

    elif method == "psf_pixel":
        raise NotImplementedError(
            "Per-pixel PSF convolution has not been implemented."
        )

    else:
        raise Exception(f"Image simulation method {method} is not supported.")

    return img_render

render_psf

render_psf(img_obj, depth=DEPTH, patch_center=(0, 0), psf_ks=PSF_KS)

Render image patch using PSF convolution. Better not use this function to avoid confusion.

Source code in deeplens/optics/lens.py
def render_psf(self, img_obj, depth=DEPTH, patch_center=(0, 0), psf_ks=PSF_KS):
    """Render image patch using PSF convolution. Better not use this function to avoid confusion."""
    return self.render_psf_patch(
        img_obj, depth=depth, patch_center=patch_center, psf_ks=psf_ks
    )

render_psf_patch

render_psf_patch(img_obj, depth=DEPTH, patch_center=(0, 0), psf_ks=PSF_KS)

Render an image patch using PSF convolution, and return positional encoding channel.

Parameters:

Name Type Description Default
img_obj tensor

Input image object in raw space. Shape of [B, C, H, W].

required
depth float

Depth of the object.

DEPTH
patch_center tensor

Center of the image patch. Shape of [2] or [B, 2].

(0, 0)
psf_ks int

PSF kernel size. Defaults to PSF_KS.

PSF_KS

Returns:

Name Type Description
img_render

Rendered image. Shape of [B, C, H, W].

Source code in deeplens/optics/lens.py
def render_psf_patch(self, img_obj, depth=DEPTH, patch_center=(0, 0), psf_ks=PSF_KS):
    """Render an image patch using PSF convolution, and return positional encoding channel.

    Args:
        img_obj (tensor): Input image object in raw space. Shape of [B, C, H, W].
        depth (float): Depth of the object.
        patch_center (tensor): Center of the image patch. Shape of [2] or [B, 2].
        psf_ks (int): PSF kernel size. Defaults to PSF_KS.

    Returns:
        img_render: Rendered image. Shape of [B, C, H, W].
    """
    # Convert patch_center to tensor
    if isinstance(patch_center, (list, tuple)):
        points = (patch_center[0], patch_center[1], depth)
        points = torch.tensor(points).unsqueeze(0)
    elif isinstance(patch_center, torch.Tensor):
        depth = torch.full_like(patch_center[..., 0], depth)
        points = torch.stack(
            [patch_center[..., 0], patch_center[..., 1], depth], dim=-1
        )
    else:
        raise Exception(
            f"Patch center must be a list or tuple or tensor, but got {type(patch_center)}."
        )

    # Compute PSF and perform PSF convolution
    psf = self.psf_rgb(points=points, ks=psf_ks).squeeze(0)
    img_render = conv_psf(img_obj, psf=psf)
    return img_render

render_psf_map

render_psf_map(img_obj, depth=DEPTH, psf_grid=7, psf_ks=PSF_KS)

Render image using PSF block convolution.

Note

Larger psf_grid and psf_ks are typically better for more accurate rendering, but slower.

Parameters:

Name Type Description Default
img_obj tensor

Input image object in raw space. Shape of [B, C, H, W].

required
depth float

Depth of the object.

DEPTH
psf_grid int

PSF grid size.

7
psf_ks int

PSF kernel size. Defaults to PSF_KS.

PSF_KS

Returns:

Name Type Description
img_render

Rendered image. Shape of [B, C, H, W].

Source code in deeplens/optics/lens.py
def render_psf_map(self, img_obj, depth=DEPTH, psf_grid=7, psf_ks=PSF_KS):
    """Render image using PSF block convolution.

    Note:
        Larger psf_grid and psf_ks are typically better for more accurate rendering, but slower.

    Args:
        img_obj (tensor): Input image object in raw space. Shape of [B, C, H, W].
        depth (float): Depth of the object.
        psf_grid (int): PSF grid size.
        psf_ks (int): PSF kernel size. Defaults to PSF_KS.

    Returns:
        img_render: Rendered image. Shape of [B, C, H, W].
    """
    psf_map = self.psf_map_rgb(grid=psf_grid, ks=psf_ks, depth=depth)
    img_render = conv_psf_map(img_obj, psf_map)
    return img_render

render_rgbd

render_rgbd(img_obj, depth_map, method='psf_patch', **kwargs)

Render RGBD image.

TODO: add obstruction-aware image simulation.

Parameters:

Name Type Description Default
img_obj tensor

Object image. Shape of [B, C, H, W].

required
depth_map tensor

Depth map [mm]. Shape of [B, 1, H, W]. Values should be positive.

required
method str

Image simulation method. Defaults to "psf_patch".

'psf_patch'
**kwargs

Additional arguments for different methods. - interp_mode (str): "depth" or "disparity". Defaults to "depth".

{}

Returns:

Name Type Description
img_render

Rendered image. Shape of [B, C, H, W].

Reference

[1] "Aberration-Aware Depth-from-Focus", TPAMI 2023. [2] "Efficient Depth- and Spatially-Varying Image Simulation for Defocus Deblur", ICCVW 2025.

Source code in deeplens/optics/lens.py
def render_rgbd(self, img_obj, depth_map, method="psf_patch", **kwargs):
    """Render RGBD image.

    TODO: add obstruction-aware image simulation.

    Args:
        img_obj (tensor): Object image. Shape of [B, C, H, W].
        depth_map (tensor): Depth map [mm]. Shape of [B, 1, H, W]. Values should be positive.
        method (str, optional): Image simulation method. Defaults to "psf_patch".
        **kwargs: Additional arguments for different methods.
            - interp_mode (str): "depth" or "disparity". Defaults to "depth".

    Returns:
        img_render: Rendered image. Shape of [B, C, H, W].

    Reference:
        [1] "Aberration-Aware Depth-from-Focus", TPAMI 2023.
        [2] "Efficient Depth- and Spatially-Varying Image Simulation for Defocus Deblur", ICCVW 2025.
    """
    if depth_map.min() < 0:
        raise ValueError("Depth map should be positive.")

    if len(depth_map.shape) == 3:
        # [B, H, W] -> [B, 1, H, W]
        depth_map = depth_map.unsqueeze(1)

    if method == "psf_patch":
        # Render an image patch (same FoV, different depth)
        patch_center = kwargs.get("patch_center", (0.0, 0.0))
        psf_ks = kwargs.get("psf_ks", PSF_KS)
        depth_min = kwargs.get("depth_min", depth_map.min())
        depth_max = kwargs.get("depth_max", depth_map.max())
        num_layers = kwargs.get("num_layers", 16)
        interp_mode = kwargs.get("interp_mode", "disparity")

        # Calculate PSF at different depths, (num_layers, 3, ks, ks)
        disp_ref, depths_ref = self._sample_depth_layers(depth_min, depth_max, num_layers)

        points = torch.stack(
            [
                torch.full_like(depths_ref, patch_center[0]),
                torch.full_like(depths_ref, patch_center[1]),
                depths_ref,
            ],
            dim=-1,
        )
        psfs = self.psf_rgb(points=points, ks=psf_ks) # (num_layers, 3, ks, ks)

        # Image simulation
        img_render = conv_psf_depth_interp(img_obj, -depth_map, psfs, depths_ref, interp_mode=interp_mode)
        return img_render

    elif method == "psf_map":
        # Render full resolution image with PSF map convolution (different FoV, different depth)
        psf_grid = kwargs.get("psf_grid", (8, 8))  # (grid_w, grid_h)
        psf_ks = kwargs.get("psf_ks", PSF_KS)
        depth_min = kwargs.get("depth_min", depth_map.min())
        depth_max = kwargs.get("depth_max", depth_map.max())
        num_layers = kwargs.get("num_layers", 16)
        interp_mode = kwargs.get("interp_mode", "disparity")

        # Calculate PSF map at different depths (convert to negative for PSF calculation)
        disp_ref, depths_ref = self._sample_depth_layers(depth_min, depth_max, num_layers)

        psf_maps = []
        for depth in tqdm(depths_ref):
            psf_map = self.psf_map_rgb(grid=psf_grid, ks=psf_ks, depth=depth)
            psf_maps.append(psf_map)
        psf_map = torch.stack(
            psf_maps, dim=2
        )  # shape [grid_h, grid_w, num_layers, 3, ks, ks]

        # Image simulation
        img_render = conv_psf_map_depth_interp(
            img_obj, -depth_map, psf_map, depths_ref, interp_mode=interp_mode
        )
        return img_render

    elif method == "psf_pixel":
        # Render full resolution image with pixel-wise PSF convolution. This method is computationally expensive.
        psf_ks = kwargs.get("psf_ks", PSF_KS)
        assert img_obj.shape[0] == 1, "Now only support batch size 1"

        # Calculate points in the object space
        points_xy = torch.meshgrid(
            torch.linspace(-1, 1, img_obj.shape[-1], device=self.device),
            torch.linspace(1, -1, img_obj.shape[-2], device=self.device),
            indexing="xy",
        )
        points_xy = torch.stack(points_xy, dim=0).unsqueeze(0)
        points = torch.cat([points_xy, -depth_map], dim=1)  # shape [B, 3, H, W]

        # Calculate PSF at different pixels. This step is the most time-consuming.
        points = points.permute(0, 2, 3, 1).reshape(-1, 3)  # shape [H*W, 3]
        psfs = self.psf_rgb(points=points, ks=psf_ks)  # shape [H*W, 3, ks, ks]
        psfs = psfs.reshape(
            img_obj.shape[-2], img_obj.shape[-1], 3, psf_ks, psf_ks
        )  # shape [H, W, 3, ks, ks]

        # Image simulation
        img_render = conv_psf_pixel(img_obj, psfs)  # shape [1, C, H, W]
        return img_render

    else:
        raise Exception(f"Image simulation method {method} is not supported.")

activate_grad

activate_grad(activate=True)

Activate gradient for each surface.

Source code in deeplens/optics/lens.py
def activate_grad(self, activate=True):
    """Activate gradient for each surface."""
    raise NotImplementedError

get_optimizer_params

get_optimizer_params(lr=[0.0001, 0.0001, 0.1, 0.001])

Get optimizer parameters for different lens parameters.

Source code in deeplens/optics/lens.py
def get_optimizer_params(self, lr=[1e-4, 1e-4, 1e-1, 1e-3]):
    """Get optimizer parameters for different lens parameters."""
    raise NotImplementedError

get_optimizer

get_optimizer(lr=[0.0001, 0.0001, 0, 0.001])

Get optimizer.

Source code in deeplens/optics/lens.py
def get_optimizer(self, lr=[1e-4, 1e-4, 0, 1e-3]):
    """Get optimizer."""
    params = self.get_optimizer_params(lr)
    optimizer = torch.optim.Adam(params)
    return optimizer

Lens Models

Differentiable multi-element refractive lens via geometric ray tracing. This is the primary lens model in DeepLens.

GeoLens uses a mixin architecture — functionality is split across GeoLensEval, GeoLensOptim, GeoLensVis, GeoLensIO, GeoLensTolerance, and GeoLensVis3D.

deeplens.optics.GeoLens

GeoLens(filename=None, device=None, dtype=torch.float32)

Bases: Lens, GeoLensEval, GeoLensOptim, GeoLensVis, GeoLensIO, GeoLensTolerance, GeoLensVis3D

Differentiable geometric lens using vectorised ray tracing.

The primary lens model in DeepLens. Supports multi-element refractive (and partially reflective) systems loaded from JSON, Zemax .zmx, or Code V .seq files. Accuracy is aligned with Zemax OpticStudio.

Uses a mixin architecture – six specialised mixin classes are composed at class definition time to keep each concern isolated:

  • :class:~deeplens.optics.geolens_pkg.eval.GeoLensEval – optical performance evaluation (spot, MTF, distortion, vignetting).
  • :class:~deeplens.optics.geolens_pkg.optim.GeoLensOptim – loss functions and gradient-based optimisation.
  • :class:~deeplens.optics.geolens_pkg.vis.GeoLensVis – 2-D layout and ray visualisation.
  • :class:~deeplens.optics.geolens_pkg.io.GeoLensIO – read/write JSON, Zemax .zmx.
  • :class:~deeplens.optics.geolens_pkg.tolerance.GeoLensTolerance – manufacturing tolerance analysis.
  • :class:~deeplens.optics.geolens_pkg.view_3d.GeoLensVis3D – 3-D mesh visualisation.

Key differentiability trick: Ray-surface intersection (:meth:~deeplens.optics.geometric_surface.base.Surface.newtons_method) uses a non-differentiable Newton loop followed by one differentiable Newton step to enable gradient flow.

Attributes:

Name Type Description
surfaces list[Surface]

Ordered list of optical surfaces.

materials list[Material]

Optical materials between surfaces.

d_sensor Tensor

Back focal distance [mm].

foclen float

Effective focal length [mm].

fnum float

F-number.

rfov float

Half-diagonal field of view [radians].

sensor_size tuple

Physical sensor size (W, H) [mm].

sensor_res tuple

Sensor resolution (W, H) [pixels].

pixel_size float

Pixel pitch [mm].

References

Xinge Yang et al., "Curriculum learning for ab initio deep learned refractive optics," Nature Communications 2024.

Initialize a refractive lens.

There are two ways to initialize a GeoLens
  1. Read a lens from .json/.zmx/.seq file
  2. Initialize a lens with no lens file, then manually add surfaces and materials

Parameters:

Name Type Description Default
filename str

Path to lens file (.json, .zmx, or .seq). Defaults to None.

None
device device

Device for tensor computations. Defaults to None.

None
dtype dtype

Data type for computations. Defaults to torch.float32.

float32
Source code in deeplens/optics/geolens.py
def __init__(
    self,
    filename=None,
    device=None,
    dtype=torch.float32,
):
    """Initialize a refractive lens.

    There are two ways to initialize a GeoLens:
        1. Read a lens from .json/.zmx/.seq file
        2. Initialize a lens with no lens file, then manually add surfaces and materials

    Args:
        filename (str, optional): Path to lens file (.json, .zmx, or .seq). Defaults to None.
        device (torch.device, optional): Device for tensor computations. Defaults to None.
        dtype (torch.dtype, optional): Data type for computations. Defaults to torch.float32.
    """
    super().__init__(device=device, dtype=dtype)

    # Load lens file
    if filename is not None:
        self.read_lens(filename)
    else:
        self.surfaces = []
        self.materials = []
        # Set default sensor size and resolution
        self.sensor_size = (8.0, 8.0)
        self.sensor_res = (2000, 2000)
        self.to(self.device)

read_lens

read_lens(filename)

Read a GeoLens from a file.

Supported file formats
  • .json: DeepLens native JSON format
  • .zmx: Zemax lens file format
  • .seq: CODE V sequence file format

Parameters:

Name Type Description Default
filename str

Path to the lens file.

required
Note

Sensor size and resolution will usually be overwritten by values from the file.

Source code in deeplens/optics/geolens.py
def read_lens(self, filename):
    """Read a GeoLens from a file.

    Supported file formats:
        - .json: DeepLens native JSON format
        - .zmx: Zemax lens file format
        - .seq: CODE V sequence file format

    Args:
        filename (str): Path to the lens file.

    Note:
        Sensor size and resolution will usually be overwritten by values from the file.
    """
    # Load lens file
    if filename[-4:] == ".txt":
        raise ValueError("File format .txt has been deprecated.")
    elif filename[-5:] == ".json":
        self.read_lens_json(filename)
    elif filename[-4:] == ".zmx":
        self.read_lens_zmx(filename)
    elif filename[-4:] == ".seq":
        self.read_lens_seq(filename)
    else:
        raise ValueError(f"File format {filename[-4:]} not supported.")

    # Complete sensor size and resolution if not set from lens file
    if not hasattr(self, "sensor_size"):
        self.sensor_size = (8.0, 8.0)
        print(
            f"Sensor_size not found in lens file. Using default: {self.sensor_size} mm. "
            "Consider specifying sensor_size in the lens file or using set_sensor()."
        )

    if not hasattr(self, "sensor_res"):
        self.sensor_res = (2000, 2000)
        print(
            f"Sensor_res not found in lens file. Using default: {self.sensor_res} pixels. "
            "Consider specifying sensor_res in the lens file or using set_sensor()."
        )
        self.set_sensor_res(self.sensor_res)

    # After loading lens, compute foclen, fov and fnum
    self.to(self.device)
    self.astype(self.dtype)
    self.post_computation()

post_computation

post_computation()

Compute derived optical properties after loading or modifying lens.

Calculates and caches
  • Effective focal length (EFL)
  • Entrance and exit pupil positions and radii
  • Field of view (FoV) in horizontal, vertical, and diagonal directions
  • F-number
Note

This method should be called after any changes to the lens geometry.

Source code in deeplens/optics/geolens.py
def post_computation(self):
    """Compute derived optical properties after loading or modifying lens.

    Calculates and caches:
        - Effective focal length (EFL)
        - Entrance and exit pupil positions and radii
        - Field of view (FoV) in horizontal, vertical, and diagonal directions
        - F-number

    Note:
        This method should be called after any changes to the lens geometry.
    """
    self.calc_foclen()
    self.calc_pupil()
    self.calc_fov()

__call__

__call__(ray)

Trace rays through the lens system.

Makes the GeoLens callable, allowing ray tracing with function call syntax.

Source code in deeplens/optics/geolens.py
def __call__(self, ray):
    """Trace rays through the lens system.

    Makes the GeoLens callable, allowing ray tracing with function call syntax.
    """
    return self.trace(ray)

sample_grid_rays

sample_grid_rays(depth=float('inf'), num_grid=(11, 11), num_rays=SPP_PSF, wvln=DEFAULT_WAVE, uniform_fov=True, sample_more_off_axis=False, scale_pupil=1.0)

Sample grid rays from object space. (1) If depth is infinite, sample parallel rays at different field angles. (2) If depth is finite, sample point source rays from the object plane.

This function is usually used for (1) PSF map, (2) RMS error map, and (3) spot diagram calculation.

Parameters:

Name Type Description Default
depth float

sampling depth. Defaults to float("inf").

float('inf')
num_grid tuple

number of grid points. Defaults to [11, 11].

(11, 11)
num_rays int

number of rays. Defaults to SPP_PSF.

SPP_PSF
wvln float

ray wvln. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE
uniform_fov bool

If True, sample uniform FoV angles.

True
sample_more_off_axis bool

If True, sample more off-axis rays.

False
scale_pupil float

Scale factor for pupil radius.

1.0

Returns:

Name Type Description
ray Ray object

Ray object. Shape [num_grid[1], num_grid[0], num_rays, 3]

Source code in deeplens/optics/geolens.py
@torch.no_grad()
def sample_grid_rays(
    self,
    depth=float("inf"),
    num_grid=(11, 11),
    num_rays=SPP_PSF,
    wvln=DEFAULT_WAVE,
    uniform_fov=True,
    sample_more_off_axis=False,
    scale_pupil=1.0,
):
    """Sample grid rays from object space.
        (1) If depth is infinite, sample parallel rays at different field angles.
        (2) If depth is finite, sample point source rays from the object plane.

    This function is usually used for (1) PSF map, (2) RMS error map, and (3) spot diagram calculation.

    Args:
        depth (float, optional): sampling depth. Defaults to float("inf").
        num_grid (tuple, optional): number of grid points. Defaults to [11, 11].
        num_rays (int, optional): number of rays. Defaults to SPP_PSF.
        wvln (float, optional): ray wvln. Defaults to DEFAULT_WAVE.
        uniform_fov (bool, optional): If True, sample uniform FoV angles.
        sample_more_off_axis (bool, optional): If True, sample more off-axis rays.
        scale_pupil (float, optional): Scale factor for pupil radius.

    Returns:
        ray (Ray object): Ray object. Shape [num_grid[1], num_grid[0], num_rays, 3]
    """
    # Calculate field angles for grid source. Top-left field has positive fov_x and negative fov_y
    x_list = [x for x in np.linspace(1, -1, num_grid[0])]
    y_list = [y for y in np.linspace(-1, 1, num_grid[1])]
    if sample_more_off_axis:
        x_list = [np.sign(x) * np.abs(x) ** 0.5 for x in x_list]
        y_list = [np.sign(y) * np.abs(y) ** 0.5 for y in y_list]

    # Calculate FoV_x and FoV_y
    if uniform_fov:
        # Sample uniform FoV angles
        fov_x_list = [x * self.vfov / 2 for x in x_list]
        fov_y_list = [y * self.hfov / 2 for y in y_list]
        fov_x_list = [float(np.rad2deg(fov_x)) for fov_x in fov_x_list]
        fov_y_list = [float(np.rad2deg(fov_y)) for fov_y in fov_y_list]
    else:
        # Sample uniform object grid
        fov_x_list = [np.arctan(x * np.tan(self.vfov / 2)) for x in x_list]
        fov_y_list = [np.arctan(y * np.tan(self.hfov / 2)) for y in y_list]
        fov_x_list = [float(np.rad2deg(fov_x)) for fov_x in fov_x_list]
        fov_y_list = [float(np.rad2deg(fov_y)) for fov_y in fov_y_list]

    # Sample rays (parallel or point source)
    if depth == float("inf"):
        rays = self.sample_parallel(
            fov_x=fov_x_list,
            fov_y=fov_y_list,
            num_rays=num_rays,
            wvln=wvln,
            scale_pupil=scale_pupil,
        )
    else:
        rays = self.sample_point_source(
            fov_x=fov_x_list,
            fov_y=fov_y_list,
            num_rays=num_rays,
            wvln=wvln,
            depth=depth,
            scale_pupil=scale_pupil,
        )
    return rays

sample_radial_rays

sample_radial_rays(num_field=5, depth=float('inf'), num_rays=SPP_PSF, wvln=DEFAULT_WAVE)

Sample radial (meridional, y direction) rays at different field angles.

This function is usually used for (1) PSF radial map, and (2) RMS error radial map calculation.

Parameters:

Name Type Description Default
num_field int

number of field angles. Defaults to 5.

5
depth float

sampling depth. Defaults to float("inf").

float('inf')
num_rays int

number of rays. Defaults to SPP_PSF.

SPP_PSF
wvln float

ray wvln. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE

Returns:

Name Type Description
ray Ray object

Ray object. Shape [num_field, num_rays, 3]

Source code in deeplens/optics/geolens.py
@torch.no_grad()
def sample_radial_rays(
    self,
    num_field=5,
    depth=float("inf"),
    num_rays=SPP_PSF,
    wvln=DEFAULT_WAVE,
):
    """Sample radial (meridional, y direction) rays at different field angles.

    This function is usually used for (1) PSF radial map, and (2) RMS error radial map calculation.

    Args:
        num_field (int, optional): number of field angles. Defaults to 5.
        depth (float, optional): sampling depth. Defaults to float("inf").
        num_rays (int, optional): number of rays. Defaults to SPP_PSF.
        wvln (float, optional): ray wvln. Defaults to DEFAULT_WAVE.

    Returns:
        ray (Ray object): Ray object. Shape [num_field, num_rays, 3]
    """
    device = self.device
    fov_deg = float(np.rad2deg(self.rfov))
    fov_y_list = torch.linspace(0, fov_deg, num_field, device=device)

    if depth == float("inf"):
        ray = self.sample_parallel(
            fov_x=0.0, fov_y=fov_y_list, num_rays=num_rays, wvln=wvln
        )
    else:
        point_obj_x = torch.zeros(num_field, device=device)
        point_obj_y = depth * torch.tan(fov_y_list * torch.pi / 180.0)
        point_obj = torch.stack(
            [point_obj_x, point_obj_y, torch.full_like(point_obj_x, depth)], dim=-1
        )
        ray = self.sample_from_points(
            points=point_obj, num_rays=num_rays, wvln=wvln
        )
    return ray

sample_from_points

sample_from_points(points=[[0.0, 0.0, -10000.0]], num_rays=SPP_PSF, wvln=DEFAULT_WAVE, scale_pupil=1.0)

Sample rays from point sources in object space (absolute physical coordinates).

Used for PSF and chief ray calculation.

Parameters:

Name Type Description Default
points list or Tensor

Ray origins in shape [3], [N, 3], or [Nx, Ny, 3].

[[0.0, 0.0, -10000.0]]
num_rays int

Number of rays per point. Default: SPP_PSF.

SPP_PSF
wvln float

Wavelength of rays. Default: DEFAULT_WAVE.

DEFAULT_WAVE
scale_pupil float

Scale factor for pupil radius.

1.0

Returns:

Name Type Description
Ray

Sampled rays with shape (\*points.shape[:-1], num_rays, 3).

Source code in deeplens/optics/geolens.py
@torch.no_grad()
def sample_from_points(
    self,
    points=[[0.0, 0.0, -10000.0]],
    num_rays=SPP_PSF,
    wvln=DEFAULT_WAVE,
    scale_pupil=1.0,
):
    """
    Sample rays from point sources in object space (absolute physical coordinates).

    Used for PSF and chief ray calculation.

    Args:
        points (list or Tensor): Ray origins in shape [3], [N, 3], or [Nx, Ny, 3].
        num_rays (int): Number of rays per point. Default: SPP_PSF.
        wvln (float): Wavelength of rays. Default: DEFAULT_WAVE.
        scale_pupil (float): Scale factor for pupil radius.

    Returns:
        Ray: Sampled rays with shape ``(\\*points.shape[:-1], num_rays, 3)``.
    """
    # Ray origin is given
    ray_o = torch.tensor(points) if not torch.is_tensor(points) else points
    ray_o = ray_o.to(self.device)

    # Sample points on the pupil
    pupilz, pupilr = self.get_entrance_pupil()
    pupilr *= scale_pupil
    ray_o2 = self.sample_circle(
        r=pupilr, z=pupilz, shape=(*ray_o.shape[:-1], num_rays)
    )

    # Compute ray directions
    if len(ray_o.shape) == 1:
        # Input point shape is [3]
        ray_o = ray_o.unsqueeze(0).repeat(num_rays, 1)  # shape [num_rays, 3]
        ray_d = ray_o2 - ray_o

    elif len(ray_o.shape) == 2:
        # Input point shape is [N, 3]
        ray_o = ray_o.unsqueeze(1).repeat(1, num_rays, 1)  # shape [N, num_rays, 3]
        ray_d = ray_o2 - ray_o

    elif len(ray_o.shape) == 3:
        # Input point shape is [Nx, Ny, 3]
        ray_o = ray_o.unsqueeze(2).repeat(
            1, 1, num_rays, 1
        )  # shape [Nx, Ny, num_rays, 3]
        ray_d = ray_o2 - ray_o

    else:
        raise Exception("The shape of input object positions is not supported.")

    # Calculate rays
    rays = Ray(ray_o, ray_d, wvln, device=self.device)
    return rays

sample_parallel

sample_parallel(fov_x=[0.0], fov_y=[0.0], num_rays=SPP_CALC, wvln=DEFAULT_WAVE, entrance_pupil=True, depth=-1.0, scale_pupil=1.0)

Sample parallel rays in object space for geometric optics calculations.

Parameters:

Name Type Description Default
fov_x float or list

Field angle(s) in the xz plane (degrees). Default: [0.0].

[0.0]
fov_y float or list

Field angle(s) in the yz plane (degrees). Default: [0.0].

[0.0]
num_rays int

Number of rays per field point. Default: SPP_CALC.

SPP_CALC
wvln float

Wavelength of rays. Default: DEFAULT_WAVE.

DEFAULT_WAVE
entrance_pupil bool

If True, sample origins on entrance pupil; otherwise, on surface 0. Default: True.

True
depth float

Propagation depth in z. Default: -1.0.

-1.0
scale_pupil float

Scale factor for pupil radius. Default: 1.0.

1.0

Returns:

Name Type Description
Ray

Rays with shape [..., num_rays, 3], where leading dims are: - both fov_x and fov_y scalars: [num_rays, 3] - fov_x scalar: [len(fov_y), num_rays, 3] - fov_y scalar: [len(fov_x), num_rays, 3] - both lists: [len(fov_y), len(fov_x), num_rays, 3] Ordered as (u, v).

Source code in deeplens/optics/geolens.py
@torch.no_grad()
def sample_parallel(
    self,
    fov_x=[0.0],
    fov_y=[0.0],
    num_rays=SPP_CALC,
    wvln=DEFAULT_WAVE,
    entrance_pupil=True,
    depth=-1.0,
    scale_pupil=1.0,
):
    """
    Sample parallel rays in object space for geometric optics calculations.

    Args:
        fov_x (float or list): Field angle(s) in the xz plane (degrees). Default: [0.0].
        fov_y (float or list): Field angle(s) in the yz plane (degrees). Default: [0.0].
        num_rays (int): Number of rays per field point. Default: SPP_CALC.
        wvln (float): Wavelength of rays. Default: DEFAULT_WAVE.
        entrance_pupil (bool): If True, sample origins on entrance pupil; otherwise, on surface 0. Default: True.
        depth (float): Propagation depth in z. Default: -1.0.
        scale_pupil (float): Scale factor for pupil radius. Default: 1.0.

    Returns:
        Ray:
            Rays with shape [..., num_rays, 3], where leading dims are:
            - both fov_x and fov_y scalars: [num_rays, 3]
            - fov_x scalar: [len(fov_y), num_rays, 3]
            - fov_y scalar: [len(fov_x), num_rays, 3]
            - both lists: [len(fov_y), len(fov_x), num_rays, 3]
            Ordered as (u, v).
    """
    # Remember whether inputs were scalar
    x_scalar = isinstance(fov_x, (float, int))
    y_scalar = isinstance(fov_y, (float, int))

    # Normalize to lists for internal processing
    if x_scalar:
        fov_x = [float(fov_x)]
    if y_scalar:
        fov_y = [float(fov_y)]

    fov_x = torch.tensor([fx * torch.pi / 180 for fx in fov_x]).to(self.device)
    fov_y = torch.tensor([fy * torch.pi / 180 for fy in fov_y]).to(self.device)

    # Sample ray origins on the pupil
    if entrance_pupil:
        pupilz, pupilr = self.get_entrance_pupil()
        pupilr *= scale_pupil
    else:
        pupilz, pupilr = 0.0, self.surfaces[0].r
        pupilr *= scale_pupil

    ray_o = self.sample_circle(
        r=pupilr, z=pupilz, shape=[len(fov_y), len(fov_x), num_rays]
    )

    # Sample ray directions
    fov_x_grid, fov_y_grid = torch.meshgrid(fov_x, fov_y, indexing="xy")
    dx = torch.tan(fov_x_grid).unsqueeze(-1).expand_as(ray_o[..., 0])
    dy = torch.tan(fov_y_grid).unsqueeze(-1).expand_as(ray_o[..., 1])
    dz = torch.ones_like(ray_o[..., 2])
    ray_d = torch.stack((dx, dy, dz), dim=-1)

    # Squeeze singleton FOV dims only if the original input was scalar
    if x_scalar:
        ray_o = ray_o.squeeze(1)
        ray_d = ray_d.squeeze(1)
    if y_scalar:
        ray_o = ray_o.squeeze(0)
        ray_d = ray_d.squeeze(0)

    rays = Ray(ray_o, ray_d, wvln, device=self.device)
    rays.prop_to(depth)
    return rays

sample_point_source

sample_point_source(fov_x=[0.0], fov_y=[0.0], depth=DEPTH, num_rays=SPP_PSF, wvln=DEFAULT_WAVE, entrance_pupil=True, scale_pupil=1.0)

Sample point source rays from object space with given field angles.

Used for (1) spot/rms/magnification calculation, (2) distortion/sensor sampling.

This function is equivalent to self.point_source_grid() + self.sample_from_points().

Parameters:

Name Type Description Default
fov_x float or list

field angle in x0z plane.

[0.0]
fov_y float or list

field angle in y0z plane.

[0.0]
depth float

sample plane z position. Defaults to -10.0.

DEPTH
num_rays int

number of rays sampled from each grid point. Defaults to 16.

SPP_PSF
entrance_pupil bool

whether to use entrance pupil. Defaults to False.

True
wvln float

ray wvln. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE

Returns:

Name Type Description
ray Ray object

Ray object. Shape [len(fov_y), len(fov_x), num_rays, 3], arranged in uv order.

Source code in deeplens/optics/geolens.py
@torch.no_grad()
def sample_point_source(
    self,
    fov_x=[0.0],
    fov_y=[0.0],
    depth=DEPTH,
    num_rays=SPP_PSF,
    wvln=DEFAULT_WAVE,
    entrance_pupil=True,
    scale_pupil=1.0,
):
    """Sample point source rays from object space with given field angles.

    Used for (1) spot/rms/magnification calculation, (2) distortion/sensor sampling.

    This function is equivalent to self.point_source_grid() + self.sample_from_points().

    Args:
        fov_x (float or list): field angle in x0z plane.
        fov_y (float or list): field angle in y0z plane.
        depth (float, optional): sample plane z position. Defaults to -10.0.
        num_rays (int, optional): number of rays sampled from each grid point. Defaults to 16.
        entrance_pupil (bool, optional): whether to use entrance pupil. Defaults to False.
        wvln (float, optional): ray wvln. Defaults to DEFAULT_WAVE.

    Returns:
        ray (Ray object): Ray object. Shape [len(fov_y), len(fov_x), num_rays, 3], arranged in uv order.
    """
    # Sample second points on the pupil, shape [len(fov_y), len(fov_x), num_rays, 3]
    if entrance_pupil:
        pupilz, pupilr = self.get_entrance_pupil()
        pupilr *= scale_pupil
    else:
        pupilz, pupilr = 0, self.surfaces[0].r

    # Sample grid points with given field angles, shape [len(fov_y), len(fov_x), 3]
    fov_x = torch.tensor([fx * torch.pi / 180 for fx in fov_x]).to(self.device)
    fov_y = torch.tensor([fy * torch.pi / 180 for fy in fov_y]).to(self.device)
    fov_x_grid, fov_y_grid = torch.meshgrid(fov_x, fov_y, indexing="xy")
    x, y = torch.tan(fov_x_grid) * depth, torch.tan(fov_y_grid) * depth

    # Form ray origins, shape [len(fov_y), len(fov_x), num_rays, 3]
    z = torch.full_like(x, depth)
    ray_o = torch.stack((x, y, z), -1)
    ray_o = ray_o.unsqueeze(2).repeat(1, 1, num_rays, 1)

    ray_o2 = self.sample_circle(
        r=pupilr, z=pupilz, shape=(len(fov_y), len(fov_x), num_rays)
    )

    # Compute ray directions
    ray_d = ray_o2 - ray_o

    ray = Ray(ray_o, ray_d, wvln, device=self.device)
    return ray

sample_sensor

sample_sensor(spp=64, wvln=DEFAULT_WAVE, sub_pixel=False)

Sample rays from sensor pixels (backward rays). Used for ray tracing rendering.

Parameters:

Name Type Description Default
spp int

sample per pixel. Defaults to 64.

64
pupil bool

whether to use pupil. Defaults to True.

required
wvln float

ray wvln. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE
sub_pixel bool

whether to sample multiple points inside the pixel. Defaults to False.

False

Returns:

Name Type Description
ray Ray object

Ray object. Shape [H, W, spp, 3]

Source code in deeplens/optics/geolens.py
@torch.no_grad()
def sample_sensor(self, spp=64, wvln=DEFAULT_WAVE, sub_pixel=False):
    """Sample rays from sensor pixels (backward rays). Used for ray tracing rendering.

    Args:
        spp (int, optional): sample per pixel. Defaults to 64.
        pupil (bool, optional): whether to use pupil. Defaults to True.
        wvln (float, optional): ray wvln. Defaults to DEFAULT_WAVE.
        sub_pixel (bool, optional): whether to sample multiple points inside the pixel. Defaults to False.

    Returns:
        ray (Ray object): Ray object. Shape [H, W, spp, 3]
    """
    w, h = self.sensor_size
    W, H = self.sensor_res
    device = self.device

    # Sample points on sensor plane
    # Use top-left point as reference in rendering, so here we should sample bottom-right point
    x1, y1 = torch.meshgrid(
        torch.linspace(
            -w / 2,
            w / 2,
            W + 1,
            device=device,
        )[1:],
        torch.linspace(
            h / 2,
            -h / 2,
            H + 1,
            device=device,
        )[1:],
        indexing="xy",
    )
    z1 = torch.full_like(x1, self.d_sensor.item())

    # Sample second points on the pupil
    pupilz, pupilr = self.get_exit_pupil()
    ray_o2 = self.sample_circle(r=pupilr, z=pupilz, shape=(*self.sensor_res, spp))

    # Form rays
    ray_o = torch.stack((x1, y1, z1), 2)
    ray_o = ray_o.unsqueeze(2).repeat(1, 1, spp, 1)  # [H, W, spp, 3]

    # Sub-pixel sampling for more realistic rendering
    if sub_pixel:
        delta_ox = (
            torch.rand((ray_o[:, :, :, 0].clone().shape), device=device)
            * self.pixel_size
        )
        delta_oy = (
            -torch.rand((ray_o[:, :, :, 1].clone().shape), device=device)
            * self.pixel_size
        )
        delta_oz = torch.zeros_like(delta_ox)
        delta_o = torch.stack((delta_ox, delta_oy, delta_oz), -1)
        ray_o = ray_o + delta_o

    # Form rays
    ray_d = ray_o2 - ray_o  # shape [H, W, spp, 3]
    ray = Ray(ray_o, ray_d, wvln, device=device)
    return ray

sample_circle

sample_circle(r, z, shape=[16, 16, 512])

Sample points inside a circle.

Parameters:

Name Type Description Default
r float

Radius of the circle.

required
z float

Z-coordinate for all sampled points.

required
shape list

Shape of the output tensor.

[16, 16, 512]

Returns:

Type Description

torch.Tensor: Sampled points, shape (\*shape, 3).

Source code in deeplens/optics/geolens.py
def sample_circle(self, r, z, shape=[16, 16, 512]):
    """Sample points inside a circle.

    Args:
        r (float): Radius of the circle.
        z (float): Z-coordinate for all sampled points.
        shape (list): Shape of the output tensor.

    Returns:
        torch.Tensor: Sampled points, shape ``(\\*shape, 3)``.
    """
    device = self.device

    # Generate random angles and radii
    theta = torch.rand(*shape, device=device) * 2 * torch.pi
    r2 = torch.rand(*shape, device=device) * r**2
    radius = torch.sqrt(r2)

    # Stack to form 3D points
    x = radius * torch.cos(theta)
    y = radius * torch.sin(theta)
    z_tensor = torch.full_like(x, z)
    points = torch.stack((x, y, z_tensor), dim=-1)

    # Manually sample chief ray
    # points[..., 0, :2] = 0.0

    return points

trace

trace(ray, surf_range=None, record=False)

Trace rays through the lens.

Forward or backward tracing is automatically determined by the ray direction.

Parameters:

Name Type Description Default
ray Ray object

Ray object.

required
surf_range list

Surface index range.

None
record bool

record ray path or not.

False

Returns:

Name Type Description
ray_final Ray object

ray after optical system.

ray_o_rec list

list of intersection points.

Source code in deeplens/optics/geolens.py
def trace(self, ray, surf_range=None, record=False):
    """Trace rays through the lens.

    Forward or backward tracing is automatically determined by the ray direction.

    Args:
        ray (Ray object): Ray object.
        surf_range (list): Surface index range.
        record (bool): record ray path or not.

    Returns:
        ray_final (Ray object): ray after optical system.
        ray_o_rec (list): list of intersection points.
    """
    if surf_range is None:
        surf_range = range(0, len(self.surfaces))

    if (ray.d[..., 2].unsqueeze(-1) > 0).any():
        ray_out, ray_o_rec = self.forward_tracing(ray, surf_range, record=record)
    else:
        ray_out, ray_o_rec = self.backward_tracing(ray, surf_range, record=record)

    return ray_out, ray_o_rec

trace2obj

trace2obj(ray)

Traces rays backwards through all lens surfaces from sensor side to object side.

Parameters:

Name Type Description Default
ray Ray

Ray object to trace backwards.

required

Returns:

Name Type Description
Ray

Ray object after backward propagation through the lens.

Source code in deeplens/optics/geolens.py
def trace2obj(self, ray):
    """Traces rays backwards through all lens surfaces from sensor side
    to object side.

    Args:
        ray (Ray): Ray object to trace backwards.

    Returns:
        Ray: Ray object after backward propagation through the lens.
    """
    ray, _ = self.trace(ray)
    return ray

trace2sensor

trace2sensor(ray, record=False)

Forward trace rays through the lens to sensor plane.

Parameters:

Name Type Description Default
ray Ray object

Ray object.

required
record bool

record ray path or not.

False

Returns:

Name Type Description
ray_out Ray object

ray after optical system.

ray_o_record list

list of intersection points.

Source code in deeplens/optics/geolens.py
def trace2sensor(self, ray, record=False):
    """Forward trace rays through the lens to sensor plane.

    Args:
        ray (Ray object): Ray object.
        record (bool): record ray path or not.

    Returns:
        ray_out (Ray object): ray after optical system.
        ray_o_record (list): list of intersection points.
    """
    # Manually propagate ray to a shallow depth to avoid numerical instability
    if (ray.o[..., 2].min() < -100.0).any():
        ray = ray.prop_to(-10.0)

    # Trace rays
    ray, ray_o_record = self.trace(ray, record=record)
    ray = ray.prop_to(self.d_sensor)

    if record:
        ray_o = ray.o.clone().detach()
        # Set to NaN to be skipped in 2d layout visualization
        ray_o[ray.is_valid == 0] = float("nan")
        ray_o_record.append(ray_o)
        return ray, ray_o_record
    else:
        return ray

trace2exit_pupil

trace2exit_pupil(ray)

Forward trace rays through the lens to exit pupil plane.

Parameters:

Name Type Description Default
ray Ray

Ray object to trace.

required

Returns:

Name Type Description
Ray

Ray object propagated to the exit pupil plane.

Source code in deeplens/optics/geolens.py
def trace2exit_pupil(self, ray):
    """Forward trace rays through the lens to exit pupil plane.

    Args:
        ray (Ray): Ray object to trace.

    Returns:
        Ray: Ray object propagated to the exit pupil plane.
    """
    ray = self.trace2sensor(ray)
    pupil_z, _ = self.get_exit_pupil()
    ray = ray.prop_to(pupil_z)
    return ray

forward_tracing

forward_tracing(ray, surf_range, record)

Forward traces rays through each surface in the specified range from object side to image side.

Parameters:

Name Type Description Default
ray Ray

Ray object to trace.

required
surf_range range

Range of surface indices to trace through.

required
record bool

If True, record ray positions at each surface.

required

Returns:

Name Type Description
tuple

(ray_out, ray_o_record) where: - ray_out (Ray): Ray after propagation through all surfaces. - ray_o_record (list or None): List of ray positions at each surface, or None if record is False.

Source code in deeplens/optics/geolens.py
def forward_tracing(self, ray, surf_range, record):
    """Forward traces rays through each surface in the specified range from object side to image side.

    Args:
        ray (Ray): Ray object to trace.
        surf_range (range): Range of surface indices to trace through.
        record (bool): If True, record ray positions at each surface.

    Returns:
        tuple: (ray_out, ray_o_record) where:
            - ray_out (Ray): Ray after propagation through all surfaces.
            - ray_o_record (list or None): List of ray positions at each surface,
                or None if record is False.
    """
    if record:
        ray_o_record = []
        ray_o_record.append(ray.o.clone().detach())
    else:
        ray_o_record = None

    mat1 = Material("air")
    for i in surf_range:
        n1 = mat1.ior(ray.wvln)
        n2 = self.surfaces[i].mat2.ior(ray.wvln)
        ray = self.surfaces[i].ray_reaction(ray, n1, n2)
        mat1 = self.surfaces[i].mat2

        if record:
            ray_out_o = ray.o.clone().detach()
            ray_out_o[ray.is_valid == 0] = float("nan")
            ray_o_record.append(ray_out_o)

    return ray, ray_o_record

backward_tracing

backward_tracing(ray, surf_range, record)

Backward traces rays through each surface in reverse order from image side to object side.

Parameters:

Name Type Description Default
ray Ray

Ray object to trace.

required
surf_range range

Range of surface indices to trace through.

required
record bool

If True, record ray positions at each surface.

required

Returns:

Name Type Description
tuple

(ray_out, ray_o_record) where: - ray_out (Ray): Ray after backward propagation through all surfaces. - ray_o_record (list or None): List of ray positions at each surface, or None if record is False.

Source code in deeplens/optics/geolens.py
def backward_tracing(self, ray, surf_range, record):
    """Backward traces rays through each surface in reverse order from image side to object side.

    Args:
        ray (Ray): Ray object to trace.
        surf_range (range): Range of surface indices to trace through.
        record (bool): If True, record ray positions at each surface.

    Returns:
        tuple: (ray_out, ray_o_record) where:
            - ray_out (Ray): Ray after backward propagation through all surfaces.
            - ray_o_record (list or None): List of ray positions at each surface,
                or None if record is False.
    """
    if record:
        ray_o_record = []
        ray_o_record.append(ray.o.clone().detach())
    else:
        ray_o_record = None

    mat1 = Material("air")
    for i in np.flip(surf_range):
        n1 = mat1.ior(ray.wvln)
        n2 = self.surfaces[i - 1].mat2.ior(ray.wvln)
        ray = self.surfaces[i].ray_reaction(ray, n1, n2)
        mat1 = self.surfaces[i - 1].mat2

        if record:
            ray_out_o = ray.o.clone().detach()
            ray_out_o[ray.is_valid == 0] = float("nan")
            ray_o_record.append(ray_out_o)

    return ray, ray_o_record

render

render(img_obj, depth=DEPTH, method='ray_tracing', **kwargs)

Differentiable image simulation.

Image simulation methods

[1] PSF map block convolution. [2] PSF patch convolution. [3] Ray tracing rendering.

Parameters:

Name Type Description Default
img_obj Tensor

Input image object in raw space. Shape of [N, C, H, W].

required
depth float

Depth of the object. Defaults to DEPTH.

DEPTH
method str

Image simulation method. One of 'psf_map', 'psf_patch', or 'ray_tracing'. Defaults to 'ray_tracing'.

'ray_tracing'
**kwargs

Additional arguments for different methods: - psf_grid (tuple): Grid size for PSF map method. Defaults to (10, 10). - psf_ks (int): Kernel size for PSF methods. Defaults to PSF_KS. - patch_center (tuple): Center position for PSF patch method. - spp (int): Samples per pixel for ray tracing. Defaults to SPP_RENDER.

{}

Returns:

Name Type Description
Tensor

Rendered image tensor. Shape of [N, C, H, W].

Source code in deeplens/optics/geolens.py
def render(self, img_obj, depth=DEPTH, method="ray_tracing", **kwargs):
    """Differentiable image simulation.

    Image simulation methods:
        [1] PSF map block convolution.
        [2] PSF patch convolution.
        [3] Ray tracing rendering.

    Args:
        img_obj (Tensor): Input image object in raw space. Shape of [N, C, H, W].
        depth (float, optional): Depth of the object. Defaults to DEPTH.
        method (str, optional): Image simulation method. One of 'psf_map', 'psf_patch',
            or 'ray_tracing'. Defaults to 'ray_tracing'.
        **kwargs: Additional arguments for different methods:
            - psf_grid (tuple): Grid size for PSF map method. Defaults to (10, 10).
            - psf_ks (int): Kernel size for PSF methods. Defaults to PSF_KS.
            - patch_center (tuple): Center position for PSF patch method.
            - spp (int): Samples per pixel for ray tracing. Defaults to SPP_RENDER.

    Returns:
        Tensor: Rendered image tensor. Shape of [N, C, H, W].
    """
    B, C, Himg, Wimg = img_obj.shape
    Wsensor, Hsensor = self.sensor_res

    # Image simulation
    if method == "psf_map":
        # PSF rendering - uses PSF map to render image
        assert Wimg == Wsensor and Himg == Hsensor, (
            f"Sensor resolution {Wsensor}x{Hsensor} must match input image {Wimg}x{Himg}."
        )
        psf_grid = kwargs.get("psf_grid", (10, 10))
        psf_ks = kwargs.get("psf_ks", PSF_KS)
        img_render = self.render_psf_map(
            img_obj, depth=depth, psf_grid=psf_grid, psf_ks=psf_ks
        )

    elif method == "psf_patch":
        # PSF patch rendering - uses a single PSF to render a patch of the image
        patch_center = kwargs.get("patch_center", (0.0, 0.0))
        psf_ks = kwargs.get("psf_ks", PSF_KS)
        img_render = self.render_psf_patch(
            img_obj, depth=depth, patch_center=patch_center, psf_ks=psf_ks
        )

    elif method == "ray_tracing":
        # Ray tracing rendering
        assert Wimg == Wsensor and Himg == Hsensor, (
            f"Sensor resolution {Wsensor}x{Hsensor} must match input image {Wimg}x{Himg}."
        )
        spp = kwargs.get("spp", SPP_RENDER)
        img_render = self.render_raytracing(img_obj, depth=depth, spp=spp)

    else:
        raise Exception(f"Image simulation method {method} is not supported.")

    return img_render

render_raytracing

render_raytracing(img, depth=DEPTH, spp=SPP_RENDER, vignetting=False)

Render RGB image using ray tracing rendering.

Parameters:

Name Type Description Default
img tensor

RGB image tensor. Shape of [N, 3, H, W].

required
depth float

Depth of the object. Defaults to DEPTH.

DEPTH
spp int

Sample per pixel. Defaults to 64.

SPP_RENDER
vignetting bool

whether to consider vignetting effect. Defaults to False.

False

Returns:

Name Type Description
img_render tensor

Rendered RGB image tensor. Shape of [N, 3, H, W].

Source code in deeplens/optics/geolens.py
def render_raytracing(self, img, depth=DEPTH, spp=SPP_RENDER, vignetting=False):
    """Render RGB image using ray tracing rendering.

    Args:
        img (tensor): RGB image tensor. Shape of [N, 3, H, W].
        depth (float, optional): Depth of the object. Defaults to DEPTH.
        spp (int, optional): Sample per pixel. Defaults to 64.
        vignetting (bool, optional): whether to consider vignetting effect. Defaults to False.

    Returns:
        img_render (tensor): Rendered RGB image tensor. Shape of [N, 3, H, W].
    """
    img_render = torch.zeros_like(img)
    for i in range(3):
        img_render[:, i, :, :] = self.render_raytracing_mono(
            img=img[:, i, :, :],
            wvln=WAVE_RGB[i],
            depth=depth,
            spp=spp,
            vignetting=vignetting,
        )
    return img_render

render_raytracing_mono

render_raytracing_mono(img, wvln, depth=DEPTH, spp=64, vignetting=False)

Render monochrome image using ray tracing rendering.

Parameters:

Name Type Description Default
img tensor

Monochrome image tensor. Shape of [N, 1, H, W] or [N, H, W].

required
wvln float

Wavelength of the light.

required
depth float

Depth of the object. Defaults to DEPTH.

DEPTH
spp int

Sample per pixel. Defaults to 64.

64

Returns:

Name Type Description
img_mono tensor

Rendered monochrome image tensor. Shape of [N, 1, H, W] or [N, H, W].

Source code in deeplens/optics/geolens.py
def render_raytracing_mono(self, img, wvln, depth=DEPTH, spp=64, vignetting=False):
    """Render monochrome image using ray tracing rendering.

    Args:
        img (tensor): Monochrome image tensor. Shape of [N, 1, H, W] or [N, H, W].
        wvln (float): Wavelength of the light.
        depth (float, optional): Depth of the object. Defaults to DEPTH.
        spp (int, optional): Sample per pixel. Defaults to 64.

    Returns:
        img_mono (tensor): Rendered monochrome image tensor. Shape of [N, 1, H, W] or [N, H, W].
    """
    img = torch.flip(img, [-2, -1])
    scale = self.calc_scale(depth=depth)
    ray = self.sample_sensor(spp=spp, wvln=wvln)
    ray = self.trace2obj(ray)
    img_mono = self.render_compute_image(
        img, depth=depth, scale=scale, ray=ray, vignetting=vignetting
    )
    return img_mono

render_compute_image

render_compute_image(img, depth, scale, ray, vignetting=False)

Computes the intersection points between rays and the object image plane, then generates the rendered image following rendering equation.

Back-propagation gradient flow: image -> w_i -> u -> p -> ray -> surface

Parameters:

Name Type Description Default
img tensor

[N, C, H, W] or [N, H, W] shape image tensor.

required
depth float

depth of the object.

required
scale float

scale factor.

required
ray Ray object

Ray object. Shape [H, W, spp, 3].

required
vignetting bool

whether to consider vignetting effect.

False

Returns:

Name Type Description
image tensor

[N, C, H, W] or [N, H, W] shape rendered image tensor.

Source code in deeplens/optics/geolens.py
def render_compute_image(self, img, depth, scale, ray, vignetting=False):
    """Computes the intersection points between rays and the object image plane, then generates the rendered image following rendering equation.

    Back-propagation gradient flow: image -> w_i -> u -> p -> ray -> surface

    Args:
        img (tensor): [N, C, H, W] or [N, H, W] shape image tensor.
        depth (float): depth of the object.
        scale (float): scale factor.
        ray (Ray object): Ray object. Shape [H, W, spp, 3].
        vignetting (bool): whether to consider vignetting effect.

    Returns:
        image (tensor): [N, C, H, W] or [N, H, W] shape rendered image tensor.
    """
    assert torch.is_tensor(img), "Input image should be Tensor."

    # Padding
    H, W = img.shape[-2:]
    if len(img.shape) == 3:
        img = F.pad(img.unsqueeze(1), (1, 1, 1, 1), "replicate").squeeze(1)
    elif len(img.shape) == 4:
        img = F.pad(img, (1, 1, 1, 1), "replicate")
    else:
        raise ValueError("Input image should be [N, C, H, W] or [N, H, W] tensor.")

    # Scale object image physical size to get 1:1 pixel-pixel alignment with sensor image
    ray = ray.prop_to(depth)
    p = ray.o[..., :2]
    pixel_size = scale * self.pixel_size
    ray.is_valid = (
        ray.is_valid
        * (torch.abs(p[..., 0] / pixel_size) < (W / 2 + 1))
        * (torch.abs(p[..., 1] / pixel_size) < (H / 2 + 1))
    )

    # Convert to uv coordinates in object image coordinate
    # (we do padding so corrdinates should add 1)
    u = torch.clamp(W / 2 + p[..., 0] / pixel_size, min=-0.99, max=W - 0.01)
    v = torch.clamp(H / 2 + p[..., 1] / pixel_size, min=0.01, max=H + 0.99)

    # (idx_i, idx_j) denotes left-top pixel (reference pixel). Index does not store gradients
    # (idx + 1 because we did padding)
    idx_i = H - v.ceil().long() + 1
    idx_j = u.floor().long() + 1

    # Gradients are stored in interpolation weight parameters
    w_i = v - v.floor().long()
    w_j = u.ceil().long() - u

    # Bilinear interpolation
    # (img shape [B, N, H', W'], idx_i shape [H, W, spp], w_i shape [H, W, spp], irr_img shape [N, C, H, W, spp])
    irr_img = img[..., idx_i, idx_j] * w_i * w_j
    irr_img += img[..., idx_i + 1, idx_j] * (1 - w_i) * w_j
    irr_img += img[..., idx_i, idx_j + 1] * w_i * (1 - w_j)
    irr_img += img[..., idx_i + 1, idx_j + 1] * (1 - w_i) * (1 - w_j)

    # Computation image
    if not vignetting:
        image = torch.sum(irr_img * ray.is_valid, -1) / (
            torch.sum(ray.is_valid, -1) + EPSILON
        )
    else:
        image = torch.sum(irr_img * ray.is_valid, -1) / torch.numel(ray.is_valid)

    return image

unwarp

unwarp(img, depth=DEPTH, num_grid=128, crop=True, flip=True)

Unwarp rendered images using distortion map.

Parameters:

Name Type Description Default
img tensor

Rendered image tensor. Shape of [N, C, H, W].

required
depth float

Depth of the object. Defaults to DEPTH.

DEPTH
grid_size int

Grid size. Defaults to 256.

required
crop bool

Whether to crop the image. Defaults to True.

True

Returns:

Name Type Description
img_unwarpped tensor

Unwarped image tensor. Shape of [N, C, H, W].

Source code in deeplens/optics/geolens.py
def unwarp(self, img, depth=DEPTH, num_grid=128, crop=True, flip=True):
    """Unwarp rendered images using distortion map.

    Args:
        img (tensor): Rendered image tensor. Shape of [N, C, H, W].
        depth (float, optional): Depth of the object. Defaults to DEPTH.
        grid_size (int, optional): Grid size. Defaults to 256.
        crop (bool, optional): Whether to crop the image. Defaults to True.

    Returns:
        img_unwarpped (tensor): Unwarped image tensor. Shape of [N, C, H, W].
    """
    # Calculate distortion map, shape (num_grid, num_grid, 2)
    distortion_map = self.distortion_map(depth=depth, num_grid=num_grid)

    # Interpolate distortion map to image resolution
    distortion_map = distortion_map.permute(2, 0, 1).unsqueeze(1)
    # distortion_map = torch.flip(distortion_map, [-2]) if flip else distortion_map
    distortion_map = F.interpolate(
        distortion_map, img.shape[-2:], mode="bilinear", align_corners=True
    )  # shape (B, 2, Himg, Wimg)
    distortion_map = distortion_map.permute(1, 2, 3, 0).repeat(
        img.shape[0], 1, 1, 1
    )  # shape (B, Himg, Wimg, 2)

    # Unwarp using grid_sample function
    img_unwarpped = F.grid_sample(
        img, distortion_map, align_corners=True
    )  # shape (B, C, Himg, Wimg)
    return img_unwarpped

analysis_rendering

analysis_rendering(img_org, save_name=None, depth=DEPTH, spp=SPP_RENDER, unwarp=False, noise=0.0, method='ray_tracing', show=False)

Render a single image for visualization and analysis.

Parameters:

Name Type Description Default
img_org Tensor

Original image with shape [H, W, 3].

required
save_name str

Path prefix for saving rendered images. Defaults to None.

None
depth float

Depth of object image. Defaults to DEPTH.

DEPTH
spp int

Sample per pixel. Defaults to SPP_RENDER.

SPP_RENDER
unwarp bool

If True, unwarp the image to correct distortion. Defaults to False.

False
noise float

Gaussian noise standard deviation. Defaults to 0.0.

0.0
method str

Rendering method ('ray_tracing', etc.). Defaults to 'ray_tracing'.

'ray_tracing'
show bool

If True, display the rendered image. Defaults to False.

False

Returns:

Name Type Description
Tensor

Rendered image tensor with shape [1, 3, H, W].

Source code in deeplens/optics/geolens.py
@torch.no_grad()
def analysis_rendering(
    self,
    img_org,
    save_name=None,
    depth=DEPTH,
    spp=SPP_RENDER,
    unwarp=False,
    noise=0.0,
    method="ray_tracing",
    show=False,
):
    """Render a single image for visualization and analysis.

    Args:
        img_org (Tensor): Original image with shape [H, W, 3].
        save_name (str, optional): Path prefix for saving rendered images. Defaults to None.
        depth (float, optional): Depth of object image. Defaults to DEPTH.
        spp (int, optional): Sample per pixel. Defaults to SPP_RENDER.
        unwarp (bool, optional): If True, unwarp the image to correct distortion. Defaults to False.
        noise (float, optional): Gaussian noise standard deviation. Defaults to 0.0.
        method (str, optional): Rendering method ('ray_tracing', etc.). Defaults to 'ray_tracing'.
        show (bool, optional): If True, display the rendered image. Defaults to False.

    Returns:
        Tensor: Rendered image tensor with shape [1, 3, H, W].
    """
    # Change sensor resolution to match the image
    sensor_res_original = self.sensor_res
    if isinstance(img_org, np.ndarray):
        img = torch.from_numpy(img_org).permute(2, 0, 1).unsqueeze(0).float() / 255.0
    elif torch.is_tensor(img_org):
        img = img_org.permute(2, 0, 1).unsqueeze(0).float()
        if img.max() > 1.0:
            img = img / 255.0
    img = img.to(self.device)
    self.set_sensor_res(sensor_res=img.shape[-2:])

    # Image rendering
    img_render = self.render(img, depth=depth, method=method, spp=spp)

    # Add noise (a very simple Gaussian noise model)
    if noise > 0:
        img_render = img_render + torch.randn_like(img_render) * noise
        img_render = torch.clamp(img_render, 0, 1)

    # Compute PSNR and SSIM
    img_np = img.squeeze(0).permute(1, 2, 0).cpu().numpy()
    render_np = img_render.squeeze(0).permute(1, 2, 0).clamp(0, 1).cpu().detach().numpy()
    render_psnr = round(peak_signal_noise_ratio(img_np, render_np, data_range=1.0), 3)
    render_ssim = round(structural_similarity(img_np, render_np, channel_axis=2, data_range=1.0), 4)
    print(f"Rendered image: PSNR={render_psnr:.3f}, SSIM={render_ssim:.4f}")

    # Save image
    if save_name is not None:
        save_image(img_render, f"{save_name}.png")

    # Unwarp to correct geometry distortion
    if unwarp:
        img_render = self.unwarp(img_render, depth)

        # Compute PSNR and SSIM
        render_np = img_render.squeeze(0).permute(1, 2, 0).clamp(0, 1).cpu().detach().numpy()
        render_psnr = round(peak_signal_noise_ratio(img_np, render_np, data_range=1.0), 3)
        render_ssim = round(structural_similarity(img_np, render_np, channel_axis=2, data_range=1.0), 4)
        print(
            f"Rendered image (unwarped): PSNR={render_psnr:.3f}, SSIM={render_ssim:.4f}"
        )

        if save_name is not None:
            save_image(img_render, f"{save_name}_unwarped.png")

    # Change the sensor resolution back
    self.set_sensor_res(sensor_res=sensor_res_original)

    # Show image
    if show:
        plt.imshow(img_render.cpu().squeeze(0).permute(1, 2, 0).numpy())
        plt.title("Rendered image")
        plt.axis("off")
        plt.show()
        plt.close()

    return img_render

psf

psf(points, ks=PSF_KS, wvln=DEFAULT_WAVE, spp=None, recenter=True, model='geometric')

Calculate Point Spread Function (PSF) for given point sources.

Supports multiple PSF calculation models
  • geometric: Incoherent intensity ray tracing (fast, differentiable)
  • coherent: Coherent ray tracing with free-space propagation (accurate, differentiable)
  • huygens: Huygens-Fresnel integration (accurate, not differentiable)

Parameters:

Name Type Description Default
points Tensor

Point source positions. Shape [N, 3] with x, y in [-1, 1] and z in [-Inf, 0]. Normalized coordinates.

required
ks int

Output kernel size in pixels. Defaults to PSF_KS.

PSF_KS
wvln float

Wavelength in [um]. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE
spp int

Samples per pixel. If None, uses model-specific default.

None
recenter bool

If True, center PSF using chief ray. Defaults to True.

True
model str

PSF model type. One of 'geometric', 'coherent', 'huygens'. Defaults to 'geometric'.

'geometric'

Returns:

Name Type Description
Tensor

PSF normalized to sum to 1. Shape [ks, ks] or [N, ks, ks].

Source code in deeplens/optics/geolens.py
def psf(
    self,
    points,
    ks=PSF_KS,
    wvln=DEFAULT_WAVE,
    spp=None,
    recenter=True,
    model="geometric",
):
    """Calculate Point Spread Function (PSF) for given point sources.

    Supports multiple PSF calculation models:
        - geometric: Incoherent intensity ray tracing (fast, differentiable)
        - coherent: Coherent ray tracing with free-space propagation (accurate, differentiable)
        - huygens: Huygens-Fresnel integration (accurate, not differentiable)

    Args:
        points (Tensor): Point source positions. Shape [N, 3] with x, y in [-1, 1]
            and z in [-Inf, 0]. Normalized coordinates.
        ks (int, optional): Output kernel size in pixels. Defaults to PSF_KS.
        wvln (float, optional): Wavelength in [um]. Defaults to DEFAULT_WAVE.
        spp (int, optional): Samples per pixel. If None, uses model-specific default.
        recenter (bool, optional): If True, center PSF using chief ray. Defaults to True.
        model (str, optional): PSF model type. One of 'geometric', 'coherent', 'huygens'.
            Defaults to 'geometric'.

    Returns:
        Tensor: PSF normalized to sum to 1. Shape [ks, ks] or [N, ks, ks].
    """
    if model == "geometric":
        spp = SPP_PSF if spp is None else spp
        return self.psf_geometric(points, ks, wvln, spp, recenter)
    elif model == "coherent":
        spp = SPP_COHERENT if spp is None else spp
        return self.psf_coherent(points, ks, wvln, spp, recenter)
    elif model == "huygens":
        spp = SPP_COHERENT if spp is None else spp
        return self.psf_huygens(points, ks, wvln, spp, recenter)
    else:
        raise ValueError(f"Unknown PSF model: {model}")

psf_geometric

psf_geometric(points, ks=PSF_KS, wvln=DEFAULT_WAVE, spp=SPP_PSF, recenter=True)

Single wavelength geometric PSF calculation.

Parameters:

Name Type Description Default
points Tensor

Normalized point source position. Shape of [N, 3], x, y in range [-1, 1], z in range [-Inf, 0].

required
ks int

Output kernel size.

PSF_KS
wvln float

Wavelength.

DEFAULT_WAVE
spp int

Sample per pixel.

SPP_PSF
recenter bool

Recenter PSF using chief ray.

True

Returns:

Name Type Description
psf

Shape of [ks, ks] or [N, ks, ks].

References

[1] https://optics.ansys.com/hc/en-us/articles/42661723066515-What-is-a-Point-Spread-Function

Source code in deeplens/optics/geolens.py
def psf_geometric(
    self, points, ks=PSF_KS, wvln=DEFAULT_WAVE, spp=SPP_PSF, recenter=True
):
    """Single wavelength geometric PSF calculation.

    Args:
        points (Tensor): Normalized point source position. Shape of [N, 3], x, y in range [-1, 1], z in range [-Inf, 0].
        ks (int, optional): Output kernel size.
        wvln (float, optional): Wavelength.
        spp (int, optional): Sample per pixel.
        recenter (bool, optional): Recenter PSF using chief ray.

    Returns:
        psf: Shape of [ks, ks] or [N, ks, ks].

    References:
        [1] https://optics.ansys.com/hc/en-us/articles/42661723066515-What-is-a-Point-Spread-Function
    """
    sensor_w, sensor_h = self.sensor_size
    pixel_size = self.pixel_size
    device = self.device

    # Points shape of [N, 3]
    if not torch.is_tensor(points):
        points = torch.tensor(points, device=device)

    if len(points.shape) == 1:
        single_point = True
        points = points.unsqueeze(0)
    else:
        single_point = False

    # Sample rays. Ray position in the object space by perspective projection
    depth = points[:, 2]
    scale = self.calc_scale(depth)
    point_obj_x = points[..., 0] * scale * sensor_w / 2
    point_obj_y = points[..., 1] * scale * sensor_h / 2
    point_obj = torch.stack([point_obj_x, point_obj_y, points[..., 2]], dim=-1)
    ray = self.sample_from_points(points=point_obj, num_rays=spp, wvln=wvln)

    # Trace rays to sensor plane (incoherent)
    ray.coherent = False
    ray = self.trace2sensor(ray)

    # Calculate PSF center, shape [N, 2]
    if recenter:
        pointc = self.psf_center(point_obj, method="chief_ray")
    else:
        pointc = self.psf_center(point_obj, method="pinhole")

    # Monte Carlo integration
    psf = forward_integral(ray.flip_xy(), ps=pixel_size, ks=ks, pointc=pointc)

    # Intensity normalization
    psf = psf / (torch.sum(psf, dim=(-2, -1), keepdim=True) + EPSILON)

    if single_point:
        psf = psf.squeeze(0)

    return diff_float(psf)

psf_coherent

psf_coherent(points, ks=PSF_KS, wvln=DEFAULT_WAVE, spp=SPP_COHERENT, recenter=True)

Alias for psf_pupil_prop. Calculates PSF by coherent ray tracing to exit pupil followed by Angular Spectrum Method (ASM) propagation.

Source code in deeplens/optics/geolens.py
def psf_coherent(
    self, points, ks=PSF_KS, wvln=DEFAULT_WAVE, spp=SPP_COHERENT, recenter=True
):
    """Alias for psf_pupil_prop. Calculates PSF by coherent ray tracing to exit pupil followed by Angular Spectrum Method (ASM) propagation."""
    return self.psf_pupil_prop(points, ks=ks, wvln=wvln, spp=spp, recenter=recenter)

psf_pupil_prop

psf_pupil_prop(points, ks=PSF_KS, wvln=DEFAULT_WAVE, spp=SPP_COHERENT, recenter=True)

Single point monochromatic PSF using exit-pupil diffraction model. This function is differentiable.

Steps

1, Calculate complex wavefield at exit-pupil plane by coherent ray tracing. 2, Free-space propagation to sensor plane and calculate intensity PSF.

Parameters:

Name Type Description Default
points Tensor

[x, y, z] coordinates of the point source. Defaults to torch.Tensor([0,0,-10000]).

required
ks int

size of the PSF patch. Defaults to PSF_KS.

PSF_KS
wvln float

wvln. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE
spp int

number of rays to sample. Defaults to SPP_COHERENT.

SPP_COHERENT
recenter bool

Recenter PSF using chief ray. Defaults to True.

True

Returns:

Name Type Description
psf_out Tensor

PSF patch. Normalized to sum to 1. Shape [ks, ks]

Reference

[1] "End-to-End Hybrid Refractive-Diffractive Lens Design with Differentiable Ray-Wave Model", SIGGRAPH Asia 2024.

Note

[1] This function is similar to ZEMAX FFT_PSF but implement free-space propagation with Angular Spectrum Method (ASM) rather than FFT transform. Free-space propagation using ASM is more accurate than doing FFT, because FFT (as used in ZEMAX) assumes far-field condition (e.g., chief ray perpendicular to image plane).

Source code in deeplens/optics/geolens.py
def psf_pupil_prop(
    self, points, ks=PSF_KS, wvln=DEFAULT_WAVE, spp=SPP_COHERENT, recenter=True
):
    """Single point monochromatic PSF using exit-pupil diffraction model. This function is differentiable.

    Steps:
        1, Calculate complex wavefield at exit-pupil plane by coherent ray tracing.
        2, Free-space propagation to sensor plane and calculate intensity PSF.

    Args:
        points (torch.Tensor, optional): [x, y, z] coordinates of the point source. Defaults to torch.Tensor([0,0,-10000]).
        ks (int, optional): size of the PSF patch. Defaults to PSF_KS.
        wvln (float, optional): wvln. Defaults to DEFAULT_WAVE.
        spp (int, optional): number of rays to sample. Defaults to SPP_COHERENT.
        recenter (bool, optional): Recenter PSF using chief ray. Defaults to True.

    Returns:
        psf_out (torch.Tensor): PSF patch. Normalized to sum to 1. Shape [ks, ks]

    Reference:
        [1] "End-to-End Hybrid Refractive-Diffractive Lens Design with Differentiable Ray-Wave Model", SIGGRAPH Asia 2024.

    Note:
        [1] This function is similar to ZEMAX FFT_PSF but implement free-space propagation with Angular Spectrum Method (ASM) rather than FFT transform. Free-space propagation using ASM is more accurate than doing FFT, because FFT (as used in ZEMAX) assumes far-field condition (e.g., chief ray perpendicular to image plane).
    """
    # Pupil field by coherent ray tracing
    wavefront, psfc = self.pupil_field(
        points=points, wvln=wvln, spp=spp, recenter=recenter
    )

    # Propagate to sensor plane and get intensity
    pupilz, pupilr = self.get_exit_pupil()
    h, w = wavefront.shape
    # Manually pad wave field
    wavefront = F.pad(
        wavefront.unsqueeze(0).unsqueeze(0),
        [h // 2, h // 2, w // 2, w // 2],
        mode="constant",
        value=0,
    )
    # Free-space propagation using Angular Spectrum Method (ASM)
    sensor_field = AngularSpectrumMethod(
        wavefront,
        z=self.d_sensor - pupilz,
        wvln=wvln,
        ps=self.pixel_size,
        padding=False,
    )
    # Get intensity
    psf_inten = sensor_field.abs() ** 2

    # Calculate PSF center
    h, w = psf_inten.shape[-2:]
    # consider both interplation and padding
    psfc_idx_i = ((2 - psfc[1]) * h / 4).round().long()
    psfc_idx_j = ((2 + psfc[0]) * w / 4).round().long()

    # Crop valid PSF region and normalize
    if ks is not None:
        psf_inten_pad = (
            F.pad(
                psf_inten,
                [ks // 2, ks // 2, ks // 2, ks // 2],
                mode="constant",
                value=0,
            )
            .squeeze(0)
            .squeeze(0)
        )
        psf = psf_inten_pad[
            psfc_idx_i : psfc_idx_i + ks, psfc_idx_j : psfc_idx_j + ks
        ]
    else:
        psf = psf_inten

    # Intensity normalization, shape of [ks, ks] or [h, w]
    psf = psf / (torch.sum(psf, dim=(-2, -1), keepdim=True) + EPSILON)

    return diff_float(psf)

pupil_field

pupil_field(points, wvln=DEFAULT_WAVE, spp=SPP_COHERENT, recenter=True)

Compute complex wavefront at exit pupil plane by coherent ray tracing.

The wavefront is flipped for subsequent PSF calculation and has the same size as the image sensor. This function is differentiable.

Parameters:

Name Type Description Default
points Tensor or list

Single point source position. Shape [3] or [1, 3], with x, y in [-1, 1] and z in [-Inf, 0].

required
wvln float

Wavelength in [um]. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE
spp int

Number of rays to sample. Must be >= 1,000,000 for accurate coherent simulation. Defaults to SPP_COHERENT.

SPP_COHERENT
recenter bool

If True, center using chief ray. Defaults to True.

True

Returns:

Name Type Description
tuple

(wavefront, psf_center) where: - wavefront (Tensor): Complex wavefront at exit pupil. Shape [H, H]. - psf_center (list): Normalized PSF center coordinates [x, y] in [-1, 1].

Note

Default dtype must be torch.float64 for accurate phase calculation.

Source code in deeplens/optics/geolens.py
def pupil_field(self, points, wvln=DEFAULT_WAVE, spp=SPP_COHERENT, recenter=True):
    """Compute complex wavefront at exit pupil plane by coherent ray tracing.

    The wavefront is flipped for subsequent PSF calculation and has the same
    size as the image sensor. This function is differentiable.

    Args:
        points (Tensor or list): Single point source position. Shape [3] or [1, 3],
            with x, y in [-1, 1] and z in [-Inf, 0].
        wvln (float, optional): Wavelength in [um]. Defaults to DEFAULT_WAVE.
        spp (int, optional): Number of rays to sample. Must be >= 1,000,000 for
            accurate coherent simulation. Defaults to SPP_COHERENT.
        recenter (bool, optional): If True, center using chief ray. Defaults to True.

    Returns:
        tuple: (wavefront, psf_center) where:
            - wavefront (Tensor): Complex wavefront at exit pupil. Shape [H, H].
            - psf_center (list): Normalized PSF center coordinates [x, y] in [-1, 1].

    Note:
        Default dtype must be torch.float64 for accurate phase calculation.
    """
    assert spp >= 1_000_000, (
        f"Ray sampling {spp} is too small for coherent ray tracing, which may lead to inaccurate simulation."
    )
    assert torch.get_default_dtype() == torch.float64, (
        "Default dtype must be set to float64 for accurate phase calculation."
    )

    sensor_w, sensor_h = self.sensor_size
    device = self.device

    if isinstance(points, list):
        points = torch.tensor(points, device=device).unsqueeze(0)  # [1, 3]
    elif torch.is_tensor(points) and len(points.shape) == 1:
        points = points.unsqueeze(0).to(device)  # [1, 3]
    elif torch.is_tensor(points) and len(points.shape) == 2:
        assert points.shape[0] == 1, (
            f"pupil_field only supports single point input, got shape {points.shape}"
        )
    else:
        raise ValueError(f"Unsupported point type {points.type()}.")

    assert points.shape[0] == 1, (
        "Only one point is supported for pupil field calculation."
    )

    # Ray origin in the object space
    scale = self.calc_scale(points[:, 2].item())
    points_obj = points.clone()
    points_obj[:, 0] = points[:, 0] * scale * sensor_w / 2  # x coordinate
    points_obj[:, 1] = points[:, 1] * scale * sensor_h / 2  # y coordinate

    # Ray center determined by chief ray
    # Shape of [N, 2], un-normalized physical coordinates
    if recenter:
        pointc = self.psf_center(points_obj, method="chief_ray")
    else:
        pointc = self.psf_center(points_obj, method="pinhole")

    # Ray-tracing to exit_pupil
    ray = self.sample_from_points(points=points_obj, num_rays=spp, wvln=wvln)
    ray.coherent = True
    ray = self.trace2exit_pupil(ray)

    # Calculate complex field (same physical size and resolution as the sensor)
    # Complex field is flipped here for further PSF calculation
    pointc_ref = torch.zeros_like(points[:, :2]).to(device)  # [N, 2]
    wavefront = forward_integral(
        ray.flip_xy(),
        ps=self.pixel_size,
        ks=self.sensor_res[1],
        pointc=pointc_ref,
    )
    wavefront = wavefront.squeeze(0)  # [H, H]

    # PSF center (on the sensor plane)
    pointc = pointc[0, :]
    psf_center = [
        pointc[0] / sensor_w * 2,
        pointc[1] / sensor_h * 2,
    ]

    return wavefront, psf_center

psf_huygens

psf_huygens(points, ks=PSF_KS, wvln=DEFAULT_WAVE, spp=SPP_COHERENT, recenter=True)

Single wavelength Huygens PSF calculation.

This function is not differentiable due to its heavy computational cost.

Steps

1, Trace coherent rays to exit-pupil plane. 2, Treat every ray as a secondary point source emitting a spherical wave.

Parameters:

Name Type Description Default
points Tensor

Normalized point source position. Shape of [N, 3], x, y in range [-1, 1], z in range [-Inf, 0].

required
ks int

Output kernel size.

PSF_KS
wvln float

Wavelength.

DEFAULT_WAVE
spp int

Sample per pixel.

SPP_COHERENT
recenter bool

Recenter PSF using chief ray.

True

Returns:

Name Type Description
psf

Shape of [ks, ks] or [N, ks, ks].

References

[1] "Optical Aberrations Correction in Postprocessing Using Imaging Simulation", TOG 2021

Note

This is different from ZEMAX Huygens PSF, which traces rays to image plane and do plane wave integration.

Source code in deeplens/optics/geolens.py
def psf_huygens(
    self, points, ks=PSF_KS, wvln=DEFAULT_WAVE, spp=SPP_COHERENT, recenter=True
):
    """Single wavelength Huygens PSF calculation.

    This function is not differentiable due to its heavy computational cost.

    Steps:
        1, Trace coherent rays to exit-pupil plane.
        2, Treat every ray as a secondary point source emitting a spherical wave.

    Args:
        points (Tensor): Normalized point source position. Shape of [N, 3], x, y in range [-1, 1], z in range [-Inf, 0].
        ks (int, optional): Output kernel size.
        wvln (float, optional): Wavelength.
        spp (int, optional): Sample per pixel.
        recenter (bool, optional): Recenter PSF using chief ray.

    Returns:
        psf: Shape of [ks, ks] or [N, ks, ks].

    References:
        [1] "Optical Aberrations Correction in Postprocessing Using Imaging Simulation", TOG 2021

    Note:
        This is different from ZEMAX Huygens PSF, which traces rays to image plane and do plane wave integration.
    """
    assert torch.get_default_dtype() == torch.float64, (
        "Default dtype must be set to float64 for accurate phase calculation."
    )

    sensor_w, sensor_h = self.sensor_size
    pixel_size = self.pixel_size
    device = self.device
    wvln_mm = wvln * 1e-3  # Convert wavelength to mm

    # Points shape of [N, 3]
    if not torch.is_tensor(points):
        points = torch.tensor(points, device=device)

    if len(points.shape) == 1:
        single_point = True
        points = points.unsqueeze(0)
    elif len(points.shape) == 2 and points.shape[0] == 1:
        single_point = True
    else:
        raise ValueError(
            f"Points must be of shape [3] or [1, 3], got {points.shape}."
        )

    # Sample rays from object point
    depth = points[:, 2]
    scale = self.calc_scale(depth)
    point_obj_x = points[..., 0] * scale * sensor_w / 2
    point_obj_y = points[..., 1] * scale * sensor_h / 2
    point_obj = torch.stack([point_obj_x, point_obj_y, points[..., 2]], dim=-1)
    ray = self.sample_from_points(points=point_obj, num_rays=spp, wvln=wvln)

    # Trace rays coherently through the lens to exit pupil
    ray.coherent = True
    ray = self.trace2exit_pupil(ray)

    # Calculate PSF center (not flipped here)
    if recenter:
        pointc = -self.psf_center(point_obj, method="chief_ray")
    else:
        pointc = -self.psf_center(point_obj, method="pinhole")

    # Build PSF pixel coordinates (sensor plane at z = d_sensor)
    sensor_z = self.d_sensor.item()
    psf_half_size = (ks / 2) * pixel_size  # Physical half-size of PSF region
    x_coords = torch.linspace(
        -psf_half_size + pixel_size / 2,
        psf_half_size - pixel_size / 2,
        ks,
        device=device,
    )
    y_coords = torch.linspace(
        psf_half_size - pixel_size / 2,
        -psf_half_size + pixel_size / 2,
        ks,
        device=device,
    )
    psf_x, psf_y = torch.meshgrid(
        pointc[0, 0] + x_coords, pointc[0, 1] + y_coords, indexing="xy"
    )  # [ks, ks] each

    # Get valid rays only
    valid_mask = ray.is_valid > 0
    valid_pos = ray.o[valid_mask]  # [num_valid, 3]
    valid_dir = ray.d[valid_mask]  # [num_valid, 3]
    valid_opl = ray.opl[valid_mask]  # [num_valid]
    num_valid = valid_pos.shape[0]

    # Huygens integration: sum spherical waves from each secondary source
    psf_complex = torch.zeros(ks, ks, dtype=torch.complex128, device=device)
    opl_min = valid_opl.min()

    # Compute distance from each secondary source to each pixel
    batch_size = min(num_valid, 10_000)  # Process rays in batches
    for batch_start in range(0, num_valid, batch_size):
        batch_end = min(batch_start + batch_size, num_valid)

        # Batch ray data
        batch_pos = valid_pos[batch_start:batch_end]  # [batch, 3]
        batch_dir = valid_dir[batch_start:batch_end]  # [batch, 3]
        batch_opl = valid_opl[batch_start:batch_end].squeeze(-1)  # [batch]

        # Distance from each secondary source to each pixel
        # batch_pos: [batch, 3], psf_x: [ks, ks]
        dx = psf_x.unsqueeze(-1) - batch_pos[:, 0]  # [ks, ks, batch]
        dy = psf_y.unsqueeze(-1) - batch_pos[:, 1]  # [ks, ks, batch]
        dz = sensor_z - batch_pos[:, 2]  # [batch]

        # Distance r from secondary source to pixel
        r = torch.sqrt(dx**2 + dy**2 + dz**2)  # [ks, ks, batch]

        # Obliquity factor: cos(theta) where theta is angle from normal
        # Using ray direction at exit pupil (dz component)
        obliq = torch.abs(batch_dir[:, 2])  # [batch]
        amp = 0.5 * (1.0 + obliq)  # Huygens–Fresnel obliquity factor

        # Total optical path = OPL through lens + distance to pixel
        total_opl = batch_opl + r  # [ks, ks, batch]

        # Phase relative to reference
        phase = torch.fmod((total_opl - opl_min) / wvln_mm, 1.0) * (
            2 * torch.pi
        )  # [ks, ks, batch]

        # Complex amplitude: A * exp(i * phase) / r (spherical wave decay)
        # We use 1/r for spherical wave amplitude decay
        complex_amp = (amp / r) * torch.exp(1j * phase)  # [ks, ks, batch]

        # Sum contributions from this batch
        psf_complex += complex_amp.sum(dim=-1)  # [ks, ks]

    # Convert complex field to intensity
    psf = psf_complex.abs() ** 2

    # Intensity normalization
    psf = psf / (torch.sum(psf, dim=(-2, -1), keepdim=True) + EPSILON)

    # Flip PSF
    psf = torch.flip(psf, [-2, -1])

    if single_point:
        psf = psf.squeeze(0)

    return diff_float(psf)

psf_map

psf_map(depth=DEPTH, grid=(7, 7), ks=PSF_KS, spp=SPP_PSF, wvln=DEFAULT_WAVE, recenter=True)

Compute the geometric PSF map at given depth.

Overrides the base method in Lens class to improve efficiency by parallel ray tracing over different field points.

Parameters:

Name Type Description Default
depth float

Depth of the object plane. Defaults to DEPTH.

DEPTH
grid (int, tuple)

Grid size (grid_w, grid_h). Defaults to 7.

(7, 7)
ks int

Kernel size. Defaults to PSF_KS.

PSF_KS
spp int

Sample per pixel. Defaults to SPP_PSF.

SPP_PSF
recenter bool

Recenter PSF using chief ray. Defaults to True.

True

Returns:

Name Type Description
psf_map

PSF map. Shape of [grid_h, grid_w, 1, ks, ks].

Source code in deeplens/optics/geolens.py
def psf_map(
    self,
    depth=DEPTH,
    grid=(7, 7),
    ks=PSF_KS,
    spp=SPP_PSF,
    wvln=DEFAULT_WAVE,
    recenter=True,
):
    """Compute the geometric PSF map at given depth.

    Overrides the base method in Lens class to improve efficiency by parallel ray tracing over different field points.

    Args:
        depth (float, optional): Depth of the object plane. Defaults to DEPTH.
        grid (int, tuple): Grid size (grid_w, grid_h). Defaults to 7.
        ks (int, optional): Kernel size. Defaults to PSF_KS.
        spp (int, optional): Sample per pixel. Defaults to SPP_PSF.
        recenter (bool, optional): Recenter PSF using chief ray. Defaults to True.

    Returns:
        psf_map: PSF map. Shape of [grid_h, grid_w, 1, ks, ks].
    """
    if isinstance(grid, int):
        grid = (grid, grid)
    points = self.point_source_grid(depth=depth, grid=grid)
    points = points.reshape(-1, 3)
    psfs = self.psf(
        points=points, ks=ks, recenter=recenter, spp=spp, wvln=wvln
    ).unsqueeze(1)  # [grid_h * grid_w, 1, ks, ks]

    psf_map = psfs.reshape(grid[1], grid[0], 1, ks, ks)
    return psf_map

psf_center

psf_center(points_obj, method='chief_ray')

Compute reference PSF center (flipped to match the original point) for given point source.

Parameters:

Name Type Description Default
points_obj

[..., 3] un-normalized point in object plane. [-Inf, Inf] * [-Inf, Inf] * [-Inf, 0]

required
method

"chief_ray" or "pinhole". Defaults to "chief_ray".

'chief_ray'

Returns:

Name Type Description
psf_center

[..., 2] un-normalized psf center in sensor plane.

Source code in deeplens/optics/geolens.py
@torch.no_grad()
def psf_center(self, points_obj, method="chief_ray"):
    """Compute reference PSF center (flipped to match the original point) for given point source.

    Args:
        points_obj: [..., 3] un-normalized point in object plane. [-Inf, Inf] * [-Inf, Inf] * [-Inf, 0]
        method: "chief_ray" or "pinhole". Defaults to "chief_ray".

    Returns:
        psf_center: [..., 2] un-normalized psf center in sensor plane.
    """
    if method == "chief_ray":
        # Shrink the pupil and calculate green light centroid ray as the chief ray
        ray = self.sample_from_points(points_obj, scale_pupil=0.5, num_rays=SPP_CALC)
        ray = self.trace2sensor(ray)
        if not (ray.is_valid == 1).any():
            raise RuntimeError(
                "When tracing chief ray for PSF center calculation, no ray arrives at the sensor."
            )
        psf_center = ray.centroid()
        psf_center = -psf_center[..., :2]  # shape [..., 2]

    elif method == "pinhole":
        # Pinhole camera perspective projection, distortion not considered
        if points_obj[..., 2].min().abs() < 100:
            print(
                "Point source is too close, pinhole model may be inaccurate for PSF center calculation."
            )
        tan_point_fov_x = -points_obj[..., 0] / points_obj[..., 2]
        tan_point_fov_y = -points_obj[..., 1] / points_obj[..., 2]
        psf_center_x = self.foclen * tan_point_fov_x
        psf_center_y = self.foclen * tan_point_fov_y
        psf_center = torch.stack([psf_center_x, psf_center_y], dim=-1).to(
            self.device
        )

    else:
        raise ValueError(
            f"Unsupported method for PSF center calculation: {method}."
        )

    return psf_center

analysis_spot

analysis_spot(num_field=3, depth=float('inf'))

Compute sensor plane ray spot RMS error and radius.

Analyzes spot sizes across the field of view for multiple wavelengths (red, green, blue) and reports statistics.

Parameters:

Name Type Description Default
num_field int

Number of field positions to analyze along the radial direction. Defaults to 3.

3
depth float

Depth of the point source. Use float('inf') for collimated light. Defaults to float('inf').

float('inf')

Returns:

Name Type Description
dict

Spot analysis results keyed by field position (e.g., 'fov0.0', 'fov0.5'). Each entry contains 'rms' (RMS radius in um) and 'radius' (geometric radius in um).

Source code in deeplens/optics/geolens.py
def analysis_spot(self, num_field=3, depth=float("inf")):
    """Compute sensor plane ray spot RMS error and radius.

    Analyzes spot sizes across the field of view for multiple wavelengths
    (red, green, blue) and reports statistics.

    Args:
        num_field (int, optional): Number of field positions to analyze along the
            radial direction. Defaults to 3.
        depth (float, optional): Depth of the point source. Use float('inf') for
            collimated light. Defaults to float('inf').

    Returns:
        dict: Spot analysis results keyed by field position (e.g., 'fov0.0', 'fov0.5').
            Each entry contains 'rms' (RMS radius in um) and 'radius' (geometric radius in um).
    """
    rms_radius_fields = []
    geo_radius_fields = []
    for i, wvln in enumerate([WAVE_RGB[1], WAVE_RGB[0], WAVE_RGB[2]]):
        # Sample rays along meridional (y) direction, shape [num_field, num_rays, 3]
        ray = self.sample_radial_rays(
            num_field=num_field, depth=depth, num_rays=SPP_PSF, wvln=wvln
        )
        ray = self.trace2sensor(ray)

        # Green light point center for reference, shape [num_field, 1, 2]
        if i == 0:
            ray_xy_center_green = ray.centroid()[..., :2].unsqueeze(-2)

        # Calculate RMS spot size and radius for different FoVs
        ray_xy_norm = (
            ray.o[..., :2] - ray_xy_center_green
        ) * ray.is_valid.unsqueeze(-1)
        spot_rms = (
            ((ray_xy_norm**2).sum(-1) * ray.is_valid).sum(-1)
            / (ray.is_valid.sum(-1) + EPSILON)
        ).sqrt()
        spot_radius = (ray_xy_norm**2).sum(-1).sqrt().max(dim=-1).values

        # Append to list
        rms_radius_fields.append(spot_rms)
        geo_radius_fields.append(spot_radius)

    # Average over wavelengths, shape [num_field]
    avg_rms_radius_um = torch.stack(rms_radius_fields, dim=0).mean(dim=0) * 1000.0
    avg_geo_radius_um = torch.stack(geo_radius_fields, dim=0).mean(dim=0) * 1000.0

    # Print results
    print(f"Ray spot analysis results for depth {depth}:")
    print(
        f"RMS radius: FoV (0.0) {avg_rms_radius_um[0]:.3f} um, FoV (0.5) {avg_rms_radius_um[num_field // 2]:.3f} um, FoV (1.0) {avg_rms_radius_um[-1]:.3f} um"
    )
    print(
        f"Geo radius: FoV (0.0) {avg_geo_radius_um[0]:.3f} um, FoV (0.5) {avg_geo_radius_um[num_field // 2]:.3f} um, FoV (1.0) {avg_geo_radius_um[-1]:.3f} um"
    )

    # Save to dict
    rms_results = {}
    fov_ls = torch.linspace(0, 1, num_field)
    for i in range(num_field):
        fov = round(fov_ls[i].item(), 2)
        rms_results[f"fov{fov}"] = {
            "rms": round(avg_rms_radius_um[i].item(), 4),
            "radius": round(avg_geo_radius_um[i].item(), 4),
        }

    return rms_results

find_diff_surf

find_diff_surf()

Get differentiable/optimizable surface indices.

Returns a list of surface indices that can be optimized during lens design. Excludes the aperture surface from optimization.

Returns:

Type Description

list or range: Surface indices excluding the aperture.

Source code in deeplens/optics/geolens.py
def find_diff_surf(self):
    """Get differentiable/optimizable surface indices.

    Returns a list of surface indices that can be optimized during lens design.
    Excludes the aperture surface from optimization.

    Returns:
        list or range: Surface indices excluding the aperture.
    """
    if self.aper_idx is None:
        diff_surf_range = range(len(self.surfaces))
    else:
        diff_surf_range = list(range(0, self.aper_idx)) + list(
            range(self.aper_idx + 1, len(self.surfaces))
        )
    return diff_surf_range

calc_foclen

calc_foclen()

Compute effective focal length (EFL).

Traces a paraxial chief ray and computes the image height, then uses the image height to compute the EFL.

Updates

self.efl: Effective focal length. self.foclen: Alias for effective focal length. self.bfl: Back focal length (distance from last surface to sensor).

Reference

[1] https://wp.optics.arizona.edu/optomech/wp-content/uploads/sites/53/2016/10/Tutorial_MorelSophie.pdf [2] https://rafcamera.com/info/imaging-theory/back-focal-length

Source code in deeplens/optics/geolens.py
@torch.no_grad()
def calc_foclen(self):
    """Compute effective focal length (EFL).

    Traces a paraxial chief ray and computes the image height, then uses the image height to compute the EFL.

    Updates:
        self.efl: Effective focal length.
        self.foclen: Alias for effective focal length.
        self.bfl: Back focal length (distance from last surface to sensor).

    Reference:
        [1] https://wp.optics.arizona.edu/optomech/wp-content/uploads/sites/53/2016/10/Tutorial_MorelSophie.pdf
        [2] https://rafcamera.com/info/imaging-theory/back-focal-length
    """
    # Trace a paraxial chief ray, shape [1, 1, num_rays, 3]
    paraxial_fov = 0.01
    paraxial_fov_deg = float(np.rad2deg(paraxial_fov))

    # 1. Trace on-axis parallel rays to find paraxial focus z (equivalent to infinite focus)
    ray_axis = self.sample_parallel(
        fov_x=0.0, fov_y=0.0, entrance_pupil=False, scale_pupil=0.2
    )
    ray_axis, _ = self.trace(ray_axis)
    valid_axis = ray_axis.is_valid > 0
    t = -(ray_axis.d[valid_axis, 0] * ray_axis.o[valid_axis, 0]
          + ray_axis.d[valid_axis, 1] * ray_axis.o[valid_axis, 1]) / (
        ray_axis.d[valid_axis, 0] ** 2 + ray_axis.d[valid_axis, 1] ** 2
    )
    focus_z = ray_axis.o[valid_axis, 2] + t * ray_axis.d[valid_axis, 2]
    focus_z = focus_z[~torch.isnan(focus_z) & (focus_z > 0)]
    paraxial_focus_z = float(torch.mean(focus_z))

    # 2. Trace off-axis paraxial ray to paraxial focus, measure image height
    ray = self.sample_parallel(
        fov_x=0.0, fov_y=paraxial_fov_deg, entrance_pupil=False, scale_pupil=0.2
    )
    ray, _ = self.trace(ray)
    ray = ray.prop_to(paraxial_focus_z)

    # Compute the effective focal length
    paraxial_imgh = (ray.o[:, 1] * ray.is_valid).sum() / ray.is_valid.sum()
    eff_foclen = paraxial_imgh.item() / float(np.tan(paraxial_fov))
    self.efl = eff_foclen
    self.foclen = eff_foclen

    # Compute the back focal length
    self.bfl = self.d_sensor.item() - self.surfaces[-1].d.item()

    return eff_foclen

calc_numerical_aperture

calc_numerical_aperture(n=1.0)

Compute numerical aperture (NA).

Parameters:

Name Type Description Default
n float

Refractive index. Defaults to 1.0.

1.0

Returns:

Name Type Description
NA float

Numerical aperture.

Reference

[1] https://en.wikipedia.org/wiki/Numerical_aperture

Source code in deeplens/optics/geolens.py
@torch.no_grad()
def calc_numerical_aperture(self, n=1.0):
    """Compute numerical aperture (NA).

    Args:
        n (float, optional): Refractive index. Defaults to 1.0.

    Returns:
        NA (float): Numerical aperture.

    Reference:
        [1] https://en.wikipedia.org/wiki/Numerical_aperture
    """
    return n * math.sin(math.atan(1 / 2 / self.fnum))

calc_focal_plane

calc_focal_plane(wvln=DEFAULT_WAVE)

Compute the focus distance in the object space. Ray starts from sensor center and traces to the object space.

Parameters:

Name Type Description Default
wvln float

Wavelength. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE

Returns:

Name Type Description
focal_plane float

Focal plane in the object space.

Source code in deeplens/optics/geolens.py
@torch.no_grad()
def calc_focal_plane(self, wvln=DEFAULT_WAVE):
    """Compute the focus distance in the object space. Ray starts from sensor center and traces to the object space.

    Args:
        wvln (float, optional): Wavelength. Defaults to DEFAULT_WAVE.

    Returns:
        focal_plane (float): Focal plane in the object space.
    """
    device = self.device

    # Sample point source rays from sensor center
    o1 = torch.tensor([0, 0, self.d_sensor.item()]).repeat(SPP_CALC, 1)
    o1 = o1.to(device)

    # Sample the first surface as pupil
    # o2 = self.sample_circle(self.surfaces[0].r, z=0.0, shape=[SPP_CALC])
    # o2 *= 0.5  # Shrink sample region to improve accuracy
    pupilz, pupilr = self.get_exit_pupil()
    o2 = self.sample_circle(pupilr, pupilz, shape=[SPP_CALC])
    d = o2 - o1
    ray = Ray(o1, d, wvln, device=device)

    # Trace rays to object space
    ray = self.trace2obj(ray)

    # Optical axis intersection
    t = (ray.d[..., 0] * ray.o[..., 0] + ray.d[..., 1] * ray.o[..., 1]) / (
        ray.d[..., 0] ** 2 + ray.d[..., 1] ** 2
    )
    focus_z = (ray.o[..., 2] - ray.d[..., 2] * t)[ray.is_valid > 0].cpu().numpy()
    focus_z = focus_z[~np.isnan(focus_z) & (focus_z < 0)]

    if len(focus_z) > 0:
        focal_plane = float(np.mean(focus_z))
    else:
        raise ValueError(
            "No valid rays found, focal plane in the image space cannot be computed."
        )

    return focal_plane

calc_sensor_plane

calc_sensor_plane(depth=float('inf'))

Calculate in-focus sensor plane.

Parameters:

Name Type Description Default
depth float

Depth of the object plane. Defaults to float("inf").

float('inf')

Returns:

Name Type Description
d_sensor Tensor

Sensor plane in the image space.

Source code in deeplens/optics/geolens.py
@torch.no_grad()
def calc_sensor_plane(self, depth=float("inf")):
    """Calculate in-focus sensor plane.

    Args:
        depth (float, optional): Depth of the object plane. Defaults to float("inf").

    Returns:
        d_sensor (torch.Tensor): Sensor plane in the image space.
    """
    # Sample and trace rays, shape [SPP_CALC, 3]
    if depth == float("inf"):
        ray = self.sample_parallel(
            fov_x=0.0, fov_y=0.0, num_rays=SPP_CALC, wvln=DEFAULT_WAVE
        )
    else:
        ray = self.sample_from_points(
            points=torch.tensor([0.0, 0.0, depth]),
            num_rays=SPP_CALC,
            wvln=DEFAULT_WAVE,
        )
    ray = self.trace2sensor(ray)

    # Calculate in-focus sensor position
    t = (ray.d[:, 0] * ray.o[:, 0] + ray.d[:, 1] * ray.o[:, 1]) / (
        ray.d[:, 0] ** 2 + ray.d[:, 1] ** 2
    )
    focus_z = ray.o[:, 2] - ray.d[:, 2] * t
    focus_z = focus_z[ray.is_valid > 0]
    focus_z = focus_z[~torch.isnan(focus_z) & (focus_z > 0)]
    d_sensor = torch.mean(focus_z)
    return d_sensor

calc_fov

calc_fov()

Compute field of view (FoV) of the lens in radians.

Calculates FoV using two methods
  1. Perspective projection — from focal length and sensor size (effective FoV, ignoring distortion).
  2. Ray tracing — traces rays from the sensor edge backwards to determine the real FoV including distortion effects.
Updates

self.vfov (float): Vertical FoV in radians. self.hfov (float): Horizontal FoV in radians. self.dfov (float): Diagonal FoV in radians. self.rfov (float): Half-diagonal (radius) FoV in radians. self.real_rfov (float): Real half-diagonal FoV from ray tracing. self.real_dfov (float): Real diagonal FoV from ray tracing. self.eqfl (float): 35mm equivalent focal length in mm.

Reference

[1] https://en.wikipedia.org/wiki/Angle_of_view_(photography)

Source code in deeplens/optics/geolens.py
@torch.no_grad()
def calc_fov(self):
    """Compute field of view (FoV) of the lens in radians.

    Calculates FoV using two methods:
        1. **Perspective projection** — from focal length and sensor size
           (effective FoV, ignoring distortion).
        2. **Ray tracing** — traces rays from the sensor edge backwards to
           determine the real FoV including distortion effects.

    Updates:
        self.vfov (float): Vertical FoV in radians.
        self.hfov (float): Horizontal FoV in radians.
        self.dfov (float): Diagonal FoV in radians.
        self.rfov (float): Half-diagonal (radius) FoV in radians.
        self.real_rfov (float): Real half-diagonal FoV from ray tracing.
        self.real_dfov (float): Real diagonal FoV from ray tracing.
        self.eqfl (float): 35mm equivalent focal length in mm.

    Reference:
        [1] https://en.wikipedia.org/wiki/Angle_of_view_(photography)
    """
    if not hasattr(self, "foclen"):
        return

    # 1. Perspective projection (effective FoV)
    self.vfov = 2 * math.atan(self.sensor_size[0] / 2 / self.foclen)
    self.hfov = 2 * math.atan(self.sensor_size[1] / 2 / self.foclen)
    self.dfov = 2 * math.atan(self.r_sensor / self.foclen)
    self.rfov = self.dfov / 2  # radius (half diagonal) FoV

    # 2. Ray tracing to calculate real FoV (distortion-affected FoV)
    # Sample rays from edge of sensor, shape [SPP_CALC, 3]
    o1 = torch.zeros([SPP_CALC, 3])
    o1 = torch.tensor([self.r_sensor, 0, self.d_sensor.item()]).repeat(SPP_CALC, 1)

    # Sample second points on exit pupil
    pupilz, pupilx = self.get_exit_pupil()
    x2 = torch.linspace(-pupilx, pupilx, SPP_CALC)
    z2 = torch.full_like(x2, pupilz)
    y2 = torch.full_like(x2, 0)
    o2 = torch.stack((x2, y2, z2), axis=-1)

    # Ray tracing to object space
    ray = Ray(o1, o2 - o1, device=self.device)
    ray = self.trace2obj(ray)

    # Compute output ray angle
    tan_rfov = ray.d[..., 0] / ray.d[..., 2]
    rfov = torch.atan(torch.sum(tan_rfov * ray.is_valid) / torch.sum(ray.is_valid))

    # If calculation failed, use pinhole camera model to compute fov
    if torch.isnan(rfov):
        self.real_rfov = self.rfov
        self.real_dfov = self.dfov
        print(
            f"Failed to calculate distorted FoV by ray tracing, use effective FoV {self.rfov} rad."
        )
    else:
        self.real_rfov = rfov.item()
        self.real_dfov = 2 * rfov.item()

    # 3. Compute 35mm equivalent focal length. 35mm sensor: 36mm * 24mm
    self.eqfl = 21.63 / math.tan(self.rfov)

calc_scale

calc_scale(depth)

Calculate the scale factor (object height / image height).

Uses the pinhole camera model to compute magnification.

Parameters:

Name Type Description Default
depth float

Object distance from the lens (negative z direction).

required

Returns:

Name Type Description
float

Scale factor relating object height to image height.

Source code in deeplens/optics/geolens.py
@torch.no_grad()
def calc_scale(self, depth):
    """Calculate the scale factor (object height / image height).

    Uses the pinhole camera model to compute magnification.

    Args:
        depth (float): Object distance from the lens (negative z direction).

    Returns:
        float: Scale factor relating object height to image height.
    """
    return -depth / self.foclen

calc_pupil

calc_pupil()

Compute entrance and exit pupil positions and radii.

The entrance and exit pupils must be recalculated whenever
  • First-order parameters change (e.g., field of view, object height, image height),
  • Lens geometry or materials change (e.g., surface curvatures, refractive indices, thicknesses),
  • Or generally, any time the lens configuration is modified.
Updates

self.aper_idx: Index of the aperture surface. self.exit_pupilz, self.exit_pupilr: Exit pupil position and radius. self.entr_pupilz, self.entr_pupilr: Entrance pupil position and radius. self.exit_pupilz_parax, self.exit_pupilr_parax: Paraxial exit pupil. self.entr_pupilz_parax, self.entr_pupilr_parax: Paraxial entrance pupil. self.fnum: F-number calculated from focal length and entrance pupil.

Source code in deeplens/optics/geolens.py
@torch.no_grad()
def calc_pupil(self):
    """Compute entrance and exit pupil positions and radii.

    The entrance and exit pupils must be recalculated whenever:
        - First-order parameters change (e.g., field of view, object height, image height),
        - Lens geometry or materials change (e.g., surface curvatures, refractive indices, thicknesses),
        - Or generally, any time the lens configuration is modified.

    Updates:
        self.aper_idx: Index of the aperture surface.
        self.exit_pupilz, self.exit_pupilr: Exit pupil position and radius.
        self.entr_pupilz, self.entr_pupilr: Entrance pupil position and radius.
        self.exit_pupilz_parax, self.exit_pupilr_parax: Paraxial exit pupil.
        self.entr_pupilz_parax, self.entr_pupilr_parax: Paraxial entrance pupil.
        self.fnum: F-number calculated from focal length and entrance pupil.
    """
    # Find aperture
    self.aper_idx = None
    for i in range(len(self.surfaces)):
        if isinstance(self.surfaces[i], Aperture):
            self.aper_idx = i
            break

    if self.aper_idx is None:
        self.aper_idx = np.argmin([s.r for s in self.surfaces])
        print("No aperture found, use the smallest surface as aperture.")

    # Compute entrance and exit pupil
    self.exit_pupilz, self.exit_pupilr = self.calc_exit_pupil(paraxial=False)
    self.entr_pupilz, self.entr_pupilr = self.calc_entrance_pupil(paraxial=False)
    self.exit_pupilz_parax, self.exit_pupilr_parax = self.calc_exit_pupil(
        paraxial=True
    )
    self.entr_pupilz_parax, self.entr_pupilr_parax = self.calc_entrance_pupil(
        paraxial=True
    )

    # Compute F-number
    self.fnum = self.foclen / (2 * self.entr_pupilr)

get_entrance_pupil

get_entrance_pupil(paraxial=False)

Get entrance pupil location and radius.

Parameters:

Name Type Description Default
paraxial bool

If True, return paraxial approximation values. If False, return real ray-traced values. Defaults to False.

False

Returns:

Name Type Description
tuple

(z_position, radius) of the entrance pupil in [mm].

Source code in deeplens/optics/geolens.py
def get_entrance_pupil(self, paraxial=False):
    """Get entrance pupil location and radius.

    Args:
        paraxial (bool, optional): If True, return paraxial approximation values.
            If False, return real ray-traced values. Defaults to False.

    Returns:
        tuple: (z_position, radius) of the entrance pupil in [mm].
    """
    if paraxial:
        return self.entr_pupilz_parax, self.entr_pupilr_parax
    else:
        return self.entr_pupilz, self.entr_pupilr

get_exit_pupil

get_exit_pupil(paraxial=False)

Get exit pupil location and radius.

Parameters:

Name Type Description Default
paraxial bool

If True, return paraxial approximation values. If False, return real ray-traced values. Defaults to False.

False

Returns:

Name Type Description
tuple

(z_position, radius) of the exit pupil in [mm].

Source code in deeplens/optics/geolens.py
def get_exit_pupil(self, paraxial=False):
    """Get exit pupil location and radius.

    Args:
        paraxial (bool, optional): If True, return paraxial approximation values.
            If False, return real ray-traced values. Defaults to False.

    Returns:
        tuple: (z_position, radius) of the exit pupil in [mm].
    """
    if paraxial:
        return self.exit_pupilz_parax, self.exit_pupilr_parax
    else:
        return self.exit_pupilz, self.exit_pupilr

calc_exit_pupil

calc_exit_pupil(paraxial=False)

Calculate exit pupil location and radius.

Paraxial mode

Rays are emitted from near the center of the aperture stop and are close to the optical axis. This mode estimates the exit pupil position and radius under ideal (first-order) optical assumptions. It is fast and stable.

Non-paraxial mode

Rays are emitted from the edge of the aperture stop in large quantities. The exit pupil position and radius are determined based on the intersection points of these rays. This mode is slower and affected by aperture-related aberrations.

Use paraxial mode unless precise ray aiming is required.

Parameters:

Name Type Description Default
paraxial bool

center (True) or edge (False).

False

Returns:

Name Type Description
avg_pupilz float

z coordinate of exit pupil.

avg_pupilr float

radius of exit pupil.

Reference

[1] Exit pupil: how many rays can come from sensor to object space. [2] https://en.wikipedia.org/wiki/Exit_pupil

Source code in deeplens/optics/geolens.py
@torch.no_grad()
def calc_exit_pupil(self, paraxial=False):
    """Calculate exit pupil location and radius.

    Paraxial mode:
        Rays are emitted from near the center of the aperture stop and are close to the optical axis.
        This mode estimates the exit pupil position and radius under ideal (first-order) optical assumptions.
        It is fast and stable.

    Non-paraxial mode:
        Rays are emitted from the edge of the aperture stop in large quantities.
        The exit pupil position and radius are determined based on the intersection points of these rays.
        This mode is slower and affected by aperture-related aberrations.

    Use paraxial mode unless precise ray aiming is required.

    Args:
        paraxial (bool): center (True) or edge (False).

    Returns:
        avg_pupilz (float): z coordinate of exit pupil.
        avg_pupilr (float): radius of exit pupil.

    Reference:
        [1] Exit pupil: how many rays can come from sensor to object space.
        [2] https://en.wikipedia.org/wiki/Exit_pupil
    """
    if self.aper_idx is None or hasattr(self, "aper_idx") is False:
        print("No aperture, use the last surface as exit pupil.")
        return self.surfaces[-1].d.item(), self.surfaces[-1].r

    # Sample rays from aperture (edge or center)
    aper_idx = self.aper_idx
    aper_z = self.surfaces[aper_idx].d.item()
    aper_r = self.surfaces[aper_idx].r

    if paraxial:
        ray_o = torch.tensor([[DELTA_PARAXIAL, 0, aper_z]]).repeat(32, 1)
        phi_rad = torch.linspace(-0.01, 0.01, 32)
    else:
        ray_o = torch.tensor([[aper_r, 0, aper_z]]).repeat(SPP_CALC, 1)
        rfov = float(np.arctan(self.r_sensor / self.foclen))
        phi_rad = torch.linspace(-rfov / 2, rfov / 2, SPP_CALC)

    d = torch.stack(
        (torch.sin(phi_rad), torch.zeros_like(phi_rad), torch.cos(phi_rad)), axis=-1
    )
    ray = Ray(ray_o, d, device=self.device)

    # Ray tracing from aperture edge to last surface
    surf_range = range(self.aper_idx + 1, len(self.surfaces))
    ray, _ = self.trace(ray, surf_range=surf_range)

    # Compute intersection points, solving the equation: o1+d1*t1 = o2+d2*t2
    ray_o = torch.stack(
        [ray.o[ray.is_valid != 0][:, 0], ray.o[ray.is_valid != 0][:, 2]], dim=-1
    )
    ray_d = torch.stack(
        [ray.d[ray.is_valid != 0][:, 0], ray.d[ray.is_valid != 0][:, 2]], dim=-1
    )
    intersection_points = self.compute_intersection_points_2d(ray_o, ray_d)

    # Handle the case where no intersection points are found or small pupil
    if len(intersection_points) == 0:
        print("No intersection points found, use the last surface as exit pupil.")
        avg_pupilr = self.surfaces[-1].r
        avg_pupilz = self.surfaces[-1].d.item()
    else:
        avg_pupilr = torch.mean(intersection_points[:, 0]).item()
        avg_pupilz = torch.mean(intersection_points[:, 1]).item()

        if paraxial:
            avg_pupilr = abs(avg_pupilr / DELTA_PARAXIAL * aper_r)

        if avg_pupilr < EPSILON:
            print(
                "Zero or negative exit pupil is detected, use the last surface as pupil."
            )
            avg_pupilr = self.surfaces[-1].r
            avg_pupilz = self.surfaces[-1].d.item()

    return avg_pupilz, avg_pupilr

calc_entrance_pupil

calc_entrance_pupil(paraxial=False)

Calculate entrance pupil of the lens.

The entrance pupil is the optical image of the physical aperture stop, as seen through the optical elements in front of the stop. We sample backward rays from the aperture stop and trace them to the first surface, then find the intersection points of the reverse extension of the rays. The average of the intersection points defines the entrance pupil position and radius.

Parameters:

Name Type Description Default
paraxial bool

Ray sampling mode. If True, rays are emitted near the centre of the aperture stop (fast, paraxially stable). If False, rays are emitted from the stop edge in larger quantities (slower, accounts for aperture aberrations). Defaults to False.

False

Returns:

Name Type Description
tuple

(z_position, radius) of entrance pupil.

Note

[1] Use paraxial mode unless precise ray aiming is required. [2] This function only works for object at a far distance. For microscopes, this function usually returns a negative entrance pupil.

References

[1] Entrance pupil: how many rays can come from object space to sensor. [2] https://en.wikipedia.org/wiki/Entrance_pupil: "In an optical system, the entrance pupil is the optical image of the physical aperture stop, as 'seen' through the optical elements in front of the stop." [3] Zemax LLC, OpticStudio User Manual, Version 19.4, Document No. 2311, 2019.

Source code in deeplens/optics/geolens.py
@torch.no_grad()
def calc_entrance_pupil(self, paraxial=False):
    """Calculate entrance pupil of the lens.

    The entrance pupil is the optical image of the physical aperture stop, as seen through the optical elements in front of the stop. We sample backward rays from the aperture stop and trace them to the first surface, then find the intersection points of the reverse extension of the rays. The average of the intersection points defines the entrance pupil position and radius.

    Args:
        paraxial (bool): Ray sampling mode.  If ``True``, rays are emitted
            near the centre of the aperture stop (fast, paraxially stable).
            If ``False``, rays are emitted from the stop edge in larger
            quantities (slower, accounts for aperture aberrations).
            Defaults to ``False``.

    Returns:
        tuple: (z_position, radius) of entrance pupil.

    Note:
        [1] Use paraxial mode unless precise ray aiming is required.
        [2] This function only works for object at a far distance. For microscopes, this function usually returns a negative entrance pupil.

    References:
        [1] Entrance pupil: how many rays can come from object space to sensor.
        [2] https://en.wikipedia.org/wiki/Entrance_pupil: "In an optical system, the entrance pupil is the optical image of the physical aperture stop, as 'seen' through the optical elements in front of the stop."
        [3] Zemax LLC, *OpticStudio User Manual*, Version 19.4, Document No. 2311, 2019.
    """
    if self.aper_idx is None or not hasattr(self, "aper_idx"):
        print("No aperture stop, use the first surface as entrance pupil.")
        return self.surfaces[0].d.item(), self.surfaces[0].r

    # Sample rays from edge of aperture stop
    aper_idx = self.aper_idx
    aper_surf = self.surfaces[aper_idx]
    aper_z = aper_surf.d.item()
    if aper_surf.is_square:
        aper_r = float(np.sqrt(2)) * aper_surf.r
    else:
        aper_r = aper_surf.r

    if paraxial:
        ray_o = torch.tensor([[DELTA_PARAXIAL, 0, aper_z]]).repeat(32, 1)
        phi = torch.linspace(-0.01, 0.01, 32)
    else:
        ray_o = torch.tensor([[aper_r, 0, aper_z]]).repeat(SPP_CALC, 1)
        rfov = float(np.arctan(self.r_sensor / self.foclen))
        phi = torch.linspace(-rfov / 2, rfov / 2, SPP_CALC)

    d = torch.stack(
        (torch.sin(phi), torch.zeros_like(phi), -torch.cos(phi)), axis=-1
    )
    ray = Ray(ray_o, d, device=self.device)

    # Ray tracing from aperture edge to first surface
    surf_range = range(0, self.aper_idx)
    ray, _ = self.trace(ray, surf_range=surf_range)

    # Compute intersection points, solving the equation: o1+d1*t1 = o2+d2*t2
    ray_o = torch.stack(
        [ray.o[ray.is_valid > 0][:, 0], ray.o[ray.is_valid > 0][:, 2]], dim=-1
    )
    ray_d = torch.stack(
        [ray.d[ray.is_valid > 0][:, 0], ray.d[ray.is_valid > 0][:, 2]], dim=-1
    )
    intersection_points = self.compute_intersection_points_2d(ray_o, ray_d)

    # Handle the case where no intersection points are found or small entrance pupil
    if len(intersection_points) == 0:
        print(
            "No intersection points found, use the first surface as entrance pupil."
        )
        avg_pupilr = self.surfaces[0].r
        avg_pupilz = self.surfaces[0].d.item()
    else:
        avg_pupilr = torch.mean(intersection_points[:, 0]).item()
        avg_pupilz = torch.mean(intersection_points[:, 1]).item()

        if paraxial:
            avg_pupilr = abs(avg_pupilr / DELTA_PARAXIAL * aper_r)

        if avg_pupilr < EPSILON:
            print(
                "Zero or negative entrance pupil is detected, use the first surface as entrance pupil."
            )
            avg_pupilr = self.surfaces[0].r
            avg_pupilz = self.surfaces[0].d.item()

    return avg_pupilz, avg_pupilr

compute_intersection_points_2d staticmethod

compute_intersection_points_2d(origins, directions)

Compute the intersection points of 2D lines.

Parameters:

Name Type Description Default
origins Tensor

Origins of the lines. Shape: [N, 2]

required
directions Tensor

Directions of the lines. Shape: [N, 2]

required

Returns:

Type Description

torch.Tensor: Intersection points. Shape: [N*(N-1)/2, 2]

Source code in deeplens/optics/geolens.py
@staticmethod
def compute_intersection_points_2d(origins, directions):
    """Compute the intersection points of 2D lines.

    Args:
        origins (torch.Tensor): Origins of the lines. Shape: [N, 2]
        directions (torch.Tensor): Directions of the lines. Shape: [N, 2]

    Returns:
        torch.Tensor: Intersection points. Shape: [N*(N-1)/2, 2]
    """
    N = origins.shape[0]

    # Create pairwise combinations of indices
    idx = torch.arange(N)
    idx_i, idx_j = torch.combinations(idx, r=2).unbind(1)

    Oi = origins[idx_i]  # Shape: [N*(N-1)/2, 2]
    Oj = origins[idx_j]  # Shape: [N*(N-1)/2, 2]
    Di = directions[idx_i]  # Shape: [N*(N-1)/2, 2]
    Dj = directions[idx_j]  # Shape: [N*(N-1)/2, 2]

    # Vector from Oi to Oj
    b = Oj - Oi  # Shape: [N*(N-1)/2, 2]

    # Coefficients matrix A
    A = torch.stack([Di, -Dj], dim=-1)  # Shape: [N*(N-1)/2, 2, 2]

    # Solve the linear system Ax = b
    # Using least squares to handle the case of no exact solution
    if A.device.type == "mps":
        # Perform lstsq on CPU for MPS devices and move result back
        x, _ = torch.linalg.lstsq(A.cpu(), b.unsqueeze(-1).cpu())[:2]
        x = x.to(A.device)
    else:
        x, _ = torch.linalg.lstsq(A, b.unsqueeze(-1))[:2]
    x = x.squeeze(-1)  # Shape: [N*(N-1)/2, 2]
    s = x[:, 0]
    t = x[:, 1]

    # Calculate the intersection points using either rays
    P_i = Oi + s.unsqueeze(-1) * Di  # Shape: [N*(N-1)/2, 2]
    P_j = Oj + t.unsqueeze(-1) * Dj  # Shape: [N*(N-1)/2, 2]

    # Take the average to mitigate numerical precision issues
    P = (P_i + P_j) / 2

    return P

refocus

refocus(foc_dist=float('inf'))

Refocus the lens to a depth distance by changing sensor position.

Parameters:

Name Type Description Default
foc_dist float

focal distance.

float('inf')
Note

In DSLR, phase detection autofocus (PDAF) is a popular and efficient method. But here we simplify the problem by calculating the in-focus position of green light.

Source code in deeplens/optics/geolens.py
@torch.no_grad()
def refocus(self, foc_dist=float("inf")):
    """Refocus the lens to a depth distance by changing sensor position.

    Args:
        foc_dist (float): focal distance.

    Note:
        In DSLR, phase detection autofocus (PDAF) is a popular and efficient method. But here we simplify the problem by calculating the in-focus position of green light.
    """
    # Calculate in-focus sensor position
    d_sensor_new = self.calc_sensor_plane(depth=foc_dist)

    # Update sensor position
    assert d_sensor_new > 0, "Obtained negative sensor position."
    self.d_sensor = d_sensor_new

    # FoV will be slightly changed
    self.post_computation()

set_fnum

set_fnum(fnum)

Set F-number and aperture radius using binary search.

Parameters:

Name Type Description Default
fnum float

target F-number.

required
Source code in deeplens/optics/geolens.py
@torch.no_grad()
def set_fnum(self, fnum):
    """Set F-number and aperture radius using binary search.

    Args:
        fnum (float): target F-number.
    """
    current_fnum = self.fnum
    current_aper_r = self.surfaces[self.aper_idx].r
    target_pupil_r = self.foclen / fnum / 2

    # Binary search to find aperture radius that gives desired exit pupil radius
    aper_r = current_aper_r * (current_fnum / fnum)
    aper_r_min = 0.5 * aper_r
    aper_r_max = 2.0 * aper_r

    for _ in range(16):
        self.surfaces[self.aper_idx].r = aper_r
        _, pupilr = self.calc_entrance_pupil()

        if abs(pupilr - target_pupil_r) < 0.1:  # Close enough
            break

        if pupilr > target_pupil_r:
            # Current radius is too large, decrease it
            aper_r_max = aper_r
            aper_r = (aper_r_min + aper_r) / 2
        else:
            # Current radius is too small, increase it
            aper_r_min = aper_r
            aper_r = (aper_r_max + aper_r) / 2

    self.surfaces[self.aper_idx].r = aper_r

    # Update pupil after setting aperture radius
    self.calc_pupil()

set_target_fov_fnum

set_target_fov_fnum(rfov, fnum)

Set FoV, ImgH and F number, only use this function to assign design targets.

Parameters:

Name Type Description Default
rfov float

half diagonal-FoV in radian.

required
fnum float

F number.

required
Source code in deeplens/optics/geolens.py
@torch.no_grad()
def set_target_fov_fnum(self, rfov, fnum):
    """Set FoV, ImgH and F number, only use this function to assign design targets.

    Args:
        rfov (float): half diagonal-FoV in radian.
        fnum (float): F number.
    """
    if rfov > math.pi:
        self.rfov = rfov / 180.0 * math.pi
    else:
        self.rfov = rfov

    self.foclen = self.r_sensor / math.tan(self.rfov)
    self.fnum = fnum
    aper_r = self.foclen / fnum / 2
    self.surfaces[self.aper_idx].update_r(float(aper_r))

    # Update pupil after setting aperture radius
    self.calc_pupil()

set_fov

set_fov(rfov)

Set half-diagonal field of view as a design target.

Unlike calc_fov() which derives FoV from focal length and sensor size, this method directly assigns the target FoV for lens optimisation.

Parameters:

Name Type Description Default
rfov float

Half-diagonal FoV in radians.

required
Source code in deeplens/optics/geolens.py
@torch.no_grad()
def set_fov(self, rfov):
    """Set half-diagonal field of view as a design target.

    Unlike ``calc_fov()`` which derives FoV from focal length and sensor
    size, this method directly assigns the target FoV for lens optimisation.

    Args:
        rfov (float): Half-diagonal FoV in radians.
    """
    self.rfov = rfov

prune_surf

prune_surf(expand_factor=None)

Prune surfaces to allow all valid rays to go through.

Parameters:

Name Type Description Default
expand_factor float

height expansion factor. - For cellphone lens, we usually expand by 5% - For camera lens, we usually expand by 20%.

None
Source code in deeplens/optics/geolens.py
@torch.no_grad()
def prune_surf(self, expand_factor=None):
    """Prune surfaces to allow all valid rays to go through.

    Args:
        expand_factor (float): height expansion factor.
            - For cellphone lens, we usually expand by 5%
            - For camera lens, we usually expand by 20%.
    """
    surface_range = self.find_diff_surf()

    # Set expansion factor
    if self.r_sensor < 10.0:
        expand_factor = 0.05 if expand_factor is None else expand_factor
    else:
        expand_factor = 0.10 if expand_factor is None else expand_factor

    # Expand surface height
    for i in surface_range:
        self.surfaces[i].r = self.surfaces[i].r * (1 + expand_factor)

    # Sample and trace rays to compute the maximum valid region
    if self.rfov is not None:
        fov_deg = self.rfov * 180 / torch.pi
    else:
        fov = np.arctan(self.r_sensor / self.foclen)
        fov_deg = float(fov) * 180 / torch.pi
        print(f"Using fov_deg: {fov_deg} during surface pruning.")

    fov_y = [f * fov_deg / 10 for f in range(0, 11)]
    ray = self.sample_parallel(
        fov_x=[0.0], fov_y=fov_y, num_rays=SPP_CALC, scale_pupil=1.5
    )
    _, ray_o_record = self.trace2sensor(ray=ray, record=True)

    # Ray record, shape [num_rays, num_surfaces + 2, 3]
    ray_o_record = torch.stack(ray_o_record, dim=-2)
    ray_o_record = torch.nan_to_num(ray_o_record, 0.0)
    ray_o_record = ray_o_record.reshape(-1, ray_o_record.shape[-2], 3)

    # Compute the maximum ray height for each surface
    ray_r_record = (ray_o_record[..., :2] ** 2).sum(-1).sqrt()
    surf_r_max = ray_r_record.max(dim=0)[0][1:-1]

    # Update surface height
    for i in surface_range:
        if surf_r_max[i] > 0:
            r_expand = surf_r_max[i].item() * expand_factor
            r_expand = max(min(r_expand, 2.0), 0.1)
            self.surfaces[i].update_r(surf_r_max[i].item() + r_expand)
        else:
            print(f"No valid rays for Surf {i}, expand existing radius.")
            r_expand = self.surfaces[i].r * expand_factor
            r_expand = max(min(r_expand, 2.0), 0.1)
            self.surfaces[i].update_r(self.surfaces[i].r + r_expand)

correct_shape

correct_shape(expand_factor=None)

Correct wrong lens shape during lens design optimization.

Applies correction rules to ensure valid lens geometry
  1. Move the first surface to z = 0.0
  2. Fix aperture distance if aperture is at the front
  3. Prune all surfaces to allow valid rays through

Parameters:

Name Type Description Default
expand_factor float

Height expansion factor for surface pruning. If None, auto-selects based on lens type. Defaults to None.

None

Returns:

Name Type Description
bool

True if any shape corrections were made, False otherwise.

Source code in deeplens/optics/geolens.py
@torch.no_grad()
def correct_shape(self, expand_factor=None):
    """Correct wrong lens shape during lens design optimization.

    Applies correction rules to ensure valid lens geometry:
        1. Move the first surface to z = 0.0
        2. Fix aperture distance if aperture is at the front
        3. Prune all surfaces to allow valid rays through

    Args:
        expand_factor (float, optional): Height expansion factor for surface pruning.
            If None, auto-selects based on lens type. Defaults to None.

    Returns:
        bool: True if any shape corrections were made, False otherwise.
    """
    aper_idx = self.aper_idx
    optim_surf_range = self.find_diff_surf()
    shape_changed = False

    # Rule 1: Move the first surface to z = 0.0
    move_dist = self.surfaces[0].d.item()
    for surf in self.surfaces:
        surf.d -= move_dist
    self.d_sensor -= move_dist

    # Rule 2: Fix aperture distance to the first surface if aperture in the front.
    if aper_idx == 0:
        d_aper = 0.05

        # If the first surface is concave, use the maximum negative sag.
        aper_r = torch.tensor(self.surfaces[aper_idx].r, device=self.device)
        sag1 = -self.surfaces[aper_idx + 1].sag(aper_r, 0).item()

        if sag1 > 0:
            d_aper += sag1

        # Update position of all surfaces.
        delta_aper = self.surfaces[1].d.item() - d_aper
        for i in optim_surf_range:
            self.surfaces[i].d -= delta_aper
        self.d_sensor -= delta_aper

    # Rule 4: Prune all surfaces
    self.prune_surf(expand_factor=expand_factor)

    if shape_changed:
        print("Surface shape corrected.")
    return shape_changed

match_materials

match_materials(mat_table='CDGM')

Match lens materials to a glass catalog.

Parameters:

Name Type Description Default
mat_table str

Glass catalog name. Common options include 'CDGM', 'SCHOTT', 'OHARA'. Defaults to 'CDGM'.

'CDGM'
Source code in deeplens/optics/geolens.py
@torch.no_grad()
def match_materials(self, mat_table="CDGM"):
    """Match lens materials to a glass catalog.

    Args:
        mat_table (str, optional): Glass catalog name. Common options include
            'CDGM', 'SCHOTT', 'OHARA'. Defaults to 'CDGM'.
    """
    for surf in self.surfaces:
        surf.mat2.match_material(mat_table=mat_table)

analysis

analysis(save_name='./lens', depth=float('inf'), render=False, render_unwarp=False, lens_title=None, show=False)

Analyze the optical lens.

Parameters:

Name Type Description Default
save_name str

save name.

'./lens'
depth float

object depth distance.

float('inf')
render bool

whether render an image.

False
render_unwarp bool

whether unwarp the rendered image.

False
lens_title str

lens title

None
show bool

whether to show the rendered image.

False
Source code in deeplens/optics/geolens.py
@torch.no_grad()
def analysis(
    self,
    save_name="./lens",
    depth=float("inf"),
    render=False,
    render_unwarp=False,
    lens_title=None,
    show=False,
):
    """Analyze the optical lens.

    Args:
        save_name (str): save name.
        depth (float): object depth distance.
        render (bool): whether render an image.
        render_unwarp (bool): whether unwarp the rendered image.
        lens_title (str): lens title
        show (bool): whether to show the rendered image.
    """
    # Draw lens layout and ray path
    self.draw_layout(
        filename=f"{save_name}.png",
        lens_title=lens_title,
        depth=depth,
        show=show,
    )

    # Draw spot diagram
    self.draw_spot_radial(
        save_name=f"{save_name}_spot.png",
        depth=depth,
        show=show,
    )

    # Draw MTF
    if depth == float("inf"):
        # This is a hack to draw MTF for infinite depth
        self.draw_mtf(
            depth_list=[DEPTH], save_name=f"{save_name}_mtf.png", show=show
        )
    else:
        self.draw_mtf(
            depth_list=[depth], save_name=f"{save_name}_mtf.png", show=show
        )

    # Calculate RMS error
    self.analysis_spot(depth=depth)

    # Render an image, compute PSNR and SSIM
    if render:
        depth = DEPTH if depth == float("inf") else depth
        img_org = Image.open("./datasets/charts/NBS_1963_1k.png").convert("RGB")
        img_org = np.array(img_org)
        self.analysis_rendering(
            img_org,
            depth=depth,
            spp=SPP_RENDER,
            unwarp=render_unwarp,
            save_name=f"{save_name}_render",
            noise=0.01,
            show=show,
        )

get_optimizer_params

get_optimizer_params(lrs=[0.0001, 0.0001, 0.01, 0.0001], decay=0.01, optim_mat=False, optim_surf_range=None)

Get optimizer parameters for different lens surface.

Recommendation

For cellphone lens: [d, c, k, a], [1e-4, 1e-4, 1e-1, 1e-4] For camera lens: [d, c, 0, 0], [1e-3, 1e-4, 0, 0]

Parameters:

Name Type Description Default
lrs list

learning rate for different parameters.

[0.0001, 0.0001, 0.01, 0.0001]
decay float

decay rate for higher order a. Defaults to 0.01.

0.01
optim_mat bool

whether to optimize material. Defaults to False.

False
optim_surf_range list

surface indices to be optimized. Defaults to None.

None

Returns:

Name Type Description
list

optimizer parameters

Source code in deeplens/optics/geolens.py
def get_optimizer_params(
    self,
    lrs=[1e-4, 1e-4, 1e-2, 1e-4],
    decay=0.01,
    optim_mat=False,
    optim_surf_range=None,
):
    """Get optimizer parameters for different lens surface.

    Recommendation:
        For cellphone lens: [d, c, k, a], [1e-4, 1e-4, 1e-1, 1e-4]
        For camera lens: [d, c, 0, 0], [1e-3, 1e-4, 0, 0]

    Args:
        lrs (list): learning rate for different parameters.
        decay (float): decay rate for higher order a. Defaults to 0.01.
        optim_mat (bool): whether to optimize material. Defaults to False.
        optim_surf_range (list): surface indices to be optimized. Defaults to None.

    Returns:
        list: optimizer parameters
    """
    # Find surfaces to be optimized
    if optim_surf_range is None:
        # optim_surf_range = self.find_diff_surf()
        optim_surf_range = range(len(self.surfaces))

    # If lr for each surface is a list is given
    if isinstance(lrs[0], list):
        return self.get_optimizer_params_manual(
            lrs=lrs, optim_mat=optim_mat, optim_surf_range=optim_surf_range
        )

    # Optimize lens surface parameters
    params = []
    for surf_idx in optim_surf_range:
        surf = self.surfaces[surf_idx]

        if isinstance(surf, Aperture):
            params += surf.get_optimizer_params(lrs=[lrs[0]])

        elif isinstance(surf, Aspheric):
            params += surf.get_optimizer_params(
                lrs=lrs[:4], decay=decay, optim_mat=optim_mat
            )

        elif isinstance(surf, AsphericNorm):
            params += surf.get_optimizer_params(
                lrs=lrs[:4], decay=decay, optim_mat=optim_mat
            )

        elif isinstance(surf, Phase):
            params += surf.get_optimizer_params(lrs=[lrs[0], lrs[4]])

        # elif isinstance(surf, GaussianRBF):
        #     params += surf.get_optimizer_params(lrs=lr, optim_mat=optim_mat)

        # elif isinstance(surf, NURBS):
        #     params += surf.get_optimizer_params(lrs=lr, optim_mat=optim_mat)

        elif isinstance(surf, Plane):
            params += surf.get_optimizer_params(lrs=[lrs[0]], optim_mat=optim_mat)

        # elif isinstance(surf, PolyEven):
        #     params += surf.get_optimizer_params(lrs=lr, optim_mat=optim_mat)

        elif isinstance(surf, Spheric):
            params += surf.get_optimizer_params(
                lrs=[lrs[0], lrs[1]], optim_mat=optim_mat
            )

        elif isinstance(surf, ThinLens):
            params += surf.get_optimizer_params(
                lrs=[lrs[0], lrs[1]], optim_mat=optim_mat
            )

        else:
            raise Exception(
                f"Surface type {surf.__class__.__name__} is not supported for optimization yet."
            )

    # Optimize sensor place
    self.d_sensor.requires_grad = True
    params += [{"params": self.d_sensor, "lr": lrs[0]}]

    return params

get_optimizer

get_optimizer(lrs=[0.0001, 0.0001, 0.1, 0.0001], decay=0.01, optim_surf_range=None, optim_mat=False)

Get optimizers and schedulers for different lens parameters.

Parameters:

Name Type Description Default
lrs list

learning rate for different parameters [c, d, k, a]. Defaults to [1e-4, 1e-4, 0, 1e-4].

[0.0001, 0.0001, 0.1, 0.0001]
decay float

decay rate for higher order a. Defaults to 0.2.

0.01
optim_surf_range list

surface indices to be optimized. Defaults to None.

None
optim_mat bool

whether to optimize material. Defaults to False.

False

Returns:

Name Type Description
list

optimizer parameters

Source code in deeplens/optics/geolens.py
def get_optimizer(
    self,
    lrs=[1e-4, 1e-4, 1e-1, 1e-4],
    decay=0.01,
    optim_surf_range=None,
    optim_mat=False,
):
    """Get optimizers and schedulers for different lens parameters.

    Args:
        lrs (list): learning rate for different parameters [c, d, k, a]. Defaults to [1e-4, 1e-4, 0, 1e-4].
        decay (float): decay rate for higher order a. Defaults to 0.2.
        optim_surf_range (list): surface indices to be optimized. Defaults to None.
        optim_mat (bool): whether to optimize material. Defaults to False.

    Returns:
        list: optimizer parameters
    """
    # Initialize lens design constraints (edge thickness, etc.)
    self.init_constraints()

    # Get optimizer
    params = self.get_optimizer_params(
        lrs=lrs, decay=decay, optim_surf_range=optim_surf_range, optim_mat=optim_mat
    )
    optimizer = torch.optim.Adam(params)
    # optimizer = torch.optim.SGD(params)
    return optimizer

read_lens_json

read_lens_json(filename='./test.json')

Read the lens from a JSON file.

Loads lens configuration including surfaces, materials, and optical properties from the DeepLens native JSON format.

Parameters:

Name Type Description Default
filename str

Path to the JSON lens file. Defaults to './test.json'.

'./test.json'
Note

After loading, the lens is moved to self.device and post_computation is called to calculate derived properties.

Source code in deeplens/optics/geolens.py
def read_lens_json(self, filename="./test.json"):
    """Read the lens from a JSON file.

    Loads lens configuration including surfaces, materials, and optical properties
    from the DeepLens native JSON format.

    Args:
        filename (str, optional): Path to the JSON lens file. Defaults to './test.json'.

    Note:
        After loading, the lens is moved to self.device and post_computation is called
        to calculate derived properties.
    """
    self.surfaces = []
    self.materials = []
    with open(filename, "r") as f:
        data = json.load(f)
        d = 0.0
        for idx, surf_dict in enumerate(data["surfaces"]):
            surf_dict["d"] = d
            surf_dict["surf_idx"] = idx

            if surf_dict["type"] == "Aperture":
                s = Aperture.init_from_dict(surf_dict)

            elif surf_dict["type"] == "Aspheric":
                # s = Aspheric.init_from_dict(surf_dict)
                s = AsphericNorm.init_from_dict(surf_dict)

            elif surf_dict["type"] == "Cubic":
                s = Cubic.init_from_dict(surf_dict)

            # elif surf_dict["type"] == "GaussianRBF":
            #     s = GaussianRBF.init_from_dict(surf_dict)

            # elif surf_dict["type"] == "NURBS":
            #     s = NURBS.init_from_dict(surf_dict)

            elif surf_dict["type"] == "Phase":
                s = Phase.init_from_dict(surf_dict)

            elif surf_dict["type"] == "Plane":
                s = Plane.init_from_dict(surf_dict)

            # elif surf_dict["type"] == "PolyEven":
            #     s = PolyEven.init_from_dict(surf_dict)

            elif surf_dict["type"] == "Stop":
                s = Aperture.init_from_dict(surf_dict)

            elif surf_dict["type"] == "Spheric":
                s = Spheric.init_from_dict(surf_dict)

            elif surf_dict["type"] == "ThinLens":
                s = ThinLens.init_from_dict(surf_dict)

            else:
                raise Exception(
                    f"Surface type {surf_dict['type']} is not implemented in GeoLens.read_lens_json()."
                )

            self.surfaces.append(s)
            d += surf_dict["d_next"]

    self.d_sensor = torch.tensor(d)
    self.lens_info = data.get("info", "None")
    self.enpd = data.get("enpd", None)
    self.float_enpd = True if self.enpd is None else False
    self.float_foclen = False
    self.float_rfov = False
    self.r_sensor = data["r_sensor"]

    self.to(self.device)

    # Set sensor size and resolution
    sensor_res = data.get("sensor_res", (2000, 2000))
    self.set_sensor_res(sensor_res=sensor_res)
    self.post_computation()

write_lens_json

write_lens_json(filename='./test.json')

Write the lens to a JSON file.

Saves the complete lens configuration including all surfaces, materials, focal length, F-number, and sensor properties to the DeepLens JSON format.

Parameters:

Name Type Description Default
filename str

Path for the output JSON file. Defaults to './test.json'.

'./test.json'
Source code in deeplens/optics/geolens.py
def write_lens_json(self, filename="./test.json"):
    """Write the lens to a JSON file.

    Saves the complete lens configuration including all surfaces, materials,
    focal length, F-number, and sensor properties to the DeepLens JSON format.

    Args:
        filename (str, optional): Path for the output JSON file. Defaults to './test.json'.
    """
    data = {}
    data["info"] = self.lens_info if hasattr(self, "lens_info") else "None"
    data["foclen"] = round(self.foclen, 4)
    data["fnum"] = round(self.fnum, 4)
    if self.float_enpd is False:
        data["enpd"] = round(self.enpd, 4)
    data["r_sensor"] = self.r_sensor
    data["(d_sensor)"] = round(self.d_sensor.item(), 4)
    data["(sensor_size)"] = [round(i, 4) for i in self.sensor_size]
    data["surfaces"] = []
    for i, s in enumerate(self.surfaces):
        surf_dict = {"idx": i}
        surf_dict.update(s.surf_dict())
        if i < len(self.surfaces) - 1:
            surf_dict["d_next"] = round(
                self.surfaces[i + 1].d.item() - self.surfaces[i].d.item(), 4
            )
        else:
            surf_dict["d_next"] = round(
                self.d_sensor.item() - self.surfaces[i].d.item(), 4
            )

        data["surfaces"].append(surf_dict)

    with open(filename, "w") as f:
        json.dump(data, f, indent=4)
    print(f"Lens written to {filename}")

deeplens.optics.geolens_pkg.eval.GeoLensEval

Mixin providing classical optical performance evaluation for GeoLens.

Provides spot diagrams, RMS error maps, MTF curves, distortion analysis, vignetting, and field curvature — results are accuracy-aligned with Zemax OpticStudio.

This class is not instantiated directly; it is mixed into :class:~deeplens.optics.geolens.GeoLens.

draw_spot_radial

draw_spot_radial(save_name='./lens_spot_radial.png', num_fov=5, depth=float('inf'), num_rays=SPP_PSF, wvln_list=WAVE_RGB, show=False)

Draw spot diagram of the lens at different field angles along meridional (y) direction.

Parameters:

Name Type Description Default
save_name string

filename to save. Defaults to "./lens_spot_radial.png".

'./lens_spot_radial.png'
num_fov int

field of view number. Defaults to 4.

5
depth float

depth of the point source. Defaults to float("inf").

float('inf')
num_rays int

number of rays to sample. Defaults to SPP_PSF.

SPP_PSF
wvln_list list

wavelength list to render.

WAVE_RGB
show bool

whether to show the plot. Defaults to False.

False
Source code in deeplens/optics/geolens_pkg/eval.py
@torch.no_grad()
def draw_spot_radial(
    self,
    save_name="./lens_spot_radial.png",
    num_fov=5,
    depth=float("inf"),
    num_rays=SPP_PSF,
    wvln_list=WAVE_RGB,
    show=False,
):
    """Draw spot diagram of the lens at different field angles along meridional (y) direction.

    Args:
        save_name (string, optional): filename to save. Defaults to "./lens_spot_radial.png".
        num_fov (int, optional): field of view number. Defaults to 4.
        depth (float, optional): depth of the point source. Defaults to float("inf").
        num_rays (int, optional): number of rays to sample. Defaults to SPP_PSF.
        wvln_list (list, optional): wavelength list to render.
        show (bool, optional): whether to show the plot. Defaults to False.
    """
    assert isinstance(wvln_list, list), "wvln_list must be a list"

    # Prepare figure
    fig, axs = plt.subplots(1, num_fov, figsize=(num_fov * 3.5, 3))
    axs = np.atleast_1d(axs)

    # Trace and draw each wavelength separately, overlaying results
    for wvln_idx, wvln in enumerate(wvln_list):
        # Sample rays along meridional (y) direction, shape [num_fov, num_rays, 3]
        ray = self.sample_radial_rays(
            num_field=num_fov, depth=depth, num_rays=num_rays, wvln=wvln
        )

        # Trace rays to sensor plane, shape [num_fov, num_rays, 3]
        ray = self.trace2sensor(ray)
        ray_o = ray.o.clone().cpu().numpy()
        ray_valid = ray.is_valid.clone().cpu().numpy()

        color = RGB_COLORS[wvln_idx % len(RGB_COLORS)]

        # Plot multiple spot diagrams in one figure
        for i in range(num_fov):
            valid = ray_valid[i, :]
            x, y = ray_o[i, :, 0], ray_o[i, :, 1]

            # Filter valid rays
            mask = valid > 0
            x_valid, y_valid = x[mask], y[mask]

            # Plot points and center of mass for this wavelength
            axs[i].scatter(x_valid, y_valid, 2, color=color, alpha=0.5)
            axs[i].set_aspect("equal", adjustable="datalim")
            axs[i].tick_params(axis="both", which="major", labelsize=6)

    if show:
        plt.show()
    else:
        assert save_name.endswith(".png"), "save_name must end with .png"
        plt.savefig(save_name, bbox_inches="tight", format="png", dpi=300)
    plt.close(fig)

draw_spot_map

draw_spot_map(save_name='./lens_spot_map.png', num_grid=5, depth=DEPTH, num_rays=SPP_PSF, wvln_list=WAVE_RGB, show=False)

Draw spot diagram of the lens at different field angles.

Parameters:

Name Type Description Default
save_name string

filename to save. Defaults to "./lens_spot_map.png".

'./lens_spot_map.png'
num_grid int

number of grid points. Defaults to 5.

5
depth float

depth of the point source. Defaults to DEPTH.

DEPTH
num_rays int

number of rays to sample. Defaults to SPP_PSF.

SPP_PSF
wvln_list list

wavelength list to render. Defaults to WAVE_RGB.

WAVE_RGB
show bool

whether to show the plot. Defaults to False.

False
Source code in deeplens/optics/geolens_pkg/eval.py
@torch.no_grad()
def draw_spot_map(
    self,
    save_name="./lens_spot_map.png",
    num_grid=5,
    depth=DEPTH,
    num_rays=SPP_PSF,
    wvln_list=WAVE_RGB,
    show=False,
):
    """Draw spot diagram of the lens at different field angles.

    Args:
        save_name (string, optional): filename to save. Defaults to "./lens_spot_map.png".
        num_grid (int, optional): number of grid points. Defaults to 5.
        depth (float, optional): depth of the point source. Defaults to DEPTH.
        num_rays (int, optional): number of rays to sample. Defaults to SPP_PSF.
        wvln_list (list, optional): wavelength list to render. Defaults to WAVE_RGB.
        show (bool, optional): whether to show the plot. Defaults to False.
    """
    assert isinstance(wvln_list, list), "wvln_list must be a list"

    # Plot multiple spot diagrams in one figure
    fig, axs = plt.subplots(
        num_grid, num_grid, figsize=(num_grid * 3, num_grid * 3)
    )
    axs = np.atleast_2d(axs)

    # Loop wavelengths and overlay scatters
    for wvln_idx, wvln in enumerate(wvln_list):
        # Sample rays per wavelength, shape [num_grid, num_grid, num_rays, 3]
        ray = self.sample_grid_rays(
            depth=depth, num_grid=num_grid, num_rays=num_rays, wvln=wvln
        )
        # Trace rays to sensor
        ray = self.trace2sensor(ray)

        # Convert to numpy, shape [num_grid, num_grid, num_rays, 3]
        ray_o = -ray.o.clone().cpu().numpy()
        ray_valid = ray.is_valid.clone().cpu().numpy()

        color = RGB_COLORS[wvln_idx % len(RGB_COLORS)]

        # Draw per grid cell
        for i in range(num_grid):
            for j in range(num_grid):
                valid = ray_valid[i, j, :]
                x, y = ray_o[i, j, :, 0], ray_o[i, j, :, 1]

                # Filter valid rays
                mask = valid > 0
                x_valid, y_valid = x[mask], y[mask]

                # Plot points for this wavelength
                axs[i, j].scatter(x_valid, y_valid, 2, color=color, alpha=0.5)
                axs[i, j].set_aspect("equal", adjustable="datalim")
                axs[i, j].tick_params(axis="both", which="major", labelsize=6)

    if show:
        plt.show()
    else:
        assert save_name.endswith(".png"), "save_name must end with .png"
        plt.savefig(save_name, bbox_inches="tight", format="png", dpi=300)
    plt.close(fig)

rms_map_rgb

rms_map_rgb(num_grid=32, depth=DEPTH)

Calculate the RMS spot error map across RGB wavelengths. Reference to the centroid of green rays.

Parameters:

Name Type Description Default
num_grid int

Number of grid points. Defaults to 64.

32
depth float

Depth of the point source. Defaults to DEPTH.

DEPTH

Returns:

Name Type Description
rms_map Tensor

RMS map for RGB channels. Shape [3, num_grid, num_grid].

Source code in deeplens/optics/geolens_pkg/eval.py
@torch.no_grad()
def rms_map_rgb(self, num_grid=32, depth=DEPTH):
    """Calculate the RMS spot error map across RGB wavelengths. Reference to the centroid of green rays.

    Args:
        num_grid (int, optional): Number of grid points. Defaults to 64.
        depth (float, optional): Depth of the point source. Defaults to DEPTH.

    Returns:
        rms_map (torch.Tensor): RMS map for RGB channels. Shape [3, num_grid, num_grid].
    """
    all_rms_maps = []

    # Iterate G, R, B
    for i, wvln in enumerate([WAVE_RGB[1], WAVE_RGB[0], WAVE_RGB[2]]):
        # Sample and trace rays, shape [num_grid, num_grid, spp, 3]
        ray = self.sample_grid_rays(
            depth=depth, num_grid=num_grid, num_rays=SPP_PSF, wvln=wvln
        )

        ray = self.trace2sensor(ray)
        ray_xy = ray.o[..., :2]
        ray_valid = ray.is_valid

        # Calculate green centroid as reference, shape [num_grid, num_grid, 2]
        if i == 0:
            ray_xy_center_green = (ray_xy * ray_valid.unsqueeze(-1)).sum(
                -2
            ) / ray_valid.sum(-1).add(EPSILON).unsqueeze(-1)

        # Calculate RMS relative to green centroid, shape [num_grid, num_grid]
        rms_map = torch.sqrt(
            (
                ((ray_xy - ray_xy_center_green.unsqueeze(-2)) ** 2).sum(-1)
                * ray_valid
            ).sum(-1)
            / (ray_valid.sum(-1) + EPSILON)
        )
        all_rms_maps.append(rms_map)

    # Stack the RMS maps for R, G, B channels, shape [3, num_grid, num_grid]
    rms_map_rgb = torch.stack(
        [all_rms_maps[1], all_rms_maps[0], all_rms_maps[2]], dim=0
    )

    return rms_map_rgb

rms_map

rms_map(num_grid=32, depth=DEPTH, wvln=DEFAULT_WAVE)

Calculate the RMS spot error map for a specific wavelength.

Currently this function is not used, but it can be used as the weight mask during optimization.

Parameters:

Name Type Description Default
num_grid int

Resolution of the grid used for sampling fields/points. Defaults to 64.

32
depth float

Depth of the point source. Defaults to DEPTH.

DEPTH
wvln float

Wavelength of the ray. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE

Returns:

Name Type Description
rms_map Tensor

RMS map for the specified wavelength. Shape [num_grid, num_grid].

Source code in deeplens/optics/geolens_pkg/eval.py
@torch.no_grad()
def rms_map(self, num_grid=32, depth=DEPTH, wvln=DEFAULT_WAVE):
    """Calculate the RMS spot error map for a specific wavelength.

    Currently this function is not used, but it can be used as the weight mask during optimization.

    Args:
        num_grid (int, optional): Resolution of the grid used for sampling fields/points. Defaults to 64.
        depth (float, optional): Depth of the point source. Defaults to DEPTH.
        wvln (float, optional): Wavelength of the ray. Defaults to DEFAULT_WAVE.

    Returns:
        rms_map (torch.Tensor): RMS map for the specified wavelength. Shape [num_grid, num_grid].
    """
    # Sample and trace rays, shape [num_grid, num_grid, spp, 3]
    ray = self.sample_grid_rays(
        depth=depth, num_grid=num_grid, num_rays=SPP_PSF, wvln=wvln
    )
    ray = self.trace2sensor(ray)
    ray_xy = ray.o[..., :2]  # Shape [num_grid, num_grid, spp, 2]
    ray_valid = ray.is_valid  # Shape [num_grid, num_grid, spp]

    # Calculate centroid for each field point for this wavelength
    ray_xy_center = (ray_xy * ray_valid.unsqueeze(-1)).sum(-2) / ray_valid.sum(
        -1
    ).add(EPSILON).unsqueeze(-1)
    # Shape [num_grid, num_grid, 2]

    # Calculate RMS error relative to its own centroid, shape [num_grid, num_grid]
    rms_map = torch.sqrt(
        (((ray_xy - ray_xy_center.unsqueeze(-2)) ** 2).sum(-1) * ray_valid).sum(-1)
        / (ray_valid.sum(-1) + EPSILON)
    )

    return rms_map

calc_distortion_2D

calc_distortion_2D(rfov, wvln=DEFAULT_WAVE, plane='meridional', ray_aiming=True)

Calculate distortion at a specific field angle.

Parameters:

Name Type Description Default
rfov float

view angle (degree)

required
wvln float

wavelength

DEFAULT_WAVE
plane str

meridional or sagittal

'meridional'
ray_aiming bool

whether the chief ray through the center of the stop.

True

Returns:

Name Type Description
distortion float

distortion at the specific field angle

Source code in deeplens/optics/geolens_pkg/eval.py
def calc_distortion_2D(
    self, rfov, wvln=DEFAULT_WAVE, plane="meridional", ray_aiming=True
):
    """Calculate distortion at a specific field angle.

    Args:
        rfov (float): view angle (degree)
        wvln (float): wavelength
        plane (str): meridional or sagittal
        ray_aiming (bool): whether the chief ray through the center of the stop.

    Returns:
        distortion (float): distortion at the specific field angle
    """
    # Calculate ideal image height
    eff_foclen = self.foclen
    ideal_imgh = eff_foclen * np.tan(rfov * np.pi / 180)

    # Calculate chief ray
    chief_ray_o, chief_ray_d = self.calc_chief_ray_infinite(
        rfov=rfov, wvln=wvln, plane=plane, ray_aiming=ray_aiming
    )
    ray = Ray(chief_ray_o, chief_ray_d, wvln=wvln, device=self.device)

    ray, _ = self.trace(ray)
    t = (self.d_sensor - ray.o[..., 2]) / ray.d[..., 2]

    # Calculate actual image height
    if plane == "sagittal":
        actual_imgh = (ray.o[..., 0] + ray.d[..., 0] * t).abs()
    elif plane == "meridional":
        actual_imgh = (ray.o[..., 1] + ray.d[..., 1] * t).abs()
    else:
        raise ValueError(f"Invalid plane: {plane}")

    # Calculate distortion
    actual_imgh = actual_imgh.cpu().numpy()
    ideal_imgh = ideal_imgh.cpu().numpy()
    distortion = (actual_imgh - ideal_imgh) / ideal_imgh

    # Handle the case where ideal_imgh is 0 or very close to 0
    mask = abs(ideal_imgh) < EPSILON
    distortion[mask] = 0.0

    return distortion

draw_distortion_radial

draw_distortion_radial(rfov, save_name=None, num_points=GEO_GRID, wvln=DEFAULT_WAVE, plane='meridional', ray_aiming=True, show=False)

Draw distortion. zemax format(default): ray_aiming = False.

Note: this function is provided by a community contributor.

Parameters:

Name Type Description Default
rfov

view angle (degrees)

required
save_name

Save filename. Defaults to None.

None
num_points

Number of points. Defaults to GEO_GRID.

GEO_GRID
plane

Meridional or sagittal. Defaults to meridional.

'meridional'
ray_aiming

Whether to use ray aiming. Defaults to False.

True
Source code in deeplens/optics/geolens_pkg/eval.py
def draw_distortion_radial(
    self,
    rfov,
    save_name=None,
    num_points=GEO_GRID,
    wvln=DEFAULT_WAVE,
    plane="meridional",
    ray_aiming=True,
    show=False,
):
    """Draw distortion. zemax format(default): ray_aiming = False.

    Note: this function is provided by a community contributor.

    Args:
        rfov: view angle (degrees)
        save_name: Save filename. Defaults to None.
        num_points: Number of points. Defaults to GEO_GRID.
        plane: Meridional or sagittal. Defaults to meridional.
        ray_aiming: Whether to use ray aiming. Defaults to False.
    """
    # Sample view angles
    rfov_samples = torch.linspace(0, rfov, num_points)
    distortions = []

    # Calculate distortion
    distortions = self.calc_distortion_2D(
        rfov=rfov_samples,
        wvln=wvln,
        plane=plane,
        ray_aiming=ray_aiming,
    )

    # Handle possible NaN values and convert to percentage
    values = [
        t.item() * 100 if not math.isnan(t.item()) else 0 for t in distortions
    ]

    # Create figure
    fig, ax = plt.subplots(figsize=(8, 8))
    ax.set_title(f"{plane} Surface Distortion")

    # Draw distortion curve
    ax.plot(values, rfov_samples, linestyle="-", color="g", linewidth=1.5)

    # Draw reference line (vertical line)
    ax.axvline(x=0, color="k", linestyle="-", linewidth=0.8)

    # Set grid
    ax.grid(True, color="gray", linestyle="-", linewidth=0.5, alpha=1)

    # Dynamically adjust x-axis range
    value = max(abs(v) for v in values)
    margin = value * 0.2  # 20% margin
    x_min, x_max = -max(0.2, value + margin), max(0.2, value + margin)

    # Set ticks
    x_ticks = np.linspace(-value, value, 3)
    y_ticks = np.linspace(0, rfov, 3)

    ax.set_xticks(x_ticks)
    ax.set_yticks(y_ticks)

    # Format tick labels
    x_labels = [f"{x:.1f}%" for x in x_ticks]
    y_labels = [f"{y:.1f}" for y in y_ticks]

    ax.set_xticklabels(x_labels)
    ax.set_yticklabels(y_labels)

    # Set axis labels
    ax.set_xlabel("Distortion (%)")
    ax.set_ylabel("Field of View (degrees)")

    # Set axis range
    ax.set_xlim(x_min, x_max)
    ax.set_ylim(0, rfov)

    if show:
        plt.show()
    else:
        if save_name is None:
            save_name = f"./{plane}_distortion_inf.png"
        plt.savefig(save_name, bbox_inches="tight", format="png", dpi=300)
    plt.close(fig)

distortion_map

distortion_map(num_grid=16, depth=DEPTH, wvln=DEFAULT_WAVE)

Compute distortion map at a given depth.

Parameters:

Name Type Description Default
num_grid int

number of grid points.

16
depth float

depth of the point source.

DEPTH
wvln float

wavelength.

DEFAULT_WAVE

Returns:

Name Type Description
distortion_grid Tensor

distortion map. shape (grid_size, grid_size, 2)

Source code in deeplens/optics/geolens_pkg/eval.py
@torch.no_grad()
def distortion_map(self, num_grid=16, depth=DEPTH, wvln=DEFAULT_WAVE):
    """Compute distortion map at a given depth.

    Args:
        num_grid (int): number of grid points.
        depth (float): depth of the point source.
        wvln (float): wavelength.

    Returns:
        distortion_grid (torch.Tensor): distortion map. shape (grid_size, grid_size, 2)
    """
    # Sample and trace rays, shape (grid_size, grid_size, num_rays, 3)
    ray = self.sample_grid_rays(depth=depth, num_grid=num_grid, wvln=wvln, uniform_fov=False)
    ray = self.trace2sensor(ray)

    # Calculate centroid of the rays, shape (grid_size, grid_size, 2)
    ray_xy = ray.centroid()[..., :2]
    x_dist = -ray_xy[..., 0] / self.sensor_size[1] * 2
    y_dist = ray_xy[..., 1] / self.sensor_size[0] * 2
    distortion_grid = torch.stack((x_dist, y_dist), dim=-1)
    return distortion_grid

distortion_center

distortion_center(points)

Calculate the distortion center for given normalized points.

Parameters:

Name Type Description Default
points

Normalized point source positions. Shape [N, 3] or [..., 3]. x, y in [-1, 1], z (depth) in [-Inf, 0].

required

Returns:

Name Type Description
distortion_center

Normalized distortion center positions. Shape [N, 2] or [..., 2]. x, y in [-1, 1].

Source code in deeplens/optics/geolens_pkg/eval.py
def distortion_center(self, points):
    """Calculate the distortion center for given normalized points.

    Args:
        points: Normalized point source positions. Shape [N, 3] or [..., 3].
            x, y in [-1, 1], z (depth) in [-Inf, 0].

    Returns:
        distortion_center: Normalized distortion center positions. Shape [N, 2] or [..., 2].
            x, y in [-1, 1].
    """
    sensor_w, sensor_h = self.sensor_size

    # Convert normalized points to object space coordinates
    depth = points[..., 2]
    scale = self.calc_scale(depth)
    points_obj_x = points[..., 0] * scale * sensor_w / 2
    points_obj_y = points[..., 1] * scale * sensor_h / 2
    points_obj = torch.stack([points_obj_x, points_obj_y, depth], dim=-1)

    # Sample rays and trace to sensor
    ray = self.sample_from_points(points=points_obj)
    ray = self.trace2sensor(ray)

    # Calculate centroid and normalize to [-1, 1]
    ray_center = -ray.centroid()  # shape [..., 3]
    distortion_center_x = ray_center[..., 0] / (sensor_w / 2)
    distortion_center_y = ray_center[..., 1] / (sensor_h / 2)
    distortion_center = torch.stack((distortion_center_x, distortion_center_y), dim=-1)
    return distortion_center

draw_distortion

draw_distortion(save_name=None, num_grid=16, depth=DEPTH, wvln=DEFAULT_WAVE, show=False)

Draw distortion map.

Parameters:

Name Type Description Default
save_name str

filename to save. Defaults to None.

None
num_grid int

number of grid points. Defaults to 16.

16
depth float

depth of the point source. Defaults to DEPTH.

DEPTH
wvln float

wavelength. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE
show bool

whether to show the plot. Defaults to False.

False
Source code in deeplens/optics/geolens_pkg/eval.py
def draw_distortion(
    self, save_name=None, num_grid=16, depth=DEPTH, wvln=DEFAULT_WAVE, show=False
):
    """Draw distortion map.

    Args:
        save_name (str, optional): filename to save. Defaults to None.
        num_grid (int, optional): number of grid points. Defaults to 16.
        depth (float, optional): depth of the point source. Defaults to DEPTH.
        wvln (float, optional): wavelength. Defaults to DEFAULT_WAVE.
        show (bool, optional): whether to show the plot. Defaults to False.
    """
    # Ray tracing to calculate distortion map
    distortion_grid = self.distortion_map(num_grid=num_grid, depth=depth, wvln=wvln)
    x1 = distortion_grid[..., 0].cpu().numpy()
    y1 = distortion_grid[..., 1].cpu().numpy()

    # Draw image
    fig, ax = plt.subplots()
    ax.set_title("Lens distortion")
    ax.scatter(x1, y1, s=2)
    ax.axis("scaled")
    ax.grid(True)

    # Add grid lines based on grid_size
    ax.set_xticks(np.linspace(-1, 1, num_grid))
    ax.set_yticks(np.linspace(-1, 1, num_grid))

    if show:
        plt.show()
    else:
        depth_str = "inf" if depth == float("inf") else f"{-depth}mm"
        if save_name is None:
            save_name = f"./distortion_{depth_str}.png"
        plt.savefig(save_name, bbox_inches="tight", format="png", dpi=300)
    plt.close(fig)

mtf

mtf(fov, wvln=DEFAULT_WAVE)

Calculate Modulation Transfer Function at a specific field of view.

Computes the geometric MTF by first generating a PSF at the given field position, then converting it to tangential and sagittal MTF curves via FFT.

Parameters:

Name Type Description Default
fov float

Field of view angle in radians.

required
wvln float

Wavelength in micrometers. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE

Returns:

Name Type Description
tuple

(freq, mtf_tan, mtf_sag) where: - freq (ndarray): Spatial frequency axis in cycles/mm. - mtf_tan (ndarray): Tangential (meridional) MTF values. - mtf_sag (ndarray): Sagittal MTF values.

Source code in deeplens/optics/geolens_pkg/eval.py
def mtf(self, fov, wvln=DEFAULT_WAVE):
    """Calculate Modulation Transfer Function at a specific field of view.

    Computes the geometric MTF by first generating a PSF at the given field
    position, then converting it to tangential and sagittal MTF curves via FFT.

    Args:
        fov (float): Field of view angle in radians.
        wvln (float, optional): Wavelength in micrometers. Defaults to DEFAULT_WAVE.

    Returns:
        tuple: (freq, mtf_tan, mtf_sag) where:
            - freq (ndarray): Spatial frequency axis in cycles/mm.
            - mtf_tan (ndarray): Tangential (meridional) MTF values.
            - mtf_sag (ndarray): Sagittal MTF values.
    """
    point = [0, -fov / self.rfov, DEPTH]
    psf = self.psf(points=point, recenter=True, wvln=wvln)
    freq, mtf_tan, mtf_sag = self.psf2mtf(psf, pixel_size=self.pixel_size)
    return freq, mtf_tan, mtf_sag

psf2mtf staticmethod

psf2mtf(psf, pixel_size)

Calculate MTF from PSF.

Parameters:

Name Type Description Default
psf tensor

2D PSF tensor (e.g., ks x ks). Assumes standard orientation where the array's y-axis corresponds to the tangential/meridional direction and the x-axis to the sagittal direction.

required
pixel_size float

Pixel size in mm.

required

Returns:

Name Type Description
freq ndarray

Frequency axis (cycles/mm).

tangential_mtf ndarray

Tangential MTF.

sagittal_mtf ndarray

Sagittal MTF.

Reference

[1] https://en.wikipedia.org/wiki/Optical_transfer_function [2] https://www.edmundoptics.com/knowledge-center/application-notes/optics/introduction-to-modulation-transfer-function/?srsltid=AfmBOoq09vVDVlh_uuwWnFoMTg18JVgh18lFSw8Ci4Sdlry-AmwGkfDd

Source code in deeplens/optics/geolens_pkg/eval.py
@staticmethod
def psf2mtf(psf, pixel_size):
    """Calculate MTF from PSF.

    Args:
        psf (tensor): 2D PSF tensor (e.g., ks x ks). Assumes standard orientation where the array's y-axis corresponds to the tangential/meridional direction and the x-axis to the sagittal direction.
        pixel_size (float): Pixel size in mm.

    Returns:
        freq (ndarray): Frequency axis (cycles/mm).
        tangential_mtf (ndarray): Tangential MTF.
        sagittal_mtf (ndarray): Sagittal MTF.

    Reference:
        [1] https://en.wikipedia.org/wiki/Optical_transfer_function
        [2] https://www.edmundoptics.com/knowledge-center/application-notes/optics/introduction-to-modulation-transfer-function/?srsltid=AfmBOoq09vVDVlh_uuwWnFoMTg18JVgh18lFSw8Ci4Sdlry-AmwGkfDd
    """
    # Convert to numpy (supports torch tensors and numpy arrays)
    try:
        psf_np = psf.detach().cpu().numpy()
    except AttributeError:
        try:
            psf_np = psf.cpu().numpy()
        except AttributeError:
            psf_np = np.asarray(psf)

    # Compute line spread functions (integrate PSF over orthogonal axes)
    # y-axis corresponds to tangential; x-axis corresponds to sagittal
    lsf_sagittal = psf_np.sum(axis=0)  # function of x
    lsf_tangential = psf_np.sum(axis=1)  # function of y

    # One-sided spectra (for real inputs)
    mtf_sag = np.abs(np.fft.rfft(lsf_sagittal))
    mtf_tan = np.abs(np.fft.rfft(lsf_tangential))

    # Normalize by DC to ensure MTF(0) == 1
    dc_sag = mtf_sag[0] if mtf_sag.size > 0 else 1.0
    dc_tan = mtf_tan[0] if mtf_tan.size > 0 else 1.0
    if dc_sag != 0:
        mtf_sag = mtf_sag / dc_sag
    if dc_tan != 0:
        mtf_tan = mtf_tan / dc_tan

    # Frequency axis in cycles/mm (one-sided)
    fx = np.fft.rfftfreq(lsf_sagittal.size, d=pixel_size)
    freq = fx
    positive_freq_idx = freq > 0

    return (
        freq[positive_freq_idx],
        mtf_tan[positive_freq_idx],
        mtf_sag[positive_freq_idx],
    )

draw_mtf

draw_mtf(save_name='./lens_mtf.png', relative_fov_list=[0.0, 0.7, 1.0], depth_list=[DEPTH], psf_ks=128, show=False)

Draw a grid of MTF curves. Each subplot in the grid corresponds to a specific (depth, FOV) combination. Each subplot displays MTF curves for R, G, B wavelengths.

Parameters:

Name Type Description Default
relative_fov_list list

List of relative field of view values. Defaults to [0.0, 0.7, 1.0].

[0.0, 0.7, 1.0]
depth_list list

List of depth values. Defaults to [DEPTH].

[DEPTH]
save_name str

Filename to save the plot. Defaults to "./mtf_grid.png".

'./lens_mtf.png'
psf_ks int

Kernel size for intermediate PSF calculation. Defaults to 256.

128
show bool

whether to show the plot. Defaults to False.

False
Source code in deeplens/optics/geolens_pkg/eval.py
@torch.no_grad()
def draw_mtf(
    self,
    save_name="./lens_mtf.png",
    relative_fov_list=[0.0, 0.7, 1.0],
    depth_list=[DEPTH],
    psf_ks=128,
    show=False,
):
    """Draw a grid of MTF curves.
    Each subplot in the grid corresponds to a specific (depth, FOV) combination.
    Each subplot displays MTF curves for R, G, B wavelengths.

    Args:
        relative_fov_list (list, optional): List of relative field of view values. Defaults to [0.0, 0.7, 1.0].
        depth_list (list, optional): List of depth values. Defaults to [DEPTH].
        save_name (str, optional): Filename to save the plot. Defaults to "./mtf_grid.png".
        psf_ks (int, optional): Kernel size for intermediate PSF calculation. Defaults to 256.
        show (bool, optional): whether to show the plot. Defaults to False.
    """
    pixel_size = self.pixel_size
    nyquist_freq = 0.5 / pixel_size
    num_fovs = len(relative_fov_list)
    if float("inf") in depth_list:
        depth_list = [DEPTH if x == float("inf") else x for x in depth_list]
    num_depths = len(depth_list)

    # Create figure and subplots (num_depths * num_fovs subplots)
    fig, axs = plt.subplots(
        num_depths, num_fovs, figsize=(num_fovs * 3, num_depths * 3), squeeze=False
    )

    # Iterate over depth and field of view
    for depth_idx, depth in enumerate(depth_list):
        for fov_idx, fov_relative in enumerate(relative_fov_list):
            # Calculate rgb PSF
            point = [0, -fov_relative, depth]
            psf_rgb = self.psf_rgb(points=point, ks=psf_ks, recenter=False)

            # Calculate MTF curves for rgb wavelengths
            for wvln_idx, wvln in enumerate(WAVE_RGB):
                # Calculate MTF curves from PSF
                psf = psf_rgb[wvln_idx]
                freq, mtf_tan, _ = self.psf2mtf(psf, pixel_size)

                # Plot MTF curves
                ax = axs[depth_idx, fov_idx]
                color = RGB_COLORS[wvln_idx % len(RGB_COLORS)]
                wvln_label = RGB_LABELS[wvln_idx % len(RGB_LABELS)]
                wvln_nm = int(wvln * 1000)
                ax.plot(
                    freq,
                    mtf_tan,
                    color=color,
                    label=f"{wvln_label}({wvln_nm}nm)-Tan",
                )

            # Draw Nyquist frequency
            ax.axvline(
                x=nyquist_freq,
                color="k",
                linestyle=":",
                linewidth=1.2,
                label="Nyquist",
            )

            # Set title and label for subplot
            fov_deg = round(fov_relative * self.rfov * 180 / np.pi, 1)
            depth_str = "inf" if depth == float("inf") else f"{depth}"
            ax.set_title(f"FOV: {fov_deg}deg, Depth: {depth_str}mm", fontsize=8)
            ax.set_xlabel("Spatial Frequency [cycles/mm]", fontsize=8)
            ax.set_ylabel("MTF", fontsize=8)
            ax.legend(fontsize=6)
            ax.tick_params(axis="both", which="major", labelsize=7)
            ax.grid(True)
            ax.set_ylim(0, 1.05)

    plt.tight_layout()
    if show:
        plt.show()
    else:
        assert save_name.endswith(".png"), "save_name must end with .png"
        plt.savefig(save_name, bbox_inches="tight", format="png", dpi=300)
    plt.close(fig)

draw_field_curvature

draw_field_curvature(save_name=None, num_points=32, z_span=1.0, z_steps=1001, wvln_list=WAVE_RGB, spp=SPP_CALC, show=False)

Draw field curvature: best-focus defocus vs field angle, RGB overlaid.

For each wavelength and field angle, sweeps defocus positions around the sensor plane and finds the position that minimizes the tangential ray spread. Plots tangential curves as solid lines.

Parameters:

Name Type Description Default
save_name str

Path to save the figure. Defaults to './field_curvature.png'.

None
num_points int

Number of field angle samples. Defaults to 32.

32
z_span float

Half-range of defocus sweep in mm. Defaults to 1.0.

1.0
z_steps int

Number of defocus steps. Defaults to 1001.

1001
wvln_list list

Wavelengths to evaluate. Defaults to WAVE_RGB.

WAVE_RGB
spp int

Number of rays per field point. Defaults to SPP_CALC.

SPP_CALC
show bool

If True, display plot interactively. Defaults to False.

False
Source code in deeplens/optics/geolens_pkg/eval.py
@torch.no_grad()
def draw_field_curvature(
    self,
    save_name=None,
    num_points=32,
    z_span=1.0,
    z_steps=1001,
    wvln_list=WAVE_RGB,
    spp=SPP_CALC,
    show=False,
):
    """Draw field curvature: best-focus defocus vs field angle, RGB overlaid.

    For each wavelength and field angle, sweeps defocus positions around the
    sensor plane and finds the position that minimizes the tangential ray spread.
    Plots tangential curves as solid lines.

    Args:
        save_name (str, optional): Path to save the figure. Defaults to
            ``'./field_curvature.png'``.
        num_points (int, optional): Number of field angle samples. Defaults to 32.
        z_span (float, optional): Half-range of defocus sweep in mm. Defaults to 1.0.
        z_steps (int, optional): Number of defocus steps. Defaults to 1001.
        wvln_list (list, optional): Wavelengths to evaluate. Defaults to WAVE_RGB.
        spp (int, optional): Number of rays per field point. Defaults to SPP_CALC.
        show (bool, optional): If True, display plot interactively. Defaults to False.
    """
    print("This function is not optimized for the best speed.")
    device = self.device
    # Convert maximum field angle to degrees
    rfov_deg = float(self.rfov) * 180.0 / np.pi

    # Sample field angles [0, rfov_deg]
    rfov_samples = torch.linspace(0.0, rfov_deg, num_points, device=device)

    # Prepare containers
    delta_z_tan = []  # list of numpy arrays per wavelength

    # Defocus sweep grid (around current sensor plane)
    d_sensor = self.d_sensor
    z_grid = d_sensor + torch.linspace(-z_span, z_span, z_steps, device=device)

    # Helper to compute best focus along a given axis (0=x sagittal, 1=y tangential)
    def best_focus_delta_z(ray, axis_idx: int):
        # ray: after lens surfaces (image space)
        # Vectorized intersection with planes z_grid
        oz = ray.o[..., 2:3]
        dz = ray.d[..., 2:3]
        t = (z_grid.unsqueeze(0) - oz) / (dz + 1e-12)  # [N, Z]

        oa = ray.o[..., axis_idx : axis_idx + 1]
        da = ray.d[..., axis_idx : axis_idx + 1]
        pos_axis = (oa + da * t).squeeze(-1)  # [N, Z]

        w = ray.is_valid.unsqueeze(-1).float()  # [N, 1] -> [N, Z] by broadcast
        pos_axis = pos_axis * w
        w_sum = w.sum(0)  # [Z]
        centroid = pos_axis.sum(0) / (w_sum + EPSILON)  # [Z]
        ms = (((pos_axis - centroid.unsqueeze(0)) ** 2) * w).sum(0) / (
            w_sum + EPSILON
        )  # [Z]
        best_idx = torch.argmin(ms)
        return (z_grid[best_idx] - d_sensor).item()

    # Loop wavelengths and field angles
    for w_idx, wvln in enumerate(wvln_list):
        dz_tan = []
        for i in range(len(rfov_samples)):
            fov_deg = rfov_samples[i].item()

            # Tangential (meridional plane: y-z plane -> minimize y spread)
            ray_t = self.sample_parallel_2D(
                fov=fov_deg,
                num_rays=spp,
                wvln=wvln,
                plane="meridional",
                entrance_pupil=True,
            )
            ray_t, _ = self.trace(ray_t)
            dz_tan.append(best_focus_delta_z(ray_t, axis_idx=1))  # y-axis

        delta_z_tan.append(np.asarray(dz_tan))

    # Plot
    fov_np = rfov_samples.detach().cpu().numpy()
    fig, ax = plt.subplots(figsize=(7, 6))
    ax.set_title("Field Curvature (Δz vs Field Angle)")

    # Determine x range (tangential only)
    all_vals = np.abs(np.concatenate(delta_z_tan)) if len(delta_z_tan) > 0 else np.array([0.0])
    x_range = float(max(0.2, all_vals.max() * 1.2)) if all_vals.size > 0 else 0.2

    for w_idx in range(len(wvln_list)):
        color = RGB_COLORS[w_idx % len(RGB_COLORS)]
        lbl = RGB_LABELS[w_idx % len(RGB_LABELS)]
        ax.plot(
            delta_z_tan[w_idx],
            fov_np,
            color=color,
            linestyle="-",
            label=f"{lbl}-Tan",
        )

    ax.axvline(x=0, color="k", linestyle="-", linewidth=0.8)
    ax.grid(True, color="gray", linestyle="-", linewidth=0.5, alpha=1.0)
    ax.set_xlabel("Defocus Δz (mm) relative to sensor plane")
    ax.set_ylabel("Field Angle (deg)")
    ax.set_xlim(-x_range, x_range)
    ax.set_ylim(0, rfov_deg)
    ax.legend(fontsize=8)
    plt.tight_layout()

    if show:
        plt.show()
    else:
        if save_name is None:
            save_name = "./field_curvature.png"
        plt.savefig(save_name, bbox_inches="tight", format="png", dpi=300)
    plt.close(fig)

vignetting

vignetting(depth=DEPTH, num_grid=64)

Compute relative illumination (vignetting) map.

Measures the fraction of rays that successfully reach the sensor for each field position, indicating light falloff from center to edge.

Parameters:

Name Type Description Default
depth float

Object distance. Defaults to DEPTH.

DEPTH
num_grid int

Grid resolution for field sampling. Defaults to 64.

64

Returns:

Name Type Description
Tensor

Vignetting map with values in [0, 1]. Shape [num_grid, num_grid]. A value of 1.0 means no vignetting; 0.0 means fully vignetted.

Source code in deeplens/optics/geolens_pkg/eval.py
def vignetting(self, depth=DEPTH, num_grid=64):
    """Compute relative illumination (vignetting) map.

    Measures the fraction of rays that successfully reach the sensor for each
    field position, indicating light falloff from center to edge.

    Args:
        depth (float, optional): Object distance. Defaults to DEPTH.
        num_grid (int, optional): Grid resolution for field sampling. Defaults to 64.

    Returns:
        Tensor: Vignetting map with values in [0, 1]. Shape [num_grid, num_grid].
            A value of 1.0 means no vignetting; 0.0 means fully vignetted.
    """
    # Sample rays, shape [num_grid, num_grid, num_rays, 3]
    ray = self.sample_grid_rays(depth=depth, num_grid=num_grid)

    # Trace rays to sensor
    ray = self.trace2sensor(ray)

    # Calculate vignetting map
    vignetting = ray.is_valid.sum(-1) / (ray.is_valid.shape[-1])
    return vignetting

draw_vignetting

draw_vignetting(filename=None, depth=DEPTH, resolution=512, show=False)

Draw vignetting (relative illumination) map as a grayscale image.

Parameters:

Name Type Description Default
filename str

Path to save the figure. If None, auto-generates a name based on depth. Defaults to None.

None
depth float

Object distance. Defaults to DEPTH.

DEPTH
resolution int

Output image resolution in pixels. Defaults to 512.

512
show bool

If True, display the plot interactively. Defaults to False.

False
Source code in deeplens/optics/geolens_pkg/eval.py
def draw_vignetting(self, filename=None, depth=DEPTH, resolution=512, show=False):
    """Draw vignetting (relative illumination) map as a grayscale image.

    Args:
        filename (str, optional): Path to save the figure. If None, auto-generates
            a name based on depth. Defaults to None.
        depth (float, optional): Object distance. Defaults to DEPTH.
        resolution (int, optional): Output image resolution in pixels. Defaults to 512.
        show (bool, optional): If True, display the plot interactively.
            Defaults to False.
    """
    # Calculate vignetting map
    vignetting = self.vignetting(depth=depth)

    # Interpolate vignetting map to desired resolution
    vignetting = F.interpolate(
        vignetting.unsqueeze(0).unsqueeze(0),
        size=(resolution, resolution),
        mode="bilinear",
        align_corners=False,
    ).squeeze()

    # Scale vignetting to [0.5, 1] range
    vignetting = 0.5 + 0.5 * vignetting

    fig, ax = plt.subplots()
    ax.imshow(vignetting.cpu().numpy(), cmap="gray", vmin=0.5, vmax=1.0)
    ax.colorbar(ticks=[0.5, 0.75, 1.0])

    if show:
        plt.show()
    else:
        if filename is None:
            filename = f"./vignetting_{depth}.png"
        plt.savefig(filename, bbox_inches="tight", format="png", dpi=300)
    plt.close(fig)

wavefront_error

wavefront_error()

Compute wavefront error across the field of view.

Not yet implemented.

Source code in deeplens/optics/geolens_pkg/eval.py
def wavefront_error(self):
    """Compute wavefront error across the field of view.

    Not yet implemented.
    """
    pass

field_curvature

field_curvature()

Compute field curvature (best-focus defocus vs field angle).

Not yet implemented.

Source code in deeplens/optics/geolens_pkg/eval.py
def field_curvature(self):
    """Compute field curvature (best-focus defocus vs field angle).

    Not yet implemented.
    """
    pass

aberration_histogram

aberration_histogram()

Compute aberration histogram (Seidel or Zernike decomposition).

Not yet implemented.

Source code in deeplens/optics/geolens_pkg/eval.py
def aberration_histogram(self):
    """Compute aberration histogram (Seidel or Zernike decomposition).

    Not yet implemented.
    """
    pass

calc_chief_ray

calc_chief_ray(fov, plane='sagittal')

Compute chief ray for an incident angle.

If chief ray is only used to determine the ideal image height, we can warp this function into the image height calculation function.

Parameters:

Name Type Description Default
fov float

incident angle in degree.

required
plane str

"sagittal" or "meridional".

'sagittal'

Returns:

Name Type Description
chief_ray_o Tensor

origin of chief ray.

chief_ray_d Tensor

direction of chief ray.

Note

It is 2D ray tracing, for 3D chief ray, we can shrink the pupil, trace rays, calculate the centroid as the chief ray.

Source code in deeplens/optics/geolens_pkg/eval.py
@torch.no_grad()
def calc_chief_ray(self, fov, plane="sagittal"):
    """Compute chief ray for an incident angle.

    If chief ray is only used to determine the ideal image height, we can warp this function into the image height calculation function.

    Args:
        fov (float): incident angle in degree.
        plane (str): "sagittal" or "meridional".

    Returns:
        chief_ray_o (torch.Tensor): origin of chief ray.
        chief_ray_d (torch.Tensor): direction of chief ray.

    Note:
        It is 2D ray tracing, for 3D chief ray, we can shrink the pupil, trace rays, calculate the centroid as the chief ray.
    """
    # Sample parallel rays from object space
    ray = self.sample_parallel_2D(
        fov=fov, num_rays=SPP_CALC, entrance_pupil=True, plane=plane
    )
    inc_ray = ray.clone()

    # Trace to the aperture
    surf_range = range(0, self.aper_idx)
    ray, _ = self.trace(ray, surf_range=surf_range)

    # Look for the ray that is closest to the optical axis
    center_x = torch.min(torch.abs(ray.o[:, 0]))
    center_idx = torch.where(torch.abs(ray.o[:, 0]) == center_x)[0][0].item()
    chief_ray_o, chief_ray_d = inc_ray.o[center_idx, :], inc_ray.d[center_idx, :]

    return chief_ray_o, chief_ray_d

calc_chief_ray_infinite

calc_chief_ray_infinite(rfov, depth=0.0, wvln=DEFAULT_WAVE, plane='meridional', num_rays=SPP_CALC, ray_aiming=True)

Compute chief ray for an incident angle.

Parameters:

Name Type Description Default
rfov float

incident angle in degree.

required
depth float

depth of the object.

0.0
wvln float

wavelength of the light.

DEFAULT_WAVE
plane str

"sagittal" or "meridional".

'meridional'
num_rays int

number of rays.

SPP_CALC
ray_aiming bool

whether the chief ray through the center of the stop.

True
Source code in deeplens/optics/geolens_pkg/eval.py
@torch.no_grad()
def calc_chief_ray_infinite(
    self,
    rfov,
    depth=0.0,
    wvln=DEFAULT_WAVE,
    plane="meridional",
    num_rays=SPP_CALC,
    ray_aiming=True,
):
    """Compute chief ray for an incident angle.

    Args:
        rfov (float): incident angle in degree.
        depth (float): depth of the object.
        wvln (float): wavelength of the light.
        plane (str): "sagittal" or "meridional".
        num_rays (int): number of rays.
        ray_aiming (bool): whether the chief ray through the center of the stop.
    """
    if isinstance(rfov, float) and rfov > 0:
        rfov = torch.linspace(0, rfov, 2)
    rfov = rfov.to(self.device)

    if not isinstance(depth, torch.Tensor):
        depth = torch.tensor(depth, device=self.device).repeat(len(rfov))

    # set chief ray
    chief_ray_o = torch.zeros([len(rfov), 3]).to(self.device)
    chief_ray_d = torch.zeros([len(rfov), 3]).to(self.device)

    # Convert rfov to radian
    rfov = rfov * torch.pi / 180.0

    if torch.any(rfov == 0):
        chief_ray_o[0, ...] = torch.tensor(
            [0.0, 0.0, depth[0]], device=self.device, dtype=torch.float32
        )
        chief_ray_d[0, ...] = torch.tensor(
            [0.0, 0.0, 1.0], device=self.device, dtype=torch.float32
        )
        if len(rfov) == 1:
            return chief_ray_o, chief_ray_d

    if len(rfov) > 1:
        rfovs = rfov[1:]
        depths = depth[1:]

    if self.aper_idx == 0:
        if plane == "sagittal":
            chief_ray_o[1:, ...] = torch.stack(
                [depths * torch.tan(rfovs), torch.zeros_like(rfovs), depths], dim=-1
            )
            chief_ray_d[1:, ...] = torch.stack(
                [torch.sin(rfovs), torch.zeros_like(rfovs), torch.cos(rfovs)],
                dim=-1,
            )
        else:
            chief_ray_o[1:, ...] = torch.stack(
                [torch.zeros_like(rfovs), depths * torch.tan(rfovs), depths], dim=-1
            )
            chief_ray_d[1:, ...] = torch.stack(
                [torch.zeros_like(rfovs), torch.sin(rfovs), torch.cos(rfovs)],
                dim=-1,
            )

        return chief_ray_o, chief_ray_d

    # Scale factor
    pupilz, _ = self.calc_entrance_pupil()
    y_distance = torch.tan(rfovs) * (abs(depths) + pupilz)

    if ray_aiming:
        scale = 0.05
        delta = scale * y_distance

    if not ray_aiming:
        if plane == "sagittal":
            chief_ray_o[1:, ...] = torch.stack(
                [-y_distance, torch.zeros_like(rfovs), depths], dim=-1
            )
            chief_ray_d[1:, ...] = torch.stack(
                [torch.sin(rfovs), torch.zeros_like(rfovs), torch.cos(rfovs)],
                dim=-1,
            )
        else:
            chief_ray_o[1:, ...] = torch.stack(
                [torch.zeros_like(rfovs), -y_distance, depths], dim=-1
            )
            chief_ray_d[1:, ...] = torch.stack(
                [torch.zeros_like(rfovs), torch.sin(rfovs), torch.cos(rfovs)],
                dim=-1,
            )

    else:
        min_y = -y_distance - delta
        max_y = -y_distance + delta
        o1_linspace = torch.stack(
            [
                torch.linspace(min_y[i], max_y[i], num_rays)
                for i in range(len(min_y))
            ],
            dim=0,
        )

        o1 = torch.zeros([len(rfovs), num_rays, 3])
        o1[:, :, 2] = depths[0]

        o2_linspace = torch.stack(
            [
                torch.linspace(-delta[i], delta[i], num_rays)
                for i in range(len(min_y))
            ],
            dim=0,
        )

        o2 = torch.zeros([len(rfovs), num_rays, 3])
        o2[:, :, 2] = pupilz

        if plane == "sagittal":
            o1[:, :, 0] = o1_linspace
            o2[:, :, 0] = o2_linspace
        else:
            o1[:, :, 1] = o1_linspace
            o2[:, :, 1] = o2_linspace

        # Trace until the aperture
        ray = Ray(o1, o2 - o1, wvln=wvln, device=self.device)
        inc_ray = ray.clone()
        surf_range = range(0, self.aper_idx + 1)
        ray, _ = self.trace(ray, surf_range=surf_range)

        # Look for the ray that is closest to the optical axis
        if plane == "sagittal":
            _, center_idx = torch.min(torch.abs(ray.o[..., 0]), dim=1)
            chief_ray_o[1:, ...] = inc_ray.o[
                torch.arange(len(rfovs)), center_idx.long(), ...
            ]
            chief_ray_d[1:, ...] = torch.stack(
                [torch.sin(rfovs), torch.zeros_like(rfovs), torch.cos(rfovs)],
                dim=-1,
            )
        else:
            _, center_idx = torch.min(torch.abs(ray.o[..., 1]), dim=1)
            chief_ray_o[1:, ...] = inc_ray.o[
                torch.arange(len(rfovs)), center_idx.long(), ...
            ]
            chief_ray_d[1:, ...] = torch.stack(
                [torch.zeros_like(rfovs), torch.sin(rfovs), torch.cos(rfovs)],
                dim=-1,
            )

    return chief_ray_o, chief_ray_d

deeplens.optics.geolens_pkg.optim.GeoLensOptim

Mixin providing differentiable optimisation for GeoLens.

Implements gradient-based lens design using PyTorch autograd:

  • Loss functions – RMS spot error, focus, surface regularity, gap constraints, material validity.
  • Constraint initialisation – edge-thickness and self-intersection guards.
  • Optimizer helpers – parameter groups with per-type learning rates and cosine annealing schedules.
  • High-level optimize() – curriculum-learning training loop.

This class is not instantiated directly; it is mixed into :class:~deeplens.optics.geolens.GeoLens.

References

Xinge Yang et al., "Curriculum learning for ab initio deep learned refractive optics," Nature Communications 2024.

init_constraints

init_constraints(constraint_params=None)

Initialize constraints for the lens design.

Parameters:

Name Type Description Default
constraint_params dict

Constraint parameters.

None
Source code in deeplens/optics/geolens_pkg/optim.py
def init_constraints(self, constraint_params=None):
    """Initialize constraints for the lens design.

    Args:
        constraint_params (dict): Constraint parameters.
    """
    # In the future, we want to use constraint_params to set the constraints.
    if constraint_params is None:
        constraint_params = {}
        print("Lens design constraints initialized with default values.")

    if self.r_sensor < 12.0:
        self.is_cellphone = True

        # Self intersection lower bounds
        self.air_min_edge = 0.05
        self.air_min_center = 0.05
        self.thick_min_edge = 0.25
        self.thick_min_center = 0.4
        self.flange_min = 0.8

        # Air gap and thickness upper bounds
        self.air_max_edge = 3.0
        self.air_max_center = 0.5
        self.thick_max_edge = 2.0
        self.thick_max_center = 3.0
        self.flange_max = 3.0

        # Surface shape constraints
        self.sag2diam_max = 0.1
        self.grad_max = 0.57 # tan(30deg)
        self.diam2thick_max = 15.0
        self.tmax2tmin_max = 5.0

        # Ray angle constraints
        self.chief_ray_angle_max = 30.0 # deg
        self.obliq_min = 0.6

    else:
        self.is_cellphone = False

        # Self-intersection lower bounds
        self.air_min_edge = 0.1
        self.air_min_center = 0.1
        self.thick_min_edge = 1.0
        self.thick_min_center = 2.0
        self.flange_min = 5.0

        # Air gap and thickness upper bounds
        self.air_max_edge = 100.0  # float("inf")
        self.air_max_center = 100.0  # float("inf")
        self.thick_max_edge = 20.0
        self.thick_max_center = 20.0
        self.flange_max = 100.0  # float("inf")

        # Surface shape constraints
        self.sag2diam_max = 0.2
        self.grad_max = 0.84 # tan(40deg)
        self.diam2thick_max = 20.0
        self.tmax2tmin_max = 10.0

        # Ray angle constraints
        self.chief_ray_angle_max = 40.0 # deg
        self.obliq_min = 0.4

loss_reg

loss_reg(w_focus=10.0, w_ray_angle=2.0, w_intersec=1.0, w_gap=0.1, w_surf=1.0)

Compute combined regularization loss for lens design.

Aggregates multiple constraint losses to keep the lens physically valid during gradient-based optimisation.

Parameters:

Name Type Description Default
w_focus float

Weight for focus loss. Defaults to 10.0.

10.0
w_ray_angle float

Weight for chief ray angle loss. Defaults to 2.0.

2.0
w_intersec float

Weight for self-intersection loss. Defaults to 1.0.

1.0
w_gap float

Weight for air gap / thickness loss. Defaults to 0.1.

0.1
w_surf float

Weight for surface shape loss. Defaults to 1.0.

1.0

Returns:

Name Type Description
tuple

(loss_reg, loss_dict) where: - loss_reg (Tensor): Scalar combined regularization loss. - loss_dict (dict): Per-component loss values for logging.

Source code in deeplens/optics/geolens_pkg/optim.py
def loss_reg(self, w_focus=10.0, w_ray_angle=2.0, w_intersec=1.0, w_gap=0.1, w_surf=1.0):
    """Compute combined regularization loss for lens design.

    Aggregates multiple constraint losses to keep the lens physically valid
    during gradient-based optimisation.

    Args:
        w_focus (float, optional): Weight for focus loss. Defaults to 10.0.
        w_ray_angle (float, optional): Weight for chief ray angle loss. Defaults to 2.0.
        w_intersec (float, optional): Weight for self-intersection loss. Defaults to 1.0.
        w_gap (float, optional): Weight for air gap / thickness loss. Defaults to 0.1.
        w_surf (float, optional): Weight for surface shape loss. Defaults to 1.0.

    Returns:
        tuple: (loss_reg, loss_dict) where:
            - loss_reg (Tensor): Scalar combined regularization loss.
            - loss_dict (dict): Per-component loss values for logging.
    """
    # Loss functions for regularization
    # loss_focus = self.loss_infocus()
    loss_ray_angle = self.loss_ray_angle()
    loss_intersec = self.loss_intersec()
    loss_gap = self.loss_gap()
    loss_surf = self.loss_surface()
    # loss_mat = self.loss_mat()
    loss_reg = (
        # w_focus * loss_focus
        + w_intersec * loss_intersec
        + w_gap * loss_gap
        + w_surf * loss_surf
        + w_ray_angle * loss_ray_angle
        # + loss_mat
    )

    # Return loss and loss dictionary
    loss_dict = {
        # "loss_focus": loss_focus.item(),
        "loss_intersec": loss_intersec.item(),
        "loss_gap": loss_gap.item(),
        "loss_surf": loss_surf.item(),
        'loss_ray_angle': loss_ray_angle.item(),
        # 'loss_mat': loss_mat.item(),
    }
    return loss_reg, loss_dict

loss_infocus

loss_infocus(target=0.005)

Sample parallel rays and compute RMS loss on the sensor plane, minimize focus loss.

Parameters:

Name Type Description Default
target float

target of RMS loss. Defaults to 0.005 [mm].

0.005
Source code in deeplens/optics/geolens_pkg/optim.py
def loss_infocus(self, target=0.005):
    """Sample parallel rays and compute RMS loss on the sensor plane, minimize focus loss.

    Args:
        target (float, optional): target of RMS loss. Defaults to 0.005 [mm].
    """
    loss = torch.tensor(0.0, device=self.device)

    # Ray tracing and calculate RMS error
    ray = self.sample_parallel(fov_x=0.0, fov_y=0.0, wvln=WAVE_RGB[1], num_rays=SPP_CALC)
    ray = self.trace2sensor(ray)
    rms_error = ray.rms_error()

    # If RMS error is larger than target, add it to loss
    if rms_error > target:
        loss += rms_error

    return loss

loss_surface

loss_surface()

Penalize extreme surface shapes that are difficult to manufacture.

Checks four constraints for each optimisable surface
  1. Sag-to-diameter ratio exceeding sag2diam_max.
  2. Maximum surface gradient exceeding grad_max.
  3. Diameter-to-thickness ratio exceeding diam2thick_max.
  4. Maximum-to-minimum thickness ratio exceeding tmax2tmin_max.

Returns:

Name Type Description
Tensor

Scalar surface shape penalty loss.

Source code in deeplens/optics/geolens_pkg/optim.py
def loss_surface(self):
    """Penalize extreme surface shapes that are difficult to manufacture.

    Checks four constraints for each optimisable surface:
        1. Sag-to-diameter ratio exceeding ``sag2diam_max``.
        2. Maximum surface gradient exceeding ``grad_max``.
        3. Diameter-to-thickness ratio exceeding ``diam2thick_max``.
        4. Maximum-to-minimum thickness ratio exceeding ``tmax2tmin_max``.

    Returns:
        Tensor: Scalar surface shape penalty loss.
    """
    sag2diam_max = self.sag2diam_max
    grad_max_allowed = self.grad_max
    diam2thick_max = self.diam2thick_max
    tmax2tmin_max = self.tmax2tmin_max

    loss_grad = torch.tensor(0.0, device=self.device)
    loss_diam2thick = torch.tensor(0.0, device=self.device)
    loss_tmax2tmin = torch.tensor(0.0, device=self.device)
    loss_sag2diam = torch.tensor(0.0, device=self.device)
    for i in self.find_diff_surf():
        # Sample points on the surface
        x_ls = torch.linspace(0.0, 1.0, 32).to(self.device) * self.surfaces[i].r
        y_ls = torch.zeros_like(x_ls)

        # Sag
        sag_ls = self.surfaces[i].sag(x_ls, y_ls)
        sag2diam = sag_ls.abs().max() / self.surfaces[i].r / 2
        if sag2diam > sag2diam_max:
            loss_sag2diam += sag2diam

        # 1st-order derivative
        grad_ls = self.surfaces[i].dfdxyz(x_ls, y_ls)[0]
        grad_max = grad_ls.abs().max()
        if grad_max > grad_max_allowed:
            loss_grad += grad_max

        # Diameter to thickness ratio, thick_max to thick_min ratio
        if not self.surfaces[i].mat2.name == "air":
            surf2 = self.surfaces[i + 1]
            surf1 = self.surfaces[i]

            # Penalize diameter to thickness ratio
            diam2thick = 2 * max(surf2.r, surf1.r) / (surf2.d - surf1.d)
            if diam2thick > diam2thick_max:
                loss_diam2thick += diam2thick

            # Penalize thick_max to thick_min ratio
            r_edge = min(surf2.r, surf1.r)
            thick_center = surf2.d - surf1.d
            thick_edge = surf2.surface_with_offset(r_edge, 0.0) - surf1.surface_with_offset(r_edge, 0.0)
            if thick_center > thick_edge:
                tmax2tmin = thick_center / thick_edge
            else:
                tmax2tmin = thick_edge / thick_center

            if tmax2tmin > tmax2tmin_max:
                loss_tmax2tmin += tmax2tmin

    return loss_sag2diam + loss_grad + loss_diam2thick + loss_tmax2tmin

loss_intersec

loss_intersec()

Loss function to avoid self-intersection.

This function penalizes when surfaces are too close to each other, which could cause self-intersection or manufacturing issues.

Source code in deeplens/optics/geolens_pkg/optim.py
def loss_intersec(self):
    """Loss function to avoid self-intersection.

    This function penalizes when surfaces are too close to each other,
    which could cause self-intersection or manufacturing issues.
    """
    # Constraints
    air_min_center = self.air_min_center
    air_min_edge = self.air_min_edge
    thick_min_center = self.thick_min_center
    thick_min_edge = self.thick_min_edge
    flange_min = self.flange_min

    # Loss
    loss = torch.tensor(0.0, device=self.device)
    for i in range(len(self.surfaces) - 1):
        # Sample evaluation points on the two surfaces
        current_surf = self.surfaces[i]
        next_surf = self.surfaces[i + 1]

        r_center = torch.tensor(0.0).to(self.device) * current_surf.r
        z_prev_center = current_surf.surface_with_offset(r_center, 0.0, valid_check=False)
        z_next_center = next_surf.surface_with_offset(r_center, 0.0, valid_check=False)

        r_edge = torch.linspace(0.5, 1.0, 16).to(self.device) * current_surf.r
        z_prev_edge = current_surf.surface_with_offset(r_edge, 0.0, valid_check=False)
        z_next_edge = next_surf.surface_with_offset(r_edge, 0.0, valid_check=False)

        # Next surface is air
        if self.surfaces[i].mat2.name == "air":
            # Center air gap
            dist_center = z_next_center - z_prev_center
            if dist_center < air_min_center:
                loss += dist_center

            # Edge air gap
            dist_edge = torch.min(z_next_edge - z_prev_edge)
            if dist_edge < air_min_edge:
                loss += dist_edge

        # Next surface is lens
        else:
            # Center thickness
            dist_center = z_next_center - z_prev_center
            if dist_center < thick_min_center:
                loss += dist_center

            # Edge thickness
            dist_edge = torch.min(z_next_edge - z_prev_edge)
            if dist_edge < thick_min_edge:
                loss += dist_edge

    # Distance to sensor (flange)
    last_surf = self.surfaces[-1]
    r = torch.linspace(0.0, 1.0, 32).to(self.device) * last_surf.r
    z_last_surf = self.d_sensor - last_surf.surface_with_offset(r, 0.0)

    flange = torch.min(z_last_surf)
    if flange < flange_min:
        loss += flange

    # Loss, maximize loss
    return -loss

loss_gap

loss_gap()

Loss function to penalize too large air gap and thickness.

This function penalizes when air gaps or lens thicknesses are too large, which could make the lens system impractically large.

Source code in deeplens/optics/geolens_pkg/optim.py
def loss_gap(self):
    """Loss function to penalize too large air gap and thickness.

    This function penalizes when air gaps or lens thicknesses are too large,
    which could make the lens system impractically large.
    """
    # Constraints
    air_max_center = self.air_max_center
    air_max_edge = self.air_max_edge
    thick_max_center = self.thick_max_center
    thick_max_edge = self.thick_max_edge
    flange_max = self.flange_max

    # Loss
    loss = torch.tensor(0.0, device=self.device)

    # Distance between surfaces
    for i in range(len(self.surfaces) - 1):
        # Sample evaluation points on the two surfaces
        current_surf = self.surfaces[i]
        next_surf = self.surfaces[i + 1]

        r_center = torch.tensor(0.0).to(self.device) * current_surf.r
        z_prev_center = current_surf.surface_with_offset(r_center, 0.0, valid_check=False)
        z_next_center = next_surf.surface_with_offset(r_center, 0.0, valid_check=False)

        r_edge = torch.linspace(0.5, 1.0, 16).to(self.device) * current_surf.r
        z_prev_edge = current_surf.surface_with_offset(r_edge, 0.0, valid_check=False)
        z_next_edge = next_surf.surface_with_offset(r_edge, 0.0, valid_check=False)

        # Air gap
        if self.surfaces[i].mat2.name == "air":
            # Center air gap
            dist_center = z_next_center - z_prev_center
            if dist_center > air_max_center:
                loss += dist_center

            # Edge air gap
            dist_edge = torch.max(z_next_edge - z_prev_edge)
            if dist_edge > air_max_edge:
                loss += dist_edge

        # Lens thickness
        else:
            # Center thickness
            dist_center = z_next_center - z_prev_center
            if dist_center > thick_max_center:
                loss += dist_center

            # Edge thickness
            dist_edge = torch.max(z_next_edge - z_prev_edge)
            if dist_edge > thick_max_edge:
                loss += dist_edge

    # Distance to sensor (flange)
    last_surf = self.surfaces[-1]
    r = torch.linspace(0.0, 1.0, 32).to(self.device) * last_surf.r
    z_last_surf = self.d_sensor - last_surf.surface_with_offset(r, 0.0)

    flange = torch.max(z_last_surf)
    if flange > flange_max:
        loss += flange

    # Loss, minimize loss
    return loss

loss_ray_angle

loss_ray_angle()

Penalize large chief ray angles and low obliquity factors.

Ensures that rays arrive at the sensor within acceptable incidence angles, which is critical for sensor coupling and colour cross-talk.

Returns:

Name Type Description
Tensor

Scalar chief-ray-angle penalty loss.

Source code in deeplens/optics/geolens_pkg/optim.py
def loss_ray_angle(self):
    """Penalize large chief ray angles and low obliquity factors.

    Ensures that rays arrive at the sensor within acceptable incidence
    angles, which is critical for sensor coupling and colour cross-talk.

    Returns:
        Tensor: Scalar chief-ray-angle penalty loss.
    """
    max_angle_deg = self.chief_ray_angle_max
    obliq_min = self.obliq_min

    # Loss on chief ray angle
    ray = self.sample_ring_arm_rays(num_ring=8, num_arm=8, spp=SPP_CALC, scale_pupil=0.2)
    ray = self.trace2sensor(ray)
    cos_cra = ray.d[..., 2]
    cos_cra_ref = float(np.cos(np.deg2rad(max_angle_deg)))
    if (cos_cra < cos_cra_ref).any():
        loss_cra = - cos_cra[cos_cra < cos_cra_ref].mean()
    else:
        loss_cra = torch.tensor(0.0, device=self.device)

    # Loss on accumulated oblique term
    ray = self.sample_ring_arm_rays(num_ring=8, num_arm=8, spp=SPP_CALC, scale_pupil=1.0)
    ray = self.trace2sensor(ray)
    obliq = ray.obliq.squeeze(-1)
    if (obliq < obliq_min).any():
        loss_obliq = - obliq[obliq < obliq_min].mean()
    else:
        loss_obliq = torch.tensor(0.0, device=self.device)

    return loss_cra + loss_obliq

loss_mat

loss_mat()

Penalize material parameters outside manufacturable ranges.

Constrains refractive index n to [1.5, 1.9] and Abbe number V to [30, 70] for each non-air surface material.

Returns:

Name Type Description
Tensor

Scalar material penalty loss.

Source code in deeplens/optics/geolens_pkg/optim.py
def loss_mat(self):
    """Penalize material parameters outside manufacturable ranges.

    Constrains refractive index *n* to [1.5, 1.9] and Abbe number *V* to
    [30, 70] for each non-air surface material.

    Returns:
        Tensor: Scalar material penalty loss.
    """
    n_max = 1.9
    n_min = 1.5
    V_max = 70
    V_min = 30
    loss_mat = torch.tensor(0.0, device=self.device)
    for i in range(len(self.surfaces)):
        if self.surfaces[i].mat1.name != "air":
            if self.surfaces[i].mat1.n > n_max:
                loss_mat += (self.surfaces[i].mat1.n - n_max) / (n_max - n_min)
            if self.surfaces[i].mat1.n < n_min:
                loss_mat += (n_min - self.surfaces[i].mat1.n) / (n_max - n_min)
            if self.surfaces[i].mat1.V > V_max:
                loss_mat += (self.surfaces[i].mat1.V - V_max) / (V_max - V_min)
            if self.surfaces[i].mat1.V < V_min:
                loss_mat += (V_min - self.surfaces[i].mat1.V) / (V_max - V_min)

    return loss_mat

loss_rms

loss_rms(num_grid=GEO_GRID, depth=DEPTH, num_rays=SPP_PSF, sample_more_off_axis=False)

Loss function to compute RGB spot error RMS.

Parameters:

Name Type Description Default
num_grid int

Number of grid points. Defaults to GEO_GRID.

GEO_GRID
depth float

Depth of the lens. Defaults to DEPTH.

DEPTH
num_rays int

Number of rays. Defaults to SPP_CALC.

SPP_PSF
sample_more_off_axis bool

Whether to sample more off-axis rays. Defaults to False.

False

Returns:

Name Type Description
avg_rms_error Tensor

RMS error averaged over wavelengths and grid points.

Source code in deeplens/optics/geolens_pkg/optim.py
def loss_rms(
    self,
    num_grid=GEO_GRID,
    depth=DEPTH,
    num_rays=SPP_PSF,
    sample_more_off_axis=False,
):
    """Loss function to compute RGB spot error RMS.

    Args:
        num_grid (int, optional): Number of grid points. Defaults to GEO_GRID.
        depth (float, optional): Depth of the lens. Defaults to DEPTH.
        num_rays (int, optional): Number of rays. Defaults to SPP_CALC.
        sample_more_off_axis (bool, optional): Whether to sample more off-axis rays. Defaults to False.

    Returns:
        avg_rms_error (torch.Tensor): RMS error averaged over wavelengths and grid points.
    """
    all_rms_errors = []
    for i, wvln in enumerate([WAVE_RGB[1], WAVE_RGB[0], WAVE_RGB[2]]):
        ray = self.sample_grid_rays(
            depth=depth,
            num_grid=num_grid,
            num_rays=num_rays,
            wvln=wvln,
            sample_more_off_axis=sample_more_off_axis,
        )

        # Calculate reference center, shape of (..., 2)
        if i == 0:
            with torch.no_grad():
                ray_center_green = -self.psf_center(points_obj=ray.o[:, :, 0, :], method="pinhole")

        ray = self.trace2sensor(ray)

        # # Green light centroid for reference
        # if i == 0:
        #     with torch.no_grad():
        #         ray_center_green = ray.centroid()

        # Calculate RMS error with reference center
        rms_error = ray.rms_error(center_ref=ray_center_green)
        all_rms_errors.append(rms_error)

    # Calculate average RMS error
    avg_rms_error = torch.stack(all_rms_errors).mean(dim=0)
    return avg_rms_error

sample_ring_arm_rays

sample_ring_arm_rays(num_ring=8, num_arm=8, spp=2048, depth=DEPTH, wvln=DEFAULT_WAVE, scale_pupil=1.0, sample_more_off_axis=True)

Sample rays from object space using a ring-arm pattern.

This method distributes sampling points (origins of ray bundles) on a polar grid in the object plane, defined by field of view. This is useful for capturing lens performance across the full field. The points include the center and num_ring rings with num_arm points on each.

Parameters:

Name Type Description Default
num_ring int

Number of rings to sample in the field of view.

8
num_arm int

Number of arms (spokes) to sample for each ring.

8
spp int

Total number of rays to be sampled, distributed among field points.

2048
depth float

Depth of the object plane.

DEPTH
wvln float

Wavelength of the rays.

DEFAULT_WAVE
scale_pupil float

Scale factor for the pupil size.

1.0

Returns:

Name Type Description
Ray

A Ray object containing the sampled rays.

Source code in deeplens/optics/geolens_pkg/optim.py
def sample_ring_arm_rays(self, num_ring=8, num_arm=8, spp=2048, depth=DEPTH, wvln=DEFAULT_WAVE, scale_pupil=1.0, sample_more_off_axis=True):
    """Sample rays from object space using a ring-arm pattern.

    This method distributes sampling points (origins of ray bundles) on a polar grid in the object plane,
    defined by field of view. This is useful for capturing lens performance across the full field.
    The points include the center and `num_ring` rings with `num_arm` points on each.

    Args:
        num_ring (int): Number of rings to sample in the field of view.
        num_arm (int): Number of arms (spokes) to sample for each ring.
        spp (int): Total number of rays to be sampled, distributed among field points.
        depth (float): Depth of the object plane.
        wvln (float): Wavelength of the rays.
        scale_pupil (float): Scale factor for the pupil size.

    Returns:
        Ray: A Ray object containing the sampled rays.
    """
    # Create points on rings and arms
    max_fov_rad = self.rfov
    if sample_more_off_axis:
        # Use beta distribution to sample more points near the edge (close to 1.0)
        # Beta(0.5, 0.5) gives more samples at 0 and 1, Beta(0.5, 0.3) gives more samples near 1.0
        beta_values = torch.linspace(0.0, 1.0, num_ring, device=self.device)
        # Apply beta transformation to concentrate samples near 1.0
        beta_transformed = beta_values ** 0.5  # Equivalent to Beta(0.5, 1.0) distribution
        ring_fovs = max_fov_rad * beta_transformed

        # Use square root to sample more points near the edge
        # ring_fovs = max_fov_rad * torch.sqrt(torch.linspace(0.0, 1.0, num_ring, device=self.device))
    else:
        ring_fovs = max_fov_rad * torch.linspace(0.0, 1.0, num_ring, device=self.device)

    arm_angles = torch.linspace(0.0, 2 * torch.pi, num_arm + 1, device=self.device)[:-1]
    ring_grid, arm_grid = torch.meshgrid(ring_fovs, arm_angles, indexing="ij")
    x = depth * torch.tan(ring_grid) * torch.cos(arm_grid)
    y = depth * torch.tan(ring_grid) * torch.sin(arm_grid)        
    z = torch.full_like(x, depth)
    points = torch.stack([x, y, z], dim=-1)  # shape: [num_ring, num_arm, 3]

    # Sample rays
    rays = self.sample_from_points(points=points, num_rays=spp, wvln=wvln, scale_pupil=scale_pupil)
    return rays

optimize

optimize(lrs=[0.001, 0.0001, 0.1, 0.0001], decay=0.01, iterations=5000, test_per_iter=100, centroid=False, optim_mat=False, shape_control=True, result_dir=None)

Optimise the lens by minimising RGB RMS spot errors.

Runs a curriculum-learning training loop with Adam optimiser and cosine annealing. Periodically evaluates the lens, saves intermediate results, and optionally corrects surface shapes.

Parameters:

Name Type Description Default
lrs list

Learning rates for [d, c, k, a] parameter groups. Defaults to [1e-3, 1e-4, 1e-1, 1e-4].

[0.001, 0.0001, 0.1, 0.0001]
decay float

Decay factor for higher-order aspheric coefficients. Defaults to 0.01.

0.01
iterations int

Total training iterations. Defaults to 5000.

5000
test_per_iter int

Evaluate and save every N iterations. Defaults to 100.

100
centroid bool

If True, use chief-ray centroid as PSF centre reference; otherwise use pinhole model. Defaults to False.

False
optim_mat bool

If True, include material parameters (n, V) in optimisation. Defaults to False.

False
shape_control bool

If True, call correct_shape() at each evaluation step. Defaults to True.

True
result_dir str

Directory to save results. If None, auto-generates a timestamped directory. Defaults to None.

None
Note

Debug hints: 1. Slowly optimise with small learning rate. 2. FoV and thickness should match well. 3. Keep parameter ranges reasonable. 4. Higher aspheric order is better but more sensitive. 5. More iterations with larger ray sampling improves convergence.

Source code in deeplens/optics/geolens_pkg/optim.py
def optimize(
    self,
    lrs=[1e-3, 1e-4, 1e-1, 1e-4],
    decay=0.01,
    iterations=5000,
    test_per_iter=100,
    centroid=False,
    optim_mat=False,
    shape_control=True,
    result_dir=None,
):
    """Optimise the lens by minimising RGB RMS spot errors.

    Runs a curriculum-learning training loop with Adam optimiser and cosine
    annealing. Periodically evaluates the lens, saves intermediate results,
    and optionally corrects surface shapes.

    Args:
        lrs (list, optional): Learning rates for [d, c, k, a] parameter groups.
            Defaults to [1e-3, 1e-4, 1e-1, 1e-4].
        decay (float, optional): Decay factor for higher-order aspheric coefficients.
            Defaults to 0.01.
        iterations (int, optional): Total training iterations. Defaults to 5000.
        test_per_iter (int, optional): Evaluate and save every N iterations.
            Defaults to 100.
        centroid (bool, optional): If True, use chief-ray centroid as PSF centre
            reference; otherwise use pinhole model. Defaults to False.
        optim_mat (bool, optional): If True, include material parameters (n, V)
            in optimisation. Defaults to False.
        shape_control (bool, optional): If True, call ``correct_shape()`` at each
            evaluation step. Defaults to True.
        result_dir (str, optional): Directory to save results. If None,
            auto-generates a timestamped directory. Defaults to None.

    Note:
        Debug hints:
            1. Slowly optimise with small learning rate.
            2. FoV and thickness should match well.
            3. Keep parameter ranges reasonable.
            4. Higher aspheric order is better but more sensitive.
            5. More iterations with larger ray sampling improves convergence.
    """
    # Experiment settings
    depth = DEPTH
    num_ring = 32
    num_arm = 8
    spp = 2048

    # Result directory and logger
    if result_dir is None:
        result_dir = f"./results/{datetime.now().strftime('%m%d-%H%M%S')}-DesignLens"

    os.makedirs(result_dir, exist_ok=True)
    if not logging.getLogger().hasHandlers():
        logger = logging.getLogger()
        logger.setLevel("DEBUG")
        fmt = logging.Formatter("%(asctime)s:%(levelname)s:%(message)s", "%Y-%m-%d %H:%M:%S")
        sh = logging.StreamHandler()
        sh.setFormatter(fmt)
        sh.setLevel("INFO")
        fh = logging.FileHandler(f"{result_dir}/output.log")
        fh.setFormatter(fmt)
        fh.setLevel("INFO")
        logger.addHandler(sh)
        logger.addHandler(fh)
    logging.info(f"lr:{lrs}, iterations:{iterations}, num_ring:{num_ring}, num_arm:{num_arm}, rays_per_fov:{spp}.")
    logging.info("If Out-of-Memory, try to reduce num_ring, num_arm, and rays_per_fov.")

    # Optimizer and scheduler
    optimizer = self.get_optimizer(lrs, decay=decay, optim_mat=optim_mat)
    scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=100, num_training_steps=iterations)

    # Training loop
    pbar = tqdm(
        total=iterations + 1,
        desc="Progress",
        postfix={"loss_rms": 0, "loss_focus": 0},
    )
    for i in range(iterations + 1):
        # ===> Evaluate the lens
        if i % test_per_iter == 0:
            with torch.no_grad():
                if shape_control and i > 0:
                    self.correct_shape()
                    # self.refocus()

                self.write_lens_json(f"{result_dir}/iter{i}.json")
                self.analysis(f"{result_dir}/iter{i}")

                # Sample rays
                self.calc_pupil()
                rays_backup = []
                for wv in WAVE_RGB:
                    ray = self.sample_ring_arm_rays(num_ring=num_ring, num_arm=num_arm, spp=spp, depth=depth, wvln=wv, scale_pupil=1.05, sample_more_off_axis=False)
                    rays_backup.append(ray)

                # Calculate ray centers
                if centroid:
                    center_ref = -self.psf_center(points_obj=ray.o[:, :, 0, :], method="chief_ray")
                    center_ref = center_ref.unsqueeze(-2).repeat(1, 1, spp, 1)
                else:
                    center_ref = -self.psf_center(points_obj=ray.o[:, :, 0, :], method="pinhole")
                    center_ref = center_ref.unsqueeze(-2).repeat(1, 1, spp, 1)

        # ===> Optimize lens by minimizing RMS
        loss_rms_ls = []
        for wv_idx, wv in enumerate(WAVE_RGB):
            # Ray tracing to sensor, [num_grid, num_grid, num_rays, 3]
            ray = rays_backup[wv_idx].clone()
            ray = self.trace2sensor(ray)

            # Ray error to center and valid mask
            ray_xy = ray.o[..., :2]
            ray_valid = ray.is_valid
            ray_err = ray_xy - center_ref

            # Weight mask, shape of [num_grid, num_grid]
            if wv_idx == 0:
                with torch.no_grad():
                    weight_mask = ((ray_err**2).sum(-1) * ray_valid).sum(-1)
                    weight_mask /= ray_valid.sum(-1) + EPSILON
                    weight_mask /= weight_mask.mean()

            # Loss on RMS error
            l_rms = ((ray_err**2).sum(-1) * ray_valid).sum(-1)
            l_rms /= ray_valid.sum(-1) + EPSILON
            l_rms = (l_rms + EPSILON).sqrt()

            # Weighted loss
            l_rms_weighted = (l_rms * weight_mask).sum()
            l_rms_weighted /= weight_mask.sum() + EPSILON
            loss_rms_ls.append(l_rms_weighted)

        # RMS loss for all wavelengths
        loss_rms = sum(loss_rms_ls) / len(loss_rms_ls)

        # Total loss
        w_focus = 1.0
        loss_focus = self.loss_infocus()

        w_reg = 0.05
        loss_reg, loss_dict = self.loss_reg()

        L_total = loss_rms + w_focus * loss_focus + w_reg * loss_reg

        # Back-propagation
        optimizer.zero_grad()
        L_total.backward()
        optimizer.step()
        scheduler.step()

        pbar.set_postfix(loss_rms=loss_rms.item(), loss_focus=loss_focus.item(), **loss_dict)
        pbar.update(1)

    pbar.close()

deeplens.optics.geolens_pkg.io.GeoLensIO

Mixin providing file I/O for GeoLens.

Supports reading and writing lens prescriptions in three formats:

  • JSON (primary): human-readable, supports parenthesised optimisable parameters, e.g. "(d)": 5.0.
  • Zemax .zmx: industry-standard sequential lens file.
  • Code V .seq: Code V sequential format (read-only).

This class is not instantiated directly; it is mixed into :class:~deeplens.optics.geolens.GeoLens.

read_lens_zmx

read_lens_zmx(filename='./test.zmx')

Load the lens from a Zemax .zmx sequential lens file.

Parses STANDARD and EVENASPH surface types, glass materials, field definitions (YFLN), and entrance pupil settings (ENPD/FLOA).

Parameters:

Name Type Description Default
filename str

Path to the .zmx file. Supports both UTF-8 and UTF-16 encoded files. Defaults to './test.zmx'.

'./test.zmx'

Returns:

Name Type Description
GeoLens

self, for method chaining.

Source code in deeplens/optics/geolens_pkg/io.py
def read_lens_zmx(self, filename="./test.zmx"):
    """Load the lens from a Zemax .zmx sequential lens file.

    Parses STANDARD and EVENASPH surface types, glass materials, field
    definitions (YFLN), and entrance pupil settings (ENPD/FLOA).

    Args:
        filename (str, optional): Path to the .zmx file. Supports both
            UTF-8 and UTF-16 encoded files. Defaults to './test.zmx'.

    Returns:
        GeoLens: ``self``, for method chaining.
    """
    # Read .zmx file
    try:
        with open(filename, "r", encoding="utf-8") as file:
            lines = file.readlines()
    except UnicodeDecodeError:
        with open(filename, "r", encoding="utf-16") as file:
            lines = file.readlines()

    # Iterate through the lines and extract SURF dict
    surfs_dict = {}
    current_surf = None
    for line in lines:
        # Strip leading/trailing whitespace for consistent parsing
        stripped_line = line.strip()

        if stripped_line.startswith("SURF"):
            current_surf = int(stripped_line.split()[1])
            surfs_dict[current_surf] = {}

        elif current_surf is not None and stripped_line != "":
            if len(stripped_line.split(maxsplit=1)) == 1:
                continue
            else:
                key, value = stripped_line.split(maxsplit=1)
                if key == "PARM":
                    new_key = "PARM" + value.split()[0]
                    new_value = value.split()[1]
                    surfs_dict[current_surf][new_key] = new_value
                else:
                    surfs_dict[current_surf][key] = value

        elif stripped_line.startswith("FLOA") or stripped_line.startswith("ENPD"):
            if stripped_line.startswith("FLOA"):
                self.float_enpd = True
                self.enpd = None
            else:
                self.float_enpd = False
                self.enpd = float(stripped_line.split()[1])

        elif stripped_line.startswith("YFLN"):
            # Parse field of view from YFLN line (field coordinates in degrees)
            # YFLN format: YFLN 0.0 <0.707*rfov_deg> <0.99*rfov_deg>
            parts = stripped_line.split()
            if len(parts) > 1:
                field_values = [abs(float(x)) for x in parts[1:] if float(x) != 0.0]
                if field_values:
                    # The largest field value is typically 0.99 * rfov_deg
                    max_field_deg = max(field_values) / 0.99
                    self.rfov = (
                        max_field_deg * math.pi / 180.0
                    )  # Convert to radians

    self.float_foclen = False
    self.float_rfov = False
    # Set default rfov if not parsed from file
    if not hasattr(self, "rfov"):
        self.rfov = None

    # Read the extracted data from each SURF
    self.surfaces = []
    d = 0.0
    mat1_name = "air"
    for surf_idx, surf_dict in surfs_dict.items():
        if surf_idx > 0 and surf_idx < current_surf:
            # Lens surface parameters
            if "GLAS" in surf_dict:
                if surf_dict["GLAS"].split()[0] == "___BLANK":
                    mat2_name = f"{surf_dict['GLAS'].split()[3]}/{surf_dict['GLAS'].split()[4]}"
                else:
                    mat2_name = surf_dict["GLAS"].split()[0].lower()
            else:
                mat2_name = "air"

            surf_r = (
                float(surf_dict["DIAM"].split()[0]) if "DIAM" in surf_dict else 1.0
            )
            surf_c = (
                float(surf_dict["CURV"].split()[0]) if "CURV" in surf_dict else 0.0
            )
            surf_d_next = (
                float(surf_dict["DISZ"].split()[0]) if "DISZ" in surf_dict else 0.0
            )
            surf_conic = float(surf_dict.get("CONI", 0.0))
            surf_param2 = float(surf_dict.get("PARM2", 0.0))
            surf_param3 = float(surf_dict.get("PARM3", 0.0))
            surf_param4 = float(surf_dict.get("PARM4", 0.0))
            surf_param5 = float(surf_dict.get("PARM5", 0.0))
            surf_param6 = float(surf_dict.get("PARM6", 0.0))
            surf_param7 = float(surf_dict.get("PARM7", 0.0))
            surf_param8 = float(surf_dict.get("PARM8", 0.0))

            # Create surface object
            if surf_dict["TYPE"] == "STANDARD":
                if mat2_name == "air" and mat1_name == "air":
                    # Aperture
                    s = Aperture(r=surf_r, d=d)
                else:
                    # Spherical surface
                    s = Spheric(c=surf_c, r=surf_r, d=d, mat2=mat2_name)

            elif surf_dict["TYPE"] == "EVENASPH":
                # Aspherical surface
                s = Aspheric(
                    c=surf_c,
                    r=surf_r,
                    d=d,
                    ai=[
                        surf_param2,
                        surf_param3,
                        surf_param4,
                        surf_param5,
                        surf_param6,
                        surf_param7,
                        surf_param8,
                    ],
                    k=surf_conic,
                    mat2=mat2_name,
                )

            else:
                print(f"Surface type {surf_dict['TYPE']} not implemented.")
                continue

            self.surfaces.append(s)
            d += surf_d_next
            mat1_name = mat2_name

        elif surf_idx == current_surf:
            # Image sensor
            self.r_sensor = float(surf_dict["DIAM"].split()[0])

        else:
            pass

    self.d_sensor = torch.tensor(d)
    return self

write_lens_zmx

write_lens_zmx(filename='./test.zmx')

Write the lens to a Zemax .zmx sequential lens file.

Exports surfaces (STANDARD or EVENASPH), materials, field definitions, and entrance pupil settings in Zemax OpticStudio format.

Parameters:

Name Type Description Default
filename str

Output file path. Defaults to './test.zmx'.

'./test.zmx'
Source code in deeplens/optics/geolens_pkg/io.py
def write_lens_zmx(self, filename="./test.zmx"):
    """Write the lens to a Zemax .zmx sequential lens file.

    Exports surfaces (STANDARD or EVENASPH), materials, field definitions,
    and entrance pupil settings in Zemax OpticStudio format.

    Args:
        filename (str, optional): Output file path. Defaults to './test.zmx'.
    """
    lens_zmx_str = ""
    if self.float_enpd:
        enpd_str = "FLOA"
    else:
        enpd_str = f"ENPD {self.enpd}"
    # Head string
    head_str = f"""VERS 190513 80 123457 L123457
MODE SEQ
NAME 
PFIL 0 0 0
LANG 0
UNIT MM X W X CM MR CPMM
{enpd_str}
ENVD 2.0E+1 1 0
GFAC 0 0
GCAT OSAKAGASCHEMICAL MISC
XFLN 0. 0. 0.
YFLN 0.0 {0.707 * self.rfov * 57.3} {0.99 * self.rfov * 57.3}
WAVL 0.4861327 0.5875618 0.6562725
RAIM 0 0 1 1 0 0 0 0 0
PUSH 0 0 0 0 0 0
SDMA 0 1 0
FTYP 0 0 3 3 0 0 0
ROPD 2
PICB 1
PWAV 2
POLS 1 0 1 0 0 1 0
GLRS 1 0
GSTD 0 100.000 100.000 100.000 100.000 100.000 100.000 0 1 1 0 0 1 1 1 1 1 1
NSCD 100 500 0 1.0E-3 5 1.0E-6 0 0 0 0 0 0 1000000 0 2
COFN QF "COATING.DAT" "SCATTER_PROFILE.DAT" "ABG_DATA.DAT" "PROFILE.GRD"
COFN COATING.DAT SCATTER_PROFILE.DAT ABG_DATA.DAT PROFILE.GRD
SURF 0
TYPE STANDARD
CURV 0.0
DISZ INFINITY
"""
    lens_zmx_str += head_str

    # Surface string
    for i, s in enumerate(self.surfaces):
        d_next = (
            self.surfaces[i + 1].d - self.surfaces[i].d
            if i < len(self.surfaces) - 1
            else self.d_sensor - self.surfaces[i].d
        )
        surf_str = s.zmx_str(surf_idx=i + 1, d_next=d_next)
        lens_zmx_str += surf_str

    # Sensor string
    sensor_str = f"""SURF {i + 2}
TYPE STANDARD
CURV 0.
DISZ 0.0
DIAM {self.r_sensor}
"""
    lens_zmx_str += sensor_str

    # Write lens zmx string into file
    with open(filename, "w") as f:
        f.writelines(lens_zmx_str)
        f.close()
        print(f"Lens written to {filename}")

read_lens_seq

read_lens_seq(filename='./test.seq')

Load the lens from a CODE V .seq sequential file.

Parses standard and aspheric surfaces (with conic and polynomial coefficients A–I), entrance pupil diameter (EPD), field angles (YAN), aperture stop (STO), and image surface (SI).

Parameters:

Name Type Description Default
filename str

Path to the .seq file. Supports both UTF-8 and Latin-1 encoded files. Defaults to './test.seq'.

'./test.seq'

Returns:

Name Type Description
GeoLens

self, for method chaining.

Source code in deeplens/optics/geolens_pkg/io.py
def read_lens_seq(self, filename="./test.seq"):
    """Load the lens from a CODE V .seq sequential file.

    Parses standard and aspheric surfaces (with conic and polynomial
    coefficients A–I), entrance pupil diameter (EPD), field angles (YAN),
    aperture stop (STO), and image surface (SI).

    Args:
        filename (str, optional): Path to the .seq file. Supports both
            UTF-8 and Latin-1 encoded files. Defaults to './test.seq'.

    Returns:
        GeoLens: ``self``, for method chaining.
    """
    print(f"\n{'=' * 60}")
    print(f"Start reading CODE V file: {filename}")
    print(f"{'=' * 60}\n")

    # Read .seq file
    try:
        with open(filename, "r", encoding="utf-8") as file:
            lines = file.readlines()
        print(f"File read successfully (UTF-8)")
    except UnicodeDecodeError:
        try:
            with open(filename, "r", encoding="latin-1") as file:
                lines = file.readlines()
            print(f"File read successfully (Latin-1)")
        except Exception as e:
            print(f"Failed to read file: {e}")
            return self
    print(f"Total lines: {len(lines)}\n")

    # ============ Step 1: Parse file structure ============
    surfaces = []
    current_surface = {}
    surface_index = 0
    global_diameter = None

    print("Beginning to parse surface data...\n")

    for line_num, line in enumerate(lines, 1):
        line = line.strip()

        # Skip irrelevant lines
        if not line or line.startswith(
            (
                "RDM",
                "TITLE",
                "UID",
                "GO",
                "WL",
                "XAN",
                "REF",
                "WTW",
                "INI",
                "WTF",
                "VUY",
                "VLY",
                "DOR",
                "DIM",
                "THC",
            )
        ):
            continue
        # Read entrance pupil diameter
        if line.startswith("EPD"):
            self.enpd = float(line.split()[1])
            self.float_enpd = False
            global_diameter = self.enpd / 2.0
            print(
                f"[Line {line_num}] EPD={self.enpd} -> default radius={global_diameter}"
            )
            continue
        # Read field of view angle
        if line.startswith("YAN"):
            angles = [abs(float(x)) for x in line.split()[1:] if float(x) != 0.0]
            if angles:
                self.hfov = max(angles)
                # Also set rfov in radians for consistency with write functions
                self.rfov = self.hfov * math.pi / 180.0
                print(f"[Line {line_num}] Max field of view={self.hfov} deg")
            continue
        # Object surface
        if line.startswith("SO"):
            parts = line.split()
            thickness = float(parts[2]) if len(parts) > 2 else 1e10

            current_surface = {
                "type": "OBJECT",
                "thickness": thickness,
                "index": surface_index,
            }
            surfaces.append(current_surface)
            print(f"[Line {line_num}] Object surface: T={thickness}")
            surface_index += 1
            current_surface = {}
            continue
        # Standard surface
        if line.startswith("S "):
            # Save the previous surface
            if current_surface:
                surfaces.append(current_surface)
                surface_index += 1

            parts = line.split()
            radius_value = float(parts[1]) if len(parts) > 1 else 0.0
            thickness = float(parts[2]) if len(parts) > 2 else 0.0
            material = parts[3].upper() if len(parts) > 3 else "AIR"

            # Key: compute curvature C = 1/R
            if abs(radius_value) > 1e-10:
                curvature = 1.0 / radius_value
            else:
                curvature = 0.0

            current_surface = {
                "type": "STANDARD",
                "radius": radius_value,
                "curvature": curvature,
                "thickness": thickness,
                "material": material,
                "index": surface_index,
                "diameter": global_diameter,
                "conic": 0.0,
                "asph_coeffs": {},
                "is_stop": False,
            }

            print(
                f"[Line {line_num}] Surface{surface_index}: R={radius_value:.4f} → C={curvature:.6f}, T={thickness}, Mat={material}"
            )
            continue
        # Image surface - do not append yet, wait for CIR
        if line.startswith("SI"):
            if current_surface:
                surfaces.append(current_surface)
                surface_index += 1

            parts = line.split()
            thickness = float(parts[1]) if len(parts) > 1 else 0.0

            current_surface = {
                "type": "IMAGE",
                "thickness": thickness,
                "diameter": None,  # Set to None first, wait for CIR line to update
                "index": surface_index,
            }
            print(f"[Line {line_num}] Image surface")
            continue
        # Handle surface attributes (CIR, STO, ASP, K, A~J, etc.)
        if current_surface:
            if line.startswith("CIR"):
                current_surface["diameter"] = float(
                    line.split()[1].replace(";", "")
                )
                print(f"[Line {line_num}]   → CIR={current_surface['diameter']}")

            elif line.startswith("STO"):
                current_surface["is_stop"] = True
                print(f"[Line {line_num}]   → Aperture stop flag")

            elif line.startswith("ASP"):
                current_surface["type"] = "ASPHERIC"
                print(f"[Line {line_num}]   → Aspheric surface")

            elif line.startswith("K "):
                current_surface["conic"] = float(line.split()[1].replace(";", ""))
                print(f"[Line {line_num}]   → K={current_surface['conic']}")

            # Only extract single-letter coefficients A-J
            elif any(
                line.startswith(p)
                for p in [
                    "A ",
                    "B ",
                    "C ",
                    "D ",
                    "E ",
                    "F ",
                    "G ",
                    "H ",
                    "I ",
                    "J ",
                ]
            ):
                parts = line.replace(";", "").split()
                i = 0
                while i < len(parts) - 1:
                    try:
                        key = parts[i]
                        # Only accept single letters within the range A-J
                        if len(key) == 1 and key in [
                            "A",
                            "B",
                            "C",
                            "D",
                            "E",
                            "F",
                            "G",
                            "H",
                            "I",
                            "J",
                        ]:
                            value = float(parts[i + 1])
                            current_surface["asph_coeffs"][key] = value
                            print(f"[Line {line_num}]   → {key}={value}")
                        i += 2
                    except:
                        i += 1

    # Save the last surface
    if current_surface:
        surfaces.append(current_surface)

    print(f"\nParsing complete, total {len(surfaces)} surfaces\n")

    # ============ Step 2: Create surface objects ============
    print(f"{'=' * 60}")
    print("Start creating surface objects:")
    print(f"{'=' * 60}\n")

    self.surfaces = []
    d = 0.0  # Cumulative distance from the first optical surface to the current surface
    previous_material = "air"

    for surf in surfaces:
        surf_idx = surf["index"]
        surf_type = surf["type"]

        print(f"{'=' * 50}")
        print(f"Processing surface{surf_idx} ({surf_type}), current d={d:.4f}")

        # Handle object surface
        if surf_type == "OBJECT":
            obj_thickness = surf["thickness"]
            if obj_thickness < 1e9:  # Finite object distance
                d += obj_thickness
                print(
                    f"   Object surface thickness={obj_thickness} → accumulated d={d:.4f}"
                )
            else:
                print("   Object surface at infinity")
            previous_material = "air"
            continue

        # Handle image surface
        if surf_type == "IMAGE":
            self.d_sensor = torch.tensor(d)
            # Read diameter from surf dictionary (CIR value)
            self.r_sensor = (
                surf.get("diameter") if surf.get("diameter") is not None else 18.0
            )
            print(
                f"   Image plane position: d_sensor={d:.4f}, r_sensor={self.r_sensor:.4f}"
            )
            break

        # Get surface parameters
        current_material = surf.get("material", "AIR")
        if current_material in ["AIR", "0.0", "", None]:
            current_material = "air"
        else:
            current_material = current_material.lower()

        c = surf.get("curvature", 0.0)
        r = surf.get("diameter", 10.0)
        d_next = surf.get("thickness", 0.0)
        is_stop = surf.get("is_stop", False)

        print(f"   C={c:.6f}, R_aperture={r:.4f}, T={d_next:.4f}")
        print(f"   Material: {previous_material}{current_material}")
        print(f"   is_stop={is_stop}")

        # Create surface object
        try:
            # Case 1: pure aperture (air on both sides + STO flag)
            if is_stop and current_material == "air" and previous_material == "air":
                aperture = Aperture(r=r, d=d)
                self.surfaces.append(aperture)
                print(f"   Created pure aperture: Aperture(r={r:.4f}, d={d:.4f})")

            # Case 2: refractive surface (material change)
            elif current_material != previous_material:
                if surf_type == "STANDARD":
                    s = Spheric(c=c, r=r, d=d, mat2=current_material)
                    self.surfaces.append(s)
                    status = " (stop surface)" if is_stop else ""
                    print(
                        f"   Created spherical surface{status}: Spheric(c={c:.6f}, r={r:.4f}, d={d:.4f}, mat2='{current_material}')"
                    )

                elif surf_type == "ASPHERIC":
                    k = surf.get("conic", 0.0)
                    asph_coeffs = surf.get("asph_coeffs", {})

                    # CODE V aspheric coefficient mapping (shift forward by one position):
                    # A → ai[1] (2nd term, ρ²)
                    # B → ai[2] (4th term, ρ⁴)
                    # C → ai[3] (6th term, ρ⁶)
                    # D → ai[4] (8th term, ρ⁸)
                    # E → ai[5] (10th term, ρ¹⁰)
                    # F → ai[6] (12th term, ρ¹²)
                    # G → ai[7] (14th term, ρ¹⁴)
                    # H → ai[8] (16th term, ρ¹⁶)
                    # I → ai[9] (18th term, ρ¹⁸)

                    # Initialize ai array (10 elements)
                    ai = [0.0] * 10
                    ai[0] = 0.0  # ρ⁰ term (unused)
                    ai[1] = asph_coeffs.get("A", 0.0)  # ρ²
                    ai[2] = asph_coeffs.get("B", 0.0)  # ρ⁴
                    ai[3] = asph_coeffs.get("C", 0.0)  # ρ⁶
                    ai[4] = asph_coeffs.get("D", 0.0)  # ρ⁸
                    ai[5] = asph_coeffs.get("E", 0.0)  # ρ¹⁰
                    ai[6] = asph_coeffs.get("F", 0.0)  # ρ¹²
                    ai[7] = asph_coeffs.get("G", 0.0)  # ρ¹⁴
                    ai[8] = asph_coeffs.get("H", 0.0)  # ρ¹⁶
                    ai[9] = asph_coeffs.get("I", 0.0)  # ρ¹⁸

                    s = Aspheric(c=c, r=r, d=d, ai=ai, k=k, mat2=current_material)
                    self.surfaces.append(s)
                    status = " (stop surface)" if is_stop else ""
                    print(
                        f"   Created aspheric surface{status}: Aspheric(c={c:.6f}, r={r:.4f}, d={d:.4f}, k={k}, mat2='{current_material}')"
                    )
                    if any(
                        ai[1:]
                    ):  # If there are non-zero higher-order terms (starting from ai[1])
                        print(
                            f"      Aspheric coefficients: A={ai[1]:.2e}, B={ai[2]:.2e}, C={ai[3]:.2e}, D={ai[4]:.2e}"
                        )

            else:
                print(f"   Skipped (same material on both sides and no stop flag)")

        except Exception as e:
            print(f"   Failed to create surface: {e}")
            import traceback

            traceback.print_exc()

        # Key: accumulate distance at the end of the loop
        d += d_next
        print(f"   After accumulation: d={d:.4f}")
        previous_material = current_material

    print(f"\n{'=' * 60}")
    print(f"   Done! Created {len(self.surfaces)} objects")
    print(f"   d_sensor={self.d_sensor:.4f}")
    print(f"   r_sensor={self.r_sensor:.4f}")
    print(f"   hfov={self.hfov:.4f}°")
    print(f"{'=' * 60}\n")

    return self

write_lens_seq

write_lens_seq(filename='./test.seq')

Write the lens to a CODE V .seq sequential file.

Exports surfaces, materials, field definitions, and entrance pupil settings in CODE V format.

Parameters:

Name Type Description Default
filename str

Output file path. Defaults to './test.seq'.

'./test.seq'

Returns:

Name Type Description
GeoLens

self, for method chaining.

Source code in deeplens/optics/geolens_pkg/io.py
def write_lens_seq(self, filename="./test.seq"):
    """Write the lens to a CODE V .seq sequential file.

    Exports surfaces, materials, field definitions, and entrance pupil
    settings in CODE V format.

    Args:
        filename (str, optional): Output file path. Defaults to './test.seq'.

    Returns:
        GeoLens: ``self``, for method chaining.
    """

    import datetime

    current_date = datetime.datetime.now().strftime("%d-%b-%Y")

    head_str = f"""RDM;LEN       "VERSION: 2023.03       LENS VERSION: 89       Creation Date:  {current_date}"
TITLE 'Lens Design'
EPD   {self.enpd}
DIM   M
WL    650.0 550.0 480.0
REF   2
WTW   1 2 1
INI   '   '
XAN   0.0 0.0 0.0
YAN   0.0  {0.707 * self.rfov * 57.3} {0.99 * self.rfov * 57.3}
WTF   1.0 1.0 1.0
VUY   0.0 0.0 0.0
VLY   0.0 0.0 0.0
DOR   1.15 1.05
SO    0.0 0.1e14
"""

    lens_seq_str = head_str
    previous_material = "air"

    for i, surf in enumerate(self.surfaces):
        if i < len(self.surfaces) - 1:
            d_next = self.surfaces[i + 1].d - surf.d
        else:
            d_next = float(self.d_sensor - surf.d)

        current_material = getattr(surf, "mat2", "air")

        if current_material is None or current_material == "air":
            material_str = ""
            material_name = "air"
        elif isinstance(current_material, str):
            material_str = f" {current_material.upper()}"
            material_name = current_material
        else:
            material_name = getattr(current_material, "name", str(current_material))
            material_str = f" {material_name.upper()}"

        is_aperture = surf.__class__.__name__ == "Aperture"

        if is_aperture:
            continue

        is_aspheric = surf.__class__.__name__ == "Aspheric"
        is_stop_surface = getattr(surf, "is_stop", False)

        if is_aspheric:
            if abs(surf.c) > 1e-10:
                radius = 1.0 / surf.c
            else:
                radius = 0.0

            k = surf.k if hasattr(surf, "k") else 0.0
            ai = surf.ai if hasattr(surf, "ai") else [0.0] * 10

            surf_str = f"S     {radius} {d_next}{material_str}\n"
            surf_str += f"  CCY 0; THC 0\n"
            surf_str += f"  CIR {surf.r}\n"
            if is_stop_surface:
                surf_str += f"  STO\n"
            surf_str += f"  ASP\n"
            surf_str += f"  K   {k}\n"

            if len(ai) > 4 and any(ai[1:5]):
                surf_str += f"  A   {ai[1]:.16e}; B {ai[2]:.16e}; C&\n"
                surf_str += f"   {ai[3]:.16e}; D {ai[4]:.16e}\n"

            if len(ai) > 8 and any(ai[5:9]):
                surf_str += f"  E   {ai[5]:.16e}; F {ai[6]:.16e}; G {ai[7]:.16e}; H {ai[8]:.16e}\n"

        else:
            if abs(surf.c) > 1e-10:
                radius = 1.0 / surf.c
            else:
                radius = 0.0

            surf_str = f"S     {radius} {d_next}{material_str}\n"
            surf_str += f"  CCY 0; THC 0\n"

            if is_stop_surface:
                surf_str += f"  STO\n"

            surf_str += f"  CIR {surf.r}\n"

        lens_seq_str += surf_str
        previous_material = material_name

    sensor_str = f"SI    0.0 0.0\n"
    sensor_str += f"  CIR {self.r_sensor}\n"
    lens_seq_str += sensor_str
    lens_seq_str += "GO \n"

    with open(filename, "w") as f:
        f.write(lens_seq_str)

    print(f"Lens written to CODE V file: {filename}")
    return self

deeplens.optics.geolens_pkg.vis.GeoLensVis

Mixin providing 2-D lens layout and ray visualisation for GeoLens.

Generates publication-quality cross-section plots showing lens surfaces and traced ray bundles in either the meridional or sagittal plane.

This class is not instantiated directly; it is mixed into :class:~deeplens.optics.geolens.GeoLens.

sample_parallel_2D

sample_parallel_2D(fov=0.0, num_rays=7, wvln=DEFAULT_WAVE, plane='meridional', entrance_pupil=True, depth=0.0)

Sample parallel rays (2D) in object space.

Used for (1) drawing lens setup, (2) 2D geometric optics calculation, for example, refocusing to infinity

Parameters:

Name Type Description Default
fov float

incident angle (in degree). Defaults to 0.0.

0.0
depth float

sampling depth. Defaults to 0.0.

0.0
num_rays int

ray number. Defaults to 7.

7
wvln float

ray wvln. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE
plane str

sampling plane. Defaults to "meridional" (y-z plane).

'meridional'
entrance_pupil bool

whether to use entrance pupil. Defaults to True.

True

Returns:

Name Type Description
ray Ray object

Ray object. Shape [num_rays, 3]

Source code in deeplens/optics/geolens_pkg/vis.py
@torch.no_grad()
def sample_parallel_2D(
    self,
    fov=0.0,
    num_rays=7,
    wvln=DEFAULT_WAVE,
    plane="meridional",
    entrance_pupil=True,
    depth=0.0,
):
    """Sample parallel rays (2D) in object space.

    Used for (1) drawing lens setup, (2) 2D geometric optics calculation, for example, refocusing to infinity

    Args:
        fov (float, optional): incident angle (in degree). Defaults to 0.0.
        depth (float, optional): sampling depth. Defaults to 0.0.
        num_rays (int, optional): ray number. Defaults to 7.
        wvln (float, optional): ray wvln. Defaults to DEFAULT_WAVE.
        plane (str, optional): sampling plane. Defaults to "meridional" (y-z plane).
        entrance_pupil (bool, optional): whether to use entrance pupil. Defaults to True.

    Returns:
        ray (Ray object): Ray object. Shape [num_rays, 3]
    """
    # Sample points on the pupil
    if entrance_pupil:
        pupilz, pupilr = self.get_entrance_pupil()
    else:
        pupilz, pupilr = 0, self.surfaces[0].r

    # Sample ray origins, shape [num_rays, 3]
    if plane == "sagittal":
        ray_o = torch.stack(
            (
                torch.linspace(-pupilr, pupilr, num_rays) * 0.99,
                torch.full((num_rays,), 0),
                torch.full((num_rays,), pupilz),
            ),
            axis=-1,
        )
    elif plane == "meridional":
        ray_o = torch.stack(
            (
                torch.full((num_rays,), 0),
                torch.linspace(-pupilr, pupilr, num_rays) * 0.99,
                torch.full((num_rays,), pupilz),
            ),
            axis=-1,
        )
    else:
        raise ValueError(f"Invalid plane: {plane}")

    # Sample ray directions, shape [num_rays, 3]
    if plane == "sagittal":
        ray_d = torch.stack(
            (
                torch.full((num_rays,), float(np.sin(np.deg2rad(fov)))),
                torch.zeros((num_rays,)),
                torch.full((num_rays,), float(np.cos(np.deg2rad(fov)))),
            ),
            axis=-1,
        )
    elif plane == "meridional":
        ray_d = torch.stack(
            (
                torch.zeros((num_rays,)),
                torch.full((num_rays,), float(np.sin(np.deg2rad(fov)))),
                torch.full((num_rays,), float(np.cos(np.deg2rad(fov)))),
            ),
            axis=-1,
        )
    else:
        raise ValueError(f"Invalid plane: {plane}")

    # Form rays and propagate to the target depth
    rays = Ray(ray_o, ray_d, wvln, device=self.device)
    rays.prop_to(depth)
    return rays

sample_point_source_2D

sample_point_source_2D(fov=0.0, depth=DEPTH, num_rays=7, wvln=DEFAULT_WAVE, entrance_pupil=True)

Sample point source rays (2D) in object space.

Used for (1) drawing lens setup.

Parameters:

Name Type Description Default
fov float

incident angle (in degree). Defaults to 0.0.

0.0
depth float

sampling depth. Defaults to DEPTH.

DEPTH
num_rays int

ray number. Defaults to 7.

7
wvln float

ray wvln. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE
entrance_pupil bool

whether to use entrance pupil. Defaults to False.

True

Returns:

Name Type Description
ray Ray object

Ray object. Shape [num_rays, 3]

Source code in deeplens/optics/geolens_pkg/vis.py
@torch.no_grad()
def sample_point_source_2D(
    self,
    fov=0.0,
    depth=DEPTH,
    num_rays=7,
    wvln=DEFAULT_WAVE,
    entrance_pupil=True,
):
    """Sample point source rays (2D) in object space.

    Used for (1) drawing lens setup.

    Args:
        fov (float, optional): incident angle (in degree). Defaults to 0.0.
        depth (float, optional): sampling depth. Defaults to DEPTH.
        num_rays (int, optional): ray number. Defaults to 7.
        wvln (float, optional): ray wvln. Defaults to DEFAULT_WAVE.
        entrance_pupil (bool, optional): whether to use entrance pupil. Defaults to False.

    Returns:
        ray (Ray object): Ray object. Shape [num_rays, 3]
    """
    # Sample point on the object plane
    ray_o = torch.tensor([depth * float(np.tan(np.deg2rad(fov))), 0.0, depth])
    ray_o = ray_o.unsqueeze(0).repeat(num_rays, 1)

    # Sample points (second point) on the pupil
    if entrance_pupil:
        pupilz, pupilr = self.calc_entrance_pupil()
    else:
        pupilz, pupilr = 0, self.surfaces[0].r

    x2 = torch.linspace(-pupilr, pupilr, num_rays) * 0.99
    y2 = torch.zeros_like(x2)
    z2 = torch.full_like(x2, pupilz)
    ray_o2 = torch.stack((x2, y2, z2), axis=1)

    # Form the rays
    ray_d = ray_o2 - ray_o
    ray = Ray(ray_o, ray_d, wvln, device=self.device)

    # Propagate rays to the sampling depth
    ray.prop_to(depth)
    return ray

draw_layout

draw_layout(filename, depth=float('inf'), zmx_format=True, multi_plot=False, lens_title=None, show=False)

Plot 2D lens layout with ray tracing.

Parameters:

Name Type Description Default
filename

Output filename

required
depth

Depth for ray tracing

float('inf')
entrance_pupil

Whether to use entrance pupil

required
zmx_format

Whether to use ZMX format

True
multi_plot

Whether to create multiple plots

False
lens_title

Title for the lens plot

None
show

Whether to show the plot

False
Source code in deeplens/optics/geolens_pkg/vis.py
def draw_layout(
    self,
    filename,
    depth=float("inf"),
    zmx_format=True,
    multi_plot=False,
    lens_title=None,
    show=False,
):
    """Plot 2D lens layout with ray tracing.

    Args:
        filename: Output filename
        depth: Depth for ray tracing
        entrance_pupil: Whether to use entrance pupil
        zmx_format: Whether to use ZMX format
        multi_plot: Whether to create multiple plots
        lens_title: Title for the lens plot
        show: Whether to show the plot
    """
    num_rays = 11
    num_views = 3

    # Lens title
    if lens_title is None:
        eff_foclen = int(self.foclen)
        eq_foclen = int(self.eqfl)
        fov_deg = round(2 * self.rfov * 180 / torch.pi, 1)
        sensor_r = round(self.r_sensor, 1)
        sensor_w, sensor_h = self.sensor_size
        sensor_w = round(sensor_w, 1)
        sensor_h = round(sensor_h, 1)

        if self.aper_idx is not None:
            _, pupil_r = self.calc_entrance_pupil()
            fnum = round(eff_foclen / pupil_r / 2, 1)
            lens_title = f"FocLen{eff_foclen}mm - F/{fnum} - FoV{fov_deg}(Equivalent {eq_foclen}mm) - Sensor Diagonal {2 * sensor_r}mm"
        else:
            lens_title = f"FocLen{eff_foclen}mm - FoV{fov_deg}(Equivalent {eq_foclen}mm) - Sensor Diagonal {2 * sensor_r}mm"

    # Draw lens layout
    colors_list = ["#CC0000", "#006600", "#0066CC"]
    rfov_deg = float(np.rad2deg(self.rfov))
    fov_ls = np.linspace(0, rfov_deg * 0.99, num=num_views)

    if not multi_plot:
        ax, fig = self.draw_lens_2d(zmx_format=zmx_format)
        fig.suptitle(lens_title, fontsize=10)
        for i, fov in enumerate(fov_ls):
            # Sample rays, shape (num_rays, 3)
            if depth == float("inf"):
                ray = self.sample_parallel_2D(
                    fov=fov,
                    wvln=WAVE_RGB[2 - i],
                    num_rays=num_rays,
                    depth=-1.0,
                    plane="sagittal",
                )
            else:
                ray = self.sample_point_source_2D(
                    fov=fov,
                    depth=depth,
                    num_rays=num_rays,
                    wvln=WAVE_RGB[2 - i],
                )
                ray.prop_to(-1.0)

            # Trace rays to sensor and plot ray paths
            _, ray_o_record = self.trace2sensor(ray=ray, record=True)
            ax, fig = self.draw_ray_2d(
                ray_o_record, ax=ax, fig=fig, color=colors_list[i]
            )

        ax.axis("off")

    else:
        fig, axs = plt.subplots(1, 3, figsize=(15, 5))
        fig.suptitle(lens_title, fontsize=10)
        for i, wvln in enumerate(WAVE_RGB):
            ax = axs[i]
            ax, fig = self.draw_lens_2d(ax=ax, fig=fig, zmx_format=zmx_format)
            for fov in fov_ls:
                # Sample rays, shape (num_rays, 3)
                if depth == float("inf"):
                    ray = self.sample_parallel_2D(
                        fov=fov,
                        num_rays=num_rays,
                        wvln=wvln,
                        plane="sagittal",
                    )
                else:
                    ray = self.sample_point_source_2D(
                        fov=fov,
                        depth=depth,
                        num_rays=num_rays,
                        wvln=wvln,
                    )

                # Trace rays to sensor and plot ray paths
                ray_out, ray_o_record = self.trace2sensor(ray=ray, record=True)
                ax, fig = self.draw_ray_2d(
                    ray_o_record, ax=ax, fig=fig, color=colors_list[i]
                )
                ax.axis("off")

    if show:
        fig.show()
    else:
        fig.savefig(filename, format="png", dpi=300)
        plt.close()

draw_lens_2d

draw_lens_2d(ax=None, fig=None, color='k', linestyle='-', zmx_format=False, fix_bound=False)

Draw lens cross-section layout in a 2D plot.

Renders each surface profile, connects lens elements with edge lines, and draws the sensor plane.

Parameters:

Name Type Description Default
ax Axes

Existing axes to draw on. If None, creates a new figure. Defaults to None.

None
fig Figure

Existing figure. Defaults to None.

None
color str

Line colour for lens outlines. Defaults to 'k'.

'k'
linestyle str

Line style. Defaults to '-'.

'-'
zmx_format bool

If True, draw stepped edge connections matching Zemax layout style. Defaults to False.

False
fix_bound bool

If True, use fixed axis limits [-1,7]x[-4,4]. Defaults to False.

False

Returns:

Name Type Description
tuple

(ax, fig) matplotlib axes and figure objects.

Source code in deeplens/optics/geolens_pkg/vis.py
def draw_lens_2d(
    self,
    ax=None,
    fig=None,
    color="k",
    linestyle="-",
    zmx_format=False,
    fix_bound=False,
):
    """Draw lens cross-section layout in a 2D plot.

    Renders each surface profile, connects lens elements with edge lines,
    and draws the sensor plane.

    Args:
        ax (matplotlib.axes.Axes, optional): Existing axes to draw on. If None,
            creates a new figure. Defaults to None.
        fig (matplotlib.figure.Figure, optional): Existing figure. Defaults to None.
        color (str, optional): Line colour for lens outlines. Defaults to 'k'.
        linestyle (str, optional): Line style. Defaults to '-'.
        zmx_format (bool, optional): If True, draw stepped edge connections
            matching Zemax layout style. Defaults to False.
        fix_bound (bool, optional): If True, use fixed axis limits [-1,7]x[-4,4].
            Defaults to False.

    Returns:
        tuple: (ax, fig) matplotlib axes and figure objects.
    """
    # If no ax is given, generate a new one.
    if ax is None and fig is None:
        # fig, ax = plt.subplots(figsize=(6, 6))
        fig, ax = plt.subplots()

    # Draw lens surfaces
    for i, s in enumerate(self.surfaces):
        s.draw_widget(ax)

    # Connect two surfaces
    for i in range(len(self.surfaces)):
        if self.surfaces[i].mat2.n > 1.1:
            s_prev = self.surfaces[i]
            s = self.surfaces[i + 1]

            r_prev = float(s_prev.r)
            r = float(s.r)
            sag_prev = s_prev.surface_with_offset(r_prev, 0.0).item()
            sag = s.surface_with_offset(r, 0.0).item()

            if zmx_format:
                if r > r_prev:
                    z = np.array([sag_prev, sag_prev, sag])
                    x = np.array([r_prev, r, r])
                else:
                    z = np.array([sag_prev, sag, sag])
                    x = np.array([r_prev, r, r])
            else:
                z = np.array([sag_prev, sag])
                x = np.array([r_prev, r])

            ax.plot(z, -x, color, linewidth=0.75)
            ax.plot(z, x, color, linewidth=0.75)
            s_prev = s

    # Draw sensor
    ax.plot(
        [self.d_sensor.item(), self.d_sensor.item()],
        [-self.r_sensor, self.r_sensor],
        color,
    )

    # Set figure size
    if fix_bound:
        ax.set_aspect("equal")
        ax.set_xlim(-1, 7)
        ax.set_ylim(-4, 4)
    else:
        ax.set_aspect("equal", adjustable="datalim", anchor="C")
        ax.minorticks_on()
        ax.set_xlim(-0.5, 7.5)
        ax.set_ylim(-4, 4)
        ax.autoscale()

    return ax, fig

draw_ray_2d

draw_ray_2d(ray_o_record, ax, fig, color='b')

Plot ray paths.

Parameters:

Name Type Description Default
ray_o_record list

list of intersection points.

required
ax Axes

matplotlib axes.

required
fig Figure

matplotlib figure.

required
Source code in deeplens/optics/geolens_pkg/vis.py
def draw_ray_2d(self, ray_o_record, ax, fig, color="b"):
    """Plot ray paths.

    Args:
        ray_o_record (list): list of intersection points.
        ax (matplotlib.axes.Axes): matplotlib axes.
        fig (matplotlib.figure.Figure): matplotlib figure.
    """
    # shape (num_view, num_rays, num_path, 2)
    ray_o_record = torch.stack(ray_o_record, dim=-2).cpu().numpy()
    if ray_o_record.ndim == 3:
        ray_o_record = ray_o_record[None, ...]

    for idx_view in range(ray_o_record.shape[0]):
        for idx_ray in range(ray_o_record.shape[1]):
            ax.plot(
                ray_o_record[idx_view, idx_ray, :, 2],
                ray_o_record[idx_view, idx_ray, :, 0],
                color,
                linewidth=0.8,
            )

            # ax.scatter(
            #     ray_o_record[idx_view, idx_ray, :, 2],
            #     ray_o_record[idx_view, idx_ray, :, 0],
            #     "b",
            #     marker="x",
            # )

    return ax, fig

draw_layout_3d

draw_layout_3d(filename=None, view_angle=30, show=False)

Draw 3D layout of the lens system.

Parameters:

Name Type Description Default
filename str

Path to save the figure. Defaults to None.

None
view_angle int

Viewing angle for the 3D plot

30
show bool

Whether to display the figure

False

Returns:

Type Description

fig, ax: Matplotlib figure and axis objects

Source code in deeplens/optics/geolens_pkg/vis.py
def draw_layout_3d(self, filename=None, view_angle=30, show=False):
    """Draw 3D layout of the lens system.

    Args:
        filename (str, optional): Path to save the figure. Defaults to None.
        view_angle (int): Viewing angle for the 3D plot
        show (bool): Whether to display the figure

    Returns:
        fig, ax: Matplotlib figure and axis objects
    """
    raise Exception(
        "This function is deprecated. Please use the draw_lens_3d function in the view_3d module instead."
    )
    fig = plt.figure(figsize=(10, 6))
    ax = fig.add_subplot(111, projection="3d")

    # Enable depth sorting for proper occlusion
    ax.set_proj_type(
        "persp"
    )  # Use perspective projection for better depth perception

    # Draw each surface
    for i, surf in enumerate(self.surfaces):
        surf.draw_widget3D(ax)

        # Connect current surface with previous surface if material is not air
        if i > 0 and self.surfaces[i - 1].mat2.get_name() != "air":
            # Get edge points of current and previous surfaces
            theta = np.linspace(0, 2 * np.pi, 256)

            # Current surface edge
            curr_edge_x = surf.r * np.cos(theta)
            curr_edge_y = surf.r * np.sin(theta)
            curr_edge_z = np.array(
                [
                    surf.surface_with_offset(
                        torch.tensor(curr_edge_x[j], device=surf.device),
                        torch.tensor(curr_edge_y[j], device=surf.device),
                    ).item()
                    for j in range(len(theta))
                ]
            )

            # Previous surface edge
            prev_surf = self.surfaces[i - 1]
            prev_edge_x = prev_surf.r * np.cos(theta)
            prev_edge_y = prev_surf.r * np.sin(theta)
            prev_edge_z = np.array(
                [
                    prev_surf.surface_with_offset(
                        torch.tensor(prev_edge_x[j], device=prev_surf.device),
                        torch.tensor(prev_edge_y[j], device=prev_surf.device),
                    ).item()
                    for j in range(len(theta))
                ]
            )

            # Create a cylindrical surface connecting the two edges
            theta_mesh, t_mesh = np.meshgrid(theta, np.array([0, 1]))

            # Interpolate between previous and current surface edges
            x_mesh = (
                prev_edge_x[None, :] * (1 - t_mesh) + curr_edge_x[None, :] * t_mesh
            )
            y_mesh = (
                prev_edge_y[None, :] * (1 - t_mesh) + curr_edge_y[None, :] * t_mesh
            )
            z_mesh = (
                prev_edge_z[None, :] * (1 - t_mesh) + curr_edge_z[None, :] * t_mesh
            )

            # Plot the connecting surface with sort_zpos for proper occlusion
            surf = ax.plot_surface(
                z_mesh,
                x_mesh,
                y_mesh,
                color="lightblue",
                alpha=0.3,
                edgecolor="lightblue",
                linewidth=0.5,
                antialiased=True,
            )
            # Set the zorder based on the mean z position for better occlusion
            surf._sort_zpos = np.mean(z_mesh)

    # Draw sensor as a rectangle
    if hasattr(self, "sensor_size") and hasattr(self, "d_sensor"):
        # Get sensor dimensions
        sensor_width = self.sensor_size[0]
        sensor_height = self.sensor_size[1]
        sensor_z = self.d_sensor.item()

        # Create sensor vertices
        half_width = sensor_width / 2
        half_height = sensor_height / 2

        # Define the corners of the rectangle
        x = np.array(
            [-half_width, half_width, half_width, -half_width, -half_width]
        )
        y = np.array(
            [-half_height, -half_height, half_height, half_height, -half_height]
        )
        z = np.full_like(x, sensor_z)

        # Plot the sensor rectangle
        ax.plot(z, x, y, color="black", linewidth=1.5)

        # Add a semi-transparent surface for the sensor
        sensor_x, sensor_y = np.meshgrid(
            np.linspace(-half_width, half_width, 2),
            np.linspace(-half_height, half_height, 2),
        )
        sensor_z = np.full_like(sensor_x, sensor_z)
        sensor_surf = ax.plot_surface(
            sensor_z,
            sensor_x,
            sensor_y,
            color="gray",
            alpha=0.3,
            edgecolor="black",
            linewidth=0.5,
        )
        # Set the zorder for the sensor
        sensor_surf._sort_zpos = sensor_z.mean()

    # Set axis properties
    ax.set_xlabel("Z")
    ax.set_ylabel("X")
    ax.set_zlabel("Y")
    ax.view_init(elev=20, azim=-view_angle - 90)

    # Make all axes have the same scale (unit step size)
    ax.set_box_aspect([1, 1, 1])
    ax.set_aspect("equal")

    # Enable depth sorting for proper occlusion
    from matplotlib.collections import PathCollection

    for c in ax.collections:
        if isinstance(c, PathCollection):
            c.set_sort_zpos(c.get_offsets()[:, 2].mean())

    plt.tight_layout()

    if filename:
        fig.savefig(f"{filename}.png", format="png", dpi=300)

    if show:
        plt.show()
    else:
        plt.close()

    return fig, ax

create_barrier

create_barrier(filename, barrier_thickness=1.0, ring_height=0.5, ring_size=1.0)

Create a 3D barrier for the lens system.

Parameters:

Name Type Description Default
filename

Path to save the figure

required
barrier_thickness

Thickness of the barrier

1.0
ring_height

Height of the annular ring

0.5
ring_size

Size of the annular ring

1.0
Source code in deeplens/optics/geolens_pkg/vis.py
def create_barrier(
    self, filename, barrier_thickness=1.0, ring_height=0.5, ring_size=1.0
):
    """Create a 3D barrier for the lens system.

    Args:
        filename: Path to save the figure
        barrier_thickness: Thickness of the barrier
        ring_height: Height of the annular ring
        ring_size: Size of the annular ring
    """
    barriers = []
    rings = []

    # Create barriers
    barrier_z = 0.0
    barrier_r = 0.0
    barrier_length = 0.0
    for i in range(len(self.surfaces)):
        barrier_r = max(self.surfaces[i].r, barrier_r)

        if self.surfaces[i].mat2.get_name() != "air":
            # Update the barrier radius
            # barrier_r = max(geolens.surfaces[i].r, barrier_r)
            pass
        else:
            # Extend the barrier till middle of the air space to the next surface
            max_curr_surf_d = self.surfaces[i].d.item() + max(
                self.surfaces[i].surface_sag(0.0, self.surfaces[i].r), 0.0
            )
            if i < len(self.surfaces) - 1:
                min_next_surf_d = self.surfaces[i + 1].d.item() + min(
                    self.surfaces[i + 1].surface_sag(0.0, self.surfaces[i + 1].r),
                    0.0,
                )
                extra_space = (min_next_surf_d - max_curr_surf_d) / 2
            else:
                min_next_surf_d = self.d_sensor.item()
                extra_space = min_next_surf_d - max_curr_surf_d

            barrier_length = max_curr_surf_d + extra_space - barrier_z

            # Create a barrier
            barrier = {
                "pos_z": barrier_z,
                "pos_r": barrier_r,
                "length": barrier_length,
                "thickness": barrier_thickness,
            }
            barriers.append(barrier)

            # Reset the barrier parameters
            barrier_z = barrier_length + barrier_z
            barrier_r = 0.0
            barrier_length = 0.0

    # # Create rings
    # for i in range(len(geolens.surfaces)):
    #     if geolens.surfaces[i].mat2.get_name() != "air":
    #         ring = {
    #             "pos_z": geolens.surfaces[i].d.item(),

    # Plot lens layout
    ax, fig = self.draw_layout()

    # Plot barrier
    barrier_z_ls = []
    barrier_r_ls = []
    for b in barriers:
        barrier_z_ls.append(b["pos_z"])
        barrier_z_ls.append(b["pos_z"] + b["length"])
        barrier_r_ls.append(b["pos_r"])
        barrier_r_ls.append(b["pos_r"])
    ax.plot(barrier_z_ls, barrier_r_ls, "green", linewidth=1.0)
    ax.plot(barrier_z_ls, [-i for i in barrier_r_ls], "green", linewidth=1.0)

    # Plot rings

    fig.savefig(filename, format="png", dpi=300)
    plt.close()

    pass

deeplens.optics.geolens_pkg.tolerance.GeoLensTolerance

Mixin providing tolerance analysis for GeoLens.

Implements two complementary approaches:

  • Sensitivity analysis – first-order gradient-based estimation of how each manufacturing error affects optical performance.
  • Monte-Carlo analysis – statistical sampling of random manufacturing errors to predict yield and worst-case performance.

This class is not instantiated directly; it is mixed into :class:~deeplens.optics.geolens.GeoLens.

References

Jun Dai et al., "Tolerance-Aware Deep Optics," arXiv:2502.04719, 2025.

init_tolerance

init_tolerance(tolerance_params=None)

Initialize manufacturing tolerance parameters for all surfaces.

Sets up tolerance ranges (e.g., curvature, thickness, decenter, tilt) on each surface. These are used by sample_tolerance() to simulate random manufacturing errors.

Parameters:

Name Type Description Default
tolerance_params dict

Custom tolerance specifications. If None, each surface uses its own defaults. Defaults to None.

None
Source code in deeplens/optics/geolens_pkg/tolerance.py
def init_tolerance(self, tolerance_params=None):
    """Initialize manufacturing tolerance parameters for all surfaces.

    Sets up tolerance ranges (e.g., curvature, thickness, decenter, tilt)
    on each surface. These are used by ``sample_tolerance()`` to simulate
    random manufacturing errors.

    Args:
        tolerance_params (dict, optional): Custom tolerance specifications.
            If None, each surface uses its own defaults. Defaults to None.
    """
    if tolerance_params is None:
        tolerance_params = {}

    for i in range(len(self.surfaces)):
        self.surfaces[i].init_tolerance(tolerance_params=tolerance_params)

sample_tolerance

sample_tolerance()

Apply random manufacturing errors to all surfaces.

Randomly perturbs each surface according to its tolerance ranges and then refocuses the lens to compensate for the focus shift.

Source code in deeplens/optics/geolens_pkg/tolerance.py
@torch.no_grad()
def sample_tolerance(self):
    """Apply random manufacturing errors to all surfaces.

    Randomly perturbs each surface according to its tolerance ranges and
    then refocuses the lens to compensate for the focus shift.
    """
    # Randomly perturb all surfaces
    for i in range(len(self.surfaces)):
        self.surfaces[i].sample_tolerance()

    # Refocus the lens
    self.refocus()

zero_tolerance

zero_tolerance()

Reset all manufacturing errors to zero (nominal lens state).

Clears the perturbations on every surface and refocuses the lens.

Source code in deeplens/optics/geolens_pkg/tolerance.py
@torch.no_grad()
def zero_tolerance(self):
    """Reset all manufacturing errors to zero (nominal lens state).

    Clears the perturbations on every surface and refocuses the lens.
    """
    for i in range(len(self.surfaces)):
        self.surfaces[i].zero_tolerance()

    # Refocus the lens
    self.refocus()

tolerancing_sensitivity

tolerancing_sensitivity(tolerance_params=None)

Use sensitivity analysis (1st order gradient) to compute the tolerance score.

References

[1] Page 10 from: https://wp.optics.arizona.edu/optomech/wp-content/uploads/sites/53/2016/08/8-Tolerancing-1.pdf [2] Fast sensitivity control method with differentiable optics. Optics Express 2025. [3] Optical Design Tolerancing. CODE V.

Source code in deeplens/optics/geolens_pkg/tolerance.py
def tolerancing_sensitivity(self, tolerance_params=None):
    """Use sensitivity analysis (1st order gradient) to compute the tolerance score.

    References:
        [1] Page 10 from: https://wp.optics.arizona.edu/optomech/wp-content/uploads/sites/53/2016/08/8-Tolerancing-1.pdf
        [2] Fast sensitivity control method with differentiable optics. Optics Express 2025.
        [3] Optical Design Tolerancing. CODE V.
    """
    # Initialize tolerance
    self.init_tolerance(tolerance_params=tolerance_params)

    # AutoDiff to compute the gradient/sensitivity
    self.get_optimizer_params()
    loss = self.loss_rms()
    loss.backward()

    # Calculate sensitivity results
    sensitivity_results = {}
    for i in range(len(self.surfaces)):
        sensitivity_results.update(self.surfaces[i].sensitivity_score())

    # Toleranced RSS (Root Sum Square) loss
    tolerancing_score = sum(
        v for k, v in sensitivity_results.items() if k.endswith("_score")
    )
    loss_rss = torch.sqrt(loss**2 + tolerancing_score).item()
    sensitivity_results["loss_nominal"] = round(loss.item(), 6)
    sensitivity_results["loss_rss"] = round(loss_rss, 6)
    return sensitivity_results

tolerancing_monte_carlo

tolerancing_monte_carlo(trials=1000, tolerance_params=None)

Use Monte Carlo simulation to compute the tolerance.

Note: we can multiplex sampled rays to improve the speed.

Parameters:

Name Type Description Default
trials int

Number of Monte Carlo trials

1000
tolerance_params dict

Tolerance parameters

None

Returns:

Name Type Description
dict

Monte Carlo tolerance analysis results

References

[1] https://optics.ansys.com/hc/en-us/articles/43071088477587-How-to-analyze-your-tolerance-results [2] Optical Design Tolerancing. CODE V.

Source code in deeplens/optics/geolens_pkg/tolerance.py
@torch.no_grad()
def tolerancing_monte_carlo(self, trials=1000, tolerance_params=None):
    """Use Monte Carlo simulation to compute the tolerance.

    Note: we can multiplex sampled rays to improve the speed.

    Args:
        trials (int): Number of Monte Carlo trials
        tolerance_params (dict): Tolerance parameters

    Returns:
        dict: Monte Carlo tolerance analysis results

    References:
        [1] https://optics.ansys.com/hc/en-us/articles/43071088477587-How-to-analyze-your-tolerance-results
        [2] Optical Design Tolerancing. CODE V.
    """

    def merit_func(lens, fov=0.0, depth=DEPTH):
        # Calculate MTF at a specific field of view
        point = [0, -fov / lens.rfov, depth]
        psf = lens.psf(points=point, recenter=True)
        freq, mtf_tan, mtf_sag = lens.psf2mtf(psf, pixel_size=lens.pixel_size)

        # Evaluate MTF at a specific frequency
        nyquist_freq = 0.5 / lens.pixel_size
        eval_freq = 0.25 * nyquist_freq
        idx = torch.argmin(torch.abs(torch.tensor(freq) - eval_freq))
        score = (mtf_tan[idx] + mtf_sag[idx]) / 2
        return score.item()

    # Initialize tolerance
    self.init_tolerance(tolerance_params=tolerance_params)

    # Monte Carlo simulation
    merit_ls = []
    with torch.no_grad():
        for i in tqdm(range(trials)):
            # Sample a random perturbation
            self.sample_tolerance()

            # Evaluate perturbed performance
            perturbed_merit = merit_func(lens=self, fov=0.0, depth=DEPTH)
            merit_ls.append(perturbed_merit)

            # Clear perturbation
            self.zero_tolerance()

    merit_ls = np.array(merit_ls)

    # Baseline merit
    self.refocus()
    baseline_merit = merit_func(lens=self, fov=0.0, depth=DEPTH)
    # merit_ls /= baseline_merit

    # Results plot
    sorted_merit = np.sort(merit_ls)
    cumulative_prob = (1 - np.arange(len(sorted_merit)) / len(sorted_merit)) * 100
    plt.figure(figsize=(8, 6))
    plt.xlabel("Merit Score", fontsize=12)
    plt.ylabel("Cumulative Probability (%)", fontsize=12)
    plt.title("Cumulative Probability beyond Merit Score", fontsize=14)
    plt.plot(sorted_merit, cumulative_prob, linewidth=2)
    plt.gca().invert_xaxis()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig("Monte_Carlo_Cumulative_Prob.png", dpi=150, bbox_inches="tight")
    plt.close()

    # Results dict
    results = {
        "method": "monte_carlo",
        "trials": trials,
        "baseline_merit": round(baseline_merit, 6),
        "merit_std": round(float(np.std(merit_ls)), 6),
        "merit_mean": round(float(np.mean(merit_ls)), 6),
        "merit_yield": {
            "99% > ": round(float(np.percentile(merit_ls, 1)), 4),
            "95% > ": round(float(np.percentile(merit_ls, 5)), 4),
            "90% > ": round(float(np.percentile(merit_ls, 10)), 4),
            "80% > ": round(float(np.percentile(merit_ls, 20)), 4),
            "70% > ": round(float(np.percentile(merit_ls, 30)), 4),
            "60% > ": round(float(np.percentile(merit_ls, 60)), 4),
            "50% > ": round(float(np.percentile(merit_ls, 50)), 4),
        },
        "merit_percentile": {
            "99% < ": round(float(np.percentile(merit_ls, 99)), 4),
            "95% < ": round(float(np.percentile(merit_ls, 95)), 4),
            "90% < ": round(float(np.percentile(merit_ls, 90)), 4),
            "80% < ": round(float(np.percentile(merit_ls, 80)), 4),
            "70% < ": round(float(np.percentile(merit_ls, 70)), 4),
            "60% < ": round(float(np.percentile(merit_ls, 60)), 4),
            "50% < ": round(float(np.percentile(merit_ls, 50)), 4),
        },
    }
    return results

tolerancing_wavefront

tolerancing_wavefront(tolerance_params=None)

Use wavefront differential method to compute the tolerance.

Wavefront differential method is proposed in [1], while the detailed implementation remains unknown. I (Xinge Yang) assume a symbolic differentiation is used to compute the gradient/Jacobian of the wavefront error. With AutoDiff, we can easily calculate Jacobian with gradient backpropagation, therefore I leave the implementation of this method as future work.

Parameters:

Name Type Description Default
tolerance_params dict

Tolerance parameters

None

Returns:

Name Type Description
dict

Wavefront tolerance analysis results

References

[1] Optical Design Tolerancing. CODE V.

Source code in deeplens/optics/geolens_pkg/tolerance.py
def tolerancing_wavefront(self, tolerance_params=None):
    """Use wavefront differential method to compute the tolerance.

    Wavefront differential method is proposed in [1], while the detailed implementation remains unknown. I (Xinge Yang) assume a symbolic differentiation is used to compute the gradient/Jacobian of the wavefront error. With AutoDiff, we can easily calculate Jacobian with gradient backpropagation, therefore I leave the implementation of this method as future work.

    Args:
        tolerance_params (dict): Tolerance parameters

    Returns:
        dict: Wavefront tolerance analysis results

    References:
        [1] Optical Design Tolerancing. CODE V.
    """
    pass

deeplens.optics.geolens_pkg.view_3d.GeoLensVis3D

Mixin providing 3-D mesh visualisation for GeoLens.

Creates lens surface, aperture, barrier, sensor, and ray-path meshes as polygon data and optionally renders them with PyVista. All geometry is expressed in millimetres and stored as :class:CrossPoly (vertex/face) objects that can be saved to .obj files for external renderers.

This class is not instantiated directly; it is mixed into :class:~deeplens.optics.geolens.GeoLens.

create_mesh

create_mesh(mesh_rings: int = 32, mesh_arms: int = 128, is_wrap: bool = False)

Create all lens/bridge/sensor/aperture meshes.

Parameters:

Name Type Description Default
lens GeoLens

The lens object.

required
mesh_rings int

The number of rings in the mesh.

32
mesh_arms int

The number of arms in the mesh.

128
is_wrap bool

Whether to wrap the lens bridge around the lens as cylinder.

False

Returns: surf_meshes (List[Surface]): Lens surfaces meshes. bridge_meshes (List[FaceMesh]): Lens bridges meshes. (NOT support wrap around for now) sensor_mesh (RectangleMesh): Sensor meshes. (only support rectangular sensor for now)

Source code in deeplens/optics/geolens_pkg/view_3d.py
def create_mesh(
    self,
    mesh_rings: int = 32,
    mesh_arms: int = 128,
    is_wrap: bool = False,
):
    """Create all lens/bridge/sensor/aperture meshes.

    Args:
        lens (GeoLens): The lens object.
        mesh_rings (int): The number of rings in the mesh.
        mesh_arms (int): The number of arms in the mesh.
        is_wrap (bool): Whether to wrap the lens bridge around the lens as cylinder.
    Returns:
        surf_meshes (List[Surface]): Lens surfaces meshes.
        bridge_meshes (List[FaceMesh]): Lens bridges meshes. (NOT support wrap around for now)
        sensor_mesh (RectangleMesh): Sensor meshes. (only support rectangular sensor for now)
    """
    surf_meshes = []
    element_group = []
    element_groups = []
    bridge_meshes = []  # change to nested list for wrap around
    sensor_mesh = None

    # Create the surface meshes
    for i, surf in enumerate(self.surfaces):
        # Create the surface mesh (list of Surface objects)
        surf_meshes.append(surf.create_mesh(n_rings=mesh_rings, n_arms=mesh_arms))

        # Add the surface to the element group
        element_group.append(i)
        if surf.mat2.name == "air":
            element_groups.append(element_group)
            element_group = []

    # Create the bridge meshes (list of FaceMesh objects)
    for i, pair in enumerate(element_groups):
        if len(pair) == 1:
            bridge_meshes.append([])
            continue
        elif len(pair) == 2:
            a_idx, b_idx = pair
            a = surf_meshes[a_idx]
            b = surf_meshes[b_idx]
            bridge_mesh_group = []
            if not is_wrap:
                bridge_mesh = bridge(a.rim, b.rim)
                bridge_mesh_group.append(bridge_mesh)
            else:
                # create wrap by creating a new rim
                # from projecting the larger rim onto the smaller rim plane
                # assume the elements are always ordered on z-axis
                r_a = self.surfaces[a_idx].r
                r_b = self.surfaces[b_idx].r
                d_rim_a = np.mean(
                    a.rim.vertices[:, 2], keepdims=False
                )  # calc rim mean z
                d_rim_b = np.mean(b.rim.vertices[:, 2], keepdims=False)

                if r_a > r_b:
                    z = line_translate(a.rim, 0, 0, d_rim_b - d_rim_a)
                    bridge_mesh_wrap = bridge(z, b.rim)
                    bridge_mesh = bridge(a.rim, z)
                    bridge_mesh_group.append(bridge_mesh_wrap)
                elif r_a < r_b:
                    z = line_translate(b.rim, 0, 0, d_rim_a - d_rim_b)
                    bridge_mesh_wrap = bridge(a.rim, z)
                    bridge_mesh = bridge(z, b.rim)
                    bridge_mesh_group.append(bridge_mesh_wrap)
                else:
                    bridge_mesh = bridge(a.rim, b.rim)
                bridge_mesh_group.append(bridge_mesh)
            bridge_meshes.append(bridge_mesh_group)

        elif len(pair) == 3:
            a_idx, b_idx, c_idx = pair
            a = surf_meshes[a_idx]
            b = surf_meshes[b_idx]
            c = surf_meshes[c_idx]
            bridge_mesh_group = []
            if not is_wrap:
                bridge_mesh = bridge(a.rim, b.rim)
                bridge_mesh_group.append(bridge_mesh)
                bridge_mesh = bridge(b.rim, c.rim)
                bridge_mesh_group.append(bridge_mesh)
            else:
                # create wrap by creating a new rim
                # from projecting the larger rim onto the smaller rim plane
                # assume the elements are always ordered on z-axis
                r_a = self.surfaces[a_idx].r
                r_b = self.surfaces[b_idx].r
                r_c = self.surfaces[c_idx].r
                d_rim_a = np.mean(
                    a.rim.vertices[:, 2], keepdims=False
                )  # calc rim mean z
                d_rim_b = np.mean(b.rim.vertices[:, 2], keepdims=False)
                d_rim_c = np.mean(c.rim.vertices[:, 2], keepdims=False)

                rim_list = [a.rim, b.rim, c.rim]
                r_list = [r_a, r_b, r_c]
                d_rim_list = [d_rim_a, d_rim_b, d_rim_c]
                idx_wrap = r_list.index(max(r_list))
                r_wrap = r_list[idx_wrap]
                d_rim_wrap = d_rim_list[idx_wrap]

                for i in range(3):
                    if i != idx_wrap and r_list[i] != r_wrap:
                        # substitute the rim with the wrapped rim
                        d_diff = d_rim_list[i] - d_rim_wrap
                        z = line_translate(rim_list[idx_wrap], 0, 0, d_diff)
                        # add the wrap bridge between older rim and wrapped one
                        wrap_mesh = bridge(rim_list[i], z)
                        # update the rim
                        rim_list[i] = z
                        bridge_mesh_group.append(wrap_mesh)
                bridge_mesh = bridge(rim_list[0], rim_list[1])
                bridge_mesh_group.append(bridge_mesh)
                bridge_mesh = bridge(rim_list[1], rim_list[2])
                bridge_mesh_group.append(bridge_mesh)
            bridge_meshes.append(bridge_mesh_group)

        else:
            raise ValueError(f"Invalid bridge group length: {len(pair)}")

    # Create the sensor mesh (RectangleMesh object)
    sensor_d = self.d_sensor.item()
    sensor_r = self.r_sensor
    h, w = sensor_r * 1.4142, sensor_r * 1.4142
    sensor_mesh = RectangleMesh(
        np.array([0, 0, sensor_d]), np.array([1, 0, 0]), np.array([0, 1, 0]), w, h
    )

    # turn surf_meshes to list of FaceMesh
    surf_meshes_cvt = [surf_to_face_mesh(surf) for surf in surf_meshes]
    return surf_meshes_cvt, bridge_meshes, element_groups, sensor_mesh

draw_lens_3d

draw_lens_3d(plotter=None, save_dir: Optional[str] = None, mesh_rings: int = 32, mesh_arms: int = 128, surface_color: List[float] = [0.06, 0.3, 0.6], draw_rays: bool = True, fovs: List[float] = [0.0], fov_phis: List[float] = [0.0], ray_rings: int = 6, ray_arms: int = 8, is_wrap: bool = False)

Draw lens 3D layout with rays using pyvista.

Note: PyVista is imported lazily only when this method is called.

Parameters:

Name Type Description Default
plotter

pv.Plotter. Optional pyvista Plotter instance. If None, a new one is created.

None
save_dir str

The directory to save the image.

None
mesh_rings int

The number of rings in the mesh.

32
mesh_arms int

The number of arms in the mesh.

128
surface_color List[float]

The color of the surfaces.

[0.06, 0.3, 0.6]
draw_rays bool

Whether to show the rays.

True
fovs List[float]

The FoV angles to be sampled, unit: degree.

[0.0]
fov_phis List[float]

The FoV azimuthal angles to be sampled, unit: degree.

[0.0]
ray_rings int

The number of pupil rings to be sampled.

6
ray_arms int

The number of pupil arms to be sampled.

8
is_wrap bool

Whether to wrap the lens bridge around the lens as cylinder.

False

Returns:

Name Type Description
plotter

pv.Plotter. The pyvista Plotter instance.

Source code in deeplens/optics/geolens_pkg/view_3d.py
def draw_lens_3d(
    self,
    plotter=None,
    save_dir: Optional[str] = None,
    mesh_rings: int = 32,
    mesh_arms: int = 128,
    surface_color: List[float] = [0.06, 0.3, 0.6],
    draw_rays: bool = True,
    fovs: List[float] = [0.0],
    fov_phis: List[float] = [0.0],
    ray_rings: int = 6,
    ray_arms: int = 8,
    is_wrap: bool = False,
):
    """Draw lens 3D layout with rays using pyvista.

    Note: PyVista is imported lazily only when this method is called.

    Args:
        plotter: pv.Plotter. Optional pyvista Plotter instance. If None, a new one is created.
        save_dir (str): The directory to save the image.
        mesh_rings (int): The number of rings in the mesh.
        mesh_arms (int): The number of arms in the mesh.
        surface_color (List[float]): The color of the surfaces.
        draw_rays (bool): Whether to show the rays.
        fovs (List[float]): The FoV angles to be sampled, unit: degree.
        fov_phis (List[float]): The FoV azimuthal angles to be sampled, unit: degree.
        ray_rings (int): The number of pupil rings to be sampled.
        ray_arms (int): The number of pupil arms to be sampled.
        is_wrap (bool): Whether to wrap the lens bridge around the lens as cylinder.

    Returns:
        plotter: pv.Plotter. The pyvista Plotter instance.
    """
    # Lazy import of pyvista
    try:
        import pyvista as pv
    except ImportError as e:
        raise ImportError(
            "PyVista is required for 3D GUI rendering. Install with `pip install pyvista`."
        ) from e

    # Create plotter if not provided
    if plotter is None:
        plotter = pv.Plotter()

    surf_color = surface_color
    sensor_color = [0.5, 0.5, 0.5]

    # Create meshes
    surf_meshes, bridge_meshes, _, sensor_mesh = self.create_mesh(
        mesh_rings, mesh_arms, is_wrap
    )

    # Draw meshes
    for surf in surf_meshes:
        if not isinstance(surf, Aperture):
            _draw_mesh_to_plotter(
                plotter, surf, color=surf_color, opacity=0.5, pv=pv
            )

    for bridge_group in bridge_meshes:
        for bridge_mesh in bridge_group:
            _draw_mesh_to_plotter(
                plotter, bridge_mesh, color=surf_color, opacity=0.5, pv=pv
            )

    _draw_mesh_to_plotter(
        plotter, sensor_mesh, color=sensor_color, opacity=1.0, pv=pv
    )

    # Draw rays
    if draw_rays:
        rays_curve = geolens_ray_poly(
            self, fovs, fov_phis, n_rings=ray_rings, n_arms=ray_arms
        )

        rays_poly_list = [curve_list_to_polydata(r) for r in rays_curve]
        rays_poly_fov = [merge(r) for r in rays_poly_list]
        rays_poly_fov = [_wrap_base_poly_to_pyvista(r, pv) for r in rays_poly_fov]
        for r in rays_poly_fov:
            plotter.add_mesh(r)

    # Save images
    if save_dir is not None:
        os.makedirs(save_dir, exist_ok=True)
        plotter.show(screenshot=os.path.join(save_dir, "lens_layout3d.png"))

    return plotter

save_lens_obj

save_lens_obj(save_dir: str, mesh_rings: int = 64, mesh_arms: int = 128, save_rays: bool = False, fovs: List[float] = [0.0], fov_phis: List[float] = [0.0], ray_rings: int = 6, ray_arms: int = 8, is_wrap: bool = False, save_elements: bool = True)

Save lens geometry and rays as .obj files using pyvista.

Note: use #F2F7FFFF as the color for lens when rendering in Blender.

Parameters:

Name Type Description Default
lens GeoLens

The lens object.

required
save_dir str

The directory to save the image.

required
mesh_rings int

The number of rings in the mesh. (default: 128)

64
mesh_arms int

The number of arms in the mesh. (default: 256)

128
save_rays bool

Whether to save the rays.

False
fovs List[float]

The FoV angles to be sampled, unit: degree.

[0.0]
fov_phis List[float]

The FoV azimuthal angles to be sampled, unit: degree.

[0.0]
ray_rings int

The number of pupil rings to be sampled. (default: 6)

6
ray_arms int

The number of pupil arms to be sampled. (default: 8)

8
is_wrap bool

Whether to wrap the lens bridge around the lens as cylinder.

False
save_elements bool

Whether to save the elements.

True
Source code in deeplens/optics/geolens_pkg/view_3d.py
def save_lens_obj(
    self,
    save_dir: str,
    mesh_rings: int = 64,
    mesh_arms: int = 128,
    save_rays: bool = False,
    fovs: List[float] = [0.0],
    fov_phis: List[float] = [0.0],
    ray_rings: int = 6,
    ray_arms: int = 8,
    is_wrap: bool = False,
    save_elements: bool = True,
):
    """Save lens geometry and rays as .obj files using pyvista.

    Note: use #F2F7FFFF as the color for lens when rendering in Blender.

    Args:
        lens (GeoLens): The lens object.
        save_dir (str): The directory to save the image.
        mesh_rings (int): The number of rings in the mesh. (default: 128)
        mesh_arms (int): The number of arms in the mesh. (default: 256)
        save_rays (bool): Whether to save the rays.
        fovs (List[float]): The FoV angles to be sampled, unit: degree.
        fov_phis (List[float]): The FoV azimuthal angles to be sampled, unit: degree.
        ray_rings (int): The number of pupil rings to be sampled. (default: 6)
        ray_arms (int): The number of pupil arms to be sampled. (default: 8)
        is_wrap (bool): Whether to wrap the lens bridge around the lens as cylinder.
        save_elements (bool): Whether to save the elements.
    """
    os.makedirs(save_dir, exist_ok=True)

    # Create surfaces & bridges meshes
    surf_meshes, bridge_meshes, element_groups, sensor_mesh = self.create_mesh(
        mesh_rings, mesh_arms, is_wrap
    )

    # Save individual lens elements (surfaces + bridges merged)
    if save_elements:
        for i, pair in enumerate(element_groups):
            print(f"Running in pair {i} with pair length {len(pair)}")
            # Collect surface polydata
            surf_polydata_list = [surf_meshes[idx].get_polydata() for idx in pair]

            # Collect bridge polydata if available
            bridge_polydata_list = []
            if i < len(bridge_meshes) and len(bridge_meshes[i]) > 0:
                print(f"Bridge mesh group number: {len(bridge_meshes[i])}")
                bridge_polydata_list = [b.get_polydata() for b in bridge_meshes[i]]

            # Merge surfaces and bridges together
            all_polydata = surf_polydata_list + bridge_polydata_list
            if len(all_polydata) == 1:
                element = all_polydata[0]
            else:
                element = merge(all_polydata)
            element.save(os.path.join(save_dir, f"element_{i}.obj"))

    # Merge all surfaces and bridges, and save as single lens.obj file
    surf_polydata = [
        surf.get_polydata()
        for surf in surf_meshes
        if not isinstance(surf, Aperture)
    ]
    bridge_polydata = [
        b.get_polydata() for group in bridge_meshes for b in group
    ]  # flatten the nested list
    lens_polydata = surf_polydata + bridge_polydata
    lens_polydata = merge(lens_polydata)
    lens_polydata.save(os.path.join(save_dir, "lens.obj"))

    # Save sensor
    sensor_polydata = sensor_mesh.get_polydata()
    sensor_polydata.save(os.path.join(save_dir, "sensor.obj"))

    # Save rays
    if save_rays:
        rays_curve = geolens_ray_poly(
            self, fovs, fov_phis, n_rings=ray_rings, n_arms=ray_arms
        )
        rays_poly_list = [curve_list_to_polydata(r) for r in rays_curve]
        rays_poly_fov = [merge(r) for r in rays_poly_list]
        for i, r in enumerate(rays_poly_fov):
            r.save(os.path.join(save_dir, f"lens_rays_fov_{i}.obj"))

Combines a GeoLens with a diffractive optical element (DOE). Performs coherent ray tracing to the DOE plane, then Angular Spectrum Method (ASM) propagation to the sensor.

deeplens.optics.HybridLens

HybridLens(filename=None, device=None, dtype=torch.float64)

Bases: Lens

Hybrid refractive-diffractive lens using a differentiable ray–wave model.

Combines a :class:~deeplens.optics.geolens.GeoLens (refractive module) with a diffractive optical element (DOE) placed behind it. The pipeline is:

  1. Coherent ray tracing through the embedded GeoLens to obtain a complex wavefront at the DOE plane (including all geometric aberrations).
  2. DOE phase modulation applied to the wavefront.
  3. Angular Spectrum Method (ASM) propagation from the DOE to the sensor plane to produce the final intensity PSF.

This enables end-to-end gradient flow from image quality metrics back to both refractive surface parameters and the DOE phase profile.

Attributes:

Name Type Description
geolens GeoLens

Embedded refractive module.

doe

Diffractive optical element (one of Binary2, Pixel2D, Fresnel, Zernike, Grating).

Notes

Operates in torch.float64 by default for numerical stability of the wave-propagation step.

References

Xinge Yang et al., "End-to-End Hybrid Refractive-Diffractive Lens Design with Differentiable Ray-Wave Model," SIGGRAPH Asia 2024.

Initialize a hybrid refractive-diffractive lens.

Parameters:

Name Type Description Default
filename str

Path to the lens configuration JSON file. Defaults to None.

None
device str

Computation device ('cpu' or 'cuda'). Defaults to None.

None
dtype dtype

Data type for computations. Defaults to torch.float64.

float64
Source code in deeplens/optics/hybridlens.py
def __init__(
    self,
    filename=None,
    device=None,
    dtype=torch.float64,
):
    """Initialize a hybrid refractive-diffractive lens.

    Args:
        filename (str, optional): Path to the lens configuration JSON file. Defaults to None.
        device (str, optional): Computation device ('cpu' or 'cuda'). Defaults to None.
        dtype (torch.dtype, optional): Data type for computations. Defaults to torch.float64.
    """
    super().__init__(device=device, dtype=dtype)

    # Load lens file
    if filename is not None:
        self.read_lens_json(filename)
    else:
        self.geolens = None
        self.doe = None
        # Set default sensor size and resolution if no file provided
        self.sensor_size = (8.0, 8.0)
        self.sensor_res = (2000, 2000)
        print(
            f"No lens file provided. Using default sensor_size: {self.sensor_size} mm, "
            f"sensor_res: {self.sensor_res} pixels. Use set_sensor() to change."
        )

    self.double()

read_lens_json

read_lens_json(filename)

Read the lens configuration from a JSON file.

Loads a :class:GeoLens and associated DOE from the specified file. A Plane surface is appended to the GeoLens surface list as a placeholder for the DOE plane.

Supported DOE types: binary2, pixel2d, fresnel, zernike, grating.

Parameters:

Name Type Description Default
filename str

Path to the JSON configuration file. Must contain a "DOE" key with a "type" field.

required

Raises:

Type Description
ValueError

If the DOE type in the file is not supported.

Source code in deeplens/optics/hybridlens.py
def read_lens_json(self, filename):
    """Read the lens configuration from a JSON file.

    Loads a :class:`GeoLens` and associated DOE from the specified file.
    A ``Plane`` surface is appended to the GeoLens surface list as a
    placeholder for the DOE plane.

    Supported DOE types: ``binary2``, ``pixel2d``, ``fresnel``,
    ``zernike``, ``grating``.

    Args:
        filename (str): Path to the JSON configuration file.  Must
            contain a ``"DOE"`` key with a ``"type"`` field.

    Raises:
        ValueError: If the DOE type in the file is not supported.
    """
    # Load geolens
    geolens = GeoLens(filename=filename, device=self.device)

    # Load DOE (diffractive surface)
    with open(filename, "r") as f:
        data = json.load(f)

        doe_dict = data["DOE"]
        doe_param_model = doe_dict["type"].lower()
        if doe_param_model == "binary2":
            doe = Binary2.init_from_dict(doe_dict)
        elif doe_param_model == "pixel2d":
            doe = Pixel2D.init_from_dict(doe_dict)
        elif doe_param_model == "fresnel":
            doe = Fresnel.init_from_dict(doe_dict)
        elif doe_param_model == "zernike":
            doe = Zernike.init_from_dict(doe_dict)
        elif doe_param_model == "grating":
            doe = Grating.init_from_dict(doe_dict)
        else:
            raise ValueError(f"Unsupported DOE parameter model: {doe_param_model}")
        self.doe = doe

    # Add a Plane/Phase surface to GeoLens (DOE placeholder)
    r_doe = float(np.sqrt(doe.w**2 + doe.h**2) / 2)
    geolens.surfaces.append(Plane(d=doe.d.item(), r=r_doe, mat2="air"))
    # r_doe = float(np.sqrt(doe.w**2 + doe.h**2) / 2)
    # geolens.surfaces.append(Phase(r=r_doe, d=doe.d))
    self.geolens = geolens
    self.foclen = geolens.foclen

    # Update hybrid lens sensor resolution and pixel size
    self.set_sensor(sensor_size=geolens.sensor_size, sensor_res=geolens.sensor_res)
    self.to(self.device)

write_lens_json

write_lens_json(lens_path)

Write the lens configuration to a JSON file.

Serialises the GeoLens surfaces (excluding the DOE placeholder) and the DOE configuration into a single JSON file that can be reloaded with :meth:read_lens_json.

Parameters:

Name Type Description Default
lens_path str

Output file path.

required
Source code in deeplens/optics/hybridlens.py
def write_lens_json(self, lens_path):
    """Write the lens configuration to a JSON file.

    Serialises the ``GeoLens`` surfaces (excluding the DOE placeholder)
    and the ``DOE`` configuration into a single JSON file that can be
    reloaded with :meth:`read_lens_json`.

    Args:
        lens_path (str): Output file path.
    """
    geolens = self.geolens
    data = {}
    data["info"] = geolens.lens_info if hasattr(geolens, "lens_info") else "None"
    data["foclen"] = round(geolens.foclen, 4)
    data["fnum"] = round(geolens.fnum, 4)
    data["r_sensor"] = round(geolens.r_sensor, 4)
    data["d_sensor"] = round(geolens.d_sensor.item(), 4)
    data["sensor_size"] = [round(i, 4) for i in geolens.sensor_size]
    data["sensor_res"] = geolens.sensor_res

    # Geolens
    data["surfaces"] = []
    for i, s in enumerate(geolens.surfaces[:-1]):
        surf_dict = s.surf_dict()

        # To exclude the last surface (DOE)
        if i < len(geolens.surfaces) - 2:
            surf_dict["d_next"] = round(
                geolens.surfaces[i + 1].d.item() - geolens.surfaces[i].d.item(), 3
            )
        else:
            surf_dict["d_next"] = round(
                geolens.d_sensor.item() - geolens.surfaces[i].d.item(), 3
            )

        data["surfaces"].append(surf_dict)

    # DOE
    data["DOE"] = self.doe.surf_dict()

    with open(lens_path, "w") as f:
        json.dump(data, f, indent=4)

analysis

analysis(save_name='./test.png')

Run a quick visual analysis of the hybrid lens.

Generates two figures: the 2D lens layout (saved to save_name) and the DOE phase map (saved to <save_name>_doe.png).

Parameters:

Name Type Description Default
save_name str

Base file path for the layout image. The DOE phase-map image is derived by appending _doe before the extension. Defaults to './test.png'.

'./test.png'
Source code in deeplens/optics/hybridlens.py
def analysis(self, save_name="./test.png"):
    """Run a quick visual analysis of the hybrid lens.

    Generates two figures: the 2D lens layout (saved to *save_name*) and
    the DOE phase map (saved to ``<save_name>_doe.png``).

    Args:
        save_name (str, optional): Base file path for the layout image.
            The DOE phase-map image is derived by appending ``_doe``
            before the extension.  Defaults to ``'./test.png'``.
    """
    self.draw_layout(save_name=save_name)
    self.doe.draw_phase_map(save_name=f"{save_name}_doe.png")

double

double()

Convert the GeoLens and DOE to float64 precision.

Double precision is required for numerically stable phase accumulation during coherent ray tracing and ASM propagation. Called automatically by :meth:__init__.

Source code in deeplens/optics/hybridlens.py
def double(self):
    """Convert the GeoLens and DOE to ``float64`` precision.

    Double precision is required for numerically stable phase
    accumulation during coherent ray tracing and ASM propagation.
    Called automatically by :meth:`__init__`.
    """
    self.geolens.astype(torch.float64)
    self.doe.astype(torch.float64)

refocus

refocus(foc_dist)

Refocus the hybrid lens to a given object distance.

Only the GeoLens sensor-to-last-surface spacing is adjusted; the DOE remains fixed relative to the refractive group (it is physically cemented to the lens barrel).

Parameters:

Name Type Description Default
foc_dist float

Target focus distance in [mm] (negative, towards the object).

required
Source code in deeplens/optics/hybridlens.py
def refocus(self, foc_dist):
    """Refocus the hybrid lens to a given object distance.

    Only the ``GeoLens`` sensor-to-last-surface spacing is adjusted; the
    DOE remains fixed relative to the refractive group (it is physically
    cemented to the lens barrel).

    Args:
        foc_dist (float): Target focus distance in [mm] (negative,
            towards the object).
    """
    self.geolens.refocus(foc_dist)

calc_scale

calc_scale(depth)

Calculate the object-to-image magnification scale factor.

Delegates to the embedded :class:GeoLens.

Parameters:

Name Type Description Default
depth float

Object distance in [mm] (negative, towards the object).

required

Returns:

Name Type Description
float

Scale factor mapping normalised sensor coordinates [-1, 1] to physical object-space coordinates [mm].

Source code in deeplens/optics/hybridlens.py
def calc_scale(self, depth):
    """Calculate the object-to-image magnification scale factor.

    Delegates to the embedded :class:`GeoLens`.

    Args:
        depth (float): Object distance in [mm] (negative, towards the
            object).

    Returns:
        float: Scale factor mapping normalised sensor coordinates
            ``[-1, 1]`` to physical object-space coordinates [mm].
    """
    return self.geolens.calc_scale(depth)

doe_field

doe_field(point, wvln=DEFAULT_WAVE, spp=SPP_COHERENT)

Compute the complex wave field at the DOE plane via coherent ray tracing.

Similar to GeoLens.pupil_field(), but evaluates the field at the last surface (DOE plane) instead of the exit pupil. The returned wavefront encodes amplitude, phase, and all diffraction-order information needed for subsequent DOE modulation and ASM propagation.

Parameters:

Name Type Description Default
point Tensor

Point source position, shape (3,) or (1, 3) as [x, y, z] in normalised sensor coordinates for x/y and mm for z.

required
wvln float

Wavelength in [um]. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE
spp int

Number of rays to sample. Must be

= 1,000,000 for accurate coherent simulation. Defaults to SPP_COHERENT.

SPP_COHERENT

Returns:

Name Type Description
tuple
  • wavefront (torch.Tensor) -- Complex wavefront at the DOE plane, shape [H, W].
  • psf_center (list[float]) -- Estimated PSF centre on the sensor in normalised coordinates [x, y].

Raises:

Type Description
AssertionError

If spp < 1,000,000 or the default dtype is not float64.

Source code in deeplens/optics/hybridlens.py
def doe_field(self, point, wvln=DEFAULT_WAVE, spp=SPP_COHERENT):
    """Compute the complex wave field at the DOE plane via coherent ray tracing.

    Similar to ``GeoLens.pupil_field()``, but evaluates the field at the
    last surface (DOE plane) instead of the exit pupil.  The returned
    wavefront encodes amplitude, phase, and all diffraction-order
    information needed for subsequent DOE modulation and ASM propagation.

    Args:
        point (torch.Tensor): Point source position, shape ``(3,)`` or
            ``(1, 3)`` as ``[x, y, z]`` in normalised sensor coordinates
            for x/y and mm for z.
        wvln (float, optional): Wavelength in [um].  Defaults to
            ``DEFAULT_WAVE``.
        spp (int, optional): Number of rays to sample.  Must be
            >= 1,000,000 for accurate coherent simulation.  Defaults to
            ``SPP_COHERENT``.

    Returns:
        tuple:
            - **wavefront** (*torch.Tensor*) -- Complex wavefront at the
              DOE plane, shape ``[H, W]``.
            - **psf_center** (*list[float]*) -- Estimated PSF centre on
              the sensor in normalised coordinates ``[x, y]``.

    Raises:
        AssertionError: If *spp* < 1,000,000 or the default dtype is not
            ``float64``.
    """
    assert spp >= 1_000_000, (
        "Coherent ray tracing spp is too small, "
        "which may lead to inaccurate simulation."
    )
    assert torch.get_default_dtype() == torch.float64, (
        "Default dtype must be set to float64 for accurate phase tracing."
    )

    geolens, doe = self.geolens, self.doe

    if point.dim() == 1:
        point = point.unsqueeze(0)
    point = point.to(self.device)

    # Calculate ray origin in the object space
    scale = geolens.calc_scale(point[:, 2].item())
    point_obj = point.clone()
    point_obj[:, 0] = point[:, 0] * scale * geolens.sensor_size[1] / 2
    point_obj[:, 1] = point[:, 1] * scale * geolens.sensor_size[0] / 2

    # Determine ray center via chief ray
    pointc_chief_ray = geolens.psf_center(point_obj, method="chief_ray")[
        0
    ]  # shape [2]

    # Ray tracing to the DOE plane
    ray = geolens.sample_from_points(points=point_obj, num_rays=spp, wvln=wvln)
    ray.coherent = True
    ray, _ = geolens.trace(ray)
    ray = ray.prop_to(doe.d)

    # Calculate full-resolution complex field for exit-pupil diffraction
    wavefront = forward_integral(
        ray.flip_xy(),
        ps=doe.ps,
        ks=doe.res[0],
        pointc=torch.zeros_like(point[:, :2]),
    ).squeeze(0)  # shape [H, W]

    # Compute PSF center based on chief ray
    psf_center = [
        pointc_chief_ray[0] / geolens.sensor_size[0] * 2,
        pointc_chief_ray[1] / geolens.sensor_size[1] * 2,
    ]

    return wavefront, psf_center

psf

psf(points=[0.0, 0.0, -10000.0], ks=PSF_KS, wvln=DEFAULT_WAVE, spp=SPP_COHERENT)

Compute a single-point monochromatic PSF using the ray-wave model.

The returned PSF includes all diffraction orders with physically correct diffraction efficiencies. The pipeline is:

  1. Coherent ray tracing through the GeoLens to obtain the complex wavefront at the DOE plane.
  2. DOE phase modulation applied to the wavefront.
  3. ASM propagation to the sensor, intensity calculation, cropping, and normalisation.

Parameters:

Name Type Description Default
points list or Tensor

[x, y, z] point source coordinates. x, y are in normalised sensor coordinates [-1, 1]; z is depth in [mm]. Defaults to [0.0, 0.0, -10000.0].

[0.0, 0.0, -10000.0]
ks int or None

Output PSF patch size. If None, returns the central quarter of the full-sensor intensity. Defaults to PSF_KS.

PSF_KS
wvln float

Wavelength in [um]. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE
spp int

Number of coherent rays to sample. Defaults to SPP_COHERENT.

SPP_COHERENT

Returns:

Type Description

torch.Tensor: Normalised PSF patch (sums to 1), shape [ks, ks]. Returned in float32 precision.

Raises:

Type Description
ValueError

If the default dtype is not float64 (call :meth:double first).

Source code in deeplens/optics/hybridlens.py
def psf(
    self,
    points=[0.0, 0.0, -10000.0],
    ks=PSF_KS,
    wvln=DEFAULT_WAVE,
    spp=SPP_COHERENT,
):
    """Compute a single-point monochromatic PSF using the ray-wave model.

    The returned PSF includes all diffraction orders with physically
    correct diffraction efficiencies.  The pipeline is:

    1. Coherent ray tracing through the ``GeoLens`` to obtain the complex
       wavefront at the DOE plane.
    2. DOE phase modulation applied to the wavefront.
    3. ASM propagation to the sensor, intensity calculation, cropping, and
       normalisation.

    Args:
        points (list or torch.Tensor, optional): ``[x, y, z]`` point
            source coordinates.  *x, y* are in normalised sensor
            coordinates ``[-1, 1]``; *z* is depth in [mm].  Defaults to
            ``[0.0, 0.0, -10000.0]``.
        ks (int or None, optional): Output PSF patch size.  If ``None``,
            returns the central quarter of the full-sensor intensity.
            Defaults to ``PSF_KS``.
        wvln (float, optional): Wavelength in [um].  Defaults to
            ``DEFAULT_WAVE``.
        spp (int, optional): Number of coherent rays to sample.  Defaults
            to ``SPP_COHERENT``.

    Returns:
        torch.Tensor: Normalised PSF patch (sums to 1), shape
            ``[ks, ks]``.  Returned in ``float32`` precision.

    Raises:
        ValueError: If the default dtype is not ``float64`` (call
            :meth:`double` first).
    """
    # Check double precision
    if not torch.get_default_dtype() == torch.float64:
        raise ValueError(
            "Please call HybridLens.double() to set the default dtype to float64 for accurate phase tracing."
        )

    # Check lens last surface
    assert isinstance(self.geolens.surfaces[-1], Phase) or isinstance(
        self.geolens.surfaces[-1], Plane
    ), "The last lens surface should be a DOE."
    geolens, doe = self.geolens, self.doe

    # Compute pupil field by coherent ray tracing
    if isinstance(points, list):
        point0 = torch.tensor(points)
    elif isinstance(points, torch.Tensor):
        point0 = points
    else:
        raise ValueError("point should be a list or a torch.Tensor.")

    wavefront, psfc = self.doe_field(point=point0, wvln=wvln, spp=spp)
    wavefront = wavefront.squeeze(0)  # shape of [H, W]

    # DOE phase modulation. We have to flip the phase map because the wavefront has been flipped
    phase_map = torch.flip(doe.get_phase_map(wvln), [-1, -2])
    wavefront = wavefront * torch.exp(1j * phase_map)

    # Propagate wave field to sensor plane
    h, w = wavefront.shape
    wavefront = F.pad(
        wavefront.unsqueeze(0).unsqueeze(0),
        [h // 2, h // 2, w // 2, w // 2],
        mode="constant",
        value=0,
    )
    sensor_field = AngularSpectrumMethod(
        wavefront, z=geolens.d_sensor - doe.d, wvln=wvln, ps=doe.ps, padding=False
    )

    # Compute PSF (intensity distribution)
    psf_inten = sensor_field.abs() ** 2
    psf_inten = (
        F.interpolate(
            psf_inten,
            scale_factor=geolens.sensor_res[0] / h,
            mode="bilinear",
            align_corners=False,
        )
        .squeeze(0)
        .squeeze(0)
    )

    # Calculate PSF center index and crop valid PSF region (Consider both interplation and padding)
    if ks is not None:
        h, w = psf_inten.shape[-2:]
        psfc_idx_i = ((2 - psfc[1]) * h / 4).round().long()
        psfc_idx_j = ((2 + psfc[0]) * w / 4).round().long()

        # Pad to avoid invalid edge region
        psf_inten_pad = F.pad(
            psf_inten,
            [ks // 2, ks // 2, ks // 2, ks // 2],
            mode="constant",
            value=0,
        )
        psf = psf_inten_pad[
            psfc_idx_i : psfc_idx_i + ks, psfc_idx_j : psfc_idx_j + ks
        ]
    else:
        h, w = psf_inten.shape[-2:]
        psf = psf_inten[
            int(h / 2 - h / 4) : int(h / 2 + h / 4),
            int(w / 2 - w / 4) : int(w / 2 + w / 4),
        ]

    # Normalize and convert to float precision
    psf /= psf.sum()  # shape of [ks, ks] or [h, w]
    return diff_float(psf)

draw_layout

draw_layout(save_name='./DOELens.png', depth=-10000.0, ax=None, fig=None)

Draw the hybrid-lens layout with ray paths and wave-propagation arcs.

Renders the refractive elements via GeoLens.draw_lens_2d(), traces rays at three field angles (on-axis, 0.707x, 0.99x full field), and overlays concentric arcs between the DOE and sensor to illustrate the wave-propagation region.

Parameters:

Name Type Description Default
save_name str

File path to save the figure (used only when ax is None). Defaults to './DOELens.png'.

'./DOELens.png'
depth float

Object depth [mm] for the traced rays. Defaults to -10000.0.

-10000.0
ax Axes

Pre-existing axes to draw into. If None, a new figure is created and saved.

None
fig Figure

Pre-existing figure. Required when ax is provided.

None

Returns:

Type Description

tuple or None: (ax, fig) when ax was provided; otherwise the figure is saved to save_name and nothing is returned.

Source code in deeplens/optics/hybridlens.py
@torch.no_grad()
def draw_layout(self, save_name="./DOELens.png", depth=-10000.0, ax=None, fig=None):
    """Draw the hybrid-lens layout with ray paths and wave-propagation arcs.

    Renders the refractive elements via ``GeoLens.draw_lens_2d()``, traces
    rays at three field angles (on-axis, 0.707x, 0.99x full field), and
    overlays concentric arcs between the DOE and sensor to illustrate the
    wave-propagation region.

    Args:
        save_name (str, optional): File path to save the figure (used only
            when *ax* is ``None``).  Defaults to ``'./DOELens.png'``.
        depth (float, optional): Object depth [mm] for the traced rays.
            Defaults to ``-10000.0``.
        ax (matplotlib.axes.Axes, optional): Pre-existing axes to draw
            into.  If ``None``, a new figure is created and saved.
        fig (matplotlib.figure.Figure, optional): Pre-existing figure.
            Required when *ax* is provided.

    Returns:
        tuple or None: ``(ax, fig)`` when *ax* was provided; otherwise
            the figure is saved to *save_name* and nothing is returned.
    """
    geolens = self.geolens

    # Draw lens layout
    if ax is None:
        ax, fig = geolens.draw_lens_2d()
        save_fig = True
    else:
        save_fig = False

    # Draw light path
    color_list = ["#CC0000", "#006600", "#0066CC"]
    views = [
        0.0,
        float(np.rad2deg(geolens.rfov) * 0.707),
        float(np.rad2deg(geolens.rfov) * 0.99),
    ]
    arc_radi_list = [0.1, 0.4, 0.7, 1.0, 1.4, 1.8]
    num_rays = 7
    for i, view in enumerate(views):
        # Draw ray tracing
        ray = geolens.sample_point_source_2D(
            depth=depth,
            fov=view,
            num_rays=num_rays,
            entrance_pupil=True,
            wvln=WAVE_RGB[2 - i],
        )
        ray.prop_to(-1.0)

        ray, ray_o_record = geolens.trace(ray=ray, record=True)
        ax, fig = geolens.draw_ray_2d(
            ray_o_record, ax=ax, fig=fig, color=color_list[i]
        )

        # Draw wave propagation
        # Calculate ray center for wave propagation visualization
        ray_center_doe = (
            ((ray.o * ray.is_valid.unsqueeze(-1)).sum(dim=0) / ray.is_valid.sum())
            .cpu()
            .numpy()
        )  # shape [3]
        ray.prop_to(geolens.d_sensor)  # shape [num_rays, 3]
        ray_center_sensor = (
            ((ray.o * ray.is_valid.unsqueeze(-1)).sum(dim=0) / ray.is_valid.sum())
            .cpu()
            .numpy()
        )  # shape [3]

        arc_radi = ray_center_sensor[2] - ray_center_doe[2]
        chief_theta = np.rad2deg(
            np.arctan2(
                ray_center_sensor[0] - ray_center_doe[0],
                ray_center_sensor[2] - ray_center_doe[2],
            )
        )
        theta1 = chief_theta - 10
        theta2 = chief_theta + 10

        for j in arc_radi_list:
            arc_radi_j = arc_radi * j
            arc = patches.Arc(
                (ray_center_sensor[2], ray_center_sensor[0]),
                arc_radi_j,
                arc_radi_j,
                angle=180.0,
                theta1=theta1,
                theta2=theta2,
                color=color_list[i],
            )
            ax.add_patch(arc)

    if save_fig:
        # Save figure
        ax.axis("off")
        ax.set_title("DOE Lens")
        fig.savefig(save_name, bbox_inches="tight", format="png", dpi=600)
        plt.close()
    else:
        return ax, fig

get_optimizer

get_optimizer(doe_lr=0.0001, lens_lr=[0.0001, 0.0001, 0.01, 1e-05], lr_decay=0.01)

Build an Adam optimiser for joint lens + DOE design.

Collects trainable parameters from both the GeoLens (surface thicknesses, curvatures, conic constants, aspheric coefficients) and the DOE phase profile into a single optimiser with per-group learning rates.

Parameters:

Name Type Description Default
doe_lr float

Learning rate for DOE phase parameters. Defaults to 1e-4.

0.0001
lens_lr list[float]

Per-parameter-group learning rates for the GeoLens, ordered as [thickness_d, curvature_c, conic_k, aspheric_a]. Defaults to [1e-4, 1e-4, 1e-2, 1e-5].

[0.0001, 0.0001, 0.01, 1e-05]
lr_decay float

Multiplicative decay applied to higher-order aspheric coefficients. Defaults to 0.01.

0.01

Returns:

Type Description

torch.optim.Adam: Configured optimiser over all trainable parameters.

Source code in deeplens/optics/hybridlens.py
def get_optimizer(
    self, doe_lr=1e-4, lens_lr=[1e-4, 1e-4, 1e-2, 1e-5], lr_decay=0.01
):
    """Build an Adam optimiser for joint lens + DOE design.

    Collects trainable parameters from both the ``GeoLens`` (surface
    thicknesses, curvatures, conic constants, aspheric coefficients) and
    the DOE phase profile into a single optimiser with per-group learning
    rates.

    Args:
        doe_lr (float, optional): Learning rate for DOE phase parameters.
            Defaults to ``1e-4``.
        lens_lr (list[float], optional): Per-parameter-group learning
            rates for the GeoLens, ordered as
            ``[thickness_d, curvature_c, conic_k, aspheric_a]``.
            Defaults to ``[1e-4, 1e-4, 1e-2, 1e-5]``.
        lr_decay (float, optional): Multiplicative decay applied to
            higher-order aspheric coefficients.  Defaults to ``0.01``.

    Returns:
        torch.optim.Adam: Configured optimiser over all trainable
            parameters.
    """
    params = []
    params += self.geolens.get_optimizer_params(lrs=lens_lr, decay=lr_decay)
    params += self.doe.get_optimizer_params(lr=doe_lr)

    optimizer = torch.optim.Adam(params)
    return optimizer

Pure wave-optics lens using diffractive surfaces and scalar diffraction propagation.

deeplens.optics.DiffractiveLens

DiffractiveLens(filename=None, device=None)

Bases: Lens

Paraxial diffractive lens in which each element is modelled as a phase surface.

Every optical element (converging lens, DOE, metasurface, …) is represented by a phase function applied to an incoming complex wavefront. Propagation between surfaces uses the Angular Spectrum Method (ASM). This model is simple and fast, but accurate only in the paraxial regime (it does not account for higher-order geometric aberrations).

Attributes:

Name Type Description
surfaces list

Ordered list of diffractive/phase surfaces.

d_sensor Tensor

Distance from the last surface to the sensor plane [mm].

Notes

Operates in torch.float64 by default for numerical stability of the wave-propagation step.

Initialize a diffractive lens.

Parameters:

Name Type Description Default
filename str

Path to the lens configuration JSON file. If provided, loads the lens configuration from file. Defaults to None.

None
device str

Computation device ('cpu' or 'cuda'). Defaults to 'cpu'.

None
Source code in deeplens/optics/diffraclens.py
def __init__(
    self,
    filename=None,
    device=None,
):
    """Initialize a diffractive lens.

    Args:
        filename (str, optional): Path to the lens configuration JSON file. If provided, loads the lens configuration from file. Defaults to None.
        device (str, optional): Computation device ('cpu' or 'cuda'). Defaults to 'cpu'.
    """
    super().__init__(device=device)

    # Load lens file
    if filename is not None:
        self.read_lens_json(filename)
    else:
        self.surfaces = []
        # Set default sensor size and resolution if no file provided
        self.sensor_size = (8.0, 8.0)
        self.sensor_res = (2000, 2000)

    self.double()

load_example1 classmethod

load_example1()

Create an example diffractive lens with a single Fresnel DOE.

Returns:

Name Type Description
DiffractiveLens

A configured diffractive lens with a Fresnel surface at f=50mm, 4mm size, and 4000 resolution.

Source code in deeplens/optics/diffraclens.py
@classmethod
def load_example1(cls):
    """Create an example diffractive lens with a single Fresnel DOE.

    Returns:
        DiffractiveLens: A configured diffractive lens with a Fresnel surface
            at f=50mm, 4mm size, and 4000 resolution.
    """
    self = cls(sensor_size=(4.0, 4.0), sensor_res=(2000, 2000))

    # Diffractive Fresnel DOE
    self.surfaces = [Fresnel(f0=50, d=0, size=4, res=4000)]

    # Sensor
    self.d_sensor = torch.tensor(50)

    self.to(self.device)
    return self

load_example2 classmethod

load_example2()

Create an example diffractive lens with a thin lens and binary DOE combination.

Returns:

Name Type Description
DiffractiveLens

A configured diffractive lens with a ThinLens (f=50mm) and a Binary2 DOE, both at 4mm size and 4000 resolution.

Source code in deeplens/optics/diffraclens.py
@classmethod
def load_example2(cls):
    """Create an example diffractive lens with a thin lens and binary DOE combination.

    Returns:
        DiffractiveLens: A configured diffractive lens with a ThinLens (f=50mm)
            and a Binary2 DOE, both at 4mm size and 4000 resolution.
    """
    self = cls(sensor_size=(8.0, 8.0), sensor_res=(2000, 2000))

    # Diffractive Fresnel DOE
    self.surfaces = [
        ThinLens(f0=50, d=0, size=4, res=4000),
        Binary2(d=0, size=4, res=4000),
    ]

    # Sensor
    self.d_sensor = torch.tensor(50)
    self.sensor_size = (8.0, 8.0)
    self.sensor_res = (2000, 2000)

    self.to(self.device)
    return self

read_lens_json

read_lens_json(filename)

Load the lens configuration from a JSON file.

Reads lens parameters including sensor configuration and diffractive surfaces from the specified JSON file. If sensor_size or sensor_res are not provided, defaults of 8mm x 8mm and 2000x2000 pixels will be used.

Parameters:

Name Type Description Default
filename str

Path to the JSON configuration file.

required
Source code in deeplens/optics/diffraclens.py
def read_lens_json(self, filename):
    """Load the lens configuration from a JSON file.

    Reads lens parameters including sensor configuration and diffractive surfaces
    from the specified JSON file. If sensor_size or sensor_res are not provided,
    defaults of 8mm x 8mm and 2000x2000 pixels will be used.

    Args:
        filename (str): Path to the JSON configuration file.
    """
    assert filename.endswith(".json"), "File must be a .json file."

    with open(filename, "r") as f:
        # Lens general info
        data = json.load(f)
        self.d_sensor = torch.tensor(data["d_sensor"])
        self.lens_info = data.get("info", "None")

        # Read sensor_size with default
        if "sensor_size" in data:
            self.sensor_size = tuple(data["sensor_size"])
        else:
            self.sensor_size = (8.0, 8.0)
            print(
                f"Sensor_size not found in lens file. Using default: {self.sensor_size} mm. "
                "Consider specifying sensor_size in the lens file or using set_sensor()."
            )

        # Read sensor_res with default
        if "sensor_res" in data:
            self.sensor_res = tuple(data["sensor_res"])
        else:
            self.sensor_res = (2000, 2000)
            print(
                f"Sensor_res not found in lens file. Using default: {self.sensor_res} pixels. "
                "Consider specifying sensor_res in the lens file or using set_sensor()."
            )

        # Load diffractive surfaces/elements
        d = 0.0
        self.surfaces = []
        for surf_dict in data["surfaces"]:
            surf_dict["d"] = d

            if surf_dict["type"].lower() == "binary2":
                s = Binary2.init_from_dict(surf_dict)
            elif surf_dict["type"].lower() == "fresnel":
                s = Fresnel.init_from_dict(surf_dict)
            elif surf_dict["type"].lower() == "pixel2d":
                s = Pixel2D.init_from_dict(surf_dict)
            elif surf_dict["type"].lower() == "thinlens":
                s = ThinLens.init_from_dict(surf_dict)
            elif surf_dict["type"].lower() == "zernike":
                s = Zernike.init_from_dict(surf_dict)
            else:
                raise ValueError(
                    f"Diffractive surface type {surf_dict['type']} not implemented."
                )

            self.surfaces.append(s)
            d_next = surf_dict["d_next"]
            d += d_next

write_lens_json

write_lens_json(filename)

Write the lens configuration to a JSON file.

Saves all lens parameters including sensor configuration and diffractive surface data to the specified file.

Parameters:

Name Type Description Default
filename str

Output path for the JSON file.

required
Source code in deeplens/optics/diffraclens.py
def write_lens_json(self, filename):
    """Write the lens configuration to a JSON file.

    Saves all lens parameters including sensor configuration and
    diffractive surface data to the specified file.

    Args:
        filename (str): Output path for the JSON file.
    """
    assert filename.endswith(".json"), "File must be a .json file."

    # Save lens to a file
    data = {}
    data["info"] = self.lens_info if hasattr(self, "lens_info") else "None"
    data["surfaces"] = []
    data["d_sensor"] = round(self.d_sensor.item(), 3)
    data["l_sensor"] = round(self.l_sensor, 3)
    data["sensor_res"] = self.sensor_res

    # Save diffractive surfaces
    for i, s in enumerate(self.surfaces):
        surf_dict = {"idx": i + 1}

        if isinstance(s, Pixel2D):
            surf_data = s.surf_dict(filename.replace(".json", "_pixel2d.pth"))
        else:
            surf_data = s.surf_dict()

        surf_dict.update(surf_data)

        if i < len(self.surfaces) - 1:
            surf_dict["d_next"] = (
                self.surfaces[i + 1].d.item() - self.surfaces[i].d.item()
            )

        data["surfaces"].append(surf_dict)

    # Save data to a file
    with open(filename, "w") as f:
        json.dump(data, f, indent=4)

__call__

__call__(wave)

Propagate a wave through the lens system.

Source code in deeplens/optics/diffraclens.py
def __call__(self, wave):
    """Propagate a wave through the lens system."""
    return self.forward(wave)

forward

forward(wave)

Propagate a wave through the diffractive lens system to the sensor.

Sequentially applies phase modulation from each diffractive surface, then propagates the wave to the sensor plane using wave optics.

Parameters:

Name Type Description Default
wave ComplexWave

Input wave field entering the lens system.

required

Returns:

Name Type Description
ComplexWave

Output wave field at the sensor plane.

Source code in deeplens/optics/diffraclens.py
def forward(self, wave):
    """Propagate a wave through the diffractive lens system to the sensor.

    Sequentially applies phase modulation from each diffractive surface, then propagates
    the wave to the sensor plane using wave optics.

    Args:
        wave (ComplexWave): Input wave field entering the lens system.

    Returns:
        ComplexWave: Output wave field at the sensor plane.
    """
    # Propagate to DOE
    for surf in self.surfaces:
        wave = surf(wave)

    # Propagate to sensor
    wave = wave.prop_to(self.d_sensor.item())

    return wave

render_mono

render_mono(img, wvln=DEFAULT_WAVE, ks=PSF_KS)

Simulate monochromatic lens blur by convolving an image with the point spread function.

Parameters:

Name Type Description Default
img Tensor

Input image. Shape: (B, 1, H, W)

required
wvln float

Wavelength. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE
ks int

PSF kernel size. Defaults to PSF_KS.

PSF_KS

Returns:

Type Description

torch.Tensor: Rendered image after applying lens blur with shape (B, 1, H, W).

Source code in deeplens/optics/diffraclens.py
def render_mono(self, img, wvln=DEFAULT_WAVE, ks=PSF_KS):
    """Simulate monochromatic lens blur by convolving an image with the point spread function.

    Args:
        img (torch.Tensor): Input image. Shape: (B, 1, H, W)
        wvln (float, optional): Wavelength. Defaults to DEFAULT_WAVE.
        ks (int, optional): PSF kernel size. Defaults to PSF_KS.

    Returns:
        torch.Tensor: Rendered image after applying lens blur with shape (B, 1, H, W).
    """
    psf = self.psf_infinite(wvln=wvln, ks=ks).unsqueeze(0)  # (1, ks, ks)
    img_render = conv_psf(img, psf)
    return img_render

psf

psf(depth=float('inf'), wvln=DEFAULT_WAVE, ks=PSF_KS, upsample_factor=1)

Calculate monochromatic point PSF by wave propagation approach.

Parameters:

Name Type Description Default
depth float

Depth of the point source. Defaults to float('inf').

float('inf')
wvln float

Wavelength in micrometers. Defaults to DEFAULT_WAVE.

DEFAULT_WAVE
ks int

PSF kernel size. Defaults to PSF_KS.

PSF_KS
upsample_factor int

Upsampling factor to meet Nyquist sampling constraint. Defaults to 1.

1

Returns:

Name Type Description
psf_out tensor

PSF. shape [ks, ks]

Note

[1] Usually we only consider the on-axis PSF because paraxial approximation is implicitly applied for wave optical model. For the shifted phase issue, refer to "Modeling off-axis diffraction with the least-sampling angular spectrum method".

Source code in deeplens/optics/diffraclens.py
def psf(self, depth=float("inf"), wvln=DEFAULT_WAVE, ks=PSF_KS, upsample_factor=1):
    """Calculate monochromatic point PSF by wave propagation approach.

    Args:
        depth (float, optional): Depth of the point source. Defaults to float('inf').
        wvln (float, optional): Wavelength in micrometers. Defaults to DEFAULT_WAVE.
        ks (int, optional): PSF kernel size. Defaults to PSF_KS.
        upsample_factor (int, optional): Upsampling factor to meet Nyquist sampling constraint. Defaults to 1.

    Returns:
        psf_out (tensor): PSF. shape [ks, ks]

    Note:
        [1] Usually we only consider the on-axis PSF because paraxial approximation is implicitly applied for wave optical model. For the shifted phase issue, refer to "Modeling off-axis diffraction with the least-sampling angular spectrum method".
    """
    # Sample input wave field (We have to sample high resolution to meet Nyquist sampling constraint)
    field_res = [
        self.surfaces[0].res[0] * upsample_factor,
        self.surfaces[0].res[1] * upsample_factor,
    ]
    field_size = [
        self.surfaces[0].res[0] * self.surfaces[0].ps,
        self.surfaces[0].res[1] * self.surfaces[0].ps,
    ]
    if depth == float("inf"):
        inp_wave = ComplexWave.plane_wave(
            phy_size=field_size,
            res=field_res,
            wvln=wvln,
            z=0.0,
        ).to(self.device)
    else:
        inp_wave = ComplexWave.point_wave(
            point=[0.0, 0.0, depth],
            phy_size=field_size,
            res=field_res,
            wvln=wvln,
            z=0.0,
        ).to(self.device)

    # Calculate intensity on the sensor. Shape [H_sensor, W_sensor]
    output_wave = self.forward(inp_wave)
    intensity = output_wave.u.abs() ** 2

    # Interpolate wave to have the same pixel size as the sensor
    factor = output_wave.ps / self.pixel_size
    intensity = F.interpolate(
        intensity,
        scale_factor=(factor, factor),
        mode="bilinear",
        align_corners=False,
    )[0, 0, :, :]

    # Crop or pad wave to the sensor resolution
    intensity_h, intensity_w = intensity.shape[-2:]
    sensor_h, sensor_w = self.sensor_res
    if sensor_h < intensity_h or sensor_w < intensity_w:
        # crop
        start_h = (intensity_h - sensor_h) // 2
        start_w = (intensity_w - sensor_w) // 2
        intensity = intensity[
            start_h : start_h + sensor_h, start_w : start_w + sensor_w
        ]
    elif sensor_h > intensity_h or sensor_w > intensity_w:
        # pad
        pad_top = (sensor_h - intensity_h) // 2
        pad_bottom = sensor_h - intensity_h - pad_top
        pad_left = (sensor_w - intensity_w) // 2
        pad_right = sensor_w - intensity_w - pad_left
        intensity = F.pad(
            intensity,
            (pad_left, pad_right, pad_top, pad_bottom),
            mode="constant",
            value=0,
        )

    # Crop the valid patch from the full-resolution intensity map as the PSF
    coord_c_i = int(self.sensor_res[1] / 2)
    coord_c_j = int(self.sensor_res[0] / 2)
    intensity = F.pad(
        intensity,
        [ks // 2, ks // 2, ks // 2, ks // 2],
        mode="constant",
        value=0,
    )
    psf = intensity[coord_c_i : coord_c_i + ks, coord_c_j : coord_c_j + ks]

    # Normalize PSF
    psf /= psf.sum()
    psf = torch.flip(psf, [0, 1])

    return diff_float(psf)

draw_layout

draw_layout(save_name='./doelens.png')

Draw the lens layout diagram.

Visualizes the DOE and sensor positions in a 2D layout.

Parameters:

Name Type Description Default
save_name str

Path to save the figure. Defaults to './doelens.png'.

'./doelens.png'
Source code in deeplens/optics/diffraclens.py
def draw_layout(self, save_name="./doelens.png"):
    """Draw the lens layout diagram.

    Visualizes the DOE and sensor positions in a 2D layout.

    Args:
        save_name (str, optional): Path to save the figure. Defaults to './doelens.png'.
    """
    fig, ax = plt.subplots()

    # Draw DOE
    d = self.doe.d.item()
    doe_l = self.doe.l
    ax.plot(
        [d, d], [-doe_l / 2, doe_l / 2], "orange", linestyle="--", dashes=[1, 1]
    )

    # Draw sensor
    d = self.sensor.d.item()
    sensor_l = self.sensor.l
    width = 0.2  # Width of the rectangle
    rect = plt.Rectangle(
        (d - width / 2, -sensor_l / 2),
        width,
        sensor_l,
        facecolor="none",
        edgecolor="black",
        linewidth=1,
    )
    ax.add_patch(rect)

    ax.set_aspect("equal")
    ax.axis("off")
    fig.savefig(save_name, dpi=600, bbox_inches="tight")
    plt.close(fig)

draw_psf

draw_psf(depth=DEPTH, ks=PSF_KS, save_name='./psf_doelens.png', log_scale=True, eps=0.0001)

Draw on-axis RGB PSF.

Computes and saves a visualization of the RGB PSF for a given depth.

Parameters:

Name Type Description Default
depth float

Depth of the point source. Defaults to DEPTH.

DEPTH
ks int

Size of the PSF kernel in pixels. Defaults to PSF_KS.

PSF_KS
save_name str

Path to save the PSF image. Defaults to './psf_doelens.png'.

'./psf_doelens.png'
log_scale bool

If True, display PSF in log scale. Defaults to True.

True
eps float

Small value for log scale to avoid log(0). Defaults to 1e-4.

0.0001
Source code in deeplens/optics/diffraclens.py
def draw_psf(
    self,
    depth=DEPTH,
    ks=PSF_KS,
    save_name="./psf_doelens.png",
    log_scale=True,
    eps=1e-4,
):
    """Draw on-axis RGB PSF.

    Computes and saves a visualization of the RGB PSF for a given depth.

    Args:
        depth (float, optional): Depth of the point source. Defaults to DEPTH.
        ks (int, optional): Size of the PSF kernel in pixels. Defaults to PSF_KS.
        save_name (str, optional): Path to save the PSF image. Defaults to './psf_doelens.png'.
        log_scale (bool, optional): If True, display PSF in log scale. Defaults to True.
        eps (float, optional): Small value for log scale to avoid log(0). Defaults to 1e-4.
    """
    psf_rgb = self.psf_rgb(point=[0, 0, depth], ks=ks)

    if log_scale:
        psf_rgb = torch.log10(psf_rgb + eps)
        psf_rgb = (psf_rgb - psf_rgb.min()) / (psf_rgb.max() - psf_rgb.min())
        save_name = save_name.replace(".png", "_log.png")

    save_image(psf_rgb.unsqueeze(0), save_name, normalize=True)

get_optimizer

get_optimizer(lr)

Get optimizer for the lens parameters.

Parameters:

Name Type Description Default
lr float

Learning rate.

required

Returns:

Name Type Description
Optimizer

Optimizer object for lens parameters.

Source code in deeplens/optics/diffraclens.py
def get_optimizer(self, lr):
    """Get optimizer for the lens parameters.

    Args:
        lr (float): Learning rate.

    Returns:
        Optimizer: Optimizer object for lens parameters.
    """
    return self.doe.get_optimizer(lr=lr)

Thin-lens / circle-of-confusion model for simple depth-of-field and bokeh simulation.

deeplens.optics.ParaxialLens

ParaxialLens(foclen, fnum, sensor_size=None, sensor_res=None, device='cpu')

Bases: Lens

Thin-lens / ABCD-matrix model for fast defocus simulation.

Models the circle of confusion (CoC) caused by defocus but not higher-order optical aberrations. Useful as a fast baseline renderer for depth-of-field effects, as commonly used in Blender and similar tools.

Attributes:

Name Type Description
foclen float

Focal length [mm].

fnum float

F-number.

sensor_size tuple

Physical sensor size (W, H) [mm].

sensor_res tuple

Pixel resolution (W, H).

pixel_size float

Pixel pitch [mm].

Initialize a paraxial lens.

Parameters:

Name Type Description Default
foclen float

Focal length in [mm].

required
fnum float

F-number.

required
sensor_size tuple

Physical sensor size as (W, H) in [mm]. Defaults to (8.0, 8.0).

None
sensor_res tuple

Sensor resolution as (W, H) in pixels. Defaults to (2000, 2000).

None
device str

Computation device. Defaults to "cpu".

'cpu'
Source code in deeplens/optics/paraxiallens.py
def __init__(self, foclen, fnum, sensor_size=None, sensor_res=None, device="cpu"):
    """Initialize a paraxial lens.

    Args:
        foclen (float): Focal length in [mm].
        fnum (float): F-number.
        sensor_size (tuple, optional): Physical sensor size as (W, H) in [mm]. Defaults to (8.0, 8.0).
        sensor_res (tuple, optional): Sensor resolution as (W, H) in pixels. Defaults to (2000, 2000).
        device (str, optional): Computation device. Defaults to "cpu".
    """
    super(ParaxialLens, self).__init__(device=device)

    # Lens parameters
    self.foclen = foclen  # Focal length [mm]
    self.fnum = fnum

    # Sensor size and resolution with defaults
    if sensor_size is None:
        sensor_size = (8.0, 8.0)
        print(
            f"Sensor_size not provided. Using default: {sensor_size} mm. "
            "Use set_sensor() to change."
        )
    if sensor_res is None:
        sensor_res = (2000, 2000)
        print(
            f"Sensor_res not provided. Using default: {sensor_res} pixels. "
            "Use set_sensor() to change."
        )

    self.sensor_size = sensor_size
    self.sensor_res = sensor_res
    self.pixel_size = self.sensor_size[0] / self.sensor_res[0]  # Pixel size [mm]

    self.d_far = -20000.0
    self.d_close = -200.0
    self.refocus(foc_dist=-20000)

refocus

refocus(foc_dist)

Refocus the lens to a given object distance.

Parameters:

Name Type Description Default
foc_dist float

Focus distance in [mm]. Must be less than the focal length (i.e. beyond the focal point).

required

Raises:

Type Description
AssertionError

If foc_dist >= self.foclen.

Source code in deeplens/optics/paraxiallens.py
def refocus(self, foc_dist):
    """Refocus the lens to a given object distance.

    Args:
        foc_dist (float): Focus distance in [mm].  Must be less than the
            focal length (i.e. beyond the focal point).

    Raises:
        AssertionError: If *foc_dist* >= ``self.foclen``.
    """
    assert foc_dist < self.foclen, "Focus distance is too close."
    self.foc_dist = foc_dist

psf

psf(points, ks=PSF_KS, psf_type='gaussian', **kwargs)

PSF is modeled as a 2D uniform circular disk with diameter CoC.

Parameters:

Name Type Description Default
points Tensor

Points of the object. Shape [N, 3] or [3].

required
ks int

Kernel size.

PSF_KS
psf_type str

PSF type. "gaussian" or "pillbox".

'gaussian'
**kwargs

Additional arguments for psf(). Currently not used.

{}

Returns:

Name Type Description
psf Tensor

PSF kernels. Shape [ks, ks] or [N, ks, ks].

Source code in deeplens/optics/paraxiallens.py
def psf(self, points, ks=PSF_KS, psf_type="gaussian", **kwargs):
    """PSF is modeled as a 2D uniform circular disk with diameter CoC.

    Args:
        points (torch.Tensor): Points of the object. Shape [N, 3] or [3].
        ks (int): Kernel size.
        psf_type (str): PSF type. "gaussian" or "pillbox".
        **kwargs: Additional arguments for psf(). Currently not used.

    Returns:
        psf (torch.Tensor): PSF kernels. Shape [ks, ks] or [N, ks, ks].
    """
    points = points.to(self.device)

    # Handle single point vs multiple points
    if len(points.shape) == 1:
        points = points.unsqueeze(0)
        single_point = True
    else:
        single_point = False

    # Calculate circle of confusion for each point
    depths = points[:, 2]  # Shape [N]
    coc_values = self.coc(depths)  # Shape [N]

    # Convert CoC from mm to pixels and add minimum value for numerical stability
    coc_pixel = torch.clamp(
        coc_values / self.pixel_size, min=0.5
    )  # Shape [N], minimum 0.5 pixels
    coc_pixel = (
        coc_pixel.unsqueeze(-1).unsqueeze(-1).repeat(1, ks, ks)
    )  # Shape [N, ks, ks]
    coc_pixel_radius = coc_pixel / 2

    # Create coordinate meshgrid
    x, y = torch.meshgrid(
        torch.linspace(-ks / 2 + 1 / 2, ks / 2 - 1 / 2, ks),
        torch.linspace(-ks / 2 + 1 / 2, ks / 2 - 1 / 2, ks),
        indexing="xy",
    )
    x, y = x.to(self.device), y.to(self.device)
    distance_sq = x**2 + y**2

    # Create PSF
    if psf_type == "gaussian":
        # Gaussian PSF
        psf = torch.exp(-distance_sq / (2 * coc_pixel_radius**2)) / (
            2 * np.pi * coc_pixel_radius**2
        )
    elif psf_type == "pillbox":
        # Pillbox PSF
        psf = torch.ones_like(x)
    else:
        raise ValueError(f"Invalid PSF type: {psf_type}")

    # Apply circular mask
    psf_mask = distance_sq < coc_pixel_radius**2
    psf = psf * psf_mask

    # Normalize PSF to sum to 1
    psf = psf / (psf.sum(dim=(-1, -2), keepdim=True) + EPSILON)

    if single_point:
        psf = psf.squeeze(0)

    return psf

coc

coc(depth)

Calculate circle of confusion (CoC) [mm].

Parameters:

Name Type Description Default
depth Tensor

Depth of the object. Shape [B].

required

Returns:

Name Type Description
coc Tensor

Circle of confusion. Shape [B].

Reference

[1] https://en.wikipedia.org/wiki/Circle_of_confusion

Source code in deeplens/optics/paraxiallens.py
def coc(self, depth):
    """Calculate circle of confusion (CoC) [mm].

    Args:
        depth (torch.Tensor): Depth of the object. Shape [B].

    Returns:
        coc (torch.Tensor): Circle of confusion. Shape [B].

    Reference:
        [1] https://en.wikipedia.org/wiki/Circle_of_confusion
    """
    foc_dist = torch.tensor(
        self.foc_dist, device=self.device, dtype=depth.dtype
    ).abs()
    foclen = self.foclen
    fnum = self.fnum

    depth = torch.clamp(depth, self.d_far, self.d_close)
    depth = torch.abs(depth)

    # Calculate circle of confusion diameter, [mm]
    part1 = torch.abs(depth - foc_dist) / depth
    part2 = foclen**2 / (fnum * (foc_dist - foclen))
    coc = part1 * part2

    return coc

dof

dof(depth)

Calculate depth of field [mm].

Parameters:

Name Type Description Default
depth Tensor

Depth of the object. Shape [B].

required

Returns:

Name Type Description
dof Tensor

Depth of field. Shape [B].

Reference

[1] https://en.wikipedia.org/wiki/Depth_of_field

Source code in deeplens/optics/paraxiallens.py
def dof(self, depth):
    """Calculate depth of field [mm].

    Args:
        depth (torch.Tensor): Depth of the object. Shape [B].

    Returns:
        dof (torch.Tensor): Depth of field. Shape [B].

    Reference:
        [1] https://en.wikipedia.org/wiki/Depth_of_field
    """
    depth = torch.clamp(depth, self.d_far, self.d_close)
    depth_abs = torch.abs(depth)

    foclen = self.foclen
    fnum = self.fnum

    # Magnification factor
    m = foclen / (depth_abs - foclen)

    # CoC, [mm]
    coc = self.coc(depth)

    # Depth of field, [mm]
    part1 = 2 * fnum * coc * (m + 1)
    part2 = m**2 - (fnum * coc / foclen) ** 2
    dof = part1 / part2

    return dof

psf_rgb

psf_rgb(points, ks=PSF_KS, **kwargs)

Compute RGB PSF by replicating the monochrome PSF across three channels.

The paraxial model is achromatic, so all channels share the same PSF.

Parameters:

Name Type Description Default
points Tensor

Point source positions, shape [N, 3].

required
ks int

Kernel size. Defaults to PSF_KS.

PSF_KS
**kwargs

Forwarded to :meth:psf.

{}

Returns:

Type Description

torch.Tensor: RGB PSFs, shape [N, 3, ks, ks].

Source code in deeplens/optics/paraxiallens.py
def psf_rgb(self, points, ks=PSF_KS, **kwargs):
    """Compute RGB PSF by replicating the monochrome PSF across three channels.

    The paraxial model is achromatic, so all channels share the same PSF.

    Args:
        points (torch.Tensor): Point source positions, shape ``[N, 3]``.
        ks (int, optional): Kernel size. Defaults to ``PSF_KS``.
        **kwargs: Forwarded to :meth:`psf`.

    Returns:
        torch.Tensor: RGB PSFs, shape ``[N, 3, ks, ks]``.
    """
    psf = self.psf(points, ks=ks, psf_type="gaussian", **kwargs)
    return psf.unsqueeze(1).repeat(1, 3, 1, 1)

psf_map

psf_map(grid=(5, 5), ks=PSF_KS, depth=DEPTH, **kwargs)

Compute a spatially-uniform monochrome PSF map.

Because the paraxial model has no spatially-varying aberrations, every grid position receives the same on-axis PSF.

Parameters:

Name Type Description Default
grid tuple

Grid dimensions (rows, cols). Defaults to (5, 5).

(5, 5)
ks int

Kernel size. Defaults to PSF_KS.

PSF_KS
depth float

Object depth [mm]. Defaults to DEPTH.

DEPTH
**kwargs

Forwarded to :meth:psf.

{}

Returns:

Type Description

torch.Tensor: PSF map, shape [rows, cols, 1, ks, ks].

Source code in deeplens/optics/paraxiallens.py
def psf_map(self, grid=(5, 5), ks=PSF_KS, depth=DEPTH, **kwargs):
    """Compute a spatially-uniform monochrome PSF map.

    Because the paraxial model has no spatially-varying aberrations, every
    grid position receives the same on-axis PSF.

    Args:
        grid (tuple, optional): Grid dimensions ``(rows, cols)``.
            Defaults to ``(5, 5)``.
        ks (int, optional): Kernel size. Defaults to ``PSF_KS``.
        depth (float, optional): Object depth [mm]. Defaults to ``DEPTH``.
        **kwargs: Forwarded to :meth:`psf`.

    Returns:
        torch.Tensor: PSF map, shape ``[rows, cols, 1, ks, ks]``.
    """
    points = torch.tensor([[0, 0, depth]], device=self.device)
    psf = self.psf(points=points, ks=ks, psf_type="gaussian", **kwargs)
    psf_map = psf.unsqueeze(0).unsqueeze(0).repeat(grid[0], grid[1], 1, 1, 1)
    return psf_map

psf_dp

psf_dp(points, ks=PSF_KS)

Generate dual-pixel PSF for left and right sub-apertures.

This function generates separate PSFs for left and right sub-apertures of a dual pixel sensor, which enables depth estimation and improved autofocus capabilities.

Parameters:

Name Type Description Default
points Tensor

Input tensor with shape [N, 3], where columns are [x, y, z] coordinates.

required
ks int

Kernel size for PSF generation.

PSF_KS

Returns:

Name Type Description
tuple

(left_psf, right_psf) where each PSF tensor has shape [N, ks, ks].

Source code in deeplens/optics/paraxiallens.py
def psf_dp(self, points, ks=PSF_KS):
    """Generate dual-pixel PSF for left and right sub-apertures.

    This function generates separate PSFs for left and right sub-apertures of a dual pixel sensor,
    which enables depth estimation and improved autofocus capabilities.

    Args:
        points (torch.Tensor): Input tensor with shape [N, 3], where columns are [x, y, z] coordinates.
        ks (int): Kernel size for PSF generation.

    Returns:
        tuple: (left_psf, right_psf) where each PSF tensor has shape [N, ks, ks].
    """
    N = points.shape[0]
    depth = points[:, 2]

    # Get the base PSF
    psf_base = self.psf(points, ks=ks, psf_type="gaussian")
    device = psf_base.device

    # Create left and right masks for dual pixel simulation
    l_mask = torch.ones((ks, ks), device=device)
    r_mask = torch.ones((ks, ks), device=device)

    # Split aperture vertically (left half and right half)
    l_pixel, r_pixel = ks // 2, ks // 2 + 1
    l_mask[:, 0:l_pixel] = 0  # Block right side for left PSF
    r_mask[:, r_pixel:] = 0  # Block left side for right PSF

    # Expand masks to match batch dimension [N, ks, ks]
    l_mask = l_mask.unsqueeze(0).repeat(N, 1, 1)
    r_mask = r_mask.unsqueeze(0).repeat(N, 1, 1)

    # Determine focus positions
    depth = depth.to(device)
    foc_dist = torch.tensor(self.foc_dist, device=device, dtype=depth.dtype)
    near_focus_pos = depth > foc_dist  # Shape [N]

    # Create left and right PSFs from base PSF
    psf_l = psf_base.clone()
    psf_r = psf_base.clone()

    # Apply masks based on focus position (this creates the depth-dependent asymmetry)
    # For near focus: left PSF gets left mask, right PSF gets right mask
    # For far focus: masks are swapped to create opposite asymmetry
    for i in range(N):
        if near_focus_pos[i]:
            psf_l[i] = psf_l[i] * l_mask[i]
            psf_r[i] = psf_r[i] * r_mask[i]
        else:
            psf_l[i] = psf_l[i] * r_mask[i]  # Swap masks for far focus
            psf_r[i] = psf_r[i] * l_mask[i]

    # Normalize PSFs
    psf_l = psf_l / (psf_l.sum(dim=(-1, -2), keepdim=True) + EPSILON)
    psf_r = psf_r / (psf_r.sum(dim=(-1, -2), keepdim=True) + EPSILON)

    return psf_l, psf_r

psf_rgb_dp

psf_rgb_dp(points, ks=PSF_KS)

Compute RGB dual-pixel PSFs for left and right sub-apertures.

Replicates the monochrome dual-pixel PSFs across three colour channels.

Parameters:

Name Type Description Default
points Tensor

Point source positions, shape [N, 3].

required
ks int

Kernel size. Defaults to PSF_KS.

PSF_KS

Returns:

Name Type Description
tuple

(psf_left, psf_right) each of shape [N, 3, ks, ks].

Source code in deeplens/optics/paraxiallens.py
def psf_rgb_dp(self, points, ks=PSF_KS):
    """Compute RGB dual-pixel PSFs for left and right sub-apertures.

    Replicates the monochrome dual-pixel PSFs across three colour channels.

    Args:
        points (torch.Tensor): Point source positions, shape ``[N, 3]``.
        ks (int, optional): Kernel size. Defaults to ``PSF_KS``.

    Returns:
        tuple: ``(psf_left, psf_right)`` each of shape ``[N, 3, ks, ks]``.
    """
    psf_l, psf_r = self.psf_dp(points, ks=ks)
    psf_l = psf_l.unsqueeze(1).repeat(1, 3, 1, 1)
    psf_r = psf_r.unsqueeze(1).repeat(1, 3, 1, 1)
    return psf_l, psf_r

psf_map_dp

psf_map_dp(grid=(5, 5), ks=PSF_KS, depth=DEPTH, **kwargs)

Compute spatially-uniform dual-pixel PSF maps.

Parameters:

Name Type Description Default
grid tuple

Grid dimensions (rows, cols). Defaults to (5, 5).

(5, 5)
ks int

Kernel size. Defaults to PSF_KS.

PSF_KS
depth float

Object depth [mm]. Defaults to DEPTH.

DEPTH
**kwargs

Forwarded to :meth:psf_dp.

{}

Returns:

Name Type Description
tuple

(psf_map_left, psf_map_right) each of shape [rows, cols, 1, ks, ks].

Source code in deeplens/optics/paraxiallens.py
def psf_map_dp(self, grid=(5, 5), ks=PSF_KS, depth=DEPTH, **kwargs):
    """Compute spatially-uniform dual-pixel PSF maps.

    Args:
        grid (tuple, optional): Grid dimensions ``(rows, cols)``.
            Defaults to ``(5, 5)``.
        ks (int, optional): Kernel size. Defaults to ``PSF_KS``.
        depth (float, optional): Object depth [mm]. Defaults to ``DEPTH``.
        **kwargs: Forwarded to :meth:`psf_dp`.

    Returns:
        tuple: ``(psf_map_left, psf_map_right)`` each of shape
            ``[rows, cols, 1, ks, ks]``.
    """
    points = torch.tensor([[0, 0, depth]], device=self.device)
    psf_l, psf_r = self.psf_dp(points, ks=ks, **kwargs)
    psf_map_l = psf_l.unsqueeze(0).unsqueeze(0).repeat(grid[0], grid[1], 1, 1, 1)
    psf_map_r = psf_r.unsqueeze(0).unsqueeze(0).repeat(grid[0], grid[1], 1, 1, 1)
    return psf_map_l, psf_map_r

render_rgbd

render_rgbd(img_obj, depth_map, method='psf_patch', **kwargs)

Occlusion-aware RGBD rendering for paraxial lens.

Uses back-to-front layered compositing to prevent color bleeding at depth discontinuities. Since paraxial lenses have no spatially varying aberrations, all methods (psf_patch, psf_map, psf_pixel) produce identical results; the method parameter is accepted for API compatibility but ignored.

Parameters:

Name Type Description Default
img_obj tensor

Object image. Shape [B, C, H, W].

required
depth_map tensor

Depth map [mm]. Shape [B, 1, H, W]. Values should be positive.

required
method str

Ignored (no spatial variation). Defaults to "psf_patch".

'psf_patch'
**kwargs

Additional keyword arguments: - psf_ks (int): PSF kernel size. Defaults to PSF_KS. - num_layers (int): Number of depth layers. Defaults to 16. - depth_min (float): Minimum depth. Defaults to depth_map.min(). - depth_max (float): Maximum depth. Defaults to depth_map.max().

{}

Returns:

Name Type Description
img_render tensor

Rendered image. Shape [B, C, H, W].

Reference

[1] "Dr.Bokeh: DiffeRentiable Occlusion-aware Bokeh Rendering", CVPR 2024.

Source code in deeplens/optics/paraxiallens.py
def render_rgbd(self, img_obj, depth_map, method="psf_patch", **kwargs):
    """Occlusion-aware RGBD rendering for paraxial lens.

    Uses back-to-front layered compositing to prevent color bleeding at depth
    discontinuities. Since paraxial lenses have no spatially varying
    aberrations, all methods (psf_patch, psf_map, psf_pixel) produce
    identical results; the `method` parameter is accepted for API
    compatibility but ignored.

    Args:
        img_obj (tensor): Object image. Shape [B, C, H, W].
        depth_map (tensor): Depth map [mm]. Shape [B, 1, H, W]. Values should be positive.
        method (str, optional): Ignored (no spatial variation). Defaults to "psf_patch".
        **kwargs: Additional keyword arguments:
            - psf_ks (int): PSF kernel size. Defaults to PSF_KS.
            - num_layers (int): Number of depth layers. Defaults to 16.
            - depth_min (float): Minimum depth. Defaults to depth_map.min().
            - depth_max (float): Maximum depth. Defaults to depth_map.max().

    Returns:
        img_render (tensor): Rendered image. Shape [B, C, H, W].

    Reference:
        [1] "Dr.Bokeh: DiffeRentiable Occlusion-aware Bokeh Rendering", CVPR 2024.
    """
    if depth_map.min() < 0:
        raise ValueError("Depth map should be positive.")

    if len(depth_map.shape) == 3:
        depth_map = depth_map.unsqueeze(1)  # [B, H, W] -> [B, 1, H, W]

    psf_ks = kwargs.get("psf_ks", PSF_KS)
    num_layers = kwargs.get("num_layers", 16)
    depth_min = kwargs.get("depth_min", depth_map.min())
    depth_max = kwargs.get("depth_max", depth_map.max())

    # Sample depth layers
    disp_ref, depths_ref = self._sample_depth_layers(depth_min, depth_max, num_layers)

    # Compute PSF at each depth layer (spatially invariant, so patch_center=(0,0))
    points = torch.stack(
        [
            torch.zeros_like(depths_ref),
            torch.zeros_like(depths_ref),
            depths_ref,
        ],
        dim=-1,
    )
    psfs = self.psf_rgb(points=points, ks=psf_ks)  # [num_layers, 3, ks, ks]

    # Occlusion-aware rendering
    img_render = conv_psf_occlusion(img_obj, -depth_map, psfs, depths_ref)
    return img_render

render_rgbd_dp

render_rgbd_dp(rgb_img, depth)

Render RGBD image with dual-pixel PSF.

Parameters:

Name Type Description Default
rgb_img tensor

[B, 3, H, W]

required
depth tensor

[B, 1, H, W]

required

Returns:

Name Type Description
img_left tensor

[B, 3, H, W]

img_right tensor

[B, 3, H, W]

Source code in deeplens/optics/paraxiallens.py
def render_rgbd_dp(self, rgb_img, depth):
    """Render RGBD image with dual-pixel PSF.

    Args:
        rgb_img (tensor): [B, 3, H, W]
        depth (tensor): [B, 1, H, W]

    Returns:
        img_left (tensor): [B, 3, H, W]
        img_right (tensor): [B, 3, H, W]
    """
    # Convert depth to negative values
    if (depth > 0).any():
        depth = -depth

    depth_min = depth.min()
    depth_max = depth.max()
    num_depth = 10
    patch_center = (0.0, 0.0)
    psf_ks = PSF_KS

    # Calculate dual-pixel PSF at reference depths
    depths_ref = torch.linspace(depth_min, depth_max, num_depth).to(self.device)
    points = torch.stack(
        [
            torch.full_like(depths_ref, patch_center[0]),
            torch.full_like(depths_ref, patch_center[1]),
            depths_ref,
        ],
        dim=-1,
    )
    psfs_left, psfs_right = self.psf_rgb_dp(
        points=points, ks=psf_ks
    )  # shape [num_depth, 3, ks, ks]

    # Render dual-pixel image with PSF convolution and depth interpolation
    img_left = conv_psf_depth_interp(rgb_img, depth, psfs_left, depths_ref)
    img_right = conv_psf_depth_interp(rgb_img, depth, psfs_right, depths_ref)
    return img_left, img_right

Neural surrogate that wraps a GeoLens with an MLP to predict PSFs. Useful for fast, differentiable PSF evaluation during end-to-end training.

deeplens.optics.PSFNetLens

PSFNetLens(lens_path, in_chan=3, psf_chan=3, model_name='mlp_conv', kernel_size=64)

Bases: Lens

Neural surrogate lens that predicts PSFs via a small MLP/MLPConv network.

Wraps a :class:~deeplens.optics.geolens.GeoLens with a neural network trained to predict RGB PSFs from (fov, depth, focus_distance) inputs. After training, PSF prediction is ~100× faster than ray tracing, making it suitable for real-time applications and large-scale optimisation.

Attributes:

Name Type Description
lens GeoLens

The underlying refractive lens (used for training data generation and for sensor metadata).

psfnet Module

Neural network for PSF prediction.

pixel_size float

Pixel pitch [mm] (copied from the embedded lens).

rfov float

Half-diagonal field of view [radians].

Notes

Use :meth:train_psfnet to train the surrogate from ray-traced PSF samples. Use :meth:load_net to load pre-trained weights.

Initialize a PSF network lens.

In the default settings, the PSF network takes (fov, depth, foc_dist) as input and outputs RGB PSF on y-axis at (fov, depth, foc_dist).

Parameters:

Name Type Description Default
lens_path str

Path to the lens file.

required
in_chan int

Number of input channels.

3
psf_chan int

Number of output channels.

3
model_name str

Name of the model.

'mlp_conv'
kernel_size int

Kernel size.

64
Source code in deeplens/optics/psfnetlens.py
def __init__(
    self,
    lens_path,
    in_chan=3,
    psf_chan=3,
    model_name="mlp_conv",
    kernel_size=64,
):
    """Initialize a PSF network lens.

    In the default settings, the PSF network takes (fov, depth, foc_dist) as input and outputs RGB PSF on y-axis at (fov, depth, foc_dist).

    Args:
        lens_path (str): Path to the lens file.
        in_chan (int): Number of input channels.
        psf_chan (int): Number of output channels.
        model_name (str): Name of the model.
        kernel_size (int): Kernel size.
    """
    super().__init__()

    # Load lens (sensor_size and sensor_res are read from the lens file)
    self.lens_path = lens_path
    self.lens = GeoLens(filename=lens_path, device=self.device)
    self.rfov = self.lens.rfov

    # Init PSF network
    self.in_chan = in_chan
    self.psf_chan = psf_chan
    self.kernel_size = kernel_size
    self.pixel_size = self.lens.pixel_size

    self.psfnet = self.init_net(
        in_chan=in_chan,
        psf_chan=psf_chan,
        kernel_size=kernel_size,
        model_name=model_name,
    )
    self.psfnet.to(self.device)

    # Object depth range
    self.d_close = -200
    self.d_far = -20000

    # Focus distance range
    # There is a minimum focal distance for each lens. For example, the Canon EF 50mm lens can only focus to 0.5m and further.
    self.foc_d_close = -500
    self.foc_d_far = -20000
    self.refocus(foc_dist=-20000)

set_sensor_res

set_sensor_res(sensor_res)

Set sensor resolution for both PSFNetLens and the embedded GeoLens.

Updates the pixel size accordingly.

Parameters:

Name Type Description Default
sensor_res tuple

New sensor resolution as (W, H) in pixels.

required
Source code in deeplens/optics/psfnetlens.py
def set_sensor_res(self, sensor_res):
    """Set sensor resolution for both PSFNetLens and the embedded GeoLens.

    Updates the pixel size accordingly.

    Args:
        sensor_res (tuple): New sensor resolution as ``(W, H)`` in pixels.
    """
    self.lens.set_sensor_res(sensor_res)
    self.pixel_size = self.lens.pixel_size

init_net

init_net(in_chan=2, psf_chan=3, kernel_size=64, model_name='mlpconv')

Initialize a PSF network.

PSF network

Input: [B, 3], (fov, depth, foc_dist). fov from [0, pi/2], depth from [-20000, -100], foc_dist from [-20000, -500] Output: psf kernel [B, 3, ks, ks]

Parameters:

Name Type Description Default
in_chan int

number of input channels

2
psf_chan int

number of output channels

3
kernel_size int

kernel size

64
model_name str

name of the network architecture

'mlpconv'

Returns:

Name Type Description
psfnet Module

network

Source code in deeplens/optics/psfnetlens.py
def init_net(self, in_chan=2, psf_chan=3, kernel_size=64, model_name="mlpconv"):
    """Initialize a PSF network.

    PSF network:
        Input: [B, 3], (fov, depth, foc_dist). fov from [0, pi/2], depth from [-20000, -100], foc_dist from [-20000, -500]
        Output: psf kernel [B, 3, ks, ks]

    Args:
        in_chan (int): number of input channels
        psf_chan (int): number of output channels
        kernel_size (int): kernel size
        model_name (str): name of the network architecture

    Returns:
        psfnet (nn.Module): network
    """
    if model_name == "mlp":
        psfnet = MLP(
            in_features=in_chan,
            out_features=psf_chan * kernel_size**2,
            hidden_features=256,
            hidden_layers=8,
        )
    elif model_name == "mlpconv":
        psfnet = PSFNet_MLPConv(
            in_chan=in_chan, kernel_size=kernel_size, out_chan=psf_chan
        )
    else:
        raise Exception(f"Unsupported PSF network architecture: {model_name}.")

    return psfnet

load_net

load_net(net_path)

Load pretrained network.

Parameters:

Name Type Description Default
net_path str

path to load the network

required
Source code in deeplens/optics/psfnetlens.py
def load_net(self, net_path):
    """Load pretrained network.

    Args:
        net_path (str): path to load the network
    """
    # Check the correct model is loaded
    psfnet_dict = torch.load(net_path, map_location="cpu", weights_only=False)
    print(
        f"Pretrained model lens pixel size: {psfnet_dict['pixel_size']*1000.0:.1f} um, "
        f"Current lens pixel size: {self.pixel_size*1000.0:.1f} um"
    )
    print(
        f"Pretrained model lens path: {psfnet_dict['lens_path']}, "
        f"Current lens path: {self.lens_path}"
    )

    # Load the model weights
    self.psfnet.load_state_dict(psfnet_dict["psfnet_model_weights"])

save_psfnet

save_psfnet(psfnet_path)

Save the PSF network.

Parameters:

Name Type Description Default
psfnet_path str

path to save the PSF network

required
Source code in deeplens/optics/psfnetlens.py
def save_psfnet(self, psfnet_path):
    """Save the PSF network.

    Args:
        psfnet_path (str): path to save the PSF network
    """
    psfnet_dict = {
        "model_name": self.psfnet.__class__.__name__,
        "in_chan": self.in_chan,
        "pixel_size": self.pixel_size,
        "kernel_size": self.kernel_size,
        "psf_chan": self.psf_chan,
        "lens_path": self.lens_path,
        "psfnet_model_weights": self.psfnet.state_dict(),
    }
    torch.save(psfnet_dict, psfnet_path)

train_psfnet

train_psfnet(iters=100000, bs=128, lr=5e-05, evaluate_every=500, spp=16384, concentration_factor=2.0, result_dir='./results/psfnet')

Train the PSF surrogate network.

Parameters:

Name Type Description Default
iters int

number of training iterations

100000
bs int

batch size

128
lr float

learning rate

5e-05
evaluate_every int

evaluate every how many iterations

500
spp int

number of samples per pixel

16384
concentration_factor float

concentration factor for training data sampling

2.0
result_dir str

directory to save the results

'./results/psfnet'
Source code in deeplens/optics/psfnetlens.py
def train_psfnet(
    self,
    iters=100000,
    bs=128,
    lr=5e-5,
    evaluate_every=500,
    spp=16384,
    concentration_factor=2.0,
    result_dir="./results/psfnet",
):
    """Train the PSF surrogate network.

    Args:
        iters (int): number of training iterations
        bs (int): batch size
        lr (float): learning rate
        evaluate_every (int): evaluate every how many iterations
        spp (int): number of samples per pixel
        concentration_factor (float): concentration factor for training data sampling
        result_dir (str): directory to save the results
    """
    # Init network and prepare for training
    psfnet = self.psfnet
    loss_fn = nn.L1Loss()
    optimizer = torch.optim.AdamW(psfnet.parameters(), lr=lr)
    scheduler = get_cosine_schedule_with_warmup(
        optimizer, num_warmup_steps=int(iters) // 100, num_training_steps=iters
    )

    # Train the network
    for i in tqdm(range(iters + 1)):
        # Sample training data
        sample_input, sample_psf = self.sample_training_data(
            num_points=bs, concentration_factor=concentration_factor
        )
        sample_input, sample_psf = (
            sample_input.to(self.device),
            sample_psf.to(self.device),
        )

        # Forward pass, pred_psf: [B, 3, ks, ks]
        pred_psf = psfnet(sample_input)

        # Backward propagation
        optimizer.zero_grad()
        loss = loss_fn(pred_psf, sample_psf)
        loss.backward()
        optimizer.step()
        scheduler.step()

        # Evaluate training
        if (i + 1) % evaluate_every == 0:
            # Visualize PSFs
            n_vis = 16
            fig, axs = plt.subplots(n_vis, 2, figsize=(4, n_vis * 2))
            for j in range(n_vis):
                psf0 = sample_psf[j, ...].detach().clone().cpu()
                axs[j, 0].imshow(psf0.permute(1, 2, 0) * 255.0)
                axs[j, 0].axis("off")

                psf1 = pred_psf[j, ...].detach().clone().cpu()
                axs[j, 1].imshow(psf1.permute(1, 2, 0) * 255.0)
                axs[j, 1].axis("off")

            axs[0, 0].set_title("GT")
            axs[0, 1].set_title("Pred")

            fig.suptitle(f"GT/Pred PSFs at iter {i + 1}")
            plt.tight_layout()
            plt.savefig(f"{result_dir}/iter{i + 1}.png", dpi=300)
            plt.close()

            # Save network
            self.save_psfnet(f"{result_dir}/PSFNet_last.pth")

sample_training_data

sample_training_data(num_points=512, concentration_factor=2.0)

Sample training data for PSF surrogate network.

Parameters:

Name Type Description Default
num_points int

number of training points

512
concentration_factor float

concentration factor for training data sampling

2.0

Returns:

Name Type Description
sample_input tensor

[B, 3] tensor, (fov, depth, foc_dist). - fov from [0, rfov] on 0y-axis, [radians] - depth from [d_far, d_close], [mm] - foc_dist from [foc_d_far, foc_d_close], [mm] - We use absolute fov and depth.

sample_psf tensor

[B, 3, ks, ks] tensor

Source code in deeplens/optics/psfnetlens.py
@torch.no_grad()
def sample_training_data(self, num_points=512, concentration_factor=2.0):
    """Sample training data for PSF surrogate network.

    Args:
        num_points (int): number of training points
        concentration_factor (float): concentration factor for training data sampling

    Returns:
        sample_input (tensor): [B, 3] tensor, (fov, depth, foc_dist).
            - fov from [0, rfov] on 0y-axis, [radians]
            - depth from [d_far, d_close], [mm]
            - foc_dist from [foc_d_far, foc_d_close], [mm]
            - We use absolute fov and depth.

        sample_psf (tensor): [B, 3, ks, ks] tensor
    """
    d_close = self.d_close
    d_far = self.d_far
    rfov = self.lens.rfov

    # In each iteration, sample one focus distance, [mm], range [foc_d_far, foc_d_close]
    # Example beta distribution: https://share.google/images/Mrb9c39PdddYx3UHj
    beta_sample = float(np.random.beta(1, 4))  # Biased towards 0
    foc_dist = self.foc_d_close + beta_sample * (self.foc_d_far - self.foc_d_close)
    self.lens.refocus(foc_dist)
    foc_dist = torch.full((num_points,), foc_dist)

    # Sample (fov), [radians], range[0, rfov]
    beta_values = np.random.beta(4, 1, num_points)  # Biased towards 1
    beta_values = torch.from_numpy(beta_values).float()
    fov = beta_values * rfov

    # Sample (depth), sample more points near the focus distance, [mm], range [d_far, d_close]
    # A smaller std_dev value samples points more tightly
    std_dev = -foc_dist / concentration_factor
    depth = foc_dist + torch.randn(num_points) * std_dev
    depth = torch.clamp(depth, d_far, d_close)

    # Create input tensor
    sample_input = torch.stack([fov, depth / 1000.0, foc_dist / 1000.0], dim=1)
    sample_input = sample_input.to(self.device)

    # Calculate PSF by ray tracing, shape of [B, 3, ks, ks]
    points_x = torch.zeros_like(depth)
    points_y = self.lens.foclen * torch.tan(fov) / self.lens.r_sensor
    points_z = depth
    points = torch.stack((points_x, points_y, points_z), dim=-1)
    sample_psf = self.lens.psf_rgb(
        points=points, ks=self.kernel_size, recenter=True
    )

    return sample_input, sample_psf

eval

eval()

Switch the PSF surrogate network to evaluation mode.

Disables dropout and batch normalisation updates in the internal psfnet module. Call this before inference.

Source code in deeplens/optics/psfnetlens.py
def eval(self):
    """Switch the PSF surrogate network to evaluation mode.

    Disables dropout and batch normalisation updates in the internal
    ``psfnet`` module.  Call this before inference.
    """
    self.psfnet.eval()

points2input

points2input(points)

Convert points to input tensor.

Parameters:

Name Type Description Default
points tensor

[N, 3] tensor, [-1, 1] * [-1, 1] * [depth_min, depth_max]

required

Returns:

Name Type Description
input tensor

[N, 3] tensor, (fov, depth, foc_dist). - fov from [0, rfov] on y-axis, [radians] - depth/1000.0 from [d_far, d_close], [mm] - foc_dist/1000.0 from [foc_d_far, foc_d_close], [mm]

Source code in deeplens/optics/psfnetlens.py
def points2input(self, points):
    """Convert points to input tensor.

    Args:
        points (tensor): [N, 3] tensor, [-1, 1] * [-1, 1] * [depth_min, depth_max]

    Returns:
        input (tensor): [N, 3] tensor, (fov, depth, foc_dist).
            - fov from [0, rfov] on y-axis, [radians]
            - depth/1000.0 from [d_far, d_close], [mm]
            - foc_dist/1000.0 from [foc_d_far, foc_d_close], [mm]
    """
    sensor_h, sensor_w = self.lens.sensor_size
    foclen = self.lens.foclen

    points_x = points[:, 0] * sensor_w / 2
    points_y = points[:, 1] * sensor_h / 2
    points_r = torch.sqrt(points_x**2 + points_y**2)
    fov = torch.atan(points_r / foclen)
    depth = points[:, 2]
    foc_dist = torch.full_like(fov, self.foc_dist)
    network_inp = torch.stack((fov, depth / 1000.0, foc_dist / 1000.0), dim=-1)
    return network_inp

refocus

refocus(foc_dist)

Refocus the lens to a given object distance.

Delegates to the embedded :class:GeoLens and stores the focus distance for subsequent PSF predictions.

Parameters:

Name Type Description Default
foc_dist float

Focus distance in [mm] (negative, towards the object).

required
Source code in deeplens/optics/psfnetlens.py
def refocus(self, foc_dist):
    """Refocus the lens to a given object distance.

    Delegates to the embedded :class:`GeoLens` and stores the focus
    distance for subsequent PSF predictions.

    Args:
        foc_dist (float): Focus distance in [mm] (negative, towards the
            object).
    """
    self.lens.refocus(foc_dist)
    self.foc_dist = foc_dist

psf_rgb

psf_rgb(points, ks=64)

Calculate RGB PSF using the PSF network.

Parameters:

Name Type Description Default
points tensor

[N, 3] tensor, [-1, 1] * [-1, 1] * [depth_min, depth_max]

required
foc_dist float

focus distance

required

Returns:

Name Type Description
psf tensor

[N, 3, ks, ks] tensor

Source code in deeplens/optics/psfnetlens.py
def psf_rgb(self, points, ks=64):
    """Calculate RGB PSF using the PSF network.

    Args:
        points (tensor): [N, 3] tensor, [-1, 1] * [-1, 1] * [depth_min, depth_max]
        foc_dist (float): focus distance

    Returns:
        psf (tensor): [N, 3, ks, ks] tensor
    """
    # Calculate network input
    network_inp = self.points2input(points)

    # Predict y-axis PSF from network
    psf = self.psfnet(network_inp)

    # Post-process PSF
    # The psfnet is trained with PSFs on the y-axis.
    # We need to rotate the PSF to the correct orientation based on the point's coordinates.
    # The counter-clockwise angle from the positive y-axis to the point (x, y) is atan2(x, y).
    rot_angle = torch.atan2(points[:, 0], points[:, 1])
    psf = rotate_psf(psf, rot_angle)

    # Crop PSF to the given kernel size
    if ks < self.kernel_size:
        psf = psf[
            :,
            :,
            self.kernel_size // 2 - ks // 2 : self.kernel_size // 2 + ks // 2,
            self.kernel_size // 2 - ks // 2 : self.kernel_size // 2 + ks // 2,
        ]
    return psf

psf_map_rgb

psf_map_rgb(grid=(11, 11), depth=DEPTH, ks=PSF_KS, **kwargs)

Compute monochrome PSF map.

Parameters:

Name Type Description Default
grid tuple

Grid size. Defaults to (11, 11), meaning 11x11 grid.

(11, 11)
wvln float

Wavelength. Defaults to DEFAULT_WAVE.

required
depth float

Depth of the object. Defaults to DEPTH.

DEPTH
ks int

Kernel size. Defaults to PSF_KS, meaning PSF_KS x PSF_KS kernel size.

PSF_KS

Returns:

Name Type Description
psf_map

Shape of [grid, grid, 3, ks, ks].

Source code in deeplens/optics/psfnetlens.py
def psf_map_rgb(self, grid=(11, 11), depth=DEPTH, ks=PSF_KS, **kwargs):
    """Compute monochrome PSF map.

    Args:
        grid (tuple, optional): Grid size. Defaults to (11, 11), meaning 11x11 grid.
        wvln (float, optional): Wavelength. Defaults to DEFAULT_WAVE.
        depth (float, optional): Depth of the object. Defaults to DEPTH.
        ks (int, optional): Kernel size. Defaults to PSF_KS, meaning PSF_KS x PSF_KS kernel size.

    Returns:
        psf_map: Shape of [grid, grid, 3, ks, ks].
    """
    # PSF map grid
    points = self.point_source_grid(depth=depth, grid=grid, center=True)
    points = points.reshape(-1, 3).to(self.device)

    # Compute PSF map
    psf = self.psf_rgb(points=points, ks=ks)  # [grid*grid, 3, ks, ks]
    psf_map = psf.reshape(grid[0], grid[1], 3, ks, ks)  # [grid, grid, 3, ks, ks]
    return psf_map

render_rgbd

render_rgbd(img, depth, foc_dist, ks=64, high_res=False)

Render image with aif image and depth map. Receive [N, C, H, W] image.

Parameters:

Name Type Description Default
img tensor

[1, C, H, W]

required
depth tensor

[1, H, W], depth map, unit in mm, range from [-20000, -200]

required
foc_dist tensor

[1], unit in mm, range from [-20000, -200]

required
ks int

kernel size

64
high_res bool

whether to use high resolution rendering

False

Returns:

Name Type Description
render tensor

[1, C, H, W]

Source code in deeplens/optics/psfnetlens.py
@torch.no_grad()
def render_rgbd(self, img, depth, foc_dist, ks=64, high_res=False):
    """Render image with aif image and depth map. Receive [N, C, H, W] image.

    Args:
        img (tensor): [1, C, H, W]
        depth (tensor): [1, H, W], depth map, unit in mm, range from [-20000, -200]
        foc_dist (tensor): [1], unit in mm, range from [-20000, -200]
        ks (int): kernel size
        high_res (bool): whether to use high resolution rendering

    Returns:
        render (tensor): [1, C, H, W]
    """
    B, C, H, W = img.shape
    assert B == 1, "Only support batch size 1"

    # Refocus the lens to the given focus distance
    self.refocus(foc_dist)

    # Estimate PSF for each pixel
    x, y = torch.meshgrid(
        torch.linspace(-1, 1, W, device=self.device),
        torch.linspace(1, -1, H, device=self.device),
        indexing="xy",
    )
    x, y = x.unsqueeze(0).repeat(B, 1, 1), y.unsqueeze(0).repeat(B, 1, 1)
    depth = depth.squeeze(1) / 1000.0

    points = torch.stack((x, y, depth), -1).float()
    psf = self.psf_rgb(points=points, ks=ks)

    # Render image with per-pixel PSF convolution
    if high_res:
        render = conv_psf_pixel_high_res(img, psf)
    else:
        render = conv_psf_pixel(img, psf)

    return render

Surfaces

Base class for all geometric optical surfaces. Implements surface intersection (Newton's method with one differentiable step) and differentiable vector Snell's law refraction.

deeplens.optics.geometric_surface.Surface

Surface(r, d, mat2, pos_xy=[0.0, 0.0], vec_local=[0.0, 0.0, 1.0], is_square=False, device='cpu')

Bases: DeepObj

Base class for all geometric optical surfaces.

A surface sits at axial position d (mm) in the global coordinate system, has an aperture radius r (mm), and separates two optical media. Subclasses override :meth:_sag and :meth:_dfdxy to define their shape.

Ray–surface interaction is handled by three stages, implemented in :meth:ray_reaction:

  1. Coordinate transform – ray is brought into the local surface frame.
  2. Intersection – solved via Newton's method (:meth:newtons_method), using a non-differentiable iteration loop followed by a single differentiable Newton step to enable gradient flow.
  3. Refraction / reflection – vector Snell's law (:meth:refract) or specular reflection (:meth:reflect).

Attributes:

Name Type Description
d Tensor

Axial position of the surface vertex [mm].

r float

Aperture radius [mm].

mat2 Material

Optical material on the transmission side.

is_square bool

If True the aperture is square; otherwise circular.

tolerancing bool

When True, manufacturing error offsets are applied during ray tracing.

Initialize a generic optical surface.

Parameters:

Name Type Description Default
r float

Aperture radius [mm].

required
d float

Axial position of the surface vertex [mm].

required
mat2 str or Material

Material on the transmission side (e.g. "N-BK7", "air").

required
pos_xy list[float]

Lateral offset [x, y] [mm]. Defaults to [0.0, 0.0].

[0.0, 0.0]
vec_local list[float]

Local normal direction. Defaults to [0.0, 0.0, 1.0] (on-axis).

[0.0, 0.0, 1.0]
is_square bool

Use a square aperture. Defaults to False.

False
device str

Compute device. Defaults to "cpu".

'cpu'
Source code in deeplens/optics/geometric_surface/base.py
def __init__(
    self,
    r,
    d,
    mat2,
    pos_xy=[0.0, 0.0],
    vec_local=[0.0, 0.0, 1.0],
    is_square=False,
    device="cpu",
):
    """Initialize a generic optical surface.

    Args:
        r (float): Aperture radius [mm].
        d (float): Axial position of the surface vertex [mm].
        mat2 (str or Material): Material on the transmission side
            (e.g. ``"N-BK7"``, ``"air"``).
        pos_xy (list[float], optional): Lateral offset ``[x, y]`` [mm].
            Defaults to ``[0.0, 0.0]``.
        vec_local (list[float], optional): Local normal direction.
            Defaults to ``[0.0, 0.0, 1.0]`` (on-axis).
        is_square (bool, optional): Use a square aperture.
            Defaults to ``False``.
        device (str, optional): Compute device. Defaults to ``"cpu"``.
    """
    super(Surface, self).__init__()

    # Global direction vector, always pointing to the positive z-axis
    self.vec_global = torch.tensor([0.0, 0.0, 1.0])

    # Surface position in global coordinate system
    self.d = torch.tensor(d)
    self.pos_x = torch.tensor(pos_xy[0])
    self.pos_y = torch.tensor(pos_xy[1])

    # Surface direction vector in global coordinate system
    self.vec_local = F.normalize(torch.tensor(vec_local), p=2, dim=-1)

    # Material after the surface
    self.mat2 = Material(mat2)

    # Surface aperture radius (non-differentiable)
    self.r = float(r)
    self.is_square = is_square
    if is_square:
        # r is the incircle radius
        self.h = 2 * self.r
        self.w = 2 * self.r

    # Newton method parameters
    self.newton_maxiter = 10  # [int], maximum number of Newton iterations
    self.newton_convergence = (
        50.0 * 1e-6
    )  # [mm], Newton method convergence threshold
    self.newton_step_bound = self.r / 5  # [mm], maximum step size in each iteration

    self.tolerancing = False
    self.device = device if device is not None else torch.device("cpu")
    self.to(self.device)

init_from_dict classmethod

init_from_dict(surf_dict)

Initialize surface from a dict.

Source code in deeplens/optics/geometric_surface/base.py
@classmethod
def init_from_dict(cls, surf_dict):
    """Initialize surface from a dict."""
    raise NotImplementedError(
        f"init_from_dict() is not implemented for {cls.__name__}."
    )

ray_reaction

ray_reaction(ray, n1, n2, refraction=True)

Compute the output ray after intersection and refraction/reflection.

Transforms the ray to the local surface frame, solves the intersection via Newton's method, applies vector Snell's law (or specular reflection), then transforms back to global coordinates.

Parameters:

Name Type Description Default
ray Ray

Incident ray bundle.

required
n1 float

Refractive index of the incident medium.

required
n2 float

Refractive index of the transmission medium.

required
refraction bool

If True (default) refract the ray; if False reflect it.

True

Returns:

Name Type Description
Ray

Updated ray bundle after the surface interaction.

Source code in deeplens/optics/geometric_surface/base.py
def ray_reaction(self, ray, n1, n2, refraction=True):
    """Compute the output ray after intersection and refraction/reflection.

    Transforms the ray to the local surface frame, solves the intersection
    via Newton's method, applies vector Snell's law (or specular reflection),
    then transforms back to global coordinates.

    Args:
        ray (Ray): Incident ray bundle.
        n1 (float): Refractive index of the incident medium.
        n2 (float): Refractive index of the transmission medium.
        refraction (bool, optional): If ``True`` (default) refract the ray;
            if ``False`` reflect it.

    Returns:
        Ray: Updated ray bundle after the surface interaction.
    """
    # Transform ray to local coordinate system
    ray = self.to_local_coord(ray)

    # Intersection
    ray = self.intersect(ray, n1)

    if refraction:
        # Refraction
        ray = self.refract(ray, n1 / n2)
    else:
        # Reflection
        ray = self.reflect(ray)

    # Transform ray to global coordinate system
    ray = self.to_global_coord(ray)

    return ray

intersect

intersect(ray, n=1.0)

Solve ray-surface intersection in local coordinate system.

Parameters:

Name Type Description Default
ray Ray

input ray.

required
n float

refractive index. Defaults to 1.0.

1.0
Source code in deeplens/optics/geometric_surface/base.py
def intersect(self, ray, n=1.0):
    """Solve ray-surface intersection in local coordinate system.

    Args:
        ray (Ray): input ray.
        n (float, optional): refractive index. Defaults to 1.0.
    """
    # Solve ray-surface intersection time by Newton's method
    t, valid = self.newtons_method(ray)

    # Update ray
    new_o = ray.o + ray.d * t.unsqueeze(-1)
    ray.o = torch.where(valid.unsqueeze(-1), new_o, ray.o)
    ray.is_valid = ray.is_valid * valid

    if ray.coherent:
        if t.abs().max() > 100 and torch.get_default_dtype() == torch.float32:
            raise Exception(
                "Using float32 may cause precision problem for OPL calculation."
            )
        new_opl = ray.opl + n * t.unsqueeze(-1)
        ray.opl = torch.where(valid.unsqueeze(-1), new_opl, ray.opl)

    return ray

newtons_method

newtons_method(ray)

Solve intersection by Newton's method in local coordinate system.

Parameters:

Name Type Description Default
ray Ray

input ray.

required

Returns:

Name Type Description
t tensor

intersection time.

valid tensor

valid mask.

Source code in deeplens/optics/geometric_surface/base.py
def newtons_method(self, ray):
    """Solve intersection by Newton's method in local coordinate system.

    Args:
        ray (Ray): input ray.

    Returns:
        t (tensor): intersection time.
        valid (tensor): valid mask.
    """
    newton_maxiter = self.newton_maxiter
    newton_convergence = self.newton_convergence
    newton_step_bound = self.newton_step_bound

    # Tolerance
    if self.tolerancing:
        d_surf = self.d_error
    else:
        d_surf = 0.0

    # Initial guess of t (can also use spherical surface for initial guess)
    t = - ray.o[..., 2] / ray.d[..., 2]

    # 1. Non-differentiable Newton's iterations to find the intersection points
    with torch.no_grad():
        it = 0
        ft = 1e6 * torch.ones_like(ray.o[..., 2])
        while it < newton_maxiter:
            # Converged
            if (torch.abs(ft) < newton_convergence).all():
                break

            # One Newton step
            it += 1

            new_o = ray.o + ray.d * t.unsqueeze(-1)
            new_x, new_y = new_o[..., 0], new_o[..., 1]
            valid = self.is_within_data_range(new_x, new_y) & (ray.is_valid > 0)

            ft = self.sag(new_x, new_y, valid) - new_o[..., 2]
            dxdt, dydt, dzdt = ray.d[..., 0], ray.d[..., 1], ray.d[..., 2]
            dfdx, dfdy, dfdz = self.dfdxyz(new_x, new_y, valid)
            dfdt = dfdx * dxdt + dfdy * dydt + dfdz * dzdt
            t = t - torch.clamp(
                ft / (dfdt + EPSILON), -newton_step_bound, newton_step_bound
            )

    # 2. One more (differentiable) Newton step to gain gradients
    new_o = ray.o + ray.d * t.unsqueeze(-1)
    new_x, new_y = new_o[..., 0], new_o[..., 1]
    valid = self.is_valid(new_x, new_y) & (ray.is_valid > 0)

    ft = self.sag(new_x, new_y, valid) - new_o[..., 2]
    dxdt, dydt, dzdt = ray.d[..., 0], ray.d[..., 1], ray.d[..., 2]
    dfdx, dfdy, dfdz = self.dfdxyz(new_x, new_y, valid)
    dfdt = dfdx * dxdt + dfdy * dydt + dfdz * dzdt
    t = t - torch.clamp(
        ft / (dfdt + EPSILON), -newton_step_bound, newton_step_bound
    )

    # 3. Determine valid solutions
    with torch.no_grad():
        # Solution within the surface boundary. Ray is allowed to go back
        new_x, new_y = new_o[..., 0], new_o[..., 1]
        valid = self.is_valid(new_x, new_y) & (ray.is_valid > 0)

        # Solution accurate enough
        ft = self.sag(new_x, new_y, valid) - new_o[..., 2]
        valid = valid & (torch.abs(ft) < newton_convergence)

    return t, valid

refract

refract(ray, eta)

Calculate refracted ray according to Snell's law in local coordinate system.

Normal vector points from the surface toward the side where the light is coming from. d is already normalized if both n and ray.d are normalized.

Parameters:

Name Type Description Default
ray Ray

incident ray.

required
eta float

ratio of indices of refraction, eta = n_i / n_t

required

Returns:

Name Type Description
ray Ray

refracted ray.

References

[1] https://registry.khronos.org/OpenGL-Refpages/gl4/html/refract.xhtml [2] https://en.wikipedia.org/wiki/Snell%27s_law, "Vector form" section.

Source code in deeplens/optics/geometric_surface/base.py
def refract(self, ray, eta):
    """Calculate refracted ray according to Snell's law in local coordinate system.

    Normal vector points from the surface toward the side where the light is coming from. d is already normalized if both n and ray.d are normalized.

    Args:
        ray (Ray): incident ray.
        eta (float): ratio of indices of refraction, eta = n_i / n_t

    Returns:
        ray (Ray): refracted ray.

    References:
        [1] https://registry.khronos.org/OpenGL-Refpages/gl4/html/refract.xhtml
        [2] https://en.wikipedia.org/wiki/Snell%27s_law, "Vector form" section.
    """
    # Compute normal vectors
    normal_vec = self.normal_vec(ray)

    # Compute refraction according to Snell's law, normal_vec * ray_d
    dot_product = (-normal_vec * ray.d).sum(-1).unsqueeze(-1)
    k = 1 - eta**2 * (1 - dot_product**2) 

    # Total internal reflection
    valid = (k >= 0).squeeze(-1) & (ray.is_valid > 0)
    k = k * valid.unsqueeze(-1)

    # Update ray direction and obliquity
    new_d = eta * ray.d + (eta * dot_product - torch.sqrt(k + EPSILON)) * normal_vec
    # ==> Update obliq term to penalize steep rays in the later optimization.
    obliq = torch.sum(new_d * ray.d, axis=-1).unsqueeze(-1)
    obliq_update_mask = valid.unsqueeze(-1) & (obliq < 0.5)
    ray.obliq = torch.where(obliq_update_mask, obliq * ray.obliq, ray.obliq)
    # ==> 
    ray.d = torch.where(valid.unsqueeze(-1), new_d, ray.d)

    # Update ray valid mask
    ray.is_valid = ray.is_valid * valid

    return ray

reflect

reflect(ray)

Calculate reflected ray in local coordinate system.

Normal vector points from the surface toward the side where the light is coming from.

Parameters:

Name Type Description Default
ray Ray

incident ray.

required

Returns:

Name Type Description
ray Ray

reflected ray.

References

[1] https://registry.khronos.org/OpenGL-Refpages/gl4/html/reflect.xhtml [2] https://en.wikipedia.org/wiki/Snell%27s_law, "Vector form" section.

Source code in deeplens/optics/geometric_surface/base.py
def reflect(self, ray):
    """Calculate reflected ray in local coordinate system.

    Normal vector points from the surface toward the side where the light is coming from.

    Args:
        ray (Ray): incident ray.

    Returns:
        ray (Ray): reflected ray.

    References:
        [1] https://registry.khronos.org/OpenGL-Refpages/gl4/html/reflect.xhtml
        [2] https://en.wikipedia.org/wiki/Snell%27s_law, "Vector form" section.
    """
    # Compute surface normal vectors
    normal_vec = self.normal_vec(ray)

    # Reflect
    dot_product = (normal_vec * ray.d).sum(-1).unsqueeze(-1)
    new_d = ray.d - 2 * dot_product * normal_vec
    new_d = F.normalize(new_d, p=2, dim=-1)

    # Update valid rays
    valid_mask = ray.is_valid > 0
    ray.d = torch.where(valid_mask.unsqueeze(-1), new_d, ray.d)

    return ray

normal_vec

normal_vec(ray)

Calculate surface normal vector at the intersection point in local coordinate system.

Normal vector points from the surface toward the side where the light is coming from.

Parameters:

Name Type Description Default
ray Ray

input ray.

required

Returns:

Name Type Description
n_vec tensor

surface normal vector.

Source code in deeplens/optics/geometric_surface/base.py
def normal_vec(self, ray):
    """Calculate surface normal vector at the intersection point in local coordinate system.

    Normal vector points from the surface toward the side where the light is coming from.

    Args:
        ray (Ray): input ray.

    Returns:
        n_vec (tensor): surface normal vector.
    """
    x, y = ray.o[..., 0], ray.o[..., 1]
    nx, ny, nz = self.dfdxyz(x, y)
    n_vec = torch.stack((nx, ny, nz), axis=-1)
    n_vec = F.normalize(n_vec, p=2, dim=-1)

    is_forward = ray.d[..., 2].unsqueeze(-1) > 0
    n_vec = torch.where(is_forward, n_vec, -n_vec)
    return n_vec

to_local_coord

to_local_coord(ray)

Transform ray to local coordinate system.

Parameters:

Name Type Description Default
ray Ray

input ray in global coordinate system.

required

Returns:

Name Type Description
ray Ray

transformed ray in local coordinate system.

Source code in deeplens/optics/geometric_surface/base.py
def to_local_coord(self, ray):
    """Transform ray to local coordinate system.

    Args:
        ray (Ray): input ray in global coordinate system.

    Returns:
        ray (Ray): transformed ray in local coordinate system.
    """
    # Shift ray origin to surface origin
    ray.o[..., 0] = ray.o[..., 0] - self.pos_x
    ray.o[..., 1] = ray.o[..., 1] - self.pos_y
    ray.o[..., 2] = ray.o[..., 2] - self.d

    # Rotate ray origin and direction
    if torch.abs(torch.dot(self.vec_local, self.vec_global) - 1.0) > EPSILON:
        R = self._get_rotation_matrix(self.vec_local, self.vec_global)
        ray.o = self._apply_rotation(ray.o, R)
        ray.d = self._apply_rotation(ray.d, R)
        ray.d = F.normalize(ray.d, p=2, dim=-1)

    return ray

to_global_coord

to_global_coord(ray)

Transform ray to global coordinate system.

Parameters:

Name Type Description Default
ray Ray

input ray in local coordinate system.

required

Returns:

Name Type Description
ray Ray

transformed ray in global coordinate system.

Source code in deeplens/optics/geometric_surface/base.py
def to_global_coord(self, ray):
    """Transform ray to global coordinate system.

    Args:
        ray (Ray): input ray in local coordinate system.

    Returns:
        ray (Ray): transformed ray in global coordinate system.
    """
    # Rotate ray origin and direction
    if torch.abs(torch.dot(self.vec_local, self.vec_global) - 1.0) > EPSILON:
        R = self._get_rotation_matrix(self.vec_global, self.vec_local)
        ray.o = self._apply_rotation(ray.o, R)
        ray.d = self._apply_rotation(ray.d, R)
        ray.d = F.normalize(ray.d, p=2, dim=-1)

    # Shift ray origin back to global coordinates
    ray.o[..., 0] = ray.o[..., 0] + self.pos_x
    ray.o[..., 1] = ray.o[..., 1] + self.pos_y
    ray.o[..., 2] = ray.o[..., 2] + self.d

    return ray

sag

sag(x, y, valid=None)

Calculate sag (z) of the surface: z = f(x, y).

Valid term is used to avoid NaN when x, y exceed the data range, which happens in spherical and aspherical surfaces.

Calculating r = sqrt(x2, y2) may cause an NaN error during back-propagation. Because dr/dx = x / sqrt(x2 + y2), NaN will occur when x=y=0.

Source code in deeplens/optics/geometric_surface/base.py
def sag(self, x, y, valid=None):
    """Calculate sag (z) of the surface: z = f(x, y).

    Valid term is used to avoid NaN when x, y exceed the data range, which happens in spherical and aspherical surfaces.

    Calculating r = sqrt(x**2, y**2) may cause an NaN error during back-propagation. Because dr/dx = x / sqrt(x**2 + y**2), NaN will occur when x=y=0.
    """
    if valid is None:
        valid = self.is_valid(x, y)

    x, y = x * valid, y * valid
    return self._sag(x, y)

dfdxyz

dfdxyz(x, y, valid=None)

Compute derivatives of surface function. Surface function: f(x, y, z): sag(x, y) - z = 0. This function is used in Newton's method and normal vector calculation.

There are several methods to compute derivatives of surfaces

[1] Analytical derivatives: The current implementation is based on this method. But the implementation only works for surfaces which can be written as z = sag(x, y). For implicit surfaces, we need to compute derivatives (df/dx, df/dy, df/dz). [2] Numerical derivatives: Use finite difference method to compute derivatives. This can be used for those very complex surfaces, for example, NURBS. But it may suffer from numerical instability when the surface is very steep. [3] Automatic differentiation: Use torch.autograd to compute derivatives. This can work for almost all the surfaces and is accurate, but it requires an extra backward pass to compute the derivatives of the surface function.

Source code in deeplens/optics/geometric_surface/base.py
def dfdxyz(self, x, y, valid=None):
    """Compute derivatives of surface function. Surface function: f(x, y, z): sag(x, y) - z = 0. This function is used in Newton's method and normal vector calculation.

    There are several methods to compute derivatives of surfaces:
        [1] Analytical derivatives: The current implementation is based on this method. But the implementation only works for surfaces which can be written as z = sag(x, y). For implicit surfaces, we need to compute derivatives (df/dx, df/dy, df/dz).
        [2] Numerical derivatives: Use finite difference method to compute derivatives. This can be used for those very complex surfaces, for example, NURBS. But it may suffer from numerical instability when the surface is very steep.
        [3] Automatic differentiation: Use torch.autograd to compute derivatives. This can work for almost all the surfaces and is accurate, but it requires an extra backward pass to compute the derivatives of the surface function.
    """
    if valid is None:
        valid = self.is_valid(x, y)

    x, y = x * valid, y * valid
    dx, dy = self._dfdxy(x, y)
    return dx, dy, -torch.ones_like(x)

d2fdxyz2

d2fdxyz2(x, y, valid=None)

Compute second-order partial derivatives of the surface function f(x, y, z): sag(x, y) - z = 0. This function is currently only used for surfaces constraints.

Source code in deeplens/optics/geometric_surface/base.py
def d2fdxyz2(self, x, y, valid=None):
    """Compute second-order partial derivatives of the surface function f(x, y, z): sag(x, y) - z = 0. This function is currently only used for surfaces constraints."""
    if valid is None:
        valid = self.is_within_data_range(x, y)

    x, y = x * valid, y * valid

    # Compute second-order derivatives of sag(x, y)
    d2f_dx2, d2f_dxdy, d2f_dy2 = self._d2fdxy(x, y)

    # Mixed partial derivatives involving z are zero
    zeros = torch.zeros_like(x)
    d2f_dxdz = zeros  # ∂²f/∂x∂z = 0
    d2f_dydz = zeros  # ∂²f/∂y∂z = 0
    d2f_dz2 = zeros  # ∂²f/∂z² = 0

    return d2f_dx2, d2f_dxdy, d2f_dy2, d2f_dxdz, d2f_dydz, d2f_dz2

is_valid

is_valid(x, y)

Valid points within the data range and boundary of the surface.

Source code in deeplens/optics/geometric_surface/base.py
def is_valid(self, x, y):
    """Valid points within the data range and boundary of the surface."""
    return self.is_within_data_range(x, y) & self.is_within_boundary(x, y)

is_within_boundary

is_within_boundary(x, y)

Valid points within the boundary of the surface.

Source code in deeplens/optics/geometric_surface/base.py
def is_within_boundary(self, x, y):
    """Valid points within the boundary of the surface."""
    if self.is_square:
        valid = (torch.abs(x) <= (self.w / 2 + EPSILON)) & (
            torch.abs(y) <= (self.h / 2 + EPSILON)
        )
    else:
        if self.tolerancing:
            r = self.r + self.r_error
        else:
            r = self.r
        valid = (x**2 + y**2) <= r**2

    return valid

is_within_data_range

is_within_data_range(x, y)

Valid points inside the data region of the sag function.

Source code in deeplens/optics/geometric_surface/base.py
def is_within_data_range(self, x, y):
    """Valid points inside the data region of the sag function."""
    return torch.ones_like(x, dtype=torch.bool)

max_height

max_height()

Maximum valid height.

Source code in deeplens/optics/geometric_surface/base.py
def max_height(self):
    """Maximum valid height."""
    return 10e3

surface_with_offset

surface_with_offset(x, y, valid_check=True)

Calculate z coordinate of the surface at (x, y).

This function is used in lens setup plotting and lens self-intersection detection.

Source code in deeplens/optics/geometric_surface/base.py
def surface_with_offset(self, x, y, valid_check=True):
    """Calculate z coordinate of the surface at (x, y).

    This function is used in lens setup plotting and lens self-intersection detection.
    """
    x = x if torch.is_tensor(x) else torch.tensor(x).to(self.device)
    y = y if torch.is_tensor(y) else torch.tensor(y).to(self.device)
    if valid_check:
        return self.sag(x, y) + self.d
    else:
        return self._sag(x, y) + self.d

surface_sag

surface_sag(x, y)

Calculate sag of the surface at (x, y).

This function is currently not used.

Source code in deeplens/optics/geometric_surface/base.py
def surface_sag(self, x, y):
    """Calculate sag of the surface at (x, y).

    This function is currently not used.
    """
    x = x if torch.is_tensor(x) else torch.tensor(x).to(self.device)
    y = y if torch.is_tensor(y) else torch.tensor(y).to(self.device)
    return self.sag(x, y).item()

get_optimizer_params

get_optimizer_params(lrs=[0.0001], optim_mat=False)

Get optimizer parameters for different parameters.

Parameters:

Name Type Description Default
lrs list

learning rates for different parameters.

[0.0001]
optim_mat bool

whether to optimize material. Defaults to False.

False
Source code in deeplens/optics/geometric_surface/base.py
def get_optimizer_params(self, lrs=[1e-4], optim_mat=False):
    """Get optimizer parameters for different parameters.

    Args:
        lrs (list): learning rates for different parameters.
        optim_mat (bool): whether to optimize material. Defaults to False.
    """
    raise NotImplementedError(
        "get_optimizer_params() is not implemented for {}".format(
            self.__class__.__name__
        )
    )

get_optimizer

get_optimizer(lrs=[0.0001], optim_mat=False)

Get optimizer for the surface.

Source code in deeplens/optics/geometric_surface/base.py
def get_optimizer(self, lrs=[1e-4], optim_mat=False):
    """Get optimizer for the surface."""
    params = self.get_optimizer_params(lrs, optim_mat=optim_mat)
    return torch.optim.Adam(params)

update_r

update_r(r)

Update surface radius.

Source code in deeplens/optics/geometric_surface/base.py
def update_r(self, r):
    """Update surface radius."""
    r_max = self.max_height()
    self.r = min(r, r_max)

init_tolerance

init_tolerance(tolerance_params=None)

Initialize tolerance parameters for the surface.

Parameters:

Name Type Description Default
tolerance_params dict or None

Tolerance for surface parameters. Supported keys (all optional, default values shown):

.. code-block:: python

{
    "r_tole": 0.05,               # aperture radius [mm]
    "d_tole": 0.05,               # axial position [mm]
    "center_thickness_tole": 0.1, # centre thickness [mm]
    "decenter_tole": 0.1,         # lateral decentre [mm]
    "tilt_tole": 0.1,             # tilt [arcmin]
    "mat2_n_tole": 0.001,         # refractive index
    "mat2_V_tole": 0.01,          # Abbe number [%]
}
None
References

[1] https://www.edmundoptics.com/knowledge-center/application-notes/optics/understanding-optical-specifications/?srsltid=AfmBOorBa-0zaOcOhdQpUjmytthZc07oFlmPW_2AgaiNHHQwobcAzWII [2] https://wp.optics.arizona.edu/optomech/wp-content/uploads/sites/53/2016/08/8-Tolerancing-1.pdf [3] https://wp.optics.arizona.edu/jsasian/wp-content/uploads/sites/33/2016/03/L17_OPTI517_Lens-_Tolerancing.pdf

Source code in deeplens/optics/geometric_surface/base.py
def init_tolerance(self, tolerance_params=None):
    """Initialize tolerance parameters for the surface.

    Args:
        tolerance_params (dict or None): Tolerance for surface parameters.
            Supported keys (all optional, default values shown):

            .. code-block:: python

                {
                    "r_tole": 0.05,               # aperture radius [mm]
                    "d_tole": 0.05,               # axial position [mm]
                    "center_thickness_tole": 0.1, # centre thickness [mm]
                    "decenter_tole": 0.1,         # lateral decentre [mm]
                    "tilt_tole": 0.1,             # tilt [arcmin]
                    "mat2_n_tole": 0.001,         # refractive index
                    "mat2_V_tole": 0.01,          # Abbe number [%]
                }

    References:
        [1] https://www.edmundoptics.com/knowledge-center/application-notes/optics/understanding-optical-specifications/?srsltid=AfmBOorBa-0zaOcOhdQpUjmytthZc07oFlmPW_2AgaiNHHQwobcAzWII
        [2] https://wp.optics.arizona.edu/optomech/wp-content/uploads/sites/53/2016/08/8-Tolerancing-1.pdf
        [3] https://wp.optics.arizona.edu/jsasian/wp-content/uploads/sites/33/2016/03/L17_OPTI517_Lens-_Tolerancing.pdf
    """
    if tolerance_params is None:
        tolerance_params = {}

    self.r_tole = tolerance_params.get("r_tole", 0.05)
    self.d_tole = tolerance_params.get("d_tole", 0.05)
    self.center_thick_tole = tolerance_params.get("center_thick_tole", 0.1)
    self.decenter_tole = tolerance_params.get("decenter_tole", 0.1)
    self.tilt_tole = tolerance_params.get("tilt_tole", 0.1)
    self.mat2_n_tole = tolerance_params.get("mat2_n_tole", 0.001)
    self.mat2_V_tole = tolerance_params.get("mat2_V_tole", 0.01)

sample_tolerance

sample_tolerance()

Sample one example manufacturing error for the surface.

Source code in deeplens/optics/geometric_surface/base.py
@torch.no_grad()
def sample_tolerance(self):
    """Sample one example manufacturing error for the surface."""
    self.r_error = float(np.random.uniform(-self.r_tole, 0))  # [mm]
    self.d_error = float(np.random.randn() * self.d_tole)  # [mm]
    self.center_thick_error = float(np.random.randn() * self.center_thick_tole)
    self.decenter_error = float(np.random.randn() * self.decenter_tole)  # [mm]
    self.tilt_error = float(np.random.randn() * self.tilt_tole)  # [arcmin]
    self.tilt_error = self.tilt_error / 60.0 * np.pi / 180.0  # [rad]
    self.mat2_n_error = float(np.random.randn() * self.mat2_n_tole)
    self.mat2_V_error = float(np.random.randn() * self.mat2_V_tole) * self.mat2.V
    self.tolerancing = True

zero_tolerance

zero_tolerance()

Zero tolerance.

Source code in deeplens/optics/geometric_surface/base.py
def zero_tolerance(self):
    """Zero tolerance."""
    self.r_error = 0.0
    self.d_error = 0.0
    self.center_thick_error = 0.0
    self.decenter_error = 0.0
    self.tilt_error = 0.0
    self.mat2_n_error = 0.0
    self.mat2_V_error = 0.0
    self.tolerancing = False

sensitivity_score

sensitivity_score()

Tolerance squared sum.

Reference

[1] Page 10 from: https://wp.optics.arizona.edu/optomech/wp-content/uploads/sites/53/2016/08/8-Tolerancing-1.pdf

Source code in deeplens/optics/geometric_surface/base.py
def sensitivity_score(self):
    """Tolerance squared sum.

    Reference:
        [1] Page 10 from: https://wp.optics.arizona.edu/optomech/wp-content/uploads/sites/53/2016/08/8-Tolerancing-1.pdf
    """
    score_dict = {}
    score_dict.update(
        {
            f"surf{self.surf_idx}_d_grad": round(self.d.grad.item(), 6),
            f"surf{self.surf_idx}_d_score": round(
                (self.d_tole**2 * self.d.grad**2).item(), 6
            ),
        }
    )
    return score_dict

draw_widget

draw_widget(ax, color='black', linestyle='solid')

Draw widget for the surface on the 2D plot.

Source code in deeplens/optics/geometric_surface/base.py
def draw_widget(self, ax, color="black", linestyle="solid"):
    """Draw widget for the surface on the 2D plot."""
    r = torch.linspace(-self.r, self.r, 128, device=self.device)
    z = self.surface_with_offset(r, torch.zeros(len(r), device=self.device))
    ax.plot(
        z.cpu().detach().numpy(),
        r.cpu().detach().numpy(),
        color=color,
        linestyle=linestyle,
        linewidth=0.75,
    )

create_mesh

create_mesh(n_rings=32, n_arms=128, color=[0.06, 0.3, 0.6])

Create triangulated surface mesh.

Parameters:

Name Type Description Default
n_rings int

Number of concentric rings for sampling.

32
n_arms int

Number of angular divisions.

128
color List[float]

The color of the mesh.

[0.06, 0.3, 0.6]

Returns:

Name Type Description
self

The surface with mesh data.

Source code in deeplens/optics/geometric_surface/base.py
def create_mesh(self, n_rings=32, n_arms=128, color=[0.06, 0.3, 0.6]):
    """Create triangulated surface mesh.

    Args:
        n_rings (int): Number of concentric rings for sampling.
        n_arms (int): Number of angular divisions.
        color (List[float]): The color of the mesh.

    Returns:
        self: The surface with mesh data.
    """
    self.vertices = self._create_vertices(n_rings, n_arms)
    self.faces = self._create_faces(n_rings, n_arms)
    self.rim = self._create_rim(n_rings, n_arms)
    self.mesh_color = color
    return self

get_polydata

get_polydata()

Get PyVista PolyData object from previously generated vertices and faces.

PolyData object will be used to draw the surface and export as .obj file.

Source code in deeplens/optics/geometric_surface/base.py
def get_polydata(self):
    """Get PyVista PolyData object from previously generated vertices and faces.

    PolyData object will be used to draw the surface and export as .obj file.
    """
    from pyvista import PolyData

    face_vertex_n = 3  # vertices per triangle
    formatted_faces = np.hstack(
        [
            face_vertex_n * np.ones((self.faces.shape[0], 1), dtype=np.uint32),
            self.faces,
        ]
    )
    return PolyData(self.vertices, formatted_faces)

zmx_str

zmx_str(surf_idx, d_next)

Return Zemax surface string.

Source code in deeplens/optics/geometric_surface/base.py
def zmx_str(self, surf_idx, d_next):
    """Return Zemax surface string."""
    raise NotImplementedError(
        "zmx_str() is not implemented for {}".format(self.__class__.__name__)
    )

Spherical surface defined by curvature \(c = 1/R\).

deeplens.optics.geometric_surface.Spheric

Spheric(c, r, d, mat2, pos_xy=[0.0, 0.0], vec_local=[0.0, 0.0, 1.0], is_square=False, device='cpu')

Bases: Surface

Spherical refractive surface parameterized by curvature.

The sag function is:

.. math::

z(x, y) = \frac{c \rho^2}{1 + \sqrt{1 - c^2 \rho^2}}, \quad
\rho^2 = x^2 + y^2

Attributes:

Name Type Description
c Tensor

Surface curvature 1/R [1/mm]. Differentiable with respect to gradient-based optimization.

Initialize a spherical surface.

Parameters:

Name Type Description Default
c float

Surface curvature 1/R [1/mm]. Use 0 for a flat surface (equivalent to Plane).

required
r float

Aperture radius [mm].

required
d float

Axial vertex position [mm].

required
mat2 str or Material

Material on the transmission side.

required
pos_xy list[float]

Lateral offset [x, y] [mm]. Defaults to [0.0, 0.0].

[0.0, 0.0]
vec_local list[float]

Local normal direction. Defaults to [0.0, 0.0, 1.0].

[0.0, 0.0, 1.0]
is_square bool

Square aperture flag. Defaults to False.

False
device str

Compute device. Defaults to "cpu".

'cpu'
Source code in deeplens/optics/geometric_surface/spheric.py
def __init__(
    self,
    c,
    r,
    d,
    mat2,
    pos_xy=[0.0, 0.0],
    vec_local=[0.0, 0.0, 1.0],
    is_square=False,
    device="cpu",
):
    """Initialize a spherical surface.

    Args:
        c (float): Surface curvature ``1/R`` [1/mm].  Use ``0`` for a flat
            surface (equivalent to ``Plane``).
        r (float): Aperture radius [mm].
        d (float): Axial vertex position [mm].
        mat2 (str or Material): Material on the transmission side.
        pos_xy (list[float], optional): Lateral offset ``[x, y]`` [mm].
            Defaults to ``[0.0, 0.0]``.
        vec_local (list[float], optional): Local normal direction.
            Defaults to ``[0.0, 0.0, 1.0]``.
        is_square (bool, optional): Square aperture flag. Defaults to
            ``False``.
        device (str, optional): Compute device. Defaults to ``"cpu"``.
    """
    super(Spheric, self).__init__(
        r=r,
        d=d,
        mat2=mat2,
        pos_xy=pos_xy,
        vec_local=vec_local,
        is_square=is_square,
        device=device,
    )
    self.c = torch.tensor(c)

    self.tolerancing = False
    self.to(device)

intersect

intersect(ray, n=1.0)

Solve ray-surface intersection in local coordinate system using analytical method.

Sphere equation: (x)^2 + (y)^2 + (z - R)^2 = R^2, where R = 1/c Ray equation: p(t) = o + t*d Solve quadratic equation for intersection parameter t.

Parameters:

Name Type Description Default
ray Ray

input ray.

required
n float

refractive index. Defaults to 1.0.

1.0

Returns:

Name Type Description
ray Ray

ray with updated position and opl.

Source code in deeplens/optics/geometric_surface/spheric.py
def intersect(self, ray, n=1.0):
    """Solve ray-surface intersection in local coordinate system using analytical method.

    Sphere equation: (x)^2 + (y)^2 + (z - R)^2 = R^2, where R = 1/c
    Ray equation: p(t) = o + t*d
    Solve quadratic equation for intersection parameter t.

    Args:
        ray (Ray): input ray.
        n (float, optional): refractive index. Defaults to 1.0.

    Returns:
        ray (Ray): ray with updated position and opl.
    """
    # Tolerance
    if self.tolerancing:
        c = self.c + self.c_error
    else:
        c = self.c

    if torch.abs(c) < EPSILON:
        # Handle flat surface as a plane
        t = (0.0 - ray.o[..., 2]) / ray.d[..., 2]
        new_o = ray.o + t.unsqueeze(-1) * ray.d
        valid = (new_o[..., 0] ** 2 + new_o[..., 1] ** 2 < self.r**2) & (
            ray.is_valid > 0
        )
    else:
        R = 1.0 / c

        # Vector from ray origin to sphere center at (0, 0, R)
        oc = ray.o.clone()
        oc[..., 2] = oc[..., 2] - R

        # Quadratic equation: a*t^2 + b*t + c = 0
        # a = d·d = 1 (since ray direction is normalized)
        # b = 2*(o-center)·d
        # c = (o-center)·(o-center) - R^2

        a = torch.sum(ray.d * ray.d, dim=-1)  # Should be 1 for normalized rays
        b = 2.0 * torch.sum(oc * ray.d, dim=-1)
        c_coeff = torch.sum(oc * oc, dim=-1) - R * R

        discriminant = b * b - 4 * a * c_coeff
        valid_intersect = discriminant >= 0

        sqrt_discriminant = torch.sqrt(torch.clamp(discriminant, min=EPSILON))
        t1 = (-b - sqrt_discriminant) / (2 * a + EPSILON)
        t2 = (-b + sqrt_discriminant) / (2 * a + EPSILON)

        # Choose intersection closest to z=0 (surface vertex)
        z1 = ray.o[..., 2] + t1 * ray.d[..., 2]
        z2 = ray.o[..., 2] + t2 * ray.d[..., 2]
        use_t1 = torch.abs(z1) < torch.abs(z2)
        t = torch.where(use_t1, t1, t2)

        new_o = ray.o + t.unsqueeze(-1) * ray.d

        # Check aperture
        r_squared = new_o[..., 0] ** 2 + new_o[..., 1] ** 2
        within_aperture = r_squared <= (self.r**2 + EPSILON)

        valid = valid_intersect & within_aperture & (ray.is_valid > 0)

    # Update ray position
    ray.o = torch.where(valid.unsqueeze(-1), new_o, ray.o)
    ray.is_valid = ray.is_valid * valid

    if ray.coherent:
        if t.abs().max() > 100 and torch.get_default_dtype() == torch.float32:
            raise Exception(
                "Using float32 may cause precision problem for OPL calculation."
            )
        new_opl = ray.opl + n * t.unsqueeze(-1)
        ray.opl = torch.where(valid.unsqueeze(-1), new_opl, ray.opl)

    return ray

is_within_data_range

is_within_data_range(x, y)

Invalid when shape is non-defined.

Source code in deeplens/optics/geometric_surface/spheric.py
def is_within_data_range(self, x, y):
    """Invalid when shape is non-defined."""
    if self.tolerancing:
        c = self.c + self.c_error
    else:
        c = self.c

    valid = (x**2 + y**2) < 1 / c**2
    return valid

max_height

max_height()

Maximum valid height.

Source code in deeplens/optics/geometric_surface/spheric.py
def max_height(self):
    """Maximum valid height."""
    if self.tolerancing:
        c = self.c + self.c_error
    else:
        c = self.c

    max_height = torch.sqrt(1 / c**2).item() - 0.001
    return max_height

init_tolerance

init_tolerance(tolerance_params=None)

Initialize tolerance parameters for the surface.

Parameters:

Name Type Description Default
tolerance_params dict

Tolerance for surface parameters.

None
Source code in deeplens/optics/geometric_surface/spheric.py
def init_tolerance(self, tolerance_params=None):
    """Initialize tolerance parameters for the surface.

    Args:
        tolerance_params (dict): Tolerance for surface parameters.
    """
    super().init_tolerance(tolerance_params)
    self.c_tole = tolerance_params.get("c_tole", 0.0001)

sample_tolerance

sample_tolerance()

Randomly perturb surface parameters to simulate manufacturing errors.

Source code in deeplens/optics/geometric_surface/spheric.py
def sample_tolerance(self):
    """Randomly perturb surface parameters to simulate manufacturing errors."""
    super().sample_tolerance()
    self.c_error = float(np.random.randn() * self.c_tole)

zero_tolerance

zero_tolerance()

Zero tolerance.

Source code in deeplens/optics/geometric_surface/spheric.py
def zero_tolerance(self):
    """Zero tolerance."""
    super().zero_tolerance()
    self.c_error = 0.0

sensitivity_score

sensitivity_score()

Tolerance squared sum.

Source code in deeplens/optics/geometric_surface/spheric.py
def sensitivity_score(self):
    """Tolerance squared sum."""
    score_dict = super().sensitivity_score()
    score_dict.update(
        {
            "c_grad": round(self.c.grad.item(), 6),
            "c_score": round(
                (self.c_tole**2 * self.c.grad**2).item(), 6
            ),
        }
    )
    return score_dict

get_optimizer_params

get_optimizer_params(lrs=[0.0001, 0.0001], optim_mat=False)

Activate gradient computation for c and d and return optimizer parameters.

Source code in deeplens/optics/geometric_surface/spheric.py
def get_optimizer_params(self, lrs=[1e-4, 1e-4], optim_mat=False):
    """Activate gradient computation for c and d and return optimizer parameters."""
    self.c.requires_grad_(True)
    self.d.requires_grad_(True)

    params = []
    params.append({"params": [self.d], "lr": lrs[0]})
    params.append({"params": [self.c], "lr": lrs[1]})

    if optim_mat and self.mat2.get_name() != "air":
        params += self.mat2.get_optimizer_params()

    return params

surf_dict

surf_dict()

Return surface parameters.

Source code in deeplens/optics/geometric_surface/spheric.py
def surf_dict(self):
    """Return surface parameters."""
    roc = 1 / self.c.item() if self.c.item() != 0 else 0.0
    surf_dict = {
        "type": "Spheric",
        "r": round(self.r, 4),
        "(c)": round(self.c.item(), 4),
        "roc": round(roc, 4),
        "(d)": round(self.d.item(), 4),
        "mat2": self.mat2.get_name(),
    }

    return surf_dict

zmx_str

zmx_str(surf_idx, d_next)

Return Zemax surface string.

Source code in deeplens/optics/geometric_surface/spheric.py
    def zmx_str(self, surf_idx, d_next):
        """Return Zemax surface string."""
        if self.mat2.get_name() == "air":
            zmx_str = f"""SURF {surf_idx} 
    TYPE STANDARD 
    CURV {self.c.item()} 
    DISZ {d_next.item()} 
    DIAM {self.r} 1 0 0 1 ""
"""
        else:
            zmx_str = f"""SURF {surf_idx} 
    TYPE STANDARD 
    CURV {self.c.item()} 
    DISZ {d_next.item()} 
    GLAS ___BLANK 1 0 {self.mat2.n} {self.mat2.V}
    DIAM {self.r} 1 0 0 1 ""
"""
        return zmx_str

Even-asphere surface: spherical base with polynomial corrections.

deeplens.optics.geometric_surface.Aspheric

Aspheric(r, d, c, k, ai, mat2, pos_xy=[0.0, 0.0], vec_local=[0.0, 0.0, 1.0], is_square=False, device='cpu')

Bases: Surface

Even-order aspheric surface.

The sag function is:

.. math::

z(\rho) = \frac{c\,\rho^2}{1 + \sqrt{1-(1+k)c^2\rho^2}}
         + \sum_{i=1}^{n} a_{2i}\,\rho^{2i},
\quad \rho^2 = x^2 + y^2

All coefficients c, k, and ai are differentiable torch tensors so they can be optimised with gradient descent.

Attributes:

Name Type Description
c Tensor

Base curvature [1/mm].

k Tensor

Conic constant.

ai Tensor

Even-order aspheric coefficients [a2, a4, a6, ...].

Initialize an aspheric surface.

Parameters:

Name Type Description Default
r float

Aperture radius [mm].

required
d float

Axial vertex position [mm].

required
c float

Base curvature 1/R [1/mm].

required
k float

Conic constant (0 = sphere, -1 = paraboloid).

required
ai list[float] or None

Even-order aspheric coefficients [a2, a4, a6, ...]. Pass None or an empty list for a pure conic.

required
mat2 str or Material

Material on the transmission side.

required
pos_xy list[float]

Lateral offset [x, y] [mm]. Defaults to [0.0, 0.0].

[0.0, 0.0]
vec_local list[float]

Local normal direction. Defaults to [0.0, 0.0, 1.0].

[0.0, 0.0, 1.0]
is_square bool

Square aperture flag. Defaults to False.

False
device str

Compute device. Defaults to "cpu".

'cpu'
Source code in deeplens/optics/geometric_surface/aspheric.py
def __init__(
    self,
    r,
    d,
    c,
    k,
    ai,
    mat2,
    pos_xy=[0.0, 0.0],
    vec_local=[0.0, 0.0, 1.0],
    is_square=False,
    device="cpu",
):
    """Initialize an aspheric surface.

    Args:
        r (float): Aperture radius [mm].
        d (float): Axial vertex position [mm].
        c (float): Base curvature ``1/R`` [1/mm].
        k (float): Conic constant (``0`` = sphere, ``-1`` = paraboloid).
        ai (list[float] or None): Even-order aspheric coefficients
            ``[a2, a4, a6, ...]``.  Pass ``None`` or an empty list for a
            pure conic.
        mat2 (str or Material): Material on the transmission side.
        pos_xy (list[float], optional): Lateral offset ``[x, y]`` [mm].
            Defaults to ``[0.0, 0.0]``.
        vec_local (list[float], optional): Local normal direction.
            Defaults to ``[0.0, 0.0, 1.0]``.
        is_square (bool, optional): Square aperture flag.
            Defaults to ``False``.
        device (str, optional): Compute device. Defaults to ``"cpu"``.
    """
    Surface.__init__(
        self,
        r=r,
        d=d,
        mat2=mat2,
        pos_xy=pos_xy,
        vec_local=vec_local,
        is_square=is_square,
        device=device,
    )

    self.c = torch.tensor(c)
    self.k = torch.tensor(k)
    if ai is not None:
        self.ai = torch.tensor(ai)
        self.ai_degree = len(ai)
        if self.ai_degree == 4:
            self.ai2 = torch.tensor(ai[0])
            self.ai4 = torch.tensor(ai[1])
            self.ai6 = torch.tensor(ai[2])
            self.ai8 = torch.tensor(ai[3])
        elif self.ai_degree == 5:
            self.ai2 = torch.tensor(ai[0])
            self.ai4 = torch.tensor(ai[1])
            self.ai6 = torch.tensor(ai[2])
            self.ai8 = torch.tensor(ai[3])
            self.ai10 = torch.tensor(ai[4])
        elif self.ai_degree == 6:
            self.ai2 = torch.tensor(ai[0])
            self.ai4 = torch.tensor(ai[1])
            self.ai6 = torch.tensor(ai[2])
            self.ai8 = torch.tensor(ai[3])
            self.ai10 = torch.tensor(ai[4])
            self.ai12 = torch.tensor(ai[5])
        else:
            for i, a in enumerate(ai):
                exec(f"self.ai{2 * i + 2} = torch.tensor({a})")
    else:
        self.ai = None
        self.ai_degree = 0

    self.tolerancing = False
    self.to(device)

is_within_data_range

is_within_data_range(x, y)

Invalid when shape is non-defined.

Source code in deeplens/optics/geometric_surface/aspheric.py
def is_within_data_range(self, x, y):
    """Invalid when shape is non-defined."""
    c, k = self._get_curvature_params()
    if k > -1:
        return (x**2 + y**2) < 1 / c**2 / (1 + k)
    return torch.ones_like(x, dtype=torch.bool)

max_height

max_height()

Maximum valid height.

Source code in deeplens/optics/geometric_surface/aspheric.py
def max_height(self):
    """Maximum valid height."""
    c, k = self._get_curvature_params()
    if k > -1:
        return torch.sqrt(1 / (k + 1) / (c**2)).item() - 0.001
    return 10e3

get_optimizer_params

get_optimizer_params(lrs=[0.0001, 0.0001, 0.01, 0.0001], decay=0.001, optim_mat=False)

Get optimizer parameters for different parameters.

Parameters:

Name Type Description Default
lrs list

learning rates for d, c, k, ai2, (ai4, ai6, ai8, ai10, ai12).

[0.0001, 0.0001, 0.01, 0.0001]
optim_mat bool

whether to optimize material. Defaults to False.

False
Source code in deeplens/optics/geometric_surface/aspheric.py
def get_optimizer_params(
    self, lrs=[1e-4, 1e-4, 1e-2, 1e-4], decay=0.001, optim_mat=False
):
    """Get optimizer parameters for different parameters.

    Args:
        lrs (list, optional): learning rates for d, c, k, ai2, (ai4, ai6, ai8, ai10, ai12).
        optim_mat (bool, optional): whether to optimize material. Defaults to False.
    """
    # Broadcast learning rates to all aspheric coefficients
    if len(lrs) == 4:
        lrs = lrs + [
            lrs[-1] * decay ** (ai_degree + 1)
            for ai_degree in range(self.ai_degree - 1)
        ]

    params = []
    param_idx = 0

    # Optimize distance
    self.d.requires_grad_(True)
    params.append({"params": [self.d], "lr": lrs[param_idx]})
    param_idx += 1

    # Optimize curvature
    self.c.requires_grad_(True)
    params.append({"params": [self.c], "lr": lrs[param_idx]})
    param_idx += 1

    # Optimize conic constant
    self.k.requires_grad_(True)
    params.append({"params": [self.k], "lr": lrs[param_idx]})
    param_idx += 1

    # Optimize aspheric coefficients
    if self.ai is not None:
        if self.ai_degree > 0:
            for i in range(1, self.ai_degree + 1):
                p_name = f"ai{2 * i}"
                p = getattr(self, p_name)
                p.requires_grad_(True)
                params.append({"params": [p], "lr": lrs[param_idx]})
                param_idx += 1

    # Optimize material parameters
    if optim_mat and self.mat2.get_name() != "air":
        params += self.mat2.get_optimizer_params()

    return params

init_tolerance

init_tolerance(tolerance_params=None)

Perturb the surface with some tolerance.

Parameters:

Name Type Description Default
tolerance_params dict

Tolerance for surface parameters.

None
References

[1] https://www.edmundoptics.com/capabilities/precision-optics/capabilities/aspheric-lenses/ [2] https://www.edmundoptics.com/knowledge-center/application-notes/optics/all-about-aspheric-lenses/?srsltid=AfmBOoon8AUXVALojol2s5K20gQk7W1qUisc6cE4WzZp3ATFY5T1pK8q

Source code in deeplens/optics/geometric_surface/aspheric.py
@torch.no_grad()
def init_tolerance(self, tolerance_params=None):
    """Perturb the surface with some tolerance.

    Args:
        tolerance_params (dict): Tolerance for surface parameters.

    References:
        [1] https://www.edmundoptics.com/capabilities/precision-optics/capabilities/aspheric-lenses/
        [2] https://www.edmundoptics.com/knowledge-center/application-notes/optics/all-about-aspheric-lenses/?srsltid=AfmBOoon8AUXVALojol2s5K20gQk7W1qUisc6cE4WzZp3ATFY5T1pK8q
    """
    super().init_tolerance(tolerance_params)
    self.c_tole = tolerance_params.get("c_tole", 0.001)
    self.k_tole = tolerance_params.get("k_tole", 0.001)

sample_tolerance

sample_tolerance()

Randomly perturb surface parameters to simulate manufacturing errors.

Source code in deeplens/optics/geometric_surface/aspheric.py
def sample_tolerance(self):
    """Randomly perturb surface parameters to simulate manufacturing errors."""
    super().sample_tolerance()
    self.c_error = float(np.random.randn() * self.c_tole)
    self.k_error = float(np.random.randn() * self.k_tole)

zero_tolerance

zero_tolerance()

Zero tolerance.

Source code in deeplens/optics/geometric_surface/aspheric.py
def zero_tolerance(self):
    """Zero tolerance."""
    super().zero_tolerance()
    self.c_error = 0.0
    self.k_error = 0.0

sensitivity_score

sensitivity_score()

Tolerance squared sum.

Source code in deeplens/optics/geometric_surface/aspheric.py
def sensitivity_score(self):
    """Tolerance squared sum."""
    score_dict = super().sensitivity_score()

    score_dict.update(
        {
            "c_grad": round(self.c.grad.item(), 6),
            "c_score": round(
                (self.c_tole**2 * self.c.grad**2).item(), 6
            ),
        }
    )

    score_dict.update(
        {
            "k_grad": round(self.k.grad.item(), 6),
            "k_score": round(
                (self.k_tole**2 * self.k.grad**2).item(), 6
            ),
        }
    )
    return score_dict

surf_dict

surf_dict()

Return a dict of surface.

Source code in deeplens/optics/geometric_surface/aspheric.py
def surf_dict(self):
    """Return a dict of surface."""
    surf_dict = {
        "type": "Aspheric",
        "r": round(self.r, 4),
        "(c)": round(self.c.item(), 4),
        "roc": round(1 / self.c.item(), 4),
        "d": round(self.d.item(), 4),
        "k": round(self.k.item(), 4),
        "ai": [],
        "mat2": self.mat2.get_name(),
    }
    for i in range(1, self.ai_degree + 1):
        exec(
            f"surf_dict['(ai{2 * i})'] = float(format(self.ai{2 * i}.item(), '.6e'))"
        )
        surf_dict["ai"].append(float(format(eval(f"self.ai{2 * i}.item()"), ".6e")))

    return surf_dict

zmx_str

zmx_str(surf_idx, d_next)

Return Zemax surface string.

Source code in deeplens/optics/geometric_surface/aspheric.py
    def zmx_str(self, surf_idx, d_next):
        """Return Zemax surface string."""
        assert self.c.item() != 0, (
            "Aperture surface is re-implemented in Aperture class."
        )
        assert self.ai is not None or self.k != 0, (
            "Spheric surface is re-implemented in Spheric class."
        )
        if self.mat2.get_name() == "air":
            zmx_str = f"""SURF {surf_idx} 
    TYPE EVENASPH
    CURV {self.c.item()} 
    DISZ {d_next.item()}
    DIAM {self.r} 1 0 0 1 ""
    CONI {self.k}
    PARM 1 {self.ai2.item()}
    PARM 2 {self.ai4.item()}
    PARM 3 {self.ai6.item()}
    PARM 4 {self.ai8.item()}
    PARM 5 {self.ai10.item()}
    PARM 6 {self.ai12.item()}
"""
        else:
            zmx_str = f"""SURF {surf_idx} 
    TYPE EVENASPH 
    CURV {self.c.item()} 
    DISZ {d_next.item()} 
    GLAS ___BLANK 1 0 {self.mat2.n} {self.mat2.V}
    DIAM {self.r} 1 0 0 1 ""
    CONI {self.k}
    PARM 1 {self.ai2.item()}
    PARM 2 {self.ai4.item()}
    PARM 3 {self.ai6.item()}
    PARM 4 {self.ai8.item()}
    PARM 5 {self.ai10.item()}
    PARM 6 {self.ai12.item()}
"""
        return zmx_str

deeplens.optics.geometric_surface.Aperture

Aperture(r, d, pos_xy=[0.0, 0.0], vec_local=[0.0, 0.0, 1.0], is_square=False, device='cpu')

Bases: Plane

Aperture surface.

Source code in deeplens/optics/geometric_surface/aperture.py
def __init__(
    self,
    r,
    d,
    pos_xy=[0.0, 0.0],
    vec_local=[0.0, 0.0, 1.0],
    is_square=False,
    device="cpu",
):
    """Aperture surface."""
    Plane.__init__(
        self,
        r=r,
        d=d,
        mat2="air",
        pos_xy=pos_xy,
        vec_local=vec_local,
        is_square=is_square,
        device=device,
    )
    self.tolerancing = False
    self.to(device)

ray_reaction

ray_reaction(ray, n1=1.0, n2=1.0, refraction=False)

Compute output ray after intersection and refraction.

Source code in deeplens/optics/geometric_surface/aperture.py
def ray_reaction(self, ray, n1=1.0, n2=1.0, refraction=False):
    """Compute output ray after intersection and refraction."""
    ray = self.to_local_coord(ray)
    ray = self.intersect(ray)
    ray = self.to_global_coord(ray)
    return ray

draw_widget

draw_widget(ax, color='orange', linestyle='solid')

Draw aperture wedge on the figure.

Source code in deeplens/optics/geometric_surface/aperture.py
def draw_widget(self, ax, color="orange", linestyle="solid"):
    """Draw aperture wedge on the figure."""
    d = self.d.item()
    aper_wedge_l = 0.05 * self.r  # [mm]
    aper_wedge_h = 0.15 * self.r  # [mm]

    # Parallel edges
    z = np.linspace(d - aper_wedge_l, d + aper_wedge_l, 3)
    x = -self.r * np.ones(3)
    ax.plot(z, x, color=color, linestyle=linestyle, linewidth=0.8)
    x = self.r * np.ones(3)
    ax.plot(z, x, color=color, linestyle=linestyle, linewidth=0.8)

    # Vertical edges
    z = d * np.ones(3)
    x = np.linspace(self.r, self.r + aper_wedge_h, 3)
    ax.plot(z, x, color=color, linestyle=linestyle, linewidth=0.8)
    x = np.linspace(-self.r - aper_wedge_h, -self.r, 3)
    ax.plot(z, x, color=color, linestyle=linestyle, linewidth=0.8)

draw_widget3D

draw_widget3D(ax, color='black')

Draw the aperture as a circle in a 3D plot.

Source code in deeplens/optics/geometric_surface/aperture.py
def draw_widget3D(self, ax, color="black"):
    """Draw the aperture as a circle in a 3D plot."""
    # Draw the edge circle
    theta = np.linspace(0, 2 * np.pi, 100)
    edge_x = self.r * np.cos(theta)
    edge_y = self.r * np.sin(theta)
    edge_z = np.full_like(edge_x, self.d.item())  # Constant z at aperture position

    # Plot the edge circle
    line = ax.plot(edge_z, edge_x, edge_y, color=color, linewidth=1.5)

    return line

create_mesh

create_mesh(n_rings=32, n_arms=128, color=[0.0, 0.0, 0.0])

Create triangulated surface mesh.

Parameters:

Name Type Description Default
n_rings int

Number of concentric rings for sampling.

32
n_arms int

Number of angular divisions.

128
color List[float]

The color of the mesh.

[0.0, 0.0, 0.0]

Returns:

Name Type Description
self

The surface with mesh data.

Source code in deeplens/optics/geometric_surface/aperture.py
def create_mesh(self, n_rings=32, n_arms=128, color=[0.0, 0.0, 0.0]):
    """Create triangulated surface mesh.

    Args:
        n_rings (int): Number of concentric rings for sampling.
        n_arms (int): Number of angular divisions.
        color (List[float]): The color of the mesh.

    Returns:
        self: The surface with mesh data.
    """
    self.vertices = self._create_vertices(n_rings, n_arms)
    self.faces = self._create_faces(n_rings, n_arms)
    self.rim = self._create_rim(n_rings, n_arms)
    self.mesh_color = color
    return self

get_optimizer_params

get_optimizer_params(lrs=[0.0001])

Activate gradient computation for d and return optimizer parameters.

Source code in deeplens/optics/geometric_surface/aperture.py
def get_optimizer_params(self, lrs=[1e-4]):
    """Activate gradient computation for d and return optimizer parameters."""
    self.d.requires_grad_(True)

    params = []
    params.append({"params": [self.d], "lr": lrs[0]})

    return params

surf_dict

surf_dict()

Dict of surface parameters.

Source code in deeplens/optics/geometric_surface/aperture.py
def surf_dict(self):
    """Dict of surface parameters."""
    surf_dict = {
        "type": "Aperture",
        "r": round(self.r, 4),
        "(d)": round(self.d.item(), 4),
        "mat2": "air",
        "is_square": self.is_square,
    }
    return surf_dict

zmx_str

zmx_str(surf_idx, d_next)

Zemax surface string.

Source code in deeplens/optics/geometric_surface/aperture.py
    def zmx_str(self, surf_idx, d_next):
        """Zemax surface string."""
        zmx_str = f"""SURF {surf_idx}
    STOP
    TYPE STANDARD
    CURV 0.0
    DISZ {d_next.item()}
"""
        return zmx_str

Light Representations

Geometric ray representation carrying origin, direction, wavelength, validity mask, energy, and optical path length (OPL).

deeplens.optics.Ray

Ray(o, d, wvln=DEFAULT_WAVE, coherent=False, device='cpu')

Bases: DeepObj

Batched ray bundle for optical simulation.

Stores ray origins, directions, wavelength, validity mask, energy, obliquity, and (in coherent mode) optical path length. All tensor attributes share the same batch shape (*batch_size, num_rays).

Attributes:

Name Type Description
o Tensor

Ray origins, shape (*batch, num_rays, 3) [mm].

d Tensor

Unit ray directions, shape (*batch, num_rays, 3).

wvln Tensor

Wavelength scalar [µm].

is_valid Tensor

Binary validity mask, shape (*batch, num_rays).

en Tensor

Energy weight, shape (*batch, num_rays, 1).

obliq Tensor

Obliquity factor, shape (*batch, num_rays, 1).

opl Tensor

Optical path length (coherent mode only), shape (*batch, num_rays, 1) [mm].

coherent bool

Whether OPL tracking is enabled.

Initialize a ray object.

Parameters:

Name Type Description Default
o Tensor

Ray origin, shape (..., num_rays, 3) [mm].

required
d Tensor

Ray direction, shape (..., num_rays, 3).

required
wvln float

Ray wavelength [µm].

DEFAULT_WAVE
coherent bool

Enable optical path length tracking for coherent tracing. Defaults to False.

False
device str

Compute device. Defaults to "cpu".

'cpu'
Source code in deeplens/optics/light/ray.py
def __init__(self, o, d, wvln=DEFAULT_WAVE, coherent=False, device="cpu"):
    """Initialize a ray object.

    Args:
        o (torch.Tensor): Ray origin, shape ``(..., num_rays, 3)`` [mm].
        d (torch.Tensor): Ray direction, shape ``(..., num_rays, 3)``.
        wvln (float): Ray wavelength [µm].
        coherent (bool): Enable optical path length tracking for coherent
            tracing. Defaults to ``False``.
        device (str): Compute device. Defaults to ``"cpu"``.
    """
    # Basic ray parameters - move to device
    self.o = (o if torch.is_tensor(o) else torch.tensor(o)).to(device)
    self.d = (d if torch.is_tensor(d) else torch.tensor(d)).to(device)
    self.shape = self.o.shape[:-1]

    # Wavelength
    assert wvln > 0.1 and wvln < 10.0, "Ray wavelength unit should be [um]"
    self.wvln = torch.tensor(wvln, device=device)

    # Auxiliary ray parameters - create directly on device
    self.is_valid = torch.ones(self.shape, device=device)
    self.en = torch.ones((*self.shape, 1), device=device)
    self.obliq = torch.ones((*self.shape, 1), device=device)

    # Coherent ray tracing
    self.coherent = coherent  # bool
    self.opl = torch.zeros((*self.shape, 1), device=device)

    self.device = device
    self.d = F.normalize(self.d, p=2, dim=-1)

prop_to

prop_to(z, n=1.0)

Ray propagates to a given depth plane.

Parameters:

Name Type Description Default
z float

depth.

required
n float

refractive index. Defaults to 1.

1.0
Source code in deeplens/optics/light/ray.py
def prop_to(self, z, n=1.0):
    """Ray propagates to a given depth plane.

    Args:
        z (float): depth.
        n (float, optional): refractive index. Defaults to 1.
    """
    t = (z - self.o[..., 2]) / self.d[..., 2]
    new_o = self.o + self.d * t.unsqueeze(-1)
    valid_mask = (self.is_valid > 0).unsqueeze(-1)
    self.o = torch.where(valid_mask, new_o, self.o)

    if self.coherent:
        if t.dtype != torch.float64:
            raise Warning("Should use float64 in coherent ray tracing.")
        else:
            new_opl = self.opl + n * t.unsqueeze(-1)
            self.opl = torch.where(valid_mask, new_opl, self.opl)

    return self

centroid

centroid()

Calculate the centroid of the ray, shape (..., num_rays, 3)

Returns:

Type Description

torch.Tensor: Centroid of the ray, shape (..., 3)

Source code in deeplens/optics/light/ray.py
def centroid(self):
    """Calculate the centroid of the ray, shape (..., num_rays, 3)

    Returns:
        torch.Tensor: Centroid of the ray, shape (..., 3)
    """
    return (self.o * self.is_valid.unsqueeze(-1)).sum(-2) / self.is_valid.sum(
        -1
    ).add(EPSILON).unsqueeze(-1)

rms_error

rms_error(center_ref=None)

Calculate the RMS error of the ray.

Parameters:

Name Type Description Default
center_ref Tensor

Reference center of the ray, shape (..., 3). If None, use the centroid of the ray as reference.

None

Returns:

Type Description

torch.Tensor: average RMS error of the ray

Source code in deeplens/optics/light/ray.py
def rms_error(self, center_ref=None):
    """Calculate the RMS error of the ray.

    Args:
        center_ref (torch.Tensor): Reference center of the ray, shape (..., 3). If None, use the centroid of the ray as reference.

    Returns:
        torch.Tensor: average RMS error of the ray
    """
    # Calculate the centroid of the ray as reference
    if center_ref is None:
        with torch.no_grad():
            center_ref = self.centroid()

    center_ref = center_ref.unsqueeze(-2)

    # Calculate RMS error for each region
    rms_error = ((self.o[..., :2] - center_ref[..., :2]) ** 2).sum(-1)
    rms_error = (rms_error * self.is_valid).sum(-1) / (
        self.is_valid.sum(-1) + EPSILON
    )
    rms_error = rms_error.sqrt()

    # Average RMS error
    return rms_error.mean()

flip_xy

flip_xy()

Flip the x and y coordinates of the ray.

This function is used when calculating point spread function and wavefront distribution.

Source code in deeplens/optics/light/ray.py
def flip_xy(self):
    """Flip the x and y coordinates of the ray.

    This function is used when calculating point spread function and wavefront distribution.
    """
    self.o = torch.cat([-self.o[..., :2], self.o[..., 2:]], dim=-1)
    self.d = torch.cat([-self.d[..., :2], self.d[..., 2:]], dim=-1)
    return self

clone

clone(device=None)

Clone the ray.

Can spercify which device we want to clone. Sometimes we want to store all rays in CPU, and when using it, we move it to GPU.

Source code in deeplens/optics/light/ray.py
def clone(self, device=None):
    """Clone the ray.

    Can spercify which device we want to clone. Sometimes we want to store all rays in CPU, and when using it, we move it to GPU.
    """
    if device is None:
        return copy.deepcopy(self).to(self.device)
    else:
        return copy.deepcopy(self).to(device)

squeeze

squeeze(dim=None)

Squeeze the ray.

Parameters:

Name Type Description Default
dim int

dimension to squeeze. Defaults to None.

None
Source code in deeplens/optics/light/ray.py
def squeeze(self, dim=None):
    """Squeeze the ray.

    Args:
        dim (int, optional): dimension to squeeze. Defaults to None.
    """
    self.o = self.o.squeeze(dim)
    self.d = self.d.squeeze(dim)
    # wvln is a single element tensor, no squeeze needed
    self.is_valid = self.is_valid.squeeze(dim)
    self.en = self.en.squeeze(dim)
    self.opl = self.opl.squeeze(dim)
    self.obliq = self.obliq.squeeze(dim)
    return self

unsqueeze

unsqueeze(dim=None)

Unsqueeze the ray.

Parameters:

Name Type Description Default
dim int

dimension to unsqueeze. Defaults to None.

None
Source code in deeplens/optics/light/ray.py
def unsqueeze(self, dim=None):
    """Unsqueeze the ray.

    Args:
        dim (int, optional): dimension to unsqueeze. Defaults to None.
    """
    self.o = self.o.unsqueeze(dim)
    self.d = self.d.unsqueeze(dim)
    # wvln is a single element tensor, no unsqueeze needed
    self.is_valid = self.is_valid.unsqueeze(dim)
    self.en = self.en.unsqueeze(dim)
    self.opl = self.opl.unsqueeze(dim)
    self.obliq = self.obliq.unsqueeze(dim)
    return self

Complex electromagnetic field with Angular Spectrum Method (ASM), Fresnel, and Fraunhofer propagation via torch.fft.

deeplens.optics.ComplexWave

ComplexWave(u=None, wvln=0.55, z=0.0, phy_size=(4.0, 4.0), res=(2000, 2000))

Bases: DeepObj

Complex scalar wave field for diffraction simulation.

Represents a monochromatic, coherent complex amplitude on a uniform rectangular grid. Propagation methods (ASM, Fresnel, Fraunhofer) are implemented as member functions and use torch.fft for efficiency.

Attributes:

Name Type Description
u Tensor

Complex amplitude, shape [1, 1, H, W].

wvln float

Wavelength [µm].

k float

Wave number 2π / (λ × 10⁻³) [mm⁻¹].

phy_size tuple

Physical aperture size (W, H) [mm].

ps float

Pixel pitch [mm] (must be square).

res tuple

Grid resolution (H, W) in pixels.

z float

Current axial position [mm].

Initialize a complex wave field.

Parameters:

Name Type Description Default
u Tensor or None

Initial complex amplitude. Accepted shapes: [H, W], [1, H, W], or [1, 1, H, W]. If None a zero field is created with the given res.

None
wvln float

Wavelength [µm]. Defaults to 0.55.

0.55
z float

Initial axial position [mm]. Defaults to 0.0.

0.0
phy_size tuple

Physical aperture (W, H) [mm]. Defaults to (4.0, 4.0).

(4.0, 4.0)
res tuple

Grid resolution (H, W) [pixels]. Only used when u is None. Defaults to (2000, 2000).

(2000, 2000)

Raises:

Type Description
AssertionError

If the pixel pitch is not square or the wavelength is outside the range (0.1, 10) µm.

Source code in deeplens/optics/light/wave.py
def __init__(
    self,
    u=None,
    wvln=0.55,
    z=0.0,
    phy_size=(4.0, 4.0),
    res=(2000, 2000),
):
    """Initialize a complex wave field.

    Args:
        u (torch.Tensor or None, optional): Initial complex amplitude.
            Accepted shapes: ``[H, W]``, ``[1, H, W]``, or
            ``[1, 1, H, W]``.  If ``None`` a zero field is created with
            the given *res*.
        wvln (float, optional): Wavelength [µm].  Defaults to ``0.55``.
        z (float, optional): Initial axial position [mm].  Defaults to
            ``0.0``.
        phy_size (tuple, optional): Physical aperture (W, H) [mm].
            Defaults to ``(4.0, 4.0)``.
        res (tuple, optional): Grid resolution (H, W) [pixels].  Only
            used when *u* is ``None``.  Defaults to ``(2000, 2000)``.

    Raises:
        AssertionError: If the pixel pitch is not square or the
            wavelength is outside the range ``(0.1, 10)`` µm.
    """
    if u is not None:
        if not u.dtype == torch.complex128:
            print(
                "A complex wave field is created with single precision. In the future, we want to always use double precision."
            )

        self.u = u if torch.is_tensor(u) else torch.from_numpy(u)
        if not self.u.is_complex():
            self.u = self.u.to(torch.complex64)

        # [H, W] or [1, H, W] to [1, 1, H, W]
        if len(u.shape) == 2:
            self.u = u.unsqueeze(0).unsqueeze(0)
        elif len(self.u.shape) == 3:
            self.u = self.u.unsqueeze(0)

        self.res = self.u.shape[-2:]

    else:
        # Initialize a zero complex wave field
        amp = torch.zeros(res).unsqueeze(0).unsqueeze(0)
        phi = torch.zeros(res).unsqueeze(0).unsqueeze(0)
        self.u = amp + 1j * phi
        self.res = res

    # Wave field parameters
    assert wvln > 0.1 and wvln < 10.0, "Wavelength should be in [um]."
    self.wvln = wvln  # [um], wavelength
    self.k = 2 * torch.pi / (self.wvln * 1e-3)  # [mm^-1], wave number
    self.phy_size = phy_size  # [mm], physical size
    assert phy_size[0] / self.res[0] == phy_size[1] / self.res[1], (
        "Pixel size is not square."
    )
    self.ps = phy_size[0] / self.res[0]  # [mm], pixel size

    # Wave field grid
    self.x, self.y = self.gen_xy_grid()  # x, y grid
    self.z = torch.full_like(self.x, z)  # z grid

point_wave classmethod

point_wave(point=(0, 0, -1000.0), wvln=0.55, z=0.0, phy_size=(4.0, 4.0), res=(2000, 2000), valid_r=None)

Create a spherical wave field on x0y plane originating from a point source.

Parameters:

Name Type Description Default
point tuple

Point source position in object space. [mm]. Defaults to (0, 0, -1000.0).

(0, 0, -1000.0)
wvln float

Wavelength. [um]. Defaults to 0.55.

0.55
z float

Field z position. [mm]. Defaults to 0.0.

0.0
phy_size tuple

Valid plane on x0y plane. [mm]. Defaults to (2, 2).

(4.0, 4.0)
res tuple

Valid plane resoltution. Defaults to (1000, 1000).

(2000, 2000)
valid_r float

Valid circle radius. [mm]. Defaults to None.

None

Returns:

Name Type Description
field ComplexWave

Complex field on x0y plane.

Source code in deeplens/optics/light/wave.py
@classmethod
def point_wave(
    cls,
    point=(0, 0, -1000.0),
    wvln=0.55,
    z=0.0,
    phy_size=(4.0, 4.0),
    res=(2000, 2000),
    valid_r=None,
):
    """Create a spherical wave field on x0y plane originating from a point source.

    Args:
        point (tuple): Point source position in object space. [mm]. Defaults to (0, 0, -1000.0).
        wvln (float): Wavelength. [um]. Defaults to 0.55.
        z (float): Field z position. [mm]. Defaults to 0.0.
        phy_size (tuple): Valid plane on x0y plane. [mm]. Defaults to (2, 2).
        res (tuple): Valid plane resoltution. Defaults to (1000, 1000).
        valid_r (float): Valid circle radius. [mm]. Defaults to None.

    Returns:
        field (ComplexWave): Complex field on x0y plane.
    """
    assert wvln > 0.1 and wvln < 10.0, "Wavelength should be in [um]."
    k = 2 * torch.pi / (wvln * 1e-3)  # [mm^-1], wave number

    # Create meshgrid on target plane
    x, y = torch.meshgrid(
        torch.linspace(
            -0.5 * phy_size[0], 0.5 * phy_size[0], res[0], dtype=torch.float64
        ),
        torch.linspace(
            0.5 * phy_size[1], -0.5 * phy_size[1], res[1], dtype=torch.float64
        ),
        indexing="xy",
    )

    # Calculate distance to point source, and calculate spherical wave phase
    r = torch.sqrt((x - point[0]) ** 2 + (y - point[1]) ** 2 + (z - point[2]) ** 2)
    if point[2] < z:
        phi = k * r
    else:
        phi = -k * r
    u = (r.min() / r) * torch.exp(1j * phi)

    # Apply valid circle if provided, e.g., the aperture of a lens
    if valid_r is not None:
        mask = (x - point[0]) ** 2 + (y - point[1]) ** 2 < valid_r**2
        u = u * mask

    # Create wave field
    return cls(u=u, wvln=wvln, phy_size=phy_size, res=res, z=z)

plane_wave classmethod

plane_wave(wvln=0.55, z=0.0, phy_size=(4.0, 4.0), res=(2000, 2000), valid_r=None)

Create a planar wave field on x0y plane.

Parameters:

Name Type Description Default
wvln float

Wavelength. [um].

0.55
z float

Field z position. [mm].

0.0
phy_size tuple

Physical size of the field. [mm].

(4.0, 4.0)
res tuple

Resolution.

(2000, 2000)
valid_r float

Valid circle radius. [mm].

None

Returns:

Name Type Description
field ComplexWave

Complex field.

Source code in deeplens/optics/light/wave.py
@classmethod
def plane_wave(
    cls,
    wvln=0.55,
    z=0.0,
    phy_size=(4.0, 4.0),
    res=(2000, 2000),
    valid_r=None,
):
    """Create a planar wave field on x0y plane.

    Args:
        wvln (float): Wavelength. [um].
        z (float): Field z position. [mm].
        phy_size (tuple): Physical size of the field. [mm].
        res (tuple): Resolution.
        valid_r (float): Valid circle radius. [mm].

    Returns:
        field (ComplexWave): Complex field.
    """
    assert wvln > 0.1 and wvln < 10.0, "Wavelength should be in [um]."

    # Create a plane wave field
    u = torch.ones(res, dtype=torch.float64) + 0j

    # Apply valid circle if provided
    if valid_r is not None:
        x, y = torch.meshgrid(
            torch.linspace(-0.5 * phy_size[0], 0.5 * phy_size[0], res[0]),
            torch.linspace(-0.5 * phy_size[1], 0.5 * phy_size[1], res[1]),
            indexing="xy",
        )
        mask = (x**2 + y**2) < valid_r**2
        u = u * mask

    # Create wave field
    return cls(u=u, phy_size=phy_size, wvln=wvln, res=res, z=z)

image_wave classmethod

image_wave(img, wvln=0.55, z=0.0, phy_size=(4.0, 4.0))

Initialize a complex wave field from an image.

Parameters:

Name Type Description Default
img Tensor

Input image with shape [H, W] or [B, C, H, W]. Data range is [0, 1].

required
wvln float

Wavelength. [um].

0.55
z float

Field z position. [mm].

0.0
phy_size tuple

Physical size of the field. [mm].

(4.0, 4.0)

Returns:

Name Type Description
field ComplexWave

Complex field.

Source code in deeplens/optics/light/wave.py
@classmethod
def image_wave(cls, img, wvln=0.55, z=0.0, phy_size=(4.0, 4.0)):
    """Initialize a complex wave field from an image.

    Args:
        img (torch.Tensor): Input image with shape [H, W] or [B, C, H, W]. Data range is [0, 1].
        wvln (float): Wavelength. [um].
        z (float): Field z position. [mm].
        phy_size (tuple): Physical size of the field. [mm].

    Returns:
        field (ComplexWave): Complex field.
    """
    assert img.dtype == torch.float32, "Image must be float32."

    amp = torch.sqrt(img)
    phi = torch.zeros_like(amp)
    u = amp + 1j * phi

    return cls(u=u, wvln=wvln, phy_size=phy_size, res=u.shape[-2:], z=z)

prop

prop(prop_dist, n=1.0)

Propagate the field by distance z. Can only propagate planar wave.

Reference

[1] Modeling and propagation of near-field diffraction patterns: A more complete approach. Table 1. [2] https://github.com/kaanaksit/odak/blob/master/odak/wave/classical.py [3] https://spie.org/samples/PM103.pdf [4] "Non-approximated Rayleigh Sommerfeld diffraction integral: advantages and disadvantages in the propagation of complex wave fields"

Parameters:

Name Type Description Default
prop_dist float

propagation distance, unit [mm].

required
n float

refractive index.

1.0

Returns:

Name Type Description
self

propagated complex wave field.

Source code in deeplens/optics/light/wave.py
def prop(self, prop_dist, n=1.0):
    """Propagate the field by distance z. Can only propagate planar wave.

    Reference:
        [1] Modeling and propagation of near-field diffraction patterns: A more complete approach. Table 1.
        [2] https://github.com/kaanaksit/odak/blob/master/odak/wave/classical.py
        [3] https://spie.org/samples/PM103.pdf
        [4] "Non-approximated Rayleigh Sommerfeld diffraction integral: advantages and disadvantages in the propagation of complex wave fields"

    Args:
        prop_dist (float): propagation distance, unit [mm].
        n (float): refractive index.

    Returns:
        self: propagated complex wave field.
    """
    # Determine propagation method and perform propagation
    wvln_mm = self.wvln * 1e-3  # [um] to [mm]
    asm_zmax = Nyquist_ASM_zmax(wvln=self.wvln, ps=self.ps, side_length=self.phy_size[0])
    fresnel_zmin = Fresnel_zmin(wvln=self.wvln, ps=self.ps, side_length=self.phy_size[0])

    # Wave propagation methods
    if prop_dist < DELTA:
        # Zero distance: do nothing
        pass

    elif prop_dist < wvln_mm:
        # Sub-wavelength distance: full wave method (e.g., FDTD)
        raise Exception(
            "The propagation distance in sub-wavelength range is not implemented yet. Have to use full wave method (e.g., FDTD)."
        )

    elif prop_dist < asm_zmax:
        # Angular Spectrum Method (ASM)
        self.u = AngularSpectrumMethod(self.u, z=prop_dist, wvln=self.wvln, ps=self.ps, n=n)

    elif prop_dist > fresnel_zmin:
        # Fresnel diffraction
        self.u = FresnelDiffraction(self.u, z=prop_dist, wvln=self.wvln, ps=self.ps, n=n)

    else:
        raise Exception(f"Propagation method not implemented for distance {prop_dist} mm.")

    # Update z grid
    self.z += prop_dist
    return self

prop_to

prop_to(z, n=1)

Propagate the field to plane z.

Parameters:

Name Type Description Default
z float

destination plane z coordinate.

required
Source code in deeplens/optics/light/wave.py
def prop_to(self, z, n=1):
    """Propagate the field to plane z.

    Args:
        z (float): destination plane z coordinate.
    """
    prop_dist = z - self.z[0, 0].item()
    self.prop(prop_dist, n=n)
    return self

gen_xy_grid

gen_xy_grid()

Generate the x and y grid.

Source code in deeplens/optics/light/wave.py
def gen_xy_grid(self):
    """Generate the x and y grid."""
    x, y = torch.meshgrid(
        torch.linspace(-0.5 * self.phy_size[1], 0.5 * self.phy_size[1], self.res[0],),
        torch.linspace(0.5 * self.phy_size[0], -0.5 * self.phy_size[0], self.res[1],),
        indexing="xy",
    )
    return x, y

gen_freq_grid

gen_freq_grid()

Generate the frequency grid.

Source code in deeplens/optics/light/wave.py
def gen_freq_grid(self):
    """Generate the frequency grid."""
    x, y = self.gen_xy_grid()
    fx = x / (self.ps * self.phy_size[0])
    fy = y / (self.ps * self.phy_size[1])
    return fx, fy

load_npz

load_npz(filepath)

Load data from npz file.

Source code in deeplens/optics/light/wave.py
def load_npz(self, filepath):
    """Load data from npz file."""
    data = np.load(filepath)
    self.u = torch.from_numpy(data["u"])
    self.x = torch.from_numpy(data["x"])
    self.y = torch.from_numpy(data["y"])
    self.wvln = data["wvln"].item()
    self.phy_size = data["phy_size"].tolist()
    self.res = self.u.shape[-2:]

save

save(filepath='./wavefield.npz')

Save the complex wave field to a npz file.

Source code in deeplens/optics/light/wave.py
def save(self, filepath="./wavefield.npz"):
    """Save the complex wave field to a npz file."""
    if filepath.endswith(".npz"):
        self.save_npz(filepath)
    else:
        raise Exception("Unimplemented file format.")

save_npz

save_npz(filepath='./wavefield.npz')

Save the complex wave field to a npz file.

Source code in deeplens/optics/light/wave.py
def save_npz(self, filepath="./wavefield.npz"):
    """Save the complex wave field to a npz file."""
    # Save data
    np.savez_compressed(
        filepath,
        u=self.u.cpu().numpy(),
        x=self.x.cpu().numpy(),
        y=self.y.cpu().numpy(),
        wvln=np.array(self.wvln),
        phy_size=np.array(self.phy_size),
    )

    # Save intensity, amplitude, and phase images
    u = self.u.cpu()
    save_image(u.abs() ** 2, f"{filepath[:-4]}_intensity.png", normalize=True)
    save_image(u.abs(), f"{filepath[:-4]}_amp.png", normalize=True)
    save_image(u.angle(), f"{filepath[:-4]}_phase.png", normalize=True)

show

show(save_name=None, data='irr')

Save the field as an image.

Source code in deeplens/optics/light/wave.py
def show(self, save_name=None, data="irr"):
    """Save the field as an image."""
    cmap = "gray"
    if data == "irr":
        value = self.u.detach().abs() ** 2
    elif data == "amp":
        value = self.u.detach().abs()
    elif data == "phi" or data == "phase":
        value = torch.angle(self.u).detach()
        cmap = "hsv"
    elif data == "real":
        value = self.u.real.detach()
    elif data == "imag":
        value = self.u.imag.detach()
    else:
        raise Exception(f"Unimplemented visualization: {data}.")

    if len(self.u.shape) == 2:
        raise Exception("Deprecated.")
        if save_name is not None:
            save_image(value, save_name, normalize=True)
        else:
            value = value.cpu().numpy()
            plt.imshow(
                value,
                cmap=cmap,
                extent=[
                    -self.phy_size[0] / 2,
                    self.phy_size[0] / 2,
                    -self.phy_size[1] / 2,
                    self.phy_size[1] / 2,
                ],
            )

    elif len(self.u.shape) == 4:
        B, C, H, W = self.u.shape
        if B == 1:
            if save_name is not None:
                save_image(value, save_name, normalize=True)
            else:
                value = value.cpu().numpy()
                plt.imshow(
                    value[0, 0, :, :],
                    cmap=cmap,
                    extent=[
                        -self.phy_size[0] / 2,
                        self.phy_size[0] / 2,
                        -self.phy_size[1] / 2,
                        self.phy_size[1] / 2,
                    ],
                )
        else:
            if save_name is not None:
                plt.savefig(save_name)
            else:
                value = value.cpu().numpy()
                fig, axs = plt.subplots(1, B)
                for i in range(B):
                    axs[i].imshow(
                        value[i, 0, :, :],
                        cmap=cmap,
                        extent=[
                            -self.phy_size[0] / 2,
                            self.phy_size[0] / 2,
                            -self.phy_size[1] / 2,
                            self.phy_size[1] / 2,
                        ],
                    )
                fig.show()
    else:
        raise Exception("Unsupported complex field shape.")

pad

pad(Hpad, Wpad)

Pad the input field by (Hpad, Hpad, Wpad, Wpad). This step will also expand physical size of the field.

Parameters:

Name Type Description Default
Hpad int

Number of pixels to pad on the top and bottom.

required
Wpad int

Number of pixels to pad on the left and right.

required

Returns:

Name Type Description
self

Padded complex wave field.

Source code in deeplens/optics/light/wave.py
def pad(self, Hpad, Wpad):
    """Pad the input field by (Hpad, Hpad, Wpad, Wpad). This step will also expand physical size of the field.

    Args:
        Hpad (int): Number of pixels to pad on the top and bottom.
        Wpad (int): Number of pixels to pad on the left and right.

    Returns:
        self: Padded complex wave field.
    """
    self.u = F.pad(self.u, (Hpad, Hpad, Wpad, Wpad), mode="constant", value=0)

    Horg, Worg = self.res
    self.res = [Horg + 2 * Hpad, Worg + 2 * Wpad]
    self.phy_size = [
        self.phy_size[0] * self.res[0] / Horg,
        self.phy_size[1] * self.res[1] / Worg,
    ]
    self.x, self.y = self.gen_xy_grid()
    self.z = torch.full_like(self.x, self.z[0, 0].item())

flip

flip()

Flip the field horizontally and vertically.

Source code in deeplens/optics/light/wave.py
def flip(self):
    """Flip the field horizontally and vertically."""
    self.u = torch.flip(self.u, [-1, -2])
    self.x = torch.flip(self.x, [-1, -2])
    self.y = torch.flip(self.y, [-1, -2])
    self.z = torch.flip(self.z, [-1, -2])
    return self

PSF Utilities

Functions for convolving images with point spread functions.

deeplens.optics.imgsim.psf.conv_psf

conv_psf(img, psf)

Convolve an image batch with a single spatially-uniform PSF.

Applies a per-channel 2-D convolution using reflect boundary padding so that the output has the same spatial dimensions as the input. The PSF is internally flipped to convert the cross-correlation implemented by F.conv2d into a true convolution.

Parameters:

Name Type Description Default
img Tensor

Input image batch, shape [B, C, H, W].

required
psf Tensor

PSF kernel, shape [C, ks, ks]. ks may be odd or even.

required

Returns:

Type Description

torch.Tensor: Rendered image, shape [B, C, H, W].

Example

psf = lens.psf_rgb(points=torch.tensor([0.0, 0.0, -10000.0])) img_blur = conv_psf(img, psf)

Source code in deeplens/optics/imgsim/psf.py
def conv_psf(img, psf):
    """Convolve an image batch with a single spatially-uniform PSF.

    Applies a per-channel 2-D convolution using ``reflect`` boundary padding
    so that the output has the same spatial dimensions as the input.  The PSF
    is internally flipped to convert the cross-correlation implemented by
    ``F.conv2d`` into a true convolution.

    Args:
        img (torch.Tensor): Input image batch, shape ``[B, C, H, W]``.
        psf (torch.Tensor): PSF kernel, shape ``[C, ks, ks]``.  ``ks`` may be
            odd or even.

    Returns:
        torch.Tensor: Rendered image, shape ``[B, C, H, W]``.

    Example:
        >>> psf = lens.psf_rgb(points=torch.tensor([0.0, 0.0, -10000.0]))
        >>> img_blur = conv_psf(img, psf)
    """
    B, C, H, W = img.shape
    C_psf, ks, _ = psf.shape
    assert C_psf == C, f"psf channels ({C_psf}) must match image channels ({C})."

    # Flip the PSF because F.conv2d use cross-correlation
    psf = torch.flip(psf, [1, 2])
    psf = psf.unsqueeze(1)  # shape [C, 1, ks, ks]

    # Padding
    pad_h_left  = (ks - 1) // 2
    pad_h_right = ks // 2
    pad_w_left  = (ks - 1) // 2
    pad_w_right = ks // 2
    img_pad = F.pad(img, (pad_w_left, pad_w_right, pad_h_left, pad_h_right), mode="reflect")

    # Convolution
    img_render = F.conv2d(img_pad, psf, groups=C)
    return img_render

deeplens.optics.imgsim.psf.conv_psf_map

conv_psf_map(img, psf_map)

Convolve an image batch with a spatially-varying PSF map.

Divides the image into grid_h × grid_w non-overlapping patches and convolves each patch with its corresponding PSF kernel. The results are assembled back into a full-resolution output via a weighted blending step.

Parameters:

Name Type Description Default
img Tensor

Input image batch, shape [B, C, H, W].

required
psf_map Tensor

PSF map, shape [grid_h, grid_w, C, ks, ks].

required

Returns:

Type Description

torch.Tensor: Rendered image, shape [B, C, H, W].

Source code in deeplens/optics/imgsim/psf.py
def conv_psf_map(img, psf_map):
    """Convolve an image batch with a spatially-varying PSF map.

    Divides the image into ``grid_h × grid_w`` non-overlapping patches and
    convolves each patch with its corresponding PSF kernel.  The results are
    assembled back into a full-resolution output via a weighted blending step.

    Args:
        img (torch.Tensor): Input image batch, shape ``[B, C, H, W]``.
        psf_map (torch.Tensor): PSF map, shape ``[grid_h, grid_w, C, ks, ks]``.

    Returns:
        torch.Tensor: Rendered image, shape ``[B, C, H, W]``.
    """
    B, C, H, W = img.shape
    grid_h, grid_w, C_psf, ks, _ = psf_map.shape
    assert C_psf == C, f"PSF map channels ({C_psf}) must match image channels ({C})."

    # Padding
    pad_h_left  = (ks - 1) // 2
    pad_h_right = ks // 2
    pad_w_left  = (ks - 1) // 2
    pad_w_right = ks // 2
    img_pad = F.pad(img, (pad_w_left, pad_w_right, pad_h_left, pad_h_right), mode="reflect")

    # Pre-flip entire PSF map once (instead of flipping each PSF inside the loop)
    psf_map_flipped = torch.flip(psf_map, dims=(-2, -1))

    # Render image patch by patch
    img_render = torch.zeros_like(img)
    for i in range(grid_h):
        h_low  = (i * H) // grid_h
        h_high = ((i + 1) * H) // grid_h

        for j in range(grid_w):
            w_low  = (j * W) // grid_w
            w_high = ((j + 1) * W) // grid_w

            # PSF, [C, 1, ks, ks]
            psf = psf_map_flipped[i, j].unsqueeze(1)

            # Consider overlap to avoid boundary artifacts
            img_pad_patch = img_pad[
                :,
                :,
                h_low : h_high + pad_h_left + pad_h_right,
                w_low : w_high + pad_w_left + pad_w_right,
            ]

            # Convolution, [B, C, h_high-h_low, w_high-w_low]
            render_patch = F.conv2d(img_pad_patch, psf, groups=C)  
            img_render[:, :, h_low:h_high, w_low:w_high] = render_patch

    return img_render