Skip to content

Animations

flyvision.analysis.animations.imshow

Classes

flyvision.analysis.animations.imshow.Imshow

Bases: Animation

Animates an array of images using imshow.

Parameters:

Name Type Description Default
images ndarray

Array of images to animate (n_samples, n_frames, height, width).

required
fig Optional[Figure]

Existing Figure instance or None.

None
ax Optional[Axes]

Existing Axis instance or None.

None
update bool

Whether to update the canvas after an animation step. Must be False if this animation is composed with others.

True
figsize List[int]

Size of the figure.

[1, 1]
sleep float

Time to sleep between frames.

0.01
**kwargs

Additional arguments passed to plt.imshow.

{}

Attributes:

Name Type Description
fig Figure

The figure object.

ax Axes

The axes object.

kwargs dict

Additional arguments for imshow.

update bool

Whether to update the canvas after each step.

n_samples int

Number of samples in the images array.

frames int

Number of frames in each sample.

images ndarray

Array of images to animate.

sleep float

Time to sleep between frames.

img AxesImage

The image object created by imshow.

Note

The images array should have shape (n_samples, n_frames, height, width).

Source code in flyvision/analysis/animations/imshow.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
class Imshow(Animation):
    """Animates an array of images using imshow.

    Args:
        images: Array of images to animate (n_samples, n_frames, height, width).
        fig: Existing Figure instance or None.
        ax: Existing Axis instance or None.
        update: Whether to update the canvas after an animation step.
            Must be False if this animation is composed with others.
        figsize: Size of the figure.
        sleep: Time to sleep between frames.
        **kwargs: Additional arguments passed to plt.imshow.

    Attributes:
        fig (plt.Figure): The figure object.
        ax (plt.Axes): The axes object.
        kwargs (dict): Additional arguments for imshow.
        update (bool): Whether to update the canvas after each step.
        n_samples (int): Number of samples in the images array.
        frames (int): Number of frames in each sample.
        images (np.ndarray): Array of images to animate.
        sleep (float): Time to sleep between frames.
        img (plt.AxesImage): The image object created by imshow.

    Note:
        The `images` array should have shape (n_samples, n_frames, height, width).
    """

    def __init__(
        self,
        images: np.ndarray,
        fig: Optional[plt.Figure] = None,
        ax: Optional[plt.Axes] = None,
        update: bool = True,
        figsize: List[int] = [1, 1],
        sleep: float = 0.01,
        **kwargs,
    ) -> None:
        super().__init__()
        self.fig, self.ax = plt_utils.init_plot(
            figsize=figsize,
            fig=fig,
            ax=ax,
            position=[0, 0, 1, 1],
            set_axis_off=True,
        )
        self.kwargs = kwargs
        self.update = update
        self.n_samples, self.frames = images.shape[:2]
        self.images = images
        self.sleep = sleep
        super().__init__(None, self.fig)

    def init(self, frame: int = 0) -> None:
        """Initialize the animation.

        Args:
            frame: The initial frame to display.
        """
        self.img = self.ax.imshow(self.images[self.batch_sample, frame], **self.kwargs)
        if self.sleep is not None:
            sleep(self.sleep)

    def animate(self, frame: int) -> None:
        """Animate a single frame.

        Args:
            frame: The frame number to animate.
        """
        self.img.set_data(self.images[self.batch_sample, frame])

        if self.update:
            self.update_figure()

        if self.sleep is not None:
            sleep(self.sleep)
init
init(frame=0)

Initialize the animation.

Parameters:

Name Type Description Default
frame int

The initial frame to display.

0
Source code in flyvision/analysis/animations/imshow.py
66
67
68
69
70
71
72
73
74
def init(self, frame: int = 0) -> None:
    """Initialize the animation.

    Args:
        frame: The initial frame to display.
    """
    self.img = self.ax.imshow(self.images[self.batch_sample, frame], **self.kwargs)
    if self.sleep is not None:
        sleep(self.sleep)
animate
animate(frame)

Animate a single frame.

Parameters:

Name Type Description Default
frame int

The frame number to animate.

required
Source code in flyvision/analysis/animations/imshow.py
76
77
78
79
80
81
82
83
84
85
86
87
88
def animate(self, frame: int) -> None:
    """Animate a single frame.

    Args:
        frame: The frame number to animate.
    """
    self.img.set_data(self.images[self.batch_sample, frame])

    if self.update:
        self.update_figure()

    if self.sleep is not None:
        sleep(self.sleep)

flyvision.analysis.animations.hexscatter

Classes

flyvision.analysis.animations.hexscatter.HexScatter

Bases: Animation

Regular hex-scatter animation.

For hexals not on a regular hex grid, use the function pad_to_regular_hex.

Parameters:

Name Type Description Default
hexarray ndarray

Shape (n_samples, n_frames, 1, n_input_elements).

required
u Optional[List[float]]

List of u coordinates of elements to plot.

None
v Optional[List[float]]

List of v coordinates of elements to plot.

None
cranges Optional[List[float]]

Color minimal and maximal abs value (n_samples).

None
vmin Optional[float]

Color minimal value.

None
vmax Optional[float]

Color maximal value.

None
fig Optional[Figure]

Existing Figure instance or None.

None
ax Optional[Axes]

Existing Axis instance or None.

None
batch_sample int

Batch sample to start from.

0
cmap Union[str, Colormap]

Colormap for the hex-scatter.

get_cmap('binary_r')
edgecolor Optional[str]

Edgecolor for the hexals. None for no edge.

None
update_edge_color bool

Whether to update the edgecolor after an animation step.

True
update bool

Whether to update the canvas after an animation step. Must be False if this animation is composed with others using AnimationCollector.

False
label str

Label of the animation. Formatted with the current sample and frame number per frame.

'Sample: {}\nFrame: {}'
labelxy Tuple[float, float]

Location of the label.

(0.1, 0.95)
fontsize float

Fontsize.

5
cbar bool

Display colorbar.

True
background_color str

Background color.

'none'
midpoint Optional[float]

Midpoint for diverging colormaps.

None

Attributes:

Name Type Description
fig Figure

Matplotlib figure instance.

ax Axes

Matplotlib axes instance.

background_color str

Background color.

hexarray ndarray

Hex array data.

cranges Optional[List[float]]

Color ranges.

vmin Optional[float]

Minimum value for color mapping.

vmax Optional[float]

Maximum value for color mapping.

midpoint Optional[float]

Midpoint for diverging colormaps.

kwargs dict

Additional keyword arguments.

batch_sample int

Batch sample index.

cmap

Colormap for the hex-scatter.

update bool

Whether to update the canvas after an animation step.

label str

Label template for the animation.

labelxy Tuple[float, float]

Label position.

label_text

Text object for the label.

n_samples int

Number of samples.

frames int

Number of frames.

extent int

Hex extent.

edgecolor Optional[str]

Edgecolor for the hexals.

update_edge_color bool

Whether to update the edgecolor.

fontsize float

Font size.

cbar bool

Whether to display colorbar.

u List[float]

U coordinates of elements to plot.

v List[float]

V coordinates of elements to plot.

Source code in flyvision/analysis/animations/hexscatter.py
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
class HexScatter(Animation):
    """Regular hex-scatter animation.

    For hexals not on a regular hex grid, use the function pad_to_regular_hex.

    Args:
        hexarray: Shape (n_samples, n_frames, 1, n_input_elements).
        u: List of u coordinates of elements to plot.
        v: List of v coordinates of elements to plot.
        cranges: Color minimal and maximal abs value (n_samples).
        vmin: Color minimal value.
        vmax: Color maximal value.
        fig: Existing Figure instance or None.
        ax: Existing Axis instance or None.
        batch_sample: Batch sample to start from.
        cmap: Colormap for the hex-scatter.
        edgecolor: Edgecolor for the hexals. None for no edge.
        update_edge_color: Whether to update the edgecolor after an animation step.
        update: Whether to update the canvas after an animation step.
            Must be False if this animation is composed with others using
            AnimationCollector.
        label: Label of the animation. Formatted with the current sample and
            frame number per frame.
        labelxy: Location of the label.
        fontsize: Fontsize.
        cbar: Display colorbar.
        background_color: Background color.
        midpoint: Midpoint for diverging colormaps.

    Attributes:
        fig (Figure): Matplotlib figure instance.
        ax (Axes): Matplotlib axes instance.
        background_color (str): Background color.
        hexarray (np.ndarray): Hex array data.
        cranges (Optional[List[float]]): Color ranges.
        vmin (Optional[float]): Minimum value for color mapping.
        vmax (Optional[float]): Maximum value for color mapping.
        midpoint (Optional[float]): Midpoint for diverging colormaps.
        kwargs (dict): Additional keyword arguments.
        batch_sample (int): Batch sample index.
        cmap: Colormap for the hex-scatter.
        update (bool): Whether to update the canvas after an animation step.
        label (str): Label template for the animation.
        labelxy (Tuple[float, float]): Label position.
        label_text: Text object for the label.
        n_samples (int): Number of samples.
        frames (int): Number of frames.
        extent (int): Hex extent.
        edgecolor (Optional[str]): Edgecolor for the hexals.
        update_edge_color (bool): Whether to update the edgecolor.
        fontsize (float): Font size.
        cbar (bool): Whether to display colorbar.
        u (List[float]): U coordinates of elements to plot.
        v (List[float]): V coordinates of elements to plot.

    """

    def __init__(
        self,
        hexarray: np.ndarray,
        u: Optional[List[float]] = None,
        v: Optional[List[float]] = None,
        cranges: Optional[List[float]] = None,
        vmin: Optional[float] = None,
        vmax: Optional[float] = None,
        fig: Optional[Figure] = None,
        ax: Optional[Axes] = None,
        batch_sample: int = 0,
        cmap: Union[str, Colormap] = cm.get_cmap("binary_r"),
        edgecolor: Optional[str] = None,
        update_edge_color: bool = True,
        update: bool = False,
        label: str = "Sample: {}\nFrame: {}",
        labelxy: Tuple[float, float] = (0.1, 0.95),
        fontsize: float = 5,
        cbar: bool = True,
        background_color: str = "none",
        midpoint: Optional[float] = None,
        **kwargs,
    ):
        self.fig = fig
        self.ax = ax
        self.background_color = background_color
        self.hexarray = utils.tensor_utils.to_numpy(hexarray)
        self.cranges = cranges
        self.vmin = vmin
        self.vmax = vmax
        self.midpoint = midpoint
        self.kwargs = kwargs
        self.batch_sample = batch_sample
        self.cmap = cmap
        self.update = update
        self.label = label
        self.labelxy = labelxy
        self.label_text = None
        self.n_samples, self.frames = hexarray.shape[0:2]
        self.extent = utils.hex_utils.get_hextent(hexarray.shape[-1])
        self.edgecolor = edgecolor
        self.update_edge_color = update_edge_color
        self.fontsize = fontsize
        self.cbar = cbar
        if u is None or v is None:
            u, v = utils.hex_utils.get_hex_coords(self.extent)
        self.u = u
        self.v = v
        super().__init__(None, self.fig)

    def init(self, frame: int = 0) -> None:
        """Initialize the animation.

        Args:
            frame: Frame number to initialize.

        """
        if frame < 0:
            frame += self.frames
        u, v = utils.hex_utils.get_hex_coords(self.extent)
        _values = self.hexarray[self.batch_sample]
        _vmin = _values.min()
        _vmax = _values.max()
        values = _values[frame].squeeze()
        vmin = (
            -self.cranges[self.batch_sample]
            if self.cranges is not None
            else self.vmin
            if self.vmin is not None
            else _vmin
        )
        vmax = (
            +self.cranges[self.batch_sample]
            if self.cranges is not None
            else self.vmax
            if self.vmax is not None
            else _vmax
        )
        scalarmapper, norm = plt_utils.get_scalarmapper(
            scalarmapper=None,
            cmap=self.cmap,
            norm=None,
            vmin=vmin,
            vmax=vmax,
            midpoint=self.midpoint,
        )
        self.fig, self.ax, (self.label_text, _) = plots.hex_scatter(
            self.u,
            self.v,
            values,
            fig=self.fig,
            midpoint=None,
            scalarmapper=scalarmapper,
            norm=norm,
            ax=self.ax,
            cmap=self.cmap,
            annotate=False,
            labelxy=self.labelxy,
            label=self.label.format(self.batch_sample, frame),
            edgecolor=self.edgecolor,
            fill=False,
            cbar=False,
            fontsize=self.fontsize,
            **self.kwargs,
        )
        self.fig.patch.set_facecolor(self.background_color)
        self.ax.patch.set_facecolor(self.background_color)
        if self.cbar:
            plt_utils.add_colorbar_to_fig(
                self.fig,
                [self.ax],
                label="",
                width=0.01,
                height=0.5,
                x_offset=-2,
                cmap=self.cmap,
                norm=norm,
                fontsize=self.fontsize - 1,
                tick_length=1,
                tick_width=0.25,
                rm_outline=True,
                n_ticks=5,
                n_decimals=0,
            )

    def animate(self, frame: int) -> None:
        """Animate a single frame.

        Args:
            frame: Frame number to animate.

        """
        if frame < 0:
            frame += self.frames
        _values = self.hexarray[self.batch_sample]
        _vmin = _values.min()
        _vmax = _values.max()
        values = _values[frame].squeeze()
        vmin = (
            -self.cranges[self.batch_sample]
            if self.cranges is not None
            else self.vmin
            if self.vmin is not None
            else _vmin
        )
        vmax = (
            +self.cranges[self.batch_sample]
            if self.cranges is not None
            else self.vmax
            if self.vmax is not None
            else _vmax
        )
        scalarmapper, norm = plt_utils.get_scalarmapper(
            scalarmapper=None,
            cmap=self.cmap,
            norm=None,
            vmin=vmin,
            vmax=vmax,
            midpoint=self.midpoint,
        )
        if self.cbar:
            for ax in self.fig.axes:
                if ax.get_label() == "cbar":
                    ax.remove()
            plt_utils.add_colorbar_to_fig(
                self.fig,
                [self.ax],
                label="",
                width=0.01,
                height=0.5,
                x_offset=-2,
                cmap=self.cmap,
                norm=norm,
                fontsize=self.fontsize - 1,
                tick_length=1,
                tick_width=0.25,
                rm_outline=True,
                n_ticks=5,
                n_decimals=0,
            )
        fcolors = scalarmapper.to_rgba(values)
        for i, fc in enumerate(fcolors):
            if self.update_edge_color:
                self.ax.patches[i].set_color(fc)
            else:
                self.ax.patches[i].set_facecolor(fc)

        if self.label:
            self.label_text.set_text(self.label.format(self.batch_sample, frame))

        if self.update:
            self.update_figure()
