# Copyright 2019 The TensorTrade Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from deprecated import deprecated
import time
import numpy as np
import multiprocessing as mp
from typing import Callable
from tensortrade.agents import Agent
from tensortrade.agents.parallel.parallel_dqn_model import ParallelDQNModel
from tensortrade.agents.parallel.parallel_dqn_optimizer import ParallelDQNOptimizer
from tensortrade.agents.parallel.parallel_dqn_trainer import ParallelDQNTrainer
from tensortrade.agents.parallel.parallel_queue import ParallelQueue
[docs]@deprecated(version='1.0.4', reason="Builtin agents are being deprecated in favor of external implementations (ie: Ray)")
class ParallelDQNAgent(Agent):
def __init__(self,
create_env: Callable[[None], 'TradingEnvironment'],
model: ParallelDQNModel = None):
self.create_env = create_env
self.model = model or ParallelDQNModel(create_env=self.create_env)
[docs] def restore(self, path: str, **kwargs):
self.model.restore(path, **kwargs)
[docs] def save(self, path: str, **kwargs):
self.model.save(path, agent_id=self.id, **kwargs)
[docs] def get_action(self, state: np.ndarray, **kwargs) -> int:
return self.model.get_action(state, **kwargs)
[docs] def update_networks(self, model: 'ParallelDQNModel'):
self.model.update_networks(model)
[docs] def update_target_network(self):
self.model.update_target_network()
def _start_trainer_process(self,
create_env,
memory_queue,
model_update_queue,
done_queue,
n_steps,
n_episodes,
eps_start,
eps_end,
eps_decay_steps,
update_target_every):
trainer_process = ParallelDQNTrainer(self,
create_env,
memory_queue,
model_update_queue,
done_queue,
n_steps,
n_episodes,
eps_start,
eps_end,
eps_decay_steps,
update_target_every)
trainer_process.start()
return trainer_process
def _start_optimizer_process(self,
model,
n_envs,
memory_queue,
model_update_queue,
done_queue,
discount_factor,
batch_size,
learning_rate,
memory_capacity):
optimizer_process = ParallelDQNOptimizer(model,
n_envs,
memory_queue,
model_update_queue,
done_queue,
discount_factor,
batch_size,
learning_rate,
memory_capacity)
optimizer_process.daemon = True
optimizer_process.start()
return optimizer_process
[docs] def train(self,
n_steps: int = None,
n_episodes: int = None,
save_every: int = None,
save_path: str = None,
callback: callable = None,
**kwargs) -> float:
n_envs: int = kwargs.get('n_envs', mp.cpu_count())
batch_size: int = kwargs.get('batch_size', 128)
discount_factor: float = kwargs.get('discount_factor', 0.9999)
learning_rate: float = kwargs.get('learning_rate', 0.0001)
eps_start: float = kwargs.get('eps_start', 0.9)
eps_end: float = kwargs.get('eps_end', 0.05)
eps_decay_steps: int = kwargs.get('eps_decay_steps', 2000)
update_target_every: int = kwargs.get('update_target_every', 1000)
memory_capacity: int = kwargs.get('memory_capacity', 10000)
memory_queue = ParallelQueue()
model_update_queue = ParallelQueue()
done_queue = ParallelQueue()
print('==== AGENT ID: {} ===='.format(self.id))
trainers = [self._start_trainer_process(self.create_env,
memory_queue,
model_update_queue,
done_queue,
n_steps,
n_episodes,
eps_start,
eps_end,
eps_decay_steps,
update_target_every) for _ in range(n_envs)]
self._start_optimizer_process(self.model,
n_envs,
memory_queue,
model_update_queue,
done_queue,
discount_factor,
batch_size,
learning_rate,
memory_capacity)
while done_queue.qsize() < n_envs:
time.sleep(5)
total_reward = 0
while done_queue.qsize() > 0:
total_reward += done_queue.get()
for queue in [memory_queue, model_update_queue, done_queue]:
queue.close()
for queue in [memory_queue, model_update_queue, done_queue]:
queue.join_thread()
for trainer in trainers:
trainer.terminate()
for trainer in trainers:
trainer.join()
mean_reward = total_reward / n_envs
return mean_reward