from typing import Generic, Iterable, Tuple, Type, Callable
import numpy as np
import pandas as pd
from ..api_types import PeriodObservation
from .._legacy_dataset import TemporalIndexType, FeaturesT
from ..datatypes import Location, add_field, remove_field, TimeSeriesArray, TimeSeriesData
from ..time_period import PeriodRange
from ..time_period.date_util_wrapper import TimeStamp
import dataclasses
class TemporalDataclass(Generic[FeaturesT]):
'''
Wraps a dataclass in a object that is can be sliced by time period.
Call .data() to get the data back.
'''
def __init__(self, data: FeaturesT):
self._data = data
def __repr__(self):
return f'{self.__class__.__name__}({self._data})'
def _restrict_by_slice(self, period_range: slice):
assert period_range.step is None
start, stop = (None, None)
if period_range.start is not None:
start = self._data.time_period.searchsorted(period_range.start)
if period_range.stop is not None:
stop = self._data.time_period.searchsorted(period_range.stop, side='right')
return self._data[start:stop]
def fill_to_endpoint(self, end_time_stamp: TimeStamp) -> 'TemporalDataclass[FeaturesT]':
if self.end_timestamp == end_time_stamp:
return self
n_missing = self._data.time_period.delta.n_periods(self.end_timestamp, end_time_stamp)
#n_missing = (end_time_stamp - self.end_timestamp) // self._data.time_period.delta
assert n_missing >= 0, (f'{n_missing} < 0', end_time_stamp, self.end_timestamp)
old_time_period = self._data.time_period
new_time_period = PeriodRange(old_time_period.start_timestamp, end_time_stamp, old_time_period.delta)
d = {field.name: getattr(self._data, field.name) for field in dataclasses.fields(self._data) if
field.name != 'time_period'}
for name, data in d.items():
d[name] = np.pad(data.astype(float), (0, n_missing),
constant_values=np.nan)
return TemporalDataclass(
self._data.__class__(new_time_period, **d))
def fill_to_range(self, start_timestamp, end_timestamp):
if self.end_timestamp == end_timestamp and self.start_timestamp==start_timestamp:
return self
n_missing_start = self._data.time_period.delta.n_periods(start_timestamp, self.start_timestamp)
#n_missing_start = (self.start_timestamp - start_timestamp) // self._data.time_period.delta
n_missing = (end_timestamp - self.end_timestamp) // self._data.time_period.delta
assert n_missing >= 0, (f'{n_missing} < 0', end_timestamp, self.end_timestamp)
assert n_missing_start >= 0, (f'{n_missing} < 0', end_timestamp, self.end_timestamp)
old_time_period = self._data.time_period
new_time_period = PeriodRange(start_timestamp, end_timestamp, old_time_period.delta)
d = {field.name: getattr(self._data, field.name) for field in dataclasses.fields(self._data) if
field.name != 'time_period'}
for name, data in d.items():
d[name] = np.pad(data.astype(float), (n_missing_start, n_missing),
constant_values=np.nan)
return TemporalDataclass(
self._data.__class__(new_time_period, **d))
def restrict_time_period(self, period_range: TemporalIndexType) -> 'TemporalDataclass[FeaturesT]':
assert isinstance(period_range, slice)
assert period_range.step is None
if hasattr(self._data.time_period, 'searchsorted'):
return TemporalDataclass(self._restrict_by_slice(period_range))
mask = np.full(len(self._data.time_period), True)
if period_range.start is not None:
mask = mask & (self._data.time_period >= period_range.start)
if period_range.stop is not None:
mask = mask & (self._data.time_period <= period_range.stop)
return TemporalDataclass(self._data[mask])
def data(self) -> Iterable[FeaturesT]:
return self._data
def to_pandas(self) -> pd.DataFrame:
return self._data.to_pandas()
def join(self, other):
return TemporalDataclass(np.concatenate([self._data, other._data]))
@property
def start_timestamp(self) -> pd.Timestamp:
return self._data.time_period[0].start_timestamp
@property
def end_timestamp(self) -> pd.Timestamp:
return self._data.time_period[-1].end_timestamp
[docs]
class DataSet(Generic[FeaturesT]):
'''
Class representing severeal time series at different locations.
'''
def __init__(self, data_dict: dict[str, FeaturesT]):
self._data_dict = {loc: TemporalDataclass(data) if not isinstance(data, TemporalDataclass) else data for
loc, data in data_dict.items()}
def __repr__(self):
return f'{self.__class__.__name__}({self._data_dict})'
def __getitem__(self, location: Location) -> TemporalDataclass[FeaturesT]:
return self._data_dict[location].data()
def keys(self):
return self._data_dict.keys()
def items(self):
return ((k, d.data()) for k, d in self._data_dict.items())
def values(self):
return (d.data() for d in self._data_dict.values())
@property
def period_range(self) -> PeriodRange:
first_period_range = self._data_dict[next(iter(self._data_dict))].data().time_period
assert first_period_range.start_timestamp == first_period_range.start_timestamp
assert first_period_range.end_timestamp == first_period_range.end_timestamp
return first_period_range
@property
def start_timestamp(self) -> pd.Timestamp:
return min(data.start_timestamp for data in self.data())
@property
def end_timestamp(self) -> pd.Timestamp:
return max(data.end_timestamp for data in self.data())
def get_locations(self, location: Iterable[Location]) -> 'DataSet[FeaturesT]':
return self.__class__({loc: self._data_dict[loc] for loc in location})
def get_location(self, location: Location) -> FeaturesT:
return self._data_dict[location]
def restrict_time_period(self, period_range: TemporalIndexType) -> 'DataSet[FeaturesT]':
return self.__class__(
{loc: data.restrict_time_period(period_range) for loc, data in self._data_dict.items()})
def locations(self) -> Iterable[Location]:
return self._data_dict.keys()
def data(self) -> Iterable[FeaturesT]:
return self._data_dict.values()
#def items(self) -> Iterable[Tuple[Location, FeaturesT]]:
# return self._data_dict.items()
def _add_location_to_dataframe(self, df, location):
df['location'] = location
return df
[docs]
def to_pandas(self) -> pd.DataFrame:
''' Join the pandas frame for all locations with locations as column'''
tables = [self._add_location_to_dataframe(data.to_pandas(), location) for location, data in
self._data_dict.items()]
return pd.concat(tables)
def interpolate(self):
return self.__class__(
{loc: TemporalDataclass(data.data().interpolate()) for loc, data in self.items()})
@classmethod
def _fill_missing(cls, data_dict: dict[str, TemporalDataclass[FeaturesT]]):
''' Fill missing values in a dictionary of TemporalDataclasses'''
end = max(data.end_timestamp for data in data_dict.values())
start = min(data.start_timestamp for data in data_dict.values())
for location, data in data_dict.items():
data_dict[location] = data.fill_to_range(start, end)
return data_dict
[docs]
@classmethod
def from_pandas(cls, df: pd.DataFrame, dataclass: Type[FeaturesT], fill_missing=False) -> 'DataSet[FeaturesT]':
'''
Create a SpatioTemporalDict from a pandas dataframe.
The dataframe needs to have a 'location' column, and a 'time_period' column.
The time_period columnt needs to have strings that can be parsed into a period.
All fields in the dataclass needs to be present in the dataframe.
If 'fill_missing' is True, missing values will be filled with np.nan. Else all the time series needs to be
consecutive.
Parameters
----------
df : pd.DataFrame
The dataframe
dataclass : Type[FeaturesT]
The dataclass to use for the time series
fill_missing : bool, optional
If missing values should be filled, by default False
Returns
-------
DataSet[FeaturesT]
The SpatioTemporalDict
Examples
--------
>>> import pandas as pd
>>> from climate_health.spatio_temporal_data.temporal_dataclass import DataSet
>>> from climate_health.datatypes import HealthData
>>> df = pd.DataFrame({'location': ['Oslo', 'Oslo', 'Bergen', 'Bergen'],
... 'time_period': ['2020-01', '2020-02', '2020-01', '2020-02'],
... 'disease_cases': [10, 20, 30, 40]})
>>> DataSet.from_pandas(df, HealthData)
'''
data_dict = {}
for location, data in df.groupby('location'):
data_dict[location] = TemporalDataclass(dataclass.from_pandas(data, fill_missing))
data_dict = cls._fill_missing(data_dict)
return cls(data_dict)
def to_csv(self, file_name: str, mode='w'):
self.to_pandas().to_csv(file_name, mode=mode)
@classmethod
def df_from_pydantic_observations(cls, observations: list[PeriodObservation])-> TimeSeriesData:
df = pd.DataFrame([obs.model_dump() for obs in observations])
dataclass = TimeSeriesData.create_class_from_basemodel(type(observations[0]))
return dataclass.from_pandas(df)
[docs]
@classmethod
def from_period_observations(cls, observation_dict: dict[str, list[PeriodObservation]]) -> 'DataSet[TimeSeriesData]':
'''
Create a SpatioTemporalDict from a dictionary of PeriodObservations.
The keys are the location names, and the values are lists of PeriodObservations.
Parameters
----------
observation_dict : dict[str, list[PeriodObservation]]
The dictionary of observations
Returns
-------
DataSet[TimeSeriesData]
The SpatioTemporalDict
Examples
--------
>>> from climate_health.spatio_temporal_data.temporal_dataclass import DataSet
>>> from climate_health.api_types import PeriodObservation
>>> class HealthObservation(PeriodObservation):
... disease_cases: int
>>> observations = {'Oslo': [HealthObservation(time_period='2020-01', disease_cases=10),
... HealthObservation(time_period='2020-02', disease_cases=20)]}
>>> DataSet.from_period_observations(observations)
>>> DataSet.to_pandas()
'''
data_dict = {}
for location, observations in observation_dict.items():
data_dict[location] = TemporalDataclass(cls.df_from_pydantic_observations(observations))
return cls(data_dict)
@classmethod
def from_csv(cls, file_name: str, dataclass: Type[FeaturesT]) -> 'DataSet[FeaturesT]':
return cls.from_pandas(pd.read_csv(file_name), dataclass)
def join_on_time(self, other: 'DataSet[FeaturesT]') -> 'DataSet[Tuple[FeaturesT, FeaturesT]]':
''' Join two SpatioTemporalDicts on time. Returns a new SpatioTemporalDict.
Assumes other is later in time.
'''
return self.__class__({loc: self._data_dict[loc].join(other._data_dict[loc]) for loc in self.locations()})
def add_fields(self, new_type, **kwargs: dict[str, Callable]):
return self.__class__({loc: add_field(data.data(), new_type, **{key: func(data.data()) for key, func in kwargs.items()}) for loc, data in self.items()})
def remove_field(self, field_name, new_class=None):
return self.__class__({loc: remove_field(data.data(), field_name, new_class) for loc, data in self.items()})
@classmethod
def from_fields(cls, dataclass: type[TimeSeriesData], fields: dict[str, 'DataSet[TimeSeriesArray]']):
start_timestamp = min(data.start_timestamp for data in fields.values())
end_timestamp = max(data.end_timestamp for data in fields.values())
period_range = PeriodRange(start_timestamp, end_timestamp, fields[next(iter(fields))].period_range.delta)
new_dict = {}
field_names = list(fields.keys())
#all_locations = {location for field in fields.values() for location in field.keys()}
common_locations = set.intersection(*[set(field.keys()) for field in fields.values()])
#for field, data in fields.items():
# assert set(data.keys()) == all_locations, (field, all_locations-set(data.keys()))
for location in common_locations:
new_dict[location] = dataclass(period_range, **{field: fields[field][location].fill_to_range(start_timestamp, end_timestamp).value for field in field_names})
return cls(new_dict)