init
init(frame=0)

Initialize the animation.

Parameters:

Name Type Description Default
frame int

Frame number to initialize.

0
Source code in flyvision/analysis/animations/hexscatter.py
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
def init(self, frame: int = 0) -> None:
    """Initialize the animation.

    Args:
        frame: Frame number to initialize.

    """
    if frame < 0:
        frame += self.frames
    u, v = utils.hex_utils.get_hex_coords(self.extent)
    _values = self.hexarray[self.batch_sample]
    _vmin = _values.min()
    _vmax = _values.max()
    values = _values[frame].squeeze()
    vmin = (
        -self.cranges[self.batch_sample]
        if self.cranges is not None
        else self.vmin
        if self.vmin is not None
        else _vmin
    )
    vmax = (
        +self.cranges[self.batch_sample]
        if self.cranges is not None
        else self.vmax
        if self.vmax is not None
        else _vmax
    )
    scalarmapper, norm = plt_utils.get_scalarmapper(
        scalarmapper=None,
        cmap=self.cmap,
        norm=None,
        vmin=vmin,
        vmax=vmax,
        midpoint=self.midpoint,
    )
    self.fig, self.ax, (self.label_text, _) = plots.hex_scatter(
        self.u,
        self.v,
        values,
        fig=self.fig,
        midpoint=None,
        scalarmapper=scalarmapper,
        norm=norm,
        ax=self.ax,
        cmap=self.cmap,
        annotate=False,
        labelxy=self.labelxy,
        label=self.label.format(self.batch_sample, frame),
        edgecolor=self.edgecolor,
        fill=False,
        cbar=False,
        fontsize=self.fontsize,
        **self.kwargs,
    )
    self.fig.patch.set_facecolor(self.background_color)
    self.ax.patch.set_facecolor(self.background_color)
    if self.cbar:
        plt_utils.add_colorbar_to_fig(
            self.fig,
            [self.ax],
            label="",
            width=0.01,
            height=0.5,
            x_offset=-2,
            cmap=self.cmap,
            norm=norm,
            fontsize=self.fontsize - 1,
            tick_length=1,
            tick_width=0.25,
            rm_outline=True,
            n_ticks=5,
            n_decimals=0,
        )
animate
animate(frame)

Animate a single frame.

Parameters:

Name Type Description Default
frame int

Frame number to animate.

required
Source code in flyvision/analysis/animations/hexscatter.py
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
def animate(self, frame: int) -> None:
    """Animate a single frame.

    Args:
        frame: Frame number to animate.

    """
    if frame < 0:
        frame += self.frames
    _values = self.hexarray[self.batch_sample]
    _vmin = _values.min()
    _vmax = _values.max()
    values = _values[frame].squeeze()
    vmin = (
        -self.cranges[self.batch_sample]
        if self.cranges is not None
        else self.vmin
        if self.vmin is not None
        else _vmin
    )
    vmax = (
        +self.cranges[self.batch_sample]
        if self.cranges is not None
        else self.vmax
        if self.vmax is not None
        else _vmax
    )
    scalarmapper, norm = plt_utils.get_scalarmapper(
        scalarmapper=None,
        cmap=self.cmap,
        norm=None,
        vmin=vmin,
        vmax=vmax,
        midpoint=self.midpoint,
    )
    if self.cbar:
        for ax in self.fig.axes:
            if ax.get_label() == "cbar":
                ax.remove()
        plt_utils.add_colorbar_to_fig(
            self.fig,
            [self.ax],
            label="",
            width=0.01,
            height=0.5,
            x_offset=-2,
            cmap=self.cmap,
            norm=norm,
            fontsize=self.fontsize - 1,
            tick_length=1,
            tick_width=0.25,
            rm_outline=True,
            n_ticks=5,
            n_decimals=0,
        )
    fcolors = scalarmapper.to_rgba(values)
    for i, fc in enumerate(fcolors):
        if self.update_edge_color:
            self.ax.patches[i].set_color(fc)
        else:
            self.ax.patches[i].set_facecolor(fc)

    if self.label:
        self.label_text.set_text(self.label.format(self.batch_sample, frame))

    if self.update:
        self.update_figure()

flyvision.analysis.animations.hexflow

Classes

flyvision.analysis.animations.hexflow.HexFlow

Bases: Animation

Hexscatter of a color encoded flow field.

Parameters:

Name Type Description Default
flow Union[ndarray, Tensor]

Optic flow of shape (n_samples, n_frames, 2, n_input_elements).

required
fig Optional[Figure]

Existing Figure instance or None.

None
ax Optional[Axes]

Existing Axis instance or None.

None
batch_sample int

Batch sample to start from.

0
cmap Colormap

Colormap for the hex-scatter.

cm_uniform_2d
cwheel bool

Display colorwheel.

False
cwheelxy Tuple[float, float]

Colorwheel offset x and y.

()
label str

Label of the animation.

'Sample: {}\nFrame: {}'
labelxy Tuple[float, float]

Normalized x and y location of the label.

(0, 1)
update bool

Whether to update the canvas after an animation step.

False
path Optional[str]

Path to save the animation to.

None
figsize List[float]

Figure size.

[2, 2]
fontsize float

Font size.

5
background_color Literal['none']

Background color of the figure and axis.

'none'

Attributes:

Name Type Description
fig Figure

Figure instance.

ax Axes

Axis instance.

background_color str

Background color of the figure and axis.

batch_sample int

Batch sample to start from.

kwargs dict

Additional keyword arguments.

update bool

Whether to update the canvas after an animation step. Must be False if this animation is composed with others using AnimationCollector.

cmap

Colormap for the hex-scatter.

cwheel bool

Display colorwheel.

cwheelxy Tuple[float, float]

Colorwheel offset x and y.

label str

Label of the animation.

labelxy Tuple[float, float]

Normalized x and y location of the label.

label_text Text

Text instance for the label.

sm ScalarMappable

ScalarMappable instance for color mapping.

fontsize float

Font size.

figsize List[float, float]

Figure size.

flow ndarray

Optic flow data.

n_samples int

Number of samples in the flow data.

frames int

Number of frames in the flow data.

extent Tuple[float, float, float, float]

Extent of the hexagonal grid.

Note

All kwargs are passed to ~flyvision.analysis.visualization.plots.hex_flow.

