Skip to content

Flash responses

Rendering

flyvision.datasets.flashes.RenderedFlashes

Bases: Directory

Render a directory with flashes for the Flashes dataset.

Parameters:

Name Type Description Default
boxfilter Dict[str, int]

Parameters for the BoxEye filter.

dict(extent=15, kernel_size=13)
dynamic_range List[float]

Range of intensities. E.g. [0, 1] renders flashes with decrement 0.5->0 and increment 0.5->1.

[0, 1]
t_stim float

Duration of the stimulus.

1.0
t_pre float

Duration of the grey stimulus.

1.0
dt float

Timesteps.

1 / 200
radius List[int]

Radius of the stimulus.

[-1, 6]
alternations Tuple[int, ...]

Sequence of alternations between lower or upper intensity and baseline of the dynamic range.

(0, 1, 0)

Attributes:

Name Type Description
flashes ArrayFile

Array containing rendered flash sequences.

Source code in flyvision/datasets/flashes.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
@root(renderings_dir)
class RenderedFlashes(Directory):
    """Render a directory with flashes for the Flashes dataset.

    Args:
        boxfilter: Parameters for the BoxEye filter.
        dynamic_range: Range of intensities. E.g. [0, 1] renders flashes
            with decrement 0.5->0 and increment 0.5->1.
        t_stim: Duration of the stimulus.
        t_pre: Duration of the grey stimulus.
        dt: Timesteps.
        radius: Radius of the stimulus.
        alternations: Sequence of alternations between lower or upper intensity and
            baseline of the dynamic range.

    Attributes:
        flashes (ArrayFile): Array containing rendered flash sequences.
    """

    def __init__(
        self,
        boxfilter: Dict[str, int] = dict(extent=15, kernel_size=13),
        dynamic_range: List[float] = [0, 1],
        t_stim: float = 1.0,
        t_pre: float = 1.0,
        dt: float = 1 / 200,
        radius: List[int] = [-1, 6],
        alternations: Tuple[int, ...] = (0, 1, 0),
    ):
        boxfilter = BoxEye(**boxfilter)
        n_ommatidia = len(boxfilter.receptor_centers)
        dynamic_range = np.array(dynamic_range)
        baseline = 2 * (dynamic_range.sum() / 2,)

        intensity = dynamic_range.copy()
        values = np.array(list(zip(baseline, intensity)))
        samples = dict(v=values, r=radius)
        values = list(product(*(v for v in samples.values())))
        sequence = []  # samples, #frames, width, height
        for (baseline, intensity), rad in tqdm(values, desc="Flashes"):
            sequence.append(
                render_flash(
                    n_ommatidia,
                    intensity,
                    baseline,
                    t_stim,
                    t_pre,
                    dt,
                    alternations,
                    rad,
                )
            )

        self.flashes = np.array(sequence)

Datasets

flyvision.datasets.flashes.Flashes

Bases: SequenceDataset

Flashes dataset.

Parameters:

Name Type Description Default
boxfilter Dict[str, int]

Parameters for the BoxEye filter.

dict(extent=15, kernel_size=13)
dynamic_range List[float]

Range of intensities. E.g. [0, 1] renders flashes with decrement 0.5->0 and increment 0.5->1.

[0, 1]
t_stim float

Duration of the stimulus.

1.0
t_pre float

Duration of the grey stimulus.

1.0
dt float

Timesteps.

1 / 200
radius List[int]

Radius of the stimulus.

[-1, 6]
alternations Tuple[int, ...]

Sequence of alternations between lower or upper intensity and baseline of the dynamic range.

(0, 1, 0)

Attributes:

Name Type Description
dt Union[float, None]

Timestep.

t_post float

Post-stimulus time.

flashes_dir

Directory containing rendered flashes.

config

Configuration object.

baseline

Baseline intensity.

arg_df

DataFrame containing flash parameters.

Note

Zero alternation is the prestimulus and baseline. One alternation is the central stimulus. Has to start with zero alternation. t_pre is the duration of the prestimulus and t_stim is the duration of the stimulus.

Source code in flyvision/datasets/flashes.py
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
class Flashes(SequenceDataset):
    """Flashes dataset.

    Args:
        boxfilter: Parameters for the BoxEye filter.
        dynamic_range: Range of intensities. E.g. [0, 1] renders flashes
            with decrement 0.5->0 and increment 0.5->1.
        t_stim: Duration of the stimulus.
        t_pre: Duration of the grey stimulus.
        dt: Timesteps.
        radius: Radius of the stimulus.
        alternations: Sequence of alternations between lower or upper intensity and
            baseline of the dynamic range.

    Attributes:
        dt: Timestep.
        t_post: Post-stimulus time.
        flashes_dir: Directory containing rendered flashes.
        config: Configuration object.
        baseline: Baseline intensity.
        arg_df: DataFrame containing flash parameters.

    Note:
        Zero alternation is the prestimulus and baseline. One alternation is the
        central stimulus. Has to start with zero alternation. `t_pre` is the
        duration of the prestimulus and `t_stim` is the duration of the stimulus.
    """

    dt: Union[float, None] = None
    t_post: float = 0.0

    def __init__(
        self,
        boxfilter: Dict[str, int] = dict(extent=15, kernel_size=13),
        dynamic_range: List[float] = [0, 1],
        t_stim: float = 1.0,
        t_pre: float = 1.0,
        dt: float = 1 / 200,
        radius: List[int] = [-1, 6],
        alternations: Tuple[int, ...] = (0, 1, 0),
    ):
        assert alternations[0] == 0, "First alternation must be 0."
        self.flashes_dir = RenderedFlashes(
            boxfilter=boxfilter,
            dynamic_range=dynamic_range,
            t_stim=t_stim,
            t_pre=t_pre,
            dt=dt,
            radius=radius,
            alternations=alternations,
        )
        self.config = self.flashes_dir.config
        baseline = 2 * (sum(dynamic_range) / 2,)
        intensity = dynamic_range.copy()

        params = [
            (p[0][0], p[0][1], p[1])
            for p in list(product(zip(baseline, intensity), radius))
        ]
        self.baseline = baseline[0]
        self.arg_df = pd.DataFrame(params, columns=["baseline", "intensity", "radius"])

        self.dt = dt

    @property
    def t_pre(self) -> float:
        """Duration of the prestimulus and zero alternation."""
        return self.config.t_pre

    @property
    def t_stim(self) -> float:
        """Duration of the one alternation."""
        return self.config.t_stim

    def get_item(self, key: int) -> torch.Tensor:
        """Index the dataset.

        Args:
            key: Index of the item to retrieve.

        Returns:
            Flash sequence at the given index.
        """
        return torch.Tensor(self.flashes_dir.flashes[key])

    def __repr__(self) -> str:
        """Return a string representation of the dataset."""
        return f"Flashes dataset. Parametrization: \n{self.arg_df}"

