Skip to content

Datasets

Base Classes

flyvision.datasets.datasets.SequenceDataset

Bases: Dataset

Base class for all sequence datasets.

All sequence datasets can subclass this class. They are expected to implement the following attributes and methods.

Attributes:

Name Type Description
framerate int

Framerate of the original sequences.

dt float

Sampling and integration time constant.

t_pre float

Warmup time.

t_post float

Cooldown time.

arg_df DataFrame

required DataFrame containing the dataset parameters.

Source code in flyvision/datasets/datasets.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
class SequenceDataset(torch.utils.data.Dataset):
    """Base class for all sequence datasets.

    All sequence datasets can subclass this class. They are expected to implement
    the following attributes and methods.

    Attributes:
        framerate (int): Framerate of the original sequences.
        dt (float): Sampling and integration time constant.
        t_pre (float): Warmup time.
        t_post (float): Cooldown time.
        arg_df (pd.DataFrame): required DataFrame containing the dataset parameters.
    """

    arg_df: pd.DataFrame = None
    dt: float = None
    t_pre: float = None
    t_post: float = None

    def get_item(self, key: int) -> Any:
        """Return an item of the dataset.

        Args:
            key: Index of the item to retrieve.

        Returns:
            The dataset item at the specified index.
        """

    def __len__(self) -> int:
        """Size of the dataset."""
        return len(self.arg_df)

    def __getitem__(self, key: Union[slice, Iterable, int, np.int_]) -> Any:
        """Implements advanced indexing.

        Args:
            key: Index, slice, or iterable of indices.

        Returns:
            The dataset item(s) at the specified index/indices.

        Raises:
            IndexError: If the index is out of range.
            TypeError: If the key type is invalid.
        """
        return getitem(self, key)

    def get_temporal_sample_indices(
        self, n_frames: int, total_seq_length: int, augment: bool = None
    ) -> torch.Tensor:
        """Returns temporal indices to sample from a sequence.

        Args:
            n_frames: Number of sequence frames to sample from.
            total_seq_length: Total sequence length.
            augment: If True, picks the start frame at random. If False, starts at 0.

        Returns:
            Tensor of temporal indices.

        Note:
            Interpolates between start_index and start_index + n_frames and rounds the
            resulting float values to integer to create indices. This can lead to
            irregularities in terms of how many times each raw data frame is sampled.
        """
        augment = augment if augment is not None else getattr(self, "augment", False)
        framerate = getattr(self, "original_framerate", 1 / self.dt)
        return get_temporal_sample_indices(
            n_frames, total_seq_length, framerate, self.dt, augment
        )

get_item

get_item(key)

Return an item of the dataset.

Parameters:

Name Type Description Default
key int

Index of the item to retrieve.

required

Returns:

Type Description
Any

The dataset item at the specified index.

Source code in flyvision/datasets/datasets.py
36
37
38
39
40
41
42
43
44
def get_item(self, key: int) -> Any:
    """Return an item of the dataset.

    Args:
        key: Index of the item to retrieve.

    Returns:
        The dataset item at the specified index.
    """

__len__

__len__()

Size of the dataset.

Source code in flyvision/datasets/datasets.py
46
47
48
def __len__(self) -> int:
    """Size of the dataset."""
    return len(self.arg_df)

__getitem__

__getitem__(key)

Implements advanced indexing.

Parameters:

Name Type Description Default
key Union[slice, Iterable, int, int_]

Index, slice, or iterable of indices.

required

Returns:

Type Description
Any

The dataset item(s) at the specified index/indices.

Raises:

Type Description
IndexError

If the index is out of range.

TypeError

If the key type is invalid.

Source code in flyvision/datasets/datasets.py
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def __getitem__(self, key: Union[slice, Iterable, int, np.int_]) -> Any:
    """Implements advanced indexing.

    Args:
        key: Index, slice, or iterable of indices.

    Returns:
        The dataset item(s) at the specified index/indices.

    Raises:
        IndexError: If the index is out of range.
        TypeError: If the key type is invalid.
    """
    return getitem(self, key)

get_temporal_sample_indices

get_temporal_sample_indices(
    n_frames, total_seq_length, augment=None
)

Returns temporal indices to sample from a sequence.

Parameters:

Name Type Description Default
n_frames int

Number of sequence frames to sample from.

required
total_seq_length int

Total sequence length.

required
augment bool

If True, picks the start frame at random. If False, starts at 0.

None

Returns:

Type Description
Tensor

Tensor of temporal indices.

Note

Interpolates between start_index and start_index + n_frames and rounds the resulting float values to integer to create indices. This can lead to irregularities in terms of how many times each raw data frame is sampled.

Source code in flyvision/datasets/datasets.py
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def get_temporal_sample_indices(
    self, n_frames: int, total_seq_length: int, augment: bool = None
) -> torch.Tensor:
    """Returns temporal indices to sample from a sequence.

    Args:
        n_frames: Number of sequence frames to sample from.
        total_seq_length: Total sequence length.
        augment: If True, picks the start frame at random. If False, starts at 0.

    Returns:
        Tensor of temporal indices.

    Note:
        Interpolates between start_index and start_index + n_frames and rounds the
        resulting float values to integer to create indices. This can lead to
        irregularities in terms of how many times each raw data frame is sampled.
    """
    augment = augment if augment is not None else getattr(self, "augment", False)
    framerate = getattr(self, "original_framerate", 1 / self.dt)
    return get_temporal_sample_indices(
        n_frames, total_seq_length, framerate, self.dt, augment
    )

flyvision.datasets.datasets.StimulusDataset

Bases: SequenceDataset

Base class for stimulus datasets.

Source code in flyvision/datasets/datasets.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
class StimulusDataset(SequenceDataset):
    """Base class for stimulus datasets."""

    def get_stimulus_index(self, kwargs: Dict[str, Any]) -> int:
        """Get the sequence id for a set of arguments.

        Args:
            kwargs: Dictionary containing independent arguments or parameters
                describing the sample of the dataset.

        Returns:
            The sequence id for the given arguments.

        Raises:
            ValueError: If arg_df attribute is not specified.

        Note:
            The child dataset implements the specific method:
            ```python
            def get_stimulus_index(self, arg1, arg2, ...):
                return StimulusDataset.get_stimulus_index(locals())
            ```
            with locals() specifying kwargs in terms of `arg1`, `arg2`, ...
            to index arg_df.
        """
        if getattr(self, "arg_df", None) is None:
            raise ValueError("arg_df attribute not specified.")

        if "self" in kwargs:
            del kwargs["self"]

        return where_dataframe(self.arg_df, **kwargs).item()

get_stimulus_index

get_stimulus_index(kwargs)

Get the sequence id for a set of arguments.

Parameters:

Name Type Description Default
kwargs Dict[str, Any]

Dictionary containing independent arguments or parameters describing the sample of the dataset.

required

Returns:

Type Description
int

The sequence id for the given arguments.

Raises:

Type Description
ValueError

If arg_df attribute is not specified.

Note

The child dataset implements the specific method:

def get_stimulus_index(self, arg1, arg2, ...):
    return StimulusDataset.get_stimulus_index(locals())
with locals() specifying kwargs in terms of arg1, arg2, … to index arg_df.

Source code in flyvision/datasets/datasets.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
def get_stimulus_index(self, kwargs: Dict[str, Any]) -> int:
    """Get the sequence id for a set of arguments.

    Args:
        kwargs: Dictionary containing independent arguments or parameters
            describing the sample of the dataset.

    Returns:
        The sequence id for the given arguments.

    Raises:
        ValueError: If arg_df attribute is not specified.

    Note:
        The child dataset implements the specific method:
        ```python
        def get_stimulus_index(self, arg1, arg2, ...):
            return StimulusDataset.get_stimulus_index(locals())
        ```
        with locals() specifying kwargs in terms of `arg1`, `arg2`, ...
        to index arg_df.
    """
    if getattr(self, "arg_df", None) is None:
        raise ValueError("arg_df attribute not specified.")

    if "self" in kwargs:
        del kwargs["self"]

    return where_dataframe(self.arg_df, **kwargs).item()

flyvision.datasets.datasets.MultiTaskDataset

Bases: SequenceDataset

Base class for all (multi-)task sequence datasets.

All (multi-)task sequence datasets can subclass this class. They are expected to implement the following additional attributes and methods.

Attributes:

Name Type Description
tasks List[str]

A list of all tasks.

augment bool

Turns augmentation on and off.

Source code in flyvision/datasets/datasets.py
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
class MultiTaskDataset(SequenceDataset):
    """Base class for all (multi-)task sequence datasets.

    All (multi-)task sequence datasets can subclass this class. They are expected
    to implement the following additional attributes and methods.

    Attributes:
        tasks (List[str]): A list of all tasks.
        augment (bool): Turns augmentation on and off.
    """

    tasks: List[str] = []
    augment: bool = False

    @contextmanager
    def augmentation(self, abool: bool) -> None:
        """Contextmanager to turn augmentation on or off in a code block.

        Args:
            abool: Boolean value to set augmentation.

        Example:
            ```python
            with dataset.augmentation(True):
                for i, data in enumerate(dataloader):
                    ...  # all data is augmented
            ```
        """
        _prev = self.augment
        self.augment = abool
        try:
            yield
        finally:
            self.augment = _prev

    def get_random_data_split(
        self, fold: int, n_folds: int, shuffle: bool = True, seed: int = 0
    ) -> np.ndarray:
        """Returns a random data split.

        Args:
            fold: Current fold number.
            n_folds: Total number of folds.
            shuffle: Whether to shuffle the data.
            seed: Random seed for reproducibility.

        Returns:
            Array of indices for the data split.
        """
        return get_random_data_split(
            fold,
            n_samples=len(self),
            n_folds=n_folds,
            shuffle=shuffle,
            seed=seed,
        )

augmentation

augmentation(abool)

Contextmanager to turn augmentation on or off in a code block.

Parameters:

Name Type Description Default
abool bool

Boolean value to set augmentation.

required
Example
with dataset.augmentation(True):
    for i, data in enumerate(dataloader):
        ...  # all data is augmented
Source code in flyvision/datasets/datasets.py
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
@contextmanager
def augmentation(self, abool: bool) -> None:
    """Contextmanager to turn augmentation on or off in a code block.

    Args:
        abool: Boolean value to set augmentation.

    Example:
        ```python
        with dataset.augmentation(True):
            for i, data in enumerate(dataloader):
                ...  # all data is augmented
        ```
    """
    _prev = self.augment
    self.augment = abool
    try:
        yield
    finally:
        self.augment = _prev

get_random_data_split

get_random_data_split(fold, n_folds, shuffle=True, seed=0)

Returns a random data split.

Parameters:

Name Type Description Default
fold int

Current fold number.

required
n_folds int

Total number of folds.

required
shuffle bool

Whether to shuffle the data.

True
seed int

Random seed for reproducibility.

0

Returns:

Type Description
ndarray

Array of indices for the data split.

Source code in flyvision/datasets/datasets.py
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
def get_random_data_split(
    self, fold: int, n_folds: int, shuffle: bool = True, seed: int = 0
) -> np.ndarray:
    """Returns a random data split.

    Args:
        fold: Current fold number.
        n_folds: Total number of folds.
        shuffle: Whether to shuffle the data.
        seed: Random seed for reproducibility.

    Returns:
        Array of indices for the data split.
    """
    return get_random_data_split(
        fold,
        n_samples=len(self),
        n_folds=n_folds,
        shuffle=shuffle,
        seed=seed,
    )

Flashes

flyvision.datasets.flashes.Flashes

Bases: SequenceDataset

Flashes dataset.

Parameters:

Name Type Description Default
boxfilter Dict[str, int]

Parameters for the BoxEye filter.

dict(extent=15, kernel_size=13)
dynamic_range List[float]

Range of intensities. E.g. [0, 1] renders flashes with decrement 0.5->0 and increment 0.5->1.

[0, 1]
t_stim float

Duration of the stimulus.

1.0
t_pre float

Duration of the grey stimulus.

1.0
dt float

Timesteps.

1 / 200
radius List[int]

Radius of the stimulus.

[-1, 6]
alternations Tuple[int, ...]

Sequence of alternations between lower or upper intensity and baseline of the dynamic range.

(0, 1, 0)

Attributes:

Name Type Description
dt Union[float, None]

Timestep.

t_post float

Post-stimulus time.

flashes_dir

Directory containing rendered flashes.

config

Configuration object.

baseline

Baseline intensity.

arg_df

DataFrame containing flash parameters.

Note

Zero alternation is the prestimulus and baseline. One alternation is the central stimulus. Has to start with zero alternation. t_pre is the duration of the prestimulus and t_stim is the duration of the stimulus.

