Source code for instancelib.environment.memory

# Copyright (C) 2021 The InstanceLib Authors. All Rights Reserved.

# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 3 of the License, or (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.

# You should have received a copy of the GNU Lesser General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

from __future__ import annotations
from abc import ABC

from typing import (
    Generic,
    Iterable,
    Dict,
    Iterator,
    Mapping,
    TypeVar,
    Any,
    Union,
)
from typing_extensions import Self

from ..instances.base import Instance, InstanceProvider
from ..instances.memory import (
    MemoryBucketProvider,
    AbstractMemoryProvider,
)
from ..labels.base import LabelProvider
from ..labels.memory import MemoryLabelProvider

from .base import AbstractEnvironment

from ..typehints import KT, DT, VT, RT, LT

import numpy as np

DEFAULT_RNG = np.random.default_rng()
InstanceType = TypeVar("InstanceType", bound="Instance[Any, Any, Any, Any]")


[docs]class AbstractMemoryEnvironment( AbstractEnvironment[InstanceType, KT, DT, VT, RT, LT], ABC, Generic[InstanceType, KT, DT, VT, RT, LT], ): """Environments provide an interface that enable you to access all data stored in the datasets. If there are labels stored in the environment, you can access these as well from here. There are two important properties in every :class:`Environment`: - :meth:`dataset`: Contains all Instances of the original dataset - :meth:`labels`: Contains an object that allows you to access labels easily Besides these properties, this object also provides methods to create new :class:`~instancelib.InstanceProvider` objects that contain a subset of the set of all instances stored in this environment. Attributes ---------- _public_dataset An :class:`InstanceProvider` that contains all original Instances _dataset An :class:`InstanceProvider` that contains all instances _labelprovider This object contains all labels _named_provider All user generated providers that were given a name Examples -------- Access the dataset: >>> dataset = env.dataset >>> instance = next(iter(dataset.values())) Access the labels: >>> labels = env.labels >>> ins_lbls = labels.get_labels(instance) Create a train-test split on the dataset (70 % train, 30 % test): >>> train, test = env.train_test_split(dataset, 0.70) """ _public_dataset: InstanceProvider[InstanceType, KT, DT, VT, RT] """An :class:`~instancelib.InstanceProvider` that contains all original Instances""" _dataset: InstanceProvider[InstanceType, KT, DT, VT, RT] """An :class:`InstanceProvider` that contains all instances""" _labelprovider: LabelProvider[KT, LT] """This object contains all labels""" _named_providers: Dict[ str, InstanceProvider[InstanceType, KT, DT, VT, RT] ] = dict() """All user generated providers that were given a name""" def __contains__(self, __o: object) -> bool: return __o in self._named_providers def __getitem__( self, __k: str ) -> InstanceProvider[InstanceType, KT, DT, VT, RT]: return self._named_providers[__k] def __setitem__( self, __k: str, __v: InstanceProvider[InstanceType, KT, DT, VT, RT] ) -> None: self.set_named_provider(__k, __v) def __len__(self) -> int: return len(self._named_providers) def __delitem__(self, __v: str) -> None: del self._named_providers[__v] def __iter__(self) -> Iterator[str]: return iter(self._named_providers) @property def dataset(self) -> InstanceProvider[InstanceType, KT, DT, VT, RT]: return self._public_dataset @property def all_instances(self) -> InstanceProvider[InstanceType, KT, DT, VT, RT]: return self._dataset @property def labels(self) -> LabelProvider[KT, LT]: return self._labelprovider
[docs] def create_bucket( self, keys: Iterable[KT] ) -> InstanceProvider[InstanceType, KT, DT, VT, RT]: return MemoryBucketProvider[InstanceType, KT, DT, VT, RT]( self._dataset, keys )
[docs] def create_empty_provider( self, ) -> InstanceProvider[InstanceType, KT, DT, VT, RT]: return self.create_bucket([])
[docs] def set_named_provider( self, name: str, value: InstanceProvider[InstanceType, KT, DT, VT, RT] ): self._named_providers[name] = value
[docs] def create_named_provider( self, name: str ) -> InstanceProvider[InstanceType, KT, DT, VT, RT]: self._named_providers[name] = self.create_empty_provider() return self._named_providers[name]
[docs]class MemoryEnvironment( AbstractMemoryEnvironment[InstanceType, KT, DT, VT, RT, LT], Generic[InstanceType, KT, DT, VT, RT, LT], ): """This class implements the :class:`~abc.ABC` :class:`~instancelib.Environment`. In this method, all data is loaded and stored in RAM and is not preserved on disk. There are two important properties in every :class:`Environment`: - :meth:`dataset`: Contains all Instances of the original dataset - :meth:`labels`: Contains an object that allows you to access labels easily Besides these properties, this object also provides methods to create new :class:`~instancelib.InstanceProvider` objects that contain a subset of the set of all instances stored in this environment. Parameters ---------- dataset : InstanceProvider[InstanceType, KT, DT, VT, RT] An InstanceProvider that contains all Instances labelprovider : MemoryLabelProvider[KT, LT] The label provider that contains the labels associated with the instances from the :data:`dataset` variable Attributes ---------- _public_dataset An :class:`InstanceProvider` that contains all original Instances _dataset An :class:`InstanceProvider` that contains all instances _labelprovider This object contains all labels _named_provider All user generated providers that were given a name Examples -------- Access the dataset: >>> dataset = env.dataset >>> instance = next(iter(dataset.values())) Access the labels: >>> labels = env.labels >>> ins_lbls = labels.get_labels(instance) Create a train-test split on the dataset (70 % train, 30 % test): >>> train, test = env.train_test_split(dataset, 0.70) Store the environment to disk: >>> import pickle >>> with open("file.pkl", "wb") as fh: ... pickle.dump(env, fh) >>> print("The file is saved to file.pkl") Load the environment from disk: >>> import pickle >>> with open("file.pkl", "rb") as fh: ... env = pickle.load(fh) >>> dataset = env.dataset """ def __init__( self, dataset: InstanceProvider[InstanceType, KT, DT, VT, RT], labelprovider: LabelProvider[KT, LT], ): """[summary] Parameters ---------- dataset : InstanceProvider[InstanceType, KT, DT, VT, RT] [description] labelprovider : MemoryLabelProvider[KT, LT] [description] """ self._dataset = dataset self._public_dataset = MemoryBucketProvider[ InstanceType, KT, DT, VT, RT ](dataset, dataset.key_list) self._labelprovider = labelprovider self._named_providers = dict()
[docs] @classmethod def shuffle( cls, env: Self, rng: np.random.Generator = DEFAULT_RNG ) -> Self: """Shuffle the identifiers according to a mapping Parameters ---------- env : Self The environment that needs to be shuffled mapping : Mapping[KT, KT] The mapping that is to be used Returns ------- Self A shuffled environment """ keys = env.all_instances.key_list permutation = env.all_instances.key_list rng.shuffle(permutation) # type: ignore mapping = {k: v for k, v in zip(keys, permutation)} if isinstance(env.all_instances, AbstractMemoryProvider): newprovider = env.all_instances.shuffle(env.all_instances, mapping) if isinstance(env.labels, MemoryLabelProvider): newlabels = env.labels.translate_keys(env.labels, mapping) else: newlabels = MemoryLabelProvider.translate_keys( env.labels, mapping ) return cls(newprovider, newlabels) return cls(env.all_instances, env.labels)