Skip to content

Network API Reference

The deeplens.network module provides neural networks for PSF prediction (surrogates) and image reconstruction, plus loss functions for training.


Surrogate Networks

Neural networks that learn to predict PSFs from lens parameters, replacing expensive ray tracing during training.

deeplens.network.MLP

MLP(in_features, out_features, hidden_features=64, hidden_layers=3)

Bases: Module

Fully-connected network for low-frequency PSF prediction.

Predicts PSFs as flattened vectors using stacked linear layers with ReLU activations and a Sigmoid output. The output is L1-normalized so it sums to 1 (valid as a PSF energy distribution).

Parameters:

Name Type Description Default
in_features

Number of input features (e.g., field angle + wavelength).

required
out_features

Number of output features (flattened PSF size).

required
hidden_features

Width of hidden layers. Defaults to 64.

64
hidden_layers

Number of hidden layers. Defaults to 3.

3
Source code in deeplens/network/surrogate/mlp.py
def __init__(self, in_features, out_features, hidden_features=64, hidden_layers=3):
    super(MLP, self).__init__()

    layers = [
        nn.Linear(in_features, hidden_features // 4, bias=True),
        nn.ReLU(inplace=True),
        nn.Linear(hidden_features // 4, hidden_features, bias=True),
        nn.ReLU(inplace=True),
    ]

    for _ in range(hidden_layers):
        layers.extend(
            [
                nn.Linear(hidden_features, hidden_features, bias=True),
                nn.ReLU(inplace=True),
            ]
        )

    layers.extend(
        [nn.Linear(hidden_features, out_features, bias=True), nn.Sigmoid()]
    )

    self.net = nn.Sequential(*layers)

forward

forward(x)

Forward pass.

Parameters:

Name Type Description Default
x

Input tensor of shape (batch_size, in_features).

required

Returns:

Type Description

L1-normalized output tensor of shape (batch_size, out_features).

Source code in deeplens/network/surrogate/mlp.py
def forward(self, x):
    """Forward pass.

    Args:
        x: Input tensor of shape ``(batch_size, in_features)``.

    Returns:
        L1-normalized output tensor of shape ``(batch_size, out_features)``.
    """
    x = self.net(x)
    x = F.normalize(x, p=1, dim=-1)
    return x

deeplens.network.MLPConv

MLPConv(in_features, ks, channels=3, activation='relu')

Bases: Module

MLP encoder + convolutional decoder for high-resolution PSF prediction.

Uses a linear encoder to produce a low-resolution feature map, then a transposed-convolution decoder to upsample to the target PSF resolution. The output is Sigmoid-activated and L1-normalized per spatial dimension.

Reference: "Differentiable Compound Optics and Processing Pipeline Optimization for End-To-end Camera Design".

Parameters:

Name Type Description Default
in_features

Number of input features (e.g., field angle + wavelength).

required
ks

Spatial size of the output PSF. Must be a multiple of 32 if > 32.

required
channels

Number of output channels. Defaults to 3.

3
activation

Activation function, "relu" or "sigmoid". Defaults to "relu".

'relu'
Source code in deeplens/network/surrogate/mlpconv.py
def __init__(self, in_features, ks, channels=3, activation="relu"):
    super(MLPConv, self).__init__()

    self.ks_mlp = min(ks, 32)
    if ks > 32:
        assert ks % 32 == 0, "ks must be 32n"
        upsample_times = int(math.log(ks / 32, 2))

    linear_output = channels * self.ks_mlp**2
    self.ks = ks
    self.channels = channels

    # MLP encoder
    self.encoder = nn.Sequential(
        nn.Linear(in_features, 256),
        nn.ReLU(),
        nn.Linear(256, 256),
        nn.ReLU(),
        nn.Linear(256, 512),
        nn.ReLU(),
        nn.Linear(512, linear_output),
    )

    # Conv decoder
    conv_layers = []
    conv_layers.append(
        nn.ConvTranspose2d(channels, 64, kernel_size=3, stride=1, padding=1)
    )
    conv_layers.append(nn.ReLU())
    for _ in range(upsample_times):
        conv_layers.append(
            nn.ConvTranspose2d(64, 64, kernel_size=3, stride=1, padding=1)
        )
        conv_layers.append(nn.ReLU())
        conv_layers.append(nn.Upsample(scale_factor=2))

    conv_layers.append(
        nn.ConvTranspose2d(64, 64, kernel_size=3, stride=1, padding=1)
    )
    conv_layers.append(nn.ReLU())
    conv_layers.append(
        nn.ConvTranspose2d(64, channels, kernel_size=3, stride=1, padding=1)
    )
    self.decoder = nn.Sequential(*conv_layers)

    if activation == "relu":
        self.activation = nn.ReLU()
    elif activation == "sigmoid":
        self.activation = nn.Sigmoid()

forward

forward(x)

Forward pass.

Parameters:

Name Type Description Default
x

Input tensor of shape (batch_size, in_features).

required

Returns:

Type Description

Normalized PSF tensor of shape (batch_size, channels, ks, ks).

Source code in deeplens/network/surrogate/mlpconv.py
def forward(self, x):
    """Forward pass.

    Args:
        x: Input tensor of shape ``(batch_size, in_features)``.

    Returns:
        Normalized PSF tensor of shape ``(batch_size, channels, ks, ks)``.
    """
    # Encode the input using the MLP
    encoded = self.encoder(x)

    # Reshape the output from the MLP to feed to the CNN
    decoded_input = encoded.view(
        -1, self.channels, self.ks_mlp, self.ks_mlp
    )  # reshape to (batch_size, channels, height, width)

    # Decode the output using the CNN
    decoded = self.decoder(decoded_input)

    # This normalization only works for PSF network
    decoded = nn.Sigmoid()(decoded)
    decoded = F.normalize(decoded, p=1, dim=[-1, -2])

    return decoded

deeplens.network.surrogate.siren.Siren

Siren(dim_in, dim_out, w0=1.0, c=6.0, is_first=False, use_bias=True, activation=None)

Bases: Module

Single SIREN (Sinusoidal Representation Network) layer.

A linear layer followed by a sine activation. Uses the initialization scheme from "Implicit Neural Representations with Periodic Activation Functions".

Parameters:

Name Type Description Default
dim_in

Input dimension.

required
dim_out

Output dimension.

required
w0

Frequency multiplier for the sine activation. Defaults to 1.0.

1.0
c

Constant for weight initialization. Defaults to 6.0.

6.0
is_first

Whether this is the first layer (uses different init). Defaults to False.

False
use_bias

Whether to include a bias term. Defaults to True.

True
activation

Custom activation module. If None, uses Sine(w0).

None
Source code in deeplens/network/surrogate/siren.py
def __init__(
    self,
    dim_in,
    dim_out,
    w0=1.0,
    c=6.0,
    is_first=False,
    use_bias=True,
    activation=None,
):
    super().__init__()
    self.dim_in = dim_in
    self.is_first = is_first

    weight = torch.zeros(dim_out, dim_in)
    bias = torch.zeros(dim_out) if use_bias else None
    self.init_(weight, bias, c=c, w0=w0)

    self.weight = nn.Parameter(weight)
    self.bias = nn.Parameter(bias) if use_bias else None
    self.activation = Sine(w0) if activation is None else activation

forward

forward(x)

Forward pass.

Parameters:

Name Type Description Default
x

Input tensor of shape (..., dim_in).

required

Returns:

Type Description

Output tensor of shape (..., dim_out).

Source code in deeplens/network/surrogate/siren.py
def forward(self, x):
    """Forward pass.

    Args:
        x: Input tensor of shape ``(..., dim_in)``.

    Returns:
        Output tensor of shape ``(..., dim_out)``.
    """
    out = F.linear(x, self.weight, self.bias)
    out = self.activation(out)
    return out

deeplens.network.ModulateSiren

ModulateSiren(dim_in, dim_hidden, dim_out, dim_latent, num_layers, image_width, image_height, w0=1.0, w0_initial=30.0, use_bias=True, final_activation=None, outermost_linear=True)

Bases: Module

Modulated SIREN for latent-conditioned image synthesis.

Combines a SIREN synthesizer network (mapping pixel coordinates to output values) with a modulator network that conditions each layer based on a latent vector. Used to predict spatially-varying PSFs conditioned on lens parameters.

Parameters:

Name Type Description Default
dim_in

Input coordinate dimension (typically 2 for x, y).

required
dim_hidden

Hidden layer width for both synthesizer and modulator.

required
dim_out

Output dimension per pixel (e.g., 1 for grayscale PSF).

required
dim_latent

Dimension of the conditioning latent vector.

required
num_layers

Number of SIREN + modulator layers.

required
image_width

Output image width in pixels.

required
image_height

Output image height in pixels.

required
w0

Frequency multiplier for hidden sine layers. Defaults to 1.0.

1.0
w0_initial

Frequency multiplier for the first sine layer. Defaults to 30.0.

30.0
use_bias

Whether to use bias in sine layers. Defaults to True.

True
final_activation

Activation for the last layer. Defaults to None (linear).

None
outermost_linear

If True, the last layer is a plain linear layer. Defaults to True.

True
Source code in deeplens/network/surrogate/modulate_siren.py
def __init__(
    self,
    dim_in,
    dim_hidden,
    dim_out,
    dim_latent,
    num_layers,
    image_width,
    image_height,
    w0=1.0,
    w0_initial=30.0,
    use_bias=True,
    final_activation=None,
    outermost_linear=True,
):
    super().__init__()
    self.num_layers = num_layers
    self.dim_hidden = dim_hidden
    self.img_width = image_width
    self.img_height = image_height

    # ==> Synthesizer
    synthesizer_layers = nn.ModuleList([])
    for ind in range(num_layers):
        is_first = ind == 0
        layer_w0 = w0_initial if is_first else w0
        layer_dim_in = dim_in if is_first else dim_hidden

        synthesizer_layers.append(
            SineLayer(
                in_features=layer_dim_in,
                out_features=dim_hidden,
                omega_0=layer_w0,
                bias=use_bias,
                is_first=is_first,
            )
        )

    if outermost_linear:
        last_layer = nn.Linear(dim_hidden, dim_out)
        with torch.no_grad():
            # w_std = math.sqrt(6 / dim_hidden) / w0
            # self.last_layer.weight.uniform_(- w_std, w_std)
            nn.init.kaiming_normal_(
                last_layer.weight, a=0.0, nonlinearity="relu", mode="fan_in"
            )
    else:
        final_activation = (
            nn.Identity() if not exists(final_activation) else final_activation
        )
        last_layer = Siren(
            dim_in=dim_hidden,
            dim_out=dim_out,
            w0=w0,
            use_bias=use_bias,
            activation=final_activation,
        )
    synthesizer_layers.append(last_layer)

    self.synthesizer = synthesizer_layers
    # self.synthesizer = nn.Sequential(*synthesizer)

    # ==> Modulator
    modulator_layers = nn.ModuleList([])
    for ind in range(num_layers):
        is_first = ind == 0
        dim = dim_latent if is_first else (dim_hidden + dim_latent)

        modulator_layers.append(
            nn.Sequential(nn.Linear(dim, dim_hidden), nn.ReLU())
        )

        with torch.no_grad():
            # self.layers[-1][0].weight.uniform_(-1 / dim_hidden, 1 / dim_hidden)
            nn.init.kaiming_normal_(
                modulator_layers[-1][0].weight,
                a=0.0,
                nonlinearity="relu",
                mode="fan_in",
            )

    self.modulator = modulator_layers
    # self.modulator = nn.Sequential(*modulator_layers)

    # ==> Positions
    tensors = [
        torch.linspace(-1, 1, steps=image_height),
        torch.linspace(-1, 1, steps=image_width),
    ]
    mgrid = torch.stack(torch.meshgrid(*tensors, indexing="ij"), dim=-1)
    mgrid = rearrange(mgrid, "h w c -> (h w) c")
    self.register_buffer("grid", mgrid)

forward

forward(latent)

Forward pass.

Parameters:

Name Type Description Default
latent

Conditioning latent vector of shape (batch_size, dim_latent).

required

Returns:

Type Description

Output image tensor of shape (batch_size, dim_out, image_height, image_width).

Source code in deeplens/network/surrogate/modulate_siren.py
def forward(self, latent):
    """Forward pass.

    Args:
        latent: Conditioning latent vector of shape ``(batch_size, dim_latent)``.

    Returns:
        Output image tensor of shape ``(batch_size, dim_out, image_height, image_width)``.
    """
    x = self.grid.clone().detach().requires_grad_()

    for i in range(self.num_layers):
        if i == 0:
            z = self.modulator[i](latent)
        else:
            z = self.modulator[i](torch.cat((latent, z), dim=-1))

        x = self.synthesizer[i](x)
        x = x * z

    x = self.synthesizer[-1](x)  # shape of (h*w, 1)
    x = torch.tanh(x)
    x = x.view(
        -1, self.img_height, self.img_width, 1
    )  # reshape to (batch_size, height, width, channels)
    x = x.permute(0, 3, 1, 2)  # reshape to (batch_size, channels, height, width)
    return x

Reconstruction Networks

Image restoration networks that recover a clean image from a degraded (aberrated) sensor capture.

deeplens.network.NAFNet

NAFNet(in_chan=3, out_chan=3, width=32, middle_blk_num=1, enc_blk_nums=[1, 1, 1, 28], dec_blk_nums=[1, 1, 1, 1])

Bases: Module

Nonlinear Activation Free Network for image restoration.

A U-Net-style encoder-decoder with NAFBlocks that replace nonlinear activations with SimpleGate (element-wise multiplication of channel-split halves). Includes a global residual connection from input to output.

Reference: "Simple Baselines for Image Restoration" (ECCV 2022).

Parameters:

Name Type Description Default
in_chan

Number of input channels. Defaults to 3.

3
out_chan

Number of output channels. Defaults to 3.

3
width

Base channel width. Defaults to 32.

32
middle_blk_num

Number of NAFBlocks in the bottleneck. Defaults to 1.

1
enc_blk_nums

Number of NAFBlocks per encoder stage. Defaults to [1, 1, 1, 28].

[1, 1, 1, 28]
dec_blk_nums

Number of NAFBlocks per decoder stage. Defaults to [1, 1, 1, 1].

[1, 1, 1, 1]
Source code in deeplens/network/reconstruction/nafnet.py
def __init__(
    self,
    in_chan=3,
    out_chan=3,
    width=32,  # 64
    middle_blk_num=1,
    enc_blk_nums=[1, 1, 1, 28],
    dec_blk_nums=[1, 1, 1, 1],
):
    super().__init__()

    self.intro = nn.Conv2d(
        in_channels=in_chan,
        out_channels=width,
        kernel_size=3,
        padding=1,
        stride=1,
        groups=1,
        bias=True,
    )
    self.ending = nn.Conv2d(
        in_channels=width,
        out_channels=out_chan,
        kernel_size=3,
        padding=1,
        stride=1,
        groups=1,
        bias=True,
    )

    self.encoders = nn.ModuleList()
    self.decoders = nn.ModuleList()
    self.middle_blks = nn.ModuleList()
    self.ups = nn.ModuleList()
    self.downs = nn.ModuleList()

    chan = width
    for num in enc_blk_nums:
        self.encoders.append(nn.Sequential(*[NAFBlock(chan) for _ in range(num)]))
        self.downs.append(nn.Conv2d(chan, 2 * chan, 2, 2))
        chan = chan * 2

    self.middle_blks = nn.Sequential(
        *[NAFBlock(chan) for _ in range(middle_blk_num)]
    )

    for num in dec_blk_nums:
        self.ups.append(
            nn.Sequential(
                nn.Conv2d(chan, chan * 2, 1, bias=False), nn.PixelShuffle(2)
            )
        )
        chan = chan // 2
        self.decoders.append(nn.Sequential(*[NAFBlock(chan) for _ in range(num)]))

    self.padder_size = 2 ** len(self.encoders)

    # Initialize weights  
    self.initialize_weights()  

forward

forward(inp)

Forward pass with global residual connection.

Parameters:

Name Type Description Default
inp

Input image tensor of shape (B, in_chan, H, W).

required

Returns:

Type Description

Restored image tensor of shape (B, out_chan, H, W).

Source code in deeplens/network/reconstruction/nafnet.py
def forward(self, inp):
    """Forward pass with global residual connection.

    Args:
        inp: Input image tensor of shape ``(B, in_chan, H, W)``.

    Returns:
        Restored image tensor of shape ``(B, out_chan, H, W)``.
    """
    B, C, H, W = inp.shape
    inp = self.check_image_size(inp)

    x = self.intro(inp)

    encs = []

    for encoder, down in zip(self.encoders, self.downs):
        x = encoder(x)
        encs.append(x)
        x = down(x)

    x = self.middle_blks(x)

    for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]):
        x = up(x)
        x = x + enc_skip
        x = decoder(x)

    x = self.ending(x)
    x = x + inp[:, :x.shape[1], :, :]

    return x[:, :, :H, :W]

