import threading
import json
import yaml
from collections import UserDict
from typing import List
import numpy as np
from . import registry
[docs]class TradingContext(UserDict):
"""A class for objects that put themselves in a `Context` using
the `with` statement.
The implementation for this class is heavily borrowed from the pymc3
library and adapted with the design goals of TensorTrade in mind.
Parameters
----------
config : dict
The configuration holding the information for each `Component`.
Methods
-------
from_json(path)
Creates a `TradingContext` from a json file.
from_yaml(path)
Creates a `TradingContext` from a yaml file.
Warnings
--------
If there is a conflict in the contexts of different components because
they were initialized under different contexts, can have undesirable effects.
Therefore, a warning should be made to the user indicating that using
components together that have conflicting contexts can lead to unwanted
behavior.
References
----------
[1] https://github.com/pymc-devs/pymc3/blob/master/pymc3/model.py
"""
contexts = threading.local()
def __init__(self, config: dict):
super().__init__(**config)
r = registry.registry()
registered_names = list(np.unique([r[i] for i in r.keys()]))
for name in registered_names:
if name not in registry.MAJOR_COMPONENTS:
setattr(self, name, config.get(name, {}))
config_items = {k: config[k]
for k in config.keys()
if k not in registered_names}
self._config = config
self._shared = config.get('shared', {})
self._shared = {
**self._shared,
**config_items
}
@property
def shared(self) -> dict:
"""The shared values in common for all components involved with the
`TradingContext`.
Returns
-------
dict
Shared values for components under the `TradingContext`.
"""
return self._shared
[docs] def __enter__(self) -> 'TradingContext':
"""Adds a new `TradingContext` to the context stack.
This method is used for a `with` statement and adds a `TradingContext`
to the context stack. The new context on the stack is then used by every
class that subclasses `Component` the initialization of its instances.
Returns
-------
`TradingContext`
The context associated with the given with statement.
"""
type(self).get_contexts().append(self)
return self
[docs] def __exit__(self, typ, value, traceback) -> None:
"""Pops the first `TradingContext` of the stack.
Parameters
----------
typ : type
The type of `Exception`
value : `Exception`
An instance of `typ`.
traceback : python traceback object
The traceback object associated with the exception.
"""
type(self).get_contexts().pop()
[docs] @classmethod
def get_contexts(cls) -> List['TradingContext']:
"""Gets the stack of trading contexts.
Returns
-------
List['TradingContext']
The stack of trading contexts.
"""
if not hasattr(cls.contexts, 'stack'):
cls.contexts.stack = [TradingContext({})]
return cls.contexts.stack
[docs] @classmethod
def get_context(cls) -> 'TradingContext':
"""Gets the first context on the stack.
Returns
-------
`TradingContext`
The first context on the stack.
"""
return cls.get_contexts()[-1]
[docs] @classmethod
def from_json(cls, path: str) -> 'TradingContext':
"""Creates a `TradingContext` from a json file.
Parameters
----------
path : str
The path to locate the json file.
Returns
-------
`TradingContext`
A trading context with all the variables provided in the json file.
"""
with open(path, "rb") as fp:
config = json.load(fp)
return TradingContext(config)
[docs] @classmethod
def from_yaml(cls, path: str) -> 'TradingContext':
"""Creates a `TradingContext` from a yaml file.
Parameters
----------
path : str
The path to locate the yaml file.
Returns
-------
`TradingContext`
A trading context with all the variables provided in the yaml file.
"""
with open(path, "rb") as fp:
config = yaml.load(fp, Loader=yaml.FullLoader)
return TradingContext(config)
[docs]class Context(UserDict):
"""A context that is injected into every instance of a class that is
a subclass of `Component`.
"""
def __init__(self, **kwargs):
super(Context, self).__init__(**kwargs)
self.__dict__ = {**self.__dict__, **self.data}
def __str__(self):
data = ['{}={}'.format(k, getattr(self, k)) for k in self.__slots__]
return '<{}: {}>'.format(self.__class__.__name__, ', '.join(data))