Source code in flyvision/datasets/flashes.py
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
class Flashes(SequenceDataset):
    """Flashes dataset.

    Args:
        boxfilter: Parameters for the BoxEye filter.
        dynamic_range: Range of intensities. E.g. [0, 1] renders flashes
            with decrement 0.5->0 and increment 0.5->1.
        t_stim: Duration of the stimulus.
        t_pre: Duration of the grey stimulus.
        dt: Timesteps.
        radius: Radius of the stimulus.
        alternations: Sequence of alternations between lower or upper intensity and
            baseline of the dynamic range.

    Attributes:
        dt: Timestep.
        t_post: Post-stimulus time.
        flashes_dir: Directory containing rendered flashes.
        config: Configuration object.
        baseline: Baseline intensity.
        arg_df: DataFrame containing flash parameters.

    Note:
        Zero alternation is the prestimulus and baseline. One alternation is the
        central stimulus. Has to start with zero alternation. `t_pre` is the
        duration of the prestimulus and `t_stim` is the duration of the stimulus.
    """

    dt: Union[float, None] = None
    t_post: float = 0.0

    def __init__(
        self,
        boxfilter: Dict[str, int] = dict(extent=15, kernel_size=13),
        dynamic_range: List[float] = [0, 1],
        t_stim: float = 1.0,
        t_pre: float = 1.0,
        dt: float = 1 / 200,
        radius: List[int] = [-1, 6],
        alternations: Tuple[int, ...] = (0, 1, 0),
    ):
        assert alternations[0] == 0, "First alternation must be 0."
        self.flashes_dir = RenderedFlashes(
            boxfilter=boxfilter,
            dynamic_range=dynamic_range,
            t_stim=t_stim,
            t_pre=t_pre,
            dt=dt,
            radius=radius,
            alternations=alternations,
        )
        self.config = self.flashes_dir.config
        baseline = 2 * (sum(dynamic_range) / 2,)
        intensity = dynamic_range.copy()

        params = [
            (p[0][0], p[0][1], p[1])
            for p in list(product(zip(baseline, intensity), radius))
        ]
        self.baseline = baseline[0]
        self.arg_df = pd.DataFrame(params, columns=["baseline", "intensity", "radius"])

        self.dt = dt

    @property
    def t_pre(self) -> float:
        """Duration of the prestimulus and zero alternation."""
        return self.config.t_pre

    @property
    def t_stim(self) -> float:
        """Duration of the one alternation."""
        return self.config.t_stim

    def get_item(self, key: int) -> torch.Tensor:
        """Index the dataset.

        Args:
            key: Index of the item to retrieve.

        Returns:
            Flash sequence at the given index.
        """
        return torch.Tensor(self.flashes_dir.flashes[key])

    def __repr__(self) -> str:
        """Return a string representation of the dataset."""
        return f"Flashes dataset. Parametrization: \n{self.arg_df}"

t_pre property

t_pre

Duration of the prestimulus and zero alternation.

t_stim property

t_stim

Duration of the one alternation.

get_item

get_item(key)

Index the dataset.

Parameters:

Name Type Description Default
key int

Index of the item to retrieve.

required

Returns:

Type Description
Tensor

Flash sequence at the given index.

Source code in flyvision/datasets/flashes.py
205
206
207
208
209
210
211
212
213
214
def get_item(self, key: int) -> torch.Tensor:
    """Index the dataset.

    Args:
        key: Index of the item to retrieve.

    Returns:
        Flash sequence at the given index.
    """
    return torch.Tensor(self.flashes_dir.flashes[key])

__repr__

__repr__()

Return a string representation of the dataset.

Source code in flyvision/datasets/flashes.py
216
217
218
def __repr__(self) -> str:
    """Return a string representation of the dataset."""
    return f"Flashes dataset. Parametrization: \n{self.arg_df}"

Moving Bar

flyvision.datasets.moving_bar.MovingBar

Bases: StimulusDataset

Moving bar stimulus.

Parameters:

Name Type Description Default
widths list[int]

Width of the bar in half ommatidia.

[1, 2, 4]
offsets tuple[int, int]

First and last offset to the central column in half ommatidia.

(-10, 11)
intensities list[float]

Intensity of the bar.

[0, 1]
speeds list[float]

Speed of the bar in half ommatidia per second.

[2.4, 4.8, 9.7, 13, 19, 25]
height int

Height of the bar in half ommatidia.

9
dt float

Time step in seconds.

1 / 200
device str

Device to store the stimulus.

device
bar_loc_horizontal float

Horizontal location of the bar in radians from left to right of image plane. np.radians(90) is the center.

radians(90)
post_pad_mode Literal['continue', 'value', 'reflect']

Padding mode after the stimulus. One of ‘continue’, ‘value’, ‘reflect’. If ‘value’ the padding is filled with bg_intensity.

'value'
t_pre float

Time before the stimulus in seconds.

1.0
t_post float

Time after the stimulus in seconds.

1.0
build_stim_on_init bool

Build the stimulus on initialization.

True
shuffle_offsets bool

Shuffle the offsets to remove spatio-temporal correlation.

False
seed int

Seed for the random state.

0
angles list[int]

List of angles in degrees.

[0, 30, 60, 90, 120, 150, 180, 210, 240, 270, 300, 330]

Attributes:

Name Type Description
config Namespace

Configuration parameters.

omm_width float

Width of ommatidium in radians.

led_width float

Width of LED in radians.

angles ndarray

Array of angles in degrees.

widths ndarray

Array of widths in half ommatidia.

offsets ndarray

Array of offsets in half ommatidia.

intensities ndarray

Array of intensities.

speeds ndarray

Array of speeds in half ommatidia per second.

bg_intensity float

Background intensity.

n_bars int

Number of bars.

bar_loc_horizontal float

Horizontal location of bar in radians.

t_stim ndarray

Stimulation times for each speed.

t_stim_max float

Maximum stimulation time.

height float

Height of bar in radians.

post_pad_mode str

Padding mode after the stimulus.

arg_df DataFrame

DataFrame of stimulus parameters.

arg_group_df DataFrame

Grouped DataFrame of stimulus parameters.

device str

Device for storing stimuli.

shuffle_offsets bool

Whether to shuffle offsets.

randomstate RandomState

Random state for shuffling.

