GithubHelp home page GithubHelp logo

sotetsuk / pgx Goto Github PK

View Code? Open in Web Editor NEW
318.0 6.0 19.0 76.79 MB

🎲 Vectorized RL game environments in JAX

Home Page: http://sotets.uk/pgx/

License: Apache License 2.0

Python 69.40% Makefile 0.05% Jupyter Notebook 30.55%
machine-learning ai game reinforcement-learning backgammon chess go-game shogi bridge-game poker

pgx's Introduction

python pypi license ci codecov arxiv

A collection of GPU/TPU-accelerated parallel game simulators for reinforcement learning (RL)

Why Pgx?

Brax, a JAX-native physics engine, provides extremely high-speed parallel simulation for RL in continuous state space. Then, what about RL in discrete state spaces like Chess, Shogi, and Go? Pgx provides a wide variety of JAX-native game simulators! Highlighted features include:

  • Super fast in parallel execution on accelerators
  • 🎲 Various game support including Backgammon, Chess, Shogi, and Go
  • 🖼️ Beautiful visualization in SVG format

Quick start

Training examples

Usage

Pgx is available on PyPI. Note that your Python environment has jax and jaxlib installed, depending on your hardware specification.

$ pip install pgx

The following code snippet shows a simple example of using Pgx. You can try it out in this Colab. Note that all step functions in Pgx environments are JAX-native., i.e., they are all JIT-able. Please refer to the documentation for more details.

import jax
import pgx

env = pgx.make("go_19x19")
init = jax.jit(jax.vmap(env.init))
step = jax.jit(jax.vmap(env.step))

batch_size = 1024
keys = jax.random.split(jax.random.PRNGKey(42), batch_size)
state = init(keys)  # vectorized states
while not (state.terminated | state.truncated).all():
    action = model(state.current_player, state.observation, state.legal_action_mask)
    # step(state, action, keys) for stochastic envs
    state = step(state, action)  # state.rewards with shape (1024, 2)

Pgx is a library that focuses on faster implementations rather than just the API itself. However, the API itself is also sufficiently general. For example, all environments in Pgx can be converted to the AEC API of PettingZoo, and you can run Pgx environments through the PettingZoo API. You can see the demonstration in this Colab.

📣 API v2 (v2.0.0)

Pgx has been updated from API v1 to v2 as of November 8, 2023 (release v2.0.0). As a result, the signature for Env.step has changed as follows:

  • v1: step(state: State, action: Array)
  • v2: step(state: State, action: Array, key: Optional[PRNGKey] = None)

Also, pgx.experimental.auto_reset are changed to specify key as the third argument.

Purpose of the update: In API v1, even in environments with stochastic state transitions, the state transitions were deterministic, determined by the _rng_key inside the state. This was intentional, with the aim of increasing reproducibility. However, when using planning algorithms in this environment, there is a risk that information about the underlying true randomness could "leak." To make it easier for users to conduct correct experiments, Env.step has been changed to explicitly specify a key.

Impact of the update: Since the key is optional, it is still possible to execute as env.step(state, action) like API v1 in deterministic environments like Go and chess, so there is no impact on these games. As of v2.0.0, only 2048, backgammon, and MinAtar suite are affected by this change.

Supported games

Backgammon Chess Shogi Go

Use pgx.available_envs() -> Tuple[EnvId] to see the list of currently available games. Given an <EnvId>, you can create the environment via

>>> env = pgx.make(<EnvId>)
Game/EnvId Visualization Version Five-word description by ChatGPT
2048
"2048"
v2 Merge tiles to create 2048.
Animal Shogi
"animal_shogi"
v0 Animal-themed child-friendly shogi.
Backgammon
"backgammon"
v2 Luck aids bearing off checkers.
Bridge bidding
"bridge_bidding"
v0 Partners exchange information via bids.
Chess
"chess"
v2 Checkmate opponent's king to win.
Connect Four
"connect_four"
v0 Connect discs, win with four.
Gardner Chess
"gardner_chess"
v0 5x5 chess variant, excluding castling.
Go
"go_9x9" "go_19x19"
v0 Strategically place stones, claim territory.
Hex
"hex"
v0 Connect opposite sides, block opponent.
Kuhn Poker
"kuhn_poker"
v0 Three-card betting and bluffing game.
Leduc hold'em
"leduc_holdem"
v0 Two-suit, limited deck poker.
MinAtar/Asterix
"minatar-asterix"
v1 Avoid enemies, collect treasure, survive.
MinAtar/Breakout
"minatar-breakout"
v1 Paddle, ball, bricks, bounce, clear.
MinAtar/Freeway
"minatar-freeway"
v1 Dodging cars, climbing up freeway.
MinAtar/Seaquest
"minatar-seaquest"
v1 Underwater submarine rescue and combat.
MinAtar/SpaceInvaders
"minatar-space_invaders"
v1 Alien shooter game, dodge bullets.
Othello
"othello"
v0 Flip and conquer opponent's pieces.
Shogi
"shogi"
v0 Japanese chess with captured pieces.
Sparrow Mahjong
"sparrow_mahjong"
v1 A simplified, children-friendly Mahjong.
Tic-tac-toe
"tic_tac_toe"
v0 Three in a row wins.
Versioning policy