deeplens.network.UNet

UNet(in_channels=3, out_channels=3)

Bases: Module

U-Net with residual skip connections for image restoration.

A 3-level encoder-decoder with dense BasicBlocks and PixelShuffle upsampling. Uses additive skip connections between encoder and decoder stages.

Parameters:

Name Type Description Default
in_channels

Number of input channels. Defaults to 3.

3
out_channels

Number of output channels. Defaults to 3.

3
Source code in deeplens/network/reconstruction/unet.py
def __init__(self, in_channels=3, out_channels=3):
    super().__init__()
    self.pre = self.pre = nn.Sequential(
        nn.Conv2d(in_channels, 16, kernel_size=3, stride=1, padding=1), nn.PReLU(16)
    )
    self.conv00 = BasicBlock(16, 32)
    self.down0 = nn.MaxPool2d((2, 2))
    self.conv10 = BasicBlock(32, 64)
    self.down1 = nn.MaxPool2d((2, 2))
    self.conv20 = BasicBlock(64, 128)
    self.down2 = nn.MaxPool2d((2, 2))
    self.conv30 = BasicBlock(128, 256)
    self.conv31 = BasicBlock(256, 512)
    self.up2 = nn.PixelShuffle(2)
    self.conv21 = BasicBlock(128, 256)
    self.up1 = nn.PixelShuffle(2)
    self.conv11 = BasicBlock(64, 128)
    self.up0 = nn.PixelShuffle(2)
    self.conv01 = BasicBlock(32, 64)

    self.post = nn.Sequential(
        nn.Conv2d(64, 16, kernel_size=3, stride=1, padding=1),
        nn.PReLU(16),
        nn.Conv2d(16, out_channels, kernel_size=3, stride=1, padding=1),
    )

