Skip to content

Hex Augmentations

Geometric Transformations

flyvis.datasets.augmentation.hex.HexRotate

Bases: Augmentation

Rotate a sequence of regular hex-lattices by multiple of 60 degree.

Parameters:

Name Type Description Default
extent int

Extent of the regular hexagonal grid in columns.

required
n_rot int

Number of 60 degree rotations. 0-5.

0
p_rot float

Probability of rotating. If None, no rotation is performed.

0.5

Attributes:

Name Type Description
extent int

Extent of the regular hexagonal grid in columns.

n_rot int

Number of 60 degree rotations. 0-5.

p_rot float

Probability of rotating.

permutation_indices dict

Cached indices for rotation.

rotation_matrices dict

Cached rotation matrices for rotation.

Source code in flyvis/datasets/augmentation/hex.py
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
class HexRotate(Augmentation):
    """Rotate a sequence of regular hex-lattices by multiple of 60 degree.

    Args:
        extent: Extent of the regular hexagonal grid in columns.
        n_rot: Number of 60 degree rotations. 0-5.
        p_rot: Probability of rotating. If None, no rotation is performed.

    Attributes:
        extent (int): Extent of the regular hexagonal grid in columns.
        n_rot (int): Number of 60 degree rotations. 0-5.
        p_rot (float): Probability of rotating.
        permutation_indices (dict): Cached indices for rotation.
        rotation_matrices (dict): Cached rotation matrices for rotation.
    """

    def __init__(self, extent: int, n_rot: int = 0, p_rot: float = 0.5) -> None:
        self.extent = extent
        self.rotation_matrices: dict = {}
        self.permutation_indices: dict = {}
        for n in range(6):
            R2 = rotation_matrix(n * 60 * np.pi / 180, three_d=False)
            R3 = rotation_matrix(n * 60 * np.pi / 180, three_d=True)
            self.rotation_matrices[n] = [R2, R3]
            self.permutation_indices[n] = rotation_permutation_index(extent, n)
        self.n_rot = n_rot
        self.p_rot = p_rot
        self.set_or_sample(n_rot)

    @property
    def n_rot(self) -> int:
        """Get the number of rotations."""
        return self._n_rot

    @n_rot.setter
    def n_rot(self, n_rot: int) -> None:
        """Set the number of rotations."""
        self._n_rot = n_rot % 6

    def rotate(self, seq: torch.Tensor) -> torch.Tensor:
        """Rotate a sequence on a regular hexagonal lattice.

        Args:
            seq: Sequence of shape (frames, dims, hexals).

        Returns:
            Rotated sequence of the same shape as seq.
        """
        dims = seq.shape[-2]
        seq = seq[..., self.permutation_indices[self.n_rot]]
        if dims > 1:
            seq = self.rotation_matrices[self.n_rot][dims - 2] @ seq
        return seq

    def transform(self, seq: torch.Tensor, n_rot: Optional[int] = None) -> torch.Tensor:
        """Rotate a sequence on a regular hexagonal lattice.

        Args:
            seq: Sequence of shape (frames, dims, hexals).
            n_rot: Optional number of rotations to apply.

        Returns:
            Rotated sequence of the same shape as seq.
        """
        if n_rot is not None:
            self.n_rot = n_rot
        if self.n_rot > 0:
            return self.rotate(seq)
        return seq

    def set_or_sample(self, n_rot: Optional[int] = None) -> None:
        """Set or sample the number of rotations.

        Args:
            n_rot: Number of rotations to set. If None, sample randomly.
        """
        if n_rot is None:
            n_rot = (
                np.random.randint(low=1, high=6)
                if self.p_rot and self.p_rot > np.random.rand()
                else 0
            )
        self.n_rot = n_rot

n_rot property writable

n_rot

Get the number of rotations.

rotate

rotate(seq)

Rotate a sequence on a regular hexagonal lattice.

Parameters:

Name Type Description Default
seq Tensor

Sequence of shape (frames, dims, hexals).

required

Returns:

Type Description
Tensor

Rotated sequence of the same shape as seq.

Source code in flyvis/datasets/augmentation/hex.py
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def rotate(self, seq: torch.Tensor) -> torch.Tensor:
    """Rotate a sequence on a regular hexagonal lattice.

    Args:
        seq: Sequence of shape (frames, dims, hexals).

    Returns:
        Rotated sequence of the same shape as seq.
    """
    dims = seq.shape[-2]
    seq = seq[..., self.permutation_indices[self.n_rot]]
    if dims > 1:
        seq = self.rotation_matrices[self.n_rot][dims - 2] @ seq
    return seq

transform

transform(seq, n_rot=None)

Rotate a sequence on a regular hexagonal lattice.

Parameters:

Name Type Description Default
seq Tensor

Sequence of shape (frames, dims, hexals).

required
n_rot Optional[int]

Optional number of rotations to apply.

None

Returns:

Type Description
Tensor

Rotated sequence of the same shape as seq.

Source code in flyvis/datasets/augmentation/hex.py
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
def transform(self, seq: torch.Tensor, n_rot: Optional[int] = None) -> torch.Tensor:
    """Rotate a sequence on a regular hexagonal lattice.

    Args:
        seq: Sequence of shape (frames, dims, hexals).
        n_rot: Optional number of rotations to apply.

    Returns:
        Rotated sequence of the same shape as seq.
    """
    if n_rot is not None:
        self.n_rot = n_rot
    if self.n_rot > 0:
        return self.rotate(seq)
    return seq

set_or_sample