t_pre property

t_pre

Duration of the prestimulus and zero alternation.

t_stim property

t_stim

Duration of the one alternation.

get_item

get_item(key)

Index the dataset.

Parameters:

Name Type Description Default
key int

Index of the item to retrieve.

required

Returns:

Type Description
Tensor

Flash sequence at the given index.

Source code in flyvision/datasets/flashes.py
205
206
207
208
209
210
211
212
213
214
def get_item(self, key: int) -> torch.Tensor:
    """Index the dataset.

    Args:
        key: Index of the item to retrieve.

    Returns:
        Flash sequence at the given index.
    """
    return torch.Tensor(self.flashes_dir.flashes[key])

__repr__

__repr__()

Return a string representation of the dataset.

Source code in flyvision/datasets/flashes.py
216
217
218
def __repr__(self) -> str:
    """Return a string representation of the dataset."""
    return f"Flashes dataset. Parametrization: \n{self.arg_df}"

flyvision.datasets.flashes.render_flash

render_flash(
    n_ommatidia,
    intensity,
    baseline,
    t_stim,
    t_pre,
    dt,
    alternations,
    radius,
)

Generate a sequence of flashes on a hexagonal lattice.

Parameters:

Name Type Description Default
n_ommatidia int

Number of ommatidia.

required
intensity float

Intensity of the flash.

required
baseline float

Intensity of the baseline.

required
t_stim float

Duration of the stimulus.

required
t_pre float

Duration of the grey stimulus.

required
dt float

Timesteps.

required
alternations Tuple[int, ...]

Sequence of alternations between lower or upper intensity and baseline of the dynamic range.

required
radius int

Radius of the stimulus.

required

Returns:

Type Description
ndarray

Generated flash sequence.

Source code in flyvision/datasets/flashes.py
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
def render_flash(
    n_ommatidia: int,
    intensity: float,
    baseline: float,
    t_stim: float,
    t_pre: float,
    dt: float,
    alternations: Tuple[int, ...],
    radius: int,
) -> np.ndarray:
    """Generate a sequence of flashes on a hexagonal lattice.

    Args:
        n_ommatidia: Number of ommatidia.
        intensity: Intensity of the flash.
        baseline: Intensity of the baseline.
        t_stim: Duration of the stimulus.
        t_pre: Duration of the grey stimulus.
        dt: Timesteps.
        alternations: Sequence of alternations between lower or upper intensity
            and baseline of the dynamic range.
        radius: Radius of the stimulus.

    Returns:
        Generated flash sequence.
    """
    stimulus = torch.ones(n_ommatidia)[None] * baseline

    if radius != -1:
        ring = HexLattice.filled_circle(
            radius=radius, center=Hexal(0, 0, 0), as_lattice=True
        )
        coordinate_index = ring.where(1)
    else:
        coordinate_index = np.arange(n_ommatidia)

    stimulus[:, coordinate_index] = intensity

    on = resample(stimulus, t_stim, dt)
    off = resample(torch.ones(n_ommatidia)[None] * baseline, t_pre, dt)

    whole_stimulus = []
    for switch in alternations:
        if switch == 0:
            whole_stimulus.append(off)
        elif switch == 1:
            whole_stimulus.append(on)
    return torch.cat(whole_stimulus, dim=0).cpu().numpy()

flyvision.datasets.dots.Dots

Bases: StimulusDataset

Render flashes aka dots per ommatidia.

Note

Renders directly in receptor space, does not use BoxEye or HexEye as eye-model.

Parameters:

Name Type Description Default
dot_column_radius int

Radius of the dot column.

0
max_extent int

Maximum extent of the stimulus.

15
bg_intensity float

Background intensity.

0.5
t_stim float

Stimulus duration.

5
dt float

Time step.

1 / 200
t_impulse Optional[float]

Impulse duration.

None
n_ommatidia int

Number of ommatidia.

721
t_pre float

Pre-stimulus duration.

2.0
t_post float

Post-stimulus duration.

0
intensity float

Stimulus intensity.

1
mode Literal['sustained', 'impulse']

Stimulus mode (‘sustained’ or ‘impulse’).

'sustained'
device device

Torch device for computations.

device

Attributes:

Name Type Description
dt Optional[float]

Time step.

arg_df Optional[DataFrame]

DataFrame containing stimulus parameters.

config

Namespace containing configuration parameters.

t_stim

Stimulus duration.

n_ommatidia

Number of ommatidia.

offsets

Array of ommatidia offsets.

u

U-coordinates of the hexagonal grid.

v

V-coordinates of the hexagonal grid.

extent_condition

Boolean mask for the extent condition.

max_extent

Maximum extent of the stimulus.

bg_intensity

Background intensity.

intensities

List of stimulus intensities.

device

Torch device for computations.

mode

Stimulus mode (‘sustained’ or ‘impulse’).

params

List of stimulus parameters.

t_impulse

Impulse duration.

Raises:

Type Description
ValueError

If dot_column_radius is greater than max_extent.

