diff --git a/src/Ops.jl b/src/Ops.jl index 8a2c15798..01a7105cd 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -1557,4 +1557,90 @@ use [`MLIR.Dialects.stablehlo.dynamic_slice`](@ref) instead. ) end +# XXX: kwargs +# XXX: some of the args are not batched +# XXX: Arbitrary dimensions for batching +# XXX: Out-axis +# XXX: Multiple arg return +function batch(f, args::Vector{<:TracedRArray}) + batch_sizes = [size(x, 1) for x in args] + @assert allequal(batch_sizes) "batching dimensions must be equal" + B = first(batch_sizes) + + in_tys = [ + MLIR.IR.TensorType(size(arg)[2:end], MLIR.IR.Type(Reactant.unwrapped_eltype(arg))) + for arg in args + ] + + sym_visibility = MLIR.IR.Attribute("private") + + mod = MLIR.IR.mmodule() + func = MLIR.IR.block!(MLIR.IR.body(mod)) do + return MLIR.Dialects.func.func_(; + sym_name=string(f) * "_batch_tmp", + function_type=MLIR.IR.FunctionType(in_tys, []), + body=MLIR.IR.Region(), + sym_visibility, + ) + end + fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location() for arg in args]) + push!(MLIR.IR.region(func, 1), fnbody) + + linear_args = [ + TracedRArray{Reactant.unwrapped_eltype(arg),ndims(arg) - 1}( + (), nothing, size(arg)[2:end] + ) for arg in args + ] + + MLIR.IR.activate!(fnbody) + result = try + for (i, arg) in enumerate(linear_args) + raw_arg = MLIR.IR.argument(fnbody, i) + Reactant.TracedUtils.set_mlir_data!(arg, raw_arg) + end + # XXX: call_with_reactant is not working here? + # ERROR: type Nothing has no field stmts + # res = Reactant.call_with_reactant(f, linear_args...) + res = f(linear_args...) + @assert res isa TracedRArray + MLIR.Dialects.func.return_([res.mlir_data]) + res + finally + MLIR.IR.deactivate!(fnbody) + end + + comp_func = MLIR.IR.block!(MLIR.IR.body(mod)) do + return MLIR.Dialects.func.func_(; + sym_name=string(f) * "_batch", + function_type=MLIR.IR.FunctionType(in_tys, [mlir_type(result)]), + body=MLIR.IR.Region(), + sym_visibility, + ) + end + MLIR.API.mlirRegionTakeBody(MLIR.IR.region(comp_func, 1), MLIR.IR.region(func, 1)) + MLIR.API.mlirOperationDestroy(func.operation) + func.operation = MLIR.API.MlirOperation(C_NULL) + + fname = Reactant.TracedUtils.get_attribute_by_name(comp_func, "sym_name") + fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) + + batch_inputs = [x.mlir_data for x in args] + output_shape = (B, size(result)...) + out_tys = [ + MLIR.IR.TensorType(output_shape, MLIR.IR.Type(Reactant.unwrapped_eltype(result))) + ] + + res = MLIR.Dialects.enzyme.batch( + batch_inputs; + outputs=out_tys, + fn=fname, + batch_shape=MLIR.IR.DenseArrayAttribute(Int64[B]), + ) + + res = MLIR.IR.result(res, 1) + return TracedRArray{Reactant.unwrapped_eltype(result),ndims(result) + 1}( + (), res, output_shape + ) +end + end # module Ops