set_or_sample(n_rot=None)

Set or sample the number of rotations.

Parameters:

Name Type Description Default
n_rot Optional[int]

Number of rotations to set. If None, sample randomly.

None
Source code in flyvis/datasets/augmentation/hex.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
def set_or_sample(self, n_rot: Optional[int] = None) -> None:
    """Set or sample the number of rotations.

    Args:
        n_rot: Number of rotations to set. If None, sample randomly.
    """
    if n_rot is None:
        n_rot = (
            np.random.randint(low=1, high=6)
            if self.p_rot and self.p_rot > np.random.rand()
            else 0
        )
    self.n_rot = n_rot

flyvis.datasets.augmentation.hex.HexFlip

Bases: Augmentation

Flip a sequence of regular hex-lattices across one of three hex-axes.

Parameters:

Name Type Description Default
extent int

Extent of the regular hexagonal grid.

required
axis int

Flipping axis. 0 corresponds to no flipping.

0
p_flip float

Probability of flipping. If None, no flipping is performed.

0.5
flip_axes List[int]

List of valid flipping axes. Can contain 0, 1, 2, 3.

[0, 1, 2, 3]

Attributes:

Name Type Description
extent int

Extent of the regular hexagonal grid.

axis int

Flipping axis.

p_flip float

Probability of flipping.

flip_axes List[int]

List of valid flipping axes.

permutation_indices dict

Cached indices for flipping.

rotation_matrices dict

Cached rotation matrices for flipping.

Note

This is to avoid redundant transformations from rotation and flipping. For example, flipping across the 1st axis is equivalent to rotating by 240 degrees and flipping across the 2nd axis.

Source code in flyvis/datasets/augmentation/hex.py
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
class HexFlip(Augmentation):
    """Flip a sequence of regular hex-lattices across one of three hex-axes.

    Args:
        extent: Extent of the regular hexagonal grid.
        axis: Flipping axis. 0 corresponds to no flipping.
        p_flip: Probability of flipping. If None, no flipping is performed.
        flip_axes: List of valid flipping axes. Can contain 0, 1, 2, 3.

    Attributes:
        extent (int): Extent of the regular hexagonal grid.
        axis (int): Flipping axis.
        p_flip (float): Probability of flipping.
        flip_axes (List[int]): List of valid flipping axes.
        permutation_indices (dict): Cached indices for flipping.
        rotation_matrices (dict): Cached rotation matrices for flipping.

    Note:
        This is to avoid redundant transformations from rotation and flipping.
        For example, flipping across the 1st axis is equivalent to rotating by
        240 degrees and flipping across the 2nd axis.
    """

    def __init__(
        self,
        extent: int,
        axis: int = 0,
        p_flip: float = 0.5,
        flip_axes: List[int] = [0, 1, 2, 3],
    ) -> None:
        self.extent = extent
        self.rotation_matrices: dict = {}
        self.permutation_indices: dict = {}
        for n, angle in enumerate([90, 150, 210], 1):
            R2 = flip_matrix(np.radians(angle), three_d=False)
            R3 = flip_matrix(np.radians(angle), three_d=True)
            self.rotation_matrices[n] = [R2, R3]
            self.permutation_indices[n] = flip_permutation_index(extent, n)
        self.flip_axes = flip_axes
        self.axis = axis
        self.p_flip = p_flip
        self.set_or_sample(axis)

    @property
    def axis(self) -> int:
        """Get the flipping axis."""
        return self._axis

    @axis.setter
    def axis(self, axis: int) -> None:
        """Set the flipping axis."""
        assert (
            axis in self.flip_axes
        ), f"{axis} is not a valid axis. Must be in {self.flip_axes}."
        self._axis = axis

    def flip(self, seq: torch.Tensor) -> torch.Tensor:
        """Flip a sequence on a regular hexagonal lattice.

        Args:
            seq: Sequence of shape (frames, dims, hexals).

        Returns:
            Flipped sequence of the same shape as seq.
        """
        dims = seq.shape[-2]
        seq = seq[..., self.permutation_indices[self.axis]]
        if dims > 1:
            seq = self.rotation_matrices[self.axis][dims - 2] @ seq
        return seq

    def transform(self, seq: torch.Tensor, axis: Optional[int] = None) -> torch.Tensor:
        """Flip a sequence on a regular hexagonal lattice.

        Args:
            seq: Sequence of shape (frames, dims, hexals).
            axis: Optional flipping axis to apply.

        Returns:
            Flipped sequence of the same shape as seq.
        """
        if axis is not None:
            self.axis = axis
        if self.axis > 0:
            return self.flip(seq=seq)
        return seq

    def set_or_sample(self, axis: Optional[int] = None) -> None:
        """Set or sample the flipping axis.

        Args:
            axis: Flipping axis to set. If None, sample randomly.
        """
        if axis is None:
            axis = (
                np.random.randint(low=1, high=max(self.flip_axes) + 1)
                if self.p_flip and self.p_flip > np.random.rand()
                else 0
            )
        self.axis = axis

axis property writable

axis

Get the flipping axis.

flip

flip(seq)

Flip a sequence on a regular hexagonal lattice.

Parameters:

Name Type Description Default
seq Tensor

Sequence of shape (frames, dims, hexals).

required

Returns:

Type Description
Tensor

Flipped sequence of the same shape as seq.

