diff --git a/reinforcement_learning/common/env_utils.py b/reinforcement_learning/common/env_utils.py new file mode 100644 index 0000000000..516af62115 --- /dev/null +++ b/reinforcement_learning/common/env_utils.py @@ -0,0 +1,211 @@ +import gym +import numpy as np +import pandas as pd +import json +from pathlib import Path + +gym.logger.set_level(40) + +class VectoredGymEnvironment(): + """ + Envrioment class to run multiple similations and collect rollout data + """ + def __init__(self, registered_gym_env, num_of_envs=1): + self.envs_initialized = False + self.initialized_envs = {} + self.env_states = {} + self.env_reset_counter = {} + self.num_of_envs = num_of_envs + self.data_rows = [] + + self.initialize_envs(num_of_envs, registered_gym_env) + + def is_initialized(self): + return self.envs_initialized + + def initialize_envs( + self, + num_of_envs, + registered_gym_env): + """Initialize multiple Openai gym environments. + Each envrionment will start with a different random seed. + + Arguments: + num_of_envs {int} -- Number of environments/simulations to initiate + registered_gym_env {str} -- Environment name of the registered gym environment + """ + print("Initializing {} environments of {}".format(num_of_envs, registered_gym_env)) + for i in range(0, num_of_envs): + environment_id = "environment_" + str(i) + environment = gym.make(registered_gym_env) + environment = environment.unwrapped + environment.seed(i) + self.env_states[environment_id] = environment.reset() + self.env_reset_counter[environment_id] = 0 + self.initialized_envs[environment_id] = environment + self.envs_initialized = True + self.state_dims = len(self.env_states[environment_id]) + + def get_environment_states(self): + return self.env_states + + def dump_environment_states(self, dir_path, file_name): + """Dumping current states of all the envrionments into file + + Arguments: + dir_path {str} -- Directory path of the target file + file_name {str} -- File name of the target file + """ + data_folder = Path(dir_path) + file_path = data_folder / file_name + + with open(file_path, 'w') as outfile: + for state in self.env_states.values(): + json.dump(list(state), outfile) + outfile.write('\n') + + def get_environment_ids(self): + return list(self.initialized_envs.keys()) + + def step(self, environment_id, action): + local_env = self.initialized_envs[environment_id] + observation, reward, done, info = local_env.step(action) + + self.env_states[environment_id] = observation + return observation, reward, done, info + + def reset(self, environment_id): + self.env_states[environment_id] = \ + self.initialized_envs[environment_id].reset() + return self.env_states[environment_id] + + def reset_all_envs(self): + print("Resetting all the environments...") + for i in range(0, self.num_of_envs): + environment_id = "environment_" + str(i) + self.reset(environment_id) + + def close(self, environment_id): + self.initialized_envs[environment_id].close() + return + + def render(self, environment_id): + self.initialized_envs[environment_id].render() + return + + def collect_rollouts_for_single_env_with_given_episodes(self, environment_id, action_prob, num_episodes): + """Collect rollouts with given steps from one environment + + Arguments: + environment_id {str} -- Environment id for the environment + action_prob {list} -- Action probabilities of the simulated policy + num_episodes {int} -- Number of episodes to run rollouts + """ + # normalization if sum of probs is not exact equal to 1 + action_prob = np.array(action_prob) + if action_prob.sum() != 1: + action_prob /= action_prob.sum() + action_prob = list(action_prob) + + for _ in range(num_episodes): + done = False + cumulative_rewards = 0 + while not done: + data_item = [] + action = np.random.choice(len(action_prob), p=action_prob) + cur_state_features = self.env_states[environment_id] + _, reward, done, _ = self.step(environment_id, action) + cumulative_rewards += reward + episode_id = int(environment_id.split('_')[-1]) + \ + self.num_of_envs * self.env_reset_counter[environment_id] + if not done: + data_item.extend([action, action_prob, episode_id, reward, 0.0]) + else: + data_item.extend([action, action_prob, episode_id, reward, cumulative_rewards]) + for j in range(len(cur_state_features)): + data_item.append(cur_state_features[j]) + self.data_rows.append(data_item) + + self.reset(environment_id) + self.env_reset_counter[environment_id] += 1 + + def collect_rollouts_for_single_env_with_given_steps(self, environment_id, action_prob, num_steps): + """Collect rollouts with given steps from one environment + + Arguments: + environment_id {str} -- Environment id for the environment + action_prob {list} -- Action probabilities of the simulated policy + num_episodes {int} -- Number of steps to run rollouts + """ + # normalization if sum of probs is not exact equal to 1 + action_prob = np.array(action_prob) + if action_prob.sum() != 1: + action_prob /= action_prob.sum() + action_prob = list(action_prob) + + for _ in range(num_steps): + data_item = [] + action = np.random.choice(len(action_prob), p=action_prob) + cur_state_features = self.env_states[environment_id] + _, reward, done, _ = self.step(environment_id, action) + episode_id = int(environment_id.split('_')[-1]) + \ + self.num_of_envs * self.env_reset_counter[environment_id] + data_item.extend([action, action_prob, episode_id, reward]) + for j in range(len(cur_state_features)): + data_item.append(cur_state_features[j]) + self.data_rows.append(data_item) + if done: + self.reset(environment_id) + self.env_reset_counter[environment_id] += 1 + + def collect_rollouts_with_given_action_probs(self, num_steps=None, num_episodes=None, action_probs=None, file_name=None): + """Collect rollouts from all the initiated environments with given action probs + + Keyword Arguments: + num_steps {int} -- Number of steps to run rollouts (default: {None}) + num_episodes {int} -- Number of episodes to run rollouts (default: {None}) + action_probs {list} -- Action probs for the policy (default: {None}) + file_name {str} -- Batch transform output that contain predictions of probs (default: {None}) + + Returns: + [Dataframe] -- Dataframe that contains the rollout data from all envs + """ + if file_name is not None: + assert action_probs is None + json_lines = [json.loads(line.rstrip('\n')) for line in open(file_name) if line is not ''] + action_probs = [] + for line in json_lines: + if line.get('SageMakerOutput') is not None: + action_probs.append(line['SageMakerOutput'].get("predictions")[0]) + else: + action_probs.append(line.get("predictions")[0]) + + assert len(action_probs) == self.num_of_envs + for index, environment_id in enumerate(self.get_environment_ids()): + if num_steps is not None: + assert num_episodes is None + self.collect_rollouts_for_single_env_with_given_steps( + environment_id, action_probs[index], num_steps + ) + else: + assert num_episodes is not None + self.collect_rollouts_for_single_env_with_given_episodes( + environment_id, action_probs[index], num_episodes + ) + + col_names = self._create_col_names() + df = pd.DataFrame(self.data_rows, columns = col_names) + + return df + + def _create_col_names(self): + """Create column names of dataframe that can be consumed by Coach + + Returns: + [list] -- List of column names + """ + col_names = ['action', 'all_action_probabilities', 'episode_id', 'reward', 'cumulative_rewards'] + for i in range(self.state_dims): + col_names.append('state_feature_' + str(i)) + + return col_names \ No newline at end of file diff --git a/reinforcement_learning/rl_cartpole_batch_coach/README.md b/reinforcement_learning/rl_cartpole_batch_coach/README.md new file mode 100644 index 0000000000..4a318f2a32 --- /dev/null +++ b/reinforcement_learning/rl_cartpole_batch_coach/README.md @@ -0,0 +1,14 @@ +# Training Batch Reinforcement Learning Policies with Amazon SageMaker RL + +In many real-world problems, the reinforcement learning agent cannot interact with neither the real environment nor a simulated one. On one hand, creating a simulator that imitates the real environment dynamic could be quite complex and on the other, letting the learning agent attempt sub-optimal actions in the real world is quite risky. In such cases, the learning agent can only have access to batches of offline data that generated by some deployed policy. The learning agent need to utilize these data correctly to learn a better policy to solve the problem. + +This notebook shows an example of how to use batch reinforcement learning techniques to address such type of real-world problems: training a new policy from offline dataset when there is no way to interact with real environments or simulators. This example is a simple toy demonstrating how one might begin to address this real and challenging problem. We use gym `CartPole-v0` as a fake simulated system to generate offline dataset and the RL agents are trained using Amazon SageMaker RL. + +## Contents + +* `rl_cartpole_batch_coach.ipynb`: notebook used for training policy with Batch RL to solve the cartpole problem. +* `src/` + * `train-coach.py`: launcher for coach training. + * `evaluate-coach.py`: launcher for coach evaluation. + * `preset-cartpole-ddqnbcq.py`: coach preset for BCQ algorithm. + * `preset-cartpole-ddqnbcq-env.py`: coach preset for BCQ algorithm with environment setup. diff --git a/reinforcement_learning/rl_cartpole_batch_coach/__init__.py b/reinforcement_learning/rl_cartpole_batch_coach/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/reinforcement_learning/rl_cartpole_batch_coach/batch_rl.png b/reinforcement_learning/rl_cartpole_batch_coach/batch_rl.png new file mode 100644 index 0000000000..0d560f64e7 Binary files /dev/null and b/reinforcement_learning/rl_cartpole_batch_coach/batch_rl.png differ diff --git a/reinforcement_learning/rl_cartpole_batch_coach/common b/reinforcement_learning/rl_cartpole_batch_coach/common new file mode 120000 index 0000000000..60d3b0a6a8 --- /dev/null +++ b/reinforcement_learning/rl_cartpole_batch_coach/common @@ -0,0 +1 @@ +../common \ No newline at end of file diff --git a/reinforcement_learning/rl_cartpole_batch_coach/rl_cartpole_batch_coach.ipynb b/reinforcement_learning/rl_cartpole_batch_coach/rl_cartpole_batch_coach.ipynb new file mode 100644 index 0000000000..611b23aa27 --- /dev/null +++ b/reinforcement_learning/rl_cartpole_batch_coach/rl_cartpole_batch_coach.ipynb @@ -0,0 +1,726 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Training Batch Reinforcement Learning Policies with Amazon SageMaker RL\n", + "\n", + "For many real-world problems, the reinforcement learning (RL) agent needs to learn from historical data that was generated by some deployed policy. For example, we may have historical data of experts playing games, users interacting with a website or sensor data from a control system. This notebook shows an example of how to use batch RL to train a new policy from offline dataset[1]. We use gym `CartPole-v0` as a fake simulated system to generate offline dataset and the RL agents are trained using Amazon SageMaker RL.\n", + "\n", + "We may want to evaluate the policy learned from historical data before deployment. Since simulators may not be available in all use cases, we need to evaluate how good the learned policy by using held out historical data. This is called as off-policy evaluation or counterfactual evaluation. In this notebook, we evaluate the policy during the training using several off-policy evaluation metrics. \n", + "\n", + "We can deploy the policy using SageMaker Hosting endpoint. However, some use cases may not require a persistent serving endpoint with sub-second latency. Here we demonstrate how to deploy the policy with [SageMaker Batch Transform](https://docs.aws.amazon.com/sagemaker/latest/dg/batch-transform.html), where large volumes of input state features can be inferenced with high throughput.\n", + "\n", + "Figure below shows an overview of the entire notebook.\n", + "\n", + "![Batch RL in Notebook](./batch_rl.png)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Pre-requisites\n", + "\n", + "### Roles and permissions\n", + "\n", + "To get started, we'll import the Python libraries we need, set up the environment with a few pre-requisites for permissions and configurations." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sagemaker\n", + "import boto3\n", + "import sys\n", + "import os\n", + "import glob\n", + "import re\n", + "import subprocess\n", + "from IPython.display import HTML\n", + "import time\n", + "from time import gmtime, strftime\n", + "sys.path.append(\"common\")\n", + "from misc import get_execution_role, wait_for_s3_object\n", + "from sagemaker.rl import RLEstimator, RLToolkit, RLFramework\n", + "# install gym environments if needed\n", + "!pip install gym\n", + "from env_utils import VectoredGymEnvironment" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Steup S3 buckets\n", + "\n", + "Setup the linkage and authentication to the S3 bucket that you want to use for checkpoint and the metadata. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# S3 bucket\n", + "sage_session = sagemaker.session.Session()\n", + "s3_bucket = sage_session.default_bucket() \n", + "region_name = sage_session.boto_region_name\n", + "s3_output_path = 's3://{}/'.format(s3_bucket) # SDK appends the job name and output folder\n", + "print(\"S3 bucket path: {}\".format(s3_output_path))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Define Variables \n", + "\n", + "We define variables such as the job prefix for the training jobs *and the image path for the container (only when this is BYOC).*" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# create unique job name \n", + "job_name_prefix = 'rl-batch-cartpole'" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Configure settings\n", + "\n", + "You can run your RL training jobs on a SageMaker notebook instance or on your own machine. In both of these scenarios, you can run the following in either `local` or `SageMaker` modes. The `local` mode uses the SageMaker Python SDK to run your code in a local container before deploying to SageMaker. This can speed up iterative testing and debugging while using the same familiar Python SDK interface. You just need to set `local_mode = True`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%time\n", + "\n", + "# run in local mode?\n", + "local_mode = False\n", + "\n", + "image = '462105765813.dkr.ecr.{}.amazonaws.com/sagemaker-rl-coach-container:coach-1.0.0-tf-cpu-py3'.format(region_name)\n", + "print(\"Use ECR image: {}\".format(image))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Create an IAM role\n", + "Either get the execution role when running from a SageMaker notebook `role = sagemaker.get_execution_role()` or, when running from local machine, use utils method `role = get_execution_role()` to create an execution role." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "try:\n", + " role = sagemaker.get_execution_role()\n", + "except:\n", + " role = get_execution_role()\n", + " \n", + "print(\"Using IAM role arn: {}\".format(role))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Install docker for `local` mode\n", + "\n", + "In order to work in `local` mode, you need to have docker installed. When running from you local machine, please make sure that you have docker or docker-compose (for local CPU machines) and nvidia-docker (for local GPU machines) installed. Alternatively, when running from a SageMaker notebook instance, you can simply run the following script to install dependenceis.\n", + "\n", + "Note, you can only run a single local notebook at one time." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# only run from SageMaker notebook instance\n", + "if local_mode:\n", + " !/bin/bash ./common/setup.sh" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Collect offline data\n", + "\n", + "In order to do Batch RL training, we need to first prepare the dataset that is generated by a deployed policy. In real world scenarios, customers can collect these offline data by interacting the live environment using the already deployed agent. In this notebook, we used OpenAI gym `Cartpole-v0` as the environment to mimic a live environment and used a random policy with uniform action distribution to mimic a deployed agent. By interacting with multiple environments simultaneously, we can gather more trajectories from the environments.\n", + "\n", + "Here is a short introduction of the cart-pole balancing problem, where a pole is attached by an un-actuated joint to a cart, moving along a frictionless track.\n", + "\n", + "1. *Objective*: Prevent the pole from falling over\n", + "2. *Environment*: The environment used in this example is part of OpenAI Gym, corresponding to the version of the cart-pole problem described by Barto, Sutton, and Anderson [2]\n", + "3. *State*: Cart position, cart velocity, pole angle, pole velocity at tip\t\n", + "4. *Action*: Push cart to the left, push cart to the right\n", + "5. *Reward*: Reward is 1 for every step taken, including the termination step" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# initiate 100 environment to collect rollout data\n", + "NUM_ENVS = 100\n", + "NUM_EPISODES = 5\n", + "vectored_envs = VectoredGymEnvironment('CartPole-v0', NUM_ENVS)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we have 100 environments of `Cartpole-v0` ready. We'll collect 5 episodes from each environment so we’ll have 500 episodes of data for training. We start from a random policy that generates the same uniform action probabilities regardless of the state features." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# initiate a random policy by setting action probabilities as uniform distribution \n", + "action_probs = [[1/2, 1/2] for _ in range(NUM_ENVS)]\n", + "df = vectored_envs.collect_rollouts_with_given_action_probs(action_probs=action_probs, num_episodes=NUM_EPISODES)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# the rollout dataframes contain attributes: action, action_probs, episode_id, reward, cumulative_rewards, state_features\n", + "# only show cumulative rewards at the last step of the episode\n", + "df.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can use the average cumulative reward of the random policy as a baseline for the Batch RL trained policy. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# average cumulative rewards for each episode\n", + "avg_rewards = df['cumulative_rewards'].sum() / (NUM_ENVS * NUM_EPISODES)\n", + "print(\"Average cumulative rewards over {} episodes rollouts was {}.\".format((NUM_ENVS * NUM_EPISODES), avg_rewards))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Save Dataframe as CSV for Batch RL Training\n", + "\n", + "Coach Batch RL support reading off policy data in CSV format. We will dump our collected rollout data in CSV format." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# dump dataframe as csv file\n", + "df.to_csv(\"src/cartpole_dataset.csv\", index=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Configure the presets for RL algorithm \n", + "\n", + "The presets that configure the Batch RL training jobs are defined in the `preset-cartpole-ddqnbcq.py` file which is also uploaded on the `/src` directory. Using the preset file, you can define agent parameters to select the specific agent algorithm. You can also set the environment parameters, define the schedule and visualization parameters, and define the graph manager. The schedule presets will define the number of heat up steps, periodic evaluation steps, training steps between evaluations.\n", + "\n", + "These can be overridden at runtime by specifying the `RLCOACH_PRESET` hyperparameter. Additionally, it can be used to define custom hyperparameters. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "!pygmentize src/preset-cartpole-ddqnbcq.py" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this notebook, we use DDQN[6] to update the policy in an off-policy manner, and combine it with BCQ[5] to address the error induced by inaccurately estimated values for unseen state-action pairs. The training is completely off-line." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Write the Training Code \n", + "\n", + "The training code is written in the file “train-coach.py” which is uploaded in the /src directory. \n", + "First import the environment files and the preset files, and then define the `main()` function. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "!pygmentize src/train-coach.py" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Train the RL model using the Python SDK Script mode\n", + "\n", + "If you are using local mode, the training will run on the notebook instance. When using SageMaker for training, you can select a GPU or CPU instance. The RLEstimator is used for training RL jobs. \n", + "\n", + "1. Specify the source directory where the environment, presets and training code is uploaded.\n", + "2. Specify the entry point as the training code \n", + "3. Define the training parameters such as the instance count, job name, S3 path for output and job name. \n", + "4. Specify the hyperparameters for the RL agent algorithm. The `RLCOACH_PRESET` can be used to specify the RL agent algorithm you want to use. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%time\n", + "\n", + "if local_mode:\n", + " instance_type = 'local'\n", + "else:\n", + " instance_type = \"ml.m4.xlarge\"\n", + " \n", + "estimator = RLEstimator(entry_point=\"train-coach.py\",\n", + " source_dir='src',\n", + " dependencies=[\"common/sagemaker_rl\"],\n", + " image_name=image,\n", + " role=role,\n", + " train_instance_type=instance_type,\n", + " train_instance_count=1,\n", + " output_path=s3_output_path,\n", + " base_job_name=job_name_prefix,\n", + " hyperparameters = {\n", + " \"RLCOACH_PRESET\": \"preset-cartpole-ddqnbcq\",\n", + " \"save_model\": 1\n", + " }\n", + " )\n", + "estimator.fit()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Store intermediate training output and model checkpoints \n", + "\n", + "The output from the training job above is stored on S3. The intermediate folder contains gifs and metadata of the training. We'll need these metadata for metrics visualization and model evaluations." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "job_name=estimator._current_job_name\n", + "print(\"Job name: {}\".format(job_name))\n", + "\n", + "s3_url = \"s3://{}/{}\".format(s3_bucket,job_name)\n", + "\n", + "if local_mode:\n", + " output_tar_key = \"{}/output.tar.gz\".format(job_name)\n", + "else:\n", + " output_tar_key = \"{}/output/output.tar.gz\".format(job_name)\n", + "\n", + "intermediate_folder_key = \"{}/output/intermediate/\".format(job_name)\n", + "output_url = \"s3://{}/{}\".format(s3_bucket, output_tar_key)\n", + "intermediate_url = \"s3://{}/{}\".format(s3_bucket, intermediate_folder_key)\n", + "\n", + "print(\"S3 job path: {}\".format(s3_url))\n", + "print(\"Output.tar.gz location: {}\".format(output_url))\n", + "print(\"Intermediate folder path: {}\".format(intermediate_url))\n", + " \n", + "tmp_dir = \"/tmp/{}\".format(job_name)\n", + "os.system(\"mkdir {}\".format(tmp_dir))\n", + "print(\"Create local folder {}\".format(tmp_dir))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualization" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Plot metrics for training job\n", + "We can pull the Off Policy Evaluation(OPE) metric of the training and plot it to see the performance of the model over time." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "\n", + "csv_file_name = \"worker_0.batch_rl_graph.main_level.main_level.agent_0.csv\"\n", + "key = os.path.join(intermediate_folder_key, csv_file_name)\n", + "wait_for_s3_object(s3_bucket, key, tmp_dir, training_job_name=job_name)\n", + "\n", + "csv_file = \"{}/{}\".format(tmp_dir, csv_file_name)\n", + "df = pd.read_csv(csv_file)\n", + "df = df.dropna(subset=['Sequential Doubly Robust'])\n", + "df.dropna(subset=['Weighted Importance Sampling'])\n", + " \n", + "plt.figure(figsize=(12,5))\n", + "plt.xlabel('Number of epochs')\n", + "\n", + "ax1 = df['Weighted Importance Sampling'].plot(color='blue', grid=True, label='WIS')\n", + "ax2 = df['Sequential Doubly Robust'].plot(color='red', grid=True, secondary_y=True, label='SDR')\n", + "\n", + "h1, l1 = ax1.get_legend_handles_labels()\n", + "h2, l2 = ax2.get_legend_handles_labels()\n", + "\n", + "plt.legend(h1+h2, l1+l2, loc=1)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "There is a set of methods used to investigate the performance of the current trained policy without interacting with simulator / live environment. They can be used to estimate the goodness of the policy, based on the dataset collected from other policy. Here we showed two of these OPE metrics: WIS (Weighted Importance Sampling) [3] and SDR (Sequential Doubly Robust) [4]. As we can see in the plot, these metrics are improving as the learning agent is iterating over the given dataset." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Evaluation of RL models\n", + "\n", + "To evaluate the model trained with off policy data, we need to see the accumulative rewards of the agent by interacting with the environment. We use the last checkpointed model to run evaluation of the RL Agent. We use a different preset file here `preset-cartpole-ddqnbcq-env.py` to let the RL agent interact with the environment and collect rewards.\n", + "\n", + "### Load checkpointed model\n", + "\n", + "Checkpoint is passed on for evaluation / inference in the checkpoint channel. In local mode, we can simply use the local directory, whereas in the SageMaker mode, it needs to be moved to S3 first." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "wait_for_s3_object(s3_bucket, output_tar_key, tmp_dir, training_job_name=job_name) \n", + "\n", + "if not os.path.isfile(\"{}/output.tar.gz\".format(tmp_dir)):\n", + " raise FileNotFoundError(\"File output.tar.gz not found\")\n", + "os.system(\"tar -xvzf {}/output.tar.gz -C {}\".format(tmp_dir, tmp_dir))\n", + "\n", + "if local_mode:\n", + " checkpoint_dir = \"{}/data/checkpoint\".format(tmp_dir)\n", + "else:\n", + " checkpoint_dir = \"{}/checkpoint\".format(tmp_dir)\n", + "\n", + "print(\"Checkpoint directory {}\".format(checkpoint_dir))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if local_mode:\n", + " checkpoint_path = 'file://{}'.format(checkpoint_dir)\n", + " print(\"Local checkpoint file path: {}\".format(checkpoint_path))\n", + "else:\n", + " checkpoint_path = \"s3://{}/{}/checkpoint/\".format(s3_bucket, job_name)\n", + " if not os.listdir(checkpoint_dir):\n", + " raise FileNotFoundError(\"Checkpoint files not found under the path\")\n", + " os.system(\"aws s3 cp --recursive {} {}\".format(checkpoint_dir, checkpoint_path))\n", + " print(\"S3 checkpoint file path: {}\".format(checkpoint_path))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "estimator_eval = RLEstimator(entry_point=\"evaluate-coach.py\",\n", + " source_dir='src',\n", + " dependencies=[\"common/sagemaker_rl\"],\n", + " image_name=image,\n", + " role=role,\n", + " train_instance_type=instance_type,\n", + " train_instance_count=1,\n", + " output_path=s3_output_path,\n", + " base_job_name=job_name_prefix,\n", + " hyperparameters = {\n", + " \"RLCOACH_PRESET\": \"preset-cartpole-ddqnbcq-env\",\n", + " \"evaluate_steps\": 1000\n", + " }\n", + " )\n", + "\n", + "\n", + "estimator_eval.fit({'checkpoint': checkpoint_path})" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Batch Transform\n", + "\n", + "As we can see from the above evaluation job, the trained agent gets a total reward of around `200` as compared to a total reward around `25` in our offline dataset. Therefore, we can confirm that the agent has learned a better policy from the off-policy data.\n", + "\n", + "After we get the trained model, we can use it to do SageMaker Batch Transform, where customers can provide large volumes of input state features and get predictions with high throughput." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import time\n", + "\n", + "from sagemaker.tensorflow.serving import Model\n", + "if local_mode:\n", + " sage_session = sagemaker.local.LocalSession()\n", + "\n", + "# Create SageMaker model entity by using model data generated by the estimator \n", + "model = Model(model_data=estimator.model_data,\n", + " sagemaker_session=sage_session,\n", + " role=role)\n", + "\n", + "prefix = \"batch_test\"\n", + "\n", + "# setup input data prefix and output data prefix for batch transform\n", + "batch_input = 's3://{}/{}/{}/input/'.format(s3_bucket, job_name, prefix) # The location of the test dataset\n", + "batch_output = 's3://{}/{}/{}/output/'.format(s3_bucket, job_name, prefix) # The location to store the results of the batch transform job\n", + "print(\"Inputpath for batch transform: {}\".format(batch_input))\n", + "print(\"Outputpath for batch transform: {}\".format(batch_output))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this notebook, we use the states of the environments as input for the Batch Transform." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import time\n", + "file_name = 'env_states_{}.json'.format(int(time.time()))\n", + "# resetting the environments\n", + "vectored_envs.reset_all_envs()\n", + "# dump environment states into jsonlines file\n", + "vectored_envs.dump_environment_states(tmp_dir, file_name)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In order to use SageMaker Batch Transform, we'll need to first upload the input data from local to S3 bucket" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%time\n", + "from pathlib import Path\n", + "\n", + "local_input_file_path = Path(tmp_dir) / file_name\n", + "s3_input_file_path = batch_input + file_name # Path library will remove :// from s3 path\n", + "print(\"Copy file from local path '{}' to s3 path '{}'\".format(local_input_file_path, s3_input_file_path))\n", + "assert os.system(\"aws s3 cp {} {}\".format(local_input_file_path, s3_input_file_path)) == 0\n", + "print(\"S3 batch input file path: {}\".format(s3_input_file_path))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Similar to how we launch a training job on SageMaker, we can initiate a batch transform job either in `Local` mode or `SageMaker` mode." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if local_mode:\n", + " instance_type = 'local'\n", + "else:\n", + " instance_type = \"ml.m4.xlarge\"\n", + "\n", + "transformer = model.transformer(instance_count=1, instance_type=instance_type, output_path=batch_output, assemble_with = 'Line', accept = 'application/jsonlines', strategy='SingleRecord')\n", + "\n", + "transformer.transform(data=batch_input, data_type='S3Prefix', content_type='application/jsonlines', split_type='Line', join_source='Input')\n", + "\n", + "transformer.wait()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "After we finished the batch transform job, we can download the prediction output from S3 bucket to local machine." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import subprocess\n", + "\n", + "# get the latest generated output file\n", + "cmd = \"aws s3 ls {} --recursive | sort | tail -n 1\".format(batch_output)\n", + "result = subprocess.check_output(cmd, shell=True).decode(\"utf-8\").split(' ')[-1].strip()\n", + "local_output_file_path = Path(tmp_dir) / f\"{file_name}.out\"\n", + "s3_output_file_path = 's3://{}/{}'.format(s3_bucket,result)\n", + "print(\"Copy file from s3 path '{}' to local path '{}'\".format(s3_output_file_path, local_output_file_path))\n", + "os.system(\"aws s3 cp {} {}\".format(s3_output_file_path, local_output_file_path))\n", + "print(\"S3 batch output file local path: {}\".format(local_output_file_path))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import subprocess\n", + "\n", + "batcmd=\"cat {}\".format(local_output_file_path)\n", + "results = subprocess.check_output(batcmd, shell=True).decode(\"utf-8\").split('\\n')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "results[:10]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this notebook, we use simulated environments to collect rollout data of a random policy. Assuming the updated policy is now deployed, we can use Batch Transform to collect rollout data from this policy. \n", + "\n", + "Here are the steps on how to collect rollout data with Batch Transform:\n", + "1. Use Batch Transform to get action predictions, provided observation features from the live environment at timestep *t*\n", + "2. Deployed agent takes suggested actions against the environment (simulator / real) at timestep *t*\n", + "3. Environment returns new observation features at timestep *t+1*\n", + "4. Return back to step 1. Use Batch Transform to get action predictions at timestep *t+1*\n", + "\n", + "This iterative procedure enables us to collect a set of data that can cover the whole episode, similar to what we've shown at the beginning of the notebook. Once the data is sufficient, we can use these data to kick off a BatchRL training again.\n", + "\n", + "Batch Transform works well when there are multiple episodes interacting with the environments concurrently. One of the typical use cases is email campaign, where each email user is an independent episode interacting with the deployed policy. Batch Transform can concurrently collect rollout data from millions of user context with efficiency. The collected rollout data can then be supplied to Batch RL Training to train a better policy to serve the email users." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Reference\n", + "\n", + "1. Batch Reinforcement Learning with Coach: https://github.com/NervanaSystems/coach/blob/master/tutorials/4.%20Batch%20Reinforcement%20Learning.ipynb\n", + "2. AG Barto, RS Sutton and CW Anderson, \"Neuronlike Adaptive Elements That Can Solve Difficult Learning Control Problem\", IEEE Transactions on Systems, Man, and Cybernetics, 1983.\n", + "3. Thomas, Philip, Georgios Theocharous, and Mohammad Ghavamzadeh. \"High confidence policy improvement.\" International Conference on Machine Learning. 2015.\n", + "4. Jiang, Nan, and Lihong Li. \"Doubly robust off-policy value evaluation for reinforcement learning.\" arXiv preprint arXiv:1511.03722 (2015).\n", + "5. Fujimoto, Scott, David Meger, and Doina Precup. \"Off-policy deep reinforcement learning without exploration.\" arXiv preprint arXiv:1812.02900 (2018)\n", + "6. Van Hasselt, Hado, Arthur Guez, and David Silver. \"Deep reinforcement learning with double q-learning.\" Thirtieth AAAI conference on artificial intelligence. 2016." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/reinforcement_learning/rl_cartpole_batch_coach/src/__init__.py b/reinforcement_learning/rl_cartpole_batch_coach/src/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/reinforcement_learning/rl_cartpole_batch_coach/src/evaluate-coach.py b/reinforcement_learning/rl_cartpole_batch_coach/src/evaluate-coach.py new file mode 100644 index 0000000000..0d6875a8c8 --- /dev/null +++ b/reinforcement_learning/rl_cartpole_batch_coach/src/evaluate-coach.py @@ -0,0 +1,91 @@ +from sagemaker_rl.coach_launcher import SageMakerCoachPresetLauncher, CoachConfigurationList +import argparse +import os +import rl_coach +from rl_coach.base_parameters import Frameworks, TaskParameters +from rl_coach.core_types import EnvironmentSteps + + +def inplace_replace_in_file(filepath, old, new): + with open(filepath, 'r') as f: + contents = f.read() + with open(filepath, 'w') as f: + contents = contents.replace(old, new) + f.write(contents) + + +class MyLauncher(SageMakerCoachPresetLauncher): + + def default_preset_name(self): + """This points to a .py file that configures everything about the RL job. + It can be overridden at runtime by specifying the RLCOACH_PRESET hyperparameter. + """ + return 'preset-cartpole-dqn' + + def start_single_threaded(self, task_parameters, graph_manager, args): + """Override to use custom evaluate_steps, instead of infinite steps. Just evaluate. + """ + graph_manager.agent_params.visualization.dump_csv = False # issues with CSV export in evaluation only + graph_manager.create_graph(task_parameters) + graph_manager.evaluate(EnvironmentSteps(args.evaluate_steps)) + graph_manager.close() + + def get_config_args(self, parser): + """Overrides the default CLI parsing. + Sets the configuration parameters for what a SageMaker run should do. + Note, this does not support the "play" mode. + """ + ### Parse Arguments + # first, convert the parser to a Namespace object with all default values. + empty_arg_list = [] + args, _ = parser.parse_known_args(args=empty_arg_list) + parser = self.sagemaker_argparser() + sage_args, unknown = parser.parse_known_args() + + ### Set Arguments + args.preset = sage_args.RLCOACH_PRESET + backend = os.getenv('COACH_BACKEND', 'tensorflow') + args.framework = args.framework = Frameworks[backend] + args.checkpoint_save_dir = None + args.checkpoint_restore_dir = "/opt/ml/input/data/checkpoint" + # Correct TensorFlow checkpoint file (https://github.com/tensorflow/tensorflow/issues/9146) + if backend == "tensorflow": + checkpoint_filepath = os.path.join(args.checkpoint_restore_dir, 'checkpoint') + inplace_replace_in_file(checkpoint_filepath, "/opt/ml/output/data/checkpoint", ".") + # Override experiment_path used for outputs (note CSV not stored, see `start_single_threaded`). + args.experiment_path = '/opt/ml/output/intermediate' + rl_coach.logger.experiment_path = '/opt/ml/output/intermediate' # for gifs + args.evaluate = True # not actually used, but must be set (see `evaluate_steps`) + args.evaluate_steps = sage_args.evaluate_steps + args.no_summary = True # so process doesn't hang at end + # must be set + self.hyperparameters = CoachConfigurationList() + + return args + + def sagemaker_argparser(self): + """ + Expose only the CLI arguments that make sense in the SageMaker context. + """ + parser = argparse.ArgumentParser() + parser.add_argument('-p', '--RLCOACH_PRESET', + help="(string) Name of the file with the RLCoach preset", + default=self.default_preset_name(), + type=str) + parser.add_argument('--evaluate_steps', + help="(int) Number of evaluation steps to takr", + default=1000, + type=int) + return parser + + @classmethod + def evaluate_main(cls): + """Entrypoint for training. + Parses command-line arguments and starts training. + """ + evaluator = cls() + evaluator.launch() + + +if __name__ == '__main__': + MyLauncher.evaluate_main() \ No newline at end of file diff --git a/reinforcement_learning/rl_cartpole_batch_coach/src/preset-cartpole-ddqnbcq-env.py b/reinforcement_learning/rl_cartpole_batch_coach/src/preset-cartpole-ddqnbcq-env.py new file mode 100644 index 0000000000..c27b10b5c0 --- /dev/null +++ b/reinforcement_learning/rl_cartpole_batch_coach/src/preset-cartpole-ddqnbcq-env.py @@ -0,0 +1,78 @@ +from rl_coach.agents.clipped_ppo_agent import ClippedPPOAgentParameters +from rl_coach.agents.dqn_agent import DQNAgentParameters +from rl_coach.agents.ddqn_bcq_agent import DDQNBCQAgentParameters +from rl_coach.architectures.layers import Dense +from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters, DistributedCoachSynchronizationType +from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase, CsvDataset +from rl_coach.environments.gym_environment import GymVectorEnvironment, mujoco_v2 +from rl_coach.exploration_policies.e_greedy import EGreedyParameters +from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager +from rl_coach.graph_managers.batch_rl_graph_manager import BatchRLGraphManager +from rl_coach.graph_managers.graph_manager import ScheduleParameters +from rl_coach.memories.memory import MemoryGranularity +from rl_coach.schedules import LinearSchedule +from rl_coach.memories.episodic import EpisodicExperienceReplayParameters +from rl_coach.agents.ddqn_bcq_agent import KNNParameters +from rl_coach.spaces import SpacesDefinition, DiscreteActionSpace, VectorObservationSpace, StateSpace, RewardSpace + +# #################### +# # Graph Scheduling # +# #################### + +schedule_params = ScheduleParameters() +# 50 epochs (we run train over all the dataset, every epoch) of training +schedule_params.improve_steps = TrainingSteps(50) +# we evaluate the model every epoch +schedule_params.steps_between_evaluation_periods = TrainingSteps(1) + + +######### +# Agent # +######### +# note that we have moved to BCQ, which will help the training to converge better and faster +agent_params = DDQNBCQAgentParameters() +agent_params.network_wrappers['main'].batch_size = 128 +agent_params.algorithm.num_steps_between_copying_online_weights_to_target = TrainingSteps(50) +agent_params.algorithm.discount = 0.99 + +# NN configuration +agent_params.network_wrappers['main'].learning_rate = 0.0001 +agent_params.network_wrappers['main'].replace_mse_with_huber_loss = False + +# ER - we'll be needing an episodic replay buffer for off-policy evaluation +agent_params.memory = EpisodicExperienceReplayParameters() + +# E-Greedy schedule - there is no exploration in Batch RL. Disabling E-Greedy. +agent_params.exploration.epsilon_schedule = LinearSchedule(initial_value=0, final_value=0, decay_steps=1) +agent_params.exploration.evaluation_epsilon = 0 + +# can use either a kNN or a NN based model for predicting which actions not to max over in the bellman equation +agent_params.algorithm.action_drop_method_parameters = KNNParameters() + + +################# +# Visualization # +################# + +vis_params = VisualizationParameters() +vis_params.dump_gifs = True + +################ +# Environment # +################ +env_params = GymVectorEnvironment(level='CartPole-v0') + +######## +# Test # +######## +preset_validation_params = PresetValidationParameters() +preset_validation_params.test = True +preset_validation_params.min_reward_threshold = 150 +preset_validation_params.max_episodes_to_achieve_reward = 250 + + +graph_manager = BatchRLGraphManager(agent_params=agent_params, + env_params=env_params, + schedule_params=schedule_params, + vis_params=vis_params, + preset_validation_params=preset_validation_params) diff --git a/reinforcement_learning/rl_cartpole_batch_coach/src/preset-cartpole-ddqnbcq.py b/reinforcement_learning/rl_cartpole_batch_coach/src/preset-cartpole-ddqnbcq.py new file mode 100644 index 0000000000..e6c1bfcb15 --- /dev/null +++ b/reinforcement_learning/rl_cartpole_batch_coach/src/preset-cartpole-ddqnbcq.py @@ -0,0 +1,79 @@ +from rl_coach.agents.ddqn_bcq_agent import DDQNBCQAgentParameters +from rl_coach.base_parameters import VisualizationParameters, PresetValidationParameters +from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentSteps, RunPhase, CsvDataset +from rl_coach.graph_managers.batch_rl_graph_manager import BatchRLGraphManager +from rl_coach.graph_managers.graph_manager import ScheduleParameters +from rl_coach.schedules import LinearSchedule +from rl_coach.memories.episodic import EpisodicExperienceReplayParameters +from rl_coach.agents.ddqn_bcq_agent import KNNParameters +from rl_coach.spaces import SpacesDefinition, DiscreteActionSpace, VectorObservationSpace, StateSpace, RewardSpace + +# #################### +# # Graph Scheduling # +# #################### + +schedule_params = ScheduleParameters() +# 50 epochs (we run train over all the dataset, every epoch) of training +schedule_params.improve_steps = TrainingSteps(50) +# we evaluate the model every epoch +schedule_params.steps_between_evaluation_periods = TrainingSteps(1) + +######### +# Agent # +######### +# note that we have moved to BCQ, which will help the training to converge better and faster +agent_params = DDQNBCQAgentParameters() +agent_params.network_wrappers['main'].batch_size = 128 +agent_params.algorithm.num_steps_between_copying_online_weights_to_target = TrainingSteps(50) +agent_params.algorithm.discount = 0.99 + +# NN configuration +agent_params.network_wrappers['main'].learning_rate = 0.0001 +agent_params.network_wrappers['main'].replace_mse_with_huber_loss = False + +# ER - we'll be needing an episodic replay buffer for off-policy evaluation +agent_params.memory = EpisodicExperienceReplayParameters() + +# E-Greedy schedule - there is no exploration in Batch RL. Disabling E-Greedy. +agent_params.exploration.epsilon_schedule = LinearSchedule(initial_value=0, final_value=0, decay_steps=1) +agent_params.exploration.evaluation_epsilon = 0 + +# can use either a kNN or a NN based model for predicting which actions not to max over in the bellman equation +agent_params.algorithm.action_drop_method_parameters = KNNParameters() + +########### +# Dataset # +########### +DATATSET_PATH = 'cartpole_dataset.csv' +agent_params.memory = EpisodicExperienceReplayParameters() +agent_params.memory.load_memory_from_file_path = CsvDataset(DATATSET_PATH, is_episodic = True) + +spaces = SpacesDefinition(state=StateSpace({'observation': VectorObservationSpace(shape=4)}), + goal=None, + action=DiscreteActionSpace(2), + reward=RewardSpace(1)) + +################# +# Visualization # +################# + +vis_params = VisualizationParameters() +vis_params.dump_gifs = True + +######## +# Test # +######## +preset_validation_params = PresetValidationParameters() +preset_validation_params.test = True +preset_validation_params.min_reward_threshold = 150 +preset_validation_params.max_episodes_to_achieve_reward = 250 + + +graph_manager = BatchRLGraphManager(agent_params=agent_params, + env_params=None, + spaces_definition=spaces, + schedule_params=schedule_params, + vis_params=vis_params, + reward_model_num_epochs=30, + train_to_eval_ratio=0.4, + preset_validation_params=preset_validation_params) diff --git a/reinforcement_learning/rl_cartpole_batch_coach/src/train-coach.py b/reinforcement_learning/rl_cartpole_batch_coach/src/train-coach.py new file mode 100644 index 0000000000..b6cdf532e2 --- /dev/null +++ b/reinforcement_learning/rl_cartpole_batch_coach/src/train-coach.py @@ -0,0 +1,61 @@ +from sagemaker_rl.coach_launcher import SageMakerCoachPresetLauncher +import shutil + +class MyLauncher(SageMakerCoachPresetLauncher): + + def default_preset_name(self): + """This points to a .py file that configures everything about the RL job. + It can be overridden at runtime by specifying the RLCOACH_PRESET hyperparameter. + """ + return 'preset-acrobot-dqn' + + def map_hyperparameter(self, name, value): + """Here we configure some shortcut names for hyperparameters that we expect to use frequently. + Essentially anything in the preset file can be overridden through a hyperparameter with a name + like "rl.agent_params.algorithm.etc". + """ + # maps from alias (key) to fully qualified coach parameter (value) + mapping = { + "discount": "rl.agent_params.algorithm.discount", + "evaluation_episodes": "rl.evaluation_steps:EnvironmentEpisodes", + "improve_steps": "rl.improve_steps:TrainingSteps" + } + if name in mapping: + self.apply_hyperparameter(mapping[name], value) + else: + super().map_hyperparameter(name, value) + + def _save_tf_model(self): + import tensorflow as tf + ckpt_dir = '/opt/ml/output/data/checkpoint' + model_dir = '/opt/ml/model' + + # Re-Initialize from the checkpoint so that you will have the latest models up. + tf.train.init_from_checkpoint(ckpt_dir, + {'main_level/agent/online/network_0/': 'main_level/agent/online/network_0'}) + tf.train.init_from_checkpoint(ckpt_dir, + {'main_level/agent/online/network_1/': 'main_level/agent/online/network_1'}) + + # Create a new session with a new tf graph. + sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) + sess.run(tf.global_variables_initializer()) # initialize the checkpoint. + + # print([n.name for n in tf.get_default_graph().as_graph_def().node]) + # This is the node that will accept the input. + input_nodes = tf.get_default_graph().get_tensor_by_name('main_level/agent/main/online/' + \ + 'network_0/observation/observation:0') + # This is the node that will produce the output. + output_nodes = tf.get_default_graph().get_operation_by_name('main_level/agent/main/online/' + \ + 'network_0/q_values_head_0/softmax') + # Save the model as a servable model. + tf.saved_model.simple_save(session=sess, + export_dir='model', + inputs={"observation": input_nodes}, + outputs={"policy": output_nodes.outputs[0]}) + # Move to the appropriate folder. + shutil.move('model/', model_dir + '/model/tf-model/00000001/') + # SageMaker will pick it up and upload to the right path. + print("Success") + +if __name__ == '__main__': + MyLauncher.train_main()