Skip to content
200 changes: 200 additions & 0 deletions examples/trees/mcts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torchrl
from tensordict import TensorDict

pgn_or_fen = "fen"
mask_actions = False

env = torchrl.envs.ChessEnv(
include_pgn=False,
include_fen=True,
include_hash=True,
include_hash_inv=True,
include_san=True,
stateful=True,
mask_actions=mask_actions,
)


def transform_reward(td):
if "reward" not in td:
return td
reward = td["reward"]
if reward == 0.5:
td["reward"] = 0
elif reward == 1 and td["turn"]:
td["reward"] = -td["reward"]
return td


# ChessEnv sets the reward to 0.5 for a draw and 1 for a win for either player.
# Need to transform the reward to be:
# white win = 1
# draw = 0
# black win = -1
env = env.append_transform(transform_reward)

forest = torchrl.data.MCTSForest()
forest.reward_keys = env.reward_keys
forest.done_keys = env.done_keys
forest.action_keys = env.action_keys

if mask_actions:
forest.observation_keys = [f"{pgn_or_fen}_hash", "turn", "action_mask"]
else:
forest.observation_keys = [f"{pgn_or_fen}_hash", "turn"]

C = 2.0**0.5


def traversal_priority_UCB1(tree):
subtree = tree.subtree
visits = subtree.visits
reward_sum = subtree.wins

# If it's black's turn, flip the reward, since black wants to
# optimize for the lowest reward, not highest.
if not subtree.rollout[0, 0]["turn"]:
reward_sum = -reward_sum

parent_visits = tree.visits
reward_sum = reward_sum.squeeze(-1)
priority = (reward_sum + C * torch.sqrt(torch.log(parent_visits))) / visits
priority[visits == 0] = float("inf")
return priority


def _traverse_MCTS_one_step(forest, tree, env, max_rollout_steps):
done = False
trees_visited = [tree]

while not done:
if tree.subtree is None:
td_tree = tree.rollout[-1]["next"].clone()

if (tree.visits > 0 or tree.parent is None) and not td_tree["done"]:
actions = env.all_actions(td_tree)
subtrees = []

for action in actions:
td = env.step(env.reset(td_tree).update(action))
new_node = torchrl.data.Tree(
rollout=td.unsqueeze(0),
node_data=td["next"].select(*forest.node_map.in_keys),
count=torch.tensor(0),
wins=torch.zeros_like(td["next"]["reward"]),
)
subtrees.append(new_node)

# NOTE: This whole script runs about 2x faster with lazy stack
# versus eager stack.
tree.subtree = TensorDict.lazy_stack(subtrees)
chosen_idx = torch.randint(0, len(subtrees), ()).item()
rollout_state = subtrees[chosen_idx].rollout[-1]["next"]

else:
rollout_state = td_tree

if rollout_state["done"]:
rollout_reward = rollout_state["reward"]
else:
rollout = env.rollout(
max_steps=max_rollout_steps,
tensordict=rollout_state,
)
rollout_reward = rollout[-1]["next", "reward"]
done = True

else:
priorities = traversal_priority_UCB1(tree)
chosen_idx = torch.argmax(priorities).item()
tree = tree.subtree[chosen_idx]
trees_visited.append(tree)

for tree in trees_visited:
tree.visits += 1
tree.wins += rollout_reward


def traverse_MCTS(forest, root, env, num_steps, max_rollout_steps):
"""Performs Monte-Carlo tree search in an environment.

Args:
forest (MCTSForest): Forest of the tree to update. If the tree does not
exist yet, it is added.
root (TensorDict): The root step of the tree to update.
env (EnvBase): Environment to performs actions in.
num_steps (int): Number of iterations to traverse.
max_rollout_steps (int): Maximum number of steps for each rollout.
"""
if root not in forest:
for action in env.all_actions(root):
td = env.step(env.reset(root.clone()).update(action))
forest.extend(td.unsqueeze(0))

tree = forest.get_tree(root)
tree.wins = torch.zeros_like(td["next", "reward"])
for subtree in tree.subtree:
subtree.wins = torch.zeros_like(td["next", "reward"])

for _ in range(num_steps):
_traverse_MCTS_one_step(forest, tree, env, max_rollout_steps)

return tree


def tree_format_fn(tree):
td = tree.rollout[-1]["next"]
return [
td["san"],
td[pgn_or_fen].split("\n")[-1],
tree.wins,
tree.visits,
]


def get_best_move(fen, mcts_steps, rollout_steps):
root = env.reset(TensorDict({"fen": fen}))
tree = traverse_MCTS(forest, root, env, mcts_steps, rollout_steps)

# print('------------------------------')
# print(tree.to_string(tree_format_fn))
# print('------------------------------')

moves = []

for subtree in tree.subtree:
san = subtree.rollout[0]["next", "san"]
reward_sum = subtree.wins
visits = subtree.visits
value_avg = (reward_sum / visits).item()
if not subtree.rollout[0]["turn"]:
value_avg = -value_avg
moves.append((value_avg, san))

