Skip to content

Ensemble Clustering

flyvision.analysis.clustering.compute_umap_and_clustering

compute_umap_and_clustering(
    ensemble,
    cell_type,
    embedding_kwargs=None,
    gm_kwargs=None,
    subdir="umap_and_clustering",
)

Compute UMAP embedding and Gaussian Mixture clustering of responses.

Parameters:

Name Type Description Default
ensemble EnsembleView

EnsembleView object.

required
cell_type str

Type of cell to analyze.

required
embedding_kwargs Optional[Dict]

UMAP embedding parameters.

None
gm_kwargs Optional[Dict]

Gaussian Mixture clustering parameters.

None
subdir str

Subdirectory for storing results.

'umap_and_clustering'

Returns:

Type Description
GaussianMixtureClustering

GaussianMixtureClustering object.

Note

Results are cached to disk for faster subsequent access.

Source code in flyvision/analysis/clustering.py
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
def compute_umap_and_clustering(
    ensemble: "flyvision.network.EnsembleView",
    cell_type: str,
    embedding_kwargs: Optional[Dict] = None,
    gm_kwargs: Optional[Dict] = None,
    subdir: str = "umap_and_clustering",
) -> GaussianMixtureClustering:
    """
    Compute UMAP embedding and Gaussian Mixture clustering of responses.

    Args:
        ensemble: EnsembleView object.
        cell_type: Type of cell to analyze.
        embedding_kwargs: UMAP embedding parameters.
        gm_kwargs: Gaussian Mixture clustering parameters.
        subdir: Subdirectory for storing results.

    Returns:
        GaussianMixtureClustering object.

    Note:
        Results are cached to disk for faster subsequent access.
    """
    if embedding_kwargs is None:
        embedding_kwargs = {
            "min_dist": 0.105,
            "spread": 9.0,
            "n_neighbors": 5,
            "random_state": 42,
            "n_epochs": 1500,
        }
    if gm_kwargs is None:
        gm_kwargs = {
            "range_n_clusters": [2, 3, 3, 4, 5],
            "n_init": 100,
            "max_iter": 1000,
            "random_state": 42,
            "tol": 0.001,
        }

    destination = ensemble.path / subdir

    def load_from_disk():
        with open((destination / cell_type).with_suffix(".pickle"), "rb") as f:
            embedding_and_clustering = pickle.load(f)

        logging.info("Loaded %s embedding and clustering from %s", cell_type, destination)
        return embedding_and_clustering

    if (destination / cell_type).with_suffix(".pickle").exists():
        return load_from_disk()

    def create_embedding_object(responses):
        central_responses = CentralActivity(
            responses['responses'].values, ensemble[0].connectome, keepref=True
        )
        embeddings = EnsembleEmbedding(central_responses)
        return embeddings

    responses = naturalistic_stimuli_responses(ensemble)
    embeddings = create_embedding_object(responses)

    embedding = embeddings.from_cell_type(cell_type, embedding_kwargs=embedding_kwargs)
    embedding_and_clustering = embedding.cluster.gaussian_mixture(**gm_kwargs)
    return embedding_and_clustering

flyvision.analysis.clustering.umap_embedding

umap_embedding(
    X,
    n_neighbors=5,
    min_dist=0.12,
    spread=9.0,
    random_state=42,
    n_components=2,
    metric="correlation",
    n_epochs=1500,
    **kwargs
)

Perform UMAP embedding on input data.

Parameters:

Name Type Description Default
X ndarray

Input data with shape (n_samples, n_features).

required
n_neighbors int

Number of neighbors to consider for each point.

5
min_dist float

Minimum distance between points in the embedding space.

0.12
spread float

Determines how spread out all embedded points are overall.

9.0
random_state int

Random seed for reproducibility.

42
n_components int

Number of dimensions in the embedding space.

2
metric str

Distance metric to use.

'correlation'
n_epochs int

Number of training epochs for embedding optimization.

1500
**kwargs

Additional keyword arguments for UMAP.

{}

Returns:

Type Description
ndarray

A tuple containing:

ndarray
  • embedding: The UMAP embedding.