Source code in flyvision/datasets/moving_bar.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
class MovingBar(StimulusDataset):
    """Moving bar stimulus.

    Args:
        widths: Width of the bar in half ommatidia.
        offsets: First and last offset to the central column in half ommatidia.
        intensities: Intensity of the bar.
        speeds: Speed of the bar in half ommatidia per second.
        height: Height of the bar in half ommatidia.
        dt: Time step in seconds.
        device: Device to store the stimulus.
        bar_loc_horizontal: Horizontal location of the bar in radians from left to
            right of image plane. np.radians(90) is the center.
        post_pad_mode: Padding mode after the stimulus. One of 'continue', 'value',
            'reflect'. If 'value' the padding is filled with `bg_intensity`.
        t_pre: Time before the stimulus in seconds.
        t_post: Time after the stimulus in seconds.
        build_stim_on_init: Build the stimulus on initialization.
        shuffle_offsets: Shuffle the offsets to remove spatio-temporal correlation.
        seed: Seed for the random state.
        angles: List of angles in degrees.

    Attributes:
        config (Namespace): Configuration parameters.
        omm_width (float): Width of ommatidium in radians.
        led_width (float): Width of LED in radians.
        angles (np.ndarray): Array of angles in degrees.
        widths (np.ndarray): Array of widths in half ommatidia.
        offsets (np.ndarray): Array of offsets in half ommatidia.
        intensities (np.ndarray): Array of intensities.
        speeds (np.ndarray): Array of speeds in half ommatidia per second.
        bg_intensity (float): Background intensity.
        n_bars (int): Number of bars.
        bar_loc_horizontal (float): Horizontal location of bar in radians.
        t_stim (np.ndarray): Stimulation times for each speed.
        t_stim_max (float): Maximum stimulation time.
        height (float): Height of bar in radians.
        post_pad_mode (str): Padding mode after the stimulus.
        arg_df (pd.DataFrame): DataFrame of stimulus parameters.
        arg_group_df (pd.DataFrame): Grouped DataFrame of stimulus parameters.
        device (str): Device for storing stimuli.
        shuffle_offsets (bool): Whether to shuffle offsets.
        randomstate (np.random.RandomState): Random state for shuffling.
    """

    arg_df: pd.DataFrame = None

    def __init__(
        self,
        widths: list[int] = [1, 2, 4],
        offsets: tuple[int, int] = (-10, 11),
        intensities: list[float] = [0, 1],
        speeds: list[float] = [2.4, 4.8, 9.7, 13, 19, 25],
        height: int = 9,
        dt: float = 1 / 200,
        device: str = flyvision.device,
        bar_loc_horizontal: float = np.radians(90),
        post_pad_mode: Literal["continue", "value", "reflect"] = "value",
        t_pre: float = 1.0,
        t_post: float = 1.0,
        build_stim_on_init: bool = True,
        shuffle_offsets: bool = False,
        seed: int = 0,
        angles: list[int] = [0, 30, 60, 90, 120, 150, 180, 210, 240, 270, 300, 330],
    ) -> None:
        super().__init__()
        # HexEye parameter
        self.omm_width = np.radians(5.8)  # radians(5.8 degree)

        # Monitor parameter
        self.led_width = np.radians(2.25)  # Gruntman et al. 2018

        _locals = locals()
        self.config = Namespace({
            arg: _locals[arg]
            for arg in [
                "widths",
                "offsets",
                "intensities",
                "speeds",
                "height",
                "bar_loc_horizontal",
                "shuffle_offsets",
                "post_pad_mode",
                "t_pre",
                "t_post",
                "dt",
                "angles",
            ]
        })

        # Stim parameter
        self.angles = np.array(angles)
        self.widths = np.array(widths)  # half ommatidia
        if len(offsets) == 2:
            self.offsets = np.arange(*offsets)  # half ommatidia
        else:
            assert (
                np.mean(offsets[1:] - offsets[:-1]) == 1
            )  # t_stim assumes spacing of 1 corresponding to 2.25 deg
            self.offsets = offsets
        self.intensities = np.array(intensities)
        self.speeds = np.array(speeds)
        self.bg_intensity = 0.5
        self.n_bars = 1
        self.bar_loc_horizontal = bar_loc_horizontal

        self.t_stim = (len(self.offsets) * self.led_width) / (
            self.speeds * self.omm_width
        )
        self.t_stim_max = np.max(self.t_stim)

        self._speed_to_t_stim = dict(zip(self.speeds, self.t_stim))

        self.height = self.led_width * height

        self.post_pad_mode = post_pad_mode
        self._t_pre = t_pre
        self._t_post = t_post

        params = [
            (*p[:-1], *p[-1])
            for p in list(
                product(
                    self.angles,
                    self.widths,
                    self.intensities,
                    zip(self.t_stim, self.speeds),
                )
            )
        ]
        self.arg_df = pd.DataFrame(
            params, columns=["angle", "width", "intensity", "t_stim", "speed"]
        )

        self.arg_group_df = self.arg_df.groupby(
            ["angle", "width", "intensity"], sort=False, as_index=False
        ).all()

        self.device = device
        self.shuffle_offsets = shuffle_offsets
        self.randomstate = None
        if self.shuffle_offsets:
            self.randomstate = np.random.RandomState(seed=seed)

        self._dt = dt

        self._built = False
        if build_stim_on_init:
            self._build()
            self._resample()
            self._built = True

    @property
    def dt(self) -> float:
        """Time step in seconds."""
        return getattr(self, "_dt", None)

    @dt.setter
    def dt(self, value: float) -> None:
        if self._dt == value:
            self._dt = value
            if self._built:
                self._resample()
            return
        logging.warning(
            "Cannot override dt=%s because responses with dt=%s are initialized. "
            "Keeping dt=%s.",
            value,
            self._dt,
            self._dt,
        )

    def __repr__(self) -> str:
        return (
            f"{self.__class__.__name__}\n"
            + "Config:\n"
            + repr(self.config)
            + "Stimulus parameter:\n"
            + repr(self.arg_df)
        )

    @property
    def t_pre(self) -> float:
        """Time before stimulus onset in seconds."""
        return self._t_pre

    @property
    def t_post(self) -> float:
        """Time after stimulus offset in seconds."""
        return self._t_post

    def _build(self) -> None:
        """Build the stimulus."""
        self.wrap = RenderedOffsets(
            dict(
                angles=self.angles,
                widths=self.widths,
                intensities=self.intensities,
                offsets=self.offsets,
                led_width=self.led_width,
                height=self.height,
                n_bars=self.n_bars,
                bg_intensity=self.bg_intensity,
                bar_loc_horizontal=self.bar_loc_horizontal,
            )
        )

        self._offsets = torch.tensor(self.wrap.offsets[:], device=self.device)
        self._built = True

    def _resample(self) -> None:
        """Resample the stimulus at runtime."""
        # resampling at runtime to allow for changing dt and to save GB of
        # storage.
        self.sequences = {}
        self.indices = {}
        for t, speed in zip(self.t_stim, self.speeds):
            sequence, indices = resample(
                self._offsets,
                t,
                self.dt,
                dim=1,
                device=self.device,
                return_indices=True,
            )
            if self.shuffle_offsets:
                # breakpoint()
                sequence = shuffle(sequence, self.randomstate)
            sequence = pad(
                sequence,
                t + self.t_pre,
                self.dt,
                mode="start",
                fill=self.bg_intensity,
            )
            sequence = pad(
                sequence,
                t + self.t_pre + self.t_post,
                self.dt,
                mode="end",
                pad_mode=self.post_pad_mode,
                fill=self.bg_intensity,
            )
            # Because we fix the distance that the bar moves but vary speeds we
            # have different stimulation times. To make all sequences equal
            # length for storing them in a single tensor, we pad them with nans
            # based on the maximal stimulation time (slowest speed). The nans
            # can later be removed before processing the traces.
            sequence = pad(
                sequence,
                self.t_stim_max + self.t_pre + self.t_post,
                self.dt,
                mode="end",
                fill=np.nan,
            )
            self.sequences[speed] = sequence
            self.indices[speed] = indices

    def _key(self, angle: float, width: float, intensity: float, speed: float) -> int:
        """Get the key for a specific stimulus configuration."""
        try:
            return self.arg_df.query(
                f"angle=={angle}"
                f" & width=={width}"
                f" & intensity == {intensity}"
                f" & speed == {speed}"
            ).index.values.item()
        except ValueError:
            raise ValueError(
                f"angle: {angle}, width: {width}, intensity: {intensity}, "
                f"speed: {speed} invalid."
            ) from None

    def get_sequence_id_from_arguments(
        self, angle: float, width: float, intensity: float, speed: float
    ) -> int:
        """Get sequence ID from stimulus arguments."""
        return self.get_stimulus_index(locals())

    def _params(self, key: int) -> np.ndarray:
        """Get parameters for a given key."""
        return self.arg_df.iloc[key].values

    def _group_key(self, angle: float, width: float, intensity: float) -> int:
        """Get group key for a specific stimulus configuration."""
        return self.arg_group_df.query(
            f"angle=={angle}" f" & width=={width}" f" & intensity == {intensity}"
        ).index.values.item()

    def _group_params(self, key: int) -> np.ndarray:
        """Get group parameters for a given key."""
        return self.arg_group_df.iloc[key].values

    def get(
        self, angle: float, width: float, intensity: float, speed: float
    ) -> torch.Tensor:
        """Get stimulus for specific parameters."""
        key = self._key(angle, width, intensity, speed)
        return self[key]

    def get_item(self, key: int) -> torch.Tensor:
        """Get stimulus for a specific key."""
        angle, width, intensity, _, speed = self._params(key)
        return self.sequences[speed][self._group_key(angle, width, intensity)]

    def mask(
        self,
        angle: Optional[float] = None,
        width: Optional[float] = None,
        intensity: Optional[float] = None,
        speed: Optional[float] = None,
        t_stim: Optional[float] = None,
    ) -> np.ndarray:
        """Create a mask for specific stimulus parameters."""
        # 22x faster than df.query
        values = self.arg_df.values

        def iterparam(param, name, axis, and_condition):
            condition = np.zeros(len(values)).astype(bool)
            if isinstance(param, Iterable):
                for p in param:
                    _new = values.take(axis, axis=1) == p
                    assert any(_new), f"{name} {p} not in dataset."
                    condition = np.logical_or(condition, _new)
            else:
                _new = values.take(axis, axis=1) == param
                assert any(_new), f"{name} {param} not in dataset."
                condition = np.logical_or(condition, _new)
            return condition & and_condition

        condition = np.ones(len(values)).astype(bool)
        if angle is not None:
            condition = iterparam(angle, "angle", 0, condition)
        if width is not None:
            condition = iterparam(width, "width", 1, condition)
        if intensity is not None:
            condition = iterparam(intensity, "intensity", 2, condition)
        if t_stim is not None:
            condition = iterparam(t_stim, "t_stim", 3, condition)
        if speed is not None:
            condition = iterparam(speed, "speed", 4, condition)
        return condition

    @property
    def time(self) -> np.ndarray:
        """Time array for the stimulus."""
        return np.arange(-self.t_pre, self.t_stim_max + self.t_post - self.dt, self.dt)

    def stimulus(
        self,
        angle: Optional[float] = None,
        width: Optional[float] = None,
        intensity: Optional[float] = None,
        speed: Optional[float] = None,
        pre_stim: bool = True,
        post_stim: bool = True,
    ) -> np.ndarray:
        """Get stimulus for specific parameters.

        Args:
            angle: Angle of the bar.
            width: Width of the bar.
            intensity: Intensity of the bar.
            speed: Speed of the bar.
            pre_stim: Include pre-stimulus period.
            post_stim: Include post-stimulus period.

        Returns:
            Stimulus array.
        """
        key = self._key(angle, width, intensity, speed)
        stim = self[key][:, 360].cpu().numpy()
        if not post_stim:
            stim = filter_post([stim], self.t_post, self.dt).squeeze()
        if not pre_stim:
            stim = filter_pre(stim[None], self.t_pre, self.dt).squeeze()
        return stim

    def stimulus_parameters(
        self,
        angle: Optional[float] = None,
        width: Optional[float] = None,
        intensity: Optional[float] = None,
        speed: Optional[float] = None,
    ) -> tuple[list, ...]:
        """Get stimulus parameters."""

        def _number_to_list(*args):
            returns = tuple()
            for arg in args:
                if isinstance(arg, Number):
                    returns += ([arg],)
                else:
                    returns += (arg,)
            return returns

        angle, width, speed, intensity = _number_to_list(angle, width, speed, intensity)
        angle = angle or self.angles
        width = width or self.widths
        intensity = intensity or self.intensities
        speed = speed or self.speeds
        return angle, width, intensity, speed

    def sample_shape(
        self,
        angle: Optional[float] = None,
        width: Optional[float] = None,
        intensity: Optional[float] = None,
        speed: Optional[float] = None,
    ) -> tuple[int, ...]:
        """Get shape of stimulus sample for given parameters."""
        if isinstance(angle, Number):
            angle = [angle]
        if isinstance(width, Number):
            width = [width]
        if isinstance(speed, Number):
            speed = [speed]
        if isinstance(intensity, Number):
            intensity = [intensity]
        angle = angle or self.angles
        width = width or self.widths
        intensity = intensity or self.intensities
        speed = speed or self.speeds
        return (
            len(angle),
            len(width),
            len(intensity),
            len(speed),
        )

    def time_to_center(self, speed: float) -> float:
        """Calculate time for bar to reach center at given speed."""
        # Note: time = distance / velocity, i.e.
        #     time = (n_leds * led_width) / (speed * omm_width)
        #     with speed in ommatidia / s.
        return np.abs(self.config.offsets[0]) * self.led_width / (speed * self.omm_width)

    def get_time_with_origin_at_onset(self) -> np.ndarray:
        """Get time array with origin at stimulus onset."""
        return np.linspace(
            -self.t_pre,
            self.t_stim_max - self.t_pre + self.t_post,
            int(self.t_stim_max / self.dt)
            + int(self.t_post / self.dt)
            + int(self.t_pre / self.dt),
        )

    def get_time_with_origin_at_center(self, speed: float) -> np.ndarray:
        """Get time array with origin where bar reaches central column."""
        time_to_center = self.time_to_center(speed)
        n_steps = (
            int(self.t_stim_max / self.dt)
            + int(self.t_post / self.dt)
            + int(self.t_pre / self.dt)
        )
        return np.linspace(
            -(self.t_pre + time_to_center),
            n_steps * self.dt - (self.t_pre + time_to_center),
            n_steps,
        )

    def stimulus_cartoon(
        self,
        angle: float,
        width: float,
        intensity: float,
        speed: float,
        time_after_stimulus_onset: float = 0.5,
        fig: Optional[plt.Figure] = None,
        ax: Optional[plt.Axes] = None,
        facecolor: str = "#000000",
        cmap: Colormap = plt.cm.bone,
        alpha: float = 0.5,
        vmin: float = 0,
        vmax: float = 1,
        edgecolor: str = "none",
        central_hex_color: str = "#2f7cb9",
        **kwargs,
    ) -> tuple[plt.Figure, plt.Axes]:
        """Create a cartoon representation of the stimulus."""
        fig, ax = init_plot(fig=fig, ax=ax)

        time = (
            np.arange(
                0,
                self.t_pre + self.t_stim_max + self.t_post - self.dt,
                self.dt,
            )
            - self.t_pre
        )
        index = np.argmin(np.abs(time - time_after_stimulus_onset))

        fig, ax, _ = quick_hex_scatter(
            self.get(angle=angle, width=width, speed=speed, intensity=intensity)
            .cpu()
            .numpy()[index],
            vmin=vmin,
            vmax=vmax,
            cbar=False,
            figsize=[1, 1],
            max_extent=5,
            fig=fig,
            ax=ax,
            cmap=cmap,
            alpha=alpha,
            edgecolor=edgecolor,
            **kwargs,
        )
        rotation = np.array([
            [
                np.cos(np.radians(angle - 90)),
                -np.sin(np.radians(angle - 90)),
            ],
            [
                np.sin(np.radians(angle - 90)),
                np.cos(np.radians(angle - 90)),
            ],
        ])
        x = rotation @ np.array([0, -5])
        dx = rotation @ np.array([0, 1])
        ax.arrow(
            *x,
            *dx,
            facecolor=facecolor,
            width=0.75,
            head_length=2.5,
            edgecolor="k",
            linewidth=0.25,
        )
        _hex = RegularPolygon(
            (0, 0),
            numVertices=6,
            radius=1,
            linewidth=1,
            orientation=np.radians(30),
            edgecolor=central_hex_color,
            facecolor=central_hex_color,
            alpha=1,
            ls="-",
        )
        ax.add_patch(_hex)

        return fig, ax

dt property writable

dt

Time step in seconds.

t_pre property

t_pre

Time before stimulus onset in seconds.

t_post property

t_post

Time after stimulus offset in seconds.

time property

time

Time array for the stimulus.

get_sequence_id_from_arguments

get_sequence_id_from_arguments(
    angle, width, intensity, speed
)

Get sequence ID from stimulus arguments.

Source code in flyvision/datasets/moving_bar.py
365
366
367
368
369
def get_sequence_id_from_arguments(
    self, angle: float, width: float, intensity: float, speed: float
) -> int:
    """Get sequence ID from stimulus arguments."""
    return self.get_stimulus_index(locals())

get

get(angle, width, intensity, speed)

Get stimulus for specific parameters.

Source code in flyvision/datasets/moving_bar.py
385
386
387
388
389
390
def get(
    self, angle: float, width: float, intensity: float, speed: float
) -> torch.Tensor:
    """Get stimulus for specific parameters."""
    key = self._key(angle, width, intensity, speed)
    return self[key]

get_item

get_item(key)

Get stimulus for a specific key.

Source code in flyvision/datasets/moving_bar.py
392
393
394
395
def get_item(self, key: int) -> torch.Tensor:
    """Get stimulus for a specific key."""
    angle, width, intensity, _, speed = self._params(key)
    return self.sequences[speed][self._group_key(angle, width, intensity)]

mask

mask(
    angle=None,
    width=None,
    intensity=None,
    speed=None,
    t_stim=None,
)

Create a mask for specific stimulus parameters.

Source code in flyvision/datasets/moving_bar.py
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
def mask(
    self,
    angle: Optional[float] = None,
    width: Optional[float] = None,
    intensity: Optional[float] = None,
    speed: Optional[float] = None,
    t_stim: Optional[float] = None,
) -> np.ndarray:
    """Create a mask for specific stimulus parameters."""
    # 22x faster than df.query
    values = self.arg_df.values

    def iterparam(param, name, axis, and_condition):
        condition = np.zeros(len(values)).astype(bool)
        if isinstance(param, Iterable):
            for p in param:
                _new = values.take(axis, axis=1) == p
                assert any(_new), f"{name} {p} not in dataset."
                condition = np.logical_or(condition, _new)
        else:
            _new = values.take(axis, axis=1) == param
            assert any(_new), f"{name} {param} not in dataset."
            condition = np.logical_or(condition, _new)
        return condition & and_condition

    condition = np.ones(len(values)).astype(bool)
    if angle is not None:
        condition = iterparam(angle, "angle", 0, condition)
    if width is not None:
        condition = iterparam(width, "width", 1, condition)
    if intensity is not None:
        condition = iterparam(intensity, "intensity", 2, condition)
    if t_stim is not None:
        condition = iterparam(t_stim, "t_stim", 3, condition)
    if speed is not None:
        condition = iterparam(speed, "speed", 4, condition)
    return condition

stimulus

stimulus(
    angle=None,
    width=None,
    intensity=None,
    speed=None,
    pre_stim=True,
    post_stim=True,
)

Get stimulus for specific parameters.

Parameters:

Name Type Description Default
angle Optional[float]

Angle of the bar.

None
width Optional[float]

Width of the bar.

None
intensity Optional[float]

Intensity of the bar.

None
speed Optional[float]

Speed of the bar.

None
pre_stim bool

Include pre-stimulus period.

True
post_stim bool

Include post-stimulus period.

True

Returns:

Type Description
ndarray

Stimulus array.

Source code in flyvision/datasets/moving_bar.py
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
def stimulus(
    self,
    angle: Optional[float] = None,
    width: Optional[float] = None,
    intensity: Optional[float] = None,
    speed: Optional[float] = None,
    pre_stim: bool = True,
    post_stim: bool = True,
) -> np.ndarray:
    """Get stimulus for specific parameters.

    Args:
        angle: Angle of the bar.
        width: Width of the bar.
        intensity: Intensity of the bar.
        speed: Speed of the bar.
        pre_stim: Include pre-stimulus period.
        post_stim: Include post-stimulus period.

    Returns:
        Stimulus array.
    """
    key = self._key(angle, width, intensity, speed)
    stim = self[key][:, 360].cpu().numpy()
    if not post_stim:
        stim = filter_post([stim], self.t_post, self.dt).squeeze()
    if not pre_stim:
        stim = filter_pre(stim[None], self.t_pre, self.dt).squeeze()
    return stim

