import itertools
import time
import numpy as np
import cv2
from moviepy.editor import VideoClip
WORLD_HEIGHT = 4
WORLD_WIDTH = 4
WALL_FRAC = .2
NUM_WINS = 5
NUM_LOSE = 10
class GridWorld:
def __init__(self, world_height=3, world_width=4, discount_factor=.5, default_reward=-.5, wall_penalty=-.6,
win_reward=5., lose_reward=-10., viz=True, patch_side=120, grid_thickness=2, arrow_thickness=3,
wall_locs=[[1, 1], [1, 2]], win_locs=[[0, 3]], lose_locs=[[1, 3]], start_loc=[0, 0],
reset_prob=.2):
self.world = np.ones([world_height, world_width]) * default_reward
self.reset_prob = reset_prob
self.world_height = world_height
self.world_width = world_width
self.wall_penalty = wall_penalty
self.win_reward = win_reward
self.lose_reward = lose_reward
self.default_reward = default_reward
self.discount_factor = discount_factor
self.patch_side = patch_side
self.grid_thickness = grid_thickness
self.arrow_thickness = arrow_thickness
self.wall_locs = np.array(wall_locs)
self.win_locs = np.array(win_locs)
self.lose_locs = np.array(lose_locs)
self.at_terminal_state = False
self.auto_reset = True
self.random_respawn = True
self.step = 0
self.viz_canvas = None
self.viz = viz
self.path_color = (128, 128, 128)
self.wall_color = (0, 255, 0)
self.win_color = (0, 0, 255)
self.lose_color = (255, 0, 0)
self.world[self.wall_locs[:, 0], self.wall_locs[:, 1]] = self.wall_penalty
self.world[self.lose_locs[:, 0], self.lose_locs[:, 1]] = self.lose_reward
self.world[self.win_locs[:, 0], self.win_locs[:, 1]] = self.win_reward
spawn_condn = lambda loc: self.world[loc[0], loc[1]] == self.default_reward
self.spawn_locs = np.array([loc for loc in itertools.product(np.arange(self.world_height),
np.arange(self.world_width))
if spawn_condn(loc)])
self.start_state = np.array(start_loc)
self.bot_rc = None
self.reset()
self.actions = [self.up, self.left, self.right, self.down, self.noop]
self.action_labels = ['UP', 'LEFT', 'RIGHT', 'DOWN', 'NOOP']
self.q_values = np.ones([self.world.shape[0], self.world.shape[1], len(self.actions)]) * 1. / len(self.actions)
if self.viz:
self.init_grid_canvas()
self.video_out_fpath = 'shm_dqn_gridsolver-' + str(time.time()) + '.mp4'
self.clip = VideoClip(self.make_frame, duration=15)
def make_frame(self, t):
self.action()
frame = self.highlight_loc(self.viz_canvas, self.bot_rc[0], self.bot_rc[1])
return frame
def check_terminal_state(self):
if self.world[self.bot_rc[0], self.bot_rc[1]] == self.lose_reward \
or self.world[self.bot_rc[0], self.bot_rc[1]] == self.win_reward:
self.at_terminal_state = True
# print('------++++---- TERMINAL STATE ------++++----')
# if self.world[self.bot_rc[0], self.bot_rc[1]] == self.win_reward:
# print('GAME WON! :D')
# elif self.world[self.bot_rc[0], self.bot_rc[1]] == self.lose_reward:
# print('GAME LOST! :(')
if self.auto_reset:
self.reset()
def reset(self):
# print('Resetting')
if not self.random_respawn:
self.bot_rc = self.start_state.copy()
else:
self.bot_rc = self.spawn_locs[np.random.choice(np.arange(len(self.spawn_locs)))].copy()
self.at_terminal_state = False
def up(self):
action_idx = 0
# print(self.action_labels[action_idx])
new_r = self.bot_rc[0] - 1
if new_r < 0 or self.world[new_r, self.bot_rc[1]] == self.wall_penalty:
return self.wall_penalty, action_idx
self.bot_rc[0] = new_r
reward = self.world[self.bot_rc[0], self.bot_rc[1]]
self.check_terminal_state()
return reward, action_idx
def left(self):
action_idx = 1
# print(self.action_labels[action_idx])
new_c = self.bot_rc[1] - 1
if new_c < 0 or self.world[self.bot_rc[0], new_c] == self.wall_penalty:
return self.wall_penalty, action_idx
self.bot_rc[1] = new_c
reward = self.world[self.bot_rc[0], self.bot_rc[1]]
self.check_terminal_state()
return reward, action_idx
def right(self):
action_idx = 2
# print(self.action_labels[action_idx])
new_c = self.bot_rc[1] + 1
if new_c >= self.world.shape[1] or self.world[self.bot_rc[0], new_c] == self.wall_penalty:
return self.wall_penalty, action_idx
self.bot_rc[1] = new_c
reward = self.world[self.bot_rc[0], self.bot_rc[1]]
self.check_terminal_state()
return reward, action_idx
def down(self):
action_idx = 3
# print(self.action_labels[action_idx])
new_r = self.bot_rc[0] + 1
if new_r >= self.world.shape[0] or self.world[new_r, self.bot_rc[1]] == self.wall_penalty:
return self.wall_penalty, action_idx
self.bot_rc[0] = new_r
reward = self.world[self.bot_rc[0], self.bot_rc[1]]
self.check_terminal_state()
return reward, action_idx
def noop(self):
action_idx = 4
# print(self.action_labels[action_idx])
reward = self.world[self.bot_rc[0], self.bot_rc[1]]
self.check_terminal_state()
return reward, action_idx
def qvals2probs(self, q_vals, epsilon=1e-4):
action_probs = q_vals - q_vals.min() + epsilon
action_probs = action_probs / action_probs.sum()
return action_probs
def action(self):
# print('================ ACTION =================')
if self.at_terminal_state:
print('At terminal state, please call reset()')
exit()
# print('Start position:', self.bot_rc)
start_bot_rc = self.bot_rc[0], self.bot_rc[1]
q_vals = self.q_values[self.bot_rc[0], self.bot_rc[1]]
action_probs = self.qvals2probs(q_vals)
reward, action_idx = np.random.choice(self.actions, p=action_probs)()
# print('End position:', self.bot_rc)
# print('Reward:', reward)
alpha = np.exp(-self.step / 10e9)
self.step += 1
qv = (1 - alpha) * q_vals[action_idx] + alpha * (reward + self.discount_factor
* self.q_values[self.bot_rc[0], self.bot_rc[1]].max())
self.q_values[start_bot_rc[0], start_bot_rc[1], action_idx] = qv
if self.viz:
self.update_viz(start_bot_rc[0], start_bot_rc[1])
if np.random.rand() < self.reset_prob:
# print('-----> Randomly resetting to a random spawn point with probability', self.reset_prob)
self.reset()
def highlight_loc(self, viz_in, i, j):
starty = i * (self.patch_side + self.grid_thickness)
endy = starty + self.patch_side
startx = j * (self.patch_side + self.grid_thickness)
endx = startx + self.patch_side
viz = viz_in.copy()
cv2.rectangle(viz, (startx, starty), (endx, endy), (255, 255, 255), thickness=self.grid_thickness)
return viz
def update_viz(self, i, j):
starty = i * (self.patch_side + self.grid_thickness)
endy = starty + self.patch_side
startx = j * (self.patch_side + self.grid_thickness)
endx = startx + self.patch_side
patch = np.zeros([self.patch_side, self.patch_side, 3]).astype(np.uint8)
if self.world[i, j] == self.default_reward:
patch[:, :, :] = self.path_color
elif self.world[i, j] == self.wall_penalty:
patch[:, :, :] = self.wall_color
elif self.world[i, j] == self.win_reward:
patch[:, :, :] = self.win_color
elif self.world[i, j] == self.lose_reward:
patch[:, :, :] = self.lose_color
if self.world[i, j] == self.default_reward:
action_probs = self.qvals2probs(self.q_values[i, j])
x_component = action_probs[2] - action_probs[1]
y_component = action_probs[0] - action_probs[3]
magnitude = 1. - action_probs[-1]
s = self.patch_side // 2
x_patch = int(s * x_component)
y_patch = int(s * y_component)
arrow_canvas = np.zeros_like(patch)
vx = s + x_patch
vy = s - y_patch
cv2.arrowedLine(arrow_canvas, (s, s), (vx, vy), (255, 255, 255), thickness=self.arrow_thickness,
tipLength=0.5)
gridbox = (magnitude * arrow_canvas + (1 - magnitude) * patch).astype(np.uint8)
self.viz_canvas[starty:endy, startx:endx] = gridbox
else:
self.viz_canvas[starty:endy, startx:endx] = patch
def init_grid_canvas(self):
org_h, org_w = self.world_height, self.world_width
viz_w = (self.patch_side * org_w) + (self.grid_thickness * (org_w - 1))
viz_h = (self.patch_side * org_h) + (self.grid_thickness * (org_h - 1))
self.viz_canvas = np.zeros([viz_h, viz_w, 3]).astype(np.uint8)
for i in range(org_h):
for j in range(org_w):
self.update_viz(i, j)
def solve(self):
if not self.viz:
while True:
self.action()
else:
self.clip.write_videofile(self.video_out_fpath, fps=460)
def gen_world_config(h, w, wall_frac=.5, num_wins=2, num_lose=3):
n = h * w
num_wall_blocks = int(wall_frac * n)
wall_locs = (np.random.rand(num_wall_blocks, 2) * [h, w]).astype(np.int)
win_locs = (np.random.rand(num_wins, 2) * [h, w]).astype(np.int)
lose_locs = (np.random.rand(num_lose, 2) * [h, w]).astype(np.int)
return wall_locs, win_locs, lose_locs
if __name__ == '__main__':
wall_locs, win_locs, lose_locs = gen_world_config(WORLD_HEIGHT, WORLD_WIDTH, WALL_FRAC, NUM_WINS, NUM_LOSE)
g = GridWorld(world_height=WORLD_HEIGHT, world_width=WORLD_WIDTH,
wall_locs=wall_locs, win_locs=win_locs, lose_locs=lose_locs, viz=True)
g.solve()
k = 0