Skip to content

Commit

Permalink
support openAI gym.make
Browse files Browse the repository at this point in the history
  • Loading branch information
jxx123 committed Jan 8, 2018
1 parent 7cfae30 commit a44ed34
Show file tree
Hide file tree
Showing 16 changed files with 1,349 additions and 1,102 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Compiled python modules.
*.pyc

# results foler
/results/

# Setuptools distribution folder.
/dist/

Expand Down
7 changes: 4 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@


setup(name='simglucose',
version='0.1.4',
description='A Type-1 Diabetes Simulator as a Reinforcement Learning Environment',
version='0.1.5',
description='A Type-1 Diabetes Simulator as a Reinforcement Learning Environment in OpenAI gym or rllab (python implementation of UVa/Padova Simulator)',
url='https://github.com/jxx123/simglucose',
author='Jinyu Xie',
author_email='[email protected]',
Expand All @@ -14,7 +14,8 @@
'numpy',
'scipy',
'matplotlib',
'pathos'
'pathos',
'gym'
],
include_package_data=True,
zip_safe=False)
6 changes: 6 additions & 0 deletions simglucose/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from gym.envs.registration import register

register(
id='simglucose-v0',
entry_point='simglucose.envs:T1DSimEnv',
)
9 changes: 5 additions & 4 deletions simglucose/actuator/pump.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pandas as pd
import pkg_resources
import logging
import numpy as np

INSULIN_PUMP_PARA_FILE = pkg_resources.resource_filename(
'simglucose', 'params/pump_params.csv')
Expand All @@ -21,17 +22,17 @@ def withName(cls, name):

def bolus(self, amount):
bol = amount * self.U2PMOL # convert from U/min to pmol/min
bol = round(bol / self._params['inc_bolus']
) * self._params['inc_bolus']
bol = np.round(bol / self._params['inc_bolus']
) * self._params['inc_bolus']
bol = bol / self.U2PMOL # convert from pmol/min to U/min
bol = min(bol, self._params['max_bolus'])
bol = max(bol, self._params['min_bolus'])
return bol

def basal(self, amount):
bas = amount * self.U2PMOL # convert from U/min to pmol/min
bas = round(bas / self._params['inc_basal']
) * self._params['inc_basal']
bas = np.round(bas / self._params['inc_basal']
) * self._params['inc_basal']
bas = bas / self.U2PMOL # convert from pmol/min to U/min
bas = min(bas, self._params['max_basal'])
bas = max(bas, self._params['min_basal'])
Expand Down
3 changes: 2 additions & 1 deletion simglucose/controller/basal_bolus_ctrller.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@ def __init__(self, target=140):
def policy(self, observation, reward, done, **kwargs):
sample_time = kwargs.get('sample_time', 1)
pname = kwargs.get('patient_name')
meal = kwargs.get('meal')

action = self._bb_policy(
pname,
observation.CHO,
meal,
observation.CGM,
sample_time)
return action
Expand Down
1 change: 1 addition & 0 deletions simglucose/envs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from simglucose.envs.simglucose_gym_env import T1DSimEnv
89 changes: 89 additions & 0 deletions simglucose/envs/simglucose_gym_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from simglucose.simulation.env import T1DSimEnv as _T1DSimEnv
from simglucose.patient.t1dpatient import T1DPatient
from simglucose.sensor.cgm import CGMSensor
from simglucose.actuator.pump import InsulinPump
from simglucose.simulation.scenario_gen import RandomScenario
from simglucose.controller.base import Action
import pandas as pd
import numpy as np
import pkg_resources
import gym
from gym import error, spaces, utils
from gym.utils import seeding
from datetime import datetime

PATIENT_PARA_FILE = pkg_resources.resource_filename(
'simglucose', 'params/vpatient_params.csv')


class T1DSimEnv(gym.Env):
'''
A wrapper of simglucose.simulation.env.T1DSimEnv to support gym API
'''
metadata = {'render.modes': ['human']}

