Skip to content

Task

flyvision.task.tasks

Task

Defines a task for a multi-task dataset from configurations.

Parameters:

Name Type Description Default
dataset Namespace

Configuration for the dataset.

required
decoder Namespace

Configuration for the decoder.

required
loss Namespace

Configuration for the loss functions.

required
batch_size int

Size of each batch. Defaults to 4.

4
n_iters int

Number of iterations. Defaults to 250,000.

250000
n_folds int

Number of folds for cross-validation. Defaults to 4.

4
fold int

Current fold number. Defaults to 1.

1
seed int

Random seed for reproducibility. Defaults to 0.

0
original_split bool

Whether to use the original data split. Defaults to False.

False

Attributes:

Name Type Description
batch_size

Size of each batch.

n_iters

Number of iterations.

n_folds

Number of folds for cross-validation.

fold

Current fold number.

seed

Random seed for reproducibility.

decoder

Configuration for the decoder.

dataset MultiTaskDataset

The initialized multi-task dataset.

losses Namespace

Loss functions for each task.

train_seq_index List[int]

Indices of training sequences.

val_seq_index List[int]

Indices of validation sequences.

train_data DataLoader

DataLoader for training data.

train_batch DataLoader

DataLoader for a single training batch.

val_data DataLoader

DataLoader for validation data.

val_batch DataLoader

DataLoader for a single validation batch.

overfit_data DataLoader

DataLoader for overfitting on a single sample.

Source code in flyvision/task/tasks.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
class Task:
    """Defines a task for a multi-task dataset from configurations.

    Args:
        dataset: Configuration for the dataset.
        decoder: Configuration for the decoder.
        loss: Configuration for the loss functions.
        batch_size: Size of each batch. Defaults to 4.
        n_iters: Number of iterations. Defaults to 250,000.
        n_folds: Number of folds for cross-validation. Defaults to 4.
        fold: Current fold number. Defaults to 1.
        seed: Random seed for reproducibility. Defaults to 0.
        original_split: Whether to use the original data split. Defaults to False.

    Attributes:
        batch_size: Size of each batch.
        n_iters: Number of iterations.
        n_folds: Number of folds for cross-validation.
        fold: Current fold number.
        seed: Random seed for reproducibility.
        decoder: Configuration for the decoder.
        dataset (MultiTaskDataset): The initialized multi-task dataset.
        losses (Namespace): Loss functions for each task.
        train_seq_index (List[int]): Indices of training sequences.
        val_seq_index (List[int]): Indices of validation sequences.
        train_data (DataLoader): DataLoader for training data.
        train_batch (DataLoader): DataLoader for a single training batch.
        val_data (DataLoader): DataLoader for validation data.
        val_batch (DataLoader): DataLoader for a single validation batch.
        overfit_data (DataLoader): DataLoader for overfitting on a single sample.
    """

    def __init__(
        self,
        dataset: Namespace,
        decoder: Namespace,
        loss: Namespace,
        task_weights: Dict[str, float] = None,
        batch_size: int = 4,
        n_iters: int = 250_000,
        n_folds: int = 4,
        fold: int = 1,
        seed: int = 0,
        original_split: bool = False,
    ):
        self.batch_size = batch_size
        self.n_iters = n_iters
        self.n_folds = n_folds
        self.fold = fold
        self.seed = seed
        self.decoder = decoder

        # Initialize dataset.
        self.dataset = forward_subclass(MultiTaskDataset, dataset)
        self.task_weights, self.task_weights_sum = self.init_task_weights(task_weights)

        self.losses = Namespace({
            task: getattr(objectives, config) for task, config in loss.items()
        })

        if original_split:
            self.train_seq_index, self.val_seq_index = (
                self.dataset.original_train_and_validation_indices()
            )
        else:
            self.train_seq_index, self.val_seq_index = self.dataset.get_random_data_split(
                fold, n_folds, seed
            )

        # Initialize dataloaders.
        self.train_data = DataLoader(
            self.dataset,
            batch_size=batch_size,
            sampler=sampler.SubsetRandomSampler(self.train_seq_index),
            drop_last=True,
        )
        self.train_batch = DataLoader(
            self.dataset,
            batch_size=batch_size,
            sampler=IndexSampler(self.train_seq_index[:batch_size]),
            drop_last=False,
        )
        logging.info(
            "Initialized dataloader with training sequence indices \n%s",
            self.train_seq_index,
        )

        self.val_data = DataLoader(
            self.dataset,
            batch_size=1,
            sampler=IndexSampler(self.val_seq_index),
        )
        self.val_batch = DataLoader(
            self.dataset,
            batch_size=batch_size,
            sampler=IndexSampler(self.val_seq_index[:batch_size]),
        )
        logging.info(
            "Initialized dataloader with validation sequence indices \n%s",
            self.val_seq_index,
        )

        # Initialize overfitting loader.
        self.overfit_data = DataLoader(self.dataset, sampler=IndexSampler([0]))

    def init_decoder(
        self, connectome: ConnectomeFromAvgFilters
    ) -> Dict[str, ActivityDecoder]:
        """Initialize the decoder.

        Args:
            connectome: The connectome directory.

        Returns:
            A dictionary of initialized decoders.
        """
        return init_decoder(self.decoder, connectome)

    def loss(
        self, input: torch.Tensor, target: torch.Tensor, task: str, **kwargs
    ) -> torch.Tensor:
        """Returns the task loss multiplied with the task weight.

        Args:
            input: Input tensor.
            target: Target tensor.
            task: Task name.
            **kwargs: Additional keyword arguments for the loss function.

        Returns:
            Weighted task loss.
        """
        return (
            self.task_weights[task]
            * self.losses[task](input, target, **kwargs)
            / self.task_weights_sum
        )

    def init_task_weights(self, task_weights: Dict[str, float]) -> Dict[str, float]:
        """Returns the task weights.

        Returns:
            A dictionary of task weights.
        """
        task_weights = (
            task_weights
            if task_weights is not None
            else {task: 1 for task in self.dataset.tasks}
        )

        return task_weights, sum(task_weights.values())
