-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmemory.py
99 lines (75 loc) · 2.92 KB
/
memory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import numpy as np
import torch
import jax
import jax.numpy as jnp
from recordclass import recordclass
Transitions = recordclass('Transitions', ('states', 'actions', 'next_states', 'rewards', 'dones'))
class ReplayMemory():
def __init__(self, state_dim, action_dim, maxlen = 1000):
self.len = 0
self.maxlen = maxlen
self.states = np.zeros((maxlen, state_dim))
self.actions = np.zeros((maxlen, action_dim))
self.next_states = np.zeros((maxlen, state_dim))
self.rewards = np.zeros((maxlen, 1))
self.dones = np.zeros((maxlen, 1))
def add_transition(self, state, action, next_state, reward, done):
index = self.len % self.maxlen
self.states[index] = state
self.actions[index] = action
self.next_states[index] = next_state
self.rewards[index] = reward
self.dones[index] = done
self.len += 1
def sample(self, batch_size):
maxind = min(self.len, self.maxlen)
inds = np.random.choice(maxind, batch_size)
batch = Transitions(self.states[inds],
self.actions[inds],
self.next_states[inds],
self.rewards[inds],
self.dones[inds])
for i in range(5):
batch[i] = torch.tensor(batch[i], dtype = torch.float)
return batch
def clear(self):
self.len = 0
def __len__(self):
return min(self.len, self.maxlen)
# class JaxReplayMemory():
# def __init__(self, state_dim, action_dim, maxlen = 1000):
# self.len = 0
# self.maxlen = maxlen
# self.states = jnp.zeros((maxlen, state_dim))
# self.actions = jnp.zeros((maxlen, action_dim))
# self.next_states = jnp.zeros((maxlen, state_dim))
# self.rewards = jnp.zeros((maxlen, 1))
# self.dones = jnp.zeros((maxlen, 1))
# def add_transition(self, state, action, next_state, reward, done):
# index = self.len % self.maxlen
# self.states = self.states.at[index].set(state)
# self.actions = self.actions.at[index].set(action)
# self.next_states = self.next_states.at[index].set(next_state)
# self.rewards = self.rewards.at[index].set(reward)
# self.dones = self.dones.at[index].set(done)
# self.len += 1
# def sample(self, batch_size, key):
# maxind = min(self.len, self.maxlen)
# inds = jax.random.choice(key, maxind, shape=(batch_size,))
# return Transitions(
# self.states[inds],
# self.actions[inds],
# self.next_states[inds],
# self.rewards[inds],
# self.dones[inds]
# )
# def clear(self):
# self.len = 0
# def __len__(self):
# return min(self.len, self.maxlen)
# class SequentialMemory():
# '''
# ToDo
# '''
# def __init__(self, state_dim, action_dim, maxlen = 1000):
# pass