Skip to content

Visualization

flyvision.analysis.visualization.figsize_utils

Functions

flyvision.analysis.visualization.figsize_utils.figsize_from_n_items

figsize_from_n_items(
    n_panels,
    max_figure_height_cm=22,
    panel_height_cm=3,
    max_figure_width_cm=18,
    panel_width_cm=3.6,
    dw_cm=0.1,
)

Calculate figure size based on the number of panels.

Parameters:

Name Type Description Default
n_panels int

Number of panels in the figure.

required
max_figure_height_cm float

Maximum figure height in centimeters.

22
panel_height_cm float

Height of each panel in centimeters.

3
max_figure_width_cm float

Maximum figure width in centimeters.

18
panel_width_cm float

Width of each panel in centimeters.

3.6
dw_cm float

Decrement width in centimeters for panel size adjustment.

0.1

Returns:

Name Type Description
FigsizeCM FigsizeCM

Calculated figure size.

Source code in flyvision/analysis/visualization/figsize_utils.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def figsize_from_n_items(
    n_panels: int,
    max_figure_height_cm: float = 22,
    panel_height_cm: float = 3,
    max_figure_width_cm: float = 18,
    panel_width_cm: float = 3.6,
    dw_cm: float = 0.1,
) -> "FigsizeCM":
    """
    Calculate figure size based on the number of panels.

    Args:
        n_panels: Number of panels in the figure.
        max_figure_height_cm: Maximum figure height in centimeters.
        panel_height_cm: Height of each panel in centimeters.
        max_figure_width_cm: Maximum figure width in centimeters.
        panel_width_cm: Width of each panel in centimeters.
        dw_cm: Decrement width in centimeters for panel size adjustment.

    Returns:
        FigsizeCM: Calculated figure size.
    """
    n_columns = int(max_figure_width_cm / panel_width_cm)
    n_rows = 1
    while n_columns * n_rows < n_panels:
        n_rows += 1
    return fit_panel_size(
        n_rows,
        n_columns,
        max_figure_height_cm,
        panel_height_cm,
        max_figure_width_cm,
        panel_width_cm,
        dw_cm,
    )

flyvision.analysis.visualization.figsize_utils.figure_size_cm

figure_size_cm(
    n_panel_rows,
    n_panel_columns,
    max_figure_height_cm=22,
    panel_height_cm=3,
    max_figure_width_cm=18,
    panel_width_cm=3.6,
    allow_rearranging=True,
)

Calculate figure size in centimeters.

Parameters:

Name Type Description Default
n_panel_rows int

Number of panel rows.

required
n_panel_columns int

Number of panel columns.

required
max_figure_height_cm float

Maximum figure height in centimeters.

22
panel_height_cm float

Height of each panel in centimeters.

3
max_figure_width_cm float

Maximum figure width in centimeters.

18
panel_width_cm float

Width of each panel in centimeters.

3.6
allow_rearranging bool

Whether to allow rearranging panels.

True

Returns:

Name Type Description
FigsizeCM FigsizeCM

Calculated figure size.

Raises:

Type Description
ValueError

If the figure size is not realizable under given constraints.

Source code in flyvision/analysis/visualization/figsize_utils.py
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
def figure_size_cm(
    n_panel_rows: int,
    n_panel_columns: int,
    max_figure_height_cm: float = 22,
    panel_height_cm: float = 3,
    max_figure_width_cm: float = 18,
    panel_width_cm: float = 3.6,
    allow_rearranging: bool = True,
) -> FigsizeCM:
    """
    Calculate figure size in centimeters.

    Args:
        n_panel_rows: Number of panel rows.
        n_panel_columns: Number of panel columns.
        max_figure_height_cm: Maximum figure height in centimeters.
        panel_height_cm: Height of each panel in centimeters.
        max_figure_width_cm: Maximum figure width in centimeters.
        panel_width_cm: Width of each panel in centimeters.
        allow_rearranging: Whether to allow rearranging panels.

    Returns:
        FigsizeCM: Calculated figure size.

    Raises:
        ValueError: If the figure size is not realizable under given constraints.
    """
    width = n_panel_columns * panel_width_cm
    height = n_panel_rows * panel_height_cm
    n_panels = n_panel_rows * n_panel_columns

    if width > max_figure_width_cm and height > max_figure_height_cm:
        raise ValueError("Not realizable under given size constraints")
    elif width > max_figure_width_cm and allow_rearranging:
        n_panel_columns -= 1
        while n_panel_columns * n_panel_rows < n_panels:
            n_panel_rows += 1
        return figure_size_cm(
            n_panel_rows=n_panel_rows,
            n_panel_columns=n_panel_columns,
            max_figure_height_cm=max_figure_height_cm,
            panel_height_cm=panel_height_cm,
            max_figure_width_cm=max_figure_width_cm,
            panel_width_cm=panel_width_cm,
        )
    elif height > max_figure_height_cm and allow_rearranging:
        n_panel_rows -= 1
        while n_panel_columns * n_panel_rows < n_panels:
            n_panel_columns += 1
        return figure_size_cm(
            n_panel_rows=n_panel_rows,
            n_panel_columns=n_panel_columns,
            max_figure_height_cm=max_figure_height_cm,
            panel_height_cm=panel_height_cm,
            max_figure_width_cm=max_figure_width_cm,
            panel_width_cm=panel_width_cm,
        )
    elif not allow_rearranging and (
        width > max_figure_width_cm or height > max_figure_height_cm
    ):
        raise ValueError("Not realizable under given size constraints")

    return FigsizeCM(n_panel_rows, n_panel_columns, height, width)

flyvision.analysis.visualization.figsize_utils.fit_panel_size

fit_panel_size(
    n_panel_rows,
    n_panel_columns,
    max_figure_height_cm=22,
    panel_height_cm=3,
    max_figure_width_cm=18,
    panel_width_cm=3.6,
    dw_cm=0.1,
    allow_rearranging=True,
)

Fit panel size to figure constraints.

Parameters:

Name Type Description Default
n_panel_rows int

Number of panel rows.

required
n_panel_columns int

Number of panel columns.

required
max_figure_height_cm float

Maximum figure height in centimeters.

22
panel_height_cm float

Height of each panel in centimeters.

3
max_figure_width_cm float

Maximum figure width in centimeters.

18
panel_width_cm float

Width of each panel in centimeters.

3.6
dw_cm float

Decrement width in centimeters for panel size adjustment.

0.1
allow_rearranging bool

Whether to allow rearranging panels.

True

Returns:

Name Type Description
FigsizeCM FigsizeCM

Fitted figure size.

Source code in flyvision/analysis/visualization/figsize_utils.py
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
def fit_panel_size(
    n_panel_rows: int,
    n_panel_columns: int,
    max_figure_height_cm: float = 22,
    panel_height_cm: float = 3,
    max_figure_width_cm: float = 18,
    panel_width_cm: float = 3.6,
    dw_cm: float = 0.1,
    allow_rearranging: bool = True,
) -> FigsizeCM:
    """
    Fit panel size to figure constraints.

    Args:
        n_panel_rows: Number of panel rows.
        n_panel_columns: Number of panel columns.
        max_figure_height_cm: Maximum figure height in centimeters.
        panel_height_cm: Height of each panel in centimeters.
        max_figure_width_cm: Maximum figure width in centimeters.
        panel_width_cm: Width of each panel in centimeters.
        dw_cm: Decrement width in centimeters for panel size adjustment.
        allow_rearranging: Whether to allow rearranging panels.

    Returns:
        FigsizeCM: Fitted figure size.
    """
    ratio = panel_width_cm / panel_height_cm

    try:
        return figure_size_cm(
            n_panel_rows,
            n_panel_columns,
            max_figure_height_cm,
            panel_height_cm,
            max_figure_width_cm,
            panel_width_cm,
            allow_rearranging=allow_rearranging,
        )
    except ValueError:
        new_panel_width_cm = panel_width_cm - dw_cm
        new_panel_height_cm = new_panel_width_cm / ratio
        return fit_panel_size(
            n_panel_rows,
            n_panel_columns,
            max_figure_height_cm,
            new_panel_height_cm,
            max_figure_width_cm,
            new_panel_width_cm,
            dw_cm,
            allow_rearranging=allow_rearranging,
        )

flyvision.analysis.visualization.figsize_utils.cm_to_inch

cm_to_inch(*args)

Convert centimeters to inches.

Parameters:

Name Type Description Default
*args Union[Tuple[float, float], float]

Either a tuple of (width, height) or separate width and height values.

()

Returns:

Type Description
Tuple[float, float]

Tuple of width and height in inches.

Source code in flyvision/analysis/visualization/figsize_utils.py
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
def cm_to_inch(*args: Union[Tuple[float, float], float]) -> Tuple[float, float]:
    """
    Convert centimeters to inches.

    Args:
        *args: Either a tuple of (width, height) or separate width and height values.

    Returns:
        Tuple of width and height in inches.
    """
    if len(args) == 1:
        width, height = args[0]
    elif len(args) == 2:
        width, height = args
    else:
        raise ValueError("Invalid number of arguments")
    return width / 2.54, height / 2.54

Classes

flyvision.analysis.visualization.figsize_utils.FigsizeCM dataclass

Represents figure size in centimeters.

Attributes:

Name Type Description
n_rows int

Number of rows in the figure.

n_columns int

Number of columns in the figure.

height float

Height of the figure in centimeters.

width float

Width of the figure in centimeters.

pad float

Padding in centimeters.

Source code in flyvision/analysis/visualization/figsize_utils.py
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
@dataclass
class FigsizeCM:
    """
    Represents figure size in centimeters.

    Attributes:
        n_rows: Number of rows in the figure.
        n_columns: Number of columns in the figure.
        height: Height of the figure in centimeters.
        width: Width of the figure in centimeters.
        pad: Padding in centimeters.
    """

    n_rows: int
    n_columns: int
    height: float
    width: float
    pad: float = 0.5

    @property
    def inches_wh(self) -> Tuple[float, float]:
        """Convert width and height to inches."""
        return cm_to_inch(self.width + self.pad, self.height + self.pad)

    @property
    def panel_height_cm(self) -> float:
        """Calculate panel height in centimeters."""
        return self.height / self.n_rows

    @property
    def panel_width_cm(self) -> float:
        """Calculate panel width in centimeters."""
        return self.width / self.n_columns

    def axis_grid(
        self,
        projection: Union[str, None] = None,
        as_matrix: bool = False,
        fontsize: int = 5,
        wspace: float = 0.1,
        hspace: float = 0.3,
        alpha: float = 1,
        unmask_n: Union[int, None] = None,
    ) -> Tuple:
        """
        Create an axis grid for the figure.

        Args:
            projection: Type of projection for the axes.
            as_matrix: Whether to return axes as a matrix.
            fontsize: Font size for the axes.
            wspace: Width space between subplots.
            hspace: Height space between subplots.
            alpha: Alpha value for the axes.
            unmask_n: Number of axes to unmask.

        Returns:
            Tuple containing the figure and axes.
        """
        fig, axes, _ = plt_utils.get_axis_grid(
            gridwidth=self.n_columns,
            gridheight=self.n_rows,
            figsize=self.inches_wh,
            projection=projection,
            as_matrix=as_matrix,
            fontsize=fontsize,
            wspace=wspace,
            hspace=hspace,
            alpha=alpha,
            unmask_n=unmask_n,
        )
        return fig, axes
inches_wh property
inches_wh

Convert width and height to inches.

panel_height_cm property
panel_height_cm

Calculate panel height in centimeters.

panel_width_cm property
panel_width_cm

Calculate panel width in centimeters.

axis_grid
axis_grid(
    projection=None,
    as_matrix=False,
    fontsize=5,
    wspace=0.1,
    hspace=0.3,
    alpha=1,
    unmask_n=None,
)

Create an axis grid for the figure.

Parameters:

Name Type Description Default
projection Union[str, None]

Type of projection for the axes.

None
as_matrix bool

Whether to return axes as a matrix.

False
fontsize int

Font size for the axes.

5
wspace float

Width space between subplots.

0.1
hspace float

Height space between subplots.

0.3
alpha float

Alpha value for the axes.

1
unmask_n Union[int, None]

Number of axes to unmask.

None

Returns:

Type Description
Tuple

Tuple containing the figure and axes.

Source code in flyvision/analysis/visualization/figsize_utils.py
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
def axis_grid(
    self,
    projection: Union[str, None] = None,
    as_matrix: bool = False,
    fontsize: int = 5,
    wspace: float = 0.1,
    hspace: float = 0.3,
    alpha: float = 1,
    unmask_n: Union[int, None] = None,
) -> Tuple:
    """
    Create an axis grid for the figure.

    Args:
        projection: Type of projection for the axes.
        as_matrix: Whether to return axes as a matrix.
        fontsize: Font size for the axes.
        wspace: Width space between subplots.
        hspace: Height space between subplots.
        alpha: Alpha value for the axes.
        unmask_n: Number of axes to unmask.

    Returns:
        Tuple containing the figure and axes.
    """
    fig, axes, _ = plt_utils.get_axis_grid(
        gridwidth=self.n_columns,
        gridheight=self.n_rows,
        figsize=self.inches_wh,
        projection=projection,
        as_matrix=as_matrix,
        fontsize=fontsize,
        wspace=wspace,
        hspace=hspace,
        alpha=alpha,
        unmask_n=unmask_n,
    )
    return fig, axes

flyvision.analysis.visualization.network_fig

Classes

flyvision.analysis.visualization.network_fig.WholeNetworkFigure

Class for creating a whole network figure.

Attributes:

Name Type Description
nodes DataFrame

DataFrame containing node information.

edges DataFrame

DataFrame containing edge information.

layout Dict[str, str]

Dictionary mapping node types to layout positions.

cell_types List[str]

List of unique cell types.

video bool

Whether to include video node.

rendering bool

Whether to include rendering node.

motion_decoder bool

Whether to include motion decoder node.

decoded_motion bool

Whether to include decoded motion node.

pixel_accurate_motion bool

Whether to include pixel-accurate motion node.

Source code in flyvision/analysis/visualization/network_fig.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
class WholeNetworkFigure:
    """
    Class for creating a whole network figure.

    Attributes:
        nodes (pd.DataFrame): DataFrame containing node information.
        edges (pd.DataFrame): DataFrame containing edge information.
        layout (Dict[str, str]): Dictionary mapping node types to layout positions.
        cell_types (List[str]): List of unique cell types.
        video (bool): Whether to include video node.
        rendering (bool): Whether to include rendering node.
        motion_decoder (bool): Whether to include motion decoder node.
        decoded_motion (bool): Whether to include decoded motion node.
        pixel_accurate_motion (bool): Whether to include pixel-accurate motion node.
    """

    def __init__(
        self,
        connectome,
        video: bool = False,
        rendering: bool = False,
        motion_decoder: bool = False,
        decoded_motion: bool = False,
        pixel_accurate_motion: bool = False,
    ):
        self.nodes = connectome.nodes.to_df()
        self.edges = connectome.edges.to_df()
        self.layout = dict(connectome.layout[:].astype(str))
        self.cell_types = connectome.unique_cell_types[:].astype(str)

        layout = {}
        if video:
            layout.update({"video": "cartesian"})
        if rendering:
            layout.update({"rendering": "hexagonal"})
        layout.update(dict(connectome.layout[:].astype(str)))
        if motion_decoder:
            layout.update({"motion decoder": "decoder"})
        if decoded_motion:
            layout.update({"decoded motion": "motion"})
        if pixel_accurate_motion:
            layout.update({"pixel-accurate motion": "motion"})
        self.layout = layout
        self.video = video
        self.rendering = rendering
        self.motion_decoder = motion_decoder
        self.decoded_motion = decoded_motion
        self.pixel_accurate_motion = pixel_accurate_motion

    def init_figure(
        self,
        figsize: List[int] = [15, 6],
        fontsize: int = 6,
        decoder_box: bool = True,
        cell_type_labels: bool = True,
        neuropil_labels: bool = True,
        network_layout_axes_kwargs: Dict = {},
        add_graph_kwargs: Dict = {},
    ) -> None:
        """
        Initialize the figure with various components.

        Args:
            figsize: Size of the figure.
            fontsize: Font size for labels.
            decoder_box: Whether to add a decoder box.
            cell_type_labels: Whether to add cell type labels.
            neuropil_labels: Whether to add neuropil labels.
            network_layout_axes_kwargs: Additional kwargs for network_layout_axes.
            add_graph_kwargs: Additional kwargs for add_graph.
        """
        self.fig, self.axes, self.axes_centers = network_layout_axes(
            self.layout, figsize=figsize, **network_layout_axes_kwargs
        )
        self.ax_dict = {ax.get_label(): ax for ax in self.axes}
        self.add_graph(**add_graph_kwargs)

        self.add_retina_box()

        if decoder_box:
            self.add_decoded_box()

        if cell_type_labels:
            self.add_cell_type_labels(fontsize=fontsize)

        if neuropil_labels:
            self.add_neuropil_labels(fontsize=fontsize)

        if self.motion_decoder:
            self.add_decoder_sketch()

        self.add_arrows()

    def add_graph(
        self,
        edge_color_key: Optional[str] = None,
        arrows: bool = True,
        edge_alpha: float = 1.0,
        edge_width: float = 1.0,
        constant_edge_width: Optional[float] = 0.25,
        constant_edge_color: str = "#c5c5c5",
        edge_cmap: Optional[str] = None,
        nx_kwargs: Dict = {},
    ) -> None:
        """
        Add the graph to the figure.

        Args:
            edge_color_key: Key for edge color.
            arrows: Whether to add arrows to edges.
            edge_alpha: Alpha value for edges.
            edge_width: Width of edges.
            constant_edge_width: Constant width for all edges.
            constant_edge_color: Constant color for all edges.
            edge_cmap: Colormap for edges.
            nx_kwargs: Additional kwargs for networkx drawing.
        """

        def _network_graph(nodes, edges):
            """Transform graph from df to list to create networkx.Graph object."""
            nodes = nodes.groupby(by=["type"], sort=False, as_index=False).first().type
            edges = list(
                map(
                    lambda x: x.split(","),
                    (edges.source_type + "," + edges.target_type).unique(),
                )
            )
            return nodes, edges

        axes = {
            cell_type: [ax for ax in self.axes if ax.get_label() == cell_type][0]
            for cell_type in self.cell_types
        }

        (
            (lefts, bottoms, rights, tops),
            (
                centers,
                widths,
                height,
            ),
        ) = plt_utils.get_ax_positions(list(axes.values()))
        edge_ax = self.fig.add_axes([
            lefts.min(),
            bottoms.min(),
            rights.max() - lefts.min(),
            tops.max() - bottoms.min(),
        ])
        edge_ax.set_zorder(0)
        edge_ax = plt_utils.rm_spines(edge_ax, rm_xticks=True, rm_yticks=True)
        edge_ax.patch.set_alpha(0.0)
        edge_ax.set_ylim(0, 1)
        edge_ax.set_xlim(0, 1)

        fig_to_edge_ax = self.fig.transFigure + edge_ax.transData.inverted()
        positions = {
            key: fig_to_edge_ax.transform(value)
            for key, value in self.axes_centers.items()
        }

        nodes, edge_list = _network_graph(self.nodes, self.edges)

        if edge_color_key is not None and not constant_edge_color:
            grouped = self.edges.groupby(
                by=["source_type", "target_type"], sort=False, as_index=False
            ).mean(numeric_only=True)
            edge_color = {
                (row.source_type, row.target_type): row.sign
                for i, row in grouped.iterrows()
            }
            _edge_color = np.array(list(edge_color.values()))
            edge_vmin = -np.max(_edge_color) if np.any(_edge_color < 0) else 0
            edge_vmax = np.max(_edge_color)
        else:
            edge_color = {tuple(edge): constant_edge_color for edge in edge_list}
            edge_vmin = None
            edge_vmax = None

        grouped = self.edges.groupby(
            by=["source_type", "target_type"], sort=False, as_index=False
        ).mean(numeric_only=True)

        if constant_edge_width is None:
            edge_width = {
                (row.source_type, row.target_type): edge_width * (np.log(row.n_syn) + 1)
                for i, row in grouped.iterrows()
            }
        else:
            edge_width = {
                (row.source_type, row.target_type): constant_edge_width
                for i, row in grouped.iterrows()
            }

        graph = nx.DiGraph()
        graph.add_nodes_from(nodes)
        graph.add_edges_from(edge_list)

        draw_networkx_edges(
            graph,
            pos=positions,
            ax=edge_ax,
            edge_color=np.array([edge_color[tuple(edge)] for edge in edge_list]),
            edge_cmap=edge_cmap,
            edge_vmin=edge_vmin,
            edge_vmax=edge_vmax,
            alpha=edge_alpha,
            arrows=arrows,
            arrowstyle=(
                "-|>, head_length=0.4, head_width=0.075, widthA=1.0, "
                "widthB=1.0, lengthA=0.2, lengthB=0.2"
            ),
            width=np.array([edge_width[tuple(edge)] for edge in edge_list]),
            **nx_kwargs,
        )
        self.edge_ax = edge_ax

    def add_retina_box(self):
        retina_node_types = valfilter(lambda v: v == "retina", self.layout)
        axes = {
            node_type: [ax for ax in self.axes if ax.get_label() == node_type][0]
            for node_type in retina_node_types
        }
        (
            (lefts, bottoms, rights, tops),
            (
                centers,
                widths,
                height,
            ),
        ) = plt_utils.get_ax_positions(list(axes.values()))
        retina_box_ax = self.fig.add_axes(
            [
                lefts.min(),
                bottoms.min(),
                rights.max() - lefts.min(),
                tops.max() - bottoms.min(),
            ],
            label="retina_box",
        )
        retina_box_ax.patch.set_alpha(0)
        plt_utils.rm_spines(retina_box_ax)
        self.ax_dict["retina box"] = retina_box_ax

    def add_decoded_box(self):
        output_cell_types = valfilter(lambda v: v == "output", self.layout)
        axes = {
            cell_type: [ax for ax in self.axes if ax.get_label() == cell_type][0]
            for cell_type in output_cell_types
        }
        (lefts, bottoms, rights, tops), _ = plt_utils.get_ax_positions(
            list(axes.values())
        )
        bottom, top = plt_utils.get_lims((bottoms, tops), 0.02)
        left, right = plt_utils.get_lims((lefts, rights), 0.01)
        decoded_box_ax = self.fig.add_axes(
            [
                left,
                bottom,
                right - left,
                top - bottom,
            ],
            label="decoded_box",
        )
        decoded_box_ax.patch.set_alpha(0)
        decoded_box_ax.spines["top"].set_visible(True)
        decoded_box_ax.spines["right"].set_visible(True)
        decoded_box_ax.spines["left"].set_visible(True)
        decoded_box_ax.spines["bottom"].set_visible(True)
        decoded_box_ax.set_xticks([])
        decoded_box_ax.set_yticks([])
        self.ax_dict["decoded box"] = decoded_box_ax

    def add_decoder_sketch(self):
        ax = self.ax_dict["motion decoder"]
        nodes = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]
        edges = [
            (1, 11),
            (2, 10),
            (2, 12),
            (3, 10),
            (3, 11),
            (3, 12),
            (4, 11),
            (4, 12),
            (4, 14),
            (5, 10),
            (5, 12),
            (5, 13),
            (6, 11),
            (6, 13),
            (6, 14),
            (7, 12),
            (7, 14),
            (8, 13),
            (9, 14),
            (10, 15),
            (11, 16),
            (12, 15),
            (13, 15),
            (13, 16),
            (14, 16),
        ]
        graph = nx.Graph()
        graph.add_nodes_from(nodes)
        graph.add_edges_from(edges)
        x = [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2]
        y = [9, 8, 7, 6, 5, 4, 3, 2, 1, 7.3, 6.3, 5.3, 4.3, 3.3, 5.7, 4.7]
        x, y, width, height = plt_utils.scale(x, y)
        nx.draw_networkx(
            graph,
            pos=dict(zip(nodes, zip(x, y))),
            node_shape="H",
            node_size=50,
            node_color="#5a5b5b",
            edge_color="#C5C4C4",
            width=0.25,
            with_labels=False,
            ax=ax,
            arrows=False,
        )
        plt_utils.rm_spines(ax)

    def add_arrows(self):
        def arrow_between_axes(axA, axB):
            # Create the arrow
            # 1. Get transformation operators for axis and figure
            ax0tr = axA.transAxes  # Axis 0 -> Display
            ax1tr = axB.transAxes  # Axis 1 -> Display
            figtr = self.fig.transFigure.inverted()  # Display -> Figure
            # 2. Transform arrow start point from axis 0 to figure coordinates
            ptA = figtr.transform(ax0tr.transform((1, 0.5)))
            # 3. Transform arrow end point from axis 1 to figure coordinates
            ptB = figtr.transform(ax1tr.transform((0, 0.5)))
            # 4. Create the patch
            arrow = matplotlib.patches.FancyArrowPatch(
                ptA,
                ptB,
                transform=self.fig.transFigure,  # Place arrow in figure coord system
                # fc=self.fontcolor,
                # ec=self.fontcolor,
                #     connectionstyle="arc3",
                arrowstyle="simple, head_width=3, head_length=6, tail_width=0.15",
                alpha=1,
                mutation_scale=1.0,
            )
            arrow.set_lw(0.25)
            # 5. Add patch to list of objects to draw onto the figure
            self.fig.patches.append(arrow)

        if self.video and self.rendering:
            arrow_between_axes(self.ax_dict["video"], self.ax_dict["rendering"])
            arrow_between_axes(self.ax_dict["rendering"], self.ax_dict["retina box"])
        elif self.video:
            arrow_between_axes(self.ax_dict["video"], self.ax_dict["retina box"])
        elif self.rendering:
            arrow_between_axes(self.ax_dict["rendering"], self.ax_dict["retina box"])

        if self.motion_decoder and self.decoded_motion:
            arrow_between_axes(
                self.ax_dict["decoded box"], self.ax_dict["motion decoder"]
            )
            arrow_between_axes(
                self.ax_dict["motion decoder"], self.ax_dict["decoded motion"]
            )
        elif self.motion_decoder:
            arrow_between_axes(
                self.ax_dict["decoded box"], self.ax_dict["motion decoder"]
            )
        elif self.decoded_motion:
            arrow_between_axes(
                self.ax_dict["decoded box"], self.ax_dict["decoded motion"]
            )

    def add_cell_type_labels(self, fontsize=5):
        for label, ax in self.ax_dict.items():
            if label in self.cell_types:
                ax.annotate(
                    label,
                    (0, 0.9),
                    xycoords="axes fraction",
                    va="bottom",
                    ha="right",
                    fontsize=fontsize,
                )

    def add_neuropil_labels(self, fontsize=5):
        retina_cell_types = valfilter(lambda v: v == "retina", self.layout)
        axes = {
            cell_type: [ax for ax in self.axes if ax.get_label() == cell_type][0]
            for cell_type in retina_cell_types
        }
        (lefts, bottoms, rights, tops), _ = plt_utils.get_ax_positions(
            list(axes.values())
        )
        self.fig.text(
            lefts.min() + (rights.max() - lefts.min()) / 2,
            0,
            "retina",
            fontsize=fontsize,
            va="top",
            ha="center",
        )

        intermediate_cell_types = valfilter(lambda v: v == "intermediate", self.layout)
        axes = {
            cell_type: [ax for ax in self.axes if ax.get_label() == cell_type][0]
            for cell_type in intermediate_cell_types
        }
        (
            (lefts, bottoms, rights, tops),
            (
                centers,
                widths,
                height,
            ),
        ) = plt_utils.get_ax_positions(list(axes.values()))
        self.fig.text(
            lefts.min() + (rights.max() - lefts.min()) / 2,
            0,
            "lamina, medulla intrinsic cells, CT1",
            fontsize=fontsize,
            va="top",
            ha="center",
        )

        output_cell_types = valfilter(lambda v: v == "output", self.layout)
        axes = {
            cell_type: [ax for ax in self.axes if ax.get_label() == cell_type][0]
            for cell_type in output_cell_types
        }
        (
            (lefts, bottoms, rights, tops),
            (
                centers,
                widths,
                height,
            ),
        ) = plt_utils.get_ax_positions(list(axes.values()))
        self.fig.text(
            lefts.min() + (rights.max() - lefts.min()) / 2,
            0,
            "T-shaped, transmedullary cells",
            fontsize=fontsize,
            va="top",
            ha="center",
        )
init_figure
init_figure(
    figsize=[15, 6],
    fontsize=6,
    decoder_box=True,
    cell_type_labels=True,
    neuropil_labels=True,
    network_layout_axes_kwargs={},
    add_graph_kwargs={},
)

Initialize the figure with various components.

Parameters:

Name Type Description Default
figsize List[int]

Size of the figure.

[15, 6]
fontsize int

Font size for labels.

6
decoder_box bool

Whether to add a decoder box.

True
cell_type_labels bool

Whether to add cell type labels.

True
neuropil_labels bool

Whether to add neuropil labels.

True
network_layout_axes_kwargs Dict

Additional kwargs for network_layout_axes.

{}
add_graph_kwargs Dict

Additional kwargs for add_graph.

{}
Source code in flyvision/analysis/visualization/network_fig.py
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
def init_figure(
    self,
    figsize: List[int] = [15, 6],
    fontsize: int = 6,
    decoder_box: bool = True,
    cell_type_labels: bool = True,
    neuropil_labels: bool = True,
    network_layout_axes_kwargs: Dict = {},
    add_graph_kwargs: Dict = {},
) -> None:
    """
    Initialize the figure with various components.

    Args:
        figsize: Size of the figure.
        fontsize: Font size for labels.
        decoder_box: Whether to add a decoder box.
        cell_type_labels: Whether to add cell type labels.
        neuropil_labels: Whether to add neuropil labels.
        network_layout_axes_kwargs: Additional kwargs for network_layout_axes.
        add_graph_kwargs: Additional kwargs for add_graph.
    """
    self.fig, self.axes, self.axes_centers = network_layout_axes(
        self.layout, figsize=figsize, **network_layout_axes_kwargs
    )
    self.ax_dict = {ax.get_label(): ax for ax in self.axes}
    self.add_graph(**add_graph_kwargs)

    self.add_retina_box()

    if decoder_box:
        self.add_decoded_box()

    if cell_type_labels:
        self.add_cell_type_labels(fontsize=fontsize)

    if neuropil_labels:
        self.add_neuropil_labels(fontsize=fontsize)

    if self.motion_decoder:
        self.add_decoder_sketch()

    self.add_arrows()
add_graph
add_graph(
    edge_color_key=None,
    arrows=True,
    edge_alpha=1.0,
    edge_width=1.0,
    constant_edge_width=0.25,
    constant_edge_color="#c5c5c5",
    edge_cmap=None,
    nx_kwargs={},
)

Add the graph to the figure.

Parameters:

Name Type Description Default
edge_color_key Optional[str]

Key for edge color.

None
arrows bool

Whether to add arrows to edges.

True
edge_alpha float

Alpha value for edges.

1.0
edge_width float

Width of edges.

1.0
constant_edge_width Optional[float]

Constant width for all edges.

0.25
constant_edge_color str

Constant color for all edges.

'#c5c5c5'
edge_cmap Optional[str]

Colormap for edges.

None
nx_kwargs Dict

Additional kwargs for networkx drawing.