UMAP
  • mask: Boolean mask for valid samples.
Tuple[ndarray, ndarray, UMAP]
  • reducer: The fitted UMAP object.

Raises:

Type Description
ValueError

If n_components is too large relative to sample size.

Note

This function handles reshaping of input data and removes constant rows.

Source code in flyvision/analysis/clustering.py
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
def umap_embedding(
    X: np.ndarray,
    n_neighbors: int = 5,
    min_dist: float = 0.12,
    spread: float = 9.0,
    random_state: int = 42,
    n_components: int = 2,
    metric: str = "correlation",
    n_epochs: int = 1500,
    **kwargs,
) -> Tuple[np.ndarray, np.ndarray, UMAP]:
    """
    Perform UMAP embedding on input data.

    Args:
        X: Input data with shape (n_samples, n_features).
        n_neighbors: Number of neighbors to consider for each point.
        min_dist: Minimum distance between points in the embedding space.
        spread: Determines how spread out all embedded points are overall.
        random_state: Random seed for reproducibility.
        n_components: Number of dimensions in the embedding space.
        metric: Distance metric to use.
        n_epochs: Number of training epochs for embedding optimization.
        **kwargs: Additional keyword arguments for UMAP.

    Returns:
        A tuple containing:
        - embedding: The UMAP embedding.
        - mask: Boolean mask for valid samples.
        - reducer: The fitted UMAP object.

    Raises:
        ValueError: If n_components is too large relative to sample size.

    Note:
        This function handles reshaping of input data and removes constant rows.
    """
    # umap import would slow down whole library import
    from umap import UMAP
    from umap.utils import disconnected_vertices

    if n_components > X.shape[0] - 2:
        raise ValueError(
            "number of components must be 2 smaller than sample size. "
            "See: https://github.com/lmcinnes/umap/issues/201"
        )

    if len(X.shape) > 2:
        shape = X.shape
        X = X.reshape(X.shape[0], -1)
        logging.info("reshaped X from %s to %s", shape, X.shape)

    embedding = np.ones([X.shape[0], n_components]) * np.nan
    # umap doesn't like contant rows
    mask = ~np.isclose(X.std(axis=1), 0)
    X = X[mask]
    reducer = UMAP(
        n_neighbors=n_neighbors,
        min_dist=min_dist,
        random_state=random_state,
        n_components=n_components,
        metric=metric,
        spread=spread,
        n_epochs=n_epochs,
        **kwargs,
    )
    _embedding = reducer.fit_transform(X)

    # gaussian mixture doesn't like nans through disconnected vertices in umap
    connected_vertices_mask = ~disconnected_vertices(reducer)
    mask[mask] = mask[mask] & connected_vertices_mask
    embedding[mask] = _embedding[connected_vertices_mask]
    return embedding, mask, reducer

flyvision.analysis.clustering.GaussianMixtureClustering dataclass

Gaussian Mixture Clustering of the embeddings.

Attributes:

Name Type Description
embedding Embedding

The embedding to cluster.

range_n_clusters Iterable[int]

Range of number of clusters to try.

n_init int

Number of initializations for GMM.

max_iter int

Maximum number of iterations for GMM.

random_state int

Random state for reproducibility.

labels NDArray

Cluster labels.

gm object

Fitted GaussianMixture object.

scores list

Scores for each number of clusters.

n_clusters list

Number of clusters tried.