Each environment is versioned, and the version is incremented when there are changes that affect the performance of agents or when there are changes that are not backward compatible with the API. If you want to pursue complete reproducibility, we recommend that you check the version of Pgx and each environment as follows:

>>> pgx.__version__
'1.0.0'
>>> env.version
'v0'

See also

Pgx is intended to complement these JAX-native environments with (classic) board game suits:

Combining Pgx with these JAX-native algorithms/implementations might be an interesting direction:

Citation

If you use Pgx in your work, please cite our paper:

@inproceedings{koyamada2023pgx,
  title={Pgx: Hardware-Accelerated Parallel Game Simulators for Reinforcement Learning},
  author={Koyamada, Sotetsu and Okano, Shinri and Nishimori, Soichiro and Murata, Yu and Habara, Keigo and Kita, Haruka and Ishii, Shin},
  booktitle={Advances in Neural Information Processing Systems},
  pages={45716--45743},
  volume={36},
  year={2023}
}

LICENSE

Apache-2.0

pgx's People

Contributors

akulen avatar bleu48 avatar carlosgmartin avatar clement-bonnet avatar dependabot[bot] avatar egiob avatar habara-k avatar harukaki avatar hongruitang avatar howuhh avatar nissymori avatar okanoshinri avatar rhaps0dy avatar sotetsuk avatar youyou-ku avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

pgx's Issues

APIの見直し

  • obs
  • reward
  • terminated
  • truncated
  • legal_actions
  • rng引数
  • stateの構造

braxも参考に

麻雀の状態遷移図

#63 (comment) にも現状の実装についての遷移図がある。
Mjxを参考に書いたがあんまり自信がない...

LE=Last Event

flowchart TD
開局 -- env: ツモ --> ツモ後

ツモ後 --打牌--> 打牌後
LE=立直 --打牌--> 打牌後

打牌後 -- env: 特殊流局 --> 終局
打牌後 -- env: 通常流局 --> 終局

打牌後 -- env: ツモ --> ツモ後

ツモ後 -- 立直 --> LE=立直
ツモ後 -- 門前清自摸 --> 終局
打牌後 -- ロン --> 終局
LE=加槓 -- ロン --> 終局
打牌後 -- 大明槓 --> LE=大明槓
LE=大明槓 -- env: ツモ --> ツモ後
ツモ後 -- 加槓 --> LE=加槓
LE=加槓 -- env: ツモ --> ツモ後

打牌後 -- env: 供託 --> 打牌後

ツモ後 -- env: 槓ドラ/1枚--> ツモ後
LE=大明槓 -- env: 槓ドラ1or2枚--> LE=大明槓
LE=加槓 -- env: 槓ドラ/0or前回分1枚--> LE=加槓 
ツモ後 --九種九牌--> 終局

打牌後 -- チー/ポン --> ツモ後
ツモ後 -- 暗槓 --> ツモ後
打牌後 -- Pass --> 打牌後
LE=加槓 -- Pass --> LE=加槓

小規模な囲碁のシミュレータを実装

#5, #17 に留意。

必要なモノ

State (dataclass)

  • board 5 * 5 * 3 or 5 * 5で0(黒), 1(白), 2(無)
  • turn 0(黒番) or 1(白番)
  • アゲハマ(黒)
  • アゲハマ(白)
  • 前のターンでパスされたか否か
  • passがあるため累計ターンはあまり意味ない?
  • 19 * 19を考えても、精々1083+1+アゲハマ分なので1次元でも良いかも

Action

  • x,y,手番 | None(pass)
  • intで表す場合
    • Noneか否か = 2^1
    • x座標, y座標(多くても19*19) = 361 < 2^8
    • 手番 = 2^1
    • 2進数で高々10桁=1024→整数値に変換できそう

参考: MuGo(≒AlphaGo)

A board is a NxN numpy array.
A Coordinate is a tuple index into the board.
A Move is a (Coordinate c | None).
A PlayerMove is a (Color, Move) tuple
(0, 0) is considered to be the upper left corner of the board, and (18, 0) is the lower left.

で定義するのが良いかなーって感じです

めんどくさそうな壁

  • 囲まれた判定
    • クラスを使わず、可変長リストも使わず、再帰関数もwhileも使わず、関数のみで実装しなければならない
  • コウ判定
    • LegalActionにも関係する。Stateに何かのフラグを追加した方が良いかも
  • 地の計算

ContractBridgeBiddingの可視化

@harukaki
他の環境と同じようにContractBridgeBiddingも可視化したいのですが、私があまりブリッジに馴染みがないこともあり、今一つイメージが掴めません。
表示すべき内容やもし参考画像などありましたら教えていただけませんか。

`jax.jit` しやすいシミュレータの実装方針

  • [必須] 可変長のforループを使わない
  • [必須] listを使わない(appendなども使わない)
  • ifのネストやforをなるべくさけ、numpyの操作をなるべく使う
  • 長い関数や深いネストをさけ、10行以内くらいの小さいpure functionになるべく分割する
    • pure functionとは副作用が起こらず、引数が同じなら同じ結果が返ってくるもののこと
  • stateが一つの行列で表現できない場合は dataclass で実装する
    • dataclassの各要素はすべて np.ndarray とし、intの場合も a: np.ndarray = np.zeros(1) といった具合にする
    • observationを生成する関数も用意する
  • actionがintや一つのベクトルで表現できない場合は、 dataclass で実装する Actionは一つのintで表現できる必要がある(legal_actionsが可変長リストになってしまうため)
    • 0~N個の連続した数字に各アクションを割り当てる必要がある。
  • legal_actionsも可変長リストはダメ。
    • ダメな例: legal_actions: List[Action] みたいなのはダメ。