Source code in flyvision/analysis/animations/hexflow.py
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
class HexFlow(Animation):
    """Hexscatter of a color encoded flow field.

    Args:
        flow: Optic flow of shape (n_samples, n_frames, 2, n_input_elements).
        fig: Existing Figure instance or None.
        ax: Existing Axis instance or None.
        batch_sample: Batch sample to start from.
        cmap: Colormap for the hex-scatter.
        cwheel: Display colorwheel.
        cwheelxy: Colorwheel offset x and y.
        label: Label of the animation.
        labelxy: Normalized x and y location of the label.
        update: Whether to update the canvas after an animation step.
        path: Path to save the animation to.
        figsize: Figure size.
        fontsize: Font size.
        background_color: Background color of the figure and axis.

    Attributes:
        fig (Figure): Figure instance.
        ax (Axes): Axis instance.
        background_color (str): Background color of the figure and axis.
        batch_sample (int): Batch sample to start from.
        kwargs (dict): Additional keyword arguments.
        update (bool): Whether to update the canvas after an animation step.
            Must be False if this animation is composed with others using
            AnimationCollector.
        cmap: Colormap for the hex-scatter.
        cwheel (bool): Display colorwheel.
        cwheelxy (Tuple[float, float]): Colorwheel offset x and y.
        label (str): Label of the animation.
        labelxy (Tuple[float, float]): Normalized x and y location of the label.
        label_text (Text): Text instance for the label.
        sm (ScalarMappable): ScalarMappable instance for color mapping.
        fontsize (float): Font size.
        figsize (List[float, float]): Figure size.
        flow (np.ndarray): Optic flow data.
        n_samples (int): Number of samples in the flow data.
        frames (int): Number of frames in the flow data.
        extent (Tuple[float, float, float, float]): Extent of the hexagonal grid.

    Note:
        All kwargs are passed to ~flyvision.analysis.visualization.plots.hex_flow.
    """

    def __init__(
        self,
        flow: Union[np.ndarray, "torch.Tensor"],
        fig: Optional[Figure] = None,
        ax: Optional[Axes] = None,
        batch_sample: int = 0,
        cmap: "Colormap" = plt_utils.cm_uniform_2d,
        cwheel: bool = False,
        cwheelxy: Tuple[float, float] = (),
        label: str = "Sample: {}\nFrame: {}",
        labelxy: Tuple[float, float] = (0, 1),
        update: bool = False,
        path: Optional[str] = None,
        figsize: List[float] = [2, 2],
        fontsize: float = 5,
        background_color: Literal["none"] = "none",
        **kwargs,
    ):
        self.fig = fig
        self.ax = ax
        self.background_color = background_color
        self.batch_sample = batch_sample
        self.kwargs = kwargs
        self.update = update
        self.cmap = cmap
        self.cwheel = cwheel
        self.cwheelxy = cwheelxy

        self.label = label
        self.labelxy = labelxy
        self.label_text: Optional[Text] = None
        self.sm: Optional[ScalarMappable] = None
        self.fontsize = fontsize
        self.figsize = figsize

        self.flow = utils.tensor_utils.to_numpy(flow)

        self.n_samples, self.frames = self.flow.shape[0:2]
        self.extent = utils.hex_utils.get_hextent(self.flow.shape[-1])
        super().__init__(path, self.fig)

    def init(self, frame: int = 0) -> None:
        """Initialize the animation.

        Args:
            frame: Frame number to initialize with.
        """
        u, v = utils.hex_utils.get_hex_coords(self.extent)
        self.fig, self.ax, (self.label_text, self.sm, _, _) = plots.hex_flow(
            u,
            v,
            self.flow[self.batch_sample, frame],
            fig=self.fig,
            ax=self.ax,
            cwheel=self.cwheel,
            cwheelxy=self.cwheelxy,
            cmap=self.cmap,
            annotate=False,
            labelxy=self.labelxy,
            label=self.label.format(self.batch_sample, frame),
            figsize=self.figsize,
            fontsize=self.fontsize,
            **self.kwargs,
        )
        self.fig.patch.set_facecolor(self.background_color)
        self.ax.patch.set_facecolor(self.background_color)

    def animate(self, frame: int) -> None:
        """Animate a single frame.

        Args:
            frame: Frame number to animate.
        """
        flow = self.flow[self.batch_sample, frame]

        r = np.sqrt(flow[0] ** 2 + flow[1] ** 2)
        r /= r.max()
        theta = np.arctan2(flow[1], flow[0])
        color = self.sm.to_rgba(theta)
        color[:, -1] = r

        for i, fc in enumerate(color):
            self.ax.patches[i].set_color(fc)

        if self.label:
            self.label_text.set_text(self.label.format(self.batch_sample, frame))

        if self.update:
            self.update_figure()
init
init(frame=0)

Initialize the animation.

Parameters:

Name Type Description Default
frame int

Frame number to initialize with.

0
Source code in flyvision/analysis/animations/hexflow.py
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
def init(self, frame: int = 0) -> None:
    """Initialize the animation.

    Args:
        frame: Frame number to initialize with.
    """
    u, v = utils.hex_utils.get_hex_coords(self.extent)
    self.fig, self.ax, (self.label_text, self.sm, _, _) = plots.hex_flow(
        u,
        v,
        self.flow[self.batch_sample, frame],
        fig=self.fig,
        ax=self.ax,
        cwheel=self.cwheel,
        cwheelxy=self.cwheelxy,
        cmap=self.cmap,
        annotate=False,
        labelxy=self.labelxy,
        label=self.label.format(self.batch_sample, frame),
        figsize=self.figsize,
        fontsize=self.fontsize,
        **self.kwargs,
    )
    self.fig.patch.set_facecolor(self.background_color)
    self.ax.patch.set_facecolor(self.background_color)
animate
animate(frame)

Animate a single frame.

Parameters:

Name Type Description Default
frame int

Frame number to animate.

required
Source code in flyvision/analysis/animations/hexflow.py
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
def animate(self, frame: int) -> None:
    """Animate a single frame.

    Args:
        frame: Frame number to animate.
    """
    flow = self.flow[self.batch_sample, frame]

    r = np.sqrt(flow[0] ** 2 + flow[1] ** 2)
    r /= r.max()
    theta = np.arctan2(flow[1], flow[0])
    color = self.sm.to_rgba(theta)
    color[:, -1] = r

    for i, fc in enumerate(color):
        self.ax.patches[i].set_color(fc)

    if self.label:
        self.label_text.set_text(self.label.format(self.batch_sample, frame))

    if self.update:
        self.update_figure()

flyvision.analysis.animations.sintel

Classes

flyvision.analysis.animations.sintel.SintelSample

Bases: AnimationCollector

Sintel-specific animation of input, target, and groundtruth data.

Parameters:

Name Type Description Default
lum ndarray

Input of shape (n_samples, n_frames, n_hexals).

required
target ndarray

Target of shape (n_samples, n_frames, n_dims, n_features).

required
prediction Optional[ndarray]

Optional prediction of shape (n_samples, n_frames, n_dims, n_features).

None
target_cmap str

Colormap for the target (depth).

colormaps['binary_r']
fontsize float

Font size for labels and titles.

5
labelxy Tuple[float, float]

Normalized x and y location of the label.

(-0.1, 1)
max_figure_height_cm float

Maximum figure height in centimeters.

22
panel_height_cm float

Height of each panel in centimeters.

3
max_figure_width_cm float

Maximum figure width in centimeters.

18
panel_width_cm float

Width of each panel in centimeters.

3.6
title1 str

Title for the input panel.

'input'
title2 str

Title for the target panel.

'target'
title3 str

Title for the prediction panel.

'prediction'

Attributes:

Name Type Description
fig Figure

Matplotlib figure instance.

axes List[Axes]

List of matplotlib axes instances.

lum ndarray

Input data.

target ndarray

Target data.

prediction Optional[ndarray]

Prediction data.

extent Tuple[float, float, float, float]

Extent of the hexagonal grid.

n_samples int

Number of samples.

frames int

Number of frames.

update bool

Whether to update the canvas after an animation step.

labelxy Tuple[float, float]

Normalized x and y location of the label.

animations List

List of animation objects.

batch_sample int

Batch sample to start from.

Source code in flyvision/analysis/animations/sintel.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
class SintelSample(AnimationCollector):
    """Sintel-specific animation of input, target, and groundtruth data.

    Args:
        lum: Input of shape (n_samples, n_frames, n_hexals).
        target: Target of shape (n_samples, n_frames, n_dims, n_features).
        prediction: Optional prediction of shape
            (n_samples, n_frames, n_dims, n_features).
        target_cmap: Colormap for the target (depth).
        fontsize: Font size for labels and titles.
        labelxy: Normalized x and y location of the label.
        max_figure_height_cm: Maximum figure height in centimeters.
        panel_height_cm: Height of each panel in centimeters.
        max_figure_width_cm: Maximum figure width in centimeters.
        panel_width_cm: Width of each panel in centimeters.
        title1: Title for the input panel.
        title2: Title for the target panel.
        title3: Title for the prediction panel.

    Attributes:
        fig (Figure): Matplotlib figure instance.
        axes (List[Axes]): List of matplotlib axes instances.
        lum (np.ndarray): Input data.
        target (np.ndarray): Target data.
        prediction (Optional[np.ndarray]): Prediction data.
        extent (Tuple[float, float, float, float]): Extent of the hexagonal grid.
        n_samples (int): Number of samples.
        frames (int): Number of frames.
        update (bool): Whether to update the canvas after an animation step.
        labelxy (Tuple[float, float]): Normalized x and y location of the label.
        animations (List): List of animation objects.
        batch_sample (int): Batch sample to start from.
    """

    def __init__(
        self,
        lum: np.ndarray,
        target: np.ndarray,
        prediction: Optional[np.ndarray] = None,
        target_cmap: str = colormaps["binary_r"],
        fontsize: float = 5,
        labelxy: Tuple[float, float] = (-0.1, 1),
        max_figure_height_cm: float = 22,
        panel_height_cm: float = 3,
        max_figure_width_cm: float = 18,
        panel_width_cm: float = 3.6,
        title1: str = "input",
        title2: str = "target",
        title3: str = "prediction",
    ) -> None:
        figsize = figsize_utils.figsize_from_n_items(
            2 if prediction is None else 3,
            max_figure_height_cm=max_figure_height_cm,
            panel_height_cm=panel_height_cm,
            max_figure_width_cm=max_figure_width_cm,
            panel_width_cm=panel_width_cm,
        )
        self.fig, self.axes = figsize.axis_grid(
            hspace=0.0,
            wspace=0,
            fontsize=fontsize,
            unmask_n=2 if prediction is None else 3,
        )

        self.lum = lum
        self.target = target
        self.prediction = prediction
        self.extent = utils.hex_utils.get_hextent(self.lum.shape[-1])

        self.n_samples, self.frames = self.lum.shape[0:2]
        self.update = False
        self.labelxy = labelxy

        animations = []
        animations.append(
            HexScatter(
                self.lum,
                fig=self.fig,
                ax=self.axes[0],
                title=title1,
                edgecolor=None,
                update_edge_color=True,
                fontsize=fontsize,
                cbar=True,
                labelxy=labelxy,
            )
        )

        if self.target.shape[-2] == 2:
            animations.append(
                HexFlow(
                    flow=self.target,
                    fig=self.fig,
                    ax=self.axes[1],
                    cwheel=True,
                    cwheelxy=(-0.7, 0.7),
                    title=title2,
                    label="",
                    fontsize=fontsize,
                )
            )
            if prediction is not None:
                animations.append(
                    HexFlow(
                        flow=self.prediction,
                        fig=self.fig,
                        ax=self.axes[2],
                        cwheel=True,
                        cwheelxy=(-0.7, 0.7),
                        title=title3,
                        label="",
                        fontsize=fontsize,
                    )
                )
        else:
            animations.append(
                HexScatter(
                    self.target,
                    fig=self.fig,
                    ax=self.axes[1],
                    cmap=target_cmap,
                    title=title2,
                    edgecolor=None,
                    fontsize=fontsize,
                    cbar=True,
                    labelxy=labelxy,
                )
            )
            if prediction is not None:
                animations.append(
                    HexScatter(
                        self.prediction,
                        fig=self.fig,
                        ax=self.axes[2],
                        cmap=target_cmap,
                        title=title3,
                        edgecolor=None,
                        fontsize=fontsize,
                        cbar=True,
                        labelxy=labelxy,
                    )
                )
        self.animations = animations
        self.batch_sample = 0
        super().__init__(None, self.fig)

flyvision.analysis.animations.activations

Classes

flyvision.analysis.animations.activations.StimulusResponse

Bases: AnimationCollector

Hex-scatter animations for input and responses.

Parameters:

Name Type Description Default
stimulus ndarray

Hexagonal input.

required
responses Union[ndarray, List[ndarray]]

Hexagonal activation of particular neuron type.

required
batch_sample int

Batch sample to start from.

0
figsize_scale float

Scale factor for figure size.

1
fontsize int

Font size for the plot.

5
u Optional[List[int]]

List of u coordinates of neurons to plot.

None
v Optional[List[int]]

List of v coordinates of neurons to plot.

None
max_figure_height_cm float

Maximum figure height in centimeters.

22
panel_height_cm float