stimulus_parameters

stimulus_parameters(
    angle=None, width=None, intensity=None, speed=None
)

Get stimulus parameters.

Source code in flyvision/datasets/moving_bar.py
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
def stimulus_parameters(
    self,
    angle: Optional[float] = None,
    width: Optional[float] = None,
    intensity: Optional[float] = None,
    speed: Optional[float] = None,
) -> tuple[list, ...]:
    """Get stimulus parameters."""

    def _number_to_list(*args):
        returns = tuple()
        for arg in args:
            if isinstance(arg, Number):
                returns += ([arg],)
            else:
                returns += (arg,)
        return returns

    angle, width, speed, intensity = _number_to_list(angle, width, speed, intensity)
    angle = angle or self.angles
    width = width or self.widths
    intensity = intensity or self.intensities
    speed = speed or self.speeds
    return angle, width, intensity, speed

sample_shape

sample_shape(
    angle=None, width=None, intensity=None, speed=None
)

Get shape of stimulus sample for given parameters.

Source code in flyvision/datasets/moving_bar.py
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
def sample_shape(
    self,
    angle: Optional[float] = None,
    width: Optional[float] = None,
    intensity: Optional[float] = None,
    speed: Optional[float] = None,
) -> tuple[int, ...]:
    """Get shape of stimulus sample for given parameters."""
    if isinstance(angle, Number):
        angle = [angle]
    if isinstance(width, Number):
        width = [width]
    if isinstance(speed, Number):
        speed = [speed]
    if isinstance(intensity, Number):
        intensity = [intensity]
    angle = angle or self.angles
    width = width or self.widths
    intensity = intensity or self.intensities
    speed = speed or self.speeds
    return (
        len(angle),
        len(width),
        len(intensity),
        len(speed),
    )

time_to_center

time_to_center(speed)

Calculate time for bar to reach center at given speed.

Source code in flyvision/datasets/moving_bar.py
522
523
524
525
526
527
def time_to_center(self, speed: float) -> float:
    """Calculate time for bar to reach center at given speed."""
    # Note: time = distance / velocity, i.e.
    #     time = (n_leds * led_width) / (speed * omm_width)
    #     with speed in ommatidia / s.
    return np.abs(self.config.offsets[0]) * self.led_width / (speed * self.omm_width)

get_time_with_origin_at_onset

get_time_with_origin_at_onset()

Get time array with origin at stimulus onset.

Source code in flyvision/datasets/moving_bar.py
529
530
531
532
533
534
535
536
537
def get_time_with_origin_at_onset(self) -> np.ndarray:
    """Get time array with origin at stimulus onset."""
    return np.linspace(
        -self.t_pre,
        self.t_stim_max - self.t_pre + self.t_post,
        int(self.t_stim_max / self.dt)
        + int(self.t_post / self.dt)
        + int(self.t_pre / self.dt),
    )

get_time_with_origin_at_center

get_time_with_origin_at_center(speed)

Get time array with origin where bar reaches central column.

Source code in flyvision/datasets/moving_bar.py
539
540
541
542
543
544
545
546
547
548
549
550
551
def get_time_with_origin_at_center(self, speed: float) -> np.ndarray:
    """Get time array with origin where bar reaches central column."""
    time_to_center = self.time_to_center(speed)
    n_steps = (
        int(self.t_stim_max / self.dt)
        + int(self.t_post / self.dt)
        + int(self.t_pre / self.dt)
    )
    return np.linspace(
        -(self.t_pre + time_to_center),
        n_steps * self.dt - (self.t_pre + time_to_center),
        n_steps,
    )

stimulus_cartoon

stimulus_cartoon(
    angle,
    width,
    intensity,
    speed,
    time_after_stimulus_onset=0.5,
    fig=None,
    ax=None,
    facecolor="#000000",
    cmap=plt.cm.bone,
    alpha=0.5,
    vmin=0,
    vmax=1,
    edgecolor="none",
    central_hex_color="#2f7cb9",
    **kwargs
)

Create a cartoon representation of the stimulus.

Source code in flyvision/datasets/moving_bar.py
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
def stimulus_cartoon(
    self,
    angle: float,
    width: float,
    intensity: float,
    speed: float,
    time_after_stimulus_onset: float = 0.5,
    fig: Optional[plt.Figure] = None,
    ax: Optional[plt.Axes] = None,
    facecolor: str = "#000000",
    cmap: Colormap = plt.cm.bone,
    alpha: float = 0.5,
    vmin: float = 0,
    vmax: float = 1,
    edgecolor: str = "none",
    central_hex_color: str = "#2f7cb9",
    **kwargs,
) -> tuple[plt.Figure, plt.Axes]:
    """Create a cartoon representation of the stimulus."""
    fig, ax = init_plot(fig=fig, ax=ax)

    time = (
        np.arange(
            0,
            self.t_pre + self.t_stim_max + self.t_post - self.dt,
            self.dt,
        )
        - self.t_pre
    )
    index = np.argmin(np.abs(time - time_after_stimulus_onset))

    fig, ax, _ = quick_hex_scatter(
        self.get(angle=angle, width=width, speed=speed, intensity=intensity)
        .cpu()
        .numpy()[index],
        vmin=vmin,
        vmax=vmax,
        cbar=False,
        figsize=[1, 1],
        max_extent=5,
        fig=fig,
        ax=ax,
        cmap=cmap,
        alpha=alpha,
        edgecolor=edgecolor,
        **kwargs,
    )
    rotation = np.array([
        [
            np.cos(np.radians(angle - 90)),
            -np.sin(np.radians(angle - 90)),
        ],
        [
            np.sin(np.radians(angle - 90)),
            np.cos(np.radians(angle - 90)),
        ],
    ])
    x = rotation @ np.array([0, -5])
    dx = rotation @ np.array([0, 1])
    ax.arrow(
        *x,
        *dx,
        facecolor=facecolor,
        width=0.75,
        head_length=2.5,
        edgecolor="k",
        linewidth=0.25,
    )
    _hex = RegularPolygon(
        (0, 0),
        numVertices=6,
        radius=1,
        linewidth=1,
        orientation=np.radians(30),
        edgecolor=central_hex_color,
        facecolor=central_hex_color,
        alpha=1,
        ls="-",
    )
    ax.add_patch(_hex)

    return fig, ax

Impulses

flyvision.datasets.dots.Dots

Bases: StimulusDataset

Render flashes aka dots per ommatidia.

Note

Renders directly in receptor space, does not use BoxEye or HexEye as eye-model.

Parameters:

Name Type Description Default
dot_column_radius int

Radius of the dot column.

0
max_extent int

Maximum extent of the stimulus.

15
bg_intensity float

Background intensity.

0.5
t_stim float

Stimulus duration.

5
dt float

Time step.

1 / 200
t_impulse Optional[float]

Impulse duration.

None
n_ommatidia int

Number of ommatidia.

721
t_pre float

Pre-stimulus duration.

2.0
t_post float

Post-stimulus duration.

0
intensity float

Stimulus intensity.

1
mode Literal['sustained', 'impulse']

Stimulus mode (‘sustained’ or ‘impulse’).

'sustained'
device device

Torch device for computations.

device

Attributes:

Name Type Description
dt Optional[float]

Time step.

arg_df Optional[DataFrame]

DataFrame containing stimulus parameters.

config

Namespace containing configuration parameters.

t_stim

Stimulus duration.

n_ommatidia

Number of ommatidia.

offsets

Array of ommatidia offsets.

u

U-coordinates of the hexagonal grid.

v

V-coordinates of the hexagonal grid.

extent_condition

Boolean mask for the extent condition.

max_extent

Maximum extent of the stimulus.

bg_intensity

Background intensity.

intensities

List of stimulus intensities.

device

Torch device for computations.

mode

Stimulus mode (‘sustained’ or ‘impulse’).

params

List of stimulus parameters.

t_impulse

Impulse duration.

Raises:

Type Description
ValueError

If dot_column_radius is greater than max_extent.

Source code in flyvision/datasets/dots.py
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
class Dots(StimulusDataset):
    """
    Render flashes aka dots per ommatidia.

    Note:
        Renders directly in receptor space, does not use BoxEye or HexEye as eye-model.

    Args:
        dot_column_radius: Radius of the dot column.
        max_extent: Maximum extent of the stimulus.
        bg_intensity: Background intensity.
        t_stim: Stimulus duration.
        dt: Time step.
        t_impulse: Impulse duration.
        n_ommatidia: Number of ommatidia.
        t_pre: Pre-stimulus duration.
        t_post: Post-stimulus duration.
        intensity: Stimulus intensity.
        mode: Stimulus mode ('sustained' or 'impulse').
        device: Torch device for computations.

    Attributes:
        dt: Time step.
        arg_df: DataFrame containing stimulus parameters.
        config: Namespace containing configuration parameters.
        t_stim: Stimulus duration.
        n_ommatidia: Number of ommatidia.
        offsets: Array of ommatidia offsets.
        u: U-coordinates of the hexagonal grid.
        v: V-coordinates of the hexagonal grid.
        extent_condition: Boolean mask for the extent condition.
        max_extent: Maximum extent of the stimulus.
        bg_intensity: Background intensity.
        intensities: List of stimulus intensities.
        device: Torch device for computations.
        mode: Stimulus mode ('sustained' or 'impulse').
        params: List of stimulus parameters.
        t_impulse: Impulse duration.

    Raises:
        ValueError: If dot_column_radius is greater than max_extent.
    """

    dt: Optional[float] = None
    arg_df: Optional[pd.DataFrame] = None

    def __init__(
        self,
        dot_column_radius: int = 0,
        max_extent: int = 15,
        bg_intensity: float = 0.5,
        t_stim: float = 5,
        dt: float = 1 / 200,
        t_impulse: Optional[float] = None,
        n_ommatidia: int = 721,
        t_pre: float = 2.0,
        t_post: float = 0,
        intensity: float = 1,
        mode: Literal["sustained", "impulse"] = "sustained",
        device: torch.device = flyvision.device,
    ):
        if dot_column_radius > max_extent:
            raise ValueError("dot_column_radius must be smaller than max_extent")
        self.config = Namespace(
            dot_column_radius=dot_column_radius,
            max_extent=max_extent,
            bg_intensity=bg_intensity,
            t_stim=t_stim,
            dt=dt,
            n_ommatidia=n_ommatidia,
            t_pre=t_pre,
            t_post=t_post,
            intensity=intensity,
            mode=mode,
            t_impulse=t_impulse,
        )

        self.t_stim = t_stim
        self.t_pre = t_pre
        self.t_post = t_post

        self.n_ommatidia = n_ommatidia
        self.offsets = np.arange(self.n_ommatidia)

        u, v = hex_utils.get_hex_coords(hex_utils.get_hextent(n_ommatidia))
        extent_condition = (
            (-max_extent <= u)
            & (u <= max_extent)
            & (-max_extent <= v)
            & (v <= max_extent)
            & (-max_extent <= u + v)
            & (u + v <= max_extent)
        )
        self.u = u[extent_condition]
        self.v = v[extent_condition]
        # self.offsets = self.offsets[extent_condition]
        self.extent_condition = extent_condition

        # to have multi column dots at every location, construct coordinate_indices
        # for each central column
        coordinate_indices = []
        for u, v in zip(self.u, self.v):
            ring = HexLattice.filled_circle(
                radius=dot_column_radius, center=Hexal(u, v, 0), as_lattice=True
            )
            # mask = np.array([~np.isnan(h.value) for h in h1])
            coordinate_indices.append(self.offsets[ring.where(1)])

        self.max_extent = max_extent
        self.bg_intensity = bg_intensity

        self.intensities = [2 * bg_intensity - intensity, intensity]
        self.device = device
        self.mode = mode

        self.params = [
            (*p[0], p[-1])
            for p in list(
                product(
                    zip(self.u, self.v, self.offsets, coordinate_indices),
                    self.intensities,
                )
            )
        ]

        self.arg_df = pd.DataFrame(
            self.params,
            columns=["u", "v", "offset", "coordinate_index", "intensity"],
        )

        self.dt = dt
        self.t_impulse = t_impulse or self.dt

    def _params(self, key: int) -> np.ndarray:
        """Get parameters for a specific key.

        Args:
            key: Index of the parameters to retrieve.

        Returns:
            Array of parameters for the given key.
        """
        return self.arg_df.iloc[key].values

    def get_item(self, key: int) -> torch.Tensor:
        """Get a stimulus item for a specific key.

        Args:
            key: Index of the item to retrieve.

        Returns:
            Tensor representing the stimulus sequence.
        """
        # create maps with background value
        _dot = (
            torch.ones(self.n_ommatidia, device=self.device)[None, None]
            * self.bg_intensity
        )
        # fill at the ommatitdium at offset with intensity
        _, _, _, coordinate_index, intensity = self._params(key)
        _dot[:, :, coordinate_index] = torch.tensor(intensity, device=self.device).float()

        # repeat for stustained stimulus
        if self.mode == "sustained":
            sequence = resample(_dot, self.t_stim, self.dt, dim=1, device=self.device)

        elif self.mode == "impulse":
            # pad remaining stimulus duration i.e. self.t_stim - self.dt with
            # background intensity
            if self.t_impulse == self.dt:
                sequence = pad(
                    _dot,
                    self.t_stim,
                    self.dt,
                    mode="end",
                    fill=self.bg_intensity,
                )
            # first resample for t_impulse/dt then pad remaining stimulus
            # duration, i.e. t_stim - t_impulse with background intensity
            else:
                sequence = resample(
                    _dot, self.t_impulse, self.dt, dim=1, device=self.device
                )
                sequence = pad(
                    sequence,
                    self.t_stim,
                    self.dt,
                    mode="end",
                    fill=self.bg_intensity,
                )

        # pad with pre stimulus background
        sequence = pad(
            sequence,
            self.t_stim + self.t_pre,
            self.dt,
            mode="start",
            fill=self.bg_intensity,
        )
        # pad with post stimulus background
        sequence = pad(
            sequence,
            self.t_stim + self.t_pre + self.t_post,
            self.dt,
            mode="end",
            fill=self.bg_intensity,
        )
        return sequence.squeeze()

    def get_stimulus_index(self, u: float, v: float, intensity: float) -> int:
        """Get sequence ID from given arguments.

        Args:
            u: U-coordinate.
            v: V-coordinate.
            intensity: Stimulus intensity.

        Returns:
            Sequence ID.
        """
        return StimulusDataset.get_stimulus_index(self, locals())