init_decoder
init_decoder(connectome)

Initialize the decoder.

Parameters:

Name Type Description Default
connectome ConnectomeFromAvgFilters

The connectome directory.

required

Returns:

Type Description
Dict[str, ActivityDecoder]

A dictionary of initialized decoders.

Source code in flyvision/task/tasks.py
123
124
125
126
127
128
129
130
131
132
133
134
def init_decoder(
    self, connectome: ConnectomeFromAvgFilters
) -> Dict[str, ActivityDecoder]:
    """Initialize the decoder.

    Args:
        connectome: The connectome directory.

    Returns:
        A dictionary of initialized decoders.
    """
    return init_decoder(self.decoder, connectome)
loss
loss(input, target, task, **kwargs)

Returns the task loss multiplied with the task weight.

Parameters:

Name Type Description Default
input Tensor

Input tensor.

required
target Tensor

Target tensor.

required
task str

Task name.

required
**kwargs

Additional keyword arguments for the loss function.

{}

Returns:

Type Description
Tensor

Weighted task loss.

Source code in flyvision/task/tasks.py
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
def loss(
    self, input: torch.Tensor, target: torch.Tensor, task: str, **kwargs
) -> torch.Tensor:
    """Returns the task loss multiplied with the task weight.

    Args:
        input: Input tensor.
        target: Target tensor.
        task: Task name.
        **kwargs: Additional keyword arguments for the loss function.

    Returns:
        Weighted task loss.
    """
    return (
        self.task_weights[task]
        * self.losses[task](input, target, **kwargs)
        / self.task_weights_sum
    )
init_task_weights
init_task_weights(task_weights)

Returns the task weights.

Returns:

Type Description
Dict[str, float]

A dictionary of task weights.

Source code in flyvision/task/tasks.py
156
157
158
159
160
161
162
163
164
165
166
167
168
def init_task_weights(self, task_weights: Dict[str, float]) -> Dict[str, float]:
    """Returns the task weights.

    Returns:
        A dictionary of task weights.
    """
    task_weights = (
        task_weights
        if task_weights is not None
        else {task: 1 for task in self.dataset.tasks}
    )

    return task_weights, sum(task_weights.values())

init_decoder

init_decoder(config, connectome)

Initialize decoders.

