Source code for gridfix.regionset

#!/usr/bin/python3
# -*- coding: utf-8 -*-

import os

import numpy as np
import matplotlib as mp
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

from PIL import Image

from fractions import Fraction
from pandas import DataFrame, read_table

from .model import ImageSet


[docs]class RegionSet(object): """ Base class for sets of image regions of interest. RegionSets can be used to spatially group fixations, create Feature objects for a FixationModel and split an image into parts. Classes inheriting from RegionSet may specify functions to create regions. Attributes: info (DataFrame): table of region metadata (labels, bboxes, number of pixels...) imageids (list): list of all imageids associated with this RegionSet is_global (bool): True if regions are global (non-image-specific) label (str): optional label to distinguish between RegionSets. memory_usage (float): memory usage of all binary masks (kiB) size (tuple): image dimensions, specified as (width, height). """ def __init__(self, size, regions, region_labels=None, label=None): """ Create a new RegionSet from existing region masks. Args: size (tuple): image dimensions, specified as (width, height) regions: 3d ndarray (bool) with global set of masks OR dict of multiple such ndarrays, with imageids as keys region_labels: list of region labels IF _regions_ is a single array OR dict of such lists, with imageids as keys label (str): optional descriptive label for this RegionSet Raises: ValueError if incorrectly formatted regions/region_labels provided """ self._regions = {'*': np.ndarray((0,0,0))} self._labels = {'*': []} self.size = size self.label = label self._msize = (size[1], size[0]) # matrix convention if isinstance(regions, dict): # Dict with image-specific region ndarrays self._regions = regions if region_labels is not None and isinstance(region_labels, dict) and len(regions) == len(region_labels): # Check imageids for consistency for r in regions.keys(): if r not in region_labels.keys(): raise ValueError('Labels not consistent: {:s} not in region_labels'.format(r)) for r in region_labels.keys(): if r not in regions.keys(): raise ValueError('Labels not consistent: {:s} not in regions'.format(r)) self._labels = region_labels else: self._labels = {} for imid in regions: self._labels[imid] = [str(x+1) for x in range(len(regions[imid]))] elif isinstance(regions, np.ndarray): # Single array of regions - assume global region set ('*') if regions.shape[1:] == self._msize: self._regions['*'] = regions.astype(bool) if region_labels is not None and len(region_labels) == regions.shape[0]: self._labels['*'] = region_labels else: self._labels['*'] = [str(x+1) for x in range(regions.shape[0])] else: raise ValueError('First argument for RegionSet creation must be ndarray ' + '(global regions) or dict of ndarrays (image-specific regions)!') self.info = self._region_metadata() def __repr__(self): """ String representation """ r = 'gridfix.RegionSet(label={:s}, size=({:d}, {:d}),\nregions={:s},\nregion_labels={:s})' return r.format(str(self.label), self.size[0], self.size[1], str(self._regions), str(self._labels)) def __str__(self): """ Short string representation for printing """ r = '<{:s}{:s}, size={:s}, {:d} region{:s}{:s}, memory={:.1f} kB>' myclass = str(self.__class__.__name__) if self.label is not None: lab = ' ({:s})'.format(self.label) else: lab = '' num_s = '' num_r = len(self) if num_r > 1: num_s = 's' imid_s = '' if len(self._regions) > 1 and not self.is_global: imid_s = ' in {:d} images'.format(len(self._regions)) return r.format(myclass, lab, str(self.size), num_r, num_s, imid_s, self.memory_usage) def __len__(self): """ Overload len(RegionSet) to report total number of regions. """ if self.is_global: return len(self._regions['*']) else: num_r = 0 for imid in self._regions: num_r += len(self._regions[imid]) return num_r def __getitem__(self, imageid): """ Bracket indexing returns all region masks for a specified imageid. If global regions are set ('*'), always return global region set. """ return self._select_region(imageid) def _region_metadata(self): """ Return DataFrame of region metadata """ info_cols = ['imageid', 'regionid', 'regionno', 'left', 'top', 'right', 'bottom', 'width', 'height', 'area', 'imgfrac'] info = [] if self.is_global: imageids = ['*'] else: imageids = self.imageids for imid in imageids: reg = self._select_region(imid) lab = self._select_labels(imid) for i,l in enumerate(lab): a = np.argwhere(reg[i]) (top, left) = a.min(0)[0:2] (bottom, right) = a.max(0)[0:2] area = reg[i][reg[i] > 0].sum() imgfrac = round(area / (reg[i].shape[0] * reg[i].shape[1]), 4) rmeta = [imid, l, i+1, left, top, right, bottom, right-left+1, bottom-top+1, area, imgfrac] info.append(rmeta) return DataFrame(info, columns=info_cols) def _select_region(self, imageid=None): """ Select region by imageid with consistency check """ if self.is_global: return(self._regions['*']) if imageid is not None and imageid in self._regions.keys(): return(self._regions[imageid]) else: raise ValueError('RegionSet contains image-specific regions, but no valid imageid was specified!') def _select_labels(self, imageid=None): """ Select region labels corresponding to _select_region """ if self.is_global: return(self._labels['*']) if imageid is not None and imageid in self._regions.keys(): return(self._labels[imageid]) else: raise ValueError('RegionSet contains image-specific regions, but no valid imageid was specified!') @property def is_global(self): """ Return True if a global map is defined (key '*') """ if '*' in self._regions.keys(): return True else: return False @property def imageids(self): """ Return list of imageids for which region maps exist """ if self.is_global: return [] imids = [] for imid in self._regions.keys(): imids.append(imid) return imids @property def memory_usage(self): """ Calculate size in memory of all regions combined """ msize = 0.0 for reg in self._regions.keys(): msize += float(self._regions[reg].nbytes) / 1024.0 return msize
[docs] def count_map(self, imageid=None): """ Return the number of regions referencing each pixel. Args: imageid (str): if set, return map for specified image only Returns: 2d ndarray of image size, counting number of regions for each pixel """ cm = np.zeros(self._msize, dtype=int) if self.is_global: for re in self._regions['*'][:, ...]: cm += re.astype(int) return cm elif imageid is None: for imid in self._regions: if imid == '*': continue for re in self._regions[imid][:, ...]: cm += re.astype(int) else: r = self._select_region(imageid) for re in r[:, ...]: cm += re.astype(int) return cm
[docs] def mask(self, imageid=None): """ Return union mask of all regions or regions for specified image. Args: imageid (str): if set, return mask for specified image only Returns: 2d ndarray of image size (bool), True where at least one region references the corresponding pixel. """ return self.count_map(imageid).astype(bool)
[docs] def region_map(self, imageid=None): """ Return map of region numbers, global or image-specifid. Args: imageid (str): if set, return map for specified image only Returns: 2d ndarray (int), containing the number (ascending) of the last region referencing the corresponding pixel. """ apply_regions = self._select_region(imageid) tmpmap = np.zeros(self._msize) for idx, region in enumerate(apply_regions): tmpmap[region] = (idx + 1) return tmpmap
[docs] def coverage(self, imageid=None, normalize=False): """ Calculates coverage of the total image size as a scalar. Args: imageid (str): if set, return coverage for specified image only normalize (bool): if True, divide global result by number of imageids in set. Returns: Total coverage as a floating point number. """ if imageid is not None: counts = self.count_map(imageid) cov = float(counts.sum()) / float(self.size[0] * self.size[1]) return cov else: # Global coverage for all imageids cm = np.zeros(self._msize, dtype=int) for re in self._regions.keys(): if re == '*': cm += self.count_map('*') break cm += self.count_map(re) cov = float(cm.sum()) / float(self.size[0] * self.size[1]) if normalize: cov = cov / len(self) return cov
[docs] def plot(self, imageid=None, values=None, cmap=None, image_only=False, ax=None, alpha=1.0): """ Plot regions as map of shaded areas with/without corresponding feature values Args: imageid (str): if set, plot regions for specified image values (array-like): one feature value per region cmap (str): name of matplotlib colormap to use image_only (boolean): if True, return only image content without axes ax (Axes): axes object to draw to, to include result in other figure alpha (float): opacity of plotted regions (set < 1 to visualize overlap) Returns: matplotlib figure object, or None if passed an axis to draw on """ apply_regions = self._select_region(imageid) tmpmap = np.zeros(self._msize) if ax is not None: ax1 = ax else: fig = plt.figure() ax1 = fig.add_subplot(1,1,1) if cmap is None: if values is None and 'viridis' in plt.colormaps(): cmap = 'viridis' else: cmap = 'gray' if alpha < 1.0: # allow stacking by setting masked values transparent alpha_cmap = plt.get_cmap(cmap) alpha_cmap.set_bad(alpha=0) ax1.imshow(tmpmap, cmap=plt.get_cmap('gray'), interpolation='none') for idx, region in enumerate(apply_regions): rmap = np.zeros(self._msize) if values is not None and len(values) == apply_regions.shape[0]: rmap[region] = values[idx] ax1.imshow(np.ma.masked_equal(rmap, 0), cmap=alpha_cmap, interpolation='none', alpha=alpha, vmin=0, vmax=np.nanmax(values)) else: rmap[region] = idx + 1 ax1.imshow(np.ma.masked_equal(rmap, 0), cmap=alpha_cmap, interpolation='none', alpha=alpha, vmin=0, vmax=apply_regions.shape[0]) else: # If no alpha requested, this is much faster but doesn't show overlap ax1.imshow(tmpmap, cmap=plt.get_cmap('gray'), interpolation='none') if values is not None and len(values) == apply_regions.shape[0]: rmap = np.zeros(self._msize) for idx, region in enumerate(apply_regions): rmap[region] = values[idx] ax1.imshow(np.ma.masked_equal(rmap, 0), cmap=plt.get_cmap(cmap), interpolation='none', vmin=0, vmax=np.nanmax(values)) else: ax1.imshow(np.ma.masked_equal(self.region_map(imageid), 0), cmap=plt.get_cmap(cmap), interpolation='none', vmin=0, vmax=apply_regions.shape[0]) if image_only: ax1.axis('off') else: t = '{:s}'.format(self.__class__.__name__) if self.label is not None: t += ' "{:s}"'.format(self.label) if imageid is not None: t += ': {:s}'.format(imageid) ax1.set_title(t) if ax is None and not plt.isinteractive(): # see ImageSet.plot() return fig
[docs] def plot_regions_on_image(self, imageid=None, imageset=None, cmap=None, fill=False, alpha=0.4, labels=False, image_only=False, ax=None): """ Plot region bounding boxes on corresponding image Args: imageid (str): if set, plot regions for specified image imageset (ImageSet): ImageSet object containing background image/map cmap (str): name of matplotlib colormap to use for boundin boxes fill (boolean): draw shaded filled rectangles instead of boxes alpha (float): rectangle opacity (only when fill=True) labels (boolean): if True, draw text labels next to regions image_only (boolean): if True, return only image content without axes ax (Axes): axes object to draw to, to include result in other figure Returns: matplotlib figure object, or None if passed an axis to draw on """ if imageset is None or imageid not in imageset.imageids: raise ValueError('To plot regions on top of image, specify ImageSet containing corresponding background image!') if ax is not None: ax1 = ax else: fig = plt.figure() ax1 = fig.add_subplot(1,1,1) ax1.imshow(imageset[imageid], cmap=plt.get_cmap('gray'), interpolation='none') if cmap is None: if 'viridis' in plt.colormaps(): cmap = 'viridis' else: cmap = 'hsv' boxcolors = plt.get_cmap(cmap) cstep = 0 if self.is_global: rmeta = self.info[self.info.imageid == '*'] else: rmeta = self.info[self.info.imageid == imageid] for idx, region in rmeta.iterrows(): c = boxcolors(cstep/len(rmeta)) cstep += 1 if not fill: ax1.add_patch(Rectangle((region.left, region.top), region.width, region.height, color=c, fill=False, linewidth=2)) else: ax1.add_patch(Rectangle((region.left, region.top), region.width, region.height, color=c, linewidth=0, alpha=0.7)) if labels: # Draw text labels with sensible default positions if region.right > (self.size[0] * .95): tx = region.right ha = 'right' else: tx = region.left ha = 'left' if region.bottom > (self.size[1] * .95): ty = region.top - 5 else: ty = region.bottom + 20 ax1.text(tx, ty, region.regionid, horizontalalignment=ha) if image_only: ax1.axis('off') else: t = '{:s}'.format(self.__class__.__name__) if self.label is not None: t += ' "{:s}"'.format(self.label) t += ': {:s}'.format(imageid) ax1.set_title(t) if ax is None and not plt.isinteractive(): # see ImageSet.plot() return fig
[docs] def apply(self, image, imageid=None, crop=False): """ Apply this RegionSet to a specified image. Returns a list of the image arrays "cut out" by each region mask, with non-selected image areas in black. If regionset is not global, _imageid_ needs to be specified! Args: image (ndarray): image array to be segmented. imageid (str): valid imageid (to select image-specific regions if not a global regionset) crop (bool): if True, return image cropped to bounding box of selected area Returns: If crop=False, a list of ndarrays of same size as image, with non-selected areas zeroed. Else a list of image patch arrays cropped to bounding box size. """ slices = [] apply_regions = self._select_region(imageid) for region in apply_regions: mask = (region == True) out = np.zeros(image.shape) out[mask] = image[mask] if crop: a = np.argwhere(out) (ul_x, ul_y) = a.min(0)[0:2] (br_x, br_y) = a.max(0)[0:2] out = out[ul_x:br_x+1, ul_y:br_y+1] slices.append(out) return slices
[docs] def export_patches(self, image, imageid=None, crop=True, image_format='png', rescale=False): """ Apply this RegionSet to an image array and save the resulting image patches as files. Saves an image of each image part "cut out" by each region mask, cropped by default. If the RegionSet is not global, imageid needs to be specified! Args: image (ndarray): image array to be segmented. imageid (str): imageid (to select image-specific regions if not a global regionset) crop (bool): if True, return image cropped to bounding box of selected area image_format (str): image format that PIL understands (will also be used for extension) rescale (bool): if True, scale pixel values to full 0..255 range before saving (e.g., for saliency maps) """ apply_regions = self._select_region(imageid) apply_labels = self._select_labels(imageid) imstr = '{:s}_{:s}.{:s}' for idx, region in enumerate(apply_regions): mask = (region == True) out = np.zeros(image.shape) out[mask] = image[mask] if crop: a = np.argwhere(out) (ul_x, ul_y) = a.min(0)[0:2] (br_x, br_y) = a.max(0)[0:2] out = out[ul_x:br_x+1, ul_y:br_y+1] if imageid is None or imageid == '*': imageid = 'image' if rescale: out = (out - out.min()) / out.max() * 255.0 else: out *= 255.0 rimg = Image.fromarray(np.array(out, np.uint8)) rimg.save(imstr.format(str(imageid), str(apply_labels[idx]), image_format), image_format)
[docs] def export_patches_from_set(self, imageset, crop=True, image_format='png', rescale=False): """ Save all sliced image patches from an ImageSet as image files. Saves an image of each image part "cut out" by each region mask, cropped by default. If the RegionSet is not global, only images with valid region masks will be processed. Args: imageset (ImageSet): a valid ImageSet containing images to slice imageid (str): imageid (to select image-specific regions if not a global regionset) crop (bool): if True, return image cropped to bounding box of selected area image_format (str): image format that PIL understands (will also be used for extension) rescale (bool): if True, scale pixel values to full 0..255 range before saving (e.g., for saliency maps) """ if not isinstance(imageset, ImageSet): raise TypeError('First argument must be an ImageSet! To slice a single image, use export_patches().') for cimg in imageset.imageids: if not self.is_global and cimg not in self.imageids: print('Warning: RegionSet contains image-specific regions, but no regions available for {:s}. Skipped.'.format(cimg)) else: self.export_patches(imageset[cimg], imageid=cimg, crop=crop, image_format=image_format, rescale=rescale)
[docs] def fixated(self, fixations, imageid=None, count=False, exclude_first=False): """ Returns visited / fixated regions using data from a Fixations object. Args: fixations (Fixations/DataFrame): fixation data to test against regions imageid (str): imageid (to select image-specific regions if not a global regionset) count (bool): if True, return number of fixations per region instead of boolean values exclude_first (bool): if True, first fixated region will always be returned as NaN Returns: 1D ndarray (float) containing number of fixations per region (if count=True) or the values 0.0 (region was not fixated) or 1.0 (region was fixated) """ apply_regions = self._select_region(imageid) vis = np.zeros(apply_regions.shape[0], dtype=float) # Drop out-of-bounds fixations fix = fixations.data[(fixations.data[fixations._xpx] >= 0) & (fixations.data[fixations._xpx] < self.size[0]) & (fixations.data[fixations._ypx] >= 0) & (fixations.data[fixations._ypx] < self.size[1])] if len(fix) > 0: first_fix = fixations.data[fixations.data[fixations._fixid] == min(fixations.data[fixations._fixid])] if len(first_fix) > 1 and exclude_first: print('Warning: you have requested to drop the first fixated region, but more than one ' + 'location ({:d}) matches the lowest fixation ID! Either your fixation ' .format(len(first_fix)) + 'IDs are not unique or the passed dataset contains data from multiple images or conditions.') for (idx, roi) in enumerate(apply_regions): fv = roi[fix[fixations._ypx], fix[fixations._xpx]] if isinstance(fv, np.ndarray): num_fix = sum(fv) vis[idx] = num_fix if exclude_first: try: is_first = roi[first_fix[fixations._ypx], first_fix[fixations._xpx]] if isinstance(is_first, np.ndarray) and np.any(is_first): vis[idx] = np.nan elif is_first: vis[idx] = np.nan except IndexError: pass # first fixation is out of bounds for image! if not count: vis[vis >= 1.0] = 1.0 vis[vis < 1.0] = 0.0 return vis
[docs]class GridRegionSet(RegionSet): """ RegionSet defining an n-by-m regular grid covering the full image size. Attributes: cells (list): list of bounding box tuples for each cell, each formatted as (left, top, right, bottom) gridsize (tuple): grid dimensions as (width, height). If unspecified, gridfix will try to choose a sensible default. label (string): optional label to distinguish between RegionSets """ def __init__(self, size, gridsize=None, label=None, region_labels=None): """ Create a new grid RegionSet Args: size (tuple): image dimensions, specified as (width, height). gridsize(tuple): grid dimensions, specified as (width, height). region_labels (string): list of optional region labels (default: cell#) """ if gridsize is None: gridsize = self._suggest_grid(size) print('Note: no grid size was specified. Using {:d}x{:d} based on image size.'.format(gridsize[0], gridsize[1])) (regions, cells) = self._grid(size, gridsize) RegionSet.__init__(self, size=size, regions=regions, label=label, region_labels=region_labels) self.gridsize = gridsize # List of region bboxes self.cells = cells def __str__(self): """ Short string representation for printing """ r = '<gridfix.GridRegionSet{:s}, size={:s}, {:d}x{:d} grid, {:d} cell{:s}, memory={:.1f} kB>' if self.label is not None: lab = ' ({:s})'.format(self.label) else: lab = '' num_s = '' num_r = len(self) if num_r > 1: num_s = 's' return r.format(lab, str(self.size), self.gridsize[0], self.gridsize[1], num_r, num_s, self.memory_usage) def _suggest_grid(self, size): """ Suggest grid dimensions based on image size. Args: size (tuple): image dimensions, specified as (width, height). Returns: Suggested grid size tuple as (width, height). """ aspect = Fraction(size[0], size[1]) s_width = aspect.numerator s_height = aspect.denominator if s_width < 6: s_width *= 2 s_height *= 2 return (s_width, s_height) def _grid(self, size, gridsize): """ Build m-by-n (width,height) grid as 3D nparray. Args: size (tuple): image dimensions, specified as (width, height). gridsize(tuple): grid dimensions, specified as (width, height). Returns: tuple containing the grid regions and their bounding box coordinates as (grid, cells): grid (numpy.ndarray): regions for RegionSet creation cells (list): list of bounding box tuples for each cell, each formatted as (left, top, right, bottom) """ (width, height) = size _msize = (size[1], size[0]) cell_x = int(width / gridsize[0]) cell_y = int(height / gridsize[1]) n_cells = int(gridsize[0] * gridsize[1]) grid = np.zeros((n_cells,) + _msize, dtype=bool) cells = [] # Sanity check: do nothing if image dimensions not cleanly divisible by grid if width % gridsize[0] > 0 or height % gridsize[1] > 0: e = 'Error: image dimensions not cleanly divisible by grid! image=({:d}x{:d}), grid=({:d}x{:d})' raise ValueError(e.format(width, height, gridsize[0], gridsize[1])) # Create a mask of 1s/True for each cell cellno = 0 for y_es in range(0, height, cell_y): for x_es in range(0, width, cell_x): mask = np.zeros(_msize, dtype=bool) mask[y_es:y_es + cell_y, x_es:x_es + cell_x] = True grid[cellno,...] = mask cells.append((x_es, y_es, x_es + cell_x, y_es + cell_y)) cellno += 1 return (grid, cells)