get_item

get_item(key)

Get a stimulus item for a specific key.

Parameters:

Name Type Description Default
key int

Index of the item to retrieve.

required

Returns:

Type Description
Tensor

Tensor representing the stimulus sequence.

Source code in flyvision/datasets/dots.py
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
def get_item(self, key: int) -> torch.Tensor:
    """Get a stimulus item for a specific key.

    Args:
        key: Index of the item to retrieve.

    Returns:
        Tensor representing the stimulus sequence.
    """
    # create maps with background value
    _dot = (
        torch.ones(self.n_ommatidia, device=self.device)[None, None]
        * self.bg_intensity
    )
    # fill at the ommatitdium at offset with intensity
    _, _, _, coordinate_index, intensity = self._params(key)
    _dot[:, :, coordinate_index] = torch.tensor(intensity, device=self.device).float()

    # repeat for stustained stimulus
    if self.mode == "sustained":
        sequence = resample(_dot, self.t_stim, self.dt, dim=1, device=self.device)

    elif self.mode == "impulse":
        # pad remaining stimulus duration i.e. self.t_stim - self.dt with
        # background intensity
        if self.t_impulse == self.dt:
            sequence = pad(
                _dot,
                self.t_stim,
                self.dt,
                mode="end",
                fill=self.bg_intensity,
            )
        # first resample for t_impulse/dt then pad remaining stimulus
        # duration, i.e. t_stim - t_impulse with background intensity
        else:
            sequence = resample(
                _dot, self.t_impulse, self.dt, dim=1, device=self.device
            )
            sequence = pad(
                sequence,
                self.t_stim,
                self.dt,
                mode="end",
                fill=self.bg_intensity,
            )

    # pad with pre stimulus background
    sequence = pad(
        sequence,
        self.t_stim + self.t_pre,
        self.dt,
        mode="start",
        fill=self.bg_intensity,
    )
    # pad with post stimulus background
    sequence = pad(
        sequence,
        self.t_stim + self.t_pre + self.t_post,
        self.dt,
        mode="end",
        fill=self.bg_intensity,
    )
    return sequence.squeeze()

get_stimulus_index

get_stimulus_index(u, v, intensity)

Get sequence ID from given arguments.

Parameters:

Name Type Description Default
u float

U-coordinate.

required
v float

V-coordinate.

required
intensity float

Stimulus intensity.

required

Returns:

Type Description
int

Sequence ID.

Source code in flyvision/datasets/dots.py
228
229
230
231
232
233
234
235
236
237
238
239
def get_stimulus_index(self, u: float, v: float, intensity: float) -> int:
    """Get sequence ID from given arguments.

    Args:
        u: U-coordinate.
        v: V-coordinate.
        intensity: Stimulus intensity.

    Returns:
        Sequence ID.
    """
    return StimulusDataset.get_stimulus_index(self, locals())

flyvision.datasets.dots.SpatialImpulses

Bases: StimulusDataset

Spatial flashes for spatial receptive field mapping.

Parameters:

Name Type Description Default
impulse_durations List[float]

List of impulse durations.

[0.005, 0.02]
max_extent int

Maximum extent of the stimulus.

4
dot_column_radius int

Radius of the dot column.

0
bg_intensity float

Background intensity.

0.5
t_stim float

Stimulus duration.

5
dt float

Time step.

0.005
n_ommatidia int

Number of ommatidia.

721
t_pre float

Pre-stimulus duration.

2.0
t_post float

Post-stimulus duration.

0
intensity float

Stimulus intensity.

1
mode str

Stimulus mode.

'impulse'
device device

Torch device for computations.

device

Attributes:

Name Type Description
arg_df Optional[DataFrame]

DataFrame containing stimulus parameters.

dt Optional[float]

Time step.

dots

Instance of the Dots class.

impulse_durations

List of impulse durations.

config

Configuration namespace.

params

List of stimulus parameters.

Source code in flyvision/datasets/dots.py
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
class SpatialImpulses(StimulusDataset):
    """Spatial flashes for spatial receptive field mapping.

    Args:
        impulse_durations: List of impulse durations.
        max_extent: Maximum extent of the stimulus.
        dot_column_radius: Radius of the dot column.
        bg_intensity: Background intensity.
        t_stim: Stimulus duration.
        dt: Time step.
        n_ommatidia: Number of ommatidia.
        t_pre: Pre-stimulus duration.
        t_post: Post-stimulus duration.
        intensity: Stimulus intensity.
        mode: Stimulus mode.
        device: Torch device for computations.

    Attributes:
        arg_df: DataFrame containing stimulus parameters.
        dt: Time step.
        dots: Instance of the Dots class.
        impulse_durations: List of impulse durations.
        config: Configuration namespace.
        params: List of stimulus parameters.
    """

    arg_df: Optional[pd.DataFrame] = None
    dt: Optional[float] = None

    def __init__(
        self,
        impulse_durations: List[float] = [5e-3, 20e-3],
        max_extent: int = 4,
        dot_column_radius: int = 0,
        bg_intensity: float = 0.5,
        t_stim: float = 5,
        dt: float = 0.005,
        n_ommatidia: int = 721,
        t_pre: float = 2.0,
        t_post: float = 0,
        intensity: float = 1,
        mode: str = "impulse",
        device: torch.device = flyvision.device,
    ):
        self.dots = Dots(
            dot_column_radius=dot_column_radius,
            max_extent=max_extent,
            bg_intensity=bg_intensity,
            t_stim=t_stim,
            dt=dt,
            n_ommatidia=n_ommatidia,
            t_pre=t_pre,
            t_post=t_post,
            intensity=intensity,
            mode=mode,
            device=device,
        )
        self.dt = dt
        self.impulse_durations = impulse_durations

        self.config = self.dots.config
        self.config.update(impulse_durations=impulse_durations)

        self.params = [
            (*p[0], p[1])
            for p in product(self.dots.arg_df.values.tolist(), impulse_durations)
        ]
        self.arg_df = pd.DataFrame(
            self.params,
            columns=[
                "u",
                "v",
                "offset",
                "coordinate_index",
                "intensity",
                "t_impulse",
            ],
        )

    def _params(self, key: int) -> np.ndarray:
        """Get parameters for a specific key.

        Args:
            key: Index of the parameters to retrieve.

        Returns:
            Array of parameters for the given key.
        """
        return self.arg_df.iloc[key].values

    def get_item(self, key: int) -> torch.Tensor:
        """Get a stimulus item for a specific key.

        Args:
            key: Index of the item to retrieve.

        Returns:
            Tensor representing the stimulus sequence.
        """
        u, v, offset, coordinate_index, intensity, t_impulse = self._params(key)
        self.dots.t_impulse = t_impulse
        return self.dots[self.dots.get_stimulus_index(u, v, intensity)]

    def __repr__(self) -> str:
        """Get string representation of the dataset."""
        return repr(self.arg_df)

get_item

get_item(key)

Get a stimulus item for a specific key.

Parameters:

Name Type Description Default
key int

Index of the item to retrieve.

required

Returns:

Type Description
Tensor

Tensor representing the stimulus sequence.

Source code in flyvision/datasets/dots.py
461
462
463
464
465
466
467
468
469
470
471
472
def get_item(self, key: int) -> torch.Tensor:
    """Get a stimulus item for a specific key.

    Args:
        key: Index of the item to retrieve.

    Returns:
        Tensor representing the stimulus sequence.
    """
    u, v, offset, coordinate_index, intensity, t_impulse = self._params(key)
    self.dots.t_impulse = t_impulse
    return self.dots[self.dots.get_stimulus_index(u, v, intensity)]

__repr__

__repr__()

Get string representation of the dataset.

Source code in flyvision/datasets/dots.py
474
475
476
def __repr__(self) -> str:
    """Get string representation of the dataset."""
    return repr(self.arg_df)

flyvision.datasets.dots.CentralImpulses

Bases: StimulusDataset

Flashes at the center of the visual field for temporal receptive field mapping.

Parameters:

Name Type Description Default
impulse_durations List[float]

List of impulse durations.

[0.005, 0.02, 0.05, 0.1, 0.2, 0.3]
dot_column_radius int

Radius of the dot column.

0
bg_intensity float

Background intensity.

0.5
t_stim float

Stimulus duration.

5
dt float

Time step.

0.005
n_ommatidia int

Number of ommatidia.

721
t_pre float

Pre-stimulus duration.

2.0
t_post float

Post-stimulus duration.

0
intensity float

Stimulus intensity.

1
mode str

Stimulus mode.

'impulse'
device device

Torch device for computations.

device

Attributes:

Name Type Description
arg_df Optional[DataFrame]

DataFrame containing stimulus parameters.

dt Optional[float]

Time step.

dots

Instance of the Dots class.

impulse_durations

List of impulse durations.

config

Configuration namespace.

params

List of stimulus parameters.

Source code in flyvision/datasets/dots.py
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
class CentralImpulses(StimulusDataset):
    """Flashes at the center of the visual field for temporal receptive field mapping.

    Args:
        impulse_durations: List of impulse durations.
        dot_column_radius: Radius of the dot column.
        bg_intensity: Background intensity.
        t_stim: Stimulus duration.
        dt: Time step.
        n_ommatidia: Number of ommatidia.
        t_pre: Pre-stimulus duration.
        t_post: Post-stimulus duration.
        intensity: Stimulus intensity.
        mode: Stimulus mode.
        device: Torch device for computations.

    Attributes:
        arg_df: DataFrame containing stimulus parameters.
        dt: Time step.
        dots: Instance of the Dots class.
        impulse_durations: List of impulse durations.
        config: Configuration namespace.
        params: List of stimulus parameters.
    """

    arg_df: Optional[pd.DataFrame] = None
    dt: Optional[float] = None

    def __init__(
        self,
        impulse_durations: List[float] = [5e-3, 20e-3, 50e-3, 100e-3, 200e-3, 300e-3],
        dot_column_radius: int = 0,
        bg_intensity: float = 0.5,
        t_stim: float = 5,
        dt: float = 0.005,
        n_ommatidia: int = 721,
        t_pre: float = 2.0,
        t_post: float = 0,
        intensity: float = 1,
        mode: str = "impulse",
        device: torch.device = flyvision.device,
    ):
        """Initialize the CentralImpulses dataset.

        Args:
            impulse_durations: List of impulse durations.
            dot_column_radius: Radius of the dot column.
            bg_intensity: Background intensity.
            t_stim: Stimulus duration.
            dt: Time step.
            n_ommatidia: Number of ommatidia.
            t_pre: Pre-stimulus duration.
            t_post: Post-stimulus duration.
            intensity: Stimulus intensity.
            mode: Stimulus mode.
            device: Torch device for computations.
        """
        self.dots = Dots(
            dot_column_radius=dot_column_radius,
            max_extent=dot_column_radius,
            bg_intensity=bg_intensity,
            t_stim=t_stim,
            dt=dt,
            n_ommatidia=n_ommatidia,
            t_pre=t_pre,
            t_post=t_post,
            intensity=intensity,
            mode=mode,
            device=device,
        )
        self.impulse_durations = impulse_durations
        self.config = self.dots.config
        self.config.update(impulse_durations=impulse_durations)
        self.params = [
            (*p[0], p[1])
            for p in product(self.dots.arg_df.values.tolist(), impulse_durations)
        ]
        self.arg_df = pd.DataFrame(
            self.params,
            columns=[
                "u",
                "v",
                "offset",
                "coordinate_index",
                "intensity",
                "t_impulse",
            ],
        )
        self.dt = dt

    def _params(self, key: int) -> np.ndarray:
        """Get parameters for a specific key.

        Args:
            key: Index of the parameters to retrieve.

        Returns:
            Array of parameters for the given key.
        """
        return self.arg_df.iloc[key].values

    def get_item(self, key: int) -> torch.Tensor:
        """Get a stimulus item for a specific key.

        Args:
            key: Index of the item to retrieve.

        Returns:
            Tensor representing the stimulus sequence.
        """
        u, v, offset, coordinate_index, intensity, t_impulse = self._params(key)
        self.dots.t_impulse = t_impulse
        return self.dots[self.dots.get_stimulus_index(u, v, intensity)]

    @property
    def t_pre(self) -> float:
        """Get pre-stimulus duration."""
        return self.dots.t_pre

    @property
    def t_post(self) -> float:
        """Get post-stimulus duration."""
        return self.dots.t_post

    def __repr__(self) -> str:
        """Get string representation of the dataset."""
        return repr(self.arg_df)

