Skip to content

Decoder

flyvision.task.decoder.ActivityDecoder

Bases: Module

Base class for decoding DMN activity.

Parameters:

Name Type Description Default
connectome ConnectomeFromAvgFilters

Connectome directory with output_cell_types.

required

Attributes:

Name Type Description
dvs_channels LayerActivity

Dictionary of DVS channels.

num_parameters NumberOfParams

Number of parameters in the model.

u Tensor

u-coordinates of hexagonal grid.

v Tensor

v-coordinates of hexagonal grid.

H int

Height of the hexagonal grid.

W int

Width of the hexagonal grid.

Source code in flyvision/task/decoder.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
class ActivityDecoder(nn.Module):
    """
    Base class for decoding DMN activity.

    Args:
        connectome: Connectome directory with output_cell_types.

    Attributes:
        dvs_channels (LayerActivity): Dictionary of DVS channels.
        num_parameters (NumberOfParams): Number of parameters in the model.
        u (torch.Tensor): u-coordinates of hexagonal grid.
        v (torch.Tensor): v-coordinates of hexagonal grid.
        H (int): Height of the hexagonal grid.
        W (int): Width of the hexagonal grid.
    """

    dvs_channels: Union[Dict[str, torch.Tensor], LayerActivity]

    def __init__(self, connectome: ConnectomeFromAvgFilters):
        super().__init__()
        self.dvs_channels = LayerActivity(None, connectome, use_central=False)
        self.num_parameters = n_params(self)
        radius = connectome.config.extent
        self.u, self.v = get_hex_coords(radius)
        self.u -= self.u.min()
        self.v -= self.v.min()
        self.H, self.W = self.u.max() + 1, self.v.max() + 1

    def forward(self, activity: torch.Tensor) -> Dict[str, torch.Tensor]:
        """
        Forward pass of the ActivityDecoder.

        Args:
            activity: Tensor of shape (n_samples, n_frames, n_cells).

        Returns:
            Dictionary of tensors with shape
            (n_samples, n_frames, output_cell_types, n_hexals).
        """
        self.dvs_channels.update(activity)
        return self.dvs_channels

forward

forward(activity)

Forward pass of the ActivityDecoder.

Parameters:

Name Type Description Default
activity Tensor

Tensor of shape (n_samples, n_frames, n_cells).

required

Returns:

Type Description
Dict[str, Tensor]

Dictionary of tensors with shape

Dict[str, Tensor]

(n_samples, n_frames, output_cell_types, n_hexals).

Source code in flyvision/task/decoder.py
51
52
53
54
55
56
57
58
59
60
61
62
63
def forward(self, activity: torch.Tensor) -> Dict[str, torch.Tensor]:
    """
    Forward pass of the ActivityDecoder.

    Args:
        activity: Tensor of shape (n_samples, n_frames, n_cells).

    Returns:
        Dictionary of tensors with shape
        (n_samples, n_frames, output_cell_types, n_hexals).
    """
    self.dvs_channels.update(activity)
    return self.dvs_channels

flyvision.task.decoder.DecoderGAVP

Bases: ActivityDecoder

Fully convolutional decoder with optional global average pooling.

Parameters:

Name Type Description Default
connectome ConnectomeFromAvgFilters

Connectome directory.

required
shape List[int]

List of channel sizes for each layer.

required
kernel_size int

Size of the convolutional kernel.

required
p_dropout float

Dropout probability.

0.5
batch_norm bool

Whether to use batch normalization.

True
n_out_features Optional[int]

Number of output features.

None
const_weight Optional[float]

Constant value for weight initialization.

None
normalize_last bool

Whether to normalize the last layer.

True
activation str

Activation function to use.

'Softplus'

Attributes:

Name Type Description
_out_channels

Number of output channels before reshaping.

out_channels

Total number of output channels.

n_out_features

Number of output features.

base

Base convolutional layers.

decoder

Decoder convolutional layers.

head

Head layers for global average pooling.

normalize_last

Whether to normalize the last layer.

num_parameters

Number of parameters in the model.