Source code in flyvision/datasets/dots.py
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
class Dots(StimulusDataset):
    """
    Render flashes aka dots per ommatidia.

    Note:
        Renders directly in receptor space, does not use BoxEye or HexEye as eye-model.

    Args:
        dot_column_radius: Radius of the dot column.
        max_extent: Maximum extent of the stimulus.
        bg_intensity: Background intensity.
        t_stim: Stimulus duration.
        dt: Time step.
        t_impulse: Impulse duration.
        n_ommatidia: Number of ommatidia.
        t_pre: Pre-stimulus duration.
        t_post: Post-stimulus duration.
        intensity: Stimulus intensity.
        mode: Stimulus mode ('sustained' or 'impulse').
        device: Torch device for computations.

    Attributes:
        dt: Time step.
        arg_df: DataFrame containing stimulus parameters.
        config: Namespace containing configuration parameters.
        t_stim: Stimulus duration.
        n_ommatidia: Number of ommatidia.
        offsets: Array of ommatidia offsets.
        u: U-coordinates of the hexagonal grid.
        v: V-coordinates of the hexagonal grid.
        extent_condition: Boolean mask for the extent condition.
        max_extent: Maximum extent of the stimulus.
        bg_intensity: Background intensity.
        intensities: List of stimulus intensities.
        device: Torch device for computations.
        mode: Stimulus mode ('sustained' or 'impulse').
        params: List of stimulus parameters.
        t_impulse: Impulse duration.

    Raises:
        ValueError: If dot_column_radius is greater than max_extent.
    """

    dt: Optional[float] = None
    arg_df: Optional[pd.DataFrame] = None

    def __init__(
        self,
        dot_column_radius: int = 0,
        max_extent: int = 15,
        bg_intensity: float = 0.5,
        t_stim: float = 5,
        dt: float = 1 / 200,
        t_impulse: Optional[float] = None,
        n_ommatidia: int = 721,
        t_pre: float = 2.0,
        t_post: float = 0,
        intensity: float = 1,
        mode: Literal["sustained", "impulse"] = "sustained",
        device: torch.device = flyvision.device,
    ):
        if dot_column_radius > max_extent:
            raise ValueError("dot_column_radius must be smaller than max_extent")
        self.config = Namespace(
            dot_column_radius=dot_column_radius,
            max_extent=max_extent,
            bg_intensity=bg_intensity,
            t_stim=t_stim,
            dt=dt,
            n_ommatidia=n_ommatidia,
            t_pre=t_pre,
            t_post=t_post,
            intensity=intensity,
            mode=mode,
            t_impulse=t_impulse,
        )

        self.t_stim = t_stim
        self.t_pre = t_pre
        self.t_post = t_post

        self.n_ommatidia = n_ommatidia
        self.offsets = np.arange(self.n_ommatidia)

        u, v = hex_utils.get_hex_coords(hex_utils.get_hextent(n_ommatidia))
        extent_condition = (
            (-max_extent <= u)
            & (u <= max_extent)
            & (-max_extent <= v)
            & (v <= max_extent)
            & (-max_extent <= u + v)
            & (u + v <= max_extent)
        )
        self.u = u[extent_condition]
        self.v = v[extent_condition]
        # self.offsets = self.offsets[extent_condition]
        self.extent_condition = extent_condition

        # to have multi column dots at every location, construct coordinate_indices
        # for each central column
        coordinate_indices = []
        for u, v in zip(self.u, self.v):
            ring = HexLattice.filled_circle(
                radius=dot_column_radius, center=Hexal(u, v, 0), as_lattice=True
            )
            # mask = np.array([~np.isnan(h.value) for h in h1])
            coordinate_indices.append(self.offsets[ring.where(1)])

        self.max_extent = max_extent
        self.bg_intensity = bg_intensity

        self.intensities = [2 * bg_intensity - intensity, intensity]
        self.device = device
        self.mode = mode

        self.params = [
            (*p[0], p[-1])
            for p in list(
                product(
                    zip(self.u, self.v, self.offsets, coordinate_indices),
                    self.intensities,
                )
            )
        ]

        self.arg_df = pd.DataFrame(
            self.params,
            columns=["u", "v", "offset", "coordinate_index", "intensity"],
        )

        self.dt = dt
        self.t_impulse = t_impulse or self.dt

    def _params(self, key: int) -> np.ndarray:
        """Get parameters for a specific key.

        Args:
            key: Index of the parameters to retrieve.

        Returns:
            Array of parameters for the given key.
        """
        return self.arg_df.iloc[key].values

    def get_item(self, key: int) -> torch.Tensor:
        """Get a stimulus item for a specific key.

        Args:
            key: Index of the item to retrieve.

        Returns:
            Tensor representing the stimulus sequence.
        """
        # create maps with background value
        _dot = (
            torch.ones(self.n_ommatidia, device=self.device)[None, None]
            * self.bg_intensity
        )
        # fill at the ommatitdium at offset with intensity
        _, _, _, coordinate_index, intensity = self._params(key)
        _dot[:, :, coordinate_index] = torch.tensor(intensity, device=self.device).float()

        # repeat for stustained stimulus
        if self.mode == "sustained":
            sequence = resample(_dot, self.t_stim, self.dt, dim=1, device=self.device)

        elif self.mode == "impulse":
            # pad remaining stimulus duration i.e. self.t_stim - self.dt with
            # background intensity
            if self.t_impulse == self.dt:
                sequence = pad(
                    _dot,
                    self.t_stim,
                    self.dt,
                    mode="end",
                    fill=self.bg_intensity,
                )
            # first resample for t_impulse/dt then pad remaining stimulus
            # duration, i.e. t_stim - t_impulse with background intensity
            else:
                sequence = resample(
                    _dot, self.t_impulse, self.dt, dim=1, device=self.device
                )
                sequence = pad(
                    sequence,
                    self.t_stim,
                    self.dt,
                    mode="end",
                    fill=self.bg_intensity,
                )

        # pad with pre stimulus background
        sequence = pad(
            sequence,
            self.t_stim + self.t_pre,
            self.dt,
            mode="start",
            fill=self.bg_intensity,
        )
        # pad with post stimulus background
        sequence = pad(
            sequence,
            self.t_stim + self.t_pre + self.t_post,
            self.dt,
            mode="end",
            fill=self.bg_intensity,
        )
        return sequence.squeeze()

    def get_stimulus_index(self, u: float, v: float, intensity: float) -> int:
        """Get sequence ID from given arguments.

        Args:
            u: U-coordinate.
            v: V-coordinate.
            intensity: Stimulus intensity.

        Returns:
            Sequence ID.
        """
        return StimulusDataset.get_stimulus_index(self, locals())

