# Copyright (C) 2021 The InstanceLib Authors. All Rights Reserved.
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 3 of the License, or (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
# You should have received a copy of the GNU Lesser General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
from __future__ import annotations
import collections
from typing import (
Any,
Callable,
Dict,
FrozenSet,
Generic,
Iterable,
Iterator,
Mapping,
Optional,
Sequence,
Set,
Tuple,
TypeVar,
Union,
)
from typing_extensions import Self
from ..instances import Instance
from ..typehints import KT, LT
from ..utils.func import list_unzip, union
from ..utils.to_key import to_key
from .base import LabelProvider
import collections.abc
_T = TypeVar("_T")
[docs]class MemoryLabelProvider(LabelProvider[KT, LT], Generic[KT, LT]):
"""A Memory based implementation to test and benchmark AL algorithms"""
_labelset: FrozenSet[LT]
_labeldict: Dict[KT, Set[LT]]
_labeldict_inv: Dict[LT, Set[KT]]
def __init__(
self,
labelset: Iterable[LT],
labeldict: Dict[KT, Set[LT]],
labeldict_inv: Optional[Dict[LT, Set[KT]]] = None,
) -> None:
self._labelset = frozenset(labelset)
self._labeldict = labeldict
if labeldict_inv is None:
self._labeldict_inv = {label: set() for label in self._labelset}
for key in self._labeldict.keys():
for label in self._labeldict[key]:
self._labeldict_inv[label].add(key)
else:
self._labeldict_inv = labeldict_inv
def __iter__(self) -> Iterator[KT]:
return iter(self._labeldict)
def __contains__(self, __o: object) -> bool:
return to_key(__o) in self._labeldict
def __len__(self) -> int:
return len(self._labeldict)
[docs] @classmethod
def from_data(
cls,
labelset: Iterable[LT],
indices: Sequence[KT],
labels: Sequence[Iterable[LT]],
) -> MemoryLabelProvider[KT, LT]:
labelset = frozenset(labelset)
labeldict = {
idx: set(labellist) for (idx, labellist) in zip(indices, labels)
}
labeldict_inv: Dict[LT, Set[KT]] = {label: set() for label in labelset}
# Store all instances in a Dictionary<LT, Set[ID]>
for key, labellist in labeldict.items():
for label in labellist:
labeldict_inv[label].add(key)
return cls(labelset, labeldict, labeldict_inv)
[docs] @classmethod
def from_provider(
cls, provider: LabelProvider[KT, LT], subset: Iterable[KT] = list()
) -> MemoryLabelProvider[KT, LT]:
instances = frozenset(subset) if subset else frozenset(provider.keys())
labelset = provider.labelset
labeldict_inv = {
label: set(
provider.get_instances_by_label(label).intersection(instances)
)
for label in labelset
}
labeldict: Dict[KT, Set[LT]] = {}
for label, key_list in labeldict_inv.items():
for key in key_list:
labeldict.setdefault(key, set()).add(label)
return cls(labelset, labeldict, labeldict_inv)
[docs] @classmethod
def from_tuples(
cls, predictions: Sequence[Tuple[KT, FrozenSet[LT]]]
) -> MemoryLabelProvider[KT, LT]:
_, labels = list_unzip(predictions)
labelset = union(*labels)
labeldict = {key: set(labeling) for (key, labeling) in predictions}
provider = cls(labelset, labeldict, None)
return provider
@property
def labelset(self) -> FrozenSet[LT]:
return self._labelset
[docs] def remove_labels(
self, instance: Union[KT, Instance[KT, Any, Any, Any]], *labels: LT
):
key = to_key(instance)
if key not in self._labeldict:
raise KeyError("Key {} is not found".format(key))
for label in labels:
self._labeldict[key].discard(label)
self._labeldict_inv[label].discard(key)
[docs] def set_labels(
self, instance: Union[KT, Instance[KT, Any, Any, Any]], *labels: LT
):
key = to_key(instance)
for label in labels:
self._labeldict.setdefault(key, set()).add(label)
self._labeldict_inv.setdefault(label, set()).add(key)
[docs] def get_labels(
self, instance: Union[KT, Instance[KT, Any, Any, Any]]
) -> FrozenSet[LT]:
key = to_key(instance)
if key in self:
return frozenset(self._labeldict[key])
return frozenset()
[docs] def get_instances_by_label(self, label: LT) -> FrozenSet[KT]:
return frozenset(self._labeldict_inv.setdefault(label, set()))
[docs] def document_count(self, label: LT) -> int:
return len(self.get_instances_by_label(label))
[docs] @classmethod
def rename_labels(
cls,
provider: LabelProvider[KT, _T],
mapping: Union[Mapping[_T, LT], Callable[[_T], LT]],
) -> MemoryLabelProvider[KT, LT]:
mapper = (
mapping.__getitem__
if isinstance(mapping, collections.abc.Mapping)
else mapping
)
labeldict = {
key: {mapper(old_label) for old_label in old_labels}
for key, old_labels in provider.items()
}
labelset = frozenset([mapper(lbl) for lbl in provider.labelset])
provider = cls(labelset, labeldict, None) # type: ignore
return provider # type: ignore
[docs] @classmethod
def translate_keys(
cls, provider: LabelProvider[_T, LT], mapping: Mapping[_T, KT]
) -> Self:
new_dict = dict((mapping[k], set(v)) for k, v in provider.items())
lbls = provider.labelset
return cls(lbls, new_dict)