forward

forward(x)

Forward pass.

Parameters:

Name Type Description Default
x

Input image tensor of shape (B, in_channels, H, W).

required

Returns:

Type Description

Output tensor of shape (B, out_channels, H, W).

Source code in deeplens/network/reconstruction/unet.py
def forward(self, x):
    """Forward pass.

    Args:
        x: Input image tensor of shape ``(B, in_channels, H, W)``.

    Returns:
        Output tensor of shape ``(B, out_channels, H, W)``.
    """
    x0 = self.pre(x)
    x0 = self.conv00(x0)
    x1 = self.down0(x0)
    x1 = self.conv10(x1)
    x2 = self.down1(x1)
    x2 = self.conv20(x2)
    x3 = self.down2(x2)
    x3 = self.conv30(x3)
    x3 = self.conv31(x3)
    x2 = x2 + self.up2(x3)
    x2 = self.conv21(x2)
    x1 = x1 + self.up1(x2)
    x1 = self.conv11(x1)
    x0 = x0 + self.up0(x1)
    x0 = self.conv01(x0)
    x = self.post(x0)
    return x

deeplens.network.Restormer

Restormer(inp_channels=3, out_channels=3, dim=48, num_blocks=[4, 6, 6, 8], num_refinement_blocks=4, heads=[1, 2, 4, 8], ffn_expansion_factor=2.66, bias=False, LayerNorm_type='WithBias', dual_pixel_task=False)

