Skip to content

Surrogate Networks API Reference

Neural networks that learn to predict PSFs from lens parameters, replacing expensive ray tracing during training.


deeplens.surrogate.MLP

MLP(in_features, out_features, hidden_features=64, hidden_layers=3)

Bases: Module

Fully-connected network for low-frequency PSF prediction.

Predicts PSFs as flattened vectors using stacked linear layers with ReLU activations and a Sigmoid output. The output is L1-normalized so it sums to 1 (valid as a PSF energy distribution).

Parameters:

Name Type Description Default
in_features

Number of input features (e.g., field angle + wavelength).

required
out_features

Number of output features (flattened PSF size).

required
hidden_features

Width of hidden layers. Defaults to 64.

64
hidden_layers

Number of hidden layers. Defaults to 3.

3
Source code in deeplens/surrogate/mlp.py
def __init__(self, in_features, out_features, hidden_features=64, hidden_layers=3):
    super(MLP, self).__init__()

    layers = [
        nn.Linear(in_features, hidden_features // 4, bias=True),
        nn.ReLU(inplace=True),
        nn.Linear(hidden_features // 4, hidden_features, bias=True),
        nn.ReLU(inplace=True),
    ]

    for _ in range(hidden_layers):
        layers.extend(
            [
                nn.Linear(hidden_features, hidden_features, bias=True),
                nn.ReLU(inplace=True),
            ]
        )

    layers.extend(
        [nn.Linear(hidden_features, out_features, bias=True), nn.Sigmoid()]
    )

    self.net = nn.Sequential(*layers)

forward

forward(x)

Forward pass.

Parameters:

Name Type Description Default
x

Input tensor of shape (batch_size, in_features).

required

Returns:

Type Description

L1-normalized output tensor of shape (batch_size, out_features).

Source code in deeplens/surrogate/mlp.py
def forward(self, x):
    """Forward pass.

    Args:
        x: Input tensor of shape ``(batch_size, in_features)``.

    Returns:
        L1-normalized output tensor of shape ``(batch_size, out_features)``.
    """
    x = self.net(x)
    x = F.normalize(x, p=1, dim=-1)
    return x

deeplens.surrogate.MLPConv

MLPConv(in_features, ks, channels=3, activation='relu')

Bases: Module

MLP encoder + convolutional decoder for high-resolution PSF prediction.

Uses a linear encoder to produce a low-resolution feature map, then a transposed-convolution decoder to upsample to the target PSF resolution. The output is Sigmoid-activated and L1-normalized per spatial dimension.

Reference: "Differentiable Compound Optics and Processing Pipeline Optimization for End-To-end Camera Design".

Parameters:

Name Type Description Default
in_features

Number of input features (e.g., field angle + wavelength).

required
ks

Spatial size of the output PSF. Must be a multiple of 32 if > 32.

required
channels

Number of output channels. Defaults to 3.

3
activation

Activation function, "relu" or "sigmoid". Defaults to "relu".

'relu'
Source code in deeplens/surrogate/mlpconv.py
def __init__(self, in_features, ks, channels=3, activation="relu"):
    super(MLPConv, self).__init__()

    self.ks_mlp = min(ks, 32)
    if ks > 32:
        assert ks % 32 == 0, "ks must be 32n"
        upsample_times = int(math.log(ks / 32, 2))

    linear_output = channels * self.ks_mlp**2
    self.ks = ks
    self.channels = channels

    # MLP encoder
    self.encoder = nn.Sequential(
        nn.Linear(in_features, 256),
        nn.ReLU(),
        nn.Linear(256, 256),
        nn.ReLU(),
        nn.Linear(256, 512),
        nn.ReLU(),
        nn.Linear(512, linear_output),
    )

    # Conv decoder
    conv_layers = []
    conv_layers.append(
        nn.ConvTranspose2d(channels, 64, kernel_size=3, stride=1, padding=1)
    )
    conv_layers.append(nn.ReLU())
    for _ in range(upsample_times):
        conv_layers.append(
            nn.ConvTranspose2d(64, 64, kernel_size=3, stride=1, padding=1)
        )
        conv_layers.append(nn.ReLU())
        conv_layers.append(nn.Upsample(scale_factor=2))

    conv_layers.append(
        nn.ConvTranspose2d(64, 64, kernel_size=3, stride=1, padding=1)
    )
    conv_layers.append(nn.ReLU())
    conv_layers.append(
        nn.ConvTranspose2d(64, channels, kernel_size=3, stride=1, padding=1)
    )
    self.decoder = nn.Sequential(*conv_layers)

    if activation == "relu":
        self.activation = nn.ReLU()
    elif activation == "sigmoid":
        self.activation = nn.Sigmoid()

forward

forward(x)

Forward pass.

Parameters:

Name Type Description Default
x

Input tensor of shape (batch_size, in_features).

required

Returns:

Type Description

Normalized PSF tensor of shape (batch_size, channels, ks, ks).

Source code in deeplens/surrogate/mlpconv.py
def forward(self, x):
    """Forward pass.

    Args:
        x: Input tensor of shape ``(batch_size, in_features)``.

    Returns:
        Normalized PSF tensor of shape ``(batch_size, channels, ks, ks)``.
    """
    # Encode the input using the MLP
    encoded = self.encoder(x)

    # Reshape the output from the MLP to feed to the CNN
    decoded_input = encoded.view(
        -1, self.channels, self.ks_mlp, self.ks_mlp
    )  # reshape to (batch_size, channels, height, width)

    # Decode the output using the CNN
    decoded = self.decoder(decoded_input)

    # This normalization only works for PSF network
    decoded = nn.Sigmoid()(decoded)
    decoded = F.normalize(decoded, p=1, dim=[-1, -2])

    return decoded

deeplens.surrogate.siren.Siren

Siren(dim_in, dim_out, w0=1.0, c=6.0, is_first=False, use_bias=True, activation=None)

Bases: Module

Single SIREN (Sinusoidal Representation Network) layer.

A linear layer followed by a sine activation. Uses the initialization scheme from "Implicit Neural Representations with Periodic Activation Functions".

Parameters:

Name Type Description Default
dim_in

Input dimension.

required
dim_out

Output dimension.

required
w0

Frequency multiplier for the sine activation. Defaults to 1.0.

1.0
c

Constant for weight initialization. Defaults to 6.0.

6.0
is_first

Whether this is the first layer (uses different init). Defaults to False.

False
use_bias

Whether to include a bias term. Defaults to True.

True
activation

Custom activation module. If None, uses Sine(w0).

None
Source code in deeplens/surrogate/siren.py
def __init__(
    self,
    dim_in,
    dim_out,
    w0=1.0,
    c=6.0,
    is_first=False,
    use_bias=True,
    activation=None,
):
    super().__init__()
    self.dim_in = dim_in
    self.is_first = is_first

    weight = torch.zeros(dim_out, dim_in)
    bias = torch.zeros(dim_out) if use_bias else None
    self.init_(weight, bias, c=c, w0=w0)

    self.weight = nn.Parameter(weight)
    self.bias = nn.Parameter(bias) if use_bias else None
    self.activation = Sine(w0) if activation is None else activation

forward

forward(x)

Forward pass.

Parameters:

Name Type Description Default
x

Input tensor of shape (..., dim_in).

required

Returns:

Type Description

Output tensor of shape (..., dim_out).

Source code in deeplens/surrogate/siren.py
def forward(self, x):
    """Forward pass.

    Args:
        x: Input tensor of shape ``(..., dim_in)``.

    Returns:
        Output tensor of shape ``(..., dim_out)``.
    """
    out = F.linear(x, self.weight, self.bias)
    out = self.activation(out)
    return out

deeplens.surrogate.ModulateSiren

ModulateSiren(dim_in, dim_hidden, dim_out, dim_latent, num_layers, image_width, image_height, w0=1.0, w0_initial=30.0, use_bias=True, final_activation=None, outermost_linear=True)

Bases: Module

Modulated SIREN for latent-conditioned image synthesis.

Combines a SIREN synthesizer network (mapping pixel coordinates to output values) with a modulator network that conditions each layer based on a latent vector. Used to predict spatially-varying PSFs conditioned on lens parameters.

Parameters:

Name Type Description Default
dim_in

Input coordinate dimension (typically 2 for x, y).

required
dim_hidden

Hidden layer width for both synthesizer and modulator.

required
dim_out

Output dimension per pixel (e.g., 1 for grayscale PSF).

required
dim_latent

Dimension of the conditioning latent vector.

required
num_layers

Number of SIREN + modulator layers.

required
image_width

Output image width in pixels.

required
image_height

Output image height in pixels.

required
w0

Frequency multiplier for hidden sine layers. Defaults to 1.0.

1.0
w0_initial

Frequency multiplier for the first sine layer. Defaults to 30.0.

30.0
use_bias

Whether to use bias in sine layers. Defaults to True.

True
final_activation

Activation for the last layer. Defaults to None (linear).

None
outermost_linear

If True, the last layer is a plain linear layer. Defaults to True.

True
Source code in deeplens/surrogate/modulate_siren.py
def __init__(
    self,
    dim_in,
    dim_hidden,
    dim_out,
    dim_latent,
    num_layers,
    image_width,
    image_height,
    w0=1.0,
    w0_initial=30.0,
    use_bias=True,
    final_activation=None,
    outermost_linear=True,
):
    super().__init__()
    self.num_layers = num_layers
    self.dim_hidden = dim_hidden
    self.img_width = image_width
    self.img_height = image_height

    # ==> Synthesizer
    synthesizer_layers = nn.ModuleList([])
    for ind in range(num_layers):
        is_first = ind == 0
        layer_w0 = w0_initial if is_first else w0
        layer_dim_in = dim_in if is_first else dim_hidden

        synthesizer_layers.append(
            SineLayer(
                in_features=layer_dim_in,
                out_features=dim_hidden,
                omega_0=layer_w0,
                bias=use_bias,
                is_first=is_first,
            )
        )

    if outermost_linear:
        last_layer = nn.Linear(dim_hidden, dim_out)
        with torch.no_grad():
            # w_std = math.sqrt(6 / dim_hidden) / w0
            # self.last_layer.weight.uniform_(- w_std, w_std)
            nn.init.kaiming_normal_(
                last_layer.weight, a=0.0, nonlinearity="relu", mode="fan_in"
            )
    else:
        final_activation = (
            nn.Identity() if not exists(final_activation) else final_activation
        )
        last_layer = Siren(
            dim_in=dim_hidden,
            dim_out=dim_out,
            w0=w0,
            use_bias=use_bias,
            activation=final_activation,
        )
    synthesizer_layers.append(last_layer)

    self.synthesizer = synthesizer_layers
    # self.synthesizer = nn.Sequential(*synthesizer)

    # ==> Modulator
    modulator_layers = nn.ModuleList([])
    for ind in range(num_layers):
        is_first = ind == 0
        dim = dim_latent if is_first else (dim_hidden + dim_latent)

        modulator_layers.append(
            nn.Sequential(nn.Linear(dim, dim_hidden), nn.ReLU())
        )

        with torch.no_grad():
            # self.layers[-1][0].weight.uniform_(-1 / dim_hidden, 1 / dim_hidden)
            nn.init.kaiming_normal_(
                modulator_layers[-1][0].weight,
                a=0.0,
                nonlinearity="relu",
                mode="fan_in",
            )

    self.modulator = modulator_layers
    # self.modulator = nn.Sequential(*modulator_layers)

    # ==> Positions
    tensors = [
        torch.linspace(-1, 1, steps=image_height),
        torch.linspace(-1, 1, steps=image_width),
    ]
    mgrid = torch.stack(torch.meshgrid(*tensors, indexing="ij"), dim=-1)
    mgrid = rearrange(mgrid, "h w c -> (h w) c")
    self.register_buffer("grid", mgrid)

forward

forward(latent)

Forward pass.

Parameters:

Name Type Description Default
latent

Conditioning latent vector of shape (batch_size, dim_latent).

required

Returns:

Type Description

Output image tensor of shape (batch_size, dim_out, image_height, image_width).

Source code in deeplens/surrogate/modulate_siren.py
def forward(self, latent):
    """Forward pass.

    Args:
        latent: Conditioning latent vector of shape ``(batch_size, dim_latent)``.

    Returns:
        Output image tensor of shape ``(batch_size, dim_out, image_height, image_width)``.
    """
    x = self.grid.clone().detach().requires_grad_()

    for i in range(self.num_layers):
        if i == 0:
            z = self.modulator[i](latent)
        else:
            z = self.modulator[i](torch.cat((latent, z), dim=-1))

        x = self.synthesizer[i](x)
        x = x * z

    x = self.synthesizer[-1](x)  # shape of (h*w, 1)
    x = torch.tanh(x)
    x = x.view(
        -1, self.img_height, self.img_width, 1
    )  # reshape to (batch_size, height, width, channels)
    x = x.permute(0, 3, 1, 2)  # reshape to (batch_size, channels, height, width)
    return x