moves = sorted(moves, key=lambda x: -x[0])

print("------------------")
for value_avg, san in moves:
print(f" {value_avg:0.02f} {san}")
print("------------------")

return moves[0][1]


# White has M1, best move Rd8#. Any other moves lose to M2 or M1.
fen0 = "7k/6pp/7p/7K/8/8/6q1/3R4 w - - 0 1"
assert get_best_move(fen0, 100, 10) == "Rd8#"

# Black has M1, best move Qg6#. Other moves give rough equality or worse.
fen1 = "6qk/2R4p/7K/8/8/8/8/4R3 b - - 1 1"
assert get_best_move(fen1, 100, 10) == "Qg6#"

# White has M2, best move Rxg8+. Any other move loses.
fen2 = "2R3qk/5p1p/7K/8/8/8/5r2/2R5 w - - 0 1"
assert get_best_move(fen2, 1000, 10) == "Rxg8+"
67 changes: 47 additions & 20 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -4157,43 +4157,68 @@ def test_env_reset_with_hash(self, stateful, include_san):
td_check = env.reset(td.select("fen_hash"))
assert (td_check == td).all()

@pytest.mark.parametrize("include_fen", [False, True])
@pytest.mark.parametrize("include_pgn", [False, True])
@pytest.mark.parametrize("include_fen,include_pgn", [[False, True], [True, False]])
@pytest.mark.parametrize("stateful", [False, True])
@pytest.mark.parametrize("mask_actions", [False, True])
def test_all_actions(self, include_fen, include_pgn, stateful, mask_actions):
if not stateful and not include_fen and not include_pgn:
# pytest.skip("fen or pgn must be included if not stateful")
return

@pytest.mark.parametrize("include_hash", [False, True])
@pytest.mark.parametrize("include_san", [False, True])
@pytest.mark.parametrize("append_transform", [False, True])
# @pytest.mark.parametrize("mask_actions", [False, True])
@pytest.mark.parametrize("mask_actions", [False])
def test_all_actions(
self,
include_fen,
include_pgn,
stateful,
include_hash,
include_san,
append_transform,
mask_actions,
):
env = ChessEnv(
include_fen=include_fen,
include_pgn=include_pgn,
include_san=include_san,
include_hash=include_hash,
include_hash_inv=include_hash,
stateful=stateful,
mask_actions=mask_actions,
)
td = env.reset()

if not mask_actions:
with pytest.raises(RuntimeError, match="Cannot generate legal actions"):
env.all_actions()
return
def transform_reward(td):
if "reward" not in td:
return td
reward = td["reward"]
if reward == 0.5:
td["reward"] = 0
elif reward == 1 and td["turn"]:
td["reward"] = -td["reward"]
return td

# ChessEnv sets the reward to 0.5 for a draw and 1 for a win for either player.
# Need to transform the reward to be:
# white win = 1
# draw = 0
# black win = -1
if append_transform:
env = env.append_transform(transform_reward)

check_env_specs(env)

td = env.reset()

# Choose random actions from the output of `all_actions`
for _ in range(100):
if stateful:
all_actions = env.all_actions()
else:
for step_idx in range(100):
if step_idx % 5 == 0:
# Reset the the initial state first, just to make sure
# `all_actions` knows how to get the board state from the input.
env.reset()
all_actions = env.all_actions(td.clone())
all_actions = env.all_actions(td.clone())

# Choose some random actions and make sure they match exactly one of
# the actions from `all_actions`. This part is not tested when
# `mask_actions == False`, because `rand_action` can pick illegal
# actions in that case.
if mask_actions:
if mask_actions and step_idx % 4 == 0:
# TODO: Something is wrong in `ChessEnv.rand_action` which makes
# it fail to work properly for stateless mode. It doesn't know
# how to correctly reset the board state to what is given in the
Expand All @@ -4210,7 +4235,9 @@ def test_all_actions(self, include_fen, include_pgn, stateful, mask_actions):

action_idx = torch.randint(0, all_actions.shape[0], ()).item()
chosen_action = all_actions[action_idx]
td = env.step(td.update(chosen_action))["next"]
td_new = env.step(td.update(chosen_action).clone())
assert (td == td_new.exclude("next")).all()
td = td_new["next"]

if td["done"]:
td = env.reset()
Expand Down
5 changes: 5 additions & 0 deletions torchrl/data/map/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -1364,6 +1364,11 @@ def valid_paths(cls, tree: Tree):
def __len__(self):
return len(self.data_map)

def __contains__(self, root: TensorDictBase):
if self.node_map is None:
return False
return root.select(*self.node_map.in_keys) in self.node_map

def to_string(self, td_root, node_format_fn=lambda tree: tree.node_data.to_dict()):
"""Generates a string representation of a tree in the forest.

Expand Down
Loading
Loading