Source code in flyvis/datasets/augmentation/hex.py
166
167
168
169
170
171
172
173
174
175
176
177
178
179
def flip(self, seq: torch.Tensor) -> torch.Tensor:
    """Flip a sequence on a regular hexagonal lattice.

    Args:
        seq: Sequence of shape (frames, dims, hexals).

    Returns:
        Flipped sequence of the same shape as seq.
    """
    dims = seq.shape[-2]
    seq = seq[..., self.permutation_indices[self.axis]]
    if dims > 1:
        seq = self.rotation_matrices[self.axis][dims - 2] @ seq
    return seq

transform

transform(seq, axis=None)

Flip a sequence on a regular hexagonal lattice.

Parameters:

Name Type Description Default
seq Tensor

Sequence of shape (frames, dims, hexals).

required
axis Optional[int]

Optional flipping axis to apply.

None

Returns:

Type Description
Tensor

Flipped sequence of the same shape as seq.

Source code in flyvis/datasets/augmentation/hex.py
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
def transform(self, seq: torch.Tensor, axis: Optional[int] = None) -> torch.Tensor:
    """Flip a sequence on a regular hexagonal lattice.

    Args:
        seq: Sequence of shape (frames, dims, hexals).
        axis: Optional flipping axis to apply.

    Returns:
        Flipped sequence of the same shape as seq.
    """
    if axis is not None:
        self.axis = axis
    if self.axis > 0:
        return self.flip(seq=seq)
    return seq

set_or_sample

set_or_sample(axis=None)

Set or sample the flipping axis.

Parameters:

Name Type Description Default
axis Optional[int]

Flipping axis to set. If None, sample randomly.

None
Source code in flyvis/datasets/augmentation/hex.py
197
198
199
200
201
202
203
204
205
206
207
208
209
def set_or_sample(self, axis: Optional[int] = None) -> None:
    """Set or sample the flipping axis.

    Args:
        axis: Flipping axis to set. If None, sample randomly.
    """
    if axis is None:
        axis = (
            np.random.randint(low=1, high=max(self.flip_axes) + 1)
            if self.p_flip and self.p_flip > np.random.rand()
            else 0
        )
    self.axis = axis

Intensity Transformations

flyvis.datasets.augmentation.hex.ContrastBrightness

Bases: Augmentation

Contrast transformation.

The transformation is described as:

pixel = max(0, contrast_factor * (pixel - 0.5) + 0.5
            + contrast_factor * brightness_factor)

Parameters:

Name Type Description Default
contrast_factor Optional[float]

Contrast factor.

None
brightness_factor Optional[float]

Brightness factor.

None
contrast_std float

Standard deviation of the contrast factor.

0.2
brightness_std float

Standard deviation of the brightness factor.

0.1

Attributes:

Name Type Description
contrast_std float

Standard deviation of the contrast factor.

brightness_std float

Standard deviation of the brightness factor.

contrast_factor float

Contrast factor.

brightness_factor float

Brightness factor.

Source code in flyvis/datasets/augmentation/hex.py
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
class ContrastBrightness(Augmentation):
    """Contrast transformation.

    The transformation is described as:
    ```python
    pixel = max(0, contrast_factor * (pixel - 0.5) + 0.5
                + contrast_factor * brightness_factor)
    ```

    Args:
        contrast_factor: Contrast factor.
        brightness_factor: Brightness factor.
        contrast_std: Standard deviation of the contrast factor.
        brightness_std: Standard deviation of the brightness factor.

    Attributes:
        contrast_std (float): Standard deviation of the contrast factor.
        brightness_std (float): Standard deviation of the brightness factor.
        contrast_factor (float): Contrast factor.
        brightness_factor (float): Brightness factor.
    """

    def __init__(
        self,
        contrast_factor: Optional[float] = None,
        brightness_factor: Optional[float] = None,
        contrast_std: float = 0.2,
        brightness_std: float = 0.1,
    ) -> None:
        self.contrast_std = contrast_std
        self.brightness_std = brightness_std
        self.set_or_sample(contrast_factor, brightness_factor)

    def transform(self, seq: torch.Tensor) -> torch.Tensor:
        """Apply the transformation to a sequence.

        Args:
            seq: Input sequence.

        Returns:
            Transformed sequence.
        """
        if self.contrast_factor is not None:
            return (
                self.contrast_factor * (seq - 0.5)
                + 0.5
                + self.contrast_factor * self.brightness_factor
            ).clamp(0)
        return seq

    def set_or_sample(
        self,
        contrast_factor: Optional[float] = None,
        brightness_factor: Optional[float] = None,
    ) -> None:
        """Set or sample contrast and brightness factors.

        Args:
            contrast_factor: Contrast factor to set. If None, sample randomly.
            brightness_factor: Brightness factor to set. If None, sample randomly.
        """
        if contrast_factor is None:
            # behaves like N(1, std) for small std, slightly biased towards
            # high contrast in particular for large std deviations
            # TODO: implement other sampling schemes
            contrast_factor = (
                np.exp(np.random.normal(0, self.contrast_std))
                if self.contrast_std
                else None
            )
        if brightness_factor is None:
            brightness_factor = (
                np.random.normal(0, self.brightness_std) if self.brightness_std else 0.0
            )
        self.contrast_factor = contrast_factor
        self.brightness_factor = brightness_factor

transform

transform(seq)

Apply the transformation to a sequence.

Parameters:

Name Type Description Default
seq Tensor

Input sequence.

required

Returns:

Type Description
Tensor

Transformed sequence.

