diff --git a/.gitignore b/.gitignore index 7ee54fc5..db378067 100644 --- a/.gitignore +++ b/.gitignore @@ -159,3 +159,4 @@ opponent_pool !/examples/selfplay/opponent_templates/tictactoe_opponent/info.json wandb_run examples/dmc/new.gif +/examples/snake/submissions/rl/actor_2000.pth diff --git a/Makefile b/Makefile index e8c39108..d38da499 100644 --- a/Makefile +++ b/Makefile @@ -12,7 +12,7 @@ test: lint: $(call check_install, ruff) ruff ${PYTHON_FILES} --select=E9,F63,F7,F82 --show-source - ruff ${PYTHON_FILES} --exit-zero | grep -v '501\|405\|401\|402\|403' + ruff ${PYTHON_FILES} --exit-zero | grep -v '501\|405\|401\|402\|403\|722' format: $(call check_install, isort) diff --git a/examples/snake/README.md b/examples/snake/README.md index 4adb9cbd..5194b2c2 100644 --- a/examples/snake/README.md +++ b/examples/snake/README.md @@ -7,6 +7,11 @@ This is the example for the snake game. python train_selfplay.py ``` +## Evaluate JiDi submissions locally + +```bash +python jidi_eval.py +``` ## Submit to JiDi diff --git a/examples/snake/jidi_eval.py b/examples/snake/jidi_eval.py new file mode 100644 index 00000000..8c893b44 --- /dev/null +++ b/examples/snake/jidi_eval.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" + +from openrl.arena import make_arena +from openrl.arena.agents.jidi_agent import JiDiAgent +from openrl.envs.wrappers.pettingzoo_wrappers import RecordWinner + + +def run_arena( + render: bool = False, + parallel: bool = True, + seed=0, + total_games: int = 10, + max_game_onetime: int = 5, +): + env_wrappers = [RecordWinner] + + player_num = 3 + arena = make_arena( + f"snakes_{player_num}v{player_num}", env_wrappers=env_wrappers, render=render + ) + + agent1 = JiDiAgent("./submissions/rule_v1", player_num=player_num) + agent2 = JiDiAgent("./submissions/rl", player_num=player_num) + + arena.reset( + agents={"agent1": agent1, "agent2": agent2}, + total_games=total_games, + max_game_onetime=max_game_onetime, + seed=seed, + ) + result = arena.run(parallel=parallel) + arena.close() + print(result) + return result + + +if __name__ == "__main__": + run_arena(render=False, parallel=True, seed=0, total_games=100, max_game_onetime=5) diff --git a/examples/snake/submissions/random_agent/submission.py b/examples/snake/submissions/random_agent/submission.py index b1f468df..52f748b5 100644 --- a/examples/snake/submissions/random_agent/submission.py +++ b/examples/snake/submissions/random_agent/submission.py @@ -26,4 +26,5 @@ def my_controller(observation, action_space, is_act_continuous): for i in range(len(action_space)): player = sample_single_dim(action_space[i], is_act_continuous) joint_action.append(player) + return joint_action diff --git a/examples/snake/submissions/rl/README.md b/examples/snake/submissions/rl/README.md new file mode 100644 index 00000000..90e9de5a --- /dev/null +++ b/examples/snake/submissions/rl/README.md @@ -0,0 +1,3 @@ +# Download actor weight + +Please download [actor_2000.pth](https://github.com/CarlossShi/Competition_3v3snakes/tree/master/agent/rl) before use this code. \ No newline at end of file diff --git a/examples/snake/submissions/rl/submission.py b/examples/snake/submissions/rl/submission.py new file mode 100644 index 00000000..ae7a674d --- /dev/null +++ b/examples/snake/submissions/rl/submission.py @@ -0,0 +1,196 @@ +import os +import sys +from pathlib import Path + +import numpy as np +import torch +from torch import nn +from torch.distributions import Categorical + +HIDDEN_SIZE = 256 +device = torch.device("cpu") + +from typing import Union + +Activation = Union[str, nn.Module] + +_str_to_activation = { + "relu": torch.nn.ReLU(), + "tanh": nn.Tanh(), + "identity": nn.Identity(), + "softmax": nn.Softmax(dim=-1), +} + + +def mlp( + sizes, activation: Activation = "relu", output_activation: Activation = "identity" +): + if isinstance(activation, str): + activation = _str_to_activation[activation] + if isinstance(output_activation, str): + output_activation = _str_to_activation[output_activation] + + layers = [] + for i in range(len(sizes) - 1): + act = activation if i < len(sizes) - 2 else output_activation + layers += [nn.Linear(sizes[i], sizes[i + 1]), act] + return nn.Sequential(*layers) + + +def get_surrounding(state, width, height, x, y): + surrounding = [ + state[(y - 1) % height][x], # up + state[(y + 1) % height][x], # down + state[y][(x - 1) % width], # left + state[y][(x + 1) % width], + ] # right + + return surrounding + + +def make_grid_map( + board_width, board_height, beans_positions: list, snakes_positions: dict +): + snakes_map = [[[0] for _ in range(board_width)] for _ in range(board_height)] + for index, pos in snakes_positions.items(): + for p in pos: + snakes_map[p[0]][p[1]][0] = index + + for bean in beans_positions: + snakes_map[bean[0]][bean[1]][0] = 1 + + return snakes_map + + +# Self position: 0:head_x; 1:head_y +# Head surroundings: 2:head_up; 3:head_down; 4:head_left; 5:head_right +# Beans positions: (6, 7) (8, 9) (10, 11) (12, 13) (14, 15) +# Other snake positions: (16, 17) (18, 19) (20, 21) (22, 23) (24, 25) -- (other_x - self_x, other_y - self_y) +def get_observations(state, agents_index, obs_dim, height, width): + state_copy = state.copy() + board_width = state_copy["board_width"] + board_height = state_copy["board_height"] + beans_positions = state_copy[1] + snakes_positions = { + key: state_copy[key] for key in state_copy.keys() & {2, 3, 4, 5, 6, 7} + } + snakes_positions_list = [] + for key, value in snakes_positions.items(): + snakes_positions_list.append(value) + snake_map = make_grid_map( + board_width, board_height, beans_positions, snakes_positions + ) + state_ = np.array(snake_map) + state_ = np.squeeze(state_, axis=2) + + observations = np.zeros((3, obs_dim)) + snakes_position = np.array(snakes_positions_list, dtype=object) + beans_position = np.array(beans_positions, dtype=object).flatten() + for i, element in enumerate(agents_index): + # # self head position + observations[i][:2] = snakes_positions_list[element][0][:] + + # head surroundings + head_x = snakes_positions_list[element][0][1] + head_y = snakes_positions_list[element][0][0] + + head_surrounding = get_surrounding(state_, width, height, head_x, head_y) + observations[i][2:6] = head_surrounding[:] + + # beans positions + observations[i][6:16] = beans_position[:] + + # other snake positions + snake_heads = np.array([snake[0] for snake in snakes_position]) + snake_heads = np.delete(snake_heads, i, 0) + observations[i][16:] = snake_heads.flatten()[:] + return observations + + +class Actor(nn.Module): + def __init__(self, obs_dim, act_dim, num_agents, args, output_activation="softmax"): + super().__init__() + + self.obs_dim = obs_dim + self.act_dim = act_dim + self.num_agents = num_agents + + self.args = args + + sizes_prev = [obs_dim, HIDDEN_SIZE] + sizes_post = [HIDDEN_SIZE, HIDDEN_SIZE, act_dim] + + self.prev_dense = mlp(sizes_prev) + self.post_dense = mlp(sizes_post, output_activation=output_activation) + + def forward(self, obs_batch): + out = self.prev_dense(obs_batch) + out = self.post_dense(out) + return out + + +class RLAgent(object): + def __init__(self, obs_dim, act_dim, num_agent): + self.obs_dim = obs_dim + self.act_dim = act_dim + self.num_agent = num_agent + self.device = device + self.output_activation = "softmax" + self.actor = Actor(obs_dim, act_dim, num_agent, self.output_activation).to( + self.device + ) + + def choose_action(self, obs): + obs = torch.Tensor([obs]).to(self.device) + logits = self.actor(obs).cpu().detach().numpy()[0] + return logits + + def select_action_to_env(self, obs, ctrl_index): + logits = self.choose_action(obs) + actions = logits2action(logits) + action_to_env = to_joint_action(actions, ctrl_index) + return action_to_env + + def load_model(self, filename): + self.actor.load_state_dict(torch.load(filename)) + + +def to_joint_action(action, ctrl_index): + joint_action_ = [] + action_a = action[ctrl_index] + each = [0] * 4 + each[action_a] = 1 + joint_action_.append(each) + return joint_action_ + + +def logits2action(logits): + logits = torch.Tensor(logits).to(device) + actions = np.array([Categorical(out).sample().item() for out in logits]) + return np.array(actions) + + +agent = RLAgent(26, 4, 3) +actor_net = os.path.dirname(os.path.abspath(__file__)) + "/actor_2000.pth" +assert Path(actor_net).exists(), ( + "actor_2000.pth not exists, please download from:" + " https://github.com/CarlossShi/Competition_3v3snakes/tree/master/agent/rl" +) +agent.load_model(actor_net) + + +def my_controller(observation_list, action_space_list, is_act_continuous): + obs_dim = 26 + obs = observation_list.copy() + board_width = obs["board_width"] + board_height = obs["board_height"] + o_index = obs[ + "controlled_snake_index" + ] # 2, 3, 4, 5, 6, 7 -> indexs = [0,1,2,3,4,5] + o_indexs_min = 3 if o_index > 4 else 0 + indexs = [o_indexs_min, o_indexs_min + 1, o_indexs_min + 2] + observation = get_observations( + obs, indexs, obs_dim, height=board_height, width=board_width + ) + actions = agent.select_action_to_env(observation, indexs.index(o_index - 2)) + return actions diff --git a/examples/snake/submissions/rule_v1/submission.py b/examples/snake/submissions/rule_v1/submission.py new file mode 100644 index 00000000..db9a81e9 --- /dev/null +++ b/examples/snake/submissions/rule_v1/submission.py @@ -0,0 +1,660 @@ +# code from https://github.com/CarlossShi/Competition_3v3snakes/blob/master/agent/submit/submission.py +import copy +import itertools +import operator +import os +import pprint +import random +import sys + +import numpy as np + + +class HiddenPrints: + def __enter__(self): + self._original_stdout = sys.stdout + sys.stdout = open(os.devnull, "w") + + def __exit__(self, exc_type, exc_val, exc_tb): + sys.stdout.close() + sys.stdout = self._original_stdout + + +def get_direction(x_h, y_h, x, y, height, width): # from (x_h, y_h) to (x, y) + if (x_h + 1) % height == x and y_h == y: + return [0, 1, 0, 0] + elif (x_h - 1) % height == x and y_h == y: + return [1, 0, 0, 0] + elif x_h == x and (y_h + 1) % width == y: + return [0, 0, 0, 1] + elif x_h == x and (y_h - 1) % width == y: + return [0, 0, 1, 0] + else: + assert False, "the start and end points do not match" + + +def connected_count(matrix, pos): + height, width = matrix.shape + x, y = pos + sign = matrix[x, y] + unexplored = [[x, y]] + explored = [] + _connected_count = 1 + while unexplored: + x, y = unexplored.pop() + explored.append([x, y]) + for x_, y_ in [ + ((x + 1) % height, y), # down + ((x - 1) % height, y), # up + (x, (y + 1) % width), # right + (x, (y - 1) % width), + ]: # left + if ( + matrix[x_, y_] == sign + and [x_, y_] not in explored + and [x_, y_] not in unexplored + ): + unexplored.append([x_, y_]) + _connected_count += 1 + return _connected_count + + +class Snake: + def __init__(self, snake_positions, board_height, board_width, beans_positions): + self.pos = snake_positions # [[2, 9], [2, 8], [2, 7]] + self.len = len(snake_positions) # >= 3 + self.head = snake_positions[0] + self.beans_positions = beans_positions + self.claimed_count = 0 + + displace = [ + (self.head[0] - snake_positions[1][0]) % board_height, + (self.head[1] - snake_positions[1][1]) % board_width, + ] + print("creat snake, pos: ", self.pos, "displace:", displace) + if displace == [ + board_height - 1, + 0, + ]: # all action are ordered by left, up, right, relative to the body + self.dir = 0 # up + self.legal_action = [2, 0, 3] + elif displace == [1, 0]: + self.dir = 1 # down + self.legal_action = [3, 1, 2] + elif displace == [0, board_width - 1]: + self.dir = 2 # left + self.legal_action = [1, 2, 0] + elif displace == [0, 1]: + self.dir = 3 # right + self.legal_action = [0, 3, 1] + else: + assert False, "snake positions error" + positions = [ + [(self.head[0] - 1) % board_height, self.head[1]], + [(self.head[0] + 1) % board_height, self.head[1]], + [self.head[0], (self.head[1] - 1) % board_width], + [self.head[0], (self.head[1] + 1) % board_width], + ] + self.legal_position = [positions[_] for _ in self.legal_action] + + def get_action(self, position): + if position not in self.legal_position: + assert False, "the start and end points do not match" + idx = self.legal_position.index(position) + return self.legal_action[idx] # 0, 1, 2, 3: up, down, left, right + + def step(self, legal_input): + if legal_input in self.legal_position: + position = legal_input + elif legal_input in self.legal_action: + idx = self.legal_action.index(legal_input) + position = self.legal_position[idx] + else: + assert False, "illegal snake move" + self.head = position + self.pos.insert(0, position) + if position in self.beans_positions: # eat a bean + self.len += 1 + else: # do not eat a bean + self.pos.pop() + + +class Board: + def __init__(self, board_height, board_width, snakes, beans_positions, teams): + print("create board, beans_position: ", beans_positions) + self.height = board_height + self.width = board_width + self.snakes = snakes + self.snakes_count = len(snakes) + self.beans_positions = beans_positions + self.blank_sign = -self.snakes_count + self.bean_sign = -self.snakes_count + 1 + self.board = np.zeros((board_height, board_width), dtype=int) + self.blank_sign + self.open = dict() + for key, snake in self.snakes.items(): + self.open[key] = [snake.head] # state 0 open list, heads, ready to spread + # see [A* Pathfinding (E01: algorithm explanation)](https://www.youtube.com/watch?v=-L-WgKMFuhE) + for x, y in snake.pos: + self.board[x][y] = key # obstacles, e.g. 0, 1, 2, 3, 4, 5 + # for x, y in beans_positions: + # self.board[x][y] = self.bean_sign # beans + + self.state = 0 + self.controversy = dict() + self.teams = teams + + print("initial board") + print(self.board) + + def step(self): # delay: prevent rear-end collision + new_open = {key: [] for key in self.snakes.keys()} + self.state += 1 # update state + # if self.state > delay: + # for key, snake in self.snakes.items(): # drop tail + # if snake.len >= self.state: + # self.board[snake.pos[-(self.state - delay)][0]][snake.pos[-(self.state - delay)][1]] \ + # = self.blank_sign + for key, snake in self.snakes.items(): + if snake.len >= self.state: + self.board[snake.pos[-self.state][0]][ + snake.pos[-self.state][1] + ] = self.blank_sign # drop tail + for key, value in self.open.items(): # value: e.g. [[8, 3], [6, 3], [7, 4]] + others_tail_pos = [ + ( + self.snakes[_].pos[-self.state] + if self.snakes[_].len >= self.state + else [] + ) + for _ in set(range(self.snakes_count)) - {key} + ] + for x, y in value: + print("start to spread snake {} on grid ({}, {})".format(key, x, y)) + for x_, y_ in [ + ((x + 1) % self.height, y), # down + ((x - 1) % self.height, y), # up + (x, (y + 1) % self.width), # right + (x, (y - 1) % self.width), + ]: # left + sign = self.board[x_][y_] + idx = ( + sign % self.snakes_count + ) # which snake, e.g. 0, 1, 2, 3, 4, 5 / number of claims + state = ( + sign // self.snakes_count + ) # manhattan distance to snake who claim the point or its negative + if sign == self.blank_sign: # grid in initial state + if [x_, y_] in others_tail_pos: + print( + "do not spread other snakes tail, in case of rear-end" + " collision" + ) + continue # do not spread other snakes' tail, in case of rear-end collision + self.board[x_][y_] = self.state * self.snakes_count + key + self.snakes[key].claimed_count += 1 + new_open[key].append([x_, y_]) + + elif key != idx and self.state == state: + # second claim, init controversy, change grid value from + to - + print( + "\tgird ({}, {}) in the same state claimed by different" + " snakes with sign {}, idx {} and state {}".format( + x_, y_, sign, idx, state + ) + ) + if ( + self.snakes[idx].len > self.snakes[key].len + ): # shorter snake claim the controversial grid + print( + "\t\tsnake {} is shorter than snake {}".format(key, idx) + ) + self.snakes[idx].claimed_count -= 1 + new_open[idx].remove([x_, y_]) + self.board[x_][y_] = self.state * self.snakes_count + key + self.snakes[key].claimed_count += 1 + new_open[key].append([x_, y_]) + elif ( + self.snakes[idx].len == self.snakes[key].len + ): # controversial claim + print( + "\t\tcontroversy! first claimed by snake {}, then" + " claimed by snake {}".format(idx, key) + ) + self.controversy[(x_, y_)] = { + "state": self.state, + "length": self.snakes[idx].len, + "indexes": [idx, key], + } + # first claim by snake idx, then claim by snake key + self.board[x_][y_] = -self.state * self.snakes_count + 1 + # if + 2, not enough for all snakes claim one grid!! + self.snakes[ + idx + ].claimed_count -= ( + 1 # controversy, no snake claim this grid!! + ) + new_open[key].append([x_, y_]) + else: # (self.snakes[idx].len < self.snakes[key].len) + pass # longer snake do not claim the controversial grid + + elif ( + (x_, y_) in self.controversy + and key not in self.controversy[(x_, y_)]["indexes"] + and self.state + state == 0 + ): # third claim or more + print( + "snake {} meets third or more claim in grid ({}, {})" + .format(key, x_, y_) + ) + controversy = self.controversy[(x_, y_)] + pprint.pprint(controversy) + if ( + controversy["length"] > self.snakes[key].len + ): # shortest snake claim grid, do 4 things + print("\t\tsnake {} is shortest".format(key)) + indexes_count = len(controversy["indexes"]) + for i in controversy["indexes"]: + self.snakes[i].claimed_count -= ( + 1 / indexes_count + ) # update claimed_count ! + new_open[i].remove([x_, y_]) + del self.controversy[(x_, y_)] + self.board[x_][y_] = self.state * self.snakes_count + key + self.snakes[key].claimed_count += 1 + new_open[key].append([x_, y_]) + elif ( + controversy["length"] == self.snakes[key].len + ): # controversial claim + print( + "\t\tcontroversy! multi claimed by snake {}".format(key) + ) + self.controversy[(x_, y_)]["indexes"].append(key) + self.board[x_][y_] += 1 + new_open[key].append([x_, y_]) + else: # (controversy['length'] < self.snakes[key].len) + pass # longer snake do not claim the controversial grid + else: + pass # do nothing with lower state grids + + self.open = new_open # update open + # update controversial snakes' claimed_count (in fraction) in the end + for _, d in self.controversy.items(): + controversial_snake_count = len( + d["indexes"] + ) # number of controversial snakes + for idx in d["indexes"]: + self.snakes[idx].claimed_count += 1 / controversial_snake_count + + def claim2action(self, claim_position, snake_idx, step_count, output_type): + # claim e.g. [2 ,3 ,4 ,-9] bean 2 is claimed by snake 3 within 4 steps + x, y = claim_position + x_h, y_h = self.snakes[snake_idx].head # head position + + while step_count > 1: + step_count -= 1 + temp = [] + for x_, y_ in [ + ((x + 1) % self.height, y), # down + ((x - 1) % self.height, y), # up + (x, (y + 1) % self.width), # right + (x, (y - 1) % self.width), + ]: # left + sign = self.board[x_][y_] + if sign == self.blank_sign: + continue # snake too long, board not spread completely!! see example 20210815 0:41:48 + state = ( + sign // self.snakes_count + if sign > 0 + else -(sign // self.snakes_count) + ) + indexes = ( + [sign % self.snakes_count] + if sign >= 0 + else self.controversy[(x_, y_)]["indexes"] + ) + if step_count == state and snake_idx in indexes: + temp.append([x_, y_]) + x, y = random.choice(temp) + if output_type == "action": + return get_direction(x_h, y_h, x, y, self.height, self.width) + elif output_type == "position": + return [x, y] + else: + assert False, "unknown output_type {}".format(output_type) + + +def state2claims(state_array, max_state, priority): + # state_array: + # array([[ 1, 30, 30, 30, 30, 30], + # [30, 30, 3, 30, 30, 30], + # [30, 1, 30, 5, 30, 30], + # [30, 30, 30, 6, 5, 30], + # [30, 30, 30, 30, 30, 6]]) + beanCount, snakeCount = state_array.shape # (5, 6) + horiz = [ + list(state_array[_]).count(max_state) for _ in range(beanCount) + ] # [5, 5, 4, 4, 5] + vert = [ + list(state_array[:, _]).count(max_state) for _ in range(snakeCount) + ] # [4, 4, 4, 3, 4, 4] + claim_order = [] + for b in range(beanCount): + for s in range(snakeCount): + if state_array[b][s] < max_state: + claim_order.append([b, s, state_array[b][s], -horiz[b] - vert[s]]) + # priority rule: smaller state, larger horizontal or vertical cover number of max_state + if not claim_order: + return [] + if priority == "state": + temp = min(claim_order, key=operator.itemgetter(2, 3)) # [0, 0, 1, -9] + elif priority == "cover": + temp = min(claim_order, key=operator.itemgetter(3, 2)) + else: + assert False, "unknown priority" + # update + for b in range(beanCount): + state_array[b, temp[1]] = max_state + 1 + for s in range(snakeCount): + state_array[temp[0], s] = max_state + 1 + return [temp] + state2claims(state_array, max_state, priority) + + +def my_controller(observation_list, action_space_list, is_act_continuous): + with HiddenPrints(): + # detect 1v1, 3v3, 2p or 5p + # if True: + observation_len = len(observation_list.keys()) + teams = None + if observation_len == 7: + teams = [[0], [1]] # 1v1 + # teams = [[0, 1]] # 2p + elif observation_len == 10: + teams = [[0, 1, 2, 3, 4]] # 5p + elif observation_len == 11: + teams = [[0, 1, 2], [3, 4, 5]] # 3v3 + + assert teams is not None, "unknown game with observation length {}".format( + observation_len + ) + teams_count = len(teams) + snakes_count = sum([len(_) for _ in teams]) + + # read observation + obs = observation_list.copy() + board_height = obs["board_height"] # 10 + board_width = obs["board_width"] # 20 + ctrl_agent_index = obs["controlled_snake_index"] - 2 # 0, 1, 2, 3, 4, 5 + # last_directions = obs['last_direction'] # ['up', 'left', 'down', 'left', 'left', 'left'] + beans_positions = obs[1] # e.g.[[7, 15], [4, 14], [5, 12], [4, 12], [5, 7]] + snakes = { + key - 2: Snake(obs[key], board_height, board_width, beans_positions) + for key in obs.keys() & {_ + 2 for _ in range(snakes_count)} + } # &: intersection + team_indexes = [_ for _ in teams if ctrl_agent_index in _][0] + + init_board = Board(board_height, board_width, snakes, beans_positions, teams) + bd = copy.deepcopy(init_board) + + with HiddenPrints(): + while not all( + _ == [] for _ in bd.open.values() + ): # loop until all values in open are empty list + bd.step() + print(bd.board) + + defense_snakes_indexes = ( + [] + ) # save defensive or claimed snakes, to calculate safe move for ctrl snake + + # define defensive move + defensive_claim_list = [] # [pos, snake_idx, step] + # first check win side + snakes_lens = [snake.len for snake in snakes.values()] + snakes_claimed_counts = [snake.len for snake in snakes.values()] + print("snakes_lens: ", snakes_lens) + print("snakes_claimed_counts: ", snakes_claimed_counts) + + # design defense threshold + # defense_threshold = 0.5 * math.pow(board_height * board_width, 1.1) / snakes_count * \ + # math.sqrt(4 / (4 * teams_count + 1)) + defense_threshold = ( + board_height * board_width * teams_count / (teams_count + 1) / snakes_count + ) + + for idx in team_indexes: + if snakes_lens[idx] > defense_threshold: + # 3: player count + 1, 2: player count, 6: snake count + for _ in range(1, min(bd.state, snakes_lens[idx] // 2)): + # range should be designed more carefully!! + x, y = snakes[idx].pos[-_] + if ( + bd.board[x, y] == idx + _ * snakes_count + ): # claim a loop in step _ + defense_snakes_indexes.append(idx) + defensive_claim_list.append([[x, y], idx, _]) + if idx == ctrl_agent_index: + action = [ + bd.claim2action( + claim_position=[x, y], + snake_idx=idx, + step_count=_, + output_type="action", + ) + ] + print( + "the controlled agent {} make a defensive move {}" + " within {} step(s)".format(idx, action[0], _) + ) + print( + "***********************************" + + " defensive move " + + "***************************************" + ) + return action + + # calculate state_array + # e.g. + # array([[ 1, 30, 30, 30, 30, 30], + # [30, 30, 3, 30, 30, 30], + # [30, 1, 30, 5, 30, 30], + # [30, 30, 30, 6, 5, 30], + # [30, 30, 30, 30, 30, 6]]) + max_state = board_height + board_width # 30 + state_array = ( + np.zeros((len(beans_positions), len(snakes)), dtype=int) + max_state + ) + for i, (x, y) in enumerate(beans_positions): + sign = bd.board[x][y] + if sign >= snakes_count: # bean claimed by one snake + idx = sign % snakes_count # 0, 1, 2, 3, 4, 5 + state = sign // snakes_count # 1, 2, ... + state_array[i][idx] = state + elif sign < 0 and sign % snakes_count in [ + _ for _ in range(snakes_count) if _ > 0 + ]: # [2, 3, 4, 5] + state = -(sign // snakes_count) + for idx in bd.controversy[(x, y)]["indexes"]: + state_array[i][idx] = state + elif ( + sign == bd.blank_sign + ): # bean not reachable for any snakes! see example: 20210815, 1:10:23 + pass + else: + assert False, "unknown sign when calculating state_array" + + # calculate claim list + # e.g. [[2, 3, 4, -9], [1, 1, 4, -6], [3, 2, 4, -4], [0, 4, 5, -2]] + claim_list_byState = state2claims(state_array.copy(), max_state, "state") + claim_list_byCover = state2claims(state_array.copy(), max_state, "cover") + print("claim_list_byState: ", len(claim_list_byState), claim_list_byState) + print("claim_list_byCover: ", len(claim_list_byCover), claim_list_byCover) + claim_list = ( + claim_list_byState + if len(claim_list_byState) >= len(claim_list_byCover) + else claim_list_byCover + ) + print("claim_list: ", claim_list) + + claim_snakes_indexes = [] + + # for agent claiming a bean safely, simply return its action + for c in claim_list: + if ctrl_agent_index == c[1]: # the controlled agent claim a bean + print( + "the controlled agent {} claim a bean {} within {} step(s)".format( + c[1], c[0], c[2] + ) + ) + action = [ + bd.claim2action( + claim_position=bd.beans_positions[c[0]], + snake_idx=c[1], + step_count=c[2], + output_type="action", + ) + ] + print("and play a action", action[0]) + print( + "*********************************** claim move" + " ******************************************" + ) + return action + claim_snakes_indexes.append(c[1]) + else: + # not claim any beans, + # traverse all possible action combination (at most 27), + # choose one that claim most grids + # calculate free team snakes indexes and safe positions list + free_team_snakes_indexes = [ + _ + for _ in team_indexes + if _ not in claim_snakes_indexes and _ not in defense_snakes_indexes + ] + safe_positions_list = [] + for idx in free_team_snakes_indexes: + safe_positions = [] # may be empty list + x_h, y_h = snakes[idx].head + for x, y in [ + ((x_h + 1) % bd.height, y_h), # down + ((x_h - 1) % bd.height, y_h), # up + (x_h, (y_h + 1) % bd.width), # right + (x_h, (y_h - 1) % bd.width), + ]: # left + if bd.board[x][y] == idx + snakes_count: + safe_positions.append( + [x, y] + ) # should be further tested if exists breath!! + safe_positions_list.append(safe_positions) + + # delete snakes whose safe positions are [], which means they are dying + check_list = [_ != [] for _ in safe_positions_list] + free_team_snakes_indexes = list( + np.array(free_team_snakes_indexes)[check_list] + ) # [idx1, idx2] + safe_positions_list = [ + _ for _ in safe_positions_list if _ + ] # [[pos1, pos2, pos3], [pos1, pos2, pos3] + + print("free_team_snakes_indexes: ", free_team_snakes_indexes) + print("safe_positions_list: ", safe_positions_list) + + # create new snakes + snakes_next = copy.deepcopy(snakes) + for c in claim_list: # claimed snake make one move + idx = c[1] + if idx in defense_snakes_indexes: + continue # defense is prior to claim + position = bd.claim2action( + claim_position=bd.beans_positions[c[0]], + snake_idx=idx, + step_count=c[2], + output_type="position", + ) + snakes_next[idx].step(position) + + for c in defensive_claim_list: # claimed snake make one move + idx = c[1] + position = bd.claim2action( + claim_position=c[0], + snake_idx=idx, + step_count=c[2], + output_type="position", + ) + snakes_next[idx].step(position) + + # traverse and find the action combination with most grids claimed + max_claimed_counts_sum = 0 + best_pos_comb = None + for pos_comb in itertools.product( + *safe_positions_list + ): # calculate cartesian product of safe positions list + # initiate claimed_count + for i, idx in enumerate( + free_team_snakes_indexes + ): # unclaimed and undead snake make one move + snakes_next[idx] = copy.deepcopy(snakes[idx]) + snakes_next[idx].step(pos_comb[i]) + for snake in snakes_next.values(): + snake.claimed_count = 0 # reset after deep copy !! + bd_next = Board( + board_height, board_width, snakes_next, beans_positions, teams + ) + with HiddenPrints(): + while not all( + _ == [] for _ in bd_next.open.values() + ): # loop until all values in open are empty list + bd_next.step() + print(bd.board) + claimed_counts = np.zeros(len(team_indexes)) + for i, idx in enumerate( + team_indexes + ): # not free, consider all team snakes!! + claimed_counts[i] = snakes_next[idx].claimed_count + claimed_counts_sum = sum(claimed_counts) + + if claimed_counts_sum > max_claimed_counts_sum: + max_claimed_counts_sum = claimed_counts_sum + best_pos_comb = pos_comb # one-to-one with free_team_snakes_indexes + print("claimed_counts_sum: ", claimed_counts_sum) + print("pos_comb: ", pos_comb) + + print( + "max_claimed_counts_sum: ", + max_claimed_counts_sum, + "best_pos_comb: ", + best_pos_comb, + ) + + if best_pos_comb: + for i, idx in enumerate(free_team_snakes_indexes): + if ctrl_agent_index == idx: + action = [[0, 0, 0, 0]] + direction = snakes[idx].get_action(best_pos_comb[i]) + action[0][direction] = 1 + print( + "the controlled agent {} make a safe move".format(idx), + action[0], + ) + print( + "*********************************** safe move" + " ******************************************" + ) + return action + + # todo: design attack moves + + # no claim move, no safe move, no attack or die + action = [[0, 0, 0, 0]] + direction = snakes[ctrl_agent_index].legal_action[0] + action[0][direction] = 1 + print( + "the controlled agent {} play a random action and is dying".format( + ctrl_agent_index + ), + action[0], + ) + print( + "*********************************** random move" + " ******************************************" + ) + return action diff --git a/openrl/arena/__init__.py b/openrl/arena/__init__.py index 3d2aadb2..8c82204f 100644 --- a/openrl/arena/__init__.py +++ b/openrl/arena/__init__.py @@ -21,15 +21,23 @@ from openrl.envs import pettingzoo_all_envs -def make_arena(env_id: str, custom_build_env: Optional[Callable] = None, **kwargs): +def make_arena( + env_id: str, + custom_build_env: Optional[Callable] = None, + render: Optional[bool] = False, + **kwargs, +): if custom_build_env is None: if env_id in pettingzoo_all_envs: from openrl.envs.PettingZoo import make_PettingZoo_env - env_fn = make_PettingZoo_env(env_id, **kwargs) + render_mode = None + if render: + render_mode = "human" + env_fn = make_PettingZoo_env(env_id, render_mode=render_mode, **kwargs) else: raise ValueError(f"Unknown env_id: {env_id}") else: - env_fn = custom_build_env(env_id, **kwargs) + env_fn = custom_build_env(env_id, render, **kwargs) return TwoPlayerArena(env_fn) diff --git a/openrl/arena/agents/jidi_agent.py b/openrl/arena/agents/jidi_agent.py new file mode 100644 index 00000000..f6d21db3 --- /dev/null +++ b/openrl/arena/agents/jidi_agent.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" +from openrl.arena.agents.base_agent import BaseAgent +from openrl.selfplay.opponents.base_opponent import BaseOpponent +from openrl.selfplay.opponents.utils import load_opponent_from_jidi_path + + +class JiDiAgent(BaseAgent): + def __init__(self, local_agent_path, player_num: int = 1): + super().__init__() + self.local_agent_path = local_agent_path + self.player_num = player_num + + def _new_agent(self) -> BaseOpponent: + return load_opponent_from_jidi_path( + self.local_agent_path, player_num=self.player_num + ) diff --git a/openrl/arena/games/two_player_game.py b/openrl/arena/games/two_player_game.py index 5fe32fd4..a623a115 100644 --- a/openrl/arena/games/two_player_game.py +++ b/openrl/arena/games/two_player_game.py @@ -45,8 +45,9 @@ def _run(self, env_fn: Callable, agents: List[BaseAgent]): player2agent, player2agent_name = self.dispatch_agent_to_player( env.agents, agents ) - for agent in player2agent.values(): - agent.reset(env) + + for player, agent in player2agent.items(): + agent.reset(env, player) result = {} while True: termination = False diff --git a/openrl/envs/PettingZoo/__init__.py b/openrl/envs/PettingZoo/__init__.py index e70a07dc..9c253466 100644 --- a/openrl/envs/PettingZoo/__init__.py +++ b/openrl/envs/PettingZoo/__init__.py @@ -19,13 +19,18 @@ from typing import List, Optional, Union from openrl.envs.common import build_envs -from openrl.envs.PettingZoo.registration import pettingzoo_env_dict +from openrl.envs.PettingZoo.registration import pettingzoo_env_dict, register from openrl.envs.wrappers.pettingzoo_wrappers import SeedEnv def PettingZoo_make(id, render_mode, disable_env_checker, **kwargs): + if id.startswith("snakes_"): + from openrl.envs.snake.snake_pettingzoo import SnakeEatBeansAECEnv + + kwargs.__setitem__("id", id) + register(id, SnakeEatBeansAECEnv) if id in pettingzoo_env_dict.keys(): - env = pettingzoo_env_dict[id](render_mode=render_mode) + env = pettingzoo_env_dict[id](render_mode=render_mode, **kwargs) elif id == "tictactoe_v3": from pettingzoo.classic import tictactoe_v3 diff --git a/openrl/envs/__init__.py b/openrl/envs/__init__.py index 275a5552..a2eb835f 100644 --- a/openrl/envs/__init__.py +++ b/openrl/envs/__init__.py @@ -29,4 +29,4 @@ offline_all_envs = ["OfflineEnv"] -pettingzoo_all_envs = ["tictactoe_v3"] +pettingzoo_all_envs = ["tictactoe_v3", "snakes_1v1", "snakes_3v3"] diff --git a/openrl/envs/snake/snake.py b/openrl/envs/snake/snake.py index 84e09f8b..6829ab14 100644 --- a/openrl/envs/snake/snake.py +++ b/openrl/envs/snake/snake.py @@ -26,24 +26,42 @@ def convert_to_onehot(joint_action): return new_joint_action +conf_dict = { + "snakes_1v1": { + "class_literal": "SnakeEatBeans", + "n_player": 2, + "board_width": 8, + "board_height": 6, + "cell_range": 4, + "n_beans": 5, + "max_step": 50, + "game_name": "snakes", + "is_obs_continuous": False, + "is_act_continuous": False, + "agent_nums": [1, 1], + "obs_type": ["dict", "dict"], + }, + "snakes_3v3": { + "class_literal": "SnakeEatBeans", + "n_player": 6, + "board_width": 20, + "board_height": 10, + "cell_range": 8, + "n_beans": 5, + "max_step": 200, + "game_name": "snakes", + "is_obs_continuous": False, + "is_act_continuous": False, + "agent_nums": [3, 3], + "obs_type": ["dict", "dict"], + }, +} + + class SnakeEatBeans(GridGame, GridObservation, DictObservation): - def __init__(self, render_mode: Optional[str] = None): - conf = { - "class_literal": "SnakeEatBeans", - "n_player": 2, - "board_width": 8, - "board_height": 6, - "cell_range": 4, - "n_beans": 5, - "max_step": 50, - "game_name": "snakes", - "is_obs_continuous": False, - "is_act_continuous": False, - "agent_nums": [1, 1], - "obs_type": ["dict", "dict"], - "save_interval": 100, - "save_path": "../../replay_winrate_var/replay_{}.gif", - } + def __init__(self, render_mode: Optional[str] = None, id: Optional[str] = None): + assert id in conf_dict.keys(), "id must be in %s" % conf_dict.keys() + conf = conf_dict[id] self.terminate_flg = False colors = conf.get("colors", [(255, 255, 255), (255, 140, 0)]) super(SnakeEatBeans, self).__init__(conf, colors) @@ -82,8 +100,7 @@ def __init__(self, render_mode: Optional[str] = None): ] # self.action_space = [Discrete(4) for _ in range(self.n_player)] self.action_space = Discrete(4) - self.save_internal = conf["save_interval"] - self.save_path = conf["save_path"] + self.episode = 0 self.fig, self.ax = None, None if render_mode in ["human", "rgb_array"]: @@ -543,11 +560,13 @@ def can_regenerate(): def is_not_valid_action(self, all_action): not_valid = 0 if len(all_action) != self.n_player: - raise Exception("all action 维度不正确!", len(all_action)) + raise Exception("all action dimension is wrong!", len(all_action)) for i in range(self.n_player): if len(all_action[i][0]) != 4: - raise Exception("玩家%d joint action维度不正确!" % i, all_action[i]) + raise Exception( + "Player %d has wrong joint action dimension!" % i, all_action[i] + ) return not_valid def get_reward(self, all_action): diff --git a/openrl/envs/snake/snake_pettingzoo.py b/openrl/envs/snake/snake_pettingzoo.py index a9c18c76..f0598434 100644 --- a/openrl/envs/snake/snake_pettingzoo.py +++ b/openrl/envs/snake/snake_pettingzoo.py @@ -32,14 +32,27 @@ class SnakeEatBeansAECEnv(AECEnv): metadata = {"render.modes": ["human"], "name": "SnakeEatBeans"} - def __init__(self, render_mode: Optional[str] = None): - self.env = SnakeEatBeans(render_mode) - + def __init__(self, render_mode: Optional[str] = None, id: str = None): + self.env = SnakeEatBeans(render_mode, id=id) + + agent_num = len(self.possible_agents) + player_each_side = self.env.num_agents + self.agent_name_to_slice = dict( + zip( + self.possible_agents, + [ + slice(i * player_each_side, (i + 1) * player_each_side) + for i in range(agent_num) + ], + ) + ) self.agent_name_mapping = dict( zip(self.possible_agents, list(range(len(self.possible_agents)))) ) + self._action_spaces = { - agent: spaces.Discrete(4) for agent in self.possible_agents + agent: [spaces.Discrete(4) for _ in range(self.env.num_agents)] + for agent in self.possible_agents } self._observation_spaces = { agent: spaces.Box(low=-np.inf, high=np.inf, shape=(288,), dtype=np.float32) @@ -65,7 +78,7 @@ def action_space(self, agent): return deepcopy(self._action_spaces[agent]) def observe(self, agent): - return self.raw_obs[self.agent_name_mapping[agent]] + return self.raw_obs[self.agent_name_to_slice[agent]] def reset( self, @@ -93,13 +106,17 @@ def step(self, action): self._cumulative_rewards[agent] = 0 self.state[self.agent_selection] = action if self._agent_selector.is_last(): - joint_action = [self.state[agent] for agent in self.agents] + joint_action = [] + for agent in self.agents: + joint_action += self.state[agent] + self.raw_obs, self.raw_reward, self.raw_done, self.raw_info = self.env.step( joint_action ) self.rewards = { - agent: self.raw_reward[i] for i, agent in enumerate(self.agents) + agent: np.sum(self.raw_reward[self.agent_name_to_slice[agent]]) + for agent in self.agents } if np.any(self.raw_done): @@ -122,7 +139,7 @@ def close(self): @property def possible_agents(self): - return ["player_" + str(i) for i in range(self.env.n_player)] + return ["player_" + str(i) for i in range(2)] @property def num_agents(self): diff --git a/openrl/envs/wrappers/pettingzoo_wrappers.py b/openrl/envs/wrappers/pettingzoo_wrappers.py index 647c13be..0026fb61 100644 --- a/openrl/envs/wrappers/pettingzoo_wrappers.py +++ b/openrl/envs/wrappers/pettingzoo_wrappers.py @@ -15,7 +15,8 @@ # limitations under the License. """""" -from typing import Optional +from collections import defaultdict +from typing import Dict, Optional from pettingzoo.utils.env import ActionType, AECEnv from pettingzoo.utils.wrappers import BaseWrapper @@ -29,13 +30,21 @@ def reset(self, seed: Optional[int] = None, options: Optional[dict] = None): list(self.action_spaces.values()) + list(self.observation_spaces.values()) ): - space.seed(seed + i * 7891) + if isinstance(space, list): + for j in range(len(space)): + space[j].seed(seed + i * 7891 + j) + else: + space.seed(seed + i * 7891) class RecordWinner(BaseWrapper): def __init__(self, env: AECEnv): super().__init__(env) - self.cumulative_rewards = {} + self.total_rewards = defaultdict(float) + + def reset(self, seed: Optional[int] = None, options: Optional[Dict] = None): + self.total_rewards = defaultdict(float) + return super().reset(seed, options) def step(self, action: ActionType) -> None: super().step(action) @@ -50,11 +59,20 @@ def step(self, action: ActionType) -> None: self.infos[agent]["losers"] = losers def get_winners(self): - max_reward = max(self._cumulative_rewards.values()) + max_reward = max(self.total_rewards.values()) winners = [ agent - for agent, reward in self._cumulative_rewards.items() + for agent, reward in self.total_rewards.items() if reward == max_reward ] return winners + + def last(self, observe: bool = True): + """Returns observation, cumulative reward, terminated, truncated, info for the current agent (specified by self.agent_selection).""" + agent = self.agent_selection + # if self._cumulative_rewards[agent]!=0: + # print("agent:",agent,self._cumulative_rewards[agent]) + self.total_rewards[agent] += self._cumulative_rewards[agent] + + return super().last(observe) diff --git a/openrl/selfplay/opponents/jidi_opponent.py b/openrl/selfplay/opponents/jidi_opponent.py new file mode 100644 index 00000000..0db5bb69 --- /dev/null +++ b/openrl/selfplay/opponents/jidi_opponent.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2023 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" +from pathlib import Path +from typing import Callable, Dict, Optional, Union + +from openrl.selfplay.opponents.base_opponent import BaseOpponent + + +class JiDiOpponent(BaseOpponent): + def __init__( + self, + opponent_id: Optional[str] = None, + opponent_path: Optional[Union[str, Path]] = None, + opponent_info: Optional[Dict[str, str]] = None, + jidi_controller: Optional[Callable] = None, + player_num: int = 1, + ): + self.player_num = player_num + self.jidi_controller = jidi_controller + super().__init__(opponent_id, opponent_path, opponent_info) + + def act(self, player_name, observation, reward, termination, truncation, info): + # if self.player_num == 1: + # observation = [observation] + # else: + assert len(observation) == self.player_num + + joint_action = [] + for i in range(self.player_num): + action = self.jidi_controller( + observation[i], self.action_space_list[i], self.is_act_continuous + ) + joint_action.append(action[0]) + + return joint_action + + def _load(self, opponent_path: Union[str, Path]): + pass + + def _set_env(self, env, opponent_player: str): + self.action_space_list = env.action_space(opponent_player) + + assert len(self.action_space_list) == self.player_num + + self.is_act_continuous = self.action_space_list[0].__class__.__name__ == "Box" + + for i in range(self.player_num): + self.action_space_list[i] = [self.action_space_list[i]] diff --git a/openrl/selfplay/opponents/utils.py b/openrl/selfplay/opponents/utils.py index e7e23797..913d9369 100644 --- a/openrl/selfplay/opponents/utils.py +++ b/openrl/selfplay/opponents/utils.py @@ -19,10 +19,12 @@ import json import sys import time +import traceback from pathlib import Path from typing import Dict, List, Optional, Tuple, Union from openrl.selfplay.opponents.base_opponent import BaseOpponent +from openrl.selfplay.opponents.jidi_opponent import JiDiOpponent def check_opponent_template(opponent_template: Union[str, Path]): @@ -88,6 +90,37 @@ def load_opponent_from_path( return opponent +def load_opponent_from_jidi_path( + opponent_path: Union[str, Path], + opponent_info: Optional[Dict[str, str]] = None, + player_num: int = 1, +) -> Optional[BaseOpponent]: + opponent = None + if isinstance(opponent_path, str): + opponent_path = Path(opponent_path) + assert opponent_path.exists() + try: + sys.path.append(str(opponent_path.parent)) + submission_module = __import__( + "{}.submission".format(opponent_path.name), fromlist=["submission"] + ) + opponent_id = get_opponent_id(opponent_info) + opponent = JiDiOpponent( + opponent_id=opponent_id, + opponent_path=opponent_path, + opponent_info=opponent_info, + jidi_controller=submission_module.my_controller, + player_num=player_num, + ) + except Exception as e: + print(f"Load JiDi opponent from {opponent_path} failed") + traceback.print_exc() + exit() + + sys.path.remove(str(opponent_path.parent)) + return opponent + + def get_opponent_from_path( opponent_path: Union[str, Path], current_opponent: Optional[BaseOpponent] = None,