Parameters:

Name Type Description Default
config Dict

Configuration for the decoders.

required
connectome ConnectomeFromAvgFilters

The connectome directory.

required

Returns:

Type Description
Dict[str, ActivityDecoder]

A dictionary of decoders.

Example
decoder = Namespace(
    flow=Namespace(
        type="DecoderGAVP",
        shape=[8, 2],
        kernel_size=5,
        const_weight=0.001,
        p_dropout=0.5,
    ),
    depth=Namespace(
        type="DecoderGAVP",
        shape=[8, 1],
        kernel_size=5,
        const_weight=0.001,
        p_dropout=0.5,
    ),
    lum=Namespace(
        type="DecoderGAVP",
        shape=[8, 3],
        n_out_features=2,
        kernel_size=5,
        const_weight=0.001,
        p_dropout=0.5,
    ),
    shared=False,
)
Source code in flyvision/task/tasks.py
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
def init_decoder(
    config: Dict, connectome: ConnectomeFromAvgFilters
) -> Dict[str, ActivityDecoder]:
    """Initialize decoders.

    Args:
        config: Configuration for the decoders.
        connectome: The connectome directory.

    Returns:
        A dictionary of decoders.

    Example:
        ```python
        decoder = Namespace(
            flow=Namespace(
                type="DecoderGAVP",
                shape=[8, 2],
                kernel_size=5,
                const_weight=0.001,
                p_dropout=0.5,
            ),
            depth=Namespace(
                type="DecoderGAVP",
                shape=[8, 1],
                kernel_size=5,
                const_weight=0.001,
                p_dropout=0.5,
            ),
            lum=Namespace(
                type="DecoderGAVP",
                shape=[8, 3],
                n_out_features=2,
                kernel_size=5,
                const_weight=0.001,
                p_dropout=0.5,
            ),
            shared=False,
        )
        ```
    """
    config = config.deepcopy()

    def init(conf):
        return forward_subclass(ActivityDecoder, {**conf, "connectome": connectome})

    decoder = valmap(init, config)

    return decoder

flyvision.task.objectives

Loss functions compatible with torch loss function API.

l2norm

l2norm(y_est, y_gt, **kwargs)

Calculate the mean root cumulative squared error across the last three dimensions.

Parameters:

Name Type Description Default
y_est Tensor

The estimated tensor.

required
y_gt Tensor

The ground truth tensor.

required
**kwargs Any

Additional keyword arguments.

{}

Returns:

Type Description
Tensor

The mean root cumulative squared error.

Source code in flyvision/task/objectives.py
10
11
12
13
14
15
16
17
18
19
20
21
22
def l2norm(y_est: torch.Tensor, y_gt: torch.Tensor, **kwargs: Any) -> torch.Tensor:
    """
    Calculate the mean root cumulative squared error across the last three dimensions.

    Args:
        y_est: The estimated tensor.
        y_gt: The ground truth tensor.
        **kwargs: Additional keyword arguments.

    Returns:
        The mean root cumulative squared error.
    """
    return (((y_est - y_gt) ** 2).sum(dim=(1, 2, 3))).sqrt().mean()

epe

epe(y_est, y_gt, **kwargs)

Calculate the average endpoint error, conventionally reported in optic flow tasks.

Parameters:

Name Type Description Default
y_est Tensor

The estimated tensor with shape (samples, frames, ndim, hexals_or_features).

required
y_gt Tensor

The ground truth tensor with the same shape as y_est.

required
**kwargs Any

Additional keyword arguments.

{}

Returns:

Type Description
Tensor

The average endpoint error.

Source code in flyvision/task/objectives.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def epe(y_est: torch.Tensor, y_gt: torch.Tensor, **kwargs: Any) -> torch.Tensor:
    """
    Calculate the average endpoint error, conventionally reported in optic flow tasks.

    Args:
        y_est: The estimated tensor with shape
            (samples, frames, ndim, hexals_or_features).
        y_gt: The ground truth tensor with the same shape as y_est.
        **kwargs: Additional keyword arguments.

    Returns:
        The average endpoint error.
    """
    return torch.sqrt(((y_est - y_gt) ** 2).sum(dim=2)).mean()