display_lib.py 13.2 KB
#! /bin/env python
# coding: utf8

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

class display1d(object):
	"""
	Functions to easily add options to 1d objects plots
	initialize as display1d(3,2) for 3 columns and 2 rows
	Change plot by switching the indice with set_actualPlotNumber, indice starts from 1 and goes from left to right and then under
	add abscisse 
	add errorbars along x or y
	add colorscale to a scatter plot
	NOTE : Not made for timeseries!!
	"""
	def __init__(self, ncols = 1, nrows = 1) :
        
		self.fig = plt.figure(figsize = (16, 10), dpi = 96)
		self.ncols = int(ncols)
		self.nrows = int(nrows)
		self.set_actualPlotNumber(0)
		print "initializing ", ncols * nrows, 'independant plots'
		self.plots = [{}]
		self.plots[0]['label'] = {'data' : None, 'xaxis' : None, 'yaxis' : None, 'cbar' : None, 'title' : None}
		for i in range(nrows*ncols-1) : 
				self.plots.append({})
				self.plots[-1]['label'] = {'data' : None, 'xaxis' : None, 'yaxis' : None, 'cbar' : None, 'title' : None}
		
		#self.set_colorscale()
		#self.set_errorbar(x = np.arange(len(self.y))/5.)
		#self.set_logscale()
		#self.set_label()
		
	def set_data(self, y) :
		#i = int(i)
		#self.set_actualPlotNumber(i)
		i = self.get_actualPlotNumber()
		self.plots[i]['ax'] = self.fig.add_subplot(str(self.nrows) + str(self.ncols) + str(i+1))
		self.plots[i]['data'] = np.array(y)
		self.plots[i]['ax'].set_xlabel(self.plots[i]['label']['xaxis'])
		self.plots[i]['ax'].set_ylabel(self.plots[i]['label']['yaxis'])
		self.set_abscisse()
	
	def set_actualPlotNumber(self, i):
		self._actualPlotNumber = int(i)
		
	def get_actualPlotNumber(self):
		return self._actualPlotNumber
	
	def set_abscisse(self, x = [0]):
		i = self.get_actualPlotNumber()
		x = np.array(x)
		if len(x) != len(self.plots[i]['data']) : 
			print 'setting default abscisse for plot ', i
			self.plots[i]['x'] = np.arange(self.plots[i]['data'].shape[0])
		else : self.plots[i]['x'] = x
    
	def set_label(self, data = None, xaxis = None, yaxis = None, cbar = None, title = None):
		i = self.get_actualPlotNumber()
		tmpdict = {'data' : data, 'xaxis' : xaxis, 'yaxis' : yaxis, 'cbar' : cbar, 'title' : title}
		for key in tmpdict : 
			if tmpdict[key] : self.plots[i]['label'][key] = tmpdict[key]
	
	def set_limits(self, xlim = [None, None], ylim = [None, None]):
		i = self.get_actualPlotNumber()
		self.plots[i]['ax'].set_xlim(xlim)
		self.plots[i]['ax'].set_ylim(ylim)
	
	def set_logscale(self, x = False, y = False) :
		i = self.get_actualPlotNumber()
		if x : self.plots[i]['ax'].set_xscale('log')
		if y : self.plots[i]['ax'].set_yscale('log')
    
	def set_errorbar(self, xerr = [0], yerr = [0]):
		i = self.get_actualPlotNumber()
		self.plots[i]['xerr'], self.plots[i]['yerr'] = np.array(xerr), np.array(yerr)
		if self.plots[i]['xerr'].shape != self.plots[i]['data'].shape : self.plots[i]['xerr'] = np.zeros(self.plots[i]['data'].shape)
		if self.plots[i]['yerr'].shape != self.plots[i]['data'].shape : self.plots[i]['yerr'] = np.zeros(self.plots[i]['data'].shape)
		
		self.plots[i]['ax'].errorbar(self.plots[i]['x'], self.plots[i]['data'], xerr = self.plots[i]['xerr'], yerr = self.plots[i]['yerr'], color = 'black' , alpha = 0.1, fmt = '.')
        
	def set_colorscale(self, col_arr = [0], minmax = 'auto', cmap = 'plasma') :
		i = self.get_actualPlotNumber()
		self.plots[i]['col_arr'] = np.array(col_arr)
		if self.plots[i]['col_arr'].shape != self.plots[i]['data'].shape : self.plots[i]['col_arr'] = self.plots[i]['data']
		
		#self.new_val = (self.col_arr - self.col_arr.min())/(self.col_arr.max() - self.col_arr.min())
		#self.carray = plt.get_cmap(cmap)(self.new_val)
		
		#for x, y, col, lab in zip(self.x, self.y, self.carray, self.col_arr) :
			#self.collec = self.ax.scatter(x, y, color = col, label = lab)
			
		self.plots[i]['scatter'] = self.plots[i]['ax'].scatter(self.plots[i]['x'], self.plots[i]['data'])
		self.plots[i]['scatter'].set_array(self.plots[i]['col_arr'])
		self.plots[i]['scatter'].set_cmap(cmap)
		self.plots[i]['scatter'].set_label(self.plots[i]['label']['data'])
		if minmax == 'auto' : self.plots[i]['scatter'].autoscale()
		else : self.plots[i]['scatter'].set_clim(vmin = minmax[0], vmax = minmax[1])
		self.plots[i]['cbar'] = self.fig.colorbar(self.plots[i]['scatter'])
		self.plots[i]['cbar'].set_label(self.plots[i]['label']['cbar'])
		
	def set_plot(self, lstyle = '-', color = 'black'):
		i = self.get_actualPlotNumber()
		self.plots[i]['plot'], = self.plots[i]['ax'].plot(self.plots[i]['x'], self.plots[i]['data'], lstyle, color = color)
		self.plots[i]['plot'].set_label(self.plots[i]['label']['data'])
	
	def set_secondaryPlot(self, x, y, lstyle = '-', color = 'black', label = 'secondary plot'):
		i = self.get_actualPlotNumber()
		if 'splots' not in self.plots[i] : self.plots[i]['splots'] = {}
		self.plots[i]['splots'][len(self.plots[i]['splots'])] = {'x' : x, 'y': y, 'lstyle' : lstyle, 'color' : color, 'label':label}
		self.plots[i]['ax'].plot(x, y, lstyle, color = color, label = label)
	
	def set_hist(self, nstep = 20):
		i = self.get_actualPlotNumber()
		self.plots[i]['hist'] = self.plots[i]['ax'].hist(self.plots[i]['data'], nstep, alpha = 0.75, histtype = 'stepfilled', label = self.plots[i]['label']['data'])
	
	def setAllLegends(self):
		for obj in self.plots : obj['ax'].legend()
	
	def show(self):
		self.setAllLegends()
		#plt.legend()
		plt.show()
	


