From de53fce7d7c13e7bdd074297d214b8577c9c484d Mon Sep 17 00:00:00 2001 From: haozech Date: Thu, 9 Dec 2021 12:57:45 +0800 Subject: [PATCH] Add doc cherry-pick (#631) * Add doc (#623) * Refine load_paddle_model doc (#626) * refine doc and close LOG (#632) --- build.sh | 59 +++++++---- cinn/backends/codegen_cuda_dev_test.cc | 14 +-- cinn/backends/compiler.cc | 6 +- cinn/backends/llvm/execution_engine.cc | 10 +- cinn/backends/llvm/simple_jit.cc | 6 +- cinn/backends/llvm/simple_jit.h | 0 cinn/backends/nvrtc_util.cc | 2 +- cinn/common/cas.cc | 4 +- cinn/common/ir_util.cc | 2 +- cinn/frontend/computation.cc | 2 +- cinn/frontend/interpreter.cc | 2 +- cinn/frontend/paddle/model_parser.cc | 6 +- cinn/hlir/pe/nn.cc | 4 +- cinn/optim/replace_var_with_expr.cc | 4 +- cinn/poly/compute_at_transform.cc | 2 +- cinn/poly/stage.cc | 6 +- cinn/pybind/frontend.cc | 4 +- cinn/pybind/lang.cc | 15 +++ cinn/pybind/poly.cc | 33 ++++--- docs/guide.md | 41 -------- docs/source/conf.py | 7 +- docs/source/guide.md | 33 +++++++ docs/source/index.rst | 8 +- python/tests/test_computation.py | 1 - tutorials/README.md | 13 +++ tutorials/README.txt | 11 --- tutorials/jit.py | 3 +- tutorials/load_paddle_model.cc | 92 +++++++++++++++++ tutorials/load_paddle_model.py | 131 ++++++++++++++++++------- tutorials/matmul.py | 3 +- tutorials/paddlepaddle.png | Bin 0 -> 8598 bytes tutorials/schedule_primitives.py | 57 ++++++++++- 32 files changed, 413 insertions(+), 168 deletions(-) mode change 100644 => 100755 cinn/backends/codegen_cuda_dev_test.cc mode change 100644 => 100755 cinn/backends/compiler.cc mode change 100644 => 100755 cinn/backends/llvm/execution_engine.cc mode change 100644 => 100755 cinn/backends/llvm/simple_jit.cc mode change 100644 => 100755 cinn/backends/llvm/simple_jit.h mode change 100644 => 100755 cinn/backends/nvrtc_util.cc mode change 100644 => 100755 cinn/frontend/computation.cc mode change 100644 => 100755 cinn/frontend/interpreter.cc mode change 100644 => 100755 cinn/frontend/paddle/model_parser.cc mode change 100644 => 100755 cinn/hlir/pe/nn.cc mode change 100644 => 100755 cinn/poly/compute_at_transform.cc mode change 100644 => 100755 cinn/pybind/frontend.cc mode change 100644 => 100755 cinn/pybind/lang.cc delete mode 100644 docs/guide.md create mode 100644 docs/source/guide.md mode change 100644 => 100755 docs/source/index.rst mode change 100644 => 100755 python/tests/test_computation.py create mode 100644 tutorials/README.md delete mode 100644 tutorials/README.txt create mode 100644 tutorials/load_paddle_model.cc create mode 100644 tutorials/paddlepaddle.png diff --git a/build.sh b/build.sh index 95f9080f12..a5d057090b 100755 --- a/build.sh +++ b/build.sh @@ -42,6 +42,17 @@ function gpu_on { cudnn_config=ON } +function test_doc { + mkdir -p $build_dir + cd $build_dir + export runtime_include_dir=$workspace/cinn/runtime/cuda + + prepare_ci + cmake_ + build + make_doc +} + function cudnn_off { cudnn_config=OFF } @@ -94,36 +105,46 @@ function prepare_ci { pip install pre-commit pip install clang-format==9.0 pip install wheel - pip install sphinx==3.3.1 sphinx_gallery==0.8.1 recommonmark==0.6.0 exhale scipy breathe==4.24.0 matplotlib + pip install sphinx==3.3.1 sphinx_gallery==0.8.1 recommonmark==0.6.0 exhale scipy breathe==4.24.0 matplotlib sphinx_rtd_theme pip install paddlepaddle-gpu==2.1.2.post101 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html } -function make_doc { +function prepare_doc_model_file { proxy_off - cd $workspace/tutorials - if [[ -f "ResNet18.tar.gz" ]]; then - echo "model file for tutorials already downloaded." - elif [[ -f "$build_dir/thirds/ResNet18.tar.gz" ]]; then - rm -rf $workspace/tutorials/ResNet18 - ln -s $build_dir/thirds/ResNet18 $workspace/tutorials/ResNet18 + local tar_file=$1 + if [[ -f "$tar_file.tar.gz" ]]; then + echo "model file $tar_file.tar.gz for tutorials already downloaded." + elif [[ -f "$build_dir/thirds/$tar_file.tar.gz" ]]; then + rm -rf $workspace/tutorials/$tar_file + ln -s $build_dir/thirds/$tar_file $workspace/tutorials/$tar_file else - wget http://paddle-inference-dist.bj.bcebos.com/CINN/ResNet18.tar.gz - tar -zxvf ResNet18.tar.gz + wget https://paddle-inference-dist.bj.bcebos.com/CINN/$tar_file.tar.gz + tar -zxvf $tar_file.tar.gz fi +} + +function make_doc { + proxy_off + cd $workspace/tutorials + prepare_doc_model_file ResNet50 + prepare_doc_model_file MobileNetV2 + prepare_doc_model_file EfficientNet + prepare_doc_model_file FaceDet + if [[ $cuda_config == "ON" && ! -d "./is_cuda" ]]; then mkdir is_cuda fi - + if [[ $cuda_config == "OFF" && -d "./is_cuda" ]]; then + rm -rf ./is_cuda + fi cd $build_dir rm -f $workspace/python/cinn/core_api.so ln -s $build_dir/cinn/pybind/core_api.so $workspace/python/cinn/ cd $workspace/docs mkdir -p docs/source/cpp - cat $workspace/tutorials/matmul.cc | python${py_version} $workspace/tools/gen_c++_tutorial.py > $workspace/docs/source/matmul.md + cat $workspace/tutorials/matmul.cc | python${py_version} $workspace/tools/gen_c++_tutorial.py > $workspace/docs/source/matmul.md + cat $workspace/tutorials/load_paddle_model.cc | python${py_version} $workspace/tools/gen_c++_tutorial.py > $workspace/docs/source/load_paddle_model.md make html - if [[ $cuda_config == "ON" && -d "./is_cuda" ]]; then - rm -rf $workspace/tutorials/is_cuda - fi } function cmake_ { @@ -308,6 +329,10 @@ function main { run_test shift ;; + test_doc) + test_doc + shift + ;; ci) CI shift @@ -320,10 +345,6 @@ function main { prepare_model shift ;; - make_doc) - make_doc - shift - ;; esac done } diff --git a/cinn/backends/codegen_cuda_dev_test.cc b/cinn/backends/codegen_cuda_dev_test.cc old mode 100644 new mode 100755 index b637d459a8..348075268d --- a/cinn/backends/codegen_cuda_dev_test.cc +++ b/cinn/backends/codegen_cuda_dev_test.cc @@ -89,7 +89,7 @@ TEST(CodeGenCUDA, basic) { CodeGenCUDA_Dev codegen(target); - auto func = Lower("elementwise_add", stages, {A, B, C}); + auto func = Lower("elementwise_mul", stages, {A, B, C}); auto compiled = codegen.Compile(func); @@ -115,7 +115,7 @@ TEST(CodeGenCUDA, Module_output) { CodeGenCUDA_Dev codegen(target); - auto func = Lower("elementwise_add", stages, {A, B, C}); + auto func = Lower("elementwise_mul", stages, {A, B, C}); Module::Builder builder("module", target); builder.AddFunction(func); @@ -149,7 +149,7 @@ TEST(CodeGenCUDA2, test_of_cacheread) { stages[B_cache]->ComputeAt(stages[C], 1); CodeGenCUDA_Dev codegen(target); - auto func = Lower("elementwise_add", stages, {A, B, C}); + auto func = Lower("elementwise_mul", stages, {A, B, C}); Module::Builder builder("module", target); builder.AddFunction(func); @@ -181,7 +181,7 @@ TEST(CodeGenCUDA2, test_of_cacheread) { dim3 grid(10, 1, 1); dim3 block(10, 1, 1); - cuda_module.LaunchKernel(0, "elementwise_add", grid, block, args); + cuda_module.LaunchKernel(0, "elementwise_mul", grid, block, args); CUDA_CALL(cudaMemcpy(host_data3.data(), reinterpret_cast(Cd), @@ -221,7 +221,7 @@ TEST(CodeGenCUDA2, test_of_splitcudakernel) { CodeGenCUDA_Dev codegen(target); - auto func = lang::LowerVec("elementwise_add", stages, {A, B, C, D}, {}, {}, nullptr, target); + auto func = lang::LowerVec("elementwise_mul_and_add", stages, {A, B, C, D}, {}, {}, nullptr, target); Module::Builder builder("module", target); for (auto& i : func) { @@ -251,7 +251,7 @@ typedef char int8_t; __global__ -void __launch_bounds__(200) elementwise_add(const float* __restrict__ X, const float* __restrict__ Y, float* __restrict__ C) +void __launch_bounds__(200) elementwise_mul_and_add(const float* __restrict__ X, const float* __restrict__ Y, float* __restrict__ C) { if (((int)blockIdx.x < 100)) { if (((int)threadIdx.x < 200)) { @@ -259,7 +259,7 @@ void __launch_bounds__(200) elementwise_add(const float* __restrict__ X, const f }; }; }__global__ -void __launch_bounds__(200) elementwise_add_1(const float* __restrict__ X, const float* __restrict__ Y, const float* __restrict__ C, float* __restrict__ D) +void __launch_bounds__(200) elementwise_mul_and_add_1(const float* __restrict__ X, const float* __restrict__ Y, const float* __restrict__ C, float* __restrict__ D) { if (((int)blockIdx.x < 100)) { if (((int)threadIdx.x < 200)) { diff --git a/cinn/backends/compiler.cc b/cinn/backends/compiler.cc old mode 100644 new mode 100755 index 2fd2428d36..d8ef99fef4 --- a/cinn/backends/compiler.cc +++ b/cinn/backends/compiler.cc @@ -70,14 +70,14 @@ void Compiler::CompileCudaModule(const Module& module, const std::string& code, auto _host_module_device_module_ = SplitCudaAndHostModule(module); // NOLINT auto& host_module = std::get<0>(_host_module_device_module_); auto& device_module = std::get<1>(_host_module_device_module_); - LOG(INFO) << "[CUDA] host module:\n" << host_module; + VLOG(3) << "[CUDA] host module:\n" << host_module; { // compile cuda device - LOG(INFO) << "[CUDA] device module:\n" << device_module; + VLOG(3) << "[CUDA] device module:\n" << device_module; CodeGenCUDA_Dev codegen(target_); auto source_code = codegen.Compile(device_module); if (!code.empty()) source_code = code; - LOG(INFO) << "[CUDA] source code:\n" << source_code; + VLOG(3) << "[CUDA] source code:\n" << source_code; using runtime::cuda::CUDAModule; backends::NVRTC_Compiler compiler; diff --git a/cinn/backends/llvm/execution_engine.cc b/cinn/backends/llvm/execution_engine.cc old mode 100644 new mode 100755 index 43f8d625b8..4b9f29c048 --- a/cinn/backends/llvm/execution_engine.cc +++ b/cinn/backends/llvm/execution_engine.cc @@ -98,7 +98,7 @@ std::unique_ptr NaiveObjectCache::getObject(const llvm::Modu return nullptr; } - LOG(INFO) << "Object for " << m->getModuleIdentifier() << " loaded from cache."; + VLOG(3) << "Object for " << m->getModuleIdentifier() << " loaded from cache."; return llvm::MemoryBuffer::getMemBuffer(it->second->getMemBufferRef()); } @@ -178,25 +178,25 @@ void ExecutionEngine::Link(const ir::Module &module) { decltype(auto) es = jit_->getExecutionSession(); if (false) { - LOG(INFO) << "======= dump jit execution session ======"; + VLOG(3) << "======= dump jit execution session ======"; std::string buffer; llvm::raw_string_ostream os(buffer); es.dump(os); os.flush(); - LOG(INFO) << buffer; + VLOG(3) << buffer; } } bool ExecutionEngine::AddModule(std::unique_ptr module, std::unique_ptr context) { module->setDataLayout(jit_->getDataLayout()); if (false) { - LOG(INFO) << "======= dump jit lib =========="; + VLOG(3) << "======= dump jit lib =========="; std::string buffer; llvm::raw_string_ostream os(buffer); module->print(os, {}); // main_jd_->dump(os); os.flush(); - LOG(INFO) << buffer; + VLOG(3) << buffer; } llvm::orc::ThreadSafeContext tsc(std::move(context)); llvm::orc::ThreadSafeModule tsm(std::move(module), std::move(tsc)); diff --git a/cinn/backends/llvm/simple_jit.cc b/cinn/backends/llvm/simple_jit.cc old mode 100644 new mode 100755 index ed5e2062e3..8439ec9be5 --- a/cinn/backends/llvm/simple_jit.cc +++ b/cinn/backends/llvm/simple_jit.cc @@ -71,8 +71,8 @@ void SimpleJIT::AddModule(std::unique_ptr module, bool optimize) { module_pass_manager.run(*module, module_analysis_manager); } - LOG(INFO) << "jit target: " << jit_->getDataLayout().getStringRepresentation(); - LOG(INFO) << "module target: " << module->getDataLayout().getStringRepresentation(); + VLOG(3) << "jit target: " << jit_->getDataLayout().getStringRepresentation(); + VLOG(3) << "module target: " << module->getDataLayout().getStringRepresentation(); llvm::orc::ThreadSafeModule tsm(std::move(module), context_); llvm::cantFail(jit_->addIRModule(std::move(tsm))); @@ -82,7 +82,7 @@ void SimpleJIT::AddModule(std::unique_ptr module, bool optimize) { llvm::raw_string_ostream os(buffer); jit_->getExecutionSession().dump(os); os.flush(); - LOG(INFO) << "compiled jit:\n" << buffer; + VLOG(3) << "compiled jit:\n" << buffer; } } diff --git a/cinn/backends/llvm/simple_jit.h b/cinn/backends/llvm/simple_jit.h old mode 100644 new mode 100755 diff --git a/cinn/backends/nvrtc_util.cc b/cinn/backends/nvrtc_util.cc old mode 100644 new mode 100755 index 012bc1c3c2..e397cbacd4 --- a/cinn/backends/nvrtc_util.cc +++ b/cinn/backends/nvrtc_util.cc @@ -91,7 +91,7 @@ std::string NVRTC_Compiler::CompilePTX(const std::string& code, bool include_hea for (const auto& option : compile_options) { param_cstrings.push_back(option.c_str()); } - LOG(INFO) << "compile options: " << utils::Join(compile_options, " "); + VLOG(3) << "compile options: " << utils::Join(compile_options, " "); NVRTC_CALL(nvrtcCreateProgram(&prog, code.c_str(), nullptr, 0, nullptr, nullptr)); nvrtcResult compile_res = nvrtcCompileProgram(prog, param_cstrings.size(), param_cstrings.data()); diff --git a/cinn/common/cas.cc b/cinn/common/cas.cc index 57766f2400..927ac37008 100644 --- a/cinn/common/cas.cc +++ b/cinn/common/cas.cc @@ -2005,8 +2005,8 @@ Expr CasSimplifyMutator::FurtherSimplifyFracWithInterval( auto it = var_intervals.find(bv->name); auto ai_abs = std::abs(ai->value); if (it != var_intervals.end()) { - LOG(INFO) << "found " << bv->name << " " << it->second << " " - << " ai " << ai_abs; + VLOG(3) << "found " << bv->name << " " << it->second << " " + << " ai " << ai_abs; } if (it != var_intervals.end() && std::abs(it->second.r) > ai_abs && std::abs(it->second.l) > ai_abs) { return make_const(a.type(), 0); diff --git a/cinn/common/ir_util.cc b/cinn/common/ir_util.cc index 022778db5e..48fd22f4f5 100644 --- a/cinn/common/ir_util.cc +++ b/cinn/common/ir_util.cc @@ -125,7 +125,7 @@ Expr RampRelatedMul(Expr a, Expr b) { CHECK_EQ(a_broadcast->lanes, b_broadcast->lanes); return ir::Broadcast::Make(a_broadcast->value * b_broadcast->value, a_broadcast->lanes); } else { - LOG(INFO) << "a,b: " << a << " " << b; + VLOG(3) << "a,b: " << a << " " << b; CINN_NOT_IMPLEMENTED } } diff --git a/cinn/frontend/computation.cc b/cinn/frontend/computation.cc old mode 100644 new mode 100755 index 99a464e949..c62b04aead --- a/cinn/frontend/computation.cc +++ b/cinn/frontend/computation.cc @@ -127,7 +127,7 @@ std::shared_ptr CinnComputation::CompilePaddleModel( } program->SetInputs({input_vars}); program->Validate(); - LOG(INFO) << "program:\n" << *program; + VLOG(3) << "program:\n" << *program; for (auto &name : fetch_names) { output_vars.push_back(varmap.at(name)); diff --git a/cinn/frontend/interpreter.cc b/cinn/frontend/interpreter.cc old mode 100644 new mode 100755 index 9b5679dbb8..f98f498107 --- a/cinn/frontend/interpreter.cc +++ b/cinn/frontend/interpreter.cc @@ -103,7 +103,7 @@ void Interpreter::Impl::Build(const std::vector& input_names, program_->SetInputs({input_vars}); program_->Validate(); - LOG(INFO) << "Program:\n" << *program_; + VLOG(3) << "Program:\n" << *program_; auto graph = std::make_shared(*program_, target); graph->attrs["model_name"] = std::make_shared(model_name); diff --git a/cinn/frontend/paddle/model_parser.cc b/cinn/frontend/paddle/model_parser.cc old mode 100644 new mode 100755 index 5c0d83bd6a..8ab48da30e --- a/cinn/frontend/paddle/model_parser.cc +++ b/cinn/frontend/paddle/model_parser.cc @@ -222,9 +222,9 @@ void LoadModelPb(const std::string &model_dir, CHECK(cpp_prog); CHECK(scope); cpp_prog->ClearBlocks(); - LOG(INFO) << "model_dir is: " << model_dir; - LOG(INFO) << "model_file is: " << model_file; - LOG(INFO) << "param_file is: " << param_file; + VLOG(3) << "model_dir is: " << model_dir; + VLOG(3) << "model_file is: " << model_file; + VLOG(3) << "param_file is: " << param_file; // Load model VLOG(4) << "Start load model program..."; std::string prog_path = model_dir + "/__model__"; diff --git a/cinn/hlir/pe/nn.cc b/cinn/hlir/pe/nn.cc old mode 100644 new mode 100755 index 1a124b2f37..8c93189503 --- a/cinn/hlir/pe/nn.cc +++ b/cinn/hlir/pe/nn.cc @@ -253,11 +253,11 @@ std::vector Conv2d_NCHW(const ir::Tensor &input, std::to_string(output_shape_int[1]) + " " + std::to_string(output_shape_int[2]) + " " + std::to_string(output_shape_int[3]); if (res.count(key) > 0) { - LOG(INFO) << "Find saved winograd_conv2d schedule param! key is: " << key; + VLOG(3) << "Find saved winograd_conv2d schedule param! key is: " << key; return Conv2d_winograd_NCHW( input, weights, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, output_name); } - LOG(INFO) << "Didn't find saved winograd_conv2d schedule param! key is: " << key; + VLOG(3) << "Didn't find saved winograd_conv2d schedule param! key is: " << key; } ir::Tensor input_pad; if (pad_h == 0 && pad_w == 0) { diff --git a/cinn/optim/replace_var_with_expr.cc b/cinn/optim/replace_var_with_expr.cc index 6611a7278e..eeee96f2cd 100755 --- a/cinn/optim/replace_var_with_expr.cc +++ b/cinn/optim/replace_var_with_expr.cc @@ -227,7 +227,7 @@ struct ReplaceVarIndexOfCacheMutator : public ir::IRMutator<> { if (tensor_shape[index].is_constant() && tensor_shape[index].get_constant() <= 0) { tensor_shape[index] = Expr(1); } else if (!tensor_shape[index].is_constant()) { - LOG(INFO) << "Index is not constant: " << tensor_shape[index] << " and it will be replaced to 1"; + VLOG(3) << "Index is not constant: " << tensor_shape[index] << " and it will be replaced to 1"; tensor_shape[index] = Expr(1); } (*global_tensor_map_).at(tensor_name)->shape = tensor_shape; @@ -239,7 +239,7 @@ struct ReplaceVarIndexOfCacheMutator : public ir::IRMutator<> { VLOG(3) << i; } } else { - LOG(INFO) << "extent not defined"; + VLOG(3) << "extent not defined"; } } diff --git a/cinn/poly/compute_at_transform.cc b/cinn/poly/compute_at_transform.cc old mode 100644 new mode 100755 index ece3b89e9d..1f69706a4c --- a/cinn/poly/compute_at_transform.cc +++ b/cinn/poly/compute_at_transform.cc @@ -128,7 +128,7 @@ void ComputeAtTransform::DisplayC(isl_map* pschedule, isl_map* cschedule) { auto* build = isl_ast_build_from_context(context.release()); auto* node = isl_ast_build_node_from_schedule_map(build, intersect_schedule.release()); - LOG(INFO) << "code:\n\n" << isl_ast_node_to_C_str(node); + VLOG(3) << "code:\n\n" << isl_ast_node_to_C_str(node); isl_ast_node_free(node); } diff --git a/cinn/poly/stage.cc b/cinn/poly/stage.cc index c13de5f772..ae17b26951 100755 --- a/cinn/poly/stage.cc +++ b/cinn/poly/stage.cc @@ -970,12 +970,12 @@ void Stage::Vectorize(int level, int factor) { CHECK_LT(level, n_out_dims()); CHECK_GT(factor, 0); if (factor == 1) { - LOG(INFO) << "Vectorize-factor 1 has no sense, skip it"; + VLOG(3) << "Vectorize-factor 1 has no sense, skip it"; return; } auto transformed_domain = this->transformed_domain(); if (isl_is_removed_axis(transformed_domain.get(), level)) { - LOG(INFO) << "Vectorizing for-1 has no sense, skip it"; + VLOG(3) << "Vectorizing for-1 has no sense, skip it"; return; } int removed_axes_counts = isl_get_precending_removed_axes_counts(transformed_domain.get(), level); @@ -1008,7 +1008,7 @@ void Stage::Parallel(int level) { auto transformed_domain = this->transformed_domain(); VLOG(3) << "transformed_domain" << transformed_domain; if (isl_is_removed_axis(transformed_domain.get(), level)) { - LOG(INFO) << "Paralleling for-1 has no sense, skip it"; + VLOG(3) << "Paralleling for-1 has no sense, skip it"; return; } int removed_axes_counts = isl_get_precending_removed_axes_counts(transformed_domain.get(), level); diff --git a/cinn/pybind/frontend.cc b/cinn/pybind/frontend.cc old mode 100644 new mode 100755 index 693ec3edb8..6e9c9da1cf --- a/cinn/pybind/frontend.cc +++ b/cinn/pybind/frontend.cc @@ -226,7 +226,7 @@ void BindFrontend(pybind11::module *m) { CINN_NOT_IMPLEMENTED } } - LOG(INFO) << info; + VLOG(3) << info; program->ExecuteTest(repeat_); auto out = scope->GetTensor(tensor_out->id); return out; @@ -268,7 +268,7 @@ void BindFrontend(pybind11::module *m) { CINN_NOT_IMPLEMENTED } } - LOG(INFO) << info; + VLOG(3) << info; program->ExecuteTest(repeat_); auto out = scope->GetTensor(tensor_out->id); return out; diff --git a/cinn/pybind/lang.cc b/cinn/pybind/lang.cc old mode 100644 new mode 100755 index f96332fc03..6804e618eb --- a/cinn/pybind/lang.cc +++ b/cinn/pybind/lang.cc @@ -41,6 +41,7 @@ using utils::StringFormat; namespace { void BindBuffer(py::module *); void BindLower(py::module *); +void BindLowerVec(py::module *); void BindPlaceholder(py::module *); void BindCompute(py::module *); void BindModule(py::module *); @@ -66,6 +67,19 @@ void BindLower(py::module *m) { arg("target") = common::DefaultHostTarget()); } +void BindLowerVec(py::module *m) { + using py::arg; + m->def("lower_vec", + &lang::LowerVec, + arg("name"), + arg("stages"), + arg("tensor_args"), + arg("scalar_args") = std::vector(), + arg("temp_tensors") = std::vector(), + arg("b") = nullptr, + arg("target") = common::DefaultHostTarget()); +} + void BindCompute(py::module *m) { #define MAKE_COMPUTE_FN(__fn) \ py::overload_cast &, __fn, const std::string &, const std::vector &>( \ @@ -218,6 +232,7 @@ void BindBuiltin(py::module *m) { void BindLang(py::module *m) { BindBuffer(m); BindLower(m); + BindLowerVec(m); BindPlaceholder(m); BindCompute(m); BindModule(m); diff --git a/cinn/pybind/poly.cc b/cinn/pybind/poly.cc index e1bc8a52ab..dae4c51700 100644 --- a/cinn/pybind/poly.cc +++ b/cinn/pybind/poly.cc @@ -47,6 +47,18 @@ void BindMap(py::module *m) { condition.def_readwrite("cond", &Condition::cond).def(py::init()).def("__str__", &Condition::__str__); } +void BindStageMap(py::module *m) { + DefineShared(m, "StageMap"); + py::class_> stage_map(*m, "StageMap"); + stage_map // + .def( + "__getitem__", + [](poly::StageMap self, ir::Tensor &t) -> Stage & { return *self[t]; }, + py::return_value_policy::reference); + + m->def("create_stages", &poly::CreateStages, py::arg("tensors")); +} + void BindStage(py::module *m) { py::class_ stage(*m, "Stage"); // enum Stage::ComputeAtKind @@ -73,6 +85,7 @@ void BindStage(py::module *m) { .def("split", py::overload_cast(&Stage::Split), arg("level"), arg("factor")) .def("split", py::overload_cast(&Stage::Split), arg("level"), arg("factor")) .def("fuse", py::overload_cast(&Stage::Fuse), arg("level0"), arg("level1")) + .def("fuse", py::overload_cast &>(&Stage::Fuse)) .def("reorder", py::overload_cast &>(&Stage::Reorder), "Reorder the axis in the computation") @@ -87,23 +100,17 @@ void BindStage(py::module *m) { .def("unroll", py::overload_cast(&Stage::Unroll)) .def("unroll", py::overload_cast(&Stage::Unroll)) .def("unroll", py::overload_cast(&Stage::Unroll)) + .def("parallel", py::overload_cast(&Stage::Parallel)) + .def("parallel", py::overload_cast(&Stage::Parallel)) + .def("parallel", py::overload_cast(&Stage::Parallel)) .def("compute_at", &Stage::ComputeAtSchedule, arg("other"), arg("level"), arg("kind") = Stage::kComputeAtAuto) .def("skew", &Stage::Skew) .def("ctrl_depend", &Stage::CtrlDepend) .def("cache_read", &Stage::CacheRead) - .def("cache_write", &Stage::CacheWrite); -} - -void BindStageMap(py::module *m) { - DefineShared(m, "StageMap"); - py::class_> stage_map(*m, "StageMap"); - stage_map // - .def( - "__getitem__", - [](poly::StageMap self, ir::Tensor &t) -> Stage & { return *self[t]; }, - py::return_value_policy::reference); - - m->def("create_stages", &poly::CreateStages, py::arg("tensors")); + .def("cache_write", &Stage::CacheWrite) + .def("sync_threads", py::overload_cast(&Stage::SyncThreads)) + .def("sync_threads", + py::overload_cast &, poly::StageMap>(&Stage::SyncThreads)); } } // namespace diff --git a/docs/guide.md b/docs/guide.md deleted file mode 100644 index 4ea81d4616..0000000000 --- a/docs/guide.md +++ /dev/null @@ -1,41 +0,0 @@ -# CINN INSTAllATION GUIDANCE - -### Step 1. Clone Source Code - -Clone CINN from github. - -`git clone https://github.com/PaddlePaddle/CINN.git` - -### Step 2. Build Docker Image - -Build docker image based on the given dockerfile in ./tools/docker/Dockerfile. - -`cd ./CINN/tools/docker` - -`sudo docker build -t cinn_image:v1 .` - -### Step 3. Start a docker container - -Start a docker container and mount folder ./CINN into it. - -Go back to the path where you clone CINN. - -`sudo nvidia-docker run -it --net=host -v $PWD/CINN:/WorkSpace/CINN --name=your_docker_name cinn_image:v1` - -### Step 4. Prepare dependencies - -After enter the container, run ./CINN/tools/ci_build.sh - -`./CINN/tools/ci_build.sh` - -### Step 5. Build CINN and do ci test - -Build CINN and do ci test to verify correctness. - -`cd CINN` - -There are 3 kinds of ci test: - -1. Test on CPU(X86) backends: `./build.sh ci` -2. Test on NVGPU(cuda) backends with CUDNN library: `./build.sh gpu_on ci` -3. Test on NVGPU(cuda) backends without CUDNN library: `./build.sh gpu_on cudnn_off ci` diff --git a/docs/source/conf.py b/docs/source/conf.py index 20366d5e8c..00ad0e1855 100755 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -34,11 +34,11 @@ # -- Project information ----------------------------------------------------- project = 'cinn' -copyright = '2020, cinn team' +copyright = '2021, cinn team' author = 'cinn Team' # The full version, including alpha/beta/rc tags -release = '0.1-alpha' +release = 'release/v0.1-rc' # -- General configuration --------------------------------------------------- @@ -47,6 +47,7 @@ # ones. extensions = [ 'sphinx.ext.doctest', + 'sphinx_rtd_theme', 'sphinx.ext.autosummary', 'sphinx.ext.mathjax', 'sphinx_gallery.gen_gallery', @@ -76,7 +77,7 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'alabaster' +html_theme = 'sphinx_rtd_theme' # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, diff --git a/docs/source/guide.md b/docs/source/guide.md new file mode 100644 index 0000000000..4e076523e5 --- /dev/null +++ b/docs/source/guide.md @@ -0,0 +1,33 @@ +# Install CINN using docker + +### Step 1. Start a docker container + +Start a docker container based on upstream image. + +`nvidia-docker run --name $CONTAINER_NAME -it --net=host registry.baidubce.com/paddlepaddle/paddle:2.2.0-gpu-cuda11.2-cudnn8 /bin/bash` + +If you are using the latest version of docker, try: + +`docker run --gpus all --name $CONTAINER_NAME -it --net=host registry.baidubce.com/paddlepaddle/paddle:2.2.0-gpu-cuda11.2-cudnn8 /bin/bash` + +And notice that if your cuda version is not 11.2, replace the docker image to the corresponding paddle image with identical cuda version [here](https://registry.hub.docker.com/r/paddlepaddle/paddle). + +### Step 2. Clone Source Code + +After entering the container, clone the source code from github. + +`git clone https://github.com/PaddlePaddle/CINN.git` + +### Step 3. Build CINN and do ci test + +Build CINN and do ci test to verify correctness. + +`cd CINN` + +There are 5 kinds of ci test: + +1. Test on CPU(X86) backends: `bash ./build.sh ci` +2. Test on CPU(X86) backends without mklcblas: `bash ./build.sh mklcblas_off ci` +3. Test on CPU(X86) backends without mkldnn: `bash ./build.sh mkldnn_off ci` +4. Test on NVGPU(cuda) backends with CUDNN library: `bash ./build.sh gpu_on ci` +5. Test on NVGPU(cuda) backends without CUDNN library: `bash ./build.sh gpu_on cudnn_off ci` diff --git a/docs/source/index.rst b/docs/source/index.rst old mode 100644 new mode 100755 index 12cf9db542..7103cc3fcb --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -21,8 +21,9 @@ Install :maxdepth: 1 ./install.md + ./guide.md -cinn +CINN ------ Get Started @@ -34,15 +35,16 @@ Get Started C++ APIs -~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~ .. toctree:: :maxdepth: 1 matmul.md + load_paddle_model.md cpp/library_root.rst -cinnrt +CINNRT ------- TBD diff --git a/python/tests/test_computation.py b/python/tests/test_computation.py old mode 100644 new mode 100755 index aa2499052e..5a257e5ea8 --- a/python/tests/test_computation.py +++ b/python/tests/test_computation.py @@ -26,7 +26,6 @@ from cinn import lang from cinn.common import * import numpy as np -import paddle.fluid as fluid import sys assert len(sys.argv) == 3 diff --git a/tutorials/README.md b/tutorials/README.md new file mode 100644 index 0000000000..54ed716083 --- /dev/null +++ b/tutorials/README.md @@ -0,0 +1,13 @@ +================= +Tutorials +================= +This page contains the tutorials about CINN. + +--------- +Run demo +--------- +Compile ``demo.cc``:: + + cd build/dist + + bash build_demo.sh diff --git a/tutorials/README.txt b/tutorials/README.txt deleted file mode 100644 index 9159112743..0000000000 --- a/tutorials/README.txt +++ /dev/null @@ -1,11 +0,0 @@ -Tutorials -=========== -This page contains the tutorials about CINN. - -#### Run demo -compile demo.cc - -```bash -cd build/dist -bash build_demo.sh -``` diff --git a/tutorials/jit.py b/tutorials/jit.py index e9e866705e..6812ea42c1 100755 --- a/tutorials/jit.py +++ b/tutorials/jit.py @@ -18,11 +18,10 @@ In this tutorial, we will introduce the JIT module that execute the DSL on X86 and NV GPU. """ -# sphinx_gallery_thumbnail_path = '_static/icon.png' - import cinn import numpy as np from cinn import runtime +# sphinx_gallery_thumbnail_path = './paddlepaddle.png' ################################################################## # declare some variables for latter use diff --git a/tutorials/load_paddle_model.cc b/tutorials/load_paddle_model.cc new file mode 100644 index 0000000000..458329c7a9 --- /dev/null +++ b/tutorials/load_paddle_model.cc @@ -0,0 +1,92 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! @h1 Load and Execute Paddle Model in C++ +//! In this tutorial, we will show you how to load and execute a paddle model in CINN using C++. +//! We will use model ResNet50 as an example. + +#include + +#include "cinn/cinn.h" + +using namespace cinn; // NOLINT + +//! @IGNORE-NEXT +TEST(LOAD_MODEL, basic) { + //! @h2 Prepare to Load Model + //! Declare the params and prepare to load and execute the paddle model. + //! + `input_name` is the name of input tensor in the model. + //! + `target_name` is the name of output tensor we want. + //! + `x_shape` is the input tensor's shape of the model. + + std::string input_name = "inputs"; + std::string target_name = "save_infer_model/scale_0.tmp_1"; + std::vector x_shape = {1, 3, 224, 224}; + + //! @h2 Set the target backend + //! Now CINN only supports two backends: X86 and CUDA. + //! + To choose X86 backends, use : + //! `auto target = common::DefaultHostTarget();` + //! + To choose CUDA backends, use : + //! `auto target = common::DefaultNVGPUTarget();` + + auto target = common::DefaultHostTarget(); + + //! @h2 Load Model to CINN + //! Load the paddle model and compile it into CINN IR. + //! + `target` is the backend to execute model on. + //! + `model_dir` is the path where the paddle model is stored. + //! + `params_combined` implies whether the params of paddle model is stored in one file. + + std::string model_dir = "./ResNet50"; + bool params_combined = true; + auto computation = + frontend::CinnComputation::CompilePaddleModel(target, model_dir, {input_name}, {x_shape}, params_combined); + + //! @h2 Get input tensor and set input data + //! Here we use all-zero data as input. In practical applications, please replace it with real data according to your + //! needs. + + auto input_tensor = computation.GetTensor(input_name); + + std::vector fake_input(input_tensor->shape().numel(), 0.f); + + auto *input_data = input_tensor->mutable_data(target); + if (target.arch == Target::Arch::X86) { + std::copy(fake_input.begin(), fake_input.end(), input_data); + } else if (target.arch == Target::Arch::NVGPU) { + CUDA_CALL(cudaMemcpy( + input_data, fake_input.data(), input_tensor->shape().numel() * sizeof(float), cudaMemcpyHostToDevice)); + } + + //! @h2 Execute Model + //! Execute the model and get output tensor's data. + + computation.Execute(); + + auto target_tensor = computation.GetTensor(target_name); + std::vector output_data(target_tensor->shape().numel(), 0.f); + if (target.arch == Target::Arch::X86) { + std::copy(target_tensor->data(), + target_tensor->data() + target_tensor->shape().numel(), + output_data.data()); + } else if (target.arch == Target::Arch::NVGPU) { + CUDA_CALL(cudaMemcpy(output_data.data(), + reinterpret_cast(target_tensor->mutable_data(target)), + target_tensor->shape().numel() * sizeof(float), + cudaMemcpyDeviceToHost)); + } + //! @IGNORE-NEXT + LOG(INFO) << "Succeed!"; +} diff --git a/tutorials/load_paddle_model.py b/tutorials/load_paddle_model.py index 9d59e49b95..cef609ede3 100755 --- a/tutorials/load_paddle_model.py +++ b/tutorials/load_paddle_model.py @@ -16,10 +16,11 @@ ===================== In this tutorial, we will show you how to load and execute a paddle model in CINN. +We offer you four optional models: ResNet50, MobileNetV2, EfficientNet and FaceDet. """ -# sphinx_gallery_thumbnail_path = '_static/icon.png' - +import paddle +import paddle.fluid as fluid import cinn from cinn import * from cinn.frontend import * @@ -27,71 +28,133 @@ from cinn.common import * import numpy as np import os +import sys +# sphinx_gallery_thumbnail_path = './paddlepaddle.png' ################################################################## -# Prepare to Load Model -# ------------------------- +# **Prepare to Load Model** +# -------------------------- # Declare the params and prepare to load and execute the paddle model. # -# - :code:`enable_gpu` implies whether to run CINN on CUDA backends. -# -# - :code:`mnodel_dir` is the path where the paddle model is stored. +# - :code:`model_dir` is the path where the paddle model is stored. # # - :code:`input_tensor` is the name of input tensor in the model. # # - :code:`target_tensor` is the name of output tensor we want. # # - :code:`x_shape` is the input tensor's shape of the model - -model_dir = "./ResNet18" -input_tensor = 'image' -target_tensor = 'save_infer_model/scale_0' +# +# - When choosing model ResNet50, the params should be :: +# +# model_dir = "./ResNet50" +# +# input_tensor = 'inputs' +# +# target_tensor = 'save_infer_model/scale_0.tmp_1' +# +# x_shape = [1, 3, 224, 224] +# +# - When choosing model MobileNetV2, the params should be :: +# +# model_dir = "./MobileNetV2" +# +# input_tensor = 'image' +# +# target_tensor = 'save_infer_model/scale_0' +# +# x_shape = [1, 3, 224, 224] +# +# - When choosing model EfficientNet, the params should be :: +# +# model_dir = "./EfficientNet" +# +# input_tensor = 'image' +# +# target_tensor = 'save_infer_model/scale_0' +# +# x_shape = [1, 3, 224, 224] +# +# - When choosing model FaceDet, the params should be :: +# +# model_dir = "./FaceDet" +# +# input_tensor = 'image' +# +# target_tensor = 'save_infer_model/scale_0' +# +# x_shape = [1, 3, 240, 320] +# +model_dir = "./ResNet50" +input_tensor = 'inputs' +target_tensor = 'save_infer_model/scale_0.tmp_1' x_shape = [1, 3, 224, 224] ################################################################## -# Set the target backend +# **Set the target backend** +# ------------------------------ +# Now CINN only supports two backends: X86 and CUDA. +# +# - For CUDA backends, set ``target = DefaultNVGPUTarget()`` +# +# - For X86 backends, set ``target = DefaultHostTarget()`` +# if os.path.exists("is_cuda"): target = DefaultNVGPUTarget() else: target = DefaultHostTarget() ################################################################## -# Set the input tensor and init interpreter -executor = Interpreter([input_tensor], [x_shape]) - -################################################################## -# Load Model to CINN +# **Load Model to CINN** # ------------------------- -# Load the paddle model and transform it into CINN IR -# -# * :code:`mnodel_dir` is the path where the paddle model is stored. +# Load the paddle model and transform it into CINN IR. # # * :code:`target` is the backend to execute model on. # -# * :code:`params_combined` implies whether the params of paddle -# model is stored in one file. - +# * :code:`model_dir` is the path where the paddle model is stored. +# +# * :code:`params_combined` implies whether the params of paddle model is stored in one file. +# +# +model_name = "resnet50" params_combined = True -executor.load_paddle_model(model_dir, target, params_combined) +computation = Computation.compile_paddle_model( + target, model_dir, [input_tensor], [x_shape], params_combined) ################################################################## -# Get input tensor and set input data -a_t = executor.get_tensor(input_tensor) +# **Get input tensor and set input data** +# ----------------------------------------- +# Here we use random data as input. In practical applications, +# please replace it with real data according to your needs. +# +a_t = computation.get_tensor(input_tensor) x_data = np.random.random(x_shape).astype("float32") a_t.from_numpy(x_data, target) ################################################################## -# Get output tensor and init its data to zero. -out = executor.get_tensor(target_tensor) +# Here we set the output tensor's data to zero before running the model. +out = computation.get_tensor(target_tensor) out.from_numpy(np.zeros(out.shape(), dtype='float32'), target) ################################################################## -# Execute Model +# **Execute Model** # ------------------------- # Execute the model and get output tensor's data. -# * :code:`out` is the data of output tensor we want. +# :code:`out` is the data of output tensor we want. +computation.execute() +res_cinn = out.numpy(target) +print("CINN Execution Done!") -executor.run() -out = out.numpy(target) -print("Execution Done!\nResult shape is:\n", out.shape) -print("Result data is:\n", out) +################################################################## +# **Use Paddle to Verify Correctness** +# ------------------------- +# Now we run the model by paddle and check if the 2 results are identical. +config = fluid.core.AnalysisConfig(model_dir + '/__model__', + model_dir + '/params') +config.disable_gpu() +config.switch_ir_optim(False) +paddle_predictor = fluid.core.create_paddle_predictor(config) +data = fluid.core.PaddleTensor(x_data) +paddle_out = paddle_predictor.run([data]) +res_paddle = paddle_out[0].as_ndarray() +print("Paddle Execution Done!\n =============================") +print("Verification result is: ", np.allclose(res_cinn, res_paddle, atol=1e-3)) diff --git a/tutorials/matmul.py b/tutorials/matmul.py index 8cae9ae1cf..1924b80fbb 100755 --- a/tutorials/matmul.py +++ b/tutorials/matmul.py @@ -18,12 +18,11 @@ In this tutorial, we will introduce several ways to optimize the performance of the matrix multiplication on X86 CPU. """ -# sphinx_gallery_thumbnail_path = '_static/icon.png' - import cinn import numpy as np import time from cinn import runtime +# sphinx_gallery_thumbnail_path = './paddlepaddle.png' ################################################################## # Declare the basic computation for a matmul diff --git a/tutorials/paddlepaddle.png b/tutorials/paddlepaddle.png new file mode 100644 index 0000000000000000000000000000000000000000..d877d50303a4b3474708fe95a9ac2d26677199ff GIT binary patch literal 8598 zcmeHso-6|?~W3|@458Rs$jeTj$a zKnvs?Jz5Z2gv&78d;@oa}fWzS^HY#c5lF40iqM z>h2n`=%q``Fg0C&LA?G*_a~r8tGm&WAdU|Cox7i_v*G-Hwp96ctJZCP`Nf?#b7{b{ zuC!$19D?F|fV;MbtQPL0N7k61znYqB1FRYW3k)C`C5#0DnL{s}K*CSCKucOcB`&VV zTos9z(S!};Kv7JFxZy`H<8egOyfl-IwJ-~F&+$28`KJU_ zpHYX=r+!(_gjSUmS=d>Ud zjL7LrOMAMyatFJFb+5GC^Y5f|-(9WPKQ)E^&?GC5 zF^;7x45h;?(ZzBv;NAw8{Fv1@fimRN=?O0Y@tV@pW5RLB-s>q4#~xx}s^GBDD?C-1v@`G$$=HV_J!Ug09Flx{=V{p8 zjr^Pc=(94jXy?B6rrwPpkO_^IkuXr~@9PBGj8{NYx7P3q?dU90WO7E%-6(NF0XTkd zWc&eYs_Jg-N#?2W07fl3w3PjcMf7Gz^wK~q|B4Ia;3EuTx_=m4+|@)SrPH4A4P0vJ z(LhW(;vh_KUFpawMZdExb`J8lKL-cj>bhE?z~;h(N#x<=occD8F*{h+jREIESh_?& z85uRq!ILlt6|A(!7{ue@0L+Z%`~}H$0L%>>y3oj?aA$PWm{$B;7=ZB# zt_m5m&;ECBeEL{?U5kF@YKc*>V9wPPGM0N_yna^C=S@H^tjHGg^PVXsU zi&VtNCsl9R|6PTAUhk(Qk@RsR?U&S+O_WGes}nXfIczdGU<$Y(kFQX(zns-ycL~%y zcL}V-?naC91va>3LLNRL(z1 zuhlq2vVdPiHDUJadm&9!?mo_;w5IGYn#H%3zG(~GTh%mhLHJ)5D->Sd#0$&}XLh$j zBg{%EuF~$r9nHK*@S(E5v}7^8q(rG@@ydMlr>?skJmiE~;5$#q))1%;wB})nH*%xn znNIf+V3+Vw0p+zp5@t40E0$??|1)aUq0S7W`| zAoAm=mAbdSJ&LCcJPbbO5@iU@jU=NI#i?4*r1pjoD^bL*x&F7|SllGw1Z_QUpod%cxD#7q_R`y$E@ z$6{j73g(Oni!o5sEh7vZROv|gQ2eor6-;fA;vR4MHfH|c0`y`?7Ep;n8gfQ5A2@#< zL|?R+s+(6RA2kiN)=kVmUJD$5`6JNtP+Ch7fdBsFyvdAJ!Ofj9Ew_tUoorY2V|1V0 zzW&a-3KM1q#ZS>*{xCmifPSF~mJv(-J+8aAzxgE}LlUuTCpEMsQ;YPQ^gCddh_GrOQy_6r8Kc2(BjfgCH1q#{oImE zbFGK@F41vQh$uTY2A6(7(a_Q1HA;7z`J#ylZ!zK za((AU$eE@lp3BjDn5DH(e6D$vGtY$~-qpa$Cp8Yy* z=7>0^nvsb}$1aOGW4Q^J7A|`}Ah990+623^(RnT-A!Jp7$B)Yqw&Xfm{JZm=9O271 z{uyr8&o%DiJ;}%NJ@gl>tUJkqeMFYTKZAj8 zKLeKVt%w--y*1dwU$Vg3DO`NK2!q4&d}^d_|T447AFd)jqBH*DnTrnbDy4s zl$G>qwdzMtn@OPF&#_EkHz<%5lCX(EVrK6vKXhDui6f$xl1;OE`JuQ+^!nG9(g<=cmII-WcsX2jPI=mGkIo0q5BY$Ri~|py4^#md zT&Xxx5$NP0k4_#`6`ES;6&zU$ZK(PO5#f;-mqdV5- z53dA#e0$wL6#v|a+L2mX|N6@Bhb(+5uNZE<+@3c`8>}rLWGCO6BYmlxv-^8yIuF5* zT1qF&!!oDqvI87X%X0@TSdYUf03P21kHL5g5zwl!Pc-oA$2DewZ^x?$cfz4nFoOAz zpt8>$fqmSR0brtmiR=V+Hl)f$iamJ-p81(!v&HFhVr==V?ac`6h?1s+7cqI{0W z;V}EfIwAXS))lC8D`5pOJO`LxAk{We9{m<|qA6c0+5Bz0My+i}9tcZ^6J+Ukb5BKc+N=V?m=ZSSsP**e6m%LWx|u2K zQMhD@;`Xon+`(3kV?ur$k1xeuVA6YBnI+Pf+c)9J~s!1DxTaKzO ztUzAH&B&k6Gyg=Jy-WKJTBj81Um5CSKVO00&Nv4U>6+pmVoAyC0s5VoE&rkVB8W*w$nZfGkyQwh*J40qi$C0!}~!#Dp^dw(*-aJ@vfeYQ9u#xQl+C#3=U zo5`{36AJHvH~CU?#UL(dVBZE~;d#xEUBfOjXx6pl?XqP_O(hT{J>WFBEYZi~JZCfZ z_tJ9+Upe7Wuf$J}qEEF$2ds16gtTw8Ny>gcV$w4-^JF=wx3o-dxcAlbleS>a983J3 z?d8!3GDMo&+-Dw}*32!PuM4sG@G|-7;RyT~NQqT$_+yH{h@~)l#iK}2nwP2RPqOE^ z#Vgov#)%5n?zDZ^mXcB1(F02Fb_7OJy{JULzwXN7RD9)(eQ0w(laf$sqF-N>YIA9~ zn%zxTMj7<|TC}~fXo=rWgW`;_r?Nm+6M~>uEq{i$t}`F7gKhU`g(B}x@)Qwczn$Nq zik}NMsi`)xd~;dK0n6g+pNbpwWI?o36OK~$sx6HWJ1~>U>PTmJH?3_l;+H{gUUtEb z`3pr0;WYeYaEk`;UR*$D)>DZo~ubIka5mm?RFL}zD1 zYv|4`G`#sz(HiKZsH7HxHNdl?{p%O?a*+w%*Sx7LMQ{?09nTmG9EZcT$&-1{_oK~{ zel!9({anyrS9m}~TDO>VGjD&hF`?Ez9$@##;q&dJv!L`Ldj-4q-ga5mWY3*F4e*2H zcz5r$>uC^1KG=`sm?9Iz$s)|y%#jp>K&sN9-p>+ZRIw1IgkYns+oWjFUL3&e4OR^J zg2N*ptr-ppj66buKCYJeQnp_x|LMmTov#Zppg3CXJD@^;T*!aUzbaXJ*HWPakuKmv zl}p}J0nhC2L@7-yWrrih)e0&ci|1r`G2oWpb|zA7dgs6va5w53;l0lV zRw!hrIMwqQ66nh4iESeK9t9nRgv!Orp)V;{A~6bYs{&oz-4F|F8$O&JiGtA#;k zQ}_Nw=q7)}%z+OTrEMM&NiDGhdGYsw=)mRboM}1CY_Prl1j9TJ=x;nn@?qZWQ5Qq? zkmPO(@X>l}QX;?(&>G;2)?K-@)q5XYvyn^1@~YL7$mr563UV=h%t?UNwEGdIBoV)* zietf|^7ubMjfh$|Zw9!or-!%7n+R&N@?m;bWr<_Md^#`zl_=8c4*6}wDnawP zQvTY_`ZOS_YW(%hJd&c#QtmF_jd?Z8eP+=98u*)Mg|i9^G<#y?K1+raD>8D6bM4ux zVzGzC1AK&2j7`F~pyrj9+`;Hiyg)F2&1`C>%U$v_q!@WX2=|(fyisY5L0rFC7MM*F-reS-TdcgeJX zNGmN;OzD0!u#E>iMA=|xT26t%72Z3^6Ff^vbCaG0QtrI56geMrsW$ouzLUx^P*IcK zei@cKH^0oWJ)T%b&V{6^i-KMqIV?>s<1DJ<6L&Fa#(+8zIBZU?d9TG`#iXqE`&<$$ zBTRetAGHzI`q5a~($C&B<-k4MuVJY+EGKBD`Zkdk=F#J?NH6Af^sWlr&#ooS5E|M2s1w1f$>sEg zc}vh|IY3trz8?_^;8I77`(U-9ZrBdD?YHik&>(QIZe!|bVU5g zYg={W2XA9Iq00(;pEHszgn!3UtvFbH3ENMlyIs_g8(e7n8mG@etQims4vzf0CmW!@ zpnI zd9I*##zQ%SWn@mbxRA6=G!`-eW%LLbdG~1P>C%2A%v`DWrw9p;)@ry{_bj?+x6B3 zwIN#Xm$IxDQyoJJR+OEe{Ji)(XFHi#vjc0Q9Q-I0IzX|ILSEq?#ZkWrx#I$7HXRo3 zrEsgMM*jH?c0XatNl08x1)&JJUvB%G->kePm6lPqCv7@z3qaf`jfw?Eg}EU0LI{tL0fTeQJ!#xDRqK z0%qMfdH%xYG0iu8G@t5!1xm)eAQ`HqVqO*(%hrSu4Is{RG}u=aY7mPjkVQv7M+V() zqRr@=NtF@y)`GLURC39dC!I$zO z*sZE8-KrK>0CG7+ZdGhtSF`MIi-fP+dNd27HQESW1;fXLj9DV^6#C$a$Rm0#rW_GUD4xRM-ANleSl8Wq_Ack?*p+#xmFaiQhg$a+d;SLAX*x$xFxk^f4=iIs5 z>?GNdiN~35c3(7U3>sAJRL2|zvM&6*d4C(qMJe4<@Re2XVTr*XOqh#ml*?FYhwXd~ z8qQ}tsk`n3NyAl7tu?BI1$_tjUK`S6>#%-NEH@~bV)PX$w2MFT7Drzr; zJ3UH1KjP`qL6^z$$ybM`ofk>1BywO=!-{JB=&|s3fy#oBU#63WG}hvwDzeh#d3%#5xdLhUPip_z-a?JGHzmuC*3NbV{m&J2AeXW^ zujzwFO-koVgyTu~#J@f*_lyddZdc>?7ENq?+kdy+u!e&Dc9h4$IBz5!%Ps$k^Ts(6 za+Jk9j&=056Laogp-s%C>+lycm;Icl6|?D~kS6OOq7n3;fZXdYA|5VG3OV{xt4jWv zpZrzaCg~=d(C_x&Y~f#E2+~T??)zWw=hr@XWm*+i1;#`JSWuHogbM?ZW=?SKmD57= zkGxX&s{3-4G{Q6Ghq-O-NYABbT{}YUQh?lS|1ru-YY-Li7y?93>zul$dOdu*AJsN- z-S8rq!bq`|$Kjjm2T1p1N~=&9ZJ0kCI{HeVYX=);rR9C@OymVPU*uV7BM{}M#@HiA z^BOZu#Cw3925^AJ_|?*Dj6M5YqpDmIw5S9eQ02`;b`%NWD1Xd3*7X>D*iS~xXe_%D zB0OQu3Cxd$L($r+t(8tqatmN<&;Hs6V^VarU8r4_O(UoAknp0kIUzzD!1*N4ilc0V zrW67{Jm$Eb-i_1=U>cxCQ5}Ycw)#*TuzcP2Z8Hv z{JB~L=7XJSu58rlcrk_@JH4!@MHwj30ml=UuD^zqn6*E+P? zwJb#_3KDRz3te=45FA6fZE9%7$7n=y^s+)wY= zKqvxY$F}i809cQY)+@qT8Il+Vxtme!%W@!p&}%Isu49k>s=~VScq3C8U3IGF%iRPO zlac0uW7j!-RYXskYbBMFZvS^~ebhh^*3Gk$xY8Gtn-BBX7oC*AaDEse!t;eT@dE^+ z3%Y;pK=Pp0S1IDb+1*9Qhwkw%w~eXDp8YV6(mZ262oM2ok#t9SJV%_zKJBiII(dW! z;oq^>_*+-ht$ESbU@pn2_rnwc_jV~Nr2ki$)1k}%1gkHJk?^lD zqK3;O1${Ksi_y?o3H))v&ife10ckI{B3NtV8v7Qjr3wGvJ+|8pz#!7Z=K_W!!x2Xg;gss>;X z*+Z!MPBYE&(lZ~DLfA0CSZEbdgGF=pdu32_+48ThEZ={ktsVm(tj;ao*a6Y&V%9ni z50FPgI@NH|@d6F0+1#;b%j$o2s@r+~!@L=vMD}4s^Tl6C`R$MzUb%`D^-KfB8y(UE zO{=P`LxnHTkjG=dAR9E^vIOZd}XMNGp2B#M1w##E4nOAeEL9v zl22=~pea{;G&%V^t{(`L>3ZNg%{7%W(FJWTG^uIG%Rq`A&4Vst`^&zxMKhq^&>X3C z_a}8Ud)bC&FL$uKQ_)msmmB~i4vUc=