def __init__(self, patient_name=None):
'''
patient_name must be 'adolescent#001' to 'adolescent#010',
or 'adult#001' to 'adult#010', or 'child#001' to 'child#010'
'''
if patient_name is None:
# patient_name = self.pick_patient()

# have to hard code the patient_name, gym has some interesting
# error when choosing the patient
patient_name = 'adolescent#001'
patient = T1DPatient.withName(patient_name)
sensor = CGMSensor.withName('Dexcom')
scenario = RandomScenario(start_time=datetime(2018, 1, 1, 0, 0, 0))
pump = InsulinPump.withName('Insulet')
self.env = _T1DSimEnv(patient, sensor, pump, scenario)

@staticmethod
def pick_patient():
patient_params = pd.read_csv(PATIENT_PARA_FILE)
while True:
print('Select patient:')
for j in range(len(patient_params)):
print('[{0}] {1}'.format(j + 1, patient_params['Name'][j]))
try:
select = int(input('>>> '))
except ValueError:
print('Please input a number.')
continue

if select < 1 or select > len(patient_params):
print('Please input 1 to {}'.format(len(patient_params)))
continue

return select

def _step(self, action):
# This gym only controls basal insulin
act = Action(basal=action, bolus=0)
return self.env.step(act)

def _reset(self):
return self.env.reset()

def _seed(self, seed=None):
rng, seed1 = seeding.np_random(seed=seed)
# Derive a random seed. This gets passed as a uint, but gets
# checked as an int elsewhere, so we need to keep it below
# 2**31.
seed2 = seeding.hash_seed(seed1 + 1) % 2**31
self.env.sensor.seed = seed1
self.env.scenario.seed = seed2
return [seed1, seed2]

def _render(self, mode='human', close=False):
self.env.render(close=close)

@property
def action_space(self):
ub = self.env.pump._params['max_basal']
return spaces.Box(low=0, high=ub, shape=(1,))

@property
def observation_space(self):
return spaces.Box(low=0, high=np.inf, shape=(1,))
11 changes: 10 additions & 1 deletion simglucose/sensor/cgm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __init__(self, params, seed=None):
self.name = params.Name
self.sample_time = params.sample_time
self.seed = seed
self.reset()
self._last_CGM = 0

@classmethod
def withName(cls, name, **kwargs):
Expand All @@ -37,6 +37,15 @@ def measure(self, patient):
# Zero-Order Hold
return self._last_CGM

@property
def seed(self):
return self._seed

@seed.setter
def seed(self, seed):
self._seed = seed
self._noise_generator = CGMNoise(self._params, seed=seed)

def reset(self):
logger.debug('Resetting CGM sensor ...')
self._noise_generator = CGMNoise(self._params, seed=self.seed)
Expand Down
107 changes: 36 additions & 71 deletions simglucose/simulation/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from datetime import timedelta
import logging
from collections import namedtuple
from simglucose.simulation.rendering import Viewer

rllab = True
try:
Expand Down Expand Up @@ -36,7 +37,7 @@ def step(self, action):
def reset(self):
raise NotImplementedError

Observation = namedtuple('Observation', ['CHO', 'CGM'])
Observation = namedtuple('Observation', ['CGM'])
logger = logging.getLogger(__name__)


Expand All @@ -50,21 +51,7 @@ def __init__(self,
self.sensor = sensor
self.pump = pump
self.scenario = scenario
self.sample_time = self.sensor.sample_time

# Initial Recording
BG = self.patient.observation.Gsub
horizon = 0
LBGI, HBGI, risk = risk_index([BG], horizon)
CGM = self.sensor.measure(self.patient)
self.time_hist = [self.scenario.start_time]
self.BG_hist = [BG]
self.CGM_hist = [CGM]
self.risk_hist = [risk]
self.LBGI_hist = [LBGI]
self.HBGI_hist = [HBGI]
self.CHO_hist = []
self.insulin_hist = []
self._reset()

@property
def time(self):
Expand Down Expand Up @@ -122,25 +109,22 @@ def step(self, action):
self.HBGI_hist.append(HBGI)

# Compute reward, and decide whether game is over
# reward = - np.log(risk)
# reward = 10 - risk
if len(self.risk_hist) > 1:
reward = self.risk_hist[-2] - self.risk_hist[-1]
else:
reward = - self.risk_hist[-1]
done = BG < 70 or BG > 350
obs = Observation(CHO=CHO, CGM=CGM)
obs = Observation(CGM=CGM)
return Step(observation=obs,
reward=reward,
done=done,
sample_time=self.sample_time,
patient_name=self.patient.name)
patient_name=self.patient.name,
meal=CHO)