class display3d(object) :
	"""
	3D plotting object
	works fine! init 3D coordinates as numpy arrays (x,y,z)
	can add surface with values at surfaces (set_trisurf)
	By default aspect ratio is 1:1, pyplot has a weird aspect ration! (still not fixes 07/2017)
	Thus this object impose an invisible cube at first
	use set_autoscatter for a default scatter plot of x,y,z (NOT generated by default, must be called!!)
	"""
	def __init__(self, x, y, z) :
	  
		self.x = x
		self.y = y
		self.z = z
		self.planet = False
		self._initialize()
		self.plots = {}
		
	def _initialize(self):
	  
		self.fig = plt.figure(figsize = (10, 10), dpi = 96)
		self.ax = self.fig.gca(projection = '3d')#, aspect = 'equal')
		self.ax.view_init(elev = 20., azim = 30)
		
		### Create a cube to force aspect ratio 1:1
		self.ax.set_aspect('equal')
		MAX = max(np.concatenate([abs(self.x),abs(self.y),abs(self.z)]))
		for direction in (-1, 1):
			for point in np.diag(direction * MAX * np.array([1,1,1])):
				self.ax.plot([point[0]], [point[1]], [point[2]], 'w')
	
	def forceEqualRatio(self):
		
		#find highest data point
		mlist = []
		for key in self.plots : mlist.append(abs(np.concatenate(self.plots[key]['data'])).max())
		mlist.append(abs(np.concatenate([self.x,self.y,self.z])).max())
		#add any other data to the list
		MAX = max(mlist)
		for direction in (-1, 1):
			for point in np.diag(direction * MAX * np.array([1,1,1])):
				self.ax.plot([point[0]], [point[1]], [point[2]], 'w')
		
		
	def redraw(self):
		self._initialize()
		if self.trisurf : self.set_trisurf(tripos = self.trisurf['tripos'], trival = self.trisurf['trival'], minmax = self.trisurf['minmax'])
		if self.planet : self.set_planet()
		for i in self.plots : self.set_plot(self.plots[i]['data'], color = self.plots[i]['color'], save = False)
		
	def set_trisurf(self, tripos, trival = False, minmax = 'auto') :
		self.trisurf = {}
		self.trisurf['tripos'] = tripos
		self.trisurf['trival'] = trival
		self.trisurf['minmax'] = minmax
		
		self.collec = self.ax.plot_trisurf(self.x,self.y,self.z, triangles = tripos, alpha = 0.9, edgecolor = 'none')
		if trival.any() : 
			if len(trival) != len(tripos) : print 'Wrong collection of surface!'
			else : self.collec.set_array(trival)
		if minmax == 'auto' : self.collec.autoscale()
		else : self.collec.set_clim(vmin = minmax[0], vmax = minmax[1])
		cbar = self.fig.colorbar(self.collec)
	
	#def set_collec(self, collec, minmax = 'auto')

	def set_quiver(self, quiver, color = None) :
		self.ax.quiver(self.x,self.y,self.z, quiver[:,0], quiver[:,1], quiver[:,2], pivot = 'tail', color = color)
	
	def add_quiver(self, positions, quiver, color = None):
		self.ax.quiver(positions[:,0],positions[:,1],positions[:,2], quiver[:,0], quiver[:,1], quiver[:,2], pivot = 'tail', color = color)
	
	def set_scatter(self, scatter, c = None) :
		self.scatter = self.ax.scatter(scatter[:,0], scatter[:,1], scatter[:,2], c = c, cmap = 'seismic')
	  
	def set_autoscatter(self, c = None, vmin = None, vmax = None):
		self.autoscatter = self.ax.scatter(self.x, self.y, self.z, vmin = vmin, vmax = vmax, c = c, cmap = 'seismic')
		cbar = self.fig.colorbar(self.autoscatter)
		
	def set_planet(self) :
		self.planet = True
		u, v = np.mgrid[0:2*np.pi:20j, 0:np.pi:10j]
		xsp = np.cos(u)*np.sin(v)
		ysp = np.sin(u)*np.sin(v)
		zsp = np.cos(v)
		colors = np.empty(xsp.shape, dtype=str)
		colortuple = ('y', 'b')
		for i in range (colors.shape[0]) :
			for j in range(colors.shape[1]) :
				#print [(i + j) %len(colortuple)]
				colors[i,j] = colortuple[(i + j) %len(colortuple)]
		self.ax.plot_surface(xsp, ysp, zsp, facecolors = colors, color="yellow", alpha = 1., shade = False)
	  
	def set_plot(self, plot, color = 'black', save = True) :
		if save : self.plots[len(self.plots)] = {'data' : plot, 'color' : color}
		self.ax.plot(plot[:,0], plot[:,1], plot[:,2], linewidth = 3., color = color)
	  
	def show(self) :
		plt.show()
		
		