Source code in flyvis/datasets/augmentation/hex.py
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
def transform(self, seq: torch.Tensor) -> torch.Tensor:
    """Apply the transformation to a sequence.

    Args:
        seq: Input sequence.

    Returns:
        Transformed sequence.
    """
    if self.contrast_factor is not None:
        return (
            self.contrast_factor * (seq - 0.5)
            + 0.5
            + self.contrast_factor * self.brightness_factor
        ).clamp(0)
    return seq

set_or_sample

set_or_sample(contrast_factor=None, brightness_factor=None)

Set or sample contrast and brightness factors.

Parameters:

Name Type Description Default
contrast_factor Optional[float]

Contrast factor to set. If None, sample randomly.

None
brightness_factor Optional[float]

Brightness factor to set. If None, sample randomly.

None
Source code in flyvis/datasets/augmentation/hex.py
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
def set_or_sample(
    self,
    contrast_factor: Optional[float] = None,
    brightness_factor: Optional[float] = None,
) -> None:
    """Set or sample contrast and brightness factors.

    Args:
        contrast_factor: Contrast factor to set. If None, sample randomly.
        brightness_factor: Brightness factor to set. If None, sample randomly.
    """
    if contrast_factor is None:
        # behaves like N(1, std) for small std, slightly biased towards
        # high contrast in particular for large std deviations
        # TODO: implement other sampling schemes
        contrast_factor = (
            np.exp(np.random.normal(0, self.contrast_std))
            if self.contrast_std
            else None
        )
    if brightness_factor is None:
        brightness_factor = (
            np.random.normal(0, self.brightness_std) if self.brightness_std else 0.0
        )
    self.contrast_factor = contrast_factor
    self.brightness_factor = brightness_factor

flyvis.datasets.augmentation.hex.PixelNoise

Bases: Augmentation

Pixelwise gaussian noise.

The transformation is described as:

pixel = pixel + N(0, std)

It biases the signal to noise ratio: high for light, low for dark pixels.

Parameters:

Name Type Description Default
std float

Standard deviation of the gaussian noise.

0.08

Attributes:

Name Type Description
std float

Standard deviation of the gaussian noise.

Source code in flyvis/datasets/augmentation/hex.py
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
class PixelNoise(Augmentation):
    """Pixelwise gaussian noise.

    The transformation is described as:
    ```python
    pixel = pixel + N(0, std)
    ```

    It biases the signal to noise ratio: high for light, low for dark pixels.

    Args:
        std: Standard deviation of the gaussian noise.

    Attributes:
        std (float): Standard deviation of the gaussian noise.
    """

    def __init__(self, std: float = 0.08) -> None:
        self.std = std

    def transform(self, seq: torch.Tensor) -> torch.Tensor:
        """Apply the transformation to a sequence.

        Args:
            seq: Input sequence.

        Returns:
            Transformed sequence.
        """
        if self.std:
            noise = torch.randn_like(seq) * self.std
            return (seq + noise).clamp(0)
        return seq

    def set_or_sample(self, std: Optional[float] = None) -> None:
        """Set or sample the standard deviation of the gaussian noise.

        Args:
            std: Standard deviation of the gaussian noise to set.
                If None, no change is made.
        """
        if std is None:
            return
        self.std = std

transform

transform(seq)

Apply the transformation to a sequence.

Parameters:

Name Type Description Default
seq Tensor

Input sequence.

required

Returns:

Type Description
Tensor

Transformed sequence.

Source code in flyvis/datasets/augmentation/hex.py
310
311
312
313
314
315
316
317
318
319
320
321
322
def transform(self, seq: torch.Tensor) -> torch.Tensor:
    """Apply the transformation to a sequence.

    Args:
        seq: Input sequence.

    Returns:
        Transformed sequence.
    """
    if self.std:
        noise = torch.randn_like(seq) * self.std
        return (seq + noise).clamp(0)
    return seq

set_or_sample

set_or_sample(std=None)

Set or sample the standard deviation of the gaussian noise.

Parameters:

Name Type Description Default
std Optional[float]

Standard deviation of the gaussian noise to set. If None, no change is made.

None
Source code in flyvis/datasets/augmentation/hex.py
324
325
326
327
328
329
330
331
332
333
def set_or_sample(self, std: Optional[float] = None) -> None:
    """Set or sample the standard deviation of the gaussian noise.

    Args:
        std: Standard deviation of the gaussian noise to set.
            If None, no change is made.
    """
    if std is None:
        return
    self.std = std

flyvis.datasets.augmentation.hex.GammaCorrection

Bases: Augmentation

Gamma correction.

The transformation is described as:

pixel = pixel ** gamma

Gamma > 1 increases the contrast, gamma < 1 decreases the contrast.

Parameters:

Name Type Description Default
gamma float

Gamma value.

1
std Optional[float]

Standard deviation of the gamma value.

None

Attributes:

Name Type Description
gamma

float Gamma value.

std

Optional[float] Standard deviation of the gamma value.

Source code in flyvis/datasets/augmentation/hex.py
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
class GammaCorrection(Augmentation):
    """Gamma correction.

    The transformation is described as:
    ```python
    pixel = pixel ** gamma
    ```

    Gamma > 1 increases the contrast, gamma < 1 decreases the contrast.

    Args:
        gamma: Gamma value.
        std: Standard deviation of the gamma value.

    Attributes:
        gamma: float
            Gamma value.
        std: Optional[float]
            Standard deviation of the gamma value.
    """

    def __init__(self, gamma: float = 1, std: Optional[float] = None) -> None:
        self.gamma = gamma
        self.std = std

    def transform(self, seq: torch.Tensor) -> torch.Tensor:
        """Apply the transformation to a sequence.

        Args:
            seq: Input sequence.

        Returns:
            Transformed sequence.
        """
        if self.gamma:
            return seq**self.gamma
        return seq

    def __setattr__(self, name: str, value: Any) -> None:
        if name == "gamma" and value < 0:
            raise ValueError("Gamma must be positive.")
        return super().__setattr__(name, value)

    def set_or_sample(self, gamma: Optional[float] = None) -> None:
        """Set or sample the gamma value.

        Args:
            gamma: Gamma value to set. If None, sample randomly.
        """
        if gamma is None:
            gamma = max(0, np.random.normal(1, self.std)) if self.std else 1.0
        self.gamma = gamma