麻雀の聴牌判定・役判定

Mjxでは事前計算した結果をハッシュテーブルで取ってくることで高速に実現しているが、GPUではこれはできない。

解決方法をいくつか考える

案1

ニューラルネットワークで結果を予測する(覚える)
役判定は和了形の手の数はそこまで多くないので覚えられるはず。
聴牌判定はFalseの例が際限なくあるので保証はできないが99.99%くらいまでは予測できそう。

案2

少なくとも聴牌判定に関しては行列演算が高速なことを利用して計算できるかもしれない。
数牌の形が鍵だが、ある色について和了形で可能なパターンはたぶんかなり少ない。
このパターンにマッチするかは簡単に計算できるはず。
たとえば、4x9xNのonehotなパターン行列を用意しておき、現在の手札を4x9に直してandをとるる。
ぴったり一致しているところ以外は元のパターン表現より立っているbitの数が減っているはず。
各崇拝と字牌の形が正しいとして、数牌と字牌のブロック毎に頭の数と面子の数は枚数から簡単にわかるので、雀頭が一つだけになっているかを確認すればよい
(Onehotにしなくても枚数一致してるか見ればよかった)

多分こんな感じで行けそうな気がする(未検証)

>>> num_hans.shape
(3, 9)
>>> win_patterns.shape
(N, 9)
>>> np.abs((num_hand.reshape((3,1,9)) - win_patterns)) # (3, N, 9)
...    .sum(axis=-1)  # (3, N)
...    .sum(axis=0) # (N,)
...    .min()
0.  # if win_pattern 

役判定についても、同様にできるかもしれない。
たとえば一通の形のフィルタを各色かけたりすればよいか。
三色は数牌すべてについてフィルタをかける?

Jax化手順

  • numpy実装との一致を確かめるテストを書く
    • numpy実装をそのままコピーしてくる
      • numpy実装は _anmal_shogi.py のように隠す
      • jax実装はひとまずnumpy実装をそのままコピーしてくる
      • numpy実装のテストはそのままnumpy実装をテストするようにする(_animal_shogi.py 以下をテストする)
    • numpy実装とjax実装がランダムアクションで一致していることを確かめるテストを新しく書く
      • 初期状態から終了までnumpy実装でランダムアクションで実行できるようにする
      • numpy実装のstateからjax実装のstateへ変換する関数を用意する(とりあえずそのまま返す)
      • 状態sからs'へnumpy実装で遷移したとき、sをjax実装のstateへ変換してjax実装で遷移したs''が、s'をjax実装へ変換したものと同じであることを確認するコードを書く
  • Jax実装をnumpyから少しずつ移植(各ステップごとにテストが通るように確認)
    • Stateのメンバを jax.numpy の ndarray にする(intやfloatもすべて)
    • Stateをdataclassからflax.structのdataclassにする
    • (Optional)Actionも同様にする
    • 一番最後に呼ばれる関数から順番に jax.jit をつけていく
    • 関数に jax.jit をつける手順
      • 1 まず、@jax.jit をつけてみてエラーがでることを確認して、@jax.jit をコメントアウト
      • 2 エラーがでた箇所をJax化する(e.g., if => jax.lax.cond, jax.lax.switch)
      • 3 @jax.jit をつけてみてエラーが出た箇所が下の行へ移ったことを確認
      • 2,3をエラーがでなくなるまで繰り返す
    • step (と init) に @jax.jit がついたら完成

`make check` のバージョンを固定

結構な頻度でバージョンが上がっていくので、新しい機能のPR内で関係ない型チェックが落ちたりすることがよくある。

可視化で実現したいことをまとめる

  • >>> state でのNotebookでの可視化(及び保存)
  • >>> states でのbatchのNotebookでの可視化(及び保存)
  • 一つの時系列についてGifの生成
    • Notebookでの可視化
    • 保存
  • (複数の時系列についてGifの生成)

他のゲーム候補

Goal: 20 games?

  • Connect four
  • Othello
  • Leduc Holdem
  • Kuhn poker
  • 2048
  • Hex
  • Hanabi
  • Checker
  • Tetris
  • Poker
  • rubik cube

Backgammon survey

  • Backgammonのルールの確認
  • Backgammonの状態, 特徴量設計のため関連資料をサーベイ
  • 目ぼしい状態特徴量設計を洗い出す.

Go envのreadme

State

