Skip to content

Commit 511dd9d

Browse files
author
Jeremiah Lewis
committed
update for CxxWrap 0.14 compat
1 parent 1f2364e commit 511dd9d

File tree

5 files changed

+87
-13
lines changed

5 files changed

+87
-13
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
name = "OpenSpiel"
22
uuid = "ceb70bd2-fe3f-44f0-b81f-41608acaf2f2"
33
authors = ["Jun Tian <[email protected]>"]
4-
version = "0.1.5"
4+
version = "0.2.0"
55

66
[deps]
77
CxxWrap = "1f15a43c-97ca-5a2a-ae31-89f07a497df4"
88
OpenSpiel_jll = "bd10a763-4654-5023-a028-c4918c6cd33e"
99

1010
[compat]
11-
CxxWrap = "0.12, 0.13, 0.14"
11+
CxxWrap = "0.14"
1212
OpenSpiel_jll = "1"
1313
julia = "1.6"
1414

src/OpenSpiel.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using OpenSpiel_jll
55

66
import CxxWrap:argument_overloads
77

8-
@wrapmodule(libspieljl)
8+
@wrapmodule(OpenSpiel_jll.get_libspieljl_path)
99

1010
include("patch.jl")
1111

src/patch.jl

Lines changed: 79 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
1-
Base.show(io::IO, g::CxxWrap.StdLib.SharedPtrAllocated{Game}) = print(io, to_string(g))
2-
Base.show(io::IO, s::CxxWrap.StdLib.UniquePtrAllocated{State}) = print(io, to_string(s))
1+
Base.show(io::IO, g::CxxWrap.StdLib.SharedPtrAllocated{Game}) = print(io, to_string(g[]))
2+
Base.show(io::IO, s::CxxWrap.StdLib.UniquePtrAllocated{State}) = print(io, to_string(s[]))
33
Base.show(io::IO, gp::Union{GameParameterAllocated, GameParameterDereferenced}) = print(io, to_repr_string(gp))
44

55
function Base.hash(s::CxxWrap.CxxWrapCore.SmartPointer{T}, h::UInt) where {T<:Union{Game,State}}
6-
hash(to_string(s), h)
6+
hash(to_string(s[]), h)
77
end
88

99
function Base.:(==)(s::CxxWrap.CxxWrapCore.SmartPointer{T}, ss::CxxWrap.CxxWrapCore.SmartPointer{T}) where {T<:Union{Game, State}}
10-
to_string(s) == to_string(ss)
10+
to_string(s[]) == to_string(ss[])
1111
end
1212

1313
GameParameter(x::Int) = GameParameter(Ref(Int32(x)))
1414

1515
Base.copy(s::CxxWrap.StdLib.UniquePtrAllocated{State}) = deepcopy(s)
16-
Base.deepcopy(s::CxxWrap.StdLib.UniquePtrAllocated{State}) = clone(s)
16+
Base.deepcopy(s::CxxWrap.StdLib.UniquePtrAllocated{State}) = clone(s[])
1717
Base.reshape(s::CxxWrap.StdLib.StdVectorAllocated, dims::Int32...) = reshape(s, Int.(dims))
1818