{}
Source code in flyvision/analysis/visualization/network_fig.py
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
def add_graph(
    self,
    edge_color_key: Optional[str] = None,
    arrows: bool = True,
    edge_alpha: float = 1.0,
    edge_width: float = 1.0,
    constant_edge_width: Optional[float] = 0.25,
    constant_edge_color: str = "#c5c5c5",
    edge_cmap: Optional[str] = None,
    nx_kwargs: Dict = {},
) -> None:
    """
    Add the graph to the figure.

    Args:
        edge_color_key: Key for edge color.
        arrows: Whether to add arrows to edges.
        edge_alpha: Alpha value for edges.
        edge_width: Width of edges.
        constant_edge_width: Constant width for all edges.
        constant_edge_color: Constant color for all edges.
        edge_cmap: Colormap for edges.
        nx_kwargs: Additional kwargs for networkx drawing.
    """

    def _network_graph(nodes, edges):
        """Transform graph from df to list to create networkx.Graph object."""
        nodes = nodes.groupby(by=["type"], sort=False, as_index=False).first().type
        edges = list(
            map(
                lambda x: x.split(","),
                (edges.source_type + "," + edges.target_type).unique(),
            )
        )
        return nodes, edges

    axes = {
        cell_type: [ax for ax in self.axes if ax.get_label() == cell_type][0]
        for cell_type in self.cell_types
    }

    (
        (lefts, bottoms, rights, tops),
        (
            centers,
            widths,
            height,
        ),
    ) = plt_utils.get_ax_positions(list(axes.values()))
    edge_ax = self.fig.add_axes([
        lefts.min(),
        bottoms.min(),
        rights.max() - lefts.min(),
        tops.max() - bottoms.min(),
    ])
    edge_ax.set_zorder(0)
    edge_ax = plt_utils.rm_spines(edge_ax, rm_xticks=True, rm_yticks=True)
    edge_ax.patch.set_alpha(0.0)
    edge_ax.set_ylim(0, 1)
    edge_ax.set_xlim(0, 1)

    fig_to_edge_ax = self.fig.transFigure + edge_ax.transData.inverted()
    positions = {
        key: fig_to_edge_ax.transform(value)
        for key, value in self.axes_centers.items()
    }

    nodes, edge_list = _network_graph(self.nodes, self.edges)

    if edge_color_key is not None and not constant_edge_color:
        grouped = self.edges.groupby(
            by=["source_type", "target_type"], sort=False, as_index=False
        ).mean(numeric_only=True)
        edge_color = {
            (row.source_type, row.target_type): row.sign
            for i, row in grouped.iterrows()
        }
        _edge_color = np.array(list(edge_color.values()))
        edge_vmin = -np.max(_edge_color) if np.any(_edge_color < 0) else 0
        edge_vmax = np.max(_edge_color)
    else:
        edge_color = {tuple(edge): constant_edge_color for edge in edge_list}
        edge_vmin = None
        edge_vmax = None

    grouped = self.edges.groupby(
        by=["source_type", "target_type"], sort=False, as_index=False
    ).mean(numeric_only=True)

    if constant_edge_width is None:
        edge_width = {
            (row.source_type, row.target_type): edge_width * (np.log(row.n_syn) + 1)
            for i, row in grouped.iterrows()
        }
    else:
        edge_width = {
            (row.source_type, row.target_type): constant_edge_width
            for i, row in grouped.iterrows()
        }

    graph = nx.DiGraph()
    graph.add_nodes_from(nodes)
    graph.add_edges_from(edge_list)

    draw_networkx_edges(
        graph,
        pos=positions,
        ax=edge_ax,
        edge_color=np.array([edge_color[tuple(edge)] for edge in edge_list]),
        edge_cmap=edge_cmap,
        edge_vmin=edge_vmin,
        edge_vmax=edge_vmax,
        alpha=edge_alpha,
        arrows=arrows,
        arrowstyle=(
            "-|>, head_length=0.4, head_width=0.075, widthA=1.0, "
            "widthB=1.0, lengthA=0.2, lengthB=0.2"
        ),
        width=np.array([edge_width[tuple(edge)] for edge in edge_list]),
        **nx_kwargs,
    )
    self.edge_ax = edge_ax

Functions

flyvision.analysis.visualization.network_fig.network_layout_axes

network_layout_axes(
    layout,
    cell_types=None,
    fig=None,
    figsize=[16, 10],
    types_per_column=8,
    region_spacing=2,
    wspace=0,
    hspace=0,
    as_dict=False,
    pos=None,
)

Create axes for network layout.

Parameters:

Name Type Description Default
layout Dict[str, str]

Dictionary mapping node types to layout positions.

required
cell_types Optional[List[str]]

List of cell types to include.

None
fig Optional[Figure]

Existing figure to use.

None
figsize List[int]

Size of the figure.

[16, 10]
types_per_column int

Number of types per column.

8
region_spacing int

Spacing between regions.

2
wspace float

Width space between subplots.

0
hspace float

Height space between subplots.

0
as_dict bool

Whether to return axes as a dictionary.

False
pos Optional[Dict[str, List[float]]]

Pre-computed positions for nodes.

None

Returns:

Type Description
Tuple[Figure, Union[List[Axes], Dict[str, Axes]], Dict[str, List[float]]]

Tuple containing the figure, axes, and node positions.

Source code in flyvision/analysis/visualization/network_fig.py
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
def network_layout_axes(
    layout: Dict[str, str],
    cell_types: Optional[List[str]] = None,
    fig: Optional[plt.Figure] = None,
    figsize: List[int] = [16, 10],
    types_per_column: int = 8,
    region_spacing: int = 2,
    wspace: float = 0,
    hspace: float = 0,
    as_dict: bool = False,
    pos: Optional[Dict[str, List[float]]] = None,
) -> Tuple[
    plt.Figure, Union[List[plt.Axes], Dict[str, plt.Axes]], Dict[str, List[float]]
]:
    """
    Create axes for network layout.

    Args:
        layout: Dictionary mapping node types to layout positions.
        cell_types: List of cell types to include.
        fig: Existing figure to use.
        figsize: Size of the figure.
        types_per_column: Number of types per column.
        region_spacing: Spacing between regions.
        wspace: Width space between subplots.
        hspace: Height space between subplots.
        as_dict: Whether to return axes as a dictionary.
        pos: Pre-computed positions for nodes.

    Returns:
        Tuple containing the figure, axes, and node positions.
    """
    fig = fig or plt.figure(figsize=figsize)

    pos = pos or _network_graph_node_pos(
        layout, region_spacing=region_spacing, types_per_column=types_per_column
    )
    pos = {
        key: value
        for key, value in pos.items()
        if (cell_types is None or key in cell_types)
    }
    xy = np.array(list(pos.values()))
    # why pad this?
    # hpad = 0.05
    # wpad = 0.05
    hpad = 0.0
    wpad = 0.0
    fig, axes, xy_scaled = plt_utils.ax_scatter(
        xy[:, 0],
        xy[:, 1],
        fig=fig,
        wspace=wspace,
        hspace=hspace,
        hpad=hpad,
        wpad=wpad,
        alpha=0,
        labels=list(pos.keys()),
    )
    new_pos = {key: xy_scaled[i] for i, key in enumerate(pos.keys())}
    if as_dict:
        return (
            fig,
            {cell_type: axes[i] for i, cell_type in enumerate(new_pos)},
            new_pos,
        )
    return fig, axes, new_pos

flyvision.analysis.visualization.network_fig._network_graph_node_pos

_network_graph_node_pos(
    layout, region_spacing=2, types_per_column=8
)

Compute (x, y) coordinates for nodes in a network graph.

Parameters:

Name Type Description Default
layout Dict[str, str]

Dictionary mapping node types to layout positions.

required
region_spacing float

Spacing between regions.

2
types_per_column int

Number of types per column.

8

Returns:

Type Description
Dict[str, List[float]]

Dictionary mapping node types to their (x, y) coordinates.

Note

Special nodes like ‘video’, ‘rendering’, etc. are positioned at the middle y-coordinate of their respective columns.

Source code in flyvision/analysis/visualization/network_fig.py
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
def _network_graph_node_pos(
    layout: Dict[str, str], region_spacing: float = 2, types_per_column: int = 8
) -> Dict[str, List[float]]:
    """
    Compute (x, y) coordinates for nodes in a network graph.

    Args:
        layout: Dictionary mapping node types to layout positions.
        region_spacing: Spacing between regions.
        types_per_column: Number of types per column.

    Returns:
        Dictionary mapping node types to their (x, y) coordinates.

    Note:
        Special nodes like 'video', 'rendering', etc. are positioned at the middle
        y-coordinate of their respective columns.
    """
    x_coordinate = 0
    region_0 = "retina"
    pos = {}
    j = 0
    special_nodes = [
        "video",
        "rendering",
        "motion decoder",
        "decoded motion",
        "pixel-accurate motion",
    ]

    for typ in layout:
        if typ in special_nodes:
            region_spacing = 1.25
        if layout[typ] != region_0:
            x_coordinate += region_spacing
            j = 0
        elif (j % types_per_column) == 0 and j != 0:
            x_coordinate += 1
        y_coordinate = types_per_column - 1 - j % types_per_column
        pos[typ] = [x_coordinate, y_coordinate]
        region_0 = layout[typ]
        j += 1

    y_mid = (types_per_column - 1) / 2

    for node in special_nodes:
        if node in layout:
            pos[node][1] = y_mid

    if "pixel-accurate motion" in layout:
        pos["pixel-accurate motion"][1] = y_mid - 1.5

    return pos

flyvision.analysis.visualization.plots

Functions

flyvision.analysis.visualization.plots.heatmap

heatmap(
    matrix,
    xlabels,
    ylabels=None,
    size_scale="auto",
    cmap=cm.get_cmap("seismic"),
    origin="upper",
    ax=None,
    fig=None,
    vmin=None,
    vmax=None,
    symlog=None,
    cbar_label="",
    log=None,
    cbar_height=0.5,
    cbar_width=0.01,
    title="",
    figsize=[5, 4],
    fontsize=4,
    midpoint=None,
    cbar=True,
    grid_linewidth=0.5,
    **kwargs
)

Create a heatmap scatter plot of the matrix.

Parameters:

Name Type Description Default
matrix ndarray

2D matrix to be plotted.

required
xlabels List[str]

List of x-axis labels.

required
ylabels Optional[List[str]]

List of y-axis labels. If not provided, xlabels will be used.

None
size_scale Union[str, float]

Size scale of the scatter points. If “auto”, uses 0.005 * prod(figsize).

'auto'
cmap Colormap

Colormap for the heatmap.

get_cmap('seismic')
origin Literal['upper', 'lower']

Origin of the matrix. Either “upper” or “lower”.

'upper'
ax Optional[Axes]

Existing Matplotlib Axes object to plot on.

None
fig Optional[Figure]

Existing Matplotlib Figure object to use.

None
vmin Optional[float]

Minimum value for color scaling.

None
vmax Optional[float]

Maximum value for color scaling.

None
symlog Optional[bool]

Whether to use symmetric log normalization.

None
cbar_label str

Label for the colorbar.

''
log Optional[bool]

Whether to use logarithmic color scaling.

None
cbar_height float

Height of the colorbar.

0.5
cbar_width float

Width of the colorbar.

0.01
title str

Title of the plot.

''
figsize Tuple[float, float]

Size of the figure.

[5, 4]
fontsize int

Font size for labels and ticks.

4
midpoint Optional[float]

Midpoint for diverging colormaps.

None
cbar bool

Whether to show the colorbar.

True
grid_linewidth float

Width of the grid lines.

0.5
**kwargs

Additional keyword arguments.

{}

Returns:

Type Description
Tuple[Figure, Axes, Optional[Colorbar], ndarray]

A tuple containing the Figure, Axes, Colorbar (if shown), and the input matrix.

Note

This function creates a heatmap scatter plot with various customization options. The size of scatter points can be scaled based on the absolute value of the matrix elements.

Source code in flyvision/analysis/visualization/plots.py
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
def heatmap(
    matrix: np.ndarray,
    xlabels: List[str],
    ylabels: Optional[List[str]] = None,
    size_scale: Union[str, float] = "auto",
    cmap: mpl.colors.Colormap = cm.get_cmap("seismic"),
    origin: Literal["upper", "lower"] = "upper",
    ax: Optional[Axes] = None,
    fig: Optional[Figure] = None,
    vmin: Optional[float] = None,
    vmax: Optional[float] = None,
    symlog: Optional[bool] = None,
    cbar_label: str = "",
    log: Optional[bool] = None,
    cbar_height: float = 0.5,
    cbar_width: float = 0.01,
    title: str = "",
    figsize: Tuple[float, float] = [5, 4],
    fontsize: int = 4,
    midpoint: Optional[float] = None,
    cbar: bool = True,
    grid_linewidth: float = 0.5,
    **kwargs,
) -> Tuple[Figure, Axes, Optional[Colorbar], np.ndarray]:
    """
    Create a heatmap scatter plot of the matrix.

    Args:
        matrix: 2D matrix to be plotted.
        xlabels: List of x-axis labels.
        ylabels: List of y-axis labels. If not provided, xlabels will be used.
        size_scale: Size scale of the scatter points. If "auto",
            uses 0.005 * prod(figsize).
        cmap: Colormap for the heatmap.
        origin: Origin of the matrix. Either "upper" or "lower".
        ax: Existing Matplotlib Axes object to plot on.
        fig: Existing Matplotlib Figure object to use.
        vmin: Minimum value for color scaling.
        vmax: Maximum value for color scaling.
        symlog: Whether to use symmetric log normalization.
        cbar_label: Label for the colorbar.
        log: Whether to use logarithmic color scaling.
        cbar_height: Height of the colorbar.
        cbar_width: Width of the colorbar.
        title: Title of the plot.
        figsize: Size of the figure.
        fontsize: Font size for labels and ticks.
        midpoint: Midpoint for diverging colormaps.
        cbar: Whether to show the colorbar.
        grid_linewidth: Width of the grid lines.
        **kwargs: Additional keyword arguments.

    Returns:
        A tuple containing the Figure, Axes, Colorbar (if shown), and the input matrix.

    Note:
        This function creates a heatmap scatter plot with various customization options.
        The size of scatter points can be scaled based on the absolute value of the matrix
        elements.
    """
    y, x = np.nonzero(matrix)
    value = matrix[y, x]

    fig, ax = plt_utils.init_plot(figsize, title, fontsize, ax=ax, fig=fig, offset=0)

    norm = plt_utils.get_norm(
        symlog=symlog,
        vmin=vmin if vmin is not None else np.nanmin(matrix),
        vmax=vmax if vmax is not None else np.nanmax(matrix),
        log=log,
        midpoint=midpoint,
    )

    size = np.abs(value) * (
        size_scale if size_scale != "auto" else 0.005 * np.prod(figsize)
    )

    ax.scatter(
        x=x,
        y=matrix.shape[0] - y - 1 if origin == "upper" else y,
        s=size,
        c=value,
        cmap=cmap,
        norm=norm,
        marker="s",
        edgecolors="none",
    )

    ax.set_xticks(np.arange(matrix.shape[1]))
    ax.set_xticklabels(xlabels, rotation=90, fontsize=fontsize)
    ax.set_yticks(np.arange(matrix.shape[0]))
    ylabels = ylabels if ylabels is not None else xlabels
    ax.set_yticklabels(ylabels[::-1] if origin == "upper" else ylabels, fontsize=fontsize)

    ax.grid(False, "major")
    ax.grid(True, "minor", linewidth=grid_linewidth)
    ax.set_xticks([t + 0.5 for t in ax.get_xticks()[:-1]], minor=True)
    ax.set_yticks([t + 0.5 for t in ax.get_yticks()[:-1]], minor=True)

    ax.set_xlim([-0.5, matrix.shape[1]])
    ax.set_ylim([-0.5, matrix.shape[0]])
    ax.tick_params(axis="x", which="minor", bottom=False)
    ax.tick_params(axis="y", which="minor", left=False)

    cbar_obj = None
    if cbar:
        cbar_obj = plt_utils.add_colorbar_to_fig(
            fig,
            height=cbar_height,
            width=cbar_width,
            cmap=cmap,
            norm=norm,
            fontsize=fontsize,
            label=cbar_label,
            x_offset=15,
        )

    return fig, ax, cbar_obj, matrix

flyvision.analysis.visualization.plots.hex_scatter

hex_scatter(
    u,
    v,
    values,
    max_extent=None,
    fig=None,
    ax=None,
    figsize=(1, 1),
    title="",
    title_y=None,
    fontsize=5,
    label="",
    labelxy="auto",
    label_color="black",
    edgecolor=None,
    edgewidth=0.5,
    alpha=1,
    fill=False,
    scalarmapper=None,
    norm=None,
    radius=1,
    origin="lower",
    vmin=None,
    vmax=None,
    midpoint=None,
    mode="default",
    orientation=np.radians(30),
    cmap=cm.get_cmap("seismic"),
    cbar=True,
    cbar_label="",
    cbar_height=None,
    cbar_width=None,
    cbar_x_offset=0.05,
    annotate=False,
    annotate_coords=False,
    annotate_indices=False,
    frame=False,
    frame_hex_width=1,
    frame_color=None,
    nan_linestyle="-",
    text_color_hsv_threshold=0.8,
    **kwargs
)

Plot a hexagonally arranged data points with coordinates u, v, and coloring color.

Parameters:

Name Type Description Default
u NDArray

Array of hex coordinates in u direction.

required
v NDArray

Array of hex coordinates in v direction.

required
values NDArray

Array of pixel values per point (u_i, v_i).

required
fill Union[bool, int]

Whether to fill the hex grid around u, v, values.

False
max_extent Optional[int]

Maximum extent of the hex lattice shown. When fill=True, the hex grid is padded to the maximum extent when above the extent of u, v.

None
fig Optional[Figure]

Matplotlib Figure object.

None
ax Optional[Axes]

Matplotlib Axes object.

None
figsize Tuple[float, float]

Size of the figure.

(1, 1)
title str

Title of the plot.

''
title_y Optional[float]

Y-position of the title.

None
fontsize int

Font size for text elements.

5
label str

Label for the plot.

''
labelxy Union[str, Tuple[float, float]]

Position of the label. Either “auto” or a tuple of (x, y) coordinates.

'auto'
label_color str

Color of the label.

'black'
edgecolor Optional[str]

Color of the hexagon edges.

None
edgewidth float

Width of the hexagon edges.

0.5
alpha float

Alpha value for transparency.

1
scalarmapper Optional[ScalarMappable]

ScalarMappable object for color mapping.

None
norm Optional[Normalize]

Normalization for color mapping.

None
radius float

Radius of the hexagons.

1
origin Literal['lower', 'upper']

Origin of the plot. Either “lower” or “upper”.

'lower'
vmin Optional[float]

Minimum value for color mapping.

None
vmax Optional[float]

Maximum value for color mapping.

None
midpoint Optional[float]

Midpoint for color mapping.

None
mode str

Hex coordinate system mode.

'default'
orientation float

Orientation of the hexagons in radians.

radians(30)
cmap Colormap

Colormap for the plot.

get_cmap('seismic')
cbar bool

Whether to show a colorbar.

True
cbar_label str

Label for the colorbar.

''
cbar_height Optional[float]

Height of the colorbar.

None
cbar_width Optional[float]

Width of the colorbar.

None
cbar_x_offset float

X-offset of the colorbar.

0.05
annotate bool

Whether to annotate hexagons with values.

False
annotate_coords bool

Whether to annotate hexagons with coordinates.

False
annotate_indices bool

Whether to annotate hexagons with indices.

False
frame bool

Whether to add a frame around the plot.

False
frame_hex_width int

Width of the frame in hexagon units.

1
frame_color Optional[Union[str, Tuple[float, float, float, float]]]

Color of the frame.

None
nan_linestyle str

Line style for NaN values.

'-'
text_color_hsv_threshold float

Threshold for text color in HSV space.

0.8
**kwargs

Additional keyword arguments.

{}

Returns:

Type Description
Tuple[Figure, Axes, Tuple[Optional[Line2D], ScalarMappable]]

A tuple containing the Figure, Axes, and a tuple of (label_text, scalarmapper).

Source code in flyvision/analysis/visualization/plots.py
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
def hex_scatter(
    u: NDArray,
    v: NDArray,
    values: NDArray,
    max_extent: Optional[int] = None,
    fig: Optional[Figure] = None,
    ax: Optional[Axes] = None,
    figsize: Tuple[float, float] = (1, 1),
    title: str = "",
    title_y: Optional[float] = None,
    fontsize: int = 5,
    label: str = "",
    labelxy: Union[str, Tuple[float, float]] = "auto",
    label_color: str = "black",
    edgecolor: Optional[str] = None,
    edgewidth: float = 0.5,
    alpha: float = 1,
    fill: Union[bool, int] = False,
    scalarmapper: Optional[mpl.cm.ScalarMappable] = None,
    norm: Optional[mpl.colors.Normalize] = None,
    radius: float = 1,
    origin: Literal["lower", "upper"] = "lower",
    vmin: Optional[float] = None,
    vmax: Optional[float] = None,
    midpoint: Optional[float] = None,
    mode: str = "default",
    orientation: float = np.radians(30),
    cmap: mpl.colors.Colormap = cm.get_cmap("seismic"),
    cbar: bool = True,
    cbar_label: str = "",
    cbar_height: Optional[float] = None,
    cbar_width: Optional[float] = None,
    cbar_x_offset: float = 0.05,
    annotate: bool = False,
    annotate_coords: bool = False,
    annotate_indices: bool = False,
    frame: bool = False,
    frame_hex_width: int = 1,
    frame_color: Optional[Union[str, Tuple[float, float, float, float]]] = None,
    nan_linestyle: str = "-",
    text_color_hsv_threshold: float = 0.8,
    **kwargs,
) -> Tuple[Figure, Axes, Tuple[Optional[Line2D], mpl.cm.ScalarMappable]]:
    """
    Plot a hexagonally arranged data points with coordinates u, v, and coloring color.

    Args:
        u: Array of hex coordinates in u direction.
        v: Array of hex coordinates in v direction.
        values: Array of pixel values per point (u_i, v_i).
        fill: Whether to fill the hex grid around u, v, values.
        max_extent: Maximum extent of the hex lattice shown. When fill=True, the hex
            grid is padded to the maximum extent when above the extent of u, v.
        fig: Matplotlib Figure object.
        ax: Matplotlib Axes object.
        figsize: Size of the figure.
        title: Title of the plot.
        title_y: Y-position of the title.
        fontsize: Font size for text elements.
        label: Label for the plot.
        labelxy: Position of the label. Either "auto" or a tuple of (x, y) coordinates.
        label_color: Color of the label.
        edgecolor: Color of the hexagon edges.
        edgewidth: Width of the hexagon edges.
        alpha: Alpha value for transparency.
        scalarmapper: ScalarMappable object for color mapping.
        norm: Normalization for color mapping.
        radius: Radius of the hexagons.
        origin: Origin of the plot. Either "lower" or "upper".
        vmin: Minimum value for color mapping.
        vmax: Maximum value for color mapping.
        midpoint: Midpoint for color mapping.
        mode: Hex coordinate system mode.
        orientation: Orientation of the hexagons in radians.
        cmap: Colormap for the plot.
        cbar: Whether to show a colorbar.
        cbar_label: Label for the colorbar.
        cbar_height: Height of the colorbar.
        cbar_width: Width of the colorbar.
        cbar_x_offset: X-offset of the colorbar.
        annotate: Whether to annotate hexagons with values.
        annotate_coords: Whether to annotate hexagons with coordinates.
        annotate_indices: Whether to annotate hexagons with indices.
        frame: Whether to add a frame around the plot.
        frame_hex_width: Width of the frame in hexagon units.
        frame_color: Color of the frame.
        nan_linestyle: Line style for NaN values.
        text_color_hsv_threshold: Threshold for text color in HSV space.
        **kwargs: Additional keyword arguments.

    Returns:
        A tuple containing the Figure, Axes, and a tuple of (label_text, scalarmapper).
    """

    def init_plot_and_validate_input(fig, ax):
        nonlocal values, u, v
        fig, ax = plt_utils.init_plot(
            figsize, title, fontsize, ax=ax, fig=fig, title_y=title_y, **kwargs
        )
        ax.set_aspect("equal")
        values = values * np.ones_like(u) if not isinstance(values, Iterable) else values
        if u.shape != v.shape or u.shape != values.shape:
            raise ValueError("shape mismatch of hexal values and coordinates")
        u, v, values = hex_utils.sort_u_then_v(u, v, values)
        return fig, ax

    def apply_max_extent():
        nonlocal u, v, values
        extent = hex_utils.get_extent(u, v) or 1
        if fill:
            u, v, values = hex_utils.pad_to_regular_hex(u, v, values, extent=extent)
        if max_extent is not None and extent > max_extent:
            u, v, values = hex_utils.crop_to_extent(u, v, values, max_extent)
        elif max_extent is not None and extent < max_extent and fill:
            u, v, values = hex_utils.pad_to_regular_hex(u, v, values, extent=max_extent)

    def setup_color_mapping(scalarmapper, norm):
        nonlocal vmin, vmax
        if np.any(values):
            vmin = vmin - 1e-10 if vmin is not None else np.nanmin(values) - 1e-10
            vmax = vmax + 1e-10 if vmax is not None else np.nanmax(values) + 1e-10
        else:
            vmin = 0
            vmax = 1

        if (
            midpoint == 0
            and np.isclose(vmin, vmax, atol=1e-10)
            and np.sign(vmin) == np.sign(vmax)
        ):
            sign = np.sign(vmax)
            if sign > 0:
                vmin = -vmax
            elif sign < 0:
                vmax = -vmin
            else:
                raise ValueError

        if midpoint == 0 and np.isnan(values).all():
            vmin = 0
            vmax = 0

        scalarmapper, norm = plt_utils.get_scalarmapper(
            scalarmapper=scalarmapper,
            cmap=cmap,
            norm=norm,
            vmin=vmin,
            vmax=vmax,
            midpoint=midpoint,
        )
        return scalarmapper.to_rgba(values), scalarmapper, norm

    def apply_frame():
        nonlocal u, v, values, color_rgba
        if frame:
            extent = hex_utils.get_extent(u, v) or 1
            _u, _v = hex_utils.get_hex_coords(extent + frame_hex_width)
            framed_color = np.zeros([len(_u)])
            framed_color_rgba = np.zeros([len(_u), 4])
            uv = np.stack((_u, _v), 1)
            _rings = (
                abs(0 - uv[:, 0]) + abs(0 + 0 - uv[:, 0] - uv[:, 1]) + abs(0 - uv[:, 1])
            ) / 2
            mask = np.where(_rings <= extent, True, False)
            framed_color[mask] = values
            framed_color[~mask] = 0.0
            framed_color_rgba[mask] = color_rgba
            framed_color_rgba[~mask] = (
                frame_color if frame_color else np.array([0, 0, 0, 1])
            )
            u, v, color_rgba, values = _u, _v, framed_color_rgba, framed_color

    def draw_hexagons():
        x, y = hex_utils.hex_to_pixel(u, v, mode=mode)
        if origin == "upper":
            y = y[::-1]
        c_mask = np.ma.masked_invalid(values)
        for i, (_x, _y, fc) in enumerate(zip(x, y, color_rgba)):
            if c_mask.mask[i]:
                _hex = RegularPolygon(
                    (_x, _y),
                    numVertices=6,
                    radius=radius,
                    linewidth=edgewidth,
                    orientation=orientation,
                    edgecolor=edgecolor,
                    facecolor="white",
                    alpha=alpha,
                    ls=nan_linestyle,
                )
            else:
                _hex = RegularPolygon(
                    (_x, _y),
                    numVertices=6,
                    radius=radius,
                    linewidth=edgewidth,
                    orientation=orientation,
                    edgecolor=edgecolor or fc,
                    facecolor=fc,
                    alpha=alpha,
                )
            ax.add_patch(_hex)
        return x, y, c_mask

    def add_colorbar():
        if cbar:
            plt_utils.add_colorbar_to_fig(
                fig,
                label=cbar_label,
                width=cbar_width or 0.03,
                height=cbar_height or 0.5,
                x_offset=cbar_x_offset or -2,
                cmap=cmap,
                norm=norm,
                fontsize=fontsize,
                tick_length=1,
                tick_width=0.25,
                rm_outline=True,
            )

    def set_plot_limits(x, y):
        extent = hex_utils.get_extent(u, v) or 1
        if fill:
            u_cs, v_cs = hex_utils.get_hex_coords(extent)
            x_cs, y_cs = hex_utils.hex_to_pixel(u_cs, v_cs, mode=mode)
            if origin == "upper":
                y_cs = y_cs[::-1]
            xmin, xmax = plt_utils.get_lims(x_cs, 1 / extent)
            ymin, ymax = plt_utils.get_lims(y_cs, 1 / extent)
        else:
            xmin, xmax = plt_utils.get_lims(x, 1 / extent)
            ymin, ymax = plt_utils.get_lims(y, 1 / extent)
        if xmin != xmax and ymin != ymax:
            ax.set(xlim=[xmin, xmax], ylim=[ymin, ymax])

    def annotate_hexagons(x, y, c_mask):
        if annotate:
            for i, (_label, _x, _y) in enumerate(zip(values, x, y)):
                if not c_mask.mask[i] and not np.isnan(_label):
                    _textcolor = (
                        "black"
                        if mpl.colors.rgb_to_hsv(color_rgba[i][:-1])[-1]
                        > text_color_hsv_threshold
                        else "white"
                    )
                    ax.annotate(
                        f"{_label:.1F}",
                        fontsize=fontsize,
                        xy=(_x, _y),
                        xytext=(0, 0),
                        textcoords="offset points",
                        ha="center",
                        va="center",
                        color=_textcolor,
                    )
        if annotate_coords:
            for _x, _y, _u, _v in zip(x, y, u, v):
                ax.text(
                    _x - 0.45,
                    _y + 0.2,
                    _u,
                    ha="center",
                    va="center",
                    fontsize=fontsize,
                )
                ax.text(
                    _x + 0.45,
                    _y + 0.2,
                    _v,
                    ha="center",
                    va="center",
                    fontsize=fontsize,
                )
        if annotate_indices:
            for i, (_x, _y) in enumerate(zip(x, y)):
                ax.text(_x, _y, i, ha="center", va="center", fontsize=fontsize)

    def add_label():
        if labelxy == "auto":
            extent = hex_utils.get_extent(u, v) or 1
            u_cs, v_cs = hex_utils.get_hex_coords(extent)
            z = -u_cs + v_cs
            labelu, labelv = min(u_cs[z == 0]) - 1, min(v_cs[z == 0]) - 1
            labelx, labely = hex_utils.hex_to_pixel(labelu, labelv)
            ha = "right" if len(label) < 4 else "center"
            label_text = ax.annotate(
                label,
                (labelx, labely),
                ha=ha,
                va="bottom",
                fontsize=fontsize,
                zorder=1000,
                xycoords="data",
                color=label_color,
            )
        else:
            label_text = ax.text(
                labelxy[0],
                labelxy[1],
                label,
                transform=ax.transAxes,
                ha="left",
                va="center",
                fontsize=fontsize,
                zorder=100,
                color=label_color,
            )
        return label_text

    # Main execution
    fig, ax = init_plot_and_validate_input(fig, ax)
    apply_max_extent()
    color_rgba, scalarmapper, norm = setup_color_mapping(scalarmapper, norm)
    apply_frame()
    x, y, c_mask = draw_hexagons()
    add_colorbar()
    set_plot_limits(x, y)
    ax = plt_utils.rm_spines(ax, rm_xticks=True, rm_yticks=True)
    annotate_hexagons(x, y, c_mask)
    label_text = add_label()

    (xmin, ymin, xmax, ymax) = ax.dataLim.extents
    ax.set_xlim(plt_utils.get_lims((xmin, xmax), 0.01))
    ax.set_ylim(plt_utils.get_lims((ymin, ymax), 0.01))

    return fig, ax, (label_text, scalarmapper)