transform

transform(seq)

Apply the transformation to a sequence.

Parameters:

Name Type Description Default
seq Tensor

Input sequence.

required

Returns:

Type Description
Tensor

Transformed sequence.

Source code in flyvis/datasets/augmentation/hex.py
361
362
363
364
365
366
367
368
369
370
371
372
def transform(self, seq: torch.Tensor) -> torch.Tensor:
    """Apply the transformation to a sequence.

    Args:
        seq: Input sequence.

    Returns:
        Transformed sequence.
    """
    if self.gamma:
        return seq**self.gamma
    return seq

set_or_sample

set_or_sample(gamma=None)

Set or sample the gamma value.

Parameters:

Name Type Description Default
gamma Optional[float]

Gamma value to set. If None, sample randomly.

None
Source code in flyvis/datasets/augmentation/hex.py
379
380
381
382
383
384
385
386
387
def set_or_sample(self, gamma: Optional[float] = None) -> None:
    """Set or sample the gamma value.

    Args:
        gamma: Gamma value to set. If None, sample randomly.
    """
    if gamma is None:
        gamma = max(0, np.random.normal(1, self.std)) if self.std else 1.0
    self.gamma = gamma

Temporal Augmentations

flyvis.datasets.augmentation.temporal.Interpolate

Bases: Augmentation

Interpolate a sequence to a target framerate.

Attributes:

Name Type Description
original_framerate int

The original framerate of the sequence.

target_framerate float

The target framerate after interpolation.

mode str

The interpolation mode.

align_corners bool | None

Alignment of corners for interpolation.

Source code in flyvis/datasets/augmentation/temporal.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
class Interpolate(Augmentation):
    """Interpolate a sequence to a target framerate.

    Attributes:
        original_framerate (int): The original framerate of the sequence.
        target_framerate (float): The target framerate after interpolation.
        mode (str): The interpolation mode.
        align_corners (bool | None): Alignment of corners for interpolation.
    """

    def __init__(self, original_framerate: int, target_framerate: float, mode: str):
        self.original_framerate = original_framerate
        self.target_framerate = target_framerate
        self.mode = mode
        self.align_corners = (
            True if mode in ["linear", "bilinear", "bicubic", "trilinear"] else None
        )

    def transform(self, sequence: torch.Tensor, dim: int = 0) -> torch.Tensor:
        """Resample the sequence along the specified dimension.

        Args:
            sequence: Sequence to resample of ndim == 3.
            dim: Dimension along which to resample.

        Returns:
            torch.Tensor: Resampled sequence.

        Raises:
            AssertionError: If the input sequence is not 3D.
        """
        assert sequence.ndim == 3, "only 3D sequences are supported"
        if sequence.dtype == torch.long:
            sequence = sequence.float()
        return nnf.interpolate(
            sequence.transpose(dim, -1),
            size=math.ceil(
                self.target_framerate / self.original_framerate * sequence.shape[dim]
            ),
            mode=self.mode,
            align_corners=self.align_corners,
        ).transpose(dim, -1)

    def piecewise_constant_indices(self, length: int) -> torch.Tensor:
        """Return indices to sample from a sequence with piecewise constant interpolation.

        Args:
            length: Length of the original sequence.

        Returns:
            torch.Tensor: Indices for piecewise constant interpolation.
        """
        indices = torch.arange(length, dtype=torch.float)[None, None]
        return (
            nnf.interpolate(
                indices,
                size=math.ceil(self.target_framerate / self.original_framerate * length),
                mode="nearest-exact",
                align_corners=None,
            )
            .flatten()
            .long()
        )

transform

transform(sequence, dim=0)

Resample the sequence along the specified dimension.

Parameters:

Name Type Description Default
sequence Tensor

Sequence to resample of ndim == 3.

required
dim int

Dimension along which to resample.

0

Returns:

Type Description
Tensor

torch.Tensor: Resampled sequence.

Raises:

Type Description
AssertionError

If the input sequence is not 3D.

Source code in flyvis/datasets/augmentation/temporal.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def transform(self, sequence: torch.Tensor, dim: int = 0) -> torch.Tensor:
    """Resample the sequence along the specified dimension.

    Args:
        sequence: Sequence to resample of ndim == 3.
        dim: Dimension along which to resample.

    Returns:
        torch.Tensor: Resampled sequence.

    Raises:
        AssertionError: If the input sequence is not 3D.
    """
    assert sequence.ndim == 3, "only 3D sequences are supported"
    if sequence.dtype == torch.long:
        sequence = sequence.float()
    return nnf.interpolate(
        sequence.transpose(dim, -1),
        size=math.ceil(
            self.target_framerate / self.original_framerate * sequence.shape[dim]
        ),
        mode=self.mode,
        align_corners=self.align_corners,
    ).transpose(dim, -1)

piecewise_constant_indices

piecewise_constant_indices(length)