class display_colormaps(object) :
	"""
	Display all colormaps available in python
	"""
	def __init__(self) :

		self.get_colormap_info()
		self.nrows = max(len(cmap_list) for cmap_category, cmap_list in self.cmaps)
		
		gradient = np.linspace(0, 1, 256)
		self.gradient = np.vstack((gradient, gradient))
		
		for cmap_category, cmap_list in self.cmaps:
			self.plot_color_gradients(cmap_category, cmap_list, self.nrows)
		
		plt.show()
    
	def plot_color_gradients(self, cmap_category, cmap_list, nrows):
		fig, axes = plt.subplots(nrows = nrows)
		fig.subplots_adjust(top = 0.95, bottom = 0.01, left = 0.2, right = 0.99)
		axes[0].set_title(cmap_category + ' colormaps', fontsize=14)

		for ax, name in zip(axes, cmap_list):
			ax.imshow(self.gradient, aspect = 'auto', cmap = plt.get_cmap(name))
			pos = list(ax.get_position().bounds)
			x_text = pos[0] - 0.01
			y_text = pos[1] + pos[3]/2.
			fig.text(x_text, y_text, name, va = 'center', ha = 'right', fontsize = 10)

		# Turn off *all* ticks & spines, not just the ones with colormaps.
		for ax in axes:
			ax.set_axis_off()
        
        
	def get_colormap_info(self) :
		self.cmaps = [('Perceptually Uniform Sequential', ['viridis', 'plasma', 'inferno', 'magma']), ('Sequential', ['Greys', 'Purples', 'Blues', 'Greens', 'Oranges', 'Reds', 'YlOrBr', 'YlOrRd', 'OrRd', 'PuRd', 'RdPu', 'BuPu', 'GnBu', 'PuBu', 'YlGnBu', 'PuBuGn', 'BuGn', 'YlGn']), ('Sequential (2)', ['binary', 'gist_yarg', 'gist_gray', 'gray', 'bone', 'pink', 'spring', 'summer', 'autumn', 'winter', 'cool', 'Wistia', 'hot', 'afmhot', 'gist_heat', 'copper']), ('Diverging', ['PiYG', 'PRGn', 'BrBG', 'PuOr', 'RdGy', 'RdBu', 'RdYlBu', 'RdYlGn', 'Spectral', 'coolwarm', 'bwr', 'seismic']), ('Qualitative', ['Pastel1', 'Pastel2', 'Paired', 'Accent',  'Dark2', 'Set1', 'Set2', 'Set3']), ('Miscellaneous', ['flag', 'prism', 'ocean', 'gist_earth', 'terrain', 'gist_stern', 'gnuplot', 'gnuplot2', 'CMRmap', 'cubehelix', 'brg', 'hsv', 'gist_rainbow', 'rainbow', 'jet', 'nipy_spectral', 'gist_ncar'])]
		#print self.cmaps
    


        