Source code in flyvision/analysis/clustering.py
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
@dataclass
class GaussianMixtureClustering:
    """
    Gaussian Mixture Clustering of the embeddings.

    Attributes:
        embedding (Embedding): The embedding to cluster.
        range_n_clusters (Iterable[int]): Range of number of clusters to try.
        n_init (int): Number of initializations for GMM.
        max_iter (int): Maximum number of iterations for GMM.
        random_state (int): Random state for reproducibility.
        labels (npt.NDArray): Cluster labels.
        gm (object): Fitted GaussianMixture object.
        scores (list): Scores for each number of clusters.
        n_clusters (list): Number of clusters tried.
    """

    embedding: Embedding = None
    range_n_clusters: Iterable[int] = None
    n_init: int = 1
    max_iter: int = 1000
    random_state: int = 0
    labels: npt.NDArray = None
    gm: object = None
    scores: list = None
    n_clusters: list = None

    def __call__(
        self,
        range_n_clusters: Iterable[int] = None,
        n_init: int = 1,
        max_iter: int = 1000,
        random_state: int = 0,
        **kwargs,
    ) -> "GaussianMixtureClustering":
        """
        Perform Gaussian Mixture clustering.

        Args:
            range_n_clusters: Range of number of clusters to try.
            n_init: Number of initializations for GMM.
            max_iter: Maximum number of iterations for GMM.
            random_state: Random state for reproducibility.
            **kwargs: Additional arguments for gaussian_mixture function.

        Returns:
            Self with updated clustering results.
        """
        self.labels, self.gm, self.scores, self.n_clusters = gaussian_mixture(
            self.embedding.embedding,
            self.embedding.mask,
            range_n_clusters=range_n_clusters,
            n_init=n_init,
            max_iter=max_iter,
            random_state=random_state,
            **kwargs,
        )
        self.range_n_clusters = range_n_clusters
        self.n_init = n_init
        self.max_iter = max_iter
        self.random_state = random_state
        self.kwargs = kwargs
        return self

    def task_error_sort_labels(self, task_error: npt.NDArray, mode: str = "mean") -> None:
        """
        Sort cluster labels based on task error.

        Args:
            task_error: Array of task errors.
            mode: Method to compute task error ('mean', 'min', or 'median').
        """
        self.labels = task_error_sort_labels(task_error, self.labels, mode=mode)

    def plot(
        self,
        task_error: npt.NDArray = None,
        colors: npt.NDArray = None,
        annotate: bool = True,
        annotate_scores: bool = False,
        fig: Figure = None,
        ax: Axes = None,
        figsize: tuple = None,
        plot_mode: str = "paper",
        fontsize: int = 5,
        **kwargs,
    ) -> "EmbeddingPlot":
        """
        Plot the clustering results.

        Args:
            task_error: Array of task errors.
            colors: Colors for data points.
            annotate: Whether to annotate clusters.
            annotate_scores: Whether to annotate BIC scores.
            fig: Existing figure to plot on.
            ax: Existing axes to plot on.
            figsize: Size of the figure.
            plot_mode: Mode for plotting ('paper', 'small', or 'large').
            fontsize: Font size for annotations.
            **kwargs: Additional arguments for plot_embedding function.

        Returns:
            An EmbeddingPlot object.

        Raises:
            AssertionError: If the embedding is not 2-dimensional.
        """
        if self.embedding.embedding.shape[1] != 2:
            raise AssertionError("Embedding must be 2-dimensional for plotting")
        if figsize is None:
            figsize = [0.94, 2.38]
        fig, ax = plot_embedding(
            self.embedding.embedding,
            colors=colors,
            task_error=task_error,
            labels=self.labels,
            gm=self.gm,
            mask=self.embedding.mask,
            fit_gaussians=True,
            annotate=annotate,
            title="",
            fig=fig,
            ax=ax,
            mode=plot_mode,
            figsize=figsize,
            fontsize=fontsize,
            range_n_clusters=self.range_n_clusters,
            n_init_gaussian_mixture=self.n_init,
            gm_kwargs=self.kwargs,
            **kwargs,
        )
        if annotate_scores:
            ax.annotate(
                "BIC: {:.2f}".format(np.min(self.scores)),
                xy=(ax.get_xlim()[0], ax.get_ylim()[1]),
                ha="left",
                va="top",
                fontsize=fontsize,
            )
        return EmbeddingPlot(fig, ax, None, None, self.gm.n_components, self)

__call__

__call__(
    range_n_clusters=None,
    n_init=1,
    max_iter=1000,
    random_state=0,
    **kwargs
)

Perform Gaussian Mixture clustering.

Parameters:

Name Type Description Default
range_n_clusters Iterable[int]

Range of number of clusters to try.

None
n_init int

Number of initializations for GMM.

1
max_iter int

Maximum number of iterations for GMM.

1000
random_state int

