Source code for tensortrade.feed.core.feed



from typing import List

from tensortrade.feed.core.base import Stream, T, Placeholder, IterableStream


[docs]class DataFeed(Stream[dict]): """A stream the compiles together streams to be run in an organized manner. Parameters ---------- streams : `List[Stream]` A list of streams to be used in the data feed. """ def __init__(self, streams: "List[Stream]") -> None: super().__init__() self.process = None self.compiled = False if streams: self.__call__(*streams)
[docs] def compile(self) -> None: """Compiles all the given stream together. Organizes the order in which streams should be run to get valid output. """ edges = self.gather() self.process = self.toposort(edges) self.compiled = True self.reset()
[docs] def run(self) -> None: """Runs all the streams in processing order.""" if not self.compiled: self.compile() for s in self.process: s.run() super().run()
[docs] def forward(self) -> dict: return {s.name: s.value for s in self.inputs}
[docs] def next(self) -> dict: self.run() return self.value
[docs] def has_next(self) -> bool: return all(s.has_next() for s in self.process)
[docs] def reset(self, random_start=0) -> None: for s in self.process: if isinstance(s, IterableStream): s.reset(random_start) else: s.reset()
[docs]class PushFeed(DataFeed): """A data feed for working with live data in an online manner. All sources of data to be used with this feed must be a `Placeholder`. This ensures that the user can wait until all of their data has been loaded for the next time step. Parameters ---------- streams : `List[Stream]` A list of streams to be used in the data feed. """ def __init__(self, streams: "List[Stream]"): super().__init__(streams) self.compile() edges = self.gather() src = set([s for s, t in edges]) tgt = set([t for s, t in edges]) self.start = [s for s in src.difference(tgt) if isinstance(s, Placeholder)] @property def is_loaded(self): return all([s.value is not None for s in self.start])
[docs] def push(self, data: dict) -> dict: """Generates the values from the data feed based on the values being provided in `data`. Parameters ---------- data : dict The data to be pushed to each of the placholders in the feed. Returns ------- dict The next data point generated from the feed based on `data`. """ for s in self.start: s.push(data[s.name]) output = self.next() for s in self.start: s.value = None return output
[docs] def next(self) -> dict: if not self.is_loaded: raise Exception("No data has been pushed to the feed.") self.run() return self.value