diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 3925f41d2b..939d271286 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -14,7 +14,7 @@ import numpy as np from .globals import get_global_tmp_folder, is_set_global_tmp_folder -from .core_tools import check_json, is_dict_extractor, recursive_path_modifier, SIJsonEncoder +from .core_tools import check_json, is_dict_extractor, recursive_path_modifier, dict_contains_extractors, SIJsonEncoder from .job_tools import _shared_job_kwargs_doc @@ -310,6 +310,7 @@ def to_dict( relative_to: Union[str, Path, None] = None, folder_metadata=None, recursive: bool = False, + skip_recursive_path_modifier_warning: bool = False, ) -> dict: """ Make a nested serialized dictionary out of the extractor. The dictionary produced can be used to re-initialize @@ -329,6 +330,8 @@ def to_dict( Folder with numpy `npy` files containing additional information (e.g. probe in BaseRecording) and properties. recursive: bool If True, all dicitionaries in the kwargs are expanded with `to_dict` as well, by default False. + skip_recursive_path_modifier_warning: bool + If True, skip the warning that is raised when `recursive=True` and `relative_to` is not None. Returns ------- @@ -359,6 +362,7 @@ def to_dict( new_kwargs[name] = transform_extractors_to_dict(value) kwargs = new_kwargs + class_name = str(type(self)).replace("", "") module = class_name.split(".")[0] imported_module = importlib.import_module(module) @@ -376,11 +380,6 @@ def to_dict( "relative_paths": (relative_to is not None), } - try: - dump_dict["version"] = imported_module.__version__ - except AttributeError: - dump_dict["version"] = "unknown" - if include_annotations: dump_dict["annotations"] = self._annotations else: @@ -394,9 +393,12 @@ def to_dict( dump_dict["properties"] = {k: self._properties.get(k, None) for k in self._main_properties} if relative_to is not None: - relative_to = Path(relative_to).absolute() + relative_to = Path(relative_to).resolve().absolute() assert relative_to.is_dir(), "'relative_to' must be an existing directory" - dump_dict = _make_paths_relative(dump_dict, relative_to) + copy = False if dict_contains_extractors(dump_dict) else True + dump_dict = _make_paths_relative( + dump_dict, relative_to, copy=copy, skip_warning=skip_recursive_path_modifier_warning + ) if folder_metadata is not None: if relative_to is not None: @@ -424,7 +426,8 @@ def from_dict(dictionary: dict, base_folder: Optional[Union[Path, str]] = None) """ if dictionary["relative_paths"]: assert base_folder is not None, "When relative_paths=True, need to provide base_folder" - dictionary = _make_paths_absolute(dictionary, base_folder) + copy = False if dict_contains_extractors(dictionary) else True + dictionary = _make_paths_absolute(dictionary, base_folder, copy=copy) extractor = _load_extractor_from_dict(dictionary) folder_metadata = dictionary.get("folder_metadata", None) if folder_metadata is not None: @@ -463,9 +466,9 @@ def clone(self) -> "BaseExtractor": """ Clones an existing extractor into a new instance. """ - d = self.to_dict(include_annotations=True, include_properties=True) - d = deepcopy(d) - clone = BaseExtractor.from_dict(d) + dictionary = self.to_dict(include_annotations=True, include_properties=True) + dictionary = deepcopy(dictionary) + clone = BaseExtractor.from_dict(dictionary) return clone def check_if_dumpable(self): @@ -557,7 +560,9 @@ def dump(self, file_path: Union[str, Path], relative_to=None, folder_metadata=No else: raise ValueError("Dump: file must .json or .pkl") - def dump_to_json(self, file_path: Union[str, Path, None] = None, relative_to=None, folder_metadata=None) -> None: + def dump_to_json( + self, file_path: Union[str, Path, None] = None, relative_to=None, folder_metadata=None, recursive=False + ) -> None: """ Dump recording extractor to json file. The extractor can be re-loaded with load_extractor_from_json(json_file) @@ -568,10 +573,19 @@ def dump_to_json(self, file_path: Union[str, Path, None] = None, relative_to=Non Path of the json file relative_to: str, Path, or None If not None, file_paths are serialized relative to this path + folder_metadata: str, Path, or None + Folder with numpy files containing additional information (e.g. probe in BaseRecording) and properties. + recursive: bool + If True, all dicitionaries in the kwargs are expanded with `to_dict` as well, by default False. """ - assert self.check_if_dumpable() + assert self.check_if_json_serializable(), "The extractor is not json serializable" dump_dict = self.to_dict( - include_annotations=True, include_properties=False, relative_to=relative_to, folder_metadata=folder_metadata + include_annotations=True, + include_properties=False, + relative_to=relative_to, + folder_metadata=folder_metadata, + recursive=recursive, + skip_recursive_path_modifier_warning=True, # we skip warning because we will make paths absolute again ) file_path = self._get_file_path(file_path, [".json"]) @@ -579,6 +593,9 @@ def dump_to_json(self, file_path: Union[str, Path, None] = None, relative_to=Non json.dumps(dump_dict, indent=4, cls=SIJsonEncoder), encoding="utf8", ) + if relative_to: + # Make paths absolute again + dump_dict = _make_paths_absolute(dump_dict, relative_to, copy=False, skip_warning=True) def dump_to_pickle( self, @@ -603,18 +620,23 @@ def dump_to_pickle( recursive: bool If True, all dicitionaries in the kwargs are expanded with `to_dict` as well, by default False. """ - assert self.check_if_dumpable() + assert self.check_if_dumpable(), "The extractor is not dumpable to pickle" + if relative_to: + assert recursive, "When relative_to is given, recursive must be True" dump_dict = self.to_dict( include_annotations=True, include_properties=include_properties, relative_to=relative_to, folder_metadata=folder_metadata, recursive=recursive, + skip_recursive_path_modifier_warning=True, # we skip warning because we will make paths absolute again ) file_path = self._get_file_path(file_path, [".pkl", ".pickle"]) file_path.write_bytes(pickle.dumps(dump_dict)) + # we don't need to make paths absolute, because for pickle this is only available for recursive=True + @staticmethod def load(file_path: Union[str, Path], base_folder=None) -> "BaseExtractor": """ @@ -630,16 +652,16 @@ def load(file_path: Union[str, Path], base_folder=None) -> "BaseExtractor": # standard case based on a file (json or pickle) if str(file_path).endswith(".json"): with open(str(file_path), "r") as f: - d = json.load(f) + dictionary = json.load(f) elif str(file_path).endswith(".pkl") or str(file_path).endswith(".pickle"): with open(str(file_path), "rb") as f: - d = pickle.load(f) + dictionary = pickle.load(f) else: raise ValueError(f"Impossible to load {file_path}") - if "warning" in d and "not dumpable" in d["warning"]: + if "warning" in dictionary and "not dumpable" in dictionary["warning"]: print("The extractor was not dumpable") return None - extractor = BaseExtractor.from_dict(d, base_folder=base_folder) + extractor = BaseExtractor.from_dict(dictionary, base_folder=base_folder) return extractor elif file_path.is_dir(): @@ -920,16 +942,20 @@ def save_to_zarr( return cached -def _make_paths_relative(d, relative) -> dict: - relative = str(Path(relative).absolute()) - func = lambda p: os.path.relpath(str(p), start=relative) - return recursive_path_modifier(d, func, target="path", copy=True) +def _make_paths_relative(d, relative, copy=True, skip_warning=False) -> dict: + relative = Path(relative).absolute() + func = lambda p: os.path.relpath(Path(p).resolve().absolute(), start=relative) + return recursive_path_modifier( + d, func, target="path", copy=copy, skip_targets=["relative_paths"], skip_warning=skip_warning + ) -def _make_paths_absolute(d, base): +def _make_paths_absolute(d, base, copy=True, skip_warning=False) -> dict: base = Path(base) func = lambda p: str((base / p).resolve().absolute()) - return recursive_path_modifier(d, func, target="path", copy=True) + return recursive_path_modifier( + d, func, target="path", copy=copy, skip_targets=["relative_paths"], skip_warning=skip_warning + ) def _load_extractor_from_dict(dic) -> BaseExtractor: diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 8c24e4e624..8281f56526 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -385,8 +385,8 @@ def has_time_vector(self, segment_index=None): """ segment_index = self._check_segment_index(segment_index) rs = self._recording_segments[segment_index] - d = rs.get_times_kwargs() - return d["time_vector"] is not None + time_kwargs = rs.get_times_kwargs() + return time_kwargs["time_vector"] is not None def set_times(self, times, segment_index=None, with_warning=True): """Set times for a recording segment. @@ -501,8 +501,8 @@ def _save(self, format="binary", **save_kwargs): # save time vector if any t_starts = np.zeros(self.get_num_segments(), dtype="float64") * np.nan for segment_index, rs in enumerate(self._recording_segments): - d = rs.get_times_kwargs() - time_vector = d["time_vector"] + time_kwargs = rs.get_times_kwargs() + time_vector = time_kwargs["time_vector"] if time_vector is not None: _ = zarr_root.create_dataset( name=f"times_seg{segment_index}", @@ -511,7 +511,7 @@ def _save(self, format="binary", **save_kwargs): compressor=zarr_kwargs["compressor"], ) elif d["t_start"] is not None: - t_starts[segment_index] = d["t_start"] + t_starts[segment_index] = time_kwargs["t_start"] if np.any(~np.isnan(t_starts)): zarr_root.create_dataset(name="t_starts", data=t_starts, compressor=None) @@ -530,8 +530,8 @@ def _save(self, format="binary", **save_kwargs): cached.set_probegroup(probegroup) for segment_index, rs in enumerate(self._recording_segments): - d = rs.get_times_kwargs() - time_vector = d["time_vector"] + time_kwargs = rs.get_times_kwargs() + time_vector = time_kwargs["time_vector"] if time_vector is not None: cached._recording_segments[segment_index].time_vector = time_vector @@ -559,8 +559,8 @@ def _extra_metadata_to_folder(self, folder): # save time vector if any for segment_index, rs in enumerate(self._recording_segments): - d = rs.get_times_kwargs() - time_vector = d["time_vector"] + time_kwargs = rs.get_times_kwargs() + time_vector = time_kwargs["time_vector"] if time_vector is not None: np.save(folder / f"times_cached_seg{segment_index}.npy", time_vector) diff --git a/src/spikeinterface/core/binaryfolder.py b/src/spikeinterface/core/binaryfolder.py index d9a4ce0963..12d60c2e2f 100644 --- a/src/spikeinterface/core/binaryfolder.py +++ b/src/spikeinterface/core/binaryfolder.py @@ -33,22 +33,23 @@ def __init__(self, folder_path): folder_path = Path(folder_path) with open(folder_path / "binary.json", "r") as f: - d = json.load(f) + dictionary = json.load(f) - if not d["class"].endswith(".BinaryRecordingExtractor"): + if not dictionary["class"].endswith(".BinaryRecordingExtractor"): raise ValueError("This folder is not a binary spikeinterface folder") - assert d["relative_paths"] + assert dictionary["relative_paths"] - d = _make_paths_absolute(d, folder_path) + kwargs = dictionary["kwargs"] + kwargs = _make_paths_absolute(kwargs, folder_path) - BinaryRecordingExtractor.__init__(self, **d["kwargs"]) + BinaryRecordingExtractor.__init__(self, **kwargs) folder_metadata = folder_path self.load_metadata_from_folder(folder_metadata) self._kwargs = dict(folder_path=str(folder_path.absolute())) - self._bin_kwargs = d["kwargs"] + self._bin_kwargs = kwargs if "num_channels" not in self._bin_kwargs: assert "num_chan" in self._bin_kwargs, "Cannot find num_channels or num_chan in binary.json" self._bin_kwargs["num_channels"] = self._bin_kwargs["num_chan"] diff --git a/src/spikeinterface/core/core_tools.py b/src/spikeinterface/core/core_tools.py index 106a794f6e..30d104149d 100644 --- a/src/spikeinterface/core/core_tools.py +++ b/src/spikeinterface/core/core_tools.py @@ -1,24 +1,24 @@ -from pathlib import Path -from typing import Union -import os -import sys import datetime -import json -from copy import deepcopy import gc +import json import mmap -import inspect +import os +import sys +import warnings +from copy import deepcopy +from pathlib import Path +from typing import Union import numpy as np from tqdm import tqdm from .job_tools import ( + ChunkRecordingExecutor, + _shared_job_kwargs_doc, + divide_segment_into_chunks, ensure_chunk_size, ensure_n_jobs, - divide_segment_into_chunks, fix_job_kwargs, - ChunkRecordingExecutor, - _shared_job_kwargs_doc, ) @@ -42,9 +42,10 @@ def read_python(path): dictionary containing parsed file """ - from six import exec_ import re + from six import exec_ + path = Path(path).absolute() assert path.is_file() with path.open("r") as f: @@ -786,7 +787,42 @@ def is_dict_extractor(d): return is_extractor -def recursive_path_modifier(d, func, target="path", copy=True) -> dict: +def dict_contains_extractors(dictionary): + """ + Checks if a dictionary containes BaseExtractor objects. + + Parameters + ---------- + dictionary : dict + Dictionary to check + + Returns + ------- + bool + True if the dictionary contains extractors, False otherwise + """ + from .base import BaseExtractor + + contains_extractors = False + for name, value in dictionary.items(): + if isinstance(value, dict): + return dict_contains_extractors(value) + elif isinstance(value, list): + if any([isinstance(v, BaseExtractor) for v in value]): + return True + else: + all_vals = [dict_contains_extractors(v) for v in value if isinstance(v, dict)] + if len(all_vals) > 0: + return all(all_vals) + else: + return False + else: + if isinstance(value, BaseExtractor): + contains_extractors = True + return contains_extractors + + +def recursive_path_modifier(dictionary, func, target="path", copy=True, skip_targets=None, skip_warning=False) -> dict: """ Generic function for recursive modification of paths in an extractor dict. A recording can be nested and this function explores the dictionary recursively @@ -796,61 +832,84 @@ def recursive_path_modifier(d, func, target="path", copy=True) -> dict: * relative/absolute path change * docker rebase path change - Modification is inplace with an optional copy. + Modification is in-place with an optional copy. + If the dictionary contains extractors, the copy argument is ignored and an exception is raised. Parameters ---------- - d : dict - Extractor dictionary + dictionary : dict + Extractor dictionary, including "kwargs" key. func : function Function to apply to the path. It must take a path as input and return a path target : str, optional String to match to dictionary key, by default 'path' copy : bool, optional If True the original dictionary is deep copied, by default True (at first call) + skip_targets : list or str, optional + List of targets to skip + skip_warning : bool, optional + If True, skip the warning when the dictionary contains extractors, by default False Returns ------- dict Modified dictionary """ + from .base import BaseExtractor + + if skip_targets is None: + skip_targets = [] + elif isinstance(skip_targets, str): + skip_targets = [skip_targets] + + if dict_contains_extractors(dictionary): + if copy: + raise Exception("The copy argument is only available if the input dictionary does not contain objects") + else: + if not skip_warning: + warnings.warn( + "The dictionary contains extractors, so the paths will be modified in-place!" "Use with caution" + ) + if copy: - dc = deepcopy(d) + dc = deepcopy(dictionary) else: - dc = d + dc = dictionary - if "kwargs" in dc.keys(): + if "kwargs" in dc: kwargs = dc["kwargs"] - - # change in place (copy=False) recursive_path_modifier(kwargs, func, copy=False) - - # find nested and also change inplace (copy=False) - nested_extractor_dict = None - for k, v in kwargs.items(): - if isinstance(v, dict) and is_dict_extractor(v): - nested_extractor_dict = v - recursive_path_modifier(nested_extractor_dict, func, copy=False) - # deal with list of extractor objects (e.g. concatenate_recordings) - elif isinstance(v, list): - for vl in v: - if isinstance(vl, dict) and is_dict_extractor(vl): - nested_extractor_dict = vl - recursive_path_modifier(nested_extractor_dict, func, copy=False) - - return dc else: - for k, v in d.items(): - if target in k: - # paths can be str or list of str or None - if v is None: - continue - if isinstance(v, (str, Path)): - dc[k] = func(v) - elif isinstance(v, list): - dc[k] = [func(e) for e in v] - else: - raise ValueError(f"{k} key for path must be str or list[str]") + for name, value in dictionary.items(): + # here we check if the value is a BaseExtractor, a list of BaseExtractors, or a dict of BaseExtractors + if isinstance(value, BaseExtractor): + kwargs = value._kwargs + recursive_path_modifier(kwargs, func, copy=False) + elif isinstance(value, list) and (len(value) > 0) and isinstance(value[0], BaseExtractor): + for v in value: + kwargs = v._kwargs + recursive_path_modifier(kwargs, func, copy=False) + elif isinstance(value, dict): + if isinstance(value[list(value.keys())[0]], BaseExtractor): + for v in value.values(): + kwargs = v._kwargs + recursive_path_modifier(kwargs, func, copy=False) + elif is_dict_extractor(value): + kwargs = value["kwargs"] + recursive_path_modifier(kwargs, func, copy=False) + else: + # relative_paths is protected! + if target in name and target not in skip_targets: + # paths can be str or list of str or None + if value is None: + continue + if isinstance(value, (str, Path)): + dc[name] = func(value) + elif isinstance(value, list): + dc[name] = [func(e) for e in value] + else: + raise ValueError(f"{name} key for path must be str or list[str]") + return dc def recursive_key_finder(d, key): diff --git a/src/spikeinterface/core/npyfoldersnippets.py b/src/spikeinterface/core/npyfoldersnippets.py index b7c773aad3..4f5491f166 100644 --- a/src/spikeinterface/core/npyfoldersnippets.py +++ b/src/spikeinterface/core/npyfoldersnippets.py @@ -34,22 +34,23 @@ def __init__(self, folder_path): folder_path = Path(folder_path) with open(folder_path / "npy.json", "r") as f: - d = json.load(f) + dictionary = json.load(f) - if not d["class"].endswith(".NpySnippetsExtractor"): + if not dictionary["class"].endswith(".NpySnippetsExtractor"): raise ValueError("This folder is not a binary spikeinterface folder") - assert d["relative_paths"] + assert dictionary["relative_paths"] - d = _make_paths_absolute(d, folder_path) + kwargs = dictionary["kwargs"] + kwargs = _make_paths_absolute(kwargs, folder_path) - NpySnippetsExtractor.__init__(self, **d["kwargs"]) + NpySnippetsExtractor.__init__(self, **kwargs) folder_metadata = folder_path self.load_metadata_from_folder(folder_metadata) self._kwargs = dict(folder_path=str(folder_path.absolute())) - self._bin_kwargs = d["kwargs"] + self._bin_kwargs = kwargs read_npy_snippets_folder = define_function_from_class(source_class=NpyFolderSnippets, name="read_npy_snippets_folder") diff --git a/src/spikeinterface/core/npzfolder.py b/src/spikeinterface/core/npzfolder.py index 9d2eb43af6..c0947fd4d6 100644 --- a/src/spikeinterface/core/npzfolder.py +++ b/src/spikeinterface/core/npzfolder.py @@ -33,22 +33,23 @@ def __init__(self, folder_path): folder_path = Path(folder_path) with open(folder_path / "npz.json", "r") as f: - d = json.load(f) + dictionary = json.load(f) - if not d["class"].endswith(".NpzSortingExtractor"): + if not dictionary["class"].endswith(".NpzSortingExtractor"): raise ValueError("This folder is not an npz spikeinterface folder") - assert d["relative_paths"] + assert dictionary["relative_paths"] - d = _make_paths_absolute(d, folder_path) + kwargs = dictionary["kwargs"] + kwargs = _make_paths_absolute(kwargs, folder_path) - NpzSortingExtractor.__init__(self, **d["kwargs"]) + NpzSortingExtractor.__init__(self, **kwargs) folder_metadata = folder_path self.load_metadata_from_folder(folder_metadata) self._kwargs = dict(folder_path=str(folder_path.absolute())) - self._npz_kwargs = d["kwargs"] + self._npz_kwargs = kwargs read_npz_folder = define_function_from_class(source_class=NpzFolderSorting, name="read_npz_folder") diff --git a/src/spikeinterface/core/tests/test_base.py b/src/spikeinterface/core/tests/test_base.py index ea1a9cf0d2..55d36e4500 100644 --- a/src/spikeinterface/core/tests/test_base.py +++ b/src/spikeinterface/core/tests/test_base.py @@ -2,9 +2,15 @@ test for BaseRecording are done with BinaryRecordingExtractor. but check only for BaseRecording general methods. """ +import pytest +import numpy as np from typing import Sequence +from pathlib import Path + from spikeinterface.core.base import BaseExtractor from spikeinterface.core import generate_recording, concatenate_recordings +from spikeinterface.core.core_tools import dict_contains_extractors +from spikeinterface.core.testing import check_recordings_equal class DummyDictExtractor(BaseExtractor): @@ -14,6 +20,15 @@ def __init__(self, main_ids: Sequence, base_dicts=None) -> None: self._kwargs = dict(base_dicts=base_dicts) +def generate(): + return generate_recording(seed=0, durations=[2]) + + +@pytest.fixture +def recording(): + return generate() + + def make_nested_extractors(extractor): extractor_wih_parent = extractor.frame_slice(start_frame=0, end_frame=100) extractor_with_parent_list = concatenate_recordings([extractor, extractor]) @@ -31,8 +46,8 @@ def make_nested_extractors(extractor): ) -def test_check_if_dumpable(): - test_extractor = generate_recording(seed=0, durations=[2]) +def test_check_if_dumpable(recording): + test_extractor = recording # make a list of dumpable objects extractors_dumpable = make_nested_extractors(test_extractor) @@ -46,8 +61,8 @@ def test_check_if_dumpable(): assert not extractor.check_if_dumpable() -def test_check_if_json_serializable(): - test_extractor = generate_recording(seed=0, durations=[2]) +def test_check_if_json_serializable(recording): + test_extractor = recording # make a list of dumpable objects test_extractor._is_json_serializable = True @@ -64,6 +79,74 @@ def test_check_if_json_serializable(): assert not extractor.check_if_json_serializable() +def test_to_dict(recording): + d0 = recording.to_dict() + d0_recursive = recording.to_dict(recursive=True) + assert not dict_contains_extractors(d0) + assert not dict_contains_extractors(d0_recursive) + + nested_extractors = make_nested_extractors(recording) + for extractor in nested_extractors: + d1 = extractor.to_dict() + d1_recursive = extractor.to_dict(recursive=True) + + assert dict_contains_extractors(d1) + assert not dict_contains_extractors(d1_recursive) + + +def test_relative_to(recording, tmp_path): + recording_saved = recording.save(folder=tmp_path / "test") + folder_path = Path(recording_saved._kwargs["folder_path"]) + relative_folder = tmp_path.parent + + d1 = recording_saved.to_dict(recursive=True) + d2 = recording_saved.to_dict(recursive=True, relative_to=relative_folder) + + assert d1["kwargs"]["folder_path"] == str(folder_path.absolute()) + assert d2["kwargs"]["folder_path"] != str(folder_path.absolute()) + assert d2["kwargs"]["folder_path"] == str(folder_path.relative_to(relative_folder)) + assert ( + str((relative_folder / Path(d2["kwargs"]["folder_path"])).resolve().absolute()) == d1["kwargs"]["folder_path"] + ) + + recording_loaded = BaseExtractor.from_dict(d2, base_folder=relative_folder) + check_recordings_equal(recording_saved, recording_loaded, return_scaled=False) + + # test double pass in memory + recording_nested = recording_saved.channel_slice(recording_saved.channel_ids) + with pytest.warns(UserWarning): + d3 = recording_nested.to_dict(relative_to=relative_folder) + recording_loaded2 = BaseExtractor.from_dict(d3, base_folder=relative_folder) + check_recordings_equal(recording_nested, recording_loaded2, return_scaled=False) + d4 = recording_nested.to_dict(relative_to=relative_folder) + recording_loaded3 = BaseExtractor.from_dict(d4, base_folder=relative_folder) + check_recordings_equal(recording_nested, recording_loaded3, return_scaled=False) + + # check that dump to json/pickle don't modify paths + full_folder_path = str(recording_saved._kwargs["folder_path"]) + recording_saved.dump_to_json(tmp_path / "test.json", relative_to=relative_folder) + assert str(recording_saved._kwargs["folder_path"]) == full_folder_path + assert str(recording_saved._kwargs["folder_path"]) == full_folder_path + # now with nested + recording_nested.dump_to_json(tmp_path / "test_nested.json", relative_to=relative_folder) + assert str(recording_saved._kwargs["folder_path"]) == full_folder_path + + # this raises an exception + np.testing.assert_raises( + AssertionError, + recording_nested.dump_to_pickle, + file_path=tmp_path / "test_nested.pkl", + relative_to=relative_folder, + ) + + if __name__ == "__main__": - test_check_if_dumpable() - test_check_if_json_serializable() + recording = generate() + test_check_if_dumpable(recording) + test_check_if_json_serializable(recording) + test_to_dict(recording) + import tempfile + + with tempfile.TemporaryDirectory() as tmpdirname: + tmp_path = Path(tmpdirname) + test_relative_to(recording, tmp_path) diff --git a/src/spikeinterface/core/tests/test_basesorting.py b/src/spikeinterface/core/tests/test_basesorting.py index 6e471121b6..9214f4b0e4 100644 --- a/src/spikeinterface/core/tests/test_basesorting.py +++ b/src/spikeinterface/core/tests/test_basesorting.py @@ -134,7 +134,7 @@ def test_npy_sorting(): seg_nframes = [9, 5] rec = NumpyRecording([np.zeros((nframes, 10)) for nframes in seg_nframes], sampling_frequency=sfreq) # assert_raises(Exception, sorting.register_recording, rec) - with pytest.warns(): + with pytest.warns(UserWarning): sorting.register_recording(rec) # Registering a rec with too many segments diff --git a/src/spikeinterface/core/tests/test_core_tools.py b/src/spikeinterface/core/tests/test_core_tools.py index 89a4143e19..36b2e5ca2d 100644 --- a/src/spikeinterface/core/tests/test_core_tools.py +++ b/src/spikeinterface/core/tests/test_core_tools.py @@ -148,7 +148,7 @@ def test_write_memory_recording(): def test_recursive_path_modifier(): # this test nested depth 2 path modifier - d = { + d1 = { "kwargs": { "path": "/yep/path1", "recording": { @@ -161,7 +161,7 @@ def test_recursive_path_modifier(): } } - d2 = recursive_path_modifier(d, lambda p: p.replace("/yep", "/yop")) + d2 = recursive_path_modifier(d1, lambda p: p.replace("/yep", "/yop")) assert d2["kwargs"]["path"].startswith("/yop") assert d2["kwargs"]["recording"]["kwargs"]["path"].startswith("/yop") @@ -173,5 +173,6 @@ def test_recursive_path_modifier(): with tempfile.TemporaryDirectory() as tmpdirname: tmp_path = Path(tmpdirname) test_write_binary_recording(tmp_path) - # test_write_memory_recording() - # test_recursive_path_modifier() + test_write_memory_recording() + test_recursive_path_modifier() + test_recursive_path_modifier() diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index 546bee2ec1..c279f57d01 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -921,10 +921,12 @@ def save( zarr_root.attrs["params"] = check_json(self._params) if self.has_recording(): if self.recording.check_if_json_serializable(): - rec_dict = self.recording.to_dict(relative_to=relative_to) + # use recursive True to avoid modifying objects in place + rec_dict = self.recording.to_dict(relative_to=relative_to, recursive=True) zarr_root.attrs["recording"] = check_json(rec_dict) if self.sorting.check_if_json_serializable(): - sort_dict = self.sorting.to_dict(relative_to=relative_to) + # use recursive True to avoid modifying objects in place + sort_dict = self.sorting.to_dict(relative_to=relative_to, recursive=True) zarr_root.attrs["sorting"] = check_json(sort_dict) else: warn( diff --git a/src/spikeinterface/preprocessing/motion.py b/src/spikeinterface/preprocessing/motion.py index 957d4f588e..2d7ac0ced2 100644 --- a/src/spikeinterface/preprocessing/motion.py +++ b/src/spikeinterface/preprocessing/motion.py @@ -331,7 +331,8 @@ def correct_motion( ) (folder / "parameters.json").write_text(json.dumps(parameters, indent=4, cls=SIJsonEncoder), encoding="utf8") (folder / "run_times.json").write_text(json.dumps(run_times, indent=4), encoding="utf8") - recording.dump_to_json(folder / "recording.json") + if recording.check_if_json_serializable(): + recording.dump_to_json(folder / "recording.json") np.save(folder / "peaks.npy", peaks) np.save(folder / "peak_locations.npy", peak_locations) diff --git a/src/spikeinterface/sorters/runsorter.py b/src/spikeinterface/sorters/runsorter.py index 6e6ccc0358..ca26c2bc8a 100644 --- a/src/spikeinterface/sorters/runsorter.py +++ b/src/spikeinterface/sorters/runsorter.py @@ -385,8 +385,9 @@ def run_sorter_container( parent_folder = output_folder.parent.absolute().resolve() parent_folder.mkdir(parents=True, exist_ok=True) - # find input folder of recording for folder bind + # here we need recursive True because we need a copy of the dict to be saved to JSON rec_dict = recording.to_dict(recursive=True) + # find input folder of recording for folder bind recording_input_folders = find_recording_folders(rec_dict) if platform.system() == "Windows": diff --git a/src/spikeinterface/sorters/tests/test_launcher.py b/src/spikeinterface/sorters/tests/test_launcher.py index cd8bc0fa5d..84d234e06b 100644 --- a/src/spikeinterface/sorters/tests/test_launcher.py +++ b/src/spikeinterface/sorters/tests/test_launcher.py @@ -245,9 +245,9 @@ def test_sorter_installation(): # pass # test_run_sorters_with_list() - # test_run_sorter_by_property() + test_run_sorter_by_property() - test_run_sorters_with_dict() + # test_run_sorters_with_dict() # test_run_sorters_joblib()