#def lineplot(y, **kwargs):
	#import matplotlib.dates as mpldates
	#import matplotlib.pyplot as plt
	#from matplotlib.widgets import Slider

	##### Transformation en numpy array et récupération des abscisses (def une autre fonction?)
	#y = np.array(y)
	#if isinstance(y[0], (np.ndarray, list)): 
		#nplots = len(y)
		#for nplot in range(nplots) : y[nplot] = np.array(y[nplot])
	#else :
		#y = np.array([y])
		#nplots = len(y)
	
	#if "x" in kwargs : abscisse = kwargs["x"]
	#else : 
		#abscisse = []
		#for nplot in range(nplots) : abscisse.append(np.arange(0, len(y[nplot]), 1))
	#abscisse = np.array(abscisse)

	##for nplot in range(nplots) :
		##if isinstance(abscisse[nplot], (np.ndarray, list)) :
			##abscisse[nplot] = np.array(abscisse[nplot])
		##else :
			##print "No abscisse specified, using default"
			
			##np.append(abscisse)
	
	#fig, ax = plt.subplots()
	#plt.subplots_adjust(left = 0.25, bottom = 0.25)
	#plt.grid(True)
	
	#for nplot in range(nplots) : 
		#if isinstance(abscisse[nplot][0], datetime) :
			#dates = mpldates.date2num(abscisse[nplot])
			#plt.plot_date(dates, y[nplot])
		#else : 
			#plt.plot(abscisse[nplot], y[nplot])
	##pdb.set_trace()
	#slide_select_abscisse = [abscisse.min(), abscisse.min()]
	#slide_select_y = [y.min(), y.max()]

	#if isinstance(abscisse[nplot][0], datetime) :
		#abcdate = [0, 0]
		#sliders_plot, = plt.plot(abcdate, slide_select_y)
		#slaxe = plt.axes([0.25, 0.1, 0.65, 0.03], axisbg = 'white')
		#slpos = Slider(slaxe, 'debut', 0, len(abscisse[0]) - 1)
	#else :
		#sliders_plot, = plt.plot(slide_select_abscisse, slide_select_y)
		#slaxe = plt.axes([0.25, 0.1, 0.65, 0.03], axisbg = 'white')
		#slpos = Slider(slaxe, 'debut', abscisse.min(), abscisse.max())
	
	#def update_date(val):
		#selec = mpldates.date2num(abscisse[0][int(slpos.val)])
		#sliders_plot.set_xdata([selec,selec])
		#fig.canvas.draw_idle()
	
	#def update(val):
		#selec = slpos.val
		#sliders_plot.set_xdata([selec,selec])
		#fig.canvas.draw_idle()
		
	#slpos.on_changed(update_date)
	

	#plt.draw()
	#plt.show()
	
	#return slideval