from abc import ABC, abstractmethod
from dataclasses import dataclass
from threading import local
from typing import (
Any,
Callable,
FrozenSet,
Generic,
Iterable,
Iterator,
Mapping,
MutableMapping,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
)
import numpy.typing as npt
import pandas as pd
from instancelib.environment.memory import MemoryEnvironment
from instancelib.utils.func import union
from ..typehints import DT, KT, RT, VT
from ..utils.chunks import divide_iterable_in_lists
from .base import Instance, InstanceProvider
from .external import ExternalProvider
from .hdf5 import HDF5VectorInstanceProvider
from .memory import AbstractMemoryProvider
IT = TypeVar("IT", bound="Instance[Any, Any, Any, Any]")
[docs]class ReadOnlyDataset(
Mapping[KT, DT],
ABC,
Generic[KT, DT],
):
@abstractmethod
def __getitem__(self, __k: KT) -> DT:
raise NotImplementedError
[docs] def get_bulk(self, keys: Sequence[KT]) -> Sequence[DT]:
return [self[key] for key in keys]
@property
@abstractmethod
def identifiers(self) -> FrozenSet[KT]:
raise NotImplementedError
@abstractmethod
def __contains__(self, __o: object) -> bool:
return super().__contains__(__o)
def __iter__(self) -> Iterator[KT]:
return iter(self.identifiers)
[docs]class PandasDataset(ReadOnlyDataset[int, Any]):
def __init__(self, df: pd.DataFrame, data_col: str) -> None:
self.df = df
self.data_col = data_col
self.ids = frozenset(range(0, len(self.df)))
def __getitem__(self, __k: int) -> Any:
data: Any = self.df.iloc[__k][self.data_col]
return data
def __len__(self) -> int:
return len(self.df)
@property
def identifiers(self) -> FrozenSet[int]:
return self.ids
def __contains__(self, __o: object) -> bool:
return __o in self.ids
[docs] def get_bulk(self, keys: Sequence[int]) -> Sequence[Any]:
data: Sequence[Any] = self.df.iloc[keys][self.data_col] # type: ignore
return data
[docs]class ReadOnlyProvider(
InstanceProvider[IT, KT, DT, npt.NDArray[Any], RT], Generic[IT, KT, DT, RT]
):
local_data: InstanceProvider[IT, KT, DT, npt.NDArray[Any], RT]
_stores: Sequence[Mapping[KT, Any]]
def __init__(
self,
dataset: ReadOnlyDataset[KT, DT],
from_data_builder: Callable[[KT, DT], IT],
local_data: InstanceProvider[IT, KT, DT, npt.NDArray[Any], RT],
) -> None:
self.instance_cache = dict()
self.dataset = dataset
self._stores = (self.local_data, self.instance_cache, self.dataset)
self.from_data_builder = from_data_builder
[docs] def build_from_external(self, k: KT) -> IT:
data = self.dataset[k]
ins = self.from_data_builder(k, data)
return ins
[docs] def update_external(
self, ins: Instance[KT, DT, npt.NDArray[Any], RT]
) -> None:
return super().update_external(ins)
def __getitem__(self, k: KT) -> IT:
if k in self.instance_cache:
instance = self.instance_cache[k]
return instance
if k in self.local_data:
instance = self.local_data[k]
return instance
if k in self.dataset:
instance = self.build_from_external(k)
self.instance_cache[k] = instance
return instance
raise KeyError(
f"Instance with key {k} is not present in this provider"
)
def __contains__(self, item: object) -> bool:
disjunction = any(map(lambda x: item in x, self._stores))
return disjunction
def _get_local_keys(self, keys: Iterable[KT]) -> FrozenSet[KT]:
return frozenset(self.local_data).intersection(keys)
def _get_cached_keys(self, keys: Iterable[KT]) -> FrozenSet[KT]:
return frozenset(self.instance_cache).intersection(keys)
def _get_external_keys(self, keys: Iterable[KT]) -> FrozenSet[KT]:
return frozenset(self.dataset).intersection(keys)
def _cached_data(
self, keys: Iterable[KT], batch_size: int = 200
) -> Iterator[Sequence[Tuple[KT, DT]]]:
chunks = divide_iterable_in_lists(keys, batch_size)
c = self.instance_cache
for chunk in chunks:
yield [(k, c[k].data) for k in chunk]
def _local_data(
self, keys: Iterable[KT], batch_size: int = 200
) -> Iterator[Sequence[Tuple[KT, DT]]]:
return self.local_data.data_chunker_selector(keys, batch_size)
def _external_data(
self, keys: Iterable[KT], batch_size: int = 200
) -> Iterator[Sequence[Tuple[KT, DT]]]:
chunks = divide_iterable_in_lists(keys, batch_size)
for chunk in chunks:
datas = self.dataset.get_bulk(chunk)
result = list(zip(chunk, datas))
yield result
@property
def _all_keys(self) -> FrozenSet[KT]:
return union(*map(lambda x: frozenset(x.keys()), self._stores))
@property
def key_list(self) -> Sequence[KT]:
return list(self._all_keys)
def __iter__(self) -> Iterator[KT]:
return iter(self.key_list)
[docs] def data_chunker(
self, batch_size: int = 200
) -> Iterator[Sequence[Tuple[KT, DT]]]:
yield from self.data_chunker_selector(self.key_list)
[docs] def data_chunker_selector(
self, keys: Iterable[KT], batch_size: int = 200
) -> Iterator[Sequence[Tuple[KT, DT]]]:
keyset = frozenset(keys)
local_keys = self._get_local_keys(keyset)
yield from super().data_chunker_selector(local_keys, batch_size)
remaining_keys = frozenset(keyset).difference(local_keys)
cached_keys = self._get_cached_keys(remaining_keys)
yield from self._cached_data(cached_keys)
remaining_keys = remaining_keys.difference(cached_keys)
external_keys = self._get_external_keys(remaining_keys)
yield from self._external_data(external_keys)
[docs] def construct(*args: Any, **kwargs: Any) -> IT:
raise NotImplementedError
[docs] def create(*args: Any, **kwargs: Any) -> IT:
raise NotImplementedError