Return indices to sample from a sequence with piecewise constant interpolation.

Parameters:

Name Type Description Default
length int

Length of the original sequence.

required

Returns:

Type Description
Tensor

torch.Tensor: Indices for piecewise constant interpolation.

Source code in flyvis/datasets/augmentation/temporal.py
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def piecewise_constant_indices(self, length: int) -> torch.Tensor:
    """Return indices to sample from a sequence with piecewise constant interpolation.

    Args:
        length: Length of the original sequence.

    Returns:
        torch.Tensor: Indices for piecewise constant interpolation.
    """
    indices = torch.arange(length, dtype=torch.float)[None, None]
    return (
        nnf.interpolate(
            indices,
            size=math.ceil(self.target_framerate / self.original_framerate * length),
            mode="nearest-exact",
            align_corners=None,
        )
        .flatten()
        .long()
    )

flyvis.datasets.augmentation.temporal.CropFrames

Bases: Augmentation

Crop frames from a sequence.

Attributes:

Name Type Description
n_frames int

Number of frames to crop.

all_frames bool

Whether to return all frames.

start int

Starting frame for cropping.

random bool

Whether to use random cropping.

Source code in flyvis/datasets/augmentation/temporal.py
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
class CropFrames(Augmentation):
    """Crop frames from a sequence.

    Attributes:
        n_frames (int): Number of frames to crop.
        all_frames (bool): Whether to return all frames.
        start (int): Starting frame for cropping.
        random (bool): Whether to use random cropping.
    """

    def __init__(
        self,
        n_frames: int,
        start: int = 0,
        all_frames: bool = False,
        random: bool = False,
    ):
        self.n_frames = n_frames
        self.all_frames = all_frames
        self.start = start
        self.random = random

    def transform(self, sequence: torch.Tensor, dim: int = 0) -> torch.Tensor:
        """Crop the sequence along the specified dimension.

        Args:
            sequence: Sequence to crop of shape (..., n_frames, ...).
            dim: Dimension along which to crop.

        Returns:
            torch.Tensor: Cropped sequence.

        Raises:
            ValueError: If n_frames is greater than the total sequence length.
        """
        if self.all_frames:
            return sequence
        total_seq_length = sequence.shape[dim]
        if self.n_frames > total_seq_length:
            raise ValueError(
                f"cannot crop {self.n_frames} frames from a total"
                f" of {total_seq_length} frames"
            )
        start = self.start if self.random else 0
        indx = [slice(None)] * sequence.ndim
        indx[dim] = slice(start, start + self.n_frames)
        return sequence[indx]

    def set_or_sample(
        self, start: int | None = None, total_sequence_length: int | None = None
    ):
        """Set or sample the starting frame for cropping.

        Args:
            start: Starting frame for cropping.
            total_sequence_length: Total length of the sequence.

        Raises:
            ValueError: If n_frames is greater than the total sequence length.
        """
        if total_sequence_length and self.n_frames > total_sequence_length:
            raise ValueError(
                f"cannot crop {self.n_frames} frames from a total"
                f" of {total_sequence_length} frames"
            )
        if start is None and total_sequence_length:
            start = np.random.randint(
                low=0, high=total_sequence_length - self.n_frames or 1
            )
        self.start = start

transform

transform(sequence, dim=0)

Crop the sequence along the specified dimension.

Parameters:

Name Type Description Default
sequence Tensor

Sequence to crop of shape (…, n_frames, …).

required
dim int

Dimension along which to crop.

0

Returns:

Type Description
Tensor

torch.Tensor: Cropped sequence.

Raises:

Type Description
ValueError

If n_frames is greater than the total sequence length.

Source code in flyvis/datasets/augmentation/temporal.py
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
def transform(self, sequence: torch.Tensor, dim: int = 0) -> torch.Tensor:
    """Crop the sequence along the specified dimension.

    Args:
        sequence: Sequence to crop of shape (..., n_frames, ...).
        dim: Dimension along which to crop.

    Returns:
        torch.Tensor: Cropped sequence.

    Raises:
        ValueError: If n_frames is greater than the total sequence length.
    """
    if self.all_frames:
        return sequence
    total_seq_length = sequence.shape[dim]
    if self.n_frames > total_seq_length:
        raise ValueError(
            f"cannot crop {self.n_frames} frames from a total"
            f" of {total_seq_length} frames"
        )
    start = self.start if self.random else 0
    indx = [slice(None)] * sequence.ndim
    indx[dim] = slice(start, start + self.n_frames)
    return sequence[indx]

set_or_sample

set_or_sample(start=None, total_sequence_length=None)

Set or sample the starting frame for cropping.

Parameters:

Name Type Description Default
start int | None

Starting frame for cropping.

None
total_sequence_length int | None

Total length of the sequence.

None

Raises:

Type Description
ValueError

If n_frames is greater than the total sequence length.

Source code in flyvis/datasets/augmentation/temporal.py
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
def set_or_sample(
    self, start: int | None = None, total_sequence_length: int | None = None
):
    """Set or sample the starting frame for cropping.

    Args:
        start: Starting frame for cropping.
        total_sequence_length: Total length of the sequence.

    Raises:
        ValueError: If n_frames is greater than the total sequence length.
    """
    if total_sequence_length and self.n_frames > total_sequence_length:
        raise ValueError(
            f"cannot crop {self.n_frames} frames from a total"
            f" of {total_sequence_length} frames"
        )
    if start is None and total_sequence_length:
        start = np.random.randint(
            low=0, high=total_sequence_length - self.n_frames or 1
        )
    self.start = start

