Skip to content

Optimal Stimuli

Naturalistic

flyvis.analysis.optimal_stimuli.FindOptimalStimuli

Methods to derive optimal stimuli for cells from stimuli dataset.

Parameters:

Name Type Description Default
network_view NetworkView

Network view.

required
stimuli StimulusDataset | str

Stimuli dataset. “default” uses AugmentedSintelLum.

'default'

Attributes:

Name Type Description
nv NetworkView

Network view.

network Network

Initialized network.

central_cells_index list

Central cells index.

stimuli StimulusDataset

Stimulus dataset.

Source code in flyvis/analysis/optimal_stimuli.py
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
class FindOptimalStimuli:
    """Methods to derive optimal stimuli for cells from stimuli dataset.

    Args:
        network_view: Network view.
        stimuli: Stimuli dataset. "default" uses AugmentedSintelLum.

    Attributes:
        nv (flyvis.NetworkView): Network view.
        network (flyvis.Network): Initialized network.
        central_cells_index (list): Central cells index.
        stimuli (StimulusDataset): Stimulus dataset.
    """

    def __init__(
        self,
        network_view: flyvis.NetworkView,
        stimuli: StimulusDataset | str = "default",
    ):
        self.nv = network_view
        self.network = network_view.init_network()  # type: flyvis.Network
        for param in self.network.parameters():
            param.requires_grad = False
        self.central_cells_index = self.network.connectome.central_cells_index[:]
        self.stimuli = (
            AugmentedSintel(tasks=["lum"], dt=1 / 100, temporal_split=True)
            if stimuli == "default"
            else stimuli
        )

    def optimal_stimuli(
        self,
        cell_type: str,
        dt: float = 1 / 100,
        indices: list[int] | None = None,
    ) -> OptimalStimulus:
        """Finds optimal stimuli for a given cell type in stimuli dataset.

        Args:
            cell_type: Node type.
            dt: Time step.
            indices: Indices of stimuli.

        Returns:
            OptimalStimulus object containing the stimulus and response.
        """
        responses = self.nv.naturalistic_stimuli_responses()
        cell_responses = responses['responses'].custom.where(cell_type=cell_type)

        argmax = cell_responses.argmax(dim=("sample", "frame"))['sample'].item()
        if indices is not None:
            argmax = indices[argmax]
        nat_opt_stim = self.stimuli[argmax]["lum"]

        n_frames = nat_opt_stim.shape[0]
        initial_state = self.network.steady_state(1.0, dt, 1)
        stimulus = Stimulus(self.network.connectome, 1, n_frames)
        stimulus.zero()
        stimulus.add_input(nat_opt_stim[None])
        response = self.network(stimulus(), dt, state=initial_state).detach().cpu()
        response = LayerActivity(response, self.network.connectome, keepref=True)[
            cell_type
        ]

        return OptimalStimulus(nat_opt_stim[None, :], response[:, :, None])

    def regularized_optimal_stimuli(
        self,
        cell_type: str,
        l2_act: float = 1,
        lr: float = 1e-2,
        l2_stim: float = 1,
        n_iters: int = 100,
        dt: float = 1 / 100,
        indices: list[int] | None = None,
    ) -> RegularizedOptimalStimulus:
        """Regularizes the optimal stimulus for a given cell type.

        Maintains central node activity while minimizing mean square of input pixels.

        Args:
            cell_type: Node type.
            l2_act: L2 regularization strength for the activity.
            lr: Learning rate.
            l2_stim: L2 regularization strength for the stimulus.
            n_iters: Number of iterations.
            dt: Time step.
            indices: Indices of stimuli.

        Returns:
            RegularizedOptimalStimulus object.
        """

        optim_stimuli = self.optimal_stimuli(
            cell_type=cell_type,
            dt=dt,
            indices=indices,
        )
        non_nan = ~torch.isnan(
            optim_stimuli.stimulus[0, :, 0, optim_stimuli.stimulus.shape[-1] // 2]
        )
        reg_opt_stim = optim_stimuli.stimulus.clone()
        reg_opt_stim = reg_opt_stim[:, non_nan]
        reg_opt_stim.requires_grad = True

        central_target_response = (
            optim_stimuli.response.to(non_nan.device)[
                :, non_nan, :, optim_stimuli.response.shape[-1] // 2
            ]
            .clone()
            .detach()
            .squeeze()
        )

        optim = torch.optim.Adam([reg_opt_stim], lr=lr)

        n_frames = reg_opt_stim.shape[1]

        stim = Stimulus(self.network.connectome, 1, n_frames)

        layer_activity = LayerActivity(None, self.network.connectome, keepref=True)

        initial_state = self.network.steady_state(1.0, dt, 1)

        losses = []
        for _ in range(n_iters):
            optim.zero_grad()
            stim.zero()
            stim.add_input(reg_opt_stim)
            activities = self.network(stim(), dt, state=initial_state)
            layer_activity.update(activities)
            central_predicted_response = layer_activity.central[cell_type].squeeze()

            act_loss = (
                l2_act
                * ((central_predicted_response - central_target_response) ** 2).sum()
            )
            stim_loss = l2_stim * ((reg_opt_stim - 0.5) ** 2).mean(dim=0).sum()
            loss = act_loss + stim_loss
            loss.backward(retain_graph=True)
            optim.step()
            losses.append(loss.detach().cpu().numpy().item())

        stim.zero()
        reg_opt_stim.requires_grad = False
        stim.add_input(reg_opt_stim)
        activities = self.network(stim(), dt, state=initial_state)
        layer_activity.update(activities)

        reg_opt_stim = reg_opt_stim.detach().cpu()
        rnmei_response = layer_activity[cell_type].detach().cpu()
        central_predicted_response = central_predicted_response.detach().cpu()
        central_target_response = central_target_response.detach().cpu()
        return RegularizedOptimalStimulus(
            optim_stimuli,
            reg_opt_stim,
            rnmei_response,
            central_predicted_response,
            central_target_response,
            losses,
        )

optimal_stimuli

optimal_stimuli(cell_type, dt=1 / 100, indices=None)

Finds optimal stimuli for a given cell type in stimuli dataset.

Parameters:

Name Type Description Default
cell_type str

Node type.

required
dt float

Time step.

1 / 100
indices list[int] | None

Indices of stimuli.

None

Returns:

Type Description
OptimalStimulus

OptimalStimulus object containing the stimulus and response.

Source code in flyvis/analysis/optimal_stimuli.py
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def optimal_stimuli(
    self,
    cell_type: str,
    dt: float = 1 / 100,
    indices: list[int] | None = None,
) -> OptimalStimulus:
    """Finds optimal stimuli for a given cell type in stimuli dataset.

    Args:
        cell_type: Node type.
        dt: Time step.
        indices: Indices of stimuli.

    Returns:
        OptimalStimulus object containing the stimulus and response.
    """
    responses = self.nv.naturalistic_stimuli_responses()
    cell_responses = responses['responses'].custom.where(cell_type=cell_type)

    argmax = cell_responses.argmax(dim=("sample", "frame"))['sample'].item()
    if indices is not None:
        argmax = indices[argmax]
    nat_opt_stim = self.stimuli[argmax]["lum"]

    n_frames = nat_opt_stim.shape[0]
    initial_state = self.network.steady_state(1.0, dt, 1)
    stimulus = Stimulus(self.network.connectome, 1, n_frames)
    stimulus.zero()
    stimulus.add_input(nat_opt_stim[None])
    response = self.network(stimulus(), dt, state=initial_state).detach().cpu()
    response = LayerActivity(response, self.network.connectome, keepref=True)[
        cell_type
    ]

    return OptimalStimulus(nat_opt_stim[None, :], response[:, :, None])

regularized_optimal_stimuli

regularized_optimal_stimuli(cell_type, l2_act=1, lr=0.01, l2_stim=1, n_iters=100, dt=1 / 100, indices=None)

Regularizes the optimal stimulus for a given cell type.

Maintains central node activity while minimizing mean square of input pixels.

Parameters:

Name Type Description Default
cell_type str

Node type.

required
l2_act float

L2 regularization strength for the activity.

1
lr float

Learning rate.

0.01
l2_stim float

L2 regularization strength for the stimulus.

1
n_iters int

Number of iterations.

100
dt float

Time step.

1 / 100
indices list[int] | None

Indices of stimuli.

None

Returns:

Type Description
RegularizedOptimalStimulus

RegularizedOptimalStimulus object.

Source code in flyvis/analysis/optimal_stimuli.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
def regularized_optimal_stimuli(
    self,
    cell_type: str,
    l2_act: float = 1,
    lr: float = 1e-2,
    l2_stim: float = 1,
    n_iters: int = 100,
    dt: float = 1 / 100,
    indices: list[int] | None = None,
) -> RegularizedOptimalStimulus:
    """Regularizes the optimal stimulus for a given cell type.

    Maintains central node activity while minimizing mean square of input pixels.

    Args:
        cell_type: Node type.
        l2_act: L2 regularization strength for the activity.
        lr: Learning rate.
        l2_stim: L2 regularization strength for the stimulus.
        n_iters: Number of iterations.
        dt: Time step.
        indices: Indices of stimuli.

    Returns:
        RegularizedOptimalStimulus object.
    """

    optim_stimuli = self.optimal_stimuli(
        cell_type=cell_type,
        dt=dt,
        indices=indices,
    )
    non_nan = ~torch.isnan(
        optim_stimuli.stimulus[0, :, 0, optim_stimuli.stimulus.shape[-1] // 2]
    )
    reg_opt_stim = optim_stimuli.stimulus.clone()
    reg_opt_stim = reg_opt_stim[:, non_nan]
    reg_opt_stim.requires_grad = True

    central_target_response = (
        optim_stimuli.response.to(non_nan.device)[
            :, non_nan, :, optim_stimuli.response.shape[-1] // 2
        ]
        .clone()
        .detach()
        .squeeze()
    )

    optim = torch.optim.Adam([reg_opt_stim], lr=lr)

    n_frames = reg_opt_stim.shape[1]

    stim = Stimulus(self.network.connectome, 1, n_frames)

    layer_activity = LayerActivity(None, self.network.connectome, keepref=True)

    initial_state = self.network.steady_state(1.0, dt, 1)

    losses = []
    for _ in range(n_iters):
        optim.zero_grad()
        stim.zero()
        stim.add_input(reg_opt_stim)
        activities = self.network(stim(), dt, state=initial_state)
        layer_activity.update(activities)
        central_predicted_response = layer_activity.central[cell_type].squeeze()

        act_loss = (
            l2_act
            * ((central_predicted_response - central_target_response) ** 2).sum()
        )
        stim_loss = l2_stim * ((reg_opt_stim - 0.5) ** 2).mean(dim=0).sum()
        loss = act_loss + stim_loss
        loss.backward(retain_graph=True)
        optim.step()
        losses.append(loss.detach().cpu().numpy().item())

    stim.zero()
    reg_opt_stim.requires_grad = False
    stim.add_input(reg_opt_stim)
    activities = self.network(stim(), dt, state=initial_state)
    layer_activity.update(activities)

    reg_opt_stim = reg_opt_stim.detach().cpu()
    rnmei_response = layer_activity[cell_type].detach().cpu()
    central_predicted_response = central_predicted_response.detach().cpu()
    central_target_response = central_target_response.detach().cpu()
    return RegularizedOptimalStimulus(
        optim_stimuli,
        reg_opt_stim,
        rnmei_response,
        central_predicted_response,
        central_target_response,
        losses,
    )

flyvis.analysis.optimal_stimuli.OptimalStimulus dataclass

Optimal stimulus and response.

Source code in flyvis/analysis/optimal_stimuli.py
289
290
291
292
293
294
@dataclass
class OptimalStimulus:
    """Optimal stimulus and response."""

    stimulus: np.ndarray
    response: np.ndarray

Artificial

flyvis.analysis.optimal_stimuli.GeneratedOptimalStimulus dataclass

Generated optimal stimulus, response, and optimization losses.

Source code in flyvis/analysis/optimal_stimuli.py
309
310
311
312
313
314
315
@dataclass
class GeneratedOptimalStimulus:
    """Generated optimal stimulus, response, and optimization losses."""

    stimulus: np.ndarray
    response: np.ndarray
    losses: np.ndarray

flyvis.analysis.optimal_stimuli.RegularizedOptimalStimulus dataclass

Regularized optimal stimulus and related data.

Source code in flyvis/analysis/optimal_stimuli.py
297
298
299
300
301
302
303
304
305
306
@dataclass
class RegularizedOptimalStimulus:
    """Regularized optimal stimulus and related data."""

    stimulus: OptimalStimulus
    regularized_stimulus: np.ndarray
    response: np.ndarray
    central_predicted_response: np.ndarray
    central_target_response: np.ndarray
    losses: np.ndarray

Visualization

flyvis.analysis.optimal_stimuli.StimResponsePlot dataclass

Stimulus-response plot data and methods.

Source code in flyvis/analysis/optimal_stimuli.py
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
@dataclass
class StimResponsePlot:
    """Stimulus-response plot data and methods."""

    stim: np.ndarray
    response: np.ndarray
    dt: float
    u: np.ndarray
    v: np.ndarray
    time: np.ndarray
    t_step: np.ndarray
    t_steps_stim: np.ndarray
    t_steps_response: np.ndarray
    xmin_lattice: float
    xmax_lattice: float
    ymin_lattice: float
    ymax_lattice: float
    subtraced_baseline: bool
    steps: int
    fig: Any
    axes: Any
    time_axis: Any
    trace_axis: Any
    argmax: int
    t_argmax: float

    def __iter__(self):
        """Yield figure and axes."""
        yield from [self.fig, self.axes]

    def add_to_trace_axis(
        self,
        other: "StimResponsePlot",
        color: str | None = None,
        label: str | None = None,
        linewidth: float | None = None,
    ):
        """Add another StimResponsePlot's trace to this plot's trace axis."""
        xticks = self.trace_axis.get_xticks()
        mask = (other.time >= other.t_step.min()) & (other.time <= other.t_step.max())
        time = np.linspace(xticks.min(), xticks.max(), mask.sum())
        self.trace_axis.plot(
            time,
            other.response[mask, other.response.shape[-1] // 2],
            color=color,
            label=label,
            linewidth=linewidth,
        )

__iter__

__iter__()

Yield figure and axes.

Source code in flyvis/analysis/optimal_stimuli.py
344
345
346
def __iter__(self):
    """Yield figure and axes."""
    yield from [self.fig, self.axes]

add_to_trace_axis

add_to_trace_axis(other, color=None, label=None, linewidth=None)

Add another StimResponsePlot’s trace to this plot’s trace axis.

Source code in flyvis/analysis/optimal_stimuli.py
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
def add_to_trace_axis(
    self,
    other: "StimResponsePlot",
    color: str | None = None,
    label: str | None = None,
    linewidth: float | None = None,
):
    """Add another StimResponsePlot's trace to this plot's trace axis."""
    xticks = self.trace_axis.get_xticks()
    mask = (other.time >= other.t_step.min()) & (other.time <= other.t_step.max())
    time = np.linspace(xticks.min(), xticks.max(), mask.sum())
    self.trace_axis.plot(
        time,
        other.response[mask, other.response.shape[-1] // 2],
        color=color,
        label=label,
        linewidth=linewidth,
    )

flyvis.analysis.optimal_stimuli.plot_stim_response

plot_stim_response(stim, response, dt, u, v, max_extent=6, subtract_baseline=True, seconds=0.2, steps=10, columns=10, suptitle='', plot_resp=True, hlines=True, vlines=True, time_axis=True, peak_central=False, wspace=-0.2, peak_last=True, fontsize=5, ylabel='', ylabelrotation=90, figsize=[5, 1], label_peak_response=False, fig=None, axes=None, crange=None, trace_axis=False, trace_label=None, trace_axis_offset=0.1, trace_color=None)

Plot spatio-temporal stimulus and response on regular hex lattices.

Parameters:

Name Type Description Default
stim ndarray

Stimulus array.

required
response ndarray

Response array.

required
dt float

Time step.

required
u ndarray

Hexagonal u-coordinates.

required
v ndarray

Hexagonal v-coordinates.

required
max_extent int

Maximum extent of the hexagonal grid.

6
subtract_baseline bool

Whether to subtract baseline from response.

True
seconds float

Duration to plot in seconds.

0.2
steps int

Number of time steps to plot.

10
columns int

Number of columns in the plot.

10
suptitle str

Super title for the plot.

''
plot_resp bool

Whether to plot response.

True
hlines bool

Whether to plot horizontal lines.

True
vlines bool

Whether to plot vertical lines.

True
time_axis bool

Whether to add a time axis.

True
peak_central bool

Whether to center the plot around the peak.

False
wspace float

Width space between subplots.

-0.2
peak_last bool

Whether to show the peak in the last frame.

True
fontsize int

Font size for labels and titles.

5
ylabel str

Y-axis label.

''
ylabelrotation int

Rotation angle for y-axis label.

90
figsize list[float]

Figure size.

[5, 1]
label_peak_response bool

Whether to label the peak response.

False
fig Figure | None

Existing figure to plot on.

None
axes ndarray | None

Existing axes to plot on.

None
crange float | None

Color range for the plot.

None
trace_axis bool

Whether to add a trace axis.

False
trace_label str | None

Label for the trace.

None
trace_axis_offset float

Offset for the trace axis.

0.1
trace_color str | None

Color for the trace.

None

Returns:

Type Description
StimResponsePlot

StimResponsePlot object containing plot data and figure.

Source code in flyvis/analysis/optimal_stimuli.py
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
def plot_stim_response(
    stim: np.ndarray,
    response: np.ndarray,
    dt: float,
    u: np.ndarray,
    v: np.ndarray,
    max_extent: int = 6,
    subtract_baseline: bool = True,
    seconds: float = 0.2,
    steps: int = 10,
    columns: int = 10,
    suptitle: str = "",
    plot_resp: bool = True,
    hlines: bool = True,
    vlines: bool = True,
    time_axis: bool = True,
    peak_central: bool = False,
    wspace: float = -0.2,
    peak_last: bool = True,
    fontsize: int = 5,
    ylabel: str = "",
    ylabelrotation: int = 90,
    figsize: list[float] = [5, 1],
    label_peak_response: bool = False,
    fig: plt.Figure | None = None,
    axes: np.ndarray | None = None,
    crange: float | None = None,
    trace_axis: bool = False,
    trace_label: str | None = None,
    trace_axis_offset: float = 0.1,
    trace_color: str | None = None,
) -> StimResponsePlot:
    """Plot spatio-temporal stimulus and response on regular hex lattices.

    Args:
        stim: Stimulus array.
        response: Response array.
        dt: Time step.
        u: Hexagonal u-coordinates.
        v: Hexagonal v-coordinates.
        max_extent: Maximum extent of the hexagonal grid.
        subtract_baseline: Whether to subtract baseline from response.
        seconds: Duration to plot in seconds.
        steps: Number of time steps to plot.
        columns: Number of columns in the plot.
        suptitle: Super title for the plot.
        plot_resp: Whether to plot response.
        hlines: Whether to plot horizontal lines.
        vlines: Whether to plot vertical lines.
        time_axis: Whether to add a time axis.
        peak_central: Whether to center the plot around the peak.
        wspace: Width space between subplots.
        peak_last: Whether to show the peak in the last frame.
        fontsize: Font size for labels and titles.
        ylabel: Y-axis label.
        ylabelrotation: Rotation angle for y-axis label.
        figsize: Figure size.
        label_peak_response: Whether to label the peak response.
        fig: Existing figure to plot on.
        axes: Existing axes to plot on.
        crange: Color range for the plot.
        trace_axis: Whether to add a trace axis.
        trace_label: Label for the trace.
        trace_axis_offset: Offset for the trace axis.
        trace_color: Color for the trace.

    Returns:
        StimResponsePlot object containing plot data and figure.
    """
    stim = tensor_utils.to_numpy(stim).squeeze()
    mask = ~np.isnan(stim).any(axis=-1).squeeze()
    response = tensor_utils.to_numpy(response).squeeze()
    stim = stim[mask]
    response = response[mask]

    if subtract_baseline:
        response -= response[0]

    argmax = np.nanargmax(response[:, response.shape[-1] // 2])

    n_frames = response.shape[0]
    time = np.arange(n_frames) * dt
    steps = int(seconds / dt)
    t_argmax = time[argmax]

    if peak_central:
        start = argmax - steps // 2
        end = argmax + steps // 2
        if start < 0:
            start = 0
            end = steps
        peak_last = False

    if peak_last:
        start = argmax - steps
        end = argmax
        if start < 0:
            start = 0
            end = steps

    _t_steps = time[start:end]

    # resample in time in case seconds, number of columns, dt does not match
    time_index = np.linspace(0, len(_t_steps), 2 * columns, endpoint=False).astype(int)
    _t_steps = _t_steps[time_index]

    # breakpoint()
    t_steps_stim = _t_steps[0::2]
    t_steps_resp = _t_steps[1::2]

    _u, _v = hex_utils.get_hex_coords(max_extent)
    x, y = hex_utils.hex_to_pixel(_u, _v)
    xmin, xmax = x.min(), x.max()
    ymin, ymax = y.min(), y.max()
    elev = 0
    azim = 0

    if fig is None or axes is None:
        if plot_resp:
            x, y = hex_utils.hex_rows(2, columns)
            fig, axes, pos = plt_utils.ax_scatter(
                x,
                y,
                figsize=figsize,
                hpad=0,
                wpad=0.07,
                wspace=-0.7,
                hspace=-0.5,
            )
            axes = np.array(axes).reshape(2, columns)

        else:
            fig, axes = plt_utils.divide_figure_to_grid(
                np.arange(10).reshape(1, 10),
                wspace=wspace,
                as_matrix=True,
                figsize=figsize,
            )

    crange = crange or np.abs(np.nanmax(response))
    for i, t in enumerate(t_steps_stim):
        # plot stimulus
        mask = np.where(np.abs(time - t) <= 1e-15, True, False)
        _stim = stim[mask].squeeze()
        plots.quick_hex_scatter(
            _stim,
            vmin=0,
            vmax=1,
            cbar=False,
            max_extent=max_extent,
            fig=fig,
            ax=axes[0, i],
        )

        if hlines:
            axes[0, i].hlines(elev, xmin, xmax, color="#006400", linewidth=0.25)
        if vlines:
            axes[0, i].vlines(azim, ymin, ymax, color="#006400", linewidth=0.25)

        if plot_resp:
            # --- plot response

            mask = np.where(np.abs(time - t_steps_resp[i]) <= 1e-15, True, False)
            _resp = response[mask].squeeze()
            plots.hex_scatter(
                u,
                v,
                _resp,
                fill=True,
                # edgecolor="0.3",
                # edgewidth=0.1,
                cmap=plt.cm.coolwarm,
                vmin=-crange,
                vmax=crange,
                midpoint=0,
                cbar=False,
                max_extent=max_extent,
                fig=fig,
                ax=axes[1, i],
            )
            if t_steps_resp[i] == t_argmax and label_peak_response:
                axes[1, i].set_title("peak", fontsize=fontsize)

            if hlines:
                axes[1, i].hlines(elev, xmin, xmax, color="#006400", linewidth=0.25)
            if vlines:
                axes[1, i].vlines(azim, ymin, ymax, color="#006400", linewidth=0.25)

    if trace_axis:
        left = fig.transFigure.inverted().transform(
            axes[0, 0].transData.transform((0, 0))
        )[0]
        right = fig.transFigure.inverted().transform(
            axes[-1, -1].transData.transform((0, 0))
        )[0]

        lefts, bottoms, rights, tops = np.array([
            ax.get_position().extents for ax in axes.flatten()
        ]).T

        trace_axis = fig.add_axes(
            (
                left,
                bottoms.min() - trace_axis_offset,
                right - left,
                trace_axis_offset - 0.05 * trace_axis_offset,
            ),
            label="trace_axis",
        )
        plt_utils.rm_spines(
            trace_axis, ("top", "right"), rm_yticks=False, rm_xticks=False
        )

        data_centers_in_points = np.array([
            ax.transData.transform((0, 0)) for ax in axes.flatten(order="F")
        ])
        trace_axis.tick_params(axis="both", labelsize=fontsize)
        if plot_resp:
            xticks = trace_axis.transData.inverted().transform(data_centers_in_points)[
                1::2, 0
            ]
            trace_axis.set_xticks(xticks)
            ticklabels = np.round(_t_steps * 1000, 0)
            trace_axis.set_xticklabels((ticklabels - ticklabels.max())[1::2])
        else:
            xticks = trace_axis.transData.inverted().transform(data_centers_in_points)[
                :, 0
            ]
            trace_axis.set_xticks(xticks)
            ticklabels = np.round(t_steps_stim * 1000, 0)
            trace_axis.set_xticklabels((ticklabels - ticklabels.max()))
        trace_axis.set_xlabel("time (ms)", fontsize=fontsize, labelpad=2)
        plt_utils.set_spine_tick_params(
            trace_axis,
            spinewidth=0.25,
            tickwidth=0.25,
            ticklength=3,
            ticklabelpad=2,
            spines=("top", "right", "bottom", "left"),
        )
        xlim = trace_axis.get_xlim()
        mask = (time >= _t_steps.min()) & (time <= _t_steps.max())

        time = np.linspace(xticks.min(), xticks.max(), mask.sum())

        trace_axis.plot(
            time,
            response[mask, response.shape[-1] // 2],
            label=trace_label,
            color=trace_color,
        )
        trace_axis.set_xlim(*xlim)
        trace_axis.set_ylabel("central\nresponse", fontsize=fontsize)
        # flyvis.plots.trim_axis(trace_axis)

        time_axis = False

    if time_axis:
        left = fig.transFigure.inverted().transform(
            axes[0, 0].transData.transform((0, 0))
        )[0]
        right = fig.transFigure.inverted().transform(
            axes[-1, -1].transData.transform((0, 0))
        )[0]

        lefts, bottoms, rights, tops = np.array([
            ax.get_position().extents for ax in axes.flatten()
        ]).T

        time_axis = fig.add_axes((left, bottoms.min(), right - left, 0.01))
        plt_utils.rm_spines(
            time_axis, ("left", "top", "right"), rm_yticks=True, rm_xticks=False
        )

        data_centers_in_points = np.array([
            ax.transData.transform((0, 0)) for ax in axes.flatten(order="F")
        ])
        time_axis.tick_params(axis="both", labelsize=fontsize)
        if plot_resp:
            time_axis.set_xticks(
                time_axis.transData.inverted().transform(data_centers_in_points)[1::2, 0]
            )
            ticklabels = np.round(_t_steps * 1000, 0)
            time_axis.set_xticklabels((ticklabels - ticklabels.max())[1::2])
        else:
            time_axis.set_xticks(
                time_axis.transData.inverted().transform(data_centers_in_points)[:, 0]
            )
            ticklabels = np.round(t_steps_stim * 1000, 0)
            time_axis.set_xticklabels((ticklabels - ticklabels.max()))
        time_axis.set_xlabel("time (ms)", fontsize=fontsize, labelpad=2)
        plt_utils.set_spine_tick_params(
            time_axis,
            spinewidth=0.25,
            tickwidth=0.25,
            ticklength=3,
            ticklabelpad=2,
            spines=("top", "right", "bottom", "left"),
        )

    if ylabel:
        lefts, bottoms, rights, tops = np.array([
            ax.get_position().extents for ax in axes.flatten()
        ]).T
        ylabel_axis = fig.add_axes((
            lefts.min(),
            bottoms.min(),
            0.01,
            tops.max() - bottoms.min(),
        ))
        plt_utils.rm_spines(
            ylabel_axis,
            ("left", "top", "right", "bottom"),
            rm_yticks=True,
            rm_xticks=True,
        )
        ylabel_axis.set_ylabel(ylabel, fontsize=fontsize, rotation=ylabelrotation)
        ylabel_axis.patch.set_alpha(0)

    if plot_resp and ylabel is not None:
        axes[0, 0].annotate(
            "stimulus",
            xy=(0, 0.5),
            ha="right",
            va="center",
            fontsize=fontsize,
            rotation=90,
            xycoords="axes fraction",
        )
        axes[1, 0].annotate(
            "response",
            xy=(0, 0.5),
            ha="right",
            va="center",
            fontsize=fontsize,
            rotation=90,
            xycoords="axes fraction",
        )

    if suptitle:
        lefts, bottoms, rights, tops = np.array([
            ax.get_position().extents for ax in axes.flatten()
        ]).T
        fig.suptitle(suptitle, fontsize=fontsize, y=tops.max(), va="bottom")

    plt_utils.set_spine_tick_params(
        fig.axes[-1],
        spinewidth=0.25,
        tickwidth=0.25,
        ticklength=3,
        ticklabelpad=2,
        spines=("top", "right", "bottom", "left"),
    )

    fig.crange = crange

    return StimResponsePlot(
        stim,
        response,
        dt,
        u,
        v,
        time,
        _t_steps,
        t_steps_stim,
        t_steps_resp,
        xmin,
        xmax,
        ymin,
        ymax,
        subtract_baseline,
        steps,
        fig,
        axes,
        time_axis,
        trace_axis,
        argmax,
        t_argmax,
    )