Bases: Module

Restormer: Efficient Transformer for high-resolution image restoration.

A multi-scale encoder-decoder transformer using Multi-DConv Head Transposed Self-Attention (MDTA) and Gated-DConv Feed-Forward Networks (GDFN). Includes a global residual connection from input to output.

Reference: Zamir et al., "Restormer: Efficient Transformer for High-Resolution Image Restoration" (CVPR 2022).

Parameters:

Name Type Description Default
inp_channels

Number of input channels. Defaults to 3.

3
out_channels

Number of output channels. Defaults to 3.

3
dim

Base embedding dimension. Defaults to 48.

48
num_blocks

Number of transformer blocks per encoder/decoder stage. Defaults to [4, 6, 6, 8].

[4, 6, 6, 8]
num_refinement_blocks

Number of refinement blocks after the decoder. Defaults to 4.

4
heads

Number of attention heads per stage. Defaults to [1, 2, 4, 8].

[1, 2, 4, 8]
ffn_expansion_factor

Hidden dimension multiplier in GDFN. Defaults to 2.66.

2.66
bias

Whether to use bias in convolutions. Defaults to False.

False
LayerNorm_type

"WithBias" or "BiasFree". Defaults to "WithBias".

'WithBias'
dual_pixel_task

If True, uses skip connection for dual-pixel defocus deblurring (set inp_channels=6). Defaults to False.