get_item

get_item(key)

Get a stimulus item for a specific key.

Parameters:

Name Type Description Default
key int

Index of the item to retrieve.

required

Returns:

Type Description
Tensor

Tensor representing the stimulus sequence.

Source code in flyvision/datasets/dots.py
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
def get_item(self, key: int) -> torch.Tensor:
    """Get a stimulus item for a specific key.

    Args:
        key: Index of the item to retrieve.

    Returns:
        Tensor representing the stimulus sequence.
    """
    # create maps with background value
    _dot = (
        torch.ones(self.n_ommatidia, device=self.device)[None, None]
        * self.bg_intensity
    )
    # fill at the ommatitdium at offset with intensity
    _, _, _, coordinate_index, intensity = self._params(key)
    _dot[:, :, coordinate_index] = torch.tensor(intensity, device=self.device).float()

    # repeat for stustained stimulus
    if self.mode == "sustained":
        sequence = resample(_dot, self.t_stim, self.dt, dim=1, device=self.device)

    elif self.mode == "impulse":
        # pad remaining stimulus duration i.e. self.t_stim - self.dt with
        # background intensity
        if self.t_impulse == self.dt:
            sequence = pad(
                _dot,
                self.t_stim,
                self.dt,
                mode="end",
                fill=self.bg_intensity,
            )
        # first resample for t_impulse/dt then pad remaining stimulus
        # duration, i.e. t_stim - t_impulse with background intensity
        else:
            sequence = resample(
                _dot, self.t_impulse, self.dt, dim=1, device=self.device
            )
            sequence = pad(
                sequence,
                self.t_stim,
                self.dt,
                mode="end",
                fill=self.bg_intensity,
            )

    # pad with pre stimulus background
    sequence = pad(
        sequence,
        self.t_stim + self.t_pre,
        self.dt,
        mode="start",
        fill=self.bg_intensity,
    )
    # pad with post stimulus background
    sequence = pad(
        sequence,
        self.t_stim + self.t_pre + self.t_post,
        self.dt,
        mode="end",
        fill=self.bg_intensity,
    )
    return sequence.squeeze()

get_stimulus_index

get_stimulus_index(u, v, intensity)

Get sequence ID from given arguments.

Parameters:

Name Type Description Default
u float

U-coordinate.

required
v float

V-coordinate.

required
intensity float

Stimulus intensity.

required

Returns:

Type Description
int

Sequence ID.

Source code in flyvision/datasets/dots.py
228
229
230
231
232
233
234
235
236
237
238
239
def get_stimulus_index(self, u: float, v: float, intensity: float) -> int:
    """Get sequence ID from given arguments.

    Args:
        u: U-coordinate.
        v: V-coordinate.
        intensity: Stimulus intensity.

    Returns:
        Sequence ID.
    """
    return StimulusDataset.get_stimulus_index(self, locals())

flyvision.datasets.dots.CentralImpulses

Bases: StimulusDataset

Flashes at the center of the visual field for temporal receptive field mapping.

Parameters:

Name Type Description Default
impulse_durations List[float]

List of impulse durations.

[0.005, 0.02, 0.05, 0.1, 0.2, 0.3]
dot_column_radius int

Radius of the dot column.

0
bg_intensity float

Background intensity.

0.5
t_stim float

Stimulus duration.

5
dt float

Time step.

0.005
n_ommatidia int

Number of ommatidia.

721
t_pre float

Pre-stimulus duration.

2.0
t_post float

Post-stimulus duration.

0
intensity float

Stimulus intensity.

1
mode str

Stimulus mode.

'impulse'
device device

Torch device for computations.

device

Attributes:

Name Type Description
arg_df Optional[DataFrame]

DataFrame containing stimulus parameters.

dt Optional[float]

Time step.

dots

Instance of the Dots class.

impulse_durations

List of impulse durations.

config

Configuration namespace.

params

List of stimulus parameters.

Source code in flyvision/datasets/dots.py
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
class CentralImpulses(StimulusDataset):
    """Flashes at the center of the visual field for temporal receptive field mapping.

    Args:
        impulse_durations: List of impulse durations.
        dot_column_radius: Radius of the dot column.
        bg_intensity: Background intensity.
        t_stim: Stimulus duration.
        dt: Time step.
        n_ommatidia: Number of ommatidia.
        t_pre: Pre-stimulus duration.
        t_post: Post-stimulus duration.
        intensity: Stimulus intensity.
        mode: Stimulus mode.
        device: Torch device for computations.

    Attributes:
        arg_df: DataFrame containing stimulus parameters.
        dt: Time step.
        dots: Instance of the Dots class.
        impulse_durations: List of impulse durations.
        config: Configuration namespace.
        params: List of stimulus parameters.
    """

    arg_df: Optional[pd.DataFrame] = None
    dt: Optional[float] = None

    def __init__(
        self,
        impulse_durations: List[float] = [5e-3, 20e-3, 50e-3, 100e-3, 200e-3, 300e-3],
        dot_column_radius: int = 0,
        bg_intensity: float = 0.5,
        t_stim: float = 5,
        dt: float = 0.005,
        n_ommatidia: int = 721,
        t_pre: float = 2.0,
        t_post: float = 0,
        intensity: float = 1,
        mode: str = "impulse",
        device: torch.device = flyvision.device,
    ):
        """Initialize the CentralImpulses dataset.

        Args:
            impulse_durations: List of impulse durations.
            dot_column_radius: Radius of the dot column.
            bg_intensity: Background intensity.
            t_stim: Stimulus duration.
            dt: Time step.
            n_ommatidia: Number of ommatidia.
            t_pre: Pre-stimulus duration.
            t_post: Post-stimulus duration.
            intensity: Stimulus intensity.
            mode: Stimulus mode.
            device: Torch device for computations.
        """
        self.dots = Dots(
            dot_column_radius=dot_column_radius,
            max_extent=dot_column_radius,
            bg_intensity=bg_intensity,
            t_stim=t_stim,
            dt=dt,
            n_ommatidia=n_ommatidia,
            t_pre=t_pre,
            t_post=t_post,
            intensity=intensity,
            mode=mode,
            device=device,
        )
        self.impulse_durations = impulse_durations
        self.config = self.dots.config
        self.config.update(impulse_durations=impulse_durations)
        self.params = [
            (*p[0], p[1])
            for p in product(self.dots.arg_df.values.tolist(), impulse_durations)
        ]
        self.arg_df = pd.DataFrame(
            self.params,
            columns=[
                "u",
                "v",
                "offset",
                "coordinate_index",
                "intensity",
                "t_impulse",
            ],
        )
        self.dt = dt

    def _params(self, key: int) -> np.ndarray:
        """Get parameters for a specific key.

        Args:
            key: Index of the parameters to retrieve.

        Returns:
            Array of parameters for the given key.
        """
        return self.arg_df.iloc[key].values

    def get_item(self, key: int) -> torch.Tensor:
        """Get a stimulus item for a specific key.

        Args:
            key: Index of the item to retrieve.

        Returns:
            Tensor representing the stimulus sequence.
        """
        u, v, offset, coordinate_index, intensity, t_impulse = self._params(key)
        self.dots.t_impulse = t_impulse
        return self.dots[self.dots.get_stimulus_index(u, v, intensity)]

    @property
    def t_pre(self) -> float:
        """Get pre-stimulus duration."""
        return self.dots.t_pre

    @property
    def t_post(self) -> float:
        """Get post-stimulus duration."""
        return self.dots.t_post

    def __repr__(self) -> str:
        """Get string representation of the dataset."""
        return repr(self.arg_df)