Random state for reproducibility.

0
**kwargs

Additional arguments for gaussian_mixture function.

{}

Returns:

Type Description
GaussianMixtureClustering

Self with updated clustering results.

Source code in flyvision/analysis/clustering.py
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
def __call__(
    self,
    range_n_clusters: Iterable[int] = None,
    n_init: int = 1,
    max_iter: int = 1000,
    random_state: int = 0,
    **kwargs,
) -> "GaussianMixtureClustering":
    """
    Perform Gaussian Mixture clustering.

    Args:
        range_n_clusters: Range of number of clusters to try.
        n_init: Number of initializations for GMM.
        max_iter: Maximum number of iterations for GMM.
        random_state: Random state for reproducibility.
        **kwargs: Additional arguments for gaussian_mixture function.

    Returns:
        Self with updated clustering results.
    """
    self.labels, self.gm, self.scores, self.n_clusters = gaussian_mixture(
        self.embedding.embedding,
        self.embedding.mask,
        range_n_clusters=range_n_clusters,
        n_init=n_init,
        max_iter=max_iter,
        random_state=random_state,
        **kwargs,
    )
    self.range_n_clusters = range_n_clusters
    self.n_init = n_init
    self.max_iter = max_iter
    self.random_state = random_state
    self.kwargs = kwargs
    return self

task_error_sort_labels

task_error_sort_labels(task_error, mode='mean')

Sort cluster labels based on task error.

Parameters:

Name Type Description Default
task_error NDArray

Array of task errors.

required
mode str

Method to compute task error (‘mean’, ‘min’, or ‘median’).

'mean'
Source code in flyvision/analysis/clustering.py
207
208
209
210
211
212
213
214
215
def task_error_sort_labels(self, task_error: npt.NDArray, mode: str = "mean") -> None:
    """
    Sort cluster labels based on task error.

    Args:
        task_error: Array of task errors.
        mode: Method to compute task error ('mean', 'min', or 'median').
    """
    self.labels = task_error_sort_labels(task_error, self.labels, mode=mode)

plot

plot(
    task_error=None,
    colors=None,
    annotate=True,
    annotate_scores=False,
    fig=None,
    ax=None,
    figsize=None,
    plot_mode="paper",
    fontsize=5,
    **kwargs
)

Plot the clustering results.

Parameters:

Name Type Description Default
task_error NDArray

Array of task errors.

None
colors NDArray

Colors for data points.

None
annotate bool

Whether to annotate clusters.

True
annotate_scores bool

Whether to annotate BIC scores.

False
fig Figure

Existing figure to plot on.

None
ax Axes

Existing axes to plot on.

None
figsize tuple

Size of the figure.

None
plot_mode str

Mode for plotting (‘paper’, ‘small’, or ‘large’).

'paper'
fontsize int

Font size for annotations.

5
**kwargs

Additional arguments for plot_embedding function.

{}

Returns:

Type Description
EmbeddingPlot

An EmbeddingPlot object.

Raises:

Type Description
AssertionError

If the embedding is not 2-dimensional.

Source code in flyvision/analysis/clustering.py
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
def plot(
    self,
    task_error: npt.NDArray = None,
    colors: npt.NDArray = None,
    annotate: bool = True,
    annotate_scores: bool = False,
    fig: Figure = None,
    ax: Axes = None,
    figsize: tuple = None,
    plot_mode: str = "paper",
    fontsize: int = 5,
    **kwargs,
) -> "EmbeddingPlot":
    """
    Plot the clustering results.

    Args:
        task_error: Array of task errors.
        colors: Colors for data points.
        annotate: Whether to annotate clusters.
        annotate_scores: Whether to annotate BIC scores.
        fig: Existing figure to plot on.
        ax: Existing axes to plot on.
        figsize: Size of the figure.
        plot_mode: Mode for plotting ('paper', 'small', or 'large').
        fontsize: Font size for annotations.
        **kwargs: Additional arguments for plot_embedding function.

    Returns:
        An EmbeddingPlot object.

    Raises:
        AssertionError: If the embedding is not 2-dimensional.
    """
    if self.embedding.embedding.shape[1] != 2:
        raise AssertionError("Embedding must be 2-dimensional for plotting")
    if figsize is None:
        figsize = [0.94, 2.38]
    fig, ax = plot_embedding(
        self.embedding.embedding,
        colors=colors,
        task_error=task_error,
        labels=self.labels,
        gm=self.gm,
        mask=self.embedding.mask,
        fit_gaussians=True,
        annotate=annotate,
        title="",
        fig=fig,
        ax=ax,
        mode=plot_mode,
        figsize=figsize,
        fontsize=fontsize,
        range_n_clusters=self.range_n_clusters,
        n_init_gaussian_mixture=self.n_init,
        gm_kwargs=self.kwargs,
        **kwargs,
    )
    if annotate_scores:
        ax.annotate(
            "BIC: {:.2f}".format(np.min(self.scores)),
            xy=(ax.get_xlim()[0], ax.get_ylim()[1]),
            ha="left",
            va="top",
            fontsize=fontsize,
        )
    return EmbeddingPlot(fig, ax, None, None, self.gm.n_components, self)