False
Source code in deeplens/network/reconstruction/restormer.py
def __init__(
    self,
    inp_channels=3,
    out_channels=3,
    dim=48,
    num_blocks=[4, 6, 6, 8],
    num_refinement_blocks=4,
    heads=[1, 2, 4, 8],
    ffn_expansion_factor=2.66,
    bias=False,
    LayerNorm_type="WithBias",  ## Other option 'BiasFree'
    dual_pixel_task=False,  ## True for dual-pixel defocus deblurring only. Also set inp_channels=6
):
    super(Restormer, self).__init__()

    self.patch_embed = OverlapPatchEmbed(inp_channels, dim)

    self.encoder_level1 = nn.Sequential(
        *[
            TransformerBlock(
                dim=dim,
                num_heads=heads[0],
                ffn_expansion_factor=ffn_expansion_factor,
                bias=bias,
                LayerNorm_type=LayerNorm_type,
            )
            for i in range(num_blocks[0])
        ]
    )

    self.down1_2 = Downsample(dim)  ## From Level 1 to Level 2
    self.encoder_level2 = nn.Sequential(
        *[
            TransformerBlock(
                dim=int(dim * 2**1),
                num_heads=heads[1],
                ffn_expansion_factor=ffn_expansion_factor,
                bias=bias,
                LayerNorm_type=LayerNorm_type,
            )
            for i in range(num_blocks[1])
        ]
    )

    self.down2_3 = Downsample(int(dim * 2**1))  ## From Level 2 to Level 3
    self.encoder_level3 = nn.Sequential(
        *[
            TransformerBlock(
                dim=int(dim * 2**2),
                num_heads=heads[2],
                ffn_expansion_factor=ffn_expansion_factor,
                bias=bias,
                LayerNorm_type=LayerNorm_type,
            )
            for i in range(num_blocks[2])
        ]
    )

    self.down3_4 = Downsample(int(dim * 2**2))  ## From Level 3 to Level 4
    self.latent = nn.Sequential(
        *[
            TransformerBlock(
                dim=int(dim * 2**3),
                num_heads=heads[3],
                ffn_expansion_factor=ffn_expansion_factor,
                bias=bias,
                LayerNorm_type=LayerNorm_type,
            )
            for i in range(num_blocks[3])
        ]
    )

    self.up4_3 = Upsample(int(dim * 2**3))  ## From Level 4 to Level 3
    self.reduce_chan_level3 = nn.Conv2d(
        int(dim * 2**3), int(dim * 2**2), kernel_size=1, bias=bias
    )
    self.decoder_level3 = nn.Sequential(
        *[
            TransformerBlock(
                dim=int(dim * 2**2),
                num_heads=heads[2],
                ffn_expansion_factor=ffn_expansion_factor,
                bias=bias,
                LayerNorm_type=LayerNorm_type,
            )
            for i in range(num_blocks[2])
        ]
    )

    self.up3_2 = Upsample(int(dim * 2**2))  ## From Level 3 to Level 2
    self.reduce_chan_level2 = nn.Conv2d(
        int(dim * 2**2), int(dim * 2**1), kernel_size=1, bias=bias
    )
    self.decoder_level2 = nn.Sequential(
        *[
            TransformerBlock(
                dim=int(dim * 2**1),
                num_heads=heads[1],
                ffn_expansion_factor=ffn_expansion_factor,
                bias=bias,
                LayerNorm_type=LayerNorm_type,
            )
            for i in range(num_blocks[1])
        ]
    )

    self.up2_1 = Upsample(
        int(dim * 2**1)
    )  ## From Level 2 to Level 1  (NO 1x1 conv to reduce channels)

    self.decoder_level1 = nn.Sequential(
        *[
            TransformerBlock(
                dim=int(dim * 2**1),
                num_heads=heads[0],
                ffn_expansion_factor=ffn_expansion_factor,
                bias=bias,
                LayerNorm_type=LayerNorm_type,
            )
            for i in range(num_blocks[0])
        ]
    )

    self.refinement = nn.Sequential(
        *[
            TransformerBlock(
                dim=int(dim * 2**1),
                num_heads=heads[0],
                ffn_expansion_factor=ffn_expansion_factor,
                bias=bias,
                LayerNorm_type=LayerNorm_type,
            )
            for i in range(num_refinement_blocks)
        ]
    )

    #### For Dual-Pixel Defocus Deblurring Task ####
    self.dual_pixel_task = dual_pixel_task
    if self.dual_pixel_task:
        self.skip_conv = nn.Conv2d(dim, int(dim * 2**1), kernel_size=1, bias=bias)
    ###########################

    self.output = nn.Conv2d(
        int(dim * 2**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias
    )

forward

forward(inp_img)

Forward pass with global residual connection.

Parameters:

Name Type Description Default
inp_img

Input image tensor of shape (B, inp_channels, H, W).

required

Returns:

Type Description

Restored image tensor of shape (B, out_channels, H, W).

Source code in deeplens/network/reconstruction/restormer.py
def forward(self, inp_img):
    """Forward pass with global residual connection.

    Args:
        inp_img: Input image tensor of shape ``(B, inp_channels, H, W)``.

    Returns:
        Restored image tensor of shape ``(B, out_channels, H, W)``.
    """
    inp_enc_level1 = self.patch_embed(inp_img)
    out_enc_level1 = self.encoder_level1(inp_enc_level1)

    inp_enc_level2 = self.down1_2(out_enc_level1)
    out_enc_level2 = self.encoder_level2(inp_enc_level2)

    inp_enc_level3 = self.down2_3(out_enc_level2)
    out_enc_level3 = self.encoder_level3(inp_enc_level3)

    inp_enc_level4 = self.down3_4(out_enc_level3)
    latent = self.latent(inp_enc_level4)

    inp_dec_level3 = self.up4_3(latent)
    inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1)
    inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3)
    out_dec_level3 = self.decoder_level3(inp_dec_level3)

    inp_dec_level2 = self.up3_2(out_dec_level3)
    inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1)
    inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2)
    out_dec_level2 = self.decoder_level2(inp_dec_level2)

    inp_dec_level1 = self.up2_1(out_dec_level2)
    inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1)
    out_dec_level1 = self.decoder_level1(inp_dec_level1)

    out_dec_level1 = self.refinement(out_dec_level1)

    #### For Dual-Pixel Defocus Deblurring Task ####
    if self.dual_pixel_task:
        out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1)
        out_dec_level1 = self.output(out_dec_level1)
    ###########################
    else:
        out_dec_level1 = self.output(out_dec_level1) + inp_img

    return out_dec_level1

