Skip to content

Commit 32604f2

Browse files
Merge branch 'main' into main
2 parents 04ecd28 + ea00fdf commit 32604f2

File tree

10 files changed

+37
-24
lines changed

10 files changed

+37
-24
lines changed

src/ReinforcementLearningCore/src/core/run.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,31 +89,36 @@ function _run(policy::AbstractPolicy,
8989
while !is_stop
9090
reset!(env)
9191
push!(policy, PreEpisodeStage(), env)
92+
optimise!(policy, PreActStage())
9293
push!(hook, PreEpisodeStage(), policy, env)
9394

95+
9496
while !reset_condition(policy, env) # one episode
9597
push!(policy, PreActStage(), env)
98+
optimise!(policy, PreActStage())
9699
push!(hook, PreActStage(), policy, env)
97100

98101
action = RLBase.plan!(policy, env)
99102
act!(env, action)
100103

101-
optimise!(policy)
102-
103104
push!(policy, PostActStage(), env)
105+
optimise!(policy, PostActStage())
104106
push!(hook, PostActStage(), policy, env)
105107

106108
if check_stop(stop_condition, policy, env)
107109
is_stop = true
108110
push!(policy, PreActStage(), env)
111+
optimise!(policy, PreActStage())
109112
push!(hook, PreActStage(), policy, env)
110113
RLBase.plan!(policy, env) # let the policy see the last observation
111114
break
112115
end
113116
end # end of an episode
114117

115118
push!(policy, PostEpisodeStage(), env) # let the policy see the last observation
119+
optimise!(policy, PostEpisodeStage())
116120
push!(hook, PostEpisodeStage(), policy, env)
121+
117122
end
118123
push!(policy, PostExperimentStage(), env)
119124
push!(hook, PostExperimentStage(), policy, env)

src/ReinforcementLearningCore/src/core/stages.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,6 @@ struct PostActStage <: AbstractStage end
1919
Base.push!(p::AbstractPolicy, ::AbstractStage, ::AbstractEnv) = nothing
2020
Base.push!(p::AbstractPolicy, ::AbstractStage, ::AbstractEnv, ::Symbol) = nothing
2121

22-
RLBase.optimise!(::AbstractPolicy) = nothing
22+
RLBase.optimise!(policy::P, ::S) where {P<:AbstractPolicy,S<:AbstractStage} = nothing
23+
24+
RLBase.optimise!(policy::P, ::S, batch) where {P<:AbstractPolicy, S<:AbstractStage} = nothing

src/ReinforcementLearningCore/src/policies/agent/base.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,16 @@ end
3939

4040
Agent(;policy, trajectory, cache = SRT()) = Agent(policy, trajectory, cache)
4141

42-
RLBase.optimise!(agent::Agent) = optimise!(TrajectoryStyle(agent.trajectory), agent)
43-
RLBase.optimise!(::SyncTrajectoryStyle, agent::Agent) =
44-
optimise!(agent.policy, agent.trajectory)
42+
RLBase.optimise!(agent::Agent, stage::S) where {S<:AbstractStage} = optimise!(TrajectoryStyle(agent.trajectory), agent, stage)
43+
RLBase.optimise!(::SyncTrajectoryStyle, agent::Agent, stage::S) where {S<:AbstractStage} =
44+
optimise!(agent.policy, stage, agent.trajectory)
4545

4646
# already spawn a task to optimise inner policy when initializing the agent
47-
RLBase.optimise!(::AsyncTrajectoryStyle, agent::Agent) = nothing
47+
RLBase.optimise!(::AsyncTrajectoryStyle, agent::Agent, stage::S) where {S<:AbstractStage} = nothing
4848

49-
function RLBase.optimise!(policy::AbstractPolicy, trajectory::Trajectory)
49+
function RLBase.optimise!(policy::AbstractPolicy, stage::S, trajectory::Trajectory) where {S<:AbstractStage}
5050
for batch in trajectory
51-
optimise!(policy, batch)
51+
optimise!(policy, stage, batch)
5252
end
5353
end
5454

src/ReinforcementLearningCore/src/policies/agent/multi_agent.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,26 +110,30 @@ function Base.run(
110110
while !is_stop
111111
reset!(env)
112112
push!(multiagent_policy, PreEpisodeStage(), env)
113+
optimise!(multiagent_policy, PreEpisodeStage())
113114
push!(multiagent_hook, PreEpisodeStage(), multiagent_policy, env)
114115

115116
while !(reset_condition(multiagent_policy, env) || is_stop) # one episode
116117
for player in CurrentPlayerIterator(env)
117118
policy = multiagent_policy[player] # Select appropriate policy
118119
hook = multiagent_hook[player] # Select appropriate hook
119120
push!(policy, PreActStage(), env)
121+
optimise!(policy, PreActStage())
120122
push!(hook, PreActStage(), policy, env)
121123

122124
action = RLBase.plan!(policy, env)
123125
act!(env, action)
124126

125-
optimise!(policy)
127+
126128

127129
push!(policy, PostActStage(), env)
130+
optimise!(policy, PostActStage())
128131
push!(hook, PostActStage(), policy, env)
129132

130133
if check_stop(stop_condition, policy, env)
131134
is_stop = true
132135
push!(multiagent_policy, PreActStage(), env)
136+
optimise!(multiagent_policy, PreActStage())
133137
push!(multiagent_hook, PreActStage(), policy, env)
134138
RLBase.plan!(multiagent_policy, env) # let the policy see the last observation
135139
break
@@ -142,6 +146,7 @@ function Base.run(
142146
end # end of an episode
143147

144148
push!(multiagent_policy, PostEpisodeStage(), env) # let the policy see the last observation
149+
optimise!(multiagent_policy, PostEpisodeStage())
145150
push!(multiagent_hook, PostEpisodeStage(), multiagent_policy, env)
146151
end
147152
push!(multiagent_policy, PostExperimentStage(), env)
@@ -225,8 +230,8 @@ function RLBase.plan!(multiagent::MultiAgentPolicy, env::E) where {E<:AbstractEn
225230
return (RLBase.plan!(multiagent[player], env, player) for player in players(env))
226231
end
227232

228-
function RLBase.optimise!(multiagent::MultiAgentPolicy)
233+
function RLBase.optimise!(multiagent::MultiAgentPolicy, stage::S) where {S<:AbstractStage}
229234
for policy in multiagent
230-
RLCore.optimise!(policy)
235+
RLCore.optimise!(policy, stage)
231236
end
232237
end

src/ReinforcementLearningCore/src/policies/q_based_policy.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,4 @@ end
3737
RLBase.prob(p::QBasedPolicy{L,Ex}, env::AbstractEnv) where {L<:AbstractLearner,Ex<:AbstractExplorer} =
3838
prob(p.explorer, forward(p.learner, env), legal_action_space_mask(env))
3939

40-
RLBase.optimise!(p::QBasedPolicy{L,Ex}, x::NamedTuple) where {L<:AbstractLearner,Ex<:AbstractExplorer} = optimise!(p.learner, x)
40+
RLBase.optimise!(p::QBasedPolicy{L,Ex}, stage::S, x::NamedTuple) where {L<:AbstractLearner,Ex<:AbstractExplorer, S<:AbstractStage} = optimise!(p.learner, x)

src/ReinforcementLearningZoo/src/algorithms/dqns/prioritized_dqn.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,9 @@ function RLBase.optimise!(
7171
k => p′
7272
end
7373

74-
function RLBase.optimise!(policy::QBasedPolicy{<:PrioritizedDQNLearner}, trajectory::Trajectory)
74+
function RLBase.optimise!(policy::QBasedPolicy{<:PrioritizedDQNLearner}, ::PostActStage, trajectory::Trajectory)
7575
for batch in trajectory
76-
k, p = optimise!(policy, batch) |> send_to_host
76+
k, p = optimise!(policy, PostActStage(), batch) |> send_to_host
7777
trajectory[:priority, k] = p
7878
end
7979
end

src/ReinforcementLearningZoo/src/algorithms/dqns/rainbow.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,9 @@ function project_distribution(supports, weights, target_support, delta_z, vmin,
139139
reshape(sum(projection, dims=1), n_atoms, batch_size)
140140
end
141141

142-
function RLBase.optimise!(policy::QBasedPolicy{<:RainbowLearner}, trajectory::Trajectory)
142+
function RLBase.optimise!(policy::QBasedPolicy{<:RainbowLearner}, ::PostActStage, trajectory::Trajectory)
143143
for batch in trajectory
144-
res = optimise!(policy, batch) |> send_to_host
144+
res = optimise!(policy, PostActStage(), batch) |> send_to_host
145145
if !isnothing(res)
146146
k, p = res
147147
trajectory[:priority, k] = p

src/ReinforcementLearningZoo/src/algorithms/policy_gradient/mpo.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ end
9393

9494
function RLBase.optimise!(
9595
p::MPOPolicy,
96+
::PostActStage,
9697
batches::NamedTuple{
9798
(:actor, :critic),
9899
<: Tuple{

src/ReinforcementLearningZoo/src/algorithms/policy_gradient/trpo.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,16 @@ function Base.push!(p::Agent{<:TRPO}, ::PostEpisodeStage, env::AbstractEnv)
3939
empty!(p.trajectory.container)
4040
end
4141

42-
RLBase.optimise!(::Agent{<:TRPO}) = nothing
42+
RLBase.optimise!(::Agent{<:TRPO}, ::PostActStage) = nothing
4343

44-
function RLBase.optimise!::TRPO, episode::Episode)
44+
function RLBase.optimise!::TRPO, ::PostActStage, episode::Episode)
4545
gain = discount_rewards(episode[:reward][:], π.γ)
4646
for inds in Iterators.partition(shuffle.rng, 1:length(episode)), π.batch_size)
4747
optimise!(π, (state=episode[:state][inds], action=episode[:action][inds], gain=gain[inds]))
4848
end
4949
end
5050

51-
function RLBase.optimise!(p::TRPO, batch::NamedTuple{(:state, :action, :gain)})
51+
function RLBase.optimise!(p::TRPO, ::PostActStage, batch::NamedTuple{(:state, :action, :gain)})
5252
A = p.approximator
5353
B = p.baseline
5454
s, a, g = map(Array, batch) # !!! FIXME

src/ReinforcementLearningZoo/src/algorithms/policy_gradient/vpg.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,16 @@ function update!(p::Agent{<:VPG}, ::PostEpisodeStage, env::AbstractEnv)
3636
empty!(p.trajectory.container)
3737
end
3838

39-
RLBase.optimise!(::Agent{<:VPG}) = nothing
39+
RLBase.optimise!(::Agent{<:VPG}, ::PostActStage) = nothing
4040

41-
function RLBase.optimise!::VPG, episode::Episode)
41+
function RLBase.optimise!::VPG, ::PostActStage, episode::Episode)
4242
gain = discount_rewards(episode[:reward][:], π.γ)
4343
for inds in Iterators.partition(shuffle.rng, 1:length(episode)), π.batch_size)
44-
optimise!(π, (state=episode[:state][inds], action=episode[:action][inds], gain=gain[inds]))
44+
optimise!(π, PostActStage(), (state=episode[:state][inds], action=episode[:action][inds], gain=gain[inds]))
4545
end
4646
end
4747

48-
function RLBase.optimise!(p::VPG, batch::NamedTuple{(:state, :action, :gain)})
48+
function RLBase.optimise!(p::VPG, ::PostActStage, batch::NamedTuple{(:state, :action, :gain)})
4949
A = p.approximator
5050
B = p.baseline
5151
s, a, g = map(Array, batch) # !!! FIXME

0 commit comments

Comments
 (0)