flyvision.analysis.clustering.EnsembleEmbedding

Embedding of the ensemble responses.

Args: responses (CentralActivity): CentralActivity object

Source code in flyvision/analysis/clustering.py
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
class EnsembleEmbedding:
    """Embedding of the ensemble responses.

    Args: responses (CentralActivity): CentralActivity object
    """

    def __init__(self, responses: CentralActivity):
        self.responses = responses

    def from_cell_type(
        self,
        cell_type,
        embedding_kwargs=None,
    ) -> Embedding:
        """Umap Embedding of the responses of a specific cell type."""

        embedding_kwargs = embedding_kwargs or {}
        return Embedding(*umap_embedding(self.responses[cell_type], **embedding_kwargs))

    def __call__(
        self,
        arg: Union[str, Iterable],
        embedding_kwargs=None,
    ):
        if isinstance(arg, str):
            return self.from_cell_type(arg, embedding_kwargs)
        else:
            raise ValueError("arg")

from_cell_type

from_cell_type(cell_type, embedding_kwargs=None)

Umap Embedding of the responses of a specific cell type.

Source code in flyvision/analysis/clustering.py
411
412
413
414
415
416
417
418
419
def from_cell_type(
    self,
    cell_type,
    embedding_kwargs=None,
) -> Embedding:
    """Umap Embedding of the responses of a specific cell type."""

    embedding_kwargs = embedding_kwargs or {}
    return Embedding(*umap_embedding(self.responses[cell_type], **embedding_kwargs))

flyvision.analysis.clustering.Embedding dataclass

Embedding of the ensemble responses.

Attributes:

Name Type Description
embedding NDArray

The embedded data.

mask NDArray

Mask for valid data points.

reducer object

The reduction object used for embedding.

Source code in flyvision/analysis/clustering.py
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
@dataclass
class Embedding:
    """
    Embedding of the ensemble responses.

    Attributes:
        embedding (npt.NDArray): The embedded data.
        mask (npt.NDArray): Mask for valid data points.
        reducer (object): The reduction object used for embedding.
    """

    embedding: npt.NDArray = None
    mask: npt.NDArray = None
    reducer: object = None

    @property
    def cluster(self) -> "Clustering":
        """Returns a Clustering object for this embedding."""
        return Clustering(self)

    @property
    def embedding(self) -> npt.NDArray:  # noqa: F811
        """Returns the embedded data."""
        return getattr(self, "_embedding", None)

    @embedding.setter
    def embedding(self, value: npt.NDArray) -> None:
        """
        Sets the embedding and scales it to range (0, 1).

        Args:
            value: The embedding array to set.
        """
        self._embedding, self.minmaxscaler = scale_tensor(value)

    def plot(
        self,
        fig: Figure = None,
        ax: Axes = None,
        figsize: tuple = None,
        plot_mode: str = "paper",
        fontsize: int = 5,
        colors: npt.NDArray = None,
        **kwargs,
    ) -> tuple[Figure, Axes]:
        """
        Plot the embedding.

        Args:
            fig: Existing figure to plot on.
            ax: Existing axes to plot on.
            figsize: Size of the figure.
            plot_mode: Mode for plotting ('paper', 'small', or 'large').
            fontsize: Font size for annotations.
            colors: Colors for data points.
            **kwargs: Additional arguments passed to plot_embedding.

        Returns:
            A tuple containing the figure and axes objects.

        Raises:
            AssertionError: If the embedding is not 2-dimensional.
        """
        if self.embedding.shape[1] != 2:
            raise AssertionError("Embedding must be 2-dimensional for plotting")
        if figsize is None:
            figsize = [0.94, 2.38]
        return plot_embedding(
            self.embedding,
            colors=colors,
            task_error=None,
            labels=None,
            mask=self.mask,
            fit_gaussians=False,
            annotate=False,
            title="",
            fig=fig,
            ax=ax,
            mode=plot_mode,
            figsize=figsize,
            fontsize=fontsize,
            **kwargs,
        )

