Source code for sfepy.base.ioutils

import numpy as nm
import os
import os.path as op
import fnmatch
import shutil
from base import output, Struct, basestr
try:
    import tables as pt
except:
    pt = None

[docs]class InDir(Struct): """ Store the directory name a file is in, and prepend this name to other files. Examples -------- >>> indir = InDir('output/file1') >>> print indir('file2') """ def __init__(self, filename): self.dir = op.split(op.join(os.getcwd(), filename))[0] def __call__(self, filename): return op.join(self.dir, filename)
[docs]def ensure_path(filename): """ Check if path to `filename` exists and if not, create the necessary intermediate directories. """ dirname = os.path.dirname(filename) if dirname: if not os.path.exists(dirname): os.makedirs(dirname) if not os.path.isdir(dirname): raise IOError('cannot ensure path for "%s"!' % filename)
[docs]def locate_files(pattern, root_dir=os.curdir): """ Locate all files matching fiven filename pattern in and below supplied root directory. """ for dirpath, dirnames, filenames in os.walk(os.path.abspath(root_dir)): for filename in fnmatch.filter(filenames, pattern): yield os.path.join(dirpath, filename)
[docs]def remove_files(root_dir): """ Remove all files and directories in supplied root directory. """ for dirpath, dirnames, filenames in os.walk(os.path.abspath(root_dir)): for filename in filenames: os.remove(os.path.join(root_dir, filename)) for dirname in dirnames: shutil.rmtree(os.path.join(root_dir, dirname))
## # 27.04.2006, c
[docs]def get_trunk(filename): return op.splitext(op.basename(filename))[0]
[docs]def edit_filename(filename, prefix='', suffix='', new_ext=None): """ Edit a file name by add a prefix, inserting a suffix in front of a file name extension or replacing the extension. Parameters ---------- filename : str The file name. prefix : str The prefix to be added. suffix : str The suffix to be inserted. new_ext : str, optional If not None, it replaces the original file name extension. Returns ------- new_filename : str The new file name. """ path, filename = os.path.split(filename) base, ext = os.path.splitext(filename) if new_ext is None: new_filename = prefix + base + suffix + ext else: new_filename = prefix + base + suffix + new_ext return os.path.join(path, new_filename)
[docs]def get_print_info(n_step, fill=None): """ Returns the max. number of digits in range(n_step) and the corresponding format string. Examples: >>> get_print_info(11) (2, '%2d') >>> get_print_info(8) (1, '%1d') >>> get_print_info(100) (2, '%2d') >>> get_print_info(101) (3, '%3d') >>> get_print_info(101, fill='0') (3, '%03d') """ if n_step > 1: n_digit = int(nm.log10(n_step - 1) + 1) if fill is None: format = '%%%dd' % n_digit else: format = '%%%s%dd' % (fill, n_digit) else: n_digit, format = 0, None return n_digit, format
[docs]def skip_read_line(fd, no_eof=False): """ Read the first non-empty line (if any) from the given file object. Return an empty string at EOF, if `no_eof` is False. If it is True, raise the EOFError instead. """ ls = '' while 1: try: line = fd.readline() except EOFError: break if not line: if no_eof: raise EOFError else: break ls = line.strip() if ls and (ls[0] != '#'): break return ls
[docs]def read_token(fd): """ Read a single token (sequence of non-whitespace characters) from the given file object. Notes ----- Consumes the first whitespace character after the token. """ out = '' # Skip initial whitespace. while 1: ch = fd.read(1) if ch.isspace(): continue elif len(ch) == 0: return out else: break while not ch.isspace(): out = out + ch ch = fd.read(1) if len(ch) == 0: break return out
[docs]def read_array(fd, n_row, n_col, dtype): """ Read a NumPy array of shape `(n_row, n_col)` from the given file object and cast it to type `dtype`. If `n_col` is None, determine the number of columns automatically. """ if n_col is None: idx = fd.tell() row = fd.readline().split() fd.seek(idx) n_col = len(row) count = n_row * n_col val = nm.fromfile(fd, sep=' ', count=count) if val.shape[0] < count: raise ValueError('(%d, %d) array reading failed!' % (n_row, n_col)) val = nm.asarray(val, dtype=dtype) val.shape = (n_row, n_col) return val
## # c: 05.02.2008, r: 05.02.2008
[docs]def read_list(fd, n_item, dtype): vals = [] ii = 0 while ii < n_item: line = [dtype(ic) for ic in fd.readline().split()] vals.append(line) ii += len(line) if ii > n_item: output('corrupted row?', line, ii, n_item) raise ValueError return vals
[docs]def write_dict_hdf5(filename, adict, level=0, group=None, fd=None): if level == 0: fd = pt.openFile(filename, mode='w', title='Recursive dict dump') group = '/' for key, val in adict.iteritems(): if isinstance(val, dict): group2 = fd.createGroup(group, '_' + str(key), '%s group' % key) write_dict_hdf5(filename, val, level + 1, group2, fd) else: fd.createArray(group, '_' + str(key), val, '%s data' % key) if level == 0: fd.close()
[docs]def read_dict_hdf5(filename, level=0, group=None, fd=None): out = {} if level == 0: fd = pt.openFile(filename, mode='r') group = fd.root for name, gr in group._v_groups.iteritems(): name = name.replace('_', '', 1) out[name] = read_dict_hdf5(filename, level + 1, gr, fd) for name, data in group._v_leaves.iteritems(): name = name.replace('_', '', 1) out[name] = data.read() if level == 0: fd.close() return out
## # 02.07.2007, c
[docs]def write_sparse_matrix_hdf5(filename, mtx, name='a sparse matrix'): """Assume CSR/CSC.""" fd = pt.openFile(filename, mode='w', title=name) try: info = fd.createGroup('/', 'info') fd.createArray(info, 'dtype', mtx.dtype.str) fd.createArray(info, 'shape', mtx.shape) fd.createArray(info, 'format', mtx.format) data = fd.createGroup('/', 'data') fd.createArray(data, 'data', mtx.data) fd.createArray(data, 'indptr', mtx.indptr) fd.createArray(data, 'indices', mtx.indices) except: print 'matrix must be in SciPy sparse CSR/CSC format!' print mtx.__repr__() raise fd.close()
## # 02.07.2007, c # 08.10.2007
[docs]def read_sparse_matrix_hdf5(filename, output_format=None): import scipy.sparse as sp constructors = {'csr' : sp.csr_matrix, 'csc' : sp.csc_matrix} fd = pt.openFile(filename, mode='r') info = fd.root.info data = fd.root.data format = info.format.read() if not isinstance(format, basestr): format = format[0] dtype = info.dtype.read() if not isinstance(dtype, basestr): dtype = dtype[0] if output_format is None: constructor = constructors[format] else: constructor = constructors[output_format] if format in ['csc', 'csr']: mtx = constructor((data.data.read(), data.indices.read(), data.indptr.read()), shape=info.shape.read(), dtype=dtype) elif format == 'coo': mtx = constructor((data.data.read(), nm.c_[data.rows.read(), data.cols.read()].T), shape=info.shape.read(), dtype=dtype) else: print format raise ValueError fd.close() if output_format in ['csc', 'csr']: mtx.sort_indices() return mtx