forked from Element-Research/dpnn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Reinforce.lua
52 lines (47 loc) · 1.87 KB
/
Reinforce.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
------------------------------------------------------------------------
--[[ Reinforce ]]--
-- Ref A. http://incompleteideas.net/sutton/williams-92.pdf
-- Abstract class for modules that use the REINFORCE algorithm (ref A).
-- The reinforce(reward) method is called by a special Reward Criterion.
-- After which, when backward is called, the reward will be used to
-- generate gradInputs. The gradOutput is usually ignored.
------------------------------------------------------------------------
local Reinforce, parent = torch.class("nn.Reinforce", "nn.Module")
function Reinforce:__init(stochastic)
parent.__init(self)
-- true makes it stochastic during evaluation and training
-- false makes it stochastic only during training
self.stochastic = stochastic
end
-- a Reward Criterion will call this
function Reinforce:reinforce(reward)
parent.reinforce(self, reward)
self.reward = reward
end
function Reinforce:updateOutput(input)
self.output:set(input)
end
function Reinforce:updateGradInput(input, gradOutput)
local reward = self:rewardAs(input)
self.gradInput:resizeAs(reward):copy(reward)
end
-- this can be called by updateGradInput
function Reinforce:rewardAs(input)
assert(self.reward:dim() == 1)
if input:isSameSizeAs(self.reward) then
return self.reward
else
if self.reward:size(1) ~= input:size(1) then
-- assume input is in online-mode
input = self:toBatch(input, input:dim())
assert(self.reward:size(1) == input:size(1), self.reward:size(1).." ~= "..input:size(1))
end
self._reward = self._reward or self.reward.new()
self.__reward = self.__reward or self.reward.new()
local size = input:size():fill(1):totable()
size[1] = self.reward:size(1)
self._reward:view(self.reward, table.unpack(size))
self.__reward:expandAs(self._reward, input)
return self.__reward
end
end