flyvision.analysis.visualization.plots.kernel

kernel(
    u,
    v,
    values,
    fontsize=5,
    cbar=True,
    edgecolor="k",
    fig=None,
    ax=None,
    figsize=(1, 1),
    midpoint=0,
    annotate=True,
    alpha=0.8,
    annotate_coords=False,
    coord_fs=8,
    cbar_height=0.3,
    cbar_x_offset=-1,
    **kwargs
)

Plot receptive fields with hex_scatter.

Parameters:

Name Type Description Default
u NDArray

Array of hex coordinates in u direction.

required
v NDArray

Array of hex coordinates in v direction.

required
color

Array of pixel values per point (u_i, v_i).

required
fontsize int

Font size for text elements.

5
cbar bool

Whether to show a colorbar.

True
edgecolor str

Color of the hexagon edges.

'k'
fig Optional[Figure]

Matplotlib Figure object.

None
ax Optional[Axes]

Matplotlib Axes object.

None
figsize Tuple[float, float]

Size of the figure.

(1, 1)
midpoint float

Midpoint for color mapping.

0
annotate bool

Whether to annotate hexagons with values.

True
alpha float

Alpha value for transparency.

0.8
annotate_coords bool

Whether to annotate hexagons with coordinates.

False
coord_fs int

Font size for coordinate annotations.

8
cbar_height float

Height of the colorbar.

0.3
cbar_x_offset float

X-offset of the colorbar.

-1
**kwargs

Additional keyword arguments passed to hex_scatter.

{}

Returns:

Type Description
Tuple[Figure, Axes, Tuple[Optional[Line2D], ScalarMappable]]

A tuple containing the Figure, Axes, and a tuple of (label_text, scalarmapper).

Raises:

Type Description
SignError

If signs in the kernel are inconsistent.

Note

Assigns seismic as colormap and checks that signs are consistent. All arguments except cmap can be passed to hex_scatter.

Source code in flyvision/analysis/visualization/plots.py
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
@wraps(hex_scatter)
def kernel(
    u: NDArray,
    v: NDArray,
    values: NDArray,
    fontsize: int = 5,
    cbar: bool = True,
    edgecolor: str = "k",
    fig: Optional[Figure] = None,
    ax: Optional[Axes] = None,
    figsize: Tuple[float, float] = (1, 1),
    midpoint: float = 0,
    annotate: bool = True,
    alpha: float = 0.8,
    annotate_coords: bool = False,
    coord_fs: int = 8,
    cbar_height: float = 0.3,
    cbar_x_offset: float = -1,
    **kwargs,
) -> Tuple[Figure, Axes, Tuple[Optional[Line2D], mpl.cm.ScalarMappable]]:
    """Plot receptive fields with hex_scatter.

    Args:
        u: Array of hex coordinates in u direction.
        v: Array of hex coordinates in v direction.
        color: Array of pixel values per point (u_i, v_i).
        fontsize: Font size for text elements.
        cbar: Whether to show a colorbar.
        edgecolor: Color of the hexagon edges.
        fig: Matplotlib Figure object.
        ax: Matplotlib Axes object.
        figsize: Size of the figure.
        midpoint: Midpoint for color mapping.
        annotate: Whether to annotate hexagons with values.
        alpha: Alpha value for transparency.
        annotate_coords: Whether to annotate hexagons with coordinates.
        coord_fs: Font size for coordinate annotations.
        cbar_height: Height of the colorbar.
        cbar_x_offset: X-offset of the colorbar.
        **kwargs: Additional keyword arguments passed to hex_scatter.

    Returns:
        A tuple containing the Figure, Axes, and a tuple of (label_text, scalarmapper).

    Raises:
        SignError: If signs in the kernel are inconsistent.

    Note:
        Assigns `seismic` as colormap and checks that signs are consistent.
        All arguments except `cmap` can be passed to hex_scatter.
    """

    def check_sign_consistency(values: NDArray) -> None:
        non_zero_signs = set(np.sign(values[np.nonzero(values)]))
        if len(non_zero_signs) > 1:
            raise SignError(f"Inconsistent kernel with signs {non_zero_signs}")

    check_sign_consistency(values)

    hex_scatter_kwargs = {
        'u': u,
        'v': v,
        'values': values,
        'fontsize': fontsize,
        'cbar': cbar,
        'edgecolor': edgecolor,
        'fig': fig,
        'ax': ax,
        'figsize': figsize,
        'midpoint': midpoint,
        'annotate': annotate,
        'alpha': alpha,
        'annotate_coords': annotate_coords,
        'coord_fs': coord_fs,
        'cbar_height': cbar_height,
        'cbar_x_offset': cbar_x_offset,
        'cmap': cm.get_cmap("seismic"),
        **kwargs,
    }

    return hex_scatter(**hex_scatter_kwargs)

flyvision.analysis.visualization.plots.hex_cs

hex_cs(
    extent=5,
    mode="default",
    annotate_coords=True,
    edgecolor="black",
    **kwargs
)

Plot a hexagonal coordinate system.

Parameters:

Name Type Description Default
extent int

Extent of the hexagonal grid.

5
mode Literal['default', 'flat']

Hex coordinate system mode.

'default'
annotate_coords bool

Whether to annotate hexagons with coordinates.

True
edgecolor str

Color of the hexagon edges.

'black'
**kwargs

Additional keyword arguments passed to hex_scatter.

{}

Returns:

Type Description
Tuple[Figure, Axes, Tuple[Optional[Line2D], ScalarMappable]]

A tuple containing the Figure, Axes, and a tuple of (label_text, scalarmapper).

Source code in flyvision/analysis/visualization/plots.py
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
def hex_cs(
    extent: int = 5,
    mode: Literal["default", "flat"] = "default",
    annotate_coords: bool = True,
    edgecolor: str = "black",
    **kwargs,
) -> Tuple[Figure, Axes, Tuple[Optional[Line2D], mpl.cm.ScalarMappable]]:
    """Plot a hexagonal coordinate system.

    Args:
        extent: Extent of the hexagonal grid.
        mode: Hex coordinate system mode.
        annotate_coords: Whether to annotate hexagons with coordinates.
        edgecolor: Color of the hexagon edges.
        **kwargs: Additional keyword arguments passed to hex_scatter.

    Returns:
        A tuple containing the Figure, Axes, and a tuple of (label_text, scalarmapper).
    """
    u, v = hex_utils.get_hex_coords(extent)
    return hex_scatter(
        u,
        v,
        1,
        cmap=cm.get_cmap("binary_r"),
        annotate_coords=annotate_coords,
        vmin=0,
        vmax=1,
        edgecolor=edgecolor,
        cbar=False,
        mode=mode,
        **kwargs,
    )

flyvision.analysis.visualization.plots.quick_hex_scatter

quick_hex_scatter(
    values, cmap=cm.get_cmap("binary_r"), **kwargs
)

Create a hex scatter plot with implicit coordinates.

Parameters:

Name Type Description Default
values NDArray

Array of pixel values.

required
cmap Colormap

Colormap for the plot.

get_cmap('binary_r')
**kwargs

Additional keyword arguments passed to hex_scatter.

{}

Returns:

Type Description
Tuple[Figure, Axes, Tuple[Optional[Line2D], ScalarMappable]]

A tuple containing the Figure, Axes, and a tuple of (label_text, scalarmapper).

Source code in flyvision/analysis/visualization/plots.py
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
def quick_hex_scatter(
    values: NDArray, cmap: mpl.colors.Colormap = cm.get_cmap("binary_r"), **kwargs
) -> Tuple[Figure, Axes, Tuple[Optional[Line2D], mpl.cm.ScalarMappable]]:
    """Create a hex scatter plot with implicit coordinates.

    Args:
        values: Array of pixel values.
        cmap: Colormap for the plot.
        **kwargs: Additional keyword arguments passed to hex_scatter.

    Returns:
        A tuple containing the Figure, Axes, and a tuple of (label_text, scalarmapper).
    """
    values = utils.tensor_utils.to_numpy(values.squeeze())
    u, v = hex_utils.get_hex_coords(hex_utils.get_hextent(len(values)))
    return hex_scatter(u, v, values, cmap=cmap, **kwargs)

flyvision.analysis.visualization.plots.hex_flow

hex_flow(
    u,
    v,
    flow,
    fig=None,
    ax=None,
    figsize=(1, 1),
    title="",
    cmap=plt_utils.cm_uniform_2d,
    max_extent=None,
    cwheelradius=0.25,
    mode="default",
    orientation=np.radians(30),
    origin="lower",
    fontsize=5,
    cwheel=True,
    cwheelxy=(),
    cwheelpos="southeast",
    cwheellabelpad=-5,
    annotate_r=False,
    annotate_theta=False,
    annotate_coords=False,
    coord_fs=3,
    label="",
    labelxy=(0, 1),
    vmin=-np.pi,
    vmax=np.pi,
    edgecolor=None,
    **kwargs
)

Plot a hexagonal lattice with coordinates u, v, and flow.

Parameters:

Name Type Description Default
u NDArray

Array of hex coordinates in u direction.

required
v NDArray

Array of hex coordinates in v direction.

required
flow NDArray

Array of flow per point (u_i, v_i), shape [2, len(u)].

required
fig Optional[Figure]

Matplotlib Figure object.

None
ax Optional[Axes]

Matplotlib Axes object.

None
figsize Tuple[float, float]

Size of the figure.

(1, 1)
title str

Title of the plot.

''
cmap Colormap

Colormap for the plot.

cm_uniform_2d
max_extent Optional[int]

Maximum extent of the hex lattice.

None
cwheelradius float

Radius of the colorwheel.

0.25
mode Literal['default', 'flat']

Hex coordinate system mode.

'default'
orientation float

Orientation of hexagons in radians.

radians(30)
origin Literal['lower', 'upper']

Origin of the plot.

'lower'
fontsize int

Font size for text elements.

5
cwheel bool

Whether to show a colorwheel.

True
cwheelxy Tuple[float, float]

Position of the colorwheel.

()
cwheelpos str

Position of the colorwheel.

'southeast'
cwheellabelpad float

Padding for colorwheel labels.

-5
annotate_r bool

Whether to annotate hexagons with magnitude.

False
annotate_theta bool

Whether to annotate hexagons with angle.

False
annotate_coords bool

Whether to annotate hexagons with coordinates.

False
coord_fs int

Font size for coordinate annotations.

3
label str

Label for the plot.

''
labelxy Tuple[float, float]

Position of the label.

(0, 1)
vmin float

Minimum value for color mapping.

-pi
vmax float

Maximum value for color mapping.

pi
edgecolor Optional[str]

Color of the hexagon edges.

None
**kwargs

Additional keyword arguments.

{}

Returns:

Type Description
Tuple[Figure, Axes, Tuple[Optional[Line2D], ScalarMappable, Optional[Colorbar], Optional[PathCollection]]]

A tuple containing the Figure, Axes, and a tuple of (label_text, scalarmapper, colorbar, scatter).

Note

Works largely like hex_scatter, but with 2d-flow instead of 1d-intensities.

Source code in flyvision/analysis/visualization/plots.py
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
def hex_flow(
    u: NDArray,
    v: NDArray,
    flow: NDArray,
    fig: Optional[Figure] = None,
    ax: Optional[Axes] = None,
    figsize: Tuple[float, float] = (1, 1),
    title: str = "",
    cmap: mpl.colors.Colormap = plt_utils.cm_uniform_2d,
    max_extent: Optional[int] = None,
    cwheelradius: float = 0.25,
    mode: Literal["default", "flat"] = "default",
    orientation: float = np.radians(30),
    origin: Literal["lower", "upper"] = "lower",
    fontsize: int = 5,
    cwheel: bool = True,
    cwheelxy: Tuple[float, float] = (),
    cwheelpos: str = "southeast",
    cwheellabelpad: float = -5,
    annotate_r: bool = False,
    annotate_theta: bool = False,
    annotate_coords: bool = False,
    coord_fs: int = 3,
    label: str = "",
    labelxy: Tuple[float, float] = (0, 1),
    vmin: float = -np.pi,
    vmax: float = np.pi,
    edgecolor: Optional[str] = None,
    **kwargs,
) -> Tuple[
    Figure,
    Axes,
    Tuple[
        Optional[Line2D],
        mpl.cm.ScalarMappable,
        Optional[mpl.colorbar.Colorbar],
        Optional[mpl.collections.PathCollection],
    ],
]:
    """Plot a hexagonal lattice with coordinates u, v, and flow.

    Args:
        u: Array of hex coordinates in u direction.
        v: Array of hex coordinates in v direction.
        flow: Array of flow per point (u_i, v_i), shape [2, len(u)].
        fig: Matplotlib Figure object.
        ax: Matplotlib Axes object.
        figsize: Size of the figure.
        title: Title of the plot.
        cmap: Colormap for the plot.
        max_extent: Maximum extent of the hex lattice.
        cwheelradius: Radius of the colorwheel.
        mode: Hex coordinate system mode.
        orientation: Orientation of hexagons in radians.
        origin: Origin of the plot.
        fontsize: Font size for text elements.
        cwheel: Whether to show a colorwheel.
        cwheelxy: Position of the colorwheel.
        cwheelpos: Position of the colorwheel.
        cwheellabelpad: Padding for colorwheel labels.
        annotate_r: Whether to annotate hexagons with magnitude.
        annotate_theta: Whether to annotate hexagons with angle.
        annotate_coords: Whether to annotate hexagons with coordinates.
        coord_fs: Font size for coordinate annotations.
        label: Label for the plot.
        labelxy: Position of the label.
        vmin: Minimum value for color mapping.
        vmax: Maximum value for color mapping.
        edgecolor: Color of the hexagon edges.
        **kwargs: Additional keyword arguments.

    Returns:
        A tuple containing the Figure, Axes, and a tuple of
            (label_text, scalarmapper, colorbar, scatter).

    Note:
        Works largely like hex_scatter, but with 2d-flow instead of 1d-intensities.
    """
    fig, ax = plt_utils.init_plot(figsize, title, fontsize, ax, fig)
    ax.set_aspect("equal")

    if max_extent:
        max_extent_index = hex_utils.max_extent_index(u, v, max_extent=max_extent)
        flow = flow[:, max_extent_index]
        u = u[max_extent_index]
        v = v[max_extent_index]

    r = np.linalg.norm(flow, axis=0)
    r /= r.max()
    theta = np.arctan2(flow[1], flow[0])

    vmin = vmin if vmin else theta.min()
    vmax = vmax if vmax else theta.max()
    scalarmapper, _ = plt_utils.get_scalarmapper(
        cmap=cmap, vmin=vmin, vmax=vmax, midpoint=0.0
    )
    color_rgba = scalarmapper.to_rgba(theta)
    color_rgba[:, -1] = r

    x, y = hex_utils.hex_to_pixel(u, v, mode=mode)
    if origin == "upper":
        y = y[::-1]

    def draw_hexagons():
        for _x, _y, c in zip(x, y, color_rgba):
            _hex = RegularPolygon(
                (_x, _y),
                numVertices=6,
                radius=1,
                linewidth=0.5,
                orientation=orientation,
                edgecolor=edgecolor or c,
                facecolor=c,
            )
            ax.add_patch(_hex)

    draw_hexagons()

    if cwheel:
        x_offset, y_offset = cwheelxy or (0, 0)
        cb, cs = plt_utils.add_colorwheel_2d(
            fig,
            [ax],
            radius=cwheelradius,
            pos=cwheelpos,
            sm=scalarmapper,
            fontsize=fontsize,
            x_offset=x_offset,
            y_offset=y_offset,
            N=1024,
            labelpad=cwheellabelpad,
        )

    extent = hex_utils.get_extent(u, v)
    ax.set_xlim(x.min() + x.min() / extent, x.max() + x.max() / extent)
    ax.set_ylim(y.min() + y.min() / extent, y.max() + y.max() / extent)

    ax = plt_utils.rm_spines(ax, rm_xticks=True, rm_yticks=True)

    if annotate_r:
        for _r, _x, _y in zip(r, x, y):
            ax.annotate(
                f"{_r:.2G}",
                fontsize=fontsize,
                xy=(_x, _y),
                xytext=(0, 0),
                textcoords="offset points",
                ha="center",
                va="center",
            )

    if annotate_theta:
        for _theta, _x, _y in zip(np.degrees(theta), x, y):
            ax.annotate(
                f"{_theta:.2f}",
                fontsize=fontsize,
                xy=(_x, _y),
                xytext=(0, 0),
                textcoords="offset points",
                ha="center",
                va="center",
            )

    if annotate_coords:
        for _x, _y, _u, _v in zip(x, y, u, v):
            ax.annotate(
                _u,
                fontsize=coord_fs,
                xy=(_x, _y),
                xytext=np.array([-0.25, 0.25]),
                textcoords="offset points",
                ha="center",
                va="center",
            )
            ax.annotate(
                _v,
                fontsize=coord_fs,
                xy=(_x, _y),
                xytext=np.array([0.25, 0.25]),
                textcoords="offset points",
                ha="center",
                va="center",
            )

    label_text = None
    if label:
        label_text = ax.text(
            labelxy[0],
            labelxy[1],
            label,
            transform=ax.transAxes,
            ha="left",
            va="center",
            fontsize=fontsize,
        )

    if cwheel:
        return fig, ax, (label_text, scalarmapper, cb, cs)
    return fig, ax, (label_text, scalarmapper, None, None)

flyvision.analysis.visualization.plots.quick_hex_flow

quick_hex_flow(flow, **kwargs)

Plot a flow field on a hexagonal lattice with implicit coordinates.

Parameters:

Name Type Description Default
flow NDArray

Array of flow values.

required
**kwargs

Additional keyword arguments passed to hex_flow.

{}

Returns:

Type Description
Tuple[Figure, Axes, Tuple[Optional[Line2D], ScalarMappable, Optional[Colorbar], Optional[PathCollection]]]

A tuple containing the Figure, Axes, and a tuple of (label_text, scalarmapper, colorbar, scatter).

Source code in flyvision/analysis/visualization/plots.py
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
def quick_hex_flow(
    flow: NDArray, **kwargs
) -> Tuple[
    Figure,
    Axes,
    Tuple[
        Optional[Line2D],
        mpl.cm.ScalarMappable,
        Optional[mpl.colorbar.Colorbar],
        Optional[mpl.collections.PathCollection],
    ],
]:
    """Plot a flow field on a hexagonal lattice with implicit coordinates.

    Args:
        flow: Array of flow values.
        **kwargs: Additional keyword arguments passed to hex_flow.

    Returns:
        A tuple containing the Figure, Axes, and a tuple of
            (label_text, scalarmapper, colorbar, scatter).
    """
    flow = utils.tensor_utils.to_numpy(flow.squeeze())
    u, v = hex_utils.get_hex_coords(hex_utils.get_hextent(flow.shape[-1]))
    return hex_flow(u, v, flow, **kwargs)

flyvision.analysis.visualization.plots.flow_to_rgba

flow_to_rgba(flow)

Map cartesian flow to RGBA colors.

Parameters:

Name Type Description Default
flow Union[ndarray, Tensor]

Flow field of shape (2, h, w).

required

Returns:

Type Description
ndarray

RGBA color representation of the flow field.

Note

The flow magnitude is mapped to the alpha channel, while the flow direction is mapped to the color using a uniform 2D colormap.

