-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathPartialLinear.lua
114 lines (96 loc) · 3.79 KB
/
PartialLinear.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
local PartialLinear, Module = torch.class('nn.PartialLinear', 'nn.Module')
--[[
PartialLinear is a Linear layer that allows the user to a set a collection of
column indices. When the column indices are set, the layer will behave like a
Linear layer that only has those columns. Meanwhile, all parameters are
preserved, so resetting the PartialLinear layer will result in a module that
behaves just like a regular Linear layer.
This module is useful, for instance, when you want to do forward-backward on
only a subset of a Linear layer during training but use the full Linear layer
at test time.
]]--
function PartialLinear:__init(inputsize, outputsize, bias)
local bias = ((bias == nil) and true) or bias
Module.__init(self)
-- define the layer as a small network:
local pt = nn.ParallelTable()
pt:add(nn.Identity()):add(nn.LookupTable(outputsize, inputsize))
self.network = nn.Sequential():add(pt):add(nn.MM(false, true))
if bias then
self.bias = torch.Tensor(1, outputsize):zero()
self.gradBias = torch.Tensor(1, outputsize):zero()
end
-- set partition:
self.inputsize = inputsize
self.outputsize = outputsize
self.allcolumns = torch.range(1, self.outputsize)
self:resetPartition()
end
function PartialLinear:setPartition(indices)
self.partition = indices:type(self.allcolumns:type())
end
function PartialLinear:resetPartition()
self.partition = self.allcolumns
end
function PartialLinear:parameters()
return {self.network:get(1):get(2).weight, self.bias},
{self.network:get(1):get(2).gradWeight, self.gradBias}
end -- should return only the relevant partition?
function PartialLinear:updateOutput(input)
self.output:set(self.network:forward{input, self.partition})
if self.bias then
self.output:add(
self.bias:index(2, self.partition:long()):expandAs(self.output)
)
self.addBuffer = self.addBuffer or input.new()
if self.addBuffer:nElement() ~= input:size(1) then
self.addBuffer:resize(input:size(1)):fill(1)
end
end
return self.output
end
function PartialLinear:updateGradInput(input, gradOutput)
if self.gradInput then
self.network:updateGradInput({input, self.partition}, gradOutput)
self.gradInput:set(self.network.gradInput[1])
end
return self.gradInput
end
function PartialLinear:accGradParameters(input, gradOutput, scale)
local scale = scale or 1
self.network:accGradParameters({input, self.partition}, gradOutput, scale)
if self.bias then
self.buffer = self.buffer or input.new()
self.buffer:resize(gradOutput:size(2))
self.buffer:mv(gradOutput:t(), self.addBuffer):mul(scale)
self.gradBias:indexAdd(
2, self.partition:long(), self.buffer:view(1, self.buffer:nElement())
)
end
end
function PartialLinear:accUpdateGradParameters(input, gradOutput, lr)
local gradWeight = self.network:get(1):get(2).gradWeight
local gradBias = self.gradBias
self.network:get(1):get(2).gradWeight = self.network:get(1):get(2).weight
self.gradBias = self.bias
self:accGradParameters(input, gradOutput, -lr)
self.network:get(1):get(2).gradWeight = gradWeight
self.gradBias = gradBias
end
function PartialLinear:zeroGradParameters()
self.network:zeroGradParameters()
self.gradBias:zero()
end
function PartialLinear:updateParameters(learningRate)
self.network:updateParameters(learningRate)
self.bias:add(-learningRate, self.gradBias)
end
function PartialLinear:sharedAccUpdateGradParameters(input, gradOutput, lr)
-- we do not need to accumulate parameters when sharing:
self:defaultAccUpdateGradParameters(input, gradOutput, lr)
end
function PartialLinear:__tostring__()
return torch.type(self) ..
string.format('(%d -> %d)', self.inputsize, self.outputsize) ..
(self.bias == nil and ' without bias' or '')
end