t_pre property

t_pre

Get pre-stimulus duration.

t_post property

t_post

Get post-stimulus duration.

__init__

__init__(
    impulse_durations=[0.005, 0.02, 0.05, 0.1, 0.2, 0.3],
    dot_column_radius=0,
    bg_intensity=0.5,
    t_stim=5,
    dt=0.005,
    n_ommatidia=721,
    t_pre=2.0,
    t_post=0,
    intensity=1,
    mode="impulse",
    device=flyvision.device,
)

Initialize the CentralImpulses dataset.

Parameters:

Name Type Description Default
impulse_durations List[float]

List of impulse durations.

[0.005, 0.02, 0.05, 0.1, 0.2, 0.3]
dot_column_radius int

Radius of the dot column.

0
bg_intensity float

Background intensity.

0.5
t_stim float

Stimulus duration.

5
dt float

Time step.

0.005
n_ommatidia int

Number of ommatidia.

721
t_pre float

Pre-stimulus duration.

2.0
t_post float

Post-stimulus duration.

0
intensity float

Stimulus intensity.

1
mode str

Stimulus mode.

'impulse'
device device

Torch device for computations.

device
Source code in flyvision/datasets/dots.py
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
def __init__(
    self,
    impulse_durations: List[float] = [5e-3, 20e-3, 50e-3, 100e-3, 200e-3, 300e-3],
    dot_column_radius: int = 0,
    bg_intensity: float = 0.5,
    t_stim: float = 5,
    dt: float = 0.005,
    n_ommatidia: int = 721,
    t_pre: float = 2.0,
    t_post: float = 0,
    intensity: float = 1,
    mode: str = "impulse",
    device: torch.device = flyvision.device,
):
    """Initialize the CentralImpulses dataset.

    Args:
        impulse_durations: List of impulse durations.
        dot_column_radius: Radius of the dot column.
        bg_intensity: Background intensity.
        t_stim: Stimulus duration.
        dt: Time step.
        n_ommatidia: Number of ommatidia.
        t_pre: Pre-stimulus duration.
        t_post: Post-stimulus duration.
        intensity: Stimulus intensity.
        mode: Stimulus mode.
        device: Torch device for computations.
    """
    self.dots = Dots(
        dot_column_radius=dot_column_radius,
        max_extent=dot_column_radius,
        bg_intensity=bg_intensity,
        t_stim=t_stim,
        dt=dt,
        n_ommatidia=n_ommatidia,
        t_pre=t_pre,
        t_post=t_post,
        intensity=intensity,
        mode=mode,
        device=device,
    )
    self.impulse_durations = impulse_durations
    self.config = self.dots.config
    self.config.update(impulse_durations=impulse_durations)
    self.params = [
        (*p[0], p[1])
        for p in product(self.dots.arg_df.values.tolist(), impulse_durations)
    ]
    self.arg_df = pd.DataFrame(
        self.params,
        columns=[
            "u",
            "v",
            "offset",
            "coordinate_index",
            "intensity",
            "t_impulse",
        ],
    )
    self.dt = dt

get_item

get_item(key)

Get a stimulus item for a specific key.

Parameters:

Name Type Description Default
key int

Index of the item to retrieve.

required

Returns:

Type Description
Tensor

Tensor representing the stimulus sequence.

Source code in flyvision/datasets/dots.py
343
344
345
346
347
348
349
350
351
352
353
354
def get_item(self, key: int) -> torch.Tensor:
    """Get a stimulus item for a specific key.

    Args:
        key: Index of the item to retrieve.

    Returns:
        Tensor representing the stimulus sequence.
    """
    u, v, offset, coordinate_index, intensity, t_impulse = self._params(key)
    self.dots.t_impulse = t_impulse
    return self.dots[self.dots.get_stimulus_index(u, v, intensity)]

__repr__

__repr__()

Get string representation of the dataset.

Source code in flyvision/datasets/dots.py
366
367
368
def __repr__(self) -> str:
    """Get string representation of the dataset."""
    return repr(self.arg_df)

flyvision.datasets.dots.SpatialImpulses

Bases: StimulusDataset

Spatial flashes for spatial receptive field mapping.

Parameters:

Name Type Description Default
impulse_durations List[float]

List of impulse durations.

[0.005, 0.02]
max_extent int

Maximum extent of the stimulus.

4
dot_column_radius int

Radius of the dot column.

0
bg_intensity float

Background intensity.

0.5
t_stim float

Stimulus duration.

5
dt float

Time step.

0.005
n_ommatidia int

Number of ommatidia.

721
t_pre float

Pre-stimulus duration.

2.0
t_post float

Post-stimulus duration.

0
intensity float

Stimulus intensity.

1
mode str

Stimulus mode.

'impulse'
device device

