Skip to content

Rendering

BoxEye

flyvision.datasets.rendering.eye.BoxEye

BoxFilter to produce an array of hexals matching the photoreceptor array.

Parameters:

Name Type Description Default
extent int

Radius, in number of receptors, of the hexagonal array.

15
kernel_size int

Photon collection radius, in pixels.

13

Attributes:

Name Type Description
extent int

Radius, in number of receptors, of the hexagonal array.

kernel_size int

Photon collection radius, in pixels.

receptor_centers Tensor

Tensor of shape (hexals, 2) containing the y, x coordinates of the hexal centers.

hexals int

Number of hexals in the array.

min_frame_size Tensor

Minimum frame size to contain the hexal array.

pad Tuple[int, int, int, int]

Padding to apply to the frame before convolution.

conv Conv2d

Convolutional box filter to apply to the frame.

Source code in flyvision/datasets/rendering/eye.py
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
class BoxEye:
    """BoxFilter to produce an array of hexals matching the photoreceptor array.

    Args:
        extent: Radius, in number of receptors, of the hexagonal array.
        kernel_size: Photon collection radius, in pixels.

    Attributes:
        extent (int): Radius, in number of receptors, of the hexagonal array.
        kernel_size (int): Photon collection radius, in pixels.
        receptor_centers (torch.Tensor): Tensor of shape (hexals, 2) containing the y, x
            coordinates of the hexal centers.
        hexals (int): Number of hexals in the array.
        min_frame_size (torch.Tensor): Minimum frame size to contain the hexal array.
        pad (Tuple[int, int, int, int]): Padding to apply to the frame before convolution.
        conv (nn.Conv2d): Convolutional box filter to apply to the frame.
    """

    def __init__(self, extent: int = 15, kernel_size: int = 13):
        self.extent = extent
        self.kernel_size = kernel_size
        self.receptor_centers = torch.tensor(
            [*self._receptor_centers()], dtype=torch.long
        )
        self.hexals = len(self.receptor_centers)
        # The rest of kernel_size distance from outer centers to the border
        # is taken care of by the padding of the convolution object.
        self.min_frame_size = (
            self.receptor_centers.max(dim=0).values
            - self.receptor_centers.min(dim=0).values
            + 1
        )
        self._set_filter()

        pad = (self.kernel_size - 1) / 2
        self.pad = (
            int(np.ceil(pad)),
            int(np.floor(pad)),
            int(np.ceil(pad)),
            int(np.floor(pad)),
        )

    def _receptor_centers(self) -> Iterator[Tuple[float, float]]:
        """Generate receptor center coordinates.

        Returns:
            Iterator[Tuple[float, float]]: Yields y, x coordinates of receptor centers.
        """
        n = self.extent
        d = self.kernel_size
        for u in range(-n, n + 1):
            v_min = max(-n, -n - u)
            v_max = min(n, n - u)
            for v in range(v_min, v_max + 1):
                # y = -d * v
                # x = 2 / np.sqrt(3) * d * (u + v/2)
                y = d * (
                    u + v / 2
                )  # - d * v # either must be negative or origin must be upper
                x = d * v  # 2 / np.sqrt(3) * d * (u + v / 2)
                yield y, x
                # xs.append()
                # ys.append()

    def _set_filter(self) -> None:
        """Set up the convolutional filter for the box kernel."""
        self.conv = nn.Conv2d(1, 1, kernel_size=self.kernel_size, stride=1, padding=0)
        self.conv.weight.data /= self.conv.weight.data
        self.conv.bias.data.fill_(0)  # if not self.requires_grad else None
        self.conv.weight.requires_grad = False  # self.requires_grad
        self.conv.bias.requires_grad = False  # self.requires_grad

    def __call__(
        self,
        sequence: torch.Tensor,
        ftype: Literal["mean", "sum", "median"] = "mean",
        hex_sample: bool = True,
    ) -> torch.Tensor:
        """Apply a box kernel to all frames in a sequence.

        Args:
            sequence: Cartesian movie sequences of shape (samples, frames, height, width).
            ftype: Filter type.
            hex_sample: If False, returns filtered cartesian sequences.

        Returns:
            torch.Tensor: Shape (samples, frames, 1, hexals) if hex_sample is True,
                otherwise (samples, frames, height, width).
        """
        samples, frames, height, width = sequence.shape

        if not isinstance(sequence, torch.Tensor):
            # auto-moving to GPU in case default tensor is cuda but passed
            # sequence is not, for convenience
            sequence = torch.tensor(
                sequence, dtype=torch.float32, device=flyvision.device
            )

        if (self.min_frame_size > torch.tensor([height, width])).any():
            # to rescale to the minimum frame size
            sequence = ttf.resize(sequence, self.min_frame_size.tolist())
            height, width = sequence.shape[2:]

        def _convolve():
            # convole each sample sequentially to avoid gpu memory issues
            def conv(x):
                return self.conv(x.unsqueeze(1))

            return torch.cat(
                tuple(map(conv, torch.unbind(F.pad(sequence, self.pad), dim=0))), dim=0
            )

        if ftype == "mean":
            out = _convolve() / self.kernel_size**2
        elif ftype == "sum":
            out = _convolve()
        elif ftype == "median":
            out = median(sequence, self.kernel_size)
        else:
            raise ValueError("ftype must be 'sum', 'mean', or 'median." f"Is {ftype}.")

        if hex_sample is True:
            return self.hex_render(out).reshape(samples, frames, 1, -1)

        return out.reshape(samples, frames, height, width)

    def hex_render(self, sequence: torch.Tensor) -> torch.Tensor:
        """Sample receptor locations from a sequence of cartesian frames.

        Args:
            sequence: Cartesian movie sequences of shape (samples, frames, height, width).

        Returns:
            torch.Tensor: Shape (samples, frames, 1, hexals).

        Note:
            Resizes the sequence to the minimum frame size if necessary.
        """
        h, w = sequence.shape[2:]
        if (self.min_frame_size > torch.tensor([h, w])).any():
            sequence = ttf.resize(sequence, self.min_frame_size.tolist())
            h, w = sequence.shape[2:]
        c = self.receptor_centers + torch.tensor([h // 2, w // 2])
        out = sequence[:, :, c[:, 0], c[:, 1]]
        return out.view(*sequence.shape[:2], 1, -1)

    def illustrate(self) -> plt.Figure:
        """Illustrate the receptive field centers and the hexagonal sampling.

        Returns:
            plt.Figure: Matplotlib figure object.
        """
        figsize = [2, 2]
        fontsize = 5
        y_hc, x_hc = np.array(list(self._receptor_centers())).T

        height, width = self.min_frame_size.cpu().numpy()
        x_img, y_img = np.array(
            list(
                product(
                    np.arange(-width / 2, width / 2),
                    np.arange(-height / 2, height / 2),
                )
            )
        ).T

        r = np.sqrt(2) * self.kernel_size / 2

        vertices = []
        angles = [45, 135, 225, 315, 405]
        for _y_c, _x_c in zip(y_hc, x_hc):
            _vertices = []
            for angle in angles:
                offset = r * np.exp(np.radians(angle) * 1j)
                _vertices.append([_y_c + offset.real, _x_c + offset.imag])
            vertices.append(_vertices)
        vertices = np.transpose(vertices, (1, 2, 0))

        fig, ax = init_plot(figsize=figsize, fontsize=fontsize)
        ax.scatter(x_hc, y_hc, color="#00008B", zorder=1, s=0.5)
        ax.scatter(x_img, y_img, color="#34ebd6", s=0.1, zorder=0)

        for h in range(len(x_hc)):
            for i in range(4):
                y1, x1 = vertices[i, :, h]  # x1, y1: (n_hexagons)
                y2, x2 = vertices[i + 1, :, h]
                ax.plot([x1, x2], [y1, y2], c="black", lw=0.25)

        ax.set_xlim(-width / 2, width / 2)
        ax.set_ylim(-height / 2, height / 2)
        rm_spines(ax)
        # fig.tight_layout()
        return fig

__call__

__call__(sequence, ftype='mean', hex_sample=True)

Apply a box kernel to all frames in a sequence.

Parameters:

Name Type Description Default
sequence Tensor

Cartesian movie sequences of shape (samples, frames, height, width).

required
ftype Literal['mean', 'sum', 'median']

Filter type.

'mean'
hex_sample bool

If False, returns filtered cartesian sequences.

True

Returns:

Type Description
Tensor

torch.Tensor: Shape (samples, frames, 1, hexals) if hex_sample is True, otherwise (samples, frames, height, width).

Source code in flyvision/datasets/rendering/eye.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
def __call__(
    self,
    sequence: torch.Tensor,
    ftype: Literal["mean", "sum", "median"] = "mean",
    hex_sample: bool = True,
) -> torch.Tensor:
    """Apply a box kernel to all frames in a sequence.

    Args:
        sequence: Cartesian movie sequences of shape (samples, frames, height, width).
        ftype: Filter type.
        hex_sample: If False, returns filtered cartesian sequences.

    Returns:
        torch.Tensor: Shape (samples, frames, 1, hexals) if hex_sample is True,
            otherwise (samples, frames, height, width).
    """
    samples, frames, height, width = sequence.shape

    if not isinstance(sequence, torch.Tensor):
        # auto-moving to GPU in case default tensor is cuda but passed
        # sequence is not, for convenience
        sequence = torch.tensor(
            sequence, dtype=torch.float32, device=flyvision.device
        )

    if (self.min_frame_size > torch.tensor([height, width])).any():
        # to rescale to the minimum frame size
        sequence = ttf.resize(sequence, self.min_frame_size.tolist())
        height, width = sequence.shape[2:]

    def _convolve():
        # convole each sample sequentially to avoid gpu memory issues
        def conv(x):
            return self.conv(x.unsqueeze(1))

        return torch.cat(
            tuple(map(conv, torch.unbind(F.pad(sequence, self.pad), dim=0))), dim=0
        )

    if ftype == "mean":
        out = _convolve() / self.kernel_size**2
    elif ftype == "sum":
        out = _convolve()
    elif ftype == "median":
        out = median(sequence, self.kernel_size)
    else:
        raise ValueError("ftype must be 'sum', 'mean', or 'median." f"Is {ftype}.")

    if hex_sample is True:
        return self.hex_render(out).reshape(samples, frames, 1, -1)

    return out.reshape(samples, frames, height, width)

hex_render

hex_render(sequence)

Sample receptor locations from a sequence of cartesian frames.

Parameters:

Name Type Description Default
sequence Tensor

Cartesian movie sequences of shape (samples, frames, height, width).

required

Returns:

Type Description
Tensor

torch.Tensor: Shape (samples, frames, 1, hexals).

Note

Resizes the sequence to the minimum frame size if necessary.

Source code in flyvision/datasets/rendering/eye.py
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
def hex_render(self, sequence: torch.Tensor) -> torch.Tensor:
    """Sample receptor locations from a sequence of cartesian frames.

    Args:
        sequence: Cartesian movie sequences of shape (samples, frames, height, width).

    Returns:
        torch.Tensor: Shape (samples, frames, 1, hexals).

    Note:
        Resizes the sequence to the minimum frame size if necessary.
    """
    h, w = sequence.shape[2:]
    if (self.min_frame_size > torch.tensor([h, w])).any():
        sequence = ttf.resize(sequence, self.min_frame_size.tolist())
        h, w = sequence.shape[2:]
    c = self.receptor_centers + torch.tensor([h // 2, w // 2])
    out = sequence[:, :, c[:, 0], c[:, 1]]
    return out.view(*sequence.shape[:2], 1, -1)

illustrate

illustrate()

Illustrate the receptive field centers and the hexagonal sampling.

Returns:

Type Description
Figure

plt.Figure: Matplotlib figure object.

Source code in flyvision/datasets/rendering/eye.py
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
def illustrate(self) -> plt.Figure:
    """Illustrate the receptive field centers and the hexagonal sampling.

    Returns:
        plt.Figure: Matplotlib figure object.
    """
    figsize = [2, 2]
    fontsize = 5
    y_hc, x_hc = np.array(list(self._receptor_centers())).T

    height, width = self.min_frame_size.cpu().numpy()
    x_img, y_img = np.array(
        list(
            product(
                np.arange(-width / 2, width / 2),
                np.arange(-height / 2, height / 2),
            )
        )
    ).T

    r = np.sqrt(2) * self.kernel_size / 2

    vertices = []
    angles = [45, 135, 225, 315, 405]
    for _y_c, _x_c in zip(y_hc, x_hc):
        _vertices = []
        for angle in angles:
            offset = r * np.exp(np.radians(angle) * 1j)
            _vertices.append([_y_c + offset.real, _x_c + offset.imag])
        vertices.append(_vertices)
    vertices = np.transpose(vertices, (1, 2, 0))

    fig, ax = init_plot(figsize=figsize, fontsize=fontsize)
    ax.scatter(x_hc, y_hc, color="#00008B", zorder=1, s=0.5)
    ax.scatter(x_img, y_img, color="#34ebd6", s=0.1, zorder=0)

    for h in range(len(x_hc)):
        for i in range(4):
            y1, x1 = vertices[i, :, h]  # x1, y1: (n_hexagons)
            y2, x2 = vertices[i + 1, :, h]
            ax.plot([x1, x2], [y1, y2], c="black", lw=0.25)

    ax.set_xlim(-width / 2, width / 2)
    ax.set_ylim(-height / 2, height / 2)
    rm_spines(ax)
    # fig.tight_layout()
    return fig

HexEye

flyvision.datasets.rendering.eye.HexEye

Hexagonal eye model for more precise rendering.

Parameters:

Name Type Description Default
n_ommatidia int

Number of ommatidia in the eye. Must currently fill a regular hex grid.

721
ppo int

Pixels per ommatidium.

25
monitor_height_px Optional[int]

Monitor height in pixels.

None
monitor_width_px Optional[int]

Monitor width in pixels.

None
device device

Computation device.

device
dtype dtype

Data type for computations.

float16

Attributes:

Name Type Description
monitor_width_px int

Monitor width in pixels.

monitor_height_px int

Monitor height in pixels.

is_inside Tensor

Boolean mask for pixels inside hexagons.

n_ommatidia int

Number of ommatidia in the eye.

omm_width_rad float

Ommatidium width in radians.

omm_height_rad float

Ommatidium height in radians.

ppo int

Pixels per ommatidium.

n_hex_circfer float

Number of hexagons in the circumference.

device device

Computation device.

dtype dtype

Data type for computations.

Source code in flyvision/datasets/rendering/eye.py
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
class HexEye:
    """Hexagonal eye model for more precise rendering.

    Args:
        n_ommatidia: Number of ommatidia in the eye. Must currently fill a regular
            hex grid.
        ppo: Pixels per ommatidium.
        monitor_height_px: Monitor height in pixels.
        monitor_width_px: Monitor width in pixels.
        device: Computation device.
        dtype: Data type for computations.

    Attributes:
        monitor_width_px (int): Monitor width in pixels.
        monitor_height_px (int): Monitor height in pixels.
        is_inside (torch.Tensor): Boolean mask for pixels inside hexagons.
        n_ommatidia (int): Number of ommatidia in the eye.
        omm_width_rad (float): Ommatidium width in radians.
        omm_height_rad (float): Ommatidium height in radians.
        ppo (int): Pixels per ommatidium.
        n_hex_circfer (float): Number of hexagons in the circumference.
        device (torch.device): Computation device.
        dtype (torch.dtype): Data type for computations.
    """

    def __init__(
        self,
        n_ommatidia: int = 721,
        ppo: int = 25,
        monitor_height_px: Optional[int] = None,
        monitor_width_px: Optional[int] = None,
        device: torch.device = flyvision.device,
        dtype: torch.dtype = torch.float16,
    ):
        n_hex_circfer = 2 * (-1 / 2 + np.sqrt(1 / 4 - ((1 - n_ommatidia) / 3))) + 1

        if n_hex_circfer % 1 != 0:
            raise ValueError(f"{n_ommatidia} does not fill a regular hex grid.")

        self.monitor_width_px = monitor_width_px or ppo * int(n_hex_circfer)
        self.monitor_height_px = monitor_height_px or ppo * int(n_hex_circfer)

        x_hc, y_hc, (dist_w, dist_h) = hex_center_coordinates(
            n_ommatidia, self.monitor_width_px, self.monitor_height_px
        )

        x_img, y_img = np.array(
            list(
                product(
                    np.arange(self.monitor_width_px),
                    np.arange(self.monitor_height_px),
                )
            )
        ).T

        dist_to_edge = (dist_w + dist_h) / 4

        _, self.is_inside = is_inside_hex(
            torch.tensor(y_img, dtype=dtype, device=device),
            torch.tensor(x_img, dtype=dtype, device=device),
            torch.tensor(x_hc, dtype=dtype, device=device),
            torch.tensor(y_hc, dtype=dtype, device=device),
            torch.tensor(dist_to_edge, dtype=dtype, device=device),
            torch.tensor(np.radians(0), dtype=dtype, device=device),
            device=device,
            dtype=dtype,
        )
        self.kernel_sum = self.is_inside.sum(dim=0)
        # Clean up excessive memory usage.
        if device != "cpu":
            torch.cuda.empty_cache()
        self.n_ommatidia = n_ommatidia
        self.omm_width_rad = np.radians(5.8)
        self.omm_height_rad = np.radians(5.8)
        self.ppo = ppo or int(
            (self.monitor_width_px + self.monitor_width_px) / (2 * n_hex_circfer)
        )
        self.n_hex_circfer = n_hex_circfer
        self.device = device
        self.dtype = dtype

    def __call__(
        self,
        stim: torch.Tensor,
        mode: Literal["mean", "median", "sum"] = "mean",
        n_chunks: int = 1,
    ) -> torch.Tensor:
        """Process stimulus through the hexagonal eye model.

        Args:
            stim: Input stimulus tensor (n_frames, height, width) or
                (n_frames, height * width). Height and width must correspond to the
                monitor size.
            mode: Processing mode.
            n_chunks: Number of chunks to process the stimulus in.

        Returns:
            torch.Tensor: Processed stimulus.
        """
        shape = stim.shape
        if mode not in ["mean", "median", "sum"]:
            raise ValueError

        if len(stim.shape) == 3:
            h, w = stim.shape[1:]
            # resize to monitor size if necessary
            if h < self.monitor_height_px or w < self.monitor_width_px:
                stim = ttf.resize(stim, [self.monitor_height_px, self.monitor_width_px])
            stim = stim.reshape(shape[0], -1)

        try:
            if mode == "median":
                n_pixels = shape[1]
                stim = median(
                    stim.view(-1, self.monitor_height_px, self.monitor_width_px)
                    .float()
                    .to(self.device),
                    int(np.sqrt(self.ppo)),
                ).view(-1, n_pixels)
            elif mode == "sum":
                return (stim[:, :, None] * self.is_inside).sum(dim=1)
            elif mode == "mean":
                return (stim[:, :, None] * self.is_inside).sum(dim=1) / self.kernel_sum
            else:
                raise ValueError(f"Invalid mode: {mode}")

        except RuntimeError as e:
            if "memory" not in str(e):
                raise e
            if "CUDA" in str(e):
                torch.cuda.empty_cache()
            if n_chunks > shape[0]:
                raise ValueError from e

            chunks = torch.chunk(stim, max(n_chunks, 1), dim=0)

            def map_fn(chunk):
                return self(chunk, mode=mode, n_chunks=n_chunks + 1)

            return torch.cat(tuple(map(map_fn, chunks)), dim=0)

    # Rendering functions
    # TODO: move to separate file and make agnostic of the eye model

    def render_bar(
        self,
        bar_width_rad: float,
        bar_height_rad: float,
        bar_loc_theta: float,
        bar_loc_phi: float,
        n_bars: int,
        bar_intensity: float,
        bg_intensity: float,
        moving_angle: float,
        cartesian: bool = False,
        mode: Literal["mean", "median", "sum"] = "mean",
    ) -> Union[np.ndarray, torch.Tensor]:
        """Render a bar stimulus.

        Args:
            bar_width_rad: Width of bars in radians.
            bar_height_rad: Height of bars in radians.
            bar_loc_theta: Horizontal location of bars in radians.
            bar_loc_phi: Vertical location of bars in radians.
            n_bars: Number of bars.
            bar_intensity: Intensity of the bar.
            bg_intensity: Intensity of the background.
            moving_angle: Rotation angle in degrees.
            cartesian: If True, return cartesian coordinates.
            mode: Processing mode.

        Returns:
            Union[np.ndarray, torch.Tensor]: Generated bar stimulus.
        """
        bar_width_px = int(bar_width_rad / self.omm_width_rad * self.ppo)
        bar_height_px = int(bar_height_rad / self.omm_height_rad * self.ppo)
        bar_loc_horizontal_px = int(
            self.monitor_width_px * bar_loc_theta / np.radians(180)
        )
        bar_loc_vertical_px = int(self.monitor_height_px * bar_loc_phi / np.radians(180))

        bar = render_bars_cartesian(
            self.monitor_height_px,
            self.monitor_width_px,
            bar_width_px,
            bar_height_px,
            bar_loc_horizontal_px,
            bar_loc_vertical_px,
            n_bars,
            bar_intensity,
            bg_intensity,
            moving_angle,
        )
        if cartesian:
            return bar
        return self(torch.tensor(bar.flatten(), device=self.device)[None], mode)

    def render_grating(
        self,
        period_rad: float,
        phase_rad: float,
        intensity: float,
        bg_intensity: float,
        moving_angle: float,
        width_rad: Optional[float] = None,
        height_rad: Optional[float] = None,
        cartesian: bool = False,
        mode: Literal["mean", "median", "sum"] = "mean",
    ) -> Union[np.ndarray, torch.Tensor]:
        """Render a grating stimulus.

        Args:
            period_rad: Period of the grating in radians.
            phase_rad: Phase of the grating in radians.
            intensity: Intensity of the grating.
            bg_intensity: Intensity of the background.
            moving_angle: Rotation angle in degrees.
            width_rad: Width of the grating in radians.
            height_rad: Height of the grating in radians.
            cartesian: If True, return cartesian coordinates.
            mode: Processing mode.

        Returns:
            Union[np.ndarray, torch.Tensor]: Generated grating stimulus.
        """
        period_px = int(period_rad / self.omm_width_rad * self.ppo)
        phase_px = int(phase_rad / self.omm_width_rad * self.ppo)

        height_rad_px = None
        if height_rad:
            height_rad_px = int(height_rad / self.omm_height_rad * self.ppo)

        width_rad_px = None
        if width_rad:
            width_rad_px = int(width_rad / self.omm_width_rad * self.ppo)

        grating = render_gratings_cartesian(
            self.monitor_height_px,
            self.monitor_width_px,
            period_px,
            intensity,
            bg_intensity,
            grating_phase_px=phase_px,
            rotate=moving_angle,
            grating_height_px=height_rad_px,
            grating_width_px=width_rad_px,
        )
        if cartesian:
            return grating
        return self(torch.tensor(grating.flatten(), device=self.device)[None], mode)

    def render_grating_offsets(
        self,
        period_rad: float,
        intensity: float,
        bg_intensity: float,
        moving_angle: float,
        width_rad: Optional[float] = None,
        height_rad: Optional[float] = None,
        cartesian: bool = False,
        mode: Literal["mean", "median", "sum"] = "mean",
    ) -> Union[np.ndarray, torch.Tensor]:
        """Render grating stimuli with a range of offsets.

        Args:
            period_rad: Period of the grating in radians.
            intensity: Intensity of the grating.
            bg_intensity: Intensity of the background.
            moving_angle: Rotation angle in degrees.
            width_rad: Width of the grating in radians.
            height_rad: Height of the grating in radians.
            cartesian: If True, return cartesian coordinates.
            mode: Processing mode.

        Returns:
            Union[np.ndarray, torch.Tensor]: Generated grating stimuli with offsets.
        """
        dphase_px = np.radians(
            5.8 / 2
        )  # half ommatidia width - corresponds to led width of 2.25 degree
        n_offsets = np.ceil(period_rad / dphase_px).astype(int)
        gratings = []
        for offset in range(n_offsets):
            gratings.append(
                self.render_grating(
                    period_rad,
                    offset * dphase_px,
                    intensity,
                    bg_intensity,
                    moving_angle,
                    width_rad=width_rad,
                    height_rad=height_rad,
                    cartesian=cartesian,
                    mode=mode,
                )
            )
        if cartesian:
            return np.array(gratings)
        return torch.cat(gratings, dim=0)

    def render_offset_bars(
        self,
        bar_width_rad: float,
        bar_height_rad: float,
        n_bars: int,
        offsets: List[float],
        bar_intensity: float,
        bg_intensity: float,
        moving_angle: float,
        bar_loc_horizontal: float = np.radians(90),
        bar_loc_vertical: float = np.radians(90),
        mode: Literal["mean", "median", "sum"] = "mean",
    ) -> torch.Tensor:
        """Render bars with a range of offsets.

        Args:
            bar_width_rad: Width of bars in radians.
            bar_height_rad: Height of bars in radians.
            n_bars: Number of bars.
            offsets: Offsets of bars wrt. the center in radians.
            bar_intensity: Intensity of the bar.
            bg_intensity: Intensity of the background.
            moving_angle: Rotation angle in degrees.
            bar_loc_horizontal: Horizontal location of bars in radians.
            bar_loc_vertical: Vertical location of bars in radians.
            mode: Processing mode.

        Returns:
            torch.Tensor: Generated offset bars.
        """
        flashes = []
        for offset in offsets:
            flashes.append(
                self.render_bar(
                    bar_width_rad,
                    bar_height_rad,
                    bar_loc_horizontal + offset,
                    bar_loc_vertical,
                    n_bars,
                    bar_intensity,
                    bg_intensity,
                    moving_angle,
                    mode=mode,
                )
            )
        return torch.cat(flashes, dim=0)

    def render_bar_movie(
        self,
        t_stim: float,
        dt: float,
        bar_width_rad: float,
        bar_height_rad: float,
        n_bars: int,
        offsets: List[float],
        bar_intensity: float,
        bg_intensity: float,
        moving_angle: float,
        t_pre: float = 0.0,
        t_between: float = 0.0,
        t_post: float = 0.0,
        bar_loc_horizontal: float = np.radians(90),
        bar_loc_vertical: float = np.radians(90),
    ) -> torch.Tensor:
        """Render moving bars.

        Args:
            t_stim: Stimulus duration.
            dt: Temporal resolution.
            bar_width_rad: Width of bars in radians.
            bar_height_rad: Height of bars in radians.
            n_bars: Number of bars.
            offsets: Offsets of bars wrt. the center in radians.
            bar_intensity: Intensity of the bar.
            bg_intensity: Intensity of the background.
            moving_angle: Rotation angle in degrees.
            t_pre: Grey pre stimulus duration.
            t_between: Grey between offset stimulus duration.
            t_post: Grey post stimulus duration.
            bar_loc_horizontal: Horizontal location of bars in radians.
            bar_loc_vertical: Vertical location of bars in radians.

        Returns:
            torch.Tensor: Generated moving bars.
        """
        pre_frames = round(t_pre / dt)
        stim_frames = round(t_stim / (len(offsets) * dt))
        if stim_frames == 0:
            raise ValueError(
                f"stimulus time {t_stim}s not sufficient to sample {len(offsets)} "
                "offsets at {dt}s"
            )
        between_frames = round(t_between / dt)
        post_frames = round(t_post / dt)

        flashes = []
        if pre_frames:
            flashes.append(torch.ones([pre_frames, self.n_ommatidia]) * bg_intensity)

        for i, offset in enumerate(offsets):
            flash = self.render_bar(
                bar_width_rad,
                bar_height_rad,
                bar_loc_horizontal + offset,
                bar_loc_vertical,
                n_bars,
                bar_intensity,
                bg_intensity,
                moving_angle,
            )
            flashes.append(flash.repeat(stim_frames, 1))

            if between_frames and i < len(offsets) - 1:
                flashes.append(
                    torch.ones([between_frames, self.n_ommatidia]) * bg_intensity
                )
        if post_frames:
            flashes.append(torch.ones([post_frames, self.n_ommatidia]) * bg_intensity)
        return torch.cat(flashes, dim=0)

    def illustrate(
        self, figsize: List[int] = [5, 5], fontsize: int = 5
    ) -> Tuple[plt.Figure, plt.Axes]:
        """Illustrate the hexagonal eye model.

        Args:
            figsize: Figure size.
            fontsize: Font size for the plot.

        Returns:
            Tuple[plt.Figure, plt.Axes]: Matplotlib figure and axes objects.
        """
        x_hc, y_hc, (dist_w, dist_h) = hex_center_coordinates(
            self.n_ommatidia, self.monitor_width_px, self.monitor_height_px
        )

        x_img, y_img = np.array(
            list(
                product(
                    np.arange(self.monitor_width_px),
                    np.arange(
                        self.monitor_height_px,
                    ),
                )
            )
        ).T

        dist_to_edge = (dist_w + dist_h) / 4

        vertices, _ = is_inside_hex(
            torch.tensor(y_img, dtype=self.dtype),
            torch.tensor(x_img, dtype=self.dtype),
            torch.tensor(x_hc, dtype=self.dtype),
            torch.tensor(y_hc, dtype=self.dtype),
            torch.tensor(dist_to_edge, dtype=self.dtype),
            torch.tensor(np.radians(0), dtype=self.dtype),
        )
        vertices = vertices.cpu()
        fig, ax = init_plot(figsize=figsize, fontsize=fontsize)
        ax.scatter(x_hc, y_hc, color="#eb4034", zorder=1)
        ax.scatter(x_img, y_img, color="#34ebd6", s=0.5, zorder=0)

        for h in range(self.n_ommatidia):
            for i in range(6):
                x1, y1 = vertices[i, :, h]  # x1, y1: (n_hexagons)
                x2, y2 = vertices[i + 1, :, h]
                ax.plot([x1, x2], [y1, y2], c="black")

        ax.set_xlim(0, self.monitor_width_px)
        ax.set_ylim(0, self.monitor_height_px)
        rm_spines(ax)
        # fig.tight_layout()
        return fig, ax

__call__

__call__(stim, mode='mean', n_chunks=1)

Process stimulus through the hexagonal eye model.

Parameters:

Name Type Description Default
stim Tensor

Input stimulus tensor (n_frames, height, width) or (n_frames, height * width). Height and width must correspond to the monitor size.

required
mode Literal['mean', 'median', 'sum']

Processing mode.

'mean'
n_chunks int

Number of chunks to process the stimulus in.

1

Returns:

Type Description
Tensor

torch.Tensor: Processed stimulus.

Source code in flyvision/datasets/rendering/eye.py
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
def __call__(
    self,
    stim: torch.Tensor,
    mode: Literal["mean", "median", "sum"] = "mean",
    n_chunks: int = 1,
) -> torch.Tensor:
    """Process stimulus through the hexagonal eye model.

    Args:
        stim: Input stimulus tensor (n_frames, height, width) or
            (n_frames, height * width). Height and width must correspond to the
            monitor size.
        mode: Processing mode.
        n_chunks: Number of chunks to process the stimulus in.

    Returns:
        torch.Tensor: Processed stimulus.
    """
    shape = stim.shape
    if mode not in ["mean", "median", "sum"]:
        raise ValueError

    if len(stim.shape) == 3:
        h, w = stim.shape[1:]
        # resize to monitor size if necessary
        if h < self.monitor_height_px or w < self.monitor_width_px:
            stim = ttf.resize(stim, [self.monitor_height_px, self.monitor_width_px])
        stim = stim.reshape(shape[0], -1)

    try:
        if mode == "median":
            n_pixels = shape[1]
            stim = median(
                stim.view(-1, self.monitor_height_px, self.monitor_width_px)
                .float()
                .to(self.device),
                int(np.sqrt(self.ppo)),
            ).view(-1, n_pixels)
        elif mode == "sum":
            return (stim[:, :, None] * self.is_inside).sum(dim=1)
        elif mode == "mean":
            return (stim[:, :, None] * self.is_inside).sum(dim=1) / self.kernel_sum
        else:
            raise ValueError(f"Invalid mode: {mode}")

    except RuntimeError as e:
        if "memory" not in str(e):
            raise e
        if "CUDA" in str(e):
            torch.cuda.empty_cache()
        if n_chunks > shape[0]:
            raise ValueError from e

        chunks = torch.chunk(stim, max(n_chunks, 1), dim=0)

        def map_fn(chunk):
            return self(chunk, mode=mode, n_chunks=n_chunks + 1)

        return torch.cat(tuple(map(map_fn, chunks)), dim=0)

render_bar

render_bar(
    bar_width_rad,
    bar_height_rad,
    bar_loc_theta,
    bar_loc_phi,
    n_bars,
    bar_intensity,
    bg_intensity,
    moving_angle,
    cartesian=False,
    mode="mean",
)

Render a bar stimulus.

Parameters:

Name Type Description Default
bar_width_rad float

Width of bars in radians.

required
bar_height_rad float

Height of bars in radians.

required
bar_loc_theta float

Horizontal location of bars in radians.

required
bar_loc_phi float

Vertical location of bars in radians.

required
n_bars int

Number of bars.

required
bar_intensity float

Intensity of the bar.

required
bg_intensity float

Intensity of the background.

required
moving_angle float

Rotation angle in degrees.

required
cartesian bool

If True, return cartesian coordinates.

False
mode Literal['mean', 'median', 'sum']

Processing mode.

'mean'

Returns:

Type Description
Union[ndarray, Tensor]

Union[np.ndarray, torch.Tensor]: Generated bar stimulus.

Source code in flyvision/datasets/rendering/eye.py
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
def render_bar(
    self,
    bar_width_rad: float,
    bar_height_rad: float,
    bar_loc_theta: float,
    bar_loc_phi: float,
    n_bars: int,
    bar_intensity: float,
    bg_intensity: float,
    moving_angle: float,
    cartesian: bool = False,
    mode: Literal["mean", "median", "sum"] = "mean",
) -> Union[np.ndarray, torch.Tensor]:
    """Render a bar stimulus.

    Args:
        bar_width_rad: Width of bars in radians.
        bar_height_rad: Height of bars in radians.
        bar_loc_theta: Horizontal location of bars in radians.
        bar_loc_phi: Vertical location of bars in radians.
        n_bars: Number of bars.
        bar_intensity: Intensity of the bar.
        bg_intensity: Intensity of the background.
        moving_angle: Rotation angle in degrees.
        cartesian: If True, return cartesian coordinates.
        mode: Processing mode.

    Returns:
        Union[np.ndarray, torch.Tensor]: Generated bar stimulus.
    """
    bar_width_px = int(bar_width_rad / self.omm_width_rad * self.ppo)
    bar_height_px = int(bar_height_rad / self.omm_height_rad * self.ppo)
    bar_loc_horizontal_px = int(
        self.monitor_width_px * bar_loc_theta / np.radians(180)
    )
    bar_loc_vertical_px = int(self.monitor_height_px * bar_loc_phi / np.radians(180))

    bar = render_bars_cartesian(
        self.monitor_height_px,
        self.monitor_width_px,
        bar_width_px,
        bar_height_px,
        bar_loc_horizontal_px,
        bar_loc_vertical_px,
        n_bars,
        bar_intensity,
        bg_intensity,
        moving_angle,
    )
    if cartesian:
        return bar
    return self(torch.tensor(bar.flatten(), device=self.device)[None], mode)

render_grating

render_grating(
    period_rad,
    phase_rad,
    intensity,
    bg_intensity,
    moving_angle,
    width_rad=None,
    height_rad=None,
    cartesian=False,
    mode="mean",
)

Render a grating stimulus.

Parameters:

Name Type Description Default
period_rad float

Period of the grating in radians.

required
phase_rad float

Phase of the grating in radians.

required
intensity float

Intensity of the grating.

required
bg_intensity float

Intensity of the background.

required
moving_angle float

Rotation angle in degrees.

required
width_rad Optional[float]

Width of the grating in radians.

None
height_rad Optional[float]

Height of the grating in radians.

None
cartesian bool

If True, return cartesian coordinates.

False
mode Literal['mean', 'median', 'sum']

Processing mode.

'mean'

Returns:

Type Description
Union[ndarray, Tensor]

Union[np.ndarray, torch.Tensor]: Generated grating stimulus.

Source code in flyvision/datasets/rendering/eye.py
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
def render_grating(
    self,
    period_rad: float,
    phase_rad: float,
    intensity: float,
    bg_intensity: float,
    moving_angle: float,
    width_rad: Optional[float] = None,
    height_rad: Optional[float] = None,
    cartesian: bool = False,
    mode: Literal["mean", "median", "sum"] = "mean",
) -> Union[np.ndarray, torch.Tensor]:
    """Render a grating stimulus.

    Args:
        period_rad: Period of the grating in radians.
        phase_rad: Phase of the grating in radians.
        intensity: Intensity of the grating.
        bg_intensity: Intensity of the background.
        moving_angle: Rotation angle in degrees.
        width_rad: Width of the grating in radians.
        height_rad: Height of the grating in radians.
        cartesian: If True, return cartesian coordinates.
        mode: Processing mode.

    Returns:
        Union[np.ndarray, torch.Tensor]: Generated grating stimulus.
    """
    period_px = int(period_rad / self.omm_width_rad * self.ppo)
    phase_px = int(phase_rad / self.omm_width_rad * self.ppo)

    height_rad_px = None
    if height_rad:
        height_rad_px = int(height_rad / self.omm_height_rad * self.ppo)

    width_rad_px = None
    if width_rad:
        width_rad_px = int(width_rad / self.omm_width_rad * self.ppo)

    grating = render_gratings_cartesian(
        self.monitor_height_px,
        self.monitor_width_px,
        period_px,
        intensity,
        bg_intensity,
        grating_phase_px=phase_px,
        rotate=moving_angle,
        grating_height_px=height_rad_px,
        grating_width_px=width_rad_px,
    )
    if cartesian:
        return grating
    return self(torch.tensor(grating.flatten(), device=self.device)[None], mode)

render_grating_offsets

render_grating_offsets(
    period_rad,
    intensity,
    bg_intensity,
    moving_angle,
    width_rad=None,
    height_rad=None,
    cartesian=False,
    mode="mean",
)

Render grating stimuli with a range of offsets.

Parameters:

Name Type Description Default
period_rad float

Period of the grating in radians.

required
intensity float

Intensity of the grating.

required
bg_intensity float

Intensity of the background.

required
moving_angle float

Rotation angle in degrees.

required
width_rad Optional[float]

Width of the grating in radians.

None
height_rad Optional[float]

Height of the grating in radians.

None
cartesian bool

If True, return cartesian coordinates.

False
mode Literal['mean', 'median', 'sum']

Processing mode.

'mean'

Returns:

Type Description
Union[ndarray, Tensor]

Union[np.ndarray, torch.Tensor]: Generated grating stimuli with offsets.

Source code in flyvision/datasets/rendering/eye.py
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
def render_grating_offsets(
    self,
    period_rad: float,
    intensity: float,
    bg_intensity: float,
    moving_angle: float,
    width_rad: Optional[float] = None,
    height_rad: Optional[float] = None,
    cartesian: bool = False,
    mode: Literal["mean", "median", "sum"] = "mean",
) -> Union[np.ndarray, torch.Tensor]:
    """Render grating stimuli with a range of offsets.

    Args:
        period_rad: Period of the grating in radians.
        intensity: Intensity of the grating.
        bg_intensity: Intensity of the background.
        moving_angle: Rotation angle in degrees.
        width_rad: Width of the grating in radians.
        height_rad: Height of the grating in radians.
        cartesian: If True, return cartesian coordinates.
        mode: Processing mode.

    Returns:
        Union[np.ndarray, torch.Tensor]: Generated grating stimuli with offsets.
    """
    dphase_px = np.radians(
        5.8 / 2
    )  # half ommatidia width - corresponds to led width of 2.25 degree
    n_offsets = np.ceil(period_rad / dphase_px).astype(int)
    gratings = []
    for offset in range(n_offsets):
        gratings.append(
            self.render_grating(
                period_rad,
                offset * dphase_px,
                intensity,
                bg_intensity,
                moving_angle,
                width_rad=width_rad,
                height_rad=height_rad,
                cartesian=cartesian,
                mode=mode,
            )
        )
    if cartesian:
        return np.array(gratings)
    return torch.cat(gratings, dim=0)

render_offset_bars

render_offset_bars(
    bar_width_rad,
    bar_height_rad,
    n_bars,
    offsets,
    bar_intensity,
    bg_intensity,
    moving_angle,
    bar_loc_horizontal=np.radians(90),
    bar_loc_vertical=np.radians(90),
    mode="mean",
)

Render bars with a range of offsets.

Parameters:

Name Type Description Default
bar_width_rad float

Width of bars in radians.

required
bar_height_rad float

Height of bars in radians.

required
n_bars int

Number of bars.

required
offsets List[float]

Offsets of bars wrt. the center in radians.

required
bar_intensity float

Intensity of the bar.

required
bg_intensity float

Intensity of the background.

required
moving_angle float

Rotation angle in degrees.

required
bar_loc_horizontal float

Horizontal location of bars in radians.

radians(90)
bar_loc_vertical float

Vertical location of bars in radians.

radians(90)
mode Literal['mean', 'median', 'sum']

Processing mode.

'mean'

Returns:

Type Description
Tensor

torch.Tensor: Generated offset bars.

Source code in flyvision/datasets/rendering/eye.py
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
def render_offset_bars(
    self,
    bar_width_rad: float,
    bar_height_rad: float,
    n_bars: int,
    offsets: List[float],
    bar_intensity: float,
    bg_intensity: float,
    moving_angle: float,
    bar_loc_horizontal: float = np.radians(90),
    bar_loc_vertical: float = np.radians(90),
    mode: Literal["mean", "median", "sum"] = "mean",
) -> torch.Tensor:
    """Render bars with a range of offsets.

    Args:
        bar_width_rad: Width of bars in radians.
        bar_height_rad: Height of bars in radians.
        n_bars: Number of bars.
        offsets: Offsets of bars wrt. the center in radians.
        bar_intensity: Intensity of the bar.
        bg_intensity: Intensity of the background.
        moving_angle: Rotation angle in degrees.
        bar_loc_horizontal: Horizontal location of bars in radians.
        bar_loc_vertical: Vertical location of bars in radians.
        mode: Processing mode.

    Returns:
        torch.Tensor: Generated offset bars.
    """
    flashes = []
    for offset in offsets:
        flashes.append(
            self.render_bar(
                bar_width_rad,
                bar_height_rad,
                bar_loc_horizontal + offset,
                bar_loc_vertical,
                n_bars,
                bar_intensity,
                bg_intensity,
                moving_angle,
                mode=mode,
            )
        )
    return torch.cat(flashes, dim=0)

render_bar_movie

render_bar_movie(
    t_stim,
    dt,
    bar_width_rad,
    bar_height_rad,
    n_bars,
    offsets,
    bar_intensity,
    bg_intensity,
    moving_angle,
    t_pre=0.0,
    t_between=0.0,
    t_post=0.0,
    bar_loc_horizontal=np.radians(90),
    bar_loc_vertical=np.radians(90),
)

Render moving bars.

Parameters:

Name Type Description Default
t_stim float

Stimulus duration.

required
dt float

Temporal resolution.

required
bar_width_rad float

Width of bars in radians.

required
bar_height_rad float

Height of bars in radians.

required
n_bars int

Number of bars.

required
offsets List[float]

Offsets of bars wrt. the center in radians.

required
bar_intensity float

Intensity of the bar.

required
bg_intensity float

Intensity of the background.

required
moving_angle float

Rotation angle in degrees.

required
t_pre float

Grey pre stimulus duration.

0.0
t_between float

Grey between offset stimulus duration.

0.0
t_post float

Grey post stimulus duration.

0.0
bar_loc_horizontal float

Horizontal location of bars in radians.

radians(90)
bar_loc_vertical float

Vertical location of bars in radians.

radians(90)

Returns:

Type Description
Tensor

torch.Tensor: Generated moving bars.

Source code in flyvision/datasets/rendering/eye.py
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
def render_bar_movie(
    self,
    t_stim: float,
    dt: float,
    bar_width_rad: float,
    bar_height_rad: float,
    n_bars: int,
    offsets: List[float],
    bar_intensity: float,
    bg_intensity: float,
    moving_angle: float,
    t_pre: float = 0.0,
    t_between: float = 0.0,
    t_post: float = 0.0,
    bar_loc_horizontal: float = np.radians(90),
    bar_loc_vertical: float = np.radians(90),
) -> torch.Tensor:
    """Render moving bars.

    Args:
        t_stim: Stimulus duration.
        dt: Temporal resolution.
        bar_width_rad: Width of bars in radians.
        bar_height_rad: Height of bars in radians.
        n_bars: Number of bars.
        offsets: Offsets of bars wrt. the center in radians.
        bar_intensity: Intensity of the bar.
        bg_intensity: Intensity of the background.
        moving_angle: Rotation angle in degrees.
        t_pre: Grey pre stimulus duration.
        t_between: Grey between offset stimulus duration.
        t_post: Grey post stimulus duration.
        bar_loc_horizontal: Horizontal location of bars in radians.
        bar_loc_vertical: Vertical location of bars in radians.

    Returns:
        torch.Tensor: Generated moving bars.
    """
    pre_frames = round(t_pre / dt)
    stim_frames = round(t_stim / (len(offsets) * dt))
    if stim_frames == 0:
        raise ValueError(
            f"stimulus time {t_stim}s not sufficient to sample {len(offsets)} "
            "offsets at {dt}s"
        )
    between_frames = round(t_between / dt)
    post_frames = round(t_post / dt)

    flashes = []
    if pre_frames:
        flashes.append(torch.ones([pre_frames, self.n_ommatidia]) * bg_intensity)

    for i, offset in enumerate(offsets):
        flash = self.render_bar(
            bar_width_rad,
            bar_height_rad,
            bar_loc_horizontal + offset,
            bar_loc_vertical,
            n_bars,
            bar_intensity,
            bg_intensity,
            moving_angle,
        )
        flashes.append(flash.repeat(stim_frames, 1))

        if between_frames and i < len(offsets) - 1:
            flashes.append(
                torch.ones([between_frames, self.n_ommatidia]) * bg_intensity
            )
    if post_frames:
        flashes.append(torch.ones([post_frames, self.n_ommatidia]) * bg_intensity)
    return torch.cat(flashes, dim=0)

illustrate

illustrate(figsize=[5, 5], fontsize=5)

Illustrate the hexagonal eye model.

Parameters:

Name Type Description Default
figsize List[int]

Figure size.

[5, 5]
fontsize int

Font size for the plot.

5

Returns:

Type Description
Tuple[Figure, Axes]

Tuple[plt.Figure, plt.Axes]: Matplotlib figure and axes objects.

Source code in flyvision/datasets/rendering/eye.py
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
def illustrate(
    self, figsize: List[int] = [5, 5], fontsize: int = 5
) -> Tuple[plt.Figure, plt.Axes]:
    """Illustrate the hexagonal eye model.

    Args:
        figsize: Figure size.
        fontsize: Font size for the plot.

    Returns:
        Tuple[plt.Figure, plt.Axes]: Matplotlib figure and axes objects.
    """
    x_hc, y_hc, (dist_w, dist_h) = hex_center_coordinates(
        self.n_ommatidia, self.monitor_width_px, self.monitor_height_px
    )

    x_img, y_img = np.array(
        list(
            product(
                np.arange(self.monitor_width_px),
                np.arange(
                    self.monitor_height_px,
                ),
            )
        )
    ).T

    dist_to_edge = (dist_w + dist_h) / 4

    vertices, _ = is_inside_hex(
        torch.tensor(y_img, dtype=self.dtype),
        torch.tensor(x_img, dtype=self.dtype),
        torch.tensor(x_hc, dtype=self.dtype),
        torch.tensor(y_hc, dtype=self.dtype),
        torch.tensor(dist_to_edge, dtype=self.dtype),
        torch.tensor(np.radians(0), dtype=self.dtype),
    )
    vertices = vertices.cpu()
    fig, ax = init_plot(figsize=figsize, fontsize=fontsize)
    ax.scatter(x_hc, y_hc, color="#eb4034", zorder=1)
    ax.scatter(x_img, y_img, color="#34ebd6", s=0.5, zorder=0)

    for h in range(self.n_ommatidia):
        for i in range(6):
            x1, y1 = vertices[i, :, h]  # x1, y1: (n_hexagons)
            x2, y2 = vertices[i + 1, :, h]
            ax.plot([x1, x2], [y1, y2], c="black")

    ax.set_xlim(0, self.monitor_width_px)
    ax.set_ylim(0, self.monitor_height_px)
    rm_spines(ax)
    # fig.tight_layout()
    return fig, ax

Utils

flyvision.datasets.rendering.utils

Rendering utils

median

median(x, kernel_size, stride=1, n_chunks=10)

Apply median image filter with reflected padding.

Parameters:

Name Type Description Default
x Tensor

Input array or tensor of shape (n_samples, n_frames, height, width). First and second dimensions are optional.

required
kernel_size int

Size of the filter kernel.

required
stride int

Stride for the filter operation.

1
n_chunks int

Number of chunks to process the data if memory is limited. Recursively increases the chunk size until the data fits in memory.

10

Returns:

Type Description
Tensor

Filtered array or tensor of the same shape as input.

Note

On GPU, this creates a tensor of kernel_size ** 2 * prod(x.shape) elements, consuming significant memory (e.g., ~14 GB for 50 frames of 436x1024 with kernel_size 13). In case of a RuntimeError due to memory, the method processes the data in chunks.

Source code in flyvision/datasets/rendering/utils.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def median(
    x: torch.Tensor,
    kernel_size: int,
    stride: int = 1,
    n_chunks: int = 10,
) -> torch.Tensor:
    """
    Apply median image filter with reflected padding.

    Args:
        x: Input array or tensor of shape (n_samples, n_frames, height, width).
           First and second dimensions are optional.
        kernel_size: Size of the filter kernel.
        stride: Stride for the filter operation.
        n_chunks: Number of chunks to process the data if memory is limited.
            Recursively increases the chunk size until the data fits in memory.

    Returns:
        Filtered array or tensor of the same shape as input.

    Note:
        On GPU, this creates a tensor of kernel_size ** 2 * prod(x.shape) elements,
        consuming significant memory (e.g., ~14 GB for 50 frames of 436x1024 with
        kernel_size 13). In case of a RuntimeError due to memory, the method
        processes the data in chunks.
    """
    # Get padding so that the resulting tensor is of the same shape.
    p = max(kernel_size - 1, 0)
    p_floor = p // 2
    p_ceil = p - p_floor
    padding = (p_floor, p_ceil, p_floor, p_ceil)

    shape = x.shape

    try:
        with torch.no_grad():
            if len(shape) == 2:
                x.unsqueeze_(0).unsqueeze_(0)
            elif len(shape) == 3:
                x.unsqueeze_(0)
            elif len(shape) == 4:
                pass
            else:
                raise ValueError(f"Invalid shape: {shape}")
            assert len(x.shape) == 4
            _x = F.pad(x, padding, mode="reflect")
            _x = _x.unfold(dimension=2, size=kernel_size, step=stride).unfold(
                dimension=3, size=kernel_size, step=stride
            )
            _x = _x.contiguous().view(shape[:4] + (-1,)).median(dim=-1)[0]
            return _x.view(shape)
    except RuntimeError as e:
        if "memory" not in str(e):
            raise e
        if "CUDA" in str(e):
            torch.cuda.empty_cache()
        _x = x.reshape(-1, *x.shape[-2:])
        chunks = torch.chunk(_x, max(n_chunks, 1), dim=0)

        def map_fn(z):
            return median(z, kernel_size, n_chunks=n_chunks - 1)

        _x = torch.cat(tuple(map(map_fn, chunks)), dim=0)
        return _x.view(shape)

split

split(
    array, out_nelements, n_splits, center_crop_fraction=0.7
)

Split an array into overlapping segments along the last dimension.

Parameters:

Name Type Description Default
array Union[ndarray, Tensor]

Input array of shape (…, nelements).

required
out_nelements int

Number of elements in each output split.

required
n_splits int

Number of splits to create.

required
center_crop_fraction Optional[float]

If not None, the array is centrally cropped in the last dimension to this fraction before splitting.

0.7

Returns:

Type Description
Union[ndarray, Tensor]

A new array of shape (n_splits, …, out_nelements) containing the splits.

Raises:

Type Description
ValueError

If n_splits is less than 0.

TypeError

If the input array is neither a numpy array nor a torch tensor.

Note
  • If n_splits is 1, the entire array is returned (with an added dimension).
  • If n_splits is None or 0, the original array is returned unchanged.
  • Splits may overlap if out_nelements * n_splits > array.shape[-1].
Source code in flyvision/datasets/rendering/utils.py
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
def split(
    array: Union[np.ndarray, torch.Tensor],
    out_nelements: int,
    n_splits: int,
    center_crop_fraction: Optional[float] = 0.7,
) -> Union[np.ndarray, torch.Tensor]:
    """
    Split an array into overlapping segments along the last dimension.

    Args:
        array: Input array of shape (..., nelements).
        out_nelements: Number of elements in each output split.
        n_splits: Number of splits to create.
        center_crop_fraction: If not None, the array is centrally cropped in the
            last dimension to this fraction before splitting.

    Returns:
        A new array of shape (n_splits, ..., out_nelements) containing the splits.

    Raises:
        ValueError: If n_splits is less than 0.
        TypeError: If the input array is neither a numpy array nor a torch tensor.

    Note:
        - If n_splits is 1, the entire array is returned (with an added dimension).
        - If n_splits is None or 0, the original array is returned unchanged.
        - Splits may overlap if out_nelements * n_splits > array.shape[-1].
    """
    assert isinstance(array, (np.ndarray, torch.Tensor))
    if center_crop_fraction is not None:
        return split(
            center_crop(array, center_crop_fraction),
            out_nelements,
            n_splits,
            center_crop_fraction=None,
        )

    actual_nelements = array.shape[-1]
    out_nelements = int(out_nelements)

    def take(
        arr: Union[np.ndarray, torch.Tensor], start: int, stop: int
    ) -> Union[np.ndarray, torch.Tensor]:
        if isinstance(arr, np.ndarray):
            return np.take(arr, np.arange(start, stop), axis=-1)[None]
        elif isinstance(arr, torch.Tensor):
            return torch.index_select(arr, dim=-1, index=torch.arange(start, stop))[None]

    if n_splits == 1:
        out = (array[None, :],)
    elif n_splits > 1:
        out = ()
        out_nelements = max(out_nelements, int(actual_nelements / n_splits))
        overlap = np.ceil(
            (out_nelements * n_splits - actual_nelements) / (n_splits - 1)
        ).astype(int)
        for i in range(n_splits):
            start = i * out_nelements - i * overlap
            stop = (i + 1) * out_nelements - i * overlap
            out += (take(array, start, stop),)
    elif n_splits is None or n_splits == 0:
        return array
    else:
        raise ValueError("n_splits must be a non-negative integer or None")

    if isinstance(array, np.ndarray):
        return np.concatenate(out, axis=0)
    elif isinstance(array, torch.Tensor):
        return torch.cat(out, dim=0)

center_crop

center_crop(array, out_nelements_ratio)

Centrally crop an array along the last dimension with given ratio.

Parameters:

Name Type Description Default
array Union[ndarray, Tensor]

Array of shape (…, nelements).

required
out_nelements_ratio float

Ratio of output elements to input elements.

required

Returns:

Type Description
Union[ndarray, Tensor]

Cropped array of shape (…, out_nelements).

Source code in flyvision/datasets/rendering/utils.py
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
def center_crop(
    array: Union[np.ndarray, torch.Tensor], out_nelements_ratio: float
) -> Union[np.ndarray, torch.Tensor]:
    """
    Centrally crop an array along the last dimension with given ratio.

    Args:
        array: Array of shape (..., nelements).
        out_nelements_ratio: Ratio of output elements to input elements.

    Returns:
        Cropped array of shape (..., out_nelements).
    """

    def take(arr, start, stop):
        if isinstance(arr, np.ndarray):
            return np.take(arr, np.arange(start, stop), axis=-1)
        elif isinstance(arr, torch.Tensor):
            return torch.index_select(arr, dim=-1, index=torch.arange(start, stop))

    nelements = array.shape[-1]
    out_nelements = int(out_nelements_ratio * nelements)
    return take(array, (nelements - out_nelements) // 2, (nelements + out_nelements) // 2)

hex_center_coordinates

hex_center_coordinates(
    n_hex_area, img_width, img_height, center=True
)

Calculate hexagon center coordinates for a given area.

Parameters:

Name Type Description Default
n_hex_area int

Number of hexagons in the area.

required
img_width int

Width of the image.

required
img_height int

Height of the image.

required
center bool

If True, center the hexagon grid in the image.

True

Returns:

Type Description
Tuple[ndarray, ndarray, Tuple[float, float]]

Tuple containing x coordinates, y coordinates, and (dist_w, dist_h).

Source code in flyvision/datasets/rendering/utils.py
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
def hex_center_coordinates(
    n_hex_area: int, img_width: int, img_height: int, center: bool = True
) -> Tuple[np.ndarray, np.ndarray, Tuple[float, float]]:
    """
    Calculate hexagon center coordinates for a given area.

    Args:
        n_hex_area: Number of hexagons in the area.
        img_width: Width of the image.
        img_height: Height of the image.
        center: If True, center the hexagon grid in the image.

    Returns:
        Tuple containing x coordinates, y coordinates, and (dist_w, dist_h).
    """
    # Horizontal extent of the grid
    n = np.floor(np.sqrt(n_hex_area / 3)).astype("int")

    dist_h = img_height / (2 * n + 1)
    dist_w = img_width / (2 * n + 1)

    xs = []
    ys = []
    for q in range(-n, n + 1):
        for r in range(max(-n, -n - q), min(n, n - q) + 1):
            xs.append(dist_w * r)
            ys.append(
                dist_h * (q + r / 2)
            )  # either must be negative or origin must be upper
    xs, ys = np.array(xs), np.array(ys)
    if center:
        xs += img_width // 2
        ys += img_height // 2
    return xs, ys, (dist_w, dist_h)

is_inside_hex

is_inside_hex(
    x,
    y,
    x_centers,
    y_centers,
    dist_to_edge,
    tilt,
    device=flyvision.device,
    dtype=torch.float16,
)

Find whether given points are inside the given hexagons.

Parameters:

Name Type Description Default
x Tensor

Cartesian x-coordinates of the points (n_points).

required
y Tensor

Cartesian y-coordinates of the points (n_points).

required
x_centers Tensor

Cartesian x-centers of the hexagons (n_hexagons).

required
y_centers Tensor

Cartesian y-centers of the hexagons (n_hexagons).

required
dist_to_edge float

Euclidean distance from center to edge of the hexagon.

required
tilt Union[float, Tensor]

Angle of hexagon counter-clockwise tilt in radians.

required
device device

Torch device to use for computations.

device
dtype dtype

Data type for torch tensors.

float16

Returns:

Type Description
Tensor

Tuple containing:

Tensor
  • vertices: Cartesian coordinates of the hexagons’ vertices (7, 2, n_hexagons).
Tuple[Tensor, Tensor]
  • is_inside: Boolean tensor indicating whether points are inside (n_points, n_hexagons).
Credits

Adapted from Roman Vaxenburg’s original implementation.

Source code in flyvision/datasets/rendering/utils.py
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
def is_inside_hex(
    x: torch.Tensor,
    y: torch.Tensor,
    x_centers: torch.Tensor,
    y_centers: torch.Tensor,
    dist_to_edge: float,
    tilt: Union[float, torch.Tensor],
    device: torch.device = flyvision.device,
    dtype: torch.dtype = torch.float16,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Find whether given points are inside the given hexagons.

    Args:
        x: Cartesian x-coordinates of the points (n_points).
        y: Cartesian y-coordinates of the points (n_points).
        x_centers: Cartesian x-centers of the hexagons (n_hexagons).
        y_centers: Cartesian y-centers of the hexagons (n_hexagons).
        dist_to_edge: Euclidean distance from center to edge of the hexagon.
        tilt: Angle of hexagon counter-clockwise tilt in radians.
        device: Torch device to use for computations.
        dtype: Data type for torch tensors.

    Returns:
        Tuple containing:
        - vertices: Cartesian coordinates of the hexagons' vertices (7, 2, n_hexagons).
        - is_inside: Boolean tensor indicating whether points are inside
            (n_points, n_hexagons).

    Info: Credits
        Adapted from Roman Vaxenburg's original implementation.
    """
    if not isinstance(tilt, torch.Tensor):
        tilt = torch.tensor(tilt, device=device)

    R = torch.tensor(
        [
            [torch.cos(tilt), -torch.sin(tilt)],
            [torch.sin(tilt), torch.cos(tilt)],
        ],
        dtype=dtype,
        device=device,
    )  # rotation matrix
    pi = torch.tensor(np.pi, device=device, dtype=dtype)
    R60 = torch.tensor(
        [
            [torch.cos(pi / 3), -torch.sin(pi / 3)],
            [torch.sin(pi / 3), torch.cos(pi / 3)],
        ],
        dtype=dtype,
        device=device,
    )  # rotation matrix

    # Generate hexagon vertices
    dist_to_vertex = 2 / np.sqrt(3) * dist_to_edge
    vertices = torch.zeros(7, 2, dtype=dtype, device=device)
    vertices[0, :] = torch.matmul(
        R, torch.tensor([dist_to_vertex, 0], dtype=dtype, device=device)
    )
    for i in range(1, 7):
        vertices[i] = torch.matmul(R60, vertices[i - 1])
    vertices = vertices[:, :, None]
    vertices = torch.cat(
        (
            vertices[:, 0:1, :] + x_centers[None, None, :],
            vertices[:, 1:2, :] + y_centers[None, None, :],
        ),
        dim=1,
    )  # (7, 2, n_hexagons)

    # Generate is_inside output
    is_inside = torch.ones(len(x), len(x_centers), dtype=torch.bool, device=device)
    for i in range(6):
        x1, y1 = vertices[i, :, :]  # x1, y1: (n_hexagons)
        x2, y2 = vertices[i + 1, :, :]
        slope = (y2 - y1) / (x2 - x1)  # (n_hexagons)
        f_center = y1 + slope * (x_centers - x1) - y_centers  # (n_hexagons)
        f_points = (
            y1[None, :] + slope[None, :] * (x[:, None] - x1[None, :]) - y[:, None]
        )  # (n_points, n_hexagons)
        is_inside = torch.logical_and(is_inside, f_center.sign() == f_points.sign())

    return vertices, is_inside  # (7, 2, n_hexagons), (n_points, n_hexagons)

render_bars_cartesian

render_bars_cartesian(
    img_height_px,
    img_width_px,
    bar_width_px,
    bar_height_px,
    bar_loc_horizontal_px,
    bar_loc_vertical_px,
    n_bars,
    bar_intensity,
    bg_intensity,
    rotate=0,
)

Render bars in a cartesian coordinate system.

Parameters:

Name Type Description Default
img_height_px int

Height of the image in pixels.

required
img_width_px int

Width of the image in pixels.

required
bar_width_px int

Width of each bar in pixels.

required
bar_height_px int

Height of each bar in pixels.

required
bar_loc_horizontal_px int

Horizontal location of the bars in pixels.

required
bar_loc_vertical_px int

Vertical location of the bars in pixels.

required
n_bars int

Number of bars to generate.

required
bar_intensity float

Intensity of the bars.

required
bg_intensity float

Intensity of the background.

required
rotate float

Rotation angle in degrees.

0

Returns:

Type Description
ndarray

Generated image as a numpy array.

Source code in flyvision/datasets/rendering/utils.py
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
def render_bars_cartesian(
    img_height_px: int,
    img_width_px: int,
    bar_width_px: int,
    bar_height_px: int,
    bar_loc_horizontal_px: int,
    bar_loc_vertical_px: int,
    n_bars: int,
    bar_intensity: float,
    bg_intensity: float,
    rotate: float = 0,
) -> np.ndarray:
    """
    Render bars in a cartesian coordinate system.

    Args:
        img_height_px: Height of the image in pixels.
        img_width_px: Width of the image in pixels.
        bar_width_px: Width of each bar in pixels.
        bar_height_px: Height of each bar in pixels.
        bar_loc_horizontal_px: Horizontal location of the bars in pixels.
        bar_loc_vertical_px: Vertical location of the bars in pixels.
        n_bars: Number of bars to generate.
        bar_intensity: Intensity of the bars.
        bg_intensity: Intensity of the background.
        rotate: Rotation angle in degrees.

    Returns:
        Generated image as a numpy array.
    """
    bar_spacing = int(img_width_px / n_bars - bar_width_px)

    height_slice = slice(
        int(bar_loc_vertical_px - bar_height_px / 2),
        int(bar_loc_vertical_px + bar_height_px / 2) + 1,
    )

    img = np.ones([img_height_px, img_width_px]) * bg_intensity

    loc_w = int(bar_loc_horizontal_px - bar_width_px / 2)
    for i in range(n_bars):
        #  Fill background with bars.
        start = max(loc_w + i * bar_width_px + i * bar_spacing, 0)
        width_slice = slice(start, loc_w + (i + 1) * bar_width_px + i * bar_spacing + 1)
        img[height_slice, width_slice] = bar_intensity

    if rotate % 360 != 0:
        img = rotate_image(img, angle=rotate)

    return img

render_gratings_cartesian

render_gratings_cartesian(
    img_height_px,
    img_width_px,
    spatial_period_px,
    grating_intensity,
    bg_intensity,
    grating_height_px=None,
    grating_width_px=None,
    grating_phase_px=0,
    rotate=0,
)

Render gratings in a cartesian coordinate system.

Parameters:

Name Type Description Default
img_height_px int

Height of the image in pixels.

required
img_width_px int

Width of the image in pixels.

required
spatial_period_px float

Spatial period of the gratings in pixels.

required
grating_intensity float

Intensity of the gratings.

required
bg_intensity float

Intensity of the background.

required
grating_height_px Optional[int]

Height of the grating area in pixels.

None
grating_width_px Optional[int]

Width of the grating area in pixels.

None
grating_phase_px float

Phase of the gratings in pixels.

0
rotate float

Rotation angle in degrees.

0

Returns:

Type Description
ndarray

Generated image as a numpy array.

Source code in flyvision/datasets/rendering/utils.py
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
def render_gratings_cartesian(
    img_height_px: int,
    img_width_px: int,
    spatial_period_px: float,
    grating_intensity: float,
    bg_intensity: float,
    grating_height_px: Optional[int] = None,
    grating_width_px: Optional[int] = None,
    grating_phase_px: float = 0,
    rotate: float = 0,
) -> np.ndarray:
    """
    Render gratings in a cartesian coordinate system.

    Args:
        img_height_px: Height of the image in pixels.
        img_width_px: Width of the image in pixels.
        spatial_period_px: Spatial period of the gratings in pixels.
        grating_intensity: Intensity of the gratings.
        bg_intensity: Intensity of the background.
        grating_height_px: Height of the grating area in pixels.
        grating_width_px: Width of the grating area in pixels.
        grating_phase_px: Phase of the gratings in pixels.
        rotate: Rotation angle in degrees.

    Returns:
        Generated image as a numpy array.
    """
    # to save time at library import
    from scipy.signal import square

    t = (
        2
        * np.pi
        / (spatial_period_px / img_width_px)
        * (
            np.linspace(-1 / 2, 1 / 2, int(img_width_px))
            - grating_phase_px / img_width_px
        )
    )

    gratings = np.tile(square(t), img_height_px).reshape(img_height_px, img_width_px)
    gratings[gratings == -1] = bg_intensity
    gratings[gratings == 1] = grating_intensity

    if grating_height_px:
        mask = np.ones_like(gratings).astype(bool)

        height_slice = slice(
            int(img_height_px // 2 - grating_height_px / 2),
            int(img_height_px // 2 + grating_height_px / 2) + 1,
        )
        mask[height_slice] = False
        gratings[mask] = 0.5

    if grating_width_px:
        mask = np.ones_like(gratings).astype(bool)

        width_slice = slice(
            int(img_width_px // 2 - grating_width_px / 2),
            int(img_width_px // 2 + grating_width_px / 2) + 1,
        )
        mask[:, width_slice] = False
        gratings[mask] = 0.5

    if rotate % 360 != 0:
        gratings = rotate_image(gratings, angle=rotate)

    return gratings

rotate_image

rotate_image(img, angle=0)

Rotate an image by a given angle.

Parameters:

Name Type Description Default
img ndarray

Input image as a numpy array.

required
angle float

Rotation angle in degrees.

0

Returns:

Type Description
ndarray

Rotated image as a numpy array.

Source code in flyvision/datasets/rendering/utils.py
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
def rotate_image(img: np.ndarray, angle: float = 0) -> np.ndarray:
    """
    Rotate an image by a given angle.

    Args:
        img: Input image as a numpy array.
        angle: Rotation angle in degrees.

    Returns:
        Rotated image as a numpy array.
    """
    h, w = img.shape

    diagonal = int(np.sqrt(h**2 + w**2))

    pad_in_height = (diagonal - h) // 2
    pad_in_width = (diagonal - w) // 2

    img = np.pad(
        img,
        ((pad_in_height, pad_in_height), (pad_in_width, pad_in_width)),
        mode="edge",
    )

    img = Image.fromarray((255 * img).astype("uint8")).rotate(
        angle, Image.BILINEAR, False, None
    )
    img = np.array(img, dtype=float) / 255.0

    padded_h, padded_w = img.shape
    return img[
        pad_in_height : padded_h - pad_in_height,
        pad_in_width : padded_w - pad_in_width,
    ]

resample

resample(
    stims,
    t_stim,
    dt,
    dim=0,
    device=flyvision.device,
    return_indices=False,
)

Resample a set of stimuli for a given stimulus duration and time step.

Parameters:

Name Type Description Default
stims Tensor

Stimuli tensor of shape (#conditions, #hexals).

required
t_stim float

Stimulus duration in seconds.

required
dt float

Integration time constant in seconds.

required
dim int

Dimension along which to resample.

0
device device

Torch device to use for computations.

device
return_indices bool

If True, return the indices used for resampling.

False

Returns:

Type Description
Union[Tensor, Tuple[Tensor, Tensor]]

Resampled stimuli tensor of shape (#frames, #hexals), or a tuple of

Union[Tensor, Tuple[Tensor, Tensor]]

(resampled stimuli, indices) if return_indices is True.

Source code in flyvision/datasets/rendering/utils.py
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
def resample(
    stims: torch.Tensor,
    t_stim: float,
    dt: float,
    dim: int = 0,
    device: torch.device = flyvision.device,
    return_indices: bool = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
    """
    Resample a set of stimuli for a given stimulus duration and time step.

    Args:
        stims: Stimuli tensor of shape (#conditions, #hexals).
        t_stim: Stimulus duration in seconds.
        dt: Integration time constant in seconds.
        dim: Dimension along which to resample.
        device: Torch device to use for computations.
        return_indices: If True, return the indices used for resampling.

    Returns:
        Resampled stimuli tensor of shape (#frames, #hexals), or a tuple of
        (resampled stimuli, indices) if return_indices is True.
    """
    n_offsets = stims.shape[dim]
    # round to nearest integer
    # this results in unequal counts of each frame usually by +-1
    indices = torch.linspace(0, n_offsets - 1, int(t_stim / dt), device=device).long()
    if not return_indices:
        return torch.index_select(stims, dim, indices)
    return torch.index_select(stims, dim, indices), indices

shuffle

shuffle(stims, randomstate=None)

Randomly shuffle stimuli along the frame dimension.

Parameters:

Name Type Description Default
stims Tensor

Stimuli tensor of shape (N (optional), #frames, #hexals).

required
randomstate Optional[RandomState]

Random state for reproducibility.

None

Returns:

Type Description
Tensor

Shuffled stimuli tensor.

Source code in flyvision/datasets/rendering/utils.py
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
def shuffle(
    stims: torch.Tensor, randomstate: Optional[np.random.RandomState] = None
) -> torch.Tensor:
    """
    Randomly shuffle stimuli along the frame dimension.

    Args:
        stims: Stimuli tensor of shape (N (optional), #frames, #hexals).
        randomstate: Random state for reproducibility.

    Returns:
        Shuffled stimuli tensor.
    """
    if len(stims.shape) == 3:
        # assume (smples frames hexals)
        def _shuffle(x):
            return shuffle(x, randomstate)

        return torch.stack(list(map(_shuffle, stims)), dim=0)
    perms = (
        randomstate.permutation(stims.shape[0])
        if randomstate is not None
        else np.random.permutation(stims.shape[0])
    )
    return stims[perms]

resample_grating

resample_grating(grating, t_stim, dt, temporal_frequency)

Resample a grating stimulus for a given duration and temporal frequency.

Parameters:

Name Type Description Default
grating Tensor

Input grating tensor.

required
t_stim float

Stimulus duration in seconds.

required
dt float

Time step in seconds.

required
temporal_frequency float

Temporal frequency of the grating.

required

Returns:

Type Description
Tensor

Resampled grating tensor.

Source code in flyvision/datasets/rendering/utils.py
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
def resample_grating(
    grating: torch.Tensor, t_stim: float, dt: float, temporal_frequency: float
) -> torch.Tensor:
    """
    Resample a grating stimulus for a given duration and temporal frequency.

    Args:
        grating: Input grating tensor.
        t_stim: Stimulus duration in seconds.
        dt: Time step in seconds.
        temporal_frequency: Temporal frequency of the grating.

    Returns:
        Resampled grating tensor.
    """
    n_frames = int(t_stim / dt)
    t_period = 1 / temporal_frequency
    _grating = resample(grating, t_period, dt)
    _grating = _grating.repeat(np.ceil(n_frames / _grating.shape[0]).astype(int), 1)
    return _grating[:n_frames]

pad

pad(stim, t_stim, dt, fill=0, mode='end', pad_mode='value')

Pad the second to last dimension of a stimulus tensor.

Parameters:

Name Type Description Default
stim Tensor

Stimulus tensor of shape (…, n_frames, n_hexals).

required
t_stim float

Target stimulus duration in seconds.

required
dt float

Integration time constant in seconds.

required
fill float

Value to fill with if pad_mode is “value”.

0
mode Literal['end', 'start']

Padding mode, either “end” or “start”.

'end'
pad_mode Literal['value', 'continue', 'reflect']

Padding type, either “value”, “continue”, or “reflect”.

'value'

Returns:

Type Description
Tensor

Padded stimulus tensor.

Source code in flyvision/datasets/rendering/utils.py
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
def pad(
    stim: torch.Tensor,
    t_stim: float,
    dt: float,
    fill: float = 0,
    mode: Literal["end", "start"] = "end",
    pad_mode: Literal["value", "continue", "reflect"] = "value",
) -> torch.Tensor:
    """
    Pad the second to last dimension of a stimulus tensor.

    Args:
        stim: Stimulus tensor of shape (..., n_frames, n_hexals).
        t_stim: Target stimulus duration in seconds.
        dt: Integration time constant in seconds.
        fill: Value to fill with if pad_mode is "value".
        mode: Padding mode, either "end" or "start".
        pad_mode: Padding type, either "value", "continue", or "reflect".

    Returns:
        Padded stimulus tensor.
    """
    diff = int(t_stim / dt) - stim.shape[-2]
    if diff <= 0:
        return stim

    # Pad the second-to-last dimension (n_frames)
    # Format: (pad_last_dim_left, pad_last_dim_right,
    #          pad_second_to_last_dim_before, pad_second_to_last_dim_after)
    if mode == "end":
        pad = (0, 0, 0, diff)  # Pad after the existing frames
    elif mode == "start":
        pad = (0, 0, diff, 0)  # Pad before the existing frames

    if pad_mode == "value":
        return torch.nn.functional.pad(stim, pad=pad, mode="constant", value=fill)
    elif pad_mode == "continue":
        return repeat_last(stim, -2, diff)
    else:
        return torch.nn.functional.pad(stim, pad=pad, mode=pad_mode)

repeat_last

repeat_last(stim, dim, n_repeats)

Repeat the last frame of a stimulus tensor along a specified dimension.

Parameters:

Name Type Description Default
stim Tensor

Input stimulus tensor.

required
dim int

Dimension along which to repeat.

required
n_repeats int

Number of times to repeat the last frame.

required

Returns:

Type Description
Tensor

Stimulus tensor with the last frame repeated.

Source code in flyvision/datasets/rendering/utils.py
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
def repeat_last(stim: torch.Tensor, dim: int, n_repeats: int) -> torch.Tensor:
    """
    Repeat the last frame of a stimulus tensor along a specified dimension.

    Args:
        stim: Input stimulus tensor.
        dim: Dimension along which to repeat.
        n_repeats: Number of times to repeat the last frame.

    Returns:
        Stimulus tensor with the last frame repeated.
    """
    last = stim.index_select(dim, torch.tensor([stim.size(dim) - 1], device=stim.device))
    stim = torch.cat((stim, last.repeat_interleave(n_repeats, dim=dim)), dim=dim)
    return stim