Source code in flyvision/task/decoder.py
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
class DecoderGAVP(ActivityDecoder):
    """
    Fully convolutional decoder with optional global average pooling.

    Args:
        connectome: Connectome directory.
        shape: List of channel sizes for each layer.
        kernel_size: Size of the convolutional kernel.
        p_dropout: Dropout probability.
        batch_norm: Whether to use batch normalization.
        n_out_features: Number of output features.
        const_weight: Constant value for weight initialization.
        normalize_last: Whether to normalize the last layer.
        activation: Activation function to use.

    Attributes:
        _out_channels: Number of output channels before reshaping.
        out_channels: Total number of output channels.
        n_out_features: Number of output features.
        base: Base convolutional layers.
        decoder: Decoder convolutional layers.
        head: Head layers for global average pooling.
        normalize_last: Whether to normalize the last layer.
        num_parameters: Number of parameters in the model.
    """

    def __init__(
        self,
        connectome: ConnectomeFromAvgFilters,
        shape: List[int],
        kernel_size: int,
        p_dropout: float = 0.5,
        batch_norm: bool = True,
        n_out_features: Optional[int] = None,
        const_weight: Optional[float] = None,
        normalize_last: bool = True,
        activation: str = "Softplus",
    ):
        super().__init__(connectome)
        p = int((kernel_size - 1) / 2)
        in_channels = len(connectome.output_cell_types)
        out_channels = shape[-1]
        self._out_channels = out_channels
        self.out_channels = (
            out_channels * n_out_features if n_out_features is not None else out_channels
        )
        self.n_out_features = n_out_features

        self.base = []
        for c in shape[:-1]:
            if c == 0:
                continue
            self.base.append(
                Conv2dHexSpace(
                    in_channels,
                    c,
                    kernel_size,
                    const_weight=const_weight,
                    padding=p,
                )
            )
            if batch_norm:
                self.base.append(nn.BatchNorm2d(c))
            self.base.append(getattr(nn, activation)())
            if p_dropout:
                self.base.append(nn.Dropout(p_dropout))
            in_channels = c
        self.base = nn.Sequential(*self.base)

        self.decoder = []
        if len(self.base) == 0 and batch_norm:
            self.decoder.append(nn.BatchNorm2d(in_channels))
        self.decoder.append(
            Conv2dHexSpace(
                in_channels,
                self.out_channels + 1 if normalize_last else self.out_channels,
                kernel_size,
                const_weight=const_weight,
                padding=p,
            )
        )
        self.decoder = nn.Sequential(*self.decoder)

        self.n_out_features = n_out_features
        self.head = []
        if n_out_features is not None:
            self.head.append(GlobalAvgPool())
        self.head = nn.Sequential(*self.head)

        self.normalize_last = normalize_last

        self.num_parameters = n_params(self)
        logging.info(f"Initialized decoder with {self.num_parameters} parameters.")
        logging.info(repr(self))

    def forward(self, activity: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the DecoderGAVP.

        Args:
            activity: Input activity tensor.

        Returns:
            Decoded output tensor.
        """
        self.dvs_channels.update(activity)
        # Ensure that the outputs of the dvs-model are rectified potentials.
        x = nnf.relu(self.dvs_channels.output)

        # (n_frames, #samples, #outputneurons, n_hexals)
        n_samples, n_frames, in_channels, n_hexals = x.shape

        # Store hexals in square map.
        # (n_frames, #samples, #outputneurons, H, W)
        x_map = torch.zeros([n_samples, n_frames, in_channels, self.H, self.W])
        x_map[:, :, :, self.u, self.v] = x

        # Concatenate actual batch dimension with the frame dimension.
        # torch.flatten(x_map, 0, 1)  # (#samples*n_frames, #outputneurons, H, W)
        x_map = x_map.view(-1, in_channels, self.H, self.W)

        # Run decoder.
        # (n_frames*#samples, out_channels + 1, H, W)
        out = self.decoder(self.base(x_map))

        if self.normalize_last:
            # Do some normalization with the additional channel.
            # (n_frames*#samples, out_channels, H, W)
            out = out[:, : self.out_channels] / (
                nnf.softplus(out[:, self.out_channels :]) + 1
            )

        # Bring back into shape: # (#samples, n_frames, out_channels, n_hexals)
        out = out.view(n_samples, n_frames, self.out_channels, self.H, self.W)[
            :, :, :, self.u, self.v
        ]

        if self.n_out_features is not None:
            out = self.head(out).view(
                n_samples, n_frames, self._out_channels, self.n_out_features
            )

        return out

forward

forward(activity)

Forward pass of the DecoderGAVP.

Parameters:

Name Type Description Default
activity Tensor

Input activity tensor.

required

Returns:

Type Description
Tensor

Decoded output tensor.

Source code in flyvision/task/decoder.py
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
def forward(self, activity: torch.Tensor) -> torch.Tensor:
    """
    Forward pass of the DecoderGAVP.

    Args:
        activity: Input activity tensor.

    Returns:
        Decoded output tensor.
    """
    self.dvs_channels.update(activity)
    # Ensure that the outputs of the dvs-model are rectified potentials.
    x = nnf.relu(self.dvs_channels.output)

    # (n_frames, #samples, #outputneurons, n_hexals)
    n_samples, n_frames, in_channels, n_hexals = x.shape

    # Store hexals in square map.
    # (n_frames, #samples, #outputneurons, H, W)
    x_map = torch.zeros([n_samples, n_frames, in_channels, self.H, self.W])
    x_map[:, :, :, self.u, self.v] = x

    # Concatenate actual batch dimension with the frame dimension.
    # torch.flatten(x_map, 0, 1)  # (#samples*n_frames, #outputneurons, H, W)
    x_map = x_map.view(-1, in_channels, self.H, self.W)

    # Run decoder.
    # (n_frames*#samples, out_channels + 1, H, W)
    out = self.decoder(self.base(x_map))

    if self.normalize_last:
        # Do some normalization with the additional channel.
        # (n_frames*#samples, out_channels, H, W)
        out = out[:, : self.out_channels] / (
            nnf.softplus(out[:, self.out_channels :]) + 1
        )

    # Bring back into shape: # (#samples, n_frames, out_channels, n_hexals)
    out = out.view(n_samples, n_frames, self.out_channels, self.H, self.W)[
        :, :, :, self.u, self.v
    ]

    if self.n_out_features is not None:
        out = self.head(out).view(
            n_samples, n_frames, self._out_channels, self.n_out_features
        )

    return out

flyvision.task.decoder.init_decoder

init_decoder(decoder_config, connectome)

Initialize a decoder based on the provided configuration.

Parameters:

Name Type Description Default
decoder_config Namespace

Configuration for the decoder.

required
connectome ConnectomeFromAvgFilters

Connectome directory.

required

Returns:

Type Description
Module

Initialized decoder module.

Source code in flyvision/task/decoder.py
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
def init_decoder(
    decoder_config: Namespace, connectome: ConnectomeFromAvgFilters
) -> nn.Module:
    """
    Initialize a decoder based on the provided configuration.

    Args:
        decoder_config: Configuration for the decoder.
        connectome: Connectome directory.

    Returns:
        Initialized decoder module.
    """
    decoder_config = decoder_config.deepcopy()
    _type = decoder_config.pop("type")
    decoder_type = globals()[_type]
    decoder_config.update(dict(connectome=connectome))
    return decoder_type(**decoder_config)

flyvision.task.decoder.Conv2dHexSpace

Bases: Conv2dConstWeight

Convolution with regularly, hexagonally shaped filters (in cartesian map storage).

Reference to map storage: https://www.redblobgames.com/grids/hexagons/#map-storage

Info

kernel_size must be odd!

Parameters:

Name Type Description Default
in_channels int

Number of input channels.

required
out_channels int

Number of output channels.

required
kernel_size int

Size of the convolutional kernel.

required
const_weight Optional[float]

Optional constant value for weight initialization. If None, the standard PyTorch initialization is used.

0.001
stride int

Stride of the convolution.

1
padding int

Padding added to input.

0
**kwargs

Additional keyword arguments for Conv2d.

{}

Attributes:

Name Type Description
mask

Mask for hexagonal convolution.

_filter_to_hex

Whether to apply hexagonal filter.

Source code in flyvision/task/decoder.py
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
class Conv2dHexSpace(Conv2dConstWeight):
    """
    Convolution with regularly, hexagonally shaped filters (in cartesian map storage).

    Reference to map storage:
    https://www.redblobgames.com/grids/hexagons/#map-storage

    Info:
        kernel_size must be odd!

    Args:
        in_channels: Number of input channels.
        out_channels: Number of output channels.
        kernel_size: Size of the convolutional kernel.
        const_weight: Optional constant value for weight initialization.
            If None, the standard PyTorch initialization is used.
        stride: Stride of the convolution.
        padding: Padding added to input.
        **kwargs: Additional keyword arguments for Conv2d.

    Attributes:
        mask: Mask for hexagonal convolution.
        _filter_to_hex: Whether to apply hexagonal filter.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        const_weight: Optional[float] = 1e-3,
        stride: int = 1,
        padding: int = 0,
        **kwargs,
    ):
        super().__init__(
            in_channels,
            out_channels,
            kernel_size,
            const_weight,
            stride=stride,
            padding=padding,
            **kwargs,
        )

        if not kernel_size % 2:
            raise ValueError(f"{kernel_size} is even. Must be odd.")
        if kernel_size > 1:
            u, v = get_hex_coords(kernel_size // 2)
            u -= u.min()
            v -= v.min()
            mask = np.zeros(tuple(self.weight.shape))
            mask[:, :, u, v] = 1
            self.mask = torch.tensor(mask, device="cpu")
            self.weight.data.mul_(self.mask.to(device))
            self._filter_to_hex = True
        else:
            self._filter_to_hex = False

    def filter_to_hex(self):
        """Apply hexagonal filter to weights."""
        self.weight.data.mul_(self.mask.to(device))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the Conv2dHexSpace layer.

        Args:
            x: Input tensor.

        Returns:
            Output tensor after hexagonal convolution.
        """
        if self._filter_to_hex:
            self.filter_to_hex()
        return super().forward(x)

filter_to_hex

filter_to_hex()

Apply hexagonal filter to weights.

Source code in flyvision/task/decoder.py
171
172
173
def filter_to_hex(self):
    """Apply hexagonal filter to weights."""
    self.weight.data.mul_(self.mask.to(device))

forward

forward(x)

Forward pass of the Conv2dHexSpace layer.

Parameters:

Name Type Description Default
x Tensor

Input tensor.

required

Returns:

Type Description
Tensor

Output tensor after hexagonal convolution.

Source code in flyvision/task/decoder.py
175
176
177
178
179
180
181
182
183
184
185
186
187
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """
    Forward pass of the Conv2dHexSpace layer.

    Args:
        x: Input tensor.

    Returns:
        Output tensor after hexagonal convolution.
    """
    if self._filter_to_hex:
        self.filter_to_hex()
    return super().forward(x)

flyvision.task.decoder.Conv2dConstWeight

Bases: Conv2d

PyTorch’s Conv2d layer with optional constant weight initialization.

Parameters:

Name Type Description Default
in_channels int

Number of input channels.

required
out_channels int

Number of output channels.

required
kernel_size int

Size of the convolutional kernel.

required
const_weight Optional[float]

Optional constant value for weight initialization. If None, the standard PyTorch initialization is used.

None
stride int

Stride of the convolution.

1
padding int

Padding added to input.

0
**kwargs

Additional keyword arguments for Conv2d.

{}
Source code in flyvision/task/decoder.py
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
class Conv2dConstWeight(nn.Conv2d):
    """
    PyTorch's Conv2d layer with optional constant weight initialization.

    Args:
        in_channels: Number of input channels.
        out_channels: Number of output channels.
        kernel_size: Size of the convolutional kernel.
        const_weight: Optional constant value for weight initialization.
            If None, the standard PyTorch initialization is used.
        stride: Stride of the convolution.
        padding: Padding added to input.
        **kwargs: Additional keyword arguments for Conv2d.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        const_weight: Optional[float] = None,
        stride: int = 1,
        padding: int = 0,
        **kwargs,
    ):
        super().__init__(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            **kwargs,
        )
        if const_weight is not None and self.weight is not None:
            self.weight.data.fill_(const_weight)
        if const_weight is not None and self.bias is not None:
            self.bias.data.fill_(const_weight)

flyvision.task.decoder.GlobalAvgPool

Bases: Module

Returns the average over the last dimension.

Source code in flyvision/task/decoder.py
66
67
68
69
70
class GlobalAvgPool(nn.Module):
    """Returns the average over the last dimension."""

    def forward(self, x):
        return x.mean(dim=-1)