Height of each panel in centimeters.

3
max_figure_width_cm float

Maximum figure width in centimeters.

18
panel_width_cm float

Width of each panel in centimeters.

3.6

Attributes:

Name Type Description
stimulus ndarray

Numpy array of stimulus data.

responses List[ndarray]

List of numpy arrays of response data.

update bool

Flag to indicate if update is needed.

n_samples int

Number of samples.

frames int

Number of frames.

fig Figure

Matplotlib figure object.

axes List[Axes]

List of matplotlib axes objects.

animations List[HexScatter]

List of HexScatter animation objects.

batch_sample int

Batch sample index.

Note

If u and v are not specified, all neurons are plotted.

Source code in flyvision/analysis/animations/activations.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
class StimulusResponse(AnimationCollector):
    """Hex-scatter animations for input and responses.

    Args:
        stimulus: Hexagonal input.
        responses: Hexagonal activation of particular neuron type.
        batch_sample: Batch sample to start from.
        figsize_scale: Scale factor for figure size.
        fontsize: Font size for the plot.
        u: List of u coordinates of neurons to plot.
        v: List of v coordinates of neurons to plot.
        max_figure_height_cm: Maximum figure height in centimeters.
        panel_height_cm: Height of each panel in centimeters.
        max_figure_width_cm: Maximum figure width in centimeters.
        panel_width_cm: Width of each panel in centimeters.

    Attributes:
        stimulus (np.ndarray): Numpy array of stimulus data.
        responses (List[np.ndarray]): List of numpy arrays of response data.
        update (bool): Flag to indicate if update is needed.
        n_samples (int): Number of samples.
        frames (int): Number of frames.
        fig (matplotlib.figure.Figure): Matplotlib figure object.
        axes (List[matplotlib.axes.Axes]): List of matplotlib axes objects.
        animations (List[HexScatter]): List of HexScatter animation objects.
        batch_sample (int): Batch sample index.

    Note:
        If u and v are not specified, all neurons are plotted.
    """

    def __init__(
        self,
        stimulus: np.ndarray,
        responses: Union[np.ndarray, List[np.ndarray]],
        batch_sample: int = 0,
        figsize_scale: float = 1,
        fontsize: int = 5,
        u: Optional[List[int]] = None,
        v: Optional[List[int]] = None,
        max_figure_height_cm: float = 22,
        panel_height_cm: float = 3,
        max_figure_width_cm: float = 18,
        panel_width_cm: float = 3.6,
    ) -> None:
        self.stimulus = utils.tensor_utils.to_numpy(stimulus)

        # case: multiple response
        if isinstance(responses, List):
            self.responses = [utils.tensor_utils.to_numpy(r) for r in responses]
        else:
            self.responses = [utils.tensor_utils.to_numpy(responses)]

        self.update = False
        self.n_samples, self.frames = self.responses[0].shape[:2]

        figsize = figsize_utils.figsize_from_n_items(
            1 + len(self.responses),
            max_figure_height_cm=max_figure_height_cm,
            panel_height_cm=panel_height_cm,
            max_figure_width_cm=max_figure_width_cm,
            panel_width_cm=panel_width_cm,
        )
        self.fig, self.axes = figsize.axis_grid(
            unmask_n=1 + len(self.responses), hspace=0.0, wspace=0, fontsize=fontsize
        )

        stimulus_samples = self.stimulus.shape[0]
        if stimulus_samples != self.n_samples and stimulus_samples == 1:
            self.stimulus = np.repeat(self.stimulus, self.n_samples, axis=0)

        animations = []

        animations.append(
            HexScatter(
                self.stimulus,
                fig=self.fig,
                ax=self.axes[0],
                title="stimulus",
                labelxy=(-0.1, 1),
                update=False,
                title_y=0.9,
            )
        )

        cranges = np.max(np.abs(self.responses), axis=(0, 2, 3, 4))

        for i, responses in enumerate(self.responses, 1):
            animations.append(
                HexScatter(
                    responses,
                    fig=self.fig,
                    ax=self.axes[i],
                    cmap=cm.get_cmap("seismic"),
                    title=f"response {i}" if len(self.responses) > 1 else "response",
                    label="",
                    midpoint=0,
                    update=False,
                    u=u,
                    v=v,
                    cranges=cranges,
                    cbar=i == len(self.responses),
                    title_y=0.9,
                )
            )

        self.animations = animations
        self.batch_sample = batch_sample
        super().__init__(None, self.fig)

flyvision.analysis.animations.network

Classes

flyvision.analysis.animations.network.WholeNetworkAnimation

Bases: Animation

Create an animation of the whole network activity.

This class generates an animation that visualizes the activity of a neural network, including input, rendering, predicted flow, and target flow if provided.

Attributes:

Name Type Description
fig_backbone WholeNetworkFigure

The backbone figure for the animation.

fig Figure

The main figure object.

ax_dict dict

Dictionary of axes for different components of the animation.

batch_sample int

The index of the batch sample to animate.

kwargs dict

Additional keyword arguments.

update bool

Whether to update the figure during animation.

label str

Label format string for the animation.

labelxy tuple[float, float]

Position of the label.

fontsize int

Font size for labels.

cmap Colormap

Colormap for the animation.

n_samples int

Number of samples in the responses.

frames int

Number of frames in the responses.

responses LayerActivity

Layer activity data.

cartesian_input Optional[Any]

Cartesian input data.

rendered_input Optional[Any]

Rendered input data.

predicted_flow Optional[Any]

Predicted flow data.

target_flow Optional[Any]

Target flow data.

color_norm_per str

Color normalization method.

voltage_axes list

List of voltage axes for different cell types.

Source code in flyvision/analysis/animations/network.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
class WholeNetworkAnimation(Animation):
    """
    Create an animation of the whole network activity.

    This class generates an animation that visualizes the activity of a neural network,
    including input, rendering, predicted flow, and target flow if provided.

    Attributes:
        fig_backbone (WholeNetworkFigure): The backbone figure for the animation.
        fig (matplotlib.figure.Figure): The main figure object.
        ax_dict (dict): Dictionary of axes for different components of the animation.
        batch_sample (int): The index of the batch sample to animate.
        kwargs (dict): Additional keyword arguments.
        update (bool): Whether to update the figure during animation.
        label (str): Label format string for the animation.
        labelxy (tuple[float, float]): Position of the label.
        fontsize (int): Font size for labels.
        cmap (matplotlib.colors.Colormap): Colormap for the animation.
        n_samples (int): Number of samples in the responses.
        frames (int): Number of frames in the responses.
        responses (LayerActivity): Layer activity data.
        cartesian_input (Optional[Any]): Cartesian input data.
        rendered_input (Optional[Any]): Rendered input data.
        predicted_flow (Optional[Any]): Predicted flow data.
        target_flow (Optional[Any]): Target flow data.
        color_norm_per (str): Color normalization method.
        voltage_axes (list): List of voltage axes for different cell types.
    """

    def __init__(
        self,
        connectome: Any,
        responses: Any,
        cartesian_input: Optional[Any] = None,
        rendered_input: Optional[Any] = None,
        predicted_flow: Optional[Any] = None,
        target_flow: Optional[Any] = None,
        batch_sample: int = 0,
        update: bool = False,
        color_norm_per: Literal["batch"] = "batch",
        label: str = "Sample: {}\nFrame: {}",
        cmap: Any = plt.get_cmap("binary_r"),
        labelxy: tuple[float, float] = (0, 0.9),
        titlepad: int = 1,
        fontsize: int = 5,
        **kwargs: Any,
    ) -> None:
        self.fig_backbone = WholeNetworkFigure(
            connectome,
            video=cartesian_input is not None,
            rendering=rendered_input is not None,
            motion_decoder=predicted_flow is not None,
            decoded_motion=predicted_flow is not None,
            pixel_accurate_motion=target_flow is not None,
        )
        self.fig_backbone.init_figure()
        self.fig = self.fig_backbone.fig
        self.ax_dict = self.fig_backbone.ax_dict

        plt.rc("axes", titlepad=titlepad)
        self.batch_sample = batch_sample
        self.kwargs = kwargs
        self.update = update
        self.label = label
        self.labelxy = labelxy
        self.fontsize = fontsize
        self.cmap = cmap
        self.n_samples, self.frames = responses.shape[:2]

        self.responses = LayerActivity(responses, connectome, keepref=True)
        self.cartesian_input = cartesian_input
        self.rendered_input = rendered_input
        self.predicted_flow = predicted_flow
        self.target_flow = target_flow
        self.color_norm_per = color_norm_per
        path = None
        super().__init__(path, self.fig)

    def init(self, frame: int = 0) -> None:
        """
        Initialize the animation components.

        Args:
            frame: The initial frame number.
        """
        if self.fig_backbone.video:
            self.cartesian_input = Imshow(
                self.cartesian_input,
                vmin=0,
                vmax=1,
                cmap=plt.cm.binary_r,
                fig=self.fig,
                ax=self.ax_dict["video"],
            )
            self.cartesian_input.init(frame)
            self.cartesian_input.update = False

        if self.fig_backbone.rendering:
            self.rendered_input = HexScatter(
                self.rendered_input,
                vmin=0,
                vmax=1,
                cmap=plt.cm.binary_r,
                fig=self.fig,
                ax=self.ax_dict["rendering"],
                edgecolor=None,
                cbar=False,
                label="",
                background_color=self.fig_backbone.facecolor,
            )
            self.rendered_input.init(frame)
            self.rendered_input.update = False

        if self.fig_backbone.decoded_motion:
            self.predicted_flow = HexFlow(
                self.predicted_flow,
                fig=self.fig,
                ax=self.ax_dict["decoded motion"],
                label="",
                cwheel=True,
                cwheelradius=0.5,
                fontsize=5,
            )
            self.predicted_flow.init(frame)
            self.predicted_flow.update = False

        if self.fig_backbone.pixel_accurate_motion:
            self.target_flow = HexFlow(
                self.target_flow,
                fig=self.fig,
                ax=self.ax_dict["pixel-accurate motion"],
                label="",
            )
            self.target_flow.init(frame)
            self.target_flow.update = False

        self.voltage_axes = []
        for cell_type in self.fig_backbone.cell_types:
            voltage = self.responses[cell_type][:, :, None]
            nodes = self.fig_backbone.nodes
            nodes = nodes[nodes.type == cell_type]
            u, v = nodes[["u", "v"]].values.T
            anim = HexScatter(
                voltage,
                u=u,
                v=v,
                label="",
                cbar=False,
                edgecolor=None,
                ax=self.ax_dict[cell_type],
                fig=self.fig,
                cmap=plt.cm.binary_r,
            )
            anim.init(frame)
            anim.update = False
            self.voltage_axes.append(anim)

    def animate(self, frame: int) -> None:
        """
        Update the animation for a given frame.

        Args:
            frame: The current frame number.
        """
        if self.fig_backbone.video:
            self.cartesian_input.animate(frame)
        if self.fig_backbone.rendering:
            self.rendered_input.animate(frame)
        if self.fig_backbone.decoded_motion:
            self.predicted_flow.animate(frame)
        if self.fig_backbone.pixel_accurate_motion:
            self.target_flow.animate(frame)

        for anim in self.voltage_axes:
            anim.animate(frame)

        if self.update:
            self.update_figure()
init
init(frame=0)

Initialize the animation components.

Parameters:

Name Type Description Default
frame int

The initial frame number.