Torch device for computations.

device

Attributes:

Name Type Description
arg_df Optional[DataFrame]

DataFrame containing stimulus parameters.

dt Optional[float]

Time step.

dots

Instance of the Dots class.

impulse_durations

List of impulse durations.

config

Configuration namespace.

params

List of stimulus parameters.

Source code in flyvision/datasets/dots.py
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
class SpatialImpulses(StimulusDataset):
    """Spatial flashes for spatial receptive field mapping.

    Args:
        impulse_durations: List of impulse durations.
        max_extent: Maximum extent of the stimulus.
        dot_column_radius: Radius of the dot column.
        bg_intensity: Background intensity.
        t_stim: Stimulus duration.
        dt: Time step.
        n_ommatidia: Number of ommatidia.
        t_pre: Pre-stimulus duration.
        t_post: Post-stimulus duration.
        intensity: Stimulus intensity.
        mode: Stimulus mode.
        device: Torch device for computations.

    Attributes:
        arg_df: DataFrame containing stimulus parameters.
        dt: Time step.
        dots: Instance of the Dots class.
        impulse_durations: List of impulse durations.
        config: Configuration namespace.
        params: List of stimulus parameters.
    """

    arg_df: Optional[pd.DataFrame] = None
    dt: Optional[float] = None

    def __init__(
        self,
        impulse_durations: List[float] = [5e-3, 20e-3],
        max_extent: int = 4,
        dot_column_radius: int = 0,
        bg_intensity: float = 0.5,
        t_stim: float = 5,
        dt: float = 0.005,
        n_ommatidia: int = 721,
        t_pre: float = 2.0,
        t_post: float = 0,
        intensity: float = 1,
        mode: str = "impulse",
        device: torch.device = flyvision.device,
    ):
        self.dots = Dots(
            dot_column_radius=dot_column_radius,
            max_extent=max_extent,
            bg_intensity=bg_intensity,
            t_stim=t_stim,
            dt=dt,
            n_ommatidia=n_ommatidia,
            t_pre=t_pre,
            t_post=t_post,
            intensity=intensity,
            mode=mode,
            device=device,
        )
        self.dt = dt
        self.impulse_durations = impulse_durations

        self.config = self.dots.config
        self.config.update(impulse_durations=impulse_durations)

        self.params = [
            (*p[0], p[1])
            for p in product(self.dots.arg_df.values.tolist(), impulse_durations)
        ]
        self.arg_df = pd.DataFrame(
            self.params,
            columns=[
                "u",
                "v",
                "offset",
                "coordinate_index",
                "intensity",
                "t_impulse",
            ],
        )

    def _params(self, key: int) -> np.ndarray:
        """Get parameters for a specific key.

        Args:
            key: Index of the parameters to retrieve.

        Returns:
            Array of parameters for the given key.
        """
        return self.arg_df.iloc[key].values

    def get_item(self, key: int) -> torch.Tensor:
        """Get a stimulus item for a specific key.

        Args:
            key: Index of the item to retrieve.

        Returns:
            Tensor representing the stimulus sequence.
        """
        u, v, offset, coordinate_index, intensity, t_impulse = self._params(key)
        self.dots.t_impulse = t_impulse
        return self.dots[self.dots.get_stimulus_index(u, v, intensity)]

    def __repr__(self) -> str:
        """Get string representation of the dataset."""
        return repr(self.arg_df)

get_item

get_item(key)

Get a stimulus item for a specific key.

Parameters:

Name Type Description Default
key int

Index of the item to retrieve.

required

Returns:

Type Description
Tensor

Tensor representing the stimulus sequence.

Source code in flyvision/datasets/dots.py
461
462
463
464
465
466
467
468
469
470
471
472
def get_item(self, key: int) -> torch.Tensor:
    """Get a stimulus item for a specific key.

    Args:
        key: Index of the item to retrieve.

    Returns:
        Tensor representing the stimulus sequence.
    """
    u, v, offset, coordinate_index, intensity, t_impulse = self._params(key)
    self.dots.t_impulse = t_impulse
    return self.dots[self.dots.get_stimulus_index(u, v, intensity)]

__repr__

__repr__()

Get string representation of the dataset.

Source code in flyvision/datasets/dots.py
474
475
476
def __repr__(self) -> str:
    """Get string representation of the dataset."""
    return repr(self.arg_df)

Analysis

flyvision.analysis.flash_responses

Analysis of responses to flash stimuli.

Info

Relies on xarray dataset format defined in flyvision.analysis.stimulus_responses.

flash_response_index

flash_response_index(
    self,
    radius,
    on_intensity=1.0,
    off_intensity=0.0,
    nonnegative=True,
)

Compute the Flash Response Index (FRI) using xarray methods.

Parameters:

Name Type Description Default
self DataArray

The input DataArray containing response data.

required
radius float

The radius value to select data for.

required
on_intensity float

The intensity value for the ‘on’ state.

1.0
off_intensity float

The intensity value for the ‘off’ state.

0.0
nonnegative bool

If True, applies a nonnegative constraint to the data.

True

Returns:

Type Description
DataArray

xr.DataArray: The computed Flash Response Index.

Note

Ensures that the stimulus configuration is correct for FRI computation.