Loss Functions

deeplens.network.PerceptualLoss

PerceptualLoss(device=None, weights=[1.0, 1.0, 1.0, 1.0, 1.0])

Bases: Module

Perceptual loss based on VGG16 features.

Initialize perceptual loss.

Parameters:

Name Type Description Default
device

Device to put the VGG model on. If None, uses cuda if available.

None
weights

Weights for different feature layers.

[1.0, 1.0, 1.0, 1.0, 1.0]
Source code in deeplens/network/loss/perceptual_loss.py
def __init__(self, device=None, weights=[1.0, 1.0, 1.0, 1.0, 1.0]):
    """Initialize perceptual loss.

    Args:
        device: Device to put the VGG model on. If None, uses cuda if available.
        weights: Weights for different feature layers.
    """
    super(PerceptualLoss, self).__init__()

    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    self.vgg = models.vgg16(weights=VGG16_Weights.DEFAULT).features.to(device)
    self.layer_name_mapping = {
        '3': "relu1_2",
        '8': "relu2_2",
        '15': "relu3_3",
        '22': "relu4_3",
        '29': "relu5_3"
    }

    self.weights = weights

    for param in self.vgg.parameters():
        param.requires_grad = False

forward

forward(x, y)

Calculate perceptual loss.

Parameters:

