Skip to content
70 changes: 69 additions & 1 deletion mne_qt_browser/_pg_figure.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,15 @@
from mne.viz import plot_sensors
from mne.viz._figure import BrowserBase
from mne.viz.backends._utils import _init_mne_qtapp, _qt_raise_window
from mne.viz.ui_events import (
ChannelsSelect,
TimeBrowse,
TimeChange,
disable_ui_events,
publish,
subscribe,
unsubscribe,
)
from mne.viz.utils import _figure_agg, _merge_annotations, _simplify_float
from pyqtgraph import (
AxisItem,
Expand Down Expand Up @@ -4243,6 +4252,43 @@ def __init__(self, **kwargs):
# disable histogram of epoch PTP amplitude
del self.mne.keyboard_shortcuts["h"]

# Subscribe to vertical line change
subscribe(self, "time_change", self._on_time_change_event)

# Subscribe to time browse
subscribe(self, "time_browse", self._on_time_browse_event)

# Subscribe to channel browse
# self.mne.plt.sigYRangeChanged.connect(self._on_channel_browse_event)
subscribe(self, "channels_select", self._on_channel_browse_event)

def _on_time_change_event(self, event):
"""Response to TimeChange event from the event-ui system."""
with disable_ui_events(self):
self._add_vline(event.time)

def _on_time_browse_event(self, event):
"""Response to TimeBrowse event from the event-ui system."""
with disable_ui_events(self):
self.mne.plt.setXRange(event.time_start, event.time_end, padding=0)

def _on_channel_browse_event(self, event):
"""Response to ChannelsSelect event from the event-ui system."""
# Get the indices of the subset in the full set of channels
all_channels = self.mne.ch_names[self.mne.ch_order]
# KRUFT
# ch_indices = [np.where(all_channels == ch)[0][0] for ch in event.channels]
ch_indices = np.where(np.isin(all_channels, event.ch_names))[0]

# Take the start index and set range
with disable_ui_events(self):
start_idx, end_idx = ch_indices.min(), ch_indices.max() + 2
# KRUFT
# start_idx = ch_indices[0]
# n_chans = len(ch_indices)
# end_idx = start_idx+n_chans+1
self.mne.plt.setYRange(start_idx, end_idx, padding=0)

def _hidpi_mkPen(self, *args, **kwargs):
kwargs["width"] = self._pixel_ratio * kwargs.get("width", 1.0)
return mkPen(*args, **kwargs)
Expand Down Expand Up @@ -4493,6 +4539,9 @@ def _vline_slot(self, orig_vline):
vl.setPos(xt)
self.mne.overview_bar.update_vline()

def _vline_drag_slot(self, vline):
publish(self, TimeChange(time=vline.value()))

def _add_vline(self, t):
if self.mne.is_epochs:
ts = self._get_vline_times(t)
Expand All @@ -4510,8 +4559,9 @@ def _add_vline(self, t):
# Avoid off-by-one-error at bmax for VlineLabel
bmax -= 1 / self.mne.info["sfreq"]
vl = VLine(self.mne, xt, bounds=(bmin, bmax))
# Should only be emitted when dragged
# Connect signals for both drag and position change
vl.sigPositionChangeFinished.connect(self._vline_slot)
vl.sigDragged.connect(self._vline_drag_slot)
self.mne.vline.append(vl)
self.mne.plt.addItem(vl)
else:
Expand All @@ -4521,12 +4571,15 @@ def _add_vline(self, t):
if self.mne.vline is None:
self.mne.vline = VLine(self.mne, t, bounds=(0, self.mne.xmax))
self.mne.vline.sigPositionChangeFinished.connect(self._vline_slot)
self.mne.vline.sigDragged.connect(self._vline_drag_slot)
self.mne.plt.addItem(self.mne.vline)

else:
self.mne.vline.setPos(t)

self.mne.vline_visible = True
self.mne.overview_bar.update_vline()
publish(self, TimeChange(time=t))

def _mouse_moved(self, pos):
"""Show Crosshair if enabled at mouse move."""
Expand Down Expand Up @@ -4621,6 +4674,15 @@ def _xrange_changed(self, _, xrange):
# Update annotations
self._update_regions_visible()

# Publish event
publish(
self,
TimeBrowse(
time_start=self.mne.t_start,
time_end=self.mne.t_start + self.mne.duration,
),
)

def _yrange_changed(self, _, yrange):
if not self.mne.butterfly:
if not self.mne.fig_selection:
Expand Down Expand Up @@ -4674,6 +4736,9 @@ def _yrange_changed(self, _, yrange):
trace.update_color()
trace.update_data()

# Publish to event system
publish(self, ChannelsSelect(ch_names=self.mne.ch_names[self.mne.picks]))

# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# DATA HANDLING
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
Expand Down Expand Up @@ -5604,6 +5669,9 @@ def closeEvent(self, event):
self.load_thread.clean()
self.load_thread = None

# Ensure all event handlers are unsubscribed
unsubscribe(self, ["time_change", "time_browse", "channels_select"])

# Remove self from browser_instances in globals
if self in _browser_instances:
_browser_instances.remove(self)
Expand Down
Loading