@struct.dataclass
class GoState:
    # 横幅, マスの数ではない
    size: jnp.ndarray = jnp.full(1, 19, dtype=int)

    # 連
    ren_id_board: jnp.ndarray = jnp.full((2, 19 * 19), -1, dtype=int)

    # 連idが使えるか
    available_ren_id: jnp.ndarray = jnp.ones((2, 19 * 19), dtype=bool)

    # 連周りの情報 0:None 1:呼吸点 2:石
    liberty: jnp.ndarray = jnp.zeros((2, 19 * 19, 19 * 19), dtype=int)

    # 隣接している敵の連id
    adj_ren_id: jnp.ndarray = jnp.zeros((2, 19 * 19, 19 * 19), dtype=bool)

    # 経過ターン, 0始まり
    turn: jnp.ndarray = jnp.zeros(1, dtype=int)

    # [0]: 黒の得たアゲハマ, [1]: 白の得たアゲハマ
    agehama: jnp.ndarray = jnp.zeros(2, dtype=int)

    # 直前のactionがパスだとTrue
    passed: jnp.ndarray = jnp.zeros(1, dtype=bool)

    # コウによる着手禁止点(xy), 無ければ(-1)
    kou: jnp.ndarray = jnp.full(1, -1, dtype=int)

関数

  • init(size: int)
    • GoStateを返す
    • sizeで横幅を指定する
  • step(state: GoState, action: int, size: int)
    • actionは盤面を左上から右に数えていったマス目の番号(整数値)
    • passの時はaction-1を指定する
    • jaxの都合で盤のsizeを指定する必要あり
    • 返り値はTuple[GoState, jnp.ndarray, bool]で、それぞれ(state, reward, done)
  • legal_actions(state: GoState, size: int)
    • 長さsize*sizendarrayを返す。石を置けるマスはTrue、そうでないマスはFalseが入っている
  • show(state: GoState)
    • 可視化する

[囲碁] 地の判定、報酬計算

環境を作る以上正しく報酬を返さなければならない。

現実の囲碁は

  • 終了条件が「両者のパス」であり、将棋やチェスのようにこうしたら勝ちというものではない
  • 駄目の存在により、盤の空点全てが地になっていなくても終局する(終局の判定が人間任せ)
  • すなわち互いに合意するまで対局が続く

今回のシミュレータでは、ひとまず

  • 両者がパスしたら終局(シミュレータ側で終局の判定はしない)
  • 終局したら盤面の地を数える
  • 「地-アゲハマ」が大きい方が勝ち

として進めたいと思います。
報酬はとりあえず勝った方を+100, 負けた方を-100勝った方を+1, 負けた方を-1としておきます。

mahjong の機能拡張

    • ドラ
    • 赤牌
    • リーチ
    • 点数計算
  • ノーテン罰符
  • 局システム(現状は一局だけの対戦)
  • 特殊流局
  • 喰い変えの禁止
  • フリテン
  • 自摸切り、手出しの記録
  • observation(現状は自分の純手牌と、直前に捨てられた牌のみ)

[囲碁] Stateの可視化

svgwrite
matplotlib

  • 今回はmjxと違って独自フォントを埋め込まないので、png変換→matplotlibで表示、とかも使えそう
    • 新しいウィンドウで画像を出したりできる

TODO

  • to_svg(state)
  • show_svg(state)を用意
  • カラーバリエーション(ダークモード等)

Backgammonのテストが落ちることがある

https://github.com/sotetsuk/pgx/actions/runs/3377609227/jobs/5606745367

=========================== short test summary info ============================
FAILED tests/test_backgammon.py::test_step - assert 1 == -1
 +  where 1 = BackgammonState(board=array([  0,   0,   0,   3,   0,   0,   0,   0,   0,   0,   5,   0,   0,\n         0,   0,   0,   0,   0,   0, -10,   1,  -2,   2,   0,  -1,   4,\n        -2,   0], dtype=int8), dice=array([0, 2], dtype=int8), playable_dice=array([ 0,  2, -1, -1], dtype=int8), played_dice_num=0, turn=1, legal_action_mask=array([0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n       0, 0, 0, 0, 0, 0, 0, 0], dtype=int8)).turn
============ 1 failed, 101 passed, 11 warnings in 571.63s (0:09:31) ============

全体共通でテストすべきこと

  • 各関数の型
  • stepを繰り返してゲームが進行できるか
  • terminal stateが帰ってくるときの curr_playerlegal_action_mask
  • terminal stateでstepしたときの挙動
  • illegal action時の負の報酬を決める

どうぶつ将棋・将棋のデータの持ち方

今の構造は
盤面:それぞれの座標に存在する駒の種類に対応した行列(種類iの駒が入っている座標の行列は、p(i)番目の要素が1で他が0、というもの)を入れている
持ち駒:それぞれの駒種について初期値が0枚で、駒を取ったり打ったりと変動が起きた場合に持っている枚数に応じて変化する(持っているのが0枚なら[1, 0, ...], 1枚なら[0, 1, 0, ...]といった感じ)

dlshogiの特徴量と比較すると、盤面の方はほぼ同じ(駒の利きが入っていたくらい)で、持ち駒の方は先手後手のみ分かれていて、駒種ごとに行列のエリアが分かれていてそれぞれの駒を持っている枚数だけ1を立てる(例えば、歩のエリアが0~3、香車のエリアが4~7で歩を2枚、香車を1枚持っている場合は、[1,1,0,0,(ここまで歩のエリア)1, 0, 0, 0,(ここまで香車のエリア)…]といった形式)仕様だった。シミュレーションという点で見ると使いにくそうだったので上のものに変更しているが、特徴量の形式に変換するのは難しくないと思う

[MinAtar] Fix Seaquest tests