t_pre property

t_pre

Get pre-stimulus duration.

t_post property

t_post

Get post-stimulus duration.

__init__

__init__(
    impulse_durations=[0.005, 0.02, 0.05, 0.1, 0.2, 0.3],
    dot_column_radius=0,
    bg_intensity=0.5,
    t_stim=5,
    dt=0.005,
    n_ommatidia=721,
    t_pre=2.0,
    t_post=0,
    intensity=1,
    mode="impulse",
    device=flyvision.device,
)

Initialize the CentralImpulses dataset.

Parameters:

Name Type Description Default
impulse_durations List[float]

List of impulse durations.

[0.005, 0.02, 0.05, 0.1, 0.2, 0.3]
dot_column_radius int

Radius of the dot column.

0
bg_intensity float

Background intensity.

0.5
t_stim float

Stimulus duration.

5
dt float

Time step.

0.005
n_ommatidia int

Number of ommatidia.

721
t_pre float

Pre-stimulus duration.

2.0
t_post float

Post-stimulus duration.

0
intensity float

Stimulus intensity.

1
mode str

Stimulus mode.

'impulse'
device device

Torch device for computations.

device
Source code in flyvision/datasets/dots.py
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
def __init__(
    self,
    impulse_durations: List[float] = [5e-3, 20e-3, 50e-3, 100e-3, 200e-3, 300e-3],
    dot_column_radius: int = 0,
    bg_intensity: float = 0.5,
    t_stim: float = 5,
    dt: float = 0.005,
    n_ommatidia: int = 721,
    t_pre: float = 2.0,
    t_post: float = 0,
    intensity: float = 1,
    mode: str = "impulse",
    device: torch.device = flyvision.device,
):
    """Initialize the CentralImpulses dataset.

    Args:
        impulse_durations: List of impulse durations.
        dot_column_radius: Radius of the dot column.
        bg_intensity: Background intensity.
        t_stim: Stimulus duration.
        dt: Time step.
        n_ommatidia: Number of ommatidia.
        t_pre: Pre-stimulus duration.
        t_post: Post-stimulus duration.
        intensity: Stimulus intensity.
        mode: Stimulus mode.
        device: Torch device for computations.
    """
    self.dots = Dots(
        dot_column_radius=dot_column_radius,
        max_extent=dot_column_radius,
        bg_intensity=bg_intensity,
        t_stim=t_stim,
        dt=dt,
        n_ommatidia=n_ommatidia,
        t_pre=t_pre,
        t_post=t_post,
        intensity=intensity,
        mode=mode,
        device=device,
    )
    self.impulse_durations = impulse_durations
    self.config = self.dots.config
    self.config.update(impulse_durations=impulse_durations)
    self.params = [
        (*p[0], p[1])
        for p in product(self.dots.arg_df.values.tolist(), impulse_durations)
    ]
    self.arg_df = pd.DataFrame(
        self.params,
        columns=[
            "u",
            "v",
            "offset",
            "coordinate_index",
            "intensity",
            "t_impulse",
        ],
    )
    self.dt = dt

get_item

get_item(key)

Get a stimulus item for a specific key.

Parameters:

Name Type Description Default
key int

Index of the item to retrieve.

required

Returns:

Type Description
Tensor

Tensor representing the stimulus sequence.

Source code in flyvision/datasets/dots.py
343
344
345
346
347
348
349
350
351
352
353
354
def get_item(self, key: int) -> torch.Tensor:
    """Get a stimulus item for a specific key.

    Args:
        key: Index of the item to retrieve.

    Returns:
        Tensor representing the stimulus sequence.
    """
    u, v, offset, coordinate_index, intensity, t_impulse = self._params(key)
    self.dots.t_impulse = t_impulse
    return self.dots[self.dots.get_stimulus_index(u, v, intensity)]

__repr__

__repr__()

Get string representation of the dataset.

Source code in flyvision/datasets/dots.py
366
367
368
def __repr__(self) -> str:
    """Get string representation of the dataset."""
    return repr(self.arg_df)

Sintel

flyvision.datasets.sintel.MultiTaskSintel

Bases: MultiTaskDataset

Sintel dataset.

Parameters:

Name Type Description Default
tasks List[str]

List of tasks to include. May include ‘flow’, ‘lum’, or ‘depth’.

['flow']
boxfilter Dict[str, int]

Key word arguments for the BoxEye filter.

dict(extent=15, kernel_size=13)
vertical_splits int

Number of vertical splits of each frame.

3
n_frames int

Number of frames to render for each sequence.

19
center_crop_fraction float

Fraction of the image to keep after cropping.

0.7
dt float

Sampling and integration time constant.

1 / 50
augment bool

Turns augmentation on and off.

True
random_temporal_crop bool

Randomly crops a temporal window of length n_frames from each sequence.

True
all_frames bool

If True, all frames are returned. If False, only n_frames. Takes precedence over random_temporal_crop.

False
resampling bool

If True, piecewise-constant resamples the input sequence to the target framerate (1/dt).

True
interpolate bool

If True, linearly interpolates the target sequence to the target framerate (1/dt).

True
p_flip float

Probability of flipping the sequence across hexagonal axes.

0.5
p_rot float

Probability of rotating the sequence by n*60 degrees.

5 / 6
contrast_std float

Standard deviation of the contrast augmentation.

0.2
brightness_std float

Standard deviation of the brightness augmentation.

0.1
gaussian_white_noise float

Standard deviation of the pixel-wise gaussian white noise.

0.08
gamma_std Optional[float]

Standard deviation of the gamma augmentation.

None
_init_cache bool

If True, caches the dataset in memory.

True
unittest bool

If True, only renders a single sequence.

False
flip_axes List[int]

List of axes to flip over.

[0, 1]

Attributes:

Name Type Description
dt float

Sampling and integration time constant.

t_pre float

Warmup time.

t_post float

Cooldown time.

tasks List[str]

List of all tasks.

valid_tasks List[str]

List of valid task names.

Raises:

Type Description
ValueError

If any element in tasks is invalid.

