Source code for qinfer.tomography.plotting_tools

# -*- coding: utf-8 -*-
# Functions for plotting tomographic data and estimates.
# © 2015 Chris Ferrie ( and
#        Christopher E. Granade (,
#        except where otherwise noted.
# Based on work with Joshua Combes (
# This file is a part of the Qinfer project.
# Licensed under the AGPL version 3.
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <>.

# TODO: unit tests!

## FEATURES ##################################################################

from __future__ import absolute_import
from __future__ import division

## IMPORTS ###################################################################

from builtins import map

import numpy as np

    import matplotlib.pyplot as plt
    from matplotlib.patches import Ellipse, Polygon
except ImportError:
    import warnings
    warnings.warn("Could not import matplotlib.")
    plt = None
    Ellipse = None

# Since the rest of QInfer does not require QuTiP,
# we need to import it in a way that we don't propagate exceptions if QuTiP
# is missing or is too early a version.
from qinfer.utils import get_qutip_module
qt = get_qutip_module('3.2')

## EXPORTS ###################################################################

__all__ = [

## CONSTANTS #################################################################

REBIT_AXES = [1, 2]

## FUNCTIONS #################################################################

def plot_rebit_modelparams(modelparams, rebit_axes=REBIT_AXES, **kwargs):
    Given model parameters representing rebits, plots the
    rebit states as a scatter plot. Additional keyword arguments
    are passed to :ref:`plt.scatter`.

    :param np.ndarray modelparams: Model parameters representing
    :param list rebit_axes: List containing indices for the :math:`x`
        and :math:`z` axes.
    mps = modelparams[:, rebit_axes] * np.sqrt(2)
    plt.scatter(mps[:, 0], mps[:, 1], **kwargs)

def plot_decorate_rebits(basis=None, rebit_axes=REBIT_AXES):
    Decorates a figure with the boundary of rebit state space
    and basis labels drawn from a :ref:`~qinfer.tomography.TomographyBasis`.

    :param qinfer.tomography.TomographyBasis basis: Basis to use in
        labeling axes.
    :param list rebit_axes: List containing indices for the :math:`x`
        and :math:`z` axes.
    ax = plt.gca()

    if basis is not None:
        labels = list(map(r'$\langle\!\langle {} | \rho \rangle\!\rangle$'.format,
            # Pick out the x and z by default.
            [basis.labels[rebit_axes[0]], basis.labels[rebit_axes[1]]]

    ax.add_artist(plt.Circle([0, 0], 1, color='k', fill=False))
    ax.set_xlim(-1.1, 1.1)
    ax.set_ylim(-1.1, 1.1)

[docs]def plot_cov_ellipse(cov, pos, nstd=2, ax=None, **kwargs): # Copied from in # accordance with its license agreement. """ Plots an `nstd` sigma error ellipse based on the specified covariance matrix (`cov`). Additional keyword arguments are passed on to the ellipse patch artist. :param cov: The 2x2 covariance matrix to base the ellipse on. :param pos: The location of the center of the ellipse. Expects a 2-element sequence of ``[x0, y0]``. :param nstd: The radius of the ellipse in numbers of standard deviations. Defaults to 2 standard deviations. :param ax: The axis that the ellipse will be plotted on. Defaults to the current axis. :return: A matplotlib ellipse artist. """ def eigsorted(cov): vals, vecs = np.linalg.eigh(cov) order = vals.argsort()[::-1] return vals[order], vecs[:,order] if ax is None: ax = plt.gca() vals, vecs = eigsorted(cov) theta = np.degrees(np.arctan2(*vecs[:,0][::-1])) # Width and height are "full" widths, not radius width, height = 2 * nstd * np.sqrt(vals) ellip = Ellipse(xy=pos, width=width, height=height, angle=theta, **kwargs) ax.add_artist(ellip) return ellip
[docs]def plot_rebit_prior(prior, rebit_axes=REBIT_AXES, n_samples=2000, true_state=None, true_size=250, force_mean=None, legend=True, mean_color_index=2 ): """ Plots rebit states drawn from a given prior. :param qinfer.tomography.DensityOperatorDistribution prior: Distribution over rebit states to plot. :param list rebit_axes: List containing indices for the :math:`x` and :math:`z` axes. :param int n_samples: Number of samples to draw from the prior. :param np.ndarray true_state: State to be plotted as a "true" state for comparison. """ pallette = plt.rcParams['axes.color_cycle'] plot_rebit_modelparams(prior.sample(n_samples), c=pallette[0], label='Prior', rebit_axes=rebit_axes ) if true_state is not None: plot_rebit_modelparams(true_state, c=pallette[1], label='True', marker='*', s=true_size, rebit_axes=rebit_axes ) if hasattr(prior, '_mean') or force_mean is not None: mean = force_mean if force_mean is not None else prior._mean plot_rebit_modelparams( prior._basis.state_to_modelparams(mean)[None, :], edgecolors=pallette[mean_color_index], s=250, facecolors='none', linewidth=3, label='Mean', rebit_axes=rebit_axes ) plot_decorate_rebits(prior.basis, rebit_axes=rebit_axes ) if legend: plt.legend(loc='lower left', ncol=3, scatterpoints=1)
[docs]def plot_rebit_posterior(updater, prior=None, true_state=None, n_std=3, rebit_axes=REBIT_AXES, true_size=250, legend=True, level=0.95, region_est_method='cov' ): """ Plots posterior distributions over rebits, including covariance ellipsoids :param qinfer.smc.SMCUpdater updater: Posterior distribution over rebits. :param qinfer.tomography.DensityOperatorDistribution: Prior distribution over rebit states. :param np.ndarray true_state: Model parameters for "true" state to plot as comparison. :param float n_std: Number of standard deviations out from the mean at which to draw the covariance ellipse. Only used if region_est_method is ``'cov'``. :param float level: Credibility level to use for computing region estimators from convex hulls. :param list rebit_axes: List containing indices for the :math:`x` and :math:`z` axes. :param str region_est_method: Method to use to draw region estimation. Must be one of None, ``'cov'`` or ``'hull'``. """ pallette = plt.rcParams['axes.color_cycle'] plot_rebit_modelparams(updater.particle_locations, c=pallette[0], label='Posterior', s=12 * np.sqrt(updater.particle_weights * len(updater.particle_weights)), rebit_axes=rebit_axes, zorder=-10 ) plot_rebit_modelparams(true_state, c=pallette[1], label='True', marker='*', s=true_size, rebit_axes=rebit_axes ) if prior is not None: plot_rebit_modelparams( prior._basis.state_to_modelparams(prior._mean)[None, :], edgecolors=pallette[3], s=250, facecolors='none', linewidth=3, label='Prior Mean', rebit_axes=rebit_axes ) plot_rebit_modelparams( updater.est_mean()[None, :], edgecolors=pallette[2], s=250, facecolors='none', linewidth=3, label='Posterior Mean', rebit_axes=rebit_axes ) if region_est_method == 'cov': # Multiplying by sqrt{2} to rescale to Bloch ball. cov = 2 * updater.est_covariance_mtx() # Use fancy indexing to cut out all but the desired submatrix. cov = cov[rebit_axes, :][:, rebit_axes] plot_cov_ellipse( cov, updater.est_mean()[rebit_axes] * np.sqrt(2), nstd=n_std, edgecolor='k', fill=True, lw=2, facecolor=pallette[0], alpha=0.4, zorder=-9, label='Posterior Cov Ellipse ($Z = {}$)'.format(n_std) ) elif region_est_method == 'hull': # Find the convex hull from the updater, projected # on the rebit axes. faces, vertices = updater.region_est_hull(level, modelparam_slice=rebit_axes) polygon = Polygon(vertices * np.sqrt(2), facecolor=pallette[0], alpha=0.4, zorder=-9, label=r'Credible Region ($\alpha = {}$)'.format(level), edgecolor='k', lw=2, fill=True ) # TODO: consolidate add_patch code with that above. plt.gca().add_patch(polygon) plot_decorate_rebits(updater.model.base_model._basis, rebit_axes=rebit_axes ) if legend: plt.legend(loc='lower left', ncol=4, scatterpoints=1)