Name Type Description Default
x

Predicted tensor.

required
y

Target tensor.

required

Returns:

Type Description

Perceptual loss.

Source code in deeplens/network/loss/perceptual_loss.py
def forward(self, x, y):
    """Calculate perceptual loss.

    Args:
        x: Predicted tensor.
        y: Target tensor.

    Returns:
        Perceptual loss.
    """
    x_vgg, y_vgg = self._get_features(x), self._get_features(y)

    content_loss = 0.0
    for i, (key, value) in enumerate(x_vgg.items()):
        content_loss += self.weights[i] * torch.mean((value - y_vgg[key]) ** 2)

    return content_loss

deeplens.network.PSNRLoss

PSNRLoss(loss_weight=1.0, reduction='mean', toY=False)

Bases: Module

Peak Signal-to-Noise Ratio (PSNR) loss.

Initialize PSNR loss.

Parameters:

Name Type Description Default
loss_weight

Weight for the loss.

1.0
reduction

Reduction method, only "mean" is supported.

'mean'
toY

Whether to convert RGB to Y channel.

False
Source code in deeplens/network/loss/psnr_loss.py
def __init__(self, loss_weight=1.0, reduction="mean", toY=False):
    """Initialize PSNR loss.

    Args:
        loss_weight: Weight for the loss.
        reduction: Reduction method, only "mean" is supported.
        toY: Whether to convert RGB to Y channel.
    """
    super(PSNRLoss, self).__init__()
    assert reduction == "mean"
    self.loss_weight = loss_weight
    self.scale = 10 / np.log(10)
    self.toY = toY
    self.coef = torch.tensor([65.481, 128.553, 24.966]).reshape(1, 3, 1, 1)
    self.first = True