Source code in flyvision/datasets/sintel.py
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
class MultiTaskSintel(MultiTaskDataset):
    """Sintel dataset.

    Args:
        tasks: List of tasks to include. May include 'flow', 'lum', or 'depth'.
        boxfilter: Key word arguments for the BoxEye filter.
        vertical_splits: Number of vertical splits of each frame.
        n_frames: Number of frames to render for each sequence.
        center_crop_fraction: Fraction of the image to keep after cropping.
        dt: Sampling and integration time constant.
        augment: Turns augmentation on and off.
        random_temporal_crop: Randomly crops a temporal window of length `n_frames` from
            each sequence.
        all_frames: If True, all frames are returned. If False, only `n_frames`. Takes
            precedence over `random_temporal_crop`.
        resampling: If True, piecewise-constant resamples the input sequence to the
            target framerate (1/dt).
        interpolate: If True, linearly interpolates the target sequence to the target
            framerate (1/dt).
        p_flip: Probability of flipping the sequence across hexagonal axes.
        p_rot: Probability of rotating the sequence by n*60 degrees.
        contrast_std: Standard deviation of the contrast augmentation.
        brightness_std: Standard deviation of the brightness augmentation.
        gaussian_white_noise: Standard deviation of the pixel-wise gaussian white noise.
        gamma_std: Standard deviation of the gamma augmentation.
        _init_cache: If True, caches the dataset in memory.
        unittest: If True, only renders a single sequence.
        flip_axes: List of axes to flip over.

    Attributes:
        dt (float): Sampling and integration time constant.
        t_pre (float): Warmup time.
        t_post (float): Cooldown time.
        tasks (List[str]): List of all tasks.
        valid_tasks (List[str]): List of valid task names.

    Raises:
        ValueError: If any element in tasks is invalid.
    """

    original_framerate: int = 24
    dt: float = 1 / 50
    t_pre: float = 0.0
    t_post: float = 0.0
    tasks: List[str] = []
    valid_tasks: List[str] = ["lum", "flow", "depth"]

    def __init__(
        self,
        tasks: List[str] = ["flow"],
        boxfilter: Dict[str, int] = dict(extent=15, kernel_size=13),
        vertical_splits: int = 3,
        n_frames: int = 19,
        center_crop_fraction: float = 0.7,
        dt: float = 1 / 50,
        augment: bool = True,
        random_temporal_crop: bool = True,
        all_frames: bool = False,
        resampling: bool = True,
        interpolate: bool = True,
        p_flip: float = 0.5,
        p_rot: float = 5 / 6,
        contrast_std: float = 0.2,
        brightness_std: float = 0.1,
        gaussian_white_noise: float = 0.08,
        gamma_std: Optional[float] = None,
        _init_cache: bool = True,
        unittest: bool = False,
        flip_axes: List[int] = [0, 1],
    ):
        def check_tasks(tasks):
            invalid_tasks = [x for x in tasks if x not in self.valid_tasks]
            if invalid_tasks:
                raise ValueError(f"invalid tasks {invalid_tasks}")

            tasks = [v for v in self.valid_tasks if v in tasks]  # sort
            # because the input 'lum' is always required
            data_keys = tasks if "lum" in tasks else ["lum", *tasks]
            return tasks, data_keys

        self.tasks, self.data_keys = check_tasks(tasks)
        self.interpolate = interpolate
        self.n_frames = n_frames if not unittest else 3
        self.dt = dt

        self.all_frames = all_frames
        self.resampling = resampling

        self.boxfilter = boxfilter
        self.extent = boxfilter["extent"]
        assert vertical_splits >= 1
        self.vertical_splits = vertical_splits
        self.center_crop_fraction = center_crop_fraction

        self.p_flip = p_flip
        self.p_rot = p_rot
        self.contrast_std = contrast_std
        self.brightness_std = brightness_std
        self.gaussian_white_noise = gaussian_white_noise
        self.gamma_std = gamma_std
        self.random_temporal_crop = random_temporal_crop
        self.flip_axes = flip_axes
        self.fix_augmentation_params = False

        self.init_augmentation()
        self._augmentations_are_initialized = True
        # note: self.augment is a property with a setter that relies on
        # _augmentations_are_initialized
        self.augment = augment

        self.unittest = unittest

        self.sintel_path = download_sintel(depth="depth" in tasks)
        self.rendered = RenderedSintel(
            tasks=tasks,
            boxfilter=boxfilter,
            vertical_splits=vertical_splits,
            n_frames=n_frames,
            center_crop_fraction=center_crop_fraction,
            unittest=unittest,
        )
        self.meta = sintel_meta(
            self.rendered, self.sintel_path, n_frames, vertical_splits, "depth" in tasks
        )

        self.config = Namespace(
            tasks=tasks,
            interpolate=interpolate,
            n_frames=n_frames,
            dt=dt,
            augment=augment,
            all_frames=all_frames,
            resampling=resampling,
            random_temporal_crop=random_temporal_crop,
            boxfilter=boxfilter,
            vertical_splits=vertical_splits,
            p_flip=p_flip,
            p_rot=p_rot,
            contrast_std=contrast_std,
            brightness_std=brightness_std,
            gaussian_white_noise=gaussian_white_noise,
            gamma_std=gamma_std,
            center_crop_fraction=center_crop_fraction,
            flip_axes=flip_axes,
        )

        self.arg_df = pd.DataFrame(
            dict(
                index=np.arange(len(self.rendered)),
                original_index=self.meta.sequence_indices.repeat(vertical_splits),
                name=sorted(self.rendered.keys()),
                original_n_frames=self.meta.frames_per_scene.repeat(vertical_splits),
            )
        )

        if _init_cache:
            self.init_cache()

    def init_cache(self) -> None:
        """Initialize the cache with preprocessed sequences."""
        self.cached_sequences = [
            {
                key: torch.tensor(val, dtype=torch.float32)
                for key, val in self.rendered(seq_id).items()
                if key in self.data_keys
            }
            for seq_id in range(len(self))
        ]

    def __repr__(self) -> str:
        repr = f"{self.__class__.__name__} with {len(self)} sequences.\n"
        repr += "See docs, arg_df and meta for more details.\n"
        return repr

    @property
    def docs(self) -> str:
        print(self.__doc__)

    def __setattr__(self, name: str, value: Any) -> None:
        """Custom attribute setter to handle special cases and update augmentation.

        Args:
            name: Name of the attribute to set.
            value: Value to set the attribute to.

        Raises:
            AttributeError: If trying to change framerate or rendered initialization
                attributes.
        """
        # some changes have no effect cause they are fixed, or set by the pre-rendering
        if name == "framerate":
            raise AttributeError("cannot change framerate")
        if hasattr(self, "rendered") and name in self.rendered.config:
            raise AttributeError("cannot change attribute of rendered initialization")
        super().__setattr__(name, value)
        # also update augmentation because it may already be initialized
        if getattr(self, "_augmentations_are_initialized", False):
            self.update_augmentation(name, value)

    def init_augmentation(self) -> None:
        """Initialize augmentation callables."""
        self.temporal_crop = CropFrames(
            self.n_frames, all_frames=self.all_frames, random=self.random_temporal_crop
        )
        self.jitter = ContrastBrightness(
            contrast_std=self.contrast_std, brightness_std=self.brightness_std
        )
        self.rotate = HexRotate(self.extent, p_rot=self.p_rot)
        self.flip = HexFlip(self.extent, p_flip=self.p_flip, flip_axes=self.flip_axes)
        self.noise = PixelNoise(self.gaussian_white_noise)

        self.piecewise_resample = Interpolate(
            self.original_framerate, 1 / self.dt, mode="nearest-exact"
        )
        self.linear_interpolate = Interpolate(
            self.original_framerate,
            1 / self.dt,
            mode="linear",
        )
        self.gamma_correct = GammaCorrection(1, self.gamma_std)

    def update_augmentation(self, name: str, value: Any) -> None:
        """Update augmentation parameters based on attribute changes.

        Args:
            name: Name of the attribute that changed.
            value: New value of the attribute.
        """
        if name == "dt":
            self.piecewise_resample.target_framerate = 1 / value
            self.linear_interpolate.target_framerate = 1 / value
        if name in ["all_frames", "random_temporal_crop"]:
            self.temporal_crop.all_frames = value
            self.temporal_crop.random = value
        if name in ["contrast_std", "brightness_std"]:
            self.jitter.contrast_std = value
            self.jitter.brightness_std = value
        if name == "p_rot":
            self.rotate.p_rot = value
        if name == "p_flip":
            self.flip.p_flip = value
        if name == "gaussian_white_noise":
            self.noise.std = value
        if name == "gamma_std":
            self.gamma_correct.std = value

    def set_augmentation_params(
        self,
        n_rot: Optional[int] = None,
        flip_axis: Optional[int] = None,
        contrast_factor: Optional[float] = None,
        brightness_factor: Optional[float] = None,
        gaussian_white_noise: Optional[float] = None,
        gamma: Optional[float] = None,
        start_frame: Optional[int] = None,
        total_sequence_length: Optional[int] = None,
    ) -> None:
        """Set augmentation callable parameters.

        Info:
            Called for each call of get_item.

        Args:
            n_rot: Number of rotations to apply.
            flip_axis: Axis to flip over.
            contrast_factor: Contrast factor for jitter augmentation.
            brightness_factor: Brightness factor for jitter augmentation.
            gaussian_white_noise: Standard deviation for noise augmentation.
            gamma: Gamma value for gamma correction.
            start_frame: Starting frame for temporal crop.
            total_sequence_length: Total length of the sequence.
        """
        if not self.fix_augmentation_params:
            self.rotate.set_or_sample(n_rot)
            self.flip.set_or_sample(flip_axis)
            self.jitter.set_or_sample(contrast_factor, brightness_factor)
            self.noise.set_or_sample(gaussian_white_noise)
            self.gamma_correct.set_or_sample(gamma)
            self.temporal_crop.set_or_sample(
                start=start_frame, total_sequence_length=total_sequence_length
            )

    def get_item(self, key: int) -> Dict[str, torch.Tensor]:
        """Return a dataset sample.

        Args:
            key: Index of the sample to retrieve.

        Returns:
            Dictionary containing the augmented sample data.
        """
        return self.apply_augmentation(self.cached_sequences[key])

    @contextmanager
    def augmentation(self, abool: bool):
        """Context manager to turn augmentation on or off in a code block.

        Args:
            abool: Boolean value to set augmentation state.

        Example:
            ```python
            with dataset.augmentation(True):
                for i, data in enumerate(dataloader):
                    ...  # all data is augmented
            ```
        """
        augmentations = [
            "temporal_crop",
            "jitter",
            "rotate",
            "flip",
            "noise",
            "piecewise_resample",
            "linear_interpolate",
            "gamma_correct",
        ]
        states = {key: getattr(self, key).augment for key in augmentations}
        _augment = self.augment
        try:
            self.augment = abool
            yield
        finally:
            self.augment = _augment
            for key in augmentations:
                getattr(self, key).augment = states[key]

    @property
    def augment(self) -> bool:
        """Get the current augmentation state."""
        return self._augment

    @augment.setter
    def augment(self, value: bool) -> None:
        """Set the augmentation state and update augmentation callables.

        Args:
            value: Boolean value to set augmentation state.
        """
        self._augment = value
        if not self._augmentations_are_initialized:
            return
        # note: random_temporal_crop can override augment=True
        self.temporal_crop.random = self.random_temporal_crop if value else False
        self.jitter.augment = value
        self.rotate.augment = value
        self.flip.augment = value
        self.noise.augment = value
        # note: these two are not affected by augment
        self.piecewise_resample.augment = self.resampling
        self.linear_interpolate.augment = self.interpolate
        self.gamma_correct.augment = value

    def apply_augmentation(
        self,
        data: Dict[str, torch.Tensor],
        n_rot: Optional[int] = None,
        flip_axis: Optional[int] = None,
        contrast_factor: Optional[float] = None,
        brightness_factor: Optional[float] = None,
        gaussian_white_noise: Optional[float] = None,
        gamma: Optional[float] = None,
    ) -> Dict[str, torch.Tensor]:
        """Apply augmentation to a sample from the dataset.

        Args:
            data: Dictionary containing the sample data.
            n_rot: Number of rotations to apply.
            flip_axis: Axis to flip over.
            contrast_factor: Contrast factor for jitter augmentation.
            brightness_factor: Brightness factor for jitter augmentation.
            gaussian_white_noise: Standard deviation for noise augmentation.
            gamma: Gamma value for gamma correction.

        Returns:
            Dictionary containing the augmented sample data.
        """

        self.set_augmentation_params(
            n_rot=n_rot,
            flip_axis=flip_axis,
            contrast_factor=contrast_factor,
            brightness_factor=brightness_factor,
            gaussian_white_noise=gaussian_white_noise,
            gamma=gamma,
            start_frame=None,
            total_sequence_length=data["lum"].shape[0],
        )

        def transform_lum(lum):
            return self.piecewise_resample(
                self.rotate(
                    self.flip(
                        self.jitter(
                            self.noise(self.temporal_crop(lum)),
                        ),
                    )
                )
            )

        def transform_target(target):
            if self.interpolate:
                return self.linear_interpolate(
                    self.rotate(self.flip(self.temporal_crop(target)))
                )
            return self.piecewise_resample(
                self.rotate(self.flip(self.temporal_crop(target)))
            )

        return {
            **{"lum": transform_lum(data["lum"])},
            **{
                target: transform_target(data[target])
                for target in self.tasks
                if target in ["flow", "depth"]
            },
        }

    def original_sequence_index(self, key: int) -> int:
        """Get the original sequence index from an index of the split.

        Args:
            key: Index of the split.

        Returns:
            Original sequence index.

        Raises:
            ValueError: If the key is not found in splits.
        """
        for index, splits in self.meta.sequence_index_to_splits.items():
            if key in splits:
                return index
        raise ValueError(f"key {key} not found in splits")

    def cartesian_sequence(
        self,
        key: int,
        vertical_splits: Optional[int] = None,
        outwidth: int = 716,
        center_crop_fraction: Optional[float] = None,
        sampling: slice = slice(1, None, None),
    ) -> np.ndarray:
        """Return the cartesian sequence of a fly eye rendered sequence.

        Args:
            key: Index of the sequence.
            vertical_splits: Number of vertical splits to apply.
            outwidth: Output width of the sequence.
            center_crop_fraction: Fraction of the image to keep after cropping.
            sampling: Slice object for sampling frames.

        Returns:
            Numpy array containing the cartesian sequence.
        """
        # we want to retrieve the original scene which is possibly split
        # into multiple ones
        key = self.original_sequence_index(key)
        lum_path = self.meta.lum_paths[key]
        images = np.array([
            sample_lum(path) for path in sorted(lum_path.iterdir())[sampling]
        ])
        return split(
            images,
            outwidth,
            vertical_splits or self.vertical_splits,
            center_crop_fraction or self.center_crop_fraction,
        )

    def cartesian_flow(
        self,
        key: int,
        vertical_splits: Optional[int] = None,
        outwidth: int = 417,
        center_crop_fraction: Optional[float] = None,
        sampling: slice = slice(None, None, None),
    ) -> np.ndarray:
        """Return the cartesian flow of a fly eye rendered flow.

        Args:
            key: Index of the sequence.
            vertical_splits: Number of vertical splits to apply.
            outwidth: Output width of the flow.
            center_crop_fraction: Fraction of the image to keep after cropping.
            sampling: Slice object for sampling frames.

        Returns:
            Numpy array containing the cartesian flow.
        """
        key = self.original_sequence_index(key)
        flow_path = self.meta.flow_paths[key]
        flow = np.array([
            sample_flow(path) for path in sorted(flow_path.iterdir())[sampling]
        ])

        return split(
            flow,
            outwidth,
            vertical_splits or self.vertical_splits,
            center_crop_fraction or self.center_crop_fraction,
        )

    def cartesian_depth(
        self,
        key: int,
        vertical_splits: Optional[int] = None,
        outwidth: int = 417,
        center_crop_fraction: Optional[float] = None,
        sampling: slice = slice(1, None, None),
    ) -> np.ndarray:
        """Return the cartesian depth of a fly eye rendered depth.

        Args:
            key: Index of the sequence.
            vertical_splits: Number of vertical splits to apply.
            outwidth: Output width of the depth.
            center_crop_fraction: Fraction of the image to keep after cropping.
            sampling: Slice object for sampling frames.

        Returns:
            Numpy array containing the cartesian depth.
        """
        key = self.original_sequence_index(key)
        flow_path = self.meta.depth_paths[key]
        depth = np.array([
            sample_depth(path) for path in sorted(flow_path.iterdir())[sampling]
        ])

        return split(
            depth,
            outwidth,
            vertical_splits or self.vertical_splits,
            center_crop_fraction or self.center_crop_fraction,
        )

    def original_train_and_validation_indices(self) -> Tuple[List[int], List[int]]:
        """Get original training and validation indices for the dataloader.

        Returns:
            Tuple containing lists of train and validation indices.
        """
        return original_train_and_validation_indices(self)

augment property writable

augment

Get the current augmentation state.

init_cache

init_cache()

Initialize the cache with preprocessed sequences.

Source code in flyvision/datasets/sintel.py
337
338
339
340
341
342
343
344
345
346
def init_cache(self) -> None:
    """Initialize the cache with preprocessed sequences."""
    self.cached_sequences = [
        {
            key: torch.tensor(val, dtype=torch.float32)
            for key, val in self.rendered(seq_id).items()
            if key in self.data_keys
        }
        for seq_id in range(len(self))
    ]

__setattr__

__setattr__(name, value)

Custom attribute setter to handle special cases and update augmentation.

Parameters:

Name Type Description Default
name str

Name of the attribute to set.

required
value Any

Value to set the attribute to.

required

Raises:

Type Description
AttributeError

If trying to change framerate or rendered initialization attributes.

Source code in flyvision/datasets/sintel.py
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
def __setattr__(self, name: str, value: Any) -> None:
    """Custom attribute setter to handle special cases and update augmentation.

    Args:
        name: Name of the attribute to set.
        value: Value to set the attribute to.

    Raises:
        AttributeError: If trying to change framerate or rendered initialization
            attributes.
    """
    # some changes have no effect cause they are fixed, or set by the pre-rendering
    if name == "framerate":
        raise AttributeError("cannot change framerate")
    if hasattr(self, "rendered") and name in self.rendered.config:
        raise AttributeError("cannot change attribute of rendered initialization")
    super().__setattr__(name, value)
    # also update augmentation because it may already be initialized
    if getattr(self, "_augmentations_are_initialized", False):
        self.update_augmentation(name, value)

init_augmentation

init_augmentation()

Initialize augmentation callables.

Source code in flyvision/datasets/sintel.py
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
def init_augmentation(self) -> None:
    """Initialize augmentation callables."""
    self.temporal_crop = CropFrames(
        self.n_frames, all_frames=self.all_frames, random=self.random_temporal_crop
    )
    self.jitter = ContrastBrightness(
        contrast_std=self.contrast_std, brightness_std=self.brightness_std
    )
    self.rotate = HexRotate(self.extent, p_rot=self.p_rot)
    self.flip = HexFlip(self.extent, p_flip=self.p_flip, flip_axes=self.flip_axes)
    self.noise = PixelNoise(self.gaussian_white_noise)

    self.piecewise_resample = Interpolate(
        self.original_framerate, 1 / self.dt, mode="nearest-exact"
    )
    self.linear_interpolate = Interpolate(
        self.original_framerate,
        1 / self.dt,
        mode="linear",
    )
    self.gamma_correct = GammaCorrection(1, self.gamma_std)

update_augmentation

update_augmentation(name, value)

Update augmentation parameters based on attribute changes.

Parameters:

Name Type Description Default
name str

Name of the attribute that changed.

required
value Any

New value of the attribute.

required
Source code in flyvision/datasets/sintel.py
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
def update_augmentation(self, name: str, value: Any) -> None:
    """Update augmentation parameters based on attribute changes.

    Args:
        name: Name of the attribute that changed.
        value: New value of the attribute.
    """
    if name == "dt":
        self.piecewise_resample.target_framerate = 1 / value
        self.linear_interpolate.target_framerate = 1 / value
    if name in ["all_frames", "random_temporal_crop"]:
        self.temporal_crop.all_frames = value
        self.temporal_crop.random = value
    if name in ["contrast_std", "brightness_std"]:
        self.jitter.contrast_std = value
        self.jitter.brightness_std = value
    if name == "p_rot":
        self.rotate.p_rot = value
    if name == "p_flip":
        self.flip.p_flip = value
    if name == "gaussian_white_noise":
        self.noise.std = value
    if name == "gamma_std":
        self.gamma_correct.std = value

