diff --git a/paddle/fluid/operators/sequence_ops/sequence_mask_op.cc b/paddle/fluid/operators/sequence_ops/sequence_mask_op.cc index f50fc2268c4e8..a6cd59e44dff0 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_mask_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_mask_op.cc @@ -12,7 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/binary.h" namespace paddle { namespace operators { @@ -21,21 +24,6 @@ class SequenceMaskOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "SequenceMask"); - OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "SequenceMask"); - - int maxlen = ctx->Attrs().Get("maxlen"); - auto dim = phi::vectorize(ctx->GetInputDim("X")); - - if (ctx->HasInputs("MaxLenTensor")) { - dim.push_back(-1); - } else { - dim.push_back(maxlen > 0 ? maxlen : -1); - } - ctx->SetOutputDim("Y", phi::make_ddim(dim)); - } - protected: phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { @@ -93,9 +81,14 @@ If maxlen < 0, maxlen = max(X) } // namespace operators } // namespace paddle +DECLARE_INFER_SHAPE_FUNCTOR(sequence_mask, + SequenceMaskInferShapeFunctor, + PD_INFER_META(phi::SequenceMaskInferMeta)); + REGISTER_OPERATOR( sequence_mask, paddle::operators::SequenceMaskOp, paddle::operators::SequenceMaskOpMaker, paddle::framework::EmptyGradOpMaker, - paddle::framework::EmptyGradOpMaker); + paddle::framework::EmptyGradOpMaker, + SequenceMaskInferShapeFunctor); diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index e32b1fc241b48..dde953b5d9db1 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -23,6 +23,7 @@ limitations under the License. */ #include "paddle/phi/common/type_traits.h" #include "paddle/phi/core/ddim.h" #include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/core/utils/data_type.h" #include "paddle/phi/infermeta/unary.h" #include "paddle/phi/kernels/cpu/conv_util.h" #include "paddle/phi/kernels/funcs/axis_utils.h" @@ -2584,6 +2585,24 @@ void SearchsortedInferMeta(const MetaTensor& sorted_sequence, } } +void SequenceMaskInferMeta(const MetaTensor& x, + const MetaTensor& max_len_tensor, + int maxlen, + int out_dtype, + MetaTensor* y) { + auto dim = phi::vectorize(x.dims()); + + if (max_len_tensor) { + dim.push_back(-1); + } else { + dim.push_back(maxlen > 0 ? maxlen : -1); + } + + y->set_dims(phi::make_ddim(dim)); + auto out_phi_dtype = phi::TransToPhiDataType(out_dtype); + y->set_dtype(out_phi_dtype); +} + void SoftmaxMaskFuseInferMeta(const MetaTensor& x, const MetaTensor& mask, MetaTensor* out) { diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 3a38e6b599021..ff03ecb8f8a75 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -399,6 +399,12 @@ void SearchsortedInferMeta(const MetaTensor& sorted_sequence, bool right, MetaTensor* out); +void SequenceMaskInferMeta(const MetaTensor& x, + const MetaTensor& max_len_tensor, + int maxlen, + int out_dtype, + MetaTensor* y); + void SoftmaxMaskFuseInferMeta(const MetaTensor& x, const MetaTensor& mask, MetaTensor* out);