Skip to content

Commit

Permalink
Support custom reward
Browse files Browse the repository at this point in the history
  • Loading branch information
jxx123 committed Feb 27, 2018
1 parent fc0d6a0 commit d4d98a6
Show file tree
Hide file tree
Showing 7 changed files with 155 additions and 100 deletions.
43 changes: 42 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ This simulator is a python implementation of the FDA-approved [UVa/Padova Simula

## Main Features
- Simulation enviroment follows [OpenAI gym](https://github.com/openai/gym) and [rllab](https://github.com/rll/rllab) APIs. It returns observation, reward, done, info at each step, which means the simulator is "reinforcement-learning-ready".
- The reward at each step is `risk[t-1] - risk[t]`. Customized reward is not supported for now. `risk[t]` is the risk index at time `t` defined in this [paper](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC2903980/pdf/dia.2008.0138.pdf).
- Supports customized reward function. The reward function is a function of blood glucose measurements in the last hour. By default, the reward at each step is `risk[t-1] - risk[t]`. `risk[t]` is the risk index at time `t` defined in this [paper](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC2903980/pdf/dia.2008.0138.pdf).
- Supports parallel computing. The simulator simulates mutliple patients parallelly using [pathos multiprocessing package](https://github.com/uqfoundation/pathos) (you are free to turn parallel off by setting `parallel=False`).
- The simulator provides a random scenario generator (`from simglucose.simulation.scenario_gen import RandomScenario`) and a customized scenario generator (`from simglucose.simulation.scenario import CustomScnenario`). Commandline user-interface will guide you through the scenario settings.
- The simulator provides the most basic basal-bolus controller for now. It provides very simple syntax to implement your own controller, like Model Predictive Control, PID control, reinforcement learning control, etc.
Expand Down Expand Up @@ -112,6 +112,7 @@ simulate(sim_time=my_sim_time,
parallel=True)
```
### OpenAI Gym usage
- Using default reward
```python
import gym

Expand Down Expand Up @@ -145,6 +146,44 @@ for t in range(100):
print("Episode finished after {} timesteps".format(t + 1))
break
```
- Customized reward function
```python
import gym
from gym.envs.registration import register


def custom_reward(BG_last_hour):
if BG_last_hour[-1] > 180:
return -1
elif BG_last_hour[-1] < 70:
return -2
else:
return 1


register(
id='simglucose-adolescent2-v0',
entry_point='simglucose.envs:T1DSimEnv',
kwargs={'patient_name': 'adolescent#002',
'reward_fun': custom_reward}
)

env = gym.make('simglucose-adolescent2-v0')

reward = 1
done = False

observation = env.reset()
for t in range(200):
env.render(mode='human')
action = env.action_space.sample()
observation, reward, done, info = env.step(action)
print(observation)
print("Reward = {}".format(reward))
if done:
print("Episode finished after {} timesteps".format(t + 1))
break
```

### rllab usage
```python
Expand Down Expand Up @@ -276,6 +315,8 @@ name = [_f[:-4] for _f in filename] # get the filename without extension
df = pd.concat([pd.read_csv(f, index_col=0) for f in filename], keys=name)
report(df)
```
## Release Notes, 2/26/2017
- Support customized reward function.
## Release Notes, 1/10/2017
- Added workaround to select patient when make gym environment: register gym environment by passing kwargs of patient_name.
## Release Notes, 1/7/2017
Expand Down
35 changes: 35 additions & 0 deletions examples/custom_reward_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import gym
from gym.envs.registration import register


def custom_reward(BG_last_hour):
if BG_last_hour[-1] > 180:
return -1
elif BG_last_hour[-1] < 70:
return -2
else:
return 1


register(
id='simglucose-adolescent2-v0',
entry_point='simglucose.envs:T1DSimEnv',
kwargs={'patient_name': 'adolescent#002',
'reward_fun': custom_reward}
)

env = gym.make('simglucose-adolescent2-v0')

reward = 1
done = False

observation = env.reset()
for t in range(200):
env.render(mode='human')
action = env.action_space.sample()
observation, reward, done, info = env.step(action)
print(observation)
print("Reward = {}".format(reward))
if done:
print("Episode finished after {} timesteps".format(t + 1))
break
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


setup(name='simglucose',
version='0.1.8',
version='0.1.9',
description='A Type-1 Diabetes Simulator as a Reinforcement Learning Environment in OpenAI gym or rllab (python implementation of UVa/Padova Simulator)',
url='https://github.com/jxx123/simglucose',
author='Jinyu Xie',
Expand All @@ -15,7 +15,7 @@
'scipy',
'matplotlib',
'pathos',
'gym'
'gym==0.9.4'
],
include_package_data=True,
zip_safe=False)
8 changes: 6 additions & 2 deletions simglucose/envs/simglucose_gym_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class T1DSimEnv(gym.Env):
'''
metadata = {'render.modes': ['human']}

def __init__(self, patient_name=None):
def __init__(self, patient_name=None, reward_fun=None):
'''
patient_name must be 'adolescent#001' to 'adolescent#010',
or 'adult#001' to 'adult#010', or 'child#001' to 'child#010'
Expand All @@ -39,6 +39,7 @@ def __init__(self, patient_name=None):
scenario = RandomScenario(start_time=start_time, seed=seeds[2])
pump = InsulinPump.withName('Insulet')
self.env = _T1DSimEnv(patient, sensor, pump, scenario)
self.reward_fun = reward_fun

@staticmethod
def pick_patient():
Expand All @@ -64,7 +65,10 @@ def pick_patient():
def _step(self, action):
# This gym only controls basal insulin
act = Action(basal=action, bolus=0)
return self.env.step(act)
if self.reward_fun is None:
return self.env.step(act)
else:
return self.env.step(act, reward_fun=self.reward_fun)

def _reset(self):
obs, _, _, _ = self.env.reset()
Expand Down
22 changes: 15 additions & 7 deletions simglucose/simulation/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,15 @@ def Step(observation, reward, done, **kwargs):
logger = logging.getLogger(__name__)


def risk_diff(BG_last_hour):
if len(BG_last_hour) < 2:
return 0
else:
_, _, risk_current = risk_index([BG_last_hour[-1]], 1)
_, _, risk_prev = risk_index([BG_last_hour[-2]], 1)
return risk_prev - risk_current


class T1DSimEnv(object):
def __init__(self,
patient,
Expand Down Expand Up @@ -58,7 +67,7 @@ def mini_step(self, action):

return CHO, insulin, BG, CGM

def step(self, action):
def step(self, action, reward_fun=risk_diff):
'''
action is a namedtuple with keys: basal, bolus
'''
Expand All @@ -76,7 +85,7 @@ def step(self, action):
CGM += tmp_CGM / self.sample_time

# Compute risk index
horizon = 0
horizon = 1
LBGI, HBGI, risk = risk_index([BG], horizon)

# Record current action
Expand All @@ -92,10 +101,9 @@ def step(self, action):
self.HBGI_hist.append(HBGI)

# Compute reward, and decide whether game is over
if len(self.risk_hist) > 1:
reward = self.risk_hist[-2] - self.risk_hist[-1]
else:
reward = - self.risk_hist[-1]
window_size = int(60 / self.sample_time)
BG_last_hour = self.CGM_hist[- window_size:]
reward = reward_fun(BG_last_hour)
done = BG < 70 or BG > 350
obs = Observation(CGM=CGM)
return Step(observation=obs,
Expand All @@ -110,7 +118,7 @@ def _reset(self):
self.viewer = None

BG = self.patient.observation.Gsub
horizon = 0
horizon = 1
LBGI, HBGI, risk = risk_index([BG], horizon)
CGM = self.sensor.measure(self.patient)
self.time_hist = [self.scenario.start_time]
Expand Down
88 changes: 0 additions & 88 deletions simglucose/simulation/sim_engine.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
from simglucose.controller.base import Action
import matplotlib.pyplot as plt
import logging
import matplotlib.dates as mdates
from datetime import timedelta
import time
import os

Expand Down Expand Up @@ -31,8 +27,6 @@ def __init__(self,
self.path = path

def simulate(self):
# basal = self.env.patient._params.u2ss * self.env.patient._params.BW / 6000
# action = Action(basal=basal, bolus=0)
obs, reward, done, info = self.env.reset()
tic = time.time()
while self.env.time < self.env.scenario.start_time + self.sim_time:
Expand Down Expand Up @@ -79,85 +73,3 @@ def batch_sim(sim_instances, parallel=False):
toc = time.time()
print('Simulation took {} sec.'.format(toc - tic))
return results


if __name__ == '__main__':
from simglucose.simulation.env import T1DSimEnv
from simglucose.controller.basal_bolus_ctrller import BBController
from simglucose.sensor.cgm import CGMSensor
from simglucose.actuator.pump import InsulinPump
from simglucose.patient.t1dpatient import T1DPatient
from simglucose.simulation.scenario_gen import RandomScenario

path = './results'

patient1 = T1DPatient.withName('adolescent#001')
sensor1 = CGMSensor.withName('Dexcom', seed=1)
pump1 = InsulinPump.withName('Insulet')
scenario1 = RandomScenario(seed=1)
env1 = T1DSimEnv(patient1, sensor1, pump1, scenario1)
controller1 = BBController()

s1 = SimObj(env1, controller1, sim_time=timedelta(days=1), path=path)
s1.animate = False
# s1.animate = True
# s1.set_controller_kwargs(patient_name=patient1.name,
# sample_time=s1.env.sample_time)

# s1.simulate()

patient2 = T1DPatient.withName('adolescent#002')
sensor2 = CGMSensor.withName('Dexcom', seed=1)
pump2 = InsulinPump.withName('Insulet')
scenario2 = RandomScenario(seed=1)
env2 = T1DSimEnv(patient2, sensor2, pump2, scenario2)
controller2 = BBController()

s2 = SimObj(env2, controller2, sim_time=timedelta(days=1), path=path)
# s2.set_controller_kwargs(patient_name=patient2.name,
# sample_time=s2.env.sample_time)
s2.animate = False

sim_objects = [s1, s2]

nodes = 2
p = Pool(nodes=nodes)
tic = time.time()
results = p.map(sim, sim_objects)
toc = time.time()
print('{} workers took {} sec.'.format(nodes, toc - tic))
print(results)

# with Pool(processes=nodes) as p:
# tic = time.time()
# p.map(sim, sim_objects)
# toc = time.time()
# print('{} workers took {} sec.'.format(nodes, toc - tic))

# tic = time.time()
# map(simulation, sim_objects)
# toc = time.time()
# print('Serial took {} sec.'.format(toc - tic))

# for s in sim_objects:
# s.simulate()
# processes = [Process(target=simulation, args=(s,)) for s in sim_objects]
# for p in processes:
# p.start()
# p.join()

# # for p in processes:
# # p.join()

# # for s in sim_objects:
# # p = Process(target=s.simulate, args=())
# # p.start()

# with Pool(processes=2) as pool:
# # pool.map(simulation, sim_objects)
# res = [pool.apply_async(simulation, (s,)) for s in sim_objects]
# # res = pool.apply_async(simulation, ())
# # print(res.get())

# for r in res:
# print(r.get())
55 changes: 55 additions & 0 deletions tests/test_reward_fun.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import gym
import unittest
from simglucose.controller.basal_bolus_ctrller import BBController


def custom_reward(BG_last_hour):
if BG_last_hour[-1] > 180:
return -1
elif BG_last_hour[-1] < 70:
return -2
else:
return 1


class TestCustomReward(unittest.TestCase):
def test_custom_reward(self):
from gym.envs.registration import register
register(
id='simglucose-adolescent2-v0',
entry_point='simglucose.envs:T1DSimEnv',
kwargs={'patient_name': 'adolescent#002',
'reward_fun': custom_reward}
)

env = gym.make('simglucose-adolescent2-v0')
ctrller = BBController()

reward = 1
done = False
info = {'sample_time': 3,
'patient_name': 'adolescent#002',
'meal': 0}

observation = env.reset()
for t in range(200):
env.render(mode='human')
print(observation)
# action = env.action_space.sample()
ctrl_action = ctrller.policy(observation, reward, done, **info)
action = ctrl_action.basal + ctrl_action.bolus
observation, reward, done, info = env.step(action)
print("Reward = {}".format(reward))
if observation.CGM > 180:
self.assertEqual(reward, -1)
elif observation.CGM < 70:
self.assertEqual(reward, -2)
else:
self.assertEqual(reward, 1)
if done:
print("Episode finished after {} timesteps".format(t + 1))
break


if __name__ == '__main__':
unittest.main()

0 comments on commit d4d98a6

Please sign in to comment.