cluster property

cluster

Returns a Clustering object for this embedding.

embedding property writable

embedding

Returns the embedded data.

plot

plot(
    fig=None,
    ax=None,
    figsize=None,
    plot_mode="paper",
    fontsize=5,
    colors=None,
    **kwargs
)

Plot the embedding.

Parameters:

Name Type Description Default
fig Figure

Existing figure to plot on.

None
ax Axes

Existing axes to plot on.

None
figsize tuple

Size of the figure.

None
plot_mode str

Mode for plotting (‘paper’, ‘small’, or ‘large’).

'paper'
fontsize int

Font size for annotations.

5
colors NDArray

Colors for data points.

None
**kwargs

Additional arguments passed to plot_embedding.

{}

Returns:

Type Description
tuple[Figure, Axes]

A tuple containing the figure and axes objects.

Raises:

Type Description
AssertionError

If the embedding is not 2-dimensional.

Source code in flyvision/analysis/clustering.py
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
def plot(
    self,
    fig: Figure = None,
    ax: Axes = None,
    figsize: tuple = None,
    plot_mode: str = "paper",
    fontsize: int = 5,
    colors: npt.NDArray = None,
    **kwargs,
) -> tuple[Figure, Axes]:
    """
    Plot the embedding.

    Args:
        fig: Existing figure to plot on.
        ax: Existing axes to plot on.
        figsize: Size of the figure.
        plot_mode: Mode for plotting ('paper', 'small', or 'large').
        fontsize: Font size for annotations.
        colors: Colors for data points.
        **kwargs: Additional arguments passed to plot_embedding.

    Returns:
        A tuple containing the figure and axes objects.

    Raises:
        AssertionError: If the embedding is not 2-dimensional.
    """
    if self.embedding.shape[1] != 2:
        raise AssertionError("Embedding must be 2-dimensional for plotting")
    if figsize is None:
        figsize = [0.94, 2.38]
    return plot_embedding(
        self.embedding,
        colors=colors,
        task_error=None,
        labels=None,
        mask=self.mask,
        fit_gaussians=False,
        annotate=False,
        title="",
        fig=fig,
        ax=ax,
        mode=plot_mode,
        figsize=figsize,
        fontsize=fontsize,
        **kwargs,
    )

flyvision.analysis.clustering.Clustering dataclass

Clustering of the embedding.

Attributes:

Name Type Description
embedding Embedding

The embedding to be clustered.

Source code in flyvision/analysis/clustering.py
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
@dataclass
class Clustering:
    """Clustering of the embedding.

    Attributes:
        embedding (Embedding): The embedding to be clustered.
    """

    embedding: Embedding = None

    @property
    def gaussian_mixture(self) -> GaussianMixtureClustering:
        """Create a GaussianMixtureClustering object for the embedding.

        Returns:
            GaussianMixtureClustering: A clustering object for Gaussian mixture models.
        """
        return GaussianMixtureClustering(self.embedding)

gaussian_mixture property

gaussian_mixture

Create a GaussianMixtureClustering object for the embedding.

Returns:

Name Type Description
GaussianMixtureClustering GaussianMixtureClustering

A clustering object for Gaussian mixture models.