Base Class

flyvis.datasets.augmentation.augmentation.Augmentation

Base class for data augmentation operations.

This class provides a framework for implementing various data augmentation techniques. Subclasses should override the transform method to implement specific augmentation logic.

Attributes:

Name Type Description
augment bool

Flag to enable or disable augmentation.

Source code in flyvis/datasets/augmentation/augmentation.py
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
class Augmentation:
    """
    Base class for data augmentation operations.

    This class provides a framework for implementing various data augmentation
    techniques. Subclasses should override the `transform` method to implement
    specific augmentation logic.

    Attributes:
        augment (bool): Flag to enable or disable augmentation.
    """

    augment: bool = True

    def __call__(self, seq: Sequence[Any], *args: Any, **kwargs: Any) -> Sequence[Any]:
        """
        Apply augmentation to the input sequence if enabled.

        Args:
            seq: Input sequence to be augmented.
            *args: Additional positional arguments.
            **kwargs: Additional keyword arguments.

        Returns:
            Augmented sequence if augmentation is enabled,
                otherwise the original sequence.
        """
        if self.augment:
            return self.transform(seq, *args, **kwargs)
        return seq

    def transform(self, seq: Sequence[Any], *args: Any, **kwargs: Any) -> Sequence[Any]:
        """
        Apply the augmentation transformation to the input sequence.

        This method should be overridden by subclasses to implement specific
        augmentation logic.

        Args:
            seq: Input sequence to be transformed.
            *args: Additional positional arguments.
            **kwargs: Additional keyword arguments.

        Returns:
            Transformed sequence.
        """
        return seq

    def set_or_sample(self, *args: Any) -> None:
        """
        Set or sample augmentation parameters.

        This method can be used to set fixed parameters or sample random
        parameters for the augmentation process.

        Args:
            *args: Arguments for setting or sampling parameters.
        """
        pass

__call__

__call__(seq, *args, **kwargs)

Apply augmentation to the input sequence if enabled.

Parameters:

Name Type Description Default
seq Sequence[Any]

Input sequence to be augmented.

required
*args Any

Additional positional arguments.

()
**kwargs Any

Additional keyword arguments.

{}

Returns:

Type Description
Sequence[Any]

Augmented sequence if augmentation is enabled, otherwise the original sequence.

Source code in flyvis/datasets/augmentation/augmentation.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def __call__(self, seq: Sequence[Any], *args: Any, **kwargs: Any) -> Sequence[Any]:
    """
    Apply augmentation to the input sequence if enabled.

    Args:
        seq: Input sequence to be augmented.
        *args: Additional positional arguments.
        **kwargs: Additional keyword arguments.

    Returns:
        Augmented sequence if augmentation is enabled,
            otherwise the original sequence.
    """
    if self.augment:
        return self.transform(seq, *args, **kwargs)
    return seq

transform

transform(seq, *args, **kwargs)

Apply the augmentation transformation to the input sequence.

This method should be overridden by subclasses to implement specific augmentation logic.

Parameters:

Name Type Description Default
seq Sequence[Any]

Input sequence to be transformed.

required
*args Any

Additional positional arguments.

()
**kwargs Any

Additional keyword arguments.

{}

Returns:

Type Description
Sequence[Any]

Transformed sequence.

Source code in flyvis/datasets/augmentation/augmentation.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def transform(self, seq: Sequence[Any], *args: Any, **kwargs: Any) -> Sequence[Any]:
    """
    Apply the augmentation transformation to the input sequence.

    This method should be overridden by subclasses to implement specific
    augmentation logic.

    Args:
        seq: Input sequence to be transformed.
        *args: Additional positional arguments.
        **kwargs: Additional keyword arguments.

    Returns:
        Transformed sequence.
    """
    return seq

set_or_sample

set_or_sample(*args)

Set or sample augmentation parameters.

This method can be used to set fixed parameters or sample random parameters for the augmentation process.

Parameters:

Name Type Description Default
*args Any

Arguments for setting or sampling parameters.

()
Source code in flyvis/datasets/augmentation/augmentation.py
52
53
54
55
56
57
58
59
60
61
62
def set_or_sample(self, *args: Any) -> None:
    """
    Set or sample augmentation parameters.

    This method can be used to set fixed parameters or sample random
    parameters for the augmentation process.

    Args:
        *args: Arguments for setting or sampling parameters.
    """
    pass

Utils

flyvis.datasets.augmentation.utils

rotation_matrix

rotation_matrix(angle_in_rad, three_d=False)

Generate a rotation matrix.

Parameters:

Name Type Description Default
angle_in_rad float

Rotation angle in radians.

required
three_d bool

If True, generate a 3D rotation matrix.

False

Returns:

Type Description
Tensor

Rotation matrix as a torch.Tensor.

Source code in flyvis/datasets/augmentation/utils.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def rotation_matrix(angle_in_rad: float, three_d: bool = False) -> torch.Tensor:
    """Generate a rotation matrix.

    Args:
        angle_in_rad: Rotation angle in radians.
        three_d: If True, generate a 3D rotation matrix.

    Returns:
        Rotation matrix as a torch.Tensor.
    """
    if three_d:
        return torch.tensor(
            np.array([
                [np.cos(angle_in_rad), -np.sin(angle_in_rad), 0],
                [np.sin(angle_in_rad), np.cos(angle_in_rad), 0],
                [0, 0, 1],
            ]),
            dtype=torch.float,
        )
    return torch.tensor(
        np.array([
            [np.cos(angle_in_rad), -np.sin(angle_in_rad)],
            [np.sin(angle_in_rad), np.cos(angle_in_rad)],
        ]),
        dtype=torch.float,
    )