Source code in flyvision/analysis/visualization/plots.py
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
def flow_to_rgba(flow: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
    """Map cartesian flow to RGBA colors.

    Args:
        flow: Flow field of shape (2, h, w).

    Returns:
        RGBA color representation of the flow field.

    Note:
        The flow magnitude is mapped to the alpha channel, while the flow
        direction is mapped to the color using a uniform 2D colormap.
    """
    if isinstance(flow, torch.Tensor):
        flow = flow.cpu().numpy()

    X, Y = flow[0], flow[1]
    R = np.sqrt(X * X + Y * Y)
    PHI = np.arctan2(Y, X)
    scalarmapper, _ = plt_utils.get_scalarmapper(
        cmap=plt_utils.cm_uniform_2d, vmin=-np.pi, vmax=np.pi
    )
    rgba = scalarmapper.to_rgba(PHI)
    rgba[:, :, -1] = R / R.max()
    return rgba

flyvision.analysis.visualization.plots.plot_flow

plot_flow(flow)

Plot cartesian flow.

Parameters:

Name Type Description Default
flow Union[ndarray, Tensor]

Flow field of shape (2, h, w).

required
Note

This function displays the flow field using matplotlib’s imshow and immediately shows the plot.

Source code in flyvision/analysis/visualization/plots.py
884
885
886
887
888
889
890
891
892
893
894
895
896
def plot_flow(flow: Union[np.ndarray, torch.Tensor]) -> None:
    """Plot cartesian flow.

    Args:
        flow: Flow field of shape (2, h, w).

    Note:
        This function displays the flow field using matplotlib's imshow
        and immediately shows the plot.
    """
    rgba = flow_to_rgba(flow)
    plt.imshow(rgba)
    plt.show()

flyvision.analysis.visualization.plots.traces

traces(
    trace,
    x=None,
    contour=None,
    legend=(),
    smooth=None,
    stim_line=None,
    contour_cmap=cm.get_cmap("bone"),
    color=None,
    label="",
    labelxy=(0, 1),
    linewidth=1,
    ax=None,
    fig=None,
    title="",
    highlight_mean=False,
    figsize=(7, 4),
    fontsize=10,
    ylim=None,
    ylabel="",
    xlabel="",
    legend_frame_alpha=0,
    contour_mode="full",
    contour_y_rel=0.06,
    fancy=False,
    scale_pos=None,
    scale_label="100ms",
    null_line=False,
    zorder_traces=None,
    zorder_mean=None,
    **kwargs
)

Create a line plot with optional contour and smoothing.

Parameters:

Name Type Description Default
trace NDArray

2D array (n_traces, n_points) of trace values.

required
x Optional[NDArray]

X-axis values.

None
contour Optional[NDArray]

Array of contour values.

None
legend Tuple[str, ...]

Legend for each trace.

()
smooth Optional[float]

Size of smoothing window in percent of #points.

None
stim_line Optional[NDArray]

Stimulus line data.

None
contour_cmap Colormap

Colormap for the contour.

get_cmap('bone')
color Optional[Union[str, List[str]]]

Color(s) for the traces.

None
label str

Label for the plot.

''
labelxy Tuple[float, float]

Position of the label.

(0, 1)
linewidth float

Width of the trace lines.

1
ax Optional[Axes]

Matplotlib Axes object.

None
fig Optional[Figure]

Matplotlib Figure object.

None
title str

Title of the plot.

''
highlight_mean bool

Whether to highlight the mean trace.

False
figsize Tuple[float, float]

Size of the figure.

(7, 4)
fontsize int

Font size for text elements.

10
ylim Optional[Tuple[float, float]]

Y-axis limits.

None
ylabel str

Y-axis label.

''
xlabel str

X-axis label.

''
legend_frame_alpha float

Alpha value for the legend frame.

0
contour_mode Literal['full', 'top', 'bottom']

Mode for contour plotting.

'full'
contour_y_rel float

Relative Y position for contour in “top” or “bottom” mode.

0.06
fancy bool

Whether to use fancy styling.

False
scale_pos Optional[str]

Position of the scale bar.

None
scale_label str

Label for the scale bar.

'100ms'
null_line bool

Whether to draw a null line at y=0.

False
zorder_traces Optional[int]

Z-order for traces.

None
zorder_mean Optional[int]

Z-order for mean trace.

None
**kwargs

Additional keyword arguments.

{}

Returns:

Type Description
Tuple[Figure, Axes, NDArray, Optional[Line2D]]

A tuple containing the Figure, Axes, smoothed trace, and label text.

Note

This function creates a line plot with various options for customization, including contour plotting and trace smoothing.

Source code in flyvision/analysis/visualization/plots.py
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
def traces(
    trace: NDArray,
    x: Optional[NDArray] = None,
    contour: Optional[NDArray] = None,
    legend: Tuple[str, ...] = (),
    smooth: Optional[float] = None,
    stim_line: Optional[NDArray] = None,
    contour_cmap: mpl.colors.Colormap = cm.get_cmap("bone"),
    color: Optional[Union[str, List[str]]] = None,
    label: str = "",
    labelxy: Tuple[float, float] = (0, 1),
    linewidth: float = 1,
    ax: Optional[Axes] = None,
    fig: Optional[Figure] = None,
    title: str = "",
    highlight_mean: bool = False,
    figsize: Tuple[float, float] = (7, 4),
    fontsize: int = 10,
    ylim: Optional[Tuple[float, float]] = None,
    ylabel: str = "",
    xlabel: str = "",
    legend_frame_alpha: float = 0,
    contour_mode: Literal["full", "top", "bottom"] = "full",
    contour_y_rel: float = 0.06,
    fancy: bool = False,
    scale_pos: Optional[str] = None,
    scale_label: str = "100ms",
    null_line: bool = False,
    zorder_traces: Optional[int] = None,
    zorder_mean: Optional[int] = None,
    **kwargs,
) -> Tuple[Figure, Axes, NDArray, Optional[Line2D]]:
    """Create a line plot with optional contour and smoothing.

    Args:
        trace: 2D array (n_traces, n_points) of trace values.
        x: X-axis values.
        contour: Array of contour values.
        legend: Legend for each trace.
        smooth: Size of smoothing window in percent of #points.
        stim_line: Stimulus line data.
        contour_cmap: Colormap for the contour.
        color: Color(s) for the traces.
        label: Label for the plot.
        labelxy: Position of the label.
        linewidth: Width of the trace lines.
        ax: Matplotlib Axes object.
        fig: Matplotlib Figure object.
        title: Title of the plot.
        highlight_mean: Whether to highlight the mean trace.
        figsize: Size of the figure.
        fontsize: Font size for text elements.
        ylim: Y-axis limits.
        ylabel: Y-axis label.
        xlabel: X-axis label.
        legend_frame_alpha: Alpha value for the legend frame.
        contour_mode: Mode for contour plotting.
        contour_y_rel: Relative Y position for contour in "top" or "bottom" mode.
        fancy: Whether to use fancy styling.
        scale_pos: Position of the scale bar.
        scale_label: Label for the scale bar.
        null_line: Whether to draw a null line at y=0.
        zorder_traces: Z-order for traces.
        zorder_mean: Z-order for mean trace.
        **kwargs: Additional keyword arguments.

    Returns:
        A tuple containing the Figure, Axes, smoothed trace, and label text.

    Note:
        This function creates a line plot with various options for customization,
        including contour plotting and trace smoothing.
    """
    trace = np.atleast_2d(np.array(trace))

    if np.ma.masked_invalid(trace).mask.any():
        logging.debug("Invalid values encountered in trace.")

    # Smooth traces.
    if smooth:
        smooth = int(smooth * trace.shape[1])
        ylabel += " (smoothed)"
        trace = plt_utils.avg_pool(trace, smooth)
        if x is not None:
            x = x[0::smooth][: trace.shape[1]]

    shape = trace.shape

    fig, ax = plt_utils.init_plot(figsize, title, fontsize, ax=ax, fig=fig)

    legends = legend if len(legend) == shape[0] else ("",) * shape[0]

    if len(np.shape(color)) <= 1:
        colors = (color,) * shape[0]
    elif len(color) == shape[0]:
        colors = color
    else:
        colors = (None,) * shape[0]

    # Plot traces.
    iterations = np.arange(trace.shape[1]) if x is None else x
    for i, _trace in enumerate(trace):
        ax.plot(
            iterations,
            _trace,
            label=legends[i],
            c=colors[i],
            linewidth=linewidth,
            zorder=zorder_traces,
        )

    if highlight_mean:
        ax.plot(
            iterations,
            np.mean(trace, axis=0),
            linewidth=0.5,
            c="k",
            label="average",
            zorder=zorder_mean,
        )

    if contour is not None and contour_mode is not None:
        ylim = ylim or plt_utils.get_lims(
            np.array([
                min(contour.min(), trace.min()),
                max(contour.max(), trace.max()),
            ]),
            0.1,
        )

        _x = np.arange(len(contour)) if x is None or len(x) != len(contour) else x
        if contour_mode == "full":
            contour_y_range = (-20_000, 20_000)
        elif contour_mode == "top":
            yrange = ylim[1] - ylim[0]
            contour_y_range = (ylim[1], ylim[1] + yrange * contour_y_rel)
            ylim = (ylim[0], contour_y_range[1])
        elif contour_mode == "bottom":
            yrange = ylim[1] - ylim[0]
            contour_y_range = (ylim[0] - yrange * contour_y_rel, ylim[0])
            ylim = (contour_y_range[0], ylim[1])

        _y = np.linspace(*contour_y_range, 100)
        Z = np.tile(contour, (len(_y), 1))
        ax.contourf(
            _x,
            _y,
            Z,
            cmap=contour_cmap,
            levels=2,
            alpha=0.3,
            vmin=0,
            vmax=1,
        )

        if stim_line is not None:
            ax.plot(x, contour, color="k", linestyle="--")

    # Cosmetics.
    ax.set_xlabel(xlabel, fontsize=fontsize)
    ax.set_ylabel(ylabel, fontsize=fontsize)
    if null_line:
        ax.hlines(
            0,
            -20_000,
            20_000,
            color="0.5",
            zorder=-1,
            linewidth=0.5,
        )
    ax.set_xlim(*plt_utils.get_lims(iterations, 0.01))
    ax.tick_params(labelsize=fontsize)
    if legend:
        ax.legend(
            fontsize=fontsize,
            edgecolor="white",
            **dict(
                labelspacing=0.0,
                framealpha=legend_frame_alpha,
                borderaxespad=0.1,
                borderpad=0.1,
                handlelength=1,
                handletextpad=0.3,
            ),
        )
    if ylim is not None:
        ax.set_ylim(*ylim)

    label_text = None
    if label != "":
        label_text = ax.text(
            labelxy[0],
            labelxy[1],
            label,
            transform=ax.transAxes,
            ha="left",
            va="center",
            fontsize=fontsize,
        )

    if scale_pos and not any([isinstance(a, AnchoredSizeBar) for a in ax.artists]):
        scalebar = AnchoredSizeBar(
            ax.transData,
            size=0.1,
            label=scale_label,
            loc=scale_pos,
            pad=0.4,
            frameon=False,
            size_vertical=0.01 * (ax.get_ylim()[1] - ax.get_ylim()[0]),
            fontproperties=dict(size=fontsize),
        )
        ax.add_artist(scalebar)

    if fancy:
        plt_utils.rm_spines(ax, ("left", "bottom"), rm_yticks=True, rm_xticks=True)

    return fig, ax, trace, label_text

flyvision.analysis.visualization.plots.grouped_traces

grouped_traces(
    trace_groups,
    x=None,
    legend=(),
    color=None,
    linewidth=1,
    ax=None,
    fig=None,
    title="",
    highlight_mean=False,
    figsize=(7, 4),
    fontsize=10,
    ylim=None,
    ylabel="",
    xlabel="",
    legend_frame_alpha=0,
    **kwargs
)

Create a line plot with grouped traces.

Parameters:

Name Type Description Default
trace_groups List[ndarray]

List of 2D arrays, each containing trace values.

required
x Optional[ndarray]

X-axis values.

None
legend Tuple[str, ...]

Legend for each trace group.

()
color Optional[Union[str, List[str]]]

Color(s) for the trace groups.

None
linewidth float

Width of the trace lines.

1
ax Optional[Axes]

Matplotlib Axes object.

None
fig Optional[Figure]

Matplotlib Figure object.

None
title str

Title of the plot.

''
highlight_mean bool

Whether to highlight the mean trace.

False
figsize Tuple[float, float]

Size of the figure.

(7, 4)
fontsize int

Font size for text elements.

10
ylim Optional[Tuple[float, float]]

Y-axis limits.

None
ylabel str

Y-axis label.

''
xlabel str

X-axis label.

''
legend_frame_alpha float

Alpha value for the legend frame.

0
**kwargs

Additional keyword arguments passed to traces().

{}

Returns:

Type Description
Tuple[Figure, Axes]

A tuple containing the Figure and Axes objects.

Source code in flyvision/analysis/visualization/plots.py
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
def grouped_traces(
    trace_groups: List[np.ndarray],
    x: Optional[np.ndarray] = None,
    legend: Tuple[str, ...] = (),
    color: Optional[Union[str, List[str]]] = None,
    linewidth: float = 1,
    ax: Optional[Axes] = None,
    fig: Optional[Figure] = None,
    title: str = "",
    highlight_mean: bool = False,
    figsize: Tuple[float, float] = (7, 4),
    fontsize: int = 10,
    ylim: Optional[Tuple[float, float]] = None,
    ylabel: str = "",
    xlabel: str = "",
    legend_frame_alpha: float = 0,
    **kwargs,
) -> Tuple[Figure, Axes]:
    """Create a line plot with grouped traces.

    Args:
        trace_groups: List of 2D arrays, each containing trace values.
        x: X-axis values.
        legend: Legend for each trace group.
        color: Color(s) for the trace groups.
        linewidth: Width of the trace lines.
        ax: Matplotlib Axes object.
        fig: Matplotlib Figure object.
        title: Title of the plot.
        highlight_mean: Whether to highlight the mean trace.
        figsize: Size of the figure.
        fontsize: Font size for text elements.
        ylim: Y-axis limits.
        ylabel: Y-axis label.
        xlabel: X-axis label.
        legend_frame_alpha: Alpha value for the legend frame.
        **kwargs: Additional keyword arguments passed to traces().

    Returns:
        A tuple containing the Figure and Axes objects.
    """
    fig, ax = plt_utils.init_plot(figsize, title, fontsize, ax=ax, fig=fig)

    legends = legend if len(legend) == len(trace_groups) else ("",) * len(trace_groups)

    if color is None:
        color_cycle = plt.rcParams["axes.prop_cycle"].by_key()["color"]
        colors = [color_cycle[i % len(color_cycle)] for i in range(len(trace_groups))]
    elif len(np.shape(color)) <= 1:
        colors = (color,) * len(trace_groups)
    elif len(color) == len(trace_groups) or (
        len(trace_groups) == 1 and len(color) == trace_groups[0].shape[0]
    ):
        colors = color
    else:
        raise ValueError(
            "`color` should be a single value, an iterable of length "
            f"`traces.shape[0]`, or None. Got {color} of shape {np.shape(color)}. "
            f"Expected {np.shape(trace_groups)}."
        )

    for i, _trace in enumerate(trace_groups):
        fig, ax, *_ = traces(
            trace=_trace,
            x=x,
            legend=(),
            color=colors[i],
            linewidth=linewidth,
            ax=ax,
            fig=fig,
            title=title,
            highlight_mean=highlight_mean,
            figsize=figsize,
            fontsize=fontsize,
            ylim=ylim,
            ylabel=ylabel,
            xlabel=xlabel,
            legend_frame_alpha=legend_frame_alpha,
            **kwargs,
        )
    if legend:
        custom_lines = [Line2D([0], [0], color=c) for c in colors]
        ax.legend(
            custom_lines,
            legends,
            fontsize=fontsize,
            edgecolor="white",
            **dict(
                labelspacing=0.0,
                framealpha=legend_frame_alpha,
                borderaxespad=0.1,
                borderpad=0.1,
                handlelength=1,
                handletextpad=0.3,
            ),
        )
    return fig, ax

flyvision.analysis.visualization.plots.get_violin_x_locations

get_violin_x_locations(
    n_groups, n_random_variables, violin_width
)

Calculate x-axis locations for violin plots.

Parameters:

Name Type Description Default
n_groups int

Number of groups.

required
n_random_variables int

Number of random variables.

required
violin_width float

Width of each violin plot.

required

Returns:

Type Description
ndarray

A tuple containing:

ndarray
  • np.ndarray: 2D array of violin locations.
Tuple[ndarray, ndarray]
  • np.ndarray: 1D array of first violin locations.
Source code in flyvision/analysis/visualization/plots.py
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
def get_violin_x_locations(
    n_groups: int, n_random_variables: int, violin_width: float
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Calculate x-axis locations for violin plots.

    Args:
        n_groups: Number of groups.
        n_random_variables: Number of random variables.
        violin_width: Width of each violin plot.

    Returns:
        A tuple containing:
        - np.ndarray: 2D array of violin locations.
        - np.ndarray: 1D array of first violin locations.
    """
    violin_locations = np.zeros([n_groups, n_random_variables])
    first_violins_location = np.arange(0, n_groups * n_random_variables, n_groups)
    for j in range(n_groups):
        violin_locations[j] = first_violins_location + j * violin_width

    return violin_locations, first_violins_location

flyvision.analysis.visualization.plots.violin_groups

violin_groups(
    values,
    xticklabels=None,
    pvalues=None,
    display_pvalues_kwargs={},
    legend=False,
    legend_kwargs={},
    as_bars=False,
    colors=None,
    cmap=mpl.colormaps["tab10"],
    cstart=0,
    cdist=1,
    figsize=(10, 1),
    title="",
    ylabel=None,
    ylim=None,
    rotation=90,
    width=0.7,
    fontsize=6,
    ax=None,
    fig=None,
    showmeans=False,
    showmedians=True,
    grid=False,
    scatter=True,
    scatter_radius=3,
    scatter_edge_color=None,
    scatter_edge_width=0.5,
    violin_alpha=0.5,
    violin_marker_lw=0.5,
    violin_marker_color="k",
    color_by="groups",
    zorder_mean_median=5,
    zorder_min_max=5,
    mean_median_linewidth=0.5,
    mean_median_color="k",
    mean_median_bar_length=None,
    **kwargs
)

Create violin plots or bar plots for grouped data.

Parameters:

Name Type Description Default
values ndarray

Array of shape (n_random_variables, n_groups, n_samples).

required
xticklabels Optional[List[str]]

Labels for the x-axis ticks (random variables).

None
pvalues Optional[ndarray]

Array of p-values for statistical significance.

None
display_pvalues_kwargs dict

Keyword arguments for displaying p-values.

{}
legend Union[bool, List[str]]

If True or a list, display a legend with group labels.

False
legend_kwargs dict

Keyword arguments for the legend.

{}
as_bars bool

If True, create bar plots instead of violin plots.

False
colors Optional[List[str]]

List of colors for the violins or bars.

None
cmap Colormap

Colormap to use when colors are not provided.

colormaps['tab10']
cstart float

Starting point in the colormap.

0
cdist float

Distance between colors in the colormap.

1
figsize Tuple[float, float]

Size of the figure (width, height).

(10, 1)
title str

Title of the plot.

''
ylabel Optional[str]

Label for the y-axis.

None
ylim Optional[Tuple[float, float]]

Limits for the y-axis (min, max).

None
rotation float

Rotation angle for x-axis labels.

90
width float

Width of the violins or bars.

0.7
fontsize int

Font size for labels and ticks.

6
ax Optional[Axes]

Existing Axes object to plot on.

None
fig Optional[Figure]

Existing Figure object to use.

None
showmeans bool

If True, show mean lines on violins.

False
showmedians bool

If True, show median lines on violins.

True
grid bool

If True, display a grid.

False
scatter bool

If True, scatter individual data points.

True
scatter_radius float

Size of scattered points.

3
scatter_edge_color Optional[str]

Color of scattered point edges.

None
scatter_edge_width float

Width of scattered point edges.

0.5
violin_alpha float

Alpha (transparency) of violin plots.

0.5
violin_marker_lw float

Line width of violin markers.

0.5
violin_marker_color str

Color of violin markers.

'k'
color_by Literal['groups', 'experiments']

Whether to color by “groups” or “experiments”.

'groups'
zorder_mean_median int

Z-order for mean and median lines.

5
zorder_min_max int

Z-order for min and max lines.

5
mean_median_linewidth float

Line width for mean and median lines.

0.5
mean_median_color str

Color for mean and median lines.

'k'
mean_median_bar_length Optional[float]

Length of mean and median bars.

None
**kwargs

Additional keyword arguments.

{}

Returns:

Type Description
Figure

A tuple containing:

Axes
  • Figure: The matplotlib Figure object.
ViolinData
  • Axes: The matplotlib Axes object.
Tuple[Figure, Axes, ViolinData]
  • ViolinData: A custom object containing plot data.

Raises:

Type Description
ValueError

If color specifications are invalid.

Note

This function creates either violin plots or bar plots for grouped data, with options for customizing colors, scatter plots, and statistical annotations.

Source code in flyvision/analysis/visualization/plots.py
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
def violin_groups(
    values: np.ndarray,
    xticklabels: Optional[List[str]] = None,
    pvalues: Optional[np.ndarray] = None,
    display_pvalues_kwargs: dict = {},
    legend: Union[bool, List[str]] = False,
    legend_kwargs: dict = {},
    as_bars: bool = False,
    colors: Optional[List[str]] = None,
    cmap: mpl.colors.Colormap = mpl.colormaps["tab10"],
    cstart: float = 0,
    cdist: float = 1,
    figsize: Tuple[float, float] = (10, 1),
    title: str = "",
    ylabel: Optional[str] = None,
    ylim: Optional[Tuple[float, float]] = None,
    rotation: float = 90,
    width: float = 0.7,
    fontsize: int = 6,
    ax: Optional[Axes] = None,
    fig: Optional[Figure] = None,
    showmeans: bool = False,
    showmedians: bool = True,
    grid: bool = False,
    scatter: bool = True,
    scatter_radius: float = 3,
    scatter_edge_color: Optional[str] = None,
    scatter_edge_width: float = 0.5,
    violin_alpha: float = 0.5,
    violin_marker_lw: float = 0.5,
    violin_marker_color: str = "k",
    color_by: Literal["groups", "experiments"] = "groups",
    zorder_mean_median: int = 5,
    zorder_min_max: int = 5,
    mean_median_linewidth: float = 0.5,
    mean_median_color: str = "k",
    mean_median_bar_length: Optional[float] = None,
    **kwargs,
) -> Tuple[Figure, Axes, ViolinData]:
    """
    Create violin plots or bar plots for grouped data.

    Args:
        values: Array of shape (n_random_variables, n_groups, n_samples).
        xticklabels: Labels for the x-axis ticks (random variables).
        pvalues: Array of p-values for statistical significance.
        display_pvalues_kwargs: Keyword arguments for displaying p-values.
        legend: If True or a list, display a legend with group labels.
        legend_kwargs: Keyword arguments for the legend.
        as_bars: If True, create bar plots instead of violin plots.
        colors: List of colors for the violins or bars.
        cmap: Colormap to use when colors are not provided.
        cstart: Starting point in the colormap.
        cdist: Distance between colors in the colormap.
        figsize: Size of the figure (width, height).
        title: Title of the plot.
        ylabel: Label for the y-axis.
        ylim: Limits for the y-axis (min, max).
        rotation: Rotation angle for x-axis labels.
        width: Width of the violins or bars.
        fontsize: Font size for labels and ticks.
        ax: Existing Axes object to plot on.
        fig: Existing Figure object to use.
        showmeans: If True, show mean lines on violins.
        showmedians: If True, show median lines on violins.
        grid: If True, display a grid.
        scatter: If True, scatter individual data points.
        scatter_radius: Size of scattered points.
        scatter_edge_color: Color of scattered point edges.
        scatter_edge_width: Width of scattered point edges.
        violin_alpha: Alpha (transparency) of violin plots.
        violin_marker_lw: Line width of violin markers.
        violin_marker_color: Color of violin markers.
        color_by: Whether to color by "groups" or "experiments".
        zorder_mean_median: Z-order for mean and median lines.
        zorder_min_max: Z-order for min and max lines.
        mean_median_linewidth: Line width for mean and median lines.
        mean_median_color: Color for mean and median lines.
        mean_median_bar_length: Length of mean and median bars.
        **kwargs: Additional keyword arguments.

    Returns:
        A tuple containing:
        - Figure: The matplotlib Figure object.
        - Axes: The matplotlib Axes object.
        - ViolinData: A custom object containing plot data.

    Raises:
        ValueError: If color specifications are invalid.

    Note:
        This function creates either violin plots or bar plots for grouped data,
        with options for customizing colors, scatter plots, and statistical annotations.
    """
    fig, ax = plt_utils.init_plot(figsize, title, fontsize, ax, fig)
    if grid:
        ax.yaxis.grid(zorder=-100)

    def plot_bar(X: float, values: np.ndarray, color: str) -> mpl.patches.Rectangle:
        handle = ax.bar(x=X, width=width, height=np.mean(values), color=color, zorder=1)
        return handle

    def plot_violin(
        X: float, values: np.ndarray, color: str
    ) -> mpl.collections.PolyCollection:
        if isinstance(values, np.ma.core.MaskedArray):
            values = values[~values.mask]

        parts = ax.violinplot(
            values,
            positions=[X],
            widths=width,
            showmedians=showmedians,
            showmeans=showmeans,
        )
        # Color the bodies.
        for pc in parts["bodies"]:
            pc.set_facecolor(color)
            pc.set_alpha(violin_alpha)
            pc.set_zorder(0)
        # Color the lines.
        parts["cbars"].set_color(violin_marker_color)
        parts["cbars"].set_linewidth(violin_marker_lw)
        parts["cbars"].set_zorder(zorder_min_max)
        parts["cmaxes"].set_color(violin_marker_color)
        parts["cmaxes"].set_linewidth(violin_marker_lw)
        parts["cmaxes"].set_zorder(zorder_min_max)
        parts["cmins"].set_color(violin_marker_color)
        parts["cmins"].set_linewidth(violin_marker_lw)
        parts["cmins"].set_zorder(zorder_min_max)
        if "cmeans" in parts:
            parts["cmeans"].set_color(mean_median_color)
            parts["cmeans"].set_linewidth(mean_median_linewidth)
            parts["cmeans"].set_zorder(zorder_mean_median)
            if mean_median_bar_length is not None:
                (_, y0), (_, y1) = parts["cmeans"].get_segments()[0]
                (x0_vert, _), _ = parts["cbars"].get_segments()[0]
                parts["cmeans"].set_segments([
                    [
                        [x0_vert - mean_median_bar_length * width / 2, y0],
                        [x0_vert + mean_median_bar_length * width / 2, y1],
                    ]
                ])
        if "cmedians" in parts:
            parts["cmedians"].set_color(mean_median_color)
            parts["cmedians"].set_linewidth(mean_median_linewidth)
            parts["cmedians"].set_zorder(zorder_mean_median)
            if mean_median_bar_length is not None:
                (_, y0), (_, y1) = parts["cmedians"].get_segments()[0]
                (x0_vert, _), _ = parts["cbars"].get_segments()[0]
                parts["cmedians"].set_segments([
                    [
                        [x0_vert - mean_median_bar_length * width / 2, y0],
                        [x0_vert + mean_median_bar_length * width / 2, y1],
                    ]
                ])
        return parts["bodies"][0]

    shape = np.array(values).shape
    n_random_variables, n_groups = shape[0], shape[1]

    violin_locations, first_violins_location = get_violin_x_locations(
        n_groups, n_random_variables, violin_width=width
    )
    X = violin_locations.T

    if colors is None:
        if color_by == "groups":
            C = np.asarray([cmap(cstart + i * cdist) for i in range(n_groups)]).reshape(
                n_groups, 4
            )
        elif color_by == "experiments":
            C = np.asarray([
                cmap(cstart + i * cdist) for i in range(n_random_variables)
            ]).reshape(n_random_variables, 4)
        else:
            raise ValueError("Invalid color_by option")
    elif isinstance(colors, Iterable):
        if (
            color_by == "groups"
            and len(colors) == n_groups
            or color_by == "experiments"
            and len(colors) == n_random_variables
        ):
            C = colors
        else:
            raise ValueError("Invalid colors length")
    else:
        raise ValueError("Invalid colors specification")

    handles = []

    for i in range(n_random_variables):
        for j in range(n_groups):
            _color = C[i] if color_by == "experiments" else C[j]

            h = (
                plot_bar(X[i, j], values[i, j], _color)
                if as_bars
                else plot_violin(X[i, j], values[i, j], _color)
            )
            handles.append(h)

            if scatter:
                lims = plt_utils.get_lims(
                    (-width / (2 * n_groups), width / (2 * n_groups)), -0.05
                )
                xticks = np.ones_like(values[i][j]) * X[i, j]
                ax.scatter(
                    xticks + np.random.uniform(*lims, size=len(xticks)),
                    values[i][j],
                    facecolor="none",
                    edgecolor=scatter_edge_color or _color,
                    s=scatter_radius,
                    linewidth=scatter_edge_width,
                    zorder=2,
                )

    if legend:
        ax.legend(handles, legend, **legend_kwargs)

    if ylim is not None:
        ax.set_ylim(*ylim)

    ax.tick_params(axis="both", which="major", labelsize=fontsize)

    if xticklabels is not None:
        ax.set_xticks(first_violins_location + (n_groups - 1) / 2 * width)
        ax.set_xticklabels(xticklabels, rotation=rotation)

    with suppress(ValueError):
        ax.set_xlim(np.min(X - width), np.max(X + width))

    ax.set_ylabel(ylabel or "", fontsize=fontsize)
    ax.set_title(title, fontsize=fontsize)

    if pvalues is not None:
        plt_utils.display_pvalues(
            ax, pvalues, xticklabels, values, **display_pvalues_kwargs
        )

    return fig, ax, ViolinData(values, X, colors)

flyvision.analysis.visualization.plots.plot_complex

plot_complex(
    z,
    marker="s",
    fig=None,
    ax=None,
    figsize=(1, 1),
    fontsize=5,
)

Plot a complex number on a polar plot.

Parameters:

Name Type Description Default
z complex

Complex number to plot.

required
marker str

Marker style for the point.

's'
fig Optional[Figure]

Existing figure to plot on.

None
ax Optional[Axes]

Existing axes to plot on.

None
figsize Tuple[float, float]

Size of the figure.

(1, 1)
fontsize int

Font size for text elements.

5

Returns:

Type Description
Tuple[Figure, Axes]

A tuple containing the Figure and Axes objects.

Source code in flyvision/analysis/visualization/plots.py
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
def plot_complex(
    z: complex,
    marker: str = "s",
    fig: Optional[Figure] = None,
    ax: Optional[Axes] = None,
    figsize: Tuple[float, float] = (1, 1),
    fontsize: int = 5,
) -> Tuple[Figure, Axes]:
    """
    Plot a complex number on a polar plot.

    Args:
        z: Complex number to plot.
        marker: Marker style for the point.
        fig: Existing figure to plot on.
        ax: Existing axes to plot on.
        figsize: Size of the figure.
        fontsize: Font size for text elements.

    Returns:
        A tuple containing the Figure and Axes objects.
    """
    fig, ax = plt_utils.init_plot(
        figsize=figsize, projection="polar", fontsize=fontsize, fig=fig, ax=ax
    )

    theta = np.angle(z)
    r = np.abs(z)

    ax.plot([0, theta], [0, r], marker=marker)
    return fig, ax

flyvision.analysis.visualization.plots.plot_complex_vector

plot_complex_vector(
    z0,
    z1,
    marker="s",
    fig=None,
    ax=None,
    figsize=(1, 1),
    fontsize=5,
)

Plot a vector between two complex numbers on a polar plot.

Parameters:

Name Type Description Default
z0 complex

Starting complex number.

required
z1 complex

Ending complex number.

required
marker str

Marker style for the points.

's'
fig Optional[Figure]

Existing figure to plot on.

None
ax Optional[Axes]

Existing axes to plot on.

None
figsize Tuple[float, float]

Size of the figure.

(1, 1)
fontsize int

Font size for text elements.

5

Returns:

Type Description
Tuple[Figure, Axes]

A tuple containing the Figure and Axes objects.

Source code in flyvision/analysis/visualization/plots.py
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
def plot_complex_vector(
    z0: complex,
    z1: complex,
    marker: str = "s",
    fig: Optional[Figure] = None,
    ax: Optional[Axes] = None,
    figsize: Tuple[float, float] = (1, 1),
    fontsize: int = 5,
) -> Tuple[Figure, Axes]:
    """
    Plot a vector between two complex numbers on a polar plot.

    Args:
        z0: Starting complex number.
        z1: Ending complex number.
        marker: Marker style for the points.
        fig: Existing figure to plot on.
        ax: Existing axes to plot on.
        figsize: Size of the figure.
        fontsize: Font size for text elements.

    Returns:
        A tuple containing the Figure and Axes objects.
    """
    fig, ax = plt_utils.init_plot(
        figsize=figsize, projection="polar", fontsize=fontsize, fig=fig, ax=ax
    )

    theta0 = np.angle(z0)
    r0 = np.abs(z0)

    theta = np.angle(z1)
    r = np.abs(z1)

    ax.plot([theta0, theta], [r0, r], marker=marker)
    return fig, ax

flyvision.analysis.visualization.plots.polar

polar(
    theta,
    r,
    ax=None,
    fig=None,
    color="b",
    linestyle="-",
    marker="",
    markersize=None,
    label=None,
    title="",
    figsize=(5, 5),
    fontsize=10,
    xlabel="",
    fontweight="normal",
    anglepad=-2,
    xlabelpad=-3,
    linewidth=2,
    ymin=None,
    ymax=None,
    stroke_kwargs={},
    yticks_off=True,
    zorder=100,
    **kwargs
)

Create a polar tuning plot.

Parameters:

Name Type Description Default
theta NDArray

Array of angles in degrees.

required
r NDArray

Array of radii.

required
ax Optional[Axes]

Matplotlib Axes object.

None
fig Optional[Figure]

Matplotlib Figure object.

None
color Union[str, List[str]]

Color(s) for the plot.

'b'
linestyle str

Line style for the plot.

'-'
marker str

Marker style for data points.

''
markersize Optional[float]

Size of markers.

None
label Optional[str]

Label for the plot.

None
title str

Title of the plot.

''
figsize Tuple[float, float]

Size of the figure.

(5, 5)
fontsize int

Font size for text elements.

10
xlabel str

X-axis label.

''
fontweight Literal['normal', 'bold', 'light', 'ultralight', 'heavy', 'black', 'semibold']

Font weight for labels.

'normal'
anglepad int

Padding for angle labels.

-2
xlabelpad int

Padding for x-axis label.

-3
linewidth float

Width of the plot lines.

2
ymin Optional[float]

Minimum y-axis value.

None
ymax Optional[float]

Maximum y-axis value.

None
stroke_kwargs dict

Keyword arguments for stroke effects.

{}
yticks_off bool

Whether to turn off y-axis ticks.

True
zorder Union[int, List[int]]

Z-order for plot elements.

100
**kwargs

Additional keyword arguments.

{}

Returns:

Type Description
Tuple[Figure, Axes]

A tuple containing the Figure and Axes objects.

Note

This function creates a polar plot with various customization options. It supports multiple traces and custom styling.

Source code in flyvision/analysis/visualization/plots.py
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
def polar(
    theta: NDArray,
    r: NDArray,
    ax: Optional[Axes] = None,
    fig: Optional[Figure] = None,
    color: Union[str, List[str]] = "b",
    linestyle: str = "-",
    marker: str = "",
    markersize: Optional[float] = None,
    label: Optional[str] = None,
    title: str = "",
    figsize: Tuple[float, float] = (5, 5),
    fontsize: int = 10,
    xlabel: str = "",
    fontweight: Literal[
        "normal", "bold", "light", "ultralight", "heavy", "black", "semibold"
    ] = "normal",
    anglepad: int = -2,
    xlabelpad: int = -3,
    linewidth: float = 2,
    ymin: Optional[float] = None,
    ymax: Optional[float] = None,
    stroke_kwargs: dict = {},
    yticks_off: bool = True,
    zorder: Union[int, List[int]] = 100,
    **kwargs,
) -> Tuple[Figure, Axes]:
    """
    Create a polar tuning plot.

    Args:
        theta: Array of angles in degrees.
        r: Array of radii.
        ax: Matplotlib Axes object.
        fig: Matplotlib Figure object.
        color: Color(s) for the plot.
        linestyle: Line style for the plot.
        marker: Marker style for data points.
        markersize: Size of markers.
        label: Label for the plot.
        title: Title of the plot.
        figsize: Size of the figure.
        fontsize: Font size for text elements.
        xlabel: X-axis label.
        fontweight: Font weight for labels.
        anglepad: Padding for angle labels.
        xlabelpad: Padding for x-axis label.
        linewidth: Width of the plot lines.
        ymin: Minimum y-axis value.
        ymax: Maximum y-axis value.
        stroke_kwargs: Keyword arguments for stroke effects.
        yticks_off: Whether to turn off y-axis ticks.
        zorder: Z-order for plot elements.
        **kwargs: Additional keyword arguments.

    Returns:
        A tuple containing the Figure and Axes objects.

    Note:
        This function creates a polar plot with various customization options.
        It supports multiple traces and custom styling.
    """
    fig, ax = plt_utils.init_plot(
        figsize=figsize,
        title=title,
        fontsize=fontsize,
        ax=ax,
        fig=fig,
        projection="polar",
    )

    if sum(theta) < 100:
        logging.warning("Using radians instead of degrees?")

    closed = theta[-1] % 360 == theta[0]
    theta = theta * np.pi / 180
    if not closed:
        theta = np.append(theta, theta[0])

    r = np.asarray(r)
    if not closed:
        r = np.append(r, np.expand_dims(r[0], 0), axis=0)

    line_effects = None
    if stroke_kwargs:
        line_effects = [
            path_effects.Stroke(**stroke_kwargs),
            path_effects.Normal(),
        ]

    zorder = plt_utils.extend_arg(zorder, int, r, default=0, dim=-1)

    if r.ndim == 2:
        for i, _r in enumerate(r.T):
            if isinstance(color, Iterable):
                if isinstance(color, str) and color.startswith("#"):
                    _color = color
                elif len(color) == r.shape[1]:
                    _color = color[i]
                else:
                    _color = color
            else:
                _color = color

            ax.plot(
                theta,
                _r,
                linewidth=linewidth,
                color=_color,
                linestyle=linestyle,
                marker=marker,
                label=label,
                path_effects=line_effects,
                zorder=zorder[i],
                markersize=markersize,
            )
    elif r.ndim == 1:
        ax.plot(
            theta,
            r,
            linewidth=linewidth,
            color=color,
            linestyle=linestyle,
            marker=marker,
            label=label,
            path_effects=line_effects,
            zorder=zorder,
            markersize=markersize,
        )

    ax.tick_params(axis="both", which="major", labelsize=fontsize, pad=anglepad)
    if yticks_off:
        ax.set_yticks([])
        ax.set_yticklabels([])
    ax.set_xticks([
        0,
        np.pi / 4,
        np.pi / 2,
        3 / 4 * np.pi,
        np.pi,
        5 / 4 * np.pi,
        3 / 2 * np.pi,
        7 / 4 * np.pi,
    ])
    ax.set_xticklabels(["0°", "45°", "90°", "", "", "", "", ""])

    ax.set_xlabel(xlabel, fontsize=fontsize, labelpad=xlabelpad, fontweight=fontweight)
    if all((val is not None for val in (ymin, ymax))):
        ax.set_ylim((ymin, ymax))
    plt.setp(ax.spines.values(), color="grey", linewidth=1)
    return fig, ax

flyvision.analysis.visualization.plots.multi_polar

multi_polar(
    theta,
    r,
    ax=None,
    fig=None,
    mean_color="b",
    norm=False,
    std=False,
    color="b",
    mean=False,
    linestyle="-",
    marker="",
    label="",
    legend=False,
    title="",
    figsize=(0.98, 2.38),
    fontsize=5,
    xlabel="",
    fontweight="bold",
    alpha=1,
    anglepad=-6,
    xlabelpad=-3,
    linewidth=0.75,
    ymin=None,
    ymax=None,
    zorder=None,
    legend_kwargs=dict(fontsize=5),
    rm_yticks=True,
    **kwargs
)

Create a polar tuning plot.

Parameters:

Name Type Description Default
theta ndarray

Angles in degrees.

required
r ndarray

Radius values. Shape (n_samples, n_values).

required
ax Optional[Axes]

Existing Axes object to plot on. Defaults to None.

None
fig Optional[Figure]

Existing Figure object to use. Defaults to None.

None
mean_color str

Color for the mean line. Defaults to “b”.

'b'
norm bool

Whether to normalize the radius values. Defaults to False.

False
std bool

Whether to plot standard deviation. Defaults to False.

False
color Union[str, List[str], ndarray]

Color(s) for the lines. Defaults to “b”.

'b'
mean bool

Whether to plot the mean. Defaults to False.

False
linestyle str

Style of the lines. Defaults to “-“.

'-'
marker str

Marker style for data points. Defaults to “”.

''
label Union[str, List[str]]

Label(s) for the lines. Defaults to “”.

''
legend bool

Whether to show a legend. Defaults to False.

False
title str

Title of the plot. Defaults to “”.

''
figsize Tuple[float, float]

Size of the figure. Defaults to (0.98, 2.38).

(0.98, 2.38)
fontsize int

Font size for text elements. Defaults to 5.

5
xlabel str

Label for the x-axis. Defaults to “”.

''
fontweight str

Font weight for labels. Defaults to “bold”.

'bold'
alpha float

Alpha value for line transparency. Defaults to 1.

1
anglepad int

Padding for angle labels. Defaults to -6.

-6
xlabelpad int

Padding for x-axis label. Defaults to -3.

-3
linewidth float

Width of the lines. Defaults to 0.75.

0.75
ymin Optional[float]

Minimum y-axis value. Defaults to None.

None
ymax Optional[float]

Maximum y-axis value. Defaults to None.

None
zorder Optional[Union[int, List[int], ndarray]]

Z-order for drawing. Defaults to None.

None
legend_kwargs Dict[str, Any]

Additional keyword arguments for legend. Defaults to dict(fontsize=5).

dict(fontsize=5)
rm_yticks bool

Whether to remove y-axis ticks. Defaults to True.

True
**kwargs Any

Additional keyword arguments.

{}

Returns:

Type Description
Tuple[Figure, Axes]

A tuple containing the Figure and Axes objects.

Note

This function creates a polar plot with multiple traces, optionally showing mean and standard deviation.

Source code in flyvision/analysis/visualization/plots.py
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
def multi_polar(
    theta: np.ndarray,
    r: np.ndarray,
    ax: Optional[Axes] = None,
    fig: Optional[Figure] = None,
    mean_color: str = "b",
    norm: bool = False,
    std: bool = False,
    color: Union[str, List[str], np.ndarray] = "b",
    mean: bool = False,
    linestyle: str = "-",
    marker: str = "",
    label: Union[str, List[str]] = "",
    legend: bool = False,
    title: str = "",
    figsize: Tuple[float, float] = (0.98, 2.38),
    fontsize: int = 5,
    xlabel: str = "",
    fontweight: str = "bold",
    alpha: float = 1,
    anglepad: int = -6,
    xlabelpad: int = -3,
    linewidth: float = 0.75,
    ymin: Optional[float] = None,
    ymax: Optional[float] = None,
    zorder: Optional[Union[int, List[int], np.ndarray]] = None,
    legend_kwargs: Dict[str, Any] = dict(fontsize=5),
    rm_yticks: bool = True,
    **kwargs: Any,
) -> Tuple[Figure, Axes]:
    """
    Create a polar tuning plot.

    Args:
        theta: Angles in degrees.
        r: Radius values. Shape (n_samples, n_values).
        ax: Existing Axes object to plot on. Defaults to None.
        fig: Existing Figure object to use. Defaults to None.
        mean_color: Color for the mean line. Defaults to "b".
        norm: Whether to normalize the radius values. Defaults to False.
        std: Whether to plot standard deviation. Defaults to False.
        color: Color(s) for the lines. Defaults to "b".
        mean: Whether to plot the mean. Defaults to False.
        linestyle: Style of the lines. Defaults to "-".
        marker: Marker style for data points. Defaults to "".
        label: Label(s) for the lines. Defaults to "".
        legend: Whether to show a legend. Defaults to False.
        title: Title of the plot. Defaults to "".
        figsize: Size of the figure. Defaults to (0.98, 2.38).
        fontsize: Font size for text elements. Defaults to 5.
        xlabel: Label for the x-axis. Defaults to "".
        fontweight: Font weight for labels. Defaults to "bold".
        alpha: Alpha value for line transparency. Defaults to 1.
        anglepad: Padding for angle labels. Defaults to -6.
        xlabelpad: Padding for x-axis label. Defaults to -3.
        linewidth: Width of the lines. Defaults to 0.75.
        ymin: Minimum y-axis value. Defaults to None.
        ymax: Maximum y-axis value. Defaults to None.
        zorder: Z-order for drawing. Defaults to None.
        legend_kwargs: Additional keyword arguments for legend.
            Defaults to dict(fontsize=5).
        rm_yticks: Whether to remove y-axis ticks. Defaults to True.
        **kwargs: Additional keyword arguments.

    Returns:
        A tuple containing the Figure and Axes objects.

    Note:
        This function creates a polar plot with multiple traces, optionally showing
        mean and standard deviation.
    """
    fig, ax = plt_utils.init_plot(
        figsize=figsize,
        title=title,
        fontsize=fontsize,
        ax=ax,
        fig=fig,
        projection="polar",
    )
    r = np.atleast_2d(r)
    n_traces = r.shape[0]

    if norm:
        r = r / (r.max(axis=1, keepdims=True) + 1e-15)

    closed = theta[-1] % 360 == theta[0]
    theta = theta * np.pi / 180
    if not closed:
        theta = np.append(theta, theta[0])
        r = np.append(r, np.expand_dims(r[:, 0], 1), axis=1)

    color = [color] * n_traces if not isinstance(color, (list, np.ndarray)) else color
    label = [label] * n_traces if not isinstance(label, (list, np.ndarray)) else label
    zorder = [100] * n_traces if not isinstance(zorder, (list, np.ndarray)) else zorder

    for i, _r in enumerate(r):
        ax.plot(
            theta,
            _r,
            linewidth=linewidth,
            color=color[i],
            linestyle=linestyle,
            marker=marker,
            label=label[i],
            zorder=zorder[i],
            alpha=alpha,
        )

    if mean:
        ax.plot(
            theta,
            r.mean(0),
            linewidth=linewidth,
            color=mean_color,
            linestyle=linestyle,
            marker=marker,
            label="average",
            alpha=alpha,
        )

    if std:
        ax.fill_between(
            theta,
            r.mean(0) - r.std(0),
            r.mean(0) + r.std(0),
            color="0.8",
            alpha=0.5,
            zorder=-1,
        )

    ax.tick_params(axis="both", which="major", labelsize=fontsize, pad=anglepad)
    if rm_yticks:
        ax.set_yticks([])
        ax.set_yticklabels([])
    ax.set_xticks([
        0,
        np.pi / 4,
        np.pi / 2,
        3 / 4 * np.pi,
        np.pi,
        5 / 4 * np.pi,
        3 / 2 * np.pi,
        7 / 4 * np.pi,
    ])
    ax.set_xticklabels(["0°", "45°", "90°", "", "", "", "", ""])

    ax.set_xlabel(xlabel, fontsize=fontsize, labelpad=xlabelpad, fontweight=fontweight)
    if all((val is not None for val in (ymin, ymax))):
        ax.set_ylim((ymin, ymax))
    plt.setp(ax.spines.values(), color="grey", linewidth=0.5)

    if legend:
        ax.legend(**legend_kwargs)

    return fig, ax

flyvision.analysis.visualization.plots.loss_curves

loss_curves(
    losses,
    smooth=0.05,
    subsample=1,
    mean=False,
    grid=True,
    colors=None,
    cbar=False,
    cmap=None,
    norm=None,
    fig=None,
    ax=None,
    xlabel=None,
    ylabel=None,
)

Plot loss traces.

Parameters:

Name Type Description Default
losses List[ndarray]

List of loss arrays, each of shape (n_iters,).

required
smooth float

Smoothing factor for the loss curves.

0.05
subsample int

Subsample factor for the loss curves.

1
mean bool

Whether to plot the mean loss curve.

False
grid bool

Whether to show grid lines.

True
colors Optional[List[str]]

List of colors for the loss curves.

None
cbar bool

Whether to add a colorbar.

False
cmap Optional[Colormap]

Colormap for the loss curves.

None
norm Optional[Normalize]

Normalization for the colormap.

None
fig Optional[Figure]

Existing figure to plot on.

None
ax Optional[Axes]

Existing axes to plot on.

None
xlabel Optional[str]

Label for the x-axis.

None
ylabel Optional[str]

Label for the y-axis.

None

Returns:

Type Description
Tuple[Figure, Axes]

A tuple containing the Figure and Axes objects.

Note

This function plots loss curves for multiple models, with options for smoothing, subsampling, and various visual customizations.

Source code in flyvision/analysis/visualization/plots.py
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
def loss_curves(
    losses: List[np.ndarray],
    smooth: float = 0.05,
    subsample: int = 1,
    mean: bool = False,
    grid: bool = True,
    colors: Optional[List[str]] = None,
    cbar: bool = False,
    cmap: Optional[mpl.colors.Colormap] = None,
    norm: Optional[mpl.colors.Normalize] = None,
    fig: Optional[Figure] = None,
    ax: Optional[Axes] = None,
    xlabel: Optional[str] = None,
    ylabel: Optional[str] = None,
) -> Tuple[Figure, Axes]:
    """Plot loss traces.

    Args:
        losses: List of loss arrays, each of shape (n_iters,).
        smooth: Smoothing factor for the loss curves.
        subsample: Subsample factor for the loss curves.
        mean: Whether to plot the mean loss curve.
        grid: Whether to show grid lines.
        colors: List of colors for the loss curves.
        cbar: Whether to add a colorbar.
        cmap: Colormap for the loss curves.
        norm: Normalization for the colormap.
        fig: Existing figure to plot on.
        ax: Existing axes to plot on.
        xlabel: Label for the x-axis.
        ylabel: Label for the y-axis.

    Returns:
        A tuple containing the Figure and Axes objects.

    Note:
        This function plots loss curves for multiple models, with options for
        smoothing, subsampling, and various visual customizations.
    """
    losses = np.array([loss[::subsample] for loss in losses])

    max_n_iters = max(len(loss) for loss in losses)

    _losses = np.full((len(losses), max_n_iters), np.nan)
    for i, loss in enumerate(losses):
        n_iters = len(loss)
        _losses[i, :n_iters] = loss

    fig, ax, _, _ = traces(
        _losses[::-1],
        x=np.arange(max_n_iters) * subsample,
        fontsize=5,
        figsize=[1.2, 1],
        smooth=smooth,
        fig=fig,
        ax=ax,
        color=colors[::-1] if colors is not None else None,
        linewidth=0.5,
        highlight_mean=mean,
    )

    ax.set_ylabel(ylabel, fontsize=5)
    ax.set_xlabel(xlabel, fontsize=5)

    if cbar and cmap is not None and norm is not None:
        plt_utils.add_colorbar_to_fig(
            fig,
            cmap=cmap,
            norm=norm,
            label="min task error",
            fontsize=5,
            tick_length=1,
            tick_width=0.5,
            x_offset=2,
            y_offset=0.25,
        )

    if grid:
        ax.yaxis.set_major_locator(MaxNLocator(nbins=10))
        ax.grid(True, linewidth=0.5)

    return fig, ax

flyvision.analysis.visualization.plots.histogram

histogram(
    array,
    bins=None,
    fill=False,
    histtype="step",
    figsize=(1, 1),
    fontsize=5,
    fig=None,
    ax=None,
    xlabel=None,
    ylabel=None,
)

Create a histogram plot.

Parameters:

Name Type Description Default
array ndarray

Input data to plot.

required
bins Optional[Union[int, Sequence, str]]

Number of bins or bin edges. Defaults to len(array).

None
fill bool

Whether to fill the bars. Defaults to False.

False
histtype Literal['bar', 'barstacked', 'step', 'stepfilled']

Type of histogram to plot. Defaults to “step”.

'step'
figsize Tuple[float, float]

Size of the figure. Defaults to (1, 1).

(1, 1)
fontsize int

Font size for labels. Defaults to 5.

5
fig Optional[Figure]

Existing figure to plot on. Defaults to None.

None
ax Optional[Axes]

Existing axes to plot on. Defaults to None.

None
xlabel Optional[str]

Label for x-axis. Defaults to None.

None
ylabel Optional[str]

Label for y-axis. Defaults to None.

None

Returns:

Type Description
Tuple[Figure, Axes]

A tuple containing the Figure and Axes objects.

Source code in flyvision/analysis/visualization/plots.py
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
def histogram(
    array: np.ndarray,
    bins: Optional[Union[int, Sequence, str]] = None,
    fill: bool = False,
    histtype: Literal["bar", "barstacked", "step", "stepfilled"] = "step",
    figsize: Tuple[float, float] = (1, 1),
    fontsize: int = 5,
    fig: Optional[Figure] = None,
    ax: Optional[Axes] = None,
    xlabel: Optional[str] = None,
    ylabel: Optional[str] = None,
) -> Tuple[Figure, Axes]:
    """
    Create a histogram plot.

    Args:
        array: Input data to plot.
        bins: Number of bins or bin edges. Defaults to len(array).
        fill: Whether to fill the bars. Defaults to False.
        histtype: Type of histogram to plot. Defaults to "step".
        figsize: Size of the figure. Defaults to (1, 1).
        fontsize: Font size for labels. Defaults to 5.
        fig: Existing figure to plot on. Defaults to None.
        ax: Existing axes to plot on. Defaults to None.
        xlabel: Label for x-axis. Defaults to None.
        ylabel: Label for y-axis. Defaults to None.

    Returns:
        A tuple containing the Figure and Axes objects.
    """
    fig, ax = plt_utils.init_plot(figsize=figsize, fontsize=fontsize, fig=fig, ax=ax)
    ax.hist(
        array,
        bins=bins if bins is not None else len(array),
        linewidth=0.5,
        fill=fill,
        histtype=histtype,
    )
    ax.set_xlabel(xlabel, fontsize=fontsize)
    ax.set_ylabel(ylabel, fontsize=fontsize)
    return fig, ax

flyvision.analysis.visualization.plots.violins

violins(
    variable_names,
    variable_values,
    ylabel=None,
    title=None,
    max_per_ax=20,
    colors=None,
    cmap=plt.cm.viridis_r,
    fontsize=5,
    violin_width=0.7,
    legend=None,
    scatter_extent=[-0.35, 0.35],
    figwidth=10,
    fig=None,
    axes=None,
    ylabel_offset=0.2,
    **kwargs
)

Create violin plots for multiple variables across groups.

Parameters:

Name Type Description Default
variable_names List[str]

Names of the variables to plot.

required
variable_values ndarray

Array of values for each variable and group.

required
ylabel Optional[str]

Label for the y-axis.

None
title Optional[str]

Title of the plot.

None
max_per_ax Optional[int]

Maximum number of variables per axis.

20
colors Optional[Union[str, List[str]]]

Colors for the violin plots.

None
cmap cm

Colormap to use if colors are not specified.

viridis_r
fontsize int

Font size for labels and ticks.

5
violin_width float

Width of each violin plot.

0.7
legend Optional[Union[str, List[str]]]

Legend labels for groups.

None
scatter_extent List[float]

Extent of scatter points on violins.

[-0.35, 0.35]
figwidth float

Width of the figure.

10
fig Optional[Figure]

Existing figure to plot on.

None
axes Optional[List[Axes]]

Existing axes to plot on.

None
ylabel_offset float

Offset for y-axis label.

0.2
**kwargs Any

Additional keyword arguments for violin_groups function.

{}

Returns:

Type Description
Tuple[Figure, List[Axes]]

A tuple containing the Figure and list of Axes objects.

Note

This function creates violin plots for multiple variables, potentially across multiple groups, with optional scatter points on each violin.

Source code in flyvision/analysis/visualization/plots.py
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
def violins(
    variable_names: List[str],
    variable_values: np.ndarray,
    ylabel: Optional[str] = None,
    title: Optional[str] = None,
    max_per_ax: Optional[int] = 20,
    colors: Optional[Union[str, List[str]]] = None,
    cmap: plt.cm = plt.cm.viridis_r,
    fontsize: int = 5,
    violin_width: float = 0.7,
    legend: Optional[Union[str, List[str]]] = None,
    scatter_extent: List[float] = [-0.35, 0.35],
    figwidth: float = 10,
    fig: Optional[Figure] = None,
    axes: Optional[List[Axes]] = None,
    ylabel_offset: float = 0.2,
    **kwargs: Any,
) -> Tuple[Figure, List[Axes]]:
    """
    Create violin plots for multiple variables across groups.

    Args:
        variable_names: Names of the variables to plot.
        variable_values: Array of values for each variable and group.
        ylabel: Label for the y-axis.
        title: Title of the plot.
        max_per_ax: Maximum number of variables per axis.
        colors: Colors for the violin plots.
        cmap: Colormap to use if colors are not specified.
        fontsize: Font size for labels and ticks.
        violin_width: Width of each violin plot.
        legend: Legend labels for groups.
        scatter_extent: Extent of scatter points on violins.
        figwidth: Width of the figure.
        fig: Existing figure to plot on.
        axes: Existing axes to plot on.
        ylabel_offset: Offset for y-axis label.
        **kwargs: Additional keyword arguments for violin_groups function.

    Returns:
        A tuple containing the Figure and list of Axes objects.

    Note:
        This function creates violin plots for multiple variables, potentially
        across multiple groups, with optional scatter points on each violin.
    """
    variable_values = variable_values.T
    if len(variable_values.shape) == 2:
        variable_values = variable_values[:, None]

    n_variables, n_groups, n_samples = variable_values.shape
    if max_per_ax is None:
        max_per_ax = n_variables
    max_per_ax = min(max_per_ax, n_variables)
    n_axes = int(n_variables / max_per_ax)
    max_per_ax += int(np.ceil((n_variables % max_per_ax) / n_axes))

    fig, axes, _ = plt_utils.get_axis_grid(
        gridheight=n_axes,
        gridwidth=1,
        figsize=[figwidth, n_axes * 1.2],
        hspace=1,
        alpha=0,
        fig=fig,
        axes=axes,
    )

    for i in range(n_axes):
        ax_values = variable_values[i * max_per_ax : (i + 1) * max_per_ax]
        ax_names = variable_names[i * max_per_ax : (i + 1) * max_per_ax]

        fig, ax, C = violin_groups(
            ax_values,
            ax_names,
            rotation=90,
            scatter=False,
            fontsize=fontsize,
            width=violin_width,
            scatter_edge_color="white",
            scatter_radius=5,
            scatter_edge_width=0.25,
            cdist=100,
            colors=colors,
            cmap=cmap,
            showmedians=True,
            showmeans=False,
            violin_marker_lw=0.25,
            legend=(legend if legend else None if i == 0 else None),
            legend_kwargs=dict(
                fontsize=5,
                markerscale=10,
                loc="lower left",
                bbox_to_anchor=(0.75, 0.75),
            ),
            fig=fig,
            ax=axes[i],
            **kwargs,
        )

        violin_locations, _ = get_violin_x_locations(
            n_groups, len(ax_names), violin_width
        )

        for group in range(n_groups):
            plt_utils.scatter_on_violins_or_bars(
                ax_values[:, group].T,
                ax,
                xticks=violin_locations[group],
                facecolor="none",
                edgecolor="k",
                zorder=100,
                alpha=0.35,
                uniform=scatter_extent,
                marker="o",
                linewidth=0.5,
            )

        ax.grid(False)

        plt_utils.trim_axis(ax, yaxis=False)
        plt_utils.set_spine_tick_params(
            ax,
            tickwidth=0.5,
            ticklength=3,
            ticklabelpad=2,
            spinewidth=0.5,
        )

    lefts, bottoms, rights, tops = np.array([ax.get_position().extents for ax in axes]).T
    fig.text(
        lefts.min() - ylabel_offset * lefts.min(),
        (tops.max() - bottoms.min()) / 2,
        ylabel,
        rotation=90,
        fontsize=fontsize,
        ha="right",
        va="center",
    )

    axes[0].set_title(title, y=0.91, fontsize=fontsize)

    return fig, axes

flyvision.analysis.visualization.plots.plot_strf

plot_strf(
    time,
    rf,
    hlines=True,
    vlines=True,
    time_axis=True,
    fontsize=6,
    fig=None,
    axes=None,
    figsize=[5, 1],
    wspace=0,
    y_offset_time_axis=0,
)

Plot a Spatio-Temporal Receptive Field (STRF).

Parameters:

Name Type Description Default
time ndarray

Array of time points.

required
rf ndarray

Receptive field array.

required
hlines bool

Whether to draw horizontal lines. Defaults to True.

True
vlines bool

Whether to draw vertical lines. Defaults to True.

True
time_axis bool

Whether to add a time axis. Defaults to True.

True
fontsize int

Font size for labels and ticks.

6
fig Optional[Figure]

Existing figure to plot on.

None
axes Optional[ndarray]

Existing axes to plot on.

None
figsize List[float]

Size of the figure as [width, height].

[5, 1]
wspace float

Width space between subplots.

0
y_offset_time_axis float

Vertical offset for the time axis.

0

Returns:

Type Description
Tuple[Figure, ndarray]

A tuple containing the Figure and Axes objects.

Note

This function creates a series of hexagonal plots representing the STRF at different time points.

Source code in flyvision/analysis/visualization/plots.py
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
def plot_strf(
    time: np.ndarray,
    rf: np.ndarray,
    hlines: bool = True,
    vlines: bool = True,
    time_axis: bool = True,
    fontsize: int = 6,
    fig: Optional[Figure] = None,
    axes: Optional[np.ndarray] = None,
    figsize: List[float] = [5, 1],
    wspace: float = 0,
    y_offset_time_axis: float = 0,
) -> Tuple[Figure, np.ndarray]:
    """
    Plot a Spatio-Temporal Receptive Field (STRF).

    Args:
        time: Array of time points.
        rf: Receptive field array.
        hlines: Whether to draw horizontal lines. Defaults to True.
        vlines: Whether to draw vertical lines. Defaults to True.
        time_axis: Whether to add a time axis. Defaults to True.
        fontsize: Font size for labels and ticks.
        fig: Existing figure to plot on.
        axes: Existing axes to plot on.
        figsize: Size of the figure as [width, height].
        wspace: Width space between subplots.
        y_offset_time_axis: Vertical offset for the time axis.

    Returns:
        A tuple containing the Figure and Axes objects.

    Note:
        This function creates a series of hexagonal plots representing the STRF
        at different time points.
    """
    max_extent = hex_utils.get_hextent(rf.shape[-1])
    t_steps = np.arange(0.0, 0.2, 0.01)[::2]

    u, v = hex_utils.get_hex_coords(max_extent)
    x, y = hex_utils.hex_to_pixel(u, v)
    xmin, xmax = x.min(), x.max()
    ymin, ymax = y.min(), y.max()
    elev = 0
    azim = 0

    if fig is None or axes is None:
        fig, axes = plt_utils.divide_figure_to_grid(
            np.arange(10).reshape(1, 10),
            wspace=wspace,
            as_matrix=True,
            figsize=figsize,
        )

    crange = np.abs(rf).max()

    for i, t in enumerate(t_steps):
        mask = np.where(np.abs(time - t) <= 1e-15, True, False)
        _rf = rf[mask]
        quick_hex_scatter(
            _rf,
            cmap=plt.cm.coolwarm,
            edgecolor=None,
            vmin=-crange,
            vmax=crange,
            midpoint=0,
            cbar=False,
            max_extent=max_extent,
            fig=fig,
            ax=axes[0, i],
            fill=True,
            fontsize=fontsize,
        )

        if hlines:
            axes[0, i].hlines(elev, xmin, xmax, color="grey", linewidth=0.25)
        if vlines:
            axes[0, i].vlines(azim, ymin, ymax, color="grey", linewidth=0.25)

    if time_axis:
        left = fig.transFigure.inverted().transform(
            axes[0, 0].transData.transform((0, 0))
        )[0]
        right = fig.transFigure.inverted().transform(
            axes[0, -1].transData.transform((0, 0))
        )[0]

        lefts, bottoms, rights, tops = np.array([
            ax.get_position().extents for ax in axes.flatten()
        ]).T
        time_axis = fig.add_axes((
            left,
            bottoms.min() + y_offset_time_axis * bottoms.min(),
            right - left,
            0.01,
        ))
        plt_utils.rm_spines(
            time_axis,
            ("left", "top", "right"),
            rm_yticks=True,
            rm_xticks=False,
        )

        data_centers_in_points = np.array([
            ax.transData.transform((0, 0)) for ax in axes.flatten()
        ])
        time_axis.tick_params(axis="both", labelsize=fontsize)
        ticks = time_axis.transData.inverted().transform(data_centers_in_points)[:, 0]
        time_axis.set_xticks(ticks)
        time_axis.set_xticklabels(np.arange(0, 200, 20))
        time_axis.set_xlabel("time (ms)", fontsize=fontsize, labelpad=2)
        plt_utils.set_spine_tick_params(
            time_axis,
            spinewidth=0.25,
            tickwidth=0.25,
            ticklength=3,
            ticklabelpad=2,
            spines=("top", "right", "bottom", "left"),
        )

    return fig, axes

Classes

flyvision.analysis.visualization.plots.ViolinData dataclass

Container for violin plot data.

Attributes:

Name Type Description
data ndarray

np.ndarray The data used for creating violin plots.

locations ndarray

np.ndarray The x-axis locations of the violin plots.

colors ndarray

np.ndarray The colors used for the violin plots.

Source code in flyvision/analysis/visualization/plots.py
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
@dataclass
class ViolinData:
    """
    Container for violin plot data.

    Attributes:
        data: np.ndarray
            The data used for creating violin plots.
        locations: np.ndarray
            The x-axis locations of the violin plots.
        colors: np.ndarray
            The colors used for the violin plots.
    """

    data: np.ndarray
    locations: np.ndarray
    colors: np.ndarray

flyvision.analysis.visualization.plt_utils

Functions

flyvision.analysis.visualization.plt_utils.check_markers

check_markers(N)

Check if the number of clusters is larger than the number of markers.

Parameters:

Name Type Description Default
N int

Number of clusters.

required

Returns:

Type Description
List[str]

List of markers.

Source code in flyvision/analysis/visualization/plt_utils.py
21
22
23
24
25
26
27
28
29
30
31
32
33
def check_markers(N: int) -> List[str]:
    """
    Check if the number of clusters is larger than the number of markers.

    Args:
        N: Number of clusters.

    Returns:
        List of markers.
    """
    if len(MARKERS) < N:
        return [f"${i}$" for i in range(N)]
    return MARKERS

flyvision.analysis.visualization.plt_utils.get_marker

get_marker(n)

Get marker for n.

Parameters:

Name Type Description Default
n int

Index of the marker.

required

Returns:

Type Description
str

Marker string.

Source code in flyvision/analysis/visualization/plt_utils.py
36
37
38
39
40
41
42
43
44
45
46
def get_marker(n: int) -> str:
    """
    Get marker for n.

    Args:
        n: Index of the marker.

    Returns:
        Marker string.
    """
    return check_markers(n)[n]

flyvision.analysis.visualization.plt_utils.init_plot

init_plot(
    figsize=[1, 1],
    title="",
    fontsize=5,
    ax=None,
    fig=None,
    projection=None,
    set_axis_off=False,
    transparent=False,
    face_alpha=0,
    position=None,
    title_pos="center",
    title_y=None,
    **kwargs
)

Creates fig and axis object with certain default settings.

Parameters:

Name Type Description Default
figsize List[float]

Figure size.

[1, 1]
title str

Title of the plot.

''
fontsize int

Font size for title and labels.

5
ax Axes

Existing axis object.

None
fig Figure

Existing figure object.

None
projection str

Projection type (e.g., ‘polar’).

None
set_axis_off bool

Whether to turn off axis.

False
transparent bool

Whether to make the axis transparent.

False
face_alpha float

Alpha value for the face color.

0
position List[float]

Position for newly created axis.

None
title_pos Literal['center', 'left', 'right']

Position of the title.

'center'
title_y float

Y-coordinate of the title.

None

Returns:

Type Description
Tuple[Figure, Axes]

Tuple containing the figure and axis objects.

Source code in flyvision/analysis/visualization/plt_utils.py
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
def init_plot(
    figsize: List[float] = [1, 1],
    title: str = "",
    fontsize: int = 5,
    ax: Axes = None,
    fig: plt.Figure = None,
    projection: str = None,
    set_axis_off: bool = False,
    transparent: bool = False,
    face_alpha: float = 0,
    position: List[float] = None,
    title_pos: Literal["center", "left", "right"] = "center",
    title_y: float = None,
    **kwargs,
) -> Tuple[plt.Figure, Axes]:
    """
    Creates fig and axis object with certain default settings.

    Args:
        figsize: Figure size.
        title: Title of the plot.
        fontsize: Font size for title and labels.
        ax: Existing axis object.
        fig: Existing figure object.
        projection: Projection type (e.g., 'polar').
        set_axis_off: Whether to turn off axis.
        transparent: Whether to make the axis transparent.
        face_alpha: Alpha value for the face color.
        position: Position for newly created axis.
        title_pos: Position of the title.
        title_y: Y-coordinate of the title.

    Returns:
        Tuple containing the figure and axis objects.
    """
    if fig is None:
        fig = plt.figure(figsize=figsize, layout="constrained")
    if ax is not None:
        ax.set_title(title, fontsize=fontsize, loc=title_pos, y=title_y)
    else:
        ax = fig.add_subplot(projection=projection)
        if position is not None:
            ax.set_position(position)
        ax.patch.set_alpha(face_alpha)
        ax.set_title(title, fontsize=fontsize, loc=title_pos, y=title_y)

        if set_axis_off:
            ax.set_axis_off()
    ax.tick_params(axis="both", which="major", labelsize=fontsize)
    if transparent:
        ax.patch.set_alpha(0)
    return fig, ax

flyvision.analysis.visualization.plt_utils.truncate_colormap

truncate_colormap(cmap, minval=0.0, maxval=1.0, n=100)

Truncate colormap.

Parameters:

Name Type Description Default
cmap Colormap

Original colormap.

required
minval float

Minimum value for truncation.

0.0
maxval float

Maximum value for truncation.

1.0
n int

Number of colors in the new colormap.

100

Returns:

Type Description
LinearSegmentedColormap

Truncated colormap.

Source code in flyvision/analysis/visualization/plt_utils.py
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def truncate_colormap(
    cmap: colors.Colormap, minval: float = 0.0, maxval: float = 1.0, n: int = 100
) -> colors.LinearSegmentedColormap:
    """
    Truncate colormap.

    Args:
        cmap: Original colormap.
        minval: Minimum value for truncation.
        maxval: Maximum value for truncation.
        n: Number of colors in the new colormap.

    Returns:
        Truncated colormap.
    """
    new_cmap = colors.LinearSegmentedColormap.from_list(
        f"trunc({cmap.name},{minval:.2f},{maxval:.2f})",
        cmap(np.linspace(minval, maxval, n)),
    )
    return new_cmap

flyvision.analysis.visualization.plt_utils.rm_spines

rm_spines(
    ax,
    spines=("top", "right", "bottom", "left"),
    visible=False,
    rm_xticks=True,
    rm_yticks=True,
)

Removes spines and ticks from axis.

Parameters:

Name Type Description Default
ax Axes

Matplotlib axis object.

required
spines Tuple[str, ...]

Tuple of spines to remove.

('top', 'right', 'bottom', 'left')
visible bool

Whether to make spines visible.

False
rm_xticks bool

Whether to remove x-ticks.

True
rm_yticks bool

Whether to remove y-ticks.

True

Returns:

Type Description
Axes

Modified axis object.

Source code in flyvision/analysis/visualization/plt_utils.py
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
def rm_spines(
    ax: Axes,
    spines: Tuple[str, ...] = ("top", "right", "bottom", "left"),
    visible: bool = False,
    rm_xticks: bool = True,
    rm_yticks: bool = True,
) -> Axes:
    """
    Removes spines and ticks from axis.

    Args:
        ax: Matplotlib axis object.
        spines: Tuple of spines to remove.
        visible: Whether to make spines visible.
        rm_xticks: Whether to remove x-ticks.
        rm_yticks: Whether to remove y-ticks.

    Returns:
        Modified axis object.
    """
    for spine in spines:
        ax.spines[spine].set_visible(visible)
    if ("top" in spines or "bottom" in spines) and rm_xticks:
        ax.xaxis.set_ticklabels([])
        ax.xaxis.set_ticks_position("none")
    if ("left" in spines or "right" in spines) and rm_yticks:
        ax.yaxis.set_ticklabels([])
        ax.yaxis.set_ticks_position("none")
    return ax

flyvision.analysis.visualization.plt_utils.get_ax_positions

get_ax_positions(axes)

Returns the positions of the axes in the figure.

Parameters:

Name Type Description Default
axes Iterable[Axes]

Single ax or iterable of axes.

required

Returns:

Type Description
ndarray

Tuple containing arrays of left, bottom, right, and top positions,

ndarray

and arrays of centers, widths, and heights.

Source code in flyvision/analysis/visualization/plt_utils.py
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
def get_ax_positions(axes: Iterable[Axes]) -> Tuple[np.ndarray, np.ndarray]:
    """
    Returns the positions of the axes in the figure.

    Args:
        axes: Single ax or iterable of axes.

    Returns:
        Tuple containing arrays of left, bottom, right, and top positions,
        and arrays of centers, widths, and heights.
    """
    axes = np.atleast_1d(axes)
    lefts, bottoms, rights, tops = np.atleast_2d(
        np.array([ax.get_position().extents for ax in axes])
    ).T
    widths = rights - lefts
    heights = tops - bottoms
    centers = np.array([lefts + widths / 2, bottoms + heights / 2])
    return (lefts, bottoms, rights, tops), (centers, widths, heights)

flyvision.analysis.visualization.plt_utils.is_hex

is_hex(color)

Checks if color is hex.

Parameters:

Name Type Description Default
color str

Color string.

required

Returns:

Type Description
bool

True if color is hex, False otherwise.

Source code in flyvision/analysis/visualization/plt_utils.py
177
178
179
180
181
182
183
184
185
186
187
def is_hex(color: str) -> bool:
    """
    Checks if color is hex.

    Args:
        color: Color string.

    Returns:
        True if color is hex, False otherwise.
    """
    return "#" in color

flyvision.analysis.visualization.plt_utils.is_integer_rgb

is_integer_rgb(color)

Checks if color is integer RGB.

Parameters:

Name Type Description Default
color Iterable[int]

Color tuple or list.

required

Returns:

Type Description
bool

True if color is integer RGB, False otherwise.

Source code in flyvision/analysis/visualization/plt_utils.py
190
191
192
193
194
195
196
197
198
199
200
201
202
203
def is_integer_rgb(color: Iterable[int]) -> bool:
    """
    Checks if color is integer RGB.

    Args:
        color: Color tuple or list.

    Returns:
        True if color is integer RGB, False otherwise.
    """
    try:
        return any([c > 1 for c in color])
    except TypeError:
        return False

flyvision.analysis.visualization.plt_utils.get_alpha_colormap

get_alpha_colormap(saturated_color, number_of_shades)

Create a colormap from a color and a number of shades.

Parameters:

Name Type Description Default
saturated_color str

Saturated color string.

required
number_of_shades int

Number of shades in the colormap.

required

Returns:

Type Description
ListedColormap

ListedColormap object.

Source code in flyvision/analysis/visualization/plt_utils.py
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
def get_alpha_colormap(saturated_color: str, number_of_shades: int) -> ListedColormap:
    """
    Create a colormap from a color and a number of shades.

    Args:
        saturated_color: Saturated color string.
        number_of_shades: Number of shades in the colormap.

    Returns:
        ListedColormap object.
    """
    if is_hex(saturated_color):
        rgba = [*hex2color(saturated_color)[:3], 0]
    elif is_integer_rgb(saturated_color):
        rgba = [*list(np.array(saturated_color) / 255.0), 0]

    N = number_of_shades
    colors = []
    alphas = np.linspace(1 / N, 1, N)[::-1]
    for alpha in alphas:
        rgba[-1] = alpha
        colors.append(rgba.copy())

    return ListedColormap(colors)

flyvision.analysis.visualization.plt_utils.polar_to_cmap

polar_to_cmap(
    r,
    theta,
    invert=True,
    cmap=plt.cm.twilight_shifted,
    norm=None,
    sm=None,
)

Maps angle to rgb and amplitude to alpha and returns the resulting array.

Parameters:

Name Type Description Default
r ndarray

Amplitude array.

required
theta ndarray

Angle array.

required
invert bool

Whether to invert the colormap.

True
cmap Colormap

Colormap.

twilight_shifted
norm Normalize

Normalization object.

None
sm ScalarMappable

ScalarMappable object.

None

Returns:

Type Description
ndarray

RGBA array.

Source code in flyvision/analysis/visualization/plt_utils.py
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
def polar_to_cmap(
    r: np.ndarray,
    theta: np.ndarray,
    invert: bool = True,
    cmap: colors.Colormap = plt.cm.twilight_shifted,
    norm: Normalize = None,
    sm: ScalarMappable = None,
) -> np.ndarray:
    """
    Maps angle to rgb and amplitude to alpha and returns the resulting array.

    Args:
        r: Amplitude array.
        theta: Angle array.
        invert: Whether to invert the colormap.
        cmap: Colormap.
        norm: Normalization object.
        sm: ScalarMappable object.

    Returns:
        RGBA array.
    """
    sm = sm if sm else plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    r = r / r.max()
    A = np.zeros([theta.shape[0], theta.shape[1], 4])
    RGBA = sm.to_rgba(theta)
    A[:, :, 0] = RGBA[:, :, 0]
    A[:, :, 1] = RGBA[:, :, 1]
    A[:, :, 2] = RGBA[:, :, 2]
    if invert:
        A = 1 - A
        A[:, :, -1] = r  # amplitude
        return A
    else:
        A[:, :, -1] = r  # amplitude
        return A

flyvision.analysis.visualization.plt_utils.add_colorwheel_2d

add_colorwheel_2d(
    fig,
    axes=None,
    pos="southeast",
    radius=0.25,
    x_offset=0,
    y_offset=0,
    sm=None,
    cmap="cm_uniform_2d",
    norm=None,
    fontsize=6,
    N=512,
    labelpad=0,
    invert=False,
    mode="2d",
    ticks=[0, 60, 120],
)

Adds a colorwheel to a figure.

Parameters:

Name Type Description Default
fig Figure

Matplotlib figure object.

required
axes Iterable[Axes]

Iterable of axes to which the colorwheel will be added.

None
pos Literal['southeast', 'east', 'northeast', 'north', 'northwest', 'west', 'southwest', 'south', 'origin']

Position of the colorwheel.

'southeast'
radius float

Radius of the colorwheel in percentage of the ax radius.

0.25
x_offset float

X-offset of the colorwheel in percentage of the cbar diameter.

0
y_offset float

Y-offset of the colorwheel in percentage of the cbar diameter.

0
sm ScalarMappable

ScalarMappable object.

None
cmap str

Colormap name.

'cm_uniform_2d'
norm Normalize

Normalization object.

None
fontsize int

Font size for tick labels.

6
N int

Number of samples for the colorwheel.

512
labelpad float

Padding for tick labels.

0
invert bool

Whether to invert the colormap.

False
mode Literal['1d', '2d']

Mode of the colorwheel (“1d” or “2d”).

'2d'
ticks List[int]

Tick positions in degrees.

[0, 60, 120]

Returns:

Type Description
Tuple[Axes, Axes]

Tuple containing the colorwheel axis and the annotation axis.

Source code in flyvision/analysis/visualization/plt_utils.py
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
def add_colorwheel_2d(
    fig: plt.Figure,
    axes: Iterable[Axes] = None,
    pos: Literal[
        "southeast",
        "east",
        "northeast",
        "north",
        "northwest",
        "west",
        "southwest",
        "south",
        "origin",
    ] = "southeast",
    radius: float = 0.25,
    x_offset: float = 0,
    y_offset: float = 0,
    sm: ScalarMappable = None,
    cmap: str = "cm_uniform_2d",
    norm: Normalize = None,
    fontsize: int = 6,
    N: int = 512,
    labelpad: float = 0,
    invert: bool = False,
    mode: Literal["1d", "2d"] = "2d",
    ticks: List[int] = [0, 60, 120],
) -> Tuple[Axes, Axes]:
    """
    Adds a colorwheel to a figure.

    Args:
        fig: Matplotlib figure object.
        axes: Iterable of axes to which the colorwheel will be added.
        pos: Position of the colorwheel.
        radius: Radius of the colorwheel in percentage of the ax radius.
        x_offset: X-offset of the colorwheel in percentage of the cbar diameter.
        y_offset: Y-offset of the colorwheel in percentage of the cbar diameter.
        sm: ScalarMappable object.
        cmap: Colormap name.
        norm: Normalization object.
        fontsize: Font size for tick labels.
        N: Number of samples for the colorwheel.
        labelpad: Padding for tick labels.
        invert: Whether to invert the colormap.
        mode: Mode of the colorwheel ("1d" or "2d").
        ticks: Tick positions in degrees.

    Returns:
        Tuple containing the colorwheel axis and the annotation axis.
    """
    cmap = plt.get_cmap(cmap)

    pos = derive_position_for_supplementary_ax_hex(
        fig,
        axes=axes,
        pos=pos,
        radius=radius,
        x_offset=x_offset,
        y_offset=y_offset,
    )
    cb = fig.add_axes(pos, alpha=0)
    cb.patch.set_alpha(0)

    x = np.linspace(-1, 1, N)
    y = np.linspace(-1, 1, N)
    X, Y = np.meshgrid(x, y)
    R = np.sqrt(X * X + Y * Y)
    circular_mask = R < 1
    R[~circular_mask] = 0
    if mode == "1d":
        R[circular_mask] = 1

    PHI = np.arctan2(Y, X)  # + np.pi
    cb.imshow(
        polar_to_cmap(R, PHI, invert=invert, cmap=cmap, norm=norm, sm=sm),
        origin="lower",
    )
    cb.set_axis_off()

    cs = fig.add_axes(pos, polar=True, label="annotation", alpha=0)
    cs.set_facecolor("none")
    cs.set_yticks([])
    cs.set_yticklabels([])  # turn off radial tick labels (yticks)

    # cosmetic changes to tick labels
    cs.tick_params(pad=labelpad, labelsize=fontsize)
    cs.set_xticks(np.radians(ticks))
    cs.set_xticklabels([x + "°" for x in np.array(ticks).astype(int).astype(str)])

    plt.setp(cs.spines.values(), color="white", linewidth=2)

    return cb, cs

flyvision.analysis.visualization.plt_utils.add_cluster_marker

add_cluster_marker(
    fig,
    ax,
    marker="o",
    marker_size=15,
    color="#4F73AE",
    x_offset=0,
    y_offset=0,
)

Adds a cluster marker to a figure.

Parameters:

Name Type Description Default
fig Figure

Matplotlib figure object.

required
ax Axes

Matplotlib axis object.

required
marker str

Marker style.

'o'
marker_size int

Marker size.

15
color str

Marker color.

'#4F73AE'
x_offset float

X-offset of the marker in percentage of the ax width.

0
y_offset float

Y-offset of the marker in percentage of the ax height.

0
Source code in flyvision/analysis/visualization/plt_utils.py
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
def add_cluster_marker(
    fig: plt.Figure,
    ax: Axes,
    marker: str = "o",
    marker_size: int = 15,
    color: str = "#4F73AE",
    x_offset: float = 0,
    y_offset: float = 0,
) -> None:
    """
    Adds a cluster marker to a figure.

    Args:
        fig: Matplotlib figure object.
        ax: Matplotlib axis object.
        marker: Marker style.
        marker_size: Marker size.
        color: Marker color.
        x_offset: X-offset of the marker in percentage of the ax width.
        y_offset: Y-offset of the marker in percentage of the ax height.
    """
    # make all axes transparent to see the marker regardless where on the figure
    # plane it is
    for _ax in fig.axes:
        _ax.patch.set_alpha(0)

    # create an invisible ax that spans the entire figure to scatter the marker on it
    overlay_ax = [ax for ax in fig.axes if ax.get_label() == "overlay"]
    overlay_ax = (
        overlay_ax[0]
        if overlay_ax
        else fig.add_axes([0, 0, 1, 1], label="overlay", alpha=0)
    )

    overlay_ax.set_ylim(0, 1)
    overlay_ax.set_xlim(0, 1)
    overlay_ax.patch.set_alpha(0)
    rm_spines(overlay_ax, visible=False, rm_xticks=True, rm_yticks=True)

    # get where the axis is actually positioned, that will be annotated with the
    # marker
    left, bottom, width, height = ax.get_position().bounds

    # scatter the marker relative to that position of the ax
    overlay_ax.scatter(
        left + x_offset * width,
        bottom + y_offset * height,
        marker=marker,
        s=marker_size,
        color=color,
    )

flyvision.analysis.visualization.plt_utils.derive_position_for_supplementary_ax

derive_position_for_supplementary_ax(
    fig,
    pos="right",
    width=0.04,
    height=0.5,
    x_offset=0,
    y_offset=0,
    axes=None,
)

Returns a position for a supplementary ax.

Parameters:

Name Type Description Default
fig Figure

Matplotlib figure object.

required
pos Literal['right', 'left', 'top', 'bottom']

Position of the supplementary ax relative to the main axes.

'right'
width float

Width of the supplementary ax in percentage of the main ax width.

0.04
height float

Height of the supplementary ax in percentage of the main ax height.

0.5
x_offset float

X-offset of the supplementary ax in percentage of the main ax width.

0
y_offset float

Y-offset of the supplementary ax in percentage of the main ax height.

0
axes Iterable[Axes]

Iterable of axes to which the supplementary ax will be added.

None

Returns:

Type Description
List[float]

List containing the left, bottom, width, and height of the supplementary ax.

Source code in flyvision/analysis/visualization/plt_utils.py
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
def derive_position_for_supplementary_ax(
    fig: plt.Figure,
    pos: Literal["right", "left", "top", "bottom"] = "right",
    width: float = 0.04,
    height: float = 0.5,
    x_offset: float = 0,
    y_offset: float = 0,
    axes: Iterable[Axes] = None,
) -> List[float]:
    """
    Returns a position for a supplementary ax.

    Args:
        fig: Matplotlib figure object.
        pos: Position of the supplementary ax relative to the main axes.
        width: Width of the supplementary ax in percentage of the main ax width.
        height: Height of the supplementary ax in percentage of the main ax height.
        x_offset: X-offset of the supplementary ax in percentage of the main ax width.
        y_offset: Y-offset of the supplementary ax in percentage of the main ax height.
        axes: Iterable of axes to which the supplementary ax will be added.

    Returns:
        List containing the left, bottom, width, and height of the supplementary ax.
    """
    axes = axes if axes is not None else fig.get_axes()
    x0, y0, x1, y1 = np.array([ax.get_position().extents for ax in axes]).T
    x0, y0, x1, y1 = x0.min(), y0.min(), x1.max(), y1.max()
    ax_width = x1 - x0
    ax_height = y1 - y0
    positions = {
        "right": [
            x1 + ax_width * width / 2 + ax_width * width * x_offset,  # left
            y0 + (1 - height) * ax_height / 2 + ax_height * height * y_offset,  # bottom
            ax_width * width,  # width
            ax_height * height,  # height
        ],
        "left": [
            x0 - 3 / 2 * ax_width * width + ax_width * width * x_offset,  # left
            y0 + (1 - height) * ax_height / 2 + y_offset,  # bottom
            ax_width * width,  # width
            ax_height * height,  # height
        ],
        "top": [
            x1
            - ax_width * width
            + ax_width * width * x_offset,  # x0 + (1 - width) * ax_width/2
            y1 + ax_height * height / 2 + ax_height * height * y_offset,
            ax_width * width,
            ax_height * height,
        ],
        "bottom": [
            x0 + (1 - width) * ax_width / +ax_width * width * x_offset,
            y0 - 3 / 2 * ax_height * height + ax_height * height * y_offset,
            ax_width * width,
            ax_height * height,
        ],
    }
    return positions[pos]

flyvision.analysis.visualization.plt_utils.derive_position_for_supplementary_ax_hex

derive_position_for_supplementary_ax_hex(
    fig,
    axes=None,
    pos="southwest",
    radius=0.25,
    x_offset=0,
    y_offset=0,
)

Returns a position for a supplementary ax.

Parameters:

Name Type Description Default
fig Figure

Matplotlib figure object.

required
axes Iterable[Axes]

Iterable of axes to which the supplementary ax will be added.

None
pos Literal['southeast', 'east', 'northeast', 'north', 'northwest', 'west', 'southwest', 'south', 'origin']

Position of the supplementary ax relative to the main axes.

'southwest'
radius float

Radius of the supplementary ax in percentage of the main ax radius.

0.25
x_offset float

X-offset of the supplementary ax in percentage of the main ax width.

0
y_offset float

Y-offset of the supplementary ax in percentage of the main ax height.

0

Returns:

Type Description
List[float]

List containing the left, bottom, width, and height of the supplementary ax.

Source code in flyvision/analysis/visualization/plt_utils.py
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
def derive_position_for_supplementary_ax_hex(
    fig: plt.Figure,
    axes: Iterable[Axes] = None,
    pos: Literal[
        "southeast",
        "east",
        "northeast",
        "north",
        "northwest",
        "west",
        "southwest",
        "south",
        "origin",
    ] = "southwest",
    radius: float = 0.25,
    x_offset: float = 0,
    y_offset: float = 0,
) -> List[float]:
    """
    Returns a position for a supplementary ax.

    Args:
        fig: Matplotlib figure object.
        axes: Iterable of axes to which the supplementary ax will be added.
        pos: Position of the supplementary ax relative to the main axes.
        radius: Radius of the supplementary ax in percentage of the main ax radius.
        x_offset: X-offset of the supplementary ax in percentage of the main ax width.
        y_offset: Y-offset of the supplementary ax in percentage of the main ax height.

    Returns:
        List containing the left, bottom, width, and height of the supplementary ax.
    """
    axes = axes if axes is not None else fig.get_axes()
    x0, y0, x1, y1 = np.array([ax.get_position().extents for ax in axes]).T
    x0, y0, x1, y1 = x0.min(), y0.min(), x1.max(), y1.max()
    axes_width = x1 - x0
    axes_height = y1 - y0
    axes_radius = (axes_width + axes_height) / 4
    new_ax_radius = axes_radius * radius
    position = {
        "southeast": [
            x1 + x_offset * 2 * new_ax_radius,
            y1
            - axes_height
            - 2 * new_ax_radius
            + y_offset * 2 * new_ax_radius,  # + radius ** 2 * axes_height,
            2 * new_ax_radius,
            2 * new_ax_radius,
        ],
        "east": [
            x1 + x_offset * 2 * new_ax_radius,
            y1
            - axes_height / 2
            - new_ax_radius
            + y_offset * 2 * new_ax_radius,  # + radius ** 2 * axes_height,
            2 * new_ax_radius,
            2 * new_ax_radius,
        ],
        "northeast": [
            x1 + x_offset * 2 * new_ax_radius,
            y1 + y_offset * 2 * new_ax_radius,  # + radius ** 2 * axes_height,
            2 * new_ax_radius,
            2 * new_ax_radius,
        ],
        "north": [
            x1 - axes_width / 2 - new_ax_radius + x_offset * 2 * new_ax_radius,
            y1 + y_offset * 2 * new_ax_radius,  # + radius ** 2 * axes_height,
            2 * new_ax_radius,
            2 * new_ax_radius,
        ],
        "northwest": [
            x1 - axes_width - 2 * new_ax_radius + x_offset * 2 * new_ax_radius,
            y1 + y_offset * 2 * new_ax_radius,  # + radius ** 2 * axes_height,
            2 * new_ax_radius,
            2 * new_ax_radius,
        ],
        "west": [
            x1 - axes_width - 2 * new_ax_radius + x_offset * 2 * new_ax_radius,
            y1
            - axes_height / 2
            - new_ax_radius
            + y_offset * 2 * new_ax_radius,  # + radius ** 2 * axes_height,
            2 * new_ax_radius,
            2 * new_ax_radius,
        ],
        "southwest": [
            x1 - axes_width - 2 * new_ax_radius + x_offset * 2 * new_ax_radius,
            y1
            - axes_height
            - 2 * new_ax_radius
            + y_offset * 2 * new_ax_radius,  # + radius ** 2 * axes_height,
            2 * new_ax_radius,
            2 * new_ax_radius,
        ],
        "south": [
            x1 - axes_width / 2 - new_ax_radius + x_offset * 2 * new_ax_radius,
            y1
            - axes_height
            - 2 * new_ax_radius
            + y_offset * 2 * new_ax_radius,  # + radius ** 2 * axes_height,
            2 * new_ax_radius,
            2 * new_ax_radius,
        ],
        "origin": [
            x0 + x_offset,
            y0 + y_offset,
            2 * new_ax_radius,
            2 * new_ax_radius,
        ],
    }
    return position[pos]

flyvision.analysis.visualization.plt_utils.add_colorbar_to_fig

add_colorbar_to_fig(
    fig,
    axes=None,
    pos="right",
    width=0.04,
    height=0.5,
    x_offset=0,
    y_offset=0,
    cmap=cm.get_cmap("binary"),
    fontsize=10,
    tick_length=1.5,
    tick_width=0.75,
    rm_outline=True,
    ticks=None,
    norm=None,
    label="",
    plain=False,
    use_math_text=False,
    scilimits=None,
    style="",
    alpha=1,
    n_ticks=9,
    discrete=False,
    n_discrete=None,
    discrete_labels=None,
    n_decimals=2,
)

Adds a colorbar to a figure.

Parameters:

Name Type Description Default
fig Figure

Matplotlib figure object.

required
axes Iterable[Axes]

Iterable of axes to which the colorbar will be added.

None
pos Literal['right', 'left', 'top', 'bottom']

Position of the colorbar.

'right'
width float

Width of the colorbar in percentage of the ax width.

0.04
height float

Height of the colorbar in percentage of the ax height.

0.5
x_offset float

X-offset of the colorbar in percentage of the ax width.

0
y_offset float

Y-offset of the colorbar in percentage of the ax height.

0
cmap Colormap

Colormap.

get_cmap('binary')
fontsize int

Font size for tick labels.

10
tick_length float

Length of the tick marks.

1.5
tick_width float

Width of the tick marks.

0.75
rm_outline bool

Whether to remove the outline of the colorbar.

True
ticks Iterable[float]

Tick positions.

None
norm Normalize

Normalization object.

None
label str

Colorbar label.

''
plain bool

Whether to remove tick labels.

False
use_math_text bool

Whether to use math text for tick labels.

False
scilimits Tuple[float, float]

Limits for scientific notation.

None
style str

Style for scientific notation.

''
alpha float

Alpha value for the colorbar.

1
n_ticks int

Number of ticks for TwoSlopeNorm.

9
discrete bool

Whether to use discrete colors.

False
n_discrete int

Number of discrete colors.

None
discrete_labels Iterable[str]

Labels for discrete colors.

None
n_decimals int

Number of decimal places for tick labels.

2

Returns:

Type Description
Colorbar

Matplotlib colorbar object.

Source code in flyvision/analysis/visualization/plt_utils.py
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
def add_colorbar_to_fig(
    fig: plt.Figure,
    axes: Iterable[Axes] = None,
    pos: Literal["right", "left", "top", "bottom"] = "right",
    width: float = 0.04,
    height: float = 0.5,
    x_offset: float = 0,
    y_offset: float = 0,
    cmap: colors.Colormap = cm.get_cmap("binary"),
    fontsize: int = 10,
    tick_length: float = 1.5,
    tick_width: float = 0.75,
    rm_outline: bool = True,
    ticks: Iterable[float] = None,
    norm: Normalize = None,
    label: str = "",
    plain: bool = False,
    use_math_text: bool = False,
    scilimits: Tuple[float, float] = None,
    style: str = "",
    alpha: float = 1,
    n_ticks: int = 9,  # only effective if norm is TwoSlopeNorm
    discrete: bool = False,
    n_discrete: int = None,
    discrete_labels: Iterable[str] = None,
    n_decimals: int = 2,
) -> mpl.colorbar.Colorbar:
    """
    Adds a colorbar to a figure.

    Args:
        fig: Matplotlib figure object.
        axes: Iterable of axes to which the colorbar will be added.
        pos: Position of the colorbar.
        width: Width of the colorbar in percentage of the ax width.
        height: Height of the colorbar in percentage of the ax height.
        x_offset: X-offset of the colorbar in percentage of the ax width.
        y_offset: Y-offset of the colorbar in percentage of the ax height.
        cmap: Colormap.
        fontsize: Font size for tick labels.
        tick_length: Length of the tick marks.
        tick_width: Width of the tick marks.
        rm_outline: Whether to remove the outline of the colorbar.
        ticks: Tick positions.
        norm: Normalization object.
        label: Colorbar label.
        plain: Whether to remove tick labels.
        use_math_text: Whether to use math text for tick labels.
        scilimits: Limits for scientific notation.
        style: Style for scientific notation.
        alpha: Alpha value for the colorbar.
        n_ticks: Number of ticks for TwoSlopeNorm.
        discrete: Whether to use discrete colors.
        n_discrete: Number of discrete colors.
        discrete_labels: Labels for discrete colors.
        n_decimals: Number of decimal places for tick labels.

    Returns:
        Matplotlib colorbar object.
    """
    _orientation = "vertical" if pos in ("left", "right") else "horizontal"

    position = derive_position_for_supplementary_ax(
        fig=fig,
        pos=pos,
        width=width,
        height=height,
        x_offset=x_offset,
        y_offset=y_offset,
        axes=axes,
    )
    cbax = fig.add_axes(position, label="cbar")

    cbar = mpl.colorbar.ColorbarBase(
        cbax,
        cmap=cmap,
        norm=norm,
        orientation=_orientation,
        ticks=ticks,
        alpha=alpha,
    )
    cbar.set_label(fontsize=fontsize, label=label)
    cbar.ax.tick_params(labelsize=fontsize, length=tick_length, width=tick_width)
    if pos in ("left", "right"):
        scalarformatter = isinstance(
            cbar.ax.yaxis.get_major_formatter(), mpl.ticker.ScalarFormatter
        )
        cbar.ax.yaxis.set_ticks_position(pos)
        cbar.ax.yaxis.set_label_position(pos)
        cbar.ax.yaxis.get_offset_text().set_fontsize(fontsize)
        cbar.ax.yaxis.get_offset_text().set_horizontalalignment("left")
        cbar.ax.yaxis.get_offset_text().set_verticalalignment("bottom")
    else:
        scalarformatter = isinstance(
            cbar.ax.xaxis.get_major_formatter(), mpl.ticker.ScalarFormatter
        )
        cbar.ax.xaxis.set_ticks_position(pos)
        cbar.ax.xaxis.set_label_position(pos)
        cbar.ax.xaxis.get_offset_text().set_fontsize(fontsize)
        cbar.ax.xaxis.get_offset_text().set_verticalalignment("top")
        cbar.ax.xaxis.get_offset_text().set_horizontalalignment("left")

    if scalarformatter:
        cbar.ax.ticklabel_format(
            style=style, useMathText=use_math_text, scilimits=scilimits
        )

    if rm_outline:
        cbar.outline.set_visible(False)

    if isinstance(norm, TwoSlopeNorm):
        vmin = norm.vmin
        vmax = norm.vmax
        vcenter = norm.vcenter
        left_ticks = np.linspace(vmin, vcenter, n_ticks // 2)
        right_ticks = np.linspace(vcenter, vmax, n_ticks // 2)
        ticks = ticks or [*left_ticks, *right_ticks[1:]]
        cbar.set_ticks(
            ticks,
            labels=[f"{t:.{n_decimals}f}" for t in ticks],
            fontsize=fontsize,
        )

    if plain:
        cbar.set_ticks([])

    if discrete:
        # to put ticklabels for discrete colors in the middle
        if not n_discrete:
            raise ValueError(f"n_discrete {n_discrete}")
        lim = cbar.ax.get_ylim() if pos in ["left", "right"] else cbar.ax.get_xlim()
        color_width = (lim[1] - lim[0]) / n_discrete
        label_offset_to_center = color_width / 2
        labels = np.arange(n_discrete)
        loc = np.linspace(*lim, n_discrete, endpoint=False) + label_offset_to_center
        cbar.set_ticks(loc)
        cbar.set_ticklabels(discrete_labels or labels)

    return cbar

flyvision.analysis.visualization.plt_utils.get_norm

get_norm(
    norm=None,
    vmin=None,
    vmax=None,
    midpoint=None,
    log=None,
    symlog=None,
)

Returns a normalization object for color normalization.

Parameters:

Name Type Description Default
norm Normalize

A class which, when called, can normalize data into an interval [vmin, vmax].

None
vmin float

Minimum value for normalization.

None
vmax float

Maximum value for normalization.

None
midpoint float

Midpoint value so that data is normalized around it.

None
log bool

Whether to normalize on a log-scale.

None
symlog float

Normalizes to symlog with linear range around the range (-symlog, symlog).

None

Returns:

Type Description
Normalize

Normalization object.

Source code in flyvision/analysis/visualization/plt_utils.py
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
def get_norm(
    norm: Normalize = None,
    vmin: float = None,
    vmax: float = None,
    midpoint: float = None,
    log: bool = None,
    symlog: float = None,
) -> Normalize:
    """
    Returns a normalization object for color normalization.

    Args:
        norm: A class which, when called, can normalize data into an interval
            [vmin, vmax].
        vmin: Minimum value for normalization.
        vmax: Maximum value for normalization.
        midpoint: Midpoint value so that data is normalized around it.
        log: Whether to normalize on a log-scale.
        symlog: Normalizes to symlog with linear range around the range (-symlog, symlog).

    Returns:
        Normalization object.
    """
    if norm:
        return norm

    if all(val is not None for val in (vmin, vmax)):
        vmin -= 1e-15
        vmax += 1e-15

    if all(val is not None for val in (vmin, vmax, midpoint)):
        if vmin > midpoint or np.isclose(vmin, midpoint, atol=1e-9):
            vmin = midpoint - vmax
        if vmax < midpoint or np.isclose(vmax, midpoint, atol=1e-9):
            vmax = midpoint - vmin
        return TwoSlopeNorm(vcenter=midpoint, vmin=vmin, vmax=vmax)
    elif all(val is not None for val in (vmin, vmax, log)):
        return mpl.colors.LogNorm(vmin=vmin, vmax=vmax)
    elif all(val is not None for val in (vmin, vmax, symlog)):
        v = max(np.abs(vmin), np.abs(vmax))
        return mpl.colors.SymLogNorm(symlog, vmin=-v, vmax=v)
    elif all(val is not None for val in (vmin, vmax)):
        return mpl.colors.Normalize(vmin=vmin, vmax=vmax)
    else:
        return None

flyvision.analysis.visualization.plt_utils.get_scalarmapper

get_scalarmapper(
    scalarmapper=None,
    cmap=None,
    norm=None,
    vmin=None,
    vmax=None,
    midpoint=None,
    log=None,
    symlog=None,
)

Returns scalarmappable with norm from get_norm and cmap.

Parameters:

Name Type Description Default
scalarmapper ScalarMappable

Scalarmappable for data to RGBA mapping.

None
cmap Colormap

Colormap.

None
norm Normalize

Normalization object.

None
vmin float

Minimum value for normalization.

None
vmax float

Maximum value for normalization.

None
midpoint float

Midpoint value for normalization.

None
log bool

Whether to normalize on a log-scale.

None
symlog float

Normalizes to symlog with linear range around the range (-symlog, symlog).

None

Returns:

Type Description
Tuple[ScalarMappable, Normalize]

Tuple containing the scalarmappable and the normalization object.

Source code in flyvision/analysis/visualization/plt_utils.py
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
def get_scalarmapper(
    scalarmapper: ScalarMappable = None,
    cmap: colors.Colormap = None,
    norm: Normalize = None,
    vmin: float = None,
    vmax: float = None,
    midpoint: float = None,
    log: bool = None,
    symlog: float = None,
) -> Tuple[ScalarMappable, Normalize]:
    """
    Returns scalarmappable with norm from `get_norm` and cmap.

    Args:
        scalarmapper: Scalarmappable for data to RGBA mapping.
        cmap: Colormap.
        norm: Normalization object.
        vmin: Minimum value for normalization.
        vmax: Maximum value for normalization.
        midpoint: Midpoint value for normalization.
        log: Whether to normalize on a log-scale.
        symlog: Normalizes to symlog with linear range around the range (-symlog, symlog).

    Returns:
        Tuple containing the scalarmappable and the normalization object.
    """
    if scalarmapper:
        return scalarmapper, norm

    norm = get_norm(
        norm=norm,
        vmin=vmin,
        vmax=vmax,
        midpoint=midpoint,
        log=log,
        symlog=symlog,
    )
    return plt.cm.ScalarMappable(norm=norm, cmap=cmap), norm

flyvision.analysis.visualization.plt_utils.get_lims

get_lims(z, offset, min=None, max=None)

Get scalar bounds of Ndim-array-like structure with relative offset.

Parameters:

Name Type Description Default
z Union[ndarray, Iterable[ndarray]]

Ndim-array-like structure.

required
offset float

Relative offset for the bounds.

required
min float

Minimum value for the bounds.

None
max float

Maximum value for the bounds.

None

Returns:

Type Description
Tuple[float, float]

Tuple containing the minimum and maximum values.

Source code in flyvision/analysis/visualization/plt_utils.py
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
def get_lims(
    z: Union[np.ndarray, Iterable[np.ndarray]],
    offset: float,
    min: float = None,
    max: float = None,
) -> Tuple[float, float]:
    """
    Get scalar bounds of Ndim-array-like structure with relative offset.

    Args:
        z: Ndim-array-like structure.
        offset: Relative offset for the bounds.
        min: Minimum value for the bounds.
        max: Maximum value for the bounds.

    Returns:
        Tuple containing the minimum and maximum values.
    """

    def sub_nan(val: float, sub: float) -> float:
        if np.isnan(val):
            return sub
        else:
            return val

    if isinstance(z, (tuple, list)):
        z = list(map(lambda x: get_lims(x, offset), z))
    if isinstance(z, torch.Tensor):
        z = z.detach().cpu().numpy()
    z = np.array(z)[~np.isinf(z)]
    if not z.any():
        return -1, 1
    _min, _max = np.nanmin(z), np.nanmax(z)
    _range = np.abs(_max - _min)
    _min -= _range * offset
    _max += _range * offset
    _min, _max = sub_nan(_min, 0), sub_nan(_max, 1)
    if min is not None:
        _min = np.min((min, _min))
    if max is not None:
        _max = np.max((max, _max))
    return _min, _max

flyvision.analysis.visualization.plt_utils.avg_pool

avg_pool(trace, N)

Smoothes (multiple) traces over the second dimension using the GPU.

Parameters:

Name Type Description Default
trace ndarray

Array of shape (N, t).

required
N int

Window size for averaging.

required

Returns:

Type Description
ndarray

Smoothed trace array.

Source code in flyvision/analysis/visualization/plt_utils.py
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
def avg_pool(trace: np.ndarray, N: int) -> np.ndarray:
    """
    Smoothes (multiple) traces over the second dimension using the GPU.

    Args:
        trace: Array of shape (N, t).
        N: Window size for averaging.

    Returns:
        Smoothed trace array.
    """
    shape = trace.shape
    trace = trace.reshape(np.prod(shape[:-1]), 1, shape[-1])
    with torch.no_grad():
        trace_smooth = (
            F.avg_pool1d(torch.tensor(trace, dtype=torch.float32), N, N).cpu().numpy()
        )
    return trace_smooth.reshape(shape[0], -1)

flyvision.analysis.visualization.plt_utils.width_n_height

width_n_height(
    N, aspect_ratio, max_width=None, max_height=None
)

Integer width and height for a grid of N plots with aspect ratio.

Parameters:

Name Type Description Default
N int

Number of plots.

required
aspect_ratio float

Aspect ratio of the grid.

required
max_width int

Maximum width of the grid.

None
max_height int

Maximum height of the grid.

None

Returns:

Type Description
Tuple[int, int]

Tuple containing the width and height of the grid.

Source code in flyvision/analysis/visualization/plt_utils.py
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
def width_n_height(
    N: int, aspect_ratio: float, max_width: int = None, max_height: int = None
) -> Tuple[int, int]:
    """
    Integer width and height for a grid of N plots with aspect ratio.

    Args:
        N: Number of plots.
        aspect_ratio: Aspect ratio of the grid.
        max_width: Maximum width of the grid.
        max_height: Maximum height of the grid.

    Returns:
        Tuple containing the width and height of the grid.
    """
    if max_width is not None and max_height is not None:
        raise ValueError

    _sqrt = int(np.ceil(np.sqrt(N)))

    gridwidth = np.ceil(_sqrt * aspect_ratio).astype(int)
    gridheight = np.ceil(_sqrt / aspect_ratio).astype(int)

    gridwidth = max(1, min(N, gridwidth, np.ceil(N / gridheight)))
    gridheight = max(1, min(N, gridheight, np.ceil(N / gridwidth)))

    if max_width is not None and gridwidth > max_width:
        gridwidth = max_width
        gridheight = np.ceil(N / gridwidth)

    if max_height is not None and gridheight > max_height:
        gridheight = max_height
        gridwidth = np.ceil(N / gridheight)

    assert gridwidth * gridheight >= N

    return int(gridwidth), int(gridheight)

flyvision.analysis.visualization.plt_utils.get_axis_grid

get_axis_grid(
    alist=None,
    gridwidth=None,
    gridheight=None,
    max_width=None,
    max_height=None,
    fig=None,
    ax=None,
    axes=None,
    aspect_ratio=1,
    figsize=None,
    scale=3,
    projection=None,
    as_matrix=False,
    fontsize=5,
    wspace=0.1,
    hspace=0.3,
    alpha=1,
    sharex=None,
    sharey=None,
    unmask_n=None,
)

Create axis grid for a list of elements or integer width and height.

Parameters:

Name Type Description Default
alist Iterable

List of elements to create grid for.

None
gridwidth int

Width of grid.

None
gridheight int

Height of grid.

None
max_width int

Maximum width of grid.

None
max_height int

Maximum height of grid.

None
fig Figure

Existing figure to use.

None
ax Axes

Existing axis to use. This ax will be divided into a grid of axes with the same size as the grid.

None
axes Iterable[Axes]

Existing axes to use.

None
aspect_ratio float

Aspect ratio of grid.

1
figsize List[float]

Figure size.

None
scale Union[int, Iterable[int]]

Scales figure size by this factor(s) times the grid width and height.

3
projection Union[str, Iterable[str]]

Projection of axes.

None
as_matrix bool

Return axes as matrix.

False
fontsize int

Fontsize of axes.

5
wspace float

Width space between axes.

0.1
hspace float

Height space between axes.

0.3
alpha float

Alpha of axes.

1
sharex Axes

Share x axis. Only effective if a new grid of axes is created.

None
sharey Axes

Share y axis. Only effective if a new grid of axes is created.

None
unmask_n int

Number of elements to unmask. If None, all elements are unmasked. If provided elements at indices >= unmask_n are padded with nans.

None

Returns:

Type Description
Tuple[Figure, Union[List[Axes], ndarray], Tuple[int, int]]

Tuple containing the figure, axes, and the grid width and height.

Source code in flyvision/analysis/visualization/plt_utils.py
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
def get_axis_grid(
    alist: Iterable = None,
    gridwidth: int = None,
    gridheight: int = None,
    max_width: int = None,
    max_height: int = None,
    fig: plt.Figure = None,
    ax: Axes = None,
    axes: Iterable[Axes] = None,
    aspect_ratio: float = 1,
    figsize: List[float] = None,
    scale: Union[int, Iterable[int]] = 3,
    projection: Union[str, Iterable[str]] = None,
    as_matrix: bool = False,
    fontsize: int = 5,
    wspace: float = 0.1,
    hspace: float = 0.3,
    alpha: float = 1,
    sharex: Axes = None,
    sharey: Axes = None,
    unmask_n: int = None,
) -> Tuple[plt.Figure, Union[List[Axes], np.ndarray], Tuple[int, int]]:
    """
    Create axis grid for a list of elements or integer width and height.

    Args:
        alist: List of elements to create grid for.
        gridwidth: Width of grid.
        gridheight: Height of grid.
        max_width: Maximum width of grid.
        max_height: Maximum height of grid.
        fig: Existing figure to use.
        ax: Existing axis to use. This ax will be divided into a grid of axes with the
            same size as the grid.
        axes: Existing axes to use.
        aspect_ratio: Aspect ratio of grid.
        figsize: Figure size.
        scale: Scales figure size by this factor(s) times the grid width and height.
        projection: Projection of axes.
        as_matrix: Return axes as matrix.
        fontsize: Fontsize of axes.
        wspace: Width space between axes.
        hspace: Height space between axes.
        alpha: Alpha of axes.
        sharex: Share x axis. Only effective if a new grid of axes is created.
        sharey: Share y axis. Only effective if a new grid of axes is created.
        unmask_n: Number of elements to unmask. If None, all elements are unmasked.
            If provided elements at indices >= unmask_n are padded with nans.

    Returns:
        Tuple containing the figure, axes, and the grid width and height.
    """
    if alist is not None and (
        gridwidth is None or gridheight is None or gridwidth * gridheight != len(alist)
    ):
        gridwidth, gridheight = width_n_height(
            len(alist), aspect_ratio, max_width=max_width, max_height=max_height
        )
    elif gridwidth and gridheight:
        alist = range(gridwidth * gridheight)
    else:
        raise ValueError("Either specify alist or gridwidth and gridheight manually.")
    unmask_n = unmask_n or len(alist)
    if figsize is not None:
        pass
    elif isinstance(scale, Number):
        figsize = [scale * gridwidth, scale * gridheight]
    elif isinstance(scale, Iterable) and len(scale) == 2:
        figsize = [scale[0] * gridwidth, scale[1] * gridheight]

    if fig is None:
        fig = figure(figsize=figsize, hspace=hspace, wspace=wspace)

    if not isinstance(projection, Iterable) or isinstance(projection, str):
        projection = (projection,) * len(alist)

    if isinstance(ax, Iterable):
        assert len(ax) == len(alist)
        if isinstance(ax, dict):
            ax = list(ax.values())
        return fig, ax, (gridwidth, gridheight)

    if ax:
        # divide an existing ax in a figure
        matrix = np.ones(gridwidth * gridheight) * np.nan
        for i in range(len(alist)):
            matrix[i] = i
        axes = divide_axis_to_grid(
            ax,
            matrix=matrix.reshape(gridwidth, gridheight),
            wspace=wspace,
            hspace=hspace,
        )
        axes = list(axes.values())
    elif axes is None:
        # fill a figure with axes
        axes = []
        _sharex, _sharey = None, None

        for i, _ in enumerate(alist):
            if i < unmask_n:
                ax = subplot(
                    "",
                    grid=(gridheight, gridwidth),
                    location=(int(i // gridwidth), int(i % gridwidth)),
                    rowspan=1,
                    colspan=1,
                    sharex=_sharex,
                    sharey=_sharey,
                    projection=projection[i],
                )

                if sharex is not None:
                    sharex = ax
                if sharey is not None:
                    sharey = ax

                axes.append(ax)
            else:
                axes.append(np.nan)

    for ax in axes:
        if isinstance(ax, Axes):
            ax.tick_params(axis="both", which="major", labelsize=fontsize)
            ax.patch.set_alpha(alpha)

    if as_matrix:
        axes = np.array(axes).reshape(gridheight, gridwidth)

    return fig, axes, (gridwidth, gridheight)

flyvision.analysis.visualization.plt_utils.figure

figure(
    figsize,
    hspace=0.3,
    wspace=0.1,
    left=0.125,
    right=0.9,
    top=0.9,
    bottom=0.1,
    frameon=None,
)

Create a figure with the given size and spacing.

Parameters:

Name Type Description Default
figsize List[float]

Figure size.

required
hspace float

Height space between subplots.

0.3
wspace float

Width space between subplots.

0.1
left float

Left margin.

0.125
right float

Right margin.

0.9
top float

Top margin.

0.9
bottom float

Bottom margin.

0.1
frameon bool

Whether to draw the figure frame.

None

Returns:

Type Description
Figure

Matplotlib figure object.

Source code in flyvision/analysis/visualization/plt_utils.py
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
def figure(
    figsize: List[float],
    hspace: float = 0.3,
    wspace: float = 0.1,
    left: float = 0.125,
    right: float = 0.9,
    top: float = 0.9,
    bottom: float = 0.1,
    frameon: bool = None,
) -> plt.Figure:
    """
    Create a figure with the given size and spacing.

    Args:
        figsize: Figure size.
        hspace: Height space between subplots.
        wspace: Width space between subplots.
        left: Left margin.
        right: Right margin.
        top: Top margin.
        bottom: Bottom margin.
        frameon: Whether to draw the figure frame.

    Returns:
        Matplotlib figure object.
    """
    fig = plt.figure(figsize=figsize, frameon=frameon)
    plt.subplots_adjust(
        hspace=hspace,
        wspace=wspace,
        left=left,
        top=top,
        right=right,
        bottom=bottom,
    )
    return fig

flyvision.analysis.visualization.plt_utils.subplot

subplot(
    title="",
    grid=(1, 1),
    location=(0, 0),
    colspan=1,
    rowspan=1,
    projection=None,
    sharex=None,
    sharey=None,
    xlabel="",
    ylabel="",
    face_alpha=1.0,
    fontisze=5,
    title_pos="center",
    position=None,
    **kwargs
)

Create a subplot using subplot2grid with some extra options.

Parameters:

Name Type Description Default
title str

Title of the subplot.

''
grid Tuple[int, int]

Grid shape.

(1, 1)
location Tuple[int, int]

Location of the subplot in the grid.

(0, 0)
colspan int

Number of columns the subplot spans.

1
rowspan int

Number of rows the subplot spans.

1
projection str

Projection type (e.g., ‘polar’).

None
sharex Axes

Axis to share x-axis with.

None
sharey Axes

Axis to share y-axis with.

None
xlabel str

X-axis label.

''
ylabel str

Y-axis label.

''
face_alpha float

Alpha value for the face color.

1.0
fontisze int

Font size for title and labels.

5
title_pos Literal['center', 'left', 'right']

Position of the title.

'center'
position List[float]

Position for the subplot.

None

Returns:

Type Description
Axes

Matplotlib axis object.

Source code in flyvision/analysis/visualization/plt_utils.py
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
def subplot(
    title: str = "",
    grid: Tuple[int, int] = (1, 1),
    location: Tuple[int, int] = (0, 0),
    colspan: int = 1,
    rowspan: int = 1,
    projection: str = None,
    sharex: Axes = None,
    sharey: Axes = None,
    xlabel: str = "",
    ylabel: str = "",
    face_alpha: float = 1.0,
    fontisze: int = 5,
    title_pos: Literal["center", "left", "right"] = "center",
    position: List[float] = None,
    **kwargs,
) -> Axes:
    """
    Create a subplot using subplot2grid with some extra options.

    Args:
        title: Title of the subplot.
        grid: Grid shape.
        location: Location of the subplot in the grid.
        colspan: Number of columns the subplot spans.
        rowspan: Number of rows the subplot spans.
        projection: Projection type (e.g., 'polar').
        sharex: Axis to share x-axis with.
        sharey: Axis to share y-axis with.
        xlabel: X-axis label.
        ylabel: Y-axis label.
        face_alpha: Alpha value for the face color.
        fontisze: Font size for title and labels.
        title_pos: Position of the title.
        position: Position for the subplot.

    Returns:
        Matplotlib axis object.
    """
    ax = plt.subplot2grid(
        grid,
        location,
        colspan,
        rowspan,
        sharex=sharex,
        sharey=sharey,
        projection=projection,
    )
    if position:
        ax.set_position(position)
    ax.patch.set_alpha(face_alpha)

    ax.set_title(title, fontsize=fontisze, loc=title_pos)

    plt.xlabel(xlabel, fontsize=fontisze)
    plt.ylabel(ylabel, fontsize=fontisze)

    return ax

flyvision.analysis.visualization.plt_utils.divide_axis_to_grid

divide_axis_to_grid(
    ax,
    matrix=((0, 1, 2), (3, 3, 3)),
    wspace=0.1,
    hspace=0.1,
    projection=None,
)

Divides an existing axis inside a figure to a grid specified by unique elements in a matrix.

Parameters:

Name Type Description Default
ax Axes

Existing Axes object.

required
matrix ndarray

Grid matrix, where each unique element specifies a new axis.

((0, 1, 2), (3, 3, 3))
wspace float

Horizontal space between new axes.

0.1
hspace float

Vertical space between new axes.

0.1
projection str

Projection of new axes.

None

Returns:

Type Description
Dict[Any, Axes]

Dictionary of new axes, where keys are unique elements in the matrix.

Example
fig = plt.figure()
ax = plt.subplot()
plt.tight_layout()
divide_axis_to_grid(ax, matrix=[[0, 1, 1, 1, 2, 2, 2],
                                    [3, 4, 5, 6, 2, 2, 2],
                                    [3, 7, 7, 7, 2, 2, 2],
                                    [3, 8, 8, 12, 2, 2, 2],
                                    [3, 10, 11, 12, 2, 2, 2]],
                        wspace=0.1, hspace=0.1)
Source code in flyvision/analysis/visualization/plt_utils.py
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
def divide_axis_to_grid(
    ax: Axes,
    matrix: np.ndarray = ((0, 1, 2), (3, 3, 3)),
    wspace: float = 0.1,
    hspace: float = 0.1,
    projection: str = None,
) -> Dict[Any, Axes]:
    """
    Divides an existing axis inside a figure to a grid specified by unique elements
    in a matrix.

    Args:
        ax: Existing Axes object.
        matrix: Grid matrix, where each unique element specifies a new axis.
        wspace: Horizontal space between new axes.
        hspace: Vertical space between new axes.
        projection: Projection of new axes.

    Returns:
        Dictionary of new axes, where keys are unique elements in the matrix.

    Example:
        ```python
        fig = plt.figure()
        ax = plt.subplot()
        plt.tight_layout()
        divide_axis_to_grid(ax, matrix=[[0, 1, 1, 1, 2, 2, 2],
                                            [3, 4, 5, 6, 2, 2, 2],
                                            [3, 7, 7, 7, 2, 2, 2],
                                            [3, 8, 8, 12, 2, 2, 2],
                                            [3, 10, 11, 12, 2, 2, 2]],
                                wspace=0.1, hspace=0.1)
        ```
    """

    # get position of original axis, and dispose it
    x0, y0, x1, y1 = ax.get_position().extents
    ax.set_axis_off()
    ax.patch.set_alpha(0)
    fig = ax.figure

    # get grid shape
    n_row, n_col = np.array(matrix).shape

    # get geometry params
    width = x1 - x0
    height = y1 - y0
    height_per_row = height / n_row
    width_per_col = width / n_col

    _ax_pos = {}

    for i, row in enumerate(matrix):
        for j, _ax in enumerate(row):
            # get occurence of unique element per row and column
            _ax_per_row = sum([1 for _ in np.array(matrix).T[j] if _ == _ax])
            _ax_per_col = sum([1 for _ in row if _ == _ax])

            # compute positioning of _ax
            left = x0 + j * width_per_col + wspace / 2
            bottom = y0 + height - (i + _ax_per_row) * height_per_row + hspace / 2
            _width = width_per_col * _ax_per_col - min(
                wspace / 2, width_per_col * _ax_per_col
            )
            _height = height_per_row * _ax_per_row - min(
                hspace / 2, height_per_row * _ax_per_row
            )

            # store positioning
            if _ax not in _ax_pos and not np.isnan(_ax):
                _ax_pos[_ax] = [left, bottom, _width, _height]

    # add axis to existing figure and store in dict
    axes = {k: None for k in _ax_pos}
    for _ax, pos in _ax_pos.items():
        axes[_ax] = fig.add_axes(pos, projection=projection)

    return axes

flyvision.analysis.visualization.plt_utils.divide_figure_to_grid

divide_figure_to_grid(
    matrix=[
        [0, 1, 1, 1, 2, 2, 2],
        [3, 4, 5, 6, 2, 2, 2],
        [3, 7, 7, 7, 2, 2, 2],
        [3, 8, 8, 12, 2, 2, 2],
        [3, 10, 11, 12, 2, 2, 2],
    ],
    as_matrix=False,
    alpha=0,
    constrained_layout=False,
    fig=None,
    figsize=[10, 10],
    projection=None,
    wspace=0.1,
    hspace=0.3,
    no_spines=False,
    keep_nan_axes=False,
    fontsize=5,
    reshape_order="F",
)

Creates a figure grid specified by the arrangement of unique elements in a matrix.

Info

matplotlib now also has matplotlib.pyplot.subplot_mosaic which does the same thing and should be used instead.

Parameters:

Name Type Description Default
matrix List[List[int]]

Grid layout specification.

[[0, 1, 1, 1, 2, 2, 2], [3, 4, 5, 6, 2, 2, 2], [3, 7, 7, 7, 2, 2, 2], [3, 8, 8, 12, 2, 2, 2], [3, 10, 11, 12, 2, 2, 2]]
as_matrix bool

If True, return axes as a numpy array.

False
alpha float

Alpha value for axis patches.

0
constrained_layout bool

Use constrained layout for the figure.

False
fig Optional[Figure]

Existing figure to use. If None, a new figure is created.

None
figsize List[float]

Figure size in inches.

[10, 10]
projection Optional[Union[str, List[str]]]

Projection type for the axes.

None
wspace float

Width space between subplots.

0.1
hspace float

Height space between subplots.

0.3
no_spines bool

If True, remove spines from all axes.

False
keep_nan_axes bool

If True, keep axes for NaN values in the matrix.

False
fontsize int

Font size for tick labels.

5
reshape_order Literal['C', 'F', 'A', 'K']

Order to use when reshaping the axes array.

'F'

Returns:

Type Description
Tuple[Figure, Union[Dict[int, Axes], ndarray]]

A tuple containing the figure and a dictionary or numpy array of axes.

Source code in flyvision/analysis/visualization/plt_utils.py
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
def divide_figure_to_grid(
    matrix: List[List[int]] = [
        [0, 1, 1, 1, 2, 2, 2],
        [3, 4, 5, 6, 2, 2, 2],
        [3, 7, 7, 7, 2, 2, 2],
        [3, 8, 8, 12, 2, 2, 2],
        [3, 10, 11, 12, 2, 2, 2],
    ],
    as_matrix: bool = False,
    alpha: float = 0,
    constrained_layout: bool = False,
    fig: Optional[plt.Figure] = None,
    figsize: List[float] = [10, 10],
    projection: Optional[Union[str, List[str]]] = None,
    wspace: float = 0.1,
    hspace: float = 0.3,
    no_spines: bool = False,
    keep_nan_axes: bool = False,
    fontsize: int = 5,
    reshape_order: Literal["C", "F", "A", "K"] = "F",
) -> Tuple[plt.Figure, Union[Dict[int, plt.Axes], np.ndarray]]:
    """
    Creates a figure grid specified by the arrangement of unique elements in a matrix.

    Info:
        `matplotlib` now also has `matplotlib.pyplot.subplot_mosaic` which does the
        same thing and should be used instead.

    Args:
        matrix: Grid layout specification.
        as_matrix: If True, return axes as a numpy array.
        alpha: Alpha value for axis patches.
        constrained_layout: Use constrained layout for the figure.
        fig: Existing figure to use. If None, a new figure is created.
        figsize: Figure size in inches.
        projection: Projection type for the axes.
        wspace: Width space between subplots.
        hspace: Height space between subplots.
        no_spines: If True, remove spines from all axes.
        keep_nan_axes: If True, keep axes for NaN values in the matrix.
        fontsize: Font size for tick labels.
        reshape_order: Order to use when reshaping the axes array.

    Returns:
        A tuple containing the figure and a dictionary or numpy array of axes.
    """

    def _array_to_slice(array):
        step = 1
        start = array.min()
        stop = array.max() + 1
        return slice(start, stop, step)

    fig = plt.figure(figsize=figsize) if fig is None else fig
    fig.set_constrained_layout(constrained_layout)
    matrix = np.ma.masked_invalid(matrix)
    rows, columns = matrix.shape

    gs = GridSpec(rows, columns, figure=fig, hspace=hspace, wspace=wspace)

    axes = {}
    for val in np.unique(matrix[~matrix.mask]):
        _row_ind, _col_ind = np.where(matrix == val)
        _row_slc, _col_slc = _array_to_slice(_row_ind), _array_to_slice(_col_ind)

        _projection = projection[val] if isinstance(projection, dict) else projection
        ax = fig.add_subplot(gs[_row_slc, _col_slc], projection=_projection)
        ax.patch.set_alpha(alpha)
        ax.tick_params(axis="both", which="major", labelsize=fontsize)
        if projection is None:
            ax.spines["top"].set_visible(False)
            ax.spines["right"].set_visible(False)
        axes[val] = ax

    if keep_nan_axes:
        for _row_ind, _col_ind in np.array(np.where(np.isnan(matrix))).T:
            _row_slc, _col_slc = _array_to_slice(_row_ind), _array_to_slice(_col_ind)
            ax = fig.add_subplot(gs[_row_slc, _col_slc])
            ax.patch.set_alpha(alpha)
            rm_spines(ax, ("left", "right", "top", "bottom"))

    if no_spines:
        for ax in axes.values():
            rm_spines(ax, rm_xticks=True, rm_yticks=True)

    if as_matrix:
        if reshape_order == "special":
            # reshape based on elements in matrix
            ax_matrix = np.ones(matrix.shape, dtype=object) * np.nan
            for key, value in axes.items():
                _row_ind, _col_ind = np.where(matrix == key)
                ax_matrix[_row_ind, _col_ind] = value
            axes = ax_matrix
        else:
            # reshape without considering specific locations
            _axes = np.ones(matrix.shape, dtype=object).flatten() * np.nan
            _axes[: len(axes)] = np.array(list(axes.values()), dtype=object)
            axes = _axes.reshape(matrix.shape, order=reshape_order)

    return fig, axes

flyvision.analysis.visualization.plt_utils.scale

scale(x, y, wpad=0.1, hpad=0.1, wspace=0, hspace=0)

Scale x and y coordinates to fit within a specified padding and spacing.

Parameters:

Name Type Description Default
x ndarray

Array of x-coordinates.

required
y ndarray

Array of y-coordinates.

required
wpad float

Width padding.

0.1
hpad float

Height padding.

0.1
wspace float

Width space between elements.

0
hspace float

Height space between elements.

0

Returns:

Type Description
Tuple[ndarray, ndarray, float, float]

A tuple containing scaled x, y coordinates, width, and height.

Source code in flyvision/analysis/visualization/plt_utils.py
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
def scale(
    x: np.ndarray,
    y: np.ndarray,
    wpad: float = 0.1,
    hpad: float = 0.1,
    wspace: float = 0,
    hspace: float = 0,
) -> Tuple[np.ndarray, np.ndarray, float, float]:
    """
    Scale x and y coordinates to fit within a specified padding and spacing.

    Args:
        x: Array of x-coordinates.
        y: Array of y-coordinates.
        wpad: Width padding.
        hpad: Height padding.
        wspace: Width space between elements.
        hspace: Height space between elements.

    Returns:
        A tuple containing scaled x, y coordinates, width, and height.
    """
    x, y = np.array(x), np.array(y)
    assert len(x) == len(y)

    # Min-Max Scale x-Positions.
    width = (1 - 2 * wpad) / (2 * np.ceil(np.median(np.unique(x))))
    width = width - wspace * width
    x = (x - np.min(x)) / (np.max(x) + np.min(x)) * (1 - width) * (1 - wpad * 2) + wpad

    # Min-Max Scale y-Positions.
    height = (1 - 2 * hpad) / (2 * np.ceil(np.median(np.unique(y))))
    height = height - hspace * height
    y = (y - np.min(y)) / (np.max(y) + np.min(y)) * (1 - height) * (1 - hpad * 2) + hpad
    return x, y, width, height

flyvision.analysis.visualization.plt_utils.ax_scatter

ax_scatter(
    x,
    y,
    fig=None,
    figsize=[7, 7],
    hspace=0,
    wspace=0,
    hpad=0.1,
    wpad=0.1,
    alpha=0,
    zorder=10,
    projection=None,
    labels=None,
)

Creates scattered axes in a given or new figure.

Parameters:

Name Type Description Default
x ndarray

Array of x-coordinates.

required
y ndarray

Array of y-coordinates.

required
fig Optional[Figure]

Existing figure to use. If None, a new figure is created.

None
figsize List[float]

Figure size in inches.

[7, 7]
hspace float

Height space between subplots.

0
wspace float

Width space between subplots.

0
hpad float

Height padding.

0.1
wpad float

Width padding.

0.1
alpha float

Alpha value for axis patches.

0
zorder int

Z-order for axis patches.

10
projection Optional[str]

Projection type for the axes.

None
labels Optional[List[str]]

List of labels for each axis.

None

Returns:

Type Description
Tuple[Figure, List[Axes], List[List[float]]]

A tuple containing the figure, a list of axes, and a list of center coordinates.

Source code in flyvision/analysis/visualization/plt_utils.py
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
def ax_scatter(
    x: np.ndarray,
    y: np.ndarray,
    fig: Optional[plt.Figure] = None,
    figsize: List[float] = [7, 7],
    hspace: float = 0,
    wspace: float = 0,
    hpad: float = 0.1,
    wpad: float = 0.1,
    alpha: float = 0,
    zorder: int = 10,
    projection: Optional[str] = None,
    labels: Optional[List[str]] = None,
) -> Tuple[plt.Figure, List[plt.Axes], List[List[float]]]:
    """
    Creates scattered axes in a given or new figure.

    Args:
        x: Array of x-coordinates.
        y: Array of y-coordinates.
        fig: Existing figure to use. If None, a new figure is created.
        figsize: Figure size in inches.
        hspace: Height space between subplots.
        wspace: Width space between subplots.
        hpad: Height padding.
        wpad: Width padding.
        alpha: Alpha value for axis patches.
        zorder: Z-order for axis patches.
        projection: Projection type for the axes.
        labels: List of labels for each axis.

    Returns:
        A tuple containing the figure, a list of axes, and a list of center coordinates.
    """
    x, y, width, height = scale(x, y, wpad, hpad, wspace, hspace)

    # Create axes in figure.
    fig = fig or plt.figure(figsize=figsize)
    axes = []
    for i, (_x, _y) in enumerate(zip(x, y)):
        ax = fig.add_axes(
            [_x, _y, width, height],
            projection=projection,
            label=labels[i] if labels is not None else None,
        )
        ax.set_zorder(zorder)
        ax.patch.set_alpha(alpha)
        axes.append(ax)

    center = []
    for ax in axes:
        _, (_center, _, _) = get_ax_positions(ax)
        center.append(_center.flatten().tolist())
    return fig, axes, center

flyvision.analysis.visualization.plt_utils.color_labels

color_labels(labels, color, ax)
Source code in flyvision/analysis/visualization/plt_utils.py
1426
1427
1428
def color_labels(labels: List[str], color, ax):
    for label in labels:
        color_label(label, color, ax)

flyvision.analysis.visualization.plt_utils.color_label

color_label(label, color, ax)

Color a specific label in the given axes.

Parameters:

Name Type Description Default
label str

The label text to color.

required
color Union[str, Tuple[float, float, float]]

The color to apply to the label.

required
ax Axes

The matplotlib axes object.

required
Source code in flyvision/analysis/visualization/plt_utils.py
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
def color_label(
    label: str, color: Union[str, Tuple[float, float, float]], ax: plt.Axes
) -> None:
    """
    Color a specific label in the given axes.

    Args:
        label: The label text to color.
        color: The color to apply to the label.
        ax: The matplotlib axes object.
    """
    for t in ax.texts:
        if t.get_text() == label:
            t.set_color(color)

    for tick in ax.xaxis.get_major_ticks():
        if tick.label1.get_text() == label:
            tick.label1.set_color(color)

    for tick in ax.yaxis.get_major_ticks():
        if tick.label1.get_text() == label:
            tick.label1.set_color(color)

    if ax.xaxis.get_label().get_text() == label:
        ax.xaxis.get_label().set_color(color)

    if ax.yaxis.get_label().get_text() == label:
        ax.yaxis.get_label().set_color(color)

flyvision.analysis.visualization.plt_utils.boldify_labels

boldify_labels(labels, ax)

Make specific labels bold in the given axes.

Parameters:

Name Type Description Default
labels List[str]

List of label texts to make bold.

required
ax Axes

The matplotlib axes object.

required
Source code in flyvision/analysis/visualization/plt_utils.py
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
def boldify_labels(labels: List[str], ax: plt.Axes) -> None:
    """
    Make specific labels bold in the given axes.

    Args:
        labels: List of label texts to make bold.
        ax: The matplotlib axes object.
    """
    for t in ax.texts:
        if t.get_text() in labels:
            t.set_weight("bold")

    for tick in ax.xaxis.get_major_ticks():
        if tick.label1.get_text() in labels:
            tick.label1.set_weight("bold")

    for tick in ax.yaxis.get_major_ticks():
        if tick.label1.get_text() in labels:
            tick.label1.set_weight("bold")

    if ax.xaxis.get_label().get_text() in labels:
        ax.xaxis.get_label().set_weight("bold")

    if ax.yaxis.get_label().get_text() in labels:
        ax.yaxis.get_label().set_weight("bold")

flyvision.analysis.visualization.plt_utils.scatter_on_violins_or_bars

scatter_on_violins_or_bars(
    data,
    ax,
    xticks=None,
    indices=None,
    s=5,
    zorder=100,
    facecolor="none",
    edgecolor="k",
    linewidth=0.5,
    alpha=0.35,
    uniform=[-0.35, 0.35],
    seed=42,
    marker="o",
    **kwargs
)

Scatter data points on violin or bar plots.

Parameters:

Name Type Description Default
data ndarray

Array of shape (n_samples, n_random_variables).

required
ax Axes

Matplotlib axes object to plot on.

required
xticks Optional[ndarray]

X-axis tick positions.

None
indices Optional[ndarray]

Selection along sample dimension.

None
s float

Marker size.

5
zorder int

Z-order for plotting.

100
facecolor Union[str, Tuple[float, float, float], List[Union[str, Tuple[float, float, float]]]]

Color(s) for marker face.

'none'
edgecolor Union[str, Tuple[float, float, float], List[Union[str, Tuple[float, float, float]]]]

Color(s) for marker edge.

'k'
linewidth float

Width of marker edge.

0.5
alpha float

Transparency of markers.

0.35
uniform List[float]

Range for uniform distribution of x-positions.

[-0.35, 0.35]
seed int

Random seed for reproducibility.

42
marker str

Marker style.

'o'
**kwargs

Additional keyword arguments for plt.scatter.

{}

Returns:

Type Description
None

None

Source code in flyvision/analysis/visualization/plt_utils.py
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
def scatter_on_violins_or_bars(
    data: np.ndarray,
    ax: Axes,
    xticks: Optional[np.ndarray] = None,
    indices: Optional[np.ndarray] = None,
    s: float = 5,
    zorder: int = 100,
    facecolor: Union[
        str, Tuple[float, float, float], List[Union[str, Tuple[float, float, float]]]
    ] = "none",
    edgecolor: Union[
        str, Tuple[float, float, float], List[Union[str, Tuple[float, float, float]]]
    ] = "k",
    linewidth: float = 0.5,
    alpha: float = 0.35,
    uniform: List[float] = [-0.35, 0.35],
    seed: int = 42,
    marker: str = "o",
    **kwargs,
) -> None:
    """
    Scatter data points on violin or bar plots.

    Args:
        data: Array of shape (n_samples, n_random_variables).
        ax: Matplotlib axes object to plot on.
        xticks: X-axis tick positions.
        indices: Selection along sample dimension.
        s: Marker size.
        zorder: Z-order for plotting.
        facecolor: Color(s) for marker face.
        edgecolor: Color(s) for marker edge.
        linewidth: Width of marker edge.
        alpha: Transparency of markers.
        uniform: Range for uniform distribution of x-positions.
        seed: Random seed for reproducibility.
        marker: Marker style.
        **kwargs: Additional keyword arguments for plt.scatter.

    Returns:
        None
    """
    random = np.random.RandomState(seed)

    if xticks is None:
        xticks = ax.get_xticks()
    data = np.atleast_2d(data)
    indices = indices if indices is not None else range(data.shape[0])

    if (
        not isinstance(facecolor, Iterable)
        or len(facecolor) != len(data)
        or isinstance(facecolor, str)
    ):
        facecolor = (facecolor,) * len(indices)

    if (
        not isinstance(edgecolor, Iterable)
        or len(edgecolor) != len(data)
        or isinstance(edgecolor, str)
    ):
        edgecolor = (edgecolor,) * len(indices)

    for i, model_index in enumerate(indices):
        ax.scatter(
            xticks + random.uniform(*uniform, size=len(xticks)),
            data[model_index],
            s=s,
            zorder=zorder,
            facecolor=facecolor[i],
            edgecolor=edgecolor[i],
            linewidth=linewidth,
            alpha=alpha,
            marker=marker,
            **kwargs,
        )

flyvision.analysis.visualization.plt_utils.set_spine_tick_params

set_spine_tick_params(
    ax,
    spinewidth=0.25,
    tickwidth=0.25,
    ticklength=3,
    ticklabelpad=2,
    spines=("top", "right", "bottom", "left"),
)

Set spine and tick widths and lengths.

Parameters:

Name Type Description Default
ax Axes

Matplotlib axes object.

required
spinewidth float

Width of spines.

0.25
tickwidth float

Width of ticks.

0.25
ticklength float

Length of ticks.

3
ticklabelpad float

Padding between ticks and labels.

2
spines Tuple[str, ...]

Tuple of spine names to adjust.

('top', 'right', 'bottom', 'left')

Returns:

Type Description
None

None

Source code in flyvision/analysis/visualization/plt_utils.py
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
def set_spine_tick_params(
    ax: Axes,
    spinewidth: float = 0.25,
    tickwidth: float = 0.25,
    ticklength: float = 3,
    ticklabelpad: float = 2,
    spines: Tuple[str, ...] = ("top", "right", "bottom", "left"),
) -> None:
    """
    Set spine and tick widths and lengths.

    Args:
        ax: Matplotlib axes object.
        spinewidth: Width of spines.
        tickwidth: Width of ticks.
        ticklength: Length of ticks.
        ticklabelpad: Padding between ticks and labels.
        spines: Tuple of spine names to adjust.

    Returns:
        None
    """
    for s in spines:
        ax.spines[s].set_linewidth(spinewidth)
    ax.tick_params(axis="both", width=tickwidth, length=ticklength, pad=ticklabelpad)

flyvision.analysis.visualization.plt_utils.scatter_on_violins_with_best

scatter_on_violins_with_best(
    data,
    ax,
    scatter_best,
    scatter_all,
    xticks=None,
    facecolor="none",
    edgecolor="k",
    best_scatter_alpha=1.0,
    all_scatter_alpha=0.35,
    best_index=None,
    best_color=None,
    all_marker="o",
    best_marker="o",
    linewidth=0.5,
    best_linewidth=0.75,
    uniform=[-0.35, 0.35],
    **kwargs
)

Scatter data points on violin plots, optionally highlighting the best point.

Parameters:

Name Type Description Default
data ndarray

Array of shape (n_samples, n_variables).

required
ax Axes

Matplotlib axes object to plot on.

required
scatter_best bool

Whether to scatter the best point.

required
scatter_all bool

Whether to scatter all points.

required
xticks Optional[ndarray]

X-axis tick positions.

None
facecolor Union[str, Tuple[float, float, float]]

Color for marker face.

'none'
edgecolor Union[str, Tuple[float, float, float]]

Color for marker edge.

'k'
best_scatter_alpha float

Alpha for best point.

1.0
all_scatter_alpha float

Alpha for all other points.

0.35
best_index Optional[int]

Index of the best point.

None
best_color Optional[Union[str, Tuple[float, float, float]]]

Color for the best point.

None
all_marker str

Marker style for all points.

'o'
best_marker str

Marker style for the best point.

'o'
linewidth float

Width of marker edge for all points.

0.5
best_linewidth float

Width of marker edge for the best point.

0.75
uniform List[float]

Range for uniform distribution of x-positions.

[-0.35, 0.35]

Returns:

Type Description
None

None

Source code in flyvision/analysis/visualization/plt_utils.py
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
def scatter_on_violins_with_best(
    data: np.ndarray,
    ax: Axes,
    scatter_best: bool,
    scatter_all: bool,
    xticks: Optional[np.ndarray] = None,
    facecolor: Union[str, Tuple[float, float, float]] = "none",
    edgecolor: Union[str, Tuple[float, float, float]] = "k",
    best_scatter_alpha: float = 1.0,
    all_scatter_alpha: float = 0.35,
    best_index: Optional[int] = None,
    best_color: Optional[Union[str, Tuple[float, float, float]]] = None,
    all_marker: str = "o",
    best_marker: str = "o",
    linewidth: float = 0.5,
    best_linewidth: float = 0.75,
    uniform: List[float] = [-0.35, 0.35],
    **kwargs,
) -> None:
    """
    Scatter data points on violin plots, optionally highlighting the best point.

    Args:
        data: Array of shape (n_samples, n_variables).
        ax: Matplotlib axes object to plot on.
        scatter_best: Whether to scatter the best point.
        scatter_all: Whether to scatter all points.
        xticks: X-axis tick positions.
        facecolor: Color for marker face.
        edgecolor: Color for marker edge.
        best_scatter_alpha: Alpha for best point.
        all_scatter_alpha: Alpha for all other points.
        best_index: Index of the best point.
        best_color: Color for the best point.
        all_marker: Marker style for all points.
        best_marker: Marker style for the best point.
        linewidth: Width of marker edge for all points.
        best_linewidth: Width of marker edge for the best point.
        uniform: Range for uniform distribution of x-positions.

    Returns:
        None
    """
    if scatter_all and not scatter_best:
        scatter_on_violins_or_bars(
            data,
            ax,
            xticks=xticks,
            zorder=100,
            facecolor=facecolor,
            edgecolor=edgecolor,
            alpha=all_scatter_alpha,
            uniform=uniform,
            marker=all_marker,
            linewidth=linewidth,
        )
    elif scatter_all:
        assert (
            best_index is not None
        ), "`best_index` must be provided if `scatter_best=True`"
        indices = list(range(data.shape[0]))
        indices.remove(best_index)
        scatter_on_violins_or_bars(
            data,
            ax,
            xticks=xticks,
            indices=indices,
            zorder=10,
            facecolor=facecolor,
            edgecolor=edgecolor,
            alpha=all_scatter_alpha,
            uniform=uniform,
            marker=all_marker,
            linewidth=linewidth,
        )
    if scatter_best:
        assert (
            best_index is not None
        ), "`best_index` must be provided if `scatter_best=True`"
        assert (
            best_color is not None
        ), "`best_color` must be provided if `scatter_all=True`"
        scatter_on_violins_or_bars(
            data,
            ax,
            xticks=xticks,
            indices=[best_index],
            alpha=best_scatter_alpha,
            linewidth=best_linewidth,
            edgecolor=best_color,
            facecolor=best_color,
            uniform=[0, 0],
            s=7.5,
            zorder=11,
            marker=best_marker,
        )

flyvision.analysis.visualization.plt_utils.trim_axis

trim_axis(ax, xaxis=True, yaxis=True)

Trim axis to show only the range of data.

Parameters:

Name Type Description Default
ax Axes

Matplotlib axes object.

required
xaxis bool

Whether to trim x-axis.

True
yaxis bool

Whether to trim y-axis.

True

Returns:

Type Description
None

None

Source code in flyvision/analysis/visualization/plt_utils.py
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
def trim_axis(ax: Axes, xaxis: bool = True, yaxis: bool = True) -> None:
    """
    Trim axis to show only the range of data.

    Args:
        ax: Matplotlib axes object.
        xaxis: Whether to trim x-axis.
        yaxis: Whether to trim y-axis.

    Returns:
        None
    """
    if xaxis:
        xticks = np.array(ax.get_xticks())
        minor_xticks = np.array(ax.get_xticks(minor=True))
        all_ticks = np.sort(np.concatenate((minor_xticks, xticks)))
        if hasattr(xticks, "size"):
            firsttick = np.compress(all_ticks >= min(ax.get_xlim()), all_ticks)[0]
            lasttick = np.compress(all_ticks <= max(ax.get_xlim()), all_ticks)[-1]
            ax.spines["top"].set_bounds(firsttick, lasttick)
            ax.spines["bottom"].set_bounds(firsttick, lasttick)
            new_minor_ticks = minor_xticks.compress(minor_xticks <= lasttick)
            new_minor_ticks = new_minor_ticks.compress(new_minor_ticks >= firsttick)
            newticks = xticks.compress(xticks <= lasttick)
            newticks = newticks.compress(newticks >= firsttick)
            ax.set_xticks(newticks)
            ax.set_xticks(new_minor_ticks, minor=True)

    if yaxis:
        yticks = np.array(ax.get_yticks())
        minor_yticks = np.array(ax.get_yticks(minor=True))
        all_ticks = np.sort(np.concatenate((minor_yticks, yticks)))
        if hasattr(yticks, "size"):
            firsttick = np.compress(all_ticks >= min(ax.get_ylim()), all_ticks)[0]
            lasttick = np.compress(all_ticks <= max(ax.get_ylim()), all_ticks)[-1]
            ax.spines["left"].set_bounds(firsttick, lasttick)
            ax.spines["right"].set_bounds(firsttick, lasttick)
            new_minor_ticks = minor_yticks.compress(minor_yticks <= lasttick)
            new_minor_ticks = new_minor_ticks.compress(new_minor_ticks >= firsttick)
            newticks = yticks.compress(yticks <= lasttick)
            newticks = newticks.compress(newticks >= firsttick)
            ax.set_yticks(newticks)
            ax.set_yticks(new_minor_ticks, minor=True)

flyvision.analysis.visualization.plt_utils.display_significance_value

display_significance_value(
    ax,
    pvalue,
    y,
    x0=None,
    x1=None,
    ticklabel=None,
    bar_width=0.7,
    pthresholds={0.01: "***", 0.05: "**", 0.1: "*"},
    fontsize=8,
    annotate_insignificant="",
    append_tick=False,
    show_bar=False,
    other_ax=None,
    bar_height_ylim_ratio=0.01,
    linewidth=0.5,
    annotate_pthresholds=True,
    loc_pthresh_annotation=(0.1, 0.1),
    location="above",
    asterisk_offset=None,
)

Display a significance value annotation along x at height y.

Parameters:

Name Type Description Default
ax Axes

Matplotlib axes object.

required
pvalue float

P-value to display.

required
y float

Height to put text.

required
x0 Optional[float]

Left edge of bar if show_bar is True.

None
x1 Optional[float]

Right edge of bar if show_bar is True.

None
ticklabel Optional[str]

Tick label to annotate.

None
bar_width float

Width of the bar.

0.7
pthresholds Dict[float, str]

Dictionary of p-value thresholds and corresponding annotations.

{0.01: '***', 0.05: '**', 0.1: '*'}
fontsize int

Font size for annotations.

8
annotate_insignificant str

Annotation for insignificant p-values.

''
append_tick bool

Whether to append annotation to tick label.

False
show_bar bool

Whether to show a bar above the annotation.

False
other_ax Optional[Axes]

Another axes object to get tick labels from.

None
bar_height_ylim_ratio float

Ratio of bar height to y-axis limits.

0.01
linewidth float

Line width for the bar.

0.5
annotate_pthresholds bool

Whether to annotate p-value thresholds.

True
loc_pthresh_annotation Tuple[float, float]

Location of p-value threshold annotation.

(0.1, 0.1)
location Literal['above', 'below']

Location of annotation (“above” or “below”).

'above'
asterisk_offset Optional[float]

Offset for asterisk annotation.

None

Returns:

Type Description
None

None

Source code in flyvision/analysis/visualization/plt_utils.py
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
def display_significance_value(
    ax: Axes,
    pvalue: float,
    y: float,
    x0: Optional[float] = None,
    x1: Optional[float] = None,
    ticklabel: Optional[str] = None,
    bar_width: float = 0.7,
    pthresholds: Dict[float, str] = {0.01: "***", 0.05: "**", 0.1: "*"},
    fontsize: int = 8,
    annotate_insignificant: str = "",
    append_tick: bool = False,
    show_bar: bool = False,
    other_ax: Optional[Axes] = None,
    bar_height_ylim_ratio: float = 0.01,
    linewidth: float = 0.5,
    annotate_pthresholds: bool = True,
    loc_pthresh_annotation: Tuple[float, float] = (0.1, 0.1),
    location: Literal["above", "below"] = "above",
    asterisk_offset: Optional[float] = None,
) -> None:
    """
    Display a significance value annotation along x at height y.

    Args:
        ax: Matplotlib axes object.
        pvalue: P-value to display.
        y: Height to put text.
        x0: Left edge of bar if show_bar is True.
        x1: Right edge of bar if show_bar is True.
        ticklabel: Tick label to annotate.
        bar_width: Width of the bar.
        pthresholds: Dictionary of p-value thresholds and corresponding annotations.
        fontsize: Font size for annotations.
        annotate_insignificant: Annotation for insignificant p-values.
        append_tick: Whether to append annotation to tick label.
        show_bar: Whether to show a bar above the annotation.
        other_ax: Another axes object to get tick labels from.
        bar_height_ylim_ratio: Ratio of bar height to y-axis limits.
        linewidth: Line width for the bar.
        annotate_pthresholds: Whether to annotate p-value thresholds.
        loc_pthresh_annotation: Location of p-value threshold annotation.
        location: Location of annotation ("above" or "below").
        asterisk_offset: Offset for asterisk annotation.

    Returns:
        None
    """
    if x0 is None and x1 is None and ticklabel is None and bar_width is None:
        raise ValueError("specify (x0, x1) or (ticklabel, bar_width)")

    if show_bar and ((x0 is None or x1 is None) and bar_width is None):
        raise ValueError("need to specify width of bar or specify x0 and x1")

    if location == "above":
        va = "bottom"
        asterisk_offset = asterisk_offset or -0.1
    elif location == "below":
        va = "top"
        asterisk_offset = asterisk_offset or -0.05
    else:
        raise ValueError(f"location {location}")

    if x0 is None and x1 is None and ticklabel is not None:
        ticklabels = ax.get_xticklabels()
        if not ticklabels:
            ticklabels = other_ax.get_xticklabels()
        if not ticklabels:
            raise AssertionError("no ticklables found")
        tick = [tick for tick in ticklabels if tick.get_text() == ticklabel][0]
        x, _ = tick.get_position()
        x0 = x - bar_width / 2
        x1 = x + bar_width / 2

    text = ""
    any_thresh = False
    less = []
    for thresh in pthresholds:
        if pvalue < thresh:
            less.append(thresh)
            any_thresh = True

    if (any_thresh or annotate_insignificant) and show_bar:
        bar_height = (ax.get_ylim()[1] - ax.get_ylim()[0]) * bar_height_ylim_ratio
        bar_x = [x0, x0, x1, x1]
        if location == "above":
            bar_y = [y, y + bar_height, y + bar_height, y]
            y = y + bar_height
            mid = ((x0 + x1) / 2, y)
        elif location == "below":
            bar_y = [y, y - bar_height, y - bar_height, y]
            y = y - bar_height
            mid = ((x0 + x1) / 2, y)
        ax.plot(bar_x, bar_y, c="k", lw=linewidth)
        x = mid[0]

    if any_thresh:
        text = pthresholds[min(less)]
        if ticklabel is not None and append_tick:
            tick.set_text(f"{tick.get_text()}$^{{{text}}}$")
            ax.xaxis.set_ticklabels(ax.xaxis.get_ticklabels())
        else:
            ax.annotate(
                text,
                (x, y + asterisk_offset),
                fontsize=fontsize,
                ha="center",
                va=va,
            )

    elif annotate_insignificant:
        if ticklabel is not None and append_tick:
            tick.set_text(f"{tick.get_text()}$^{{{annotate_insignificant}}}$")
            ax.xaxis.set_ticklabels(ax.xaxis.get_ticklabels())
        else:
            ax.annotate(
                annotate_insignificant,
                (x, y),
                fontsize=fontsize,
                ha="center",
                va=va,
            )

    if annotate_pthresholds:
        pthreshold_annotation = ""
        for i, (thresh, symbol) in enumerate(pthresholds.items()):
            pthreshold_annotation += f"{symbol}p<{thresh:.2f}"
            if i != len(pthresholds) - 1:
                pthreshold_annotation += "\n"

        ax.annotate(
            pthreshold_annotation,
            loc_pthresh_annotation,
            xycoords="axes fraction",
            fontsize=fontsize,
            va="bottom",
            ha="left",
        )

flyvision.analysis.visualization.plt_utils.display_pvalues

display_pvalues(
    ax,
    pvalues,
    ticklabels,
    data,
    location="above",
    bar_width=0.7,
    show_bar=True,
    bar_height_ylim_ratio=0.01,
    fontsize=6,
    annotate_insignificant="ns",
    loc_pthresh_annotation=(0.01, 0.01),
    append_tick=False,
    data_relative_offset=0.05,
    asterisk_offset=0,
    pthresholds={0.01: "***", 0.05: "**", 0.1: "*"},
)

Annotate all p-values from a dictionary of x-tick labels to p-values.

Parameters:

Name Type Description Default
ax Axes

Matplotlib axes object.

required
pvalues Dict[str, float]

Dictionary mapping x-tick labels to p-values.

required
ticklabels List[str]

List of x-tick labels.

required
data ndarray

Array of shape (random variables, …).

required
location Literal['above', 'below']

Location of annotation (“above” or “below”).

'above'
bar_width float

Width of the bar.

0.7
show_bar bool

Whether to show a bar above the annotation.

True
bar_height_ylim_ratio float

Ratio of bar height to y-axis limits.

0.01
fontsize int

Font size for annotations.

6
annotate_insignificant str

Annotation for insignificant p-values.

'ns'
loc_pthresh_annotation Tuple[float, float]

Location of p-value threshold annotation.

(0.01, 0.01)
append_tick bool

Whether to append annotation to tick label.

False
data_relative_offset float

Relative offset for annotation placement.

0.05
asterisk_offset float

Offset for asterisk annotation.

0
pthresholds Dict[float, str]

Dictionary of p-value thresholds and corresponding annotations.

{0.01: '***', 0.05: '**', 0.1: '*'}

Returns:

Type Description
None

None

Source code in flyvision/analysis/visualization/plt_utils.py
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
def display_pvalues(
    ax: Axes,
    pvalues: Dict[str, float],
    ticklabels: List[str],
    data: np.ndarray,
    location: Literal["above", "below"] = "above",
    bar_width: float = 0.7,
    show_bar: bool = True,
    bar_height_ylim_ratio: float = 0.01,
    fontsize: int = 6,
    annotate_insignificant: str = "ns",
    loc_pthresh_annotation: Tuple[float, float] = (0.01, 0.01),
    append_tick: bool = False,
    data_relative_offset: float = 0.05,
    asterisk_offset: float = 0,
    pthresholds: Dict[float, str] = {0.01: "***", 0.05: "**", 0.1: "*"},
) -> None:
    """
    Annotate all p-values from a dictionary of x-tick labels to p-values.

    Args:
        ax: Matplotlib axes object.
        pvalues: Dictionary mapping x-tick labels to p-values.
        ticklabels: List of x-tick labels.
        data: Array of shape (random variables, ...).
        location: Location of annotation ("above" or "below").
        bar_width: Width of the bar.
        show_bar: Whether to show a bar above the annotation.
        bar_height_ylim_ratio: Ratio of bar height to y-axis limits.
        fontsize: Font size for annotations.
        annotate_insignificant: Annotation for insignificant p-values.
        loc_pthresh_annotation: Location of p-value threshold annotation.
        append_tick: Whether to append annotation to tick label.
        data_relative_offset: Relative offset for annotation placement.
        asterisk_offset: Offset for asterisk annotation.
        pthresholds: Dictionary of p-value thresholds and corresponding annotations.

    Returns:
        None
    """
    for key in pvalues:
        if key not in ticklabels:
            raise ValueError(f"pvalue key {key} is not a ticklabel")

    offset = data_relative_offset * np.abs(data.max() - data.min())

    ylim = ax.get_ylim()
    bars = []
    for ticklabel, pvalue in pvalues.items():
        index = [i for i, _ticklabel in enumerate(ticklabels) if _ticklabel == ticklabel][
            0
        ]
        _values = data[index]

        if location == "above":
            _max = _values.max()
            y = min(_max + offset, ylim[1])
        elif location == "below":
            _min = _values.min()
            y = max(_min - offset, ylim[0])

        display_significance_value(
            ax,
            pvalue,
            y=y,
            ticklabel=str(ticklabel),
            bar_width=bar_width,
            show_bar=show_bar,
            bar_height_ylim_ratio=bar_height_ylim_ratio,
            fontsize=fontsize,
            annotate_insignificant=annotate_insignificant,
            loc_pthresh_annotation=loc_pthresh_annotation,
            append_tick=append_tick,
            location=location,
            asterisk_offset=asterisk_offset,
            pthresholds=pthresholds,
        )
        bars.append(y)

    ax.set_ylim(*get_lims([bars, ylim], 0.01))

flyvision.analysis.visualization.plt_utils.closest_divisors

closest_divisors(n)

Find the closest divisors of a number.

Parameters:

Name Type Description Default
n int

Number to find divisors for.

required

Returns:

Type Description
Tuple[int, int]

Tuple of closest divisors.

Source code in flyvision/analysis/visualization/plt_utils.py
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
def closest_divisors(n: int) -> Tuple[int, int]:
    """
    Find the closest divisors of a number.

    Args:
        n: Number to find divisors for.

    Returns:
        Tuple of closest divisors.
    """
    closest_diff = float("inf")
    closest_divisors = (1, 1)

    for divisor in range(1, int(n**0.5) + 1):
        if n % divisor == 0:
            other_divisor = n // divisor
            diff = abs(divisor - other_divisor)

            if diff < closest_diff:
                closest_diff = diff
                closest_divisors = (divisor, other_divisor)

    return closest_divisors

flyvision.analysis.visualization.plt_utils.standalone_legend

standalone_legend(
    labels,
    colors,
    legend_elements=None,
    alpha=1,
    fontsize=6,
    fig=None,
    ax=None,
    lw=4,
    labelspacing=0.5,
    handlelength=2.0,
    n_cols=None,
    columnspacing=0.8,
    figsize=None,
    linestyles=None,
)

Create a standalone legend.

Parameters:

Name Type Description Default
labels List[str]

List of labels for legend entries.

required
colors List[Union[str, Tuple[float, float, float]]]

List of colors for legend entries.

required
legend_elements Optional[List]

List of custom legend elements.

None
alpha float

Alpha value for legend entries.

1
fontsize int

Font size for legend text.

6
fig Optional[Figure]

Existing figure to use.

None
ax Optional[Axes]

Existing axes to use.

None
lw float

Line width for legend entries.

4
labelspacing float

Vertical space between legend entries.

0.5
handlelength float

Length of the legend handles.

2.0
n_cols Optional[int]

Number of columns in the legend.

None
columnspacing float

Spacing between legend columns.

0.8
figsize Optional[Tuple[float, float]]

Figure size (width, height) in inches.

None
linestyles Optional[List[str]]

List of line styles for legend entries.

None

Returns:

Type Description
Tuple[Figure, Axes]

Tuple containing the figure and axes objects.

Source code in flyvision/analysis/visualization/plt_utils.py
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
def standalone_legend(
    labels: List[str],
    colors: List[Union[str, Tuple[float, float, float]]],
    legend_elements: Optional[List] = None,
    alpha: float = 1,
    fontsize: int = 6,
    fig: Optional[plt.Figure] = None,
    ax: Optional[Axes] = None,
    lw: float = 4,
    labelspacing: float = 0.5,
    handlelength: float = 2.0,
    n_cols: Optional[int] = None,
    columnspacing: float = 0.8,
    figsize: Optional[Tuple[float, float]] = None,
    linestyles: Optional[List[str]] = None,
) -> Tuple[plt.Figure, Axes]:
    """
    Create a standalone legend.

    Args:
        labels: List of labels for legend entries.
        colors: List of colors for legend entries.
        legend_elements: List of custom legend elements.
        alpha: Alpha value for legend entries.
        fontsize: Font size for legend text.
        fig: Existing figure to use.
        ax: Existing axes to use.
        lw: Line width for legend entries.
        labelspacing: Vertical space between legend entries.
        handlelength: Length of the legend handles.
        n_cols: Number of columns in the legend.
        columnspacing: Spacing between legend columns.
        figsize: Figure size (width, height) in inches.
        linestyles: List of line styles for legend entries.

    Returns:
        Tuple containing the figure and axes objects.
    """
    if legend_elements is None:
        from matplotlib.lines import Line2D

        legend_elements = []
        for i, label in enumerate(labels):
            legend_elements.append(
                Line2D(
                    [0],
                    [0],
                    color=colors[i],
                    lw=lw,
                    label=label,
                    alpha=alpha,
                    solid_capstyle="round",
                    linestyle=linestyles[i] if linestyles is not None else "solid",
                )
            )
    if n_cols is None:
        n_rows, n_cols = closest_divisors(
            len(legend_elements) - (len(legend_elements) % 2)
        )
    else:
        n_rows = int(np.ceil(len(labels) / n_cols))
    if fig is None or ax is None:
        figsize = figsize or [0.1 * n_cols, 0.1 * n_rows]
        fig, ax = plt.subplots(figsize=figsize)
    ax.legend(
        handles=legend_elements,
        loc="center",
        edgecolor="white",
        framealpha=1,
        fontsize=fontsize,
        labelspacing=labelspacing,
        handlelength=handlelength,
        columnspacing=columnspacing,
        ncol=n_cols,
    )
    rm_spines(ax, rm_yticks=True, rm_xticks=True)
    return fig, ax

flyvision.analysis.visualization.plt_utils.extend_arg

extend_arg(arg, argtype, r, default, dim=-1)

Extend an argument to the correct length for a given dimension.

Parameters:

Name Type Description Default
arg Union[Number, List[Number]]

Argument to extend.

required
argtype type

Type of the argument.

required
r ndarray

Reference array for shape.

required
default Any

Default value if arg is not provided.

required
dim int

Dimension to extend along.

-1

Returns:

Type Description
Union[List[Number], Number]

Extended argument.

Source code in flyvision/analysis/visualization/plt_utils.py
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
def extend_arg(
    arg: Union[Number, List[Number]],
    argtype: type,
    r: np.ndarray,
    default: Any,
    dim: int = -1,
) -> Union[List[Number], Number]:
    """
    Extend an argument to the correct length for a given dimension.

    Args:
        arg: Argument to extend.
        argtype: Type of the argument.
        r: Reference array for shape.
        default: Default value if arg is not provided.
        dim: Dimension to extend along.

    Returns:
        Extended argument.
    """
    r = np.asarray(r)

    if isinstance(arg, argtype) and r.ndim > 1:
        return [arg] * r.shape[dim]
    elif (
        isinstance(arg, Iterable)
        and len(arg) == r.shape[dim]
        or r.ndim == 1
        and np.asarray(arg).size == 1
    ):
        return arg
    elif r.ndim == 1:
        return default
    else:
        raise ValueError(
            f"arg must be either an integer or a list of length {r.shape[-1]}."
        )