0
Source code in flyvision/analysis/animations/network.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
def init(self, frame: int = 0) -> None:
    """
    Initialize the animation components.

    Args:
        frame: The initial frame number.
    """
    if self.fig_backbone.video:
        self.cartesian_input = Imshow(
            self.cartesian_input,
            vmin=0,
            vmax=1,
            cmap=plt.cm.binary_r,
            fig=self.fig,
            ax=self.ax_dict["video"],
        )
        self.cartesian_input.init(frame)
        self.cartesian_input.update = False

    if self.fig_backbone.rendering:
        self.rendered_input = HexScatter(
            self.rendered_input,
            vmin=0,
            vmax=1,
            cmap=plt.cm.binary_r,
            fig=self.fig,
            ax=self.ax_dict["rendering"],
            edgecolor=None,
            cbar=False,
            label="",
            background_color=self.fig_backbone.facecolor,
        )
        self.rendered_input.init(frame)
        self.rendered_input.update = False

    if self.fig_backbone.decoded_motion:
        self.predicted_flow = HexFlow(
            self.predicted_flow,
            fig=self.fig,
            ax=self.ax_dict["decoded motion"],
            label="",
            cwheel=True,
            cwheelradius=0.5,
            fontsize=5,
        )
        self.predicted_flow.init(frame)
        self.predicted_flow.update = False

    if self.fig_backbone.pixel_accurate_motion:
        self.target_flow = HexFlow(
            self.target_flow,
            fig=self.fig,
            ax=self.ax_dict["pixel-accurate motion"],
            label="",
        )
        self.target_flow.init(frame)
        self.target_flow.update = False

    self.voltage_axes = []
    for cell_type in self.fig_backbone.cell_types:
        voltage = self.responses[cell_type][:, :, None]
        nodes = self.fig_backbone.nodes
        nodes = nodes[nodes.type == cell_type]
        u, v = nodes[["u", "v"]].values.T
        anim = HexScatter(
            voltage,
            u=u,
            v=v,
            label="",
            cbar=False,
            edgecolor=None,
            ax=self.ax_dict[cell_type],
            fig=self.fig,
            cmap=plt.cm.binary_r,
        )
        anim.init(frame)
        anim.update = False
        self.voltage_axes.append(anim)
animate
animate(frame)

Update the animation for a given frame.

Parameters:

Name Type Description Default
frame int

The current frame number.

required
Source code in flyvision/analysis/animations/network.py
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
def animate(self, frame: int) -> None:
    """
    Update the animation for a given frame.

    Args:
        frame: The current frame number.
    """
    if self.fig_backbone.video:
        self.cartesian_input.animate(frame)
    if self.fig_backbone.rendering:
        self.rendered_input.animate(frame)
    if self.fig_backbone.decoded_motion:
        self.predicted_flow.animate(frame)
    if self.fig_backbone.pixel_accurate_motion:
        self.target_flow.animate(frame)

    for anim in self.voltage_axes:
        anim.animate(frame)

    if self.update:
        self.update_figure()

flyvision.analysis.animations.animations

Classes

flyvision.analysis.animations.animations.Animation

Base class for animations.

Subclasses must implement init and animate methods.

Parameters:

Name Type Description Default
path Optional[Union[str, Path]]

Path to save the animation.

None
fig Optional[Figure]

Existing Figure instance or None.

None
suffix str

Suffix for the animation path.

'{}'

Attributes:

Name Type Description
fig Figure

Figure instance for the animation.

update bool

Whether to update the canvas after each animation step.

batch_sample int

Sample to animate.

frames int

Number of frames in the animation.

n_samples int

Number of samples in the animation.

path Path

Path to save the animation.

Source code in flyvision/analysis/animations/animations.py
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
class Animation:
    """Base class for animations.

    Subclasses must implement `init` and `animate` methods.

    Args:
        path: Path to save the animation.
        fig: Existing Figure instance or None.
        suffix: Suffix for the animation path.

    Attributes:
        fig (matplotlib.figure.Figure): Figure instance for the animation.
        update (bool): Whether to update the canvas after each animation step.
        batch_sample (int): Sample to animate.
        frames (int): Number of frames in the animation.
        n_samples (int): Number of samples in the animation.
        path (Path): Path to save the animation.
    """

    fig: Optional[matplotlib.figure.Figure] = None
    update: bool = True
    batch_sample: int = 0
    frames: int = 0
    n_samples: int = 0

    def __init__(
        self,
        path: Optional[Union[str, Path]] = None,
        fig: Optional[matplotlib.figure.Figure] = None,
        suffix: str = "{}",
    ):
        self.path = Path(ANIMATION_DIR if path is None else path) / suffix.format(
            self.__class__.__name__
        )
        self.fig = fig

    def init(self, frame: int = 0) -> None:
        """Initialize the animation.

        Args:
            frame: Initial frame number.

        Raises:
            NotImplementedError: If not implemented by subclass.
        """
        raise NotImplementedError("Subclasses should implement this method.")

    def animate(self, frame: int) -> None:
        """Animate a single frame.

        Args:
            frame: Frame number to animate.

        Raises:
            NotImplementedError: If not implemented by subclass.
        """
        raise NotImplementedError("Subclasses should implement this method.")

    def update_figure(self, clear_output: bool = True) -> None:
        """Update the figure canvas.

        Args:
            clear_output: Whether to clear the previous output.
        """
        self.fig.canvas.draw()
        self.fig.canvas.flush_events()
        if matplotlib.get_backend().lower() != "nbagg" or COLAB:
            display.display(self.fig)
            if clear_output:
                display.clear_output(wait=True)

    def animate_save(self, frame: int, dpi: int = 100) -> None:
        """Update the figure to the given frame and save it.

        Args:
            frame: Frame number to animate and save.
            dpi: Dots per inch for the saved image.
        """
        self.animate(frame)
        identifier = f"{self.batch_sample:04}_{frame:04}"
        self.fig.savefig(
            self._path / f"{identifier}.png",
            dpi=dpi,
            bbox_inches="tight",
            facecolor=self.fig.get_facecolor(),
            edgecolor="none",
        )

    def _get_indices(self, key: str, input: Union[str, Iterable]) -> list[int]:
        """Get sorted list of indices based on input.

        Args:
            key: Attribute name to get total number of indices.
            input: Input specifying which indices to return.

        Returns:
            Sorted list of indices.

        Raises:
            ValueError: If input is invalid.
        """
        total = getattr(self, key)
        _indices = set(range(total))
        if input == "all":
            indices = _indices
        elif isinstance(input, Iterable):
            indices = _indices.intersection(set(input))
        else:
            raise ValueError(f"Invalid input for {key}: {input}")
        return sorted(indices)

    def animate_in_notebook(
        self,
        frames: Union[str, Iterable] = "all",
        samples: Union[str, Iterable] = "all",
        repeat: int = 1,
    ) -> None:
        """Play animation within a Jupyter notebook.

        Args:
            frames: Frames to animate.
            samples: Samples to animate.
            repeat: Number of times to repeat the animation.
        """
        frames_list, samples_list = self._initialize_animation(frames, samples)

        if TESTING:
            frames_list = frames_list[:1]
            samples_list = samples_list[:1]
            repeat = 1

        try:
            for _ in range(repeat):
                for sample in samples_list:
                    self.batch_sample = sample
                    for frame in frames_list:
                        self.animate(frame)
                        sleep(0.1)  # Pause between frames
        except KeyboardInterrupt:
            print("Animation interrupted. Displaying last frame.")
            self.update_figure(clear_output=False)
            return

    def _verify_backend(self) -> None:
        """Ensure the notebook backend is set correctly.

        Raises:
            RuntimeError: If matplotlib backend is not set to notebook.
        """
        backend = matplotlib.get_backend().lower()
        if backend != "nbagg" and not COLAB:
            raise RuntimeError(
                "Matplotlib backend is not set to notebook. Use '%matplotlib notebook'."
            )

    def _initialize_animation(
        self, frames: Union[str, Iterable], samples: Union[str, Iterable]
    ) -> tuple[list[int], list[int]]:
        """Initialize the animation state.

        Args:
            frames: Frames to animate.
            samples: Samples to animate.

        Returns:
            Tuple of frames list and samples list.
        """
        self.update = True
        self.init()
        frames_list = self._get_indices("frames", frames)
        samples_list = self._get_indices("n_samples", samples)
        return frames_list, samples_list

    def plot(self, sample: int, frame: int) -> None:
        """Plot a single frame for a specific sample.

        Args:
            sample: Sample number to plot.
            frame: Frame number to plot.
        """
        previous_sample = self.batch_sample
        self.update = True
        self.init()
        self.batch_sample = sample
        self.animate(frame)
        self.batch_sample = previous_sample

    def _create_temp_dir(self, path: Optional[Union[str, Path]] = None) -> None:
        """Create a temporary directory as destination for the images.

        Args:
            path: Path to create the temporary directory.
        """
        self._temp_dir = tempfile.TemporaryDirectory()
        self._path = Path(self._temp_dir.name)

    def to_vid(
        self,
        fname: str,
        frames: Union[str, Iterable] = "all",
        dpi: int = 100,
        framerate: int = 30,
        samples: Union[str, Iterable] = "all",
        delete_if_exists: bool = False,
        source_path: Optional[Union[str, Path]] = None,
        dest_path: Optional[Union[str, Path]] = None,
        type: Literal["mp4", "webm"] = "webm",
    ) -> None:
        """Animate, save individual frames, and convert to video using ffmpeg.

        Args:
            fname: Output filename.
            frames: Frames to animate.
            dpi: Dots per inch for saved images.
            framerate: Frame rate of the output video.
            samples: Samples to animate.
            delete_if_exists: Whether to delete existing output file.
            source_path: Source path for temporary files.
            dest_path: Destination path for the output video.
            type: Output video type.
        """
        self._create_temp_dir(path=source_path)
        self.update = True
        self.init()
        frames_list = self._get_indices("frames", frames)
        samples_list = self._get_indices("n_samples", samples)

        try:
            for sample in samples_list:
                self.batch_sample = sample
                for frame in frames_list:
                    self.animate_save(frame, dpi=dpi)
        except Exception as e:
            logging.error("Error during animation: %s", e)
            raise

        self.convert(
            fname,
            delete_if_exists=delete_if_exists,
            framerate=framerate,
            source_path=source_path,
            dest_path=dest_path,
            type=type,
        )

        self._temp_dir.cleanup()

    def convert(
        self,
        fname: str,
        delete_if_exists: bool = False,
        framerate: int = 30,
        source_path: Optional[Union[str, Path]] = None,
        dest_path: Optional[Union[str, Path]] = None,
        type: Literal["mp4", "webm"] = "mp4",
    ) -> None:
        """Convert PNG files in the animations directory to video.

        Args:
            fname: Output filename.
            delete_if_exists: Whether to delete existing output file.
            framerate: Frame rate of the output video.
            source_path: Source path for input PNG files.
            dest_path: Destination path for the output video.
            type: Output video type.
        """
        dest_path = Path(dest_path or self.path)
        dest_path.mkdir(parents=True, exist_ok=True)
        convert(
            source_path or self._path,
            dest_path / f"{fname}.{type}",
            framerate,
            delete_if_exists,
            type=type,
        )
init
init(frame=0)

Initialize the animation.

Parameters:

Name Type Description Default
frame int

Initial frame number.

0

Raises:

Type Description
NotImplementedError

If not implemented by subclass.

