Source code for vivarium.core.emitter

"""
========
Emitters
========

Emitters log configuration data and time-series data somewhere.
"""

import os
import json
import uuid
import itertools
from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Callable, Union
from urllib.parse import quote_plus
from concurrent.futures import ProcessPoolExecutor

from pymongo import ASCENDING
from pymongo.errors import DocumentTooLarge
from pymongo.mongo_client import MongoClient
from bson import MinKey, MaxKey

from vivarium.library.units import remove_units
from vivarium.library.dict_utils import (
    value_in_embedded_dict,
    make_path_dict,
    deep_merge_check,
)
from vivarium.library.topology import (
    assoc_path,
    get_in,
    paths_to_dict,
)
from vivarium.core.registry import emitter_registry
from vivarium.core.serialize import (
    make_fallback_serializer_function,
    serialize_value,
    deserialize_value)

HISTORY_INDEXES = [
    'data.time',
    [('experiment_id', ASCENDING),
     ('data.time', ASCENDING),
     ('_id', ASCENDING)],
]

CONFIGURATION_INDEXES = [
    'experiment_id',
]

SECRETS_PATH = 'secrets.json'


[docs]def breakdown_data( limit: float, data: Any, path: Tuple = (), size: float = None, ) -> list: size = size or len(str(data)) if size > limit: if isinstance(data, dict): output = [] subsizes = {} total = 0 for key, subdata in data.items(): subsizes[key] = len(str(subdata)) total += subsizes[key] order = sorted( subsizes.items(), key=lambda item: item[1], reverse=True) remaining = total index = 0 large_keys = [] while remaining > limit and index < len(order): key, subsize = order[index] large_keys.append(key) remaining -= subsize index += 1 for large_key in large_keys: subdata = breakdown_data( limit, data[large_key], path=path + (large_key,), size=subsizes[large_key]) try: output.extend(subdata) except ValueError: print(f'data can not be broken down to size ' f'{limit}: {data[large_key]}') pruned = { key: value for key, value in data.items() if key not in large_keys} output.append((path, pruned)) return output print(f'Data at {path} is too large, skipped: {size} > {limit}') return [] return [(path, data)]
[docs]def get_emitter(config: Optional[Dict[str, str]]) -> 'Emitter': """Construct an Emitter using the provided config. The available Emitter type names and their classes are: * ``database``: :py:class:`DatabaseEmitter` * ``null``: :py:class:`NullEmitter` * ``print``: :py:class:`Emitter`, prints to stdout * ``timeseries``: :py:class:`RAMEmitter` Arguments: config: Must comtain the ``type`` key, which specifies the emitter type name (e.g. ``database``). Returns: A new Emitter instance. """ if config is None: config = {} emitter_type = config.get('type', 'print') emitter: Emitter = emitter_registry.access(emitter_type)(config) return emitter
[docs]def path_timeseries_from_data(data: dict) -> dict: """Convert from :term:`raw data` to a :term:`path timeseries`.""" embedded_timeseries = timeseries_from_data(data) return path_timeseries_from_embedded_timeseries(embedded_timeseries)
[docs]def path_timeseries_from_embedded_timeseries(embedded_timeseries: dict) -> dict: """Convert an :term:`embedded timeseries` to a :term:`path timeseries`.""" times_vector = embedded_timeseries['time'] path_timeseries = make_path_dict( {key: val for key, val in embedded_timeseries.items() if key != 'time'}) path_timeseries['time'] = times_vector return path_timeseries
[docs]def timeseries_from_data(data: dict) -> dict: """Convert :term:`raw data` to an :term:`embedded timeseries`.""" times_vector = list(data.keys()) embedded_timeseries: dict = {} for value in data.values(): if isinstance(value, dict): embedded_timeseries = value_in_embedded_dict( value, embedded_timeseries) embedded_timeseries['time'] = times_vector return embedded_timeseries
[docs]class Emitter: def __init__(self, config: Dict[str, str]) -> None: """Base class for emitters. This emitter simply emits to STDOUT. Args: config: Emitter configuration. """ self.config = config
[docs] def emit(self, data: Dict[str, Any]) -> None: """Emit data. Args: data: The data to emit. This gets called by the Vivarium engine with a snapshot of the simulation state. """ print(data)
[docs] def get_data(self, query: list = None) -> dict: """Get the emitted data. Returns: The data that has been emitted to the database in the :term:`raw data` format. For this particular class, an empty dictionary is returned. """ _ = query return {}
[docs] def get_data_deserialized(self, query: list = None) -> Any: """Get the emitted data with variable values deserialized. Returns: The data that has been emitted to the database in the :term:`raw data` format. Before being returned, serialized values in the data are deserialized. """ return deserialize_value(self.get_data(query))
[docs] def get_data_unitless(self, query: list = None) -> Any: """Get the emitted data with units stripped from variable values. Returns: The data that has been emitted to the database in the :term:`raw data` format. Before being returned, units are stripped from values. """ return remove_units(self.get_data_deserialized(query))
[docs] def get_path_timeseries(self, query: list = None) -> dict: """Get the deserialized data as a :term:`path timeseries`. Returns: The deserialized emitted data, formatted as a :term:`path timeseries`. """ return path_timeseries_from_data(self.get_data_deserialized(query))
[docs] def get_timeseries(self, query: list = None) -> dict: """Get the deserialized data as an :term:`embedded timeseries`. Returns: The deserialized emitted data, formatted as an :term:`embedded timeseries`. """ return timeseries_from_data(self.get_data_deserialized(query))
[docs]class NullEmitter(Emitter): """ Don't emit anything """
[docs] def emit(self, data: Dict[str, Any]) -> None: pass
[docs]class RAMEmitter(Emitter): """ Accumulate the timeseries history portion of the "emitted" data to a table in RAM. """ def __init__(self, config: Dict[str, Any]) -> None: super().__init__(config) self.saved_data: Dict[float, Dict[str, Any]] = {} self.fallback_serializer = make_fallback_serializer_function() self.embed_path = config.get('embed_path', tuple())
[docs] def emit(self, data: Dict[str, Any]) -> None: """ Emit the timeseries history portion of ``data``, which is ``data['data'] if data['table'] == 'history'`` and put it at ``data['data']['time']`` in the history. """ if data['table'] == 'history': emit_data = data['data'].copy() time = emit_data.pop('time', None) data_at_time = assoc_path({}, self.embed_path, emit_data) self.saved_data.setdefault(time, {}) data_at_time = serialize_value( data_at_time, self.fallback_serializer) deep_merge_check( self.saved_data[time], data_at_time, check_equality=True)
[docs] def get_data(self, query: list = None) -> dict: """ Return the accumulated timeseries history of "emitted" data. """ if query: returned_data = {} for t, data in self.saved_data.items(): paths_data = [] for path in query: datum = get_in(data, path) if datum: path_data = (path, datum) paths_data.append(path_data) returned_data[t] = paths_to_dict(paths_data) return returned_data return self.saved_data
[docs]class SharedRamEmitter(RAMEmitter): """ Accumulate the timeseries history portion of the "emitted" data to a table in RAM that is shared across all instances of the emitter. """ saved_data: Dict[float, Dict[str, Any]] = {} def __init__(self, config: Dict[str, Any]) -> None: # pylint: disable=super-init-not-called # We intentionally don't call the superclass constructor because # we don't want to create a per-instance ``saved_data`` # attribute. self.fallback_serializer = make_fallback_serializer_function() self.embed_path = config.get('embed_path', tuple())
[docs]class DatabaseEmitter(Emitter): """ Emit data to a mongoDB database Example: >>> config = { ... 'host': 'localhost:27017', ... 'database': 'DB_NAME', ... } >>> # The line below works only if you have to have 27017 open locally >>> # emitter = DatabaseEmitter(config) """ default_host = 'localhost:27017' client_dict: Dict[int, MongoClient] = {}
[docs] @classmethod def create_indexes(cls, table: Any, columns: List[Any]) -> None: """Create the listed column indexes for the given DB table.""" for column in columns: table.create_index(column)
def __init__(self, config: Dict[str, Any]) -> None: """config may have 'host' and 'database' items.""" super().__init__(config) self.experiment_id = config.get('experiment_id') # In the worst case, `breakdown_data` can underestimate the size of # data by a factor of 4: len(str(0)) == 1 but 0 is a 4-byte int. # Use 4 MB as the breakdown limit to stay under MongoDB's 16 MB limit. self.emit_limit = config.get('emit_limit', 4000000) self.embed_path = config.get('embed_path', tuple()) # create new MongoClient per OS process curr_pid = os.getpid() if curr_pid not in DatabaseEmitter.client_dict: DatabaseEmitter.client_dict[curr_pid] = MongoClient( config.get('host', self.default_host)) self.client = DatabaseEmitter.client_dict[curr_pid] self.db = getattr(self.client, config.get('database', 'simulations')) self.history = getattr(self.db, 'history') self.configuration = getattr(self.db, 'configuration') self.phylogeny = getattr(self.db, 'phylogeny') self.create_indexes(self.history, HISTORY_INDEXES) self.create_indexes(self.configuration, CONFIGURATION_INDEXES) self.create_indexes(self.phylogeny, CONFIGURATION_INDEXES) self.fallback_serializer = make_fallback_serializer_function()
[docs] def emit(self, data: Dict[str, Any]) -> None: table_id = data['table'] table = self.db.get_collection(table_id) time = data['data'].pop('time', None) data['data'] = assoc_path({}, self.embed_path, data['data']) # Analysis scripts expect the time to be at the top level of the # dictionary, but some emits, like configuration emits, lack a # time key. if time is not None: data['data']['time'] = time emit_data = data.copy() emit_data.pop('table', None) emit_data['experiment_id'] = self.experiment_id self.write_emit(table, emit_data)
[docs] def write_emit(self, table: Any, emit_data: Dict[str, Any]) -> None: """Check that data size is less than emit limit. Break up large emits into smaller pieces and emit them individually """ assembly_id = str(uuid.uuid4()) emit_data = serialize_value(emit_data, self.fallback_serializer) try: emit_data['assembly_id'] = assembly_id table.insert_one(emit_data) # If document is too large, break up into smaller dictionaries # with shared assembly IDs and time keys except DocumentTooLarge: emit_data.pop('assembly_id') experiment_id = emit_data.pop('experiment_id') time = emit_data['data'].pop('time', None) broken_down_data = breakdown_data(self.emit_limit, emit_data) for (path, datum) in broken_down_data: d: Dict[str, Any] = {} assoc_path(d, path, datum) d['assembly_id'] = assembly_id d['experiment_id'] = experiment_id if time: d.setdefault('data', {}) d['data']['time'] = time table.insert_one(d)
[docs] def get_data(self, query: list = None) -> dict: return get_history_data_db(self.history, self.experiment_id, query)
[docs]def get_experiment_database( port: Any = 27017, database_name: str = 'simulations' ) -> Any: """Get a database object. Args: port: Port number of database. This can usually be left as the default. database_name: Name of the database table. This can usually be left as the default. Returns: The database object. """ config = { 'host': '{}:{}'.format('localhost', port), 'database': database_name} emitter = DatabaseEmitter(config) db = emitter.db return db
[docs]def delete_experiment( host: str = 'localhost', port: Any = 27017, query: dict = None ) -> None: """Helper function to delete experiment data in parallel Args: host: Host name of database. This can usually be left as the default. port: Port number of database. This can usually be left as the default. query: Filter for documents to delete. """ history_collection = get_local_client(host, port, 'simulations').history history_collection.delete_many(query, hint=HISTORY_INDEXES[1])
[docs]def delete_experiment_from_database( experiment_id: str, host: str = 'localhost', port: Any = 27017, cpus: int = 1 ) -> None: """Delete an experiment's data from a database. Args: experiment_id: Identifier of experiment. host: Host name of database. This can usually be left as the default. port: Port number of database. This can usually be left as the default. cpus: Number of chunks to split delete operation into to be run in parallel. Useful if single-threaded delete does not saturate I/O. """ db = get_local_client(host, port, 'simulations') if cpus > 1: chunks = get_data_chunks(db.history, experiment_id, cpus=cpus) queries = [] for chunk in chunks: queries.append({ 'experiment_id': experiment_id, '_id': {'$gte': chunk[0], '$lt': chunk[1]}, 'data.time': {'$gte': MinKey(), '$lte': MaxKey()} }) partial_del_exp = partial(delete_experiment, host, port) with ProcessPoolExecutor(cpus) as executor: executor.map(partial_del_exp, queries) else: query = {'experiment_id': experiment_id} db.history.delete_many(query, hint=HISTORY_INDEXES[1]) db.configuration.delete_many(query)
[docs]def assemble_data(data: list) -> dict: """re-assemble data""" assembly: dict = {} for datum in data: if 'assembly_id' in datum: assembly_id = datum['assembly_id'] if assembly_id not in assembly: assembly[assembly_id] = {} deep_merge_check( assembly[assembly_id], datum['data'], check_equality=True, ) else: assembly_id = str(uuid.uuid4()) assembly[assembly_id] = datum['data'] return assembly
[docs]def apply_func( document: Any, field: Tuple, f: Callable[..., Any] = None, ) -> Any: if field[0] not in document: return document if len(field) != 1: document[field[0]] = apply_func(document[field[0]], field[1:], f) elif f is not None: document[field[0]] = f(document[field[0]]) return document
[docs]def get_query( projection: dict, host: str, port: Any, query: dict ) -> list: """Helper function for parallel queries Args: projection: a MongoDB projection in dictionary form host, port: used to create new MongoClient for each parallel process query: a MongoDB query in dictionary form Returns: List of projected documents for given query """ history_collection = get_local_client(host, port, 'simulations').history return list(history_collection.find(query, projection, hint=HISTORY_INDEXES[1]))
[docs]def get_data_chunks( history_collection: Any, experiment_id: str, start_time: Union[int, MinKey] = MinKey(), end_time: Union[int, MaxKey] = MaxKey(), cpus: int = 8 ) -> list: """Helper function to get chunks for parallel queries Args: history_collection: the MongoDB history collection to query experiment_id: the experiment id which is being retrieved start_time, end_time: first and last simulation time to query cpus: number of chunks to create Returns: List of ObjectId tuples that represent chunk boundaries. For each tuple, include ``{'_id': {$gte: tuple[0], $lt: tuple[1]}}`` in the query to search its corresponding chunk. """ id_cutoffs = list(history_collection.aggregate([{ '$match': { 'experiment_id': experiment_id, 'data.time': {'$gte': start_time, '$lte': end_time}}}, {'$project': {'_id':1}}, {'$bucketAuto': {'groupBy': '$_id', 'buckets': cpus}}, {'$group': {'_id': '', 'splitPoints': {'$push': '$_id.min'}}}, {'$unset': '_id'}], hint={'experiment_id':1, 'data.time':1, '_id':1}))[0]['splitPoints'] id_ranges = [] for i in range(len(id_cutoffs)-1): id_ranges.append((id_cutoffs[i], id_cutoffs[i+1])) id_ranges.append((id_cutoffs[-1], MaxKey())) return id_ranges
[docs]def get_history_data_db( history_collection: Any, experiment_id: Any, query: list = None, func_dict: dict = None, f: Callable[..., Any] = None, filters: Optional[dict] = None, start_time: Union[int, MinKey] = MinKey(), end_time: Union[int, MaxKey] = MaxKey(), cpus: int = 1, host: str ='localhost', port: Any = '27017' ) -> Dict[float, dict]: """Query MongoDB for history data. Args: history_collection: a MongoDB collection experiment_id: the experiment id which is being retrieved query: a list of tuples pointing to fields within the experiment data. In the format: [('path', 'to', 'field1'), ('path', 'to', 'field2')] func_dict: a dict which maps the given query paths to a function that operates on the retrieved values and returns the results. If None then the raw values are returned. In the format: {('path', 'to', 'field1'): function} f: a function that applies equally to all fields in query. func_dict is the recommended approach and takes priority over f. filters: MongoDB query arguments to further filter results beyond matching the experiment ID. start_time, end_time: first and last simulation time to query cpus: splits query into this many chunks to run in parallel, useful if single-threaded query does not saturate I/O (e.g. on Google Cloud) host: used if cpus>1 to create MongoClient in parallel processes port: used if cpus>1 to create MongoClient in parallel processes Returns: data (dict) """ experiment_query = {'experiment_id': experiment_id} if filters: experiment_query.update(filters) projection = None if query: projection = {f"data.{'.'.join(field)}": 1 for field in query} projection['data.time'] = 1 projection['assembly_id'] = 1 if cpus > 1: chunks = get_data_chunks(history_collection, experiment_id, cpus=cpus) queries = [] for chunk in chunks: queries.append({ **experiment_query, '_id': {'$gte': chunk[0], '$lt': chunk[1]}, 'data.time': {'$gte': start_time, '$lte': end_time} }) partial_get_query = partial(get_query, projection, host, port) with ProcessPoolExecutor(cpus) as executor: queried_chunks = executor.map(partial_get_query, queries) cursor = itertools.chain.from_iterable(queried_chunks) else: cursor = history_collection.find(experiment_query, projection) raw_data = [] for document in cursor: assert document.get('assembly_id'), \ "all database documents require an assembly_id" if (f or func_dict) and query: for field in query: if func_dict: # func_dict takes priority over f func = func_dict.get(field) else: func = f document["data"] = apply_func( document["data"], field, func) raw_data.append(document) # re-assemble data assembly = assemble_data(raw_data) # restructure by time data: Dict[float, Any] = {} for datum in assembly.values(): time = datum['time'] datum = datum.copy() datum.pop('_id', None) datum.pop('time', None) deep_merge_check( data, {time: datum}, check_equality=True, ) return data
[docs]def get_atlas_client(secrets_path: str) -> Any: """Open a MongoDB client using the named secrets config JSON file.""" with open(secrets_path, 'r') as f: secrets = json.load(f) emitter_config = get_atlas_database_emitter_config( **secrets['database']) uri = emitter_config['host'] client = MongoClient(uri) return client[emitter_config['database']]
[docs]def get_local_client(host: str, port: Any, database_name: str) -> Any: """Open a MongoDB client onto the given host, port, and DB.""" client = MongoClient('{}:{}'.format(host, port)) return client[database_name]
[docs]def data_from_database( experiment_id: str, client: Any, query: list = None, func_dict: dict = None, f: Callable[..., Any] = None, filters: Optional[dict] = None, start_time: Union[int, MinKey] = MinKey(), end_time: Union[int, MaxKey] = MaxKey(), cpus: int = 1 ) -> Tuple[dict, Any]: """Fetch something from a MongoDB. Args: experiment_id: the experiment id which is being retrieved client: a MongoClient instance connected to the DB query: a list of tuples pointing to fields within the experiment data. In the format: [('path', 'to', 'field1'), ('path', 'to', 'field2')] func_dict: a dict which maps the given query paths to a function that operates on the retrieved values and returns the results. If None then the raw values are returned. In the format: {('path', 'to', 'field1'): function} f: a function that applies equally to all fields in query. func_dict is the recommended approach and takes priority over f. filters: MongoDB query arguments to further filter results beyond matching the experiment ID. start_time, end_time: first and last simulation time to query cpus: splits query into this many chunks to run in parallel Returns: data (dict) """ # Retrieve environment config config_collection = client.configuration experiment_query = {'experiment_id': experiment_id} experiment_configs = list(config_collection.find(experiment_query)) # Re-assemble experiment_config experiment_assembly = assemble_data(experiment_configs) assert len(experiment_assembly) == 1 assembly_id = list(experiment_assembly.keys())[0] experiment_config = experiment_assembly[assembly_id] # Retrieve timepoint data history = client.history host = client.address[0] port = client.address[1] data = get_history_data_db(history, experiment_id, query, func_dict, f, filters, start_time, end_time, cpus, host, port) return data, experiment_config
[docs]def data_to_database( data: Dict[float, dict], environment_config: Any, client: Any) -> Any: """Insert something into a MongoDB.""" history_collection = client.history reshaped_data = [] for time, timepoint_data in data.items(): # Since time is the dictionary key, it has to be a string for # JSON/BSON compatibility. But now that we're uploading it, we # want it to be a float for fast searching. reshaped_entry = {'time': float(time)} for key, val in timepoint_data.items(): if key not in ('_id', 'time'): reshaped_entry[key] = val reshaped_data.append(reshaped_entry) history_collection.insert_many(reshaped_data) config_collection = client.configuration config_collection.insert_one(environment_config)
[docs]def get_atlas_database_emitter_config( username: str, password: str, cluster_subdomain: Any, database: str ) -> Dict[str, Any]: """Construct an Emitter config for a MongoDB on the Atlas service.""" username = quote_plus(username) password = quote_plus(password) database = quote_plus(database) uri = ( "mongodb+srv://{}:{}@{}.mongodb.net/" "?retryWrites=true&w=majority" ).format(username, password, cluster_subdomain) return { 'type': 'database', 'host': uri, 'database': database, }
def test_breakdown() -> None: data = { 'a': [1, 2, 3], 'b': { 'X': [1, 2, 3, 4, 5], 'Y': [1, 2, 3, 4, 5, 6], 'Z': [5]}} output = breakdown_data(20, data) assert output == [ (('b', 'Y'), [1, 2, 3, 4, 5, 6]), (('b',), {'X': [1, 2, 3, 4, 5], 'Z': [5]}), ((), {'a': [1, 2, 3]})] if __name__ == '__main__': test_breakdown()