forward

forward(pred, target)

Calculate PSNR loss.

Parameters:

Name Type Description Default
pred

Predicted tensor.

required
target

Target tensor.

required

Returns:

Type Description

PSNR loss.

Source code in deeplens/network/loss/psnr_loss.py
def forward(self, pred, target):
    """Calculate PSNR loss.

    Args:
        pred: Predicted tensor.
        target: Target tensor.

    Returns:
        PSNR loss.
    """
    assert len(pred.size()) == 4
    if self.toY:
        if self.first:
            self.coef = self.coef.to(pred.device)
            self.first = False

        pred = (pred * self.coef).sum(dim=1).unsqueeze(dim=1) + 16.0
        target = (target * self.coef).sum(dim=1).unsqueeze(dim=1) + 16.0

        pred, target = pred / 255.0, target / 255.0
        pass
    assert len(pred.size()) == 4

    return (
        self.loss_weight
        * self.scale
        * torch.log(((pred - target) ** 2).mean(dim=(1, 2, 3)) + 1e-8).mean()
    ) 

deeplens.network.SSIMLoss

SSIMLoss(window_size=11, size_average=True)

Bases: Module

Structural Similarity Index (SSIM) loss.

Initialize SSIM loss.

Parameters:

Name Type Description Default
window_size

Size of the window.

11
size_average

Whether to average the loss.

True
Source code in deeplens/network/loss/ssim_loss.py
def __init__(self, window_size=11, size_average=True):
    """Initialize SSIM loss.

    Args:
        window_size: Size of the window.
        size_average: Whether to average the loss.
    """
    super(SSIMLoss, self).__init__()
    self.window_size = window_size
    self.size_average = size_average
    self.channel = 1
    self.window = self._create_window(window_size, self.channel)

forward

forward(pred, target)

Calculate SSIM loss.

Parameters:

Name Type Description Default
pred

Predicted tensor.

required
target

Target tensor.

required

Returns:

Type Description

1 - SSIM value.

Source code in deeplens/network/loss/ssim_loss.py
def forward(self, pred, target):
    """Calculate SSIM loss.

    Args:
        pred: Predicted tensor.
        target: Target tensor.

    Returns:
        1 - SSIM value.
    """
    return 1 - self._ssim(pred, target)