https://github.com/sotetsuk/pgx/actions/runs/3343343322/jobs/5536462957

=================================== FAILURES ===================================
________________________________ test_step_det _________________________________

    def test_step_det():
        env = Environment("seaquest", sticky_action_prob=0.0)
        num_actions = env.num_actions()
    
        N = 100
        for _ in range(N):
            env.reset()
            done = False
            while not done:
                s = extract_state(env, state_keys)
                a = random.randrange(num_actions)
                r, done = env.act(a)
                enemy_lr, is_sub, enemy_y, diver_lr, diver_y = env.env.enemy_lr, env.env.is_sub, env.env.enemy_y, env.env.diver_lr, env.env.diver_y
                s_next_pgx, _, _ = seaquest._step_det(
                    minatar2pgx(s, seaquest.MinAtarSeaquestState),
                    a,
                    enemy_lr,
                    is_sub,
                    enemy_y,
                    diver_lr,
                    diver_y
                )
>               assert jnp.allclose(
                    env.state(),
                    seaquest.observe(s_next_pgx),
                )
E               assert DeviceArray(False, dtype=bool)
E                +  where DeviceArray(False, dtype=bool) = <CompiledFunction of <function allclose at 0x7f8ae1d13820>>(array([[[False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False]],\n\n       [[False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, Fa... False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False]],\n\n       [[False, False, False, False, False, False, False,  True, False,\n         False],\n        [False, False, False, False, False, False, False,  True, False,\n         False],\n        [False, False, False, False, False, False, False,  True, False,\n         False],\n        [False, False, False, False, False, False, False,  True, False,\n         False],\n        [False, False, False, False, False, False, False,  True, False,\n         False],\n        [False, False, False, False, False, False, False,  True, False,\n         False],\n        [False, False, False, False, False, False, False,  True, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False]]]), DeviceArray([[[False, False, False, False, False, False, False, False,\n               False, False],\n              [False, False, False, False, False, False, False, False,\n               False, False],\n              [False, False, False, False, False, False, False, False,\n               False, False],\n              [False, False, False, False, False, False, False, False,\n               False, False],\n              [False, False, False, False, False, False, False, False,\n               False, False],\n              [False, False, False, False, False, False, False, False,\n               False, False],\n              [False, False, False, False, False, False, False, False,\n               False, False],\n              [False, False, False, False, False, False, False, False,\n               False, False],\n              [False, False, False, False, False, False, False, False,\n               False, False],\n              [False, False, False, False, False, False, False, False,\n               False, False]],\n\n             [[False, False, False, False, False, False, False, False,\n               False, False],\n              [False, False, False, False, False, False, False, False,\n              ... False, False, False, False, False, False,\n               False, False],\n              [False, False, False, False, False, False, False, False,\n               False, False]],\n\n             [[False, False, False, False, False, False, False,  True,\n               False, False],\n              [False, False, False, False, False, False, False,  True,\n               False, False],\n              [False, False, False, False, False, False, False,  True,\n               False, False],\n              [False, False, False, False, False, False, False,  True,\n               False, False],\n              [False, False, False, False, False, False, False,  True,\n               False, False],\n              [False, False, False, False, False, False, False,  True,\n               False, False],\n              [False, False, False, False, False, False, False,  True,\n               False, False],\n              [False, False, False, False, False, False, False, False,\n               False, False],\n              [False, False, False, False, False, False, False, False,\n               False, False],\n              [False, False, False, False, False, False, False, False,\n               False, False]]], dtype=bool))
E                +    where <CompiledFunction of <function allclose at 0x7f8ae1d13820>> = jnp.allclose
E                +    and   array([[[False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False]],\n\n       [[False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, Fa... False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False]],\n\n       [[False, False, False, False, False, False, False,  True, False,\n         False],\n        [False, False, False, False, False, False, False,  True, False,\n         False],\n        [False, False, False, False, False, False, False,  True, False,\n         False],\n        [False, False, False, False, False, False, False,  True, False,\n         False],\n        [False, False, False, False, False, False, False,  True, False,\n         False],\n        [False, False, False, False, False, False, False,  True, False,\n         False],\n        [False, False, False, False, False, False, False,  True, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False]]]) = <bound method Environment.state of <minatar.environment.Environment object at 0x7f8aac46bf70>>()
E                +      where <bound method Environment.state of <minatar.environment.Environment object at 0x7f8aac46bf70>> = <minatar.environment.Environment object at 0x7f8aac46bf70>.state
E                +    and   DeviceArray([[[False, False, False, False, False, False, False, False,\n               False, False],\n              [False, False, False, False, False, False, False, False,\n               False, False],\n              [False, False, False, False, False, False, False, False,\n               False, False],\n              [False, False, False, False, False, False, False, False,\n               False, False],\n              [False, False, False, False, False, False, False, False,\n               False, False],\n              [False, False, False, False, False, False, False, False,\n               False, False],\n              [False, False, False, False, False, False, False, False,\n               False, False],\n              [False, False, False, False, False, False, False, False,\n               False, False],\n              [False, False, False, False, False, False, False, False,\n               False, False],\n              [False, False, False, False, False, False, False, False,\n               False, False]],\n\n             [[False, False, False, False, False, False, False, False,\n               False, False],\n              [False, False, False, False, False, False, False, False,\n              ... False, False, False, False, False, False,\n               False, False],\n              [False, False, False, False, False, False, False, False,\n               False, False]],\n\n             [[False, False, False, False, False, False, False,  True,\n               False, False],\n              [False, False, False, False, False, False, False,  True,\n               False, False],\n              [False, False, False, False, False, False, False,  True,\n               False, False],\n              [False, False, False, False, False, False, False,  True,\n               False, False],\n              [False, False, False, False, False, False, False,  True,\n               False, False],\n              [False, False, False, False, False, False, False,  True,\n               False, False],\n              [False, False, False, False, False, False, False,  True,\n               False, False],\n              [False, False, False, False, False, False, False, False,\n               False, False],\n              [False, False, False, False, False, False, False, False,\n               False, False],\n              [False, False, False, False, False, False, False, False,\n               False, False]]], dtype=bool) = <CompiledFunction of <function observe at 0x7f8a82828dc0>>(MinAtarSeaquestState(oxygen=DeviceArray(157, dtype=int16), diver_count=DeviceArray(0, dtype=int8), sub_x=DeviceArray(0, dtype=int8), sub_y=DeviceArray(3, dtype=int8), sub_or=DeviceArray(False, dtype=bool), f_bullets=DeviceArray([[-1, -1, -1],\n             [-1, -1, -1],\n             [-1, -1, -1],\n             [-1, -1, -1],\n             [-1, -1, -1]], dtype=int8), e_bullets=DeviceArray([[ 5,  4,  0],\n             [-1, -1, -1],\n             [-1, -1, -1],\n             [-1, -1, -1],\n             [-1, -1, -1],\n             [-1, -1, -1],\n             [-1, -1, -1],\n             [-1, -1, -1],\n             [-1, -1, -1],\n             [-1, -1, -1],\n             [-1, -1, -1],\n             [-1, -1, -1],\n             [-1, -1, -1],\n             [-1, -1, -1],\n             [-1, -1, -1],\n             [-1, -1, -1],\n             [-1, -1, -1],\n             [-1, -1, -1],\n             [-1, -1, -1],\n             [-1, -1, -1],\n             [-1, -1, -1],\n             [-1, -1, -1],\n             [-1, -1, -1],\n             [-1, -1, -1],\n             [-1, -1, -1]], dtype=int8), e_fish=DeviceArray([[ 3,  1,  0,  1],\n             [-1, -1, -1, -1],\n             [-1, -1, -1, -1],\n             [-1, -1, -1, -1],\n   ...         [-1, -1, -1, -1, -1],\n             [-1, -1, -1, -1, -1],\n             [-1, -1, -1, -1, -1],\n             [-1, -1, -1, -1, -1],\n             [-1, -1, -1, -1, -1],\n             [-1, -1, -1, -1, -1],\n             [-1, -1, -1, -1, -1],\n             [-1, -1, -1, -1, -1],\n             [-1, -1, -1, -1, -1],\n             [-1, -1, -1, -1, -1],\n             [-1, -1, -1, -1, -1],\n             [-1, -1, -1, -1, -1],\n             [-1, -1, -1, -1, -1],\n             [-1, -1, -1, -1, -1],\n             [-1, -1, -1, -1, -1],\n             [-1, -1, -1, -1, -1],\n             [-1, -1, -1, -1, -1],\n             [-1, -1, -1, -1, -1]], dtype=int8), divers=DeviceArray([[ 1,  6,  0,  3],\n             [ 3,  8,  1,  3],\n             [-1, -1, -1, -1],\n             [-1, -1, -1, -1],\n             [-1, -1, -1, -1]], dtype=int8), e_spawn_speed=DeviceArray(19, dtype=int8), e_spawn_timer=DeviceArray(17, dtype=int8), d_spawn_timer=DeviceArray(10, dtype=int8), move_speed=DeviceArray(5, dtype=int8), ramp_index=DeviceArray(1, dtype=int8), shot_timer=DeviceArray(3, dtype=int8), surface=DeviceArray(False, dtype=bool), terminal=DeviceArray(False, dtype=bool), last_action=DeviceArray(1, dtype=int32, weak_type=True)))
E                +      where <CompiledFunction of <function observe at 0x7f8a82828dc0>> = seaquest.observe