Source code in flyvision/analysis/animations/animations.py
55
56
57
58
59
60
61
62
63
64
def init(self, frame: int = 0) -> None:
    """Initialize the animation.

    Args:
        frame: Initial frame number.

    Raises:
        NotImplementedError: If not implemented by subclass.
    """
    raise NotImplementedError("Subclasses should implement this method.")
animate
animate(frame)

Animate a single frame.

Parameters:

Name Type Description Default
frame int

Frame number to animate.

required

Raises:

Type Description
NotImplementedError

If not implemented by subclass.

Source code in flyvision/analysis/animations/animations.py
66
67
68
69
70
71
72
73
74
75
def animate(self, frame: int) -> None:
    """Animate a single frame.

    Args:
        frame: Frame number to animate.

    Raises:
        NotImplementedError: If not implemented by subclass.
    """
    raise NotImplementedError("Subclasses should implement this method.")
update_figure
update_figure(clear_output=True)

Update the figure canvas.

Parameters:

Name Type Description Default
clear_output bool

Whether to clear the previous output.

True
Source code in flyvision/analysis/animations/animations.py
77
78
79
80
81
82
83
84
85
86
87
88
def update_figure(self, clear_output: bool = True) -> None:
    """Update the figure canvas.

    Args:
        clear_output: Whether to clear the previous output.
    """
    self.fig.canvas.draw()
    self.fig.canvas.flush_events()
    if matplotlib.get_backend().lower() != "nbagg" or COLAB:
        display.display(self.fig)
        if clear_output:
            display.clear_output(wait=True)
animate_save
animate_save(frame, dpi=100)

Update the figure to the given frame and save it.

Parameters:

Name Type Description Default
frame int

Frame number to animate and save.

required
dpi int

Dots per inch for the saved image.

100
Source code in flyvision/analysis/animations/animations.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
def animate_save(self, frame: int, dpi: int = 100) -> None:
    """Update the figure to the given frame and save it.

    Args:
        frame: Frame number to animate and save.
        dpi: Dots per inch for the saved image.
    """
    self.animate(frame)
    identifier = f"{self.batch_sample:04}_{frame:04}"
    self.fig.savefig(
        self._path / f"{identifier}.png",
        dpi=dpi,
        bbox_inches="tight",
        facecolor=self.fig.get_facecolor(),
        edgecolor="none",
    )
animate_in_notebook
animate_in_notebook(frames='all', samples='all', repeat=1)

Play animation within a Jupyter notebook.

Parameters:

Name Type Description Default
frames Union[str, Iterable]

Frames to animate.

'all'
samples Union[str, Iterable]

Samples to animate.

'all'
repeat int

Number of times to repeat the animation.

1
Source code in flyvision/analysis/animations/animations.py
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
def animate_in_notebook(
    self,
    frames: Union[str, Iterable] = "all",
    samples: Union[str, Iterable] = "all",
    repeat: int = 1,
) -> None:
    """Play animation within a Jupyter notebook.

    Args:
        frames: Frames to animate.
        samples: Samples to animate.
        repeat: Number of times to repeat the animation.
    """
    frames_list, samples_list = self._initialize_animation(frames, samples)

    if TESTING:
        frames_list = frames_list[:1]
        samples_list = samples_list[:1]
        repeat = 1

    try:
        for _ in range(repeat):
            for sample in samples_list:
                self.batch_sample = sample
                for frame in frames_list:
                    self.animate(frame)
                    sleep(0.1)  # Pause between frames
    except KeyboardInterrupt:
        print("Animation interrupted. Displaying last frame.")
        self.update_figure(clear_output=False)
        return
plot
plot(sample, frame)

Plot a single frame for a specific sample.

Parameters:

Name Type Description Default
sample int

Sample number to plot.

required
frame int

Frame number to plot.

required
Source code in flyvision/analysis/animations/animations.py
192
193
194
195
196
197
198
199
200
201
202
203
204
def plot(self, sample: int, frame: int) -> None:
    """Plot a single frame for a specific sample.

    Args:
        sample: Sample number to plot.
        frame: Frame number to plot.
    """
    previous_sample = self.batch_sample
    self.update = True
    self.init()
    self.batch_sample = sample
    self.animate(frame)
    self.batch_sample = previous_sample
to_vid
to_vid(
    fname,
    frames="all",
    dpi=100,
    framerate=30,
    samples="all",
    delete_if_exists=False,
    source_path=None,
    dest_path=None,
    type="webm",
)

Animate, save individual frames, and convert to video using ffmpeg.

Parameters:

Name Type Description Default
fname str

Output filename.

required
frames Union[str, Iterable]

Frames to animate.

'all'
dpi int

Dots per inch for saved images.

100
framerate int

Frame rate of the output video.

30
samples Union[str, Iterable]

Samples to animate.

'all'
delete_if_exists bool

Whether to delete existing output file.

False
source_path Optional[Union[str, Path]]

Source path for temporary files.

None
dest_path Optional[Union[str, Path]]

Destination path for the output video.

None
type Literal['mp4', 'webm']

Output video type.

'webm'
Source code in flyvision/analysis/animations/animations.py
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
def to_vid(
    self,
    fname: str,
    frames: Union[str, Iterable] = "all",
    dpi: int = 100,
    framerate: int = 30,
    samples: Union[str, Iterable] = "all",
    delete_if_exists: bool = False,
    source_path: Optional[Union[str, Path]] = None,
    dest_path: Optional[Union[str, Path]] = None,
    type: Literal["mp4", "webm"] = "webm",
) -> None:
    """Animate, save individual frames, and convert to video using ffmpeg.

    Args:
        fname: Output filename.
        frames: Frames to animate.
        dpi: Dots per inch for saved images.
        framerate: Frame rate of the output video.
        samples: Samples to animate.
        delete_if_exists: Whether to delete existing output file.
        source_path: Source path for temporary files.
        dest_path: Destination path for the output video.
        type: Output video type.
    """
    self._create_temp_dir(path=source_path)
    self.update = True
    self.init()
    frames_list = self._get_indices("frames", frames)
    samples_list = self._get_indices("n_samples", samples)

    try:
        for sample in samples_list:
            self.batch_sample = sample
            for frame in frames_list:
                self.animate_save(frame, dpi=dpi)
    except Exception as e:
        logging.error("Error during animation: %s", e)
        raise

    self.convert(
        fname,
        delete_if_exists=delete_if_exists,
        framerate=framerate,
        source_path=source_path,
        dest_path=dest_path,
        type=type,
    )

    self._temp_dir.cleanup()
convert
convert(
    fname,
    delete_if_exists=False,
    framerate=30,
    source_path=None,
    dest_path=None,
    type="mp4",
)

Convert PNG files in the animations directory to video.

Parameters:

Name Type Description Default
fname str

Output filename.

required
delete_if_exists bool

Whether to delete existing output file.

False
framerate int

Frame rate of the output video.

30
source_path Optional[Union[str, Path]]

Source path for input PNG files.

None
dest_path Optional[Union[str, Path]]

Destination path for the output video.

None
type Literal['mp4', 'webm']

Output video type.

'mp4'
Source code in flyvision/analysis/animations/animations.py
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
def convert(
    self,
    fname: str,
    delete_if_exists: bool = False,
    framerate: int = 30,
    source_path: Optional[Union[str, Path]] = None,
    dest_path: Optional[Union[str, Path]] = None,
    type: Literal["mp4", "webm"] = "mp4",
) -> None:
    """Convert PNG files in the animations directory to video.

    Args:
        fname: Output filename.
        delete_if_exists: Whether to delete existing output file.
        framerate: Frame rate of the output video.
        source_path: Source path for input PNG files.
        dest_path: Destination path for the output video.
        type: Output video type.
    """
    dest_path = Path(dest_path or self.path)
    dest_path.mkdir(parents=True, exist_ok=True)
    convert(
        source_path or self._path,
        dest_path / f"{fname}.{type}",
        framerate,
        delete_if_exists,
        type=type,
    )

flyvision.analysis.animations.animations.AnimationCollector

Bases: Animation

Collects Animations and updates all axes at once.

Subclasses must populate the animations attribute with Animation objects and adhere to the Animation interface.

Attributes:

Name Type Description
animations list[Animation]

List of Animation objects to collect.

Source code in flyvision/analysis/animations/animations.py
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
class AnimationCollector(Animation):
    """Collects Animations and updates all axes at once.

    Subclasses must populate the `animations` attribute with Animation objects
    and adhere to the Animation interface.

    Attributes:
        animations (list[Animation]): List of Animation objects to collect.
    """

    animations: list[Animation] = []

    def init(self, frame: int = 0) -> None:
        """Initialize all collected animations.

        Args:
            frame: Initial frame number.
        """
        for animation in self.animations:
            animation.init(frame)
            animation.update = False

    def animate(self, frame: int) -> None:
        """Animate all collected animations for a single frame.

        Args:
            frame: Frame number to animate.
        """
        for animation in self.animations:
            animation.animate(frame)
        if self.update:
            self.update_figure()

    def __setattr__(self, key: str, val: Any) -> None:
        """Set attributes for all Animation objects at once.

        Args:
            key: Attribute name to set.
            val: Value to set for the attribute.
        """
        if key == "batch_sample" and hasattr(self, "animations"):
            for animation in self.animations:
                setattr(animation, key, val)
        super().__setattr__(key, val)
init
init(frame=0)

Initialize all collected animations.

Parameters:

Name Type Description Default
frame int

Initial frame number.

0
Source code in flyvision/analysis/animations/animations.py
380
381
382
383
384
385
386
387
388
def init(self, frame: int = 0) -> None:
    """Initialize all collected animations.

    Args:
        frame: Initial frame number.
    """
    for animation in self.animations:
        animation.init(frame)
        animation.update = False
animate
animate(frame)

Animate all collected animations for a single frame.

Parameters:

Name Type Description Default
frame int

Frame number to animate.

required
Source code in flyvision/analysis/animations/animations.py
390
391
392
393
394
395
396
397
398
399
def animate(self, frame: int) -> None:
    """Animate all collected animations for a single frame.

    Args:
        frame: Frame number to animate.
    """
    for animation in self.animations:
        animation.animate(frame)
    if self.update:
        self.update_figure()
__setattr__
__setattr__(key, val)

Set attributes for all Animation objects at once.

Parameters:

Name Type Description Default
key str

Attribute name to set.

required
val Any

Value to set for the attribute.

required
Source code in flyvision/analysis/animations/animations.py
401
402
403
404
405
406
407
408
409
410
411
def __setattr__(self, key: str, val: Any) -> None:
    """Set attributes for all Animation objects at once.

    Args:
        key: Attribute name to set.
        val: Value to set for the attribute.
    """
    if key == "batch_sample" and hasattr(self, "animations"):
        for animation in self.animations:
            setattr(animation, key, val)
    super().__setattr__(key, val)

Functions

flyvision.analysis.animations.animations.convert

convert(
    directory, dest, framerate, delete_if_exists, type="mp4"
)

Convert PNG files in directory to MP4 or WebM.

Parameters:

Name Type Description Default
directory Union[str, Path]

Source directory containing PNG files.

required
dest Union[str, Path]

Destination path for the output video.

required
framerate int

Frame rate of the output video.

required
delete_if_exists bool

Whether to delete existing output file.

required
type Literal['mp4', 'webm']

Output video type.

'mp4'

Raises:

Type Description
ValueError

If unsupported video type is specified.

FileExistsError

If output file exists and delete_if_exists is False.

