Naive Learning

import itertools
import time

import numpy as np
import cv2

from moviepy.editor import VideoClip


WORLD_HEIGHT = 4
WORLD_WIDTH = 4
WALL_FRAC = .2
NUM_WINS = 5
NUM_LOSE = 10


class GridWorld:

    def __init__(self, world_height=3, world_width=4, discount_factor=.5, default_reward=-.5, wall_penalty=-.6,
                 win_reward=5., lose_reward=-10., viz=True, patch_side=120, grid_thickness=2, arrow_thickness=3,
                 wall_locs=[[1, 1], [1, 2]], win_locs=[[0, 3]], lose_locs=[[1, 3]], start_loc=[0, 0],
                 reset_prob=.2):
        self.world = np.ones([world_height, world_width]) * default_reward
        self.reset_prob = reset_prob
        self.world_height = world_height
        self.world_width = world_width
        self.wall_penalty = wall_penalty
        self.win_reward = win_reward
        self.lose_reward = lose_reward
        self.default_reward = default_reward
        self.discount_factor = discount_factor
        self.patch_side = patch_side
        self.grid_thickness = grid_thickness
        self.arrow_thickness = arrow_thickness
        self.wall_locs = np.array(wall_locs)
        self.win_locs = np.array(win_locs)
        self.lose_locs = np.array(lose_locs)
        self.at_terminal_state = False
        self.auto_reset = True
        self.random_respawn = True
        self.step = 0
        self.viz_canvas = None
        self.viz = viz
        self.path_color = (128, 128, 128)
        self.wall_color = (0, 255, 0)
        self.win_color = (0, 0, 255)
        self.lose_color = (255, 0, 0)
        self.world[self.wall_locs[:, 0], self.wall_locs[:, 1]] = self.wall_penalty
        self.world[self.lose_locs[:, 0], self.lose_locs[:, 1]] = self.lose_reward
        self.world[self.win_locs[:, 0], self.win_locs[:, 1]] = self.win_reward
        spawn_condn = lambda loc: self.world[loc[0], loc[1]] == self.default_reward
        self.spawn_locs = np.array([loc for loc in itertools.product(np.arange(self.world_height),
                                                                     np.arange(self.world_width))
                                    if spawn_condn(loc)])
        self.start_state = np.array(start_loc)
        self.bot_rc = None
        self.reset()
        self.actions = [self.up, self.left, self.right, self.down, self.noop]
        self.action_labels = ['UP', 'LEFT', 'RIGHT', 'DOWN', 'NOOP']
        self.q_values = np.ones([self.world.shape[0], self.world.shape[1], len(self.actions)]) * 1. / len(self.actions)
        if self.viz:
            self.init_grid_canvas()
            self.video_out_fpath = 'shm_dqn_gridsolver-' + str(time.time()) + '.mp4'
            self.clip = VideoClip(self.make_frame, duration=15)

    def make_frame(self, t):
        self.action()
        frame = self.highlight_loc(self.viz_canvas, self.bot_rc[0], self.bot_rc[1])
        return frame

    def check_terminal_state(self):
        if self.world[self.bot_rc[0], self.bot_rc[1]] == self.lose_reward \
                or self.world[self.bot_rc[0], self.bot_rc[1]] == self.win_reward:
            self.at_terminal_state = True
            # print('------++++---- TERMINAL STATE ------++++----')
            # if self.world[self.bot_rc[0], self.bot_rc[1]] == self.win_reward:
            #     print('GAME WON! :D')
            # elif self.world[self.bot_rc[0], self.bot_rc[1]] == self.lose_reward:
            #     print('GAME LOST! :(')
            if self.auto_reset:
                self.reset()

    def reset(self):
        # print('Resetting')
        if not self.random_respawn:
            self.bot_rc = self.start_state.copy()
        else:
            self.bot_rc = self.spawn_locs[np.random.choice(np.arange(len(self.spawn_locs)))].copy()
        self.at_terminal_state = False

    def up(self):
        action_idx = 0
        # print(self.action_labels[action_idx])
        new_r = self.bot_rc[0] - 1
        if new_r < 0 or self.world[new_r, self.bot_rc[1]] == self.wall_penalty:
            return self.wall_penalty, action_idx
        self.bot_rc[0] = new_r
        reward = self.world[self.bot_rc[0], self.bot_rc[1]]
        self.check_terminal_state()
        return reward, action_idx

    def left(self):
        action_idx = 1
        # print(self.action_labels[action_idx])
        new_c = self.bot_rc[1] - 1
        if new_c < 0 or self.world[self.bot_rc[0], new_c] == self.wall_penalty:
            return self.wall_penalty, action_idx
        self.bot_rc[1] = new_c
        reward = self.world[self.bot_rc[0], self.bot_rc[1]]
        self.check_terminal_state()
        return reward, action_idx

    def right(self):
        action_idx = 2
        # print(self.action_labels[action_idx])
        new_c = self.bot_rc[1] + 1
        if new_c >= self.world.shape[1] or self.world[self.bot_rc[0], new_c] == self.wall_penalty:
            return self.wall_penalty, action_idx
        self.bot_rc[1] = new_c
        reward = self.world[self.bot_rc[0], self.bot_rc[1]]
        self.check_terminal_state()
        return reward, action_idx

    def down(self):
        action_idx = 3
        # print(self.action_labels[action_idx])
        new_r = self.bot_rc[0] + 1
        if new_r >= self.world.shape[0] or self.world[new_r, self.bot_rc[1]] == self.wall_penalty:
            return self.wall_penalty, action_idx
        self.bot_rc[0] = new_r
        reward = self.world[self.bot_rc[0], self.bot_rc[1]]
        self.check_terminal_state()
        return reward, action_idx

    def noop(self):
        action_idx = 4
        # print(self.action_labels[action_idx])
        reward = self.world[self.bot_rc[0], self.bot_rc[1]]
        self.check_terminal_state()
        return reward, action_idx

    def qvals2probs(self, q_vals, epsilon=1e-4):
        action_probs = q_vals - q_vals.min() + epsilon
        action_probs = action_probs / action_probs.sum()
        return action_probs

    def action(self):
        # print('================ ACTION =================')
        if self.at_terminal_state:
            print('At terminal state, please call reset()')
            exit()
        # print('Start position:', self.bot_rc)
        start_bot_rc = self.bot_rc[0], self.bot_rc[1]
        q_vals = self.q_values[self.bot_rc[0], self.bot_rc[1]]
        action_probs = self.qvals2probs(q_vals)
        reward, action_idx = np.random.choice(self.actions, p=action_probs)()
        # print('End position:', self.bot_rc)
        # print('Reward:', reward)
        alpha = np.exp(-self.step / 10e9)
        self.step += 1
        qv = (1 - alpha) * q_vals[action_idx] + alpha * (reward + self.discount_factor
                                                         * self.q_values[self.bot_rc[0], self.bot_rc[1]].max())
        self.q_values[start_bot_rc[0], start_bot_rc[1], action_idx] = qv
        if self.viz:
            self.update_viz(start_bot_rc[0], start_bot_rc[1])
        if np.random.rand() < self.reset_prob:
            # print('-----> Randomly resetting to a random spawn point with probability', self.reset_prob)
            self.reset()

    def highlight_loc(self, viz_in, i, j):
        starty = i * (self.patch_side + self.grid_thickness)
        endy = starty + self.patch_side
        startx = j * (self.patch_side + self.grid_thickness)
        endx = startx + self.patch_side
        viz = viz_in.copy()
        cv2.rectangle(viz, (startx, starty), (endx, endy), (255, 255, 255), thickness=self.grid_thickness)
        return viz

    def update_viz(self, i, j):
        starty = i * (self.patch_side + self.grid_thickness)
        endy = starty + self.patch_side
        startx = j * (self.patch_side + self.grid_thickness)
        endx = startx + self.patch_side
        patch = np.zeros([self.patch_side, self.patch_side, 3]).astype(np.uint8)
        if self.world[i, j] == self.default_reward:
            patch[:, :, :] = self.path_color
        elif self.world[i, j] == self.wall_penalty:
            patch[:, :, :] = self.wall_color
        elif self.world[i, j] == self.win_reward:
            patch[:, :, :] = self.win_color
        elif self.world[i, j] == self.lose_reward:
            patch[:, :, :] = self.lose_color
        if self.world[i, j] == self.default_reward:
            action_probs = self.qvals2probs(self.q_values[i, j])
            x_component = action_probs[2] - action_probs[1]
            y_component = action_probs[0] - action_probs[3]
            magnitude = 1. - action_probs[-1]
            s = self.patch_side // 2
            x_patch = int(s * x_component)
            y_patch = int(s * y_component)
            arrow_canvas = np.zeros_like(patch)
            vx = s + x_patch
            vy = s - y_patch
            cv2.arrowedLine(arrow_canvas, (s, s), (vx, vy), (255, 255, 255), thickness=self.arrow_thickness,
                            tipLength=0.5)
            gridbox = (magnitude * arrow_canvas + (1 - magnitude) * patch).astype(np.uint8)
            self.viz_canvas[starty:endy, startx:endx] = gridbox
        else:
            self.viz_canvas[starty:endy, startx:endx] = patch

    def init_grid_canvas(self):
        org_h, org_w = self.world_height, self.world_width
        viz_w = (self.patch_side * org_w) + (self.grid_thickness * (org_w - 1))
        viz_h = (self.patch_side * org_h) + (self.grid_thickness * (org_h - 1))
        self.viz_canvas = np.zeros([viz_h, viz_w, 3]).astype(np.uint8)
        for i in range(org_h):
            for j in range(org_w):
                self.update_viz(i, j)

    def solve(self):
        if not self.viz:
            while True:
                self.action()
        else:
            self.clip.write_videofile(self.video_out_fpath, fps=460)


def gen_world_config(h, w, wall_frac=.5, num_wins=2, num_lose=3):
    n = h * w
    num_wall_blocks = int(wall_frac * n)
    wall_locs = (np.random.rand(num_wall_blocks, 2) * [h, w]).astype(np.int)
    win_locs = (np.random.rand(num_wins, 2) * [h, w]).astype(np.int)
    lose_locs = (np.random.rand(num_lose, 2) * [h, w]).astype(np.int)
    return wall_locs, win_locs, lose_locs


if __name__ == '__main__':
    wall_locs, win_locs, lose_locs = gen_world_config(WORLD_HEIGHT, WORLD_WIDTH, WALL_FRAC, NUM_WINS, NUM_LOSE)
    g = GridWorld(world_height=WORLD_HEIGHT, world_width=WORLD_WIDTH,
                  wall_locs=wall_locs, win_locs=win_locs, lose_locs=lose_locs, viz=True)
    g.solve()
    k = 0