Source code in flyvision/analysis/flash_responses.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def flash_response_index(
    self: xr.DataArray,
    radius: float,
    on_intensity: float = 1.0,
    off_intensity: float = 0.0,
    nonnegative: bool = True,
) -> xr.DataArray:
    """
    Compute the Flash Response Index (FRI) using xarray methods.

    Args:
        self: The input DataArray containing response data.
        radius: The radius value to select data for.
        on_intensity: The intensity value for the 'on' state.
        off_intensity: The intensity value for the 'off' state.
        nonnegative: If True, applies a nonnegative constraint to the data.

    Returns:
        xr.DataArray: The computed Flash Response Index.

    Note:
        Ensures that the stimulus configuration is correct for FRI computation.
    """

    # Ensure that the stimulus configuration is correct for FRI computation
    assert tuple(self.attrs['config']['alternations']) == (0, 1, 0)

    responses = self['responses']

    # Select the time window for the stimulus response using query
    time_query = (
        f"{-self.attrs['config']['dt']} <= time <= {self.attrs['config']['t_stim']}"
    )
    stim_response = responses.query(frame=time_query)

    # Select the data for the given radius
    stim_response = stim_response.query(sample=f'radius=={radius}')

    # Apply nonnegative constraint if required
    if nonnegative:
        minimum = stim_response.min(dim=['frame', 'sample'])
        stim_response += np.abs(minimum)

    # Select the response data for on and off intensities
    r_on = stim_response.query(sample=f'intensity=={on_intensity}')
    r_off = stim_response.query(sample=f'intensity=={off_intensity}')

    # Compute the peak responses by finding the maximum along the 'frame' dimension
    on_peak = r_on.max(dim='frame')
    off_peak = r_off.max(dim='frame')

    # Drop the 'sample' coordinate to avoid broadcasting issues
    on_peak = on_peak.drop('sample')
    off_peak = off_peak.drop('sample')

    # Compute the Flash Response Index (FRI)
    fri = on_peak - off_peak
    fri /= on_peak + off_peak + np.array([1e-16])

    # Optionally, you can drop NaN values after computation
    return fri.dropna(dim='sample', how='any')

fri_correlation_to_known

fri_correlation_to_known(fris)

Compute the correlation of the FRI to known cell type tunings.

Parameters:

Name Type Description Default
fris DataArray

DataArray containing Flash Response Index values.

required

Returns:

Type Description
DataArray

xr.DataArray: Correlation of FRIs to known cell type tunings.

Source code in flyvision/analysis/flash_responses.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
def fri_correlation_to_known(fris: xr.DataArray) -> xr.DataArray:
    """
    Compute the correlation of the FRI to known cell type tunings.

    Args:
        fris: DataArray containing Flash Response Index values.

    Returns:
        xr.DataArray: Correlation of FRIs to known cell type tunings.
    """
    known_preferred_contrasts = {
        k: v for k, v in groundtruth_utils.polarity.items() if v != 0
    }
    known_cell_types = list(known_preferred_contrasts.keys())
    groundtruth = list(known_preferred_contrasts.values())

    index = np.array([
        np.where(nt == fris.cell_type)[0].item() for i, nt in enumerate(known_cell_types)
    ])
    fris = fris.isel(neuron=index)
    groundtruth = xr.DataArray(
        data=groundtruth,
        dims=["neuron"],
    )
    return xr.corr(fris, groundtruth, dim="neuron")

plot_fris

plot_fris(
    fris,
    cell_types,
    scatter_best=False,
    scatter_all=True,
    bold_output_type_labels=True,
    output_cell_types=None,
    known_first=True,
    sorted_type_list=None,
    figsize=[10, 1],
    cmap=plt.cm.Greys_r,
    ylim=(-1, 1),
    color_known_types=True,
    fontsize=6,
    colors=None,
    color="b",
    showmeans=False,
    showmedians=True,
    scatter_edge_width=0.5,
    scatter_best_edge_width=0.75,
    scatter_edge_color="none",
    scatter_face_color="k",
    scatter_alpha=0.35,
    scatter_best_alpha=1.0,
    scatter_all_marker="o",
    scatter_best_index=None,
    scatter_best_marker="o",
    scatter_best_color=None,
    mean_median_linewidth=1.5,
    mean_median_bar_length=1.0,
    violin_alpha=0.3,
    **kwargs
)

Plot flash response indices (FRIs) for the given cell types with violins.

Parameters:

Name Type Description Default
fris ndarray

Array of FRI values (n_random_variables, n_groups, n_samples).

required
cell_types ndarray

Array of cell type labels, corresponding to the first axis (n_random_variables) of fris.

required
scatter_best bool

If True, scatter the best points.

False
scatter_all bool

If True, scatter all points.

True
bold_output_type_labels bool

If True, bold the output type labels.

True
output_cell_types Optional[List[str]]

List of cell types to bold in the output.

None
known_first bool

If True, sort known cell types first.

True
sorted_type_list Optional[List[str]]

List of cell types to sort by.

None
figsize List[int]

Figure size as [width, height].

[10, 1]
cmap cm

Colormap for the plot.

Greys_r
ylim Tuple[float, float]

Y-axis limits as (min, max).

(-1, 1)
color_known_types bool

If True, color known cell type labels.

True
fontsize int

Font size for labels.

6
colors Optional[List[str]]

List of colors for the violins.

None
color str

Single color for all violins if cmap is None.

'b'
showmeans bool

If True, show means on the violins.

False
showmedians bool

If True, show medians on the violins.

True
scatter_edge_width float

Width of scatter point edges.

0.5
scatter_best_edge_width float

Width of best scatter point edges.

0.75
scatter_edge_color str

Color of scatter point edges.

'none'
scatter_face_color str

Color of scatter point faces.

'k'
scatter_alpha float

Alpha value for scatter points.

0.35
scatter_best_alpha float

Alpha value for best scatter points.

1.0
scatter_all_marker str

Marker style for all scatter points.

'o'
scatter_best_index Optional[int]

Index of the best scatter point.

None
scatter_best_marker str

Marker style for the best scatter point.

'o'
scatter_best_color Optional[str]

Color of the best scatter point.

None
mean_median_linewidth float

Line width for mean/median lines.

1.5
mean_median_bar_length float

Length of mean/median bars.

1.0
violin_alpha float

Alpha value for violin plots.

0.3
**kwargs

Additional keyword arguments for violin_groups.

{}

Returns:

Type Description
Tuple[Figure, Axes]

Tuple containing the Figure and Axes objects.