Source code in flyvision/analysis/animations/animations.py
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
def convert(
    directory: Union[str, Path],
    dest: Union[str, Path],
    framerate: int,
    delete_if_exists: bool,
    type: Literal["mp4", "webm"] = "mp4",
) -> None:
    """Convert PNG files in directory to MP4 or WebM.

    Args:
        directory: Source directory containing PNG files.
        dest: Destination path for the output video.
        framerate: Frame rate of the output video.
        delete_if_exists: Whether to delete existing output file.
        type: Output video type.

    Raises:
        ValueError: If unsupported video type is specified.
        FileExistsError: If output file exists and delete_if_exists is False.
    """
    video = Path(dest)

    if type == "mp4":
        kwargs = dict(
            vcodec="libx264",
            vprofile="high",
            vlevel="4.0",
            vf="pad=ceil(iw/2)*2:ceil(ih/2)*2",  # to make sizes even
            pix_fmt="yuv420p",
            crf=18,
        )
    elif type == "webm":
        kwargs = dict(
            vcodec="libvpx-vp9",
            vf="pad=ceil(iw/2)*2:ceil(ih/2)*2",
            pix_fmt="yuva420p",
            crf=18,
            threads=4,
        )
    else:
        raise ValueError(f"Unsupported video type: {type}")

    if video.exists():
        if delete_if_exists:
            video.unlink()
        else:
            raise FileExistsError(f"File {video} already exists.")

    try:
        (
            ffmpeg.input(f"{directory}/*_*.png", pattern_type="glob", framerate=framerate)
            .output(str(video), **kwargs)
            .run(
                overwrite_output=True,
                quiet=True,
                capture_stdout=True,
                capture_stderr=True,
            )
        )
    except FileNotFoundError as e:
        if "ffmpeg" in str(e):
            logging.warning("Check ffmpeg installation: %s", e)
            return
        else:
            raise
    except ffmpeg.Error as e:
        logging.error("ffmpeg error: %s", e.stderr.decode("utf8"))
        raise e

    logging.info("Created %s", video)

flyvision.analysis.animations.traces

Classes

flyvision.analysis.animations.traces.Trace

Bases: Animation

Animates a trace.

Parameters:

Name Type Description Default
trace ndarray

Trace of shape (n_samples, n_frames).

required
dt float

Time step in seconds for accurate time axis.

1
fig Optional[Figure]

Existing Figure instance or None.

None
ax Optional[Axes]

Existing Axis instance or None.

None
update bool

Whether to update the canvas after an animation step. Must be False if this animation is composed with others.

False
color Optional[Union[str, ndarray]]

Optional color of the trace.

None
title str

Optional title of the animation.

''
batch_sample int

Batch sample to start from.

0
dynamic_ax_lims bool

Whether the ax limits of the trace are animated.

True
ylims Optional[List[Tuple[float, float]]]

Static y-limits for the trace for each sample.

None
ylabel str

Optional y-label of the trace.

''
contour Optional[ndarray]

Optional background contour for trace in x direction.

None
label str

Label of the animation. Formatted with the current sample and frame number.

'Sample: {}\nFrame: {}'
labelxy Tuple[float, float]

Normalized x and y location of the label.

(0.1, 0.95)
fontsize float

Fontsize.

5
figsize Tuple[float, float]

Figure size.

(2, 2)

Attributes:

Name Type Description
trace ndarray

Trace data.

n_samples int

Number of samples.

frames int

Number of frames.

fig Optional[Figure]

Figure instance.

ax Optional[Axes]

Axes instance.

update bool

Update flag.

color Optional[Union[str, ndarray]]

Color of the trace.

label str

Label format string.

labelxy Tuple[float, float]

Label position.

label_text

Label text object.

batch_sample int

Current batch sample.

fontsize float

Font size.

dynamic_ax_lims bool

Dynamic axis limits flag.

ylabel str

Y-axis label.

ylims Optional[List[Tuple[float, float]]]

Y-axis limits.

title str

Plot title.

contour Optional[ndarray]

Contour data.

contour_lims Optional[ndarray]

Contour limits.

dt float

Time step.

figsize Tuple[float, float]

Figure size.

Source code in flyvision/analysis/animations/traces.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
class Trace(Animation):
    """Animates a trace.

    Args:
        trace: Trace of shape (n_samples, n_frames).
        dt: Time step in seconds for accurate time axis.
        fig: Existing Figure instance or None.
        ax: Existing Axis instance or None.
        update: Whether to update the canvas after an animation step.
            Must be False if this animation is composed with others.
        color: Optional color of the trace.
        title: Optional title of the animation.
        batch_sample: Batch sample to start from.
        dynamic_ax_lims: Whether the ax limits of the trace are animated.
        ylims: Static y-limits for the trace for each sample.
        ylabel: Optional y-label of the trace.
        contour: Optional background contour for trace in x direction.
        label: Label of the animation. Formatted with the current sample and frame number.
        labelxy: Normalized x and y location of the label.
        fontsize: Fontsize.
        figsize: Figure size.

    Attributes:
        trace (np.ndarray): Trace data.
        n_samples (int): Number of samples.
        frames (int): Number of frames.
        fig (Optional[Figure]): Figure instance.
        ax (Optional[Axes]): Axes instance.
        update (bool): Update flag.
        color (Optional[Union[str, np.ndarray]]): Color of the trace.
        label (str): Label format string.
        labelxy (Tuple[float, float]): Label position.
        label_text: Label text object.
        batch_sample (int): Current batch sample.
        fontsize (float): Font size.
        dynamic_ax_lims (bool): Dynamic axis limits flag.
        ylabel (str): Y-axis label.
        ylims (Optional[List[Tuple[float, float]]]): Y-axis limits.
        title (str): Plot title.
        contour (Optional[np.ndarray]): Contour data.
        contour_lims (Optional[np.ndarray]): Contour limits.
        dt (float): Time step.
        figsize (Tuple[float, float]): Figure size.
    """

    def __init__(
        self,
        trace: np.ndarray,
        dt: float = 1,
        fig: Optional[Figure] = None,
        ax: Optional[Axes] = None,
        update: bool = False,
        color: Optional[Union[str, np.ndarray]] = None,
        title: str = "",
        batch_sample: int = 0,
        dynamic_ax_lims: bool = True,
        ylims: Optional[List[Tuple[float, float]]] = None,
        ylabel: str = "",
        contour: Optional[np.ndarray] = None,
        label: str = "Sample: {}\nFrame: {}",
        labelxy: Tuple[float, float] = (0.1, 0.95),
        fontsize: float = 5,
        figsize: Tuple[float, float] = (2, 2),
    ):
        self.trace = utils.tensor_utils.to_numpy(trace)
        self.n_samples, self.frames = self.trace.shape
        self.fig = fig
        self.ax = ax
        self.update = update
        self.color = color
        self.label = label
        self.labelxy = labelxy
        self.label_text = None
        self.batch_sample = batch_sample
        self.fontsize = fontsize
        self._initial_frame = 0
        self.dynamic_ax_lims = dynamic_ax_lims
        self.ylabel = ylabel
        self.ylims = ylims
        self.title = title
        self.contour = contour
        if self.contour is not None:
            self.contour_lims = np.array([
                plt_utils.get_lims(c, 0.01) for c in self.contour
            ])
        self.dt = dt
        self.figsize = figsize
        super().__init__(None, self.fig)

    def init(self, frame: int = 0) -> None:
        """Initialize the animation.

        Args:
            frame: Starting frame number.
        """
        if frame < 0:
            frame += self.frames
        trace = self.trace[self.batch_sample, self._initial_frame : frame + 1]
        x = np.arange(frame + 1) * self.dt
        self.fig, self.ax, _, self.label_text = plots.traces(
            trace,
            x=x,
            contour=None,
            smooth=None,
            fig=self.fig,
            ax=self.ax,
            label=self.label,
            color=self.color,
            labelxy=self.labelxy,
            xlabel="time in s",
            ylabel=self.ylabel,
            fontsize=self.fontsize,
            title=self.title,
            figsize=self.figsize,
        )

        self._plot_contour()

        if self.dynamic_ax_lims:
            if self.ylims is not None:
                ymin, ymax = self.ylims[self.batch_sample]
            else:
                ymin, ymax = plt_utils.get_lims(trace, 0.1)
            xmin, xmax = plt_utils.get_lims(x, 0.1)
            self.ax.axis([xmin, xmax, ymin, ymax])

        self._sample = self.batch_sample

    def animate(self, frame: int) -> None:
        """Animate a single frame.

        Args:
            frame: Current frame number.
        """
        if frame < 0:
            frame += self.frames
        trace = self.trace[self.batch_sample, self._initial_frame : frame]
        x = np.arange(self._initial_frame, frame) * self.dt
        self.ax.lines[0].set_data(x, trace)

        if self.batch_sample != self._sample:
            self._plot_contour()

        if self.dynamic_ax_lims:
            if self.ylims is not None:
                ymin, ymax = self.ylims[self.batch_sample]
            else:
                ymin, ymax = plt_utils.get_lims(trace, 0.1)
            xmin, xmax = plt_utils.get_lims(x, 0.1)
            self.ax.axis([xmin, xmax, ymin, ymax])

        if self.label:
            self.label_text.set_text(self.label.format(self.batch_sample))

        if self.update:
            self.update_figure()

        self._sample = self.batch_sample

    def _plot_contour(self) -> None:
        """Plot the contour if available."""
        if self.contour is None:
            return

        contour = self.contour[self.batch_sample]

        while self.ax.collections:
            for c in self.ax.collections:
                c.remove()

        x = np.arange(len(contour)) * self.dt
        _y = np.linspace(-2000, 2000, 100)
        Z = np.tile(contour, (len(_y), 1))

        self.ax.contourf(
            x,
            _y,
            Z,
            cmap=cm.get_cmap("binary_r"),
            levels=20,
            zorder=-100,
            alpha=0.2,
            vmin=self.contour_lims[self.batch_sample, 0],
            vmax=self.contour_lims[self.batch_sample, 1],
        )
init
init(frame=0)

Initialize the animation.

Parameters:

Name Type Description Default
frame int

Starting frame number.

0
Source code in flyvision/analysis/animations/traces.py
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
def init(self, frame: int = 0) -> None:
    """Initialize the animation.

    Args:
        frame: Starting frame number.
    """
    if frame < 0:
        frame += self.frames
    trace = self.trace[self.batch_sample, self._initial_frame : frame + 1]
    x = np.arange(frame + 1) * self.dt
    self.fig, self.ax, _, self.label_text = plots.traces(
        trace,
        x=x,
        contour=None,
        smooth=None,
        fig=self.fig,
        ax=self.ax,
        label=self.label,
        color=self.color,
        labelxy=self.labelxy,
        xlabel="time in s",
        ylabel=self.ylabel,
        fontsize=self.fontsize,
        title=self.title,
        figsize=self.figsize,
    )

    self._plot_contour()

    if self.dynamic_ax_lims:
        if self.ylims is not None:
            ymin, ymax = self.ylims[self.batch_sample]
        else:
            ymin, ymax = plt_utils.get_lims(trace, 0.1)
        xmin, xmax = plt_utils.get_lims(x, 0.1)
        self.ax.axis([xmin, xmax, ymin, ymax])

    self._sample = self.batch_sample
animate
animate(frame)

Animate a single frame.

Parameters:

Name Type Description Default
frame int

Current frame number.

