Source code for bmtk.simulator.bionet.modules.record_netcons

import os
import csv
import h5py
import numpy as np
from neuron import h

from .sim_module import SimulatorMod
from bmtk.simulator.bionet.biocell import BioCell
# from bmtk.simulator.bionet.io_tools import io
# from bmtk.simulator.bionet.pointprocesscell import PointProcessCell
from bmtk.utils.reports import CompartmentReport

try:
    # Check to see if h5py is built to run in parallel
    if h5py.get_config().mpi:
        MembraneRecorder = CompartmentReport  # cell_vars.CellVarRecorderParallel
    else:
        MembraneRecorder = CompartmentReport  # cell_vars.CellVarRecorder
except Exception as e:
    MembraneRecorder = CompartmentReport  # cell_vars.CellVarRecorder

pc = h.ParallelContext()
MPI_RANK = int(pc.id())
N_HOSTS = int(pc.nhost())


[docs]class NetconReport(SimulatorMod): def __init__(self, tmp_dir, file_name, variable_name, cells, sections='all', syn_type='Exp2Syn', buffer_data=True, transform={}): """Module used for saving NEURON cell properities at each given step of the simulation. :param tmp_dir: :param file_name: name of h5 file to save variable. :param variables: list of cell variables to record :param gids: list of gids to to record :param sections: :param buffer_data: Set to true then data will be saved to memory until written to disk during each block, reqs. more memory but faster. Set to false and data will be written to disk on each step (default: True) """ self._all_variables = list(variable_name) self._variables = list(variable_name) self._tmp_dir = tmp_dir self._file_name = file_name if os.path.isabs(file_name) else os.path.join(tmp_dir, file_name) self._all_gids = cells self._local_gids = [] self._sections = sections self._var_recorder = None #self._var_recorder = MembraneRecorder(self._file_name, self._tmp_dir, self._all_variables, # buffer_data=buffer_data, mpi_rank=MPI_RANK, mpi_size=N_HOSTS) self._virt_lookup = {} self._gid_lookup = {} self._sec_lookup = {} self._gid_list = [] # list of all gids that will have their variables saved self._data_block = {} # table of variable data indexed by [gid][variable] self._block_step = 0 # time step within a given block self._object_lookup = {} self._syn_type = syn_type self._gid_map = None def _get_gids(self, sim): selected_gids = set(sim.net.get_node_set(self._all_gids).gids()) self._local_gids = list(set(sim.local_gids) & selected_gids) def _save_sim_data(self, sim): self._var_recorder.tstart = 0.0 self._var_recorder.tstop = sim.tstop self._var_recorder.dt = sim.dt def _get_syn_location(self, nc, cell): if isinstance(cell, BioCell): sec_x = nc.postloc() sec = h.cas() sec_id = self._sec_lookup[cell.gid][sec] # cell.get_section_id(sec) h.pop_section() return sec_id, sec_x else: return -1, -1
[docs] def initialize(self, sim): self._gid_map = sim.net.gid_pool self._get_gids(sim) self._var_recorder = MembraneRecorder(self._file_name, mode='w', variable=self._variables[0], buffer_size=sim.nsteps_block, tstart=0.0, tstop=sim.tstop, dt=sim.dt, n_steps=sim.n_steps) #self._save_sim_data(sim) for node_pop in sim.net.node_populations: pop_name = node_pop.name for node in node_pop[0::1]: if node.model_type != 'virtual': self._gid_lookup[node.gid] = (pop_name, node.node_id) for gid, cell in sim.net.get_local_cells().items(): trg_pop, trg_id = self._gid_lookup[gid] if isinstance(cell, BioCell): self._sec_lookup[gid] = {sec_name: sec_id for sec_id, sec_name in enumerate(cell.get_sections_id())} for gid in self._local_gids: pop_id = self._gid_map.get_pool_id(gid) sec_list = [] seg_list = [] src_list = [] syn_objects = [] cell = sim.net.get_cell_gid(gid) for nc in cell.netcons: synapse = nc.syn() if self._syn_type is None or synapse.hname().startswith(self._syn_type): sec_id, seg_x = self._get_syn_location(nc, cell) src_gid = int(nc.srcgid()) sec_list.append(sec_id) seg_list.append(seg_x) src_list.append(src_gid) syn_objects.append(nc.syn()) elif self._syn_type == 'netcon': syn_objects.append(nc) if syn_objects: # self._var_recorder.add_cell(gid, sec_list, seg_list, src_ids=src_list, trg_ids=[gid]*len(src_list)) self._var_recorder.add_cell(node_id=pop_id.node_id, population=pop_id.population, element_ids=sec_list, element_pos=seg_list, src_ids=src_list, trg_ids=[gid]*len(src_list)) self._object_lookup[gid] = syn_objects # self._var_recorder.initialize(sim.n_steps, sim.nsteps_block) self._var_recorder.initialize()
[docs] def step(self, sim, tstep): # save all necessary cells/variables at the current time-step into memory for gid, netcon_objs in self._object_lookup.items(): pop_id = self._gid_map.get_pool_id(gid) for var_name in self._variables: syn_values = [getattr(syn, var_name) for syn in netcon_objs] if syn_values: self._var_recorder.record_cell(pop_id.node_id, population=pop_id.population, vals=syn_values, tstep=tstep) self._block_step += 1
[docs] def block(self, sim, block_interval): # write variables in memory to file self._var_recorder.flush()
[docs] def finalize(self, sim): # TODO: Build in mpi signaling into var_recorder pc.barrier() self._var_recorder.close()
#pc.barrier() #self._var_recorder.merge()