rotation_permutation_index

rotation_permutation_index(extent, n_rot)

Calculate rotation permutation indices for hex coordinates.

Parameters:

Name Type Description Default
extent int

Extent of the regular hexagonal grid.

required
n_rot int

Number of 60-degree rotations.

required

Returns:

Type Description
Tensor

Permutation indices as a torch.Tensor.

Source code in flyvis/datasets/augmentation/utils.py
37
38
39
40
41
42
43
44
45
46
47
48
49
def rotation_permutation_index(extent: int, n_rot: int) -> torch.Tensor:
    """Calculate rotation permutation indices for hex coordinates.

    Args:
        extent: Extent of the regular hexagonal grid.
        n_rot: Number of 60-degree rotations.

    Returns:
        Permutation indices as a torch.Tensor.
    """
    u, v = hex_utils.get_hex_coords(extent)
    u_new, v_new = rotate_Nx60(u, v, n_rot)
    return hex_utils.sort_u_then_v_index(u_new, v_new)

rotate_Nx60

rotate_Nx60(u, v, n)

Rotate hex coordinates by multiples of 60 degrees.

Parameters:

Name Type Description Default
u ndarray

U coordinates of hex grid.

required
v ndarray

V coordinates of hex grid.

required
n int

Number of 60-degree rotations.

required

Returns:

Type Description
Tuple[ndarray, ndarray]

Tuple of rotated U and V coordinates.

Note

Resource: http://devmag.org.za/2013/08/31/geometry-with-hex-coordinates/

Source code in flyvis/datasets/augmentation/utils.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def rotate_Nx60(u: np.ndarray, v: np.ndarray, n: int) -> Tuple[np.ndarray, np.ndarray]:
    """Rotate hex coordinates by multiples of 60 degrees.

    Args:
        u: U coordinates of hex grid.
        v: V coordinates of hex grid.
        n: Number of 60-degree rotations.

    Returns:
        Tuple of rotated U and V coordinates.

    Note:
        Resource: http://devmag.org.za/2013/08/31/geometry-with-hex-coordinates/
    """

    def rotate(u: np.ndarray, v: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        """Rotate hex coordinates by 60 degrees.

        Rotation matrix R = [[0, -1], [1, 1]]
        """
        return -v, u + v

    for _ in range(n % 6):
        u, v = rotate(u, v)

    return u, v

flip_matrix

flip_matrix(angle_in_rad, three_d=False)

Generate a flip matrix for mirroring over a line.

Parameters:

Name Type Description Default
angle_in_rad float

Angle of the flip axis in radians.

required
three_d bool

If True, generate a 3D flip matrix.

False

Returns:

Type Description
Tensor

Flip matrix as a torch.Tensor.

Note

Reference: https://math.stackexchange.com/questions/807031/

Source code in flyvis/datasets/augmentation/utils.py
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
def flip_matrix(angle_in_rad: float, three_d: bool = False) -> torch.Tensor:
    """Generate a flip matrix for mirroring over a line.

    Args:
        angle_in_rad: Angle of the flip axis in radians.
        three_d: If True, generate a 3D flip matrix.

    Returns:
        Flip matrix as a torch.Tensor.

    Note:
        Reference: https://math.stackexchange.com/questions/807031/
    """
    if three_d:
        return torch.tensor(
            np.array([
                [np.cos(2 * angle_in_rad), np.sin(2 * angle_in_rad), 0],
                [np.sin(2 * angle_in_rad), -np.cos(2 * angle_in_rad), 0],
                [0, 0, 1],
            ]),
            dtype=torch.float,
        )
    return torch.tensor(
        np.array([
            [np.cos(2 * angle_in_rad), np.sin(2 * angle_in_rad)],
            [np.sin(2 * angle_in_rad), -np.cos(2 * angle_in_rad)],
        ]),
        dtype=torch.float,
    )

flip_permutation_index

flip_permutation_index(extent, axis)

Get indices used to flip the sequence.

Parameters:

Name Type Description Default
extent int

Extent of the regular hexagonal grid.

required
axis Literal[1, 2, 3]

Axis to flip across (1, 2, or 3).

required

Returns:

Type Description
Tensor

Permutation indices as a torch.Tensor.

Raises:

Type Description
ValueError

If axis is not in [1, 2, 3].

Source code in flyvis/datasets/augmentation/utils.py
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
def flip_permutation_index(extent: int, axis: Literal[1, 2, 3]) -> torch.Tensor:
    """Get indices used to flip the sequence.

    Args:
        extent: Extent of the regular hexagonal grid.
        axis: Axis to flip across (1, 2, or 3).

    Returns:
        Permutation indices as a torch.Tensor.

    Raises:
        ValueError: If axis is not in [1, 2, 3].
    """
    u, v = hex_utils.get_hex_coords(extent)
    if axis == 1:
        # flip across v = 0, that is the x axis.
        u_new = u + v
        v_new = -v
    elif axis == 2:
        # flip across u = 0, that is the y axis.
        u_new = -u
        v_new = u + v
    elif axis == 3:
        # flip across u + v = 0, that is the 'z' axis of the hex lattice.
        u_new = -v
        v_new = -u
    else:
        raise ValueError("axis must be in [1, 2, 3].")
    return hex_utils.sort_u_then_v_index(u_new, v_new)