1919
if Sys.KERNEL == :Linux
@@ -75,3 +75,77 @@ function load_game_as_turn_based(s::Union{String, CxxWrap.StdLib.StdStringAlloca
7575
_load_game_as_turn_based(s, StdMap{StdString, GameParameter}(ps))
7676
end
7777
end
78+
79+
is_chance_node(state::CxxWrap.StdLib.UniquePtrAllocated{State}) = is_chance_node(state[])
80+
81+
new_initial_state(game::CxxWrap.StdLib.SharedPtrAllocated{Game}) = new_initial_state(game[])
82+
83+
legal_actions(state::CxxWrap.StdLib.UniquePtrAllocated{State}) = legal_actions(state[])
84+
85+
child(state::CxxWrap.StdLib.UniquePtrAllocated{State}, i::Int64) = child(state[], i)
86+
87+
is_terminal(state::CxxWrap.StdLib.UniquePtrAllocated{State}) = is_terminal(state[])
88+
89+
information_state_string(state::CxxWrap.StdLib.UniquePtrAllocated{State}) = information_state_string(state[])
90+
91+
get_uniform_policy(game::CxxWrap.StdLib.SharedPtrAllocated{Game}) = get_uniform_policy(game[])
92+
93+
record_batched_trajectories(game::CxxWrap.StdLib.SharedPtrAllocated{Game}, p::CxxWrap.StdLib.StdVectorAllocated{TabularPolicy}, m::StdMapAllocated{StdString, Int32}, i::Int64, b::Bool, i2::Int64, i3::Int64) = record_batched_trajectories(game[], p, m, i, b, i2, i3)
94+
95+
expected_returns(state::CxxWrap.StdLib.UniquePtrAllocated{State}, policy::CxxWrap.StdLib.SharedPtrAllocated{Policy}, i::Int64) = expected_returns(state[], policy[], i)
96+
97+
exploitability(game::CxxWrap.StdLib.SharedPtrAllocated{Game}, policy::CxxWrap.StdLib.SharedPtrAllocated{Policy}) = exploitability(game[], policy[])
98+
99+
current_player(state::CxxWrap.StdLib.UniquePtrAllocated{State}) = current_player(state[])
100+
101+
action_to_string(state::CxxWrap.StdLib.UniquePtrAllocated{State}, i1, i2) = action_to_string(state[], i1, i2)
102+
103+
apply_action(state::CxxWrap.StdLib.UniquePtrAllocated{State}, i::AbstractVector{<:Number}) = apply_action(state[], i)
104+
105+
apply_action(state::CxxWrap.StdLib.UniquePtrAllocated{State}, i::Number) = apply_action(state[], i)
106+
107+
restart_at(b::MCTSBotAllocated, s::CxxWrap.StdLib.UniquePtrAllocated{State}) = restart_at(b, s[])
108+
109+
best_child(root::CxxWrap.StdLib.UniquePtrAllocated{SearchNode}) = best_child(root[])
110+
111+
get_outcome(root::CxxWrap.StdLib.UniquePtrAllocated{SearchNode}) = get_outcome(root[])
112+
113+
get_player(p::CxxWrap.StdLib.UniquePtrAllocated{SearchNode}) = get_player(p[])
114+
115+
get_children(root::CxxWrap.StdLib.UniquePtrAllocated{SearchNode}) = get_children(root[])
116+
117+
is_simultaneous_node(state::CxxWrap.StdLib.UniquePtrAllocated{State}) = is_simultaneous_node(state[])
118+
119+
step(bot, state::CxxWrap.StdLib.UniquePtrAllocated{State}) = step(bot, state[])
120+
121+
chance_outcomes(state::CxxWrap.StdLib.UniquePtrAllocated{State}) = chance_outcomes(state[])
122+
123+
returns(state::CxxWrap.StdLib.UniquePtrAllocated{State}) = returns(state[])
124+
125+
min_utility(game::CxxWrap.StdLib.SharedPtrAllocated{Game}) = min_utility(game[])
126+
127+
max_utility(game::CxxWrap.StdLib.SharedPtrAllocated{Game}) = max_utility(game[])
128+
129+
serialize_game_and_state(game::CxxWrap.StdLib.SharedPtrAllocated{Game}, state::CxxWrap.StdLib.UniquePtrAllocated{State}) = serialize_game_and_state(game[], state[])
130+
131+
is_mean_field_node(state::CxxWrap.StdLib.UniquePtrAllocated{State}) = is_mean_field_node(state[])
132+
133+
legal_actions(state::CxxWrap.StdLib.UniquePtrAllocated{State}, i) = legal_actions(state[], i)
134+
135+
history(state::CxxWrap.StdLib.UniquePtrAllocated{State}) = history(state[])
136+
137+
is_player_node(state::CxxWrap.StdLib.UniquePtrAllocated{State}) = is_player_node(state[])
138+
139+
num_players(game::CxxWrap.StdLib.SharedPtrAllocated{Game}) = num_players(game[])
140+
141+
distribution_support(state::CxxWrap.StdLib.UniquePtrAllocated{State}) = distribution_support(state[])
142+
143+
get_type(game::CxxWrap.StdLib.SharedPtrAllocated{Game}) = get_type(game[])
144+
145+
update_distribution(state::CxxWrap.StdLib.UniquePtrAllocated{State}, dist::CxxWrap.StdLib.StdVectorAllocated{Float64}) = update_distribution(state[], dist)
146+
147+
num_cols(game::CxxWrap.StdLib.SharedPtrAllocated{MatrixGame}) = num_cols(game[])
148+
149+
num_rows(game::CxxWrap.StdLib.SharedPtrAllocated{MatrixGame}) = num_rows(game[])
150+
151+
extensive_to_matrix_game(game::CxxWrap.StdLib.SharedPtrAllocated{Game}) = extensive_to_matrix_game(game[])

test/bots.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
@testset "MCTSBot" begin
44
UCT_C = 2.
55

6-
init_bot(game, max_simulations, evaluator) = MCTSBot(game, evaluator, UCT_C, max_simulations, 5, true, 42, false, UCT, 0., 0.)
6+
init_bot(game, max_simulations, evaluator) = MCTSBot(game[], evaluator, UCT_C, max_simulations, 5, true, 42, false, UCT, 0., 0.)
77

88
@testset "can play tic_tac_toe" begin
99
game = load_game("tic_tac_toe")
@@ -51,8 +51,8 @@
5151
apply_action(state, get_action_by_str(state, action_str))
5252
end
5353
evaluator = random_rollout_evaluator_factory(20, 42)
54-
bot = MCTSBot(game, evaluator, UCT_C, 10000, 10, true, 42, false, UCT, 0., 0.)
55-
mcts_search(bot, state), state
54+
bot = MCTSBot(game[], evaluator, UCT_C, 10000, 10, true, 42, false, UCT, 0., 0.)
55+
mcts_search(bot, state[]), state[]
5656
end
5757

5858
@testset "solve draw" begin

test/cfr.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ test_exploitability_kuhn_poker(game, policy) = @test exploitability(game, policy
1515

1616
@testset "CFRSolver" begin
1717
game = load_game("kuhn_poker")
18-
solver = CFRSolver(game)
18+
solver = CFRSolver(game[])
1919
for _ in 1:300
2020
evaluate_and_update_policy(solver)
2121
end
@@ -26,7 +26,7 @@ end
2626

2727
@testset "CFRPlusSolver" begin
2828
game = load_game("kuhn_poker")
29-
solver = CFRPlusSolver(game)
29+
solver = CFRPlusSolver(game[])
3030
for _ in 1:200
3131
evaluate_and_update_policy(solver)
3232
end

0 commit comments

Comments
 (0)