def reset(self):
self.patient.reset()
self.sensor.reset()
self.pump.reset()
self.scenario.reset()
def _reset(self):
self.sample_time = self.sensor.sample_time
self.viewer = None

BG = self.patient.observation.Gsub
horizon = 0
Expand All @@ -155,12 +139,27 @@ def reset(self):
self.CHO_hist = []
self.insulin_hist = []

def reset(self):
self.patient.reset()
self.sensor.reset()
self.pump.reset()
self.scenario.reset()
self._reset()
CGM = self.sensor.measure(self.patient)
obs = Observation(CGM=CGM)
return Step(observation=obs,
reward=0,
done=False,
sample_time=self.sample_time,
patient_name=self.patient.name,
meal=0)

@property
def action_space(self):
if rllab:
ub = self.pump._params['max_basal'] + \
self.pump._params['max_bolus']
return Box(low=0, high=ub, shape=(1,))
ub = np.array([self.pump._params['max_basal'],
self.pump._params['max_bolus']])
return Box(low=np.array([0, 0]), high=ub)
else:
pass

Expand All @@ -171,51 +170,17 @@ def observation_space(self):
else:
pass

def render(self, axes, lines):
logger.info('Rendering ...')

lines[0].set_xdata(self.time_hist)
lines[0].set_ydata(self.BG_hist)

lines[1].set_xdata(self.time_hist)
lines[1].set_ydata(self.CGM_hist)

axes[0].draw_artist(axes[0].patch)
axes[0].draw_artist(lines[0])
axes[0].draw_artist(lines[1])

adjust_ylim(axes[0], min(min(self.BG_hist), min(self.CGM_hist)), max(
max(self.BG_hist), max(self.CGM_hist)))

lines[2].set_xdata(self.time_hist[:-1])
lines[2].set_ydata(self.CHO_hist)

axes[1].draw_artist(axes[1].patch)
axes[1].draw_artist(lines[2])

adjust_ylim(axes[1], min(self.CHO_hist), max(self.CHO_hist))

lines[3].set_xdata(self.time_hist[:-1])
lines[3].set_ydata(self.insulin_hist)

axes[2].draw_artist(axes[2].patch)
axes[2].draw_artist(lines[3])
adjust_ylim(axes[2], min(self.insulin_hist), max(self.insulin_hist))

lines[4].set_xdata(self.time_hist)
lines[4].set_ydata(self.LBGI_hist)

lines[5].set_xdata(self.time_hist)
lines[5].set_ydata(self.HBGI_hist)
def render(self, close=False):
if close:
if self.viewer is not None:
self.viewer.close()
self.viewer = None
return

lines[6].set_xdata(self.time_hist)
lines[6].set_ydata(self.risk_hist)
if self.viewer is None:
self.viewer = Viewer(self.scenario.start_time, self.patient.name)

axes[3].draw_artist(axes[3].patch)
axes[3].draw_artist(lines[4])
axes[3].draw_artist(lines[5])
axes[3].draw_artist(lines[6])
adjust_ylim(axes[3], min(self.risk_hist), max(self.risk_hist))
self.viewer.render(self.show_history())

def show_history(self):
df = pd.DataFrame()
Expand Down
Loading

0 comments on commit a44ed34

Please sign in to comment.