From b3fd5741eb5346acde1c1c0e3a18fa8938ed4600 Mon Sep 17 00:00:00 2001 From: jayyoung0802 Date: Wed, 30 Aug 2023 16:30:39 +0800 Subject: [PATCH 1/3] polish(yzj): polish render and add practice demo --- gobigger/render/realtime_render.py | 14 +- practice/README.md | 25 + practice/battle.py | 546 ++++++++++++++++++ practice/cooperative_agent/agent.py | 91 +++ .../default_model_config.yaml | 198 +++++++ practice/cooperative_agent/model.py | 29 + practice/solo_agent/agent.py | 120 ++++ practice/solo_agent/default_model_config.yaml | 187 ++++++ practice/solo_agent/model.py | 120 ++++ practice/solo_agent/util.py | 184 ++++++ practice/test_ai.py | 19 + practice/tools/encoder.py | 294 ++++++++++ practice/tools/features.py | 390 +++++++++++++ practice/tools/head.py | 71 +++ practice/tools/network/__init__.py | 8 + practice/tools/network/activation.py | 96 +++ practice/tools/network/encoder.py | 136 +++++ practice/tools/network/nn_module.py | 235 ++++++++ practice/tools/network/normalization.py | 36 ++ practice/tools/network/res_block.py | 231 ++++++++ practice/tools/network/rnn.py | 276 +++++++++ practice/tools/network/scatter_connection.py | 107 ++++ practice/tools/network/soft_argmax.py | 60 ++ practice/tools/network/transformer.py | 391 +++++++++++++ practice/tools/util.py | 184 ++++++ 25 files changed, 4041 insertions(+), 7 deletions(-) create mode 100644 practice/README.md create mode 100644 practice/battle.py create mode 100644 practice/cooperative_agent/agent.py create mode 100644 practice/cooperative_agent/default_model_config.yaml create mode 100644 practice/cooperative_agent/model.py create mode 100644 practice/solo_agent/agent.py create mode 100644 practice/solo_agent/default_model_config.yaml create mode 100644 practice/solo_agent/model.py create mode 100644 practice/solo_agent/util.py create mode 100644 practice/test_ai.py create mode 100644 practice/tools/encoder.py create mode 100644 practice/tools/features.py create mode 100644 practice/tools/head.py create mode 100644 practice/tools/network/__init__.py create mode 100644 practice/tools/network/activation.py create mode 100644 practice/tools/network/encoder.py create mode 100644 practice/tools/network/nn_module.py create mode 100644 practice/tools/network/normalization.py create mode 100644 practice/tools/network/res_block.py create mode 100644 practice/tools/network/rnn.py create mode 100644 practice/tools/network/scatter_connection.py create mode 100644 practice/tools/network/soft_argmax.py create mode 100644 practice/tools/network/transformer.py create mode 100644 practice/tools/util.py diff --git a/gobigger/render/realtime_render.py b/gobigger/render/realtime_render.py index 37a2538..fb60c2d 100644 --- a/gobigger/render/realtime_render.py +++ b/gobigger/render/realtime_render.py @@ -56,7 +56,7 @@ def render_all_balls_colorful(self, food_balls, thorns_balls, spore_balls, playe txt_rect = txt.get_rect(center=(x, y)) self.screen.blit(txt, txt_rect) - def fill(self, food_balls, thorns_balls, spore_balls, players, player_num_per_team=1, fps=20): + def fill(self, food_balls, thorns_balls, spore_balls, players, player_num_per_team=1, fps=20, leaderboard=None): self.screen.fill(BACKGROUND) self.render_all_balls_colorful(food_balls, thorns_balls, spore_balls, players, player_num_per_team) # add line @@ -67,13 +67,13 @@ def fill(self, food_balls, thorns_balls, spore_balls, players, player_num_per_te pygame.draw.line(self.screen, RED, (self.game_screen_width-self.padding, self.padding), (self.game_screen_width-self.padding, self.game_screen_width-self.padding), width=1) # for debug - # font = pygame.font.SysFont('Menlo', 15, True) + font = pygame.font.SysFont('Menlo', 15, True) - # assert len(leaderboard) > 0, 'leaderboard could not be None' - # leaderboard = sorted(leaderboard.items(), key=lambda d: d[1], reverse=True) - # for index, (team_id, team_size) in enumerate(leaderboard): - # pos_txt = font.render('{}: {:.5f}'.format(team_id, team_size), 1, RED) - # self.screen.blit(pos_txt, (20, 10+10*(index*2+1))) + if leaderboard is not None: + leaderboard = sorted(leaderboard.items(), key=lambda d: d[1], reverse=True) + for index, (team_id, team_size) in enumerate(leaderboard): + pos_txt = font.render('{}: {:.5f}'.format(team_id, team_size), 1, RED) + self.screen.blit(pos_txt, (20, 10+10*(index*2+1))) # fps_txt = font.render('fps: ' + str(fps), 1, RED) # last_frame_txt = font.render('frame_count: {} / {}'.format(frame_count, int(frame_count/20)), 1, RED) diff --git a/practice/README.md b/practice/README.md new file mode 100644 index 0000000..b7758ec --- /dev/null +++ b/practice/README.md @@ -0,0 +1,25 @@ +## Practice +We offer multiple gameplay modes for players to enjoy, including development, battle (Bot and AI) and spectator mode. Our platform supports both single-player and two-player matches. Welcome to explore and enjoy the experience. + +### Download Weight +```bash + +``` +### Quick Start + +#### Installation +```bash +git clone https://github.com/opendilab/GoBigger.git +pip install -e . +``` + +#### Usage +```bash +python battle.py --mode single --map farm # Single-player development +python battle.py --mode single --map vsbot # Single-player vs. Bot +python battle.py --mode single --map vsai # Single-player vs. Single-AI +python battle.py --mode team --map farm # Two-player development +python battle.py --mode team --map vsbot # Two-player vs. Bot +python battle.py --mode team --map vsai # Two-player vs. Team-AI +python battle.py --mode watch # Spectator mode: Team-AI vs. Team-AI +``` \ No newline at end of file diff --git a/practice/battle.py b/practice/battle.py new file mode 100644 index 0000000..1776db8 --- /dev/null +++ b/practice/battle.py @@ -0,0 +1,546 @@ +import logging +import pygame +import argparse +import time + +from gobigger.agents import BotAgent +from gobigger.envs import GoBiggerEnv +from gobigger.render import RealtimeRender + +logging.basicConfig(level=logging.DEBUG) + +cfg = dict( + team_num=2, + player_num_per_team=2, + direction_num=12, + step_mul=8, + map_width=64, + map_height=64, + frame_limit=3600, + action_space_size=27, + use_action_mask=False, + reward_div_value=0.1, + reward_type='log_reward', + contain_raw_obs=True, # False on collect mode, True on eval vsbot mode, because bot need raw obs + start_spirit_progress=0.2, + end_spirit_progress=0.8, + manager_settings=dict( + food_manager=dict( + num_init=260, + num_min=260, + num_max=300, + ), + thorns_manager=dict( + num_init=3, + num_min=3, + num_max=4, + ), + player_manager=dict(ball_settings=dict(score_init=13000, ), ), + ), + playback_settings=dict( + playback_type='by_frame', + by_frame=dict( + save_frame=False, # when training should set as False + save_dir='./', + save_name_prefix='gobigger', + ), + ), + ) + +def play_farm_single(step): + cfg['player_num_per_team'] = 1 + cfg['team_num'] = 1 + cfg['frame_limit'] = step + env = GoBiggerEnv(cfg) + obs = env.reset() + done = False + render = RealtimeRender(map_width=64, map_height=64) + fps_real = 0 + t1 = time.time() + clock = pygame.time.Clock() + fps_set = env.server.fps + my_player_id = 0 + bot_agents = [] + for player in env.server.player_manager.get_players(): + if player.player_id != my_player_id: + bot_agents.append(BotAgent(player.player_id)) + for i in range(100000): + actions = None + x1, y1 = None, None + action_type = 0 + # ================ control by keyboard =============== + for event in pygame.event.get(): + if event.type == pygame.KEYDOWN: + action_type = -1 + action_type = -1 + if event.key == pygame.K_UP: + x1, y1 = 0, -1 + if event.key == pygame.K_DOWN: + x1, y1 = 0, 1 + if event.key == pygame.K_LEFT: + x1, y1 = -1, 0 + if event.key == pygame.K_RIGHT: + x1, y1 = 1, 0 + if event.key == pygame.K_1: # Spores + action_type = 0 + if event.key == pygame.K_2: # Splite + action_type = 1 + if event.key == pygame.K_3: # Stop moving + action_type = 2 + actions = {my_player_id: [x1, y1, action_type]} + actions.update({agent.name: agent.step(obs[1][agent.name]) for agent in bot_agents}) + if not done: + obs, reward, done, info = env.step(actions=actions) + print(obs[0]['leaderboard']) + # import pdb; pdb.set_trace() + render.fill(food_balls=env.server.food_manager.get_balls(), + thorns_balls=env.server.thorns_manager.get_balls(), + spore_balls=env.server.spore_manager.get_balls(), + players=env.server.player_manager.get_players(), + player_num_per_team=env.server.player_num_per_team, + fps=fps_real, + leaderboard=obs[0]['leaderboard']) + render.show() + if i % fps_set == 0: + t2 = time.time() + fps_real = fps_set/(t2-t1) + t1 = time.time() + else: + logging.debug('Game Over') + break + clock.tick(fps_set) + render.close() + +def play_vsbot_single(step): + cfg['player_num_per_team'] = 1 + cfg['frame_limit'] = step + env = GoBiggerEnv(cfg) + obs = env.reset() + done = False + render = RealtimeRender(map_width=64, map_height=64) + fps_real = 0 + t1 = time.time() + clock = pygame.time.Clock() + fps_set = env.server.fps + my_player_id = 0 + bot_agents = [] + for player in env.server.player_manager.get_players(): + if player.player_id != my_player_id: + bot_agents.append(BotAgent(player.player_id)) + for i in range(100000): + actions = None + x1, y1 = None, None + action_type = 0 + # ================ control by keyboard =============== + for event in pygame.event.get(): + if event.type == pygame.KEYDOWN: + action_type = -1 + action_type = -1 + if event.key == pygame.K_UP: + x1, y1 = 0, -1 + if event.key == pygame.K_DOWN: + x1, y1 = 0, 1 + if event.key == pygame.K_LEFT: + x1, y1 = -1, 0 + if event.key == pygame.K_RIGHT: + x1, y1 = 1, 0 + if event.key == pygame.K_1: # Spores + action_type = 0 + if event.key == pygame.K_2: # Splite + action_type = 1 + if event.key == pygame.K_3: # Stop moving + action_type = 2 + actions = {my_player_id: [x1, y1, action_type]} + actions.update({agent.name: agent.step(obs[1][agent.name]) for agent in bot_agents}) + if not done: + obs, reward, done, info = env.step(actions=actions) + print(obs[0]['leaderboard']) + # import pdb; pdb.set_trace() + render.fill(food_balls=env.server.food_manager.get_balls(), + thorns_balls=env.server.thorns_manager.get_balls(), + spore_balls=env.server.spore_manager.get_balls(), + players=env.server.player_manager.get_players(), + player_num_per_team=env.server.player_num_per_team, + fps=fps_real, + leaderboard=obs[0]['leaderboard']) + render.show() + if i % fps_set == 0: + t2 = time.time() + fps_real = fps_set/(t2-t1) + t1 = time.time() + else: + logging.debug('Game Over') + break + clock.tick(fps_set) + render.close() + +def play_farm_team(step): + cfg['player_num_per_team'] = 2 + cfg['team_num'] = 1 + cfg['frame_limit'] = step + env = GoBiggerEnv(cfg) + obs = env.reset() + done = False + render = RealtimeRender(map_width=64, map_height=64) + fps_real = 0 + t1 = time.time() + clock = pygame.time.Clock() + fps_set = env.server.fps + + for i in range(100000): + action_type1 = None + action_type2 = None + x1, y1, x2, y2 = None, None, None, None + # ================ control by keyboard =============== + for event in pygame.event.get(): + if event.type == pygame.KEYDOWN: + action_type1 = -1 + action_type2 = -1 + if event.key == pygame.K_UP: + x1, y1 = 0, -1 + if event.key == pygame.K_DOWN: + x1, y1 = 0, 1 + if event.key == pygame.K_LEFT: + x1, y1 = -1, 0 + if event.key == pygame.K_RIGHT: + x1, y1 = 1, 0 + if event.key == pygame.K_1: # Spores + action_type1 = 0 + if event.key == pygame.K_2: # Splite + action_type1 = 1 + if event.key == pygame.K_3: # Stop moving + action_type1 = 2 + if event.key == pygame.K_w: + x2, y2 = 0, -1 + if event.key == pygame.K_s: + x2, y2 = 0, 1 + if event.key == pygame.K_a: + x2, y2 = -1, 0 + if event.key == pygame.K_d: + x2, y2 = 1, 0 + if event.key == pygame.K_j: # Spores + action_type2 = 0 + if event.key == pygame.K_k: # Splite + action_type2 = 1 + if event.key == pygame.K_l: # Stop moving + action_type2 = 2 + actions = { + 0: [x1, y1, action_type1], + 1: [x2, y2, action_type2], + } + # actions.update({agent.name: agent.step(obs[1][agent.name]) for agent in bot_agents}) + if not done: + obs, reward, done, info = env.step(actions=actions) + print(obs[0]['leaderboard']) + # import pdb; pdb.set_trace() + render.fill(food_balls=env.server.food_manager.get_balls(), + thorns_balls=env.server.thorns_manager.get_balls(), + spore_balls=env.server.spore_manager.get_balls(), + players=env.server.player_manager.get_players(), + player_num_per_team=env.server.player_num_per_team, + fps=fps_real, + leaderboard=obs[0]['leaderboard']) + render.show() + if i % fps_set == 0: + t2 = time.time() + fps_real = fps_set/(t2-t1) + t1 = time.time() + else: + logging.debug('Game Over') + break + clock.tick(fps_set) + render.close() + +def play_vsbot_team(step): + cfg['frame_limit'] = step + env = GoBiggerEnv(cfg) + obs = env.reset() + done = False + render = RealtimeRender(map_width=64, map_height=64) + fps_real = 0 + t1 = time.time() + clock = pygame.time.Clock() + fps_set = env.server.fps + + bot_agents = [] + for player in env.server.player_manager.get_players(): + if player.player_id != 0 and player.player_id != 1: + bot_agents.append(BotAgent(player.player_id)) + for i in range(100000): + action_type1 = None + action_type2 = None + x1, y1, x2, y2 = None, None, None, None + # ================ control by keyboard =============== + for event in pygame.event.get(): + if event.type == pygame.KEYDOWN: + action_type1 = -1 + action_type2 = -1 + if event.key == pygame.K_UP: + x1, y1 = 0, -1 + if event.key == pygame.K_DOWN: + x1, y1 = 0, 1 + if event.key == pygame.K_LEFT: + x1, y1 = -1, 0 + if event.key == pygame.K_RIGHT: + x1, y1 = 1, 0 + if event.key == pygame.K_1: # Spores + action_type1 = 0 + if event.key == pygame.K_2: # Splite + action_type1 = 1 + if event.key == pygame.K_3: # Stop moving + action_type1 = 2 + if event.key == pygame.K_w: + x2, y2 = 0, -1 + if event.key == pygame.K_s: + x2, y2 = 0, 1 + if event.key == pygame.K_a: + x2, y2 = -1, 0 + if event.key == pygame.K_d: + x2, y2 = 1, 0 + if event.key == pygame.K_j: # Spores + action_type2 = 0 + if event.key == pygame.K_k: # Splite + action_type2 = 1 + if event.key == pygame.K_l: # Stop moving + action_type2 = 2 + actions = { + 0: [x1, y1, action_type1], + 1: [x2, y2, action_type2], + } + actions.update({agent.name: agent.step(obs[1][agent.name]) for agent in bot_agents}) + if not done: + obs, reward, done, info = env.step(actions=actions) + print(obs[0]['leaderboard']) + # import pdb; pdb.set_trace() + render.fill(food_balls=env.server.food_manager.get_balls(), + thorns_balls=env.server.thorns_manager.get_balls(), + spore_balls=env.server.spore_manager.get_balls(), + players=env.server.player_manager.get_players(), + player_num_per_team=env.server.player_num_per_team, + fps=fps_real, + leaderboard=obs[0]['leaderboard']) + render.show() + if i % fps_set == 0: + t2 = time.time() + fps_real = fps_set/(t2-t1) + t1 = time.time() + else: + logging.debug('Game Over') + break + clock.tick(fps_set) + render.close() + +def play_vsai_single(step): + cfg['frame_limit'] = step + cfg['player_num_per_team'] = 1 + env = GoBiggerEnv(cfg) + obs = env.reset() + done = False + render = RealtimeRender(map_width=64, map_height=64) + fps_real = 0 + t1 = time.time() + clock = pygame.time.Clock() + fps_set = env.server.fps + my_player_id = 0 + ai_player_id = 1 + from solo_agent.agent import AIAgent as AI + ai = AI(team_name=1, player_names=[1]) + for i in range(100000): + actions = None + x1, y1 = None, None + action_type = 0 + # ================ control by keyboard =============== + for event in pygame.event.get(): + if event.type == pygame.KEYDOWN: + action_type = -1 + action_type = -1 + if event.key == pygame.K_UP: + x1, y1 = 0, -1 + if event.key == pygame.K_DOWN: + x1, y1 = 0, 1 + if event.key == pygame.K_LEFT: + x1, y1 = -1, 0 + if event.key == pygame.K_RIGHT: + x1, y1 = 1, 0 + if event.key == pygame.K_1: # Spores + action_type = 0 + if event.key == pygame.K_2: # Splite + action_type = 1 + if event.key == pygame.K_3: # Stop moving + action_type = 2 + ai_action = ai.get_actions(obs) + actions = {my_player_id: [x1, y1, action_type]} + actions.update(ai_action) + if not done: + obs, reward, done, info = env.step(actions=actions) + print(obs[0]['leaderboard']) + # import pdb; pdb.set_trace() + render.fill(food_balls=env.server.food_manager.get_balls(), + thorns_balls=env.server.thorns_manager.get_balls(), + spore_balls=env.server.spore_manager.get_balls(), + players=env.server.player_manager.get_players(), + player_num_per_team=env.server.player_num_per_team, + fps=fps_real, + leaderboard=obs[0]['leaderboard']) + render.show() + if i % fps_set == 0: + t2 = time.time() + fps_real = fps_set/(t2-t1) + t1 = time.time() + else: + logging.debug('Game Over') + break + clock.tick(fps_set) + render.close() + +def play_vsai_team(step): + cfg['frame_limit'] = step + env = GoBiggerEnv(cfg) + obs = env.reset() + done = False + render = RealtimeRender(map_width=64, map_height=64) + fps_real = 0 + t1 = time.time() + clock = pygame.time.Clock() + fps_set = env.server.fps + from cooperative_agent.agent import AIAgent as AI + ai = AI(team_name=1, player_names=[2,3]) + for i in range(100000): + action_type1 = None + action_type2 = None + x1, y1, x2, y2 = None, None, None, None + # ================ control by keyboard =============== + for event in pygame.event.get(): + if event.type == pygame.KEYDOWN: + action_type1 = -1 + action_type2 = -1 + if event.key == pygame.K_UP: + x1, y1 = 0, -1 + if event.key == pygame.K_DOWN: + x1, y1 = 0, 1 + if event.key == pygame.K_LEFT: + x1, y1 = -1, 0 + if event.key == pygame.K_RIGHT: + x1, y1 = 1, 0 + if event.key == pygame.K_1: # Spores + action_type1 = 0 + if event.key == pygame.K_2: # Splite + action_type1 = 1 + if event.key == pygame.K_3: # Stop moving + action_type1 = 2 + if event.key == pygame.K_w: + x2, y2 = 0, -1 + if event.key == pygame.K_s: + x2, y2 = 0, 1 + if event.key == pygame.K_a: + x2, y2 = -1, 0 + if event.key == pygame.K_d: + x2, y2 = 1, 0 + if event.key == pygame.K_j: # Spores + action_type2 = 0 + if event.key == pygame.K_k: # Splite + action_type2 = 1 + if event.key == pygame.K_l: # Stop moving + action_type2 = 2 + actions = { + 0: [x1, y1, action_type1], + 1: [x2, y2, action_type2], + } + ai_action = ai.get_actions(obs) + actions.update(ai_action) + if not done: + obs, reward, done, info = env.step(actions=actions) + print(obs[0]['leaderboard']) + # import pdb; pdb.set_trace() + render.fill(food_balls=env.server.food_manager.get_balls(), + thorns_balls=env.server.thorns_manager.get_balls(), + spore_balls=env.server.spore_manager.get_balls(), + players=env.server.player_manager.get_players(), + player_num_per_team=env.server.player_num_per_team, + fps=fps_real, + leaderboard=obs[0]['leaderboard']) + render.show() + if i % fps_set == 0: + t2 = time.time() + fps_real = fps_set/(t2-t1) + t1 = time.time() + else: + logging.debug('Game Over') + break + clock.tick(fps_set) + render.close() + +def watch_vsai_only(step): + cfg['frame_limit'] = step + env = GoBiggerEnv(cfg) + obs = env.reset() + done = False + render = RealtimeRender(map_width=64, map_height=64) + fps_real = 0 + t1 = time.time() + clock = pygame.time.Clock() + fps_set = env.server.fps + from cooperative_agent.agent import AIAgent as AI + ai_0 = AI(team_name=0, player_names=[0,1]) + ai_1 = AI(team_name=1, player_names=[2,3]) + for i in range(100000): + action_type1 = None + action_type2 = None + x1, y1, x2, y2 = None, None, None, None + # ================ control by keyboard =============== + actions = { + 0: [x1, y1, action_type1], + 1: [x2, y2, action_type2], + } + ai_action = ai_0.get_actions(obs) + actions.update(ai_action) + ai_action = ai_1.get_actions(obs) + actions.update(ai_action) + if not done: + obs, reward, done, info = env.step(actions=actions) + print(obs[0]['leaderboard']) + # import pdb; pdb.set_trace() + render.fill(food_balls=env.server.food_manager.get_balls(), + thorns_balls=env.server.thorns_manager.get_balls(), + spore_balls=env.server.spore_manager.get_balls(), + players=env.server.player_manager.get_players(), + player_num_per_team=env.server.player_num_per_team, + fps=fps_real, + leaderboard=obs[0]['leaderboard']) + render.show() + if i % fps_set == 0: + t2 = time.time() + fps_real = fps_set/(t2-t1) + t1 = time.time() + else: + logging.debug('Game Over') + break + clock.tick(fps_set) + render.close() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--mode', type=str, choices=['single', 'team', 'watch'], default='single') + parser.add_argument('--map', type=str, choices=['farm', 'vsbot', 'vsai'], default='farm') + parser.add_argument('--step', type=int, default=3600) + + args = parser.parse_args() + + if args.mode == 'single': + if args.map == 'farm': + play_farm_single(args.step) + elif args.map == 'vsbot': + play_vsbot_single(args.step) + elif args.map == 'vsai': + play_vsai_single(args.step) + elif args.mode == 'team': + if args.map == 'farm': + play_farm_team(args.step) + elif args.map == 'vsbot': + play_vsbot_team(args.step) + elif args.map == 'vsai': + play_vsai_team(args.step) + elif args.mode == 'watch': + watch_vsai_only(args.step) + diff --git a/practice/cooperative_agent/agent.py b/practice/cooperative_agent/agent.py new file mode 100644 index 0000000..c404757 --- /dev/null +++ b/practice/cooperative_agent/agent.py @@ -0,0 +1,91 @@ +import torch +from practice.tools.util import default_collate_with_dim +from practice.tools.features import Features +from practice.cooperative_agent.model import Model +from copy import deepcopy +from easydict import EasyDict +import torch + +class AIAgent: + + def __init__(self, team_name, player_names): + cfg = EasyDict({ + 'team_name': team_name, + 'player_names': player_names, + 'env': { + 'name': 'gobigger', + 'player_num_per_team': 2, + 'team_num': 2, + 'step_mul': 8 + }, + 'agent': { + 'player_id': None, + 'game_player_id': None, + 'features': {} + }, + 'checkpoint_path': 'PATH/MODEL_NAME.pth.tar' + }) + self.agents = {} + for player_name in player_names: + cfg_cp = deepcopy(cfg) + cfg_cp.agent.player_id = player_name + cfg_cp.agent.game_player_id = player_name + agent = Agent(cfg_cp) + agent.reset() + agent.model.load_state_dict(torch.load(cfg.checkpoint_path, map_location='cpu')['model'], strict=False) + self.agents[player_name] = agent + + def get_actions(self, obs): + global_state, player_states = obs + actions = {} + for player_name, agent in self.agents.items(): + action = agent.step([global_state, {player_name: player_states[player_name]}]) + actions.update(action) + return actions + +class Agent: + + def __init__(self, cfg,): + self.whole_cfg = cfg + self.player_num = self.whole_cfg.env.player_num_per_team + self.team_num = self.whole_cfg.env.team_num + self.game_player_id = self.whole_cfg.agent.game_player_id # start from 0 + self.game_team_id = self.game_player_id // self.player_num # start from 0 + self.player_id = self.whole_cfg.agent.player_id + self.features = Features(self.whole_cfg) + self.eval_padding = self.whole_cfg.agent.get('eval_padding', False) + self.use_action_mask = self.whole_cfg.agent.get('use_action_mask', False) + self.model = Model(self.whole_cfg) + + def reset(self): + self.last_action_type = self.features.direction_num * 2 + + def preprocess(self, obs): + self.last_player_score = obs[1][self.game_player_id]['score'] + if self.use_action_mask: + can_eject = obs[1][self.game_player_id]['can_eject'] + can_split = obs[1][self.game_player_id]['can_split'] + action_mask = self.features.generate_action_mask(can_eject=can_eject,can_split=can_split) + else: + action_mask = self.features.generate_action_mask(can_eject=True,can_split=True) + obs = self.features.transform_obs(obs, game_player_id=self.game_player_id, + last_action_type=self.last_action_type,padding=self.eval_padding) + obs = default_collate_with_dim([obs]) + + obs['action_mask'] = action_mask.unsqueeze(0) + return obs + + def step(self, obs): + self.raw_obs = obs + obs = self.preprocess(obs) + self.model_input = obs + with torch.no_grad(): + self.model_output = self.model.compute_action(self.model_input) + actions = self.postprocess(self.model_output['action'].detach().numpy()) + return actions + + def postprocess(self, model_actions): + actions = {} + actions[self.game_player_id] = self.features.transform_action(model_actions[0]) + self.last_action_type = model_actions[0].item() + return actions diff --git a/practice/cooperative_agent/default_model_config.yaml b/practice/cooperative_agent/default_model_config.yaml new file mode 100644 index 0000000..6837134 --- /dev/null +++ b/practice/cooperative_agent/default_model_config.yaml @@ -0,0 +1,198 @@ +var1: &VIEW_BINARY_NUM 8 +var2: &ABS_VIEW_BINARY_NUM 7 +agent: + enable_baselines: [ 'score', 'spore','team_spore', 'clone','team_clone','opponent','team_opponent' ] + features: + max_ball_num: 64 + max_food_num: 256 + max_spore_num: 128 + direction_num: 12 + spatial_x: 64 + spatial_y: 64 +model: + ortho_init: True + value_head_init_gains: + score: 1 + spore: 0.01 + team_spore: 0.01 + clone: 0.01 + team_clone: 0.01 + opponent: 0.01 + team_opponent: 0.01 + min_dist: 0.01 + max_dist: 0.01 + scalar_encoder: + modules: + view_x: + arc: sign_binary + num_embeddings: *ABS_VIEW_BINARY_NUM + embedding_dim: 8 + view_y: + arc: sign_binary + num_embeddings: *ABS_VIEW_BINARY_NUM + embedding_dim: 8 + view_width: + arc: binary + num_embeddings: *ABS_VIEW_BINARY_NUM + embedding_dim: 8 +# view_height: +# arc: binary +# num_embeddings: *ABS_VIEW_BINARY_NUM +# embedding_dim: 8 + score: + arc: one_hot + num_embeddings: 10 + embedding_dim: 8 + team_score: + arc: one_hot + num_embeddings: 10 + embedding_dim: 8 + rank: + arc: one_hot + num_embeddings: 4 + embedding_dim: 8 + time: + arc: time + embedding_dim: 8 + last_action_type: + arc: one_hot + num_embeddings: 27 # direction_num * 2 + 3 + input_dim: 80 + hidden_dim: 64 + layer_num: 2 + norm_type: 'none' + output_dim: 32 + activation: 'relu' + team_encoder: + modules: + alliance: + arc: one_hot + num_embeddings: 2 + view_x: + arc: sign_binary + num_embeddings: *ABS_VIEW_BINARY_NUM + view_y: + arc: sign_binary + num_embeddings: *ABS_VIEW_BINARY_NUM +# view_width: +# arc: binary +# num_embeddings: *ABS_VIEW_BINARY_NUM +# embedding_dim: 12 +# view_height: +# arc: binary +# num_embeddings: *ABS_VIEW_BINARY_NUM +# embedding_dim: 12 +# score: +# arc: one_hot +# num_embeddings: 10 +# embedding_dim: 12 +# team_score: +# arc: one_hot +# num_embeddings: 10 +# embedding_dim: 12 +# team_rank: +# arc: one_hot +# num_embeddings: 10 +# embedding_dim: 12 + embedding_dim: 16 + encoder: + input_dim: 16 + hidden_dim: 32 + layer_num: 2 + activation: 'relu' + norm_type: 'none' + transformer: + head_num: 4 + ffn_size: 32 + layer_num: 2 + activation: 'relu' + variant: 'postnorm' + output: + output_dim: 16 + activation: 'relu' + norm_type: 'none' + ball_encoder: + modules: + alliance: + arc: one_hot + num_embeddings: 4 + score: + arc: one_hot + num_embeddings: 50 + radius: + arc: unsqueeze +# score_ratio: +# arc: one_hot +# num_embeddings: 50 + rank: + arc: one_hot + num_embeddings: 5 + x: + arc: sign_binary + num_embeddings: *VIEW_BINARY_NUM + embedding_dim: 8 + y: + arc: sign_binary + num_embeddings: *VIEW_BINARY_NUM + embedding_dim: 8 + next_x: + arc: sign_binary + num_embeddings: *VIEW_BINARY_NUM + embedding_dim: 8 + next_y: + arc: sign_binary + num_embeddings: *VIEW_BINARY_NUM + embedding_dim: 8 + embedding_dim: 64 + encoder: + input_dim: 92 + hidden_dim: 128 + layer_num: 2 + activation: 'relu' + norm_type: 'none' + transformer: + head_num: 4 + ffn_size: 64 + layer_num: 3 + activation: 'relu' + variant: 'postnorm' + output: + output_dim: 64 + activation: 'relu' + norm_type: 'none' + spatial_encoder: + scatter: + input_dim: 64 + output_dim: 16 + scatter_type: add + activation: 'relu' + norm_type: 'none' + resnet: + project_dim: 12 + down_channels: [32, 32, 16 ] + activation: 'relu' + norm_type: 'none' + output: + output_dim: 64 + activation: 'relu' + norm_type: 'none' + policy: + embedding_dim: 64 + project: + input_dim: 176 # scalar + team + ball + spatial + activation: 'relu' + norm_type: 'none' + resnet: + activation: 'relu' + norm_type: 'none' + res_num: 3 + value: + embedding_dim: 64 + project: + input_dim: 176 # scalar + team + ball + spatial + activation: 'relu' + norm_type: 'none' + resnet: + activation: 'relu' + norm_type: 'none' + res_num: 3 diff --git a/practice/cooperative_agent/model.py b/practice/cooperative_agent/model.py new file mode 100644 index 0000000..ce63e59 --- /dev/null +++ b/practice/cooperative_agent/model.py @@ -0,0 +1,29 @@ +import os +import torch +import torch.nn as nn +from ..tools.util import read_config, deep_merge_dicts +from ..tools.encoder import Encoder +from ..tools.head import PolicyHead, ValueHead + +default_config = read_config(os.path.join(os.path.dirname(__file__), 'default_model_config.yaml')) + +class Model(nn.Module): + def __init__(self, cfg={}, use_value_network=False): + super(Model, self).__init__() + self.whole_cfg = deep_merge_dicts(default_config, cfg) + self.model_cfg = self.whole_cfg.model + self.use_value_network = use_value_network + self.encoder = Encoder(self.whole_cfg) + self.policy_head = PolicyHead(self.whole_cfg) + self.temperature = self.whole_cfg.agent.get('temperature', 1) + + # used in rl_eval actor + def compute_action(self, obs, ): + action_mask = obs.pop('action_mask',None) + embedding = self.encoder(obs, ) + logit = self.policy_head(embedding, temperature=self.temperature) + if action_mask is not None: + logit.masked_fill_(mask=action_mask,value=-1e9) + dist = torch.distributions.Categorical(logits=logit) + action = dist.sample() + return {'action': action, 'logit': logit} \ No newline at end of file diff --git a/practice/solo_agent/agent.py b/practice/solo_agent/agent.py new file mode 100644 index 0000000..599e0f0 --- /dev/null +++ b/practice/solo_agent/agent.py @@ -0,0 +1,120 @@ +import torch +from practice.tools.util import default_collate_with_dim +from practice.tools.features import Features +from practice.solo_agent.model import Model +from copy import deepcopy +from easydict import EasyDict +import torch + +class AIAgent: + + def __init__(self, team_name, player_names): + cfg = EasyDict({ + 'env': { + 'name': 'gobigger', + 'player_num_per_team': 1, + 'team_num': 2, + }, + 'agent': { + 'player_id': None, + 'game_player_id': None, + 'features': {} + }, + 'checkpoint_path': 'PATH/MODEL_NAME.pth.tar', + }) + self.agents = {} + for player_name in player_names: + cfg_cp = deepcopy(cfg) + cfg_cp.agent.player_id = player_name + cfg_cp.agent.game_player_id = player_name + agent = Agent(cfg_cp) + agent.reset() + agent.model.load_state_dict(torch.load(cfg.checkpoint_path, map_location='cpu')['model'], strict=False) + self.agents[player_name] = agent + + def get_actions(self, obs): + global_state, player_states = obs + actions = {} + for player_name, agent in self.agents.items(): + action = agent.step([global_state, {player_name: player_states[player_name]}]) + actions.update(action) + return actions + +class Agent: + + def __init__(self, cfg=None, ): + self.whole_cfg = cfg + self.cfg = self.whole_cfg.agent + # setup model + self.use_action_mask = self.whole_cfg.agent.get('use_action_mask', False) + self.player_num = self.whole_cfg.env.player_num_per_team + self.team_num = self.whole_cfg.env.team_num + self.game_player_id = self.whole_cfg.agent.game_player_id + self.game_team_id = self.game_player_id // self.player_num + self.features = Features(self.whole_cfg) + self.device = 'cpu' + self.model = Model(self.whole_cfg) + + def transform_action(self, agent_outputs, env_status, eval_vsbot=False): + env_num = len(env_status) + actions_list = agent_outputs['action'].cpu().numpy().tolist() + actions = {} + for env_id in range(env_num): + actions[env_id] = {} + game_player_num = self.player_num if eval_vsbot else self.player_num * self.team_num + for game_player_id in range(game_player_num): + action_idx = actions_list[env_id * (game_player_num) + game_player_id] + env_status[env_id].last_action_types[game_player_id] = action_idx + actions[env_id][game_player_id] = self.features.transform_action(action_idx) + return actions + + ########## only for submission#################### + def reset(self): + self.last_action_type = {} + for player_id in range(self.player_num*self.game_team_id, self.player_num*(self.game_team_id+1)): + self.last_action_type[player_id] = self.features.direction_num * 2 + + def step(self, obs): + """ + Overview: + Agent.step() in submission + Arguments: + - obs + Returns: + - action + """ + # preprocess obs + env_team_obs = [] + for player_id in range(self.player_num*self.game_team_id, self.player_num*(self.game_team_id+1)): + game_player_obs = self.features.transform_obs(obs, game_player_id=player_id, + last_action_type=self.last_action_type[player_id]) + env_team_obs.append(game_player_obs) + env_team_obs = stack(env_team_obs) + obs = default_collate_with_dim([env_team_obs], device=self.device) + + # policy + self.model_input = obs + with torch.no_grad(): + model_output = self.model(self.model_input)['action'].cpu().detach().numpy() + + actions = [] + for i in range(len(model_output)): + actions.append(self.features.transform_action(model_output[i])) + ret = {} + for player_id, act in zip(range(self.player_num*self.game_team_id, self.player_num*(self.game_team_id+1)), actions): + ret[player_id] = act + for player_id, act in zip(range(self.player_num*self.game_team_id, self.player_num*(self.game_team_id+1)), model_output): + self.last_action_type[player_id] = act.item() # TODO + return ret + #################################################### + +def stack(data): + result = {} + for k1 in data[0].keys(): + result[k1] = {} + if isinstance(data[0][k1], dict): + for k2 in data[0][k1].keys(): + result[k1][k2] = torch.stack([o[k1][k2] for o in data]) + else: + result[k1] = torch.stack([o[k1] for o in data]) + return result \ No newline at end of file diff --git a/practice/solo_agent/default_model_config.yaml b/practice/solo_agent/default_model_config.yaml new file mode 100644 index 0000000..1a532ec --- /dev/null +++ b/practice/solo_agent/default_model_config.yaml @@ -0,0 +1,187 @@ +var1: &VIEW_BINARY_NUM 8 +var2: &ABS_VIEW_BINARY_NUM 7 +agent: + enable_baselines: [ 'score', 'spore','team_spore', 'clone','team_clone','opponent','team_opponent' ] + features: + max_ball_num: 64 + max_food_num: 256 + max_spore_num: 128 + direction_num: 12 + spatial_x: 64 + spatial_y: 64 +model: + scalar_encoder: + modules: + view_x: + arc: sign_binary + num_embeddings: *ABS_VIEW_BINARY_NUM + embedding_dim: 8 + view_y: + arc: sign_binary + num_embeddings: *ABS_VIEW_BINARY_NUM + embedding_dim: 8 + view_width: + arc: binary + num_embeddings: *ABS_VIEW_BINARY_NUM + embedding_dim: 8 +# view_height: +# arc: binary +# num_embeddings: *ABS_VIEW_BINARY_NUM +# embedding_dim: 8 + score: + arc: one_hot + num_embeddings: 10 + embedding_dim: 8 + team_score: + arc: one_hot + num_embeddings: 10 + embedding_dim: 8 + rank: + arc: one_hot + num_embeddings: 4 + embedding_dim: 8 + time: + arc: time + embedding_dim: 8 + last_action_type: + arc: one_hot + num_embeddings: 27 # direction_num * 2 + 3 + input_dim: 80 + hidden_dim: 64 + layer_num: 2 + norm_type: 'none' + output_dim: 32 + activation: 'relu' + team_encoder: + modules: + alliance: + arc: one_hot + num_embeddings: 2 + view_x: + arc: sign_binary + num_embeddings: *ABS_VIEW_BINARY_NUM + view_y: + arc: sign_binary + num_embeddings: *ABS_VIEW_BINARY_NUM +# view_width: +# arc: binary +# num_embeddings: *ABS_VIEW_BINARY_NUM +# embedding_dim: 12 +# view_height: +# arc: binary +# num_embeddings: *ABS_VIEW_BINARY_NUM +# embedding_dim: 12 +# score: +# arc: one_hot +# num_embeddings: 10 +# embedding_dim: 12 +# team_score: +# arc: one_hot +# num_embeddings: 10 +# embedding_dim: 12 +# team_rank: +# arc: one_hot +# num_embeddings: 10 +# embedding_dim: 12 + embedding_dim: 16 + encoder: + input_dim: 16 + hidden_dim: 32 + layer_num: 2 + activation: 'relu' + norm_type: 'LN' + transformer: + head_num: 4 + ffn_size: 32 + layer_num: 2 + activation: 'relu' + variant: 'postnorm' + output: + output_dim: 16 + activation: 'relu' + norm_type: 'LN' + ball_encoder: + modules: + alliance: + arc: one_hot + num_embeddings: 4 + score: + arc: one_hot + num_embeddings: 50 + radius: + arc: unsqueeze +# score_ratio: +# arc: one_hot +# num_embeddings: 50 + rank: + arc: one_hot + num_embeddings: 5 + x: + arc: sign_binary + num_embeddings: *VIEW_BINARY_NUM + embedding_dim: 8 + y: + arc: sign_binary + num_embeddings: *VIEW_BINARY_NUM + embedding_dim: 8 + next_x: + arc: sign_binary + num_embeddings: *VIEW_BINARY_NUM + embedding_dim: 8 + next_y: + arc: sign_binary + num_embeddings: *VIEW_BINARY_NUM + embedding_dim: 8 + embedding_dim: 64 + encoder: + input_dim: 92 + hidden_dim: 128 + layer_num: 2 + activation: 'relu' + norm_type: 'LN' + transformer: + head_num: 4 + ffn_size: 64 + layer_num: 3 + activation: 'relu' + variant: 'postnorm' + output: + output_dim: 64 + activation: 'relu' + norm_type: 'LN' + spatial_encoder: + scatter: + input_dim: 64 + output_dim: 16 + scatter_type: add + activation: 'relu' + norm_type: 'LN' + resnet: + project_dim: 12 + down_channels: [32, 32, 16 ] + activation: 'relu' + norm_type: 'LN' + output: + output_dim: 64 + activation: 'relu' + norm_type: 'LN' + policy: + embedding_dim: 64 + project: + input_dim: 176 # scalar + team + ball + spatial + activation: 'relu' + norm_type: 'LN' + resnet: + activation: 'relu' + norm_type: 'LN' + res_num: 3 + value: + embedding_dim: 64 + project: + input_dim: 352 # scalar + team + ball + spatial + activation: 'relu' + norm_type: 'LN' + resnet: + activation: 'relu' + norm_type: 'LN' + res_num: 3 \ No newline at end of file diff --git a/practice/solo_agent/model.py b/practice/solo_agent/model.py new file mode 100644 index 0000000..8d3be88 --- /dev/null +++ b/practice/solo_agent/model.py @@ -0,0 +1,120 @@ +import os +from typing import Dict, Any +import numpy as np +import torch +import torch.nn as nn +from ..tools.util import read_config, deep_merge_dicts +from ..tools.encoder import Encoder +from ..tools.head import PolicyHead, ValueHead + +default_config = read_config(os.path.join(os.path.dirname(__file__), 'default_model_config.yaml')) + +class Model(nn.Module): + def __init__(self, cfg={}, **kwargs): + super(Model, self).__init__() + self.whole_cfg = deep_merge_dicts(default_config, cfg) + self.encoder = Encoder(self.whole_cfg) + self.policy_head = PolicyHead(self.whole_cfg) + self.value_head = ValueHead(self.whole_cfg) + self.only_update_value = False + self.ortho_init = self.whole_cfg.model.get('ortho_init', True) + self.player_num = self.whole_cfg.env.player_num_per_team + self.team_num = self.whole_cfg.env.team_num + + def forward(self, obs, temperature=0): + obs = flatten_data(obs,start_dim=0,end_dim=1) # [env_num*team_num, 2] + embedding = self.encoder(obs) + logit = self.policy_head(embedding) + if temperature == 0: + action = logit.argmax(dim=-1) + else: + logit = logit.div(temperature) + dist = torch.distributions.Categorical(logits=logit) + action = dist.sample() + return {'action': action, 'logit': logit} + + def compute_value(self, obs, ): + obs = flatten_data(obs,start_dim=0,end_dim=1) + embedding = self.encoder(obs) + batch_size = embedding.shape[0] // self.team_num // self.player_num + team_embedding = embedding.reshape(batch_size*self.team_num, self.player_num, -1) + team_embedding = self.transform_ctde(team_embedding,device=team_embedding.device) + value = self.value_head(team_embedding) # [bs, player_num, 1] + return {'value': value.reshape(-1)} + + def compute_logp_action(self, obs, **kwargs, ): + obs = flatten_data(obs,start_dim=0,end_dim=1) + embedding = self.encoder(obs) + batch_size = embedding.shape[0] // self.team_num // self.player_num + logit = self.policy_head(embedding) + dist = torch.distributions.Categorical(logits=logit) + action = dist.sample() + action_log_probs = dist.log_prob(action) + log_action_probs = action_log_probs + team_embedding = embedding.reshape(batch_size*self.team_num, self.player_num, -1) + team_embedding = self.transform_ctde(team_embedding,device=team_embedding.device) + value = self.value_head(team_embedding) + return {'action': action, + 'action_logp': log_action_probs, + 'logit': logit, + 'value': value.reshape(-1), + } + + def rl_train(self, inputs: dict, **kwargs) -> Dict[str, Any]: + r""" + Overview: + Forward and backward function of learn mode. + Arguments: + - inputs (:obj:`dict`): Dict type data + ArgumentsKeys: + - obs shape :math:`(T+1, B)`, where T is timestep, B is batch size + - action_logp: behaviour logits, :math:`(T, B,action_size)` + - action: behaviour actions, :math:`(T, B)` + - reward: shape math:`(T, B)` + - done:shape math:`(T, B)` + Returns: + - metric_dict (:obj:`Dict[str, Any]`): + Including current total_loss, policy_gradient_loss, critic_loss and entropy_loss + """ + + obs = inputs['obs'] + # flat obs + obs = flatten_data(obs,start_dim=0,end_dim=1) + embedding = self.encoder(obs, ) + batch_size = embedding.shape[0] // self.player_num + logits = self.policy_head(embedding) + critic_input = embedding.reshape(batch_size, self.player_num, -1) + critic_input = self.transform_ctde(critic_input, device=critic_input.device) + if self.only_update_value: + critic_input = detach_grad(critic_input) + values = self.value_head(critic_input) + outputs = { + 'value': values.squeeze(-1).reshape(-1), + 'logit': logits, + 'action': inputs['action'].reshape(-1), + 'action_logp': inputs['action_logp'].reshape(-1), + # 'reward': inputs['reward'], + # 'done': inputs['done'], + 'old_value': inputs['old_value'].reshape(-1), + 'advantage': inputs['advantage'].reshape(-1), + 'return': inputs['return'].reshape(-1), + } + return outputs + + def transform_ctde(self, array, device): + # player = A,B array AB and BA + ret = [] + for i in range(self.player_num): + index = [i for i in range(self.player_num)] + index.pop(i) + other_array = torch.index_select(array, dim=1, index=torch.LongTensor(index).to(device)) + self_array = array[:,i,:].unsqueeze(dim=1) + ret.append(torch.cat((self_array, other_array), dim=1).flatten(start_dim=1,end_dim=2).unsqueeze(1)) + ret = torch.cat(ret, dim=1) + return ret + +def flatten_data(data,start_dim=0,end_dim=1): + if isinstance(data, dict): + return {k: flatten_data(v,start_dim=start_dim, end_dim=end_dim) for k, v in data.items()} + elif isinstance(data, torch.Tensor): + return torch.flatten(data, start_dim=start_dim, end_dim=end_dim) \ No newline at end of file diff --git a/practice/solo_agent/util.py b/practice/solo_agent/util.py new file mode 100644 index 0000000..a512dd7 --- /dev/null +++ b/practice/solo_agent/util.py @@ -0,0 +1,184 @@ +import copy +import json +import os +from typing import NoReturn, Optional, List + +import yaml +from easydict import EasyDict +import re +from collections.abc import Sequence, Mapping +from typing import List, Dict, Union, Any + +import torch +import collections.abc as container_abcs +from torch._six import string_classes +#from torch._six import int_classes as _int_classes +int_classes = int + +np_str_obj_array_pattern = re.compile(r'[SaUO]') + +default_collate_err_msg_format = ( + "default_collate: batch must contain tensors, numpy arrays, numbers, " + "dicts or lists; found {}" +) + + +def read_config(path: str) -> EasyDict: + """ + Overview: + read configuration from path + Arguments: + - path (:obj:`str`): Path of source yaml + Returns: + - (:obj:`EasyDict`): Config data from this file with dict type + """ + if path: + assert os.path.exists(path), path + with open(path, "r") as f: + config = yaml.safe_load(f) + else: + config = {} + return EasyDict(config) + +def deep_merge_dicts(original: dict, new_dict: dict) -> dict: + """ + Overview: + merge two dict using deep_update + Arguments: + - original (:obj:`dict`): Dict 1. + - new_dict (:obj:`dict`): Dict 2. + Returns: + - (:obj:`dict`): A new dict that is d1 and d2 deeply merged. + """ + original = original or {} + new_dict = new_dict or {} + merged = copy.deepcopy(original) + if new_dict: # if new_dict is neither empty dict nor None + deep_update(merged, new_dict, True, []) + + return merged + +def deep_update( + original: dict, + new_dict: dict, + new_keys_allowed: bool = False, + whitelist: Optional[List[str]] = None, + override_all_if_type_changes: Optional[List[str]] = None +): + """ + Overview: + Updates original dict with values from new_dict recursively. + + .. note:: + + If new key is introduced in new_dict, then if new_keys_allowed is not + True, an error will be thrown. Further, for sub-dicts, if the key is + in the whitelist, then new subkeys can be introduced. + + Arguments: + - original (:obj:`dict`): Dictionary with default values. + - new_dict (:obj:`dict`): Dictionary with values to be updated + - new_keys_allowed (:obj:`bool`): Whether new keys are allowed. + - whitelist (Optional[List[str]]): List of keys that correspond to dict + values where new subkeys can be introduced. This is only at the top + level. + - override_all_if_type_changes(Optional[List[str]]): List of top level + keys with value=dict, for which we always simply override the + entire value (:obj:`dict`), if the "type" key in that value dict changes. + """ + whitelist = whitelist or [] + override_all_if_type_changes = override_all_if_type_changes or [] + + for k, value in new_dict.items(): + if k not in original and not new_keys_allowed: + raise RuntimeError("Unknown config parameter `{}`. Base config have: {}.".format(k, original.keys())) + + # Both original value and new one are dicts. + if isinstance(original.get(k), dict) and isinstance(value, dict): + # Check old type vs old one. If different, override entire value. + if k in override_all_if_type_changes and \ + "type" in value and "type" in original[k] and \ + value["type"] != original[k]["type"]: + original[k] = value + # Whitelisted key -> ok to add new subkeys. + elif k in whitelist: + deep_update(original[k], value, True) + # Non-whitelisted key. + else: + deep_update(original[k], value, new_keys_allowed) + # Original value not a dict OR new value not a dict: + # Override entire value. + else: + original[k] = value + return original + + +def default_collate_with_dim(batch,device='cpu',dim=0, k=None,cat=False): + r"""Puts each data field into a tensor with outer dimension batch size""" + elem = batch[0] + elem_type = type(elem) + #if k is not None: + # print(k) + + if isinstance(elem, torch.Tensor): + out = None + if torch.utils.data.get_worker_info() is not None: + # If we're in a background process, concatenate directly into a + # shared memory tensor to avoid an extra copy + numel = sum([x.numel() for x in batch]) + storage = elem.storage()._new_shared(numel) + out = elem.new(storage) + try: + if cat == True: + return torch.cat(batch, dim=dim, out=out).to(device=device) + else: + return torch.stack(batch, dim=dim, out=out).to(device=device) + except: + print(batch) + if k is not None: + print(k) + elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ + and elem_type.__name__ != 'string_': + if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap': + # array of string classes and object + if np_str_obj_array_pattern.search(elem.dtype.str) is not None: + raise TypeError(default_collate_err_msg_format.format(elem.dtype)) + + return default_collate_with_dim([torch.as_tensor(b,device=device) for b in batch],device=device,dim=dim,cat=cat) + elif elem.shape == (): # scalars + try: + return torch.as_tensor(batch,device=device) + except: + print(batch) + if k is not None: + print(k) + elif isinstance(elem, float): + try: + return torch.tensor(batch,device=device) + except: + print(batch) + if k is not None: + print(k) + elif isinstance(elem, int_classes): + try: + return torch.tensor(batch,device=device) + except: + print(batch) + if k is not None: + print(k) + elif isinstance(elem, string_classes): + return batch + elif isinstance(elem, container_abcs.Mapping): + return {key: default_collate_with_dim([d[key] for d in batch if key in d.keys()],device=device,dim=dim, k=key, cat=cat) for key in elem} + elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple + return elem_type(*(default_collate_with_dim(samples,device=device,dim=dim,cat=cat) for samples in zip(*batch))) + elif isinstance(elem, container_abcs.Sequence): + # check to make sure that the elements in batch have consistent size + it = iter(batch) + elem_size = len(next(it)) + if not all(len(elem) == elem_size for elem in it): + raise RuntimeError('each element in list of batch should be of equal size') + transposed = zip(*batch) + return [default_collate_with_dim(samples,device=device,dim=dim,cat=cat) for samples in transposed] + + raise TypeError(default_collate_err_msg_format.format(elem_type)) \ No newline at end of file diff --git a/practice/test_ai.py b/practice/test_ai.py new file mode 100644 index 0000000..f16f39c --- /dev/null +++ b/practice/test_ai.py @@ -0,0 +1,19 @@ +from gobigger.envs import create_env +from cooperative_agent.agent import AIAgent as AI + +env = create_env('st_t2p2') +obs = env.reset() + +agent1 = AI(team_name=0, player_names=[0,1]) +agent2 = AI(team_name=1, player_names=[2,3]) + +for i in range(1000): + actions1 = agent1.get_actions(obs) + actions2 = agent2.get_actions(obs) + actions1.update(actions2) + obs, rew, done, info = env.step(actions1) + print('[{}] leaderboard={}'.format(i, obs[0]['leaderboard'])) + if done: + print('finish game!') + break +env.close() \ No newline at end of file diff --git a/practice/tools/encoder.py b/practice/tools/encoder.py new file mode 100644 index 0000000..7094b78 --- /dev/null +++ b/practice/tools/encoder.py @@ -0,0 +1,294 @@ +from typing import Dict + +import torch +import torch.nn as nn +from torch import Tensor + +from .network import sequence_mask, ScatterConnection +from .network.encoder import SignBinaryEncoder, BinaryEncoder, OnehotEncoder, TimeEncoder, UnsqueezeEncoder +from .network.nn_module import fc_block, conv2d_block, MLP +from .network.res_block import ResBlock +from .network.transformer import Transformer + + +class ScalarEncoder(nn.Module): + def __init__(self, cfg): + super(ScalarEncoder, self).__init__() + self.whole_cfg = cfg + self.cfg = self.whole_cfg.model.scalar_encoder + self.encode_modules = nn.ModuleDict() + for k, item in self.cfg.modules.items(): + if item['arc'] == 'time': + self.encode_modules[k] = TimeEncoder(embedding_dim=item['embedding_dim']) + elif item['arc'] == 'one_hot': + self.encode_modules[k] = OnehotEncoder(num_embeddings=item['num_embeddings'], ) + elif item['arc'] == 'binary': + self.encode_modules[k] = BinaryEncoder(num_embeddings=item['num_embeddings'], ) + elif item['arc'] == 'sign_binary': + self.encode_modules[k] = SignBinaryEncoder(num_embeddings=item['num_embeddings'], ) + else: + print(f'cant implement {k} for arc {item["arc"]}') + raise NotImplementedError + + self.layers = MLP(in_channels=self.cfg.input_dim, hidden_channels=self.cfg.hidden_dim, + out_channels=self.cfg.output_dim, + layer_num=self.cfg.layer_num, + layer_fn=fc_block, + activation=self.cfg.activation, + norm_type=self.cfg.norm_type, + use_dropout=False + ) + + def forward(self, x: Dict[str, Tensor]): + embeddings = [] + for key, item in self.cfg.modules.items(): + assert key in x, key + embeddings.append(self.encode_modules[key](x[key])) + + out = torch.cat(embeddings, dim=-1) + out = self.layers(out) + return out + + +class TeamEncoder(nn.Module): + def __init__(self, cfg): + super(TeamEncoder, self).__init__() + self.whole_cfg = cfg + self.cfg = self.whole_cfg.model.team_encoder + self.encode_modules = nn.ModuleDict() + + for k, item in self.cfg.modules.items(): + if item['arc'] == 'one_hot': + self.encode_modules[k] = OnehotEncoder(num_embeddings=item['num_embeddings'], ) + elif item['arc'] == 'binary': + self.encode_modules[k] = BinaryEncoder(num_embeddings=item['num_embeddings'], ) + elif item['arc'] == 'sign_binary': + self.encode_modules[k] = SignBinaryEncoder(num_embeddings=item['num_embeddings'], ) + else: + print(f'cant implement {k} for arc {item["arc"]}') + raise NotImplementedError + + self.embedding_dim = self.cfg.embedding_dim + self.encoder_cfg = self.cfg.encoder + self.encode_layers = MLP(in_channels=self.encoder_cfg.input_dim, + hidden_channels=self.encoder_cfg.hidden_dim, + out_channels=self.embedding_dim, + layer_num=self.encoder_cfg.layer_num, + layer_fn=fc_block, + activation=self.encoder_cfg.activation, + norm_type=self.encoder_cfg.norm_type, + use_dropout=False) + # self.activation_type = self.cfg.activation + + self.transformer_cfg = self.cfg.transformer + self.transformer = Transformer( + n_heads=self.transformer_cfg.head_num, + embedding_size=self.embedding_dim, + ffn_size=self.transformer_cfg.ffn_size, + n_layers=self.transformer_cfg.layer_num, + attention_dropout=0.0, + relu_dropout=0.0, + dropout=0.0, + activation=self.transformer_cfg.activation, + variant=self.transformer_cfg.variant, + ) + self.output_cfg = self.cfg.output + self.output_fc = fc_block(self.embedding_dim, + self.output_cfg.output_dim, + norm_type=self.output_cfg.norm_type, + activation=self.output_cfg.activation) + + def forward(self, x): + embeddings = [] + player_num = x['player_num'] + mask = sequence_mask(player_num, max_len=x['view_x'].shape[1]) + for key, item in self.cfg.modules.items(): + assert key in x, f"{key} not implemented" + x_input = x[key] + embeddings.append(self.encode_modules[key](x_input)) + + x = torch.cat(embeddings, dim=-1) + x = self.encode_layers(x) + x = self.transformer(x, mask=mask) + team_info = self.output_fc(x.sum(dim=1) / player_num.unsqueeze(dim=-1)) + return team_info + + +class BallEncoder(nn.Module): + def __init__(self, cfg): + super(BallEncoder, self).__init__() + self.whole_cfg = cfg + self.cfg = self.whole_cfg.model.ball_encoder + self.encode_modules = nn.ModuleDict() + for k, item in self.cfg.modules.items(): + if item['arc'] == 'one_hot': + self.encode_modules[k] = OnehotEncoder(num_embeddings=item['num_embeddings'], ) + elif item['arc'] == 'binary': + self.encode_modules[k] = BinaryEncoder(num_embeddings=item['num_embeddings'], ) + elif item['arc'] == 'sign_binary': + self.encode_modules[k] = SignBinaryEncoder(num_embeddings=item['num_embeddings'], ) + elif item['arc'] == 'unsqueeze': + self.encode_modules[k] = UnsqueezeEncoder() + else: + print(f'cant implement {k} for arc {item["arc"]}') + raise NotImplementedError + self.embedding_dim = self.cfg.embedding_dim + self.encoder_cfg = self.cfg.encoder + self.encode_layers = MLP(in_channels=self.encoder_cfg.input_dim, + hidden_channels=self.encoder_cfg.hidden_dim, + out_channels=self.embedding_dim, + layer_num=self.encoder_cfg.layer_num, + layer_fn=fc_block, + activation=self.encoder_cfg.activation, + norm_type=self.encoder_cfg.norm_type, + use_dropout=False) + + self.transformer_cfg = self.cfg.transformer + self.transformer = Transformer( + n_heads=self.transformer_cfg.head_num, + embedding_size=self.embedding_dim, + ffn_size=self.transformer_cfg.ffn_size, + n_layers=self.transformer_cfg.layer_num, + attention_dropout=0.0, + relu_dropout=0.0, + dropout=0.0, + activation=self.transformer_cfg.activation, + variant=self.transformer_cfg.variant, + ) + self.output_cfg = self.cfg.output + self.output_fc = fc_block(self.embedding_dim, + self.output_cfg.output_dim, + norm_type=self.output_cfg.norm_type, + activation=self.output_cfg.activation) + + def forward(self, x): + ball_num = x['ball_num'] + embeddings = [] + mask = sequence_mask(ball_num, max_len=x['x'].shape[1]) + for key, item in self.cfg.modules.items(): + assert key in x, key + x_input = x[key] + embeddings.append(self.encode_modules[key](x_input)) + x = torch.cat(embeddings, dim=-1) + x = self.encode_layers(x) + x = self.transformer(x, mask=mask) + + ball_info = x.sum(dim=1) / ball_num.unsqueeze(dim=-1) + ball_info = self.output_fc(ball_info) + return x, ball_info + + +class SpatialEncoder(nn.Module): + def __init__(self, cfg): + super(SpatialEncoder, self).__init__() + self.whole_cfg = cfg + self.cfg = self.whole_cfg.model.spatial_encoder + + # scatter related + self.spatial_x = 64 + self.spatial_y = 64 + self.scatter_cfg = self.cfg.scatter + self.scatter_fc = fc_block(in_channels=self.scatter_cfg.input_dim, out_channels=self.scatter_cfg.output_dim, + activation=self.scatter_cfg.activation, norm_type=self.scatter_cfg.norm_type) + self.scatter_connection = ScatterConnection(self.scatter_cfg.scatter_type) + + # resnet related + self.resnet_cfg = self.cfg.resnet + self.get_resnet_blocks() + + self.output_cfg = self.cfg.output + self.output_fc = fc_block( + in_channels=self.spatial_x // 8 * self.spatial_y // 8 * self.resnet_cfg.down_channels[-1], + out_channels=self.output_cfg.output_dim, + norm_type=self.output_cfg.norm_type, + activation=self.output_cfg.activation) + + def get_resnet_blocks(self): + # 2 means food/spore embedding + project = conv2d_block(in_channels=self.scatter_cfg.output_dim + 2, + out_channels=self.resnet_cfg.project_dim, + kernel_size=1, + stride=1, + padding=0, + activation=self.resnet_cfg.activation, + norm_type=self.resnet_cfg.norm_type, + bias=False, + ) + + layers = [project] + dims = [self.resnet_cfg.project_dim] + self.resnet_cfg.down_channels + for i in range(len(dims) - 1): + layer = conv2d_block(in_channels=dims[i], + out_channels=dims[i + 1], + kernel_size=4, + stride=2, + padding=1, + activation=self.resnet_cfg.activation, + norm_type=self.resnet_cfg.norm_type, + bias=False, + ) + layers.append(layer) + layers.append(ResBlock(in_channels=dims[i + 1], + activation=self.resnet_cfg.activation, + norm_type=self.resnet_cfg.norm_type)) + self.resnet = torch.nn.Sequential(*layers) + + + def get_background_embedding(self, coord_x, coord_y, num, ): + + background_ones = torch.ones(size=(coord_x.shape[0], coord_x.shape[1]), device=coord_x.device) + background_mask = sequence_mask(num, max_len=coord_x.shape[1]) + background_ones = (background_ones * background_mask).unsqueeze(-1) + background_embedding = self.scatter_connection.xy_forward(background_ones, + spatial_size=[self.spatial_x, self.spatial_y], + coord_x=coord_x, + coord_y=coord_y) + + return background_embedding + + def forward(self, inputs, ball_embeddings, ): + spatial_info = inputs['spatial_info'] + # food and spore + food_embedding = self.get_background_embedding(coord_x=spatial_info['food_x'], + coord_y=spatial_info['food_y'], + num=spatial_info['food_num'], ) + + spore_embedding = self.get_background_embedding(coord_x=spatial_info['spore_x'], + coord_y=spatial_info['spore_y'], + num=spatial_info['spore_num'], ) + # scatter ball embeddings + ball_info = inputs['ball_info'] + ball_num = ball_info['ball_num'] + ball_mask = sequence_mask(ball_num, max_len=ball_embeddings.shape[1]) + ball_embedding = self.scatter_fc(ball_embeddings) * ball_mask.unsqueeze(dim=2) + + ball_embedding = self.scatter_connection.xy_forward(ball_embedding, + spatial_size=[self.spatial_x, self.spatial_y], + coord_x=spatial_info['ball_x'], + coord_y=spatial_info['ball_y']) + + x = torch.cat([food_embedding, spore_embedding, ball_embedding], dim=1) + + x = self.resnet(x) + + x = torch.flatten(x, start_dim=1, end_dim=-1) + x = self.output_fc(x) + return x + + +class Encoder(nn.Module): + def __init__(self, cfg): + super(Encoder, self).__init__() + self.whole_cfg = cfg + self.scalar_encoder = ScalarEncoder(cfg) + self.team_encoder = TeamEncoder(cfg) + self.ball_encoder = BallEncoder(cfg) + self.spatial_encoder = SpatialEncoder(cfg) + + def forward(self, x): + scalar_info = self.scalar_encoder(x['scalar_info']) + team_info = self.team_encoder(x['team_info']) + ball_embeddings, ball_info = self.ball_encoder(x['ball_info']) + spatial_info = self.spatial_encoder(x, ball_embeddings) + x = torch.cat([scalar_info, team_info, ball_info, spatial_info], dim=1) + return x diff --git a/practice/tools/features.py b/practice/tools/features.py new file mode 100644 index 0000000..6f7d303 --- /dev/null +++ b/practice/tools/features.py @@ -0,0 +1,390 @@ +import math +import torch + + +class Features: + def __init__(self, cfg): + self.cfg = cfg + self.player_num_per_team = self.cfg.env.player_num_per_team + self.team_num = self.cfg.env.team_num + self.max_player_num = self.player_num_per_team + self.max_team_num = self.team_num + self.max_ball_num = self.cfg.agent.features.get('max_ball_num', 80) + self.max_food_num = self.cfg.agent.features.get('max_food_num', 256) + self.max_spore_num = self.cfg.agent.features.get('max_spore_num', 64) + self.direction_num = self.cfg.agent.features.get('direction_num', 12) + self.spatial_x = 64 + self.spatial_y = 64 + self.step_mul = self.cfg.env.get('step_mul', 5) + self.second_per_frame = self.cfg.agent.features.get('second_per_frame', 0.05) + self.action_num = self.direction_num * 2 + 3 + self.setup_action() + self._init_fake_data() + + def get_augmentation_map(self): + augmentation_mapping = {} + for aug_type in ['ud', 'lr', 'lrud']: + augmentation_mapping[aug_type] = {action: self.augmentation_action(action, aug_type=aug_type) for action in + range(self.action_num)} + return augmentation_mapping + + def setup_action(self): + theta = math.pi * 2 / self.direction_num + self.x_y_action_List = [[0.3 * math.cos(theta * i), 0.3 * math.sin(theta * i), 0] for i in + range(self.direction_num)] + \ + [[math.cos(theta * i), math.sin(theta * i), 0] for i in + range(self.direction_num)] + \ + [[0, 0, 0], [0, 0, 1], [0, 0, 2]] + + def _init_fake_data(self): + self.SCALAR_INFO = { + 'view_x': (torch.long, ()), + 'view_y': (torch.long, ()), # view center (x, y) from left bottom + 'view_width': (torch.long, ()), + # 'view_height': (torch.long, ()), + 'score': (torch.long, ()), # log + 'team_score': (torch.long, ()), + 'rank': (torch.long, ()), + 'time': (torch.long, ()), + 'last_action_type': (torch.long, ()) + } + self.TEAM_INFO = { + 'alliance': (torch.long, (self.max_player_num,)), + 'view_x': (torch.long, (self.max_player_num,)), + 'view_y': (torch.long, (self.max_player_num,)), # view center (x, y) from left bottom + # 'view_width': (torch.long, (self.max_player_num * self.max_team_num,)), + # 'view_height': (torch.long, (self.max_player_num * self.max_team_num,)), + # 'score': (torch.long, (self.max_player_num * self.max_team_num,)), # log + # 'team_score': (torch.long, (self.max_player_num * self.max_team_num,)), + # 'team_rank': (torch.long, (self.max_player_num * self.max_team_num,)), + 'player_num': (torch.long, ()), + + } + self.BALL_INFO = { + 'alliance': (torch.long, (self.max_ball_num,)), # 0 neutral, 1 self, 2 teammate, 3 enemy, + 'score': (torch.long, (self.max_ball_num,)), # log own score + 1 + 'radius': (torch.float, (self.max_ball_num,)), # log (ratio to own score + 1) + 'rank': (torch.long, (self.max_ball_num,)), # onehot 0 neural, 1-10 for team rank + 'x': (torch.long, (self.max_ball_num,)), # binary relative coordinate to view center + 'y': (torch.long, (self.max_ball_num,)), # binary relative coordinate to view center + 'next_x': (torch.long, (self.max_ball_num,)), # binary relative coordinate to view center + 'next_y': (torch.long, (self.max_ball_num,)), # binary relative coordinate to view center + 'ball_num': (torch.long, ()) + } + + self.SPATIAL_INFO = { + 'food_x': (torch.long, (self.max_food_num,)), # relative coordinate to view center + 'food_y': (torch.long, (self.max_food_num,)), # relative coordinate to view center + 'spore_x': (torch.long, (self.max_spore_num,)), # relative coordinate to view center + 'spore_y': (torch.long, (self.max_spore_num,)), # relative coordinate to view center + 'ball_x': (torch.long, (self.max_ball_num,)), # relative coordinate to view center + 'ball_y': (torch.long, (self.max_ball_num,)), # relative coordinate to view center + 'food_num': (torch.long, ()), + 'spore_num': (torch.long, ()) + } + self.REWARD_INFO = { + 'score': (torch.float, ()), + 'spore': (torch.float, ()), + 'mate_spore': (torch.float, ()), + 'team_spore': (torch.float, ()), + 'clone': (torch.float, ()), + 'team_clone': (torch.float, ()), + 'opponent': (torch.float, ()), + 'team_opponent': (torch.float, ()), + 'max_dist': (torch.float, ()), + 'min_dist': (torch.float, ()), + } + + self.ACTION_INFO = { + 'action': (torch.long, ()), + 'logit': (torch.float, (self.action_num,)), # self.direction_num * 2, feed -3, split -2, stop -1, + 'action_logp': (torch.long, ()), + } + + def get_rl_step_data(self, last=False): + data = {} + scalar_info = {k: torch.ones(size=v[1], dtype=v[0]) for k, v in self.SCALAR_INFO.items()} + team_info = {k: torch.ones(size=v[1], dtype=v[0]) for k, v in self.TEAM_INFO.items()} + ball_info = {k: torch.ones(size=v[1], dtype=v[0]) for k, v in self.BALL_INFO.items()} + spatial_info = {k: torch.ones(size=v[1], dtype=v[0]) for k, v in self.SPATIAL_INFO.items()} + action_mask = torch.zeros(size=(self.action_num,), dtype=torch.bool) + + data['obs'] = {'scalar_info': scalar_info, 'team_info': team_info, 'ball_info': ball_info, + 'spatial_info': spatial_info, 'action_mask': action_mask} + if not last: + data['action'] = torch.zeros(size=(), dtype=torch.long) + data['action_logp'] = torch.zeros(size=(), dtype=torch.float) + data['reward'] = {k: torch.zeros(size=v[1], dtype=v[0]) for k, v in self.REWARD_INFO.items()} + data['done'] = torch.zeros(size=(), dtype=torch.bool) + data['model_last_iter'] = torch.zeros(size=(), dtype=torch.float) + return data + + def get_player2team(self, ): + player2team = {} + for player_id in range(self.player_num_per_team * self.team_num): + player2team[player_id] = player_id // self.player_num_per_team + return player2team + + def transform_obs(self, obs, game_player_id=1, padding=True, last_action_type=None, ): + global_state, player_observations = obs + player2team = self.get_player2team() + own_player_id = game_player_id + leaderboard = global_state['leaderboard'] + team2rank = {key: rank for rank, key in enumerate(sorted(leaderboard, key=leaderboard.get, reverse=True), )} + + own_player_obs = player_observations[own_player_id] + own_team_id = player2team[own_player_id] + + # =========== + # scalar info + # =========== + scene_size = global_state['border'][0] + own_left_top_x, own_left_top_y, own_right_bottom_x, own_right_bottom_y = own_player_obs['rectangle'] + own_view_center = [(own_left_top_x + own_right_bottom_x - scene_size) / 2, + (own_left_top_y + own_right_bottom_y - scene_size) / 2] + own_view_width = float(own_right_bottom_x - own_left_top_x) + # own_view_height = float(own_right_bottom_y - own_left_top_y) + + own_score = own_player_obs['score'] / 100 + own_team_score = global_state['leaderboard'][own_team_id] / 100 + own_rank = team2rank[own_team_id] + + scalar_info = { + 'view_x': torch.tensor(own_view_center[0]).round().long(), + 'view_y': torch.tensor(own_view_center[1]).round().long(), + 'view_width': torch.tensor(own_view_width).round().long(), + 'score': torch.log(torch.tensor(own_score) / 10).round().long().clamp_(max=9), + 'team_score': torch.log(torch.tensor(own_team_score / 10)).round().long().clamp_(max=9), + 'time': torch.tensor(global_state['last_time']//20, dtype=torch.long), + 'rank': torch.tensor(own_rank, dtype=torch.long), + 'last_action_type': torch.tensor(last_action_type, dtype=torch.long) + } + # =========== + # team_info + # =========== + + all_players = [] + scene_size = global_state['border'][0] + + for game_player_id in player_observations.keys(): + game_team_id = player2team[game_player_id] + game_player_left_top_x, game_player_left_top_y, game_player_right_bottom_x, game_player_right_bottom_y = \ + player_observations[game_player_id]['rectangle'] + if game_player_id == own_player_id: + alliance = 0 + elif game_team_id == own_team_id: + alliance = 1 + else: + alliance = 2 + if alliance != 2: + game_player_view_x = (game_player_right_bottom_x + game_player_left_top_x - scene_size) / 2 + game_player_view_y = (game_player_right_bottom_y + game_player_left_top_y - scene_size) / 2 + + # game_player_view_width = game_player_right_bottom_x - game_player_left_top_x + # game_player_view_height = game_player_right_bottom_y - game_player_left_top_y + # + # game_player_score = math.log((player_observations[game_player_id]['score'] + 1) / 1000) + # game_player_team_score = math.log((global_state['leaderboard'][game_team_id] + 1) / 1000) + # game_player_rank = team2rank[game_team_id] + + all_players.append([alliance, + game_player_view_x, + game_player_view_y, + # game_player_view_width, + # game_player_view_height, + # game_player_score, + # game_player_team_score, + # game_player_rank, + ]) + all_players = torch.as_tensor(all_players) + player_padding_num = self.max_player_num - len(all_players) + player_num = len(all_players) + all_players = torch.nn.functional.pad(all_players, (0, 0, 0, player_padding_num), 'constant', 0) + team_info = { + 'alliance': all_players[:, 0].long(), + 'view_x': all_players[:, 1].round().long(), + 'view_y': all_players[:, 2].round().long(), + # 'view_width': all_players[:,3].round().long(), + # 'view_height': all_players[:,4].round().long(), + # 'score': all_players[:,5].round().long().clamp_(max=10,min=0), + # 'team_score': all_players[:,6].round().long().clamp_(max=10,min=0), + # 'team_rank': all_players[:,7].long(), + 'player_num': torch.tensor(player_num, dtype=torch.long), + } + + # =========== + # ball info + # =========== + ball_type_map = {'clone': 1, 'food': 2, 'thorns': 3, 'spore': 4} + clone = own_player_obs['overlap']['clone'] + thorns = own_player_obs['overlap']['thorns'] + food = own_player_obs['overlap']['food'] + spore = own_player_obs['overlap']['spore'] + + neutral_team_id = self.team_num + neutral_player_id = self.team_num * self.player_num_per_team + neutral_team_rank = self.team_num + + clone = [[ball_type_map['clone'], bl[3], bl[-2], bl[-1], team2rank[bl[-1]], bl[0], bl[1], + *self.next_position(bl[0], bl[1], bl[4], bl[5])] for bl in clone] + thorns = [[ball_type_map['thorns'], bl[3], neutral_player_id, neutral_team_id, neutral_team_rank, bl[0], bl[1], + *self.next_position(bl[0], bl[1], bl[4], bl[5])] for bl in thorns] + food = [ + [ball_type_map['food'], bl[3], neutral_player_id, neutral_team_id, neutral_team_rank, bl[0], bl[1], bl[0], + bl[1]] for bl in food] + + spore = [ + [ball_type_map['spore'], bl[3], bl[-1], player2team[bl[-1]], team2rank[player2team[bl[-1]]], bl[0], + bl[1], + *self.next_position(bl[0], bl[1], bl[4], bl[5])] for bl in spore] + + all_balls = clone + thorns + food + spore + + for b in all_balls: + if b[2] == own_player_id and b[0] == 1: + if b[5] < own_left_top_x or b[5] > own_right_bottom_x or \ + b[6] < own_left_top_y or b[6] > own_right_bottom_y: + b[5] = int((own_left_top_x + own_right_bottom_x) / 2) + b[6] = int((own_left_top_y + own_right_bottom_y) / 2) + b[7], b[8] = b[5], b[6] + all_balls = torch.as_tensor(all_balls, dtype=torch.float) + + origin_x = own_left_top_x + origin_y = own_left_top_y + + all_balls[:, -4] = ((all_balls[:, -4] - origin_x) / own_view_width * self.spatial_x) + all_balls[:, -3] = ((all_balls[:, -3] - origin_y) / own_view_width * self.spatial_y) + all_balls[:, -2] = ((all_balls[:, -2] - origin_x) / own_view_width * self.spatial_x) + all_balls[:, -1] = ((all_balls[:, -1] - origin_y) / own_view_width * self.spatial_y) + + # ball + ball_indices = torch.logical_and(all_balls[:, 0] != 2, + all_balls[:, 0] != 4) # include player balls and thorn balls + balls = all_balls[ball_indices] + + balls_num = len(balls) + + # consider position of thorns ball + if balls_num > self.max_ball_num: # filter small balls + own_indices = balls[:, 3] == own_player_id + teammate_indices = (balls[:, 4] == own_team_id) & ~own_indices + enemy_indices = balls[:, 4] != own_team_id + + own_balls = balls[own_indices] + teammate_balls = balls[teammate_indices] + enemy_balls = balls[enemy_indices] + + if own_balls.shape[0] + teammate_balls.shape[0] >= self.max_ball_num: + remain_ball_num = self.max_ball_num - own_balls.shape[0] + teammate_ball_score = teammate_balls[:, 1] + teammate_high_score_indices = teammate_ball_score.sort(descending=True)[1][:remain_ball_num] + teammate_remain_balls = teammate_balls[teammate_high_score_indices] + balls = torch.cat([own_balls, teammate_remain_balls]) + else: + remain_ball_num = self.max_ball_num - own_balls.shape[0] - teammate_balls.shape[0] + enemy_ball_score = enemy_balls[:, 1] + enemy_high_score_ball_indices = enemy_ball_score.sort(descending=True)[1][:remain_ball_num] + remain_enemy_balls = enemy_balls[enemy_high_score_ball_indices] + + balls = torch.cat([own_balls, teammate_balls, remain_enemy_balls]) + balls_num = len(balls) + ball_padding_num = self.max_ball_num - len(balls) + if padding or ball_padding_num < 0: + balls = torch.nn.functional.pad(balls, (0, 0, 0, ball_padding_num), 'constant', 0) + alliance = torch.zeros(self.max_ball_num) + balls_num = min(self.max_ball_num, balls_num) + else: + alliance = torch.zeros(balls_num) + alliance[balls[:, 3] == own_team_id] = 2 + alliance[balls[:, 2] == own_player_id] = 1 + alliance[balls[:, 3] != own_team_id] = 3 + alliance[balls[:, 0] == 3] = 0 + + ## score&radius + scale_score = balls[:, 1] / 100 + radius = (torch.sqrt(scale_score * 0.042 + 0.15) / own_view_width).clamp_(max=1) + score = ((torch.sqrt(scale_score * 0.042 + 0.15) / own_view_width).clamp_(max=1) * 50).round().long().clamp_( + max=49) + + ## rank: + ball_rank = balls[:, 4] + + ## coordinate + x = balls[:, -4] - self.spatial_x // 2 + y = balls[:, -3] - self.spatial_y // 2 + next_x = balls[:, -2] - self.spatial_x // 2 + next_y = balls[:, -1] - self.spatial_y // 2 + + ball_info = { + 'alliance': alliance.long(), + 'score': score.long(), + 'radius': radius, + 'rank': ball_rank.long(), + 'x': x.round().long(), + 'y': y.round().long(), + 'next_x': next_x.round().long(), + 'next_y': next_y.round().long(), + 'ball_num': torch.tensor(balls_num, dtype=torch.long) + } + + # ============ + # spatial info + # ============ + # ball coordinate for scatter connection + ball_x = balls[:, -4] + ball_y = balls[:, -3] + + food_indices = all_balls[:, 0] == 2 + food_x = all_balls[food_indices, -4] + food_y = all_balls[food_indices, -3] + food_num = len(food_x) + food_padding_num = self.max_food_num - len(food_x) + if padding or food_padding_num < 0: + food_x = torch.nn.functional.pad(food_x, (0, food_padding_num), 'constant', 0) + food_y = torch.nn.functional.pad(food_y, (0, food_padding_num), 'constant', 0) + food_num = min(food_num, self.max_food_num) + + spore_indices = all_balls[:, 0] == 4 + spore_x = all_balls[spore_indices, -4] + spore_y = all_balls[spore_indices, -3] + spore_num = len(spore_x) + spore_padding_num = self.max_spore_num - len(spore_x) + if padding or spore_padding_num < 0: + spore_x = torch.nn.functional.pad(spore_x, (0, spore_padding_num), 'constant', 0) + spore_y = torch.nn.functional.pad(spore_y, (0, spore_padding_num), 'constant', 0) + spore_num = min(spore_num, self.max_spore_num) + + spatial_info = { + 'food_x': food_x.round().clamp_(min=0, max=self.spatial_x - 1).long(), + 'food_y': food_y.round().clamp_(min=0, max=self.spatial_y - 1).long(), + 'spore_x': spore_x.round().clamp_(min=0, max=self.spatial_x - 1).long(), + 'spore_y': spore_y.round().clamp_(min=0, max=self.spatial_y - 1).long(), + 'ball_x': ball_x.round().clamp_(min=0, max=self.spatial_x - 1).long(), + 'ball_y': ball_y.round().clamp_(min=0, max=self.spatial_y - 1).long(), + 'food_num': torch.tensor(food_num, dtype=torch.long), + 'spore_num': torch.tensor(spore_num, dtype=torch.long) + } + + output_obs = { + 'scalar_info': scalar_info, + 'team_info': team_info, + 'ball_info': ball_info, + 'spatial_info': spatial_info, + } + return output_obs + + def generate_action_mask(self, can_eject, can_split, ): + action_mask = torch.zeros(size=(self.action_num,), dtype=torch.bool) + if not can_eject: + action_mask[self.direction_num * 2 + 1] = True + if not can_split: + action_mask[self.direction_num * 2 + 2] = True + return action_mask + + def transform_action(self, action_idx): + return self.x_y_action_List[int(action_idx)] + + def next_position(self, x, y, vel_x, vel_y): + next_x = x + self.second_per_frame * vel_x * self.step_mul + next_y = y + self.second_per_frame * vel_y * self.step_mul + return next_x, next_y diff --git a/practice/tools/head.py b/practice/tools/head.py new file mode 100644 index 0000000..ba4b0dd --- /dev/null +++ b/practice/tools/head.py @@ -0,0 +1,71 @@ +import torch.nn as nn + +from .network.nn_module import fc_block +from .network.res_block import ResFCBlock + + +class PolicyHead(nn.Module): + def __init__(self, cfg): + super(PolicyHead, self).__init__() + self.whole_cfg = cfg + self.cfg = self.whole_cfg.model.policy + + self.embedding_dim = self.cfg.embedding_dim + self.project_cfg = self.cfg.project + self.project = fc_block(in_channels=self.project_cfg.input_dim, + out_channels=self.embedding_dim, + activation= self.project_cfg.activation, + norm_type=self.project_cfg.norm_type) + + self.resnet_cfg = self.cfg.resnet + blocks = [ResFCBlock(in_channels=self.embedding_dim, + activation=self.resnet_cfg.activation, + norm_type=self.resnet_cfg.norm_type) + for _ in range(self.resnet_cfg.res_num)] + self.resnet = nn.Sequential(*blocks) + + self.direction_num = self.whole_cfg.agent.features.get('direction_num', 12) + self.action_num = 2 * self.direction_num + 3 + self.output_layer = fc_block(in_channels=self.embedding_dim, + out_channels=self.action_num, + norm_type=None, + activation=None) + + def forward(self, x, temperature=1): + x = self.project(x) + x = self.resnet(x) + logit = self.output_layer(x) + logit /= temperature + return logit + + +class ValueHead(nn.Module): + def __init__(self, cfg): + super(ValueHead, self).__init__() + self.whole_cfg = cfg + self.cfg = self.whole_cfg.model.value + + self.embedding_dim = self.cfg.embedding_dim + self.project_cfg = self.cfg.project + self.project = fc_block(in_channels=self.project_cfg.input_dim, + out_channels=self.embedding_dim, + activation= self.project_cfg.activation, + norm_type=self.project_cfg.norm_type) + + self.resnet_cfg = self.cfg.resnet + blocks = [ResFCBlock(in_channels=self.embedding_dim, + activation=self.resnet_cfg.activation, + norm_type=self.resnet_cfg.norm_type) + for _ in range(self.resnet_cfg.res_num)] + self.resnet = nn.Sequential(*blocks) + + self.output_layer = fc_block(in_channels=self.embedding_dim, + out_channels=1, + norm_type=None, + activation=None) + def forward(self, x): + x = self.project(x) + x = self.resnet(x) + x = self.output_layer(x) + x = x.squeeze(1) + return x \ No newline at end of file diff --git a/practice/tools/network/__init__.py b/practice/tools/network/__init__.py new file mode 100644 index 0000000..50e7db8 --- /dev/null +++ b/practice/tools/network/__init__.py @@ -0,0 +1,8 @@ +from .activation import build_activation +from .res_block import ResBlock, ResFCBlock,ResFCBlock2 +from .nn_module import fc_block, fc_block2, conv2d_block, MLP +from .normalization import build_normalization +from .rnn import get_lstm, sequence_mask +from .soft_argmax import SoftArgmax +from .transformer import Transformer +from .scatter_connection import ScatterConnection diff --git a/practice/tools/network/activation.py b/practice/tools/network/activation.py new file mode 100644 index 0000000..550bee3 --- /dev/null +++ b/practice/tools/network/activation.py @@ -0,0 +1,96 @@ +""" +Copyright 2020 Sensetime X-lab. All Rights Reserved + +Main Function: + 1. build activation: you can use build_activation to build relu or glu +""" +import torch +import torch.nn as nn + + +class GLU(nn.Module): + r""" + Overview: + Gating Linear Unit. + This class does a thing like this: + + .. code:: python + + # Inputs: input, context, output_size + # The gate value is a learnt function of the input. + gate = sigmoid(linear(input.size)(context)) + # Gate the input and return an output of desired size. + gated_input = gate * input + output = linear(output_size)(gated_input) + return output + Interfaces: + forward + + .. tip:: + + This module also supports 2D convolution, in which case, the input and context must have the same shape. + """ + + def __init__(self, input_dim: int, output_dim: int, context_dim: int, input_type: str = 'fc') -> None: + r""" + Overview: + Init GLU + Arguments: + - input_dim (:obj:`int`): the input dimension + - output_dim (:obj:`int`): the output dimension + - context_dim (:obj:`int`): the context dimension + - input_type (:obj:`str`): the type of input, now support ['fc', 'conv2d'] + """ + super(GLU, self).__init__() + assert (input_type in ['fc', 'conv2d']) + if input_type == 'fc': + self.layer1 = nn.Linear(context_dim, input_dim) + self.layer2 = nn.Linear(input_dim, output_dim) + elif input_type == 'conv2d': + self.layer1 = nn.Conv2d(context_dim, input_dim, 1, 1, 0) + self.layer2 = nn.Conv2d(input_dim, output_dim, 1, 1, 0) + + def forward(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor: + r""" + Overview: + Return GLU computed tensor + Arguments: + - x (:obj:`torch.Tensor`) : the input tensor + - context (:obj:`torch.Tensor`) : the context tensor + Returns: + - x (:obj:`torch.Tensor`): the computed tensor + """ + gate = self.layer1(context) + gate = torch.sigmoid(gate) + x = gate * x + x = self.layer2(x) + return x + +class Swish(nn.Module): + + def __init__(self): + super(Swish, self).__init__() + + def forward(self, x): + x = x * torch.sigmoid(x) + return x + +def build_activation(activation: str, inplace: bool = None) -> nn.Module: + r""" + Overview: + Return the activation module according to the given type. + Arguments: + - actvation (:obj:`str`): the type of activation module, now supports ['relu', 'glu', 'prelu'] + - inplace (:obj:`bool`): can optionally do the operation in-place in relu. Default ``None`` + Returns: + - act_func (:obj:`nn.module`): the corresponding activation module + """ + if inplace is not None: + assert activation == 'relu', 'inplace argument is not compatible with {}'.format(activation) + else: + inplace = True + act_func = {'relu': nn.ReLU(inplace=inplace), 'glu': GLU, 'prelu': nn.PReLU(),'swish': Swish()} + if activation in act_func.keys(): + return act_func[activation] + else: + raise KeyError("invalid key for activation: {}".format(activation)) diff --git a/practice/tools/network/encoder.py b/practice/tools/network/encoder.py new file mode 100644 index 0000000..daa014e --- /dev/null +++ b/practice/tools/network/encoder.py @@ -0,0 +1,136 @@ +import numpy as np +import torch +import torch.nn as nn + + +class OnehotEncoder(nn.Module): + def __init__(self, num_embeddings: int): + super(OnehotEncoder, self).__init__() + self.num_embeddings = num_embeddings + self.main = nn.Embedding.from_pretrained(torch.eye(self.num_embeddings), freeze=True, + padding_idx=None) + + def forward(self, x: torch.Tensor): + x = x.long().clamp_(max=self.num_embeddings - 1) + return self.main(x) + + +class OnehotEmbedding(nn.Module): + def __init__(self, num_embeddings: int, embedding_dim: int): + super(OnehotEmbedding, self).__init__() + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.main = nn.Embedding(num_embeddings=self.num_embeddings, embedding_dim=self.embedding_dim) + + def forward(self, x: torch.Tensor): + x = x.long().clamp_(max=self.num_embeddings - 1) + return self.main(x) + + +class BinaryEncoder(nn.Module): + def __init__(self, num_embeddings: int): + super(BinaryEncoder, self).__init__() + self.bit_num = num_embeddings + self.main = nn.Embedding.from_pretrained(self.get_binary_embed_matrix(self.bit_num), freeze=True, + padding_idx=None) + + @staticmethod + def get_binary_embed_matrix(bit_num): + embedding_matrix = [] + for n in range(2 ** bit_num): + embedding = [n >> d & 1 for d in range(bit_num)][::-1] + embedding_matrix.append(embedding) + return torch.tensor(embedding_matrix, dtype=torch.float) + + def forward(self, x: torch.Tensor): + x = x.long().clamp_(max=2 ** self.bit_num - 1) + return self.main(x) + + +class SignBinaryEncoder(nn.Module): + def __init__(self, num_embeddings): + super(SignBinaryEncoder, self).__init__() + self.bit_num = num_embeddings + self.main = nn.Embedding.from_pretrained(self.get_sign_binary_matrix(self.bit_num), freeze=True, + padding_idx=None) + self.max_val = 2 ** (self.bit_num - 1) - 1 + + @staticmethod + def get_sign_binary_matrix(bit_num): + neg_embedding_matrix = [] + pos_embedding_matrix = [] + for n in range(1, 2 ** (bit_num - 1)): + embedding = [n >> d & 1 for d in range(bit_num - 1)][::-1] + neg_embedding_matrix.append([1] + embedding) + pos_embedding_matrix.append([0] + embedding) + embedding_matrix = neg_embedding_matrix[::-1] + [[0 for _ in range(bit_num)]] + pos_embedding_matrix + return torch.tensor(embedding_matrix, dtype=torch.float) + + def forward(self, x: torch.Tensor): + x = x.long().clamp_(max=self.max_val, min=- self.max_val) + return self.main(x + self.max_val) + + +class PositionEncoder(nn.Module): + def __init__(self, num_embeddings, embedding_dim=None): + super(PositionEncoder, self).__init__() + self.n_position = num_embeddings + self.embedding_dim = self.n_position if embedding_dim is None else embedding_dim + self.position_enc = nn.Embedding.from_pretrained( + self.position_encoding_init(self.n_position, self.embedding_dim), + freeze=True, padding_idx=None) + + @staticmethod + def position_encoding_init(n_position, embedding_dim): + ''' Init the sinusoid position encoding table ''' + + # keep dim 0 for padding token position encoding zero vector + position_enc = np.array([ + [pos / np.power(10000, 2 * (j // 2) / embedding_dim) for j in range(embedding_dim)] + for pos in range(n_position)]) + position_enc[:, 0::2] = np.sin(position_enc[:, 0::2]) # apply sin on 0th,2nd,4th...embedding_dim + position_enc[:, 1::2] = np.cos(position_enc[:, 1::2]) # apply cos on 1st,3rd,5th...embedding_dim + return torch.from_numpy(position_enc).type(torch.FloatTensor) + + def forward(self, x: torch.Tensor): + return self.position_enc(x) + + +class TimeEncoder(nn.Module): + def __init__(self, embedding_dim): + super(TimeEncoder, self).__init__() + self.embedding_dim = embedding_dim + self.position_array = torch.nn.Parameter(self.get_position_array(), requires_grad=False) + + def get_position_array(self): + x = torch.arange(0, self.embedding_dim, dtype=torch.float) + x = x // 2 * 2 + x = torch.div(x, self.embedding_dim) + x = torch.pow(10000., x) + x = torch.div(1., x) + return x + + def forward(self, x: torch.Tensor): + v = torch.zeros(size=(x.shape[0], self.embedding_dim), dtype=torch.float, device=x.device) + assert len(x.shape) == 1 + x = x.unsqueeze(dim=1) + v[:, 0::2] = torch.sin(x * self.position_array[0::2]) # even + v[:, 1::2] = torch.cos(x * self.position_array[1::2]) # odd + return v + + +class UnsqueezeEncoder(nn.Module): + def __init__(self, unsqueeze_dim: int = -1, norm_value: float = 1): + super(UnsqueezeEncoder, self).__init__() + self.unsqueeze_dim = unsqueeze_dim + self.norm_value = norm_value + + def forward(self, x: torch.Tensor): + x = x.float().unsqueeze(dim=self.unsqueeze_dim) + if self.norm_value != 1: + x = x / self.norm_value + return x + + +if __name__ == '__main__': + pass diff --git a/practice/tools/network/nn_module.py b/practice/tools/network/nn_module.py new file mode 100644 index 0000000..9768314 --- /dev/null +++ b/practice/tools/network/nn_module.py @@ -0,0 +1,235 @@ +from typing import Callable + +import torch +import torch.nn as nn + +from .activation import build_activation +from .normalization import build_normalization + + +def fc_block( + in_channels: int, + out_channels: int, + activation: nn.Module = None, + norm_type: str = None, + use_dropout: bool = False, + dropout_probability: float = 0.5 +) -> nn.Sequential: + r""" + Overview: + Create a fully-connected block with activation, normalization and dropout. + Optional normalization can be done to the dim 1 (across the channels) + x -> fc -> norm -> act -> dropout -> out + Arguments: + - in_channels (:obj:`int`): Number of channels in the input tensor + - out_channels (:obj:`int`): Number of channels in the output tensor + - activation (:obj:`nn.Module`): the optional activation function + - norm_type (:obj:`str`): type of the normalization + - use_dropout (:obj:`bool`) : whether to use dropout in the fully-connected block + - dropout_probability (:obj:`float`) : probability of an element to be zeroed in the dropout. Default: 0.5 + Returns: + - block (:obj:`nn.Sequential`): a sequential list containing the torch layers of the fully-connected block + + .. note:: + + you can refer to nn.linear (https://pytorch.org/docs/master/generated/torch.nn.Linear.html) + """ + block = [] + block.append(nn.Linear(in_channels, out_channels)) + if norm_type is not None and norm_type != 'none': + block.append(build_normalization(norm_type, dim=1)(out_channels)) + if isinstance(activation, str) and activation != 'none': + block.append(build_activation(activation)) + elif isinstance(activation, torch.nn.Module): + block.append(activation) + if use_dropout: + block.append(nn.Dropout(dropout_probability)) + return nn.Sequential(*block) + + +def fc_block2( + in_channels, + out_channels, + activation=None, + norm_type=None, + use_dropout=False, + dropout_probability=0.5 +): + r""" + Overview: + create a fully-connected block with activation, normalization and dropout + optional normalization can be done to the dim 1 (across the channels) + x -> fc -> norm -> act -> dropout -> out + Arguments: + - in_channels (:obj:`int`): Number of channels in the input tensor + - out_channels (:obj:`int`): Number of channels in the output tensor + - init_type (:obj:`str`): the type of init to implement + - activation (:obj:`nn.Moduel`): the optional activation function + - norm_type (:obj:`str`): type of the normalization + - use_dropout (:obj:`bool`) : whether to use dropout in the fully-connected block + - dropout_probability (:obj:`float`) : probability of an element to be zeroed in the dropout. Default: 0.5 + Returns: + - block (:obj:`nn.Sequential`): a sequential list containing the torch layers of the fully-connected block + + .. note:: + you can refer to nn.linear (https://pytorch.org/docs/master/generated/torch.nn.Linear.html) + """ + block = [] + if norm_type is not None and norm_type != 'none': + block.append(build_normalization(norm_type, dim=1)(in_channels)) + if isinstance(activation, str) and activation != 'none': + block.append(build_activation(activation)) + elif isinstance(activation, torch.nn.Module): + block.append(activation) + block.append(nn.Linear(in_channels, out_channels)) + if use_dropout: + block.append(nn.Dropout(dropout_probability)) + return nn.Sequential(*block) + + +def conv2d_block( + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + groups: int = 1, + activation: str = None, + norm_type: str = None, + bias: bool = True, +) -> nn.Sequential: + r""" + Overview: + Create a 2-dim convlution layer with activation and normalization. + Arguments: + - in_channels (:obj:`int`): Number of channels in the input tensor + - out_channels (:obj:`int`): Number of channels in the output tensor + - kernel_size (:obj:`int`): Size of the convolving kernel + - stride (:obj:`int`): Stride of the convolution + - padding (:obj:`int`): Zero-padding added to both sides of the input + - dilation (:obj:`int`): Spacing between kernel elements + - groups (:obj:`int`): Number of blocked connections from input channels to output channels + - pad_type (:obj:`str`): the way to add padding, include ['zero', 'reflect', 'replicate'], default: None + - activation (:obj:`nn.Module`): the optional activation function + - norm_type (:obj:`str`): type of the normalization, default set to None, now support ['BN', 'IN', 'SyncBN'] + Returns: + - block (:obj:`nn.Sequential`): a sequential list containing the torch layers of the 2 dim convlution layer + + .. note:: + + Conv2d (https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html#torch.nn.Conv2d) + """ + block = [] + block.append( + nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, groups=groups,bias=bias) + ) + if norm_type is not None: + block.append(nn.GroupNorm(num_groups=1, num_channels=out_channels)) + if isinstance(activation, str) and activation != 'none': + block.append(build_activation(activation)) + elif isinstance(activation, torch.nn.Module): + block.append(activation) + return nn.Sequential(*block) + + +def conv2d_block2( + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + activation: str = None, + norm_type=None, + bias: bool = True, +): + r""" + Overview: + create a 2-dim convlution layer with activation and normalization. + + Note: + Conv2d (https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html#torch.nn.Conv2d) + + Arguments: + - in_channels (:obj:`int`): Number of channels in the input tensor + - out_channels (:obj:`int`): Number of channels in the output tensor + - kernel_size (:obj:`int`): Size of the convolving kernel + - stride (:obj:`int`): Stride of the convolution + - padding (:obj:`int`): Zero-padding added to both sides of the input + - dilation (:obj:`int`): Spacing between kernel elements + - groups (:obj:`int`): Number of blocked connections from input channels to output channels + - init_type (:obj:`str`): the type of init to implement + - pad_type (:obj:`str`): the way to add padding, include ['zero', 'reflect', 'replicate'], default: None + - activation (:obj:`nn.Moduel`): the optional activation function + - norm_type (:obj:`str`): type of the normalization, default set to None, now support ['BN', 'IN', 'SyncBN'] + + Returns: + - block (:obj:`nn.Sequential`): a sequential list containing the torch layers of the 2 dim convlution layer + """ + + block = [] + if norm_type is not None: + block.append(nn.GroupNorm(num_groups=1, num_channels=out_channels)) + if isinstance(activation, str) and activation != 'none': + block.append(build_activation(activation)) + elif isinstance(activation, torch.nn.Module): + block.append(activation) + block.append( + nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, groups=groups,bias=bias) + ) + return nn.Sequential(*block) + + +def MLP( + in_channels: int, + hidden_channels: int, + out_channels: int, + layer_num: int, + layer_fn: Callable = None, + activation: str = None, + norm_type: str = None, + use_dropout: bool = False, + dropout_probability: float = 0.5 +): + r""" + Overview: + create a multi-layer perceptron using fully-connected blocks with activation, normalization and dropout, + optional normalization can be done to the dim 1 (across the channels) + x -> fc -> norm -> act -> dropout -> out + Arguments: + - in_channels (:obj:`int`): Number of channels in the input tensor + - hidden_channels (:obj:`int`): Number of channels in the hidden tensor + - out_channels (:obj:`int`): Number of channels in the output tensor + - layer_num (:obj:`int`): Number of layers + - layer_fn (:obj:`Callable`): layer function + - activation (:obj:`nn.Module`): the optional activation function + - norm_type (:obj:`str`): type of the normalization + - use_dropout (:obj:`bool`): whether to use dropout in the fully-connected block + - dropout_probability (:obj:`float`): probability of an element to be zeroed in the dropout. Default: 0.5 + Returns: + - block (:obj:`nn.Sequential`): a sequential list containing the torch layers of the fully-connected block + + .. note:: + + you can refer to nn.linear (https://pytorch.org/docs/master/generated/torch.nn.Linear.html) + """ + assert layer_num >= 0, layer_num + if layer_num == 0: + return nn.Sequential(*[nn.Identity()]) + + channels = [in_channels] + [hidden_channels] * (layer_num - 1) + [out_channels] + if layer_fn is None: + layer_fn = fc_block + block = [] + for i, (in_channels, out_channels) in enumerate(zip(channels[:-1], channels[1:])): + block.append(layer_fn(in_channels=in_channels, + out_channels=out_channels, + activation=activation, + norm_type=norm_type, + use_dropout=use_dropout, + dropout_probability=dropout_probability)) + return nn.Sequential(*block) diff --git a/practice/tools/network/normalization.py b/practice/tools/network/normalization.py new file mode 100644 index 0000000..fd5831c --- /dev/null +++ b/practice/tools/network/normalization.py @@ -0,0 +1,36 @@ +from typing import Optional +import torch.nn as nn + + +def build_normalization(norm_type: str, dim: Optional[int] = None) -> nn.Module: + r""" + Overview: + Build the corresponding normalization module + Arguments: + - norm_type (:obj:`str`): type of the normaliztion, now support ['BN', 'IN', 'SyncBN', 'AdaptiveIN'] + - dim (:obj:`int`): dimension of the normalization, when norm_type is in [BN, IN] + Returns: + - norm_func (:obj:`nn.Module`): the corresponding batch normalization function + + .. note:: + For beginers, you can refer to to learn more about batch normalization. + """ + if dim is None: + key = norm_type + else: + if norm_type in ['BN', 'IN', 'SyncBN']: + key = norm_type + str(dim) + elif norm_type in ['LN']: + key = norm_type + else: + raise NotImplementedError("not support indicated dim when creates {}".format(norm_type)) + norm_func = { + 'BN1': nn.BatchNorm1d, + 'BN2': nn.BatchNorm2d, + 'LN': nn.LayerNorm, + 'IN2': nn.InstanceNorm2d, + } + if key in norm_func.keys(): + return norm_func[key] + else: + raise KeyError("invalid norm type: {}".format(key)) \ No newline at end of file diff --git a/practice/tools/network/res_block.py b/practice/tools/network/res_block.py new file mode 100644 index 0000000..f64fae1 --- /dev/null +++ b/practice/tools/network/res_block.py @@ -0,0 +1,231 @@ +""" +Copyright 2020 Sensetime X-lab. All Rights Reserved + +Main Function: + 1. build ResBlock: you can use this classes to build residual blocks +""" +import torch.nn as nn +from .nn_module import conv2d_block, fc_block,conv2d_block2,fc_block2 +from .activation import build_activation +from .normalization import build_normalization + + +class ResBlock(nn.Module): + r''' + Overview: + Residual Block with 2D convolution layers, including 2 types: + basic block: + input channel: C + x -> 3*3*C -> norm -> act -> 3*3*C -> norm -> act -> out + \__________________________________________/+ + bottleneck block: + x -> 1*1*(1/4*C) -> norm -> act -> 3*3*(1/4*C) -> norm -> act -> 1*1*C -> norm -> act -> out + \_____________________________________________________________________________/+ + + Interface: + __init__, forward + ''' + + def __init__(self, in_channels, out_channels=None,stride=1, downsample=None, activation='relu', norm_type='LN',): + r""" + Overview: + Init the Residual Block + + Arguments: + - in_channels (:obj:`int`): Number of channels in the input tensor + - activation (:obj:`nn.Module`): the optional activation function + - norm_type (:obj:`str`): type of the normalization, defalut set to batch normalization, + support ['BN', 'IN', 'SyncBN', None] + - res_type (:obj:`str`): type of residual block, support ['basic', 'bottleneck'], see overview for details + """ + super(ResBlock, self).__init__() + self.in_channels = in_channels + self.out_channels = self.in_channels if out_channels is None else out_channels + self.activation_type = activation + self.norm_type = norm_type + self.stride = stride + self.downsample = downsample + self.conv1 = conv2d_block(in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=self.stride, + padding= 1, + activation=self.activation_type, + norm_type=self.norm_type) + self.conv2 = conv2d_block(in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=self.stride, + padding= 1, + activation=None, + norm_type=self.norm_type) + self.activation = build_activation(self.activation_type) + + def forward(self, x): + r""" + Overview: + return the redisual block output + + Arguments: + - x (:obj:`tensor`): the input tensor + + Returns: + - x(:obj:`tensor`): the resblock output tensor + """ + residual = x + out = self.conv1(x) + out = self.conv2(out) + if self.downsample is not None: + residual = self.downsample(x) + out += residual + out = self.activation(out) + return out + + +class ResBlock2(nn.Module): + r''' + Overview: + Residual Block with 2D convolution layers, including 2 types: + basic block: + input channel: C + x -> 3*3*C -> norm -> act -> 3*3*C -> norm -> act -> out + \__________________________________________/+ + bottleneck block: + x -> 1*1*(1/4*C) -> norm -> act -> 3*3*(1/4*C) -> norm -> act -> 1*1*C -> norm -> act -> out + \_____________________________________________________________________________/+ + + Interface: + __init__, forward + ''' + + def __init__(self, in_channels, out_channels=None,stride=1, downsample=None, activation='relu', norm_type='LN',): + r""" + Overview: + Init the Residual Block + + Arguments: + - in_channels (:obj:`int`): Number of channels in the input tensor + - activation (:obj:`nn.Module`): the optional activation function + - norm_type (:obj:`str`): type of the normalization, defalut set to batch normalization, + support ['BN', 'IN', 'SyncBN', None] + - res_type (:obj:`str`): type of residual block, support ['basic', 'bottleneck'], see overview for details + """ + super(ResBlock2, self).__init__() + self.in_channels = in_channels + self.out_channels = self.in_channels if out_channels is None else out_channels + self.activation_type = activation + self.norm_type = norm_type + self.stride = stride + self.downsample = downsample + self.conv1 = conv2d_block2(in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=self.stride, + padding= 1, + activation=self.activation_type, + norm_type=self.norm_type) + self.conv2 = conv2d_block2(in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=self.stride, + padding= 1, + activation=self.activation_type, + norm_type=self.norm_type) + self.activation = build_activation(self.activation_type) + + + def forward(self, x): + r""" + Overview: + return the redisual block output + + Arguments: + - x (:obj:`tensor`): the input tensor + + Returns: + - x(:obj:`tensor`): the resblock output tensor + """ + residual = x + out = self.conv1(x) + out = self.conv2(out) + if self.downsample is not None: + residual = self.downsample(x) + out += residual + return x + +class ResFCBlock(nn.Module): + def __init__(self, in_channels, activation='relu', norm_type=None): + r""" + Overview: + Init the Residual Block + + Arguments: + - activation (:obj:`nn.Module`): the optional activation function + - norm_type (:obj:`str`): type of the normalization, defalut set to batch normalization + """ + super(ResFCBlock, self).__init__() + self.activation_type = activation + self.norm_type = norm_type + self.fc1 = fc_block(in_channels, in_channels, norm_type=self.norm_type, activation=self.activation_type) + self.fc2 = fc_block(in_channels, in_channels,norm_type=self.norm_type, activation=None) + self.activation = build_activation(self.activation_type) + + + def forward(self, x): + r""" + Overview: + return output of the residual block with 2 fully connected block + + Arguments: + - x (:obj:`tensor`): the input tensor + + Returns: + - x(:obj:`tensor`): the resblock output tensor + """ + residual = x + x = self.fc1(x) + x = self.fc2(x) + x = self.activation(x + residual) + return x + +class ResFCBlock2(nn.Module): + r''' + Overview: + Residual Block with 2 fully connected block + x -> fc1 -> norm -> act -> fc2 -> norm -> act -> out + \_____________________________________/+ + + Interface: + __init__, forward + ''' + + def __init__(self, in_channels, activation='relu', norm_type='LN'): + r""" + Overview: + Init the Residual Block + + Arguments: + - activation (:obj:`nn.Module`): the optional activation function + - norm_type (:obj:`str`): type of the normalization, defalut set to batch normalization + """ + super(ResFCBlock2, self).__init__() + self.activation_type = activation + self.fc1 = fc_block2(in_channels, in_channels, activation=self.activation_type, norm_type=norm_type) + self.fc2 = fc_block2(in_channels, in_channels, activation=self.activation_type, norm_type=norm_type) + + def forward(self, x): + r""" + Overview: + return output of the residual block with 2 fully connected block + + Arguments: + - x (:obj:`tensor`): the input tensor + + Returns: + - x(:obj:`tensor`): the resblock output tensor + """ + residual = x + x = self.fc1(x) + x = self.fc2(x) + x = x + residual + return x \ No newline at end of file diff --git a/practice/tools/network/rnn.py b/practice/tools/network/rnn.py new file mode 100644 index 0000000..3631073 --- /dev/null +++ b/practice/tools/network/rnn.py @@ -0,0 +1,276 @@ +""" +Copyright 2020 Sensetime X-lab. All Rights Reserved + +Main Function: + 1. build LSTM: you can use build_LSTM to build the lstm module +""" +import math + +import torch +import torch.nn as nn + +from typing import Optional +from .normalization import build_normalization + + +def is_sequence(data): + return isinstance(data, list) or isinstance(data, tuple) + + +def sequence_mask(lengths: torch.Tensor, max_len: Optional[int] =None): + r""" + Overview: + create a mask for a batch sequences with different lengths + Arguments: + - lengths (:obj:`tensor`): lengths in each different sequences, shape could be (n, 1) or (n) + - max_len (:obj:`int`): the padding size, if max_len is None, the padding size is the + max length of sequences + Returns: + - masks (:obj:`torch.BoolTensor`): mask has the same device as lengths + """ + if len(lengths.shape) == 1: + lengths = lengths.unsqueeze(dim=1) + bz = lengths.numel() + if max_len is None: + max_len = lengths.max() + return torch.arange(0, max_len).type_as(lengths).repeat(bz, 1).lt(lengths).to(lengths.device) + + +class LSTMForwardWrapper(object): + r""" + Overview: + abstract class used to wrap the LSTM forward method + Interface: + _before_forward, _after_forward + """ + + def _before_forward(self, inputs, prev_state): + r""" + Overview: + preprocess the inputs and previous states + Arguments: + - inputs (:obj:`tensor`): input vector of cell, tensor of size [seq_len, batch_size, input_size] + - prev_state (:obj:`tensor` or :obj:`list`): + None or tensor of size [num_directions*num_layers, batch_size, hidden_size], if None then prv_state + will be initialized to all zeros. + Returns: + - prev_state (:obj:`tensor`): batch previous state in lstm + """ + assert hasattr(self, 'num_layers') + assert hasattr(self, 'hidden_size') + seq_len, batch_size = inputs.shape[:2] + if prev_state is None: + num_directions = 1 + zeros = torch.zeros( + num_directions * self.num_layers, + batch_size, + self.hidden_size, + dtype=inputs.dtype, + device=inputs.device + ) + prev_state = (zeros, zeros) + elif is_sequence(prev_state): + if len(prev_state) == 2 and isinstance(prev_state[0], torch.Tensor): + pass + else: + if len(prev_state) != batch_size: + raise RuntimeError( + "prev_state number is not equal to batch_size: {}/{}".format(len(prev_state), batch_size) + ) + num_directions = 1 + zeros = torch.zeros( + num_directions * self.num_layers, 1, self.hidden_size, dtype=inputs.dtype, device=inputs.device + ) + state = [] + for prev in prev_state: + if prev is None: + state.append([zeros, zeros]) + else: + state.append(prev) + state = list(zip(*state)) + prev_state = [torch.cat(t, dim=1) for t in state] + else: + raise TypeError("not support prev_state type: {}".format(type(prev_state))) + return prev_state + + def _after_forward(self, next_state, list_next_state=False): + r""" + Overview: + post process the next_state, return list or tensor type next_states + Arguments: + - next_state (:obj:`list` :obj:`Tuple` of :obj:`tensor`): list of Tuple contains the next (h, c) + - list_next_state (:obj:`bool`): whether return next_state with list format, default set to False + Returns: + - next_state(:obj:`list` of :obj:`tensor` or :obj:`tensor`): the formated next_state + """ + if list_next_state: + h, c = [torch.stack(t, dim=0) for t in zip(*next_state)] + batch_size = h.shape[1] + next_state = [torch.chunk(h, batch_size, dim=1), torch.chunk(c, batch_size, dim=1)] + next_state = list(zip(*next_state)) + else: + next_state = [torch.stack(t, dim=0) for t in zip(*next_state)] + return next_state + + +class LSTM(nn.Module, LSTMForwardWrapper): + r""" + Overview: + Implimentation of LSTM cell + + .. note:: + for begainners, you can reference to learn the basics about lstm + + Interface: + __init__, forward + """ + + def __init__(self, input_size, hidden_size, num_layers, norm_type=None, dropout=0.): + r""" + Overview: + initializate the LSTM cell + + Arguments: + - input_size (:obj:`int`): size of the input vector + - hidden_size (:obj:`int`): size of the hidden state vector + - num_layers (:obj:`int`): number of lstm layers + - norm_type (:obj:`str`): type of the normaliztion, (default: None) + - dropout (:obj:float): dropout rate, default set to .0 + """ + super(LSTM, self).__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.num_layers = num_layers + + norm_func = build_normalization(norm_type) + self.norm = nn.ModuleList([norm_func(hidden_size * 4) for _ in range(2 * num_layers)]) + self.wx = nn.ParameterList() + self.wh = nn.ParameterList() + dims = [input_size] + [hidden_size] * num_layers + for l in range(num_layers): + self.wx.append(nn.Parameter(torch.zeros(dims[l], dims[l + 1] * 4))) + self.wh.append(nn.Parameter(torch.zeros(hidden_size, hidden_size * 4))) + self.bias = nn.Parameter(torch.zeros(num_layers, hidden_size * 4)) + self.use_dropout = dropout > 0. + if self.use_dropout: + self.dropout = nn.Dropout(dropout) + self._init() + + def _init(self): + gain = math.sqrt(1. / self.hidden_size) + for l in range(self.num_layers): + torch.nn.init.uniform_(self.wx[l], -gain, gain) + torch.nn.init.uniform_(self.wh[l], -gain, gain) + if self.bias is not None: + torch.nn.init.uniform_(self.bias[l], -gain, gain) + + def forward(self, inputs, prev_state, list_next_state=True): + r""" + Overview: + Take the previous state and the input and calculate the output and the nextstate + Arguments: + - inputs (:obj:`tensor`): input vector of cell, tensor of size [seq_len, batch_size, input_size] + - prev_state (:obj:`tensor`): None or tensor of size [num_directions*num_layers, batch_size, hidden_size] + - list_next_state (:obj:`bool`): whether return next_state with list format, default set to False + Returns: + - x (:obj:`tensor`): output from lstm + - next_state (:obj:`tensor` or :obj:`list`): hidden state from lstm + """ + seq_len, batch_size = inputs.shape[:2] + prev_state = self._before_forward(inputs, prev_state) + + H, C = prev_state + x = inputs + next_state = [] + for l in range(self.num_layers): + h, c = H[l], C[l] + new_x = [] + for s in range(seq_len): + gate = self.norm[l * 2](torch.matmul(x[s], self.wx[l]) + ) + self.norm[l * 2 + 1](torch.matmul(h, self.wh[l])) + if self.bias is not None: + gate += self.bias[l] + gate = list(torch.chunk(gate, 4, dim=1)) + i, f, o, u = gate + i = torch.sigmoid(i) + f = torch.sigmoid(f) + o = torch.sigmoid(o) + u = torch.tanh(u) + c = f * c + i * u + h = o * torch.tanh(c) + new_x.append(h) + next_state.append((h, c)) + x = torch.stack(new_x, dim=0) + if self.use_dropout and l != self.num_layers - 1: + x = self.dropout(x) + + next_state = self._after_forward(next_state, list_next_state) + return x, next_state + + +class PytorchLSTM(nn.LSTM, LSTMForwardWrapper): + r""" + Overview: + Wrap the nn.LSTM , format the input and output + Interface: + forward + + .. note:: + you can reference the + """ + + def forward(self, inputs, prev_state, list_next_state=True): + r""" + Overview: + wrapped nn.LSTM.forward + Arguments: + - inputs (:obj:`tensor`): input vector of cell, tensor of size [seq_len, batch_size, input_size] + - prev_state (:obj:`tensor`): None or tensor of size [num_directions*num_layers, batch_size, hidden_size] + - list_next_state (:obj:`bool`): whether return next_state with list format, default set to False + Returns: + - output (:obj:`tensor`): output from lstm + - next_state (:obj:`tensor` or :obj:`list`): hidden state from lstm + """ + prev_state = self._before_forward(inputs, prev_state) + output, next_state = nn.LSTM.forward(self, inputs, prev_state) + next_state = self._after_forward(next_state, list_next_state) + return output, next_state + + def _after_forward(self, next_state, list_next_state=False): + r""" + Overview: + process hidden state after lstm, make it list or remains tensor + Arguments: + - nex_state (:obj:`tensor`): hidden state from lstm + - list_nex_state (:obj:`bool`): whether return next_state with list format, default set to False + Returns: + - next_state (:obj:`tensor` or :obj:`list`): hidden state from lstm + """ + if list_next_state: + h, c = next_state + batch_size = h.shape[1] + next_state = [torch.chunk(h, batch_size, dim=1), torch.chunk(c, batch_size, dim=1)] + return list(zip(*next_state)) + else: + return next_state + + +def get_lstm(lstm_type, input_size, hidden_size, num_layers=1, norm_type='LN', dropout=0.): + r""" + Overview: + build and return the corresponding LSTM cell + Arguments: + - lstm_type (:obj:`str`): version of lstm cell, now support ['normal', 'pytorch'] + - input_size (:obj:`int`): size of the input vector + - hidden_size (:obj:`int`): size of the hidden state vector + - num_layers (:obj:`int`): number of lstm layers + - norm_type (:obj:`str`): type of the normaliztion, (default: None) + - dropout (:obj:float): dropout rate, default set to .0 + Returns: + - lstm (:obj:`LSTM` or :obj:`PytorchLSTM`): the corresponding lstm cell + """ + assert lstm_type in ['normal', 'pytorch'] + if lstm_type == 'normal': + return LSTM(input_size, hidden_size, num_layers, norm_type, dropout=dropout) + elif lstm_type == 'pytorch': + return PytorchLSTM(input_size, hidden_size, num_layers, dropout=dropout) diff --git a/practice/tools/network/scatter_connection.py b/practice/tools/network/scatter_connection.py new file mode 100644 index 0000000..dbb6ab7 --- /dev/null +++ b/practice/tools/network/scatter_connection.py @@ -0,0 +1,107 @@ +from typing import Tuple + +import torch +import torch.nn as nn + + +class ScatterConnection(nn.Module): + r""" + Overview: + Scatter feature to its corresponding location + In alphastar, each entity is embedded into a tensor, these tensors are scattered into a feature map + with map size + """ + + def __init__(self, scatter_type='add') -> None: + r""" + Overview: + Init class + Arguments: + - scatter_type (:obj:`str`): add or cover, if two entities have same location, scatter type decides the + first one should be covered or added to second one + """ + super(ScatterConnection, self).__init__() + self.scatter_type = scatter_type + assert self.scatter_type in ['cover', 'add'] + + def xy_forward(self, x: torch.Tensor, spatial_size: Tuple[int, int], coord_x: torch.Tensor,coord_y) -> torch.Tensor: + device = x.device + BatchSize, Num, EmbeddingSize = x.shape + x = x.permute(0, 2, 1) + H, W = spatial_size + indices = (coord_x * W + coord_y).long() + indices = indices.unsqueeze(dim=1).repeat(1, EmbeddingSize, 1) + output = torch.zeros(size=(BatchSize, EmbeddingSize, H, W), device=device).view(BatchSize, EmbeddingSize, + H * W) + if self.scatter_type == 'cover': + output.scatter_(dim=2, index=indices, src=x) + elif self.scatter_type == 'add': + output.scatter_add_(dim=2, index=indices, src=x) + output = output.view(BatchSize, EmbeddingSize, H, W) + return output + + def forward(self, x: torch.Tensor, spatial_size: Tuple[int, int], location: torch.Tensor) -> torch.Tensor: + """ + Overview: + scatter x into a spatial feature map + Arguments: + - x (:obj:`tensor`): input tensor :math: `(B, M, N)` where `M` means the number of entity, `N` means\ + the dimension of entity attributes + - spatial_size (:obj:`tuple`): Tuple[H, W], the size of spatial feature x will be scattered into + - location (:obj:`tensor`): :math: `(B, M, 2)` torch.LongTensor, each location should be (y, x) + Returns: + - output (:obj:`tensor`): :math: `(B, N, H, W)` where `H` and `W` are spatial_size, return the\ + scattered feature map + Shapes: + - Input: :math: `(B, M, N)` where `M` means the number of entity, `N` means\ + the dimension of entity attributes + - Size: Tuple[H, W] + - Location: :math: `(B, M, 2)` torch.LongTensor, each location should be (y, x) + - Output: :math: `(B, N, H, W)` where `H` and `W` are spatial_size + + .. note:: + when there are some overlapping in locations, ``cover`` mode will result in the loss of information, we + use the addition as temporal substitute. + """ + device = x.device + BatchSize, Num, EmbeddingSize = x.shape + x = x.permute(0, 2, 1) + H, W = spatial_size + indices = location[:, :, 1] + location[:, :, 0] * W + indices = indices.unsqueeze(dim=1).repeat(1, EmbeddingSize, 1) + output = torch.zeros(size=(BatchSize, EmbeddingSize, H, W), device=device).view(BatchSize, EmbeddingSize, + H * W) + if self.scatter_type == 'cover': + output.scatter_(dim=2, index=indices, src=x) + elif self.scatter_type == 'add': + output.scatter_add_(dim=2, index=indices, src=x) + output = output.view(BatchSize, EmbeddingSize, H, W) + + # device = x.device + # B, M, N = x.shape + # H, W = spatial_size + # index = location.view(-1, 2) + # bias = torch.arange(B).mul_(H * W).unsqueeze(1).repeat(1, M).view(-1).to(device) + # index = index[:, 0] * W + index[:, 1] + # index += bias + # index = index.repeat(N, 1) + # x = x.view(-1, N).permute(1, 0) + # output = torch.zeros(N, B * H * W, device=device) + # if self.scatter_type == 'cover': + # output.scatter_(dim=1, index=index, src=x) + # elif self.scatter_type == 'add': + # output.scatter_add_(dim=1, index=index, src=x) + # output = output.reshape(N, B, H, W) + # output = output.permute(1, 0, 2, 3).contiguous() + + return output + + +if __name__ == '__main__': + scatter_conn = ScatterConnection() + BatchSize, Num, EmbeddingSize = 10, 20, 3 + SpatialSize = (13, 17) + for _ in range(10): + x = torch.randn(size=(BatchSize, Num, EmbeddingSize)) + locations = torch.randint(low=0, high=12, size=(BatchSize, Num, 2)) + scatter_conn.forward(x, SpatialSize, location=locations) diff --git a/practice/tools/network/soft_argmax.py b/practice/tools/network/soft_argmax.py new file mode 100644 index 0000000..a963fd1 --- /dev/null +++ b/practice/tools/network/soft_argmax.py @@ -0,0 +1,60 @@ +""" +Copyright 2020 Sensetime X-lab. All Rights Reserved + +Main Function: + 1. SoftArgmax: a nn.Module that computes SoftArgmax +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class SoftArgmax(nn.Module): + r""" + Overview: + a nn.Module that computes SoftArgmax + + Note: + for more softargmax info, you can reference the wiki page + or reference the lecture + + + Interface: + __init__, forward + """ + + def __init__(self): + r""" + Overview: + initialize the SoftArgmax module + """ + super(SoftArgmax, self).__init__() + + def forward(self, x): + r""" + Overview: + soft-argmax for location regression + + Arguments: + - x (:obj:`Tensor`): predict heat map + + Returns: + - location (:obj:`Tensor`): predict location + + Shapes: + - x (:obj:`Tensor`): :math:`(B, C, H, W)`, while B is the batch size, + C is number of channels , H and W stands for height and width + - location (:obj:`Tensor`): :math:`(B, 2)`, while B is the batch size + """ + B, C, H, W = x.shape + device, dtype = x.device, x.dtype + # 1 channel + assert (x.shape[1] == 1) + h_kernel = torch.arange(0, H, device=device).to(dtype) + h_kernel = h_kernel.view(1, 1, H, 1).repeat(1, 1, 1, W) + w_kernel = torch.arange(0, W, device=device).to(dtype) + w_kernel = w_kernel.view(1, 1, 1, W).repeat(1, 1, H, 1) + x = F.softmax(x.view(B, C, -1), dim=-1).view(B, C, H, W) + h = (x * h_kernel).sum(dim=[1, 2, 3]) + w = (x * w_kernel).sum(dim=[1, 2, 3]) + return torch.stack([h, w], dim=1) diff --git a/practice/tools/network/transformer.py b/practice/tools/network/transformer.py new file mode 100644 index 0000000..caba8dc --- /dev/null +++ b/practice/tools/network/transformer.py @@ -0,0 +1,391 @@ +import math +from typing import Dict, Tuple, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +LAYER_NORM_EPS = 1e-5 +NEAR_INF = 1e20 +NEAR_INF_FP16 = 65504 + + +def neginf(dtype: torch.dtype) -> float: + """ + Return a representable finite number near -inf for a dtype. + """ + if dtype is torch.float16: + return -NEAR_INF_FP16 + else: + return -NEAR_INF + + +class MultiHeadAttention(nn.Module): + r""" + Overview: + For each entry embedding, compute individual attention across all entries, add them up to get output attention + """ + + def __init__(self, n_heads: int = None, dim: int = None, dropout: float = 0): + r""" + Overview: + Init attention + Arguments: + - input_dim (:obj:`int`): dimension of input + - head_dim (:obj:`int`): dimension of each head + - output_dim (:obj:`int`): dimension of output + - head_num (:obj:`int`): head num for multihead attention + - dropout (:obj:`nn.Module`): dropout layer + """ + super(MultiHeadAttention, self).__init__() + self.n_heads = n_heads + self.dim = dim + + self.attn_dropout = nn.Dropout(p=dropout) + self.q_lin = nn.Linear(dim, dim) + self.k_lin = nn.Linear(dim, dim) + self.v_lin = nn.Linear(dim, dim) + + # TODO: merge for the initialization step + nn.init.xavier_normal_(self.q_lin.weight) + nn.init.xavier_normal_(self.k_lin.weight) + nn.init.xavier_normal_(self.v_lin.weight) + self.out_lin = nn.Linear(dim, dim) + nn.init.xavier_normal_(self.out_lin.weight) + + # self.attention_pre = fc_block(self.dim, self.dim * 3) # query, key, value + # self.project = fc_block(self.dim,self.dim) + + def split(self, x, T=False): + r""" + Overview: + Split input to get multihead queries, keys, values + Arguments: + - x (:obj:`tensor`): query or key or value + - T (:obj:`bool`): whether to transpose output + Returns: + - x (:obj:`list`): list of output tensors for each head + """ + B, N = x.shape[:2] + x = x.view(B, N, self.head_num, self.head_dim) + x = x.permute(0, 2, 1, 3).contiguous() # B, head_num, N, head_dim + if T: + x = x.permute(0, 1, 3, 2).contiguous() + return x + + def forward(self, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + value: Optional[torch.Tensor] = None, + mask: torch.Tensor = None, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: + batch_size, query_len, dim = query.size() + assert ( + dim == self.dim + ), 'Dimensions do not match: {} query vs {} configured'.format(dim, self.dim) + assert mask is not None, 'Mask is None, please specify a mask' + n_heads = self.n_heads + dim_per_head = dim // n_heads + scale = math.sqrt(dim_per_head) + + def prepare_head(tensor): + # input is [batch_size, seq_len, n_heads * dim_per_head] + # output is [batch_size * n_heads, seq_len, dim_per_head] + bsz, seq_len, _ = tensor.size() + tensor = tensor.view(batch_size, tensor.size(1), n_heads, dim_per_head) + tensor = ( + tensor.transpose(1, 2) + .contiguous() + .view(batch_size * n_heads, seq_len, dim_per_head) + ) + return tensor + + # q, k, v are the transformed values + if key is None and value is None: + # self attention + key = value = query + _, _key_len, dim = query.size() + elif value is None: + # key and value are the same, but query differs + # self attention + value = key + + assert key is not None # let mypy know we sorted this + _, _key_len, dim = key.size() + + q = prepare_head(self.q_lin(query)) + k = prepare_head(self.k_lin(key)) + v = prepare_head(self.v_lin(value)) + full_key_len = k.size(1) + dot_prod = q.div_(scale).bmm(k.transpose(1, 2)) + # [B * n_heads, query_len, key_len] + attn_mask = ( + (mask == 0) + .view(batch_size, 1, -1, full_key_len) + .repeat(1, n_heads, 1, 1) + .expand(batch_size, n_heads, query_len, full_key_len) + .view(batch_size * n_heads, query_len, full_key_len) + ) + assert attn_mask.shape == dot_prod.shape + dot_prod.masked_fill_(attn_mask, neginf(dot_prod.dtype)) + + attn_weights = F.softmax( + dot_prod, dim=-1, dtype=torch.float # type: ignore + ).type_as(query) + attn_weights = self.attn_dropout(attn_weights) # --attention-dropout + + attentioned = attn_weights.bmm(v) + attentioned = ( + attentioned.type_as(query) + .view(batch_size, n_heads, query_len, dim_per_head) + .transpose(1, 2) + .contiguous() + .view(batch_size, query_len, dim) + ) + + out = self.out_lin(attentioned) + + return out, dot_prod + # + # def forward(self, x, mask=None): + # r""" + # Overview: + # Compute attention + # Arguments: + # - x (:obj:`tensor`): input tensor + # - mask (:obj:`tensor`): mask out invalid entries + # Returns: + # - attention (:obj:`tensor`): attention tensor + # """ + # assert (len(x.shape) == 3) + # B, N = x.shape[:2] + # x = self.attention_pre(x) + # query, key, value = torch.chunk(x, 3, dim=2) + # query, key, value = self.split(query), self.split(key, T=True), self.split(value) + # + # score = torch.matmul(query, key) # B, head_num, N, N + # score /= math.sqrt(self.head_dim) + # if mask is not None: + # score.masked_fill_(~mask, value=-1e9) + # + # score = F.softmax(score, dim=-1) + # score = self.dropout(score) + # attention = torch.matmul(score, value) # B, head_num, N, head_dim + # + # attention = attention.permute(0, 2, 1, 3).contiguous() # B, N, head_num, head_dim + # attention = self.project(attention.view(B, N, -1)) # B, N, output_dim + # return attention + + +class TransformerFFN(nn.Module): + """ + Implements the FFN part of the transformer. + """ + + def __init__( + self, + dim: int = None, + dim_hidden: int = None, + dropout: float = 0, + activation: str = 'relu', + **kwargs, + ): + super(TransformerFFN, self).__init__(**kwargs) + self.dim = dim + self.dim_hidden = dim_hidden + self.dropout_ratio = dropout + self.relu_dropout = nn.Dropout(p=self.dropout_ratio) + if activation == 'relu': + self.nonlinear = F.relu + elif activation == 'gelu': + self.nonlinear = F.gelu + else: + raise ValueError( + "Don't know how to handle --activation {}".format(activation) + ) + self.lin1 = nn.Linear(self.dim, self.dim_hidden) + self.lin2 = nn.Linear(self.dim_hidden, self.dim) + nn.init.xavier_uniform_(self.lin1.weight) + nn.init.xavier_uniform_(self.lin2.weight) + # TODO: initialize biases to 0 + + def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: + """ + Forward pass. + """ + x = self.nonlinear(self.lin1(x)) + x = self.relu_dropout(x) # --relu-dropout + x = self.lin2(x) + return x + + +class TransformerLayer(nn.Module): + r""" + Overview: + In transformer layer, first computes entries's attention and applies a feedforward layer + """ + + def __init__(self, + n_heads: int = None, + embedding_size: int = None, + ffn_size: int = None, + attention_dropout: float = 0.0, + relu_dropout: float = 0.0, + dropout: float = 0.0, + activation: str = 'relu', + variant: Optional[str] = None, + ): + r""" + Overview: + Init transformer layer + Arguments: + - input_dim (:obj:`int`): dimension of input + - head_dim (:obj:`int`): dimension of each head + - hidden_dim (:obj:`int`): dimension of hidden layer in mlp + - output_dim (:obj:`int`): dimension of output + - head_num (:obj:`int`): number of heads for multihead attention + - mlp_num (:obj:`int`): number of mlp layers + - dropout (:obj:`nn.Module`): dropout layer + - activation (:obj:`nn.Module`): activation function + """ + super(TransformerLayer, self).__init__() + self.n_heads = n_heads + self.dim = embedding_size + self.ffn_dim = ffn_size + self.activation = activation + self.variant = variant + self.attention = MultiHeadAttention( + n_heads=self.n_heads, + dim=embedding_size, + dropout=attention_dropout) + self.norm1 = torch.nn.LayerNorm(embedding_size, eps=LAYER_NORM_EPS) + self.ffn = TransformerFFN(dim=embedding_size, + dim_hidden=ffn_size, + dropout=relu_dropout, + activation=activation, + ) + self.norm2 = torch.nn.LayerNorm(embedding_size, eps=LAYER_NORM_EPS) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, mask): + """ + Overview: + transformer layer forward + Arguments: + - inputs (:obj:`tuple`): x and mask + Returns: + - output (:obj:`tuple`): x and mask + """ + residual = x + + if self.variant == 'prenorm': + x = self.norm1(x) + attended_tensor = self.attention(x, mask=mask)[0] + x = residual + self.dropout(attended_tensor) + if self.variant == 'postnorm': + x = self.norm1(x) + + residual = x + if self.variant == 'prenorm': + x = self.norm2(x) + x = residual + self.dropout(self.ffn(x)) + if self.variant == 'postnorm': + x = self.norm2(x) + + x *= mask.unsqueeze(-1).type_as(x) + return x + + +class Transformer(nn.Module): + ''' + Overview: + Transformer implementation + + Note: + For details refer to Attention is all you need: http://arxiv.org/abs/1706.03762 + ''' + + def __init__( + self, + n_heads=8, + embedding_size: int = 128, + ffn_size: int = 128, + n_layers: int = 3, + attention_dropout: float = 0.0, + relu_dropout: float = 0.0, + dropout: float = 0.0, + activation: Optional[str] = 'relu', + variant: Optional[str] = 'prenorm', + ): + r""" + Overview: + Init transformer + Arguments: + - input_dim (:obj:`int`): dimension of input + - head_dim (:obj:`int`): dimension of each head + - hidden_dim (:obj:`int`): dimension of hidden layer in mlp + - output_dim (:obj:`int`): dimension of output + - head_num (:obj:`int`): number of heads for multihead attention + - mlp_num (:obj:`int`): number of mlp layers + - layer_num (:obj:`int`): number of transformer layers + - dropout_ratio (:obj:`float`): dropout ratio + - activation (:obj:`nn.Module`): activation function + """ + super(Transformer, self).__init__() + self.n_heads = n_heads + self.dim = embedding_size + self.ffn_size = ffn_size + self.n_layers = n_layers + + self.dropout_ratio = dropout + self.attention_dropout = attention_dropout + self.relu_dropout = relu_dropout + self.activation = activation + self.variant = variant + + # build the model + self.layers = self.build_layers() + self.norm_embedding = torch.nn.LayerNorm(self.dim, eps=LAYER_NORM_EPS) + + def build_layers(self) -> nn.ModuleList: + layers = nn.ModuleList() + for _ in range(self.n_layers): + layer = TransformerLayer( + n_heads=self.n_heads, + embedding_size=self.dim, + ffn_size=self.ffn_size, + attention_dropout=self.attention_dropout, + relu_dropout=self.relu_dropout, + dropout=self.dropout_ratio, + variant=self.variant, + activation=self.activation, + ) + layers.append(layer) + return layers + + def forward(self, x, mask=None): + r""" + Overview: + Transformer forward + Arguments: + - x (:obj:`tensor`): input tensor, shape (B, N, C), B is batch size, N is number of entries, + C is feature dimension + - mask (:obj:`tensor` or :obj:`None`): bool tensor, can be used to mask out invalid entries in attention, + shape (B, N), B is batch size, N is number of entries + Returns: + - x (:obj:`tensor`): transformer output + """ + if self.variant == 'postnorm': + x = self.norm_embedding(x) + if mask is not None: + x *= mask.unsqueeze(-1).type_as(x) + else: + mask = torch.ones(size=x.shape[:2],dtype=torch.bool, device=x.device) + if self.variant == 'postnorm': + x = self.norm_embedding(x) + for i in range(self.n_layers): + x = self.layers[i](x, mask) + if self.variant == 'prenorm': + x = self.norm_embedding(x) + return x + diff --git a/practice/tools/util.py b/practice/tools/util.py new file mode 100644 index 0000000..a512dd7 --- /dev/null +++ b/practice/tools/util.py @@ -0,0 +1,184 @@ +import copy +import json +import os +from typing import NoReturn, Optional, List + +import yaml +from easydict import EasyDict +import re +from collections.abc import Sequence, Mapping +from typing import List, Dict, Union, Any + +import torch +import collections.abc as container_abcs +from torch._six import string_classes +#from torch._six import int_classes as _int_classes +int_classes = int + +np_str_obj_array_pattern = re.compile(r'[SaUO]') + +default_collate_err_msg_format = ( + "default_collate: batch must contain tensors, numpy arrays, numbers, " + "dicts or lists; found {}" +) + + +def read_config(path: str) -> EasyDict: + """ + Overview: + read configuration from path + Arguments: + - path (:obj:`str`): Path of source yaml + Returns: + - (:obj:`EasyDict`): Config data from this file with dict type + """ + if path: + assert os.path.exists(path), path + with open(path, "r") as f: + config = yaml.safe_load(f) + else: + config = {} + return EasyDict(config) + +def deep_merge_dicts(original: dict, new_dict: dict) -> dict: + """ + Overview: + merge two dict using deep_update + Arguments: + - original (:obj:`dict`): Dict 1. + - new_dict (:obj:`dict`): Dict 2. + Returns: + - (:obj:`dict`): A new dict that is d1 and d2 deeply merged. + """ + original = original or {} + new_dict = new_dict or {} + merged = copy.deepcopy(original) + if new_dict: # if new_dict is neither empty dict nor None + deep_update(merged, new_dict, True, []) + + return merged + +def deep_update( + original: dict, + new_dict: dict, + new_keys_allowed: bool = False, + whitelist: Optional[List[str]] = None, + override_all_if_type_changes: Optional[List[str]] = None +): + """ + Overview: + Updates original dict with values from new_dict recursively. + + .. note:: + + If new key is introduced in new_dict, then if new_keys_allowed is not + True, an error will be thrown. Further, for sub-dicts, if the key is + in the whitelist, then new subkeys can be introduced. + + Arguments: + - original (:obj:`dict`): Dictionary with default values. + - new_dict (:obj:`dict`): Dictionary with values to be updated + - new_keys_allowed (:obj:`bool`): Whether new keys are allowed. + - whitelist (Optional[List[str]]): List of keys that correspond to dict + values where new subkeys can be introduced. This is only at the top + level. + - override_all_if_type_changes(Optional[List[str]]): List of top level + keys with value=dict, for which we always simply override the + entire value (:obj:`dict`), if the "type" key in that value dict changes. + """ + whitelist = whitelist or [] + override_all_if_type_changes = override_all_if_type_changes or [] + + for k, value in new_dict.items(): + if k not in original and not new_keys_allowed: + raise RuntimeError("Unknown config parameter `{}`. Base config have: {}.".format(k, original.keys())) + + # Both original value and new one are dicts. + if isinstance(original.get(k), dict) and isinstance(value, dict): + # Check old type vs old one. If different, override entire value. + if k in override_all_if_type_changes and \ + "type" in value and "type" in original[k] and \ + value["type"] != original[k]["type"]: + original[k] = value + # Whitelisted key -> ok to add new subkeys. + elif k in whitelist: + deep_update(original[k], value, True) + # Non-whitelisted key. + else: + deep_update(original[k], value, new_keys_allowed) + # Original value not a dict OR new value not a dict: + # Override entire value. + else: + original[k] = value + return original + + +def default_collate_with_dim(batch,device='cpu',dim=0, k=None,cat=False): + r"""Puts each data field into a tensor with outer dimension batch size""" + elem = batch[0] + elem_type = type(elem) + #if k is not None: + # print(k) + + if isinstance(elem, torch.Tensor): + out = None + if torch.utils.data.get_worker_info() is not None: + # If we're in a background process, concatenate directly into a + # shared memory tensor to avoid an extra copy + numel = sum([x.numel() for x in batch]) + storage = elem.storage()._new_shared(numel) + out = elem.new(storage) + try: + if cat == True: + return torch.cat(batch, dim=dim, out=out).to(device=device) + else: + return torch.stack(batch, dim=dim, out=out).to(device=device) + except: + print(batch) + if k is not None: + print(k) + elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ + and elem_type.__name__ != 'string_': + if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap': + # array of string classes and object + if np_str_obj_array_pattern.search(elem.dtype.str) is not None: + raise TypeError(default_collate_err_msg_format.format(elem.dtype)) + + return default_collate_with_dim([torch.as_tensor(b,device=device) for b in batch],device=device,dim=dim,cat=cat) + elif elem.shape == (): # scalars + try: + return torch.as_tensor(batch,device=device) + except: + print(batch) + if k is not None: + print(k) + elif isinstance(elem, float): + try: + return torch.tensor(batch,device=device) + except: + print(batch) + if k is not None: + print(k) + elif isinstance(elem, int_classes): + try: + return torch.tensor(batch,device=device) + except: + print(batch) + if k is not None: + print(k) + elif isinstance(elem, string_classes): + return batch + elif isinstance(elem, container_abcs.Mapping): + return {key: default_collate_with_dim([d[key] for d in batch if key in d.keys()],device=device,dim=dim, k=key, cat=cat) for key in elem} + elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple + return elem_type(*(default_collate_with_dim(samples,device=device,dim=dim,cat=cat) for samples in zip(*batch))) + elif isinstance(elem, container_abcs.Sequence): + # check to make sure that the elements in batch have consistent size + it = iter(batch) + elem_size = len(next(it)) + if not all(len(elem) == elem_size for elem in it): + raise RuntimeError('each element in list of batch should be of equal size') + transposed = zip(*batch) + return [default_collate_with_dim(samples,device=device,dim=dim,cat=cat) for samples in transposed] + + raise TypeError(default_collate_err_msg_format.format(elem_type)) \ No newline at end of file From 10c656c935a47e24598c09c2fb1d68bf14f7eaa3 Mon Sep 17 00:00:00 2001 From: Norman <74552232+TuTuHuss@users.noreply.github.com> Date: Wed, 30 Aug 2023 22:07:02 +0800 Subject: [PATCH 2/3] doc(hus):update download weight online link --- practice/README.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/practice/README.md b/practice/README.md index b7758ec..612d5e7 100644 --- a/practice/README.md +++ b/practice/README.md @@ -3,7 +3,8 @@ We offer multiple gameplay modes for players to enjoy, including development, ba ### Download Weight ```bash - +wget https://opendilab.net/download/GoBigger/solo_agent.pth.tar +wget https://opendilab.net/download/GoBigger/team.pth.tar ``` ### Quick Start @@ -22,4 +23,4 @@ python battle.py --mode team --map farm # Two-player development python battle.py --mode team --map vsbot # Two-player vs. Bot python battle.py --mode team --map vsai # Two-player vs. Team-AI python battle.py --mode watch # Spectator mode: Team-AI vs. Team-AI -``` \ No newline at end of file +``` From 1de39c2fa1161fe8519e3cfc5e827b36bf29dceb Mon Sep 17 00:00:00 2001 From: Norman <74552232+TuTuHuss@users.noreply.github.com> Date: Wed, 30 Aug 2023 23:29:09 +0800 Subject: [PATCH 3/3] updaate cooperative agent path --- practice/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/practice/README.md b/practice/README.md index 612d5e7..842d32c 100644 --- a/practice/README.md +++ b/practice/README.md @@ -4,7 +4,7 @@ We offer multiple gameplay modes for players to enjoy, including development, ba ### Download Weight ```bash wget https://opendilab.net/download/GoBigger/solo_agent.pth.tar -wget https://opendilab.net/download/GoBigger/team.pth.tar +wget https://opendilab.net/download/GoBigger/cooperative_agent.pth.tar ``` ### Quick Start