required
Source code in flyvision/analysis/animations/traces.py
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
def animate(self, frame: int) -> None:
    """Animate a single frame.

    Args:
        frame: Current frame number.
    """
    if frame < 0:
        frame += self.frames
    trace = self.trace[self.batch_sample, self._initial_frame : frame]
    x = np.arange(self._initial_frame, frame) * self.dt
    self.ax.lines[0].set_data(x, trace)

    if self.batch_sample != self._sample:
        self._plot_contour()

    if self.dynamic_ax_lims:
        if self.ylims is not None:
            ymin, ymax = self.ylims[self.batch_sample]
        else:
            ymin, ymax = plt_utils.get_lims(trace, 0.1)
        xmin, xmax = plt_utils.get_lims(x, 0.1)
        self.ax.axis([xmin, xmax, ymin, ymax])

    if self.label:
        self.label_text.set_text(self.label.format(self.batch_sample))

    if self.update:
        self.update_figure()

    self._sample = self.batch_sample

flyvision.analysis.animations.traces.MultiTrace

Bases: Animation

Animates multiple traces in single plot.

Parameters:

Name Type Description Default
trace ndarray

Trace of shape (n_samples, n_frames, n_traces).

required
dt float

Time step in seconds.

1
fig Optional[Figure]

Existing Figure instance or None.

None
ax Optional[Axes]

Existing Axis instance or None.

None
update bool

Whether to update the figure after each frame.

False
legend Optional[List[str]]

Legends of the traces.

None
colors Optional[List[Union[str, ndarray]]]

Optional colors of the traces.

None
title str

Optional title of the animation.

''
batch_sample int

Batch sample to start from.

0
dynamic_ax_lims bool

Whether the ax limits of the trace are animated.

True
ylims Optional[List[Tuple[float, float]]]

Static y-limits for the trace for each sample.

None
ylabel str

Optional y-label of the trace.

''
contour Optional[ndarray]

Optional background contour for trace in x direction.

None
label str

Label of the animation. Formatted with the current sample and frame number.

'Sample: {}\nFrame: {}'
labelxy Tuple[float, float]

Normalized x and y location of the label.

(0.1, 0.95)
fontsize float

Fontsize.

5
path Optional[str]

Path object to save animation to.

None

Attributes:

Name Type Description
trace ndarray

Trace data.

n_samples int

Number of samples.

frames int

Number of frames.

n_trace int

Number of traces.

fig Optional[Figure]

Figure instance.

ax Optional[Axes]

Axes instance.

update bool

Update flag.

colors Optional[List[Union[str, ndarray]]]

Colors of the traces.

label str

Label format string.

labelxy Tuple[float, float]

Label position.

label_text

Label text object.

legend Optional[List[str]]

Legend labels.

batch_sample int

Current batch sample.

fontsize float

Font size.

dynamic_ax_lims bool

Dynamic axis limits flag.

ylabel str

Y-axis label.

ylims Optional[List[Tuple[float, float]]]

Y-axis limits.

title str

Plot title.

contour Optional[ndarray]

Contour data.

dt float

Time step.

Source code in flyvision/analysis/animations/traces.py
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
class MultiTrace(Animation):
    """Animates multiple traces in single plot.

    Args:
        trace: Trace of shape (n_samples, n_frames, n_traces).
        dt: Time step in seconds.
        fig: Existing Figure instance or None.
        ax: Existing Axis instance or None.
        update: Whether to update the figure after each frame.
        legend: Legends of the traces.
        colors: Optional colors of the traces.
        title: Optional title of the animation.
        batch_sample: Batch sample to start from.
        dynamic_ax_lims: Whether the ax limits of the trace are animated.
        ylims: Static y-limits for the trace for each sample.
        ylabel: Optional y-label of the trace.
        contour: Optional background contour for trace in x direction.
        label: Label of the animation. Formatted with the current sample and frame number.
        labelxy: Normalized x and y location of the label.
        fontsize: Fontsize.
        path: Path object to save animation to.

    Attributes:
        trace (np.ndarray): Trace data.
        n_samples (int): Number of samples.
        frames (int): Number of frames.
        n_trace (int): Number of traces.
        fig (Optional[Figure]): Figure instance.
        ax (Optional[Axes]): Axes instance.
        update (bool): Update flag.
        colors (Optional[List[Union[str, np.ndarray]]]): Colors of the traces.
        label (str): Label format string.
        labelxy (Tuple[float, float]): Label position.
        label_text: Label text object.
        legend (Optional[List[str]]): Legend labels.
        batch_sample (int): Current batch sample.
        fontsize (float): Font size.
        dynamic_ax_lims (bool): Dynamic axis limits flag.
        ylabel (str): Y-axis label.
        ylims (Optional[List[Tuple[float, float]]]): Y-axis limits.
        title (str): Plot title.
        contour (Optional[np.ndarray]): Contour data.
        dt (float): Time step.
    """

    def __init__(
        self,
        trace: np.ndarray,
        dt: float = 1,
        fig: Optional[Figure] = None,
        ax: Optional[Axes] = None,
        update: bool = False,
        legend: Optional[List[str]] = None,
        colors: Optional[List[Union[str, np.ndarray]]] = None,
        title: str = "",
        batch_sample: int = 0,
        dynamic_ax_lims: bool = True,
        ylims: Optional[List[Tuple[float, float]]] = None,
        ylabel: str = "",
        contour: Optional[np.ndarray] = None,
        label: str = "Sample: {}\nFrame: {}",
        labelxy: Tuple[float, float] = (0.1, 0.95),
        fontsize: float = 5,
        path: Optional[str] = None,
    ):
        self.trace = utils.tensor_utils.to_numpy(trace)
        self.n_samples, self.frames, self.n_trace = self.trace.shape
        self.fig = fig
        self.ax = ax
        self.update = update
        self.colors = colors
        self.label = label
        self.labelxy = labelxy
        self.label_text = None
        self.legend = legend
        self.batch_sample = batch_sample
        self.fontsize = fontsize
        self._initial_frame = 0
        self.dynamic_ax_lims = dynamic_ax_lims
        self.ylabel = ylabel
        self.ylims = ylims
        self.title = title
        self.contour = contour
        self.dt = dt
        super().__init__(path, self.fig)

    def init(self, frame: int = 0) -> None:
        """Initialize the animation.

        Args:
            frame: Starting frame number.
        """
        self._initial_frame = frame
        trace = self.trace[self.batch_sample, frame]
        x = np.arange(frame + 1) * self.dt
        self.fig, self.ax, _, self.label_text = plots.traces(
            trace[:, None],
            x=x,
            contour=self.contour,
            smooth=None,
            fig=self.fig,
            ax=self.ax,
            label=self.label,
            color=self.colors,
            labelxy=self.labelxy,
            xlabel="time (s)",
            ylabel=self.ylabel,
            fontsize=self.fontsize,
            title=self.title,
            legend=self.legend,
        )
        if not self.dynamic_ax_lims:
            if self.ylims is not None:
                ymin, ymax = self.ylims[self.batch_sample]
            else:
                ymin, ymax = plt_utils.get_lims(self.trace, 0.1)
            xmin, xmax = plt_utils.get_lims(
                np.arange(self._initial_frame, self.trace.shape[1]), 0.1
            )
            self.ax.axis([xmin, xmax, ymin, ymax])

        self._sample = self.batch_sample

    def animate(self, frame: int) -> None:
        """Animate a single frame.

        Args:
            frame: Current frame number.
        """
        trace = self.trace[self.batch_sample, self._initial_frame : frame]
        x = np.arange(self._initial_frame, frame) * self.dt

        for n in range(self.n_trace):
            self.ax.lines[n].set_data(x, trace[:, n])

        contour = self.contour[self.batch_sample] if self.contour is not None else None

        if self.batch_sample != self._sample and contour is not None:
            self.ax.collections = []
            _x = np.arange(len(contour))
            _y = np.linspace(-2000, 2000, 100)
            Z = np.tile(contour, (len(_y), 1))
            self.ax.contourf(
                _x,
                _y,
                Z,
                cmap=cm.get_cmap("bone"),
                levels=2,
                alpha=0.3,
                vmin=0,
                vmax=1,
            )

        if self.dynamic_ax_lims:
            if self.ylims is not None:
                ymin, ymax = self.ylims[self.batch_sample]
            else:
                ymin, ymax = plt_utils.get_lims(trace, 0.1)
            xmin, xmax = plt_utils.get_lims(x, 0.1)
            self.ax.axis([xmin, xmax, ymin, ymax])

        if self.label:
            self.label_text.set_text(self.label.format(self.batch_sample, frame))

        if self.update:
            self.update_figure()

        self._sample = self.batch_sample
init
init(frame=0)

Initialize the animation.

Parameters:

Name Type Description Default
frame int

Starting frame number.

0
Source code in flyvision/analysis/animations/traces.py
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
def init(self, frame: int = 0) -> None:
    """Initialize the animation.

    Args:
        frame: Starting frame number.
    """
    self._initial_frame = frame
    trace = self.trace[self.batch_sample, frame]
    x = np.arange(frame + 1) * self.dt
    self.fig, self.ax, _, self.label_text = plots.traces(
        trace[:, None],
        x=x,
        contour=self.contour,
        smooth=None,
        fig=self.fig,
        ax=self.ax,
        label=self.label,
        color=self.colors,
        labelxy=self.labelxy,
        xlabel="time (s)",
        ylabel=self.ylabel,
        fontsize=self.fontsize,
        title=self.title,
        legend=self.legend,
    )
    if not self.dynamic_ax_lims:
        if self.ylims is not None:
            ymin, ymax = self.ylims[self.batch_sample]
        else:
            ymin, ymax = plt_utils.get_lims(self.trace, 0.1)
        xmin, xmax = plt_utils.get_lims(
            np.arange(self._initial_frame, self.trace.shape[1]), 0.1
        )
        self.ax.axis([xmin, xmax, ymin, ymax])

    self._sample = self.batch_sample
animate
animate(frame)

Animate a single frame.

Parameters:

Name Type Description Default
frame int

Current frame number.

required
Source code in flyvision/analysis/animations/traces.py
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
def animate(self, frame: int) -> None:
    """Animate a single frame.

    Args:
        frame: Current frame number.
    """
    trace = self.trace[self.batch_sample, self._initial_frame : frame]
    x = np.arange(self._initial_frame, frame) * self.dt

    for n in range(self.n_trace):
        self.ax.lines[n].set_data(x, trace[:, n])

    contour = self.contour[self.batch_sample] if self.contour is not None else None

    if self.batch_sample != self._sample and contour is not None:
        self.ax.collections = []
        _x = np.arange(len(contour))
        _y = np.linspace(-2000, 2000, 100)
        Z = np.tile(contour, (len(_y), 1))
        self.ax.contourf(
            _x,
            _y,
            Z,
            cmap=cm.get_cmap("bone"),
            levels=2,
            alpha=0.3,
            vmin=0,
            vmax=1,
        )

    if self.dynamic_ax_lims:
        if self.ylims is not None:
            ymin, ymax = self.ylims[self.batch_sample]
        else:
            ymin, ymax = plt_utils.get_lims(trace, 0.1)
        xmin, xmax = plt_utils.get_lims(x, 0.1)
        self.ax.axis([xmin, xmax, ymin, ymax])

    if self.label:
        self.label_text.set_text(self.label.format(self.batch_sample, frame))

    if self.update:
        self.update_figure()

    self._sample = self.batch_sample