Source code for composeml.label_maker

from sys import stdout

from tqdm import tqdm

from composeml.data_slice import DataSliceGenerator
from composeml.label_search import ExampleSearch, LabelSearch
from composeml.label_times import LabelTimes


[docs]class LabelMaker: """Automatically makes labels for prediction problems."""
[docs] def __init__(self, target_entity, time_index, labeling_function=None, window_size=None): """Creates an instance of label maker. Args: target_entity (str): Entity on which to make labels. time_index (str): Name of time column in the data frame. labeling_function (function or list(function) or dict(str=function)): Function, list of functions, or dictionary of functions that transform a data slice. When set as a dictionary, the key is used as the name of the labeling function. window_size (str or int): Size of the data slices. As a string, the value can be a timedelta or a column in the data frame to group by. As an integer, the value can be the number of rows. Default value is all future data. """ self.labeling_function = labeling_function or {} self.target_entity = target_entity self.time_index = time_index self.window_size = window_size
def _name_labeling_function(self, function): """Gets the names of the labeling functions.""" has_name = hasattr(function, '__name__') return function.__name__ if has_name else type(function).__name__ def _check_labeling_function(self, function, name=None): """Checks whether the labeling function is callable.""" assert callable(function), 'labeling function must be callabe' return function @property def labeling_function(self): """Gets the labeling function(s).""" return self._labeling_function @labeling_function.setter def labeling_function(self, value): """Sets and formats the intial labeling function(s). Args: value (function or list(function) or dict(str=function)): Function that transforms a data slice to a label. """ if isinstance(value, dict): for name, function in value.items(): self._check_labeling_function(function) assert isinstance(name, str), 'labeling function name must be string' if callable(value): value = [value] if isinstance(value, (tuple, list)): value = {self._name_labeling_function(function): self._check_labeling_function(function) for function in value} assert isinstance(value, dict), 'value type for labeling function not supported' self._labeling_function = value
[docs] def slice(self, df, num_examples_per_instance, minimum_data=None, maximum_data=None, gap=None, drop_empty=True): """Generates data slices of target entity. Args: df (DataFrame): Data frame to create slices on. num_examples_per_instance (int): Number of examples per unique instance of target entity. minimum_data (str): Minimum data before starting the search. Default value is first time of index. maximum_data (str): Maximum data before stopping the search. Default value is last time of index. gap (str or int): Time between examples. Default value is window size. If an integer, search will start on the first event after the minimum data. drop_empty (bool): Whether to drop empty slices. Default value is True. Returns: ds (generator): Returns a generator of data slices. """ self._check_example_count(num_examples_per_instance, gap) df = self.set_index(df) entity_groups = df.groupby(self.target_entity) num_examples_per_instance = ExampleSearch._check_number(num_examples_per_instance) generator = DataSliceGenerator( window_size=self.window_size, min_data=minimum_data, max_data=maximum_data, drop_empty=drop_empty, gap=gap, ) for entity_id, df in entity_groups: for ds in generator(df): setattr(ds.context, self.target_entity, entity_id) yield ds if ds.context.slice_number >= num_examples_per_instance: break
@property def _bar_format(self): """Template to format the progress bar during a label search.""" value = "Elapsed: {elapsed} | " value += "Remaining: {remaining} | " value += "Progress: {l_bar}{bar}| " value += self.target_entity + ": {n}/{total} " return value def _run_search( self, df, generator, search, verbose=True, *args, **kwargs, ): """Search implementation to make label records. Args: df (DataFrame): Data frame to search and extract labels. generator (DataSliceGenerator): The generator for data slices. search (LabelSearch or ExampleSearch): The type of search to be done. verbose (bool): Whether to render progress bar. Default value is True. *args: Positional arguments for labeling function. **kwargs: Keyword arguments for labeling function. Returns: records (list(dict)): Label Records """ df = self.set_index(df) entity_groups = df.groupby(self.target_entity) multiplier = search.expected_count if search.is_finite else 1 total = entity_groups.ngroups * multiplier progress_bar, records = tqdm( total=total, bar_format=self._bar_format, disable=not verbose, file=stdout, ), [] def missing_examples(entity_count): return entity_count * search.expected_count - progress_bar.n for entity_count, (entity_id, df) in enumerate(entity_groups): for ds in generator(df): items = self.labeling_function.items() labels = {name: lf(ds, *args, **kwargs) for name, lf in items} valid_labels = search.is_valid_labels(labels) if not valid_labels: continue records.append({ self.target_entity: entity_id, 'time': ds.context.slice_start, **labels, }) search.update_count(labels) # if finite search, progress bar is updated for each example found if search.is_finite: progress_bar.update(n=1) if search.is_complete: break # if finite search, progress bar is updated for examples not found # otherwise, progress bar is updated for each entity group n = missing_examples(entity_count + 1) if search.is_finite else 1 progress_bar.update(n=n) search.reset_count() total -= progress_bar.n progress_bar.update(n=total) progress_bar.close() return records def _check_example_count(self, num_examples_per_instance, gap): """Checks whether example count corresponds to data slices.""" if self.window_size is None and gap is None: more_than_one = num_examples_per_instance > 1 assert not more_than_one, "must specify gap if num_examples > 1 and window size = none"
[docs] def search(self, df, num_examples_per_instance, minimum_data=None, maximum_data=None, gap=None, drop_empty=True, verbose=True, *args, **kwargs): """Searches the data to calculates labels. Args: df (DataFrame): Data frame to search and extract labels. num_examples_per_instance (int or dict): The expected number of examples to return from each entity group. A dictionary can be used to further specify the expected number of examples to return from each label. minimum_data (str): Minimum data before starting the search. Default value is first time of index. maximum_data (str): Maximum data before stopping the search. Default value is last time of index. gap (str or int): Time between examples. Default value is window size. If an integer, search will start on the first event after the minimum data. drop_empty (bool): Whether to drop empty slices. Default value is True. verbose (bool): Whether to render progress bar. Default value is True. *args: Positional arguments for labeling function. **kwargs: Keyword arguments for labeling function. Returns: lt (LabelTimes): Calculated labels with cutoff times. """ assert self.labeling_function, 'missing labeling function(s)' self._check_example_count(num_examples_per_instance, gap) is_label_search = isinstance(num_examples_per_instance, dict) search = (LabelSearch if is_label_search else ExampleSearch)(num_examples_per_instance) generator = DataSliceGenerator( window_size=self.window_size, min_data=minimum_data, max_data=maximum_data, drop_empty=drop_empty, gap=gap, ) records = self._run_search( df=df, generator=generator, search=search, verbose=verbose, *args, **kwargs, ) lt = LabelTimes( data=records, target_columns=list(self.labeling_function), target_entity=self.target_entity, search_settings={ 'num_examples_per_instance': num_examples_per_instance, 'minimum_data': str(minimum_data), 'maximum_data': str(maximum_data), 'window_size': str(self.window_size), 'gap': str(gap), }, ) return lt
[docs] def set_index(self, df): """Sets the time index in a data frame (if not already set). Args: df (DataFrame): Data frame to set time index in. Returns: df (DataFrame): Data frame with time index set. """ if df.index.name != self.time_index: df = df.set_index(self.time_index) if 'time' not in str(df.index.dtype): df.index = df.index.astype('datetime64[ns]') return df