Skip to content

Commit

Permalink
feat: add Ops.batch
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jan 15, 2025
1 parent dafa186 commit 57f342a
Showing 1 changed file with 86 additions and 0 deletions.
86 changes: 86 additions & 0 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1557,4 +1557,90 @@ use [`MLIR.Dialects.stablehlo.dynamic_slice`](@ref) instead.
)
end

# XXX: kwargs
# XXX: some of the args are not batched
# XXX: Arbitrary dimensions for batching
# XXX: Out-axis
# XXX: Multiple arg return
function batch(f, args::Vector{<:TracedRArray})
batch_sizes = [size(x, 1) for x in args]
@assert allequal(batch_sizes) "batching dimensions must be equal"
B = first(batch_sizes)

in_tys = [
MLIR.IR.TensorType(size(arg)[2:end], MLIR.IR.Type(Reactant.unwrapped_eltype(arg)))
for arg in args
]

sym_visibility = MLIR.IR.Attribute("private")

mod = MLIR.IR.mmodule()
func = MLIR.IR.block!(MLIR.IR.body(mod)) do
return MLIR.Dialects.func.func_(;
sym_name=string(f) * "_batch_tmp",
function_type=MLIR.IR.FunctionType(in_tys, []),
body=MLIR.IR.Region(),
sym_visibility,
)
end
fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location() for arg in args])
push!(MLIR.IR.region(func, 1), fnbody)

linear_args = [
TracedRArray{Reactant.unwrapped_eltype(arg),ndims(arg) - 1}(
(), nothing, size(arg)[2:end]
) for arg in args
]

MLIR.IR.activate!(fnbody)
result = try
for (i, arg) in enumerate(linear_args)
raw_arg = MLIR.IR.argument(fnbody, i)
Reactant.TracedUtils.set_mlir_data!(arg, raw_arg)
end
# XXX: call_with_reactant is not working here?
# ERROR: type Nothing has no field stmts
# res = Reactant.call_with_reactant(f, linear_args...)
res = f(linear_args...)
@assert res isa TracedRArray
MLIR.Dialects.func.return_([res.mlir_data])
res
finally
MLIR.IR.deactivate!(fnbody)
end

comp_func = MLIR.IR.block!(MLIR.IR.body(mod)) do
return MLIR.Dialects.func.func_(;
sym_name=string(f) * "_batch",
function_type=MLIR.IR.FunctionType(in_tys, [mlir_type(result)]),
body=MLIR.IR.Region(),
sym_visibility,
)
end
MLIR.API.mlirRegionTakeBody(MLIR.IR.region(comp_func, 1), MLIR.IR.region(func, 1))
MLIR.API.mlirOperationDestroy(func.operation)
func.operation = MLIR.API.MlirOperation(C_NULL)

fname = Reactant.TracedUtils.get_attribute_by_name(comp_func, "sym_name")
fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname))

batch_inputs = [x.mlir_data for x in args]
output_shape = (B, size(result)...)
out_tys = [
MLIR.IR.TensorType(output_shape, MLIR.IR.Type(Reactant.unwrapped_eltype(result)))
]

res = MLIR.Dialects.enzyme.batch(
batch_inputs;
outputs=out_tys,
fn=fname,
batch_shape=MLIR.IR.DenseArrayAttribute(Int64[B]),
)

res = MLIR.IR.result(res, 1)
return TracedRArray{Reactant.unwrapped_eltype(result),ndims(result) + 1}(
(), res, output_shape
)
end

end # module Ops

0 comments on commit 57f342a

Please sign in to comment.