tests/test_seaquest.py:55: AssertionError
=============================== warnings summary ===============================
tests/test_asterix.py::test_step_det
tests/test_freeway.py::test_step_det
tests/test_seaquest.py::test_step_det
tests/test_seaquest.py::test_step_det
tests/test_seaquest.py::test_step_det
tests/test_seaquest.py::test_observe
  /opt/hostedtoolcache/Python/3.8.14/x64/lib/python3.8/site-packages/jax-0.3.23-py3.8.egg/jax/_src/ops/scatter.py:87: FutureWarning: scatter inputs have incompatible types: cannot safely cast value from dtype=int32 to dtype=int8. In future JAX releases this will result in an error.
    warnings.warn("scatter inputs have incompatible types: cannot safely cast "

tests/test_freeway.py::test_observe
tests/test_seaquest.py::test_step_det
tests/test_seaquest.py::test_observe
tests/test_space_invaders.py::test_step_det
tests/test_space_invaders.py::test_observe
  /opt/hostedtoolcache/Python/3.8.14/x64/lib/python3.8/site-packages/jax-0.3.23-py3.8.egg/jax/_src/ops/scatter.py:87: FutureWarning: scatter inputs have incompatible types: cannot safely cast value from dtype=int32 to dtype=bool. In future JAX releases this will result in an error.
    warnings.warn("scatter inputs have incompatible types: cannot safely cast "

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
=========================== short test summary info ============================
FAILED tests/test_seaquest.py::test_step_det - assert DeviceArray(False, dtype=bool)
 +  where DeviceArray(False, dtype=bool) = <CompiledFunction of <function allclose at 0x7f8ae1d13820>>(array([[[False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False]],\n\n       [[False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, Fa... False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False]],\n\n       [[False, False, False, False, False, False, False,  True, False,\n         False],\n        [False, False, False, False, False, False, False,  True, False,\n         False],\n        [False, False, False, False, False, False, False,  True, False,\n         False],\n        [False, False, False, False, False, False, False,  True, False,\n         False],\n        [False, False, False, False, False, False, False,  True, False,\n         False],\n        [False, False, False, False, False, False, False,  True, False,\n         False],\n        [False, False, False, False, False, False, False,  True, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False]]]), DeviceArray([[[False, False, False, False, False, False, False, False,\n               False, False],\n              [False, False, False, False, False, False, False, False,\n               False, False],\n              [False, False, False, False, False, False, False, False,\n               False, False],\n              [False, False, False, False, False, False, False, False,\n               False, False],\n              [False, False, False, False, False, False, False, False,\n               False, False],\n              [False, False, False, False, False, False, False, False,\n               False, False],\n              [False, False, False, False, False, False, False, False,\n               False, False],\n              [False, False, False, False, False, False, False, False,\n               False, False],\n              [False, False, False, False, False, False, False, False,\n               False, False],\n              [False, False, False, False, False, False, False, False,\n               False, False]],\n\n             [[False, False, False, False, False, False, False, False,\n               False, False],\n              [False, False, False, False, False, False, False, False,\n              ... False, False, False, False, False, False,\n               False, False],\n              [False, False, False, False, False, False, False, False,\n               False, False]],\n\n             [[False, False, False, False, False, False, False,  True,\n               False, False],\n              [False, False, False, False, False, False, False,  True,\n               False, False],\n              [False, False, False, False, False, False, False,  True,\n               False, False],\n              [False, False, False, False, False, False, False,  True,\n               False, False],\n              [False, False, False, False, False, False, False,  True,\n               False, False],\n              [False, False, False, False, False, False, False,  True,\n               False, False],\n              [False, False, False, False, False, False, False,  True,\n               False, False],\n              [False, False, False, False, False, False, False, False,\n               False, False],\n              [False, False, False, False, False, False, False, False,\n               False, False],\n              [False, False, False, False, False, False, False, False,\n               False, False]]], dtype=bool))
 +    where <CompiledFunction of <function allclose at 0x7f8ae1d13820>> = jnp.allclose
 +    and   array([[[False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False]],\n\n       [[False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, Fa... False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False]],\n\n       [[False, False, False, False, False, False, False,  True, False,\n         False],\n        [False, False, False, False, False, False, False,  True, False,\n         False],\n        [False, False, False, False, False, False, False,  True, False,\n         False],\n        [False, False, False, False, False, False, False,  True, False,\n         False],\n        [False, False, False, False, False, False, False,  True, False,\n         False],\n        [False, False, False, False, False, False, False,  True, False,\n         False],\n        [False, False, False, False, False, False, False,  True, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False],\n        [False, False, False, False, False, False, False, False, False,\n         False]]]) = <bound method Environment.state of <minatar.environment.Environment object at 0x7f8aac46bf70>>()
 +      where <bound method Environment.state of <minatar.environment.Environment object at 0x7f8aac46bf70>> = <minatar.environment.Environment object at 0x7f8aac46bf70>.state
 +    and   DeviceArray([[[False, False, False, False, False, False, False, False,\n               False, False],\n              [False, False, False, False, False, False, False, False,\n               False, False],\n              [False, False, False, False, False, False, False, False,\n               False, False],\n              [False, False, False, False, False, False, False, False,\n               False, False],\n              [False, False, False, False, False, False, False, False,\n               False, False],\n              [False, False, False, False, False, False, False, False,\n               False, False],\n              [False, False, False, False, False, False, False, False,\n               False, False],\n              [False, False, False, False, False, False, False, False,\n               False, False],\n              [False, False, False, False, False, False, False, False,\n               False, False],\n              [False, False, False, False, False, False, False, False,\n               False, False]],\n\n             [[False, False, False, False, False, False, False, False,\n               False, False],\n              [False, False, False, False, False, False, False, False,\n              ... False, False, False, False, False, False,\n               False, False],\n              [False, False, False, False, False, False, False, False,\n               False, False]],\n\n             [[False, False, False, False, False, False, False,  True,\n               False, False],\n              [False, False, False, False, False, False, False,  True,\n               False, False],\n              [False, False, False, False, False, False, False,  True,\n               False, False],\n              [False, False, False, False, False, False, False,  True,\n               False, False],\n              [False, False, False, False, False, False, False,  True,\n               False, False],\n              [False, False, False, False, False, False, False,  True,\n               False, False],\n              [False, False, False, False, False, False, False,  True,\n               False, False],\n              [False, False, False, False, False, False, False, False,\n               False, False],\n              [False, False, False, False, False, False, False, False,\n               False, False],\n              [False, False, False, False, False, False, False, False,\n               False, False]]], dtype=bool) = <CompiledFunction of <function observe at 0x7f8a82828dc0>>(MinAtarSeaquestState(oxygen=DeviceArray(157, dtype=int16), diver_count=DeviceArray(0, dtype=int8), sub_x=DeviceArray(0, dtype=int8), sub_y=DeviceArray(3, dtype=int8), sub_or=DeviceArray(False, dtype=bool), f_bullets=DeviceArray([[-1, -1, -1],\n             [-1, -1, -1],\n             [-1, -1, -1],\n             [-1, -1, -1],\n             [-1, -1, -1]], dtype=int8), e_bullets=DeviceArray([[ 5,  4,  0],\n             [-1, -1, -1],\n             [-1, -1, -1],\n             [-1, -1, -1],\n             [-1, -1, -1],\n             [-1, -1, -1],\n             [-1, -1, -1],\n             [-1, -1, -1],\n             [-1, -1, -1],\n             [-1, -1, -1],\n             [-1, -1, -1],\n             [-1, -1, -1],\n             [-1, -1, -1],\n             [-1, -1, -1],\n             [-1, -1, -1],\n             [-1, -1, -1],\n             [-1, -1, -1],\n             [-1, -1, -1],\n             [-1, -1, -1],\n             [-1, -1, -1],\n             [-1, -1, -1],\n             [-1, -1, -1],\n             [-1, -1, -1],\n             [-1, -1, -1],\n             [-1, -1, -1]], dtype=int8), e_fish=DeviceArray([[ 3,  1,  0,  1],\n             [-1, -1, -1, -1],\n             [-1, -1, -1, -1],\n             [-1, -1, -1, -1],\n   ...         [-1, -1, -1, -1, -1],\n             [-1, -1, -1, -1, -1],\n             [-1, -1, -1, -1, -1],\n             [-1, -1, -1, -1, -1],\n             [-1, -1, -1, -1, -1],\n             [-1, -1, -1, -1, -1],\n             [-1, -1, -1, -1, -1],\n             [-1, -1, -1, -1, -1],\n             [-1, -1, -1, -1, -1],\n             [-1, -1, -1, -1, -1],\n             [-1, -1, -1, -1, -1],\n             [-1, -1, -1, -1, -1],\n             [-1, -1, -1, -1, -1],\n             [-1, -1, -1, -1, -1],\n             [-1, -1, -1, -1, -1],\n             [-1, -1, -1, -1, -1],\n             [-1, -1, -1, -1, -1],\n             [-1, -1, -1, -1, -1]], dtype=int8), divers=DeviceArray([[ 1,  6,  0,  3],\n             [ 3,  8,  1,  3],\n             [-1, -1, -1, -1],\n             [-1, -1, -1, -1],\n             [-1, -1, -1, -1]], dtype=int8), e_spawn_speed=DeviceArray(19, dtype=int8), e_spawn_timer=DeviceArray(17, dtype=int8), d_spawn_timer=DeviceArray(10, dtype=int8), move_speed=DeviceArray(5, dtype=int8), ramp_index=DeviceArray(1, dtype=int8), shot_timer=DeviceArray(3, dtype=int8), surface=DeviceArray(False, dtype=bool), terminal=DeviceArray(False, dtype=bool), last_action=DeviceArray(1, dtype=int32, weak_type=True)))
 +      where <CompiledFunction of <function observe at 0x7f8a82828dc0>> = seaquest.observe
============ 1 failed, 94 passed, 11 warnings in 696.61s (0:11:36) =============
make: *** [Makefile:33: test] Error 1

`jax.jit` を使って高速化するときのメモ

jax のnumpyとの基本的な違いとして、配列操作がimmutable(コピーしかできない)

jax.jit を使うためには動的な計算が制限される。

  • 引数を使ったif
  • 引数を長さに使ったfor
  • 引数をindexにつかった配列操作

などが動かないはず。
大体としてこれらを使うとjitでコンパイルできる(なぜかはよくわからない)

TIPS

  • 基本的にはまず、jitなしの実装とテストを用意する
  • 少しずつ関数呼び出しの深い方からテストとjitが通るようにjitに書き換えていく
  • もとの実装の時点で深いネストの実装を避ける
  • 深いネストは細かく純関数に分ける
    • このとき、たとえばネスト内で複数の変数が更新されていたら(e.g., x,y,z)、これらすべてを引数にとって、同様に返すような変数にすると楽

if

基本構文。<true_fn>,<false_fn> の型は同じである必要がある。

y = jax.lax.cond(
   x > 0, # cond
   lambda: x ** 2,  # if true 
   lambda: x, # if false
)

and/or&,| を使う

x = jax.lax.cond(
    (x % 2 == 0) & (x % 3 == 0),
    lambda: x**2,
    lambda: x,
)

switch

3つ以上の条件分岐などに使える

for

n がstaticなら大丈夫

@jax.jit
def f():
    s = 0
    for i in range(8):
        s += 1
    return s

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.