Source code for mlpy.knowledgerep.cbr.engine

from __future__ import division, print_function, absolute_import

import math
import copy

import numpy as np
import matplotlib
from matplotlib import pyplot as plt

from ...auxiliary.misc import remove_key, listify
from ...knowledgerep.cbr.methods import CBRMethodFactory
from .features import FeatureFactory
from .similarity import SimilarityFactory


[docs]class CaseMatch(object): """Case match information. Parameters ---------- case : Case The matching case. similarity : A measure for the similarity to the query case. Attributes ---------- is_solution : bool Whether this case match is a solution to the query case or not. error : float The error of the prediction. predicted : bool Whether the query case could be correctly predicted using this case match. """ __slots__ = ('_case', '_similarity', 'is_solution', 'error', 'predicted') @property def case(self): """The case that matches the query case. Returns ------- Case : The case matching the query. """ return self._case # noinspection PyShadowingNames def __init__(self, case, key, similarity=None): self._case = case self._similarity = { key: similarity } self.is_solution = False self.error = np.inf self.predicted = False
[docs] def get_similarity(self, key): """Retrieve the similarity measure for the feature identified by the key. Returns ------- float : The similarity measure. """ return self._similarity[key]
[docs] def set_similarity(self, key, value): """Set the similarity measure for the feature identified by the key.""" self._similarity[key] = value
[docs]class Case(object): """The representation of a case in the case base. A case is composed of one or more :class:`Feature`. Parameters ---------- cid : int The case's unique identifier. name : str The name for the case. description : str Text describing the case, optional. features : dict A list of features describing the case. """ __slots__ = ('_id', '_name', '_description', '_features', 'ix') @property def id(self): """ The case's unique identifier. :rtype: int """ return self._id def __init__(self, cid, name=None, description=None, features=None): self._id = cid self._name = "" if name is None else name self._description = "" if description is None else description self._features = {} if features is None else copy.copy(features) def __getitem__(self, key): return self._features[key].value def __len__(self): return len(self._features) def __iter__(self): self.ix = 0 return self
[docs] def next(self): if self.ix == len(self._features): raise StopIteration item = self._features[self.ix][1] self.ix += 1 return item
[docs] def add_feature(self, name, _type, value, **kwargs): """Add a new feature. Parameters ---------- name : str The name of the feature (this also serves as the features identifying key). _type : str The type of the feature. Valid feature types are: bool The feature values are boolean types (:class:`.BoolFeature`). string The feature values are of types sting (:class:`.StringFeature`). int The feature values are of type integer (:class:`.IntFeature`). float The feature values are of type float (:class:`.FloatFeature`). value : bool or string or int or float or list The feature value. Other Parameters ---------------- weight : float or list[float] The weights given to each feature value. is_index : bool Flag indicating whether this feature is an index. retrieval_method : str The similarity model used for retrieval. Refer to :attr:`.Feature.retrieval_method` for valid methods. retrieval_method_params : dict Parameters relevant to the selected retrieval method. retrieval_algorithm : str The internal indexing structure of the training data. Refer to :attr:`.Feature.retrieval_method` for valid algorithms. retrieval_metric : str The metric used to compute the distances between pairs of points. Refer to :class:`sklearn.neighbors.DistanceMetric` for valid identifiers. retrieval_metric_params : dict Parameters relevant to the specified metric. """ feature = FeatureFactory.create(_type, name, value, **kwargs) self._features[name] = feature
[docs] def get_retrieval_method(self, names): """Returns the retrieval method for the given features. Parameters ---------- names : str or list[str] The name(s) of the feature for which to retrieve the retrieval method. Returns ------- str : The retrieval method for all feature. Features grouped together for retrieval must use the same retrieval method. Raises ------ UserWarning If not all features use the same retrieval method. """ method = None for n in listify(names): feat = self._features[n] if method is None: method = feat.retrieval_method elif not method == feat.retrieval_method: raise UserWarning("All features grouped for retrieval must use the same retrieval method.") return method
[docs] def get_retrieval_params(self, names): """Return the retrieval parameters for the given features. Parameters ---------- names : str or list[str] The name(s) of the feature for which to retrieve the retrieval parameters. Returns ------- dict : The retrieval parameters for the feature(s). Features grouped together for retrieval must use the same retrieval parameters. Raises ------ UserWarning If not all features use the same retrieval parameters. """ params = {} for n in listify(names): feat = self._features[n] if not params: params["method_params"] = feat.retrieval_method_params params["algorithm"] = feat.retrieval_algorithm params["metric"] = feat.retrieval_metric params["metric_params"] = feat.retrieval_metric_params elif not (params["method_params"] == feat.retrieval_method_params and params["algorithm"] == feat.retrieval_algorithm and params["metric"] == feat.retrieval_metric and params["metric_params"] == feat.retrieval_metric_param): raise UserWarning("All features grouped for retrieval must have the same retrieval parameters") return params
# noinspection PyShadowingNames
[docs] def get_indexed(self): """Return sorted collection of all indexed features. Returns ------- list : The names of the indexed features in ascending order. """ names = [x[1].name for x in self._features.items() if x[1].is_index] if len(names) == 1: return names[0] return names
# noinspection PyShadowingNames
[docs] def get_features(self, names): """Return sorted collection of features with the specified name. Parameters ---------- names : str or list[str] The name(s) of the features to retrieve. Returns ------- list or int or str or bool or float : List of features with the specified names(s) """ if isinstance(names, list): return [x[1].value for x in self._features.items() if x[1].name in names] return self._features[names].value
[docs] def compute_similarity(self, other): """Computes how similar two cases are. Parameters ---------- other : Case The other case this case is compared to. Returns ------- float : The similarity measure between the two cases. """ total_similarity = 0.0 for key, sfeature in self._features.iteritems(): ofeature = other[key] if sfeature.is_index and ofeature.is_index: weight = sfeature.weight * ofeature.weight total_similarity += weight * math.pow(sfeature.compare(ofeature), 2) return math.sqrt(total_similarity)
[docs]class CaseBaseEntry(object): """The case base entry class. The entry maintains a similarity model from which similar cases can be derived. Internally the similarity model maintains an indexing structure dependent on the similarity model type for efficient computation of the similarity between cases. The case base is responsible for updating the indexing structure as cases are added and removed. Parameters ---------- model : ISimilarity The similarity model. validity_check : bool This flag controls whether the dirty flag is being checked before determining whether to rebuild the model or not. Attributes ---------- dirty : bool A flag which identifies whether the model needs to be rebuild. The indexing structure of the similarity model is always rebuild unless a validity check is required. If a validity check is required the indexing structure is only rebuild if the entry is considered dirty. """ __slots__ = ('dirty', '_similarity', '_validity_check') def __init__(self, model, validity_check=True): self.dirty = True self._similarity = model self._validity_check = validity_check
[docs] def compute_similarity(self, data_point, **kwargs): """Computes the similarity. Computes the similarity between the data point and each entry in the similarity model's indexing structure. Parameters ---------- data_point : list[float] The data point to each entry in the similarity model is compared to. Returns ------- list[Stat] : The similarity statistics of all entries in the model's indexing structure. Other Parameters ---------------- cases : dict[int, Case] The complete case base from which to build the indexing structure used by the similarity model. data : ndarray[ndarray[float]] If this keyword is set, the cases in the case base are ignored and the data entries specified in this variable are used to build the indexing structure. names : str or list[str] The feature name(s) relevant for the similarity computation. This field is only required if the cases for building the indexing structure comes from the `cases` field. id_map : dict[int, int] The mapping from the data stored in the `data` field to their case ids. This field is only required if the data for building the indexing structure comes from the `data` field. """ self._build_indexing_structure(**kwargs) return self._similarity.compute_similarity(data_point)
def _build_indexing_structure(self, **kwargs): """Build the indexing structure. Build the indexing structure of the similarity model for specific feature names or alternatively for the cases provided in the `data` field. Parameters ---------- cases : dict[int, Case] The complete case base from which to build the indexing structure used by the similarity model. data : ndarray[ndarray[float]] If this keyword is set, the cases in the case base are ignored and the data entries specified in this variable are used to build the indexing structure. names : str or list[str] The feature name(s) relevant for the similarity computation. This field is only required if the cases for building the indexing structure comes from the `cases` field. id_map : dict[int, int] The mapping from the data stored in the `data` field to their case ids. This field is only required if the data for building the indexing structure comes from the `data` field. """ if not self._validity_check or self.dirty: try: data = kwargs["data"] id_map = kwargs["id_map"] except KeyError: data = None id_map = {} for i, c in enumerate(kwargs["cases"].itervalues()): feature_list = c.get_features(kwargs["names"]) if data is None: data = np.empty((0, len(feature_list)), dtype=np.float64) data = np.vstack([data, feature_list]) id_map[i] = c.id self._similarity.build_indexing_structure(data, id_map) self.dirty = False
[docs]class CaseBase(object): # noinspection PyTypeChecker """The case base engine. The case base engine maintains the a database of all cases entered into the case base. It manages retrieval, revision, reuse, and retention of cases. Parameters ---------- case_template: dict The template from which to create a new case. :Example: An example template for a feature named ``state`` with the specified feature parameters. ``data`` is the data from which to extract the case from. In this example it is expected that ``data`` has a member variable ``state``. :: { "state": { "type": "float", "value": "data.state", "is_index": True, "retrieval_method": "radius-n", "retrieval_method_params": 0.01 }, "delta_state": { "type": "float", "value": "data.next_state - data.state", "is_index": False, } } reuse_method : str The reuse method name to be used during the reuse step. Default is `defaultreusemethod`. reuse_method_params : dict Non-positional initialization parameters for the reuse method instantiation. revision_method : str The revision method name to be used during the revision step. Default is `defaultrevisionmethod`. revision_method_params : dict Non-positional initialization parameters for the revision method instantiation. retention_method : str The retention method name to be used during the retention step. Default is `defaultretentionmethod`. retention_method_params : dict Non-positional initialization parameters for the retention method instantiation. plot_retrieval : bool Whether to plot the result or not. Default is False. plot_retrieval_names : str or list[str] The names of the feature which to plot. Examples -------- Create a case base: >>> from mlpy.auxiliary.io import load_from_file >>> >>> template = {} >>> cb = CaseBase(template) Fill case base with data read from file: >>> from mlpy.mdp.stateaction import Experience, State, Action >>> >>> data = load_from_file("data/jointsAndActionsData.pkl") >>> for i in xrange(len(data.itervalues().next())): ... for j in xrange(len(data.itervalues().next()[0][i]) - 1): ... if not j == 10: # exclude one experience as test case ... experience = Experience(State(data["states"][i][:, j]), ... Action(data["actions"][i][:, j]), ... State(data["states"][i][:, j + 1])) ... cb.run(cb.case_from_data(experience)) Loop over all cases in the case base: >>> for i in len(cb): ... pass Retrieve case with ``id=0``: >>> case = cb[0] """ @property def counter(self): """The case counter. The counter is increased with every case added to the case base. Returns ------- int : The current count. """ return self._counter def __init__(self, case_template, reuse_method=None, reuse_method_params=None, revision_method=None, revision_method_params=None, retention_method=None, retention_method_params=None, plot_retrieval=None, plot_retrieval_names=None): #: Collection of the unadulterated cases. self._cases = {} """:type: dict[int, Case]""" #: The cases database keeping a similarity model for a #: collection of cases specified by their feature names self._cb = {} """:type: dict[str|tuple[str], CaseBaseEntry]""" self._counter = 0 """:type: int""" self._case_template = case_template """:type: dict""" reuse_method = reuse_method if reuse_method is not None else 'defaultreusemethod' reuse_method_params = reuse_method_params if reuse_method_params is not None else {} try: self._reuse_method = CBRMethodFactory.create(reuse_method, **reuse_method_params) """:type: IReuseMethod""" except: raise ValueError("%s is not a valid reuse method" % reuse_method) revision_method = revision_method if revision_method is not None else 'defaultrevisionmethod' revision_method_params = revision_method_params if revision_method_params is not None else {} try: self._revision_method = CBRMethodFactory.create(revision_method, **revision_method_params) """:type: IRevisionMethod""" except: raise ValueError("%s is not a valid revision method" % revision_method) retention_method = retention_method if retention_method is not None else 'defaultretentionmethod' retention_method_params = retention_method_params if retention_method_params is not None else {} try: self._retention_method = CBRMethodFactory.create(retention_method, **retention_method_params) """:type: IRetentionMethod""" except: raise ValueError("%s is not a valid retention method" % retention_method) self._plot_retrieval = plot_retrieval if plot_retrieval is not None else False """:type: bool""" self._plot_retrieval_names = plot_retrieval_names if plot_retrieval_names is not None else None """:type: str | list[str]""" self._fig = None self._ax = None def __getitem__(self, key): return self._cases[key] def __len__(self): return len(self._cases) def __iter__(self): self.ix = 0 return self
[docs] def next(self): if self.ix == len(self._cases): raise StopIteration item = self._cases[self.ix] self.ix += 1 return item
[docs] def load(self, filename): pass
[docs] def save(self, filename): pass
[docs] def get_new_id(self): """Return an unused case id. Returns ------- int : Unused case ID. """ return self.counter
[docs] def add(self, case): """Add a new case without any checks. Parameters ---------- case : Case The case to add to the case base. """ self._cases[self._counter] = case self._counter += 1 for c in self._cb.itervalues(): c.dirty = True
[docs] def run(self, case): """Run the case base. Run the case base using the CBR methods retrieve, reuse, revision and retention. Parameters ---------- case : Case The query case Returns ------- dict[int, CaseMatch] : The solution to the problem-solving experience """ case_matches = self.retrieve(case) solution = self.reuse(case, case_matches) solution = self.revision(case, solution) self.retain(case, solution) return solution
[docs] def retrieve(self, case, names=None, validity_check=True, **kwargs): """Retrieve cases similar to the query case. Parameters ---------- case : Case The query case. names : str or list[str] The name(s) of the features for which to retrieve similar cases. validity_check : bool This flag controls whether the dirty flag is being checked before determining whether to rebuild the indexing structure or not. Returns ------- dict[int, CaseMatch] : The solution to the problem-solving experience. Other Parameters ---------------- cases : dict[int, Case] The complete case base from which to build the indexing structure used by the similarity model. data : ndarray[ndarray[float]] If this keyword is set, the cases in the case base are ignored and the data entries specified in this variable are used to build the indexing structure. names : str or list[str] The feature name(s) relevant for the similarity computation. This field is only required if the cases for building the indexing structure comes from the `cases` field. id_map : dict[int, int] The mapping from the data stored in the `data` field to their case ids. This field is only required if the data for building the indexing structure comes from the `data` field. """ if len(self._cases) == 0: return {} if names is None: names = case.get_indexed() key = tuple(names) if isinstance(names, list) else names # Update the similarity model if key not in self._cb: self._cb[key] = CaseBaseEntry( SimilarityFactory.create(case.get_retrieval_method(names), **case.get_retrieval_params(names)), validity_check) if "data" not in kwargs and "id_map" not in kwargs: kwargs["cases"] = self._cases kwargs["names"] = names stats = self._cb[key].compute_similarity(case.get_features(names), **kwargs) if self._plot_retrieval and names == self._plot_retrieval_names: self.plot_retrieval(case, [s.case_id for s in stats], names) return {s.case_id: CaseMatch(self._cases[s.case_id], names, s.similarity) for s in stats}
[docs] def reuse(self, case, case_matches): """Performs the reuse step Performs new generalizations and specializations as a consequence of the solution transformation. Parameters ---------- case : Case The query case. case_matches : dict[int, CaseMatch] The solution to the problem-solving experience. Returns ------- dict[int, CaseMatch] : The revised solution to the problem-solving experience. """ if not case_matches: return {} return self._reuse_method.execute(case, case_matches, self.retrieve)
[docs] def revision(self, case, case_matches): """Evaluate solution provided by problem-solving experience. Parameters ---------- case : Case The query case. case_matches : dict[int, CaseMatch] The revised solution to the problem-solving experience. Returns ------- dict[int, CaseMatch] : The corrected solution. """ if not case_matches: return {} return self._revision_method.execute(case, case_matches)
[docs] def retain(self, case, case_matches): """Retain new case. Retain new case depending on the revise outcomes and the CBR policy regarding case retention. Parameters ---------- case : Case The query case. case_matches : dict[int, CaseMatch] The corrected solution """ self._retention_method.execute(case, case_matches, self.add)
[docs] def plot_retrieval(self, case, case_id_list, names=None): """Plot the retrieved data. Parameters ---------- case : Case The query case. case_id_list : list[int] The ids of the cases identified to be similar. names : str or list[str] The name(s) of the features for which similar cases were retrieve. """ if self._fig is None or not plt.fignum_exists(self._fig.number): self._fig = plt.figure() plt.rcParams['legend.fontsize'] = 10 self._fig.suptitle('Similarity: {0}'.format(names)) self._ax = self._fig.add_subplot(111, projection='3d') self._fig.show() self._ax.cla() [x, y, z] = case.get_features(names) self._ax.scatter(x, y, z, edgecolors='r', c='r', marker='o') for c in self._cases.itervalues(): [xs, ys, zs] = c.get_features(names) if c.id in case_id_list: self._ax.scatter(xs, ys, zs, edgecolors='g', c='g', marker='^') else: self._ax.scatter(xs, ys, zs, c='k', marker='o') scatter1_proxy = matplotlib.lines.Line2D([0], [0], linestyle="none", markeredgecolor='r', c='r', marker='o') scatter2_proxy = matplotlib.lines.Line2D([0], [0], linestyle="none", markeredgecolor='k', c='k', marker='o') scatter3_proxy = matplotlib.lines.Line2D([0], [0], linestyle="none", markeredgecolor='g', c='g', marker='^') self._ax.legend([scatter1_proxy, scatter2_proxy, scatter3_proxy], ['query case', 'cases', 'similar'], numpoints=1) self._ax.set_xlabel('X position') self._ax.set_ylabel('Y position') self._ax.set_zlabel('Z position') self._fig.canvas.draw()
[docs] def plot_reuse(self, case, case_matches, revised_matches): """Plot the reuse result. Parameters ---------- case : Case The query case. case_matches : dict[int, CaseMatch] The solution to the problem-solving experience. revised_matches : dict[int, CaseMatch] The revised solution to the problem-solving experience. """ self._reuse_method.plot_data(case, case_matches, revised_matches)
[docs] def plot_revision(self, case, case_matches): """Plot revision results. Parameters ---------- case : Case The query case. case_matches : dict[int, CaseMatch] The revised solution to the problem-solving experience. """ self._revision_method.plot_data(case, case_matches)
[docs] def plot_retention(self, case, case_matches): """Plot the retention result. Parameters ---------- case : Case The query case. case_matches : dict[int, CaseMatch] The corrected solution """ self._retention_method.plot_data(case, case_matches)
# noinspection PyUnusedLocal
[docs] def case_from_data(self, data): """Convert data into a case using the case template. Parameters ---------- data : The data from which to extract the case. Returns ------- Case : The case extracted from the data. """ feature_list = {} for key, t in self._case_template.iteritems(): type_, params = remove_key(t, "type") value, params = remove_key(params, "value") feature_list[key] = FeatureFactory.create(type_, key, eval(value), **params) return Case(self.get_new_id(), features=feature_list)