Source code in flyvision/analysis/flash_responses.py
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
def plot_fris(
    fris: np.ndarray,
    cell_types: np.ndarray,
    scatter_best: bool = False,
    scatter_all: bool = True,
    bold_output_type_labels: bool = True,
    output_cell_types: Optional[List[str]] = None,
    known_first: bool = True,
    sorted_type_list: Optional[List[str]] = None,
    figsize: List[int] = [10, 1],
    cmap: plt.cm = plt.cm.Greys_r,
    ylim: Tuple[float, float] = (-1, 1),
    color_known_types: bool = True,
    fontsize: int = 6,
    colors: Optional[List[str]] = None,
    color: str = "b",
    showmeans: bool = False,
    showmedians: bool = True,
    scatter_edge_width: float = 0.5,
    scatter_best_edge_width: float = 0.75,
    scatter_edge_color: str = "none",
    scatter_face_color: str = "k",
    scatter_alpha: float = 0.35,
    scatter_best_alpha: float = 1.0,
    scatter_all_marker: str = "o",
    scatter_best_index: Optional[int] = None,
    scatter_best_marker: str = "o",
    scatter_best_color: Optional[str] = None,
    mean_median_linewidth: float = 1.5,
    mean_median_bar_length: float = 1.0,
    violin_alpha: float = 0.3,
    **kwargs,
) -> Tuple[plt.Figure, plt.Axes]:
    """
    Plot flash response indices (FRIs) for the given cell types with violins.

    Args:
        fris: Array of FRI values (n_random_variables, n_groups, n_samples).
        cell_types: Array of cell type labels, corresponding to the first axis
            (n_random_variables) of `fris`.
        scatter_best: If True, scatter the best points.
        scatter_all: If True, scatter all points.
        bold_output_type_labels: If True, bold the output type labels.
        output_cell_types: List of cell types to bold in the output.
        known_first: If True, sort known cell types first.
        sorted_type_list: List of cell types to sort by.
        figsize: Figure size as [width, height].
        cmap: Colormap for the plot.
        ylim: Y-axis limits as (min, max).
        color_known_types: If True, color known cell type labels.
        fontsize: Font size for labels.
        colors: List of colors for the violins.
        color: Single color for all violins if cmap is None.
        showmeans: If True, show means on the violins.
        showmedians: If True, show medians on the violins.
        scatter_edge_width: Width of scatter point edges.
        scatter_best_edge_width: Width of best scatter point edges.
        scatter_edge_color: Color of scatter point edges.
        scatter_face_color: Color of scatter point faces.
        scatter_alpha: Alpha value for scatter points.
        scatter_best_alpha: Alpha value for best scatter points.
        scatter_all_marker: Marker style for all scatter points.
        scatter_best_index: Index of the best scatter point.
        scatter_best_marker: Marker style for the best scatter point.
        scatter_best_color: Color of the best scatter point.
        mean_median_linewidth: Line width for mean/median lines.
        mean_median_bar_length: Length of mean/median bars.
        violin_alpha: Alpha value for violin plots.
        **kwargs: Additional keyword arguments for violin_groups.

    Returns:
        Tuple containing the Figure and Axes objects.
    """
    # Process FRIs data
    if len(fris.shape) != 3:
        fris = fris[:, None]
    if fris.shape[0] != len(cell_types):
        fris = np.transpose(fris, (2, 1, 0))

    # Sort cell types
    if sorted_type_list is not None:
        fris = nodes_edges_utils.sort_by_mapping_lists(
            cell_types, sorted_type_list, fris, axis=0
        )
        cell_types = np.array(sorted_type_list)
    if known_first:
        _cell_types = nodes_edges_utils.nodes_list_sorting_on_off_unknown(cell_types)
        fris = nodes_edges_utils.sort_by_mapping_lists(
            cell_types, _cell_types, fris, axis=0
        )
        cell_types = np.array(_cell_types)

    # Set colors
    if colors is not None:
        pass
    elif cmap is not None:
        colors = None
    elif color is not None:
        cmap = None
        colors = (color,)

    # Create violin plot
    fig, ax, colors = violin_groups(
        fris,
        cell_types[:],
        rotation=90,
        scatter=False,
        cmap=cmap,
        colors=colors,
        fontsize=fontsize,
        figsize=figsize,
        width=0.7,
        showmeans=showmeans,
        showmedians=showmedians,
        mean_median_linewidth=mean_median_linewidth,
        mean_median_bar_length=mean_median_bar_length,
        violin_alpha=violin_alpha,
        **kwargs,
    )

    # Add scatter points if necessary
    if fris.shape[1] == 1:
        plt_utils.scatter_on_violins_with_best(
            fris.T.squeeze(),
            ax,
            scatter_best,
            scatter_all,
            best_index=scatter_best_index,
            linewidth=scatter_edge_width,
            best_linewidth=scatter_best_edge_width,
            edgecolor=scatter_edge_color,
            facecolor=scatter_face_color,
            all_scatter_alpha=scatter_alpha,
            best_scatter_alpha=scatter_best_alpha,
            all_marker=scatter_all_marker,
            best_marker=scatter_best_marker,
            best_color=scatter_best_color,
        )

    # Customize plot appearance
    ax.grid(False)
    if bold_output_type_labels and output_cell_types is not None:
        plt_utils.boldify_labels(output_cell_types, ax)
    ax.set_ylim(*ylim)
    plt_utils.trim_axis(ax)
    plt_utils.set_spine_tick_params(
        ax,
        tickwidth=0.5,
        ticklength=3,
        ticklabelpad=2,
        spinewidth=0.5,
    )
    if color_known_types:
        ax = flash_response_color_labels(ax)
    ax.hlines(
        0,
        min(ax.get_xticks()),
        max(ax.get_xticks()),
        linewidth=0.25,
        color="k",
        zorder=0,
    )
    ax.set_yticks(np.arange(-1.0, 1.5, 0.5))

    return fig, ax

flash_response_color_labels

flash_response_color_labels(ax)

Color the labels of ON and OFF cells in the plot.

Parameters:

Name Type Description Default
ax Axes

The matplotlib Axes object to modify.

required

Returns:

Type Description
Axes

The modified matplotlib Axes object.

Source code in flyvision/analysis/flash_responses.py
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
def flash_response_color_labels(ax: plt.Axes) -> plt.Axes:
    """
    Color the labels of ON and OFF cells in the plot.

    Args:
        ax: The matplotlib Axes object to modify.

    Returns:
        The modified matplotlib Axes object.
    """
    on = [key for key, value in groundtruth_utils.polarity.items() if value == 1]
    off = [key for key, value in groundtruth_utils.polarity.items() if value == -1]
    plt_utils.color_labels(on, ON_FR, ax)
    plt_utils.color_labels(off, OFF_FR, ax)
    return ax