-
Notifications
You must be signed in to change notification settings - Fork 45
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Operator] Add weight_norm op [MooreThreads] #177
Conversation
54570d7
to
e7fa442
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please simplify the implementation only considering dim is 0 or v.dim() - 1.
g is supposed to be a scalar factor for dim dimensions. For instance if dim is 0, g.shape should be something like [1, N]. |
d5c2125
to
8ff8cad
Compare
8ff8cad
to
30760f4
Compare
v_value = tl.load(v + row_offset * N + col_offset, mask=mask) | ||
v_block += v_value * v_value | ||
|
||
normalized = tl.sqrt(tl.sum(v_block, axis=1) + eps) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should be reducing on the first dimension, ie., axis=0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
v_block is stored in row-major order, so I perform the sum along the rows regardless of whether the reduction dimension is the first or last (xy index will be permuted for last). The test encountered an error because REDUCTION_SHAPES = (200, 40999, 3) and dim = 1 is not supported for weight normalization; this issue has now been resolved.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reducing on dim 1 is only correct provided the inputs are transposed up front. It looks like that's not the case in WeightNorm.forward. Can we further verify that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If reduction occurs in the error dimension, the result will definitely be different compared to the golden reference, but currently, they are consistent. The transpose occurs within the kernel, where threads load the number in the row direction from global, but store it in the column direction of v_block.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
M = v.shape[0]
N = math.prod(v.shape[1:])
grid = lambda META: (triton.cdiv(M, META["BLOCK_ROW_SIZE"]),)
Above is the blocking scheme in the code, where M is the reduction dim size. It's clear the reduction axis is split. I don't see how transpose could be done in the kernel...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in the kernel
// for reduce dim is first
tx = tl.arange(0, BLOCK_COL_SIZE)[None, :]
v_block = tl.zeros([BLOCK_ROW_SIZE, BLOCK_COL_SIZE], dtype=tl.float32)
// for reduce dim is last
ty = tl.arange(0, BLOCK_ROW_SIZE)[None, :]
v_block = tl.zeros([BLOCK_COL_SIZE, BLOCK_ROW_SIZE], dtype=tl.float32)
how about you verify this with a simple instance, for example reduce shape = (2, 2). if reduce dim is wrong in the kernel, the result will not consistent with golden
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My bad... I took for granted that the input dim
is the dimension to be contracted off..
45ab7ec
to
3bea848
Compare
3bea848
to
e8d9c9f
Compare
e8d9c9f
to
162366c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG
No description provided.