set_augmentation_params

set_augmentation_params(
    n_rot=None,
    flip_axis=None,
    contrast_factor=None,
    brightness_factor=None,
    gaussian_white_noise=None,
    gamma=None,
    start_frame=None,
    total_sequence_length=None,
)

Set augmentation callable parameters.

Info

Called for each call of get_item.

Parameters:

Name Type Description Default
n_rot Optional[int]

Number of rotations to apply.

None
flip_axis Optional[int]

Axis to flip over.

None
contrast_factor Optional[float]

Contrast factor for jitter augmentation.

None
brightness_factor Optional[float]

Brightness factor for jitter augmentation.

None
gaussian_white_noise Optional[float]

Standard deviation for noise augmentation.

None
gamma Optional[float]

Gamma value for gamma correction.

None
start_frame Optional[int]

Starting frame for temporal crop.

None
total_sequence_length Optional[int]

Total length of the sequence.

None
Source code in flyvision/datasets/sintel.py
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
def set_augmentation_params(
    self,
    n_rot: Optional[int] = None,
    flip_axis: Optional[int] = None,
    contrast_factor: Optional[float] = None,
    brightness_factor: Optional[float] = None,
    gaussian_white_noise: Optional[float] = None,
    gamma: Optional[float] = None,
    start_frame: Optional[int] = None,
    total_sequence_length: Optional[int] = None,
) -> None:
    """Set augmentation callable parameters.

    Info:
        Called for each call of get_item.

    Args:
        n_rot: Number of rotations to apply.
        flip_axis: Axis to flip over.
        contrast_factor: Contrast factor for jitter augmentation.
        brightness_factor: Brightness factor for jitter augmentation.
        gaussian_white_noise: Standard deviation for noise augmentation.
        gamma: Gamma value for gamma correction.
        start_frame: Starting frame for temporal crop.
        total_sequence_length: Total length of the sequence.
    """
    if not self.fix_augmentation_params:
        self.rotate.set_or_sample(n_rot)
        self.flip.set_or_sample(flip_axis)
        self.jitter.set_or_sample(contrast_factor, brightness_factor)
        self.noise.set_or_sample(gaussian_white_noise)
        self.gamma_correct.set_or_sample(gamma)
        self.temporal_crop.set_or_sample(
            start=start_frame, total_sequence_length=total_sequence_length
        )

get_item

get_item(key)

Return a dataset sample.

Parameters:

Name Type Description Default
key int

Index of the sample to retrieve.

required

Returns:

Type Description
Dict[str, Tensor]

Dictionary containing the augmented sample data.

Source code in flyvision/datasets/sintel.py
461
462
463
464
465
466
467
468
469
470
def get_item(self, key: int) -> Dict[str, torch.Tensor]:
    """Return a dataset sample.

    Args:
        key: Index of the sample to retrieve.

    Returns:
        Dictionary containing the augmented sample data.
    """
    return self.apply_augmentation(self.cached_sequences[key])

augmentation

augmentation(abool)

Context manager to turn augmentation on or off in a code block.

Parameters:

Name Type Description Default
abool bool

Boolean value to set augmentation state.

required
Example
with dataset.augmentation(True):
    for i, data in enumerate(dataloader):
        ...  # all data is augmented
Source code in flyvision/datasets/sintel.py
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
@contextmanager
def augmentation(self, abool: bool):
    """Context manager to turn augmentation on or off in a code block.

    Args:
        abool: Boolean value to set augmentation state.

    Example:
        ```python
        with dataset.augmentation(True):
            for i, data in enumerate(dataloader):
                ...  # all data is augmented
        ```
    """
    augmentations = [
        "temporal_crop",
        "jitter",
        "rotate",
        "flip",
        "noise",
        "piecewise_resample",
        "linear_interpolate",
        "gamma_correct",
    ]
    states = {key: getattr(self, key).augment for key in augmentations}
    _augment = self.augment
    try:
        self.augment = abool
        yield
    finally:
        self.augment = _augment
        for key in augmentations:
            getattr(self, key).augment = states[key]

apply_augmentation

apply_augmentation(
    data,
    n_rot=None,
    flip_axis=None,
    contrast_factor=None,
    brightness_factor=None,
    gaussian_white_noise=None,
    gamma=None,
)

Apply augmentation to a sample from the dataset.

Parameters:

Name Type Description Default
data Dict[str, Tensor]

Dictionary containing the sample data.

required
n_rot Optional[int]

Number of rotations to apply.

None
flip_axis Optional[int]

Axis to flip over.

None
contrast_factor Optional[float]

Contrast factor for jitter augmentation.

None
brightness_factor Optional[float]

Brightness factor for jitter augmentation.

None
gaussian_white_noise Optional[float]

Standard deviation for noise augmentation.

None
gamma Optional[float]

Gamma value for gamma correction.

None

Returns:

Type Description
Dict[str, Tensor]

Dictionary containing the augmented sample data.

Source code in flyvision/datasets/sintel.py
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
def apply_augmentation(
    self,
    data: Dict[str, torch.Tensor],
    n_rot: Optional[int] = None,
    flip_axis: Optional[int] = None,
    contrast_factor: Optional[float] = None,
    brightness_factor: Optional[float] = None,
    gaussian_white_noise: Optional[float] = None,
    gamma: Optional[float] = None,
) -> Dict[str, torch.Tensor]:
    """Apply augmentation to a sample from the dataset.

    Args:
        data: Dictionary containing the sample data.
        n_rot: Number of rotations to apply.
        flip_axis: Axis to flip over.
        contrast_factor: Contrast factor for jitter augmentation.
        brightness_factor: Brightness factor for jitter augmentation.
        gaussian_white_noise: Standard deviation for noise augmentation.
        gamma: Gamma value for gamma correction.

    Returns:
        Dictionary containing the augmented sample data.
    """

    self.set_augmentation_params(
        n_rot=n_rot,
        flip_axis=flip_axis,
        contrast_factor=contrast_factor,
        brightness_factor=brightness_factor,
        gaussian_white_noise=gaussian_white_noise,
        gamma=gamma,
        start_frame=None,
        total_sequence_length=data["lum"].shape[0],
    )

    def transform_lum(lum):
        return self.piecewise_resample(
            self.rotate(
                self.flip(
                    self.jitter(
                        self.noise(self.temporal_crop(lum)),
                    ),
                )
            )
        )

    def transform_target(target):
        if self.interpolate:
            return self.linear_interpolate(
                self.rotate(self.flip(self.temporal_crop(target)))
            )
        return self.piecewise_resample(
            self.rotate(self.flip(self.temporal_crop(target)))
        )

    return {
        **{"lum": transform_lum(data["lum"])},
        **{
            target: transform_target(data[target])
            for target in self.tasks
            if target in ["flow", "depth"]
        },
    }

original_sequence_index

original_sequence_index(key)

Get the original sequence index from an index of the split.

Parameters:

Name Type Description Default
key int

Index of the split.

required

Returns:

Type Description
int

Original sequence index.

Raises:

Type Description
ValueError

If the key is not found in splits.

Source code in flyvision/datasets/sintel.py
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
def original_sequence_index(self, key: int) -> int:
    """Get the original sequence index from an index of the split.

    Args:
        key: Index of the split.

    Returns:
        Original sequence index.

    Raises:
        ValueError: If the key is not found in splits.
    """
    for index, splits in self.meta.sequence_index_to_splits.items():
        if key in splits:
            return index
    raise ValueError(f"key {key} not found in splits")

cartesian_sequence

cartesian_sequence(
    key,
    vertical_splits=None,
    outwidth=716,
    center_crop_fraction=None,
    sampling=slice(1, None, None),
)

Return the cartesian sequence of a fly eye rendered sequence.

Parameters:

Name Type Description Default
key int

Index of the sequence.

required
vertical_splits Optional[int]

Number of vertical splits to apply.

None
outwidth int

Output width of the sequence.

716
center_crop_fraction Optional[float]

Fraction of the image to keep after cropping.

None
sampling slice

Slice object for sampling frames.

slice(1, None, None)

Returns:

Type Description
ndarray

Numpy array containing the cartesian sequence.

Source code in flyvision/datasets/sintel.py
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
def cartesian_sequence(
    self,
    key: int,
    vertical_splits: Optional[int] = None,
    outwidth: int = 716,
    center_crop_fraction: Optional[float] = None,
    sampling: slice = slice(1, None, None),
) -> np.ndarray:
    """Return the cartesian sequence of a fly eye rendered sequence.

    Args:
        key: Index of the sequence.
        vertical_splits: Number of vertical splits to apply.
        outwidth: Output width of the sequence.
        center_crop_fraction: Fraction of the image to keep after cropping.
        sampling: Slice object for sampling frames.

    Returns:
        Numpy array containing the cartesian sequence.
    """
    # we want to retrieve the original scene which is possibly split
    # into multiple ones
    key = self.original_sequence_index(key)
    lum_path = self.meta.lum_paths[key]
    images = np.array([
        sample_lum(path) for path in sorted(lum_path.iterdir())[sampling]
    ])
    return split(
        images,
        outwidth,
        vertical_splits or self.vertical_splits,
        center_crop_fraction or self.center_crop_fraction,
    )

cartesian_flow

cartesian_flow(
    key,
    vertical_splits=None,
    outwidth=417,
    center_crop_fraction=None,
    sampling=slice(None, None, None),
)

Return the cartesian flow of a fly eye rendered flow.

Parameters:

Name Type Description Default
key int

Index of the sequence.

required
vertical_splits Optional[int]

Number of vertical splits to apply.

None
outwidth int

Output width of the flow.

417
center_crop_fraction Optional[float]

Fraction of the image to keep after cropping.

None
sampling slice

Slice object for sampling frames.

slice(None, None, None)

Returns:

Type Description
ndarray

Numpy array containing the cartesian flow.

Source code in flyvision/datasets/sintel.py
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
def cartesian_flow(
    self,
    key: int,
    vertical_splits: Optional[int] = None,
    outwidth: int = 417,
    center_crop_fraction: Optional[float] = None,
    sampling: slice = slice(None, None, None),
) -> np.ndarray:
    """Return the cartesian flow of a fly eye rendered flow.

    Args:
        key: Index of the sequence.
        vertical_splits: Number of vertical splits to apply.
        outwidth: Output width of the flow.
        center_crop_fraction: Fraction of the image to keep after cropping.
        sampling: Slice object for sampling frames.

    Returns:
        Numpy array containing the cartesian flow.
    """
    key = self.original_sequence_index(key)
    flow_path = self.meta.flow_paths[key]
    flow = np.array([
        sample_flow(path) for path in sorted(flow_path.iterdir())[sampling]
    ])

    return split(
        flow,
        outwidth,
        vertical_splits or self.vertical_splits,
        center_crop_fraction or self.center_crop_fraction,
    )

cartesian_depth

cartesian_depth(
    key,
    vertical_splits=None,
    outwidth=417,
    center_crop_fraction=None,
    sampling=slice(1, None, None),
)

Return the cartesian depth of a fly eye rendered depth.

Parameters:

Name Type Description Default
key int

Index of the sequence.

required
vertical_splits Optional[int]

Number of vertical splits to apply.

None
outwidth int

Output width of the depth.

417
center_crop_fraction Optional[float]

Fraction of the image to keep after cropping.

None
sampling slice

Slice object for sampling frames.

slice(1, None, None)

Returns:

Type Description
ndarray

Numpy array containing the cartesian depth.

Source code in flyvision/datasets/sintel.py
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
def cartesian_depth(
    self,
    key: int,
    vertical_splits: Optional[int] = None,
    outwidth: int = 417,
    center_crop_fraction: Optional[float] = None,
    sampling: slice = slice(1, None, None),
) -> np.ndarray:
    """Return the cartesian depth of a fly eye rendered depth.

    Args:
        key: Index of the sequence.
        vertical_splits: Number of vertical splits to apply.
        outwidth: Output width of the depth.
        center_crop_fraction: Fraction of the image to keep after cropping.
        sampling: Slice object for sampling frames.

    Returns:
        Numpy array containing the cartesian depth.
    """
    key = self.original_sequence_index(key)
    flow_path = self.meta.depth_paths[key]
    depth = np.array([
        sample_depth(path) for path in sorted(flow_path.iterdir())[sampling]
    ])

    return split(
        depth,
        outwidth,
        vertical_splits or self.vertical_splits,
        center_crop_fraction or self.center_crop_fraction,
    )

original_train_and_validation_indices

original_train_and_validation_indices()

Get original training and validation indices for the dataloader.

Returns:

Type Description
Tuple[List[int], List[int]]

Tuple containing lists of train and validation indices.

Source code in flyvision/datasets/sintel.py
714
715
716
717
718
719
720
def original_train_and_validation_indices(self) -> Tuple[List[int], List[int]]:
    """Get original training and validation indices for the dataloader.

    Returns:
        Tuple containing lists of train and validation indices.
    """
    return original_train_and_validation_indices(self)