210 lines
8.8 KiB
Plaintext
210 lines
8.8 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Run a pre-trained model\n",
|
|
"\n",
|
|
"This notebook loads a pre-trained model and uses it to play games. \n",
|
|
"Note that it does not render the image of the game, it just prints out the episodic score. "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"pygame 2.1.2 (SDL 2.0.18, Python 3.10.8)\n",
|
|
"Hello from the pygame community. https://www.pygame.org/contribute.html\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# sanity check: can we create breakwall?\n",
|
|
"import gym\n",
|
|
"e = gym.make('gym_gs:BreakwallNoFrameskip-v1')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Collecting git+https://github.com/openai/baselines.git\n",
|
|
" Cloning https://github.com/openai/baselines.git to c:\\users\\gofor\\appdata\\local\\temp\\pip-req-build-s405pyio\n",
|
|
" Resolved https://github.com/openai/baselines.git to commit ea25b9e8b234e6ee1bca43083f8f3cf974143998\n",
|
|
" Preparing metadata (setup.py): started\n",
|
|
" Preparing metadata (setup.py): finished with status 'done'\n",
|
|
"Requirement already satisfied: gym<0.16.0,>=0.15.4 in c:\\users\\gofor\\myvenv\\lib\\site-packages (from baselines==0.1.6) (0.15.7)\n",
|
|
"Requirement already satisfied: scipy in c:\\users\\gofor\\myvenv\\lib\\site-packages (from baselines==0.1.6) (1.9.3)\n",
|
|
"Requirement already satisfied: tqdm in c:\\users\\gofor\\myvenv\\lib\\site-packages (from baselines==0.1.6) (4.64.1)\n",
|
|
"Requirement already satisfied: joblib in c:\\users\\gofor\\myvenv\\lib\\site-packages (from baselines==0.1.6) (1.2.0)\n",
|
|
"Requirement already satisfied: cloudpickle in c:\\users\\gofor\\myvenv\\lib\\site-packages (from baselines==0.1.6) (1.2.2)\n",
|
|
"Requirement already satisfied: click in c:\\users\\gofor\\myvenv\\lib\\site-packages (from baselines==0.1.6) (8.1.3)\n",
|
|
"Requirement already satisfied: opencv-python in c:\\users\\gofor\\myvenv\\lib\\site-packages (from baselines==0.1.6) (4.6.0.66)\n",
|
|
"Requirement already satisfied: numpy>=1.10.4 in c:\\users\\gofor\\myvenv\\lib\\site-packages (from gym<0.16.0,>=0.15.4->baselines==0.1.6) (1.23.5)\n",
|
|
"Requirement already satisfied: pyglet<=1.5.0,>=1.4.0 in c:\\users\\gofor\\myvenv\\lib\\site-packages (from gym<0.16.0,>=0.15.4->baselines==0.1.6) (1.5.0)\n",
|
|
"Requirement already satisfied: six in c:\\users\\gofor\\myvenv\\lib\\site-packages (from gym<0.16.0,>=0.15.4->baselines==0.1.6) (1.16.0)\n",
|
|
"Requirement already satisfied: colorama in c:\\users\\gofor\\myvenv\\lib\\site-packages (from click->baselines==0.1.6) (0.4.6)\n",
|
|
"Requirement already satisfied: future in c:\\users\\gofor\\myvenv\\lib\\site-packages (from pyglet<=1.5.0,>=1.4.0->gym<0.16.0,>=0.15.4->baselines==0.1.6) (0.18.2)\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
" Running command git clone --filter=blob:none --quiet https://github.com/openai/baselines.git 'C:\\Users\\gofor\\AppData\\Local\\Temp\\pip-req-build-s405pyio'\n",
|
|
"\n",
|
|
"[notice] A new release of pip available: 22.2.2 -> 22.3.1\n",
|
|
"[notice] To update, run: python.exe -m pip install --upgrade pip\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# install baselines and other stuff\n",
|
|
"!pip install git+https://github.com/openai/baselines.git"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Loaded gym\n",
|
|
"Model weights look loadable ./pre-trained/mac_hard_breakwall/gym_gs_BreakwallNoFrameskip-v1_20211018-114642_5424.data-00000-of-00001\n",
|
|
"Model loaded weights - starting sim\n",
|
|
"Game over at frame 278 rew 2.0 rewards/frame: 0.007194244604316547\n",
|
|
"Game over at frame 453 rew 3.0 rewards/frame: 0.006622516556291391\n",
|
|
"Game over at frame 631 rew 4.0 rewards/frame: 0.006339144215530904\n",
|
|
"Game over at frame 906 rew 6.0 rewards/frame: 0.006622516556291391\n",
|
|
"Game over at frame 976 rew 6.0 rewards/frame: 0.006147540983606557\n",
|
|
"Sim ended : rew is 6.0\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"## full check - can we use the full opencv/ openai version \n",
|
|
"## of the gym?\n",
|
|
"\n",
|
|
"# Script to test a pre-trained model\n",
|
|
"# Written by Matthew Yee-King\n",
|
|
"# MIT license \n",
|
|
"# https://mit-license.org/\n",
|
|
"\n",
|
|
"import sys\n",
|
|
"import os\n",
|
|
"from baselines.common.atari_wrappers import make_atari, wrap_deepmind\n",
|
|
"import numpy as np\n",
|
|
"import tensorflow as tf\n",
|
|
"from tensorflow import keras\n",
|
|
"from tensorflow.keras import layers\n",
|
|
"import datetime\n",
|
|
"import random\n",
|
|
"import time \n",
|
|
"\n",
|
|
"env_name = \"gym_gs:BreakwallNoFrameskip-v1\" \n",
|
|
"# for notebook users - make sure you have uploaded your pre-trained\n",
|
|
"# models... then adjust this to reflect the file path\n",
|
|
"model_file = \"./pre-trained/mac_hard_breakwall/gym_gs_BreakwallNoFrameskip-v1_20211018-114642_5424\"\n",
|
|
"\n",
|
|
"def create_q_model(num_actions):\n",
|
|
" # Network defined by the Deepmind paper\n",
|
|
" inputs = layers.Input(shape=(84, 84, 4,))\n",
|
|
" # Convolutions on the frames on the screen\n",
|
|
" layer1 = layers.Conv2D(32, 8, strides=4, activation=\"relu\")(inputs) \n",
|
|
" layer2 = layers.Conv2D(64, 4, strides=2, activation=\"relu\")(layer1)\n",
|
|
" layer3 = layers.Conv2D(64, 3, strides=1, activation=\"relu\")(layer2)\n",
|
|
" layer4 = layers.Flatten()(layer3)\n",
|
|
" layer5 = layers.Dense(512, activation=\"relu\")(layer4) \n",
|
|
" action = layers.Dense(num_actions, activation=\"linear\")(layer5) \n",
|
|
" return keras.Model(inputs=inputs, outputs=action)\n",
|
|
"\n",
|
|
"def create_env(env_name, seed=42):\n",
|
|
" try:\n",
|
|
" # Use the Baseline Atari environment because of Deepmind helper functions\n",
|
|
" env = make_atari(env_name)\n",
|
|
" # Warp the frames, grey scale, stake four frame and scale to smaller ratio\n",
|
|
" env = wrap_deepmind(env, frame_stack=True, scale=True)\n",
|
|
" print(\"Loaded gym\")\n",
|
|
" env.seed(seed)\n",
|
|
" return env\n",
|
|
" except:\n",
|
|
" print(\"Failed to make gym env\", env_name)\n",
|
|
" return None\n",
|
|
"\n",
|
|
"def run_sim(env, model, frame_count):\n",
|
|
" state = np.array(env.reset())\n",
|
|
" total_reward = 0\n",
|
|
" for i in range(frame_count):\n",
|
|
" # in the notebook version we cannot really \n",
|
|
" # render in realtime, so you just have\n",
|
|
" # to check the score :( \n",
|
|
" env.render('human')\n",
|
|
" state_tensor = keras.backend.constant(state)\n",
|
|
" state_tensor = keras.backend.expand_dims(state_tensor, 0)\n",
|
|
" action_values = model(state_tensor, training=False)\n",
|
|
" # Take best action\n",
|
|
" action = keras.backend.argmax(action_values[0]).numpy()\n",
|
|
" state, reward, done, _ = env.step(action)\n",
|
|
" state = np.array(state)\n",
|
|
" total_reward += reward\n",
|
|
" if done:\n",
|
|
" print(\"Game over at frame\", i, \"rew\", total_reward, \"rewards/frame: \", total_reward/i)\n",
|
|
" env.reset()\n",
|
|
" #break\n",
|
|
" #time.sleep(0.1)\n",
|
|
" print(\"Sim ended : rew is \", total_reward)\n",
|
|
"\n",
|
|
"def main(env_name, model_file,frame_count=1000, seed=42):\n",
|
|
" env = create_env(env_name=env_name)\n",
|
|
" assert env is not None, \"Failed to make env \" + env_name\n",
|
|
" model = create_q_model(num_actions=env.action_space.n)\n",
|
|
" model_testfile = model_file + \".data-00000-of-00001\"\n",
|
|
" assert os.path.exists(model_testfile), \"Failed to load model: \" + model_testfile\n",
|
|
" print(\"Model weights look loadable\", model_testfile)\n",
|
|
" model.load_weights(model_file)\n",
|
|
" print(\"Model loaded weights - starting sim\")\n",
|
|
" run_sim(env, model, frame_count)\n",
|
|
" \n",
|
|
"main(env_name, model_file, frame_count=1000)\n",
|
|
"\n",
|
|
"# LEV"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.8"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 4
|
|
}
|