Source code for qinfer.tests.base_test

#!/usr/bin/python
# -*- coding: utf-8 -*-
##
# base_test.py: Base class for derandomized test classes.
##
# © 2017, Chris Ferrie ([email protected]) and
#         Christopher Granade ([email protected]).
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
#     1. Redistributions of source code must retain the above copyright
#        notice, this list of conditions and the following disclaimer.
#
#     2. Redistributions in binary form must reproduce the above copyright
#        notice, this list of conditions and the following disclaimer in the
#        documentation and/or other materials provided with the distribution.
#
#     3. Neither the name of the copyright holder nor the names of its
#        contributors may be used to endorse or promote products derived from
#        this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
##

## FEATURES ###################################################################

from __future__ import absolute_import
from __future__ import division # Ensures that a/b is always a float.
from future.utils import with_metaclass

## IMPORTS ####################################################################

import sys
import warnings
import abc
import numpy as np
from numpy.testing import assert_equal, assert_almost_equal
import unittest
from qinfer import Domain, Model, Simulatable, FiniteOutcomeModel, DifferentiableModel

from contextlib import contextmanager

## FUNCTIONS ##################################################################

[docs]def test_model(model, prior, expparams, stream=sys.stderr): """ Tests the given Simulatable instance for errors. Useful for debugging new or third party models. :param model: Instance of Simulatable or a subclass thereof. :param prior: Instance of Distribution, or any other class which implements a function `sample` that returns valid modelparams. :param expparams: `np.ndarray` of experimental parameters to test with. :param stream: Stream to dump the results into, default is stderr. """ if isinstance(model, DifferentiableModel): test_class = ConcreteDifferentiableModelTest elif isinstance(model, Model): test_class = ConcreteModelTest elif isinstance(model, Simulatable): test_class = ConcreteSimulatableTest else: raise ValueError("Given model has unrecognized type.") class TestGivenModel(test_class, DerandomizedTestCase): def instantiate_model(self): return model def instantiate_prior(self): return prior def instantiate_expparams(self): return expparams suite = unittest.TestLoader().loadTestsFromTestCase(TestGivenModel) runner = unittest.TextTestRunner(stream=stream) runner.run(suite)
@contextmanager def assert_warns(category): """ Context manager which asserts that its contents raise a particular warning. :param type category: Category of the warning that should be raised. """ with warnings.catch_warnings(record=True) as caught_warnings: # Catch everything. warnings.simplefilter('always') yield assert any([ issubclass(warning.category, category) for warning in caught_warnings ]), "No warning of category {} raised.".format(category) ## CLASSES #################################################################### class MockModel(FiniteOutcomeModel): """ Two-outcome model whose likelihood is always 0.5, irrespective of model parameters, outcomes or experiment parameters. """ def __init__(self, n_mps=2): self._n_mps = n_mps super(MockModel, self).__init__() @property def n_modelparams(self): return self._n_mps @staticmethod def are_models_valid(modelparams): return np.ones((modelparams.shape[0], ), dtype=bool) @property def is_n_outcomes_constant(self): return True def n_outcomes(self, expparams): return 2 @property def expparams_dtype(self): return [('a', float), ('b', int)] def likelihood(self, outcomes, modelparams, expparams): super(MockModel, self).likelihood(outcomes, modelparams, expparams) pr0 = np.ones((modelparams.shape[0], expparams.shape[0])) / 2 return FiniteOutcomeModel.pr0_to_likelihood_array(outcomes, pr0) class MockAsyncResult(object): def __init__(self, value): self._value = value def MockAsyncMapResult(MockAsyncResult): def __iter__(self): return iter(self._value) class MockDirectView(object): """ Object that mocks up an ipyparallel DirectView using serial execution, allowing for testing of classes that make use of ipyparallel without needing to install more libraries. """ n_engines = None def __init__(self, n_engines=1): self.n_engines = n_engines def __len__(self): return self.n_engines def clear(targets=None, block=None): raise NotImplementedError def execute(self, code, silent=True, targets=None, block=None): exec(code) def gather(self, key, dist='b', targets=None, block=None): raise NotImplementedError def get(self, key_s): raise NotImplementedError def map(self, f, *sequences, **kwargs): if 'block' in kwargs and kwargs['block']: return list(map(f, *sequences)) else: return MockAsyncMapResult(list(map(f, *sequences))) def map_sync(self, f, *sequences): return self.map(f, *sequences, **dict(block=True)) def map_async(self, f, *sequences): return self.map(f, *sequences, **dict(block=False)) class DerandomizedTestCase(unittest.TestCase): ## SETUP AND TEARDOWN ## # We want every test method to be setup first by seeding NumPy's random # number generator with a predictable seed (namely: zero). This way, # all of our tests are *deterministic*, and once first checked, will # not deviate from that behavior unless there is a change to the underlying # functionality. # # We do this by using the fact that nosetests and unittest both call # the method named "setUp" (note the capitalization!) before each # test method. def setUp(self): np.random.seed(0) class ConcreteSimulatableTest(with_metaclass(abc.ABCMeta, object)): """ Mixin of generic tests which can be run to test basic properties of any subclass of Simulatable. """ # FORCED PROPERTIES ## # We use this abstract instantiate_* paradigm to ensure that the actual # property cannot change instances throughout the testing. Although # unlikely, this paranoid approach prevents subclasses from having # model return something different every time it is called! @abc.abstractproperty def instantiate_model(self): """ Generates and returns an instance of the concrete Model class being tested. """ pass @property def model(self): """ Returns (a fixed) instance of the concrete Model class being tested. """ try: return self._model except AttributeError: self._model = self.instantiate_model() return self._model @abc.abstractproperty def instantiate_prior(self): """ Generates and returns a prior Distribution to be used with the model. """ pass @property def prior(self): """ Returns (a fixed) instance of the prior to be used while testing the model. """ try: return self._prior except AttributeError: self._prior = self.instantiate_prior() return self._prior @abc.abstractproperty def instantiate_expparams(self): """ Generates and returns a set of expparams to be used with the model. """ pass @property def expparams(self): """ Returns (a fixed) set of expparams to be used while testing the model. """ try: return self._expparams except AttributeError: self._expparams = self.instantiate_expparams() return self._expparams ## PROPERTIES ## @property def n_expparams(self): """ Number of experimental parameters to do tests with. """ return self.expparams.shape[0] @property def n_models(self): """ Number of model parameters to do tests with. """ # Ensure that n_models is not equal to n_expparams return 21 if self.n_expparams == 20 else 20 @property def n_outcomes(self): """ Number of outcomes to do tests with. """ # Ensure that this is not equal to n_models or n_expparams return self.n_models + self.n_expparams @property def modelparams(self): """ Fixed set of model parameter to do tests with. """ try: return self._modelparams except AttributeError: # get modelparams by sampling the prior mps = self.prior.sample(n=self.n_models) self._modelparams = mps return self._modelparams @property def outcomes(self): """ Fixed set of outcomes to do tests with. If you have a weird model with different outcome dtypes, you may want to set this property manually. """ try: return self._outcomes except AttributeError: # get some of our elements from the domain os = self.model.domain(self.expparams)[0].values while os.shape[0] < self.n_outcomes: os = np.concatenate([os,os]) if os.shape[0] > self.n_outcomes: os = os[:self.n_outcomes] self._outcomes = os return self._outcomes ## TESTS ## def test_simulate_experiment(self): """ Tests that simulate_experiment does not fail and has the right output format. """ # ensure that repeat is not equal to n_models or n_expparams repeat = 2 while repeat == self.n_expparams or repeat == self.n_models: repeat = repeat + 1 outcomes = self.model.simulate_experiment(self.modelparams, self.expparams, repeat=repeat) assert(outcomes.shape == ( repeat, self.n_models, self.n_expparams) ) # check that outcomes are in the right domains for idx_ep in range(self.n_expparams): domain = self.model.domain(self.expparams[idx_ep:idx_ep+1])[0] assert(domain.in_domain(outcomes[:,:,idx_ep].flatten())) def test_update_timestep(self): """ Tests that update_timstep does not fail and has the right output format. """ mps = self.model.update_timestep(self.modelparams, self.expparams) assert(mps.shape == ( self.n_models, self.model.n_modelparams, self.n_expparams )) mps = mps.transpose((2,0,1)).reshape(self.n_models * self.n_expparams, -1) assert(np.all(self.model.are_models_valid(mps))) def test_domain_with_none(self): """ Tests that the domain property of a Model works with the None input whenever is_n_outcomes_constant is True. """ if self.model.is_n_outcomes_constant: domain = self.model.domain(None) assert(isinstance(domain, Domain)) def test_domain(self): """ Tests that the domain property returns a list of domains of the correct length """ domains = self.model.domain(self.expparams) assert(len(domains) == self.n_expparams) for domain in domains: assert(isinstance(domain, Domain)) class ConcreteModelTest(ConcreteSimulatableTest): """ Mixin of generic tests which can be run to test basic properties of any subclass of Model. """ ## TESTS ## def test_are_models_valid(self): """ Tests that are_models_valid does not fail. """ # we are more interested in whether this fails than if the models are valid self.model.are_models_valid(self.modelparams) def test_canonicalize(self): """ Tests that canonicalize does not fail and that it returns valid models for the tester's specific modelparams. """ new_mps = self.model.canonicalize(self.modelparams) assert(np.all(self.model.are_models_valid(new_mps))) def test_likelihood(self): """ Tests that likelihood does not fail and has the right output format. """ L = self.model.likelihood(self.outcomes, self.modelparams, self.expparams) assert(L.shape == ( self.n_outcomes, self.n_models, self.n_expparams) ) class ConcreteDifferentiableModelTest(ConcreteModelTest): """ Mixin of generic tests which can be run to test basic properties of any subclass of Model. """ ## TESTS ## def test_fisher_information(self): """ Tests that fisher information does not fail and has the right output format. """ fisher = self.model.fisher_information(self.modelparams, self.expparams) assert(fisher.shape == ( self.model.n_modelparams, self.model.n_modelparams, self.n_models, self.n_expparams)) def test_score(self): """ Tests that score does not fail and has the right output format. """ score1 = self.model.score(self.outcomes, self.modelparams, self.expparams, return_L=False) L1 = self.model.likelihood(self.outcomes, self.modelparams, self.expparams) score, L = self.model.score(self.outcomes, self.modelparams, self.expparams, return_L=True) # Ensure some consistency assert_almost_equal(score1, score, 3) assert_almost_equal(L1, L, 3) # Dimensions must be correct assert(score.shape == ( self.model.n_modelparams, self.n_outcomes, self.n_models, self.n_expparams) ) class ConcreteDomainTest(with_metaclass(abc.ABCMeta, object)): """ Mixin of generic tests which can be run to test basic properties of any subclass of Domain. """ # FORCED PROPERTIES ## # We use this abstract instantiate_* paradigm to ensure that the actual # property cannot change instances throughout the testing. @abc.abstractproperty def instantiate_domain(self): """ Generates and returns an instance of the concrete Domain class being tested. """ pass @property def domain(self): """ Returns (a fixed) instance of the concrete Model class being tested. """ try: return self._domain except AttributeError: self._domain = self.instantiate_domain() return self._domain @abc.abstractproperty def instantiate_good_values(self): """ Returns a list of values in the domain. """ pass @property def good_values(self): """ Returns (a fixed) list of values in the domain. """ try: return self._good_values except AttributeError: self._good_values = self.instantiate_good_values() return self._good_values @abc.abstractproperty def instantiate_bad_values(self): """ Returns a list of values not in the domain. """ pass @property def bad_values(self): """ Returns (a fixed) list of values not in the domain. """ try: return self._bad_values except AttributeError: self._bad_values = self.instantiate_bad_values() return self._bad_values ## TESTS ## def test_is_cts_or_is_descrete(self): """ Tests that is_continuous is not is_discrete """ assert(self.domain.is_continuous or not self.domain.is_continuous) assert(self.domain.is_continuous is not self.domain.is_discrete) def test_is_finite(self): """ Tests that is_finite is bool and consistent """ assert(self.domain.is_finite or not self.domain.is_finite) if self.domain.is_finite: assert(self.domain.is_discrete) def test_example_point(self): """ Tests that the example point is in the domain and has the right dtype """ assert(self.domain.in_domain(self.domain.example_point)) assert_equal(self.domain.example_point, self.domain.example_point.astype(self.domain.dtype)) def test_values(self): """ Tests that n_members is consistent """ values = self.domain.values if self.domain.n_members < np.inf: assert(values.size == self.domain.n_members) assert(self.domain.in_domain(values)) def test_in_domain(self): """ Tests that good_values are in the domain and bad_values are not. (self.values is tested elsewhere) """ for v in self.good_values: try: assert(self.domain.in_domain(v)) except AssertionError as e: e.args += ('Current good value: {}'.format(v),) raise e for v in self.bad_values: try: assert(not self.domain.in_domain(v)) except AssertionError as e: e.args += ('Current bad value: {}'.format(v),) raise e