Files
2022-11-28 15:27:10 -05:00

210 lines
8.8 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Run a pre-trained model\n",
"\n",
"This notebook loads a pre-trained model and uses it to play games. \n",
"Note that it does not render the image of the game, it just prints out the episodic score. "
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"pygame 2.1.2 (SDL 2.0.18, Python 3.10.8)\n",
"Hello from the pygame community. https://www.pygame.org/contribute.html\n"
]
}
],
"source": [
"# sanity check: can we create breakwall?\n",
"import gym\n",
"e = gym.make('gym_gs:BreakwallNoFrameskip-v1')"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Collecting git+https://github.com/openai/baselines.git\n",
" Cloning https://github.com/openai/baselines.git to c:\\users\\gofor\\appdata\\local\\temp\\pip-req-build-s405pyio\n",
" Resolved https://github.com/openai/baselines.git to commit ea25b9e8b234e6ee1bca43083f8f3cf974143998\n",
" Preparing metadata (setup.py): started\n",
" Preparing metadata (setup.py): finished with status 'done'\n",
"Requirement already satisfied: gym<0.16.0,>=0.15.4 in c:\\users\\gofor\\myvenv\\lib\\site-packages (from baselines==0.1.6) (0.15.7)\n",
"Requirement already satisfied: scipy in c:\\users\\gofor\\myvenv\\lib\\site-packages (from baselines==0.1.6) (1.9.3)\n",
"Requirement already satisfied: tqdm in c:\\users\\gofor\\myvenv\\lib\\site-packages (from baselines==0.1.6) (4.64.1)\n",
"Requirement already satisfied: joblib in c:\\users\\gofor\\myvenv\\lib\\site-packages (from baselines==0.1.6) (1.2.0)\n",
"Requirement already satisfied: cloudpickle in c:\\users\\gofor\\myvenv\\lib\\site-packages (from baselines==0.1.6) (1.2.2)\n",
"Requirement already satisfied: click in c:\\users\\gofor\\myvenv\\lib\\site-packages (from baselines==0.1.6) (8.1.3)\n",
"Requirement already satisfied: opencv-python in c:\\users\\gofor\\myvenv\\lib\\site-packages (from baselines==0.1.6) (4.6.0.66)\n",
"Requirement already satisfied: numpy>=1.10.4 in c:\\users\\gofor\\myvenv\\lib\\site-packages (from gym<0.16.0,>=0.15.4->baselines==0.1.6) (1.23.5)\n",
"Requirement already satisfied: pyglet<=1.5.0,>=1.4.0 in c:\\users\\gofor\\myvenv\\lib\\site-packages (from gym<0.16.0,>=0.15.4->baselines==0.1.6) (1.5.0)\n",
"Requirement already satisfied: six in c:\\users\\gofor\\myvenv\\lib\\site-packages (from gym<0.16.0,>=0.15.4->baselines==0.1.6) (1.16.0)\n",
"Requirement already satisfied: colorama in c:\\users\\gofor\\myvenv\\lib\\site-packages (from click->baselines==0.1.6) (0.4.6)\n",
"Requirement already satisfied: future in c:\\users\\gofor\\myvenv\\lib\\site-packages (from pyglet<=1.5.0,>=1.4.0->gym<0.16.0,>=0.15.4->baselines==0.1.6) (0.18.2)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" Running command git clone --filter=blob:none --quiet https://github.com/openai/baselines.git 'C:\\Users\\gofor\\AppData\\Local\\Temp\\pip-req-build-s405pyio'\n",
"\n",
"[notice] A new release of pip available: 22.2.2 -> 22.3.1\n",
"[notice] To update, run: python.exe -m pip install --upgrade pip\n"
]
}
],
"source": [
"# install baselines and other stuff\n",
"!pip install git+https://github.com/openai/baselines.git"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loaded gym\n",
"Model weights look loadable ./pre-trained/mac_hard_breakwall/gym_gs_BreakwallNoFrameskip-v1_20211018-114642_5424.data-00000-of-00001\n",
"Model loaded weights - starting sim\n",
"Game over at frame 278 rew 2.0 rewards/frame: 0.007194244604316547\n",
"Game over at frame 453 rew 3.0 rewards/frame: 0.006622516556291391\n",
"Game over at frame 631 rew 4.0 rewards/frame: 0.006339144215530904\n",
"Game over at frame 906 rew 6.0 rewards/frame: 0.006622516556291391\n",
"Game over at frame 976 rew 6.0 rewards/frame: 0.006147540983606557\n",
"Sim ended : rew is 6.0\n"
]
}
],
"source": [
"## full check - can we use the full opencv/ openai version \n",
"## of the gym?\n",
"\n",
"# Script to test a pre-trained model\n",
"# Written by Matthew Yee-King\n",
"# MIT license \n",
"# https://mit-license.org/\n",
"\n",
"import sys\n",
"import os\n",
"from baselines.common.atari_wrappers import make_atari, wrap_deepmind\n",
"import numpy as np\n",
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"from tensorflow.keras import layers\n",
"import datetime\n",
"import random\n",
"import time \n",
"\n",
"env_name = \"gym_gs:BreakwallNoFrameskip-v1\" \n",
"# for notebook users - make sure you have uploaded your pre-trained\n",
"# models... then adjust this to reflect the file path\n",
"model_file = \"./pre-trained/mac_hard_breakwall/gym_gs_BreakwallNoFrameskip-v1_20211018-114642_5424\"\n",
"\n",
"def create_q_model(num_actions):\n",
" # Network defined by the Deepmind paper\n",
" inputs = layers.Input(shape=(84, 84, 4,))\n",
" # Convolutions on the frames on the screen\n",
" layer1 = layers.Conv2D(32, 8, strides=4, activation=\"relu\")(inputs) \n",
" layer2 = layers.Conv2D(64, 4, strides=2, activation=\"relu\")(layer1)\n",
" layer3 = layers.Conv2D(64, 3, strides=1, activation=\"relu\")(layer2)\n",
" layer4 = layers.Flatten()(layer3)\n",
" layer5 = layers.Dense(512, activation=\"relu\")(layer4) \n",
" action = layers.Dense(num_actions, activation=\"linear\")(layer5) \n",
" return keras.Model(inputs=inputs, outputs=action)\n",
"\n",
"def create_env(env_name, seed=42):\n",
" try:\n",
" # Use the Baseline Atari environment because of Deepmind helper functions\n",
" env = make_atari(env_name)\n",
" # Warp the frames, grey scale, stake four frame and scale to smaller ratio\n",
" env = wrap_deepmind(env, frame_stack=True, scale=True)\n",
" print(\"Loaded gym\")\n",
" env.seed(seed)\n",
" return env\n",
" except:\n",
" print(\"Failed to make gym env\", env_name)\n",
" return None\n",
"\n",
"def run_sim(env, model, frame_count):\n",
" state = np.array(env.reset())\n",
" total_reward = 0\n",
" for i in range(frame_count):\n",
" # in the notebook version we cannot really \n",
" # render in realtime, so you just have\n",
" # to check the score :( \n",
" env.render('human')\n",
" state_tensor = keras.backend.constant(state)\n",
" state_tensor = keras.backend.expand_dims(state_tensor, 0)\n",
" action_values = model(state_tensor, training=False)\n",
" # Take best action\n",
" action = keras.backend.argmax(action_values[0]).numpy()\n",
" state, reward, done, _ = env.step(action)\n",
" state = np.array(state)\n",
" total_reward += reward\n",
" if done:\n",
" print(\"Game over at frame\", i, \"rew\", total_reward, \"rewards/frame: \", total_reward/i)\n",
" env.reset()\n",
" #break\n",
" #time.sleep(0.1)\n",
" print(\"Sim ended : rew is \", total_reward)\n",
"\n",
"def main(env_name, model_file,frame_count=1000, seed=42):\n",
" env = create_env(env_name=env_name)\n",
" assert env is not None, \"Failed to make env \" + env_name\n",
" model = create_q_model(num_actions=env.action_space.n)\n",
" model_testfile = model_file + \".data-00000-of-00001\"\n",
" assert os.path.exists(model_testfile), \"Failed to load model: \" + model_testfile\n",
" print(\"Model weights look loadable\", model_testfile)\n",
" model.load_weights(model_file)\n",
" print(\"Model loaded weights - starting sim\")\n",
" run_sim(env, model, frame_count)\n",
" \n",
"main(env_name, model_file, frame_count=1000)\n",
"\n",
"# LEV"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
}
},
"nbformat": 4,
"nbformat_minor": 4
}