# Copyright (C) 2021 The InstanceLib Authors. All Rights Reserved.
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 3 of the License, or (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
# You should have received a copy of the GNU Lesser General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
from __future__ import annotations
from abc import ABC, abstractmethod
from uuid import UUID, uuid4
from ..utils.func import filter_snd_none
from ..utils.to_key import to_key
import itertools
from typing import (
Any,
Dict,
Generic,
Iterable,
Iterator,
List,
Mapping,
Optional,
Sequence,
Set,
Tuple,
TypeVar,
Union,
)
from .base import AbstractBucketProvider, Instance, InstanceProvider
from ..typehints import KT, DT, VT, RT
from typing_extensions import Self
_T = TypeVar("_T")
InstanceType = TypeVar("InstanceType", bound="Instance[Any, Any, Any, Any]")
[docs]class DataPoint(Instance[KT, DT, VT, RT], Generic[KT, DT, VT, RT]):
def __init__(
self,
identifier: KT,
data: DT,
vector: Optional[VT] = None,
representation: Optional[RT] = None,
) -> None:
self._identifier = identifier
self._data = data
self._vector = vector
self._representation = data if representation is None else representation
@property
def data(self) -> DT:
return self._data
@property
def representation(self) -> RT:
return self._representation
@property
def identifier(self) -> KT:
return self._identifier
@identifier.setter
def identifier(self, value: KT) -> None:
self._identifier = value
@property
def vector(self) -> Optional[VT]:
return self._vector
@vector.setter
def vector(self, value: Optional[VT]) -> None: # type: ignore
self._vector = value
[docs] @classmethod
def from_instance(cls, instance: Instance[KT, DT, VT, RT]):
return cls(
instance.identifier,
instance.data,
instance.vector,
instance.representation,
)
[docs]class AbstractMemoryProvider(
InstanceProvider[InstanceType, KT, DT, VT, RT],
ABC,
Generic[InstanceType, KT, DT, VT, RT],
):
dictionary: Dict[KT, InstanceType]
children: Dict[KT, Set[KT]]
parents: Dict[KT, KT]
def __init__(self, instances: Iterable[InstanceType]):
self.dictionary = {
instance.identifier: instance for instance in instances
}
self.children = dict()
self.parents = dict()
def __iter__(self) -> Iterator[KT]:
yield from self.dictionary.keys()
def __getitem__(self, key: KT) -> InstanceType:
return self.dictionary[key]
def __setitem__(self, key: KT, value: InstanceType) -> None:
self.dictionary[key] = value # type: ignore
def __delitem__(self, key: KT) -> None:
del self.dictionary[key]
def __len__(self) -> int:
return len(self.dictionary)
def __contains__(self, key: object) -> bool:
return key in self.dictionary
@property
def empty(self) -> bool:
return not self.dictionary
[docs] def get_all(self) -> Iterator[InstanceType]:
yield from list(self.values())
[docs] def clear(self) -> None:
self.dictionary = {}
[docs] def bulk_get_vectors(
self, keys: Sequence[KT]
) -> Tuple[Sequence[KT], Sequence[VT]]:
vectors = [self[key].vector for key in keys]
ret_keys, ret_vectors = filter_snd_none(keys, vectors) # type: ignore
return ret_keys, ret_vectors
[docs] def bulk_get_all(self) -> List[InstanceType]:
return list(self.get_all())
[docs] def add_child(
self,
parent: Union[KT, Instance[KT, DT, VT, RT]],
child: Union[KT, Instance[KT, DT, VT, RT]],
) -> None:
parent_key: KT = to_key(parent)
child_key: KT = to_key(child)
assert parent_key != child_key
if parent_key in self and child_key in self:
self.children.setdefault(parent_key, set()).add(child_key)
self.parents[child_key] = parent_key
else:
raise KeyError(
"Either the parent or child does not exist in this Provider"
)
[docs] def get_children(
self, parent: Union[KT, Instance[KT, DT, VT, RT]]
) -> Sequence[InstanceType]:
parent_key: KT = to_key(parent)
if parent_key in self.children:
children = [
self.dictionary[child_key]
for child_key in self.children[parent_key]
]
return children # type: ignore
return []
[docs] def get_children_keys(
self, parent: Union[KT, Instance[KT, DT, VT, RT]]
) -> Sequence[KT]:
parent_key: KT = to_key(parent)
if parent_key in self.children:
return list(self.children[parent_key])
return []
[docs] def get_parent(
self, child: Union[KT, Instance[KT, DT, VT, RT]]
) -> InstanceType:
child_key: KT = to_key(child)
if child_key in self.parents:
parent_key = self.parents[child_key]
parent = self.dictionary[parent_key]
return parent # type: ignore
raise KeyError(f"The instance with key {child_key} has no parent")
[docs] def discard_children(
self, parent: Union[KT, Instance[KT, DT, VT, RT]]
) -> None:
parent_key: KT = to_key(parent)
if parent_key in self.children:
children = self.children[parent_key]
self.children[parent_key] = set()
for child in children:
del self.dictionary[child]
[docs] @staticmethod
@abstractmethod
def construct(*args: Any, **kwargs: Any) -> InstanceType:
raise NotImplementedError
[docs] @classmethod
def from_data_and_indices(
cls,
indices: Sequence[KT],
raw_data: Sequence[DT],
vectors: Optional[Sequence[Optional[VT]]] = None,
) -> AbstractMemoryProvider[InstanceType, KT, DT, VT, RT]:
if vectors is None or len(vectors) != len(indices):
vectors = [None] * len(indices)
datapoints = itertools.starmap(
cls.construct, zip(indices, raw_data, vectors, raw_data)
)
return cls(datapoints)
[docs] @classmethod
def from_data(cls, raw_data: Sequence[DT]) -> Self:
indices = range(len(raw_data))
vectors = [None] * len(raw_data)
datapoints = itertools.starmap(
cls.construct, zip(indices, raw_data, vectors, raw_data)
)
return cls(datapoints)
[docs] @classmethod
def shuffle(
cls,
provider: InstanceProvider[InstanceType, _T, DT, VT, RT],
mapping: Mapping[_T, KT],
) -> Self:
"""Reorder the provider according to the given mapping
Parameters
----------
provider : InstanceProvider[InstanceType, _T, DT, VT, RT]
The provider that needs to be reordered
mapping : Mapping[_T, KT]
The mapping that maps old identifiers to new identifiers
Returns
-------
Self
The shuffled
"""
instances = itertools.starmap(
cls.construct,
sorted(
(
(
mapping[ins.identifier],
ins.data,
ins.vector,
ins.representation,
)
for ins in provider.values()
),
key=lambda x: x[0], # type: ignore
),
)
return cls(instances)
[docs]class DataPointProvider(
AbstractMemoryProvider[
DataPoint[Union[KT, UUID], DT, VT, RT], Union[KT, UUID], DT, VT, RT
],
Generic[KT, DT, VT, RT],
):
[docs] @staticmethod
def construct(*args: Any, **kwargs: Any):
new_instance = DataPoint[Union[KT, UUID], DT, VT, RT](*args, **kwargs)
return new_instance
[docs] def create(self, *args: Any, **kwargs: Any):
new_key = uuid4()
new_instance = DataPoint[Union[KT, UUID], DT, VT, RT](
new_key, *args, **kwargs
)
self.add(new_instance)
return new_instance
[docs]class MemoryBucketProvider(
AbstractBucketProvider[InstanceType, KT, DT, VT, RT],
Generic[InstanceType, KT, DT, VT, RT],
):
def __init__(
self,
dataset: InstanceProvider[InstanceType, KT, DT, VT, RT],
instances: Iterable[KT],
):
self._elements: Set[KT] = set(instances)
self.dataset = dataset
def _add_to_bucket(self, key: KT) -> None:
self._elements.add(key)
def _remove_from_bucket(self, key: KT) -> None:
self._elements.discard(key)
def _clear_bucket(self) -> None:
self._elements = set()
def _in_bucket(self, key: KT) -> bool:
return key in self._elements
def _len_bucket(self) -> int:
return len(self._elements)
@property
def _bucket(self) -> Iterable[KT]:
iterable = iter(self._elements)
return iterable
@property
def empty(self) -> bool:
return not self._elements